From 7aa8818647303b567c3a21fe4220b2681988e220 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Br=C3=A1ulio=20Oliveira?= Date: Wed, 21 Jan 2026 04:40:30 -0300 Subject: [PATCH 001/831] examples : use -dev/--device and WHISPER_ARG_DEVICE (#3557) Align device selection naming with llama.cpp. --- examples/cli/cli.cpp | 8 ++++++++ examples/server/server.cpp | 8 ++++++++ 2 files changed, 16 insertions(+) diff --git a/examples/cli/cli.cpp b/examples/cli/cli.cpp index 9a54742fe1d..4e84c1b2750 100644 --- a/examples/cli/cli.cpp +++ b/examples/cli/cli.cpp @@ -77,6 +77,7 @@ struct whisper_params { bool log_score = false; bool use_gpu = true; bool flash_attn = true; + int32_t gpu_device = 0; bool suppress_nst = false; bool carry_initial_prompt = false; @@ -129,6 +130,10 @@ static char * requires_value_error(const std::string & arg) { } static bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { + if (const char * env_device = std::getenv("WHISPER_ARG_DEVICE")) { + params.gpu_device = std::stoi(env_device); + } + for (int i = 1; i < argc; i++) { std::string arg = argv[i]; @@ -195,6 +200,7 @@ static bool whisper_params_parse(int argc, char ** argv, whisper_params & params else if (arg == "-dtw" || arg == "--dtw") { params.dtw = ARGV_NEXT; } else if (arg == "-ls" || arg == "--log-score") { params.log_score = true; } else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } + else if (arg == "-dev" || arg == "--device") { params.gpu_device = std::stoi(ARGV_NEXT); } else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; } else if (arg == "-nfa" || arg == "--no-flash-attn") { params.flash_attn = false; } else if (arg == "-sns" || arg == "--suppress-nst") { params.suppress_nst = true; } @@ -276,6 +282,7 @@ static void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params fprintf(stderr, " -dtw MODEL --dtw MODEL [%-7s] compute token-level timestamps\n", params.dtw.c_str()); fprintf(stderr, " -ls, --log-score [%-7s] log best decoder scores of tokens\n", params.log_score?"true":"false"); fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true"); + fprintf(stderr, " -dev N, --device N [%-7d] GPU device ID (default: 0)\n", params.gpu_device); fprintf(stderr, " -fa, --flash-attn [%-7s] enable flash attention\n", params.flash_attn ? "true" : "false"); fprintf(stderr, " -nfa, --no-flash-attn [%-7s] disable flash attention\n", params.flash_attn ? "false" : "true"); fprintf(stderr, " -sns, --suppress-nst [%-7s] suppress non-speech tokens\n", params.suppress_nst ? "true" : "false"); @@ -1003,6 +1010,7 @@ int main(int argc, char ** argv) { struct whisper_context_params cparams = whisper_context_default_params(); cparams.use_gpu = params.use_gpu; + cparams.gpu_device = params.gpu_device; cparams.flash_attn = params.flash_attn; if (!params.dtw.empty()) { diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 866ac4eafaa..b77d8a3ed46 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -103,6 +103,7 @@ struct whisper_params { bool no_timestamps = false; bool use_gpu = true; bool flash_attn = true; + int32_t gpu_device = 0; bool suppress_nst = false; bool no_context = true; bool no_language_probabilities = false; @@ -179,6 +180,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -sns, --suppress-nst [%-7s] suppress non-speech tokens\n", params.suppress_nst ? "true" : "false"); fprintf(stderr, " -nth N, --no-speech-thold N [%-7.2f] no speech threshold\n", params.no_speech_thold); fprintf(stderr, " -ng, --no-gpu [%-7s] do not use gpu\n", params.use_gpu ? "false" : "true"); + fprintf(stderr, " -dev N, --device N [%-7d] GPU device ID (default: 0)\n", params.gpu_device); fprintf(stderr, " -fa, --flash-attn [%-7s] enable flash attention\n", params.flash_attn ? "true" : "false"); fprintf(stderr, " -nfa, --no-flash-attn [%-7s] disable flash attention\n", params.flash_attn ? "false" : "true"); fprintf(stderr, " -nlp, --no-language-probabilities [%-7s] exclude language probabilities from verbose_json output\n", params.no_language_probabilities ? "true" : "false"); @@ -198,6 +200,10 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para } bool whisper_params_parse(int argc, char ** argv, whisper_params & params, server_params & sparams) { + if (const char * env_device = std::getenv("WHISPER_ARG_DEVICE")) { + params.gpu_device = std::stoi(env_device); + } + for (int i = 1; i < argc; i++) { std::string arg = argv[i]; @@ -237,6 +243,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params, serve else if (arg == "-oved" || arg == "--ov-e-device") { params.openvino_encode_device = argv[++i]; } else if (arg == "-dtw" || arg == "--dtw") { params.dtw = argv[++i]; } else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } + else if (arg == "-dev" || arg == "--device") { params.gpu_device = std::stoi(argv[++i]); } else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; } else if (arg == "-nfa" || arg == "--no-flash-attn") { params.flash_attn = false; } else if (arg == "-sns" || arg == "--suppress-nst") { params.suppress_nst = true; } @@ -643,6 +650,7 @@ int main(int argc, char ** argv) { struct whisper_context_params cparams = whisper_context_default_params(); cparams.use_gpu = params.use_gpu; + cparams.gpu_device = params.gpu_device; cparams.flash_attn = params.flash_attn; if (!params.dtw.empty()) { From c6a495ae5da4ccb1158d421ba9355893d005016c Mon Sep 17 00:00:00 2001 From: yulo <77381088+zhang-hui-yulo@users.noreply.github.com> Date: Tue, 13 Jan 2026 20:52:16 +0800 Subject: [PATCH 002/831] HIP: add fattn-mma-f16 for RDNA4 (llama/18481) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * finish VQ mma * flash_attn_ext_f16_iter * KQ_rowsum * correct exp * fix scale error * fix softmax scale * fix softmax scale * enable fattn on cpu side * fix random error * disable fattn-mma-f16 on rdna3 * fix wrong col for rdna * use identity mat to transpose * resolve conflicts * basic tuning for DeepSeek-R1-Distill-Qwen-1.5B * fix volta compile error * align rdna4 policy for fattn * adjust fattn policy * adjust kernel selection logic * update as the review comments * keep fattn-wmma logic * adjust kernel selection logic --------- Co-authored-by: zhang hui Co-authored-by: Johannes Gäßler --- ggml/src/ggml-cuda/common.cuh | 4 + ggml/src/ggml-cuda/fattn-common.cuh | 2 +- ggml/src/ggml-cuda/fattn-mma-f16.cuh | 203 +++++++++++++++++++++++---- ggml/src/ggml-cuda/fattn.cu | 42 +++++- ggml/src/ggml-cuda/mma.cuh | 49 ++++++- ggml/src/ggml-cuda/vendors/hip.h | 2 + 6 files changed, 266 insertions(+), 36 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 9516d8ec8f9..90794ff2641 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -262,6 +262,10 @@ static const char * cu_get_error_str(CUresult err) { #define FLASH_ATTN_AVAILABLE #endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ < 220) +#if defined(TURING_MMA_AVAILABLE) +#define LDMATRIX_TRANS_AVAILABLE +#endif // defined(TURING_MMA_AVAILABLE) + static bool fp16_available(const int cc) { return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL || (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_PH1); diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index 31446787287..6b55f784f34 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -914,7 +914,7 @@ void launch_fattn( const int nblocks_stream_k = max_blocks; - const bool use_stream_k = cc >= GGML_CUDA_CC_ADA_LOVELACE || tiles_efficiency_percent < 75; + const bool use_stream_k = cc >= GGML_CUDA_CC_ADA_LOVELACE || amd_wmma_available(cc) || tiles_efficiency_percent < 75; blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_total; blocks_num.y = 1; diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 856291dc3ce..e53bbc0502c 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -98,6 +98,19 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols); } +static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_rdna(const int DKQ, const int DV, const int ncols) { + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 128, 2, 64, 128, 128, 128, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 96, 64, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false); + + // TODO tune specifically for RDNA + return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols); +} + static __host__ fattn_mma_config ggml_cuda_fattn_mma_get_config(const int DKQ, const int DV, const int ncols, const int cc) { if (ampere_mma_available(cc)) { return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols); @@ -105,6 +118,9 @@ static __host__ fattn_mma_config ggml_cuda_fattn_mma_get_config(const int DKQ, c if (turing_mma_available(cc)) { return ggml_cuda_fattn_mma_get_config_turing(DKQ, DV, ncols); } + if (amd_wmma_available(cc)) { + return ggml_cuda_fattn_mma_get_config_rdna(DKQ, DV, ncols); + } GGML_ASSERT(volta_mma_available(cc)); return ggml_cuda_fattn_mma_get_config_volta(DKQ, DV, ncols); } @@ -116,6 +132,8 @@ static constexpr __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config(cons return ggml_cuda_fattn_mma_get_config_turing(DKQ, DV, ncols); #elif defined(VOLTA_MMA_AVAILABLE) return ggml_cuda_fattn_mma_get_config_volta(DKQ, DV, ncols); +#elif defined(AMD_WMMA_AVAILABLE) + return ggml_cuda_fattn_mma_get_config_rdna(DKQ, DV, ncols); #else GGML_UNUSED_VARS(DKQ, DV, ncols); return fattn_mma_config(32, 1, 0, 0, 0, 0, 0, false); @@ -186,6 +204,23 @@ static constexpr __device__ bool ggml_cuda_fattn_mma_get_Q_in_reg(const int DKQ, return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).Q_in_reg; } +static constexpr __device__ int get_cols_per_thread() { +#if defined(AMD_WMMA_AVAILABLE) + return 1; // RDNA has a single column. +#else + return 2; // This is specifically KQ columns, Volta only has a single VKQ column. +#endif // defined(AMD_WMMA_AVAILABLE) +} + +static __host__ int get_cols_per_warp(const int cc) { + if (turing_mma_available(cc) || amd_wmma_available(cc)) { + return 16; + } else { + // Volta + return 32; + } +} + // ------------------------------------------------------------------------------------------------------------------ static __host__ int ggml_cuda_fattn_mma_get_nstages(const int DKQ, const int DV, const int ncols1, const int ncols2, const int cc) { @@ -393,10 +428,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( const int jt, const int kb0, const int k_VKQ_sup) { -#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) constexpr int ncols = ncols1 * ncols2; constexpr int cols_per_warp = T_B_KQ::I; - constexpr int cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column. + constexpr int cols_per_thread = get_cols_per_thread(); constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column. constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols); constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2(DKQ, DV, ncols); @@ -413,6 +448,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( const int k_VKQ_0 = kb0 * nbatch_fa; #if defined(TURING_MMA_AVAILABLE) T_C_KQ KQ_C[nbatch_fa/(np*(cols_per_warp == 8 ? T_C_KQ::I : T_C_KQ::J))]; +#elif defined(AMD_WMMA_AVAILABLE) + T_C_KQ KQ_C[nbatch_fa/(np*T_C_KQ::J)]; #else // Volta T_C_KQ KQ_C[nbatch_fa/(np*T_C_KQ::J)]; #endif // defined(TURING_MMA_AVAILABLE) @@ -461,8 +498,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( if constexpr (cols_per_warp == 8) { mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[k_KQ_0/T_A_KQ::J]); } else { - // Wide version of KQ_C is column-major => swap A and B. + // Wide version of KQ_C is column-major +#if defined(AMD_WMMA_AVAILABLE) + // RDNA matrix C is column-major. + mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[k_KQ_0/T_A_KQ::J]); +#else + // swap A and B for CUDA. mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[k_KQ_0/T_A_KQ::J], K_A); +#endif // defined(AMD_WMMA_AVAILABLE) } } } @@ -479,8 +522,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( T_A_KQ K_A; load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K); - // Wide version of KQ_C is column-major => swap A and B. + // Wide version of KQ_C is column-major +#if defined(AMD_WMMA_AVAILABLE) + // RDNA matrix C is column-major. + mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]); +#else + // swap A and B for CUDA. mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A); +#endif // defined(AMD_WMMA_AVAILABLE) } } } @@ -532,7 +581,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( #pragma unroll for (int l = 0; l < T_C_KQ::ne; ++l) { if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::I + T_C_KQ::get_i(l) < k_VKQ_sup) { - KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k0/(np*T_C_KQ::I)].x[l] + FATTN_KQ_MAX_OFFSET); +#if defined(AMD_WMMA_AVAILABLE) + constexpr int KQ_idx = 0; +#else + // Turing + Volta: + const int KQ_idx = l % 2; +#endif // defined(AMD_WMMA_AVAILABLE) + KQ_max_new[KQ_idx] = fmaxf(KQ_max_new[KQ_idx], KQ_C[k0/(np*T_C_KQ::I)].x[l] + FATTN_KQ_MAX_OFFSET); } } } @@ -552,8 +607,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( #pragma unroll for (int l = 0; l < T_C_KQ::ne; ++l) { if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::I + T_C_KQ::get_i(l) < k_VKQ_sup) { - KQ_C[k0/(np*T_C_KQ::I)].x[l] = expf(KQ_C[k0/(np*T_C_KQ::I)].x[l] - KQ_max_new[l % 2]); - KQ_rowsum_add[l % 2] += KQ_C[k0/(np*T_C_KQ::I)].x[l]; +#if defined(AMD_WMMA_AVAILABLE) + constexpr int KQ_idx = 0; +#else + // Turing + Volta: + const int KQ_idx = l % 2; +#endif // defined(AMD_WMMA_AVAILABLE) + KQ_C[k0/(np*T_C_KQ::I)].x[l] = expf(KQ_C[k0/(np*T_C_KQ::I)].x[l] - KQ_max_new[KQ_idx]); + KQ_rowsum_add[KQ_idx] += KQ_C[k0/(np*T_C_KQ::I)].x[l]; } else { KQ_C[k0/(np*T_C_KQ::I)].x[l] = 0.0f; } @@ -584,8 +645,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( #pragma unroll for (int l = 0; l < T_C_KQ::ne; ++l) { if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::J + T_C_KQ::get_j(l) < k_VKQ_sup) { +#if defined(AMD_WMMA_AVAILABLE) + constexpr int KQ_idx = 0; +#else // Turing + Volta: - KQ_max_new[(l/2) % 2] = fmaxf(KQ_max_new[(l/2) % 2], KQ_C[(k0/(np*T_C_KQ::J))].x[l] + FATTN_KQ_MAX_OFFSET); + const int KQ_idx = (l/2) % 2; +#endif // defined(AMD_WMMA_AVAILABLE) + KQ_max_new[KQ_idx] = fmaxf(KQ_max_new[KQ_idx], KQ_C[(k0/(np*T_C_KQ::J))].x[l] + FATTN_KQ_MAX_OFFSET); } } } @@ -596,7 +662,11 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( // Values per KQ column are spread across 4 threads: constexpr int offset_first = 2; constexpr int offset_last = 1; -#else +#elif defined(AMD_WMMA_AVAILABLE) + // Values per KQ column are spread across 2 threads: + constexpr int offset_first = 16; + constexpr int offset_last = 16; +#else // Volta // Values per KQ column are spread across 2 threads: constexpr int offset_first = 2; constexpr int offset_last = 2; @@ -612,10 +682,15 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::J) { #pragma unroll for (int l = 0; l < T_C_KQ::ne; ++l) { - // Turing + Volta: if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::J + T_C_KQ::get_j(l) < k_VKQ_sup) { - KQ_C[(k0/(np*T_C_KQ::J))].x[l] = expf(KQ_C[(k0/(np*T_C_KQ::J))].x[l] - KQ_max_new[(l/2) % 2]); - KQ_rowsum_add[(l/2) % 2] += KQ_C[(k0/(np*T_C_KQ::J))].x[l]; +#if defined(AMD_WMMA_AVAILABLE) + constexpr int KQ_idx = 0; +#else + // Turing + Volta: + const int KQ_idx = (l/2) % 2; +#endif // defined(AMD_WMMA_AVAILABLE) + KQ_C[(k0/(np*T_C_KQ::J))].x[l] = expf(KQ_C[(k0/(np*T_C_KQ::J))].x[l] - KQ_max_new[KQ_idx]); + KQ_rowsum_add[KQ_idx] += KQ_C[(k0/(np*T_C_KQ::J))].x[l]; } else { KQ_C[(k0/(np*T_C_KQ::J))].x[l] = 0.0f; } @@ -639,7 +714,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( #if defined(TURING_MMA_AVAILABLE) if constexpr (cols_per_warp == 8) { - const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]); + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[cols_per_thread - 1]); #pragma unroll for (int i = 0; i < DV/T_C_VKQ::I; ++i) { #pragma unroll @@ -660,6 +735,16 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } } } +#elif defined(AMD_WMMA_AVAILABLE) + const half2 KQ_max_scale_h2 = make_half2( + KQ_max_scale[0], KQ_max_scale[0]); +#pragma unroll + for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) { +#pragma unroll + for (int l = 0; l < T_C_VKQ::ne; ++l) { + VKQ_C[i].x[l] *= KQ_max_scale_h2; + } + } #else // Volta const half2 KQ_max_scale_h2 = make_half2( KQ_max_scale[(threadIdx.x / 2) % 2], KQ_max_scale[(threadIdx.x / 2) % 2]); @@ -707,6 +792,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( // Therefore, iterate over V in reverse and re-use the data if possible. static_assert(!mla || nstages <= 1, "combination of MLA and multi-stage loading not implemented"); constexpr int reusable_cutoff = mla ? (DKQ - 1) - (DKQ - 1) % (2*nbatch_K2) - (DKQ - DV) : DV; +#if defined(AMD_WMMA_AVAILABLE) && !defined(LDMATRIX_TRANS_AVAILABLE) + T_A_VKQ A_identity; + make_identity_mat(A_identity); +#endif // defined(AMD_WMMA_AVAILABLE) && !defined(LDMATRIX_TRANS_AVAILABLE) // Calculate VKQ tile, need to use logical rather than physical elements for i0 due to transposition of V: #pragma unroll @@ -727,7 +816,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } const half2 * tile_V_i = i0_start < reusable_cutoff ? tile_V : tile_V + (i0_start - reusable_cutoff)/2; -#if defined(TURING_MMA_AVAILABLE) +#if defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int i0_stride = cols_per_warp == 8 ? T_C_VKQ::I : 2*T_C_VKQ::J; #pragma unroll for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += i0_stride) { @@ -737,12 +826,26 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( const int k0 = k00 + (threadIdx.y % np)*T_A_VKQ::J; T_A_VKQ A; // Transposed in SRAM but not in registers, gets transposed on load. +#if defined(LDMATRIX_TRANS_AVAILABLE) load_ldmatrix_trans(A, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V); +#else + // TODO: Try to transpose tile_V when loading gmem to smem. + // Use mma to transpose T_A_VKQ for RDNA. + T_A_VKQ A_trans; + load_ldmatrix(A_trans, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V); + mma(A, A_trans, A_identity); +#endif // defined(TURING_MMA_AVAILABLE) if constexpr (T_B_KQ::I == 8) { mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]); } else { - // Wide version of VKQ_C is column-major => swap A and B. + // Wide version of VKQ_C is column-major. +#if defined(AMD_WMMA_AVAILABLE) + // RDNA matrix C is column-major. + mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]); +#else + // swap A and B for CUDA. mma(VKQ_C[i_VKQ_0/i0_stride], B[k00/(np*T_A_VKQ::J)], A); +#endif // defined(AMD_WMMA_AVAILABLE) } } } @@ -761,7 +864,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( mma(VKQ_C[i_VKQ_0/i0_stride], B[k00/(np*T_A_VKQ::I)], A); } } -#endif // defined(TURING_MMA_AVAILABLE) +#endif // defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) if constexpr (nstages <= 1) { __syncthreads(); // Only needed if tile_K == tile_V. @@ -774,7 +877,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0); NO_DEVICE_CODE; -#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) } #if defined(TURING_MMA_AVAILABLE) @@ -794,6 +897,15 @@ template<> struct mma_tile_sizes<8> { using T_B_VKQ = tile< 8, 8, half2>; // column-major using T_C_VKQ = tile<16, 4, half2>; // row-major }; +#elif defined(AMD_WMMA_AVAILABLE) +template struct mma_tile_sizes { + using T_A_KQ = tile<16, 8, half2>; // row-major + using T_B_KQ = tile<16, 8, half2>; // column-major + using T_C_KQ = tile<16, 16, float>; // column-major + using T_A_VKQ = tile<16, 8, half2>; // row-major + using T_B_VKQ = tile<16, 8, half2>; // column-major + using T_C_VKQ = tile<16, 8, half2>; // column-major +}; #else // Volta template struct mma_tile_sizes { using T_A_KQ = tile< 8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major @@ -828,7 +940,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const int jt, const int kb0_start, const int kb0_stop) { -#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) //In this kernel Q, K, V are matrices while i, j, k are matrix indices. constexpr int ncols = ncols1 * ncols2; @@ -840,7 +952,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( using T_C_VKQ = typename mma_tile_sizes::T_C_VKQ; constexpr int cols_per_warp = T_B_KQ::I; - constexpr int cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column. + constexpr int cols_per_thread = get_cols_per_thread(); constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column. constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa (DKQ, DV, ncols); constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2 (DKQ, DV, ncols); @@ -871,6 +983,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( T_B_KQ Q_B[(Q_in_reg ? DKQ/(2*T_B_KQ::J) : 1)]; #if defined(TURING_MMA_AVAILABLE) T_C_VKQ VKQ_C[cols_per_warp == 8 ? DV/T_C_VKQ::I : DV/(2*T_C_VKQ::J)]; +#elif defined(AMD_WMMA_AVAILABLE) + T_C_VKQ VKQ_C[ DV/(2*T_C_VKQ::J)]; #else // Volta T_C_VKQ VKQ_C[ DV/(2*T_C_VKQ::J)]; #endif // defined(TURING_MMA_AVAILABLE) @@ -1010,6 +1124,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( // The partial sums are spread across 8/4 threads. constexpr int offset_first = cols_per_warp == 8 ? 16 : 2; constexpr int offset_last = cols_per_warp == 8 ? 4 : 1; +#elif defined(AMD_WMMA_AVAILABLE) + // The partial sums are spread across 2 threads. + constexpr int offset_first = 16; + constexpr int offset_last = 16; #else // Volta // The partial sums are spread across 2 threads. constexpr int offset_first = 2; @@ -1047,7 +1165,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( #if defined(TURING_MMA_AVAILABLE) if constexpr (cols_per_warp == 8) { - const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]); + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[cols_per_thread - 1]); #pragma unroll for (int i = 0; i < DV/T_C_VKQ::I; ++i) { #pragma unroll @@ -1068,6 +1186,15 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( } } } +#elif defined(AMD_WMMA_AVAILABLE) + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[0]); +#pragma unroll + for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) { +#pragma unroll + for (int l = 0; l < T_C_VKQ::ne; ++l) { + VKQ_C[i].x[l] *= KQ_max_scale_h2; + } + } #else // Volta const int col = (threadIdx.x / 2) % 2; const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]); @@ -1119,6 +1246,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const int jc_cwm = threadIdx.y*cols_per_warp + T_C_VKQ::get_i(threadIdx.x % 4); const float2 KQ_cmr = make_float2(KQ_max[threadIdx.x % cols_per_thread], KQ_rowsum[threadIdx.x % cols_per_thread]); const bool thread_should_write = threadIdx.x % 4 < cols_per_thread; +#elif defined(AMD_WMMA_AVAILABLE) + const int jc_cwm = threadIdx.y*cols_per_warp + T_C_VKQ::get_i(0); + const float2 KQ_cmr = make_float2(KQ_max[0], KQ_rowsum[0]); + const bool thread_should_write = threadIdx.x / 16 < cols_per_thread; #else // Volta const int jc_cwm = threadIdx.y*cols_per_warp + T_C_KQ::get_i(threadIdx.x & 2); const float2 KQ_cmr = make_float2(KQ_max[(threadIdx.x & 2) / 2], KQ_rowsum[(threadIdx.x & 2) / 2]); @@ -1319,7 +1450,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop); NO_DEVICE_CODE; -#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) } template @@ -1346,7 +1477,7 @@ static __global__ void flash_attn_ext_f16( const int32_t nb21, const int32_t nb22, const int64_t nb23, const int32_t ne31, const int32_t ne32, const int32_t ne33, const int32_t nb31, const int32_t nb32, const int64_t nb33) { -#if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)) +#if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))) // Skip unused kernel variants for faster compilation: if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) { @@ -1360,6 +1491,13 @@ static __global__ void flash_attn_ext_f16( } #endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING +#if defined(AMD_WMMA_AVAILABLE) + if (ncols1*ncols2 > 32 || ncols1*ncols2 < 16 || DKQ > 128 || ncols2 == 1) { + NO_DEVICE_CODE; + return; + } +#endif // defined(AMD_WMMA_AVAILABLE) + static_assert(!mla || DKQ >= DV, "MLA needs DKQ >= DV"); constexpr int ncols = ncols1 * ncols2; @@ -1473,7 +1611,7 @@ static __global__ void flash_attn_ext_f16( ne31, ne32, ne33, nb31, nb32, nb33); NO_DEVICE_CODE; -#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)) +#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))) } template @@ -1492,7 +1630,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml const bool Q_in_reg = ggml_cuda_fattn_mma_get_Q_in_reg (DKQ, DV, ncols, cc); const int nstages = ggml_cuda_fattn_mma_get_nstages (DKQ, DV, ncols1, ncols2, cc); - const int cols_per_warp = std::min(ncols, turing_mma_available(cc) ? 16 : 32); + const int cols_per_warp = std::min(ncols, get_cols_per_warp(cc)); const int nwarps = nthreads / WARP_SIZE; constexpr bool mla = DKQ == 576; @@ -1512,29 +1650,34 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml float logit_softcap; memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); +#if defined(GGML_USE_HIP) + using fattn_kernel_ptr_t = const void*; +#else + using fattn_kernel_ptr_t = fattn_kernel_t; +#endif // defined(GGML_USE_HIP) fattn_kernel_t fattn_kernel; if (logit_softcap == 0.0f) { constexpr bool use_logit_softcap = false; fattn_kernel = flash_attn_ext_f16; -#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) +#if !defined(GGML_USE_MUSA) static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; if (!shared_memory_limit_raised[id]) { - CUDA_CHECK(cudaFuncSetAttribute(fattn_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total)); + CUDA_CHECK(cudaFuncSetAttribute(reinterpret_cast(fattn_kernel), cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total)); shared_memory_limit_raised[id] = true; } -#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) +#endif // !defined(GGML_USE_MUSA) } else { constexpr bool use_logit_softcap = true; fattn_kernel = flash_attn_ext_f16; -#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) +#if !defined(GGML_USE_MUSA) static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; if (!shared_memory_limit_raised[id]) { - CUDA_CHECK(cudaFuncSetAttribute(fattn_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total)); + CUDA_CHECK(cudaFuncSetAttribute(reinterpret_cast(fattn_kernel), cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total)); shared_memory_limit_raised[id] = true; } -#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) +#endif // !defined(GGML_USE_MUSA) } launch_fattn diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 0155406665c..598cda7daa0 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -18,12 +18,12 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_con } } - if (turing_mma_available(cc) && Q->ne[1] <= 16/ncols2) { + if ((turing_mma_available(cc) || amd_wmma_available(cc)) && Q->ne[1] <= 16/ncols2) { ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); return; } - if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING || Q->ne[1] <= 32/ncols2) { + if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING || amd_wmma_available(cc) || Q->ne[1] <= 32/ncols2) { ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); return; } @@ -230,7 +230,18 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const // The effective batch size for the kernel can be increased by gqa_ratio. // The kernel versions without this optimization are also used for ALiBi, if there is no mask, or if the KV cache is not padded, - const bool gqa_opt_applies = gqa_ratio % 2 == 0 && mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0; + bool gqa_opt_applies = gqa_ratio % 2 == 0 && mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0; + for (const ggml_tensor * t : {Q, K, V, mask}) { + if (t == nullptr) { + continue; + } + for (size_t i = 1; i < GGML_MAX_DIMS; ++i) { + if (t->nb[i] % 16 != 0) { + gqa_opt_applies = false; + break; + } + } + } const int cc = ggml_cuda_info().devices[device].cc; @@ -337,6 +348,31 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const return BEST_FATTN_KERNEL_WMMA_F16; } + if (amd_wmma_available(cc) && GGML_CUDA_CC_IS_RDNA4(cc) && gqa_opt_applies && Q->ne[0] <= 128 && Q->ne[0] != 40 && Q->ne[0] != 72) { + if (can_use_vector_kernel) { + if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) { + if (Q->ne[1] == 1) { + if (!gqa_opt_applies) { + return BEST_FATTN_KERNEL_VEC; + } + } + } else { + if (Q->ne[1] <= 2) { + return BEST_FATTN_KERNEL_VEC; + } + } + } + int gqa_ratio_eff = 1; + const int ncols2_max = Q->ne[0] == 576 ? 16 : 8; + while (gqa_ratio % (2*gqa_ratio_eff) == 0 && gqa_ratio_eff < ncols2_max) { + gqa_ratio_eff *= 2; + } + if (Q->ne[1] * gqa_ratio_eff <= 8) { + return BEST_FATTN_KERNEL_TILE; // AMD WMMA is only faster if the full tile width of 16 can be utilized. + } + return BEST_FATTN_KERNEL_MMA_F16; + } + // If there are no tensor cores available, use the generic tile kernel: if (can_use_vector_kernel) { if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) { diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index df9eed71172..42085d10027 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -206,10 +206,16 @@ namespace ggml_cuda_mma { static __device__ __forceinline__ int get_j(const int l) { if constexpr (I == 16 && J == 16) { - // matrix C #if defined(RDNA3) - return 2 * l + (threadIdx.x / 16); + if constexpr (std::is_same_v || std::is_same_v) { + // matrix C + return 2 * l + (threadIdx.x / 16); + } else { + // matrix A&B + return l; + } #else + // matrix C is the transposed matrix A&B on RDNA4 return ne * (threadIdx.x / 16) + l; #endif // defined(RDNA3) } else if constexpr (I == 16 && J == 8) { @@ -621,6 +627,21 @@ namespace ggml_cuda_mma { return ret; } +#elif defined(AMD_WMMA_AVAILABLE) + template + static __device__ __forceinline__ tile get_half2(const tile & tile_float) { + tile ret; +#pragma unroll + for (int l0 = 0; l0 < tile_float.ne; l0 += 2) { + ret.x[l0/2] = make_half2(tile_float.x[l0 + 0], tile_float.x[l0 + 1]); + } + return ret; + } + + static __device__ __forceinline__ tile<8, 8, half2> get_transposed(const tile<16, 4, half2> & t) { + NO_DEVICE_CODE; + return tile<8, 8, half2>{}; + } #else // Volta template static __device__ __forceinline__ tile get_half2(const tile & tile_float) { @@ -639,6 +660,19 @@ namespace ggml_cuda_mma { } #endif // defined(TURING_MMA_AVAILABLE) + static __device__ __forceinline__ void make_identity_mat(tile<16, 8, half2> & t) { +#if defined(RDNA4) + const int row = t.get_i(0); + const int left_right = t.get_j(0) / 4; + const int up_down = row / 8; + const int idx = row % 8; + reinterpret_cast(t.x)[idx] = left_right == up_down ? 1.0f : 0.0f; +#else + GGML_UNUSED_VARS(t); + NO_DEVICE_CODE; +#endif // defined(RDNA4) + } + template static __device__ __forceinline__ void load_generic(tile & t, const T * __restrict__ xs0, const int stride) { #if defined(AMD_MFMA_AVAILABLE) @@ -878,6 +912,17 @@ namespace ggml_cuda_mma { : "+r"(Dxi[2]), "+r"(Dxi[3]) : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3])); #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#elif defined(AMD_WMMA_AVAILABLE) +#if defined(RDNA4) + using halfx8_t = __attribute__((ext_vector_type(8))) _Float16; + halfx8_t& acc_frag = reinterpret_cast(D.x[0]); + const halfx8_t& a_frag = reinterpret_cast(A.x[0]); + const halfx8_t& b_frag = reinterpret_cast(B.x[0]); + acc_frag = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12(a_frag, b_frag, acc_frag); +#else + GGML_UNUSED_VARS(D, A, B); + NO_DEVICE_CODE; +#endif // defined(RDNA4) #else GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h index 016b04e5a0c..5cc1b54319c 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h @@ -138,6 +138,8 @@ #define cudaStream_t hipStream_t #define cudaSuccess hipSuccess #define cudaOccupancyMaxActiveBlocksPerMultiprocessor hipOccupancyMaxActiveBlocksPerMultiprocessor +#define cudaFuncSetAttribute hipFuncSetAttribute +#define cudaFuncAttributeMaxDynamicSharedMemorySize hipFuncAttributeMaxDynamicSharedMemorySize #define __trap() do { abort(); __builtin_unreachable(); } while(0) #define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS #define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED From 17656e56dc637d518a0aec07a21144de78e227ca Mon Sep 17 00:00:00 2001 From: Perry Naseck <4472083+DaAwesomeP@users.noreply.github.com> Date: Wed, 14 Jan 2026 02:22:25 -0500 Subject: [PATCH 003/831] ggml-metal: do not copy headers for embedded, use current binary dir for embedded (llama/18705) --- ggml/src/ggml-metal/CMakeLists.txt | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/ggml/src/ggml-metal/CMakeLists.txt b/ggml/src/ggml-metal/CMakeLists.txt index 63418fe1430..9c0b3db8599 100644 --- a/ggml/src/ggml-metal/CMakeLists.txt +++ b/ggml/src/ggml-metal/CMakeLists.txt @@ -23,11 +23,6 @@ if (GGML_METAL_NDEBUG) add_compile_definitions(GGML_METAL_NDEBUG) endif() -# copy metal files to bin directory -configure_file(../ggml-common.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h COPYONLY) -configure_file(ggml-metal.metal ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal COPYONLY) -configure_file(ggml-metal-impl.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal-impl.h COPYONLY) - set(METALLIB_COMMON "${CMAKE_CURRENT_SOURCE_DIR}/../ggml-common.h") if (GGML_METAL_EMBED_LIBRARY) enable_language(ASM) @@ -37,12 +32,12 @@ if (GGML_METAL_EMBED_LIBRARY) set(METALLIB_SOURCE "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal.metal") set(METALLIB_IMPL "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal-impl.h") - file(MAKE_DIRECTORY "${CMAKE_BINARY_DIR}/autogenerated") + file(MAKE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/autogenerated") # merge ggml-common.h and ggml-metal.metal into a single file - set(METALLIB_EMBED_ASM "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.s") - set(METALLIB_SOURCE_EMBED "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.metal") - set(METALLIB_SOURCE_EMBED_TMP "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.metal.tmp") + set(METALLIB_EMBED_ASM "${CMAKE_CURRENT_BINARY_DIR}/autogenerated/ggml-metal-embed.s") + set(METALLIB_SOURCE_EMBED "${CMAKE_CURRENT_BINARY_DIR}/autogenerated/ggml-metal-embed.metal") + set(METALLIB_SOURCE_EMBED_TMP "${CMAKE_CURRENT_BINARY_DIR}/autogenerated/ggml-metal-embed.metal.tmp") add_custom_command( OUTPUT "${METALLIB_EMBED_ASM}" @@ -62,6 +57,11 @@ if (GGML_METAL_EMBED_LIBRARY) target_sources(ggml-metal PRIVATE "${METALLIB_EMBED_ASM}") else() + # copy metal files to bin directory + configure_file(../ggml-common.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h COPYONLY) + configure_file(ggml-metal.metal ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal COPYONLY) + configure_file(ggml-metal-impl.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal-impl.h COPYONLY) + if (GGML_METAL_SHADER_DEBUG) # custom command to do the following: # xcrun -sdk macosx metal -fno-fast-math -c ggml-metal.metal -o ggml-metal.air From 49762e8fb390a0832e27a31ab03c5b5d6a320a22 Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Wed, 14 Jan 2026 09:41:23 +0100 Subject: [PATCH 004/831] vulkan: work around Intel fp16 bug in mmq (llama/18814) --- ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl index 7f32dadf17d..9c297d1c60d 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl @@ -264,7 +264,7 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { const i8vec2 scales = i8vec2(unpack8(uint32_t(((data_a_packed16[ib_k].scales[(is % 8 ) / 2] >> (4 * (is / 8))) & 0x0F0F) | (((data_a_packed16[ib_k].scales[(8 + (is % 4)) / 2] >> (2 * (is / 4))) & 0x0303) << 4))).xy); // vec4 used due to #12147 - buf_a[buf_ib].d_scales = FLOAT_TYPE(data_a_packed16[ib_k].d) * FLOAT_TYPE_VEC2(scales - 32); + buf_a[buf_ib].d_scales = FLOAT_TYPE_VEC2(float(data_a_packed16[ib_k].d) * vec2(scales - 32)); } } @@ -334,7 +334,7 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { (data_a[ib_k].scales[is+4] >> 4) | ((data_a[ib_k].scales[is ] & 0xC0) >> 2)); } - buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm) * FLOAT_TYPE_VEC2(scale_dm); + buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(vec2(data_a_packed32[ib_k].dm) * vec2(scale_dm)); } } @@ -385,7 +385,7 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { const uint is = iqs_k / 4; const i8vec2 scales = unpack8(int32_t(data_a_packed16[ib_k].scales[is / 2])).xy; - buf_a[buf_ib].d_scales = FLOAT_TYPE(data_a_packed16[ib_k].d) * FLOAT_TYPE_VEC2(scales); + buf_a[buf_ib].d_scales = FLOAT_TYPE_VEC2(float(data_a_packed16[ib_k].d) * vec2(scales)); } } From 25aeb66a4a56bc20430e564d1c7815109a5cb801 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Wed, 14 Jan 2026 10:31:49 +0100 Subject: [PATCH 005/831] CUDA : fix typo in clang pragma comment [no ci] (llama/18830) --- ggml/src/ggml-cuda/fattn-vec.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/fattn-vec.cuh b/ggml/src/ggml-cuda/fattn-vec.cuh index 4d167b95a07..86f4dc0f7f1 100644 --- a/ggml/src/ggml-cuda/fattn-vec.cuh +++ b/ggml/src/ggml-cuda/fattn-vec.cuh @@ -10,7 +10,7 @@ static constexpr __device__ int ggml_cuda_fattn_vec_get_nthreads_device() { return 128; } -// Currenlty llvm with the amdgcn target dose not support unrolling loops +// Currenlty llvm with the amdgcn target does not support unrolling loops // that contain a break that can not be resolved at compile time. #ifdef __clang__ #pragma clang diagnostic push From 4b155e9bfb84a3e9f506940f253406d8c7cce377 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Wed, 14 Jan 2026 03:59:05 -0600 Subject: [PATCH 006/831] vulkan: Check maxStorageBufferRange in supports_op (llama/18709) * vulkan: Check maxStorageBufferRange in supports_op * skip maxStorageBufferRange check when shader64BitIndexing is enabled --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index deed5055d54..0fabbcec31d 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -14413,13 +14413,29 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; const vk_device& device = ggml_vk_get_device(ctx->device); + const bool uses_bda = (op->op == GGML_OP_IM2COL || op->op == GGML_OP_IM2COL_3D) && + device->shader_int64 && device->buffer_device_address; + + auto const & tensor_size_supported = [&](size_t tensor_size) { + if (tensor_size > device->max_buffer_size) { + return false; + } + // For im2col shaders using BDA, maxStorageBufferRange limit doesn't apply. + // If shader64BitIndexing is enabled, maxStorageBufferRange limit doesn't apply. + if (!uses_bda && !device->shader_64b_indexing) { + if (tensor_size > device->properties.limits.maxStorageBufferRange) { + return false; + } + } + return true; + }; // reject any tensors larger than the max buffer size for (int i = 0; i < GGML_MAX_SRC; i++) { - if (op->src[i] && ggml_nbytes(op->src[i]) > device->max_buffer_size) { + if (op->src[i] && !tensor_size_supported(ggml_nbytes(op->src[i]))) { return false; } } - if (ggml_nbytes(op) > device->max_buffer_size) { + if (!tensor_size_supported(ggml_nbytes(op))) { return false; } From bc09047405c36d4f639e77739a01bf7fed7b0f24 Mon Sep 17 00:00:00 2001 From: Oliver Simons Date: Thu, 15 Jan 2026 03:44:54 +0100 Subject: [PATCH 007/831] CUDA: Factor out and re-use `block_reduce` function (llama/18785) * CUDA: Refactor and expose two_stage_warp_reduce_* function * Use `two_stage_warp_reduce` also in softmax kernel, move smem out of it Moving smem out of `__device__` function to `__global__` function allows for explicit smem reuse, as either compiler or cuda rt seem to not free it afterwards (`cudaFuncSetAttribute` fails when not accounting for it once for each call to two_stage_warp_reduce) * Update ggml/src/ggml-cuda/common.cuh Co-authored-by: Aman Gupta * Use two_stage_warp_reduce in group_norm_f32 * Use two_stage_warp_reduce in rms_norm_f32 * Fix smem calculation which expects bytes * Make `two_stage_warp_reduce` accept all values warp_reduce accepts Also integrate it into norm_f32 function * Use two_stage_warp_reduce in l2_norm_f32 * Use type traits for block reduction for better legibility Also adresss other requests by @am17an such as variable renaming * Make norm tests cover all cuda paths * Mark columns % WARP_SIZE !=0 as supported for RMS_NORM_BACK Unit-tests passed locally, let's see if they pass in the CI as well * Use `enum class` for `block_reduce_method` This is more type-safe than plain enum * Rename variables as suggested in code review by @am17an * Rename two_stage_warp_reduce -> block_reduce * Fix trailing whitespace in common.cuh * Make condition of static_assert type-dependent This delays evaluation until the template is actually instantiated. Otherwise, some compilers may evaluate the assert when parsing the template, resulting in build errors as observed here: https://github.com/ggml-org/llama.cpp/actions/runs/20960323123/job/60235530068?pr=18785 * Inline definitions --------- Co-authored-by: Aman Gupta --- ggml/src/ggml-cuda/common.cuh | 80 +++++++++++++++++++++++++ ggml/src/ggml-cuda/ggml-cuda.cu | 2 +- ggml/src/ggml-cuda/norm.cu | 94 ++++++------------------------ ggml/src/ggml-cuda/reduce_rows.cuh | 18 +----- ggml/src/ggml-cuda/softmax.cu | 89 +++------------------------- 5 files changed, 108 insertions(+), 175 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 90794ff2641..eaaf87612d2 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -530,6 +530,86 @@ static __device__ __forceinline__ half2 warp_prefix_inclusive_sum(half2 a) { #endif // FP16_AVAILABLE } +enum class block_reduce_method { + MAX, + SUM, +}; + +template +struct block_reduce_policy; + +template +inline constexpr bool is_any = (std::is_same_v || ...); + +template +inline constexpr bool ggml_cuda_dependent_false_v = false; + +template struct block_reduce_policy { + static __device__ T reduce(T val) { + if constexpr(is_any) { + return warp_reduce_sum(val); + } else { + static_assert(ggml_cuda_dependent_false_v, "Unsupported type for block reduce sum"); + } + } + + static __device__ T sentinel() { + if constexpr (std::is_same_v) { + return 0.0f; + } else if constexpr (std::is_same_v) { + return make_float2(0.0f, 0.0f); + } else if constexpr (std::is_same_v) { + return make_half2(0.0f, 0.0f); + } else if constexpr (std::is_same_v) { + return 0; + } else { + static_assert(ggml_cuda_dependent_false_v, "Unsupported type for block reduce sum"); + } + } +}; + +template struct block_reduce_policy { + static __device__ T reduce(T val) { + if constexpr (is_any) { + return warp_reduce_max(val); + } else { + static_assert(ggml_cuda_dependent_false_v, "Unsupported type for block reduce max"); + } + } + + static __device__ T sentinel() { + if constexpr (std::is_same_v) { + return -INFINITY; + } else if constexpr (std::is_same_v) { + return make_half2(-INFINITY, -INFINITY); + } else { + static_assert(ggml_cuda_dependent_false_v, "Unsupported type for block reduce max"); + } + } +}; + +template +static __device__ T block_reduce(T val, T * shared_vals) { + val = block_reduce_policy::reduce(val); + const unsigned int block_size = block_size_template == 0 ? blockDim.x : block_size_template; + if (block_size > WARP_SIZE) { + assert((block_size <= 1024) && (block_size % WARP_SIZE) == 0); + const int warp_id = threadIdx.x / WARP_SIZE; + const int lane_id = threadIdx.x % WARP_SIZE; + if (lane_id == 0) { + shared_vals[warp_id] = val; + } + __syncthreads(); + val = block_reduce_policy::sentinel(); + if (lane_id < (static_cast(block_size) / WARP_SIZE)) { + val = shared_vals[lane_id]; + } + return block_reduce_policy::reduce(val); + } + + return val; +} + static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) { #ifdef FP16_AVAILABLE diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index c3ee2ea0667..553623fbd42 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -4551,7 +4551,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_L2_NORM: return true; case GGML_OP_RMS_NORM_BACK: - return ggml_is_contiguous(op->src[0]) && op->ne[0] % WARP_SIZE == 0; + return ggml_is_contiguous(op->src[0]); break; case GGML_OP_NONE: case GGML_OP_RESHAPE: diff --git a/ggml/src/ggml-cuda/norm.cu b/ggml/src/ggml-cuda/norm.cu index 4f153c5718e..ef98f675aa7 100644 --- a/ggml/src/ggml-cuda/norm.cu +++ b/ggml/src/ggml-cuda/norm.cu @@ -25,19 +25,8 @@ static __global__ void norm_f32( } // sum up partial sums - mean_var = warp_reduce_sum(mean_var); - if constexpr (block_size > WARP_SIZE) { - static_assert(block_size == 1024, "unexpected block_size"); - __shared__ float2 s_sum[32]; - const int warp_id = threadIdx.x / WARP_SIZE; - const int lane_id = threadIdx.x % WARP_SIZE; - if (lane_id == 0) { - s_sum[warp_id] = mean_var; - } - __syncthreads(); - mean_var = s_sum[lane_id]; - mean_var = warp_reduce_sum(mean_var); - } + extern __shared__ float2 s_sum2[]; + mean_var = block_reduce(mean_var, s_sum2); const float mean = mean_var.x / ncols; const float var = mean_var.y / ncols - mean * mean; @@ -61,19 +50,8 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr tmp += x[j]; } - tmp = warp_reduce_sum(tmp); - if constexpr (block_size > WARP_SIZE) { - static_assert(block_size == 1024, "unexpected block_size"); - __shared__ float s_sum[32]; - const int warp_id = threadIdx.x / WARP_SIZE; - const int lane_id = threadIdx.x % WARP_SIZE; - if (lane_id == 0) { - s_sum[warp_id] = tmp; - } - __syncthreads(); - tmp = s_sum[lane_id]; - tmp = warp_reduce_sum(tmp); - } + extern __shared__ float s_sum[]; + tmp = block_reduce(tmp, s_sum); const float mean = tmp / group_size; tmp = 0.0f; @@ -84,18 +62,7 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr tmp += xi * xi; } - tmp = warp_reduce_sum(tmp); - if (block_size > WARP_SIZE) { - __shared__ float s_sum[32]; - const int warp_id = threadIdx.x / WARP_SIZE; - const int lane_id = threadIdx.x % WARP_SIZE; - if (lane_id == 0) { - s_sum[warp_id] = tmp; - } - __syncthreads(); - tmp = s_sum[lane_id]; - tmp = warp_reduce_sum(tmp); - } + tmp = block_reduce(tmp, s_sum); const float variance = tmp / group_size; const float scale = rsqrtf(variance + eps); @@ -163,22 +130,8 @@ static __global__ void rms_norm_f32(const float * x, } // sum up partial sums - tmp = warp_reduce_sum(tmp); - if constexpr (block_size > WARP_SIZE) { - static_assert((block_size <= 1024) && (block_size % 32 == 0), "unexpected block_size"); - __shared__ float s_sum[32]; - const int warp_id = tid / WARP_SIZE; - const int lane_id = tid % WARP_SIZE; - if (lane_id == 0) { - s_sum[warp_id] = tmp; - } - __syncthreads(); - tmp = 0.0f; - if (lane_id < (block_size / WARP_SIZE)) { - tmp = s_sum[lane_id]; - } - tmp = warp_reduce_sum(tmp); - } + extern __shared__ float s_sum[]; + tmp = block_reduce(tmp, s_sum); const float mean = tmp / ncols; const float scale = rsqrtf(mean + eps); @@ -306,19 +259,8 @@ static __global__ void l2_norm_f32( } // sum up partial sums - tmp = warp_reduce_sum(tmp); - if constexpr (block_size > WARP_SIZE) { - static_assert(block_size == 1024, "unexpected block_size"); - __shared__ float s_sum[32]; - const int warp_id = threadIdx.x / WARP_SIZE; - const int lane_id = threadIdx.x % WARP_SIZE; - if (lane_id == 0) { - s_sum[warp_id] = tmp; - } - __syncthreads(); - tmp = s_sum[lane_id]; - tmp = warp_reduce_sum(tmp); - } + extern __shared__ float s_sum[]; + tmp = block_reduce(tmp, s_sum); // from https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html const float scale = rsqrtf(fmaxf(tmp, eps * eps)); @@ -337,7 +279,7 @@ static void norm_f32_cuda( norm_f32<<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); } else { const dim3 block_dims(1024, 1, 1); - norm_f32<1024><<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); + norm_f32<1024><< WARP_SIZE ? 32 * sizeof(float2): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); } } @@ -348,7 +290,7 @@ static void group_norm_f32_cuda( group_norm_f32<<>>(x, dst, group_size, ne_elements, eps); } else { const dim3 block_dims(1024, 1, 1); - group_norm_f32<1024><<>>(x, dst, group_size, ne_elements, eps); + group_norm_f32<1024><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, group_size, ne_elements, eps); } } @@ -358,10 +300,10 @@ static void rms_norm_f32_cuda( const dim3 blocks_num(nrows, nchannels, nsamples); if (ncols < 1024) { const dim3 block_dims(256, 1, 1); - rms_norm_f32<256, false><<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); + rms_norm_f32<256, false><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); } else { const dim3 block_dims(1024, 1, 1); - rms_norm_f32<1024, false><<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); + rms_norm_f32<1024, false><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); } } @@ -404,12 +346,12 @@ static void rms_norm_mul_f32_cuda(const float * x, const uint3 mul_nsamples_packed = init_fastdiv_values(mul_nsamples); if (ncols < 1024) { const dim3 block_dims(256, 1, 1); - rms_norm_f32<256, true><<>>( + rms_norm_f32<256, true><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>( x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed); } else { const dim3 block_dims(1024, 1, 1); - rms_norm_f32<1024, true><<>>( + rms_norm_f32<1024, true><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>( x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed); } @@ -425,14 +367,14 @@ static void rms_norm_mul_f32_cuda(const float * x, const uint3 add_nsamples_packed = init_fastdiv_values(add_nsamples); if (ncols < 1024) { const dim3 block_dims(256, 1, 1); - rms_norm_f32<256, true, true><<>>( + rms_norm_f32<256, true, true><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>( x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add, add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed, add_nchannels_packed, add_nsamples_packed); } else { const dim3 block_dims(1024, 1, 1); - rms_norm_f32<1024, true, true><<>>( + rms_norm_f32<1024, true, true><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>( x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add, add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed, @@ -460,7 +402,7 @@ static void l2_norm_f32_cuda( l2_norm_f32<<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); } else { const dim3 block_dims(1024, 1, 1); - l2_norm_f32<1024><<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); + l2_norm_f32<1024><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); } } diff --git a/ggml/src/ggml-cuda/reduce_rows.cuh b/ggml/src/ggml-cuda/reduce_rows.cuh index 6bcae9e52fb..de240fd4413 100644 --- a/ggml/src/ggml-cuda/reduce_rows.cuh +++ b/ggml/src/ggml-cuda/reduce_rows.cuh @@ -28,22 +28,8 @@ static __global__ void reduce_rows_f32(const float * __restrict__ x, float * __r } // sum up partial sums - sum = warp_reduce_sum(sum); - if (blockDim.x > WARP_SIZE) { - assert((blockDim.x <= 1024) && (blockDim.x % WARP_SIZE) == 0); - __shared__ float s_sum[32]; - const int warp_id = threadIdx.x / WARP_SIZE; - const int lane_id = threadIdx.x % WARP_SIZE; - if (lane_id == 0) { - s_sum[warp_id] = sum; - } - __syncthreads(); - sum = 0.0f; - if (lane_id < (static_cast(blockDim.x) / WARP_SIZE)) { - sum = s_sum[lane_id]; - } - sum = warp_reduce_sum(sum); - } + __shared__ float shared_vals[32]; + sum = block_reduce(sum, shared_vals); if (col != 0) { return; diff --git a/ggml/src/ggml-cuda/softmax.cu b/ggml/src/ggml-cuda/softmax.cu index 1ae84ebf630..dc06d06930e 100644 --- a/ggml/src/ggml-cuda/softmax.cu +++ b/ggml/src/ggml-cuda/softmax.cu @@ -75,9 +75,6 @@ static __global__ void soft_max_f32( const int block_size = block_size_template == 0 ? blockDim.x : block_size_template; - const int warp_id = threadIdx.x / WARP_SIZE; - const int lane_id = threadIdx.x % WARP_SIZE; - const float slope = get_alibi_slope(p.max_bias, i02, p.n_head_log2, p.m0, p.m1); extern __shared__ float data_soft_max_f32[]; @@ -102,21 +99,7 @@ static __global__ void soft_max_f32( } // find the max value in the block - max_val = warp_reduce_max(max_val); - if (block_size > WARP_SIZE) { - if (warp_id == 0) { - buf_iw[lane_id] = -INFINITY; - } - __syncthreads(); - - if (lane_id == 0) { - buf_iw[warp_id] = max_val; - } - __syncthreads(); - - max_val = buf_iw[lane_id]; - max_val = warp_reduce_max(max_val); - } + max_val = block_reduce(max_val, buf_iw); float tmp = 0.0f; // partial sum @@ -134,22 +117,7 @@ static __global__ void soft_max_f32( } // find the sum of exps in the block - tmp = warp_reduce_sum(tmp); - if (block_size > WARP_SIZE) { - __syncthreads(); - if (warp_id == 0) { - buf_iw[lane_id] = 0.0f; - } - __syncthreads(); - - if (lane_id == 0) { - buf_iw[warp_id] = tmp; - } - __syncthreads(); - - tmp = buf_iw[lane_id]; - tmp = warp_reduce_sum(tmp); - } + tmp = block_reduce(tmp, buf_iw); if (sinks) { tmp += expf(sinks[i02] - max_val); @@ -169,50 +137,6 @@ static __global__ void soft_max_f32( } } - -// TODO: This is a common pattern used across kernels that could be moved to common.cuh + templated -static __device__ float two_stage_warp_reduce_max(float val) { - val = warp_reduce_max(val); - if (blockDim.x > WARP_SIZE) { - assert((blockDim.x <= 1024) && (blockDim.x % WARP_SIZE) == 0); - __shared__ float local_vals[32]; - const int warp_id = threadIdx.x / WARP_SIZE; - const int lane_id = threadIdx.x % WARP_SIZE; - if (lane_id == 0) { - local_vals[warp_id] = val; - } - __syncthreads(); - val = -INFINITY; - if (lane_id < (static_cast(blockDim.x) / WARP_SIZE)) { - val = local_vals[lane_id]; - } - return warp_reduce_max(val); - } else { - return val; - } -} - -static __device__ float two_stage_warp_reduce_sum(float val) { - val = warp_reduce_sum(val); - if (blockDim.x > WARP_SIZE) { - assert((blockDim.x <= 1024) && (blockDim.x % WARP_SIZE) == 0); - __shared__ float local_vals[32]; - const int warp_id = threadIdx.x / WARP_SIZE; - const int lane_id = threadIdx.x % WARP_SIZE; - if (lane_id == 0) { - local_vals[warp_id] = val; - } - __syncthreads(); - val = 0.0f; - if (lane_id < (static_cast(blockDim.x) / WARP_SIZE)) { - val = local_vals[lane_id]; - } - return warp_reduce_sum(val); - } else { - return val; - } -} - // TODO: Template to allow keeping ncols in registers if they fit static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __restrict__ x, float * __restrict__ dst, @@ -230,6 +154,7 @@ static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __ float local_vals[n_elem_per_thread] = { -INFINITY, -INFINITY, -INFINITY, -INFINITY }; float local_max = -INFINITY; const int step_size = gridDim.x * blockDim.x; + __shared__ float shared_vals[32]; // Compute thread-local max for (int col = col_start; col < p.ncols;) { @@ -246,7 +171,7 @@ static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __ } // Compute CTA-level max - local_max = two_stage_warp_reduce_max(local_max); + local_max = block_reduce(local_max, shared_vals); // Store CTA-level max to GMEM if (tid == 0) { @@ -261,7 +186,7 @@ static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __ } else { local_max = -INFINITY; } - local_max = two_stage_warp_reduce_max(local_max); + local_max = block_reduce(local_max, shared_vals); // Compute softmax dividends, accumulate divisor float tmp_expf = 0.0f; @@ -284,7 +209,7 @@ static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __ } // Reduce divisor within CTA - tmp_expf = two_stage_warp_reduce_sum(tmp_expf); + tmp_expf = block_reduce(tmp_expf, shared_vals); // Store CTA-level sum to GMEM if (tid == 0) { @@ -298,7 +223,7 @@ static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __ } else { tmp_expf = 0.0f; } - tmp_expf = two_stage_warp_reduce_sum(tmp_expf); + tmp_expf = block_reduce(tmp_expf, shared_vals); // Divide dividend by global sum + store data for (int col = col_start; col < p.ncols;) { From 50b7ab3d461d80c45b389e9469ef40e654feffb7 Mon Sep 17 00:00:00 2001 From: Max Krasnyansky Date: Fri, 30 Jan 2026 10:28:03 +0200 Subject: [PATCH 008/831] hexagon: support for OP_CPY, host buffers now optional (llama/18822) --- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 93 +- ggml/src/ggml-hexagon/htp/CMakeLists.txt | 8 +- ggml/src/ggml-hexagon/htp/act-ops.c | 107 +- ggml/src/ggml-hexagon/htp/binary-ops.c | 57 +- ggml/src/ggml-hexagon/htp/cpy-ops.c | 251 +++ ggml/src/ggml-hexagon/htp/flash-attn-ops.c | 77 +- ggml/src/ggml-hexagon/htp/get-rows-ops.c | 10 +- .../ggml-hexagon/htp/{htp-dma.c => hex-dma.c} | 2 +- .../ggml-hexagon/htp/{htp-dma.h => hex-dma.h} | 1 - ggml/src/ggml-hexagon/htp/hex-dump.h | 77 + ggml/src/ggml-hexagon/htp/hex-fastdiv.h | 37 + ggml/src/ggml-hexagon/htp/hex-utils.h | 51 + ggml/src/ggml-hexagon/htp/htp-ctx.h | 2 +- ggml/src/ggml-hexagon/htp/htp-msg.h | 1 + ggml/src/ggml-hexagon/htp/htp-ops.h | 12 +- ggml/src/ggml-hexagon/htp/hvx-arith.h | 457 ++++++ ggml/src/ggml-hexagon/htp/hvx-base.h | 167 ++ ggml/src/ggml-hexagon/htp/hvx-copy.h | 247 +++ ggml/src/ggml-hexagon/htp/hvx-dump.h | 132 ++ ggml/src/ggml-hexagon/htp/hvx-exp.c | 94 -- ggml/src/ggml-hexagon/htp/hvx-exp.h | 215 +++ ggml/src/ggml-hexagon/htp/hvx-floor.h | 100 ++ ggml/src/ggml-hexagon/htp/hvx-inverse.c | 72 - ggml/src/ggml-hexagon/htp/hvx-inverse.h | 176 +++ ggml/src/ggml-hexagon/htp/hvx-reduce.h | 225 +++ ggml/src/ggml-hexagon/htp/hvx-scale.h | 133 ++ ggml/src/ggml-hexagon/htp/hvx-sigmoid.c | 49 - ggml/src/ggml-hexagon/htp/hvx-sigmoid.h | 114 ++ ggml/src/ggml-hexagon/htp/hvx-sqrt.h | 60 + ggml/src/ggml-hexagon/htp/hvx-types.h | 36 + ggml/src/ggml-hexagon/htp/hvx-utils.c | 1020 ------------- ggml/src/ggml-hexagon/htp/hvx-utils.h | 1360 +---------------- ggml/src/ggml-hexagon/htp/main.c | 66 +- ggml/src/ggml-hexagon/htp/matmul-ops.c | 264 ++-- ggml/src/ggml-hexagon/htp/ops-utils.h | 149 -- ggml/src/ggml-hexagon/htp/rope-ops.c | 25 +- ggml/src/ggml-hexagon/htp/set-rows-ops.c | 16 +- ggml/src/ggml-hexagon/htp/softmax-ops.c | 45 +- ggml/src/ggml-hexagon/htp/unary-ops.c | 37 +- ggml/src/ggml-hexagon/htp/worker-pool.c | 4 - 40 files changed, 2904 insertions(+), 3145 deletions(-) create mode 100644 ggml/src/ggml-hexagon/htp/cpy-ops.c rename ggml/src/ggml-hexagon/htp/{htp-dma.c => hex-dma.c} (98%) rename ggml/src/ggml-hexagon/htp/{htp-dma.h => hex-dma.h} (99%) create mode 100644 ggml/src/ggml-hexagon/htp/hex-dump.h create mode 100644 ggml/src/ggml-hexagon/htp/hex-fastdiv.h create mode 100644 ggml/src/ggml-hexagon/htp/hex-utils.h create mode 100644 ggml/src/ggml-hexagon/htp/hvx-arith.h create mode 100644 ggml/src/ggml-hexagon/htp/hvx-base.h create mode 100644 ggml/src/ggml-hexagon/htp/hvx-copy.h create mode 100644 ggml/src/ggml-hexagon/htp/hvx-dump.h delete mode 100644 ggml/src/ggml-hexagon/htp/hvx-exp.c create mode 100644 ggml/src/ggml-hexagon/htp/hvx-exp.h create mode 100644 ggml/src/ggml-hexagon/htp/hvx-floor.h delete mode 100644 ggml/src/ggml-hexagon/htp/hvx-inverse.c create mode 100644 ggml/src/ggml-hexagon/htp/hvx-inverse.h create mode 100644 ggml/src/ggml-hexagon/htp/hvx-reduce.h create mode 100644 ggml/src/ggml-hexagon/htp/hvx-scale.h delete mode 100644 ggml/src/ggml-hexagon/htp/hvx-sigmoid.c create mode 100644 ggml/src/ggml-hexagon/htp/hvx-sigmoid.h create mode 100644 ggml/src/ggml-hexagon/htp/hvx-sqrt.h create mode 100644 ggml/src/ggml-hexagon/htp/hvx-types.h delete mode 100644 ggml/src/ggml-hexagon/htp/hvx-utils.c delete mode 100644 ggml/src/ggml-hexagon/htp/ops-utils.h diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 365a24b4965..cf1eb994c3e 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -42,12 +42,12 @@ #include "htp_iface.h" static size_t opt_ndev = 1; -static size_t opt_nhvx = 0; // use all -static int opt_arch = 0; // autodetect +static size_t opt_nhvx = 0; // use all +static int opt_arch = 0; // autodetect static int opt_etm = 0; static int opt_verbose = 0; static int opt_profile = 0; -static int opt_hostbuf = 1; +static int opt_hostbuf = 1; // hostbuf ON by default static int opt_experimental = 0; // Enable all stages by default @@ -1753,6 +1753,9 @@ static bool ggml_backend_buffer_is_hexagon(const struct ggml_backend_buffer * b) } static inline bool ggml_backend_buffer_is_hexagon_repack(const struct ggml_backend_buffer * b) { + if (!opt_hostbuf) { + return ggml_backend_buffer_is_hexagon(b); + } return b->buft->iface.alloc_buffer == ggml_backend_hexagon_repack_buffer_type_alloc_buffer; } @@ -2302,6 +2305,16 @@ static inline size_t init_binary_req(htp_general_req * req, dspqueue_buffer * bu return n_bufs; } +static inline size_t init_cpy_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { + req->op = HTP_OP_CPY; + + size_t n_bufs = 0; + n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); + n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); + + return n_bufs; +} + static inline size_t init_get_rows_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { req->op = HTP_OP_GET_ROWS; @@ -2557,6 +2570,10 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg ggml_hexagon_dispatch_op(sess, node, flags); break; + case GGML_OP_CPY: + ggml_hexagon_dispatch_op(sess, node, flags); + break; + default: GGML_ABORT("\nggml-hex: graph-compute %s is not supported\n", ggml_op_desc(node)); } @@ -2858,6 +2875,27 @@ static bool ggml_hexagon_supported_buffers(ggml_hexagon_session *sess, const str return true; } +static bool ggml_hexagon_supported_cpy(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + const struct ggml_tensor * src0 = op->src[0]; + const struct ggml_tensor * dst = op; + + // for now we can do f32 -> f16 and f16 -> f32 (without reshaping) + if (src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) return false; + if ( dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16) return false; + + const bool sametype = (src0->type == dst->type); + const bool transposed = ggml_is_transposed(src0) || ggml_is_transposed(dst); + const bool sameshape = !transposed && ggml_are_same_shape(src0, dst); + + // can handle any shape and any same-type (pretty slow if reshaping is required) + if (sametype) return true; + + // cannot handle re-shaping and type conversion at the same time + if (!sameshape) return false; + + return true; +} + static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) { auto sess = static_cast(dev->context); @@ -2936,6 +2974,10 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons supp = ggml_hexagon_supported_get_rows(sess, op); break; + case GGML_OP_CPY: + supp = ggml_hexagon_supported_cpy(sess, op); + break; + default: break; } @@ -3061,7 +3103,7 @@ static ggml_backend_dev_t ggml_backend_hexagon_reg_get_device(ggml_backend_reg_t } static void * ggml_backend_hexagon_get_proc_address(ggml_backend_reg_t reg, const char * name) { - if (strcmp(name, "ggml_backend_dev_get_extra_bufts") == 0) { + if (strcmp(name, "ggml_backend_dev_get_extra_bufts") == 0 && opt_hostbuf) { ggml_backend_dev_get_extra_bufts_t fct = ggml_backend_hexagon_device_get_extra_buffers_type; return (void *) fct; } @@ -3078,34 +3120,31 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) { static_assert((unsigned int) HTP_TYPE_MXFP4 == (unsigned int) GGML_TYPE_MXFP4, "please update hexagon_type to match ggml_type"); + const char * str_experimental = getenv("GGML_HEXAGON_EXPERIMENTAL"); const char * str_verbose = getenv("GGML_HEXAGON_VERBOSE"); const char * str_hostbuf = getenv("GGML_HEXAGON_HOSTBUF"); - + const char * str_opmask = getenv("GGML_HEXAGON_OPMASK"); + const char * str_opsync = getenv("GGML_HEXAGON_OPSYNC"); + const char * str_profile = getenv("GGML_HEXAGON_PROFILE"); + const char * str_etm = getenv("GGML_HEXAGON_ETM"); + const char * str_nhvx = getenv("GGML_HEXAGON_NHVX"); + const char * str_ndev = getenv("GGML_HEXAGON_NDEV"); + const char * str_arch = getenv("GGML_HEXAGON_ARCH"); + + opt_experimental = str_experimental ? atoi(str_experimental) : 0; opt_verbose = str_verbose ? atoi(str_verbose) : 0; - opt_profile = getenv("GGML_HEXAGON_PROFILE") != nullptr; - opt_etm = getenv("GGML_HEXAGON_ETM") != nullptr; - opt_experimental = getenv("GGML_HEXAGON_EXPERIMENTAL") != nullptr; - - const char * str_opmask = getenv("GGML_HEXAGON_OPMASK"); - if (str_opmask != nullptr) { - opt_opmask = strtoul(str_opmask, NULL, 0); - } - opt_opsync = getenv("GGML_HEXAGON_OPSYNC") != nullptr; + opt_hostbuf = str_hostbuf ? atoi(str_hostbuf) : opt_hostbuf; + opt_opmask = str_opmask ? strtoul(str_opmask, NULL, 0) : opt_opmask; + opt_opsync = str_opsync ? atoi(str_opsync) : 0; + opt_profile = str_profile ? atoi(str_profile) : 0; + opt_etm = str_etm ? atoi(str_etm) : 0; + opt_nhvx = str_nhvx ? strtoul(str_nhvx, NULL, 0) : opt_nhvx; + opt_ndev = str_ndev ? strtoul(str_ndev, NULL, 0) : opt_ndev; - const char * str_ndev = getenv("GGML_HEXAGON_NDEV"); - if (str_ndev) { - opt_ndev = strtoul(str_ndev, NULL, 0); - if (opt_ndev > GGML_HEXAGON_MAX_SESSIONS) { - opt_ndev = GGML_HEXAGON_MAX_SESSIONS; - } + if (opt_ndev > GGML_HEXAGON_MAX_SESSIONS) { + opt_ndev = GGML_HEXAGON_MAX_SESSIONS; } - const char * str_nhvx = getenv("GGML_HEXAGON_NHVX"); - if (str_nhvx) { - opt_nhvx = strtoul(str_nhvx, NULL, 0); - } - - const char * str_arch = getenv("GGML_HEXAGON_ARCH"); if (str_arch) { if (str_arch[0] == 'v') { str_arch++; @@ -3113,8 +3152,6 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) { opt_arch = strtoul(str_arch, NULL, 0); } - opt_hostbuf = str_hostbuf ? atoi(str_hostbuf) : 1; - reg->context = new ggml_hexagon_registry(reg); HEX_VERBOSE("ggml-hex: size-of-general-req %zu size-of-general-rsp %zu\n", sizeof(struct htp_general_req), diff --git a/ggml/src/ggml-hexagon/htp/CMakeLists.txt b/ggml/src/ggml-hexagon/htp/CMakeLists.txt index 6a34a215fa4..e8ef203045c 100644 --- a/ggml/src/ggml-hexagon/htp/CMakeLists.txt +++ b/ggml/src/ggml-hexagon/htp/CMakeLists.txt @@ -17,11 +17,7 @@ add_library(${HTP_LIB} SHARED main.c htp_iface_skel.c worker-pool.c - htp-dma.c - hvx-sigmoid.c - hvx-inverse.c - hvx-exp.c - hvx-utils.c + hex-dma.c matmul-ops.c binary-ops.c unary-ops.c @@ -31,10 +27,12 @@ add_library(${HTP_LIB} SHARED flash-attn-ops.c set-rows-ops.c get-rows-ops.c + cpy-ops.c ) target_compile_definitions(${HTP_LIB} PRIVATE $,HTP_DEBUG=1,NDEBUG=1> + $,FARF_HIGH=1,> FP32_QUANTIZE_GROUP_SIZE=${GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE}) build_idl(htp_iface.idl ${HTP_LIB}) diff --git a/ggml/src/ggml-hexagon/htp/act-ops.c b/ggml/src/ggml-hexagon/htp/act-ops.c index 88bd2ddc435..c3daf5adb2e 100644 --- a/ggml/src/ggml-hexagon/htp/act-ops.c +++ b/ggml/src/ggml-hexagon/htp/act-ops.c @@ -2,27 +2,20 @@ #pragma clang diagnostic ignored "-Wunused-function" #pragma clang diagnostic ignored "-Wunused-but-set-variable" -#ifdef HTP_DEBUG -# define FARF_HIGH 1 -#endif #include -#include #include -#include -#include -#include + #include -#include #include +#include "hex-dma.h" +#include "hvx-utils.h" + #define GGML_COMMON_DECL_C #include "ggml-common.h" #include "htp-ctx.h" -#include "htp-dma.h" #include "htp-msg.h" #include "htp-ops.h" -#include "hvx-utils.h" -#include "ops-utils.h" #define htp_act_preamble3 \ const uint32_t ne00 = src0->ne[0]; \ @@ -76,7 +69,7 @@ const uint32_t nb2 = dst->nb[2]; \ const uint32_t nb3 = dst->nb[3]; -static void glu_swiglu_fp32_per_thread(const struct htp_tensor * src0, +static void glu_swiglu_f32_per_thread(const struct htp_tensor * src0, const struct htp_tensor * src1, struct htp_tensor * dst, const int32_t * op_params, @@ -124,9 +117,9 @@ static void glu_swiglu_fp32_per_thread(const struct htp_tensor * src0, data_src1 += swapped ? 0 : nc_in_bytes; } - const size_t src0_row_size_aligned = htp_round_up(src0_row_size, VLEN); - const size_t src1_row_size_aligned = htp_round_up(src1_row_size, VLEN); - const size_t dst_row_size_aligned = htp_round_up(dst_row_size, VLEN); + const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN); + const size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN); + const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN); uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread); uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_spad->size_per_thread); @@ -175,9 +168,9 @@ static void glu_swiglu_fp32_per_thread(const struct htp_tensor * src0, float * dst_spad_ptr = dst_spad + ib * (dst_row_size_aligned / sizeof(float)); //swiglu(x) = x1 * sigmoid(x0) - hvx_fast_sigmoid_f32((const uint8_t *) src0_spad_ptr, (uint8_t *) dst_spad_ptr, nc); - hvx_mul_mul_f32_opt((const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, - (const uint8_t *) src1_spad_ptr, (uint8_t *) dst_spad_ptr, nc); + hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, nc); + hvx_mul_mul_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, + (const uint8_t *) src1_spad_ptr, nc); } dma_queue_push_vtcm_to_ddr(dma_queue, dma_make_ptr(data_dst + (ir * dst_row_size), dst_spad), dst_row_size, @@ -203,7 +196,7 @@ static void glu_swiglu_fp32_per_thread(const struct htp_tensor * src0, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } -static void glu_swiglu_oai_fp32_per_thread(const struct htp_tensor * src0, +static void glu_swiglu_oai_f32_per_thread(const struct htp_tensor * src0, const struct htp_tensor * src1, struct htp_tensor * dst, const int32_t * op_params, @@ -249,9 +242,9 @@ static void glu_swiglu_oai_fp32_per_thread(const struct htp_tensor * src0, data_src1 += swapped ? 0 : nc_in_bytes; } - const size_t src0_row_size_aligned = htp_round_up(src0_row_size, VLEN); - const size_t src1_row_size_aligned = htp_round_up(src1_row_size, VLEN); - const size_t dst_row_size_aligned = htp_round_up(dst_row_size, VLEN); + const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN); + const size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN); + const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN); uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread); uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_spad->size_per_thread); @@ -304,18 +297,18 @@ static void glu_swiglu_oai_fp32_per_thread(const struct htp_tensor * src0, float * dst_spad_ptr = dst_spad + ib * (dst_row_size_aligned / sizeof(float)); // x (src0_spad_data) = std::min(src0_p[k], limit); - hvx_min_scalar_f32((const uint8_t *) src0_spad_ptr, limit, (uint8_t *) src0_spad_ptr, nc); + hvx_min_scalar_f32((uint8_t *) src0_spad_ptr, (const uint8_t *) src0_spad_ptr, limit, nc); // y1 (src1_spad_data) = std::clamp(src1_p[k], -limit, limit); - hvx_clamp_scalar_f32((const uint8_t *) src1_spad_ptr, -limit, limit, (uint8_t *) src1_spad_ptr, nc); + hvx_clamp_scalar_f32((uint8_t *) src1_spad_ptr, (const uint8_t *) src1_spad_ptr, -limit, limit, nc); // y (src1_spad_data) = y1 + 1.f - hvx_add_scalar_f32((const uint8_t *) src1_spad_ptr, 1.0, (uint8_t *) src1_spad_ptr, nc); + hvx_add_scalar_f32((uint8_t *) src1_spad_ptr, (const uint8_t *) src1_spad_ptr, 1.0, nc); // x1 (dst_spad_data) = alpha * (x) - hvx_mul_scalar_f32((const uint8_t *) src0_spad_ptr, alpha, (uint8_t *) dst_spad_ptr, nc); + hvx_mul_scalar_f32((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, alpha, nc); // x2 (dst_spad_data) = sigmoid(x1) = 1/(1+exp(-x1)) - hvx_fast_sigmoid_f32((const uint8_t *) dst_spad_ptr, (uint8_t *) dst_spad_ptr, nc); + hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) dst_spad_ptr, nc); // out = x * sigmoid(alpha * x) * (y + 1.f) - hvx_mul_mul_f32_opt((const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, - (const uint8_t *) src1_spad_ptr, (uint8_t *) dst_spad_ptr, nc); + hvx_mul_mul_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, + (const uint8_t *) src1_spad_ptr, nc); } dma_queue_push_vtcm_to_ddr(dma_queue, dma_make_ptr(data_dst + (ir * dst_row_size), dst_spad), dst_row_size, @@ -342,7 +335,7 @@ static void glu_swiglu_oai_fp32_per_thread(const struct htp_tensor * src0, } -static void unary_gelu_fp32_per_thread(const struct htp_tensor * src0, +static void unary_gelu_f32_per_thread(const struct htp_tensor * src0, struct htp_tensor * dst, const int32_t * op_params, struct htp_spad * src0_spad, @@ -358,8 +351,8 @@ static void unary_gelu_fp32_per_thread(const struct htp_tensor * src0, const size_t src0_row_size = nb01; const size_t dst_row_size = nb1; - const size_t src0_row_size_aligned = htp_round_up(src0_row_size, VLEN); - const size_t dst_row_size_aligned = htp_round_up(dst_row_size, VLEN); + const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN); + const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN); const uint32_t src0_nrows = ne01 * ne02 * ne03; @@ -415,9 +408,9 @@ static void unary_gelu_fp32_per_thread(const struct htp_tensor * src0, float* dst_spad_ptr = dst_spad + ib * (dst_row_size_aligned / sizeof(float)); // gelu = x * sigmoid(1.702 * x) // current implementation - hvx_mul_scalar_f32((const uint8_t *) src0_spad_ptr, (float) 1.702, (uint8_t *) dst_spad_ptr, ne0); - hvx_fast_sigmoid_f32((const uint8_t *) dst_spad_ptr, (uint8_t *) dst_spad_ptr, ne0); - hvx_mul_f32_opt((const uint8_t *) src0_spad_ptr, (uint8_t *) dst_spad_ptr, (uint8_t *) dst_spad_ptr, ne0); + hvx_mul_scalar_f32((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (float) 1.702, ne0); + hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0); + hvx_mul_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0); } dma_queue_push_vtcm_to_ddr(dma_queue, @@ -442,15 +435,15 @@ static void unary_gelu_fp32_per_thread(const struct htp_tensor * src0, ne03, src0_start_row, src0_end_row, ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } -static void unary_gelu_fp32(unsigned int n, unsigned int i, void * data) { +static void unary_gelu_f32(unsigned int n, unsigned int i, void * data) { struct htp_ops_context * octx = (struct htp_ops_context *) data; - unary_gelu_fp32_per_thread(&octx->src0, &octx->dst, octx->op_params, &octx->src0_spad, &octx->dst_spad, n, i, + unary_gelu_f32_per_thread(&octx->src0, &octx->dst, octx->op_params, &octx->src0_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]); } -static void unary_silu_fp32_per_thread(const struct htp_tensor * src0, +static void unary_silu_f32_per_thread(const struct htp_tensor * src0, struct htp_tensor * dst, const int32_t * op_params, struct htp_spad * src0_spad, @@ -466,8 +459,8 @@ static void unary_silu_fp32_per_thread(const struct htp_tensor * src0, const size_t src0_row_size = nb01; const size_t dst_row_size = nb1; - const size_t src0_row_size_aligned = htp_round_up(src0_row_size, VLEN); - const size_t dst_row_size_aligned = htp_round_up(dst_row_size, VLEN); + const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN); + const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN); const uint32_t src0_nrows = ne01 * ne02 * ne03; @@ -522,8 +515,8 @@ static void unary_silu_fp32_per_thread(const struct htp_tensor * src0, float* dst_spad_ptr = dst_spad + ib * (dst_row_size_aligned / sizeof(float)); // silu = x * sigmoid(x) - hvx_fast_sigmoid_f32((const uint8_t *) src0_spad_ptr, (uint8_t *) dst_spad_ptr, ne0); - hvx_mul_f32_opt((const uint8_t *) src0_spad_ptr, (uint8_t *) dst_spad_ptr, (uint8_t *) dst_spad_ptr, ne0); + hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, ne0); + hvx_mul_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0); } dma_queue_push_vtcm_to_ddr(dma_queue, @@ -548,25 +541,25 @@ static void unary_silu_fp32_per_thread(const struct htp_tensor * src0, ne03, src0_start_row, src0_end_row, ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } -static void unary_silu_fp32(unsigned int n, unsigned int i, void * data) { +static void unary_silu_f32(unsigned int n, unsigned int i, void * data) { struct htp_ops_context * octx = (struct htp_ops_context *) data; - unary_silu_fp32_per_thread(&octx->src0, &octx->dst, octx->op_params, &octx->src0_spad, &octx->dst_spad, n, i, + unary_silu_f32_per_thread(&octx->src0, &octx->dst, octx->op_params, &octx->src0_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]); } -static void glu_swiglu_fp32(unsigned int n, unsigned int i, void * data) { +static void glu_swiglu_f32(unsigned int n, unsigned int i, void * data) { struct htp_ops_context * octx = (struct htp_ops_context *) data; - glu_swiglu_fp32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad, + glu_swiglu_f32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad, &octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]); } -static void glu_swiglu_oai_fp32(unsigned int n, unsigned int i, void * data) { +static void glu_swiglu_oai_f32(unsigned int n, unsigned int i, void * data) { struct htp_ops_context * octx = (struct htp_ops_context *) data; - glu_swiglu_oai_fp32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad, + glu_swiglu_oai_f32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad, &octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]); } -static int execute_op_activations_fp32(struct htp_ops_context * octx) { +static int execute_op_activations_f32(struct htp_ops_context * octx) { int err = HTP_STATUS_OK; const struct htp_tensor * src0 = &octx->src0; @@ -583,21 +576,21 @@ static int execute_op_activations_fp32(struct htp_ops_context * octx) { switch (octx->op) { case HTP_OP_UNARY_SILU: - act_op_func = unary_silu_fp32; + act_op_func = unary_silu_f32; op_type = "silu-f32"; break; case HTP_OP_GLU_SWIGLU: - act_op_func = glu_swiglu_fp32; + act_op_func = glu_swiglu_f32; op_type = "swiglu-f32"; break; case HTP_OP_GLU_SWIGLU_OAI: - act_op_func = glu_swiglu_oai_fp32; + act_op_func = glu_swiglu_oai_f32; op_type = "swiglu-oai-f32"; break; case HTP_OP_UNARY_GELU: - act_op_func = unary_gelu_fp32; + act_op_func = unary_gelu_f32; op_type = "gelu-f32"; break; default: @@ -617,9 +610,9 @@ static int execute_op_activations_fp32(struct htp_ops_context * octx) { src1_row_size = src0_row_size; } - const size_t src0_row_size_aligned = htp_round_up(src0_row_size, VLEN); - const size_t src1_row_size_aligned = htp_round_up(src1_row_size, VLEN); - const size_t dst_row_size_aligned = htp_round_up(dst_row_size, VLEN); + const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN); + const size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN); + const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN); // VTCM scratchpads for all tensors // N rows per thread, padded to HVX vector size @@ -670,7 +663,7 @@ int op_activations(struct htp_ops_context * octx) { switch (octx->src0.type) { case HTP_TYPE_F32: - err = execute_op_activations_fp32(octx); + err = execute_op_activations_f32(octx); break; default: diff --git a/ggml/src/ggml-hexagon/htp/binary-ops.c b/ggml/src/ggml-hexagon/htp/binary-ops.c index 8ed7f67d9c8..de22afe460e 100644 --- a/ggml/src/ggml-hexagon/htp/binary-ops.c +++ b/ggml/src/ggml-hexagon/htp/binary-ops.c @@ -2,36 +2,25 @@ #pragma clang diagnostic ignored "-Wunused-function" #pragma clang diagnostic ignored "-Wunused-but-set-variable" -#ifdef HTP_DEBUG -# define FARF_HIGH 1 -#endif - #include -#include #include -#include -#include -#include + #include -#include #include +#include "hex-dma.h" +#include "hvx-utils.h" + #define GGML_COMMON_DECL_C #include "ggml-common.h" #include "htp-ctx.h" -#include "htp-dma.h" #include "htp-msg.h" #include "htp-ops.h" -#include "hvx-utils.h" -#include "ops-utils.h" -typedef void (*hvx_elemwise_f32_func)(const uint8_t * src0, - const uint8_t * src1, - uint8_t * data_dst, - const int num_elems); +typedef void (*hvx_elemwise_f32_func)(uint8_t * data_dst, const uint8_t * src0, const uint8_t * src1, const uint32_t num_elems); static hvx_elemwise_f32_func func_table_HVX[] = { hvx_mul_f32, hvx_add_f32, hvx_sub_f32 }; -static hvx_elemwise_f32_func func_table_HVX_opt[] = { hvx_mul_f32_opt, hvx_add_f32_opt, hvx_sub_f32_opt }; +static hvx_elemwise_f32_func func_table_HVX_opt[] = { hvx_mul_f32_aa, hvx_add_f32_aa, hvx_sub_f32_aa }; #define htp_binary_preamble \ const struct htp_tensor * src0 = &octx->src0; \ @@ -98,9 +87,8 @@ static void binary_job_f32_per_thread(struct htp_ops_context * octx, int is_aligned = 1; int opt_path = 0; - if ((0 == htp_is_aligned((void *) src0->data, VLEN)) || (0 == htp_is_aligned((void *) src1->data, VLEN)) || - (0 == htp_is_aligned((void *) dst->data, VLEN))) { - FARF(HIGH, "binary-f32: unaligned addresses in elementwise op, possibly slower execution\n"); + if ((0 == hex_is_aligned((void *) src0->data, VLEN)) || (0 == hex_is_aligned((void *) src1->data, VLEN)) || + (0 == hex_is_aligned((void *) dst->data, VLEN))) { is_aligned = 0; } if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) { @@ -130,24 +118,24 @@ static void binary_job_f32_per_thread(struct htp_ops_context * octx, const uint8_t * restrict src1_ptr = data_src1 + i13 * nb13 + i12 * nb12 + i11 * src1_row_size; if (ir + 1 < src0_end_row) { - htp_l2fetch(src0_ptr + ne00, 1, src0_row_size, src0_row_size); + hex_l2fetch(src0_ptr + ne00, src0_row_size, src0_row_size, 1); if (src1_row_size == src0_row_size) { - htp_l2fetch(src1_ptr, 1, src1_row_size, src1_row_size); + hex_l2fetch(src1_ptr, src1_row_size, src1_row_size, 1); } } const uint32_t nr0 = ne00 / ne10; if (nr0 > 1) { if ((1 == is_aligned) && (nr0 == ne00)) { - hvx_bcast_fp32_a(spad_data_th, *(float *) src1_ptr, nr0); + hvx_splat_f32_a(spad_data_th, *(float *) src1_ptr, nr0); } else { for (uint32_t r = 0; r < nr0; r++) { memcpy(spad_data_th + r * nb11, (const uint8_t *) src1_ptr, nb11); } } - func_HVX((const uint8_t *) src0_ptr, (const uint8_t *) spad_data_th, (uint8_t *) dst_ptr, ne00); + func_HVX((uint8_t *) dst_ptr, (const uint8_t *) src0_ptr, (const uint8_t *) spad_data_th, ne00); } else { - func_HVX((const uint8_t *) src0_ptr, (const uint8_t *) src1_ptr, (uint8_t *) dst_ptr, ne00); + func_HVX((uint8_t *) dst_ptr, (const uint8_t *) src0_ptr, (const uint8_t *) src1_ptr, ne00); } src0_ptr += src0_row_size; @@ -185,11 +173,6 @@ static void binary_add_id_job_f32_per_thread(struct htp_ops_context * octx, uint64_t t1, t2; t1 = HAP_perf_get_qtimer_count(); - if ((0 == htp_is_aligned((void *) src0->data, VLEN)) || (0 == htp_is_aligned((void *) src1->data, VLEN)) || - (0 == htp_is_aligned((void *) dst->data, VLEN))) { - FARF(HIGH, "add-id-f32: unaligned addresses, possibly slower execution\n"); - } - const uint8_t * restrict data_src0 = (const uint8_t *) src0->data; const uint8_t * restrict data_src1 = (const uint8_t *) src1->data; uint8_t * restrict data_dst = (uint8_t *) dst->data; @@ -210,9 +193,9 @@ static void binary_add_id_job_f32_per_thread(struct htp_ops_context * octx, const float * restrict src1_ptr = (const float *) (data_src1 + 0 + 0 + i11 * nb11); if (ir + 1 < src0_end_row) { - htp_l2fetch(src0_ptr + ne00, 1, src0_row_size, src0_row_size); + hex_l2fetch(src0_ptr + ne00, src0_row_size, src0_row_size, 1); if (src1_row_size == src0_row_size) { - htp_l2fetch(src1_ptr + ne10, 1, src1_row_size, src1_row_size); + hex_l2fetch(src1_ptr + ne10, src1_row_size, src1_row_size, 1); } } @@ -221,9 +204,9 @@ static void binary_add_id_job_f32_per_thread(struct htp_ops_context * octx, for (uint32_t r = 0; r < nr0; r++) { memcpy(spad_data + r * nb10, (const uint8_t *) src1_ptr, nb10); } - func_HVX((const uint8_t *) src0_ptr, (const uint8_t *) spad_data, (uint8_t *) dst_ptr, ne00); + func_HVX((uint8_t *) dst_ptr, (const uint8_t *) src0_ptr, (const uint8_t *) spad_data, ne00); } else { - func_HVX((const uint8_t *) src0_ptr, (const uint8_t *) src1_ptr, (uint8_t *) dst_ptr, ne00); + func_HVX((uint8_t *) dst_ptr, (const uint8_t *) src0_ptr, (const uint8_t *) src1_ptr, ne00); } } @@ -299,9 +282,9 @@ static int execute_op_binary_f32(struct htp_ops_context * octx) { const size_t dst_row_size = dst->nb[1]; // VTCM scratchpads for all tensors - octx->dst_spad.size = htp_round_up(dst_row_size, 128) * n_threads; - octx->src0_spad.size = htp_round_up(src0_row_size, 128) * n_threads; - octx->src1_spad.size = htp_round_up(src1_row_size, 128) * n_threads; + octx->dst_spad.size = hex_round_up(dst_row_size, 128) * n_threads; + octx->src0_spad.size = hex_round_up(src0_row_size, 128) * n_threads; + octx->src1_spad.size = hex_round_up(src1_row_size, 128) * n_threads; size_t spad_size = octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size; diff --git a/ggml/src/ggml-hexagon/htp/cpy-ops.c b/ggml/src/ggml-hexagon/htp/cpy-ops.c new file mode 100644 index 00000000000..559ca183789 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/cpy-ops.c @@ -0,0 +1,251 @@ +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#include +#include + +#include +#include + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" +#include "htp-msg.h" +#include "htp-ops.h" +#include "hvx-utils.h" + +struct htp_copy_context { + struct htp_ops_context * octx; + + uint32_t src0_type_size; + uint32_t src0_block_size; + + uint32_t dst_type_size; + uint32_t dst_block_size; + + uint32_t src0_blocks_per_row; + uint32_t dst_blocks_per_row; + + uint32_t src0_nrows_per_thread; + + void (*copy)(struct htp_copy_context * ct, struct htp_ops_context * octx, int nth, int ith); +}; + +#define cpy_preamble \ + struct htp_tensor *src0 = &octx->src0; \ + struct htp_tensor *dst = &octx->dst; \ + \ + const uint32_t ne00 = src0->ne[0]; \ + const uint32_t ne01 = src0->ne[1]; \ + const uint32_t ne02 = src0->ne[2]; \ + const uint32_t ne03 = src0->ne[3]; \ + \ + const uint32_t nb00 = src0->nb[0]; \ + const uint32_t nb01 = src0->nb[1]; \ + const uint32_t nb02 = src0->nb[2]; \ + const uint32_t nb03 = src0->nb[3]; \ + \ + const uint32_t ne0 = dst->ne[0]; \ + const uint32_t ne1 = dst->ne[1]; \ + const uint32_t ne2 = dst->ne[2]; \ + const uint32_t ne3 = dst->ne[3]; \ + \ + const uint32_t nb0 = dst->nb[0]; \ + const uint32_t nb1 = dst->nb[1]; \ + const uint32_t nb2 = dst->nb[2]; \ + const uint32_t nb3 = dst->nb[3]; \ + \ + const uint32_t nr = ne01; + +static void cpy_thread_sametype_sameshape(struct htp_copy_context * ct, struct htp_ops_context * octx, const int nth, const int ith) { + cpy_preamble; + + // parallelize by src0 rows + const uint32_t dr = ct->src0_nrows_per_thread; + const uint32_t ir0 = dr * ith; + const uint32_t ir1 = (ir0 + dr) < nr ? (ir0 + dr) : nr; + + // copy by rows + for (uint32_t i03 = 0; i03 < ne03; i03++) { + for (uint32_t i02 = 0; i02 < ne02; i02++) { + #pragma unroll(2) + for (uint32_t i01 = ir0; i01 < ir1; i01++) { + uint8_t* dst_ptr = (uint8_t*) dst->data + i01*nb1 + i02*nb2 + i03*nb3; + uint8_t* src0_ptr = (uint8_t*) src0->data + i01*nb01 + i02*nb02 + i03*nb03; + hex_l2fetch(src0_ptr, ne00 * ct->src0_type_size, nb01, 2); + hvx_copy_uu(dst_ptr, src0_ptr, ne00, ct->src0_type_size); + } + } + } +} + +static void cpy_thread_sametype_reshape(struct htp_copy_context * ct, struct htp_ops_context * octx, int nth, int ith) { + cpy_preamble; + + // parallelize by src0 rows + const uint32_t dr = ct->src0_nrows_per_thread; + const uint32_t ir0 = dr * ith; + const uint32_t ir1 = (ir0 + dr) < nr ? (ir0 + dr) : nr; + + // dst counters + int64_t k10 = 0; + int64_t i11 = 0; + int64_t i12 = 0; + int64_t i13 = 0; + + // number of blocks in a row + const int64_t nk00 = ct->src0_blocks_per_row; + const int64_t nk0 = ct->dst_blocks_per_row; + + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + k10 += nk00 * ir0; + while (k10 >= nk0) { + k10 -= nk0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + for (int64_t i01 = ir0; i01 < ir1; i01++) { + for (int64_t k00 = 0; k00 < nk00; k00++) { + const char * src0_ptr = ((char *) src0->data + k00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); + char * dst_ptr = ((char *) dst->data + k10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); + memcpy(dst_ptr, src0_ptr, ct->dst_type_size); + + if (++k10 == nk0) { + k10 = 0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } + k10 += nk00 * (ne01 - ir1); + while (k10 >= nk0) { + k10 -= nk0; + if (++i11 == ne1) { + i11 = 0; + if (++i12 == ne2) { + i12 = 0; + if (++i13 == ne3) { + i13 = 0; + } + } + } + } + } + } +} + +static void cpy_thread_f16_f32_sameshape(struct htp_copy_context * ct, struct htp_ops_context * octx, const int nth, const int ith) { + cpy_preamble; + + // parallelize by src0 rows + const uint32_t dr = ct->src0_nrows_per_thread; + const uint32_t ir0 = dr * ith; + const uint32_t ir1 = (ir0 + dr) < nr ? (ir0 + dr) : nr; + + // copy by rows + for (uint32_t i03 = 0; i03 < ne03; i03++) { + for (uint32_t i02 = 0; i02 < ne02; i02++) { + #pragma unroll(2) + for (uint32_t i01 = ir0; i01 < ir1; i01++) { + uint8_t* dst_ptr = (uint8_t*) dst->data + i01*nb1 + i02*nb2 + i03*nb3; + uint8_t* src0_ptr = (uint8_t*) src0->data + i01*nb01 + i02*nb02 + i03*nb03; + hex_l2fetch(src0_ptr, ne00 * sizeof(float), nb01, 2); + hvx_copy_f16_f32_uu(dst_ptr, src0_ptr, ne00); + } + } + } +} + +static void cpy_thread_f32_f16_sameshape(struct htp_copy_context * ct, struct htp_ops_context * octx, const int nth, const int ith) { + cpy_preamble; + + // parallelize by src0 rows + const uint32_t dr = ct->src0_nrows_per_thread; + const uint32_t ir0 = dr * ith; + const uint32_t ir1 = (ir0 + dr) < nr ? (ir0 + dr) : nr; + + // copy by rows + for (uint32_t i03 = 0; i03 < ne03; i03++) { + for (uint32_t i02 = 0; i02 < ne02; i02++) { + #pragma unroll(2) + for (uint32_t i01 = ir0; i01 < ir1; i01++) { + uint8_t* dst_ptr = (uint8_t*) dst->data + i01*nb1 + i02*nb2 + i03*nb3; + uint8_t* src0_ptr = (uint8_t*) src0->data + i01*nb01 + i02*nb02 + i03*nb03; + hex_l2fetch(src0_ptr, ne00 * sizeof(__fp16), nb01, 2); + hvx_copy_f32_f16_uu(dst_ptr, src0_ptr, ne00); + } + } + } +} + +static void cpy_work_func(unsigned int n, unsigned int i, void *data) { + struct htp_copy_context *ct = (struct htp_copy_context *) data; + ct->copy(ct, ct->octx, n, i); +} + +int op_cpy(struct htp_ops_context * octx) { + cpy_preamble; + + struct htp_copy_context ct; + ct.octx = octx; + + switch (src0->type) { + case HTP_TYPE_F32: ct.src0_type_size = 4; ct.src0_block_size = 1; ct.src0_blocks_per_row = ne00 / 1; break; + case HTP_TYPE_F16: ct.src0_type_size = 2; ct.src0_block_size = 1; ct.src0_blocks_per_row = ne00 / 1; break; + default: + return HTP_STATUS_NO_SUPPORT; + } + + switch (dst->type) { + case HTP_TYPE_F32: ct.dst_type_size = 4; ct.dst_block_size = 1; ct.dst_blocks_per_row = ne0 / 1; break; + case HTP_TYPE_F16: ct.dst_type_size = 2; ct.dst_block_size = 1; ct.dst_blocks_per_row = ne0 / 1; break; + default: + return HTP_STATUS_NO_SUPPORT; + } + + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { + return HTP_STATUS_OK; + } + + const bool sametype = (src0->type == dst->type); + const bool transposed = (nb00 > nb01) || (nb0 > nb1); + const bool sameshape = !transposed && (ne00 == ne0 && ne01 == ne1 && ne02 == ne2 && ne03 == ne3); + + const uint32_t n_jobs = MIN(nr, octx->n_threads); + ct.src0_nrows_per_thread = (nr + n_jobs - 1) / n_jobs; + + if (sametype && sameshape) { + ct.copy = cpy_thread_sametype_sameshape; + } else if (sameshape) { + /**/ if (dst->type == HTP_TYPE_F16 && src0->type == HTP_TYPE_F32) + ct.copy = cpy_thread_f16_f32_sameshape; + else if (dst->type == HTP_TYPE_F32 && src0->type == HTP_TYPE_F16) + ct.copy = cpy_thread_f32_f16_sameshape; + else + return HTP_STATUS_NO_SUPPORT; + } else if (sametype) { + ct.copy = cpy_thread_sametype_reshape; + } else { + return HTP_STATUS_NO_SUPPORT; + } + + worker_pool_run_func(octx->ctx->worker_pool, cpy_work_func, &ct, n_jobs); + + return HTP_STATUS_OK; +} diff --git a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c index 04a7b843ce5..1de47d0f3d4 100644 --- a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +++ b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c @@ -2,25 +2,20 @@ #pragma clang diagnostic ignored "-Wunused-function" #pragma clang diagnostic ignored "-Wunused-but-set-variable" -#ifdef HTP_DEBUG -# define FARF_HIGH 1 -#endif #include -#include #include -#include -#include + #include #include +#include "hex-dma.h" +#include "hvx-utils.h" + #define GGML_COMMON_DECL_C #include "ggml-common.h" #include "htp-ctx.h" -#include "htp-dma.h" #include "htp-msg.h" #include "htp-ops.h" -#include "hvx-utils.h" -#include "ops-utils.h" // Dot product of FP32 and FP16 vectors, accumulating to float static inline void hvx_dot_f32_f16_aa(float * restrict r, const void * restrict y, const void * restrict x, unsigned int n, float s) { @@ -70,8 +65,8 @@ static inline void hvx_dot_f32_f16_aa(float * restrict r, const void * restrict rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); } - rsum = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(rsum), hvx_vec_splat_fp32(s)); - rsum = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum)); + rsum = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(rsum), hvx_vec_splat_f32(s)); + rsum = Q6_Vsf_equals_Vqf32(hvx_vec_reduce_sum_qf32(rsum)); hvx_vec_store_u(r, 4, rsum); } @@ -111,8 +106,8 @@ static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); } - rsum = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(rsum), hvx_vec_splat_fp32(s)); - rsum = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum)); + rsum = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(rsum), hvx_vec_splat_f32(s)); + rsum = Q6_Vsf_equals_Vqf32(hvx_vec_reduce_sum_qf32(rsum)); hvx_vec_store_u(r, 4, rsum); } @@ -124,7 +119,7 @@ static inline void hvx_mad_f32_f16_aa(float * restrict y, const void * restrict uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors uint32_t nloe = n % VLEN_FP16; // leftover elements - HVX_Vector S = hvx_vec_splat_fp16(s); + HVX_Vector S = hvx_vec_splat_f16(s); uint32_t i = 0; #pragma unroll(4) @@ -148,7 +143,7 @@ static inline void hvx_mad_f32_f16_aa(float * restrict y, const void * restrict if (nloe) { HVX_Vector xy = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i])); - hvx_vec_store_u(&ptr_y[i], nloe * 4, xy); + hvx_vec_store_a(&ptr_y[i], nloe * 4, xy); } } } @@ -225,18 +220,18 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in const uint32_t DV = nev0; const size_t size_q_row = DK * ((q->type == HTP_TYPE_F32) ? 4 : 2); - const size_t size_q_row_padded = htp_round_up(size_q_row, 128); + const size_t size_q_row_padded = hex_round_up(size_q_row, 128); const size_t size_k_row = DK * sizeof(__fp16); const size_t size_v_row = DV * sizeof(__fp16); const size_t size_m_row = FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16); // Treat block as one row for mask - const size_t size_k_row_padded = htp_round_up(size_k_row, 128); - const size_t size_v_row_padded = htp_round_up(size_v_row, 128); + const size_t size_k_row_padded = hex_round_up(size_k_row, 128); + const size_t size_v_row_padded = hex_round_up(size_v_row, 128); const size_t size_k_block = size_k_row_padded * FLASH_ATTN_BLOCK_SIZE; const size_t size_v_block = size_v_row_padded * FLASH_ATTN_BLOCK_SIZE; - const size_t size_m_block = htp_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128); + const size_t size_m_block = hex_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128); // Scratchpad buffers for Q, K, V, Mask, and VKQ32 accumulator uint8_t * spad_q = octx->src0_spad.data + octx->src0_spad.size_per_thread * ith; @@ -272,8 +267,8 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in float M = -INFINITY; // maximum KQ value // Clear accumulator + hvx_splat_f32_a(spad_a, 0, DV); float * VKQ32 = (float *) spad_a; - memset(VKQ32, 0, DV * sizeof(float)); const __fp16 * mp_base = NULL; if (mask) { @@ -340,30 +335,30 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in // 2. Softcap if (logit_softcap != 0.0f) { - scores = hvx_vec_tanh_fp32(scores); - scores = Q6_Vqf32_vmpy_VsfVsf(scores, hvx_vec_splat_fp32(logit_softcap)); + scores = hvx_vec_tanh_f32(scores); + scores = Q6_Vqf32_vmpy_VsfVsf(scores, hvx_vec_splat_f32(logit_softcap)); scores = Q6_Vsf_equals_Vqf32(scores); } // 3. Mask if (mask) { const __fp16 * mp = m_base + ic; - HVX_Vector m_vals_fp16 = *(const HVX_UVector *) mp; + HVX_Vector m_vals_f16 = *(const HVX_UVector *) mp; - HVX_Vector one_fp16 = Q6_Vh_vsplat_R(0x3c00); - HVX_VectorPair m_vals_fp32_pair = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_fp16), one_fp16); + HVX_Vector one_f16 = Q6_Vh_vsplat_R(0x3c00); + HVX_VectorPair m_vals_f32_pair = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_f16), one_f16); - HVX_Vector m_vals_fp32 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(m_vals_fp32_pair)); + HVX_Vector m_vals_f32 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(m_vals_f32_pair)); - HVX_Vector slope_vec = hvx_vec_splat_fp32(slope); - HVX_Vector add_val = Q6_Vqf32_vmpy_VsfVsf(m_vals_fp32, slope_vec); + HVX_Vector slope_vec = hvx_vec_splat_f32(slope); + HVX_Vector add_val = Q6_Vqf32_vmpy_VsfVsf(m_vals_f32, slope_vec); scores = Q6_Vqf32_vadd_VsfVsf(scores, Q6_Vsf_equals_Vqf32(add_val)); scores = Q6_Vsf_equals_Vqf32(scores); } // 4. Online Softmax Update - HVX_Vector v_max = hvx_vec_reduce_max_fp32(scores); - float m_block = hvx_vec_get_fp32(v_max); + HVX_Vector v_max = hvx_vec_reduce_max_f32(scores); + float m_block = hvx_vec_get_f32(v_max); float M_old = M; float M_new = (m_block > M) ? m_block : M; @@ -374,12 +369,12 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms); S = S * ms; - HVX_Vector M_new_vec = hvx_vec_splat_fp32(M_new); + HVX_Vector M_new_vec = hvx_vec_splat_f32(M_new); HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_new_vec); - HVX_Vector P = hvx_vec_exp_fp32(Q6_Vsf_equals_Vqf32(scores_shifted)); + HVX_Vector P = hvx_vec_exp_f32(Q6_Vsf_equals_Vqf32(scores_shifted)); - HVX_Vector p_sum_vec = hvx_vec_fp32_reduce_sum(P); - float p_sum = hvx_vec_get_fp32(p_sum_vec); + HVX_Vector p_sum_vec = hvx_vec_reduce_sum_f32(P); + float p_sum = hvx_vec_get_f32(p_sum_vec); S += p_sum; // 5. Accumulate V @@ -484,9 +479,9 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in uint8_t * dst_ptr = (uint8_t *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1) * nb1; if (dst->type == HTP_TYPE_F32) { - hvx_copy_fp32_ua(dst_ptr, (uint8_t *) VKQ32, DV); + hvx_copy_f32_ua(dst_ptr, (uint8_t *) VKQ32, DV); } else if (dst->type == HTP_TYPE_F16) { - hvx_copy_fp16_fp32_ua(dst_ptr, (uint8_t *) VKQ32, DV); + hvx_copy_f16_f32_ua(dst_ptr, (uint8_t *) VKQ32, DV); } } } @@ -523,16 +518,16 @@ int op_flash_attn_ext(struct htp_ops_context * octx) { octx->src3_div3 = init_fastdiv_values(mask->ne[3]); } - size_t size_q_row_padded = htp_round_up(q->ne[0] * (q->type == HTP_TYPE_F32 ? 4 : 2), 128); - size_t size_k_row_padded = htp_round_up(k->ne[0] * sizeof(__fp16), 128); - size_t size_v_row_padded = htp_round_up(v->ne[0] * sizeof(__fp16), 128); + size_t size_q_row_padded = hex_round_up(q->ne[0] * (q->type == HTP_TYPE_F32 ? 4 : 2), 128); + size_t size_k_row_padded = hex_round_up(k->ne[0] * sizeof(__fp16), 128); + size_t size_v_row_padded = hex_round_up(v->ne[0] * sizeof(__fp16), 128); size_t size_q_block = size_q_row_padded * 1; // single row for now size_t size_k_block = size_k_row_padded * FLASH_ATTN_BLOCK_SIZE; size_t size_v_block = size_v_row_padded * FLASH_ATTN_BLOCK_SIZE; - size_t size_m_block = htp_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128); + size_t size_m_block = hex_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128); - size_t size_vkq_acc = htp_round_up(v->ne[0] * sizeof(float), 128); // VKQ32 + size_t size_vkq_acc = hex_round_up(v->ne[0] * sizeof(float), 128); // VKQ32 octx->src0_spad.size_per_thread = size_q_block * 1; octx->src1_spad.size_per_thread = size_k_block * 2; diff --git a/ggml/src/ggml-hexagon/htp/get-rows-ops.c b/ggml/src/ggml-hexagon/htp/get-rows-ops.c index 54321421eb5..a657cd2dcf2 100644 --- a/ggml/src/ggml-hexagon/htp/get-rows-ops.c +++ b/ggml/src/ggml-hexagon/htp/get-rows-ops.c @@ -2,14 +2,9 @@ #pragma clang diagnostic ignored "-Wunused-function" #pragma clang diagnostic ignored "-Wunused-but-set-variable" -#ifdef HTP_DEBUG -# define FARF_HIGH 1 -#endif #include -#include #include -#include -#include + #include #include @@ -19,7 +14,6 @@ #include "htp-msg.h" #include "htp-ops.h" #include "hvx-utils.h" -#include "ops-utils.h" #define get_rows_preamble \ const uint32_t ne00 = octx->src0.ne[0]; \ @@ -72,7 +66,7 @@ static int get_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth, const uintptr_t src0_ptr = octx->src0.data + i01*nb01 + i11*nb02 + i12*nb03; const uintptr_t dst_ptr = octx->dst.data + i10*nb1 + i11*nb2 + i12*nb3; - hvx_copy_fp32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, ne00); + hvx_copy_f32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, ne00); } return HTP_STATUS_OK; diff --git a/ggml/src/ggml-hexagon/htp/htp-dma.c b/ggml/src/ggml-hexagon/htp/hex-dma.c similarity index 98% rename from ggml/src/ggml-hexagon/htp/htp-dma.c rename to ggml/src/ggml-hexagon/htp/hex-dma.c index 880c4542a0e..44e1be40c5d 100644 --- a/ggml/src/ggml-hexagon/htp/htp-dma.c +++ b/ggml/src/ggml-hexagon/htp/hex-dma.c @@ -1,4 +1,4 @@ -#include "htp-dma.h" +#include "hex-dma.h" #include #include diff --git a/ggml/src/ggml-hexagon/htp/htp-dma.h b/ggml/src/ggml-hexagon/htp/hex-dma.h similarity index 99% rename from ggml/src/ggml-hexagon/htp/htp-dma.h rename to ggml/src/ggml-hexagon/htp/hex-dma.h index 32fd06e7d46..d1ddb0ecbf0 100644 --- a/ggml/src/ggml-hexagon/htp/htp-dma.h +++ b/ggml/src/ggml-hexagon/htp/hex-dma.h @@ -2,7 +2,6 @@ #define HTP_DMA_H #include -#include #include #include #include diff --git a/ggml/src/ggml-hexagon/htp/hex-dump.h b/ggml/src/ggml-hexagon/htp/hex-dump.h new file mode 100644 index 00000000000..e3badb57f92 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hex-dump.h @@ -0,0 +1,77 @@ +#ifndef HEX_DUMP_H +#define HEX_DUMP_H + +#include + +static inline void hex_dump_int8_line(char * pref, const int8_t * x, int n) { + char str[1024], *p = str, *p_end = str + sizeof(str); + p += snprintf(p, p_end - p, "%s: ", pref); + for (int i = 0; i < n && p < p_end; i++) { + p += snprintf(p, p_end - p, "%d, ", x[i]); + } + FARF(HIGH, "%s\n", str); +} + +static inline void hex_dump_uint8_line(char * pref, const uint8_t * x, uint32_t n) { + char str[1024], *p = str, *p_end = str + sizeof(str); + p += snprintf(p, p_end - p, "%s: ", pref); + for (int i = 0; i < n && p < p_end; i++) { + p += snprintf(p, p_end - p, "%d, ", x[i]); + } + FARF(HIGH, "%s\n", str); +} + +static inline void hex_dump_int32_line(char * pref, const int32_t * x, uint32_t n) { + char str[1024], *p = str, *p_end = str + sizeof(str); + p += snprintf(p, p_end - p, "%s: ", pref); + for (int i = 0; i < n; i++) { + p += snprintf(p, p_end - p, "%d, ", (int) x[i]); + } + FARF(HIGH, "%s\n", str); +} + +static inline void hex_dump_f16_line(char * pref, const __fp16 * x, uint32_t n) { + char str[1024], *p = str, *p_end = str + sizeof(str); + p += snprintf(p, p_end - p, "%s: ", pref); + for (int i = 0; i < n; i++) { + p += snprintf(p, p_end - p, "%.6f, ", (float) x[i]); + } + FARF(HIGH, "%s\n", str); +} + +static inline void hex_dump_f32_line(char * pref, const float * x, uint32_t n) { + char str[1024], *p = str, *p_end = str + sizeof(str); + p += snprintf(p, p_end - p, "%s: ", pref); + for (int i = 0; i < n; i++) { + p += snprintf(p, p_end - p, "%.6f, ", x[i]); + } + FARF(HIGH, "%s\n", str); +} + +static inline void hex_dump_f32(char * pref, const float * x, uint32_t n) { + uint32_t n0 = n / 16; + uint32_t n1 = n % 16; + + uint32_t i = 0; + for (; i < n0; i++) { + hex_dump_f32_line(pref, x + (16 * i), 16); + } + if (n1) { + hex_dump_f32_line(pref, x + (16 * i), n1); + } +} + +static inline void hex_dump_f16(char * pref, const __fp16 * x, uint32_t n) { + uint32_t n0 = n / 16; + uint32_t n1 = n % 16; + + uint32_t i = 0; + for (; i < n0; i++) { + hex_dump_f16_line(pref, x + (16 * i), 16); + } + if (n1) { + hex_dump_f16_line(pref, x + (16 * i), n1); + } +} + +#endif /* HEX_DUMP_H */ diff --git a/ggml/src/ggml-hexagon/htp/hex-fastdiv.h b/ggml/src/ggml-hexagon/htp/hex-fastdiv.h new file mode 100644 index 00000000000..b7b5867593f --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hex-fastdiv.h @@ -0,0 +1,37 @@ +#ifndef HEX_FASTDIV_H +#define HEX_FASTDIV_H + +// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1. +// Precompute mp (m' in the paper) and L such that division +// can be computed using a multiply (high 32b of 64b result) +// and a shift: +// +// n/d = (mulhi(n, mp) + n) >> L; +struct fastdiv_values { + uint32_t mp; + uint32_t l; +}; + +static inline struct fastdiv_values init_fastdiv_values(uint32_t d) { + struct fastdiv_values result = { 0, 0 }; + // compute L = ceil(log2(d)); + while (result.l < 32 && ((uint32_t) 1 << result.l) < d) { + ++(result.l); + } + + result.mp = (uint32_t) (((uint64_t) 1 << 32) * (((uint64_t) 1 << result.l) - d) / d + 1); + return result; +} + +static inline uint32_t fastdiv(uint32_t n, const struct fastdiv_values * vals) { + // Compute high 32 bits of n * mp + const uint32_t hi = (uint32_t) (((uint64_t) n * vals->mp) >> 32); // mulhi(n, mp) + // add n, apply bit shift + return (hi + n) >> vals->l; +} + +static inline uint32_t fastmodulo(uint32_t n, uint32_t d, const struct fastdiv_values * vals) { + return n - fastdiv(n, vals) * d; +} + +#endif /* HEX_FASTDIV_H */ diff --git a/ggml/src/ggml-hexagon/htp/hex-utils.h b/ggml/src/ggml-hexagon/htp/hex-utils.h new file mode 100644 index 00000000000..fb8a25a3f20 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hex-utils.h @@ -0,0 +1,51 @@ +#ifndef HEX_UTILS_H +#define HEX_UTILS_H + +#include +#include + +#include "hexagon_types.h" + +#include "hex-fastdiv.h" +#include "hex-dump.h" + +#ifndef MAX +#define MAX(a, b) ((a) > (b) ? (a) : (b)) +#endif + +#ifndef MIN +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#endif + +static inline uint64_t hex_get_cycles() { + uint64_t cycles = 0; + asm volatile(" %0 = c15:14\n" : "=r"(cycles)); + return cycles; +} + +static inline uint64_t hex_get_pktcnt() { + uint64_t pktcnt; + asm volatile(" %0 = c19:18\n" : "=r"(pktcnt)); + return pktcnt; +} + +static inline int32_t hex_is_aligned(void * addr, uint32_t align) { + return ((size_t) addr & (align - 1)) == 0; +} + +static inline int32_t hex_is_one_chunk(void * addr, uint32_t n, uint32_t chunk_size) { + uint32_t left_off = (size_t) addr & (chunk_size - 1); + uint32_t right_off = left_off + n; + return right_off <= chunk_size; +} + +static inline uint32_t hex_round_up(uint32_t n, uint32_t m) { + return m * ((n + m - 1) / m); +} + +static inline void hex_l2fetch(const void * p, uint32_t width, uint32_t stride, uint32_t height) { + const uint64_t control = Q6_P_combine_RR(stride, Q6_R_combine_RlRl(width, height)); + Q6_l2fetch_AP((void *) p, control); +} + +#endif /* HEX_UTILS_H */ diff --git a/ggml/src/ggml-hexagon/htp/htp-ctx.h b/ggml/src/ggml-hexagon/htp/htp-ctx.h index 4bd0ea7a36a..a707d98239c 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ctx.h +++ b/ggml/src/ggml-hexagon/htp/htp-ctx.h @@ -1,7 +1,7 @@ #ifndef HTP_CTX_H #define HTP_CTX_H -#include "htp-dma.h" +#include "hex-dma.h" #include "worker-pool.h" #include diff --git a/ggml/src/ggml-hexagon/htp/htp-msg.h b/ggml/src/ggml-hexagon/htp/htp-msg.h index 846d0617843..f49e8ee4478 100644 --- a/ggml/src/ggml-hexagon/htp/htp-msg.h +++ b/ggml/src/ggml-hexagon/htp/htp-msg.h @@ -63,6 +63,7 @@ enum htp_op { HTP_OP_SET_ROWS = 15, HTP_OP_SCALE = 16, HTP_OP_GET_ROWS = 17, + HTP_OP_CPY = 18, INVALID }; diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index 7c828ae6362..602a2775a47 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -4,11 +4,12 @@ #include "htp-ctx.h" #include "htp-msg.h" #include "worker-pool.h" -#include "ops-utils.h" #include #include +#include + // ggml-common.h must be included prior to this header struct htp_spad { @@ -74,6 +75,14 @@ struct htp_ops_context { struct fastdiv_values get_rows_div_ne10; // fastdiv values for ne10 struct fastdiv_values get_rows_div_ne10_ne11; // fastdiv values for ne10 * ne11 + struct fastdiv_values cpy_div_ne01; // fastdiv values for ne01 + struct fastdiv_values cpy_div_ne02; // fastdiv values for ne02 + struct fastdiv_values cpy_div_ne03; // fastdiv values for ne03 + + struct fastdiv_values cpy_rshp_div_n0; // fastdiv values for ne00 + struct fastdiv_values cpy_rshp_div_n1n0; // fastdiv values for ne00*ne01 + struct fastdiv_values cpy_rshp_div_n2n1n0; // fastdiv values for ne00*ne01*ne02 + uint32_t flags; }; @@ -88,5 +97,6 @@ int op_rope(struct htp_ops_context * octx); int op_flash_attn_ext(struct htp_ops_context * octx); int op_set_rows(struct htp_ops_context * octx); int op_get_rows(struct htp_ops_context * octx); +int op_cpy(struct htp_ops_context * octx); #endif /* HTP_OPS_H */ diff --git a/ggml/src/ggml-hexagon/htp/hvx-arith.h b/ggml/src/ggml-hexagon/htp/hvx-arith.h new file mode 100644 index 00000000000..3449739a4fa --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hvx-arith.h @@ -0,0 +1,457 @@ +#ifndef HVX_ARITH_H +#define HVX_ARITH_H + +#include +#include +#include +#include + +#include "hvx-base.h" +#include "hex-utils.h" + +// +// Binary operations (add, mul, sub) +// + +#define hvx_arith_loop_body(dst_type, src0_type, src1_type, vec_store, vec_op) \ + do { \ + dst_type * restrict vdst = (dst_type *) dst; \ + src0_type * restrict vsrc0 = (src0_type *) src0; \ + src1_type * restrict vsrc1 = (src1_type *) src1; \ + \ + const uint32_t elem_size = sizeof(float); \ + const uint32_t epv = 128 / elem_size; \ + const uint32_t nvec = n / epv; \ + const uint32_t nloe = n % epv; \ + \ + uint32_t i = 0; \ + \ + _Pragma("unroll(4)") \ + for (; i < nvec; i++) { \ + vdst[i] = vec_op(vsrc0[i], vsrc1[i]); \ + } \ + if (nloe) { \ + HVX_Vector v = vec_op(vsrc0[i], vsrc1[i]); \ + vec_store((void *) &vdst[i], nloe * elem_size, v); \ + } \ + } while(0) + +#if __HVX_ARCH__ < 79 +#define HVX_OP_ADD(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(a, b)) +#define HVX_OP_SUB(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(a, b)) +#define HVX_OP_MUL(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b)) +#else +#define HVX_OP_ADD(a, b) Q6_Vsf_vadd_VsfVsf(a, b) +#define HVX_OP_SUB(a, b) Q6_Vsf_vsub_VsfVsf(a, b) +#define HVX_OP_MUL(a, b) Q6_Vsf_vmpy_VsfVsf(a, b) +#endif + +// ADD variants + +static inline void hvx_add_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { + assert((unsigned long) dst % 128 == 0); + assert((unsigned long) src0 % 128 == 0); + assert((unsigned long) src1 % 128 == 0); + hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_ADD); +} + +static inline void hvx_add_f32_au(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { + assert((unsigned long) dst % 128 == 0); + assert((unsigned long) src0 % 128 == 0); + hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_ADD); +} + +static inline void hvx_add_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { + assert((unsigned long) src0 % 128 == 0); + assert((unsigned long) src1 % 128 == 0); + hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u, HVX_OP_ADD); +} + +static inline void hvx_add_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { + hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_ADD); +} + +// SUB variants + +static inline void hvx_sub_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { + assert((unsigned long) dst % 128 == 0); + assert((unsigned long) src0 % 128 == 0); + assert((unsigned long) src1 % 128 == 0); + hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_SUB); +} + +static inline void hvx_sub_f32_au(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { + assert((unsigned long) dst % 128 == 0); + assert((unsigned long) src0 % 128 == 0); + hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_SUB); +} + +static inline void hvx_sub_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { + assert((unsigned long) src0 % 128 == 0); + assert((unsigned long) src1 % 128 == 0); + hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u, HVX_OP_SUB); +} + +static inline void hvx_sub_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { + hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_SUB); +} + +// MUL variants + +static inline void hvx_mul_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { + assert((unsigned long) dst % 128 == 0); + assert((unsigned long) src0 % 128 == 0); + assert((unsigned long) src1 % 128 == 0); + hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_MUL); +} + +static inline void hvx_mul_f32_au(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { + assert((unsigned long) dst % 128 == 0); + assert((unsigned long) src0 % 128 == 0); + hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_MUL); +} + +static inline void hvx_mul_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { + assert((unsigned long) src0 % 128 == 0); + assert((unsigned long) src1 % 128 == 0); + hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u, HVX_OP_MUL); +} + +static inline void hvx_mul_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { + hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_MUL); +} + +// Dispatchers + +static inline void hvx_add_f32(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) { + if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src0, 128)) { + if (hex_is_aligned((void *) src1, 128)) { + hvx_add_f32_aa(dst, src0, src1, num_elems); + } else { + hvx_add_f32_au(dst, src0, src1, num_elems); + } + } else if (hex_is_aligned((void *) src0, 128) && hex_is_aligned((void *) src1, 128)) { + hvx_add_f32_ua(dst, src0, src1, num_elems); + } else { + hvx_add_f32_uu(dst, src0, src1, num_elems); + } +} + +static inline void hvx_sub_f32(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) { + if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src0, 128)) { + if (hex_is_aligned((void *) src1, 128)) { + hvx_sub_f32_aa(dst, src0, src1, num_elems); + } else { + hvx_sub_f32_au(dst, src0, src1, num_elems); + } + } else if (hex_is_aligned((void *) src0, 128) && hex_is_aligned((void *) src1, 128)) { + hvx_sub_f32_ua(dst, src0, src1, num_elems); + } else { + hvx_sub_f32_uu(dst, src0, src1, num_elems); + } +} + +static inline void hvx_mul_f32(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) { + if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src0, 128)) { + if (hex_is_aligned((void *) src1, 128)) { + hvx_mul_f32_aa(dst, src0, src1, num_elems); + } else { + hvx_mul_f32_au(dst, src0, src1, num_elems); + } + } else if (hex_is_aligned((void *) src0, 128) && hex_is_aligned((void *) src1, 128)) { + hvx_mul_f32_ua(dst, src0, src1, num_elems); + } else { + hvx_mul_f32_uu(dst, src0, src1, num_elems); + } +} + +// Mul-Mul Optimized + +static inline void hvx_mul_mul_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint8_t * restrict src2, const uint32_t num_elems) { + assert((unsigned long) dst % 128 == 0); + assert((unsigned long) src0 % 128 == 0); + assert((unsigned long) src1 % 128 == 0); + assert((unsigned long) src2 % 128 == 0); + + HVX_Vector * restrict vdst = (HVX_Vector *) dst; + HVX_Vector * restrict vsrc0 = (HVX_Vector *) src0; + HVX_Vector * restrict vsrc1 = (HVX_Vector *) src1; + HVX_Vector * restrict vsrc2 = (HVX_Vector *) src2; + + const uint32_t elem_size = sizeof(float); + const uint32_t epv = 128 / elem_size; + const uint32_t nvec = num_elems / epv; + const uint32_t nloe = num_elems % epv; + + uint32_t i = 0; + + _Pragma("unroll(4)") + for (; i < nvec; i++) { + HVX_Vector v1 = HVX_OP_MUL(vsrc0[i], vsrc1[i]); + vdst[i] = HVX_OP_MUL(v1, vsrc2[i]); + } + + if (nloe) { + HVX_Vector v1 = HVX_OP_MUL(vsrc0[i], vsrc1[i]); + HVX_Vector v2 = HVX_OP_MUL(v1, vsrc2[i]); + hvx_vec_store_a((void *) &vdst[i], nloe * elem_size, v2); + } +} + +// Scalar Operations + +#define hvx_scalar_loop_body(dst_type, src_type, vec_store, scalar_op_macro) \ + do { \ + dst_type * restrict vdst = (dst_type *) dst; \ + src_type * restrict vsrc = (src_type *) src; \ + \ + const uint32_t elem_size = sizeof(float); \ + const uint32_t epv = 128 / elem_size; \ + const uint32_t nvec = n / epv; \ + const uint32_t nloe = n % epv; \ + \ + uint32_t i = 0; \ + \ + _Pragma("unroll(4)") \ + for (; i < nvec; i++) { \ + HVX_Vector v = vsrc[i]; \ + vdst[i] = scalar_op_macro(v); \ + } \ + if (nloe) { \ + HVX_Vector v = vsrc[i]; \ + v = scalar_op_macro(v); \ + vec_store((void *) &vdst[i], nloe * elem_size, v); \ + } \ + } while(0) + +#define HVX_OP_ADD_SCALAR(v) \ + ({ \ + const HVX_VectorPred pred_inf = Q6_Q_vcmp_eq_VwVw(inf, v); \ + HVX_Vector out = HVX_OP_ADD(v, val_vec); \ + Q6_V_vmux_QVV(pred_inf, inf, out); \ + }) + +#define HVX_OP_MUL_SCALAR(v) HVX_OP_MUL(v, val_vec) +#define HVX_OP_SUB_SCALAR(v) HVX_OP_SUB(v, val_vec) + +// Add Scalar Variants + +static inline void hvx_add_scalar_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { + const HVX_Vector val_vec = hvx_vec_splat_f32(val); + const HVX_Vector inf = hvx_vec_splat_f32(INFINITY); + assert((unsigned long) dst % 128 == 0); + assert((unsigned long) src % 128 == 0); + hvx_scalar_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_ADD_SCALAR); +} + +static inline void hvx_add_scalar_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { + const HVX_Vector val_vec = hvx_vec_splat_f32(val); + const HVX_Vector inf = hvx_vec_splat_f32(INFINITY); + assert((unsigned long) dst % 128 == 0); + hvx_scalar_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_ADD_SCALAR); +} + +static inline void hvx_add_scalar_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { + const HVX_Vector val_vec = hvx_vec_splat_f32(val); + const HVX_Vector inf = hvx_vec_splat_f32(INFINITY); + assert((unsigned long) src % 128 == 0); + hvx_scalar_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u, HVX_OP_ADD_SCALAR); +} + +static inline void hvx_add_scalar_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { + const HVX_Vector val_vec = hvx_vec_splat_f32(val); + static const float kInf = INFINITY; + const HVX_Vector inf = hvx_vec_splat_f32(kInf); + hvx_scalar_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_ADD_SCALAR); +} + +// Sub Scalar Variants + +static inline void hvx_sub_scalar_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { + const HVX_Vector val_vec = hvx_vec_splat_f32(val); + assert((unsigned long) dst % 128 == 0); + assert((unsigned long) src % 128 == 0); + hvx_scalar_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_SUB_SCALAR); +} + +static inline void hvx_sub_scalar_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { + const HVX_Vector val_vec = hvx_vec_splat_f32(val); + assert((unsigned long) dst % 128 == 0); + hvx_scalar_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_SUB_SCALAR); +} + +static inline void hvx_sub_scalar_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { + const HVX_Vector val_vec = hvx_vec_splat_f32(val); + assert((unsigned long) src % 128 == 0); + hvx_scalar_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u, HVX_OP_SUB_SCALAR); +} + +static inline void hvx_sub_scalar_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { + const HVX_Vector val_vec = hvx_vec_splat_f32(val); + hvx_scalar_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_SUB_SCALAR); +} + +// Mul Scalar Variants + +static inline void hvx_mul_scalar_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { + const HVX_Vector val_vec = hvx_vec_splat_f32(val); + assert((unsigned long) dst % 128 == 0); + assert((unsigned long) src % 128 == 0); + hvx_scalar_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_MUL_SCALAR); +} + +static inline void hvx_mul_scalar_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { + const HVX_Vector val_vec = hvx_vec_splat_f32(val); + assert((unsigned long) dst % 128 == 0); + hvx_scalar_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_MUL_SCALAR); +} + +static inline void hvx_mul_scalar_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { + const HVX_Vector val_vec = hvx_vec_splat_f32(val); + assert((unsigned long) src % 128 == 0); + hvx_scalar_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u, HVX_OP_MUL_SCALAR); +} + +static inline void hvx_mul_scalar_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { + const HVX_Vector val_vec = hvx_vec_splat_f32(val); + hvx_scalar_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_MUL_SCALAR); +} + +static inline void hvx_add_scalar_f32(uint8_t * restrict dst, const uint8_t * restrict src, const float val, const int num_elems) { + if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) { + hvx_add_scalar_f32_aa(dst, src, val, num_elems); + } else if (hex_is_aligned((void *) dst, 128)) { + hvx_add_scalar_f32_au(dst, src, val, num_elems); + } else if (hex_is_aligned((void *) src, 128)) { + hvx_add_scalar_f32_ua(dst, src, val, num_elems); + } else { + hvx_add_scalar_f32_uu(dst, src, val, num_elems); + } +} + +static inline void hvx_mul_scalar_f32(uint8_t * restrict dst, const uint8_t * restrict src, const float val, const int num_elems) { + if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) { + hvx_mul_scalar_f32_aa(dst, src, val, num_elems); + } else if (hex_is_aligned((void *) dst, 128)) { + hvx_mul_scalar_f32_au(dst, src, val, num_elems); + } else if (hex_is_aligned((void *) src, 128)) { + hvx_mul_scalar_f32_ua(dst, src, val, num_elems); + } else { + hvx_mul_scalar_f32_uu(dst, src, val, num_elems); + } +} + +static inline void hvx_sub_scalar_f32(uint8_t * restrict dst, const uint8_t * restrict src, const float val, const int num_elems) { + if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) { + hvx_sub_scalar_f32_aa(dst, src, val, num_elems); + } else if (hex_is_aligned((void *) dst, 128)) { + hvx_sub_scalar_f32_au(dst, src, val, num_elems); + } else if (hex_is_aligned((void *) src, 128)) { + hvx_sub_scalar_f32_ua(dst, src, val, num_elems); + } else { + hvx_sub_scalar_f32_uu(dst, src, val, num_elems); + } +} + +// MIN Scalar variants + +#define HVX_OP_MIN_SCALAR(v) Q6_Vsf_vmin_VsfVsf(val_vec, v) + +static inline void hvx_min_scalar_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { + const HVX_Vector val_vec = hvx_vec_splat_f32(val); + assert((unsigned long) dst % 128 == 0); + assert((unsigned long) src % 128 == 0); + hvx_scalar_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_MIN_SCALAR); +} + +static inline void hvx_min_scalar_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { + const HVX_Vector val_vec = hvx_vec_splat_f32(val); + assert((unsigned long) dst % 128 == 0); + hvx_scalar_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_MIN_SCALAR); +} + +static inline void hvx_min_scalar_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { + const HVX_Vector val_vec = hvx_vec_splat_f32(val); + assert((unsigned long) src % 128 == 0); + hvx_scalar_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u, HVX_OP_MIN_SCALAR); +} + +static inline void hvx_min_scalar_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { + const HVX_Vector val_vec = hvx_vec_splat_f32(val); + hvx_scalar_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_MIN_SCALAR); +} + +static inline void hvx_min_scalar_f32(uint8_t * restrict dst, const uint8_t * restrict src, const float val, const int num_elems) { + if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) { + hvx_min_scalar_f32_aa(dst, src, val, num_elems); + } else if (hex_is_aligned((void *) dst, 128)) { + hvx_min_scalar_f32_au(dst, src, val, num_elems); + } else if (hex_is_aligned((void *) src, 128)) { + hvx_min_scalar_f32_ua(dst, src, val, num_elems); + } else { + hvx_min_scalar_f32_uu(dst, src, val, num_elems); + } +} + +// CLAMP Scalar variants + +#define HVX_OP_CLAMP_SCALAR(v) \ + ({ \ + HVX_VectorPred pred_cap_right = Q6_Q_vcmp_gt_VsfVsf(v, max_vec); \ + HVX_VectorPred pred_cap_left = Q6_Q_vcmp_gt_VsfVsf(min_vec, v); \ + HVX_Vector tmp = Q6_V_vmux_QVV(pred_cap_right, max_vec, v); \ + Q6_V_vmux_QVV(pred_cap_left, min_vec, tmp); \ + }) + +static inline void hvx_clamp_scalar_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const float min, const float max, uint32_t n) { + const HVX_Vector min_vec = hvx_vec_splat_f32(min); + const HVX_Vector max_vec = hvx_vec_splat_f32(max); + assert((unsigned long) dst % 128 == 0); + assert((unsigned long) src % 128 == 0); + hvx_scalar_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_CLAMP_SCALAR); +} + +static inline void hvx_clamp_scalar_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const float min, const float max, uint32_t n) { + const HVX_Vector min_vec = hvx_vec_splat_f32(min); + const HVX_Vector max_vec = hvx_vec_splat_f32(max); + assert((unsigned long) dst % 128 == 0); + hvx_scalar_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_CLAMP_SCALAR); +} + +static inline void hvx_clamp_scalar_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const float min, const float max, uint32_t n) { + const HVX_Vector min_vec = hvx_vec_splat_f32(min); + const HVX_Vector max_vec = hvx_vec_splat_f32(max); + assert((unsigned long) src % 128 == 0); + hvx_scalar_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u, HVX_OP_CLAMP_SCALAR); +} + +static inline void hvx_clamp_scalar_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const float min, const float max, uint32_t n) { + const HVX_Vector min_vec = hvx_vec_splat_f32(min); + const HVX_Vector max_vec = hvx_vec_splat_f32(max); + hvx_scalar_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_CLAMP_SCALAR); +} + +static inline void hvx_clamp_scalar_f32(uint8_t * restrict dst, const uint8_t * restrict src, const float min, const float max, const int num_elems) { + if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) { + hvx_clamp_scalar_f32_aa(dst, src, min, max, num_elems); + } else if (hex_is_aligned((void *) dst, 128)) { + hvx_clamp_scalar_f32_au(dst, src, min, max, num_elems); + } else if (hex_is_aligned((void *) src, 128)) { + hvx_clamp_scalar_f32_ua(dst, src, min, max, num_elems); + } else { + hvx_clamp_scalar_f32_uu(dst, src, min, max, num_elems); + } +} + +#undef HVX_OP_ADD +#undef HVX_OP_SUB +#undef HVX_OP_MUL +#undef hvx_arith_loop_body +#undef HVX_OP_ADD_SCALAR +#undef HVX_OP_SUB_SCALAR +#undef HVX_OP_MUL_SCALAR +#undef hvx_scalar_loop_body +#undef HVX_OP_MIN_SCALAR +#undef HVX_OP_CLAMP_SCALAR + +#endif // HVX_ARITH_H diff --git a/ggml/src/ggml-hexagon/htp/hvx-base.h b/ggml/src/ggml-hexagon/htp/hvx-base.h new file mode 100644 index 00000000000..ffa6e18e645 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hvx-base.h @@ -0,0 +1,167 @@ +#ifndef HVX_BASE_H +#define HVX_BASE_H + +#include +#include + +#include "hex-utils.h" +#include "hvx-types.h" + +static inline void hvx_vec_store_u(void * restrict dst, uint32_t n, HVX_Vector v) { + // Rotate as needed. + v = Q6_V_vlalign_VVR(v, v, (size_t) dst); + + uint32_t left_off = (size_t) dst & 127; + uint32_t right_off = left_off + n; + + HVX_VectorPred ql_not = Q6_Q_vsetq_R((size_t) dst); + HVX_VectorPred qr = Q6_Q_vsetq2_R(right_off); + + if (right_off > 128) { + Q6_vmem_QRIV(qr, (HVX_Vector *) dst + 1, v); + // all 1's + qr = Q6_Q_vcmp_eq_VbVb(v, v); + } + + ql_not = Q6_Q_or_QQn(ql_not, qr); + Q6_vmem_QnRIV(ql_not, (HVX_Vector *) dst, v); +} + +static inline void hvx_vec_store_a(void * restrict dst, uint32_t n, HVX_Vector v) { + assert((unsigned long) dst % 128 == 0); + HVX_VectorPred m = Q6_Q_or_QQn(Q6_Q_vsetq_R((unsigned long) dst), Q6_Q_vsetq2_R(n)); + Q6_vmem_QnRIV(m, (HVX_Vector *) dst, v); +} + +static inline HVX_Vector hvx_vec_splat_f32(float v) { + union { float f; uint32_t i; } u = { .f = v }; + return Q6_V_vsplat_R(u.i); +} + +static inline HVX_Vector hvx_vec_splat_f16(float v) { + union { __fp16 f; uint16_t i; } u = { .f = v }; + return Q6_Vh_vsplat_R(u.i); +} + +static inline HVX_Vector hvx_vec_repl4(HVX_Vector v) { + // vdelta control to replicate first 4 bytes across all elements + static const uint8_t __attribute__((aligned(128))) repl[128] = { + 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + }; + + HVX_Vector ctrl = *(HVX_Vector *) repl; + return Q6_V_vdelta_VV(v, ctrl); +} + +static inline float hvx_vec_get_f32(HVX_Vector v) { + float __attribute__((aligned(128))) x; + hvx_vec_store_a(&x, 4, v); + return x; +} + +static inline HVX_Vector hvx_vec_abs_f16(HVX_Vector v) { + // abs by clearing the fp16 sign bit + HVX_Vector mask = Q6_Vh_vsplat_R(0x7fff); + return Q6_V_vand_VV(v, mask); +} + +static inline HVX_Vector hvx_vec_neg_f16(HVX_Vector v) { + // neg by setting the fp16 sign bit + HVX_Vector mask = Q6_Vh_vsplat_R(0x8000); + return Q6_V_vxor_VV(v, mask); +} + +static inline HVX_Vector hvx_vec_abs_f32(HVX_Vector v) { + // abs by clearing the fp32 sign bit + HVX_Vector mask = Q6_V_vsplat_R(0x7fffffff); + return Q6_V_vand_VV(v, mask); +} + +static inline HVX_Vector hvx_vec_neg_f32(HVX_Vector v) { +#if __HVX_ARCH__ > 75 + return Q6_Vsf_vfneg_Vsf(v); +#else + // neg by setting the fp32 sign bit + HVX_Vector mask = Q6_V_vsplat_R(0x80000000); + return Q6_V_vxor_VV(v, mask); +#endif // __HVX_ARCH__ > 75 +} + +static inline HVX_VectorPred hvx_vec_is_nan_f16(HVX_Vector v) { + const HVX_Vector vnan_exp = Q6_Vh_vsplat_R(0x7C00); + const HVX_Vector vnan_frac = Q6_Vh_vsplat_R(0x7FFF); + + // get pred of which are NaN, i.e., exponent bits all 1s and fraction bits non 0s + HVX_VectorPred p_exp = Q6_Q_vcmp_eq_VhVh(Q6_V_vand_VV(v, vnan_exp), vnan_exp); + HVX_VectorPred p_frac = Q6_Q_not_Q(Q6_Q_vcmp_eq_VhVh(Q6_V_vand_VV(v, vnan_frac), vnan_exp)); + return Q6_Q_and_QQ(p_exp, p_frac); +} + +static inline HVX_Vector hvx_vec_f32_to_f16(HVX_Vector v0, HVX_Vector v1) { + const HVX_Vector zero = Q6_V_vsplat_R(0); + HVX_Vector q0 = Q6_Vqf32_vadd_VsfVsf(v0, zero); + HVX_Vector q1 = Q6_Vqf32_vadd_VsfVsf(v1, zero); + HVX_Vector v = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(q1, q0))); + +#if __HVX_ARCH__ < 79 + // replace NaNs with -INF, older arches produce NaNs for (-INF + 0.0) + const HVX_Vector neg_inf = hvx_vec_splat_f16(-INFINITY); + HVX_VectorPred nan = hvx_vec_is_nan_f16(v); + v = Q6_V_vmux_QVV(nan, neg_inf, v); +#endif + + return v; +} + +/* Q6_Vsf_equals_Vw is only available on v73+.*/ +#if __HVX_ARCH__ < 73 +static inline HVX_Vector hvx_vec_i32_to_qf32(HVX_Vector const in) +{ + HVX_Vector const vzero = Q6_V_vzero(); + HVX_VectorPred is_zero = Q6_Q_vcmp_eq_VwVw(in, vzero); + HVX_Vector lshift = Q6_Vw_vnormamt_Vw(in); + HVX_Vector normalized = Q6_Vw_vasl_VwVw(in, lshift); + HVX_Vector vexp = Q6_Vw_vsub_VwVw(Q6_V_vsplat_R(0x7f + 30), lshift); + HVX_Vector mant = Q6_V_vand_VV(Q6_V_vsplat_R(0xFFFFFF00), normalized); + HVX_Vector ret = Q6_V_vmux_QVV(is_zero, vzero, Q6_Vw_vadd_VwVw(mant, vexp)); + return ret; +} + +static inline HVX_Vector Q6_Vsf_equals_Vw(HVX_Vector const in) +{ + return Q6_Vsf_equals_Vqf32(hvx_vec_i32_to_qf32(in)); +} +#endif + +static inline HVX_Vector hvx_vec_i16_from_hf_rnd_sat(HVX_Vector vin) { + // This looks complicated. + // Ideally should just be Q6_Vh_equals_Vhf(vin) + // but that instruction does not do proper rounding. + + // convert to qf32, multiplying by 1.0 in the process. + HVX_VectorPair v32 = Q6_Wqf32_vmpy_VhfVhf(vin, Q6_Vh_vsplat_R(0x3C00)); + + // 'in-range' values are +/32752. + // add 192K to it, convert to sf + HVX_Vector v192K = Q6_V_vsplat_R(0x48400000); + HVX_Vector vsf_0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_lo_W(v32), v192K)); + HVX_Vector vsf_1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_hi_W(v32), v192K)); + + // for in-range cases, result is {163858... 229360} so the exponent is always 144. + // if we extract bits 21..0 as a signed quantity, and round 6 bits off, that will be the answer. + // Start by <<10 to get the final 'sign' bit in bit 15... + vsf_0 = Q6_Vw_vasl_VwR(vsf_0, 10); + vsf_1 = Q6_Vw_vasl_VwR(vsf_1, 10); + + // now round down to 16 + return Q6_Vh_vround_VwVw_sat(vsf_1, vsf_0); +} + +#endif /* HVX_BASE_H */ diff --git a/ggml/src/ggml-hexagon/htp/hvx-copy.h b/ggml/src/ggml-hexagon/htp/hvx-copy.h new file mode 100644 index 00000000000..6b617b76177 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hvx-copy.h @@ -0,0 +1,247 @@ +#ifndef HVX_COPY_H +#define HVX_COPY_H + +#include +#include +#include + +#include "hvx-base.h" + +#define hvx_splat_loop_body(dst_type, vec_store) \ + do { \ + dst_type * restrict vdst = (dst_type *) dst; \ + \ + uint32_t nvec = n / (128 / elem_size); \ + uint32_t nloe = n % (128 / elem_size); \ + \ + uint32_t i = 0; \ + \ + _Pragma("unroll(4)") \ + for (; i < nvec; i++) { \ + vdst[i] = src; \ + } \ + if (nloe) { \ + vec_store((void *) &vdst[i], nloe * elem_size, src); \ + } \ + } while(0) + +static inline void hvx_splat_a(uint8_t * restrict dst, HVX_Vector src, uint32_t n, uint32_t elem_size) { + assert((unsigned long) dst % 128 == 0); + hvx_splat_loop_body(HVX_Vector, hvx_vec_store_a); +} + +static inline void hvx_splat_u(uint8_t * restrict dst, HVX_Vector src, uint32_t n, uint32_t elem_size) { + hvx_splat_loop_body(HVX_UVector, hvx_vec_store_u); +} + +static inline void hvx_splat_f32_a(uint8_t * restrict dst, float v, uint32_t n) { + hvx_splat_a(dst, hvx_vec_splat_f32(v), n, sizeof(float)); +} + +static inline void hvx_splat_f32_u(uint8_t * restrict dst, float v, uint32_t n) { + hvx_splat_u(dst, hvx_vec_splat_f32(v), n, sizeof(float)); +} + +static inline void hvx_splat_f16_a(uint8_t * restrict dst, float v, uint32_t n) { + hvx_splat_u(dst, hvx_vec_splat_f16(v), n, sizeof(__fp16)); +} + +static inline void hvx_splat_f16_u(uint8_t * restrict dst, float v, uint32_t n) { + hvx_splat_u(dst, hvx_vec_splat_f16(v), n, sizeof(__fp16)); +} + +#define hvx_copy_loop_body(dst_type, src_type, vec_store) \ + do { \ + dst_type * restrict vdst = (dst_type *) dst; \ + src_type * restrict vsrc = (src_type *) src; \ + \ + const uint32_t epv = 128 / elem_size; \ + const uint32_t nvec = n / epv; \ + const uint32_t nloe = n % epv; \ + \ + uint32_t i = 0; \ + \ + _Pragma("unroll(4)") \ + for (; i < nvec; i++) { vdst[i] = vsrc[i]; } \ + if (nloe) { \ + vec_store((void *) &vdst[i], nloe * elem_size, vsrc[i]); \ + } \ + } while(0) + +// Generic copy routines +static inline void hvx_copy_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n, uint32_t elem_size) { + assert((unsigned long) dst % 128 == 0); + assert((unsigned long) src % 128 == 0); + hvx_copy_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a); +} + +static inline void hvx_copy_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n, uint32_t elem_size) { + assert((unsigned long) dst % 128 == 0); + hvx_copy_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a); +} + +static inline void hvx_copy_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n, uint32_t elem_size) { + assert((unsigned long) src % 128 == 0); + hvx_copy_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u); +} + +static inline void hvx_copy_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n, uint32_t elem_size) { + hvx_copy_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u); +} + +// copy n fp16 elements : source and destination are aligned to HVX Vector (128) +static inline void hvx_copy_f16_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + hvx_copy_aa(dst, src, n, sizeof(__fp16)); +} + +// copy n fp16 elements : source is aligned, destination is potentially unaligned +static inline void hvx_copy_f16_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + hvx_copy_au(dst, src, n, sizeof(__fp16)); +} + +// copy n fp16 elements : source is aligned, destination is potentially unaligned +static inline void hvx_copy_f16_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + hvx_copy_ua(dst, src, n, sizeof(__fp16)); +} + +// copy n fp16 elements : source is aligned, destination is potentially unaligned +static inline void hvx_copy_f16_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + hvx_copy_uu(dst, src, n, sizeof(__fp16)); +} + +// copy n fp32 elements : source and destination are aligned to HVX Vector (128) +static inline void hvx_copy_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + hvx_copy_aa(dst, src, n, sizeof(float)); +} + +// copy n fp32 elements : source is aligned, destination is unaligned +static inline void hvx_copy_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + hvx_copy_ua(dst, src, n, sizeof(float)); +} + +// copy n fp32 elements : source is unaligned, destination is aligned +static inline void hvx_copy_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + hvx_copy_au(dst, src, n, sizeof(float)); +} + +// copy n fp32 elements : source is unaligned, destination unaligned +static inline void hvx_copy_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + hvx_copy_uu(dst, src, n, sizeof(float)); +} + +//// fp32 -> fp16 + +#define hvx_copy_f16_f32_loop_body(dst_type, src_type, vec_store) \ + do { \ + dst_type * restrict vdst = (dst_type *) dst; \ + src_type * restrict vsrc = (src_type *) src; \ + \ + const HVX_Vector zero = Q6_V_vsplat_R(0); \ + \ + const uint32_t elem_size = sizeof(__fp16); \ + const uint32_t epv = 128 / elem_size; \ + const uint32_t nvec = n / epv; \ + const uint32_t nloe = n % epv; \ + \ + uint32_t i = 0; \ + \ + _Pragma("unroll(4)") \ + for (; i < nvec; i++) { \ + vdst[i] = hvx_vec_f32_to_f16(vsrc[i*2+0], vsrc[i*2+1]); \ + } \ + if (nloe) { \ + HVX_Vector v = hvx_vec_f32_to_f16(vsrc[i*2+0], vsrc[i*2+1]); \ + vec_store((void *) &vdst[i], nloe * elem_size, v); \ + } \ + } while(0) + +// copy/convert n fp32 elements into n fp16 elements : source is aligned, destination is aligned +static inline void hvx_copy_f16_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + assert((unsigned long) dst % 128 == 0); + assert((unsigned long) src % 128 == 0); + hvx_copy_f16_f32_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a); +} + +// copy/convert n fp32 elements into n fp16 elements : source is unaligned, destination is aligned +static inline void hvx_copy_f16_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + assert((unsigned long) dst % 128 == 0); + hvx_copy_f16_f32_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a); +} + +// copy/convert n fp32 elements into n fp16 elements : source is aligned, destination is unaligned +static inline void hvx_copy_f16_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + assert((unsigned long) src % 128 == 0); + hvx_copy_f16_f32_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u); +} + +// copy/convert n fp32 elements into n fp16 elements : source is unaligned, destination is unaligned +static inline void hvx_copy_f16_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + hvx_copy_f16_f32_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u); +} + +//// fp16 -> fp32 + +#define hvx_copy_f32_f16_loop_body(dst_type, src_type, vec_store) \ + do { \ + dst_type * restrict vdst = (dst_type *) dst; \ + src_type * restrict vsrc = (src_type *) src; \ + \ + const HVX_Vector one = hvx_vec_splat_f16(1.0); \ + \ + const uint32_t elem_size = sizeof(__fp16); \ + const uint32_t epv = 128 / elem_size; \ + const uint32_t nvec = n / epv; \ + uint32_t nloe = n % epv; \ + \ + uint32_t i = 0; \ + \ + _Pragma("unroll(4)") \ + for (i = 0; i < nvec; ++i) { \ + HVX_VectorPair p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vsrc[i]), one); \ + vdst[i*2] = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(p)); \ + vdst[i*2+1] = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(p)); \ + } \ + \ + if (nloe) { \ + HVX_VectorPair p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vsrc[i]), one); \ + \ + HVX_Vector vd = Q6_V_lo_W(p); \ + i = 2 * i; \ + \ + if (nloe >= 32) { \ + vdst[i] = Q6_Vsf_equals_Vqf32(vd); \ + nloe -= 32; ++i; vd = Q6_V_hi_W(p); \ + } \ + \ + if (nloe) { \ + vd = Q6_Vsf_equals_Vqf32(vd); \ + hvx_vec_store_u(&vdst[i], nloe * sizeof(float), vd); \ + } \ + } \ + } while(0) + +// copy/convert n fp16 elements into n fp32 elements : source is aligned, destination is aligned +static inline void hvx_copy_f32_f16_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + assert((unsigned long) dst % 128 == 0); + assert((unsigned long) src % 128 == 0); + hvx_copy_f32_f16_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a); +} + +// copy/convert n fp16 elements into n fp32 elements : source is unaligned, destination is aligned +static inline void hvx_copy_f32_f16_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + assert((unsigned long) dst % 128 == 0); + hvx_copy_f32_f16_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a); +} + +// copy/convert n fp16 elements into n fp32 elements : source is aligned, destination is unaligned +static inline void hvx_copy_f32_f16_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + assert((unsigned long) src % 128 == 0); + hvx_copy_f32_f16_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u); +} + +// copy/convert n fp16 elements into n fp32 elements : source is unaligned, destination is unaligned +static inline void hvx_copy_f32_f16_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + hvx_copy_f32_f16_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u); +} + +#endif // HVX_COPY_H diff --git a/ggml/src/ggml-hexagon/htp/hvx-dump.h b/ggml/src/ggml-hexagon/htp/hvx-dump.h new file mode 100644 index 00000000000..e882227893e --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hvx-dump.h @@ -0,0 +1,132 @@ +#ifndef HVX_DUMP_H +#define HVX_DUMP_H + +#include + +#include +#include + +#include "hex-utils.h" +#include "hvx-types.h" + +static void hvx_vec_dump_f16_n(char * pref, HVX_Vector v, uint32_t n) { + HVX_VectorAlias u = { .v = v }; + + const uint32_t n0 = n / 16; + const uint32_t n1 = n % 16; + int i = 0; + for (; i < n0; i++) { + hex_dump_f16_line(pref, u.fp16 + (16 * i), 16); + } + if (n1) { + hex_dump_f16_line(pref, u.fp16 + (16 * i), n1); + } +} + +static void hvx_vec_dump_f16(char * pref, HVX_Vector v) { + hvx_vec_dump_f16_n(pref, v, 64); +} + +static void hvx_vec_dump_f32_n(char * pref, HVX_Vector v, uint32_t n) { + union { + HVX_Vector v; + float d[32]; + } u = { .v = v }; + + const uint32_t n0 = n / 16; + const uint32_t n1 = n % 16; + int i = 0; + for (; i < n0; i++) { + hex_dump_f32_line(pref, u.d + (16 * i), 16); + } + if (n1) { + hex_dump_f32_line(pref, u.d + (16 * i), n1); + } +} + +static void hvx_vec_dump_f32_hmt(char * pref, HVX_Vector v) { + union { + HVX_Vector v; + float d[32]; + } u = { .v = v }; + + FARF(HIGH, "%s: %.6f %.6f %.6f %.6f ... %.6f %.6f %.6f %.6f ... %.6f %.6f %.6f %.6f\n", pref, u.d[0], u.d[1], + u.d[2], u.d[3], u.d[12], u.d[13], u.d[14], u.d[15], u.d[28], u.d[29], u.d[30], u.d[31]); +} + +static void hvx_vec_dump_f32(char * pref, HVX_Vector v) { + hvx_vec_dump_f32_n(pref, v, 32); +} + +static void hvx_vec_dump_int32(char * pref, HVX_Vector v) { + union { + HVX_Vector v; + int32_t d[32]; + } u = { .v = v }; + + for (int i = 0; i < 32 / 16; i++) { + hex_dump_int32_line(pref, u.d + (16 * i), 16); + } +} + +static void hvx_vec_dump_int32_hmt(char * pref, HVX_Vector v) { + union { + HVX_Vector v; + int32_t d[32]; + } u = { .v = v }; + + FARF(HIGH, "%s: %d %d %d %d ... %d %d %d %d ... %d %d %d %d\n", pref, u.d[0], u.d[1], u.d[2], u.d[3], u.d[12], + u.d[13], u.d[14], u.d[15], u.d[28], u.d[29], u.d[30], u.d[31]); +} + +static void hvx_vec_dump_int8_hmt(char * pref, HVX_Vector v) { + union { + HVX_Vector v; + int8_t d[128]; + } u = { .v = v }; + + FARF(HIGH, "%s: %d %d %d %d ... %d %d %d %d ... %d %d %d %d\n", pref, u.d[0], u.d[1], u.d[2], u.d[3], u.d[60], + u.d[61], u.d[62], u.d[63], u.d[124], u.d[125], u.d[126], u.d[127]); +} + +static void hvx_vec_dump_int8(char * pref, HVX_Vector v) { + union { + HVX_Vector v; + int8_t d[128]; + } u = { .v = v }; + + for (int i = 0; i < 128 / 16; i++) { + hex_dump_int8_line(pref, u.d + (16 * i), 16); + } +} + +static void hvx_vec_dump_uint8(char * pref, HVX_Vector v) { + union { + HVX_Vector v; + uint8_t d[128]; + } u = { .v = v }; + + for (int i = 0; i < 128 / 16; i++) { + hex_dump_uint8_line(pref, u.d + (16 * i), 16); + } +} + +static bool hvx_vec_eq(HVX_Vector v0, HVX_Vector v1, size_t n) { + typedef union { + HVX_Vector v; + int8_t d[128]; + } U; + + U u0 = { .v = v0 }; + U u1 = { .v = v1 }; + + for (int i = 0; i < n; i++) { + if (u0.d[i] != u1.d[i]) { + return false; + } + } + + return true; +} + +#endif /* HVX_DUMP_H */ diff --git a/ggml/src/ggml-hexagon/htp/hvx-exp.c b/ggml/src/ggml-hexagon/htp/hvx-exp.c deleted file mode 100644 index 21bf46a542f..00000000000 --- a/ggml/src/ggml-hexagon/htp/hvx-exp.c +++ /dev/null @@ -1,94 +0,0 @@ -#pragma clang diagnostic ignored "-Wunused-variable" -#pragma clang diagnostic ignored "-Wunused-function" -#pragma clang diagnostic ignored "-Wunused-but-set-variable" - -#include -#include -#include -#include - -#define GGML_COMMON_DECL_C -#include "ggml-common.h" -#include "htp-ctx.h" -#include "htp-dma.h" -#include "htp-msg.h" -#include "htp-ops.h" -#include "hvx-utils.h" -#include "ops-utils.h" - -static inline HVX_Vector hvx_vec_exp_fp32_guard(HVX_Vector in_vec, HVX_Vector max_exp, HVX_Vector inf) { - const HVX_VectorPred pred0 = Q6_Q_vcmp_gt_VsfVsf(in_vec, max_exp); - - HVX_Vector out = hvx_vec_exp_fp32(in_vec); - - return Q6_V_vmux_QVV(pred0, inf, out); -} - -void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, bool negate) { - int left_over = num_elems & (VLEN_FP32 - 1); - int num_elems_whole = num_elems - left_over; - - int unaligned_addr = 0; - int unaligned_loop = 0; - if ((0 == htp_is_aligned((void *) src, VLEN)) || (0 == htp_is_aligned((void *) dst, VLEN))) { - FARF(HIGH, "hvx_exp_f32: unaligned address in hvx op, possibly slower execution\n"); - unaligned_addr = 1; - } - // assert((0 == unaligned_addr) || (0 == num_elems_whole)); - if ((1 == unaligned_addr) && (num_elems_whole != 0)) { - unaligned_loop = 1; - FARF(HIGH, "hvx_exp_f32: unaligned loop in hvx op, possibly slower execution\n"); - } - - HVX_Vector vec_out = Q6_V_vzero(); - - static const float kInf = INFINITY; - static const float kMaxExp = 88.02f; // log(INF) - - const HVX_Vector max_exp = hvx_vec_splat_fp32(kMaxExp); - const HVX_Vector inf = hvx_vec_splat_fp32(kInf); - - if (0 == unaligned_loop) { - HVX_Vector * p_vec_in1 = (HVX_Vector *) src; - HVX_Vector * p_vec_out = (HVX_Vector *) dst; - - #pragma unroll(4) - for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { - if (true == negate) { - HVX_Vector neg_vec_in = hvx_vec_neg_fp32(*p_vec_in1++); - *p_vec_out++ = hvx_vec_exp_fp32_guard(neg_vec_in, max_exp, inf); - } else { - *p_vec_out++ = hvx_vec_exp_fp32_guard(*p_vec_in1++, max_exp, inf); - } - } - } else { - #pragma unroll(4) - for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { - HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32); - - if (true == negate) { - HVX_Vector neg_vec_in = hvx_vec_neg_fp32(in); - *(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_exp_fp32_guard(neg_vec_in, max_exp, inf); - } else { - *(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_exp_fp32_guard(in, max_exp, inf); - } - } - } - - if (left_over > 0) { - const float * srcf = (float *) src + num_elems_whole; - float * dstf = (float *) dst + num_elems_whole; - - HVX_Vector in = *(HVX_UVector *) srcf; - - if (true == negate) { - HVX_Vector neg_vec_in = hvx_vec_neg_fp32(in); - - vec_out = hvx_vec_exp_fp32_guard(neg_vec_in, max_exp, inf); - } else { - vec_out = hvx_vec_exp_fp32_guard(in, max_exp, inf); - } - - hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, vec_out); - } -} diff --git a/ggml/src/ggml-hexagon/htp/hvx-exp.h b/ggml/src/ggml-hexagon/htp/hvx-exp.h new file mode 100644 index 00000000000..44dfe232a3d --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hvx-exp.h @@ -0,0 +1,215 @@ +#ifndef HVX_EXP_H +#define HVX_EXP_H + +#include +#include + +#include "hvx-base.h" +#include "hvx-floor.h" + +#define EXP_COEFF_5 (0x39506967) // 0.000198757 = 1/(7!) +#define EXP_COEFF_4 (0x3AB743CE) // 0.0013982 = 1/(6!) +#define EXP_COEFF_3 (0x3C088908) // 0.00833345 = 1/(5!) +#define EXP_COEFF_2 (0x3D2AA9C1) // 0.416658 = 1/(4!) +#define EXP_COEFF_1 (0x3E2AAAAA) // 0.16666667 = 1/(3!) +#define EXP_COEFF_0 (0x3F000000) // 0.5 = 1/(2!) +#define EXP_LOGN2 (0x3F317218) // ln(2) = 0.6931471805 +#define EXP_LOG2E (0x3FB8AA3B) // log2(e) = 1/ln(2) = 1.4426950408 +#define EXP_ONE (0x3f800000) // 1.0 +#define EXP_RANGE_R (0x41a00000) // 20.0 +#define EXP_RANGE_L (0xc1a00000) // -20.0 + +static inline HVX_Vector hvx_vec_exp_f32(HVX_Vector in_vec) { + HVX_Vector z_qf32_v; + HVX_Vector x_v; + HVX_Vector x_qf32_v; + HVX_Vector y_v; + HVX_Vector k_v; + HVX_Vector f_v; + HVX_Vector epsilon_v; + HVX_Vector log2e = Q6_V_vsplat_R(EXP_LOG2E); + HVX_Vector logn2 = Q6_V_vsplat_R(EXP_LOGN2); + HVX_Vector E_const; + HVX_Vector zero_v = Q6_V_vzero(); + + // exp(x) is approximated as follows: + // f = floor(x/ln(2)) = floor(x*log2(e)) + // epsilon = x - f*ln(2) + // exp(x) = exp(epsilon+f*ln(2)) + // = exp(epsilon)*exp(f*ln(2)) + // = exp(epsilon)*2^f + // + // Since epsilon is close to zero, it can be approximated with its Taylor series: + // exp(x) ~= 1+x+x^2/2!+x^3/3!+...+x^n/n!+... + // Preserving the first eight elements, we get: + // exp(x) ~= 1+x+e0*x^2+e1*x^3+e2*x^4+e3*x^5+e4*x^6+e5*x^7 + // = 1+x+(E0+(E1+(E2+(E3+(E4+E5*x)*x)*x)*x)*x)*x^2 + + HVX_Vector temp_v = in_vec; + + // Clamp inputs to (-20.0, 20.0) + HVX_VectorPred pred_cap_right = Q6_Q_vcmp_gt_VsfVsf(in_vec, Q6_V_vsplat_R(EXP_RANGE_R)); + HVX_VectorPred pred_cap_left = Q6_Q_vcmp_gt_VsfVsf(Q6_V_vsplat_R(EXP_RANGE_L), in_vec); + + in_vec = Q6_V_vmux_QVV(pred_cap_right, Q6_V_vsplat_R(EXP_RANGE_R), temp_v); + in_vec = Q6_V_vmux_QVV(pred_cap_left, Q6_V_vsplat_R(EXP_RANGE_L), temp_v); + + epsilon_v = Q6_Vqf32_vmpy_VsfVsf(log2e, in_vec); + epsilon_v = Q6_Vsf_equals_Vqf32(epsilon_v); + + // f_v is the floating point result and k_v is the integer result + f_v = hvx_vec_floor_f32(epsilon_v); + k_v = hvx_vec_truncate_f32(f_v); + + x_qf32_v = Q6_Vqf32_vadd_VsfVsf(in_vec, zero_v); + + // x = x - f_v * logn2; + epsilon_v = Q6_Vqf32_vmpy_VsfVsf(f_v, logn2); + x_qf32_v = Q6_Vqf32_vsub_Vqf32Vqf32(x_qf32_v, epsilon_v); + // normalize before every QFloat's vmpy + x_qf32_v = Q6_Vqf32_vadd_Vqf32Vsf(x_qf32_v, zero_v); + + // z = x * x; + z_qf32_v = Q6_Vqf32_vmpy_Vqf32Vqf32(x_qf32_v, x_qf32_v); + z_qf32_v = Q6_Vqf32_vadd_Vqf32Vsf(z_qf32_v, zero_v); + + x_v = Q6_Vsf_equals_Vqf32(x_qf32_v); + + // y = E4 + E5 * x; + E_const = Q6_V_vsplat_R(EXP_COEFF_5); + y_v = Q6_Vqf32_vmpy_VsfVsf(E_const, x_v); + E_const = Q6_V_vsplat_R(EXP_COEFF_4); + y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const); + y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v); + + // y = E3 + y * x; + E_const = Q6_V_vsplat_R(EXP_COEFF_3); + y_v = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, x_qf32_v); + y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const); + y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v); + + // y = E2 + y * x; + E_const = Q6_V_vsplat_R(EXP_COEFF_2); + y_v = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, x_qf32_v); + y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const); + y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v); + + // y = E1 + y * x; + E_const = Q6_V_vsplat_R(EXP_COEFF_1); + y_v = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, x_qf32_v); + y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const); + y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v); + + // y = E0 + y * x; + E_const = Q6_V_vsplat_R(EXP_COEFF_0); + y_v = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, x_qf32_v); + y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const); + y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v); + + // y = x + y * z; + y_v = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, z_qf32_v); + y_v = Q6_Vqf32_vadd_Vqf32Vqf32(y_v, x_qf32_v); + y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v); + + // y = y + 1.0; + y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, Q6_V_vsplat_R(EXP_ONE)); + + // insert exponents + // y = ldexpf(y, k); + // y_v += k_v; // qf32 + // modify exponent + + y_v = Q6_Vsf_equals_Vqf32(y_v); + + // add k_v to the exponent of y_v + HVX_Vector y_v_exponent = Q6_Vw_vasl_VwR(y_v, 1); + + y_v_exponent = Q6_Vuw_vlsr_VuwR(y_v_exponent, IEEE_VSF_MANTLEN + 1); + y_v_exponent = Q6_Vw_vadd_VwVw(k_v, y_v_exponent); + + // exponent cannot be negative; if overflow is detected, result is set to zero + HVX_VectorPred qy_v_negative_exponent = Q6_Q_vcmp_gt_VwVw(zero_v, y_v_exponent); + + y_v = Q6_Vw_vaslacc_VwVwR(y_v, k_v, IEEE_VSF_MANTLEN); + + y_v = Q6_V_vmux_QVV(qy_v_negative_exponent, zero_v, y_v); + + return y_v; +} + +static inline HVX_Vector hvx_vec_exp_f32_guard(HVX_Vector in_vec, HVX_Vector max_exp, HVX_Vector inf) { + const HVX_VectorPred pred0 = Q6_Q_vcmp_gt_VsfVsf(in_vec, max_exp); + + HVX_Vector out = hvx_vec_exp_f32(in_vec); + + return Q6_V_vmux_QVV(pred0, inf, out); +} + +static inline void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, bool negate) { + int left_over = num_elems & (VLEN_FP32 - 1); + int num_elems_whole = num_elems - left_over; + + int unaligned_addr = 0; + int unaligned_loop = 0; + if ((0 == hex_is_aligned((void *) src, VLEN)) || (0 == hex_is_aligned((void *) dst, VLEN))) { + unaligned_addr = 1; + } + // assert((0 == unaligned_addr) || (0 == num_elems_whole)); + if ((1 == unaligned_addr) && (num_elems_whole != 0)) { + unaligned_loop = 1; + } + + HVX_Vector vec_out = Q6_V_vzero(); + + static const float kInf = INFINITY; + static const float kMaxExp = 88.02f; // log(INF) + + const HVX_Vector max_exp = hvx_vec_splat_f32(kMaxExp); + const HVX_Vector inf = hvx_vec_splat_f32(kInf); + + if (0 == unaligned_loop) { + HVX_Vector * p_vec_in1 = (HVX_Vector *) src; + HVX_Vector * p_vec_out = (HVX_Vector *) dst; + + #pragma unroll(4) + for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { + if (true == negate) { + HVX_Vector neg_vec_in = hvx_vec_neg_f32(*p_vec_in1++); + *p_vec_out++ = hvx_vec_exp_f32_guard(neg_vec_in, max_exp, inf); + } else { + *p_vec_out++ = hvx_vec_exp_f32_guard(*p_vec_in1++, max_exp, inf); + } + } + } else { + #pragma unroll(4) + for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { + HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32); + + if (true == negate) { + HVX_Vector neg_vec_in = hvx_vec_neg_f32(in); + *(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_exp_f32_guard(neg_vec_in, max_exp, inf); + } else { + *(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_exp_f32_guard(in, max_exp, inf); + } + } + } + + if (left_over > 0) { + const float * srcf = (float *) src + num_elems_whole; + float * dstf = (float *) dst + num_elems_whole; + + HVX_Vector in = *(HVX_UVector *) srcf; + + if (true == negate) { + HVX_Vector neg_vec_in = hvx_vec_neg_f32(in); + + vec_out = hvx_vec_exp_f32_guard(neg_vec_in, max_exp, inf); + } else { + vec_out = hvx_vec_exp_f32_guard(in, max_exp, inf); + } + + hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, vec_out); + } +} + +#endif /* HVX_EXP_H */ diff --git a/ggml/src/ggml-hexagon/htp/hvx-floor.h b/ggml/src/ggml-hexagon/htp/hvx-floor.h new file mode 100644 index 00000000000..6a1bfde5675 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hvx-floor.h @@ -0,0 +1,100 @@ +#ifndef HVX_FLOOR_H +#define HVX_FLOOR_H + +#include +#include + +#include "hvx-base.h" + +#define IEEE_VSF_EXPLEN (8) +#define IEEE_VSF_EXPBIAS (127) +#define IEEE_VSF_EXPMASK (0xFF) +#define IEEE_VSF_MANTLEN (23) +#define IEEE_VSF_MANTMASK (0x7FFFFF) +#define IEEE_VSF_MIMPMASK (0x800000) + +static inline HVX_Vector hvx_vec_truncate_f32(HVX_Vector in_vec) { + HVX_Vector mask_mant_v = Q6_V_vsplat_R(IEEE_VSF_MANTMASK); + HVX_Vector mask_impl_v = Q6_V_vsplat_R(IEEE_VSF_MIMPMASK); + HVX_Vector const_zero_v = Q6_V_vzero(); + + HVX_VectorPred q_negative = Q6_Q_vcmp_gt_VwVw(const_zero_v, in_vec); + + HVX_Vector expval_v = in_vec >> IEEE_VSF_MANTLEN; + expval_v &= IEEE_VSF_EXPMASK; + expval_v -= IEEE_VSF_EXPBIAS; + + // negative exp == fractional value + HVX_VectorPred q_negexp = Q6_Q_vcmp_gt_VwVw(const_zero_v, expval_v); + + HVX_Vector rshift_v = IEEE_VSF_MANTLEN - expval_v; // fractional bits - exp shift + + HVX_Vector mant_v = in_vec & mask_mant_v; // obtain mantissa + HVX_Vector vout = Q6_Vw_vadd_VwVw(mant_v, mask_impl_v); // add implicit 1.0 + + vout = Q6_Vw_vasr_VwVw(vout, rshift_v); // shift to obtain truncated integer + vout = Q6_V_vmux_QVV(q_negexp, const_zero_v, vout); // expval<0 -> 0 + + HVX_Vector neg_vout = -vout; + + vout = Q6_V_vmux_QVV(q_negative, neg_vout, vout); // handle negatives + + return (vout); +} + +static inline HVX_Vector hvx_vec_floor_f32(HVX_Vector in_vec) { + HVX_Vector mask_mant_v = Q6_V_vsplat_R(IEEE_VSF_MANTMASK); + HVX_Vector mask_impl_v = Q6_V_vsplat_R(IEEE_VSF_MIMPMASK); + HVX_Vector const_mnlen_v = Q6_V_vsplat_R(IEEE_VSF_MANTLEN); + HVX_Vector const_zero_v = Q6_V_vzero(); + HVX_Vector const_negone_v = Q6_V_vsplat_R(0xbf800000); // -1 IEEE vsf + + HVX_VectorPred q_negative = Q6_Q_vcmp_gt_VwVw(const_zero_v, in_vec); + + HVX_Vector expval_v = in_vec >> IEEE_VSF_MANTLEN; + expval_v &= IEEE_VSF_EXPMASK; + expval_v -= IEEE_VSF_EXPBIAS; + + HVX_VectorPred q_negexp = Q6_Q_vcmp_gt_VwVw(const_zero_v, expval_v); + HVX_VectorPred q_expltmn = Q6_Q_vcmp_gt_VwVw(const_mnlen_v, expval_v); + HVX_VectorPred q_negexp_pos = Q6_Q_vcmp_gtand_QVwVw(q_negexp, in_vec, const_zero_v); + HVX_VectorPred q_negexp_neg = Q6_Q_vcmp_gtand_QVwVw(q_negexp, const_zero_v, in_vec); + + // if expval < 0 (q_negexp) // <0, floor is 0 + // if vin > 0 + // floor = 0 + // if vin < 0 + // floor = -1 + // if expval < mant_len (q_expltmn) // >0, but fraction may exist + // get sign (q_negative) + // mask >> expval // fraction bits to mask off + // vout = ~(mask) // apply mask to remove fraction + // if (qneg) // negative floor is one less (more, sign bit for neg) + // vout += ((impl_mask) >> expval) + // if (mask && vin) + // vout = vin + // else // already an integer + // ; // no change + + // compute floor + mask_mant_v >>= expval_v; + HVX_Vector neg_addin_v = mask_impl_v >> expval_v; + HVX_Vector vout_neg_addin = Q6_Vw_vadd_VwVw(in_vec, neg_addin_v); + HVX_Vector vout = Q6_V_vmux_QVV(q_negative, vout_neg_addin, in_vec); + + HVX_Vector mask_chk_v = Q6_V_vand_VV(in_vec, mask_mant_v); // chk if bits set + HVX_VectorPred q_integral = Q6_Q_vcmp_eq_VwVw(const_zero_v, mask_chk_v); + + HVX_Vector not_mask_v = Q6_V_vnot_V(mask_mant_v); // frac bits to clear + HVX_Vector vfrfloor_v = Q6_V_vand_VV(vout, not_mask_v); // clear frac bits + + vout = in_vec; + vout = Q6_V_vmux_QVV(q_expltmn, vfrfloor_v, vout); // expval0 -> 0 + vout = Q6_V_vmux_QVV(q_negexp_neg, const_negone_v, vout); // expval<0 x<0 -> -1 + + return vout; +} + +#endif /* HVX_FLOOR_H */ diff --git a/ggml/src/ggml-hexagon/htp/hvx-inverse.c b/ggml/src/ggml-hexagon/htp/hvx-inverse.c deleted file mode 100644 index 4d70634fcd4..00000000000 --- a/ggml/src/ggml-hexagon/htp/hvx-inverse.c +++ /dev/null @@ -1,72 +0,0 @@ -#pragma clang diagnostic ignored "-Wunused-variable" -#pragma clang diagnostic ignored "-Wunused-function" -#pragma clang diagnostic ignored "-Wunused-but-set-variable" - -#include -#include -#include -#include - -#define GGML_COMMON_DECL_C -#include "ggml-common.h" -#include "htp-ctx.h" -#include "htp-dma.h" -#include "htp-msg.h" -#include "htp-ops.h" -#include "hvx-utils.h" -#include "ops-utils.h" - -static inline HVX_Vector hvx_vec_inverse_fp32_guard(HVX_Vector v_sf, HVX_Vector nan_inf_mask) { - HVX_Vector out = hvx_vec_inverse_fp32(v_sf); - - HVX_Vector masked_out = Q6_V_vand_VV(out, nan_inf_mask); - const HVX_VectorPred pred = Q6_Q_vcmp_eq_VwVw(nan_inf_mask, masked_out); - - return Q6_V_vmux_QVV(pred, Q6_V_vzero(), out); -} - -void hvx_inverse_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems) { - int left_over = num_elems & (VLEN_FP32 - 1); - int num_elems_whole = num_elems - left_over; - - int unaligned_addr = 0; - int unaligned_loop = 0; - if ((0 == htp_is_aligned((void *) src, VLEN)) || (0 == htp_is_aligned((void *) dst, VLEN))) { - FARF(HIGH, "hvx_inverse_f32: unaligned address in hvx op, possibly slower execution\n"); - unaligned_addr = 1; - } - // assert((0 == unaligned_addr) || (0 == num_elems_whole)); - if ((1 == unaligned_addr) && (num_elems_whole != 0)) { - unaligned_loop = 1; - FARF(HIGH, "hvx_inverse_f32: unaligned loop in hvx op, possibly slower execution\n"); - } - - static const uint32_t kNanInfMask = 0x7f800000; - const HVX_Vector nan_inf_mask = Q6_V_vsplat_R(kNanInfMask); - - if (0 == unaligned_loop) { - HVX_Vector * p_vec_in = (HVX_Vector *) src; - HVX_Vector * p_vec_out = (HVX_Vector *) dst; - - #pragma unroll(4) - for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { - *p_vec_out++ = hvx_vec_inverse_fp32_guard(*p_vec_in++, nan_inf_mask); - } - } else { - #pragma unroll(4) - for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { - HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32); - *(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_inverse_fp32_guard(in, nan_inf_mask); - } - } - - if (left_over > 0) { - const float * srcf = (float *) src + num_elems_whole; - float * dstf = (float *) dst + num_elems_whole; - - HVX_Vector in = *(HVX_UVector *) srcf; - HVX_Vector out = hvx_vec_inverse_fp32_guard(in, nan_inf_mask); - - hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, out); - } -} diff --git a/ggml/src/ggml-hexagon/htp/hvx-inverse.h b/ggml/src/ggml-hexagon/htp/hvx-inverse.h new file mode 100644 index 00000000000..49f3efabbcc --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hvx-inverse.h @@ -0,0 +1,176 @@ +#ifndef HVX_INVERSE_H +#define HVX_INVERSE_H + +#include + +#include +#include +#include +#include +#include + +#include "hvx-base.h" + +// ==================================================== +// FUNCTION: 1/(x+1) y(0) = 1, y(0.5) = 0.6667, y(1) = 0.5 +// Order:3; continuity: True; Ends forced: True +// Mode: unsigned; Result fractional bits: 14 +// Peak Error: 1.1295e-04 Rms Error: 2.8410e-05 Mean Error: 1.1370e-05 +// 32769 -32706 31252 -10589 +// 32590 -30635 22793 -4493 +// 32066 -27505 16481 -2348 +// 31205 -24054 11849 -1306 + +static inline HVX_Vector hvx_vec_recip_xp1_O3_unsigned(HVX_Vector vx) { + // input is 0..0xffff representing 0.0 .. 1.0 + HVX_Vector p; + p = Q6_Vh_vlut4_VuhPh(vx, 0xFAE6F6D4EE73D6A3ull); + p = Q6_Vh_vmpa_VhVhVuhPuh_sat(p, vx, 0x2E49406159097A14ull); + p = Q6_Vh_vmps_VhVhVuhPuh_sat(p, vx, 0x5DF66B7177AB7FC2ull); + p = Q6_Vh_vmpa_VhVhVuhPuh_sat(p, vx, 0x79E57D427F4E8001ull); + return p; // signed result, 14 fractional bits +} + +// Find reciprocal of fp16. +// (1) first, convert to fp32, multiplying by 1.0; this is done to +// handle denormals. Ignoring sign and zero, result should be at +// least 5.9604645e-08 (32-bit code 0x33800000) and at most 131008 (0x47ffe000) +// (exponent in range [103,143]) +// (2) extract the mantissa into 16-bit unsigned; find reciprocal using a fitted poly +// (3) put this, along with '253-exp' (exp from (1)) together to make an qf32 +// (4) convert that to fp16 +// (5) put sign back in. Also, if the original value (w/o sign) was <0x81, replace +// the result with the max value. +static inline HVX_Vector hvx_vec_inverse_f16(HVX_Vector vals) { + HVX_Vector em_mask = Q6_Vh_vsplat_R(0x7FFF); + HVX_Vector avals = Q6_V_vand_VV(vals, em_mask); + HVX_VectorPred is_neg = Q6_Q_vcmp_gt_VhVh(avals, vals); + // is too small to 1/x ? for 'standard' fp16, this would be 0x101 + HVX_VectorPred is_small = Q6_Q_vcmp_gt_VhVh(Q6_Vh_vsplat_R(0x101), avals); + + HVX_VectorPair to_qf32 = Q6_Wqf32_vmpy_VhfVhf(avals, Q6_Vh_vsplat_R(0x3C00)); // *1.0 + HVX_Vector to_f32_0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(to_qf32)); + HVX_Vector to_f32_1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(to_qf32)); + + // bits 22..13 contain the mantissa now (w/o hidden bit); move to bit 14..5 of a 16-bit vector + HVX_Vector mant_u16 = Q6_Vh_vshuffo_VhVh(Q6_Vw_vasl_VwR(to_f32_1, 9), Q6_Vw_vasl_VwR(to_f32_0, 9)); + // likewise extract the upper 16 from each, containing the exponents in range 103..142 + HVX_Vector exp_u16 = Q6_Vh_vshuffo_VhVh(to_f32_1, to_f32_0); + //Get exponent in IEEE 32-bit representation + exp_u16 = Q6_Vuh_vlsr_VuhR(exp_u16, 7); + + // so, mant_u16 contains an unbiased mantissa in upper 10 bits of each u16 lane + // We can consider it to be x-1.0, with 16 fractional bits, where 'x' is in range [1.0,2.0) + // Use poly to transform to 1/x, with 14 fractional bits + // + HVX_Vector rm = hvx_vec_recip_xp1_O3_unsigned(mant_u16); + + HVX_Vector vcl0 = Q6_Vuh_vcl0_Vuh(rm); //count leading zeros + + // Get mantissa for 16-bit represenation + HVX_Vector mant_recip = Q6_V_vand_VV(Q6_Vh_vasr_VhR(Q6_Vh_vasl_VhVh(rm, vcl0), 5), Q6_Vh_vsplat_R(0x03FF)); + + //Compute Reciprocal Exponent + HVX_Vector exp_recip = + Q6_Vh_vsub_VhVh(Q6_Vh_vsub_VhVh(Q6_Vh_vsplat_R(254), exp_u16), Q6_Vh_vsub_VhVh(vcl0, Q6_Vh_vsplat_R(1))); + //Convert it for 16-bit representation + exp_recip = Q6_Vh_vadd_VhVh_sat(Q6_Vh_vsub_VhVh(exp_recip, Q6_Vh_vsplat_R(127)), Q6_Vh_vsplat_R(15)); + exp_recip = Q6_Vh_vasl_VhR(exp_recip, 10); + + //Merge exponent and mantissa for reciprocal + HVX_Vector recip = Q6_V_vor_VV(exp_recip, mant_recip); + // map 'small' inputs to standard largest value 0x7bff + recip = Q6_V_vmux_QVV(is_small, Q6_Vh_vsplat_R(0x7bff), recip); + // add sign back + recip = Q6_V_vandor_VQR(recip, is_neg, 0x80008000); + return recip; +} + +static inline HVX_Vector hvx_vec_inverse_f32(HVX_Vector v_sf) { + HVX_Vector inv_aprox_sf = Q6_V_vsplat_R(0x7EEEEBB3); + HVX_Vector two_sf = hvx_vec_splat_f32(2.0); + + // First approximation + HVX_Vector i_sf = Q6_Vw_vsub_VwVw(inv_aprox_sf, v_sf); + + HVX_Vector r_qf; + + // Refine + r_qf = Q6_Vqf32_vmpy_VsfVsf( + i_sf, Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(two_sf, Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(i_sf, v_sf))))); + r_qf = Q6_Vqf32_vmpy_Vqf32Vqf32( + r_qf, Q6_Vqf32_vsub_VsfVsf(two_sf, Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(r_qf), v_sf)))); + r_qf = Q6_Vqf32_vmpy_Vqf32Vqf32( + r_qf, Q6_Vqf32_vsub_VsfVsf(two_sf, Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(r_qf), v_sf)))); + + return Q6_Vsf_equals_Vqf32(r_qf); +} + +static inline HVX_Vector hvx_vec_inverse_f32_guard(HVX_Vector v_sf, HVX_Vector nan_inf_mask) { + HVX_Vector out = hvx_vec_inverse_f32(v_sf); + + HVX_Vector masked_out = Q6_V_vand_VV(out, nan_inf_mask); + const HVX_VectorPred pred = Q6_Q_vcmp_eq_VwVw(nan_inf_mask, masked_out); + + return Q6_V_vmux_QVV(pred, Q6_V_vzero(), out); +} + +#define hvx_inverse_f32_loop_body(dst_type, src_type, vec_store) \ + do { \ + dst_type * restrict vdst = (dst_type *) dst; \ + src_type * restrict vsrc = (src_type *) src; \ + \ + const HVX_Vector nan_inf_mask = Q6_V_vsplat_R(0x7f800000); \ + \ + const uint32_t nvec = n / VLEN_FP32; \ + const uint32_t nloe = n % VLEN_FP32; \ + \ + uint32_t i = 0; \ + \ + _Pragma("unroll(4)") \ + for (; i < nvec; i++) { \ + vdst[i] = hvx_vec_inverse_f32_guard(vsrc[i], nan_inf_mask); \ + } \ + if (nloe) { \ + HVX_Vector v = hvx_vec_inverse_f32_guard(vsrc[i], nan_inf_mask); \ + vec_store((void *) &vdst[i], nloe * SIZEOF_FP32, v); \ + } \ + } while(0) + +static inline void hvx_inverse_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + assert((unsigned long) dst % 128 == 0); + assert((unsigned long) src % 128 == 0); + hvx_inverse_f32_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a); +} + +static inline void hvx_inverse_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + assert((unsigned long) dst % 128 == 0); + hvx_inverse_f32_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a); +} + +static inline void hvx_inverse_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + assert((unsigned long) src % 128 == 0); + hvx_inverse_f32_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u); +} + +static inline void hvx_inverse_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + hvx_inverse_f32_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u); +} + +static inline void hvx_inverse_f32(uint8_t * restrict dst, uint8_t * restrict src, const int num_elems) { + if ((unsigned long) dst % 128 == 0) { + if ((unsigned long) src % 128 == 0) { + hvx_inverse_f32_aa(dst, src, num_elems); + } else { + hvx_inverse_f32_au(dst, src, num_elems); + } + } else { + if ((unsigned long) src % 128 == 0) { + hvx_inverse_f32_ua(dst, src, num_elems); + } else { + hvx_inverse_f32_uu(dst, src, num_elems); + } + } +} + +#endif // HVX_INVERSE_H diff --git a/ggml/src/ggml-hexagon/htp/hvx-reduce.h b/ggml/src/ggml-hexagon/htp/hvx-reduce.h new file mode 100644 index 00000000000..8845fe73ea1 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hvx-reduce.h @@ -0,0 +1,225 @@ +#ifndef HVX_REDUCE_H +#define HVX_REDUCE_H + +#include +#include +#include +#include + +#include "hex-utils.h" +#include "hvx-base.h" +#include "hvx-types.h" + +static inline HVX_Vector hvx_vec_reduce_sum_n_i32(HVX_Vector in, unsigned int n) { + unsigned int total = n * 4; // total vec nbytes + unsigned int width = 4; // int32 + + HVX_Vector sum = in, sum_t; + while (width < total) { + sum_t = Q6_V_vror_VR(sum, width); // rotate right + sum = Q6_Vw_vadd_VwVw(sum_t, sum); // elementwise sum + width = width << 1; + } + return sum; +} + +static inline HVX_Vector hvx_vec_reduce_sum_i32(HVX_Vector in) { + return hvx_vec_reduce_sum_n_i32(in, 32); +} + +static inline HVX_Vector hvx_vec_reduce_sum_n_qf32(HVX_Vector in, unsigned int n) { + unsigned int total = n * 4; // total vec nbytes + unsigned int width = 4; // fp32 nbytes + + HVX_Vector sum = in, sum_t; + while (width < total) { + sum_t = Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum), width); // rotate right + sum = Q6_Vqf32_vadd_Vqf32Vsf(sum, sum_t); // elementwise sum + width = width << 1; + } + return sum; +} + +static inline HVX_Vector hvx_vec_reduce_sum_qf32(HVX_Vector in) { + return hvx_vec_reduce_sum_n_qf32(in, 32); +} + +static inline HVX_Vector hvx_vec_reduce_sum_n_f32(HVX_Vector in, unsigned int n) { + unsigned int total = n * 4; // total vec nbytes + unsigned int width = 4; // fp32 nbytes + + HVX_Vector sum = in, sum_t; + while (width < total) { + sum_t = Q6_V_vror_VR(sum, width); // rotate right + sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(sum, sum_t)); // elementwise sum + width = width << 1; + } + return sum; +} + +static inline HVX_Vector hvx_vec_reduce_sum_f32(HVX_Vector in) { + return hvx_vec_reduce_sum_n_f32(in, 32); +} + +static inline HVX_Vector hvx_vec_reduce_max_f16(HVX_Vector in) { + unsigned total = 128; // total vec nbytes + unsigned width = 2; // fp16 nbytes + + HVX_Vector _max = in, _max_t; + while (width < total) { + _max_t = Q6_V_vror_VR(_max, width); // rotate right + _max = Q6_Vhf_vmax_VhfVhf(_max_t, _max); // elementwise max + width = width << 1; + } + + return _max; +} + +static inline HVX_Vector hvx_vec_reduce_max2_f16(HVX_Vector in, HVX_Vector _max) { + unsigned total = 128; // total vec nbytes + unsigned width = 2; // fp32 nbytes + + HVX_Vector _max_t; + + _max = Q6_Vhf_vmax_VhfVhf(in, _max); + while (width < total) { + _max_t = Q6_V_vror_VR(_max, width); // rotate right + _max = Q6_Vhf_vmax_VhfVhf(_max_t, _max); // elementwise max + width = width << 1; + } + + return _max; +} + +static inline HVX_Vector hvx_vec_reduce_max_f32(HVX_Vector in) { + unsigned total = 128; // total vec nbytes + unsigned width = 4; // fp32 nbytes + + HVX_Vector _max = in, _max_t; + while (width < total) { + _max_t = Q6_V_vror_VR(_max, width); // rotate right + _max = Q6_Vsf_vmax_VsfVsf(_max_t, _max); // elementwise max + width = width << 1; + } + + return _max; +} + +static inline HVX_Vector hvx_vec_reduce_max2_f32(HVX_Vector in, HVX_Vector _max) { + unsigned total = 128; // total vec nbytes + unsigned width = 4; // fp32 nbytes + + HVX_Vector _max_t; + + _max = Q6_Vsf_vmax_VsfVsf(in, _max); + while (width < total) { + _max_t = Q6_V_vror_VR(_max, width); // rotate right + _max = Q6_Vsf_vmax_VsfVsf(_max_t, _max); // elementwise max + width = width << 1; + } + + return _max; +} + +#define hvx_reduce_loop_body(src_type, init_vec, pad_vec, vec_op, reduce_op, scalar_reduce) \ + do { \ + src_type * restrict vsrc = (src_type *) src; \ + HVX_Vector acc = init_vec; \ + \ + const uint32_t elem_size = sizeof(float); \ + const uint32_t epv = 128 / elem_size; \ + const uint32_t nvec = num_elems / epv; \ + const uint32_t nloe = num_elems % epv; \ + \ + uint32_t i = 0; \ + _Pragma("unroll(4)") \ + for (; i < nvec; i++) { \ + acc = vec_op(acc, vsrc[i]); \ + } \ + if (nloe) { \ + const float * srcf = (const float *) src + i * epv; \ + HVX_Vector in = *(HVX_UVector *) srcf; \ + HVX_Vector temp = Q6_V_valign_VVR(in, pad_vec, nloe * elem_size); \ + acc = vec_op(acc, temp); \ + } \ + HVX_Vector v = reduce_op(acc); \ + return scalar_reduce(v); \ + } while(0) + +#define HVX_REDUCE_MAX_OP(acc, val) Q6_Vsf_vmax_VsfVsf(acc, val) +#define HVX_REDUCE_SUM_OP(acc, val) Q6_Vqf32_vadd_VsfVsf(Q6_Vsf_equals_Vqf32(acc), val) +#define HVX_SUM_SQ_OP(acc, val) Q6_Vqf32_vadd_Vqf32Vqf32(acc, Q6_Vqf32_vmpy_VsfVsf(val, val)) +#define HVX_REDUCE_MAX_SCALAR(v) hvx_vec_get_f32(v) +#define HVX_REDUCE_SUM_SCALAR(v) hvx_vec_get_f32(Q6_Vsf_equals_Vqf32(v)) + +// Max variants + +static inline float hvx_reduce_max_f32_a(const uint8_t * restrict src, const int num_elems) { + HVX_Vector init_vec = hvx_vec_splat_f32(((const float *) src)[0]); + assert((unsigned long) src % 128 == 0); + hvx_reduce_loop_body(HVX_Vector, init_vec, init_vec, HVX_REDUCE_MAX_OP, hvx_vec_reduce_max_f32, HVX_REDUCE_MAX_SCALAR); +} + +static inline float hvx_reduce_max_f32_u(const uint8_t * restrict src, const int num_elems) { + HVX_Vector init_vec = hvx_vec_splat_f32(((const float *) src)[0]); + hvx_reduce_loop_body(HVX_UVector, init_vec, init_vec, HVX_REDUCE_MAX_OP, hvx_vec_reduce_max_f32, HVX_REDUCE_MAX_SCALAR); +} + +static inline float hvx_reduce_max_f32(const uint8_t * restrict src, const int num_elems) { + if (hex_is_aligned((void *) src, 128)) { + return hvx_reduce_max_f32_a(src, num_elems); + } else { + return hvx_reduce_max_f32_u(src, num_elems); + } +} + +// Sum variants + +static inline float hvx_reduce_sum_f32_a(const uint8_t * restrict src, const int num_elems) { + HVX_Vector init_vec = Q6_V_vsplat_R(0); + assert((unsigned long) src % 128 == 0); + hvx_reduce_loop_body(HVX_Vector, init_vec, init_vec, HVX_REDUCE_SUM_OP, hvx_vec_reduce_sum_qf32, HVX_REDUCE_SUM_SCALAR); +} + +static inline float hvx_reduce_sum_f32_u(const uint8_t * restrict src, const int num_elems) { + HVX_Vector init_vec = Q6_V_vsplat_R(0); + hvx_reduce_loop_body(HVX_UVector, init_vec, init_vec, HVX_REDUCE_SUM_OP, hvx_vec_reduce_sum_qf32, HVX_REDUCE_SUM_SCALAR); +} + +static inline float hvx_reduce_sum_f32(const uint8_t * restrict src, const int num_elems) { + if (hex_is_aligned((void *) src, 128)) { + return hvx_reduce_sum_f32_a(src, num_elems); + } else { + return hvx_reduce_sum_f32_u(src, num_elems); + } +} + +// Sum of squares variants + +static inline float hvx_sum_of_squares_f32_a(const uint8_t * restrict src, const int num_elems) { + HVX_Vector init_vec = Q6_V_vsplat_R(0); + assert((uintptr_t) src % 128 == 0); + hvx_reduce_loop_body(HVX_Vector, init_vec, init_vec, HVX_SUM_SQ_OP, hvx_vec_reduce_sum_qf32, HVX_REDUCE_SUM_SCALAR); +} + +static inline float hvx_sum_of_squares_f32_u(const uint8_t * restrict src, const int num_elems) { + HVX_Vector init_vec = Q6_V_vsplat_R(0); + hvx_reduce_loop_body(HVX_UVector, init_vec, init_vec, HVX_SUM_SQ_OP, hvx_vec_reduce_sum_qf32, HVX_REDUCE_SUM_SCALAR); +} + +static inline float hvx_sum_of_squares_f32(const uint8_t * restrict src, const int num_elems) { + if (hex_is_aligned((void *) src, 128)) { + return hvx_sum_of_squares_f32_a(src, num_elems); + } else { + return hvx_sum_of_squares_f32_u(src, num_elems); + } +} + +#undef hvx_reduce_loop_body +#undef HVX_REDUCE_MAX_OP +#undef HVX_REDUCE_SUM_OP +#undef HVX_REDUCE_MAX_SCALAR +#undef HVX_REDUCE_SUM_SCALAR +#undef HVX_SUM_SQ_OP + +#endif /* HVX_REDUCE_H */ diff --git a/ggml/src/ggml-hexagon/htp/hvx-scale.h b/ggml/src/ggml-hexagon/htp/hvx-scale.h new file mode 100644 index 00000000000..c65c98639dc --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hvx-scale.h @@ -0,0 +1,133 @@ +#ifndef HVX_SCALE_H +#define HVX_SCALE_H + +#include +#include +#include + +#include "hvx-base.h" + +#define hvx_scale_f32_loop_body(dst_type, src_type, vec_store) \ + do { \ + dst_type * restrict vdst = (dst_type *) dst; \ + src_type * restrict vsrc = (src_type *) src; \ + \ + HVX_Vector vs = hvx_vec_splat_f32(scale); \ + \ + const uint32_t elem_size = sizeof(float); \ + const uint32_t epv = 128 / elem_size; \ + const uint32_t nvec = n / epv; \ + const uint32_t nloe = n % epv; \ + \ + uint32_t i = 0; \ + \ + _Pragma("unroll(4)") \ + for (; i < nvec; ++i) { \ + HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs); \ + vdst[i] = Q6_Vsf_equals_Vqf32(v); \ + } \ + if (nloe) { \ + HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs); \ + vec_store((void *) &vdst[i], nloe * elem_size, Q6_Vsf_equals_Vqf32(v)); \ + } \ + } while(0) + +static inline void hvx_scale_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) { + assert((size_t) dst % 128 == 0); + assert((size_t) src % 128 == 0); + hvx_scale_f32_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a); +} + +static inline void hvx_scale_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) { + assert((size_t) dst % 128 == 0); + hvx_scale_f32_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a); +} + +static inline void hvx_scale_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) { + assert((size_t) src % 128 == 0); + hvx_scale_f32_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u); +} + +static inline void hvx_scale_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) { + hvx_scale_f32_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u); +} + +static inline void hvx_scale_f32(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) { + if (((size_t) dst & 127) == 0) { + if (((size_t) src & 127) == 0) { + hvx_scale_f32_aa(dst, src, n, scale); + } else { + hvx_scale_f32_au(dst, src, n, scale); + } + } else { + if (((size_t) src & 127) == 0) { + hvx_scale_f32_ua(dst, src, n, scale); + } else { + hvx_scale_f32_uu(dst, src, n, scale); + } + } +} + +#define hvx_scale_offset_f32_loop_body(dst_type, src_type, vec_store) \ + do { \ + dst_type * restrict vdst = (dst_type *) dst; \ + src_type * restrict vsrc = (src_type *) src; \ + \ + HVX_Vector vs = hvx_vec_splat_f32(scale); \ + HVX_Vector vo = hvx_vec_splat_f32(offset); \ + \ + const uint32_t elem_size = sizeof(float); \ + const uint32_t epv = 128 / elem_size; \ + const uint32_t nvec = n / epv; \ + const uint32_t nloe = n % epv; \ + \ + uint32_t i = 0; \ + \ + _Pragma("unroll(4)") \ + for (; i < nvec; ++i) { \ + HVX_Vector v = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs), vo); \ + vdst[i] = Q6_Vsf_equals_Vqf32(v); \ + } \ + if (nloe) { \ + HVX_Vector v = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs), vo); \ + vec_store((void *) &vdst[i], nloe * elem_size, Q6_Vsf_equals_Vqf32(v)); \ + } \ + } while(0) + +static inline void hvx_scale_offset_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) { + assert((size_t) dst % 128 == 0); + assert((size_t) src % 128 == 0); + hvx_scale_offset_f32_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a); +} + +static inline void hvx_scale_offset_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) { + assert((size_t) dst % 128 == 0); + hvx_scale_offset_f32_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a); +} + +static inline void hvx_scale_offset_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) { + assert((size_t) src % 128 == 0); + hvx_scale_offset_f32_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u); +} + +static inline void hvx_scale_offset_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) { + hvx_scale_offset_f32_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u); +} + +static inline void hvx_scale_offset_f32(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) { + if (((size_t) dst & 127) == 0) { + if (((size_t) src & 127) == 0) { + hvx_scale_offset_f32_aa(dst, src, n, scale, offset); + } else { + hvx_scale_offset_f32_au(dst, src, n, scale, offset); + } + } else { + if (((size_t) src & 127) == 0) { + hvx_scale_offset_f32_ua(dst, src, n, scale, offset); + } else { + hvx_scale_offset_f32_uu(dst, src, n, scale, offset); + } + } +} + +#endif // HVX_SCALE_H diff --git a/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c b/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c deleted file mode 100644 index 15ac64697c7..00000000000 --- a/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +++ /dev/null @@ -1,49 +0,0 @@ -#pragma clang diagnostic ignored "-Wunused-variable" -#pragma clang diagnostic ignored "-Wunused-function" -#pragma clang diagnostic ignored "-Wunused-but-set-variable" - -#include -#include -#include -#include - -#define GGML_COMMON_DECL_C -#include "ggml-common.h" -#include "htp-ctx.h" -#include "htp-dma.h" -#include "htp-msg.h" -#include "htp-ops.h" -#include "hvx-utils.h" -#include "ops-utils.h" - -#if 0 -// Reference algo used in hvx-utils -static void fast_sigmoid_f32(const float* restrict src, float* restrict dst, const int num_elems) -{ - const float c1 = 0.03138777; - const float c2 = 0.276281267; - const float c_log2f = 1.442695022; - - int32_t store_ints[32]; - float store_floats[3][32]; - - for (int i = 0; i < num_elems; i++) - { - float v = src0[i]; - - v *= c_log2f*0.5; - int intPart = (int)v; - float x = (v - intPart); - float xx = x * x; - float v1 = c_log2f + c2 * xx; - float v2 = x + xx * c1 * x; - float v3 = (v2 + v1); - *((int*)&v3) += intPart << 24; - float v4 = v2 - v1; - float v5 = v3 - v4; - float res = v3 / v5; - - dst[i] = res; - } -} -#endif diff --git a/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h b/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h new file mode 100644 index 00000000000..1b4aaff0c92 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h @@ -0,0 +1,114 @@ +#ifndef HVX_SIGMOID_H +#define HVX_SIGMOID_H + +#include "hvx-base.h" + +#define FAST_SIGMOID_LOG2F (0x3fb8aa3b) // 1.442695022 +#define FAST_SIGMOID_C1 (0x3d009076) // 0.03138777 +#define FAST_SIGMOID_C2 (0x3e8d74bd) // 0.276281267 +#define FAST_SIGMOID_C3 (0x3f000000) // 0.5 + +static inline HVX_Vector hvx_vec_fast_sigmoid_f32(HVX_Vector v) { + v = Q6_Vqf32_vmpy_VsfVsf(v, Q6_V_vsplat_R(FAST_SIGMOID_LOG2F)); + v = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(v), Q6_V_vsplat_R(FAST_SIGMOID_C3)); + + HVX_Vector in_int = hvx_vec_truncate_f32(Q6_Vsf_equals_Vqf32(v)); + HVX_Vector x = Q6_Vqf32_vsub_Vqf32Vsf(v, Q6_Vsf_equals_Vw(in_int)); + HVX_Vector xx = Q6_Vqf32_vmpy_Vqf32Vqf32(x, x); + + HVX_Vector v1 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(xx), Q6_V_vsplat_R(FAST_SIGMOID_C2)); + v1 = Q6_Vqf32_vadd_Vqf32Vsf(v1, Q6_V_vsplat_R(FAST_SIGMOID_LOG2F)); + + HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(x), Q6_V_vsplat_R(FAST_SIGMOID_C1)); + v2 = Q6_Vqf32_vmpy_Vqf32Vqf32(v2, xx); + v2 = Q6_Vqf32_vadd_Vqf32Vqf32(v2, x); + + HVX_Vector v3 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vqf32(v2, v1)); + HVX_Vector v3_exponent = Q6_Vw_vasl_VwR(v3, 1); + v3_exponent = Q6_Vuw_vlsr_VuwR(v3_exponent, 24); + v3_exponent = Q6_Vw_vadd_VwVw(in_int, v3_exponent); + v3 = Q6_Vw_vaslacc_VwVwR(v3, in_int, 24); + + HVX_Vector v4 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_Vqf32Vqf32(v2, v1)); + HVX_Vector v5 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(v3, v4)); + + HVX_Vector res = hvx_vec_inverse_f32(v5); + res = Q6_Vqf32_vmpy_VsfVsf(v3, res); + + return Q6_Vsf_equals_Vqf32(res); +} + +static inline HVX_Vector hvx_vec_fast_sigmoid_f32_guard(HVX_Vector v, + HVX_Vector one, + HVX_Vector max_exp, + HVX_Vector min_exp) { + const HVX_VectorPred pred_max = Q6_Q_vcmp_gt_VsfVsf(max_exp, v); + const HVX_VectorPred pred_min = Q6_Q_vcmp_gt_VsfVsf(v, min_exp); + + HVX_Vector out = hvx_vec_fast_sigmoid_f32(v); + out = Q6_V_vmux_QVV(pred_max, out, one); + return Q6_V_vmux_QVV(pred_min, out, Q6_V_vzero()); +} + +static inline HVX_Vector hvx_vec_tanh_f32(HVX_Vector x) { + // tanh(x) = 2 * sigmoid(2x) - 1 + HVX_Vector two = hvx_vec_splat_f32(2.0f); + HVX_Vector one = hvx_vec_splat_f32(1.0f); + HVX_Vector x2 = Q6_Vqf32_vmpy_VsfVsf(x, two); + + HVX_Vector max_exp = hvx_vec_splat_f32(87.f); + HVX_Vector min_exp = hvx_vec_splat_f32(-87.f); + + HVX_Vector sig2x = hvx_vec_fast_sigmoid_f32_guard(Q6_Vsf_equals_Vqf32(x2), one, max_exp, min_exp); + + HVX_Vector res = Q6_Vqf32_vmpy_VsfVsf(sig2x, two); + res = Q6_Vqf32_vsub_Vqf32Vsf(res, one); + return Q6_Vsf_equals_Vqf32(res); +} + +#define hvx_sigmoid_loop_body(dst_type, src_type, vec_store) \ + do { \ + dst_type * restrict vdst = (dst_type *) dst; \ + src_type * restrict vsrc = (src_type *) src; \ + \ + const HVX_Vector one = hvx_vec_splat_f32(1.f); \ + const HVX_Vector max_exp = hvx_vec_splat_f32(87.f); \ + const HVX_Vector min_exp = hvx_vec_splat_f32(-87.f); \ + \ + const uint32_t epv = 128 / sizeof(float); \ + const uint32_t nvec = n / epv; \ + const uint32_t nloe = n % epv; \ + \ + uint32_t i = 0; \ + \ + _Pragma("unroll(4)") \ + for (; i < nvec; i++) { \ + vdst[i] = hvx_vec_fast_sigmoid_f32_guard(vsrc[i], one, max_exp, min_exp); \ + } \ + if (nloe) { \ + HVX_Vector tmp = hvx_vec_fast_sigmoid_f32_guard(vsrc[i], one, max_exp, min_exp); \ + vec_store((void *) &vdst[i], nloe * sizeof(float), tmp); \ + } \ + } while(0) + +static inline void hvx_sigmoid_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + assert((unsigned long) dst % 128 == 0); + assert((unsigned long) src % 128 == 0); + hvx_sigmoid_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a); +} + +static inline void hvx_sigmoid_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + assert((unsigned long) dst % 128 == 0); + hvx_sigmoid_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a); +} + +static inline void hvx_sigmoid_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + assert((unsigned long) src % 128 == 0); + hvx_sigmoid_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u); +} + +static inline void hvx_sigmoid_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + hvx_sigmoid_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u); +} + +#endif /* HVX_SIGMOID_H */ diff --git a/ggml/src/ggml-hexagon/htp/hvx-sqrt.h b/ggml/src/ggml-hexagon/htp/hvx-sqrt.h new file mode 100644 index 00000000000..28ee9f68d3e --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hvx-sqrt.h @@ -0,0 +1,60 @@ +#ifndef HVX_SQRT_H +#define HVX_SQRT_H + +#include +#include + +#include "hex-utils.h" + +#include "hvx-base.h" + +#define RSQRT_CONST 0x5f3759df // Constant for fast inverse square root calculation +#define RSQRT_ONE_HALF 0x3f000000 // 0.5 +#define RSQRT_THREE_HALVES 0x3fc00000 // 1.5 + +static inline HVX_Vector hvx_vec_rsqrt_f32(HVX_Vector in_vec) { + //Algorithm : + // x2 = input*0.5 + // y = * (long *) &input + // y = 0x5f3759df - (y>>2) + // y = y*(threehalfs - x2*y*y) + + HVX_Vector rsqrtconst = Q6_V_vsplat_R(RSQRT_CONST); + HVX_Vector onehalf = Q6_V_vsplat_R(RSQRT_ONE_HALF); + HVX_Vector threehalfs = Q6_V_vsplat_R(RSQRT_THREE_HALVES); + + HVX_Vector x2, y, ypower2, temp; + + x2 = Q6_Vqf32_vmpy_VsfVsf(in_vec, onehalf); + x2 = Q6_Vqf32_vadd_Vqf32Vsf(x2, Q6_V_vzero()); + + y = Q6_Vw_vasr_VwR(in_vec, 1); + y = Q6_Vw_vsub_VwVw(rsqrtconst, y); + + // 1st iteration + ypower2 = Q6_Vqf32_vmpy_VsfVsf(y, y); + ypower2 = Q6_Vqf32_vadd_Vqf32Vsf(ypower2, Q6_V_vzero()); + temp = Q6_Vqf32_vmpy_Vqf32Vqf32(x2, ypower2); + temp = Q6_Vqf32_vsub_VsfVsf(threehalfs, Q6_Vsf_equals_Vqf32(temp)); + temp = Q6_Vqf32_vmpy_VsfVsf(y, Q6_Vsf_equals_Vqf32(temp)); + + // 2nd iteration + y = Q6_Vqf32_vadd_Vqf32Vsf(temp, Q6_V_vzero()); + ypower2 = Q6_Vqf32_vmpy_Vqf32Vqf32(y, y); + ypower2 = Q6_Vqf32_vadd_Vqf32Vsf(ypower2, Q6_V_vzero()); + temp = Q6_Vqf32_vmpy_Vqf32Vqf32(x2, ypower2); + temp = Q6_Vqf32_vsub_VsfVsf(threehalfs, Q6_Vsf_equals_Vqf32(temp)); + temp = Q6_Vqf32_vmpy_Vqf32Vqf32(y, temp); + + // 3rd iteration + y = Q6_Vqf32_vadd_Vqf32Vsf(temp, Q6_V_vzero()); + ypower2 = Q6_Vqf32_vmpy_Vqf32Vqf32(y, y); + ypower2 = Q6_Vqf32_vadd_Vqf32Vsf(ypower2, Q6_V_vzero()); + temp = Q6_Vqf32_vmpy_Vqf32Vqf32(x2, ypower2); + temp = Q6_Vqf32_vsub_VsfVsf(threehalfs, Q6_Vsf_equals_Vqf32(temp)); + temp = Q6_Vqf32_vmpy_Vqf32Vqf32(y, temp); + + return Q6_Vsf_equals_Vqf32(temp); +} + +#endif /* HVX_SQRT_H */ diff --git a/ggml/src/ggml-hexagon/htp/hvx-types.h b/ggml/src/ggml-hexagon/htp/hvx-types.h new file mode 100644 index 00000000000..d495a59fbea --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hvx-types.h @@ -0,0 +1,36 @@ +#ifndef HVX_TYPES_H +#define HVX_TYPES_H + +#include +#include + +#include + +#define SIZEOF_FP32 (4) +#define SIZEOF_FP16 (2) +#define VLEN (128) +#define VLEN_FP32 (VLEN / SIZEOF_FP32) +#define VLEN_FP16 (VLEN / SIZEOF_FP16) + +typedef union { + HVX_Vector v; + uint8_t b[VLEN]; + uint16_t h[VLEN_FP16]; + uint32_t w[VLEN_FP32]; + __fp16 fp16[VLEN_FP16]; + float fp32[VLEN_FP32]; +} __attribute__((aligned(VLEN), packed)) HVX_VectorAlias; + +typedef struct { + HVX_Vector v[2]; +} HVX_Vector_x2; + +typedef struct { + HVX_Vector v[4]; +} HVX_Vector_x4; + +typedef struct { + HVX_Vector v[8]; +} HVX_Vector_x8; + +#endif /* HVX_TYPES_H */ diff --git a/ggml/src/ggml-hexagon/htp/hvx-utils.c b/ggml/src/ggml-hexagon/htp/hvx-utils.c deleted file mode 100644 index 29d73b8622b..00000000000 --- a/ggml/src/ggml-hexagon/htp/hvx-utils.c +++ /dev/null @@ -1,1020 +0,0 @@ -#pragma clang diagnostic ignored "-Wunused-variable" -#pragma clang diagnostic ignored "-Wunused-function" -#pragma clang diagnostic ignored "-Wunused-but-set-variable" - -#ifdef HTP_DEBUG -# define FARF_HIGH 1 -#endif - -#include -#include -#include -#include -#include -#include -#include -#include - -#define GGML_COMMON_DECL_C -#include "ggml-common.h" -#include "hvx-utils.h" - -#define htp_binary_ops_preamble \ - int step_of_4 = num_elems >> 7; \ - int step_of_2 = (num_elems - step_of_4 * VLEN_FP32 * 4) >> 6; \ - int step_of_1 = (num_elems - step_of_4 * VLEN_FP32 * 4 - step_of_2 * VLEN_FP32 * 2) >> 5; \ - int remaining = num_elems - step_of_4 * VLEN_FP32 * 4 - step_of_2 * VLEN_FP32 * 2 - step_of_1 * VLEN_FP32; \ - \ - const uint8_t * restrict src0_curr = src0; \ - const uint8_t * restrict src1_curr = src1; \ - uint8_t * restrict dst_curr = dst; - -void hvx_mul_f32(const uint8_t * restrict src0, - const uint8_t * restrict src1, - uint8_t * restrict dst, - const int num_elems) { - int left_over = num_elems & (VLEN_FP32 - 1); - int num_elems_whole = num_elems - left_over; - - int unaligned_addr = 0; - int unaligned_loop = 0; - if ((0 == htp_is_aligned((void *) src0, VLEN)) || (0 == htp_is_aligned((void *) src1, VLEN)) || - (0 == htp_is_aligned((void *) dst, VLEN))) { - FARF(HIGH, "hvx_mul_f32: unaligned address in hvx op, possibly slower execution\n"); - unaligned_addr = 1; - } - - if ((1 == unaligned_addr) && (num_elems_whole != 0)) { - unaligned_loop = 1; - FARF(HIGH, "hvx_mul_f32: unaligned loop in hvx op, possibly slower execution\n"); - } - - - bool handled_leftover = false; - if (0 == unaligned_loop) { - HVX_Vector * restrict vec_in1 = (HVX_Vector *) src0; - HVX_Vector * restrict vec_in2 = (HVX_Vector *) src1; - HVX_Vector * restrict vec_out = (HVX_Vector *) dst; - - #pragma unroll(4) - for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { - HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(*vec_in1++, *vec_in2++); - *vec_out++ = Q6_Vsf_equals_Vqf32(v); - } - } else { - int step_of_1 = num_elems_whole >> 5; // divby 32, because 32 float = 128 bytes per HVX vector - int leftover_size = left_over * sizeof(float); - - - HVX_Vector * restrict vec_in1 = (HVX_Vector *) src0; - HVX_Vector * restrict vec_in2 = (HVX_Vector *) src1; - HVX_UVector * restrict vec_out = (HVX_UVector *) dst; - - HVX_Vector slinep; - HVX_Vector slinec; - HVX_Vector sline; - HVX_Vector sline2p; - HVX_Vector sline2c; - HVX_Vector sline2; - - slinep = *vec_in1++; - sline2p = *vec_in2++; - #pragma unroll(4) - for (int i = step_of_1 - 1; i > 0; i--) { - slinec = *vec_in1++; - sline2c = *vec_in2++; - sline = Q6_V_valign_VVR(slinec, slinep, (size_t) src0); - sline2 = Q6_V_valign_VVR(sline2c, sline2p, (size_t) src1); - - *((HVX_UVector *) (vec_out++)) = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(sline, sline2)); - slinep = slinec; - sline2p = sline2c; - } - if (step_of_1 > 1) { - slinec = htp_is_aligned(vec_in1, VLEN) && left_over == 0 ? slinep : *vec_in1++; - sline2c = htp_is_aligned(vec_in2, VLEN) && left_over == 0 ? sline2p : *vec_in2++; - - sline = Q6_V_valign_VVR(slinec, slinep, (size_t) src0); - sline2 = Q6_V_valign_VVR(sline2c, sline2p, (size_t) src1); - *((HVX_UVector *) (vec_out++)) = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(sline, sline2)); - slinep = slinec; - sline2p = sline2c; - } - if (left_over > 0) { - slinec = (is_in_one_chunk(vec_in1, leftover_size, VLEN) ? slinep : *vec_in1++); - - sline = Q6_V_valign_VVR(slinec, slinep, (size_t) src0); - sline2c = (is_in_one_chunk(vec_in2, leftover_size, VLEN) ? sline2p : *vec_in2++); - sline2 = Q6_V_valign_VVR(sline2c, sline2p, (size_t) src1); - - HVX_Vector out = Q6_Vqf32_vmpy_VsfVsf(sline, sline2); - hvx_vec_store_u(vec_out, leftover_size, Q6_Vsf_equals_Vqf32(out)); - handled_leftover = true; - } - } - - - if (left_over > 0 && !handled_leftover) { - const float * src0f = (const float *) src0 + num_elems_whole; - const float * src1f = (const float *) src1 + num_elems_whole; - float * dstf = (float *) dst + num_elems_whole; - - HVX_Vector in1 = *(HVX_UVector *) src0f; - HVX_Vector in2 = *(HVX_UVector *) src1f; - - HVX_Vector out = Q6_Vqf32_vmpy_VsfVsf(in1, in2); - hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(out)); - } -} - -void hvx_mul_f32_opt(const uint8_t * restrict src0, - const uint8_t * restrict src1, - uint8_t * restrict dst, - const int num_elems) { - htp_binary_ops_preamble; - - for (int i = 0; i < step_of_4; i++) { - HVX_Vector v1a = *(HVX_Vector *) src0_curr; - - HVX_Vector v1b = *(HVX_Vector *) src1_curr; - - HVX_Vector v2a = *(HVX_Vector *) (src0_curr + VLEN); - - HVX_Vector v1 = Q6_Vqf32_vmpy_VsfVsf(v1a, v1b); - - HVX_Vector v2b = *(HVX_Vector *) (src1_curr + VLEN); - - HVX_Vector v3a = *(HVX_Vector *) (src0_curr + 2 * VLEN); - - HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v2a, v2b); - - *(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v1); - - HVX_Vector v3b = *(HVX_Vector *) (src1_curr + 2 * VLEN); - - HVX_Vector v4a = *(HVX_Vector *) (src0_curr + 3 * VLEN); - - src0_curr += 4 * VLEN; - - HVX_Vector v3 = Q6_Vqf32_vmpy_VsfVsf(v3a, v3b); - - *(HVX_Vector *) (dst_curr + VLEN) = Q6_Vsf_equals_Vqf32(v2); - - HVX_Vector v4b = *(HVX_Vector *) (src1_curr + 3 * VLEN); - - *(HVX_Vector *) (dst_curr + 2 * VLEN) = Q6_Vsf_equals_Vqf32(v3); - - HVX_Vector v4 = Q6_Vqf32_vmpy_VsfVsf(v4a, v4b); - - src1_curr += 4 * VLEN; - - *(HVX_Vector *) (dst_curr + 3 * VLEN) = Q6_Vsf_equals_Vqf32(v4); - - dst_curr += 4 * VLEN; - } - - for (int i = 0; i < step_of_2; i++) { - HVX_Vector v1a = *(HVX_Vector *) src0_curr; - - HVX_Vector v1b = *(HVX_Vector *) src1_curr; - - HVX_Vector v2a = *(HVX_Vector *) (src0_curr + VLEN); - - HVX_Vector v1 = Q6_Vqf32_vmpy_VsfVsf(v1a, v1b); - - HVX_Vector v2b = *(HVX_Vector *) (src1_curr + VLEN); - - *(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v1); - - src0_curr += 2 * VLEN; - - HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v2a, v2b); - - src1_curr += 2 * VLEN; - - *(HVX_Vector *) (dst_curr + VLEN) = Q6_Vsf_equals_Vqf32(v2); - - dst_curr += 2 * VLEN; - } - - for (int i = 0; i < step_of_1; i++) { - HVX_Vector va = *(HVX_Vector *) src0_curr; - - src0_curr += VLEN; - - HVX_Vector vb = *(HVX_Vector *) src1_curr; - - src1_curr += VLEN; - - HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(va, vb); - - *(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v); - - dst_curr += VLEN; - } - - if (remaining > 0) { - HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(*(HVX_Vector *) src0_curr, *(HVX_Vector *) src1_curr); - hvx_vec_store_u((void *) dst_curr, remaining * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(v)); - } -} - -void hvx_mul_mul_f32_opt(const uint8_t * restrict src0, - const uint8_t * restrict src1, - const uint8_t * restrict src2, - uint8_t * restrict dst, - const int num_elems) { - const uint8_t * restrict src0_curr = src0; - const uint8_t * restrict src1_curr = src1; - const uint8_t * restrict src2_curr = src2; - uint8_t * restrict dst_curr = dst; - - int step_of_2 = num_elems >> 6; - int step_of_1 = (num_elems - step_of_2 * VLEN_FP32 * 2) >> 5; - int remaining = num_elems - step_of_2 * VLEN_FP32 * 2 - step_of_1 * VLEN_FP32; - - for (int i = 0; i < step_of_2; i++) { - HVX_Vector v1a = *(HVX_Vector *) src0_curr; - HVX_Vector v1b = *(HVX_Vector *) src1_curr; - HVX_Vector v1c = *(HVX_Vector *) src2_curr; - - HVX_Vector v2a = *(HVX_Vector *) (src0_curr + VLEN); - - HVX_Vector v1_ = Q6_Vqf32_vmpy_VsfVsf(v1a, v1b); - HVX_Vector v1 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(v1_), v1c); - - HVX_Vector v2b = *(HVX_Vector *) (src1_curr + VLEN); - - *(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v1); - - HVX_Vector v2c = *(HVX_Vector *) (src2_curr + VLEN); - - src0_curr += 2 * VLEN; - - HVX_Vector v2_ = Q6_Vqf32_vmpy_VsfVsf(v2a, v2b); - HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(v2_), v2c); - - src1_curr += 2 * VLEN; - src2_curr += 2 * VLEN; - - *(HVX_Vector *) (dst_curr + VLEN) = Q6_Vsf_equals_Vqf32(v2); - - dst_curr += 2 * VLEN; - } - for (int i = 0; i < step_of_1; i++) { - HVX_Vector va = *(HVX_Vector *) src0_curr; - src0_curr += VLEN; - - HVX_Vector vb = *(HVX_Vector *) src1_curr; - src1_curr += VLEN; - - HVX_Vector vc = *(HVX_Vector *) src2_curr; - src2_curr += VLEN; - - HVX_Vector v1 = Q6_Vqf32_vmpy_VsfVsf(va, vb); - HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(v1), vc); - - *(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v2); - dst_curr += VLEN; - } - if (remaining > 0) { - HVX_Vector v1 = Q6_Vqf32_vmpy_VsfVsf(*(HVX_Vector *) src0_curr, *(HVX_Vector *) src1_curr); - HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(v1), *(HVX_Vector *) src2_curr); - hvx_vec_store_u((void *) dst_curr, remaining * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(v2)); - } -} - -void hvx_add_f32(const uint8_t * restrict src0, - const uint8_t * restrict src1, - uint8_t * restrict dst, - const int num_elems) { - int left_over = num_elems & (VLEN_FP32 - 1); - int num_elems_whole = num_elems - left_over; - - int unaligned_addr = 0; - int unaligned_loop = 0; - if ((0 == htp_is_aligned((void *) src0, VLEN)) || (0 == htp_is_aligned((void *) src1, VLEN)) || - (0 == htp_is_aligned((void *) dst, VLEN))) { - FARF(HIGH, "hvx_add_f32: unaligned address in hvx op, possibly slower execution\n"); - unaligned_addr = 1; - } - - if ((1 == unaligned_addr) && (num_elems_whole != 0)) { - unaligned_loop = 1; - FARF(HIGH, "hvx_add_f32: unaligned loop in hvx op, possibly slower execution\n"); - } - - if (0 == unaligned_loop) { - HVX_Vector * restrict vec_in1 = (HVX_Vector *) src0; - HVX_Vector * restrict vec_in2 = (HVX_Vector *) src1; - HVX_Vector * restrict vec_out = (HVX_Vector *) dst; - - #pragma unroll(4) - for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { - HVX_Vector v = Q6_Vqf32_vadd_VsfVsf(*vec_in1++, *vec_in2++); - *vec_out++ = Q6_Vsf_equals_Vqf32(v); - } - } else { - #pragma unroll(4) - for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { - HVX_Vector in1 = *(HVX_UVector *) (src0 + i * SIZEOF_FP32); - HVX_Vector in2 = *(HVX_UVector *) (src1 + i * SIZEOF_FP32); - - HVX_Vector out = Q6_Vqf32_vadd_VsfVsf(in1, in2); - - *(HVX_UVector *) (dst + i * SIZEOF_FP32) = Q6_Vsf_equals_Vqf32(out); - } - } - - if (left_over > 0) { - const float * src0f = (const float *) src0 + num_elems_whole; - const float * src1f = (const float *) src1 + num_elems_whole; - float * dstf = (float *) dst + num_elems_whole; - - HVX_Vector in1 = *(HVX_UVector *) src0f; - HVX_Vector in2 = *(HVX_UVector *) src1f; - - HVX_Vector out = Q6_Vqf32_vadd_VsfVsf(in1, in2); - hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(out)); - } -} - -void hvx_add_f32_opt(const uint8_t * restrict src0, - const uint8_t * restrict src1, - uint8_t * restrict dst, - const int num_elems) { - htp_binary_ops_preamble; - - for (int i = 0; i < step_of_4; i++) { - HVX_Vector v1a = *(HVX_Vector *) src0_curr; - - HVX_Vector v1b = *(HVX_Vector *) src1_curr; - - HVX_Vector v2a = *(HVX_Vector *) (src0_curr + VLEN); - - HVX_Vector v1 = Q6_Vqf32_vadd_VsfVsf(v1a, v1b); - - HVX_Vector v2b = *(HVX_Vector *) (src1_curr + VLEN); - - HVX_Vector v3a = *(HVX_Vector *) (src0_curr + 2 * VLEN); - - HVX_Vector v2 = Q6_Vqf32_vadd_VsfVsf(v2a, v2b); - - *(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v1); - - HVX_Vector v3b = *(HVX_Vector *) (src1_curr + 2 * VLEN); - - HVX_Vector v4a = *(HVX_Vector *) (src0_curr + 3 * VLEN); - - src0_curr += 4 * VLEN; - - HVX_Vector v3 = Q6_Vqf32_vadd_VsfVsf(v3a, v3b); - - *(HVX_Vector *) (dst_curr + VLEN) = Q6_Vsf_equals_Vqf32(v2); - - HVX_Vector v4b = *(HVX_Vector *) (src1_curr + 3 * VLEN); - - *(HVX_Vector *) (dst_curr + 2 * VLEN) = Q6_Vsf_equals_Vqf32(v3); - - HVX_Vector v4 = Q6_Vqf32_vadd_VsfVsf(v4a, v4b); - - src1_curr += 4 * VLEN; - - *(HVX_Vector *) (dst_curr + 3 * VLEN) = Q6_Vsf_equals_Vqf32(v4); - - dst_curr += 4 * VLEN; - } - for (int i = 0; i < step_of_2; i++) { - HVX_Vector v1a = *(HVX_Vector *) src0_curr; - - HVX_Vector v1b = *(HVX_Vector *) src1_curr; - - HVX_Vector v2a = *(HVX_Vector *) (src0_curr + VLEN); - - HVX_Vector v1 = Q6_Vqf32_vadd_VsfVsf(v1a, v1b); - - HVX_Vector v2b = *(HVX_Vector *) (src1_curr + VLEN); - - *(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v1); - - src0_curr += 2 * VLEN; - - HVX_Vector v2 = Q6_Vqf32_vadd_VsfVsf(v2a, v2b); - - src1_curr += 2 * VLEN; - - *(HVX_Vector *) (dst_curr + VLEN) = Q6_Vsf_equals_Vqf32(v2); - - dst_curr += 2 * VLEN; - } - for (int i = 0; i < step_of_1; i++) { - HVX_Vector va = *(HVX_Vector *) src0_curr; - - src0_curr += VLEN; - - HVX_Vector vb = *(HVX_Vector *) src1_curr; - - src1_curr += VLEN; - - HVX_Vector v = Q6_Vqf32_vadd_VsfVsf(va, vb); - - *(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v); - - dst_curr += VLEN; - } - if (remaining > 0) { - HVX_Vector v = Q6_Vqf32_vadd_VsfVsf(*(HVX_Vector *) src0_curr, *(HVX_Vector *) src1_curr); - hvx_vec_store_u((void *) dst_curr, remaining * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(v)); - } -} - -void hvx_add_scalar_f32(const uint8_t * restrict src, const float val, uint8_t * restrict dst, const int num_elems) { - size_t left_over = num_elems & (VLEN_FP32 - 1); - size_t num_elems_whole = num_elems - left_over; - - int unaligned_addr = 0; - int unaligned_loop = 0; - if ((0 == htp_is_aligned((void *) src, VLEN)) || (0 == htp_is_aligned((void *) dst, VLEN))) { - FARF(HIGH, "hvx_add_scalar_f32: unaligned address in hvx op, possibly slower execution\n"); - unaligned_addr = 1; - } - - if ((1 == unaligned_addr) && (num_elems_whole != 0)) { - unaligned_loop = 1; - FARF(HIGH, "hvx_add_scalar_f32: unaligned loop in hvx op, possibly slower execution\n"); - } - - static const float kInf = INFINITY; - const HVX_Vector inf = hvx_vec_splat_fp32(kInf); - HVX_Vector val_vec = hvx_vec_splat_fp32(val); - - if (0 == unaligned_loop) { - HVX_Vector * restrict vec_in1 = (HVX_Vector *) src; - HVX_Vector * restrict vec_out = (HVX_Vector *) dst; - - #pragma unroll(4) - for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { - HVX_Vector in = *vec_in1++; - const HVX_VectorPred pred_inf = Q6_Q_vcmp_eq_VwVw(inf, in); - HVX_Vector v = Q6_Vqf32_vadd_VsfVsf(in, val_vec); - v = Q6_Vsf_equals_Vqf32(v); - v = Q6_V_vmux_QVV(pred_inf, inf, v); - *vec_out++ = v; - } - } else { - #pragma unroll(4) - for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { - HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32); - - const HVX_VectorPred pred_inf = Q6_Q_vcmp_eq_VwVw(inf, in); - HVX_Vector out = Q6_Vqf32_vadd_VsfVsf(in, val_vec); - out = Q6_Vsf_equals_Vqf32(out); - out = Q6_V_vmux_QVV(pred_inf, inf, out); - - *(HVX_UVector *) (dst + i * SIZEOF_FP32) = out; - } - } - - if (left_over > 0) { - const float * srcf = (const float *) src + num_elems_whole; - float * dstf = (float *) dst + num_elems_whole; - - HVX_Vector in = *(HVX_UVector *) srcf; - - const HVX_VectorPred pred_inf = Q6_Q_vcmp_eq_VwVw(inf, in); - HVX_Vector out = Q6_Vqf32_vadd_VsfVsf(in, val_vec); - out = Q6_Vsf_equals_Vqf32(out); - out = Q6_V_vmux_QVV(pred_inf, inf, out); - - hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, out); - } -} - -void hvx_mul_scalar_f32(const uint8_t * restrict src, const float val, uint8_t * restrict dst, const int num_elems) { - size_t left_over = num_elems & (VLEN_FP32 - 1); - size_t num_elems_whole = num_elems - left_over; - - int unaligned_addr = 0; - int unaligned_loop = 0; - if ((0 == htp_is_aligned((void *) src, VLEN)) || (0 == htp_is_aligned((void *) dst, VLEN))) { - FARF(HIGH, "hvx_mul_scalar_f32: unaligned address in hvx op, possibly slower execution\n"); - unaligned_addr = 1; - } - - if ((1 == unaligned_addr) && (num_elems_whole != 0)) { - unaligned_loop = 1; - FARF(HIGH, "hvx_mul_scalar_f32: unaligned loop in hvx op, possibly slower execution\n"); - } - - HVX_Vector val_vec = hvx_vec_splat_fp32(val); - bool handled_leftover = false; - if (0 == unaligned_loop) { - HVX_Vector * restrict vec_in1 = (HVX_Vector *) src; - HVX_Vector * restrict vec_out = (HVX_Vector *) dst; - - #pragma unroll(4) - for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { - HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(*vec_in1++, val_vec); - *vec_out++ = Q6_Vsf_equals_Vqf32(v); - } - } else { - int step_of_1 = num_elems >> 5; // divby 32, because 32 float = 128 bytes per HVX vector - int leftover_size = left_over * sizeof(float); - - HVX_Vector * input_v_ptr = (HVX_Vector *) src; - HVX_UVector * output_v_ptr = (HVX_UVector *) dst; - - HVX_Vector slinep; - HVX_Vector slinec; - HVX_Vector sline; - - slinep = *input_v_ptr++; - - #pragma unroll(4) - for (int i = step_of_1 - 1; i > 0; i--) { - slinec = *input_v_ptr++; - sline = Q6_V_valign_VVR(slinec, slinep, (size_t) src); - *((HVX_UVector *) (output_v_ptr++)) = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(sline, val_vec)); - /* Prepare slinep for next iteration */ - slinep = slinec; - } - - if (step_of_1 > 0) { - slinec = htp_is_aligned(input_v_ptr, VLEN) && left_over == 0 ? slinep : *input_v_ptr++; - sline = Q6_V_valign_VVR(slinec, slinep, (size_t) src); - *((HVX_UVector *) (output_v_ptr++)) = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(sline, val_vec)); - - slinep = slinec; - } - - if (leftover_size > 0) { - slinec = (is_in_one_chunk(input_v_ptr, leftover_size, VLEN) ? slinep : *input_v_ptr++); - - sline = Q6_V_valign_VVR(slinec, slinep, (size_t) src); - - HVX_Vector sout = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(sline, val_vec)); - hvx_vec_store_u(output_v_ptr, leftover_size, sout); - handled_leftover = true; - } - } - - if (left_over > 0 && !handled_leftover) { - const float * srcf = (const float *) src + num_elems_whole; - float * dstf = (float *) dst + num_elems_whole; - - HVX_Vector in = *(HVX_UVector *) srcf; - - HVX_Vector out = Q6_Vqf32_vmpy_VsfVsf(in, val_vec); - hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(out)); - } -} - -void hvx_sub_f32(const uint8_t * restrict src0, - const uint8_t * restrict src1, - uint8_t * restrict dst, - const int num_elems) { - size_t left_over = num_elems & (VLEN_FP32 - 1); - size_t num_elems_whole = num_elems - left_over; - - int unaligned_addr = 0; - int unaligned_loop = 0; - if ((0 == htp_is_aligned((void *) src0, VLEN)) || (0 == htp_is_aligned((void *) src1, VLEN)) || - (0 == htp_is_aligned((void *) dst, VLEN))) { - FARF(HIGH, "hvx_sub_f32: unaligned address in hvx op, possibly slower execution\n"); - unaligned_addr = 1; - } - - if ((1 == unaligned_addr) && (num_elems_whole != 0)) { - unaligned_loop = 1; - FARF(HIGH, "hvx_sub_f32: unaligned loop in hvx op, possibly slower execution\n"); - } - - if (0 == unaligned_loop) { - HVX_Vector * restrict vec_in1 = (HVX_Vector *) src0; - HVX_Vector * restrict vec_in2 = (HVX_Vector *) src1; - HVX_Vector * restrict vec_out = (HVX_Vector *) dst; - - #pragma unroll(4) - for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { - HVX_Vector v = Q6_Vqf32_vsub_VsfVsf(*vec_in1++, *vec_in2++); - *vec_out++ = Q6_Vsf_equals_Vqf32(v); - } - } else { - #pragma unroll(4) - for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { - HVX_Vector in1 = *(HVX_UVector *) (src0 + i * SIZEOF_FP32); - HVX_Vector in2 = *(HVX_UVector *) (src1 + i * SIZEOF_FP32); - - HVX_Vector out = Q6_Vqf32_vsub_VsfVsf(in1, in2); - - *(HVX_UVector *) (dst + i * SIZEOF_FP32) = Q6_Vsf_equals_Vqf32(out); - } - } - - if (left_over > 0) { - const float * src0f = (const float *) src0 + num_elems_whole; - const float * src1f = (const float *) src1 + num_elems_whole; - float * dstf = (float *) dst + num_elems_whole; - - HVX_Vector in1 = *(HVX_UVector *) src0f; - HVX_Vector in2 = *(HVX_UVector *) src1f; - - HVX_Vector out = Q6_Vqf32_vsub_VsfVsf(in1, in2); - hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(out)); - } -} - -void hvx_sub_f32_opt(const uint8_t * restrict src0, - const uint8_t * restrict src1, - uint8_t * restrict dst, - const int num_elems) { - htp_binary_ops_preamble; - - for (int i = 0; i < step_of_4; i++) { - HVX_Vector v1a = *(HVX_Vector *) src0_curr; - - HVX_Vector v1b = *(HVX_Vector *) src1_curr; - - HVX_Vector v2a = *(HVX_Vector *) (src0_curr + VLEN); - - HVX_Vector v1 = Q6_Vqf32_vsub_VsfVsf(v1a, v1b); - - HVX_Vector v2b = *(HVX_Vector *) (src1_curr + VLEN); - - HVX_Vector v3a = *(HVX_Vector *) (src0_curr + 2 * VLEN); - - HVX_Vector v2 = Q6_Vqf32_vsub_VsfVsf(v2a, v2b); - - *(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v1); - - HVX_Vector v3b = *(HVX_Vector *) (src1_curr + 2 * VLEN); - - HVX_Vector v4a = *(HVX_Vector *) (src0_curr + 3 * VLEN); - - src0_curr += 4 * VLEN; - - HVX_Vector v3 = Q6_Vqf32_vsub_VsfVsf(v3a, v3b); - - *(HVX_Vector *) (dst_curr + VLEN) = Q6_Vsf_equals_Vqf32(v2); - - HVX_Vector v4b = *(HVX_Vector *) (src1_curr + 3 * VLEN); - - *(HVX_Vector *) (dst_curr + 2 * VLEN) = Q6_Vsf_equals_Vqf32(v3); - - HVX_Vector v4 = Q6_Vqf32_vsub_VsfVsf(v4a, v4b); - - src1_curr += 4 * VLEN; - - *(HVX_Vector *) (dst_curr + 3 * VLEN) = Q6_Vsf_equals_Vqf32(v4); - - dst_curr += 4 * VLEN; - } - for (int i = 0; i < step_of_2; i++) { - HVX_Vector v1a = *(HVX_Vector *) src0_curr; - - HVX_Vector v1b = *(HVX_Vector *) src1_curr; - - HVX_Vector v2a = *(HVX_Vector *) (src0_curr + VLEN); - - HVX_Vector v1 = Q6_Vqf32_vsub_VsfVsf(v1a, v1b); - - HVX_Vector v2b = *(HVX_Vector *) (src1_curr + VLEN); - - *(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v1); - - src0_curr += 2 * VLEN; - - HVX_Vector v2 = Q6_Vqf32_vsub_VsfVsf(v2a, v2b); - - src1_curr += 2 * VLEN; - - *(HVX_Vector *) (dst_curr + VLEN) = Q6_Vsf_equals_Vqf32(v2); - - dst_curr += 2 * VLEN; - } - for (int i = 0; i < step_of_1; i++) { - HVX_Vector va = *(HVX_Vector *) src0_curr; - - src0_curr += VLEN; - - HVX_Vector vb = *(HVX_Vector *) src1_curr; - - src1_curr += VLEN; - - HVX_Vector v = Q6_Vqf32_vsub_VsfVsf(va, vb); - - *(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v); - - dst_curr += VLEN; - } - if (remaining > 0) { - HVX_Vector v = Q6_Vqf32_vsub_VsfVsf(*(HVX_Vector *) src0_curr, *(HVX_Vector *) src1_curr); - hvx_vec_store_u((void *) dst_curr, remaining * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(v)); - } -} - -void hvx_sub_scalar_f32(const uint8_t * restrict src, const float val, uint8_t * restrict dst, const int num_elems) { - size_t left_over = num_elems & (VLEN_FP32 - 1); - size_t num_elems_whole = num_elems - left_over; - - int unaligned_addr = 0; - int unaligned_loop = 0; - if ((0 == htp_is_aligned((void *) src, VLEN)) || (0 == htp_is_aligned((void *) dst, VLEN))) { - FARF(HIGH, "hvx_sub_scalar_f32: unaligned address in hvx op, possibly slower execution\n"); - unaligned_addr = 1; - } - - if ((1 == unaligned_addr) && (num_elems_whole != 0)) { - unaligned_loop = 1; - FARF(HIGH, "hvx_sub_scalar_f32: unaligned loop in hvx op, possibly slower execution\n"); - } - - HVX_Vector val_vec = hvx_vec_splat_fp32(val); - - if (0 == unaligned_loop) { - HVX_Vector * restrict vec_in1 = (HVX_Vector *) src; - HVX_Vector * restrict vec_out = (HVX_Vector *) dst; - - #pragma unroll(4) - for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { - HVX_Vector v = Q6_Vqf32_vsub_VsfVsf(*vec_in1++, val_vec); - *vec_out++ = Q6_Vsf_equals_Vqf32(v); - } - } else { - #pragma unroll(4) - for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { - HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32); - - HVX_Vector out = Q6_Vqf32_vsub_VsfVsf(in, val_vec); - - *(HVX_UVector *) (dst + i * SIZEOF_FP32) = Q6_Vsf_equals_Vqf32(out); - } - } - - if (left_over > 0) { - const float * srcf = (const float *) src + num_elems_whole; - float * dstf = (float *) dst + num_elems_whole; - - HVX_Vector in = *(HVX_UVector *) srcf; - - HVX_Vector out = Q6_Vqf32_vsub_VsfVsf(in, val_vec); - hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(out)); - } -} - -float hvx_sum_of_squares_f32(const uint8_t * restrict src, const int num_elems) { - int left_over = num_elems & (VLEN_FP32 - 1); - int num_elems_whole = num_elems - left_over; - - if (0 == htp_is_aligned((void *) src, VLEN)) { - FARF(HIGH, "hvx_sum_of_squares_f32: unaligned address in hvx op, possibly slower execution\n"); - } - - assert((1 == htp_is_aligned((void *) src, VLEN)) || (0 == num_elems_whole)); - - HVX_Vector * restrict vec_in1 = (HVX_Vector *) src; - - HVX_Vector sum_vec_acc = Q6_V_vsplat_R(0x00000000); - HVX_Vector zero_vec = Q6_V_vsplat_R(0x00000000); - - #pragma unroll(4) - for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { - HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(*vec_in1, *vec_in1); - sum_vec_acc = Q6_Vqf32_vadd_Vqf32Vqf32(sum_vec_acc, v); - vec_in1++; - } - - if (left_over > 0) { - const float * srcf = (const float *) src + num_elems_whole; - - HVX_Vector vec_left = *(HVX_UVector *) srcf; - - HVX_Vector vec_left_sq = Q6_Vqf32_vmpy_VsfVsf(vec_left, vec_left); - HVX_Vector vec_tmp = Q6_V_valign_VVR(vec_left_sq, zero_vec, left_over * SIZEOF_FP32); - - sum_vec_acc = Q6_Vqf32_vadd_Vqf32Vqf32(sum_vec_acc, vec_tmp); - } - - HVX_Vector v = hvx_vec_qf32_reduce_sum(sum_vec_acc); - return hvx_vec_get_fp32(Q6_Vsf_equals_Vqf32(v)); -} - -float hvx_self_sum_f32(const uint8_t * restrict src, const int num_elems) { - int left_over = num_elems & (VLEN_FP32 - 1); - int num_elems_whole = num_elems - left_over; - - int unaligned_addr = 0; - int unaligned_loop = 0; - if (0 == htp_is_aligned((void *) src, VLEN)) { - FARF(HIGH, "hvx_self_sum_f32: unaligned address in hvx op, possibly slower execution\n"); - unaligned_addr = 1; - } - - if ((1 == unaligned_addr) && (num_elems_whole != 0)) { - unaligned_loop = 1; - FARF(HIGH, "hvx_self_sum_f32: unaligned loop in hvx op, possibly slower execution\n"); - } - - HVX_Vector sum_vec = Q6_V_vsplat_R(0x00000000); - HVX_Vector zero_vec = Q6_V_vsplat_R(0x00000000); - - if (0 == unaligned_loop) { - HVX_Vector * vec_in = (HVX_Vector *) src; - - #pragma unroll(4) - for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { - // sum_vec = Q6_Vqf32_vadd_Vqf32Vsf(sum_vec, *vec_in++); - sum_vec = Q6_Vqf32_vadd_VsfVsf(Q6_Vsf_equals_Vqf32(sum_vec), *vec_in++); - } - } else { - #pragma unroll(4) - for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { - HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32); - - sum_vec = Q6_Vqf32_vadd_VsfVsf(Q6_Vsf_equals_Vqf32(sum_vec), in); - } - } - - if (left_over > 0) { - const float * srcf = (const float *) src + num_elems_whole; - - HVX_Vector vec_left = *(HVX_UVector *) srcf; - HVX_Vector vec_tmp = Q6_V_valign_VVR(vec_left, zero_vec, left_over * SIZEOF_FP32); - // sum_vec = Q6_Vqf32_vadd_Vqf32Vsf(sum_vec, vec_tmp); - sum_vec = Q6_Vqf32_vadd_VsfVsf(Q6_Vsf_equals_Vqf32(sum_vec), vec_tmp); - } - - HVX_Vector v = hvx_vec_qf32_reduce_sum(sum_vec); - return hvx_vec_get_fp32(Q6_Vsf_equals_Vqf32(v)); -} - -float hvx_self_max_f32(const uint8_t * restrict src, const int num_elems) { - int left_over = num_elems & (VLEN_FP32 - 1); - int num_elems_whole = num_elems - left_over; - - int unaligned_addr = 0; - int unaligned_loop = 0; - if (0 == htp_is_aligned((void *) src, VLEN)) { - FARF(HIGH, "hvx_self_max_f32: unaligned address in hvx op, possibly slower execution\n"); - unaligned_addr = 1; - } - - if ((1 == unaligned_addr) && (num_elems_whole != 0)) { - unaligned_loop = 1; - FARF(HIGH, "hvx_self_max_f32: unaligned loop in hvx op, possibly slower execution\n"); - } - - HVX_Vector vec_max = hvx_vec_splat_fp32(((const float *) src)[0]); - HVX_Vector vec_first = hvx_vec_splat_fp32(((const float *) src)[0]); - - if (0 == unaligned_loop) { - HVX_Vector * restrict vec_in = (HVX_Vector *) src; - - #pragma unroll(4) - for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { - vec_max = Q6_Vsf_vmax_VsfVsf(vec_max, *vec_in++); - } - } else { - #pragma unroll(4) - for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { - HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32); - - vec_max = Q6_Vsf_vmax_VsfVsf(vec_max, in); - } - } - - if (left_over > 0) { - const float * srcf = (const float *) src + num_elems_whole; - - HVX_Vector in = *(HVX_UVector *) srcf; - - HVX_Vector temp = Q6_V_valign_VVR(in, vec_first, left_over * SIZEOF_FP32); - vec_max = Q6_Vsf_vmax_VsfVsf(vec_max, temp); - } - - HVX_Vector v = hvx_vec_reduce_max_fp32(vec_max); - return hvx_vec_get_fp32(v); -} - -void hvx_min_scalar_f32(const uint8_t * restrict src, const float val, uint8_t * restrict dst, const int num_elems) { - size_t left_over = num_elems & (VLEN_FP32 - 1); - size_t num_elems_whole = num_elems - left_over; - int unalign_address = 0; - if ((0 == htp_is_aligned((void *) src, VLEN)) || (0 == htp_is_aligned((void *) dst, VLEN))) { - FARF(HIGH, "hvx_min_scalar_f32: unaligned address in hvx op, possibly slower execution\n"); - unalign_address = 1; - } - - const float * src_f = (const float *) src; - - HVX_Vector vec_min = hvx_vec_splat_fp32(val); - - if(unalign_address == 0){ - HVX_Vector * restrict vec_in = (HVX_Vector *) src; - HVX_Vector * restrict vec_out = (HVX_Vector *) dst; - - #pragma unroll(4) - for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { - HVX_Vector min_clamp = Q6_Vsf_vmin_VsfVsf(vec_min, *vec_in++); - *vec_out++ = (min_clamp); - } - }else{ - HVX_UVector * restrict vec_in = (HVX_Vector *) src; - HVX_UVector * restrict vec_out = (HVX_Vector *) dst; - - #pragma unroll(4) - for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { - HVX_Vector min_clamp = Q6_Vsf_vmin_VsfVsf(vec_min, *vec_in++); - *vec_out++ = (min_clamp); - } - } - - if (left_over > 0 ) { - const float * srcf = (const float *) src + num_elems_whole; - float * dstf = (float *) dst + num_elems_whole; - - HVX_UVector in = *(HVX_UVector *) srcf; - - HVX_UVector min_clamp = Q6_Vsf_vmin_VsfVsf(vec_min, in); - - hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, (min_clamp)); - } -} - -void hvx_clamp_scalar_f32(const uint8_t * restrict src, - const float limit_left, - const float limit_right, - uint8_t * restrict dst, - const int num_elems) { - size_t left_over = num_elems & (VLEN_FP32 - 1); - size_t num_elems_whole = num_elems - left_over; - - int unalign_address = 0; - if ((0 == htp_is_aligned((void *) src, VLEN)) || (0 == htp_is_aligned((void *) dst, VLEN))) { - FARF(HIGH, "hvx_clamp_scalar_f32: unaligned address in hvx op, possibly slower execution\n"); - unalign_address = 1; - } - - HVX_Vector range_left = hvx_vec_splat_fp32(limit_left); - HVX_Vector range_right = hvx_vec_splat_fp32(limit_right); - - if(unalign_address == 0){ - HVX_Vector * restrict vec_in = (HVX_Vector *) src; - HVX_Vector * restrict vec_out = (HVX_Vector *) dst; - - - - #pragma unroll(4) - for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { - HVX_Vector in_vec = *vec_in++; - HVX_Vector temp_v = in_vec; - - HVX_VectorPred pred_cap_right = Q6_Q_vcmp_gt_VsfVsf(in_vec, range_right); - HVX_VectorPred pred_cap_left = Q6_Q_vcmp_gt_VsfVsf(range_left, in_vec); - - in_vec = Q6_V_vmux_QVV(pred_cap_right, range_right, temp_v); - in_vec = Q6_V_vmux_QVV(pred_cap_left, range_left, in_vec); - - *vec_out++ = in_vec; - } - - }else{ - - HVX_UVector * restrict vec_in = (HVX_UVector *) src; - HVX_UVector * restrict vec_out = (HVX_UVector *) dst; - - #pragma unroll(4) - for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { - HVX_Vector in_vec = *vec_in++; - HVX_Vector temp_v = in_vec; - - HVX_VectorPred pred_cap_right = Q6_Q_vcmp_gt_VsfVsf(in_vec, range_right); - HVX_VectorPred pred_cap_left = Q6_Q_vcmp_gt_VsfVsf(range_left, in_vec); - - in_vec = Q6_V_vmux_QVV(pred_cap_right, range_right, temp_v); - in_vec = Q6_V_vmux_QVV(pred_cap_left, range_left, in_vec); - - *vec_out++ = in_vec; - } - - } - - if (left_over > 0) { - const float * srcf = (const float *) src + num_elems_whole; - float * dstf = (float *) dst + num_elems_whole; - - HVX_Vector in_vec = *(HVX_UVector *) srcf; - - HVX_Vector temp_v = in_vec; - - HVX_VectorPred pred_cap_right = Q6_Q_vcmp_gt_VsfVsf(in_vec, range_right); - HVX_VectorPred pred_cap_left = Q6_Q_vcmp_gt_VsfVsf(range_left, in_vec); - - in_vec = Q6_V_vmux_QVV(pred_cap_right, range_right, temp_v); - in_vec = Q6_V_vmux_QVV(pred_cap_left, range_left, in_vec); - - hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, in_vec); - } -} - - diff --git a/ggml/src/ggml-hexagon/htp/hvx-utils.h b/ggml/src/ggml-hexagon/htp/hvx-utils.h index 22876e6dbaa..7b79a5ea322 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-utils.h +++ b/ggml/src/ggml-hexagon/htp/hvx-utils.h @@ -1,1353 +1,17 @@ #ifndef HVX_UTILS_H #define HVX_UTILS_H -#include "ops-utils.h" - -#include -#include - -#define SIZEOF_FP32 (4) -#define SIZEOF_FP16 (2) -#define VLEN (128) -#define VLEN_FP32 (VLEN / SIZEOF_FP32) -#define VLEN_FP16 (VLEN / SIZEOF_FP16) - -typedef union { - HVX_Vector v; - uint8_t b[VLEN]; - uint16_t h[VLEN_FP16]; - uint32_t w[VLEN_FP32]; - __fp16 fp16[VLEN_FP16]; - float fp32[VLEN_FP32]; -} __attribute__((aligned(VLEN), packed)) HVX_VectorAlias; - -/* Q6_Vsf_equals_Vw is only available on v73+.*/ -#if __HVX_ARCH__ < 73 -static inline HVX_Vector int32_to_qfloat(HVX_Vector const in) -{ - HVX_Vector const vzero = Q6_V_vzero(); - HVX_VectorPred is_zero = Q6_Q_vcmp_eq_VwVw(in, vzero); - HVX_Vector lshift = Q6_Vw_vnormamt_Vw(in); - HVX_Vector normalized = Q6_Vw_vasl_VwVw(in, lshift); - HVX_Vector vexp = Q6_Vw_vsub_VwVw(Q6_V_vsplat_R(0x7f + 30), lshift); - HVX_Vector mant = Q6_V_vand_VV(Q6_V_vsplat_R(0xFFFFFF00), normalized); - HVX_Vector ret = Q6_V_vmux_QVV(is_zero, vzero, Q6_Vw_vadd_VwVw(mant, vexp)); - return ret; -} - -static inline HVX_Vector Q6_Vsf_equals_Vw(HVX_Vector const in) -{ - return Q6_Vsf_equals_Vqf32(int32_to_qfloat(in)); -} -#endif - -static inline HVX_Vector hvx_vec_splat_fp32(float v) { - union { - float f; - uint32_t i; - } fp32 = { .f = v }; - - return Q6_V_vsplat_R(fp32.i); -} - -static inline HVX_Vector hvx_vec_splat_fp16(float v) { - union { - __fp16 f; - uint16_t i; - } fp16 = { .f = v }; - - return Q6_Vh_vsplat_R(fp16.i); -} - -static inline void hvx_vec_store_u(void * addr, uint32_t n, HVX_Vector v) { - // Rotate as needed. - v = Q6_V_vlalign_VVR(v, v, (size_t) addr); - - uint32_t left_off = (size_t) addr & 127; - uint32_t right_off = left_off + n; - - HVX_VectorPred ql_not = Q6_Q_vsetq_R((size_t) addr); - HVX_VectorPred qr = Q6_Q_vsetq2_R(right_off); - - if (right_off > 128) { - Q6_vmem_QRIV(qr, (HVX_Vector *) addr + 1, v); - // all 1's - qr = Q6_Q_vcmp_eq_VbVb(v, v); - } - - ql_not = Q6_Q_or_QQn(ql_not, qr); - Q6_vmem_QnRIV(ql_not, (HVX_Vector *) addr, v); -} - -static inline void hvx_vec_store_a(void * ptr, size_t n, HVX_Vector v) { - assert((unsigned long) ptr % 128 == 0); - - HVX_VectorPred ql_not = Q6_Q_vsetq_R((size_t) ptr); - HVX_VectorPred qr = Q6_Q_vsetq2_R(n); - ql_not = Q6_Q_or_QQn(ql_not, qr); - Q6_vmem_QnRIV(ql_not, (HVX_Vector *) ptr, v); -} - -static inline HVX_Vector hvx_vec_repl4(HVX_Vector v) { - // vdelta control to replicate first 4 bytes across all elements - static const uint8_t __attribute__((aligned(128))) repl[128] = { - 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, - 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, - 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, - 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, - 0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, - 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, - 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, - 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, - }; - - HVX_Vector ctrl = *(HVX_Vector *) repl; - return Q6_V_vdelta_VV(v, ctrl); -} - -// copy n fp16 elements : source and destination are aligned to HVX Vector (128) -static inline void hvx_copy_fp16_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { - HVX_Vector * restrict vdst = (HVX_Vector *) dst; - HVX_Vector * restrict vsrc = (HVX_Vector *) src; - - assert((unsigned long) dst % 128 == 0); - assert((unsigned long) src % 128 == 0); - - uint32_t nvec = n / 64; - uint32_t nloe = n % 64; - - uint32_t i = 0; - - #pragma unroll(4) - for (; i < nvec; i++) { - HVX_Vector v = vsrc[i]; - vdst[i] = v; - } - - if (nloe) { - HVX_Vector v = vsrc[i]; - hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(__fp16), v); - } -} - -// copy n fp16 elements : source is aligned, destination is potentially unaligned -static inline void hvx_copy_fp16_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { - HVX_UVector * restrict vdst = (HVX_UVector *) dst; - HVX_Vector * restrict vsrc = (HVX_Vector *) src; - - assert((unsigned long) src % 128 == 0); - - uint32_t nvec = n / 64; - uint32_t nloe = n % 64; - - uint32_t i = 0; - - #pragma unroll(4) - for (; i < nvec; i++) { - HVX_Vector v = vsrc[i]; - vdst[i] = v; - } - - if (nloe) { - HVX_Vector v = vsrc[i]; - hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(__fp16), v); - } -} - -// copy n fp16 elements : source is aligned, destination is potentially unaligned -static inline void hvx_copy_fp16_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { - HVX_Vector * restrict vdst = (HVX_Vector *) dst; - HVX_UVector * restrict vsrc = (HVX_UVector *) src; - - assert((unsigned long) dst % 128 == 0); - - uint32_t nvec = n / 64; - uint32_t nloe = n % 64; - - uint32_t i = 0; - - #pragma unroll(4) - for (; i < nvec; i++) { - HVX_Vector v = vsrc[i]; - vdst[i] = v; - } - - if (nloe) { - HVX_Vector v = vsrc[i]; - hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(__fp16), v); - } -} - -// copy n fp32 elements : source and destination are aligned to HVX Vector (128) -static inline void hvx_copy_fp32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { - HVX_Vector * restrict vdst = (HVX_Vector *) dst; - HVX_Vector * restrict vsrc = (HVX_Vector *) src; - - assert((unsigned long) dst % 128 == 0); - assert((unsigned long) src % 128 == 0); - - uint32_t nvec = n / 32; - uint32_t nloe = n % 32; - - uint32_t i = 0; - - #pragma unroll(4) - for (; i < nvec; i++) { - HVX_Vector v = vsrc[i]; - vdst[i] = v; - } - - if (nloe) { - HVX_Vector v = vsrc[i]; - hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(float), v); - } -} - -// copy n fp32 elements : source is aligned, destination is unaligned -static inline void hvx_copy_fp32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { - HVX_UVector * restrict vdst = (HVX_UVector *) dst; - HVX_Vector * restrict vsrc = (HVX_Vector *) src; - - assert((unsigned long) src % 128 == 0); - - uint32_t nvec = n / 32; - uint32_t nloe = n % 32; - - uint32_t i = 0; - - #pragma unroll(4) - for (; i < nvec; i++) { - HVX_Vector v = vsrc[i]; - vdst[i] = v; - } - - if (nloe) { - HVX_Vector v = vsrc[i]; - hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(float), v); - } -} - -// copy n fp32 elements : source is unaligned, destination is aligned -static inline void hvx_copy_fp32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { - HVX_Vector * restrict vdst = (HVX_Vector *) dst; - HVX_UVector * restrict vsrc = (HVX_UVector *) src; - - assert((unsigned long) dst % 128 == 0); - - uint32_t nvec = n / 32; - uint32_t nloe = n % 32; - - uint32_t i = 0; - - #pragma unroll(4) - for (; i < nvec; i++) { - HVX_Vector v = vsrc[i]; - vdst[i] = v; - } - - if (nloe) { - HVX_Vector v = vsrc[i]; - hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(float), v); - } -} - -// copy n fp32 elements : source is unaligned, destination unaligned -static inline void hvx_copy_fp32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { - HVX_UVector * restrict vdst = (HVX_UVector *) dst; - HVX_UVector * restrict vsrc = (HVX_UVector *) src; - - assert((unsigned long) dst % 128 == 0); - - uint32_t nvec = n / 32; - uint32_t nloe = n % 32; - - uint32_t i = 0; - - #pragma unroll(4) - for (; i < nvec; i++) { - HVX_Vector v = vsrc[i]; - vdst[i] = v; - } - - if (nloe) { - HVX_Vector v = vsrc[i]; - hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(float), v); - } -} - -// copy/convert n fp32 elements into n fp16 elements : source is unaligned, destination is unaligned -static inline void hvx_copy_fp16_fp32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { - HVX_UVector * restrict vdst = (HVX_UVector *) dst; // fp16 - HVX_UVector * restrict vsrc = (HVX_UVector *) src; // fp32 - - const HVX_Vector zero = Q6_V_vsplat_R(0); - - uint32_t nvec = n / 64; - uint32_t nloe = n % 64; - - uint32_t i = 0; - - #pragma unroll(4) - for (; i < nvec; i++) { - // Load y (fp32) and convert into fp16 - HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements - HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements - HVX_Vector s_hf = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf)); - vdst[i] = Q6_Vh_vdeal_Vh(s_hf); - } - - if (nloe) { - // Load y (fp32) and convert into fp16 - HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements - HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements - HVX_Vector s_hf = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf)); - hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(__fp16), Q6_Vh_vdeal_Vh(s_hf)); - } -} - -// copy/convert n fp32 elements into n fp16 elements : source is aligned, destination is unaligned -static inline void hvx_copy_fp16_fp32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { - HVX_UVector * restrict vdst = (HVX_UVector *) dst; // fp16 - HVX_Vector * restrict vsrc = (HVX_Vector *) src; // fp32 - - const HVX_Vector zero = Q6_V_vsplat_R(0); - - uint32_t nvec = n / 64; - uint32_t nloe = n % 64; - - uint32_t i = 0; - - #pragma unroll(4) - for (; i < nvec; i++) { - // Load y (fp32) and convert into fp16 - HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements - HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements - HVX_Vector s_hf = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf)); - vdst[i] = Q6_Vh_vdeal_Vh(s_hf); - } - - if (nloe) { - // Load y (fp32) and convert into fp16 - HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements - HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements - HVX_Vector s_hf = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf)); - hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(__fp16), Q6_Vh_vdeal_Vh(s_hf)); - } -} - -// copy/convert n fp32 elements into n fp16 elements : source is unaligned, destination is aligned -static inline void hvx_copy_fp16_fp32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { - HVX_Vector * restrict vdst = (HVX_Vector *) dst; // fp16 - HVX_UVector * restrict vsrc = (HVX_UVector *) src; // fp32 - - const HVX_Vector zero = Q6_V_vsplat_R(0); - - uint32_t nvec = n / 64; - uint32_t nloe = n % 64; - - uint32_t i = 0; - - #pragma unroll(4) - for (; i < nvec; i++) { - // Load y (fp32) and convert into fp16 - HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements - HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements - HVX_Vector s_hf = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf)); - vdst[i] = Q6_Vh_vdeal_Vh(s_hf); - } - - if (nloe) { - // Load y (fp32) and convert into fp16 - HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements - HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements - HVX_Vector s_hf = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf)); - hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(__fp16), Q6_Vh_vdeal_Vh(s_hf)); - } -} - -// bcast 1 fp32 element from source to n fp32 elements in destination : destination is aligned -static inline void hvx_bcast_fp32_a(uint8_t * restrict dst, float elem, uint32_t n) { - HVX_Vector * restrict vdst = (HVX_Vector *) dst; - - HVX_Vector velem = hvx_vec_splat_fp32(elem); - - assert((unsigned long) dst % 128 == 0); - - uint32_t nvec = n / 32; - uint32_t nloe = n % 32; - - uint32_t i = 0; - - #pragma unroll(4) - for (; i < nvec; i++) { - vdst[i] = velem; - } - - if (nloe) { - hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(float), velem); - } -} - - -/* Return whether 'n' elements from vector are in the one chunk of 'chunk_size'. */ -static __attribute__((always_inline)) int32_t is_in_one_chunk(void * addr, uint32_t n, uint32_t chunk_size) { - uint32_t left_off = (size_t) addr & (chunk_size - 1); - uint32_t right_off = left_off + n; - return right_off <= chunk_size; -} - -static void hvx_vec_dump_fp16_n(char * pref, HVX_Vector v, uint32_t n) { - HVX_VectorAlias u = { .v = v }; - - const uint32_t n0 = n / 16; - const uint32_t n1 = n % 16; - int i = 0; - for (; i < n0; i++) { - htp_dump_fp16_line(pref, u.fp16 + (16 * i), 16); - } - if (n1) { - htp_dump_fp16_line(pref, u.fp16 + (16 * i), n1); - } -} - -static void hvx_vec_dump_fp16(char * pref, HVX_Vector v) { - hvx_vec_dump_fp16_n(pref, v, 64); -} - -static void hvx_vec_dump_fp32_n(char * pref, HVX_Vector v, uint32_t n) { - union { - HVX_Vector v; - float d[32]; - } u = { .v = v }; - - const uint32_t n0 = n / 16; - const uint32_t n1 = n % 16; - int i = 0; - for (; i < n0; i++) { - htp_dump_fp32_line(pref, u.d + (16 * i), 16); - } - if (n1) { - htp_dump_fp32_line(pref, u.d + (16 * i), n1); - } -} - -static void hvx_vec_dump_fp32_hmt(char * pref, HVX_Vector v) { - union { - HVX_Vector v; - float d[32]; - } u = { .v = v }; - - FARF(HIGH, "%s: %.6f %.6f %.6f %.6f ... %.6f %.6f %.6f %.6f ... %.6f %.6f %.6f %.6f\n", pref, u.d[0], u.d[1], - u.d[2], u.d[3], u.d[12], u.d[13], u.d[14], u.d[15], u.d[28], u.d[29], u.d[30], u.d[31]); -} - -static void hvx_vec_dump_fp32(char * pref, HVX_Vector v) { - hvx_vec_dump_fp32_n(pref, v, 32); -} - -static void hvx_vec_dump_int32(char * pref, HVX_Vector v) { - union { - HVX_Vector v; - int32_t d[32]; - } u = { .v = v }; - - for (int i = 0; i < 32 / 16; i++) { - htp_dump_int32_line(pref, u.d + (16 * i), 16); - } -} - -static void hvx_vec_dump_int32_hmt(char * pref, HVX_Vector v) { - union { - HVX_Vector v; - int32_t d[32]; - } u = { .v = v }; - - FARF(HIGH, "%s: %d %d %d %d ... %d %d %d %d ... %d %d %d %d\n", pref, u.d[0], u.d[1], u.d[2], u.d[3], u.d[12], - u.d[13], u.d[14], u.d[15], u.d[28], u.d[29], u.d[30], u.d[31]); -} - -static void hvx_vec_dump_int8_hmt(char * pref, HVX_Vector v) { - union { - HVX_Vector v; - int8_t d[128]; - } u = { .v = v }; - - FARF(HIGH, "%s: %d %d %d %d ... %d %d %d %d ... %d %d %d %d\n", pref, u.d[0], u.d[1], u.d[2], u.d[3], u.d[60], - u.d[61], u.d[62], u.d[63], u.d[124], u.d[125], u.d[126], u.d[127]); -} - -static void hvx_vec_dump_int8(char * pref, HVX_Vector v) { - union { - HVX_Vector v; - int8_t d[128]; - } u = { .v = v }; - - for (int i = 0; i < 128 / 16; i++) { - htp_dump_int8_line(pref, u.d + (16 * i), 16); - } -} - -static void hvx_vec_dump_uint8(char * pref, HVX_Vector v) { - union { - HVX_Vector v; - uint8_t d[128]; - } u = { .v = v }; - - for (int i = 0; i < 128 / 16; i++) { - htp_dump_uint8_line(pref, u.d + (16 * i), 16); - } -} - -static bool hvx_vec_eq(HVX_Vector v0, HVX_Vector v1, size_t n) { - typedef union { - HVX_Vector v; - int8_t d[128]; - } U; - - U u0 = { .v = v0 }; - U u1 = { .v = v1 }; - - for (int i = 0; i < n; i++) { - if (u0.d[i] != u1.d[i]) { - return false; - } - } - - return true; -} - -static inline float hvx_vec_get_fp32(HVX_Vector v) { - float __attribute__((aligned(128))) x; - hvx_vec_store_a(&x, 4, v); - return x; -} - -static inline HVX_Vector hvx_vec_int32_reduce_sum_n(HVX_Vector in, unsigned int n) { - unsigned int total = n * 4; // total vec nbytes - unsigned int width = 4; // int32 - - HVX_Vector sum = in, sum_t; - while (width < total) { - sum_t = Q6_V_vror_VR(sum, width); // rotate right - sum = Q6_Vw_vadd_VwVw(sum_t, sum); // elementwise sum - width = width << 1; - } - return sum; -} - -static inline HVX_Vector hvx_vec_int32_reduce_sum(HVX_Vector in) { - return hvx_vec_int32_reduce_sum_n(in, 32); -} - -static inline HVX_Vector hvx_vec_qf32_reduce_sum_n(HVX_Vector in, unsigned int n) { - unsigned int total = n * 4; // total vec nbytes - unsigned int width = 4; // fp32 nbytes - - HVX_Vector sum = in, sum_t; - while (width < total) { - sum_t = Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum), width); // rotate right - sum = Q6_Vqf32_vadd_Vqf32Vsf(sum, sum_t); // elementwise sum - width = width << 1; - } - return sum; -} - -static inline HVX_Vector hvx_vec_qf32_reduce_sum(HVX_Vector in) { - return hvx_vec_qf32_reduce_sum_n(in, 32); -} - -static inline HVX_Vector hvx_vec_fp32_reduce_sum_n(HVX_Vector in, unsigned int n) { - unsigned int total = n * 4; // total vec nbytes - unsigned int width = 4; // fp32 nbytes - - HVX_Vector sum = in, sum_t; - while (width < total) { - sum_t = Q6_V_vror_VR(sum, width); // rotate right - sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(sum, sum_t)); // elementwise sum - width = width << 1; - } - return sum; -} - -static inline HVX_Vector hvx_vec_fp32_reduce_sum(HVX_Vector in) { - return hvx_vec_fp32_reduce_sum_n(in, 32); -} - -static inline HVX_Vector hvx_vec_reduce_max_fp16(HVX_Vector in) { - unsigned total = 128; // total vec nbytes - unsigned width = 2; // fp16 nbytes - - HVX_Vector _max = in, _max_t; - while (width < total) { - _max_t = Q6_V_vror_VR(_max, width); // rotate right - _max = Q6_Vhf_vmax_VhfVhf(_max_t, _max); // elementwise max - width = width << 1; - } - - return _max; -} - -static inline HVX_Vector hvx_vec_reduce_max2_fp16(HVX_Vector in, HVX_Vector _max) { - unsigned total = 128; // total vec nbytes - unsigned width = 2; // fp32 nbytes - - HVX_Vector _max_t; - - _max = Q6_Vhf_vmax_VhfVhf(in, _max); - while (width < total) { - _max_t = Q6_V_vror_VR(_max, width); // rotate right - _max = Q6_Vhf_vmax_VhfVhf(_max_t, _max); // elementwise max - width = width << 1; - } - - return _max; -} - -static inline HVX_Vector hvx_vec_reduce_max_fp32(HVX_Vector in) { - unsigned total = 128; // total vec nbytes - unsigned width = 4; // fp32 nbytes - - HVX_Vector _max = in, _max_t; - while (width < total) { - _max_t = Q6_V_vror_VR(_max, width); // rotate right - _max = Q6_Vsf_vmax_VsfVsf(_max_t, _max); // elementwise max - width = width << 1; - } - - return _max; -} - -static inline HVX_Vector hvx_vec_reduce_max2_fp32(HVX_Vector in, HVX_Vector _max) { - unsigned total = 128; // total vec nbytes - unsigned width = 4; // fp32 nbytes - - HVX_Vector _max_t; - - _max = Q6_Vsf_vmax_VsfVsf(in, _max); - while (width < total) { - _max_t = Q6_V_vror_VR(_max, width); // rotate right - _max = Q6_Vsf_vmax_VsfVsf(_max_t, _max); // elementwise max - width = width << 1; - } - - return _max; -} - -static inline HVX_Vector hvx_vec_abs_fp16(HVX_Vector v) { - // abs by clearing the fp16 sign bit - HVX_Vector mask = Q6_Vh_vsplat_R(0x7fff); - return Q6_V_vand_VV(v, mask); -} - -static inline HVX_Vector hvx_vec_neg_fp16(HVX_Vector v) { - // neg by setting the fp16 sign bit - HVX_Vector mask = Q6_Vh_vsplat_R(0x8000); - return Q6_V_vxor_VV(v, mask); -} - -static inline HVX_Vector hvx_vec_abs_fp32(HVX_Vector v) { - // abs by clearing the fp32 sign bit - HVX_Vector mask = Q6_V_vsplat_R(0x7fffffff); - return Q6_V_vand_VV(v, mask); -} - -static inline HVX_Vector hvx_vec_neg_fp32(HVX_Vector v) { -#if __HVX_ARCH__ > 75 - return Q6_Vsf_vfneg_Vsf(v); -#else - // neg by setting the fp32 sign bit - HVX_Vector mask = Q6_V_vsplat_R(0x80000000); - return Q6_V_vxor_VV(v, mask); -#endif // __HVX_ARCH__ > 75 -} - -// ==================================================== -// FUNCTION: 1/(x+1) y(0) = 1, y(0.5) = 0.6667, y(1) = 0.5 -// Order:3; continuity: True; Ends forced: True -// Mode: unsigned; Result fractional bits: 14 -// Peak Error: 1.1295e-04 Rms Error: 2.8410e-05 Mean Error: 1.1370e-05 -// 32769 -32706 31252 -10589 -// 32590 -30635 22793 -4493 -// 32066 -27505 16481 -2348 -// 31205 -24054 11849 -1306 - -static inline HVX_Vector hvx_vec_recip_xp1_O3_unsigned(HVX_Vector vx) { - // input is 0..0xffff representing 0.0 .. 1.0 - HVX_Vector p; - p = Q6_Vh_vlut4_VuhPh(vx, 0xFAE6F6D4EE73D6A3ull); - p = Q6_Vh_vmpa_VhVhVuhPuh_sat(p, vx, 0x2E49406159097A14ull); - p = Q6_Vh_vmps_VhVhVuhPuh_sat(p, vx, 0x5DF66B7177AB7FC2ull); - p = Q6_Vh_vmpa_VhVhVuhPuh_sat(p, vx, 0x79E57D427F4E8001ull); - return p; // signed result, 14 fractional bits -} - -// Find reciprocal of fp16. -// (1) first, convert to fp32, multiplying by 1.0; this is done to -// handle denormals. Ignoring sign and zero, result should be at -// least 5.9604645e-08 (32-bit code 0x33800000) and at most 131008 (0x47ffe000) -// (exponent in range [103,143]) -// (2) extract the mantissa into 16-bit unsigned; find reciprocal using a fitted poly -// (3) put this, along with '253-exp' (exp from (1)) together to make an qf32 -// (4) convert that to fp16 -// (5) put sign back in. Also, if the original value (w/o sign) was <0x81, replace -// the result with the max value. -static inline HVX_Vector hvx_vec_inverse_fp16(HVX_Vector vals) { - HVX_Vector em_mask = Q6_Vh_vsplat_R(0x7FFF); - HVX_Vector avals = Q6_V_vand_VV(vals, em_mask); - HVX_VectorPred is_neg = Q6_Q_vcmp_gt_VhVh(avals, vals); - // is too small to 1/x ? for 'standard' fp16, this would be 0x101 - HVX_VectorPred is_small = Q6_Q_vcmp_gt_VhVh(Q6_Vh_vsplat_R(0x101), avals); - - HVX_VectorPair to_qf32 = Q6_Wqf32_vmpy_VhfVhf(avals, Q6_Vh_vsplat_R(0x3C00)); // *1.0 - HVX_Vector to_f32_0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(to_qf32)); - HVX_Vector to_f32_1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(to_qf32)); - - // bits 22..13 contain the mantissa now (w/o hidden bit); move to bit 14..5 of a 16-bit vector - HVX_Vector mant_u16 = Q6_Vh_vshuffo_VhVh(Q6_Vw_vasl_VwR(to_f32_1, 9), Q6_Vw_vasl_VwR(to_f32_0, 9)); - // likewise extract the upper 16 from each, containing the exponents in range 103..142 - HVX_Vector exp_u16 = Q6_Vh_vshuffo_VhVh(to_f32_1, to_f32_0); - //Get exponent in IEEE 32-bit representation - exp_u16 = Q6_Vuh_vlsr_VuhR(exp_u16, 7); - - // so, mant_u16 contains an unbiased mantissa in upper 10 bits of each u16 lane - // We can consider it to be x-1.0, with 16 fractional bits, where 'x' is in range [1.0,2.0) - // Use poly to transform to 1/x, with 14 fractional bits - // - HVX_Vector rm = hvx_vec_recip_xp1_O3_unsigned(mant_u16); - - HVX_Vector vcl0 = Q6_Vuh_vcl0_Vuh(rm); //count leading zeros - - // Get mantissa for 16-bit represenation - HVX_Vector mant_recip = Q6_V_vand_VV(Q6_Vh_vasr_VhR(Q6_Vh_vasl_VhVh(rm, vcl0), 5), Q6_Vh_vsplat_R(0x03FF)); - - //Compute Reciprocal Exponent - HVX_Vector exp_recip = - Q6_Vh_vsub_VhVh(Q6_Vh_vsub_VhVh(Q6_Vh_vsplat_R(254), exp_u16), Q6_Vh_vsub_VhVh(vcl0, Q6_Vh_vsplat_R(1))); - //Convert it for 16-bit representation - exp_recip = Q6_Vh_vadd_VhVh_sat(Q6_Vh_vsub_VhVh(exp_recip, Q6_Vh_vsplat_R(127)), Q6_Vh_vsplat_R(15)); - exp_recip = Q6_Vh_vasl_VhR(exp_recip, 10); - - //Merge exponent and mantissa for reciprocal - HVX_Vector recip = Q6_V_vor_VV(exp_recip, mant_recip); - // map 'small' inputs to standard largest value 0x7bff - recip = Q6_V_vmux_QVV(is_small, Q6_Vh_vsplat_R(0x7bff), recip); - // add sign back - recip = Q6_V_vandor_VQR(recip, is_neg, 0x80008000); - return recip; -} - -#define IEEE_VSF_EXPLEN (8) -#define IEEE_VSF_EXPBIAS (127) -#define IEEE_VSF_EXPMASK (0xFF) -#define IEEE_VSF_MANTLEN (23) -#define IEEE_VSF_MANTMASK (0x7FFFFF) -#define IEEE_VSF_MIMPMASK (0x800000) - -static inline HVX_Vector hvx_vec_truncate_fp32(HVX_Vector in_vec) { - HVX_Vector mask_mant_v = Q6_V_vsplat_R(IEEE_VSF_MANTMASK); - HVX_Vector mask_impl_v = Q6_V_vsplat_R(IEEE_VSF_MIMPMASK); - HVX_Vector const_zero_v = Q6_V_vzero(); - - HVX_VectorPred q_negative = Q6_Q_vcmp_gt_VwVw(const_zero_v, in_vec); - - HVX_Vector expval_v = in_vec >> IEEE_VSF_MANTLEN; - expval_v &= IEEE_VSF_EXPMASK; - expval_v -= IEEE_VSF_EXPBIAS; - - // negative exp == fractional value - HVX_VectorPred q_negexp = Q6_Q_vcmp_gt_VwVw(const_zero_v, expval_v); - - HVX_Vector rshift_v = IEEE_VSF_MANTLEN - expval_v; // fractional bits - exp shift - - HVX_Vector mant_v = in_vec & mask_mant_v; // obtain mantissa - HVX_Vector vout = Q6_Vw_vadd_VwVw(mant_v, mask_impl_v); // add implicit 1.0 - - vout = Q6_Vw_vasr_VwVw(vout, rshift_v); // shift to obtain truncated integer - vout = Q6_V_vmux_QVV(q_negexp, const_zero_v, vout); // expval<0 -> 0 - - HVX_Vector neg_vout = -vout; - - vout = Q6_V_vmux_QVV(q_negative, neg_vout, vout); // handle negatives - - return (vout); -} - -static inline HVX_Vector hvx_vec_floor_fp32(HVX_Vector in_vec) { - HVX_Vector mask_mant_v = Q6_V_vsplat_R(IEEE_VSF_MANTMASK); - HVX_Vector mask_impl_v = Q6_V_vsplat_R(IEEE_VSF_MIMPMASK); - HVX_Vector const_mnlen_v = Q6_V_vsplat_R(IEEE_VSF_MANTLEN); - HVX_Vector const_zero_v = Q6_V_vzero(); - HVX_Vector const_negone_v = Q6_V_vsplat_R(0xbf800000); // -1 IEEE vsf - - HVX_VectorPred q_negative = Q6_Q_vcmp_gt_VwVw(const_zero_v, in_vec); - - HVX_Vector expval_v = in_vec >> IEEE_VSF_MANTLEN; - expval_v &= IEEE_VSF_EXPMASK; - expval_v -= IEEE_VSF_EXPBIAS; - - HVX_VectorPred q_negexp = Q6_Q_vcmp_gt_VwVw(const_zero_v, expval_v); - HVX_VectorPred q_expltmn = Q6_Q_vcmp_gt_VwVw(const_mnlen_v, expval_v); - HVX_VectorPred q_negexp_pos = Q6_Q_vcmp_gtand_QVwVw(q_negexp, in_vec, const_zero_v); - HVX_VectorPred q_negexp_neg = Q6_Q_vcmp_gtand_QVwVw(q_negexp, const_zero_v, in_vec); - - // if expval < 0 (q_negexp) // <0, floor is 0 - // if vin > 0 - // floor = 0 - // if vin < 0 - // floor = -1 - // if expval < mant_len (q_expltmn) // >0, but fraction may exist - // get sign (q_negative) - // mask >> expval // fraction bits to mask off - // vout = ~(mask) // apply mask to remove fraction - // if (qneg) // negative floor is one less (more, sign bit for neg) - // vout += ((impl_mask) >> expval) - // if (mask && vin) - // vout = vin - // else // already an integer - // ; // no change - - // compute floor - mask_mant_v >>= expval_v; - HVX_Vector neg_addin_v = mask_impl_v >> expval_v; - HVX_Vector vout_neg_addin = Q6_Vw_vadd_VwVw(in_vec, neg_addin_v); - HVX_Vector vout = Q6_V_vmux_QVV(q_negative, vout_neg_addin, in_vec); - - HVX_Vector mask_chk_v = Q6_V_vand_VV(in_vec, mask_mant_v); // chk if bits set - HVX_VectorPred q_integral = Q6_Q_vcmp_eq_VwVw(const_zero_v, mask_chk_v); - - HVX_Vector not_mask_v = Q6_V_vnot_V(mask_mant_v); // frac bits to clear - HVX_Vector vfrfloor_v = Q6_V_vand_VV(vout, not_mask_v); // clear frac bits - - vout = in_vec; - vout = Q6_V_vmux_QVV(q_expltmn, vfrfloor_v, vout); // expval0 -> 0 - vout = Q6_V_vmux_QVV(q_negexp_neg, const_negone_v, vout); // expval<0 x<0 -> -1 - - return vout; -} - -static inline HVX_Vector hvx_vec_i16_from_hf_rnd_sat(HVX_Vector vin) { - // This looks complicated. - // Ideally should just be Q6_Vh_equals_Vhf(vin) - // but that instruction does not do proper rounding. - - // convert to qf32, multiplying by 1.0 in the process. - HVX_VectorPair v32 = Q6_Wqf32_vmpy_VhfVhf(vin, Q6_Vh_vsplat_R(0x3C00)); - - // 'in-range' values are +/32752. - // add 192K to it, convert to sf - HVX_Vector v192K = Q6_V_vsplat_R(0x48400000); - HVX_Vector vsf_0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_lo_W(v32), v192K)); - HVX_Vector vsf_1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_hi_W(v32), v192K)); - - // for in-range cases, result is {163858... 229360} so the exponent is always 144. - // if we extract bits 21..0 as a signed quantity, and round 6 bits off, that will be the answer. - // Start by <<10 to get the final 'sign' bit in bit 15... - vsf_0 = Q6_Vw_vasl_VwR(vsf_0, 10); - vsf_1 = Q6_Vw_vasl_VwR(vsf_1, 10); - - // now round down to 16 - return Q6_Vh_vround_VwVw_sat(vsf_1, vsf_0); -} - -static inline HVX_Vector hvx_vec_inverse_fp32(HVX_Vector v_sf) { - HVX_Vector inv_aprox_sf = Q6_V_vsplat_R(0x7EEEEBB3); - HVX_Vector two_sf = hvx_vec_splat_fp32(2.0); - - // First approximation - HVX_Vector i_sf = Q6_Vw_vsub_VwVw(inv_aprox_sf, v_sf); - - HVX_Vector r_qf; - - // Refine - r_qf = Q6_Vqf32_vmpy_VsfVsf( - i_sf, Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(two_sf, Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(i_sf, v_sf))))); - r_qf = Q6_Vqf32_vmpy_Vqf32Vqf32( - r_qf, Q6_Vqf32_vsub_VsfVsf(two_sf, Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(r_qf), v_sf)))); - r_qf = Q6_Vqf32_vmpy_Vqf32Vqf32( - r_qf, Q6_Vqf32_vsub_VsfVsf(two_sf, Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(r_qf), v_sf)))); - - return Q6_Vsf_equals_Vqf32(r_qf); -} - -#define FAST_SIGMOID_LOG2F (0x3fb8aa3b) // 1.442695022 -#define FAST_SIGMOID_C1 (0x3d009076) // 0.03138777 -#define FAST_SIGMOID_C2 (0x3e8d74bd) // 0.276281267 -#define FAST_SIGMOID_C3 (0x3f000000) // 0.5 - -static inline HVX_Vector hvx_vec_fast_sigmoid_fp32(HVX_Vector v) { - v = Q6_Vqf32_vmpy_VsfVsf(v, Q6_V_vsplat_R(FAST_SIGMOID_LOG2F)); - v = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(v), Q6_V_vsplat_R(FAST_SIGMOID_C3)); - - HVX_Vector in_int = hvx_vec_truncate_fp32(Q6_Vsf_equals_Vqf32(v)); - HVX_Vector x = Q6_Vqf32_vsub_Vqf32Vsf(v, Q6_Vsf_equals_Vw(in_int)); - HVX_Vector xx = Q6_Vqf32_vmpy_Vqf32Vqf32(x, x); - - HVX_Vector v1 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(xx), Q6_V_vsplat_R(FAST_SIGMOID_C2)); - v1 = Q6_Vqf32_vadd_Vqf32Vsf(v1, Q6_V_vsplat_R(FAST_SIGMOID_LOG2F)); - - HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(x), Q6_V_vsplat_R(FAST_SIGMOID_C1)); - v2 = Q6_Vqf32_vmpy_Vqf32Vqf32(v2, xx); - v2 = Q6_Vqf32_vadd_Vqf32Vqf32(v2, x); - - HVX_Vector v3 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vqf32(v2, v1)); - HVX_Vector v3_exponent = Q6_Vw_vasl_VwR(v3, 1); - v3_exponent = Q6_Vuw_vlsr_VuwR(v3_exponent, 24); - v3_exponent = Q6_Vw_vadd_VwVw(in_int, v3_exponent); - v3 = Q6_Vw_vaslacc_VwVwR(v3, in_int, 24); - - HVX_Vector v4 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_Vqf32Vqf32(v2, v1)); - HVX_Vector v5 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(v3, v4)); - - HVX_Vector res = hvx_vec_inverse_fp32(v5); - res = Q6_Vqf32_vmpy_VsfVsf(v3, res); - - return Q6_Vsf_equals_Vqf32(res); -} - -#define EXP_COEFF_5 (0x39506967) // 0.000198757 = 1/(7!) -#define EXP_COEFF_4 (0x3AB743CE) // 0.0013982 = 1/(6!) -#define EXP_COEFF_3 (0x3C088908) // 0.00833345 = 1/(5!) -#define EXP_COEFF_2 (0x3D2AA9C1) // 0.416658 = 1/(4!) -#define EXP_COEFF_1 (0x3E2AAAAA) // 0.16666667 = 1/(3!) -#define EXP_COEFF_0 (0x3F000000) // 0.5 = 1/(2!) -#define EXP_LOGN2 (0x3F317218) // ln(2) = 0.6931471805 -#define EXP_LOG2E (0x3FB8AA3B) // log2(e) = 1/ln(2) = 1.4426950408 -#define EXP_ONE (0x3f800000) // 1.0 -#define EXP_RANGE_R (0x41a00000) // 20.0 -#define EXP_RANGE_L (0xc1a00000) // -20.0 - -static inline HVX_Vector hvx_vec_exp_fp32(HVX_Vector in_vec) { - HVX_Vector z_qf32_v; - HVX_Vector x_v; - HVX_Vector x_qf32_v; - HVX_Vector y_v; - HVX_Vector k_v; - HVX_Vector f_v; - HVX_Vector epsilon_v; - HVX_Vector log2e = Q6_V_vsplat_R(EXP_LOG2E); - HVX_Vector logn2 = Q6_V_vsplat_R(EXP_LOGN2); - HVX_Vector E_const; - HVX_Vector zero_v = Q6_V_vzero(); - - // exp(x) is approximated as follows: - // f = floor(x/ln(2)) = floor(x*log2(e)) - // epsilon = x - f*ln(2) - // exp(x) = exp(epsilon+f*ln(2)) - // = exp(epsilon)*exp(f*ln(2)) - // = exp(epsilon)*2^f - // - // Since epsilon is close to zero, it can be approximated with its Taylor series: - // exp(x) ~= 1+x+x^2/2!+x^3/3!+...+x^n/n!+... - // Preserving the first eight elements, we get: - // exp(x) ~= 1+x+e0*x^2+e1*x^3+e2*x^4+e3*x^5+e4*x^6+e5*x^7 - // = 1+x+(E0+(E1+(E2+(E3+(E4+E5*x)*x)*x)*x)*x)*x^2 - - HVX_Vector temp_v = in_vec; - - // Clamp inputs to (-20.0, 20.0) - HVX_VectorPred pred_cap_right = Q6_Q_vcmp_gt_VsfVsf(in_vec, Q6_V_vsplat_R(EXP_RANGE_R)); - HVX_VectorPred pred_cap_left = Q6_Q_vcmp_gt_VsfVsf(Q6_V_vsplat_R(EXP_RANGE_L), in_vec); - - in_vec = Q6_V_vmux_QVV(pred_cap_right, Q6_V_vsplat_R(EXP_RANGE_R), temp_v); - in_vec = Q6_V_vmux_QVV(pred_cap_left, Q6_V_vsplat_R(EXP_RANGE_L), temp_v); - - epsilon_v = Q6_Vqf32_vmpy_VsfVsf(log2e, in_vec); - epsilon_v = Q6_Vsf_equals_Vqf32(epsilon_v); - - // f_v is the floating point result and k_v is the integer result - f_v = hvx_vec_floor_fp32(epsilon_v); - k_v = hvx_vec_truncate_fp32(f_v); - - x_qf32_v = Q6_Vqf32_vadd_VsfVsf(in_vec, zero_v); - - // x = x - f_v * logn2; - epsilon_v = Q6_Vqf32_vmpy_VsfVsf(f_v, logn2); - x_qf32_v = Q6_Vqf32_vsub_Vqf32Vqf32(x_qf32_v, epsilon_v); - // normalize before every QFloat's vmpy - x_qf32_v = Q6_Vqf32_vadd_Vqf32Vsf(x_qf32_v, zero_v); - - // z = x * x; - z_qf32_v = Q6_Vqf32_vmpy_Vqf32Vqf32(x_qf32_v, x_qf32_v); - z_qf32_v = Q6_Vqf32_vadd_Vqf32Vsf(z_qf32_v, zero_v); - - x_v = Q6_Vsf_equals_Vqf32(x_qf32_v); - - // y = E4 + E5 * x; - E_const = Q6_V_vsplat_R(EXP_COEFF_5); - y_v = Q6_Vqf32_vmpy_VsfVsf(E_const, x_v); - E_const = Q6_V_vsplat_R(EXP_COEFF_4); - y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const); - y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v); - - // y = E3 + y * x; - E_const = Q6_V_vsplat_R(EXP_COEFF_3); - y_v = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, x_qf32_v); - y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const); - y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v); - - // y = E2 + y * x; - E_const = Q6_V_vsplat_R(EXP_COEFF_2); - y_v = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, x_qf32_v); - y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const); - y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v); - - // y = E1 + y * x; - E_const = Q6_V_vsplat_R(EXP_COEFF_1); - y_v = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, x_qf32_v); - y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const); - y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v); - - // y = E0 + y * x; - E_const = Q6_V_vsplat_R(EXP_COEFF_0); - y_v = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, x_qf32_v); - y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const); - y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v); - - // y = x + y * z; - y_v = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, z_qf32_v); - y_v = Q6_Vqf32_vadd_Vqf32Vqf32(y_v, x_qf32_v); - y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v); - - // y = y + 1.0; - y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, Q6_V_vsplat_R(EXP_ONE)); - - // insert exponents - // y = ldexpf(y, k); - // y_v += k_v; // qf32 - // modify exponent - - y_v = Q6_Vsf_equals_Vqf32(y_v); - - // add k_v to the exponent of y_v - HVX_Vector y_v_exponent = Q6_Vw_vasl_VwR(y_v, 1); - - y_v_exponent = Q6_Vuw_vlsr_VuwR(y_v_exponent, IEEE_VSF_MANTLEN + 1); - y_v_exponent = Q6_Vw_vadd_VwVw(k_v, y_v_exponent); - - // exponent cannot be negative; if overflow is detected, result is set to zero - HVX_VectorPred qy_v_negative_exponent = Q6_Q_vcmp_gt_VwVw(zero_v, y_v_exponent); - - y_v = Q6_Vw_vaslacc_VwVwR(y_v, k_v, IEEE_VSF_MANTLEN); - - y_v = Q6_V_vmux_QVV(qy_v_negative_exponent, zero_v, y_v); - - return y_v; -} - -#define RSQRT_CONST 0x5f3759df // Constant for fast inverse square root calculation -#define RSQRT_ONE_HALF 0x3f000000 // 0.5 -#define RSQRT_THREE_HALVES 0x3fc00000 // 1.5 - -static inline HVX_Vector hvx_vec_rsqrt_fp32(HVX_Vector in_vec) { - //Algorithm : - // x2 = input*0.5 - // y = * (long *) &input - // y = 0x5f3759df - (y>>2) - // y = y*(threehalfs - x2*y*y) - - HVX_Vector rsqrtconst = Q6_V_vsplat_R(RSQRT_CONST); - HVX_Vector onehalf = Q6_V_vsplat_R(RSQRT_ONE_HALF); - HVX_Vector threehalfs = Q6_V_vsplat_R(RSQRT_THREE_HALVES); - - HVX_Vector x2, y, ypower2, temp; - - x2 = Q6_Vqf32_vmpy_VsfVsf(in_vec, onehalf); - x2 = Q6_Vqf32_vadd_Vqf32Vsf(x2, Q6_V_vzero()); - - y = Q6_Vw_vasr_VwR(in_vec, 1); - y = Q6_Vw_vsub_VwVw(rsqrtconst, y); - - // 1st iteration - ypower2 = Q6_Vqf32_vmpy_VsfVsf(y, y); - ypower2 = Q6_Vqf32_vadd_Vqf32Vsf(ypower2, Q6_V_vzero()); - temp = Q6_Vqf32_vmpy_Vqf32Vqf32(x2, ypower2); - temp = Q6_Vqf32_vsub_VsfVsf(threehalfs, Q6_Vsf_equals_Vqf32(temp)); - temp = Q6_Vqf32_vmpy_VsfVsf(y, Q6_Vsf_equals_Vqf32(temp)); - - // 2nd iteration - y = Q6_Vqf32_vadd_Vqf32Vsf(temp, Q6_V_vzero()); - ypower2 = Q6_Vqf32_vmpy_Vqf32Vqf32(y, y); - ypower2 = Q6_Vqf32_vadd_Vqf32Vsf(ypower2, Q6_V_vzero()); - temp = Q6_Vqf32_vmpy_Vqf32Vqf32(x2, ypower2); - temp = Q6_Vqf32_vsub_VsfVsf(threehalfs, Q6_Vsf_equals_Vqf32(temp)); - temp = Q6_Vqf32_vmpy_Vqf32Vqf32(y, temp); - - // 3rd iteration - y = Q6_Vqf32_vadd_Vqf32Vsf(temp, Q6_V_vzero()); - ypower2 = Q6_Vqf32_vmpy_Vqf32Vqf32(y, y); - ypower2 = Q6_Vqf32_vadd_Vqf32Vsf(ypower2, Q6_V_vzero()); - temp = Q6_Vqf32_vmpy_Vqf32Vqf32(x2, ypower2); - temp = Q6_Vqf32_vsub_VsfVsf(threehalfs, Q6_Vsf_equals_Vqf32(temp)); - temp = Q6_Vqf32_vmpy_Vqf32Vqf32(y, temp); - - return Q6_Vsf_equals_Vqf32(temp); -} - -static inline HVX_Vector hvx_vec_fast_sigmoid_fp32_guard(HVX_Vector v, - HVX_Vector one, - HVX_Vector max_exp, - HVX_Vector min_exp) { - const HVX_VectorPred pred_max = Q6_Q_vcmp_gt_VsfVsf(max_exp, v); - const HVX_VectorPred pred_min = Q6_Q_vcmp_gt_VsfVsf(v, min_exp); - - HVX_Vector out = hvx_vec_fast_sigmoid_fp32(v); - out = Q6_V_vmux_QVV(pred_max, out, one); - return Q6_V_vmux_QVV(pred_min, out, Q6_V_vzero()); -} - -static inline HVX_Vector hvx_vec_tanh_fp32(HVX_Vector x) { - // tanh(x) = 2 * sigmoid(2x) - 1 - HVX_Vector two = hvx_vec_splat_fp32(2.0f); - HVX_Vector one = hvx_vec_splat_fp32(1.0f); - HVX_Vector x2 = Q6_Vqf32_vmpy_VsfVsf(x, two); - - static const float kMinExp = -87.f; // 0 - static const float kMaxExp = 87.f; // 1 - HVX_Vector max_exp = hvx_vec_splat_fp32(kMaxExp); - HVX_Vector min_exp = hvx_vec_splat_fp32(kMinExp); - - HVX_Vector sig2x = hvx_vec_fast_sigmoid_fp32_guard(Q6_Vsf_equals_Vqf32(x2), one, max_exp, min_exp); - - HVX_Vector res = Q6_Vqf32_vmpy_VsfVsf(sig2x, two); - res = Q6_Vqf32_vsub_Vqf32Vsf(res, one); - return Q6_Vsf_equals_Vqf32(res); -} - -static inline void hvx_fast_sigmoid_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems) { - int step_of_1 = num_elems >> 5; - int remaining = num_elems - step_of_1 * VLEN_FP32; - - const HVX_Vector * restrict v_src = (HVX_Vector *) src; - HVX_Vector * restrict v_dst = (HVX_Vector *) dst; - - static const float kMinExp = -87.f; // 0 - static const float kMaxExp = 87.f; // 1 - - const HVX_Vector one = hvx_vec_splat_fp32(1.f); - const HVX_Vector max_exp = hvx_vec_splat_fp32(kMaxExp); - const HVX_Vector min_exp = hvx_vec_splat_fp32(kMinExp); - - #pragma unroll(4) - for (int i = 0; i < step_of_1; i++) { - v_dst[i] = hvx_vec_fast_sigmoid_fp32_guard(v_src[i], one, max_exp, min_exp); - } - - if (remaining > 0) { - const float * srcf = ((const float *) src) + step_of_1* VLEN_FP32; - float * dstf = (float *) dst + step_of_1*VLEN_FP32; - - HVX_Vector in = *(HVX_UVector *) srcf; - HVX_Vector out = hvx_vec_fast_sigmoid_fp32_guard(in, one, max_exp, min_exp); - hvx_vec_store_u((void *) dstf, remaining * SIZEOF_FP32, out); - } -} - -static inline void hvx_sigmoid_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems){ - int step_of_1 = num_elems >> 5; // divby 32, because 32 float = 128 bytes per HVX vector - int leftover = num_elems - (step_of_1 * VLEN_FP32); - - int32_t leftover_size = leftover * sizeof(float); - - static const float kMinExp = -87.f; // 0 - static const float kMaxExp = 87.f; // 1 - - const HVX_Vector one = hvx_vec_splat_fp32(1.f); - const HVX_Vector max_exp = hvx_vec_splat_fp32(kMaxExp); - const HVX_Vector min_exp = hvx_vec_splat_fp32(kMinExp); - - const float *input = (float *)src; - float *output = (float *)dst; - - HVX_Vector * input_v_ptr = (HVX_Vector *) input; - HVX_UVector * output_v_ptr = (HVX_UVector *) output; - - HVX_Vector slinep; - HVX_Vector slinec; - HVX_Vector sline; - - slinep = *input_v_ptr++; - #pragma unroll(4) - for (int i = step_of_1 - 1; i > 0; i--) { - slinec = *input_v_ptr++; - sline = Q6_V_valign_VVR(slinec, slinep, (size_t) input); - *((HVX_UVector *) (output_v_ptr++)) = hvx_vec_fast_sigmoid_fp32_guard(sline, one, max_exp, min_exp); - /* Prepare slinep for next iteration */ - slinep = slinec; - } - - if (step_of_1 > 0) { - slinec = htp_is_aligned(input_v_ptr, 128) && leftover == 0 ? slinep : *input_v_ptr++; - sline = Q6_V_valign_VVR(slinec, slinep, (size_t) input); - *((HVX_UVector *) (output_v_ptr++)) = hvx_vec_fast_sigmoid_fp32_guard(sline, one, max_exp, min_exp); - ; - - slinep = slinec; - } - if (leftover > 0) { - slinec = (is_in_one_chunk(input_v_ptr, leftover_size, 128) ? slinep : *input_v_ptr++); - - sline = Q6_V_valign_VVR(slinec, slinep, (size_t) input); - - HVX_Vector sout = hvx_vec_fast_sigmoid_fp32_guard(sline, one, max_exp, min_exp); - hvx_vec_store_u(output_v_ptr, leftover_size, sout); - } -} - -static inline void hvx_scale_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) { - int nvec = n / VLEN_FP32; - int nloe = n % VLEN_FP32; - - HVX_Vector vs = hvx_vec_splat_fp32(scale); - - HVX_Vector * vsrc = (HVX_Vector *) src; - HVX_Vector * vdst = (HVX_Vector *) dst; - - uint32_t i = 0; - - #pragma unroll(4) - for (i = 0; i < nvec; ++i) { - HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs); - vdst[i] = Q6_Vsf_equals_Vqf32(v); - } - - if (nloe) { - HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs); - hvx_vec_store_u((void *) &vdst[i], nloe * 4, Q6_Vsf_equals_Vqf32(v)); - } -} - -static inline void hvx_scale_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) { - int nvec = n / VLEN_FP32; - int nloe = n % VLEN_FP32; - - HVX_Vector vs = hvx_vec_splat_fp32(scale); - - HVX_UVector * vsrc = (HVX_UVector *) src; - HVX_UVector * vdst = (HVX_UVector *) dst; - - uint32_t i = 0; - - #pragma unroll(4) - for (i = 0; i < nvec; ++i) { - HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs); - vdst[i] = Q6_Vsf_equals_Vqf32(v); - } - - if (nloe) { - HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs); - hvx_vec_store_u((void *) &vdst[i], nloe * 4, Q6_Vsf_equals_Vqf32(v)); - } -} - -static inline void hvx_scale_f32(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) { - if (htp_is_aligned((void *) src, VLEN) && htp_is_aligned((void *) dst, VLEN)) { - hvx_scale_f32_aa(dst, src, n, scale); - } else { - hvx_scale_f32_uu(dst, src, n, scale); - } -} - -static inline void hvx_scale_offset_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) { - int nvec = n / VLEN_FP32; - int nloe = n % VLEN_FP32; - - HVX_Vector vs = hvx_vec_splat_fp32(scale); - HVX_Vector vo = hvx_vec_splat_fp32(offset); - - HVX_Vector * vsrc = (HVX_Vector *) src; - HVX_Vector * vdst = (HVX_Vector *) dst; - - uint32_t i = 0; - - #pragma unroll(4) - for (i = 0; i < nvec; ++i) { - HVX_Vector v = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs), vo); - vdst[i] = Q6_Vsf_equals_Vqf32(v); - } - - if (nloe) { - HVX_Vector v = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs), vo); - hvx_vec_store_u((void *) &vdst[i], nloe * 4, Q6_Vsf_equals_Vqf32(v)); - } -} - -static inline void hvx_scale_offset_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) { - int nvec = n / VLEN_FP32; - int nloe = n % VLEN_FP32; - - HVX_Vector vs = hvx_vec_splat_fp32(scale); - HVX_Vector vo = hvx_vec_splat_fp32(offset); - - HVX_UVector * vsrc = (HVX_UVector *) src; - HVX_UVector * vdst = (HVX_UVector *) dst; - - uint32_t i = 0; - - #pragma unroll(4) - for (i = 0; i < nvec; ++i) { - HVX_Vector v = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs), vo); - vdst[i] = Q6_Vsf_equals_Vqf32(v); - } - - if (nloe) { - HVX_Vector v = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs), vo); - hvx_vec_store_u((void *) &vdst[i], nloe * 4, Q6_Vsf_equals_Vqf32(v)); - } -} - -static inline void hvx_scale_offset_f32(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) { - if (htp_is_aligned((void *) src, VLEN) && htp_is_aligned((void *) dst, VLEN)) { - hvx_scale_offset_f32_aa(dst, src, n, scale, offset); - } else { - hvx_scale_offset_f32_uu(dst, src, n, scale, offset); - } -} - -float hvx_sum_of_squares_f32(const uint8_t * restrict src, const int num_elems); -void hvx_mul_f32(const uint8_t * restrict src0, - const uint8_t * restrict src1, - uint8_t * restrict dst, - const int num_elems); -void hvx_mul_f32_opt(const uint8_t * restrict src0, - const uint8_t * restrict src1, - uint8_t * restrict dst, - const int num_elems); -void hvx_mul_mul_f32_opt(const uint8_t * restrict src0, - const uint8_t * restrict src1, - const uint8_t * restrict src2, - uint8_t * restrict dst, - const int num_elems); -void hvx_mul_scalar_f32(const uint8_t * restrict src, const float val, uint8_t * restrict dst, const int num_elems); -void hvx_add_f32(const uint8_t * restrict src0, - const uint8_t * restrict src1, - uint8_t * restrict dst, - const int num_elems); -void hvx_add_f32_opt(const uint8_t * restrict src0, - const uint8_t * restrict src1, - uint8_t * restrict dst, - const int num_elems); -void hvx_add_scalar_f32(const uint8_t * restrict src, const float val, uint8_t * restrict dst, const int num_elems); -void hvx_sub_f32(const uint8_t * restrict src0, - const uint8_t * restrict src1, - uint8_t * restrict dst, - const int num_elems); -void hvx_sub_f32_opt(const uint8_t * restrict src0, - const uint8_t * restrict src1, - uint8_t * restrict dst, - const int num_elems); -void hvx_sub_scalar_f32(const uint8_t * restrict src, const float val, uint8_t * restrict dst, const int num_elems); -void hvx_inverse_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems); -void hvx_sigmoid_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems); -void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, bool negate); -float hvx_self_max_f32(const uint8_t * restrict src, const int num_elems); -float hvx_self_sum_f32(const uint8_t * restrict src, const int num_elems); -void hvx_min_scalar_f32(const uint8_t * restrict src, const float val, uint8_t * restrict dst, const int num_elems); -void hvx_clamp_scalar_f32(const uint8_t * restrict src, - const float limit_left, - const float limit_right, - uint8_t * restrict dst, - const int num_elems); +#include "hex-utils.h" + +#include "hvx-types.h" +#include "hvx-copy.h" +#include "hvx-scale.h" +#include "hvx-exp.h" +#include "hvx-inverse.h" +#include "hvx-reduce.h" +#include "hvx-sigmoid.h" +#include "hvx-sqrt.h" +#include "hvx-arith.h" +#include "hvx-base.h" #endif /* HVX_UTILS_H */ diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index 24b3e90e4b6..e28a67a95dc 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -1,17 +1,13 @@ #pragma clang diagnostic ignored "-Wgnu-zero-variadic-macro-arguments" #pragma clang diagnostic ignored "-Wunused-function" -#define FARF_ERROR 1 -#define FARF_HIGH 1 -#define FARF_MEDIUM 0 -#define FARF_LOW 0 +#include +#include #include #include #include #include -#include #include -#include #include #include #include @@ -19,13 +15,14 @@ #include #include +#include "hex-dma.h" +#include "hex-utils.h" + #define GGML_COMMON_DECL_C #include "ggml-common.h" #include "htp-ctx.h" -#include "htp-dma.h" #include "htp-msg.h" #include "htp-ops.h" -#include "ops-utils.h" #include "worker-pool.h" AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) { @@ -362,14 +359,14 @@ struct profile_data { static inline void profile_start(struct profile_data * d) { d->usecs = HAP_perf_get_qtimer_count(); - d->cycles = htp_get_cycles(); - d->pkts = htp_get_pktcnt(); + d->cycles = hex_get_cycles(); + d->pkts = hex_get_pktcnt(); } static inline void profile_stop(struct profile_data * d) { d->usecs = HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - d->usecs); - d->cycles = htp_get_cycles() - d->cycles; - d->pkts = htp_get_pktcnt() - d->pkts; + d->cycles = hex_get_cycles() - d->cycles; + d->pkts = hex_get_pktcnt() - d->pkts; } static int send_htp_rsp(struct htp_context * c, @@ -443,6 +440,43 @@ static void proc_matmul_req(struct htp_context * ctx, send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); } +static void proc_cpy_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { + struct dspqueue_buffer rsp_bufs[1]; + + // We had written to the output buffer, we'd also need to flush it + rsp_bufs[0].fd = bufs[1].fd; + rsp_bufs[0].ptr = bufs[1].ptr; + rsp_bufs[0].offset = bufs[1].offset; + rsp_bufs[0].size = bufs[1].size; + rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP + DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU + + // Setup Op context + struct htp_ops_context octx = { 0 }; + octx.ctx = ctx; + octx.src0 = req->src0; + octx.dst = req->dst; + octx.flags = req->flags; + octx.op = req->op; + + // Update data pointers + octx.src0.data = (uint32_t) bufs[0].ptr; + octx.dst.data = (uint32_t) bufs[1].ptr; + octx.n_threads = ctx->n_threads; + + struct profile_data prof; + profile_start(&prof); + + uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; + if (vtcm_acquire(ctx) == AEE_SUCCESS) { + rsp_status = op_cpy(&octx); + vtcm_release(ctx); + } + + profile_stop(&prof); + send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); +} + static void proc_get_rows_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { struct dspqueue_buffer rsp_bufs[1]; @@ -993,6 +1027,14 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) { proc_get_rows_req(ctx, &req, bufs); break; + case HTP_OP_CPY: + if (n_bufs != 2) { + FARF(ERROR, "Bad cpy-req buffer list"); + continue; + } + proc_cpy_req(ctx, &req, bufs); + break; + default: FARF(ERROR, "Unknown Op %u", req.op); break; diff --git a/ggml/src/ggml-hexagon/htp/matmul-ops.c b/ggml/src/ggml-hexagon/htp/matmul-ops.c index 9bb39db9fcb..1603ff2b3b6 100644 --- a/ggml/src/ggml-hexagon/htp/matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/matmul-ops.c @@ -3,28 +3,20 @@ #pragma clang diagnostic ignored "-Wunused-variable" #pragma clang diagnostic ignored "-Wunused-but-set-variable" -#ifdef HTP_DEBUG -# define FARF_HIGH 1 -#endif - #include -#include #include -#include -#include -#include + #include -#include #include +#include "hex-dma.h" +#include "hvx-utils.h" + #define GGML_COMMON_DECL_C #include "ggml-common.h" #include "htp-ctx.h" -#include "htp-dma.h" #include "htp-msg.h" #include "htp-ops.h" -#include "hvx-utils.h" -#include "ops-utils.h" #define MM_SPAD_SRC0_NROWS 16 #define MM_SPAD_SRC1_NROWS 16 @@ -36,20 +28,8 @@ struct htp_matmul_type { void (*vec_dot_rx2)(const int n, float * restrict s, const void * restrict vx, uint32_t vx_row_size, const void * restrict vy); }; -typedef struct { - HVX_Vector v[2]; -} HVX_Vector_x2; - -typedef struct { - HVX_Vector v[4]; -} HVX_Vector_x4; - -typedef struct { - HVX_Vector v[8]; -} HVX_Vector_x8; - // vdelta control to replicate first 4x fp32 values across lanes -static const uint8_t __attribute__((aligned(128))) repl_4x_fp32[128] = { +static const uint8_t __attribute__((aligned(128))) repl_4x_f32[128] = { 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04, @@ -60,7 +40,7 @@ static const uint8_t __attribute__((aligned(128))) repl_4x_fp32[128] = { }; // vdelta control to replicate and interleave first 8x fp32 values across lanes -static const uint8_t __attribute__((aligned(128))) repl_interleave_8x_fp32[128] = { +static const uint8_t __attribute__((aligned(128))) repl_interleave_8x_f32[128] = { 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, @@ -71,7 +51,7 @@ static const uint8_t __attribute__((aligned(128))) repl_interleave_8x_fp32[128] }; // vdelta control to replicate first fp32 value across all elements -static const uint8_t __attribute__((aligned(128))) repl_1x_fp32[128] = { +static const uint8_t __attribute__((aligned(128))) repl_1x_f32[128] = { 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, @@ -82,7 +62,7 @@ static const uint8_t __attribute__((aligned(128))) repl_1x_fp32[128] = { }; // vdelta control to replicate first fp16 value across all elements -static const uint8_t __attribute__((aligned(128))) repl_1x_fp16[128] = { +static const uint8_t __attribute__((aligned(128))) repl_1x_f16[128] = { 0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, @@ -93,7 +73,7 @@ static const uint8_t __attribute__((aligned(128))) repl_1x_fp16[128] = { }; // vdelta control to replicate first fp16 value across all elements -static const uint8_t __attribute__((aligned(128))) repl_2x_fp16[128] = { +static const uint8_t __attribute__((aligned(128))) repl_2x_f16[128] = { 0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, @@ -129,7 +109,7 @@ static inline size_t q8x4x2_row_size(uint32_t ne) { // ensures perfect alignment of quants and full row const uint32_t qk = QK_Q8_0x4x2; const uint32_t nb = (ne + qk - 1) / qk; - return htp_round_up(ne + nb * 8 * sizeof(__fp16), 128); + return hex_round_up(ne + nb * 8 * sizeof(__fp16), 128); } static inline HVX_Vector_x8 hvx_vec_load_q4x4x8(const uint8_t * restrict ptr) { @@ -389,7 +369,7 @@ static void vec_dot_q4x4x2_q8x4x2(const int n, float * restrict s, const void * } // Reduce and convert into fp32 - r0_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r0_sum)); + r0_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r0_sum)); hvx_vec_store_u(&s[0], 4, r0_sum); } @@ -485,8 +465,8 @@ static void vec_dot_q4x4x2_q8x4x2_rx2(const int n, } // Convert into fp32 and reduce - r0_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r0_sum)); - r1_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r1_sum)); + r0_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r0_sum)); + r1_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r1_sum)); HVX_VectorPair p0 = Q6_W_vshuff_VVR(r1_sum, r0_sum, 4); hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0)); @@ -562,7 +542,7 @@ static void vec_dot_q8x4x2_q8x4x2(const int n, float * restrict s, const void * } // Reduce and convert into fp32 - r0_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r0_sum)); + r0_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r0_sum)); hvx_vec_store_u(&s[0], 4, r0_sum); } @@ -658,8 +638,8 @@ static void vec_dot_q8x4x2_q8x4x2_rx2(const int n, } // Convert into fp32 and reduce - r0_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r0_sum)); - r1_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r1_sum)); + r0_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r0_sum)); + r1_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r1_sum)); HVX_VectorPair p0 = Q6_W_vshuff_VVR(r1_sum, r0_sum, 4); hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0)); @@ -768,7 +748,7 @@ static void vec_dot_mxfp4x4x2_q8x4x2(const int n, } // Reduce and convert into fp32 - r0_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r0_sum)); + r0_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r0_sum)); hvx_vec_store_u(&s[0], 4, r0_sum); } @@ -900,8 +880,8 @@ static void vec_dot_mxfp4x4x2_q8x4x2_rx2(const int n, } // Convert into fp32 and reduce - r0_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r0_sum)); - r1_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r1_sum)); + r0_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r0_sum)); + r1_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r1_sum)); HVX_VectorPair p0 = Q6_W_vshuff_VVR(r1_sum, r0_sum, 4); hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0)); @@ -933,7 +913,7 @@ static void vec_dot_f16_f16_aa(const int n, float * restrict s, const void * res rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); } - rsum = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum)); + rsum = Q6_Vsf_equals_Vqf32(hvx_vec_reduce_sum_qf32(rsum)); hvx_vec_store_u(&s[0], 4, rsum); } @@ -977,8 +957,8 @@ static void vec_dot_f16_f16_aa_rx2(const int n, rsum1 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum1, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf))); } - rsum0 = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum0)); - rsum1 = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum1)); + rsum0 = Q6_Vsf_equals_Vqf32(hvx_vec_reduce_sum_qf32(rsum0)); + rsum1 = Q6_Vsf_equals_Vqf32(hvx_vec_reduce_sum_qf32(rsum1)); HVX_VectorPair p0 = Q6_W_vshuff_VVR(rsum1, rsum0, 4); hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0)); @@ -1010,7 +990,7 @@ static void vec_dot_f16_f16_uu(const int n, float * restrict s, const void * res rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); } - rsum = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum)); + rsum = Q6_Vsf_equals_Vqf32(hvx_vec_reduce_sum_qf32(rsum)); hvx_vec_store_u(&s[0], 4, rsum); } @@ -1062,7 +1042,7 @@ static void vec_dot_f16_f32_uu(const int n, float * restrict s, const void * res rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); } - rsum = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum)); + rsum = Q6_Vsf_equals_Vqf32(hvx_vec_reduce_sum_qf32(rsum)); hvx_vec_store_u(&s[0], 4, rsum); } @@ -1359,7 +1339,7 @@ static void matvec_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx mt->vec_dot(ne00, &tmp[ir0 - src0_start_row], ss0, src1_col); } - hvx_copy_fp32_ua((uint8_t *) &dst_col[src0_start_row], (uint8_t *) tmp, src0_end_row - src0_start_row); + hvx_copy_f32_ua((uint8_t *) &dst_col[src0_start_row], (uint8_t *) tmp, src0_end_row - src0_start_row); t2 = HAP_perf_get_qtimer_count(); @@ -1411,7 +1391,7 @@ static void matmul_id(struct htp_matmul_type * mt, struct htp_ops_context * octx const size_t src0_row_size = nb01; const size_t src1_row_size = q8x4x2_row_size(ne10); - const size_t src0_row_size_padded = htp_round_up(src0_row_size, 128); + const size_t src0_row_size_padded = hex_round_up(src0_row_size, 128); // Per-thread VTCM scratchpads for all tensors // Note that the entire src1 tensor is already in VTCM @@ -1524,7 +1504,7 @@ static void matvec_id(struct htp_matmul_type * mt, struct htp_ops_context * octx const size_t src0_row_size = nb01; const size_t src1_row_size = q8x4x2_row_size(ne10); - const size_t src0_row_size_padded = htp_round_up(src0_row_size, 128); + const size_t src0_row_size_padded = hex_round_up(src0_row_size, 128); const uint32_t n_aids = src2->ne[0]; // num activated experts const uint32_t n_ids = ne02; // num experts @@ -1590,7 +1570,7 @@ static void matvec_id(struct htp_matmul_type * mt, struct htp_ops_context * octx // *** dynamic quant -static inline void quantize_block_fp32_q8x1(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) { +static inline void quantize_block_f32_q8x1(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) { assert((unsigned long) x % 128 == 0); assert((unsigned long) y_q % 128 == 0); @@ -1598,10 +1578,10 @@ static inline void quantize_block_fp32_q8x1(float * restrict x, uint8_t * restri HVX_Vector zero = Q6_V_vsplat_R(0); // Use reduce max fp32 to find max(abs(e)) first - HVX_Vector vmax0_sf = hvx_vec_reduce_max_fp32(hvx_vec_abs_fp32(vx[0])); - HVX_Vector vmax1_sf = hvx_vec_reduce_max_fp32(hvx_vec_abs_fp32(vx[1])); - HVX_Vector vmax2_sf = hvx_vec_reduce_max_fp32(hvx_vec_abs_fp32(vx[2])); - HVX_Vector vmax3_sf = hvx_vec_reduce_max_fp32(hvx_vec_abs_fp32(vx[3])); + HVX_Vector vmax0_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[0])); + HVX_Vector vmax1_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[1])); + HVX_Vector vmax2_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[2])); + HVX_Vector vmax3_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[3])); // Load and convert into QF32 HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero); // 32 elements HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero); // 32 elements @@ -1623,7 +1603,7 @@ static inline void quantize_block_fp32_q8x1(float * restrict x, uint8_t * restri HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf))); // Replicate first fp16 scale across all lanes - HVX_Vector ctrl = *(const HVX_Vector *) repl_2x_fp16; + HVX_Vector ctrl = *(const HVX_Vector *) repl_2x_f16; vmax01_hf = Q6_V_vdelta_VV(vmax01_hf, ctrl); vmax23_hf = Q6_V_vdelta_VV(vmax23_hf, ctrl); @@ -1641,8 +1621,8 @@ static inline void quantize_block_fp32_q8x1(float * restrict x, uint8_t * restri hvx_vec_store_u(y_d + 6, 2, rotated_vd_hf); // Divide input by the scale - HVX_Vector vd01_inv_hf = hvx_vec_inverse_fp16(vd01_hf); - HVX_Vector vd23_inv_hf = hvx_vec_inverse_fp16(vd23_hf); + HVX_Vector vd01_inv_hf = hvx_vec_inverse_f16(vd01_hf); + HVX_Vector vd23_inv_hf = hvx_vec_inverse_f16(vd23_hf); vx01_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd01_inv_hf)); vx23_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd23_inv_hf)); @@ -1654,7 +1634,7 @@ static inline void quantize_block_fp32_q8x1(float * restrict x, uint8_t * restri *(HVX_Vector *) y_q = vx_i8; } -static inline void quantize_block_fp32_q8x2(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) { +static inline void quantize_block_f32_q8x2(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) { assert((unsigned long) x % 128 == 0); assert((unsigned long) y_q % 128 == 0); @@ -1672,11 +1652,11 @@ static inline void quantize_block_fp32_q8x2(float * restrict x, uint8_t * restri HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf))); // Compute max and scale - HVX_Vector vmax01_hf = hvx_vec_reduce_max_fp16(hvx_vec_abs_fp16(vx01_hf)); - HVX_Vector vmax23_hf = hvx_vec_reduce_max_fp16(hvx_vec_abs_fp16(vx23_hf)); + HVX_Vector vmax01_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx01_hf)); + HVX_Vector vmax23_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx23_hf)); // Replicate first fp16 scale across all lanes - HVX_Vector ctrl = *(const HVX_Vector *) repl_1x_fp16; + HVX_Vector ctrl = *(const HVX_Vector *) repl_1x_f16; vmax01_hf = Q6_V_vdelta_VV(vmax01_hf, ctrl); vmax23_hf = Q6_V_vdelta_VV(vmax23_hf, ctrl); @@ -1689,8 +1669,8 @@ static inline void quantize_block_fp32_q8x2(float * restrict x, uint8_t * restri hvx_vec_store_u(y_d + 4, 4, vd23_hf); // Divide input by the scale - HVX_Vector vd01_inv_hf = hvx_vec_inverse_fp16(vd01_hf); - HVX_Vector vd23_inv_hf = hvx_vec_inverse_fp16(vd23_hf); + HVX_Vector vd01_inv_hf = hvx_vec_inverse_f16(vd01_hf); + HVX_Vector vd23_inv_hf = hvx_vec_inverse_f16(vd23_hf); vx01_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd01_inv_hf)); vx23_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd23_inv_hf)); @@ -1702,7 +1682,7 @@ static inline void quantize_block_fp32_q8x2(float * restrict x, uint8_t * restri *(HVX_Vector *) y_q = vx_i8; } -static inline void quantize_block_fp32_q8x4(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) { +static inline void quantize_block_f32_q8x4(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) { assert((unsigned long) x % 128 == 0); assert((unsigned long) y_q % 128 == 0); @@ -1720,11 +1700,11 @@ static inline void quantize_block_fp32_q8x4(float * restrict x, uint8_t * restri HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf))); // Compute max and scale - HVX_Vector vmax_hf = hvx_vec_reduce_max_fp16(hvx_vec_abs_fp16(vx01_hf)); - vmax_hf = hvx_vec_reduce_max2_fp16(hvx_vec_abs_fp16(vx23_hf), vmax_hf); + HVX_Vector vmax_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx01_hf)); + vmax_hf = hvx_vec_reduce_max2_f16(hvx_vec_abs_f16(vx23_hf), vmax_hf); // Replicate first fp16 scale across all lanes - HVX_Vector ctrl = *(const HVX_Vector *) repl_1x_fp16; + HVX_Vector ctrl = *(const HVX_Vector *) repl_1x_f16; vmax_hf = Q6_V_vdelta_VV(vmax_hf, ctrl); HVX_Vector vd_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0 @@ -1733,7 +1713,7 @@ static inline void quantize_block_fp32_q8x4(float * restrict x, uint8_t * restri *(HVX_UVector *) y_d = vd_hf; // Divide input by the scale - HVX_Vector vd_inv_hf = hvx_vec_inverse_fp16(vd_hf); + HVX_Vector vd_inv_hf = hvx_vec_inverse_f16(vd_hf); vx01_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd_inv_hf)); vx23_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd_inv_hf)); @@ -1746,7 +1726,7 @@ static inline void quantize_block_fp32_q8x4(float * restrict x, uint8_t * restri } // Overrides input x -static void quantize_row_fp32_q8x4x2(float * restrict x, uint8_t * restrict y, uint32_t k) { +static void quantize_row_f32_q8x4x2(float * restrict x, uint8_t * restrict y, uint32_t k) { assert(k % 32 == 0); const uint32_t qk = QK_Q8_0x4x2; const uint32_t nb = (k + qk - 1) / qk; @@ -1764,24 +1744,24 @@ static void quantize_row_fp32_q8x4x2(float * restrict x, uint8_t * restrict y, u for (uint32_t i = 0; i < nb; i++) { #if FP32_QUANTIZE_GROUP_SIZE == 32 - quantize_block_fp32_q8x1(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2); - quantize_block_fp32_q8x1(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2); + quantize_block_f32_q8x1(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2); + quantize_block_f32_q8x1(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2); #elif FP32_QUANTIZE_GROUP_SIZE == 64 - quantize_block_fp32_q8x2(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2); - quantize_block_fp32_q8x2(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2); + quantize_block_f32_q8x2(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2); + quantize_block_f32_q8x2(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2); #elif FP32_QUANTIZE_GROUP_SIZE == 128 - quantize_block_fp32_q8x4(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2); - quantize_block_fp32_q8x4(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2); + quantize_block_f32_q8x4(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2); + quantize_block_f32_q8x4(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2); #else #error "FP32_QUANTIZE_GROUP_SIZE must be 32, 64, or 128" #endif } // now copy the scales into final location - hvx_copy_fp16_ua(y_d, t_d, nb * 8); + hvx_copy_f16_ua(y_d, t_d, nb * 8); } -static void quantize_fp32_q8x4x2(const struct htp_tensor * src, +static void quantize_f32_q8x4x2(const struct htp_tensor * src, uint8_t * restrict dst, struct htp_spad * spad, uint32_t nth, @@ -1807,26 +1787,26 @@ static void quantize_fp32_q8x4x2(const struct htp_tensor * src, uint8_t * restrict dst_data = (uint8_t *) dst + (dst_row_size * ir_first); uint8_t * restrict tmp_data = (uint8_t *) spad->data + (spad->size_per_thread * ith); - const size_t src_row_size_padded = htp_round_up(src_row_size, QK_Q8_0x4x2 * sizeof(float)); + const size_t src_row_size_padded = hex_round_up(src_row_size, QK_Q8_0x4x2 * sizeof(float)); memset(tmp_data, 0, src_row_size_padded); // zero-out temp row data for padding for (uint32_t i = ir_first; i < ir_last; ++i) { - htp_l2fetch(src_data, 2, src_row_size, src_row_size); - hvx_copy_fp32_aa(tmp_data, src_data, ne0); + hex_l2fetch(src_data, src_row_size, src_row_size, 2); + hvx_copy_f32_aa(tmp_data, src_data, ne0); // FARF(HIGH, "quantize-q8x4-row: %u\n", i); - quantize_row_fp32_q8x4x2((float *) tmp_data, dst_data, ne0); + quantize_row_f32_q8x4x2((float *) tmp_data, dst_data, ne0); dst_data += dst_row_size; src_data += src_row_size; } uint64_t t2 = HAP_perf_get_qtimer_count(); - FARF(HIGH, "quantize-fp32-q8x4: %u/%u : n-rows %u (%u:%u) row-size %u -> %u usec %u\n", ith, nth, nrows, ir_first, + FARF(HIGH, "quantize-f32-q8x4: %u/%u : n-rows %u (%u:%u) row-size %u -> %u usec %u\n", ith, nth, nrows, ir_first, ir_last, src_row_size, dst_row_size, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } -static void quantize_fp32_fp16(const struct htp_tensor * src, uint8_t * restrict dst, uint32_t nth, uint32_t ith, +static void quantize_f32_f16(const struct htp_tensor * src, uint8_t * restrict dst, uint32_t nth, uint32_t ith, uint32_t nrows_per_thread, uint32_t dst_stride) { uint64_t t1 = HAP_perf_get_qtimer_count(); @@ -1848,8 +1828,8 @@ static void quantize_fp32_fp16(const struct htp_tensor * src, uint8_t * restrict uint8_t * restrict dst_data = (uint8_t *) dst + (dst_stride * ir_first); for (uint32_t i = ir_first; i < ir_last; ++i) { - htp_l2fetch(src_data, 2, src_row_size, src_stride); - hvx_copy_fp16_fp32_au(dst_data, src_data, ne0); + hex_l2fetch(src_data, src_row_size, src_stride, 2); + hvx_copy_f16_f32_au(dst_data, src_data, ne0); dst_data += dst_stride; src_data += src_stride; @@ -1857,12 +1837,12 @@ static void quantize_fp32_fp16(const struct htp_tensor * src, uint8_t * restrict uint64_t t2 = HAP_perf_get_qtimer_count(); - FARF(HIGH, "quantize-fp32-fp16: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\n", ith, nth, nrows, ir_first, + FARF(HIGH, "quantize-f32-f16: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\n", ith, nth, nrows, ir_first, ir_last, src_row_size, src_stride, dst_stride, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } // TODO just a plain copy that should be done via the DMA during the Op setup -static void quantize_fp16_fp16(const struct htp_tensor * src, uint8_t * restrict dst, uint32_t nth, uint32_t ith, +static void quantize_f16_f16(const struct htp_tensor * src, uint8_t * restrict dst, uint32_t nth, uint32_t ith, uint32_t nrows_per_thread, uint32_t dst_stride) { uint64_t t1 = HAP_perf_get_qtimer_count(); @@ -1884,8 +1864,8 @@ static void quantize_fp16_fp16(const struct htp_tensor * src, uint8_t * restrict uint8_t * restrict dst_data = (uint8_t *) dst + (dst_stride * ir_first); for (uint32_t i = ir_first; i < ir_last; ++i) { - htp_l2fetch(src_data, 2, src_row_size, src_stride); - hvx_copy_fp16_au(dst_data, src_data, ne0); + hex_l2fetch(src_data, src_row_size, src_stride, 2); + hvx_copy_f16_au(dst_data, src_data, ne0); dst_data += dst_stride; src_data += src_stride; @@ -1893,23 +1873,23 @@ static void quantize_fp16_fp16(const struct htp_tensor * src, uint8_t * restrict uint64_t t2 = HAP_perf_get_qtimer_count(); - FARF(HIGH, "quantize-fp16-fp16: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\n", ith, nth, nrows, ir_first, + FARF(HIGH, "quantize-f16-f16: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\n", ith, nth, nrows, ir_first, ir_last, src_row_size, src_stride, dst_stride, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } -static void htp_quantize_fp32_q8x4x2(unsigned int n, unsigned int i, void * data) { +static void htp_quantize_f32_q8x4x2(unsigned int n, unsigned int i, void * data) { struct htp_ops_context * octx = data; - quantize_fp32_q8x4x2(&octx->src1, octx->src1_spad.data, &octx->src0_spad, n, i, octx->src1_nrows_per_thread); + quantize_f32_q8x4x2(&octx->src1, octx->src1_spad.data, &octx->src0_spad, n, i, octx->src1_nrows_per_thread); } -static void htp_quantize_fp32_fp16(unsigned int n, unsigned int i, void * data) { +static void htp_quantize_f32_f16(unsigned int n, unsigned int i, void * data) { struct htp_ops_context * octx = data; - quantize_fp32_fp16(&octx->src1, octx->src1_spad.data, n, i, octx->src1_nrows_per_thread, octx->src1_spad.stride); + quantize_f32_f16(&octx->src1, octx->src1_spad.data, n, i, octx->src1_nrows_per_thread, octx->src1_spad.stride); } -static void htp_quantize_fp16_fp16(unsigned int n, unsigned int i, void * data) { +static void htp_quantize_f16_f16(unsigned int n, unsigned int i, void * data) { struct htp_ops_context * octx = data; - quantize_fp16_fp16(&octx->src1, octx->src1_spad.data, n, i, octx->src1_nrows_per_thread, octx->src1_spad.stride); + quantize_f16_f16(&octx->src1, octx->src1_spad.data, n, i, octx->src1_nrows_per_thread, octx->src1_spad.stride); } // ** matmul/matvec callbacks for worker_pool @@ -2108,7 +2088,7 @@ int op_matmul(struct htp_ops_context * octx) { const size_t dst_row_size = nb1; size_t src1_row_size = nb11; - const size_t src0_row_size_padded = htp_round_up(src0_row_size, 128); + const size_t src0_row_size_padded = hex_round_up(src0_row_size, 128); size_t src1_row_size_padded; worker_callback_t quant_job_func; @@ -2118,8 +2098,8 @@ int op_matmul(struct htp_ops_context * octx) { switch (src0->type) { case HTP_TYPE_Q4_0: - op_type = "q4x4x2-fp32"; - quant_job_func = htp_quantize_fp32_q8x4x2; + op_type = "q4x4x2-f32"; + quant_job_func = htp_quantize_f32_q8x4x2; if (src1_nrows > 1) { matmul_job_func = htp_matmul_2d_q4x4x2_q8x4x2; } else { @@ -2131,12 +2111,12 @@ int op_matmul(struct htp_ops_context * octx) { // Entire src1 tensor is placed into the VTCM // For other tensors we allocate N rows per thread, padded to HVX vector size - octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); - octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256); + octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); + octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); + octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256); // src0 spad is also used in dynamic quantizer to store padded src1 rows - src1_row_size_padded = htp_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float)); + src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float)); if (octx->src0_spad.size_per_thread < src1_row_size_padded) { octx->src0_spad.size_per_thread = src1_row_size_padded; } @@ -2147,8 +2127,8 @@ int op_matmul(struct htp_ops_context * octx) { break; case HTP_TYPE_Q8_0: - op_type = "q8x4x2-fp32"; - quant_job_func = htp_quantize_fp32_q8x4x2; + op_type = "q8x4x2-f32"; + quant_job_func = htp_quantize_f32_q8x4x2; if (src1_nrows > 1) { matmul_job_func = htp_matmul_2d_q8x4x2_q8x4x2; } else { @@ -2160,12 +2140,12 @@ int op_matmul(struct htp_ops_context * octx) { // Entire src1 tensor is placed into the VTCM // For other tensors we allocate N rows per thread, padded to HVX vector size - octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); - octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256); + octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); + octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); + octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256); // src0 spad is also used in dynamic quantizer to store padded src1 rows - src1_row_size_padded = htp_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float)); + src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float)); if (octx->src0_spad.size_per_thread < src1_row_size_padded) { octx->src0_spad.size_per_thread = src1_row_size_padded; } @@ -2177,7 +2157,7 @@ int op_matmul(struct htp_ops_context * octx) { case HTP_TYPE_MXFP4: op_type = "mxfp4x4x2-f32"; - quant_job_func = htp_quantize_fp32_q8x4x2; + quant_job_func = htp_quantize_f32_q8x4x2; if (src1_nrows > 1) { matmul_job_func = htp_matmul_2d_mxfp4x4x2_q8x4x2; } else { @@ -2189,12 +2169,12 @@ int op_matmul(struct htp_ops_context * octx) { // Entire src1 tensor is placed into the VTCM // For other tensors we allocate N rows per thread, padded to HVX vector size - octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); - octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256); + octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); + octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); + octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256); // src0 spad is also used in dynamic quantizer to store padded src1 rows - src1_row_size_padded = htp_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float)); + src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float)); if (octx->src0_spad.size_per_thread < src1_row_size_padded) { octx->src0_spad.size_per_thread = src1_row_size_padded; } @@ -2207,10 +2187,10 @@ int op_matmul(struct htp_ops_context * octx) { case HTP_TYPE_F16: { // Try optimized f16-f16 path first (src1 in VTCM) - const size_t f16_src1_row_size = htp_round_up(ne10 * 2, 128); - const size_t f16_src1_spad_size = htp_round_up(f16_src1_row_size * src1_nrows, 256); - const size_t f16_src0_spad_size = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256) * octx->n_threads; - const size_t f16_dst_spad_size = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256) * octx->n_threads; + const size_t f16_src1_row_size = hex_round_up(ne10 * 2, 128); + const size_t f16_src1_spad_size = hex_round_up(f16_src1_row_size * src1_nrows, 256); + const size_t f16_src0_spad_size = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256) * octx->n_threads; + const size_t f16_dst_spad_size = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256) * octx->n_threads; const size_t f16_total_size = f16_src1_spad_size + f16_src0_spad_size + f16_dst_spad_size; @@ -2222,7 +2202,7 @@ int op_matmul(struct htp_ops_context * octx) { if (!is_batched && !is_permuted && f16_total_size <= octx->ctx->vtcm_size) { // Optimized path op_type = "f16-f16"; - quant_job_func = (src1->type == HTP_TYPE_F32) ? htp_quantize_fp32_fp16 : htp_quantize_fp16_fp16; + quant_job_func = (src1->type == HTP_TYPE_F32) ? htp_quantize_f32_f16 : htp_quantize_f16_f16; if (src1_nrows > 1) { matmul_job_func = htp_matmul_2d_f16_f16; } else { @@ -2231,9 +2211,9 @@ int op_matmul(struct htp_ops_context * octx) { src1_row_size = f16_src1_row_size; // row size post quantization - octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); - octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256); + octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); + octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); + octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256); octx->src1_spad.size = octx->src1_spad.size_per_thread; octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; @@ -2251,9 +2231,9 @@ int op_matmul(struct htp_ops_context * octx) { src1_row_size = nb11; // original row size in DDR - octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size, 256); - octx->src1_spad.size_per_thread = htp_round_up(MM_SPAD_SRC1_NROWS * src1_row_size, 256); + octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); + octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size, 256); + octx->src1_spad.size_per_thread = hex_round_up(MM_SPAD_SRC1_NROWS * src1_row_size, 256); octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads; @@ -2332,7 +2312,7 @@ int op_matmul_id(struct htp_ops_context * octx) { const size_t src0_row_size = nb01; const size_t dst_row_size = nb1; - const size_t src0_row_size_padded = htp_round_up(src0_row_size, 128); + const size_t src0_row_size_padded = hex_round_up(src0_row_size, 128); const uint32_t src0_nrows = ne01; // per expert const uint32_t src1_nrows = ne11 * ne12 * ne13; @@ -2350,7 +2330,7 @@ int op_matmul_id(struct htp_ops_context * octx) { switch (src0->type) { case HTP_TYPE_Q4_0: op_type = "q4x2x2-f32"; - quant_job_func = htp_quantize_fp32_q8x4x2; + quant_job_func = htp_quantize_f32_q8x4x2; src1_row_size = q8x4x2_row_size(ne10); // row size post quantization if (src1_nrows > 1) { matmul_id_job_func = htp_matmul_id_q4x4x2_q8x4x2; @@ -2360,13 +2340,13 @@ int op_matmul_id(struct htp_ops_context * octx) { // Entire src1 tensor is placed into the VTCM // For other tensors we allocate N rows per thread, padded to HVX vector size - octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); - octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256); - octx->src2_spad.size_per_thread = htp_round_up(matrix_row_counts_size + matrix_row_map_size, 256); + octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); + octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); + octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256); + octx->src2_spad.size_per_thread = hex_round_up(matrix_row_counts_size + matrix_row_map_size, 256); // src0 spad is also used in dynamic quantizer to store padded src1 rows - src1_row_size_padded = htp_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float)); + src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float)); if (octx->src0_spad.size_per_thread < src1_row_size_padded) { octx->src0_spad.size_per_thread = src1_row_size_padded; } @@ -2379,7 +2359,7 @@ int op_matmul_id(struct htp_ops_context * octx) { case HTP_TYPE_Q8_0: op_type = "q8x2x2-f32"; - quant_job_func = htp_quantize_fp32_q8x4x2; + quant_job_func = htp_quantize_f32_q8x4x2; src1_row_size = q8x4x2_row_size(ne10); // row size post quantization if (src1_nrows > 1) { matmul_id_job_func = htp_matmul_id_q8x4x2_q8x4x2; @@ -2389,13 +2369,13 @@ int op_matmul_id(struct htp_ops_context * octx) { // Entire src1 tensor is placed into the VTCM // For other tensors we allocate N rows per thread, padded to HVX vector size - octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); - octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256); - octx->src2_spad.size_per_thread = htp_round_up(matrix_row_counts_size + matrix_row_map_size, 256); + octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); + octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); + octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256); + octx->src2_spad.size_per_thread = hex_round_up(matrix_row_counts_size + matrix_row_map_size, 256); // src0 spad is also used in dynamic quantizer to store padded src1 rows - src1_row_size_padded = htp_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float)); + src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float)); if (octx->src0_spad.size_per_thread < src1_row_size_padded) { octx->src0_spad.size_per_thread = src1_row_size_padded; } @@ -2408,7 +2388,7 @@ int op_matmul_id(struct htp_ops_context * octx) { case HTP_TYPE_MXFP4: op_type = "mxfp4x2x2-f32"; - quant_job_func = htp_quantize_fp32_q8x4x2; + quant_job_func = htp_quantize_f32_q8x4x2; src1_row_size = q8x4x2_row_size(ne10); // row size post quantization if (src1_nrows > 1) { matmul_id_job_func = htp_matmul_id_mxfp4x4x2_q8x4x2; @@ -2418,13 +2398,13 @@ int op_matmul_id(struct htp_ops_context * octx) { // Entire src1 tensor is placed into the VTCM // For other tensors we allocate N rows per thread, padded to HVX vector size - octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); - octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256); - octx->src2_spad.size_per_thread = htp_round_up(matrix_row_counts_size + matrix_row_map_size, 256); + octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); + octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); + octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256); + octx->src2_spad.size_per_thread = hex_round_up(matrix_row_counts_size + matrix_row_map_size, 256); // src0 spad is also used in dynamic quantizer to store padded src1 rows - src1_row_size_padded = htp_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float)); + src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float)); if (octx->src0_spad.size_per_thread < src1_row_size_padded) { octx->src0_spad.size_per_thread = src1_row_size_padded; } diff --git a/ggml/src/ggml-hexagon/htp/ops-utils.h b/ggml/src/ggml-hexagon/htp/ops-utils.h deleted file mode 100644 index af9c3305f61..00000000000 --- a/ggml/src/ggml-hexagon/htp/ops-utils.h +++ /dev/null @@ -1,149 +0,0 @@ -#ifndef OPS_UTILS_H -#define OPS_UTILS_H - -#include "htp-msg.h" - -#ifndef MAX -# define MAX(a, b) ((a) > (b) ? (a) : (b)) -#endif - -#ifndef MIN -# define MIN(a, b) ((a) < (b) ? (a) : (b)) -#endif - -static inline uint64_t htp_get_cycles() { - uint64_t cycles = 0; - asm volatile(" %0 = c15:14\n" : "=r"(cycles)); - return cycles; -} - -static inline uint64_t htp_get_pktcnt() { - uint64_t pktcnt; - asm volatile(" %0 = c19:18\n" : "=r"(pktcnt)); - return pktcnt; -} - -static inline int32_t htp_is_aligned(void * addr, uint32_t align) { - return ((size_t) addr & (align - 1)) == 0; -} - -static inline uint32_t htp_round_up(uint32_t n, uint32_t m) { - return m * ((n + m - 1) / m); -} - -// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1. -// Precompute mp (m' in the paper) and L such that division -// can be computed using a multiply (high 32b of 64b result) -// and a shift: -// -// n/d = (mulhi(n, mp) + n) >> L; -struct fastdiv_values { - uint32_t mp; - uint32_t l; -}; - -static inline struct fastdiv_values init_fastdiv_values(uint32_t d) { - struct fastdiv_values result = { 0, 0 }; - // compute L = ceil(log2(d)); - while (result.l < 32 && ((uint32_t) 1 << result.l) < d) { - ++(result.l); - } - - result.mp = (uint32_t) (((uint64_t) 1 << 32) * (((uint64_t) 1 << result.l) - d) / d + 1); - return result; -} - -static inline uint32_t fastdiv(uint32_t n, const struct fastdiv_values * vals) { - // Compute high 32 bits of n * mp - const uint32_t hi = (uint32_t) (((uint64_t) n * vals->mp) >> 32); // mulhi(n, mp) - // add n, apply bit shift - return (hi + n) >> vals->l; -} - -static inline uint32_t fastmodulo(uint32_t n, uint32_t d, const struct fastdiv_values * vals) { - return n - fastdiv(n, vals) * d; -} - -static inline void htp_l2fetch(const void * p, uint32_t height, uint32_t width, uint32_t stride) { - const uint64_t control = Q6_P_combine_RR(stride, Q6_R_combine_RlRl(width, height)); - asm volatile(" l2fetch(%0,%1) " : : "r"(p), "r"(control)); -} - -static inline int32_t htp_is_one_chunk(void * addr, uint32_t n, uint32_t chunk_size) { - uint32_t left_off = (size_t) addr & (chunk_size - 1); - uint32_t right_off = left_off + n; - return right_off <= chunk_size; -} - -static inline void htp_dump_int8_line(char * pref, const int8_t * x, int n) { - char str[1024], *p = str, *p_end = str + sizeof(str); - p += snprintf(p, p_end - p, "%s: ", pref); - for (int i = 0; i < n && p < p_end; i++) { - p += snprintf(p, p_end - p, "%d, ", x[i]); - } - FARF(HIGH, "%s\n", str); -} - -static inline void htp_dump_uint8_line(char * pref, const uint8_t * x, uint32_t n) { - char str[1024], *p = str, *p_end = str + sizeof(str); - p += snprintf(p, p_end - p, "%s: ", pref); - for (int i = 0; i < n && p < p_end; i++) { - p += snprintf(p, p_end - p, "%d, ", x[i]); - } - FARF(HIGH, "%s\n", str); -} - -static inline void htp_dump_int32_line(char * pref, const int32_t * x, uint32_t n) { - char str[1024], *p = str, *p_end = str + sizeof(str); - p += snprintf(p, p_end - p, "%s: ", pref); - for (int i = 0; i < n; i++) { - p += snprintf(p, p_end - p, "%d, ", (int) x[i]); - } - FARF(HIGH, "%s\n", str); -} - -static inline void htp_dump_fp16_line(char * pref, const __fp16 * x, uint32_t n) { - char str[1024], *p = str, *p_end = str + sizeof(str); - p += snprintf(p, p_end - p, "%s: ", pref); - for (int i = 0; i < n; i++) { - p += snprintf(p, p_end - p, "%.6f, ", (float) x[i]); - } - FARF(HIGH, "%s\n", str); -} - -static inline void htp_dump_fp32_line(char * pref, const float * x, uint32_t n) { - char str[1024], *p = str, *p_end = str + sizeof(str); - p += snprintf(p, p_end - p, "%s: ", pref); - for (int i = 0; i < n; i++) { - p += snprintf(p, p_end - p, "%.6f, ", x[i]); - } - FARF(HIGH, "%s\n", str); -} - -static inline void htp_dump_f32(char * pref, const float * x, uint32_t n) { - uint32_t n0 = n / 16; - uint32_t n1 = n % 16; - - uint32_t i = 0; - for (; i < n0; i++) { - htp_dump_fp32_line(pref, x + (16 * i), 16); - } - if (n1) { - htp_dump_fp32_line(pref, x + (16 * i), n1); - } -} - -static inline void htp_dump_f16(char * pref, const __fp16 * x, uint32_t n) { - uint32_t n0 = n / 16; - uint32_t n1 = n % 16; - - uint32_t i = 0; - for (; i < n0; i++) { - htp_dump_fp16_line(pref, x + (16 * i), 16); - } - if (n1) { - htp_dump_fp16_line(pref, x + (16 * i), n1); - } -} - -#endif /* OPS_UTILS_H */ diff --git a/ggml/src/ggml-hexagon/htp/rope-ops.c b/ggml/src/ggml-hexagon/htp/rope-ops.c index a4399704fcb..943ca5c952e 100644 --- a/ggml/src/ggml-hexagon/htp/rope-ops.c +++ b/ggml/src/ggml-hexagon/htp/rope-ops.c @@ -2,27 +2,20 @@ #pragma clang diagnostic ignored "-Wunused-function" #pragma clang diagnostic ignored "-Wunused-but-set-variable" -#ifdef HTP_DEBUG -# define FARF_HIGH 1 -#endif #include -#include #include -#include -#include -#include + #include -#include #include +#include "hex-dma.h" +#include "hvx-utils.h" + #define GGML_COMMON_DECL_C #include "ggml-common.h" #include "htp-ctx.h" -#include "htp-dma.h" #include "htp-msg.h" #include "htp-ops.h" -#include "hvx-utils.h" -#include "ops-utils.h" // Redefined the types GGML_ROPE_TYPE_NORMAL & GGML_ROPE_TYPE_NEOX as we cant include ggml.h #define HTP_ROPE_TYPE_NORMAL 0 @@ -370,8 +363,8 @@ static void rope_job_f32_per_thread(struct rope_th_ctx * rope_ctx, int nth, int int is_aligned = 1; int opt_path = 0; - if ((0 == htp_is_aligned((void *) src0->data, VLEN)) || (0 == htp_is_aligned((void *) src1->data, VLEN)) || - (0 == htp_is_aligned((void *) dst->data, VLEN))) { + if ((0 == hex_is_aligned((void *) src0->data, VLEN)) || (0 == hex_is_aligned((void *) src1->data, VLEN)) || + (0 == hex_is_aligned((void *) dst->data, VLEN))) { FARF(HIGH, "rope-f32: unaligned addresses in rope op, possibly slower execution\n"); is_aligned = 0; } @@ -427,9 +420,9 @@ static int execute_op_rope_f32(struct htp_ops_context * octx) { // VTCM scratchpads for all tensors // N rows per thread, padded to HVX vector size - octx->dst_spad.size = htp_round_up(dst_row_size, 128) * n_threads; - octx->src0_spad.size = htp_round_up(src0_row_size, 128) * n_threads; - octx->src1_spad.size = htp_round_up(src1_row_size, 128) * n_threads; + octx->dst_spad.size = hex_round_up(dst_row_size, 128) * n_threads; + octx->src0_spad.size = hex_round_up(src0_row_size, 128) * n_threads; + octx->src1_spad.size = hex_round_up(src1_row_size, 128) * n_threads; size_t spad_size = octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size; diff --git a/ggml/src/ggml-hexagon/htp/set-rows-ops.c b/ggml/src/ggml-hexagon/htp/set-rows-ops.c index bdd64fcc8f7..904484da9de 100644 --- a/ggml/src/ggml-hexagon/htp/set-rows-ops.c +++ b/ggml/src/ggml-hexagon/htp/set-rows-ops.c @@ -2,24 +2,20 @@ #pragma clang diagnostic ignored "-Wunused-function" #pragma clang diagnostic ignored "-Wunused-but-set-variable" -#ifdef HTP_DEBUG -# define FARF_HIGH 1 -#endif #include -#include #include -#include -#include + #include #include +#include "hex-dma.h" +#include "hvx-utils.h" + #define GGML_COMMON_DECL_C #include "ggml-common.h" #include "htp-ctx.h" #include "htp-msg.h" #include "htp-ops.h" -#include "hvx-utils.h" -#include "ops-utils.h" #define set_rows_preamble \ const uint32_t ne00 = octx->src0.ne[0]; \ @@ -76,7 +72,7 @@ static int set_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth, const uintptr_t dst_ptr = octx->dst.data + i1*nb1 + i02*nb2 + i03*nb3; // copy row - hvx_copy_fp32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, ne00); + hvx_copy_f32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, ne00); } } } @@ -112,7 +108,7 @@ static int set_rows_thread_f16_f32(struct htp_ops_context * octx, const int nth, const uint8_t* src0_ptr = (const uint8_t *) octx->src0.data + i*nb01 + i02*nb02 + i03*nb03; uint8_t* dst_ptr = (uint8_t *) octx->dst.data + i1*nb1 + i02*nb2 + i03*nb3; - hvx_copy_fp16_fp32_uu(dst_ptr, src0_ptr, ne00); + hvx_copy_f16_f32_uu(dst_ptr, src0_ptr, ne00); } } } diff --git a/ggml/src/ggml-hexagon/htp/softmax-ops.c b/ggml/src/ggml-hexagon/htp/softmax-ops.c index 80d249a22c6..1b6b2eba4ae 100644 --- a/ggml/src/ggml-hexagon/htp/softmax-ops.c +++ b/ggml/src/ggml-hexagon/htp/softmax-ops.c @@ -2,27 +2,20 @@ #pragma clang diagnostic ignored "-Wunused-function" #pragma clang diagnostic ignored "-Wunused-but-set-variable" -#ifdef HTP_DEBUG -# define FARF_HIGH 1 -#endif #include -#include #include -#include -#include -#include + #include -#include #include +#include "hex-dma.h" +#include "hvx-utils.h" + #define GGML_COMMON_DECL_C #include "ggml-common.h" #include "htp-ctx.h" -#include "htp-dma.h" #include "htp-msg.h" #include "htp-ops.h" -#include "hvx-utils.h" -#include "ops-utils.h" #define htp_softmax_preamble3 \ const uint32_t ne00 = src0->ne[0]; \ @@ -100,8 +93,8 @@ static void hvx_fast_softmax_prep_f32(const uint8_t * restrict src, uint8_t * restrict dst_curr = dst; const uint8_t * restrict mask_curr = mask; - HVX_Vector scale_vec = hvx_vec_splat_fp32(scale); - HVX_Vector slope_vec = hvx_vec_splat_fp32(slope); + HVX_Vector scale_vec = hvx_vec_splat_f32(scale); + HVX_Vector slope_vec = hvx_vec_splat_f32(slope); int step_of_1 = num_elems >> 5; @@ -134,9 +127,9 @@ static void hvx_fast_softmax_f32(const uint8_t * restrict src, HVX_Vector * restrict v_dst = (HVX_Vector *) dst; HVX_Vector sum_vec = Q6_V_vsplat_R(0x00000000); - HVX_Vector max_vec = hvx_vec_splat_fp32(((const float *) src)[0]); + HVX_Vector max_vec = hvx_vec_splat_f32(((const float *) src)[0]); HVX_Vector zero_v = Q6_V_vzero(); - HVX_Vector one_v = hvx_vec_splat_fp32(1.0); + HVX_Vector one_v = hvx_vec_splat_f32(1.0); int step_of_1 = num_elems >> 5; @@ -146,7 +139,7 @@ static void hvx_fast_softmax_f32(const uint8_t * restrict src, max_vec = Q6_Vsf_vmax_VsfVsf(max_vec, v1); } - HVX_Vector v = hvx_vec_reduce_max_fp32(max_vec); + HVX_Vector v = hvx_vec_reduce_max_f32(max_vec); max_vec = hvx_vec_repl4(v); #pragma unroll(4) @@ -154,18 +147,18 @@ static void hvx_fast_softmax_f32(const uint8_t * restrict src, HVX_Vector v1 = v_src[i]; HVX_Vector v2 = Q6_Vqf32_vsub_VsfVsf(v1, max_vec); - HVX_Vector v3 = hvx_vec_exp_fp32(Q6_Vsf_equals_Vqf32(v2)); + HVX_Vector v3 = hvx_vec_exp_f32(Q6_Vsf_equals_Vqf32(v2)); sum_vec = Q6_Vqf32_vadd_VsfVsf(Q6_Vsf_equals_Vqf32(sum_vec), v3); v_pad[i] = v3; } - v = hvx_vec_qf32_reduce_sum(sum_vec); + v = hvx_vec_reduce_sum_qf32(sum_vec); sum_vec = hvx_vec_repl4(Q6_Vsf_equals_Vqf32(v)); HVX_VectorPred pos_sum = Q6_Q_vcmp_gt_VwVw(sum_vec, zero_v); - HVX_Vector v4 = hvx_vec_inverse_fp32(sum_vec); + HVX_Vector v4 = hvx_vec_inverse_f32(sum_vec); HVX_Vector scale_vec = Q6_V_vmux_QVV(pos_sum, v4, one_v); #pragma unroll(4) @@ -181,11 +174,11 @@ static float hvx_softmax_f32(const uint8_t * restrict src, uint8_t * restrict spad, const int num_elems, const float max) { - hvx_sub_scalar_f32(src, max, spad, num_elems); + hvx_sub_scalar_f32(spad, src, max, num_elems); hvx_exp_f32(spad, dst, num_elems, false); - float sum = hvx_self_sum_f32(dst, num_elems); + float sum = hvx_reduce_sum_f32(dst, num_elems); return sum; } @@ -255,7 +248,7 @@ static void softmax_htp_f32(int nth, int ith, struct softmax_th_ctx * softmax_ct if (1 == opt_path) { hvx_fast_softmax_f32((const uint8_t *) wp0, (uint8_t *) dp, (uint8_t *) wp1, ne00); } else { - float max = hvx_self_max_f32((const uint8_t *) wp0, ne00); + float max = hvx_reduce_max_f32((const uint8_t *) wp0, ne00); float sum = hvx_softmax_f32((const uint8_t *) wp0, (uint8_t *) wp2, (uint8_t *) wp1, ne00, max); sum = sum > 0.0 ? (1.0 / sum) : 1; hvx_scale_f32((uint8_t *) dp, (const uint8_t *) wp2, ne00, sum); @@ -290,7 +283,7 @@ static void softmax_job_f32_per_thread(struct softmax_th_ctx * softmax_ctx, int int is_aligned = 1; int opt_path = 0; - if (!htp_is_aligned((void *) src0->data, VLEN) || !htp_is_aligned((void *) dst->data, VLEN)) { + if (!hex_is_aligned((void *) src0->data, VLEN) || !hex_is_aligned((void *) dst->data, VLEN)) { is_aligned = 0; FARF(HIGH, "softmax-f32: unaligned addresses in elementwise op, possibly slower execution\n"); } @@ -345,9 +338,9 @@ static int execute_op_softmax_f32(struct htp_ops_context * octx) { // VTCM scratchpads for all tensors // N rows per thread, padded to HVX vector size - octx->dst_spad.size = htp_round_up(dst_row_size, 128) * n_threads; - octx->src0_spad.size = htp_round_up(src0_row_size, 128) * n_threads; - octx->src1_spad.size = htp_round_up(src1_row_size, 128) * n_threads; + octx->dst_spad.size = hex_round_up(dst_row_size, 128) * n_threads; + octx->src0_spad.size = hex_round_up(src0_row_size, 128) * n_threads; + octx->src1_spad.size = hex_round_up(src1_row_size, 128) * n_threads; size_t spad_size = octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size; diff --git a/ggml/src/ggml-hexagon/htp/unary-ops.c b/ggml/src/ggml-hexagon/htp/unary-ops.c index 8ed1e5b6619..be8be8c4e64 100644 --- a/ggml/src/ggml-hexagon/htp/unary-ops.c +++ b/ggml/src/ggml-hexagon/htp/unary-ops.c @@ -2,28 +2,20 @@ #pragma clang diagnostic ignored "-Wunused-function" #pragma clang diagnostic ignored "-Wunused-but-set-variable" -#ifdef HTP_DEBUG -# define FARF_HIGH 1 -#endif - #include -#include #include -#include -#include -#include + #include -#include #include +#include "hex-dma.h" +#include "hvx-utils.h" + #define GGML_COMMON_DECL_C #include "ggml-common.h" #include "htp-ctx.h" -#include "htp-dma.h" #include "htp-msg.h" #include "htp-ops.h" -#include "hvx-utils.h" -#include "ops-utils.h" #define htp_unary_preamble \ const uint32_t ne00 = src->ne[0]; \ @@ -55,7 +47,7 @@ static void hvx_fast_rms_norm_f32(const uint8_t * restrict src, HVX_Vector * restrict v_dst = (HVX_Vector *) dst; HVX_Vector sum_v = Q6_V_vsplat_R(0x00000000); - HVX_Vector epsilon_v = hvx_vec_splat_fp32(epsilon); + HVX_Vector epsilon_v = hvx_vec_splat_f32(epsilon); int step_of_1 = num_elems >> 5; #pragma unroll(4) @@ -65,15 +57,15 @@ static void hvx_fast_rms_norm_f32(const uint8_t * restrict src, sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2); } - HVX_Vector reduced_sum = hvx_vec_qf32_reduce_sum(sum_v); + HVX_Vector reduced_sum = hvx_vec_reduce_sum_qf32(sum_v); sum_v = hvx_vec_repl4(Q6_Vsf_equals_Vqf32(reduced_sum)); - HVX_Vector t_v = hvx_vec_splat_fp32((float) num_elems); - HVX_Vector denom_v = hvx_vec_inverse_fp32(t_v); + HVX_Vector t_v = hvx_vec_splat_f32((float) num_elems); + HVX_Vector denom_v = hvx_vec_inverse_f32(t_v); HVX_Vector mean_v = Q6_Vqf32_vmpy_VsfVsf(sum_v, denom_v); HVX_Vector mean_epsilon_v = Q6_Vqf32_vadd_Vqf32Vsf(mean_v, epsilon_v); - HVX_Vector scale_v = hvx_vec_rsqrt_fp32(Q6_Vsf_equals_Vqf32(mean_epsilon_v)); + HVX_Vector scale_v = hvx_vec_rsqrt_f32(Q6_Vsf_equals_Vqf32(mean_epsilon_v)); #pragma unroll(4) for (int i = 0; i < step_of_1; i++) { @@ -101,7 +93,7 @@ static void scale_htp_f32(const float * restrict src, float * restrict dst_local = dst + (ir * row_elems); if (ir + 1 < num_rows) { - htp_l2fetch(src_local + row_elems, 1, row_size, row_size); + hex_l2fetch(src_local + row_elems, row_size, row_size, 1); } hvx_scale_offset_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems, scale, bias); @@ -124,7 +116,7 @@ static void rms_norm_htp_f32(const float * restrict src, float * restrict dst_local = dst + (ir * row_elems); if (ir + 1 < num_rows) { - htp_l2fetch(src_local + row_elems, 1, row_size, row_size); + hex_l2fetch(src_local + row_elems, row_size, row_size, 1); } if (1 == opt_path) { @@ -168,9 +160,8 @@ static void unary_job_f32_per_thread(const struct htp_tensor * src, int is_aligned = 1; int opt_path = 0; - if ((0 == htp_is_aligned((void *) src->data, VLEN)) || (0 == htp_is_aligned((void *) dst->data, VLEN))) { + if ((0 == hex_is_aligned((void *) src->data, VLEN)) || (0 == hex_is_aligned((void *) dst->data, VLEN))) { is_aligned = 0; - FARF(HIGH, "unary-f32: unaligned addresses in unary op, possibly slower execution\n"); } if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) { opt_path = 1; @@ -240,8 +231,8 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) { const size_t dst_row_size = dst->nb[1]; // VTCM scratchpads for all tensors - octx->dst_spad.size = htp_round_up(dst_row_size, 128) * n_threads; - octx->src0_spad.size = htp_round_up(src0_row_size, 128) * n_threads; + octx->dst_spad.size = hex_round_up(dst_row_size, 128) * n_threads; + octx->src0_spad.size = hex_round_up(src0_row_size, 128) * n_threads; size_t spad_size = octx->src0_spad.size + octx->dst_spad.size; diff --git a/ggml/src/ggml-hexagon/htp/worker-pool.c b/ggml/src/ggml-hexagon/htp/worker-pool.c index cd38c2126c7..894815f46a5 100644 --- a/ggml/src/ggml-hexagon/htp/worker-pool.c +++ b/ggml/src/ggml-hexagon/htp/worker-pool.c @@ -7,10 +7,6 @@ #include #include -#ifdef HTP_DEBUG -# define FARF_HIGH 1 -#endif - #include "HAP_farf.h" #define WORKER_THREAD_STACK_SZ (2 * 16384) From 78a23d48302c7e458e69f7cb0491c4a0f337f4d5 Mon Sep 17 00:00:00 2001 From: shalinib-ibm Date: Thu, 15 Jan 2026 15:01:18 +0530 Subject: [PATCH 009/831] ggml-cpu: optimize ggml_vec_dot_bf16 for Power9 (llama/18837) --- ggml/src/ggml-cpu/simd-mappings.h | 31 +++++++++++++++++++++++++++++++ ggml/src/ggml-cpu/vec.cpp | 18 ++++++++++++++++++ 2 files changed, 49 insertions(+) diff --git a/ggml/src/ggml-cpu/simd-mappings.h b/ggml/src/ggml-cpu/simd-mappings.h index a7a82722052..e367f110b46 100644 --- a/ggml/src/ggml-cpu/simd-mappings.h +++ b/ggml/src/ggml-cpu/simd-mappings.h @@ -654,6 +654,14 @@ static inline void __avx_f32cx8_store(ggml_fp16_t *x, __m256 y) { vec_extract(x[0], 2) + \ vec_extract(x[0], 3); \ } +#define GGML_F32x4_REDUCE_4(res, s0, s1, s2, s3) \ +{ \ + vector float v = vec_add(vec_add(s0, s1), \ + vec_add(s2, s3)); \ + v = vec_add(v, vec_sld(v, v, 8)); \ + v = vec_add(v, vec_sld(v, v, 4)); \ + res += (ggml_float) vec_extract(v, 0); \ +} #define GGML_F32_VEC GGML_F32x4 #define GGML_F32_VEC_ZERO GGML_F32x4_ZERO @@ -690,6 +698,29 @@ static inline unsigned char ggml_endian_byte(int i) { r[i - GGML_ENDIAN_BYTE(0)]), \ 0, p - GGML_F16_EPR) +//BF16 POWER9 +#define GGML_BF16_STEP 16 +#define GGML_BF16_EPR 8 + +#define GGML_BF16x8 vector unsigned short +#define GGML_BF16x8_ZERO vec_splats((unsigned short)0) +#define GGML_BF16x8_LOAD(p) vec_xl(0, (const unsigned short *)(p)) + +#define GGML_BF16_VEC GGML_BF16x8 +#define GGML_BF16_VEC_ZERO GGML_BF16x8_ZERO +#define GGML_BF16_VEC_LOAD GGML_BF16x8_LOAD +#if defined(__LITTLE_ENDIAN__) +#define GGML_BF16_TO_F32_LO(v) ((vector float) vec_mergel(GGML_BF16_VEC_ZERO, (v))) +#define GGML_BF16_TO_F32_HI(v) ((vector float) vec_mergeh(GGML_BF16_VEC_ZERO, (v))) +#else +#define GGML_BF16_TO_F32_LO(v) ((vector float) vec_mergel((v), GGML_BF16_VEC_ZERO)) +#define GGML_BF16_TO_F32_HI(v) ((vector float) vec_mergeh((v), GGML_BF16_VEC_ZERO)) +#endif +#define GGML_BF16_FMA_LO(acc, x, y) \ + (acc) = GGML_F32x4_FMA((acc), GGML_BF16_TO_F32_LO(x), GGML_BF16_TO_F32_LO(y)) +#define GGML_BF16_FMA_HI(acc, x, y) \ + (acc) = GGML_F32x4_FMA((acc), GGML_BF16_TO_F32_HI(x), GGML_BF16_TO_F32_HI(y)) + #elif defined(__wasm_simd128__) #define GGML_SIMD diff --git a/ggml/src/ggml-cpu/vec.cpp b/ggml/src/ggml-cpu/vec.cpp index 427e63245b0..8708cd4e92f 100644 --- a/ggml/src/ggml-cpu/vec.cpp +++ b/ggml/src/ggml-cpu/vec.cpp @@ -237,6 +237,24 @@ void ggml_vec_dot_bf16(int n, float * GGML_RESTRICT s, size_t bs, ggml_bf16_t * sumf += __riscv_vfmv_f_s_f32m1_f32(redsum); #endif +#if defined(__POWER9_VECTOR__) + const int np = (n & ~(GGML_BF16_STEP - 1)); + if (np > 0) { + GGML_F32_VEC sum[4] = {GGML_F32_VEC_ZERO}; + for (; i < np; i += GGML_BF16_STEP) { + GGML_BF16_VEC vx0 = GGML_BF16_VEC_LOAD(x + i); + GGML_BF16_VEC vx1 = GGML_BF16_VEC_LOAD(x + i + 8); + GGML_BF16_VEC vy0 = GGML_BF16_VEC_LOAD(y + i); + GGML_BF16_VEC vy1 = GGML_BF16_VEC_LOAD(y + i + 8); + GGML_BF16_FMA_LO(sum[0], vx0, vy0); + GGML_BF16_FMA_HI(sum[1], vx0, vy0); + GGML_BF16_FMA_LO(sum[2], vx1, vy1); + GGML_BF16_FMA_HI(sum[3], vx1, vy1); + } + GGML_F32x4_REDUCE_4(sumf, sum[0], sum[1], sum[2], sum[3]); + } +#endif + for (; i < n; ++i) { sumf += (ggml_float)(GGML_BF16_TO_FP32(x[i]) * GGML_BF16_TO_FP32(y[i])); From f2f0ba0384e55a880a75f6928aa4dfdbe801d320 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Thu, 15 Jan 2026 15:14:50 +0100 Subject: [PATCH 010/831] CUDA: fix allignment on register spill for FA (llama/18815) --- ggml/src/ggml-cuda/fattn-common.cuh | 4 +-- ggml/src/ggml-cuda/fattn-tile.cuh | 42 ++++++++++++++--------------- ggml/src/ggml-cuda/fattn-vec.cuh | 4 +-- 3 files changed, 25 insertions(+), 25 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index 6b55f784f34..8468ba8488d 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -59,7 +59,7 @@ static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_f16( #pragma unroll for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += nthreads*cpy_ne) { - half2 tmp[cpy_ne]; + __align__(16) half2 tmp[cpy_ne]; ggml_cuda_memcpy_1(tmp, K_h2 + k_KQ_0 + (threadIdx.x % nthreads)*cpy_ne); #pragma unroll for (int k_KQ_1 = 0; k_KQ_1 < cpy_ne; ++k_KQ_1) { @@ -309,7 +309,7 @@ static __device__ __forceinline__ void dequantize_V_f16(const void * __restrict_ ggml_cuda_memcpy_1(dst, (const half *) vx + i0); } else if constexpr (std::is_same_v) { static_assert(ne % 2 == 0, "bad ne"); - half2 tmp[ne/2]; + __align__(16) half2 tmp[ne/2]; ggml_cuda_memcpy_1(tmp, (const half *) vx + i0); float2 * dst_f2 = (float2 *) dst; #pragma unroll diff --git a/ggml/src/ggml-cuda/fattn-tile.cuh b/ggml/src/ggml-cuda/fattn-tile.cuh index 7c4d6fe67fe..f055da8e2be 100644 --- a/ggml/src/ggml-cuda/fattn-tile.cuh +++ b/ggml/src/ggml-cuda/fattn-tile.cuh @@ -343,7 +343,7 @@ static __device__ __forceinline__ void flash_attn_tile_load_tile( for (int j0 = j0_start; j0 < j0_stop; j0 += stride_j) { const int j = j0*cpy_ne + (stride_j == warp_size ? threadIdx.x : threadIdx.x % stride_j)*cpy_ne; - const half2 zero[cpy_ne] = {{0.0f, 0.0f}}; + const __align__(16) half2 zero[cpy_ne] = {{0.0f, 0.0f}}; ggml_cuda_memcpy_1( tile_KV + i*(J/2 + J_padding) + j, !oob_check || i < i_sup ? KV + i*stride_KV + j : zero); @@ -394,11 +394,11 @@ static __device__ __forceinline__ void flash_attn_tile_load_tile( const int j = j0*(cpy_ne/2) + (stride_j == warp_size ? threadIdx.x : threadIdx.x % stride_j)*(cpy_ne/2); const half2 zero[cpy_ne/2] = {{0.0f, 0.0f}}; - half2 tmp_h2[cpy_ne/2]; + __align__(16) half2 tmp_h2[cpy_ne/2]; ggml_cuda_memcpy_1( tmp_h2, !oob_check || i < i_sup ? KV + i*stride_KV + j : zero); - float2 tmp_f2[cpy_ne/2]; + __align__(16) float2 tmp_f2[cpy_ne/2]; #pragma unroll for (int l = 0; l < cpy_ne/2; ++l) { tmp_f2[l] = __half22float2(tmp_h2[l]); @@ -445,14 +445,14 @@ static __device__ __forceinline__ void flash_attn_tile_iter_KQ( static_assert((nbatch_K/2) % cpy_ne == 0, "bad nbatch_K"); #pragma unroll for (int k_KQ_1 = 0; k_KQ_1 < nbatch_K/2; k_KQ_1 += cpy_ne) { - half2 K_k[nbatch_fa/(np*warp_size)][cpy_ne]; - half2 Q_k[cpw][cpy_ne]; + __align__(16) half2 K_k[nbatch_fa/(np*warp_size)][cpy_ne]; + __align__(16) half2 Q_k[cpw][cpy_ne]; #else static_assert(nbatch_K % cpy_ne == 0, "bad nbatch_K"); #pragma unroll for (int k_KQ_1 = 0; k_KQ_1 < nbatch_K; k_KQ_1 += cpy_ne) { - float K_k[nbatch_fa/(np*warp_size)][cpy_ne]; - float Q_k[cpw][cpy_ne]; + __align__(16) float K_k[nbatch_fa/(np*warp_size)][cpy_ne]; + __align__(16) float Q_k[cpw][cpy_ne]; #endif // FAST_FP16_AVAILABLE #pragma unroll @@ -602,9 +602,9 @@ static __device__ __forceinline__ void flash_attn_tile_iter( #pragma unroll for (int jc0 = 0; jc0 < cpw; jc0 += KQ_cs) { #ifdef FAST_FP16_AVAILABLE - half tmp[nbatch_fa/(np*warp_size)][KQ_cs]; + __align__(16) half tmp[nbatch_fa/(np*warp_size)][KQ_cs]; #else - float tmp[nbatch_fa/(np*warp_size)][KQ_cs]; + __align__(16) float tmp[nbatch_fa/(np*warp_size)][KQ_cs]; #endif // FAST_FP16_AVAILABLE #pragma unroll @@ -664,8 +664,8 @@ static __device__ __forceinline__ void flash_attn_tile_iter( #ifdef FAST_FP16_AVAILABLE #pragma unroll for (int k1 = 0; k1 < nbatch_V; k1 += np) { - half2 V_k[(DVp/2)/warp_size]; - half2 KQ_k[cpw]; + __align__(16) half2 V_k[(DVp/2)/warp_size]; + __align__(16) half2 KQ_k[cpw]; constexpr int cpy_ne_D = cpy_ne/2 < (DVp/2)/warp_size ? cpy_ne/2 : (DVp/2)/warp_size; #pragma unroll @@ -676,7 +676,7 @@ static __device__ __forceinline__ void flash_attn_tile_iter( for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; jc_VKQ_0 += KQ_cs) { const int jc_KQ = jc_VKQ_0/KQ_cs + (threadIdx.y / np)*(cpw/KQ_cs); - half tmp[KQ_cs]; + __align__(16) half tmp[KQ_cs]; ggml_cuda_memcpy_1( &tmp, KQ + jc_KQ*(nbatch_fa*KQ_cs) + (k0 + k1 + threadIdx.y % np)*KQ_cs); #pragma unroll @@ -696,8 +696,8 @@ static __device__ __forceinline__ void flash_attn_tile_iter( #else #pragma unroll for (int k1 = 0; k1 < nbatch_V; k1 += np) { - float2 V_k[(DVp/2)/warp_size]; - float KQ_k[cpw]; + __align__(16) float2 V_k[(DVp/2)/warp_size]; + __align__(16) float KQ_k[cpw]; constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size; #pragma unroll @@ -821,12 +821,12 @@ static __global__ void flash_attn_tile( __shared__ half2 Q_tmp[ncols * DKQ/2]; __shared__ half2 KV_tmp[nbatch_fa * (nbatch_K/2 + cpy_ne) + DVp-DV]; __shared__ half KQ[ncols * nbatch_fa]; - half2 VKQ[cpw * ((DVp/2)/warp_size)] = {{0.0f, 0.0f}}; + __align__(16) half2 VKQ[cpw * ((DVp/2)/warp_size)] = {{0.0f, 0.0f}}; #else __shared__ float Q_tmp[ncols * DKQ]; __shared__ float KV_tmp[nbatch_fa * (nbatch_K + cpy_ne) + DVp-DV]; __shared__ float KQ[ncols * nbatch_fa]; - float2 VKQ[cpw * ((DVp/2)/warp_size)] = {{0.0f, 0.0f}}; + __align__(16) float2 VKQ[cpw * ((DVp/2)/warp_size)] = {{0.0f, 0.0f}}; #endif // FAST_FP16_AVAILABLE float KQ_max[cpw]; @@ -849,7 +849,7 @@ static __global__ void flash_attn_tile( #pragma unroll for (int i0 = 0; i0 < DKQp; i0 += np*warp_size*cpy_ne_D) { if (i0 + np*warp_size*cpy_ne_D <= DKQ || i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x*cpy_ne_D < DKQ) { - float tmp_f[cpy_ne_D] = {0.0f}; + __align__(16) float tmp_f[cpy_ne_D] = {0.0f}; ggml_cuda_memcpy_1 (tmp_f, &Q_f[c*(nb02/sizeof(float)) + fastmodulo(col_Q_0 + j, ne01)*(nb01/sizeof(float)) + i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x*cpy_ne_D]); @@ -860,7 +860,7 @@ static __global__ void flash_attn_tile( } #ifdef FAST_FP16_AVAILABLE - half2 tmp_h2[cpy_ne_D/2]; + __align__(16) half2 tmp_h2[cpy_ne_D/2]; #pragma unroll for (int i1 = 0; i1 < cpy_ne_D; i1 += 2) { tmp_h2[i1/2] = make_half2(tmp_f[i1 + 0], tmp_f[i1 + 1]); @@ -959,7 +959,7 @@ static __global__ void flash_attn_tile( constexpr int cpy_ne_D = cpy_ne < (DVp/2)/warp_size ? cpy_ne : (DVp/2)/warp_size; #pragma unroll for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) { - half2 tmp[cpy_ne_D]; + __align__(16) half2 tmp[cpy_ne_D]; ggml_cuda_memcpy_1(tmp, &VKQ_combine[(threadIdx.y + ip)*(DVp/2) + i0 + threadIdx.x*cpy_ne_D]); #pragma unroll for (int i1 = 0; i1 < cpy_ne_D; ++i1) { @@ -970,7 +970,7 @@ static __global__ void flash_attn_tile( constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size; #pragma unroll for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) { - float tmp[cpy_ne_D]; + __align__(16) float tmp[cpy_ne_D]; ggml_cuda_memcpy_1(tmp, &VKQ_combine[(threadIdx.y + ip)*DVp + i0 + threadIdx.x*cpy_ne_D]); #pragma unroll for (int i1 = 0; i1 < cpy_ne_D; ++i1) { @@ -1033,7 +1033,7 @@ static __global__ void flash_attn_tile( constexpr int cpy_ne_D = cpy_ne/2 < (DVp/2)/warp_size ? cpy_ne/2 : (DVp/2)/warp_size; #pragma unroll for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) { - float2 tmp[cpy_ne_D]; + __align__(16) float2 tmp[cpy_ne_D]; #pragma unroll for (int i1 = 0; i1 < cpy_ne_D; ++i1) { tmp[i1] = __half22float2(VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size + i1]); diff --git a/ggml/src/ggml-cuda/fattn-vec.cuh b/ggml/src/ggml-cuda/fattn-vec.cuh index 86f4dc0f7f1..3f4a78cc6e5 100644 --- a/ggml/src/ggml-cuda/fattn-vec.cuh +++ b/ggml/src/ggml-cuda/fattn-vec.cuh @@ -132,7 +132,7 @@ static __global__ void flash_attn_ext_vec( #ifdef V_DOT2_F32_F16_AVAILABLE half2 Q_reg[ncols][(D/2)/nthreads_KQ]; // Will be initialized completely. #else - float2 Q_reg[ncols][(D/2)/nthreads_KQ] = {{{0.0f, 0.0f}}}; // May be only partially initialized. + __align__(16) float2 Q_reg[ncols][(D/2)/nthreads_KQ] = {{{0.0f, 0.0f}}}; // May be only partially initialized. #endif // V_DOT2_F32_F16_AVAILABLE int Q_i32[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)]; float2 Q_ds[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)]; @@ -200,7 +200,7 @@ static __global__ void flash_attn_ext_vec( for (int i0 = 0; i0 < D/2; i0 += nthreads_KQ*cpy_ne) { const int i = i0 + (nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ)*cpy_ne; - float2 tmp[cpy_ne] = {{0.0f, 0.0f}}; + __align__(16) float2 tmp[cpy_ne] = {{0.0f, 0.0f}}; if (ncols == 1 || ic0 + j < int(ne01.z)) { ggml_cuda_memcpy_1(tmp, &Q_j[i]); ggml_cuda_memcpy_1(tmp + cpy_ne/2, &Q_j[i + cpy_ne/2]); From 290ff3d28da31c8325a4e2f43d166fe18a00e7c0 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 15 Jan 2026 20:53:01 +0200 Subject: [PATCH 011/831] cuda : print less debug logs when disabling cuda graphs (llama/18868) --- ggml/src/ggml-cuda/ggml-cuda.cu | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 553623fbd42..ed1021469a7 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3730,8 +3730,10 @@ static bool ggml_cuda_graph_set_enabled(ggml_backend_cuda_context * cuda_ctx) { if (cuda_ctx->cuda_graph->graph == nullptr) { if (ggml_cuda_info().devices[cuda_ctx->device].cc < GGML_CUDA_CC_AMPERE) { + if (!cuda_ctx->cuda_graph->disable_due_to_gpu_arch) { + GGML_LOG_DEBUG("%s: disabling CUDA graphs due to GPU architecture\n", __func__); + } cuda_ctx->cuda_graph->disable_due_to_gpu_arch = true; - GGML_LOG_DEBUG("%s: disabling CUDA graphs due to GPU architecture\n", __func__); } } From ed6004d051b9b914d4bb94fe9595cb0f6df93aa5 Mon Sep 17 00:00:00 2001 From: shaofeiqi <109865877+shaofeiqi@users.noreply.github.com> Date: Thu, 15 Jan 2026 11:17:17 -0800 Subject: [PATCH 012/831] OpenCL: add SOLVE_TRI op support (llama/18846) --- ggml/src/ggml-opencl/CMakeLists.txt | 1 + ggml/src/ggml-opencl/ggml-opencl.cpp | 92 +++++++++++++++++++++++ ggml/src/ggml-opencl/kernels/solve_tri.cl | 51 +++++++++++++ 3 files changed, 144 insertions(+) create mode 100644 ggml/src/ggml-opencl/kernels/solve_tri.cl diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index d8fa53109b7..307ec08242a 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -69,6 +69,7 @@ set(GGML_OPENCL_KERNELS get_rows glu group_norm + solve_tri im2col_f32 im2col_f16 mean diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index d925f67f065..d89d5e7242d 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -531,6 +531,7 @@ struct ggml_backend_opencl_context { cl_kernel kernel_mul_mv_q6_K_f32; cl_kernel kernel_mul_mv_mxfp4_f32, kernel_mul_mv_mxfp4_f32_flat; cl_kernel kernel_mul_mv_q8_0_f32, kernel_mul_mv_q8_0_f32_flat; + cl_kernel kernel_solve_tri_f32; cl_kernel kernel_im2col_f32, kernel_im2col_f16; cl_kernel kernel_argsort_f32_i32; cl_kernel kernel_sum_rows_f32; @@ -952,6 +953,23 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // solve_tri_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "solve_tri.cl.h" + }; +#else + const std::string kernel_src = read_file("solve_tri.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_solve_tri_f32 = clCreateKernel(prog, "kernel_solve_tri_f32", &err), err)); + GGML_LOG_CONT("."); + CL_CHECK(clReleaseProgram(prog)); + } + // im2col_f32 { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -3266,6 +3284,8 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te } return true; } + case GGML_OP_SOLVE_TRI: + return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]); case GGML_OP_IM2COL: return true; case GGML_OP_ARGSORT: { @@ -9474,6 +9494,72 @@ static void ggml_cl_rope(ggml_backend_t backend, const ggml_tensor * src0, const backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); } +static void ggml_cl_solve_tri(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + cl_kernel kernel = backend_ctx->kernel_solve_tri_f32; + GGML_ASSERT(kernel != nullptr); + + const int n = src0->ne[0]; + const int k = src1->ne[0]; + + const cl_ulong nb00 = src0->nb[0]; + const cl_ulong nb01 = src0->nb[1]; + const cl_ulong nb02 = src0->nb[2]; + const cl_ulong nb03 = src0->nb[3]; + + const cl_ulong nb10 = src1->nb[0]; + const cl_ulong nb11 = src1->nb[1]; + const cl_ulong nb12 = src1->nb[2]; + const cl_ulong nb13 = src1->nb[3]; + + const cl_ulong nb0 = dst->nb[0]; + const cl_ulong nb1 = dst->nb[1]; + const cl_ulong nb2 = dst->nb[2]; + const cl_ulong nb3 = dst->nb[3]; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &n)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &k)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb00)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong),&nb02)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong),&nb03)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong),&nb10)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong),&nb11)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong),&nb12)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong),&nb13)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong),&nb0)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong),&nb1)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong),&nb2)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong),&nb3)); + + size_t global_work_size[3]= { (size_t)k, (size_t)dst->ne[2], (size_t)dst->ne[3]}; + size_t local_work_size[] = {16, 4, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); +} + static void ggml_cl_im2col(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0); GGML_ASSERT(src1); @@ -10039,6 +10125,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor } func = ggml_cl_rope; break; + case GGML_OP_SOLVE_TRI: + if (!any_on_device) { + return false; + } + func = ggml_cl_solve_tri; + break; case GGML_OP_IM2COL: if (!any_on_device) { return false; diff --git a/ggml/src/ggml-opencl/kernels/solve_tri.cl b/ggml/src/ggml-opencl/kernels/solve_tri.cl new file mode 100644 index 00000000000..80745fc7045 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/solve_tri.cl @@ -0,0 +1,51 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +//------------------------------------------------------------------------------ +// solve_tri +//------------------------------------------------------------------------------ +kernel void kernel_solve_tri_f32( + global uchar * src0, + ulong offset0, + global uchar * src1, + ulong offset1, + global uchar * dst, + ulong offsetd, + int n, + int k, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + ulong nb10, + ulong nb11, + ulong nb12, + ulong nb13, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 +) { + int col = get_global_id(0); + int i2 = get_global_id(1); + int i3 = get_global_id(2); + + global const uchar * Lb = src0 + offset0 + i2 * nb02 + i3 * nb03; + global const uchar * Bb = src1 + offset1 + i2 * nb12 + i3 * nb13; + global uchar * Xb = dst + offsetd + i2 * nb2 + i3 * nb3; + + for(int row = 0; row < n; ++row){ + global const float *pB = (global const float *)(Bb + row * nb11 + col * nb10); + + float sum = 0.0f; + for(int j = 0; j < row; ++j){ + global const float *pL = (global const float *)(Lb + row * nb01 + j * nb00); + global const float *pX = (global const float *)(Xb + j * nb1 + col * nb0); + sum += (*pL) * (*pX); + } + + global const float * pDiag = (global const float *)(Lb + row * nb01 + row *nb00); + global float * pOut = (global float *)(Xb + row * nb1 + col *nb0); + + *pOut = ((* pB) - sum) / (*pDiag); + } +} From 854274a297335c8e91df00cc14e3deb802b73367 Mon Sep 17 00:00:00 2001 From: hipudding Date: Fri, 16 Jan 2026 16:18:49 +0800 Subject: [PATCH 013/831] CANN: support gated linear attn (llama/18653) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * CANN: support gated linear attn This change adds support for the GGML_OP_GATED_LINEAR_ATTN operator. The feature was implemented by YushengZhao. Because the previous submission was based on an outdated codebase, this PR was rebased to merge. Co-authored-by: YushengZhao Co-authored-by: hipudding * CANN: optimize OP gla Optimize gla for high preformance * Remove unused comments --------- Co-authored-by: 赵禹昇 <2501112001@cninfer02.localdomain> Co-authored-by: YushengZhao --- ggml/src/ggml-cann/aclnn_ops.cpp | 220 ++++++++++++++++++++----------- ggml/src/ggml-cann/aclnn_ops.h | 123 ++++++----------- ggml/src/ggml-cann/ggml-cann.cpp | 4 + 3 files changed, 186 insertions(+), 161 deletions(-) diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index 6b718e01c31..02867e4fdb5 100644 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -58,6 +58,7 @@ #include #include #include +#include #include #include #include @@ -2338,20 +2339,21 @@ static void aclnn_rope_cache_init(ggml_backend_cann_context & ctx, // Step1.2: prepare rope_yarn_ramp, if this part updated, should update theta_scale_tensor. // TODO: acl_yarn_ramp_tensor use rope cache. - bool yarn_ramp_tensor_updated = false; - acl_tensor_ptr acl_yarn_ramp_tensor; + bool yarn_ramp_tensor_updated = false; + acl_tensor_ptr acl_yarn_ramp_tensor; if (ext_factor != 0 && (theta_scale_updated || ctx.rope_cache.theta_scale_length != theta_scale_length || ctx.rope_cache.freq_scale != freq_scale)) { yarn_ramp_tensor_updated = true; if (ctx.rope_cache.yarn_ramp_cache != nullptr) { ACL_CHECK(aclrtFree(ctx.rope_cache.yarn_ramp_cache)); } - ACL_CHECK(aclrtMalloc(&ctx.rope_cache.yarn_ramp_cache, theta_scale_length * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc(&ctx.rope_cache.yarn_ramp_cache, theta_scale_length * sizeof(float), + ACL_MEM_MALLOC_HUGE_FIRST)); // -rope_yarn_ramp // const float y = (i0 / 2 - low) / MAX(0.001f, high - low); // return MIN(1, MAX(0, y)) - 1; - acl_yarn_ramp_tensor = - ggml_cann_create_tensor(ctx.rope_cache.yarn_ramp_cache, ACL_FLOAT, sizeof(float), theta_scale_ne, theta_scale_nb, 1); + acl_yarn_ramp_tensor = ggml_cann_create_tensor(ctx.rope_cache.yarn_ramp_cache, ACL_FLOAT, sizeof(float), + theta_scale_ne, theta_scale_nb, 1); float zero_value = 0, one_value = 1; float denom_safe_value = MAX(0.001f, corr_dims[1] - corr_dims[0]); acl_scalar_ptr low = ggml_cann_create_scalar(&corr_dims[0], aclDataType::ACL_FLOAT); @@ -2382,8 +2384,8 @@ static void aclnn_rope_cache_init(ggml_backend_cann_context & ctx, GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_yarn_ramp_tensor.get(), freq_scale_1_sc.get()); GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdds, acl_yarn_ramp_tensor.get(), freq_scale_sc.get(), one.get()); } else { - acl_yarn_ramp_tensor = - ggml_cann_create_tensor(ctx.rope_cache.yarn_ramp_cache, ACL_FLOAT, sizeof(float), theta_scale_ne, theta_scale_nb, 1); + acl_yarn_ramp_tensor = ggml_cann_create_tensor(ctx.rope_cache.yarn_ramp_cache, ACL_FLOAT, sizeof(float), + theta_scale_ne, theta_scale_nb, 1); } // Step 1.3: update theta_scale_tensor according to ext_factor or freq_scale. if (ext_factor != 0) { @@ -2991,20 +2993,20 @@ void ggml_cann_argmax(ggml_backend_cann_context & ctx, ggml_tensor * dst) { GGML_CANN_CALL_ACLNN_OP(ctx, ArgMax, acl_src.get(), 3, false, acl_dst.get()); } -void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* dst){ +void ggml_cann_conv_transpose_1d(ggml_backend_cann_context & ctx, ggml_tensor * dst) { ggml_tensor * src0 = dst->src[0]; ggml_tensor * src1 = dst->src[1]; // stride - int64_t s0 = ((const int32_t*)(dst->op_params))[0]; + int64_t s0 = ((const int32_t *) (dst->op_params))[0]; - acl_tensor_ptr acl_input = ggml_cann_create_tensor(src1, src1->ne, src1->nb, 3, ACL_FORMAT_NCL); + acl_tensor_ptr acl_input = ggml_cann_create_tensor(src1, src1->ne, src1->nb, 3, ACL_FORMAT_NCL); acl_tensor_ptr acl_weight = ggml_cann_create_tensor(src0, src0->ne, src0->nb, 3, ACL_FORMAT_NCL); - acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst, dst->ne, dst->nb, 3, ACL_FORMAT_NCL); + acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst, dst->ne, dst->nb, 3, ACL_FORMAT_NCL); // get base information of input and kernel - int64_t input_len = *(src1->ne); - int64_t dst_len = *(dst->ne); + int64_t input_len = *(src1->ne); + int64_t dst_len = *(dst->ne); int64_t kernel_size = *(src0->ne); // set the max kernel size for each conv @@ -3012,56 +3014,55 @@ void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* ds // compute the partition of kernel int64_t part_num = 1; - part_num = (kernel_size + max_kernel_size - 1) / max_kernel_size; + part_num = (kernel_size + max_kernel_size - 1) / max_kernel_size; int64_t strideVal[1]; - strideVal[0] = s0; - acl_int_array_ptr stride = ggml_cann_create_int_array(strideVal, 1); - int64_t paddingVal[] = {0}; - acl_int_array_ptr padding = ggml_cann_create_int_array(paddingVal, 1); - int64_t dilationVal[] = {1}; - acl_int_array_ptr dilation = ggml_cann_create_int_array(dilationVal, 1); - bool transposed = true; - int64_t groups = 1; - int8_t cubeMathType = 0; + strideVal[0] = s0; + acl_int_array_ptr stride = ggml_cann_create_int_array(strideVal, 1); + int64_t paddingVal[] = { 0 }; + acl_int_array_ptr padding = ggml_cann_create_int_array(paddingVal, 1); + int64_t dilationVal[] = { 1 }; + acl_int_array_ptr dilation = ggml_cann_create_int_array(dilationVal, 1); + bool transposed = true; + int64_t groups = 1; + int8_t cubeMathType = 0; #ifdef ASCEND_310P cubeMathType = 1; #endif auto weight_type = ggml_cann_type_mapping(src0->type); - auto dst_type = ggml_cann_type_mapping(dst->type); + auto dst_type = ggml_cann_type_mapping(dst->type); // slice the kernel to make each conv available - int64_t slice_dim = -1; + int64_t slice_dim = -1; int64_t slice_start = 0; - int64_t slice_end = max_kernel_size; - int64_t slice_step = 1; - int64_t interval = max_kernel_size; + int64_t slice_end = max_kernel_size; + int64_t slice_step = 1; + int64_t interval = max_kernel_size; - int64_t left_pad_len = dilationVal[0] * (max_kernel_size - 1) + 1 - 2 * paddingVal[0]; + int64_t left_pad_len = dilationVal[0] * (max_kernel_size - 1) + 1 - 2 * paddingVal[0]; int64_t right_pad_len = 0; - acl_scalar_ptr alpha = nullptr; - float alphaValue = 1.0; - alpha = ggml_cann_create_scalar(&alphaValue, aclDataType::ACL_FLOAT); + acl_scalar_ptr alpha = nullptr; + float alphaValue = 1.0; + alpha = ggml_cann_create_scalar(&alphaValue, aclDataType::ACL_FLOAT); // set zero to destination GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, acl_dst.get()); - for(int k = 0; k < part_num; k++){ - + for (int k = 0; k < part_num; k++) { // create part kernel tensor and slice from big kernel slice_start = max_kernel_size * k; - if(k == part_num - 1){ + if (k == part_num - 1) { slice_end = kernel_size; - interval = kernel_size - max_kernel_size * k; - }else{ - slice_end = max_kernel_size * (k+1); + interval = kernel_size - max_kernel_size * k; + } else { + slice_end = max_kernel_size * (k + 1); } int64_t part_ne[4]; - for(int i = 0; i < 4; i++) { + for (int i = 0; i < 4; i++) { part_ne[i] = *(src0->ne + i); } part_ne[0] = interval; @@ -3074,16 +3075,17 @@ void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* ds ggml_cann_pool_alloc part_kernel_allocator; part_kernel_allocator.alloc(ctx.pool(), part_nb[3]); - void* part_kernel_buf = part_kernel_allocator.get(); + void * part_kernel_buf = part_kernel_allocator.get(); - acl_tensor_ptr part_kernel = ggml_cann_create_tensor(part_kernel_buf, weight_type, - ggml_element_size(src0), part_ne, part_nb, 3, ACL_FORMAT_NCL); + acl_tensor_ptr part_kernel = ggml_cann_create_tensor(part_kernel_buf, weight_type, ggml_element_size(src0), + part_ne, part_nb, 3, ACL_FORMAT_NCL); - GGML_CANN_CALL_ACLNN_OP(ctx, Slice, acl_weight.get(), slice_dim, slice_start, slice_end, slice_step, part_kernel.get()); + GGML_CANN_CALL_ACLNN_OP(ctx, Slice, acl_weight.get(), slice_dim, slice_start, slice_end, slice_step, + part_kernel.get()); // create the part conv result tensor int64_t part_dst_ne[4]; - for(int i = 0; i < 4; i++){ + for (int i = 0; i < 4; i++) { part_dst_ne[i] = *(dst->ne + i); } part_dst_ne[0] = (input_len - 1) * strideVal[0] - 2 * paddingVal[0] + dilationVal[0] * (part_ne[0] - 1) + 1; @@ -3095,32 +3097,33 @@ void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* ds } ggml_cann_pool_alloc part_dst_allocator; part_dst_allocator.alloc(ctx.pool(), part_dst_nb[3]); - void* part_dst_buf = part_dst_allocator.get(); + void * part_dst_buf = part_dst_allocator.get(); acl_tensor_ptr acl_part_dst = ggml_cann_create_tensor(part_dst_buf, dst_type, ggml_element_size(dst), - part_dst_ne, part_dst_nb, 3, ACL_FORMAT_NCL); + part_dst_ne, part_dst_nb, 3, ACL_FORMAT_NCL); GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, acl_part_dst.get()); // compute part conv transpose 1d GGML_CANN_CALL_ACLNN_OP(ctx, Convolution, acl_input.get(), part_kernel.get(), nullptr, stride.get(), - padding.get(), dilation.get(), transposed, padding.get(), groups, acl_part_dst.get(), cubeMathType); + padding.get(), dilation.get(), transposed, padding.get(), groups, acl_part_dst.get(), + cubeMathType); // compute the position of part result in final result int64_t global_start = slice_start; - int64_t global_end = std::min((input_len - 1) * strideVal[0] + slice_end, dst_len); + int64_t global_end = std::min((input_len - 1) * strideVal[0] + slice_end, dst_len); - left_pad_len = global_start; + left_pad_len = global_start; right_pad_len = dst_len - global_end; - std::vector padDataVal = {left_pad_len,right_pad_len}; - acl_int_array_ptr padData = ggml_cann_create_int_array(padDataVal.data(), 2); + std::vector padDataVal = { left_pad_len, right_pad_len }; + acl_int_array_ptr padData = ggml_cann_create_int_array(padDataVal.data(), 2); - acl_scalar_ptr pad_value = nullptr; - float pad_valueVal = 0.0; - pad_value = ggml_cann_create_scalar(&pad_valueVal, aclDataType::ACL_FLOAT); + acl_scalar_ptr pad_value = nullptr; + float pad_valueVal = 0.0; + pad_value = ggml_cann_create_scalar(&pad_valueVal, aclDataType::ACL_FLOAT); int64_t conv_result_ne[4]; - for(int i = 0; i < 4; i++){ + for (int i = 0; i < 4; i++) { conv_result_ne[i] = *(dst->ne + i); } @@ -3132,13 +3135,14 @@ void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* ds ggml_cann_pool_alloc conv_result_allocator; conv_result_allocator.alloc(ctx.pool(), conv_result_nb[3]); - void* conv_result_buf = conv_result_allocator.get(); + void * conv_result_buf = conv_result_allocator.get(); acl_tensor_ptr conv_result = ggml_cann_create_tensor(conv_result_buf, dst_type, ggml_element_size(dst), - conv_result_ne, conv_result_nb, 3, ACL_FORMAT_NCL); + conv_result_ne, conv_result_nb, 3, ACL_FORMAT_NCL); GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, conv_result.get()); - GGML_CANN_CALL_ACLNN_OP(ctx, ConstantPadNd, acl_part_dst.get(), padData.get(), pad_value.get(), conv_result.get()); + GGML_CANN_CALL_ACLNN_OP(ctx, ConstantPadNd, acl_part_dst.get(), padData.get(), pad_value.get(), + conv_result.get()); GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdd, acl_dst.get(), conv_result.get(), alpha.get()); } } @@ -3742,15 +3746,15 @@ void ggml_cann_ssm_conv(ggml_backend_cann_context & ctx, ggml_tensor * dst) { // we want a view: ne_w = { nc, 1, nr } // [K, 1, C] // so that reversed dims -> [C, 1, K] which matches // [out_channels, in_channels/groups, kernel_size] - int64_t w_ne[GGML_MAX_DIMS] = { nc, 1, nr, 1 }; // [K, 1 input ch. per group, C groups] + int64_t w_ne[GGML_MAX_DIMS] = { nc, 1, nr, 1 }; // [K, 1 input ch. per group, C groups] // Layout: src1 data is [K, C] with // offset(k, c) = k*nb0 + c*nb1 // We want offset_w(k, 0, c) = k*nb0 + c*nb1, // so we can reuse nb0 and nb1, and set nb2 = nb1. - size_t w_nb[GGML_MAX_DIMS] = { src1->nb[0], src1->nb[1], src1->nb[1], src1->nb[3] }; // same as src1 + size_t w_nb[GGML_MAX_DIMS] = { src1->nb[0], src1->nb[1], src1->nb[1], src1->nb[3] }; // same as src1 - acl_tensor_ptr acl_w = ggml_cann_create_tensor( - src1->data, ggml_cann_type_mapping(src1->type), ggml_type_size(src1->type), w_ne, w_nb, 3, ACL_FORMAT_NCL); + acl_tensor_ptr acl_w = ggml_cann_create_tensor(src1->data, ggml_cann_type_mapping(src1->type), + ggml_type_size(src1->type), w_ne, w_nb, 3, ACL_FORMAT_NCL); // 3) Output: dst is { d_inner, n_t, n_s } (CLN) // @@ -3768,11 +3772,12 @@ void ggml_cann_ssm_conv(ggml_backend_cann_context & ctx, ggml_tensor * dst) { // nb_y[0] = nr * sizeof(float); // step in L // nb_y[1] = sizeof(float); // step in C // nb_y[2] = nr * n_t * sizeof(float); // step in N - int64_t y_ne[GGML_MAX_DIMS] = { n_t, nr, n_s, 1 }; // [L_out, C, N] - size_t y_nb[GGML_MAX_DIMS] = { dst->ne[0] * sizeof(float), sizeof(float), dst->ne[0] * dst->ne[1] * sizeof(float), dst->nb[3] }; // [nr, 1, nr * n_t] + int64_t y_ne[GGML_MAX_DIMS] = { n_t, nr, n_s, 1 }; // [L_out, C, N] + size_t y_nb[GGML_MAX_DIMS] = { dst->ne[0] * sizeof(float), sizeof(float), dst->ne[0] * dst->ne[1] * sizeof(float), + dst->nb[3] }; // [nr, 1, nr * n_t] - acl_tensor_ptr acl_y = ggml_cann_create_tensor( - dst->data, ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type), y_ne, y_nb, 3, ACL_FORMAT_NCL); + acl_tensor_ptr acl_y = ggml_cann_create_tensor(dst->data, ggml_cann_type_mapping(dst->type), + ggml_type_size(dst->type), y_ne, y_nb, 3, ACL_FORMAT_NCL); // --- Conv1d parameters: depthwise, stride 1, no padding ("valid") --- int64_t strideVal[1] = { 1 }; @@ -3791,22 +3796,15 @@ void ggml_cann_ssm_conv(ggml_backend_cann_context & ctx, ggml_tensor * dst) { cubeMathType = 1; #endif - GGML_CANN_CALL_ACLNN_OP(ctx, - Convolution, + GGML_CANN_CALL_ACLNN_OP(ctx, Convolution, acl_x.get(), // input: N, C, L_in = ncs acl_w.get(), // weight: [C, 1, K] with groups=nr nullptr, // bias - stride.get(), - padding.get(), - dilation.get(), - transposed, - padding.get(), // output padding (unused for non-transposed) - groups, - acl_y.get(), - cubeMathType); + stride.get(), padding.get(), dilation.get(), transposed, + padding.get(), // output padding (unused for non-transposed) + groups, acl_y.get(), cubeMathType); } - void ggml_cann_op_add_rms_norm_fused(ggml_backend_cann_context & ctx, ggml_tensor * add_node, ggml_tensor * rms_norm_node) { @@ -3860,3 +3858,71 @@ void ggml_cann_op_add_rms_norm_fused(ggml_backend_cann_context & ctx, eps, // double type acl_yout.get(), acl_rstd.get(), acl_xout.get()); } + +void ggml_cann_gated_linear_attn(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + ggml_tensor * k = dst->src[0]; + ggml_tensor * v = dst->src[1]; + ggml_tensor * q = dst->src[2]; + ggml_tensor * g = dst->src[3]; + ggml_tensor * s = dst->src[4]; + + int64_t B = dst->src[4]->ne[1]; + int64_t T = dst->src[0]->ne[2]; + int64_t H = dst->src[0]->ne[1]; + int64_t C = dst->ne[0]; + int64_t D = C / H; + int64_t L = T / B; + + int64_t ne_qkg[2] = { 1, D }; + int64_t ne_s[2] = { D, D }; + int64_t ne_st[2] = { ne_s[1], ne_s[0] }; + int64_t ne_vo[2] = { D, 1 }; + int64_t ne_q[1] = { D }; + size_t nb_base = ggml_type_size(k->type); + size_t nb_qkg[2] = { nb_base, nb_base }; + size_t nb_s[2] = { nb_base, D * nb_base }; + size_t nb_st[2] = { nb_s[1], nb_s[0] }; + size_t nb_vo[2] = { nb_base, D * nb_base }; + size_t nb_q[1] = { nb_base }; + + const float scale = ggml_get_op_params_f32(dst, 0); + + acl_tensor_ptr acl_s = ggml_cann_create_tensor(s, s->ne, s->nb, 2, ACL_FORMAT_ND); + acl_tensor_ptr new_state = ggml_cann_create_tensor(dst, s->ne, s->nb, 2, ACL_FORMAT_ND, (B * L * H * D) * nb_base); + cann_copy(ctx, acl_s.get(), new_state.get()); + + for (int64_t b = 0; b < B; b++) { + for (int64_t h = 0; h < H; h++) { + size_t s_offset = (b * (H * D * D) + h * (D * D)) * nb_base; + // D * D + acl_tensor_ptr acl_s_new = + ggml_cann_create_tensor(dst, ne_s, nb_s, 2, ACL_FORMAT_ND, (B * L * H * D) * nb_base + s_offset); + acl_tensor_ptr acl_s_new_t = + ggml_cann_create_tensor(dst, ne_st, nb_st, 2, ACL_FORMAT_ND, (B * L * H * D) * nb_base + s_offset); + for (int64_t l = 0; l < L; l++) { + size_t qkvgo_offset = (b * (L * H * D) + l * (H * D) + h * (D)) * nb_base; + // D * 1 + acl_tensor_ptr acl_k = ggml_cann_create_tensor(k, ne_qkg, nb_qkg, 2, ACL_FORMAT_ND, qkvgo_offset); + acl_tensor_ptr acl_g = ggml_cann_create_tensor(g, ne_qkg, nb_qkg, 2, ACL_FORMAT_ND, qkvgo_offset); + // D + acl_tensor_ptr acl_q = ggml_cann_create_tensor(q, ne_q, nb_q, 1, ACL_FORMAT_ND, qkvgo_offset); + // 1 * D + acl_tensor_ptr acl_v = ggml_cann_create_tensor(v, ne_vo, nb_vo, 2, ACL_FORMAT_ND, qkvgo_offset); + // D + acl_tensor_ptr acl_o = ggml_cann_create_tensor(dst, ne_q, nb_q, 1, ACL_FORMAT_ND, qkvgo_offset); + // k ⊗ v + size_t buf_size = D * D * nb_base; + ggml_cann_pool_alloc buffer_allocator(ctx.pool(), buf_size); + acl_tensor_ptr tmp_tensor = ggml_cann_create_tensor( + buffer_allocator.get(), ggml_cann_type_mapping(k->type), nb_base, ne_s, nb_s, 2); + aclnn_mul(ctx, acl_k.get(), acl_v.get(), tmp_tensor.get()); + //s_new = g ⊗ s_old + k ⊗ v + aclnn_mul(ctx, acl_s_new.get(), acl_g.get(), nullptr); + aclnn_add(ctx, acl_s_new.get(), tmp_tensor.get(), nullptr); + // compute output + GGML_CANN_CALL_ACLNN_OP(ctx, Mv, acl_s_new_t.get(), acl_q.get(), acl_o.get(), 1); + aclnn_muls(ctx, acl_o.get(), scale, nullptr, true); + } + } + } +} diff --git a/ggml/src/ggml-cann/aclnn_ops.h b/ggml/src/ggml-cann/aclnn_ops.h index 08ee7b1fbdf..b76e4707ac7 100644 --- a/ggml/src/ggml-cann/aclnn_ops.h +++ b/ggml/src/ggml-cann/aclnn_ops.h @@ -814,67 +814,20 @@ void ggml_cann_step(ggml_backend_cann_context & ctx, ggml_tensor * dst); */ void ggml_cann_flash_attn_ext(ggml_backend_cann_context & ctx, ggml_tensor * dst); -/* - * @brief A generic wrapper for ACL resources with custom deleter support. - */ -using any_acl_resource = std::unique_ptr>; - /** - * @brief Trait structure used to define how to destroy a given ACL resource type. + * @brief Forward Gated Linear Attention on the CANN backend. * - * @tparam T ACL resource type. - */ -template struct acl_resource_traits; - -/** - * @brief Specialization for aclTensor, defines how to destroy an aclTensor resource. - */ -template <> struct acl_resource_traits { - static void destroy(void * p) { ACL_CHECK(aclDestroyTensor(static_cast(p))); } -}; - -/** - * @brief Specialization for aclIntArray, defines how to destroy an aclIntArray resource. - */ -template <> struct acl_resource_traits { - static void destroy(void * p) { ACL_CHECK(aclDestroyIntArray(static_cast(p))); } -}; - -/** - * @brief Specialization for aclScalar, defines how to destroy an aclScalar resource. - */ -template <> struct acl_resource_traits { - static void destroy(void * p) { ACL_CHECK(aclDestroyScalar(static_cast(p))); } -}; - -/** - * @brief Specialization for aclTensorList, defines how to destroy an aclTensorList resource. - */ -template <> struct acl_resource_traits { - static void destroy(void * p) { ACL_CHECK(aclDestroyTensorList(static_cast(p))); } -}; - -/** - * @brief Creates a generic ACL resource wrapper with proper destruction logic. + * Expects dst->src[0..4] = {k, v, q, g, s} with shape conventions: + * k, v, q, g: [D] with outer dims T x H batched as ne[2]=T, ne[1]=H + * s: initial state [B, H, D, D], where B is batch and D=C/H + * dst holds both outputs (o) and updated state; a scale factor is read from op params. * - * @tparam T ACL resource type. - * @param ptr Raw pointer to ACL resource. - * @return any_acl_resource Smart pointer that handles destruction. - */ -template any_acl_resource make_acl_resource(T * ptr) { - return any_acl_resource(static_cast(ptr), [](void * p) { acl_resource_traits::destroy(p); }); -} - -/** - * @brief Registers multiple ACL resources into a vector for lifetime management. + * The kernel updates per time step l: S_new = g ⊗ S_old + k ⊗ v, then computes o = (S_new^T q) * scale. * - * @tparam Args Variadic list of ACL resource types. - * @param vec Target vector to hold ACL resources. - * @param args Raw pointers to ACL resources. + * @param ctx Backend context providing stream/allocator utilities. + * @param dst Output tensor; src deps are k, v, q, g, s as above. */ -template void register_acl_resources(std::vector & vec, Args *... args) { - (vec.emplace_back(make_acl_resource(args)), ...); -} +void ggml_cann_gated_linear_attn(ggml_backend_cann_context & ctx, ggml_tensor * dst); /** * @brief Launches an asynchronous task using the memory allocator. @@ -894,19 +847,19 @@ template void register_acl_resources(std::vector 0) { \ - ggml_cann_pool_alloc workspace_allocator(CTX.pool(), workspaceSize); \ - workspaceAddr = workspace_allocator.get(); \ - } \ - ACL_CHECK(aclnn##OP_NAME(workspaceAddr, workspaceSize, executor, CTX.stream())); \ - } while (0) +# define GGML_CANN_CALL_ACLNN_OP(CTX, OP_NAME, ...) \ + do { \ + uint64_t workspaceSize = 0; \ + aclOpExecutor * executor; \ + void * workspaceAddr = nullptr; \ + ACL_CHECK(aclnn##OP_NAME##GetWorkspaceSize(__VA_ARGS__, &workspaceSize, &executor)); \ + /* workspace should alloced in main thread to keep malloc order when using vmm. */ \ + if (workspaceSize > 0) { \ + ggml_cann_pool_alloc workspace_allocator(CTX.pool(), workspaceSize); \ + workspaceAddr = workspace_allocator.get(); \ + } \ + ACL_CHECK(aclnn##OP_NAME(workspaceAddr, workspaceSize, executor, CTX.stream())); \ + } while (0) /** * @brief Performs sparse expert-based matrix multiplication using the CANN backend. @@ -947,7 +900,9 @@ void ggml_cann_mul_mat_id(ggml_backend_cann_context & ctx, ggml_tensor * dst); * @param rms_norm_tensor The RMS_NORM operation node, contains the gamma weights * and epsilon parameter. */ -void ggml_cann_op_add_rms_norm_fused(ggml_backend_cann_context & ctx, ggml_tensor * add_node, ggml_tensor * rms_norm_node); +void ggml_cann_op_add_rms_norm_fused(ggml_backend_cann_context & ctx, + ggml_tensor * add_node, + ggml_tensor * rms_norm_node); /** * @brief Check whether a tensor is a weight tensor for matrix multiplication. @@ -1104,13 +1059,13 @@ void ggml_cann_op_unary_gated(std::function Date: Fri, 16 Jan 2026 16:24:04 +0800 Subject: [PATCH 014/831] CANN: fix an issue where get_env was not fully renamed (llama/18796) * CANN: fix an issue where get_env was not fully renamed * ci: add cann with acl group * ci: define use_acl_graph using GitHub Action * ci: update cann dockerfile with acl graph --- ggml/src/ggml-cann/common.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cann/common.h b/ggml/src/ggml-cann/common.h index 6895349b207..70d3f2b225d 100644 --- a/ggml/src/ggml-cann/common.h +++ b/ggml/src/ggml-cann/common.h @@ -382,7 +382,7 @@ struct ggml_cann_graph_lru_cache { std::list cache_list; /**< List storing cached graphs as raw pointers. */ - ggml_cann_graph_lru_cache() { capacity = parse_integer(get_env("GGML_CANN_GRAPH_CACHE_CAPACITY").value_or("12")); } + ggml_cann_graph_lru_cache() { capacity = parse_integer(get_env_as_lowercase("GGML_CANN_GRAPH_CACHE_CAPACITY").value_or("12")); } /** * @brief Push a new graph to the front of the cache. @@ -574,7 +574,7 @@ struct ggml_backend_cann_context { description = aclrtGetSocName(); #ifdef USE_ACL_GRAPH - acl_graph_mode = parse_bool(get_env("GGML_CANN_ACL_GRAPH").value_or("on")); + acl_graph_mode = parse_bool(get_env_as_lowercase("GGML_CANN_ACL_GRAPH").value_or("on")); GGML_LOG_INFO("%s: device %d execution mode is %s (%s)\n", __func__, device, acl_graph_mode ? "GRAPH" : "EAGER", acl_graph_mode ? "acl graph enabled" : "acl graph disabled"); #endif From 42960b6073def996ea5b3013c9135d1d3c6b53d5 Mon Sep 17 00:00:00 2001 From: Raul Torres <138264735+rauletorresc@users.noreply.github.com> Date: Fri, 16 Jan 2026 08:34:09 +0000 Subject: [PATCH 015/831] CANN: Remove unused `ggml_cann_get_device` function (llama/18625) --- ggml/src/ggml-cann/common.h | 1 - ggml/src/ggml-cann/ggml-cann.cpp | 11 ----------- 2 files changed, 12 deletions(-) diff --git a/ggml/src/ggml-cann/common.h b/ggml/src/ggml-cann/common.h index 70d3f2b225d..fb3e7572e2c 100644 --- a/ggml/src/ggml-cann/common.h +++ b/ggml/src/ggml-cann/common.h @@ -101,7 +101,6 @@ struct ggml_cann_device_info { const ggml_cann_device_info & ggml_cann_info(); void ggml_cann_set_device(int32_t device); -int32_t ggml_cann_get_device(); std::optional get_env_as_lowercase(const std::string & name); bool parse_bool(const std::string & value); diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index a5ca85c4470..eba83327f13 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -93,17 +93,6 @@ void ggml_cann_set_device(const int32_t device) { g_current_cann_device = device; } -/** - * @brief Retrieves the current device ID. - * - * @return The current device ID. - */ -int32_t ggml_cann_get_device() { - int32_t id; - ACL_CHECK(aclrtGetDevice(&id)); - return id; -} - /** * @brief Get the value of the specified environment variable (name) as lowercase. * if not empty, return a std::string object From ecb4b80c35934d446d3fb9d7786e8b79d32db97f Mon Sep 17 00:00:00 2001 From: Perry Naseck <4472083+DaAwesomeP@users.noreply.github.com> Date: Fri, 16 Jan 2026 06:38:25 -0500 Subject: [PATCH 016/831] ggml-blas: hide warnings from included BLAS headers (llama/18818) * fix compile def openblas, blis for compat libs, nvpl compile def, warn if no blas vendor set * ggml-blas: hide warnings from included BLAS headers --- ggml/src/ggml-blas/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-blas/CMakeLists.txt b/ggml/src/ggml-blas/CMakeLists.txt index fb0936f47b7..c27dc174c00 100644 --- a/ggml/src/ggml-blas/CMakeLists.txt +++ b/ggml/src/ggml-blas/CMakeLists.txt @@ -93,7 +93,7 @@ if (BLAS_FOUND) endif() target_link_libraries (ggml-blas PRIVATE ${BLAS_LIBRARIES}) - target_include_directories(ggml-blas PRIVATE ${BLAS_INCLUDE_DIRS}) + target_include_directories(ggml-blas SYSTEM PRIVATE ${BLAS_INCLUDE_DIRS}) else() message(FATAL_ERROR "BLAS not found, please refer to " "https://cmake.org/cmake/help/latest/module/FindBLAS.html#blas-lapack-vendors" From 511ca7a1f4fe0e5341190f57c0a181c656d09b97 Mon Sep 17 00:00:00 2001 From: Thore Koritzius Date: Fri, 16 Jan 2026 15:59:56 +0100 Subject: [PATCH 017/831] ggml : extend ggml_pool_1d + metal (llama/16429) * chore: resolve conflicts * feat: ggml metal impl * fix: ggml_metal_kargs_pool_1d struct * fix: require contiguous input * chore: test pool_1d * chore: limit pool1d test cases to p0=0 and s0=k0 to conform with asserts * chore: add p0 and s0 to testing * fix: allow padding for cpu and metal * Update ggml/src/ggml-metal/ggml-metal.metal * fix: correct single-threaded loop * ggml : cleanup * tests : add ne[1] != 1 tests * fix: ne[1] handling in np * cont : fixes --------- Co-authored-by: Georgi Gerganov --- ggml/src/ggml-cpu/ops.cpp | 100 ++++++++++++++-------- ggml/src/ggml-metal/ggml-metal-device.cpp | 25 ++++++ ggml/src/ggml-metal/ggml-metal-device.h | 1 + ggml/src/ggml-metal/ggml-metal-device.m | 4 +- ggml/src/ggml-metal/ggml-metal-impl.h | 9 ++ ggml/src/ggml-metal/ggml-metal-ops.cpp | 52 +++++++++++ ggml/src/ggml-metal/ggml-metal-ops.h | 1 + ggml/src/ggml-metal/ggml-metal.metal | 68 +++++++++++++++ ggml/src/ggml.c | 5 ++ 9 files changed, 226 insertions(+), 39 deletions(-) diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 3032783971d..387e2fe42c3 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -7,10 +7,9 @@ #include "unary-ops.h" #include "vec.h" -#include #include +#include #include -#include // ggml_compute_forward_dup @@ -7110,12 +7109,13 @@ void ggml_compute_forward_conv_2d_dw( } } -// ggml_compute_forward_pool_1d_sk_p0 - -static void ggml_compute_forward_pool_1d_sk_p0( +// ggml_compute_forward_pool_1d_ksp +static void ggml_compute_forward_pool_1d_ksp( const ggml_compute_params * params, const ggml_op_pool op, const int k, + const int s, + const int p, ggml_tensor * dst) { const ggml_tensor * src = dst->src[0]; @@ -7126,39 +7126,56 @@ static void ggml_compute_forward_pool_1d_sk_p0( return; } - const char * cdata = (const char *)src->data; - const char * const data_end = cdata + ggml_nbytes(src); - float * drow = (float *)dst->data; + const int64_t IW = src->ne[0]; + const int64_t OW = dst->ne[0]; - const int64_t rs = dst->ne[0]; + const int64_t nr = ggml_nrows(src); - while (cdata < data_end) { - const void * srow = (const void *)cdata; - int j = 0; - for (int64_t i = 0; i < rs; ++i) { + for (int64_t ir = 0; ir < nr; ++ir) { + const char * srow_bytes = (const char *) src->data + ir * src->nb[1]; + float * drow = (float *) (( char *) dst->data + ir * dst->nb[1]); + + for (int64_t ow = 0; ow < OW; ++ow) { + float res = 0; switch (op) { - case GGML_OP_POOL_AVG: drow[i] = 0; break; - case GGML_OP_POOL_MAX: drow[i] = -FLT_MAX; break; + case GGML_OP_POOL_AVG: res = 0.0f; break; + case GGML_OP_POOL_MAX: res = -FLT_MAX; break; case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error"); } + + int count = 0; + const int base = (int) ow * s - p; + for (int ki = 0; ki < k; ++ki) { - const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]); + const int j = base + ki; + if (j < 0 || j >= (int) IW) { + continue; + } + + float v; + if (src->type == GGML_TYPE_F32) { + v = ((const float *) srow_bytes)[j]; + } else { + v = GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t *) srow_bytes)[j]); + } + switch (op) { - case GGML_OP_POOL_AVG: drow[i] += srow_j; break; - case GGML_OP_POOL_MAX: if (srow_j > drow[i]) drow[i] = srow_j; break; - case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error"); + case GGML_OP_POOL_AVG: res += v; break; + case GGML_OP_POOL_MAX: res = std::max(v, res); break; + case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error"); } - ++j; + + ++count; } + switch (op) { - case GGML_OP_POOL_AVG: drow[i] /= k; break; - case GGML_OP_POOL_MAX: break; + case GGML_OP_POOL_AVG: res = (count > 0) ? (res / count) : 0.0f; break; + case GGML_OP_POOL_MAX: break; case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error"); } - } - cdata += src->nb[1]; - drow += rs; + drow[ow] = res; + } } } @@ -7173,10 +7190,8 @@ void ggml_compute_forward_pool_1d( const int k0 = opts[1]; const int s0 = opts[2]; const int p0 = opts[3]; - GGML_ASSERT(p0 == 0); // padding not supported - GGML_ASSERT(k0 == s0); // only s = k supported - ggml_compute_forward_pool_1d_sk_p0(params, op, k0, dst); + ggml_compute_forward_pool_1d_ksp(params, op, k0, s0, p0, dst); } // ggml_compute_forward_pool_2d @@ -7194,6 +7209,7 @@ void ggml_compute_forward_pool_2d( } const int32_t * opts = (const int32_t *)dst->op_params; + ggml_op_pool op = static_cast(opts[0]); const int k0 = opts[1]; const int k1 = opts[2]; @@ -7217,11 +7233,13 @@ void ggml_compute_forward_pool_2d( while (cdata < data_end) { for (int oy = 0; oy < py; ++oy) { float * const drow = dplane + oy * px; + float * const out = drow; + for (int ox = 0; ox < px; ++ox) { - float * const out = drow + ox; + float res = 0; switch (op) { - case GGML_OP_POOL_AVG: *out = 0; break; - case GGML_OP_POOL_MAX: *out = -FLT_MAX; break; + case GGML_OP_POOL_AVG: res = 0; break; + case GGML_OP_POOL_MAX: res = -FLT_MAX; break; case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error"); } @@ -7229,24 +7247,32 @@ void ggml_compute_forward_pool_2d( const int iy = offset1 + oy * s1; for (int ky = 0; ky < k1; ++ky) { - if (iy + ky < 0 || iy + ky >= src->ne[1]) continue; + if (iy + ky < 0 || iy + ky >= src->ne[1]) { + continue; + } + const void * srow = (const void *)(cdata + src->nb[1] * (iy + ky)); for (int kx = 0; kx < k0; ++kx) { int j = ix + kx; - if (j < 0 || j >= src->ne[0]) continue; + if (j < 0 || j >= src->ne[0]) { + continue; + } + const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]); switch (op) { - case GGML_OP_POOL_AVG: *out += srow_j; break; - case GGML_OP_POOL_MAX: if (srow_j > *out) *out = srow_j; break; + case GGML_OP_POOL_AVG: res += srow_j; break; + case GGML_OP_POOL_MAX: res = std::max(srow_j, res); break; case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error"); } } } switch (op) { - case GGML_OP_POOL_AVG: *out /= ka; break; - case GGML_OP_POOL_MAX: break; + case GGML_OP_POOL_AVG: res /= ka; break; + case GGML_OP_POOL_MAX: break; case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error"); } + + out[ox] = res; } } diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index b0734797f19..04c6137c5a7 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -94,6 +94,31 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cpy(ggml_metal_l return res; } +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_1d(ggml_metal_library_t lib, const ggml_tensor * op, ggml_op_pool op_pool) { + GGML_ASSERT(ggml_is_contiguous(op->src[0])); + GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32 && op->src[0]->type == op->type); + + const char * pool_str = "undefined"; + switch (op_pool) { + case GGML_OP_POOL_AVG: pool_str = "avg"; break; + case GGML_OP_POOL_MAX: pool_str = "max"; break; + default: GGML_ASSERT(false && "not implemented"); + }; + + char base[256]; + char name[256]; + + snprintf(base, sizeof(base), "kernel_pool_1d_%s_%s", pool_str, ggml_type_name(op->src[0]->type)); + snprintf(name, sizeof(name), "%s", base); + + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + } + + return res; +} + ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_2d(ggml_metal_library_t lib, const ggml_tensor * op, ggml_op_pool op_pool) { GGML_ASSERT(ggml_is_contiguous(op->src[0])); GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32 && op->src[0]->type == op->type); diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index 9c3b0014878..3d01c56fb81 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -104,6 +104,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_compile_pipeline(ggml_ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_base (ggml_metal_library_t lib, enum ggml_op op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cpy (ggml_metal_library_t lib, enum ggml_type tsrc, enum ggml_type tdst); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_1d (ggml_metal_library_t lib, const struct ggml_tensor * op, enum ggml_op_pool op_pool); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_2d (ggml_metal_library_t lib, const struct ggml_tensor * op, enum ggml_op_pool op_pool); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_get_rows (ggml_metal_library_t lib, enum ggml_type tsrc); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_set_rows (ggml_metal_library_t lib, enum ggml_type tidx, enum ggml_type tdst); diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index ff899a81709..c418afe9c3b 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -1044,10 +1044,10 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32 && (op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32); - case GGML_OP_POOL_1D: - return false; case GGML_OP_UPSCALE: return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST && !(op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS); + case GGML_OP_POOL_1D: + return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_POOL_2D: return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_PAD: diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index d3b0e732ec4..59d88b01a55 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -928,6 +928,15 @@ typedef struct { int64_t np; } ggml_metal_kargs_pool_2d; +typedef struct { + int32_t k0; + int32_t s0; + int32_t p0; + int64_t IW; + int64_t OW; + int64_t np; +} ggml_metal_kargs_pool_1d; + typedef struct { int64_t ne00; uint64_t nb01; diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index a50b12b6f3b..680ad794de9 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -432,6 +432,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { { n_fuse = ggml_metal_op_cpy(ctx, idx); } break; + case GGML_OP_POOL_1D: + { + n_fuse = ggml_metal_op_pool_1d(ctx, idx); + } break; case GGML_OP_POOL_2D: { n_fuse = ggml_metal_op_pool_2d(ctx, idx); @@ -1622,6 +1626,54 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) { return 1; } +int ggml_metal_op_pool_1d(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + const int32_t * opts = op->op_params; + ggml_op_pool op_pool = (ggml_op_pool) opts[0]; + + const int32_t k0 = opts[1]; + const int32_t s0 = opts[2]; + const int32_t p0 = opts[3]; + + const int64_t IW = op->src[0]->ne[0]; + const int64_t OW = op->ne[0]; + + const int64_t np = ggml_nelements(op); + + ggml_metal_kargs_pool_1d args_pool_1d = { + /* .k0 = */ k0, + /* .s0 = */ s0, + /* .p0 = */ p0, + /* .IW = */ IW, + /* .OW = */ OW, + /* .np = */ np + }; + + auto pipeline = ggml_metal_library_get_pipeline_pool_1d(lib, op, op_pool); + + const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), (int) np); + const int ntg = (np + nth - 1) / nth; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args_pool_1d, sizeof(args_pool_1d), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + + ggml_metal_encoder_dispatch_threadgroups(enc, ntg, 1, 1, nth, 1, 1); + + return 1; +} + + int ggml_metal_op_pool_2d(ggml_metal_op_t ctx, int idx) { ggml_tensor * op = ctx->node(idx); diff --git a/ggml/src/ggml-metal/ggml-metal-ops.h b/ggml/src/ggml-metal/ggml-metal-ops.h index c1025d35677..10686a334e0 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.h +++ b/ggml/src/ggml-metal/ggml-metal-ops.h @@ -61,6 +61,7 @@ int ggml_metal_op_ssm_conv (ggml_metal_op_t ctx, int idx); int ggml_metal_op_ssm_scan (ggml_metal_op_t ctx, int idx); int ggml_metal_op_rwkv (ggml_metal_op_t ctx, int idx); int ggml_metal_op_cpy (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_pool_1d (ggml_metal_op_t ctx, int idx); int ggml_metal_op_pool_2d (ggml_metal_op_t ctx, int idx); int ggml_metal_op_mul_mat (ggml_metal_op_t ctx, int idx); int ggml_metal_op_mul_mat_id (ggml_metal_op_t ctx, int idx); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 16d17d26af8..a4e1cafe552 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -9869,6 +9869,74 @@ kernel void kernel_pool_2d_avg_f32( o_ptr[cur_oh * args.OW + cur_ow] = res; } + +kernel void kernel_pool_1d_max_f32( + constant ggml_metal_kargs_pool_1d & args, + device const float * src, + device float * dst, + uint gid [[thread_position_in_grid]] +) { + + if (gid >= args.np) { + return; + } + + const int ow = (int)gid % args.OW; + const int row = (int)gid / args.OW; + + const int base = ow * args.s0 - args.p0; + + float acc = -INFINITY; + + const int src_off = row * args.IW; + const int dst_off = row * args.OW; + + for (int ki = 0; ki < args.k0; ++ki) { + int j = base + ki; + if (j < 0 || j >= args.IW){ + continue; + } + float v = src[src_off + j]; + acc = max(acc, v); + } + + dst[dst_off + ow] = acc; +} + +kernel void kernel_pool_1d_avg_f32( + constant ggml_metal_kargs_pool_1d & args, + device const float * src, + device float * dst, + uint gid [[thread_position_in_grid]] +) { + + if (gid >= args.np) { + return; + } + + const int ow = (int)gid % args.OW; + const int row = (int)gid / args.OW; + + const int base = ow * args.s0 - args.p0; + + float acc = 0.0f; + int cnt = 0; + + const int src_off = row * args.IW; + const int dst_off = row * args.OW; + + for (int ki = 0; ki < args.k0; ++ki) { + const int j = base + ki; + if (j < 0 || j >= args.IW) { + continue; + } + acc += src[src_off + j]; + cnt += 1; + } + + dst[dst_off + ow] = (cnt > 0) ? (acc / (float)cnt) : 0.0f; +} + kernel void kernel_opt_step_adamw_f32( constant ggml_metal_kargs_opt_step_adamw & args, device float * x, diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 09b8eb466d3..c75fe7d2716 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -4838,6 +4838,8 @@ struct ggml_tensor * ggml_pool_1d( a->ne[2], a->ne[3], }; + GGML_ASSERT(ne[0] > 0); + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); int32_t params[] = { op, k0, s0, p0 }; @@ -4868,6 +4870,9 @@ struct ggml_tensor * ggml_pool_2d( a->ne[2], a->ne[3], }; + GGML_ASSERT(ne[0] > 0); + GGML_ASSERT(ne[1] > 0); + result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); int32_t params[] = { op, k0, k1, s0, s1, p0, p1 }; From 389dafc7c2bff8cc822fe3870b45d7c6b223da01 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Fri, 30 Jan 2026 10:32:34 +0200 Subject: [PATCH 018/831] ggml webgpu: support for backend sampling (llama/18880) --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 371 +++++- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 1031 +++++++++++------ ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl | 72 ++ .../src/ggml-webgpu/wgsl-shaders/argsort.wgsl | 106 ++ .../wgsl-shaders/argsort_merge.wgsl | 134 +++ .../ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl | 6 + ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl | 66 ++ ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl | 86 ++ .../wgsl-shaders/set_rows.tmpl.wgsl | 112 -- .../ggml-webgpu/wgsl-shaders/set_rows.wgsl | 52 +- .../ggml-webgpu/wgsl-shaders/sum_rows.wgsl | 55 + ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl | 179 +++ .../ggml-webgpu/wgsl-shaders/unary_op.wgsl | 483 -------- 13 files changed, 1758 insertions(+), 995 deletions(-) create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl delete mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl delete mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 7fdb4c8c8da..84d88e81d45 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -9,12 +9,28 @@ #define GGML_WEBGPU_F16_SIZE_BYTES 2 #define GGML_WEBGPU_F32_SIZE_BYTES 4 +#define GGML_WEBGPU_I32_SIZE_BYTES 4 #define GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES 8u #define GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE 128u // Matches GGML_PAD(..., 256) in src/llama-context.cpp for KV cache sizing. #define GGML_WEBGPU_KV_SEQ_PAD 256u -struct ggml_webgpu_flash_attn_shader_lib_context { +#define GGML_WEBGPU_ARGSORT_MERGE_MAX_WG_SIZE 512u + +struct ggml_webgpu_processed_shader { + std::string wgsl; + std::string variant; + void * decisions; +}; + +// Same hash combine function as in boost +template inline void ggml_webgpu_hash_combine(size_t & seed, const T & value) { + seed ^= std::hash{}(value) + 0x9e3779b9 + (seed << 6) + (seed >> 2); +} + +/** FlashAttention */ + +struct ggml_webgpu_flash_attn_pipeline_key { ggml_type kv_type; uint32_t head_dim_qk; uint32_t head_dim_v; @@ -22,11 +38,35 @@ struct ggml_webgpu_flash_attn_shader_lib_context { bool has_mask; bool has_sinks; bool uses_logit_softcap; - uint32_t sg_mat_m; - uint32_t sg_mat_n; - uint32_t sg_mat_k; - size_t wg_mem_limit_bytes; - uint32_t max_subgroup_size; + + bool operator==(const ggml_webgpu_flash_attn_pipeline_key & other) const { + return kv_type == other.kv_type && head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v && + kv_direct == other.kv_direct && has_mask == other.has_mask && has_sinks == other.has_sinks && + uses_logit_softcap == other.uses_logit_softcap; + } +}; + +struct ggml_webgpu_flash_attn_pipeline_key_hash { + size_t operator()(const ggml_webgpu_flash_attn_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.kv_type); + ggml_webgpu_hash_combine(seed, key.head_dim_qk); + ggml_webgpu_hash_combine(seed, key.head_dim_v); + ggml_webgpu_hash_combine(seed, key.kv_direct); + ggml_webgpu_hash_combine(seed, key.has_mask); + ggml_webgpu_hash_combine(seed, key.has_sinks); + ggml_webgpu_hash_combine(seed, key.uses_logit_softcap); + return seed; + } +}; + +struct ggml_webgpu_flash_attn_shader_lib_context { + ggml_webgpu_flash_attn_pipeline_key key; + uint32_t sg_mat_m; + uint32_t sg_mat_n; + uint32_t sg_mat_k; + size_t wg_mem_limit_bytes; + uint32_t max_subgroup_size; }; struct ggml_webgpu_flash_attn_shader_decisions { @@ -35,12 +75,6 @@ struct ggml_webgpu_flash_attn_shader_decisions { uint32_t wg_size = 0; }; -struct ggml_webgpu_processed_shader { - std::string wgsl; - std::string variant; - ggml_webgpu_flash_attn_shader_decisions decisions; -}; - // This is exposed because it's necessary in supports_op inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile, uint32_t kv_tile, @@ -66,15 +100,16 @@ inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile, } static uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_flash_attn_shader_lib_context & context) { - const size_t limit_bytes = context.wg_mem_limit_bytes; - const size_t q_tile = context.sg_mat_m; - const size_t base_q_bytes = (context.head_dim_qk + context.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES + - 2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES; + const size_t limit_bytes = context.wg_mem_limit_bytes; + const size_t q_tile = context.sg_mat_m; + const size_t base_q_bytes = + (context.key.head_dim_qk + context.key.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES + + 2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES; size_t bytes_per_kv = 0; - if (!context.kv_direct) { - bytes_per_kv += std::max(context.head_dim_qk, context.head_dim_v); + if (!context.key.kv_direct) { + bytes_per_kv += std::max(context.key.head_dim_qk, context.key.head_dim_v); } - if (context.has_mask) { + if (context.key.has_mask) { bytes_per_kv += q_tile; } bytes_per_kv += q_tile; @@ -90,7 +125,7 @@ inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_shader( std::vector defines; std::string variant = "flash_attn"; - switch (context.kv_type) { + switch (context.key.kv_type) { case GGML_TYPE_F32: defines.push_back("KV_F32"); break; @@ -106,32 +141,31 @@ inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_shader( default: GGML_ABORT("Unsupported KV type for flash attention shader"); } - variant += std::string("_") + ggml_type_name(context.kv_type); + variant += std::string("_") + ggml_type_name(context.key.kv_type); - if (context.has_mask) { + if (context.key.has_mask) { defines.push_back("MASK"); variant += "_mask"; } - if (context.has_sinks) { + if (context.key.has_sinks) { defines.push_back("SINKS"); variant += "_sinks"; } - if (context.uses_logit_softcap) { + if (context.key.uses_logit_softcap) { defines.push_back("LOGIT_SOFTCAP"); variant += "_lgsc"; } - if (context.kv_direct) { + if (context.key.kv_direct) { defines.push_back("KV_DIRECT"); variant += "_kvdirect"; } - defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(context.head_dim_qk)); - variant += std::string("_hsqk") + std::to_string(context.head_dim_qk); - - defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(context.head_dim_v)); - variant += std::string("_hsv") + std::to_string(context.head_dim_v); + defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(context.key.head_dim_qk)); + variant += std::string("_hsqk") + std::to_string(context.key.head_dim_qk); + defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(context.key.head_dim_v)); + variant += std::string("_hsv") + std::to_string(context.key.head_dim_v); // For now these are not part of the variant name defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m)); defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n)); @@ -141,7 +175,7 @@ inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_shader( uint32_t q_tile = context.sg_mat_m; uint32_t kv_tile = std::min(ggml_webgpu_flash_attn_max_kv_tile(context), context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES); - if (context.kv_direct) { + if (context.key.kv_direct) { GGML_ASSERT(kv_tile <= GGML_WEBGPU_KV_SEQ_PAD); // Avoids having to use bounds-checks and decreasing performance for direct KV loads while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) { @@ -158,11 +192,276 @@ inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_shader( defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); ggml_webgpu_processed_shader result; - result.wgsl = preprocessor.preprocess(shader_src, defines); - result.variant = variant; - result.decisions.q_tile = q_tile; - result.decisions.kv_tile = kv_tile; - result.decisions.wg_size = wg_size; + result.wgsl = preprocessor.preprocess(shader_src, defines); + result.variant = variant; + ggml_webgpu_flash_attn_shader_decisions * decisions = new ggml_webgpu_flash_attn_shader_decisions(); + decisions->q_tile = q_tile; + decisions->kv_tile = kv_tile; + decisions->wg_size = wg_size; + result.decisions = decisions; + return result; +} + +/** Generic **/ + +struct ggml_webgpu_generic_shader_lib_context { + int vec4; + uint32_t max_wg_size; +}; + +struct ggml_webgpu_generic_shader_decisions { + uint32_t wg_size; +}; + +inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_generic_shader( + pre_wgsl::Preprocessor & preprocessor, + const char * shader_src, + const ggml_webgpu_generic_shader_lib_context & context, + const std::string & base_variant) { + std::vector defines; + std::string variant = base_variant; + + if (context.vec4) { + defines.push_back("VEC4"); + variant += "_vec"; + } + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + ggml_webgpu_processed_shader result; + result.wgsl = preprocessor.preprocess(shader_src, defines); + result.variant = variant; + return result; +} + +/** Pad **/ + +struct ggml_webgpu_pad_pipeline_key { + bool circular; + + bool operator==(const ggml_webgpu_pad_pipeline_key & other) const { return circular == other.circular; } +}; + +struct ggml_webgpu_pad_pipeline_key_hash { + size_t operator()(const ggml_webgpu_pad_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.circular); + return seed; + } +}; + +struct ggml_webgpu_pad_shader_lib_context { + ggml_webgpu_pad_pipeline_key key; + uint32_t max_wg_size; +}; + +inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_pad_shader( + pre_wgsl::Preprocessor & preprocessor, + const char * shader_src, + const ggml_webgpu_pad_shader_lib_context & context) { + std::vector defines; + std::string variant = "pad"; + + if (context.key.circular) { + defines.push_back("CIRCULAR"); + variant += "_circular"; + } + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + ggml_webgpu_processed_shader result; + result.wgsl = preprocessor.preprocess(shader_src, defines); + result.variant = variant; + ggml_webgpu_generic_shader_decisions * decisions = new ggml_webgpu_generic_shader_decisions(); + decisions->wg_size = context.max_wg_size; + result.decisions = decisions; + return result; +} + +/** Argsort **/ + +struct ggml_webgpu_argsort_shader_lib_context { + uint32_t max_wg_size; + size_t wg_mem_limit_bytes; + int32_t order; +}; + +struct ggml_webgpu_argsort_shader_decisions { + uint32_t wg_size = 0; +}; + +inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_argsort_shader( + pre_wgsl::Preprocessor & preprocessor, + const char * shader_src, + const ggml_webgpu_argsort_shader_lib_context & context) { + std::vector defines; + std::string variant = "argsort"; + defines.push_back(std::string("ORDER=") + std::to_string(context.order)); + variant += std::string("_order") + std::to_string(context.order); + uint32_t wg_size = 1; + while (wg_size * 2 <= context.max_wg_size && + wg_size * GGML_WEBGPU_I32_SIZE_BYTES <= context.wg_mem_limit_bytes / 2) { + wg_size *= 2; + } + defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); + ggml_webgpu_processed_shader result; + result.wgsl = preprocessor.preprocess(shader_src, defines); + result.variant = variant; + ggml_webgpu_argsort_shader_decisions * decisions = new ggml_webgpu_argsort_shader_decisions(); + decisions->wg_size = wg_size; + result.decisions = decisions; + return result; +} + +inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_argsort_merge_shader( + pre_wgsl::Preprocessor & preprocessor, + const char * shader_src, + const ggml_webgpu_argsort_shader_lib_context & context) { + std::vector defines; + std::string variant = "argsort_merge"; + defines.push_back(std::string("ORDER=") + std::to_string(context.order)); + variant += std::string("_order") + std::to_string(context.order); + uint32_t wg_size = std::min(GGML_WEBGPU_ARGSORT_MERGE_MAX_WG_SIZE, context.max_wg_size); + defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); + ggml_webgpu_processed_shader result; + result.wgsl = preprocessor.preprocess(shader_src, defines); + result.variant = variant; + ggml_webgpu_argsort_shader_decisions * decisions = new ggml_webgpu_argsort_shader_decisions(); + decisions->wg_size = wg_size; + result.decisions = decisions; + return result; +} + +/** Set Rows **/ + +struct ggml_webgpu_set_rows_pipeline_key { + int dst_type; + int vec4; + int i64_idx; + + bool operator==(const ggml_webgpu_set_rows_pipeline_key & other) const { + return dst_type == other.dst_type && vec4 == other.vec4 && i64_idx == other.i64_idx; + } +}; + +struct ggml_webgpu_set_rows_pipeline_key_hash { + size_t operator()(const ggml_webgpu_set_rows_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.dst_type); + ggml_webgpu_hash_combine(seed, key.vec4); + ggml_webgpu_hash_combine(seed, key.i64_idx); + return seed; + } +}; + +struct ggml_webgpu_set_rows_shader_lib_context { + ggml_webgpu_set_rows_pipeline_key key; + uint32_t max_wg_size; +}; + +inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_set_rows_shader( + pre_wgsl::Preprocessor & preprocessor, + const char * shader_src, + const ggml_webgpu_set_rows_shader_lib_context & context) { + std::vector defines; + std::string variant = "set_rows"; + + switch (context.key.dst_type) { + case GGML_TYPE_F32: + defines.push_back("DST_F32"); + variant += "_dstf32"; + break; + case GGML_TYPE_F16: + defines.push_back("DST_F16"); + variant += "_dstf16"; + break; + default: + GGML_ABORT("Unsupported dst type for set_rows shader"); + } + + if (context.key.vec4) { + defines.push_back("VEC4"); + variant += "_vec"; + } + if (context.key.i64_idx) { + defines.push_back("I64_IDX"); + variant += "_i64idx"; + } + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + ggml_webgpu_processed_shader result; + result.wgsl = preprocessor.preprocess(shader_src, defines); + result.variant = variant; + ggml_webgpu_generic_shader_decisions * decisions = new ggml_webgpu_generic_shader_decisions(); + decisions->wg_size = context.max_wg_size; + result.decisions = decisions; + return result; +} + +struct ggml_webgpu_unary_pipeline_key { + int type; + int op; + bool is_unary; // many unary operators fall under the GGML_OP_UNARY umbrella + bool inplace; + + bool operator==(const ggml_webgpu_unary_pipeline_key & other) const { + return type == other.type && op == other.op && is_unary == other.is_unary && inplace == other.inplace; + } +}; + +struct ggml_webgpu_unary_pipeline_key_hash { + size_t operator()(const ggml_webgpu_unary_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.type); + ggml_webgpu_hash_combine(seed, key.op); + ggml_webgpu_hash_combine(seed, key.is_unary); + ggml_webgpu_hash_combine(seed, key.inplace); + return seed; + } +}; + +struct ggml_webgpu_unary_shader_lib_context { + ggml_webgpu_unary_pipeline_key key; + uint32_t max_wg_size; +}; + +inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_unary_shader( + pre_wgsl::Preprocessor & preprocessor, + const char * shader_src, + const ggml_webgpu_unary_shader_lib_context & context) { + std::vector defines; + std::string variant = context.key.is_unary ? ggml_unary_op_name((ggml_unary_op) context.key.op) : + ggml_op_name((ggml_op) context.key.op); + // Operation-specific behavior + defines.push_back(variant); + + switch (context.key.type) { + case GGML_TYPE_F32: + defines.push_back("TYPE_F32"); + variant += "_f32"; + break; + case GGML_TYPE_F16: + defines.push_back("TYPE_F16"); + variant += "_f16"; + break; + default: + GGML_ABORT("Unsupported type for unary shader"); + } + + if (context.key.inplace) { + defines.push_back("INPLACE"); + variant += "_inplace"; + } + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + ggml_webgpu_processed_shader result; + result.wgsl = preprocessor.preprocess(shader_src, defines); + result.variant = variant; + ggml_webgpu_generic_shader_decisions * decisions = new ggml_webgpu_generic_shader_decisions(); + decisions->wg_size = context.max_wg_size; + result.decisions = decisions; return result; } diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 5b8f7f72d57..1470378af00 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -259,7 +259,7 @@ struct webgpu_pipeline { struct webgpu_command { wgpu::CommandBuffer commands; - webgpu_pool_bufs params_bufs; + std::vector params_bufs; std::optional set_rows_error_bufs; #ifdef GGML_WEBGPU_GPU_PROFILE webgpu_gpu_profile_bufs timestamp_query_bufs; @@ -267,46 +267,6 @@ struct webgpu_command { #endif }; -struct flash_attn_pipeline_key { - int q_type; - int kv_type; - int dst_type; - uint32_t head_dim_qk; - uint32_t head_dim_v; - bool kv_direct; - bool has_mask; - bool has_sinks; - bool uses_logit_softcap; - - bool operator==(const flash_attn_pipeline_key & other) const { - return q_type == other.q_type && kv_type == other.kv_type && dst_type == other.dst_type && - head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v && kv_direct == other.kv_direct && - has_mask == other.has_mask && has_sinks == other.has_sinks && - uses_logit_softcap == other.uses_logit_softcap; - } -}; - -// Same hash combine function as in boost -template inline void ggml_webgpu_hash_combine(size_t & seed, const T & value) { - seed ^= std::hash{}(value) + 0x9e3779b9 + (seed << 6) + (seed >> 2); -} - -struct flash_attn_pipeline_key_hash { - size_t operator()(const flash_attn_pipeline_key & key) const { - size_t seed = 0; - ggml_webgpu_hash_combine(seed, key.q_type); - ggml_webgpu_hash_combine(seed, key.kv_type); - ggml_webgpu_hash_combine(seed, key.dst_type); - ggml_webgpu_hash_combine(seed, key.head_dim_qk); - ggml_webgpu_hash_combine(seed, key.head_dim_v); - ggml_webgpu_hash_combine(seed, key.kv_direct); - ggml_webgpu_hash_combine(seed, key.has_mask); - ggml_webgpu_hash_combine(seed, key.has_sinks); - ggml_webgpu_hash_combine(seed, key.uses_logit_softcap); - return seed; - } -}; - // All the base objects needed to run operations on a WebGPU device struct webgpu_context_struct { wgpu::Instance instance; @@ -336,9 +296,17 @@ struct webgpu_context_struct { std::map>> mul_mat_vec_pipelines; // src0_type, src1_type, vectorized - std::unordered_map flash_attn_pipelines; + std::unordered_map + flash_attn_pipelines; - std::map> set_rows_pipelines; // dst_type, vectorized + std::unordered_map argmax_pipelines; // key is vec4 + std::unordered_map argsort_pipelines; // key is order (asc/desc) + std::unordered_map argsort_merge_pipelines; // key is order (asc/desc) + std::unordered_map cumsum_pipelines; // key is fixed, no variants yet + std::unordered_map sum_rows_pipelines; // key is fixed, no variants yet + + std::unordered_map + set_rows_pipelines; std::map> get_rows_pipelines; // src_type, vectorized std::map> cpy_pipelines; // src_type, dst_type @@ -352,7 +320,9 @@ struct webgpu_context_struct { std::map>> glu_pipelines; // glu_op, type, split std::map scale_pipelines; // inplace std::map>> soft_max_pipelines; // mask_type, has_sink, inplace - std::map>> unary_pipelines; // unary_op, type, inplace + std::unordered_map + unary_pipelines; + std::unordered_map pad_pipelines; size_t memset_bytes_per_thread; @@ -547,7 +517,7 @@ static webgpu_submission_futures ggml_backend_webgpu_submit(webgpu_context ctx, for (const auto & command : commands) { command_buffers.push_back(command.commands); - params_bufs.push_back(command.params_bufs); + params_bufs.insert(params_bufs.end(), command.params_bufs.begin(), command.params_bufs.end()); if (command.set_rows_error_bufs) { set_rows_error_bufs.push_back(command.set_rows_error_bufs.value()); } @@ -563,7 +533,7 @@ static webgpu_submission_futures ggml_backend_webgpu_submit(webgpu_context ctx, GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", std::string(message).c_str()); } // Free the staged buffers - ctx->param_buf_pool.free_bufs({ params_bufs }); + ctx->param_buf_pool.free_bufs(params_bufs); }); futures.push_back({ p_f }); @@ -610,42 +580,60 @@ static webgpu_submission_futures ggml_backend_webgpu_submit(webgpu_context ctx, return { futures }; } -static webgpu_command ggml_backend_webgpu_build(webgpu_context & ctx, - webgpu_pipeline & pipeline, - std::vector params, - std::vector bind_group_entries, - uint32_t wg_x, - uint32_t wg_y = 1, - std::optional set_rows_error_bufs = std::nullopt) { - webgpu_pool_bufs params_bufs = ctx->param_buf_pool.alloc_bufs(); - - ggml_backend_webgpu_map_buffer(ctx, params_bufs.host_buf, wgpu::MapMode::Write, 0, params_bufs.host_buf.GetSize()); - uint32_t * _params = (uint32_t *) params_bufs.host_buf.GetMappedRange(); - for (size_t i = 0; i < params.size(); i++) { - _params[i] = params[i]; - }; +static webgpu_command ggml_backend_webgpu_build_multi( + webgpu_context & ctx, + const std::vector & pipelines, + const std::vector> & params_list, + const std::vector> & bind_group_entries_list, + const std::vector> & workgroups_list, + const std::optional & set_rows_error_bufs = std::nullopt) { + GGML_ASSERT(pipelines.size() == params_list.size()); + GGML_ASSERT(pipelines.size() == bind_group_entries_list.size()); + GGML_ASSERT(pipelines.size() == workgroups_list.size()); + + std::vector params_bufs_list; + std::vector bind_groups; + + for (size_t i = 0; i < pipelines.size(); i++) { + webgpu_pool_bufs params_bufs = ctx->param_buf_pool.alloc_bufs(); + + ggml_backend_webgpu_map_buffer(ctx, params_bufs.host_buf, wgpu::MapMode::Write, 0, + params_bufs.host_buf.GetSize()); + uint32_t * _params = (uint32_t *) params_bufs.host_buf.GetMappedRange(); + for (size_t j = 0; j < params_list[i].size(); j++) { + _params[j] = params_list[i][j]; + } + params_bufs.host_buf.Unmap(); - params_bufs.host_buf.Unmap(); + std::vector entries = bind_group_entries_list[i]; + uint32_t params_binding_num = entries.size(); + entries.push_back({ .binding = params_binding_num, + .buffer = params_bufs.dev_buf, + .offset = 0, + .size = params_bufs.dev_buf.GetSize() }); - uint32_t params_bufs_binding_num = bind_group_entries.size(); - bind_group_entries.push_back({ .binding = params_bufs_binding_num, - .buffer = params_bufs.dev_buf, - .offset = 0, - .size = params_bufs.dev_buf.GetSize() }); + wgpu::BindGroupDescriptor bind_group_desc; + bind_group_desc.layout = pipelines[i].pipeline.GetBindGroupLayout(0); + bind_group_desc.entryCount = entries.size(); + bind_group_desc.entries = entries.data(); + bind_group_desc.label = pipelines[i].name.c_str(); + bind_groups.push_back(ctx->device.CreateBindGroup(&bind_group_desc)); - wgpu::BindGroupDescriptor bind_group_desc; - bind_group_desc.layout = pipeline.pipeline.GetBindGroupLayout(0); - bind_group_desc.entryCount = bind_group_entries.size(); - bind_group_desc.entries = bind_group_entries.data(); - bind_group_desc.label = pipeline.name.c_str(); - wgpu::BindGroup bind_group = ctx->device.CreateBindGroup(&bind_group_desc); + params_bufs_list.push_back(params_bufs); + } wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder(); - encoder.CopyBufferToBuffer(params_bufs.host_buf, 0, params_bufs.dev_buf, 0, params_bufs.dev_buf.GetSize()); + for (const auto & params_bufs : params_bufs_list) { + encoder.CopyBufferToBuffer(params_bufs.host_buf, 0, params_bufs.dev_buf, 0, params_bufs.dev_buf.GetSize()); + } + + // If there are SET_ROWS operations in this submission, copy their error buffers to the host. + if (set_rows_error_bufs) { + encoder.CopyBufferToBuffer(set_rows_error_bufs->dev_buf, 0, set_rows_error_bufs->host_buf, 0, + set_rows_error_bufs->host_buf.GetSize()); + } #ifdef GGML_WEBGPU_GPU_PROFILE - // --- Profiling: GPU timestamp queries --- - // Allocate a timestamp query buffer (2 timestamps: start/end) webgpu_gpu_profile_bufs ts_bufs = ctx->timestamp_query_buf_pool.alloc_bufs(); if (ts_bufs.host_buf.GetMapState() == wgpu::BufferMapState::Mapped) { ts_bufs.host_buf.Unmap(); @@ -659,35 +647,45 @@ static webgpu_command ggml_backend_webgpu_build(webgpu_context & #else wgpu::ComputePassEncoder pass = encoder.BeginComputePass(); #endif - pass.SetPipeline(pipeline.pipeline); - pass.SetBindGroup(0, bind_group); - pass.DispatchWorkgroups(wg_x, wg_y, 1); + for (size_t i = 0; i < pipelines.size(); i++) { + pass.SetPipeline(pipelines[i].pipeline); + pass.SetBindGroup(0, bind_groups[i]); + pass.DispatchWorkgroups(workgroups_list[i].first, workgroups_list[i].second, 1); + } pass.End(); #ifdef GGML_WEBGPU_GPU_PROFILE - // Resolve the query set into the device buffer encoder.ResolveQuerySet(ts_bufs.query_set, 0, 2, ts_bufs.dev_buf, 0); encoder.CopyBufferToBuffer(ts_bufs.dev_buf, 0, ts_bufs.host_buf, 0, ts_bufs.host_buf.GetSize()); #endif - // If there are SET_ROWS operations in this submission, copy their error buffers to the host. - if (set_rows_error_bufs) { - encoder.CopyBufferToBuffer(set_rows_error_bufs->dev_buf, 0, set_rows_error_bufs->host_buf, 0, - set_rows_error_bufs->host_buf.GetSize()); - } - wgpu::CommandBuffer commands = encoder.Finish(); webgpu_command result = {}; result.commands = commands; - result.params_bufs = params_bufs; + result.params_bufs = params_bufs_list; result.set_rows_error_bufs = set_rows_error_bufs; #ifdef GGML_WEBGPU_GPU_PROFILE result.timestamp_query_bufs = ts_bufs; - result.pipeline_name = pipeline.name; + // TODO: handle multiple pipeline names + result.pipeline_name = pipelines.front().name; #endif return result; } +static webgpu_command ggml_backend_webgpu_build(webgpu_context & ctx, + webgpu_pipeline & pipeline, + std::vector params, + std::vector bind_group_entries, + uint32_t wg_x, + uint32_t wg_y = 1, + std::optional set_rows_error_bufs = std::nullopt) { + return ggml_backend_webgpu_build_multi(ctx, + { + pipeline + }, + { params }, { bind_group_entries }, { { wg_x, wg_y } }, set_rows_error_bufs); +} + static void ggml_backend_webgpu_buffer_memset(webgpu_context & ctx, wgpu::Buffer & buf, uint32_t value, @@ -823,6 +821,79 @@ static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, g return ggml_backend_webgpu_build(ctx, ctx->cpy_pipelines[src->type][dst->type], params, entries, wg_x); } +static webgpu_command ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { + const bool circular = ggml_get_op_params_i32(dst, 8) != 0; + + ggml_webgpu_pad_pipeline_key pipeline_key = { .circular = circular }; + ggml_webgpu_pad_shader_lib_context shader_lib_ctx = { .key = pipeline_key, + .max_wg_size = + ctx->limits.maxComputeInvocationsPerWorkgroup }; + + webgpu_pipeline pipeline; + { + // TODO: remove guard once pipeline caches are per-thread + std::lock_guard lock(ctx->mutex); + auto it = ctx->pad_pipelines.find(pipeline_key); + if (it != ctx->pad_pipelines.end()) { + pipeline = it->second; + } else { + ggml_webgpu_processed_shader processed = + ggml_webgpu_preprocess_pad_shader(ctx->p, wgsl_pad, shader_lib_ctx); + pipeline = ggml_webgpu_create_pipeline(ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); + pipeline.context = processed.decisions; + ctx->pad_pipelines.emplace(pipeline_key, pipeline); + } + } + + ggml_webgpu_generic_shader_decisions decisions = + *static_cast(pipeline.context); + + const uint32_t ne = (uint32_t) ggml_nelements(dst); + + std::vector params = { + ne, + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + // Strides (in elements) + (uint32_t) (src->nb[0] / ggml_type_size(src->type)), + (uint32_t) (src->nb[1] / ggml_type_size(src->type)), + (uint32_t) (src->nb[2] / ggml_type_size(src->type)), + (uint32_t) (src->nb[3] / ggml_type_size(src->type)), + // Shapes + (uint32_t) src->ne[0], + (uint32_t) src->ne[1], + (uint32_t) src->ne[2], + (uint32_t) src->ne[3], + (uint32_t) dst->ne[0], + (uint32_t) dst->ne[1], + (uint32_t) dst->ne[2], + (uint32_t) dst->ne[3], + // Pad sizes + (uint32_t) ggml_get_op_params_i32(dst, 0), + (uint32_t) ggml_get_op_params_i32(dst, 1), + (uint32_t) ggml_get_op_params_i32(dst, 2), + (uint32_t) ggml_get_op_params_i32(dst, 3), + (uint32_t) ggml_get_op_params_i32(dst, 4), + (uint32_t) ggml_get_op_params_i32(dst, 5), + (uint32_t) ggml_get_op_params_i32(dst, 6), + (uint32_t) ggml_get_op_params_i32(dst, 7), + }; + + std::vector entries = { + { .binding = 0, + .buffer = ggml_webgpu_tensor_buf(src), + .offset = ggml_webgpu_tensor_align_offset(ctx, src), + .size = ggml_webgpu_tensor_binding_size(ctx, src) }, + { .binding = 1, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = ggml_webgpu_tensor_align_offset(ctx, dst), + .size = ggml_webgpu_tensor_binding_size(ctx, dst) } + }; + + uint32_t wg_x = CEIL_DIV(ne, decisions.wg_size); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); +} + static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * idx, @@ -832,9 +903,39 @@ static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, return std::nullopt; } - webgpu_pool_bufs error_bufs = ctx->set_rows_error_buf_pool.alloc_bufs(); - if (error_bufs.host_buf.GetMapState() == wgpu::BufferMapState::Mapped) { - error_bufs.host_buf.Unmap(); + ggml_webgpu_set_rows_pipeline_key key = { .dst_type = dst->type, + .vec4 = src->ne[0] % 4 == 0, + .i64_idx = idx->type == GGML_TYPE_I64 }; + + ggml_webgpu_set_rows_shader_lib_context shader_lib_ctx = { .key = key, + .max_wg_size = + ctx->limits.maxComputeInvocationsPerWorkgroup }; + + webgpu_pipeline pipeline; + // TODO: remove guard once pipeline caches are per-thread + { + std::lock_guard lock(ctx->mutex); + auto it = ctx->set_rows_pipelines.find(key); + if (it != ctx->set_rows_pipelines.end()) { + pipeline = it->second; + } else { + ggml_webgpu_processed_shader processed = + ggml_webgpu_preprocess_set_rows_shader(ctx->p, wgsl_set_rows, shader_lib_ctx); + pipeline = ggml_webgpu_create_pipeline(ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); + pipeline.context = processed.decisions; + ctx->set_rows_pipelines.emplace(key, pipeline); + } + } + + ggml_webgpu_generic_shader_decisions decisions = + *static_cast(pipeline.context); + + std::optional error_bufs = std::nullopt; + if (key.i64_idx) { + error_bufs = ctx->set_rows_error_buf_pool.alloc_bufs(); + if (error_bufs->host_buf.GetMapState() == wgpu::BufferMapState::Mapped) { + error_bufs->host_buf.Unmap(); + } } std::vector params = { @@ -865,21 +966,21 @@ static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, { .binding = 2, .buffer = ggml_webgpu_tensor_buf(dst), .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }, - { .binding = 3, .buffer = error_bufs.dev_buf, .offset = 0, .size = error_bufs.dev_buf.GetSize() } + .size = ggml_webgpu_tensor_binding_size(ctx, dst) } }; - int vectorized = src->ne[0] % 4 == 0; - webgpu_pipeline pipeline = ctx->set_rows_pipelines[0][vectorized]; - uint32_t threads; - if (vectorized) { + if (key.i64_idx) { + entries.push_back( + { .binding = 3, .buffer = error_bufs->dev_buf, .offset = 0, .size = error_bufs->dev_buf.GetSize() }); + } + + uint32_t threads; + if (key.vec4) { threads = (src->ne[1] * src->ne[2] * src->ne[3]) * (src->ne[0] / 4); } else { threads = src->ne[0] * src->ne[1] * src->ne[2] * src->ne[3]; } - - uint32_t wg_x = CEIL_DIV(threads, WEBGPU_MAX_WG_SIZE); - + uint32_t wg_x = CEIL_DIV(threads, decisions.wg_size); return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, 1, error_bufs); } @@ -1112,10 +1213,8 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, bool kv_direct = (K->type == GGML_TYPE_F16) && (Q->ne[0] % ctx->sg_mat_k == 0) && (K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0); - flash_attn_pipeline_key key = { - .q_type = Q->type, + ggml_webgpu_flash_attn_pipeline_key key = { .kv_type = K->type, - .dst_type = dst->type, .head_dim_qk = (uint32_t) Q->ne[0], .head_dim_v = (uint32_t) V->ne[0], .kv_direct = kv_direct, @@ -1125,29 +1224,17 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, }; webgpu_pipeline pipeline; - ggml_webgpu_flash_attn_shader_decisions decisions = {}; - - auto it = ctx->flash_attn_pipelines.find(key); - if (it != ctx->flash_attn_pipelines.end()) { - pipeline = it->second; - decisions = *static_cast(pipeline.context); - } else { + // TODO: remove guard once pipeline caches are per-thread + { std::lock_guard lock(ctx->mutex); - it = ctx->flash_attn_pipelines.find(key); + auto it = ctx->flash_attn_pipelines.find(key); if (it != ctx->flash_attn_pipelines.end()) { pipeline = it->second; - decisions = *static_cast(pipeline.context); } else { - ggml_webgpu_flash_attn_shader_lib_context shader_lib_ctx = { .kv_type = K->type, - .head_dim_qk = (uint32_t) Q->ne[0], - .head_dim_v = (uint32_t) V->ne[0], - .kv_direct = kv_direct, - .has_mask = static_cast(has_mask), - .has_sinks = static_cast(has_sinks), - .uses_logit_softcap = logit_softcap != 0.0f, - .sg_mat_m = ctx->sg_mat_m, - .sg_mat_n = ctx->sg_mat_n, - .sg_mat_k = ctx->sg_mat_k, + ggml_webgpu_flash_attn_shader_lib_context shader_lib_ctx = { .key = key, + .sg_mat_m = ctx->sg_mat_m, + .sg_mat_n = ctx->sg_mat_n, + .sg_mat_k = ctx->sg_mat_k, .wg_mem_limit_bytes = ctx->limits.maxComputeWorkgroupStorageSize, .max_subgroup_size = ctx->max_subgroup_size }; @@ -1155,59 +1242,101 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, ggml_webgpu_processed_shader processed = ggml_webgpu_preprocess_flash_attn_shader(ctx->p, wgsl_flash_attn, shader_lib_ctx); pipeline = ggml_webgpu_create_pipeline(ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); - pipeline.context = new ggml_webgpu_flash_attn_shader_decisions(processed.decisions); + pipeline.context = processed.decisions; ctx->flash_attn_pipelines.emplace(key, pipeline); - decisions = processed.decisions; } } + ggml_webgpu_flash_attn_shader_decisions decisions = + *static_cast(pipeline.context); + + uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions.q_tile); uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { - uint32_t ne = (uint32_t) ggml_nelements(dst); - ggml_unary_op unary_op = ggml_get_unary_op(dst); - uint32_t inplace = ggml_webgpu_tensor_equal(src, dst); + bool is_unary = dst->op == GGML_OP_UNARY; + bool inplace = ggml_webgpu_tensor_equal(src, dst) || (dst->op == GGML_OP_FILL); + int op = is_unary ? (int) ggml_get_unary_op(dst) : dst->op; - std::vector params = { - ne, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), - // Convert byte-strides to element-strides - (uint32_t) (src->nb[0] / ggml_type_size(src->type)), (uint32_t) (src->nb[1] / ggml_type_size(src->type)), - (uint32_t) (src->nb[2] / ggml_type_size(src->type)), (uint32_t) (src->nb[3] / ggml_type_size(src->type)), - (uint32_t) (dst->nb[0] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), - (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), - // Logical shapes - (uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) dst->ne[0], - (uint32_t) dst->ne[1], (uint32_t) dst->ne[2] + ggml_webgpu_unary_pipeline_key pipeline_key = { + .type = dst->type, .op = op, .is_unary = is_unary, .inplace = inplace }; + ggml_webgpu_unary_shader_lib_context shader_lib_ctx = { .key = pipeline_key, + .max_wg_size = + ctx->limits.maxComputeInvocationsPerWorkgroup }; - switch (unary_op) { - case GGML_UNARY_OP_XIELU: - { - // Get float parameters and reinterpret their bit patterns as uint32_t - // for passing through the params buffer - float alpha_n = ggml_get_op_params_f32(dst, 1); - float alpha_p = ggml_get_op_params_f32(dst, 2); - float beta = ggml_get_op_params_f32(dst, 3); - float eps = ggml_get_op_params_f32(dst, 4); - params.push_back(*reinterpret_cast(&alpha_n)); - params.push_back(*reinterpret_cast(&alpha_p)); - params.push_back(*reinterpret_cast(&beta)); - params.push_back(*reinterpret_cast(&eps)); + webgpu_pipeline pipeline; + { + // TODO: remove guard once pipeline caches are per-thread + std::lock_guard lock(ctx->mutex); + auto it = ctx->unary_pipelines.find(pipeline_key); + if (it != ctx->unary_pipelines.end()) { + pipeline = it->second; + } else { + ggml_webgpu_processed_shader processed = + ggml_webgpu_preprocess_unary_shader(ctx->p, wgsl_unary, shader_lib_ctx); + pipeline = ggml_webgpu_create_pipeline(ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); + pipeline.context = processed.decisions; + ctx->unary_pipelines.emplace(pipeline_key, pipeline); + } + } + + ggml_webgpu_generic_shader_decisions decisions = + *static_cast(pipeline.context); + + uint32_t ne = (uint32_t) ggml_nelements(dst); + + std::vector params = { ne, + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + (uint32_t) (src->nb[0] / ggml_type_size(src->type)), + (uint32_t) (src->nb[1] / ggml_type_size(src->type)), + (uint32_t) (src->nb[2] / ggml_type_size(src->type)), + (uint32_t) (src->nb[3] / ggml_type_size(src->type)), + (uint32_t) src->ne[0], + (uint32_t) src->ne[1], + (uint32_t) src->ne[2] }; + + ggml_tensor * effective_src = src; + if (is_unary) { + ggml_unary_op unary_op = ggml_get_unary_op(dst); + switch (unary_op) { + case GGML_UNARY_OP_XIELU: + { + // Get float parameters and reinterpret their bit patterns as uint32_t + // for passing through the params buffer + float alpha_n = ggml_get_op_params_f32(dst, 1); + float alpha_p = ggml_get_op_params_f32(dst, 2); + float beta = ggml_get_op_params_f32(dst, 3); + float eps = ggml_get_op_params_f32(dst, 4); + params.push_back(*reinterpret_cast(&alpha_n)); + params.push_back(*reinterpret_cast(&alpha_p)); + params.push_back(*reinterpret_cast(&beta)); + params.push_back(*reinterpret_cast(&eps)); + break; + } + default: break; - } - default: - break; + } + } else if (dst->op == GGML_OP_CLAMP) { + float clamp_min = ggml_get_op_params_f32(dst, 0); + float clamp_max = ggml_get_op_params_f32(dst, 1); + params.push_back(*reinterpret_cast(&clamp_min)); + params.push_back(*reinterpret_cast(&clamp_max)); + } else if (dst->op == GGML_OP_FILL) { + float fill_val = ggml_get_op_params_f32(dst, 0); + params.push_back(*reinterpret_cast(&fill_val)); + effective_src = dst; // fill simply fills dst } std::vector entries = { { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src), - .offset = ggml_webgpu_tensor_align_offset(ctx, src), - .size = ggml_webgpu_tensor_binding_size(ctx, src) }, + .buffer = ggml_webgpu_tensor_buf(effective_src), + .offset = ggml_webgpu_tensor_align_offset(ctx, effective_src), + .size = ggml_webgpu_tensor_binding_size(ctx, effective_src) }, }; if (!inplace) { entries.push_back({ .binding = 1, @@ -1216,8 +1345,8 @@ static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * s .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); } - uint32_t wg_x = CEIL_DIV(ne, WEBGPU_MAX_WG_SIZE); - return ggml_backend_webgpu_build(ctx, ctx->unary_pipelines[unary_op][dst->type][inplace], params, entries, wg_x); + uint32_t wg_x = CEIL_DIV(ne, decisions.wg_size); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx, @@ -1549,6 +1678,305 @@ static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx, ggml_nrows(dst)); } +static webgpu_command ggml_webgpu_argmax(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { + std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + (uint32_t) src->ne[0] }; + + std::vector entries = { + { .binding = 0, + .buffer = ggml_webgpu_tensor_buf(src), + .offset = ggml_webgpu_tensor_align_offset(ctx, src), + .size = ggml_webgpu_tensor_binding_size(ctx, src) }, + { .binding = 1, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = ggml_webgpu_tensor_align_offset(ctx, dst), + .size = ggml_webgpu_tensor_binding_size(ctx, dst) } + }; + + ggml_webgpu_generic_shader_lib_context shader_lib_ctx = { + .vec4 = src->ne[0] % 4 == 0, + .max_wg_size = ctx->limits.maxComputeInvocationsPerWorkgroup, + }; + + webgpu_pipeline pipeline; + { + // TODO: remove guard once pipeline caches are per-thread + std::lock_guard lock(ctx->mutex); + auto it = ctx->argmax_pipelines.find(shader_lib_ctx.vec4); + if (it != ctx->argmax_pipelines.end()) { + pipeline = it->second; + } else { + ggml_webgpu_processed_shader processed = + ggml_webgpu_preprocess_generic_shader(ctx->p, wgsl_argmax, shader_lib_ctx, "argmax"); + pipeline = ggml_webgpu_create_pipeline(ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); + ctx->argmax_pipelines.emplace(shader_lib_ctx.vec4, pipeline); + } + } + uint32_t wg_x = ggml_nelements(dst); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); +} + +static webgpu_command ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { + bool is_top_k = dst->op == GGML_OP_TOP_K; + // ascending order is 0, descending order is 1 + const int32_t order = is_top_k ? (int32_t) GGML_SORT_ORDER_DESC : (int32_t) ggml_get_op_params_i32(dst, 0); + + ggml_webgpu_argsort_shader_lib_context shader_lib_ctx = { .max_wg_size = + ctx->limits.maxComputeInvocationsPerWorkgroup, + .wg_mem_limit_bytes = + ctx->limits.maxComputeWorkgroupStorageSize, + .order = order }; + + std::lock_guard lock(ctx->mutex); + webgpu_pipeline argsort_pipeline; + auto it = ctx->argsort_pipelines.find(order); + if (it != ctx->argsort_pipelines.end()) { + argsort_pipeline = it->second; + } else { + ggml_webgpu_processed_shader processed = + ggml_webgpu_preprocess_argsort_shader(ctx->p, wgsl_argsort, shader_lib_ctx); + argsort_pipeline = ggml_webgpu_create_pipeline(ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); + argsort_pipeline.context = processed.decisions; + ctx->argsort_pipelines.emplace(order, argsort_pipeline); + } + ggml_webgpu_argsort_shader_decisions argsort_decisions = + *static_cast(argsort_pipeline.context); + + webgpu_pipeline argsort_merge_pipeline; + it = ctx->argsort_merge_pipelines.find(order); + if (it != ctx->argsort_merge_pipelines.end()) { + argsort_merge_pipeline = it->second; + } else { + ggml_webgpu_processed_shader processed = + ggml_webgpu_preprocess_argsort_merge_shader(ctx->p, wgsl_argsort_merge, shader_lib_ctx); + argsort_merge_pipeline = + ggml_webgpu_create_pipeline(ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); + argsort_merge_pipeline.context = processed.decisions; + ctx->argsort_merge_pipelines.emplace(order, argsort_merge_pipeline); + } + + const uint32_t src_ne0 = (uint32_t) src->ne[0]; + const uint32_t nrows = (uint32_t) ggml_nrows(src); + const uint32_t npr = CEIL_DIV(src_ne0, argsort_decisions.wg_size); + const uint32_t block_size = + is_top_k ? std::min(argsort_decisions.wg_size, (uint32_t) dst->ne[0]) : argsort_decisions.wg_size; + uint32_t out_ne0 = src_ne0; + if (is_top_k) { + if (npr > 1) { + const uint32_t last_tile = src_ne0 - (npr - 1) * argsort_decisions.wg_size; + out_ne0 = (npr - 1) * block_size + std::min(last_tile, block_size); + } else { + out_ne0 = block_size; + } + } + + uint32_t merge_len = block_size; + uint32_t merge_passes = 0; + while (merge_len < out_ne0) { + merge_len <<= 1; + merge_passes++; + } + + const bool start_in_tmp = (merge_passes % 2) == 1; + + const size_t dst_offset = ggml_webgpu_tensor_offset(dst); + const size_t idx_nbytes = out_ne0 * ggml_nrows(dst) * sizeof(int32_t); + const size_t tmp_offset = ROUNDUP_POW2(dst_offset + idx_nbytes, ctx->limits.minStorageBufferOffsetAlignment); + const size_t tmp_binding_size = ROUNDUP_POW2(idx_nbytes, WEBGPU_STORAGE_BUF_BINDING_MULT); + const size_t dst_binding_size = + ROUNDUP_POW2(idx_nbytes + ggml_webgpu_tensor_misalignment(ctx, dst), WEBGPU_STORAGE_BUF_BINDING_MULT); + + const uint32_t offset_src = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)); + const uint32_t offset_dst = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)); + const uint32_t offset_tmp = 0; + const uint32_t stride_src1 = (uint32_t) (src->nb[1] / ggml_type_size(src->type)); + const uint32_t stride_src2 = (uint32_t) (src->nb[2] / ggml_type_size(src->type)); + const uint32_t stride_src3 = (uint32_t) (src->nb[3] / ggml_type_size(src->type)); + const uint32_t stride_idx1 = out_ne0; + const uint32_t stride_idx2 = out_ne0 * (uint32_t) dst->ne[1]; + const uint32_t stride_idx3 = stride_idx2 * (uint32_t) dst->ne[2]; + + std::vector pipelines; + std::vector> params_list; + std::vector> entries_list; + std::vector> workgroups_list; + + const uint32_t init_offset = start_in_tmp ? offset_tmp : offset_dst; + const size_t init_align_offset = start_in_tmp ? tmp_offset : ggml_webgpu_tensor_align_offset(ctx, dst); + const size_t init_binding_size = start_in_tmp ? tmp_binding_size : dst_binding_size; + + std::vector init_params = { + offset_src, init_offset, stride_src1, stride_src2, stride_src3, stride_idx1, + stride_idx2, stride_idx3, src_ne0, (uint32_t) src->ne[1], (uint32_t) src->ne[2], out_ne0, + block_size, npr, nrows + }; + + const uint32_t total_wg_init = npr * nrows; + const uint32_t max_wg = ctx->limits.maxComputeWorkgroupsPerDimension; + const uint32_t wg_x_init = std::min(total_wg_init, max_wg); + const uint32_t wg_y_init = CEIL_DIV(total_wg_init, wg_x_init); + std::vector init_entries = { + { .binding = 0, + .buffer = ggml_webgpu_tensor_buf(src), + .offset = ggml_webgpu_tensor_align_offset(ctx, src), + .size = ggml_webgpu_tensor_binding_size(ctx, src) }, + { .binding = 1, .buffer = ggml_webgpu_tensor_buf(dst), .offset = init_align_offset, .size = init_binding_size } + }; + + pipelines.push_back(argsort_pipeline); + params_list.push_back(std::move(init_params)); + entries_list.push_back(std::move(init_entries)); + workgroups_list.push_back({ wg_x_init, wg_y_init }); + + if (merge_passes == 0) { + return ggml_backend_webgpu_build_multi(ctx, pipelines, params_list, entries_list, workgroups_list); + } + + bool in_is_tmp = start_in_tmp; + uint32_t len = block_size; + while (len < out_ne0) { + const uint32_t nm = CEIL_DIV(out_ne0, 2 * len); + + const bool out_is_tmp = !in_is_tmp; + const uint32_t offset_in = in_is_tmp ? offset_tmp : offset_dst; + const uint32_t offset_out = out_is_tmp ? offset_tmp : offset_dst; + const size_t align_in = in_is_tmp ? tmp_offset : ggml_webgpu_tensor_align_offset(ctx, dst); + const size_t align_out = out_is_tmp ? tmp_offset : ggml_webgpu_tensor_align_offset(ctx, dst); + const size_t size_in = in_is_tmp ? tmp_binding_size : dst_binding_size; + const size_t size_out = out_is_tmp ? tmp_binding_size : dst_binding_size; + const uint32_t top_k_out = (is_top_k && nm == 1) ? (uint32_t) dst->ne[0] : out_ne0; + const uint32_t stride_out1 = top_k_out; + const uint32_t stride_out2 = top_k_out * (uint32_t) dst->ne[1]; + const uint32_t stride_out3 = stride_out2 * (uint32_t) dst->ne[2]; + + std::vector merge_params = { offset_src, + offset_in, + offset_out, + stride_src1, + stride_src2, + stride_src3, + stride_idx1, + stride_idx2, + stride_idx3, + stride_out1, + stride_out2, + stride_out3, + out_ne0, + (uint32_t) src->ne[1], + (uint32_t) src->ne[2], + top_k_out, + len, + nm, + nrows }; + + std::vector merge_entries = { + { .binding = 0, + .buffer = ggml_webgpu_tensor_buf(src), + .offset = ggml_webgpu_tensor_align_offset(ctx, src), + .size = ggml_webgpu_tensor_binding_size(ctx, src) }, + { .binding = 1, .buffer = ggml_webgpu_tensor_buf(dst), .offset = align_in, .size = size_in }, + { .binding = 2, .buffer = ggml_webgpu_tensor_buf(dst), .offset = align_out, .size = size_out } + }; + + const uint32_t total_wg_merge = nm * nrows; + const uint32_t wg_x_merge = std::min(total_wg_merge, max_wg); + const uint32_t wg_y_merge = CEIL_DIV(total_wg_merge, wg_x_merge); + workgroups_list.push_back({ wg_x_merge, wg_y_merge }); + pipelines.push_back(argsort_merge_pipeline); + params_list.push_back(std::move(merge_params)); + entries_list.push_back(std::move(merge_entries)); + + len <<= 1; + in_is_tmp = !in_is_tmp; + } + + return ggml_backend_webgpu_build_multi(ctx, pipelines, params_list, entries_list, workgroups_list); +} + +static webgpu_command ggml_webgpu_cumsum(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { + std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + (uint32_t) src->ne[0] }; + + std::vector entries = { + { .binding = 0, + .buffer = ggml_webgpu_tensor_buf(src), + .offset = ggml_webgpu_tensor_align_offset(ctx, src), + .size = ggml_webgpu_tensor_binding_size(ctx, src) }, + { .binding = 1, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = ggml_webgpu_tensor_align_offset(ctx, dst), + .size = ggml_webgpu_tensor_binding_size(ctx, dst) } + }; + + ggml_webgpu_generic_shader_lib_context shader_lib_ctx = { + .vec4 = false, + .max_wg_size = ctx->limits.maxComputeInvocationsPerWorkgroup, + }; + webgpu_pipeline pipeline; + // TODO: remove guard once pipeline caches are per-thread + { + std::lock_guard lock(ctx->mutex); + auto it = ctx->cumsum_pipelines.find(1); + if (it != ctx->cumsum_pipelines.end()) { + pipeline = it->second; + } else { + ggml_webgpu_processed_shader processed = + ggml_webgpu_preprocess_generic_shader(ctx->p, wgsl_cumsum, shader_lib_ctx, "cumsum"); + pipeline = ggml_webgpu_create_pipeline(ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); + ctx->cumsum_pipelines.emplace(1, pipeline); + } + } + uint32_t wg_x = ggml_nrows(dst); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); +} + +static webgpu_command ggml_webgpu_sum_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { + bool total_sum = dst->op == GGML_OP_SUM; + std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + total_sum ? 0 : (uint32_t) (src->nb[1] / ggml_type_size(src->type)), + total_sum ? 0 : (uint32_t) (src->nb[2] / ggml_type_size(src->type)), + total_sum ? 0 : (uint32_t) (src->nb[3] / ggml_type_size(src->type)), + total_sum ? static_cast(ggml_nelements(src)) : (uint32_t) src->ne[0], + total_sum ? 1 : (uint32_t) src->ne[1], + total_sum ? 1 : (uint32_t) src->ne[2] }; + + std::vector entries = { + { .binding = 0, + .buffer = ggml_webgpu_tensor_buf(src), + .offset = ggml_webgpu_tensor_align_offset(ctx, src), + .size = ggml_webgpu_tensor_binding_size(ctx, src) }, + { .binding = 1, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = ggml_webgpu_tensor_align_offset(ctx, dst), + .size = ggml_webgpu_tensor_binding_size(ctx, dst) } + }; + + ggml_webgpu_generic_shader_lib_context shader_lib_ctx = { + .vec4 = false, + .max_wg_size = ctx->limits.maxComputeInvocationsPerWorkgroup, + }; + + webgpu_pipeline pipeline; + { + // TODO: remove guard once pipeline caches are per-thread + std::lock_guard lock(ctx->mutex); + auto it = ctx->sum_rows_pipelines.find(1); + if (it != ctx->sum_rows_pipelines.end()) { + pipeline = it->second; + } else { + ggml_webgpu_processed_shader processed = + ggml_webgpu_preprocess_generic_shader(ctx->p, wgsl_sum_rows, shader_lib_ctx, "sum_rows"); + pipeline = ggml_webgpu_create_pipeline(ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); + ctx->sum_rows_pipelines.emplace(1, pipeline); + } + } + uint32_t wg_x = total_sum ? 1 : ggml_nrows(dst); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); +} + // Returns the encoded command, or std::nullopt if the operation is a no-op static std::optional ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) { if (ggml_is_empty(node)) { @@ -1611,6 +2039,26 @@ static std::optional ggml_webgpu_encode_node(webgpu_context ctx, return ggml_webgpu_soft_max(ctx, src0, src1, src2, node); case GGML_OP_UNARY: return ggml_webgpu_unary_op(ctx, src0, node); + case GGML_OP_CLAMP: + return ggml_webgpu_unary_op(ctx, src0, node); + case GGML_OP_FILL: + return ggml_webgpu_unary_op(ctx, src0, node); + case GGML_OP_LOG: + return ggml_webgpu_unary_op(ctx, src0, node); + case GGML_OP_PAD: + return ggml_webgpu_pad(ctx, src0, node); + case GGML_OP_ARGMAX: + return ggml_webgpu_argmax(ctx, src0, node); + case GGML_OP_ARGSORT: + return ggml_webgpu_argsort(ctx, src0, node); + case GGML_OP_TOP_K: + // we reuse the same argsort implementation for top_k + return ggml_webgpu_argsort(ctx, src0, node); + case GGML_OP_CUMSUM: + return ggml_webgpu_cumsum(ctx, src0, node); + case GGML_OP_SUM: + case GGML_OP_SUM_ROWS: + return ggml_webgpu_sum_rows(ctx, src0, node); default: return std::nullopt; } @@ -1865,6 +2313,31 @@ static size_t ggml_backend_webgpu_buffer_type_get_max_size(ggml_backend_buffer_t return ctx->webgpu_ctx->limits.maxStorageBufferBindingSize; } +static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, + const ggml_tensor * tensor) { + ggml_backend_webgpu_device_context * ctx = static_cast(buft->device->context); + size_t res = ggml_nbytes(tensor); + switch (tensor->op) { + case GGML_OP_ARGSORT: + res = ROUNDUP_POW2(res * 2 + ctx->webgpu_ctx->limits.minStorageBufferOffsetAlignment, + WEBGPU_STORAGE_BUF_BINDING_MULT); + break; + case GGML_OP_TOP_K: + { + const ggml_tensor * src0 = tensor->src[0]; + if (src0) { + const size_t full = sizeof(int32_t) * ggml_nelements(src0); + res = ROUNDUP_POW2(full * 2 + ctx->webgpu_ctx->limits.minStorageBufferOffsetAlignment, + WEBGPU_STORAGE_BUF_BINDING_MULT); + } + } + break; + default: + break; + } + return res; +} + /* End GGML Backend Buffer Type Interface */ /* GGML Backend Device Interface */ @@ -1883,7 +2356,7 @@ static void ggml_backend_webgpu_device_get_memory(ggml_backend_dev_t dev, size_t ggml_backend_webgpu_device_context * ctx = static_cast(dev->context); // TODO: for now, return maxBufferSize as both free and total memory // Track https://github.com/gpuweb/gpuweb/issues/5505 for updates. - uint64_t max_buffer_size = ctx->webgpu_ctx->limits.maxBufferSize; + uint64_t max_buffer_size = ctx->webgpu_ctx->limits.maxBufferSize; // If we're on a 32-bit system, clamp to UINTPTR_MAX #if UINTPTR_MAX < UINT64_MAX uint64_t max_ptr_size = static_cast(UINTPTR_MAX); @@ -2086,13 +2559,6 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { webgpu_ctx->device, wgsl_mul_mat_vec_q4_0_f32, "mul_mat_vec_q4_0_f32", mul_mat_vec_constants); } -static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) { - webgpu_ctx->set_rows_pipelines[0][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, wgsl_set_rows_f16, "set_rows_f16", ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE)); - webgpu_ctx->set_rows_pipelines[0][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, wgsl_set_rows_f16_vec, "set_rows_f16_vec", ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE)); -} - static void ggml_webgpu_init_get_rows_pipeline(webgpu_context & webgpu_ctx) { std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); @@ -2152,6 +2618,8 @@ static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) { webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_F32] = ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_cpy_f32_f32, "cpy_f32_f32", constants); + webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_I32] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_cpy_f32_i32, "cpy_f32_i32", constants); webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_F16] = ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_cpy_f32_f16, "cpy_f32_f16", constants); webgpu_ctx->cpy_pipelines[GGML_TYPE_F16][GGML_TYPE_F32] = @@ -2303,180 +2771,6 @@ static void ggml_webgpu_init_glu_pipeline(webgpu_context & webgpu_ctx) { ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_quick_f16_split, "geglu_quick_f16_split", constants); } -static void ggml_webgpu_init_unary_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); - - // ABS - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ABS][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_abs_f32, "abs_f32", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ABS][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_abs_f16, "abs_f16", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ABS][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_abs_inplace_f32, "abs_inplace_f32", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ABS][GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_abs_inplace_f16, "abs_inplace_f16", constants); - - // SGN - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SGN][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sgn_f32, "sgn_f32", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SGN][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sgn_f16, "sgn_f16", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SGN][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sgn_inplace_f32, "sgn_inplace_f32", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SGN][GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sgn_inplace_f16, "sgn_inplace_f16", constants); - - // NEG - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_NEG][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_neg_f32, "neg_f32", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_NEG][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_neg_f16, "neg_f16", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_NEG][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_neg_inplace_f32, "neg_inplace_f32", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_NEG][GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_neg_inplace_f16, "neg_inplace_f16", constants); - - // STEP - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_STEP][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_step_f32, "step_f32", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_STEP][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_step_f16, "step_f16", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_STEP][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_step_inplace_f32, "step_inplace_f32", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_STEP][GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_step_inplace_f16, "step_inplace_f16", constants); - - // TANH - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_TANH][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_tanh_f32, "tanh_f32", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_TANH][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_tanh_f16, "tanh_f16", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_TANH][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_tanh_inplace_f32, "tanh_inplace_f32", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_TANH][GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_tanh_inplace_f16, "tanh_inplace_f16", constants); - - // ELU - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ELU][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_elu_f32, "elu_f32", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ELU][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_elu_f16, "elu_f16", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ELU][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_elu_inplace_f32, "elu_inplace_f32", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ELU][GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_elu_inplace_f16, "elu_inplace_f16", constants); - - // RELU - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_RELU][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_relu_f32, "relu_f32", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_RELU][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_relu_f16, "relu_f16", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_RELU][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_relu_inplace_f32, "relu_inplace_f32", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_RELU][GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_relu_inplace_f16, "relu_inplace_f16", constants); - - // SIGMOID - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SIGMOID][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sigmoid_f32, "sigmoid_f32", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SIGMOID][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sigmoid_f16, "sigmoid_f16", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SIGMOID][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sigmoid_inplace_f32, "sigmoid_inplace_f32", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SIGMOID][GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sigmoid_inplace_f16, "sigmoid_inplace_f16", constants); - - // GELU - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_f32, "gelu_f32", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_f16, "gelu_f16", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_inplace_f32, "gelu_inplace_f32", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU][GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_inplace_f16, "gelu_inplace_f16", constants); - - // GELU_QUICK - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_QUICK][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_quick_f32, "gelu_quick_f32", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_QUICK][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_quick_f16, "gelu_quick_f16", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_QUICK][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, wgsl_gelu_quick_inplace_f32, "gelu_quick_inplace_f32", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_QUICK][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, wgsl_gelu_quick_inplace_f16, "gelu_quick_inplace_f16", constants); - - // SILU - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SILU][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_silu_f32, "silu_f32", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SILU][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_silu_f16, "silu_f16", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SILU][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_silu_inplace_f32, "silu_inplace_f32", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SILU][GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_silu_inplace_f16, "silu_inplace_f16", constants); - - // HARDSWISH - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSWISH][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_hardswish_f32, "hardswish_f32", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSWISH][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_hardswish_f16, "hardswish_f16", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSWISH][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_hardswish_inplace_f32, "hardswish_inplace_f32", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSWISH][GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_hardswish_inplace_f16, "hardswish_inplace_f16", constants); - - // HARDSIGMOID - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSIGMOID][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_hardsigmoid_f32, "hardsigmoid_f32", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSIGMOID][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_hardsigmoid_f16, "hardsigmoid_f16", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSIGMOID][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, wgsl_hardsigmoid_inplace_f32, "hardsigmoid_inplace_f32", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSIGMOID][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, wgsl_hardsigmoid_inplace_f16, "hardsigmoid_inplace_f16", constants); - - // EXP - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_EXP][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_exp_f32, "exp_f32", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_EXP][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_exp_f16, "exp_f16", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_EXP][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_exp_inplace_f32, "exp_inplace_f32", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_EXP][GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_exp_inplace_f16, "exp_inplace_f16", constants); - - // GELU_ERF - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_ERF][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_erf_f32, "gelu_erf_f32", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_ERF][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_erf_f16, "gelu_erf_f16", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_ERF][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_erf_inplace_f32, "gelu_erf_inplace_f32", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_ERF][GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_erf_inplace_f16, "gelu_erf_inplace_f16", constants); - - // XIELU - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_XIELU][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_xielu_f32, "xielu_f32", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_XIELU][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_xielu_f16, "xielu_f16", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_XIELU][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_xielu_inplace_f32, "xielu_inplace_f32", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_XIELU][GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_xielu_inplace_f16, "xielu_inplace_f16", constants); - - // CEIL - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_CEIL][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_ceil_f32, "ceil_f32", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_CEIL][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_ceil_f16, "ceil_f16", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_CEIL][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_ceil_inplace_f32, "ceil_inplace_f32", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_CEIL][GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_ceil_inplace_f16, "ceil_inplace_f16", constants); -} - static void ggml_webgpu_init_scale_pipeline(webgpu_context & webgpu_ctx) { std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); @@ -2552,8 +2846,8 @@ static ggml_backend_buffer_type_t ggml_backend_webgpu_device_get_buffer_type(ggm /* .alloc_buffer = */ ggml_backend_webgpu_buffer_type_alloc_buffer, /* .get_alignment = */ ggml_backend_webgpu_buffer_type_get_alignment, /* .get_max_size = */ ggml_backend_webgpu_buffer_type_get_max_size, - /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes - /* .is_host = */ NULL, // defaults to false + /* .get_alloc_size = */ ggml_backend_webgpu_buffer_type_get_alloc_size, + /* .is_host = */ NULL, // defaults to false }, /* .device = */ dev, @@ -2631,16 +2925,19 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const break; case GGML_OP_CPY: case GGML_OP_CONT: - supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && - (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); + supports_op = ((op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && + (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) || + (op->type == GGML_TYPE_I32 && src0->type == GGML_TYPE_F32); break; case GGML_OP_SET_ROWS: - supports_op = (op->type == GGML_TYPE_F16 && src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I64); + supports_op = ((op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32) && src0->type == GGML_TYPE_F32 && + (src1->type == GGML_TYPE_I64 || src1->type == GGML_TYPE_I32)); break; case GGML_OP_GET_ROWS: - if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_I32 || - ggml_webgpu_supported_qtype(src0->type)) { + if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_webgpu_supported_qtype(src0->type)) { supports_op = (op->type == GGML_TYPE_F32); + } else if (src0->type == GGML_TYPE_I32) { + supports_op = op->type == GGML_TYPE_I32; } break; case GGML_OP_MUL_MAT: @@ -2753,9 +3050,14 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const case GGML_UNARY_OP_HARDSIGMOID: case GGML_UNARY_OP_EXP: case GGML_UNARY_OP_GELU_ERF: - case GGML_UNARY_OP_XIELU: + case GGML_UNARY_OP_SOFTPLUS: + case GGML_UNARY_OP_EXPM1: + case GGML_UNARY_OP_FLOOR: case GGML_UNARY_OP_CEIL: - supports_op = supports_op = + case GGML_UNARY_OP_ROUND: + case GGML_UNARY_OP_TRUNC: + case GGML_UNARY_OP_XIELU: + supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type); break; default: @@ -2763,7 +3065,34 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const } } break; - + case GGML_OP_CLAMP: + supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type); + break; + case GGML_OP_FILL: + supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32; + break; + case GGML_OP_LOG: + supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type); + break; + case GGML_OP_PAD: + supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32; + break; + case GGML_OP_ARGMAX: + supports_op = op->type == GGML_TYPE_I32 && src0->type == GGML_TYPE_F32; + break; + case GGML_OP_ARGSORT: + supports_op = op->type == GGML_TYPE_I32 && src0->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(src0); + break; + case GGML_OP_TOP_K: + supports_op = op->type == GGML_TYPE_I32 && src0->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(src0); + break; + case GGML_OP_CUMSUM: + supports_op = op->type == GGML_TYPE_F32 && src0->type == op->type; + break; + case GGML_OP_SUM: + case GGML_OP_SUM_ROWS: + supports_op = op->type == GGML_TYPE_F32 && src0->type == op->type && ggml_is_contiguous_rows(src0); + break; default: break; } @@ -2984,7 +3313,6 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t ggml_webgpu_init_memset_pipeline(ctx); ggml_webgpu_init_mul_mat_pipeline(ctx); - ggml_webgpu_init_set_rows_pipeline(ctx); ggml_webgpu_init_get_rows_pipeline(ctx); ggml_webgpu_init_cpy_pipeline(ctx); ggml_webgpu_init_add_pipeline(ctx); @@ -2996,7 +3324,6 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t ggml_webgpu_init_glu_pipeline(ctx); ggml_webgpu_init_scale_pipeline(ctx); ggml_webgpu_init_soft_max_pipeline(ctx); - ggml_webgpu_init_unary_pipeline(ctx); #ifdef GGML_WEBGPU_DEBUG // Initialize debug buffers diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl new file mode 100644 index 00000000000..ca5bfcc4d4c --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl @@ -0,0 +1,72 @@ +@group(0) @binding(0) +#ifdef VEC4 +var src: array>; +#define VEC_SIZE 4 +#else +var src: array; +#define VEC_SIZE 1 +#endif + +@group(0) @binding(1) +var dst: array; + +struct Params { + offset_src: u32, // in elements + offset_dst: u32, // in elements + ne0: u32, +}; + +@group(0) @binding(2) +var params: Params; + +const FLOAT_MIN: f32 = -1.0e9; + +struct Pair { + value: f32, + index: i32 +}; + +var shared_max: array; + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(workgroup_id) wid: vec3, + @builtin(local_invocation_id) lid: vec3) { + let row_idx = params.offset_src + wid.x * params.ne0; + var local_pair = Pair(FLOAT_MIN, -1); +#ifdef VEC4 + for (var col = lid.x; col < params.ne0/VEC_SIZE; col += WG_SIZE) { + let vec_val = src[row_idx / VEC_SIZE + col]; + for (var v = 0u; v < VEC_SIZE; v++) { + let val = vec_val[v]; + if (val >= local_pair.value) { + local_pair = Pair(val, i32(col * VEC_SIZE + v)); + } + } + } +#else + for (var col = lid.x; col < params.ne0; col += WG_SIZE) { + if (src[row_idx + col] >= local_pair.value) { + local_pair = Pair(src[row_idx + col], i32(col)); + } + } +#endif + shared_max[lid.x] = local_pair; + workgroupBarrier(); + var offset: u32 = WG_SIZE >> 1; + while (offset > 0) { + if (lid.x < offset) { + let a = shared_max[lid.x]; + let b = shared_max[lid.x + offset]; + if (b.value > a.value) { + shared_max[lid.x] = b; + } else if (b.value == a.value && b.index > a.index) { + shared_max[lid.x] = b; + } + } + workgroupBarrier(); + offset >>= 1; + } + if (lid.x == 0u) { + dst[params.offset_dst + wid.x] = shared_max[0].index; + } +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl new file mode 100644 index 00000000000..46ed19fc775 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl @@ -0,0 +1,106 @@ +@group(0) @binding(0) +var src: array; + +@group(0) @binding(1) +var dst: array; + +struct Params { + offset_src: u32, // in elements + offset_dst: u32, // in elements + + stride_src1: u32, + stride_src2: u32, + stride_src3: u32, + + stride_dst1: u32, + stride_dst2: u32, + stride_dst3: u32, + + // src/dst dimensions + src_ne0: u32, + ne1: u32, + ne2: u32, + + ne0: u32, + top_k: u32, + + npr: u32, // tiles per row + nrows: u32 +}; + +@group(0) @binding(2) +var params: Params; + +var shmem_idx: array; + +#if ORDER == 0 +#define EXTREME_VALUE 1e30 +#define SWAP_COMPARE_UP > +#define SWAP_COMPARE_DOWN < +#else +#define EXTREME_VALUE -1e30 +#define SWAP_COMPARE_UP < +#define SWAP_COMPARE_DOWN > +#endif + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(workgroup_id) wid: vec3, + @builtin(num_workgroups) num_wg: vec3, + @builtin(local_invocation_id) lid: vec3) { + let linear = wid.x + wid.y * num_wg.x; + // guard against overprovisioned workgroups + if (linear >= params.npr * params.nrows) { + return; + } + let tile = linear % params.npr; + var row = linear / params.npr; + let i3 = row / (params.ne2 * params.ne1); + row = row % (params.ne2 * params.ne1); + let i2 = row / params.ne1; + let i1 = row % params.ne1; + + let row_base = params.offset_src + + i1 * params.stride_src1 + + i2 * params.stride_src2 + + i3 * params.stride_src3; + + let tile_base = tile * WG_SIZE; + let idx = tile_base + lid.x; + shmem_idx[lid.x] = select(params.src_ne0, idx, idx < params.src_ne0); + workgroupBarrier(); + + var k = 2u; + while (k <= WG_SIZE) { + var j = k >> 1; + while (j > 0) { + let ixj = lid.x ^ j; + if (ixj > lid.x) { + let dir_up = (lid.x & k) == 0; + let a_idx = shmem_idx[lid.x]; + let b_idx = shmem_idx[ixj]; + let a_val = select(EXTREME_VALUE, src[row_base + a_idx], a_idx < params.src_ne0); + let b_val = select(EXTREME_VALUE, src[row_base + b_idx], b_idx < params.src_ne0); + let should_swap = select( + (a_val SWAP_COMPARE_DOWN b_val), + (a_val SWAP_COMPARE_UP b_val), + dir_up); + if (should_swap) { + shmem_idx[lid.x] = b_idx; + shmem_idx[ixj] = a_idx; + } + } + workgroupBarrier(); + j >>= 1; + } + k <<= 1; + } + + let out_idx = tile * params.top_k + lid.x; + if (out_idx < params.ne0 && lid.x < params.top_k) { + let row_dst = params.offset_dst + + i1 * params.stride_dst1 + + i2 * params.stride_dst2 + + i3 * params.stride_dst3; + dst[row_dst + out_idx] = i32(shmem_idx[lid.x]); + } +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl new file mode 100644 index 00000000000..9a77f6eca74 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl @@ -0,0 +1,134 @@ +@group(0) @binding(0) +var src: array; + +@group(0) @binding(1) +var idx_in: array; + +@group(0) @binding(2) +var idx_out: array; + +struct Params { + offset_src: u32, // in elements + offset_in: u32, // in elements + offset_out: u32, // in elements + + stride_src1: u32, + stride_src2: u32, + stride_src3: u32, + + stride_idx1: u32, + stride_idx2: u32, + stride_idx3: u32, + + stride_out1: u32, + stride_out2: u32, + stride_out3: u32, + + ne0: u32, + ne1: u32, + ne2: u32, + + top_k: u32, + + len: u32, + nm: u32, + nrows: u32 +}; + +@group(0) @binding(3) +var params: Params; + +fn take_left(a_idx: i32, b_idx: i32, row_base: u32) -> bool { + let a_val = src[row_base + u32(a_idx)]; + let b_val = src[row_base + u32(b_idx)]; +#if ORDER == 0 + return a_val <= b_val; +#else + return a_val >= b_val; +#endif +} + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(workgroup_id) wid: vec3, + @builtin(num_workgroups) num_wg: vec3, + @builtin(local_invocation_id) lid: vec3) { + let linear = wid.x + wid.y * num_wg.x; + // guard against overprovisioned workgroups + if (linear >= params.nm * params.nrows) { + return; + } + + let start = (linear % params.nm) * params.len * 2; + let len0 = min(params.len, params.ne0 - start); + let rem1 = select(0, params.ne0 - (start + params.len), params.ne0 > (start + params.len)); + let len1 = min(params.len, rem1); + let total = len0 + len1; + let chunk = (total + WG_SIZE - 1u) / WG_SIZE; + let k0 = lid.x * chunk; + let k1 = min(min(k0 + chunk, total), params.top_k); + // guard against overprovisioned threads + if (k0 >= params.top_k || k0 >= total) { + return; + } + + var row = linear / params.nm; + let i3 = row / (params.ne2 * params.ne1); + row = row % (params.ne2 * params.ne1); + let i2 = row / params.ne1; + let i1 = row % params.ne1; + + let row_src = params.offset_src + + i1 * params.stride_src1 + + i2 * params.stride_src2 + + i3 * params.stride_src3; + + let row_in = params.offset_in + + i1 * params.stride_idx1 + + i2 * params.stride_idx2 + + i3 * params.stride_idx3; + + let row_out = params.offset_out + + i1 * params.stride_out1 + + i2 * params.stride_out2 + + i3 * params.stride_out3; + + + var low: u32 = select(0, k0 - len1, k0 > len1); + var high: u32 = min(k0, len0); + + while (low < high) { + let mid = (low + high) >> 1; + let idx0 = idx_in[row_in + start + mid]; + let idx1 = idx_in[row_in + start + params.len + (k0 - mid - 1)]; + if (take_left(idx0, idx1, row_src)) { + low = mid + 1; + } else { + high = mid; + } + } + + var i = low; + var j = k0 - i; + var k = k0; + while (k < k1) { + var take_l = false; + if (i >= len0) { + take_l = false; + } else if (j >= len1) { + take_l = true; + } else { + let idx0 = idx_in[row_in + start + i]; + let idx1 = idx_in[row_in + start + params.len + j]; + take_l = take_left(idx0, idx1, row_src); + } + + let out_idx = select( + idx_in[row_in + start + params.len + j], + idx_in[row_in + start + i], + take_l); + idx_out[row_out + start + k] = out_idx; + i = select(i, i + 1, take_l); + j = select(j + 1, j, take_l); + k += 1; + } +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl index db1aa34903b..b5e93b812fd 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl @@ -7,6 +7,12 @@ "DST_TYPE": "f32" } }, + { + "REPLS": { + "SRC_TYPE": "f32", + "DST_TYPE": "i32" + } + }, { "REPLS": { "SRC_TYPE": "f32", diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl new file mode 100644 index 00000000000..e622552c421 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl @@ -0,0 +1,66 @@ +@group(0) @binding(0) +var src: array; + +@group(0) @binding(1) +var dst: array; + +struct Params { + offset_src: u32, // in elements + offset_dst: u32, // in elements + ne0: u32, +}; + +@group(0) @binding(2) +var params: Params; + +var shared_sum: array; + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(workgroup_id) wid: vec3, + @builtin(local_invocation_id) lid: vec3) { + let row_idx = params.offset_src + wid.x * params.ne0; + let elems = (params.ne0 + WG_SIZE - 1) / WG_SIZE; + var local_sum: f32 = 0.0; + for (var col = lid.x * elems; col < (lid.x + 1) * elems && col < params.ne0; col ++) { + local_sum += src[row_idx + col]; + } + shared_sum[lid.x] = local_sum; + workgroupBarrier(); + + // upsweep + var offset = 1u; + while (offset < WG_SIZE) { + let idx = (lid.x + 1) * offset * 2 - 1; + if (idx < WG_SIZE) { + shared_sum[idx] = shared_sum[idx] + shared_sum[idx - offset]; + } + workgroupBarrier(); + offset <<= 1; + } + + // set last to 0 for exclusive sum + if (lid.x == 0) { + shared_sum[WG_SIZE - 1] = 0.0; + } + workgroupBarrier(); + + // downsweep + offset = WG_SIZE >> 1; + while (offset > 0) { + let idx = (lid.x + 1) * offset * 2 - 1; + if (idx < WG_SIZE) { + let t = shared_sum[idx - offset]; + shared_sum[idx - offset] = shared_sum[idx]; + shared_sum[idx] = shared_sum[idx] + t; + } + workgroupBarrier(); + offset = offset >> 1; + } + + // shared_sum[lid] is exclusive prefix sum up to this thread. + var running_sum = shared_sum[lid.x]; + for (var col = lid.x * elems; col < (lid.x + 1) * elems && col < params.ne0; col ++) { + running_sum += src[row_idx + col]; + dst[params.offset_dst + wid.x * params.ne0 + col] = running_sum; + } +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl new file mode 100644 index 00000000000..ea63b9a731c --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl @@ -0,0 +1,86 @@ +@group(0) @binding(0) +var src: array; + +@group(0) @binding(1) +var dst: array; + +struct Params { + ne: u32, // total number of elements + offset_src: u32, // in elements + offset_dst: u32, // in elements + + // Strides (in elements) + stride_src0: u32, + stride_src1: u32, + stride_src2: u32, + stride_src3: u32, + + // Logical shapes + src_ne0: u32, + src_ne1: u32, + src_ne2: u32, + src_ne3: u32, + + dst_ne0: u32, + dst_ne1: u32, + dst_ne2: u32, + dst_ne3: u32, + + // Pad sizes (in elements) + lp0: u32, + rp0: u32, + lp1: u32, + rp1: u32, + lp2: u32, + rp2: u32, + lp3: u32, + rp3: u32, +}; + +@group(0) @binding(2) +var params: Params; + +fn wrap_around(idx: i32, n: u32) -> u32 { + return u32(idx + i32(n)) % n; +} + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(global_invocation_id) gid: vec3) { + if (gid.x >= params.ne) { + return; + } + + var i = gid.x; + let dst_plane = params.dst_ne2 * params.dst_ne1 * params.dst_ne0; + let i3 = i / dst_plane; + i = i % dst_plane; + let i2 = i / (params.dst_ne1 * params.dst_ne0); + i = i % (params.dst_ne1 * params.dst_ne0); + let i1 = i / params.dst_ne0; + let i0 = i % params.dst_ne0; + + var value: f32 = 0.0; + +#ifdef CIRCULAR + let ci0 = wrap_around(i32(i0) - i32(params.lp0), params.src_ne0); + let ci1 = wrap_around(i32(i1) - i32(params.lp1), params.src_ne1); + let ci2 = wrap_around(i32(i2) - i32(params.lp2), params.src_ne2); + let ci3 = wrap_around(i32(i3) - i32(params.lp3), params.src_ne3); + let circular_src_idx = ci0 * params.stride_src0 + ci1 * params.stride_src1 + + ci2 * params.stride_src2 + ci3 * params.stride_src3; + value = src[params.offset_src + circular_src_idx]; +#else + let is_src = + (i0 >= params.lp0 && i0 < params.dst_ne0 - params.rp0) && + (i1 >= params.lp1 && i1 < params.dst_ne1 - params.rp1) && + (i2 >= params.lp2 && i2 < params.dst_ne2 - params.rp2) && + (i3 >= params.lp3 && i3 < params.dst_ne3 - params.rp3); + if (is_src) { + let src_idx = (i0 - params.lp0) * params.stride_src0 + (i1 - params.lp1) * params.stride_src1 + + (i2 - params.lp2) * params.stride_src2 + (i3 - params.lp3) * params.stride_src3; + value = src[params.offset_src + src_idx]; + } +#endif + + dst[params.offset_dst + gid.x] = value; +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl deleted file mode 100644 index fca3be6bc27..00000000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +++ /dev/null @@ -1,112 +0,0 @@ -#define(VARIANTS) - -[ - { - "SHADER_SUFFIX": "f16_vec", - "REPLS": { - "TYPE" : "vec4", - "DST_TYPE": "vec4", - "VEC_SIZE": 4 - } - }, - { - "SHADER_SUFFIX": "f16", - "REPLS": { - "TYPE" : "f32", - "DST_TYPE": "f16", - "VEC_SIZE": 1 - } - } -] - -#end(VARIANTS) - -#define(SHADER) - -enable f16; - -@group(0) @binding(0) -var src: array<{{TYPE}}>; - -@group(0) @binding(1) -var idx: array; - -@group(0) @binding(2) -var dst: array<{{DST_TYPE}}>; - -@group(0) @binding(3) -var error: atomic; - -struct Params { - offset_src: u32, // in elements - offset_idx: u32, // in elements - offset_dst: u32, // in elements - - // Strides (in elements) - stride_src1: u32, - stride_src2: u32, - stride_src3: u32, - - stride_idx0: u32, - stride_idx1: u32, - stride_idx2: u32, - - stride_dst1: u32, - stride_dst2: u32, - stride_dst3: u32, - - // Shape of src - ne0: u32, - n_rows: u32, - ne2: u32, - ne3: u32, - - // Shape of idx - idx1: u32, - idx2: u32, -}; - -@group(0) @binding(4) -var params: Params; - -override wg_size: u32; -@compute @workgroup_size(wg_size) -fn main(@builtin(global_invocation_id) gid: vec3) { - if (gid.x >= (params.ne3 * params.ne2 * params.n_rows * params.ne0) / {{VEC_SIZE}}) { - return; - } - - // getting the row from gid - let elems_per_row = params.ne0 / {{VEC_SIZE}}; - var i = gid.x / elems_per_row; - - let i_src3 = i / (params.ne2 * params.n_rows); - - i = i % (params.ne2 * params.n_rows); - let i_src2 = i / params.n_rows; - let i_src1 = i % params.n_rows; - - let i_idx2 = i_src3 % params.idx2; - let i_idx1 = i_src2 % params.idx1; - let i_idx0 = i_src1; - - let idx_high = (params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2) * 2; - - let idx_high_val = idx[idx_high]; - let idx_low_val = idx[idx_high + 1]; - - if (idx_low_val != 0) { - // Upper bits of index are not zero, output will be incorrect - atomicStore(&error, 1); - return; - } - - let i_dst_row = params.offset_dst + idx_high_val * params.stride_dst1 + i_src2 * params.stride_dst2 + i_src3 * params.stride_dst3; - let i_src_row = params.offset_src + i_src1 * params.stride_src1 + i_src2 * params.stride_src2 + i_src3 * params.stride_src3; - - let col_idx = (gid.x % elems_per_row); - dst[i_dst_row/{{VEC_SIZE}} + col_idx] = {{DST_TYPE}}(src[i_src_row/{{VEC_SIZE}} + col_idx]); -} - -#end(SHADER) - diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl index 3567713dc21..99e9192c71a 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl @@ -1,16 +1,37 @@ enable f16; +#ifdef DST_F32 +#define DST_INNER_TYPE f32 +#else +#define DST_INNER_TYPE f16 +#endif + +#ifdef VEC4 +#define SRC_TYPE vec4 +#define DST_TYPE vec4 +#define VEC_SIZE 4 +#else +#define SRC_TYPE f32 +#define DST_TYPE DST_INNER_TYPE +#define VEC_SIZE 1 +#endif + @group(0) @binding(0) -var src: array; +var src: array; @group(0) @binding(1) var idx: array; @group(0) @binding(2) -var dst: array; +var dst: array; +#ifdef I64_IDX @group(0) @binding(3) var error: atomic; +#define PARAMS_BINDING 4 +#else +#define PARAMS_BINDING 3 +#endif struct Params { offset_src: u32, // in elements @@ -41,16 +62,19 @@ struct Params { idx2: u32, }; -@group(0) @binding(4) +@group(0) @binding(PARAMS_BINDING) var params: Params; -override wg_size: u32; -@compute @workgroup_size(wg_size) +@compute @workgroup_size(WG_SIZE) fn main(@builtin(global_invocation_id) gid: vec3) { - if (gid.x >= params.n_rows * params.ne2 * params.ne3) { + if (gid.x >= (params.ne3 * params.ne2 * params.n_rows * params.ne0) / VEC_SIZE) { return; } - var i = gid.x; + + // getting the row from gid + let elems_per_row = params.ne0 / VEC_SIZE; + var i = gid.x / elems_per_row; + let i_src3 = i / (params.ne2 * params.n_rows); i = i % (params.ne2 * params.n_rows); @@ -61,9 +85,10 @@ fn main(@builtin(global_invocation_id) gid: vec3) { let i_idx1 = i_src2 % params.idx1; let i_idx0 = i_src1; +#ifdef I64_IDX let idx_high = (params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2) * 2; - let idx_high_val = idx[idx_high]; + let idx_val = idx[idx_high]; let idx_low_val = idx[idx_high + 1]; if (idx_low_val != 0) { @@ -71,11 +96,14 @@ fn main(@builtin(global_invocation_id) gid: vec3) { atomicStore(&error, 1); return; } +#else + let idx_i = params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2; + let idx_val = idx[idx_i]; +#endif - let i_dst_row = params.offset_dst + idx_high_val * params.stride_dst1 + i_src2 * params.stride_dst2 + i_src3 * params.stride_dst3; + let i_dst_row = params.offset_dst + idx_val * params.stride_dst1 + i_src2 * params.stride_dst2 + i_src3 * params.stride_dst3; let i_src_row = params.offset_src + i_src1 * params.stride_src1 + i_src2 * params.stride_src2 + i_src3 * params.stride_src3; - for (var i: u32 = 0; i < params.ne0; i++) { - dst[i_dst_row + i] = f16(src[i_src_row + i]); - } + let col_idx = (gid.x % elems_per_row); + dst[i_dst_row/VEC_SIZE + col_idx] = DST_TYPE(src[i_src_row/VEC_SIZE + col_idx]); } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl new file mode 100644 index 00000000000..6ea2de9b7c6 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl @@ -0,0 +1,55 @@ +@group(0) @binding(0) +var src: array; + +@group(0) @binding(1) +var dst: array; + +struct Params { + offset_src: u32, // in elements + offset_dst: u32, // in elements + + // Strides (in elements) + stride_src1: u32, + stride_src2: u32, + stride_src3: u32, + + ne0: u32, + ne1: u32, + ne2: u32 +}; + +@group(0) @binding(2) +var params: Params; + +var shared_sum: array; + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(workgroup_id) wid: vec3, + @builtin(local_invocation_id) lid: vec3) { + + var i = wid.x; + let i3 = i / (params.ne2 * params.ne1); + i = i % (params.ne2 * params.ne1); + let i2 = i / params.ne1; + let i1 = i % params.ne1; + let i_src_row = params.offset_src + i3 * params.stride_src3 + i2 * params.stride_src2 + i1 * params.stride_src1; + var local_sum: f32 = 0.0; + for (var col = lid.x; col < params.ne0; col += WG_SIZE) { + local_sum += src[i_src_row + col]; + } + shared_sum[lid.x] = local_sum; + workgroupBarrier(); + // reduce within workgroup + var offset: u32 = WG_SIZE >> 1; + while (offset > 0) { + if (lid.x < offset) { + shared_sum[lid.x] = shared_sum[lid.x] + shared_sum[lid.x + offset]; + } + workgroupBarrier(); + offset >>= 1; + } + + if (lid.x == 0) { + dst[params.offset_dst + wid.x] = shared_sum[0]; + } +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl new file mode 100644 index 00000000000..d639d984970 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl @@ -0,0 +1,179 @@ +#ifdef TYPE_F16 +enable f16; +#define TYPE f16 +#else +#define TYPE f32 +#endif + + +@group(0) @binding(0) +var src: array; + +#ifndef INPLACE +@group(0) @binding(1) +var dst: array; +#define PARAMS_BINDING 2 +#else +#define PARAMS_BINDING 1 +#endif + +struct Params { + ne: u32, // total number of elements + offset_src: u32, // in elements + offset_dst: u32, // in elements + + // Strides (in elements) + stride_src0: u32, + stride_src1: u32, + stride_src2: u32, + stride_src3: u32, + + // Logical shapes + ne0: u32, + ne1: u32, + ne2: u32, +#ifdef CLAMP + clamp_min: f32, + clamp_max: f32, +#endif +#ifdef FILL + fill_val: f32, +#endif +#ifdef XIELU + alpha_n: f32, + alpha_p: f32, + beta: f32, + eps: f32, +#endif + +}; + +@group(0) @binding(PARAMS_BINDING) +var params: Params; + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(global_invocation_id) gid: vec3) { + if (gid.x >= params.ne) { + return; + } + var i = gid.x; + let i3 = i / (params.ne2 * params.ne1 * params.ne0); + i = i % (params.ne2 * params.ne1 * params.ne0); + let i2 = i / (params.ne1 * params.ne0); + i = i % (params.ne1 * params.ne0); + let i1 = i / params.ne0; + let i0 = i % params.ne0; + + let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 + + i2 * params.stride_src2 + i3 * params.stride_src3; + +#ifdef ABS + let res = abs(src[params.offset_src + src_idx]); +#endif +#ifdef SGN + let res = select(TYPE(select(0.0, -1.0, src[params.offset_src + src_idx] < 0.0)), TYPE(1.0), + src[params.offset_src + src_idx] > 0.0); +#endif +#ifdef NEG + let res = -src[params.offset_src + src_idx]; +#endif +#ifdef STEP + let res = TYPE(select(0.0, 1.0, src[params.offset_src + src_idx] > 0.0)); +#endif +#ifdef TANH + let res = tanh(clamp(src[params.offset_src + src_idx], -9.010913, 9.010913)); +#endif +#ifdef RELU + let res = select(0.0, src[params.offset_src + src_idx], src[params.offset_src + src_idx] > 0.0); +#endif +#ifdef ELU + let res = select(exp(src[params.offset_src + src_idx]) - 1.0, src[params.offset_src + src_idx], + src[params.offset_src + src_idx] > 0.0); +#endif +#ifdef HARDSIGMOID + let res = min(1.0, max(0.0, (src[params.offset_src + src_idx] + 3.0) / 6.0)); +#endif +#ifdef SIGMOID + let res = 1.0 / (1.0 + exp(-src[params.offset_src + src_idx])); +#endif +#ifdef SILU + let res = src[params.offset_src + src_idx] / (1.0 + exp(-src[params.offset_src + src_idx])); +#endif +#ifdef EXP + let res = exp(src[params.offset_src + src_idx]); +#endif +#ifdef LOG + let res = TYPE(log(f32(src[params.offset_src + src_idx]))); +#endif +#ifdef CLAMP + let res = clamp(src[params.offset_src + src_idx], TYPE(params.clamp_min), TYPE(params.clamp_max)); +#endif +#ifdef FILL + let res = TYPE(params.fill_val); +#endif +#ifdef HARDSWISH + let res = src[params.offset_src + src_idx] * + min(1.0, max(0.0, (src[params.offset_src + src_idx] + 3.0) / 6.0)); +#endif +#ifdef GELU + let res = 0.5 * src[params.offset_src + src_idx] * + (1.0 + tanh(clamp(sqrt(2.0 / 3.14159265) * + (src[params.offset_src + src_idx] + + 0.044715 * pow(src[params.offset_src + src_idx], 3.0)), + -9.010913, 9.010913))); +#endif +#ifdef GELU_QUICK + let res = src[params.offset_src + src_idx] * 0.5 * + (1.0 + tanh(clamp(0.79788456 * + (src[params.offset_src + src_idx] + + 0.044715 * src[params.offset_src + src_idx] * + src[params.offset_src + src_idx] * src[params.offset_src + src_idx]), + -9.010913, 9.010913))); +#endif +#ifdef GELU_ERF + let res = 0.5 * src[params.offset_src + src_idx] * + (1.0 + tanh(clamp(0.79788456 * + (src[params.offset_src + src_idx] + + 0.044715 * src[params.offset_src + src_idx] * + src[params.offset_src + src_idx] * src[params.offset_src + src_idx]), + -9.010913, 9.010913))); +#endif +#ifdef XIELU + let res = + select(((exp(min(src[params.offset_src + src_idx], TYPE(params.eps))) - 1.0) - + src[params.offset_src + src_idx]) * + TYPE(params.alpha_n) + + TYPE(params.beta) * src[params.offset_src + src_idx], + TYPE(params.alpha_p) * src[params.offset_src + src_idx] * + src[params.offset_src + src_idx] + + TYPE(params.beta) * src[params.offset_src + src_idx], + src[params.offset_src + src_idx] > 0.0); +#endif +#ifdef SOFTPLUS + let src_f32 = f32(src[params.offset_src + src_idx]); + let res = TYPE(select(log(1.0 + exp(src_f32)), src_f32, src_f32 > 20.0)); +#endif +#ifdef EXPM1 + let res = exp(src[params.offset_src + src_idx]) - 1.0; +#endif +#ifdef FLOOR + let res = floor(src[params.offset_src + src_idx]); +#endif +#ifdef CEIL + let res = ceil(src[params.offset_src + src_idx]); +#endif +#ifdef ROUND + let src_f32 = f32(src[params.offset_src + src_idx]); + let result = select(ceil(src_f32 - 0.5), floor(src_f32 + 0.5), src_f32 >= 0.0); + let res = TYPE(result); +#endif +#ifdef TRUNC + let res = trunc(src[params.offset_src + src_idx]); +#endif + +#ifdef INPLACE + src[params.offset_src + src_idx] = res; +#else + dst[params.offset_dst + gid.x] = res; +#endif +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl deleted file mode 100644 index 25fe2854518..00000000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +++ /dev/null @@ -1,483 +0,0 @@ -#define(REPL_TEMPLATES) - -{ - "XIELU_FUNC": "{{MUTATE}}[dst_i] = select(((exp(min(src[src_i], {{TYPE}}(params.eps))) - 1.0) - src[src_i]) * {{TYPE}}(params.alpha_n) + {{TYPE}}(params.beta) * src[src_i], {{TYPE}}(params.alpha_p) * src[src_i] * src[src_i] + {{TYPE}}(params.beta) * src[src_i], src[src_i] > 0.0);", - "ABS_FUNC": "{{MUTATE}}[dst_i] = abs(src[src_i]);", - "SGN_FUNC": "{{MUTATE}}[dst_i] = select({{TYPE}}(select(0.0, -1.0, src[src_i] < 0.0)), {{TYPE}}(1.0), src[src_i] > 0.0);", - "NEG_FUNC": "{{MUTATE}}[dst_i] = -src[src_i];", - "STEP_FUNC": "{{MUTATE}}[dst_i] = {{TYPE}}(select(0.0, 1.0, src[src_i] > 0.0));", - "TANH_FUNC": "{{MUTATE}}[dst_i] = tanh(clamp(src[src_i], -9.010913, 9.010913)); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", - "RELU_FUNC": "{{MUTATE}}[dst_i] = select(0.0, src[src_i], src[src_i] > 0.0);", - "ELU_FUNC": "{{MUTATE}}[dst_i] = select(exp(src[src_i]) - 1.0, src[src_i], src[src_i] > 0.0);", - "HARDSIGMOID_FUNC": "{{MUTATE}}[dst_i] = min(1.0, max(0.0, (src[src_i] + 3.0) / 6.0));", - "SIGMOID_FUNC": "{{MUTATE}}[dst_i] = 1.0 / (1.0 + exp(-src[src_i]));", - "SILU_FUNC": "{{MUTATE}}[dst_i] = src[src_i] / (1.0 + exp(-src[src_i]));", - "EXP_FUNC": "{{MUTATE}}[dst_i] = exp(src[src_i]);", - "HARDSWISH_FUNC": "{{MUTATE}}[dst_i] = src[src_i] * min(1.0, max(0.0, (src[src_i] + 3.0) / 6.0));", - "GELU_FUNC": "{{MUTATE}}[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(sqrt(2.0 / 3.14159265) * (src[src_i] + 0.044715 * pow(src[src_i], 3.0)), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", - "GELU_QUICK_FUNC": "{{MUTATE}}[dst_i] = src[src_i] * 0.5 * (1.0 + tanh(clamp(0.79788456 * (src[src_i] + 0.044715 * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", - "GELU_ERF_FUNC": "{{MUTATE}}[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(0.79788456 * (src[src_i] + 0.044715 * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", - "CEIL_FUNC": "{{MUTATE}}[dst_i] = ceil(src[src_i]);" -} - -#end(REPL_TEMPLATES) - -#define(VARIANTS) - -[ - { - "SHADER_NAME": "abs_f32", - "REPLS": { "TYPE": "f32", "FUNC": "ABS_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "abs_f16", - "REPLS": { "TYPE": "f16", "FUNC": "ABS_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "abs_inplace_f32", - "REPLS": { "TYPE": "f32", "FUNC": "ABS_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "abs_inplace_f16", - "REPLS": { "TYPE": "f16", "FUNC": "ABS_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, - "DECLS": ["INPLACE"] - }, - - { - "SHADER_NAME": "sgn_f32", - "REPLS": { "TYPE": "f32", "FUNC": "SGN_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "sgn_f16", - "REPLS": { "TYPE": "f16", "FUNC": "SGN_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "sgn_inplace_f32", - "REPLS": { "TYPE": "f32", "FUNC": "SGN_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "sgn_inplace_f16", - "REPLS": { "TYPE": "f16", "FUNC": "SGN_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, - "DECLS": ["INPLACE"] - }, - - { - "SHADER_NAME": "neg_f32", - "REPLS": { "TYPE": "f32", "FUNC": "NEG_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "neg_f16", - "REPLS": { "TYPE": "f16", "FUNC": "NEG_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "neg_inplace_f32", - "REPLS": { "TYPE": "f32", "FUNC": "NEG_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "neg_inplace_f16", - "REPLS": { "TYPE": "f16", "FUNC": "NEG_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, - "DECLS": ["INPLACE"] - }, - - { - "SHADER_NAME": "step_f32", - "REPLS": { "TYPE": "f32", "FUNC": "STEP_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "step_f16", - "REPLS": { "TYPE": "f16", "FUNC": "STEP_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "step_inplace_f32", - "REPLS": { "TYPE": "f32", "FUNC": "STEP_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "step_inplace_f16", - "REPLS": { "TYPE": "f16", "FUNC": "STEP_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, - "DECLS": ["INPLACE"] - }, - - { - "SHADER_NAME": "tanh_f32", - "REPLS": { "TYPE": "f32", "FUNC": "TANH_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "tanh_f16", - "REPLS": { "TYPE": "f16", "FUNC": "TANH_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "tanh_inplace_f32", - "REPLS": { "TYPE": "f32", "FUNC": "TANH_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "tanh_inplace_f16", - "REPLS": { "TYPE": "f16", "FUNC": "TANH_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, - "DECLS": ["INPLACE"] - }, - - { - "SHADER_NAME": "elu_f32", - "REPLS": { "TYPE": "f32", "FUNC": "ELU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "elu_f16", - "REPLS": { "TYPE": "f16", "FUNC": "ELU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "elu_inplace_f32", - "REPLS": { "TYPE": "f32", "FUNC": "ELU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "elu_inplace_f16", - "REPLS": { "TYPE": "f16", "FUNC": "ELU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, - "DECLS": ["INPLACE"] - }, - - { - "SHADER_NAME": "relu_f32", - "REPLS": { "TYPE": "f32", "FUNC": "RELU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "relu_f16", - "REPLS": { "TYPE": "f16", "FUNC": "RELU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "relu_inplace_f32", - "REPLS": { "TYPE": "f32", "FUNC": "RELU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "relu_inplace_f16", - "REPLS": { "TYPE": "f16", "FUNC": "RELU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, - "DECLS": ["INPLACE"] - }, - - { - "SHADER_NAME": "sigmoid_f32", - "REPLS": { "TYPE": "f32", "FUNC": "SIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "sigmoid_f16", - "REPLS": { "TYPE": "f16", "FUNC": "SIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "sigmoid_inplace_f32", - "REPLS": { "TYPE": "f32", "FUNC": "SIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "sigmoid_inplace_f16", - "REPLS": { "TYPE": "f16", "FUNC": "SIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, - "DECLS": ["INPLACE"] - }, - - { - "SHADER_NAME": "silu_f32", - "REPLS": { "TYPE": "f32", "FUNC": "SILU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "silu_f16", - "REPLS": { "TYPE": "f16", "FUNC": "SILU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "silu_inplace_f32", - "REPLS": { "TYPE": "f32", "FUNC": "SILU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "silu_inplace_f16", - "REPLS": { "TYPE": "f16", "FUNC": "SILU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, - "DECLS": ["INPLACE"] - }, - - { - "SHADER_NAME": "exp_f32", - "REPLS": { "TYPE": "f32", "FUNC": "EXP_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "exp_f16", - "REPLS": { "TYPE": "f16", "FUNC": "EXP_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "exp_inplace_f32", - "REPLS": { "TYPE": "f32", "FUNC": "EXP_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "exp_inplace_f16", - "REPLS": { "TYPE": "f16", "FUNC": "EXP_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, - "DECLS": ["INPLACE"] - }, - - { - "SHADER_NAME": "hardsigmoid_f32", - "REPLS": { "TYPE": "f32", "FUNC": "HARDSIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "hardsigmoid_f16", - "REPLS": { "TYPE": "f16", "FUNC": "HARDSIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "hardsigmoid_inplace_f32", - "REPLS": { "TYPE": "f32", "FUNC": "HARDSIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "hardsigmoid_inplace_f16", - "REPLS": { "TYPE": "f16", "FUNC": "HARDSIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, - "DECLS": ["INPLACE"] - }, - - { - "SHADER_NAME": "hardswish_f32", - "REPLS": { "TYPE": "f32", "FUNC": "HARDSWISH_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "hardswish_f16", - "REPLS": { "TYPE": "f16", "FUNC": "HARDSWISH_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "hardswish_inplace_f32", - "REPLS": { "TYPE": "f32", "FUNC": "HARDSWISH_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "hardswish_inplace_f16", - "REPLS": { "TYPE": "f16", "FUNC": "HARDSWISH_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, - "DECLS": ["INPLACE"] - }, - - { - "SHADER_NAME": "gelu_f32", - "REPLS": { "TYPE": "f32", "FUNC": "GELU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "gelu_f16", - "REPLS": { "TYPE": "f16", "FUNC": "GELU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "gelu_inplace_f32", - "REPLS": { "TYPE": "f32", "FUNC": "GELU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "gelu_inplace_f16", - "REPLS": { "TYPE": "f16", "FUNC": "GELU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, - "DECLS": ["INPLACE"] - }, - - { - "SHADER_NAME": "gelu_quick_f32", - "REPLS": { "TYPE": "f32", "FUNC": "GELU_QUICK_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "gelu_quick_f16", - "REPLS": { "TYPE": "f16", "FUNC": "GELU_QUICK_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "gelu_quick_inplace_f32", - "REPLS": { "TYPE": "f32", "FUNC": "GELU_QUICK_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "gelu_quick_inplace_f16", - "REPLS": { "TYPE": "f16", "FUNC": "GELU_QUICK_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, - "DECLS": ["INPLACE"] - }, - - { - "SHADER_NAME": "xielu_f32", - "REPLS": { "TYPE": "f32", "FUNC": "XIELU_FUNC", "EXT_PARAMS": "alpha_n: f32, alpha_p: f32, beta: f32, eps: f32", "MUTATE": "dst" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "xielu_f16", - "REPLS": { "TYPE": "f16", "FUNC": "XIELU_FUNC", "EXT_PARAMS": "alpha_n: f32, alpha_p: f32, beta: f32, eps: f32", "MUTATE": "dst" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "xielu_inplace_f32", - "REPLS": { "TYPE": "f32", "FUNC": "XIELU_FUNC", "EXT_PARAMS": "alpha_n: f32, alpha_p: f32, beta: f32, eps: f32", "MUTATE": "src" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "xielu_inplace_f16", - "REPLS": { "TYPE": "f16", "FUNC": "XIELU_FUNC", "EXT_PARAMS": "alpha_n: f32, alpha_p: f32, beta: f32, eps: f32", "MUTATE": "src" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "gelu_erf_f32", - "REPLS": { "TYPE": "f32", "FUNC": "GELU_ERF_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "gelu_erf_f16", - "REPLS": { "TYPE": "f16", "FUNC": "GELU_ERF_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "gelu_erf_inplace_f32", - "REPLS": { "TYPE": "f32", "FUNC": "GELU_ERF_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "gelu_erf_inplace_f16", - "REPLS": { "TYPE": "f16", "FUNC": "GELU_ERF_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, - "DECLS": ["INPLACE"] - }, - - { - "SHADER_NAME": "ceil_f32", - "REPLS": { "TYPE": "f32", "FUNC": "CEIL_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "ceil_f16", - "REPLS": { "TYPE": "f16", "FUNC": "CEIL_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "ceil_inplace_f32", - "REPLS": { "TYPE": "f32", "FUNC": "CEIL_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "ceil_inplace_f16", - "REPLS": { "TYPE": "f16", "FUNC": "CEIL_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, - "DECLS": ["INPLACE"] - } -] - -#end(VARIANTS) - -#define(DECLS) - -#decl(INPLACE) - -@group(0) @binding(1) -var params: Params; - -#enddecl(INPLACE) - -#decl(NOT_INPLACE) - -@group(0) @binding(1) -var dst: array<{{TYPE}}>; - -@group(0) @binding(2) -var params: Params; - -#enddecl(NOT_INPLACE) - -#end(DECLS) - -#define(SHADER) - -enable f16; - -fn update(dst_i: u32, src_i: u32) { - {{FUNC}} -} - -@group(0) @binding(0) -var src: array<{{TYPE}}>; - -DECLS - -struct Params { - ne: u32, // total number of elements - offset_src: u32, // in elements - offset_dst: u32, // in elements - - // Strides (in elements) — may be permuted - stride_src0: u32, - stride_src1: u32, - stride_src2: u32, - stride_src3: u32, - - stride_dst0: u32, - stride_dst1: u32, - stride_dst2: u32, - stride_dst3: u32, - - // Logical shapes - src_ne0: u32, - src_ne1: u32, - src_ne2: u32, - - dst_ne0: u32, - dst_ne1: u32, - dst_ne2: u32, - - {{EXT_PARAMS}} -}; - -override wg_size: u32; -@compute @workgroup_size(wg_size) -fn main(@builtin(global_invocation_id) gid: vec3) { - if (gid.x >= params.ne) { - return; - } - - var i = gid.x; - let i3 = i / (params.src_ne2 * params.src_ne1 * params.src_ne0); - i = i % (params.src_ne2 * params.src_ne1 * params.src_ne0); - let i2 = i / (params.src_ne1 * params.src_ne0); - i = i % (params.src_ne1 * params.src_ne0); - let i1 = i / params.src_ne0; - let i0 = i % params.src_ne0; - - var j = gid.x; - let j3 = j / (params.dst_ne2 * params.dst_ne1 * params.dst_ne0); - j = j % (params.dst_ne2 * params.dst_ne1 * params.dst_ne0); - let j2 = j / (params.dst_ne1 * params.dst_ne0); - j = j % (params.dst_ne1 * params.dst_ne0); - let j1 = j / params.dst_ne0; - let j0 = j % params.dst_ne0; - - let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 + - i2 * params.stride_src2 + i3 * params.stride_src3; - - let dst_idx = j0 * params.stride_dst0 + j1 * params.stride_dst1 + - j2 * params.stride_dst2 + j3 * params.stride_dst3; - - - update(params.offset_dst + dst_idx, params.offset_src + src_idx); -} - -#end(SHADER) - From 62a09b106d62f25f898d41c152563d5ce3458c1b Mon Sep 17 00:00:00 2001 From: lhez Date: Sat, 17 Jan 2026 13:50:32 -0800 Subject: [PATCH 019/831] opencl: fix q6_K mv for m=1 (llama/18893) --- ggml/src/ggml-opencl/kernels/mul_mv_q6_k.cl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q6_k.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q6_k.cl index 8a17b9aae63..819e5192e35 100644 --- a/ggml/src/ggml-opencl/kernels/mul_mv_q6_k.cl +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q6_k.cl @@ -111,6 +111,10 @@ kernel void kernel_mul_mv_q6_K_f32( int row = N_SIMDGROUP * r0 + get_sub_group_id(); + if (row >= ne01) { + return; + } + int i12 = im%ne12; int i13 = im/ne12; From 47f3e3b9271491393e66caf88c765c167b7ee0bf Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 19 Jan 2026 20:03:19 +0200 Subject: [PATCH 020/831] ggml : add ggml_build_forward_select (llama/18550) * ggml : add ggml_build_forward_select * cuda : adapt CUDA graph compat to new feature * vulkan : update logic to handle command buffer closing * ggml : check compute for fusion * ggml : add comment --- ggml/include/ggml.h | 46 ++++++++++++++--- ggml/src/ggml-backend.cpp | 5 +- ggml/src/ggml-blas/ggml-blas.cpp | 4 ++ ggml/src/ggml-cann/ggml-cann.cpp | 4 ++ ggml/src/ggml-cpu/ggml-cpu.c | 4 ++ ggml/src/ggml-cuda/common.cuh | 1 + ggml/src/ggml-cuda/ggml-cuda.cu | 8 +++ ggml/src/ggml-hexagon/ggml-hexagon.cpp | 4 ++ ggml/src/ggml-impl.h | 3 ++ ggml/src/ggml-metal/ggml-metal-ops.cpp | 4 ++ ggml/src/ggml-opencl/ggml-opencl.cpp | 4 ++ ggml/src/ggml-sycl/ggml-sycl.cpp | 3 ++ ggml/src/ggml-vulkan/ggml-vulkan.cpp | 5 +- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 3 ++ ggml/src/ggml-zdnn/ggml-zdnn.cpp | 4 ++ ggml/src/ggml-zendnn/ggml-zendnn.cpp | 4 ++ ggml/src/ggml.c | 70 +++++++++++++++++++------- 17 files changed, 148 insertions(+), 28 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index b69583dd3fd..1988d16dc42 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -630,10 +630,11 @@ extern "C" { // this tensor... enum ggml_tensor_flag { - GGML_TENSOR_FLAG_INPUT = 1, // ...is an input for the GGML compute graph - GGML_TENSOR_FLAG_OUTPUT = 2, // ...is an output for the GGML compute graph - GGML_TENSOR_FLAG_PARAM = 4, // ...contains trainable parameters - GGML_TENSOR_FLAG_LOSS = 8, // ...defines loss for numerical optimization (multiple loss tensors add up) + GGML_TENSOR_FLAG_INPUT = 1, // ...is an input for the GGML compute graph + GGML_TENSOR_FLAG_OUTPUT = 2, // ...is an output for the GGML compute graph + GGML_TENSOR_FLAG_PARAM = 4, // ...contains trainable parameters + GGML_TENSOR_FLAG_LOSS = 8, // ...defines loss for numerical optimization (multiple loss tensors add up) + GGML_TENSOR_FLAG_COMPUTE = 16, // ...must be computed }; enum ggml_tri_type { @@ -2577,11 +2578,42 @@ extern "C" { struct ggml_tensor * grad, struct ggml_tensor * sgd_params); // alpha, weight decay + // build forward mutiple tensors and select one of them for computing + // this is useful for creating graphs that have constant topology but compute different things based on the input + // ref: https://github.com/ggml-org/llama.cpp/pull/18550 // - // automatic differentiation + // nodes: + // | - build forward into the graph but do not compute + // c - build forward into the graph and compute // + // | | ... c ... | + // | | ... c ... | + // | | ... c ... | + // [0 1 ... idx ... n-1] <-- ggml_build_forward_select(..., n, idx) + // c + // c + // + // example: + // struct ggml_tensor * curs[3]; + // + // curs[0] = compute0(...); + // curs[1] = compute1(...); + // curs[2] = compute2(...); + // + // int idx = select_branch(some_input); + // + // struct ggml_tensor * out = ggml_build_forward_select(cgraph, curs, 3, idx); + // + GGML_API struct ggml_tensor * ggml_build_forward_select( + struct ggml_cgraph * cgraph, + struct ggml_tensor ** tensors, + int n_tensors, + int idx); + + GGML_API void ggml_build_forward_expand( + struct ggml_cgraph * cgraph, + struct ggml_tensor * tensor); - GGML_API void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor); GGML_API void ggml_build_backward_expand( struct ggml_context * ctx, // context for gradient computation struct ggml_cgraph * cgraph, @@ -2613,7 +2645,7 @@ extern "C" { GGML_API void ggml_graph_print(const struct ggml_cgraph * cgraph); // dump the graph into a file using the dot format - GGML_API void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename); + GGML_API void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * cgraph, const char * filename); // TODO these functions were sandwiched in the old optimization interface, is there a better place for them? typedef void (*ggml_log_callback)(enum ggml_log_level level, const char * text, void * user_data); diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index 1b59924b8cb..354876574a0 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -874,9 +874,9 @@ static void ggml_backend_sched_print_assignments(ggml_backend_sched_t sched, str } if (sched->debug > 1) { ggml_backend_t tensor_backend = ggml_backend_sched_get_tensor_backend(sched, node); - GGML_LOG_DEBUG("node #%3d (%10.10s): %20.20s (%5.5s) [%5.5s %8.8s] use=%d:", i, ggml_op_name(node->op), node->name, + GGML_LOG_DEBUG("node #%3d (%10.10s): %20.20s (%5.5s) [%5.5s %8.8s] use=%d,c=%d:", i, ggml_op_name(node->op), node->name, fmt_size(ggml_nbytes(node)), tensor_backend ? ggml_backend_name(tensor_backend) : "NULL", GET_CAUSE(node), - graph->use_counts[ggml_hash_find(&graph->visited_hash_set, node)]); + graph->use_counts[ggml_hash_find(&graph->visited_hash_set, node)], node->flags & GGML_TENSOR_FLAG_COMPUTE ? 1 : 0); for (int j = 0; j < GGML_MAX_SRC; j++) { struct ggml_tensor * src = node->src[j]; if (src == NULL) { @@ -1922,6 +1922,7 @@ static struct ggml_tensor * graph_copy_dup_tensor(struct ggml_hash_set hash_set, dst->view_offs = src->view_offs; } dst->op = src->op; + dst->flags = src->flags; memcpy(dst->op_params, src->op_params, sizeof(dst->op_params)); ggml_set_name(dst, src->name); diff --git a/ggml/src/ggml-blas/ggml-blas.cpp b/ggml/src/ggml-blas/ggml-blas.cpp index 84956cbb9ce..2e9ddf2240d 100644 --- a/ggml/src/ggml-blas/ggml-blas.cpp +++ b/ggml/src/ggml-blas/ggml-blas.cpp @@ -226,6 +226,10 @@ static enum ggml_status ggml_backend_blas_graph_compute(ggml_backend_t backend, for (int i = 0; i < cgraph->n_nodes; i++) { struct ggml_tensor * node = cgraph->nodes[i]; + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + continue; + } + switch (node->op) { case GGML_OP_MUL_MAT: ggml_backend_blas_mul_mat(ctx, node); diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index eba83327f13..42c6c67a40b 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -2146,6 +2146,10 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx continue; } + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + continue; + } + bool ok = ggml_cann_compute_forward(*cann_ctx, node); if (!ok) { GGML_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op)); diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index f7ba1fe317d..4c7a75e768a 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -2943,6 +2943,10 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { continue; } + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + continue; + } + ggml_compute_forward(¶ms, node); if (state->ith == 0 && cplan->abort_callback && diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index eaaf87612d2..179522d8355 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -1123,6 +1123,7 @@ struct ggml_tensor_extra_gpu { struct ggml_cuda_graph_node_properties { void * node_address; ggml_op node_op; + int32_t flags; int64_t ne[GGML_MAX_DIMS]; size_t nb[GGML_MAX_DIMS]; void * src_address[GGML_MAX_SRC]; diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index ed1021469a7..cda422defbe 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2918,6 +2918,7 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) { static void ggml_cuda_graph_node_set_properties(ggml_cuda_graph_node_properties * props, ggml_tensor * node) { props->node_address = node->data; props->node_op = node->op; + props->flags = node->flags; for (int i = 0; i < GGML_MAX_DIMS; i++) { props->ne[i] = node->ne[i]; props->nb[i] = node->nb[i]; @@ -2961,6 +2962,10 @@ static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_ return false; } + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) != (props->flags & GGML_TENSOR_FLAG_COMPUTE)) { + return false; + } + return true; } @@ -3378,6 +3383,9 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud continue; } + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + continue; + } // start of fusion operations static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr); diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index cf1eb994c3e..5b835c11c72 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -2497,6 +2497,10 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg continue; } + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + continue; + } + uint32_t flags = 0; // skip quantizer if src1 is reused diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h index 80e0fd2ff8b..baadfe9a7b3 100644 --- a/ggml/src/ggml-impl.h +++ b/ggml/src/ggml-impl.h @@ -611,6 +611,9 @@ static inline bool ggml_can_fuse_ext(const struct ggml_cgraph * cgraph, const in if (node->op != ops[i]) { return false; } + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + return false; + } if (i < num_ops - 1 && !ggml_node_has_n_uses(cgraph, node_idxs[i], 1)) { return false; } diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 680ad794de9..3d97d3dfdcb 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -203,6 +203,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { GGML_ABORT("unsupported op"); } + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + return 1; + } + int n_fuse = 1; // check if the current node can run concurrently with other nodes before it diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index d89d5e7242d..8059240b1c4 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -3058,6 +3058,10 @@ static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggm continue; } + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + continue; + } + if (!backend_ctx->disable_fusion && ggml_opencl_can_fuse(cgraph, i, { GGML_OP_NORM, GGML_OP_MUL, GGML_OP_ADD })) { ggml_opencl_op_norm_fused(backend, node, cgraph->nodes[i+1], cgraph->nodes[i+2]); i += 2; diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 8f8176b678a..bb8acc922b9 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -4109,6 +4109,9 @@ static void ggml_backend_sycl_graph_compute_impl(ggml_backend_sycl_context * syc if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) { continue; } + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + continue; + } #ifndef NDEBUG assert(node->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device)); for (int j = 0; j < GGML_MAX_SRC; j++) { diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 0fabbcec31d..08fd044ca03 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -12191,6 +12191,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr if (ggml_is_empty(node) || ggml_op_is_empty(node->op) || !node->buffer) { return false; } + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + return false; + } VK_LOG_DEBUG("ggml_vk_build_graph(" << node << ", " << ggml_op_name(node->op) << ")"); ctx->semaphore_idx = 0; @@ -13645,7 +13648,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg int last_node = cgraph->n_nodes - 1; // If the last op in the cgraph isn't backend GPU, the command buffer doesn't get closed properly - while (last_node > 0 && ggml_vk_is_empty(cgraph->nodes[last_node])) { + while (last_node > 0 && (ggml_vk_is_empty(cgraph->nodes[last_node]) || ((cgraph->nodes[last_node]->flags & GGML_TENSOR_FLAG_COMPUTE) == 0))) { last_node -= 1; } diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 1470378af00..584cea7698b 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -1982,6 +1982,9 @@ static std::optional ggml_webgpu_encode_node(webgpu_context ctx, if (ggml_is_empty(node)) { return std::nullopt; } + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + return std::nullopt; + } WEBGPU_LOG_DEBUG("ggml_webgpu_encode_node(" << node << ", " << ggml_op_name(node->op) << ")"); ggml_tensor * src0 = node->src[0]; diff --git a/ggml/src/ggml-zdnn/ggml-zdnn.cpp b/ggml/src/ggml-zdnn/ggml-zdnn.cpp index edbeb8eef24..906d25417e4 100644 --- a/ggml/src/ggml-zdnn/ggml-zdnn.cpp +++ b/ggml/src/ggml-zdnn/ggml-zdnn.cpp @@ -58,6 +58,10 @@ static enum ggml_status ggml_zdnn_graph_compute(ggml_backend_t backend, ggml_cgr continue; } + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + continue; + } + bool ok = ggml_zdnn_compute_forward(ctx, node); if (!ok) { GGML_LOG_ERROR("%s: unsupported op %s (%s)\n", diff --git a/ggml/src/ggml-zendnn/ggml-zendnn.cpp b/ggml/src/ggml-zendnn/ggml-zendnn.cpp index fd07f983da7..afbecde7a5a 100644 --- a/ggml/src/ggml-zendnn/ggml-zendnn.cpp +++ b/ggml/src/ggml-zendnn/ggml-zendnn.cpp @@ -211,6 +211,10 @@ static ggml_status ggml_backend_zendnn_graph_compute(ggml_backend_t backend, ggm for (int i = 0; i < cgraph->n_nodes; i++) { struct ggml_tensor * node = cgraph->nodes[i]; + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + continue; + } + switch (node->op) { case GGML_OP_MUL_MAT: ggml_zendnn_compute_forward_mul_mat(ctx, node); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index c75fe7d2716..1725ad16545 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -3441,7 +3441,8 @@ struct ggml_tensor * ggml_cast( result->op = GGML_OP_CPY; result->src[0] = a; - result->src[1] = result; + result->src[1] = result; // note: this self-reference might seem redundant, but it's actually needed by some + // backends for consistency with ggml_cpy_impl() above return result; } @@ -6725,20 +6726,35 @@ static void ggml_compute_backward( GGML_ASSERT(!src2_needs_grads || ggml_are_same_shape(src2, cgraph->grads[isrc2])); } -static size_t ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) { - // check if already visited - size_t node_hash_pos = ggml_hash_find(&cgraph->visited_hash_set, node); +static size_t ggml_visit_parents_graph(struct ggml_cgraph * cgraph, struct ggml_tensor * node, bool compute) { + if (node->op != GGML_OP_NONE && compute) { + node->flags |= GGML_TENSOR_FLAG_COMPUTE; + } + + const size_t node_hash_pos = ggml_hash_find(&cgraph->visited_hash_set, node); GGML_ASSERT(node_hash_pos != GGML_HASHSET_FULL); - if (!ggml_bitset_get(cgraph->visited_hash_set.used, node_hash_pos)) { - // This is the first time we see this node in the current graph. - cgraph->visited_hash_set.keys[node_hash_pos] = node; - ggml_bitset_set(cgraph->visited_hash_set.used, node_hash_pos); - cgraph->use_counts[node_hash_pos] = 0; - } else { + + if (ggml_bitset_get(cgraph->visited_hash_set.used, node_hash_pos)) { // already visited + + if (compute) { + // update the compute flag regardless + for (int i = 0; i < GGML_MAX_SRC; ++i) { + struct ggml_tensor * src = node->src[i]; + if (src && ((src->flags & GGML_TENSOR_FLAG_COMPUTE) == 0)) { + ggml_visit_parents_graph(cgraph, src, true); + } + } + } + return node_hash_pos; } + // This is the first time we see this node in the current graph. + cgraph->visited_hash_set.keys[node_hash_pos] = node; + ggml_bitset_set(cgraph->visited_hash_set.used, node_hash_pos); + cgraph->use_counts[node_hash_pos] = 0; + for (int i = 0; i < GGML_MAX_SRC; ++i) { const int k = (cgraph->order == GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT) ? i : @@ -6747,7 +6763,7 @@ static size_t ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor struct ggml_tensor * src = node->src[k]; if (src) { - size_t src_hash_pos = ggml_visit_parents(cgraph, src); + const size_t src_hash_pos = ggml_visit_parents_graph(cgraph, src, compute); // Update the use count for this operand. cgraph->use_counts[src_hash_pos]++; @@ -6778,17 +6794,17 @@ static size_t ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor return node_hash_pos; } -static void ggml_build_forward_impl(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor, bool expand) { +static void ggml_build_forward_impl(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor, bool expand, bool compute) { if (!expand) { // TODO: this branch isn't accessible anymore, maybe move this to ggml_build_forward_expand ggml_graph_clear(cgraph); } - const int n0 = cgraph->n_nodes; + const int n_old = cgraph->n_nodes; - ggml_visit_parents(cgraph, tensor); + ggml_visit_parents_graph(cgraph, tensor, compute); - const int n_new = cgraph->n_nodes - n0; + const int n_new = cgraph->n_nodes - n_old; GGML_PRINT_DEBUG("%s: visited %d new nodes\n", __func__, n_new); if (n_new > 0) { @@ -6797,8 +6813,22 @@ static void ggml_build_forward_impl(struct ggml_cgraph * cgraph, struct ggml_ten } } +struct ggml_tensor * ggml_build_forward_select( + struct ggml_cgraph * cgraph, + struct ggml_tensor ** tensors, + int n_tensors, + int idx) { + GGML_ASSERT(idx >= 0 && idx < n_tensors); + + for (int i = 0; i < n_tensors; i++) { + ggml_build_forward_impl(cgraph, tensors[i], true, i == idx ? true : false); + } + + return tensors[idx]; +} + void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor) { - ggml_build_forward_impl(cgraph, tensor, true); + ggml_build_forward_impl(cgraph, tensor, true, true); } void ggml_build_backward_expand( @@ -7229,6 +7259,10 @@ bool ggml_can_fuse_subgraph_ext(const struct ggml_cgraph * cgraph, return false; } + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + return false; + } + if (ggml_node_list_find_tensor(cgraph, outputs, num_outputs, node) != -1) { continue; } @@ -7310,7 +7344,7 @@ static void ggml_graph_dump_dot_leaf_edge(FILE * fp, struct ggml_tensor * node, label); } -void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename) { +void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * cgraph, const char * filename) { char color[16]; FILE * fp = ggml_fopen(filename, "w"); @@ -7331,7 +7365,7 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph if (node->flags & GGML_TENSOR_FLAG_PARAM) { snprintf(color, sizeof(color), "yellow"); } else if (grad) { - if (ggml_graph_find(gf, node)) { + if (ggml_graph_find(cgraph, node)) { snprintf(color, sizeof(color), "green"); } else { snprintf(color, sizeof(color), "lightblue"); From b0517d6912de72330616d66bb676ea902a1b93b6 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 20 Jan 2026 12:21:28 +0200 Subject: [PATCH 021/831] metal : enable FA for MLA heads (llama/18950) --- ggml/src/ggml-metal/ggml-metal-device.m | 8 ++------ ggml/src/ggml-metal/ggml-metal-ops.cpp | 2 +- ggml/src/ggml-metal/ggml-metal.metal | 13 ++++++++----- 3 files changed, 11 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index c418afe9c3b..eb4e2c209ce 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -1078,12 +1078,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te op->src[0]->ne[0] != 112 && op->src[0]->ne[0] != 128 && op->src[0]->ne[0] != 192 && - op->src[0]->ne[0] != 256) { - return false; - } - if (op->src[0]->ne[0] == 576) { - // DeepSeek sizes - // TODO: disabled for now, until optmized + op->src[0]->ne[0] != 256 && + op->src[0]->ne[0] != 576) { return false; } if (op->src[1]->type != op->src[2]->type) { diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 3d97d3dfdcb..7f4cfbba226 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -2520,7 +2520,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { // simdgroups per threadgroup (a.k.a. warps) //nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4; - int32_t nsg = 4; + int32_t nsg = ne00 >= 512 ? 8 : 4; const size_t smem = FATTN_SMEM(nsg); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index a4e1cafe552..17e358d1a8d 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -5552,9 +5552,7 @@ void kernel_flash_attn_ext_impl( constexpr short NC = (C/8)/NSG; - // note: do not unroll for large heads - #pragma unroll (DK <= 64 ? NC : 1) - for (short cc = 0; cc < NC; ++cc) { + FOR_UNROLL (short cc = 0; cc < NC; ++cc) { qk8x8_t mqk = make_filled_simdgroup_matrix((qk_t) 0.0f); if (DK % 16 != 0) { @@ -5575,7 +5573,9 @@ void kernel_flash_attn_ext_impl( k8x8_t mk[2]; q8x8_t mq[2]; - FOR_UNROLL (short i = 0; i < DK8/2; ++i) { + // note: too much unroll can tank the performance for large heads + #pragma unroll (MIN(DK8/2, 4*NSG)) + for (short i = 0; i < DK8/2; ++i) { simdgroup_barrier(mem_flags::mem_none); simdgroup_load(mq[0], pq + 0*8 + 16*i, DK); @@ -5749,7 +5749,9 @@ void kernel_flash_attn_ext_impl( pv += 8*NS20; } } else { - FOR_UNROLL (short cc = 0; cc < (C/8)/2; ++cc) { + constexpr short NC = (C/8)/2; + + FOR_UNROLL (short cc = 0; cc < NC; ++cc) { s8x8_t vs[2]; simdgroup_load(vs[0], ss + 16*cc + 0, SH, 0, false); @@ -5952,6 +5954,7 @@ kernel void kernel_flash_attn_ext( //case 1: kernel_flash_attn_ext_impl(FWD_ARGS); break; //case 2: kernel_flash_attn_ext_impl(FWD_ARGS); break; case 4: kernel_flash_attn_ext_impl(FWD_ARGS); break; + case 8: kernel_flash_attn_ext_impl(FWD_ARGS); break; } #undef FWD_TMPL #undef FWD_ARGS From bf71ffa6b34a04d18b5f0b5489206171f16e9790 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrien=20Gallou=C3=ABt?= Date: Tue, 20 Jan 2026 11:42:49 +0100 Subject: [PATCH 022/831] ggml : cleanup path_str() (llama/18928) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Remove pragmas as `std::codecvt_utf8` is not used. - Avoid implicit `strlen()`. Signed-off-by: Adrien Gallouët --- ggml/src/ggml-backend-reg.cpp | 24 ++++-------------------- 1 file changed, 4 insertions(+), 20 deletions(-) diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp index 4181a714ad6..6bee1bc4b49 100644 --- a/ggml/src/ggml-backend-reg.cpp +++ b/ggml/src/ggml-backend-reg.cpp @@ -77,39 +77,23 @@ #include "ggml-zendnn.h" #endif -// disable C++17 deprecation warning for std::codecvt_utf8 -#if defined(__clang__) -# pragma clang diagnostic push -# pragma clang diagnostic ignored "-Wdeprecated-declarations" -#elif defined(__GNUC__) -# pragma GCC diagnostic push -# pragma GCC diagnostic ignored "-Wdeprecated-declarations" -#endif - namespace fs = std::filesystem; static std::string path_str(const fs::path & path) { - std::string u8path; try { #if defined(__cpp_lib_char8_t) // C++20 and later: u8string() returns std::u8string - std::u8string u8str = path.u8string(); - u8path = std::string(reinterpret_cast(u8str.c_str())); + const std::u8string u8str = path.u8string(); + return std::string(reinterpret_cast(u8str.data()), u8str.size()); #else // C++17: u8string() returns std::string - u8path = path.u8string(); + return path.u8string(); #endif } catch (...) { + return std::string(); } - return u8path; } -#if defined(__clang__) -# pragma clang diagnostic pop -#elif defined(__GNUC__) -# pragma GCC diagnostic pop -#endif - #ifdef _WIN32 using dl_handle = std::remove_pointer_t; From fdc83ee3c0367e26ffc7b00edbed75d5189df8ce Mon Sep 17 00:00:00 2001 From: Oliver Simons Date: Tue, 20 Jan 2026 13:11:01 +0100 Subject: [PATCH 023/831] CUDA: Replace init_offsets kernel with iterators in cub-based argsort (llama/18930) * CUDA: Replace `init_offsets` with iterators in argsort This is a QOL improvement, saving us the cost of materializing the iterator * Remove unnecessary include from top-k.cu --- ggml/src/ggml-cuda/argsort.cu | 22 +++++++--------------- ggml/src/ggml-cuda/top-k.cu | 1 - 2 files changed, 7 insertions(+), 16 deletions(-) diff --git a/ggml/src/ggml-cuda/argsort.cu b/ggml/src/ggml-cuda/argsort.cu index 57c8a99a286..cf7a44f7adc 100644 --- a/ggml/src/ggml-cuda/argsort.cu +++ b/ggml/src/ggml-cuda/argsort.cu @@ -14,12 +14,6 @@ static __global__ void init_indices(int * indices, const int ncols, const int nr } } -static __global__ void init_offsets(int * offsets, const int ncols, const int nrows) { - const int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx <= nrows) { - offsets[idx] = idx * ncols; - } -} #ifdef GGML_CUDA_USE_CUB void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool, @@ -31,18 +25,15 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool, cudaStream_t stream) { ggml_cuda_pool_alloc temp_indices_alloc(pool, ncols * nrows); ggml_cuda_pool_alloc temp_keys_alloc(pool, ncols * nrows); - ggml_cuda_pool_alloc offsets_alloc(pool, nrows + 1); int * temp_indices = temp_indices_alloc.get(); float * temp_keys = temp_keys_alloc.get(); - int * d_offsets = offsets_alloc.get(); static const int block_size = 256; const dim3 grid_size((ncols + block_size - 1) / block_size, nrows); init_indices<<>>(temp_indices, ncols, nrows); - const dim3 offset_grid((nrows + block_size - 1) / block_size); - init_offsets<<>>(d_offsets, ncols, nrows); + auto offset_iterator = cuda::make_strided_iterator(cuda::make_counting_iterator(0), ncols); CUDA_CHECK(cudaMemcpyAsync(temp_keys, x, ncols * nrows * sizeof(float), cudaMemcpyDeviceToDevice, stream)); @@ -57,7 +48,7 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool, DeviceSegmentedSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) temp_indices, dst, // values (indices) ncols * nrows, nrows, // num items, num segments - d_offsets, d_offsets + 1, stream); + offset_iterator, offset_iterator + 1, stream); } } else { if (nrows == 1) { @@ -66,7 +57,8 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool, ncols, 0, sizeof(float) * 8, stream); } else { DeviceSegmentedSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices, - dst, ncols * nrows, nrows, d_offsets, d_offsets + 1, stream); + dst, ncols * nrows, nrows, offset_iterator, offset_iterator + 1, + stream); } } @@ -80,7 +72,7 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool, ncols, 0, sizeof(float) * 8, stream); } else { DeviceSegmentedSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst, - ncols * nrows, nrows, d_offsets, d_offsets + 1, stream); + ncols * nrows, nrows, offset_iterator, offset_iterator + 1, stream); } } else { if (nrows == 1) { @@ -89,8 +81,8 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool, ncols, 0, sizeof(float) * 8, stream); } else { DeviceSegmentedSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, - temp_indices, dst, ncols * nrows, nrows, d_offsets, d_offsets + 1, - stream); + temp_indices, dst, ncols * nrows, nrows, offset_iterator, + offset_iterator + 1, stream); } } } diff --git a/ggml/src/ggml-cuda/top-k.cu b/ggml/src/ggml-cuda/top-k.cu index 318ac38691e..785a18389f2 100644 --- a/ggml/src/ggml-cuda/top-k.cu +++ b/ggml/src/ggml-cuda/top-k.cu @@ -4,7 +4,6 @@ #ifdef GGML_CUDA_USE_CUB # include # if (CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 2) -# include # define CUB_TOP_K_AVAILABLE using namespace cub; # endif // CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 2 From 924a9e292ca7b818ec93517f83732f89194793df Mon Sep 17 00:00:00 2001 From: Oliver Simons Date: Wed, 21 Jan 2026 02:34:29 +0100 Subject: [PATCH 024/831] CUDA: Fix builds for older CCCL versions by ifdefing strided_iterator (llama/18964) * CUDA: Fix builds for older CCCL versions by ifdefing strided_iterator Strided iterator was added in [CCCL 3.1](https://github.com/NVIDIA/cccl/releases/tag/v3.1.0), which is packaged into [CTK 13.1](https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html#id5) * Unindent as per code review request --- ggml/src/ggml-cuda/argsort.cu | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/argsort.cu b/ggml/src/ggml-cuda/argsort.cu index cf7a44f7adc..4896669c32a 100644 --- a/ggml/src/ggml-cuda/argsort.cu +++ b/ggml/src/ggml-cuda/argsort.cu @@ -2,6 +2,9 @@ #ifdef GGML_CUDA_USE_CUB # include +# if (CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 1) +# define STRIDED_ITERATOR_AVAILABLE +# endif using namespace cub; #endif // GGML_CUDA_USE_CUB @@ -14,6 +17,14 @@ static __global__ void init_indices(int * indices, const int ncols, const int nr } } +#ifndef STRIDED_ITERATOR_AVAILABLE +static __global__ void init_offsets(int * offsets, const int ncols, const int nrows) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx <= nrows) { + offsets[idx] = idx * ncols; + } +} +#endif // STRIDED_ITERATOR_AVAILABLE #ifdef GGML_CUDA_USE_CUB void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool, @@ -33,8 +44,14 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool, const dim3 grid_size((ncols + block_size - 1) / block_size, nrows); init_indices<<>>(temp_indices, ncols, nrows); +#ifdef STRIDED_ITERATOR_AVAILABLE auto offset_iterator = cuda::make_strided_iterator(cuda::make_counting_iterator(0), ncols); - +#else + ggml_cuda_pool_alloc offsets_alloc(pool, nrows + 1); + int * offset_iterator = offsets_alloc.get(); + const dim3 offset_grid((nrows + block_size - 1) / block_size); + init_offsets<<>>(offset_iterator, ncols, nrows); +#endif CUDA_CHECK(cudaMemcpyAsync(temp_keys, x, ncols * nrows * sizeof(float), cudaMemcpyDeviceToDevice, stream)); size_t temp_storage_bytes = 0; From 660d943ff8d69398a217fde7a8adf99373556839 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Wed, 21 Jan 2026 09:22:02 -0600 Subject: [PATCH 025/831] vulkan: Use mul_mat_vec_id for small values of n (llama/18918) Change ggml_vk_mul_mat_vec_id_q_f16 to loop over the batch dimension and update the indexing calculations in get_offsets. Mat-vec is faster than mat-mat for small values of n. We don't get the same reuse of the weights as in the non-ID path, but with this the cost is linear in n rather than n>1 being far slower than n==1. --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 47 ++++++++++--------- .../vulkan-shaders/mul_mat_vec_base.glsl | 34 +++++++------- 2 files changed, 43 insertions(+), 38 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 08fd044ca03..4c8bdd4e635 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -991,6 +991,8 @@ struct vk_mat_vec_id_push_constants { uint32_t fusion_flags; uint32_t nei0; uint32_t ne11; + uint32_t expert_i1; + uint32_t nbi1; }; struct vk_flash_attn_push_constants { @@ -8083,8 +8085,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte const uint64_t nei0 = ids->ne[0]; const uint64_t nei1 = ids->ne[1]; - - GGML_ASSERT(nei1 == 1); + const uint32_t nbi1 = (uint32_t)(ids->nb[1] / sizeof(int)); const uint64_t ne20 = dst->ne[0]; const uint64_t ne21 = dst->ne[1]; @@ -8168,7 +8169,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte if (quantize_y) { ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1); } - ggml_pipeline_request_descriptor_sets(ctx, dmmv, 1); + ggml_pipeline_request_descriptor_sets(ctx, dmmv, nei1); } vk_subbuffer d_D = ggml_vk_tensor_subbuffer(ctx, cgraph->nodes[node_idx + ctx->num_additional_fused_ops]); @@ -8226,7 +8227,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte uint32_t stride_batch_y = ne10*ne11; if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) { - stride_batch_y = src1->nb[0] / ggml_type_size(src1->type); + stride_batch_y = src1->nb[2] / ggml_type_size(src1->type); } const uint32_t max_groups_x = ctx->device->properties.limits.maxComputeWorkGroupCount[0]; @@ -8262,23 +8263,25 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte fusion_flags |= MAT_VEC_FUSION_FLAGS_SCALE1; } - // compute - const vk_mat_vec_id_push_constants pc = { - (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01, - (uint32_t)(ne00 * ne01), stride_batch_y, (uint32_t)(ne20 * ne21), - fusion_flags, - (uint32_t)nei0, (uint32_t)ne11, - }; - ggml_vk_dispatch_pipeline(ctx, subctx, dmmv, - { - d_X, - d_Y, - d_D, - d_F0, - d_F1, - d_ids, - }, - pc, { groups_x, (uint32_t)nei0, groups_z }); + // Loop over the batch dimension + for (uint32_t expert_i1 = 0; expert_i1 < nei1; ++expert_i1) { + const vk_mat_vec_id_push_constants pc = { + (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01, + (uint32_t)(ne00 * ne01), stride_batch_y, (uint32_t)(ne20 * ne21), + fusion_flags, + (uint32_t)nei0, (uint32_t)ne11, expert_i1, nbi1 + }; + ggml_vk_dispatch_pipeline(ctx, subctx, dmmv, + { + d_X, + d_Y, + d_D, + d_F0, + d_F1, + d_ids, + }, + pc, { groups_x, (uint32_t)nei0, groups_z }); + } if (x_non_contig) { ctx->prealloc_x_need_sync = true; @@ -8292,7 +8295,7 @@ static bool ggml_vk_use_mul_mat_vec_id(const struct ggml_cgraph * cgraph, int no ggml_tensor * dst = cgraph->nodes[node_idx]; ggml_tensor * src0 = dst->src[0]; ggml_tensor * src2 = dst->src[2]; - return src2->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)); + return (src2->ne[1] <= 8) && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)); } static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl index dfb78659362..4f2c7003065 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl @@ -29,6 +29,8 @@ layout (push_constant) uniform parameter #ifdef MUL_MAT_ID uint nei0; uint ne11; + uint expert_i1; + uint nbi1; #else uint ne02; uint ne12; @@ -43,7 +45,7 @@ uint expert_id; void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) { #ifdef MUL_MAT_ID - const uint expert_idx = gl_GlobalInvocationID.y; + const uint expert_i0 = gl_GlobalInvocationID.y; #else const uint batch_idx = gl_GlobalInvocationID.y; #endif @@ -60,7 +62,7 @@ void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) { batch_idx_a = i03 * p.ne02 + i02; } #else - expert_id = data_ids[expert_idx]; + expert_id = data_ids[expert_i0 + p.expert_i1 * p.nbi1]; #endif a_offset = @@ -71,13 +73,13 @@ void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) { #endif b_offset = #ifdef MUL_MAT_ID - (expert_idx % p.ne11) * p.stride_b; + (expert_i0 % p.ne11) * p.stride_b + p.expert_i1 * p.batch_stride_b; #else batch_idx * p.batch_stride_b; #endif d_offset = #ifdef MUL_MAT_ID - expert_idx * p.stride_d; + expert_i0 * p.stride_d + p.expert_i1 * p.batch_stride_d; #else batch_idx * p.batch_stride_d; #endif @@ -103,12 +105,12 @@ void reduce_result(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t temp[j][n] += FLOAT_TYPE(data_fuse0[expert_id*p.stride_d + first_row + n]); } if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE0) != 0) { - const uint expert_idx = gl_GlobalInvocationID.y; - temp[j][n] *= FLOAT_TYPE(data_fuse0[expert_idx]); + const uint expert_i0 = gl_GlobalInvocationID.y; + temp[j][n] *= FLOAT_TYPE(data_fuse0[expert_i0]); } if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE1) != 0) { - const uint expert_idx = gl_GlobalInvocationID.y; - temp[j][n] *= FLOAT_TYPE(data_fuse1[expert_idx]); + const uint expert_i0 = gl_GlobalInvocationID.y; + temp[j][n] *= FLOAT_TYPE(data_fuse1[expert_i0]); } #else if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) { @@ -158,12 +160,12 @@ void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offs temp[j][n] += FLOAT_TYPE(data_fuse0[expert_id*p.stride_d + first_row + n]); } if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE0) != 0) { - const uint expert_idx = gl_GlobalInvocationID.y; - temp[j][n] *= FLOAT_TYPE(data_fuse0[expert_idx]); + const uint expert_i0 = gl_GlobalInvocationID.y; + temp[j][n] *= FLOAT_TYPE(data_fuse0[expert_i0]); } if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE1) != 0) { - const uint expert_idx = gl_GlobalInvocationID.y; - temp[j][n] *= FLOAT_TYPE(data_fuse1[expert_idx]); + const uint expert_i0 = gl_GlobalInvocationID.y; + temp[j][n] *= FLOAT_TYPE(data_fuse1[expert_i0]); } #else if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) { @@ -203,12 +205,12 @@ void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offs tmpsh[j][n][0] += FLOAT_TYPE(data_fuse0[expert_id*p.stride_d + first_row + n]); } if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE0) != 0) { - const uint expert_idx = gl_GlobalInvocationID.y; - tmpsh[j][n][0] *= FLOAT_TYPE(data_fuse0[expert_idx]); + const uint expert_i0 = gl_GlobalInvocationID.y; + tmpsh[j][n][0] *= FLOAT_TYPE(data_fuse0[expert_i0]); } if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE1) != 0) { - const uint expert_idx = gl_GlobalInvocationID.y; - tmpsh[j][n][0] *= FLOAT_TYPE(data_fuse1[expert_idx]); + const uint expert_i0 = gl_GlobalInvocationID.y; + tmpsh[j][n][0] *= FLOAT_TYPE(data_fuse1[expert_i0]); } #else if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) { From 3bbf4ced474f5cd43dbbd827b6815ffb526ac33e Mon Sep 17 00:00:00 2001 From: Masato Nakasaka Date: Thu, 22 Jan 2026 01:13:43 +0900 Subject: [PATCH 026/831] Revert "vulkan: force full subgroups for flash attention to fix intel subgroup crash (#17356)" (llama/18831) This reverts commit 980b7cd17e055c8c587f79ffda7eb4fddf405566. --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 4c8bdd4e635..62a878556ba 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -3180,15 +3180,15 @@ static void ggml_vk_load_shaders(vk_device& device) { if (path == FAPATH) { \ if (aligned) { \ if (f32acc) { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ } else { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ } \ } else { \ if (f32acc) { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ } else { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ } \ } \ } \ From b2bc4d810b2df64b33be6613fc76c7769fecf503 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Wed, 21 Jan 2026 10:43:43 -0600 Subject: [PATCH 027/831] vulkan: support flash attention GQA/split_k with small batches (llama/18938) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 42 +++++++++++++------ .../vulkan-shaders/flash_attn.comp | 13 +++--- .../vulkan-shaders/flash_attn_base.glsl | 17 +++++--- .../vulkan-shaders/flash_attn_cm1.comp | 13 +++--- .../vulkan-shaders/flash_attn_cm2.comp | 13 +++--- .../flash_attn_split_k_reduce.comp | 17 ++++---- 6 files changed, 71 insertions(+), 44 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 62a878556ba..739361e7778 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -1518,6 +1518,15 @@ struct vk_quantize_q8_1_push_constants { uint32_t num_blocks; }; +struct vk_op_flash_attn_split_k_reduce_push_constants { + uint32_t D; + uint32_t ne1; + uint32_t ne2; + uint32_t ne3; + uint32_t k_num; + uint32_t sinks; +}; + // Allow pre-recording command buffers struct vk_staging_memcpy { vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {} @@ -3982,7 +3991,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_MXFP4], "get_rows_mxfp4_f32", get_rows_mxfp4_f32_len, get_rows_mxfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, 5 * sizeof(uint32_t), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, sizeof(vk_op_flash_attn_split_k_reduce_push_constants), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true); if (device->subgroup_clustered && device->subgroup_require_full_support) { ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1_x4, "quantize_q8_1_x4", quantize_q8_1_x4_subgroup_len, quantize_q8_1_x4_subgroup_data, "main", 2, sizeof(vk_quantize_q8_1_push_constants), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1, true, true); @@ -8457,14 +8466,14 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx GGML_ASSERT(0); } - if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa && + if (N <= 8 && qk_ratio > 1 && qk_ratio <= max_gqa && qk_ratio * nek2 == neq2 && nek2 == nev2 && nem2 <= 1) { // grouped query attention - make the N dimension equal to gqa_ratio, reduce // workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1 // and change addressing calculations to index Q's dimension 2. gqa_ratio = qk_ratio; N = gqa_ratio; - workgroups_y /= N; + workgroups_y /= gqa_ratio; } bool small_rows = N <= get_fa_num_small_rows(path); @@ -8526,6 +8535,8 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx } assert(pipeline); + // Compile early to initialize wg_denoms. + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); uint32_t split_kv = KV; uint32_t split_k = 1; @@ -8533,22 +8544,24 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx // Use a placeholder core count if one isn't available. split_k is a big help for perf. const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count : 16; - // Try to use split_k when KV is large enough to be worth the overhead - if (workgroups_x == 1 && shader_core_count > 0) { + // Try to use split_k when KV is large enough to be worth the overhead. + // Must either be a single batch or be using gqa, we can't mix the two. + if (workgroups_x <= pipeline->wg_denoms[0] && (workgroups_x == 1 || gqa_ratio > 1)) { // Try to run two workgroups per SM. - split_k = shader_core_count * 2 / (workgroups_y * workgroups_z); + split_k = shader_core_count * 2 / (workgroups_x * workgroups_y * workgroups_z); if (split_k > 1) { // Try to evenly split KV into split_k chunks, but it needs to be a multiple // of "align", so recompute split_k based on that. split_kv = ROUNDUP_POW2(std::max(1u, KV / split_k), alignment); split_k = CEIL_DIV(KV, split_kv); - workgroups_x = split_k; } } // Reserve space for split_k temporaries. For each split x batch, we need to store the O matrix (D x ne1) // and the per-row m and L values (ne1 rows). We store all the matrices first, followed by the rows. - const uint64_t split_k_size = split_k > 1 ? (HSV * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k * ne3 : 0; + // For matrices, the order is (inner to outer) [HSV, ne1, k, ne2, ne3]. + // For L/M, the order is (inner to outer) [ne1, k, ne2, ne3]. + const uint64_t split_k_size = split_k > 1 ? (HSV * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k * ne2 * ne3 : 0; if (split_k_size > ctx->device->properties.limits.maxStorageBufferRange) { GGML_ABORT("Requested preallocation size is too large"); } @@ -8559,7 +8572,6 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx { // Request descriptor sets - ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); if (split_k > 1) { ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_flash_attn_split_k_reduce, 1); } @@ -8608,7 +8620,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx if (ctx->prealloc_split_k_need_sync) { ggml_vk_sync_buffers(ctx, subctx); } - + workgroups_x *= pipeline->wg_denoms[0]; vk_subbuffer split_k_buf = ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {q_buf, k_buf, v_buf, mask_buf, sinks_buf, split_k_buf}, @@ -8616,15 +8628,19 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx // there's no more than one tile of rows (i.e. workgroups_x would have been // one). We reuse workgroups_x to mean the number of splits, so we need to // cancel out the divide by wg_denoms[0]. - pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z }); + pc, { split_k * workgroups_x, workgroups_y, workgroups_z }); ggml_vk_sync_buffers(ctx, subctx); - const std::array pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne3, split_k, (sinks != nullptr) }; + const vk_op_flash_attn_split_k_reduce_push_constants pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3, split_k, (sinks != nullptr) }; ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce, {split_k_buf, sinks_buf, dst_buf}, - pc2, { (uint32_t)ne1, HSV, (uint32_t)ne3 }); + pc2, { (uint32_t)ne1, HSV, (uint32_t)(ne2 * ne3) }); ctx->prealloc_split_k_need_sync = true; } else { + if (gqa_ratio > 1) { + // When using gqa, we want one actual workgroup per batch, so cancel out wg_denoms + workgroups_x *= pipeline->wg_denoms[0]; + } ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {q_buf, k_buf, v_buf, mask_buf, sinks_buf, dst_buf}, pc, { workgroups_x, workgroups_y, workgroups_z }); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index 0379e5d5024..3ce8d07be80 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -53,7 +53,7 @@ void main() { const uint32_t d_tid = gl_LocalInvocationIndex % D_split; const uint32_t col_tid = gl_LocalInvocationIndex / D_split; - uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4; + uint32_t q_offset = gqa_iq1*p.nb01 + (iq2*p.nb02 + iq3*p.nb03) / 4; [[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) { uint32_t d = (idx + tid) % (HSK / 4); @@ -101,9 +101,9 @@ void main() { uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2; uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2; #endif - uint32_t m_offset = 0; + uint32_t m_offset = gqa_iq1*KV; if (p.nem2 != 1 || p.nem3 != 1) { - m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV; + m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV; } [[dont_unroll]] @@ -320,7 +320,8 @@ void main() { // If there is split_k, then the split_k resolve shader does the final // division by L. Store the intermediate O value and per-row m and L values. if (p.k_num > 1) { - uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num); + // note: O and Q have swapped coord 1,2. + uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)); [[unroll]] for (uint32_t r = 0; r < Br; ++r) { if (r < N) { @@ -332,7 +333,7 @@ void main() { } } - o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2; + o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)); [[unroll]] for (uint32_t r = 0; r < Br; ++r) { if (r < N) { perElemOpStoreCol0(r, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N); @@ -378,7 +379,7 @@ void main() { } } - uint32_t o_offset = iq3*p.ne2*p.ne1*HSV; + uint32_t o_offset = gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV; if (p.gqa_ratio > 1) { [[unroll]] for (uint32_t r = 0; r < Br; ++r) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl index eb93903c468..29b5c7c3a41 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl @@ -165,7 +165,7 @@ ACC_TYPE perElemOpGetSink(const in uint32_t r, const in uint32_t c, const in ACC } uint32_t i, N, KV, split_k_index, Tr, start_j, end_j, - iq2, iq3, rk2, rk3, rv2, rv3, ik2, ik3, iv2, iv3, + gqa_iq1, iq2, iq3, rk2, rk3, rv2, rv3, ik2, ik3, iv2, iv3, q_stride, k_stride, v_stride, m_stride; void init_indices() @@ -173,12 +173,19 @@ void init_indices() N = p.N; KV = p.KV; - i = gl_WorkGroupID.x; - split_k_index = 0; - if (p.k_num > 1) { i = 0; - split_k_index = gl_WorkGroupID.x; + // batch and split_k share gl_WorkGroupID.x + gqa_iq1 = gl_WorkGroupID.x / p.k_num; + split_k_index = gl_WorkGroupID.x % p.k_num; + } else if (p.gqa_ratio > 1) { + i = 0; + gqa_iq1 = gl_WorkGroupID.x; + split_k_index = 0; + } else { + i = gl_WorkGroupID.x; + gqa_iq1 = 0; + split_k_index = 0; } Tr = CEIL_DIV(N, Br); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp index c995ab140ee..0eb50fe58f9 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp @@ -90,7 +90,7 @@ void main() { barrier(); } - uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4; + uint32_t q_offset = gqa_iq1*p.nb01 + (iq2*p.nb02+iq3*p.nb03) / 4; [[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) { uint32_t d = (idx + tid) % (HSK / 4); @@ -141,9 +141,9 @@ void main() { uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2; uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2; #endif - uint32_t m_offset = 0; + uint32_t m_offset = gqa_iq1*KV; if (p.nem2 != 1 || p.nem3 != 1) { - m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV; + m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV; } [[dont_unroll]] @@ -370,7 +370,8 @@ void main() { // If there is split_k, then the split_k resolve shader does the final // division by L. Store the intermediate O value and per-row m and L values. if (p.k_num > 1) { - uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num); + // note: O and Q have swapped coord 1,2. + uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)); [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { if (tile_row(r) < N) { @@ -382,7 +383,7 @@ void main() { } } - o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2; + o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)); [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { if (tile_row(r) < N) { perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N); @@ -428,7 +429,7 @@ void main() { } } - uint32_t o_offset = iq3*p.ne2*p.ne1*HSV; + uint32_t o_offset = gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV; if (p.gqa_ratio > 1) { [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index 9a71996383d..d49a8da65fb 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -111,7 +111,7 @@ void main() { coopmat Q; coopmat Qf16; - uint32_t q_offset = iq2*p.nb02+iq3*p.nb03; + uint32_t q_offset = gqa_iq1*p.nb01*4/*sizeof(float)*/ + iq2*p.nb02+iq3*p.nb03; coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, HSK_pad)); Qf16 = coopmat(Q); @@ -138,9 +138,9 @@ void main() { coopMatPerElementNV(slopeMat, slopeMat, perElemOpComputeSlope, iq2); } - uint32_t m_offset = 0; + uint32_t m_offset = gqa_iq1*KV * 2 /*sizeof(float16_t)*/; if (p.nem2 != 1 || p.nem3 != 1) { - m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV * 2 /*sizeof(float16_t)*/; + m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV * 2 /*sizeof(float16_t)*/; } [[dont_unroll]] @@ -272,10 +272,11 @@ void main() { if (p.k_num > 1) { coopmat O_D = coopmat(O); - uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num); + // note: O and Q have swapped coord 1,2. + uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)); coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N); - o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2; + o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)); coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N); coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N); return; @@ -325,7 +326,7 @@ void main() { [[unroll]] for (uint i = 0; i < O.length(); ++i) { O[i] = clamp(O[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); } #endif - uint32_t o_offset = iq3*p.ne2*p.ne1*HSV; + uint32_t o_offset = gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV; coopmat O_D = coopmat(O); if (p.gqa_ratio > 1) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp index 4eaddd31a8f..68917fc0bb0 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp @@ -12,7 +12,8 @@ layout (binding = 2) writeonly buffer D {float data_d[];}; layout (push_constant) uniform parameter { uint D; - uint N; + uint ne1; + uint ne2; uint ne3; uint k_num; uint sinks; @@ -24,15 +25,15 @@ void main() { // Each workgroup handles a row const uint n = gl_WorkGroupID.x; const uint tid = gl_LocalInvocationID.x; - const uint iq3 = gl_WorkGroupID.z; + const uint i2 = gl_WorkGroupID.z % p.ne2; + const uint i3 = gl_WorkGroupID.z / p.ne2; uint D = p.D; - uint N = p.N; uint k_num = p.k_num; - uint l_offset = D * N * p.ne3 * k_num + N * iq3 * k_num * 2 + n; - uint m_offset = D * N * p.ne3 * k_num + N * iq3 * k_num * 2 + N + n; - uint lm_stride = N * 2; + uint l_offset = D * p.ne1 * p.ne2 * p.ne3 * k_num + p.ne1 * 2 * (0/*split_k_index*/ + p.k_num * (i2 + p.ne2 * i3)) + n; + uint m_offset = D * p.ne1 * p.ne2 * p.ne3 * k_num + p.ne1 * 2 * (0/*split_k_index*/ + p.k_num * (i2 + p.ne2 * i3)) + p.ne1 + n; + uint lm_stride = p.ne1 * 2; // Compute the max m value for the row float m_max = -1.0/0.0; @@ -99,7 +100,7 @@ void main() { if (d < D) { float O = 0.0; [[unroll]] for (uint k = 0; k < k_num; ++k) { - uint o_offset = D * N * (k + iq3 * k_num) + D * n + d; + uint o_offset = D * p.ne1 * (k + p.k_num * (i2 + p.ne2 * i3)) + D * n + d; float m = data_a[m_offset + k * lm_stride]; O += exp(m - m_max) * data_a[o_offset]; } @@ -115,6 +116,6 @@ void main() { const float FLT_MAX = uintBitsToFloat(0x7F7FFFFF); O = clamp(O, -FLT_MAX, FLT_MAX); - data_d[iq3 * D * N + D * n + d] = O; + data_d[(i3 * p.ne2 + i2) * p.ne1 * D + D * n + d] = O; } } From b7e323f40b915a6cc0a6a5c951290a1ccda0b2f3 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Wed, 21 Jan 2026 11:01:40 -0600 Subject: [PATCH 028/831] vulkan: Remove transfer_ctx, do everything in compute_ctx. (llama/18945) * vulkan: Remove transfer_ctx, do everything in compute_ctx. We had a bug where a set_tensor_async (using transfer_ctx) didn't get submitted before the graph_compute (using compute_ctx) that came after it. To avoid this sort of issue, just do everything in compute_ctx. Remove transfer_cmd_pool, which was already unused. * fix crash with perf logger --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 120 +++++++++++++-------------- 1 file changed, 59 insertions(+), 61 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 739361e7778..b5e5dba95fe 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -1813,7 +1813,6 @@ struct ggml_backend_vk_context { bool prealloc_x_need_sync, prealloc_y_need_sync, prealloc_split_k_need_sync; vk_context_ref compute_ctx; - vk_context_ref transfer_ctx; std::vector tensor_ctxs; @@ -1823,7 +1822,6 @@ struct ggml_backend_vk_context { uint32_t pipeline_descriptor_set_requirements {}; vk_command_pool compute_cmd_pool; - vk_command_pool transfer_cmd_pool; // number of additional consecutive nodes that are being fused with the // node currently being processed @@ -5658,7 +5656,6 @@ static void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) { ctx->almost_ready_fence = ctx->device->device.createFence({}); ctx->compute_cmd_pool.init(ctx->device, &ctx->device->compute_queue); - ctx->transfer_cmd_pool.init(ctx->device, &ctx->device->transfer_queue); if (vk_perf_logger_enabled) { ctx->perf_logger = std::unique_ptr(new vk_perf_logger()); @@ -11579,7 +11576,6 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t free(d_chk); ggml_vk_command_pool_cleanup(ctx->device, ctx->compute_cmd_pool); - ggml_vk_command_pool_cleanup(ctx->device, ctx->transfer_cmd_pool); ggml_vk_destroy_buffer(d_X); ggml_vk_destroy_buffer(d_Y); @@ -12164,7 +12160,9 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx, vk_contex ggml_vk_submit(subctx, {}); ctx->submit_pending = true; ggml_vk_synchronize(ctx); + GGML_ASSERT(ctx->compute_ctx.expired()); ggml_vk_ctx_begin(ctx->device, subctx); + ctx->compute_ctx = subctx; } if (ctx->prealloc_x == nullptr || (ctx->prealloc_size_x > 0 && ctx->prealloc_x->size < ctx->prealloc_size_x)) { @@ -12182,6 +12180,7 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx, vk_contex ggml_vk_destroy_buffer(ctx->prealloc_y); } ctx->prealloc_y = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_y); + ctx->prealloc_y_last_tensor_used = nullptr; } if (ctx->prealloc_split_k == nullptr || (ctx->prealloc_size_split_k > 0 && ctx->prealloc_split_k->size < ctx->prealloc_size_split_k)) { VK_LOG_MEMORY("ggml_vk_preallocate_buffers(split_k_size: " << ctx->prealloc_size_split_k << ")"); @@ -12762,7 +12761,6 @@ static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) { ctx->prealloc_x_need_sync = ctx->prealloc_y_need_sync = ctx->prealloc_split_k_need_sync = false; ggml_vk_command_pool_cleanup(ctx->device, ctx->compute_cmd_pool); - ggml_vk_command_pool_cleanup(ctx->device, ctx->transfer_cmd_pool); for (size_t i = 0; i < ctx->gc.semaphores.size(); i++) { ctx->device->device.destroySemaphore({ ctx->gc.semaphores[i].s }); @@ -12791,7 +12789,7 @@ static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) { static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) { VK_LOG_DEBUG("ggml_vk_cleanup(" << ctx->name << ")"); // discard any unsubmitted command buffers - ctx->transfer_ctx.reset(); + ctx->compute_ctx.reset(); // wait for any pending command buffers to finish ggml_vk_synchronize(ctx); @@ -12824,7 +12822,6 @@ static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) { ctx->descriptor_sets.clear(); ctx->compute_cmd_pool.destroy(ctx->device->device); - ctx->transfer_cmd_pool.destroy(ctx->device->device); if (vk_perf_logger_enabled) { ctx->perf_logger->print_timings(true); } @@ -13096,34 +13093,34 @@ static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, ggml_tensor ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context; - vk_context transfer_ctx; + vk_context compute_ctx; - if (ctx->transfer_ctx.expired()) { + if (ctx->compute_ctx.expired()) { // Initialize new transfer context - transfer_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); - ctx->transfer_ctx = transfer_ctx; - ggml_vk_ctx_begin(ctx->device, transfer_ctx); + compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); + ctx->compute_ctx = compute_ctx; + ggml_vk_ctx_begin(ctx->device, compute_ctx); } else { - transfer_ctx = ctx->transfer_ctx.lock(); + compute_ctx = ctx->compute_ctx.lock(); } vk_buffer buf = buf_ctx->dev_buffer; auto dst_offset = vk_tensor_offset(tensor) + tensor->view_offs + offset; - bool ret = ggml_vk_buffer_write_async(transfer_ctx, buf, dst_offset, data, size); + bool ret = ggml_vk_buffer_write_async(compute_ctx, buf, dst_offset, data, size); if (!ret) { ggml_vk_ensure_sync_staging_buffer(ctx, size); - ggml_vk_sync_buffers(nullptr, transfer_ctx); + ggml_vk_sync_buffers(nullptr, compute_ctx); vk::BufferCopy buffer_cpy; buffer_cpy.srcOffset = 0; buffer_cpy.dstOffset = dst_offset; buffer_cpy.size = size; - transfer_ctx->s->buffer.copyBuffer(ctx->sync_staging->buffer, buf->buffer, { buffer_cpy }); - deferred_memcpy(ctx->sync_staging->ptr, data, size, &transfer_ctx->in_memcpys); + compute_ctx->s->buffer.copyBuffer(ctx->sync_staging->buffer, buf->buffer, { buffer_cpy }); + deferred_memcpy(ctx->sync_staging->ptr, data, size, &compute_ctx->in_memcpys); ggml_vk_synchronize(ctx); } } @@ -13135,34 +13132,34 @@ static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_ ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context; - vk_context transfer_ctx; + vk_context compute_ctx; - if (ctx->transfer_ctx.expired()) { + if (ctx->compute_ctx.expired()) { // Initialize new transfer context - transfer_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); - ctx->transfer_ctx = transfer_ctx; - ggml_vk_ctx_begin(ctx->device, transfer_ctx); + compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); + ctx->compute_ctx = compute_ctx; + ggml_vk_ctx_begin(ctx->device, compute_ctx); } else { - transfer_ctx = ctx->transfer_ctx.lock(); + compute_ctx = ctx->compute_ctx.lock(); } vk_buffer buf = buf_ctx->dev_buffer; auto src_offset = vk_tensor_offset(tensor) + tensor->view_offs + offset; - bool ret = ggml_vk_buffer_read_async(transfer_ctx, buf, src_offset, data, size); + bool ret = ggml_vk_buffer_read_async(compute_ctx, buf, src_offset, data, size); // If that failed, copy synchronously through a staging buffer if (!ret) { ggml_vk_ensure_sync_staging_buffer(ctx, size); - ggml_vk_sync_buffers(nullptr, transfer_ctx); + ggml_vk_sync_buffers(nullptr, compute_ctx); vk::BufferCopy buffer_cpy; buffer_cpy.srcOffset = src_offset; buffer_cpy.dstOffset = 0; buffer_cpy.size = size; - transfer_ctx->s->buffer.copyBuffer(buf->buffer, ctx->sync_staging->buffer, { buffer_cpy }); - deferred_memcpy(data, ctx->sync_staging->ptr, size, &transfer_ctx->out_memcpys); + compute_ctx->s->buffer.copyBuffer(buf->buffer, ctx->sync_staging->buffer, { buffer_cpy }); + deferred_memcpy(data, ctx->sync_staging->ptr, size, &compute_ctx->out_memcpys); ggml_vk_synchronize(ctx); } } @@ -13174,21 +13171,21 @@ static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend, const ggml_ ggml_backend_vk_buffer_context * src_buf_ctx = (ggml_backend_vk_buffer_context *)src->buffer->context; ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; - vk_context transfer_ctx; + vk_context compute_ctx; - if (ctx->transfer_ctx.expired()) { + if (ctx->compute_ctx.expired()) { // Initialize new transfer context - transfer_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); - ctx->transfer_ctx = transfer_ctx; - ggml_vk_ctx_begin(ctx->device, transfer_ctx); + compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); + ctx->compute_ctx = compute_ctx; + ggml_vk_ctx_begin(ctx->device, compute_ctx); } else { - transfer_ctx = ctx->transfer_ctx.lock(); + compute_ctx = ctx->compute_ctx.lock(); } vk_buffer src_buf = src_buf_ctx->dev_buffer; vk_buffer dst_buf = dst_buf_ctx->dev_buffer; - ggml_vk_buffer_copy_async(transfer_ctx, dst_buf, vk_tensor_offset(dst) + dst->view_offs, src_buf, vk_tensor_offset(src) + src->view_offs, ggml_nbytes(src)); + ggml_vk_buffer_copy_async(compute_ctx, dst_buf, vk_tensor_offset(dst) + dst->view_offs, src_buf, vk_tensor_offset(src) + src->view_offs, ggml_nbytes(src)); return true; } @@ -13198,19 +13195,19 @@ static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend, const ggml_ static void ggml_vk_synchronize(ggml_backend_vk_context * ctx) { VK_LOG_DEBUG("ggml_vk_synchronize()"); - bool do_transfer = !ctx->transfer_ctx.expired(); + bool do_transfer = !ctx->compute_ctx.expired(); - vk_context transfer_ctx; + vk_context compute_ctx; if (do_transfer) { - transfer_ctx = ctx->transfer_ctx.lock(); + compute_ctx = ctx->compute_ctx.lock(); - ggml_vk_ctx_end(transfer_ctx); + ggml_vk_ctx_end(compute_ctx); - for (auto& cpy : transfer_ctx->in_memcpys) { + for (auto& cpy : compute_ctx->in_memcpys) { memcpy(cpy.dst, cpy.src, cpy.n); } - ggml_vk_submit(transfer_ctx, {}); + ggml_vk_submit(compute_ctx, {}); ctx->submit_pending = true; } @@ -13224,10 +13221,10 @@ static void ggml_vk_synchronize(ggml_backend_vk_context * ctx) { } if (do_transfer) { - for (auto& cpy : transfer_ctx->out_memcpys) { + for (auto& cpy : compute_ctx->out_memcpys) { memcpy(cpy.dst, cpy.src, cpy.n); } - ctx->transfer_ctx.reset(); + ctx->compute_ctx.reset(); } } @@ -13896,6 +13893,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg ggml_vk_submit(compute_ctx, ctx->device->fence); VK_CHECK(ctx->device->device.waitForFences({ ctx->device->fence }, true, UINT64_MAX), "GGML_VULKAN_PERF waitForFences"); ctx->device->device.resetFences({ ctx->device->fence }); + ctx->compute_ctx.reset(); // Get the results and pass them to the logger std::vector timestamps(cgraph->n_nodes + 1); @@ -14182,15 +14180,15 @@ static void ggml_backend_vk_event_record(ggml_backend_t backend, ggml_backend_ev ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; vk_event *vkev = (vk_event *)event->context; - vk_context transfer_ctx; + vk_context compute_ctx; - if (ctx->transfer_ctx.expired()) { + if (ctx->compute_ctx.expired()) { // Initialize new transfer context - transfer_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); - ctx->transfer_ctx = transfer_ctx; - ggml_vk_ctx_begin(ctx->device, transfer_ctx); + compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); + ctx->compute_ctx = compute_ctx; + ggml_vk_ctx_begin(ctx->device, compute_ctx); } else { - transfer_ctx = ctx->transfer_ctx.lock(); + compute_ctx = ctx->compute_ctx.lock(); } // the backend interface doesn't have an explicit reset, so reset it here @@ -14198,13 +14196,13 @@ static void ggml_backend_vk_event_record(ggml_backend_t backend, ggml_backend_ev ctx->device->device.resetEvent(vkev->event); ctx->device->device.resetFences({ vkev->fence }); - ggml_vk_set_event(transfer_ctx, vkev->event); + ggml_vk_set_event(compute_ctx, vkev->event); - ggml_vk_ctx_end(transfer_ctx); + ggml_vk_ctx_end(compute_ctx); - ggml_vk_submit(transfer_ctx, {vkev->fence}); + ggml_vk_submit(compute_ctx, {vkev->fence}); ctx->submit_pending = true; - ctx->transfer_ctx.reset(); + ctx->compute_ctx.reset(); } static void ggml_backend_vk_event_wait(ggml_backend_t backend, ggml_backend_event_t event) { @@ -14212,20 +14210,20 @@ static void ggml_backend_vk_event_wait(ggml_backend_t backend, ggml_backend_even ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; vk_event *vkev = (vk_event *)event->context; - vk_context transfer_ctx; + vk_context compute_ctx; - if (ctx->transfer_ctx.expired()) { + if (ctx->compute_ctx.expired()) { // Initialize new transfer context - transfer_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); - ctx->transfer_ctx = transfer_ctx; - ggml_vk_ctx_begin(ctx->device, transfer_ctx); + compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); + ctx->compute_ctx = compute_ctx; + ggml_vk_ctx_begin(ctx->device, compute_ctx); } else { - transfer_ctx = ctx->transfer_ctx.lock(); + compute_ctx = ctx->compute_ctx.lock(); } - ggml_vk_wait_events(transfer_ctx, {vkev->event}); - ggml_vk_ctx_end(transfer_ctx); - ctx->transfer_ctx.reset(); + ggml_vk_wait_events(compute_ctx, {vkev->event}); + ggml_vk_ctx_end(compute_ctx); + ctx->compute_ctx.reset(); } // TODO: enable async and synchronize From 55927d42eff4a46e4a63db9ed201258f52f6bba2 Mon Sep 17 00:00:00 2001 From: Aleksei Nikiforov <103434461+AlekseiNikiforovIBM@users.noreply.github.com> Date: Thu, 22 Jan 2026 01:16:21 +0100 Subject: [PATCH 029/831] ggml-zdnn : mark zDNN buffers as non-host (llama/18967) While buffers reside in host memory, additional transformation is needed to use buffers with zDNN. Fixes #18848 --- ggml/src/ggml-zdnn/ggml-zdnn.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-zdnn/ggml-zdnn.cpp b/ggml/src/ggml-zdnn/ggml-zdnn.cpp index 906d25417e4..9b6938abf7e 100644 --- a/ggml/src/ggml-zdnn/ggml-zdnn.cpp +++ b/ggml/src/ggml-zdnn/ggml-zdnn.cpp @@ -372,7 +372,8 @@ static size_t ggml_backend_zdnn_buffer_type_get_alignment(ggml_backend_buffer_ty } static bool ggml_backend_zdnn_buffer_type_is_host(ggml_backend_buffer_type_t buft) { - return true; + /* while it resides in host memory, additional transformation is needed */ + return false; GGML_UNUSED(buft); } From 167fec69d5208a80976f0ef5678a36ef4b3d1b62 Mon Sep 17 00:00:00 2001 From: shaofeiqi Date: Wed, 21 Jan 2026 22:05:54 -0800 Subject: [PATCH 030/831] opencl: add TRI op support (llama/18979) --- ggml/src/ggml-opencl/CMakeLists.txt | 1 + ggml/src/ggml-opencl/ggml-opencl.cpp | 65 ++++++++++++++++++++++++++++ ggml/src/ggml-opencl/kernels/tri.cl | 32 ++++++++++++++ 3 files changed, 98 insertions(+) create mode 100644 ggml/src/ggml-opencl/kernels/tri.cl diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index 307ec08242a..79039c30e14 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -57,6 +57,7 @@ set(GGML_OPENCL_KERNELS add add_id argsort + tri fill clamp cpy diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 8059240b1c4..efdebe2bbaa 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -489,6 +489,7 @@ struct ggml_backend_opencl_context { cl_kernel kernel_gelu_quick, kernel_gelu_quick_4; cl_kernel kernel_relu; cl_kernel kernel_sigmoid_f32, kernel_sigmoid_f16; + cl_kernel kernel_tri; cl_kernel kernel_fill; cl_kernel kernel_clamp; cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu, kernel_swiglu_oai, kernel_geglu_erf, kernel_geglu_quick, @@ -793,6 +794,24 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // tri + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "tri.cl.h" + }; +#else + const std::string kernel_src = read_file("tri.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_tri = clCreateKernel(prog, "kernel_tri_f32", &err), err)); + GGML_LOG_CONT("."); + + CL_CHECK(clReleaseProgram(prog)); + } + // fill { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -3205,6 +3224,8 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te default: return false; } + case GGML_OP_TRI: + return op->type == GGML_TYPE_F32 && ggml_is_contiguous(op); case GGML_OP_FILL: return op->type == GGML_TYPE_F32 && ggml_is_contiguous(op); case GGML_OP_CLAMP: @@ -5965,6 +5986,44 @@ static void ggml_cl_sigmoid(ggml_backend_t backend, const ggml_tensor * src0, co backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst); } +static void ggml_cl_tri(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + UNUSED(src1); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + const int tri_type = ggml_get_op_params_i32(dst, 0); + const int64_t n = ggml_nelements(dst); + const int ne0 = dst->ne[0]; + const int ne1 = dst->ne[1]; + + cl_kernel kernel = backend_ctx->kernel_tri; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &n)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &tri_type)); + + size_t local_work_size[1] = { 256 }; + size_t global_work_size[1] = { ((size_t)n + local_work_size[0] - 1) / local_work_size[0] * local_work_size[0] }; + + backend_ctx->enqueue_ndrange_kernel(kernel, 1, global_work_size, local_work_size, dst); +} + static void ggml_cl_fill(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(dst); GGML_ASSERT(dst->extra); @@ -10012,6 +10071,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor } func = ggml_cl_glu; break; + case GGML_OP_TRI: + if (!any_on_device) { + return false; + } + func = ggml_cl_tri; + break; case GGML_OP_FILL: if (!any_on_device) { return false; diff --git a/ggml/src/ggml-opencl/kernels/tri.cl b/ggml/src/ggml-opencl/kernels/tri.cl new file mode 100644 index 00000000000..35cdd543bc5 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/tri.cl @@ -0,0 +1,32 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +//------------------------------------------------------------------------------ +// tri +//------------------------------------------------------------------------------ +__kernel void kernel_tri_f32( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd, + int n, + int ne0, + int ne1, + int tri_type +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + int idx = get_global_id(0); + if (idx >= n) return; + + int i0 = idx % ne0; + int i1 = (idx / ne0) % ne1; + + int keep = 0; + if (tri_type == 0) keep = (i0 >= i1); + else if (tri_type == 1) keep = (i0 > i1); + else if (tri_type == 2) keep = (i0 <= i1); + else keep = (i0 < i1); + + dst[idx] = keep ? src0[idx] : 0.0f; +} From d4fafcfc6fa61434a40b2fb27cc025d84e3aae5b Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Thu, 22 Jan 2026 18:51:53 +0800 Subject: [PATCH 031/831] CUDA: add gqa_ratio 4 for GLM 4.7 flash (llama/18953) --- ggml/src/ggml-cuda/fattn-mma-f16.cuh | 31 ++++++++++++++----- ggml/src/ggml-cuda/fattn-tile.cuh | 12 +++++++ ggml/src/ggml-cuda/fattn.cu | 10 ++++-- ...ttn-mma-f16-instance-ncols1_16-ncols2_4.cu | 1 + ...attn-mma-f16-instance-ncols1_2-ncols2_4.cu | 1 + ...attn-mma-f16-instance-ncols1_4-ncols2_4.cu | 1 + ...attn-mma-f16-instance-ncols1_8-ncols2_4.cu | 1 + .../template-instances/generate_cu_files.py | 2 +- 8 files changed, 47 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index e53bbc0502c..8cca89c2bfa 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -432,7 +432,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( constexpr int ncols = ncols1 * ncols2; constexpr int cols_per_warp = T_B_KQ::I; constexpr int cols_per_thread = get_cols_per_thread(); - constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column. + constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column. constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols); constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2(DKQ, DV, ncols); constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2(DKQ, DV, ncols); @@ -510,7 +510,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } } } else { - static_assert(cols_per_warp != 8, "cols_per_warp == 8 not implemented"); #pragma unroll for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += T_A_KQ::J) { load_ldmatrix(Q_B[0], tile_Q + (threadIdx.y / np)*(T_B_KQ::I*stride_tile_Q) + k_KQ_0, stride_tile_Q); @@ -522,14 +521,18 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( T_A_KQ K_A; load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K); - // Wide version of KQ_C is column-major + if constexpr (cols_per_warp == 8) { + mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]); + } else { + // Wide version of KQ_C is column-major #if defined(AMD_WMMA_AVAILABLE) - // RDNA matrix C is column-major. - mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]); + // RDNA matrix C is column-major. + mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]); #else - // swap A and B for CUDA. - mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A); + // swap A and B for CUDA. + mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A); #endif // defined(AMD_WMMA_AVAILABLE) + } } } } @@ -953,7 +956,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( constexpr int cols_per_warp = T_B_KQ::I; constexpr int cols_per_thread = get_cols_per_thread(); - constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column. + constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column. constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa (DKQ, DV, ncols); constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2 (DKQ, DV, ncols); constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2 (DKQ, DV, ncols); @@ -1484,6 +1487,13 @@ static __global__ void flash_attn_ext_f16( NO_DEVICE_CODE; return; } +#ifdef VOLTA_MMA_AVAILABLE + if (ncols1*ncols2 < 32) { + NO_DEVICE_CODE; + return; + } +#endif // VOLTA_MMA_AVAILABLE + #if __CUDA_ARCH__ == GGML_CUDA_CC_TURING if (ncols1*ncols2 > 32) { NO_DEVICE_CODE; @@ -1728,3 +1738,8 @@ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 64) extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16); extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16); extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16); + +// For GLM 4.7 Flash +extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4); +extern DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4); +extern DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4); diff --git a/ggml/src/ggml-cuda/fattn-tile.cuh b/ggml/src/ggml-cuda/fattn-tile.cuh index f055da8e2be..b6db5822818 100644 --- a/ggml/src/ggml-cuda/fattn-tile.cuh +++ b/ggml/src/ggml-cuda/fattn-tile.cuh @@ -68,6 +68,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64) return 0; @@ -122,6 +124,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 32, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 32, 64) return 0; @@ -183,6 +187,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 512, 1, 128, 64) @@ -245,6 +251,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5, 32, 256) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3, 64, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 4, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 256, 2, 128, 64) @@ -1187,6 +1195,10 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm launch_fattn_tile_switch_ncols1(ctx, dst); return; } + if (use_gqa_opt && gqa_ratio % 4 == 0) { + launch_fattn_tile_switch_ncols1(ctx, dst); + return; + } } if constexpr (DV <= 256) { diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 598cda7daa0..80c3bfbc271 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -121,8 +121,12 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); const int gqa_ratio = Q->ne[2] / K->ne[2]; - GGML_ASSERT(gqa_ratio % 16 == 0); - ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst); + GGML_ASSERT(gqa_ratio % 4 == 0); + if (gqa_ratio % 16 == 0) { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst); + } else { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst); + } } break; default: GGML_ABORT("fatal error"); @@ -262,7 +266,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const if (V->ne[0] != 512) { return BEST_FATTN_KERNEL_NONE; } - if (!gqa_opt_applies || gqa_ratio % 16 != 0) { + if (!gqa_opt_applies || gqa_ratio % 4 != 0) { return BEST_FATTN_KERNEL_NONE; } break; diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu index 2074e954a32..517993cb068 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu @@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 16, 4); DECL_FATTN_MMA_F16_CASE(112, 112, 16, 4); DECL_FATTN_MMA_F16_CASE(128, 128, 16, 4); DECL_FATTN_MMA_F16_CASE(256, 256, 16, 4); +DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu index 24c64cf000f..97b19c67ade 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu @@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 2, 4); DECL_FATTN_MMA_F16_CASE(112, 112, 2, 4); DECL_FATTN_MMA_F16_CASE(128, 128, 2, 4); DECL_FATTN_MMA_F16_CASE(256, 256, 2, 4); +DECL_FATTN_MMA_F16_CASE(576, 512, 2, 4); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu index 1ada657f194..989626dfa5e 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu @@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 4, 4); DECL_FATTN_MMA_F16_CASE(112, 112, 4, 4); DECL_FATTN_MMA_F16_CASE(128, 128, 4, 4); DECL_FATTN_MMA_F16_CASE(256, 256, 4, 4); +DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu index 86d4ffae27c..173de7aac7d 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu @@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 8, 4); DECL_FATTN_MMA_F16_CASE(112, 112, 8, 4); DECL_FATTN_MMA_F16_CASE(128, 128, 8, 4); DECL_FATTN_MMA_F16_CASE(256, 256, 8, 4); +DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4); diff --git a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py index a5602da02bb..10be71ab576 100755 --- a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +++ b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py @@ -85,7 +85,7 @@ def get_short_name(long_quant_name): continue if head_size_kq != 576 and ncols2 == 16: continue - if head_size_kq == 576 and ncols2 != 16: + if head_size_kq == 576 and ncols2 not in (4, 16): continue head_size_v = head_size_kq if head_size_kq != 576 else 512 f.write(SOURCE_FATTN_MMA_CASE.format(ncols1=ncols1, ncols2=ncols2, head_size_kq=head_size_kq, head_size_v=head_size_v)) From 0e030b852a19f2a0f3c8aeeee769c9f5ce85152e Mon Sep 17 00:00:00 2001 From: lhez Date: Thu, 22 Jan 2026 10:29:25 -0800 Subject: [PATCH 032/831] opencl: enable the general fp mm for non-cont input and as a fallback for specialized kqv kernel for adreno (llama/18970) * opencl: add `copy_to_contiguous` and utilize mm kernels * opencl: only copy to cont for f32 and f16 tensors * opencl: use cont mm for fallback when dst is large * opencl: use nb local to copy-to-cont * opencl: use local offset as well --- ggml/src/ggml-opencl/ggml-opencl.cpp | 179 +++++++++++++++++++++++++-- 1 file changed, 166 insertions(+), 13 deletions(-) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index efdebe2bbaa..27b2761ef1e 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -398,6 +398,7 @@ struct ggml_backend_opencl_context { int adreno_wave_size; cl_bool non_uniform_workgroups; + size_t image_max_buffer_size; cl_context context; cl_command_queue queue; @@ -407,6 +408,10 @@ struct ggml_backend_opencl_context { ggml_cl_buffer prealloc_scales_trans; ggml_cl_buffer prealloc_act_trans; + // prealloc buffers for src0 and src1 + ggml_cl_buffer prealloc_src0; + ggml_cl_buffer prealloc_src1; + cl_program program_add; cl_program program_add_id; cl_program program_clamp; @@ -2658,6 +2663,9 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { clGetDeviceInfo(device, CL_DEVICE_MAX_MEM_ALLOC_SIZE, sizeof(size_t), &backend_ctx->max_alloc_size, NULL); GGML_LOG_INFO("ggml_opencl: max mem alloc size: %zu MB\n", backend_ctx->max_alloc_size/1024/1024); + clGetDeviceInfo(device, CL_DEVICE_IMAGE_MAX_BUFFER_SIZE, sizeof(size_t), &backend_ctx->image_max_buffer_size, NULL); + GGML_LOG_INFO("ggml_opencl: device max image buffer size (pixels): %lu\n", backend_ctx->image_max_buffer_size); + clGetDeviceInfo(device, CL_DEVICE_MAX_WORK_GROUP_SIZE, sizeof(size_t), &backend_ctx->max_workgroup_size, NULL); GGML_LOG_INFO("ggml_opencl: device max workgroup size: %lu\n", backend_ctx->max_workgroup_size); @@ -4711,6 +4719,81 @@ static bool ggml_cl_can_mul_mat(const struct ggml_tensor * src0, const struct gg (ne0 >= 32 && ne1 >= 32 && ne10 >= 32); } +// Copy a noncontiguous tensor to contiguous tensor. ne[] remains the same but +// nb[] is recalculated such that tensor is contiguous. +static void ggml_cl_copy_to_contiguous(ggml_backend_t backend, const ggml_tensor * src, cl_mem dst, + cl_ulong &nb0, cl_ulong &nb1, cl_ulong &nb2, cl_ulong &nb3) { + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + const int tensor_type_size = ggml_type_size(src->type); + + const int ne00 = src->ne[0]; + const int ne01 = src->ne[1]; + const int ne02 = src->ne[2]; + const int ne03 = src->ne[3]; + + const cl_ulong nb00 = src->nb[0]; + const cl_ulong nb01 = src->nb[1]; + const cl_ulong nb02 = src->nb[2]; + const cl_ulong nb03 = src->nb[3]; + + const int ne0 = src->ne[0]; + const int ne1 = src->ne[1]; + const int ne2 = src->ne[2]; + const int ne3 = src->ne[3]; + + nb0 = tensor_type_size; + nb1 = tensor_type_size*ne00; + nb2 = tensor_type_size*ne00*ne01; + nb3 = tensor_type_size*ne00*ne01*ne02; + + ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *)src->extra; + + cl_ulong offset0 = extra->offset + src->view_offs; + cl_ulong offsetd = 0; + + cl_kernel kernel; + + switch (src->type) { + case GGML_TYPE_F32: + kernel = backend_ctx->kernel_cpy_f32_f32; + break; + case GGML_TYPE_F16: + kernel = backend_ctx->kernel_cpy_f16_f16; + break; + default: + GGML_ASSERT(false && "not implemented"); + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &dst)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb00)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne2)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne3)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb0)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb1)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb2)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb3)); + + const int nth = MIN(64, ne00); + + size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03}; + size_t local_work_size[] = {(size_t)nth, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, src); +} + static void ggml_cl_nop(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { UNUSED(backend); UNUSED(src0); @@ -7724,9 +7807,12 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co cl_context context = backend_ctx->context; if(src0t == GGML_TYPE_F16 && src1t == GGML_TYPE_F32){ - if (ne01 >= 64 && ne1 >= 32 && ne00 >= 16 && (ne12 % ne02) == 0) { + if (ne01 >= 64 && ne1 >= 32 && ne00 >= 16 && (ne12 % ne02) == 0 && + // dst is wrapped with image1d_buffer, the size limit applies, also src0 + (ne0 * ne1 * dst->ne[2] * dst->nb[0] / 4 <= backend_ctx->image_max_buffer_size)) { // For KQ if (ggml_is_permuted(src0) && ggml_is_permuted(src1) && + ((nb01 * ne01 / 4)/4 <= backend_ctx->image_max_buffer_size) && nb00 <= nb02 && nb02 <= nb01 && nb01 <= nb03 && @@ -7737,7 +7823,8 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co return; } // For KQV - if (!ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) { + if (!ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && + ((nb02 * ne02 / 4)/4 <= backend_ctx->image_max_buffer_size)) { ggml_cl_mul_mat_kq_kqv_adreno(backend, src0, src1, dst); return; } @@ -8043,9 +8130,7 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co // GEMM using local memory // Current BK = 16, so ne00 % 16 == 0 - if (ggml_is_contiguous(src0) && - ggml_is_contiguous(src1) && - src1t == GGML_TYPE_F32 && + if (src1t == GGML_TYPE_F32 && ne00 % 16 == 0 && ne11 > 1) { switch(src0t) { @@ -8057,10 +8142,42 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co int batch_stride_b = ne10*ne11; int batch_stride_d = ne0*ne1; - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + cl_mem mem_src0 = extra0->data_device; + cl_mem mem_src1 = extra1->data_device; + + cl_ulong nb00_cont = nb00; + cl_ulong nb01_cont = nb01; + cl_ulong nb02_cont = nb02; + cl_ulong nb03_cont = nb03; + + cl_ulong nb10_cont = nb10; + cl_ulong nb11_cont = nb11; + cl_ulong nb12_cont = nb12; + cl_ulong nb13_cont = nb13; + + cl_ulong offset0_cont = offset0; + cl_ulong offset1_cont = offset1; + + if (!ggml_is_contiguous(src0)) { + backend_ctx->prealloc_src0.allocate(backend_ctx->context, ggml_nbytes(src0)); + ggml_cl_copy_to_contiguous(backend, src0, backend_ctx->prealloc_src0.buffer, + nb00_cont, nb01_cont, nb02_cont, nb03_cont); + mem_src0 = backend_ctx->prealloc_src0.buffer; + offset0_cont = 0; + } + + if (!ggml_is_contiguous(src1)) { + backend_ctx->prealloc_src1.allocate(backend_ctx->context, ggml_nbytes(src1)); + ggml_cl_copy_to_contiguous(backend, src1, backend_ctx->prealloc_src1.buffer, + nb10_cont, nb11_cont, nb12_cont, nb13_cont); + mem_src1 = backend_ctx->prealloc_src1.buffer; + offset1_cont = 0; + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &mem_src0)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0_cont)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &mem_src1)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1_cont)); CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); @@ -8092,10 +8209,42 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co int batch_stride_b = ne10*ne11; int batch_stride_d = ne0*ne1; - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + cl_mem mem_src0 = extra0->data_device; + cl_mem mem_src1 = extra1->data_device; + + cl_ulong nb00_cont = nb00; + cl_ulong nb01_cont = nb01; + cl_ulong nb02_cont = nb02; + cl_ulong nb03_cont = nb03; + + cl_ulong nb10_cont = nb10; + cl_ulong nb11_cont = nb11; + cl_ulong nb12_cont = nb12; + cl_ulong nb13_cont = nb13; + + cl_ulong offset0_cont = offset0; + cl_ulong offset1_cont = offset1; + + if (!ggml_is_contiguous(src0)) { + backend_ctx->prealloc_src0.allocate(backend_ctx->context, ggml_nbytes(src0)); + ggml_cl_copy_to_contiguous(backend, src0, backend_ctx->prealloc_src0.buffer, + nb00_cont, nb01_cont, nb02_cont, nb03_cont); + mem_src0 = backend_ctx->prealloc_src0.buffer; + offset0_cont = 0; + } + + if (!ggml_is_contiguous(src1)) { + backend_ctx->prealloc_src1.allocate(backend_ctx->context, ggml_nbytes(src1)); + ggml_cl_copy_to_contiguous(backend, src1, backend_ctx->prealloc_src1.buffer, + nb10_cont, nb11_cont, nb12_cont, nb13_cont); + mem_src1 = backend_ctx->prealloc_src1.buffer; + offset1_cont = 0; + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &mem_src0)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0_cont)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &mem_src1)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1_cont)); CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); @@ -8123,6 +8272,10 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co if (ne11 < 32) { break; } + if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) { + break; + } + kernel = backend_ctx->kernel_mul_mm_q8_0_f32_l4_lm; nth0 = 128; // calculated as (BM*BN)/(TM*TN) From f21d0cbb1ac60ecbecba7a7e4294d812c6ab0fa2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Thu, 22 Jan 2026 20:39:25 +0100 Subject: [PATCH 033/831] CUDA: fix alignment check for FA (llama/19023) --- ggml/src/ggml-cuda/fattn.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 80c3bfbc271..87f07a2f938 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -46,7 +46,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_con // are put into the template specialization without GQA optimizations. bool use_gqa_opt = mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0; for (const ggml_tensor * t : {Q, K, V, mask}) { - if (t == nullptr) { + if (t == nullptr || ggml_is_quantized(t->type)) { continue; } for (size_t i = 1; i < GGML_MAX_DIMS; ++i) { @@ -236,7 +236,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const // The kernel versions without this optimization are also used for ALiBi, if there is no mask, or if the KV cache is not padded, bool gqa_opt_applies = gqa_ratio % 2 == 0 && mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0; for (const ggml_tensor * t : {Q, K, V, mask}) { - if (t == nullptr) { + if (t == nullptr || ggml_is_quantized(t->type)) { continue; } for (size_t i = 1; i < GGML_MAX_DIMS; ++i) { From 3f96a1da0e89e538f508fb641422f5de83bb4dc4 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 22 Jan 2026 22:09:01 +0200 Subject: [PATCH 034/831] mla : make the V tensor a view of K (llama/18986) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * mla : pass V as a view of K to the FA op * cuda : adjust mla logic to new layout * kv-cache : fix rope shift * tests : remove comment * cuda : fix reusable_cutoff Co-authored-by: Johannes Gäßler --------- Co-authored-by: Johannes Gäßler --- ggml/src/ggml-cuda/fattn-common.cuh | 7 +++++-- ggml/src/ggml-cuda/fattn-mma-f16.cuh | 7 ++++--- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index 8468ba8488d..a781fb91f5b 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -778,12 +778,15 @@ void launch_fattn( ) { constexpr int ncols = ncols1 * ncols2; - const bool is_mla = DV == 512; // TODO better parameterization - const ggml_tensor * Q = dst->src[0]; const ggml_tensor * K = dst->src[1]; const ggml_tensor * V = dst->src[2]; + // TODO: make this more generic by removing the notion of "MLA". + // for example "is V a view of K?" so we can skip loading it. + // V strides should be driven by V itself and avoid assumption of the data layout + const bool is_mla = V->op == GGML_OP_VIEW && V->src[0] == K; + GGML_ASSERT(V || is_mla); const ggml_tensor * mask = dst->src[3]; diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 8cca89c2bfa..203569e3459 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -794,7 +794,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( // For MLA K and V have the same data. // Therefore, iterate over V in reverse and re-use the data if possible. static_assert(!mla || nstages <= 1, "combination of MLA and multi-stage loading not implemented"); - constexpr int reusable_cutoff = mla ? (DKQ - 1) - (DKQ - 1) % (2*nbatch_K2) - (DKQ - DV) : DV; + // constexpr int reusable_cutoff = mla ? (DV - 1) - (DV - 1) % (2*nbatch_K2) : DV; + constexpr int reusable_cutoff = DV; // TODO implement properly #if defined(AMD_WMMA_AVAILABLE) && !defined(LDMATRIX_TRANS_AVAILABLE) T_A_VKQ A_identity; make_identity_mat(A_identity); @@ -1552,7 +1553,7 @@ static __global__ void flash_attn_ext_f16( (const half *) (mask + nb33*(sequence % ne33)); float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2); - const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); + const half2 * V_h2 = mla ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr; const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f; @@ -1596,7 +1597,7 @@ static __global__ void flash_attn_ext_f16( (const half *) (mask + nb33*(sequence % ne33)); float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2); - const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); + const half2 * V_h2 = mla ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr; const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f; From e090d91f5edd636f42c0a09f155ca8a341af50d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alberto=20Cabrera=20P=C3=A9rez?= <1478977+Alcpz@users.noreply.github.com> Date: Fri, 23 Jan 2026 07:55:08 +0000 Subject: [PATCH 035/831] ggml-cpu: aarm64: q5_K repack gemm and gemv (and generic) implementations (i8mm) (llama/18860) * Boilerplate for q5_Kx8 REPACK on ARM and fallback Signed-off-by: Alberto Cabrera * Implements make_block_q5_Kx8 by extending make_block_q4_Kx8 Signed-off-by: Alberto Cabrera * q5_K repack gemm and gemv generics * Gemm and Gemv ARM implementations (i8mm) * Improved qh manipulation looking at non-repack vec_dot implementation * Full unroll * Apply Q5_K Gemv vand and vshl optimizations to gemm. Improve comments. Signed-off-by: Alberto Cabrera * Fix wrong fallback definitions of Q5_K Signed-off-by: Alberto Cabrera * Fixed comments. Reverted unnecessary formatting Signed-off-by: Alberto Cabrera * Fixed typo in generic definitions * Switching AND + Shift with Shift Insert. Better op interleaving. * Vectorize + unroll the block scales * Apply gemm optimizations to gemv * Improve bias calculation --------- Signed-off-by: Alberto Cabrera --- ggml/src/ggml-cpu/arch-fallback.h | 38 +- ggml/src/ggml-cpu/arch/arm/repack.cpp | 546 +++++++++++++++++++++++++- ggml/src/ggml-cpu/repack.cpp | 360 ++++++++++++++++- ggml/src/ggml-cpu/repack.h | 25 +- 4 files changed, 931 insertions(+), 38 deletions(-) diff --git a/ggml/src/ggml-cpu/arch-fallback.h b/ggml/src/ggml-cpu/arch-fallback.h index 3f8946ac701..0a85a4cff30 100644 --- a/ggml/src/ggml-cpu/arch-fallback.h +++ b/ggml/src/ggml-cpu/arch-fallback.h @@ -38,9 +38,10 @@ #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0 +#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K -#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K +#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 @@ -48,9 +49,10 @@ #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 +#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K -#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K +#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 @@ -70,12 +72,14 @@ #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K +#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K +#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 @@ -94,9 +98,10 @@ #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0 +#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K -#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K +#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 @@ -104,9 +109,10 @@ #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 +#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K -#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K +#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 @@ -126,9 +132,10 @@ #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0 +#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K -#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K +#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 @@ -136,9 +143,10 @@ #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 +#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K -#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K +#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 @@ -165,18 +173,20 @@ #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 +#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K -#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K +#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 +#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K -#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K +#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 @@ -202,9 +212,10 @@ #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0 +#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K -#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K +#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 @@ -212,9 +223,10 @@ #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 +#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K -#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K +#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 @@ -242,9 +254,10 @@ #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0 +#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K -#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K +#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 @@ -252,9 +265,10 @@ #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 +#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K -#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K +#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 diff --git a/ggml/src/ggml-cpu/arch/arm/repack.cpp b/ggml/src/ggml-cpu/arch/arm/repack.cpp index b61220a189a..883d862901b 100644 --- a/ggml/src/ggml-cpu/arch/arm/repack.cpp +++ b/ggml/src/ggml-cpu/arch/arm/repack.cpp @@ -25,9 +25,8 @@ #define UNUSED GGML_UNUSED #if defined(__aarch64__) && defined(__ARM_NEON) && (defined(__ARM_FEATURE_MATMUL_INT8) || defined(__ARM_FEATURE_DOTPROD)) -static inline void decode_q4_Kx8_scales_mins(const uint8_t * scales_in, - int16x8_t * out_mins, - int8_t * out_scales) { +// Helper for decoding scales and mins of Q4_K and Q5_K block formats +static inline void decode_q_Kx8_6bit_scales(const uint8_t * scales_in, int16x8_t * out_mins, int8_t * out_scales) { constexpr uint32_t kmask1 = 0x3f3f3f3f; constexpr uint32_t kmask2 = 0x0f0f0f0f; constexpr uint32_t kmask3 = 0x03030303; @@ -561,7 +560,7 @@ void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo for (int i = 0; i < 2; i++) { int8_t aux_q4sb[8]; const int offset = sb * 24 + i * 12; - decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb); + decode_q_Kx8_6bit_scales(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb); q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb)); } @@ -701,7 +700,7 @@ void ggml_gemv_q4_K_8x8_q8_K(int n, for (int i = 0; i < 2; i++) { int8_t aux_q4sb[8]; const int offset = sb * 24 + i * 12; - decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb); + decode_q_Kx8_6bit_scales(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb); q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb)); } @@ -786,6 +785,293 @@ void ggml_gemv_q4_K_8x8_q8_K(int n, ggml_gemv_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc); } +void ggml_gemv_q5_K_8x8_q8_K(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + constexpr int qk = QK_K; + const int nb = n / qk; + + constexpr int ncols_interleaved = 8; + constexpr int blocklen = 8; + + assert(n % qk == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + constexpr int col_pairs = ncols_interleaved / 2; + const uint8x16_t m4b = vdupq_n_u8(0x0f); + const uint8x16_t mone = vdupq_n_u8(1); + const uint8x16_t mtwo = vdupq_n_u8(2); + + // 1x8 tile = 2 x 4 + float32x4_t acc_f32[ncols_interleaved / 4]; + + const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy; + + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q5_Kx8 * GGML_RESTRICT q5_ptr = (const block_q5_Kx8 *) vx + (x * nb); + + for (int i = 0; i < ncols_interleaved / 4; i++) { + acc_f32[i] = vdupq_n_f32(0); + } + + for (int b = 0; b < nb; b++) { + float32x4_t q5_d_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d)); // d0 d1 d2 d3 + float32x4_t q5_d_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d + 4)); // d4 d5 d6 d7 + float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d); + float32x4_t sb_scale_0 = vmulq_f32(q5_d_0, q8_d); + float32x4_t sb_scale_1 = vmulq_f32(q5_d_1, q8_d); + float32x4_t q5_dmin_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin)); // dmin 0..3 + float32x4_t q5_dmin_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin + 4)); // dmin 4..7 + float32x4_t sb_min_0 = vmulq_f32(q5_dmin_0, q8_d); + float32x4_t sb_min_1 = vmulq_f32(q5_dmin_1, q8_d); + + // 2 sb each iteration + int32x4_t acc_lo[col_pairs]; + int32x4_t acc_hi[col_pairs]; + + // Each bsum is 16 elements, pairwise add leaves us with the 8 bsums of the entire block + const int16x8_t bsums = vpaddq_s16(vld1q_s16(q8_ptr[b].bsums), vld1q_s16(q8_ptr[b].bsums + 8)); + int16_t bsums_arr[8]; + vst1q_s16(bsums_arr, bsums); + + // Load qh once per block and shift after each subblock + const uint8_t * qh_base = q5_ptr[b].qh; + uint8x16_t qh[col_pairs][4]; + for (int cp = 0; cp < col_pairs; cp++) { + qh[cp][0] = vld1q_u8(qh_base + 16 * cp); + qh[cp][1] = vld1q_u8(qh_base + 16 * cp + 64); + qh[cp][2] = vld1q_u8(qh_base + 16 * cp + 128); + qh[cp][3] = vld1q_u8(qh_base + 16 * cp + 192); + } + + for (int sb = 0; sb < QK_K / 64; sb++) { + for (int i = 0; i < col_pairs; i++) { + acc_lo[i] = vdupq_n_s32(0); + acc_hi[i] = vdupq_n_s32(0); + } + // Need scales for the low and high nibbles + // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total + int16x8_t q5sb_mins[2]; // int16 as its needed for bias_acc later + int16x8_t q5sb_scales[2]; + for (int i = 0; i < 2; i++) { + int8_t aux_q5sb[8]; + const int offset = sb * 24 + i * 12; + decode_q_Kx8_6bit_scales(&q5_ptr[b].scales[offset], &q5sb_mins[i], aux_q5sb); + q5sb_scales[i] = vmovl_s8(vld1_s8(aux_q5sb)); + } + + const uint8_t * qs_base = q5_ptr[b].qs + sb * QK_K; + + // Load the 64 quants from q8K duplicated to use vecdots with the interleaved columns + const int8_t * q8_base = q8_ptr[b].qs + sb * 64; + int8x16_t q8_qs[8]; + for (int i = 0; i < 8; i++) { + q8_qs[i] = (int8x16_t) vld1q_dup_s64((const int64_t *) (q8_base + i * 8)); + } + + // Q5s column pair loop unrolled + { + // Cols 01 + uint8x16_t qs_0 = vld1q_u8(qs_base); + uint8x16_t qs_1 = vld1q_u8(qs_base + 64); + uint8x16_t qs_2 = vld1q_u8(qs_base + 128); + uint8x16_t qs_3 = vld1q_u8(qs_base + 192); + + uint8x16_t hbit_lo_0 = vandq_u8(qh[0][0], mone); + uint8x16_t hbit_lo_1 = vandq_u8(qh[0][1], mone); + uint8x16_t hbit_lo_2 = vandq_u8(qh[0][2], mone); + uint8x16_t hbit_lo_3 = vandq_u8(qh[0][3], mone); + uint8x16_t hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[0][0], mtwo), 3); + uint8x16_t hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[0][1], mtwo), 3); + uint8x16_t hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[0][2], mtwo), 3); + uint8x16_t hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[0][3], mtwo), 3); + + qh[0][0] = vshrq_n_u8(qh[0][0], 2); + qh[0][1] = vshrq_n_u8(qh[0][1], 2); + qh[0][2] = vshrq_n_u8(qh[0][2], 2); + qh[0][3] = vshrq_n_u8(qh[0][3], 2); + + acc_lo[0] = ggml_vdotq_s32( + acc_lo[0], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_0, m4b), hbit_lo_0, 4)), q8_qs[0]); + acc_lo[0] = ggml_vdotq_s32( + acc_lo[0], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_1, m4b), hbit_lo_1, 4)), q8_qs[1]); + acc_lo[0] = ggml_vdotq_s32( + acc_lo[0], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_2, m4b), hbit_lo_2, 4)), q8_qs[2]); + acc_lo[0] = ggml_vdotq_s32( + acc_lo[0], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_3, m4b), hbit_lo_3, 4)), q8_qs[3]); + acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)), + q8_qs[4]); + acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)), + q8_qs[5]); + acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)), + q8_qs[6]); + acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)), + q8_qs[7]); + + // Cols 23 + qs_0 = vld1q_u8(qs_base + 16); + qs_1 = vld1q_u8(qs_base + 80); + qs_2 = vld1q_u8(qs_base + 144); + qs_3 = vld1q_u8(qs_base + 208); + + hbit_lo_0 = vandq_u8(qh[1][0], mone); + hbit_lo_1 = vandq_u8(qh[1][1], mone); + hbit_lo_2 = vandq_u8(qh[1][2], mone); + hbit_lo_3 = vandq_u8(qh[1][3], mone); + hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[1][0], mtwo), 3); + hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[1][1], mtwo), 3); + hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[1][2], mtwo), 3); + hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[1][3], mtwo), 3); + + qh[1][0] = vshrq_n_u8(qh[1][0], 2); + qh[1][1] = vshrq_n_u8(qh[1][1], 2); + qh[1][2] = vshrq_n_u8(qh[1][2], 2); + qh[1][3] = vshrq_n_u8(qh[1][3], 2); + + acc_lo[1] = ggml_vdotq_s32( + acc_lo[1], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_0, m4b), hbit_lo_0, 4)), q8_qs[0]); + acc_lo[1] = ggml_vdotq_s32( + acc_lo[1], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_1, m4b), hbit_lo_1, 4)), q8_qs[1]); + acc_lo[1] = ggml_vdotq_s32( + acc_lo[1], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_2, m4b), hbit_lo_2, 4)), q8_qs[2]); + acc_lo[1] = ggml_vdotq_s32( + acc_lo[1], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_3, m4b), hbit_lo_3, 4)), q8_qs[3]); + acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)), + q8_qs[4]); + acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)), + q8_qs[5]); + acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)), + q8_qs[6]); + acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)), + q8_qs[7]); + + // Cols 45 + qs_0 = vld1q_u8(qs_base + 32); + qs_1 = vld1q_u8(qs_base + 96); + qs_2 = vld1q_u8(qs_base + 160); + qs_3 = vld1q_u8(qs_base + 224); + + hbit_lo_0 = vandq_u8(qh[2][0], mone); + hbit_lo_1 = vandq_u8(qh[2][1], mone); + hbit_lo_2 = vandq_u8(qh[2][2], mone); + hbit_lo_3 = vandq_u8(qh[2][3], mone); + hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[2][0], mtwo), 3); + hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[2][1], mtwo), 3); + hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[2][2], mtwo), 3); + hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[2][3], mtwo), 3); + + qh[2][0] = vshrq_n_u8(qh[2][0], 2); + qh[2][1] = vshrq_n_u8(qh[2][1], 2); + qh[2][2] = vshrq_n_u8(qh[2][2], 2); + qh[2][3] = vshrq_n_u8(qh[2][3], 2); + + acc_lo[2] = ggml_vdotq_s32( + acc_lo[2], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_0, m4b), hbit_lo_0, 4)), q8_qs[0]); + acc_lo[2] = ggml_vdotq_s32( + acc_lo[2], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_1, m4b), hbit_lo_1, 4)), q8_qs[1]); + acc_lo[2] = ggml_vdotq_s32( + acc_lo[2], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_2, m4b), hbit_lo_2, 4)), q8_qs[2]); + acc_lo[2] = ggml_vdotq_s32( + acc_lo[2], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_3, m4b), hbit_lo_3, 4)), q8_qs[3]); + acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)), + q8_qs[4]); + acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)), + q8_qs[5]); + acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)), + q8_qs[6]); + acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)), + q8_qs[7]); + + // Cols 45 + qs_0 = vld1q_u8(qs_base + 48); + qs_1 = vld1q_u8(qs_base + 112); + qs_2 = vld1q_u8(qs_base + 176); + qs_3 = vld1q_u8(qs_base + 240); + + hbit_lo_0 = vandq_u8(qh[3][0], mone); + hbit_lo_1 = vandq_u8(qh[3][1], mone); + hbit_lo_2 = vandq_u8(qh[3][2], mone); + hbit_lo_3 = vandq_u8(qh[3][3], mone); + hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[3][0], mtwo), 3); + hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[3][1], mtwo), 3); + hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[3][2], mtwo), 3); + hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[3][3], mtwo), 3); + + qh[3][0] = vshrq_n_u8(qh[3][0], 2); + qh[3][1] = vshrq_n_u8(qh[3][1], 2); + qh[3][2] = vshrq_n_u8(qh[3][2], 2); + qh[3][3] = vshrq_n_u8(qh[3][3], 2); + + acc_lo[3] = ggml_vdotq_s32( + acc_lo[3], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_0, m4b), hbit_lo_0, 4)), q8_qs[0]); + acc_lo[3] = ggml_vdotq_s32( + acc_lo[3], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_1, m4b), hbit_lo_1, 4)), q8_qs[1]); + acc_lo[3] = ggml_vdotq_s32( + acc_lo[3], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_2, m4b), hbit_lo_2, 4)), q8_qs[2]); + acc_lo[3] = ggml_vdotq_s32( + acc_lo[3], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_3, m4b), hbit_lo_3, 4)), q8_qs[3]); + acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)), + q8_qs[4]); + acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)), + q8_qs[5]); + acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)), + q8_qs[6]); + acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)), + q8_qs[7]); + } + + // Prepare bsum vectors for bias computation + // Each pair of subblocks share the same bsums + int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[2 * sb + 0]); + int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[2 * sb + 1]); + + // Iterates over a pair of column pairs (4 columns) to use a single 128 register + // p = 0 -> 0123 p2 -> 4567 + for (int i = 0, p = 0; p < col_pairs; i++, p += 2) { + int16x4_t group_scales_lo = p == 0 ? vget_low_s16(q5sb_scales[0]) : vget_high_s16(q5sb_scales[0]); + int16x4_t group_scales_hi = p == 0 ? vget_low_s16(q5sb_scales[1]) : vget_high_s16(q5sb_scales[1]); + int16x4_t group_mins_lo = p == 0 ? vget_low_s16(q5sb_mins[0]) : vget_high_s16(q5sb_mins[0]); + int16x4_t group_mins_hi = p == 0 ? vget_low_s16(q5sb_mins[1]) : vget_high_s16(q5sb_mins[1]); + float32x4_t sb_scale = p == 0 ? sb_scale_0 : sb_scale_1; + float32x4_t sb_min = p == 0 ? sb_min_0 : sb_min_1; + + // 0123 or 4567 + float32x4_t sumf_0 = + vcvtq_f32_s32(vmulq_s32(vmovl_s16(group_scales_lo), vpaddq_s32(acc_lo[p], acc_lo[p + 1]))); + acc_f32[i] = vfmaq_f32(acc_f32[i], sb_scale, sumf_0); + + float32x4_t sumf_1 = + vcvtq_f32_s32(vmulq_s32(vmovl_s16(group_scales_hi), vpaddq_s32(acc_hi[p], acc_hi[p + 1]))); + acc_f32[i] = vfmaq_f32(acc_f32[i], sb_scale, sumf_1); + + // FUSED BIAS: Compute and subtract bias immediately + // bias = (bsums_lo * mins_lo + bsums_hi * mins_hi) * sb_min + int32x4_t bias = vmull_s16(bsums_vec_lo, group_mins_lo); + bias = vmlal_s16(bias, bsums_vec_hi, group_mins_hi); + float32x4_t bias_f32 = vcvtq_f32_s32(bias); + acc_f32[i] = vmlsq_f32(acc_f32[i], sb_min, bias_f32); + } + } // for sb + } // for b + + int base = x * ncols_interleaved; + vst1q_f32(s + base, acc_f32[0]); + vst1q_f32(s + base + 4, acc_f32[1]); + } // for x + return; +#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + ggml_gemv_q5_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc); +} + void ggml_gemv_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, @@ -2431,7 +2717,7 @@ void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo for (int i = 0; i < 2; i++) { int8_t aux_q4sb[8]; const int offset = sb * 24 + i * 12; - decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb); + decode_q_Kx8_6bit_scales(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb); q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb)); } @@ -2595,7 +2881,7 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, int16x8_t q4sb_mins[2]; // int16 as its needed for bias_acc later for (int i = 0; i < 2; i++) { const int offset = sb * 24 + i * 12; - decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], q4sb_scales[i]); + decode_q_Kx8_6bit_scales(&q4_ptr[b].scales[offset], &q4sb_mins[i], q4sb_scales[i]); } // q8_ptr[b].qs has interleaved Q8 rows (01, 23) @@ -2738,6 +3024,252 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, ggml_gemm_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc); } +void ggml_gemm_q5_K_8x8_q8_K(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + constexpr int qk = QK_K; + const int nb = n / qk; + + constexpr int ncols_interleaved = 8; + constexpr int blocklen = 8; + + assert(n % qk == 0); + assert(nr % 4 == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) + constexpr int q8_k_blocklen = 4; + constexpr int col_pairs = ncols_interleaved / 2; + const uint8x16_t m4b = vdupq_n_u8(0x0f); + const uint8x16_t mone = vdupq_n_u8(1); + const uint8x16_t mtwo = vdupq_n_u8(2); + + // 8 accumulators: 2 row pairs × 4 col pairs + float32x4_t acc_f32[blocklen]; + + for (int y = 0; y < nr / q8_k_blocklen; y++) { + const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb); + + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q5_Kx8 * GGML_RESTRICT q5_ptr = (const block_q5_Kx8 *) vx + (x * nb); + + for (int i = 0; i < blocklen; i++) { + acc_f32[i] = vdupq_n_f32(0); + } + + for (int b = 0; b < nb; b++) { + // bsums pairs belongs to the same q8_k subblock + const int16x8_t bsums[4]{ + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)), + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)), + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)), + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)), + }; + int16_t bsums_arr[4][8]; + for (int q8_row = 0; q8_row < 4; q8_row++) { + vst1q_s16(bsums_arr[q8_row], bsums[q8_row]); + } + + int32x4_t sb_acc[4]; // Aux accumulators to store subblock (partial) results + int32x4_t acc[8]; // rows 01 stored in [0][1][2][3] rows 23 stored in [4][5][6][7] + int32x4_t bias_acc[8]; // interleaved bias_acc: [0]->r0 0123, [1]->r0 4567, [2]->r1 0123 ... + for (int i = 0; i < 8; i++) { + acc[i] = vdupq_n_s32(0); + bias_acc[i] = vdupq_n_s32(0); + } + + // Load qh once per block and shift after each subblock + const uint8_t * qh_base = q5_ptr[b].qh; + uint8x16_t qh[col_pairs][4]; + for (int cp = 0; cp < col_pairs; cp++) { + qh[cp][0] = vld1q_u8(qh_base + 16 * cp); + qh[cp][1] = vld1q_u8(qh_base + 16 * cp + 64); + qh[cp][2] = vld1q_u8(qh_base + 16 * cp + 128); + qh[cp][3] = vld1q_u8(qh_base + 16 * cp + 192); + } + + for (int sb = 0; sb < QK_K / 64; sb++) { + // Need scales for the low and high nibbles + // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total + int8_t q5sb_scales[2][8]; + int16x8_t q5sb_mins[2]; // int16 as its needed for bias_acc later + for (int i = 0; i < 2; i++) { + const int offset = sb * 24 + i * 12; + decode_q_Kx8_6bit_scales(&q5_ptr[b].scales[offset], &q5sb_mins[i], q5sb_scales[i]); + } + + // q8_ptr[b].qs has interleaved Q8 rows (01, 23) + const int8_t * q8_base = q8_ptr[b].qs + sb * 256; + + int8x16_t q8_qs_01[8]; + int8x16_t q8_qs_23[8]; + + // Load 32-byte per row pair, 1 subblock each time + for (int i = 0; i < 8; i++) { + const int offset = i * 32; // 16 for row 01, 16 for row 23 + q8_qs_01[i] = vld1q_s8(q8_base + offset); + q8_qs_23[i] = vld1q_s8(q8_base + offset + 16); + } + + const int8x16_t q8s[2][8] = { + { q8_qs_01[0], q8_qs_01[1], q8_qs_01[2], q8_qs_01[3], q8_qs_01[4], q8_qs_01[5], q8_qs_01[6], + q8_qs_01[7] }, + { q8_qs_23[0], q8_qs_23[1], q8_qs_23[2], q8_qs_23[3], q8_qs_23[4], q8_qs_23[5], q8_qs_23[6], + q8_qs_23[7] }, + }; + + // Q5s columns iterated in pairs (01, 23, 45, 67) + for (int cp = 0; cp < col_pairs; cp++) { + for (int i = 0; i < 4; i++) { + sb_acc[i] = vdupq_n_s32(0); + } + + uint8x16_t qs_cp_0 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 16 * cp + 0); // 0 .. 7 & 32..39 + uint8x16_t qs_cp_1 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 16 * cp + 64); // 8 ..15 & 40..47 + uint8x16_t qs_cp_2 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 16 * cp + 128); // 16..23 & 48..55 + uint8x16_t qs_cp_3 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 16 * cp + 192); // 24..31 & 56..63 + + // This is the only part of the algorithm that differs with Q4_K + // Extract High bits and pack into 5 bit weights + uint8x16_t hbit_lo_0 = vandq_u8(qh[cp][0], mone); + uint8x16_t hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[cp][0], mtwo), 3); + qh[cp][0] = vshrq_n_u8(qh[cp][0], 2); + // Same as Q4_K, i8mm to dequantize the weights. + const int8x16_t qs_lo_0 = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_cp_0, m4b), hbit_lo_0, 4)); + int32x4_t acc_0 = sb_acc[0]; + acc_0 = vmmlaq_s32(acc_0, qs_lo_0, q8s[0][0]); + int32x4_t acc_2 = sb_acc[2]; + acc_2 = vmmlaq_s32(acc_2, qs_lo_0, q8s[1][0]); + const int8x16_t qs_hi_0 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_0, 4), hbit_hi_0)); + int32x4_t acc_1 = sb_acc[1]; + acc_1 = vmmlaq_s32(acc_1, qs_hi_0, q8s[0][4]); + int32x4_t acc_3 = sb_acc[3]; + acc_3 = vmmlaq_s32(acc_3, qs_hi_0, q8s[1][4]); + + // Repeat for the other 3 columns (8..15, 16..23, 24..31) + uint8x16_t hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[cp][1], mtwo), 3); + uint8x16_t hbit_lo_1 = vandq_u8(qh[cp][1], mone); + qh[cp][1] = vshrq_n_u8(qh[cp][1], 2); + const int8x16_t qs_lo_1 = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_cp_1, m4b), hbit_lo_1, 4)); + acc_0 = vmmlaq_s32(acc_0, qs_lo_1, q8s[0][1]); + acc_2 = vmmlaq_s32(acc_2, qs_lo_1, q8s[1][1]); + const int8x16_t qs_hi_1 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_1, 4), hbit_hi_1)); + acc_1 = vmmlaq_s32(acc_1, qs_hi_1, q8s[0][5]); + acc_3 = vmmlaq_s32(acc_3, qs_hi_1, q8s[1][5]); + + uint8x16_t hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[cp][2], mtwo), 3); + uint8x16_t hbit_lo_2 = vandq_u8(qh[cp][2], mone); + qh[cp][2] = vshrq_n_u8(qh[cp][2], 2); + const int8x16_t qs_lo_2 = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_cp_2, m4b), hbit_lo_2, 4)); + acc_0 = vmmlaq_s32(acc_0, qs_lo_2, q8s[0][2]); + acc_2 = vmmlaq_s32(acc_2, qs_lo_2, q8s[1][2]); + const int8x16_t qs_hi_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_2, 4), hbit_hi_2)); + acc_1 = vmmlaq_s32(acc_1, qs_hi_2, q8s[0][6]); + acc_3 = vmmlaq_s32(acc_3, qs_hi_2, q8s[1][6]); + + uint8x16_t hbit_lo_3 = vandq_u8(qh[cp][3], mone); + uint8x16_t hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[cp][3], mtwo), 3); + qh[cp][3] = vshrq_n_u8(qh[cp][3], 2); + const int8x16_t qs_lo_3 = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_cp_3, m4b), hbit_lo_3, 4)); + acc_0 = vmmlaq_s32(acc_0, qs_lo_3, q8s[0][3]); + sb_acc[0] = acc_0; + acc_2 = vmmlaq_s32(acc_2, qs_lo_3, q8s[1][3]); + sb_acc[2] = acc_2; + + // Scales[i] corresponds to column i + const int scale_offset = cp * 2; + const int32_t s0 = q5sb_scales[0][scale_offset]; + const int32_t s1 = q5sb_scales[0][scale_offset + 1]; + const int32x4_t block_scale = vcombine_s32(vdup_n_s32(s0), vdup_n_s32(s1)); + acc[cp] = vmlaq_s32(acc[cp], sb_acc[0], block_scale); + acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc[2], block_scale); + + const int8x16_t qs_hi_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_3, 4), hbit_hi_3)); + acc_1 = vmmlaq_s32(acc_1, qs_hi_3, q8s[0][7]); + sb_acc[1] = acc_1; + acc_3 = vmmlaq_s32(acc_3, qs_hi_3, q8s[1][7]); + sb_acc[3] = acc_3; + + const int32_t s2 = q5sb_scales[1][scale_offset]; + const int32_t s3 = q5sb_scales[1][scale_offset + 1]; + const int32x4_t block_scale2 = vcombine_s32(vdup_n_s32(s2), vdup_n_s32(s3)); + acc[cp] = vmlaq_s32(acc[cp], sb_acc[1], block_scale2); + acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc[3], block_scale2); + } + + // Multiply Acc bsum + mins + for (int q8_row = 0; q8_row < 4; q8_row++) { + // Each pair of subblocks share the same bsums + // Load scalar bsum → broadcast to a vector (vdupq_n_s16(s)). + int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[sb][q8_row * 2]); + int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][q8_row * 2 + 1]); + + bias_acc[2 * q8_row] = + vmlal_s16(bias_acc[2 * q8_row], bsums_vec_lo, vget_low_s16(q5sb_mins[0])); + bias_acc[2 * q8_row] = + vmlal_s16(bias_acc[2 * q8_row], bsums_vec_hi, vget_low_s16(q5sb_mins[1])); + bias_acc[2 * q8_row + 1] = + vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_lo, vget_high_s16(q5sb_mins[0])); + bias_acc[2 * q8_row + 1] = + vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_hi, vget_high_s16(q5sb_mins[1])); + } + } // for sb + + // Reorder of i8mm output with bias and output layout + for (int i = 0; i < 8; i++) { + int32x2x2_t aux = vzip_s32(vget_low_s32(acc[i]), vget_high_s32(acc[i])); + acc[i] = vcombine_s32(aux.val[0], aux.val[1]); + } + int32x4_t reorder_acc[8] = { + vcombine_s32(vget_low_s32(acc[0]), vget_low_s32(acc[1])), + vcombine_s32(vget_low_s32(acc[2]), vget_low_s32(acc[3])), + vcombine_s32(vget_high_s32(acc[0]), vget_high_s32(acc[1])), + vcombine_s32(vget_high_s32(acc[2]), vget_high_s32(acc[3])), + vcombine_s32(vget_low_s32(acc[4]), vget_low_s32(acc[5])), + vcombine_s32(vget_low_s32(acc[6]), vget_low_s32(acc[7])), + vcombine_s32(vget_high_s32(acc[4]), vget_high_s32(acc[5])), + vcombine_s32(vget_high_s32(acc[6]), vget_high_s32(acc[7])), + }; + + for (int i = 0; i < q8_k_blocklen; i++) { + for (int j = 0; j < 2; j++) { + float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d[i]); + float32x4_t q5_dmin = vcvt_f32_f16(vld1_f16((const __fp16 *) (q5_ptr[b].dmin + j * 4))); + const float32x4_t dmins = vmulq_f32(q5_dmin, q8_d); + + float32x4_t q5_d = vcvt_f32_f16(vld1_f16((const __fp16 *) (q5_ptr[b].d + j * 4))); + const float32x4_t scale = vmulq_f32(q5_d, q8_d); + + acc_f32[2 * i + j] = vmlsq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(bias_acc[2 * i + j]), dmins); + acc_f32[2 * i + j] = + vmlaq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(reorder_acc[2 * i + j]), scale); + } + } + } // for b + + // With the previous reorder, the tile is already in the correct memory layout. + for (int i = 0; i < q8_k_blocklen; i++) { + int row = y * q8_k_blocklen + i; + for (int j = 0; j < 2; j++) { + int col = x * ncols_interleaved + j * 4; + int offset = row * bs + col; + vst1q_f32(s + offset, acc_f32[2 * i + j]); + } + } + } // for x + } // for y + return; +#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) + ggml_gemm_q5_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc); +} void ggml_gemm_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp index fbf7ed9432a..19e021e59aa 100644 --- a/ggml/src/ggml-cpu/repack.cpp +++ b/ggml/src/ggml-cpu/repack.cpp @@ -474,15 +474,8 @@ void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, assert (n % qk == 0); assert (nc % ncols_interleaved == 0); - UNUSED(s); UNUSED(bs); - UNUSED(vx); - UNUSED(vy); UNUSED(nr); - UNUSED(nc); - UNUSED(nb); - UNUSED(ncols_interleaved); - UNUSED(blocklen); float sumf[8]; float sum_minf[8]; @@ -616,6 +609,100 @@ void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, } } +void ggml_gemv_q5_K_8x8_q8_K_generic(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + const int qk = QK_K; + const int nb = n / qk; + const int ncols_interleaved = 8; + const int blocklen = 8; + static const uint32_t kmask1 = 0x3f3f3f3f; + static const uint32_t kmask2 = 0x0f0f0f0f; + static const uint32_t kmask3 = 0x03030303; + + assert(n % qk == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(bs); + UNUSED(nr); + + float sumf[8]; + float sum_minf[8]; + uint32_t utmp[32]; + int sumi1; + int sumi2; + int sumi; + + const block_q8_K * a_ptr = (const block_q8_K *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q5_Kx8 * b_ptr = (const block_q5_Kx8 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) { + sumf[j] = 0.0; + sum_minf[j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int sb = 0; sb < 8; sb++) { + memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12); + utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4); + const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1; + utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4); + utmp[sb * 4 + 2] = uaux_0; + utmp[sb * 4 + 0] &= kmask1; + } + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + uint8_t * scales_0 = (uint8_t *) utmp + (k / 4) * 32; + uint8_t * scales_1 = (uint8_t *) utmp + (k / 4) * 32 + 16; + + const int qh_shift = (k / 4) * 2; + for (int j = 0; j < ncols_interleaved; j++) { + sumi1 = 0; + sumi2 = 0; + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int b_qs_offset = k * ncols_interleaved * blocklen + j * blocklen + i; + + const int qh_idx = (k * 8 + i) % 32; + const int qh_chunk = qh_idx / 8; + const int qh_pos = qh_idx % 8; + const int b_qh_offset = qh_chunk * 64 + j * 8 + qh_pos; + + const uint8_t qh_val = b_ptr[l].qh[b_qh_offset]; + const uint8_t h0 = (qh_val >> qh_shift) & 1; + const uint8_t h1 = (qh_val >> (qh_shift + 1)) & 1; + + const int v0 = (int8_t) ((b_ptr[l].qs[b_qs_offset] & 0xF) | (h0 << 4)); + const int v1 = (int8_t) ((b_ptr[l].qs[b_qs_offset] >> 4) | (h1 << 4)); + + const int q8_offset = (k >> 2) * 64 + (k % 4) * blocklen + i; + + sumi1 = (v0 * a_ptr[l].qs[q8_offset]); + sumi2 = (v1 * a_ptr[l].qs[q8_offset + 32]); + sumi1 = sumi1 * scales_0[j]; + sumi2 = sumi2 * scales_1[j]; + sumi += sumi1 + sumi2; + } + sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d; + } + } + for (int sb = 0; sb < 8; sb++) { + uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16; + for (int j = 0; j < ncols_interleaved; j++) { + sum_minf[j] += mins[j] * (a_ptr[l].bsums[sb * 2] + a_ptr[l].bsums[sb * 2 + 1]) * + GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d; + } + } + } + for (int j = 0; j < ncols_interleaved; j++) { + s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j]; + } + } +} + void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; @@ -1212,6 +1299,108 @@ void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, } } +void ggml_gemm_q5_K_8x8_q8_K_generic(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + const int qk = QK_K; + const int nb = n / qk; + const int ncols_interleaved = 8; + const int blocklen = 8; + + constexpr uint32_t kmask1 = 0x3f3f3f3f; + constexpr uint32_t kmask2 = 0x0f0f0f0f; + constexpr uint32_t kmask3 = 0x03030303; + + assert(n % qk == 0); + assert(nr % 4 == 0); + assert(nc % ncols_interleaved == 0); + + float sumf[4][8]; + float sum_minf[4][8]; + uint32_t utmp[32]; + int sumi1; + int sumi2; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q5_Kx8 * b_ptr = (const block_q5_Kx8 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumf[m][j] = 0.0; + sum_minf[m][j] = 0.0; + } + } + for (int l = 0; l < nb; l++) { + for (int sb = 0; sb < 8; sb++) { + memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12); + utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4); + const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1; + utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4); + utmp[sb * 4 + 2] = uaux_0; + utmp[sb * 4 + 0] &= kmask1; + } + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + uint8_t * scales_0 = (uint8_t *) utmp + (k / 4) * 32; + uint8_t * scales_1 = (uint8_t *) utmp + (k / 4) * 32 + 16; + + const int qh_shift = (k / 4) * 2; + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi1 = 0; + sumi2 = 0; + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int b_qs_offset = k * ncols_interleaved * blocklen + j * blocklen + i; + + const int qh_idx = (k * 8 + i) % 32; + const int qh_chunk = qh_idx / 8; + const int qh_pos = qh_idx % 8; + const int b_qh_offset = qh_chunk * 64 + j * 8 + qh_pos; + + const uint8_t qh_val = b_ptr[l].qh[b_qh_offset]; + const uint8_t h0 = (qh_val >> qh_shift) & 1; + const uint8_t h1 = (qh_val >> (qh_shift + 1)) & 1; + + const int v0 = (int8_t) ((b_ptr[l].qs[b_qs_offset] & 0xF) | (h0 << 4)); + const int v1 = (int8_t) ((b_ptr[l].qs[b_qs_offset] >> 4) | (h1 << 4)); + + const int q8_offset = (k >> 2) * 256 + (k % 4) * 4 * blocklen + m * blocklen + i; + + sumi1 = (v0 * a_ptr[l].qs[q8_offset]); + sumi2 = (v1 * a_ptr[l].qs[q8_offset + 128]); + sumi1 = sumi1 * scales_0[j]; + sumi2 = sumi2 * scales_1[j]; + sumi += sumi1 + sumi2; + } + sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m]; + } + } + } + for (int sb = 0; sb < 8; sb++) { + uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16; + for (int m = 0; m < 4; m++) { + const int16_t * bsums = a_ptr[l].bsums + (sb * 8) + (m * 4) - ((sb % 2) * 6); + for (int j = 0; j < ncols_interleaved; j++) { + sum_minf[m][j] += mins[j] * (bsums[0] + bsums[1]) * + GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m]; + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j]; + } + } + } + } +} void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; @@ -1622,7 +1811,95 @@ static block_q2_Kx8 make_block_q2_Kx8(block_q2_K * in, unsigned int blck_size_in out.scales[i] = in[src1].scales[src2]; } return out; +} + +static block_q5_Kx8 make_block_q5_Kx8(block_q5_K * in, unsigned int blck_size_interleave) { + block_q5_Kx8 out; + //Delta(scale) and dmin values of the eight Q5_K structures are copied onto the output interleaved structure + for (int i = 0; i < 8; i++) { + out.d[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d; + } + + for (int i = 0; i < 8; i++) { + out.dmin[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin; + } + const int end = QK_K * 4 / blck_size_interleave; + + // Interleave Q5_K quants by taking 8 bytes at a time + for (int i = 0; i < end; ++i) { + int src_id = i % 8; + int src_offset = (i / 8) * blck_size_interleave; + int dst_offset = i * blck_size_interleave; + + uint64_t elems; + memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t)); + memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t)); + } + + // Repeat for low bits 8 bytes at a time as well, since + // the high bits are interleaved in Q5_K and the index is + // qh_idx = (qs_idx % 32); + // qh_val = qh[qh_idx] >> (qs_idx / 32); + for (int i = 0; i < end / 4; ++i) { + int src_id = i % 8; + int src_offset = (i / 8) * blck_size_interleave; + int dst_offset = i * blck_size_interleave; + + uint64_t elems; + memcpy(&elems, &in[src_id].qh[src_offset], sizeof(uint64_t)); + memcpy(&out.qh[dst_offset], &elems, sizeof(uint64_t)); + } + + // The below logic is copied over from Q4_K + // The point is to unpack all the scales and mins for each sub block every time we load 12 bytes. + // Currently the Q5_K structure has 8 scales and 8 mins packed in 12 bytes ( 6 bits for each value) + // The output Q5_Kx8 structure has 96 bytes + // Every 12 byte is packed such that it contains scales and mins for corresponding sub blocks from Q5_K structure + // For eg - First 12 bytes contains 8 scales and 8 mins - each of first sub block from different Q5_K structures + uint8_t s[8], m[8]; + + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 8; j++) { + s[j] = in[j].scales[i] & 63; + m[j] = in[j].scales[i + 4] & 63; + } + + out.scales[i * 12] = (s[0] & 63) + ((s[4] & 48) << 2); + out.scales[i * 12 + 1] = (s[1] & 63) + ((s[5] & 48) << 2); + out.scales[i * 12 + 2] = (s[2] & 63) + ((s[6] & 48) << 2); + out.scales[i * 12 + 3] = (s[3] & 63) + ((s[7] & 48) << 2); + out.scales[i * 12 + 4] = (m[0] & 63) + ((m[4] & 48) << 2); + out.scales[i * 12 + 5] = (m[1] & 63) + ((m[5] & 48) << 2); + out.scales[i * 12 + 6] = (m[2] & 63) + ((m[6] & 48) << 2); + out.scales[i * 12 + 7] = (m[3] & 63) + ((m[7] & 48) << 2); + out.scales[i * 12 + 8] = (s[4] & 15) + ((m[4] & 15) << 4); + out.scales[i * 12 + 9] = (s[5] & 15) + ((m[5] & 15) << 4); + out.scales[i * 12 + 10] = (s[6] & 15) + ((m[6] & 15) << 4); + out.scales[i * 12 + 11] = (s[7] & 15) + ((m[7] & 15) << 4); + } + + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 8; j++) { + s[j] = ((in[j].scales[i] & 192) >> 2) | (in[j].scales[i + 8] & 15); + m[j] = ((in[j].scales[i + 4] & 192) >> 2) | ((in[j].scales[i + 8] & 240) >> 4); + } + + out.scales[i * 12 + 48] = (s[0] & 63) + ((s[4] & 48) << 2); + out.scales[i * 12 + 49] = (s[1] & 63) + ((s[5] & 48) << 2); + out.scales[i * 12 + 50] = (s[2] & 63) + ((s[6] & 48) << 2); + out.scales[i * 12 + 51] = (s[3] & 63) + ((s[7] & 48) << 2); + out.scales[i * 12 + 52] = (m[0] & 63) + ((m[4] & 48) << 2); + out.scales[i * 12 + 53] = (m[1] & 63) + ((m[5] & 48) << 2); + out.scales[i * 12 + 54] = (m[2] & 63) + ((m[6] & 48) << 2); + out.scales[i * 12 + 55] = (m[3] & 63) + ((m[7] & 48) << 2); + out.scales[i * 12 + 56] = (s[4] & 15) + ((m[4] & 15) << 4); + out.scales[i * 12 + 57] = (s[5] & 15) + ((m[5] & 15) << 4); + out.scales[i * 12 + 58] = (s[6] & 15) + ((m[6] & 15) << 4); + out.scales[i * 12 + 59] = (s[7] & 15) + ((m[7] & 15) << 4); + } + + return out; } static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { @@ -1718,6 +1995,38 @@ static int repack_q2_K_to_q2_K_8_bl(struct ggml_tensor * t, int interleave_block GGML_UNUSED(data_size); } +static int repack_q5_K_to_q5_K_8_bl(struct ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q5_K); + GGML_ASSERT(interleave_block == 8); + constexpr int nrows_interleaved = 8; + + block_q5_Kx8 * dst = (block_q5_Kx8 *) t->data; + const block_q5_K * src = (const block_q5_K *) data; + block_q5_K dst_tmp[8]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_K; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q5_K)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_q5_Kx8(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; +} + static int repack_q4_0_to_q4_0_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { GGML_ASSERT(t->type == GGML_TYPE_Q4_0); GGML_ASSERT(interleave_block == 8); @@ -1936,6 +2245,10 @@ template <> int repack(struct ggml_tensor * t, const void * da return repack_q2_K_to_q2_K_8_bl(t, 8, data, data_size); } +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q5_K_to_q5_K_8_bl(t, 8, data, data_size); +} + template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { return repack_iq4_nl_to_iq4_nl_4_bl(t, 4, data, data_size); } @@ -1973,6 +2286,10 @@ template <> void gemv(int n, float * s, size_t ggml_gemv_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc); } +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q2_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); +} + template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemv_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc); } @@ -1981,8 +2298,8 @@ template <> void gemv(int n, float * s, size_t ggml_gemv_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); } -template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { - ggml_gemv_q2_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q5_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); } template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { @@ -2013,20 +2330,24 @@ template <> void gemm(int n, float * s, size_t ggml_gemm_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc); } -template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { - ggml_gemm_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc); -} - template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc); } +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q2_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc); +} + template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); } -template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { - ggml_gemm_q2_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q5_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); } template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { @@ -2432,6 +2753,9 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons static const ggml::cpu::repack::tensor_traits q4_K_8x4_q8_K; static const ggml::cpu::repack::tensor_traits q4_K_8x8_q8_K; + // instance for Q5_K + static const ggml::cpu::repack::tensor_traits q5_K_8x8_q8_K; + // instance for Q2 static const ggml::cpu::repack::tensor_traits q2_K_8x8_q8_K; @@ -2482,6 +2806,12 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons return &q2_K_8x8_q8_K; } } + } else if (cur->type == GGML_TYPE_Q5_K) { + if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) { + if (cur->ne[1] % 8 == 0) { + return &q5_K_8x8_q8_K; + } + } } else if (cur->type == GGML_TYPE_IQ4_NL) { if (ggml_cpu_has_avx2()) { if (cur->ne[1] % 8 == 0) { diff --git a/ggml/src/ggml-cpu/repack.h b/ggml/src/ggml-cpu/repack.h index af98e703442..da87103157e 100644 --- a/ggml/src/ggml-cpu/repack.h +++ b/ggml/src/ggml-cpu/repack.h @@ -44,6 +44,7 @@ struct block_q4_Kx8 { }; static_assert(sizeof(block_q4_Kx8) == sizeof(ggml_half) * 16 + K_SCALE_SIZE * 8 + QK_K * 4, "wrong q4_K block size/padding"); + struct block_q2_Kx8 { ggml_half d[8]; // super-block scale for quantized scales ggml_half dmin[8]; // super-block scale for quantized mins @@ -52,6 +53,18 @@ struct block_q2_Kx8 { }; static_assert(sizeof(block_q2_Kx8) == sizeof(ggml_half) * 16 + QK_K/2 + QK_K * 2, "wrong q2_K block size/padding"); + +struct block_q5_Kx8 { + ggml_half d[8]; // super-block scale for quantized scales + ggml_half dmin[8]; // super-block scale for quantized mins + uint8_t scales[96]; // scales and mins, quantized with 6 bits + uint8_t qh[QK_K * 8 / 8]; // high bits of 5-bit quants + uint8_t qs[QK_K * 8 / 2]; // low bits of 5-bit quants (in groups of 4) +}; + +static_assert(sizeof(block_q5_Kx8) == sizeof(ggml_half) * 16 + K_SCALE_SIZE * 8 + QK_K * 5, + "wrong q5_K block size/padding"); + struct block_q8_Kx4 { float d[4]; // delta int8_t qs[QK_K * 4]; // quants @@ -82,20 +95,22 @@ void ggml_quantize_mat_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTR void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_K_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); +void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q5_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q5_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -111,17 +126,19 @@ void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GG void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); From 0d9dda5a991afb0da39df3e1dce2e27928fac469 Mon Sep 17 00:00:00 2001 From: Neo Zhang Date: Fri, 23 Jan 2026 20:54:10 +0800 Subject: [PATCH 036/831] use malloc to support both iGPU and dGPU in same time (llama/18992) * use malloc to support both iGPU and dGPU in same time * support windows --------- Co-authored-by: Neo Zhang Jianyu --- ggml/src/ggml-sycl/ggml-sycl.cpp | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index bb8acc922b9..ce2f0d41c96 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -1157,13 +1157,28 @@ static const char * ggml_backend_sycl_host_buffer_type_name(ggml_backend_buffer_ GGML_UNUSED(buft); } +inline void * aligned_malloc_host(size_t alignment, size_t size) { +#ifdef _WIN32 + return _aligned_malloc(size, alignment); +#else + return aligned_alloc(alignment, size); +#endif +} + +inline void free_aligned_mem_host(void * memblock) { +#ifdef _WIN32 + _aligned_free(memblock); +#else + free(memblock); +#endif +} + static void ggml_backend_sycl_host_buffer_free_buffer(ggml_backend_buffer_t buffer) { - ggml_sycl_host_free(buffer->context); + free_aligned_mem_host((void *)buffer->context); } static ggml_backend_buffer_t ggml_backend_sycl_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { - void * ptr = ggml_sycl_host_malloc(size); - + void * ptr = aligned_malloc_host(TENSOR_ALIGNMENT, size); if (ptr == nullptr) { // fallback to cpu buffer return ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size); From 79f1bb3d355c198e390f622555cb22225683a2bf Mon Sep 17 00:00:00 2001 From: nullname Date: Sat, 24 Jan 2026 14:02:07 +0800 Subject: [PATCH 037/831] ggml-hexagon: flash-attn opt (llama/19025) * optimize flash attention kernel by improving score computation and online softmax update * wip * Refactor online softmax update in flash attention kernel for improved performance * Optimize flash attention kernel by replacing float array with HVX_Vector for score computation * wip --- ggml/src/ggml-hexagon/htp/flash-attn-ops.c | 58 +++++++++++++--------- 1 file changed, 34 insertions(+), 24 deletions(-) diff --git a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c index 1de47d0f3d4..c7cb2a4e0bc 100644 --- a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +++ b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c @@ -2,9 +2,9 @@ #pragma clang diagnostic ignored "-Wunused-function" #pragma clang diagnostic ignored "-Wunused-but-set-variable" +#include #include #include - #include #include @@ -111,7 +111,7 @@ static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict hvx_vec_store_u(r, 4, rsum); } -// MAD: y (F32) += x (F16) * v (float) +// MAD: y (F32) += x (F16) * s (float) static inline void hvx_mad_f32_f16_aa(float * restrict y, const void * restrict x, int n, float s) { const HVX_Vector * restrict ptr_x = (const HVX_Vector *) x; HVX_Vector * restrict ptr_y = (HVX_Vector *) y; @@ -318,9 +318,12 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in uint32_t ic = 0; // Process in blocks of 32 (VLEN_FP32) - for (; ic + VLEN_FP32 <= current_block_size; ic += VLEN_FP32) { + static_assert(FLASH_ATTN_BLOCK_SIZE / VLEN_FP32 == 4, "FLASH_ATTN_BLOCK_SIZE changed, fix HVX_Vector_x4 usage"); + HVX_Vector_x4 scores_x4; + HVX_Vector v_max = hvx_vec_splat_f32(-INFINITY); + for (uint32_t iv = 0; ic + VLEN_FP32 <= current_block_size; ic += VLEN_FP32, ++iv) { // 1. Compute scores - float __attribute__((aligned(VLEN))) scores_arr[VLEN_FP32]; + float __attribute__((aligned(VLEN))) scores_arr[FLASH_ATTN_BLOCK_SIZE]; for (int j = 0; j < VLEN_FP32; ++j) { const uint32_t cur_ic = ic + j; const uint8_t * k_ptr = k_base + cur_ic * size_k_row_padded; @@ -356,36 +359,43 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in scores = Q6_Vsf_equals_Vqf32(scores); } + scores_x4.v[iv] = scores; + v_max = Q6_Vsf_vmax_VsfVsf(scores, v_max); + } + + { // 4. Online Softmax Update - HVX_Vector v_max = hvx_vec_reduce_max_f32(scores); + v_max = hvx_vec_reduce_max_f32(v_max); float m_block = hvx_vec_get_f32(v_max); - float M_old = M; float M_new = (m_block > M) ? m_block : M; M = M_new; - float ms = expf(M_old - M_new); - + const float ms = expf(M_old - M_new); hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms); - S = S * ms; HVX_Vector M_new_vec = hvx_vec_splat_f32(M_new); - HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_new_vec); - HVX_Vector P = hvx_vec_exp_f32(Q6_Vsf_equals_Vqf32(scores_shifted)); - - HVX_Vector p_sum_vec = hvx_vec_reduce_sum_f32(P); - float p_sum = hvx_vec_get_f32(p_sum_vec); - S += p_sum; - - // 5. Accumulate V - float __attribute__((aligned(VLEN))) p_arr[VLEN_FP32]; - *(HVX_Vector*)p_arr = P; - - for (int j = 0; j < VLEN_FP32; ++j) { - const uint32_t cur_ic = ic + j; - const uint8_t * v_ptr = v_base + cur_ic * size_v_row_padded; - hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, p_arr[j]); + HVX_Vector p_sum_vec = hvx_vec_splat_f32(0.0f); + for (uint32_t ic2 = 0, iv = 0; ic2 + VLEN_FP32 <= current_block_size; ic2 += VLEN_FP32, ++iv) { + HVX_Vector scores = scores_x4.v[iv]; + HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_new_vec); + HVX_Vector P = hvx_vec_exp_f32(Q6_Vsf_equals_Vqf32(scores_shifted)); + + p_sum_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(p_sum_vec, P)); + + // 5. Accumulate V + float __attribute__((aligned(VLEN))) p_arr[VLEN_FP32]; + *(HVX_Vector*)p_arr = P; + + for (int j = 0; j < VLEN_FP32; ++j) { + const uint32_t cur_ic = ic2 + j; + const uint8_t * v_ptr = v_base + cur_ic * size_v_row_padded; + hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, p_arr[j]); + } } + + p_sum_vec = hvx_vec_reduce_sum_f32(p_sum_vec); + S = S * ms + hvx_vec_get_f32(p_sum_vec); } // Leftover From 13577a6ce4496aa3857dc6c878a4029c05ed7e69 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Sat, 24 Jan 2026 14:25:20 +0800 Subject: [PATCH 038/831] ggml-cuda: enable cuda-graphs for `n-cpu-moe` (llama/18934) * ggml-cuda: add split-wise cuda graph * add n-cpu-moe compare_llama_bench.py * fix hip/musa builds --- ggml/src/ggml-cuda/common.cuh | 38 ++++++++++++- ggml/src/ggml-cuda/ggml-cuda.cu | 95 ++++++++++++++++++++------------- ggml/src/ggml-cuda/mean.cu | 17 +++--- 3 files changed, 102 insertions(+), 48 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 179522d8355..09a491a836a 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -1327,10 +1327,44 @@ struct ggml_backend_cuda_context { cudaStream_t streams[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { { nullptr } }; cublasHandle_t cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr}; - std::unique_ptr cuda_graph; - int curr_stream_no = 0; +#ifdef USE_CUDA_GRAPH + // Map from first_node_ptr to cuda_graph - allows multiple graphs per context + // when the computation is split across CPU/GPU (e.g., with --n-cpu-moe) + std::unordered_map> cuda_graphs; + + ggml_cuda_graph * cuda_graph(const void * first_node_ptr) { + auto it = cuda_graphs.find(first_node_ptr); + if (it == cuda_graphs.end()) { + cuda_graphs[first_node_ptr] = std::make_unique(); + return cuda_graphs[first_node_ptr].get(); + } + return it->second.get(); + } + + // Check if any CUDA graph is enabled for this context (used by kernels that need to know + // if graphs are in use without having access to the specific graph key) + bool any_cuda_graph_enabled() const { + for (const auto & [key, graph] : cuda_graphs) { + if (graph && graph->is_enabled()) { + return true; + } + } + return false; + } + + // Check if any CUDA graph has an instance for this context + bool any_cuda_graph_has_instance() const { + for (const auto & [key, graph] : cuda_graphs) { + if (graph && graph->instance != nullptr) { + return true; + } + } + return false; + } +#endif // USE_CUDA_GRAPH + explicit ggml_backend_cuda_context(int device) : device(device), name(GGML_CUDA_NAME + std::to_string(device)) { diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index cda422defbe..99f0919a514 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2969,18 +2969,25 @@ static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_ return true; } +static const void * ggml_cuda_graph_get_key(ggml_cgraph * cgraph) { + return cgraph->nodes[0]; +} + static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) { bool res = false; - if (cuda_ctx->cuda_graph->instance == nullptr) { + const void * graph_key = ggml_cuda_graph_get_key(cgraph); + ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key); + + if (graph->instance == nullptr) { res = true; } // Check if the graph size has changed - if (cuda_ctx->cuda_graph->props.size() != (size_t)cgraph->n_nodes + cgraph->n_leafs) { + if (graph->props.size() != (size_t)cgraph->n_nodes + cgraph->n_leafs) { res = true; - cuda_ctx->cuda_graph->props.resize(cgraph->n_nodes + cgraph->n_leafs); + graph->props.resize(cgraph->n_nodes + cgraph->n_leafs); } // Loop over nodes in GGML graph to determine if CUDA graph update is required @@ -2988,37 +2995,38 @@ static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx for (int i = 0; i < cgraph->n_nodes; i++) { bool props_match = true; if (!res) { - props_match = ggml_cuda_graph_node_properties_match(cgraph->nodes[i], &cuda_ctx->cuda_graph->props[i]); + props_match = ggml_cuda_graph_node_properties_match(cgraph->nodes[i], &graph->props[i]); } if (!props_match) { res = true; } - ggml_cuda_graph_node_set_properties(&cuda_ctx->cuda_graph->props[i], cgraph->nodes[i]); + ggml_cuda_graph_node_set_properties(&graph->props[i], cgraph->nodes[i]); } for (int i = 0; i < cgraph->n_leafs; i++) { - bool props_match= true; + bool props_match = true; if (!res) { - props_match = ggml_cuda_graph_node_properties_match(cgraph->leafs[i], &cuda_ctx->cuda_graph->props[cgraph->n_nodes + i]); + props_match = ggml_cuda_graph_node_properties_match(cgraph->leafs[i], &graph->props[cgraph->n_nodes + i]); } if (!props_match) { res = true; } - ggml_cuda_graph_node_set_properties(&cuda_ctx->cuda_graph->props[cgraph->n_nodes + i], cgraph->leafs[i]); + ggml_cuda_graph_node_set_properties(&graph->props[cgraph->n_nodes + i], cgraph->leafs[i]); } return res; } -static void ggml_cuda_graph_update_executable(ggml_backend_cuda_context * cuda_ctx) { +static void ggml_cuda_graph_update_executable(ggml_backend_cuda_context * cuda_ctx, const void * graph_key) { + ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key); #if CUDART_VERSION >= 12000 cudaGraphExecUpdateResultInfo result_info; - cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info); + cudaError_t stat = cudaGraphExecUpdate(graph->instance, graph->graph, &result_info); #else cudaGraphNode_t errorNode; cudaGraphExecUpdateResult result_info; - cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &errorNode, &result_info); + cudaError_t stat = cudaGraphExecUpdate(graph->instance, graph->graph, &errorNode, &result_info); #endif // CUDART_VERSION >= 12000 if (stat == cudaErrorGraphExecUpdateFailure) { @@ -3029,14 +3037,14 @@ static void ggml_cuda_graph_update_executable(ggml_backend_cuda_context * cuda_c // The pre-existing graph exec cannot be updated due to violated constraints // so instead clear error and re-instantiate (void)cudaGetLastError(); - CUDA_CHECK(cudaGraphExecDestroy(cuda_ctx->cuda_graph->instance)); - cuda_ctx->cuda_graph->instance = nullptr; - CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0)); + CUDA_CHECK(cudaGraphExecDestroy(graph->instance)); + graph->instance = nullptr; + CUDA_CHECK(cudaGraphInstantiate(&graph->instance, graph->graph, NULL, NULL, 0)); } else { GGML_ASSERT(stat == cudaSuccess); } } -#endif +#endif // USE_CUDA_GRAPH static bool ggml_cuda_should_fuse_rope_set_rows(const ggml_tensor * rope, const ggml_tensor * view, @@ -3241,7 +3249,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, return false; } -static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, const bool use_cuda_graph, const bool cuda_graph_update_required) { +static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, const bool use_cuda_graph, const bool cuda_graph_update_required, const void * graph_key) { bool graph_evaluated_or_captured = false; // flag used to determine whether it is an integrated_gpu @@ -3695,13 +3703,14 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud } #ifdef USE_CUDA_GRAPH + ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key); if (use_cuda_graph && cuda_graph_update_required) { // End CUDA graph capture - if (cuda_ctx->cuda_graph->graph != nullptr) { - CUDA_CHECK(cudaGraphDestroy(cuda_ctx->cuda_graph->graph)); - cuda_ctx->cuda_graph->graph = nullptr; + if (graph->graph != nullptr) { + CUDA_CHECK(cudaGraphDestroy(graph->graph)); + graph->graph = nullptr; } - CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph)); + CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &graph->graph)); graph_evaluated_or_captured = true; // CUDA graph has been captured std::lock_guard lock(ggml_cuda_lock); @@ -3714,40 +3723,39 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud } if (use_cuda_graph) { - if (cuda_ctx->cuda_graph->instance == nullptr) { // Create executable graph from captured graph. - CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0)); + ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key); + if (graph->instance == nullptr) { // Create executable graph from captured graph. + CUDA_CHECK(cudaGraphInstantiate(&graph->instance, graph->graph, NULL, NULL, 0)); } if (cuda_graph_update_required) { // Update graph executable - ggml_cuda_graph_update_executable(cuda_ctx); + ggml_cuda_graph_update_executable(cuda_ctx, graph_key); } // Launch graph - CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream())); + CUDA_CHECK(cudaGraphLaunch(graph->instance, cuda_ctx->stream())); #else graph_evaluated_or_captured = true; #endif // USE_CUDA_GRAPH } } -static bool ggml_cuda_graph_set_enabled(ggml_backend_cuda_context * cuda_ctx) { +static bool ggml_cuda_graph_set_enabled(ggml_backend_cuda_context * cuda_ctx, const void * graph_key) { #ifdef USE_CUDA_GRAPH + ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key); - if (cuda_ctx->cuda_graph == nullptr) { - cuda_ctx->cuda_graph.reset(new ggml_cuda_graph()); - } - - if (cuda_ctx->cuda_graph->graph == nullptr) { + if (graph->graph == nullptr) { if (ggml_cuda_info().devices[cuda_ctx->device].cc < GGML_CUDA_CC_AMPERE) { - if (!cuda_ctx->cuda_graph->disable_due_to_gpu_arch) { + if (!graph->disable_due_to_gpu_arch) { GGML_LOG_DEBUG("%s: disabling CUDA graphs due to GPU architecture\n", __func__); } - cuda_ctx->cuda_graph->disable_due_to_gpu_arch = true; + graph->disable_due_to_gpu_arch = true; } } - return cuda_ctx->cuda_graph->is_enabled(); + return graph->is_enabled(); #else GGML_UNUSED(cuda_ctx); + GGML_UNUSED(graph_key); return false; #endif // USE_CUDA_GRAPH } @@ -3759,15 +3767,19 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, bool use_cuda_graph = false; bool cuda_graph_update_required = false; + const void * graph_key = nullptr; #ifdef USE_CUDA_GRAPH - use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx); + graph_key = ggml_cuda_graph_get_key(cgraph); + + use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx, graph_key); - if (cuda_ctx->cuda_graph->is_enabled()) { + ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key); + if (graph->is_enabled()) { cuda_graph_update_required = ggml_cuda_graph_update_required(cuda_ctx, cgraph); use_cuda_graph = ggml_cuda_graph_check_compability(cgraph); - cuda_ctx->cuda_graph->record_update(use_cuda_graph, cuda_graph_update_required); + graph->record_update(use_cuda_graph, cuda_graph_update_required); } #endif // USE_CUDA_GRAPH @@ -3781,7 +3793,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed)); } - ggml_cuda_graph_evaluate_and_capture(cuda_ctx, cgraph, use_cuda_graph, cuda_graph_update_required); + ggml_cuda_graph_evaluate_and_capture(cuda_ctx, cgraph, use_cuda_graph, cuda_graph_update_required, graph_key); return GGML_STATUS_SUCCESS; } @@ -3814,7 +3826,14 @@ static void ggml_backend_cuda_event_wait(ggml_backend_t backend, ggml_backend_ev static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph * cgraph) { ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context; - const bool use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx); +#ifdef USE_CUDA_GRAPH + const void * graph_key = ggml_cuda_graph_get_key(cgraph); + const bool use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx, graph_key); +#else + const bool use_cuda_graph = false; + GGML_UNUSED(cuda_ctx); + GGML_UNUSED(cgraph); +#endif static bool enable_graph_optimization = [] { const char * env = getenv("GGML_CUDA_GRAPH_OPT"); diff --git a/ggml/src/ggml-cuda/mean.cu b/ggml/src/ggml-cuda/mean.cu index 60542fc19dd..49af5389957 100644 --- a/ggml/src/ggml-cuda/mean.cu +++ b/ggml/src/ggml-cuda/mean.cu @@ -31,14 +31,15 @@ void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { #endif // USE_CUDA_GRAPH if ((nrows == 1) && #ifdef USE_CUDA_GRAPH - // CUDA_GRAPHS_DISABLED - ((ncols > 65536) && - ((ctx.cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) || - ctx.cuda_graph->is_enabled())) || - // CUDA_GRAPHS ENABLED - ((ncols > 32768) && - !((ctx.cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) || - ctx.cuda_graph->is_enabled()))) { + // Determine if CUDA graphs are effectively disabled for this context + // (no graph instance exists and we're not capturing, OR graphs are explicitly enabled) + (((ncols > 65536) && + (((!ctx.any_cuda_graph_has_instance()) && (iscapturing == cudaStreamCaptureStatusNone)) || + ctx.any_cuda_graph_enabled())) || + // CUDA graphs are enabled - use lower threshold + ((ncols > 32768) && + !(((!ctx.any_cuda_graph_has_instance()) && (iscapturing == cudaStreamCaptureStatusNone)) || + ctx.any_cuda_graph_enabled())))) { #else (ncols > 65536)) { #endif // USE_CUDA_GRAPH From f53eafd74557792e68719a75cd2cd1b205862f88 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sat, 24 Jan 2026 10:09:36 +0100 Subject: [PATCH 039/831] CUDA: re-use MLA K data for V in MMA FA (llama/19057) --- ggml/src/ggml-cuda/fattn-common.cuh | 76 ++++++++++++++-------------- ggml/src/ggml-cuda/fattn-mma-f16.cuh | 63 +++++++++++------------ ggml/src/ggml-cuda/fattn.cu | 5 ++ 3 files changed, 73 insertions(+), 71 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index a781fb91f5b..40c7725784c 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -782,12 +782,7 @@ void launch_fattn( const ggml_tensor * K = dst->src[1]; const ggml_tensor * V = dst->src[2]; - // TODO: make this more generic by removing the notion of "MLA". - // for example "is V a view of K?" so we can skip loading it. - // V strides should be driven by V itself and avoid assumption of the data layout - const bool is_mla = V->op == GGML_OP_VIEW && V->src[0] == K; - - GGML_ASSERT(V || is_mla); + const bool V_is_K_view = V->op == GGML_OP_VIEW && V->src[0] == K && V->data == K->data; const ggml_tensor * mask = dst->src[3]; const ggml_tensor * sinks = dst->src[4]; @@ -797,9 +792,9 @@ void launch_fattn( GGML_ASSERT(Q->type == GGML_TYPE_F32); GGML_ASSERT(KQV->type == GGML_TYPE_F32); - GGML_ASSERT( Q->nb[0] == ggml_element_size(Q)); - GGML_ASSERT( K->nb[0] == ggml_element_size(K)); - GGML_ASSERT(!V || V->nb[0] == ggml_element_size(V)); + GGML_ASSERT(Q->nb[0] == ggml_element_size(Q)); + GGML_ASSERT(K->nb[0] == ggml_element_size(K)); + GGML_ASSERT(V->nb[0] == ggml_element_size(V)); GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16); @@ -820,10 +815,10 @@ void launch_fattn( size_t nb12 = K->nb[2]; size_t nb13 = K->nb[3]; - const char * V_data = V ? (const char *) V->data : nullptr; - size_t nb21 = V ? V->nb[1] : nb11; - size_t nb22 = V ? V->nb[2] : nb12; - size_t nb23 = V ? V->nb[3] : nb13; + const char * V_data = (const char *) V->data; + size_t nb21 = V->nb[1]; + size_t nb22 = V->nb[2]; + size_t nb23 = V->nb[3]; if (need_f16_K && K->type != GGML_TYPE_F16) { const size_t bs = ggml_blck_size(K->type); @@ -852,32 +847,39 @@ void launch_fattn( K_data = (char *) K_f16.ptr; } - if (V && need_f16_V && V->type != GGML_TYPE_F16) { - const size_t bs = ggml_blck_size(V->type); - const size_t ts = ggml_type_size(V->type); - - V_f16.alloc(ggml_nelements(V)); - if (ggml_is_contiguously_allocated(V)) { - to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type); - to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream); - V_data = (char *) V_f16.ptr; - - nb21 = nb21*bs*sizeof(half)/ts; - nb22 = nb22*bs*sizeof(half)/ts; - nb23 = nb23*bs*sizeof(half)/ts; + if (need_f16_V && V->type != GGML_TYPE_F16) { + if (V_is_K_view) { + V_data = K_data; + nb21 = nb11; + nb22 = nb12; + nb23 = nb13; } else { - GGML_ASSERT(V->nb[0] == ts); - to_fp16_nc_cuda_t to_fp16 = ggml_get_to_fp16_nc_cuda(V->type); - const int64_t s01 = nb21 / ts; - const int64_t s02 = nb22 / ts; - const int64_t s03 = nb23 / ts; - to_fp16(V_data, V_f16.ptr, V->ne[0], V->ne[1], V->ne[2], V->ne[3], s01, s02, s03, main_stream); - - nb21 = V->ne[0] * sizeof(half); - nb22 = V->ne[1] * nb21; - nb23 = V->ne[2] * nb22; + const size_t bs = ggml_blck_size(V->type); + const size_t ts = ggml_type_size(V->type); + + V_f16.alloc(ggml_nelements(V)); + if (ggml_is_contiguously_allocated(V)) { + to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type); + to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream); + V_data = (char *) V_f16.ptr; + + nb21 = nb21*bs*sizeof(half)/ts; + nb22 = nb22*bs*sizeof(half)/ts; + nb23 = nb23*bs*sizeof(half)/ts; + } else { + GGML_ASSERT(V->nb[0] == ts); + to_fp16_nc_cuda_t to_fp16 = ggml_get_to_fp16_nc_cuda(V->type); + const int64_t s01 = nb21 / ts; + const int64_t s02 = nb22 / ts; + const int64_t s03 = nb23 / ts; + to_fp16(V_data, V_f16.ptr, V->ne[0], V->ne[1], V->ne[2], V->ne[3], s01, s02, s03, main_stream); + + nb21 = V->ne[0] * sizeof(half); + nb22 = V->ne[1] * nb21; + nb23 = V->ne[2] * nb22; + } + V_data = (char *) V_f16.ptr; } - V_data = (char *) V_f16.ptr; } const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1); diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 203569e3459..3e7d67b40dc 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -400,7 +400,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask( } template static __device__ __forceinline__ void flash_attn_ext_f16_iter( const float2 * const __restrict__ Q_f2, @@ -442,8 +442,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( constexpr int stride_tile_Q = DKQ/2 + 4; constexpr int stride_tile_K = nbatch_K2 + 4; - static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA"); - constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4; + constexpr int stride_tile_V = V_is_K_view ? stride_tile_K : nbatch_V2 + 4; const int k_VKQ_0 = kb0 * nbatch_fa; #if defined(TURING_MMA_AVAILABLE) @@ -456,7 +455,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( if constexpr (nstages > 1) { static_assert(!oob_check, "OOB check incompatible with multi-stage pipeline"); - static_assert(!mla, "multi-stage loading not implemented for MLA"); + static_assert(!V_is_K_view, "K data reuse not implemented multi-stage loading"); static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading"); constexpr bool use_cp_async = true; cp_async_wait_all(); @@ -471,8 +470,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } } + // For MLA K and V have the same data. + // Therefore, iterate over K in reverse and later re-use the data if possible. #pragma unroll - for (int k0_start = 0; k0_start < DKQ/2; k0_start += nbatch_K2) { + for (int k0_start = (DKQ/2-1) - (DKQ/2-1) % nbatch_K2; k0_start >= 0; k0_start -= nbatch_K2) { const int k0_stop = k0_start + nbatch_K2 < DKQ/2 ? k0_start + nbatch_K2 : DKQ/2; const int k0_diff = k0_stop - k0_start; @@ -776,6 +777,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } if constexpr (nstages > 1) { + static_assert(!V_is_K_view, "K data reuse not implemented multi-stage loading"); // Preload K tile for next iteration: constexpr bool use_cp_async = true; cp_async_wait_all(); @@ -791,11 +793,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } - // For MLA K and V have the same data. - // Therefore, iterate over V in reverse and re-use the data if possible. - static_assert(!mla || nstages <= 1, "combination of MLA and multi-stage loading not implemented"); - // constexpr int reusable_cutoff = mla ? (DV - 1) - (DV - 1) % (2*nbatch_K2) : DV; - constexpr int reusable_cutoff = DV; // TODO implement properly #if defined(AMD_WMMA_AVAILABLE) && !defined(LDMATRIX_TRANS_AVAILABLE) T_A_VKQ A_identity; make_identity_mat(A_identity); @@ -803,12 +800,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( // Calculate VKQ tile, need to use logical rather than physical elements for i0 due to transposition of V: #pragma unroll - for (int i0_stop = DV; i0_stop > 0; i0_stop -= 2*nbatch_V2) { - const int i0_start = i0_stop - 2*nbatch_V2 > 0 ? i0_stop - 2*nbatch_V2 : 0; - const int i0_diff = i0_stop - i0_start; + for (int i0_start = 0; i0_start < DV; i0_start += 2*nbatch_V2) { + static_assert(DV % (2*nbatch_V2) == 0, "bad loop size"); + const int i0_stop = i0_start + 2*nbatch_V2; + const int i0_diff = i0_stop - i0_start; if constexpr (nstages <= 1) { - if (i0_start < reusable_cutoff) { + if (!V_is_K_view || i0_stop > 2*nbatch_K2) { constexpr bool use_cp_async = nstages == 1; flash_attn_ext_f16_load_tile (V_h2 + int64_t(k_VKQ_0)*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V, k_VKQ_sup); @@ -818,7 +816,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( __syncthreads(); } } - const half2 * tile_V_i = i0_start < reusable_cutoff ? tile_V : tile_V + (i0_start - reusable_cutoff)/2; + const half2 * tile_V_i = !V_is_K_view || i0_stop > 2*nbatch_K2 ? tile_V : tile_V + i0_start/2; #if defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr int i0_stride = cols_per_warp == 8 ? T_C_VKQ::I : 2*T_C_VKQ::J; @@ -921,7 +919,7 @@ template struct mma_tile_sizes { }; #endif // defined(TURING_MMA_AVAILABLE) -template +template static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const float2 * const __restrict__ Q_f2, const half2 * const __restrict__ K_h2, @@ -975,8 +973,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( constexpr int stride_tile_Q = DKQ/2 + 4; constexpr int stride_tile_K = nbatch_K2 + 4; - static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA"); - constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4; + constexpr int stride_tile_V = V_is_K_view ? stride_tile_K : nbatch_V2 + 4; constexpr int stride_tile_KV_max = stride_tile_K > stride_tile_V ? stride_tile_K : stride_tile_V; extern __shared__ half2 tile_Q[]; @@ -1080,7 +1077,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( constexpr bool last_iter = false; constexpr int k_VKQ_sup = nbatch_fa; flash_attn_ext_f16_iter - (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap, ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, @@ -1089,7 +1086,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( constexpr bool last_iter = true; const int k_VKQ_sup = ne11 - kb0*nbatch_fa; flash_attn_ext_f16_iter - (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap, ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, @@ -1100,7 +1097,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( constexpr bool last_iter = false; constexpr int k_VKQ_sup = nbatch_fa; flash_attn_ext_f16_iter - (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap, ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, @@ -1109,7 +1106,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( constexpr bool last_iter = true; constexpr int k_VKQ_sup = nbatch_fa; flash_attn_ext_f16_iter - (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap, ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, @@ -1457,7 +1454,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( #endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) } -template +template __launch_bounds__(ggml_cuda_fattn_mma_get_nthreads(DKQ, DV, ncols1*ncols2), ggml_cuda_fattn_mma_get_occupancy(DKQ, DV, ncols1*ncols2)) static __global__ void flash_attn_ext_f16( const char * __restrict__ Q, @@ -1509,8 +1506,6 @@ static __global__ void flash_attn_ext_f16( } #endif // defined(AMD_WMMA_AVAILABLE) - static_assert(!mla || DKQ >= DV, "MLA needs DKQ >= DV"); - constexpr int ncols = ncols1 * ncols2; constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols); constexpr int nthreads = ggml_cuda_fattn_mma_get_nthreads(DKQ, DV, ncols); @@ -1523,7 +1518,7 @@ static __global__ void flash_attn_ext_f16( const int stride_K = nb11 / sizeof(half2); const int stride_mask = nb31 / sizeof(half); - const int stride_V = mla ? stride_K : nb21 / sizeof(half2); + const int stride_V = V_is_K_view ? stride_K : nb21 / sizeof(half2); const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa; const int iter_j = (ne01.z + (ncols1 - 1)) / ncols1; @@ -1553,7 +1548,7 @@ static __global__ void flash_attn_ext_f16( (const half *) (mask + nb33*(sequence % ne33)); float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2); - const half2 * V_h2 = mla ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); + const half2 * V_h2 = V_is_K_view ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr; const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f; @@ -1564,12 +1559,12 @@ static __global__ void flash_attn_ext_f16( constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer. if (kb0_start == 0) { constexpr bool needs_fixup = false; // CUDA block is working on an entire tile. - flash_attn_ext_f16_process_tile + flash_attn_ext_f16_process_tile (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop); } else { constexpr bool needs_fixup = true; // CUDA block is missing the beginning of a tile. - flash_attn_ext_f16_process_tile + flash_attn_ext_f16_process_tile (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop); } @@ -1597,7 +1592,7 @@ static __global__ void flash_attn_ext_f16( (const half *) (mask + nb33*(sequence % ne33)); float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2); - const half2 * V_h2 = mla ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); + const half2 * V_h2 = V_is_K_view ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr; const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f; @@ -1608,7 +1603,7 @@ static __global__ void flash_attn_ext_f16( constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks. constexpr bool needs_fixup = false; - flash_attn_ext_f16_process_tile + flash_attn_ext_f16_process_tile (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop); #else @@ -1644,7 +1639,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml const int cols_per_warp = std::min(ncols, get_cols_per_warp(cc)); const int nwarps = nthreads / WARP_SIZE; - constexpr bool mla = DKQ == 576; + constexpr bool V_is_K_view = DKQ == 576; // Guaranteed by the kernel selection logic in fattn.cu const size_t nbytes_shared_KV_1stage = nbatch_fa * std::max(nbatch_K2 + 4, nbatch_V2 + 4) * sizeof(half2); const size_t nbytes_shared_KV_2stage = nbatch_fa * (nbatch_K2 + 4 + nbatch_V2 + 4) * sizeof(half2); @@ -1669,7 +1664,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml fattn_kernel_t fattn_kernel; if (logit_softcap == 0.0f) { constexpr bool use_logit_softcap = false; - fattn_kernel = flash_attn_ext_f16; + fattn_kernel = flash_attn_ext_f16; #if !defined(GGML_USE_MUSA) static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; @@ -1680,7 +1675,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml #endif // !defined(GGML_USE_MUSA) } else { constexpr bool use_logit_softcap = true; - fattn_kernel = flash_attn_ext_f16; + fattn_kernel = flash_attn_ext_f16; #if !defined(GGML_USE_MUSA) static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 87f07a2f938..ba2b96bc327 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -247,6 +247,8 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const } } + const bool V_is_K_view = V->op == GGML_OP_VIEW && V->src[0] == K && V->data == K->data; + const int cc = ggml_cuda_info().devices[device].cc; switch (K->ne[0]) { @@ -269,6 +271,9 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const if (!gqa_opt_applies || gqa_ratio % 4 != 0) { return BEST_FATTN_KERNEL_NONE; } + if (!V_is_K_view) { + return BEST_FATTN_KERNEL_NONE; + } break; default: return BEST_FATTN_KERNEL_NONE; From d2b51404e482a0dc42de1d17e19d164e51f2dedf Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 25 Jan 2026 15:48:56 +0200 Subject: [PATCH 040/831] kv-cache : support V-less cache (llama/19067) * kv-cache : support V-less cache * cuda : better check for V_is_K_view * cuda : improve V_is_K_view check * graph : add comments * hparams : refactor --- ggml/src/ggml-cuda/fattn-common.cuh | 2 +- ggml/src/ggml-cuda/fattn.cu | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index 40c7725784c..13c5b0a4594 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -782,7 +782,7 @@ void launch_fattn( const ggml_tensor * K = dst->src[1]; const ggml_tensor * V = dst->src[2]; - const bool V_is_K_view = V->op == GGML_OP_VIEW && V->src[0] == K && V->data == K->data; + const bool V_is_K_view = V->view_src && V->view_offs == 0 && (V->view_src == K || V->view_src == K->view_src); const ggml_tensor * mask = dst->src[3]; const ggml_tensor * sinks = dst->src[4]; diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index ba2b96bc327..a5e66241817 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -247,7 +247,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const } } - const bool V_is_K_view = V->op == GGML_OP_VIEW && V->src[0] == K && V->data == K->data; + const bool V_is_K_view = V->view_src && V->view_offs == 0 && (V->view_src == K || V->view_src == K->view_src); const int cc = ggml_cuda_info().devices[device].cc; From 1642a4fb605179844ade8e0782bda04272bd2897 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Sun, 25 Jan 2026 23:25:58 +0800 Subject: [PATCH 041/831] ggml-cpu: Use tiled FA for prompt-processing (llama/19012) * ggml-cpu: Use tiled FA for prompt-processing the FA performance is gimped on CPU on long contexts because it essentially uses a vector kernel. This PR adds a tiled FA for PP. Perf tuning for tile sizes done on a AMD EPYC single-socket 64-c machine. * fix out of bounds for mask * skip rows where there are all masks * skip tile if mask is inf * store mask in worksize * check inf tile earlier --- ggml/src/ggml-cpu/common.h | 8 + ggml/src/ggml-cpu/ggml-cpu.c | 9 +- ggml/src/ggml-cpu/ops.cpp | 290 ++++++++++++++++++++++++++++++++++- 3 files changed, 303 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-cpu/common.h b/ggml/src/ggml-cpu/common.h index 6adca5437f8..1057b5bb152 100644 --- a/ggml/src/ggml-cpu/common.h +++ b/ggml/src/ggml-cpu/common.h @@ -6,6 +6,9 @@ #include "ggml-impl.h" #include "simd-mappings.h" +#define GGML_FA_TILE_Q 32 +#define GGML_FA_TILE_KV 16 + #ifdef __cplusplus #include @@ -84,4 +87,9 @@ static std::pair get_thread_range(const struct ggml_compute_pa return {ir0, ir1}; } +struct ggml_fa_tile_config { + static constexpr size_t Q = GGML_FA_TILE_Q; + static constexpr size_t KV = GGML_FA_TILE_KV; +}; + #endif diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 4c7a75e768a..b1de2ae8716 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -14,6 +14,7 @@ #include "vec.h" #include "ops.h" #include "ggml.h" +#include "common.h" #if defined(_MSC_VER) || defined(__MINGW32__) #include // using malloc.h with MSC/MINGW @@ -2866,10 +2867,12 @@ struct ggml_cplan ggml_graph_plan( } break; case GGML_OP_FLASH_ATTN_EXT: { - const int64_t ne10 = node->src[1]->ne[0]; // DK - const int64_t ne20 = node->src[2]->ne[0]; // DV + const int64_t DK = node->src[1]->ne[0]; + const int64_t DV = node->src[2]->ne[0]; - cur = sizeof(float)*(1*ne10 + 2*ne20)*n_tasks; // 1x head size K + 2x head size V (per thread) + // Tiled flash attention scratch (tile sizes defined in common.h) + // Per-thread: Q_q + KQ + mask + VKQ32 + V32 + padding + cur = sizeof(float)*(GGML_FA_TILE_Q*DK + 2*GGML_FA_TILE_Q*GGML_FA_TILE_KV + GGML_FA_TILE_Q*DV + GGML_FA_TILE_KV*DV)*n_tasks; } break; case GGML_OP_FLASH_ATTN_BACK: { diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 387e2fe42c3..48c89643619 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -8164,6 +8164,7 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( // online softmax / attention // loop over n_kv and n_head_kv // ref: https://arxiv.org/pdf/2112.05682.pdf + for (int64_t ic = 0; ic < nek1; ++ic) { const float mv = mp ? slope*GGML_CPU_FP16_TO_FP32(mp[ic]) : 0.0f; if (mv == -INFINITY) { @@ -8271,6 +8272,280 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( } } +static void ggml_compute_forward_flash_attn_ext_tiled( + const ggml_compute_params * params, + ggml_tensor * dst, + int ir0, int ir1) { + const ggml_tensor * q = dst->src[0]; + const ggml_tensor * k = dst->src[1]; + const ggml_tensor * v = dst->src[2]; + const ggml_tensor * mask = dst->src[3]; + const ggml_tensor * sinks = dst->src[4]; + + GGML_TENSOR_LOCALS(int64_t, neq, q, ne) + GGML_TENSOR_LOCALS(size_t, nbq, q, nb) + GGML_TENSOR_LOCALS(int64_t, nek, k, ne) + GGML_TENSOR_LOCALS(size_t, nbk, k, nb) + GGML_TENSOR_LOCALS(int64_t, nev, v, ne) + GGML_TENSOR_LOCALS(size_t, nbv, v, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + + const int64_t DK = nek0; + const int64_t DV = nev0; + const int64_t N = neq1; + + GGML_ASSERT(ne0 == DV); + GGML_ASSERT(ne2 == N); + + // input tensor rows must be contiguous + GGML_ASSERT(nbq0 == ggml_type_size(q->type)); + GGML_ASSERT(nbk0 == ggml_type_size(k->type)); + GGML_ASSERT(nbv0 == ggml_type_size(v->type)); + + GGML_ASSERT(neq0 == DK); + GGML_ASSERT(nek0 == DK); + GGML_ASSERT(nev0 == DV); + + GGML_ASSERT(neq1 == N); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + GGML_ASSERT(k->type == v->type); + const ggml_type kv_type = k->type; + + const auto * kv_type_traits_cpu = ggml_get_type_traits_cpu(kv_type); + const ggml_from_float_t kv_from_float = kv_type_traits_cpu->from_float; + const ggml_vec_dot_t kv_vec_dot = kv_type_traits_cpu->vec_dot; + const size_t kv_type_size = ggml_type_size(kv_type); + + // broadcast factors + const int64_t rk2 = neq2/nek2; + const int64_t rk3 = neq3/nek3; + + const int64_t rv2 = neq2/nev2; + const int64_t rv3 = neq3/nev3; + + float scale = 1.0f; + float max_bias = 0.0f; + float logit_softcap = 0.0f; + + memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); + memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); + memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float)); + + if (logit_softcap != 0) { + scale /= logit_softcap; + } + + const uint32_t n_head = neq2; + const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head)); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + int ith = params->ith; + + static constexpr int Q_TILE_SZ = ggml_fa_tile_config::Q; + static constexpr int KV_TILE_SZ = ggml_fa_tile_config::KV; + + GGML_ASSERT(nek1 % KV_TILE_SZ == 0 && "KV sequence length must be divisible by KV_TILE_SZ"); + + int ir = ir0; + while (ir < ir1) { + // q indices for the start of this tile + const int iq3 = ir/(neq2*neq1); + const int iq2 = (ir - iq3*neq2*neq1)/neq1; + const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1); + + // Number of valid rows in this tile: + // - limited by tile size (Q_TILE_SZ) + // - limited by chunk boundary (ir1 - ir) + // - limited by head boundary (neq1 - iq1) to avoid crossing into next head + const int tile_rows = MIN(Q_TILE_SZ, MIN((int)(ir1 - ir), (int)(neq1 - iq1))); + GGML_ASSERT(tile_rows > 0); + + const uint32_t h = iq2; // head index + const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f; + + float S[Q_TILE_SZ]; + float M[Q_TILE_SZ]; + + for (int i = 0 ; i < Q_TILE_SZ; ++i) { + S[i] = 0.; + M[i] = -INFINITY; + } + + // Per-thread scratch layout: + // Q_q: Q_TILE_SZ * DK (converted Q tile in KV type) + // KQ: Q_TILE_SZ * KV_TILE_SZ (attention scores in float) + // mask: Q_TILE_SZ * KV_TILE_SZ (mask in float) + // VKQ32: Q_TILE_SZ * DV (FP32 output accumulator) + // V32: KV_TILE_SZ * DV (F32 buffer for V tile - used for f166 conversion) + float * base = (float *) params->wdata + ith*(Q_TILE_SZ*DK + 2*Q_TILE_SZ*KV_TILE_SZ + Q_TILE_SZ*DV + KV_TILE_SZ*DV + CACHE_LINE_SIZE_F32); + + void * Q_q = base; + float * KQ = (float *)((char *)base + Q_TILE_SZ * DK * sizeof(float)); + float * mask32 = KQ + Q_TILE_SZ * KV_TILE_SZ; + float * VKQ32 = mask32 + Q_TILE_SZ * KV_TILE_SZ; + float * V32 = VKQ32 + Q_TILE_SZ * DV; // F32 buffer for V tile + + memset(VKQ32, 0, Q_TILE_SZ * DV * sizeof(float)); + memset(mask32, 0, Q_TILE_SZ * KV_TILE_SZ * sizeof(float)); + + // k indices + const int ik3 = iq3 / rk3; + const int ik2 = iq2 / rk2; + + // v indices + const int iv3 = iq3 / rv3; + const int iv2 = iq2 / rv2; + + for (int tq = 0; tq < tile_rows; tq++) { + const float * pq = (const float *) ((char *) q->data + ((iq1 + tq)*nbq1 + iq2*nbq2 + iq3*nbq3)); + kv_from_float(pq, (char *)Q_q + tq * DK * kv_type_size, DK); + } + // Zero-pad remaining rows + for (int tq = tile_rows; tq < Q_TILE_SZ; tq++) { + memset((char *)Q_q + tq * DK * kv_type_size, 0, DK * kv_type_size); + } + + for (int64_t ic = 0; ic < nek1; ic += KV_TILE_SZ) { + + // skip the tile entirely if all the masks are -inf + if (mask) { + bool can_skip = true; + for (int tq = 0; tq < tile_rows; tq++) { + const ggml_fp16_t * mp_row = (const ggml_fp16_t *)((const char *) mask->data + (iq1 + tq)*mask->nb[1] + (iq2%mask->ne[2])*mask->nb[2] + (iq3%mask->ne[3])*mask->nb[3]); + for (int tk = 0; tk < KV_TILE_SZ; tk++) { + mask32[tq * KV_TILE_SZ + tk] = slope * GGML_CPU_FP16_TO_FP32(mp_row[ic + tk]); + if (mask32[tq * KV_TILE_SZ + tk] != -INFINITY) { + can_skip = false; + } + } + } + + if (can_skip) { + continue; + } + } + + for (int tq = 0; tq < Q_TILE_SZ; tq++) { + const void * q_row = (const char *)Q_q + tq * DK * kv_type_size; + for (int tk = 0; tk < KV_TILE_SZ; tk++) { + const void * k_row = (const char *) k->data + ((ic + tk)*nbk1 + ik2*nbk2 + ik3*nbk3); + float s; + kv_vec_dot(DK, &s, 0, k_row, 0, q_row, 0, 1); + KQ[tq * KV_TILE_SZ + tk] = s * scale; + } + } + + if (logit_softcap != 0.0f) { + ggml_vec_tanh_f32(Q_TILE_SZ * KV_TILE_SZ, KQ, KQ); + ggml_vec_scale_f32(Q_TILE_SZ * KV_TILE_SZ, KQ, logit_softcap); + } + + if (mask) { + ggml_vec_add_f32(tile_rows * KV_TILE_SZ, KQ, KQ, mask32); + } + + bool skip[Q_TILE_SZ] = {}; + + for (int tq = 0; tq < Q_TILE_SZ; tq++) { + float * kq_row = KQ + tq * KV_TILE_SZ; + + float tile_max; + ggml_vec_max_f32(KV_TILE_SZ, &tile_max, kq_row); + + if (tile_max == -INFINITY) { + skip[tq] = true; + continue; + } + + const float Mold = M[tq]; + const float Mnew = fmaxf(Mold, tile_max); + + if (Mnew > Mold) { + const float ms = expf(Mold - Mnew); + ggml_vec_scale_f32(DV, VKQ32 + tq * DV, ms); + S[tq] *= ms; + } + M[tq] = Mnew; + + + S[tq] += ggml_vec_soft_max_f32(KV_TILE_SZ, kq_row, kq_row, Mnew); + } + + // Convert V tile to F32 first (if F16), then do MAD + // On x86, ggml_vec_mad_f16 internall converts F16<->F32 on every load/store, so pre-converting is faster. + // TODO: on ARM, native f16 should be faster + if (kv_type == GGML_TYPE_F16) { + for (int tk = 0; tk < KV_TILE_SZ; tk++) { + const ggml_fp16_t * v_row = (const ggml_fp16_t *)((const char *) v->data + ((ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3)); + ggml_fp16_to_fp32_row(v_row, V32 + tk * DV, DV); + } + for (int tq = 0; tq < Q_TILE_SZ; tq++) { + if (skip[tq]) continue; + float * vkq_row = VKQ32 + tq * DV; + for (int tk = 0; tk < KV_TILE_SZ; tk++) { + const float p = KQ[tq * KV_TILE_SZ + tk]; + ggml_vec_mad_f32(DV, vkq_row, V32 + tk * DV, p); + } + } + } else { + for (int tq = 0; tq < Q_TILE_SZ; tq++) { + if (skip[tq]) continue; + float * vkq_row = VKQ32 + tq * DV; + for (int tk = 0; tk < KV_TILE_SZ; tk++) { + const float p = KQ[tq * KV_TILE_SZ + tk]; + const float * v_row = (const float *)((const char *) v->data + ((ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3)); + ggml_vec_mad_f32(DV, vkq_row, v_row, p); + } + } + } + } + + // sinks (apply only to valid rows in the tile) + if (sinks) { + const float s = ((float *)((char *) sinks->data))[h]; + + for (int tq = 0; tq < tile_rows; tq++) { + float ms = 1.0f; + float vs = 1.0f; + + if (s > M[tq]) { + ms = expf(M[tq] - s); + ggml_vec_scale_f32(DV, VKQ32 + tq * DV, ms); + } else { + vs = expf(s - M[tq]); + } + + S[tq] = S[tq] * ms + vs; + } + } + + for (int tq = 0; tq < tile_rows; tq++) { + // V /= S + const float S_inv = S[tq] == 0.0f ? 0.0f : 1.0f / S[tq]; + ggml_vec_scale_f32(DV, VKQ32 + tq * DV, S_inv); + + // dst indices + const int i1 = iq1 + tq; + const int i2 = iq2; + const int i3 = iq3; + + // permute(0, 2, 1, 3) + memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32 + tq * DV, nb1); + } + + ir += tile_rows; + } +} + static void ggml_compute_forward_flash_attn_ext_f16( const ggml_compute_params * params, ggml_tensor * dst) { @@ -8343,6 +8618,15 @@ static void ggml_compute_forward_flash_attn_ext_f16( // The number of elements in each chunk const int64_t dr = (nr + nchunk - 1) / nchunk; + static constexpr int64_t KV_TILE_SZ = ggml_fa_tile_config::KV; + static constexpr int64_t Q_TILE_SZ = ggml_fa_tile_config::Q; + const bool kv_is_f32_or_f16 = (k->type == GGML_TYPE_F32 || k->type == GGML_TYPE_F16); + const bool use_tiled = (q->type == GGML_TYPE_F32 && + kv_is_f32_or_f16 && + k->type == v->type && + nek1 % KV_TILE_SZ == 0 && + neq1 >= Q_TILE_SZ); // Only use tiled for batch >= tile size + // The first chunk comes from our thread_id, the rest will get auto-assigned. int current_chunk = ith; @@ -8350,7 +8634,11 @@ static void ggml_compute_forward_flash_attn_ext_f16( const int64_t ir0 = dr * current_chunk; const int64_t ir1 = MIN(ir0 + dr, nr); - ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1); + if (use_tiled) { + ggml_compute_forward_flash_attn_ext_tiled(params, dst, ir0, ir1); + } else { + ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1); + } current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1); } From 4372b87b8e7b941fdc0d0176963e166747169454 Mon Sep 17 00:00:00 2001 From: ccbinn Date: Mon, 26 Jan 2026 02:07:19 +0800 Subject: [PATCH 042/831] metal : fix recommendedMaxWorkingSetSize availability on legacy iOS/macOS (llama/19088) Co-authored-by: chenbin11 --- ggml/src/ggml-metal/ggml-metal-device.m | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index eb4e2c209ce..7f9c384c344 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -785,8 +785,12 @@ ggml_metal_device_t ggml_metal_device_init(void) { dev->props.op_offload_min_batch_size = getenv("GGML_OP_OFFLOAD_MIN_BATCH") ? atoi(getenv("GGML_OP_OFFLOAD_MIN_BATCH")) : 32; dev->props.max_buffer_size = dev->mtl_device.maxBufferLength; - dev->props.max_working_set_size = dev->mtl_device.recommendedMaxWorkingSetSize; dev->props.max_theadgroup_memory_size = dev->mtl_device.maxThreadgroupMemoryLength; + if (@available(macOS 10.12, iOS 16.0, *)) { + dev->props.max_working_set_size = dev->mtl_device.recommendedMaxWorkingSetSize; + } else { + dev->props.max_working_set_size = dev->mtl_device.maxBufferLength; + } strncpy(dev->props.name, [[dev->mtl_device name] UTF8String], sizeof(dev->props.name) - 1); From f63848eada9a8a1c1a0ab52c389a15e189e33c58 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sun, 25 Jan 2026 21:19:47 +0100 Subject: [PATCH 043/831] CUDA: faster FA for GQA > 1 but not power of 2 (llama/19092) --- ggml/src/ggml-cuda/fattn-common.cuh | 22 +++--- ggml/src/ggml-cuda/fattn-mma-f16.cuh | 30 +++++---- ggml/src/ggml-cuda/fattn.cu | 67 ++++++++++++++++--- ...ttn-mma-f16-instance-ncols1_1-ncols2_32.cu | 5 ++ ...ttn-mma-f16-instance-ncols1_2-ncols2_32.cu | 5 ++ .../template-instances/generate_cu_files.py | 6 +- 6 files changed, 99 insertions(+), 36 deletions(-) create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index 13c5b0a4594..1f5f1b9206c 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -643,9 +643,10 @@ static __global__ void flash_attn_stream_k_fixup( const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa; const int iter_j = (ne01 + (ncols1 - 1)) / ncols1; + const int iter_z = (ne02 + (ncols2 - 1)) / ncols2; - const int kbc0 = int64_t(bidx0 + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x; - const int kbc0_stop = int64_t(bidx0 + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x; + const int kbc0 = int64_t(bidx0 + 0)*(iter_k*iter_j*iter_z*ne03) / gridDim.x; + const int kbc0_stop = int64_t(bidx0 + 1)*(iter_k*iter_j*iter_z*ne03) / gridDim.x; const bool did_not_have_any_data = kbc0 == kbc0_stop; const bool wrote_beginning_of_tile = kbc0 % iter_k == 0; @@ -654,15 +655,15 @@ static __global__ void flash_attn_stream_k_fixup( return; } - const int sequence = kbc0 / (iter_k*iter_j*(ne02/ncols2)); - const int head = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); - const int jt = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile. + const int sequence = kbc0 / (iter_k*iter_j*iter_z); + const int zt = (kbc0 - iter_k*iter_j*iter_z*sequence) / (iter_k*iter_j); + const int jt = (kbc0 - iter_k*iter_j*iter_z*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile. - if (jt*ncols1 + j >= ne01) { + if (jt*ncols1 + j >= ne01 || zt*ncols2 + c >= ne02) { return; } - dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + head*(ncols2*D) + (j*ne02 + c)*D + tid; + dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + zt*(ncols2*D) + (j*ne02 + c)*D + tid; // Load the partial result that needs a fixup: float dst_val = 0.0f; @@ -681,7 +682,7 @@ static __global__ void flash_attn_stream_k_fixup( int bidx = bidx0 - 1; int kbc_stop = kbc0; while(true) { - const int kbc = int64_t(bidx)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x; + const int kbc = int64_t(bidx)*(iter_k*iter_j*iter_z*ne03) / gridDim.x; if (kbc == kbc_stop) { // Did not have any data. bidx--; kbc_stop = kbc; @@ -883,7 +884,8 @@ void launch_fattn( } const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1); - const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3]; + const int ntiles_z = ((Q->ne[2] + ncols2 - 1) / ncols2); + const int ntiles_total = ntiles_x * ntiles_z * Q->ne[3]; // Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped. // Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or @@ -958,7 +960,7 @@ void launch_fattn( blocks_num.x = ntiles_x; blocks_num.y = parallel_blocks; - blocks_num.z = (Q->ne[2]/ncols2)*Q->ne[3]; + blocks_num.z = ntiles_z*Q->ne[3]; if (parallel_blocks > 1) { dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 3e7d67b40dc..9004d46904e 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -940,6 +940,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const int stride_V, const int stride_mask, const int jt, + const int zt, const int kb0_start, const int kb0_stop) { #if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) @@ -1022,7 +1023,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const int j = jc / ncols2; const int c = jc % ncols2; - if (jt*ncols1 + j < int(ne01.z)) { + if ((ncols1 == 1 || jt*ncols1 + j < int(ne01.z)) && (ncols2 == 1 || zt*ncols2 + c < ne02)) { #pragma unroll for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); @@ -1408,7 +1409,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const int j_dst = jc_dst / ncols2; const int c_dst = jc_dst % ncols2; - if (!is_fixup && jt*ncols1 + j_dst >= int(ne01.z)) { + if (!is_fixup && ((ncols1 > 1 && jt*ncols1 + j_dst >= int(ne01.z)) || (ncols2 > 1 && zt*ncols2 + c_dst >= ne02))) { continue; } @@ -1522,10 +1523,11 @@ static __global__ void flash_attn_ext_f16( const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa; const int iter_j = (ne01.z + (ncols1 - 1)) / ncols1; + const int iter_z = (ne02 + (ncols2 - 1)) / ncols2; // kbc == k block continuous, current index in continuous ijk space. - int kbc = int64_t(blockIdx.x + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x; - const int kbc_stop = int64_t(blockIdx.x + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x; + int kbc = int64_t(blockIdx.x + 0)*(iter_k*iter_j*iter_z*ne03) / gridDim.x; + const int kbc_stop = int64_t(blockIdx.x + 1)*(iter_k*iter_j*iter_z*ne03) / gridDim.x; // If the seams of 2 CUDA blocks fall within an output tile their results need to be combined. // For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup). @@ -1536,9 +1538,9 @@ static __global__ void flash_attn_ext_f16( int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc); while (kbc < kbc_stop && kb0_stop == iter_k) { - const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2)); - const int zt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); // head in units of ncols2 - const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile. + const int sequence = kbc / (iter_k*iter_j*iter_z); + const int zt = (kbc - iter_k*iter_j*iter_z*sequence) / (iter_k*iter_j); // head in units of ncols2 + const int jt = (kbc - iter_k*iter_j*iter_z*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile. const int head0 = zt * ncols2; @@ -1561,12 +1563,12 @@ static __global__ void flash_attn_ext_f16( constexpr bool needs_fixup = false; // CUDA block is working on an entire tile. flash_attn_ext_f16_process_tile (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, - ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop); + ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt, kb0_start, kb0_stop); } else { constexpr bool needs_fixup = true; // CUDA block is missing the beginning of a tile. flash_attn_ext_f16_process_tile (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, - ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop); + ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt, kb0_start, kb0_stop); } kbc += iter_k; @@ -1580,9 +1582,9 @@ static __global__ void flash_attn_ext_f16( return; } - const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2)); - const int zt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); // head in units of ncols2 - const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile. + const int sequence = kbc / (iter_k*iter_j*iter_z); + const int zt = (kbc - iter_k*iter_j*iter_z*sequence) / (iter_k*iter_j); // head in units of ncols2 + const int jt = (kbc - iter_k*iter_j*iter_z*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile. const int head0 = zt * ncols2; @@ -1605,7 +1607,7 @@ static __global__ void flash_attn_ext_f16( constexpr bool needs_fixup = false; flash_attn_ext_f16_process_tile (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, - ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop); + ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt, kb0_start, kb0_stop); #else GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, max_bias, m0, m1, n_head_log2, logit_softcap, @@ -1739,3 +1741,5 @@ extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16); extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4); extern DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4); extern DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4); +extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 32); +extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 32); diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index a5e66241817..2f5dbd13a39 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -18,9 +18,11 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_con } } - if ((turing_mma_available(cc) || amd_wmma_available(cc)) && Q->ne[1] <= 16/ncols2) { - ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); - return; + if constexpr (ncols2 <= 16) { + if ((turing_mma_available(cc) || amd_wmma_available(cc)) && Q->ne[1] <= 16/ncols2) { + ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); + return; + } } if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING || amd_wmma_available(cc) || Q->ne[1] <= 32/ncols2) { @@ -33,6 +35,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_con template static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; const ggml_tensor * KQV = dst; const ggml_tensor * Q = dst->src[0]; const ggml_tensor * K = dst->src[1]; @@ -60,17 +63,38 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_con GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); const int gqa_ratio = Q->ne[2] / K->ne[2]; - if (use_gqa_opt && gqa_ratio % 8 == 0) { + // On Volta the GQA optimizations aren't as impactful vs. minimizing wasted compute: + if (cc == GGML_CUDA_CC_VOLTA) { + if (use_gqa_opt && gqa_ratio % 8 == 0) { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst); + return; + } + + if (use_gqa_opt && gqa_ratio % 4 == 0) { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst); + return; + } + + if (use_gqa_opt && gqa_ratio % 2 == 0) { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst); + return; + } + + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst); + return; + } + + if (use_gqa_opt && gqa_ratio > 4) { ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst); return; } - if (use_gqa_opt && gqa_ratio % 4 == 0) { + if (use_gqa_opt && gqa_ratio > 2) { ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst); return; } - if (use_gqa_opt && gqa_ratio % 2 == 0) { + if (use_gqa_opt && gqa_ratio > 1) { ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst); return; } @@ -79,6 +103,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_con } static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; const ggml_tensor * KQV = dst; const ggml_tensor * Q = dst->src[0]; const ggml_tensor * K = dst->src[1]; @@ -121,8 +146,30 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); const int gqa_ratio = Q->ne[2] / K->ne[2]; - GGML_ASSERT(gqa_ratio % 4 == 0); - if (gqa_ratio % 16 == 0) { + if (gqa_ratio == 20) { // GLM 4.7 Flash + if (cc >= GGML_CUDA_CC_BLACKWELL) { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst); + break; + } + if (cc >= GGML_CUDA_CC_ADA_LOVELACE) { + if (Q->ne[1] <= 4) { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst); + break; + } + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst); + break; + } + if (cc >= GGML_CUDA_CC_TURING) { + if (Q->ne[1] <= 4) { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 32>(ctx, dst); + break; + } + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst); + break; + } + // Volta: + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst); + } else if (gqa_ratio % 16 == 0) { ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst); } else { ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst); @@ -234,7 +281,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const // The effective batch size for the kernel can be increased by gqa_ratio. // The kernel versions without this optimization are also used for ALiBi, if there is no mask, or if the KV cache is not padded, - bool gqa_opt_applies = gqa_ratio % 2 == 0 && mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0; + bool gqa_opt_applies = gqa_ratio >= 2 && mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0; for (const ggml_tensor * t : {Q, K, V, mask}) { if (t == nullptr || ggml_is_quantized(t->type)) { continue; @@ -268,7 +315,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const if (V->ne[0] != 512) { return BEST_FATTN_KERNEL_NONE; } - if (!gqa_opt_applies || gqa_ratio % 4 != 0) { + if (!gqa_opt_applies) { return BEST_FATTN_KERNEL_NONE; } if (!V_is_K_view) { diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu new file mode 100644 index 00000000000..1f554d81e5e --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(576, 512, 1, 32); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu new file mode 100644 index 00000000000..264751d65ec --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(576, 512, 2, 32); diff --git a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py index 10be71ab576..e382df1ae20 100755 --- a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +++ b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py @@ -71,7 +71,7 @@ def get_short_name(long_quant_name): f.write(SOURCE_FATTN_VEC.format(type_k=type_k, type_v=type_v)) for ncols in [8, 16, 32, 64]: - for ncols2 in [1, 2, 4, 8, 16]: + for ncols2 in [1, 2, 4, 8, 16, 32]: if ncols2 > ncols: continue ncols1 = ncols // ncols2 @@ -83,9 +83,9 @@ def get_short_name(long_quant_name): continue if head_size_kq == 72: continue - if head_size_kq != 576 and ncols2 == 16: + if head_size_kq != 576 and ncols2 in (16, 32): continue - if head_size_kq == 576 and ncols2 not in (4, 16): + if head_size_kq == 576 and ncols2 not in (4, 16, 32): continue head_size_v = head_size_kq if head_size_kq != 576 else 512 f.write(SOURCE_FATTN_MMA_CASE.format(ncols1=ncols1, ncols2=ncols2, head_size_kq=head_size_kq, head_size_v=head_size_v)) From 41d5d7bb0efda27c65ad302c71a0f185cb93fbf0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Mon, 26 Jan 2026 23:24:58 +0100 Subject: [PATCH 044/831] CUDA: fix padding of GQA to power of 2 in FA (llama/19115) --- ggml/src/ggml-cuda/fattn-common.cuh | 43 +++++++++-------- ggml/src/ggml-cuda/fattn-mma-f16.cuh | 69 +++++++++++++++------------- 2 files changed, 62 insertions(+), 50 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index 1f5f1b9206c..3d7daccfdf8 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -629,8 +629,8 @@ static __global__ void flash_attn_mask_to_KV_max( template // D == head size __launch_bounds__(D, 1) static __global__ void flash_attn_stream_k_fixup( - float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, const int ne11, - const int nbatch_fa) { + float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, + const int ne11, const int ne12, const int nbatch_fa) { constexpr int ncols = ncols1*ncols2; const int bidx0 = blockIdx.x; @@ -641,12 +641,14 @@ static __global__ void flash_attn_stream_k_fixup( const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols); - const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa; - const int iter_j = (ne01 + (ncols1 - 1)) / ncols1; - const int iter_z = (ne02 + (ncols2 - 1)) / ncols2; + const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - const int kbc0 = int64_t(bidx0 + 0)*(iter_k*iter_j*iter_z*ne03) / gridDim.x; - const int kbc0_stop = int64_t(bidx0 + 1)*(iter_k*iter_j*iter_z*ne03) / gridDim.x; + const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa; + const int iter_j = (ne01 + (ncols1 - 1)) / ncols1; + const int iter_z_gqa = (gqa_ratio + (ncols2 - 1)) / ncols2; + + const int kbc0 = int64_t(bidx0 + 0)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x; + const int kbc0_stop = int64_t(bidx0 + 1)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x; const bool did_not_have_any_data = kbc0 == kbc0_stop; const bool wrote_beginning_of_tile = kbc0 % iter_k == 0; @@ -655,15 +657,19 @@ static __global__ void flash_attn_stream_k_fixup( return; } - const int sequence = kbc0 / (iter_k*iter_j*iter_z); - const int zt = (kbc0 - iter_k*iter_j*iter_z*sequence) / (iter_k*iter_j); - const int jt = (kbc0 - iter_k*iter_j*iter_z*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile. + // z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index + const int sequence = kbc0 /(iter_k*iter_j*iter_z_gqa*ne12); + const int z_KV = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa); + const int zt_gqa = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j); + const int jt = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k; + + const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index. - if (jt*ncols1 + j >= ne01 || zt*ncols2 + c >= ne02) { + if (jt*ncols1 + j >= ne01 || zt_gqa*ncols2 + c >= gqa_ratio) { return; } - dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + zt*(ncols2*D) + (j*ne02 + c)*D + tid; + dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + zt_Q*D + (j*ne02 + c)*D + tid; // Load the partial result that needs a fixup: float dst_val = 0.0f; @@ -682,7 +688,7 @@ static __global__ void flash_attn_stream_k_fixup( int bidx = bidx0 - 1; int kbc_stop = kbc0; while(true) { - const int kbc = int64_t(bidx)*(iter_k*iter_j*iter_z*ne03) / gridDim.x; + const int kbc = int64_t(bidx)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x; if (kbc == kbc_stop) { // Did not have any data. bidx--; kbc_stop = kbc; @@ -883,9 +889,10 @@ void launch_fattn( } } - const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1); - const int ntiles_z = ((Q->ne[2] + ncols2 - 1) / ncols2); - const int ntiles_total = ntiles_x * ntiles_z * Q->ne[3]; + const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1); + const int gqa_ratio = Q->ne[2] / K->ne[2]; + const int ntiles_z_gqa = ((gqa_ratio + ncols2 - 1) / ncols2); + const int ntiles_total = ntiles_x * ntiles_z_gqa * K->ne[2] * Q->ne[3]; // Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped. // Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or @@ -960,7 +967,7 @@ void launch_fattn( blocks_num.x = ntiles_x; blocks_num.y = parallel_blocks; - blocks_num.z = ntiles_z*Q->ne[3]; + blocks_num.z = ntiles_z_gqa*K->ne[2]*Q->ne[3]; if (parallel_blocks > 1) { dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); @@ -1014,7 +1021,7 @@ void launch_fattn( flash_attn_stream_k_fixup <<>> - ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1], nbatch_fa); + ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1], K->ne[2], nbatch_fa); } } else if (parallel_blocks > 1) { const dim3 block_dim_combine(DV, 1, 1); diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 9004d46904e..0b8ef90794c 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -933,6 +933,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const float logit_softcap, const uint3 ne01, const int ne02, + const int gqa_ratio, const int ne11, const int stride_Q1, const int stride_Q2, @@ -940,7 +941,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const int stride_V, const int stride_mask, const int jt, - const int zt, + const int zt_gqa, const int kb0_start, const int kb0_stop) { #if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) @@ -1023,7 +1024,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const int j = jc / ncols2; const int c = jc % ncols2; - if ((ncols1 == 1 || jt*ncols1 + j < int(ne01.z)) && (ncols2 == 1 || zt*ncols2 + c < ne02)) { + if ((ncols1 == 1 || jt*ncols1 + j < int(ne01.z)) && (ncols2 == 1 || zt_gqa*ncols2 + c < gqa_ratio)) { #pragma unroll for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); @@ -1409,7 +1410,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const int j_dst = jc_dst / ncols2; const int c_dst = jc_dst % ncols2; - if (!is_fixup && ((ncols1 > 1 && jt*ncols1 + j_dst >= int(ne01.z)) || (ncols2 > 1 && zt*ncols2 + c_dst >= ne02))) { + if (!is_fixup && ((ncols1 > 1 && jt*ncols1 + j_dst >= int(ne01.z)) || (ncols2 > 1 && zt_gqa*ncols2 + c_dst >= gqa_ratio))) { continue; } @@ -1448,7 +1449,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( } #else GGML_UNUSED_VARS(Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dstk_fixup, - scale, slope, logit_softcap, ne01, ne02, + scale, slope, logit_softcap, ne01, ne02, gqa_ratio, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop); NO_DEVICE_CODE; @@ -1521,13 +1522,13 @@ static __global__ void flash_attn_ext_f16( const int stride_V = V_is_K_view ? stride_K : nb21 / sizeof(half2); - const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa; - const int iter_j = (ne01.z + (ncols1 - 1)) / ncols1; - const int iter_z = (ne02 + (ncols2 - 1)) / ncols2; + const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa; + const int iter_j = (ne01.z + (ncols1 - 1)) / ncols1; + const int iter_z_gqa = (gqa_ratio + (ncols2 - 1)) / ncols2; // kbc == k block continuous, current index in continuous ijk space. - int kbc = int64_t(blockIdx.x + 0)*(iter_k*iter_j*iter_z*ne03) / gridDim.x; - const int kbc_stop = int64_t(blockIdx.x + 1)*(iter_k*iter_j*iter_z*ne03) / gridDim.x; + int kbc = int64_t(blockIdx.x + 0)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x; + const int kbc_stop = int64_t(blockIdx.x + 1)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x; // If the seams of 2 CUDA blocks fall within an output tile their results need to be combined. // For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup). @@ -1538,22 +1539,24 @@ static __global__ void flash_attn_ext_f16( int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc); while (kbc < kbc_stop && kb0_stop == iter_k) { - const int sequence = kbc / (iter_k*iter_j*iter_z); - const int zt = (kbc - iter_k*iter_j*iter_z*sequence) / (iter_k*iter_j); // head in units of ncols2 - const int jt = (kbc - iter_k*iter_j*iter_z*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile. + // z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index + const int sequence = kbc /(iter_k*iter_j*iter_z_gqa*ne12); + const int z_KV = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa); + const int zt_gqa = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j); + const int jt = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k; - const int head0 = zt * ncols2; + const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index. - const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0); - const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio)); + const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*zt_Q); + const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*z_KV); const half * mask_h = ncols2 == 1 && !mask ? nullptr : (const half *) (mask + nb33*(sequence % ne33)); - float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2); + float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + zt_Q) * (DV/2); - const half2 * V_h2 = V_is_K_view ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); - const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr; + const half2 * V_h2 = V_is_K_view ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*z_KV); + const float * sinks_f = sinks ? (const float *) sinks + zt_Q : nullptr; - const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f; + const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, zt_Q, n_head_log2, m0, m1) : 1.0f; if (KV_max) { kb0_stop = min(kb0_stop, KV_max[sequence*iter_j + jt] / nbatch_fa); @@ -1563,12 +1566,12 @@ static __global__ void flash_attn_ext_f16( constexpr bool needs_fixup = false; // CUDA block is working on an entire tile. flash_attn_ext_f16_process_tile (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, - ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt, kb0_start, kb0_stop); + ne01, ne02, gqa_ratio, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt_gqa, kb0_start, kb0_stop); } else { constexpr bool needs_fixup = true; // CUDA block is missing the beginning of a tile. flash_attn_ext_f16_process_tile (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, - ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt, kb0_start, kb0_stop); + ne01, ne02, gqa_ratio, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt_gqa, kb0_start, kb0_stop); } kbc += iter_k; @@ -1582,22 +1585,24 @@ static __global__ void flash_attn_ext_f16( return; } - const int sequence = kbc / (iter_k*iter_j*iter_z); - const int zt = (kbc - iter_k*iter_j*iter_z*sequence) / (iter_k*iter_j); // head in units of ncols2 - const int jt = (kbc - iter_k*iter_j*iter_z*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile. + // z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index. + const int sequence = kbc /(iter_k*iter_j*iter_z_gqa*ne12); + const int z_KV = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa); + const int zt_gqa = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j); + const int jt = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k; - const int head0 = zt * ncols2; + const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index. - const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0); - const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio)); + const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*zt_Q); + const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*z_KV); const half * mask_h = ncols2 == 1 && !mask ? nullptr : (const half *) (mask + nb33*(sequence % ne33)); - float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2); + float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + zt_Q) * (DV/2); - const half2 * V_h2 = V_is_K_view ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); - const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr; + const half2 * V_h2 = V_is_K_view ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*z_KV); + const float * sinks_f = sinks ? (const float *) sinks + zt_Q : nullptr; - const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f; + const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, zt_Q, n_head_log2, m0, m1) : 1.0f; if (KV_max) { kb0_stop = min(kb0_stop, KV_max[sequence*iter_j + jt] / nbatch_fa); @@ -1607,7 +1612,7 @@ static __global__ void flash_attn_ext_f16( constexpr bool needs_fixup = false; flash_attn_ext_f16_process_tile (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, - ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt, kb0_start, kb0_stop); + ne01, ne02, gqa_ratio, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt_gqa, kb0_start, kb0_stop); #else GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, max_bias, m0, m1, n_head_log2, logit_softcap, From 56f82a9f33f864ead89261dafd048596a774cea0 Mon Sep 17 00:00:00 2001 From: lhez Date: Fri, 30 Jan 2026 10:34:38 +0200 Subject: [PATCH 045/831] opencl: add flattened q6_K mv (llama/19054) --- ggml/src/ggml-opencl/CMakeLists.txt | 3 +- ggml/src/ggml-opencl/ggml-opencl.cpp | 260 +++++++++++++++++- ggml/src/ggml-opencl/kernels/cvt.cl | 70 +++++ .../{mul_mv_q6_k.cl => mul_mv_q6_k_f32.cl} | 0 .../kernels/mul_mv_q6_k_f32_flat.cl | 194 +++++++++++++ 5 files changed, 518 insertions(+), 9 deletions(-) rename ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl => mul_mv_q6_k_f32.cl} (100%) create mode 100644 ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index 79039c30e14..0259474b6e1 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -85,7 +85,8 @@ set(GGML_OPENCL_KERNELS mul_mv_q4_0_f32_8x_flat mul_mv_q4_0_f32_1d_8x_flat mul_mv_q4_0_f32_1d_16x_flat - mul_mv_q6_k + mul_mv_q6_k_f32 + mul_mv_q6_k_f32_flat mul_mv_q8_0_f32 mul_mv_q8_0_f32_flat mul_mv_mxfp4_f32 diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 27b2761ef1e..678e40965ad 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -533,8 +533,10 @@ struct ggml_backend_opencl_context { cl_kernel kernel_mul_mat_q4_0_f32_8x_flat; cl_kernel kernel_convert_block_q4_0_noshuffle; cl_kernel kernel_restore_block_q4_0_noshuffle; + cl_kernel kernel_convert_block_q6_K, kernel_restore_block_q6_K; cl_kernel kernel_mul_mat_q4_0_f32_1d_8x_flat, kernel_mul_mat_q4_0_f32_1d_16x_flat; cl_kernel kernel_mul_mv_q6_K_f32; + cl_kernel kernel_mul_mv_q6_K_f32_flat; cl_kernel kernel_mul_mv_mxfp4_f32, kernel_mul_mv_mxfp4_f32_flat; cl_kernel kernel_mul_mv_q8_0_f32, kernel_mul_mv_q8_0_f32_flat; cl_kernel kernel_solve_tri_f32; @@ -892,6 +894,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve CL_CHECK((backend_ctx->kernel_restore_block_mxfp4 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_mxfp4", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q8_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q8_0", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_q8_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q8_0", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q6_K = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q6_K", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q6_K = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q6_K", &err), err)); GGML_LOG_CONT("."); } @@ -1114,14 +1118,14 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } - // mul_mv_q6_k + // mul_mv_q6_k_f32 { #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src { - #include "mul_mv_q6_k.cl.h" + #include "mul_mv_q6_k_f32.cl.h" }; #else - const std::string kernel_src = read_file("mul_mv_q6_k.cl"); + const std::string kernel_src = read_file("mul_mv_q6_k_f32.cl"); #endif backend_ctx->program_mul_mv_q6_K = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); @@ -1130,6 +1134,23 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // mul_mv_q6_k_f32_flat + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_q6_k_f32_flat.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_q6_k_f32_flat.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mv_q6_K_f32_flat = clCreateKernel(prog, "kernel_mul_mv_q6_K_f32_flat", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + // mul_mv_q8_0_f32 { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -2919,6 +2940,50 @@ struct ggml_tensor_extra_cl_q8_0 { } }; +struct ggml_tensor_extra_cl_q6_K { + // Lower 4 bits of quantized weights. + cl_mem ql = nullptr; + // Upper 2 bits of quantized weights. + cl_mem qh = nullptr; + // Scales for each block. + cl_mem s = nullptr; + // Scales for each super block. + cl_mem d = nullptr; + + size_t size_ql = 0; + size_t size_qh = 0; + size_t size_s = 0; + size_t size_d = 0; + + ~ggml_tensor_extra_cl_q6_K() { + reset(); + } + + void reset() { + if (ql != nullptr) { + CL_CHECK(clReleaseMemObject(ql)); + ql = nullptr; + } + if (qh != nullptr) { + CL_CHECK(clReleaseMemObject(qh)); + qh = nullptr; + } + if (s != nullptr) { + CL_CHECK(clReleaseMemObject(s)); + s = nullptr; + } + if (d != nullptr) { + CL_CHECK(clReleaseMemObject(d)); + d = nullptr; + } + + size_ql = 0; + size_qh = 0; + size_s = 0; + size_d = 0; + } +}; + //------------------------------------------------------------------------------ // Backend API //------------------------------------------------------------------------------ @@ -3465,6 +3530,12 @@ struct ggml_backend_opencl_buffer_context { for (ggml_tensor_extra_cl_q8_0 * e : temp_tensor_extras_q8_0_in_use) { delete e; } + for (ggml_tensor_extra_cl_q6_K * e : temp_tensor_extras_q6_K) { + delete e; + } + for (ggml_tensor_extra_cl_q6_K * e : temp_tensor_extras_q6_K_in_use) { + delete e; + } } ggml_tensor_extra_cl * ggml_opencl_alloc_temp_tensor_extra() { @@ -3527,6 +3598,21 @@ struct ggml_backend_opencl_buffer_context { return extra; } + ggml_tensor_extra_cl_q6_K * ggml_opencl_alloc_temp_tensor_extra_q6_K() { + ggml_tensor_extra_cl_q6_K * extra; + if (temp_tensor_extras_q6_K.empty()) { + extra = new ggml_tensor_extra_cl_q6_K(); + } else { + extra = temp_tensor_extras_q6_K.back(); + temp_tensor_extras_q6_K.pop_back(); + } + + temp_tensor_extras_q6_K_in_use.push_back(extra); + + extra->reset(); + return extra; + } + void reset() { for (ggml_tensor_extra_cl * e : temp_tensor_extras_in_use) { temp_tensor_extras.push_back(e); @@ -3547,6 +3633,11 @@ struct ggml_backend_opencl_buffer_context { temp_tensor_extras_q8_0.push_back(e); } temp_tensor_extras_q8_0_in_use.clear(); + + for (ggml_tensor_extra_cl_q6_K * e : temp_tensor_extras_q6_K_in_use) { + temp_tensor_extras_q6_K.push_back(e); + } + temp_tensor_extras_q6_K_in_use.clear(); } // Pools for extras. Available extras are in `temp_tensor_extras`. Extras @@ -3562,6 +3653,8 @@ struct ggml_backend_opencl_buffer_context { std::vector temp_tensor_extras_mxfp4_in_use; std::vector temp_tensor_extras_q8_0; std::vector temp_tensor_extras_q8_0_in_use; + std::vector temp_tensor_extras_q6_K; + std::vector temp_tensor_extras_q6_K_in_use; // The buffer_context is initially created by ggml_backend_buft_alloc_buffer // before any tensor is initialized (at the beginning of alloc_tensor_range). @@ -4068,6 +4161,92 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, return; } + if (tensor->type == GGML_TYPE_Q6_K) { + ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra; + GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized"); + + // Allocate the new extra and create aliases from the original. + ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; + ggml_tensor_extra_cl_q6_K * extra = ctx->ggml_opencl_alloc_temp_tensor_extra_q6_K(); + + size_t size_ql = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/2; + size_t size_qh = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/4; + size_t size_s = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/16; + size_t size_d = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t); + GGML_ASSERT(size_ql + size_qh + size_s + size_d == ggml_nbytes(tensor) && + "Incorrect tensor size"); + + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + CL_CHECK(clEnqueueWriteBuffer( + queue, data_device, CL_TRUE, 0, + ggml_nbytes(tensor), data, 0, NULL, NULL)); + + cl_buffer_region region; + + // Subbuffer for ql + region.origin = align_to(extra_orig->offset + tensor->view_offs + offset, backend_ctx->alignment); + region.size = size_ql; + extra->ql = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + auto previous_origin = region.origin; + + // Subbuffer for qh + region.origin = align_to(previous_origin + size_ql, backend_ctx->alignment); + region.size = size_qh; + extra->qh = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + previous_origin = region.origin; + + // Subbuffer for scales + region.origin = align_to(previous_origin + size_qh, backend_ctx->alignment); + region.size = size_s; + extra->s = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + previous_origin = region.origin; + + // Create subbuffer for d. + region.origin = align_to(previous_origin + size_s, backend_ctx->alignment); + region.size = size_d; + extra->d = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + previous_origin = region.origin; + + // Flatten the weights + cl_kernel kernel = backend_ctx->kernel_convert_block_q6_K; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->ql)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->qh)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->s)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra->d)); + + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + + extra->size_ql = size_ql; + extra->size_qh = size_qh; + extra->size_s = size_s; + extra->size_d = size_d; + + tensor->extra = extra; + return; + } #endif // GGML_OPENCL_SOA_Q ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *) tensor->extra; @@ -4277,6 +4456,34 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; size_t local_work_size[] = {1, 1, 1}; + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clEnqueueReadBuffer( + queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); + return; + } + if (tensor->type == GGML_TYPE_Q6_K) { + ggml_tensor_extra_cl_q6_K * extra = (ggml_tensor_extra_cl_q6_K *)tensor->extra; + + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + + cl_kernel kernel = backend_ctx->kernel_restore_block_q6_K; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->ql)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->s)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &data_device)); + + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {1, 1, 1}; + cl_event evt; CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); @@ -7765,6 +7972,7 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co ggml_tensor_extra_cl_q4_0 * extra0_q4_0 = (ggml_tensor_extra_cl_q4_0 *)src0->extra; ggml_tensor_extra_cl_mxfp4 * extra0_mxfp4 = (ggml_tensor_extra_cl_mxfp4 *)src0->extra; ggml_tensor_extra_cl_q8_0 * extra0_q8_0 = (ggml_tensor_extra_cl_q8_0 *)src0->extra; + ggml_tensor_extra_cl_q6_K * extra0_q6_K = (ggml_tensor_extra_cl_q6_K *)src0->extra; #endif const int ne00 = src0 ? src0->ne[0] : 0; @@ -8648,14 +8856,49 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: +#ifdef GGML_OPENCL_SOA_Q + kernel = backend_ctx->kernel_mul_mv_q6_K_f32_flat; + + if (backend_ctx->gpu_family == INTEL) { + nth0 = 16; + nth1 = 2; + ndst = 4; + } else if (backend_ctx->gpu_family == ADRENO) { + nth0 = 64; + nth1 = 2; + ndst = 4; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q6_K->ql)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q6_K->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q6_K->s)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q6_K->d)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &r3)); +#else kernel = backend_ctx->kernel_mul_mv_q6_K_f32; if (backend_ctx->gpu_family == INTEL) { - nth0 = 2; - nth1 = 16; + nth0 = 16; + nth1 = 2; + ndst = 1; } else if (backend_ctx->gpu_family == ADRENO) { - nth0 = 2; - nth1 = 64; + nth0 = 64; + nth1 = 2; + ndst = 1; } else { GGML_ASSERT(false && "TODO: Unknown GPU"); } @@ -8675,6 +8918,7 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne1)); CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &r2)); CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r3)); +#endif // GGML_OPENCL_SOA_Q break; case GGML_TYPE_MXFP4: { #ifdef GGML_OPENCL_SOA_Q @@ -8777,7 +9021,7 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co } else if (src0t == GGML_TYPE_Q5_K) { GGML_ASSERT(false && "not implemented"); } else if (src0t == GGML_TYPE_Q6_K) { - size_t global_work_size[] = {(size_t)(ne01+1)/2*nth0, (size_t)ne11*nth1, (size_t)ne12*ne13}; + size_t global_work_size[] = {(size_t)(ne01+ndst*nth1-1)/(ndst*nth1)*nth0, (size_t)ne11*nth1, (size_t)ne12*ne13}; size_t local_work_size[] = {(size_t)nth0, (size_t)nth1, 1}; backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); diff --git a/ggml/src/ggml-opencl/kernels/cvt.cl b/ggml/src/ggml-opencl/kernels/cvt.cl index 513a4d3e28f..adf576a8394 100644 --- a/ggml/src/ggml-opencl/kernels/cvt.cl +++ b/ggml/src/ggml-opencl/kernels/cvt.cl @@ -46,6 +46,16 @@ struct block_q4_0 uint8_t qs[QK4_0 / 2]; }; +//------------------------------------------------------------------------------ +// block_q6_K +//------------------------------------------------------------------------------ +struct block_q6_K { + uint8_t ql[QK_K/2]; // quants, lower 4 bits + uint8_t qh[QK_K/4]; // quants, upper 2 bits + int8_t scales[QK_K/16]; // scales, quantized with 8 bits + half d; // super-block scale +}; + //------------------------------------------------------------------------------ // kernel_convert_block_q4_0 // Convert the block_q4_0 format to 2 separate arrays (AOS -> SOA). @@ -263,3 +273,63 @@ kernel void kernel_restore_block_q8_0( b->qs[i] = q[i]; } } + +//------------------------------------------------------------------------------ +// kernel_convert_block_q6_K +// Convert the block_q6_K format to 3 separate arrays (AOS -> SOA). +// This kernel does not deshuffle the bits. +// Each thread processes a super block. +//------------------------------------------------------------------------------ +kernel void kernel_convert_block_q6_K( + global struct block_q6_K * src0, + global uchar * dst_ql, + global uchar * dst_qh, + global char * dst_s, + global half * dst_d +) { + global struct block_q6_K * b = (global struct block_q6_K *) src0 + get_global_id(0); + global uchar * ql = (global uchar *) dst_ql + QK_K/2*get_global_id(0); + global uchar * qh = (global uchar *) dst_qh + QK_K/4*get_global_id(0); + global char * s = (global char *) dst_s + QK_K/16*get_global_id(0); + global half * d = (global half *) dst_d + get_global_id(0); + + *d = b->d; + + for (int i = 0; i < QK_K/2; ++i) { + ql[i] = b->ql[i]; + } + for (int i = 0; i < QK_K/4; ++i) { + qh[i] = b->qh[i]; + } + for (int i = 0; i < QK_K/16; ++i) { + s[i] = b->scales[i]; + } +} + +// Restore block_q6_K from flattened arrays. +// Each thread processes a super block. +kernel void kernel_restore_block_q6_K( + global uchar * dst_ql, + global uchar * dst_qh, + global char * dst_s, + global half * dst_d, + global struct block_q6_K * dst +) { + global struct block_q6_K * b = (global struct block_q6_K *) dst + get_global_id(0); + global uchar * ql = (global uchar *) dst_ql + QK_K/2*get_global_id(0); + global uchar * qh = (global uchar *) dst_qh + QK_K/4*get_global_id(0); + global char * s = (global char *) dst_s + QK_K/16*get_global_id(0); + global half * d = (global half *) dst_d + get_global_id(0); + + b->d = *d; + + for (int i = 0; i < QK_K/2; ++i) { + b->ql[i] = ql[i]; + } + for (int i = 0; i < QK_K/4; ++i) { + b->qh[i] = qh[i]; + } + for (int i = 0; i < QK_K/16; ++i) { + b->scales[i] = s[i]; + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q6_k.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32.cl similarity index 100% rename from ggml/src/ggml-opencl/kernels/mul_mv_q6_k.cl rename to ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32.cl diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl new file mode 100644 index 00000000000..86fe09c6dd6 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl @@ -0,0 +1,194 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +//------------------------------------------------------------------------------ +// kernel_mul_mv_q6_K_f32_flat +//------------------------------------------------------------------------------ +#define Q6_K_MASK1 0x03 +#define Q6_K_MASK2 0x0C +#define Q6_K_MASK3 0x30 +#define Q6_K_MASK4 0xC0 + +#define QK_K 256 + +inline float block_q_6_K_dot_y_flat( + global uchar * blk_ql, + global uchar * blk_qh, + global char * blk_scales, + global half * blk_d, + global float * yy, + int ib, + int ip, + int is, + int l0 +) { + int y_offset = 128*ip + l0; + int q_offset_l = 64*ip + l0; + int q_offset_h = 32*ip + l0; + + global uchar * q1 = blk_ql + ib*128 + q_offset_l; + global uchar * q2 = q1 + QK_K/8; + global uchar * qh = blk_qh + ib*64 + q_offset_h; + global char * sc = blk_scales + ib*16 + is; + + global float * y = yy + ib * QK_K + y_offset; + + float dall = blk_d[ib]; + + float sumf = 0; + float4 sums = {0.f, 0.f, 0.f, 0.f}; + + sums.s0 += y[0+ 0] * ((float)((q1[0] & 0xF) | ((qh[0] & Q6_K_MASK1) << 4)) - 32.f); + sums.s1 += y[0+32] * ((float)((q2[0] & 0xF) | ((qh[0] & Q6_K_MASK2) << 2)) - 32.f); + sums.s2 += y[0+64] * ((float)((q1[0] >> 4) | ((qh[0] & Q6_K_MASK3) << 0)) - 32.f); + sums.s3 += y[0+96] * ((float)((q2[0] >> 4) | ((qh[0] & Q6_K_MASK4) >> 2)) - 32.f); + + sums.s0 += y[1+ 0] * ((float)((q1[1] & 0xF) | ((qh[1] & Q6_K_MASK1) << 4)) - 32.f); + sums.s1 += y[1+32] * ((float)((q2[1] & 0xF) | ((qh[1] & Q6_K_MASK2) << 2)) - 32.f); + sums.s2 += y[1+64] * ((float)((q1[1] >> 4) | ((qh[1] & Q6_K_MASK3) << 0)) - 32.f); + sums.s3 += y[1+96] * ((float)((q2[1] >> 4) | ((qh[1] & Q6_K_MASK4) >> 2)) - 32.f); + + sums.s0 += y[2+ 0] * ((float)((q1[2] & 0xF) | ((qh[2] & Q6_K_MASK1) << 4)) - 32.f); + sums.s1 += y[2+32] * ((float)((q2[2] & 0xF) | ((qh[2] & Q6_K_MASK2) << 2)) - 32.f); + sums.s2 += y[2+64] * ((float)((q1[2] >> 4) | ((qh[2] & Q6_K_MASK3) << 0)) - 32.f); + sums.s3 += y[2+96] * ((float)((q2[2] >> 4) | ((qh[2] & Q6_K_MASK4) >> 2)) - 32.f); + + sums.s0 += y[3+ 0] * ((float)((q1[3] & 0xF) | ((qh[3] & Q6_K_MASK1) << 4)) - 32.f); + sums.s1 += y[3+32] * ((float)((q2[3] & 0xF) | ((qh[3] & Q6_K_MASK2) << 2)) - 32.f); + sums.s2 += y[3+64] * ((float)((q1[3] >> 4) | ((qh[3] & Q6_K_MASK3) << 0)) - 32.f); + sums.s3 += y[3+96] * ((float)((q2[3] >> 4) | ((qh[3] & Q6_K_MASK4) >> 2)) - 32.f); + + sumf += dall * (sums.s0 * sc[0] + sums.s1 * sc[2] + sums.s2 * sc[4] + sums.s3 * sc[6]); + + return sumf; +} + +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 4 +#define N_SIMDGROUP 2 +#define N_SIMDWIDTH 16 +#elif defined (ADRENO_GPU) +#define N_DST 4 +#define N_SIMDGROUP 2 +#define N_SIMDWIDTH 64 +#endif + +#define BLOCK_STRIDE (N_SIMDWIDTH/16) // number of blocks each subgroup processes + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mv_q6_K_f32_flat( + global uchar * src0_ql, + global uchar * src0_qh, + global char * src0_s, + global half * src0_d, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + int nb = ne00/QK_K; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + int i12 = im%ne12; + int i13 = im/ne12; + + int first_row = (N_SIMDGROUP * r0 + get_sub_group_id()) * N_DST; + + ulong offset_src0 = first_row*nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + ulong offset_src0_ql = offset_src0 * 128; + ulong offset_src0_qh = offset_src0 * 64; + ulong offset_src0_s = offset_src0 * 16; + ulong offset_src0_d = offset_src0; + + global uchar * blk_ql = (global uchar *) src0_ql + offset_src0_ql; + global uchar * blk_qh = (global uchar *) src0_qh + offset_src0_qh; + global char * blk_scales = (global char *) src0_s + offset_src0_s; + global half * blk_d = (global half *) src0_d + offset_src0_d; + global float * yy = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + int tid = get_sub_group_local_id()/BLOCK_STRIDE; // first block_stride groups have tid=0 + int ix = get_sub_group_local_id()%BLOCK_STRIDE; // first block is 0..block_stride-1 + int ip = tid/8; // first or second half of (super) block (0 or 1) + int il = tid%8; // each half has 8 parts, one per scale + int n = 4; // 4 scales at a time (and 4 sums) + int l0 = n*il; // offset into half-block, 0..28 + int is = 8*ip + l0/16; // 0, 1, 8, 9 + + float4 sumf = 0; + + for (int ib = ix; ib < nb; ib += BLOCK_STRIDE) { + if (first_row + 0 < ne01) { + sumf.s0 += block_q_6_K_dot_y_flat(blk_ql + 0*nb*128, blk_qh + 0*nb*64, blk_scales + 0*nb*16, blk_d + 0*nb, yy, ib, ip, is, l0); + } + if (first_row + 1 < ne01) { + sumf.s1 += block_q_6_K_dot_y_flat(blk_ql + 1*nb*128, blk_qh + 1*nb*64, blk_scales + 1*nb*16, blk_d + 1*nb, yy, ib, ip, is, l0); + } + if (first_row + 2 < ne01) { + sumf.s2 += block_q_6_K_dot_y_flat(blk_ql + 2*nb*128, blk_qh + 2*nb*64, blk_scales + 2*nb*16, blk_d + 2*nb, yy, ib, ip, is, l0); + } + if (first_row + 3 < ne01) { + sumf.s3 += block_q_6_K_dot_y_flat(blk_ql + 3*nb*128, blk_qh + 3*nb*64, blk_scales + 3*nb*16, blk_d + 3*nb, yy, ib, ip, is, l0); + } + } + + float4 tot = (float4)( + sub_group_reduce_add(sumf.s0), + sub_group_reduce_add(sumf.s1), + sub_group_reduce_add(sumf.s2), + sub_group_reduce_add(sumf.s3) + ); + if (get_sub_group_local_id() == 0) { + if (first_row + 0 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; + } + if (first_row + 1 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; + } + if (first_row + 2 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; + } + if (first_row + 3 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; + } + } +} From b2e2032856d189c158aafacf853b5fb353461923 Mon Sep 17 00:00:00 2001 From: shalinib-ibm Date: Tue, 27 Jan 2026 09:22:34 +0530 Subject: [PATCH 046/831] ggml-cpu: Enable FP16 MMA kernels on PPC (llama/19060) --- ggml/src/ggml-cpu/llamafile/sgemm.cpp | 81 +++++++++++++++++++-------- 1 file changed, 58 insertions(+), 23 deletions(-) diff --git a/ggml/src/ggml-cpu/llamafile/sgemm.cpp b/ggml/src/ggml-cpu/llamafile/sgemm.cpp index 7dc36d4f8ad..8f980c16b96 100644 --- a/ggml/src/ggml-cpu/llamafile/sgemm.cpp +++ b/ggml/src/ggml-cpu/llamafile/sgemm.cpp @@ -1797,10 +1797,27 @@ class tinyBLAS_Q0_AVX { } \ } \ +template +struct mma_instr; + +template<> +struct mma_instr { + static inline void outer_product(acc_t *acc, vec_t a, vec_t b) { + __builtin_mma_xvbf16ger2pp(acc, a, b); + } +}; + +template<> +struct mma_instr { + static inline void outer_product(acc_t *acc, vec_t a, vec_t b) { + __builtin_mma_xvf16ger2pp(acc, a, b); + } +}; + template -class tinyBLAS_BF16_PPC { +class tinyBLAS_HP16_PPC { public: - tinyBLAS_BF16_PPC(int64_t k, + tinyBLAS_HP16_PPC(int64_t k, const TA *A, int64_t lda, const TB *B, int64_t ldb, TC *C, int64_t ldc, @@ -2118,8 +2135,8 @@ class tinyBLAS_BF16_PPC { packNormal((A+(ii*lda)+l), lda, 4, 8, (uint8_t*)vec_A); packNormal((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B); for (int x = 0; x < 4; x++) { - __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]); - __builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x], vec_B[x+4]); + mma_instr::outer_product(&acc_0, vec_A[x], vec_B[x]); + mma_instr::outer_product(&acc_1, vec_A[x], vec_B[x+4]); } } SAVE_ACC(&acc_0, ii, jj); @@ -2135,8 +2152,8 @@ class tinyBLAS_BF16_PPC { packNormal((A+(ii*lda)+l), lda, 8, 8, (uint8_t*)vec_A); packNormal((B+(jj*ldb)+l), ldb, 8, 4, (uint8_t*)vec_B); for (int x = 0; x < 4; x++) { - __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]); - __builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x+4], vec_B[x]); + mma_instr::outer_product(&acc_0, vec_A[x], vec_B[x]); + mma_instr::outer_product(&acc_1, vec_A[x], vec_B[x+4]); } } SAVE_ACC(&acc_0, ii, jj); @@ -2155,10 +2172,10 @@ class tinyBLAS_BF16_PPC { packNormal(A+(ii*lda)+l, lda, 8, 8, (uint8_t*)vec_A); packNormal(B+(jj*ldb)+l, ldb, 8, 8, (uint8_t*)vec_B); for (int x = 0; x < 4; x++) { - __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]); - __builtin_mma_xvbf16ger2pp(&acc_1, (vec_t)vec_A[x], (vec_t)vec_B[x+4]); - __builtin_mma_xvbf16ger2pp(&acc_2, (vec_t)vec_A[x+4], (vec_t)vec_B[x]); - __builtin_mma_xvbf16ger2pp(&acc_3, (vec_t)vec_A[x+4], (vec_t)vec_B[x+4]); + mma_instr::outer_product(&acc_0, vec_A[x], vec_B[x]); + mma_instr::outer_product(&acc_1, vec_A[x], vec_B[x+4]); + mma_instr::outer_product(&acc_2, vec_A[x+4], vec_B[x]); + mma_instr::outer_product(&acc_3, vec_A[x+4], vec_B[x+4]); } } @@ -2189,7 +2206,7 @@ class tinyBLAS_BF16_PPC { packNormal(A+(ii*lda)+l, lda, RM, 4, (uint8_t*)vec_A); packNormal(B+(jj*ldb)+l, ldb, RN, 4, (uint8_t*)vec_B); for (int x = 0; x<2; x++) { - __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]); + mma_instr::outer_product(&acc_0, vec_A[x], vec_B[x]); } } __builtin_mma_disassemble_acc(vec_C, &acc_0); @@ -2224,8 +2241,8 @@ class tinyBLAS_BF16_PPC { packNormal(A+(ii*lda)+l, lda, RM, 8, (uint8_t*)vec_A); packNormal(B+(jj*ldb)+l, ldb, RN, 8, (uint8_t*)vec_B); for (int x = 0; x<4; x++) { - __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]); - __builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x], vec_B[x+4]); + mma_instr::outer_product(&acc_0, vec_A[x], vec_B[x]); + mma_instr::outer_product(&acc_1, vec_A[x], vec_B[x+4]); } } __builtin_mma_disassemble_acc(vec_C, &acc_0); @@ -3418,16 +3435,19 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64 return tb.matmul(m, n); } #elif defined(__MMA__) - if ((k % 8)) - return false; - if(Btype == GGML_TYPE_BF16) { - tinyBLAS_BF16_PPC tb{ k, - (const ggml_bf16_t *)A, lda, - (const ggml_bf16_t *)B, ldb, - (float *)C, ldc, - params->ith, params->nth}; - tb.matmul(m, n); - return true; + if (k % 8) { + return false; + } + + if (Btype == GGML_TYPE_BF16) { + tinyBLAS_HP16_PPC tb{ k, + (const ggml_bf16_t *)A, lda, + (const ggml_bf16_t *)B, ldb, + (float *)C, ldc, + params->ith, params->nth }; + + tb.matmul(m, n); + return true; } #elif defined(__riscv_zvfbfwma) #if LMUL == 1 @@ -3516,6 +3536,21 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64 #endif return tb.matmul(m, n); } +#elif defined(__MMA__) + if (k % 8) { + return false; + } + + if (Btype == GGML_TYPE_F16) { + tinyBLAS_HP16_PPC tb{ k, + (const ggml_fp16_t *)A, lda, + (const ggml_fp16_t *)B, ldb, + (float *)C, ldc, + params->ith, params->nth }; + + tb.matmul(m, n); + return true; + } #endif return false; } From 5fcbbdc0ddda79214fe40d828fc338a4f28c29ff Mon Sep 17 00:00:00 2001 From: Gaurav Garg Date: Tue, 27 Jan 2026 06:52:44 +0000 Subject: [PATCH 047/831] Reduce CPU-side stalls due to the CUDA command buffer being full (llama/19042) * [CUDA] Reduce CPU-side stalls due to the CUDA command buffer being full With pipeline parallelism, during prompt processing, the CPU-side CUDA command buffer gets full, stalling the CPU. Due to this, enough work doesn't get submitted to the GPU, causing bubbles in the GPU timeline. Fix this by setting the CUDA environment variable CUDA_SCALE_LAUNCH_QUEUES to 4x to increase the command buffer size. * Set the env variable in the CUDA backend registry allocation * Add link to PR in code comment * Remove warning logs and update documentation --- ggml/src/ggml-cuda/ggml-cuda.cu | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 99f0919a514..e9df0ea4a7c 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -4876,6 +4876,16 @@ ggml_backend_reg_t ggml_backend_cuda_reg() { static std::mutex mutex; std::lock_guard lock(mutex); if (!initialized) { + // Set CUDA_SCALE_LAUNCH_QUEUES before any CUDA API call to improve multi-GPU pipeline parallelism performance + // PR: https://github.com/ggml-org/llama.cpp/pull/19042 + if (getenv("CUDA_SCALE_LAUNCH_QUEUES") == nullptr) { +#ifdef _WIN32 + _putenv_s("CUDA_SCALE_LAUNCH_QUEUES", "4x"); +#else + setenv("CUDA_SCALE_LAUNCH_QUEUES", "4x", 0); // don't overwrite if already set +#endif // _WIN32 + } + ggml_backend_cuda_reg_context * ctx = new ggml_backend_cuda_reg_context; const int min_batch_size = getenv("GGML_OP_OFFLOAD_MIN_BATCH") ? atoi(getenv("GGML_OP_OFFLOAD_MIN_BATCH")) : 32; From 00885e08e2fbb52eb8190f2554d5b2179e967733 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alberto=20Cabrera=20P=C3=A9rez?= <1478977+Alcpz@users.noreply.github.com> Date: Tue, 27 Jan 2026 09:08:10 +0000 Subject: [PATCH 048/831] ggml-cpu: aarm64: q6_K repack gemm and gemv (and generic) implementations (i8mm) #18860 (llama/18888) * Boilerplate for q6_K repack * q6_K repack to q6_Kx8 implementation Signed-off-by: Alberto Cabrera * q6_K generic gemv and gemm * wip, gemm_q6_K 8x8 * Still WIP: loading of q8s, q6h and q6l * first working version of q6_K gemm * Moved q6 loads outside of sb block, Unrolled inner loop * Replaced modulo with mask * First implementation of GEMV * ggml_vdotq_s32 -> vdotq_s32 * Reduce width of accumulators in q6_K gemv * Bsums instead of calc bias. Preload scales to use vget_lane. Unroll. * Reuse scales in GEMM (same GEMV opt) * Added todos for bsum and different qh repack * Arch fallback * VSLIQ for merging qh adn ql * Removed TODO, already tested * Apply suggestions Co-authored-by: Georgi Gerganov * Removed unused import --------- Signed-off-by: Alberto Cabrera Co-authored-by: Georgi Gerganov --- ggml/src/ggml-cpu/arch-fallback.h | 15 + ggml/src/ggml-cpu/arch/arm/repack.cpp | 429 +++++++++++++++++++++++++- ggml/src/ggml-cpu/repack.cpp | 339 ++++++++++++++++++-- ggml/src/ggml-cpu/repack.h | 16 +- 4 files changed, 771 insertions(+), 28 deletions(-) diff --git a/ggml/src/ggml-cpu/arch-fallback.h b/ggml/src/ggml-cpu/arch-fallback.h index 0a85a4cff30..427c1146e46 100644 --- a/ggml/src/ggml-cpu/arch-fallback.h +++ b/ggml/src/ggml-cpu/arch-fallback.h @@ -1,3 +1,4 @@ + #pragma once // Rename `_generic` functions if no native implementation is available. @@ -42,6 +43,7 @@ #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K #define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K +#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 @@ -53,6 +55,7 @@ #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K #define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K +# define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 @@ -73,6 +76,7 @@ #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K +#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 @@ -80,6 +84,7 @@ #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K +#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 @@ -102,6 +107,7 @@ #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K #define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K +#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 @@ -113,6 +119,7 @@ #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K #define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K +#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 @@ -136,6 +143,7 @@ #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K #define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K +#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 @@ -147,6 +155,7 @@ #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K #define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K +#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 @@ -177,6 +186,7 @@ #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K #define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K +#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 @@ -187,6 +197,7 @@ #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K #define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K +#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 @@ -216,6 +227,7 @@ #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K #define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K +#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 @@ -227,6 +239,7 @@ #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K #define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K +#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 @@ -258,6 +271,7 @@ #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K #define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K +#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 @@ -269,6 +283,7 @@ #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K #define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K +#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 diff --git a/ggml/src/ggml-cpu/arch/arm/repack.cpp b/ggml/src/ggml-cpu/arch/arm/repack.cpp index 883d862901b..f40226494cd 100644 --- a/ggml/src/ggml-cpu/arch/arm/repack.cpp +++ b/ggml/src/ggml-cpu/arch/arm/repack.cpp @@ -1055,10 +1055,10 @@ void ggml_gemv_q5_K_8x8_q8_K(int n, // FUSED BIAS: Compute and subtract bias immediately // bias = (bsums_lo * mins_lo + bsums_hi * mins_hi) * sb_min - int32x4_t bias = vmull_s16(bsums_vec_lo, group_mins_lo); - bias = vmlal_s16(bias, bsums_vec_hi, group_mins_hi); + int32x4_t bias = vmull_s16(bsums_vec_lo, group_mins_lo); + bias = vmlal_s16(bias, bsums_vec_hi, group_mins_hi); float32x4_t bias_f32 = vcvtq_f32_s32(bias); - acc_f32[i] = vmlsq_f32(acc_f32[i], sb_min, bias_f32); + acc_f32[i] = vmlsq_f32(acc_f32[i], sb_min, bias_f32); } } // for sb } // for b @@ -1072,6 +1072,208 @@ void ggml_gemv_q5_K_8x8_q8_K(int n, ggml_gemv_q5_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc); } +void ggml_gemv_q6_K_8x8_q8_K(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + constexpr int qk = QK_K; + const int nb = n / qk; + + constexpr int ncols_interleaved = 8; + constexpr int blocklen = 8; + + assert(n % qk == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + constexpr int col_pairs = ncols_interleaved / 2; + const uint8x16_t m4b = vdupq_n_u8(0x0f); + const uint8x16_t mask_lo = vdupq_n_u8(0x03); + const uint8x16_t mask_hi = vdupq_n_u8(0x30); + + // 1x8 tile = 2 x 4 + float32x4_t acc_f32[2]; + + const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy; + + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q6_Kx8 * GGML_RESTRICT q6_ptr = (const block_q6_Kx8 *) vx + (x * nb); + + acc_f32[0] = vdupq_n_f32(0); + acc_f32[1] = vdupq_n_f32(0); + + for (int b = 0; b < nb; b++) { + float32x4_t q6_d_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d)); // d0 d1 d2 d3 + float32x4_t q6_d_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d + 4)); // d4 d5 d6 d7 + float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d); + float32x4_t sb_scale_0 = vmulq_f32(q6_d_0, q8_d); + float32x4_t sb_scale_1 = vmulq_f32(q6_d_1, q8_d); + + int32x2_t acc[col_pairs]; + for (int i = 0; i < col_pairs; i++) { + acc[i] = vdup_n_s32(0); + } + + // Load all 16 scales once and widen to int16 (Q6_K has 16 scales per block) + // Reused for bias and dequantization later + int16_t q6_scales[16 * 8]; + for (int i = 0; i < 16; i++) { + int16x8_t scales = vmovl_s8(vld1_s8(q6_ptr[b].scales + i * 8)); + vst1q_s16(q6_scales + i * 8, scales); + } + + // Compute bias per column using q8 bsums and preloaded scales to skip the -32 shift + int32x4_t bias_lo = vdupq_n_s32(0); + int32x4_t bias_hi = vdupq_n_s32(0); + + // Load bsums in chunks of 4 to process with vectorized operations + for (int i = 0; i < 16; i += 4) { + int16x4_t bsums_vec = vld1_s16(q8_ptr[b].bsums + i); + int16x4_t scales_lo_0 = vld1_s16(q6_scales + (i + 0) * 8); + int16x4_t scales_hi_0 = vld1_s16(q6_scales + (i + 0) * 8 + 4); + int16x4_t scales_lo_1 = vld1_s16(q6_scales + (i + 1) * 8); + int16x4_t scales_hi_1 = vld1_s16(q6_scales + (i + 1) * 8 + 4); + int16x4_t scales_lo_2 = vld1_s16(q6_scales + (i + 2) * 8); + int16x4_t scales_hi_2 = vld1_s16(q6_scales + (i + 2) * 8 + 4); + int16x4_t scales_lo_3 = vld1_s16(q6_scales + (i + 3) * 8); + int16x4_t scales_hi_3 = vld1_s16(q6_scales + (i + 3) * 8 + 4); + + bias_lo = vmlal_lane_s16(bias_lo, scales_lo_0, bsums_vec, 0); + bias_hi = vmlal_lane_s16(bias_hi, scales_hi_0, bsums_vec, 0); + bias_lo = vmlal_lane_s16(bias_lo, scales_lo_1, bsums_vec, 1); + bias_hi = vmlal_lane_s16(bias_hi, scales_hi_1, bsums_vec, 1); + bias_lo = vmlal_lane_s16(bias_lo, scales_lo_2, bsums_vec, 2); + bias_hi = vmlal_lane_s16(bias_hi, scales_hi_2, bsums_vec, 2); + bias_lo = vmlal_lane_s16(bias_lo, scales_lo_3, bsums_vec, 3); + bias_hi = vmlal_lane_s16(bias_hi, scales_hi_3, bsums_vec, 3); + } + bias_lo = vshlq_n_s32(bias_lo, 5); + bias_hi = vshlq_n_s32(bias_hi, 5); + + // Process two 128-value halves per superblock + for (int half = 0; half < 2; half++) { + const uint8_t * ql_base = q6_ptr[b].ql + half * 512; + const uint8_t * qh_base = q6_ptr[b].qh + half * 256; + + // A subblock (sb) is a set of weights that share the scale + // Since q6_K scales are per 16 elements + // num sbs -> 256 elements / (16 elements/scale * 2 elements/byte * 2 halves) + for (int sb = 0; sb < QK_K / 64; sb++) { + const int8_t * q8_base_l = q8_ptr[b].qs + half * 128 + sb * 16; + const int8_t * q8_base_h = q8_base_l + 64; + + // Load and duplicate q8 values (each register covers two interleaved columns of q6) + int8x16_t q8_l[2]; + int8x16_t q8_h[2]; + for (int i = 0; i < 2; i++) { + q8_l[i] = (int8x16_t) vld1q_dup_s64((const int64_t *) (q8_base_l + i * 8)); + q8_h[i] = (int8x16_t) vld1q_dup_s64((const int64_t *) (q8_base_h + i * 8)); + } + + // TODO: Test other qh repack patterns to reduce loads + const int ql_off_base = sb * QK_K / 2; + const int qh_off_base = ql_off_base & 255; // wraps after 256 bytes + + // Load 4 vectors at once (64 bytes each for ql_0, ql_1, qh_0, qh_1) + ggml_uint8x16x4_t q6_ql_0 = ggml_vld1q_u8_x4(ql_base + ql_off_base); + ggml_uint8x16x4_t q6_ql_1 = ggml_vld1q_u8_x4(ql_base + ql_off_base + 64); + ggml_uint8x16x4_t q6_qh_0 = ggml_vld1q_u8_x4(qh_base + qh_off_base); + ggml_uint8x16x4_t q6_qh_1 = ggml_vld1q_u8_x4(qh_base + qh_off_base + 64); + + // Adjust qh for subblocks 2 and 3 (shift right by 2) + if (sb > 1) { + q6_qh_0.val[0] = vshrq_n_u8(q6_qh_0.val[0], 2); + q6_qh_0.val[1] = vshrq_n_u8(q6_qh_0.val[1], 2); + q6_qh_0.val[2] = vshrq_n_u8(q6_qh_0.val[2], 2); + q6_qh_0.val[3] = vshrq_n_u8(q6_qh_0.val[3], 2); + q6_qh_1.val[0] = vshrq_n_u8(q6_qh_1.val[0], 2); + q6_qh_1.val[1] = vshrq_n_u8(q6_qh_1.val[1], 2); + q6_qh_1.val[2] = vshrq_n_u8(q6_qh_1.val[2], 2); + q6_qh_1.val[3] = vshrq_n_u8(q6_qh_1.val[3], 2); + } + + // Process column pairs (0-1, 2-3, 4-5, 6-7) + for (int cp = 0; cp < col_pairs; cp++) { + const uint8x16_t q6_qs_cp_0_l = q6_ql_0.val[cp]; + const uint8x16_t q6_qs_cp_1_l = q6_ql_1.val[cp]; + const uint8x16_t q6_qs_cp_0_h = q6_qh_0.val[cp]; + const uint8x16_t q6_qs_cp_1_h = q6_qh_1.val[cp]; + + // Extract high 2 bits for upper nibble reconstruction + const uint8x16_t q6_qs_cp_0_hh = vandq_u8(q6_qs_cp_0_h, mask_hi); + const uint8x16_t q6_qs_cp_1_hh = vandq_u8(q6_qs_cp_1_h, mask_hi); + + // q6 = (low4 | high2<<4), without -32 bias (handled via bsums) + const int8x16_t q6_l0 = vreinterpretq_s8_u8( + vsliq_n_u8(vandq_u8(q6_qs_cp_0_l, m4b), vandq_u8(q6_qs_cp_0_h, mask_lo), 4)); + const int8x16_t q6_l1 = vreinterpretq_s8_u8( + vsliq_n_u8(vandq_u8(q6_qs_cp_1_l, m4b), vandq_u8(q6_qs_cp_1_h, mask_lo), 4)); + const int8x16_t q6_h0 = + vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_qs_cp_0_l, 4), q6_qs_cp_0_hh)); + const int8x16_t q6_h1 = + vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_qs_cp_1_l, 4), q6_qs_cp_1_hh)); + + int32x4_t sb_acc_l = vdupq_n_s32(0); + sb_acc_l = vdotq_s32(sb_acc_l, q6_l0, q8_l[0]); + sb_acc_l = vdotq_s32(sb_acc_l, q6_l1, q8_l[1]); + + int32x4_t sb_acc_h = vdupq_n_s32(0); + sb_acc_h = vdotq_s32(sb_acc_h, q6_h0, q8_h[0]); + sb_acc_h = vdotq_s32(sb_acc_h, q6_h1, q8_h[1]); + + // Pairwise add to get per-column sums: [col0, col1] + int32x2_t sum_l = vpadd_s32(vget_low_s32(sb_acc_l), vget_high_s32(sb_acc_l)); + int32x2_t sum_h = vpadd_s32(vget_low_s32(sb_acc_h), vget_high_s32(sb_acc_h)); + + const int scale_idx_l = half * 8 + sb; + const int scale_idx_h = half * 8 + sb + 4; + + // Access scales using array indexing (scales are interleaved by column) + const int32x2_t scale_vec_l = { (int32_t) q6_scales[scale_idx_l * 8 + cp * 2], + (int32_t) q6_scales[scale_idx_l * 8 + cp * 2 + 1] }; + const int32x2_t scale_vec_h = { (int32_t) q6_scales[scale_idx_h * 8 + cp * 2], + (int32_t) q6_scales[scale_idx_h * 8 + cp * 2 + 1] }; + + // Accumulate scaled results + acc[cp] = vmla_s32(acc[cp], sum_l, scale_vec_l); + acc[cp] = vmla_s32(acc[cp], sum_h, scale_vec_h); + } + } + } // for half + + // Bias correction + acc[0] = vsub_s32(acc[0], vget_low_s32(bias_lo)); + acc[1] = vsub_s32(acc[1], vget_high_s32(bias_lo)); + acc[2] = vsub_s32(acc[2], vget_low_s32(bias_hi)); + acc[3] = vsub_s32(acc[3], vget_high_s32(bias_hi)); + + // Apply superblock scale (no mins for q6_K) + // acc[cp] has [c0, c1] + float32x2_t w_01 = vmul_f32(vcvt_f32_s32(acc[0]), vget_low_f32(sb_scale_0)); + float32x2_t w_23 = vmul_f32(vcvt_f32_s32(acc[1]), vget_high_f32(sb_scale_0)); + float32x2_t w_45 = vmul_f32(vcvt_f32_s32(acc[2]), vget_low_f32(sb_scale_1)); + float32x2_t w_67 = vmul_f32(vcvt_f32_s32(acc[3]), vget_high_f32(sb_scale_1)); + + acc_f32[0] = vaddq_f32(acc_f32[0], vcombine_f32(w_01, w_23)); + acc_f32[1] = vaddq_f32(acc_f32[1], vcombine_f32(w_45, w_67)); + } // for b + + int base = x * ncols_interleaved; + vst1q_f32(s + base, acc_f32[0]); + vst1q_f32(s + base + 4, acc_f32[1]); + } // for x + return; +#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + ggml_gemv_q6_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc); +} + void ggml_gemv_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, @@ -3146,8 +3348,8 @@ void ggml_gemm_q5_K_8x8_q8_K(int n, const int8x16_t qs_lo_0 = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_cp_0, m4b), hbit_lo_0, 4)); int32x4_t acc_0 = sb_acc[0]; acc_0 = vmmlaq_s32(acc_0, qs_lo_0, q8s[0][0]); - int32x4_t acc_2 = sb_acc[2]; - acc_2 = vmmlaq_s32(acc_2, qs_lo_0, q8s[1][0]); + int32x4_t acc_2 = sb_acc[2]; + acc_2 = vmmlaq_s32(acc_2, qs_lo_0, q8s[1][0]); const int8x16_t qs_hi_0 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_0, 4), hbit_hi_0)); int32x4_t acc_1 = sb_acc[1]; acc_1 = vmmlaq_s32(acc_1, qs_hi_0, q8s[0][4]); @@ -3271,6 +3473,223 @@ void ggml_gemm_q5_K_8x8_q8_K(int n, ggml_gemm_q5_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc); } +void ggml_gemm_q6_K_8x8_q8_K(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + constexpr int qk = QK_K; + const int nb = n / qk; + + constexpr int ncols_interleaved = 8; + constexpr int blocklen = 8; + + assert(n % qk == 0); + assert(nr % 4 == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) + constexpr int q8_k_blocklen = 4; + const uint8x16_t m4b = vdupq_n_u8(0x0f); + const uint8x16_t mask_lo = vdupq_n_u8(0x03); + const uint8x16_t mask_hi = vdupq_n_u8(0x30); + const int8x16_t m32s = vdupq_n_s8(32); + + // 8 accumulators: 4 q8 rows × 2 col groups (0-3, 4-7) + float32x4_t acc_f32[blocklen]; + + for (int y = 0; y < nr / q8_k_blocklen; y++) { + const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb); + + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q6_Kx8 * GGML_RESTRICT q6_ptr = (const block_q6_Kx8 *) vx + (x * nb); + + for (int i = 0; i < blocklen; i++) { + acc_f32[i] = vdupq_n_f32(0); + } + + for (int b = 0; b < nb; b++) { + int32x4_t acc[8]; // rows 01 stored in [0][1][2][3], rows 23 stored in [4][5][6][7] + for (int i = 0; i < 8; i++) { + acc[i] = vdupq_n_s32(0); + } + + // Q6_K has simple 8-bit scales, 16 per block (one per 16 values) + // Reused for bias and dequantization later + int16_t q6_scales[16 * 8]; + for (int i = 0; i < 16; ++i) { + int16x8_t s16 = vmovl_s8(vld1_s8(q6_ptr[b].scales + i * 8)); + vst1q_s16(q6_scales + i * 8, s16); + } + + // Process two 128-value halves per superblock + for (int half = 0; half < 2; half++) { + + const uint8_t * ql_base = q6_ptr[b].ql + half * 512; + const uint8_t * qh_base = q6_ptr[b].qh + half * 256; + + // A subblock (sb) is a set of weights that share the scale + // Since q6_K scales are per 16 elements + // num sbs -> 256 elements / (16 elements/scale * 2 elements/byte * 2 halves) + for (int sb = 0; sb < QK_K / 64; sb++) { + // Q6_K weight index increasing by 64 instead of 32 requires + // loading various q8 memory regions + const int8_t * q8_base_l = q8_ptr[b].qs + half * 512 + sb * 64; + const int8_t * q8_base_h = q8_ptr[b].qs + half * 512 + 256 + sb * 64; + + int8x16_t q8_l_01[2]; + int8x16_t q8_l_23[2]; + for (int i = 0; i < 2; i++) { + const int offset = i * 32; + q8_l_01[i] = vld1q_s8(q8_base_l + offset); // 0..7 & 8..15 (r01) + q8_l_23[i] = vld1q_s8(q8_base_l + offset + 16); // 0..7 & 8..15 (r23) + } + + int8x16_t q8_h_01[2]; + int8x16_t q8_h_23[2]; + for (int i = 0; i < 2; i++) { + const int offset = i * 32; + q8_h_01[i] = vld1q_s8(q8_base_h + offset); + q8_h_23[i] = vld1q_s8(q8_base_h + offset + 16); + } + + const int ql_off_base = sb * QK_K / 2; + + uint8x16_t q6_ql_0[4]; + uint8x16_t q6_ql_1[4]; + for (int k = 0; k < 4; k++) { + q6_ql_0[k] = vld1q_u8(ql_base + ql_off_base + 16 * k); + q6_ql_1[k] = vld1q_u8(ql_base + ql_off_base + 64 + 16 * k); + } + + const int qh_off_base = (sb * QK_K / 2) & 255; // wrap after 256 bytes + uint8x16_t q6_qh_0[4]; + uint8x16_t q6_qh_1[4]; + for (int k = 0; k < 4; k++) { + q6_qh_0[k] = vld1q_u8(qh_base + qh_off_base + 16 * k); + q6_qh_1[k] = vld1q_u8(qh_base + qh_off_base + 64 + 16 * k); + } + + // Adjust for the proper high bits (Sb 2 and 3) + if (sb > 1) { + for (int k = 0; k < 4; k++) { + q6_qh_0[k] = vshrq_n_u8(q6_qh_0[k], 2); + q6_qh_1[k] = vshrq_n_u8(q6_qh_1[k], 2); + } + } + + // Process column pairs (0-1, 2-3, 4-5, 6-7) + for (int cp = 0; cp < ncols_interleaved / 2; cp++) { + const uint8x16_t q6_qs_cp_0_l = q6_ql_0[cp]; + const uint8x16_t q6_qs_cp_1_l = q6_ql_1[cp]; + const uint8x16_t q6_qs_cp_0_h = q6_qh_0[cp]; + const uint8x16_t q6_qs_cp_1_h = q6_qh_1[cp]; + + // Extract high 2 bits for upper nibble reconstruction + const uint8x16_t q6_qs_cp_0_hh = vandq_u8(q6_qs_cp_0_h, mask_hi); + const uint8x16_t q6_qs_cp_1_hh = vandq_u8(q6_qs_cp_1_h, mask_hi); + + // q6 = (low4 | high2<<4) - 32 + // Use vsliq_n_u8 to combine shift-left-insert in one instruction (like Q5_K) + const int8x16_t q6_l0 = vsubq_s8( + vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_qs_cp_0_l, m4b), vandq_u8(q6_qs_cp_0_h, mask_lo), 4)), + m32s); + const int8x16_t q6_l1 = vsubq_s8( + vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_qs_cp_1_l, m4b), vandq_u8(q6_qs_cp_1_h, mask_lo), 4)), + m32s); + const int8x16_t q6_h0 = vsubq_s8( + vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_qs_cp_0_l, 4), q6_qs_cp_0_hh)), m32s); + const int8x16_t q6_h1 = vsubq_s8( + vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_qs_cp_1_l, 4), q6_qs_cp_1_hh)), m32s); + + // row pair 0, base_l + int32x4_t sb_acc_0l = vmmlaq_s32(vdupq_n_s32(0), q6_l0, q8_l_01[0]); + sb_acc_0l = vmmlaq_s32(sb_acc_0l, q6_l1, q8_l_01[1]); + // row pair 0, base_h + int32x4_t sb_acc_0h = vmmlaq_s32(vdupq_n_s32(0), q6_h0, q8_h_01[0]); + sb_acc_0h = vmmlaq_s32(sb_acc_0h, q6_h1, q8_h_01[1]); + // row pair 1, base_l + int32x4_t sb_acc_1l = vmmlaq_s32(vdupq_n_s32(0), q6_l0, q8_l_23[0]); + sb_acc_1l = vmmlaq_s32(sb_acc_1l, q6_l1, q8_l_23[1]); + // row pair 1, base_h + int32x4_t sb_acc_1h = vmmlaq_s32(vdupq_n_s32(0), q6_h0, q8_h_23[0]); + sb_acc_1h = vmmlaq_s32(sb_acc_1h, q6_h1, q8_h_23[1]); + + const int scale_idx_l = half * 8 + sb; + const int scale_idx_h = half * 8 + sb + 4; + + const int32x4_t scale_vec_l = { + q6_scales[scale_idx_l * 8 + cp * 2 + 0], + q6_scales[scale_idx_l * 8 + cp * 2 + 0], + q6_scales[scale_idx_l * 8 + cp * 2 + 1], + q6_scales[scale_idx_l * 8 + cp * 2 + 1], + }; + const int32x4_t scale_vec_h = { + q6_scales[scale_idx_h * 8 + cp * 2 + 0], + q6_scales[scale_idx_h * 8 + cp * 2 + 0], + q6_scales[scale_idx_h * 8 + cp * 2 + 1], + q6_scales[scale_idx_h * 8 + cp * 2 + 1], + }; + + acc[cp] = vmlaq_s32(acc[cp], sb_acc_0l, scale_vec_l); + acc[cp] = vmlaq_s32(acc[cp], sb_acc_0h, scale_vec_h); + acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc_1l, scale_vec_l); + acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc_1h, scale_vec_h); + } + } + } // for half + + // Reorder i8mm output to match memory layout + for (int i = 0; i < 8; i++) { + int32x2x2_t aux = vzip_s32(vget_low_s32(acc[i]), vget_high_s32(acc[i])); + acc[i] = vcombine_s32(aux.val[0], aux.val[1]); + } + int32x4_t reorder_acc[8] = { + vcombine_s32(vget_low_s32(acc[0]), vget_low_s32(acc[1])), + vcombine_s32(vget_low_s32(acc[2]), vget_low_s32(acc[3])), + vcombine_s32(vget_high_s32(acc[0]), vget_high_s32(acc[1])), + vcombine_s32(vget_high_s32(acc[2]), vget_high_s32(acc[3])), + vcombine_s32(vget_low_s32(acc[4]), vget_low_s32(acc[5])), + vcombine_s32(vget_low_s32(acc[6]), vget_low_s32(acc[7])), + vcombine_s32(vget_high_s32(acc[4]), vget_high_s32(acc[5])), + vcombine_s32(vget_high_s32(acc[6]), vget_high_s32(acc[7])), + }; + + // Apply superblock scale (no mins for q6_K) + for (int i = 0; i < q8_k_blocklen; i++) { + for (int j = 0; j < 2; j++) { + float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d[i]); + float32x4_t q6_d = vcvt_f32_f16(vld1_f16((const __fp16 *) (q6_ptr[b].d + j * 4))); + const float32x4_t scale = vmulq_f32(q6_d, q8_d); + + acc_f32[2 * i + j] = + vmlaq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(reorder_acc[2 * i + j]), scale); + } + } + } // for b + + // Store results + for (int i = 0; i < q8_k_blocklen; i++) { + int row = y * q8_k_blocklen + i; + for (int j = 0; j < 2; j++) { + int col = x * ncols_interleaved + j * 4; + int offset = row * bs + col; + vst1q_f32(s + offset, acc_f32[2 * i + j]); + } + } + } // for x + } // for y + return; +#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) + ggml_gemm_q6_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc); +} + void ggml_gemm_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp index 19e021e59aa..24e8ab46182 100644 --- a/ggml/src/ggml-cpu/repack.cpp +++ b/ggml/src/ggml-cpu/repack.cpp @@ -703,6 +703,97 @@ void ggml_gemv_q5_K_8x8_q8_K_generic(int n, } } + +void ggml_gemv_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + constexpr int qk = QK_K; + const int nb = n / qk; + const int ncols_interleaved = 8; + const int blocklen = 8; + + assert(n % qk == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(bs); + UNUSED(nr); + + float sumf[8]; + + const block_q8_K * a_ptr = (const block_q8_K *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q6_Kx8 * b_ptr = (const block_q6_Kx8 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) { + sumf[j] = 0.0f; + } + + for (int l = 0; l < nb; l++) { + + + for (int k = 0; k < 16; k++) { + // k = 0.. 7 weights 0-63 low, 64-127 high + // k = 8..15 weights 128-191 low, 192-255 high + const int base_l = (k / 8) * 128 + (k % 8) * 8; + const int base_h = base_l + 64; + + const int scale_idx_l = base_l / 16; + const int scale_idx_h = base_h / 16; + + // Bit shift cycles 0,2,4,6 for each 32-value group within a 128-value half + const int qh_shift_l = ((base_l % 128) / 32) * 2; + const int qh_shift_h = ((base_h % 128) / 32) * 2; + + // qh_half: offset to the correct 32-byte half (0 or 32) + const int qh_half_l = (base_l / 128) * 32; + const int qh_half_h = (base_h / 128) * 32; + + for (int j = 0; j < ncols_interleaved; j++) { + // Interleaved scales + const int8_t scale_l = b_ptr[l].scales[scale_idx_l * 8 + j]; + const int8_t scale_h = b_ptr[l].scales[scale_idx_h * 8 + j]; + + int sumi_l = 0; + int sumi_h = 0; + + for (int i = 0; i < blocklen; i++) { + const int ql_pos = k * 64 + j * 8 + i; + const int l_4 = b_ptr[l].ql[ql_pos] & 0xF; + const int hi_4 = (b_ptr[l].ql[ql_pos] >> 4) & 0xF; + + // qh indexing with 8-byte interleaving (like q5_K) + const int qh_byte_l = qh_half_l + ((base_l + i) % 32); + const int qh_chunk_l = qh_byte_l / 8; + const int qh_pos_l = qh_byte_l % 8; + const int qh_offset_l = qh_chunk_l * 64 + j * 8 + qh_pos_l; + const int hi_2_l = (b_ptr[l].qh[qh_offset_l] >> qh_shift_l) & 0x3; + + const int qh_byte_h = qh_half_h + ((base_h + i) % 32); + const int qh_chunk_h = qh_byte_h / 8; + const int qh_pos_h = qh_byte_h % 8; + const int qh_offset_h = qh_chunk_h * 64 + j * 8 + qh_pos_h; + const int hi_2_h = (b_ptr[l].qh[qh_offset_h] >> qh_shift_h) & 0x3; + + const int q_l = ((hi_2_l << 4) | l_4) - 32; + const int q_h = ((hi_2_h << 4) | hi_4) - 32; + + const int8_t a_l = a_ptr[l].qs[base_l + i]; + const int8_t a_h = a_ptr[l].qs[base_h + i]; + + sumi_l += q_l * a_l; + sumi_h += q_h * a_h; + } + + sumf[j] += + (sumi_l * scale_l + sumi_h * scale_h) * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d; + } + } + } + + for (int j = 0; j < ncols_interleaved; j++) { + s[x * ncols_interleaved + j] = sumf[j]; + } + } +} + void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; @@ -1133,15 +1224,7 @@ void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, assert (nr % 4 == 0); assert (nc % ncols_interleaved == 0); - UNUSED(s); UNUSED(bs); - UNUSED(vx); - UNUSED(vy); - UNUSED(nr); - UNUSED(nc); - UNUSED(nb); - UNUSED(ncols_interleaved); - UNUSED(blocklen); float sumf[4][8]; float sum_minf[4][8]; @@ -1402,6 +1485,111 @@ void ggml_gemm_q5_K_8x8_q8_K_generic(int n, } } +void ggml_gemm_q6_K_8x8_q8_K_generic(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + const int qk = QK_K; + const int nb = n / qk; + const int ncols_interleaved = 8; + const int blocklen = 8; + + assert(n % qk == 0); + assert(nr % 4 == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(bs); + + float sumf[4][8]; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q6_Kx8 * b_ptr = (const block_q6_Kx8 *) vx + (x * nb); + + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumf[m][j] = 0.0f; + } + } + + for (int l = 0; l < nb; l++) { + for (int k = 0; k < 16; k++) { + // k = 0.. 7 weights 0-63 low, 64-127 high + // k = 8..15 weights 128-191 low, 192-255 high + const int base_l = (k / 8) * 128 + (k % 8) * 8; + const int base_h = base_l + 64; + + const int scale_idx_l = base_l / 16; + const int scale_idx_h = base_h / 16; + + // Bit shift cycles 0,2,4,6 for each 32-value group within a 128-value half + const int qh_shift_l = ((base_l % 128) / 32) * 2; + const int qh_shift_h = ((base_h % 128) / 32) * 2; + + // qh_half: offset to the correct 32-byte half (0 or 32) + const int qh_half_l = (base_l / 128) * 32; + const int qh_half_h = (base_h / 128) * 32; + + // Activation base indices for q8_Kx4 interleaved format + // Layout: 128-value halves (k/8), then 8-value sub-blocks (k%8) with stride 32 + const int q8_base = (k / 8) * 512 + (k % 8) * 32; + + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + // Interleaved scales + const int8_t scale_l = b_ptr[l].scales[scale_idx_l * 8 + j]; + const int8_t scale_h = b_ptr[l].scales[scale_idx_h * 8 + j]; + + int sumi_l = 0; + int sumi_h = 0; + + for (int i = 0; i < blocklen; i++) { + const int ql_pos = k * 64 + j * 8 + i; + const int l_4 = b_ptr[l].ql[ql_pos] & 0xF; + const int hi_4 = (b_ptr[l].ql[ql_pos] >> 4) & 0xF; + + const int qh_idx_l = qh_half_l + ((base_l + i) % 32); + const int qh_chunk_l = qh_idx_l / 8; + const int qh_pos_l = qh_idx_l % 8; + const int qh_offset_l = qh_chunk_l * 64 + j * 8 + qh_pos_l; + const int hi_2_l = (b_ptr[l].qh[qh_offset_l] >> qh_shift_l) & 0x3; + + const int qh_idx_h = qh_half_h + ((base_h + i) % 32); + const int qh_chunk_h = qh_idx_h / 8; + const int qh_pos_h = qh_idx_h % 8; + const int qh_offset_h = qh_chunk_h * 64 + j * 8 + qh_pos_h; + const int hi_2_h = (b_ptr[l].qh[qh_offset_h] >> qh_shift_h) & 0x3; + + const int q_l = ((hi_2_l << 4) | l_4) - 32; + const int q_h = ((hi_2_h << 4) | hi_4) - 32; + + const int8_t q8_l = a_ptr[l].qs[q8_base + m * 8 + i]; + const int8_t q8_h = a_ptr[l].qs[q8_base + m * 8 + i + 256]; + + sumi_l += q_l * q8_l; + sumi_h += q_h * q8_h; + } + + sumf[m][j] += (sumi_l * scale_l + sumi_h * scale_h) * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * + a_ptr[l].d[m]; + } + } + } + } + + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + } + } + } + } +} + void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; @@ -1801,8 +1989,7 @@ static block_q2_Kx8 make_block_q2_Kx8(block_q2_K * in, unsigned int blck_size_in // Every 16 byte is packed such that it contains scales and mins for corresponding sub blocks from Q2_K structure // For eg - First 16 bytes contains 16 scales and 16 mins - each of first and second sub blocks from different Q2_K structures - for(int i = 0; i < 128; i++){ - + for (int i = 0; i < 128; i++) { // Index for selecting which q2k super block int src1 = (i % 16) / 2; // Index for selecting scale @@ -1902,6 +2089,52 @@ static block_q5_Kx8 make_block_q5_Kx8(block_q5_K * in, unsigned int blck_size_in return out; } +static block_q6_Kx8 make_block_q6_Kx8(block_q6_K * in, unsigned int blck_size_interleave) { + block_q6_Kx8 out; + constexpr int n_blocks = 8; // Kx8 + for (int i = 0; i < n_blocks; i++) { + out.d[i] = in[i].d; + } + + const int end_ls = QK_K * 4 / blck_size_interleave; + // Interleave Q6_K quants by taking 8 bytes at a time + for (int i = 0; i < end_ls; ++i) { + int src_id = i % n_blocks; + int src_offset = (i / n_blocks) * blck_size_interleave; + int dst_offset = i * blck_size_interleave; + + uint64_t elem_ls; + memcpy(&elem_ls, &in[src_id].ql[src_offset], sizeof(uint64_t)); + memcpy(&out.ql[dst_offset], &elem_ls, sizeof(uint64_t)); + } + + // Interleave high bits using same 8-byte pattern as low bits + const int end_hs = end_ls / 2; + for (int i = 0; i < end_hs; ++i) { + int src_id = i % n_blocks; + int src_offset = (i / n_blocks) * blck_size_interleave; + int dst_offset = i * blck_size_interleave; + + uint64_t elem_hs; + memcpy(&elem_hs, &in[src_id].qh[src_offset], sizeof(uint64_t)); + memcpy(&out.qh[dst_offset], &elem_hs, sizeof(uint64_t)); + } + + // The below logic is designed so as to unpack and rearrange scales in Q6_K + // The output Q6_Kx8 structure interleaves the 8 bit scales in the same fashion as the quants + // Q6_K structure has an 8-bit scale per 16 elements -> 16 scales + // scales: [0 bl0 0 bl1 ... 0 bl7][1 bl0 ... 1 bl7] ... [15 bl0 ... 15 bl7] (bl = block) + constexpr int n_scales = QK_K / 16; + + for (int i = 0; i < n_blocks; i++) { + for (int j = 0; j < n_scales; j++) { + out.scales[j * n_blocks + i] = in[i].scales[j]; + } + } + + return out; +} + static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { GGML_ASSERT(t->type == GGML_TYPE_Q4_0); GGML_ASSERT(interleave_block == 4 || interleave_block == 8); @@ -1983,7 +2216,7 @@ static int repack_q2_K_to_q2_K_8_bl(struct ggml_tensor * t, int interleave_block for (int b = 0; b < nrow; b += nrows_interleaved) { for (int64_t x = 0; x < nblocks; x++) { - for (int i = 0; i < nrows_interleaved; i++ ) { + for (int i = 0; i < nrows_interleaved; i++) { dst_tmp[i] = src[x + i * nblocks]; } *dst++ = make_block_q2_Kx8(dst_tmp, interleave_block); @@ -2027,6 +2260,35 @@ static int repack_q5_K_to_q5_K_8_bl(struct ggml_tensor * t, return 0; } +static int repack_q6_K_to_q6_K_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q6_K); + GGML_ASSERT(interleave_block == 8); + constexpr int nrows_interleaved = 8; + + block_q6_Kx8 * dst = (block_q6_Kx8 *)t->data; + const block_q6_K * src = (const block_q6_K *) data; + block_q6_K dst_tmp[8]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_K; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q6_K)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_q6_Kx8(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; +} + static int repack_q4_0_to_q4_0_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { GGML_ASSERT(t->type == GGML_TYPE_Q4_0); GGML_ASSERT(interleave_block == 8); @@ -2249,6 +2511,10 @@ template <> int repack(struct ggml_tensor * t, const void * da return repack_q5_K_to_q5_K_8_bl(t, 8, data, data_size); } +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q6_K_to_q6_K_8_bl(t, 8, data, data_size); +} + template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { return repack_iq4_nl_to_iq4_nl_4_bl(t, 4, data, data_size); } @@ -2286,7 +2552,14 @@ template <> void gemv(int n, float * s, size_t ggml_gemv_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc); } -template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { +template <> +void gemv(int n, + float * s, + size_t bs, + const void * vx, + const void * vy, + int nr, + int nc) { ggml_gemv_q2_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); } @@ -2302,6 +2575,10 @@ template <> void gemv(int n, float * s, size_t ggml_gemv_q5_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); } +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q6_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); +} + template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemv_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc); } @@ -2330,7 +2607,14 @@ template <> void gemm(int n, float * s, size_t ggml_gemm_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc); } -template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { +template <> +void gemm(int n, + float * s, + size_t bs, + const void * vx, + const void * vy, + int nr, + int nc) { ggml_gemm_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc); } @@ -2350,6 +2634,10 @@ template <> void gemm(int n, float * s, size_t ggml_gemm_q5_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); } +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q6_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); +} + template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc); } @@ -2714,20 +3002,19 @@ template (ne00, - (float *)((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start, ne01, - src0_cur + src0_cur_start * nb01, - src1_col, 1, src0_cur_end - src0_cur_start); + gemv( + ne00, (float *) ((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start, ne01, + src0_cur + src0_cur_start * nb01, src1_col, 1, src0_cur_end - src0_cur_start); } } #undef MMID_MATRIX_ROW @@ -2743,7 +3030,6 @@ template q4_0_4x4_q8_0; static const ggml::cpu::repack::tensor_traits q4_0_4x8_q8_0; @@ -2756,6 +3042,9 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons // instance for Q5_K static const ggml::cpu::repack::tensor_traits q5_K_8x8_q8_K; + // instance for Q6_K + static const ggml::cpu::repack::tensor_traits q6_K_8x8_q8_K; + // instance for Q2 static const ggml::cpu::repack::tensor_traits q2_K_8x8_q8_K; @@ -2812,6 +3101,12 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons return &q5_K_8x8_q8_K; } } + } else if (cur->type == GGML_TYPE_Q6_K) { + if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) { + if (cur->ne[1] % 8 == 0) { + return &q6_K_8x8_q8_K; + } + } } else if (cur->type == GGML_TYPE_IQ4_NL) { if (ggml_cpu_has_avx2()) { if (cur->ne[1] % 8 == 0) { diff --git a/ggml/src/ggml-cpu/repack.h b/ggml/src/ggml-cpu/repack.h index da87103157e..855320eeeb6 100644 --- a/ggml/src/ggml-cpu/repack.h +++ b/ggml/src/ggml-cpu/repack.h @@ -65,6 +65,16 @@ struct block_q5_Kx8 { static_assert(sizeof(block_q5_Kx8) == sizeof(ggml_half) * 16 + K_SCALE_SIZE * 8 + QK_K * 5, "wrong q5_K block size/padding"); +struct block_q6_Kx8 { + ggml_half d[8]; + int8_t scales[QK_K / 16 * 8]; + uint8_t ql[QK_K / 2 * 8]; // low bits of 6-bit quants (groups of 2) + uint8_t qh[QK_K / 4 * 8]; // high bits of 6-bit quants (groups of 4) +}; + +static_assert(sizeof(block_q6_Kx8) == sizeof(ggml_half) * 8 + QK_K / 16 * 8 + 3 * QK_K / 4 * 8, + "wrong q6_K block size/padding"); + struct block_q8_Kx4 { float d[4]; // delta int8_t qs[QK_K * 4]; // quants @@ -95,13 +105,14 @@ void ggml_quantize_mat_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTR void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_K_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); -void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q5_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q6_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -111,6 +122,7 @@ void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q5_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q6_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -130,6 +142,7 @@ void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, void ggml_gemv_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -139,6 +152,7 @@ void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, void ggml_gemm_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); From 9d94d0f78271a1eb99e14a11292881d3e3f1e3e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Tue, 27 Jan 2026 14:28:56 +0100 Subject: [PATCH 049/831] CUDA: tune GLM 4.7 Flash FA kernel selection logic (llama/19097) --- ggml/src/ggml-cuda/fattn.cu | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 2f5dbd13a39..b061fdf9a24 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -148,6 +148,10 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg const int gqa_ratio = Q->ne[2] / K->ne[2]; if (gqa_ratio == 20) { // GLM 4.7 Flash if (cc >= GGML_CUDA_CC_BLACKWELL) { + if (Q->ne[1] <= 4 && K->ne[1] >= 65536) { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst); + break; + } ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst); break; } @@ -161,6 +165,10 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg } if (cc >= GGML_CUDA_CC_TURING) { if (Q->ne[1] <= 4) { + if (K->ne[1] <= 16384) { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst); + break; + } ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 32>(ctx, dst); break; } From 9c75c793a6bdec85ea5d969a61dd96414b4ef64a Mon Sep 17 00:00:00 2001 From: Vishal Singh Date: Wed, 28 Jan 2026 03:51:36 +0530 Subject: [PATCH 050/831] ggml-zendnn : update ZenDNN git tag to main branch (llama/19133) --- ggml/src/ggml-zendnn/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-zendnn/CMakeLists.txt b/ggml/src/ggml-zendnn/CMakeLists.txt index bdbfc74369f..f5cf6eedd3a 100644 --- a/ggml/src/ggml-zendnn/CMakeLists.txt +++ b/ggml/src/ggml-zendnn/CMakeLists.txt @@ -21,7 +21,7 @@ if (NOT ZENDNN_ROOT OR ZENDNN_ROOT STREQUAL "" OR ZENDNN_ROOT STREQUAL "OFF") ExternalProject_Add( zendnn GIT_REPOSITORY https://github.com/amd/ZenDNN.git - GIT_TAG zendnnl + GIT_TAG 21ce8f7879c86bf3637f707fae6f29e0951db5fe PREFIX ${ZENDNN_PREFIX} SOURCE_DIR ${ZENDNN_SOURCE_DIR} BINARY_DIR ${ZENDNN_BUILD_DIR} From dfdd2fee832e2c346f79743a7b617a0545985650 Mon Sep 17 00:00:00 2001 From: Nikhil Jain Date: Tue, 27 Jan 2026 20:53:36 -0800 Subject: [PATCH 051/831] ggml webgpu: Split shared state (webgpu_context) into global state and per-thread state (llama/18976) * Squashed commit of the following: commit b3c6bf4b0450d8d452b934df27a0fb7cb53cd755 Author: Abhijit Ramesh Date: Mon Dec 1 18:29:00 2025 -0800 ggml webgpu: fix xielu parameter passing (llama/11) The XIELU operation was incorrectly using static_cast to convert float parameters to uint32_t, which converted numeric values instead of preserving IEEE 754 bit patterns. This caused incorrect values to be interpreted by the GPU shader. * Use reinterpret_cast to preserve float bit patterns when passing through uint32_t params buffer * Update WGSL shader parameter types from u32 to f32 * Re-enable XIELU support (was disabled due to numerical issues) Fixes NMSE test failures for XIELU operation on WebGPU backend. commit 5ca9b5e49ea7cddc9ab7c8b43a11a9c76a4dff4a Author: neha-ha <137219201+neha-ha@users.noreply.github.com> Date: Tue Nov 18 12:17:00 2025 -0800 Refactored pipelines and workgroup calculations (llama/10) * refactored pipelines * refactored workgroup calculation * removed commented out block of prior maps * Clean up ceiling division pattern --------- Co-authored-by: Neha Abbas Co-authored-by: Reese Levine Author: James Contini Date: Wed Oct 29 23:13:06 2025 -0700 formatted embed wgsl and ggml-webgpu.cpp commit e1f6baea31645e5d96ad53664acae856f74b96f4 Author: James Contini Date: Wed Oct 29 23:08:37 2025 -0700 implemented REPL_Template support and removed bug in unary operators kernel commit 8c70b8fece445cdc9a8c660dbddbf201e52da2bb Author: James Contini Date: Wed Oct 15 16:14:20 2025 -0700 responded and dealt with PR comments commit f9282c660c10dec4487d434549bdb707a9cd9f37 Author: James Contini Date: Sun Oct 12 13:41:41 2025 -0700 removed unnecesarry checking if node->src[1] exists for unary operators commit 4cf28d7dec41c29186d66152735b244c5699f9dc Author: James Contini Date: Sun Oct 12 13:32:45 2025 -0700 All operators (inlcluding xielu) working commit 74c6add1761a59d2c2ff60b60e8ad3c8300f6d3e Author: James Contini Date: Fri Oct 10 13:16:48 2025 -0700 fixed autoconfig commit 362749910be4f0120c8ffb21ceddeb7d2c088e51 Author: James Contini Date: Fri Oct 10 13:10:46 2025 -0700 removed vestigial files commit cb0858333785757804c5104e59c4981843207c16 Author: James Contini Date: Fri Oct 10 12:59:32 2025 -0700 abides by editor-config commit 5360e2852a4b51197d7d67d0a5d42e908b02d7ed Author: James Contini Date: Fri Oct 10 12:45:57 2025 -0700 rms_norm double declaration bug atoned commit 7b09baa4aa53711be5a126043670cc182c78bfcd Merge: 8a6ec843 74b8fc17 Author: James Contini Date: Fri Oct 10 11:50:03 2025 -0700 resolving merge conflicts commit 8a6ec843a50ab82f8cef59b4558eb63f318ba02d Author: James Contini Date: Wed Oct 8 18:06:47 2025 -0700 unary operators pass ggml tests commit c3ae38278a2db236adc5912c9140e4f0d63f2c19 Author: James Contini Date: Wed Oct 1 16:22:40 2025 -0700 neg passes backend test commit aa1c9b2f8877a405470ca56709c42a1fd43713de Author: James Contini Date: Tue Sep 30 23:55:27 2025 -0700 neg f16xf32xip builds and runs, havent actually ran a model that uses neg kernel yet though Co-authored-by: James Contini Co-authored-by: Neha Abbas Co-authored-by: Abhijit Ramesh * Remove extra code and format * Add ops documentation (finally) * ggml webgpu: add SOFTPLUS unary operator Implements SOFTPLUS (log(1 + exp(x))) with f16/f32 support. Uses f32 precision for intermediate calculations to prevent f16 overflow. * Add shader implementation and 4 variants (f32/f16, inplace/non-inplace) * Register pipelines and device support * Follow Vulkan backend numerical stability pattern * ggml webgpu: add EXPM1 unary operator Implements EXPM1 (exp(x) - 1) with f16/f32 support. * Add shader implementation and 4 variants (f32/f16, inplace/non-inplace) * Register pipelines and device support * ggml webgpu: add FLOOR unary operator Implements FLOOR (rounds down to nearest integer) with f16/f32 support. * Add shader implementation and 4 variants (f32/f16, inplace/non-inplace) * Register pipelines and device support * ggml webgpu: add CEIL unary operator Implements CEIL (rounds up to nearest integer) with f16/f32 support. * Add shader implementation and 4 variants (f32/f16, inplace/non-inplace) * Register pipelines and device support * ggml webgpu: add ROUND unary operator Implements ROUND (rounds to nearest integer) with f16/f32 support. * Add shader implementation and 4 variants (f32/f16, inplace/non-inplace) * Register pipelines and device support * ggml webgpu: add TRUNC unary operator Implements TRUNC (truncates towards zero) with f16/f32 support. * Add shader implementation and 4 variants (f32/f16, inplace/non-inplace) * Register pipelines and device support * docs : update WebGPU support for unary operators (FLOOR, CEIL, ROUND, TRUNC, EXPM1, SOFTPLUS) * Updates to webgpu get_memory * Move shared state (webgpu_context) and device creation out of registration context, device context, and buffer context, and move into backend context * Small cleanup * Move Instance, Device, Adapter, Device creation, and capabilities to global state while moving Queue, pipelines, and buffers to per-thread state. * Cleanups * More cleanup * Move staging_buf mutex to global context * Resolve merge * Resolve merge * Resolve merge * Clean up merge errors, delete forward declaration, and run clang-format * Rename device_init to backend_init * Move webgpu_context to backend_context * Move buffer context members into global context and refactor function calls * Run clang-format * Remove commends * Move parameter buffers to per-thread, add single memset_tensor param buf * Fix CI compilation issue * Fix builds for emscripten not supporting subgroups * cleanup * cleanup --------- Co-authored-by: Reese Levine --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 1246 ++++++++++++++------------ 1 file changed, 662 insertions(+), 584 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 584cea7698b..22e2bfeb4ce 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -47,7 +47,6 @@ double cpu_total_time_##id = \ std::chrono::duration(cpu_total_end_##id - cpu_total_start_##id).count(); \ (ctx)->cpu_time_ms[#id] += cpu_total_time_##id; - // fine-grained timing (not included in totals) # define WEBGPU_CPU_PROFILE_DETAIL_START(id) auto cpu_detail_start_##id = std::chrono::high_resolution_clock::now(); @@ -74,13 +73,13 @@ #define WEBGPU_MAX_WG_SIZE 288 #define WEBGPU_MUL_MAT_WG_SIZE 256 -#define WEBGPU_NUM_PARAM_BUFS 32u +#define WEBGPU_NUM_PARAM_BUFS 16u #define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 8u #define WEBGPU_WAIT_ANY_TIMEOUT_MS 0 // Maximum number of in-flight submissions per-thread, to avoid exhausting the parameter buffer pool #define WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE #define WEBGPU_PARAMS_BUF_SIZE_BYTES 128 // enough for 32 parameters -#define WEBGPU_NUM_SET_ROWS_ERROR_BUFS 32 +#define WEBGPU_NUM_SET_ROWS_ERROR_BUFS 16 #define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4 #define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4 @@ -267,30 +266,67 @@ struct webgpu_command { #endif }; -// All the base objects needed to run operations on a WebGPU device -struct webgpu_context_struct { +struct webgpu_capabilities_base { + wgpu::Limits limits; + bool supports_subgroup_matrix = false; + + uint32_t sg_mat_m = 0; + uint32_t sg_mat_n = 0; + uint32_t sg_mat_k = 0; + + uint32_t subgroup_size = 0; + uint32_t max_subgroup_size = 0; + size_t memset_bytes_per_thread; +}; + +// Stores global webgpu members +struct webgpu_global_context_struct { wgpu::Instance instance; wgpu::Adapter adapter; wgpu::Device device; wgpu::Queue queue; - wgpu::Limits limits; - uint32_t max_subgroup_size; + webgpu_capabilities_base capabilities; + // Shared buffer to move data from device to host + wgpu::Buffer get_tensor_staging_buf; + // Global mutex for pipeline and staging buffer, will be refactored to exclude pipeline caches. + std::recursive_mutex mutex; - bool supports_subgroup_matrix = false; - uint32_t sg_mat_m; - uint32_t sg_mat_n; - uint32_t sg_mat_k; + webgpu_buf_pool memset_buf_pool; + std::map memset_pipelines; // variant or type index + std::atomic_uint inflight_threads = 0; - std::recursive_mutex mutex; - std::atomic_uint inflight_threads = 0; +#ifdef GGML_WEBGPU_CPU_PROFILE + // Profiling: labeled CPU time in ms (total) + std::unordered_map cpu_time_ms; + // Profiling: detailed CPU time in ms + std::unordered_map cpu_detail_ms; +#endif - webgpu_buf_pool param_buf_pool; - webgpu_buf_pool set_rows_error_buf_pool; +#ifdef GGML_WEBGPU_GPU_PROFILE + // Profiling: per-shader GPU time in ms + std::unordered_map shader_gpu_time_ms; + // Profiling: pool of timestamp query buffers (one per operation) + webgpu_gpu_profile_buf_pool timestamp_query_buf_pool; +#endif + +#ifdef GGML_WEBGPU_DEBUG + wgpu::Buffer debug_host_buf; + wgpu::Buffer debug_dev_buf; +#endif +}; + +typedef std::shared_ptr webgpu_global_context; + +// All the base objects needed to run operations on a WebGPU device +struct webgpu_context_struct { + // Points to global instances owned by ggml_backend_webgpu_reg_context + webgpu_global_context global_ctx; pre_wgsl::Preprocessor p; - std::map memset_pipelines; // variant or type index + webgpu_buf_pool param_buf_pool; + webgpu_buf_pool set_rows_error_buf_pool; std::map>> mul_mat_pipelines; // src0_type, src1_type, vectorized std::map>> @@ -326,57 +362,42 @@ struct webgpu_context_struct { size_t memset_bytes_per_thread; - // Staging buffer for reading data from the GPU - wgpu::Buffer get_tensor_staging_buf; - -#ifdef GGML_WEBGPU_DEBUG - wgpu::Buffer debug_host_buf; - wgpu::Buffer debug_dev_buf; -#endif - -#ifdef GGML_WEBGPU_CPU_PROFILE - // Profiling: labeled CPU time in ms (total) - std::unordered_map cpu_time_ms; - // Profiling: detailed CPU time in ms - std::unordered_map cpu_detail_ms; -#endif - -#ifdef GGML_WEBGPU_GPU_PROFILE - // Profiling: per-shader GPU time in ms - std::unordered_map shader_gpu_time_ms; - // Profiling: pool of timestamp query buffers (one per operation) - webgpu_gpu_profile_buf_pool timestamp_query_buf_pool; -#endif }; typedef std::shared_ptr webgpu_context; +// Metadata required for the ggml backend registration/discovery interface struct ggml_backend_webgpu_reg_context { - webgpu_context webgpu_ctx; - size_t device_count; - const char * name; + // Since the Instance is a global entrypoint into the WebGPU API, it lives here + webgpu_global_context webgpu_global_ctx; + size_t device_count; + const char * name; }; +// Per-device struct for the global logical device interface struct ggml_backend_webgpu_device_context { - webgpu_context webgpu_ctx; - std::string device_name; - std::string device_desc; + webgpu_global_context webgpu_global_ctx; + std::string device_name; + std::string device_desc; }; +// Per-thread data required to actually run WebGPU operations in a backend instance struct ggml_backend_webgpu_context { - webgpu_context webgpu_ctx; - std::string name; + webgpu_context webgpu_ctx; + std::once_flag init_once; + std::string name; }; +// Per-thread data related to buffers struct ggml_backend_webgpu_buffer_context { - webgpu_context webgpu_ctx; - wgpu::Buffer buffer; - std::string label; + wgpu::Buffer buffer; + std::string label; + webgpu_global_context global_ctx; - ggml_backend_webgpu_buffer_context(webgpu_context ctx, wgpu::Buffer buf, std::string lbl) : - webgpu_ctx(std::move(ctx)), + ggml_backend_webgpu_buffer_context(wgpu::Buffer buf, std::string lbl, webgpu_global_context global_ctx_) : buffer(std::move(buf)), - label(std::move(lbl)) {} + label(std::move(lbl)), + global_ctx(std::move(global_ctx_)) {} }; /* WebGPU object initializations */ @@ -444,7 +465,7 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device, /** WebGPU Actions */ // Wait for the queue to finish processing all submitted work -static void ggml_backend_webgpu_wait(webgpu_context & ctx, +static void ggml_backend_webgpu_wait(webgpu_global_context & ctx, std::vector & futures, bool block = true) { // If we have too many in-flight submissions, wait on the oldest one first. If there are many threads, @@ -476,11 +497,11 @@ static void ggml_backend_webgpu_wait(webgpu_context & ct } } -static void ggml_backend_webgpu_map_buffer(webgpu_context & ctx, - wgpu::Buffer & buffer, - wgpu::MapMode mode, - size_t offset, - size_t size) { +static void ggml_backend_webgpu_map_buffer(webgpu_global_context & ctx, + wgpu::Buffer & buffer, + wgpu::MapMode mode, + size_t offset, + size_t size) { ctx->instance.WaitAny(buffer.MapAsync(mode, offset, size, wgpu::CallbackMode::AllowSpontaneous, [](wgpu::MapAsyncStatus status, wgpu::StringView message) { if (status != wgpu::MapAsyncStatus::Success) { @@ -495,7 +516,7 @@ static void ggml_backend_webgpu_map_buffer(webgpu_context & ctx, // This function adds debugging information to shaders, as WebGPU does not support printing directly. // To use, add a bind group entry to the setup for the shader you are debugging, add the buffer and // debug statements in the shader, and then call this function after encoding the commands and submitting them. -static void ggml_backend_webgpu_debug(webgpu_context & ctx) { +static void ggml_backend_webgpu_debug(webgpu_global_context & ctx) { wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder(); encoder.CopyBufferToBuffer(ctx->debug_dev_buf, 0, ctx->debug_host_buf, 0, ctx->debug_host_buf.GetSize()); wgpu::CommandBuffer commands = encoder.Finish(); @@ -507,7 +528,10 @@ static void ggml_backend_webgpu_debug(webgpu_context & ctx) { } #endif -static webgpu_submission_futures ggml_backend_webgpu_submit(webgpu_context ctx, std::vector commands) { +static webgpu_submission_futures ggml_backend_webgpu_submit(webgpu_global_context ctx, + std::vector commands, + webgpu_buf_pool & param_buf_pool, + webgpu_buf_pool * set_rows_error_buf_pool = nullptr) { std::vector command_buffers; std::vector params_bufs; std::vector set_rows_error_bufs; @@ -528,19 +552,19 @@ static webgpu_submission_futures ggml_backend_webgpu_submit(webgpu_context ctx, wgpu::Future p_f = ctx->queue.OnSubmittedWorkDone( wgpu::CallbackMode::AllowSpontaneous, - [ctx, params_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) { + [¶m_buf_pool, params_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) { if (status != wgpu::QueueWorkDoneStatus::Success) { GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", std::string(message).c_str()); } // Free the staged buffers - ctx->param_buf_pool.free_bufs(params_bufs); + param_buf_pool.free_bufs(params_bufs); }); futures.push_back({ p_f }); for (const auto & bufs : set_rows_error_bufs) { wgpu::Future f = bufs.host_buf.MapAsync( wgpu::MapMode::Read, 0, bufs.host_buf.GetSize(), wgpu::CallbackMode::AllowSpontaneous, - [ctx, bufs](wgpu::MapAsyncStatus status, wgpu::StringView message) { + [set_rows_error_buf_pool, bufs](wgpu::MapAsyncStatus status, wgpu::StringView message) { if (status != wgpu::MapAsyncStatus::Success) { GGML_LOG_ERROR("ggml_webgpu: Failed to map error buffer: %s\n", std::string(message).c_str()); } else { @@ -549,7 +573,9 @@ static webgpu_submission_futures ggml_backend_webgpu_submit(webgpu_context ctx, GGML_ABORT("ggml_webgpu: SET_ROWS index > 2^32, unsupported."); } // We can't unmap in here due to WebGPU reentrancy limitations. - ctx->set_rows_error_buf_pool.free_bufs({ bufs }); + if (set_rows_error_buf_pool) { + set_rows_error_buf_pool->free_bufs({ bufs }); + } } }); futures.push_back({ f }); @@ -581,7 +607,8 @@ static webgpu_submission_futures ggml_backend_webgpu_submit(webgpu_context ctx, } static webgpu_command ggml_backend_webgpu_build_multi( - webgpu_context & ctx, + webgpu_global_context & ctx, + webgpu_buf_pool & param_buf_pool, const std::vector & pipelines, const std::vector> & params_list, const std::vector> & bind_group_entries_list, @@ -595,7 +622,7 @@ static webgpu_command ggml_backend_webgpu_build_multi( std::vector bind_groups; for (size_t i = 0; i < pipelines.size(); i++) { - webgpu_pool_bufs params_bufs = ctx->param_buf_pool.alloc_bufs(); + webgpu_pool_bufs params_bufs = param_buf_pool.alloc_bufs(); ggml_backend_webgpu_map_buffer(ctx, params_bufs.host_buf, wgpu::MapMode::Write, 0, params_bufs.host_buf.GetSize()); @@ -672,34 +699,37 @@ static webgpu_command ggml_backend_webgpu_build_multi( return result; } -static webgpu_command ggml_backend_webgpu_build(webgpu_context & ctx, +static webgpu_command ggml_backend_webgpu_build(webgpu_global_context & ctx, + webgpu_buf_pool & param_buf_pool, webgpu_pipeline & pipeline, std::vector params, std::vector bind_group_entries, uint32_t wg_x, uint32_t wg_y = 1, std::optional set_rows_error_bufs = std::nullopt) { - return ggml_backend_webgpu_build_multi(ctx, + return ggml_backend_webgpu_build_multi(ctx, param_buf_pool, { pipeline }, { params }, { bind_group_entries }, { { wg_x, wg_y } }, set_rows_error_bufs); } -static void ggml_backend_webgpu_buffer_memset(webgpu_context & ctx, - wgpu::Buffer & buf, - uint32_t value, - size_t offset, - size_t size) { +static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx, + wgpu::Buffer & buf, + uint32_t value, + size_t offset, + size_t size) { std::vector params = { (uint32_t) offset, (uint32_t) size, value }; std::vector entries = { { .binding = 0, .buffer = buf, .offset = 0, .size = buf.GetSize() } }; - size_t bytes_per_wg = WEBGPU_MAX_WG_SIZE * ctx->memset_bytes_per_thread; + size_t bytes_per_wg = WEBGPU_MAX_WG_SIZE * ctx->capabilities.memset_bytes_per_thread; uint32_t wg_x = CEIL_DIV(size + 3, bytes_per_wg); - webgpu_command command = ggml_backend_webgpu_build(ctx, ctx->memset_pipelines[0], params, entries, wg_x); - std::vector futures = { ggml_backend_webgpu_submit(ctx, { command }) }; + webgpu_command command = + ggml_backend_webgpu_build(ctx, ctx->memset_buf_pool, ctx->memset_pipelines[0], params, entries, wg_x); + std::vector futures = { ggml_backend_webgpu_submit(ctx, { command }, + ctx->memset_buf_pool) }; ggml_backend_webgpu_wait(ctx, futures); } @@ -720,19 +750,19 @@ static void ggml_backend_webgpu_free(ggml_backend_t backend) { #ifdef GGML_WEBGPU_CPU_PROFILE std::cout << "\n[ggml_webgpu cpu profiling summary]\n"; double total_cpu = 0.0; - for (const auto & kv : ctx->webgpu_ctx->cpu_time_ms) { + for (const auto & kv : ctx->webgpu_ctx->global_ctx->cpu_time_ms) { total_cpu += kv.second; } std::cout << "ggml_webgpu: total cpu time: " << total_cpu << " ms\n"; std::cout << "ggml_webgpu: cpu breakdown:\n"; - for (const auto & kv : ctx->webgpu_ctx->cpu_time_ms) { + for (const auto & kv : ctx->webgpu_ctx->global_ctx->cpu_time_ms) { double pct = (total_cpu > 0.0) ? (kv.second / total_cpu * 100.0) : 0.0; std::cout << "ggml_webgpu: " << kv.first << ": " << kv.second << " ms (" << pct << "%)\n"; } - if (ctx->webgpu_ctx->cpu_detail_ms.size() > 0) { + if (ctx->webgpu_ctx->global_ctx->cpu_detail_ms.size() > 0) { std::cout << "ggml_webgpu: cpu detailed breakdown:\n"; } - for (const auto & kv : ctx->webgpu_ctx->cpu_detail_ms) { + for (const auto & kv : ctx->webgpu_ctx->global_ctx->cpu_detail_ms) { double pct = (total_cpu > 0.0) ? (kv.second / total_cpu * 100.0) : 0.0; std::cout << "ggml_webgpu: " << kv.first << ": " << kv.second << " ms (" << pct << "%)\n"; } @@ -741,12 +771,12 @@ static void ggml_backend_webgpu_free(ggml_backend_t backend) { #ifdef GGML_WEBGPU_GPU_PROFILE std::cout << "\n[ggml_webgpu gpu profiling summary]\n"; double total_gpu = 0.0; - for (const auto & kv : ctx->webgpu_ctx->shader_gpu_time_ms) { + for (const auto & kv : ctx->webgpu_ctx->global_ctx->shader_gpu_time_ms) { total_gpu += kv.second; } std::cout << "ggml_webgpu: total gpu time (all shaders): " << total_gpu << " ms\n"; std::cout << "\nggml_webgpu: gpu breakdown:\n"; - for (const auto & kv : ctx->webgpu_ctx->shader_gpu_time_ms) { + for (const auto & kv : ctx->webgpu_ctx->global_ctx->shader_gpu_time_ms) { double pct = (total_gpu > 0.0) ? (kv.second / total_gpu * 100.0) : 0.0; std::cout << "ggml_webgpu: " << kv.first << ": " << kv.second << " ms (" << pct << "%)\n"; } @@ -772,12 +802,12 @@ static wgpu::Buffer ggml_webgpu_tensor_buf(const ggml_tensor * tensor) { static size_t ggml_webgpu_tensor_misalignment(webgpu_context & ctx, const ggml_tensor * t) { size_t offset = ggml_webgpu_tensor_offset(t); - return offset & (ctx->limits.minStorageBufferOffsetAlignment - 1); + return offset & (ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment - 1); } static size_t ggml_webgpu_tensor_align_offset(webgpu_context & ctx, const ggml_tensor * t) { size_t offset = ggml_webgpu_tensor_offset(t); - return offset & ~(ctx->limits.minStorageBufferOffsetAlignment - 1); + return offset & ~(ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment - 1); } static size_t ggml_webgpu_tensor_binding_size(webgpu_context & ctx, ggml_tensor * t) { @@ -818,28 +848,30 @@ static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, g }; uint32_t wg_x = CEIL_DIV(ne, WEBGPU_MAX_WG_SIZE); - return ggml_backend_webgpu_build(ctx, ctx->cpy_pipelines[src->type][dst->type], params, entries, wg_x); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, ctx->cpy_pipelines[src->type][dst->type], + params, entries, wg_x); } static webgpu_command ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { const bool circular = ggml_get_op_params_i32(dst, 8) != 0; ggml_webgpu_pad_pipeline_key pipeline_key = { .circular = circular }; - ggml_webgpu_pad_shader_lib_context shader_lib_ctx = { .key = pipeline_key, - .max_wg_size = - ctx->limits.maxComputeInvocationsPerWorkgroup }; + ggml_webgpu_pad_shader_lib_context shader_lib_ctx = { + .key = pipeline_key, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup + }; webgpu_pipeline pipeline; { // TODO: remove guard once pipeline caches are per-thread - std::lock_guard lock(ctx->mutex); + std::lock_guard lock(ctx->global_ctx->mutex); auto it = ctx->pad_pipelines.find(pipeline_key); if (it != ctx->pad_pipelines.end()) { pipeline = it->second; } else { ggml_webgpu_processed_shader processed = ggml_webgpu_preprocess_pad_shader(ctx->p, wgsl_pad, shader_lib_ctx); - pipeline = ggml_webgpu_create_pipeline(ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); + pipeline = + ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); pipeline.context = processed.decisions; ctx->pad_pipelines.emplace(pipeline_key, pipeline); } @@ -891,7 +923,7 @@ static webgpu_command ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, g }; uint32_t wg_x = CEIL_DIV(ne, decisions.wg_size); - return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, @@ -907,21 +939,22 @@ static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, .vec4 = src->ne[0] % 4 == 0, .i64_idx = idx->type == GGML_TYPE_I64 }; - ggml_webgpu_set_rows_shader_lib_context shader_lib_ctx = { .key = key, - .max_wg_size = - ctx->limits.maxComputeInvocationsPerWorkgroup }; + ggml_webgpu_set_rows_shader_lib_context shader_lib_ctx = { + .key = key, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup + }; webgpu_pipeline pipeline; // TODO: remove guard once pipeline caches are per-thread { - std::lock_guard lock(ctx->mutex); + std::lock_guard lock(ctx->global_ctx->mutex); auto it = ctx->set_rows_pipelines.find(key); if (it != ctx->set_rows_pipelines.end()) { pipeline = it->second; } else { ggml_webgpu_processed_shader processed = ggml_webgpu_preprocess_set_rows_shader(ctx->p, wgsl_set_rows, shader_lib_ctx); - pipeline = ggml_webgpu_create_pipeline(ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); + pipeline = + ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); pipeline.context = processed.decisions; ctx->set_rows_pipelines.emplace(key, pipeline); } @@ -981,7 +1014,8 @@ static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, threads = src->ne[0] * src->ne[1] * src->ne[2] * src->ne[3]; } uint32_t wg_x = CEIL_DIV(threads, decisions.wg_size); - return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, 1, error_bufs); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, 1, + error_bufs); } static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx, @@ -1023,7 +1057,7 @@ static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx, uint32_t vectorized = src->type == GGML_TYPE_F32 && dst->ne[0] % 4 == 0; webgpu_pipeline pipeline = ctx->get_rows_pipelines[src->type][vectorized]; - return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, @@ -1098,19 +1132,21 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, uint32_t batches = dst->ne[2] * dst->ne[3]; uint32_t output_groups = CEIL_DIV(dst->ne[0], WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG); uint32_t total_wg = output_groups * batches; - wg_x = total_wg % ctx->limits.maxComputeWorkgroupsPerDimension; - wg_y = CEIL_DIV(total_wg, ctx->limits.maxComputeWorkgroupsPerDimension); + wg_x = total_wg % ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; + wg_y = CEIL_DIV(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension); } else { pipeline = ctx->mul_mat_pipelines[src0->type][src1->type][vectorized]; uint32_t wg_m; uint32_t wg_n; #ifndef __EMSCRIPTEN__ - if (ctx->supports_subgroup_matrix) { + if (ctx->global_ctx->capabilities.supports_subgroup_matrix) { // The total number of subgroups/workgroups needed per matrix. - uint32_t wg_m_sg_tile = WEBGPU_MUL_MAT_SUBGROUP_M * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M * ctx->sg_mat_m; + uint32_t wg_m_sg_tile = WEBGPU_MUL_MAT_SUBGROUP_M * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M * + ctx->global_ctx->capabilities.sg_mat_m; wg_m = CEIL_DIV(dst->ne[0], wg_m_sg_tile); - uint32_t wg_n_sg_tile = WEBGPU_MUL_MAT_SUBGROUP_N * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N * ctx->sg_mat_n; - wg_n = CEIL_DIV(dst->ne[1], wg_n_sg_tile); + uint32_t wg_n_sg_tile = WEBGPU_MUL_MAT_SUBGROUP_N * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N * + ctx->global_ctx->capabilities.sg_mat_n; + wg_n = CEIL_DIV(dst->ne[1], wg_n_sg_tile); } else { #endif uint32_t tile_m_s = WEBGPU_MUL_MAT_TILE_M * WEBGPU_MUL_MAT_WG_SIZE_M; @@ -1124,9 +1160,10 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, wg_x = wg_m * wg_n * dst->ne[2] * dst->ne[3]; } } - return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, wg_y); } +#ifndef __EMSCRIPTEN__ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, ggml_tensor * Q, ggml_tensor * K, @@ -1210,8 +1247,8 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, .offset = ggml_webgpu_tensor_align_offset(ctx, dst), .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); - bool kv_direct = - (K->type == GGML_TYPE_F16) && (Q->ne[0] % ctx->sg_mat_k == 0) && (K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0); + bool kv_direct = (K->type == GGML_TYPE_F16) && (Q->ne[0] % ctx->global_ctx->capabilities.sg_mat_k == 0) && + (K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0); ggml_webgpu_flash_attn_pipeline_key key = { .kv_type = K->type, @@ -1223,25 +1260,27 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, .uses_logit_softcap = logit_softcap != 0.0f, }; - webgpu_pipeline pipeline; + webgpu_pipeline pipeline; // TODO: remove guard once pipeline caches are per-thread { - std::lock_guard lock(ctx->mutex); + std::lock_guard lock(ctx->global_ctx->mutex); auto it = ctx->flash_attn_pipelines.find(key); if (it != ctx->flash_attn_pipelines.end()) { - pipeline = it->second; + pipeline = it->second; } else { - ggml_webgpu_flash_attn_shader_lib_context shader_lib_ctx = { .key = key, - .sg_mat_m = ctx->sg_mat_m, - .sg_mat_n = ctx->sg_mat_n, - .sg_mat_k = ctx->sg_mat_k, - .wg_mem_limit_bytes = - ctx->limits.maxComputeWorkgroupStorageSize, - .max_subgroup_size = ctx->max_subgroup_size }; + ggml_webgpu_flash_attn_shader_lib_context shader_lib_ctx = { + .key = key, + .sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m, + .sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n, + .sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k, + .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize, + .max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size + }; ggml_webgpu_processed_shader processed = ggml_webgpu_preprocess_flash_attn_shader(ctx->p, wgsl_flash_attn, shader_lib_ctx); - pipeline = ggml_webgpu_create_pipeline(ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); + pipeline = + ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); pipeline.context = processed.decisions; ctx->flash_attn_pipelines.emplace(key, pipeline); } @@ -1250,11 +1289,11 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, ggml_webgpu_flash_attn_shader_decisions decisions = *static_cast(pipeline.context); - uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions.q_tile); uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches - return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } +#endif static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { bool is_unary = dst->op == GGML_OP_UNARY; @@ -1264,21 +1303,22 @@ static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * s ggml_webgpu_unary_pipeline_key pipeline_key = { .type = dst->type, .op = op, .is_unary = is_unary, .inplace = inplace }; - ggml_webgpu_unary_shader_lib_context shader_lib_ctx = { .key = pipeline_key, - .max_wg_size = - ctx->limits.maxComputeInvocationsPerWorkgroup }; + ggml_webgpu_unary_shader_lib_context shader_lib_ctx = { + .key = pipeline_key, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup + }; webgpu_pipeline pipeline; { // TODO: remove guard once pipeline caches are per-thread - std::lock_guard lock(ctx->mutex); + std::lock_guard lock(ctx->global_ctx->mutex); auto it = ctx->unary_pipelines.find(pipeline_key); if (it != ctx->unary_pipelines.end()) { pipeline = it->second; } else { ggml_webgpu_processed_shader processed = ggml_webgpu_preprocess_unary_shader(ctx->p, wgsl_unary, shader_lib_ctx); - pipeline = ggml_webgpu_create_pipeline(ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); + pipeline = + ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); pipeline.context = processed.decisions; ctx->unary_pipelines.emplace(pipeline_key, pipeline); } @@ -1346,7 +1386,7 @@ static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * s } uint32_t wg_x = CEIL_DIV(ne, decisions.wg_size); - return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx, @@ -1391,7 +1431,7 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx, } uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE); - return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } static webgpu_command ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { @@ -1426,7 +1466,8 @@ static webgpu_command ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * s .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); } - return ggml_backend_webgpu_build(ctx, ctx->rms_norm_pipelines[inplace], params, entries, ggml_nrows(src)); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, ctx->rms_norm_pipelines[inplace], params, + entries, ggml_nrows(src)); } static webgpu_command ggml_webgpu_rope(webgpu_context & ctx, @@ -1513,7 +1554,7 @@ static webgpu_command ggml_webgpu_rope(webgpu_context & ctx, webgpu_pipeline pipeline = ctx->rope_pipelines[dst->type][has_freq_factor][inplace]; uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE); - return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } static webgpu_command ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) { @@ -1565,7 +1606,7 @@ static webgpu_command ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0, webgpu_pipeline pipeline = ctx->glu_pipelines[ggml_get_glu_op(dst)][dst->type][split]; uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE); - return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } static webgpu_command ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { @@ -1602,7 +1643,8 @@ static webgpu_command ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, } uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE); - return ggml_backend_webgpu_build(ctx, ctx->scale_pipelines[inplace], params, entries, wg_x); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, ctx->scale_pipelines[inplace], params, + entries, wg_x); } static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx, @@ -1674,7 +1716,8 @@ static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx, .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); } - return ggml_backend_webgpu_build(ctx, ctx->soft_max_pipelines[mask_type][has_sink][inplace], params, entries, + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, + ctx->soft_max_pipelines[mask_type][has_sink][inplace], params, entries, ggml_nrows(dst)); } @@ -1696,25 +1739,26 @@ static webgpu_command ggml_webgpu_argmax(webgpu_context & ctx, ggml_tensor * src ggml_webgpu_generic_shader_lib_context shader_lib_ctx = { .vec4 = src->ne[0] % 4 == 0, - .max_wg_size = ctx->limits.maxComputeInvocationsPerWorkgroup, + .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, }; webgpu_pipeline pipeline; { // TODO: remove guard once pipeline caches are per-thread - std::lock_guard lock(ctx->mutex); + std::lock_guard lock(ctx->global_ctx->mutex); auto it = ctx->argmax_pipelines.find(shader_lib_ctx.vec4); if (it != ctx->argmax_pipelines.end()) { pipeline = it->second; } else { ggml_webgpu_processed_shader processed = ggml_webgpu_preprocess_generic_shader(ctx->p, wgsl_argmax, shader_lib_ctx, "argmax"); - pipeline = ggml_webgpu_create_pipeline(ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); + pipeline = + ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); ctx->argmax_pipelines.emplace(shader_lib_ctx.vec4, pipeline); } } uint32_t wg_x = ggml_nelements(dst); - return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } static webgpu_command ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { @@ -1722,13 +1766,13 @@ static webgpu_command ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * sr // ascending order is 0, descending order is 1 const int32_t order = is_top_k ? (int32_t) GGML_SORT_ORDER_DESC : (int32_t) ggml_get_op_params_i32(dst, 0); - ggml_webgpu_argsort_shader_lib_context shader_lib_ctx = { .max_wg_size = - ctx->limits.maxComputeInvocationsPerWorkgroup, - .wg_mem_limit_bytes = - ctx->limits.maxComputeWorkgroupStorageSize, - .order = order }; + ggml_webgpu_argsort_shader_lib_context shader_lib_ctx = { + .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, + .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize, + .order = order + }; - std::lock_guard lock(ctx->mutex); + std::lock_guard lock(ctx->global_ctx->mutex); webgpu_pipeline argsort_pipeline; auto it = ctx->argsort_pipelines.find(order); if (it != ctx->argsort_pipelines.end()) { @@ -1736,7 +1780,8 @@ static webgpu_command ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * sr } else { ggml_webgpu_processed_shader processed = ggml_webgpu_preprocess_argsort_shader(ctx->p, wgsl_argsort, shader_lib_ctx); - argsort_pipeline = ggml_webgpu_create_pipeline(ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); + argsort_pipeline = + ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); argsort_pipeline.context = processed.decisions; ctx->argsort_pipelines.emplace(order, argsort_pipeline); } @@ -1751,7 +1796,7 @@ static webgpu_command ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * sr ggml_webgpu_processed_shader processed = ggml_webgpu_preprocess_argsort_merge_shader(ctx->p, wgsl_argsort_merge, shader_lib_ctx); argsort_merge_pipeline = - ggml_webgpu_create_pipeline(ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); + ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); argsort_merge_pipeline.context = processed.decisions; ctx->argsort_merge_pipelines.emplace(order, argsort_merge_pipeline); } @@ -1780,9 +1825,10 @@ static webgpu_command ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * sr const bool start_in_tmp = (merge_passes % 2) == 1; - const size_t dst_offset = ggml_webgpu_tensor_offset(dst); - const size_t idx_nbytes = out_ne0 * ggml_nrows(dst) * sizeof(int32_t); - const size_t tmp_offset = ROUNDUP_POW2(dst_offset + idx_nbytes, ctx->limits.minStorageBufferOffsetAlignment); + const size_t dst_offset = ggml_webgpu_tensor_offset(dst); + const size_t idx_nbytes = out_ne0 * ggml_nrows(dst) * sizeof(int32_t); + const size_t tmp_offset = + ROUNDUP_POW2(dst_offset + idx_nbytes, ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment); const size_t tmp_binding_size = ROUNDUP_POW2(idx_nbytes, WEBGPU_STORAGE_BUF_BINDING_MULT); const size_t dst_binding_size = ROUNDUP_POW2(idx_nbytes + ggml_webgpu_tensor_misalignment(ctx, dst), WEBGPU_STORAGE_BUF_BINDING_MULT); @@ -1813,10 +1859,10 @@ static webgpu_command ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * sr }; const uint32_t total_wg_init = npr * nrows; - const uint32_t max_wg = ctx->limits.maxComputeWorkgroupsPerDimension; - const uint32_t wg_x_init = std::min(total_wg_init, max_wg); - const uint32_t wg_y_init = CEIL_DIV(total_wg_init, wg_x_init); - std::vector init_entries = { + const uint32_t max_wg = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; + const uint32_t wg_x_init = std::min(total_wg_init, max_wg); + const uint32_t wg_y_init = CEIL_DIV(total_wg_init, wg_x_init); + std::vector init_entries = { { .binding = 0, .buffer = ggml_webgpu_tensor_buf(src), .offset = ggml_webgpu_tensor_align_offset(ctx, src), @@ -1830,7 +1876,8 @@ static webgpu_command ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * sr workgroups_list.push_back({ wg_x_init, wg_y_init }); if (merge_passes == 0) { - return ggml_backend_webgpu_build_multi(ctx, pipelines, params_list, entries_list, workgroups_list); + return ggml_backend_webgpu_build_multi(ctx->global_ctx, ctx->param_buf_pool, pipelines, params_list, + entries_list, workgroups_list); } bool in_is_tmp = start_in_tmp; @@ -1891,7 +1938,8 @@ static webgpu_command ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * sr in_is_tmp = !in_is_tmp; } - return ggml_backend_webgpu_build_multi(ctx, pipelines, params_list, entries_list, workgroups_list); + return ggml_backend_webgpu_build_multi(ctx->global_ctx, ctx->param_buf_pool, pipelines, params_list, entries_list, + workgroups_list); } static webgpu_command ggml_webgpu_cumsum(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { @@ -1912,24 +1960,25 @@ static webgpu_command ggml_webgpu_cumsum(webgpu_context & ctx, ggml_tensor * src ggml_webgpu_generic_shader_lib_context shader_lib_ctx = { .vec4 = false, - .max_wg_size = ctx->limits.maxComputeInvocationsPerWorkgroup, + .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, }; webgpu_pipeline pipeline; // TODO: remove guard once pipeline caches are per-thread { - std::lock_guard lock(ctx->mutex); + std::lock_guard lock(ctx->global_ctx->mutex); auto it = ctx->cumsum_pipelines.find(1); if (it != ctx->cumsum_pipelines.end()) { pipeline = it->second; } else { ggml_webgpu_processed_shader processed = ggml_webgpu_preprocess_generic_shader(ctx->p, wgsl_cumsum, shader_lib_ctx, "cumsum"); - pipeline = ggml_webgpu_create_pipeline(ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); + pipeline = + ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); ctx->cumsum_pipelines.emplace(1, pipeline); } } uint32_t wg_x = ggml_nrows(dst); - return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } static webgpu_command ggml_webgpu_sum_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { @@ -1956,25 +2005,26 @@ static webgpu_command ggml_webgpu_sum_rows(webgpu_context & ctx, ggml_tensor * s ggml_webgpu_generic_shader_lib_context shader_lib_ctx = { .vec4 = false, - .max_wg_size = ctx->limits.maxComputeInvocationsPerWorkgroup, + .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, }; webgpu_pipeline pipeline; { // TODO: remove guard once pipeline caches are per-thread - std::lock_guard lock(ctx->mutex); + std::lock_guard lock(ctx->global_ctx->mutex); auto it = ctx->sum_rows_pipelines.find(1); if (it != ctx->sum_rows_pipelines.end()) { pipeline = it->second; } else { ggml_webgpu_processed_shader processed = ggml_webgpu_preprocess_generic_shader(ctx->p, wgsl_sum_rows, shader_lib_ctx, "sum_rows"); - pipeline = ggml_webgpu_create_pipeline(ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); + pipeline = + ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); ctx->sum_rows_pipelines.emplace(1, pipeline); } } uint32_t wg_x = total_sum ? 1 : ggml_nrows(dst); - return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } // Returns the encoded command, or std::nullopt if the operation is a no-op @@ -2009,7 +2059,11 @@ static std::optional ggml_webgpu_encode_node(webgpu_context ctx, case GGML_OP_MUL_MAT: return ggml_webgpu_mul_mat(ctx, src0, src1, node); case GGML_OP_FLASH_ATTN_EXT: +#ifndef __EMSCRIPTEN__ return ggml_webgpu_flash_attn(ctx, src0, src1, src2, node->src[3], node->src[4], node); +#else + return std::nullopt; +#endif case GGML_OP_ADD: { int inplace = ggml_webgpu_tensor_equal(src0, node); @@ -2070,12 +2124,12 @@ static std::optional ggml_webgpu_encode_node(webgpu_context ctx, static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) { WEBGPU_LOG_DEBUG("ggml_backend_webgpu_graph_compute(" << cgraph->n_nodes << " nodes)"); - ggml_backend_webgpu_context * backend_ctx = static_cast(backend->context); + ggml_backend_webgpu_context * backend_ctx = (ggml_backend_webgpu_context *) backend->context; webgpu_context ctx = backend_ctx->webgpu_ctx; WEBGPU_CPU_PROFILE_TOTAL_START(graph_compute); - ctx->inflight_threads++; + ctx->global_ctx->inflight_threads++; std::vector commands; std::vector futures; @@ -2084,25 +2138,27 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str commands.push_back(*cmd); } // compute the batch size based on the number of inflight threads - uint32_t inflight_threads = ctx->inflight_threads; + uint32_t inflight_threads = ctx->global_ctx->inflight_threads; uint32_t batch_size = std::min(std::max(1u, WEBGPU_NUM_PARAM_BUFS / std::max(inflight_threads, 1u)), WEBGPU_COMMAND_SUBMIT_BATCH_SIZE); if (commands.size() >= batch_size) { - futures.push_back(ggml_backend_webgpu_submit(ctx, commands)); + futures.push_back(ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool, + &ctx->set_rows_error_buf_pool)); // Process events and check for completed submissions - ctx->instance.ProcessEvents(); - ggml_backend_webgpu_wait(ctx, futures, false); + ctx->global_ctx->instance.ProcessEvents(); + ggml_backend_webgpu_wait(ctx->global_ctx, futures, false); commands.clear(); } } if (!commands.empty()) { - webgpu_submission_futures new_futures = ggml_backend_webgpu_submit(ctx, commands); + webgpu_submission_futures new_futures = + ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool, &ctx->set_rows_error_buf_pool); futures.push_back(new_futures); } - ggml_backend_webgpu_wait(ctx, futures); - ctx->inflight_threads--; - WEBGPU_CPU_PROFILE_TOTAL_END(graph_compute, ctx); + ggml_backend_webgpu_wait(ctx->global_ctx, futures); + ctx->global_ctx->inflight_threads--; + WEBGPU_CPU_PROFILE_TOTAL_END(graph_compute, ctx->global_ctx); return GGML_STATUS_SUCCESS; } @@ -2159,8 +2215,8 @@ static void ggml_backend_webgpu_buffer_memset_tensor(ggml_backend_buffer_t buffe // This is a trick to set all bytes of a u32 to the same 1 byte value. uint32_t val32 = (uint32_t) value * 0x01010101; - ggml_backend_webgpu_buffer_memset(buf_ctx->webgpu_ctx, buf_ctx->buffer, val32, total_offset, size); - WEBGPU_CPU_PROFILE_TOTAL_END(memset_tensor, buf_ctx->webgpu_ctx); + ggml_backend_webgpu_buffer_memset(buf_ctx->global_ctx, buf_ctx->buffer, val32, total_offset, size); + WEBGPU_CPU_PROFILE_TOTAL_END(memset_tensor, buf_ctx->global_ctx); } static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer, @@ -2169,15 +2225,14 @@ static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer, size_t offset, size_t size) { WEBGPU_CPU_PROFILE_TOTAL_START(set_tensor); - ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context; - webgpu_context webgpu_ctx = buf_ctx->webgpu_ctx; + ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context; WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_set_tensor(" << buf_ctx->label << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")"); size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset; - webgpu_ctx->queue.WriteBuffer(buf_ctx->buffer, total_offset, data, (size / 4) * 4); + buf_ctx->global_ctx->queue.WriteBuffer(buf_ctx->buffer, total_offset, data, (size / 4) * 4); if (size % 4 != 0) { // If size is not a multiple of 4, we need to memset the remaining bytes @@ -2190,21 +2245,21 @@ static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer, ((uint8_t *) &val32)[i] = ((const uint8_t *) data)[size - remaining_size + i]; } // memset the remaining bytes - ggml_backend_webgpu_buffer_memset(webgpu_ctx, buf_ctx->buffer, val32, total_offset + (size - remaining_size), - remaining_size); + ggml_backend_webgpu_buffer_memset(buf_ctx->global_ctx, buf_ctx->buffer, val32, + total_offset + (size - remaining_size), remaining_size); } else { // wait for WriteBuffer to complete - webgpu_ctx->instance.WaitAny( - webgpu_ctx->queue.OnSubmittedWorkDone(wgpu::CallbackMode::AllowSpontaneous, + buf_ctx->global_ctx->instance.WaitAny(buf_ctx->global_ctx->queue.OnSubmittedWorkDone( + wgpu::CallbackMode::AllowSpontaneous, [](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) { if (status != wgpu::QueueWorkDoneStatus::Success) { GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", std::string(message).c_str()); } }), - UINT64_MAX); + UINT64_MAX); } - WEBGPU_CPU_PROFILE_TOTAL_END(set_tensor, webgpu_ctx); + WEBGPU_CPU_PROFILE_TOTAL_END(set_tensor, buf_ctx->global_ctx); } static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer, @@ -2216,8 +2271,7 @@ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer, ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context; WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_get_tensor(" << buf_ctx->label << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")"); - webgpu_context webgpu_ctx = buf_ctx->webgpu_ctx; - wgpu::Device device = webgpu_ctx->device; + wgpu::Device device = buf_ctx->global_ctx->device; size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset; @@ -2227,42 +2281,45 @@ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer, final_size = size + (4 - (size % 4)); } - std::lock_guard lock(webgpu_ctx->mutex); + std::lock_guard lock(buf_ctx->global_ctx->mutex); - if (webgpu_ctx->get_tensor_staging_buf == nullptr || webgpu_ctx->get_tensor_staging_buf.GetSize() < final_size) { + if (buf_ctx->global_ctx->get_tensor_staging_buf == nullptr || + buf_ctx->global_ctx->get_tensor_staging_buf.GetSize() < final_size) { // Create a new staging buffer if it doesn't exist or is too small - if (webgpu_ctx->get_tensor_staging_buf) { - webgpu_ctx->get_tensor_staging_buf.Destroy(); + if (buf_ctx->global_ctx->get_tensor_staging_buf) { + buf_ctx->global_ctx->get_tensor_staging_buf.Destroy(); } - ggml_webgpu_create_buffer(device, webgpu_ctx->get_tensor_staging_buf, final_size, + ggml_webgpu_create_buffer(device, buf_ctx->global_ctx->get_tensor_staging_buf, final_size, wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "get_tensor_staging_buf"); } // Copy the data from the buffer to the staging buffer wgpu::CommandEncoder encoder = device.CreateCommandEncoder(); - encoder.CopyBufferToBuffer(buf_ctx->buffer, total_offset, webgpu_ctx->get_tensor_staging_buf, 0, final_size); + encoder.CopyBufferToBuffer(buf_ctx->buffer, total_offset, buf_ctx->global_ctx->get_tensor_staging_buf, 0, + final_size); wgpu::CommandBuffer commands = encoder.Finish(); // Submit the command buffer to the queue - webgpu_ctx->queue.Submit(1, &commands); + buf_ctx->global_ctx->queue.Submit(1, &commands); // Map the staging buffer to read the data - ggml_backend_webgpu_map_buffer(webgpu_ctx, webgpu_ctx->get_tensor_staging_buf, wgpu::MapMode::Read, 0, final_size); + ggml_backend_webgpu_map_buffer(buf_ctx->global_ctx, buf_ctx->global_ctx->get_tensor_staging_buf, + wgpu::MapMode::Read, 0, final_size); // Must specify size here since the staging buffer might be larger than the tensor size - const void * mapped_range = webgpu_ctx->get_tensor_staging_buf.GetConstMappedRange(0, final_size); + const void * mapped_range = buf_ctx->global_ctx->get_tensor_staging_buf.GetConstMappedRange(0, final_size); // Copy the data from the mapped range to the output buffer std::memcpy(data, mapped_range, size); - webgpu_ctx->get_tensor_staging_buf.Unmap(); - WEBGPU_CPU_PROFILE_TOTAL_END(get_tensor, webgpu_ctx); + buf_ctx->global_ctx->get_tensor_staging_buf.Unmap(); + WEBGPU_CPU_PROFILE_TOTAL_END(get_tensor, buf_ctx->global_ctx); } static void ggml_backend_webgpu_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_clear(" << buffer << ", " << (uint32_t) value << ")"); WEBGPU_CPU_PROFILE_TOTAL_START(clear); ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context; - ggml_backend_webgpu_buffer_memset(buf_ctx->webgpu_ctx, buf_ctx->buffer, value, 0, buffer->size); - WEBGPU_CPU_PROFILE_TOTAL_END(clear, buf_ctx->webgpu_ctx); + ggml_backend_webgpu_buffer_memset(buf_ctx->global_ctx, buf_ctx->buffer, value, 0, buffer->size); + WEBGPU_CPU_PROFILE_TOTAL_END(clear, buf_ctx->global_ctx); } static ggml_backend_buffer_i ggml_backend_webgpu_buffer_interface = { @@ -2292,28 +2349,30 @@ static ggml_backend_buffer_t ggml_backend_webgpu_buffer_type_alloc_buffer(ggml_b int buffer_id = buffer_count++; std::string buf_name = "tensor_buf" + std::to_string(buffer_id); WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_type_alloc_buffer_" << buffer_id << ": " << size << " bytes"); - ggml_backend_webgpu_device_context * ctx = static_cast(buft->device->context); - wgpu::Buffer buf; - ggml_webgpu_create_buffer(ctx->webgpu_ctx->device, buf, ROUNDUP_POW2(size, WEBGPU_STORAGE_BUF_BINDING_MULT), + ggml_backend_webgpu_device_context * ctx = static_cast(buft->device->context); + wgpu::Buffer buf; + ggml_webgpu_create_buffer(ctx->webgpu_global_ctx->device, buf, ROUNDUP_POW2(size, WEBGPU_STORAGE_BUF_BINDING_MULT), wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst, buf_name.c_str()); ggml_backend_webgpu_buffer_context * buf_ctx = - new ggml_backend_webgpu_buffer_context(ctx->webgpu_ctx, buf, buf_name); + new ggml_backend_webgpu_buffer_context(buf, buf_name, ctx->webgpu_global_ctx); return ggml_backend_buffer_init(buft, ggml_backend_webgpu_buffer_interface, buf_ctx, size); } static size_t ggml_backend_webgpu_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { - ggml_backend_webgpu_device_context * ctx = static_cast(buft->device->context); - return ctx->webgpu_ctx->limits.minStorageBufferOffsetAlignment; + ggml_backend_webgpu_device_context * dev_ctx = + static_cast(buft->device->context); + return dev_ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment; } // maxBufferSize might be larger, but you can't bind more than maxStorageBufferBindingSize to a single binding. static size_t ggml_backend_webgpu_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) { - ggml_backend_webgpu_device_context * ctx = static_cast(buft->device->context); - return ctx->webgpu_ctx->limits.maxStorageBufferBindingSize; + ggml_backend_webgpu_device_context * dev_ctx = + static_cast(buft->device->context); + return dev_ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize; } static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, @@ -2322,7 +2381,7 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer size_t res = ggml_nbytes(tensor); switch (tensor->op) { case GGML_OP_ARGSORT: - res = ROUNDUP_POW2(res * 2 + ctx->webgpu_ctx->limits.minStorageBufferOffsetAlignment, + res = ROUNDUP_POW2(res * 2 + ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment, WEBGPU_STORAGE_BUF_BINDING_MULT); break; case GGML_OP_TOP_K: @@ -2330,8 +2389,9 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer const ggml_tensor * src0 = tensor->src[0]; if (src0) { const size_t full = sizeof(int32_t) * ggml_nelements(src0); - res = ROUNDUP_POW2(full * 2 + ctx->webgpu_ctx->limits.minStorageBufferOffsetAlignment, - WEBGPU_STORAGE_BUF_BINDING_MULT); + res = ROUNDUP_POW2( + full * 2 + ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment, + WEBGPU_STORAGE_BUF_BINDING_MULT); } } break; @@ -2359,7 +2419,7 @@ static void ggml_backend_webgpu_device_get_memory(ggml_backend_dev_t dev, size_t ggml_backend_webgpu_device_context * ctx = static_cast(dev->context); // TODO: for now, return maxBufferSize as both free and total memory // Track https://github.com/gpuweb/gpuweb/issues/5505 for updates. - uint64_t max_buffer_size = ctx->webgpu_ctx->limits.maxBufferSize; + uint64_t max_buffer_size = ctx->webgpu_global_ctx->capabilities.limits.maxBufferSize; // If we're on a 32-bit system, clamp to UINTPTR_MAX #if UINTPTR_MAX < UINT64_MAX uint64_t max_ptr_size = static_cast(UINTPTR_MAX); @@ -2402,66 +2462,67 @@ static std::vector ggml_webgpu_wg_size_entry(uint32_t wg_si return constants; } -static void ggml_webgpu_init_memset_pipeline(webgpu_context & webgpu_ctx) { +static void ggml_webgpu_init_memset_pipeline(webgpu_global_context & ctx) { // we use the maximum workgroup size for the memset pipeline - size_t max_threads = WEBGPU_MAX_WG_SIZE * webgpu_ctx->limits.maxComputeWorkgroupsPerDimension; + size_t max_threads = WEBGPU_MAX_WG_SIZE * ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; // Size the bytes_per_thread so that the largest buffer size can be handled - webgpu_ctx->memset_bytes_per_thread = CEIL_DIV(webgpu_ctx->limits.maxStorageBufferBindingSize, max_threads); + ctx->capabilities.memset_bytes_per_thread = + CEIL_DIV(ctx->capabilities.limits.maxStorageBufferBindingSize, max_threads); std::vector constants(2); - constants[0].key = "wg_size"; - constants[0].value = WEBGPU_MAX_WG_SIZE; - constants[1].key = "bytes_per_thread"; - constants[1].value = webgpu_ctx->memset_bytes_per_thread; - webgpu_ctx->memset_pipelines[0] = ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_memset, "memset", constants); + constants[0].key = "wg_size"; + constants[0].value = WEBGPU_MAX_WG_SIZE; + constants[1].key = "bytes_per_thread"; + constants[1].value = ctx->capabilities.memset_bytes_per_thread; + ctx->memset_pipelines[0] = ggml_webgpu_create_pipeline(ctx->device, wgsl_memset, "memset", constants); } static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { // Q4/Q5/Q8 classic quantizations webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q4_0_f32, "mul_mat_q4_0_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q4_0_f32, "mul_mat_q4_0_f32"); webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_1][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q4_1_f32, "mul_mat_q4_1_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q4_1_f32, "mul_mat_q4_1_f32"); webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q5_0][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q5_0_f32, "mul_mat_q5_0_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q5_0_f32, "mul_mat_q5_0_f32"); webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q5_1][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q5_1_f32, "mul_mat_q5_1_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q5_1_f32, "mul_mat_q5_1_f32"); webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q8_0][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q8_0_f32, "mul_mat_q8_0_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q8_0_f32, "mul_mat_q8_0_f32"); // K-quantizations webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q2_K][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q2_k_f32, "mul_mat_q2_k_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q2_k_f32, "mul_mat_q2_k_f32"); webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q3_K][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q3_k_f32, "mul_mat_q3_k_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q3_k_f32, "mul_mat_q3_k_f32"); webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_K][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q4_k_f32, "mul_mat_q4_k_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q4_k_f32, "mul_mat_q4_k_f32"); webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q5_K][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q5_k_f32, "mul_mat_q5_k_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q5_k_f32, "mul_mat_q5_k_f32"); webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q6_K][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q6_k_f32, "mul_mat_q6_k_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q6_k_f32, "mul_mat_q6_k_f32"); // IQ quantizations (2-, 3-, 4-bit variants) webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ2_XXS][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq2_xxs_f32, "mul_mat_iq2_xxs_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq2_xxs_f32, "mul_mat_iq2_xxs_f32"); webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ2_XS][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq2_xs_f32, "mul_mat_iq2_xs_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq2_xs_f32, "mul_mat_iq2_xs_f32"); webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ2_S][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq2_s_f32, "mul_mat_iq2_s_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq2_s_f32, "mul_mat_iq2_s_f32"); webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ3_XXS][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq3_xxs_f32, "mul_mat_iq3_xxs_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq3_xxs_f32, "mul_mat_iq3_xxs_f32"); webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ3_S][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq3_s_f32, "mul_mat_iq3_s_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq3_s_f32, "mul_mat_iq3_s_f32"); // 1-bit and 4-bit IQ variants webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ1_S][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq1_s_f32, "mul_mat_iq1_s_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq1_s_f32, "mul_mat_iq1_s_f32"); webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ1_M][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq1_m_f32, "mul_mat_iq1_m_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq1_m_f32, "mul_mat_iq1_m_f32"); webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ4_NL][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq4_nl_f32, "mul_mat_iq4_nl_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq4_nl_f32, "mul_mat_iq4_nl_f32"); webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ4_XS][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq4_xs_f32, "mul_mat_iq4_xs_f32"); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq4_xs_f32, "mul_mat_iq4_xs_f32"); std::string proc_mul_mat_f32_f32; std::string proc_mul_mat_f32_f32_vec; @@ -2474,18 +2535,18 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { std::vector mul_mat_constants; #ifndef __EMSCRIPTEN__ - if (webgpu_ctx->supports_subgroup_matrix) { + if (webgpu_ctx->global_ctx->capabilities.supports_subgroup_matrix) { std::map sg_matrix_repls; - sg_matrix_repls["WEBGPU_MAX_SUBGROUP_SIZE"] = std::to_string(webgpu_ctx->max_subgroup_size); + sg_matrix_repls["WEBGPU_MAX_SUBGROUP_SIZE"] = + std::to_string(webgpu_ctx->global_ctx->capabilities.max_subgroup_size); sg_matrix_repls["WEBGPU_TILE_K"] = std::to_string(WEBGPU_MUL_MAT_TILE_K); sg_matrix_repls["WEBGPU_SUBGROUP_M"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_M); sg_matrix_repls["WEBGPU_SUBGROUP_N"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_N); sg_matrix_repls["WEBGPU_SUBGROUP_MATRIX_M"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M); sg_matrix_repls["WEBGPU_SUBGROUP_MATRIX_N"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N); - sg_matrix_repls["WEBGPU_SG_MAT_M_SIZE"] = std::to_string(webgpu_ctx->sg_mat_m); - sg_matrix_repls["WEBGPU_SG_MAT_N_SIZE"] = std::to_string(webgpu_ctx->sg_mat_n); - sg_matrix_repls["WEBGPU_SG_MAT_K_SIZE"] = std::to_string(webgpu_ctx->sg_mat_k); - + sg_matrix_repls["WEBGPU_SG_MAT_M_SIZE"] = std::to_string(webgpu_ctx->global_ctx->capabilities.sg_mat_m); + sg_matrix_repls["WEBGPU_SG_MAT_N_SIZE"] = std::to_string(webgpu_ctx->global_ctx->capabilities.sg_mat_n); + sg_matrix_repls["WEBGPU_SG_MAT_K_SIZE"] = std::to_string(webgpu_ctx->global_ctx->capabilities.sg_mat_k); proc_mul_mat_f32_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32, sg_matrix_repls); proc_mul_mat_f32_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32_vec, sg_matrix_repls); @@ -2522,21 +2583,21 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { #endif webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, proc_mul_mat_f32_f32.c_str(), "mul_mat_f32_f32", mul_mat_constants); + webgpu_ctx->global_ctx->device, proc_mul_mat_f32_f32.c_str(), "mul_mat_f32_f32", mul_mat_constants); webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, proc_mul_mat_f32_f32_vec.c_str(), "mul_mat_f32_f32_vec", mul_mat_constants); + webgpu_ctx->global_ctx->device, proc_mul_mat_f32_f32_vec.c_str(), "mul_mat_f32_f32_vec", mul_mat_constants); webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, proc_mul_mat_f16_f32.c_str(), "mul_mat_f16_f32", mul_mat_constants); + webgpu_ctx->global_ctx->device, proc_mul_mat_f16_f32.c_str(), "mul_mat_f16_f32", mul_mat_constants); webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, proc_mul_mat_f16_f32_vec.c_str(), "mul_mat_f16_f32_vec", mul_mat_constants); + webgpu_ctx->global_ctx->device, proc_mul_mat_f16_f32_vec.c_str(), "mul_mat_f16_f32_vec", mul_mat_constants); webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, proc_mul_mat_f16_f16.c_str(), "mul_mat_f16_f16", mul_mat_constants); + webgpu_ctx->global_ctx->device, proc_mul_mat_f16_f16.c_str(), "mul_mat_f16_f16", mul_mat_constants); webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, proc_mul_mat_f16_f16_vec.c_str(), "mul_mat_f16_f16_vec", mul_mat_constants); + webgpu_ctx->global_ctx->device, proc_mul_mat_f16_f16_vec.c_str(), "mul_mat_f16_f16_vec", mul_mat_constants); webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, proc_mul_mat_q4_0_f32.c_str(), "mul_mat_q4_0_f32", mul_mat_constants); + webgpu_ctx->global_ctx->device, proc_mul_mat_q4_0_f32.c_str(), "mul_mat_q4_0_f32", mul_mat_constants); webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, proc_mul_mat_q4_0_f32_vec.c_str(), "mul_mat_q4_0_f32_vec", mul_mat_constants); + webgpu_ctx->global_ctx->device, proc_mul_mat_q4_0_f32_vec.c_str(), "mul_mat_q4_0_f32_vec", mul_mat_constants); std::vector mul_mat_vec_constants(3); mul_mat_vec_constants[0].key = "WORKGROUP_SIZE"; @@ -2547,171 +2608,171 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { mul_mat_vec_constants[2].value = WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG; webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, wgsl_mul_mat_vec_f32_f32, "mul_mat_vec_f32_f32", mul_mat_vec_constants); + webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f32_f32, "mul_mat_vec_f32_f32", mul_mat_vec_constants); webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, wgsl_mul_mat_vec_f32_f32_vec, "mul_mat_vec_f32_f32_vec", mul_mat_vec_constants); + webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f32_f32_vec, "mul_mat_vec_f32_f32_vec", mul_mat_vec_constants); webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, wgsl_mul_mat_vec_f16_f32, "mul_mat_vec_f16_f32", mul_mat_vec_constants); + webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f16_f32, "mul_mat_vec_f16_f32", mul_mat_vec_constants); webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, wgsl_mul_mat_vec_f16_f32_vec, "mul_mat_vec_f16_f32_vec", mul_mat_vec_constants); + webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f16_f32_vec, "mul_mat_vec_f16_f32_vec", mul_mat_vec_constants); webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, wgsl_mul_mat_vec_f16_f16, "mul_mat_vec_f16_f16", mul_mat_vec_constants); + webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f16_f16, "mul_mat_vec_f16_f16", mul_mat_vec_constants); webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, wgsl_mul_mat_vec_f16_f16_vec, "mul_mat_vec_f16_f16_vec", mul_mat_vec_constants); + webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f16_f16_vec, "mul_mat_vec_f16_f16_vec", mul_mat_vec_constants); webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, wgsl_mul_mat_vec_q4_0_f32, "mul_mat_vec_q4_0_f32", mul_mat_vec_constants); + webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_q4_0_f32, "mul_mat_vec_q4_0_f32", mul_mat_vec_constants); } static void ggml_webgpu_init_get_rows_pipeline(webgpu_context & webgpu_ctx) { std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); webgpu_ctx->get_rows_pipelines[GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_f32, "get_rows_f32", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_f32_vec, "get_rows_f32_vec", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_f32, "get_rows_f32", constants); + webgpu_ctx->get_rows_pipelines[GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( + webgpu_ctx->global_ctx->device, wgsl_get_rows_f32_vec, "get_rows_f32_vec", constants); webgpu_ctx->get_rows_pipelines[GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_f16, "get_rows_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_f16, "get_rows_f16", constants); webgpu_ctx->get_rows_pipelines[GGML_TYPE_I32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_i32, "get_rows_i32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_i32, "get_rows_i32", constants); webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q4_0][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q4_0, "get_rows_q4_0", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q4_0, "get_rows_q4_0", constants); webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q4_1][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q4_1, "get_rows_q4_1", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q4_1, "get_rows_q4_1", constants); webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q5_0][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q5_0, "get_rows_q5_0", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q5_0, "get_rows_q5_0", constants); webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q5_1][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q5_1, "get_rows_q5_1", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q5_1, "get_rows_q5_1", constants); webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q8_0][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q8_0, "get_rows_q8_0", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q8_0, "get_rows_q8_0", constants); webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q2_K][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q2_k, "get_rows_q2_k", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q2_k, "get_rows_q2_k", constants); webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q3_K][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q3_k, "get_rows_q3_k", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q3_k, "get_rows_q3_k", constants); webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q4_K][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q4_k, "get_rows_q4_k", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q4_k, "get_rows_q4_k", constants); webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q5_K][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q5_k, "get_rows_q5_k", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q5_k, "get_rows_q5_k", constants); webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q6_K][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q6_k, "get_rows_q6_k", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q6_k, "get_rows_q6_k", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ2_XXS][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq2_xxs, "get_rows_iq2_xxs", constants); + webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ2_XXS][0] = ggml_webgpu_create_pipeline( + webgpu_ctx->global_ctx->device, wgsl_get_rows_iq2_xxs, "get_rows_iq2_xxs", constants); webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ2_XS][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq2_xs, "get_rows_iq2_xs", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq2_xs, "get_rows_iq2_xs", constants); webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ2_S][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq2_s, "get_rows_iq2_s", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ3_XXS][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq3_xxs, "get_rows_iq3_xxs", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq2_s, "get_rows_iq2_s", constants); + webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ3_XXS][0] = ggml_webgpu_create_pipeline( + webgpu_ctx->global_ctx->device, wgsl_get_rows_iq3_xxs, "get_rows_iq3_xxs", constants); webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ3_S][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq3_s, "get_rows_iq3_s", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq3_s, "get_rows_iq3_s", constants); webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ1_S][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq1_s, "get_rows_iq1_s", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq1_s, "get_rows_iq1_s", constants); webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ1_M][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq1_m, "get_rows_iq1_m", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq1_m, "get_rows_iq1_m", constants); webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ4_NL][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq4_nl, "get_rows_iq4_nl", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq4_nl, "get_rows_iq4_nl", constants); webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ4_XS][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq4_xs, "get_rows_iq4_xs", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq4_xs, "get_rows_iq4_xs", constants); } static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) { std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_F32] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_cpy_f32_f32, "cpy_f32_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f32_f32, "cpy_f32_f32", constants); webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_I32] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_cpy_f32_i32, "cpy_f32_i32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f32_i32, "cpy_f32_i32", constants); webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_F16] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_cpy_f32_f16, "cpy_f32_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f32_f16, "cpy_f32_f16", constants); webgpu_ctx->cpy_pipelines[GGML_TYPE_F16][GGML_TYPE_F32] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_cpy_f16_f32, "cpy_f16_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f16_f32, "cpy_f16_f32", constants); webgpu_ctx->cpy_pipelines[GGML_TYPE_F16][GGML_TYPE_F16] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_cpy_f16_f16, "cpy_f16_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f16_f16, "cpy_f16_f16", constants); } static void ggml_webgpu_init_add_pipeline(webgpu_context & webgpu_ctx) { std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); webgpu_ctx->add_pipelines[GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_add_f32, "add_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_add_f32, "add_f32", constants); webgpu_ctx->add_pipelines[GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_add_f16, "add_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_add_f16, "add_f16", constants); webgpu_ctx->add_pipelines[GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_add_f32_inplace, "add_f32_inplace", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_add_f32_inplace, "add_f32_inplace", constants); webgpu_ctx->add_pipelines[GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_add_f16_inplace, "add_f16_inplace", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_add_f16_inplace, "add_f16_inplace", constants); } static void ggml_webgpu_init_sub_pipeline(webgpu_context & webgpu_ctx) { std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); webgpu_ctx->sub_pipelines[GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sub_f32, "sub_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_sub_f32, "sub_f32", constants); webgpu_ctx->sub_pipelines[GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sub_f16, "sub_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_sub_f16, "sub_f16", constants); webgpu_ctx->sub_pipelines[GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sub_f32_inplace, "sub_f32_inplace", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_sub_f32_inplace, "sub_f32_inplace", constants); webgpu_ctx->sub_pipelines[GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sub_f16_inplace, "sub_f16_inplace", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_sub_f16_inplace, "sub_f16_inplace", constants); } static void ggml_webgpu_init_mul_pipeline(webgpu_context & webgpu_ctx) { std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); webgpu_ctx->mul_pipelines[GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_f32, "mul_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_f32, "mul_f32", constants); webgpu_ctx->mul_pipelines[GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_f16, "mul_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_f16, "mul_f16", constants); webgpu_ctx->mul_pipelines[GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_f32_inplace, "mul_f32_inplace", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_f32_inplace, "mul_f32_inplace", constants); webgpu_ctx->mul_pipelines[GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_f16_inplace, "mul_f16_inplace", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_f16_inplace, "mul_f16_inplace", constants); } static void ggml_webgpu_init_div_pipeline(webgpu_context & webgpu_ctx) { std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); webgpu_ctx->div_pipelines[GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_div_f32, "div_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_div_f32, "div_f32", constants); webgpu_ctx->div_pipelines[GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_div_f16, "div_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_div_f16, "div_f16", constants); webgpu_ctx->div_pipelines[GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_div_f32_inplace, "div_f32_inplace", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_div_f32_inplace, "div_f32_inplace", constants); webgpu_ctx->div_pipelines[GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_div_f16_inplace, "div_f16_inplace", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_div_f16_inplace, "div_f16_inplace", constants); } static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) { std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE); webgpu_ctx->rms_norm_pipelines[0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rms_norm, "rms_norm", constants); - webgpu_ctx->rms_norm_pipelines[1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rms_norm_inplace, "rms_norm_inplace", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rms_norm, "rms_norm", constants); + webgpu_ctx->rms_norm_pipelines[1] = ggml_webgpu_create_pipeline( + webgpu_ctx->global_ctx->device, wgsl_rms_norm_inplace, "rms_norm_inplace", constants); } static void ggml_webgpu_init_rope_pipeline(webgpu_context & webgpu_ctx) { std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); webgpu_ctx->rope_pipelines[GGML_TYPE_F32][0][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f32, "rope_f32", constants); - webgpu_ctx->rope_pipelines[GGML_TYPE_F32][0][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f32_inplace, "rope_f32_inplace", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f32, "rope_f32", constants); + webgpu_ctx->rope_pipelines[GGML_TYPE_F32][0][1] = ggml_webgpu_create_pipeline( + webgpu_ctx->global_ctx->device, wgsl_rope_f32_inplace, "rope_f32_inplace", constants); webgpu_ctx->rope_pipelines[GGML_TYPE_F32][1][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f32_ff, "rope_f32_ff", constants); - webgpu_ctx->rope_pipelines[GGML_TYPE_F32][1][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f32_ff_inplace, "rope_f32_ff_inplace", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f32_ff, "rope_f32_ff", constants); + webgpu_ctx->rope_pipelines[GGML_TYPE_F32][1][1] = ggml_webgpu_create_pipeline( + webgpu_ctx->global_ctx->device, wgsl_rope_f32_ff_inplace, "rope_f32_ff_inplace", constants); webgpu_ctx->rope_pipelines[GGML_TYPE_F16][0][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f16, "rope_f16", constants); - webgpu_ctx->rope_pipelines[GGML_TYPE_F16][0][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f16_inplace, "rope_f16_inplace", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f16, "rope_f16", constants); + webgpu_ctx->rope_pipelines[GGML_TYPE_F16][0][1] = ggml_webgpu_create_pipeline( + webgpu_ctx->global_ctx->device, wgsl_rope_f16_inplace, "rope_f16_inplace", constants); webgpu_ctx->rope_pipelines[GGML_TYPE_F16][1][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f16_ff, "rope_f16_ff", constants); - webgpu_ctx->rope_pipelines[GGML_TYPE_F16][1][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f16_ff_inplace, "rope_f16_ff_inplace", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f16_ff, "rope_f16_ff", constants); + webgpu_ctx->rope_pipelines[GGML_TYPE_F16][1][1] = ggml_webgpu_create_pipeline( + webgpu_ctx->global_ctx->device, wgsl_rope_f16_ff_inplace, "rope_f16_ff_inplace", constants); } static void ggml_webgpu_init_glu_pipeline(webgpu_context & webgpu_ctx) { @@ -2719,68 +2780,68 @@ static void ggml_webgpu_init_glu_pipeline(webgpu_context & webgpu_ctx) { // REGLU webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_reglu_f32, "reglu_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f32, "reglu_f32", constants); webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_reglu_f16, "reglu_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f16, "reglu_f16", constants); webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_reglu_f32_split, "reglu_f32_split", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f32_split, "reglu_f32_split", constants); webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_reglu_f16_split, "reglu_f16_split", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f16_split, "reglu_f16_split", constants); // GEGLU webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_f32, "geglu_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f32, "geglu_f32", constants); webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_f16, "geglu_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f16, "geglu_f16", constants); webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_f32_split, "geglu_f32_split", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f32_split, "geglu_f32_split", constants); webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_f16_split, "geglu_f16_split", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f16_split, "geglu_f16_split", constants); // SWIGLU webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_swiglu_f32, "swiglu_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_swiglu_f32, "swiglu_f32", constants); webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_swiglu_f16, "swiglu_f16", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_swiglu_f32_split, "swiglu_f32_split", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_swiglu_f16_split, "swiglu_f16_split", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_swiglu_f16, "swiglu_f16", constants); + webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( + webgpu_ctx->global_ctx->device, wgsl_swiglu_f32_split, "swiglu_f32_split", constants); + webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline( + webgpu_ctx->global_ctx->device, wgsl_swiglu_f16_split, "swiglu_f16_split", constants); // SWIGLU_OAI webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_swiglu_oai_f32, "swiglu_oai_f32", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_swiglu_oai_f32_split, "swiglu_oai_f32_split", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_swiglu_oai_f32, "swiglu_oai_f32", constants); + webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( + webgpu_ctx->global_ctx->device, wgsl_swiglu_oai_f32_split, "swiglu_oai_f32_split", constants); // GEGLU_ERF webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_erf_f32, "geglu_erf_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f32, "geglu_erf_f32", constants); webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_erf_f16, "geglu_erf_f16", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_erf_f32_split, "geglu_erf_f32_split", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_erf_f16_split, "geglu_erf_f16_split", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f16, "geglu_erf_f16", constants); + webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( + webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f32_split, "geglu_erf_f32_split", constants); + webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline( + webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f16_split, "geglu_erf_f16_split", constants); // GEGLU_QUICK webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_quick_f32, "geglu_quick_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f32, "geglu_quick_f32", constants); webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_quick_f16, "geglu_quick_f16", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_quick_f32_split, "geglu_quick_f32_split", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_quick_f16_split, "geglu_quick_f16_split", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f16, "geglu_quick_f16", constants); + webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( + webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f32_split, "geglu_quick_f32_split", constants); + webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline( + webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f16_split, "geglu_quick_f16_split", constants); } static void ggml_webgpu_init_scale_pipeline(webgpu_context & webgpu_ctx) { std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); webgpu_ctx->scale_pipelines[0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_scale_f32, "scale_f32", constants); - webgpu_ctx->scale_pipelines[1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_scale_f32_inplace, "scale_f32_inplace", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_scale_f32, "scale_f32", constants); + webgpu_ctx->scale_pipelines[1] = ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_scale_f32_inplace, + "scale_f32_inplace", constants); } static void ggml_webgpu_init_soft_max_pipeline(webgpu_context & webgpu_ctx) { @@ -2788,56 +2849,243 @@ static void ggml_webgpu_init_soft_max_pipeline(webgpu_context & webgpu_ctx) { // f32 (no mask) webgpu_ctx->soft_max_pipelines[2][0][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_soft_max_f32, "soft_max_f32", constants); - webgpu_ctx->soft_max_pipelines[2][0][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_soft_max_f32_inplace, "soft_max_f32_inplace", constants); - webgpu_ctx->soft_max_pipelines[2][1][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_soft_max_f32_sink, "soft_max_f32_sink", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_soft_max_f32, "soft_max_f32", constants); + webgpu_ctx->soft_max_pipelines[2][0][1] = ggml_webgpu_create_pipeline( + webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_inplace, "soft_max_f32_inplace", constants); + webgpu_ctx->soft_max_pipelines[2][1][0] = ggml_webgpu_create_pipeline( + webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_sink, "soft_max_f32_sink", constants); webgpu_ctx->soft_max_pipelines[2][1][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, wgsl_soft_max_f32_sink_inplace, "soft_max_f32_sink_inplace", constants); + webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_sink_inplace, "soft_max_f32_sink_inplace", constants); // f32 mask (mask_type = 0) - webgpu_ctx->soft_max_pipelines[0][0][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_soft_max_f32_mask_f32, "soft_max_f32_mask_f32", constants); + webgpu_ctx->soft_max_pipelines[0][0][0] = ggml_webgpu_create_pipeline( + webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32, "soft_max_f32_mask_f32", constants); webgpu_ctx->soft_max_pipelines[0][0][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, wgsl_soft_max_f32_mask_f32_inplace, "soft_max_f32_mask_f32_inplace", constants); + webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32_inplace, "soft_max_f32_mask_f32_inplace", constants); webgpu_ctx->soft_max_pipelines[0][1][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, wgsl_soft_max_f32_mask_f32_sink, "soft_max_f32_mask_f32_sink", constants); - webgpu_ctx->soft_max_pipelines[0][1][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, wgsl_soft_max_f32_mask_f32_sink_inplace, "soft_max_f32_mask_f32_sink_inplace", constants); + webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32_sink, "soft_max_f32_mask_f32_sink", constants); + webgpu_ctx->soft_max_pipelines[0][1][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32_sink_inplace, + "soft_max_f32_mask_f32_sink_inplace", constants); // f16 mask (mask_type = 1) - webgpu_ctx->soft_max_pipelines[1][0][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_soft_max_f32_mask_f16, "soft_max_f32_mask_f16", constants); + webgpu_ctx->soft_max_pipelines[1][0][0] = ggml_webgpu_create_pipeline( + webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16, "soft_max_f32_mask_f16", constants); webgpu_ctx->soft_max_pipelines[1][0][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, wgsl_soft_max_f32_mask_f16_inplace, "soft_max_f32_mask_f16_inplace", constants); + webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16_inplace, "soft_max_f32_mask_f16_inplace", constants); webgpu_ctx->soft_max_pipelines[1][1][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, wgsl_soft_max_f32_mask_f16_sink, "soft_max_f32_mask_f16_sink", constants); - webgpu_ctx->soft_max_pipelines[1][1][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, wgsl_soft_max_f32_mask_f16_sink_inplace, "soft_max_f32_mask_f16_sink_inplace", constants); + webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16_sink, "soft_max_f32_mask_f16_sink", constants); + webgpu_ctx->soft_max_pipelines[1][1][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16_sink_inplace, + "soft_max_f32_mask_f16_sink_inplace", constants); +} + +static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { + wgpu::RequestAdapterOptions options = {}; + +#ifndef __EMSCRIPTEN__ + // TODO: track need for these toggles: https://issues.chromium.org/issues/42251215 + const char * const adapterEnabledToggles[] = { "vulkan_enable_f16_on_nvidia", "use_vulkan_memory_model" }; + wgpu::DawnTogglesDescriptor adapterTogglesDesc; + adapterTogglesDesc.enabledToggles = adapterEnabledToggles; + adapterTogglesDesc.enabledToggleCount = 2; + options.nextInChain = &adapterTogglesDesc; +#endif + + ctx->webgpu_global_ctx->instance.WaitAny( + ctx->webgpu_global_ctx->instance.RequestAdapter( + &options, wgpu::CallbackMode::AllowSpontaneous, + [&ctx](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message) { + if (status != wgpu::RequestAdapterStatus::Success) { + GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message); + return; + } + ctx->webgpu_global_ctx->adapter = std::move(adapter); + }), + UINT64_MAX); + GGML_ASSERT(ctx->webgpu_global_ctx->adapter != nullptr); + + ctx->webgpu_global_ctx->adapter.GetLimits(&ctx->webgpu_global_ctx->capabilities.limits); + + wgpu::AdapterInfo info{}; +#ifndef __EMSCRIPTEN__ + wgpu::AdapterPropertiesSubgroupMatrixConfigs subgroup_matrix_configs{}; + if (ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) { + info.nextInChain = &subgroup_matrix_configs; + } +#endif + ctx->webgpu_global_ctx->adapter.GetInfo(&info); + wgpu::SupportedFeatures features; + ctx->webgpu_global_ctx->adapter.GetFeatures(&features); + // we require f16 support + GGML_ASSERT(ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ShaderF16)); + +#ifndef __EMSCRIPTEN__ + // Only support square f16 matrices of size 8 or 16 for now + bool valid_subgroup_matrix_config = false; + if (ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) { + for (size_t i = 0; i < subgroup_matrix_configs.configCount; i++) { + const wgpu::SubgroupMatrixConfig config = subgroup_matrix_configs.configs[i]; + if (config.M == config.N && config.N == config.K && (config.K == 8 || config.K == 16) && + config.componentType == wgpu::SubgroupMatrixComponentType::F16 && + config.resultComponentType == wgpu::SubgroupMatrixComponentType::F16) { + ctx->webgpu_global_ctx->capabilities.sg_mat_m = config.M; + ctx->webgpu_global_ctx->capabilities.sg_mat_n = config.N; + ctx->webgpu_global_ctx->capabilities.sg_mat_k = config.K; + valid_subgroup_matrix_config = true; + break; + } + } + } + ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix = valid_subgroup_matrix_config; +#endif + + // For subgroup matrix code to be the most efficient, we would like the subgroup size to be consistent and accurate. + // Unfortunately, that is not possible, so we use the maximum subgroup size reported by the adapter. + ctx->webgpu_global_ctx->capabilities.max_subgroup_size = info.subgroupMaxSize; + // Initialize device + std::vector required_features = { wgpu::FeatureName::ShaderF16 }; + +#ifndef __EMSCRIPTEN__ + required_features.push_back(wgpu::FeatureName::ImplicitDeviceSynchronization); + if (ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix) { + required_features.push_back(wgpu::FeatureName::Subgroups); + required_features.push_back(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix); + } +#endif + +#ifdef GGML_WEBGPU_GPU_PROFILE + required_features.push_back(wgpu::FeatureName::TimestampQuery); +#endif + + wgpu::DeviceDescriptor dev_desc; + dev_desc.requiredLimits = &ctx->webgpu_global_ctx->capabilities.limits; + dev_desc.requiredFeatures = required_features.data(); + dev_desc.requiredFeatureCount = required_features.size(); + dev_desc.SetDeviceLostCallback( + wgpu::CallbackMode::AllowSpontaneous, + [](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) { + GGML_UNUSED(device); + GGML_UNUSED(reason); + GGML_UNUSED(message); + //TODO: uncomment once proper free logic is in place + //GGML_LOG_ERROR("ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast(reason), + //std::string(message).c_str()); + }); + dev_desc.SetUncapturedErrorCallback( + [](const wgpu::Device & device, wgpu::ErrorType reason, wgpu::StringView message) { + GGML_UNUSED(device); + GGML_ABORT("ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast(reason), + std::string(message).c_str()); + }); + +#ifndef __EMSCRIPTEN__ + // Enable Dawn-specific toggles to increase native performance + // TODO: Maybe WebGPU needs a "fast" mode where you can request compilers skip adding checks like these, + // only for native performance? + const char * const deviceEnabledToggles[] = { "skip_validation", "disable_robustness", "disable_workgroup_init", + "disable_polyfills_on_integer_div_and_mod" }; + const char * const deviceDisabledToggles[] = { "timestamp_quantization" }; + wgpu::DawnTogglesDescriptor deviceTogglesDesc; + deviceTogglesDesc.enabledToggles = deviceEnabledToggles; + deviceTogglesDesc.enabledToggleCount = 4; + deviceTogglesDesc.disabledToggles = deviceDisabledToggles; + deviceTogglesDesc.disabledToggleCount = 1; + + dev_desc.nextInChain = &deviceTogglesDesc; +#endif + + ctx->webgpu_global_ctx->instance.WaitAny( + ctx->webgpu_global_ctx->adapter.RequestDevice( + &dev_desc, wgpu::CallbackMode::AllowSpontaneous, + [ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) { + if (status != wgpu::RequestDeviceStatus::Success) { + GGML_LOG_ERROR("ggml_webgpu: Failed to get a device: %s\n", std::string(message).c_str()); + return; + } + ctx->webgpu_global_ctx->device = std::move(device); + }), + UINT64_MAX); + GGML_ASSERT(ctx->webgpu_global_ctx->device != nullptr); + + ggml_webgpu_init_memset_pipeline(ctx->webgpu_global_ctx); + ctx->webgpu_global_ctx->memset_buf_pool.init(ctx->webgpu_global_ctx->device, 1, WEBGPU_PARAMS_BUF_SIZE_BYTES, + wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform, + wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite); + ctx->webgpu_global_ctx->queue = ctx->webgpu_global_ctx->device.GetQueue(); + +#ifdef GGML_WEBGPU_GPU_PROFILE + // Initialize buffer pool for timestamp queries, used for profiling + ctx->webgpu_global_ctx->timestamp_query_buf_pool.init(ctx->webgpu_global_ctx->device, WEBGPU_NUM_TIMESTAMP_QUERY_BUFS, + WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES, + wgpu::BufferUsage::QueryResolve | wgpu::BufferUsage::CopySrc, + wgpu::BufferUsage::MapRead | wgpu::BufferUsage::CopyDst); +#endif + + GGML_LOG_INFO( + "ggml_webgpu: adapter_info: vendor_id: %u | vendor: %s | architecture: %s | device_id: %u | name: %s | " + "device_desc: %s\n", + info.vendorID, std::string(info.vendor).c_str(), std::string(info.architecture).c_str(), info.deviceID, + std::string(info.device).c_str(), std::string(info.description).c_str()); + return true; +} + +static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) { + ggml_backend_webgpu_device_context * dev_ctx = (ggml_backend_webgpu_device_context *) dev->context; + webgpu_context webgpu_ctx = std::make_shared(); + webgpu_ctx->global_ctx = dev_ctx->webgpu_global_ctx; + webgpu_ctx->param_buf_pool.init(webgpu_ctx->global_ctx->device, WEBGPU_NUM_PARAM_BUFS, WEBGPU_PARAMS_BUF_SIZE_BYTES, + wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform, + wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite); + webgpu_ctx->set_rows_error_buf_pool.init(webgpu_ctx->global_ctx->device, WEBGPU_NUM_SET_ROWS_ERROR_BUFS, + WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES, + wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage, + wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead); + + ggml_webgpu_init_mul_mat_pipeline(webgpu_ctx); + ggml_webgpu_init_get_rows_pipeline(webgpu_ctx); + ggml_webgpu_init_cpy_pipeline(webgpu_ctx); + ggml_webgpu_init_add_pipeline(webgpu_ctx); + ggml_webgpu_init_sub_pipeline(webgpu_ctx); + ggml_webgpu_init_mul_pipeline(webgpu_ctx); + ggml_webgpu_init_div_pipeline(webgpu_ctx); + ggml_webgpu_init_rms_norm_pipeline(webgpu_ctx); + ggml_webgpu_init_rope_pipeline(webgpu_ctx); + ggml_webgpu_init_glu_pipeline(webgpu_ctx); + ggml_webgpu_init_scale_pipeline(webgpu_ctx); + ggml_webgpu_init_soft_max_pipeline(webgpu_ctx); +#ifdef GGML_WEBGPU_DEBUG + // Initialize debug buffers + ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->global_ctx->debug_host_buf, + WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t), + wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "debug_host_buf"); + ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->global_ctx->debug_dev_buf, + WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t), + wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc, "debug_dev_buf"); +#endif + return webgpu_ctx; } -// TODO: move most initialization logic here -static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char * params) { +static ggml_backend_t ggml_backend_webgpu_backend_init(ggml_backend_dev_t dev, const char * params) { GGML_UNUSED(params); - WEBGPU_LOG_DEBUG("ggml_backend_webgpu_device_init()"); + WEBGPU_LOG_DEBUG("ggml_backend_webgpu_backend_init()"); - ggml_backend_webgpu_device_context * dev_ctx = static_cast(dev->context); - webgpu_context webgpu_ctx = dev_ctx->webgpu_ctx; + ggml_backend_webgpu_device_context * dev_ctx = static_cast(dev->context); - static ggml_backend_webgpu_context backend_ctx; - backend_ctx.name = GGML_WEBGPU_NAME + std::string(": ") + dev_ctx->device_name; - backend_ctx.webgpu_ctx = webgpu_ctx; + auto * backend_ctx = new ggml_backend_webgpu_context(); + backend_ctx->name = GGML_WEBGPU_NAME + std::string(": ") + dev_ctx->device_name; + backend_ctx->webgpu_ctx = initialize_webgpu_context(dev); // See GGML Backend Interface section - static ggml_backend backend = { + auto * backend = new ggml_backend(); + *backend = { /* .guid = */ ggml_backend_webgpu_guid(), /* .interface = */ ggml_backend_webgpu_i, /* .device = */ dev, - /* .context = */ &backend_ctx, + /* .context = */ backend_ctx, }; - return &backend; + return backend; } static ggml_backend_buffer_type_t ggml_backend_webgpu_device_get_buffer_type(ggml_backend_dev_t dev) { @@ -2854,7 +3102,8 @@ static ggml_backend_buffer_type_t ggml_backend_webgpu_device_get_buffer_type(ggm }, /* .device = */ dev, - /* .context = */ NULL, + /* .context = */ + NULL }; return &ggml_backend_webgpu_buffer_type; @@ -2895,16 +3144,16 @@ static bool ggml_webgpu_supported_qtype(ggml_type type) { static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) { ggml_backend_webgpu_device_context * ctx = static_cast(dev->context); - webgpu_context webgpu_ctx = ctx->webgpu_ctx; - ggml_tensor * src0 = op->src[0]; ggml_tensor * src1 = op->src[1]; ggml_tensor * src2 = op->src[2]; // on smaller devices (or CI), tensors may be larger than the max storage buffer size - if (ggml_nbytes(op) > webgpu_ctx->limits.maxStorageBufferBindingSize || - (src0 != nullptr && ggml_nbytes(src0) > webgpu_ctx->limits.maxStorageBufferBindingSize) || - (src1 != nullptr && ggml_nbytes(src1) > webgpu_ctx->limits.maxStorageBufferBindingSize)) { + if (ggml_nbytes(op) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize || + (src0 != nullptr && + ggml_nbytes(src0) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize) || + (src1 != nullptr && + ggml_nbytes(src1) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize)) { return false; } @@ -2984,17 +3233,19 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const } case GGML_OP_FLASH_ATTN_EXT: { - if (!webgpu_ctx->supports_subgroup_matrix) { +#ifndef __EMSCRIPTEN__ + if (!ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix) { break; } // Head dimensions must fit in workgroup memory with minimum tile sizes - size_t limit_bytes = webgpu_ctx->limits.maxComputeWorkgroupStorageSize; + size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; const bool has_mask = op->src[3] != nullptr; - const bool kv_direct = src1->type == GGML_TYPE_F16 && (src0->ne[0] % webgpu_ctx->sg_mat_k) == 0 && + const bool kv_direct = src1->type == GGML_TYPE_F16 && + (src0->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k) == 0 && (src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD) == 0; const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes( - webgpu_ctx->sg_mat_m, webgpu_ctx->sg_mat_n, (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], - has_mask, kv_direct); + ctx->webgpu_global_ctx->capabilities.sg_mat_m, ctx->webgpu_global_ctx->capabilities.sg_mat_n, + (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask, kv_direct); if (min_bytes > limit_bytes) { break; } @@ -3003,6 +3254,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const (src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_Q4_0 || src1->type == GGML_TYPE_Q8_0) && src2->type == src1->type && op->type == GGML_TYPE_F32; +#endif break; } case GGML_OP_RMS_NORM: @@ -3099,10 +3351,13 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const default: break; } - if (ggml_nbytes(op) > webgpu_ctx->limits.maxStorageBufferBindingSize || - (src0 != nullptr && ggml_nbytes(src0) > webgpu_ctx->limits.maxStorageBufferBindingSize) || - (src1 != nullptr && ggml_nbytes(src1) > webgpu_ctx->limits.maxStorageBufferBindingSize) || - (src2 != nullptr && ggml_nbytes(src2) > webgpu_ctx->limits.maxStorageBufferBindingSize)) { + if (ggml_nbytes(op) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize || + (src0 != nullptr && + ggml_nbytes(src0) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize) || + (src1 != nullptr && + ggml_nbytes(src1) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize) || + (src2 != nullptr && + ggml_nbytes(src2) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize)) { supports_op = false; WEBGPU_LOG_DEBUG("ggml_webgpu op not supported due to size: "); } @@ -3127,7 +3382,7 @@ static struct ggml_backend_device_i ggml_backend_webgpu_device_i = { /* .get_memory = */ ggml_backend_webgpu_device_get_memory, /* .get_type = */ ggml_backend_webgpu_device_get_type, /* .get_props = */ ggml_backend_webgpu_device_get_props, - /* .init_backend = */ ggml_backend_webgpu_device_init, + /* .init_backend = */ ggml_backend_webgpu_backend_init, /* .get_buffer_type = */ ggml_backend_webgpu_device_get_buffer_type, /* .get_host_buffer_type = */ NULL, /* .buffer_from_host_ptr = */ NULL, @@ -3156,6 +3411,7 @@ static size_t ggml_backend_webgpu_reg_get_device_count(ggml_backend_reg_t reg) { // TODO: Does this need to be thread safe? Is it only called once? // TODO: move most logic to device_init function so backend can be freed/initialized properly // Only one device is supported for now + static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t reg, size_t index) { GGML_ASSERT(index == 0); WEBGPU_LOG_DEBUG("ggml_backend_reg_get_device()"); @@ -3164,189 +3420,12 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t ggml_backend_webgpu_reg_context * reg_ctx = static_cast(reg->context); - webgpu_context ctx = reg_ctx->webgpu_ctx; - - wgpu::RequestAdapterOptions options = {}; - -#ifndef __EMSCRIPTEN__ - // TODO: track need for these toggles: https://issues.chromium.org/issues/42251215 - const char * const adapterEnabledToggles[] = { "vulkan_enable_f16_on_nvidia", "use_vulkan_memory_model" }; - wgpu::DawnTogglesDescriptor adapterTogglesDesc; - adapterTogglesDesc.enabledToggles = adapterEnabledToggles; - adapterTogglesDesc.enabledToggleCount = 2; - options.nextInChain = &adapterTogglesDesc; -#endif - - ctx->instance.WaitAny(ctx->instance.RequestAdapter( - &options, wgpu::CallbackMode::AllowSpontaneous, - [&ctx](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message) { - if (status != wgpu::RequestAdapterStatus::Success) { - GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message); - return; - } - ctx->adapter = std::move(adapter); - }), - UINT64_MAX); - GGML_ASSERT(ctx->adapter != nullptr); - - ctx->adapter.GetLimits(&ctx->limits); - - wgpu::AdapterInfo info{}; -#ifndef __EMSCRIPTEN__ - wgpu::AdapterPropertiesSubgroupMatrixConfigs subgroup_matrix_configs{}; - if (ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) { - info.nextInChain = &subgroup_matrix_configs; - } -#endif - ctx->adapter.GetInfo(&info); - - wgpu::SupportedFeatures features; - ctx->adapter.GetFeatures(&features); - // we require f16 support - GGML_ASSERT(ctx->adapter.HasFeature(wgpu::FeatureName::ShaderF16)); - -#ifndef __EMSCRIPTEN__ - // Only support square f16 matrices of size 8 or 16 for now - bool valid_subgroup_matrix_config = false; - if (ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) { - for (size_t i = 0; i < subgroup_matrix_configs.configCount; i++) { - const wgpu::SubgroupMatrixConfig config = subgroup_matrix_configs.configs[i]; - if (config.M == config.N && config.N == config.K && (config.K == 8 || config.K == 16) && - config.componentType == wgpu::SubgroupMatrixComponentType::F16 && - config.resultComponentType == wgpu::SubgroupMatrixComponentType::F16) { - ctx->sg_mat_m = config.M; - ctx->sg_mat_n = config.N; - ctx->sg_mat_k = config.K; - valid_subgroup_matrix_config = true; - break; - } - } - } - - ctx->supports_subgroup_matrix = valid_subgroup_matrix_config; -#endif - // For subgroup matrix code to be the most efficient, we would like the subgroup size to be consistent and accurate. - // Unfortunately, that is not possible, so we use the maximum subgroup size reported by the adapter. - ctx->max_subgroup_size = info.subgroupMaxSize; - - // Initialize device - std::vector required_features = { wgpu::FeatureName::ShaderF16 }; - -#ifndef __EMSCRIPTEN__ - required_features.push_back(wgpu::FeatureName::ImplicitDeviceSynchronization); - if (ctx->supports_subgroup_matrix) { - required_features.push_back(wgpu::FeatureName::Subgroups); - required_features.push_back(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix); - } -#endif - -#ifdef GGML_WEBGPU_GPU_PROFILE - required_features.push_back(wgpu::FeatureName::TimestampQuery); -#endif - - wgpu::DeviceDescriptor dev_desc; - dev_desc.requiredLimits = &ctx->limits; - dev_desc.requiredFeatures = required_features.data(); - dev_desc.requiredFeatureCount = required_features.size(); - dev_desc.SetDeviceLostCallback( - wgpu::CallbackMode::AllowSpontaneous, - [](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) { - GGML_UNUSED(device); - GGML_UNUSED(reason); - GGML_UNUSED(message); - //TODO: uncomment once proper free logic is in place - //GGML_LOG_ERROR("ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast(reason), - //std::string(message).c_str()); - }); - dev_desc.SetUncapturedErrorCallback( - [](const wgpu::Device & device, wgpu::ErrorType reason, wgpu::StringView message) { - GGML_UNUSED(device); - GGML_ABORT("ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast(reason), - std::string(message).c_str()); - }); - -#ifndef __EMSCRIPTEN__ - // Enable Dawn-specific toggles to increase native performance - // TODO: Maybe WebGPU needs a "fast" mode where you can request compilers skip adding checks like these, - // only for native performance? - const char * const deviceEnabledToggles[] = { "skip_validation", "disable_robustness", "disable_workgroup_init", - "disable_polyfills_on_integer_div_and_mod" }; - const char * const deviceDisabledToggles[] = { "timestamp_quantization" }; - wgpu::DawnTogglesDescriptor deviceTogglesDesc; - deviceTogglesDesc.enabledToggles = deviceEnabledToggles; - deviceTogglesDesc.enabledToggleCount = 4; - deviceTogglesDesc.disabledToggles = deviceDisabledToggles; - deviceTogglesDesc.disabledToggleCount = 1; - - dev_desc.nextInChain = &deviceTogglesDesc; -#endif - - ctx->instance.WaitAny(ctx->adapter.RequestDevice( - &dev_desc, wgpu::CallbackMode::AllowSpontaneous, - [ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) { - if (status != wgpu::RequestDeviceStatus::Success) { - GGML_LOG_ERROR("ggml_webgpu: Failed to get a device: %s\n", - std::string(message).c_str()); - return; - } - ctx->device = std::move(device); - }), - UINT64_MAX); - GGML_ASSERT(ctx->device != nullptr); - - // Initialize (compute) queue - ctx->queue = ctx->device.GetQueue(); - - // Create buffer pool for shader parameters - ctx->param_buf_pool.init(ctx->device, WEBGPU_NUM_PARAM_BUFS, WEBGPU_PARAMS_BUF_SIZE_BYTES, - wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform, - wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite); - -#ifdef GGML_WEBGPU_GPU_PROFILE - // Initialize buffer pool for timestamp queries (profiling) - ctx->timestamp_query_buf_pool.init(ctx->device, WEBGPU_NUM_TIMESTAMP_QUERY_BUFS, - WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES, - wgpu::BufferUsage::QueryResolve | wgpu::BufferUsage::CopySrc, - wgpu::BufferUsage::MapRead | wgpu::BufferUsage::CopyDst); -#endif - - ctx->set_rows_error_buf_pool.init(ctx->device, WEBGPU_NUM_SET_ROWS_ERROR_BUFS, WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES, - wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage, - wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead); - - ggml_webgpu_init_memset_pipeline(ctx); - ggml_webgpu_init_mul_mat_pipeline(ctx); - ggml_webgpu_init_get_rows_pipeline(ctx); - ggml_webgpu_init_cpy_pipeline(ctx); - ggml_webgpu_init_add_pipeline(ctx); - ggml_webgpu_init_sub_pipeline(ctx); - ggml_webgpu_init_mul_pipeline(ctx); - ggml_webgpu_init_div_pipeline(ctx); - ggml_webgpu_init_rms_norm_pipeline(ctx); - ggml_webgpu_init_rope_pipeline(ctx); - ggml_webgpu_init_glu_pipeline(ctx); - ggml_webgpu_init_scale_pipeline(ctx); - ggml_webgpu_init_soft_max_pipeline(ctx); - -#ifdef GGML_WEBGPU_DEBUG - // Initialize debug buffers - ggml_webgpu_create_buffer(ctx->device, ctx->debug_host_buf, WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t), - wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "debug_host_buf"); - ggml_webgpu_create_buffer(ctx->device, ctx->debug_dev_buf, WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t), - wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc, "debug_dev_buf"); -#endif + create_webgpu_device(reg_ctx); static ggml_backend_webgpu_device_context device_ctx; - device_ctx.webgpu_ctx = ctx; - device_ctx.device_name = GGML_WEBGPU_NAME; - device_ctx.device_desc = info.description; - - GGML_LOG_INFO( - "ggml_webgpu: adapter_info: vendor_id: %u | vendor: %s | architecture: %s | device_id: %u | name: %s | " - "device_desc: %s\n", - info.vendorID, std::string(info.vendor).c_str(), std::string(info.architecture).c_str(), info.deviceID, - std::string(info.device).c_str(), std::string(info.description).c_str()); - + device_ctx.device_name = GGML_WEBGPU_NAME; + device_ctx.device_desc = GGML_WEBGPU_NAME; + device_ctx.webgpu_global_ctx = reg_ctx->webgpu_global_ctx; // See GGML Backend Device Interface section static ggml_backend_device device = { /* .iface = */ ggml_backend_webgpu_device_i, @@ -3354,7 +3433,7 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t /* .context = */ &device_ctx, }; - WEBGPU_CPU_PROFILE_TOTAL_END(reg_get_device, ctx); + WEBGPU_CPU_PROFILE_TOTAL_END(reg_get_device, reg_ctx->webgpu_global_ctx); return &device; } @@ -3370,10 +3449,7 @@ static const struct ggml_backend_reg_i ggml_backend_webgpu_reg_i = { ggml_backend_reg_t ggml_backend_webgpu_reg() { WEBGPU_LOG_DEBUG("ggml_backend_webgpu_reg()"); - webgpu_context webgpu_ctx = std::make_shared(); - static ggml_backend_webgpu_reg_context ctx; - ctx.webgpu_ctx = webgpu_ctx; ctx.name = GGML_WEBGPU_NAME; ctx.device_count = 1; @@ -3390,15 +3466,17 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() { instance_descriptor.nextInChain = &instanceTogglesDesc; #endif - webgpu_ctx->instance = wgpu::CreateInstance(&instance_descriptor); + wgpu::Instance inst = wgpu::CreateInstance(&instance_descriptor); + ctx.webgpu_global_ctx = webgpu_global_context(new webgpu_global_context_struct()); + ctx.webgpu_global_ctx->instance = std::move(inst); #ifdef __EMSCRIPTEN__ - if (webgpu_ctx->instance == nullptr) { + if (ctx.webgpu_global_ctx->instance == nullptr) { GGML_LOG_ERROR("ggml_webgpu: Failed to create WebGPU instance. Make sure either -sASYNCIFY or -sJSPI is set\n"); return nullptr; } #endif - GGML_ASSERT(webgpu_ctx->instance != nullptr); + GGML_ASSERT(ctx.webgpu_global_ctx->instance != nullptr); static ggml_backend_reg reg = { /* .api_version = */ GGML_BACKEND_API_VERSION, @@ -3411,7 +3489,7 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() { ggml_backend_t ggml_backend_webgpu_init(void) { ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_webgpu_reg(), 0); - return ggml_backend_webgpu_device_init(dev, nullptr); + return ggml_backend_webgpu_backend_init(dev, nullptr); } GGML_BACKEND_DL_IMPL(ggml_backend_webgpu_reg) From f28a7330257a073d443429c82f558b0492603391 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 28 Jan 2026 09:15:11 +0200 Subject: [PATCH 052/831] CUDA: tune GLM 4.7 Flash FA kernel selection logic (DGX Spark) (llama/19142) --- ggml/src/ggml-cuda/common.cuh | 1 + ggml/src/ggml-cuda/fattn.cu | 8 ++++++++ 2 files changed, 9 insertions(+) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 09a491a836a..3335f443aeb 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -53,6 +53,7 @@ // While BW spans CC 1000, 1100 & 1200, we are integrating Tensor Core instructions available to 1200 family, see // https://docs.nvidia.com/cutlass/media/docs/cpp/blackwell_functionality.html#blackwell-sm120-gemms #define GGML_CUDA_CC_BLACKWELL 1200 +#define GGML_CUDA_CC_DGX_SPARK 1210 #define GGML_CUDA_CC_RUBIN 1300 #define GGML_CUDA_CC_OFFSET_AMD 0x1000000 #define GGML_CUDA_CC_OFFSET_MTHREADS 0x0100000 diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index b061fdf9a24..fe18ff6c7dc 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -147,6 +147,14 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); const int gqa_ratio = Q->ne[2] / K->ne[2]; if (gqa_ratio == 20) { // GLM 4.7 Flash + if (cc >= GGML_CUDA_CC_DGX_SPARK) { + if (Q->ne[1] <= 8) { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst); + break; + } + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst); + break; + } if (cc >= GGML_CUDA_CC_BLACKWELL) { if (Q->ne[1] <= 4 && K->ne[1] >= 65536) { ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst); From 7fb0f823de760f1be8e6d951516d3adc95b8bb61 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 28 Jan 2026 09:15:27 +0200 Subject: [PATCH 053/831] cuda : fix "V is K view" check for non-unified KV cache (llama/19145) --- ggml/src/ggml-cuda/fattn-common.cuh | 2 +- ggml/src/ggml-cuda/fattn.cu | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index 3d7daccfdf8..b6a7460da83 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -789,7 +789,7 @@ void launch_fattn( const ggml_tensor * K = dst->src[1]; const ggml_tensor * V = dst->src[2]; - const bool V_is_K_view = V->view_src && V->view_offs == 0 && (V->view_src == K || V->view_src == K->view_src); + const bool V_is_K_view = V->view_src && (V->view_src == K || (V->view_src == K->view_src && V->view_offs == K->view_offs)); const ggml_tensor * mask = dst->src[3]; const ggml_tensor * sinks = dst->src[4]; diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index fe18ff6c7dc..195904ee206 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -310,7 +310,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const } } - const bool V_is_K_view = V->view_src && V->view_offs == 0 && (V->view_src == K || V->view_src == K->view_src); + const bool V_is_K_view = V->view_src && (V->view_src == K || (V->view_src == K->view_src && V->view_offs == K->view_offs)); const int cc = ggml_cuda_info().devices[device].cc; From 3701413a713bd85f0219d19d45c0503a6f8708c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alberto=20Cabrera=20P=C3=A9rez?= <1478977+Alcpz@users.noreply.github.com> Date: Wed, 28 Jan 2026 07:15:56 +0000 Subject: [PATCH 054/831] ggml-cpu: arm64: Q4_K scale unroll and vectorization (llama/19108) --- ggml/src/ggml-cpu/arch/arm/repack.cpp | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/ggml/src/ggml-cpu/arch/arm/repack.cpp b/ggml/src/ggml-cpu/arch/arm/repack.cpp index f40226494cd..99bb70274c5 100644 --- a/ggml/src/ggml-cpu/arch/arm/repack.cpp +++ b/ggml/src/ggml-cpu/arch/arm/repack.cpp @@ -3148,16 +3148,17 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, // Scales[i] corresponds to column i const int scale_offset = cp * 2; - for (int blk = 0; blk < 2; blk++) { - const int32x4_t block_scale = { - (int32_t) q4sb_scales[blk][scale_offset], - (int32_t) q4sb_scales[blk][scale_offset], - (int32_t) q4sb_scales[blk][scale_offset + 1], - (int32_t) q4sb_scales[blk][scale_offset + 1], - }; - acc[cp] = vmlaq_s32(acc[cp], sb_acc[blk], block_scale); - acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc[blk + 2], block_scale); - } + const int32_t scale_00 = q4sb_scales[0][scale_offset]; + const int32_t scale_01 = q4sb_scales[0][scale_offset + 1]; + const int32_t scale_10 = q4sb_scales[1][scale_offset]; + const int32_t scale_11 = q4sb_scales[1][scale_offset + 1]; + const int32x4_t block_scale_0 = vcombine_s32(vdup_n_s32(scale_00), vdup_n_s32(scale_01)); + const int32x4_t block_scale_1 = vcombine_s32(vdup_n_s32(scale_10), vdup_n_s32(scale_11)); + + acc[cp] = vmlaq_s32(acc[cp], sb_acc[0], block_scale_0); + acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc[2], block_scale_0); + acc[cp] = vmlaq_s32(acc[cp], sb_acc[1], block_scale_1); + acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc[3], block_scale_1); } // Multiply Acc bsum + mins From 531d7b6781bd57e20f157d7d37dfe5c0b18a4341 Mon Sep 17 00:00:00 2001 From: Kevin Pouget Date: Wed, 28 Jan 2026 10:49:40 +0100 Subject: [PATCH 055/831] ggml: new backend for Virglrenderer API Remoting acceleration (v2) (llama/18718) --- ggml/CMakeLists.txt | 3 + ggml/include/ggml-virtgpu.h | 16 + ggml/src/CMakeLists.txt | 1 + ggml/src/ggml-backend-reg.cpp | 14 + ggml/src/ggml-virtgpu/CMakeLists.txt | 70 +++ .../ggml-virtgpu/apir_cs_ggml-rpc-front.cpp | 87 +++ ggml/src/ggml-virtgpu/backend/CMakeLists.txt | 21 + .../backend/apir_cs_ggml-rpc-back.cpp | 115 ++++ .../ggml-virtgpu/backend/backend-convert.h | 13 + .../backend/backend-dispatched-backend.cpp | 65 +++ .../backend-dispatched-buffer-type.cpp | 89 ++++ .../backend/backend-dispatched-buffer.cpp | 131 +++++ .../backend/backend-dispatched-device.cpp | 148 ++++++ .../backend/backend-dispatched.cpp | 46 ++ .../backend/backend-dispatched.gen.h | 130 +++++ .../ggml-virtgpu/backend/backend-dispatched.h | 23 + .../ggml-virtgpu/backend/backend-virgl-apir.h | 32 ++ ggml/src/ggml-virtgpu/backend/backend.cpp | 148 ++++++ .../backend/shared/api_remoting.h | 90 ++++ .../backend/shared/apir_backend.gen.h | 36 ++ .../backend/shared/apir_backend.h | 46 ++ .../src/ggml-virtgpu/backend/shared/apir_cs.h | 383 ++++++++++++++ .../backend/shared/apir_cs_ggml.h | 211 ++++++++ .../ggml-virtgpu/backend/shared/apir_cs_rpc.h | 54 ++ .../ggml-virtgpu/ggml-backend-buffer-type.cpp | 98 ++++ ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp | 119 +++++ ggml/src/ggml-virtgpu/ggml-backend-device.cpp | 144 +++++ ggml/src/ggml-virtgpu/ggml-backend-reg.cpp | 137 +++++ ggml/src/ggml-virtgpu/ggml-backend.cpp | 69 +++ ggml/src/ggml-virtgpu/ggml-remoting.h | 68 +++ .../ggml-virtgpu/ggmlremoting_functions.yaml | 168 ++++++ ggml/src/ggml-virtgpu/include/apir_hw.h | 9 + ggml/src/ggml-virtgpu/regenerate_remoting.py | 322 +++++++++++ ggml/src/ggml-virtgpu/virtgpu-apir.h | 15 + .../ggml-virtgpu/virtgpu-forward-backend.cpp | 50 ++ .../virtgpu-forward-buffer-type.cpp | 125 +++++ .../ggml-virtgpu/virtgpu-forward-buffer.cpp | 157 ++++++ .../ggml-virtgpu/virtgpu-forward-device.cpp | 200 +++++++ ggml/src/ggml-virtgpu/virtgpu-forward-impl.h | 29 + ggml/src/ggml-virtgpu/virtgpu-forward.gen.h | 51 ++ ggml/src/ggml-virtgpu/virtgpu-shm.cpp | 99 ++++ ggml/src/ggml-virtgpu/virtgpu-shm.h | 23 + ggml/src/ggml-virtgpu/virtgpu-utils.cpp | 179 +++++++ ggml/src/ggml-virtgpu/virtgpu-utils.h | 86 +++ ggml/src/ggml-virtgpu/virtgpu.cpp | 498 ++++++++++++++++++ ggml/src/ggml-virtgpu/virtgpu.h | 92 ++++ 46 files changed, 4710 insertions(+) create mode 100644 ggml/include/ggml-virtgpu.h create mode 100644 ggml/src/ggml-virtgpu/CMakeLists.txt create mode 100644 ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp create mode 100644 ggml/src/ggml-virtgpu/backend/CMakeLists.txt create mode 100644 ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp create mode 100644 ggml/src/ggml-virtgpu/backend/backend-convert.h create mode 100644 ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp create mode 100644 ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp create mode 100644 ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp create mode 100644 ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp create mode 100644 ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp create mode 100644 ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h create mode 100644 ggml/src/ggml-virtgpu/backend/backend-dispatched.h create mode 100644 ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h create mode 100644 ggml/src/ggml-virtgpu/backend/backend.cpp create mode 100644 ggml/src/ggml-virtgpu/backend/shared/api_remoting.h create mode 100644 ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h create mode 100644 ggml/src/ggml-virtgpu/backend/shared/apir_backend.h create mode 100644 ggml/src/ggml-virtgpu/backend/shared/apir_cs.h create mode 100644 ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h create mode 100644 ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h create mode 100644 ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp create mode 100644 ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp create mode 100644 ggml/src/ggml-virtgpu/ggml-backend-device.cpp create mode 100644 ggml/src/ggml-virtgpu/ggml-backend-reg.cpp create mode 100644 ggml/src/ggml-virtgpu/ggml-backend.cpp create mode 100644 ggml/src/ggml-virtgpu/ggml-remoting.h create mode 100644 ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml create mode 100644 ggml/src/ggml-virtgpu/include/apir_hw.h create mode 100755 ggml/src/ggml-virtgpu/regenerate_remoting.py create mode 100644 ggml/src/ggml-virtgpu/virtgpu-apir.h create mode 100644 ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp create mode 100644 ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp create mode 100644 ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp create mode 100644 ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp create mode 100644 ggml/src/ggml-virtgpu/virtgpu-forward-impl.h create mode 100644 ggml/src/ggml-virtgpu/virtgpu-forward.gen.h create mode 100644 ggml/src/ggml-virtgpu/virtgpu-shm.cpp create mode 100644 ggml/src/ggml-virtgpu/virtgpu-shm.h create mode 100644 ggml/src/ggml-virtgpu/virtgpu-utils.cpp create mode 100644 ggml/src/ggml-virtgpu/virtgpu-utils.h create mode 100644 ggml/src/ggml-virtgpu/virtgpu.cpp create mode 100644 ggml/src/ggml-virtgpu/virtgpu.h diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 0176ca1ce93..b0b8e57898c 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -228,6 +228,8 @@ option(GGML_WEBGPU_CPU_PROFILE "ggml: enable WebGPU profiling (CPU) option(GGML_WEBGPU_GPU_PROFILE "ggml: enable WebGPU profiling (GPU)" OFF) option(GGML_WEBGPU_JSPI "ggml: use JSPI for WebGPU" ON) option(GGML_ZDNN "ggml: use zDNN" OFF) +option(GGML_VIRTGPU "ggml: use the VirtGPU/Virglrenderer API Remoting frontend" OFF) +option(GGML_VIRTGPU_BACKEND "ggml: build the VirtGPU/Virglrenderer API Remoting backend" OFF) option(GGML_METAL "ggml: use Metal" ${GGML_METAL_DEFAULT}) option(GGML_METAL_NDEBUG "ggml: disable Metal debugging" OFF) option(GGML_METAL_SHADER_DEBUG "ggml: compile Metal with -fno-fast-math" OFF) @@ -320,6 +322,7 @@ set(GGML_PUBLIC_HEADERS include/ggml-opt.h include/ggml-metal.h include/ggml-rpc.h + include/ggml-virtgpu.h include/ggml-sycl.h include/ggml-vulkan.h include/ggml-webgpu.h diff --git a/ggml/include/ggml-virtgpu.h b/ggml/include/ggml-virtgpu.h new file mode 100644 index 00000000000..1cb4bd7a038 --- /dev/null +++ b/ggml/include/ggml-virtgpu.h @@ -0,0 +1,16 @@ +#pragma once + +#include "ggml.h" +#include "ggml-backend.h" + +#ifdef __cplusplus +extern "C" { +#endif + +#define GGML_REMOTING_FRONTEND_NAME "RemotingFrontend" + +GGML_BACKEND_API ggml_backend_reg_t ggml_backend_virtgpu_reg(); + +#ifdef __cplusplus +} +#endif diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index 6192a870466..260ad48f0e8 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -451,6 +451,7 @@ ggml_add_backend(HIP) ggml_add_backend(METAL) ggml_add_backend(MUSA) ggml_add_backend(RPC) +ggml_add_backend(VirtGPU) ggml_add_backend(SYCL) ggml_add_backend(Vulkan) ggml_add_backend(WebGPU) diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp index 6bee1bc4b49..dd991f262e6 100644 --- a/ggml/src/ggml-backend-reg.cpp +++ b/ggml/src/ggml-backend-reg.cpp @@ -69,6 +69,10 @@ #include "ggml-rpc.h" #endif +#ifdef GGML_USE_VIRTGPU_FRONTEND +#include "ggml-virtgpu.h" +#endif + #ifdef GGML_USE_CANN #include "ggml-cann.h" #endif @@ -180,7 +184,12 @@ struct ggml_backend_registry { register_backend(ggml_backend_sycl_reg()); #endif #ifdef GGML_USE_VULKAN + // Add runtime disable check + if (getenv("GGML_DISABLE_VULKAN") == nullptr) { register_backend(ggml_backend_vk_reg()); + } else { + GGML_LOG_DEBUG("Vulkan backend disabled by GGML_DISABLE_VULKAN environment variable\n"); + } #endif #ifdef GGML_USE_WEBGPU register_backend(ggml_backend_webgpu_reg()); @@ -188,6 +197,10 @@ struct ggml_backend_registry { #ifdef GGML_USE_ZDNN register_backend(ggml_backend_zdnn_reg()); #endif +#ifdef GGML_USE_VIRTGPU_FRONTEND + register_backend(ggml_backend_virtgpu_reg()); +#endif + #ifdef GGML_USE_OPENCL register_backend(ggml_backend_opencl_reg()); #endif @@ -604,6 +617,7 @@ void ggml_backend_load_all_from_path(const char * dir_path) { ggml_backend_load_best("rpc", silent, dir_path); ggml_backend_load_best("sycl", silent, dir_path); ggml_backend_load_best("vulkan", silent, dir_path); + ggml_backend_load_best("virtgpu", silent, dir_path); ggml_backend_load_best("opencl", silent, dir_path); ggml_backend_load_best("hexagon", silent, dir_path); ggml_backend_load_best("musa", silent, dir_path); diff --git a/ggml/src/ggml-virtgpu/CMakeLists.txt b/ggml/src/ggml-virtgpu/CMakeLists.txt new file mode 100644 index 00000000000..e6b020beb5b --- /dev/null +++ b/ggml/src/ggml-virtgpu/CMakeLists.txt @@ -0,0 +1,70 @@ +cmake_minimum_required(VERSION 3.19) +cmake_policy(SET CMP0114 NEW) + +include(ExternalProject) + +message(STATUS "Including the VirtGPU/Virglrenderer API Remoting") + +# Download venus_hw.h from virglrenderer repository +ExternalProject_Add( + venus_hw_header + URL https://gitlab.freedesktop.org/virgl/virglrenderer/-/raw/virglrenderer-1.2.0/src/venus_hw.h + DOWNLOAD_NO_EXTRACT YES + DOWNLOAD_DIR ${CMAKE_CURRENT_SOURCE_DIR}/include + DOWNLOAD_NAME venus_hw.h + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + INSTALL_COMMAND "" + LOG_DOWNLOAD ON +) + +if (NOT GGML_VIRTGPU_BACKEND STREQUAL "ONLY") + message(STATUS "Enable the VirtGPU/Virglrenderer API Remoting frontend library") + + find_package(PkgConfig REQUIRED) + pkg_check_modules(DRM REQUIRED libdrm) + if (NOT GGML_BACKEND_DL) + # cannot simply use USE_VIRTGPU, as in the 'else()' case the + # frontend isn't compiled + target_compile_definitions(ggml PUBLIC "GGML_USE_VIRTGPU_FRONTEND") + endif() + + ggml_add_backend_library(ggml-virtgpu + ggml-backend-buffer.cpp + ggml-backend.cpp + ggml-backend-device.cpp + ggml-backend-reg.cpp + ggml-backend-buffer-type.cpp + virtgpu-apir.h + virtgpu-forward.gen.h + virtgpu.cpp + virtgpu-shm.cpp + virtgpu-utils.cpp + virtgpu-forward-device.cpp + virtgpu-forward-buffer-type.cpp + virtgpu-forward-buffer.cpp + virtgpu-forward-backend.cpp + virtgpu-forward-impl.h + apir_cs_ggml-rpc-front.cpp + ../../include/ggml-virtgpu.h) + + target_include_directories(ggml-virtgpu PUBLIC /usr/include/libdrm/) + + target_link_libraries(ggml-virtgpu PUBLIC ${DRM_LIBRARIES}) + target_include_directories(ggml-virtgpu PUBLIC ${DRM_INCLUDE_DIRS}) + target_compile_options(ggml-virtgpu PUBLIC ${DRM_CFLAGS_OTHER}) + + target_include_directories(ggml-virtgpu PUBLIC ./include) + target_include_directories(ggml-virtgpu PRIVATE ${CMAKE_CURRENT_BINARY_DIR}) + + # Ensure venus_hw.h is downloaded before building ggml-virtgpu + add_dependencies(ggml-virtgpu venus_hw_header) + + target_compile_options(ggml-virtgpu PRIVATE -std=c++20) +else() + message(STATUS "Not building the VirtGPU/Virglrenderer API Remoting frontend library") +endif() + +if (NOT GGML_VIRTGPU_BACKEND STREQUAL "OFF") + add_subdirectory("backend") +endif() diff --git a/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp b/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp new file mode 100644 index 00000000000..f60ae3556ca --- /dev/null +++ b/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp @@ -0,0 +1,87 @@ +#include "backend/shared/apir_cs_rpc.h" +#include "ggml-backend-impl.h" +#include "ggml-impl.h" +#include "ggml-remoting.h" + +#include +#include +#include +#include + +apir_rpc_tensor apir_serialize_tensor(const ggml_tensor * tensor) { + apir_rpc_tensor result; + result.id = reinterpret_cast(tensor); + result.type = tensor->type; + if (tensor->buffer) { + ggml_backend_buffer_t buffer = tensor->buffer; + + result.buffer = BUFFER_TO_HOST_HANDLE(buffer); + } else { + result.buffer = 0; + } + for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) { + result.ne[i] = tensor->ne[i]; + result.nb[i] = tensor->nb[i]; + } + result.op = tensor->op; + for (uint32_t i = 0; i < GGML_MAX_OP_PARAMS / sizeof(int32_t); i++) { + result.op_params[i] = tensor->op_params[i]; + } + result.flags = tensor->flags; + for (uint32_t i = 0; i < GGML_MAX_SRC; i++) { + result.src[i] = reinterpret_cast(tensor->src[i]); + } + result.view_src = reinterpret_cast(tensor->view_src); + result.view_offs = tensor->view_offs; + result.data = reinterpret_cast(tensor->data); + if (tensor->data) { + if (!tensor->buffer) { + GGML_ABORT("tensor has data but not buffer"); + } + // tensor->data is serialized as an offset to the buffer base address + result.data -= reinterpret_cast(BUFFER_TO_GGML_CONTEXT(tensor->buffer)->base); + } + snprintf(result.name, GGML_MAX_NAME, "%s", tensor->name); + return result; +} + +void apir_add_tensor(ggml_tensor * tensor, + std::vector & tensors, + std::unordered_set & visited) { + if (tensor == nullptr) { + return; + } + if (visited.find(tensor) != visited.end()) { + return; + } + visited.insert(tensor); + for (int i = 0; i < GGML_MAX_SRC; i++) { + apir_add_tensor(tensor->src[i], tensors, visited); + } + apir_add_tensor(tensor->view_src, tensors, visited); + tensors.push_back(apir_serialize_tensor(tensor)); +} + +void apir_serialize_graph(const ggml_cgraph * cgraph, std::vector & output) { + uint32_t n_nodes = cgraph->n_nodes; + std::vector tensors; + std::unordered_set visited; + for (uint32_t i = 0; i < n_nodes; i++) { + apir_add_tensor(cgraph->nodes[i], tensors, visited); + } + // serialization format: + // | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(apir_rpc_tensor)) | + uint32_t n_tensors = tensors.size(); + int output_size = + sizeof(uint32_t) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t) + n_tensors * sizeof(apir_rpc_tensor); + output.resize(output_size, 0); + memcpy(output.data(), &n_nodes, sizeof(n_nodes)); + for (uint32_t i = 0; i < n_nodes; i++) { + memcpy(output.data() + sizeof(n_nodes) + i * sizeof(uint64_t), &cgraph->nodes[i], sizeof(uint64_t)); + } + uint32_t * out_ntensors = (uint32_t *) (output.data() + sizeof(n_nodes) + n_nodes * sizeof(uint64_t)); + *out_ntensors = n_tensors; + apir_rpc_tensor * out_tensors = + (apir_rpc_tensor *) (output.data() + sizeof(n_nodes) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t)); + memcpy(out_tensors, tensors.data(), n_tensors * sizeof(apir_rpc_tensor)); +} diff --git a/ggml/src/ggml-virtgpu/backend/CMakeLists.txt b/ggml/src/ggml-virtgpu/backend/CMakeLists.txt new file mode 100644 index 00000000000..0b49c403b9a --- /dev/null +++ b/ggml/src/ggml-virtgpu/backend/CMakeLists.txt @@ -0,0 +1,21 @@ +cmake_minimum_required(VERSION 3.19) +cmake_policy(SET CMP0114 NEW) + +message(STATUS "Enable the VirtGPU/Virglrenderer backend library") + +ggml_add_backend_library(ggml-virtgpu-backend + backend.cpp + backend-dispatched.cpp + backend-dispatched-backend.cpp + backend-dispatched-device.cpp + backend-dispatched-buffer.cpp + backend-dispatched-buffer-type.cpp + shared/api_remoting.h + shared/apir_backend.h + shared/apir_cs.h + apir_cs_ggml-rpc-back.cpp) + +target_compile_options(ggml-virtgpu-backend PRIVATE -std=c++20) + +# Add include directory for ggml-backend-impl.h and other core headers +target_include_directories(ggml-virtgpu-backend PRIVATE ../..) diff --git a/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp b/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp new file mode 100644 index 00000000000..60a8a93bfb8 --- /dev/null +++ b/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp @@ -0,0 +1,115 @@ +#include "ggml-backend-impl.h" +#include "ggml-impl.h" +#include "shared/apir_cs_rpc.h" + +#include +#include +#include +#include + +std::unordered_set backend_buffers; + +void apir_track_backend_buffer(ggml_backend_buffer_t buffer) { + backend_buffers.insert(buffer); +} + +bool apir_untrack_backend_buffer(ggml_backend_buffer_t buffer) { + auto it = backend_buffers.find(buffer); + if (it == backend_buffers.end()) { + return false; + } + + backend_buffers.erase(it); + return true; +} + +std::unordered_set apir_get_track_backend_buffers() { + return backend_buffers; +} + +ggml_tensor * apir_deserialize_tensor(ggml_context * ctx, const apir_rpc_tensor * tensor) { + ggml_tensor * result = + ggml_new_tensor_4d(ctx, (ggml_type) tensor->type, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); + for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) { + result->nb[i] = tensor->nb[i]; + } + result->buffer = reinterpret_cast(tensor->buffer); + if (result->buffer && backend_buffers.find(result->buffer) == backend_buffers.end()) { + printf("WARNING: HOST BUFFER NOT FOUND | %p\n", (void *) result->buffer); + result->buffer = nullptr; + } + + uint64_t tensor_data = tensor->data; + if (result->buffer) { + // require that the tensor data does not go beyond the buffer end + uint64_t tensor_size = (uint64_t) ggml_nbytes(result); + uint64_t buffer_start = (uint64_t) ggml_backend_buffer_get_base(result->buffer); + uint64_t buffer_size = (uint64_t) ggml_backend_buffer_get_size(result->buffer); + + // tensor->data is serialized as an offset to the buffer base address + tensor_data += buffer_start; + + GGML_ASSERT(tensor_data + tensor_size >= tensor_data); // check for overflow + GGML_ASSERT(tensor_data >= buffer_start && tensor_data + tensor_size <= buffer_start + buffer_size); + } + + result->op = (ggml_op) tensor->op; + for (uint32_t i = 0; i < GGML_MAX_OP_PARAMS / sizeof(int32_t); i++) { + result->op_params[i] = tensor->op_params[i]; + } + result->flags = tensor->flags; + result->data = reinterpret_cast(tensor_data); + ggml_set_name(result, tensor->name); + return result; +} + +ggml_tensor * apir_create_node(uint64_t id, + ggml_context * ctx, + const std::unordered_map & tensor_ptrs, + std::unordered_map & tensor_map) { + if (id == 0) { + return nullptr; + } + if (tensor_map.find(id) != tensor_map.end()) { + return tensor_map[id]; + } + const apir_rpc_tensor * tensor = tensor_ptrs.at(id); + ggml_tensor * result = apir_deserialize_tensor(ctx, tensor); + if (result == nullptr) { + return nullptr; + } + tensor_map[id] = result; + for (int i = 0; i < GGML_MAX_SRC; i++) { + result->src[i] = apir_create_node(tensor->src[i], ctx, tensor_ptrs, tensor_map); + } + result->view_src = apir_create_node(tensor->view_src, ctx, tensor_ptrs, tensor_map); + result->view_offs = tensor->view_offs; + return result; +} + +ggml_cgraph * apir_deserialize_graph(uint32_t n_nodes, + uint32_t n_tensors, + const apir_rpc_tensor * tensors, + const uint64_t * nodes) { + size_t buf_size = ggml_tensor_overhead() * (n_nodes + n_tensors) + ggml_graph_overhead_custom(n_nodes, false); + ggml_init_params params = { + /*.mem_size =*/buf_size, + /*.mem_buffer =*/NULL, + /*.no_alloc =*/true, + }; + ggml_context * ctx = ggml_init(params); + ggml_cgraph * graph = ggml_new_graph_custom(ctx, n_nodes, false); + graph->n_nodes = n_nodes; + std::unordered_map tensor_ptrs; + for (uint32_t i = 0; i < n_tensors; i++) { + tensor_ptrs[tensors[i].id] = &tensors[i]; + } + std::unordered_map tensor_map; + for (uint32_t i = 0; i < n_nodes; i++) { + int64_t id; + memcpy(&id, &nodes[i], sizeof(id)); + graph->nodes[i] = apir_create_node(id, ctx, tensor_ptrs, tensor_map); + } + + return graph; +} diff --git a/ggml/src/ggml-virtgpu/backend/backend-convert.h b/ggml/src/ggml-virtgpu/backend/backend-convert.h new file mode 100644 index 00000000000..1978d21f7ef --- /dev/null +++ b/ggml/src/ggml-virtgpu/backend/backend-convert.h @@ -0,0 +1,13 @@ +#include "shared/apir_backend.h" + +#define BUFFER_TO_HOST_HANDLE(name) ggml_buffer_to_apir_handle(name) + +static inline apir_buffer_host_handle_t ggml_buffer_to_apir_handle(ggml_backend_buffer_t buffer) { + // in the backend, the buffer handle is the buffer pointer + return (apir_buffer_host_handle_t) buffer; +} + +static inline apir_buffer_type_host_handle_t ggml_buffer_type_to_apir_handle(ggml_backend_buffer_type_t buft) { + // in the backend, the buffer handle is the buffer pointer + return (apir_buffer_type_host_handle_t) buft; +} diff --git a/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp b/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp new file mode 100644 index 00000000000..77b4ee71e12 --- /dev/null +++ b/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp @@ -0,0 +1,65 @@ +#include "backend-dispatched.h" +#include "backend-virgl-apir.h" +#include "ggml-backend-impl.h" +#include "ggml-backend.h" +#include "ggml-impl.h" +#include "shared/apir_backend.h" + +#include + +uint32_t backend_backend_graph_compute(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) { + GGML_UNUSED(ctx); + GGML_UNUSED(enc); + + static bool async_backend_initialized = false; + static bool async_backend; + + if (!async_backend_initialized) { + ggml_backend_dev_props props; + + dev->iface.get_props(dev, &props); + async_backend = props.caps.async; + async_backend_initialized = true; + } + + uint32_t shmem_res_id; + apir_decode_virtgpu_shmem_res_id(dec, &shmem_res_id); + + const void * shmem_data = ctx->iface->get_shmem_ptr(ctx->ctx_id, shmem_res_id); + if (!shmem_data) { + GGML_LOG_ERROR("Couldn't get the shmem addr from virgl\n"); + apir_decoder_set_fatal(dec); + return 1; + } + size_t cgraph_size; + apir_decode_size_t(dec, &cgraph_size); + + apir_decoder secondary_dec = apir_new_decoder((const char *) shmem_data, cgraph_size); + + ggml_cgraph * cgraph = apir_decode_ggml_cgraph(&secondary_dec, cgraph_size); + + ggml_status status; +#if APIR_BACKEND_CHECK_SUPPORTS_OP == 1 + for (int idx = 0; idx < cgraph->n_nodes; idx++) { + ggml_tensor * op = ggml_graph_node(cgraph, idx); + if (dev->iface.supports_op(dev, op)) { + continue; + } + GGML_LOG_ERROR("Graph node %d (%s) not supported by the backend\n", idx, ggml_op_desc(op)); + + status = GGML_STATUS_ABORTED; + apir_encode_ggml_status(enc, &status); + + return 0; + } +#endif + status = bck->iface.graph_compute(bck, cgraph); + + if (async_backend) { + bck->iface.synchronize(bck); + } + + apir_encode_ggml_status(enc, &status); + + return 0; +} diff --git a/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp b/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp new file mode 100644 index 00000000000..8ea1bb4fb49 --- /dev/null +++ b/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp @@ -0,0 +1,89 @@ +#include "backend-dispatched.h" +#include "backend-virgl-apir.h" +#include "ggml-backend-impl.h" +#include "ggml-backend.h" +#include "ggml-impl.h" + +#include + +uint32_t backend_buffer_type_get_name(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) { + GGML_UNUSED(ctx); + ggml_backend_buffer_type_t buft; + buft = apir_decode_ggml_buffer_type(dec); + + const char * string = buft->iface.get_name(buft); + + const size_t string_size = strlen(string) + 1; + apir_encode_array_size(enc, string_size); + apir_encode_char_array(enc, string, string_size); + + return 0; +} + +uint32_t backend_buffer_type_get_alignment(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) { + GGML_UNUSED(ctx); + ggml_backend_buffer_type_t buft; + buft = apir_decode_ggml_buffer_type(dec); + + size_t value = buft->iface.get_alignment(buft); + apir_encode_size_t(enc, &value); + + return 0; +} + +uint32_t backend_buffer_type_get_max_size(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) { + GGML_UNUSED(ctx); + ggml_backend_buffer_type_t buft; + buft = apir_decode_ggml_buffer_type(dec); + + size_t value = buft->iface.get_max_size(buft); + apir_encode_size_t(enc, &value); + + return 0; +} + +uint32_t backend_buffer_type_is_host(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) { + GGML_UNUSED(ctx); + ggml_backend_buffer_type_t buft; + buft = apir_decode_ggml_buffer_type(dec); + + bool is_host = buft->iface.is_host(buft); + apir_encode_bool_t(enc, &is_host); + + return 0; +} + +uint32_t backend_buffer_type_alloc_buffer(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) { + GGML_UNUSED(ctx); + ggml_backend_buffer_type_t buft; + buft = apir_decode_ggml_buffer_type(dec); + + size_t size; + apir_decode_size_t(dec, &size); + + ggml_backend_buffer_t buffer; + + buffer = buft->iface.alloc_buffer(buft, size); + + apir_encode_ggml_buffer(enc, buffer); + + if (buffer) { + apir_track_backend_buffer(buffer); + } + + return 0; +} + +uint32_t backend_buffer_type_get_alloc_size(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) { + GGML_UNUSED(ctx); + ggml_backend_buffer_type_t buft; + buft = apir_decode_ggml_buffer_type(dec); + + const ggml_tensor * op = apir_decode_ggml_tensor_inplace(dec); + + size_t value = buft->iface.get_alloc_size(buft, op); + + apir_encode_size_t(enc, &value); + + return 0; +} diff --git a/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp b/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp new file mode 100644 index 00000000000..cf81888e989 --- /dev/null +++ b/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp @@ -0,0 +1,131 @@ +#include "backend-dispatched.h" +#include "backend-virgl-apir.h" +#include "ggml-backend-impl.h" +#include "ggml-backend.h" +#include "ggml-impl.h" + +#include + +uint32_t backend_buffer_get_base(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) { + GGML_UNUSED(ctx); + ggml_backend_buffer_t buffer; + buffer = apir_decode_ggml_buffer(dec); + + uintptr_t base = (uintptr_t) buffer->iface.get_base(buffer); + apir_encode_uintptr_t(enc, &base); + + return 0; +} + +uint32_t backend_buffer_set_tensor(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) { + GGML_UNUSED(ctx); + GGML_UNUSED(enc); + + ggml_backend_buffer_t buffer; + buffer = apir_decode_ggml_buffer(dec); + + ggml_tensor * tensor; + // safe to remove the const qualifier here + tensor = (ggml_tensor *) (uintptr_t) apir_decode_ggml_tensor(dec); + + uint32_t shmem_res_id; + apir_decode_virtgpu_shmem_res_id(dec, &shmem_res_id); + + size_t offset; + apir_decode_size_t(dec, &offset); + + size_t size; + apir_decode_size_t(dec, &size); + + void * shmem_data = ctx->iface->get_shmem_ptr(ctx->ctx_id, shmem_res_id); + + if (!shmem_data) { + GGML_LOG_ERROR("Couldn't get the shmem addr from virgl\n"); + return 1; + } + + buffer->iface.set_tensor(buffer, tensor, shmem_data, offset, size); + + return 0; +} + +uint32_t backend_buffer_get_tensor(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) { + GGML_UNUSED(ctx); + GGML_UNUSED(enc); + + ggml_backend_buffer_t buffer; + buffer = apir_decode_ggml_buffer(dec); + + const ggml_tensor * tensor; + // safe to remove the const qualifier here + tensor = apir_decode_ggml_tensor(dec); + + uint32_t shmem_res_id; + apir_decode_virtgpu_shmem_res_id(dec, &shmem_res_id); + + size_t offset; + apir_decode_size_t(dec, &offset); + + size_t size; + apir_decode_size_t(dec, &size); + + void * shmem_data = ctx->iface->get_shmem_ptr(ctx->ctx_id, shmem_res_id); + if (!shmem_data) { + GGML_LOG_ERROR("Couldn't get the shmem addr from virgl\n"); + return 1; + } + + buffer->iface.get_tensor(buffer, tensor, shmem_data, offset, size); + + return 0; +} + +uint32_t backend_buffer_cpy_tensor(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) { + GGML_UNUSED(ctx); + + ggml_backend_buffer_t buffer; + buffer = apir_decode_ggml_buffer(dec); + + const ggml_tensor * src; + // safe to remove the const qualifier here + src = apir_decode_ggml_tensor(dec); + ggml_tensor * dst = (ggml_tensor *) (uintptr_t) apir_decode_ggml_tensor(dec); + + bool ret = buffer->iface.cpy_tensor(buffer, src, (ggml_tensor *) dst); + + apir_encode_bool_t(enc, &ret); + + return 0; +} + +uint32_t backend_buffer_clear(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) { + GGML_UNUSED(ctx); + GGML_UNUSED(enc); + + ggml_backend_buffer_t buffer; + buffer = apir_decode_ggml_buffer(dec); + + uint8_t value; + apir_decode_uint8_t(dec, &value); + + buffer->iface.clear(buffer, value); + + return 0; +} + +uint32_t backend_buffer_free_buffer(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) { + GGML_UNUSED(ctx); + GGML_UNUSED(enc); + + ggml_backend_buffer_t buffer; + buffer = apir_decode_ggml_buffer(dec); + + if (!apir_untrack_backend_buffer(buffer)) { + GGML_LOG_WARN("%s: unknown buffer %p\n", __func__, (void *) buffer); + return 1; + } + + buffer->iface.free_buffer(buffer); + + return 0; +} diff --git a/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp b/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp new file mode 100644 index 00000000000..497f737a881 --- /dev/null +++ b/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp @@ -0,0 +1,148 @@ +#include "backend-dispatched.h" +#include "backend-virgl-apir.h" +#include "ggml-backend-impl.h" +#include "ggml-backend.h" +#include "ggml-impl.h" + +#include + +uint32_t backend_device_get_device_count(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) { + GGML_UNUSED(ctx); + GGML_UNUSED(ctx); + GGML_UNUSED(dec); + + int32_t dev_count = reg->iface.get_device_count(reg); + apir_encode_int32_t(enc, &dev_count); + + return 0; +} + +uint32_t backend_device_get_count(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) { + GGML_UNUSED(ctx); + GGML_UNUSED(ctx); + GGML_UNUSED(dec); + + int32_t dev_count = reg->iface.get_device_count(reg); + apir_encode_int32_t(enc, &dev_count); + + return 0; +} + +uint32_t backend_device_get_name(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) { + GGML_UNUSED(ctx); + GGML_UNUSED(dec); + + const char * string = dev->iface.get_name(dev); + + const size_t string_size = strlen(string) + 1; + apir_encode_array_size(enc, string_size); + apir_encode_char_array(enc, string, string_size); + + return 0; +} + +uint32_t backend_device_get_description(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) { + GGML_UNUSED(ctx); + GGML_UNUSED(dec); + + const char * string = dev->iface.get_description(dev); + + const size_t string_size = strlen(string) + 1; + apir_encode_array_size(enc, string_size); + apir_encode_char_array(enc, string, string_size); + + return 0; +} + +uint32_t backend_device_get_type(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) { + GGML_UNUSED(ctx); + GGML_UNUSED(dec); + + uint32_t type = dev->iface.get_type(dev); + apir_encode_uint32_t(enc, &type); + + return 0; +} + +uint32_t backend_device_get_memory(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) { + GGML_UNUSED(ctx); + GGML_UNUSED(dec); + + size_t free, total; + dev->iface.get_memory(dev, &free, &total); + + apir_encode_size_t(enc, &free); + apir_encode_size_t(enc, &total); + + return 0; +} + +uint32_t backend_device_supports_op(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) { + GGML_UNUSED(ctx); + + const ggml_tensor * op = apir_decode_ggml_tensor_inplace(dec); + + bool supports_op = dev->iface.supports_op(dev, op); + + apir_encode_bool_t(enc, &supports_op); + + return 0; +} + +uint32_t backend_device_get_buffer_type(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) { + GGML_UNUSED(ctx); + GGML_UNUSED(dec); + + ggml_backend_buffer_type_t bufft = dev->iface.get_buffer_type(dev); + + apir_encode_ggml_buffer_type(enc, bufft); + + return 0; +} + +uint32_t backend_device_get_props(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) { + GGML_UNUSED(ctx); + GGML_UNUSED(dec); + + ggml_backend_dev_props props; + dev->iface.get_props(dev, &props); + + apir_encode_bool_t(enc, &props.caps.async); + apir_encode_bool_t(enc, &props.caps.host_buffer); + apir_encode_bool_t(enc, &props.caps.buffer_from_host_ptr); + apir_encode_bool_t(enc, &props.caps.events); + + return 0; +} + +uint32_t backend_device_buffer_from_ptr(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) { + GGML_UNUSED(ctx); + GGML_UNUSED(dec); + + uint32_t shmem_res_id; + apir_decode_virtgpu_shmem_res_id(dec, &shmem_res_id); + + void * shmem_ptr = ctx->iface->get_shmem_ptr(ctx->ctx_id, shmem_res_id); + if (!shmem_ptr) { + GGML_LOG_ERROR("Couldn't get the shmem addr from virgl\n"); + apir_decoder_set_fatal(dec); + return 1; + } + + size_t size; + apir_decode_size_t(dec, &size); + size_t max_tensor_size; + apir_decode_size_t(dec, &max_tensor_size); + + ggml_backend_buffer_t buffer; + buffer = dev->iface.buffer_from_host_ptr(dev, shmem_ptr, size, max_tensor_size); + + apir_encode_ggml_buffer(enc, buffer); + apir_encode_ggml_buffer_type(enc, buffer->buft); + + if (buffer) { + apir_track_backend_buffer(buffer); + } + + return 0; +} diff --git a/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp b/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp new file mode 100644 index 00000000000..51d445725f0 --- /dev/null +++ b/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp @@ -0,0 +1,46 @@ +#include "backend-dispatched.h" +#include "backend-virgl-apir.h" + +#include "ggml-backend-impl.h" +#include "ggml-backend.h" +#include "ggml-impl.h" + +#include + +ggml_backend_reg_t reg = NULL; +ggml_backend_dev_t dev = NULL; +ggml_backend_t bck = NULL; + +uint64_t timer_start = 0; +uint64_t timer_total = 0; +uint64_t timer_count = 0; + +uint32_t backend_dispatch_initialize(void * ggml_backend_reg_fct_p) { + if (reg != NULL) { + GGML_LOG_WARN("%s: already initialized\n", __func__); + return APIR_BACKEND_INITIALIZE_ALREADY_INITED; + } + ggml_backend_reg_t (*ggml_backend_reg_fct)(void) = (ggml_backend_reg_t (*)()) ggml_backend_reg_fct_p; + + reg = ggml_backend_reg_fct(); + if (reg == NULL) { + GGML_LOG_ERROR("%s: backend registration failed\n", __func__); + return APIR_BACKEND_INITIALIZE_BACKEND_REG_FAILED; + } + + if (!reg->iface.get_device_count(reg)) { + GGML_LOG_ERROR("%s: backend initialization failed: no device found\n", __func__); + return APIR_BACKEND_INITIALIZE_NO_DEVICE; + } + + dev = reg->iface.get_device(reg, 0); + + if (!dev) { + GGML_LOG_ERROR("%s: backend initialization failed: no device received\n", __func__); + return APIR_BACKEND_INITIALIZE_NO_DEVICE; + } + + bck = dev->iface.init_backend(dev, NULL); + + return APIR_BACKEND_INITIALIZE_SUCCESS; +} diff --git a/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h b/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h new file mode 100644 index 00000000000..b81fd5039bd --- /dev/null +++ b/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h @@ -0,0 +1,130 @@ +#pragma once + +/* device */ +uint32_t backend_device_get_device_count(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); +uint32_t backend_device_get_count(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); +uint32_t backend_device_get_name(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); +uint32_t backend_device_get_description(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); +uint32_t backend_device_get_type(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); +uint32_t backend_device_get_memory(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); +uint32_t backend_device_supports_op(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); +uint32_t backend_device_get_buffer_type(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); +uint32_t backend_device_get_props(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); +uint32_t backend_device_buffer_from_ptr(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); + +/* buffer-type */ +uint32_t backend_buffer_type_get_name(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); +uint32_t backend_buffer_type_get_alignment(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); +uint32_t backend_buffer_type_get_max_size(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); +uint32_t backend_buffer_type_is_host(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); +uint32_t backend_buffer_type_alloc_buffer(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); +uint32_t backend_buffer_type_get_alloc_size(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); + +/* buffer */ +uint32_t backend_buffer_get_base(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); +uint32_t backend_buffer_set_tensor(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); +uint32_t backend_buffer_get_tensor(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); +uint32_t backend_buffer_cpy_tensor(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); +uint32_t backend_buffer_clear(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); +uint32_t backend_buffer_free_buffer(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); + +/* backend */ +uint32_t backend_backend_graph_compute(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); + +static inline const char * backend_dispatch_command_name(ApirBackendCommandType type) { + switch (type) { + /* device */ + case APIR_COMMAND_TYPE_DEVICE_GET_DEVICE_COUNT: + return "backend_device_get_device_count"; + case APIR_COMMAND_TYPE_DEVICE_GET_COUNT: + return "backend_device_get_count"; + case APIR_COMMAND_TYPE_DEVICE_GET_NAME: + return "backend_device_get_name"; + case APIR_COMMAND_TYPE_DEVICE_GET_DESCRIPTION: + return "backend_device_get_description"; + case APIR_COMMAND_TYPE_DEVICE_GET_TYPE: + return "backend_device_get_type"; + case APIR_COMMAND_TYPE_DEVICE_GET_MEMORY: + return "backend_device_get_memory"; + case APIR_COMMAND_TYPE_DEVICE_SUPPORTS_OP: + return "backend_device_supports_op"; + case APIR_COMMAND_TYPE_DEVICE_GET_BUFFER_TYPE: + return "backend_device_get_buffer_type"; + case APIR_COMMAND_TYPE_DEVICE_GET_PROPS: + return "backend_device_get_props"; + case APIR_COMMAND_TYPE_DEVICE_BUFFER_FROM_PTR: + return "backend_device_buffer_from_ptr"; + /* buffer-type */ + case APIR_COMMAND_TYPE_BUFFER_TYPE_GET_NAME: + return "backend_buffer_type_get_name"; + case APIR_COMMAND_TYPE_BUFFER_TYPE_GET_ALIGNMENT: + return "backend_buffer_type_get_alignment"; + case APIR_COMMAND_TYPE_BUFFER_TYPE_GET_MAX_SIZE: + return "backend_buffer_type_get_max_size"; + case APIR_COMMAND_TYPE_BUFFER_TYPE_IS_HOST: + return "backend_buffer_type_is_host"; + case APIR_COMMAND_TYPE_BUFFER_TYPE_ALLOC_BUFFER: + return "backend_buffer_type_alloc_buffer"; + case APIR_COMMAND_TYPE_BUFFER_TYPE_GET_ALLOC_SIZE: + return "backend_buffer_type_get_alloc_size"; + /* buffer */ + case APIR_COMMAND_TYPE_BUFFER_GET_BASE: + return "backend_buffer_get_base"; + case APIR_COMMAND_TYPE_BUFFER_SET_TENSOR: + return "backend_buffer_set_tensor"; + case APIR_COMMAND_TYPE_BUFFER_GET_TENSOR: + return "backend_buffer_get_tensor"; + case APIR_COMMAND_TYPE_BUFFER_CPY_TENSOR: + return "backend_buffer_cpy_tensor"; + case APIR_COMMAND_TYPE_BUFFER_CLEAR: + return "backend_buffer_clear"; + case APIR_COMMAND_TYPE_BUFFER_FREE_BUFFER: + return "backend_buffer_free_buffer"; + /* backend */ + case APIR_COMMAND_TYPE_BACKEND_GRAPH_COMPUTE: + return "backend_backend_graph_compute"; + + default: + return "unknown"; + } +} + +extern "C" { +static const backend_dispatch_t apir_backend_dispatch_table[APIR_BACKEND_DISPATCH_TABLE_COUNT] = { + + /* device */ + + /* APIR_COMMAND_TYPE_DEVICE_GET_DEVICE_COUNT = */ backend_device_get_device_count, + /* APIR_COMMAND_TYPE_DEVICE_GET_COUNT = */ backend_device_get_count, + /* APIR_COMMAND_TYPE_DEVICE_GET_NAME = */ backend_device_get_name, + /* APIR_COMMAND_TYPE_DEVICE_GET_DESCRIPTION = */ backend_device_get_description, + /* APIR_COMMAND_TYPE_DEVICE_GET_TYPE = */ backend_device_get_type, + /* APIR_COMMAND_TYPE_DEVICE_GET_MEMORY = */ backend_device_get_memory, + /* APIR_COMMAND_TYPE_DEVICE_SUPPORTS_OP = */ backend_device_supports_op, + /* APIR_COMMAND_TYPE_DEVICE_GET_BUFFER_TYPE = */ backend_device_get_buffer_type, + /* APIR_COMMAND_TYPE_DEVICE_GET_PROPS = */ backend_device_get_props, + /* APIR_COMMAND_TYPE_DEVICE_BUFFER_FROM_PTR = */ backend_device_buffer_from_ptr, + + /* buffer-type */ + + /* APIR_COMMAND_TYPE_BUFFER_TYPE_GET_NAME = */ backend_buffer_type_get_name, + /* APIR_COMMAND_TYPE_BUFFER_TYPE_GET_ALIGNMENT = */ backend_buffer_type_get_alignment, + /* APIR_COMMAND_TYPE_BUFFER_TYPE_GET_MAX_SIZE = */ backend_buffer_type_get_max_size, + /* APIR_COMMAND_TYPE_BUFFER_TYPE_IS_HOST = */ backend_buffer_type_is_host, + /* APIR_COMMAND_TYPE_BUFFER_TYPE_ALLOC_BUFFER = */ backend_buffer_type_alloc_buffer, + /* APIR_COMMAND_TYPE_BUFFER_TYPE_GET_ALLOC_SIZE = */ backend_buffer_type_get_alloc_size, + + /* buffer */ + + /* APIR_COMMAND_TYPE_BUFFER_GET_BASE = */ backend_buffer_get_base, + /* APIR_COMMAND_TYPE_BUFFER_SET_TENSOR = */ backend_buffer_set_tensor, + /* APIR_COMMAND_TYPE_BUFFER_GET_TENSOR = */ backend_buffer_get_tensor, + /* APIR_COMMAND_TYPE_BUFFER_CPY_TENSOR = */ backend_buffer_cpy_tensor, + /* APIR_COMMAND_TYPE_BUFFER_CLEAR = */ backend_buffer_clear, + /* APIR_COMMAND_TYPE_BUFFER_FREE_BUFFER = */ backend_buffer_free_buffer, + + /* backend */ + + /* APIR_COMMAND_TYPE_BACKEND_GRAPH_COMPUTE = */ backend_backend_graph_compute, +}; +} diff --git a/ggml/src/ggml-virtgpu/backend/backend-dispatched.h b/ggml/src/ggml-virtgpu/backend/backend-dispatched.h new file mode 100644 index 00000000000..6ccbecf078d --- /dev/null +++ b/ggml/src/ggml-virtgpu/backend/backend-dispatched.h @@ -0,0 +1,23 @@ +#pragma once + +#include +#include + +#include + +#include "backend-convert.h" +#include "backend-virgl-apir.h" +#include "shared/apir_backend.h" +#include "shared/apir_cs.h" +#include "shared/apir_cs_ggml.h" + +struct virgl_apir_context { + uint32_t ctx_id; + virgl_apir_callbacks * iface; +}; + +typedef uint32_t (*backend_dispatch_t)(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); + +#include "backend-dispatched.gen.h" + +uint32_t backend_dispatch_initialize(void * ggml_backend_reg_fct_p); diff --git a/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h b/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h new file mode 100644 index 00000000000..44b347f853f --- /dev/null +++ b/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h @@ -0,0 +1,32 @@ +#pragma once + +#include "ggml-backend-impl.h" +#include "ggml-backend.h" +#include "ggml-impl.h" +#include "shared/api_remoting.h" + +#include +#include +#include + +extern ggml_backend_reg_t reg; +extern ggml_backend_dev_t dev; +extern ggml_backend_t bck; + +struct virgl_apir_callbacks { + const char * (*get_config)(uint32_t virgl_ctx_id, const char * key); + void * (*get_shmem_ptr)(uint32_t virgl_ctx_id, uint32_t res_id); +}; + +extern "C" { +ApirLoadLibraryReturnCode apir_backend_initialize(uint32_t virgl_ctx_id, struct virgl_apir_callbacks *virgl_cbs); +void apir_backend_deinit(uint32_t virgl_ctx_id); +uint32_t apir_backend_dispatcher(uint32_t virgl_ctx_id, + virgl_apir_callbacks * virgl_cbs, + uint32_t cmd_type, + char * dec_cur, + const char * dec_end, + char * enc_cur, + const char * enc_end, + char ** enc_cur_after); +} diff --git a/ggml/src/ggml-virtgpu/backend/backend.cpp b/ggml/src/ggml-virtgpu/backend/backend.cpp new file mode 100644 index 00000000000..95d602ed603 --- /dev/null +++ b/ggml/src/ggml-virtgpu/backend/backend.cpp @@ -0,0 +1,148 @@ +#include "backend-dispatched.h" +#include "backend-virgl-apir.h" + +#include "shared/api_remoting.h" +#include "shared/apir_backend.h" +#include "shared/apir_cs.h" + +#include +#include + +#include + +#define APIR_LLAMA_CPP_GGML_LIBRARY_PATH_ENV "APIR_LLAMA_CPP_GGML_LIBRARY_PATH" +#define APIR_LLAMA_CPP_GGML_LIBRARY_REG_ENV "APIR_LLAMA_CPP_GGML_LIBRARY_REG" +#define APIR_LLAMA_CPP_LOG_TO_FILE_ENV "APIR_LLAMA_CPP_LOG_TO_FILE" + +#define GGML_DEFAULT_BACKEND_REG "ggml_backend_init" + +static void * backend_library_handle = NULL; +static FILE * apir_logfile = NULL; + +static void log_to_file_callback(enum ggml_log_level level, const char * text, void * user_data) { + FILE * logfile = (FILE *)user_data; + fprintf(logfile, "[%d] %s", level, text); + fflush(logfile); +} + +extern "C" { +void apir_backend_deinit(uint32_t virgl_ctx_id) { + GGML_UNUSED(virgl_ctx_id); + + auto buffers = apir_get_track_backend_buffers(); + for (const auto & buffer : buffers) { + apir_untrack_backend_buffer(buffer); + buffer->iface.free_buffer(buffer); + } + + if (dev) { + size_t free, total; + dev->iface.get_memory(dev, &free, &total); + GGML_LOG_INFO("%s: free memory: %ld MB\n", __func__, (size_t) free / 1024 / 1024); + } + + if (backend_library_handle) { + GGML_LOG_INFO("%s: The GGML backend library was loaded. Unloading it.\n", __func__); + dlclose(backend_library_handle); + backend_library_handle = NULL; + } + + if (apir_logfile) { + fclose(apir_logfile); + apir_logfile = NULL; + } +} + +#define APIR_GGML_LIBRARY_PATH_KEY "ggml.library.path" +#define APIR_GGML_LIBRARY_REG_KEY "ggml.library.reg" + +ApirLoadLibraryReturnCode apir_backend_initialize(uint32_t virgl_ctx_id, struct virgl_apir_callbacks *virgl_cbs) { + const char * dlsym_error; + + const char * apir_log_to_file = getenv(APIR_LLAMA_CPP_LOG_TO_FILE_ENV); + if (apir_log_to_file) { + apir_logfile = fopen(apir_log_to_file, "w"); + if (apir_logfile) { + ggml_log_set(log_to_file_callback, apir_logfile); + } else { + GGML_LOG_INFO("Could not open the log file at '%s'\n", apir_log_to_file); + } + } + + const char * library_name = virgl_cbs->get_config(virgl_ctx_id, APIR_GGML_LIBRARY_PATH_KEY); + const char * virgl_library_reg = virgl_cbs->get_config(virgl_ctx_id, APIR_GGML_LIBRARY_REG_KEY); + const char * library_reg = virgl_library_reg ? virgl_library_reg : GGML_DEFAULT_BACKEND_REG; + + if (!library_name) { + GGML_LOG_ERROR("cannot open the GGML library: env var '%s' not defined\n", APIR_LLAMA_CPP_GGML_LIBRARY_PATH_ENV); + + return APIR_LOAD_LIBRARY_ENV_VAR_MISSING; + } + + backend_library_handle = dlopen(library_name, RTLD_LAZY); + + if (!backend_library_handle) { + GGML_LOG_ERROR("cannot open the GGML library: %s\n", dlerror()); + + return APIR_LOAD_LIBRARY_CANNOT_OPEN; + } + + if (!library_reg) { + GGML_LOG_ERROR("cannot register the GGML library: env var '%s' not defined\n", APIR_LLAMA_CPP_GGML_LIBRARY_REG_ENV); + + return APIR_LOAD_LIBRARY_ENV_VAR_MISSING; + } + + void * ggml_backend_reg_fct = dlsym(backend_library_handle, library_reg); + dlsym_error = dlerror(); + if (dlsym_error) { + GGML_LOG_ERROR("cannot find the GGML backend registration symbol '%s' (from %s): %s\n", library_reg, + APIR_LLAMA_CPP_GGML_LIBRARY_REG_ENV, dlsym_error); + + return APIR_LOAD_LIBRARY_SYMBOL_MISSING; + } + + uint32_t ret = backend_dispatch_initialize(ggml_backend_reg_fct); + + return (ApirLoadLibraryReturnCode) (APIR_LOAD_LIBRARY_INIT_BASE_INDEX + ret); +} + +uint32_t apir_backend_dispatcher(uint32_t virgl_ctx_id, + virgl_apir_callbacks * virgl_cbs, + uint32_t cmd_type, + char * dec_cur, + const char * dec_end, + char * enc_cur, + const char * enc_end, + char ** enc_cur_after) { + apir_encoder enc = { + .cur = enc_cur, + .start = enc_cur, + .end = enc_end, + .fatal = false, + }; + + apir_decoder dec = { + .cur = dec_cur, + .end = dec_end, + .fatal = false, + }; + + virgl_apir_context ctx = { + .ctx_id = virgl_ctx_id, + .iface = virgl_cbs, + }; + + if (cmd_type >= APIR_BACKEND_DISPATCH_TABLE_COUNT) { + GGML_LOG_ERROR("Received an invalid dispatch index (%d >= %d)\n", cmd_type, APIR_BACKEND_DISPATCH_TABLE_COUNT); + return APIR_BACKEND_FORWARD_INDEX_INVALID; + } + + backend_dispatch_t forward_fct = apir_backend_dispatch_table[cmd_type]; + uint32_t ret = forward_fct(&enc, &dec, &ctx); + + *enc_cur_after = enc.cur; + + return ret; +} +} diff --git a/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h b/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h new file mode 100644 index 00000000000..f19a5d12d17 --- /dev/null +++ b/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h @@ -0,0 +1,90 @@ +#pragma once + +/* the rest of this file must match virglrenderer/src/apir-protocol.h */ + +#include + +#include + +#define APIR_PROTOCOL_MAJOR 0 +#define APIR_PROTOCOL_MINOR 1 + +#define APIR_HANDSHAKE_MAGIC 0xab1e + +enum ApirCommandType { + APIR_COMMAND_TYPE_HANDSHAKE = 0, + APIR_COMMAND_TYPE_LOADLIBRARY = 1, + APIR_COMMAND_TYPE_FORWARD = 2, + + APIR_COMMAND_TYPE_LENGTH = 3, +}; + +typedef uint64_t ApirCommandFlags; + +enum ApirLoadLibraryReturnCode { + APIR_LOAD_LIBRARY_SUCCESS = 0, + APIR_LOAD_LIBRARY_HYPERCALL_INITIALIZATION_ERROR = 1, + APIR_LOAD_LIBRARY_ALREADY_LOADED = 2, + APIR_LOAD_LIBRARY_ENV_VAR_MISSING = 3, + APIR_LOAD_LIBRARY_CANNOT_OPEN = 4, + APIR_LOAD_LIBRARY_SYMBOL_MISSING = 5, + APIR_LOAD_LIBRARY_INIT_BASE_INDEX = 6, // anything above this is a APIR backend library initialization return code +}; + +enum ApirForwardReturnCode { + APIR_FORWARD_SUCCESS = 0, + APIR_FORWARD_NO_DISPATCH_FCT = 1, + APIR_FORWARD_TIMEOUT = 2, + + APIR_FORWARD_BASE_INDEX = 3, // anything above this is a APIR backend library forward return code +} ; + +__attribute__((unused)) static inline const char * apir_command_name(ApirCommandType type) { + switch (type) { + case APIR_COMMAND_TYPE_HANDSHAKE: + return "HandShake"; + case APIR_COMMAND_TYPE_LOADLIBRARY: + return "LoadLibrary"; + case APIR_COMMAND_TYPE_FORWARD: + return "Forward"; + default: + return "unknown"; + } +} + +__attribute__((unused)) static const char * apir_load_library_error(ApirLoadLibraryReturnCode code) { +#define APIR_LOAD_LIBRARY_ERROR(code_name) \ + do { \ + if (code == code_name) \ + return #code_name; \ + } while (0) + + APIR_LOAD_LIBRARY_ERROR(APIR_LOAD_LIBRARY_SUCCESS); + APIR_LOAD_LIBRARY_ERROR(APIR_LOAD_LIBRARY_HYPERCALL_INITIALIZATION_ERROR); + APIR_LOAD_LIBRARY_ERROR(APIR_LOAD_LIBRARY_ALREADY_LOADED); + APIR_LOAD_LIBRARY_ERROR(APIR_LOAD_LIBRARY_ENV_VAR_MISSING); + APIR_LOAD_LIBRARY_ERROR(APIR_LOAD_LIBRARY_CANNOT_OPEN); + APIR_LOAD_LIBRARY_ERROR(APIR_LOAD_LIBRARY_SYMBOL_MISSING); + APIR_LOAD_LIBRARY_ERROR(APIR_LOAD_LIBRARY_INIT_BASE_INDEX); + + return "Unknown APIR_COMMAND_TYPE_LoadLibrary error"; + +#undef APIR_LOAD_LIBRARY_ERROR +} + +__attribute__((unused)) static const char * apir_forward_error(ApirForwardReturnCode code) { +#define APIR_FORWARD_ERROR(code_name) \ + do { \ + if (code == code_name) \ + return #code_name; \ + } while (0) + + APIR_FORWARD_ERROR(APIR_FORWARD_SUCCESS); + APIR_FORWARD_ERROR(APIR_FORWARD_NO_DISPATCH_FCT); + APIR_FORWARD_ERROR(APIR_FORWARD_TIMEOUT); + APIR_FORWARD_ERROR(APIR_FORWARD_BASE_INDEX); + + return "Unknown APIR_COMMAND_TYPE_FORWARD error"; + +#undef APIR_FORWARD_ERROR +} diff --git a/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h b/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h new file mode 100644 index 00000000000..d214b6f2a90 --- /dev/null +++ b/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h @@ -0,0 +1,36 @@ +typedef enum ApirBackendCommandType { + + /* device */ + APIR_COMMAND_TYPE_DEVICE_GET_DEVICE_COUNT = 0, + APIR_COMMAND_TYPE_DEVICE_GET_COUNT = 1, + APIR_COMMAND_TYPE_DEVICE_GET_NAME = 2, + APIR_COMMAND_TYPE_DEVICE_GET_DESCRIPTION = 3, + APIR_COMMAND_TYPE_DEVICE_GET_TYPE = 4, + APIR_COMMAND_TYPE_DEVICE_GET_MEMORY = 5, + APIR_COMMAND_TYPE_DEVICE_SUPPORTS_OP = 6, + APIR_COMMAND_TYPE_DEVICE_GET_BUFFER_TYPE = 7, + APIR_COMMAND_TYPE_DEVICE_GET_PROPS = 8, + APIR_COMMAND_TYPE_DEVICE_BUFFER_FROM_PTR = 9, + + /* buffer-type */ + APIR_COMMAND_TYPE_BUFFER_TYPE_GET_NAME = 10, + APIR_COMMAND_TYPE_BUFFER_TYPE_GET_ALIGNMENT = 11, + APIR_COMMAND_TYPE_BUFFER_TYPE_GET_MAX_SIZE = 12, + APIR_COMMAND_TYPE_BUFFER_TYPE_IS_HOST = 13, + APIR_COMMAND_TYPE_BUFFER_TYPE_ALLOC_BUFFER = 14, + APIR_COMMAND_TYPE_BUFFER_TYPE_GET_ALLOC_SIZE = 15, + + /* buffer */ + APIR_COMMAND_TYPE_BUFFER_GET_BASE = 16, + APIR_COMMAND_TYPE_BUFFER_SET_TENSOR = 17, + APIR_COMMAND_TYPE_BUFFER_GET_TENSOR = 18, + APIR_COMMAND_TYPE_BUFFER_CPY_TENSOR = 19, + APIR_COMMAND_TYPE_BUFFER_CLEAR = 20, + APIR_COMMAND_TYPE_BUFFER_FREE_BUFFER = 21, + + /* backend */ + APIR_COMMAND_TYPE_BACKEND_GRAPH_COMPUTE = 22, + + // last command_type index + 1 + APIR_BACKEND_DISPATCH_TABLE_COUNT = 23, +} ApirBackendCommandType; diff --git a/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h b/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h new file mode 100644 index 00000000000..f3efa52c721 --- /dev/null +++ b/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h @@ -0,0 +1,46 @@ +#pragma once + +#include "apir_backend.gen.h" + +#include // for uintptr_t +#include // for timespec, clock_gettime + +#define APIR_BACKEND_INITIALIZE_SUCCESS 0 +#define APIR_BACKEND_INITIALIZE_CANNOT_OPEN_BACKEND_LIBRARY 1 +#define APIR_BACKEND_INITIALIZE_CANNOT_OPEN_GGML_LIBRARY 2 +#define APIR_BACKEND_INITIALIZE_MISSING_BACKEND_SYMBOLS 3 +#define APIR_BACKEND_INITIALIZE_MISSING_GGML_SYMBOLS 4 +#define APIR_BACKEND_INITIALIZE_BACKEND_FAILED 5 +#define APIR_BACKEND_INITIALIZE_BACKEND_REG_FAILED 6 +#define APIR_BACKEND_INITIALIZE_ALREADY_INITED 7 +#define APIR_BACKEND_INITIALIZE_NO_DEVICE 8 + + +// new entries here need to be added to the apir_backend_initialize_error function below + +#define APIR_BACKEND_FORWARD_INDEX_INVALID 6 + +// 0 is fast, 1 avoids the backend to crash if an unsupported tensor is received +#define APIR_BACKEND_CHECK_SUPPORTS_OP 0 + +typedef uintptr_t apir_buffer_type_host_handle_t; +typedef uintptr_t apir_buffer_host_handle_t; + +static const char * apir_backend_initialize_error(int code) { +#define APIR_BACKEND_INITIALIZE_ERROR(code_name) \ + do { \ + if (code == code_name) \ + return #code_name; \ + } while (0) + + APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_SUCCESS); + APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_CANNOT_OPEN_BACKEND_LIBRARY); + APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_CANNOT_OPEN_GGML_LIBRARY); + APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_MISSING_BACKEND_SYMBOLS); + APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_MISSING_GGML_SYMBOLS); + APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_BACKEND_FAILED); + + return "Unknown APIR_BACKEND_INITIALIZE error:/"; + +#undef APIR_BACKEND_INITIALIZE_ERROR +} diff --git a/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h b/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h new file mode 100644 index 00000000000..27a61091ffd --- /dev/null +++ b/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h @@ -0,0 +1,383 @@ +#pragma once + +#include "ggml-impl.h" + +#include +#include + +#define likely(x) __builtin_expect(!!(x), 1) +#define unlikely(x) __builtin_expect(!!(x), 0) + +struct apir_encoder { + char * cur; + const char * start; + const char * end; + bool fatal; + +}; + +struct apir_decoder { + const char * cur; + const char * end; + bool fatal; +}; + +/* + * new encoder and decoder + */ + +static apir_decoder apir_new_decoder(const char * ptr, size_t size) { + apir_decoder dec = { + .cur = ptr, + .end = ptr + size, + .fatal = false, + }; + + return dec; +} + +static apir_encoder apir_new_encoder(char * ptr, size_t size) { + apir_encoder enc = { + .cur = ptr, + .start = ptr, + .end = ptr + size, + .fatal = false, + }; + + return enc; +} + +/* + * fatal flag handling + */ + +static inline void apir_encoder_reset_fatal(apir_encoder * enc) { + enc->fatal = false; +} + +static inline void apir_encoder_set_fatal(apir_encoder * enc) { + enc->fatal = true; +} + +static inline bool apir_encoder_get_fatal(const apir_encoder * enc) { + return enc->fatal; +} + +static inline void apir_decoder_reset_fatal(apir_decoder * dec) { + dec->fatal = false; +} + +static inline void apir_decoder_set_fatal(apir_decoder * dec) { + dec->fatal = true; +} + +static inline bool apir_decoder_get_fatal(const apir_decoder * dec) { + return dec->fatal; +} + +/* + * encode peek + */ + +static inline bool apir_decoder_peek_internal(apir_decoder * dec, + size_t size, + void * val, + size_t val_size) { + assert(val_size <= size); + + if (unlikely(size > (size_t) (dec->end - dec->cur))) { + GGML_LOG_ERROR("reading too much from the decoder ...\n"); + apir_decoder_set_fatal(dec); + memset(val, 0, val_size); + return false; + } + + /* we should not rely on the compiler to optimize away memcpy... */ + memcpy(val, dec->cur, val_size); + return true; +} + +static inline void apir_decoder_peek(apir_decoder * dec, size_t size, void * val, size_t val_size) { + apir_decoder_peek_internal(dec, size, val, val_size); +} + +static inline const void * apir_decoder_use_inplace(apir_decoder * dec, size_t size) { + if (unlikely(size > (size_t) (dec->end - dec->cur))) { + GGML_LOG_ERROR("reading too much from the decoder ...\n"); + apir_decoder_set_fatal(dec); + return NULL; + } + const void * addr = dec->cur; + dec->cur += size; + + return addr; +} + +/* + * read/write + */ + +static inline void apir_decoder_read(apir_decoder * dec, size_t size, void * val, size_t val_size) { + if (apir_decoder_peek_internal(dec, size, val, val_size)) { + dec->cur += size; + } +} + +static inline char * apir_encoder_write(apir_encoder * enc, size_t size, const void * val, size_t val_size) { + assert(val_size <= size); + assert(size <= ((size_t) (enc->end - enc->cur))); + + char * write_addr = enc->cur; + /* we should not rely on the compiler to optimize away memcpy... */ + memcpy(write_addr, val, val_size); + enc->cur += size; + + return write_addr; +} + +/* + * encode/decode + */ + +static inline void apir_decode(apir_decoder * dec, size_t size, void * data, size_t data_size) { + assert(size % 4 == 0); + apir_decoder_read(dec, size, data, data_size); +} + +static inline void apir_encode(apir_encoder * enc, size_t size, const void * data, size_t data_size) { + assert(size % 4 == 0); + apir_encoder_write(enc, size, data, data_size); +} + +/* + * typed encode/decode + */ + +/* uint8_t */ + +static inline void apir_encode_uint8_t(apir_encoder * enc, const uint8_t * val) { + apir_encode(enc, sizeof(int), val, sizeof(*val)); +} + +static inline void apir_decode_uint8_t(apir_decoder * dec, uint8_t * val) { + apir_decode(dec, sizeof(int), val, sizeof(*val)); +} + +/* uint64_t */ + +static inline void apir_encode_uint64_t(apir_encoder * enc, const uint64_t * val) { + apir_encode(enc, 8, val, sizeof(*val)); +} + +static inline void apir_decode_uint64_t(apir_decoder * dec, uint64_t * val) { + apir_decode(dec, 8, val, sizeof(*val)); +} + +static inline void apir_encode_uint64_t_array(apir_encoder * enc, const uint64_t * val, uint32_t count) { + const size_t size = sizeof(*val) * count; + assert(size >= count); + apir_encode(enc, size, val, size); +} + +static inline void apir_decode_uint64_t_array(apir_decoder * dec, uint64_t * val, uint32_t count) { + const size_t size = sizeof(*val) * count; + assert(size >= count); + apir_decode(dec, size, val, size); +} + +static inline const uint64_t * apir_decode_uint64_t_array_inplace(apir_decoder * dec, uint32_t count) { + return (uint64_t *) (uintptr_t) apir_decoder_use_inplace(dec, count * sizeof(uint64_t)); +} + +/* int32_t */ + +static inline void apir_encode_int32_t(apir_encoder * enc, const int32_t * val) { + apir_encode(enc, 4, val, sizeof(*val)); +} + +static inline void apir_decode_int32_t(apir_decoder * dec, int32_t * val) { + apir_decode(dec, 4, val, sizeof(*val)); +} + +static inline void apir_encode_int32_t_array(apir_encoder * enc, const int32_t * val, uint32_t count) { + const size_t size = sizeof(*val) * count; + assert(size >= count); + apir_encode(enc, size, val, size); +} + +static inline void apir_decode_int32_t_array(apir_decoder * dec, int32_t * val, uint32_t count) { + const size_t size = sizeof(*val) * count; + assert(size >= count); + apir_decode(dec, size, val, size); +} + +/* array size (uint64_t) */ + +static inline void apir_encode_array_size(apir_encoder * enc, uint64_t size) { + apir_encode_uint64_t(enc, &size); +} + +static inline uint64_t apir_decode_array_size(apir_decoder * dec, uint64_t expected_size) { + uint64_t size; + apir_decode_uint64_t(dec, &size); + if (size != expected_size) { + GGML_LOG_ERROR("Couldn't decode array from the decoder\n"); + apir_decoder_set_fatal(dec); + size = 0; + } + return size; +} + +static inline uint64_t apir_decode_array_size_unchecked(apir_decoder * dec) { + uint64_t size; + apir_decode_uint64_t(dec, &size); + return size; +} + +/* non-array pointer */ + +static inline bool apir_encode_simple_pointer(apir_encoder * enc, const void * val) { + apir_encode_array_size(enc, val ? 1 : 0); + return val; +} + +static inline bool apir_decode_simple_pointer(apir_decoder * dec) { + return apir_decode_array_size_unchecked(dec); +} + +/* uint32_t */ + +static inline void apir_encode_uint32_t(apir_encoder * enc, const uint32_t * val) { + apir_encode(enc, 4, val, sizeof(*val)); +} + +static inline void apir_decode_uint32_t(apir_decoder * dec, uint32_t * val) { + apir_decode(dec, 4, val, sizeof(*val)); +} + +static inline void apir_encode_uint32_t_array(apir_encoder * enc, const uint32_t * val, uint32_t count) { + const size_t size = sizeof(*val) * count; + assert(size >= count); + apir_encode(enc, size, val, size); +} + +static inline void apir_decode_uint32_t_array(apir_decoder * dec, uint32_t * val, uint32_t count) { + const size_t size = sizeof(*val) * count; + assert(size >= count); + apir_decode(dec, size, val, size); +} + +/* size_t */ + +static inline void apir_encode_size_t(apir_encoder * enc, const size_t * val) { + const uint64_t tmp = *val; + apir_encode_uint64_t(enc, &tmp); +} + +static inline void apir_decode_size_t(apir_decoder * dec, size_t * val) { + uint64_t tmp; + apir_decode_uint64_t(dec, &tmp); + *val = tmp; +} + +static inline void apir_encode_size_t_array(apir_encoder * enc, const size_t * val, uint32_t count) { + if (sizeof(size_t) == sizeof(uint64_t)) { + apir_encode_uint64_t_array(enc, (const uint64_t *) val, count); + } else { + for (uint32_t i = 0; i < count; i++) { + apir_encode_size_t(enc, &val[i]); + } + } +} + +static inline void apir_decode_size_t_array(apir_decoder * dec, size_t * val, uint32_t count) { + if (sizeof(size_t) == sizeof(uint64_t)) { + apir_decode_uint64_t_array(dec, (uint64_t *) val, count); + } else { + for (uint32_t i = 0; i < count; i++) { + apir_decode_size_t(dec, &val[i]); + } + } +} + +/* opaque blob */ + +static inline void apir_encode_blob_array(apir_encoder * enc, const void * val, size_t size) { + apir_encode(enc, (size + 3) & ~3, val, size); +} + +static inline void apir_decode_blob_array(apir_decoder * dec, void * val, size_t size) { + apir_decode(dec, (size + 3) & ~3, val, size); +} + +/* string */ + +static inline void apir_encode_char_array(apir_encoder * enc, const char * val, size_t size) { + assert(size && strlen(val) < size); + apir_encode_blob_array(enc, val, size); +} + +static inline void apir_decode_char_array(apir_decoder * dec, char * val, size_t size) { + apir_decode_blob_array(dec, val, size); + if (size) { + val[size - 1] = '\0'; + } else { + GGML_LOG_ERROR("Couldn't decode the blog array\n"); + apir_decoder_set_fatal(dec); + } +} + +/* (temp) buffer allocation */ + +static inline void * apir_decoder_alloc_array(size_t size, size_t count) { + size_t alloc_size; + if (unlikely(__builtin_mul_overflow(size, count, &alloc_size))) { + GGML_LOG_ERROR("overflow in array allocation of %zu * %zu bytes\n", size, count); + return NULL; + } + + return malloc(alloc_size); +} + +/* bool */ + +static inline void apir_encode_bool_t(apir_encoder * enc, const bool * val) { + apir_encode(enc, sizeof(int), val, sizeof(bool)); +} + +static inline void apir_decode_bool_t(apir_decoder * dec, bool * val) { + apir_decode(dec, sizeof(int), val, sizeof(bool)); +} + +/* apir_buffer_type_host_handle_t */ + +static inline void apir_encode_apir_buffer_type_host_handle_t(apir_encoder * enc, + const apir_buffer_type_host_handle_t * val) { + apir_encode(enc, sizeof(apir_buffer_type_host_handle_t), val, sizeof(apir_buffer_type_host_handle_t)); +} + +static inline void apir_decode_apir_buffer_type_host_handle_t(apir_decoder * dec, + apir_buffer_type_host_handle_t * val) { + apir_decode(dec, sizeof(apir_buffer_type_host_handle_t), val, sizeof(apir_buffer_type_host_handle_t)); +} + +/* apir_buffer_host_handle_t */ + +static inline void apir_encode_apir_buffer_host_handle_t(apir_encoder * enc, + const apir_buffer_host_handle_t * val) { + apir_encode(enc, sizeof(apir_buffer_host_handle_t), val, sizeof(apir_buffer_host_handle_t)); +} + +static inline void apir_decode_apir_buffer_host_handle_t(apir_decoder * dec, apir_buffer_host_handle_t * val) { + apir_decode(dec, sizeof(apir_buffer_host_handle_t), val, sizeof(apir_buffer_host_handle_t)); +} + +/* uintptr_t */ + +static inline void apir_encode_uintptr_t(apir_encoder * enc, const uintptr_t * val) { + apir_encode(enc, sizeof(*val), val, sizeof(*val)); +} + +static inline void apir_decode_uintptr_t(apir_decoder * dec, uintptr_t * val) { + apir_decode(dec, sizeof(*val), val, sizeof(*val)); +} diff --git a/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h b/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h new file mode 100644 index 00000000000..070c3b25fb1 --- /dev/null +++ b/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h @@ -0,0 +1,211 @@ +#include "ggml-impl.h" +#include "apir_cs.h" +#include "apir_cs_rpc.h" + +// ggml_buffer_to_apir_host_handle(ggml_backend_buffer_t buffer); + +static inline void apir_encode_ggml_buffer_host_handle(apir_encoder * enc, + const apir_buffer_host_handle_t * handle); + +static inline ggml_backend_buffer_t apir_decode_ggml_buffer(apir_decoder * dec); + +/* apir_rpc_tensor */ + +static inline void apir_encode_rcp_tensor(apir_encoder * enc, const apir_rpc_tensor * apir_rpc_tensor) { + size_t apir_rpc_tensor_size = sizeof(*apir_rpc_tensor); + apir_encode(enc, apir_rpc_tensor_size, apir_rpc_tensor, apir_rpc_tensor_size); +} + +static inline apir_rpc_tensor * apir_decode_apir_rpc_tensor_inplace(apir_decoder * dec) { + size_t apir_rpc_tensor_size = sizeof(apir_rpc_tensor); + + return (apir_rpc_tensor *) (uintptr_t) apir_decoder_use_inplace(dec, apir_rpc_tensor_size); +} + +static inline apir_rpc_tensor * apir_decode_apir_rpc_tensor_array_inplace(apir_decoder * dec, + uint32_t n_tensors) { + size_t apir_rpc_tensor_size = sizeof(apir_rpc_tensor) * n_tensors; + + return (apir_rpc_tensor *) (uintptr_t) apir_decoder_use_inplace(dec, apir_rpc_tensor_size); +} + +/* ggml_tensor */ + +static inline void apir_encode_ggml_tensor(apir_encoder * enc, const ggml_tensor * tensor) { + apir_rpc_tensor serialized = apir_serialize_tensor(tensor); + + apir_encode_rcp_tensor(enc, &serialized); +} + +static inline const ggml_tensor * apir_decode_ggml_tensor(apir_decoder * dec) { + const apir_rpc_tensor * apir_rpc_tensor = apir_decode_apir_rpc_tensor_inplace(dec); + ggml_init_params params{ + /*.mem_size =*/ ggml_tensor_overhead(), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + ggml_context * ctx = ggml_init(params); + + const ggml_tensor * tensor = apir_deserialize_tensor(ctx, apir_rpc_tensor); + + return tensor; +} + +/* *** ggml_backend_buffer_type_t *** */ + +// ggml_backend_buffer_type_t is a POINTER (to a struct). +// Only the host pointer is shared between the host and guest. +// The guest stores it in `buft->context`. +// The host simply writes the pointer address in the buffer variable. + +static inline void apir_encode_ggml_buffer_type(apir_encoder * enc, ggml_backend_buffer_type_t buft) { + apir_buffer_type_host_handle_t handle = ggml_buffer_type_to_apir_handle(buft); + apir_encoder_write(enc, sizeof(handle), &handle, sizeof(handle)); +} + +static inline ggml_backend_buffer_type_t apir_decode_ggml_buffer_type(apir_decoder * dec) { + apir_buffer_type_host_handle_t handle; + + apir_decoder_read(dec, sizeof(handle), &handle, sizeof(handle)); + + return (ggml_backend_buffer_type_t) handle; +} + +static inline apir_buffer_type_host_handle_t apir_decode_apir_buffer_type_host_handle(apir_decoder * dec) { + apir_buffer_type_host_handle_t handle; + + apir_decoder_read(dec, sizeof(handle), &handle, sizeof(handle)); + + return handle; +} + +/* *** ggml_backend_type_t *** */ + +// ggml_backend_buffer_t is a POINTER. +// same logic as for ggml_backend_buffer_type_t + +static inline void apir_encode_ggml_buffer(apir_encoder * enc, const ggml_backend_buffer_t buffer) { + apir_buffer_host_handle_t handle = BUFFER_TO_HOST_HANDLE(buffer); + apir_encoder_write(enc, sizeof(handle), &handle, sizeof(handle)); +} + +static inline ggml_backend_buffer_t apir_decode_ggml_buffer(apir_decoder * dec) { + ggml_backend_buffer_t buffer; + size_t buffer_ptr_size = sizeof(buffer); + + apir_decoder_read(dec, buffer_ptr_size, &buffer, buffer_ptr_size); + + return buffer; +} + +/* enum ggml_status */ + +static inline void apir_encode_ggml_status(apir_encoder * enc, const ggml_status * status) { + apir_encoder_write(enc, sizeof(*status), status, sizeof(*status)); +} + +static inline void apir_decode_ggml_status(apir_decoder * dec, ggml_status * status) { + apir_decoder_read(dec, sizeof(*status), status, sizeof(*status)); +} + +/* virtgpu_shmem */ + +static inline void apir_encode_virtgpu_shmem_res_id(apir_encoder * enc, uint32_t shmem_res_id) { + apir_encode_uint32_t(enc, &shmem_res_id); +} + +static inline void apir_decode_virtgpu_shmem_res_id(apir_decoder * dec, uint32_t * shmem_res_id) { + apir_decode_uint32_t(dec, shmem_res_id); +} + +/* ggml_cgraph */ + +static inline size_t apir_serialize_ggml_cgraph(ggml_cgraph * cgraph, std::vector & cgraph_data) { + apir_serialize_graph(cgraph, cgraph_data); + + return cgraph_data.size(); +} + +static inline void apir_encode_cgraph_data(apir_encoder * enc, std::vector & cgraph_data) { + size_t cgraph_size = cgraph_data.size(); + + apir_encode(enc, cgraph_size, cgraph_data.data(), cgraph_size); +} + +static inline ggml_cgraph * apir_decode_ggml_cgraph(apir_decoder * dec, size_t cgraph_size) { + GGML_UNUSED(cgraph_size); + + uint32_t n_nodes; + apir_decode_uint32_t(dec, &n_nodes); + const uint64_t * nodes = apir_decode_uint64_t_array_inplace(dec, n_nodes); + + uint32_t n_tensors; + apir_decode_uint32_t(dec, &n_tensors); + const apir_rpc_tensor * tensors = apir_decode_apir_rpc_tensor_array_inplace(dec, n_tensors); + + return apir_deserialize_graph(n_nodes, n_tensors, tensors, nodes); +} + +static inline void apir_encode_ggml_buffer_handle(apir_encoder * enc, const apir_buffer_host_handle_t * handle) { + apir_encoder_write(enc, sizeof(*handle), &handle, sizeof(*handle)); +} + +static inline void apir_encode_ggml_tensor_inline(apir_encoder * enc, const ggml_tensor * tensor) { + size_t tensor_size = sizeof(*tensor); + + if (tensor->extra) { + GGML_ABORT("Cannot pass tensors with extra"); + } + + if (tensor->src[0] && tensor->buffer) { + static int first = 1; + if (first) { + GGML_LOG_WARN("Cannot pass tensors with src and buffer\n"); + first = 0; + } + } + + apir_encoder_write(enc, tensor_size, tensor, tensor_size); + + // tensor->data is a pointer inside the device buffer. No need to touch it + // tensor->buffer is a pointer to a buffer. Encoding the buffer handle in sequence. + // (could also make a copy of the tensor, and update locally.) + + if (tensor->buffer) { + apir_buffer_host_handle_t buffer_handle = ggml_buffer_to_apir_handle(tensor->buffer); + apir_encode_ggml_buffer_handle(enc, &buffer_handle); + } + + if (tensor->view_src) { + apir_encoder_write(enc, tensor_size, tensor->view_src, tensor_size); + } + + for (int i = 0; tensor->src[i]; i++) { + const ggml_tensor * tensor_src = tensor->src[i]; + apir_encoder_write(enc, tensor_size, tensor_src, tensor_size); + } +} + +static inline const ggml_tensor * apir_decode_ggml_tensor_inplace(apir_decoder * dec) { + // it safe to remove the `const` qualifier here, we *do* want to + // modify the shared memory data to fix the `src` pointers. + ggml_tensor * tensor = (ggml_tensor *) (uintptr_t) apir_decoder_use_inplace(dec, sizeof(ggml_tensor)); + + // tensor->data is a pointer inside the device buffer. No need to touch it + // tensor->buffer is a pointer to a buffer. Decode the buffer handle encoded in sequence. + if (tensor->buffer) { + tensor->buffer = apir_decode_ggml_buffer(dec); + } + + if (tensor->view_src) { + ggml_tensor * tensor_view_src = (ggml_tensor *) (uintptr_t) apir_decoder_use_inplace(dec, sizeof(ggml_tensor)); + tensor->view_src = tensor_view_src; + } + + for (int i = 0; tensor->src[i]; i++) { + ggml_tensor * tensor_src = (ggml_tensor *) (uintptr_t) apir_decoder_use_inplace(dec, sizeof(ggml_tensor)); + tensor->src[i] = tensor_src; // overwrite op->src[i] pointer with the actual location of the src tensor + } + + return tensor; +} diff --git a/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h b/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h new file mode 100644 index 00000000000..f6817989528 --- /dev/null +++ b/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h @@ -0,0 +1,54 @@ +#include "ggml.h" +#include "ggml-backend-impl.h" + +#include +#include +#include +#include + +// ggml_tensor is serialized into apir_rpc_tensor +struct apir_rpc_tensor { + uint64_t id; + uint32_t type; + uint64_t buffer; + uint32_t ne[GGML_MAX_DIMS]; + uint32_t nb[GGML_MAX_DIMS]; + uint32_t op; + int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)]; + int32_t flags; + uint64_t src[GGML_MAX_SRC]; + uint64_t view_src; + uint64_t view_offs; + uint64_t data; + char name[GGML_MAX_NAME]; + + char padding[4]; +}; + +/* frontend */ + +apir_rpc_tensor apir_serialize_tensor(const ggml_tensor * tensor); + +void apir_serialize_graph(const ggml_cgraph * cgraph, std::vector & output); + +/* backend */ + +void apir_track_backend_buffer(ggml_backend_buffer_t buffer); +bool apir_untrack_backend_buffer(ggml_backend_buffer_t buffer); +std::unordered_set apir_get_track_backend_buffers(); + +void apir_add_tensor(ggml_tensor * tensor, + std::vector & tensors, + std::unordered_set & visited); + +ggml_tensor * apir_deserialize_tensor(ggml_context * ctx, const apir_rpc_tensor * tensor); + +ggml_tensor * apir_create_node(uint64_t id, + ggml_context * ctx, + const std::unordered_map & tensor_ptrs, + std::unordered_map & tensor_map); + +ggml_cgraph * apir_deserialize_graph(uint32_t n_nodes, + uint32_t n_tensors, + const apir_rpc_tensor * tensors, + const uint64_t * nodes); diff --git a/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp b/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp new file mode 100644 index 00000000000..7f650659b8a --- /dev/null +++ b/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp @@ -0,0 +1,98 @@ +#include "ggml-remoting.h" + +static ggml_backend_buffer_t ggml_backend_remoting_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, + size_t size) { + virtgpu * gpu = BUFT_TO_GPU(buft); + + ggml_backend_remoting_buffer_context * context = (ggml_backend_remoting_buffer_context *) malloc(sizeof(*context)); + if (!context) { + GGML_ABORT("Couldn't allocate the buffer context ..."); + } + + context->gpu = gpu; + + bool async__unused, host_buffer__unused, events__unused; + bool buffer_from_host_ptr; + apir_device_get_props(gpu, &async__unused, &host_buffer__unused, &buffer_from_host_ptr, &events__unused); + + if (buffer_from_host_ptr) { + context->apir_context = apir_device_buffer_from_ptr(gpu, size, size); + context->base = context->apir_context.shmem.mmap_ptr; + context->is_from_ptr = true; + } else { + context->apir_context = apir_buffer_type_alloc_buffer(gpu, buft, size); + context->is_from_ptr = false; + context->base = NULL; + } + + ggml_backend_buffer_t buffer = + ggml_backend_buffer_init(buft, ggml_backend_remoting_buffer_interface, (void *) context, size); + + return buffer; +} + +static const char * ggml_backend_remoting_buffer_type_get_name(ggml_backend_buffer_type_t buft) { + virtgpu * gpu = BUFT_TO_GPU(buft); + + return apir_buffer_type_get_name(gpu, buft); +} + +static size_t ggml_backend_remoting_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { + virtgpu * gpu = BUFT_TO_GPU(buft); + + static size_t align = 0; + + if (align == 0) { + align = apir_buffer_type_get_alignment(gpu, buft); + } + + return align; +} + +static size_t ggml_backend_remoting_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) { + virtgpu * gpu = BUFT_TO_GPU(buft); + + static size_t max_size = 0; + if (max_size == 0) { + max_size = apir_buffer_type_get_max_size(gpu, buft); + } + + return max_size; +} + +static bool ggml_backend_remoting_buffer_type_is_host(ggml_backend_buffer_type_t buft) { + virtgpu * gpu = BUFT_TO_GPU(buft); + + return apir_buffer_type_is_host(gpu, buft); +} + +static size_t ggml_backend_remoting_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, + const ggml_tensor * tensor) { + virtgpu * gpu = BUFT_TO_GPU(buft); + + if (tensor->buffer == NULL + || !tensor->buffer->context + || !buft->device->iface.supports_buft(buft->device, tensor->buffer->buft)) { + return ggml_nbytes(tensor); + } + + return apir_buffer_type_get_alloc_size(gpu, buft, tensor); +} + +const ggml_backend_buffer_type_i ggml_backend_remoting_buffer_type_interface = { + /* .get_name = */ ggml_backend_remoting_buffer_type_get_name, + /* .alloc_buffer = */ ggml_backend_remoting_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_remoting_buffer_type_get_alignment, + /* .get_max_size = */ ggml_backend_remoting_buffer_type_get_max_size, + /* .get_alloc_size = */ ggml_backend_remoting_buffer_type_get_alloc_size, + /* .is_host = */ NULL, +}; + +const ggml_backend_buffer_type_i ggml_backend_remoting_buffer_from_ptr_type_interface = { + /* .get_name = */ ggml_backend_remoting_buffer_type_get_name, + /* .alloc_buffer = */ NULL, + /* .get_alignment = */ ggml_backend_remoting_buffer_type_get_alignment, + /* .get_max_size = */ ggml_backend_remoting_buffer_type_get_max_size, + /* .get_alloc_size = */ ggml_backend_remoting_buffer_type_get_alloc_size, + /* .is_host = */ NULL, +}; diff --git a/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp b/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp new file mode 100644 index 00000000000..6b95362dd80 --- /dev/null +++ b/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp @@ -0,0 +1,119 @@ +#include "ggml-remoting.h" + +#define BUFFER_TO_GPU(name) ((ggml_backend_remoting_buffer_context *) (name)->context)->gpu + +static void * ggml_backend_remoting_buffer_get_base(ggml_backend_buffer_t buffer) { + ggml_backend_remoting_buffer_context * context = (ggml_backend_remoting_buffer_context *) buffer->context; + if (context->base) { + return context->base; + } + + context->base = apir_buffer_get_base(BUFFER_TO_GPU(buffer), BUFFER_TO_APIR_CONTEXT(buffer)); + + return context->base; +} + +static void ggml_backend_remoting_buffer_set_tensor(ggml_backend_buffer_t buffer, + ggml_tensor * tensor, + const void * data, + size_t offset, + size_t size) { + virtgpu * gpu = BUFFER_TO_GPU(buffer); + + ggml_backend_remoting_buffer_context * context = BUFFER_TO_GGML_CONTEXT(buffer); + if (context->is_from_ptr) { + memcpy((char *) tensor->data + offset, data, size); + } else { + apir_buffer_set_tensor(gpu, BUFFER_TO_APIR_CONTEXT(buffer), tensor, data, offset, size); + } + + return; +} + +static void ggml_backend_remoting_buffer_get_tensor(ggml_backend_buffer_t buffer, + const ggml_tensor * tensor, + void * data, + size_t offset, + size_t size) { + virtgpu * gpu = BUFFER_TO_GPU(buffer); + ggml_backend_remoting_buffer_context * context = BUFFER_TO_GGML_CONTEXT(buffer); + if (context->is_from_ptr) { + memcpy(data, (const char *) tensor->data + offset, size); + } else { + apir_buffer_get_tensor(gpu, BUFFER_TO_APIR_CONTEXT(buffer), tensor, data, offset, size); + } +} + +static void ggml_backend_remoting_buffer_set_tensor_from_ptr(ggml_backend_buffer_t buffer, + ggml_tensor * tensor, + const void * data, + size_t offset, + size_t size) { + UNUSED(buffer); + + memcpy((char *) tensor->data + offset, data, size); + + return; +} + +static void ggml_backend_remoting_buffer_get_tensor_from_ptr(ggml_backend_buffer_t buffer, + const ggml_tensor * tensor, + void * data, + size_t offset, + size_t size) { + UNUSED(buffer); + + memcpy(data, (const char *) tensor->data + offset, size); +} + +static bool ggml_backend_remoting_buffer_cpy_tensor(ggml_backend_buffer_t buffer, + const ggml_tensor * src, + ggml_tensor * dst) { + virtgpu * gpu = BUFFER_TO_GPU(buffer); + + bool ret = apir_buffer_cpy_tensor(gpu, BUFFER_TO_APIR_CONTEXT(buffer), src, dst); + + return ret; +} + +static void ggml_backend_remoting_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { + virtgpu * gpu = BUFFER_TO_GPU(buffer); + + apir_buffer_clear(gpu, BUFFER_TO_APIR_CONTEXT(buffer), value); + + return; +} + +static void ggml_backend_remoting_buffer_free_buffer(ggml_backend_buffer_t buffer) { + virtgpu * gpu = BUFFER_TO_GPU(buffer); + + apir_buffer_free_buffer(gpu, BUFFER_TO_APIR_CONTEXT(buffer)); + + ggml_backend_remoting_buffer_context * context = BUFFER_TO_GGML_CONTEXT(buffer); + free(context); + buffer->context = NULL; +} + +const ggml_backend_buffer_i ggml_backend_remoting_buffer_interface = { + /* .free_buffer = */ ggml_backend_remoting_buffer_free_buffer, + /* .get_base = */ ggml_backend_remoting_buffer_get_base, + /* .init_tensor = */ NULL, + /* .memset_tensor = */ NULL, + /* .set_tensor = */ ggml_backend_remoting_buffer_set_tensor, + /* .get_tensor = */ ggml_backend_remoting_buffer_get_tensor, + /* .cpy_tensor = */ ggml_backend_remoting_buffer_cpy_tensor, + /* .clear = */ ggml_backend_remoting_buffer_clear, + /* .reset = */ NULL, +}; + +const ggml_backend_buffer_i ggml_backend_remoting_buffer_from_ptr_interface = { + /* .free_buffer = */ ggml_backend_remoting_buffer_free_buffer, + /* .get_base = */ ggml_backend_remoting_buffer_get_base, + /* .init_tensor = */ NULL, + /* .memset_tensor = */ NULL, + /* .set_tensor = */ ggml_backend_remoting_buffer_set_tensor_from_ptr, + /* .get_tensor = */ ggml_backend_remoting_buffer_get_tensor_from_ptr, + /* .cpy_tensor = */ ggml_backend_remoting_buffer_cpy_tensor, + /* .clear = */ ggml_backend_remoting_buffer_clear, + /* .reset = */ NULL, +}; diff --git a/ggml/src/ggml-virtgpu/ggml-backend-device.cpp b/ggml/src/ggml-virtgpu/ggml-backend-device.cpp new file mode 100644 index 00000000000..579eb990781 --- /dev/null +++ b/ggml/src/ggml-virtgpu/ggml-backend-device.cpp @@ -0,0 +1,144 @@ +#include "ggml-remoting.h" + +static const char * ggml_backend_remoting_device_get_name(ggml_backend_dev_t dev) { + virtgpu * gpu = DEV_TO_GPU(dev); + + return apir_device_get_name(gpu); +} + +static const char * ggml_backend_remoting_device_get_description(ggml_backend_dev_t dev) { + virtgpu * gpu = DEV_TO_GPU(dev); + + return apir_device_get_description(gpu); +} + +static enum ggml_backend_dev_type ggml_backend_remoting_device_get_type(ggml_backend_dev_t dev) { + virtgpu * gpu = DEV_TO_GPU(dev); + + static enum ggml_backend_dev_type type; + static bool has_type = false; + if (!has_type) { + has_type = true; + type = (enum ggml_backend_dev_type) apir_device_get_type(gpu); + } + + return type; +} + +static void ggml_backend_remoting_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { + virtgpu * gpu = DEV_TO_GPU(dev); + + return apir_device_get_memory(gpu, free, total); +} + +static bool ggml_backend_remoting_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) { +#if USE_ALWAYS_TRUE_SUPPORTS_OP == 1 + /* ggml-rpc cheats it like this */ + /* with the current implementation of serialize_tensor, the src/view aren't properly passed */ + UNUSED(dev); + UNUSED(op); + + return true; +#else + virtgpu * gpu = DEV_TO_GPU(dev); + + return apir_device_supports_op(gpu, op); +#endif +} + +static bool ggml_backend_remoting_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { + bool supported = buft->device == dev; + + return supported; +} + +static bool ggml_backend_remoting_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) { + UNUSED(dev); + UNUSED(op); + + return false; +} + +static void ggml_backend_remoting_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) { + props->name = ggml_backend_remoting_device_get_name(dev); + props->description = ggml_backend_remoting_device_get_description(dev); + props->type = ggml_backend_remoting_device_get_type(dev); + ggml_backend_remoting_device_get_memory(dev, &props->memory_free, &props->memory_total); + + virtgpu * gpu = DEV_TO_GPU(dev); + apir_device_get_props(gpu, &props->caps.async, &props->caps.host_buffer, &props->caps.buffer_from_host_ptr, + &props->caps.events); + + props->caps.buffer_from_host_ptr = false; + props->caps.async = false; + props->caps.events = false; +} + +ggml_backend_buffer_type_t ggml_backend_remoting_device_get_buffer_type(ggml_backend_dev_t dev) { + virtgpu * gpu = DEV_TO_GPU(dev); + + apir_buffer_type_host_handle_t ctx = apir_device_get_buffer_type(gpu); + + static ggml_backend_buffer_type buft{ + /* .iface = */ ggml_backend_remoting_buffer_type_interface, + /* .device = */ dev, + /* .context = */ (void *) ctx, + }; + + return &buft; +} + +static ggml_backend_buffer_type_t ggml_backend_remoting_device_get_buffer_from_ptr_type(ggml_backend_dev_t dev) { + virtgpu * gpu = DEV_TO_GPU(dev); + + apir_buffer_type_host_handle_t ctx = apir_device_get_buffer_type(gpu); + + static ggml_backend_buffer_type buft{ + /* .iface = */ ggml_backend_remoting_buffer_from_ptr_type_interface, + /* .device = */ dev, + /* .context = */ (void *) ctx, + }; + + return &buft; +} + +static ggml_backend_buffer_t ggml_backend_remoting_device_buffer_from_ptr(ggml_backend_dev_t dev, + void * ptr, + size_t size, + size_t max_tensor_size) { + virtgpu * gpu = DEV_TO_GPU(dev); + + ggml_backend_remoting_buffer_context * context = (ggml_backend_remoting_buffer_context *) malloc(sizeof(*context)); + if (!context) { + GGML_ABORT("Couldn't allocate the buffer context ..."); + } + + context->gpu = gpu; + context->apir_context = apir_device_buffer_from_ptr(gpu, size, max_tensor_size); + context->base = ptr; + context->is_from_ptr = true; + + ggml_backend_buffer_t buffer = + ggml_backend_buffer_init(ggml_backend_remoting_device_get_buffer_from_ptr_type(dev), + ggml_backend_remoting_buffer_from_ptr_interface, (void *) context, size); + + return buffer; +} + +const ggml_backend_device_i ggml_backend_remoting_device_interface = { + /* .get_name = */ ggml_backend_remoting_device_get_name, + /* .get_description = */ ggml_backend_remoting_device_get_description, + /* .get_memory = */ ggml_backend_remoting_device_get_memory, + /* .get_type = */ ggml_backend_remoting_device_get_type, + /* .get_props = */ ggml_backend_remoting_device_get_props, + /* .init_backend = */ ggml_backend_remoting_device_init, + /* .get_buffer_type = */ ggml_backend_remoting_device_get_buffer_type, + /* .get_host_buffer_type = */ NULL, + /* .buffer_from_host_ptr = */ ggml_backend_remoting_device_buffer_from_ptr, + /* .supports_op = */ ggml_backend_remoting_device_supports_op, + /* .supports_buft = */ ggml_backend_remoting_device_supports_buft, + /* .offload_op = */ ggml_backend_remoting_device_offload_op, + /* .event_new = */ NULL, + /* .event_free = */ NULL, + /* .event_synchronize = */ NULL, +}; diff --git a/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp b/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp new file mode 100644 index 00000000000..c46cf51c022 --- /dev/null +++ b/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp @@ -0,0 +1,137 @@ +#include "ggml-remoting.h" +#include "ggml-virtgpu.h" + +#include +#include + +static virtgpu * apir_initialize() { + static virtgpu * apir_gpu_instance = NULL; + static bool apir_initialized = false; + + { + static std::mutex mutex; + std::lock_guard lock(mutex); + + if (apir_initialized) { + return apir_gpu_instance; + } + + apir_gpu_instance = create_virtgpu(); + if (!apir_gpu_instance) { + GGML_ABORT("failed to initialize the virtgpu"); + } + + apir_initialized = true; + } + + return apir_gpu_instance; +} + +static int ggml_backend_remoting_get_device_count() { + virtgpu * gpu = apir_initialize(); + if (!gpu) { + GGML_LOG_WARN("apir_initialize failed\n"); + return 0; + } + + return apir_device_get_count(gpu); +} + +static size_t ggml_backend_remoting_reg_get_device_count(ggml_backend_reg_t reg) { + UNUSED(reg); + + return ggml_backend_remoting_get_device_count(); +} + +static std::vector devices; + +ggml_backend_dev_t ggml_backend_remoting_get_device(size_t device) { + GGML_ASSERT(device < devices.size()); + return devices[device]; +} + +static void ggml_backend_remoting_reg_init_devices(ggml_backend_reg_t reg) { + if (devices.size() > 0) { + GGML_LOG_INFO("%s: already initialized\n", __func__); + return; + } + + virtgpu * gpu = apir_initialize(); + if (!gpu) { + GGML_LOG_ERROR("apir_initialize failed\n"); + return; + } + + static bool initialized = false; + + { + static std::mutex mutex; + std::lock_guard lock(mutex); + if (!initialized) { + for (int i = 0; i < ggml_backend_remoting_get_device_count(); i++) { + ggml_backend_remoting_device_context * ctx = new ggml_backend_remoting_device_context; + char desc[256] = "API Remoting device"; + + ctx->device = i; + ctx->name = GGML_REMOTING_FRONTEND_NAME + std::to_string(i); + ctx->description = desc; + ctx->gpu = gpu; + + ggml_backend_dev_t dev = new ggml_backend_device{ + /* .iface = */ ggml_backend_remoting_device_interface, + /* .reg = */ reg, + /* .context = */ ctx, + }; + devices.push_back(dev); + } + initialized = true; + } + } +} + +static ggml_backend_dev_t ggml_backend_remoting_reg_get_device(ggml_backend_reg_t reg, size_t device) { + UNUSED(reg); + + return ggml_backend_remoting_get_device(device); +} + +static const char * ggml_backend_remoting_reg_get_name(ggml_backend_reg_t reg) { + UNUSED(reg); + + return GGML_REMOTING_FRONTEND_NAME; +} + +static const ggml_backend_reg_i ggml_backend_remoting_reg_i = { + /* .get_name = */ ggml_backend_remoting_reg_get_name, + /* .get_device_count = */ ggml_backend_remoting_reg_get_device_count, + /* .get_device = */ ggml_backend_remoting_reg_get_device, + /* .get_proc_address = */ NULL, +}; + +ggml_backend_reg_t ggml_backend_virtgpu_reg() { + virtgpu * gpu = apir_initialize(); + if (!gpu) { + GGML_LOG_ERROR("virtgpu_apir_initialize failed\n"); + return NULL; + } + + static ggml_backend_reg reg = { + /* .api_version = */ GGML_BACKEND_API_VERSION, + /* .iface = */ ggml_backend_remoting_reg_i, + /* .context = */ gpu, + }; + + static bool initialized = false; + if (initialized) { + return ® + } + initialized = true; + + ggml_backend_remoting_reg_init_devices(®); + + GGML_LOG_INFO("%s: initialized\n", __func__); + + return ® +} + +GGML_BACKEND_DL_IMPL(ggml_backend_virtgpu_reg) diff --git a/ggml/src/ggml-virtgpu/ggml-backend.cpp b/ggml/src/ggml-virtgpu/ggml-backend.cpp new file mode 100644 index 00000000000..5cd6c0c0608 --- /dev/null +++ b/ggml/src/ggml-virtgpu/ggml-backend.cpp @@ -0,0 +1,69 @@ +#include "ggml-remoting.h" +#include "../../include/ggml-virtgpu.h" + +static const char * ggml_backend_remoting_get_name(ggml_backend_t backend) { + UNUSED(backend); + + return "API Remoting backend"; +} + +static void ggml_backend_remoting_free(ggml_backend_t backend) { + delete backend; +} + +static ggml_status ggml_backend_remoting_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { + virtgpu * gpu = DEV_TO_GPU(backend->device); + + return apir_backend_graph_compute(gpu, cgraph); +} + +static void ggml_backend_remoting_graph_optimize(ggml_backend_t backend, ggml_cgraph * cgraph) { + virtgpu * gpu = DEV_TO_GPU(backend->device); +#if true + UNUSED(gpu); + UNUSED(cgraph); +#else + // not working yet + + apir_backend_graph_optimize(gpu, cgraph); +#endif +} + +static ggml_backend_i ggml_backend_remoting_interface = { + /* .get_name = */ ggml_backend_remoting_get_name, + /* .free = */ ggml_backend_remoting_free, + /* .set_tensor_async = */ NULL, // ggml_backend_remoting_set_tensor_async, + /* .get_tensor_async = */ NULL, // ggml_backend_remoting_get_tensor_async, + /* .cpy_tensor_async = */ NULL, // ggml_backend_remoting_cpy_tensor_async, + /* .synchronize = */ NULL, // ggml_backend_remoting_synchronize, + /* .graph_plan_create = */ NULL, + /* .graph_plan_free = */ NULL, + /* .graph_plan_update = */ NULL, + /* .graph_plan_compute = */ NULL, + /* .graph_compute = */ ggml_backend_remoting_graph_compute, + /* .event_record = */ NULL, + /* .event_wait = */ NULL, + /* .graph_optimize = */ ggml_backend_remoting_graph_optimize, +}; + +static ggml_guid_t ggml_backend_remoting_guid() { + static ggml_guid guid = { 0xb8, 0xf7, 0x4f, 0x86, 0x14, 0x03, 0x86, 0x02, + 0x91, 0xc8, 0xdd, 0xe9, 0x02, 0x3f, 0xc0, 0x2b }; + + return &guid; +} + +ggml_backend_t ggml_backend_remoting_device_init(ggml_backend_dev_t dev, const char * params) { + UNUSED(params); + + ggml_backend_remoting_device_context * ctx = (ggml_backend_remoting_device_context *) dev->context; + + ggml_backend_t remoting_backend = new ggml_backend{ + /* .guid = */ ggml_backend_remoting_guid(), + /* .interface = */ ggml_backend_remoting_interface, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_virtgpu_reg(), ctx->device), + /* .context = */ ctx, + }; + + return remoting_backend; +} diff --git a/ggml/src/ggml-virtgpu/ggml-remoting.h b/ggml/src/ggml-virtgpu/ggml-remoting.h new file mode 100644 index 00000000000..36fc6b2a7bd --- /dev/null +++ b/ggml/src/ggml-virtgpu/ggml-remoting.h @@ -0,0 +1,68 @@ +#pragma once + +#include "ggml-backend-impl.h" +#include "ggml-backend.h" +#include "ggml-impl.h" +#include "virtgpu.h" + +#include +#include + +// USE_ALWAYS_TRUE_SUPPORTS_OP: 1 is fast, 0 avoid micro-benchmark crashes + +#define USE_ALWAYS_TRUE_SUPPORTS_OP 1 +#define USE_METAL_GUEST_SUPPORTS_OP 0 + +#define DEV_TO_GPU(name) ((ggml_backend_remoting_device_context *) (name)->context)->gpu + +#define BUFFER_TO_GGML_CONTEXT(name) ((ggml_backend_remoting_buffer_context *) (name)->context) + +#define BUFFER_TO_APIR_CONTEXT(name) &((ggml_backend_remoting_buffer_context *) (name)->context)->apir_context + +#define BUFFER_TO_HOST_HANDLE(name) ((ggml_backend_remoting_buffer_context *) (name)->context)->apir_context.host_handle + +#define GET_DEVICE_CONTEXT() (ggml_backend_remoting_device_context *) ggml_backend_remoting_get_device(0)->context + +#define BUFT_TO_GPU(name) ((ggml_backend_remoting_device_context *) (name)->device->context)->gpu + +struct ggml_backend_remoting_device_context { + size_t device; + std::string name; + std::string description; + + std::vector> shared_memory; + + virtgpu * gpu; +}; + +struct ggml_backend_remoting_buffer_context { + apir_buffer_context_t apir_context; + + virtgpu * gpu; + + void * base; + + bool is_from_ptr; +}; + +extern const ggml_backend_buffer_type_i ggml_backend_remoting_buffer_type_interface; +extern const ggml_backend_device_i ggml_backend_remoting_device_interface; +extern const ggml_backend_buffer_i ggml_backend_remoting_buffer_interface; +extern const ggml_backend_buffer_type_i ggml_backend_remoting_buffer_from_ptr_type_interface; +extern const ggml_backend_buffer_i ggml_backend_remoting_buffer_from_ptr_interface; + +ggml_backend_dev_t ggml_backend_remoting_get_device(size_t device); +ggml_backend_t ggml_backend_remoting_device_init(ggml_backend_dev_t dev, const char * params); +ggml_backend_buffer_type_t ggml_backend_remoting_device_get_buffer_type(ggml_backend_dev_t dev); + +static inline apir_buffer_type_host_handle_t ggml_buffer_type_to_apir_handle(ggml_backend_buffer_type_t buft) { + // in the backend, the buffer handle is the buffer pointer + return (apir_buffer_type_host_handle_t) buft->context; +} + +static inline apir_buffer_host_handle_t ggml_buffer_to_apir_handle(ggml_backend_buffer_t buffer) { + if (!buffer->context) { + GGML_ABORT("%s: no context available :/", __func__); + } + return BUFFER_TO_HOST_HANDLE(buffer); +} diff --git a/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml b/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml new file mode 100644 index 00000000000..0b7cccfe9cf --- /dev/null +++ b/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml @@ -0,0 +1,168 @@ +# YAML schema for GGML remoting API functions +# This defines the structure for generating the remoting layer code + +# Configuration for the generated files +config: + # Base path for the generated files + base_path: "ggml/src" + + # Header files to update + files: + apir_backend_header: "ggml-virtgpu-apir/backend/shared/apir_backend.gen.h" + backend_dispatched_header: "ggml-virtgpu-apir/backend/backend-dispatched.gen.h" + virtgpu_forward_header: "ggml-virtgpu-apir/virtgpu-forward.gen.h" + +# Simplified function definitions with grouping and metadata combined +functions: + device: + group_description: "device" + functions: + get_device_count: + # No specific metadata - uses default void return and base params + + get_count: + frontend_return: "int" + + get_name: + frontend_return: "const char *" + + get_description: + frontend_return: "const char *" + + get_type: + frontend_return: "uint32_t" + + get_memory: + frontend_return: "void" + frontend_extra_params: + - "size_t *free" + - "size_t *total" + + supports_op: + frontend_return: "bool" + frontend_extra_params: + - "const ggml_tensor *op" + + get_buffer_type: + frontend_return: "apir_buffer_type_host_handle_t" + + get_props: + frontend_return: "void" + frontend_extra_params: + - "bool *async" + - "bool *host_buffer" + - "bool *buffer_from_host_ptr" + - "bool *events" + + buffer_from_ptr: + frontend_return: "apir_buffer_context_t" + frontend_extra_params: + - "size_t size" + - "size_t max_tensor_size" + + buffer_type: + group_description: "buffer-type" + functions: + get_name: + frontend_return: "const char *" + frontend_extra_params: + - "ggml_backend_buffer_type_t buft" + + get_alignment: + frontend_return: "size_t" + frontend_extra_params: + - "ggml_backend_buffer_type_t buft" + + get_max_size: + frontend_return: "size_t" + frontend_extra_params: + - "ggml_backend_buffer_type_t buft" + + is_host: + frontend_return: "bool" + frontend_extra_params: + - "ggml_backend_buffer_type_t buft" + + alloc_buffer: + frontend_return: "apir_buffer_context_t" + frontend_extra_params: + - "ggml_backend_buffer_type_t buffer_buft" + - "size_t size" + + get_alloc_size: + frontend_return: "size_t" + frontend_extra_params: + - "ggml_backend_buffer_type_t buft" + - "const ggml_tensor *op" + + buffer: + group_description: "buffer" + functions: + get_base: + frontend_return: "void *" + frontend_extra_params: + - "apir_buffer_context_t *buffer_context" + + set_tensor: + frontend_return: "void" + frontend_extra_params: + - "apir_buffer_context_t *buffer_context" + - "ggml_tensor *tensor" + - "const void *data" + - "size_t offset" + - "size_t size" + + get_tensor: + frontend_return: "void" + frontend_extra_params: + - "apir_buffer_context_t *buffer_context" + - "const ggml_tensor *tensor" + - "void *data" + - "size_t offset" + - "size_t size" + + cpy_tensor: + frontend_return: "bool" + frontend_extra_params: + - "apir_buffer_context_t *buffer_context" + - "const ggml_tensor *src" + - "const ggml_tensor *dst" + + clear: + frontend_return: "void" + frontend_extra_params: + - "apir_buffer_context_t *buffer_context" + - "uint8_t value" + + free_buffer: + frontend_return: "void" + frontend_extra_params: + - "apir_buffer_context_t *buffer_context" + + backend: + group_description: "backend" + functions: + graph_compute: + frontend_return: "ggml_status" + frontend_extra_params: + - "ggml_cgraph *cgraph" + + graph_optimize: + frontend_return: "ggml_cgraph *" + frontend_extra_params: + - "ggml_cgraph *cgraph" + enabled: false + +# Naming patterns used for code generation +naming_patterns: + # How to generate enum names + enum_prefix: "APIR_COMMAND_TYPE_" + + # How to generate backend function names + backend_function_prefix: "backend_" + + # How to generate frontend function names + frontend_function_prefix: "apir_" + + # Standard frontend first parameter + frontend_base_param: "struct virtgpu *gpu" diff --git a/ggml/src/ggml-virtgpu/include/apir_hw.h b/ggml/src/ggml-virtgpu/include/apir_hw.h new file mode 100644 index 00000000000..33af045ca2b --- /dev/null +++ b/ggml/src/ggml-virtgpu/include/apir_hw.h @@ -0,0 +1,9 @@ +#pragma once + +#include + +struct virgl_renderer_capset_apir { + uint32_t apir_version; + uint32_t supports_blob_resources; + uint32_t reserved[4]; // For future expansion +}; diff --git a/ggml/src/ggml-virtgpu/regenerate_remoting.py b/ggml/src/ggml-virtgpu/regenerate_remoting.py new file mode 100755 index 00000000000..4174a24327f --- /dev/null +++ b/ggml/src/ggml-virtgpu/regenerate_remoting.py @@ -0,0 +1,322 @@ +#!/usr/bin/env python3 +""" +# Generated by Claude AI + +Script to completely regenerate the GGML remoting codebase from YAML configuration. + +This script reads api_functions.yaml and regenerates all the header files and +implementation templates for the GGML remoting layer. + +Usage: + python regenerate_remoting.py + +The script will: +1. Read ggmlremoting_functions.yaml configuration +2. Generate updated header files +3. Generate implementation templates in dedicated files +4. Show a summary of what was generated +""" + +import yaml +from typing import Dict, List, Any +from pathlib import Path +import os +import subprocess +import shutil +import logging + +NL = '\n' # can't have f"{'\n'}" in f-strings + + +class RemotingCodebaseGenerator: + def __init__(self, yaml_path: str = "ggmlremoting_functions.yaml"): + """Initialize the generator with the YAML configuration.""" + self.yaml_path = yaml_path + + if not Path(yaml_path).exists(): + raise FileNotFoundError(f"Configuration file {yaml_path} not found") + + with open(yaml_path, 'r') as f: + self.config = yaml.safe_load(f) + + self.functions = self.config['functions'] + self.naming_patterns = self.config['naming_patterns'] + self.config_data = self.config['config'] + + # Check if clang-format is available + self.clang_format_available = self._check_clang_format_available() + + def _check_clang_format_available(self) -> bool: + """Check if clang-format is available in the system PATH.""" + return shutil.which("clang-format") is not None + + def _format_file_with_clang_format(self, file_path: Path) -> bool: + """Format a file with clang-format -i. Returns True if successful, False otherwise.""" + if not self.clang_format_available: + return False + + try: + subprocess.run( + ["clang-format", "-i", str(file_path)], + check=True, + capture_output=True, + text=True + ) + return True + except subprocess.CalledProcessError: + logging.exception(f" ⚠️ clang-format failed for {file_path}") + return False + except Exception as e: + logging.exception(f" ⚠️ Unexpected error formatting {file_path}: {e}") + return False + + def generate_enum_name(self, group_name: str, function_name: str) -> str: + """Generate the APIR_COMMAND_TYPE enum name for a function.""" + prefix = self.naming_patterns['enum_prefix'] + return f"{prefix}{group_name.upper()}_{function_name.upper()}" + + def generate_backend_function_name(self, group_name: str, function_name: str) -> str: + """Generate the backend function name.""" + function_key = f"{group_name}_{function_name}" + overrides = self.naming_patterns.get('backend_function_overrides', {}) + + if function_key in overrides: + return overrides[function_key] + + prefix = self.naming_patterns['backend_function_prefix'] + return f"{prefix}{group_name}_{function_name}" + + def generate_frontend_function_name(self, group_name: str, function_name: str) -> str: + """Generate the frontend function name.""" + prefix = self.naming_patterns['frontend_function_prefix'] + return f"{prefix}{group_name}_{function_name}" + + def get_enabled_functions(self) -> List[Dict[str, Any]]: + """Get all enabled functions with their metadata.""" + functions = [] + enum_value = 0 + + for group_name, group_data in self.functions.items(): + group_description = group_data['group_description'] + + for function_name, func_metadata in group_data['functions'].items(): + # Handle case where func_metadata is None or empty (functions with only comments) + if func_metadata is None: + func_metadata = {} + + # Functions are enabled by default unless explicitly disabled + if func_metadata.get('enabled', True): + functions.append({ + 'group_name': group_name, + 'function_name': function_name, + 'enum_name': self.generate_enum_name(group_name, function_name), + 'enum_value': enum_value, + 'backend_function': self.generate_backend_function_name(group_name, function_name), + 'frontend_function': self.generate_frontend_function_name(group_name, function_name), + 'frontend_return': func_metadata.get('frontend_return', 'void'), + 'frontend_extra_params': func_metadata.get('frontend_extra_params', []), + 'group_description': group_description, + 'newly_added': func_metadata.get('newly_added', False) + }) + enum_value += 1 + + return functions + + def generate_apir_backend_header(self) -> str: + """Generate the complete apir_backend.h file.""" + functions = self.get_enabled_functions() + + # Generate the enum section + enum_lines = ["typedef enum ApirBackendCommandType {"] + current_group = None + + for func in functions: + # Add comment for new group + if func['group_name'] != current_group: + enum_lines.append("") + enum_lines.append(f" /* {func['group_description']} */") + current_group = func['group_name'] + + enum_lines.append(f" {func['enum_name']} = {func['enum_value']},") + + # Add the count + total_count = len(functions) + enum_lines.append("\n // last command_type index + 1") + enum_lines.append(f" APIR_BACKEND_DISPATCH_TABLE_COUNT = {total_count},") + enum_lines.append("} ApirBackendCommandType;") + + # Full header template + header_content = NL.join(enum_lines) + "\n" + + return header_content + + def generate_backend_dispatched_header(self) -> str: + """Generate the complete backend-dispatched.h file.""" + functions = self.get_enabled_functions() + + # Function declarations + decl_lines = [] + current_group = None + + for func in functions: + if func['group_name'] != current_group: + decl_lines.append(f"\n/* {func['group_description']} */") + current_group = func['group_name'] + + signature = "uint32_t" + params = "apir_encoder *enc, apir_decoder *dec, virgl_apir_context *ctx" + decl_lines.append(f"{signature} {func['backend_function']}({params});") + + # Switch cases + switch_lines = [] + current_group = None + + for func in functions: + if func['group_name'] != current_group: + switch_lines.append(f" /* {func['group_description']} */") + current_group = func['group_name'] + + switch_lines.append(f" case {func['enum_name']}: return \"{func['backend_function']}\";") + + # Dispatch table + table_lines = [] + current_group = None + + for func in functions: + if func['group_name'] != current_group: + table_lines.append(f"\n /* {func['group_description']} */") + table_lines.append("") + current_group = func['group_name'] + + table_lines.append(f" /* {func['enum_name']} = */ {func['backend_function']},") + + header_content = f'''\ +#pragma once + +{NL.join(decl_lines)} + +static inline const char *backend_dispatch_command_name(ApirBackendCommandType type) +{{ + switch (type) {{ +{NL.join(switch_lines)} + + default: return "unknown"; + }} +}} + +extern "C" {{ +static const backend_dispatch_t apir_backend_dispatch_table[APIR_BACKEND_DISPATCH_TABLE_COUNT] = {{ + {NL.join(table_lines)} +}}; +}} +''' + return header_content + + def generate_virtgpu_forward_header(self) -> str: + """Generate the complete virtgpu-forward.gen.h file.""" + functions = self.get_enabled_functions() + + decl_lines = [] + current_group = None + + for func in functions: + if func['group_name'] != current_group: + decl_lines.append("") + decl_lines.append(f"/* {func['group_description']} */") + current_group = func['group_name'] + + # Build parameter list + params = [self.naming_patterns['frontend_base_param']] + params.extend(func['frontend_extra_params']) + param_str = ', '.join(params) + + decl_lines.append(f"{func['frontend_return']} {func['frontend_function']}({param_str});") + + header_content = f'''\ +#pragma once +{NL.join(decl_lines)} +''' + return header_content + + def regenerate_codebase(self) -> None: + """Regenerate the entire remoting codebase.""" + logging.info("🔄 Regenerating GGML Remoting Codebase...") + logging.info("=" * 50) + + # Detect if we're running from frontend directory + current_dir = os.getcwd() + is_frontend_dir = current_dir.endswith('ggml-virtgpu') + + if is_frontend_dir: + # Running from ggml/src/ggml-virtgpu-apir + logging.info("📍 Detected frontend directory execution") + frontend_base = Path(".") + else: + # Running from project root (fallback to original behavior) + logging.info("📍 Detected project root execution") + base_path = self.config_data.get('base_path', 'ggml/src') + frontend_base = Path(base_path) / "ggml-virtgpu" + + # Compute final file paths + backend_base = frontend_base / "backend" + apir_backend_path = backend_base / "shared" / "apir_backend.gen.h" + backend_dispatched_path = backend_base / "backend-dispatched.gen.h" + virtgpu_forward_path = frontend_base / "virtgpu-forward.gen.h" + + # Create output directories for each file + apir_backend_path.parent.mkdir(parents=True, exist_ok=True) + backend_dispatched_path.parent.mkdir(parents=True, exist_ok=True) + virtgpu_forward_path.parent.mkdir(parents=True, exist_ok=True) + + # Generate header files + logging.info("📁 Generating header files...") + + apir_backend_content = self.generate_apir_backend_header() + apir_backend_path.write_text(apir_backend_content) + logging.info(f" ✅ {apir_backend_path.resolve()}") + + backend_dispatched_content = self.generate_backend_dispatched_header() + backend_dispatched_path.write_text(backend_dispatched_content) + logging.info(f" ✅ {backend_dispatched_path.resolve()}") + + virtgpu_forward_content = self.generate_virtgpu_forward_header() + virtgpu_forward_path.write_text(virtgpu_forward_content) + logging.info(f" ✅ {virtgpu_forward_path.resolve()}") + + # Format generated files with clang-format + generated_files = [apir_backend_path, backend_dispatched_path, virtgpu_forward_path] + + if not self.clang_format_available: + logging.warning("\n⚠️clang-format not found in PATH. Generated files will not be formatted." + " Install clang-format to enable automatic code formatting.") + else: + logging.info("\n🎨 Formatting files with clang-format...") + for file_path in generated_files: + if self._format_file_with_clang_format(file_path): + logging.info(f" ✅ Formatted {file_path.name}") + else: + logging.warning(f" ❌ Failed to format {file_path.name}") + + # Generate summary + functions = self.get_enabled_functions() + total_functions = len(functions) + + logging.info("\n📊 Generation Summary:") + logging.info("=" * 50) + logging.info(f" Total functions: {total_functions}") + logging.info(f" Function groups: {len(self.functions)}") + logging.info(" Header files: 3") + logging.info(f" Working directory: {current_dir}") + + +def main(): + try: + generator = RemotingCodebaseGenerator() + generator.regenerate_codebase() + except Exception as e: + logging.exception(f"❌ Error: {e}") + exit(1) + + +if __name__ == "__main__": + main() diff --git a/ggml/src/ggml-virtgpu/virtgpu-apir.h b/ggml/src/ggml-virtgpu/virtgpu-apir.h new file mode 100644 index 00000000000..238f960acd2 --- /dev/null +++ b/ggml/src/ggml-virtgpu/virtgpu-apir.h @@ -0,0 +1,15 @@ +#include "backend/shared/apir_backend.h" +#include "ggml-alloc.h" +#include "ggml-impl.h" +#include "ggml.h" +#include "virtgpu-shm.h" +#include "virtgpu-utils.h" + +struct apir_buffer_context_t { + apir_buffer_host_handle_t host_handle; + + struct virtgpu_shmem shmem; + apir_buffer_type_host_handle_t buft_host_handle; +}; + +#include "virtgpu-forward.gen.h" diff --git a/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp b/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp new file mode 100644 index 00000000000..bf3c41011ac --- /dev/null +++ b/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp @@ -0,0 +1,50 @@ +#include "virtgpu-forward-impl.h" + +static long long current_time_ms() { + timespec ts; + clock_gettime(CLOCK_REALTIME, &ts); // Use CLOCK_MONOTONIC for elapsed time + return (long long) ts.tv_sec * 1000000000LL + ts.tv_nsec; +} + +ggml_status apir_backend_graph_compute(virtgpu * gpu, ggml_cgraph * cgraph) { + apir_encoder * encoder; + apir_decoder * decoder; + ApirForwardReturnCode ret; + + REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_BACKEND_GRAPH_COMPUTE); + + std::vector cgraph_data; + size_t cgraph_size = apir_serialize_ggml_cgraph(cgraph, cgraph_data); + + virtgpu_shmem temp_shmem; // Local storage for large buffers + virtgpu_shmem * shmem = &temp_shmem; + + if (cgraph_size <= gpu->data_shmem.mmap_size) { + // prefer the init-time allocated page, if large enough + shmem = &gpu->data_shmem; + } else if (virtgpu_shmem_create(gpu, cgraph_size, shmem)) { + GGML_ABORT("Couldn't allocate the guest-host shared buffer"); + } + + apir_encode_virtgpu_shmem_res_id(encoder, shmem->res_id); + + apir_encode_size_t(encoder, &cgraph_size); + + char * shmem_data = (char *) shmem->mmap_ptr; + apir_encoder secondary_enc = apir_new_encoder(shmem_data, cgraph_size); + + apir_encode_cgraph_data(&secondary_enc, cgraph_data); + + REMOTE_CALL(gpu, encoder, decoder, ret); + + ggml_status status = GGML_STATUS_ABORTED; + apir_decode_ggml_status(decoder, &status); + + remote_call_finish(gpu, encoder, decoder); + + if (shmem != &gpu->data_shmem) { + virtgpu_shmem_destroy(gpu, shmem); + } + + return status; +} diff --git a/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp b/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp new file mode 100644 index 00000000000..03cb09e0643 --- /dev/null +++ b/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp @@ -0,0 +1,125 @@ +#include "virtgpu-forward-impl.h" + +const char * apir_buffer_type_get_name(virtgpu * gpu, ggml_backend_buffer_type_t buft) { + apir_encoder * encoder; + apir_decoder * decoder; + ApirForwardReturnCode ret; + + REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_BUFFER_TYPE_GET_NAME); + + apir_encode_ggml_buffer_type(encoder, buft); + + REMOTE_CALL(gpu, encoder, decoder, ret); + + const size_t string_size = apir_decode_array_size_unchecked(decoder); + char * string = (char *) apir_decoder_alloc_array(sizeof(char), string_size); + if (!string) { + GGML_LOG_ERROR("%s: Could not allocate the device name buffer\n", __func__); + apir_decoder_set_fatal(decoder); + } + apir_decode_char_array(decoder, string, string_size); + + remote_call_finish(gpu, encoder, decoder); + + return string; +} + +size_t apir_buffer_type_get_alignment(virtgpu * gpu, ggml_backend_buffer_type_t buft) { + apir_encoder * encoder; + apir_decoder * decoder; + ApirForwardReturnCode ret; + + REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_BUFFER_TYPE_GET_ALIGNMENT); + + apir_encode_ggml_buffer_type(encoder, buft); + + REMOTE_CALL(gpu, encoder, decoder, ret); + + size_t alignment; + apir_decode_size_t(decoder, &alignment); + + remote_call_finish(gpu, encoder, decoder); + + return alignment; +} + +size_t apir_buffer_type_get_max_size(virtgpu * gpu, ggml_backend_buffer_type_t buft) { + apir_encoder * encoder; + apir_decoder * decoder; + ApirForwardReturnCode ret; + + REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_BUFFER_TYPE_GET_MAX_SIZE); + + apir_encode_ggml_buffer_type(encoder, buft); + + REMOTE_CALL(gpu, encoder, decoder, ret); + + size_t max_size; + apir_decode_size_t(decoder, &max_size); + + remote_call_finish(gpu, encoder, decoder); + + return max_size; +} + +bool apir_buffer_type_is_host(virtgpu * gpu, ggml_backend_buffer_type_t buft) { + apir_encoder * encoder; + apir_decoder * decoder; + ApirForwardReturnCode ret; + + REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_BUFFER_TYPE_IS_HOST); + + apir_encode_ggml_buffer_type(encoder, buft); + + REMOTE_CALL(gpu, encoder, decoder, ret); + + bool is_host; + apir_decode_bool_t(decoder, &is_host); + + remote_call_finish(gpu, encoder, decoder); + + return is_host; +} + +apir_buffer_context_t apir_buffer_type_alloc_buffer(virtgpu * gpu, ggml_backend_buffer_type_t buft, size_t size) { + apir_encoder * encoder; + apir_decoder * decoder; + ApirForwardReturnCode ret; + + apir_buffer_context_t buffer_context; + + REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_BUFFER_TYPE_ALLOC_BUFFER); + + apir_encode_ggml_buffer_type(encoder, buft); + + apir_encode_size_t(encoder, &size); + + REMOTE_CALL(gpu, encoder, decoder, ret); + + apir_decode_apir_buffer_host_handle_t(decoder, &buffer_context.host_handle); + + remote_call_finish(gpu, encoder, decoder); + + return buffer_context; +} + +size_t apir_buffer_type_get_alloc_size(virtgpu * gpu, ggml_backend_buffer_type_t buft, const ggml_tensor * op) { + apir_encoder * encoder; + apir_decoder * decoder; + ApirForwardReturnCode ret; + + REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_BUFFER_TYPE_GET_ALLOC_SIZE); + + apir_encode_ggml_buffer_type(encoder, buft); + + apir_encode_ggml_tensor_inline(encoder, op); + + REMOTE_CALL(gpu, encoder, decoder, ret); + + size_t alloc_size; + apir_decode_size_t(decoder, &alloc_size); + + remote_call_finish(gpu, encoder, decoder); + + return alloc_size; +} diff --git a/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp b/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp new file mode 100644 index 00000000000..3181e394407 --- /dev/null +++ b/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp @@ -0,0 +1,157 @@ +#include "virtgpu-forward-impl.h" + +void * apir_buffer_get_base(virtgpu * gpu, apir_buffer_context_t * buffer_context) { + apir_encoder * encoder; + apir_decoder * decoder; + ApirForwardReturnCode ret; + + REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_BUFFER_GET_BASE); + + apir_encode_apir_buffer_host_handle_t(encoder, &buffer_context->host_handle); + + REMOTE_CALL(gpu, encoder, decoder, ret); + + uintptr_t base; + apir_decode_uintptr_t(decoder, &base); + + remote_call_finish(gpu, encoder, decoder); + + return (void *) base; +} + +void apir_buffer_set_tensor(virtgpu * gpu, + apir_buffer_context_t * buffer_context, + ggml_tensor * tensor, + const void * data, + size_t offset, + size_t size) { + apir_encoder * encoder; + apir_decoder * decoder; + ApirForwardReturnCode ret; + + REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_BUFFER_SET_TENSOR); + + apir_encode_apir_buffer_host_handle_t(encoder, &buffer_context->host_handle); + apir_encode_ggml_tensor(encoder, tensor); + + virtgpu_shmem temp_shmem; // Local storage for large buffers + virtgpu_shmem * shmem = &temp_shmem; + + if (size <= gpu->data_shmem.mmap_size) { + // prefer the init-time allocated page, if large enough + shmem = &gpu->data_shmem; + + } else if (virtgpu_shmem_create(gpu, size, shmem)) { + GGML_ABORT("Couldn't allocate the guest-host shared buffer"); + } + + memcpy(shmem->mmap_ptr, data, size); + apir_encode_virtgpu_shmem_res_id(encoder, shmem->res_id); + + apir_encode_size_t(encoder, &offset); + apir_encode_size_t(encoder, &size); + + REMOTE_CALL(gpu, encoder, decoder, ret); + + remote_call_finish(gpu, encoder, decoder); + + if (shmem != &gpu->data_shmem) { + virtgpu_shmem_destroy(gpu, shmem); + } + + return; +} + +void apir_buffer_get_tensor(virtgpu * gpu, + apir_buffer_context_t * buffer_context, + const ggml_tensor * tensor, + void * data, + size_t offset, + size_t size) { + apir_encoder * encoder; + apir_decoder * decoder; + ApirForwardReturnCode ret; + + REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_BUFFER_GET_TENSOR); + + apir_encode_apir_buffer_host_handle_t(encoder, &buffer_context->host_handle); + apir_encode_ggml_tensor(encoder, tensor); + + virtgpu_shmem temp_shmem; // Local storage for large buffers + virtgpu_shmem * shmem = &temp_shmem; + + if (size <= gpu->data_shmem.mmap_size) { + // prefer the init-time allocated page, if large enough + shmem = &gpu->data_shmem; + + } else if (virtgpu_shmem_create(gpu, size, shmem)) { + GGML_ABORT("Couldn't allocate the guest-host shared buffer"); + } + + apir_encode_virtgpu_shmem_res_id(encoder, shmem->res_id); + apir_encode_size_t(encoder, &offset); + apir_encode_size_t(encoder, &size); + + REMOTE_CALL(gpu, encoder, decoder, ret); + + memcpy(data, shmem->mmap_ptr, size); + + remote_call_finish(gpu, encoder, decoder); + + if (shmem != &gpu->data_shmem) { + virtgpu_shmem_destroy(gpu, shmem); + } +} + +bool apir_buffer_cpy_tensor(virtgpu * gpu, + apir_buffer_context_t * buffer_context, + const ggml_tensor * src, + const ggml_tensor * dst) { + apir_encoder * encoder; + apir_decoder * decoder; + ApirForwardReturnCode ret; + + REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_BUFFER_CPY_TENSOR); + + apir_encode_apir_buffer_host_handle_t(encoder, &buffer_context->host_handle); + apir_encode_ggml_tensor(encoder, src); + apir_encode_ggml_tensor(encoder, dst); + + REMOTE_CALL(gpu, encoder, decoder, ret); + + bool ret_val; + apir_decode_bool_t(decoder, &ret_val); + + remote_call_finish(gpu, encoder, decoder); + + return ret_val; +} + +void apir_buffer_clear(virtgpu * gpu, apir_buffer_context_t * buffer_context, uint8_t value) { + apir_encoder * encoder; + apir_decoder * decoder; + ApirForwardReturnCode ret; + + REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_BUFFER_CLEAR); + + apir_encode_apir_buffer_host_handle_t(encoder, &buffer_context->host_handle); + apir_encode_uint8_t(encoder, &value); + + REMOTE_CALL(gpu, encoder, decoder, ret); + + remote_call_finish(gpu, encoder, decoder); +} + +void apir_buffer_free_buffer(virtgpu * gpu, apir_buffer_context_t * buffer_context) { + apir_encoder * encoder; + apir_decoder * decoder; + ApirForwardReturnCode ret; + + REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_BUFFER_FREE_BUFFER); + + apir_encode_apir_buffer_host_handle_t(encoder, &buffer_context->host_handle); + + REMOTE_CALL(gpu, encoder, decoder, ret); + + remote_call_finish(gpu, encoder, decoder); +} diff --git a/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp b/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp new file mode 100644 index 00000000000..3e45e55bdcb --- /dev/null +++ b/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp @@ -0,0 +1,200 @@ +#include "virtgpu-forward-impl.h" +#include "virtgpu-shm.h" + +int apir_device_get_count(virtgpu * gpu) { + static int32_t dev_count = -1; + if (dev_count != -1) { + return dev_count; + } + + apir_encoder * encoder; + apir_decoder * decoder; + ApirForwardReturnCode ret; + + REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_DEVICE_GET_COUNT); + REMOTE_CALL(gpu, encoder, decoder, ret); + + apir_decode_int32_t(decoder, &dev_count); + + remote_call_finish(gpu, encoder, decoder); + + return dev_count; +} + +const char * apir_device_get_name(virtgpu * gpu) { + static char * string = nullptr; + if (string) { + return string; + } + apir_encoder * encoder; + apir_decoder * decoder; + ApirForwardReturnCode ret; + + REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_DEVICE_GET_NAME); + REMOTE_CALL(gpu, encoder, decoder, ret); + + const size_t string_size = apir_decode_array_size_unchecked(decoder); + string = (char *) apir_decoder_alloc_array(sizeof(char), string_size); + if (!string) { + GGML_LOG_ERROR("%s: Could not allocate the device name buffer\n", __func__); + return NULL; + } + apir_decode_char_array(decoder, string, string_size); + + remote_call_finish(gpu, encoder, decoder); + + return string; +} + +const char * apir_device_get_description(virtgpu * gpu) { + apir_encoder * encoder; + apir_decoder * decoder; + ApirForwardReturnCode ret; + + REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_DEVICE_GET_DESCRIPTION); + + REMOTE_CALL(gpu, encoder, decoder, ret); + + const size_t string_size = apir_decode_array_size_unchecked(decoder); + char * string = (char *) apir_decoder_alloc_array(sizeof(char), string_size); + if (!string) { + GGML_LOG_ERROR("%s: Could not allocate the device description buffer\n", __func__); + + return NULL; + } + apir_decode_char_array(decoder, string, string_size); + + remote_call_finish(gpu, encoder, decoder); + + return string; +} + +uint32_t apir_device_get_type(virtgpu * gpu) { + static uint32_t dev_type = 255; + if (dev_type != 255) { + return dev_type; + } + + apir_encoder * encoder; + apir_decoder * decoder; + ApirForwardReturnCode ret; + + REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_DEVICE_GET_TYPE); + + REMOTE_CALL(gpu, encoder, decoder, ret); + + apir_decode_uint32_t(decoder, &dev_type); + + remote_call_finish(gpu, encoder, decoder); + + return dev_type; +} + +void apir_device_get_memory(virtgpu * gpu, size_t * free, size_t * total) { + static size_t dev_free = 0; + static size_t dev_total = 0; + apir_encoder * encoder; + apir_decoder * decoder; + ApirForwardReturnCode ret; + + REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_DEVICE_GET_MEMORY); + + REMOTE_CALL(gpu, encoder, decoder, ret); + + apir_decode_size_t(decoder, &dev_free); + apir_decode_size_t(decoder, &dev_total); + + *free = dev_free; + *total = dev_total; + + remote_call_finish(gpu, encoder, decoder); + + return; +} + +bool apir_device_supports_op(virtgpu * gpu, const ggml_tensor * op) { + apir_encoder * encoder; + apir_decoder * decoder; + ApirForwardReturnCode ret; + + REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_DEVICE_SUPPORTS_OP); + + apir_encode_ggml_tensor_inline(encoder, op); + + REMOTE_CALL(gpu, encoder, decoder, ret); + + bool supports_op; + apir_decode_bool_t(decoder, &supports_op); + + remote_call_finish(gpu, encoder, decoder); + + return supports_op; +} + +apir_buffer_type_host_handle_t apir_device_get_buffer_type(virtgpu * gpu) { + apir_encoder * encoder; + apir_decoder * decoder; + ApirForwardReturnCode ret; + + REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_DEVICE_GET_BUFFER_TYPE); + + REMOTE_CALL(gpu, encoder, decoder, ret); + + apir_buffer_type_host_handle_t buft_handle; + apir_decode_apir_buffer_type_host_handle_t(decoder, &buft_handle); + + remote_call_finish(gpu, encoder, decoder); + + return buft_handle; +} + +void apir_device_get_props(virtgpu * gpu, + bool * async, + bool * host_buffer, + bool * buffer_from_host_ptr, + bool * events) { + apir_encoder * encoder; + apir_decoder * decoder; + ApirForwardReturnCode ret; + + REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_DEVICE_GET_PROPS); + + REMOTE_CALL(gpu, encoder, decoder, ret); + + apir_decode_bool_t(decoder, async); + apir_decode_bool_t(decoder, host_buffer); + apir_decode_bool_t(decoder, buffer_from_host_ptr); + apir_decode_bool_t(decoder, events); + + remote_call_finish(gpu, encoder, decoder); + + return; +} + +apir_buffer_context_t apir_device_buffer_from_ptr(virtgpu * gpu, size_t size, size_t max_tensor_size) { + apir_encoder * encoder; + apir_decoder * decoder; + ApirForwardReturnCode ret; + + apir_buffer_context_t buffer_context; + + REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_DEVICE_BUFFER_FROM_PTR); + + if (virtgpu_shmem_create(gpu, size, &buffer_context.shmem)) { + GGML_ABORT("Couldn't allocate the guest-host shared buffer"); + } + + apir_encode_virtgpu_shmem_res_id(encoder, buffer_context.shmem.res_id); + + apir_encode_size_t(encoder, &size); + apir_encode_size_t(encoder, &max_tensor_size); + + REMOTE_CALL(gpu, encoder, decoder, ret); + + apir_decode_apir_buffer_host_handle_t(decoder, &buffer_context.host_handle); + buffer_context.buft_host_handle = apir_decode_apir_buffer_type_host_handle(decoder); + + remote_call_finish(gpu, encoder, decoder); + + return buffer_context; +} diff --git a/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h b/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h new file mode 100644 index 00000000000..eea3e7e5a9b --- /dev/null +++ b/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h @@ -0,0 +1,29 @@ +#include "virtgpu.h" + +#include "ggml-remoting.h" +#include "backend/shared/apir_backend.h" +#include "backend/shared/apir_cs_ggml.h" + +#include "ggml-backend-impl.h" + +#define REMOTE_CALL_PREPARE(gpu_dev_name, encoder_name, apir_command_type__) \ + do { \ + int32_t forward_flag = (int32_t) apir_command_type__; \ + encoder_name = remote_call_prepare(gpu_dev_name, APIR_COMMAND_TYPE_FORWARD, forward_flag); \ + if (!encoder_name) { \ + GGML_ABORT("%s: failed to prepare the remote call encoder", __func__); \ + } \ + } while (0) + +#define REMOTE_CALL(gpu_dev_name, encoder_name, decoder_name, ret_name) \ + do { \ + ret_name = (ApirForwardReturnCode) remote_call(gpu_dev_name, encoder_name, &decoder_name, 0, NULL); \ + if (!decoder_name) { \ + GGML_ABORT("%s: failed to kick the remote call", __func__); \ + } \ + if (ret_name < APIR_FORWARD_BASE_INDEX) { \ + GGML_ABORT("%s: failed to forward the API call: %s: code %d", __func__, \ + apir_forward_error(ret_name), ret_name); \ + } \ + ret_name = (ApirForwardReturnCode) (ret_name - APIR_FORWARD_BASE_INDEX); \ + } while (0) diff --git a/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h b/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h new file mode 100644 index 00000000000..c27c07f0865 --- /dev/null +++ b/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h @@ -0,0 +1,51 @@ +#pragma once + +/* device */ +void apir_device_get_device_count(struct virtgpu * gpu); +int apir_device_get_count(struct virtgpu * gpu); +const char * apir_device_get_name(struct virtgpu * gpu); +const char * apir_device_get_description(struct virtgpu * gpu); +uint32_t apir_device_get_type(struct virtgpu * gpu); +void apir_device_get_memory(struct virtgpu * gpu, size_t * free, size_t * total); +bool apir_device_supports_op(struct virtgpu * gpu, const ggml_tensor * op); +apir_buffer_type_host_handle_t apir_device_get_buffer_type(struct virtgpu * gpu); +void apir_device_get_props(struct virtgpu * gpu, + bool * async, + bool * host_buffer, + bool * buffer_from_host_ptr, + bool * events); +apir_buffer_context_t apir_device_buffer_from_ptr(struct virtgpu * gpu, size_t size, size_t max_tensor_size); + +/* buffer-type */ +const char * apir_buffer_type_get_name(struct virtgpu * gpu, ggml_backend_buffer_type_t buft); +size_t apir_buffer_type_get_alignment(struct virtgpu * gpu, ggml_backend_buffer_type_t buft); +size_t apir_buffer_type_get_max_size(struct virtgpu * gpu, ggml_backend_buffer_type_t buft); +bool apir_buffer_type_is_host(struct virtgpu * gpu, ggml_backend_buffer_type_t buft); +apir_buffer_context_t apir_buffer_type_alloc_buffer(struct virtgpu * gpu, + ggml_backend_buffer_type_t buffer_buft, + size_t size); +size_t apir_buffer_type_get_alloc_size(struct virtgpu * gpu, ggml_backend_buffer_type_t buft, const ggml_tensor * op); + +/* buffer */ +void * apir_buffer_get_base(struct virtgpu * gpu, apir_buffer_context_t * buffer_context); +void apir_buffer_set_tensor(struct virtgpu * gpu, + apir_buffer_context_t * buffer_context, + ggml_tensor * tensor, + const void * data, + size_t offset, + size_t size); +void apir_buffer_get_tensor(struct virtgpu * gpu, + apir_buffer_context_t * buffer_context, + const ggml_tensor * tensor, + void * data, + size_t offset, + size_t size); +bool apir_buffer_cpy_tensor(struct virtgpu * gpu, + apir_buffer_context_t * buffer_context, + const ggml_tensor * src, + const ggml_tensor * dst); +void apir_buffer_clear(struct virtgpu * gpu, apir_buffer_context_t * buffer_context, uint8_t value); +void apir_buffer_free_buffer(struct virtgpu * gpu, apir_buffer_context_t * buffer_context); + +/* backend */ +ggml_status apir_backend_graph_compute(struct virtgpu * gpu, ggml_cgraph * cgraph); diff --git a/ggml/src/ggml-virtgpu/virtgpu-shm.cpp b/ggml/src/ggml-virtgpu/virtgpu-shm.cpp new file mode 100644 index 00000000000..4def405a62b --- /dev/null +++ b/ggml/src/ggml-virtgpu/virtgpu-shm.cpp @@ -0,0 +1,99 @@ +#include "virtgpu-shm.h" + +#include "virtgpu.h" + +#include + +static uint32_t virtgpu_ioctl_resource_create_blob(virtgpu * gpu, + uint32_t blob_mem, + uint32_t blob_flags, + size_t blob_size, + uint64_t blob_id, + uint32_t * res_id) { +#ifdef SIMULATE_BO_SIZE_FIX + blob_size = align64(blob_size, 4096); +#endif + + drm_virtgpu_resource_create_blob args = { + .blob_mem = blob_mem, + .blob_flags = blob_flags, + .bo_handle = 0, + .res_handle = 0, + .size = blob_size, + .pad = 0, + .cmd_size = 0, + .cmd = 0, + .blob_id = blob_id, + }; + + if (virtgpu_ioctl(gpu, DRM_IOCTL_VIRTGPU_RESOURCE_CREATE_BLOB, &args)) { + return 0; + } + + *res_id = args.res_handle; + return args.bo_handle; +} + +static void virtgpu_ioctl_gem_close(virtgpu * gpu, uint32_t gem_handle) { + drm_gem_close args = { + .handle = gem_handle, + .pad = 0, + }; + + const int ret = virtgpu_ioctl(gpu, DRM_IOCTL_GEM_CLOSE, &args); + assert(!ret); +#ifdef NDEBUG + UNUSED(ret); +#endif +} + +static void * virtgpu_ioctl_map(virtgpu * gpu, uint32_t gem_handle, size_t size) { + drm_virtgpu_map args = { + .offset = 0, + .handle = gem_handle, + .pad = 0, + }; + + if (virtgpu_ioctl(gpu, DRM_IOCTL_VIRTGPU_MAP, &args)) { + return NULL; + } + + void * ptr = mmap(NULL, size, PROT_READ | PROT_WRITE, MAP_SHARED, gpu->fd, args.offset); + if (ptr == MAP_FAILED) { + return NULL; + } + + return ptr; +} + +void virtgpu_shmem_destroy(virtgpu * gpu, virtgpu_shmem * shmem) { + munmap(shmem->mmap_ptr, shmem->mmap_size); + virtgpu_ioctl_gem_close(gpu, shmem->gem_handle); +} + +int virtgpu_shmem_create(virtgpu * gpu, size_t size, virtgpu_shmem * shmem) { + size = align64(size, 16384); + + uint32_t res_id; + uint32_t gem_handle = virtgpu_ioctl_resource_create_blob(gpu, VIRTGPU_BLOB_MEM_HOST3D, + VIRTGPU_BLOB_FLAG_USE_MAPPABLE, size, 0, &res_id); + + if (!gem_handle) { + return 1; + } + + void * ptr = virtgpu_ioctl_map(gpu, gem_handle, size); + if (!ptr) { + virtgpu_ioctl_gem_close(gpu, gem_handle); + GGML_LOG_ERROR("virtgpu_ioctl_map FAILED\n"); + exit(1); + return 1; + } + + shmem->res_id = res_id; + shmem->mmap_size = size; + shmem->mmap_ptr = ptr; + shmem->gem_handle = gem_handle; + + return 0; +} diff --git a/ggml/src/ggml-virtgpu/virtgpu-shm.h b/ggml/src/ggml-virtgpu/virtgpu-shm.h new file mode 100644 index 00000000000..606860a0946 --- /dev/null +++ b/ggml/src/ggml-virtgpu/virtgpu-shm.h @@ -0,0 +1,23 @@ +#pragma once + +#include "virtgpu-utils.h" + +#include + +#include +#include +#include +#include + +struct virtgpu; + +struct virtgpu_shmem { + uint32_t res_id; + size_t mmap_size; + void * mmap_ptr; + + uint32_t gem_handle; +}; + +int virtgpu_shmem_create(virtgpu * gpu, size_t size, virtgpu_shmem * shmem); +void virtgpu_shmem_destroy(virtgpu * gpu, virtgpu_shmem * shmem); diff --git a/ggml/src/ggml-virtgpu/virtgpu-utils.cpp b/ggml/src/ggml-virtgpu/virtgpu-utils.cpp new file mode 100644 index 00000000000..8a2805e9902 --- /dev/null +++ b/ggml/src/ggml-virtgpu/virtgpu-utils.cpp @@ -0,0 +1,179 @@ +#include "virtgpu-utils.h" + +#include +#include + +#include + +#define NODE_ALLOC_ALIGN 64 +#define NODE_PTR_MASK (~((uintptr_t) NODE_ALLOC_ALIGN - 1)) +#define NODE_LEVEL_MASK ((uintptr_t) NODE_ALLOC_ALIGN - 1) +#define NULL_NODE 0 + +#define os_malloc_aligned(_size, _align) _aligned_malloc(_size, _align) +#define os_free_aligned(_ptr) free(_ptr) +#define p_atomic_cmpxchg(v, old, _new) __sync_val_compare_and_swap((v), (old), (_new)) + +static inline uint64_t util_logbase2_64(uint64_t n) { +#if defined(HAVE___BUILTIN_CLZLL) + return ((sizeof(uint64_t) * 8 - 1) - __builtin_clzll(n | 1)); +#else + uint64_t pos = 0ull; + if (n >= 1ull << 32) { + n >>= 32; + pos += 32; + } + if (n >= 1ull << 16) { + n >>= 16; + pos += 16; + } + if (n >= 1ull << 8) { + n >>= 8; + pos += 8; + } + if (n >= 1ull << 4) { + n >>= 4; + pos += 4; + } + if (n >= 1ull << 2) { + n >>= 2; + pos += 2; + } + if (n >= 1ull << 1) { + pos += 1; + } + return pos; +#endif +} + +void util_sparse_array_init(util_sparse_array * arr, size_t elem_size, size_t node_size) { + memset(arr, 0, sizeof(*arr)); + arr->elem_size = elem_size; + arr->node_size_log2 = util_logbase2_64(node_size); + assert(node_size >= 2 && node_size == (1ull << arr->node_size_log2)); +} + +static inline void * os_malloc_aligned(size_t size, size_t alignment) { + void * ptr; + alignment = (alignment + sizeof(void *) - 1) & ~(sizeof(void *) - 1); + if (posix_memalign(&ptr, alignment, size) != 0) { + return NULL; + } + return ptr; +} + +static inline void * _util_sparse_array_node_data(uintptr_t handle) { + return (void *) (handle & NODE_PTR_MASK); +} + +static inline unsigned _util_sparse_array_node_level(uintptr_t handle) { + return handle & NODE_LEVEL_MASK; +} + +static inline void _util_sparse_array_node_finish(util_sparse_array * arr, uintptr_t node) { + if (_util_sparse_array_node_level(node) > 0) { + uintptr_t * children = (uintptr_t *) _util_sparse_array_node_data(node); + size_t node_size = 1ull << arr->node_size_log2; + for (size_t i = 0; i < node_size; i++) { + if (children[i]) { + _util_sparse_array_node_finish(arr, children[i]); + } + } + } + + os_free_aligned(_util_sparse_array_node_data(node)); +} + +static inline uintptr_t _util_sparse_array_node(void * data, unsigned level) { + assert(data != NULL); + assert(((uintptr_t) data & NODE_LEVEL_MASK) == 0); + assert((level & NODE_PTR_MASK) == 0); + return (uintptr_t) data | level; +} + +inline uintptr_t _util_sparse_array_node_alloc(util_sparse_array * arr, unsigned level) { + size_t size; + if (level == 0) { + size = arr->elem_size << arr->node_size_log2; + } else { + size = sizeof(uintptr_t) << arr->node_size_log2; + } + + void * data = os_malloc_aligned(size, NODE_ALLOC_ALIGN); + memset(data, 0, size); + + return _util_sparse_array_node(data, level); +} + +static inline uintptr_t _util_sparse_array_set_or_free_node(uintptr_t * node_ptr, uintptr_t cmp_node, uintptr_t node) { + uintptr_t prev_node = p_atomic_cmpxchg(node_ptr, cmp_node, node); + + if (prev_node != cmp_node) { + /* We lost the race. Free this one and return the one that was already + * allocated. + */ + os_free_aligned(_util_sparse_array_node_data(node)); + return prev_node; + } else { + return node; + } +} + +void * util_sparse_array_get(util_sparse_array * arr, uint64_t idx) { + const unsigned node_size_log2 = arr->node_size_log2; + uintptr_t root = p_atomic_read(&arr->root); + if (unlikely(!root)) { + unsigned root_level = 0; + uint64_t idx_iter = idx >> node_size_log2; + while (idx_iter) { + idx_iter >>= node_size_log2; + root_level++; + } + uintptr_t new_root = _util_sparse_array_node_alloc(arr, root_level); + root = _util_sparse_array_set_or_free_node(&arr->root, NULL_NODE, new_root); + } + + while (1) { + unsigned root_level = _util_sparse_array_node_level(root); + uint64_t root_idx = idx >> (root_level * node_size_log2); + if (likely(root_idx < (1ull << node_size_log2))) { + break; + } + + /* In this case, we have a root but its level is low enough that the + * requested index is out-of-bounds. + */ + uintptr_t new_root = _util_sparse_array_node_alloc(arr, root_level + 1); + + uintptr_t * new_root_children = (uintptr_t *) _util_sparse_array_node_data(new_root); + new_root_children[0] = root; + + /* We only add one at a time instead of the whole tree because it's + * easier to ensure correctness of both the tree building and the + * clean-up path. Because we're only adding one node we never have to + * worry about trying to free multiple things without freeing the old + * things. + */ + root = _util_sparse_array_set_or_free_node(&arr->root, root, new_root); + } + + void * node_data = _util_sparse_array_node_data(root); + unsigned node_level = _util_sparse_array_node_level(root); + while (node_level > 0) { + uint64_t child_idx = (idx >> (node_level * node_size_log2)) & ((1ull << node_size_log2) - 1); + + uintptr_t * children = (uintptr_t *) node_data; + uintptr_t child = p_atomic_read(&children[child_idx]); + + if (unlikely(!child)) { + child = _util_sparse_array_node_alloc(arr, node_level - 1); + child = _util_sparse_array_set_or_free_node(&children[child_idx], NULL_NODE, child); + } + + node_data = _util_sparse_array_node_data(child); + node_level = _util_sparse_array_node_level(child); + } + + uint64_t elem_idx = idx & ((1ull << node_size_log2) - 1); + return (void *) ((char *) node_data + (elem_idx * arr->elem_size)); +} diff --git a/ggml/src/ggml-virtgpu/virtgpu-utils.h b/ggml/src/ggml-virtgpu/virtgpu-utils.h new file mode 100644 index 00000000000..a0036b4e2bc --- /dev/null +++ b/ggml/src/ggml-virtgpu/virtgpu-utils.h @@ -0,0 +1,86 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define unlikely(x) __builtin_expect(!!(x), 0) +#define likely(x) __builtin_expect(!!(x), 1) + +#ifndef UNUSED +# define UNUSED(x) (void) (x) +#endif + +/** Checks is a value is a power of two. Does not handle zero. */ +#define IS_POT(v) (((v) & ((v) - 1)) == 0) + +/** Checks is a value is a power of two. Zero handled. */ +#define IS_POT_NONZERO(v) ((v) != 0 && IS_POT(v)) + +/** Align a value to a power of two */ +#define ALIGN_POT(x, pot_align) (((x) + (pot_align) - 1) & ~((pot_align) - 1)) + +#define p_atomic_read(_v) __atomic_load_n((_v), __ATOMIC_ACQUIRE) + +static inline bool util_is_power_of_two_nonzero64(uint64_t v) { + return IS_POT_NONZERO(v); +} + +static inline uint64_t align64(uint64_t value, uint64_t alignment) { + assert(util_is_power_of_two_nonzero64(alignment)); + return ALIGN_POT(value, alignment); +} + +struct list_head { + list_head * prev; + list_head * next; +}; + +struct util_sparse_array { + size_t elem_size; + unsigned node_size_log2; + + uintptr_t root; +}; + +void * util_sparse_array_get(util_sparse_array * arr, uint64_t idx); +void util_sparse_array_init(util_sparse_array * arr, size_t elem_size, size_t node_size); + +inline void os_time_sleep(int64_t usecs) { + timespec time; + time.tv_sec = usecs / 1000000; + time.tv_nsec = (usecs % 1000000) * 1000; + while (clock_nanosleep(CLOCK_MONOTONIC, 0, &time, &time) == EINTR) + ; +} + +struct timer_data { + long long start; + long long total; + long long count; +}; + +static inline void start_timer(timer_data * timer) { + timespec ts; + clock_gettime(CLOCK_MONOTONIC, &ts); + timer->start = (long long) ts.tv_sec * 1000000000LL + ts.tv_nsec; +} + +// returns the duration in ns +static inline long long stop_timer(timer_data * timer) { + timespec ts; + clock_gettime(CLOCK_MONOTONIC, &ts); + long long timer_end = (long long) ts.tv_sec * 1000000000LL + ts.tv_nsec; + + long long duration = (timer_end - timer->start); + timer->total += duration; + timer->count += 1; + + return duration; +} diff --git a/ggml/src/ggml-virtgpu/virtgpu.cpp b/ggml/src/ggml-virtgpu/virtgpu.cpp new file mode 100644 index 00000000000..005c8e21db8 --- /dev/null +++ b/ggml/src/ggml-virtgpu/virtgpu.cpp @@ -0,0 +1,498 @@ +#include "virtgpu.h" + +#include +#include + +#include +#include +#include + +static virt_gpu_result_t virtgpu_open_device(virtgpu * gpu, const drmDevicePtr dev); +static virt_gpu_result_t virtgpu_open(virtgpu * gpu); + +static virt_gpu_result_t virtgpu_init_capset(virtgpu * gpu); +static virt_gpu_result_t virtgpu_init_context(virtgpu * gpu); + +static int virtgpu_ioctl_context_init(virtgpu * gpu, virgl_renderer_capset capset_id); +static int virtgpu_ioctl_get_caps(virtgpu * gpu, + virgl_renderer_capset id, + uint32_t version, + void * capset, + size_t capset_size); +static uint64_t virtgpu_ioctl_getparam(virtgpu * gpu, uint64_t param); +static void virtgpu_init_renderer_info(virtgpu * gpu); + +static void log_call_duration(long long call_duration_ns, const char * name); + +const uint64_t APIR_HANDSHAKE_MAX_WAIT_MS = 2 * 1000; // 2s +const uint64_t APIR_LOADLIBRARY_MAX_WAIT_MS = 60 * 1000; // 60s + +static int virtgpu_handshake(virtgpu * gpu) { + apir_encoder * encoder; + apir_decoder * decoder; + + encoder = remote_call_prepare(gpu, APIR_COMMAND_TYPE_HANDSHAKE, 0); + if (!encoder) { + GGML_ABORT("%s: failed to prepare the remote call encoder", __func__); + return 1; + } + + /* write handshake props */ + + uint32_t guest_major = APIR_PROTOCOL_MAJOR; + uint32_t guest_minor = APIR_PROTOCOL_MINOR; + apir_encode_uint32_t(encoder, &guest_major); + apir_encode_uint32_t(encoder, &guest_minor); + + /* *** */ + + uint32_t ret_magic; + long long call_duration_ns; + ret_magic = remote_call(gpu, encoder, &decoder, APIR_HANDSHAKE_MAX_WAIT_MS, &call_duration_ns); + log_call_duration(call_duration_ns, "API Remoting handshake"); + + if (!decoder) { + GGML_ABORT( + "%s: failed to initiate the communication with the virglrenderer library. " + "Most likely, the wrong virglrenderer library was loaded in the hypervisor.", + __func__); + return 1; + } + + /* read handshake return values */ + + uint32_t host_major; + uint32_t host_minor; + + if (ret_magic != APIR_HANDSHAKE_MAGIC) { + GGML_ABORT("%s: handshake with the virglrenderer failed (code=%d | %s)", __func__, ret_magic, + apir_backend_initialize_error(ret_magic)); + } else { + apir_decode_uint32_t(decoder, &host_major); + apir_decode_uint32_t(decoder, &host_minor); + } + + remote_call_finish(gpu, encoder, decoder); + + if (ret_magic != APIR_HANDSHAKE_MAGIC) { + return 1; + } + + GGML_LOG_INFO("%s: Guest is running with %u.%u\n", __func__, guest_major, guest_minor); + GGML_LOG_INFO("%s: Host is running with %u.%u\n", __func__, host_major, host_minor); + + if (guest_major != host_major) { + GGML_LOG_ERROR("Host major (%d) and guest major (%d) version differ\n", host_major, guest_major); + } else if (guest_minor != host_minor) { + GGML_LOG_WARN("Host minor (%d) and guest minor (%d) version differ\n", host_minor, guest_minor); + } + + return 0; +} + +static ApirLoadLibraryReturnCode virtgpu_load_library(virtgpu * gpu) { + apir_encoder * encoder; + apir_decoder * decoder; + ApirLoadLibraryReturnCode ret; + + encoder = remote_call_prepare(gpu, APIR_COMMAND_TYPE_LOADLIBRARY, 0); + if (!encoder) { + GGML_ABORT("%s: hypercall error: failed to prepare the remote call encoder", __func__); + return APIR_LOAD_LIBRARY_HYPERCALL_INITIALIZATION_ERROR; + } + + long long call_duration_ns; + + ret = (ApirLoadLibraryReturnCode) remote_call(gpu, encoder, &decoder, APIR_LOADLIBRARY_MAX_WAIT_MS, + &call_duration_ns); + log_call_duration(call_duration_ns, "API Remoting LoadLibrary"); + + if (!decoder) { + GGML_ABORT("%s: hypercall error: failed to kick the API remoting hypercall.\n", __func__); + return APIR_LOAD_LIBRARY_HYPERCALL_INITIALIZATION_ERROR; + } + + remote_call_finish(gpu, encoder, decoder); + + if (ret == APIR_LOAD_LIBRARY_SUCCESS) { + GGML_LOG_INFO("%s: The API Remoting backend was successfully loaded and initialized\n", __func__); + + return ret; + } + + // something wrong happened, find out what. + + if (ret < APIR_LOAD_LIBRARY_INIT_BASE_INDEX) { + GGML_ABORT("%s: virglrenderer could not load the API Remoting backend library: %s (code %d)", __func__, + apir_load_library_error(ret), ret); + return ret; + } + + GGML_LOG_INFO("%s: virglrenderer successfully loaded the API Remoting backend library", __func__); + + ApirLoadLibraryReturnCode apir_ret = (ApirLoadLibraryReturnCode) (ret - APIR_LOAD_LIBRARY_INIT_BASE_INDEX); + + if (apir_ret < APIR_LOAD_LIBRARY_INIT_BASE_INDEX) { + GGML_ABORT("%s: the API Remoting backend library couldn't load the backend library: apir code=%d | %s)", + __func__, apir_ret, apir_load_library_error(apir_ret)); + } else { + uint32_t lib_ret = apir_ret - APIR_LOAD_LIBRARY_INIT_BASE_INDEX; + GGML_ABORT("%s: the API Remoting backend library initialize its backend library: apir code=%d)", __func__, + lib_ret); + } + return ret; +} + +virtgpu * create_virtgpu() { + virtgpu * gpu = new virtgpu(); + + gpu->use_apir_capset = getenv("GGML_REMOTING_USE_APIR_CAPSET") != nullptr; + util_sparse_array_init(&gpu->shmem_array, sizeof(virtgpu_shmem), 1024); + + if (virtgpu_open(gpu) != APIR_SUCCESS) { + GGML_ABORT("%s: failed to open the virtgpu device", __func__); + return NULL; + } + + if (virtgpu_init_capset(gpu) != APIR_SUCCESS) { + GGML_ABORT("%s: failed to initialize the GPU capset", __func__); + return NULL; + } + + if (virtgpu_init_context(gpu) != APIR_SUCCESS) { + GGML_ABORT("%s: failed to initialize the GPU context", __func__); + return NULL; + } + + if (virtgpu_shmem_create(gpu, SHMEM_REPLY_SIZE, &gpu->reply_shmem)) { + GGML_ABORT("%s: failed to create the shared reply memory pages", __func__); + return NULL; + } + + if (virtgpu_shmem_create(gpu, SHMEM_DATA_SIZE, &gpu->data_shmem)) { + GGML_ABORT("%s: failed to create the shared data memory pages", __func__); + return NULL; + } + + if (virtgpu_handshake(gpu)) { + GGML_ABORT("%s: failed to handshake with the virglrenderer library", __func__); + return NULL; + } + + if (virtgpu_load_library(gpu) != APIR_LOAD_LIBRARY_SUCCESS) { + GGML_ABORT("%s: failed to load the backend library", __func__); + return NULL; + } + + return gpu; +} + +static virt_gpu_result_t virtgpu_open(virtgpu * gpu) { + drmDevicePtr devs[8]; + int count = drmGetDevices2(0, devs, ARRAY_SIZE(devs)); + if (count < 0) { + GGML_LOG_ERROR("%s: failed to enumerate DRM devices\n", __func__); + return APIR_ERROR_INITIALIZATION_FAILED; + } + + virt_gpu_result_t result = APIR_ERROR_INITIALIZATION_FAILED; + for (int i = 0; i < count; i++) { + result = virtgpu_open_device(gpu, devs[i]); + if (result == APIR_SUCCESS) { + break; + } + } + + drmFreeDevices(devs, count); + + return result; +} + +static virt_gpu_result_t virtgpu_open_device(virtgpu * gpu, const drmDevicePtr dev) { + const char * node_path = dev->nodes[DRM_NODE_RENDER]; + + int fd = open(node_path, O_RDWR | O_CLOEXEC); + if (fd < 0) { + GGML_ABORT("failed to open %s", node_path); + return APIR_ERROR_INITIALIZATION_FAILED; + } + + drmVersionPtr version = drmGetVersion(fd); + if (!version || strcmp(version->name, "virtio_gpu") || version->version_major != 0) { + if (version) { + GGML_ABORT("unknown DRM driver %s version %d", version->name, version->version_major); + } else { + GGML_ABORT("failed to get DRM driver version"); + } + + if (version) { + drmFreeVersion(version); + } + close(fd); + return APIR_ERROR_INITIALIZATION_FAILED; + } + + gpu->fd = fd; + + drmFreeVersion(version); + + GGML_LOG_INFO("using DRM device %s\n", node_path); + + return APIR_SUCCESS; +} + +static virt_gpu_result_t virtgpu_init_context(virtgpu * gpu) { + assert(!gpu->capset.version); + const int ret = virtgpu_ioctl_context_init(gpu, gpu->capset.id); + if (ret) { + GGML_LOG_INFO("failed to initialize context: %s\n", strerror(errno)); + return APIR_ERROR_INITIALIZATION_FAILED; + } + + return APIR_SUCCESS; +} + +static virt_gpu_result_t virtgpu_init_capset(virtgpu * gpu) { + if (gpu->use_apir_capset) { + GGML_LOG_INFO("Using the APIR capset\n"); + gpu->capset.id = VIRTGPU_DRM_CAPSET_APIR; + } else { + GGML_LOG_INFO("Using the Venus capset\n"); + gpu->capset.id = VIRTGPU_DRM_CAPSET_VENUS; + } + gpu->capset.version = 0; + + int ret = + virtgpu_ioctl_get_caps(gpu, gpu->capset.id, gpu->capset.version, &gpu->capset.data, sizeof(gpu->capset.data)); + + if (ret) { + GGML_LOG_INFO("failed to get APIR v%d capset: %s\n", gpu->capset.version, strerror(errno)); + return APIR_ERROR_INITIALIZATION_FAILED; + } + + assert(gpu->capset.data.supports_blob_resources); + + return APIR_SUCCESS; +} + +static int virtgpu_ioctl_context_init(virtgpu * gpu, virgl_renderer_capset capset_id) { + drm_virtgpu_context_set_param ctx_set_params[3] = { + { + .param = VIRTGPU_CONTEXT_PARAM_CAPSET_ID, + .value = capset_id, + }, + { + .param = VIRTGPU_CONTEXT_PARAM_NUM_RINGS, + .value = 1, + }, + { + .param = VIRTGPU_CONTEXT_PARAM_POLL_RINGS_MASK, + .value = 0, /* don't generate drm_events on fence signaling */ + }, + }; + + drm_virtgpu_context_init args = { + .num_params = ARRAY_SIZE(ctx_set_params), + .pad = 0, + .ctx_set_params = (uintptr_t) &ctx_set_params, + }; + + return virtgpu_ioctl(gpu, DRM_IOCTL_VIRTGPU_CONTEXT_INIT, &args); +} + +static int virtgpu_ioctl_get_caps(virtgpu * gpu, + virgl_renderer_capset id, + uint32_t version, + void * capset, + size_t capset_size) { + drm_virtgpu_get_caps args = { + .cap_set_id = id, + .cap_set_ver = version, + .addr = (uintptr_t) capset, + .size = (__u32) capset_size, + .pad = 0, + }; + + return virtgpu_ioctl(gpu, DRM_IOCTL_VIRTGPU_GET_CAPS, &args); +} + +static uint64_t virtgpu_ioctl_getparam(virtgpu * gpu, uint64_t param) { + /* val must be zeroed because kernel only writes the lower 32 bits */ + uint64_t val = 0; + drm_virtgpu_getparam args = { + .param = param, + .value = (uintptr_t) &val, + }; + + const int ret = virtgpu_ioctl(gpu, DRM_IOCTL_VIRTGPU_GETPARAM, &args); + return ret ? 0 : val; +} + +apir_encoder * remote_call_prepare(virtgpu * gpu, ApirCommandType apir_cmd_type, int32_t cmd_flags) { + /* + * Prepare the command encoder and its buffer + */ + + static char encoder_buffer[4096]; + + static apir_encoder enc; + enc = { + .cur = encoder_buffer, + .start = encoder_buffer, + .end = encoder_buffer + sizeof(encoder_buffer), + .fatal = false, + }; + + /* + * Fill the command encoder with the common args: + * - cmd_type (int32_t) + * - cmd_flags (int32_t) + * - reply res id (uint32_t) + */ + + int32_t cmd_type = apir_cmd_type; + + // for testing during the hypervisor transition + if (!gpu->use_apir_capset) { + cmd_type += VENUS_COMMAND_TYPE_LENGTH; + } + apir_encode_int32_t(&enc, &cmd_type); + apir_encode_int32_t(&enc, &cmd_flags); + + uint32_t reply_res_id = gpu->reply_shmem.res_id; + apir_encode_uint32_t(&enc, &reply_res_id); + + return &enc; +} + +void remote_call_finish(virtgpu * gpu, apir_encoder * enc, apir_decoder * dec) { + UNUSED(gpu); + + if (!enc) { + GGML_LOG_ERROR("Invalid (null) encoder\n"); + } + + if (!dec) { + GGML_LOG_ERROR("Invalid (null) decoder\n"); + } + + if (apir_encoder_get_fatal(enc)) { + GGML_LOG_ERROR("Failed to encode the output parameters.\n"); + } + + if (apir_decoder_get_fatal(dec)) { + GGML_LOG_ERROR("Failed to decode the input parameters.\n"); + } +} + +uint32_t remote_call(virtgpu * gpu, + apir_encoder * encoder, + apir_decoder ** decoder, + float max_wait_ms, + long long * call_duration_ns) { + /* + * Prepare the reply notification pointer + */ + + volatile std::atomic_uint * atomic_reply_notif = (volatile std::atomic_uint *) gpu->reply_shmem.mmap_ptr; + *atomic_reply_notif = 0; + + /* + * Trigger the execbuf ioctl + */ + + drm_virtgpu_execbuffer args = { + .flags = VIRTGPU_EXECBUF_RING_IDX, + .size = (uint32_t) (encoder->cur - encoder->start), + .command = (uintptr_t) encoder->start, + + .bo_handles = 0, + .num_bo_handles = 0, + + .fence_fd = 0, + .ring_idx = 0, + .syncobj_stride = 0, + .num_in_syncobjs = 0, + .num_out_syncobjs = 0, + .in_syncobjs = 0, + .out_syncobjs = 0, + }; + + *decoder = NULL; + + int ret = drmIoctl(gpu->fd, DRM_IOCTL_VIRTGPU_EXECBUFFER, &args); + + if (ret != 0) { + GGML_ABORT("%s: the virtgpu EXECBUFFER ioctl failed (%d)", __func__, ret); + } + + /* + * Wait for the response notification + */ + timer_data wait_host_reply_timer = { 0, 0, 0 }; + + start_timer(&wait_host_reply_timer); + + timespec ts_start, ts_end; + clock_gettime(CLOCK_MONOTONIC, &ts_start); + long long start_time = (long long) ts_start.tv_sec * 1000000000LL + ts_start.tv_nsec; + + bool timedout = false; + uint32_t notif_value = 0; + while (true) { + notif_value = std::atomic_load_explicit(atomic_reply_notif, std::memory_order_acquire); + + if (notif_value != 0) { + break; + } + + int64_t base_sleep_us = 15; + + os_time_sleep(base_sleep_us); + + if (max_wait_ms) { + clock_gettime(CLOCK_MONOTONIC, &ts_end); + long long end_time = (long long) ts_end.tv_sec * 1000000000LL + ts_end.tv_nsec; + float duration_ms = (end_time - start_time) / 1000000; + + if (duration_ms > max_wait_ms) { + timedout = true; + break; + } + } + } + + if (call_duration_ns) { + *call_duration_ns = stop_timer(&wait_host_reply_timer); + } + + if (max_wait_ms && timedout) { + GGML_LOG_ERROR("timed out waiting for the host answer...\n"); + return APIR_FORWARD_TIMEOUT; + } + + /* + * Prepare the decoder + */ + static apir_decoder response_dec; + response_dec.cur = (char *) gpu->reply_shmem.mmap_ptr + sizeof(*atomic_reply_notif); + response_dec.end = (char *) gpu->reply_shmem.mmap_ptr + gpu->reply_shmem.mmap_size; + *decoder = &response_dec; + + // extract the actual return value from the notif flag + uint32_t returned_value = notif_value - 1; + return returned_value; +} + +static void log_call_duration(long long call_duration_ns, const char * name) { + double call_duration_ms = (double) call_duration_ns / 1e6; // 1 millisecond = 1e6 nanoseconds + double call_duration_s = (double) call_duration_ns / 1e9; // 1 second = 1e9 nanoseconds + + if (call_duration_s > 1) { + GGML_LOG_INFO("%s: waited %.2fs for the %s host reply...\n", __func__, call_duration_s, name); + } else if (call_duration_ms > 1) { + GGML_LOG_INFO("%s: waited %.2fms for the %s host reply...\n", __func__, call_duration_ms, name); + } else { + GGML_LOG_INFO("%s: waited %lldns for the %s host reply...\n", __func__, call_duration_ns, name); + } +} diff --git a/ggml/src/ggml-virtgpu/virtgpu.h b/ggml/src/ggml-virtgpu/virtgpu.h new file mode 100644 index 00000000000..d4bb42e20b2 --- /dev/null +++ b/ggml/src/ggml-virtgpu/virtgpu.h @@ -0,0 +1,92 @@ +#pragma once + +#include "virtgpu-utils.h" +#include "virtgpu-shm.h" +#include "virtgpu-apir.h" + +#include "backend/shared/api_remoting.h" +#include "backend/shared/apir_cs.h" + +#include +#include +#include +#include +#include +#include +#include + +#include + +#define VIRGL_RENDERER_UNSTABLE_APIS 1 +#include "apir_hw.h" +#include +#include "venus_hw.h" + +#ifndef VIRTGPU_DRM_CAPSET_APIR +// Will be defined include/drm/virtgpu_drm.h when +// https://gitlab.freedesktop.org/virgl/virglrenderer/-/merge_requests/1590/diffs +// is merged +#define VIRTGPU_DRM_CAPSET_APIR 10 +#endif + +// Mesa/Virlgrenderer Venus internal. Only necessary during the +// Venus->APIR transition in Virglrenderer +#define VENUS_COMMAND_TYPE_LENGTH 331 + +#ifndef VIRTGPU_DRM_CAPSET_VENUS // only available with Linux >= v6.16 +#define VIRTGPU_DRM_CAPSET_VENUS 4 +#endif + +typedef uint32_t virgl_renderer_capset; + +/* from src/virtio/vulkan/vn_renderer_virtgpu.c */ +#define VIRTGPU_PCI_VENDOR_ID 0x1af4 +#define VIRTGPU_PCI_DEVICE_ID 0x1050 +#define VIRTGPU_BLOB_MEM_GUEST_VRAM 0x0004 +#define VIRTGPU_PARAM_GUEST_VRAM 9 + +#define SHMEM_DATA_SIZE 0x1830000 // 24MiB +#define SHMEM_REPLY_SIZE 0x4000 + +#define ARRAY_SIZE(x) (sizeof(x) / sizeof((x)[0])) + +enum virt_gpu_result_t { + APIR_SUCCESS = 0, + APIR_ERROR_INITIALIZATION_FAILED = -1, +}; + +#define PRINTFLIKE(f, a) __attribute__((format(__printf__, f, a))) + +struct virtgpu { + bool use_apir_capset; + + int fd; + + struct { + virgl_renderer_capset id; + uint32_t version; + virgl_renderer_capset_apir data; + } capset; + + util_sparse_array shmem_array; + + /* APIR communication pages */ + virtgpu_shmem reply_shmem; + virtgpu_shmem data_shmem; +}; + +static inline int virtgpu_ioctl(virtgpu * gpu, unsigned long request, void * args) { + return drmIoctl(gpu->fd, request, args); +} + +virtgpu * create_virtgpu(); + +apir_encoder * remote_call_prepare(virtgpu * gpu, ApirCommandType apir_cmd_type, int32_t cmd_flags); + +uint32_t remote_call(virtgpu * gpu, + apir_encoder * enc, + apir_decoder ** dec, + float max_wait_ms, + long long * call_duration_ns); + +void remote_call_finish(virtgpu * gpu, apir_encoder * enc, apir_decoder * dec); From dda7d9cd1c2ba9808abbd66a0d2268d946a56fe3 Mon Sep 17 00:00:00 2001 From: Oleksandr Kuvshynov <661042+okuvshynov@users.noreply.github.com> Date: Wed, 28 Jan 2026 06:35:54 -0500 Subject: [PATCH 056/831] vulkan: handle device dedup on MacOS + Vega II Duo cards (llama/19058) Deduplication here relied on the fact that vulkan would return unique UUID for different physical GPUs. It is at the moment not always the case. On Mac Pro 2019 running Mac OS, with 2 Vega II Duo cards (so, 4 GPU total), MotlenVK would assign same UUID to pairs of GPUs, unless they are connected with Infinity Fabric. See more details here: KhronosGroup/MoltenVK#2683. The right way is to fix that in MoltenVK, but until it is fixed, llama.cpp would only recognize 2 of 4 GPUs in such configuration. The deduplication logic here is changed to only filter GPUs if UUID is same but driver is different. --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index b5e5dba95fe..514f290d098 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -5522,22 +5522,32 @@ static void ggml_vk_instance_init() { if ((new_props.properties.deviceType == vk::PhysicalDeviceType::eDiscreteGpu || new_props.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu) && ggml_vk_device_is_supported(devices[i])) { // Check if there are two physical devices corresponding to the same GPU + // This handles the case where the same GPU appears with different drivers (e.g., RADV + AMDVLK on Linux), + // see https://github.com/ggml-org/llama.cpp/pull/7582 for original deduplication. + // However, for MoltenVK on macOS, multiple GPUs on the same card may report the same UUID, + // see https://github.com/KhronosGroup/MoltenVK/issues/2683. Until this is fixed, we'll only deduplicate + // when drivers differ (same driver + same UUID = likely different GPUs) auto old_device = std::find_if( vk_instance.device_indices.begin(), vk_instance.device_indices.end(), - [&devices, &new_id](const size_t k){ + [&devices, &new_id, &new_driver](const size_t k){ vk::PhysicalDeviceProperties2 old_props; + vk::PhysicalDeviceDriverProperties old_driver; vk::PhysicalDeviceIDProperties old_id; - old_props.pNext = &old_id; + old_props.pNext = &old_driver; + old_driver.pNext = &old_id; devices[k].getProperties2(&old_props); - bool equals = std::equal(std::begin(old_id.deviceUUID), std::end(old_id.deviceUUID), std::begin(new_id.deviceUUID)); - equals = equals || ( + bool same_uuid = std::equal(std::begin(old_id.deviceUUID), std::end(old_id.deviceUUID), std::begin(new_id.deviceUUID)); + same_uuid = same_uuid || ( old_id.deviceLUIDValid && new_id.deviceLUIDValid && std::equal(std::begin(old_id.deviceLUID), std::end(old_id.deviceLUID), std::begin(new_id.deviceLUID)) ); - return equals; + // Only deduplicate if same UUID AND different drivers + // (same driver + same UUID on MoltenVK = likely different GPUs on multi-GPU card) + bool different_driver = (old_driver.driverID != new_driver.driverID); + return same_uuid && different_driver; } ); if (old_device == vk_instance.device_indices.end()) { From cc0c103b5d78c74ed43bc648a284e121ef1e35d2 Mon Sep 17 00:00:00 2001 From: Patryk Kaminski Date: Wed, 28 Jan 2026 16:33:54 +0100 Subject: [PATCH 057/831] ggml-sycl: remove unused syclcompat header (llama/19140) The syclcompat/math.hpp is not used anymore. The change that intrduced it was successfuly reverted (https://github.com/ggml-org/llama.cpp/pull/17826). This include path will become obsolete and dropped in oneAPI 2026.0 effectively breaking ggml-sycl builds. --- ggml/src/ggml-sycl/dpct/helper.hpp | 1 - 1 file changed, 1 deletion(-) diff --git a/ggml/src/ggml-sycl/dpct/helper.hpp b/ggml/src/ggml-sycl/dpct/helper.hpp index 30ec1e8dafc..8ae8098717d 100644 --- a/ggml/src/ggml-sycl/dpct/helper.hpp +++ b/ggml/src/ggml-sycl/dpct/helper.hpp @@ -15,7 +15,6 @@ #include #include -#include #include #ifdef GGML_SYCL_USE_INTEL_ONEMKL From 33148bb52336b057186ff7edf2ee9019e9facc5f Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Wed, 28 Jan 2026 18:52:45 +0100 Subject: [PATCH 058/831] Vulkan Flash Attention Coopmat1 Refactor (llama/19075) * vulkan: use coopmat for flash attention p*v matrix multiplication * fix P loading issue * fix barrier position * remove reduction that is no longer needed * move max thread reduction into loop * remove osh padding * add bounds checks and padding * remove unused code * fix shmem sizes, loop duration and accesses * don't overwrite Qf, add new shared psh buffer instead * add missing bounds checks * use subgroup reductions * optimize * move bounds check, reduce barriers * support other Bc values and other subgroup sizes * remove D_split * replace Of register array with shared memory Ofsh array * parallelize HSV across the rowgroups * go back to Of in registers, not shmem * vectorize sfsh * don't store entire K tile in shmem * fixes * load large k tiles to shmem on Nvidia * adapt shared memory host check function to shader changes * remove Bc 32 case * remove unused variable * fix missing mask reduction tmspsh barrier * fix mask bounds check * fix rowmax f16 under/overflow to inf * fix flash_attn_cm2 BLOCK_SIZE preprocessor directives --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 64 ++- .../vulkan-shaders/flash_attn_base.glsl | 6 + .../vulkan-shaders/flash_attn_cm1.comp | 414 +++++++++++------- .../vulkan-shaders/flash_attn_cm2.comp | 6 +- 4 files changed, 317 insertions(+), 173 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 514f290d098..3852867c291 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -3162,17 +3162,31 @@ static void ggml_vk_load_shaders(vk_device& device) { // For scalar, use 128 (arbitrary) // The same D_split value is used for both HSK and HSV, so just base it on the union of the LSBs. const uint32_t D = (hsk|hsv); - uint32_t wg_size = (path == FA_SCALAR || path == FA_COOPMAT1) - ? scalar_flash_attention_workgroup_size - : ((small_rows && (D % 32) == 0) ? 256 : 128); auto rows_cols = fa_rows_cols(path, hsk, hsv, clamp, type, small_rows, small_cache); + uint32_t wg_size; + switch (path) { + case FA_COOPMAT2: + wg_size = ((small_rows && (D % 32) == 0) ? 256 : 128); + break; + case FA_COOPMAT1: + wg_size = (rows_cols[1] / 16) * device->subgroup_size; // enough subgroups for Bc/MatBc + break; + default: + wg_size = scalar_flash_attention_workgroup_size; + break; + } + // D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it. // D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader. const uint32_t D_lsb = D ^ (D & (D-1)); uint32_t D_split = std::min(std::min(device->subgroup_size, 8u), D_lsb / 4); - return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split}; + // Nvidia prefers shared memory use to load large tiles of K + // AMD prefers loading K directly from global memory + const uint32_t k_load_shmem = device->vendor_id == VK_VENDOR_ID_NVIDIA ? 1 : 0; + + return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split, device->subgroup_size, k_load_shmem}; }; #define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \ @@ -3187,15 +3201,15 @@ static void ggml_vk_load_shaders(vk_device& device) { if (path == FAPATH) { \ if (aligned) { \ if (f32acc) { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \ } else { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \ } \ } else { \ if (f32acc) { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \ } else { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \ } \ } \ } \ @@ -8344,41 +8358,49 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf; const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize; - VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", total_size=" << total_size << ", supported=" << supported); + VK_LOG_DEBUG("ggml_vk_flash_attn_scalar_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", total_size=" << total_size << ", supported=" << supported); return supported; } -static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool f32acc) { +static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type kv_type) { // Needs to be kept up to date on shader changes GGML_UNUSED(hsv); - const uint32_t wg_size = scalar_flash_attention_workgroup_size; - const uint32_t Br = coopmat1_flash_attention_num_large_rows; - const uint32_t Bc = scalar_flash_attention_Bc; + const auto rows_cols = fa_rows_cols(FA_COOPMAT1, hsk, hsv, 0, kv_type, false, false); + const uint32_t Br = rows_cols[0]; + const uint32_t Bc = rows_cols[1]; + + const uint32_t MatBr = 16, MatBc = 16; + + const uint32_t row_split = Bc / MatBc; const uint32_t hsk_pad = ROUNDUP_POW2(hsk, 16); const uint32_t acctype = f32acc ? 4 : 2; const uint32_t f16vec4 = 8; - const uint32_t tmpsh = wg_size * sizeof(float); - const uint32_t tmpshv4 = wg_size * 4 * acctype; + const uint32_t tmpsh = (Bc / MatBc) * sizeof(float); const uint32_t qstride = hsk_pad / 4 + 2; const uint32_t Qf = Br * qstride * f16vec4; + const uint32_t psh_stride = Br / 4 + 2; + const uint32_t Psh = Bc * psh_stride * f16vec4; + const uint32_t sfshstride = (hsk <= 128) ? (Br + 8) : Br; const uint32_t sfsh = Bc * sfshstride * acctype; - const uint32_t kshstride = hsk_pad / 4 + 2; - const uint32_t ksh = Bc * kshstride * f16vec4; + const bool k_load_shmem = device->vendor_id == VK_VENDOR_ID_NVIDIA; + const uint32_t kshstride = (k_load_shmem ? hsk_pad : MatBr) / 4 + 2; + const uint32_t vsh_stride = MatBc / 4 * row_split; + const uint32_t ksh = ((kshstride >= vsh_stride) ? (Bc * kshstride) : (Bc * vsh_stride)) * f16vec4; - const uint32_t slope = Br * sizeof(float); + const uint32_t slope = Br * acctype; - const uint32_t total_size = tmpsh + tmpshv4 + Qf + sfsh + ksh + slope; + const uint32_t total_size = tmpsh + Qf + Psh + sfsh + ksh + slope; const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize; - VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", f32acc=" << f32acc << ", total_size=" << total_size << ", supported=" << supported); + VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", f32acc=" << f32acc << ", kv_type=" << kv_type << ", total_size=" << total_size << ", supported=" << supported); return supported; } @@ -8442,7 +8464,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx const bool coopmat_shape_supported = (dst->op_params[3] == GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f32acc) || (dst->op_params[3] != GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f16acc); - const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device, HSK, HSV, dst->op_params[3] == GGML_PREC_F32); + const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device, HSK, HSV, dst->op_params[3] == GGML_PREC_F32, k->type); if (!coopmat_shape_supported || !coopmat_shmem_supported) { path = FA_SCALAR; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl index 29b5c7c3a41..23a4d2c0058 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl @@ -8,6 +8,8 @@ layout (constant_id = 3) const uint32_t HSK = 32; layout (constant_id = 4) const uint32_t HSV = 32; layout (constant_id = 5) const uint32_t Clamp = 0; layout (constant_id = 6) const uint32_t D_split = 16; +layout (constant_id = 7) const uint32_t SubGroupSize = 32; +layout (constant_id = 8) const uint32_t K_LOAD_SHMEM = 0; // Round up head sizes to a multiple of 16, for coopmat1/coopmat2 paths const uint32_t HSK_pad = (HSK + 15) & ~15; @@ -74,6 +76,10 @@ layout (binding = 1) readonly buffer K_PACKED16 {A_TYPE_PACKED16 k_data_packed16 layout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16[];} v_packed; #endif +#ifndef BLOCK_SIZE +#define BLOCK_SIZE 1 +#endif + #if defined(DATA_A_F32) #undef BLOCK_SIZE #define BLOCK_SIZE 4 diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp index 0eb50fe58f9..83d52d19d67 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp @@ -7,6 +7,7 @@ #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require #extension GL_KHR_shader_subgroup_basic : enable +#extension GL_KHR_shader_subgroup_arithmetic : enable #extension GL_KHR_shader_subgroup_vote : enable #extension GL_KHR_memory_scope_semantics : enable #extension GL_KHR_cooperative_matrix : enable @@ -14,12 +15,13 @@ #include "types.glsl" #include "flash_attn_base.glsl" -const uint32_t HSK_per_thread = HSK / D_split; -const uint32_t HSV_per_thread = HSV / D_split; +// These need to be supported N,M values for a MatBc x MatBr x 16 coopmatmuladd +const uint32_t MatBr = 16; +const uint32_t MatBc = 16; -const uint32_t row_split = 4; +const uint32_t row_split = Bc / MatBc; const uint32_t rows_per_thread = Br / row_split; -const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split / row_split; +const uint32_t cols_per_iter = gl_WorkGroupSize.x / row_split; const uint32_t cols_per_thread = Bc / cols_per_iter; @@ -40,24 +42,24 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TY return elem; } -// These need to be supported N,M values for a MatBc x MatBr x 16 coopmatmuladd -const uint32_t MatBr = 16; -const uint32_t MatBc = 16; - -shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x]; -shared ACC_TYPEV4 tmpshv4[gl_WorkGroupSize.x]; +shared float tmpsh[row_split]; const uint32_t qstride = HSK_pad / 4 + 2; // in units of f16vec4 shared f16vec4 Qf[Br * qstride]; +const uint psh_stride = Br / 4 + 2; +shared f16vec4 Psh[Bc * psh_stride]; + // Avoid padding for hsk==256 to make it fit in 48KB shmem. -const uint32_t sfshstride = (HSK <= 128) ? (Br + 8) : Br; -shared ACC_TYPE sfsh[Bc * sfshstride]; +const uint32_t sfshstride = (HSK <= 128) ? (Br / 4 + 2) : Br / 4; +shared ACC_TYPEV4 sfsh[Bc * sfshstride]; -const uint32_t kshstride = HSK_pad / 4 + 2; // in units of f16vec4 -shared f16vec4 ksh[Bc * kshstride]; +const uint32_t kshstride = (K_LOAD_SHMEM != 0 ? HSK_pad : MatBr) / 4 + 2; // in units of f16vec4 +const uint v_cols = MatBc / 4 * row_split; // total cols, 4 vec4s per MatBc * number of subgroups +const uint vsh_stride = v_cols; +shared f16vec4 ksh[(kshstride >= vsh_stride) ? (Bc * kshstride) : (Bc * vsh_stride)]; -shared float slope[Br]; +shared ACC_TYPE slope[Br]; void main() { #ifdef NEEDS_INIT_IQ_SHMEM @@ -69,9 +71,9 @@ void main() { const uint32_t tid = gl_LocalInvocationIndex; const uint32_t threads_per_rowgroup = gl_WorkGroupSize.x / row_split; + const uint32_t d_per_thread = (HSV/4 + threads_per_rowgroup - 1) / threads_per_rowgroup; const uint32_t row_tid = gl_LocalInvocationIndex / threads_per_rowgroup; - const uint32_t d_tid = gl_LocalInvocationIndex % D_split; - const uint32_t col_tid = (gl_LocalInvocationIndex % threads_per_rowgroup) / D_split; + const uint32_t col_tid = gl_LocalInvocationIndex % threads_per_rowgroup; #define tile_row(r) (row_tid * rows_per_thread + (r)) @@ -102,9 +104,9 @@ void main() { } barrier(); - ACC_TYPEV4 Of[rows_per_thread][HSV_per_thread / 4]; - [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { - [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + ACC_TYPEV4 Of[rows_per_thread][d_per_thread]; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + [[unroll]] for (uint32_t d = 0; d < d_per_thread; ++d) { Of[r][d] = ACC_TYPEV4(0.0); } } @@ -125,13 +127,11 @@ void main() { uint r = tid; slope[r] = perElemOpComputeSlope(r, col_tid, ACC_TYPE(0), iq2); } - barrier(); } else { if (tid < Br) { uint r = tid; - slope[r] = 1.0; + slope[r] = ACC_TYPE(1.0); } - barrier(); } #if BLOCK_SIZE > 1 @@ -149,19 +149,45 @@ void main() { [[dont_unroll]] for (uint32_t j = start_j; j < end_j; ++j) { - float mask_cache[Bc * Br / WorkGroupSize]; + f16vec4 mask_cache[Bc * Br / 4 / WorkGroupSize]; if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) { bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0; float max_mask = NEG_FLT_MAX_OVER_2; - [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { - uint32_t c = (idx + tid) % Bc; - uint32_t r = (idx + tid) / Bc; - if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) { - if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) { - float m = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]); + [[unroll]] for (uint32_t idx = 0; idx < Bc * Br / 4; idx += gl_WorkGroupSize.x) { + uint32_t c = (idx + tid) / (Br / 4); + uint32_t r = (idx + tid) % (Br / 4); + if (idx + tid < Bc * Br / 4 || idx + gl_WorkGroupSize.x <= Bc * Br / 4) { + if ((!KV_bounds_check || j * Bc + c < KV)) { + f16vec4 m; + if (!nem1_bounds_check || i * Br + r * 4 + 3 < p.nem1) { + m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)], + data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)], + data_m[m_offset + (i * Br + r * 4 + 2) * m_stride + (j * Bc + c)], + data_m[m_offset + (i * Br + r * 4 + 3) * m_stride + (j * Bc + c)]); + max_mask = max(max(max(max(max_mask, float(m[0])), float(m[1])), float(m[2])), float(m[3])); + } else if (i * Br + r * 4 + 2 < p.nem1) { + m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)], + data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)], + data_m[m_offset + (i * Br + r * 4 + 2) * m_stride + (j * Bc + c)], + 0.0); + max_mask = max(max(max(max_mask, float(m[0])), float(m[1])), float(m[2])); + } else if (i * Br + r * 4 + 1 < p.nem1) { + m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)], + data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)], + 0.0, + 0.0); + max_mask = max(max(max_mask, float(m[0])), float(m[1])); + } else if (i * Br + r * 4 < p.nem1) { + m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)], + 0.0, + 0.0, + 0.0); + max_mask = max(max_mask, float(m[0])); + } else { + m = f16vec4(0.0); + } mask_cache[idx / WorkGroupSize] = m; - max_mask = max(max_mask, m); } } } @@ -180,26 +206,28 @@ void main() { } } - [[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) { - uint32_t d = (idx + tid) % (HSK / 4); - uint32_t c = (idx + tid) / (HSK / 4); - if (c < Bc && d < HSK / 4) { - f16vec4 K_Tf = f16vec4(0); - if (!KV_bounds_check || j * Bc + c < KV) { + if (K_LOAD_SHMEM != 0) { + [[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) { + uint32_t d = (idx + tid) % (HSK / 4); + uint32_t c = (idx + tid) / (HSK / 4); + if (c < Bc && d < HSK / 4) { + f16vec4 K_Tf = f16vec4(0); + if (!KV_bounds_check || j * Bc + c < KV) { #if BLOCK_SIZE > 1 - uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d; - uint ib = coord / BLOCK_SIZE; - uint iqs = (coord % BLOCK_SIZE); - K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K)); + uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d; + uint ib = coord / BLOCK_SIZE; + uint iqs = (coord % BLOCK_SIZE); + K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K)); #else - K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]); + K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]); #endif - } + } - ksh[c * kshstride + d] = K_Tf; + ksh[c * kshstride + d] = K_Tf; + } } + barrier(); } - barrier(); // K * Q^T -> S^T: Bc x HSK_pad * HSK_pad x Br -> Bc x Br // Bc split across workgroup (four subgroups), loop over HSK in chunks of 16: 16 x 16 * 16 x 16 -> 16 x 16 @@ -208,11 +236,55 @@ void main() { coopmat KMat; coopmat QMat; - for (uint32_t d = 0; d < HSK_pad / 16; ++d) { - coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor); + [[unroll]] for (uint32_t d = 0; d < HSK_pad / 16; ++d) { + if (K_LOAD_SHMEM == 0) { +#if BLOCK_SIZE == 1 + if (KV_bounds_check || d * 16 + 16 > HSK) { +#endif + barrier(); + [[unroll]] for (uint32_t idx = 0; idx < Bc * MatBr / 4; idx += gl_WorkGroupSize.x) { + uint32_t col_vec = (idx + tid) % (MatBr / 4); + uint32_t row = (idx + tid) / (MatBr / 4); + if (idx + tid < Bc * MatBr / 4) { + f16vec4 K_Tf = f16vec4(0); + if ((!KV_bounds_check || j * Bc + row < KV) && (HSK == HSK_pad || d * 16 + col_vec * 4 < HSK)) { +#if BLOCK_SIZE > 1 + uint coord = (j * Bc + row) * k_stride * BLOCK_SIZE + d * 16 + col_vec * 4; + uint ib = coord / BLOCK_SIZE; + uint iqs = (coord % BLOCK_SIZE); + K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K)); +#else + K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + row) * k_stride / 4 + d * 16 / 4 + col_vec]); +#endif + } - uint coord = (gl_SubgroupID * MatBc) * kshstride + d * 16 / 4; - coopMatLoad(KMat, ksh, coord, kshstride, gl_CooperativeMatrixLayoutRowMajor); + ksh[row * kshstride + col_vec] = K_Tf; + } + } + barrier(); +#if BLOCK_SIZE == 1 + } +#endif + +#if BLOCK_SIZE == 1 + if (KV_bounds_check || d * 16 + 16 > HSK) +#endif + { + uint coord = (gl_SubgroupID * MatBc) * kshstride; + coopMatLoad(KMat, ksh, coord, kshstride, gl_CooperativeMatrixLayoutRowMajor); + } +#if BLOCK_SIZE == 1 + else { + const uint coord = k_offset / 4 + (j * Bc + gl_SubgroupID * MatBc) * k_stride / 4 + d * 16 / 4; + coopMatLoad(KMat, data_kv4, coord, k_stride / 4, gl_CooperativeMatrixLayoutRowMajor); + } +#endif + } else { + uint coord = (gl_SubgroupID * MatBc) * kshstride + d * 16 / 4; + coopMatLoad(KMat, ksh, coord, kshstride, gl_CooperativeMatrixLayoutRowMajor); + } + + coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor); SfMat = coopMatMulAdd(KMat, QMat, SfMat); } @@ -222,26 +294,26 @@ void main() { barrier(); if (p.logit_softcap != 0.0f) { - [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { - uint32_t c = (idx + tid) / Br; - uint32_t r = (idx + tid) % Br; - if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) { - sfsh[c * sfshstride + r] = ACC_TYPE(p.logit_softcap * tanh(sfsh[c * sfshstride + r])); + [[unroll]] for (uint32_t idx = 0; idx < Bc * Br / 4; idx += gl_WorkGroupSize.x) { + uint32_t c = (idx + tid) / (Br / 4); + uint32_t r = (idx + tid) % (Br / 4); + if (idx + tid < Bc * Br / 4 || idx + gl_WorkGroupSize.x <= Bc * Br / 4) { + sfsh[c * sfshstride + r] = ACC_TYPEV4(p.logit_softcap * tanh(sfsh[c * sfshstride + r])); } } barrier(); } if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) { - bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0; - - [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { - uint32_t c = (idx + tid) % Bc; - uint32_t r = (idx + tid) / Bc; - if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) { - if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) { - float f = mask_cache[idx / WorkGroupSize]; - sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * f); + [[unroll]] for (uint32_t idx = 0; idx < Bc * Br / 4; idx += gl_WorkGroupSize.x) { + uint32_t c = (idx + tid) / (Br / 4); + uint32_t r = (idx + tid) % (Br / 4); + if (idx + tid < Bc * Br / 4 || idx + gl_WorkGroupSize.x <= Bc * Br / 4) { + if (!KV_bounds_check || j * Bc + c < KV) { + // Mask nem1 bounds check is handled when loading masks + ACC_TYPEV4 masks = ACC_TYPEV4(mask_cache[idx / WorkGroupSize]); + ACC_TYPEV4 slopes = ACC_TYPEV4(slope[r * 4], slope[r * 4 + 1], slope[r * 4 + 2], slope[r * 4 + 3]); + sfsh[c * sfshstride + r] += slopes * masks; } } } @@ -250,121 +322,154 @@ void main() { float eMf[rows_per_thread]; [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + const uint r_vec = tile_row(r) / 4; + const uint r_comp = tile_row(r) % 4; + float rowmaxf = NEG_FLT_MAX_OVER_2; [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) { continue; } - rowmaxf = max(rowmaxf, float(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride])); + rowmaxf = max(rowmaxf, float(sfsh[r_vec + (c * cols_per_iter + col_tid) * sfshstride][r_comp])); } float Moldf = Mf[r]; + // Compute max across the row + rowmaxf = subgroupMax(rowmaxf); + // M = max(rowmax, Mold) // P = e^(S - M) // eM = e^(Mold - M) Mf[r] = max(rowmaxf, Moldf); eMf[r] = exp(Moldf - Mf[r]); + + Lf[r] = eMf[r]*Lf[r]; } - [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) { + const uint d_local = d0 / threads_per_rowgroup; [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - Of[r][d] = ACC_TYPE(eMf[r]) * Of[r][d]; + Of[r][d_local] = ACC_TYPE(eMf[r]) * Of[r][d_local]; } } - [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - Lf[r] = eMf[r]*Lf[r]; - } + // Calculate and store Pf in Psh [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { - if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) { - continue; - } - float Pf[rows_per_thread]; - [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - Pf[r] = exp(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride] - Mf[r]); - Lf[r] += Pf[r]; + const uint col = c * cols_per_iter + col_tid; + + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; r += 4) { + const uint row = tile_row(r); + if (KV_bounds_check && j * Bc + col >= KV) { + Psh[col * psh_stride + row / 4] = f16vec4(0.0f); + } else { + const vec4 mfvec = vec4(Mf[r], Mf[r + 1], Mf[r + 2], Mf[r + 3]); + const f16vec4 Pf = f16vec4(exp(vec4(sfsh[row / 4 + col * sfshstride]) - mfvec)); + [[unroll]] for (uint32_t vec_idx = 0; vec_idx < 4; ++vec_idx) { + Lf[r + vec_idx] += Pf[vec_idx]; + } + Psh[col * psh_stride + row / 4] = Pf; + } } - [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + } + + const uint num_hsv_tiles = (HSV + MatBc * row_split - 1) / (MatBc * row_split); // round up + + // Each subgroup handles HSV/4 columns + [[unroll]] for (uint32_t hsv_tile = 0; hsv_tile < num_hsv_tiles; ++hsv_tile) { + const uint hsv_offset = (hsv_tile * row_split + gl_SubgroupID) * 16; + + SfMat = coopmat(0); + + // Preload V tiles for [Bc, 16 * num subgroups] + const uint v_rows = Bc; + const uint v_total = v_rows * v_cols; + const uint v_loads_per_thread = v_total / gl_WorkGroupSize.x; + +#if BLOCK_SIZE == 1 + // For f16, only preload if not aligned + if (KV_bounds_check) { +#endif + [[unroll]] for (uint32_t i = 0; i < v_loads_per_thread; ++i) { + const uint idx = i * gl_WorkGroupSize.x + tid; + const uint row = idx / v_cols; + const uint col = idx % v_cols; + + const uint v_row = j * Bc + row; + const uint v_col = hsv_tile * MatBc * row_split + col * 4; + + const uint coord = v_row * v_stride * BLOCK_SIZE + v_col; + const uint ib = coord / BLOCK_SIZE; + const uint iqs = coord % BLOCK_SIZE; + + if (!KV_bounds_check || (v_row < KV && v_col < HSV)) { #if BLOCK_SIZE > 1 - uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); - uint ib = coord / BLOCK_SIZE; - uint iqs = (coord % BLOCK_SIZE); - vec4 Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); + ksh[row * vsh_stride + col] = f16vec4(dequantize4(ib, iqs, v_offset, BINDING_IDX_V)); #else - vec4 Vf = vec4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]); + ksh[row * vsh_stride + col] = data_vv4[(v_offset + v_row * v_stride + v_col) / 4]; #endif - [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - Of[r][d] += ACC_TYPE(Pf[r]) * ACC_TYPEV4(Vf); + } else { + ksh[row * vsh_stride + col] = f16vec4(0.0f); } } - } +#if BLOCK_SIZE == 1 + } +#endif - barrier(); - } + barrier(); - // prevent race on tmpsh - barrier(); + [[unroll]] for (uint32_t bc_chunk = 0; bc_chunk < Bc / MatBc; ++bc_chunk) { + coopMatLoad(KMat, Psh, bc_chunk * MatBc * psh_stride, psh_stride, gl_CooperativeMatrixLayoutColumnMajor); - // reduce across threads +#if BLOCK_SIZE == 1 + if (!KV_bounds_check) { + // F16 values can be loaded directly from global memory + const uint v_tile_row = j * Bc + bc_chunk * MatBc; + const uint v_tile_offset = v_offset / 4 + v_tile_row * v_stride / 4 + hsv_offset / 4; + coopMatLoad(QMat, data_vv4, v_tile_offset, v_stride / 4, gl_CooperativeMatrixLayoutRowMajor); + } else +#endif + { + const uint v_tile_offset = bc_chunk * MatBr * v_cols + gl_SubgroupID * (MatBc / 4); + coopMatLoad(QMat, ksh, v_tile_offset, vsh_stride, gl_CooperativeMatrixLayoutRowMajor); + } + + SfMat = coopMatMulAdd(KMat, QMat, SfMat); + } + + // Store SfMat to sfsh and load into Of + const uint osh_stride = row_split * MatBc / 4; + const uint o_offset = gl_SubgroupID * MatBc / 4; + coopMatStore(SfMat, sfsh, o_offset, osh_stride, gl_CooperativeMatrixLayoutRowMajor); - float rowmaxf[rows_per_thread], eMf[rows_per_thread], Moldf[rows_per_thread]; - [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - FLOAT_TYPE M = Mf[r]; - tmpsh[tid] = M; - // Compute max across the row - barrier(); - [[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) { - M = max(M, tmpsh[tid ^ s]); - barrier(); - tmpsh[tid] = M; barrier(); - } - rowmaxf[r] = tmpsh[d_tid + row_tid * threads_per_rowgroup]; - barrier(); - } - [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - Moldf[r] = Mf[r]; + const uint hsv_per_tile = row_split * MatBc; + const uint hsv_base = hsv_tile * hsv_per_tile; + const uint d_values_per_tile = hsv_per_tile / 4; - // M = max(rowmax, Mold) - // eM = e^(Mold - M) - Mf[r] = max(rowmaxf[r], Moldf[r]); - eMf[r] = exp(Moldf[r] - Mf[r]); + const uint d_start = hsv_tile * d_values_per_tile; + const uint d_end = min(d_start + d_values_per_tile, HSV / 4); - Lf[r] = eMf[r]*Lf[r]; - } + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + const uint row = tile_row(r); - [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - FLOAT_TYPE L = Lf[r]; - tmpsh[tid] = L; - // Compute sum across the row - barrier(); - [[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) { - L += tmpsh[tid ^ s]; - barrier(); - tmpsh[tid] = L; - barrier(); + [[unroll]] for (uint32_t d_local = 0; d_local < d_per_thread; ++d_local) { + const uint d = d_local * threads_per_rowgroup + col_tid; + const uint hsv_col = 4 * d; + + if (hsv_col >= hsv_base && hsv_col < hsv_base + hsv_per_tile && hsv_col < HSV) { + const uint local_hsv = (hsv_col - hsv_base) / 4; + Of[r][d_local] += ACC_TYPEV4(sfsh[row * osh_stride + local_hsv]); + } + } + } } - Lf[r] = tmpsh[d_tid + row_tid * threads_per_rowgroup]; + barrier(); } [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { - - Of[r][d] = ACC_TYPE(eMf[r]) * Of[r][d]; - tmpshv4[tid] = Of[r][d]; - - barrier(); - [[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) { - Of[r][d] += tmpshv4[tid ^ s]; - barrier(); - tmpshv4[tid] = Of[r][d]; - barrier(); - } - Of[r][d] = tmpshv4[d_tid + row_tid * threads_per_rowgroup]; - barrier(); - } + Lf[r] = subgroupAdd(Lf[r]); } // If there is split_k, then the split_k resolve shader does the final @@ -375,9 +480,12 @@ void main() { [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { if (tile_row(r) < N) { - [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) { + const uint d = d0 + col_tid; + if (d >= HSV/4) break; + const uint d_local = d0 / threads_per_rowgroup; [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { - perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N); + perElemOpGqaStore(tile_row(r), 4 * d + comp, float(Of[r][d_local][comp]), o_offset, iq2, N); } } } @@ -404,8 +512,9 @@ void main() { if (sink > Mf[r]) { ms = exp(Mf[r] - sink); - [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { - Of[r][d] *= ACC_TYPE(ms); + [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) { + const uint d_local = d0 / threads_per_rowgroup; + Of[r][d_local] *= ACC_TYPE(ms); } } else { vs = exp(sink - Mf[r]); @@ -420,11 +529,12 @@ void main() { Lfrcp[r] = (Lf[r] == 0.0) ? 0.0 : (1.0 / Lf[r]); } - [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) { + const uint d_local = d0 / threads_per_rowgroup; [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - Of[r][d] *= ACC_TYPE(Lfrcp[r]); + Of[r][d_local] *= ACC_TYPE(Lfrcp[r]); #if defined(ACC_TYPE_MAX) - Of[r][d] = clamp(Of[r][d], -ACC_TYPE_MAX, ACC_TYPE_MAX); + Of[r][d_local] = clamp(Of[r][d_local], -ACC_TYPE_MAX, ACC_TYPE_MAX); #endif } } @@ -434,9 +544,12 @@ void main() { if (p.gqa_ratio > 1) { [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { if (tile_row(r) < N) { - [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) { + const uint d = d0 + col_tid; + if (d >= HSV / 4) break; + const uint d_local = d0 / threads_per_rowgroup; [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { - perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N); + perElemOpGqaStore(tile_row(r), 4 * d + comp, float(Of[r][d_local][comp]), o_offset, iq2, N); } } } @@ -444,9 +557,12 @@ void main() { } else { [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { if (i * Br + tile_row(r) < N) { - [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) { + const uint d = d0 + col_tid; + if (d >= HSV / 4) break; + const uint d_local = d0 / threads_per_rowgroup; [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { - data_o[o_offset + iq2 * HSV + (i * Br + tile_row(r)) * p.ne1 * HSV + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]); + data_o[o_offset + iq2 * HSV + (i * Br + tile_row(r)) * p.ne1 * HSV + 4 * d + comp] = D_TYPE(Of[r][d_local][comp]); } } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index d49a8da65fb..54f1b0b6226 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -55,7 +55,7 @@ ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE ele return max(elem0, elem1); } -#if defined(BLOCK_SIZE) +#if BLOCK_SIZE > 1 #define DECODEFUNC , DEQUANTFUNC #else #define DECODEFUNC @@ -85,7 +85,7 @@ void main() { tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0); -#if defined(BLOCK_SIZE) +#if BLOCK_SIZE > 1 tensorLayoutK = setTensorLayoutBlockSizeNV(tensorLayoutK, 1, BLOCK_SIZE); tensorLayoutV = setTensorLayoutBlockSizeNV(tensorLayoutV, 1, BLOCK_SIZE); #endif @@ -98,7 +98,7 @@ void main() { if (Clamp != gl_CooperativeMatrixClampModeConstantNV) { q_stride &= ~7; -#if !defined(BLOCK_SIZE) +#if BLOCK_SIZE == 1 k_stride &= ~7; v_stride &= ~7; #endif From f0e85bb142fdfdd3cb30d964385ddebea4f84c12 Mon Sep 17 00:00:00 2001 From: Neo Zhang Date: Thu, 29 Jan 2026 09:20:22 +0800 Subject: [PATCH 059/831] sycl: fix norm kernels: l2_norm, group_norm, rms_norm by remove assert to support more cases (llama/19154) Co-authored-by: Neo Zhang Jianyu --- ggml/src/ggml-sycl/ggml-sycl.cpp | 6 ++---- ggml/src/ggml-sycl/norm.cpp | 3 --- 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index ce2f0d41c96..3a4c092af5d 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -4606,14 +4606,12 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g return (op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32) && (op->type == op->src[0]->type); #endif case GGML_OP_NORM: - return true; case GGML_OP_L2_NORM: case GGML_OP_GROUP_NORM: - return ggml_is_contiguous(op->src[0]); case GGML_OP_RMS_NORM: - return ((op->src[0]->ne[0] % WARP_SIZE) == 0); + return true; case GGML_OP_RMS_NORM_BACK: - return ((op->src[0]->ne[0] % WARP_SIZE) == 0); + return ggml_is_contiguous(op->src[0]); case GGML_OP_SCALE: return true; case GGML_OP_CONT: diff --git a/ggml/src/ggml-sycl/norm.cpp b/ggml/src/ggml-sycl/norm.cpp index 823d3a4828c..00702b5d09c 100644 --- a/ggml/src/ggml-sycl/norm.cpp +++ b/ggml/src/ggml-sycl/norm.cpp @@ -251,7 +251,6 @@ static void norm_f32_sycl(const float * x, float * dst, const int ncols, const i const float eps, queue_ptr stream, int device) { const sycl::range<3> global_dims(nsamples, nchannels, nrows); - GGML_ASSERT(ncols % WARP_SIZE == 0); if (ncols < 1024) { const sycl::range<3> block_dims(1, 1, WARP_SIZE); stream->submit([&](sycl::handler& cgh) { @@ -334,7 +333,6 @@ static void group_norm_f32_sycl(const float* x, float* dst, static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const int nrows, const int nchannels, const int nsamples, const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, queue_ptr stream, int device) { - GGML_ASSERT(ncols % WARP_SIZE == 0); // printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE); const sycl::range<3> global_dims(nsamples, nchannels, nrows); @@ -374,7 +372,6 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols, const int nrows, const float eps, queue_ptr stream, int device) { - GGML_ASSERT(ncols % WARP_SIZE == 0); // printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE); if (ncols < 1024) { const sycl::range<3> block_dims(1, 1, WARP_SIZE); From 62ba8b537fef7b96727dbd3efa0d035ada52cb7d Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Thu, 29 Jan 2026 10:31:28 +0800 Subject: [PATCH 060/831] CUDA: refactor topk-moe to enable more models (GLM 4.7, Nemotron etc.) (llama/19126) --- ggml/src/ggml-cuda/ggml-cuda.cu | 279 ++++++++++++++++++++------- ggml/src/ggml-cuda/topk-moe.cu | 322 +++++++++++++++++++------------- ggml/src/ggml-cuda/topk-moe.cuh | 34 ++-- 3 files changed, 418 insertions(+), 217 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index e9df0ea4a7c..76d0f12550e 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3080,63 +3080,166 @@ static bool ggml_cuda_should_fuse_rope_set_rows(const ggml_tensor * rope, return true; } -static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list ops, std::initializer_list unary_ops) { -#ifndef NDEBUG - const size_t num_unary = std::count(ops.begin(), ops.end(), GGML_OP_UNARY); - GGML_ASSERT(unary_ops.size() == num_unary); -#endif - - //TODO: remove special case once ggml_can_fuse can handle empty nodes - std::initializer_list topk_moe_ops = - ggml_cuda_topk_moe_ops(/*with_norm*/ false, /*delayed_softmax=*/false); - std::initializer_list topk_moe_ops_with_norm = - ggml_cuda_topk_moe_ops(/*with_norm=*/true, /*delayed_softmax=*/false); - std::initializer_list topk_moe_ops_delayed_softmax = - ggml_cuda_topk_moe_ops(/*with_norm=*/false, /*delayed_softmax=*/true); +static bool ggml_cuda_topk_moe_fusion(const struct ggml_cgraph * cgraph, int node_idx, ggml_cuda_topk_moe_args & args) { + args.sigmoid = false; + args.softmax = false; + args.delayed_softmax = false; + args.prob_bias = false; + args.norm = false; - const auto is_equal = [](const std::initializer_list & list1, - const std::initializer_list & list2) { - return std::equal(list1.begin(), list1.end(), list2.begin(), list2.end()); - }; + const int n_nodes = cgraph->n_nodes; + ggml_tensor ** nodes = cgraph->nodes; - if (is_equal(topk_moe_ops_with_norm, ops) && - ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 9 })) { - ggml_tensor * softmax = cgraph->nodes[node_idx]; - ggml_tensor * weights = cgraph->nodes[node_idx + 9]; - ggml_tensor * get_rows = cgraph->nodes[node_idx + 4]; - ggml_tensor * argsort = cgraph->nodes[node_idx + 2]; - int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0]; + if (nodes[node_idx]->op == GGML_OP_SOFT_MAX) { + args.softmax = true; + } - if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) { - return true; + if (nodes[node_idx]->op == GGML_OP_UNARY) { + if (ggml_get_unary_op(nodes[node_idx]) != GGML_UNARY_OP_SIGMOID) { + return false; } + args.sigmoid = true; } - if (is_equal(topk_moe_ops, ops) && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 4 })) { - ggml_tensor * softmax = cgraph->nodes[node_idx]; - ggml_tensor * weights = cgraph->nodes[node_idx + 4]; - ggml_tensor * get_rows = cgraph->nodes[node_idx + 4]; - ggml_tensor * argsort = cgraph->nodes[node_idx + 2]; - int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0]; + if (nodes[node_idx]->op == GGML_OP_ARGSORT) { + args.delayed_softmax = true; + } - if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) { - return true; + node_idx++; + + if (args.sigmoid || args.softmax) { + // SOFTMAX -> RESHAPE + if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_RESHAPE || + nodes[node_idx]->src[0] != nodes[node_idx - 1]) { + return false; + } + ggml_tensor * probs_reshaped = nodes[node_idx]; + node_idx++; + + if (node_idx >= n_nodes) { + return false; + } + + // src of bias add is the unreshaped probs (-2 instead of -1) + if (nodes[node_idx]->op == GGML_OP_ADD && nodes[node_idx]->src[0] == nodes[node_idx - 2]) { + args.prob_bias = true; + node_idx++; + } + // RESHAPE/ADD -> ARGSORT + if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_ARGSORT) { + return false; + } + + if (args.prob_bias && nodes[node_idx]->src[0] != nodes[node_idx - 1]) { + return false; + } else if (!args.prob_bias && nodes[node_idx]->src[0] != nodes[node_idx - 2]) { + return false; + } + + node_idx++; + + // ARGSORT-> VIEW + if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_VIEW || + nodes[node_idx]->src[0] != nodes[node_idx - 1]) { + return false; + } + node_idx++; + + if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_GET_ROWS) { + return false; + } + + // GET_ROWS + if (nodes[node_idx]->src[0] != probs_reshaped || nodes[node_idx]->src[1] != nodes[node_idx - 1]) { + return false; + } + node_idx++; + } else if (args.delayed_softmax) { + if (node_idx - 2 < 0) { + return false; + } + ggml_tensor * probs_reshaped = nodes[node_idx - 2]; + + // VIEW->ARGSORT + if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_VIEW || + nodes[node_idx]->src[0] != nodes[node_idx - 1]) { + return false; + } + node_idx++; + + // GET_ROWS + if (node_idx >= n_nodes || nodes[node_idx]->src[1] != nodes[node_idx - 1] || + nodes[node_idx]->src[0] != probs_reshaped) { + return false; } + node_idx++; + + static const std::vector remaining_ops = { GGML_OP_RESHAPE, GGML_OP_SOFT_MAX, GGML_OP_RESHAPE }; + + for (const ggml_op op : remaining_ops) { + if (node_idx >= n_nodes || nodes[node_idx]->op != op || nodes[node_idx]->src[0] != nodes[node_idx - 1]) { + return false; + } + node_idx++; + } + } + + // At this point we can check for norm + scale. Everything is now at least valid till the norm + if (node_idx >= n_nodes) { + return true; } - if (is_equal(topk_moe_ops_delayed_softmax, ops) && - ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 1, node_idx + 5 })) { - ggml_tensor * softmax = cgraph->nodes[node_idx + 4]; - ggml_tensor * weights = cgraph->nodes[node_idx + 5]; - ggml_tensor * get_rows = cgraph->nodes[node_idx + 2]; - ggml_tensor * argsort = cgraph->nodes[node_idx + 0]; - int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0]; + if (nodes[node_idx]->op == GGML_OP_RESHAPE) { + //check RESHAPE->SUM_ROWS->CLAMP->DIV->RESHAPE + static const std::vector norm_ops = { GGML_OP_RESHAPE, GGML_OP_SUM_ROWS, GGML_OP_CLAMP }; - if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) { + args.norm = true; + for (const ggml_op op : norm_ops) { + if (nodes[node_idx]->op == op && nodes[node_idx]->src[0] == nodes[node_idx - 1]) { + node_idx++; + } else { + args.norm = false; + return true; + } + } + + // DIV <- CLAMP, RESHAPE + if (nodes[node_idx]->op != GGML_OP_DIV || nodes[node_idx]->src[1] != nodes[node_idx - 1] || + nodes[node_idx]->src[0] != nodes[node_idx - 3]) { + args.norm = false; + return true; + } + node_idx++; + + if (nodes[node_idx]->op != GGML_OP_RESHAPE || nodes[node_idx]->src[0] != nodes[node_idx - 1]) { + args.norm = false; return true; } + + node_idx++; } + if (nodes[node_idx]->op == GGML_OP_SCALE && nodes[node_idx]->src[0] == nodes[node_idx - 1]) { + args.scale = true; + } + + return true; +} + +static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, + int node_idx, + std::initializer_list ops, + std::initializer_list unary_ops) { +#ifndef NDEBUG + const size_t num_unary = std::count(ops.begin(), ops.end(), GGML_OP_UNARY); + GGML_ASSERT(unary_ops.size() == num_unary); +#endif + + const auto is_equal = [](const std::initializer_list & list1, + const std::initializer_list & list2) { + return std::equal(list1.begin(), list1.end(), list2.begin(), list2.end()); + }; + std::initializer_list mul_mat_bias_glu_ops = { GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_GLU }; std::initializer_list mul_mat_id_bias_glu_ops = { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_GLU }; @@ -3398,35 +3501,75 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud // start of fusion operations static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr); if (!disable_fusion) { + ggml_cuda_topk_moe_args args; + + if (cgraph->nodes[i]->op == GGML_OP_UNARY || cgraph->nodes[i]->op == GGML_OP_SOFT_MAX || + cgraph->nodes[i]->op == GGML_OP_ARGSORT) { + const bool can_fuse = ggml_cuda_topk_moe_fusion(cgraph, i, args); + + std::vector ops; + + if (can_fuse) { + const ggml_tensor * logits = node->src[0]; + ggml_tensor * weights = nullptr; + ggml_tensor * ids = nullptr; + const ggml_tensor * bias = nullptr; + const ggml_tensor * clamp = nullptr; + const ggml_tensor * scale = nullptr; + + if (!args.delayed_softmax) { + ggml_op gating_op = args.sigmoid ? GGML_OP_UNARY : GGML_OP_SOFT_MAX; + int out_nodes[2]; // nodes which can't be elided + + if (args.prob_bias) { + bias = cgraph->nodes[i + 2]->src[1]; + ops.insert(ops.end(), { gating_op, GGML_OP_RESHAPE, GGML_OP_ADD, GGML_OP_ARGSORT, + GGML_OP_VIEW, GGML_OP_GET_ROWS }); + out_nodes[0] = i + 4; + ids = cgraph->nodes[i + 4]; + } else { + ops.insert(ops.end(), { gating_op, GGML_OP_RESHAPE, GGML_OP_ARGSORT, GGML_OP_VIEW, + GGML_OP_GET_ROWS }); + out_nodes[0] = i + 3; + ids = cgraph->nodes[i + 3]; + } - if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ true), {})) { - ggml_tensor * weights = cgraph->nodes[i + 9]; - ggml_tensor * selected_experts = cgraph->nodes[i + 3]; - ggml_tensor * clamp = cgraph->nodes[i + 7]; - ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ true, - /*delayed softmax*/ false, clamp); - i += 9; - continue; - } - - if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ false), {})) { - ggml_tensor * weights = cgraph->nodes[i + 4]; - ggml_tensor * selected_experts = cgraph->nodes[i + 3]; - ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ false, - /*delayed softmax*/ false); - i += 4; - continue; - } + if (args.norm) { + ops.insert(ops.end(), { GGML_OP_RESHAPE, GGML_OP_SUM_ROWS, GGML_OP_CLAMP, + GGML_OP_DIV, GGML_OP_RESHAPE }); + clamp = cgraph->nodes[i + ops.size() - 3]; + } + if (args.scale) { + ops.insert(ops.end(), { GGML_OP_SCALE }); + scale = cgraph->nodes[i + ops.size() - 1]; + } - if (ggml_cuda_can_fuse(cgraph, i, - ggml_cuda_topk_moe_ops(/*with norm*/ false, /*delayed softmax*/ true), {})) { - ggml_tensor * weights = cgraph->nodes[i + 5]; - ggml_tensor * ids = cgraph->nodes[i + 1]; + weights = cgraph->nodes[i + ops.size() - 1]; + out_nodes[1] = i + ops.size() - 1; - ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, ids, /*with norm*/ false, - /*delayed_softmax*/ true); - i += 5; - continue; + if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) && + ggml_cuda_should_use_topk_moe(node, logits, weights, ids)) { + ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args); + i += ops.size() - 1; + continue; + } + } else if (!args.norm && !args.prob_bias) { + //special case gpt-oss, no norm, no bias. + ops.insert(ops.end(), { GGML_OP_ARGSORT, GGML_OP_VIEW, GGML_OP_GET_ROWS, + GGML_OP_RESHAPE, GGML_OP_SOFT_MAX, GGML_OP_RESHAPE }); + weights = cgraph->nodes[i + 5]; + ids = cgraph->nodes[i + 1]; + const ggml_tensor * softmax = cgraph->nodes[i + 4]; + + int out_nodes[2] = { i + 1, i + 5 }; + if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) && + ggml_cuda_should_use_topk_moe(softmax, logits, weights, ids)) { + ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args); + i += ops.size() - 1; + continue; + } + } + } } if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, {})) { diff --git a/ggml/src/ggml-cuda/topk-moe.cu b/ggml/src/ggml-cuda/topk-moe.cu index 48e569efa0d..08a88990dde 100644 --- a/ggml/src/ggml-cuda/topk-moe.cu +++ b/ggml/src/ggml-cuda/topk-moe.cu @@ -5,6 +5,13 @@ #include #include +// Kernel config struct - passed by value to CUDA kernel +struct topk_moe_config { + bool use_sigmoid; + bool with_norm; + bool delayed_softmax; +}; + // Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path. template __device__ void softmax_warp_inplace(float (&vals)[experts_per_thread], const int limit, const int lane) { @@ -50,6 +57,16 @@ __device__ void softmax_warp_inplace(float (&vals)[experts_per_thread], const in } } +template +__device__ void sigmoid_warp_inplace(float (&vals)[experts_per_thread], const int limit, const int lane) { +#pragma unroll + for (int i = 0; i < experts_per_thread; i++) { + const int idx = lane + i * WARP_SIZE; + const bool active = !use_limit || (idx < limit); + vals[i] = active ? 1.f / (1.f + expf(-vals[i])) : -INFINITY; + } +} + /* This kernel does the following: 1. optionally softmax over the logits per token [n_experts, n_tokens] @@ -59,13 +76,16 @@ __device__ void softmax_warp_inplace(float (&vals)[experts_per_thread], const in It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models */ -template -__launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * logits, - float * weights, - int32_t * ids, - const int n_rows, - const int n_expert_used, - const float clamp_val) { +template +__launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * logits, + float * weights, + int32_t * ids, + float * bias, + const int n_rows, + const int n_expert_used, + const float clamp_val, + const float scale_val, + const topk_moe_config config) { const int row = blockIdx.x * blockDim.y + threadIdx.y; if (row >= n_rows) { return; @@ -79,14 +99,41 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * float wt[experts_per_thread]; + // Initialize all slots to -INFINITY +#pragma unroll + for (int i = 0; i < experts_per_thread; i++) { + wt[i] = -INFINITY; + } + #pragma unroll for (int i = 0; i < n_experts; i += WARP_SIZE) { const int expert = i + threadIdx.x; wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[expert] : -INFINITY; } - if constexpr (!delayed_softmax) { - softmax_warp_inplace(wt, n_experts, threadIdx.x); + if (!config.delayed_softmax) { + if (config.use_sigmoid) { + sigmoid_warp_inplace(wt, n_experts, threadIdx.x); + } else { + softmax_warp_inplace(wt, n_experts, threadIdx.x); + } + } + + // selection_wt is only needed when bias is present (selection uses wt + bias) + // when no bias, we use wt directly for both selection and weight values + float selection_wt[has_bias ? experts_per_thread : 1]; + + if constexpr (has_bias) { +#pragma unroll + for (int i = 0; i < experts_per_thread; i++) { + selection_wt[i] = -INFINITY; + } +#pragma unroll + for (int i = 0; i < n_experts; i += WARP_SIZE) { + const int expert = i + threadIdx.x; + selection_wt[i / WARP_SIZE] = + (n_experts % WARP_SIZE == 0 || expert < n_experts) ? wt[i / WARP_SIZE] + bias[expert] : -INFINITY; + } } //at this point, each thread holds either a portion of the softmax distribution @@ -106,22 +153,56 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * float max_val = wt[0]; int max_expert = threadIdx.x; + if constexpr (has_bias) { + float max_val_s = selection_wt[0]; + #pragma unroll - for (int i = 1; i < experts_per_thread; i++) { - const int expert = threadIdx.x + i * WARP_SIZE; - if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) { - max_val = wt[i]; - max_expert = expert; + for (int i = 1; i < experts_per_thread; i++) { + const int expert = threadIdx.x + i * WARP_SIZE; + if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && selection_wt[i] > max_val_s) { + max_val = wt[i]; + max_val_s = selection_wt[i]; + max_expert = expert; + } + } + +#pragma unroll + for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2) { + const float val = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, WARP_SIZE); + const float val_s = __shfl_xor_sync(0xFFFFFFFF, max_val_s, mask, WARP_SIZE); + const int expert = __shfl_xor_sync(0xFFFFFFFF, max_expert, mask, WARP_SIZE); + if (val_s > max_val_s || (val_s == max_val_s && expert < max_expert)) { + max_val = val; + max_val_s = val_s; + max_expert = expert; + } + } + + if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) { + selection_wt[max_expert / WARP_SIZE] = -INFINITY; + } + } else { +#pragma unroll + for (int i = 1; i < experts_per_thread; i++) { + const int expert = threadIdx.x + i * WARP_SIZE; + if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) { + max_val = wt[i]; + max_expert = expert; + } } - } #pragma unroll - for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2) { - const float val = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, WARP_SIZE); - const int expert = __shfl_xor_sync(0xFFFFFFFF, max_expert, mask, WARP_SIZE); - if (val > max_val || (val == max_val && expert < max_expert)) { - max_val = val; - max_expert = expert; + for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2) { + const float val = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, WARP_SIZE); + const int expert = __shfl_xor_sync(0xFFFFFFFF, max_expert, mask, WARP_SIZE); + if (val > max_val || (val == max_val && expert < max_expert)) { + max_val = val; + max_expert = expert; + } + } + + if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) { + wt[max_expert / WARP_SIZE] = -INFINITY; } } @@ -130,16 +211,14 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * } if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) { - wt[max_expert / WARP_SIZE] = -INFINITY; - ids[k] = max_expert; - if constexpr (with_norm) { + if (config.with_norm) { wt_sum += max_val; } } } - if constexpr (with_norm) { + if (config.with_norm) { wt_sum = warp_reduce_sum(wt_sum); wt_sum = max(wt_sum, clamp_val); const float inv_sum = 1.0f / wt_sum; @@ -149,7 +228,7 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * } } - if constexpr (delayed_softmax) { + if (config.delayed_softmax) { softmax_warp_inplace(output_weights, n_expert_used, threadIdx.x); } @@ -157,25 +236,25 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * for (int i = 0; i < experts_per_thread; i++) { const int idx = i * WARP_SIZE + threadIdx.x; if (idx < n_expert_used) { - weights[idx] = output_weights[i]; + weights[idx] = output_weights[i] * scale_val; } } - - if (!with_norm) { - GGML_UNUSED(clamp_val); - } } -template +template static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx, const float * logits, float * weights, int32_t * ids, + float * bias, const int n_rows, const int n_expert, const int n_expert_used, - const float clamp_val) { - static_assert(!(with_norm && delayed_softmax), "delayed softmax is not supported with weight normalization"); + const float clamp_val, + const float scale_val, + const topk_moe_config config) { + GGML_ASSERT(!(config.with_norm && config.delayed_softmax) && + "delayed softmax is not supported with weight normalization"); const int rows_per_block = 4; dim3 grid_dims((n_rows + rows_per_block - 1) / rows_per_block, 1, 1); dim3 block_dims(WARP_SIZE, rows_per_block, 1); @@ -183,44 +262,48 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx, switch (n_expert) { case 1: - topk_moe_cuda<1, with_norm, delayed_softmax> - <<>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); + topk_moe_cuda<1, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used, + clamp_val, scale_val, config); break; case 2: - topk_moe_cuda<2, with_norm, delayed_softmax> - <<>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); + topk_moe_cuda<2, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used, + clamp_val, scale_val, config); break; case 4: - topk_moe_cuda<4, with_norm, delayed_softmax> - <<>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); + topk_moe_cuda<4, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used, + clamp_val, scale_val, config); break; case 8: - topk_moe_cuda<8, with_norm, delayed_softmax> - <<>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); + topk_moe_cuda<8, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used, + clamp_val, scale_val, config); break; case 16: - topk_moe_cuda<16, with_norm, delayed_softmax> - <<>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); + topk_moe_cuda<16, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used, + clamp_val, scale_val, config); break; case 32: - topk_moe_cuda<32, with_norm, delayed_softmax> - <<>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); + topk_moe_cuda<32, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used, + clamp_val, scale_val, config); break; case 64: - topk_moe_cuda<64, with_norm, delayed_softmax> - <<>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); + topk_moe_cuda<64, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used, + clamp_val, scale_val, config); break; case 128: - topk_moe_cuda<128, with_norm, delayed_softmax> - <<>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); + topk_moe_cuda<128, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used, + clamp_val, scale_val, config); break; case 256: - topk_moe_cuda<256, with_norm, delayed_softmax> - <<>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); + topk_moe_cuda<256, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used, + clamp_val, scale_val, config); break; case 512: - topk_moe_cuda<512, with_norm, delayed_softmax> - <<>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); + topk_moe_cuda<512, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used, + clamp_val, scale_val, config); + break; + case 576: + topk_moe_cuda<576, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used, + clamp_val, scale_val, config); break; default: GGML_ASSERT(false && "fatal error"); @@ -228,13 +311,14 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx, } } -void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, - const ggml_tensor * logits, - ggml_tensor * weights, - ggml_tensor * ids, - const bool with_norm, - const bool delayed_softmax, - ggml_tensor * clamp) { +void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, + const ggml_tensor * logits, + ggml_tensor * weights, + ggml_tensor * ids, + const ggml_tensor * clamp, + const ggml_tensor * scale, + const ggml_tensor * bias, + const ggml_cuda_topk_moe_args & args) { GGML_ASSERT(logits->type == GGML_TYPE_F32); GGML_ASSERT(weights->type == GGML_TYPE_F32); GGML_ASSERT(ids->type == GGML_TYPE_I32); @@ -245,107 +329,75 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, const float * logits_d = (const float *) logits->data; float * weights_d = (float *) weights->data; int32_t * ids_d = (int32_t *) ids->data; + float * bias_d = bias ? (float *) bias->data : nullptr; + + float scale_val = scale ? ggml_get_op_params_f32(scale, 0) : 1.0f; GGML_ASSERT(ids->nb[1] / ggml_type_size(ids->type) == (size_t) n_experts); const int n_expert_used = weights->ne[1]; + const bool with_norm = clamp != nullptr; + float clamp_val = -INFINITY; - if (with_norm) { - if (clamp) { - clamp_val = ggml_get_op_params_f32(clamp, 0); - } - launch_topk_moe_cuda(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used, clamp_val); + if (clamp) { + clamp_val = ggml_get_op_params_f32(clamp, 0); + } + + topk_moe_config config; + config.use_sigmoid = args.sigmoid; + config.with_norm = with_norm; + config.delayed_softmax = args.delayed_softmax; + + if (bias) { + launch_topk_moe_cuda(ctx, logits_d, weights_d, ids_d, bias_d, n_rows, n_experts, n_expert_used, clamp_val, + scale_val, config); } else { - GGML_ASSERT(clamp == nullptr); - if (delayed_softmax) { - launch_topk_moe_cuda(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used, - clamp_val); - } else { - launch_topk_moe_cuda(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used, - clamp_val); - } + launch_topk_moe_cuda(ctx, logits_d, weights_d, ids_d, bias_d, n_rows, n_experts, n_expert_used, clamp_val, + scale_val, config); } } -bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, +bool ggml_cuda_should_use_topk_moe(const ggml_tensor * gating_op, const ggml_tensor * weights, - const ggml_tensor * get_rows, - const ggml_tensor * argsort, - const ggml_tensor * clamp, - int n_expert) { - ggml_tensor * probs = get_rows->src[0]; - if (probs->op != GGML_OP_RESHAPE) { + const ggml_tensor * logits, + const ggml_tensor * ids) { + const int n_expert = ids->nb[1] / ids->nb[0]; + if (((n_expert & (n_expert - 1)) != 0 || n_expert > 512) && n_expert != 576) { return false; } - probs = probs->src[0]; - ggml_tensor * selection_probs = argsort->src[0]; - if (probs != selection_probs) { + if (!ggml_is_contiguous(weights) || !ggml_is_contiguous(logits)) { return false; } - float scale = 1.0f; - float max_bias = 0.0f; + if (gating_op->op == GGML_OP_SOFT_MAX) { + const ggml_tensor * softmax = gating_op; + float scale = 1.0f; + float max_bias = 0.0f; - memcpy(&scale, (const float *) softmax->op_params + 0, sizeof(float)); - memcpy(&max_bias, (const float *) softmax->op_params + 1, sizeof(float)); + memcpy(&scale, (const float *) softmax->op_params + 0, sizeof(float)); + memcpy(&max_bias, (const float *) softmax->op_params + 1, sizeof(float)); - if (!ggml_is_contiguous(softmax->src[0]) || !ggml_is_contiguous(weights)) { - return false; - } - - if (scale != 1.0f || max_bias != 0.0f) { - return false; - } - - // don't fuse when masks or sinks are present - if (softmax->src[1] || softmax->src[2]) { - return false; - } + if (!ggml_is_contiguous(softmax->src[0])) { + return false; + } - // n_expert must be a power of 2 - if ((n_expert & (n_expert - 1)) != 0 || n_expert > 512) { - return false; - } + if (scale != 1.0f || max_bias != 0.0f) { + return false; + } - if (clamp) { - if (clamp->op != GGML_OP_CLAMP) { + // don't fuse when masks or sinks are present + if (softmax->src[1] || softmax->src[2]) { return false; } - float max_val = ggml_get_op_params_f32(clamp, 1); + } else if (gating_op->op == GGML_OP_UNARY) { + ggml_unary_op op = ggml_get_unary_op(gating_op); - if (max_val != INFINITY) { + if (op != GGML_UNARY_OP_SIGMOID) { return false; } } - return true; } - -std::initializer_list ggml_cuda_topk_moe_ops(bool norm, bool delayed_softmax) { - static std::initializer_list norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT, - GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE, - GGML_OP_SUM_ROWS, GGML_OP_CLAMP, GGML_OP_DIV, - GGML_OP_RESHAPE }; - - static std::initializer_list no_norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT, - GGML_OP_VIEW, GGML_OP_GET_ROWS }; - - static std::initializer_list delayed_softmax_ops = { GGML_OP_ARGSORT, GGML_OP_VIEW, - GGML_OP_GET_ROWS, GGML_OP_RESHAPE, - GGML_OP_SOFT_MAX, GGML_OP_RESHAPE }; - - GGML_ASSERT(!norm || !delayed_softmax); - - if (delayed_softmax) { - return delayed_softmax_ops; - } - - if (norm) { - return norm_ops; - } - - return no_norm_ops; -} diff --git a/ggml/src/ggml-cuda/topk-moe.cuh b/ggml/src/ggml-cuda/topk-moe.cuh index 6b6c13c5870..243dc2f1c41 100644 --- a/ggml/src/ggml-cuda/topk-moe.cuh +++ b/ggml/src/ggml-cuda/topk-moe.cuh @@ -3,19 +3,25 @@ #include -void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, - const ggml_tensor * logits, - ggml_tensor * weights, - ggml_tensor * ids, - const bool with_norm, - const bool delayed_softmax = false, - ggml_tensor * weight_clamp = nullptr); +struct ggml_cuda_topk_moe_args { + bool sigmoid{}; + bool softmax{}; + bool delayed_softmax{}; + bool prob_bias{}; + bool norm{}; + bool scale{}; +}; -bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, - const ggml_tensor * weights, - const ggml_tensor * get_rows, - const ggml_tensor * argsort, - const ggml_tensor * clamp, - int n_expert); +void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, + const ggml_tensor * logits, + ggml_tensor * weights, + ggml_tensor * ids, + const ggml_tensor * clamp, + const ggml_tensor * scale, + const ggml_tensor * bias, + const ggml_cuda_topk_moe_args & args); -std::initializer_list ggml_cuda_topk_moe_ops(bool with_norm, bool delayed_softmax = false); +bool ggml_cuda_should_use_topk_moe(const ggml_tensor * gating_op, + const ggml_tensor * weights, + const ggml_tensor * logits, + const ggml_tensor * ids); From e0a2182970ed7c1e4e75f381c4b5e18407d2d670 Mon Sep 17 00:00:00 2001 From: Vishal Singh Date: Thu, 29 Jan 2026 09:58:57 +0530 Subject: [PATCH 061/831] ggml-zendnn : resolve ZenDNN backend cross-module symbol dependency (llama/19159) --- ggml/src/ggml-zendnn/ggml-zendnn.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-zendnn/ggml-zendnn.cpp b/ggml/src/ggml-zendnn/ggml-zendnn.cpp index afbecde7a5a..551c15bb4ae 100644 --- a/ggml/src/ggml-zendnn/ggml-zendnn.cpp +++ b/ggml/src/ggml-zendnn/ggml-zendnn.cpp @@ -2,7 +2,6 @@ #include "ggml-backend-impl.h" #include "ggml-impl.h" -#include "ggml-cpu.h" #include "zendnnl.hpp" #include @@ -122,8 +121,8 @@ static void ggml_zendnn_compute_forward_mul_mat( GGML_TENSOR_BINARY_OP_LOCALS - ggml_type const vec_dot_type = ggml_get_type_traits_cpu(src0->type)->vec_dot_type; - ggml_from_float_t const from_float = ggml_get_type_traits_cpu(vec_dot_type)->from_float; + ggml_type const vec_dot_type = src0->type; + ggml_from_float_t const from_float = ggml_get_type_traits(vec_dot_type)->from_float_ref; GGML_ASSERT(ne0 == ne01); GGML_ASSERT(ne1 == ne11); From 34a3e28a084bf2819e12a9d7bd920313ba44663b Mon Sep 17 00:00:00 2001 From: yulo <77381088+zhang-hui-yulo@users.noreply.github.com> Date: Thu, 29 Jan 2026 18:10:53 +0800 Subject: [PATCH 062/831] HIP: add mmf for CDNA (llama/18896) * refactor mmf rows_per_block * speed up compile * pass cdna compile * fix cuda error * clean up mmf * f32 mmf * clean float mma * fix mmf error * faster mmf * extend tile k * fix compile error * Revert "extend tile k" This reverts commit 4d2ef3d483932659801a59a5af0b6b48f6ffd5c7. * fix smem overflow * speed up compiling mmf * speed up compile for hip * 512 block for cdna * config pad size * fix as comment * update select logic * move some code to cuh * fix as comment * correct cdna3 config --------- Co-authored-by: zhang hui --- ggml/src/ggml-cuda/mma.cuh | 102 ++++++++++++- ggml/src/ggml-cuda/mmf.cu | 40 +++-- ggml/src/ggml-cuda/mmf.cuh | 243 ++++++++++++++++++++----------- ggml/src/ggml-hip/CMakeLists.txt | 2 + 4 files changed, 288 insertions(+), 99 deletions(-) diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index 42085d10027..dd45d6c78fd 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -333,7 +333,33 @@ namespace ggml_cuda_mma { static __device__ __forceinline__ int get_j(const int l) { if constexpr (I == 16 && J == 8) { - return 4 * (threadIdx.x / 16) + l; + return ne * (threadIdx.x / 16) + l; + } else { + NO_DEVICE_CODE; + return -1; + } + } +#elif defined(AMD_MFMA_AVAILABLE) + static constexpr int ne = I * J / 64; + half2 x[ne] = {{0.0f, 0.0f}}; + + static constexpr __device__ bool supported() { + if (I == 16 && J == 8) return true; + return false; + } + + static __device__ __forceinline__ int get_i(const int l) { + if constexpr (I == 16 && J == 8) { + return threadIdx.x % 16; + } else { + NO_DEVICE_CODE; + return -1; + } + } + + static __device__ __forceinline__ int get_j(const int l) { + if constexpr (I == 16 && J == 8) { + return ne * (threadIdx.x / 16) + l; } else { NO_DEVICE_CODE; return -1; @@ -391,7 +417,22 @@ namespace ggml_cuda_mma { static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR; #if defined(AMD_WMMA_AVAILABLE) - static constexpr int ne = I * J / 32; + static constexpr int ne = tile::ne; + nv_bfloat162 x[ne] = {{0.0f, 0.0f}}; + + static constexpr __device__ bool supported() { + return tile::supported(); + } + + static __device__ __forceinline__ int get_i(const int l) { + return tile::get_i(l); + } + + static __device__ __forceinline__ int get_j(const int l) { + return tile::get_j(l); + } +#elif defined(AMD_MFMA_AVAILABLE) + static constexpr int ne = tile::ne; nv_bfloat162 x[ne] = {{0.0f, 0.0f}}; static constexpr __device__ bool supported() { @@ -945,6 +986,32 @@ namespace ggml_cuda_mma { #endif // AMPERE_MMA_AVAILABLE } + template + static __device__ __forceinline__ void mma( + tile<16, 16, float, dl_d> & D, const tile<16, 8, float, dl_ab> & A, const tile<16, 8, float, dl_ab> & B) { +#ifdef AMD_MFMA_AVAILABLE + using floatx4_t = __attribute__((ext_vector_type(4))) float; + floatx4_t& acc_frag = reinterpret_cast(D.x[0]); +#if defined(CDNA3) + using floatx2_t = __attribute__((ext_vector_type(2))) float; + const floatx2_t& a_frag = reinterpret_cast(A.x[0]); + const floatx2_t& b_frag = reinterpret_cast(B.x[0]); + acc_frag = __builtin_amdgcn_mfma_f32_16x16x8_xf32(a_frag, b_frag, acc_frag, 0, 0, 0); +#elif defined(CDNA2) || defined(CDNA1) +#pragma unroll + for (int i = 0; i < 2; ++i) { + acc_frag = __builtin_amdgcn_mfma_f32_16x16x4f32(A.x[i], B.x[i], acc_frag, 0, 0, 0); + } +#else + GGML_UNUSED_VARS(D, A, B); + NO_DEVICE_CODE; +#endif // defined(CDNA3) +#else + GGML_UNUSED_VARS(D, A, B); + NO_DEVICE_CODE; +#endif // AMD_MFMA_AVAILABLE + } + static __device__ __forceinline__ void mma_block_scaled(tile<16, 8, float> & D, const tile<16, 8, int> & A, const tile<8, 8, int> & B, @@ -1054,6 +1121,13 @@ namespace ggml_cuda_mma { GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; #endif // RDNA4 +#elif defined(AMD_MFMA_AVAILABLE) + using halfx4_t = __attribute__((ext_vector_type(4))) _Float16; + using floatx4_t = __attribute__((ext_vector_type(4))) float; + floatx4_t& acc_frag = reinterpret_cast(D.x[0]); + const halfx4_t& a_frag = reinterpret_cast(A.x[0]); + const halfx4_t& b_frag = reinterpret_cast(B.x[0]); + acc_frag = __builtin_amdgcn_mfma_f32_16x16x16f16(a_frag, b_frag, acc_frag, 0, 0, 0); #else GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; @@ -1081,11 +1155,31 @@ namespace ggml_cuda_mma { #else GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; -#endif // RDNA4 +#endif // defined(RDNA4) +#elif defined(AMD_MFMA_AVAILABLE) + using floatx4_t = __attribute__((ext_vector_type(4))) float; + floatx4_t& acc_frag = reinterpret_cast(D.x[0]); +#if defined(CDNA3) || defined(CDNA2) + using bf16x4_t = __attribute__((ext_vector_type(4))) __bf16; + const bf16x4_t& a_frag = reinterpret_cast(A.x[0]); + const bf16x4_t& b_frag = reinterpret_cast(B.x[0]); + acc_frag = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_frag, b_frag, acc_frag, 0, 0, 0); +#elif defined(CDNA1) +#pragma unroll + for (int i = 0; i < 2; ++i) { + using bf16x2_t = __attribute__((ext_vector_type(2))) __bf16; + const bf16x2_t& a_frag = reinterpret_cast(A.x[i]); + const bf16x2_t& b_frag = reinterpret_cast(B.x[i]); + acc_frag = __builtin_amdgcn_mfma_f32_16x16x8bf16(a_frag, b_frag, acc_frag, 0, 0, 0); + } #else GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; -#endif // AMPERE_MMA_AVAILABLE +#endif // defined(CDNA3) || defined(CDNA2) +#else + GGML_UNUSED_VARS(D, A, B); + NO_DEVICE_CODE; +#endif // defined(AMD_WMMA_AVAILABLE) } template diff --git a/ggml/src/ggml-cuda/mmf.cu b/ggml/src/ggml-cuda/mmf.cu index 6643f243b12..aad4c34aa66 100644 --- a/ggml/src/ggml-cuda/mmf.cu +++ b/ggml/src/ggml-cuda/mmf.cu @@ -2,6 +2,13 @@ #include "mmf.cuh" #include "mmid.cuh" +static __forceinline__ int mmf_get_rows_per_block(const int cc) { + if (GGML_CUDA_CC_IS_CDNA(cc)) { + return MMF_ROWS_PER_BLOCK_CDNA; + } else { + return MMF_ROWS_PER_BLOCK; + } +} void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) { GGML_ASSERT( src1->type == GGML_TYPE_F32); @@ -89,28 +96,32 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr ids_info_ptr = &ids_info; } + const int device = ggml_cuda_get_device(); + const int cc = ggml_cuda_info().devices[device].cc; + const int rows_per_block = mmf_get_rows_per_block(cc); + switch (src0->type) { case GGML_TYPE_F32: { const float * src0_d = (const float *) src0->data; constexpr int vals_per_T = 1; - mul_mat_f_switch_cols_per_block( - src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst, + mul_mat_f_switch_rows_per_block( + rows_per_block, src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst, ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr); } break; case GGML_TYPE_F16: { const half2 * src0_d = (const half2 *) src0->data; constexpr int vals_per_T = 2; - mul_mat_f_switch_cols_per_block( - src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst, + mul_mat_f_switch_rows_per_block( + rows_per_block, src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst, ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr); } break; case GGML_TYPE_BF16: { const nv_bfloat162 * src0_d = (const nv_bfloat162 *) src0->data; constexpr int vals_per_T = 2; - mul_mat_f_switch_cols_per_block( - src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst, + mul_mat_f_switch_rows_per_block( + rows_per_block, src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst, ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr); } break; @@ -140,7 +151,11 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const return false; } } - if (src0_ne[1] % MMF_ROWS_PER_BLOCK != 0) { + if (src0_ne[1] % mmf_get_rows_per_block(cc) != 0) { + return false; + } + + if (GGML_CUDA_CC_IS_CDNA3(cc) && type == GGML_TYPE_BF16) { return false; } @@ -153,6 +168,11 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const } else { if (GGML_CUDA_CC_IS_RDNA3_0(cc) && src1_ncols > 8) { return false; + } else if (GGML_CUDA_CC_IS_CDNA2(cc) && (type == GGML_TYPE_F16 || type == GGML_TYPE_BF16)) { + //TODO: truse CDNA2 as CDNA1, tune the perf when CDNA2 is available. + return false; + } else if (GGML_CUDA_CC_IS_CDNA1(cc) && (type == GGML_TYPE_F16 || type == GGML_TYPE_BF16)) { + return false; } else if (src1_ncols > 16) { return false; } @@ -160,11 +180,11 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const switch (type) { case GGML_TYPE_F32: - return ampere_mma_available(cc); + return ampere_mma_available(cc) || amd_mfma_available(cc); case GGML_TYPE_F16: - return volta_mma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc); + return volta_mma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc) || amd_mfma_available(cc); case GGML_TYPE_BF16: - return ampere_mma_available(cc) || amd_wmma_available(cc); + return ampere_mma_available(cc) || amd_wmma_available(cc) || amd_mfma_available(cc); default: return false; } diff --git a/ggml/src/ggml-cuda/mmf.cuh b/ggml/src/ggml-cuda/mmf.cuh index e36730948ff..c2a8d54c95a 100644 --- a/ggml/src/ggml-cuda/mmf.cuh +++ b/ggml/src/ggml-cuda/mmf.cuh @@ -7,6 +7,31 @@ using namespace ggml_cuda_mma; #define MMF_ROWS_PER_BLOCK 32 +#define MMF_ROWS_PER_BLOCK_CDNA 64 + +static __forceinline__ int64_t mmf_get_max_block_size(int cc) { + if (GGML_CUDA_CC_IS_CDNA(cc)) { + return 512; + } else { + return 256; + } +} + +static __forceinline__ int mmf_get_padding(int cc) { + if (GGML_CUDA_CC_IS_CDNA(cc)) { + return 2; + } else { + return 4; + } +} + +static constexpr __device__ int mmf_get_padding() { +#if defined(AMD_MFMA_AVAILABLE) + return 2; +#else + return 4; +#endif // defined(AMD_MFMA_AVAILABLE) +} struct mmf_ids_data { const int32_t * ids_src_compact = nullptr; @@ -29,23 +54,25 @@ static __global__ void mul_mat_f( const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) { // TODO: handle this in a consistent and simpler way after AMD MFMA support has been added -#if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE) +#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) #if defined(AMD_WMMA_AVAILABLE) - // Special case for tf32, just dummy mma layout as wmma doesn't support it. - constexpr bool is_tf32 = std::is_same_v; - constexpr int tile_B_I = is_tf32 ? 8 : 16; - constexpr int tile_C_J = is_tf32 ? 8 : 16; - constexpr data_layout ab_layout = is_tf32 ? DATA_LAYOUT_I_MAJOR : get_input_data_layout(); - typedef tile<16, 8, T, ab_layout> tile_A; - typedef tile tile_B; - typedef tile<16, tile_C_J, float, DATA_LAYOUT_J_MAJOR> tile_C; + if constexpr (!(std::is_same_v || std::is_same_v) || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else { + typedef tile<16, 8, T, get_input_data_layout()> tile_A; + typedef tile<16, 8, T, get_input_data_layout()> tile_B; + typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C; +#elif defined(AMD_MFMA_AVAILABLE) + if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK_CDNA) {NO_DEVICE_CODE;} else { + typedef tile<16, 8, T, DATA_LAYOUT_I_MAJOR> tile_A; + typedef tile<16, 8, T, DATA_LAYOUT_I_MAJOR> tile_B; + typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C; #else #ifdef VOLTA_MMA_AVAILABLE - if constexpr (!std::is_same_v) {NO_DEVICE_CODE;} else { + if constexpr (!std::is_same_v || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else { typedef tile<32, 4, T, DATA_LAYOUT_I_MAJOR> tile_A; typedef tile< 8, 4, T, DATA_LAYOUT_I_MAJOR_MIRRORED> tile_B; typedef tile<32, 8, float, DATA_LAYOUT_I_MAJOR> tile_C; #else + if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else { typedef tile<16, 8, T> tile_A; typedef tile<8, 8, T> tile_B; typedef tile<16, 8, float> tile_C; @@ -57,7 +84,7 @@ static __global__ void mul_mat_f( } constexpr int warp_size = ggml_cuda_get_physical_warp_size(); - constexpr int tile_k_padded = warp_size + 4; + constexpr int tile_k_padded = warp_size + mmf_get_padding(); constexpr int ntA = rows_per_block / tile_A::I; constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I; @@ -198,7 +225,7 @@ static __global__ void mul_mat_f( } float * buf_iw = (float *) compute_base; - constexpr int kiw = nwarps*rows_per_block + 4; + constexpr int kiw = nwarps*rows_per_block + mmf_get_padding(); if (nwarps > 1) { __syncthreads(); @@ -228,27 +255,34 @@ static __global__ void mul_mat_f( return; } - float sum = 0.0f; - static_assert(rows_per_block == warp_size, "need loop/check"); + float sum[rows_per_block/warp_size] = {0.0f}; + static_assert((rows_per_block % warp_size) == 0, "rows_per_block must be a multiple of warp_size."); #pragma unroll for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) { - const int i = i0 + threadIdx.x; +#pragma unroll + for (int i1 = 0; i1 < sizeof(sum)/sizeof(sum[0]); ++i1) { + const int i = i0 + i1*warp_size + threadIdx.x; - sum += buf_iw[j*kiw + i]; + sum[i1] += buf_iw[j*kiw + i]; + } } if constexpr (!has_ids) { - dst[j*stride_col_dst + row0 + threadIdx.x] = sum; +#pragma unroll + for (int i0 = 0; i0 < sizeof(sum)/sizeof(sum[0]); ++i0) { + dst[j*stride_col_dst + row0 + i0*warp_size + threadIdx.x] = sum[i0]; + } } else { const int slot = (j < cols_per_block) ? slot_map[j] : -1; if (slot >= 0 && (col_base + j) < ncols_dst_total) { - dst[slot*stride_channel_dst + j*stride_col_dst + row0 + threadIdx.x] = sum; +#pragma unroll + for (int i0 = 0; i0 < sizeof(sum)/sizeof(sum[0]); ++i0) { + dst[slot*stride_channel_dst + j*stride_col_dst + row0 + i0*warp_size + threadIdx.x] = sum[i0]; + } } } } -#ifdef VOLTA_MMA_AVAILABLE } -#endif //VOLTA_MMA_AVAILABLE #else GGML_UNUSED_VARS(x, y, ids, dst, ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst, @@ -256,7 +290,7 @@ static __global__ void mul_mat_f( channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); NO_DEVICE_CODE; -#endif // (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE) +#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) } //This kernel is for larger batch sizes of mul_mat_id @@ -271,23 +305,25 @@ static __global__ void mul_mat_f_ids( const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, const uint3 sis1_fd, const uint3 nch_fd) { // TODO: handle this in a consistent and simpler way after AMD MFMA support has been added -#if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE) +#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) #if defined(AMD_WMMA_AVAILABLE) - // Special case for tf32, just dummy mma layout as wmma doesn't support it. - constexpr bool is_tf32 = std::is_same_v; - constexpr int tile_B_I = is_tf32 ? 8 : 16; - constexpr int tile_C_J = is_tf32 ? 8 : 16; - constexpr data_layout ab_layout = is_tf32 ? DATA_LAYOUT_I_MAJOR : get_input_data_layout(); - typedef tile<16, 8, T, ab_layout> tile_A; - typedef tile tile_B; - typedef tile<16, tile_C_J, float, DATA_LAYOUT_J_MAJOR> tile_C; + if constexpr (!(std::is_same_v || std::is_same_v) || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else { + typedef tile<16, 8, T, get_input_data_layout()> tile_A; + typedef tile<16, 8, T, get_input_data_layout()> tile_B; + typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C; +#elif defined(AMD_MFMA_AVAILABLE) + if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK_CDNA) {NO_DEVICE_CODE;} else { + typedef tile<16, 8, T, DATA_LAYOUT_I_MAJOR> tile_A; + typedef tile<16, 8, T, DATA_LAYOUT_I_MAJOR> tile_B; + typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C; #else #ifdef VOLTA_MMA_AVAILABLE - if constexpr (!std::is_same_v) {NO_DEVICE_CODE;} else { + if constexpr (!std::is_same_v || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else { typedef tile<32, 4, T, DATA_LAYOUT_I_MAJOR> tile_A; typedef tile< 8, 4, T, DATA_LAYOUT_I_MAJOR_MIRRORED> tile_B; typedef tile<32, 8, float, DATA_LAYOUT_I_MAJOR> tile_C; #else + if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else { typedef tile<16, 8, T> tile_A; typedef tile<8, 8, T> tile_B; typedef tile<16, 8, float> tile_C; @@ -300,7 +336,7 @@ static __global__ void mul_mat_f_ids( constexpr int warp_size = ggml_cuda_get_physical_warp_size(); - constexpr int tile_k_padded = warp_size + 4; + constexpr int tile_k_padded = warp_size + mmf_get_padding(); constexpr int ntA = rows_per_block / tile_A::I; constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I; @@ -467,7 +503,7 @@ static __global__ void mul_mat_f_ids( } float * buf_iw = (float *) compute_base; - constexpr int kiw = nwarps*rows_per_block + 4; + constexpr int kiw = nwarps*rows_per_block + mmf_get_padding(); if (nwarps > 1) { __syncthreads(); @@ -497,13 +533,16 @@ static __global__ void mul_mat_f_ids( return; } - float sum = 0.0f; - static_assert(rows_per_block == warp_size, "need loop/check"); + float sum[rows_per_block/warp_size] = {0.0f}; + static_assert((rows_per_block % warp_size) == 0, "rows_per_block must be a multiple of warp_size."); #pragma unroll for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) { - const int i = i0 + threadIdx.x; +#pragma unroll + for (int i1 = 0; i1 < sizeof(sum)/sizeof(sum[0]); ++i1) { + const int i = i0 + i1*warp_size + threadIdx.x; - sum += buf_iw[j*kiw + i]; + sum[i1] += buf_iw[j * kiw + i]; + } } const int global_j = col_base + j; @@ -513,23 +552,24 @@ static __global__ void mul_mat_f_ids( const int token = (int) qrm.x; if (token < ncols_dst_total) { const int slot = (int) qrm.y; - dst[slot*stride_channel_dst + token*stride_col_dst + row0 + threadIdx.x] = sum; +#pragma unroll + for (int i0 = 0; i0 < sizeof(sum)/sizeof(sum[0]); ++i0) { + dst[slot * stride_channel_dst + token * stride_col_dst + row0 + i0*warp_size + threadIdx.x] = sum[i0]; + } } } } -#ifdef VOLTA_MMA_AVAILABLE } -#endif // VOLTA_MMA_AVAILABLE #else GGML_UNUSED_VARS(x, y, ids_src_compact, ids_dst_compact, expert_bounds, dst, ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, sis1_fd, nch_fd); NO_DEVICE_CODE; -#endif // (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE) +#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) } -template +template static inline void mul_mat_f_switch_ids( const T * x, const float * y, const int32_t * ids, float * dst, const int64_t ncols_x, const int64_t ncols_dst, const int64_t nchannels_dst, @@ -553,7 +593,7 @@ static inline void mul_mat_f_switch_ids( const uint3 sis1_fd = ids_data->sis1 > 0 ? init_fastdiv_values((uint32_t) ids_data->sis1) : make_uint3(0, 0, 1); const uint3 nch_fd = init_fastdiv_values((uint32_t) nchannels_dst); - mul_mat_f_ids<<>> + mul_mat_f_ids<<>> (x, y, ids_data->ids_src_compact, ids_data->ids_dst_compact, ids_data->expert_bounds_dev, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, @@ -564,19 +604,19 @@ static inline void mul_mat_f_switch_ids( dim3 block_nums_ids = block_nums; block_nums_ids.y *= col_tiles; - mul_mat_f<<>> + mul_mat_f<<>> (x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); } else { - mul_mat_f<<>> + mul_mat_f<<>> (x, y, ids, dst, ncols_x, cols_per_block, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); } } -template +template void mul_mat_f_cuda( const T * x, const float * y, const int32_t * ids, float * dst, const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst, @@ -605,7 +645,7 @@ void mul_mat_f_cuda( int64_t nwarps_best = 1; int64_t niter_best = (ncols_x + warp_size*2 - 1) / (warp_size*2); - int64_t max_block_size = 256; + int64_t max_block_size = mmf_get_max_block_size(cc); for (int64_t nwarps = 2; nwarps <= max_block_size/warp_size; nwarps++) { const int64_t niter = (ncols_x + nwarps*warp_size*2 - 1) / (nwarps*warp_size*2); if (niter < niter_best) { @@ -614,10 +654,9 @@ void mul_mat_f_cuda( } } - constexpr int rows_per_block = MMF_ROWS_PER_BLOCK; - const int nbytes_shared_iter = nwarps_best * (volta_mma_available(cc) ? tile_A_32::I : tile_A_16::I) * (warp_size + 4) * 4; - const int nbytes_cols_per_block_pad = amd_wmma_available(cc) ? tile_B_16::I : tile_B_8::I; - const int nbytes_shared_combine = GGML_PAD(cols_per_block, nbytes_cols_per_block_pad) * (nwarps_best*rows_per_block + 4) * 4; + const int nbytes_shared_iter = nwarps_best * (volta_mma_available(cc) ? tile_A_32::I : tile_A_16::I) * (warp_size + mmf_get_padding(cc)) * 4; + const int nbytes_cols_per_block_pad = (amd_wmma_available(cc) || amd_mfma_available(cc)) ? tile_B_16::I : tile_B_8::I; + const int nbytes_shared_combine = GGML_PAD(cols_per_block, nbytes_cols_per_block_pad) * (nwarps_best*rows_per_block + mmf_get_padding(cc)) * 4; const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine); const int nbytes_slotmap = ids ? GGML_PAD(cols_per_block, 16) * sizeof(int) : 0; const int nbytes_shared_total = nbytes_shared + nbytes_slotmap; @@ -628,56 +667,56 @@ void mul_mat_f_cuda( switch (nwarps_best) { case 1: { - mul_mat_f_switch_ids( + mul_mat_f_switch_ids( x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, ids_data); } break; case 2: { - mul_mat_f_switch_ids( + mul_mat_f_switch_ids( x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, ids_data); } break; case 3: { - mul_mat_f_switch_ids( + mul_mat_f_switch_ids( x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, ids_data); } break; case 4: { - mul_mat_f_switch_ids( + mul_mat_f_switch_ids( x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, ids_data); } break; case 5: { - mul_mat_f_switch_ids( + mul_mat_f_switch_ids( x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, ids_data); } break; case 6: { - mul_mat_f_switch_ids( + mul_mat_f_switch_ids( x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, ids_data); } break; case 7: { - mul_mat_f_switch_ids( + mul_mat_f_switch_ids( x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, ids_data); } break; case 8: { - mul_mat_f_switch_ids( + mul_mat_f_switch_ids( x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, @@ -691,7 +730,7 @@ void mul_mat_f_cuda( GGML_UNUSED_VARS(nchannels_y); } -template +template static void mul_mat_f_switch_cols_per_block( const T * x, const float * y, const int32_t * ids, float * dst, const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst, @@ -708,82 +747,82 @@ static void mul_mat_f_switch_cols_per_block( switch (ncols_case) { case 1: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 2: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 3: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 4: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 5: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 6: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 7: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 8: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 9: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 10: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 11: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 12: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 13: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 14: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 15: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 16: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; @@ -793,8 +832,36 @@ static void mul_mat_f_switch_cols_per_block( } } -#define DECL_MMF_CASE_HELPER(T, ncols_dst) \ - template void mul_mat_f_cuda( \ +template +static void mul_mat_f_switch_rows_per_block( + const int rows_per_block, const T * x, const float * y, const int32_t * ids, float * dst, + const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst, + const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, + const int64_t stride_col_id, const int stride_row_id, + const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, + const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, + const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, + cudaStream_t stream, const mmf_ids_data * ids_data) { + switch (rows_per_block) { + case MMF_ROWS_PER_BLOCK: { + mul_mat_f_switch_cols_per_block( + x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); + } break; + case MMF_ROWS_PER_BLOCK_CDNA: { + mul_mat_f_switch_cols_per_block( + x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); + } break; + default: + GGML_ABORT("unsupported rows_per_block: %i", rows_per_block); + } +} + +#define DECL_MMF_CASE_HELPER(T, nrows_dst, ncols_dst) \ + template void mul_mat_f_cuda( \ const T * x, const float * y, const int32_t * ids, float * dst, \ const int64_t ncols_x, const int64_t nrows_x, int64_t ncols_dst_total, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, \ const int64_t stride_col_id, const int64_t stride_row_id, \ @@ -803,16 +870,22 @@ static void mul_mat_f_switch_cols_per_block( const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, \ cudaStream_t stream, const mmf_ids_data * ids_data); -#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) +#if !defined(GGML_USE_MUSA) #define DECL_MMF_CASE_EXTERN(ncols_dst) \ - extern DECL_MMF_CASE_HELPER(float, ncols_dst) \ - extern DECL_MMF_CASE_HELPER(half2, ncols_dst) \ - extern DECL_MMF_CASE_HELPER(nv_bfloat162, ncols_dst) + extern DECL_MMF_CASE_HELPER(float, MMF_ROWS_PER_BLOCK, ncols_dst) \ + extern DECL_MMF_CASE_HELPER(half2, MMF_ROWS_PER_BLOCK, ncols_dst) \ + extern DECL_MMF_CASE_HELPER(nv_bfloat162, MMF_ROWS_PER_BLOCK, ncols_dst) \ + extern DECL_MMF_CASE_HELPER(float, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) \ + extern DECL_MMF_CASE_HELPER(half2, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) \ + extern DECL_MMF_CASE_HELPER(nv_bfloat162, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) #define DECL_MMF_CASE(ncols_dst) \ - DECL_MMF_CASE_HELPER(float, ncols_dst) \ - DECL_MMF_CASE_HELPER(half2, ncols_dst) \ - DECL_MMF_CASE_HELPER(nv_bfloat162, ncols_dst) + DECL_MMF_CASE_HELPER(float, MMF_ROWS_PER_BLOCK, ncols_dst) \ + DECL_MMF_CASE_HELPER(half2, MMF_ROWS_PER_BLOCK, ncols_dst) \ + DECL_MMF_CASE_HELPER(nv_bfloat162, MMF_ROWS_PER_BLOCK, ncols_dst) \ + DECL_MMF_CASE_HELPER(float, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) \ + DECL_MMF_CASE_HELPER(half2, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) \ + DECL_MMF_CASE_HELPER(nv_bfloat162, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) DECL_MMF_CASE_EXTERN(1); DECL_MMF_CASE_EXTERN(2); diff --git a/ggml/src/ggml-hip/CMakeLists.txt b/ggml/src/ggml-hip/CMakeLists.txt index 23b6889919f..80037d24361 100644 --- a/ggml/src/ggml-hip/CMakeLists.txt +++ b/ggml/src/ggml-hip/CMakeLists.txt @@ -62,6 +62,8 @@ file(GLOB SRCS "../ggml-cuda/template-instances/fattn-mma*.cu") list(APPEND GGML_SOURCES_ROCM ${SRCS}) file(GLOB SRCS "../ggml-cuda/template-instances/mmq*.cu") list(APPEND GGML_SOURCES_ROCM ${SRCS}) +file(GLOB SRCS "../ggml-cuda/template-instances/mmf*.cu") +list(APPEND GGML_SOURCES_ROCM ${SRCS}) if (GGML_CUDA_FA_ALL_QUANTS) file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*.cu") From b997e690ef1205089f078d7a979515397661d15d Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 29 Jan 2026 18:45:30 +0200 Subject: [PATCH 063/831] cuda : fix nkvo, offload and cuda graph node properties matching (llama/19165) * cuda : fix nkvo * cont : more robust cuda graph node property matching * cont : restore pre-leafs implementation * cont : comments + static_assert --- ggml/src/ggml-cuda/common.cuh | 12 +++++- ggml/src/ggml-cuda/fattn.cu | 5 --- ggml/src/ggml-cuda/ggml-cuda.cu | 69 +++++++++++++++++++++++---------- 3 files changed, 59 insertions(+), 27 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 3335f443aeb..43280644e48 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -1122,15 +1122,17 @@ struct ggml_tensor_extra_gpu { #endif struct ggml_cuda_graph_node_properties { - void * node_address; + void * node_data; ggml_op node_op; int32_t flags; int64_t ne[GGML_MAX_DIMS]; size_t nb[GGML_MAX_DIMS]; - void * src_address[GGML_MAX_SRC]; + void * src_data[GGML_MAX_SRC]; int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)]; }; +static_assert(std::is_trivial::value, "ggml_cuda_graph_node_properties must be trivial"); + struct ggml_cuda_graph { #ifdef USE_CUDA_GRAPH ~ggml_cuda_graph() { @@ -1150,6 +1152,12 @@ struct ggml_cuda_graph { int number_consecutive_updates = 0; std::vector props; + // these are extra tensors (inputs) that participate in the ggml graph but are not nodes + // they properties also have to match in order to be able to safely reuse a CUDA graph + // ref: https://github.com/ggml-org/llama.cpp/pull/18583 + // ref: https://github.com/ggml-org/llama.cpp/pull/19165 + std::vector extra; + void record_update(bool use_graph, bool update_required) { if (use_graph && update_required) { number_consecutive_updates++; diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 195904ee206..721edd99944 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -310,8 +310,6 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const } } - const bool V_is_K_view = V->view_src && (V->view_src == K || (V->view_src == K->view_src && V->view_offs == K->view_offs)); - const int cc = ggml_cuda_info().devices[device].cc; switch (K->ne[0]) { @@ -334,9 +332,6 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const if (!gqa_opt_applies) { return BEST_FATTN_KERNEL_NONE; } - if (!V_is_K_view) { - return BEST_FATTN_KERNEL_NONE; - } break; default: return BEST_FATTN_KERNEL_NONE; diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 76d0f12550e..cfcffde8a21 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -70,17 +70,18 @@ #include #include #include -#include +#include #include #include #include #include #include -#include -#include -#include +#include +#include +#include #include #include +#include static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size"); @@ -2916,7 +2917,8 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) { } static void ggml_cuda_graph_node_set_properties(ggml_cuda_graph_node_properties * props, ggml_tensor * node) { - props->node_address = node->data; + memset(props, 0, sizeof(ggml_cuda_graph_node_properties)); + props->node_data = node->data; props->node_op = node->op; props->flags = node->flags; for (int i = 0; i < GGML_MAX_DIMS; i++) { @@ -2924,14 +2926,17 @@ static void ggml_cuda_graph_node_set_properties(ggml_cuda_graph_node_properties props->nb[i] = node->nb[i]; } for (int i = 0; i < GGML_MAX_SRC; i++) { - props->src_address[i] = node->src[i] ? node->src[i]->data : nullptr; + if (!node->src[i]) { + continue; + } + + props->src_data[i] = node->src[i]->data; } memcpy(props->op_params, node->op_params, GGML_MAX_OP_PARAMS); } static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_graph_node_properties * props) { - if (node->data != props->node_address && - node->op != GGML_OP_VIEW) { + if (node->data != props->node_data && node->op != GGML_OP_VIEW) { return false; } @@ -2948,12 +2953,18 @@ static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_ } } - for (int i = 0; i < GGML_MAX_SRC; i++) { - if (node->src[i] && - node->src[i]->data != props->src_address[i] && - node->op != GGML_OP_VIEW - ) { - return false; + if (node->op != GGML_OP_VIEW) { + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (!node->src[i]) { + if (props->src_data[i] != nullptr) { + return false; + } + continue; + } + + if (node->src[i]->data != props->src_data[i]) { + return false; + } } } @@ -2974,7 +2985,6 @@ static const void * ggml_cuda_graph_get_key(ggml_cgraph * cgraph) { } static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) { - bool res = false; const void * graph_key = ggml_cuda_graph_get_key(cgraph); @@ -2985,15 +2995,20 @@ static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx } // Check if the graph size has changed - if (graph->props.size() != (size_t)cgraph->n_nodes + cgraph->n_leafs) { + if (graph->props.size() != (size_t)cgraph->n_nodes) { res = true; - graph->props.resize(cgraph->n_nodes + cgraph->n_leafs); + graph->props.resize(cgraph->n_nodes); } // Loop over nodes in GGML graph to determine if CUDA graph update is required // and store properties to allow this comparison for the next token + std::unordered_set seen_node; + std::vector srcs_extra; for (int i = 0; i < cgraph->n_nodes; i++) { bool props_match = true; + + seen_node.insert(cgraph->nodes[i]); + if (!res) { props_match = ggml_cuda_graph_node_properties_match(cgraph->nodes[i], &graph->props[i]); } @@ -3001,17 +3016,31 @@ static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx res = true; } ggml_cuda_graph_node_set_properties(&graph->props[i], cgraph->nodes[i]); + + for (int src_idx = 0; src_idx < GGML_MAX_SRC; ++src_idx) { + ggml_tensor * src = cgraph->nodes[i]->src[src_idx]; + if (src && seen_node.find(src) == seen_node.end()) { + srcs_extra.push_back(src); + } + } + } + + if (graph->extra.size() != (size_t) srcs_extra.size()) { + res = true; + graph->extra.resize(srcs_extra.size()); } - for (int i = 0; i < cgraph->n_leafs; i++) { + for (size_t i = 0; i < srcs_extra.size(); ++i) { bool props_match = true; + if (!res) { - props_match = ggml_cuda_graph_node_properties_match(cgraph->leafs[i], &graph->props[cgraph->n_nodes + i]); + props_match = ggml_cuda_graph_node_properties_match(srcs_extra[i], &graph->extra[i]); } + if (!props_match) { res = true; } - ggml_cuda_graph_node_set_properties(&graph->props[cgraph->n_nodes + i], cgraph->leafs[i]); + ggml_cuda_graph_node_set_properties(&graph->extra[i], srcs_extra[i]); } return res; From 2a89a3f35c0567ce9d132ed72f24b75835c407e2 Mon Sep 17 00:00:00 2001 From: Todor Boinovski Date: Thu, 29 Jan 2026 12:33:21 -0800 Subject: [PATCH 064/831] hexagon: enable offloading to Hexagon on Windows on Snapdragon (llama/19150) * hexagon: updates to enable offloading to HTP on WoS * Update windows.md * Update windows.md * hexagon: enable -O3 optimizations * hexagon: move all _WINDOWS conditional compilation to _WIN32 * hexagon: updates to enable offloading to HTP on WoS * hexagon: use run-time vs load-time dynamic linking for cdsp driver interface * refactor htp-drv * hexagon: add run-bench.ps1 script * hexagon: htdrv refactor * hexagon: unify Android and Windows build readmes * hexagon: update README.md * hexagon: refactor htpdrv * hexagon: drv refactor * hexagon: more drv refactor * hexagon: fixes for android builds * hexagon: factor out dl into ggml-backend-dl * hexagon: add run-tool.ps1 script * hexagon: merge htp-utils in htp-drv and remove unused code * wos: no need for getopt_custom.h * wos: add missing CR in htpdrv * hexagon: ndev enforecement applies only to the Android devices * hexagon: add support for generating and signing .cat file * hexagon: add .inf file * hexagon: working auto-signing and improved windows builds * hexagon: futher improve skel build * hexagon: add rough WoS guide * hexagon: updated windows guide * hexagon: improve cmake handling of certs and logging * hexagon: improve windows setup/build doc * hexagon: more windows readme updates * hexagon: windows readme updates * hexagon: windows readme updates * hexagon: windows readme updates * hexagon: windows readme updates * Update windows.md * Update windows.md * snapdragon: rename docs/backend/hexagon to docs/backends/snapdragon Also added a power shell script to simplify build env setup. * hexagon: remove trailing whitespace and move cmake requirement to user-presets * hexagon: fix CMakeUserPresets path in workflow yaml * hexagon: introduce local version of libdl.h * hexagon: fix src1 reuse logic gpt-oss needs a bigger lookahead window. The check for src[1] itself being quantized was wrong. --------- Co-authored-by: Max Krasnyansky --- ggml/src/CMakeLists.txt | 1 + ggml/src/ggml-backend-dl.cpp | 48 +++ ggml/src/ggml-backend-dl.h | 45 +++ ggml/src/ggml-backend-reg.cpp | 67 +--- ggml/src/ggml-hexagon/CMakeLists.txt | 115 ++++--- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 59 ++-- ggml/src/ggml-hexagon/htp-drv.cpp | 418 +++++++++++++++++++++++ ggml/src/ggml-hexagon/htp-drv.h | 121 +++++++ ggml/src/ggml-hexagon/htp-utils.c | 454 ------------------------- ggml/src/ggml-hexagon/htp-utils.h | 221 ------------ ggml/src/ggml-hexagon/libdl.h | 79 +++++ ggml/src/ggml-hexagon/libggml-htp.inf | 38 +++ 12 files changed, 848 insertions(+), 818 deletions(-) create mode 100644 ggml/src/ggml-backend-dl.cpp create mode 100644 ggml/src/ggml-backend-dl.h create mode 100644 ggml/src/ggml-hexagon/htp-drv.cpp create mode 100644 ggml/src/ggml-hexagon/htp-drv.h delete mode 100644 ggml/src/ggml-hexagon/htp-utils.c delete mode 100644 ggml/src/ggml-hexagon/htp-utils.h create mode 100644 ggml/src/ggml-hexagon/libdl.h create mode 100644 ggml/src/ggml-hexagon/libggml-htp.inf diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index 260ad48f0e8..265023733e7 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -222,6 +222,7 @@ if (GGML_SCHED_NO_REALLOC) endif() add_library(ggml + ggml-backend-dl.cpp ggml-backend-reg.cpp) add_library(ggml::ggml ALIAS ggml) diff --git a/ggml/src/ggml-backend-dl.cpp b/ggml/src/ggml-backend-dl.cpp new file mode 100644 index 00000000000..a65cf009055 --- /dev/null +++ b/ggml/src/ggml-backend-dl.cpp @@ -0,0 +1,48 @@ +#include "ggml-backend-dl.h" + +#ifdef _WIN32 + +dl_handle * dl_load_library(const fs::path & path) { + // suppress error dialogs for missing DLLs + DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS); + SetErrorMode(old_mode | SEM_FAILCRITICALERRORS); + + HMODULE handle = LoadLibraryW(path.wstring().c_str()); + + SetErrorMode(old_mode); + + return handle; +} + +void * dl_get_sym(dl_handle * handle, const char * name) { + DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS); + SetErrorMode(old_mode | SEM_FAILCRITICALERRORS); + + void * p = (void *) GetProcAddress(handle, name); + + SetErrorMode(old_mode); + + return p; +} + +const char * dl_error() { + return ""; +} + +#else + +dl_handle * dl_load_library(const fs::path & path) { + dl_handle * handle = dlopen(path.string().c_str(), RTLD_NOW | RTLD_LOCAL); + return handle; +} + +void * dl_get_sym(dl_handle * handle, const char * name) { + return dlsym(handle, name); +} + +const char * dl_error() { + const char *rslt = dlerror(); + return rslt != nullptr ? rslt : ""; +} + +#endif diff --git a/ggml/src/ggml-backend-dl.h b/ggml/src/ggml-backend-dl.h new file mode 100644 index 00000000000..f74b7c94894 --- /dev/null +++ b/ggml/src/ggml-backend-dl.h @@ -0,0 +1,45 @@ +#pragma once + +#ifdef _WIN32 +# define WIN32_LEAN_AND_MEAN +# ifndef NOMINMAX +# define NOMINMAX +# endif +# include +# include +#else +# include +# include +#endif +#include + +namespace fs = std::filesystem; + +#ifdef _WIN32 + +using dl_handle = std::remove_pointer_t; + +struct dl_handle_deleter { + void operator()(HMODULE handle) { + FreeLibrary(handle); + } +}; + +#else + +using dl_handle = void; + +struct dl_handle_deleter { + void operator()(void * handle) { + dlclose(handle); + } +}; + +#endif + +using dl_handle_ptr = std::unique_ptr; + +dl_handle * dl_load_library(const fs::path & path); +void * dl_get_sym(dl_handle * handle, const char * name); +const char * dl_error(); + diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp index dd991f262e6..8a693f84af5 100644 --- a/ggml/src/ggml-backend-reg.cpp +++ b/ggml/src/ggml-backend-reg.cpp @@ -1,5 +1,6 @@ #include "ggml-backend-impl.h" #include "ggml-backend.h" +#include "ggml-backend-dl.h" #include "ggml-impl.h" #include #include @@ -98,72 +99,6 @@ static std::string path_str(const fs::path & path) { } } -#ifdef _WIN32 - -using dl_handle = std::remove_pointer_t; - -struct dl_handle_deleter { - void operator()(HMODULE handle) { - FreeLibrary(handle); - } -}; - -static dl_handle * dl_load_library(const fs::path & path) { - // suppress error dialogs for missing DLLs - DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS); - SetErrorMode(old_mode | SEM_FAILCRITICALERRORS); - - HMODULE handle = LoadLibraryW(path.wstring().c_str()); - - SetErrorMode(old_mode); - - return handle; -} - -static void * dl_get_sym(dl_handle * handle, const char * name) { - DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS); - SetErrorMode(old_mode | SEM_FAILCRITICALERRORS); - - void * p = (void *) GetProcAddress(handle, name); - - SetErrorMode(old_mode); - - return p; -} - -static const char * dl_error() { - return ""; -} - -#else - -using dl_handle = void; - -struct dl_handle_deleter { - void operator()(void * handle) { - dlclose(handle); - } -}; - -static void * dl_load_library(const fs::path & path) { - dl_handle * handle = dlopen(path.string().c_str(), RTLD_NOW | RTLD_LOCAL); - - return handle; -} - -static void * dl_get_sym(dl_handle * handle, const char * name) { - return dlsym(handle, name); -} - -static const char * dl_error() { - const char *rslt = dlerror(); - return rslt != nullptr ? rslt : ""; -} - -#endif - -using dl_handle_ptr = std::unique_ptr; - struct ggml_backend_reg_entry { ggml_backend_reg_t reg; dl_handle_ptr handle; diff --git a/ggml/src/ggml-hexagon/CMakeLists.txt b/ggml/src/ggml-hexagon/CMakeLists.txt index d58e2878237..2b69197017f 100644 --- a/ggml/src/ggml-hexagon/CMakeLists.txt +++ b/ggml/src/ggml-hexagon/CMakeLists.txt @@ -1,7 +1,17 @@ +file(TO_CMAKE_PATH "${HEXAGON_SDK_ROOT}" HEXAGON_SDK_ROOT) +file(TO_CMAKE_PATH "${HEXAGON_TOOLS_ROOT}" HEXAGON_TOOLS_ROOT) + +if (NOT IS_DIRECTORY "${HEXAGON_SDK_ROOT}" OR NOT IS_DIRECTORY "${HEXAGON_TOOLS_ROOT}") + message(FATAL_ERROR "Make sure HEXAGON_SDK_ROOT and HEXAGON_TOOLS_ROOT point to the correct Hexagon SDK installation.") +endif() + +message(STATUS "hexagon: using ${HEXAGON_SDK_ROOT} and ${HEXAGON_TOOLS_ROOT} for building libggml-htp skels") + include(${HEXAGON_SDK_ROOT}/build/cmake/hexagon_fun.cmake) include(ExternalProject) option(GGML_HEXAGON_HTP_DEBUG "ggml-hexagon: enable HTP debug output" OFF) +set(GGML_HEXAGON_HTP_CERT "$ENV{HEXAGON_HTP_CERT}" CACHE PATH "ggml-hexagon: enable HTP library signing using certificate") set(GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE 128 CACHE STRING "ggml-hexagon: quantize group size (32, 64, or 128)") add_library(htp_iface OBJECT @@ -25,56 +35,71 @@ else() target_link_options(htp_iface PUBLIC -ldl) endif() -link_custom_library(htp_iface cdsprpc) -link_custom_library(htp_iface rpcmem) - set(TARGET_NAME ggml-hexagon) ggml_add_backend_library(${TARGET_NAME} - ggml-hexagon.cpp htp-utils.c htp-utils.h ../../include/ggml-hexagon.h) + ggml-hexagon.cpp + htp-drv.cpp + htp-drv.h + libdl.h + ../../include/ggml-hexagon.h) target_link_libraries(${TARGET_NAME} PRIVATE htp_iface) target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/htp ${CMAKE_CURRENT_BINARY_DIR}) -# Build HTP bits -set(HTP_CMAKE_ARGS - -DCMAKE_TOOLCHAIN_FILE=${CMAKE_CURRENT_SOURCE_DIR}/htp/cmake-toolchain.cmake - -DCMAKE_BUILD_TYPE=Release - -DCMAKE_INSTALL_LIBDIR=${CMAKE_CURRENT_BINARY_DIR} - -DHEXAGON_SDK_ROOT=$ENV{HEXAGON_SDK_ROOT} - -DHEXAGON_TOOLS_ROOT=$ENV{HEXAGON_TOOLS_ROOT} - -DHEXAGON_HTP_DEBUG=${GGML_HEXAGON_HTP_DEBUG} - -DGGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE=${GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE}) - -ExternalProject_Add(htp-v68 - SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON - CMAKE_ARGS ${HTP_CMAKE_ARGS} -DDSP_VERSION=v68 -DPREBUILT_LIB_DIR="toolv19_v68") - -ExternalProject_Add(htp-v69 - SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON - CMAKE_ARGS ${HTP_CMAKE_ARGS} -DDSP_VERSION=v69 -DPREBUILT_LIB_DIR="toolv19_v69") - -ExternalProject_Add(htp-v73 - SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON - CMAKE_ARGS ${HTP_CMAKE_ARGS} -DDSP_VERSION=v73 -DPREBUILT_LIB_DIR="toolv19_v73") - -ExternalProject_Add(htp-v75 - SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON - CMAKE_ARGS ${HTP_CMAKE_ARGS} -DDSP_VERSION=v75 -DPREBUILT_LIB_DIR="toolv19_v75") - -ExternalProject_Add(htp-v79 - SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON - CMAKE_ARGS ${HTP_CMAKE_ARGS} -DDSP_VERSION=v79 -DPREBUILT_LIB_DIR="toolv19_v79") - -ExternalProject_Add(htp-v81 - SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON - CMAKE_ARGS ${HTP_CMAKE_ARGS} -DDSP_VERSION=v81 -DPREBUILT_LIB_DIR="toolv19_v81") +# Build HTP skels +set(HTP_SKELS) +function(build_htp_skel V) + ExternalProject_Add(htp-${V} + SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON + BUILD_BYPRODUCTS ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-${V}.so + CMAKE_ARGS + -DCMAKE_BUILD_TYPE=Release + -DCMAKE_TOOLCHAIN_FILE=${CMAKE_CURRENT_SOURCE_DIR}/htp/cmake-toolchain.cmake + -DCMAKE_INSTALL_LIBDIR=${CMAKE_CURRENT_BINARY_DIR} + -DHEXAGON_SDK_ROOT=${HEXAGON_SDK_ROOT} + -DHEXAGON_TOOLS_ROOT=${HEXAGON_TOOLS_ROOT} + -DHEXAGON_HTP_DEBUG=${GGML_HEXAGON_HTP_DEBUG} + -DGGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE=${GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE} + -DDSP_VERSION=${V} + -DPREBUILT_LIB_DIR="toolv19_${V}") + list(APPEND HTP_SKELS ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-${V}.so) + set(HTP_SKELS ${HTP_SKELS} PARENT_SCOPE) +endfunction() + +build_htp_skel(v68) +build_htp_skel(v69) +build_htp_skel(v73) +build_htp_skel(v75) +build_htp_skel(v79) +build_htp_skel(v81) # Install Hexagon skels required at runtime -install(FILES - ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v68.so - ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v69.so - ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v73.so - ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v75.so - ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v79.so - ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v81.so - TYPE LIB) +install(FILES ${HTP_SKELS} TYPE LIB) + +if (CMAKE_SYSTEM_NAME MATCHES Windows AND GGML_HEXAGON_HTP_CERT) + file(TO_CMAKE_PATH "$ENV{WINDOWS_SDK_BIN}/arm64" WINSDK_BIN0_ARM64) + file(TO_CMAKE_PATH "$ENV{WINDOWS_SDK_BIN}/x86" WINSDK_BIN0_X86) + file(TO_CMAKE_PATH "$ENV{WindowsSdkVerBinPath}/arm64" WINSDK_BIN1_ARM64) + file(TO_CMAKE_PATH "$ENV{WindowsSdkVerBinPath}/x86" WINSDK_BIN1_X86) + + set(WINSDK_PATHS ${WINSDK_BIN0_ARM64} ${WINSDK_BIN0_X86} ${WINSDK_BIN1_ARM64} ${WINSDK_BIN1_X86}) + + find_program(INF2CAT NAMES inf2cat.exe PATHS ${WINSDK_PATHS} REQUIRED) + find_program(SIGNTOOL NAMES signtool.exe PATHS ${WINSDK_PATHS} REQUIRED) + + message(STATUS "hexagon: using ${GGML_HEXAGON_HTP_CERT} to sign libggml-htp skels") + + set(LIBGGML_HTP_CAT ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp.cat) + add_custom_target(libggml-htp-cat + BYPRODUCTS ${LIBGGML_HTP_CAT} + DEPENDS libggml-htp.inf ${HTP_SKELS} + COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_SOURCE_DIR}/libggml-htp.inf ${CMAKE_CURRENT_BINARY_DIR} + COMMAND ${INF2CAT} /driver:${CMAKE_CURRENT_BINARY_DIR} /os:10_25H2_ARM64 + COMMAND ${SIGNTOOL} sign /fd sha256 /f ${GGML_HEXAGON_HTP_CERT} ${LIBGGML_HTP_CAT} + COMMENT "generating and signing libggml-htp.cat file" + VERBATIM + ) + + add_dependencies(${TARGET_NAME} libggml-htp-cat) + install(FILES ${LIBGGML_HTP_CAT} TYPE LIB) +endif() diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 5b835c11c72..4f0a1620fbf 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -14,9 +14,6 @@ #ifdef _WIN32 # include -# ifndef _WINDOWS -# define _WINDOWS -# endif #else # include # include @@ -25,8 +22,6 @@ #pragma clang diagnostic ignored "-Wnested-anon-types" #pragma clang diagnostic ignored "-Wgnu-anonymous-struct" -#include "htp-utils.h" - #include #include #include @@ -40,6 +35,7 @@ #include "op-desc.h" #include "htp-msg.h" #include "htp_iface.h" +#include "htp-drv.h" static size_t opt_ndev = 1; static size_t opt_nhvx = 0; // use all @@ -150,9 +146,9 @@ void ggml_hexagon_session::enqueue(struct htp_general_req &req, struct dspqueue_ 0, // flags - the framework will autoset this n_bufs, // number of buffers bufs, // buffer references - sizeof(req), + sizeof(req), // Message length (const uint8_t *) &req, // Message - 1000000 // Timeout + DSPQUEUE_TIMEOUT // Timeout ); if (err != 0) { @@ -182,13 +178,13 @@ void ggml_hexagon_session::flush() { // Read response packet from queue int err = dspqueue_read(q, &flags, - HTP_MAX_PACKET_BUFFERS, // Maximum number of buffer references - &n_bufs, // Number of buffer references - bufs, // Buffer references - sizeof(rsp), // Max message length - &rsp_size, // Message length - (uint8_t *) &rsp, - 1000000); // Timeout + HTP_MAX_PACKET_BUFFERS, // Maximum number of buffer references + &n_bufs, // Number of buffer references + bufs, // Buffer references + sizeof(rsp), // Max message length + &rsp_size, // Message length + (uint8_t *) &rsp, // Message + DSPQUEUE_TIMEOUT); // Timeout if (err == AEE_EEXPIRED) { // TODO: might need to bail out if the HTP is stuck on something @@ -269,13 +265,7 @@ struct ggml_backend_hexagon_buffer_context { ggml_backend_hexagon_buffer_context(ggml_hexagon_session * sess, size_t size, bool repack) { size += 4 * 1024; // extra page for padding - if (rpcmem_alloc2) { - this->base = (uint8_t *) rpcmem_alloc2(RPCMEM_HEAP_ID_SYSTEM, RPCMEM_DEFAULT_FLAGS | RPCMEM_HEAP_NOREG, size); - } else { - GGML_LOG_INFO("ggml-hex: %s rpcmem_alloc2 not found, falling back to rpcmem_alloc\n", sess->name.c_str()); - this->base = (uint8_t *) rpcmem_alloc(RPCMEM_HEAP_ID_SYSTEM, RPCMEM_DEFAULT_FLAGS | RPCMEM_HEAP_NOREG, size); - } - + this->base = (uint8_t *) rpcmem_alloc2(RPCMEM_HEAP_ID_SYSTEM, RPCMEM_DEFAULT_FLAGS | RPCMEM_HEAP_NOREG, size); if (!this->base) { GGML_LOG_ERROR("ggml-hex: %s failed to allocate buffer : size %zu\n", sess->name.c_str(), size); throw std::runtime_error("ggml-hex: rpcmem_alloc failed (see log for details)"); @@ -2461,12 +2451,12 @@ static void ggml_backend_hexagon_free(ggml_backend_t backend) { } static inline bool op_reuse_src1(const ggml_tensor * op1, const ggml_tensor * op0) { - return (op0 && op0->src[1] == op1->src[1] && ggml_is_quantized(op0->src[0]->type) && ggml_is_quantized(op1->src[1]->type)); + return (op0 && op0->src[1] == op1->src[1] && ggml_is_quantized(op0->src[0]->type)); } static inline bool is_compute_op(ggml_tensor *node) { - return !(ggml_op_is_empty(node->op) || ggml_is_empty(node)); + return !ggml_op_is_empty(node->op) && !ggml_is_empty(node) && (node->flags & GGML_TENSOR_FLAG_COMPUTE); } // scan the graph and figure out last compute op index @@ -2488,7 +2478,7 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg const int last = last_compute_op(graph); - const struct ggml_tensor * prev_quant_op = nullptr; // prev executed op with quantizer + const struct ggml_tensor * prev_op = nullptr; // prev executed op for (int i = 0; i < graph->n_nodes; ++i) { ggml_tensor * node = graph->nodes[i]; @@ -2497,17 +2487,15 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg continue; } - if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { - continue; - } - uint32_t flags = 0; // skip quantizer if src1 is reused - if (op_reuse_src1(node, prev_quant_op)) { + if (op_reuse_src1(node, prev_op)) { flags |= HTP_OPFLAGS_SKIP_QUANTIZE; } + prev_op = node; + // ask for early notification for the last Op if (i == last) { flags |= HTP_OPFLAGS_EARLY_WAKEUP; @@ -2520,7 +2508,6 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg } else { ggml_hexagon_dispatch_op>(sess, node, flags); } - prev_quant_op = node; break; case GGML_OP_MUL_MAT_ID: if (ggml_is_quantized(node->src[0]->type)) { @@ -2528,7 +2515,6 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg } else { ggml_hexagon_dispatch_op>(sess, node, flags); } - prev_quant_op = node; break; case GGML_OP_MUL: case GGML_OP_ADD: @@ -2670,7 +2656,7 @@ static std::vector ggml_hexagon_graph_optimize_reorder(const std::vectorcontext = new ggml_hexagon_registry(reg); HEX_VERBOSE("ggml-hex: size-of-general-req %zu size-of-general-rsp %zu\n", sizeof(struct htp_general_req), @@ -3180,6 +3170,11 @@ ggml_backend_reg_t ggml_backend_hexagon_reg(void) { static std::mutex mutex; std::lock_guard lock(mutex); if (!initialized) { + auto nErr = htpdrv_init(); + if (nErr != AEE_SUCCESS) { + return NULL; + } + ggml_hexagon_init(®); } diff --git a/ggml/src/ggml-hexagon/htp-drv.cpp b/ggml/src/ggml-hexagon/htp-drv.cpp new file mode 100644 index 00000000000..2530bb06d6c --- /dev/null +++ b/ggml/src/ggml-hexagon/htp-drv.cpp @@ -0,0 +1,418 @@ +// sample drv interface + +#pragma clang diagnostic ignored "-Wgnu-anonymous-struct" +#pragma clang diagnostic ignored "-Wmissing-prototypes" +#pragma clang diagnostic ignored "-Wsign-compare" + +#include +#include +#include +#include +#ifdef _WIN32 +# define WIN32_LEAN_AND_MEAN +# ifndef NOMINMAX +# define NOMINMAX +# endif +# include +# include +#else +# include +# include +#endif +#include "ggml-impl.h" +#include "htp-drv.h" +#include "libdl.h" + +#include + +// +// Driver API types +// + +typedef void * (*rpcmem_alloc_pfn_t)(int heapid, uint32_t flags, int size); +typedef void * (*rpcmem_alloc2_pfn_t)(int heapid, uint32_t flags, size_t size); +typedef void (*rpcmem_free_pfn_t)(void * po); +typedef int (*rpcmem_to_fd_pfn_t)(void * po); + +typedef AEEResult (*dspqueue_create_pfn_t)(int domain, + uint32_t flags, + uint32_t req_queue_size, + uint32_t resp_queue_size, + dspqueue_callback_t packet_callback, + dspqueue_callback_t error_callback, + void * callback_context, + dspqueue_t * queue); +typedef AEEResult (*dspqueue_close_pfn_t)(dspqueue_t queue); +typedef AEEResult (*dspqueue_export_pfn_t)(dspqueue_t queue, uint64_t *queue_id); +typedef AEEResult (*dspqueue_write_pfn_t)(dspqueue_t queue, uint32_t flags, + uint32_t num_buffers, + struct dspqueue_buffer *buffers, + uint32_t message_length, + const uint8_t *message, + uint32_t timeout_us); +typedef AEEResult (*dspqueue_read_pfn_t)(dspqueue_t queue, uint32_t *flags, + uint32_t max_buffers, uint32_t *num_buffers, + struct dspqueue_buffer *buffers, + uint32_t max_message_length, + uint32_t *message_length, uint8_t *message, + uint32_t timeout_us); + +typedef int (*fastrpc_mmap_pfn_t)(int domain, int fd, void *addr, int offset, size_t length, enum fastrpc_map_flags flags); +typedef int (*fastrpc_munmap_pfn_t)(int domain, int fd, void *addr, size_t length); + +typedef int (*remote_handle64_open_pfn_t)(const char* name, remote_handle64 *ph); +typedef int (*remote_handle64_invoke_pfn_t)(remote_handle64 h, uint32_t dwScalars, remote_arg *pra); +typedef int (*remote_handle64_close_pfn_t)(remote_handle h); +typedef int (*remote_handle_control_pfn_t)(uint32_t req, void* data, uint32_t datalen); +typedef int (*remote_handle64_control_pfn_t)(remote_handle64 h, uint32_t req, void* data, uint32_t datalen); +typedef int (*remote_session_control_pfn_t)(uint32_t req, void *data, uint32_t datalen); + +// +// Driver API pfns +// + +rpcmem_alloc_pfn_t rpcmem_alloc_pfn = nullptr; +rpcmem_alloc2_pfn_t rpcmem_alloc2_pfn = nullptr; +rpcmem_free_pfn_t rpcmem_free_pfn = nullptr; +rpcmem_to_fd_pfn_t rpcmem_to_fd_pfn = nullptr; + +fastrpc_mmap_pfn_t fastrpc_mmap_pfn = nullptr; +fastrpc_munmap_pfn_t fastrpc_munmap_pfn = nullptr; + +dspqueue_create_pfn_t dspqueue_create_pfn = nullptr; +dspqueue_close_pfn_t dspqueue_close_pfn = nullptr; +dspqueue_export_pfn_t dspqueue_export_pfn = nullptr; +dspqueue_write_pfn_t dspqueue_write_pfn = nullptr; +dspqueue_read_pfn_t dspqueue_read_pfn = nullptr; + +remote_handle64_open_pfn_t remote_handle64_open_pfn = nullptr; +remote_handle64_invoke_pfn_t remote_handle64_invoke_pfn = nullptr; +remote_handle64_close_pfn_t remote_handle64_close_pfn = nullptr; +remote_handle_control_pfn_t remote_handle_control_pfn = nullptr; +remote_handle64_control_pfn_t remote_handle64_control_pfn = nullptr; +remote_session_control_pfn_t remote_session_control_pfn = nullptr; + +// +// Driver API +// + +void * rpcmem_alloc(int heapid, uint32_t flags, int size) { + return rpcmem_alloc_pfn(heapid, flags, size); +} + +void * rpcmem_alloc2(int heapid, uint32_t flags, size_t size) { + if (rpcmem_alloc2_pfn) { + return rpcmem_alloc2_pfn(heapid, flags, size); + } else { + GGML_LOG_INFO("ggml-hex: rpcmem_alloc2 not found, falling back to rpcmem_alloc\n"); + return rpcmem_alloc_pfn(heapid, flags, size); + } +} + +void rpcmem_free(void * po) { + return rpcmem_free_pfn(po); +} + +int rpcmem_to_fd(void * po) { + return rpcmem_to_fd_pfn(po); +} + +HTPDRV_API int fastrpc_mmap(int domain, int fd, void * addr, int offset, size_t length, enum fastrpc_map_flags flags) { + return fastrpc_mmap_pfn(domain, fd, addr, offset, length, flags); +} + +HTPDRV_API int fastrpc_munmap(int domain, int fd, void * addr, size_t length) { + return fastrpc_munmap_pfn(domain, fd, addr, length); +} + +AEEResult dspqueue_create(int domain, + uint32_t flags, + uint32_t req_queue_size, + uint32_t resp_queue_size, + dspqueue_callback_t packet_callback, + dspqueue_callback_t error_callback, + void * callback_context, + dspqueue_t * queue) { + return dspqueue_create_pfn(domain, flags, req_queue_size, resp_queue_size, packet_callback, error_callback, + callback_context, queue); +} + +AEEResult dspqueue_close(dspqueue_t queue) { + return dspqueue_close_pfn(queue); +} + +AEEResult dspqueue_export(dspqueue_t queue, uint64_t * queue_id) { + return dspqueue_export_pfn(queue, queue_id); +} + +AEEResult dspqueue_write(dspqueue_t queue, + uint32_t flags, + uint32_t num_buffers, + struct dspqueue_buffer * buffers, + uint32_t message_length, + const uint8_t * message, + uint32_t timeout_us) { + return dspqueue_write_pfn(queue, flags, num_buffers, buffers, message_length, message, timeout_us); +} + +AEEResult dspqueue_read(dspqueue_t queue, + uint32_t * flags, + uint32_t max_buffers, + uint32_t * num_buffers, + struct dspqueue_buffer * buffers, + uint32_t max_message_length, + uint32_t * message_length, + uint8_t * message, + uint32_t timeout_us) { + return dspqueue_read_pfn(queue, flags, max_buffers, num_buffers, buffers, max_message_length, message_length, + message, timeout_us); +} + +HTPDRV_API int remote_handle64_open(const char * name, remote_handle64 * ph) { + return remote_handle64_open_pfn(name, ph); +} + +HTPDRV_API int remote_handle64_invoke(remote_handle64 h, uint32_t dwScalars, remote_arg * pra) { + return remote_handle64_invoke_pfn(h, dwScalars, pra); +} + +HTPDRV_API int remote_handle64_close(remote_handle64 h) { + return remote_handle64_close_pfn(h); +} + +HTPDRV_API int remote_handle_control(uint32_t req, void * data, uint32_t datalen) { + return remote_handle_control_pfn(req, data, datalen); +} + +HTPDRV_API int remote_handle64_control(remote_handle64 h, uint32_t req, void * data, uint32_t datalen) { + return remote_handle64_control_pfn(h, req, data, datalen); +} + +HTPDRV_API int remote_session_control(uint32_t req, void * data, uint32_t datalen) { + return remote_session_control_pfn(req, data, datalen); +} + +#ifdef _WIN32 + +static std::string wstr_to_str(std::wstring_view wstr) { + std::string result; + if (wstr.empty()) { + return result; + } + auto bytes_needed = WideCharToMultiByte(CP_UTF8, WC_ERR_INVALID_CHARS, + wstr.data(), (int) wstr.size(), + nullptr, 0, nullptr, nullptr); + if (bytes_needed == 0) { + GGML_LOG_ERROR("ggml-hex: WideCharToMultiByte failed. Error %lu\n", GetLastError()); + throw std::runtime_error("Invalid wstring input"); + } + + result.resize(bytes_needed, '\0'); + int bytes_written = WideCharToMultiByte(CP_UTF8, WC_ERR_INVALID_CHARS, + wstr.data(), (int) wstr.size(), + result.data(), bytes_needed, + nullptr, nullptr); + if (bytes_written == 0) { + GGML_LOG_ERROR("ggml-hex: WideCharToMultiByte failed. Error %lu\n", GetLastError()); + throw std::runtime_error("Wstring conversion failed"); + } + return result; +} + +static std::string get_driver_path() { + std::wstring serviceName = L"qcnspmcdm"; + std::string result; + + // Get a handle to the SCM database. + SC_HANDLE schSCManager = OpenSCManagerW(NULL, NULL, STANDARD_RIGHTS_READ); + if (nullptr == schSCManager) { + GGML_LOG_ERROR("ggml-hex: Failed to open SCManager. Error: %lu\n", GetLastError()); + return result; + } + + // Get a handle to the service. + SC_HANDLE schService = OpenServiceW(schSCManager, // SCM database + serviceName.c_str(), // name of service + SERVICE_QUERY_CONFIG); // need query config access + + if (nullptr == schService) { + GGML_LOG_ERROR("ggml-hex: Failed to open qcnspmcdm service. Error: %lu\n", GetLastError()); + CloseServiceHandle(schSCManager); + return result; + } + + // Store the size of buffer used as an output. + DWORD bufferSize; + if (!QueryServiceConfigW(schService, NULL, 0, &bufferSize) && + (GetLastError() != ERROR_INSUFFICIENT_BUFFER)) { + GGML_LOG_ERROR("ggml-hex: Failed to query service config. Error: %lu\n", GetLastError()); + CloseServiceHandle(schService); + CloseServiceHandle(schSCManager); + return result; + } + // Get the configuration of the service. + LPQUERY_SERVICE_CONFIGW serviceConfig = + static_cast(LocalAlloc(LMEM_FIXED, bufferSize)); + if (!QueryServiceConfigW(schService, serviceConfig, bufferSize, &bufferSize)) { + fprintf(stderr, "ggml-hex: Failed to query service config. Error: %lu\n", GetLastError()); + LocalFree(serviceConfig); + CloseServiceHandle(schService); + CloseServiceHandle(schSCManager); + return result; + } + + // Read the driver file path get its parent directory + std::wstring driverPath = std::wstring(serviceConfig->lpBinaryPathName); + driverPath = driverPath.substr(0, driverPath.find_last_of(L"\\")); + + // Clean up resources + LocalFree(serviceConfig); + CloseServiceHandle(schService); + CloseServiceHandle(schSCManager); + + // Driver path would contain invalid path string, like: + // \SystemRoot\System32\DriverStore\FileRepository\qcadsprpc8280.inf_arm64_c2b9460c9a072f37 + // "\SystemRoot" should be replace with a correct one (e.g. C:\Windows) + const std::wstring systemRootPlaceholder = L"\\SystemRoot"; + if (0 != driverPath.compare(0, systemRootPlaceholder.length(), systemRootPlaceholder)) { + GGML_LOG_ERROR("ggml-hex: String pattern not found in driver path.\n"); + return result; + } + + // Replace \SystemRoot with an absolute path from system ENV windir + const std::wstring systemRootEnv = L"windir"; + + // Query the number of wide charactors this variable requires + DWORD numWords = GetEnvironmentVariableW(systemRootEnv.c_str(), NULL, 0); + if (numWords == 0) { + GGML_LOG_ERROR("ggml-hex: Failed get systemRoot environment variable\n"); + return result; + } + + // Query the actual system root name from environment variable + std::vector systemRoot(numWords + 1); + numWords = GetEnvironmentVariableW(systemRootEnv.c_str(), systemRoot.data(), numWords + 1); + if (numWords == 0) { + GGML_LOG_ERROR("ggml-hex: Failed to read windir environment variable\n"); + return result; + } + driverPath.replace(0, systemRootPlaceholder.length(), std::wstring(systemRoot.data())); + + return wstr_to_str(driverPath); +} + +#endif + +using dl_handle_ptr = std::unique_ptr; + +int htpdrv_init() { + static dl_handle_ptr lib_cdsp_rpc_handle = nullptr; + static bool initialized = false; +#ifdef _WIN32 + std::string drv_path = get_driver_path() + "\\" + "libcdsprpc.dll"; +#else + std::string drv_path = "libcdsprpc.so"; +#endif + if (initialized) { + GGML_LOG_INFO("ggml-hex: Driver already loaded\n"); + return AEE_SUCCESS; + } + GGML_LOG_INFO("ggml-hex: Loading driver %s\n", drv_path.c_str()); + + fs::path path{ drv_path.c_str() }; + dl_handle_ptr handle { dl_load_library(path) }; + if (!handle) { + GGML_LOG_ERROR("ggml-hex: failed to load %s: %s\n", path.u8string().c_str(), dl_error()); + return AEE_EUNABLETOLOAD; + } + +#define dlsym(drv, type, pfn, symbol, ignore) \ + do { \ + pfn = (type) dl_get_sym(drv, #symbol); \ + if (!ignore && nullptr == pfn) { \ + GGML_LOG_ERROR("ggml-hex: failed to dlsym %s\n", #symbol); \ + return AEE_EUNABLETOLOAD; \ + } \ + } while (0) + + dlsym(handle.get(), rpcmem_alloc_pfn_t, rpcmem_alloc_pfn, rpcmem_alloc, false); + dlsym(handle.get(), rpcmem_alloc2_pfn_t, rpcmem_alloc2_pfn, rpcmem_alloc2, true); + dlsym(handle.get(), rpcmem_free_pfn_t, rpcmem_free_pfn, rpcmem_free, false); + dlsym(handle.get(), rpcmem_to_fd_pfn_t, rpcmem_to_fd_pfn, rpcmem_to_fd, false); + dlsym(handle.get(), fastrpc_mmap_pfn_t, fastrpc_mmap_pfn, fastrpc_mmap, false); + dlsym(handle.get(), fastrpc_munmap_pfn_t, fastrpc_munmap_pfn, fastrpc_munmap, false); + dlsym(handle.get(), dspqueue_create_pfn_t, dspqueue_create_pfn, dspqueue_create, false); + dlsym(handle.get(), dspqueue_close_pfn_t, dspqueue_close_pfn, dspqueue_close, false); + dlsym(handle.get(), dspqueue_export_pfn_t, dspqueue_export_pfn, dspqueue_export, false); + dlsym(handle.get(), dspqueue_write_pfn_t, dspqueue_write_pfn, dspqueue_write, false); + dlsym(handle.get(), dspqueue_read_pfn_t, dspqueue_read_pfn, dspqueue_read, false); + dlsym(handle.get(), remote_handle64_open_pfn_t, remote_handle64_open_pfn, remote_handle64_open, false); + dlsym(handle.get(), remote_handle64_invoke_pfn_t, remote_handle64_invoke_pfn, remote_handle64_invoke, false); + dlsym(handle.get(), remote_handle_control_pfn_t, remote_handle_control_pfn, remote_handle_control, false); + dlsym(handle.get(), remote_handle64_control_pfn_t, remote_handle64_control_pfn, remote_handle64_control, false); + dlsym(handle.get(), remote_session_control_pfn_t, remote_session_control_pfn, remote_session_control, false); + dlsym(handle.get(), remote_handle64_close_pfn_t, remote_handle64_close_pfn, remote_handle64_close, false); + + lib_cdsp_rpc_handle = std::move(handle); + initialized = true; + + return AEE_SUCCESS; +} + +domain * get_domain(int domain_id) { + int i = 0; + int size = sizeof(supported_domains) / sizeof(domain); + + for (i = 0; i < size; i++) { + if (supported_domains[i].id == domain_id) { + return &supported_domains[i]; + } + } + + return NULL; +} + +int get_hex_arch_ver(int domain, int * arch) { + if (!remote_handle_control_pfn) { + GGML_LOG_ERROR("ggml-hex: remote_handle_control is not supported on this device\n"); + return AEE_EUNSUPPORTEDAPI; + } + + struct remote_dsp_capability arch_ver; + arch_ver.domain = (uint32_t) domain; + arch_ver.attribute_ID = ARCH_VER; + arch_ver.capability = (uint32_t) 0; + + int err = remote_handle_control(DSPRPC_GET_DSP_INFO, &arch_ver, sizeof(arch_ver)); + if ((err & 0xff) == (AEE_EUNSUPPORTEDAPI & 0xff)) { + GGML_LOG_ERROR("ggml-hex: FastRPC capability API is not supported on this device\n"); + return AEE_EUNSUPPORTEDAPI; + } + + if (err != AEE_SUCCESS) { + GGML_LOG_ERROR("ggml-hex: FastRPC capability query failed (err %d)\n", err); + return err; + } + + switch (arch_ver.capability & 0xff) { + case 0x68: + *arch = 68; + return 0; + case 0x69: + *arch = 69; + return 0; + case 0x73: + *arch = 73; + return 0; + case 0x75: + *arch = 75; + return 0; + case 0x79: + *arch = 79; + return 0; + case 0x81: + *arch = 81; + return 0; + } + return -1; +} diff --git a/ggml/src/ggml-hexagon/htp-drv.h b/ggml/src/ggml-hexagon/htp-drv.h new file mode 100644 index 00000000000..6eba7ba17d8 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp-drv.h @@ -0,0 +1,121 @@ +#pragma once + +#ifdef __cplusplus +extern "C" { +#endif + +#ifdef _WIN32 +# pragma clang diagnostic ignored "-Wignored-attributes" +#endif + +#include +#include +#include +#include + +#if defined(_WIN32) && !defined(__MINGW32__) +# ifdef GGML_BACKEND_BUILD +# define HTPDRV_API __declspec(dllexport) extern +# else +# define HTPDRV_API __declspec(dllimport) extern +# endif +#else +# define HTPDRV_API __attribute__ ((visibility ("default"))) extern +#endif + +/* Offset to differentiate HLOS and Hexagon error codes. + Stores the value of AEE_EOFFSET for Hexagon. */ +#ifndef DSP_OFFSET +# define DSP_OFFSET 0x80000400 +#endif + +/* Errno for connection reset by peer. */ +#ifndef ECONNRESET +# ifdef __hexagon__ +# define ECONNRESET 104 +# endif +#endif + +/* Abstraction of different OS specific sleep APIs. + SLEEP accepts input in seconds. */ +#ifndef SLEEP +# ifdef __hexagon__ +# define SLEEP(x) \ + { /* Do nothing for simulator. */ \ + } +# else +# ifdef _WIN32 +# define SLEEP(x) Sleep(1000 * x) /* Sleep accepts input in milliseconds. */ +# else +# define SLEEP(x) sleep(x) /* sleep accepts input in seconds. */ +# endif +# endif +#endif + +/* Include windows specific header files. */ +#ifdef _WIN32 +# include +# include +# define _CRT_SECURE_NO_WARNINGS 1 +# define _WINSOCK_DEPRECATED_NO_WARNINGS 1 +#endif + +/* Includes and defines for all HLOS except windows */ +#if !defined(__hexagon__) && !defined(_WIN32) +# include "unistd.h" + +# include +#endif + +/* Includes and defines for Hexagon and all HLOS except Windows. */ +#if !defined(_WIN32) +/* Weak reference to remote symbol for compilation. */ +# pragma weak remote_session_control +# pragma weak remote_handle_control +# pragma weak remote_handle64_control +# pragma weak fastrpc_mmap +# pragma weak fastrpc_munmap +# pragma weak rpcmem_alloc2 +#endif + +#if !defined(_WIN32) +# pragma weak remote_system_request +#endif + +#ifdef _WIN32 +# define DSPQUEUE_TIMEOUT DSPQUEUE_TIMEOUT_NONE +#else +# define DSPQUEUE_TIMEOUT 1000000 +#endif + +/** + * htpdrv_init API: driver interface entry point + * + * @return Return AEE error codes as defined in Hexagon SDK. + */ +HTPDRV_API int htpdrv_init(void); + +/** + * get_domain API: get domain struct from domain value. + * + * @param[in] domain value of a domain + * @return Returns domain struct of the domain if it is supported or else + * returns NULL. + * + */ +HTPDRV_API domain * get_domain(int domain_id); + +/** + * get_hex_arch_ver API: query the Hexagon processor architecture version information + * + * @param[in] domain_id value of a domain + * @param[out] Arch version (73, 75, ...) + * @return 0 if query is successful. + * non-zero if error, return value points to the error. + * + */ +HTPDRV_API int get_hex_arch_ver(int domain, int * arch); + +#ifdef __cplusplus +} +#endif diff --git a/ggml/src/ggml-hexagon/htp-utils.c b/ggml/src/ggml-hexagon/htp-utils.c deleted file mode 100644 index 3f335bf71c0..00000000000 --- a/ggml/src/ggml-hexagon/htp-utils.c +++ /dev/null @@ -1,454 +0,0 @@ - -#pragma clang diagnostic ignored "-Wgnu-anonymous-struct" -#pragma clang diagnostic ignored "-Wmissing-prototypes" -#pragma clang diagnostic ignored "-Wsign-compare" - -#define GGML_COMMON_IMPL_C -#include "ggml-backend-impl.h" -#include "ggml-common.h" -#include "ggml-hexagon.h" -#include "ggml-impl.h" - -#include "htp-utils.h" - -#include -#include -#include -#include -#include -#include -#include - -domain * get_domain(int domain_id) { - int i = 0; - int size = sizeof(supported_domains) / sizeof(domain); - - for (i = 0; i < size; i++) { - if (supported_domains[i].id == domain_id) { - return &supported_domains[i]; - } - } - - return NULL; -} - -bool is_valid_domain_id(int domain_id, int compute_only) { - int i = 0; - int size = sizeof(supported_domains) / sizeof(domain); - - if (compute_only) { - return is_CDSP(domain_id); - } - - for (i = 0; i < size; i++) { - if (supported_domains[i].id == domain_id) { - return true; - } - } - - return false; -} - -int get_domains_info(char * domain_type, int * num_domains, fastrpc_domain ** domains_info) { - int nErr = AEE_SUCCESS; - int ss_info = 0; - if (domain_type != NULL) { - if (strcmp(domain_type, "LPASS") == 0) { - ss_info = FASTRPC_LPASS; - } else if (strcmp(domain_type, "HPASS") == 0) { - ss_info = FASTRPC_HPASS; - } else { - ss_info = FASTRPC_NSP; - } - } - system_req_payload req = { 0 }; - req.id = FASTRPC_GET_DOMAINS; - req.sys.domains = NULL; - fastrpc_domain * domain = NULL; - if (ss_info != 0) { - req.sys.flags = DOMAINS_LIST_FLAGS_SET_TYPE(req.sys.flags, ss_info); - } else { - req.sys.flags = 0; - } -#ifdef _WIN32 - nErr = AEE_EUNSUPPORTED; - goto bail; -#endif - if (remote_system_request) { - nErr = remote_system_request(&req); - if (nErr != AEE_SUCCESS) { - GGML_LOG_ERROR("Failure in remote_system_request call: %d.\n", nErr); - goto bail; - } - // Allocate memory for domain-info array - req.sys.max_domains = req.sys.num_domains; - if ((req.sys.domains = calloc(req.sys.num_domains, sizeof(fastrpc_domain))) == NULL) { - nErr = AEE_ENOMEMORY; - GGML_LOG_ERROR("Unable to allocate memory for req.sys.domains"); - goto bail; - } - - nErr = remote_system_request(&req); - if (nErr != AEE_SUCCESS) { - GGML_LOG_ERROR("Failure in remote_system_request call: %d.\n", nErr); - goto bail; - } - - for (int i = 0; i < req.sys.num_domains; i++) { - // Verify that only requested type domains were returned - domain = &req.sys.domains[i]; - if (domain->type != ss_info && domain_type != NULL) { - nErr = -1; - GGML_LOG_ERROR("Incorrect data received from remote_system_request.\n"); - goto bail; - } - } - *domains_info = req.sys.domains; - *num_domains = req.sys.num_domains; - } else { - nErr = AEE_EUNSUPPORTED; - goto bail; - } -bail: - if (nErr && !req.sys.domains) { - free(req.sys.domains); - } - return nErr; -} - -int get_effective_domain_id(char * domain_name, int session_id, int * effec_domain_id) { - int err = 0; - remote_rpc_effective_domain_id_t sess = { 0 }; - - sess.domain_name = domain_name; - sess.domain_name_len = strlen(domain_name); - sess.session_id = session_id; - - err = remote_session_control(FASTRPC_GET_EFFECTIVE_DOMAIN_ID, &sess, sizeof(sess)); - if (err) { - GGML_LOG_ERROR("Error 0x%x: failed to get effective domain id for %s, session id %d\n", err, sess.domain_name, - session_id); - return err; - } - - *effec_domain_id = sess.effective_domain_id; - return err; -} - -int get_dsp_support(int * domain) { - int nErr = AEE_SUCCESS; - *domain = CDSP_DOMAIN_ID; // DSP domain default value is CDSP_DOMAIN_ID - - if (remote_handle_control) { - struct remote_dsp_capability dsp_capability_domain = { CDSP_DOMAIN_ID, DOMAIN_SUPPORT, 0 }; - nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_domain, sizeof(struct remote_dsp_capability)); - if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) { - GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device\n"); - goto bail; - } - - if (dsp_capability_domain.capability == 0) { - dsp_capability_domain.domain = ADSP_DOMAIN_ID; // Check for ADSP support. - dsp_capability_domain.attribute_ID = DOMAIN_SUPPORT; - dsp_capability_domain.capability = 0; - nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_domain, - sizeof(struct remote_dsp_capability)); - if (dsp_capability_domain.capability) { - *domain = ADSP_DOMAIN_ID; // For targets like Agatti (not having cDSP), domain is ADSP_DOMAIN_ID - } - } - - if (nErr != AEE_SUCCESS) { - GGML_LOG_ERROR("\nget_dsp_support failed with Error 0x%x\n", nErr); - goto bail; - } - } else { - nErr = AEE_EUNSUPPORTEDAPI; - GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device\n"); - } - -bail: - return nErr; -} - -int get_vtcm_info(int domain, uint32_t * capability, uint32_t attr) { - int nErr = AEE_SUCCESS; - *capability = 0; - - if (attr == VTCM_PAGE || attr == VTCM_COUNT) { - } else { - nErr = AEE_EBADPARM; - GGML_LOG_ERROR("Unsupported attr. Only VTCM_PAGE and VTCM_COUNT supported\n"); - goto bail; - } - if (remote_handle_control) { - if (domain == ADSP_DOMAIN_ID || domain == CDSP_DOMAIN_ID) { - /* - * Query the DSP for VTCM information - * Since the ADSP does not have a dedicated VTCM, we expect the output to be 0 - */ - struct remote_dsp_capability dsp_capability_vtcm_dsp; - dsp_capability_vtcm_dsp.domain = (uint32_t) domain; - dsp_capability_vtcm_dsp.attribute_ID = attr; - dsp_capability_vtcm_dsp.capability = (uint32_t) 0; - nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_vtcm_dsp, - sizeof(struct remote_dsp_capability)); - if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) { - GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device\n"); - GGML_LOG_ERROR("Running the usecase without checking the capability\n"); - nErr = AEE_SUCCESS; - goto bail; - } else if (nErr == AEE_SUCCESS) { - *capability = dsp_capability_vtcm_dsp.capability; - } else { - GGML_LOG_ERROR("\nget_vtcm_info failed with Error 0x%x\n", nErr); - goto bail; - } - } else { - nErr = AEE_EUNSUPPORTED; - GGML_LOG_ERROR("Unsupported domain %d\n", domain); - goto bail; - } - } else { - nErr = AEE_EUNSUPPORTEDAPI; - GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device\n"); - } - -bail: - return nErr; -} - -bool is_unsignedpd_supported(int domain_id) { - int nErr = AEE_SUCCESS; - if (remote_handle_control) { - struct remote_dsp_capability dsp_capability_domain = { domain_id, UNSIGNED_PD_SUPPORT, 0 }; - nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_domain, sizeof(struct remote_dsp_capability)); - if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) { - GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device. Falling back to signed pd.\n"); - return false; - } - if (nErr) { - GGML_LOG_ERROR("\nERROR 0x%x: FastRPC Capability API failed. Falling back to signed pd.", nErr); - return false; - } - if (dsp_capability_domain.capability == 1) { - return true; - } - } else { - nErr = AEE_EUNSUPPORTEDAPI; - GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device. Falling back to signed pd.\n"); - return false; - } - return false; -} - -bool get_unsignedpd_support(void) { - return is_unsignedpd_supported(CDSP_DOMAIN_ID); -} - -bool is_async_fastrpc_supported(int domain) { - int nErr = AEE_SUCCESS; - if (remote_handle_control) { - if (domain == CDSP_DOMAIN_ID) { - /* - * Query the DSP for ASYNC_FASTRPC_SUPPORT information - * Async fastrpc is supported only on CDSP - */ - struct remote_dsp_capability dsp_capability_async_support; - dsp_capability_async_support.domain = (uint32_t) domain; - dsp_capability_async_support.attribute_ID = ASYNC_FASTRPC_SUPPORT; - dsp_capability_async_support.capability = (uint32_t) 0; - nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_async_support, - sizeof(struct remote_dsp_capability)); - if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) { - GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device\n"); - GGML_LOG_ERROR("Running the usecase without checking the capability\n"); - nErr = AEE_SUCCESS; - goto bail; - } else if (dsp_capability_async_support.capability == 1) { - return true; - } - if (nErr != AEE_SUCCESS) { - GGML_LOG_ERROR("\nis_async_fastrpc_supported failed with Error 0x%x\n", nErr); - goto bail; - } - } else { - nErr = AEE_EUNSUPPORTED; - GGML_LOG_ERROR("Async fastrpc is not supported on domain %d\n", domain); - goto bail; - } - } else { - nErr = AEE_EUNSUPPORTEDAPI; - GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device\n"); - } - -bail: - return false; -} - -bool is_status_notification_supported(int domain) { - int nErr = AEE_SUCCESS; - - if (remote_handle_control) { - /* - * Query the DSP for STATUS_NOTIFICATION_SUPPORT information - * DSP User PD status notification Support - */ - struct remote_dsp_capability dsp_capability_status_notification_support; - dsp_capability_status_notification_support.domain = (uint32_t) domain; - dsp_capability_status_notification_support.attribute_ID = STATUS_NOTIFICATION_SUPPORT; - dsp_capability_status_notification_support.capability = (uint32_t) 0; - nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_status_notification_support, - sizeof(struct remote_dsp_capability)); - if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) { - GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device\n"); - GGML_LOG_ERROR("Running the usecase without checking the capability\n"); - nErr = AEE_SUCCESS; - goto bail; - } else if (dsp_capability_status_notification_support.capability == 1) { - return true; - } - if (nErr != AEE_SUCCESS) { - GGML_LOG_ERROR("\nis_status_notification_supported failed with Error 0x%x\n", nErr); - goto bail; - } - } else { - nErr = AEE_EUNSUPPORTEDAPI; - GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device\n"); - } - -bail: - return false; -} - -int get_hmx_support_info(int domain, uint32_t * capability, uint32_t attr) { - int nErr = AEE_SUCCESS; - *capability = 0; - - if (attr != HMX_SUPPORT_SPATIAL && attr != HMX_SUPPORT_DEPTH) { - nErr = AEE_EBADPARM; - GGML_LOG_ERROR("Unsupported attr. Only HMX_SUPPORT_SPATIAL and HMX_SUPPORT_DEPTH supported\n"); - goto bail; - } - if (remote_handle_control) { - if (domain == CDSP_DOMAIN_ID) { - /* - * Query the DSP for HMX SUPPORT information - * HMX is supported on CDSP only - */ - struct remote_dsp_capability dsp_capability_hmx_dsp; - dsp_capability_hmx_dsp.domain = (uint32_t) domain; - dsp_capability_hmx_dsp.attribute_ID = attr; - dsp_capability_hmx_dsp.capability = (uint32_t) 0; - nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_hmx_dsp, - sizeof(struct remote_dsp_capability)); - if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) { - GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device\n"); - GGML_LOG_ERROR("Running the usecase without checking the capability\n"); - nErr = AEE_SUCCESS; - goto bail; - } else if (nErr == AEE_SUCCESS) { - *capability = dsp_capability_hmx_dsp.capability; - } else { - GGML_LOG_ERROR("\nget_hmx_support_info failed with Error 0x%x\n", nErr); - goto bail; - } - } else { - nErr = AEE_EUNSUPPORTED; - GGML_LOG_ERROR("HMX support is not there for domain %d\n", domain); - goto bail; - } - } else { - nErr = AEE_EUNSUPPORTEDAPI; - GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device\n"); - } - -bail: - return nErr; -} - -int get_hex_arch_ver(int domain, int * arch) { - if (!remote_handle_control) { - GGML_LOG_ERROR("ggml-hex: remote_handle_control is not supported on this device\n"); - return AEE_EUNSUPPORTEDAPI; - } - - struct remote_dsp_capability arch_ver; - arch_ver.domain = (uint32_t) domain; - arch_ver.attribute_ID = ARCH_VER; - arch_ver.capability = (uint32_t) 0; - - int err = remote_handle_control(DSPRPC_GET_DSP_INFO, &arch_ver, sizeof(arch_ver)); - if ((err & 0xff) == (AEE_EUNSUPPORTEDAPI & 0xff)) { - GGML_LOG_ERROR("ggml-hex: FastRPC capability API is not supported on this device\n"); - return AEE_EUNSUPPORTEDAPI; - } - - if (err != AEE_SUCCESS) { - GGML_LOG_ERROR("ggml-hex: FastRPC capability query failed (err %d)\n", err); - return err; - } - - switch (arch_ver.capability & 0xff) { - case 0x68: - *arch = 68; - return 0; - case 0x69: - *arch = 69; - return 0; - case 0x73: - *arch = 73; - return 0; - case 0x75: - *arch = 75; - return 0; - case 0x79: - *arch = 79; - return 0; - case 0x81: - *arch = 81; - return 0; - } - return -1; -} - -int get_hvx_support_info(int domain, uint32_t * capability, uint32_t attr) { - int nErr = AEE_SUCCESS; - *capability = 0; - - if (remote_handle_control) { - if (domain == CDSP_DOMAIN_ID) { - /* - * Query the DSP for HVX SUPPORT information - * HVX is supported on CDSP only - */ - struct remote_dsp_capability dsp_capability_hvx_dsp; - dsp_capability_hvx_dsp.domain = (uint32_t) domain; - dsp_capability_hvx_dsp.attribute_ID = attr; - dsp_capability_hvx_dsp.capability = (uint32_t) 0; - nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_hvx_dsp, - sizeof(struct remote_dsp_capability)); - if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) { - GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device\n"); - GGML_LOG_ERROR("Running the usecase without checking the capability\n"); - nErr = AEE_SUCCESS; - goto bail; - } else if (nErr == AEE_SUCCESS) { - *capability = dsp_capability_hvx_dsp.capability; - } else { - GGML_LOG_ERROR("\nget_hvx_support_info failed with Error 0x%x\n", nErr); - goto bail; - } - } else { - nErr = AEE_EUNSUPPORTED; - GGML_LOG_ERROR("HVX support is not available on domain %d\n", domain); - goto bail; - } - } else { - nErr = AEE_EUNSUPPORTEDAPI; - GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device\n"); - } - -bail: - return nErr; -} diff --git a/ggml/src/ggml-hexagon/htp-utils.h b/ggml/src/ggml-hexagon/htp-utils.h deleted file mode 100644 index 7bbae3a0b73..00000000000 --- a/ggml/src/ggml-hexagon/htp-utils.h +++ /dev/null @@ -1,221 +0,0 @@ -#ifndef HTP_UTILS_H -#define HTP_UTILS_H - -#ifdef __cplusplus -extern "C" { -#endif - -#include -#include -#include -#include -#include - -/* Offset to differentiate HLOS and Hexagon error codes. - Stores the value of AEE_EOFFSET for Hexagon. */ -#ifndef DSP_OFFSET -# define DSP_OFFSET 0x80000400 -#endif - -/* Errno for connection reset by peer. */ -#ifndef ECONNRESET -# ifdef __hexagon__ -# define ECONNRESET 104 -# endif -#endif - -/* Abstraction of different OS specific sleep APIs. - SLEEP accepts input in seconds. */ -#ifndef SLEEP -# ifdef __hexagon__ -# define SLEEP(x) \ - { /* Do nothing for simulator. */ \ - } -# else -# ifdef _WINDOWS -# define SLEEP(x) Sleep(1000 * x) /* Sleep accepts input in milliseconds. */ -# else -# define SLEEP(x) sleep(x) /* sleep accepts input in seconds. */ -# endif -# endif -#endif - -/* Include windows specific header files. */ -#ifdef _WINDOWS -# include -# include -# define _CRT_SECURE_NO_WARNINGS 1 -# define _WINSOCK_DEPRECATED_NO_WARNINGS 1 -/* Including this file for custom implementation of getopt function. */ -# include "getopt_custom.h" -#endif - -/* Includes and defines for all HLOS except windows */ -#if !defined(__hexagon__) && !defined(_WINDOWS) -# include "unistd.h" - -# include -#endif - -/* Includes and defines for Hexagon and all HLOS except Windows. */ -#if !defined(_WINDOWS) -/* Weak reference to remote symbol for compilation. */ -# pragma weak remote_session_control -# pragma weak remote_handle_control -# pragma weak remote_handle64_control -# pragma weak fastrpc_mmap -# pragma weak fastrpc_munmap -# pragma weak rpcmem_alloc2 -#endif - -#if !defined(_WINDOWS) -# pragma weak remote_system_request -#endif -/** - * Wrapper for FastRPC Capability API: query DSP support. - * - * @param[out] domain pointer to supported domain. - * @return 0 if query is successful. - * non-zero if error, return value points to the error. - */ -int get_dsp_support(int * domain); - -/** - * Wrapper for FastRPC Capability API: query VTCM information. - * - * @param[in] domain value of domain in the queried. - * @param[out] capability capability value of the attribute queried. - * @param[in] attr value of the attribute to the queried. - * @return 0 if query is successful. - * non-zero if error, return value points to the error. - */ -int get_vtcm_info(int domain, uint32_t * capability, uint32_t attr); - -/** - * Wrapper for FastRPC Capability API: query unsigned pd support on CDSP domain. - * - * @return true if unsigned pd is supported. - * false if unsigned pd is not supported, capability query failed. - */ - -bool get_unsignedpd_support(void); - -/** - * Wrapper for FastRPC Capability API: query unsigned pd support. - * - * @param[in] domain value of domain in the queried. - * @return true if unsigned pd is supported. - * false if unsigned pd is not supported, capability query failed. - */ - -bool is_unsignedpd_supported(int domain_id); - -/** - * is_valid_domain_id API: query a domain id is valid. - * - * @param[in] domain value of domain in the queried. - * @param[in] compute_only value of domain is only compared with CDSP domains supported by the target when enabled. - * @return true if value of domain is valid. - * false if value of domain is not valid. - */ - -bool is_valid_domain_id(int domain_id, int compute_only); - -/** - * get_domain API: get domain struct from domain value. - * - * @param[in] domain value of a domain - * @return Returns domain struct of the domain if it is supported or else - * returns NULL. - * - */ - -domain * get_domain(int domain_id); - -/** - * get_domains_info API: get information for all the domains available on the device - * - * @param[in] domain_type pointer to domain type - * @param[in] num_domains pointer to number of domains - * @param[in] domains_info pointer to save discovered domains information. - * @return 0 if query is successful. - * non-zero if error, return value points to the error. - * - * It is user's responsibility to free the memory used to store the domains info whose address is present in domains_info before closing the application. - * - */ - -int get_domains_info(char * domain_type, int * num_domains, fastrpc_domain ** domains_info); - -/** - * get_effective_domain_id API: get effective domain id for given session id - * - * @param[in] domain_name pointer to domain name - * @param[in] session_id - * @param[in] effec_domain_id pointer to save obtained effective domain id. - * @return 0 if query is successful. - * non-zero if error, return value points to the error. - * - */ - -int get_effective_domain_id(char * domain_name, int session_id, int * effec_domain_id); - -/** - * is_async_fastrpc_supported API: query a domain id has async fastrpc supported or not - * - * @param[in] domain_id value of a domain - * @return Returns true or false stating support of Async FastRPC - * - */ - -bool is_async_fastrpc_supported(int domain_id); - -/** - * is_status_notification_supported API: query the DSP for STATUS_NOTIFICATION_SUPPORT information - * - * @param[in] domain_id value of a domain - * @return Returns true or false stating status notification support information - * - */ -bool is_status_notification_supported(int domain_id); - -/** - * get_hmx_support_info API: query the DSP for HMX SUPPORT information - * - * @param[in] domain_id value of a domain - * @param[out] capability capability value of the attribute queried. - * @param[in] attr value of the attribute to the queried. - * @return 0 if query is successful. - * non-zero if error, return value points to the error. - * - */ -int get_hmx_support_info(int domain, uint32_t * capability, uint32_t attr); - -/** - * get_hex_arch_ver API: query the Hexagon processor architecture version information - * - * @param[in] domain_id value of a domain - * @param[out] Arch version (73, 75, ...) - * @return 0 if query is successful. - * non-zero if error, return value points to the error. - * - */ -int get_hex_arch_ver(int domain, int * arch); - -/** - * get_hvx_support_info API: query the DSP for HVX SUPPORT information - * - * @param[in] domain_id value of a domain - * @param[out] capability capability value of the attribute queried. - * @param[in] attr value of the attribute to the queried. - * @return 0 if query is successful. - * non-zero if error, return value points to the error. - * - */ -int get_hvx_support_info(int domain, uint32_t * capability, uint32_t attr); - -#ifdef __cplusplus -} -#endif - -#endif //DSP_CAPABILITIES_UTILS_H diff --git a/ggml/src/ggml-hexagon/libdl.h b/ggml/src/ggml-hexagon/libdl.h new file mode 100644 index 00000000000..8ca5016f039 --- /dev/null +++ b/ggml/src/ggml-hexagon/libdl.h @@ -0,0 +1,79 @@ +#pragma once + +#ifdef _WIN32 +# define WIN32_LEAN_AND_MEAN +# ifndef NOMINMAX +# define NOMINMAX +# endif +# include +# include +#else +# include +# include +#endif +#include + +namespace fs = std::filesystem; + +#ifdef _WIN32 + +using dl_handle = std::remove_pointer_t; + +struct dl_handle_deleter { + void operator()(HMODULE handle) { + FreeLibrary(handle); + } +}; + +static inline dl_handle * dl_load_library(const fs::path & path) { + // suppress error dialogs for missing DLLs + DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS); + SetErrorMode(old_mode | SEM_FAILCRITICALERRORS); + + HMODULE handle = LoadLibraryW(path.wstring().c_str()); + + SetErrorMode(old_mode); + + return handle; +} + +static inline void * dl_get_sym(dl_handle * handle, const char * name) { + DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS); + SetErrorMode(old_mode | SEM_FAILCRITICALERRORS); + + void * p = (void *) GetProcAddress(handle, name); + + SetErrorMode(old_mode); + + return p; +} + +static inline const char * dl_error() { + return ""; +} + +#else + +using dl_handle = void; + +struct dl_handle_deleter { + void operator()(void * handle) { + dlclose(handle); + } +}; + +static inline dl_handle * dl_load_library(const fs::path & path) { + dl_handle * handle = dlopen(path.string().c_str(), RTLD_NOW | RTLD_LOCAL); + return handle; +} + +static inline void * dl_get_sym(dl_handle * handle, const char * name) { + return dlsym(handle, name); +} + +static inline const char * dl_error() { + const char *rslt = dlerror(); + return rslt != nullptr ? rslt : ""; +} + +#endif diff --git a/ggml/src/ggml-hexagon/libggml-htp.inf b/ggml/src/ggml-hexagon/libggml-htp.inf new file mode 100644 index 00000000000..656d2d9ab26 --- /dev/null +++ b/ggml/src/ggml-hexagon/libggml-htp.inf @@ -0,0 +1,38 @@ +[Version] +Signature = "$WINDOWS NT$" +Class = ComputeAccelerator +ClassGuid = {F01A9D53-3FF6-48D2-9F97-C8A7004BE10C} +Provider = %GGML% +DriverVer = 01/01/2026,1.0.0.0 +CatalogFile = libggml-htp.cat +PnpLockDown = 1 + +[DestinationDirs] +Drivers_Dir = 6 + +[SourceDisksNames] +1 = %DiskId% + +[SourceDisksFiles] +libggml-htp-v68.so = 1 +libggml-htp-v69.so = 1 +libggml-htp-v73.so = 1 +libggml-htp-v75.so = 1 +libggml-htp-v81.so = 1 + +[ControlFlags] +ExcludeFromSelect = * + +[DefaultInstall.NTarm64] +CopyFiles=Drivers_Dir + +[Drivers_Dir] +libggml-htp-v68.so,,,0x10 ;COPYFLG_NO_OVERWRITE +libggml-htp-v69.so,,,0x10 ;COPYFLG_NO_OVERWRITE +libggml-htp-v73.so,,,0x10 ;COPYFLG_NO_OVERWRITE +libggml-htp-v75.so,,,0x10 ;COPYFLG_NO_OVERWRITE +libggml-htp-v81.so,,,0x10 ;COPYFLG_NO_OVERWRITE + +[Strings] +GGML = 'GGML' +DiskId = 'GGML HTP library' From 829e70044b51d48e6a57c043a61b5b086b3acb0f Mon Sep 17 00:00:00 2001 From: Zheyuan Chen Date: Thu, 29 Jan 2026 14:05:30 -0800 Subject: [PATCH 065/831] ggml-webgpu: improve flastAttention performance by software pipelining (llama/19151) * webgpu : pipeline flash_attn Q/K loads in WGSL * ggml-webgpu: unroll Q*K accumlation inner loop * ggml-webgpu: vectorization * ggml-webgpu: unrolling * ggml-webgpu: remove redundant unrolling * ggml-webgpu: restore the config * ggml-webgpu: remove redundant comments * ggml-webgpu: formatting * ggml-webgpu: formatting and remove vectorization * ggml-webgpu: remove unnecessary constants * ggml-webgpu: change QKV buffer to read_write to pass validation * ggml-webgpu: add explanation for the additional bracket around Q K accumulate * Indentation and for -> if for tail * Kick off CI on wgsl only commits --------- Co-authored-by: Reese Levine --- .../ggml-webgpu/wgsl-shaders/flash_attn.wgsl | 165 +++++++++++------- 1 file changed, 105 insertions(+), 60 deletions(-) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl index de7c132a624..b6822161464 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl @@ -114,7 +114,7 @@ struct Params { #define PARAMS_BINDING 4 #endif -@group(0) @binding(DST_BINDING) var dst: array; +@group(0) @binding(DST_BINDING) var dst: array>; @group(0) @binding(PARAMS_BINDING) var params: Params; // Just a very small float value. @@ -160,14 +160,21 @@ fn calc_softmax_term(kv_idx: u32, q_tile_row: u32, slope: f32) -> f32 { return v; } +fn load_f32x4(buf: ptr>, read_write>, scalar_index: u32) -> vec4 { + return (*buf)[scalar_index >> 2u]; +} + +fn load_kvx4(buf: ptr>, read_write>, scalar_index: u32) -> vec4 { + return (*buf)[scalar_index >> 2u]; +} @compute @workgroup_size(WG_SIZE) fn main(@builtin(workgroup_id) wg_id: vec3, - @builtin(local_invocation_id) local_id: vec3, - @builtin(subgroup_id) subgroup_id: u32, - @builtin(subgroup_size) subgroup_size: u32, - @builtin(num_subgroups) num_subgroups: u32, - @builtin(subgroup_invocation_id) sg_inv_id: u32) { + @builtin(local_invocation_id) local_id: vec3, + @builtin(subgroup_id) subgroup_id: u32, + @builtin(subgroup_size) subgroup_size: u32, + @builtin(num_subgroups) num_subgroups: u32, + @builtin(subgroup_invocation_id) sg_inv_id: u32) { // initialize row max for online softmax for (var i = local_id.x; i < Q_TILE; i += WG_SIZE) { @@ -231,9 +238,9 @@ fn main(@builtin(workgroup_id) wg_id: vec3, for (var kv_tile = 0u; kv_tile < params.seq_len_kv; kv_tile += KV_TILE) { // clear inter_shmem to ensure zero-initialized accumulators - for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) { - inter_shmem[elem_idx] = 0.0; - } + for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) { + inter_shmem[elem_idx] = 0.0; + } // load k tile into shared memory #if defined(KV_Q4_0) @@ -309,48 +316,77 @@ fn main(@builtin(workgroup_id) wg_id: vec3, // accumulate q block * k block into registers across the entire KV tile // TODO: this loop seems to be the current largest bottleneck - for (var kv_block = subgroup_id; kv_block < KV_BLOCKS; kv_block += num_subgroups) { - let inter_offset = kv_block * SG_MAT_N; - var acc: subgroup_matrix_result = subgroupMatrixLoad< - subgroup_matrix_result>(&inter_shmem, inter_offset, false, KV_TILE); + // this bracket exists to scope the lifetime of variables, reducing register pressure + { #ifdef KV_DIRECT - let k_block_row = kv_tile + kv_block * SG_MAT_N; - let k_global_offset = k_head_offset + k_block_row * params.stride_k1; + let k_block_row = kv_tile + subgroup_id * SG_MAT_N; + var k_global_offset = k_head_offset + k_block_row * params.stride_k1; #else - let k_block_offset = kv_block * SG_MAT_N * HEAD_DIM_QK; + var k_block_offset = subgroup_id * SG_MAT_N * HEAD_DIM_QK; #endif - for (var head_dim_block = 0u; head_dim_block < HEAD_DIM_QK; head_dim_block += SG_MAT_K) { - // load q submatrix from shared memory - var q_sg_mat: subgroup_matrix_left = subgroupMatrixLoad>( - &q_shmem, - head_dim_block, - false, - HEAD_DIM_QK - ); + for (var kv_block = subgroup_id; kv_block < KV_BLOCKS; kv_block += num_subgroups) { + let inter_offset = kv_block * SG_MAT_N; + var acc: subgroup_matrix_result = subgroupMatrixLoad>(&inter_shmem, inter_offset, false, KV_TILE); + + var q_cur = subgroupMatrixLoad>(&q_shmem, 0u, false, HEAD_DIM_QK); - // load k submatrix from device or shared memory #ifdef KV_DIRECT - var k_sg_mat: subgroup_matrix_right = subgroupMatrixLoad>( - &K, - k_global_offset + head_dim_block, - true, - params.stride_k1 - ); + var k_cur = subgroupMatrixLoad>(&K, k_global_offset + 0u, true, params.stride_k1); #else - var k_sg_mat: subgroup_matrix_right = subgroupMatrixLoad>( - &kv_shmem, - k_block_offset + head_dim_block, - true, - HEAD_DIM_QK - ); + var k_cur = subgroupMatrixLoad>(&kv_shmem, k_block_offset + 0u, true, HEAD_DIM_QK); #endif - acc = subgroupMatrixMultiplyAccumulate(q_sg_mat, k_sg_mat, acc); - } - // store acc to shared memory for softmax (S matrix from paper) - subgroupMatrixStore(&inter_shmem, inter_offset, acc, false, KV_TILE); + var t: u32 = 1u; + for (; t + 1u < HEAD_DIM_QK / SG_MAT_K; t += 2u) { + let h0 = t * SG_MAT_K; + var q0 = subgroupMatrixLoad>(&q_shmem, h0, false, HEAD_DIM_QK); +#ifdef KV_DIRECT + var k0 = subgroupMatrixLoad>(&K, k_global_offset + h0, true, params.stride_k1); +#else + var k0 = subgroupMatrixLoad>(&kv_shmem, k_block_offset + h0, true, HEAD_DIM_QK); +#endif + acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc); + q_cur = q0; + k_cur = k0; + + let h1 = (t + 1u) * SG_MAT_K; + var q1g = subgroupMatrixLoad>(&q_shmem, h1, false, HEAD_DIM_QK); +#ifdef KV_DIRECT + var k1g = subgroupMatrixLoad>(&K, k_global_offset + h1, true, params.stride_k1); +#else + var k1g = subgroupMatrixLoad>(&kv_shmem, k_block_offset + h1, true, HEAD_DIM_QK); +#endif + acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc); + q_cur = q1g; + k_cur = k1g; + } + + // handle odd tail + if (t < HEAD_DIM_QK / SG_MAT_K) { + let h = t * SG_MAT_K; + var qn = subgroupMatrixLoad>(&q_shmem, h, false, HEAD_DIM_QK); +#ifdef KV_DIRECT + var kn = subgroupMatrixLoad>(&K, k_global_offset + h, true, params.stride_k1); +#else + var kn = subgroupMatrixLoad>(&kv_shmem, k_block_offset + h, true, HEAD_DIM_QK); +#endif + acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc); + q_cur = qn; + k_cur = kn; + } + + acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc); + +#ifdef KV_DIRECT + k_global_offset += num_subgroups * SG_MAT_N * params.stride_k1; +#else + k_block_offset += num_subgroups * SG_MAT_N * HEAD_DIM_QK; +#endif + subgroupMatrixStore(&inter_shmem, inter_offset, acc, false, KV_TILE); + } } + #ifdef MASK // load mask tile into shared memory for this KV block // TODO: optimize and skip if mask is -INF for the entire tile @@ -495,7 +531,6 @@ fn main(@builtin(workgroup_id) wg_id: vec3, false, HEAD_DIM_V ); - for (var kv_block = 0u; kv_block < KV_BLOCKS; kv_block++) { let p_offset = kv_block * SG_MAT_N; var p_sg_mat: subgroup_matrix_left = subgroupMatrixLoad>( @@ -527,11 +562,9 @@ fn main(@builtin(workgroup_id) wg_id: vec3, // O += P * V o_sg_mat = subgroupMatrixMultiplyAccumulate(p_sg_mat, v_sg_mat, o_sg_mat); } - // store O back to shared memory subgroupMatrixStore(&o_shmem, head_dim_block, o_sg_mat, false, HEAD_DIM_V); } - workgroupBarrier(); } @@ -566,26 +599,38 @@ fn main(@builtin(workgroup_id) wg_id: vec3, o_shmem[idx] = f16(val); } } - workgroupBarrier(); #endif - - // write output back to global memory for (var q_tile_row = subgroup_id; - q_tile_row < Q_TILE; - q_tile_row += num_subgroups) { - let global_q_row = q_row_start + q_tile_row; - if (global_q_row >= params.seq_len_q) { - break; - } + q_tile_row < Q_TILE; + q_tile_row += num_subgroups) { - let exp_sum = exp_sum_shmem[q_tile_row]; - let scale = select(0.0, 1.0 / exp_sum, exp_sum != 0); + let global_q_row = q_row_start + q_tile_row; + if (global_q_row >= params.seq_len_q) { break; } - for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) { - let o_val = o_shmem[q_tile_row * HEAD_DIM_V + elem_idx]; - let scaled = f32(o_val) * scale; - dst[dst_global_offset + q_tile_row * dst2_stride + elem_idx] = scaled; - } + let exp_sum = exp_sum_shmem[q_tile_row]; + let scale = select(0.0, 1.0 / exp_sum, exp_sum != 0.0); + + let row_base: u32 = dst_global_offset + q_tile_row * dst2_stride; + + for (var elem_base = sg_inv_id * 4u; + elem_base < HEAD_DIM_V; + elem_base += subgroup_size * 4u) { + + let i0 = q_tile_row * HEAD_DIM_V + (elem_base + 0u); + let i1 = q_tile_row * HEAD_DIM_V + (elem_base + 1u); + let i2 = q_tile_row * HEAD_DIM_V + (elem_base + 2u); + let i3 = q_tile_row * HEAD_DIM_V + (elem_base + 3u); + + let v = vec4( + f32(o_shmem[i0]) * scale, + f32(o_shmem[i1]) * scale, + f32(o_shmem[i2]) * scale, + f32(o_shmem[i3]) * scale + ); + + let dst_vec_index: u32 = (row_base + elem_base) >> 2u; + dst[dst_vec_index] = v; + } } } From 1b3c27efae3ab3fc4a29c1f78108881aeec3c3a2 Mon Sep 17 00:00:00 2001 From: RachelMantel Date: Fri, 30 Jan 2026 06:00:49 +0200 Subject: [PATCH 066/831] sycl: implement GGML_OP_TRI (llama/19089) * sycl: implement GGML_OP_TRI * docs: update ops.md for SYCL TRI * docs: regenerate ops.md * docs: update SYCL support for GGML_OP_TRI --- ggml/src/ggml-sycl/ggml-sycl.cpp | 69 ++++++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 3a4c092af5d..d20b7ec57df 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -2263,6 +2263,65 @@ inline void ggml_sycl_op_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_ten diag_mask_inf_f32_sycl(src0_dd, dst_dd, ne00, nrows0, ne01, n_past, main_stream); } +static void tri_f32_sycl( + const float * src, + float * dst, + const int64_t ne0, + const int64_t ne1, + const int64_t ne2, + const int64_t ne3, + const ggml_tri_type ttype, + dpct::queue_ptr main_stream +) { + const size_t total = (size_t) ne0 * (size_t) ne1 * (size_t) ne2 * (size_t) ne3; + + main_stream->parallel_for(sycl::range<1>(total), [=](sycl::id<1> tid) { + const int64_t idx = (int64_t) tid[0]; + + const int64_t i0 = idx % ne0; + const int64_t t1 = idx / ne0; + const int64_t i1 = t1 % ne1; + + bool keep = false; + switch (ttype) { + case GGML_TRI_TYPE_LOWER: keep = (i0 < i1); break; + case GGML_TRI_TYPE_LOWER_DIAG: keep = (i0 <= i1); break; + case GGML_TRI_TYPE_UPPER: keep = (i0 > i1); break; + case GGML_TRI_TYPE_UPPER_DIAG: keep = (i0 >= i1); break; + default: keep = false; break; + } + + dst[idx] = keep ? src[idx] : 0.0f; + }); +} + +static void ggml_sycl_op_tri(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + GGML_ASSERT(src0); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + + const float * src0_dd = static_cast(src0->data); + float * dst_dd = static_cast(dst->data); + + const ggml_tri_type ttype = (ggml_tri_type) ggml_get_op_params_i32(dst, 0); + + const int64_t ne0 = src0->ne[0]; + const int64_t ne1 = src0->ne[1]; + const int64_t ne2 = src0->ne[2]; + const int64_t ne3 = src0->ne[3]; + + tri_f32_sycl(src0_dd, dst_dd, ne0, ne1, ne2, ne3, ttype, main_stream); +} + + inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); @@ -3912,6 +3971,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg case GGML_OP_TRANSPOSE: GGML_SYCL_DEBUG("%s: Tensor NO-OP\n", __func__); break; + case GGML_OP_TRI: + ggml_sycl_op_tri(ctx, dst); + break; case GGML_OP_DIAG_MASK_INF: ggml_sycl_diag_mask_inf(ctx, dst); break; @@ -4616,6 +4678,13 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g return true; case GGML_OP_CONT: return op->src[0]->type != GGML_TYPE_BF16; + case GGML_OP_TRI: + { + const ggml_tensor * src0 = op->src[0]; + return src0 && + op->type == GGML_TYPE_F32 && + ggml_is_contiguous(src0); + } case GGML_OP_DIAG_MASK_INF: return true; case GGML_OP_SOFT_MAX: From 2a16e7a67f3c1bdaa3a77b3a9646eedde41dbb83 Mon Sep 17 00:00:00 2001 From: s8322 Date: Fri, 30 Jan 2026 06:01:38 +0200 Subject: [PATCH 067/831] sycl: implement GGML_UNARY_OP_SOFTPLUS (llama/19114) * sycl: add softplus unary op implementation * sycl: add softplus unary op implementation * docs(ops): mark SYCL SOFTPLUS as supported * docs: update SYCL status for SOFTPLUS --- ggml/src/ggml-sycl/element_wise.cpp | 20 ++++++++++++++++++++ ggml/src/ggml-sycl/element_wise.hpp | 2 ++ ggml/src/ggml-sycl/ggml-sycl.cpp | 4 ++++ 3 files changed, 26 insertions(+) diff --git a/ggml/src/ggml-sycl/element_wise.cpp b/ggml/src/ggml-sycl/element_wise.cpp index 8d83b2446bd..651b875b636 100644 --- a/ggml/src/ggml-sycl/element_wise.cpp +++ b/ggml/src/ggml-sycl/element_wise.cpp @@ -123,6 +123,15 @@ static __dpct_inline__ T op_log(T x) { return sycl::log(x); } +template +static __dpct_inline__ T op_softplus(T x) { + const float xf = (float) x; + const float ax = sycl::fabs(xf); + const float m = sycl::fmax(xf, 0.0f); + const float y = m + sycl::log1p(sycl::exp(-ax)); + return (T) y; +} + template static __dpct_inline__ T op_neg(T x) { return -x; @@ -695,6 +704,12 @@ static inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, ggml_tensor }); } +static inline void ggml_sycl_op_softplus(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) { + return op_softplus(x); + }); +} + static inline void ggml_sycl_op_neg(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) { return op_neg(x); @@ -1101,6 +1116,11 @@ void ggml_sycl_log(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { ggml_sycl_op_log(ctx, dst); } +void ggml_sycl_softplus(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + ggml_sycl_op_softplus(ctx, dst); +} + void ggml_sycl_neg(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); ggml_sycl_op_neg(ctx, dst); diff --git a/ggml/src/ggml-sycl/element_wise.hpp b/ggml/src/ggml-sycl/element_wise.hpp index 0913a2e529b..7c71974687a 100644 --- a/ggml/src/ggml-sycl/element_wise.hpp +++ b/ggml/src/ggml-sycl/element_wise.hpp @@ -61,6 +61,8 @@ void ggml_sycl_exp(ggml_backend_sycl_context & ctx, ggml_tensor * dst); void ggml_sycl_log(ggml_backend_sycl_context & ctx, ggml_tensor * dst); +void ggml_sycl_softplus(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + void ggml_sycl_neg(ggml_backend_sycl_context & ctx, ggml_tensor * dst); void ggml_sycl_step(ggml_backend_sycl_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index d20b7ec57df..74b4ed91cc0 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -3845,6 +3845,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg case GGML_UNARY_OP_EXP: ggml_sycl_exp(ctx, dst); break; + case GGML_UNARY_OP_SOFTPLUS: + ggml_sycl_softplus(ctx, dst); + break; case GGML_UNARY_OP_SGN: ggml_sycl_sgn(ctx, dst); break; @@ -4466,6 +4469,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_GELU_ERF: case GGML_UNARY_OP_EXP: + case GGML_UNARY_OP_SOFTPLUS: case GGML_UNARY_OP_ELU: return true; case GGML_UNARY_OP_FLOOR: From 5dca0db99c60f12f19aff65958a9ffa92bf37d4e Mon Sep 17 00:00:00 2001 From: bssrdf Date: Thu, 29 Jan 2026 23:57:52 -0500 Subject: [PATCH 068/831] add tensor type checking as part of cuda graph properties (llama/19186) --- ggml/src/ggml-cuda/common.cuh | 1 + ggml/src/ggml-cuda/ggml-cuda.cu | 5 +++++ 2 files changed, 6 insertions(+) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 43280644e48..a3256d59dd0 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -1124,6 +1124,7 @@ struct ggml_tensor_extra_gpu { struct ggml_cuda_graph_node_properties { void * node_data; ggml_op node_op; + enum ggml_type node_type; int32_t flags; int64_t ne[GGML_MAX_DIMS]; size_t nb[GGML_MAX_DIMS]; diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index cfcffde8a21..e9e9592ebad 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2920,6 +2920,7 @@ static void ggml_cuda_graph_node_set_properties(ggml_cuda_graph_node_properties memset(props, 0, sizeof(ggml_cuda_graph_node_properties)); props->node_data = node->data; props->node_op = node->op; + props->node_type = node->type; props->flags = node->flags; for (int i = 0; i < GGML_MAX_DIMS; i++) { props->ne[i] = node->ne[i]; @@ -2944,6 +2945,10 @@ static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_ return false; } + if (node->type != props->node_type) { + return false; + } + for (int i = 0; i < GGML_MAX_DIMS; i++) { if (node->ne[i] != props->ne[i]) { return false; From b529c0610fdd476cac8e378a59facca484a083db Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 30 Jan 2026 13:50:43 +0200 Subject: [PATCH 069/831] sync : ggml --- scripts/sync-ggml.last | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index 44fa890d78b..95f370ece52 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -b6d1f0f247adcfa25c0ca1ffe97e651fe1afd5e2 +f7cb4b731a38e1f6d24e61c966acc35b0cc31263 From 953e503fd970d80c37bb0be28111704d0961e44e Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 30 Jan 2026 13:50:58 +0200 Subject: [PATCH 070/831] talk-llama : sync llama.cpp --- examples/talk-llama/CMakeLists.txt | 1 + examples/talk-llama/llama-adapter.cpp | 20 +- examples/talk-llama/llama-adapter.h | 4 +- examples/talk-llama/llama-arch.cpp | 33 ++ examples/talk-llama/llama-arch.h | 1 + examples/talk-llama/llama-chat.cpp | 20 + examples/talk-llama/llama-chat.h | 1 + examples/talk-llama/llama-context.cpp | 482 ++++++++++-------- examples/talk-llama/llama-context.h | 12 +- examples/talk-llama/llama-cparams.h | 2 + examples/talk-llama/llama-graph.cpp | 363 ++++++++++--- examples/talk-llama/llama-graph.h | 88 +++- examples/talk-llama/llama-hparams.cpp | 55 +- examples/talk-llama/llama-hparams.h | 53 +- examples/talk-llama/llama-kv-cache.cpp | 304 ++++++++--- examples/talk-llama/llama-kv-cache.h | 2 - .../talk-llama/llama-memory-hybrid-iswa.cpp | 275 ++++++++++ .../talk-llama/llama-memory-hybrid-iswa.h | 140 +++++ examples/talk-llama/llama-mmap.cpp | 19 +- examples/talk-llama/llama-model-loader.cpp | 26 +- examples/talk-llama/llama-model-saver.cpp | 4 +- examples/talk-llama/llama-model.cpp | 399 ++++++++++----- examples/talk-llama/llama-model.h | 5 +- examples/talk-llama/llama-quant.cpp | 111 ++-- examples/talk-llama/llama-sampling.cpp | 183 ++++++- examples/talk-llama/llama-vocab.cpp | 61 ++- examples/talk-llama/llama-vocab.h | 1 + examples/talk-llama/llama.cpp | 66 ++- examples/talk-llama/llama.h | 38 +- examples/talk-llama/models/deepseek2.cpp | 28 +- examples/talk-llama/models/exaone-moe.cpp | 146 ++++++ examples/talk-llama/models/gemma3n-iswa.cpp | 12 +- examples/talk-llama/models/minicpm3.cpp | 1 + examples/talk-llama/models/models.h | 4 + examples/talk-llama/models/nemotron-h.cpp | 2 +- examples/talk-llama/models/plm.cpp | 1 + examples/talk-llama/models/qwen3vl-moe.cpp | 19 +- examples/talk-llama/models/qwen3vl.cpp | 19 +- 38 files changed, 2285 insertions(+), 716 deletions(-) create mode 100644 examples/talk-llama/llama-memory-hybrid-iswa.cpp create mode 100644 examples/talk-llama/llama-memory-hybrid-iswa.h create mode 100644 examples/talk-llama/models/exaone-moe.cpp diff --git a/examples/talk-llama/CMakeLists.txt b/examples/talk-llama/CMakeLists.txt index cac46705d6c..20caaa99de5 100644 --- a/examples/talk-llama/CMakeLists.txt +++ b/examples/talk-llama/CMakeLists.txt @@ -22,6 +22,7 @@ if (WHISPER_SDL2) llama-kv-cache-iswa.cpp llama-memory-recurrent.cpp llama-memory-hybrid.cpp + llama-memory-hybrid-iswa.cpp llama-memory.cpp llama-mmap.cpp llama-model-loader.cpp diff --git a/examples/talk-llama/llama-adapter.cpp b/examples/talk-llama/llama-adapter.cpp index bdc24c2d6b1..d6a5800e63a 100644 --- a/examples/talk-llama/llama-adapter.cpp +++ b/examples/talk-llama/llama-adapter.cpp @@ -146,11 +146,9 @@ llama_adapter_lora_weight * llama_adapter_lora::get_weight(ggml_tensor * w) { return nullptr; } -static void llama_adapter_lora_init_impl(const char * path_lora, llama_adapter_lora & adapter) { +static void llama_adapter_lora_init_impl(llama_model & model, const char * path_lora, llama_adapter_lora & adapter) { LLAMA_LOG_INFO("%s: loading lora adapter from '%s' ...\n", __func__, path_lora); - llama_model & model = adapter.model; - ggml_context * ctx_init; gguf_init_params meta_gguf_params = { /* .no_alloc = */ true, @@ -413,17 +411,17 @@ static void llama_adapter_lora_init_impl(const char * path_lora, llama_adapter_l } } - // update number of nodes used - model.n_lora_nodes += adapter.get_n_nodes(); + // register adapter with model + model.loras.insert(&adapter); LLAMA_LOG_INFO("%s: loaded %zu tensors from lora file\n", __func__, adapter.ab_map.size()*2); } llama_adapter_lora * llama_adapter_lora_init(llama_model * model, const char * path_lora) { - llama_adapter_lora * adapter = new llama_adapter_lora(*model); + llama_adapter_lora * adapter = new llama_adapter_lora(); try { - llama_adapter_lora_init_impl(path_lora, *adapter); + llama_adapter_lora_init_impl(*model, path_lora, *adapter); return adapter; } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: failed to apply lora adapter: %s\n", __func__, err.what()); @@ -473,12 +471,8 @@ int32_t llama_adapter_meta_val_str_by_index(const llama_adapter_lora * adapter, return snprintf(buf, buf_size, "%s", it->second.c_str()); } -void llama_adapter_lora_free(llama_adapter_lora * adapter) { - // update number of nodes used - GGML_ASSERT(adapter->model.n_lora_nodes >= adapter->get_n_nodes()); - adapter->model.n_lora_nodes -= adapter->get_n_nodes(); - - delete adapter; +void llama_adapter_lora_free(llama_adapter_lora *) { + // deprecated: adapters are freed by llama_model's destructor } uint64_t llama_adapter_get_alora_n_invocation_tokens(const struct llama_adapter_lora * adapter) { diff --git a/examples/talk-llama/llama-adapter.h b/examples/talk-llama/llama-adapter.h index 42d64a6e0b5..d275d25425e 100644 --- a/examples/talk-llama/llama-adapter.h +++ b/examples/talk-llama/llama-adapter.h @@ -59,8 +59,6 @@ struct llama_adapter_lora_weight { }; struct llama_adapter_lora { - llama_model & model; - // map tensor name to lora_a_b std::unordered_map ab_map; @@ -75,7 +73,7 @@ struct llama_adapter_lora { // activated lora (aLoRA) std::vector alora_invocation_tokens; - llama_adapter_lora(llama_model & model) : model(model) {} + llama_adapter_lora() = default; ~llama_adapter_lora() = default; llama_adapter_lora_weight * get_weight(ggml_tensor * w); diff --git a/examples/talk-llama/llama-arch.cpp b/examples/talk-llama/llama-arch.cpp index f736ee67050..a54bc1956ae 100644 --- a/examples/talk-llama/llama-arch.cpp +++ b/examples/talk-llama/llama-arch.cpp @@ -81,6 +81,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_NEMOTRON_H_MOE, "nemotron_h_moe" }, { LLM_ARCH_EXAONE, "exaone" }, { LLM_ARCH_EXAONE4, "exaone4" }, + { LLM_ARCH_EXAONE_MOE, "exaone-moe" }, { LLM_ARCH_RWKV6, "rwkv6" }, { LLM_ARCH_RWKV6QWEN2, "rwkv6qwen2" }, { LLM_ARCH_RWKV7, "rwkv7" }, @@ -1728,6 +1729,38 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_FFN_UP, LLM_TENSOR_FFN_POST_NORM, }; + case LLM_ARCH_EXAONE_MOE: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ROPE_FREQS, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_Q_NORM, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_K_NORM, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_FFN_GATE_INP, + LLM_TENSOR_FFN_GATE_EXPS, + LLM_TENSOR_FFN_DOWN_EXPS, + LLM_TENSOR_FFN_UP_EXPS, + LLM_TENSOR_FFN_GATE_SHEXP, + LLM_TENSOR_FFN_UP_SHEXP, + LLM_TENSOR_FFN_DOWN_SHEXP, + LLM_TENSOR_FFN_EXP_PROBS_B, + LLM_TENSOR_NEXTN_EH_PROJ, + LLM_TENSOR_NEXTN_EMBED_TOKENS, + LLM_TENSOR_NEXTN_ENORM, + LLM_TENSOR_NEXTN_HNORM, + LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, + LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, + }; case LLM_ARCH_RWKV6: return { LLM_TENSOR_TOKEN_EMBD, diff --git a/examples/talk-llama/llama-arch.h b/examples/talk-llama/llama-arch.h index 68ec6a18b18..270d28b16a4 100644 --- a/examples/talk-llama/llama-arch.h +++ b/examples/talk-llama/llama-arch.h @@ -85,6 +85,7 @@ enum llm_arch { LLM_ARCH_NEMOTRON_H_MOE, LLM_ARCH_EXAONE, LLM_ARCH_EXAONE4, + LLM_ARCH_EXAONE_MOE, LLM_ARCH_RWKV6, LLM_ARCH_RWKV6QWEN2, LLM_ARCH_RWKV7, diff --git a/examples/talk-llama/llama-chat.cpp b/examples/talk-llama/llama-chat.cpp index b54ebbd155d..3c7e0afdae8 100644 --- a/examples/talk-llama/llama-chat.cpp +++ b/examples/talk-llama/llama-chat.cpp @@ -57,6 +57,7 @@ static const std::map LLM_CHAT_TEMPLATES = { { "minicpm", LLM_CHAT_TEMPLATE_MINICPM }, { "exaone3", LLM_CHAT_TEMPLATE_EXAONE_3 }, { "exaone4", LLM_CHAT_TEMPLATE_EXAONE_4 }, + { "exaone-moe", LLM_CHAT_TEMPLATE_EXAONE_MOE }, { "rwkv-world", LLM_CHAT_TEMPLATE_RWKV_WORLD }, { "granite", LLM_CHAT_TEMPLATE_GRANITE }, { "gigachat", LLM_CHAT_TEMPLATE_GIGACHAT }, @@ -137,6 +138,9 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) { } else if (tmpl_contains("[gMASK]")) { return LLM_CHAT_TEMPLATE_CHATGLM_4; } else if (tmpl_contains("<|assistant|>") && tmpl_contains("<|user|>")) { + if (tmpl_contains("<|tool_declare|>")) { + return LLM_CHAT_TEMPLATE_EXAONE_MOE; + } return tmpl_contains("") ? LLM_CHAT_TEMPLATE_FALCON_3 : LLM_CHAT_TEMPLATE_GLMEDGE; } else if (tmpl_contains("<|{{ item['role'] }}|>") && tmpl_contains("<|begin_of_image|>")) { return LLM_CHAT_TEMPLATE_GLMEDGE; @@ -576,6 +580,22 @@ int32_t llm_chat_apply_template( if (add_ass) { ss << "[|assistant|]"; } + } else if (tmpl == LLM_CHAT_TEMPLATE_EXAONE_MOE) { + for (auto message : chat) { + std::string role(message->role); + if (role == "system") { + ss << "<|system|>\n" << trim(message->content) << "<|endofturn|>\n"; + } else if (role == "user") { + ss << "<|user|>\n" << trim(message->content) << "<|endofturn|>\n"; + } else if (role == "assistant") { + ss << "<|assistant|>\n" << trim(message->content) << "<|endofturn|>\n"; + } else if (role == "tool") { + ss << "<|tool|>\n" << trim(message->content) << "<|endofturn|>\n"; + } + } + if (add_ass) { + ss << "<|assistant|>\n"; + } } else if (tmpl == LLM_CHAT_TEMPLATE_RWKV_WORLD) { // this template requires the model to have "\n\n" as EOT token for (size_t i = 0; i < chat.size(); i++) { diff --git a/examples/talk-llama/llama-chat.h b/examples/talk-llama/llama-chat.h index e1f795249c8..9ed1db128ec 100644 --- a/examples/talk-llama/llama-chat.h +++ b/examples/talk-llama/llama-chat.h @@ -36,6 +36,7 @@ enum llm_chat_template { LLM_CHAT_TEMPLATE_MINICPM, LLM_CHAT_TEMPLATE_EXAONE_3, LLM_CHAT_TEMPLATE_EXAONE_4, + LLM_CHAT_TEMPLATE_EXAONE_MOE, LLM_CHAT_TEMPLATE_RWKV_WORLD, LLM_CHAT_TEMPLATE_GRANITE, LLM_CHAT_TEMPLATE_GIGACHAT, diff --git a/examples/talk-llama/llama-context.cpp b/examples/talk-llama/llama-context.cpp index f220010a1b4..10b306a8537 100644 --- a/examples/talk-llama/llama-context.cpp +++ b/examples/talk-llama/llama-context.cpp @@ -146,6 +146,7 @@ llama_context::llama_context( } cparams.flash_attn = params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED; + cparams.auto_fa = params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO; // with causal attention, the batch size is limited by the context size cparams.n_batch = cparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch; @@ -155,6 +156,9 @@ llama_context::llama_context( cparams.op_offload = params.op_offload; cparams.kv_unified = params.kv_unified; + // intialized later + cparams.pipeline_parallel = false; + { const char * LLAMA_GRAPH_REUSE_DISABLE = getenv("LLAMA_GRAPH_REUSE_DISABLE"); graph_reuse_disable = LLAMA_GRAPH_REUSE_DISABLE ? (atoi(LLAMA_GRAPH_REUSE_DISABLE) != 0) : graph_reuse_disable; @@ -249,11 +253,7 @@ llama_context::llama_context( // graph outputs buffer { - // resized during inference when a batch uses more outputs - // Create a dummy batch for initialization. - llama_batch dummy_batch = {}; - dummy_batch.n_tokens = 0; - if (output_reserve(params.n_seq_max, dummy_batch) < params.n_seq_max) { + if (output_reserve(params.n_seq_max) < params.n_seq_max) { throw std::runtime_error("failed to reserve initial output buffer"); } @@ -302,16 +302,6 @@ llama_context::llama_context( LLAMA_LOG_DEBUG("%s: backend_ptrs.size() = %zu\n", __func__, backend_ptrs.size()); - const uint32_t n_seqs = cparams.n_seq_max; - const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); - - const size_t max_nodes = this->graph_max_nodes(n_tokens); - - LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes); - - gf_res_prev.reset(new llm_graph_result(max_nodes)); - gf_res_reserve.reset(new llm_graph_result(max_nodes)); - // TODO: move these checks to ggml_backend_sched // enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary bool pipeline_parallel = @@ -340,177 +330,218 @@ llama_context::llama_context( } } - sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, pipeline_parallel, cparams.op_offload)); + cparams.pipeline_parallel = pipeline_parallel; - if (pipeline_parallel) { - LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(sched.get())); + if (cparams.pipeline_parallel) { + LLAMA_LOG_INFO("%s: pipeline parallelism enabled\n", __func__); } - llama_memory_context_ptr mctx; - if (memory) { - LLAMA_LOG_DEBUG("%s: reserving full memory module\n", __func__); - mctx = memory->init_full(); - if (!mctx) { - throw std::runtime_error("failed to initialize memory module"); + sched_reserve(); + + if (!cparams.flash_attn) { + if (ggml_is_quantized(params.type_v)) { + throw std::runtime_error("quantized V cache was requested, but this requires Flash Attention"); } } + } - cross.v_embd.clear(); - - // avoid reserving graphs with zero outputs - assume one output per sequence - n_outputs = n_seqs; - - LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs); + // Initialize the full vocabulary token ids for backend samplers. + { + const int n_vocab = model.vocab.n_tokens(); - // resolve automatic Flash Attention use - if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO) { - auto * gf = graph_reserve(1, n_seqs, n_outputs, mctx.get(), true); - if (!gf) { - throw std::runtime_error("failed to split graph for Flash Attention check"); - } + sampling.token_ids_full_vocab.resize(n_vocab); + for (int i = 0; i < n_vocab; ++i) { + sampling.token_ids_full_vocab[i] = i; + } + } +} - const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FATTN) + 1; - bool fa_device_mismatch = false; - for (int i = 0; i < ggml_graph_n_nodes(gf); i++) { - ggml_tensor * n = ggml_graph_node(gf, i); - if (n->op != GGML_OP_FLASH_ATTN_EXT) { - continue; - } - ggml_backend_dev_t device_fa = ggml_backend_get_device( - ggml_backend_sched_get_tensor_backend(sched.get(), n)); +llama_context::~llama_context() { + if (!model.hparams.no_alloc) { + for (size_t i = 0; i < backend_ptrs.size(); ++i) { + ggml_backend_t backend = backend_ptrs[i]; + ggml_backend_buffer_type_t buft = backend_buft[i]; - // TODO: instead of the tensor names, use a map to keep track of which (FA) tensors belong to which layer - GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FATTN "-", prefix_len) == 0); - const int il = std::stoi(n->name + prefix_len); - ggml_backend_dev_t device_kv = model.dev_layer(il); - if (device_fa != device_kv) { - LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the Flash Attention tensor " - "is assigned to device %s (usually due to missing support)\n", - __func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_fa)); - // FIXME: fa_device_mismatch logic is wrong for --no-kv-offload, but this is broken anyways - fa_device_mismatch = true; - break; - } - } - if (fa_device_mismatch) { - cparams.flash_attn = false; - LLAMA_LOG_WARN("%s: Flash Attention was auto, set to disabled\n", __func__); - if (ggml_is_quantized(params.type_v)) { - throw std::runtime_error("quantized V cache was requested, but this requires Flash Attention"); - } + const size_t size_exp = backend_buf_exp_size[i]; + const size_t size_act = ggml_backend_sched_get_buffer_size(sched.get(), backend); + if (size_exp == size_act) { + LLAMA_LOG_DEBUG("%s: %10s compute buffer size is %8.4f MiB, matches expectation of %8.4f MiB\n", + __func__, ggml_backend_buft_name(buft), size_act / (1024.0*1024.0), size_exp / (1024.0*1024.0)); } else { - cparams.flash_attn = true; - LLAMA_LOG_INFO("%s: Flash Attention was auto, set to enabled\n", __func__); + LLAMA_LOG_WARN("%s: %10s compute buffer size of %8.4f MiB, does not match expectation of %8.4f MiB\n", + __func__, ggml_backend_buft_name(buft), size_act / (1024.0*1024.0), size_exp / (1024.0*1024.0)); } } + } + ggml_opt_free(opt_ctx); +} - // reserve worst-case graph - int n_splits_pp = -1; - int n_nodes_pp = -1; +void llama_context::sched_reserve() { + if (!sched_need_reserve) { + return; + } - int n_splits_tg = -1; - int n_nodes_tg = -1; + sched_need_reserve = false; - // reserve pp (prompt processing) graph first so that buffers are only allocated once - { - auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(), - model.hparams.no_alloc, model.hparams.no_alloc ? backend_buf_exp_size.data() : nullptr); - if (!gf) { - if (pipeline_parallel) { - LLAMA_LOG_WARN("%s: compute buffer allocation failed, retrying without pipeline parallelism\n", __func__); - sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, false, cparams.op_offload)); - gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get()); - } - if (!gf) { - throw std::runtime_error("failed to allocate compute pp buffers"); - } - } + LLAMA_LOG_INFO("%s: reserving ...\n", __func__); + + synchronize(); + + const int64_t t_start_us = ggml_time_us(); + + const uint32_t n_seqs = cparams.n_seq_max; + const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); + + const size_t max_nodes = this->graph_max_nodes(n_tokens); + + LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes); + + gf_res_prev.reset(new llm_graph_result(max_nodes)); + gf_res_reserve.reset(new llm_graph_result(max_nodes)); - n_splits_pp = ggml_backend_sched_get_n_splits(sched.get()); - n_nodes_pp = ggml_graph_n_nodes(gf); + sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, cparams.pipeline_parallel, cparams.op_offload)); + + llama_memory_context_ptr mctx; + if (memory) { + LLAMA_LOG_DEBUG("%s: reserving full memory module\n", __func__); + mctx = memory->init_full(); + if (!mctx) { + throw std::runtime_error("failed to initialize memory module"); } + } - // reserve with tg (token generation) graph to get the number of splits and nodes - { - auto * gf = graph_reserve(n_seqs, n_seqs, n_seqs, mctx.get(), model.hparams.no_alloc); - if (!gf) { - throw std::runtime_error("failed to allocate compute tg buffers"); - } + // avoid reserving graphs with zero outputs - assume one output per sequence + const int n_outputs = n_seqs; + + LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs); - n_splits_tg = ggml_backend_sched_get_n_splits(sched.get()); - n_nodes_tg = ggml_graph_n_nodes(gf); + // resolve automatic Flash Attention use + if (cparams.auto_fa) { + auto * gf = graph_reserve(1, n_seqs, n_outputs, mctx.get(), true); + if (!gf) { + throw std::runtime_error("failed to split graph for Flash Attention check"); } - // reserve again with pp graph to avoid ggml-alloc reallocations during inference - { - // TODO: not sure if the following graph would be worster case for multi-stream KV caches: - // - // auto * gf = graph_reserve(n_tokens, 1, n_tokens, mctx.get()); - // - auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(), model.hparams.no_alloc); - if (!gf) { - throw std::runtime_error("failed to allocate compute pp buffers"); + const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FATTN) + 1; + bool fa_device_mismatch = false; + for (int i = 0; i < ggml_graph_n_nodes(gf); i++) { + ggml_tensor * n = ggml_graph_node(gf, i); + if (n->op != GGML_OP_FLASH_ATTN_EXT) { + continue; + } + ggml_backend_dev_t device_fa = ggml_backend_get_device( + ggml_backend_sched_get_tensor_backend(sched.get(), n)); + + // TODO: instead of the tensor names, use a map to keep track of which (FA) tensors belong to which layer + GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FATTN "-", prefix_len) == 0); + const int il = std::stoi(n->name + prefix_len); + ggml_backend_dev_t device_kv = model.dev_layer(il); + if (device_fa != device_kv) { + LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the Flash Attention tensor " + "is assigned to device %s (usually due to missing support)\n", + __func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_fa)); + // FIXME: fa_device_mismatch logic is wrong for --no-kv-offload, but this is broken anyways + fa_device_mismatch = true; + break; } } + if (fa_device_mismatch) { + cparams.flash_attn = false; + LLAMA_LOG_WARN("%s: Flash Attention was auto, set to disabled\n", __func__); + } else { + cparams.flash_attn = true; + LLAMA_LOG_INFO("%s: Flash Attention was auto, set to enabled\n", __func__); + } - for (size_t i = 0; i < backend_ptrs.size(); ++i) { - ggml_backend_t backend = backend_ptrs[i]; - ggml_backend_buffer_type_t buft = backend_buft[i]; - if (!model.hparams.no_alloc) { - backend_buf_exp_size[i] = ggml_backend_sched_get_buffer_size(sched.get(), backend); + cparams.auto_fa = false; + } + + // reserve worst-case graph + int n_splits_pp = -1; + int n_nodes_pp = -1; + + int n_splits_tg = -1; + int n_nodes_tg = -1; + + // reserve pp (prompt processing) graph first so that buffers are only allocated once + { + auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(), + model.hparams.no_alloc, model.hparams.no_alloc ? backend_buf_exp_size.data() : nullptr); + if (!gf) { + if (cparams.pipeline_parallel) { + LLAMA_LOG_WARN("%s: compute buffer allocation failed, retrying without pipeline parallelism\n", __func__); + cparams.pipeline_parallel = false; + sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, false, cparams.op_offload)); + gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get()); } - if (backend_buf_exp_size[i] > 1) { - LLAMA_LOG_INFO("%s: %10s compute buffer size = %8.2f MiB\n", __func__, - ggml_backend_buft_name(buft), - backend_buf_exp_size[i] / 1024.0 / 1024.0); + if (!gf) { + throw std::runtime_error("failed to allocate compute pp buffers"); } } - if (n_nodes_pp == n_nodes_tg) { - LLAMA_LOG_INFO("%s: graph nodes = %d\n", __func__, n_nodes_pp); - } else { - LLAMA_LOG_INFO("%s: graph nodes = %d (with bs=%d), %d (with bs=1)\n", __func__, n_nodes_pp, n_tokens, n_nodes_tg); - } + n_splits_pp = ggml_backend_sched_get_n_splits(sched.get()); + n_nodes_pp = ggml_graph_n_nodes(gf); + } - if (n_splits_pp == n_splits_tg) { - LLAMA_LOG_INFO("%s: graph splits = %d\n", __func__, n_splits_pp); - } else { - LLAMA_LOG_INFO("%s: graph splits = %d (with bs=%d), %d (with bs=1)\n", __func__, n_splits_pp, n_tokens, n_splits_tg); + // reserve with tg (token generation) graph to get the number of splits and nodes + { + auto * gf = graph_reserve(n_seqs, n_seqs, n_seqs, mctx.get(), model.hparams.no_alloc); + if (!gf) { + throw std::runtime_error("failed to allocate compute tg buffers"); } + + n_splits_tg = ggml_backend_sched_get_n_splits(sched.get()); + n_nodes_tg = ggml_graph_n_nodes(gf); } - // Initialize the full vocabulary token ids for backend samplers. + // reserve again with pp graph to avoid ggml-alloc reallocations during inference { - const int n_vocab = model.vocab.n_tokens(); + // TODO: not sure if the following graph would be worster case for multi-stream KV caches: + // + // auto * gf = graph_reserve(n_tokens, 1, n_tokens, mctx.get()); + // + auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(), model.hparams.no_alloc); + if (!gf) { + throw std::runtime_error("failed to allocate compute pp buffers"); + } + } - sampling.token_ids_full_vocab.resize(n_vocab); - for (int i = 0; i < n_vocab; ++i) { - sampling.token_ids_full_vocab[i] = i; + for (size_t i = 0; i < backend_ptrs.size(); ++i) { + ggml_backend_t backend = backend_ptrs[i]; + ggml_backend_buffer_type_t buft = backend_buft[i]; + if (!model.hparams.no_alloc) { + backend_buf_exp_size[i] = ggml_backend_sched_get_buffer_size(sched.get(), backend); + } + if (backend_buf_exp_size[i] > 1) { + LLAMA_LOG_INFO("%s: %10s compute buffer size = %8.2f MiB\n", __func__, + ggml_backend_buft_name(buft), + backend_buf_exp_size[i] / 1024.0 / 1024.0); } } -} -llama_context::~llama_context() { - if (!model.hparams.no_alloc) { - for (size_t i = 0; i < backend_ptrs.size(); ++i) { - ggml_backend_t backend = backend_ptrs[i]; - ggml_backend_buffer_type_t buft = backend_buft[i]; + if (n_nodes_pp == n_nodes_tg) { + LLAMA_LOG_INFO("%s: graph nodes = %d\n", __func__, n_nodes_pp); + } else { + LLAMA_LOG_INFO("%s: graph nodes = %d (with bs=%d), %d (with bs=1)\n", __func__, n_nodes_pp, n_tokens, n_nodes_tg); + } - const size_t size_exp = backend_buf_exp_size[i]; - const size_t size_act = ggml_backend_sched_get_buffer_size(sched.get(), backend); - if (size_exp == size_act) { - LLAMA_LOG_DEBUG("%s: %10s compute buffer size is %8.4f MiB, matches expectation of %8.4f MiB\n", - __func__, ggml_backend_buft_name(buft), size_act / (1024.0*1024.0), size_exp / (1024.0*1024.0)); - } else { - LLAMA_LOG_WARN("%s: %10s compute buffer size of %8.4f MiB, does not match expectation of %8.4f MiB\n", - __func__, ggml_backend_buft_name(buft), size_act / (1024.0*1024.0), size_exp / (1024.0*1024.0)); - } - } + if (n_splits_pp == n_splits_tg) { + LLAMA_LOG_INFO("%s: graph splits = %d\n", __func__, n_splits_pp); + } else { + LLAMA_LOG_INFO("%s: graph splits = %d (with bs=%d), %d (with bs=1)\n", __func__, n_splits_pp, n_tokens, n_splits_tg); } - ggml_opt_free(opt_ctx); + + const int64_t t_end_us = ggml_time_us(); + + LLAMA_LOG_INFO("%s: reserve took %.2f ms, sched copies = %d\n", + __func__, (t_end_us - t_start_us)/1000.0, ggml_backend_sched_get_n_copies(sched.get())); } void llama_context::synchronize() { + if (!sched) { + return; + } + ggml_backend_sched_synchronize(sched.get()); // FIXME: if multiple single tokens are evaluated without a synchronization, @@ -758,7 +789,7 @@ float * llama_context::get_embeddings_ith(int32_t i) { throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs)); } - const uint32_t n_embd_out = model.hparams.get_n_embd_out(); + const uint32_t n_embd_out = model.hparams.n_embd_out(); return embd + j*n_embd_out; } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what()); @@ -951,21 +982,41 @@ void llama_context::set_embeddings(bool value) { LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value); cparams.embeddings = value; + + // TODO: not sure yet if we want to reserve here + //sched_need_reserve = true; } void llama_context::set_causal_attn(bool value) { LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value); + if (cparams.causal_attn == value) { + return; + } + cparams.causal_attn = value; + + sched_need_reserve = true; } void llama_context::set_warmup(bool value) { LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value); + if (cparams.warmup == value) { + return; + } + cparams.warmup = value; + + // warmups are usually with small batches, so no need to reserve + //sched_need_reserve = true; } bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) { + if (!sampler && sampling.samplers.count(seq_id) == 0) { + return true; + } + LLAMA_LOG_DEBUG("%s: seq_id = %d, sampler = %p\n", __func__, (int) seq_id, (void *) sampler); const bool can_offload = @@ -985,12 +1036,18 @@ bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) { sampling.samplers[seq_id] = sampler; + sched_need_reserve = true; + return true; } if (sampler && !can_offload) { LLAMA_LOG_WARN("%s: sampler '%s' for seq_id = %d, cannot be offloaded to the backend\n", __func__, llama_sampler_name(sampler), seq_id); + if (sampling.samplers.count(seq_id) > 0) { + sched_need_reserve = true; + } + sampling.samplers.erase(seq_id); return false; @@ -998,6 +1055,8 @@ bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) { sampling.samplers.erase(seq_id); + sched_need_reserve = true; + return true; } @@ -1006,16 +1065,27 @@ void llama_context::set_adapter_lora( float scale) { LLAMA_LOG_DEBUG("%s: adapter = %p, scale = %f\n", __func__, (void *) adapter, scale); + if (auto it = loras.find(adapter); it != loras.end()) { + if (it->second == scale) { + return; + } + } + loras[adapter] = scale; + + sched_need_reserve = true; } bool llama_context::rm_adapter_lora( llama_adapter_lora * adapter) { LLAMA_LOG_DEBUG("%s: adapter = %p\n", __func__, (void *) adapter); - auto pos = loras.find(adapter); - if (pos != loras.end()) { - loras.erase(pos); + auto it = loras.find(adapter); + if (it != loras.end()) { + loras.erase(it); + + sched_need_reserve = true; + return true; } @@ -1025,7 +1095,13 @@ bool llama_context::rm_adapter_lora( void llama_context::clear_adapter_lora() { LLAMA_LOG_DEBUG("%s: call\n", __func__); + if (loras.empty()) { + return; + } + loras.clear(); + + sched_need_reserve = true; } bool llama_context::apply_adapter_cvec( @@ -1036,6 +1112,8 @@ bool llama_context::apply_adapter_cvec( int32_t il_end) { LLAMA_LOG_DEBUG("%s: il_start = %d, il_end = %d\n", __func__, il_start, il_end); + // TODO: should we reserve? + return cvec.apply(model, data, len, n_embd, il_start, il_end); } @@ -1138,10 +1216,12 @@ int llama_context::encode(const llama_batch & batch_inp) { // TODO: this clear of the buffer can easily be forgotten - need something better embd_seq.clear(); + sched_reserve(); + n_queued_tokens += n_tokens; // reserve output buffer - if (output_reserve(n_tokens, batch_inp) < n_tokens) { + if (output_reserve(n_tokens) < n_tokens) { LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens); return -2; }; @@ -1177,7 +1257,7 @@ int llama_context::encode(const llama_batch & batch_inp) { auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd(); // extract logits - if (logits && t_logits) { + if (logits && t_logits) { ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits); GGML_ASSERT(backend_res != nullptr); GGML_ASSERT(logits != nullptr); @@ -1195,7 +1275,7 @@ int llama_context::encode(const llama_batch & batch_inp) { { // extract token embeddings GGML_ASSERT(embd != nullptr); - const uint32_t n_embd_out = hparams.get_n_embd_out(); + const uint32_t n_embd_out = hparams.n_embd_out(); GGML_ASSERT(n_tokens*n_embd_out <= (int64_t) embd_size); ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*n_embd_out*sizeof(float)); @@ -1372,6 +1452,23 @@ static void copy_tensor_async_candidates( } } +static bool needs_raw_logits(const llama_ubatch & ubatch, const std::map & samplers) { + for (uint32_t i = 0; i < ubatch.n_tokens; i++) { + if (!ubatch.output[i]) { + continue; + } + + // Check if the output token has at least one sequence without a backend sampler. + for (int32_t j = 0; j < ubatch.n_seq_id[i]; ++j) { + llama_seq_id seq_id = ubatch.seq_id[i][j]; + if (samplers.find(seq_id) == samplers.end()) { + return true; + } + } + } + return false; // all sequences use backend sampling +} + int llama_context::decode(const llama_batch & batch_inp) { GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT @@ -1451,6 +1548,8 @@ int llama_context::decode(const llama_batch & batch_inp) { embd_seq.clear(); output_swaps.clear(); + sched_reserve(); + bool did_optimize = false; // handle any pending shifts/copies @@ -1502,7 +1601,7 @@ int llama_context::decode(const llama_batch & batch_inp) { } // reserve output buffer - if (output_reserve(n_outputs_all, balloc->get_batch()) < n_outputs_all) { + if (output_reserve(n_outputs_all) < n_outputs_all) { LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all); return -2; }; @@ -1575,10 +1674,7 @@ int llama_context::decode(const llama_batch & batch_inp) { } // extract logits - // For multi-sequence batches that mix backend samplers and CPU sampler - // this is currently inefficient as we copy all logits even for the - // backend sampled tokens. - if (logits && t_logits && n_outputs > 0) { + if (logits && t_logits && n_outputs > 0 && needs_raw_logits(ubatch, sampling.samplers)) { ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits); GGML_ASSERT(backend_res != nullptr); GGML_ASSERT(logits != nullptr); @@ -1602,7 +1698,7 @@ int llama_context::decode(const llama_batch & batch_inp) { { // extract token embeddings GGML_ASSERT(embd != nullptr); - const uint32_t n_embd_out = hparams.get_n_embd_out(); + const uint32_t n_embd_out = hparams.n_embd_out(); float * embd_out = embd + n_outputs_prev*n_embd_out; if (n_outputs) { @@ -1648,11 +1744,8 @@ int llama_context::decode(const llama_batch & batch_inp) { } } - // This flag indicates whether a backend sampler has actually sampled a specific - // token, or if it has produced probabilites. If true, we can skip the normal copying of logits and embeddings. - const bool has_sampled = !res->t_sampled.empty() || !res->t_sampled_probs.empty() || !res->t_sampled_logits.empty(); - - if (has_samplers && has_sampled) { + // Copy backend sampling output if this ubatch produced any sampling tensors. + if (has_samplers && (!res->t_sampled.empty() || !res->t_sampled_probs.empty() || !res->t_sampled_logits.empty())) { const auto seq_to_output_row = build_seq_to_output_row(ubatch, n_outputs_prev); const auto stride = n_vocab; @@ -1727,7 +1820,8 @@ int llama_context::decode(const llama_batch & batch_inp) { // output // -uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & batch) { +uint32_t llama_context::output_reserve(int32_t n_outputs) { + const auto & hparams = model.hparams; const auto & vocab = model.vocab; @@ -1735,7 +1829,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & ba const auto n_batch = cparams.n_batch; const auto n_vocab = vocab.n_tokens(); - const auto n_embd_out = hparams.get_n_embd_out(); + const auto n_embd_out = hparams.n_embd_out(); bool has_logits = true; bool has_embd = cparams.embeddings; @@ -1746,45 +1840,16 @@ uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & ba has_embd = true; } - // Check which sampling modes are needed for the current batch. - // TODO: avoid this branching by working with the worst-case - bool has_sampling = false; - bool cpu_logits = false; - - if (batch.logits) { - for (int32_t i = 0; i < batch.n_tokens; i++) { - if (!batch.logits[i]) { - continue; - } - for (int32_t j = 0; j < batch.n_seq_id[i]; j++) { - llama_seq_id seq_id = batch.seq_id[i][j]; - if (sampling.samplers.find(seq_id) != sampling.samplers.end()) { - has_sampling = true; - } else { - cpu_logits = true; - } - } - } - } else { - // When batch.logits is nullptr (when loading state with a dummy batch), - // allocate CPU logits. - cpu_logits = true; - } size_t backend_float_count = 0; size_t backend_token_count = 0; - // Allocate CPU logits buffer only if needed by sequences in this batch - logits_size = (has_logits && cpu_logits) ? n_vocab*n_outputs_max : 0; + logits_size = has_logits ? n_vocab*n_outputs_max : 0; embd_size = has_embd ? n_embd_out*n_outputs_max : 0; - // TODO: avoid this branching by working with the worst-case - if (!has_sampling) { - sampling.logits_size = 0; - sampling.probs_size = 0; - sampling.sampled_size = 0; - sampling.candidates_size = 0; - } else { + // Allocate backend sampling output buffers if there are backend samplers configured. + const bool has_sampling = !sampling.samplers.empty(); + if (has_sampling) { sampling.logits_size = n_vocab*n_outputs_max; sampling.probs_size = n_vocab*n_outputs_max; sampling.sampled_size = n_outputs_max; @@ -1842,7 +1907,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & ba size_t offset = 0; uint8_t * base = (uint8_t *) output_base; - logits = (has_logits && cpu_logits) ? output_base : nullptr; + logits = has_logits ? output_base : nullptr; offset += logits_size * sizeof(float); embd = has_embd ? (float *) (base + offset) : nullptr; @@ -1955,7 +2020,9 @@ uint32_t llama_context::graph_max_nodes(uint32_t n_tokens) const { return std::max(n_tokens * 40, 32u * model.n_tensors()); } uint32_t res = std::max(1024u, 8u*model.n_tensors()); - res += model.n_lora_nodes; + for (const auto & lora : model.loras) { + res += lora->get_n_nodes(); + } return res; } @@ -2085,13 +2152,6 @@ llm_graph_cb llama_context::graph_get_cb() const { ggml_set_name(cur, name); } - if (!cparams.offload_kqv) { - if (strcmp(name, "kqv_merged_cont") == 0) { - // all nodes between the KV store and the attention output are run on the CPU - ggml_backend_sched_set_tensor_backend(sched.get(), cur, backend_cpu); - } - } - // norm may be automatically assigned to the backend of the previous layer, increasing data transfer between backends // FIXME: fix in ggml_backend_sched const bool full_offload = model.n_gpu_layers() > model.hparams.n_layer; @@ -2471,6 +2531,7 @@ size_t llama_context::state_write_data(llama_io_write_i & io) { } } + // [TAG_CONTEXT_STATE_LOGITS] // write logits { LLAMA_LOG_DEBUG("%s: - writing logits\n", __func__); @@ -2532,10 +2593,7 @@ size_t llama_context::state_read_data(llama_io_read_i & io) { auto n_outputs = this->n_outputs; io.read_to(&n_outputs, sizeof(n_outputs)); - // Create a dummy batch for state loading. - llama_batch dummy_batch = {}; - dummy_batch.n_tokens = 0; - if (n_outputs > output_reserve(n_outputs, dummy_batch)) { + if (n_outputs > output_reserve(n_outputs)) { throw std::runtime_error("could not reserve outputs"); } @@ -2780,7 +2838,7 @@ void llama_context::opt_epoch_iter( } // reserve output buffer - if (output_reserve(n_outputs_all, balloc->get_batch()) < n_outputs_all) { + if (output_reserve(n_outputs_all) < n_outputs_all) { LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all); GGML_ABORT("TODO: handle this error"); }; @@ -2815,7 +2873,7 @@ void llama_context::opt_epoch_iter( }; ctx_compute_opt = ggml_init(params); } - ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_tokens(), res->get_logits()); + ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_inp_tokens(), res->get_logits()); ggml_opt_alloc(opt_ctx, train); res->set_inputs(&ubatch); diff --git a/examples/talk-llama/llama-context.h b/examples/talk-llama/llama-context.h index b29edf4db21..8e71cdd1dc5 100644 --- a/examples/talk-llama/llama-context.h +++ b/examples/talk-llama/llama-context.h @@ -40,6 +40,14 @@ struct llama_context { ~llama_context(); + // reserve a new backend scheduler (if needed) + // for example, when: + // - changing loras + // - changing samplers + // - changing attention type + // - etc. + void sched_reserve(); + void synchronize(); const llama_model & get_model() const; @@ -204,7 +212,7 @@ struct llama_context { // Make sure enough space is available for outputs. // Returns max number of outputs for which space was reserved. - uint32_t output_reserve(int32_t n_outputs, const llama_batch & batch); + uint32_t output_reserve(int32_t n_outputs); void output_reorder(); @@ -314,6 +322,8 @@ struct llama_context { ggml_backend_sched_ptr sched; + bool sched_need_reserve = true; + ggml_backend_t backend_cpu = nullptr; std::vector backends; diff --git a/examples/talk-llama/llama-cparams.h b/examples/talk-llama/llama-cparams.h index fcef8fa9760..2da3bbd6f94 100644 --- a/examples/talk-llama/llama-cparams.h +++ b/examples/talk-llama/llama-cparams.h @@ -30,10 +30,12 @@ struct llama_cparams { bool causal_attn; bool offload_kqv; bool flash_attn; + bool auto_fa; bool no_perf; bool warmup; bool op_offload; bool kv_unified; + bool pipeline_parallel; enum llama_pooling_type pooling_type; diff --git a/examples/talk-llama/llama-graph.cpp b/examples/talk-llama/llama-graph.cpp index 374ff1ebf3a..16d42c4ae3d 100644 --- a/examples/talk-llama/llama-graph.cpp +++ b/examples/talk-llama/llama-graph.cpp @@ -7,6 +7,7 @@ #include "llama-kv-cache.h" #include "llama-kv-cache-iswa.h" #include "llama-memory-hybrid.h" +#include "llama-memory-hybrid-iswa.h" #include "llama-memory-recurrent.h" #include @@ -22,7 +23,8 @@ void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) { } if (ubatch->embd) { - const int64_t n_embd = embd->ne[0]; + GGML_ASSERT(n_embd == embd->ne[0]); + const int64_t n_tokens = ubatch->n_tokens; ggml_backend_tensor_set(embd, ubatch->embd, 0, n_tokens*n_embd*ggml_element_size(embd)); @@ -32,8 +34,8 @@ void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) { bool llm_graph_input_embd::can_reuse(const llm_graph_params & params) { bool res = true; - res &= (!tokens && !params.ubatch.token) || (tokens && tokens->ne[0] == params.ubatch.n_tokens); - res &= (!embd && !params.ubatch.embd) || (embd && embd->ne[1] == params.ubatch.n_tokens); + res &= (!params.ubatch.token) || (tokens && tokens->ne[0] == params.ubatch.n_tokens); + res &= (!params.ubatch.embd) || (embd && embd->ne[1] == params.ubatch.n_tokens); return res; } @@ -96,11 +98,9 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) { int32_t * data = (int32_t *) pos_bucket->data; - for (int h = 0; h < 1; ++h) { - for (int j = 0; j < n_tokens; ++j) { - for (int i = 0; i < n_tokens; ++i) { - data[h*(n_tokens*n_tokens) + j*n_tokens + i] = llama_relative_position_bucket(ubatch->pos[i], ubatch->pos[j], hparams.n_rel_attn_bkts, true); - } + for (int j = 0; j < n_tokens; ++j) { + for (int i = 0; i < n_tokens; ++i) { + data[j*n_tokens + i] = llama_relative_position_bucket(ubatch->pos[i], ubatch->pos[j], hparams.n_rel_attn_bkts, true); } } } @@ -323,34 +323,32 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) { const int64_t n_tokens = ubatch->n_tokens; const auto fill_mask = [&](float * data, int n_swa, llama_swa_type swa_type) { - for (int h = 0; h < 1; ++h) { - for (int i1 = 0; i1 < n_tokens; ++i1) { - const llama_seq_id s1 = ubatch->seq_id[i1][0]; - const llama_pos p1 = ubatch->pos[i1]; + for (int i1 = 0; i1 < n_tokens; ++i1) { + const llama_seq_id s1 = ubatch->seq_id[i1][0]; + const llama_pos p1 = ubatch->pos[i1]; - const uint64_t idst = h*(n_kv*n_tokens) + i1*n_kv; + const uint64_t idst = i1*n_kv; - for (int i0 = 0; i0 < n_tokens; ++i0) { - const llama_seq_id s0 = ubatch->seq_id[i0][0]; - const llama_pos p0 = ubatch->pos[i0]; + for (int i0 = 0; i0 < n_tokens; ++i0) { + const llama_seq_id s0 = ubatch->seq_id[i0][0]; + const llama_pos p0 = ubatch->pos[i0]; - // mask different sequences - if (s0 != s1) { - continue; - } - - // mask future tokens - if (cparams.causal_attn && p0 > p1) { - continue; - } + // mask different sequences + if (s0 != s1) { + continue; + } - // apply SWA if any - if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1)) { - continue; - } + // mask future tokens + if (cparams.causal_attn && p0 > p1) { + continue; + } - data[idst + i0] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f; + // apply SWA if any + if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1)) { + continue; } + + data[idst + i0] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f; } } }; @@ -409,6 +407,27 @@ bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) { return res; } +void llm_graph_input_attn_k::set_input(const llama_ubatch * ubatch) { + mctx->set_input_k_idxs(self_k_idxs, ubatch); + + mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); +} + +bool llm_graph_input_attn_k::can_reuse(const llm_graph_params & params) { + const auto * mctx = static_cast(params.mctx); + + this->mctx = mctx; + + bool res = true; + + res &= self_k_idxs->ne[0] == params.ubatch.n_tokens; + + res &= self_kq_mask->ne[0] == mctx->get_n_kv(); + res &= self_kq_mask->ne[1] == params.ubatch.n_tokens; + + return res; +} + void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) { mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch); mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch); @@ -454,27 +473,19 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) { float * data = (float *) cross_kq_mask->data; - for (int h = 0; h < 1; ++h) { - for (int i = 0; i < n_tokens; ++i) { - for (int j = 0; j < n_enc; ++j) { - float f = -INFINITY; + for (int i = 0; i < n_tokens; ++i) { + for (int j = 0; j < n_enc; ++j) { + float f = -INFINITY; - for (int s = 0; s < ubatch->n_seq_id[i]; ++s) { - const llama_seq_id seq_id = ubatch->seq_id[i][s]; + for (int s = 0; s < ubatch->n_seq_id[i]; ++s) { + const llama_seq_id seq_id = ubatch->seq_id[i][s]; - if (cross->seq_ids_enc[j].find(seq_id) != cross->seq_ids_enc[j].end()) { - f = 0.0f; - } + if (cross->seq_ids_enc[j].find(seq_id) != cross->seq_ids_enc[j].end()) { + f = 0.0f; } - - data[h*(n_enc*n_tokens) + i*n_enc + j] = f; } - } - for (int i = n_tokens; i < n_tokens; ++i) { - for (int j = 0; j < n_enc; ++j) { - data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY; - } + data[i*n_enc + j] = f; } } } @@ -522,6 +533,76 @@ bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) { return res; } +void llm_graph_input_mem_hybrid_iswa::set_input(const llama_ubatch * ubatch) { + const auto * attn_ctx = mctx->get_attn(); + + // base tensors may not be allocated if there are no non-SWA attention layers + if (inp_attn->self_k_idxs && inp_attn->self_k_idxs->buffer) { + attn_ctx->get_base()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch); + attn_ctx->get_base()->set_input_v_idxs(inp_attn->self_v_idxs, ubatch); + + attn_ctx->get_base()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn); + } + + // swa tensors may not be allocated if there are no SWA attention layers + if (inp_attn->self_k_idxs_swa && inp_attn->self_k_idxs_swa->buffer) { + attn_ctx->get_swa()->set_input_k_idxs(inp_attn->self_k_idxs_swa, ubatch); + attn_ctx->get_swa()->set_input_v_idxs(inp_attn->self_v_idxs_swa, ubatch); + + attn_ctx->get_swa()->set_input_kq_mask(inp_attn->self_kq_mask_swa, ubatch, cparams.causal_attn); + } + + const int64_t n_rs = mctx->get_recr()->get_n_rs(); + + if (inp_rs->s_copy) { + GGML_ASSERT(ggml_backend_buffer_is_host(inp_rs->s_copy->buffer)); + int32_t * data = (int32_t *) inp_rs->s_copy->data; + + // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n + for (uint32_t i = 0; i < n_rs; ++i) { + data[i] = mctx->get_recr()->s_copy(i); + } + } +} + +bool llm_graph_input_mem_hybrid_iswa::can_reuse(const llm_graph_params & params) { + const auto * mctx = static_cast(params.mctx); + + this->mctx = mctx; + + bool res = true; + + const auto * attn_ctx = mctx->get_attn(); + + // base tensors may not be allocated if there are no non-SWA attention layers + if (inp_attn->self_k_idxs && inp_attn->self_k_idxs->buffer) { + res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens; + //res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there + + res &= inp_attn->self_kq_mask->ne[0] == attn_ctx->get_base()->get_n_kv(); + res &= inp_attn->self_kq_mask->ne[1] == params.ubatch.n_tokens; + } + + // swa tensors may not be allocated if there are no SWA attention layers + if (inp_attn->self_k_idxs_swa && inp_attn->self_k_idxs_swa->buffer) { + res &= inp_attn->self_k_idxs_swa->ne[0] == params.ubatch.n_tokens; + //res &= inp_attn->self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there + + res &= inp_attn->self_kq_mask_swa->ne[0] == attn_ctx->get_swa()->get_n_kv(); + res &= inp_attn->self_kq_mask_swa->ne[1] == params.ubatch.n_tokens; + } + + res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs(); + + res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs; + res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs; + + res &= inp_rs->head == mctx->get_recr()->get_head(); + res &= inp_rs->rs_z == mctx->get_recr()->get_rs_z(); + + return res; +} + void llm_graph_input_sampling::set_input(const llama_ubatch * ubatch) { // set the inputs only for the active samplers in the current ubatch std::unordered_set active_samplers; @@ -575,7 +656,8 @@ int64_t llm_graph_result::get_max_nodes() const { } void llm_graph_result::reset() { - t_tokens = nullptr; + t_inp_tokens = nullptr; + t_inp_embd = nullptr; t_logits = nullptr; t_embd = nullptr; t_embd_pooled = nullptr; @@ -1279,17 +1361,29 @@ ggml_tensor * llm_graph_context::build_moe_ffn( // input embeddings with optional lora ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const { - const int64_t n_embd = hparams.n_embd_inp(); + const int64_t n_embd_inp = hparams.n_embd_inp(); + const int64_t n_embd = hparams.n_embd; + + assert(n_embd_inp >= n_embd); + + auto inp = std::make_unique(n_embd_inp); - auto inp = std::make_unique(); + inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens); + cb(inp->tokens, "inp_tokens", -1); + ggml_set_input(inp->tokens); + res->t_inp_tokens = inp->tokens; - ggml_tensor * cur = nullptr; + inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd_inp, ubatch.n_tokens); + cb(inp->embd, "inp_embd", -1); + ggml_set_input(inp->embd); - if (ubatch.token) { - inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens); - //cb(inp->tokens, "inp_tokens", -1); - ggml_set_input(inp->tokens); - res->t_tokens = inp->tokens; + // select one of the 2 inputs, based on the batch contents + // ref: https://github.com/ggml-org/llama.cpp/pull/18550 + std::array inps; + + // token embeddings path (ubatch.token != nullptr) + { + auto & cur = inps[0]; cur = ggml_get_rows(ctx0, tok_embd, inp->tokens); @@ -1310,19 +1404,36 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const { cur = ggml_add(ctx0, cur, inpL_delta); } - } else { - inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, ubatch.n_tokens); - ggml_set_input(inp->embd); + + if (n_embd_inp != n_embd) { + cur = ggml_pad(ctx0, cur, hparams.n_embd_inp() - n_embd, 0, 0, 0); + } + } + + // vector embeddings path (ubatch.embd != nullptr) + { + auto & cur = inps[1]; cur = inp->embd; } + assert(ggml_are_same_shape (inps[0], inps[1])); + assert(ggml_are_same_stride(inps[0], inps[1])); + + ggml_tensor * cur = ggml_build_forward_select(gf, inps.data(), inps.size(), ubatch.token ? 0 : 1); + + if (n_embd_inp != n_embd) { + cur = ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0); + } + + res->t_inp_embd = cur; + // For Granite architecture if (hparams.f_embedding_scale != 0.0f) { cur = ggml_scale(ctx0, cur, hparams.f_embedding_scale); } - cb(cur, "inp_embd", -1); + cb(cur, "embd", -1); res->add_input(std::move(inp)); @@ -1421,7 +1532,7 @@ ggml_tensor * llm_graph_context::build_inp_cross_embd() const { //} const auto n_embd = !cross->v_embd.empty() ? cross->n_embd : hparams.n_embd_inp(); - const auto n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train; + const auto n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train; cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_enc); ggml_set_input(cur); @@ -1728,9 +1839,11 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * v_cur, ggml_tensor * kq_b, ggml_tensor * sinks, - ggml_tensor * v_mla, + ggml_tensor * v_mla, // TODO: remove float kq_scale, int il) const { + GGML_ASSERT(v_mla == nullptr); + // these nodes are added to the graph together so that they are not reordered // by doing so, the number of splits in the graph is reduced // expand k later to enable rope fusion which directly writes into k-v cache @@ -1773,6 +1886,93 @@ ggml_tensor * llm_graph_context::build_attn( return cur; } +static std::unique_ptr build_attn_inp_k_impl( + ggml_context * ctx0, + const llama_ubatch & ubatch, + const llama_hparams & hparams, + const llama_cparams & cparams, + const llama_kv_cache_context * mctx_cur) { + + auto inp = std::make_unique(hparams, cparams, mctx_cur); + + { + GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_iswa for SWA"); + + const auto n_kv = mctx_cur->get_n_kv(); + const auto n_tokens = ubatch.n_tokens; + const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq; + + inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch); + + inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream); + ggml_set_input(inp->self_kq_mask); + + inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; + } + + return inp; +} + +llm_graph_input_attn_k * llm_graph_context::build_attn_inp_k() const { + const auto * mctx_cur = static_cast(mctx); + + auto inp = build_attn_inp_k_impl(ctx0, ubatch, hparams, cparams, mctx_cur); + + return (llm_graph_input_attn_k *) res->add_input(std::move(inp)); +} + +ggml_tensor * llm_graph_context::build_attn( + llm_graph_input_attn_k * inp, + ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * q_cur, + ggml_tensor * k_cur, + ggml_tensor * v_cur, + ggml_tensor * kq_b, + ggml_tensor * sinks, + ggml_tensor * v_mla, + float kq_scale, + int il) const { + // these nodes are added to the graph together so that they are not reordered + // by doing so, the number of splits in the graph is reduced + // expand k later to enable rope fusion which directly writes into k-v cache + ggml_build_forward_expand(gf, q_cur); + ggml_build_forward_expand(gf, v_cur); + ggml_build_forward_expand(gf, k_cur); + + const auto * mctx_cur = inp->mctx; + + // store to KV cache + { + const auto & k_idxs = inp->get_k_idxs(); + + ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il)); + } + + const auto & kq_mask = inp->get_kq_mask(); + + ggml_tensor * q = q_cur; + ggml_tensor * k = mctx_cur->get_k(ctx0, il); + ggml_tensor * v = ggml_view_4d(ctx0, k, v_cur->ne[0], k->ne[1], k->ne[2], k->ne[3], k->nb[1], k->nb[2], k->nb[3], 0); + + ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il); + cb(cur, "kqv_out", il); + + if (wo) { + cur = build_lora_mm(wo, cur); + if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) { + // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators + ggml_mul_mat_set_prec(cur, GGML_PREC_F32); + } + } + + if (wo_b) { + cur = ggml_add(ctx0, cur, wo_b); + } + + return cur; +} + ggml_tensor * llm_graph_context::build_attn( llm_graph_input_attn_kv_iswa * inp, ggml_tensor * wo, @@ -2068,6 +2268,47 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const { return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp)); } +llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa() const { + const auto * mctx_cur = static_cast(mctx); + + auto inp_rs = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr()); + + // build iswa attention input + const auto * attn_ctx = mctx_cur->get_attn(); + + auto inp_attn = std::make_unique(hparams, cparams, attn_ctx); + + const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq; + + { + const auto n_kv = attn_ctx->get_base()->get_n_kv(); + + inp_attn->self_k_idxs = attn_ctx->get_base()->build_input_k_idxs(ctx0, ubatch); + inp_attn->self_v_idxs = attn_ctx->get_base()->build_input_v_idxs(ctx0, ubatch); + + inp_attn->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream); + ggml_set_input(inp_attn->self_kq_mask); + + inp_attn->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask, GGML_TYPE_F16) : inp_attn->self_kq_mask; + } + + { + const auto n_kv = attn_ctx->get_swa()->get_n_kv(); + + inp_attn->self_k_idxs_swa = attn_ctx->get_swa()->build_input_k_idxs(ctx0, ubatch); + inp_attn->self_v_idxs_swa = attn_ctx->get_swa()->build_input_v_idxs(ctx0, ubatch); + + inp_attn->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream); + ggml_set_input(inp_attn->self_kq_mask_swa); + + inp_attn->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask_swa, GGML_TYPE_F16) : inp_attn->self_kq_mask_swa; + } + + auto inp = std::make_unique(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur); + + return (llm_graph_input_mem_hybrid_iswa *) res->add_input(std::move(inp)); +} + void llm_graph_context::build_dense_out( ggml_tensor * dense_2, ggml_tensor * dense_3) const { diff --git a/examples/talk-llama/llama-graph.h b/examples/talk-llama/llama-graph.h index 503ffd695aa..4090d8116c9 100644 --- a/examples/talk-llama/llama-graph.h +++ b/examples/talk-llama/llama-graph.h @@ -24,6 +24,7 @@ class llama_kv_cache_context; class llama_kv_cache_iswa_context; class llama_memory_recurrent_context; class llama_memory_hybrid_context; +class llama_memory_hybrid_iswa_context; // certain models (typically multi-modal) can produce different types of graphs enum llm_graph_type { @@ -105,7 +106,7 @@ using llm_graph_input_ptr = std::unique_ptr; class llm_graph_input_embd : public llm_graph_input_i { public: - llm_graph_input_embd() = default; + llm_graph_input_embd(int64_t n_embd) : n_embd(n_embd) {} virtual ~llm_graph_input_embd() = default; void set_input(const llama_ubatch * ubatch) override; @@ -114,6 +115,8 @@ class llm_graph_input_embd : public llm_graph_input_i { ggml_tensor * tokens = nullptr; // I32 [n_batch] ggml_tensor * embd = nullptr; // F32 [n_embd, n_batch] + + const int64_t n_embd = 0; }; class llm_graph_input_pos : public llm_graph_input_i { @@ -314,6 +317,39 @@ class llm_graph_input_attn_kv : public llm_graph_input_i { const llama_kv_cache_context * mctx; }; +// V-less input for the KV cache +// ref: https://github.com/ggml-org/llama.cpp/pull/19067 +class llm_graph_input_attn_k : public llm_graph_input_i { +public: + llm_graph_input_attn_k( + const llama_hparams & hparams, + const llama_cparams & cparams, + const llama_kv_cache_context * mctx) : + hparams(hparams), + cparams(cparams), + mctx(mctx) { + } + ~llm_graph_input_attn_k() = default; + + void set_input(const llama_ubatch * ubatch) override; + + bool can_reuse(const llm_graph_params & params) override; + + ggml_tensor * get_k_idxs() const { return self_k_idxs; } + + ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; } + + ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch] + + ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] + + const llama_hparams hparams; + const llama_cparams cparams; + + const llama_kv_cache_context * mctx; +}; + class llm_graph_input_attn_kv_iswa : public llm_graph_input_i { public: llm_graph_input_attn_kv_iswa( @@ -397,6 +433,34 @@ class llm_graph_input_mem_hybrid : public llm_graph_input_i { const llama_memory_hybrid_context * mctx; }; +class llm_graph_input_mem_hybrid_iswa : public llm_graph_input_i { +public: + llm_graph_input_mem_hybrid_iswa( + const llama_cparams & cparams, + std::unique_ptr inp_attn, + std::unique_ptr inp_rs, + const llama_memory_hybrid_iswa_context * mctx) : + inp_attn(std::move(inp_attn)), + inp_rs(std::move(inp_rs)), + cparams(cparams), + mctx(mctx) { } + virtual ~llm_graph_input_mem_hybrid_iswa() = default; + + void set_input(const llama_ubatch * ubatch) override; + + bool can_reuse(const llm_graph_params & params) override; + + std::unique_ptr inp_attn; + std::unique_ptr inp_rs; + + llm_graph_input_attn_kv_iswa * get_attn() const { return inp_attn.get(); } + llm_graph_input_rs * get_recr() const { return inp_rs.get(); } + + const llama_cparams cparams; + + const llama_memory_hybrid_iswa_context * mctx; +}; + class llm_graph_input_sampling : public llm_graph_input_i { public: llm_graph_input_sampling(std::map samplers) : @@ -537,7 +601,7 @@ class llm_graph_result { virtual ~llm_graph_result() = default; - ggml_tensor * get_tokens() const { return t_tokens; } + ggml_tensor * get_inp_tokens() const { return t_inp_tokens; } ggml_tensor * get_logits() const { return t_logits; } ggml_tensor * get_embd() const { return t_embd; } ggml_tensor * get_embd_pooled() const { return t_embd_pooled; } @@ -564,7 +628,8 @@ class llm_graph_result { void set_params(const llm_graph_params & params); // important graph nodes - ggml_tensor * t_tokens = nullptr; + ggml_tensor * t_inp_tokens = nullptr; + ggml_tensor * t_inp_embd = nullptr; // [n_embd_inp, n_tokens] ggml_tensor * t_logits = nullptr; ggml_tensor * t_embd = nullptr; ggml_tensor * t_embd_pooled = nullptr; @@ -801,6 +866,21 @@ struct llm_graph_context { ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] ggml_tensor * kq_b, ggml_tensor * sinks, // [n_head_q] + ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] // TODO: remove + float kq_scale, + int il) const; + + llm_graph_input_attn_k * build_attn_inp_k() const; + + ggml_tensor * build_attn( + llm_graph_input_attn_k * inp, + ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens] + ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] + ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] + ggml_tensor * kq_b, + ggml_tensor * sinks, // [n_head_q] ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] float kq_scale, int il) const; @@ -881,6 +961,8 @@ struct llm_graph_context { llm_graph_input_mem_hybrid * build_inp_mem_hybrid() const; + llm_graph_input_mem_hybrid_iswa * build_inp_mem_hybrid_iswa() const; + // // pooling // diff --git a/examples/talk-llama/llama-hparams.cpp b/examples/talk-llama/llama-hparams.cpp index c847ef91b7a..392f9160cef 100644 --- a/examples/talk-llama/llama-hparams.cpp +++ b/examples/talk-llama/llama-hparams.cpp @@ -72,8 +72,8 @@ uint32_t llama_hparams::n_embd_inp() const { return n_embd_inp; } -uint32_t llama_hparams::get_n_embd_out() const { - return n_embd_out > 0 ? n_embd_out : n_embd; +uint32_t llama_hparams::n_embd_out() const { + return n_embd_out_impl > 0 ? n_embd_out_impl : n_embd; } uint32_t llama_hparams::n_embd_k_gqa(uint32_t il) const { @@ -175,6 +175,21 @@ bool llama_hparams::is_swa(uint32_t il) const { GGML_ABORT("fatal error"); } +bool llama_hparams::is_mla() const { + assert((n_embd_head_k_mla_impl == 0 && n_embd_head_v_mla_impl == 0) || + (n_embd_head_k_mla_impl != 0 && n_embd_head_v_mla_impl != 0)); + + return n_embd_head_k_mla_impl != 0 && n_embd_head_v_mla_impl != 0; +} + +uint32_t llama_hparams::n_embd_head_k_mla() const { + return is_mla() ? n_embd_head_k_mla_impl : n_embd_head_k; +} + +uint32_t llama_hparams::n_embd_head_v_mla() const { + return is_mla() ? n_embd_head_v_mla_impl : n_embd_head_v; +} + bool llama_hparams::has_kv(uint32_t il) const { if (n_layer_kv_from_start >= 0) { if (il < (uint32_t) n_layer_kv_from_start) { @@ -200,42 +215,6 @@ uint32_t llama_hparams::n_layer_kv() const { return res; } -bool llama_hparams::is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1) { - assert(p0 >= 0 && p1 >= 0); - - switch (swa_type) { - case LLAMA_SWA_TYPE_NONE: - { - } break; - case LLAMA_SWA_TYPE_STANDARD: - { - if (p1 - p0 >= (int32_t) n_swa) { - return true; - } - } break; - case LLAMA_SWA_TYPE_CHUNKED: - { - const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa; - - if (p0 < pos_chunk_start) { - return true; - } - } break; - case LLAMA_SWA_TYPE_SYMMETRIC: - { - const int32_t half_n_swa = (int32_t) n_swa / 2; - const int32_t pos_diff = p1 - p0; - - // Mask if outside the symmetric window - if (pos_diff < -half_n_swa || pos_diff > half_n_swa) { - return true; - } - } break; - } - - return false; -} - bool llama_hparams::use_mrope() const { return rope_sections[0] > 0 && rope_sections[1] > 0; } diff --git a/examples/talk-llama/llama-hparams.h b/examples/talk-llama/llama-hparams.h index 7ae3ec292ef..caed0ec1b76 100644 --- a/examples/talk-llama/llama-hparams.h +++ b/examples/talk-llama/llama-hparams.h @@ -3,6 +3,7 @@ #include "llama.h" #include +#include // bump if necessary #define LLAMA_MAX_LAYERS 512 @@ -52,8 +53,8 @@ struct llama_hparams { uint32_t n_rel_attn_bkts = 0; // note: deepseek2 using MLA converts into MQA with larger heads, then decompresses to MHA - uint32_t n_embd_head_k_mla = 0; - uint32_t n_embd_head_v_mla = 0; + uint32_t n_embd_head_k_mla_impl = 0; + uint32_t n_embd_head_v_mla_impl = 0; // for WavTokenizer struct llama_hparams_posnet posnet; @@ -163,7 +164,7 @@ struct llama_hparams { uint32_t n_cls_out = 1; // output embedding dimension (0 = use n_embd) - uint32_t n_embd_out = 0; + uint32_t n_embd_out_impl = 0; // llama4 smallthinker uint32_t n_moe_layer_step = 0; @@ -238,7 +239,7 @@ struct llama_hparams { uint32_t n_embd_inp() const; // dimension of output embeddings - uint32_t get_n_embd_out() const; + uint32_t n_embd_out() const; // dimension of key embeddings across all k-v heads uint32_t n_embd_k_gqa(uint32_t il = 0) const; @@ -268,15 +269,57 @@ struct llama_hparams { bool is_swa(uint32_t il) const; + // note: currently only support if either all or none of the layers are MLA + bool is_mla() const; + + uint32_t n_embd_head_k_mla() const; + uint32_t n_embd_head_v_mla() const; + bool has_kv(uint32_t il) const; // number of layers for which has_kv() returns true uint32_t n_layer_kv() const; // note that this function uses different SWA parameters from those in the hparams + // note: inlined on purpose for performance reasons // TODO: think of a better place for this function // TODO: pack the SWA params in a struct? - static bool is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1); + static bool is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1) { + assert(p0 >= 0 && p1 >= 0); + + switch (swa_type) { + case LLAMA_SWA_TYPE_NONE: + { + } break; + case LLAMA_SWA_TYPE_STANDARD: + { + if (p1 - p0 >= (int32_t) n_swa) { + return true; + } + } break; + case LLAMA_SWA_TYPE_CHUNKED: + { + const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa; + + if (p0 < pos_chunk_start) { + return true; + } + } break; + case LLAMA_SWA_TYPE_SYMMETRIC: + { + const int32_t half_n_swa = (int32_t) n_swa / 2; + const int32_t pos_diff = p1 - p0; + + // Mask if outside the symmetric window + if (pos_diff < -half_n_swa || pos_diff > half_n_swa) { + return true; + } + } break; + } + + return false; + } + bool use_mrope() const; }; diff --git a/examples/talk-llama/llama-kv-cache.cpp b/examples/talk-llama/llama-kv-cache.cpp index 3186242d60f..f3c9b49f30a 100644 --- a/examples/talk-llama/llama-kv-cache.cpp +++ b/examples/talk-llama/llama-kv-cache.cpp @@ -97,6 +97,8 @@ llama_kv_cache::llama_kv_cache( __func__, hparams.n_embd_v_gqa_max()); } + const bool is_mla = hparams.is_mla(); + for (uint32_t il = 0; il < hparams.n_layer; il++) { if (!hparams.has_kv(il)) { LLAMA_LOG_DEBUG("%s: layer %3d: does not have KV cache\n", __func__, il); @@ -130,18 +132,21 @@ llama_kv_cache::llama_kv_cache( throw std::runtime_error("failed to create ggml context for kv cache"); } - ggml_tensor * k = ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_stream); - ggml_tensor * v = ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_stream); + const bool has_k = true; + const bool has_v = !is_mla; + + ggml_tensor * k = has_k ? ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_stream) : nullptr; + ggml_tensor * v = has_v ? ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_stream) : nullptr; - ggml_format_name(k, "cache_k_l%d", il); - ggml_format_name(v, "cache_v_l%d", il); + has_k && ggml_format_name(k, "cache_k_l%d", il); + has_v && ggml_format_name(v, "cache_v_l%d", il); std::vector k_stream; std::vector v_stream; for (uint32_t s = 0; s < n_stream; ++s) { - k_stream.push_back(ggml_view_2d(ctx, k, n_embd_k_gqa, kv_size, k->nb[1], s*k->nb[2])); - v_stream.push_back(ggml_view_2d(ctx, v, n_embd_v_gqa, kv_size, v->nb[1], s*v->nb[2])); + k_stream.push_back(has_k ? ggml_view_2d(ctx, k, n_embd_k_gqa, kv_size, k->nb[1], s*k->nb[2]) : nullptr); + v_stream.push_back(has_v ? ggml_view_2d(ctx, v, n_embd_v_gqa, kv_size, v->nb[1], s*v->nb[2]) : nullptr); } map_layer_ids[il] = layers.size(); @@ -647,7 +652,10 @@ bool llama_kv_cache::update(llama_context * lctx, bool do_shift, const stream_co const auto & layer = layers[il]; ggml_backend_tensor_copy(layer.k_stream[ssrc], layer.k_stream[sdst]); - ggml_backend_tensor_copy(layer.v_stream[ssrc], layer.v_stream[sdst]); + + if (layer.v_stream[ssrc]) { + ggml_backend_tensor_copy(layer.v_stream[ssrc], layer.v_stream[sdst]); + } } } } @@ -852,7 +860,7 @@ llama_kv_cache::slot_info llama_kv_cache::find_slot(const llama_ubatch & ubatch, const llama_seq_id seq_id_cell = cells.seq_get(idx); // SWA mask - if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) { + if (llama_hparams::is_masked_swa(n_swa, swa_type, pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) { can_use = true; } } @@ -1237,90 +1245,236 @@ void llama_kv_cache::set_input_k_shift(ggml_tensor * dst) const { } } -void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const { - const uint32_t n_tokens = ubatch->n_tokens; +struct args_set_input_kq_mask { + const llama_hparams & hparams; + const llama_ubatch * ubatch; - GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); - float * data = (float *) dst->data; + const std::vector & v_cells; + const std::vector & seq_to_stream; - const int64_t n_kv = dst->ne[0]; - const int64_t n_stream = dst->ne[3]; // num streams in the current ubatch + uint32_t n_swa; + llama_swa_type swa_type; - GGML_ASSERT(n_tokens%n_stream == 0); + int64_t n_kv; + int64_t n_stream; + int64_t n_tps; +}; - // n_tps == n_tokens_per_stream - const int64_t n_tps = n_tokens/n_stream; +template +static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) { + //const auto & hparams = args.hparams; + const auto & ubatch = args.ubatch; - std::fill(data, data + ggml_nelements(dst), -INFINITY); - - // Use only the previous KV cells of the correct sequence for each token of the ubatch. - // It's assumed that if a token in the batch has multiple sequences, they are equivalent. - // Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch: - // Causal mask: - // xxx------- - // xxxx------ - // xxxxx----- - // Non-causal mask: - // xxxxx----- - // xxxxx----- - // xxxxx----- - // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615 - // TODO: optimize this section - for (uint32_t h = 0; h < 1; ++h) { - for (uint32_t s = 0; s < n_stream; ++s) { - for (uint32_t ii = 0; ii < n_tps; ++ii) { - const uint32_t i = s*n_tps + ii; + const auto & v_cells = args.v_cells; + const auto & seq_to_stream = args.seq_to_stream; + + const uint32_t n_swa = args.n_swa; + const llama_swa_type swa_type = args.swa_type; - const llama_seq_id seq_id = ubatch->seq_id[i][0]; + const int64_t n_kv = args.n_kv; + const int64_t n_stream = args.n_stream; + const int64_t n_tps = args.n_tps; - const auto & cells = v_cells[seq_to_stream[seq_id]]; + // the min position in the batch for each sequence + llama_pos seq_pos_min[LLAMA_MAX_SEQ]; + std::fill(seq_pos_min, seq_pos_min + LLAMA_MAX_SEQ, INT32_MAX); - const llama_pos p1 = ubatch->pos[i]; + for (uint32_t i = 0; i < ubatch->n_tokens; ++i) { + const llama_seq_id seq_id = ubatch->seq_id[i][0]; - // for M-RoPE - const bool is_2d = ubatch->is_pos_2d(); - const llama_pos p1_x = is_2d ? ubatch->pos[i + ubatch->n_tokens*2] : 0; - const llama_pos p1_y = is_2d ? ubatch->pos[i + ubatch->n_tokens] : 0; + seq_pos_min[seq_id] = std::min(seq_pos_min[seq_id], ubatch->pos[i]); + } - const uint64_t idst = n_kv*(h*n_stream*n_tps + s*n_tps + ii); + for (uint32_t s = 0; s < n_stream; ++s) { + // bookeeping of the KQ mask cells that could change for other tokens of the same sequence + std::unordered_map seq_srct; + std::unordered_map> seq_idxs; - for (uint32_t j = 0; j < n_kv; ++j) { - if (cells.is_empty(j)) { - continue; - } + for (uint32_t ii = 0; ii < n_tps; ++ii) { + const uint32_t i = s*n_tps + ii; + + const llama_seq_id seq_id = ubatch->seq_id[i][0]; + + const auto & cells = v_cells.at(seq_to_stream[seq_id]); + + llama_pos p0 = -1; + const llama_pos p1 = ubatch->pos[i]; + + // for M-RoPE + const llama_pos p1_x = is_2d ? ubatch->pos[i + ubatch->n_tokens*2] : 0; + const llama_pos p1_y = is_2d ? ubatch->pos[i + ubatch->n_tokens] : 0; + + const uint64_t idst = n_kv*i; + + // for tokens of the same sequence, the mask is mostly the same, so we can reuse it + // the only cells that could change are the ones that are with similar positions as the + // ones in the batch (i.e. due to causal masking, SWA, etc.) + // keep track of those cells and shortcut the loop to save time + // note: this optimization is not compatible with Alibi position encoding + // ref: https://github.com/ggml-org/llama.cpp/pull/18842 + bool prev = false; - // mask the token if not the same sequence - if (!cells.seq_has(j, seq_id)) { - continue; + auto & idxs = seq_idxs[seq_id]; + + if (!alibi) { + if (seq_srct.find(seq_id) != seq_srct.end()) { + const uint32_t srct = seq_srct[seq_id]; + + const uint64_t idst_prev = n_kv*srct; + + std::copy(data + idst_prev, data + idst_prev + n_kv, data + idst); + + prev = true; + } else { + idxs.clear(); + idxs.reserve(ubatch->n_tokens + n_swa + 32); + + seq_srct[seq_id] = i; + } + } + + for (uint32_t jj = 0; jj < n_kv; ++jj) { + uint32_t j = jj; + + // we have an exiting mask for this sequence -> update just seq_idxs + if (!alibi) { + if (prev) { + if (jj >= idxs.size()) { + break; + } + + j = idxs[jj]; } + } + + if (cells.is_empty(j)) { + goto skip; + } + + // mask the token if not the same sequence + if (!cells.seq_has(j, seq_id)) { + goto skip; + } + + p0 = cells.pos_get(j); - const llama_pos p0 = cells.pos_get(j); + if (!alibi) { + if (!prev) { + // record all cells for which: p0 >= seq_pos_min[seq_id] - n_swa - 32 + if (p0 + (int32_t) (n_swa + 32) >= seq_pos_min[seq_id]) { + idxs.push_back(j); + } + } + } + if (causal) { // mask future tokens - if (causal_attn && p0 > p1) { - continue; + if (p0 > p1) { + goto skip; } // M-RoPE causal mask - if (causal_attn && is_2d && p0 == p1) { - const auto & p0_ext = cells.ext_get(j); - if (p0_ext.is_2d_gt(p1_x, p1_y)) { - continue; + if (is_2d) { + if (p0 == p1) { + const auto & p0_ext = cells.ext_get(j); + + if (p0_ext.is_2d_gt(p1_x, p1_y)) { + goto skip; + } } } + } - // apply SWA if any - if (is_masked_swa(p0, p1)) { - continue; + // apply SWA if any + if (swa) { + if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1)) { + goto skip; } + } - data[idst + j] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f; + if (alibi) { + data[idst + j] = -std::abs(p0 - p1); + } else { + data[idst + j] = 0.0f; } + + continue; +skip: + data[idst + j] = -INFINITY; } } } } +template +static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) { + const bool alibi = args.hparams.use_alibi; + if (alibi) { + set_input_kq_mask_impl (args, data); + } else { + set_input_kq_mask_impl(args, data); + } +} + +template +static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) { + const bool is_2d = args.ubatch->is_pos_2d(); + if (is_2d) { + set_input_kq_mask_impl (args, data); + } else { + set_input_kq_mask_impl(args, data); + } +} + +template +static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) { + const bool swa = args.swa_type != LLAMA_SWA_TYPE_NONE; + if (swa) { + set_input_kq_mask_impl (args, data); + } else { + set_input_kq_mask_impl(args, data); + } +} + +void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const { + const uint32_t n_tokens = ubatch->n_tokens; + + GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); + float * data = (float *) dst->data; + + const int64_t n_kv = dst->ne[0]; + const int64_t n_stream = dst->ne[3]; // num streams in the current ubatch + + GGML_ASSERT(n_tokens%n_stream == 0); + + // n_tps == n_tokens_per_stream + const int64_t n_tps = n_tokens/n_stream; + + //const int64_t t_start = ggml_time_us(); + + const args_set_input_kq_mask args = { + /*.hparams =*/ hparams, + /*.ubatch =*/ ubatch, + /*.v_cells =*/ v_cells, + /*.seq_to_stream =*/ seq_to_stream, + /*.n_swa =*/ n_swa, + /*.swa_type =*/ swa_type, + /*.n_kv =*/ n_kv, + /*.n_stream =*/ n_stream, + /*.n_tps =*/ n_tps, + }; + + if (causal_attn) { + set_input_kq_mask_impl (args, data); + } else { + set_input_kq_mask_impl(args, data); + } + + //const int64_t t_end = ggml_time_us(); + + //LLAMA_LOG_ERROR("%s: kq mask time: %0.3f ms\n", __func__, (t_end - t_start)/1000.0); +} + void llama_kv_cache::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const { const int64_t n_tokens = ubatch->n_tokens; @@ -1370,7 +1524,7 @@ size_t llama_kv_cache::size_v_bytes() const { size_t size_v_bytes = 0; for (const auto & layer : layers) { - size_v_bytes += ggml_nbytes(layer.v); + size_v_bytes += layer.v ? ggml_nbytes(layer.v) : 0; } return size_v_bytes; @@ -1448,6 +1602,10 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co const auto & n_embd_head_k = hparams.n_embd_head_k; //const auto & n_embd_head_v = hparams.n_embd_head_v; + const auto & n_rot = hparams.n_rot; + + const auto n_embd_nope = hparams.n_lora_kv > 0 ? n_embd_head_k - n_rot : 0; + auto inp = std::make_unique(this); inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, (int64_t) get_size()*n_stream); @@ -1468,10 +1626,10 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co ggml_tensor * k = ggml_view_3d(ctx, layer.k, - n_embd_head_k, n_head_kv, get_size()*n_stream, + n_rot, n_head_kv, get_size()*n_stream, ggml_row_size(layer.k->type, n_embd_head_k), ggml_row_size(layer.k->type, n_embd_k_gqa), - 0); + ggml_row_size(layer.k->type, n_embd_nope)); ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l); @@ -1483,10 +1641,6 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co return gf; } -bool llama_kv_cache::is_masked_swa(llama_pos p0, llama_pos p1) const { - return llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1); -} - void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const { GGML_UNUSED(flags); @@ -1652,6 +1806,9 @@ void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); auto * v = layer.v_stream[cr.strm]; + if (!v) { + continue; + } // Write value type const int32_t v_type_i = (int32_t) v->type; @@ -1678,6 +1835,9 @@ void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); auto * v = layer.v_stream[cr.strm]; + if (!v) { + continue; + } // Write value type const int32_t v_type_i = (int32_t) v->type; @@ -1881,6 +2041,9 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32 const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); auto * v = layer.v_stream[strm]; + if (!v) { + continue; + } // Read type of value int32_t v_type_i_ref; @@ -1922,6 +2085,9 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32 const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); auto * v = layer.v_stream[strm]; + if (!v) { + continue; + } // Read type of value int32_t v_type_i_ref; diff --git a/examples/talk-llama/llama-kv-cache.h b/examples/talk-llama/llama-kv-cache.h index 0c4ed648456..e194bf3e26f 100644 --- a/examples/talk-llama/llama-kv-cache.h +++ b/examples/talk-llama/llama-kv-cache.h @@ -257,8 +257,6 @@ class llama_kv_cache : public llama_memory_i { size_t size_k_bytes() const; size_t size_v_bytes() const; - bool is_masked_swa(llama_pos p0, llama_pos p1) const; - ggml_tensor * build_rope_shift( const llama_cparams & cparams, ggml_context * ctx, diff --git a/examples/talk-llama/llama-memory-hybrid-iswa.cpp b/examples/talk-llama/llama-memory-hybrid-iswa.cpp new file mode 100644 index 00000000000..411769672af --- /dev/null +++ b/examples/talk-llama/llama-memory-hybrid-iswa.cpp @@ -0,0 +1,275 @@ +#include "llama-memory-hybrid-iswa.h" + +#include "llama-impl.h" +#include "llama-model.h" +#include "llama-context.h" + +// +// llama_memory_hybrid_iswa +// + +llama_memory_hybrid_iswa::llama_memory_hybrid_iswa( + const llama_model & model, + /* attn */ + ggml_type type_k, + ggml_type type_v, + bool v_trans, + bool swa_full, + uint32_t kv_size, + uint32_t n_ubatch, + uint32_t n_pad, + /* recurrent */ + ggml_type type_r, + ggml_type type_s, + uint32_t rs_size, + /* common */ + uint32_t n_seq_max, + bool offload, + bool unified, + /* layer filters */ + const layer_filter_cb & filter_attn, + const layer_filter_cb & filter_recr) : + hparams(model.hparams), + mem_attn(new llama_kv_cache_iswa( + model, + type_k, + type_v, + v_trans, + offload, + swa_full, + unified, + kv_size, + n_seq_max, + n_ubatch, + n_pad, + filter_attn == nullptr ? + [&](int32_t il) { return !hparams.is_recurrent(il); } + : filter_attn, + nullptr + )), + mem_recr(new llama_memory_recurrent( + model, + type_r, + type_s, + offload, + rs_size, + n_seq_max, + filter_recr == nullptr ? + [&](int32_t il) { return hparams.is_recurrent(il); } + : filter_recr + )) {} + +llama_memory_context_ptr llama_memory_hybrid_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) { + do { + balloc.split_reset(); + + // follow the recurrent pattern for creating the ubatch splits + std::vector ubatches; + + while (true) { + llama_ubatch ubatch; + + if (embd_all) { + // if all tokens are output, split by sequence + ubatch = balloc.split_seq(n_ubatch); + } else { + // TODO: non-sequential equal split can be done if using unified KV cache + // for simplicity, we always use sequential equal split for now + ubatch = balloc.split_equal(n_ubatch, true); + } + + if (ubatch.n_tokens == 0) { + break; + } + + ubatches.push_back(std::move(ubatch)); // NOLINT + } + + if (balloc.get_n_used() < balloc.get_n_tokens()) { + // failed to find a suitable split + break; + } + + // prepare the recurrent batches first + if (!mem_recr->prepare(ubatches)) { + // TODO: will the recurrent cache be in an undefined context at this point? + LLAMA_LOG_ERROR("%s: failed to prepare recurrent ubatches\n", __func__); + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + } + + // prepare the attention cache (iswa version returns both base and swa slot infos) + auto sinfos_base = mem_attn->get_base()->prepare(ubatches); + if (sinfos_base.empty()) { + LLAMA_LOG_ERROR("%s: failed to prepare attention base ubatches\n", __func__); + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + } + + auto sinfos_swa = mem_attn->get_swa()->prepare(ubatches); + if (sinfos_swa.empty()) { + LLAMA_LOG_ERROR("%s: failed to prepare attention swa ubatches\n", __func__); + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + } + + return std::make_unique( + this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches)); + } while(false); + + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); +} + +llama_memory_context_ptr llama_memory_hybrid_iswa::init_full() { + return std::make_unique(this); +} + +llama_memory_context_ptr llama_memory_hybrid_iswa::init_update(llama_context * lctx, bool optimize) { + return std::make_unique(this, lctx, optimize); +} + +bool llama_memory_hybrid_iswa::get_can_shift() const { + // Shifting is trivially supported for recurrent + return mem_attn->get_can_shift(); +} + +void llama_memory_hybrid_iswa::clear(bool data) { + mem_attn->clear(data); + mem_recr->clear(data); +} + +bool llama_memory_hybrid_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + // Try removing from the recurrent cache first since it may fail. If it does + // fail, the cache will not have been mutated. + if (!mem_recr->seq_rm(seq_id, p0, p1)) { + return false; + } + return mem_attn->seq_rm(seq_id, p0, p1); +} + +void llama_memory_hybrid_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { + mem_attn->seq_cp(seq_id_src, seq_id_dst, p0, p1); + mem_recr->seq_cp(seq_id_src, seq_id_dst, p0, p1); +} + +void llama_memory_hybrid_iswa::seq_keep(llama_seq_id seq_id) { + mem_attn->seq_keep(seq_id); + mem_recr->seq_keep(seq_id); +} + +void llama_memory_hybrid_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { + mem_attn->seq_add(seq_id, p0, p1, shift); + mem_recr->seq_add(seq_id, p0, p1, shift); +} + +void llama_memory_hybrid_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { + mem_attn->seq_div(seq_id, p0, p1, d); + mem_recr->seq_div(seq_id, p0, p1, d); +} + +llama_pos llama_memory_hybrid_iswa::seq_pos_min(llama_seq_id seq_id) const { + // the min of the total cache is the max of the two caches' min values + return std::max(mem_attn->seq_pos_min(seq_id), mem_recr->seq_pos_min(seq_id)); +} + +llama_pos llama_memory_hybrid_iswa::seq_pos_max(llama_seq_id seq_id) const { + // the max of the total cache is the min of the two caches' max values + return std::min(mem_attn->seq_pos_max(seq_id), mem_recr->seq_pos_max(seq_id)); +} + +std::map llama_memory_hybrid_iswa::memory_breakdown() const { + std::map mb = mem_attn->memory_breakdown(); + for (const auto & buft_size : mem_recr->memory_breakdown()) { + mb[buft_size.first] += buft_size.second; + } + return mb; +} + +void llama_memory_hybrid_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const { + mem_attn->state_write(io, seq_id, flags); + mem_recr->state_write(io, seq_id, flags); +} + +void llama_memory_hybrid_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) { + mem_attn->state_read(io, seq_id, flags); + mem_recr->state_read(io, seq_id, flags); +} + +llama_kv_cache_iswa * llama_memory_hybrid_iswa::get_mem_attn() const { + return mem_attn.get(); +} + +llama_memory_recurrent * llama_memory_hybrid_iswa::get_mem_recr() const { + return mem_recr.get(); +} + +// +// llama_memory_hybrid_iswa_context +// + +llama_memory_hybrid_iswa_context::llama_memory_hybrid_iswa_context(llama_memory_status status) : status(status) {} + +llama_memory_hybrid_iswa_context::llama_memory_hybrid_iswa_context(llama_memory_hybrid_iswa * mem) : + ctx_attn(mem->get_mem_attn()->init_full()), + ctx_recr(mem->get_mem_recr()->init_full()), + status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) { +} + +llama_memory_hybrid_iswa_context::llama_memory_hybrid_iswa_context( + llama_memory_hybrid_iswa * mem, + llama_context * lctx, + bool optimize) : + ctx_attn(mem->get_mem_attn()->init_update(lctx, optimize)), + ctx_recr(mem->get_mem_recr()->init_update(lctx, optimize)), + status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) { +} + +llama_memory_hybrid_iswa_context::llama_memory_hybrid_iswa_context( + llama_memory_hybrid_iswa * mem, + slot_info_vec_t sinfos_base, + slot_info_vec_t sinfos_swa, + std::vector ubatches) : + ubatches(std::move(ubatches)), + // note: here we copy the ubatches. not sure if this is ideal + ctx_attn(new llama_kv_cache_iswa_context(mem->get_mem_attn(), std::move(sinfos_base), std::move(sinfos_swa), this->ubatches)), + ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)), + status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) { +} + +bool llama_memory_hybrid_iswa_context::next() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + ctx_attn->next(); + ctx_recr->next(); + + if (++i_next >= ubatches.size()) { + return false; + } + + return true; +} + +bool llama_memory_hybrid_iswa_context::apply() { + assert(!llama_memory_status_is_fail(status)); + + bool res = true; + + res = res & ctx_attn->apply(); + res = res & ctx_recr->apply(); + + return res; +} + +llama_memory_status llama_memory_hybrid_iswa_context::get_status() const { + return status; +} + +const llama_ubatch & llama_memory_hybrid_iswa_context::get_ubatch() const { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + return ubatches[i_next]; +} + +const llama_kv_cache_iswa_context * llama_memory_hybrid_iswa_context::get_attn() const { + return static_cast(ctx_attn.get()); +} + +const llama_memory_recurrent_context * llama_memory_hybrid_iswa_context::get_recr() const { + return static_cast(ctx_recr.get()); +} diff --git a/examples/talk-llama/llama-memory-hybrid-iswa.h b/examples/talk-llama/llama-memory-hybrid-iswa.h new file mode 100644 index 00000000000..807c8aac96c --- /dev/null +++ b/examples/talk-llama/llama-memory-hybrid-iswa.h @@ -0,0 +1,140 @@ +#pragma once + +#include "llama-batch.h" +#include "llama-graph.h" +#include "llama-kv-cache-iswa.h" +#include "llama-memory.h" +#include "llama-memory-recurrent.h" + +#include +#include + +// +// llama_memory_hybrid_iswa +// + +// utilizes instances of llama_memory_recurrent and llama_kv_cache_iswa to +// support models where each layer may be either attention-based (with SWA support) or recurrent + +class llama_memory_hybrid_iswa : public llama_memory_i { +public: + llama_memory_hybrid_iswa( + const llama_model & model, + /* attn */ + ggml_type type_k, + ggml_type type_v, + bool v_trans, + bool swa_full, + uint32_t kv_size, + uint32_t n_ubatch, + uint32_t n_pad, + /* recurrent */ + ggml_type type_r, + ggml_type type_s, + uint32_t rs_size, + /* common */ + uint32_t n_seq_max, + bool offload, + bool unified, + /* layer filters */ + const layer_filter_cb & filter_attn = nullptr, + const layer_filter_cb & filter_recr = nullptr); + + ~llama_memory_hybrid_iswa() = default; + + // + // llama_memory_i + // + + llama_memory_context_ptr init_batch( + llama_batch_allocr & balloc, + uint32_t n_ubatch, + bool embd_all) override; + + llama_memory_context_ptr init_full() override; + + llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override; + + bool get_can_shift() const override; + + void clear(bool data) override; + + bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; + void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; + void seq_keep(llama_seq_id seq_id) override; + void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override; + void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; + + llama_pos seq_pos_min(llama_seq_id seq_id) const override; + llama_pos seq_pos_max(llama_seq_id seq_id) const override; + + std::map memory_breakdown() const override; + + // state write/load + + void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override; + void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override; + + // + // llama_memory_hybrid_iswa specific API + // + + llama_kv_cache_iswa * get_mem_attn() const; + llama_memory_recurrent * get_mem_recr() const; + +private: + const llama_hparams & hparams; + + const std::unique_ptr mem_attn; + const std::unique_ptr mem_recr; +}; + +class llama_memory_hybrid_iswa_context : public llama_memory_context_i { +public: + using slot_info_vec_t = llama_kv_cache::slot_info_vec_t; + + // init failure + explicit llama_memory_hybrid_iswa_context(llama_memory_status status); + + // init full + explicit llama_memory_hybrid_iswa_context(llama_memory_hybrid_iswa * mem); + + // init update + explicit llama_memory_hybrid_iswa_context( + llama_memory_hybrid_iswa * mem, + llama_context * lctx, + bool optimize); + + // init success + llama_memory_hybrid_iswa_context( + llama_memory_hybrid_iswa * mem, + slot_info_vec_t sinfos_base, + slot_info_vec_t sinfos_swa, + std::vector ubatches); + + ~llama_memory_hybrid_iswa_context() = default; + + bool next() override; + bool apply() override; + + llama_memory_status get_status() const override; + const llama_ubatch & get_ubatch() const override; + + // + // llama_memory_hybrid_iswa_context + // + + const llama_kv_cache_iswa_context * get_attn() const; + const llama_memory_recurrent_context * get_recr() const; + +private: + // the index of the next ubatch to process + size_t i_next = 0; + + std::vector ubatches; + + const llama_memory_context_ptr ctx_attn; + const llama_memory_context_ptr ctx_recr; + + const llama_memory_status status; +}; diff --git a/examples/talk-llama/llama-mmap.cpp b/examples/talk-llama/llama-mmap.cpp index 2da857b3aae..0261e4c72c9 100644 --- a/examples/talk-llama/llama-mmap.cpp +++ b/examples/talk-llama/llama-mmap.cpp @@ -244,11 +244,14 @@ struct llama_file::impl { } errno = 0; if (fd == -1) { - std::size_t ret = std::fread(ptr, len, 1, fp); + const size_t curr_off = tell(); + const size_t to_read = std::min(len, size - curr_off); + + std::size_t ret = std::fread(ptr, to_read, 1, fp); if (ferror(fp)) { throw std::runtime_error(format("read error: %s", strerror(errno))); } - if (ret != 1) { + if (to_read > 0 && ret != 1) { throw std::runtime_error("unexpectedly reached end of file"); } } else { @@ -262,7 +265,8 @@ struct llama_file::impl { continue; // Interrupted by signal, retry } // Fallback to std::fread in case the DMA controller cannot access the buffer - if (errno == EFAULT) { + if (errno == EFAULT || errno == EINVAL) { + LLAMA_LOG_WARN("%s: Falling back to buffered IO due to %s\n", __func__, strerror(errno)); auto curr_off = tell(); close(fd); fd = -1; @@ -381,6 +385,9 @@ int llama_file::file_id() const { #ifdef _WIN32 return _fileno(pimpl->fp); #else + if (pimpl->fd != -1) { + return pimpl->fd; + } #if defined(fileno) return fileno(pimpl->fp); #else @@ -611,9 +618,9 @@ struct llama_mlock::impl { char* errmsg = std::strerror(errno); bool suggest = (errno == ENOMEM); -#if defined(TARGET_OS_VISION) || defined(TARGET_OS_TV) || defined(_AIX) - // visionOS/tvOS dont't support RLIMIT_MEMLOCK - // Skip resource limit checks on visionOS/tvOS +#if defined(TARGET_OS_VISION) || defined(TARGET_OS_TV) || defined(_AIX) || defined(__HAIKU__) + // visionOS/tvOS/Haiku don't support RLIMIT_MEMLOCK + // Skip resource limit checks on these platforms suggest = false; #else struct rlimit lock_limit; diff --git a/examples/talk-llama/llama-model-loader.cpp b/examples/talk-llama/llama-model-loader.cpp index e66febaa021..1501e392ca8 100644 --- a/examples/talk-llama/llama-model-loader.cpp +++ b/examples/talk-llama/llama-model-loader.cpp @@ -2,6 +2,7 @@ #include "ggml.h" +#include #include #include #include @@ -344,6 +345,7 @@ namespace GGUFMeta { GGUFMeta::GKV::get_kv(ctx, kid); switch (arr_info.gt) { + case GGUF_TYPE_BOOL: case GGUF_TYPE_UINT32: case GGUF_TYPE_INT32: GGML_ASSERT((std::is_same::value) || (std::is_same::value)); break; @@ -365,7 +367,13 @@ namespace GGUFMeta { result[i] = value; } } else { - std::copy((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length, result.begin()); + if (arr_info.gt == GGUF_TYPE_BOOL) { + std::transform((const bool *)arr_info.data, (const bool *)arr_info.data + arr_info.length, result.begin(), [](bool x) { + return static_cast(x); + }); + } else { + std::copy((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length, result.begin()); + } } return true; @@ -531,12 +539,18 @@ llama_model_loader::llama_model_loader( files.emplace_back(new llama_file(fname.c_str(), "rb", use_direct_io)); contexts.emplace_back(ctx); - use_direct_io = use_direct_io && files.back()->has_direct_io(); + if (use_mmap && use_direct_io) { + if (files.back()->has_direct_io()) { + LLAMA_LOG_WARN("%s: direct I/O is enabled, disabling mmap\n", __func__); + use_mmap = false; + } else { + LLAMA_LOG_WARN("%s: direct I/O is not available, using mmap\n", __func__); + use_direct_io = false; - // Disable mmap in case Direct I/O is enabled and available - if (use_direct_io && use_mmap) { - use_mmap = false; - LLAMA_LOG_WARN("%s: direct I/O is enabled, disabling mmap\n", __func__); + // reopen file using std::fopen for mmap + files.pop_back(); + files.emplace_back(new llama_file(fname.c_str(), "rb", false)); + } } // Save tensors data offset of the main file. diff --git a/examples/talk-llama/llama-model-saver.cpp b/examples/talk-llama/llama-model-saver.cpp index ae27c71ce23..36e353074e0 100644 --- a/examples/talk-llama/llama-model-saver.cpp +++ b/examples/talk-llama/llama-model-saver.cpp @@ -146,8 +146,8 @@ void llama_model_saver::add_kv_from_model() { add_kv(LLM_KV_VOCAB_SIZE, vocab.n_tokens()); add_kv(LLM_KV_CONTEXT_LENGTH, hparams.n_ctx_train); add_kv(LLM_KV_EMBEDDING_LENGTH, hparams.n_embd); - if (hparams.n_embd_out > 0) { - add_kv(LLM_KV_EMBEDDING_LENGTH_OUT, hparams.n_embd_out); + if (hparams.n_embd_out_impl > 0) { + add_kv(LLM_KV_EMBEDDING_LENGTH_OUT, hparams.n_embd_out_impl); } add_kv(LLM_KV_BLOCK_COUNT, hparams.n_layer); add_kv(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); diff --git a/examples/talk-llama/llama-model.cpp b/examples/talk-llama/llama-model.cpp index f6cea8f8db4..72490a89b56 100644 --- a/examples/talk-llama/llama-model.cpp +++ b/examples/talk-llama/llama-model.cpp @@ -8,6 +8,7 @@ #include "llama-kv-cache.h" #include "llama-kv-cache-iswa.h" #include "llama-memory-hybrid.h" +#include "llama-memory-hybrid-iswa.h" #include "llama-memory-recurrent.h" #include "ggml-cpp.h" @@ -446,7 +447,7 @@ struct llama_model::impl { llama_mlocks mlock_bufs; llama_mlocks mlock_mmaps; - // contexts where the model tensors metadata is stored as well ass the corresponding buffers: + // contexts where the model tensors metadata is stored as well as the corresponding buffers: std::vector>> ctxs_bufs; buft_list_t cpu_buft_list; @@ -468,7 +469,11 @@ llama_model::llama_model(const llama_model_params & params) : params(params), pi pimpl->has_tensor_overrides = params.tensor_buft_overrides && params.tensor_buft_overrides[0].pattern; } -llama_model::~llama_model() = default; +llama_model::~llama_model() { + for (auto * lora : loras) { + delete lora; + } +} void llama_model::load_stats(llama_model_loader & ml) { pimpl->n_elements = ml.n_elements; @@ -507,7 +512,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_CONTEXT_LENGTH, hparams.n_ctx_train); ml.get_key(LLM_KV_EMBEDDING_LENGTH, hparams.n_embd); - ml.get_key(LLM_KV_EMBEDDING_LENGTH_OUT, hparams.n_embd_out, false); + ml.get_key(LLM_KV_EMBEDDING_LENGTH_OUT, hparams.n_embd_out_impl, false); ml.get_key(LLM_KV_BLOCK_COUNT, hparams.n_layer); ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert, false); ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, false); @@ -1692,15 +1697,16 @@ void llama_model::load_hparams(llama_model_loader & ml) { case LLM_ARCH_DEEPSEEK2: { // lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B - bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26); + const bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); if (!is_lite) { ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); } ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); - ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla, false); - ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla, false); + ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla_impl, false); + ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla_impl, false); ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); @@ -1709,7 +1715,12 @@ void llama_model::load_hparams(llama_model_loader & ml) { if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { // for compatibility with existing DeepSeek V2 and V2.5 GGUFs // that have no expert_gating_func model parameter set - hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX; + if ((hparams.n_layer == 47 || hparams.n_layer == 48) && n_vocab == 154880) { + // GLM 4.7 Lite + hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID; + } else { + hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX; + } } if (ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, 0.0f)) { @@ -1726,6 +1737,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { switch (hparams.n_layer) { case 27: type = LLM_TYPE_16B; break; + case 47: type = LLM_TYPE_30B_A3B; break; case 60: type = LLM_TYPE_236B; break; case 61: type = LLM_TYPE_671B; break; default: type = LLM_TYPE_UNKNOWN; @@ -1933,6 +1945,34 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_EXAONE_MOE: + { + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + hparams.n_swa = 128; + hparams.set_swa_pattern(4); + hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; + hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; + + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared, false); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); + + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_30B_A3B; break; + case 48: + case 49: type = LLM_TYPE_235B_A22B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_RWKV6: case LLM_ARCH_RWKV6QWEN2: { @@ -4871,14 +4911,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } break; case LLM_ARCH_DEEPSEEK2: { - // lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B - const bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26); - - const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0); + const bool is_mla = hparams.is_mla(); // note: these are the actual head sizes you get when treating as MHA or after "decompression" using wv_b for MLA - const int64_t n_embd_head_k_mla = is_mla ? hparams.n_embd_head_k_mla : hparams.n_embd_head_k; - const int64_t n_embd_head_v_mla = is_mla ? hparams.n_embd_head_v_mla : hparams.n_embd_head_v; + const int64_t n_embd_head_k_mla = hparams.n_embd_head_k_mla(); + const int64_t n_embd_head_v_mla = hparams.n_embd_head_v_mla(); const int64_t n_embd_head_qk_rope = hparams.n_rot; const int64_t n_embd_head_qk_nope = n_embd_head_k_mla - n_embd_head_qk_rope; @@ -4903,13 +4940,13 @@ bool llama_model::load_tensors(llama_model_loader & ml) { auto & layer = layers[i]; layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - if (!is_lite) { + if (q_lora_rank > 0) { layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, 0); } layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, 0); - if (!is_lite) { + if (q_lora_rank > 0) { layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, 0); layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k_mla}, 0); } else { @@ -5516,6 +5553,84 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); } } break; + case LLM_ARCH_EXAONE_MOE: + { + const int64_t n_ff_exp = hparams.n_ff_exp; + const int64_t n_expert = hparams.n_expert; + const int64_t n_expert_used = hparams.n_expert_used; + const int64_t n_ff_shexp = hparams.n_ff_shexp; + const int64_t head_dim = hparams.n_embd_head_k; + const int64_t n_qo_dim = n_head * head_dim; + const int64_t n_kv_dim = n_head_kv * head_dim; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + int flags = 0; + if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { + // skip all tensors in the NextN layers + flags |= TENSOR_SKIP; + } + + auto & layer = layers[i]; + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_qo_dim}, flags); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_kv_dim}, flags); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_kv_dim}, flags); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_qo_dim, n_embd}, flags); + + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0) | flags); + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, flags); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, flags); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, flags); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, flags); + + // dense layers for first n_layer_dense_lead layers or nextn_predict_layers layers at the end + if (i < (int) hparams.n_layer_dense_lead || (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers)) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, flags); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, flags); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, flags); + } else { + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, flags); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED | flags); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0"); + } + + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, flags); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, flags); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, flags); + + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp}, flags); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, flags); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp}, flags); + } + + // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers + if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), {2 * n_embd, n_embd}, flags); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), {n_embd}, flags); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), {n_embd}, flags); + + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), {n_embd}, flags | TENSOR_NOT_REQUIRED); + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), {n_embd, n_vocab}, flags | TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), {n_embd, n_vocab}, flags | TENSOR_NOT_REQUIRED); + } + } + } break; case LLM_ARCH_RWKV6: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -6481,7 +6596,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } // for LFM2-ColBert-350M - dense_2_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_2_OUT, "weight"), {n_embd, hparams.get_n_embd_out()}, TENSOR_NOT_REQUIRED); + dense_2_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_2_OUT, "weight"), {n_embd, hparams.n_embd_out()}, TENSOR_NOT_REQUIRED); } break; case LLM_ARCH_SMALLTHINKER: { @@ -7101,59 +7216,59 @@ void llama_model::print_info() const { }; // hparams - LLAMA_LOG_INFO("%s: arch = %s\n", __func__, arch_name().c_str()); - LLAMA_LOG_INFO("%s: vocab_only = %d\n", __func__, hparams.vocab_only); - LLAMA_LOG_INFO("%s: no_alloc = %d\n", __func__, hparams.no_alloc); + LLAMA_LOG_INFO("%s: arch = %s\n", __func__, arch_name().c_str()); + LLAMA_LOG_INFO("%s: vocab_only = %d\n", __func__, hparams.vocab_only); + LLAMA_LOG_INFO("%s: no_alloc = %d\n", __func__, hparams.no_alloc); if (!hparams.vocab_only) { - LLAMA_LOG_INFO("%s: n_ctx_train = %u\n", __func__, hparams.n_ctx_train); - LLAMA_LOG_INFO("%s: n_embd = %u\n", __func__, hparams.n_embd); - LLAMA_LOG_INFO("%s: n_embd_inp = %u\n", __func__, hparams.n_embd_inp()); - LLAMA_LOG_INFO("%s: n_layer = %u\n", __func__, hparams.n_layer); - LLAMA_LOG_INFO("%s: n_head = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head(il); }, hparams.n_layer).c_str()); - LLAMA_LOG_INFO("%s: n_head_kv = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head_kv(il); }, hparams.n_layer).c_str()); - LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot); - LLAMA_LOG_INFO("%s: n_swa = %u\n", __func__, hparams.n_swa); - LLAMA_LOG_INFO("%s: is_swa_any = %u\n", __func__, hparams.is_swa_any()); - LLAMA_LOG_INFO("%s: n_embd_head_k = %u\n", __func__, hparams.n_embd_head_k); - LLAMA_LOG_INFO("%s: n_embd_head_v = %u\n", __func__, hparams.n_embd_head_v); - LLAMA_LOG_INFO("%s: n_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_gqa(il); }, hparams.n_layer).c_str()); - LLAMA_LOG_INFO("%s: n_embd_k_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_k_gqa(il); }, hparams.n_layer).c_str()); - LLAMA_LOG_INFO("%s: n_embd_v_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_v_gqa(il); }, hparams.n_layer).c_str()); - LLAMA_LOG_INFO("%s: f_norm_eps = %.1e\n", __func__, hparams.f_norm_eps); - LLAMA_LOG_INFO("%s: f_norm_rms_eps = %.1e\n", __func__, hparams.f_norm_rms_eps); - LLAMA_LOG_INFO("%s: f_clamp_kqv = %.1e\n", __func__, hparams.f_clamp_kqv); - LLAMA_LOG_INFO("%s: f_max_alibi_bias = %.1e\n", __func__, hparams.f_max_alibi_bias); - LLAMA_LOG_INFO("%s: f_logit_scale = %.1e\n", __func__, hparams.f_logit_scale); - LLAMA_LOG_INFO("%s: f_attn_scale = %.1e\n", __func__, hparams.f_attention_scale); - LLAMA_LOG_INFO("%s: n_ff = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_ff(il); }, hparams.n_layer).c_str()); - LLAMA_LOG_INFO("%s: n_expert = %u\n", __func__, hparams.n_expert); - LLAMA_LOG_INFO("%s: n_expert_used = %u\n", __func__, hparams.n_expert_used); - LLAMA_LOG_INFO("%s: n_expert_groups = %d\n", __func__, hparams.n_expert_groups); - LLAMA_LOG_INFO("%s: n_group_used = %d\n", __func__, hparams.n_group_used); - LLAMA_LOG_INFO("%s: causal attn = %d\n", __func__, hparams.causal_attn); - LLAMA_LOG_INFO("%s: pooling type = %d\n", __func__, hparams.pooling_type); - LLAMA_LOG_INFO("%s: rope type = %d\n", __func__, hparams.rope_type); - LLAMA_LOG_INFO("%s: rope scaling = %s\n", __func__, rope_scaling_type.c_str()); - LLAMA_LOG_INFO("%s: freq_base_train = %.1f\n", __func__, hparams.rope_freq_base_train); - LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train); + LLAMA_LOG_INFO("%s: n_ctx_train = %u\n", __func__, hparams.n_ctx_train); + LLAMA_LOG_INFO("%s: n_embd = %u\n", __func__, hparams.n_embd); + LLAMA_LOG_INFO("%s: n_embd_inp = %u\n", __func__, hparams.n_embd_inp()); + LLAMA_LOG_INFO("%s: n_layer = %u\n", __func__, hparams.n_layer); + LLAMA_LOG_INFO("%s: n_head = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head(il); }, hparams.n_layer).c_str()); + LLAMA_LOG_INFO("%s: n_head_kv = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head_kv(il); }, hparams.n_layer).c_str()); + LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot); + LLAMA_LOG_INFO("%s: n_swa = %u\n", __func__, hparams.n_swa); + LLAMA_LOG_INFO("%s: is_swa_any = %u\n", __func__, hparams.is_swa_any()); + LLAMA_LOG_INFO("%s: n_embd_head_k = %u\n", __func__, hparams.n_embd_head_k); + LLAMA_LOG_INFO("%s: n_embd_head_v = %u\n", __func__, hparams.n_embd_head_v); + LLAMA_LOG_INFO("%s: n_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_gqa(il); }, hparams.n_layer).c_str()); + LLAMA_LOG_INFO("%s: n_embd_k_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_k_gqa(il); }, hparams.n_layer).c_str()); + LLAMA_LOG_INFO("%s: n_embd_v_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_v_gqa(il); }, hparams.n_layer).c_str()); + LLAMA_LOG_INFO("%s: f_norm_eps = %.1e\n", __func__, hparams.f_norm_eps); + LLAMA_LOG_INFO("%s: f_norm_rms_eps = %.1e\n", __func__, hparams.f_norm_rms_eps); + LLAMA_LOG_INFO("%s: f_clamp_kqv = %.1e\n", __func__, hparams.f_clamp_kqv); + LLAMA_LOG_INFO("%s: f_max_alibi_bias = %.1e\n", __func__, hparams.f_max_alibi_bias); + LLAMA_LOG_INFO("%s: f_logit_scale = %.1e\n", __func__, hparams.f_logit_scale); + LLAMA_LOG_INFO("%s: f_attn_scale = %.1e\n", __func__, hparams.f_attention_scale); + LLAMA_LOG_INFO("%s: n_ff = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_ff(il); }, hparams.n_layer).c_str()); + LLAMA_LOG_INFO("%s: n_expert = %u\n", __func__, hparams.n_expert); + LLAMA_LOG_INFO("%s: n_expert_used = %u\n", __func__, hparams.n_expert_used); + LLAMA_LOG_INFO("%s: n_expert_groups = %d\n", __func__, hparams.n_expert_groups); + LLAMA_LOG_INFO("%s: n_group_used = %d\n", __func__, hparams.n_group_used); + LLAMA_LOG_INFO("%s: causal attn = %d\n", __func__, hparams.causal_attn); + LLAMA_LOG_INFO("%s: pooling type = %d\n", __func__, hparams.pooling_type); + LLAMA_LOG_INFO("%s: rope type = %d\n", __func__, hparams.rope_type); + LLAMA_LOG_INFO("%s: rope scaling = %s\n", __func__, rope_scaling_type.c_str()); + LLAMA_LOG_INFO("%s: freq_base_train = %.1f\n", __func__, hparams.rope_freq_base_train); + LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train); if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { - LLAMA_LOG_INFO("%s: freq_base_swa = %.1f\n", __func__, hparams.rope_freq_base_train_swa); - LLAMA_LOG_INFO("%s: freq_scale_swa = %g\n", __func__, hparams.rope_freq_scale_train_swa); + LLAMA_LOG_INFO("%s: freq_base_swa = %.1f\n", __func__, hparams.rope_freq_base_train_swa); + LLAMA_LOG_INFO("%s: freq_scale_swa = %g\n", __func__, hparams.rope_freq_scale_train_swa); } - LLAMA_LOG_INFO("%s: n_ctx_orig_yarn = %u\n", __func__, hparams.n_ctx_orig_yarn); - LLAMA_LOG_INFO("%s: rope_yarn_log_mul= %.4f\n", __func__, hparams.rope_yarn_log_mul); - LLAMA_LOG_INFO("%s: rope_finetuned = %s\n", __func__, hparams.rope_finetuned ? "yes" : "unknown"); + LLAMA_LOG_INFO("%s: n_ctx_orig_yarn = %u\n", __func__, hparams.n_ctx_orig_yarn); + LLAMA_LOG_INFO("%s: rope_yarn_log_mul = %.4f\n", __func__, hparams.rope_yarn_log_mul); + LLAMA_LOG_INFO("%s: rope_finetuned = %s\n", __func__, hparams.rope_finetuned ? "yes" : "unknown"); // MRoPE (Multi-axis Rotary Position Embedding) sections if (const auto & s = hparams.rope_sections; s[0] || s[1] || s[2] || s[3]) { - LLAMA_LOG_INFO("%s: mrope sections = [%d, %d, %d, %d]\n", __func__, s[0], s[1], s[2], s[3]); + LLAMA_LOG_INFO("%s: mrope sections = [%d, %d, %d, %d]\n", __func__, s[0], s[1], s[2], s[3]); } if (!classifier_labels.empty()) { - LLAMA_LOG_INFO("%s: n_cls_out = %u\n", __func__, hparams.n_cls_out); + LLAMA_LOG_INFO("%s: n_cls_out = %u\n", __func__, hparams.n_cls_out); size_t i = 0; for (auto label : classifier_labels) { - LLAMA_LOG_INFO("%s: cls_label[%2zu] = %s\n", __func__, i++, label.c_str()); + LLAMA_LOG_INFO("%s: cls_label[%2zu] = %s\n", __func__, i++, label.c_str()); } } } @@ -7167,55 +7282,55 @@ void llama_model::print_info() const { arch == LLM_ARCH_QWEN3NEXT || arch == LLM_ARCH_NEMOTRON_H || arch == LLM_ARCH_NEMOTRON_H_MOE) { - LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv); - LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner); - LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state); - LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank); - LLAMA_LOG_INFO("%s: ssm_n_group = %u\n", __func__, hparams.ssm_n_group); - LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms); + LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv); + LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner); + LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state); + LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank); + LLAMA_LOG_INFO("%s: ssm_n_group = %u\n", __func__, hparams.ssm_n_group); + LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms); } - LLAMA_LOG_INFO("%s: model type = %s\n", __func__, type_name().c_str()); + LLAMA_LOG_INFO("%s: model type = %s\n", __func__, type_name().c_str()); if (pimpl->n_elements >= 1e12) { - LLAMA_LOG_INFO("%s: model params = %.2f T\n", __func__, pimpl->n_elements*1e-12); + LLAMA_LOG_INFO("%s: model params = %.2f T\n", __func__, pimpl->n_elements*1e-12); } else if (pimpl->n_elements >= 1e9) { - LLAMA_LOG_INFO("%s: model params = %.2f B\n", __func__, pimpl->n_elements*1e-9); + LLAMA_LOG_INFO("%s: model params = %.2f B\n", __func__, pimpl->n_elements*1e-9); } else if (pimpl->n_elements >= 1e6) { - LLAMA_LOG_INFO("%s: model params = %.2f M\n", __func__, pimpl->n_elements*1e-6); + LLAMA_LOG_INFO("%s: model params = %.2f M\n", __func__, pimpl->n_elements*1e-6); } else { - LLAMA_LOG_INFO("%s: model params = %.2f K\n", __func__, pimpl->n_elements*1e-3); + LLAMA_LOG_INFO("%s: model params = %.2f K\n", __func__, pimpl->n_elements*1e-3); } // general kv - LLAMA_LOG_INFO("%s: general.name = %s\n", __func__, name.c_str()); + LLAMA_LOG_INFO("%s: general.name = %s\n", __func__, name.c_str()); if (arch == LLM_ARCH_DEEPSEEK) { - LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); - LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); + LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); + LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); } if (arch == LLM_ARCH_DEEPSEEK2) { - LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); - LLAMA_LOG_INFO("%s: n_lora_q = %d\n", __func__, hparams.n_lora_q); - LLAMA_LOG_INFO("%s: n_lora_kv = %d\n", __func__, hparams.n_lora_kv); - LLAMA_LOG_INFO("%s: n_embd_head_k_mla = %d\n", __func__, hparams.n_embd_head_k_mla); - LLAMA_LOG_INFO("%s: n_embd_head_v_mla = %d\n", __func__, hparams.n_embd_head_v_mla); - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); - LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); - LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); - LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); + LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); + LLAMA_LOG_INFO("%s: n_lora_q = %d\n", __func__, hparams.n_lora_q); + LLAMA_LOG_INFO("%s: n_lora_kv = %d\n", __func__, hparams.n_lora_kv); + LLAMA_LOG_INFO("%s: n_embd_head_k_mla = %d\n", __func__, hparams.n_embd_head_k_mla()); + LLAMA_LOG_INFO("%s: n_embd_head_v_mla = %d\n", __func__, hparams.n_embd_head_v_mla()); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); + LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); + LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); + LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); } if (arch == LLM_ARCH_QWEN2MOE) { - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); } if (arch == LLM_ARCH_QWEN3MOE || arch == LLM_ARCH_OPENAI_MOE || arch == LLM_ARCH_QWEN3VLMOE || arch == LLM_ARCH_RND1) { - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); } if (arch == LLM_ARCH_MINICPM || @@ -7223,41 +7338,41 @@ void llama_model::print_info() const { arch == LLM_ARCH_GRANITE_MOE || arch == LLM_ARCH_GRANITE_HYBRID || arch == LLM_ARCH_NEMOTRON_H_MOE) { - LLAMA_LOG_INFO("%s: f_embedding_scale = %f\n", __func__, hparams.f_embedding_scale); - LLAMA_LOG_INFO("%s: f_residual_scale = %f\n", __func__, hparams.f_residual_scale); - LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale); - LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); + LLAMA_LOG_INFO("%s: f_embedding_scale = %f\n", __func__, hparams.f_embedding_scale); + LLAMA_LOG_INFO("%s: f_residual_scale = %f\n", __func__, hparams.f_residual_scale); + LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale); + LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); } if (arch == LLM_ARCH_BAILINGMOE) { - LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); - LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); - LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); + LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); + LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); + LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); } if (arch == LLM_ARCH_BAILINGMOE2) { - LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); - LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); - LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); - LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); - LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); - LLAMA_LOG_INFO("%s: nextn_predict_layers = %d\n", __func__, hparams.nextn_predict_layers); + LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); + LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); + LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); + LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); + LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); + LLAMA_LOG_INFO("%s: nextn_predict_layers = %d\n", __func__, hparams.nextn_predict_layers); } if (arch == LLM_ARCH_SMALLTHINKER || arch == LLM_ARCH_LFM2MOE) { - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); } if (arch == LLM_ARCH_GROVEMOE) { - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: n_ff_chexp = %d\n", __func__, hparams.n_ff_chexp); - LLAMA_LOG_INFO("%s: n_group_experts = %d\n", __func__, hparams.n_group_experts); - LLAMA_LOG_INFO("%s: expert_group_scale = %.2f\n", __func__, hparams.expert_group_scale); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_ff_chexp = %d\n", __func__, hparams.n_ff_chexp); + LLAMA_LOG_INFO("%s: n_group_experts = %d\n", __func__, hparams.n_group_experts); + LLAMA_LOG_INFO("%s: expert_group_scale = %.2f\n", __func__, hparams.expert_group_scale); } vocab.print_info(); @@ -7413,23 +7528,44 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, }; } - res = new llama_memory_hybrid( - /* model */ *this, - /* attn_type_k */ params.type_k, - /* attn_type_v */ params.type_v, - /* attn_v_trans */ !cparams.flash_attn, - /* attn_kv_size */ cparams.n_ctx, - /* attn_n_pad */ 1, - /* attn_n_swa */ hparams.n_swa, - /* attn_swa_type */ hparams.swa_type, - /* recurrent_type_k */ GGML_TYPE_F32, - /* recurrent_type_v */ GGML_TYPE_F32, - /* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max), - /* n_seq_max */ cparams.n_seq_max, - /* offload */ cparams.offload_kqv, - /* unified */ cparams.kv_unified, - /* filter_attn */ std::move(filter_attn), - /* filter_recr */ std::move(filter_recr)); + if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { + // Use hybrid-iswa for hybrid models with SWA + res = new llama_memory_hybrid_iswa( + /* model */ *this, + /* attn_type_k */ params.type_k, + /* attn_type_v */ params.type_v, + /* attn_v_trans */ !cparams.flash_attn, + /* attn_swa_full */ params.swa_full, + /* attn_kv_size */ cparams.n_ctx, + /* attn_n_ubatch */ cparams.n_ubatch, + /* attn_n_pad */ 1, + /* recurrent_type_r */ GGML_TYPE_F32, + /* recurrent_type_s */ GGML_TYPE_F32, + /* recurrent_rs_size */ std::max((uint32_t) 1, cparams.n_seq_max), + /* n_seq_max */ cparams.n_seq_max, + /* offload */ cparams.offload_kqv, + /* unified */ cparams.kv_unified, + /* filter_attn */ std::move(filter_attn), + /* filter_recr */ std::move(filter_recr)); + } else { + res = new llama_memory_hybrid( + /* model */ *this, + /* attn_type_k */ params.type_k, + /* attn_type_v */ params.type_v, + /* attn_v_trans */ !cparams.flash_attn, + /* attn_kv_size */ cparams.n_ctx, + /* attn_n_pad */ 1, + /* attn_n_swa */ hparams.n_swa, + /* attn_swa_type */ hparams.swa_type, + /* recurrent_type_k */ GGML_TYPE_F32, + /* recurrent_type_v */ GGML_TYPE_F32, + /* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max), + /* n_seq_max */ cparams.n_seq_max, + /* offload */ cparams.offload_kqv, + /* unified */ cparams.kv_unified, + /* filter_attn */ std::move(filter_attn), + /* filter_recr */ std::move(filter_recr)); + } } else { llama_memory_i::layer_reuse_cb reuse = nullptr; @@ -7811,6 +7947,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { llm = std::make_unique>(*this, params); } } break; + case LLM_ARCH_EXAONE_MOE: + { + llm = std::make_unique(*this, params); + } break; case LLM_ARCH_RWKV6: { llm = std::make_unique(*this, params); @@ -7985,7 +8125,7 @@ llama_model_params llama_model_default_params() { /*.kv_overrides =*/ nullptr, /*.vocab_only =*/ false, /*.use_mmap =*/ true, - /*.use_direct_io =*/ true, + /*.use_direct_io =*/ false, /*.use_mlock =*/ false, /*.check_tensors =*/ false, /*.use_extra_bufts =*/ true, @@ -8021,7 +8161,7 @@ int32_t llama_model_n_embd_inp(const llama_model * model) { } int32_t llama_model_n_embd_out(const llama_model * model) { - return model->hparams.get_n_embd_out(); + return model->hparams.n_embd_out(); } int32_t llama_model_n_layer(const llama_model * model) { @@ -8171,6 +8311,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_NEMOTRON: case LLM_ARCH_EXAONE: case LLM_ARCH_EXAONE4: + case LLM_ARCH_EXAONE_MOE: case LLM_ARCH_MINICPM3: case LLM_ARCH_BAILINGMOE2: case LLM_ARCH_DOTS1: diff --git a/examples/talk-llama/llama-model.h b/examples/talk-llama/llama-model.h index 79200a0d97a..d1de16e3f28 100644 --- a/examples/talk-llama/llama-model.h +++ b/examples/talk-llama/llama-model.h @@ -11,6 +11,7 @@ #include #include #include +#include #include struct llama_cparams; @@ -476,8 +477,8 @@ struct llama_model { // for quantize-stats only std::vector> tensors_by_name; - // for keeping track of extra nodes used by lora adapters - uint32_t n_lora_nodes = 0; + // for keeping track of associated LoRA adapters + std::unordered_set loras; int64_t t_load_us = 0; int64_t t_start_us = 0; diff --git a/examples/talk-llama/llama-quant.cpp b/examples/talk-llama/llama-quant.cpp index 048d65a75c2..776222cb6f2 100644 --- a/examples/talk-llama/llama-quant.cpp +++ b/examples/talk-llama/llama-quant.cpp @@ -422,57 +422,6 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t ++qs.i_ffn_up; } - // if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K; - //} - // IK: let's remove this, else Q2_K is almost the same as Q3_K_S - //else if (name.find("ffn_gate") != std::string::npos || name.find("ffn_up") != std::string::npos) { - // if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K; - //} - // This can be used to reduce the size of the Q5_K_S model. - // The associated PPL increase is fully in line with the size reduction - //else { - // if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_S) new_type = GGML_TYPE_Q4_K; - //} - bool convert_incompatible_tensor = false; - { - const int64_t nx = tensor->ne[0]; - const int64_t ny = tensor->ne[1]; - const int64_t qk_k = ggml_blck_size(new_type); - - if (nx % qk_k != 0) { - LLAMA_LOG_WARN("\n\n%s : tensor cols %" PRId64 " x %" PRId64 " are not divisible by %" PRId64 ", required for %s", __func__, nx, ny, qk_k, ggml_type_name(new_type)); - convert_incompatible_tensor = true; - } else { - ++qs.n_k_quantized; - } - } - - if (convert_incompatible_tensor) { - switch (new_type) { - case GGML_TYPE_TQ1_0: - case GGML_TYPE_TQ2_0: new_type = GGML_TYPE_Q4_0; break; // TODO: use a symmetric type instead - case GGML_TYPE_IQ2_XXS: - case GGML_TYPE_IQ2_XS: - case GGML_TYPE_IQ2_S: - case GGML_TYPE_IQ3_XXS: - case GGML_TYPE_IQ3_S: - case GGML_TYPE_IQ1_S: - case GGML_TYPE_IQ1_M: - case GGML_TYPE_Q2_K: - case GGML_TYPE_Q3_K: - case GGML_TYPE_IQ4_XS: new_type = GGML_TYPE_IQ4_NL; break; - case GGML_TYPE_Q4_K: new_type = GGML_TYPE_Q5_0; break; - case GGML_TYPE_Q5_K: new_type = GGML_TYPE_Q5_1; break; - case GGML_TYPE_Q6_K: new_type = GGML_TYPE_Q8_0; break; - default: throw std::runtime_error("\nUnsupported tensor size encountered\n"); - } - if (tensor->ne[0] % ggml_blck_size(new_type) != 0) { - new_type = GGML_TYPE_F16; - } - LLAMA_LOG_WARN(" - using fallback quantization %s\n", ggml_type_name(new_type)); - ++qs.n_fallback; - } - return new_type; } @@ -596,7 +545,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: } std::vector splits = {}; - llama_model_loader ml(fname_inp, splits, use_mmap, /*use_direct_io*/ true, /*check_tensors*/ true, /*no_alloc*/ false, kv_overrides, nullptr); + llama_model_loader ml(fname_inp, splits, use_mmap, /*use_direct_io*/ false, /*check_tensors*/ true, /*no_alloc*/ false, kv_overrides, nullptr); ml.init_mappings(false); // no prefetching llama_model model(llama_model_default_params()); @@ -875,21 +824,69 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: // get more optimal quantization type based on the tensor shape, layer, etc. if (!params->pure && ggml_is_quantized(default_type)) { - int fallback = qs.n_fallback; - new_type = llama_tensor_get_type(qs, new_type, tensor, ftype); - // unless the user specifies a type, and the tensor geometry will not require fallback quantisation - if (params->tensor_types && qs.n_fallback - fallback == 0) { + // if the user provided tensor types - use those + bool manual = false; + if (params->tensor_types) { const std::vector & tensor_types = *static_cast *>(params->tensor_types); const std::string tensor_name(tensor->name); for (const auto & [tname, qtype] : tensor_types) { if (std::regex pattern(tname); std::regex_search(tensor_name, pattern)) { if (qtype != new_type) { - LLAMA_LOG_DEBUG("(overriding %s) ", ggml_type_name(new_type)); + LLAMA_LOG_WARN("(manual override: %s -> %s) ", ggml_type_name(new_type), ggml_type_name(qtype)); new_type = qtype; // if two or more types are specified for the same tensor, the last match wins + manual = true; + break; } } } } + + // if not manual - use the standard logic for choosing the quantization type based on the selected mixture + if (!manual) { + new_type = llama_tensor_get_type(qs, new_type, tensor, ftype); + } + + // incompatible tensor shapes are handled here - fallback to a compatible type + { + bool convert_incompatible_tensor = false; + + const int64_t nx = tensor->ne[0]; + const int64_t ny = tensor->ne[1]; + const int64_t qk_k = ggml_blck_size(new_type); + + if (nx % qk_k != 0) { + LLAMA_LOG_WARN("\n\n%s : tensor cols %" PRId64 " x %" PRId64 " are not divisible by %" PRId64 ", required for %s", __func__, nx, ny, qk_k, ggml_type_name(new_type)); + convert_incompatible_tensor = true; + } else { + ++qs.n_k_quantized; + } + + if (convert_incompatible_tensor) { + switch (new_type) { + case GGML_TYPE_TQ1_0: + case GGML_TYPE_TQ2_0: new_type = GGML_TYPE_Q4_0; break; // TODO: use a symmetric type instead + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_IQ4_XS: new_type = GGML_TYPE_IQ4_NL; break; + case GGML_TYPE_Q4_K: new_type = GGML_TYPE_Q5_0; break; + case GGML_TYPE_Q5_K: new_type = GGML_TYPE_Q5_1; break; + case GGML_TYPE_Q6_K: new_type = GGML_TYPE_Q8_0; break; + default: throw std::runtime_error("\nUnsupported tensor size encountered\n"); + } + if (tensor->ne[0] % ggml_blck_size(new_type) != 0) { + new_type = GGML_TYPE_F16; + } + LLAMA_LOG_WARN(" - using fallback quantization %s\n", ggml_type_name(new_type)); + ++qs.n_fallback; + } + } } if (params->token_embedding_type < GGML_TYPE_COUNT && strcmp(tensor->name, "token_embd.weight") == 0) { new_type = params->token_embedding_type; diff --git a/examples/talk-llama/llama-sampling.cpp b/examples/talk-llama/llama-sampling.cpp index 11f0394c4ce..5dde513065b 100644 --- a/examples/talk-llama/llama-sampling.cpp +++ b/examples/talk-llama/llama-sampling.cpp @@ -1513,12 +1513,9 @@ static void llama_sampler_top_p_backend_apply( mask_reshaped = ggml_set_rows(ctx, mask_reshaped, ones, ggml_cast(ctx, idxf, GGML_TYPE_I32)); mask = ggml_reshape_1d(ctx, mask_reshaped, mask->ne[0]); - // Use ggml_scale_bias (output = (a * s) + b) which in this case becomes: - // top_p_bias = (mask * 1e9f) - 1e9f. - // So entries in the mask that we want to discard will become -1e9f, and - // others will be 0 (meaning that will not effect the logits). - const float large_val = 1e9f; - struct ggml_tensor * top_p_bias = ggml_scale_bias(ctx, mask, large_val, -large_val); + // Apply -INFINITY bias for masked-out tokens + // log(1) = 0 (keep), log(0) = -INF (discard) + struct ggml_tensor * top_p_bias = ggml_log(ctx, mask); ggml_set_name(top_p_bias, "top_p_bias"); data->logits = ggml_add(ctx, sorted_logits, top_p_bias); @@ -1673,15 +1670,11 @@ static void llama_sampler_min_p_backend_apply( struct ggml_tensor * mask = ggml_step(ctx, sub); ggml_set_name(mask, "min_p_mask"); - // Use ggml_scale_bias (output = (a * s) + b) which in this case becomes: - // min_p_bias = (mask * 1e9f) - 1e9f. - // So entries in the mask that we want to discard will become -1e9f, and - // others will be 0 (meaning that will not effect the logits). - const float large_val = 1e9f; - struct ggml_tensor * min_p_bias = ggml_scale_bias(ctx, mask, large_val, -large_val); + // Apply -INFINITY bias for masked-out tokens + // log(1) = 0 (keep), log(0) = -INF (discard) + struct ggml_tensor * min_p_bias = ggml_log(ctx, mask); ggml_set_name(min_p_bias, "min_p_bias"); - // Add the min_p bias to the logits. data->logits = ggml_add(ctx, data->logits, min_p_bias); ggml_set_name(data->logits, "min_p_logits"); @@ -3293,6 +3286,170 @@ struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, floa return result; } +// adaptive-p sampler state +// +// maintains an exponential moving average of the *ORIGINAL* probabilities +// of selected tokens, used to compute an adapted target at each sampling step. +// +// see llama.h for a full description of the sampler +// +// ref: https://github.com/ggml-org/llama.cpp/pull/17927 +// +struct llama_sampler_adaptive_p { + const float target; // target probability (0.0 - 1.0; negative = disabled) + const float decay; // EMA decay; history ~= 1/(1-decay) tokens (0.0 - 0.99) + const uint32_t seed; // original RNG seed + uint32_t seed_cur; // actual RNG seed + std::mt19937 rng; // RNG state + float weighted_sum; // sum(p_i * decay^i) + float total_weight; // sum(decay^i), converges to 1/(1-decay) + std::vector original_probs; // pre-transform probs, cached for EMA update + llama_token pending_token_id; // token ID of selected token + int32_t pending_token_idx; // index of orig. prob. of selected token in original_probs +}; + +// adaptive probability transformation constants +static constexpr float DISTRIBUTION_WIDTH = 0.3f; +static constexpr float PEAK_LOGIT_VALUE = 5.0f; +static constexpr float SHARPNESS = 10.0f; +static constexpr float INV_WIDTH = 1.0f / DISTRIBUTION_WIDTH; + +static const char * llama_sampler_adaptive_p_name(const struct llama_sampler * /*smpl*/) { + return "adaptive-p"; +} + +static void llama_sampler_adaptive_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * ctx = (llama_sampler_adaptive_p *) smpl->ctx; + + llama_sampler_softmax_impl(cur_p, false); + + if (ctx->target < 0.0f) { + // at negative target values, adaptive-p is no-op + // we simply sample from the existing distribution + cur_p->selected = llama_sample_dist(cur_p, ctx->rng); + return; + } + + // store the original probabilities + ctx->original_probs.resize(cur_p->size); + for (size_t i = 0; i < cur_p->size; ++i) { + ctx->original_probs[i] = cur_p->data[i].p; + } + + // using the EMA, compute the adapted target probability for the current sampling step + auto target = std::clamp(ctx->target, 0.0f, 1.0f); + float adapted_target = std::clamp( + ctx->total_weight == 0.0f ? target : 2.0f * target - (ctx->weighted_sum / ctx->total_weight), + 0.0f, 1.0f + ); + + // adaptive probability transform + // + // quadratic near target for fine differentiation, transitioning to linear decay in the + // tails. unbounded negative logits ensure proper suppression of far-from-target tokens + // after the softmax. + // + for (size_t i = 0; i < cur_p->size; ++i) { + if (cur_p->data[i].logit == -INFINITY) { + // don't transform logits that are -INFINITY + // (as masked out by e.g. min-p and top-p when using backend sampling) + continue; + } + float dist = std::abs((cur_p->data[i].p - adapted_target) * INV_WIDTH); + cur_p->data[i].logit = PEAK_LOGIT_VALUE - SHARPNESS * dist * dist / (1.0f + dist); + } + + // softmax and sample from the transformed distribution + llama_sampler_softmax_impl(cur_p, false); + const int idx = llama_sample_dist(cur_p, ctx->rng); + cur_p->selected = idx; + + // store the selected token ID for acceptance later + ctx->pending_token_id = cur_p->data[idx].id; + ctx->pending_token_idx = idx; +} + +static void llama_sampler_adaptive_p_accept(struct llama_sampler * smpl, llama_token token) { + auto * ctx = (llama_sampler_adaptive_p *) smpl->ctx; + if (ctx->pending_token_id == token) { + GGML_ASSERT(ctx->pending_token_id != LLAMA_TOKEN_NULL); + GGML_ASSERT(ctx->pending_token_idx != -1); + // update EMA with the original probability of the selected token + ctx->weighted_sum = ctx->original_probs[ctx->pending_token_idx] + ctx->decay * ctx->weighted_sum; + ctx->total_weight = 1.0f + ctx->decay * ctx->total_weight; + } + ctx->pending_token_id = LLAMA_TOKEN_NULL; + ctx->pending_token_idx = -1; +} + +static void llama_sampler_adaptive_p_reset(struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_adaptive_p *) smpl->ctx; + // ctx->target and ctx->decay never change after init, so it's safe to keep them as is. + // original_probs is completely overwritten on every call to _apply. + // so we only need to reset the EMA state and pending token. + ctx->weighted_sum = ctx->target / (1.0f - ctx->decay); + ctx->total_weight = 1.0f / (1.0f - ctx->decay); + ctx->pending_token_id = LLAMA_TOKEN_NULL; + ctx->pending_token_idx = -1; + ctx->seed_cur = get_rng_seed(ctx->seed); + ctx->rng.seed(ctx->seed_cur); +} + +static struct llama_sampler * llama_sampler_adaptive_p_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_adaptive_p *) smpl->ctx; + auto * result = llama_sampler_init_adaptive_p(ctx->target, ctx->decay, ctx->seed); + auto * result_ctx = (llama_sampler_adaptive_p *) result->ctx; + + // copy everything (target, decay, seed, and RNG are already set) + result_ctx->weighted_sum = ctx->weighted_sum; + result_ctx->total_weight = ctx->total_weight; + result_ctx->pending_token_id = ctx->pending_token_id; + result_ctx->pending_token_idx = ctx->pending_token_idx; + + return result; +} + +static void llama_sampler_adaptive_p_free(struct llama_sampler * smpl) { + delete (llama_sampler_adaptive_p *) smpl->ctx; +} + +static struct llama_sampler_i llama_sampler_adaptive_p_i = { + /* .name = */ llama_sampler_adaptive_p_name, + /* .accept = */ llama_sampler_adaptive_p_accept, + /* .apply = */ llama_sampler_adaptive_p_apply, + /* .reset = */ llama_sampler_adaptive_p_reset, + /* .clone = */ llama_sampler_adaptive_p_clone, + /* .free = */ llama_sampler_adaptive_p_free, + /* .backend_init = */ nullptr, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ nullptr, + /* .backend_set_input = */ nullptr, +}; + +struct llama_sampler * llama_sampler_init_adaptive_p( + float target, + float decay, + uint32_t seed +) { + auto seed_cur = get_rng_seed(seed); + float clamped_decay = std::clamp(decay, 0.0f, 0.99f); + return llama_sampler_init( + /* .iface = */ &llama_sampler_adaptive_p_i, + /* .ctx = */ new llama_sampler_adaptive_p { + /* .target = */ target, + /* .decay = */ clamped_decay, + /* .seed = */ seed, + /* .seed_cur = */ seed_cur, + /* .rng = */ std::mt19937(seed_cur), + /* .weighted_sum = */ target / (1.0f - clamped_decay), + /* .total_weight = */ 1.0f / (1.0f - clamped_decay), + /* .original_probs = */ {}, + /* .pending_token_id = */ LLAMA_TOKEN_NULL, + /* .pending_token_idx = */ -1 + } + ); +} + // logit-bias struct llama_sampler_logit_bias : public llama_sampler_backend { diff --git a/examples/talk-llama/llama-vocab.cpp b/examples/talk-llama/llama-vocab.cpp index a20c6525e46..a23950d007c 100644 --- a/examples/talk-llama/llama-vocab.cpp +++ b/examples/talk-llama/llama-vocab.cpp @@ -461,6 +461,13 @@ struct llm_tokenizer_bpe : llm_tokenizer { "[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\\r\\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", }; break; + case LLAMA_VOCAB_PRE_TYPE_EXAONE_MOE: + regex_exprs = { + // original regex from tokenizer.json + // "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?(?:\\p{L}\\p{M}*(?: \\p{L}\\p{M}*)*)+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]?|\\s*[\\r\\n]|\\s+(?!\\S)|\\s+" + "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?(?:\\p{L}\\p{M}*(?: \\p{L}\\p{M}*)*)+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]?|\\s*[\\r\\n]|\\s+(?!\\S)|\\s+", + }; + break; default: // default regex for BPE tokenization pre-processing regex_exprs = { @@ -1965,6 +1972,9 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { } else if ( tokenizer_pre == "exaone4") { pre_type = LLAMA_VOCAB_PRE_TYPE_GPT2; + } else if ( + tokenizer_pre == "exaone-moe") { + pre_type = LLAMA_VOCAB_PRE_TYPE_EXAONE_MOE; } else if ( tokenizer_pre == "chameleon") { pre_type = LLAMA_VOCAB_PRE_TYPE_CHAMELEON; @@ -2436,7 +2446,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { auto & attr = id_to_token[t.second].attr; if (t.first == "<|channel|>" || t.first == "<|message|>" || t.first == "<|start|>" || t.first == "<|constrain|>") { - attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_USER_DEFINED); + LLAMA_LOG_WARN("%s: setting token '%s' (%d) attribute to USER_DEFINED (%u), old attributes: %u\n", + __func__, t.first.c_str(), t.second, LLAMA_TOKEN_ATTR_USER_DEFINED, attr); + + attr = LLAMA_TOKEN_ATTR_USER_DEFINED; } } @@ -2489,7 +2502,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { special_eog_ids.erase(end_id); auto & attr = id_to_token[end_id].attr; - attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_USER_DEFINED); + attr = LLAMA_TOKEN_ATTR_USER_DEFINED; LLAMA_LOG_WARN("%s: special_eog_ids contains both '<|return|>' and '<|call|>', or '<|calls|>' and '<|flush|>' tokens, removing '<|end|>' token from EOG list\n", __func__); } @@ -3289,34 +3302,34 @@ int32_t llama_vocab::impl::detokenize( } void llama_vocab::impl::print_info() const { - LLAMA_LOG_INFO("%s: vocab type = %s\n", __func__, type_name().c_str()); - LLAMA_LOG_INFO("%s: n_vocab = %u\n", __func__, vocab.n_tokens()); - LLAMA_LOG_INFO("%s: n_merges = %u\n", __func__, (uint32_t) bpe_ranks.size()); + LLAMA_LOG_INFO("%s: vocab type = %s\n", __func__, type_name().c_str()); + LLAMA_LOG_INFO("%s: n_vocab = %u\n", __func__, vocab.n_tokens()); + LLAMA_LOG_INFO("%s: n_merges = %u\n", __func__, (uint32_t) bpe_ranks.size()); // special tokens - if (special_bos_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: BOS token = %d '%s'\n", __func__, special_bos_id, id_to_token.at(special_bos_id).text.c_str() ); } - if (special_eos_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: EOS token = %d '%s'\n", __func__, special_eos_id, id_to_token.at(special_eos_id).text.c_str() ); } - if (special_eot_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: EOT token = %d '%s'\n", __func__, special_eot_id, id_to_token.at(special_eot_id).text.c_str() ); } - if (special_eom_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: EOM token = %d '%s'\n", __func__, special_eom_id, id_to_token.at(special_eom_id).text.c_str() ); } - if (special_unk_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: UNK token = %d '%s'\n", __func__, special_unk_id, id_to_token.at(special_unk_id).text.c_str() ); } - if (special_sep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: SEP token = %d '%s'\n", __func__, special_sep_id, id_to_token.at(special_sep_id).text.c_str() ); } - if (special_pad_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: PAD token = %d '%s'\n", __func__, special_pad_id, id_to_token.at(special_pad_id).text.c_str() ); } - if (special_mask_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: MASK token = %d '%s'\n", __func__, special_mask_id, id_to_token.at(special_mask_id).text.c_str() ); } - - if (linefeed_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: LF token = %d '%s'\n", __func__, linefeed_id, id_to_token.at(linefeed_id).text.c_str() ); } - - if (special_fim_pre_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM PRE token = %d '%s'\n", __func__, special_fim_pre_id, id_to_token.at(special_fim_pre_id).text.c_str() ); } - if (special_fim_suf_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM SUF token = %d '%s'\n", __func__, special_fim_suf_id, id_to_token.at(special_fim_suf_id).text.c_str() ); } - if (special_fim_mid_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM MID token = %d '%s'\n", __func__, special_fim_mid_id, id_to_token.at(special_fim_mid_id).text.c_str() ); } - if (special_fim_pad_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM PAD token = %d '%s'\n", __func__, special_fim_pad_id, id_to_token.at(special_fim_pad_id).text.c_str() ); } - if (special_fim_rep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM REP token = %d '%s'\n", __func__, special_fim_rep_id, id_to_token.at(special_fim_rep_id).text.c_str() ); } - if (special_fim_sep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM SEP token = %d '%s'\n", __func__, special_fim_sep_id, id_to_token.at(special_fim_sep_id).text.c_str() ); } + if (special_bos_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: BOS token = %d '%s'\n", __func__, special_bos_id, id_to_token.at(special_bos_id).text.c_str() ); } + if (special_eos_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: EOS token = %d '%s'\n", __func__, special_eos_id, id_to_token.at(special_eos_id).text.c_str() ); } + if (special_eot_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: EOT token = %d '%s'\n", __func__, special_eot_id, id_to_token.at(special_eot_id).text.c_str() ); } + if (special_eom_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: EOM token = %d '%s'\n", __func__, special_eom_id, id_to_token.at(special_eom_id).text.c_str() ); } + if (special_unk_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: UNK token = %d '%s'\n", __func__, special_unk_id, id_to_token.at(special_unk_id).text.c_str() ); } + if (special_sep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: SEP token = %d '%s'\n", __func__, special_sep_id, id_to_token.at(special_sep_id).text.c_str() ); } + if (special_pad_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: PAD token = %d '%s'\n", __func__, special_pad_id, id_to_token.at(special_pad_id).text.c_str() ); } + if (special_mask_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: MASK token = %d '%s'\n", __func__, special_mask_id, id_to_token.at(special_mask_id).text.c_str() ); } + + if (linefeed_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: LF token = %d '%s'\n", __func__, linefeed_id, id_to_token.at(linefeed_id).text.c_str() ); } + + if (special_fim_pre_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM PRE token = %d '%s'\n", __func__, special_fim_pre_id, id_to_token.at(special_fim_pre_id).text.c_str() ); } + if (special_fim_suf_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM SUF token = %d '%s'\n", __func__, special_fim_suf_id, id_to_token.at(special_fim_suf_id).text.c_str() ); } + if (special_fim_mid_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM MID token = %d '%s'\n", __func__, special_fim_mid_id, id_to_token.at(special_fim_mid_id).text.c_str() ); } + if (special_fim_pad_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM PAD token = %d '%s'\n", __func__, special_fim_pad_id, id_to_token.at(special_fim_pad_id).text.c_str() ); } + if (special_fim_rep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM REP token = %d '%s'\n", __func__, special_fim_rep_id, id_to_token.at(special_fim_rep_id).text.c_str() ); } + if (special_fim_sep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM SEP token = %d '%s'\n", __func__, special_fim_sep_id, id_to_token.at(special_fim_sep_id).text.c_str() ); } for (const auto & id : special_eog_ids) { - LLAMA_LOG_INFO( "%s: EOG token = %d '%s'\n", __func__, id, id_to_token.at(id).text.c_str() ); + LLAMA_LOG_INFO( "%s: EOG token = %d '%s'\n", __func__, id, id_to_token.at(id).text.c_str() ); } - LLAMA_LOG_INFO("%s: max token length = %d\n", __func__, max_token_len); + LLAMA_LOG_INFO("%s: max token length = %d\n", __func__, max_token_len); } llama_vocab::llama_vocab() : pimpl(new impl(*this)) { diff --git a/examples/talk-llama/llama-vocab.h b/examples/talk-llama/llama-vocab.h index 2b240a5491b..28c3a82b91e 100644 --- a/examples/talk-llama/llama-vocab.h +++ b/examples/talk-llama/llama-vocab.h @@ -53,6 +53,7 @@ enum llama_vocab_pre_type { LLAMA_VOCAB_PRE_TYPE_AFMOE = 42, LLAMA_VOCAB_PRE_TYPE_SOLAR_OPEN = 43, LLAMA_VOCAB_PRE_TYPE_YOUTU = 44, + LLAMA_VOCAB_PRE_TYPE_EXAONE_MOE = 45, }; struct LLM_KV; diff --git a/examples/talk-llama/llama.cpp b/examples/talk-llama/llama.cpp index f1096d960e1..6da90d6f1f8 100644 --- a/examples/talk-llama/llama.cpp +++ b/examples/talk-llama/llama.cpp @@ -311,8 +311,12 @@ static void llama_params_fit_impl( __func__, hp_nct, cparams->n_ctx, memory_reduction/MiB); } } else { - LLAMA_LOG_INFO("%s: default model context size is %" PRIu32 " which is <= the min. context size of %" PRIu32 " -> no change\n", - __func__, hp_nct, n_ctx_min); + if (n_ctx_min == UINT32_MAX) { + LLAMA_LOG_INFO("%s: user has requested full context size of %" PRIu32 " -> no change\n", __func__, hp_nct); + } else { + LLAMA_LOG_INFO("%s: default model context size is %" PRIu32 " which is <= the min. context size of %" PRIu32 " -> no change\n", + __func__, hp_nct, n_ctx_min); + } } } else { LLAMA_LOG_INFO("%s: context size set by user to %" PRIu32 " -> no change\n", __func__, cparams->n_ctx); @@ -1091,25 +1095,55 @@ int32_t llama_chat_apply_template( // model split // -int llama_split_path(char * split_path, size_t maxlen, const char * path_prefix, int split_no, int split_count) { +int32_t llama_split_path( + char * split_path, + size_t maxlen, + const char * path_prefix, + int32_t split_no, + int32_t split_count) { + static const char * const SPLIT_PATH_FORMAT = "%s-%05d-of-%05d.gguf"; - if (snprintf(split_path, maxlen, SPLIT_PATH_FORMAT, path_prefix, split_no + 1, split_count)) { - return strlen(split_path); + + const int written = snprintf( + split_path, + maxlen, + SPLIT_PATH_FORMAT, + path_prefix, + split_no + 1, + split_count + ); + + if (written < 0 || (size_t) written >= maxlen) { + return 0; } - return 0; + + return (int32_t) written; } -int llama_split_prefix(char * split_prefix, size_t maxlen, const char * split_path, int split_no, int split_count) { - std::string str_split_path(split_path); +int32_t llama_split_prefix( + char * split_prefix, + size_t maxlen, + const char * split_path, + int32_t split_no, + int32_t split_count) { + + const std::string str_split_path(split_path); + char postfix[32]; - snprintf(postfix, 32, "-%05d-of-%05d.gguf", split_no + 1, split_count); - std::string str_postfix(postfix); - - // check if split_prefix ends with postfix - int size_prefix = str_split_path.size() - str_postfix.size(); - if (size_prefix > 0 && str_split_path.find(str_postfix, size_prefix) != std::string::npos) { - snprintf(split_prefix, std::min((size_t) size_prefix + 1, maxlen), "%s", split_path); - return size_prefix; + snprintf(postfix, sizeof(postfix), "-%05d-of-%05d.gguf", split_no + 1, split_count); + + const std::string str_postfix(postfix); + if (str_split_path.size() <= str_postfix.size()) { + return 0; + } + + const size_t size_prefix = str_split_path.size() - str_postfix.size(); + + if (str_split_path.compare(size_prefix, std::string::npos, str_postfix) == 0) { + const size_t copy_len = std::min(size_prefix + 1, maxlen); + snprintf(split_prefix, copy_len, "%s", split_path); + + return (int32_t) size_prefix; } return 0; diff --git a/examples/talk-llama/llama.h b/examples/talk-llama/llama.h index 1c17efb9fa1..bf4e28a8be1 100644 --- a/examples/talk-llama/llama.h +++ b/examples/talk-llama/llama.h @@ -309,7 +309,7 @@ extern "C" { // Keep the booleans together to avoid misalignment during copy-by-value. bool vocab_only; // only load the vocabulary, no weights bool use_mmap; // use mmap if possible - bool use_direct_io; // use direct io, takes precedence over use_mmap + bool use_direct_io; // use direct io, takes precedence over use_mmap when supported bool use_mlock; // force system to keep model in RAM bool check_tensors; // validate model tensor data bool use_extra_bufts; // use extra buffer types (used for weight repacking) @@ -489,6 +489,7 @@ extern "C" { // - returns true if the parameters could be successfully modified to fit device memory // - this function is NOT thread safe because it modifies the global llama logger state // - only parameters that have the same value as in llama_default_model_params are modified + // with the exception of the context size which is modified if and only if equal to 0 LLAMA_API enum llama_params_fit_status llama_params_fit( const char * path_model, struct llama_model_params * mparams, @@ -646,7 +647,8 @@ extern "C" { // Manually free a LoRA adapter // NOTE: loaded adapters will be free when the associated model is deleted - LLAMA_API void llama_adapter_lora_free(struct llama_adapter_lora * adapter); + LLAMA_API DEPRECATED(void llama_adapter_lora_free(struct llama_adapter_lora * adapter), + "adapters are now freed together with the associated model"); // Get the invocation tokens if the current lora is an alora LLAMA_API uint64_t llama_adapter_get_alora_n_invocation_tokens(const struct llama_adapter_lora * adapter); @@ -1255,7 +1257,6 @@ extern "C" { // [EXPERIMENTAL] // attach a sampler to the context // note: prefer initializing the context with llama_context_params.samplers when possible - // note: changing the samplers of a context can cause graph reallocations and degraded performance LLAMA_API bool llama_set_sampler(struct llama_context * ctx, llama_seq_id seq_id, struct llama_sampler * smpl); // mirror of llama_sampler_i: @@ -1395,6 +1396,33 @@ extern "C" { const char ** seq_breakers, size_t num_breakers); + /// adaptive-p: select tokens near a configurable target probability over time. + /// + /// the adaptive-p sampler transforms the token probability distribution to favor tokens + /// that fall near a user-configurable probability target. + /// + /// internally, the sampler maintains an exponential moving average of the *ORIGINAL* + /// probabilities of selected tokens at each sampling step. it uses this EMA to compute an + /// adapted target probability at each sampling step, thus maintaining the desired target + /// probability over time. + /// + /// adaptive-p selects a token ID rather than just mutating candidates, so it must be last + /// in the sampler chain (like mirostat, dist, greedy). + /// + /// only mild truncation before this sampler is recommended. we suggest applying min-p + /// before adaptive-p as the only other active sampler in the chain. + /// + /// @param target select tokens near this probability (valid range 0.0 to 1.0; negative = disabled) + /// @param decay EMA decay for adaptation; history ≈ 1/(1-decay) tokens (valid range 0.0 - 0.99) + /// @param seed RNG seed + /// + /// ref: https://github.com/ggml-org/llama.cpp/pull/17927 + /// + LLAMA_API struct llama_sampler * llama_sampler_init_adaptive_p( + float target, + float decay, + uint32_t seed); + LLAMA_API struct llama_sampler * llama_sampler_init_logit_bias( int32_t n_vocab, int32_t n_logit_bias, @@ -1448,12 +1476,12 @@ extern "C" { /// @details Build a split GGUF final path for this chunk. /// llama_split_path(split_path, sizeof(split_path), "/models/ggml-model-q4_0", 2, 4) => split_path = "/models/ggml-model-q4_0-00002-of-00004.gguf" // Returns the split_path length. - LLAMA_API int llama_split_path(char * split_path, size_t maxlen, const char * path_prefix, int split_no, int split_count); + LLAMA_API int32_t llama_split_path(char * split_path, size_t maxlen, const char * path_prefix, int32_t split_no, int32_t split_count); /// @details Extract the path prefix from the split_path if and only if the split_no and split_count match. /// llama_split_prefix(split_prefix, 64, "/models/ggml-model-q4_0-00002-of-00004.gguf", 2, 4) => split_prefix = "/models/ggml-model-q4_0" // Returns the split_prefix length. - LLAMA_API int llama_split_prefix(char * split_prefix, size_t maxlen, const char * split_path, int split_no, int split_count); + LLAMA_API int32_t llama_split_prefix(char * split_prefix, size_t maxlen, const char * split_path, int32_t split_no, int32_t split_count); // Print system information LLAMA_API const char * llama_print_system_info(void); diff --git a/examples/talk-llama/models/deepseek2.cpp b/examples/talk-llama/models/deepseek2.cpp index ca63a62ad1b..297dca51369 100644 --- a/examples/talk-llama/models/deepseek2.cpp +++ b/examples/talk-llama/models/deepseek2.cpp @@ -2,14 +2,11 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - // lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B - bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26); - - const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0); + const bool is_mla = hparams.is_mla(); // note: these are the actual head sizes you get when treating as MHA or after "decompression" using wv_b for MLA - const int64_t n_embd_head_k = is_mla ? hparams.n_embd_head_k_mla : hparams.n_embd_head_k; - const int64_t n_embd_head_v = is_mla ? hparams.n_embd_head_v_mla : hparams.n_embd_head_v; + const int64_t n_embd_head_k = hparams.n_embd_head_k_mla(); + const int64_t n_embd_head_v = hparams.n_embd_head_v_mla(); const int64_t n_embd_head_qk_rope = hparams.n_rot; const int64_t n_embd_head_qk_nope = n_embd_head_k - n_embd_head_qk_rope; @@ -43,7 +40,8 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv(); + auto * inp_attn_kv = !is_mla ? build_attn_inp_kv() : nullptr; + auto * inp_attn_k = is_mla ? build_attn_inp_k() : nullptr; ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -57,6 +55,9 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr // self_attention { ggml_tensor * q = NULL; + + const bool is_lite = model.layers[il].wq; + if (!is_lite) { q = ggml_mul_mat(ctx0, model.layers[il].wq_a, cur); cb(q, "q", il); @@ -124,14 +125,14 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr // {n_embd_head_qk_rope + kv_lora_rank, n_head, n_tokens} // note: rope must go first for in-place context shifting in build_rope_shift() - ggml_tensor * Qcur = ggml_concat(ctx0, q_pe, q_nope_absorbed, 0); + ggml_tensor * Qcur = ggml_concat(ctx0, q_nope_absorbed, q_pe, 0); cb(Qcur, "Qcur", il); kv_cmpr = ggml_reshape_3d(ctx0, kv_cmpr, kv_lora_rank, 1, n_tokens); cb(kv_cmpr, "kv_cmpr_reshape", il); // {n_embd_head_qk_rope + kv_lora_rank, 1, n_tokens} - ggml_tensor * Kcur = ggml_concat(ctx0, k_pe, kv_cmpr, 0); + ggml_tensor * Kcur = ggml_concat(ctx0, kv_cmpr, k_pe, 0); cb(Kcur, "Kcur", il); // {kv_lora_rank, 1, n_tokens} @@ -145,7 +146,7 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr } // note: MLA with the absorption optimzation converts into MQA (ie: GQA with 1 group) - cur = build_attn(inp_attn, + cur = build_attn(inp_attn_k, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, model.layers[il].wv_b, kq_scale, il); } else { @@ -169,11 +170,10 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr Vcur = ggml_cont(ctx0, Vcur); cb(Vcur, "Vcur_cont", il); - // note: rope must go first for in-place context shifting in build_rope_shift() - ggml_tensor * Qcur = ggml_concat(ctx0, q_pe, q_nope, 0); + ggml_tensor * Qcur = ggml_concat(ctx0, q_nope, q_pe, 0); cb(Qcur, "Qcur", il); - ggml_tensor * Kcur = ggml_concat(ctx0, ggml_repeat(ctx0, k_pe, q_pe), k_nope, 0); + ggml_tensor * Kcur = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0); cb(Kcur, "Kcur", il); if (inp_attn_scale) { @@ -183,7 +183,7 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr } // note: MLA without the absorption optimization converts into MHA (ie: GQA with full n_head groups) - cur = build_attn(inp_attn, + cur = build_attn(inp_attn_kv, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); } diff --git a/examples/talk-llama/models/exaone-moe.cpp b/examples/talk-llama/models/exaone-moe.cpp new file mode 100644 index 00000000000..bef5b2ad351 --- /dev/null +++ b/examples/talk-llama/models/exaone-moe.cpp @@ -0,0 +1,146 @@ +#include "models.h" + + +llm_build_exaone_moe::llm_build_exaone_moe(const llama_model & model, const llm_graph_params & params) : + llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_k; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_v); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn_iswa = build_attn_inp_kv_iswa(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + const int n_transformer_layers = n_layer - hparams.nextn_predict_layers; + for (int il = 0; il < n_transformer_layers; ++il) { + ggml_tensor * inpSA = inpL; + + // use RoPE for SWA layers + const bool is_local_layer = hparams.is_swa(il); + + // norm + cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); + + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + cb(Kcur, "Kcur_normed", il); + + if (is_local_layer) { + Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, + freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); + + Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, + freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); + } + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn_iswa, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); + cb(cur, "attn_out", il); + } + if (il == n_transformer_layers - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // norm + cur = build_norm(ffn_inp, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + // feed-forward network + if (model.layers[il].ffn_gate_inp == nullptr) { + // dense branch + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } else { + // MoE branch + ggml_tensor * moe_out = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + model.layers[il].ffn_exp_probs_b, + n_expert, n_expert_used, + LLM_FFN_SILU, hparams.expert_weights_norm, + true, hparams.expert_weights_scale, + (llama_expert_gating_func_type) hparams.expert_gating_func, + il); + cb(moe_out, "ffn_moe_out", il); + + // FFN shared expert + { + ggml_tensor * ffn_shexp = + build_ffn(cur, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(ffn_shexp, "ffn_shexp", il); + + cur = ggml_add(ctx0, moe_out, ffn_shexp); + cb(cur, "ffn_out", il); + } + } + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + cur = inpL; + + // final norm + cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} diff --git a/examples/talk-llama/models/gemma3n-iswa.cpp b/examples/talk-llama/models/gemma3n-iswa.cpp index 93defbeef9c..7db6d3bf4ec 100644 --- a/examples/talk-llama/models/gemma3n-iswa.cpp +++ b/examples/talk-llama/models/gemma3n-iswa.cpp @@ -245,12 +245,12 @@ ggml_tensor * llm_build_gemma3n_iswa::view_2d_slice(ggml_tensor * x, int idx) { // equivalent to get_per_layer_inputs() in python code // output shape: [n_embd_altup, n_layer, n_tokens] ggml_tensor * llm_build_gemma3n_iswa::get_per_layer_inputs() { - auto inp = std::make_unique(); + auto inp = std::make_unique(n_embd); ggml_tensor * inp_per_layer; if (ubatch.token) { inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens); ggml_set_input(inp->tokens); - res->t_tokens = inp->tokens; + res->t_inp_tokens = inp->tokens; inp_per_layer = ggml_get_rows(ctx0, model.tok_embd_per_layer, inp->tokens); inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_altup, n_layer, n_tokens); inp_per_layer = ggml_scale(ctx0, inp_per_layer, sqrtf((float) n_embd_altup)); @@ -258,12 +258,12 @@ ggml_tensor * llm_build_gemma3n_iswa::get_per_layer_inputs() { res->add_input(std::move(inp)); } else { // Vision embedding path: use padding token (ID=0) embedding + // TODO: verify if this is the correct behavior in transformers implementation const int64_t embd_size = model.tok_embd_per_layer->ne[0]; // n_embd_altup * n_layer - // Extract and dequantize padding token embedding (column 0) - ggml_tensor * padding_q = ggml_view_1d(ctx0, model.tok_embd_per_layer, embd_size, 0); - ggml_tensor * padding_f32 = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, embd_size); - inp_per_layer = ggml_cpy(ctx0, padding_q, padding_f32); + // Extract and dequantize padding token embedding (row 0) + ggml_tensor * padding = ggml_view_1d(ctx0, model.tok_embd_per_layer, embd_size, 0); + inp_per_layer = ggml_cast(ctx0, padding, GGML_TYPE_F32); // Reshape to [n_embd_altup, n_layer, 1] inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_altup, n_layer, 1); diff --git a/examples/talk-llama/models/minicpm3.cpp b/examples/talk-llama/models/minicpm3.cpp index f374a9fd030..297cc34ba58 100644 --- a/examples/talk-llama/models/minicpm3.cpp +++ b/examples/talk-llama/models/minicpm3.cpp @@ -9,6 +9,7 @@ llm_build_minicpm3::llm_build_minicpm3(const llama_model & model, const llm_grap const uint32_t n_embd_head_qk_rope = hparams.n_rot; const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot; + const uint32_t kv_lora_rank = hparams.n_lora_kv; ggml_tensor * cur; diff --git a/examples/talk-llama/models/models.h b/examples/talk-llama/models/models.h index 6c40f48042b..3a44f7f140f 100644 --- a/examples/talk-llama/models/models.h +++ b/examples/talk-llama/models/models.h @@ -167,6 +167,10 @@ struct llm_build_exaone : public llm_graph_context { llm_build_exaone(const llama_model & model, const llm_graph_params & params); }; +struct llm_build_exaone_moe : public llm_graph_context { + llm_build_exaone_moe(const llama_model & model, const llm_graph_params & params); +}; + struct llm_build_falcon : public llm_graph_context { llm_build_falcon(const llama_model & model, const llm_graph_params & params); }; diff --git a/examples/talk-llama/models/nemotron-h.cpp b/examples/talk-llama/models/nemotron-h.cpp index eb135e63f18..079c730ac29 100644 --- a/examples/talk-llama/models/nemotron-h.cpp +++ b/examples/talk-llama/models/nemotron-h.cpp @@ -67,7 +67,7 @@ ggml_tensor * llm_build_nemotron_h::build_attention_layer(ggml_tensor * const llama_model & model, const int64_t n_embd_head, const int il) { - // compute Q and K and (optionally) RoPE them + // compute Q and K ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); cb(Qcur, "Qcur", il); if (model.layers[il].bq) { diff --git a/examples/talk-llama/models/plm.cpp b/examples/talk-llama/models/plm.cpp index 481cbba6907..612a487c564 100644 --- a/examples/talk-llama/models/plm.cpp +++ b/examples/talk-llama/models/plm.cpp @@ -5,6 +5,7 @@ llm_build_plm::llm_build_plm(const llama_model & model, const llm_graph_params & const uint32_t n_embd_head_qk_rope = hparams.n_rot; const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot; + const uint32_t kv_lora_rank = hparams.n_lora_kv; ggml_tensor * cur; diff --git a/examples/talk-llama/models/qwen3vl-moe.cpp b/examples/talk-llama/models/qwen3vl-moe.cpp index f72f80a8376..e5e1a2150c8 100644 --- a/examples/talk-llama/models/qwen3vl-moe.cpp +++ b/examples/talk-llama/models/qwen3vl-moe.cpp @@ -2,7 +2,8 @@ llm_build_qwen3vlmoe::llm_build_qwen3vlmoe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const size_t n_deepstack_layers = hparams.n_deepstack_layers; - const int64_t n_embd = hparams.n_embd; + + const int64_t n_embd = hparams.n_embd; const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -16,17 +17,6 @@ llm_build_qwen3vlmoe::llm_build_qwen3vlmoe(const llama_model & model, const llm_ int sections[4]; std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); - std::vector deepstack_features(n_deepstack_layers, nullptr); - - if (ubatch.embd) { - // Image input: split main embd and deepstack embds - ggml_tensor * inpL_main = ggml_view_2d(ctx0, inpL, n_embd, n_tokens, inpL->nb[1], 0); - for (size_t i = 0; i < n_deepstack_layers; i++) { - deepstack_features[i] = ggml_view_2d(ctx0, inpL, n_embd, n_tokens, inpL->nb[1], (i + 1) * n_embd * sizeof(float)); - } - inpL = inpL_main; - } - // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); @@ -120,8 +110,9 @@ llm_build_qwen3vlmoe::llm_build_qwen3vlmoe(const llama_model & model, const llm_ cur = build_cvec(cur, il); cb(cur, "l_out", il); - if (ubatch.embd && (size_t)il < n_deepstack_layers) { - cur = ggml_add(ctx0, cur, deepstack_features[il]); + if (il < (int) n_deepstack_layers) { + ggml_tensor * ds = ggml_view_2d(ctx0, res->t_inp_embd, n_embd, n_tokens, res->t_inp_embd->nb[1], (il + 1) * n_embd * sizeof(float)); + cur = ggml_add(ctx0, cur, ds); cb(cur, "deepstack_out", il); } diff --git a/examples/talk-llama/models/qwen3vl.cpp b/examples/talk-llama/models/qwen3vl.cpp index 0bae52239ca..0f8315b3240 100644 --- a/examples/talk-llama/models/qwen3vl.cpp +++ b/examples/talk-llama/models/qwen3vl.cpp @@ -2,7 +2,8 @@ llm_build_qwen3vl::llm_build_qwen3vl(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const size_t n_deepstack_layers = hparams.n_deepstack_layers; - const int64_t n_embd = hparams.n_embd; + + const int64_t n_embd = hparams.n_embd; const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -16,17 +17,6 @@ llm_build_qwen3vl::llm_build_qwen3vl(const llama_model & model, const llm_graph_ int sections[4]; std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); - std::vector deepstack_features(n_deepstack_layers, nullptr); - - if (ubatch.embd) { - // Image input: split main embd and deepstack embds - ggml_tensor * inpL_main = ggml_view_2d(ctx0, inpL, n_embd, n_tokens, inpL->nb[1], 0); - for (size_t i = 0; i < n_deepstack_layers; i++) { - deepstack_features[i] = ggml_view_2d(ctx0, inpL, n_embd, n_tokens, inpL->nb[1], (i + 1) * n_embd * sizeof(float)); - } - inpL = inpL_main; - } - // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); @@ -113,8 +103,9 @@ llm_build_qwen3vl::llm_build_qwen3vl(const llama_model & model, const llm_graph_ cur = build_cvec(cur, il); cb(cur, "l_out", il); - if (ubatch.embd && (size_t)il < n_deepstack_layers) { - cur = ggml_add(ctx0, cur, deepstack_features[il]); + if (il < (int) n_deepstack_layers) { + ggml_tensor * ds = ggml_view_2d(ctx0, res->t_inp_embd, n_embd, n_tokens, res->t_inp_embd->nb[1], (il + 1) * n_embd * sizeof(float)); + cur = ggml_add(ctx0, cur, ds); cb(cur, "deepstack_out", il); } From acbace0571476bcbc7bb49b8c28eaeaedfb5bfd7 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 30 Jan 2026 15:56:15 +0200 Subject: [PATCH 071/831] cuda : fix compile warnings (#0) --- ggml/src/ggml-cuda/ggml-cuda.cu | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index e9e9592ebad..08383edb402 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3910,14 +3910,14 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud // Launch graph CUDA_CHECK(cudaGraphLaunch(graph->instance, cuda_ctx->stream())); #else + GGML_UNUSED(graph_key); graph_evaluated_or_captured = true; #endif // USE_CUDA_GRAPH } } -static bool ggml_cuda_graph_set_enabled(ggml_backend_cuda_context * cuda_ctx, const void * graph_key) { - #ifdef USE_CUDA_GRAPH +static bool ggml_cuda_graph_set_enabled(ggml_backend_cuda_context * cuda_ctx, const void * graph_key) { ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key); if (graph->graph == nullptr) { @@ -3930,12 +3930,8 @@ static bool ggml_cuda_graph_set_enabled(ggml_backend_cuda_context * cuda_ctx, co } return graph->is_enabled(); -#else - GGML_UNUSED(cuda_ctx); - GGML_UNUSED(graph_key); - return false; -#endif // USE_CUDA_GRAPH } +#endif // USE_CUDA_GRAPH static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context; From bf422cb7042b0163910f3766a528da03987c3bc3 Mon Sep 17 00:00:00 2001 From: Frieder Bluemle Date: Fri, 30 Jan 2026 05:57:26 -0800 Subject: [PATCH 072/831] scripts : Fix dSYMs path case for macOS xcframework build (#3630) The script creates dSYMs/ but references dSYMS/ for macOS, causing build failures on case-sensitive filesystems. --- build-xcframework.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build-xcframework.sh b/build-xcframework.sh index bbf2764d729..4d462bbf4f3 100755 --- a/build-xcframework.sh +++ b/build-xcframework.sh @@ -559,7 +559,7 @@ xcodebuild -create-xcframework \ -framework $(pwd)/build-ios-device/framework/whisper.framework \ -debug-symbols $(pwd)/build-ios-device/dSYMs/whisper.dSYM \ -framework $(pwd)/build-macos/framework/whisper.framework \ - -debug-symbols $(pwd)/build-macos/dSYMS/whisper.dSYM \ + -debug-symbols $(pwd)/build-macos/dSYMs/whisper.dSYM \ -framework $(pwd)/build-visionos/framework/whisper.framework \ -debug-symbols $(pwd)/build-visionos/dSYMs/whisper.dSYM \ -framework $(pwd)/build-visionos-sim/framework/whisper.framework \ From aa1bc0d1a6dfd70dbb9f60c11df12441e03a9075 Mon Sep 17 00:00:00 2001 From: KITAITI Makoto Date: Fri, 30 Jan 2026 22:59:36 +0900 Subject: [PATCH 073/831] ruby : add `VAD::Context#segments_from_samples`, allow Pathname, etc. (#3633) * ruby : Bump version to 1.3.6 * Fix code in example * Add sample code to transcribe from MemoryView * Define GetVADContext macro * Use GetVADContext * Extract parse_full_args function * Use parse_full_args in ruby_whisper_full_parallel * Free samples after use * Check return value of parse_full_args() * Define GetVADParams macro * Add VAD::Context#segments_from_samples * Add tests for VAD::Context#segments_from_samples * Add signature for VAD::Context#segments_from_samples * Add sample code for VAD::Context#segments_from_samples * Add test for Whisper::Context#transcribe with Pathname * Make Whisper::Context#transcribe and Whisper::VAD::Context#detect accept Pathname * Update signature of Whisper::Context#transcribe * Fix variable name * Don't free memory view * Make parse_full_args return struct * Fallback when failed to get MemoryView * Add num of samples when too long * Check members of MemoryView * Fix a typo * Remove unnecessary include * Fix a typo * Fix a typo * Care the case of MemoryView doesn't fit spec * Add TODO comment * Add optimazation option to compiler flags * Use ALLOC_N instead of malloc * Add description to sample code * Rename and change args: parse_full_args -> parse_samples * Free samples when exception raised * Assign type check result to a variable * Define wrapper function of whisper_full * Change signature of parse_samples for rb_ensure * Ensure release MemoryView * Extract fill_samples function * Free samples memory when filling it failed * Free samples memory when transcription failed * Prepare transcription in wrapper funciton * Change function name * Simplify function boundary --- bindings/ruby/README.md | 35 +- bindings/ruby/ext/extconf.rb | 1 + bindings/ruby/ext/ruby_whisper.c | 2 - bindings/ruby/ext/ruby_whisper.h | 20 ++ bindings/ruby/ext/ruby_whisper_context.c | 322 +++++++++++------- bindings/ruby/ext/ruby_whisper_model.c | 1 - bindings/ruby/ext/ruby_whisper_params.c | 1 - bindings/ruby/ext/ruby_whisper_segment.c | 1 - bindings/ruby/ext/ruby_whisper_token.c | 1 - bindings/ruby/ext/ruby_whisper_transcribe.cpp | 5 +- bindings/ruby/ext/ruby_whisper_vad_context.c | 49 ++- .../ext/ruby_whisper_vad_context_detect.cpp | 11 +- bindings/ruby/ext/ruby_whisper_vad_params.c | 1 - bindings/ruby/ext/ruby_whisper_vad_segment.c | 1 - bindings/ruby/ext/ruby_whisper_vad_segments.c | 1 - bindings/ruby/sig/whisper.rbs | 6 +- bindings/ruby/test/test_vad_context.rb | 66 +++- bindings/ruby/test/test_whisper.rb | 20 ++ bindings/ruby/whispercpp.gemspec | 2 +- 19 files changed, 396 insertions(+), 150 deletions(-) diff --git a/bindings/ruby/README.md b/bindings/ruby/README.md index ea202753b67..86774158355 100644 --- a/bindings/ruby/README.md +++ b/bindings/ruby/README.md @@ -323,7 +323,24 @@ whisper end ``` -The second argument `samples` may be an array, an object with `length` and `each` method, or a MemoryView. If you can prepare audio data as C array and export it as a MemoryView, whispercpp accepts and works with it with zero copy. +The second argument `samples` may be an array, an object with `length` and `each` method, or a MemoryView. + +If you can prepare audio data as C array and export it as a MemoryView, whispercpp accepts and works with it with zero copy. + +```ruby +require "torchaudio" +require "arrow-numo-narray" +require "whisper" + +waveform, sample_rate = TorchAudio.load("test/fixtures/jfk.wav") +# Convert Torch::Tensor to Arrow::Array via Numo::NArray +samples = waveform.squeeze.numo.to_arrow.to_arrow_array + +whisper = Whisper::Context.new("base") +whisper + # Arrow::Array exports MemoryView + .full(Whisper::Params.new, samples) +``` Using VAD separately from ASR ----------------------------- @@ -334,13 +351,27 @@ VAD feature itself is useful. You can use it separately from ASR: vad = Whisper::VAD::Context.new("silero-v6.2.0") vad .detect("path/to/audio.wav", Whisper::VAD::Params.new) - .each_with_index do |segment, index| + .each.with_index do |segment, index| segment => {start_time: st, end_time: ed} # `Segment` responds to `#deconstruct_keys` puts "[%{nth}: %{st} --> %{ed}]" % {nth: index + 1, st:, ed:} end ``` +You may also low level API `Whisper::VAD::Context#segments_from_samples` as such `Whisper::Context#full`: + +```ruby +# Ruby Array +reader = WaveFile::Reader.new("path/to/audio.wav", WaveFile::Format.new(:mono, :float, 16000)) +samples = reader.enum_for(:each_buffer).map(&:samples).flatten + +# Or, object which exports MemoryView +waveform, sample_rate = TorchAudio.load("test/fixtures/jfk.wav") +samples = waveform.squeeze.numo.to_arrow.to_arrow_array + +segments = vad.segments_from_samples(Whisper::VAD::Params.new, samples) +``` + Development ----------- diff --git a/bindings/ruby/ext/extconf.rb b/bindings/ruby/ext/extconf.rb index 8a5ac67457b..acff501aa3b 100644 --- a/bindings/ruby/ext/extconf.rb +++ b/bindings/ruby/ext/extconf.rb @@ -7,6 +7,7 @@ have_library("gomp") rescue nil libs = Dependencies.new(cmake, options).to_s +$CFLAGS << " -O3 -march=native" $INCFLAGS << " -Isources/include -Isources/ggml/include -Isources/examples" $LOCAL_LIBS << " #{libs}" $cleanfiles << " build #{libs}" diff --git a/bindings/ruby/ext/ruby_whisper.c b/bindings/ruby/ext/ruby_whisper.c index ac677e9e3df..eb95829c032 100644 --- a/bindings/ruby/ext/ruby_whisper.c +++ b/bindings/ruby/ext/ruby_whisper.c @@ -1,5 +1,3 @@ -#include -#include #include "ruby_whisper.h" VALUE mWhisper; diff --git a/bindings/ruby/ext/ruby_whisper.h b/bindings/ruby/ext/ruby_whisper.h index 3f5660c374d..c2c9866ae0d 100644 --- a/bindings/ruby/ext/ruby_whisper.h +++ b/bindings/ruby/ext/ruby_whisper.h @@ -1,6 +1,8 @@ #ifndef RUBY_WHISPER_H #define RUBY_WHISPER_H +#include +#include #include "whisper.h" typedef struct { @@ -55,6 +57,13 @@ typedef struct { struct whisper_vad_context *context; } ruby_whisper_vad_context; +typedef struct parsed_samples_t { + float *samples; + int n_samples; + rb_memory_view_t memview; + bool memview_exported; +} parsed_samples_t; + #define GetContext(obj, rw) do { \ TypedData_Get_Struct((obj), ruby_whisper, &ruby_whisper_type, (rw)); \ if ((rw)->context == NULL) { \ @@ -69,6 +78,17 @@ typedef struct { } \ } while (0) +#define GetVADContext(obj, rwvc) do { \ + TypedData_Get_Struct((obj), ruby_whisper_vad_context, &ruby_whisper_vad_context_type, (rwvc)); \ + if ((rwvc)->context == NULL) { \ + rb_raise(rb_eRuntimeError, "Not initialized"); \ + } \ +} while (0) + +#define GetVADParams(obj, rwvp) do { \ + TypedData_Get_Struct((obj), ruby_whisper_vad_params, &ruby_whisper_vad_params_type, (rwvp)); \ +} while (0) + #define GetVADSegments(obj, rwvss) do { \ TypedData_Get_Struct((obj), ruby_whisper_vad_segments, &ruby_whisper_vad_segments_type, (rwvss)); \ if ((rwvss)->segments == NULL) { \ diff --git a/bindings/ruby/ext/ruby_whisper_context.c b/bindings/ruby/ext/ruby_whisper_context.c index a7b5f8513db..84790e3dedf 100644 --- a/bindings/ruby/ext/ruby_whisper_context.c +++ b/bindings/ruby/ext/ruby_whisper_context.c @@ -1,5 +1,3 @@ -#include -#include #include "ruby_whisper.h" extern ID id_to_s; @@ -27,6 +25,27 @@ extern void prepare_transcription(ruby_whisper_params *rwp, VALUE *context); ID transcribe_option_names[1]; +typedef struct fill_samples_args { + float *dest; + VALUE *src; + int n_samples; +} fill_samples_args; + +typedef struct full_args { + VALUE *context; + VALUE *params; + float *samples; + int n_samples; +} full_args; + +typedef struct full_parallel_args { + VALUE *context; + VALUE *params; + float *samples; + int n_samples; + int n_processors; +} full_parallel_args; + static void ruby_whisper_free(ruby_whisper *rw) { @@ -272,82 +291,175 @@ VALUE ruby_whisper_model_type(VALUE self) return rb_str_new2(whisper_model_type_readable(rw->context)); } -/* - * Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text - * Not thread safe for same context - * Uses the specified decoding strategy to obtain the text. - * - * call-seq: - * full(params, samples, n_samples) -> nil - * full(params, samples) -> nil - * - * The second argument +samples+ must be an array of samples, respond to :length, or be a MemoryView of an array of float. It must be 32 bit float PCM audio data. - */ -VALUE ruby_whisper_full(int argc, VALUE *argv, VALUE self) +static bool +check_memory_view(rb_memory_view_t *memview) { - if (argc < 2 || argc > 3) { - rb_raise(rb_eArgError, "wrong number of arguments (given %d, expected 2..3)", argc); + if (strcmp(memview->format, "f") != 0) { + rb_warn("currently only format \"f\" is supported for MemoryView, but given: %s", memview->format); + return false; + } + if (memview->ndim != 1) { + rb_warn("currently only 1 dimensional MemoryView is supported, but given: %zd", memview->ndim); + return false; } - ruby_whisper *rw; - ruby_whisper_params *rwp; - GetContext(self, rw); - VALUE params = argv[0]; - TypedData_Get_Struct(params, ruby_whisper_params, &ruby_whisper_params_type, rwp); - VALUE samples = argv[1]; - int n_samples; - rb_memory_view_t view; - const bool memory_view_available_p = rb_memory_view_available_p(samples); - if (argc == 3) { - n_samples = NUM2INT(argv[2]); - if (TYPE(samples) == T_ARRAY) { - if (RARRAY_LEN(samples) < n_samples) { - rb_raise(rb_eArgError, "samples length %ld is less than n_samples %d", RARRAY_LEN(samples), n_samples); + return true; +} + +static VALUE +fill_samples(VALUE rb_args) +{ + fill_samples_args *args = (fill_samples_args *)rb_args; + + if (RB_TYPE_P(*args->src, T_ARRAY)) { + for (int i = 0; i < args->n_samples; i++) { + args->dest[i] = RFLOAT_VALUE(rb_ary_entry(*args->src, i)); + } + } else { + // TODO: use rb_block_call + VALUE iter = rb_funcall(*args->src, id_to_enum, 1, rb_str_new2("each")); + for (int i = 0; i < args->n_samples; i++) { + // TODO: check if iter is exhausted and raise ArgumentError appropriately + VALUE sample = rb_funcall(iter, id_next, 0); + args->dest[i] = RFLOAT_VALUE(sample); + } + } + + return Qnil; +} + +struct parsed_samples_t +parse_samples(VALUE *samples, VALUE *n_samples) +{ + bool memview_available = rb_memory_view_available_p(*samples); + struct parsed_samples_t parsed = {0}; + parsed.memview_exported = false; + const bool is_array = RB_TYPE_P(*samples, T_ARRAY); + + if (!NIL_P(*n_samples)) { + parsed.n_samples = NUM2INT(*n_samples); + if (is_array) { + if (RARRAY_LEN(*samples) < parsed.n_samples) { + rb_raise(rb_eArgError, "samples length %ld is less than n_samples %d", RARRAY_LEN(*samples), parsed.n_samples); } } // Should check when samples.respond_to?(:length)? } else { - if (TYPE(samples) == T_ARRAY) { - if (RARRAY_LEN(samples) > INT_MAX) { + if (is_array) { + if (RARRAY_LEN(*samples) > INT_MAX) { rb_raise(rb_eArgError, "samples are too long"); } - n_samples = (int)RARRAY_LEN(samples); - } else if (memory_view_available_p) { - if (!rb_memory_view_get(samples, &view, RUBY_MEMORY_VIEW_SIMPLE)) { - view.obj = Qnil; - rb_raise(rb_eArgError, "unable to get a memory view"); + parsed.n_samples = (int)RARRAY_LEN(*samples); + } else if (memview_available) { + bool memview_got = rb_memory_view_get(*samples, &parsed.memview, RUBY_MEMORY_VIEW_SIMPLE); + if (memview_got) { + parsed.memview_exported = check_memory_view(&parsed.memview); + if (!parsed.memview_exported) { + rb_memory_view_release(&parsed.memview); + parsed.memview = (rb_memory_view_t){0}; + } } - ssize_t n_samples_size = view.byte_size / view.item_size; - if (n_samples_size > INT_MAX) { - rb_raise(rb_eArgError, "samples are too long"); + if (parsed.memview_exported) { + ssize_t n_samples_size = parsed.memview.byte_size / parsed.memview.item_size; + if (n_samples_size > INT_MAX) { + rb_memory_view_release(&parsed.memview); + rb_raise(rb_eArgError, "samples are too long: %zd", n_samples_size); + } + parsed.n_samples = (int)n_samples_size; + } else { + rb_warn("unable to get a memory view. fallbacks to Ruby object"); + if (rb_respond_to(*samples, id_length)) { + parsed.n_samples = NUM2INT(rb_funcall(*samples, id_length, 0)); + } else { + rb_raise(rb_eArgError, "samples must respond to :length"); + } } - n_samples = (int)n_samples_size; - } else if (rb_respond_to(samples, id_length)) { - n_samples = NUM2INT(rb_funcall(samples, id_length, 0)); + } else if (rb_respond_to(*samples, id_length)) { + parsed.n_samples = NUM2INT(rb_funcall(*samples, id_length, 0)); } else { - rb_raise(rb_eArgError, "samples must respond to :length or be a MemoryView of an array of flaot when n_samples is not given"); + rb_raise(rb_eArgError, "samples must respond to :length or be a MemoryView of an array of float when n_samples is not given"); } } - float * c_samples = (float *)malloc(n_samples * sizeof(float)); - if (memory_view_available_p) { - c_samples = (float *)view.data; + + if (parsed.memview_exported) { + parsed.samples = (float *)parsed.memview.data; } else { - if (TYPE(samples) == T_ARRAY) { - for (int i = 0; i < n_samples; i++) { - c_samples[i] = RFLOAT_VALUE(rb_ary_entry(samples, i)); - } - } else { - // TODO: use rb_block_call - VALUE iter = rb_funcall(samples, id_to_enum, 1, rb_str_new2("each")); - for (int i = 0; i < n_samples; i++) { - // TODO: check if iter is exhausted and raise ArgumentError appropriately - VALUE sample = rb_funcall(iter, id_next, 0); - c_samples[i] = RFLOAT_VALUE(sample); - } + parsed.samples = ALLOC_N(float, parsed.n_samples); + fill_samples_args args = { + parsed.samples, + samples, + parsed.n_samples, + }; + int state; + rb_protect(fill_samples, (VALUE)&args, &state); + if (state) { + xfree(parsed.samples); + rb_jump_tag(state); } } - prepare_transcription(rwp, &self); - const int result = whisper_full(rw->context, rwp->params, c_samples, n_samples); + + return parsed; +} + +VALUE +release_samples(VALUE rb_parsed_args) +{ + parsed_samples_t *parsed_args = (parsed_samples_t *)rb_parsed_args; + + if (parsed_args->memview_exported) { + rb_memory_view_release(&parsed_args->memview); + } else { + xfree(parsed_args->samples); + } + *parsed_args = (parsed_samples_t){0}; + + return Qnil; +} + +static VALUE +full_body(VALUE rb_args) +{ + full_args *args = (full_args *)rb_args; + + ruby_whisper *rw; + ruby_whisper_params *rwp; + GetContext(*args->context, rw); + TypedData_Get_Struct(*args->params, ruby_whisper_params, &ruby_whisper_params_type, rwp); + + prepare_transcription(rwp, args->context); + int result = whisper_full(rw->context, rwp->params, args->samples, args->n_samples); + + return INT2NUM(result); +} + +/* + * Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text + * Not thread safe for same context + * Uses the specified decoding strategy to obtain the text. + * + * call-seq: + * full(params, samples, n_samples) -> nil + * full(params, samples) -> nil + * + * The second argument +samples+ must be an array of samples, respond to :length, or be a MemoryView of an array of float. It must be 32 bit float PCM audio data. + */ +VALUE ruby_whisper_full(int argc, VALUE *argv, VALUE self) +{ + if (argc < 2 || argc > 3) { + rb_raise(rb_eArgError, "wrong number of arguments (given %d, expected 2..3)", argc); + } + + VALUE n_samples = argc == 2 ? Qnil : argv[2]; + + struct parsed_samples_t parsed = parse_samples(&argv[1], &n_samples); + full_args args = { + &self, + &argv[0], + parsed.samples, + parsed.n_samples, + }; + VALUE rb_result = rb_ensure(full_body, (VALUE)&args, release_samples, (VALUE)&parsed); + const int result = NUM2INT(rb_result); if (0 == result) { return self; } else { @@ -355,6 +467,22 @@ VALUE ruby_whisper_full(int argc, VALUE *argv, VALUE self) } } +static VALUE +full_parallel_body(VALUE rb_args) +{ + full_parallel_args *args = (full_parallel_args *)rb_args; + + ruby_whisper *rw; + ruby_whisper_params *rwp; + GetContext(*args->context, rw); + TypedData_Get_Struct(*args->params, ruby_whisper_params, &ruby_whisper_params_type, rwp); + + prepare_transcription(rwp, args->context); + int result = whisper_full_parallel(rw->context, rwp->params, args->samples, args->n_samples, args->n_processors); + + return INT2NUM(result); +} + /* * Split the input audio in chunks and process each chunk separately using whisper_full_with_state() * Result is stored in the default state of the context @@ -372,19 +500,11 @@ static VALUE ruby_whisper_full_parallel(int argc, VALUE *argv,VALUE self) { if (argc < 2 || argc > 4) { - rb_raise(rb_eArgError, "wrong number of arguments (given %d, expected 2..3)", argc); + rb_raise(rb_eArgError, "wrong number of arguments (given %d, expected 2..4)", argc); } - ruby_whisper *rw; - ruby_whisper_params *rwp; - GetContext(self, rw); - VALUE params = argv[0]; - TypedData_Get_Struct(params, ruby_whisper_params, &ruby_whisper_params_type, rwp); - VALUE samples = argv[1]; - int n_samples; + VALUE n_samples = argc == 2 ? Qnil : argv[2]; int n_processors; - rb_memory_view_t view; - const bool memory_view_available_p = rb_memory_view_available_p(samples); switch (argc) { case 2: n_processors = 1; @@ -396,56 +516,16 @@ ruby_whisper_full_parallel(int argc, VALUE *argv,VALUE self) n_processors = NUM2INT(argv[3]); break; } - if (argc >= 3 && !NIL_P(argv[2])) { - n_samples = NUM2INT(argv[2]); - if (TYPE(samples) == T_ARRAY) { - if (RARRAY_LEN(samples) < n_samples) { - rb_raise(rb_eArgError, "samples length %ld is less than n_samples %d", RARRAY_LEN(samples), n_samples); - } - } - // Should check when samples.respond_to?(:length)? - } else if (memory_view_available_p) { - if (!rb_memory_view_get(samples, &view, RUBY_MEMORY_VIEW_SIMPLE)) { - view.obj = Qnil; - rb_raise(rb_eArgError, "unable to get a memory view"); - } - ssize_t n_samples_size = view.byte_size / view.item_size; - if (n_samples_size > INT_MAX) { - rb_raise(rb_eArgError, "samples are too long"); - } - n_samples = (int)n_samples_size; - } else { - if (TYPE(samples) == T_ARRAY) { - if (RARRAY_LEN(samples) > INT_MAX) { - rb_raise(rb_eArgError, "samples are too long"); - } - n_samples = (int)RARRAY_LEN(samples); - } else if (rb_respond_to(samples, id_length)) { - n_samples = NUM2INT(rb_funcall(samples, id_length, 0)); - } else { - rb_raise(rb_eArgError, "samples must respond to :length or be a MemoryView of an array of flaot when n_samples is not given"); - } - } - float * c_samples = (float *)malloc(n_samples * sizeof(float)); - if (memory_view_available_p) { - c_samples = (float *)view.data; - } else { - if (TYPE(samples) == T_ARRAY) { - for (int i = 0; i < n_samples; i++) { - c_samples[i] = RFLOAT_VALUE(rb_ary_entry(samples, i)); - } - } else { - // FIXME: use rb_block_call - VALUE iter = rb_funcall(samples, id_to_enum, 1, rb_str_new2("each")); - for (int i = 0; i < n_samples; i++) { - // TODO: check if iter is exhausted and raise ArgumentError - VALUE sample = rb_funcall(iter, id_next, 0); - c_samples[i] = RFLOAT_VALUE(sample); - } - } - } - prepare_transcription(rwp, &self); - const int result = whisper_full_parallel(rw->context, rwp->params, c_samples, n_samples, n_processors); + struct parsed_samples_t parsed = parse_samples(&argv[1], &n_samples); + const full_parallel_args args = { + &self, + &argv[0], + parsed.samples, + parsed.n_samples, + n_processors, + }; + const VALUE rb_result = rb_ensure(full_parallel_body, (VALUE)&args, release_samples, (VALUE)&parsed); + const int result = NUM2INT(rb_result); if (0 == result) { return self; } else { diff --git a/bindings/ruby/ext/ruby_whisper_model.c b/bindings/ruby/ext/ruby_whisper_model.c index b196a8b5cb5..0e91fb3f87f 100644 --- a/bindings/ruby/ext/ruby_whisper_model.c +++ b/bindings/ruby/ext/ruby_whisper_model.c @@ -1,4 +1,3 @@ -#include #include "ruby_whisper.h" extern const rb_data_type_t ruby_whisper_type; diff --git a/bindings/ruby/ext/ruby_whisper_params.c b/bindings/ruby/ext/ruby_whisper_params.c index 4dfe2575a39..61eb1733676 100644 --- a/bindings/ruby/ext/ruby_whisper_params.c +++ b/bindings/ruby/ext/ruby_whisper_params.c @@ -1,4 +1,3 @@ -#include #include "ruby_whisper.h" #define BOOL_PARAMS_SETTER(self, prop, value) \ diff --git a/bindings/ruby/ext/ruby_whisper_segment.c b/bindings/ruby/ext/ruby_whisper_segment.c index 5229cb53900..ee0d66c4cc8 100644 --- a/bindings/ruby/ext/ruby_whisper_segment.c +++ b/bindings/ruby/ext/ruby_whisper_segment.c @@ -1,4 +1,3 @@ -#include #include "ruby_whisper.h" #define N_KEY_NAMES 6 diff --git a/bindings/ruby/ext/ruby_whisper_token.c b/bindings/ruby/ext/ruby_whisper_token.c index ea4f4e635d2..56a7eab2231 100644 --- a/bindings/ruby/ext/ruby_whisper_token.c +++ b/bindings/ruby/ext/ruby_whisper_token.c @@ -1,4 +1,3 @@ -#include #include "ruby_whisper.h" #define N_KEY_NAMES 11 diff --git a/bindings/ruby/ext/ruby_whisper_transcribe.cpp b/bindings/ruby/ext/ruby_whisper_transcribe.cpp index 594b2db90e3..c00fbcd1def 100644 --- a/bindings/ruby/ext/ruby_whisper_transcribe.cpp +++ b/bindings/ruby/ext/ruby_whisper_transcribe.cpp @@ -1,4 +1,3 @@ -#include #include "ruby_whisper.h" #include "common-whisper.h" #include @@ -13,6 +12,7 @@ extern const rb_data_type_t ruby_whisper_params_type; extern ID id_to_s; extern ID id_call; +extern ID id_to_path; extern ID transcribe_option_names[1]; extern void @@ -50,6 +50,9 @@ ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) { rb_raise(rb_eRuntimeError, "Expected file path to wave file"); } + if (rb_respond_to(wave_file_path, id_to_path)) { + wave_file_path = rb_funcall(wave_file_path, id_to_path, 0); + } std::string fname_inp = StringValueCStr(wave_file_path); std::vector pcmf32; // mono-channel F32 PCM diff --git a/bindings/ruby/ext/ruby_whisper_vad_context.c b/bindings/ruby/ext/ruby_whisper_vad_context.c index bf2ed2ba465..97c9736b6f4 100644 --- a/bindings/ruby/ext/ruby_whisper_vad_context.c +++ b/bindings/ruby/ext/ruby_whisper_vad_context.c @@ -1,12 +1,23 @@ -#include #include "ruby_whisper.h" extern ID id_to_s; extern VALUE cVADContext; +extern const rb_data_type_t ruby_whisper_vad_params_type; extern VALUE ruby_whisper_vad_detect(VALUE self, VALUE file_path, VALUE params); extern VALUE ruby_whisper_normalize_model_path(VALUE model_path); +extern parsed_samples_t parse_samples(VALUE *samples, VALUE *n_samples); +extern VALUE release_samples(VALUE parsed); + +extern VALUE ruby_whisper_vad_segments_s_init(struct whisper_vad_segments *segments); + +typedef struct segments_from_samples_args { + VALUE *context; + VALUE *params; + float *samples; + int n_samples; +} segments_from_samples_args; static size_t ruby_whisper_vad_context_memsize(const void *p) @@ -66,10 +77,46 @@ ruby_whisper_vad_context_initialize(VALUE self, VALUE model_path) return Qnil; } +static VALUE +segments_from_samples_body(VALUE rb_args) +{ + segments_from_samples_args *args = (segments_from_samples_args *)rb_args; + + ruby_whisper_vad_context *rwvc; + ruby_whisper_vad_params *rwvp; + GetVADContext(*args->context, rwvc); + GetVADParams(*args->params, rwvp); + + struct whisper_vad_segments *segments = whisper_vad_segments_from_samples(rwvc->context, rwvp->params, args->samples, args->n_samples); + + return ruby_whisper_vad_segments_s_init(segments); +} + +static VALUE +ruby_whisper_vad_segments_from_samples(int argc, VALUE *argv, VALUE self) +{ + if (argc < 2 || argc > 3) { + rb_raise(rb_eArgError, "wrong number of arguments (given %d, expected 2..3)", argc); + } + + VALUE n_samples = argc == 2 ? Qnil : argv[2]; + struct parsed_samples_t parsed = parse_samples(&argv[1], &n_samples); + segments_from_samples_args args = { + &self, + &argv[0], + parsed.samples, + parsed.n_samples, + }; + VALUE segments = rb_ensure(segments_from_samples_body, (VALUE)&args, release_samples, (VALUE)&parsed); + + return segments; +} + void init_ruby_whisper_vad_context(VALUE *mVAD) { cVADContext = rb_define_class_under(*mVAD, "Context", rb_cObject); rb_define_alloc_func(cVADContext, ruby_whisper_vad_context_s_allocate); rb_define_method(cVADContext, "initialize", ruby_whisper_vad_context_initialize, 1); + rb_define_method(cVADContext, "segments_from_samples", ruby_whisper_vad_segments_from_samples, -1); rb_define_method(cVADContext, "detect", ruby_whisper_vad_detect, 2); } diff --git a/bindings/ruby/ext/ruby_whisper_vad_context_detect.cpp b/bindings/ruby/ext/ruby_whisper_vad_context_detect.cpp index 58609f87742..802b0222dbd 100644 --- a/bindings/ruby/ext/ruby_whisper_vad_context_detect.cpp +++ b/bindings/ruby/ext/ruby_whisper_vad_context_detect.cpp @@ -1,4 +1,3 @@ -#include #include "ruby_whisper.h" #include "common-whisper.h" #include @@ -8,6 +7,8 @@ extern "C" { #endif +extern ID id_to_path; + extern VALUE cVADSegments; extern const rb_data_type_t ruby_whisper_vad_context_type; @@ -25,12 +26,12 @@ ruby_whisper_vad_detect(VALUE self, VALUE file_path, VALUE params) { std::vector> pcmf32s; whisper_vad_segments *segments; - TypedData_Get_Struct(self, ruby_whisper_vad_context, &ruby_whisper_vad_context_type, rwvc); - if (rwvc->context == NULL) { - rb_raise(rb_eRuntimeError, "Doesn't have referenxe to context internally"); - } + GetVADContext(self, rwvc); TypedData_Get_Struct(params, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp); + if (rb_respond_to(file_path, id_to_path)) { + file_path = rb_funcall(file_path, id_to_path, 0); + } cpp_file_path = StringValueCStr(file_path); if (!read_audio_data(cpp_file_path, pcmf32, pcmf32s, false)) { diff --git a/bindings/ruby/ext/ruby_whisper_vad_params.c b/bindings/ruby/ext/ruby_whisper_vad_params.c index f254bfa2138..28256650e32 100644 --- a/bindings/ruby/ext/ruby_whisper_vad_params.c +++ b/bindings/ruby/ext/ruby_whisper_vad_params.c @@ -1,4 +1,3 @@ -#include #include "ruby_whisper.h" #define DEFINE_PARAM(param_name, nth) \ diff --git a/bindings/ruby/ext/ruby_whisper_vad_segment.c b/bindings/ruby/ext/ruby_whisper_vad_segment.c index 49ff0aadcce..84a007bb725 100644 --- a/bindings/ruby/ext/ruby_whisper_vad_segment.c +++ b/bindings/ruby/ext/ruby_whisper_vad_segment.c @@ -1,4 +1,3 @@ -#include #include "ruby_whisper.h" #define N_KEY_NAMES 2 diff --git a/bindings/ruby/ext/ruby_whisper_vad_segments.c b/bindings/ruby/ext/ruby_whisper_vad_segments.c index 1bb375937a4..db62fdb6222 100644 --- a/bindings/ruby/ext/ruby_whisper_vad_segments.c +++ b/bindings/ruby/ext/ruby_whisper_vad_segments.c @@ -1,4 +1,3 @@ -#include #include "ruby_whisper.h" extern ID id___method__; diff --git a/bindings/ruby/sig/whisper.rbs b/bindings/ruby/sig/whisper.rbs index 1137e3f36ab..0e7b2c276e8 100644 --- a/bindings/ruby/sig/whisper.rbs +++ b/bindings/ruby/sig/whisper.rbs @@ -37,8 +37,8 @@ module Whisper # puts text # end # - def transcribe: (string, Params, ?n_processors: Integer) -> self - | (string, Params, ?n_processors: Integer) { (String) -> void } -> self + def transcribe: (path, Params, ?n_processors: Integer) -> self + | (path, Params, ?n_processors: Integer) { (String) -> void } -> self def model_n_vocab: () -> Integer def model_n_audio_ctx: () -> Integer @@ -603,6 +603,8 @@ module Whisper class Context def self.new: (String | path | ::URI::HTTP model_name_or_path) -> instance + def segments_from_samples: (Params, Array[Float] samples, ?Integer n_samples) -> Segments + | (Params, _Samples, ?Integer n_samples) -> Segments def detect: (path wav_file_path, Params) -> Segments end diff --git a/bindings/ruby/test/test_vad_context.rb b/bindings/ruby/test/test_vad_context.rb index 704916db6de..b4558d34faf 100644 --- a/bindings/ruby/test/test_vad_context.rb +++ b/bindings/ruby/test/test_vad_context.rb @@ -9,6 +9,25 @@ def test_initialize def test_detect context = Whisper::VAD::Context.new("silero-v6.2.0") segments = context.detect(AUDIO, Whisper::VAD::Params.new) + assert_segments segments + end + + def test_invalid_model_type + assert_raise TypeError do + Whisper::VAD::Context.new(Object.new) + end + end + + def test_allocate + vad = Whisper::VAD::Context.allocate + assert_raise do + vad.detect(AUDIO, Whisper::VAD::Params.new) + end + end + + private + + def assert_segments(segments) assert_instance_of Whisper::VAD::Segments, segments i = 0 @@ -35,16 +54,47 @@ def test_detect assert_equal 4, segments.length end - def test_invalid_model_type - assert_raise TypeError do - Whisper::VAD::Context.new(Object.new) + sub_test_case "from samples" do + def setup + super + @vad = Whisper::VAD::Context.new("silero-v6.2.0") + @samples = File.read(AUDIO, nil, 78).unpack("s<*").collect {|i| i.to_f / 2**15} end - end - def test_allocate - vad = Whisper::VAD::Context.allocate - assert_raise do - vad.detect(AUDIO, Whisper::VAD::Params.new) + def test_segments_from_samples + segments = @vad.segments_from_samples(Whisper::VAD::Params.new, @samples, @samples.length) + assert_segments segments + end + + def test_segments_from_samples_without_length + segments = @vad.segments_from_samples(Whisper::VAD::Params.new, @samples) + assert_segments segments + end + + def test_segments_from_samples_enumerator + samples = @samples.each + segments = @vad.segments_from_samples(Whisper::VAD::Params.new, samples, @samples.length) + assert_segments segments + end + + def test_segments_from_samples_enumerator_without_length + samples = @samples.each + assert_raise ArgumentError do + @vad.segments_from_samples(Whisper::VAD::Params.new, samples) + end + end + + def test_segments_from_samples_enumerator_with_too_large_length + samples = @samples.each.take(10).to_enum + assert_raise StopIteration do + @vad.segments_from_samples(Whisper::VAD::Params.new, samples, 11) + end + end + + def test_segments_from_samples_with_memory_view + samples = JFKReader.new(AUDIO) + segments = @vad.segments_from_samples(Whisper::VAD::Params.new, samples) + assert_segments segments end end end diff --git a/bindings/ruby/test/test_whisper.rb b/bindings/ruby/test/test_whisper.rb index 96e248aca3a..29071210072 100644 --- a/bindings/ruby/test/test_whisper.rb +++ b/bindings/ruby/test/test_whisper.rb @@ -1,6 +1,7 @@ require_relative "helper" require "stringio" require "etc" +require "pathname" # Exists to detect memory-related bug Whisper.log_set ->(level, buffer, user_data) {}, nil @@ -20,6 +21,15 @@ def test_whisper } end + def test_whisper_pathname + @whisper = Whisper::Context.new("base.en") + params = Whisper::Params.new + + @whisper.transcribe(Pathname(AUDIO), params) {|text| + assert_match(/ask not what your country can do for you, ask what you can do for your country/, text) + } + end + def test_transcribe_non_parallel @whisper = Whisper::Context.new("base.en") params = Whisper::Params.new @@ -207,6 +217,16 @@ def test_full_with_memory_view assert_match(/ask not what your country can do for you, ask what you can do for your country/, @whisper.each_segment.first.text) end + def test_full_with_memroy_view_gc + samples = JFKReader.new(AUDIO) + @whisper.full(@params, samples) + GC.start + require "fiddle" + Fiddle::MemoryView.export samples do |view| + assert_equal 176000, view.to_s.unpack("#{view.format}*").length + end + end + def test_full_parallel nprocessors = 2 @whisper.full_parallel(@params, @samples, @samples.length, nprocessors) diff --git a/bindings/ruby/whispercpp.gemspec b/bindings/ruby/whispercpp.gemspec index 2e05769a22c..88b94e7eb8a 100644 --- a/bindings/ruby/whispercpp.gemspec +++ b/bindings/ruby/whispercpp.gemspec @@ -3,7 +3,7 @@ require_relative "extsources" Gem::Specification.new do |s| s.name = "whispercpp" s.authors = ["Georgi Gerganov", "Todd A. Fisher"] - s.version = '1.3.5' + s.version = '1.3.6' s.description = %q{High-performance inference of OpenAI's Whisper automatic speech recognition (ASR) model via Ruby} s.email = 'todd.fisher@gmail.com' s.extra_rdoc_files = ['LICENSE', 'README.md'] From 941bdabbe4561bc6de68981aea01bc5ab05781c5 Mon Sep 17 00:00:00 2001 From: KITAITI Makoto Date: Wed, 4 Feb 2026 20:33:09 +0900 Subject: [PATCH 074/831] ruby : add `Whisper::Context::Params`, fix token memory management (#3647) * Don't convert to temporary VALUE * Define Whisper::Context::Params * Add test for Whisper::Context::Params * Implement Whisper::Context::Params * Add tests for Context::Params * Fix Whisper::Token memory management * Add test for token_timestamps * Make Context accept Context::Params * Make Context::Params.new accept keyword args * Add test for Context::Params.new with keyword args * Add signature of Context::Params * Add example for Whisper::Token * Fix typos * Revert "Don't convert to temporary VALUE" This reverts commit dee66e738491ae742fc981dc6e18ad92f1b05316. * Hold Token#text as Ruby objectd * Don't use pointer for ruby_whisper_context_params.params * Use RUBY_DEFAULT_FREE instead of custom function * Update bindings/ruby/README.md Co-authored-by: Daniel Bevenius * Add document for Whisper::Context::Params --------- Co-authored-by: Daniel Bevenius --- bindings/ruby/README.md | 66 +++++++ bindings/ruby/ext/ruby_whisper.c | 22 ++- bindings/ruby/ext/ruby_whisper.h | 12 +- bindings/ruby/ext/ruby_whisper_context.c | 18 +- .../ruby/ext/ruby_whisper_context_params.c | 163 ++++++++++++++++++ bindings/ruby/ext/ruby_whisper_token.c | 37 +++- bindings/ruby/sig/whisper.rbs | 39 +++++ bindings/ruby/test/test_context_params.rb | 82 +++++++++ bindings/ruby/test/test_token.rb | 11 ++ 9 files changed, 435 insertions(+), 15 deletions(-) create mode 100644 bindings/ruby/ext/ruby_whisper_context_params.c create mode 100644 bindings/ruby/test/test_context_params.rb diff --git a/bindings/ruby/README.md b/bindings/ruby/README.md index 86774158355..c6280a6926a 100644 --- a/bindings/ruby/README.md +++ b/bindings/ruby/README.md @@ -247,6 +247,58 @@ whisper.transcribe("path/to/audio.wav", params) ``` +### Tokens ### + +Each segment has tokens. + +To enable token timestamps, you need to set `Whisper::Params#token_timestamps = true`. Then, retrieve tokens from segments using `Whisper::Segment#each_token`. + +```ruby +whisper = Whisper::Context.new("base.en") +params = Whisper::Params.new(token_timestamps: true) +whisper + .transcribe("path/to/audio.wav", params) + .each_segment do |segment| + segment.each_token do |token| + token => {start_time:, end_time:, text:, probability:} + st = "%05.2fs" % (start_time / 1000.0) + et = "%05.2fs" % (end_time / 1000.0) + prob = "%.1f%%" % (probability * 100) + puts "[#{st} --> #{et}] #{text} (#{prob})" + end + end +``` + +``` +[00.00s --> 00.00s] [_BEG_] (84.2%) +[00.32s --> 00.37s] And (71.2%) +[00.37s --> 00.53s] so (98.5%) +[00.69s --> 00.85s] my (70.7%) +[00.85s --> 01.59s] fellow (99.5%) +[01.59s --> 02.10s] Americans (90.1%) +[02.85s --> 03.30s] , (28.4%) +[03.30s --> 04.14s] ask (79.8%) +[04.14s --> 04.28s] not (78.9%) +[05.03s --> 05.35s] what (93.3%) +[05.41s --> 05.74s] your (98.8%) +[05.74s --> 06.41s] country (99.6%) +[06.41s --> 06.74s] can (97.7%) +[06.74s --> 06.92s] do (99.0%) +[07.00s --> 07.00s] for (95.8%) +[07.01s --> 07.52s] you (98.5%) +[07.81s --> 08.05s] , (49.3%) +[08.19s --> 08.37s] ask (65.6%) +[08.37s --> 08.75s] what (98.8%) +[08.91s --> 09.04s] you (98.2%) +[09.04s --> 09.32s] can (96.9%) +[09.32s --> 09.38s] do (90.3%) +[09.44s --> 09.76s] for (91.8%) +[09.76s --> 09.99s] your (98.2%) +[10.02s --> 10.36s] country (99.6%) +[10.51s --> 10.99s] . (87.0%) +[11.00s --> 11.00s] [_TT_550] (7.6%) +``` + ### Models ### You can see model information: @@ -342,6 +394,20 @@ whisper .full(Whisper::Params.new, samples) ``` +Custom context params +--------------------- + +You can use customize `Whisper::Context`'s behavior using `Whisper::Context::Params`. + +```ruby +context_params = Whisper::Context::Params.new( + use_gpu: false, + flash_attn: false, + # etc +) +whisper = Whisper::Context.new("base", context_params) +``` + Using VAD separately from ASR ----------------------------- diff --git a/bindings/ruby/ext/ruby_whisper.c b/bindings/ruby/ext/ruby_whisper.c index eb95829c032..ba71d4ba594 100644 --- a/bindings/ruby/ext/ruby_whisper.c +++ b/bindings/ruby/ext/ruby_whisper.c @@ -33,7 +33,8 @@ static bool is_log_callback_finalized = false; // High level API extern VALUE ruby_whisper_segment_allocate(VALUE klass); -extern void init_ruby_whisper_context(VALUE *mWhisper); +extern VALUE init_ruby_whisper_context(VALUE *mWhisper); +extern void init_ruby_whisper_context_params(VALUE *cContext); extern void init_ruby_whisper_params(VALUE *mWhisper); extern void init_ruby_whisper_error(VALUE *mWhisper); extern void init_ruby_whisper_segment(VALUE *mWhisper); @@ -162,6 +163,22 @@ void Init_whisper() { rb_define_const(mWhisper, "LOG_LEVEL_DEBUG", INT2NUM(GGML_LOG_LEVEL_DEBUG)); rb_define_const(mWhisper, "LOG_LEVEL_CONT", INT2NUM(GGML_LOG_LEVEL_CONT)); + rb_define_const(mWhisper, "AHEADS_NONE", INT2NUM(WHISPER_AHEADS_NONE)); + rb_define_const(mWhisper, "AHEADS_N_TOP_MOST", INT2NUM(WHISPER_AHEADS_N_TOP_MOST)); + rb_define_const(mWhisper, "AHEADS_CUSTOM", INT2NUM(WHISPER_AHEADS_CUSTOM)); + rb_define_const(mWhisper, "AHEADS_TINY_EN", INT2NUM(WHISPER_AHEADS_TINY_EN)); + rb_define_const(mWhisper, "AHEADS_TINY", INT2NUM(WHISPER_AHEADS_TINY)); + rb_define_const(mWhisper, "AHEADS_BASE_EN", INT2NUM(WHISPER_AHEADS_BASE_EN)); + rb_define_const(mWhisper, "AHEADS_BASE", INT2NUM(WHISPER_AHEADS_BASE)); + rb_define_const(mWhisper, "AHEADS_SMALL_EN", INT2NUM(WHISPER_AHEADS_SMALL_EN)); + rb_define_const(mWhisper, "AHEADS_SMALL", INT2NUM(WHISPER_AHEADS_SMALL)); + rb_define_const(mWhisper, "AHEADS_MEDIUM_EN", INT2NUM(WHISPER_AHEADS_MEDIUM_EN)); + rb_define_const(mWhisper, "AHEADS_MEDIUM", INT2NUM(WHISPER_AHEADS_MEDIUM)); + rb_define_const(mWhisper, "AHEADS_LARGE_V1", INT2NUM(WHISPER_AHEADS_LARGE_V1)); + rb_define_const(mWhisper, "AHEADS_LARGE_V2", INT2NUM(WHISPER_AHEADS_LARGE_V2)); + rb_define_const(mWhisper, "AHEADS_LARGE_V3", INT2NUM(WHISPER_AHEADS_LARGE_V3)); + rb_define_const(mWhisper, "AHEADS_LARGE_V3_TURBO", INT2NUM(WHISPER_AHEADS_LARGE_V3_TURBO)); + rb_define_singleton_method(mWhisper, "lang_max_id", ruby_whisper_s_lang_max_id, 0); rb_define_singleton_method(mWhisper, "lang_id", ruby_whisper_s_lang_id, 1); rb_define_singleton_method(mWhisper, "lang_str", ruby_whisper_s_lang_str, 1); @@ -170,7 +187,8 @@ void Init_whisper() { rb_define_singleton_method(mWhisper, "log_set", ruby_whisper_s_log_set, 2); rb_define_private_method(rb_singleton_class(mWhisper), "finalize_log_callback", ruby_whisper_s_finalize_log_callback, 1); - init_ruby_whisper_context(&mWhisper); + cContext = init_ruby_whisper_context(&mWhisper); + init_ruby_whisper_context_params(&cContext); init_ruby_whisper_params(&mWhisper); init_ruby_whisper_error(&mWhisper); init_ruby_whisper_segment(&mWhisper); diff --git a/bindings/ruby/ext/ruby_whisper.h b/bindings/ruby/ext/ruby_whisper.h index c2c9866ae0d..8dfd103c17a 100644 --- a/bindings/ruby/ext/ruby_whisper.h +++ b/bindings/ruby/ext/ruby_whisper.h @@ -16,6 +16,10 @@ typedef struct { struct whisper_context *context; } ruby_whisper; +typedef struct ruby_whisper_context_params { + struct whisper_context_params params; +} ruby_whisper_context_params; + typedef struct { struct whisper_full_params params; bool diarize; @@ -37,7 +41,7 @@ typedef struct { typedef struct { whisper_token_data *token_data; - const char *text; + VALUE text; } ruby_whisper_token; typedef struct { @@ -71,7 +75,11 @@ typedef struct parsed_samples_t { } \ } while (0) -#define GetToken(obj, rwt) do { \ +#define GetContextParams(obj, rwcp) do { \ + TypedData_Get_Struct((obj), ruby_whisper_context_params, &ruby_whisper_context_params_type, (rwcp)); \ +} while (0) + +#define GetToken(obj, rwt) do { \ TypedData_Get_Struct((obj), ruby_whisper_token, &ruby_whisper_token_type, (rwt)); \ if ((rwt)->token_data == NULL) { \ rb_raise(rb_eRuntimeError, "Not initialized"); \ diff --git a/bindings/ruby/ext/ruby_whisper_context.c b/bindings/ruby/ext/ruby_whisper_context.c index 84790e3dedf..a8118d12773 100644 --- a/bindings/ruby/ext/ruby_whisper_context.c +++ b/bindings/ruby/ext/ruby_whisper_context.c @@ -18,6 +18,7 @@ extern VALUE eError; extern VALUE cModel; extern const rb_data_type_t ruby_whisper_params_type; +extern const rb_data_type_t ruby_whisper_context_params_type; extern VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self); extern VALUE rb_whisper_model_s_new(VALUE context); extern VALUE rb_whisper_segment_s_new(VALUE context, int index); @@ -143,16 +144,25 @@ ruby_whisper_initialize(int argc, VALUE *argv, VALUE self) { ruby_whisper *rw; VALUE whisper_model_file_path; + VALUE context_params; + struct whisper_context_params params; // TODO: we can support init from buffer here too maybe another ruby object to expose - rb_scan_args(argc, argv, "01", &whisper_model_file_path); + rb_scan_args(argc, argv, "11", &whisper_model_file_path, &context_params); TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw); whisper_model_file_path = ruby_whisper_normalize_model_path(whisper_model_file_path); if (!rb_respond_to(whisper_model_file_path, id_to_s)) { rb_raise(rb_eRuntimeError, "Expected file path to model to initialize Whisper::Context"); } - rw->context = whisper_init_from_file_with_params(StringValueCStr(whisper_model_file_path), whisper_context_default_params()); + if (NIL_P(context_params)) { + params = whisper_context_default_params(); + } else { + ruby_whisper_context_params *rwcp; + GetContextParams(context_params, rwcp); + params = rwcp->params; + } + rw->context = whisper_init_from_file_with_params(StringValueCStr(whisper_model_file_path), params); if (rw->context == NULL) { rb_raise(rb_eRuntimeError, "error: failed to initialize whisper context"); } @@ -711,7 +721,7 @@ ruby_whisper_get_model(VALUE self) return rb_whisper_model_s_new(self); } -void +VALUE init_ruby_whisper_context(VALUE *mWhisper) { cContext = rb_define_class_under(*mWhisper, "Context", rb_cObject); @@ -749,4 +759,6 @@ init_ruby_whisper_context(VALUE *mWhisper) rb_define_method(cContext, "each_segment", ruby_whisper_each_segment, 0); rb_define_method(cContext, "model", ruby_whisper_get_model, 0); + + return cContext; } diff --git a/bindings/ruby/ext/ruby_whisper_context_params.c b/bindings/ruby/ext/ruby_whisper_context_params.c new file mode 100644 index 00000000000..87df21d4b5e --- /dev/null +++ b/bindings/ruby/ext/ruby_whisper_context_params.c @@ -0,0 +1,163 @@ +#include "ruby_whisper.h" + +#define NUM_PARAMS 6 + +#define DEF_BOOLEAN_ATTR_METHOD(name) \ +static VALUE \ +ruby_whisper_context_params_get_ ## name(VALUE self) { \ + ruby_whisper_context_params *rwcp; \ + GetContextParams(self, rwcp); \ + return rwcp->params.name ? Qtrue : Qfalse; \ +} \ +static VALUE \ +ruby_whisper_context_params_set_ ## name(VALUE self, VALUE value) { \ + ruby_whisper_context_params *rwcp; \ + GetContextParams(self, rwcp); \ + rwcp->params.name = RTEST(value); \ + return value; \ +} + +#define DEF_INT_ATTR_METHOD(name) \ +static VALUE \ +ruby_whisper_context_params_get_ ## name(VALUE self) { \ + ruby_whisper_context_params *rwcp; \ + GetContextParams(self, rwcp); \ + return INT2NUM(rwcp->params.name); \ +} \ +static VALUE \ +ruby_whisper_context_params_set_ ## name(VALUE self, VALUE value) { \ + ruby_whisper_context_params *rwcp; \ + GetContextParams(self, rwcp); \ + rwcp->params.name = NUM2INT(value); \ + return value; \ +} + +#define DEFINE_PARAM(param_name, nth) \ + id_ ## param_name = rb_intern(#param_name); \ + param_names[nth] = id_ ## param_name; \ + rb_define_method(cContextParams, #param_name, ruby_whisper_context_params_get_ ## param_name, 0); \ + rb_define_method(cContextParams, #param_name "=", ruby_whisper_context_params_set_ ## param_name, 1); + +VALUE cContextParams; + +static ID param_names[NUM_PARAMS]; +static ID id_use_gpu; +static ID id_flash_attn; +static ID id_gpu_device; +static ID id_dtw_token_timestamps; +static ID id_dtw_aheads_preset; +static ID id_dtw_n_top; + +static size_t +ruby_whisper_context_params_memsize(const void *p) +{ + const ruby_whisper_context_params *rwcp = (ruby_whisper_context_params *)p; + if (!rwcp) { + return 0; + } + return sizeof(ruby_whisper_context_params); +} + +const rb_data_type_t ruby_whisper_context_params_type = { + "ruby_whisper_context_params", + {0, RUBY_DEFAULT_FREE, ruby_whisper_context_params_memsize,}, + 0, 0, + 0 +}; + +static VALUE +ruby_whisper_context_params_s_allocate(VALUE klass) +{ + ruby_whisper_context_params *rwcp; + return TypedData_Make_Struct(klass, ruby_whisper_context_params, &ruby_whisper_context_params_type, rwcp); +} + +DEF_BOOLEAN_ATTR_METHOD(use_gpu); +DEF_BOOLEAN_ATTR_METHOD(flash_attn); +DEF_INT_ATTR_METHOD(gpu_device); +DEF_BOOLEAN_ATTR_METHOD(dtw_token_timestamps); +DEF_INT_ATTR_METHOD(dtw_aheads_preset); + +static VALUE +ruby_whisper_context_params_get_dtw_n_top(VALUE self) { + ruby_whisper_context_params *rwcp; + GetContextParams(self, rwcp); + + int dtw_n_top = rwcp->params.dtw_n_top; + + return dtw_n_top == -1 ? Qnil : INT2NUM(dtw_n_top); +} + +static VALUE +ruby_whisper_context_params_set_dtw_n_top(VALUE self, VALUE value) { + ruby_whisper_context_params *rwcp; + GetContextParams(self, rwcp); + + rwcp->params.dtw_n_top = NIL_P(value) ? -1 : NUM2INT(value); + + return value; +} + +#define SET_PARAM_IF_SAME(param_name) \ + if (id == id_ ## param_name) { \ + ruby_whisper_context_params_set_ ## param_name(self, value); \ + continue; \ + } + +static VALUE +ruby_whisper_context_params_initialize(int argc, VALUE *argv, VALUE self) +{ + ruby_whisper_context_params *rwcp; + TypedData_Get_Struct(self, ruby_whisper_context_params, &ruby_whisper_context_params_type, rwcp); + rwcp->params = whisper_context_default_params(); + + VALUE kw_hash; + rb_scan_args_kw(RB_SCAN_ARGS_KEYWORDS, argc, argv, ":", &kw_hash); + if (NIL_P(kw_hash)) { + return Qnil; + } + + VALUE values[NUM_PARAMS] = {Qundef}; + rb_get_kwargs(kw_hash, param_names, 0, NUM_PARAMS, values); + + ID id; + VALUE value; + for (int i = 0; i < NUM_PARAMS; i++) { + id = param_names[i]; + value = values[i]; + if (value == Qundef) { + continue; + } + SET_PARAM_IF_SAME(use_gpu) + SET_PARAM_IF_SAME(flash_attn) + SET_PARAM_IF_SAME(gpu_device) + SET_PARAM_IF_SAME(dtw_token_timestamps) + SET_PARAM_IF_SAME(dtw_aheads_preset) + SET_PARAM_IF_SAME(dtw_n_top) + } + + return Qnil; +} + +#undef SET_PARAM_IF_SAME + +void +init_ruby_whisper_context_params(VALUE *cContext) +{ + cContextParams = rb_define_class_under(*cContext, "Params", rb_cObject); + + rb_define_alloc_func(cContextParams, ruby_whisper_context_params_s_allocate); + rb_define_method(cContextParams, "initialize", ruby_whisper_context_params_initialize, -1); + + DEFINE_PARAM(use_gpu, 0) + DEFINE_PARAM(flash_attn, 1) + DEFINE_PARAM(gpu_device, 2) + DEFINE_PARAM(dtw_token_timestamps, 3) + DEFINE_PARAM(dtw_aheads_preset, 4) + DEFINE_PARAM(dtw_n_top, 5) +} + +#undef DEFINE_PARAM +#undef DEF_INT_ATTR_METHOD +#undef DEF_BOOLEAN_ATTR_METHOD +#undef NUM_PARAMS diff --git a/bindings/ruby/ext/ruby_whisper_token.c b/bindings/ruby/ext/ruby_whisper_token.c index 56a7eab2231..73f5a547daf 100644 --- a/bindings/ruby/ext/ruby_whisper_token.c +++ b/bindings/ruby/ext/ruby_whisper_token.c @@ -24,12 +24,34 @@ ruby_whisper_token_memsize(const void *p) if (!rwt) { return 0; } - return sizeof(rwt); + size_t size = sizeof(*rwt); + if (rwt->token_data) { + size += sizeof(*rwt->token_data); + } + return size; +} + +static void +ruby_whisper_token_mark(void *p) +{ + ruby_whisper_token *rwt = (ruby_whisper_token *)p; + rb_gc_mark(rwt->text); +} + +static void +ruby_whisper_token_free(void *p) +{ + ruby_whisper_token *rwt = (ruby_whisper_token *)p; + if (rwt->token_data) { + xfree(rwt->token_data); + rwt->token_data = NULL; + } + xfree(rwt); } static const rb_data_type_t ruby_whisper_token_type = { "ruby_whisper_token", - {0, RUBY_DEFAULT_FREE, ruby_whisper_token_memsize,}, + {ruby_whisper_token_mark, ruby_whisper_token_free, ruby_whisper_token_memsize,}, 0, 0, 0 }; @@ -40,19 +62,19 @@ ruby_whisper_token_allocate(VALUE klass) ruby_whisper_token *rwt; VALUE token = TypedData_Make_Struct(klass, ruby_whisper_token, &ruby_whisper_token_type, rwt); rwt->token_data = NULL; - rwt->text = NULL; + rwt->text = Qnil; return token; } VALUE ruby_whisper_token_s_init(struct whisper_context *context, int i_segment, int i_token) { - whisper_token_data token_data = whisper_full_get_token_data(context, i_segment, i_token); const VALUE token = ruby_whisper_token_allocate(cToken); ruby_whisper_token *rwt; TypedData_Get_Struct(token, ruby_whisper_token, &ruby_whisper_token_type, rwt); - rwt->token_data = &token_data; - rwt->text = whisper_full_get_token_text(context, i_segment, i_token); + rwt->token_data = ALLOC(whisper_token_data); + *(rwt->token_data) = whisper_full_get_token_data(context, i_segment, i_token); + rwt->text = rb_str_new2(whisper_full_get_token_text(context, i_segment, i_token)); return token; } @@ -182,10 +204,9 @@ ruby_whisper_token_get_text(VALUE self) { ruby_whisper_token *rwt; GetToken(self, rwt); - return rb_str_new2(rwt->text); + return rwt->text; } - /* * Start time of the token. * diff --git a/bindings/ruby/sig/whisper.rbs b/bindings/ruby/sig/whisper.rbs index 0e7b2c276e8..9ade451c6b2 100644 --- a/bindings/ruby/sig/whisper.rbs +++ b/bindings/ruby/sig/whisper.rbs @@ -17,6 +17,21 @@ module Whisper LOG_LEVEL_ERROR: Integer LOG_LEVEL_DEBUG: Integer LOG_LEVEL_CONT: Integer + AHEADS_NONE: Integer + AHEADS_N_TOP_MOST: Integer + AHEADS_CUSTOM: Integer + AHEADS_TINY_EN: Integer + AHEADS_TINY: Integer + AHEADS_BASE_EN: Integer + AHEADS_BASE: Integer + AHEADS_SMALL_EN: Integer + AHEADS_SMALL: Integer + AHEADS_MEDIUM_EN: Integer + AHEADS_MEDIUM: Integer + AHEADS_LARGE_V1: Integer + AHEADS_LARGE_V2: Integer + AHEADS_LARGE_V3: Integer + AHEADS_LARGE_V3_TURBO: Integer def self.lang_max_id: () -> Integer def self.lang_id: (string name) -> Integer @@ -120,6 +135,30 @@ module Whisper def to_srt: () -> String def to_webvtt: () -> String + + class Params + def self.new: ( + use_gpu: boolish, + flash_attn: boolish, + gpu_device: Integer, + dtw_token_timestamps: boolish, + dtw_aheads_preset: Integer, + dtw_n_top: Integer | nil, + ) -> instance + + def use_gpu=: (boolish) -> boolish + def use_gpu: () -> (true | false) + def flash_attn=: (boolish) -> boolish + def flash_attn: () -> (true | false) + def gpu_device=: (Integer) -> Integer + def gpu_device: () -> Integer + def dtw_token_timestamps=: (boolish) -> boolish + def dtw_token_timestamps: () -> (true | false) + def dtw_aheads_preset=: (Integer) -> Integer + def dtw_aheads_preset: () -> Integer + def dtw_n_top=: (Integer | nil) -> (Integer | nil) + def dtw_n_top: () -> (Integer | nil) + end end class Params diff --git a/bindings/ruby/test/test_context_params.rb b/bindings/ruby/test/test_context_params.rb new file mode 100644 index 00000000000..8d19fdc94cb --- /dev/null +++ b/bindings/ruby/test/test_context_params.rb @@ -0,0 +1,82 @@ +require_relative "helper" + +class TestContextParams < TestBase + PARAM_NAMES = [ + :use_gpu, + :flash_attn, + :gpu_device, + :dtw_token_timestamps, + :dtw_aheads_preset, + :dtw_n_top + ] + + def test_new + params = Whisper::Context::Params.new + assert_instance_of Whisper::Context::Params, params + end + + def test_attributes + params = Whisper::Context::Params.new + + assert_true params.use_gpu + params.use_gpu = false + assert_false params.use_gpu + + assert_true params.flash_attn + params.flash_attn = false + assert_false params.flash_attn + + assert_equal 0, params.gpu_device + params.gpu_device = 1 + assert_equal 1, params.gpu_device + + assert_false params.dtw_token_timestamps + params.dtw_token_timestamps = true + assert_true params.dtw_token_timestamps + + assert_equal Whisper::AHEADS_NONE, params.dtw_aheads_preset + params.dtw_aheads_preset =Whisper::AHEADS_BASE + assert_equal Whisper::AHEADS_BASE, params.dtw_aheads_preset + + assert_nil params.dtw_n_top + params.dtw_n_top = 6 + assert_equal 6, params.dtw_n_top + params.dtw_n_top = nil + assert_nil params.dtw_n_top + end + + def test_new_with_kw_args + params = Whisper::Context::Params.new(use_gpu: false) + assert_false params.use_gpu + end + + def test_new_with_kw_wargs_non_existent + assert_raise ArgumentError do + Whisper::Context::Params.new(non_existent: "value") + end + end + + data(PARAM_NAMES.collect {|param| [param, param]}.to_h) + def test_new_with_kw_args_default_values(param) + default_params = Whisper::Context::Params.new + default_value = default_params.send(param) + value = if param == :dtw_n_top + 6 + else + case default_value + in true | false + !default_value + in Integer + default_value + 1 + end + end + params = Whisper::Context::Params.new(param => value) + assert_equal value, params.send(param) + + PARAM_NAMES.reject {|name| name == param}.each do |name| + expected = default_params.send(name) + actual = params.send(name) + assert_equal expected, actual + end + end +end diff --git a/bindings/ruby/test/test_token.rb b/bindings/ruby/test/test_token.rb index e5834b1b480..a23f6813675 100644 --- a/bindings/ruby/test/test_token.rb +++ b/bindings/ruby/test/test_token.rb @@ -56,6 +56,17 @@ def test_text @segment.each_token.collect(&:text) end + def test_token_timestamps + params = Whisper::Params.new(token_timestamps: true) + whisper.transcribe(TestBase::AUDIO, params) + prev = -1 + whisper.each_segment.first.each_token do |token| + assert token.start_time >= prev + assert token.end_time >= token.start_time + prev = token.end_time + end + end + def test_deconstruct_keys_with_nil keys = %i[id tid probability log_probability pt ptsum t_dtw voice_length start_time end_time text] expected = keys.collect {|key| [key, @token.send(key)] }.to_h From fc1a3e579e21b0b59a4ef049d9a59c131f1ce3ec Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 30 Jan 2026 16:29:51 +0200 Subject: [PATCH 075/831] cmake : remove unused file (ggml/1419) --- ggml/cmake/BuildTypes.cmake | 54 ------------------------------------- 1 file changed, 54 deletions(-) delete mode 100644 ggml/cmake/BuildTypes.cmake diff --git a/ggml/cmake/BuildTypes.cmake b/ggml/cmake/BuildTypes.cmake deleted file mode 100644 index a9c7b6c91ec..00000000000 --- a/ggml/cmake/BuildTypes.cmake +++ /dev/null @@ -1,54 +0,0 @@ -# Add new build types - -# ReleaseGG - Release with enabled asserts - -SET(CMAKE_CXX_FLAGS_RELEASEGG - "-O3" - CACHE STRING "Flags used by the c++ compiler during release builds with enabled asserts." - FORCE ) -SET(CMAKE_C_FLAGS_RELEASEGG - "-O3" - CACHE STRING "Flags used by the compiler during release builds with enabled asserts." - FORCE ) -SET(CMAKE_EXE_LINKER_FLAGS_RELEASEGG - "" - CACHE STRING "Flags used for linking binaries during release builds with enabled asserts." - FORCE ) -SET(CMAKE_SHARED_LINKER_FLAGS_RELEASEGG - "" - CACHE STRING "Flags used by the shared libraries linker during release builds with enabled asserts." - FORCE ) -MARK_AS_ADVANCED( - CMAKE_CXX_FLAGS_RELEASEGG - CMAKE_C_FLAGS_RELEASEGG - CMAKE_EXE_LINKER_FLAGS_RELEASEGG - CMAKE_SHARED_LINKER_FLAGS_RELEASEGG ) - -# RelWithDebInfoGG - RelWithDebInfo with enabled asserts - -SET(CMAKE_CXX_FLAGS_RELWITHDEBINFOGG - "-O2 -g" - CACHE STRING "Flags used by the c++ compiler during release builds with debug symbols and enabled asserts." - FORCE ) -SET(CMAKE_C_FLAGS_RELWITHDEBINFOGG - "-O2 -g" - CACHE STRING "Flags used by the compiler during release builds with debug symbols and enabled asserts." - FORCE ) -SET(CMAKE_EXE_LINKER_FLAGS_RELWITHDEBINFOGG - "" - CACHE STRING "Flags used for linking binaries during release builds with debug symbols and enabled asserts." - FORCE ) -SET(CMAKE_SHARED_LINKER_FLAGS_RELWITHDEBINFOGG - "" - CACHE STRING "Flags used by the shared libraries linker during release builds with debug symbols and enabled asserts." - FORCE ) -MARK_AS_ADVANCED( - CMAKE_CXX_FLAGS_RELWITHDEBINFOGG - CMAKE_C_FLAGS_RELWITHDEBINFOGG - CMAKE_EXE_LINKER_FLAGS_RELWITHDEBINFOGG - CMAKE_SHARED_LINKER_FLAGS_RELWITHDEBINFOGG ) - -if (NOT XCODE AND NOT MSVC AND NOT CMAKE_BUILD_TYPE) - set(CMAKE_BUILD_TYPE Release CACHE STRING "Build type" FORCE) - set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo" "ReleaseGG" "RelWithDebInfoGG") -endif() From 06e37504073aa0f301bfa9a7c6ff246625a0f22f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 7 Feb 2026 09:58:02 +0200 Subject: [PATCH 076/831] ggml : bump version to 0.9.6 (ggml/1423) --- ggml/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index b0b8e57898c..590242e3f01 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -4,7 +4,7 @@ project("ggml" C CXX ASM) ### GGML Version set(GGML_VERSION_MAJOR 0) set(GGML_VERSION_MINOR 9) -set(GGML_VERSION_PATCH 5) +set(GGML_VERSION_PATCH 6) set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}") find_program(GIT_EXE NAMES git git.exe NO_CMAKE_FIND_ROOT_PATH) From efd6344939a03c0e5fd41220856055040d0712fd Mon Sep 17 00:00:00 2001 From: Simon Redman Date: Fri, 30 Jan 2026 11:27:16 -0500 Subject: [PATCH 077/831] Correctly fetch q8_1 quantize pipeline in test as needed by 8a3519b (llama/19194) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 3852867c291..a99375c0885 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -11956,7 +11956,8 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, } } if (mmq) { - ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_quantize_q8_1, num_it); + vk_pipeline pipeline_quantize_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1); + ggml_pipeline_request_descriptor_sets(ctx, pipeline_quantize_q8_1, num_it); } ggml_pipeline_allocate_descriptor_sets(ctx); From db9c88744de9e7fa775284929fa03f32e6e813ec Mon Sep 17 00:00:00 2001 From: shaofeiqi Date: Fri, 30 Jan 2026 10:19:27 -0800 Subject: [PATCH 078/831] opencl: add optimized q8_0 mm kernel for adreno (llama/18871) * Add Q8_0 OpenCL kernel Co-authored-by: yunjie * opencl: fix build for non-adreno * opencl: refactor q8_0 * opencl: enforce subgroup size of 64 for adreno for q8_0 * For A750 and older generations, subgroup size can be 64 or 128. This kernel assumes subgroup size 64. * opencl: suppress warning when adreno kernels are disabled --------- Co-authored-by: yunjie Co-authored-by: Li He --- ggml/src/ggml-opencl/CMakeLists.txt | 2 + ggml/src/ggml-opencl/ggml-opencl.cpp | 464 +++++++++++++++++- ggml/src/ggml-opencl/kernels/cvt.cl | 31 ++ .../gemv_noshuffle_general_q8_0_f32.cl | 195 ++++++++ .../kernels/mul_mm_q8_0_f32_8x4.cl | 129 +++++ 5 files changed, 819 insertions(+), 2 deletions(-) create mode 100644 ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl create mode 100644 ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index 0259474b6e1..fa5fadd112b 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -101,6 +101,8 @@ set(GGML_OPENCL_KERNELS mul_mm_f32_f32_l4_lm mul_mm_f16_f32_l4_lm mul_mm_q8_0_f32_l4_lm + mul_mm_q8_0_f32_8x4 + gemv_noshuffle_general_q8_0_f32 mul norm relu diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 678e40965ad..4850c11d147 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -226,7 +226,8 @@ static ADRENO_GPU_GEN get_adreno_gpu_gen(const char *device_name) { return ADRENO_GPU_GEN::A7X; } - if (strstr(device_name, "830")) { + if (strstr(device_name, "830") || + strstr(device_name, "840")) { return ADRENO_GPU_GEN::A8X; } @@ -529,7 +530,7 @@ struct ggml_backend_opencl_context { cl_kernel kernel_mul_mat_q4_0_f32, kernel_mul_mat_q4_0_f32_v; cl_kernel kernel_convert_block_q4_0, kernel_restore_block_q4_0; cl_kernel kernel_convert_block_mxfp4, kernel_convert_block_mxfp4_trans, kernel_restore_block_mxfp4, kernel_restore_block_mxfp4_trans; - cl_kernel kernel_convert_block_q8_0, kernel_restore_block_q8_0; + cl_kernel kernel_convert_block_q8_0, kernel_restore_block_q8_0, kernel_restore_block_q8_0_trans; cl_kernel kernel_mul_mat_q4_0_f32_8x_flat; cl_kernel kernel_convert_block_q4_0_noshuffle; cl_kernel kernel_restore_block_q4_0_noshuffle; @@ -696,6 +697,8 @@ struct ggml_backend_opencl_context { cl_kernel CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_4096; cl_kernel CL_mul_mat_vec_q4_0_f32_1d_4x_flat_11008_1_4096; cl_kernel CL_mul_mat_vec_q4_0_f32_1d_4x_flat_32000_1_4096; + cl_kernel kernel_mul_mm_q8_0_f32_8x4; + cl_kernel CL_mul_mat_vec_q8_0_f32; #endif // GGML_OPENCL_USE_ADRENO_KERNELS void free() { @@ -894,6 +897,7 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve CL_CHECK((backend_ctx->kernel_restore_block_mxfp4 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_mxfp4", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q8_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q8_0", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_q8_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q8_0", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q8_0_trans = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q8_0_trans", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q6_K = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q6_K", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_q6_K = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q6_K", &err), err)); GGML_LOG_CONT("."); @@ -2290,6 +2294,46 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // mul_mm_q8_0_f32_8x4 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src_q8_8x4_gemm { + #include "mul_mm_q8_0_f32_8x4.cl.h" + }; +#else + const std::string kernel_src_q8_8x4_gemm = read_file("mul_mm_q8_0_f32_8x4.cl"); +#endif + backend_ctx->program_CL_gemm = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_q8_8x4_gemm.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_mul_mm_q8_0_f32_8x4 = clCreateKernel(backend_ctx->program_CL_gemm, "kernel_mul_mm_q8_0_f32_8x4", &err), err)); + GGML_LOG_CONT("."); + } + + // gemv_noshuffle_general_q8_0_f32 + { + std::string CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable " + " -DSIMDGROUP_WIDTH=" + + std::to_string(backend_ctx->adreno_wave_size); + if (backend_ctx->has_vector_subgroup_broadcast) { + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; + } + +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src_CL_gemv_general { + #include "gemv_noshuffle_general_q8_0_f32.cl.h" + }; +#else + const std::string kernel_src_CL_gemv_general = read_file("gemv_noshuffle_general_q8_0_f32.cl"); +#endif + + cl_program prog = build_program_from_source( + backend_ctx->context, backend_ctx->device, kernel_src_CL_gemv_general.c_str(), CL_gemv_compile_opts); + + CL_CHECK((backend_ctx->CL_mul_mat_vec_q8_0_f32 = clCreateKernel(prog, "kernel_gemv_noshuffle", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + std::string CL_moe_compile_opts = std::string("-cl-std=") + opencl_c_std + " -cl-mad-enable " " -cl-fast-relaxed-math"; @@ -3745,6 +3789,15 @@ inline bool use_adreno_moe_kernels(const ggml_backend_opencl_context *backend_ct return ((strstr(tensor->name, "ffn") != NULL) || (strstr(tensor->name, "as") != NULL)) && (ne01 % 64 == 0); } +inline bool enable_adreno_trans_weight(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) { + + bool adreno_kernel = use_adreno_kernels(backend_ctx, tensor); + + size_t elem_num = tensor->ne[0] * tensor->ne[1] * tensor->ne[2] * tensor->ne[3]; + + return ((elem_num < 128 * 1024 * 1024) && adreno_kernel); // max element num: 2**27 +} + static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(buffer->buft->device); @@ -4159,6 +4212,130 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, tensor->extra = extra; + // Transpose the weights and scales +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (enable_adreno_trans_weight(backend_ctx, tensor)) { + + int M = tensor->ne[1]; // ne01 + int K = tensor->ne[0]; // ne00 + + GGML_ASSERT(K % 32 == 0); + GGML_ASSERT(M % 4 == 0); + GGML_ASSERT(tensor->ne[2] == 1); + GGML_ASSERT(tensor->ne[3] == 1); + + // Transpose weights + size_t q_size_bytes = K * M / 4 * sizeof(float); + cl_buffer_region region; + region.origin = 0; + region.size = q_size_bytes; + cl_mem qT_d = clCreateSubBuffer( + backend_ctx->prealloc_quant_trans.buffer, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &err); + CL_CHECK(err); + + cl_mem q_d_image1D; + cl_mem qT_d_image1D; + + cl_image_format img_fmt_1d; + cl_image_desc img_desc_1d; + + img_fmt_1d = { CL_RGBA, CL_FLOAT }; + memset(&img_desc_1d, 0, sizeof(img_desc_1d)); + img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc_1d.image_width = M * K / 4 / 4; + img_desc_1d.buffer = extra->q; + q_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err); + CL_CHECK(err); + + img_fmt_1d = { CL_RGBA, CL_FLOAT }; + memset(&img_desc_1d, 0, sizeof(img_desc_1d)); + img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc_1d.image_width = M * K / 4 / 4; + img_desc_1d.buffer = qT_d; + qT_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err); + CL_CHECK(err); + + int height_q = M / 4; + int width_q = K / 4 / 4; + kernel = backend_ctx->kernel_transpose_32; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &q_d_image1D)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &qT_d_image1D)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_q)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_q)); + + size_t local_size_q[3] = {4, 16, 1}; + size_t global_size_q[3] = {static_cast(width_q), static_cast(height_q), 1}; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_size_q, local_size_q, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + + // Transpose scales + size_t d_size_bytes = M * (K / 32) * 2; + region.origin = 0; + region.size = d_size_bytes; + cl_mem dT_d = clCreateSubBuffer( + backend_ctx->prealloc_scales_trans.buffer, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &err); + CL_CHECK(err); + + cl_mem d_d_image1D; + cl_mem dT_d_image1D; + + memset(&img_desc_1d, 0, sizeof(img_desc_1d)); + img_fmt_1d = { CL_R, CL_HALF_FLOAT }; + img_desc_1d.image_width = M * K / 32; + img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc_1d.buffer = extra->d; + d_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err); + CL_CHECK(err); + + img_fmt_1d = { CL_RGBA, CL_HALF_FLOAT }; + memset(&img_desc_1d, 0, sizeof(img_desc_1d)); + img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc_1d.image_width = M * K / 32 / 4; + img_desc_1d.buffer = dT_d; + dT_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err); + CL_CHECK(err); + + int height_s = M / 4; + int width_s = K / 32; + + kernel = backend_ctx->kernel_transpose_16_4x1; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &d_d_image1D)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &dT_d_image1D)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_s)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_s)); + + size_t local_size_s[3] = {4, 16, 1}; + size_t global_size_s[3] = {static_cast(width_s), static_cast(height_s), 1}; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_size_s, local_size_s, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + + // copy transposed buffer contents to original buffers + CL_CHECK(clEnqueueCopyBuffer(queue, qT_d, extra->q, 0, 0, q_size_bytes, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + + CL_CHECK(clEnqueueCopyBuffer(queue, dT_d, extra->d, 0, 0, d_size_bytes, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + + CL_CHECK(clReleaseMemObject(qT_d)); + CL_CHECK(clReleaseMemObject(dT_d)); + + CL_CHECK(clReleaseMemObject(q_d_image1D)); + CL_CHECK(clReleaseMemObject(d_d_image1D)); + CL_CHECK(clReleaseMemObject(qT_d_image1D)); + CL_CHECK(clReleaseMemObject(dT_d_image1D)); + } // end transpose +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + return; } if (tensor->type == GGML_TYPE_Q6_K) { @@ -4448,6 +4625,36 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, ggml_nbytes(tensor), NULL, &err); CL_CHECK(err); +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (enable_adreno_trans_weight(backend_ctx, tensor)) { + cl_kernel kernel = backend_ctx->kernel_restore_block_q8_0_trans; + + int ne00 = tensor->ne[0]; + int ne01 = tensor->ne[1]; + GGML_ASSERT(tensor->ne[2] == 1); // ??? + GGML_ASSERT(tensor->ne[3] == 1); // ??? + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_int), &ne01)); + + size_t global_work_size[3] = {static_cast(((ne01 + 63) / 64) * 64), 1, 1}; + size_t local_work_size[3] = {64, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + + CL_CHECK(clEnqueueReadBuffer( + queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); + return; + } +#endif cl_kernel kernel = backend_ctx->kernel_restore_block_q8_0; CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q)); CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->d)); @@ -7947,6 +8154,252 @@ static void ggml_cl_mul_mat_kq_kqv_adreno(ggml_backend_t backend, const ggml_ten CL_CHECK(clReleaseMemObject(D_sub_buffer)); } +static void ggml_cl_mul_mat_q8_0_f32_adreno(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + const enum ggml_type src0t = src0->type; + const enum ggml_type src1t = src1->type; + + GGML_ASSERT(src0t == GGML_TYPE_Q8_0); + GGML_ASSERT(src1t == GGML_TYPE_F32); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + ggml_tensor_extra_cl_q8_0 * extra0_q8_0 = (ggml_tensor_extra_cl_q8_0 *)src0->extra; + + GGML_ASSERT(src1->view_offs == 0); + GGML_ASSERT(dst->view_offs == 0); + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + + const int ne10 = src1->ne[0]; + const int ne12 = src1->ne[2]; + + const int ne0 = dst->ne[0]; + const int ne1 = dst->ne[1]; + + GGML_ASSERT(ne00 == ne10); + GGML_ASSERT((ne00 % 32) == 0); + GGML_ASSERT(ne0 == ne01); + + cl_context context = backend_ctx->context; + cl_kernel kernel; + + // init CL objects + cl_int status; + cl_image_format img_fmt_1d; + cl_image_desc img_desc_1d; + cl_buffer_region region; + cl_mem A_image1d; + cl_mem B_image1d; + cl_mem B_sub_buffer; + cl_mem S_image1d; + + cl_mem D_image1d; + cl_mem D_sub_buffer; + + int M = ne01; + int N = ne1; + int K = ne00; + + // create an image for A + img_fmt_1d = { CL_R, CL_FLOAT}; + memset(&img_desc_1d, 0, sizeof(img_desc_1d)); + img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc_1d.image_width = M * K / 4; // Divide by 4 for char -> float + img_desc_1d.buffer = extra0_q8_0->q; + A_image1d = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt_1d, &img_desc_1d, NULL, &status); + CL_CHECK(status); + + // create an image for Scale + img_fmt_1d = { CL_R, CL_HALF_FLOAT}; + memset(&img_desc_1d, 0, sizeof(img_desc_1d)); + img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc_1d.image_width = M * K / 32; // Block size is 32 + img_desc_1d.buffer = extra0_q8_0->d; + S_image1d = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt_1d, &img_desc_1d, NULL, &status); + CL_CHECK(status); + + // create a sub_buffer for B + region.origin = (extra1->offset); // + src1->view_offs); + region.size = K * N * sizeof(float); + B_sub_buffer = clCreateSubBuffer((extra1->data_device), 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // create an image for B from sub_buffer: RGBA (OCL) + img_fmt_1d = {CL_RGBA, CL_FLOAT}; + memset(&img_desc_1d, 0, sizeof(img_desc_1d)); + img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc_1d.image_width = K * N / 4; + img_desc_1d.buffer = B_sub_buffer; + B_image1d = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt_1d, &img_desc_1d, NULL, &status); + CL_CHECK(status); + + // Create subbuffer and image1d_buffer for dst + region.origin = (extrad->offset); // + dst->view_offs; + region.size = M * N * sizeof(float); + D_sub_buffer = clCreateSubBuffer((extrad->data_device), 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + img_fmt_1d = {CL_R, CL_FLOAT}; + memset(&img_desc_1d, 0, sizeof(img_desc_1d)); + img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc_1d.image_width = M * N; + img_desc_1d.buffer = D_sub_buffer; + D_image1d = clCreateImage(context, CL_MEM_WRITE_ONLY, &img_fmt_1d, &img_desc_1d, NULL, &status); + CL_CHECK(status); + + size_t local_work_size[3] = {1, 1, 1}; + size_t global_work_size[3] = {1, 1, 1}; + + if (N == 1) { + kernel = backend_ctx->CL_mul_mat_vec_q8_0_f32; + + int r2 = 1; + int r3 = 1; + cl_uint k_arg = 0; + + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_mem), &A_image1d)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_mem), &extra0_q8_0->d)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_mem), &B_image1d)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_ulong), &extra1->offset)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_ulong), &extrad->offset)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &r3)); + + size_t wavesize = backend_ctx->adreno_wave_size; + local_work_size[0] = wavesize; + local_work_size[1] = 4; // reduce factor + local_work_size[2] = 1; + + global_work_size[0] = ((M + wavesize - 1) / wavesize) * wavesize; + global_work_size[1] = 4; // reduce factor + global_work_size[2] = 1; + } else { + cl_ulong offsetd = extrad->offset + dst->view_offs; + cl_mem B_image1d_trans = nullptr; + // for B transpose + cl_mem B_d = nullptr; + int padding; + + //how many extra elements beyond multiple of 8 + int extra_elements = N % 8; + + //how much padding to add + padding = 0; + if (extra_elements > 0){ + padding = 8 - extra_elements; + } + + // Specify the starting offset (in bytes) + region.origin = 0; + // Specify the size of the sub-buffer (divide by 2 for FP16) + region.size = K * (N + padding) * sizeof(float)/2; + backend_ctx->prealloc_act_trans.allocate(context, region.size); + B_d = clCreateSubBuffer( + backend_ctx->prealloc_act_trans.buffer, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &status); + CL_CHECK(status); + + cl_image_format image_format_B_d_output = { CL_RGBA, CL_HALF_FLOAT }; //(CL_HALF_FLOAT for FP16) + cl_image_desc image_desc_B_d_output = { + CL_MEM_OBJECT_IMAGE1D_BUFFER, + static_cast(K * (N + padding)/4), + 0, 0, 0, 0, 0, 0, 0, { B_d } + }; + B_image1d_trans = clCreateImage( + context, + 0, + &image_format_B_d_output, + &image_desc_B_d_output, + NULL, + &status); + CL_CHECK(status); + + int height_B = N/4; + if (height_B == 0) { + height_B = 1; + } + int width_B = K/4; + int padded_height_B = (N + padding)/4; + + kernel = backend_ctx->kernel_transpose_32_16; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &B_image1d)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &B_image1d_trans)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_B)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_B)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &padded_height_B)); + + size_t local_size_t[2] = { 1, 16 }; + size_t global_size_t[2] = { + static_cast(width_B), + static_cast(padded_height_B) + }; + + backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_size_t, local_size_t, dst); + + kernel = backend_ctx->kernel_mul_mm_q8_0_f32_8x4; + + int N_with_padding = N + padding; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q8_0->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q8_0->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &B_image1d_trans)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &K)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &M)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &N_with_padding)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &N)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &offsetd)); + + global_work_size[0] = (size_t)(N + 7) / 8; + global_work_size[1] = (size_t)(M + 3) / 4; + global_work_size[2] = 1; + + local_work_size[0] = 2; + local_work_size[1] = 128; + local_work_size[2] = 1; + } + + // enqueue kernel with profiling + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + + // deallocate sub buffers and images + CL_CHECK(clReleaseMemObject(A_image1d)); + CL_CHECK(clReleaseMemObject(B_sub_buffer)); + CL_CHECK(clReleaseMemObject(B_image1d)); + CL_CHECK(clReleaseMemObject(S_image1d)); + CL_CHECK(clReleaseMemObject(D_sub_buffer)); + CL_CHECK(clReleaseMemObject(D_image1d)); +#else + GGML_UNUSED(src0); + GGML_UNUSED(src1); + GGML_UNUSED(dst); +#endif +} + static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0); GGML_ASSERT(src0->extra); @@ -8064,6 +8517,13 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co int padding; // <--------------------------------------------> // + // q8_0 x fp32 + if (src0t == GGML_TYPE_Q8_0 && src1t == GGML_TYPE_F32 && + enable_adreno_trans_weight(backend_ctx, src0)) { + ggml_cl_mul_mat_q8_0_f32_adreno(backend, src0, src1, dst); + return; + } + // q4_0 x fp32 if(src0t == GGML_TYPE_Q4_0 && src1t == GGML_TYPE_F32) { // TODO: remove duplicate definitions of image description + format -- move to top diff --git a/ggml/src/ggml-opencl/kernels/cvt.cl b/ggml/src/ggml-opencl/kernels/cvt.cl index adf576a8394..9fb434713df 100644 --- a/ggml/src/ggml-opencl/kernels/cvt.cl +++ b/ggml/src/ggml-opencl/kernels/cvt.cl @@ -274,6 +274,37 @@ kernel void kernel_restore_block_q8_0( } } +kernel void kernel_restore_block_q8_0_trans( + global uchar * src_q, + global half * src_d, + global block_q8_0 * dst, + uint ne00, + uint ne01 +){ + uint num_blk_per_row = ne00 / QK8_0; + + global block_q8_0 * b = (global block_q8_0 *) dst + get_global_id(0) * num_blk_per_row; + global uchar * q = (global uchar *) src_q + get_global_id(0) * 4; // 4 8-bit packed + global half * d = (global half *) src_d + get_global_id(0); + + for (uint blk = 0; blk < num_blk_per_row; blk++) { + b->d = *d; + + for (uint i = 0; i < QK8_0; i+=4) { + b->qs[i] = q[0]; + b->qs[i+1] = q[1]; + b->qs[i+2] = q[2]; + b->qs[i+3] = q[3]; + + q += 4 * ne01; // M stride + } + + d += ne01; + + b++; + } +} + //------------------------------------------------------------------------------ // kernel_convert_block_q6_K // Convert the block_q6_K format to 3 separate arrays (AOS -> SOA). diff --git a/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl new file mode 100644 index 00000000000..f944ef3a992 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl @@ -0,0 +1,195 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable + +#ifdef cl_qcom_reqd_sub_group_size +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#endif + +#define QK8_0 32 +#define N_SIMDGROUP 4 + +#define dequantizeBlockAccum_ns_sgbroadcast_1(total_sums, bits8, scale, y) \ + float shared_y; \ + char elem; \ + \ + shared_y = sub_group_broadcast(y.s0, 0); \ + elem = (char)(bits8.s0 & 0x000000FF); \ + total_sums += convert_int(elem) * scale * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 0); \ + elem = (char)((bits8.s0 & 0x0000FF00) >> 8); \ + total_sums += convert_int(elem) * scale * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 0); \ + elem = (char)((bits8.s0 & 0x00FF0000) >> 16); \ + total_sums += convert_int(elem) * scale * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 0); \ + elem = (char)((bits8.s0 & 0xFF000000) >> 24); \ + total_sums += convert_int(elem) * scale * shared_y; \ + \ + shared_y = sub_group_broadcast(y.s4, 0); \ + elem = (char)(bits8.s1 & 0x000000FF); \ + total_sums += convert_int(elem) * scale * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 0); \ + elem = (char)((bits8.s1 & 0x0000FF00) >> 8); \ + total_sums += convert_int(elem) * scale * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 0); \ + elem = (char)((bits8.s1 & 0x00FF0000) >> 16); \ + total_sums += convert_int(elem) * scale * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 0); \ + elem = (char)((bits8.s1 & 0xFF000000) >> 24); \ + total_sums += convert_int(elem) * scale * shared_y; \ + \ + shared_y = sub_group_broadcast(y.s0, 1); \ + elem = (char)(bits8.s2 & 0x000000FF); \ + total_sums += convert_int(elem) * scale * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 1); \ + elem = (char)((bits8.s2 & 0x0000FF00) >> 8); \ + total_sums += convert_int(elem) * scale * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 1); \ + elem = (char)((bits8.s2 & 0x00FF0000) >> 16); \ + total_sums += convert_int(elem) * scale * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 1); \ + elem = (char)((bits8.s2 & 0xFF000000) >> 24); \ + total_sums += convert_int(elem) * scale * shared_y; \ + \ + shared_y = sub_group_broadcast(y.s4, 1); \ + elem = (char)(bits8.s3 & 0x000000FF); \ + total_sums += convert_int(elem) * scale * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 1); \ + elem = (char)((bits8.s3 & 0x0000FF00) >> 8); \ + total_sums += convert_int(elem) * scale * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 1); \ + elem = (char)((bits8.s3 & 0x00FF0000) >> 16); \ + total_sums += convert_int(elem) * scale * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 1); \ + elem = (char)((bits8.s3 & 0xFF000000) >> 24); \ + total_sums += convert_int(elem) * scale * shared_y; \ + \ + shared_y = sub_group_broadcast(y.s0, 2); \ + elem = (char)(bits8.s4 & 0x000000FF); \ + total_sums += convert_int(elem) * scale * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 2); \ + elem = (char)((bits8.s4 & 0x0000FF00) >> 8); \ + total_sums += convert_int(elem) * scale * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 2); \ + elem = (char)((bits8.s4 & 0x00FF0000) >> 16); \ + total_sums += convert_int(elem) * scale * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 2); \ + elem = (char)((bits8.s4 & 0xFF000000) >> 24); \ + total_sums += convert_int(elem) * scale * shared_y; \ + \ + shared_y = sub_group_broadcast(y.s4, 2); \ + elem = (char)(bits8.s5 & 0x000000FF); \ + total_sums += convert_int(elem) * scale * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 2); \ + elem = (char)((bits8.s5 & 0x0000FF00) >> 8); \ + total_sums += convert_int(elem) * scale * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 2); \ + elem = (char)((bits8.s5 & 0x00FF0000) >> 16); \ + total_sums += convert_int(elem) * scale * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 2); \ + elem = (char)((bits8.s5 & 0xFF000000) >> 24); \ + total_sums += convert_int(elem) * scale * shared_y; \ + \ + shared_y = sub_group_broadcast(y.s0, 3); \ + elem = (char)(bits8.s6 & 0x000000FF); \ + total_sums += convert_int(elem) * scale * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 3); \ + elem = (char)((bits8.s6 & 0x0000FF00) >> 8); \ + total_sums += convert_int(elem) * scale * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 3); \ + elem = (char)((bits8.s6 & 0x00FF0000) >> 16); \ + total_sums += convert_int(elem) * scale * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 3); \ + elem = (char)((bits8.s6 & 0xFF000000) >> 24); \ + total_sums += convert_int(elem) * scale * shared_y; \ + \ + shared_y = sub_group_broadcast(y.s4, 3); \ + elem = (char)(bits8.s7 & 0x000000FF); \ + total_sums += convert_int(elem) * scale * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 3); \ + elem = (char)((bits8.s7 & 0x0000FF00) >> 8); \ + total_sums += convert_int(elem) * scale * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 3); \ + elem = (char)((bits8.s7 & 0x00FF0000) >> 16); \ + total_sums += convert_int(elem) * scale * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 3); \ + elem = (char)((bits8.s7 & 0xFF000000) >> 24); \ + total_sums += convert_int(elem) * scale * shared_y; \ + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +__kernel void kernel_gemv_noshuffle( + __read_only image1d_buffer_t src0_q, // quantized A + global half * src0_d, // A scales + __read_only image1d_buffer_t src1, // B + ulong offset1, // offset to B (0) + global float * dst, // C + ulong offsetd, // offset to C + int ne00, // K + int ne01, // M + int ne02, // 1 + int ne10, // K + int ne12, // 1 + int ne0, // M + int ne1, // N + int r2, // 1 + int r3) +{ + uint groupId = get_local_id(1); + uint gid = get_global_id(0); + ushort slid = get_sub_group_local_id(); + + uint K = ne00; + uint M = ne01; + + uint LINE_STRIDE_A = M; + uint BLOCK_STRIDE_A = 8 * M; // 32 / 4 = 8 + + __private uint8 regA; + __private half regS; + __private float8 regB; + + __private float totalSum = (float)(0.0f); + + // loop along K in block granularity, skip 4 blocks every iter + #pragma unroll 1 /* tell compiler not to unroll */ + for (uint k = groupId; k < (K / QK8_0); k += N_SIMDGROUP) { + regS = src0_d[gid + k * LINE_STRIDE_A]; // each fiber loads scale of one rows + // first 4 fibers in each wave load 8 B values to its private scope + if (slid < 4) { + regB.s0123 = read_imagef(src1, (slid * 2 + k * 8)); + regB.s4567 = read_imagef(src1, (1 + slid * 2 + k * 8)); + } + + // load weights for one block in consecutive rows + regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 0)).x; + regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 1)).x; + regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 2)).x; + regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 3)).x; + regA.s4 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 4)).x; + regA.s5 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 5)).x; + regA.s6 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x; + regA.s7 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x; + + dequantizeBlockAccum_ns_sgbroadcast_1(totalSum, regA, regS, regB); + } + + // reduction in local memory, assumes #wave=4 + __local float reduceLM[SIMDGROUP_WIDTH * 3]; + if (groupId == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = totalSum; + if (groupId == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = totalSum; + if (groupId == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = totalSum; + barrier(CLK_LOCAL_MEM_FENCE); + if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 0 + slid]; + if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 1 + slid]; + if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 2 + slid]; + + // 1 outputs per fiber in wave 0 + if (groupId == 0) { + dst = (global float*)((global char*)dst + offsetd); + dst[gid] = totalSum; + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl b/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl new file mode 100644 index 00000000000..51ce2121ce2 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl @@ -0,0 +1,129 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable + +#ifdef cl_qcom_reqd_sub_group_size +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_128 +#endif + +kernel void kernel_mul_mm_q8_0_f32_8x4( + global const uint * src0_q, + global const half * src0_d, + __read_only image1d_buffer_t src1, + global float * dst, + int k, + int m, + int n, + int n_no_padding, + ulong offsetd +) { + + int m_4 = m >> 2; + int n_4 = n >> 2; + + int gy = get_global_id(0); + int gx = get_global_id(1); + int gx_2 = gx << 2; + dst = (global float *)((global char*)dst + offsetd); + + + half8 c0 = 0, c1 = 0, c2 = 0, c3 = 0; + half8 B; + half4 deq; + + __global const uint* wptr = src0_q + gx_2; + __global const half* sptr = src0_d + gx_2; + + for (int i = 0; i < k; i += 4) { + uint4 pack4 = vload4(0, wptr + (i / 4) * m); + half4 scale = vload4(0, sptr + (i / 32) * m); + + char4 p0 = as_char4(pack4.s0); + char4 p1 = as_char4(pack4.s1); + char4 p2 = as_char4(pack4.s2); + char4 p3 = as_char4(pack4.s3); + + // ------------------- j = 0 (k = i+0) ------------------- + B.s0123 = read_imageh(src1, gy * 2 + (i + 0) * n_4); + B.s4567 = read_imageh(src1, gy * 2 + (i + 0) * n_4 + 1); + + half4 wj0 = convert_half4((char4)(p0.s0, p1.s0, p2.s0, p3.s0)) * scale; + + c0 += B * wj0.s0; + c1 += B * wj0.s1; + c2 += B * wj0.s2; + c3 += B * wj0.s3; + + // ------------------- j = 1 (k = i+1) ------------------- + B.s0123 = read_imageh(src1, gy * 2 + (i + 1) * n_4); + B.s4567 = read_imageh(src1, gy * 2 + (i + 1) * n_4 + 1); + + half4 wj1 = convert_half4((char4)(p0.s1, p1.s1, p2.s1, p3.s1)) * scale; + + c0 += B * wj1.s0; + c1 += B * wj1.s1; + c2 += B * wj1.s2; + c3 += B * wj1.s3; + + // ------------------- j = 2 (k = i+2) ------------------- + B.s0123 = read_imageh(src1, gy * 2 + (i + 2) * n_4); + B.s4567 = read_imageh(src1, gy * 2 + (i + 2) * n_4 + 1); + + half4 wj2 = convert_half4((char4)(p0.s2, p1.s2, p2.s2, p3.s2)) * scale; + + c0 += B * wj2.s0; + c1 += B * wj2.s1; + c2 += B * wj2.s2; + c3 += B * wj2.s3; + + // ------------------- j = 3 (k = i+3) ------------------- + B.s0123 = read_imageh(src1, gy * 2 + (i + 3) * n_4); + B.s4567 = read_imageh(src1, gy * 2 + (i + 3) * n_4 + 1); + + half4 wj3 = convert_half4((char4)(p0.s3, p1.s3, p2.s3, p3.s3)) * scale; + + c0 += B * wj3.s0; + c1 += B * wj3.s1; + c2 += B * wj3.s2; + c3 += B * wj3.s3; + } + + int idx = (gy << 3) * m + (gx << 2); + + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s0, c1.s0, c2.s0, c3.s0), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s1, c1.s1, c2.s1, c3.s1), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s2, c1.s2, c2.s2, c3.s2), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s3, c1.s3, c2.s3, c3.s3), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s4, c1.s4, c2.s4, c3.s4), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s5, c1.s5, c2.s5, c3.s5), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s6, c1.s6, c2.s6, c3.s6), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s7, c1.s7, c2.s7, c3.s7), 0, dst + idx); + } +} From 9b927dd8496495b247e47a01435f95478d30583b Mon Sep 17 00:00:00 2001 From: nullname Date: Sat, 31 Jan 2026 13:14:20 +0800 Subject: [PATCH 079/831] ggml-hexagon: flash-attention and reduce-sum optimizations (llama/19141) * wip * ggml-hexagon: add vectorized dot product function for FP32 and FP16 accumulation * ggml-hexagon: optimize dot product functions for FP16 and FP32 with new vectorized implementations * wip * ggml-hexagon: optimize hvx_vec_dump_f32_n and hvx_vec_reduce_sum_qf32x2 functions for improved performance * ggml-hexagon: refactor dot product functions to use a common loading function for improved readability * optimize vector dot product functions to use unified reduction for improved performance * wip * ggml-hexagon: add vectorized dot product function for FP32 and FP16 accumulation * ggml-hexagon: optimize dot product functions for FP16 and FP32 with new vectorized implementations * wip * ggml-hexagon: optimize hvx_vec_dump_f32_n and hvx_vec_reduce_sum_qf32x2 functions for improved performance * ggml-hexagon: refactor dot product functions to use a common loading function for improved readability * optimize vector dot product functions to use unified reduction for improved performance * hexagon: optimize reduce-sum for v75+ * hexagon: always keep row_sums in sf/fp32 * ggml-hexagon: enhance directory checks for HEXAGON_SDK_ROOT and HEXAGON_TOOLS_ROOT * fix compiling error after rebase --------- Co-authored-by: Max Krasnyansky --- ggml/src/ggml-hexagon/CMakeLists.txt | 16 ++- ggml/src/ggml-hexagon/htp/flash-attn-ops.c | 159 ++++++++++++++++++--- ggml/src/ggml-hexagon/htp/hvx-dump.h | 9 +- ggml/src/ggml-hexagon/htp/hvx-reduce.h | 41 ++++++ ggml/src/ggml-hexagon/htp/matmul-ops.c | 107 +++++++------- ggml/src/ggml-hexagon/htp/softmax-ops.c | 4 +- ggml/src/ggml-hexagon/htp/unary-ops.c | 4 +- 7 files changed, 248 insertions(+), 92 deletions(-) diff --git a/ggml/src/ggml-hexagon/CMakeLists.txt b/ggml/src/ggml-hexagon/CMakeLists.txt index 2b69197017f..f3a583543c6 100644 --- a/ggml/src/ggml-hexagon/CMakeLists.txt +++ b/ggml/src/ggml-hexagon/CMakeLists.txt @@ -1,8 +1,20 @@ file(TO_CMAKE_PATH "${HEXAGON_SDK_ROOT}" HEXAGON_SDK_ROOT) file(TO_CMAKE_PATH "${HEXAGON_TOOLS_ROOT}" HEXAGON_TOOLS_ROOT) -if (NOT IS_DIRECTORY "${HEXAGON_SDK_ROOT}" OR NOT IS_DIRECTORY "${HEXAGON_TOOLS_ROOT}") - message(FATAL_ERROR "Make sure HEXAGON_SDK_ROOT and HEXAGON_TOOLS_ROOT point to the correct Hexagon SDK installation.") +if (NOT IS_DIRECTORY "${HEXAGON_SDK_ROOT}") + message(FATAL_ERROR "Make sure HEXAGON_SDK_ROOT point to the correct Hexagon SDK installation.") +endif() + +if (NOT IS_DIRECTORY "${HEXAGON_TOOLS_ROOT}") + message("Try to read HEXAGON_TOOLS_ROOT from hexagon_sdk.json") + file(READ "${HEXAGON_SDK_ROOT}/hexagon_sdk.json" HEXAGON_SDK_CONFIG_PATH) + string(JSON HEXAGON_TOOLS_PATH GET ${HEXAGON_SDK_CONFIG_PATH} "root" "tools" "info" 0 "path") + message("Found HEXAGON_TOOLS_PATH: ${HEXAGON_TOOLS_PATH}") + set(HEXAGON_TOOLS_ROOT "${HEXAGON_SDK_ROOT}/${HEXAGON_TOOLS_PATH}") + file(TO_CMAKE_PATH "${HEXAGON_TOOLS_ROOT}" HEXAGON_TOOLS_ROOT) + if (NOT IS_DIRECTORY "${HEXAGON_TOOLS_ROOT}") + message(FATAL_ERROR "Make sure HEXAGON_TOOLS_ROOT point to the correct Hexagon SDK installation.") + endif() endif() message(STATUS "hexagon: using ${HEXAGON_SDK_ROOT} and ${HEXAGON_TOOLS_ROOT} for building libggml-htp skels") diff --git a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c index c7cb2a4e0bc..c1846374437 100644 --- a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +++ b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c @@ -17,6 +17,12 @@ #include "htp-msg.h" #include "htp-ops.h" +static inline HVX_Vector hvx_load_f32_to_f16(const HVX_Vector * restrict src, const HVX_Vector zero) { + HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(src[0], zero); // 32 elements + HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(src[1], zero); // 32 elements + return Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf))); +} + // Dot product of FP32 and FP16 vectors, accumulating to float static inline void hvx_dot_f32_f16_aa(float * restrict r, const void * restrict y, const void * restrict x, unsigned int n, float s) { const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp32 @@ -33,23 +39,19 @@ static inline void hvx_dot_f32_f16_aa(float * restrict r, const void * restrict #pragma unroll(4) for (i = 0; i < nvec; i++) { // Load y (fp32) and convert into fp16 - HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero); // 32 elements - HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero); // 32 elements - HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf))); + HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero); // Load x (fp16) HVX_Vector x_hf = vx[i]; HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); - rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); + rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum)); } if (nloe) { // Load y (fp32) and convert into fp16 - HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero); // 32 elements - HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero); // 32 elements - HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf))); + HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero); // Load x (fp16) HVX_Vector x_hf = vx[i]; @@ -62,13 +64,72 @@ static inline void hvx_dot_f32_f16_aa(float * restrict r, const void * restrict HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); - rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); + rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum)); } - rsum = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(rsum), hvx_vec_splat_f32(s)); - rsum = Q6_Vsf_equals_Vqf32(hvx_vec_reduce_sum_qf32(rsum)); + rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32(rsum)); + hvx_vec_store_u(r, 4, Q6_Vsf_equals_Vqf32(rsum)); +} + +// Dot product of FP32 and FP16 vectors, accumulating to float +static inline void hvx_dot_f32_f16_aa_rx2(float * restrict r, + const void * restrict y, + const void * restrict x0, + const void * restrict x1, + unsigned int n, + float s) { + const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp32 + const HVX_Vector * restrict vx0 = (const HVX_Vector * restrict) x0; // fp16 + const HVX_Vector * restrict vx1 = (const HVX_Vector * restrict) x1; // fp16 + + uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors + uint32_t nloe = n % VLEN_FP16; // leftover elements + + const HVX_Vector zero = Q6_V_vsplat_R(0); + HVX_Vector rsum0 = Q6_V_vsplat_R(0); + HVX_Vector rsum1 = Q6_V_vsplat_R(0); + + uint32_t i = 0; - hvx_vec_store_u(r, 4, rsum); + #pragma unroll(2) + for (i = 0; i < nvec; i++) { + // Load y (fp32) and convert into fp16 + HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero); + // Load x (fp16) + HVX_Vector x0_hf = vx0[i]; + HVX_Vector x1_hf = vx1[i]; + + HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf); + HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf); + + rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0)); + rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1)); + } + + if (nloe) { + // Load y (fp32) and convert into fp16 + HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero); + + // Load x (fp16) + HVX_Vector x0_hf = vx0[i]; + HVX_Vector x1_hf = vx1[i]; + + // Zero-out unused elements + // Note that we need to clear both x and y because they may contain NANs + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2); + x0_hf = Q6_V_vand_QV(bmask, x0_hf); + x1_hf = Q6_V_vand_QV(bmask, x1_hf); + y_hf = Q6_V_vand_QV(bmask, y_hf); + + HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf); + HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf); + + rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0)); + rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1)); + } + + HVX_Vector rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32x2(rsum0, rsum1)); + hvx_vec_store_u(r, 8, Q6_Vsf_equals_Vqf32(rsum)); } // Dot product of two F16 vectors, accumulating to float @@ -91,7 +152,7 @@ static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); - rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); + rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum)); } if (nloe) { @@ -103,12 +164,62 @@ static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); - rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); + rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum)); + } + + rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32(rsum)); + hvx_vec_store_u(r, 4, Q6_Vsf_equals_Vqf32(rsum)); +} + +static inline void hvx_dot_f16_f16_aa_rx2(float * restrict r, + const void * restrict y, + const void * restrict x0, + const void * restrict x1, + unsigned int n, + float s) { + const HVX_Vector * restrict vx0 = (const HVX_Vector * restrict) x0; // fp16 + const HVX_Vector * restrict vx1 = (const HVX_Vector * restrict) x1; // fp16 + const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp16 + + uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors + uint32_t nloe = n % VLEN_FP16; // leftover elements + + const HVX_Vector zero = Q6_V_vsplat_R(0); + HVX_Vector rsum0 = Q6_V_vsplat_R(0); + HVX_Vector rsum1 = Q6_V_vsplat_R(0); + + uint32_t i = 0; + + #pragma unroll(4) + for (i = 0; i < nvec; i++) { + HVX_Vector y_hf = vy[i]; + HVX_Vector x0_hf = vx0[i]; + HVX_Vector x1_hf = vx1[i]; + + HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf); + HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf); + + rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0)); + rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1)); + } + + if (nloe) { + HVX_Vector y_hf = vy[i]; + + // Load x (fp16) and zero-out unused elements + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2); + HVX_Vector x0_hf = Q6_V_vand_QV(bmask, vx0[i]); + HVX_Vector x1_hf = Q6_V_vand_QV(bmask, vx1[i]); + + HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf); + HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf); + + rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0)); + rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1)); } - rsum = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(rsum), hvx_vec_splat_f32(s)); - rsum = Q6_Vsf_equals_Vqf32(hvx_vec_reduce_sum_qf32(rsum)); - hvx_vec_store_u(r, 4, rsum); + HVX_Vector rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32x2(rsum0, rsum1)); + hvx_vec_store_u(r, 8, Q6_Vsf_equals_Vqf32(rsum)); } // MAD: y (F32) += x (F16) * s (float) @@ -317,20 +428,22 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in // Inner loop processing the block from VTCM uint32_t ic = 0; + const bool is_q_fp32 = (q->type == HTP_TYPE_F32); + // Process in blocks of 32 (VLEN_FP32) - static_assert(FLASH_ATTN_BLOCK_SIZE / VLEN_FP32 == 4, "FLASH_ATTN_BLOCK_SIZE changed, fix HVX_Vector_x4 usage"); + static_assert(FLASH_ATTN_BLOCK_SIZE / VLEN_FP32 <= 4, "FLASH_ATTN_BLOCK_SIZE changed, fix HVX_Vector_x4 usage"); HVX_Vector_x4 scores_x4; HVX_Vector v_max = hvx_vec_splat_f32(-INFINITY); for (uint32_t iv = 0; ic + VLEN_FP32 <= current_block_size; ic += VLEN_FP32, ++iv) { // 1. Compute scores - float __attribute__((aligned(VLEN))) scores_arr[FLASH_ATTN_BLOCK_SIZE]; - for (int j = 0; j < VLEN_FP32; ++j) { + float __attribute__((aligned(VLEN))) scores_arr[VLEN_FP32]; + for (int j = 0; j < VLEN_FP32; j += 2) { const uint32_t cur_ic = ic + j; const uint8_t * k_ptr = k_base + cur_ic * size_k_row_padded; - if (q->type == HTP_TYPE_F32) { - hvx_dot_f32_f16_aa(&scores_arr[j], q_ptr_vtcm, k_ptr, DK, scale); + if (is_q_fp32) { + hvx_dot_f32_f16_aa_rx2(&scores_arr[j], q_ptr_vtcm, k_ptr, k_ptr + size_k_row_padded, DK, scale); } else { - hvx_dot_f16_f16_aa(&scores_arr[j], q_ptr_vtcm, k_ptr, DK, scale); + hvx_dot_f16_f16_aa_rx2(&scores_arr[j], q_ptr_vtcm, k_ptr, k_ptr + size_k_row_padded, DK, scale); } } @@ -403,7 +516,7 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in float s_val; const uint8_t * k_ptr = k_base + ic * size_k_row_padded; - if (q->type == HTP_TYPE_F32) { + if (is_q_fp32) { hvx_dot_f32_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale); } else { hvx_dot_f16_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale); diff --git a/ggml/src/ggml-hexagon/htp/hvx-dump.h b/ggml/src/ggml-hexagon/htp/hvx-dump.h index e882227893e..85201fc3453 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-dump.h +++ b/ggml/src/ggml-hexagon/htp/hvx-dump.h @@ -28,19 +28,16 @@ static void hvx_vec_dump_f16(char * pref, HVX_Vector v) { } static void hvx_vec_dump_f32_n(char * pref, HVX_Vector v, uint32_t n) { - union { - HVX_Vector v; - float d[32]; - } u = { .v = v }; + HVX_VectorAlias u = { .v = v }; const uint32_t n0 = n / 16; const uint32_t n1 = n % 16; int i = 0; for (; i < n0; i++) { - hex_dump_f32_line(pref, u.d + (16 * i), 16); + hex_dump_f32_line(pref, u.fp32 + (16 * i), 16); } if (n1) { - hex_dump_f32_line(pref, u.d + (16 * i), n1); + hex_dump_f32_line(pref, u.fp32 + (16 * i), n1); } } diff --git a/ggml/src/ggml-hexagon/htp/hvx-reduce.h b/ggml/src/ggml-hexagon/htp/hvx-reduce.h index 8845fe73ea1..1ca7c05d983 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-reduce.h +++ b/ggml/src/ggml-hexagon/htp/hvx-reduce.h @@ -44,6 +44,45 @@ static inline HVX_Vector hvx_vec_reduce_sum_qf32(HVX_Vector in) { return hvx_vec_reduce_sum_n_qf32(in, 32); } +#if __HVX_ARCH__ > 75 + +static inline HVX_Vector hvx_vec_reduce_sum_f32x2(HVX_Vector in0, HVX_Vector in1) { + HVX_VectorPair sump = Q6_W_vshuff_VVR(in1, in0, 4); + HVX_Vector sum_sf = Q6_Vsf_vadd_VsfVsf(Q6_V_lo_W(sump), Q6_V_hi_W(sump)); + + sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 2)); + sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 4)); + sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 8)); + sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 16)); + return sum_sf; +} + +static inline HVX_Vector hvx_vec_reduce_sum_n_f32(HVX_Vector in, unsigned int n) { + unsigned int total = n * 4; // total vec nbytes + unsigned int width = 4; // fp32 nbytes + + HVX_Vector sum = in, sum_t; + while (width < total) { + sum_t = Q6_V_vror_VR(sum, width); // rotate right + sum = Q6_Vsf_vadd_VsfVsf(sum, sum_t); // elementwise sum + width = width << 1; + } + return sum; +} + +#else + +static inline HVX_Vector hvx_vec_reduce_sum_f32x2(HVX_Vector in0, HVX_Vector in1) { + HVX_VectorPair sump = Q6_W_vshuff_VVR(in1, in0, 4); + HVX_Vector sum_qf = Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(sump), Q6_V_hi_W(sump)); + + sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 2)); + sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 4)); + sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 8)); + sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 16)); + return Q6_Vsf_equals_Vqf32(sum_qf); +} + static inline HVX_Vector hvx_vec_reduce_sum_n_f32(HVX_Vector in, unsigned int n) { unsigned int total = n * 4; // total vec nbytes unsigned int width = 4; // fp32 nbytes @@ -57,6 +96,8 @@ static inline HVX_Vector hvx_vec_reduce_sum_n_f32(HVX_Vector in, unsigned int n) return sum; } +#endif + static inline HVX_Vector hvx_vec_reduce_sum_f32(HVX_Vector in) { return hvx_vec_reduce_sum_n_f32(in, 32); } diff --git a/ggml/src/ggml-hexagon/htp/matmul-ops.c b/ggml/src/ggml-hexagon/htp/matmul-ops.c index 1603ff2b3b6..d251eeed33a 100644 --- a/ggml/src/ggml-hexagon/htp/matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/matmul-ops.c @@ -11,6 +11,7 @@ #include "hex-dma.h" #include "hvx-utils.h" +#include "hvx-dump.h" #define GGML_COMMON_DECL_C #include "ggml-common.h" @@ -320,7 +321,7 @@ static void vec_dot_q4x4x2_q8x4x2(const int n, float * restrict s, const void * const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales - // Row sum (qf32) + // Row sum (sf) HVX_Vector r0_sum = Q6_V_vsplat_R(0); // Multiply and accumulate into int32. @@ -344,7 +345,7 @@ static void vec_dot_q4x4x2_q8x4x2(const int n, float * restrict s, const void * HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa); + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); } // Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks @@ -362,14 +363,14 @@ static void vec_dot_q4x4x2_q8x4x2(const int n, float * restrict s, const void * // Zero out unused scales HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); r0_dd = Q6_V_vand_QV(bmask, r0_dd); + r0_ia = Q6_V_vand_QV(bmask, r0_ia); HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa); + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); } - // Reduce and convert into fp32 - r0_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r0_sum)); + r0_sum = hvx_vec_reduce_sum_f32(r0_sum); hvx_vec_store_u(&s[0], 4, r0_sum); } @@ -402,7 +403,7 @@ static void vec_dot_q4x4x2_q8x4x2_rx2(const int n, const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales - // Row sum (qf32) + // Row sum (sf) HVX_Vector r0_sum = Q6_V_vsplat_R(0); HVX_Vector r1_sum = Q6_V_vsplat_R(0); @@ -432,8 +433,8 @@ static void vec_dot_q4x4x2_q8x4x2_rx2(const int n, HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); - r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa); - r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa); + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); } // Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks @@ -456,20 +457,18 @@ static void vec_dot_q4x4x2_q8x4x2_rx2(const int n, HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); r0_dd = Q6_V_vand_QV(bmask, r0_dd); r1_dd = Q6_V_vand_QV(bmask, r1_dd); + r0_ia = Q6_V_vand_QV(bmask, r0_ia); + r1_ia = Q6_V_vand_QV(bmask, r1_ia); HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); - r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa); - r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa); + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); } - // Convert into fp32 and reduce - r0_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r0_sum)); - r1_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r1_sum)); - HVX_VectorPair p0 = Q6_W_vshuff_VVR(r1_sum, r0_sum, 4); - - hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0)); + HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum); + hvx_vec_store_u(&s[0], 8, rsum); } static void vec_dot_q8x4x2_q8x4x2(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { @@ -493,7 +492,7 @@ static void vec_dot_q8x4x2_q8x4x2(const int n, float * restrict s, const void * const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales - // Row sum (qf32) + // Row sum (sf) HVX_Vector r0_sum = Q6_V_vsplat_R(0); // Multiply and accumulate into int32. @@ -517,7 +516,7 @@ static void vec_dot_q8x4x2_q8x4x2(const int n, float * restrict s, const void * HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa); + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); } // Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks @@ -535,14 +534,14 @@ static void vec_dot_q8x4x2_q8x4x2(const int n, float * restrict s, const void * // Zero out unused scales HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); r0_dd = Q6_V_vand_QV(bmask, r0_dd); + r0_ia = Q6_V_vand_QV(bmask, r0_ia); HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa); + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); } - // Reduce and convert into fp32 - r0_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r0_sum)); + r0_sum = hvx_vec_reduce_sum_f32(r0_sum); hvx_vec_store_u(&s[0], 4, r0_sum); } @@ -605,8 +604,8 @@ static void vec_dot_q8x4x2_q8x4x2_rx2(const int n, HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); - r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa); - r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa); + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); } // Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks @@ -629,20 +628,18 @@ static void vec_dot_q8x4x2_q8x4x2_rx2(const int n, HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); r0_dd = Q6_V_vand_QV(bmask, r0_dd); r1_dd = Q6_V_vand_QV(bmask, r1_dd); + r0_ia = Q6_V_vand_QV(bmask, r0_ia); + r1_ia = Q6_V_vand_QV(bmask, r1_ia); HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); - r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa); - r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa); + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); } - // Convert into fp32 and reduce - r0_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r0_sum)); - r1_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r1_sum)); - HVX_VectorPair p0 = Q6_W_vshuff_VVR(r1_sum, r0_sum, 4); - - hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0)); + HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum); + hvx_vec_store_u(&s[0], 8, rsum); } static void vec_dot_mxfp4x4x2_q8x4x2(const int n, @@ -669,7 +666,7 @@ static void vec_dot_mxfp4x4x2_q8x4x2(const int n, const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales - // Row sum (qf32) + // Row sum (sf) HVX_Vector r0_sum = Q6_V_vsplat_R(0); // Multiply and accumulate into int32. @@ -708,7 +705,7 @@ static void vec_dot_mxfp4x4x2_q8x4x2(const int n, HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa); + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); } // Process leftovers @@ -741,14 +738,14 @@ static void vec_dot_mxfp4x4x2_q8x4x2(const int n, // Zero-out unused scales HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); r0_dd = Q6_V_vand_QV(bmask, r0_dd); + r0_ia = Q6_V_vand_QV(bmask, r0_ia); HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa); + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); } - // Reduce and convert into fp32 - r0_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r0_sum)); + r0_sum = hvx_vec_reduce_sum_f32(r0_sum); hvx_vec_store_u(&s[0], 4, r0_sum); } @@ -781,13 +778,13 @@ static void vec_dot_mxfp4x4x2_q8x4x2_rx2(const int n, const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales - // Row sum (qf32) + // Row sum (sf) HVX_Vector r0_sum = Q6_V_vsplat_R(0); HVX_Vector r1_sum = Q6_V_vsplat_R(0); // Multiply and accumulate into int32. // Compute combined scale (fp32). - // Apply scale to acc and accumulate into the row sum (qf32). + // Apply scale to acc and accumulate into the row sum (f32). const uint32_t nb = n / qk; // num full blocks int32_t nloe = n % qk; // num leftover elemements (must be signed) @@ -829,8 +826,8 @@ static void vec_dot_mxfp4x4x2_q8x4x2_rx2(const int n, HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); - r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa); - r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa); + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); } // Process leftovers @@ -867,24 +864,22 @@ static void vec_dot_mxfp4x4x2_q8x4x2_rx2(const int n, HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d)); HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy_d)); - // Zero-out unused scales + // Zero-out unused values HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); r0_dd = Q6_V_vand_QV(bmask, r0_dd); r1_dd = Q6_V_vand_QV(bmask, r1_dd); + r0_ia = Q6_V_vand_QV(bmask, r0_ia); + r1_ia = Q6_V_vand_QV(bmask, r1_ia); HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); - r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa); - r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa); + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); } - // Convert into fp32 and reduce - r0_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r0_sum)); - r1_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r1_sum)); - HVX_VectorPair p0 = Q6_W_vshuff_VVR(r1_sum, r0_sum, 4); - - hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0)); + HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum); + hvx_vec_store_u(&s[0], 8, rsum); } static void vec_dot_f16_f16_aa(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { @@ -913,7 +908,7 @@ static void vec_dot_f16_f16_aa(const int n, float * restrict s, const void * res rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); } - rsum = Q6_Vsf_equals_Vqf32(hvx_vec_reduce_sum_qf32(rsum)); + rsum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(rsum)); hvx_vec_store_u(&s[0], 4, rsum); } @@ -957,11 +952,8 @@ static void vec_dot_f16_f16_aa_rx2(const int n, rsum1 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum1, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf))); } - rsum0 = Q6_Vsf_equals_Vqf32(hvx_vec_reduce_sum_qf32(rsum0)); - rsum1 = Q6_Vsf_equals_Vqf32(hvx_vec_reduce_sum_qf32(rsum1)); - HVX_VectorPair p0 = Q6_W_vshuff_VVR(rsum1, rsum0, 4); - - hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0)); + HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(Q6_Vsf_equals_Vqf32(rsum0), Q6_Vsf_equals_Vqf32(rsum1)); + hvx_vec_store_u(&s[0], 8, rsum); } static void vec_dot_f16_f16_uu(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { @@ -990,7 +982,7 @@ static void vec_dot_f16_f16_uu(const int n, float * restrict s, const void * res rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); } - rsum = Q6_Vsf_equals_Vqf32(hvx_vec_reduce_sum_qf32(rsum)); + rsum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(rsum)); hvx_vec_store_u(&s[0], 4, rsum); } @@ -1042,7 +1034,8 @@ static void vec_dot_f16_f32_uu(const int n, float * restrict s, const void * res rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); } - rsum = Q6_Vsf_equals_Vqf32(hvx_vec_reduce_sum_qf32(rsum)); + // Convert into fp32 and reduce + rsum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(rsum)); hvx_vec_store_u(&s[0], 4, rsum); } diff --git a/ggml/src/ggml-hexagon/htp/softmax-ops.c b/ggml/src/ggml-hexagon/htp/softmax-ops.c index 1b6b2eba4ae..e91a16d947f 100644 --- a/ggml/src/ggml-hexagon/htp/softmax-ops.c +++ b/ggml/src/ggml-hexagon/htp/softmax-ops.c @@ -154,8 +154,8 @@ static void hvx_fast_softmax_f32(const uint8_t * restrict src, v_pad[i] = v3; } - v = hvx_vec_reduce_sum_qf32(sum_vec); - sum_vec = hvx_vec_repl4(Q6_Vsf_equals_Vqf32(v)); + v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_vec)); + sum_vec = hvx_vec_repl4(v); HVX_VectorPred pos_sum = Q6_Q_vcmp_gt_VwVw(sum_vec, zero_v); HVX_Vector v4 = hvx_vec_inverse_f32(sum_vec); diff --git a/ggml/src/ggml-hexagon/htp/unary-ops.c b/ggml/src/ggml-hexagon/htp/unary-ops.c index be8be8c4e64..1a27cb6e63e 100644 --- a/ggml/src/ggml-hexagon/htp/unary-ops.c +++ b/ggml/src/ggml-hexagon/htp/unary-ops.c @@ -57,8 +57,8 @@ static void hvx_fast_rms_norm_f32(const uint8_t * restrict src, sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2); } - HVX_Vector reduced_sum = hvx_vec_reduce_sum_qf32(sum_v); - sum_v = hvx_vec_repl4(Q6_Vsf_equals_Vqf32(reduced_sum)); + HVX_Vector reduced_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_v)); + sum_v = hvx_vec_repl4(reduced_sum); HVX_Vector t_v = hvx_vec_splat_f32((float) num_elems); HVX_Vector denom_v = hvx_vec_inverse_f32(t_v); From aca5953d8d5b850f77220fc07563bfcc86a297d3 Mon Sep 17 00:00:00 2001 From: Max Krasnyansky Date: Sun, 1 Feb 2026 14:13:38 -0800 Subject: [PATCH 080/831] Bump cmake max version (needed for Windows on Snapdragon builds) (llama/19188) * Bump max cmake version (needed for Windows on Snapdragon builds) * cmake: move max version setting into ggml/CMakeLists --- ggml/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 590242e3f01..aa0ecde02a7 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -1,4 +1,4 @@ -cmake_minimum_required(VERSION 3.14) # for add_link_options and implicit target directories. +cmake_minimum_required(VERSION 3.14...3.28) # for add_link_options and implicit target directories. project("ggml" C CXX ASM) ### GGML Version From a0256b8159d90a2640f01cca7aca5e2bc301419d Mon Sep 17 00:00:00 2001 From: Nikhil Jain Date: Sun, 1 Feb 2026 18:47:29 -0800 Subject: [PATCH 081/831] Remove pipeline cache mutexes (llama/19195) * Remove mutex for pipeline caches, since they are now per-thread. * Add comment * Run clang-format * Cleanup * Run CI again * Run CI once more * Run clang-format --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 215 ++++++++++++--------------- 1 file changed, 94 insertions(+), 121 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 22e2bfeb4ce..4ef50e365ef 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -146,8 +146,13 @@ struct webgpu_submission_futures { struct webgpu_buf_pool { std::vector free; - std::mutex mutex; - + // The pool must be synchronized because + // 1. The memset pool is shared globally by every ggml buffer, + // since allocating a pool per ggml buffer would consume too much memory. + // 2. For the per-thread buffer pools in webgpu_context, + // buffers are allocated and freed in Dawn callbacks, + // which can run on a different thread than the calling thread. + std::mutex mutex; std::condition_variable cv; void init(wgpu::Device device, @@ -266,7 +271,7 @@ struct webgpu_command { #endif }; -struct webgpu_capabilities_base { +struct webgpu_capabilities { wgpu::Limits limits; bool supports_subgroup_matrix = false; @@ -286,11 +291,11 @@ struct webgpu_global_context_struct { wgpu::Device device; wgpu::Queue queue; - webgpu_capabilities_base capabilities; + webgpu_capabilities capabilities; // Shared buffer to move data from device to host - wgpu::Buffer get_tensor_staging_buf; + wgpu::Buffer get_tensor_staging_buf; // Global mutex for pipeline and staging buffer, will be refactored to exclude pipeline caches. - std::recursive_mutex mutex; + std::recursive_mutex mutex; webgpu_buf_pool memset_buf_pool; std::map memset_pipelines; // variant or type index @@ -361,7 +366,6 @@ struct webgpu_context_struct { std::unordered_map pad_pipelines; size_t memset_bytes_per_thread; - }; typedef std::shared_ptr webgpu_context; @@ -383,9 +387,8 @@ struct ggml_backend_webgpu_device_context { // Per-thread data required to actually run WebGPU operations in a backend instance struct ggml_backend_webgpu_context { - webgpu_context webgpu_ctx; - std::once_flag init_once; - std::string name; + webgpu_context webgpu_ctx; + std::string name; }; // Per-thread data related to buffers @@ -861,20 +864,15 @@ static webgpu_command ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, g }; webgpu_pipeline pipeline; - { - // TODO: remove guard once pipeline caches are per-thread - std::lock_guard lock(ctx->global_ctx->mutex); - auto it = ctx->pad_pipelines.find(pipeline_key); - if (it != ctx->pad_pipelines.end()) { - pipeline = it->second; - } else { - ggml_webgpu_processed_shader processed = - ggml_webgpu_preprocess_pad_shader(ctx->p, wgsl_pad, shader_lib_ctx); - pipeline = - ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); - pipeline.context = processed.decisions; - ctx->pad_pipelines.emplace(pipeline_key, pipeline); - } + auto it = ctx->pad_pipelines.find(pipeline_key); + if (it != ctx->pad_pipelines.end()) { + pipeline = it->second; + } else { + ggml_webgpu_processed_shader processed = ggml_webgpu_preprocess_pad_shader(ctx->p, wgsl_pad, shader_lib_ctx); + pipeline = + ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); + pipeline.context = processed.decisions; + ctx->pad_pipelines.emplace(pipeline_key, pipeline); } ggml_webgpu_generic_shader_decisions decisions = @@ -944,20 +942,16 @@ static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, }; webgpu_pipeline pipeline; - // TODO: remove guard once pipeline caches are per-thread - { - std::lock_guard lock(ctx->global_ctx->mutex); - auto it = ctx->set_rows_pipelines.find(key); - if (it != ctx->set_rows_pipelines.end()) { - pipeline = it->second; - } else { - ggml_webgpu_processed_shader processed = - ggml_webgpu_preprocess_set_rows_shader(ctx->p, wgsl_set_rows, shader_lib_ctx); - pipeline = - ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); - pipeline.context = processed.decisions; - ctx->set_rows_pipelines.emplace(key, pipeline); - } + auto it = ctx->set_rows_pipelines.find(key); + if (it != ctx->set_rows_pipelines.end()) { + pipeline = it->second; + } else { + ggml_webgpu_processed_shader processed = + ggml_webgpu_preprocess_set_rows_shader(ctx->p, wgsl_set_rows, shader_lib_ctx); + pipeline = + ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); + pipeline.context = processed.decisions; + ctx->set_rows_pipelines.emplace(key, pipeline); } ggml_webgpu_generic_shader_decisions decisions = @@ -1261,29 +1255,25 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, }; webgpu_pipeline pipeline; - // TODO: remove guard once pipeline caches are per-thread - { - std::lock_guard lock(ctx->global_ctx->mutex); - auto it = ctx->flash_attn_pipelines.find(key); - if (it != ctx->flash_attn_pipelines.end()) { - pipeline = it->second; - } else { - ggml_webgpu_flash_attn_shader_lib_context shader_lib_ctx = { - .key = key, - .sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m, - .sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n, - .sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k, - .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize, - .max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size - }; - - ggml_webgpu_processed_shader processed = - ggml_webgpu_preprocess_flash_attn_shader(ctx->p, wgsl_flash_attn, shader_lib_ctx); - pipeline = - ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); - pipeline.context = processed.decisions; - ctx->flash_attn_pipelines.emplace(key, pipeline); - } + auto it = ctx->flash_attn_pipelines.find(key); + if (it != ctx->flash_attn_pipelines.end()) { + pipeline = it->second; + } else { + ggml_webgpu_flash_attn_shader_lib_context shader_lib_ctx = { + .key = key, + .sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m, + .sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n, + .sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k, + .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize, + .max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size + }; + + ggml_webgpu_processed_shader processed = + ggml_webgpu_preprocess_flash_attn_shader(ctx->p, wgsl_flash_attn, shader_lib_ctx); + pipeline = + ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); + pipeline.context = processed.decisions; + ctx->flash_attn_pipelines.emplace(key, pipeline); } ggml_webgpu_flash_attn_shader_decisions decisions = @@ -1308,20 +1298,16 @@ static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * s }; webgpu_pipeline pipeline; - { - // TODO: remove guard once pipeline caches are per-thread - std::lock_guard lock(ctx->global_ctx->mutex); - auto it = ctx->unary_pipelines.find(pipeline_key); - if (it != ctx->unary_pipelines.end()) { - pipeline = it->second; - } else { - ggml_webgpu_processed_shader processed = - ggml_webgpu_preprocess_unary_shader(ctx->p, wgsl_unary, shader_lib_ctx); - pipeline = - ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); - pipeline.context = processed.decisions; - ctx->unary_pipelines.emplace(pipeline_key, pipeline); - } + auto it = ctx->unary_pipelines.find(pipeline_key); + if (it != ctx->unary_pipelines.end()) { + pipeline = it->second; + } else { + ggml_webgpu_processed_shader processed = + ggml_webgpu_preprocess_unary_shader(ctx->p, wgsl_unary, shader_lib_ctx); + pipeline = + ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); + pipeline.context = processed.decisions; + ctx->unary_pipelines.emplace(pipeline_key, pipeline); } ggml_webgpu_generic_shader_decisions decisions = @@ -1743,19 +1729,15 @@ static webgpu_command ggml_webgpu_argmax(webgpu_context & ctx, ggml_tensor * src }; webgpu_pipeline pipeline; - { - // TODO: remove guard once pipeline caches are per-thread - std::lock_guard lock(ctx->global_ctx->mutex); - auto it = ctx->argmax_pipelines.find(shader_lib_ctx.vec4); - if (it != ctx->argmax_pipelines.end()) { - pipeline = it->second; - } else { - ggml_webgpu_processed_shader processed = - ggml_webgpu_preprocess_generic_shader(ctx->p, wgsl_argmax, shader_lib_ctx, "argmax"); - pipeline = - ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); - ctx->argmax_pipelines.emplace(shader_lib_ctx.vec4, pipeline); - } + auto it = ctx->argmax_pipelines.find(shader_lib_ctx.vec4); + if (it != ctx->argmax_pipelines.end()) { + pipeline = it->second; + } else { + ggml_webgpu_processed_shader processed = + ggml_webgpu_preprocess_generic_shader(ctx->p, wgsl_argmax, shader_lib_ctx, "argmax"); + pipeline = + ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); + ctx->argmax_pipelines.emplace(shader_lib_ctx.vec4, pipeline); } uint32_t wg_x = ggml_nelements(dst); return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); @@ -1772,9 +1754,8 @@ static webgpu_command ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * sr .order = order }; - std::lock_guard lock(ctx->global_ctx->mutex); - webgpu_pipeline argsort_pipeline; - auto it = ctx->argsort_pipelines.find(order); + webgpu_pipeline argsort_pipeline; + auto it = ctx->argsort_pipelines.find(order); if (it != ctx->argsort_pipelines.end()) { argsort_pipeline = it->second; } else { @@ -1963,19 +1944,15 @@ static webgpu_command ggml_webgpu_cumsum(webgpu_context & ctx, ggml_tensor * src .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, }; webgpu_pipeline pipeline; - // TODO: remove guard once pipeline caches are per-thread - { - std::lock_guard lock(ctx->global_ctx->mutex); - auto it = ctx->cumsum_pipelines.find(1); - if (it != ctx->cumsum_pipelines.end()) { - pipeline = it->second; - } else { - ggml_webgpu_processed_shader processed = - ggml_webgpu_preprocess_generic_shader(ctx->p, wgsl_cumsum, shader_lib_ctx, "cumsum"); - pipeline = - ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); - ctx->cumsum_pipelines.emplace(1, pipeline); - } + auto it = ctx->cumsum_pipelines.find(1); + if (it != ctx->cumsum_pipelines.end()) { + pipeline = it->second; + } else { + ggml_webgpu_processed_shader processed = + ggml_webgpu_preprocess_generic_shader(ctx->p, wgsl_cumsum, shader_lib_ctx, "cumsum"); + pipeline = + ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); + ctx->cumsum_pipelines.emplace(1, pipeline); } uint32_t wg_x = ggml_nrows(dst); return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); @@ -2009,19 +1986,15 @@ static webgpu_command ggml_webgpu_sum_rows(webgpu_context & ctx, ggml_tensor * s }; webgpu_pipeline pipeline; - { - // TODO: remove guard once pipeline caches are per-thread - std::lock_guard lock(ctx->global_ctx->mutex); - auto it = ctx->sum_rows_pipelines.find(1); - if (it != ctx->sum_rows_pipelines.end()) { - pipeline = it->second; - } else { - ggml_webgpu_processed_shader processed = - ggml_webgpu_preprocess_generic_shader(ctx->p, wgsl_sum_rows, shader_lib_ctx, "sum_rows"); - pipeline = - ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); - ctx->sum_rows_pipelines.emplace(1, pipeline); - } + auto it = ctx->sum_rows_pipelines.find(1); + if (it != ctx->sum_rows_pipelines.end()) { + pipeline = it->second; + } else { + ggml_webgpu_processed_shader processed = + ggml_webgpu_preprocess_generic_shader(ctx->p, wgsl_sum_rows, shader_lib_ctx, "sum_rows"); + pipeline = + ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); + ctx->sum_rows_pipelines.emplace(1, pipeline); } uint32_t wg_x = total_sum ? 1 : ggml_nrows(dst); return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); @@ -3016,10 +2989,10 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { #ifdef GGML_WEBGPU_GPU_PROFILE // Initialize buffer pool for timestamp queries, used for profiling - ctx->webgpu_global_ctx->timestamp_query_buf_pool.init(ctx->webgpu_global_ctx->device, WEBGPU_NUM_TIMESTAMP_QUERY_BUFS, - WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES, - wgpu::BufferUsage::QueryResolve | wgpu::BufferUsage::CopySrc, - wgpu::BufferUsage::MapRead | wgpu::BufferUsage::CopyDst); + ctx->webgpu_global_ctx->timestamp_query_buf_pool.init( + ctx->webgpu_global_ctx->device, WEBGPU_NUM_TIMESTAMP_QUERY_BUFS, WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES, + wgpu::BufferUsage::QueryResolve | wgpu::BufferUsage::CopySrc, + wgpu::BufferUsage::MapRead | wgpu::BufferUsage::CopyDst); #endif GGML_LOG_INFO( From 0e219ebf89382291f495cdc32f2c83575bf8d21f Mon Sep 17 00:00:00 2001 From: Christian Kastner Date: Mon, 2 Feb 2026 07:38:55 +0100 Subject: [PATCH 082/831] docs : Minor cleanups (llama/19252) * Update old URLs to github.com/ggml-org/ * Bump copyrights --- LICENSE | 2 +- ggml/include/ggml-cann.h | 2 +- ggml/include/ggml.h | 2 +- ggml/src/ggml-cann/acl_tensor.cpp | 2 +- ggml/src/ggml-cann/acl_tensor.h | 2 +- ggml/src/ggml-cann/aclnn_ops.cpp | 2 +- ggml/src/ggml-cann/aclnn_ops.h | 2 +- ggml/src/ggml-cann/common.h | 2 +- ggml/src/ggml-cann/ggml-cann.cpp | 2 +- ggml/src/ggml-metal/CMakeLists.txt | 2 +- ggml/src/ggml-opencl/ggml-opencl.cpp | 2 +- ggml/src/ggml-sycl/ggml-sycl.cpp | 2 +- ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp | 2 +- ggml/src/ggml.c | 2 +- 14 files changed, 14 insertions(+), 14 deletions(-) diff --git a/LICENSE b/LICENSE index acb96ce78e0..e7dca554bcb 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ MIT License -Copyright (c) 2023-2024 The ggml authors +Copyright (c) 2023-2026 The ggml authors Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/ggml/include/ggml-cann.h b/ggml/include/ggml-cann.h index b469e228d06..74af465337a 100644 --- a/ggml/include/ggml-cann.h +++ b/ggml/include/ggml-cann.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023-2024 The ggml authors + * Copyright (c) 2023-2026 The ggml authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 1988d16dc42..f759e2d5883 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -6,7 +6,7 @@ // This documentation is still a work in progress. // If you wish some specific topics to be covered, feel free to drop a comment: // -// https://github.com/ggerganov/whisper.cpp/issues/40 +// https://github.com/ggml-org/whisper.cpp/issues/40 // // ## Overview // diff --git a/ggml/src/ggml-cann/acl_tensor.cpp b/ggml/src/ggml-cann/acl_tensor.cpp index 7b7042a1f54..e95d3c4d88d 100644 --- a/ggml/src/ggml-cann/acl_tensor.cpp +++ b/ggml/src/ggml-cann/acl_tensor.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023-2024 The ggml authors + * Copyright (c) 2023-2026 The ggml authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to diff --git a/ggml/src/ggml-cann/acl_tensor.h b/ggml/src/ggml-cann/acl_tensor.h index 7deac383420..4737773a4d4 100644 --- a/ggml/src/ggml-cann/acl_tensor.h +++ b/ggml/src/ggml-cann/acl_tensor.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023-2024 The ggml authors + * Copyright (c) 2023-2026 The ggml authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index 02867e4fdb5..87ac05748e8 100644 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023-2024 The ggml authors + * Copyright (c) 2023-2026 The ggml authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to diff --git a/ggml/src/ggml-cann/aclnn_ops.h b/ggml/src/ggml-cann/aclnn_ops.h index b76e4707ac7..3effa1c289c 100644 --- a/ggml/src/ggml-cann/aclnn_ops.h +++ b/ggml/src/ggml-cann/aclnn_ops.h @@ -1,5 +1,5 @@ /** - * Copyright (c) 2023-2024 The ggml authors + * Copyright (c) 2023-2026 The ggml authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to diff --git a/ggml/src/ggml-cann/common.h b/ggml/src/ggml-cann/common.h index fb3e7572e2c..0120f0dfd1e 100644 --- a/ggml/src/ggml-cann/common.h +++ b/ggml/src/ggml-cann/common.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023-2024 The ggml authors + * Copyright (c) 2023-2026 The ggml authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index 42c6c67a40b..6b2dbdd3591 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023-2024 The ggml authors + * Copyright (c) 2023-2026 The ggml authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to diff --git a/ggml/src/ggml-metal/CMakeLists.txt b/ggml/src/ggml-metal/CMakeLists.txt index 9c0b3db8599..42054d841aa 100644 --- a/ggml/src/ggml-metal/CMakeLists.txt +++ b/ggml/src/ggml-metal/CMakeLists.txt @@ -71,7 +71,7 @@ else() # disabling fast math is needed in order to pass tests/test-backend-ops # note: adding -fno-inline fixes the tests when using MTL_SHADER_VALIDATION=1 # note: unfortunately, we have to call it default.metallib instead of ggml.metallib - # ref: https://github.com/ggerganov/whisper.cpp/issues/1720 + # ref: https://github.com/ggml-org/whisper.cpp/issues/1720 # note: adding -g causes segmentation fault during compile #set(XC_FLAGS -fno-fast-math -fno-inline -g) set(XC_FLAGS -fno-fast-math -fno-inline) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 4850c11d147..0f0eb3a9d87 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -3740,7 +3740,7 @@ static enum ggml_status ggml_backend_opencl_buffer_init_tensor(ggml_backend_buff // Reuse extra of the parent tensor. The offset of this view tensor // becomes `extra->offset + view_offs` and needs to be calculated when // it is used. This changes is needed because of the change to - // ggml_alloc.c in https://github.com/ggerganov/llama.cpp/pull/7640. + // ggml_alloc.c in https://github.com/ggml-org/llama.cpp/pull/7640. // `buffer` passed in here will always be `tensor->buffer`. It is OK // to allocate extras from the same buffer context for ordinary // intermediate tensors. But for views into kv cache tensors, doing so diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 74b4ed91cc0..12f1e7717b7 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -3390,7 +3390,7 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor // mmvq and mmq need the __dp4a instruction which is available for gen12+ - // Workaround in https://github.com/ggerganov/llama.cpp/commit/95f84d5ce8b449a9b16009434aca800df504a02e + // Workaround in https://github.com/ggml-org/llama.cpp/commit/95f84d5ce8b449a9b16009434aca800df504a02e use_mul_mat_q = use_mul_mat_q && (src0->type != GGML_TYPE_IQ2_XXS); #ifdef SYCL_USE_XMX use_mul_mat_q = use_mul_mat_q && (src1->ne[1] <= MMQ_MAX_BATCH_SIZE); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index bbdbf9dcaaa..ca486a288a1 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -330,7 +330,7 @@ void string_to_spv_func(std::string name, std::string in_path, std::string out_p std::vector cmd = {GLSLC, "-fshader-stage=compute", target_env, in_path, "-o", out_path}; #endif - // disable spirv-opt for coopmat shaders for https://github.com/ggerganov/llama.cpp/issues/10734 + // disable spirv-opt for coopmat shaders for https://github.com/ggml-org/llama.cpp/issues/10734 // disable spirv-opt for bf16 shaders for https://github.com/ggml-org/llama.cpp/issues/15344 // disable spirv-opt for rope shaders for https://github.com/ggml-org/llama.cpp/issues/16860 if (!coopmat && name.find("bf16") == std::string::npos && name.find("rope") == std::string::npos) { diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 1725ad16545..e1471b540ed 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -6562,7 +6562,7 @@ static void ggml_compute_backward( case GGML_OP_DIAG_MASK_INF: { if (src0_needs_grads) { /* ggml_diag_mask_inf_impl() shouldn't be here */ - /* ref: https://github.com/ggerganov/llama.cpp/pull/4203#discussion_r1412377992 */ + /* ref: https://github.com/ggml-org/llama.cpp/pull/4203#discussion_r1412377992 */ const int n_past = ((const int32_t *) tensor->op_params)[0]; ggml_add_or_set(ctx, cgraph, isrc0, ggml_diag_mask_zero_impl(ctx, grad, n_past, false)); } From 625c8d863e1444874aef4e5fb17d86fd1feafec3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Mon, 2 Feb 2026 10:00:05 +0100 Subject: [PATCH 083/831] ggml-backend: fix async set/get fallback sync (llama/19179) --- ggml/src/ggml-backend.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index 354876574a0..22c656996cc 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -258,6 +258,7 @@ void ggml_backend_tensor_set_async(ggml_backend_t backend, struct ggml_tensor * GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds"); if (backend->iface.set_tensor_async == NULL) { + ggml_backend_synchronize(backend); ggml_backend_tensor_set(tensor, data, offset, size); } else { backend->iface.set_tensor_async(backend, tensor, data, offset, size); @@ -271,6 +272,7 @@ void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_ten GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds"); if (backend->iface.get_tensor_async == NULL) { + ggml_backend_synchronize(backend); ggml_backend_tensor_get(tensor, data, offset, size); } else { backend->iface.get_tensor_async(backend, tensor, data, offset, size); From 73e04555eb597c728ccc7be463d7ec055f766345 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 2 Feb 2026 14:29:44 +0200 Subject: [PATCH 084/831] metal : support virtual devices (llama/18919) * metal : support virtual devices * cont : manage buffer type context memory * metal : add events * cont : implement cpy_tensor_async --- ggml/src/ggml-metal/ggml-metal-context.h | 8 + ggml/src/ggml-metal/ggml-metal-context.m | 105 +++++- ggml/src/ggml-metal/ggml-metal-device.cpp | 8 +- ggml/src/ggml-metal/ggml-metal-device.h | 16 +- ggml/src/ggml-metal/ggml-metal-device.m | 71 +++- ggml/src/ggml-metal/ggml-metal.cpp | 424 ++++++++++++++++------ 6 files changed, 506 insertions(+), 126 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-context.h b/ggml/src/ggml-metal/ggml-metal-context.h index ec2b686b733..abf4b06ed2a 100644 --- a/ggml/src/ggml-metal/ggml-metal-context.h +++ b/ggml/src/ggml-metal/ggml-metal-context.h @@ -15,14 +15,22 @@ typedef struct ggml_metal * ggml_metal_t; ggml_metal_t ggml_metal_init(ggml_metal_device_t dev); void ggml_metal_free(ggml_metal_t ctx); +const char * ggml_metal_get_name(ggml_metal_t ctx); + void ggml_metal_synchronize(ggml_metal_t ctx); void ggml_metal_set_tensor_async(ggml_metal_t ctx, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); void ggml_metal_get_tensor_async(ggml_metal_t ctx, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); +bool ggml_metal_cpy_tensor_async(ggml_metal_t ctx_src, ggml_metal_t ctx_dst, const struct ggml_tensor * src, struct ggml_tensor * dst); enum ggml_status ggml_metal_graph_compute (ggml_metal_t ctx, struct ggml_cgraph * gf); void ggml_metal_graph_optimize(ggml_metal_t ctx, struct ggml_cgraph * gf); +void ggml_metal_event_record(ggml_metal_t ctx, ggml_metal_event_t ev); +void ggml_metal_event_wait (ggml_metal_t ctx, ggml_metal_event_t ev); + +ggml_metal_event_t ggml_metal_get_ev_cpy(ggml_metal_t ctx); + void ggml_metal_set_n_cb (ggml_metal_t ctx, int n_cb); void ggml_metal_set_abort_callback (ggml_metal_t ctx, ggml_abort_callback abort_callback, void * user_data); bool ggml_metal_supports_family (ggml_metal_t ctx, int family); diff --git a/ggml/src/ggml-metal/ggml-metal-context.m b/ggml/src/ggml-metal/ggml-metal-context.m index 42a35736eea..a412d70aed5 100644 --- a/ggml/src/ggml-metal/ggml-metal-context.m +++ b/ggml/src/ggml-metal/ggml-metal-context.m @@ -24,9 +24,13 @@ }; struct ggml_metal { + char name[128]; + ggml_metal_device_t dev; ggml_metal_library_t lib; + ggml_metal_event_t ev_cpy; // for async copies + dispatch_queue_t d_queue; // additional, inference-time compiled pipelines @@ -117,7 +121,11 @@ ggml_metal_t ggml_metal_init(ggml_metal_device_t dev) { } } - //const struct ggml_metal_device_props * props_dev = ggml_metal_device_get_props(dev); + res->ev_cpy = ggml_metal_device_event_init(dev); + + const struct ggml_metal_device_props * props_dev = ggml_metal_device_get_props(dev); + + snprintf(res->name, sizeof(res->name), "%s", props_dev->name); res->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT); @@ -206,9 +214,15 @@ void ggml_metal_free(ggml_metal_t ctx) { dispatch_release(ctx->d_queue); + ggml_metal_device_event_free(ctx->dev, ctx->ev_cpy); + free(ctx); } +const char * ggml_metal_get_name(ggml_metal_t ctx) { + return ctx->name; +} + void ggml_metal_synchronize(ggml_metal_t ctx) { // wait for any backend operations to finish if (ctx->cmd_buf_last) { @@ -273,8 +287,8 @@ void ggml_metal_set_tensor_async(ggml_metal_t ctx, struct ggml_tensor * tensor, // wrap the source data into a Metal buffer id device = ggml_metal_device_get_obj(ctx->dev); id buf_src = [device newBufferWithBytes:data - length:size - options:MTLResourceStorageModeShared]; + length:size + options:MTLResourceStorageModeShared]; GGML_ASSERT(buf_src); @@ -316,9 +330,9 @@ void ggml_metal_get_tensor_async(ggml_metal_t ctx, const struct ggml_tensor * te @autoreleasepool { id device = ggml_metal_device_get_obj(ctx->dev); id buf_dst = [device newBufferWithBytesNoCopy:data - length:size - options:MTLResourceStorageModeShared - deallocator:nil]; + length:size + options:MTLResourceStorageModeShared + deallocator:nil]; GGML_ASSERT(buf_dst); @@ -356,6 +370,49 @@ void ggml_metal_get_tensor_async(ggml_metal_t ctx, const struct ggml_tensor * te } } +bool ggml_metal_cpy_tensor_async(ggml_metal_t ctx_src, ggml_metal_t ctx_dst, const struct ggml_tensor * src, struct ggml_tensor * dst) { + @autoreleasepool { + struct ggml_metal_buffer_id bid_src = ggml_metal_get_buffer_id(src); + struct ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(dst); + + if (bid_src.metal == nil || bid_dst.metal == nil) { + return false; + } + + // queue the copy operation into the Metal context + // this will be queued at the end, after any currently ongoing GPU operations + id queue = ggml_metal_device_get_queue(ctx_src->dev); + id cmd_buf = [queue commandBuffer]; + id encoder = [cmd_buf blitCommandEncoder]; + + [encoder copyFromBuffer:bid_src.metal + sourceOffset:bid_src.offs + toBuffer:bid_dst.metal + destinationOffset:bid_dst.offs + size:ggml_nbytes(src)]; + + [encoder endEncoding]; + + ggml_metal_event_t ev_cpy = ggml_metal_get_ev_cpy(ctx_src); + ggml_metal_event_record(ctx_src, ev_cpy); + + [cmd_buf commit]; + + // do not wait here for completion + //[cmd_buf waitUntilCompleted]; + + // instead, remember a reference to the command buffer and wait for it later if needed + [ctx_src->cmd_bufs_ext addObject:cmd_buf]; + ctx_src->cmd_buf_last = cmd_buf; + + [cmd_buf retain]; + + ggml_metal_event_wait(ctx_dst, ev_cpy); + + return true; + } +} + enum ggml_status ggml_metal_graph_compute(ggml_metal_t ctx, struct ggml_cgraph * gf) { // number of nodes encoded by the main thread (empirically determined) const int n_main = 64; @@ -530,6 +587,42 @@ void ggml_metal_graph_optimize(ggml_metal_t ctx, struct ggml_cgraph * gf) { //printf("%s: graph optimize took %.3f ms\n", __func__, (ggml_time_us() - t_start) / 1000.0); } +void ggml_metal_event_record(ggml_metal_t ctx, ggml_metal_event_t ev) { + @autoreleasepool { + id queue = ggml_metal_device_get_queue(ctx->dev); + id cmd_buf = [queue commandBuffer]; + + ggml_metal_event_encode_signal(ev, cmd_buf); + + [cmd_buf commit]; + + [ctx->cmd_bufs_ext addObject:cmd_buf]; + ctx->cmd_buf_last = cmd_buf; + + [cmd_buf retain]; + } +} + +void ggml_metal_event_wait(ggml_metal_t ctx, ggml_metal_event_t ev) { + @autoreleasepool { + id queue = ggml_metal_device_get_queue(ctx->dev); + id cmd_buf = [queue commandBuffer]; + + ggml_metal_event_encode_wait(ev, cmd_buf); + + [cmd_buf commit]; + + [ctx->cmd_bufs_ext addObject:cmd_buf]; + ctx->cmd_buf_last = cmd_buf; + + [cmd_buf retain]; + } +} + +ggml_metal_event_t ggml_metal_get_ev_cpy(ggml_metal_t ctx) { + return ctx->ev_cpy; +} + void ggml_metal_set_n_cb(ggml_metal_t ctx, int n_cb) { if (ctx->n_cb != n_cb) { ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_COMMAND_BUFFERS); diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 04c6137c5a7..377b0d3eb8f 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -17,10 +17,12 @@ struct ggml_metal_device_deleter { typedef std::unique_ptr ggml_metal_device_ptr; -ggml_metal_device_t ggml_metal_device_get(void) { - static ggml_metal_device_ptr ctx { ggml_metal_device_init() }; +ggml_metal_device_t ggml_metal_device_get(int device) { + static std::vector devs; - return ctx.get(); + devs.emplace_back(ggml_metal_device_init(device)); + + return devs.back().get(); } struct ggml_metal_pipelines { diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index 3d01c56fb81..afb091e7255 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -205,7 +205,9 @@ void ggml_metal_rsets_free(ggml_metal_rsets_t rsets); // struct ggml_metal_device_props { + int device; char name[128]; + char desc[128]; size_t max_buffer_size; size_t max_working_set_size; @@ -224,11 +226,15 @@ struct ggml_metal_device_props { int op_offload_min_batch_size; }; -ggml_metal_device_t ggml_metal_device_init(void); +typedef struct ggml_metal_event * ggml_metal_event_t; + +void ggml_metal_event_encode_signal(ggml_metal_event_t ev, ggml_metal_cmd_buf_t cmd_buf); +void ggml_metal_event_encode_wait (ggml_metal_event_t ev, ggml_metal_cmd_buf_t cmd_buf); + +ggml_metal_device_t ggml_metal_device_init(int device); void ggml_metal_device_free(ggml_metal_device_t dev); -// return a singleton that is automatically destroyed when the program exits -ggml_metal_device_t ggml_metal_device_get(void); +ggml_metal_device_t ggml_metal_device_get(int device); void * ggml_metal_device_get_obj (ggml_metal_device_t dev); // id void * ggml_metal_device_get_queue(ggml_metal_device_t dev); // id @@ -240,6 +246,10 @@ void ggml_metal_device_rsets_rm (ggml_metal_device_t dev, ggml_metal_rset_t rset void ggml_metal_device_rsets_keep_alive(ggml_metal_device_t dev); +ggml_metal_event_t ggml_metal_device_event_init(ggml_metal_device_t dev); +void ggml_metal_device_event_free(ggml_metal_device_t dev, ggml_metal_event_t ev); +void ggml_metal_device_event_synchronize(ggml_metal_device_t dev, ggml_metal_event_t ev); + void ggml_metal_device_get_memory(ggml_metal_device_t dev, size_t * free, size_t * total); bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_tensor * op); diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 7f9c384c344..285dd1630e7 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -24,9 +24,6 @@ static const NSInteger MTLGPUFamilyMetal3_GGML = 5001; static const NSInteger MTLGPUFamilyMetal4_GGML = 5002; -// virtual address for GPU memory allocations -static atomic_uintptr_t g_addr_device = 0x000000400ULL; - #if !GGML_METAL_EMBED_LIBRARY // Here to assist with NSBundle Path Hack @interface GGMLMetalClass : NSObject @@ -523,6 +520,9 @@ void ggml_metal_encoder_end_encoding(ggml_metal_encoder_t encoder) { ggml_metal_library_t library; struct ggml_metal_device_props props; + + // virtual address for GPU memory allocations + atomic_uintptr_t addr_virt; }; // @@ -618,7 +618,7 @@ void ggml_metal_rsets_free(ggml_metal_rsets_t rsets) { free(rsets); } -ggml_metal_device_t ggml_metal_device_init(void) { +ggml_metal_device_t ggml_metal_device_init(int device) { ggml_metal_device_t dev = calloc(1, sizeof(struct ggml_metal_device)); assert(dev != NULL); @@ -632,6 +632,9 @@ ggml_metal_device_t ggml_metal_device_init(void) { GGML_LOG_ERROR("%s: error: failed to create command queue\n", __func__); } + dev->addr_virt = 0x000000400ULL; + + dev->props.device = device; dev->props.has_simdgroup_reduction = [dev->mtl_device supportsFamily:MTLGPUFamilyApple7]; dev->props.has_simdgroup_reduction |= [dev->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML]; @@ -792,7 +795,8 @@ ggml_metal_device_t ggml_metal_device_init(void) { dev->props.max_working_set_size = dev->mtl_device.maxBufferLength; } - strncpy(dev->props.name, [[dev->mtl_device name] UTF8String], sizeof(dev->props.name) - 1); + snprintf(dev->props.name, sizeof(dev->props.name), "%s%d", "MTL", device); + snprintf(dev->props.desc, sizeof(dev->props.desc), "%s", [[dev->mtl_device name] UTF8String]); dev->library = ggml_metal_library_init(dev); if (!dev->library) { @@ -922,6 +926,59 @@ void ggml_metal_device_rsets_keep_alive(ggml_metal_device_t dev) { atomic_store_explicit(&dev->rsets->d_loop, 2*dev->rsets->keep_alive_s, memory_order_relaxed); } +struct ggml_metal_event { + void * obj; // id + + atomic_int value; +}; + +void ggml_metal_event_encode_signal(ggml_metal_event_t ev, ggml_metal_cmd_buf_t cmd_buf_raw) { + id event = (id)ev->obj; + + id cmd_buf = (id) cmd_buf_raw; + + [cmd_buf encodeSignalEvent:event value:atomic_fetch_add_explicit(&ev->value, 1, memory_order_relaxed) + 1]; +} + +void ggml_metal_event_encode_wait(ggml_metal_event_t ev, ggml_metal_cmd_buf_t cmd_buf_raw) { + id event = (id)ev->obj; + + id cmd_buf = (id) cmd_buf_raw; + + [cmd_buf encodeWaitForEvent:event value:atomic_load_explicit(&ev->value, memory_order_relaxed)]; +} + +ggml_metal_event_t ggml_metal_device_event_init(ggml_metal_device_t dev) { + id event = [dev->mtl_device newEvent]; + + ggml_metal_event_t ev = calloc(1, sizeof(struct ggml_metal_event)); + + ev->obj = (__bridge void *)event; + ev->value = 0; + + return ev; +} + +void ggml_metal_device_event_free(ggml_metal_device_t dev, ggml_metal_event_t ev) { + id event = ev->obj; + [event release]; + + free(ev); + + GGML_UNUSED(dev); +} + +void ggml_metal_device_event_synchronize(ggml_metal_device_t dev, ggml_metal_event_t ev) { + @autoreleasepool { + id event = ev->obj; + + id cmd_buf = [dev->mtl_queue commandBuffer]; + [cmd_buf encodeWaitForEvent:event value:atomic_load_explicit(&ev->value, memory_order_relaxed)]; + [cmd_buf commit]; + [cmd_buf waitUntilCompleted]; + } +} + void ggml_metal_device_get_memory(ggml_metal_device_t dev, size_t * free, size_t * total) { if (@available(macOS 10.12, iOS 16.0, *)) { *total = dev->mtl_device.recommendedMaxWorkingSetSize; @@ -1344,8 +1401,8 @@ ggml_metal_buffer_t ggml_metal_buffer_init(ggml_metal_device_t dev, size_t size, res->all_data = ggml_metal_host_malloc(size_aligned); res->is_shared = true; } else { - // use virtual address from g_addr_device counter - res->all_data = (void *) atomic_fetch_add_explicit(&g_addr_device, size_aligned, memory_order_relaxed); + // use virtual address + res->all_data = (void *) atomic_fetch_add_explicit(&dev->addr_virt, size_aligned, memory_order_relaxed); res->is_shared = false; } res->all_size = size_aligned; diff --git a/ggml/src/ggml-metal/ggml-metal.cpp b/ggml/src/ggml-metal/ggml-metal.cpp index 56b59f0afdf..a616dcdb461 100644 --- a/ggml/src/ggml-metal/ggml-metal.cpp +++ b/ggml/src/ggml-metal/ggml-metal.cpp @@ -7,11 +7,12 @@ #include "ggml-metal-context.h" #include "ggml-metal-ops.h" -// globals +#define GGML_METAL_NAME "MTL" +#define GGML_METAL_MAX_DEVICES 16 -// initialized in ggml_backend_metal_reg -static ggml_backend_reg g_ggml_metal_reg; -static ggml_backend_device g_ggml_metal_device; +// number of Metal devices +// note: can be overriden with GGML_METAL_DEVICES env to simulate virtual devices +static int g_devices = 1; //////////////////////////////////////////////////////////////////////////////// // backend interface @@ -165,10 +166,28 @@ static ggml_backend_buffer_i ggml_backend_metal_buffer_private_i = { /* .reset = */ NULL, }; +static bool ggml_backend_buffer_is_metal(ggml_backend_buffer_t buffer) { + return buffer->iface.free_buffer == ggml_backend_metal_buffer_shared_free_buffer || + buffer->iface.free_buffer == ggml_backend_metal_buffer_private_free_buffer; +} + // // buffer types // +struct ggml_backend_metal_buffer_type { + int device; + std::string name; +}; + +struct ggml_backend_metal_buffer_type_deleter { + void operator()(ggml_backend_metal_buffer_type * ctx) const { + delete ctx; + } +}; + +typedef std::unique_ptr ggml_backend_metal_buffer_type_ptr; + // common method for allocating shread or private Metal buffers static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size, bool shared) { ggml_metal_device_t ctx_dev = (ggml_metal_device_t)buft->device->context; @@ -218,9 +237,9 @@ static size_t ggml_backend_metal_buffer_type_get_alloc_size(ggml_backend_buffer_ // default (shared) buffer type static const char * ggml_backend_metal_buffer_type_shared_get_name(ggml_backend_buffer_type_t buft) { - return "Metal"; + ggml_backend_metal_buffer_type * ctx = (ggml_backend_metal_buffer_type *)buft->context; - GGML_UNUSED(buft); + return ctx->name.c_str(); } static ggml_backend_buffer_t ggml_backend_metal_buffer_type_shared_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { @@ -249,29 +268,54 @@ static bool ggml_backend_metal_buffer_type_shared_is_host(ggml_backend_buffer_ty GGML_UNUSED(buft); } -static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_shared(void) { - static ggml_backend_buffer_type ggml_backend_buffer_type_metal = { - /* .iface = */ { - /* .get_name = */ ggml_backend_metal_buffer_type_shared_get_name, - /* .alloc_buffer = */ ggml_backend_metal_buffer_type_shared_alloc_buffer, - /* .get_alignment = */ ggml_backend_metal_buffer_type_shared_get_alignment, - /* .get_max_size = */ ggml_backend_metal_buffer_type_shared_get_max_size, - /* .get_alloc_size = */ ggml_backend_metal_buffer_type_shared_get_alloc_size, - /* .is_host = */ ggml_backend_metal_buffer_type_shared_is_host, - }, - /* .device = */ &g_ggml_metal_device, - /* .context = */ NULL, - }; +static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_shared(int device) { + static std::mutex mutex; + std::lock_guard lock(mutex); + + static std::vector bufts; + static std::vector ctxs; + + static bool initialized = false; + if (!initialized) { + bufts.reserve(g_devices); + ctxs.reserve(g_devices); + + for (int i = 0; i < g_devices; ++i) { + ggml_backend_metal_buffer_type * raw_ctx = + new ggml_backend_metal_buffer_type { + /* .device = */ i, + /* .name = */ GGML_METAL_NAME + std::to_string(i), + }; + ctxs.emplace_back(raw_ctx); + + ggml_backend_buffer_type buft = { + /* .iface = */ { + /* .get_name = */ ggml_backend_metal_buffer_type_shared_get_name, + /* .alloc_buffer = */ ggml_backend_metal_buffer_type_shared_alloc_buffer, + /* .get_alignment = */ ggml_backend_metal_buffer_type_shared_get_alignment, + /* .get_max_size = */ ggml_backend_metal_buffer_type_shared_get_max_size, + /* .get_alloc_size = */ ggml_backend_metal_buffer_type_shared_get_alloc_size, + /* .is_host = */ ggml_backend_metal_buffer_type_shared_is_host, + }, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_metal_reg(), i), + /* .context = */ raw_ctx, + }; + + bufts.emplace_back(buft); + } + + initialized = true; + } - return &ggml_backend_buffer_type_metal; + return &bufts[device]; } // default (private) buffer type static const char * ggml_backend_metal_buffer_type_private_get_name(ggml_backend_buffer_type_t buft) { - return "Metal_Private"; + ggml_backend_metal_buffer_type * ctx = (ggml_backend_metal_buffer_type *)buft->context; - GGML_UNUSED(buft); + return ctx->name.c_str(); } static ggml_backend_buffer_t ggml_backend_metal_buffer_type_private_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { @@ -300,29 +344,53 @@ static bool ggml_backend_metal_buffer_type_private_is_host(ggml_backend_buffer_t GGML_UNUSED(buft); } -static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_private(void) { - static ggml_backend_buffer_type ggml_backend_buffer_type_metal = { - /* .iface = */ { - /* .get_name = */ ggml_backend_metal_buffer_type_private_get_name, - /* .alloc_buffer = */ ggml_backend_metal_buffer_type_private_alloc_buffer, - /* .get_alignment = */ ggml_backend_metal_buffer_type_private_get_alignment, - /* .get_max_size = */ ggml_backend_metal_buffer_type_private_get_max_size, - /* .get_alloc_size = */ ggml_backend_metal_buffer_type_private_get_alloc_size, - /* .is_host = */ ggml_backend_metal_buffer_type_private_is_host, - }, - /* .device = */ &g_ggml_metal_device, - /* .context = */ NULL, - }; +static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_private(int device) { + static std::mutex mutex; + std::lock_guard lock(mutex); + + static std::vector bufts; + static std::vector ctxs; + + static bool initialized = false; + if (!initialized) { + bufts.reserve(g_devices); + ctxs.reserve(g_devices); + + for (int i = 0; i < g_devices; ++i) { + ggml_backend_metal_buffer_type * raw_ctx = new ggml_backend_metal_buffer_type{ + /* .device = */ i, + /* .name = */ GGML_METAL_NAME + std::to_string(i) + "_Private" + }; + ctxs.emplace_back(raw_ctx); + + ggml_backend_buffer_type buft = { + /* .iface = */ { + /* .get_name = */ ggml_backend_metal_buffer_type_private_get_name, + /* .alloc_buffer = */ ggml_backend_metal_buffer_type_private_alloc_buffer, + /* .get_alignment = */ ggml_backend_metal_buffer_type_private_get_alignment, + /* .get_max_size = */ ggml_backend_metal_buffer_type_private_get_max_size, + /* .get_alloc_size = */ ggml_backend_metal_buffer_type_private_get_alloc_size, + /* .is_host = */ ggml_backend_metal_buffer_type_private_is_host, + }, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_metal_reg(), i), + /* .context = */ raw_ctx, + }; + + bufts.emplace_back(buft); + } + + initialized = true; + } - return &ggml_backend_buffer_type_metal; + return &bufts[device]; } // mapped buffer type static const char * ggml_backend_metal_buffer_type_mapped_get_name(ggml_backend_buffer_type_t buft) { - return "Metal_Mapped"; + ggml_backend_metal_buffer_type * ctx = (ggml_backend_metal_buffer_type *)buft->context; - GGML_UNUSED(buft); + return ctx->name.c_str(); } static ggml_backend_buffer_t ggml_backend_metal_buffer_type_mapped_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { @@ -352,31 +420,55 @@ static bool ggml_backend_metal_buffer_type_mapped_is_host(ggml_backend_buffer_ty GGML_UNUSED(buft); } -static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_mapped(void) { - // note: not obvious, but this buffer type still needs to implement .alloc_buffer: - // https://github.com/ggml-org/llama.cpp/pull/15832#discussion_r2333177099 - static ggml_backend_buffer_type ggml_backend_buffer_type_mapped_metal = { - /* .iface = */ { - /* .get_name = */ ggml_backend_metal_buffer_type_mapped_get_name, - /* .alloc_buffer = */ ggml_backend_metal_buffer_type_mapped_alloc_buffer, - /* .get_alignment = */ ggml_backend_metal_buffer_type_mapped_get_alignment, - /* .get_max_size = */ ggml_backend_metal_buffer_type_mapped_get_max_size, - /* .get_alloc_size = */ ggml_backend_metal_buffer_type_mapped_get_alloc_size, - /* .is_host = */ ggml_backend_metal_buffer_type_mapped_is_host, - }, - /* .device = */ &g_ggml_metal_device, - /* .context = */ NULL, - }; +static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_mapped(int device) { + static std::mutex mutex; + std::lock_guard lock(mutex); + + static std::vector bufts; + static std::vector ctxs; + + static bool initialized = false; + if (!initialized) { + bufts.reserve(g_devices); + ctxs.reserve(g_devices); + + for (int i = 0; i < g_devices; ++i) { + ggml_backend_metal_buffer_type * raw_ctx = new ggml_backend_metal_buffer_type{ + /* .device = */ i, + /* .name = */ GGML_METAL_NAME + std::to_string(i) + "_Mapped" + }; + ctxs.emplace_back(raw_ctx); + + // note: not obvious, but this buffer type still needs to implement .alloc_buffer: + // https://github.com/ggml-org/llama.cpp/pull/15832#discussion_r2333177099 + ggml_backend_buffer_type buft = { + /* .iface = */ { + /* .get_name = */ ggml_backend_metal_buffer_type_mapped_get_name, + /* .alloc_buffer = */ ggml_backend_metal_buffer_type_mapped_alloc_buffer, + /* .get_alignment = */ ggml_backend_metal_buffer_type_mapped_get_alignment, + /* .get_max_size = */ ggml_backend_metal_buffer_type_mapped_get_max_size, + /* .get_alloc_size = */ ggml_backend_metal_buffer_type_mapped_get_alloc_size, + /* .is_host = */ ggml_backend_metal_buffer_type_mapped_is_host, + }, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_metal_reg(), i), + /* .context = */ raw_ctx, + }; + + bufts.emplace_back(buft); + } + + initialized = true; + } - return &ggml_backend_buffer_type_mapped_metal; + return &bufts[device]; } // backend static const char * ggml_backend_metal_name(ggml_backend_t backend) { - return "Metal"; + ggml_metal_t ctx = (ggml_metal_t)backend->context; - GGML_UNUSED(backend); + return ggml_metal_get_name(ctx); } static void ggml_backend_metal_free(ggml_backend_t backend) { @@ -409,12 +501,24 @@ static void ggml_backend_metal_get_tensor_async(ggml_backend_t backend, const gg } static bool ggml_backend_metal_cpy_tensor_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const ggml_tensor * src, ggml_tensor * dst) { - return false; + if (!ggml_backend_is_metal(backend_src) || !ggml_backend_is_metal(backend_dst)) { + return false; + } - GGML_UNUSED(backend_src); - GGML_UNUSED(backend_dst); - GGML_UNUSED(src); - GGML_UNUSED(dst); + if (!ggml_backend_buffer_is_metal(src->buffer) || !ggml_backend_buffer_is_metal(dst->buffer)) { + return false; + } + + ggml_metal_t ctx_src = (ggml_metal_t)backend_src->context; + ggml_metal_t ctx_dst = (ggml_metal_t)backend_dst->context; + + //ggml_backend_buffer_t buf_src = src->view_src ? src->view_src->buffer : src->buffer; + //ggml_backend_buffer_t buf_dst = dst->view_src ? dst->view_src->buffer : dst->buffer; + + //ggml_metal_buffer_t buf_ctx_src = (ggml_metal_buffer_t)buf_src->context; + //ggml_metal_buffer_t buf_ctx_dst = (ggml_metal_buffer_t)buf_dst->context; + + return ggml_metal_cpy_tensor_async(ctx_src, ctx_dst, src, dst); } static enum ggml_status ggml_backend_metal_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { @@ -423,6 +527,20 @@ static enum ggml_status ggml_backend_metal_graph_compute(ggml_backend_t backend, return ggml_metal_graph_compute(ctx, cgraph); } +static void ggml_backend_metal_event_record(ggml_backend_t backend, ggml_backend_event_t event) { + ggml_metal_t ctx = (ggml_metal_t)backend->context; + ggml_metal_event_t ev = (ggml_metal_event_t)event->context; + + ggml_metal_event_record(ctx, ev); +} + +static void ggml_backend_metal_event_wait(ggml_backend_t backend, ggml_backend_event_t event) { + ggml_metal_t ctx = (ggml_metal_t)backend->context; + ggml_metal_event_t ev = (ggml_metal_event_t)event->context; + + ggml_metal_event_wait(ctx, ev); +} + static void ggml_backend_metal_graph_optimize(ggml_backend_t backend, ggml_cgraph * cgraph) { ggml_metal_t ctx = (ggml_metal_t)backend->context; @@ -435,7 +553,6 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) { ggml_metal_t ctx = (ggml_metal_t)backend->context; ggml_metal_set_n_cb(ctx, n_cb); - } static ggml_backend_i ggml_backend_metal_i = { @@ -450,12 +567,8 @@ static ggml_backend_i ggml_backend_metal_i = { /* .graph_plan_update = */ NULL, /* .graph_plan_compute = */ NULL, /* .graph_compute = */ ggml_backend_metal_graph_compute, - - // the events API is needed only for multi-GPU setups, so likely no need to implement it for Metal - // in any case, these docs seem relevant if we ever decide to implement it: - // https://developer.apple.com/documentation/metal/mtlcommandbuffer#Synchronizing-Passes-with-Events - /* .event_record = */ NULL, - /* .event_wait = */ NULL, + /* .event_record = */ ggml_backend_metal_event_record, + /* .event_wait = */ ggml_backend_metal_event_wait, /* .graph_optimize = */ ggml_backend_metal_graph_optimize, }; @@ -519,15 +632,17 @@ void ggml_backend_metal_capture_next_compute(ggml_backend_t backend) { // backend device static const char * ggml_backend_metal_device_get_name(ggml_backend_dev_t dev) { - return "Metal"; + ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context; - GGML_UNUSED(dev); + const ggml_metal_device_props * props_dev = ggml_metal_device_get_props(ctx_dev); + + return props_dev->name; } static const char * ggml_backend_metal_device_get_description(ggml_backend_dev_t dev) { ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context; - return ggml_metal_device_get_props(ctx_dev)->name; + return ggml_metal_device_get_props(ctx_dev)->desc; } static void ggml_backend_metal_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { @@ -550,14 +665,14 @@ static void ggml_backend_metal_device_get_props(ggml_backend_dev_t dev, ggml_bac ggml_backend_metal_device_get_memory(dev, &props->memory_free, &props->memory_total); props->caps = { - /* .async = */ true, - /* .host_buffer = */ false, - /* .buffer_from_host_ptr = */ true, - /* .events = */ false, + /* .async = */ true, + /* .host_buffer = */ false, + /* .buffer_from_host_ptr = */ true, + /* .events = */ true, }; } -static ggml_backend_t ggml_backend_metal_device_init(ggml_backend_dev_t dev, const char * params) { +static ggml_backend_t ggml_backend_metal_device_init_backend(ggml_backend_dev_t dev, const char * params) { ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context; ggml_metal_t ctx = ggml_metal_init(ctx_dev); @@ -587,7 +702,7 @@ static ggml_backend_buffer_type_t ggml_backend_metal_device_get_buffer_type(ggml const ggml_metal_device_props * props_dev = ggml_metal_device_get_props(ctx_dev); - return props_dev->use_shared_buffers ? ggml_backend_metal_buffer_type_shared() : ggml_backend_metal_buffer_type_private(); + return props_dev->use_shared_buffers ? ggml_backend_metal_buffer_type_shared(props_dev->device) : ggml_backend_metal_buffer_type_private(props_dev->device); } static ggml_backend_buffer_t ggml_backend_metal_device_buffer_mapped(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) { @@ -595,7 +710,9 @@ static ggml_backend_buffer_t ggml_backend_metal_device_buffer_mapped(ggml_backen ggml_metal_buffer_t res = ggml_metal_buffer_map(ctx_dev, ptr, size, max_tensor_size); - return ggml_backend_buffer_init(ggml_backend_metal_buffer_type_mapped(), ggml_backend_metal_buffer_shared_i, res, size); + const ggml_metal_device_props * props_dev = ggml_metal_device_get_props(ctx_dev); + + return ggml_backend_buffer_init(ggml_backend_metal_buffer_type_mapped(props_dev->device), ggml_backend_metal_buffer_shared_i, res, size); } static bool ggml_backend_metal_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) { @@ -606,9 +723,10 @@ static bool ggml_backend_metal_device_supports_op(ggml_backend_dev_t dev, const static bool ggml_backend_metal_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { return + buft->device == dev && ( buft->iface.get_name == ggml_backend_metal_buffer_type_shared_get_name || buft->iface.get_name == ggml_backend_metal_buffer_type_private_get_name || - buft->iface.get_name == ggml_backend_metal_buffer_type_mapped_get_name; + buft->iface.get_name == ggml_backend_metal_buffer_type_mapped_get_name); GGML_UNUSED(dev); } @@ -632,45 +750,97 @@ static bool ggml_backend_metal_device_offload_op(ggml_backend_dev_t dev, const g get_op_batch_size(op) >= ggml_metal_device_get_props(ctx_dev)->op_offload_min_batch_size; } +static ggml_backend_event_t ggml_backend_metal_device_event_new(ggml_backend_dev_t dev) { + ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context; + + ggml_metal_event_t event = ggml_metal_device_event_init(ctx_dev); + GGML_ASSERT(event); + + ggml_backend_event_t ev = new ggml_backend_event { + /* .device = */ dev, + /* .context = */ event, + }; + + return ev; +} + +static void ggml_backend_metal_device_event_free(ggml_backend_dev_t dev, ggml_backend_event_t event) { + ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context; + + ggml_metal_event_t ev = (ggml_metal_event_t)event->context; + + ggml_metal_device_event_free(ctx_dev, ev); + + delete event; +} + +static void ggml_backend_metal_device_event_synchronize(ggml_backend_dev_t dev, ggml_backend_event_t event) { + ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context; + + ggml_metal_event_t evt = (ggml_metal_event_t)event->context; + + ggml_metal_device_event_synchronize(ctx_dev, evt); +} + static ggml_backend_device_i ggml_backend_metal_device_i = { /* .get_name = */ ggml_backend_metal_device_get_name, /* .get_description = */ ggml_backend_metal_device_get_description, /* .get_memory = */ ggml_backend_metal_device_get_memory, /* .get_type = */ ggml_backend_metal_device_get_type, /* .get_props = */ ggml_backend_metal_device_get_props, - /* .init_backend = */ ggml_backend_metal_device_init, + /* .init_backend = */ ggml_backend_metal_device_init_backend, /* .get_buffer_type = */ ggml_backend_metal_device_get_buffer_type, /* .get_host_buffer_type = */ NULL, /* .buffer_from_host_ptr = */ ggml_backend_metal_device_buffer_mapped, /* .supports_op = */ ggml_backend_metal_device_supports_op, /* .supports_buft = */ ggml_backend_metal_device_supports_buft, /* .offload_op = */ ggml_backend_metal_device_offload_op, - /* .event_new = */ NULL, - /* .event_free = */ NULL, - /* .event_synchronize = */ NULL, + /* .event_new = */ ggml_backend_metal_device_event_new, + /* .event_free = */ ggml_backend_metal_device_event_free, + /* .event_synchronize = */ ggml_backend_metal_device_event_synchronize, }; // backend registry +struct ggml_backend_metal_reg { + std::vector devices; +}; + +typedef struct ggml_backend_metal_reg * ggml_backend_metal_reg_t; + +static ggml_backend_metal_reg_t ggml_backend_metal_reg_init(void) { + ggml_backend_metal_reg_t ctx = new struct ggml_backend_metal_reg; + + return ctx; +} + +static void ggml_backend_metal_reg_free(ggml_backend_metal_reg_t ctx) { + delete ctx; +} + +struct ggml_backend_metal_reg_deleter { + void operator()(ggml_backend_metal_reg_t ctx) { + ggml_backend_metal_reg_free(ctx); + } +}; + +typedef std::unique_ptr ggml_backend_metal_reg_ptr; + static const char * ggml_backend_metal_reg_get_name(ggml_backend_reg_t reg) { - return "Metal"; + return GGML_METAL_NAME; GGML_UNUSED(reg); } static size_t ggml_backend_metal_reg_device_count(ggml_backend_reg_t reg) { - return 1; - - GGML_UNUSED(reg); + ggml_backend_metal_reg_t ctx = (ggml_backend_metal_reg_t)reg->context; + return ctx->devices.size(); } static ggml_backend_dev_t ggml_backend_metal_reg_device_get(ggml_backend_reg_t reg, size_t index) { - GGML_ASSERT(index == 0); - - return &g_ggml_metal_device; - - GGML_UNUSED(reg); - GGML_UNUSED(index); + ggml_backend_metal_reg_t ctx = (ggml_backend_metal_reg_t)reg->context; + GGML_ASSERT(index < ctx->devices.size()); + return ctx->devices[index]; } static ggml_backend_feature g_ggml_backend_metal_features[] = { @@ -698,27 +868,67 @@ static void * ggml_backend_metal_get_proc_address(ggml_backend_reg_t reg, const static ggml_backend_reg_i ggml_backend_metal_reg_i = { /* .get_name = */ ggml_backend_metal_reg_get_name, - /* .device_count = */ ggml_backend_metal_reg_device_count, - /* .device_get = */ ggml_backend_metal_reg_device_get, + /* .get_device_count = */ ggml_backend_metal_reg_device_count, + /* .get_device = */ ggml_backend_metal_reg_device_get, /* .get_proc_address = */ ggml_backend_metal_get_proc_address, }; +static ggml_backend_dev_t ggml_backend_metal_device_init(ggml_backend_reg_t reg, int device) { + return new ggml_backend_device { + /* .iface = */ ggml_backend_metal_device_i, + /* .reg = */ reg, + /* .context = */ ggml_metal_device_get(device), + }; +} + +static void ggml_backend_metal_device_free(ggml_backend_dev_t dev) { + delete dev; +} + +struct ggml_backend_device_deleter { + void operator()(ggml_backend_dev_t ctx) { + ggml_backend_metal_device_free(ctx); + } +}; + +typedef std::unique_ptr ggml_backend_device_ptr; + ggml_backend_reg_t ggml_backend_metal_reg(void) { + static ggml_backend_reg reg; + static bool initialized = false; + { - g_ggml_metal_reg = { - /* .api_version = */ GGML_BACKEND_API_VERSION, - /* .iface = */ ggml_backend_metal_reg_i, - /* .context = */ NULL, - }; - - g_ggml_metal_device = { - /* .iface = */ ggml_backend_metal_device_i, - /* .reg = */ &g_ggml_metal_reg, - /* .context = */ ggml_metal_device_get(), - }; + static std::mutex mutex; + std::lock_guard lock(mutex); + + const char * env = getenv("GGML_METAL_DEVICES"); + if (env) { + g_devices = atoi(env); + } + + static std::vector devs; + + if (!initialized) { + static ggml_backend_metal_reg_ptr reg_ctx(ggml_backend_metal_reg_init()); + + for (int i = 0; i < g_devices; ++i) { + auto * dev = ggml_backend_metal_device_init(®, i); + devs.emplace_back(dev); + + reg_ctx->devices.push_back(dev); + } + + reg = { + /* .api_version = */ GGML_BACKEND_API_VERSION, + /* .iface = */ ggml_backend_metal_reg_i, + /* .context = */ reg_ctx.get(), + }; + } + + initialized = true; } - return &g_ggml_metal_reg; + return ® } GGML_BACKEND_DL_IMPL(ggml_backend_metal_reg) From 74353e90a157990cbe2103495943d26a75e04384 Mon Sep 17 00:00:00 2001 From: Tamar Date: Mon, 2 Feb 2026 15:05:51 +0200 Subject: [PATCH 085/831] sycl: implement GGML_OP_TOP_K (llama/19242) --- ggml/src/ggml-sycl/ggml-sycl.cpp | 140 +++++++++++++++++++++++++++++++ 1 file changed, 140 insertions(+) diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 12f1e7717b7..c5139fd3dd2 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -1840,6 +1840,110 @@ static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols, } } +static void top_k_f32_sycl( + const float * src, + int32_t * dst_indices, + const int64_t ncols, + const int64_t nrows, + const int k, + dpct::queue_ptr main_stream +) { + const int block_size = 128; + + const sycl::range<1> block_dims(block_size); + const sycl::range<1> grid_dims(nrows); + + main_stream->submit([&](sycl::handler &cgh) { + sycl::local_accessor shared_vals(sycl::range<1>(block_size * k), cgh); + sycl::local_accessor shared_idx(sycl::range<1>(block_size * k), cgh); + + cgh.parallel_for( + sycl::nd_range<1>(grid_dims * block_dims, block_dims), + [=](sycl::nd_item<1> item_ct1) { + const int row = item_ct1.get_group(0); + const int tid = item_ct1.get_local_id(0); + + if (row >= nrows) return; + + const float * src_row = src + row * ncols; + int32_t * dst_idx_row = dst_indices + row * k; + + float local_vals[32]; + int local_idx[32]; + + for (int i = 0; i < k; i++) { + local_vals[i] = -FLT_MAX; + local_idx[i] = -1; + } + + for (int col = tid; col < ncols; col += block_size) { + float val = src_row[col]; + + if (val > local_vals[k-1]) { + int pos = k - 1; + while (pos > 0 && val > local_vals[pos - 1]) { + pos--; + } + + for (int i = k - 1; i > pos; i--) { + local_vals[i] = local_vals[i - 1]; + local_idx[i] = local_idx[i - 1]; + } + local_vals[pos] = val; + local_idx[pos] = col; + } + } + + for (int i = 0; i < k; i++) { + shared_vals[tid * k + i] = local_vals[i]; + shared_idx[tid * k + i] = local_idx[i]; + } + item_ct1.barrier(sycl::access::fence_space::local_space); + + if (tid == 0) { + float final_vals[32]; + int final_idx[32]; + + for (int i = 0; i < k; i++) { + final_vals[i] = -FLT_MAX; + final_idx[i] = -1; + } + + for (int t = 0; t < block_size; t++) { + for (int i = 0; i < k; i++) { + float val = shared_vals[t * k + i]; + int idx = shared_idx[t * k + i]; + + if (val > final_vals[k-1]) { + int pos = k - 1; + while (pos > 0 && val > final_vals[pos - 1]) { + pos--; + } + + for (int j = k - 1; j > pos; j--) { + final_vals[j] = final_vals[j - 1]; + final_idx[j] = final_idx[j - 1]; + } + final_vals[pos] = val; + final_idx[pos] = idx; + } + } + } + + for (int i = 0; i < k; i++) { + dst_idx_row[i] = final_idx[i]; + } + + if (k > 1) { + int32_t temp = dst_idx_row[0]; + dst_idx_row[0] = dst_idx_row[1]; + dst_idx_row[1] = temp; + } + } + }); + }); +} + static void argmax_f32_i32_sycl(const float *x, int *dst, const int ncols, const int nrows, queue_ptr stream) { const sycl::range<3> block_dims(1, 1, SYCL_ARGMAX_BLOCK_SIZE); @@ -2231,6 +2335,30 @@ inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * main_stream, ctx.device); } +static void ggml_sycl_op_top_k(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + + GGML_ASSERT(src0); + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_I32); + GGML_ASSERT(ggml_is_contiguous(src0)); + + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + + const float * src0_dd = static_cast(src0->data); + int32_t * dst_dd = static_cast(dst->data); + + const int k = dst->ne[0]; + const int64_t ncols = src0->ne[0]; + const int64_t nrows = ggml_nrows(src0); + + GGML_ASSERT(k > 0 && k <= 32); + GGML_ASSERT(k <= ncols); + + top_k_f32_sycl(src0_dd, dst_dd, ncols, nrows, k, main_stream); +} + inline void ggml_sycl_op_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_I32); @@ -4007,6 +4135,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg case GGML_OP_ARGSORT: ggml_sycl_argsort(ctx, dst); break; + case GGML_OP_TOP_K: + ggml_sycl_op_top_k(ctx, dst); + break; case GGML_OP_TIMESTEP_EMBEDDING: ggml_sycl_op_timestep_embedding(ctx, dst); break; @@ -4710,6 +4841,15 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_ARGSORT: return op->src[0]->ne[0] * sizeof(int) <= ggml_sycl_info().devices[device].smpbo; + case GGML_OP_TOP_K: { + const ggml_tensor * src0 = op->src[0]; + const int k = op->ne[0]; + return src0 && + op->type == GGML_TYPE_I32 && + src0->type == GGML_TYPE_F32 && + ggml_is_contiguous(src0) && + k > 0 && k <= 32; + } case GGML_OP_POOL_2D: case GGML_OP_ACC: return true; From c4003da2b838a923eeb7879f91dc6ddd1f413af2 Mon Sep 17 00:00:00 2001 From: Neo Zhang Date: Mon, 2 Feb 2026 21:06:21 +0800 Subject: [PATCH 086/831] Remove support for Nvidia & AMD GPU, because the oneAPI plugin for Nvidia & AMD GPU is unavailable: download/installation channels are out of work. (llama/19246) User can't build up the software for Nvidia & AMD GPU. rm the oneMath since it is only used in NV and AMD code path. --- ggml/src/ggml-sycl/CMakeLists.txt | 97 +++--------------------------- ggml/src/ggml-sycl/dpct/helper.hpp | 65 ++++++-------------- ggml/src/ggml-sycl/ggml-sycl.cpp | 35 ++++------- ggml/src/ggml-sycl/outprod.cpp | 6 +- ggml/src/ggml-sycl/rope.cpp | 1 - ggml/src/ggml-sycl/wkv.cpp | 2 +- 6 files changed, 44 insertions(+), 162 deletions(-) diff --git a/ggml/src/ggml-sycl/CMakeLists.txt b/ggml/src/ggml-sycl/CMakeLists.txt index 5a89d8dd688..eefdd9725ca 100644 --- a/ggml/src/ggml-sycl/CMakeLists.txt +++ b/ggml/src/ggml-sycl/CMakeLists.txt @@ -1,7 +1,7 @@ message(STATUS "GGML_SYCL_TARGET=${GGML_SYCL_TARGET}") -if (NOT GGML_SYCL_TARGET MATCHES "^(INTEL|NVIDIA|AMD)$") - message(FATAL_ERROR "Invalid backend chosen, supported options are INTEL, NVIDIA, or AMD") +if (NOT GGML_SYCL_TARGET MATCHES "^(INTEL)$") + message(FATAL_ERROR "GGML_SYCL_TARGET: Invalid target, the supported options are [INTEL]") endif() check_cxx_compiler_flag("-fsycl" SUPPORTS_SYCL) @@ -125,106 +125,27 @@ endif() target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_DNNL=${GGML_SYCL_DNNL}) if (GGML_SYCL_F16) - if (GGML_SYCL_TARGET STREQUAL "AMD") - message(WARNING "AMD target does not entirely support FP16 in the SYCL backend.") - endif() add_compile_definitions(GGML_SYCL_F16) endif() if (GGML_SYCL_TARGET STREQUAL "INTEL") add_compile_definitions(GGML_SYCL_WARP_SIZE=16) target_link_options(ggml-sycl PRIVATE -Xs -ze-intel-greater-than-4GB-buffer-required) -elseif (GGML_SYCL_TARGET STREQUAL "NVIDIA") - add_compile_definitions(GGML_SYCL_WARP_SIZE=32) -elseif (GGML_SYCL_TARGET STREQUAL "AMD") - # INFO: Allowed Sub_group_sizes are not consistent through all - # hip targets. For example, 64 is used for certain models, but the backend - # does not support it. - # Target archs tested working: gfx1030, gfx1031, (Only tested sub_group_size = 32) - add_compile_definitions(GGML_SYCL_WARP_SIZE=32) -else() - # default for other target - add_compile_definitions(GGML_SYCL_WARP_SIZE=32) -endif() - -if (GGML_SYCL_GRAPH) - target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_GRAPH) -endif() -# Link against Intel oneMKL or oneMath -if (GGML_SYCL_TARGET STREQUAL "INTEL") - # Intel devices use Intel oneMKL directly instead of oneMath to avoid the limitation of linking Intel oneMKL statically - # See https://github.com/uxlfoundation/oneMath/issues/654 + # Link against Intel oneMKL if (CMAKE_CXX_COMPILER_ID STREQUAL "Clang") set(SYCL_COMPILER ON) endif() find_package(MKL REQUIRED) target_link_libraries(ggml-sycl PRIVATE MKL::MKL_SYCL::BLAS) - target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_USE_INTEL_ONEMKL) else() - find_package(oneMath QUIET) - if (NOT oneMath_FOUND) - message(STATUS "oneMath not found: oneMath will be automatically downloaded") - # Use FetchContent to automatically pull and build oneMath - include(FetchContent) - set(BUILD_FUNCTIONAL_TESTS False) - set(BUILD_EXAMPLES False) - set(TARGET_DOMAINS blas) - if (GGML_SYCL_TARGET STREQUAL "NVIDIA") - set(ENABLE_MKLCPU_BACKEND False) - set(ENABLE_MKLGPU_BACKEND False) - set(ENABLE_CUBLAS_BACKEND True) - elseif (GGML_SYCL_TARGET STREQUAL "AMD") - set(ENABLE_MKLCPU_BACKEND False) - set(ENABLE_MKLGPU_BACKEND False) - set(ENABLE_ROCBLAS_BACKEND True) - # Ensure setting a string variable here is not overriden by oneMath CACHE variables - cmake_policy(SET CMP0126 NEW) - # Setting the device architecture is only needed and useful for AMD devices in oneMath - set(HIP_TARGETS ${GGML_SYCL_DEVICE_ARCH} CACHE STRING "oneMath HIP target" FORCE) - endif() - FetchContent_Declare( - ONEMATH - GIT_REPOSITORY https://github.com/uxlfoundation/oneMath.git - GIT_TAG 8efe85f5aaebb37f1d8c503b7af66315feabf142 - ) - FetchContent_MakeAvailable(ONEMATH) - # Create alias to match with find_package targets name - function(onemath_alias target) - if (TARGET ${target}_obj) - # Silence verbose warnings from external libraries - target_compile_options(${target}_obj PRIVATE -w) - endif() - if (TARGET ${target}) - add_library(ONEMATH::${target} ALIAS ${target}) - endif() - endfunction() - onemath_alias(onemath) - onemath_alias(onemath_blas_mklcpu) - onemath_alias(onemath_blas_mklgpu) - onemath_alias(onemath_blas_cublas) - onemath_alias(onemath_blas_rocblas) - endif() + # default for other target + message(FATAL_ERROR "GGML_SYCL_TARGET is not supported") + add_compile_definitions(GGML_SYCL_WARP_SIZE=32) +endif() - # Below oneMath compile-time dispatching is used for better performance - if (GGML_SYCL_TARGET STREQUAL "NVIDIA") - target_link_libraries(ggml-sycl PRIVATE ONEMATH::onemath_blas_cublas) - target_compile_options(ggml-sycl PRIVATE "-fsycl-targets=nvptx64-nvidia-cuda") - target_link_options(ggml-sycl PRIVATE "-fsycl-targets=nvptx64-nvidia-cuda") - target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_NVIDIA) - elseif (GGML_SYCL_TARGET STREQUAL "AMD") - if (NOT GGML_SYCL_DEVICE_ARCH) - message(FATAL_ERROR "Can't enable SYCL hip backend, GGML_SYCL_DEVICE_ARCH has not been set.") - endif() - target_link_libraries(ggml-sycl PRIVATE ONEMATH::onemath_blas_rocblas) - target_compile_options(ggml-sycl PRIVATE "-fsycl-targets=amdgcn-amd-amdhsa") - target_link_options(ggml-sycl PRIVATE "-fsycl-targets=amdgcn-amd-amdhsa") - target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_AMD) - else() - # Fallback to oneMath runtime dispatcher - target_link_libraries(ggml-sycl PRIVATE ONEMATH::onemath) - target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_GENERIC) - endif() +if (GGML_SYCL_GRAPH) + target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_GRAPH) endif() if (GGML_SYCL_DEVICE_ARCH) diff --git a/ggml/src/ggml-sycl/dpct/helper.hpp b/ggml/src/ggml-sycl/dpct/helper.hpp index 8ae8098717d..ece66a7ac1f 100644 --- a/ggml/src/ggml-sycl/dpct/helper.hpp +++ b/ggml/src/ggml-sycl/dpct/helper.hpp @@ -15,17 +15,9 @@ #include #include -#include - -#ifdef GGML_SYCL_USE_INTEL_ONEMKL #include -// Allow to use the same namespace for Intel oneMKL and oneMath -namespace oneapi { - namespace math = mkl; -} -#else -#include -#endif + +#include #include "ggml.h" @@ -91,32 +83,13 @@ inline std::string get_device_backend_and_type(const sycl::device &device) { } template struct matrix_info_t { - oneapi::math::transpose transpose_info[2]; + oneapi::mkl::transpose transpose_info[2]; Ts value_info[2]; std::int64_t size_info[3]; std::int64_t ld_info[3]; std::int64_t groupsize_info; }; -inline auto get_onemath_backend(sycl::queue& queue) -#if defined(GGML_SYCL_GENERIC) || defined(GGML_SYCL_USE_INTEL_ONEMKL) - -> sycl::queue& -#endif -{ -// If the backend is known at compile-time, use oneMath backend_selector to use -// compile-time dispatching and avoid the need to dlopen libraries. Otherwise -// fallback to runtime dispatching. -#if defined(GGML_SYCL_NVIDIA) - return oneapi::math::backend_selector{ queue }; -#elif defined(GGML_SYCL_AMD) - return oneapi::math::backend_selector{ queue }; -#elif defined(GGML_SYCL_GENERIC) || defined(GGML_SYCL_USE_INTEL_ONEMKL) - return queue; -#else - static_assert(false, "Unsupported backend"); -#endif -} - namespace dpct { typedef sycl::queue *queue_ptr; @@ -1734,7 +1707,7 @@ namespace dpct namespace detail { template - inline void gemm_impl(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m, + inline void gemm_impl(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m, int n, int k, const void * alpha, const void * a, int lda, const void * b, int ldb, const void * beta, void * c, int ldc) { Ts alpha_value = dpct::get_value(reinterpret_cast(alpha), q); @@ -1742,7 +1715,7 @@ namespace dpct auto data_a = get_memory(a); auto data_b = get_memory(b); auto data_c = get_memory(c); - oneapi::math::blas::column_major::gemm(get_onemath_backend(q), a_trans, b_trans, m, n, k, alpha_value, data_a, + oneapi::mkl::blas::column_major::gemm(q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda, data_b, ldb, beta_value, data_c, ldc); } @@ -1774,7 +1747,7 @@ namespace dpct }; template - inline void gemm_batch_impl(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, + inline void gemm_batch_impl(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m, int n, int k, const void * alpha, const void ** a, int lda, const void ** b, int ldb, const void * beta, void ** c, int ldc, int batch_size, matrix_info_t * matrix_info) { @@ -1793,8 +1766,8 @@ namespace dpct matrix_info->ld_info[2] = ldc; matrix_info->groupsize_info = batch_size; - sycl::event e = oneapi::math::blas::column_major::gemm_batch( - get_onemath_backend(q), matrix_info->transpose_info, matrix_info->transpose_info + 1, + sycl::event e = oneapi::mkl::blas::column_major::gemm_batch( + q, matrix_info->transpose_info, matrix_info->transpose_info + 1, matrix_info->size_info, matrix_info->size_info + 1, matrix_info->size_info + 2, reinterpret_cast(matrix_info->value_info), reinterpret_cast(a), matrix_info->ld_info, reinterpret_cast(b), matrix_info->ld_info + 1, @@ -1803,7 +1776,7 @@ namespace dpct } template - inline void gemm_batch_impl(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, + inline void gemm_batch_impl(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m, int n, int k, const void * alpha, const void * a, int lda, long long int stride_a, const void * b, int ldb, long long int stride_b, const void * beta, void * c, int ldc, long long int stride_c, int batch_size) { @@ -1812,7 +1785,7 @@ namespace dpct auto data_a = get_memory(a); auto data_b = get_memory(b); auto data_c = get_memory(c); - oneapi::math::blas::column_major::gemm_batch(get_onemath_backend(q), a_trans, b_trans, m, n, k, alpha_value, + oneapi::mkl::blas::column_major::gemm_batch(q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda, stride_a, data_b, ldb, stride_b, beta_value, data_c, ldc, stride_c, batch_size); } @@ -2299,7 +2272,7 @@ namespace dpct sycl::range<3>(x, y, 1), direction); } - inline void gemm(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m, int n, + inline void gemm(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m, int n, int k, const void * alpha, const void * a, library_data_t a_type, int lda, const void * b, library_data_t b_type, int ldb, const void * beta, void * c, library_data_t c_type, int ldc, library_data_t scaling_type) { @@ -2366,7 +2339,7 @@ namespace dpct library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_float, library_data_t::real_float): { - detail::gemm_impl( + detail::gemm_impl( q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); break; } @@ -2405,7 +2378,7 @@ namespace dpct library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_float): { - detail::gemm_impl( + detail::gemm_impl( q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); break; } @@ -2447,7 +2420,7 @@ namespace dpct /// \param [in] ldc Leading dimension of C. /// \param [in] batch_size Specifies the number of matrix multiply operations to perform. /// \param [in] scaling_type Data type of the scaling factors. - inline void gemm_batch(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m, + inline void gemm_batch(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m, int n, int k, const void * alpha, const void * a[], library_data_t a_type, int lda, const void * b[], library_data_t b_type, int ldb, const void * beta, void * c[], library_data_t c_type, int ldc, int batch_size, library_data_t scaling_type, @@ -2485,7 +2458,7 @@ namespace dpct library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_float): { - detail::gemm_batch_impl( + detail::gemm_batch_impl( q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info); break; } @@ -2493,7 +2466,7 @@ namespace dpct library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_float, library_data_t::real_float): { - detail::gemm_batch_impl( + detail::gemm_batch_impl( q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info); break; } @@ -2569,7 +2542,7 @@ namespace dpct /// \param [in] stride_c Stride between the different C matrices. /// \param [in] batch_size Specifies the number of matrix multiply operations to perform. /// \param [in] scaling_type Data type of the scaling factors. - inline void gemm_batch(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m, + inline void gemm_batch(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m, int n, int k, const void * alpha, const void * a, library_data_t a_type, int lda, long long int stride_a, const void * b, library_data_t b_type, int ldb, long long int stride_b, const void * beta, void * c, library_data_t c_type, int ldc, @@ -2642,7 +2615,7 @@ namespace dpct library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_float): { - detail::gemm_batch_impl( + detail::gemm_batch_impl( q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size); break; @@ -2651,7 +2624,7 @@ namespace dpct library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_float, library_data_t::real_float): { - detail::gemm_batch_impl( + detail::gemm_batch_impl( q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size); break; diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index c5139fd3dd2..a03d26d7f20 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -2167,8 +2167,8 @@ inline void ggml_sycl_op_mul_mat_sycl( const sycl::half alpha_f16 = 1.0f; const sycl::half beta_f16 = 0.0f; SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm( - *stream, oneapi::math::transpose::trans, - oneapi::math::transpose::nontrans, row_diff, src1_ncols, ne10, + *stream, oneapi::mkl::transpose::trans, + oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10, &alpha_f16, src0_ptr, dpct::library_data_t::real_half, ne00, src1_ptr, dpct::library_data_t::real_half, ne10, &beta_f16, dst_f16.get(), dpct::library_data_t::real_half, ldc, @@ -2211,8 +2211,8 @@ inline void ggml_sycl_op_mul_mat_sycl( { const float alpha = 1.0f; const float beta = 0.0f; - SYCL_CHECK(CHECK_TRY_ERROR(oneapi::math::blas::column_major::gemm( - get_onemath_backend(*stream), oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, row_diff, + SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm( + *stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10, dpct::get_value(&beta, *stream), dst_dd_i, ldc))); } @@ -3165,8 +3165,8 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons const int64_t smb = ne12 == 1 ? s13 : s12; // there is no broadcast and src0, src1 are contiguous across dims 2, 3 - SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(*queue, oneapi::math::transpose::trans, - oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha, + SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(*queue, oneapi::mkl::transpose::trans, + oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha, src0_f16, dpct::library_data_t::real_half, nb01 / nb00, sma, src1_f16, dpct::library_data_t::real_half, s11, smb, beta, dst_ddf, mkl_data_type, ne0, ne1 * ne0, ne12 * ne13, mkl_compute_type))); @@ -3190,7 +3190,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons }); SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch( - *queue, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha, + *queue, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha, (const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00, (const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, s11, beta, (void **) (ptrs_dst.get() + 0 * ne23), mkl_data_type, ne0, ne23, mkl_compute_type, matrix_info.get()))); @@ -3524,12 +3524,11 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor use_mul_mat_q = use_mul_mat_q && (src1->ne[1] <= MMQ_MAX_BATCH_SIZE); #endif // SYCL_USE_XMX - // mmvq path is faster in the CUDA backend. - if (!g_ggml_sycl_prioritize_dmmv && (ctx.stream()->get_backend() == sycl::backend::ext_oneapi_cuda - // Dispatch becomes obscure with the reorder, MMVQ when the reorder optimization - // is enabled takes precedence over DMMV, the current if-else implementation - // requires disabling DMMV if both conditions are met - || (should_reorder_tensor(ctx, dst) && ggml_sycl_supports_reorder_mmvq(src0->type)))) { + // Dispatch becomes obscure with the reorder, MMVQ when the reorder optimization + // is enabled takes precedence over DMMV, the current if-else implementation + // requires disabling DMMV if both conditions are met + if (!g_ggml_sycl_prioritize_dmmv && ((should_reorder_tensor(ctx, dst) && + ggml_sycl_supports_reorder_mmvq(src0->type)))) { use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q; } @@ -4189,16 +4188,6 @@ void ggml_backend_sycl_get_device_memory(int device, size_t *free, GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_get_device_memory\n"); ggml_sycl_set_device(device); - /* - DPCT1009:218: SYCL uses exceptions to report errors and does not use the - error codes. The original code was commented out and a warning string was - inserted. You need to rewrite this code. - */ - /* - DPCT1106:217: 'cudaMemGetInfo' was migrated with the Intel extensions for - device information which may not be supported by all compilers or runtimes. - You may need to adjust the code. - */ SYCL_CHECK(CHECK_TRY_ERROR( dpct::dev_mgr::instance().get_device(device).get_memory_info(*free, *total))); } diff --git a/ggml/src/ggml-sycl/outprod.cpp b/ggml/src/ggml-sycl/outprod.cpp index 3a17f3a1b88..f52b11f0d6e 100644 --- a/ggml/src/ggml-sycl/outprod.cpp +++ b/ggml/src/ggml-sycl/outprod.cpp @@ -32,12 +32,12 @@ void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, ggml_tensor* dst) { // Handle transposition of src1 const bool src1_T = ggml_is_transposed(src1); - const oneapi::math::transpose src1_op = src1_T ? oneapi::math::transpose::nontrans : oneapi::math::transpose::trans; + const oneapi::mkl::transpose src1_op = src1_T ? oneapi::mkl::transpose::nontrans : oneapi::mkl::transpose::trans; const int64_t ldb = (src1_T ? nb10 : nb11) / sizeof(float); try { - // Perform matrix multiplication using oneMath GEMM - oneapi::math::blas::column_major::gemm(get_onemath_backend(*stream), oneapi::math::transpose::nontrans, src1_op, + // Perform matrix multiplication using oneMKL GEMM + oneapi::mkl::blas::column_major::gemm(*stream, oneapi::mkl::transpose::nontrans, src1_op, ne0, ne1, ne01, alpha, src0_d, ne00, src1_d, ldb, beta, dst_d, ne0); } catch (sycl::exception const& exc) { diff --git a/ggml/src/ggml-sycl/rope.cpp b/ggml/src/ggml-sycl/rope.cpp index 69140b19a4c..aeaa58b95b3 100644 --- a/ggml/src/ggml-sycl/rope.cpp +++ b/ggml/src/ggml-sycl/rope.cpp @@ -207,7 +207,6 @@ static void rope_vision(const T * x, T * dst, const int ne0, const int ne1, cons const int p = sector; theta_base = pos[channel_x] * sycl::pow(theta_scale, (float) p); } else { - // Simplified from CUDA backend code: if (sector >= sections.v[0] && sector < sec_w) which is just sector >= sections.v[0] const int p = sector - sections.v[0]; theta_base = pos[channel_x + ne2] * sycl::pow(theta_scale, (float) p); } diff --git a/ggml/src/ggml-sycl/wkv.cpp b/ggml/src/ggml-sycl/wkv.cpp index c10e2f7645e..b56e0c2400f 100644 --- a/ggml/src/ggml-sycl/wkv.cpp +++ b/ggml/src/ggml-sycl/wkv.cpp @@ -1,7 +1,7 @@ #include #include "wkv.hpp" -constexpr int WKV_BLOCK_SIZE = 64; // Matching CUDA_WKV_BLOCK_SIZE +constexpr int WKV_BLOCK_SIZE = 64; // Helper function for the main kernel template From 871063016d1f72a74a55a0cc5e0db485aba8f74e Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Tue, 3 Feb 2026 01:19:55 +0800 Subject: [PATCH 087/831] ggml-cpu: FA split across kv for faster TG (llama/19209) * ggml-cpu: split across kv for faster TG * simplify sinks application * add ref impl --- ggml/include/ggml-cpu.h | 5 + ggml/src/ggml-cpu/ggml-cpu-impl.h | 3 + ggml/src/ggml-cpu/ggml-cpu.c | 22 ++- ggml/src/ggml-cpu/ggml-cpu.cpp | 15 ++ ggml/src/ggml-cpu/ops.cpp | 231 ++++++++++++++++++++++-------- 5 files changed, 210 insertions(+), 66 deletions(-) diff --git a/ggml/include/ggml-cpu.h b/ggml/include/ggml-cpu.h index 4f3b99c8d07..e3e067c916f 100644 --- a/ggml/include/ggml-cpu.h +++ b/ggml/include/ggml-cpu.h @@ -19,6 +19,9 @@ extern "C" { // abort ggml_graph_compute when true ggml_abort_callback abort_callback; void * abort_callback_data; + + // use only reference implementations + bool use_ref; }; // numa strategies @@ -132,6 +135,8 @@ extern "C" { GGML_BACKEND_API void ggml_backend_cpu_set_threadpool (ggml_backend_t backend_cpu, ggml_threadpool_t threadpool); GGML_BACKEND_API void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_callback abort_callback, void * abort_callback_data); + GGML_BACKEND_API void ggml_backend_cpu_set_use_ref(ggml_backend_t backend_cpu, bool use_ref); + GGML_BACKEND_API ggml_backend_reg_t ggml_backend_cpu_reg(void); GGML_BACKEND_API void ggml_cpu_fp32_to_fp32(const float *, float *, int64_t); diff --git a/ggml/src/ggml-cpu/ggml-cpu-impl.h b/ggml/src/ggml-cpu/ggml-cpu-impl.h index 0e8dd0ae053..88a9c9ec057 100644 --- a/ggml/src/ggml-cpu/ggml-cpu-impl.h +++ b/ggml/src/ggml-cpu/ggml-cpu-impl.h @@ -24,6 +24,9 @@ struct ggml_compute_params { void * wdata; struct ggml_threadpool * threadpool; + + // use reference implementation + bool use_ref; }; diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index b1de2ae8716..3e5f01e3fb6 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -5,7 +5,6 @@ #include "ggml-backend.h" #include "traits.h" #include "ggml-cpu-impl.h" -#include "ggml-cpu.h" #include "ggml-impl.h" #include "quants.h" #include "ggml-threading.h" @@ -2867,12 +2866,20 @@ struct ggml_cplan ggml_graph_plan( } break; case GGML_OP_FLASH_ATTN_EXT: { + const int64_t neq2 = node->src[0]->ne[2]; // number of query heads const int64_t DK = node->src[1]->ne[0]; const int64_t DV = node->src[2]->ne[0]; // Tiled flash attention scratch (tile sizes defined in common.h) // Per-thread: Q_q + KQ + mask + VKQ32 + V32 + padding - cur = sizeof(float)*(GGML_FA_TILE_Q*DK + 2*GGML_FA_TILE_Q*GGML_FA_TILE_KV + GGML_FA_TILE_Q*DV + GGML_FA_TILE_KV*DV)*n_tasks; + size_t prefill = sizeof(float)*(GGML_FA_TILE_Q*DK + 2*GGML_FA_TILE_Q*GGML_FA_TILE_KV + GGML_FA_TILE_Q*DV + GGML_FA_TILE_KV*DV)*n_tasks; + + // Decode path: n_kv_chunks = n_tasks (one chunk per thread) + // Per-thread: VKQ accmulator (DV), partial M, partial S + intra-thread scratch for V, Q and VKQ + size_t n_chunks = n_tasks; + size_t decode = sizeof(float)*(neq2*n_chunks*(2+DV) + n_tasks*(DK + 2*DV)); + + cur += MAX(prefill, decode); } break; case GGML_OP_FLASH_ATTN_BACK: { @@ -2929,11 +2936,12 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { set_numa_thread_affinity(state->ith); struct ggml_compute_params params = { - /*.ith =*/ state->ith, - /*.nth =*/ atomic_load_explicit(&tp->n_graph, memory_order_relaxed) & GGML_THREADPOOL_N_THREADS_MASK, - /*.wsize =*/ cplan->work_size, - /*.wdata =*/ cplan->work_data, - /*.threadpool=*/ tp, + /*.ith =*/ state->ith, + /*.nth =*/ atomic_load_explicit(&tp->n_graph, memory_order_relaxed) & GGML_THREADPOOL_N_THREADS_MASK, + /*.wsize =*/ cplan->work_size, + /*.wdata =*/ cplan->work_data, + /*.threadpool =*/ tp, + /*.use_ref =*/ cplan->use_ref, }; GGML_PRINT_DEBUG("thread #%d compute-start cplan %p last-graph %d \n", state->ith, cplan, state->last_graph); diff --git a/ggml/src/ggml-cpu/ggml-cpu.cpp b/ggml/src/ggml-cpu/ggml-cpu.cpp index f4713a42185..ddf1737a317 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.cpp +++ b/ggml/src/ggml-cpu/ggml-cpu.cpp @@ -105,6 +105,8 @@ struct ggml_backend_cpu_context { ggml_abort_callback abort_callback; void * abort_callback_data; + + bool use_ref; // use reference implementation }; static const char * ggml_backend_cpu_get_name(ggml_backend_t backend) { @@ -143,6 +145,7 @@ static ggml_backend_graph_plan_t ggml_backend_cpu_graph_plan_create(ggml_backend cpu_plan->cplan.abort_callback = cpu_ctx->abort_callback; cpu_plan->cplan.abort_callback_data = cpu_ctx->abort_callback_data; + cpu_plan->cplan.use_ref = cpu_ctx->use_ref; return cpu_plan; } @@ -182,6 +185,7 @@ static enum ggml_status ggml_backend_cpu_graph_compute(ggml_backend_t backend, s cplan.abort_callback = cpu_ctx->abort_callback; cplan.abort_callback_data = cpu_ctx->abort_callback_data; + cplan.use_ref = cpu_ctx->use_ref; return ggml_graph_compute(cgraph, &cplan); } @@ -223,6 +227,7 @@ ggml_backend_t ggml_backend_cpu_init(void) { ctx->work_size = 0; ctx->abort_callback = NULL; ctx->abort_callback_data = NULL; + ctx->use_ref = false; ggml_backend_t cpu_backend = new ggml_backend { /* .guid = */ ggml_backend_cpu_guid(), @@ -270,6 +275,13 @@ void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_ ctx->abort_callback_data = abort_callback_data; } +void ggml_backend_cpu_set_use_ref(ggml_backend_t backend_cpu, bool use_ref) { + GGML_ASSERT(ggml_backend_is_cpu(backend_cpu)); + + struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context; + ctx->use_ref = use_ref; +} + // CPU backend - device struct ggml_backend_cpu_device_context { @@ -646,6 +658,9 @@ static void * ggml_backend_cpu_get_proc_address(ggml_backend_reg_t reg, const ch if (strcmp(name, "ggml_backend_cpu_is_numa") == 0) { return (void *)ggml_is_numa; } + if (strcmp(name, "ggml_backend_cpu_set_use_ref") == 0) { + return (void *)ggml_backend_cpu_set_use_ref; + } // threadpool - TODO: move to ggml-base if (strcmp(name, "ggml_threadpool_new") == 0) { diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 48c89643619..ce15b18ce0e 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -8042,12 +8042,14 @@ void ggml_compute_forward_top_k( } } -// ggml_compute_forward_flash_attn_ext - static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( const ggml_compute_params * params, ggml_tensor * dst, - int ir0, int ir1) { + int ir0, int ir1, + int64_t ic_start, int64_t ic_end, + float * partials, int64_t partial_stride) { + + const bool write_partials = (partials != nullptr); const ggml_tensor * q = dst->src[0]; const ggml_tensor * k = dst->src[1]; const ggml_tensor * v = dst->src[2]; @@ -8124,7 +8126,6 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( int ith = params->ith; - // loop over n_batch and n_head for (int ir = ir0; ir < ir1; ++ir) { // q indices const int iq3 = ir/(neq2*neq1); @@ -8165,7 +8166,7 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( // loop over n_kv and n_head_kv // ref: https://arxiv.org/pdf/2112.05682.pdf - for (int64_t ic = 0; ic < nek1; ++ic) { + for (int64_t ic = ic_start; ic < ic_end; ++ic) { const float mv = mp ? slope*GGML_CPU_FP16_TO_FP32(mp[ic]) : 0.0f; if (mv == -INFINITY) { continue; @@ -8238,8 +8239,8 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( } } - // sinks - if (sinks) { + // sinks - apply only on the first kv-chunk + if (sinks && ic_start == 0) { const float s = ((float *)((char *) sinks->data))[h]; float ms = 1.0f; @@ -8247,6 +8248,7 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( if (s > M) { ms = expf(M - s); + M = s; ggml_vec_scale_f32(DV, VKQ32, ms); } else { vs = expf(s - M); @@ -8255,20 +8257,26 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( S = S*ms + vs; } - // V /= S - const float S_inv = S == 0.0f ? 0.0f : 1.0f/S; - ggml_vec_scale_f32(DV, VKQ32, S_inv); - - // dst indices - const int i1 = iq1; - const int i2 = iq2; - const int i3 = iq3; + if (write_partials) { + // Write M, S, VKQ to partials for later reduction + // partials layout: [M, S, VKQ[DV]] per query head + float * partial = partials + ir * partial_stride; + partial[0] = M; + partial[1] = S; + memcpy(partial + 2, VKQ32, DV * sizeof(float)); + } else { + // V /= S + const float S_inv = S == 0.0f ? 0.0f : 1.0f/S; + ggml_vec_scale_f32(DV, VKQ32, S_inv); - // original - //memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float)); + // dst indices + const int i1 = iq1; + const int i2 = iq2; + const int i3 = iq3; - // permute(0, 2, 1, 3) - memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1); + // permute(0, 2, 1, 3) + memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1); + } } } @@ -8546,6 +8554,78 @@ static void ggml_compute_forward_flash_attn_ext_tiled( } } +// Reduction function: combines partial results across KV chunks +// Partials layout in wdata: [n_q_heads][n_chunks][2 + DV] +static void ggml_flash_attn_ext_reduce_partials( + const ggml_compute_params * params, + ggml_tensor * dst, + const int64_t n_chunks, + const int64_t chunk_size) { + + const ggml_tensor * q = dst->src[0]; + const ggml_tensor * k = dst->src[1]; + const ggml_tensor * v = dst->src[2]; + + const int64_t DK = k->ne[0]; + const int64_t DV = v->ne[0]; + const int64_t nek1 = k->ne[1]; + const int64_t n_q_heads = q->ne[2]; + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t wdata_per_thread = DK + 2*DV + CACHE_LINE_SIZE_F32; + float * thread_wdata = (float *) params->wdata + ith * wdata_per_thread; + + const int64_t partials_offset = nth * (DK + 2*DV + CACHE_LINE_SIZE_F32); + const int64_t partial_size = 2 + DV; + const float * partials_base = (const float *) params->wdata + partials_offset; + + // Output layout + const int64_t ne1 = dst->ne[1]; + const int64_t ne2 = dst->ne[2]; + const size_t nb1 = dst->nb[1]; + + // Each thread reduces a subset of query heads + for (int64_t q_head = ith; q_head < n_q_heads; q_head += nth) { + float M_final = -INFINITY; + float S_final = 0.0f; + float * VKQ_final = thread_wdata; + memset(VKQ_final, 0, DV * sizeof(float)); + + // Combine partials from all chunks + for (int64_t chunk_idx = 0; chunk_idx < n_chunks; ++chunk_idx) { + const int64_t ic_start = chunk_idx * chunk_size; + if (ic_start >= nek1) continue; + + const float * partial = partials_base + (q_head * n_chunks + chunk_idx) * partial_size; + const float M_chunk = partial[0]; + const float S_chunk = partial[1]; + const float * VKQ_chunk = partial + 2; + + if (S_chunk == 0.0f) continue; + + const float M_new = fmaxf(M_final, M_chunk); + const float scale_old = expf(M_final - M_new); + const float scale_new = expf(M_chunk - M_new); + + for (int64_t d = 0; d < DV; ++d) { + VKQ_final[d] = VKQ_final[d] * scale_old + VKQ_chunk[d] * scale_new; + } + S_final = S_final * scale_old + S_chunk * scale_new; + M_final = M_new; + } + + // Normalize and write to output + if (S_final != 0.0f) { + const float S_inv = 1.0f / S_final; + ggml_vec_scale_f32(DV, VKQ_final, S_inv); + } + // iq1=0, iq3=0 for decode + memcpy((char *) dst->data + (0*ne2*ne1 + q_head + 0*ne1)*nb1, VKQ_final, nb1); + } +} + static void ggml_compute_forward_flash_attn_ext_f16( const ggml_compute_params * params, ggml_tensor * dst) { @@ -8567,6 +8647,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( const int64_t DV = nev0; const int64_t N = neq1; + GGML_ASSERT(ne0 == DV); GGML_ASSERT(ne2 == N); @@ -8587,60 +8668,92 @@ static void ggml_compute_forward_flash_attn_ext_f16( GGML_ASSERT(nb1 <= nb2); GGML_ASSERT(nb2 <= nb3); - // parallelize by q rows using ggml_vec_dot_f32 - - // total rows in q - const int64_t nr = neq1*neq2*neq3; - - // rows per thread const int ith = params->ith; const int nth = params->nth; - // disable for NUMA - const bool disable_chunking = ggml_is_numa(); + // When use_ref is set, force the vec-only reference implementation (no tiling, no KV-chunking) + const bool use_ref = params->use_ref; - // 4x chunks per thread - int nth_scaled = nth * 4; - int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled; - int64_t nchunk = (nr + chunk_size - 1) / chunk_size; + const bool kv_is_f32_or_f16 = (k->type == GGML_TYPE_F32 || k->type == GGML_TYPE_F16); + const bool use_split_kv_path = !use_ref && (neq1 == 1 && neq3 == 1) && kv_is_f32_or_f16 && (k->type == v->type) && q->type == GGML_TYPE_F32 && nek1 >= 512; - if (nth == 1 || nchunk < nth || disable_chunking) { - nchunk = nth; - } + if (use_split_kv_path) { + const int64_t chunk_size = (nek1 + nth - 1) / nth; - if (ith == 0) { - // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start. - ggml_threadpool_chunk_set(params->threadpool, nth); - } + // Partials buffer layout: [q_head][kv_chunk][M, S, VKQ] + const int64_t partial_size = 2 + DV; + float * partials_base = (float *) params->wdata + nth * (DK + 2*DV + CACHE_LINE_SIZE_F32); - ggml_barrier(params->threadpool); + const int64_t ic_start = ith * chunk_size; + const int64_t ic_end = std::min(ic_start + chunk_size, nek1); - // The number of elements in each chunk - const int64_t dr = (nr + nchunk - 1) / nchunk; + const int64_t partial_stride = nth * partial_size; + float * chunk_partials = partials_base + ith * partial_size; - static constexpr int64_t KV_TILE_SZ = ggml_fa_tile_config::KV; - static constexpr int64_t Q_TILE_SZ = ggml_fa_tile_config::Q; - const bool kv_is_f32_or_f16 = (k->type == GGML_TYPE_F32 || k->type == GGML_TYPE_F16); - const bool use_tiled = (q->type == GGML_TYPE_F32 && - kv_is_f32_or_f16 && - k->type == v->type && - nek1 % KV_TILE_SZ == 0 && - neq1 >= Q_TILE_SZ); // Only use tiled for batch >= tile size + if (ic_start < nek1) { + for (int64_t q_head = 0; q_head < neq2; q_head++) { + ggml_compute_forward_flash_attn_ext_f16_one_chunk( + params, dst, q_head, q_head + 1, ic_start, ic_end, + chunk_partials, partial_stride); + } + } else { + for (int64_t q_head = 0; q_head < neq2; q_head++) { + float * q_partials = chunk_partials + q_head * partial_stride; + q_partials[0] = -INFINITY; // M + q_partials[1] = 0.0f; // S + } + } - // The first chunk comes from our thread_id, the rest will get auto-assigned. - int current_chunk = ith; + ggml_barrier(params->threadpool); + ggml_flash_attn_ext_reduce_partials(params, dst, nth, chunk_size); + } else { - while (current_chunk < nchunk) { - const int64_t ir0 = dr * current_chunk; - const int64_t ir1 = MIN(ir0 + dr, nr); + // total rows in q + const int64_t nr = neq1*neq2*neq3; - if (use_tiled) { - ggml_compute_forward_flash_attn_ext_tiled(params, dst, ir0, ir1); - } else { - ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1); + // disable for NUMA + const bool disable_chunking = ggml_is_numa(); + + // 4x chunks per thread + int nth_scaled = nth * 4; + int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled; + int64_t nchunk = (nr + chunk_size - 1) / chunk_size; + + if (nth == 1 || nchunk < nth || disable_chunking) { + nchunk = nth; } - current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1); + if (ith == 0) { + ggml_threadpool_chunk_set(params->threadpool, nth); + } + + ggml_barrier(params->threadpool); + + const int64_t dr = (nr + nchunk - 1) / nchunk; + + static constexpr int64_t KV_TILE_SZ = ggml_fa_tile_config::KV; + static constexpr int64_t Q_TILE_SZ = ggml_fa_tile_config::Q; + const bool use_tiled = !use_ref && + (q->type == GGML_TYPE_F32 && + kv_is_f32_or_f16 && + k->type == v->type && + nek1 % KV_TILE_SZ == 0 && + neq1 >= Q_TILE_SZ); + + int current_chunk = ith; + + while (current_chunk < nchunk) { + const int64_t ir0 = dr * current_chunk; + const int64_t ir1 = MIN(ir0 + dr, nr); + + if (use_tiled) { + ggml_compute_forward_flash_attn_ext_tiled(params, dst, ir0, ir1); + } else { + ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1, 0, nek1, nullptr, 0); + } + + current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1); + } } } From 591072fcc8226e7364ef409fd0b3fb5885638668 Mon Sep 17 00:00:00 2001 From: lhez Date: Mon, 2 Feb 2026 15:54:43 -0800 Subject: [PATCH 088/831] opencl: refactor some ops, concat, repeat, tanh and scale (llama/19226) * opencl: refactor concat * opencl: refactor repeat * opencl: refactor tanh * opencl: enable fp16 for tanh * opencl: refactor scale * opencl: fix unused variables --- ggml/src/ggml-opencl/ggml-opencl.cpp | 474 +++++++++++-------------- ggml/src/ggml-opencl/kernels/concat.cl | 140 +++----- ggml/src/ggml-opencl/kernels/repeat.cl | 63 ++-- ggml/src/ggml-opencl/kernels/scale.cl | 18 +- ggml/src/ggml-opencl/kernels/tanh.cl | 142 +++++--- 5 files changed, 395 insertions(+), 442 deletions(-) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 0f0eb3a9d87..508b2b8f037 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -453,7 +453,6 @@ struct ggml_backend_opencl_context { cl_program program_rms_norm; cl_program program_group_norm; cl_program program_rope; - cl_program program_scale; cl_program program_silu; cl_program program_sigmoid; cl_program program_softmax_f32; @@ -462,11 +461,8 @@ struct ggml_backend_opencl_context { cl_program program_softmax_4_f16; cl_program program_argsort_f32_i32; cl_program program_sum_rows_f32; - cl_program program_repeat; cl_program program_pad; - cl_program program_tanh; cl_program program_upscale; - cl_program program_concat; cl_program program_conv_2d_f16; cl_program program_conv_2d_f32; cl_program program_conv_2d_f16_f32; @@ -485,7 +481,7 @@ struct ggml_backend_opencl_context { cl_kernel kernel_div, kernel_div_row, kernel_div_f16, kernel_div_row_f16; cl_kernel kernel_sub, kernel_sub_row, kernel_sub_f16, kernel_sub_row_f16; cl_kernel kernel_add_id; - cl_kernel kernel_scale; + cl_kernel kernel_scale_f32, kernel_scale_f32_4; cl_kernel kernel_sqr_cont_f32, kernel_sqr_cont_f32_4, kernel_sqr_cont_f16, kernel_sqr_cont_f16_4; cl_kernel kernel_sqrt_cont_f32, kernel_sqrt_cont_f32_4, kernel_sqrt_cont_f16, kernel_sqrt_cont_f16_4; cl_kernel kernel_mean_f32; @@ -544,18 +540,17 @@ struct ggml_backend_opencl_context { cl_kernel kernel_im2col_f32, kernel_im2col_f16; cl_kernel kernel_argsort_f32_i32; cl_kernel kernel_sum_rows_f32; - cl_kernel kernel_repeat; + cl_kernel kernel_repeat_f32; cl_kernel kernel_pad; - cl_kernel kernel_tanh_f32_nd; - cl_kernel kernel_tanh_f16_nd; + cl_kernel kernel_tanh_f32, kernel_tanh_f32_4, kernel_tanh_f32_nc; + cl_kernel kernel_tanh_f16, kernel_tanh_f16_4, kernel_tanh_f16_nc; cl_kernel kernel_expm1_f32_nd; cl_kernel kernel_expm1_f16_nd; cl_kernel kernel_softplus_f32_nd; cl_kernel kernel_softplus_f16_nd; cl_kernel kernel_upscale; cl_kernel kernel_upscale_bilinear; - cl_kernel kernel_concat_f32_contiguous; - cl_kernel kernel_concat_f32_non_contiguous; + cl_kernel kernel_concat_f32; cl_kernel kernel_conv_2d_f16; cl_kernel kernel_conv_2d_f32; cl_kernel kernel_conv_2d_f16_f32; @@ -1483,10 +1478,12 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve #else const std::string kernel_src = read_file("scale.cl"); #endif - backend_ctx->program_scale = + cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_scale = clCreateKernel(backend_ctx->program_scale, "kernel_scale", &err), err)); + CL_CHECK((backend_ctx->kernel_scale_f32 = clCreateKernel(prog, "kernel_scale_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_scale_f32_4 = clCreateKernel(prog, "kernel_scale_f32_4", &err), err)); + CL_CHECK(clReleaseProgram(prog)); GGML_LOG_CONT("."); } @@ -1814,16 +1811,11 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve #else const std::string kernel_src = read_file("repeat.cl"); #endif - if (!kernel_src.empty()) { - backend_ctx->program_repeat = - build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_repeat = clCreateKernel(backend_ctx->program_repeat, "kernel_repeat", &err), err)); - GGML_LOG_CONT("."); - } else { - GGML_LOG_WARN("ggml_opencl: repeat kernel source not found or empty. Repeat operations will not be available.\n"); - backend_ctx->program_repeat = nullptr; - backend_ctx->kernel_repeat = nullptr; - } + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_repeat_f32 = clCreateKernel(prog, "kernel_repeat_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); } // pad @@ -1856,18 +1848,16 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve #else const std::string kernel_src = read_file("tanh.cl"); #endif - if (!kernel_src.empty()) { - backend_ctx->program_tanh = - build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_tanh_f32_nd = clCreateKernel(backend_ctx->program_tanh, "kernel_tanh_f32_nd", &err), err)); - CL_CHECK((backend_ctx->kernel_tanh_f16_nd = clCreateKernel(backend_ctx->program_tanh, "kernel_tanh_f16_nd", &err), err)); - GGML_LOG_CONT("."); - } else { - GGML_LOG_WARN("ggml_opencl: tanh kernel source not found or empty. Tanh operation will not be available.\n"); - backend_ctx->program_tanh = nullptr; - backend_ctx->kernel_tanh_f32_nd = nullptr; - backend_ctx->kernel_tanh_f16_nd = nullptr; - } + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_tanh_f32 = clCreateKernel(prog, "kernel_tanh_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_tanh_f32_4 = clCreateKernel(prog, "kernel_tanh_f32_4", &err), err)); + CL_CHECK((backend_ctx->kernel_tanh_f32_nc = clCreateKernel(prog, "kernel_tanh_f32_nc", &err), err)); + CL_CHECK((backend_ctx->kernel_tanh_f16 = clCreateKernel(prog, "kernel_tanh_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_tanh_f16_4 = clCreateKernel(prog, "kernel_tanh_f16_4", &err), err)); + CL_CHECK((backend_ctx->kernel_tanh_f16_nc = clCreateKernel(prog, "kernel_tanh_f16_nc", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); } // expm1 @@ -1959,22 +1949,13 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve #include "concat.cl.h" }; #else - const std::string kernel_src = read_file("concat.cl"); #endif - if (!kernel_src.empty()) { - backend_ctx->program_concat = - build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - - CL_CHECK((backend_ctx->kernel_concat_f32_contiguous = clCreateKernel(backend_ctx->program_concat, "kernel_concat_f32_contiguous", &err), err)); - CL_CHECK((backend_ctx->kernel_concat_f32_non_contiguous = clCreateKernel(backend_ctx->program_concat, "kernel_concat_f32_non_contiguous", &err), err)); - GGML_LOG_CONT("."); - } else { - GGML_LOG_WARN("ggml_opencl: concat kernel source not found or empty. Concat operations will not be available.\n"); - backend_ctx->program_concat = nullptr; - backend_ctx->kernel_concat_f32_contiguous = nullptr; - backend_ctx->kernel_concat_f32_non_contiguous = nullptr; - } + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_concat_f32 = clCreateKernel(prog, "kernel_concat_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); } // timestep_embedding @@ -3318,8 +3299,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te case GGML_UNARY_OP_SIGMOID: return ggml_is_contiguous(op->src[0]); case GGML_UNARY_OP_TANH: - return (op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) || - (op->src[0]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16); + return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16; case GGML_UNARY_OP_EXPM1: return (op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) || (op->src[0]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16); @@ -7029,79 +7009,87 @@ static void ggml_cl_tanh(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; - cl_ulong offset0_abs = extra0->offset + src0->view_offs; - cl_ulong offsetd_abs = extrad->offset + dst->view_offs; + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; - cl_kernel kernel; - if (dst->type == GGML_TYPE_F32) { - kernel = backend_ctx->kernel_tanh_f32_nd; - } else if (dst->type == GGML_TYPE_F16) { - kernel = backend_ctx->kernel_tanh_f16_nd; - } else { - GGML_ASSERT(false && "Unsupported type for ggml_cl_tanh"); - } - GGML_ASSERT(kernel != nullptr); + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne03 = src0->ne[3]; - const int ne00 = src0->ne[0]; const int ne01 = src0->ne[1]; const int ne02 = src0->ne[2]; const int ne03 = src0->ne[3]; - const cl_ulong nb00 = src0->nb[0]; const cl_ulong nb01 = src0->nb[1]; const cl_ulong nb02 = src0->nb[2]; const cl_ulong nb03 = src0->nb[3]; + const cl_ulong nb00 = src0->nb[0]; + const cl_ulong nb01 = src0->nb[1]; + const cl_ulong nb02 = src0->nb[2]; + const cl_ulong nb03 = src0->nb[3]; - const int ne10 = dst->ne[0]; const int ne11 = dst->ne[1]; const int ne12 = dst->ne[2]; const int ne13 = dst->ne[3]; - const cl_ulong nb10 = dst->nb[0]; const cl_ulong nb11 = dst->nb[1]; const cl_ulong nb12 = dst->nb[2]; const cl_ulong nb13 = dst->nb[3]; + const cl_ulong nb0 = dst->nb[0]; + const cl_ulong nb1 = dst->nb[1]; + const cl_ulong nb2 = dst->nb[2]; + const cl_ulong nb3 = dst->nb[3]; - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0_abs)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd_abs)); + cl_kernel kernel; - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); - CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01)); - CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02)); - CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03)); - CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb00)); - CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb01)); - CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong),&nb02)); - CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong),&nb03)); + if (ggml_is_contiguous(src0)) { + // Handle contiguous input + int n = ggml_nelements(dst); + if (n % 4 == 0) { + if (src0->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_tanh_f32_4; + } else { + kernel = backend_ctx->kernel_tanh_f16_4; + } + n /= 4; + } else { + if (src0->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_tanh_f32; + } else { + kernel = backend_ctx->kernel_tanh_f16; + } + } - CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne10)); - CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne11)); - CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne12)); - CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne13)); - CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong),&nb10)); - CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong),&nb11)); - CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong),&nb12)); - CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong),&nb13)); + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); - size_t global_work_size[3]; - if (ne10 == 0 || ne11 == 0 || ne12 == 0 || ne13 == 0) { // Handle case of 0 elements - return; - } - global_work_size[0] = (size_t)ne10; - global_work_size[1] = (size_t)ne11; - global_work_size[2] = (size_t)ne12; + size_t global_work_size[] = {(size_t)n, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; - size_t lws0 = 16, lws1 = 4, lws2 = 1; - if (ne10 < 16) lws0 = ne10; - if (ne11 < 4) lws1 = ne11; - if (ne12 < 1) lws2 = ne12 > 0 ? ne12 : 1; + size_t * local_work_size_ptr = local_work_size; + if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) { + local_work_size_ptr = nullptr; + } - while (lws0 * lws1 * lws2 > 256 && lws0 > 1) lws0 /= 2; - while (lws0 * lws1 * lws2 > 256 && lws1 > 1) lws1 /= 2; - while (lws0 * lws1 * lws2 > 256 && lws2 > 1) lws2 /= 2; + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst); + } else { + // Handle non-contiguous input + if (src0->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_tanh_f32_nc; + } else { + kernel = backend_ctx->kernel_tanh_f16_nc; + } + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &nb00)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb0)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb1)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb2)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb3)); + + int nth = 64; - size_t local_work_size[] = {lws0, lws1, lws2}; + size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03}; + size_t local_work_size[] = {(size_t)nth, 1, 1}; - size_t* local_work_size_ptr = local_work_size; - if (!backend_ctx->non_uniform_workgroups) { - if (global_work_size[0] % local_work_size[0] != 0 || - global_work_size[1] % local_work_size[1] != 0 || - global_work_size[2] % local_work_size[2] != 0) { - local_work_size_ptr = NULL; - } + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); } - if (global_work_size[0] == 0 || global_work_size[1] == 0 || global_work_size[2] == 0) return; - - backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst); } static void ggml_cl_expm1(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { @@ -7319,53 +7307,58 @@ static void ggml_cl_repeat(ggml_backend_t backend, const ggml_tensor * src0, con ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; - if (backend_ctx->kernel_repeat == nullptr) { - GGML_LOG_WARN("%s: repeat kernel not available, skipping OpenCL execution.\n", __func__); - return; - } + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; - ggml_tensor_extra_cl * extra_src0 = (ggml_tensor_extra_cl *)src0->extra; - ggml_tensor_extra_cl * extra_dst = (ggml_tensor_extra_cl *)dst->extra; + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; - cl_ulong off_src0 = extra_src0->offset + src0->view_offs; - cl_ulong off_dst = extra_dst->offset + dst->view_offs; + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne03 = src0->ne[3]; + + const cl_ulong nb00 = src0->nb[0]; + const cl_ulong nb01 = src0->nb[1]; + const cl_ulong nb02 = src0->nb[2]; + const cl_ulong nb03 = src0->nb[3]; + + const int ne0 = dst->ne[0]; + const int ne1 = dst->ne[1]; + const int ne2 = dst->ne[2]; + const int ne3 = dst->ne[3]; - const int src0_ne0 = src0->ne[0]; const int src0_ne1 = src0->ne[1]; const int src0_ne2 = src0->ne[2]; const int src0_ne3 = src0->ne[3]; - const cl_ulong src0_nb0 = src0->nb[0]; const cl_ulong src0_nb1 = src0->nb[1]; const cl_ulong src0_nb2 = src0->nb[2]; const cl_ulong src0_nb3 = src0->nb[3]; + const cl_ulong nb0 = dst->nb[0]; + const cl_ulong nb1 = dst->nb[1]; + const cl_ulong nb2 = dst->nb[2]; + const cl_ulong nb3 = dst->nb[3]; - const int dst_ne0 = dst->ne[0]; const int dst_ne1 = dst->ne[1]; const int dst_ne2 = dst->ne[2]; const int dst_ne3 = dst->ne[3]; - const cl_ulong dst_nb0 = dst->nb[0]; const cl_ulong dst_nb1 = dst->nb[1]; const cl_ulong dst_nb2 = dst->nb[2]; const cl_ulong dst_nb3 = dst->nb[3]; + cl_kernel kernel = backend_ctx->kernel_repeat_f32; - cl_kernel kernel = backend_ctx->kernel_repeat; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb00)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb0)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb1)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb2)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb3)); - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra_src0->data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra_dst->data_device)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_ulong), &off_src0)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &off_dst)); - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &src0_ne0)); - CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &src0_ne1)); - CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &src0_ne2)); - CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &src0_ne3)); - CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &src0_nb0)); - CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &src0_nb1)); - CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &src0_nb2)); - CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &src0_nb3)); - CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &dst_ne0)); - CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &dst_ne1)); - CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &dst_ne2)); - CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &dst_ne3)); - CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &dst_nb0)); - CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &dst_nb1)); - CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &dst_nb2)); - CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &dst_nb3)); - - size_t gws0 = dst_ne1 > 0 ? (size_t)dst_ne1 : 1; - size_t gws1 = dst_ne2 > 0 ? (size_t)dst_ne2 : 1; - size_t gws2 = dst_ne3 > 0 ? (size_t)dst_ne3 : 1; - - size_t global_work_size[] = { gws0, gws1, gws2 }; + int nth = 64; - backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, NULL, dst); + size_t global_work_size[] = {(size_t)ne1*nth, (size_t)ne2, (size_t)ne3}; + size_t local_work_size[] = {(size_t)nth, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); } static void ggml_cl_pad(ggml_backend_t backend, const ggml_tensor * src0, ggml_tensor * dst) { @@ -7589,121 +7582,76 @@ static void ggml_cl_concat(ggml_backend_t backend, const ggml_tensor * src0, con GGML_ASSERT(dst->type == GGML_TYPE_F32); ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; - cl_command_queue queue = backend_ctx->queue; - if (backend_ctx->kernel_concat_f32_contiguous == nullptr || backend_ctx->kernel_concat_f32_non_contiguous == nullptr) { - GGML_LOG_WARN("%s: concat kernels not available, skipping OpenCL execution.\n", __func__); - return; - } + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; - ggml_tensor_extra_cl * extra0_cl = (ggml_tensor_extra_cl *)src0->extra; - ggml_tensor_extra_cl * extra1_cl = (ggml_tensor_extra_cl *)src1->extra; - ggml_tensor_extra_cl * extrad_cl = (ggml_tensor_extra_cl *)dst->extra; + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; - cl_ulong off_src0 = extra0_cl->offset + src0->view_offs; - cl_ulong off_src1 = extra1_cl->offset + src1->view_offs; - cl_ulong off_dst = extrad_cl->offset + dst->view_offs; + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne03 = src0->ne[3]; - const int32_t dim = ((const int32_t *) dst->op_params)[0]; + const cl_ulong nb00 = src0->nb[0]; + const cl_ulong nb01 = src0->nb[1]; + const cl_ulong nb02 = src0->nb[2]; + const cl_ulong nb03 = src0->nb[3]; + + const cl_ulong nb10 = src1->nb[0]; + const cl_ulong nb11 = src1->nb[1]; + const cl_ulong nb12 = src1->nb[2]; + const cl_ulong nb13 = src1->nb[3]; + + const int ne0 = dst->ne[0]; + const int ne1 = dst->ne[1]; + const int ne2 = dst->ne[2]; + const int ne3 = dst->ne[3]; + + const cl_ulong nb0 = dst->nb[0]; + const cl_ulong nb1 = dst->nb[1]; + const cl_ulong nb2 = dst->nb[2]; + const cl_ulong nb3 = dst->nb[3]; + + const cl_int dim = ((const int32_t *) dst->op_params)[0]; GGML_ASSERT(dim >= 0 && dim <= 3); - if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) { - if (dim == 3) { + int nth = MIN(64, ne0); - size_t nbytes_src0 = ggml_nbytes(src0); - size_t nbytes_src1 = ggml_nbytes(src1); + cl_kernel kernel = backend_ctx->kernel_concat_f32; - CL_CHECK(clEnqueueCopyBuffer(queue, extra0_cl->data_device, extrad_cl->data_device, - off_src0, off_dst, nbytes_src0, 0, NULL, NULL)); - CL_CHECK(clEnqueueCopyBuffer(queue, extra1_cl->data_device, extrad_cl->data_device, - off_src1, off_dst + nbytes_src0, nbytes_src1, 0, NULL, NULL)); - } else { + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne03)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb00)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb10)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb13)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb0)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb1)); + CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb2)); + CL_CHECK(clSetKernelArg(kernel, 22, sizeof(cl_ulong), &nb3)); + CL_CHECK(clSetKernelArg(kernel, 23, sizeof(cl_int), &dim)); + + size_t global_work_size[] = {(size_t)ne1*nth, (size_t)ne2, (size_t)ne3}; + size_t local_work_size[] = {(size_t)nth, 1, 1}; - cl_kernel kernel = backend_ctx->kernel_concat_f32_contiguous; - size_t global_work_size[3]; - - for (int i3 = 0; i3 < dst->ne[3]; ++i3) { - cl_ulong current_off_src0 = off_src0 + (i3 * src0->nb[3]); - cl_ulong current_off_src1 = off_src1 + (i3 * src1->nb[3]); - cl_ulong current_off_dst = off_dst + (i3 * dst->nb[3]); - - int d_ne00 = src0->ne[0]; int d_ne01 = src0->ne[1]; int d_ne02 = src0->ne[2]; - int d_ne10 = src1->ne[0]; int d_ne11 = src1->ne[1]; int d_ne12 = src1->ne[2]; - int d_ne0 = dst->ne[0]; int d_ne1 = dst->ne[1]; int d_ne2 = dst->ne[2]; - - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_cl->data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), ¤t_off_src0)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1_cl->data_device)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), ¤t_off_src1)); - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad_cl->data_device)); - CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), ¤t_off_dst)); - CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &d_ne00)); - CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &d_ne01)); - CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &d_ne02)); - CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &d_ne10)); - CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &d_ne11)); - CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &d_ne12)); - CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &d_ne0)); - CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &d_ne1)); - CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &d_ne2)); - CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &dim)); - - global_work_size[0] = d_ne0; - global_work_size[1] = d_ne1; - global_work_size[2] = d_ne2; - - backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, NULL, dst); - } - } - } else { - cl_kernel kernel = backend_ctx->kernel_concat_f32_non_contiguous; - - cl_long ne00 = src0->ne[0], ne01 = src0->ne[1], ne02 = src0->ne[2], ne03 = src0->ne[3]; - cl_ulong nb00 = src0->nb[0], nb01 = src0->nb[1], nb02 = src0->nb[2], nb03 = src0->nb[3]; - - cl_ulong nb10 = src1->nb[0], nb11 = src1->nb[1], nb12 = src1->nb[2], nb13 = src1->nb[3]; - - cl_long d_ne0 = dst->ne[0], d_ne1 = dst->ne[1], d_ne2 = dst->ne[2], d_ne3 = dst->ne[3]; - cl_ulong d_nb0 = dst->nb[0], d_nb1 = dst->nb[1], d_nb2 = dst->nb[2], d_nb3 = dst->nb[3]; - - - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_cl->data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &off_src0)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1_cl->data_device)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &off_src1)); - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad_cl->data_device)); - CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &off_dst)); - - CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_long), &ne00)); - CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_long), &ne01)); - CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_long), &ne02)); - CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_long), &ne03)); - CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb00)); - CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb01)); - CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb02)); - CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb03)); - - CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb10)); - CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb11)); - CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb12)); - CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb13)); - - CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_long), &d_ne0)); - CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_long), &d_ne1)); - CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_long), &d_ne2)); - CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_long), &d_ne3)); - CL_CHECK(clSetKernelArg(kernel, 22, sizeof(cl_ulong), &d_nb0)); - CL_CHECK(clSetKernelArg(kernel, 23, sizeof(cl_ulong), &d_nb1)); - CL_CHECK(clSetKernelArg(kernel, 24, sizeof(cl_ulong), &d_nb2)); - CL_CHECK(clSetKernelArg(kernel, 25, sizeof(cl_ulong), &d_nb3)); - CL_CHECK(clSetKernelArg(kernel, 26, sizeof(int), &dim)); - - size_t global_work_size_nc[] = { d_ne1 > 0 ? (size_t)d_ne1 : 1, - d_ne2 > 0 ? (size_t)d_ne2 : 1, - d_ne3 > 0 ? (size_t)d_ne3 : 1 }; - - backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size_nc, NULL, dst); - } + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); } static void ggml_cl_timestep_embedding(ggml_backend_t backend, const ggml_tensor * src0, ggml_tensor * dst) { @@ -8394,6 +8342,7 @@ static void ggml_cl_mul_mat_q8_0_f32_adreno(ggml_backend_t backend, const ggml_t CL_CHECK(clReleaseMemObject(D_sub_buffer)); CL_CHECK(clReleaseMemObject(D_image1d)); #else + GGML_UNUSED(backend); GGML_UNUSED(src0); GGML_UNUSED(src1); GGML_UNUSED(dst); @@ -9913,7 +9862,16 @@ static void ggml_cl_scale(ggml_backend_t backend, const ggml_tensor * src0, cons cl_ulong offset0 = extra0->offset + src0->view_offs; cl_ulong offsetd = extrad->offset + dst->view_offs; - cl_kernel kernel = backend_ctx->kernel_scale; + cl_kernel kernel; + + int n = ggml_nelements(dst); + + if (n % 4 == 0) { + kernel = backend_ctx->kernel_scale_f32_4; + n /= 4; + } else { + kernel = backend_ctx->kernel_scale_f32; + } CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); @@ -9922,8 +9880,6 @@ static void ggml_cl_scale(ggml_backend_t backend, const ggml_tensor * src0, cons CL_CHECK(clSetKernelArg(kernel, 4, sizeof(float), &scale)); CL_CHECK(clSetKernelArg(kernel, 5, sizeof(float), &bias)); - int n = ggml_nelements(dst)/4; - size_t global_work_size[] = {(size_t)n, 1, 1}; size_t local_work_size[] = {64, 1, 1}; diff --git a/ggml/src/ggml-opencl/kernels/concat.cl b/ggml/src/ggml-opencl/kernels/concat.cl index 132758469c6..0c1b3d785ca 100644 --- a/ggml/src/ggml-opencl/kernels/concat.cl +++ b/ggml/src/ggml-opencl/kernels/concat.cl @@ -1,109 +1,51 @@ -kernel void kernel_concat_f32_contiguous( - global const char * p_src0, ulong off_src0, - global const char * p_src1, ulong off_src1, - global char * p_dst, ulong off_dst, - int d_ne00, int d_ne01, int d_ne02, // src0->ne[0..2] for the slice - int d_ne10, int d_ne11, int d_ne12, // src1->ne[0..2] for the slice (d_ne1X must match d_ne0X on non-concat axes) - int d_ne0, int d_ne1, int d_ne2, // dst->ne[0..2] for the slice - int dim +kernel void kernel_concat_f32( + global const char * src0, + ulong offset0, + global const char * src1, + ulong offset1, + global char * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + ulong nb10, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3, + int dim ) { - global const float * src0 = (global const float*)((global char*)p_src0 + off_src0); - global const float * src1 = (global const float*)((global char*)p_src1 + off_src1); - global float * dst = (global float*)((global char*)p_dst + off_dst); + src0 = src0 + offset0; + src1 = src1 + offset1; + dst = dst + offsetd; - int i0 = get_global_id(0); // Index along dst's 0th dimension - int i1 = get_global_id(1); // Index along dst's 1st dimension - int i2 = get_global_id(2); // Index along dst's 2nd dimension + const int i3 = get_group_id(2); + const int i2 = get_group_id(1); + const int i1 = get_group_id(0); - if (i0 >= d_ne0 || i1 >= d_ne1 || i2 >= d_ne2) { - return; - } - - ulong dst_idx = (ulong)i2 * d_ne0 * d_ne1 + (ulong)i1 * d_ne0 + i0; - ulong src_idx; - - if (dim == 0) { - if (i0 < d_ne00) { // Data from src0 - src_idx = (ulong)i2 * d_ne00 * d_ne01 + (ulong)i1 * d_ne00 + i0; - dst[dst_idx] = src0[src_idx]; - } else { // Data from src1 - src_idx = (ulong)i2 * d_ne10 * d_ne11 + (ulong)i1 * d_ne10 + (i0 - d_ne00); - dst[dst_idx] = src1[src_idx]; - } - } else if (dim == 1) { - if (i1 < d_ne01) { // Data from src0 - src_idx = (ulong)i2 * d_ne00 * d_ne01 + (ulong)i1 * d_ne00 + i0; - dst[dst_idx] = src0[src_idx]; - } else { // Data from src1 - src_idx = (ulong)i2 * d_ne10 * d_ne11 + (ulong)(i1 - d_ne01) * d_ne10 + i0; - dst[dst_idx] = src1[src_idx]; - } - } else if (dim == 2) { - if (i2 < d_ne02) { // Data from src0 - src_idx = (ulong)i2 * d_ne00 * d_ne01 + (ulong)i1 * d_ne00 + i0; - dst[dst_idx] = src0[src_idx]; - } else { // Data from src1 - - src_idx = (ulong)(i2 - d_ne02) * d_ne10 * d_ne11 + (ulong)i1 * d_ne10 + i0; - dst[dst_idx] = src1[src_idx]; - } - } -} - -kernel void kernel_concat_f32_non_contiguous( - global const char * p_src0, ulong off_src0, - global const char * p_src1, ulong off_src1, - global char * p_dst, ulong off_dst, - - long ne00, long ne01, long ne02, long ne03, - ulong nb00, ulong nb01, ulong nb02, ulong nb03, + int o[4] = {0, 0, 0, 0}; + o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03)); - ulong nb10, ulong nb11, ulong nb12, ulong nb13, // Strides for src1 + global const float * x; - long d_ne0, long d_ne1, long d_ne2, long d_ne3, - ulong d_nb0, ulong d_nb1, ulong d_nb2, ulong d_nb3, - int dim -) { - global const char * src0_base = p_src0 + off_src0; - global const char * src1_base = p_src1 + off_src1; - global char * dst_base = p_dst + off_dst; - - long current_i1 = get_global_id(0); // Index for dst_dim_1 - long current_i2 = get_global_id(1); // Index for dst_dim_2 - long current_i3 = get_global_id(2); // Index for dst_dim_3 - - if (current_i1 >= d_ne1 || current_i2 >= d_ne2 || current_i3 >= d_ne3) { - return; - } - - global const float * x_val_ptr; - global float * y_val_ptr; - - for (long current_i0 = 0; current_i0 < d_ne0; ++current_i0) { - bool use_src0; - long s_i0 = current_i0, s_i1 = current_i1, s_i2 = current_i2, s_i3 = current_i3; - - if (dim == 0) { - use_src0 = (current_i0 < ne00); - if (!use_src0) { s_i0 = current_i0 - ne00; } - } else if (dim == 1) { - use_src0 = (current_i1 < ne01); - if (!use_src0) { s_i1 = current_i1 - ne01; } - } else if (dim == 2) { - use_src0 = (current_i2 < ne02); - if (!use_src0) { s_i2 = current_i2 - ne02; } - } else { // dim == 3 - use_src0 = (current_i3 < ne03); - if (!use_src0) { s_i3 = current_i3 - ne03; } - } - - if (use_src0) { - x_val_ptr = (global const float *)(src0_base + (ulong)s_i3*nb03 + (ulong)s_i2*nb02 + (ulong)s_i1*nb01 + (ulong)s_i0*nb00); + for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) { + if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) { + x = (global const float *)(src0 + (i3 )*nb03 + (i2 )*nb02 + (i1 )*nb01 + (i0 )*nb00); } else { - x_val_ptr = (global const float *)(src1_base + (ulong)s_i3*nb13 + (ulong)s_i2*nb12 + (ulong)s_i1*nb11 + (ulong)s_i0*nb10); + x = (global const float *)(src1 + (i3 - o[3])*nb13 + (i2 - o[2])*nb12 + (i1 - o[1])*nb11 + (i0 - o[0])*nb10); } - y_val_ptr = (global float *)(dst_base + (ulong)current_i3*d_nb3 + (ulong)current_i2*d_nb2 + (ulong)current_i1*d_nb1 + (ulong)current_i0*d_nb0); - *y_val_ptr = *x_val_ptr; + global float * y = (global float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + *y = *x; } } diff --git a/ggml/src/ggml-opencl/kernels/repeat.cl b/ggml/src/ggml-opencl/kernels/repeat.cl index 079498f5ab9..53951a55434 100644 --- a/ggml/src/ggml-opencl/kernels/repeat.cl +++ b/ggml/src/ggml-opencl/kernels/repeat.cl @@ -1,39 +1,38 @@ -kernel void kernel_repeat( - global const char * src0_data_in, - global char * dst_data_in, - ulong src0_offset, - ulong dst_offset, - int src0_ne0, int src0_ne1, int src0_ne2, int src0_ne3, - ulong src0_nb0, ulong src0_nb1, ulong src0_nb2, ulong src0_nb3, - int dst_ne0, int dst_ne1, int dst_ne2, int dst_ne3, - ulong dst_nb0, ulong dst_nb1, ulong dst_nb2, ulong dst_nb3 +kernel void kernel_repeat_f32( + global const char * src0, + ulong offset0, + global char * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 ) { - global const char * src0_data = src0_data_in + src0_offset; - global char * dst_data = dst_data_in + dst_offset; + src0 = src0 + offset0; + dst = dst + offsetd; - const int d3 = get_global_id(2); - const int d2 = get_global_id(1); - const int d1 = get_global_id(0); + const int i3 = get_group_id(2); + const int i2 = get_group_id(1); + const int i1 = get_group_id(0); - if (d3 >= dst_ne3 || d2 >= dst_ne2 || d1 >= dst_ne1) { - return; - } - - const int s3 = d3 % src0_ne3; - const int s2 = d2 % src0_ne2; - const int s1 = d1 % src0_ne1; - - const global char * p_src0_slice = src0_data + (ulong)s3*src0_nb3 + (ulong)s2*src0_nb2 + (ulong)s1*src0_nb1; - global char * p_dst_slice = dst_data + (ulong)d3*dst_nb3 + (ulong)d2*dst_nb2 + (ulong)d1*dst_nb1; + const int i03 = i3%ne03; + const int i02 = i2%ne02; + const int i01 = i1%ne01; - for (int d0 = 0; d0 < dst_ne0; ++d0) { - // Determine source index for dimension 0 based on tiling/broadcasting. - const int s0 = d0 % src0_ne0; + global const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01; + global char * dst_ptr = dst + i3*nb3 + i2*nb2 + i1*nb1; - const global char * restrict current_src_el_ptr = p_src0_slice + (ulong)s0*src0_nb0; - global char * restrict current_dst_el_ptr = p_dst_slice + (ulong)d0*dst_nb0; - for (int k = 0; k < src0_nb0; ++k) { - current_dst_el_ptr[k] = current_src_el_ptr[k]; - } + for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) { + const int i00 = i0%ne00; + *((global float *)(dst_ptr + i0*nb0)) = *((global float *)(src0_ptr + i00*nb00)); } } diff --git a/ggml/src/ggml-opencl/kernels/scale.cl b/ggml/src/ggml-opencl/kernels/scale.cl index aeca8a456e4..17ed97f0d66 100644 --- a/ggml/src/ggml-opencl/kernels/scale.cl +++ b/ggml/src/ggml-opencl/kernels/scale.cl @@ -1,9 +1,19 @@ #pragma OPENCL EXTENSION cl_khr_fp16 : enable -//------------------------------------------------------------------------------ -// scale -//------------------------------------------------------------------------------ -kernel void kernel_scale( +kernel void kernel_scale_f32( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd, + float scale, + float bias +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + dst[get_global_id(0)] = src0[get_global_id(0)] * scale + bias; +} + +kernel void kernel_scale_f32_4( global float4 * src0, ulong offset0, global float4 * dst, diff --git a/ggml/src/ggml-opencl/kernels/tanh.cl b/ggml/src/ggml-opencl/kernels/tanh.cl index d9da86b1489..2c4887ad3e0 100644 --- a/ggml/src/ggml-opencl/kernels/tanh.cl +++ b/ggml/src/ggml-opencl/kernels/tanh.cl @@ -1,63 +1,109 @@ #pragma OPENCL EXTENSION cl_khr_fp16 : enable -#ifdef cl_intel_required_subgroup_size -#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable -#define INTEL_GPU 1 -#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) -#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) -#elif defined(cl_qcom_reqd_sub_group_size) -#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable -#define ADRENO_GPU 1 -#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) -#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) -#endif - -kernel void kernel_tanh_f32_nd( - global void * p_src0_base, ulong off_src0_abs, - global void * p_dst_base, ulong off_dst_abs, - int ne00, int ne01, int ne02, int ne03, - ulong nb00, ulong nb01, ulong nb02, ulong nb03, - int ne10, int ne11, int ne12, int ne13, - ulong nb10, ulong nb11, ulong nb12, ulong nb13 +kernel void kernel_tanh_f32( + global const float * src0, + ulong offset0, + global float * dst, + ulong offsetd ) { - int i0 = get_global_id(0); - int i1 = get_global_id(1); - int i2 = get_global_id(2); + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); - if (i0 < ne10 && i1 < ne11 && i2 < ne12) { - for (int i3 = 0; i3 < ne13; ++i3) { - ulong src_offset_in_tensor = (ulong)i0*nb00 + (ulong)i1*nb01 + (ulong)i2*nb02 + (ulong)i3*nb03; - global const float *src_val_ptr = (global const float *)((global char *)p_src0_base + off_src0_abs + src_offset_in_tensor); + dst[get_global_id(0)] = tanh(src0[get_global_id(0)]); +} + +kernel void kernel_tanh_f32_4( + global const float4 * src0, + ulong offset0, + global float4 * dst, + ulong offsetd +) { + src0 = (global float4*)((global char*)src0 + offset0); + dst = (global float4*)((global char*)dst + offsetd); + + dst[get_global_id(0)] = tanh(src0[get_global_id(0)]); +} + +kernel void kernel_tanh_f16( + global const half * src0, + ulong offset0, + global half * dst, + ulong offsetd +) { + src0 = (global half*)((global char*)src0 + offset0); + dst = (global half*)((global char*)dst + offsetd); + + dst[get_global_id(0)] = tanh(src0[get_global_id(0)]); +} + +kernel void kernel_tanh_f16_4( + global const half4 * src0, + ulong offset0, + global half4 * dst, + ulong offsetd +) { + src0 = (global half4*)((global char*)src0 + offset0); + dst = (global half4*)((global char*)dst + offsetd); + + dst[get_global_id(0)] = tanh(src0[get_global_id(0)]); +} + +kernel void kernel_tanh_f32_nc( + global const char * src0, + ulong offset0, + global char * dst, + ulong offsetd, + int ne00, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = src0 + offset0; + dst = dst + offsetd; + + const int i3 = get_group_id(2); + const int i2 = get_group_id(1); + const int i1 = get_group_id(0); - ulong dst_offset_in_tensor = (ulong)i0*nb10 + (ulong)i1*nb11 + (ulong)i2*nb12 + (ulong)i3*nb13; - global float *dst_val_ptr = (global float *)((global char *)p_dst_base + off_dst_abs + dst_offset_in_tensor); + for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) { + global const float * x = (global const float *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global float * y = (global float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - *dst_val_ptr = tanh(*src_val_ptr); - } + *y = tanh(*x); } } -kernel void kernel_tanh_f16_nd( - global void * p_src0_base, ulong off_src0_abs, - global void * p_dst_base, ulong off_dst_abs, - int ne00, int ne01, int ne02, int ne03, - ulong nb00, ulong nb01, ulong nb02, ulong nb03, - int ne10, int ne11, int ne12, int ne13, - ulong nb10, ulong nb11, ulong nb12, ulong nb13 +kernel void kernel_tanh_f16_nc( + global const char * src0, + ulong offset0, + global char * dst, + ulong offsetd, + int ne00, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 ) { - int i0 = get_global_id(0); - int i1 = get_global_id(1); - int i2 = get_global_id(2); + src0 = src0 + offset0; + dst = dst + offsetd; - if (i0 < ne10 && i1 < ne11 && i2 < ne12) { - for (int i3 = 0; i3 < ne13; ++i3) { - ulong src_offset_in_tensor = (ulong)i0*nb00 + (ulong)i1*nb01 + (ulong)i2*nb02 + (ulong)i3*nb03; - global const half *src_val_ptr = (global const half *)((global char *)p_src0_base + off_src0_abs + src_offset_in_tensor); + const int i3 = get_group_id(2); + const int i2 = get_group_id(1); + const int i1 = get_group_id(0); - ulong dst_offset_in_tensor = (ulong)i0*nb10 + (ulong)i1*nb11 + (ulong)i2*nb12 + (ulong)i3*nb13; - global half *dst_val_ptr = (global half *)((global char *)p_dst_base + off_dst_abs + dst_offset_in_tensor); + for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) { + global const half * x = (global const half *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global half * y = (global half *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - *dst_val_ptr = tanh(*src_val_ptr); - } + *y = tanh(*x); } } From 6ec362d2e0d705892d6ded95cafbd7877d332f83 Mon Sep 17 00:00:00 2001 From: Gaurav Garg Date: Tue, 3 Feb 2026 12:11:02 +0530 Subject: [PATCH 089/831] cuda : revert CUDA_SCALE_LAUNCH_QUEUES override until investigated (llama/19227) Hangs were reported on Jetson Orin AGX if we set CUDA_SCALE_LAUNCH_QUEUES=4x. Reverting the previous PR (#19042) and updating the document to consider setting CUDA_SCALE_LAUNCH_QUEUES=4x for faster throughput on multi-GPU systems. --- ggml/src/ggml-cuda/ggml-cuda.cu | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 08383edb402..1bcd1ab1f8f 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -5049,16 +5049,6 @@ ggml_backend_reg_t ggml_backend_cuda_reg() { static std::mutex mutex; std::lock_guard lock(mutex); if (!initialized) { - // Set CUDA_SCALE_LAUNCH_QUEUES before any CUDA API call to improve multi-GPU pipeline parallelism performance - // PR: https://github.com/ggml-org/llama.cpp/pull/19042 - if (getenv("CUDA_SCALE_LAUNCH_QUEUES") == nullptr) { -#ifdef _WIN32 - _putenv_s("CUDA_SCALE_LAUNCH_QUEUES", "4x"); -#else - setenv("CUDA_SCALE_LAUNCH_QUEUES", "4x", 0); // don't overwrite if already set -#endif // _WIN32 - } - ggml_backend_cuda_reg_context * ctx = new ggml_backend_cuda_reg_context; const int min_batch_size = getenv("GGML_OP_OFFLOAD_MIN_BATCH") ? atoi(getenv("GGML_OP_OFFLOAD_MIN_BATCH")) : 32; From 57107b2bf88356afd512b1a1c4208606746d8ee2 Mon Sep 17 00:00:00 2001 From: George <35490284+noctrex@users.noreply.github.com> Date: Tue, 3 Feb 2026 08:43:39 +0200 Subject: [PATCH 090/831] ggml: added cleanups in ggml_quantize_free (llama/19278) Add missing cleanup calls for IQ2_S, IQ1_M quantization types and IQ3XS with 512 blocks during quantization cleanup. --- ggml/src/ggml.c | 3 +++ 1 file changed, 3 insertions(+) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index e1471b540ed..500cb6b72f9 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -7517,8 +7517,11 @@ void ggml_quantize_free(void) { iq2xs_free_impl(GGML_TYPE_IQ2_XXS); iq2xs_free_impl(GGML_TYPE_IQ2_XS); + iq2xs_free_impl(GGML_TYPE_IQ2_S); iq2xs_free_impl(GGML_TYPE_IQ1_S); + iq2xs_free_impl(GGML_TYPE_IQ1_M); iq3xs_free_impl(256); + iq3xs_free_impl(512); ggml_critical_section_end(); } From 698265d754069d43c58f74628427bc96b21c107e Mon Sep 17 00:00:00 2001 From: Oliver Simons Date: Tue, 3 Feb 2026 11:33:14 +0100 Subject: [PATCH 091/831] CUDA: Fix loop unrolling for BW in mul_mat_q_stream_k_fixup (llama/19053) By providing stride_* variables as size_t (i.e., 64-bit) the compiler can correctly unroll the [two for-loops](https://github.com/ggml-org/llama.cpp/blob/557515be1e93ed8939dd8a7c7d08765fdbe8be31/ggml/src/ggml-cuda/mmq.cuh#L3789-L3816) on BW. This gives some perf for prefill/pp phase on BW, while not affecting other SMs: | GPU | Model | Test | t/s master | t/s osimons/fix_bw_mmq_fixup_kernel | Speedup | |:--------------------------------------------------------|:----------------------|:-------|-------------:|--------------------------------------:|----------:| | NVIDIA RTX 6000 Ada Generation | gpt-oss 20B MXFP4 MoE | pp8096 | 8404.05 | 8375.79 | 1.00 | | NVIDIA RTX 6000 Ada Generation | llama 3B Q4_K_M | pp8096 | 16148.93 | 16019.60 | 0.99 | | NVIDIA RTX 6000 Ada Generation | llama 8B Q4_0 | pp8096 | 8008.29 | 7978.80 | 1.00 | | NVIDIA RTX 6000 Ada Generation | nemotron_h 9B BF16 | pp8096 | 4263.16 | 4248.53 | 1.00 | | NVIDIA RTX 6000 Ada Generation | nemotron_h 9B Q4_K_M | pp8096 | 5165.11 | 5157.43 | 1.00 | | NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition | gpt-oss 20B MXFP4 MoE | pp8096 | 12582.80 | 12758.37 | 1.01 | | NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition | llama 3B Q4_K_M | pp8096 | 16879.10 | 17619.47 | 1.04 | | NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition | llama 8B Q4_0 | pp8096 | 10649.90 | 10982.65 | 1.03 | | NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition | nemotron_h 9B BF16 | pp8096 | 7717.73 | 7716.22 | 1.00 | | NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition | nemotron_h 9B Q4_K_M | pp8096 | 7301.90 | 7370.38 | 1.01 | --- ggml/src/ggml-cuda/mmq.cuh | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index a382e6a6979..f80f98cda2c 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -3697,13 +3697,20 @@ static __global__ void mul_mat_q( tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop); } - template -static __global__ void mul_mat_q_stream_k_fixup( - const int32_t * ids_dst, const int32_t * expert_bounds, float * __restrict__ dst, const float * __restrict__ tmp_last_tile, - const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_col_dst, - const int nchannels_y, const int stride_channel_dst, const int nsamples_y, const int stride_sample_dst, - const int ncols_max) { +static __global__ void mul_mat_q_stream_k_fixup(const int32_t * ids_dst, + const int32_t * expert_bounds, + float * __restrict__ dst, + const float * __restrict__ tmp_last_tile, + const int ncols_x, + const int nrows_x, + const int ncols_dst, + const size_t stride_col_dst, + const int nchannels_y, + const size_t stride_channel_dst, + const int nsamples_y, + const size_t stride_sample_dst, + const int ncols_max) { constexpr int mmq_y = get_mmq_y_device(); constexpr int qk = ggml_cuda_type_traits::qk; constexpr int ITER_K = get_iter_k(type); From ce8a2da62004f47522839f1719907f785262e684 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 3 Feb 2026 13:43:29 +0200 Subject: [PATCH 092/831] metal : minor cleanup (llama/19251) --- ggml/src/ggml-metal/ggml-metal-impl.h | 4 +- ggml/src/ggml-metal/ggml-metal-ops.cpp | 38 ++++-------- ggml/src/ggml-metal/ggml-metal.metal | 86 +++++++------------------- 3 files changed, 34 insertions(+), 94 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 59d88b01a55..e074f2ef3db 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -81,10 +81,10 @@ #define FC_COUNT_EQUAL 1000 // op-specific constants -#define OP_FLASH_ATTN_EXT_NQPTG 8 +#define OP_FLASH_ATTN_EXT_NQPSG 8 #define OP_FLASH_ATTN_EXT_NCPSG 64 -#define OP_FLASH_ATTN_EXT_VEC_NQPTG 1 +#define OP_FLASH_ATTN_EXT_VEC_NQPSG 1 #define OP_FLASH_ATTN_EXT_VEC_NCPSG 32 // kernel argument structs diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 7f4cfbba226..f97c4435dec 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -2295,7 +2295,7 @@ size_t ggml_metal_op_flash_attn_ext_extra_blk(const ggml_tensor * op) { // return res; //} - const int nqptg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NQPTG : OP_FLASH_ATTN_EXT_NQPTG; + const int nqptg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NQPSG : OP_FLASH_ATTN_EXT_NQPSG; const int ncpsg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NCPSG : OP_FLASH_ATTN_EXT_NCPSG; const int64_t ne1 = (ne01 + nqptg - 1)/nqptg; @@ -2411,7 +2411,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { if (!ggml_metal_op_flash_attn_ext_use_vec(op)) { // half8x8 kernel - const int nqptg = OP_FLASH_ATTN_EXT_NQPTG; // queries per threadgroup + const int nqptg = OP_FLASH_ATTN_EXT_NQPSG; // queries per threadgroup const int ncpsg = OP_FLASH_ATTN_EXT_NCPSG; // cache values per simdgroup GGML_ASSERT(nqptg <= 32); @@ -2578,9 +2578,9 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { #undef FATTN_SMEM } else { // half4x4 kernel - const int nqptg = OP_FLASH_ATTN_EXT_VEC_NQPTG; // queries per threadgroup + const int nqptg = OP_FLASH_ATTN_EXT_VEC_NQPSG; // queries per threadgroup const int ncpsg = OP_FLASH_ATTN_EXT_VEC_NCPSG; // cache values per simdgroup !! sync with kernel template arguments !! - const int nkpsg = 1*ncpsg; + const int nhptg = 1; // heads per threadgroup GGML_ASSERT(nqptg <= 32); GGML_ASSERT(nqptg % 1 == 0); @@ -2632,6 +2632,9 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { ggml_metal_op_concurrency_reset(ctx); } + // note: for simplicity assume the K is larger or equal than V + GGML_ASSERT(ne10 >= ne20); + // ne00 + 2*ncpsg*(nsg) // for each query, we load it as f16 in shared memory (ne00) // and store the soft_max values and the mask @@ -2639,28 +2642,9 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { // ne20*(nsg) // each simdgroup has a full f32 head vector in shared mem to accumulate results // -#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + 2*GGML_PAD(ne20, 128)*(nsg))*(sizeof(float)/2), 16)) - - int64_t nsgmax = 2; - while (true) { - const size_t smem = FATTN_SMEM(nsgmax); - // avoid using more than half of the threadgroup memory - can cause slow downs especially for large head sizes - if (smem > props_dev->max_theadgroup_memory_size/2) { - break; - } - nsgmax *= 2; - } - nsgmax /= 2; - - // simdgroups per threadgroup (a.k.a. warps) - //const int64_t nsgt = MAX(2, MIN(nsgmax, MIN((ne11 + nkpsg - 1)/(nkpsg), (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))); - const int64_t nsgt = MAX(2, MIN(nsgmax, MIN((ne11 + nkpsg - 1)/(nkpsg), (int64_t) 1024/32))); +#define FATTN_SMEM(nsg) (GGML_PAD(((GGML_PAD(ne00, 128) + 4*ncpsg + 2*GGML_PAD(ne20, 128))*(nsg))*(sizeof(float)/2), 16)) int64_t nsg = 1; - while (nsg <= nsgt) { - nsg *= 2; - } - nsg /= 2; // workgroups // each workgroup handles nsg*nkpsg cache values @@ -2673,7 +2657,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { } else { nwg = 32; nsg = 1; - while (2*nwg*nsg*nkpsg < ne11 && nsg < 4) { + while (2*nwg*nsg*ncpsg < ne11 && nsg < 4) { nsg *= 2; } } @@ -2739,7 +2723,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); - ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1); + ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, (ne02 + nhptg - 1)/nhptg, ne03*nwg, 32, nsg, 1); } else { // sanity checks assert(ggml_metal_op_flash_attn_ext_extra_tmp(op) != 0); @@ -2752,7 +2736,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_set_buffer(enc, bid_tmp, 7); ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); - ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1); + ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, (ne02 + nhptg - 1)/nhptg, ne03*nwg, 32, nsg, 1); // sync the 2 kernels ggml_metal_op_concurrency_reset(ctx); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 17e358d1a8d..3259213fd61 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -5931,7 +5931,7 @@ template< void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &), short DK, // K head size short DV, // V head size - short Q = OP_FLASH_ATTN_EXT_NQPTG, // queries per threadgroup + short Q = OP_FLASH_ATTN_EXT_NQPSG, // queries per threadgroup short C = OP_FLASH_ATTN_EXT_NCPSG> // cache items per threadgroup kernel void kernel_flash_attn_ext( constant ggml_metal_kargs_flash_attn_ext & args, @@ -6141,11 +6141,10 @@ template< void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &), short DK, // K head size short DV, // V head size - short NE, // head elements per thread - short Q, // queries per threadgroup - short C, // cache items per threadgroup - short NSG> // number of simd groups -void kernel_flash_attn_ext_vec_impl( + short NE = 4, // head elements per thread + short Q = OP_FLASH_ATTN_EXT_VEC_NQPSG, // queries per threadgroup + short C = OP_FLASH_ATTN_EXT_VEC_NCPSG> // cache items per threadgroup +kernel void kernel_flash_attn_ext_vec( constant ggml_metal_kargs_flash_attn_ext_vec & args, device const char * q, device const char * k, @@ -6162,6 +6161,7 @@ void kernel_flash_attn_ext_vec_impl( static_assert(DV % 32 == 0, "DV must be divisible by 32"); #define NWG (FC_flash_attn_ext_vec_nwg) +#define NSG (FC_flash_attn_ext_vec_nsg) #define NS10 (FC_flash_attn_ext_vec_ns10) #define NS20 (FC_flash_attn_ext_vec_ns20) @@ -6190,12 +6190,12 @@ void kernel_flash_attn_ext_vec_impl( const short T = PK + NSG*SH; // shared memory size per query in (half) - //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*PK); // holds the query data - threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*PK); // same as above but in q4_t - threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*PK); // scratch buffer for attention - threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*PK); // same as above but in s4_t - threadgroup half * sm = (threadgroup half *) (shmem_f16 + sgitg*SH + 2*C + Q*PK); // scratch buffer for mask - threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 2*sgitg*PV + Q*T); // scratch buffer for the results + //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*PK); // holds the query data + threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*PK); // same as above but in q4_t + threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + NSG*PK); // scratch buffer for attention + threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + NSG*PK); // same as above but in s4_t + threadgroup half * sm = (threadgroup half *) (shmem_f16 + sgitg*SH + 2*C + NSG*PK); // scratch buffer for mask + threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 2*sgitg*PV + NSG*PK + NSG*SH); // scratch buffer for the results // store the result for all queries in shared memory (the O matrix from the paper) so4 += tiisg; @@ -6213,11 +6213,13 @@ void kernel_flash_attn_ext_vec_impl( // load heads from Q to shared memory device const float4 * q4 = (device const float4 *) ((device const char *) q); - for (short i = tiisg; i < PK4; i += NW) { - if (iq1 < args.ne01 && i < DK4) { - sq4[i] = (q4_t) q4[i]; - } else { - sq4[i] = (q4_t) 0.0f; + if (iq1 < args.ne01) { + for (short i = tiisg; i < PK4; i += NW) { + if (i < DK4) { + sq4[i] = (q4_t) q4[i]; + } else { + sq4[i] = (q4_t) 0.0f; + } } } @@ -6295,7 +6297,7 @@ void kernel_flash_attn_ext_vec_impl( } // skip -INF blocks - if (simd_max(sm[tiisg]) == -INFINITY) { + if (simd_max(sm[tiisg]) <= -MAXHALF) { continue; } @@ -6569,57 +6571,11 @@ void kernel_flash_attn_ext_vec_impl( } #undef NWG +#undef NSG #undef NS10 #undef NS20 } -template< - typename q4_t, // query types in shared memory - typename k4_t, // key types in shared memory - typename v4_t, // value types in shared memory - typename qk_t, // Q*K types - typename s_t, // soft-max types - typename s4_t, - typename o4_t, // attention accumulation types - typename kd4_t, // key type in device memory - short nl_k, - void (*deq_k_t4)(device const kd4_t *, short, thread k4_t &), - typename vd4_t, // value type in device memory - short nl_v, - void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &), - short DK, // K head size - short DV, // V head size - short NE = 4, // head elements per thread - short Q = OP_FLASH_ATTN_EXT_VEC_NQPTG, // queries per threadgroup - short C = OP_FLASH_ATTN_EXT_VEC_NCPSG> // cache items per threadgroup -kernel void kernel_flash_attn_ext_vec( - constant ggml_metal_kargs_flash_attn_ext_vec & args, - device const char * q, - device const char * k, - device const char * v, - device const char * mask, - device const char * sinks, - device const char * pad, - device char * dst, - threadgroup half * shmem_f16 [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - ushort tiisg[[thread_index_in_simdgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]]) { -#define FWD_TMPL q4_t, k4_t, v4_t, qk_t, s_t, s4_t, o4_t, kd4_t, nl_k, deq_k_t4, vd4_t, nl_v, deq_v_t4, DK, DV, NE, Q, C -#define FWD_ARGS args, q, k, v, mask, sinks, pad, dst, shmem_f16, tgpig, tiisg, sgitg - switch (FC_flash_attn_ext_vec_nsg) { - // note: disabled cases to reduce library load time - case 1: kernel_flash_attn_ext_vec_impl(FWD_ARGS); break; - case 2: kernel_flash_attn_ext_vec_impl(FWD_ARGS); break; - case 4: kernel_flash_attn_ext_vec_impl(FWD_ARGS); break; - //case 8: kernel_flash_attn_ext_vec_impl(FWD_ARGS); break; - //case 16: kernel_flash_attn_ext_vec_impl(FWD_ARGS); break; - //case 32: kernel_flash_attn_ext_vec_impl(FWD_ARGS); break; - } -#undef FWD_TMPL -#undef FWD_ARGS -} - // note: I think the s_t can be half instead of float, because the Q*K scaling is done before storing to shared mem // in the other (non-vec) kernel, we need s_t to also be float because we scale during the soft_max // From 8eede801e3d1799a12f969bba044aaffe59bace7 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Tue, 3 Feb 2026 23:31:23 +0800 Subject: [PATCH 093/831] CUDA: use mmvq for mul-mat-id for small batch sizes (llama/18958) * CUDA: use mmvq for mul-mat-id for small batch sizes * add mmvq too * Fix perf issue on ampere. Use mmvf mm-id only for non-nvidia GPUs * templatize multi_token_path --- ggml/src/ggml-cuda/ggml-cuda.cu | 14 ++- ggml/src/ggml-cuda/mmvf.cu | 194 +++++++++++++++++++++----------- ggml/src/ggml-cuda/mmvf.cuh | 2 + ggml/src/ggml-cuda/mmvq.cu | 135 ++++++++++++++-------- 4 files changed, 224 insertions(+), 121 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 1bcd1ab1f8f..eeb8625dbeb 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2279,13 +2279,19 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - if (ne2 == 1) { + static_assert(MMVQ_MAX_BATCH_SIZE == MMVF_MAX_BATCH_SIZE); + if (ne2 <= MMVQ_MAX_BATCH_SIZE) { if (ggml_is_quantized(src0->type)) { - ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst); + if (ne2 <= 4) { + ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst); + return; + } } else { - ggml_cuda_mul_mat_vec_f(ctx, src0, src1, ids, dst); + if (GGML_CUDA_CC_IS_AMD(cc)) { + ggml_cuda_mul_mat_vec_f(ctx, src0, src1, ids, dst); + return; + } } - return; } if (ggml_cuda_should_use_mmq(src0->type, cc, ne12, /*n_experts=*/ne02)) { diff --git a/ggml/src/ggml-cuda/mmvf.cu b/ggml/src/ggml-cuda/mmvf.cu index 32948e4d7a1..d9147202429 100644 --- a/ggml/src/ggml-cuda/mmvf.cu +++ b/ggml/src/ggml-cuda/mmvf.cu @@ -4,26 +4,48 @@ #include "mmvf.cuh" #include "convert.cuh" -template +template static __global__ void mul_mat_vec_f( const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst, - const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst, + const int ncols2, const uint3 nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst, const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, - const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) { + const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, + const int ids_stride) { const int row = blockIdx.x; + // for MUL_MAT_ID - blockIdx.y = n_expert_used, blockIdx.z = ncols_dst (tokens) const int channel_dst = blockIdx.y; - const int channel_x = ids ? ids[channel_dst] : fastdiv((uint32_t) channel_dst, channel_ratio); - const int channel_y = ids ? channel_dst % nchannels_y : channel_dst; - const int sample_dst = blockIdx.z; + const int tid = threadIdx.x; + + int token_idx; + int channel_x; + int channel_y; + int sample_dst; + + if constexpr (is_multi_token_id) { + // Multi-token MUL_MAT_ID path, adding these in the normal path causes a perf regression for n_tokens=1 case + token_idx = blockIdx.z; + channel_x = ids[channel_dst + token_idx * ids_stride]; + channel_y = fastmodulo(channel_dst, nchannels_y); + sample_dst = 0; + } else { + token_idx = ids ? blockIdx.z : 0; + channel_x = ids ? ids[blockIdx.y + token_idx * ids_stride] : fastdiv((uint32_t) channel_dst, channel_ratio); + channel_y = ids ? fastmodulo(blockIdx.y, nchannels_y) : channel_dst; + sample_dst = ids ? 0 : blockIdx.z; + } + const int sample_x = fastdiv((uint32_t) sample_dst, sample_ratio); const int sample_y = sample_dst; - const int tid = threadIdx.x; constexpr int warp_size = ggml_cuda_get_physical_warp_size(); x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row*stride_row; y += int64_t(sample_y) *stride_sample_y + channel_y *stride_channel_y; dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst; + if constexpr (is_multi_token_id) { + y += token_idx*stride_col_y2*2; + dst += token_idx*stride_col_dst; + } bool use_gate = false; bool use_bias = false; @@ -56,8 +78,10 @@ static __global__ void mul_mat_vec_f( if (use_gate) { gate_x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row*stride_row; } + + const int channel_bias = ids ? channel_x : channel_dst; + if constexpr (has_fusion) { - const int channel_bias = ids ? channel_x : channel_dst; if (use_bias) { x_bias += int64_t(sample_dst)*stride_sample_dst + channel_bias*stride_channel_dst; } @@ -349,36 +373,36 @@ static __global__ void mul_mat_vec_f( } } -template +template static void mul_mat_vec_f_switch_fusion( const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst, - const int64_t ncols, const int64_t nrows, + const int64_t ncols, const uint3 nchannels_y, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, - const dim3 & block_dims, const dim3 & block_nums, const int nbytes_shared, const cudaStream_t stream) { + const dim3 & block_dims, const dim3 & block_nums, const int nbytes_shared, const int ids_stride, const cudaStream_t stream) { const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr; if constexpr (ncols_dst == 1) { if (has_fusion) { - mul_mat_vec_f<<>> - (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + mul_mat_vec_f<<>> + (x, y, ids, fusion, dst, ncols, nchannels_y, stride_row, stride_col_y, stride_col_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride); return; } } GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1"); - mul_mat_vec_f<<>> - (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + mul_mat_vec_f<<>> + (x, y, ids, fusion, dst, ncols, nchannels_y, stride_row, stride_col_y, stride_col_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride); } -template +template void launch_mul_mat_vec_f_cuda( const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst, const int64_t ncols, const int64_t nrows, @@ -386,12 +410,13 @@ void launch_mul_mat_vec_f_cuda( const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, - cudaStream_t stream) { + const int64_t nsamples_or_ntokens, const int64_t ids_stride, cudaStream_t stream) { GGML_ASSERT(ncols % 2 == 0); GGML_ASSERT(stride_row % 2 == 0); GGML_ASSERT(stride_col_y % 2 == 0); GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0); GGML_ASSERT( nsamples_dst % nsamples_x == 0); + const uint3 nchannels_y_fd = ids ? init_fastdiv_values(nchannels_y) : make_uint3(0, 0, 0); const uint3 channel_ratio_fd = ids ? make_uint3(0, 0, 0) : init_fastdiv_values(nchannels_dst / nchannels_x); const uint3 sample_ratio_fd = init_fastdiv_values(nsamples_dst / nsamples_x); @@ -415,56 +440,56 @@ void launch_mul_mat_vec_f_cuda( const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr; const int nbytes_shared = warp_size*sizeof(float) + (has_fusion ? warp_size*sizeof(float) : 0); - const dim3 block_nums(nrows, nchannels_dst, nsamples_dst); + const dim3 block_nums(nrows, nchannels_dst, nsamples_or_ntokens); const dim3 block_dims(block_size_best, 1, 1); switch (block_size_best) { case 32: { - mul_mat_vec_f_switch_fusion - (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + mul_mat_vec_f_switch_fusion + (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream); } break; case 64: { - mul_mat_vec_f_switch_fusion - (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + mul_mat_vec_f_switch_fusion + (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream); } break; case 96: { - mul_mat_vec_f_switch_fusion - (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + mul_mat_vec_f_switch_fusion + (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream); } break; case 128: { - mul_mat_vec_f_switch_fusion - (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + mul_mat_vec_f_switch_fusion + (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream); } break; case 160: { - mul_mat_vec_f_switch_fusion - (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + mul_mat_vec_f_switch_fusion + (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream); } break; case 192: { - mul_mat_vec_f_switch_fusion - (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + mul_mat_vec_f_switch_fusion + (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream); } break; case 224: { - mul_mat_vec_f_switch_fusion - (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + mul_mat_vec_f_switch_fusion + (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream); } break; case 256: { - mul_mat_vec_f_switch_fusion - (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + mul_mat_vec_f_switch_fusion + (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream); } break; default: { GGML_ABORT("fatal error"); @@ -480,55 +505,88 @@ static void mul_mat_vec_f_cuda_switch_ncols_dst( const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, - cudaStream_t stream) { + const int64_t ids_stride, cudaStream_t stream) { + + const bool has_ids = ids != nullptr; + + if (has_ids && ncols_dst > 1) { + // Multi-token MUL_MAT_ID path only - single-token goes through regular path below + constexpr int c_ncols_dst = 1; + launch_mul_mat_vec_f_cuda + (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + ncols_dst, ids_stride, stream); + return; + } + + if (has_ids) { + // Single-token MUL_MAT_ID path + constexpr int c_ncols_dst = 1; + launch_mul_mat_vec_f_cuda + (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + ncols_dst, ids_stride, stream); + return; + } + switch (ncols_dst) { case 1: launch_mul_mat_vec_f_cuda (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, - stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + nsamples_dst, ids_stride, stream); break; case 2: launch_mul_mat_vec_f_cuda (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, - stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + nsamples_dst, ids_stride, stream); break; case 3: launch_mul_mat_vec_f_cuda (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, - stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + nsamples_dst, ids_stride, stream); break; case 4: launch_mul_mat_vec_f_cuda (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, - stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + nsamples_dst, ids_stride, stream); break; case 5: launch_mul_mat_vec_f_cuda (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, - stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + nsamples_dst, ids_stride, stream); break; case 6: launch_mul_mat_vec_f_cuda (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, - stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + nsamples_dst, ids_stride, stream); break; case 7: launch_mul_mat_vec_f_cuda (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, - stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + nsamples_dst, ids_stride, stream); break; case 8: launch_mul_mat_vec_f_cuda (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, - stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + nsamples_dst, ids_stride, stream); break; default: GGML_ABORT("fatal error"); @@ -544,21 +602,21 @@ static void mul_mat_vec_f_cuda( const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, - enum ggml_prec prec, cudaStream_t stream) { + const int64_t ids_stride, enum ggml_prec prec, cudaStream_t stream) { if constexpr(std::is_same_v) { if (prec == GGML_PREC_DEFAULT) { mul_mat_vec_f_cuda_switch_ncols_dst (x, y, ids, fusion, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, - stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); return; } } mul_mat_vec_f_cuda_switch_ncols_dst (x, y, ids, fusion, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, - stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); } void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, @@ -573,7 +631,7 @@ void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor const size_t ts_src1 = ggml_type_size(src1->type); const size_t ts_dst = ggml_type_size(dst->type); - GGML_ASSERT(!ids || ne12 == 1); // Implementation is only correct for batch size 1. + GGML_ASSERT(!ids || ne12 <= MMVF_MAX_BATCH_SIZE); GGML_ASSERT(ne13 == ne3); GGML_ASSERT( nb00 == ts_src0); @@ -626,29 +684,31 @@ void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor const int64_t ncols_dst = ids ? ne2 : ne1; const int64_t nchannels_y = ids ? ne11 : ne12; const int64_t nchannels_dst = ids ? ne1 : ne2; + const int64_t stride_col_dst = ids ? s2 : s1; + const int64_t stride_col_y = ids ? s12 : s11; const int64_t stride_channel_dst = ids ? s1 : s2; const int64_t stride_channel_y = ids ? s11 : s12; - GGML_ASSERT(!ids || ncols_dst == 1); + const int64_t ids_stride = ids ? ids->nb[1] / ggml_type_size(ids->type) : 0; switch (src0->type) { case GGML_TYPE_F32: { const float * src0_d = (const float *) src0->data; - mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, s11, s1, + mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, stride_col_y, stride_col_dst, ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst, - ne03, ne3, s03, s13, s3, prec, ctx.stream()); + ne03, ne3, s03, s13, s3, ids_stride, prec, ctx.stream()); } break; case GGML_TYPE_F16: { const half * src0_d = (const half *) src0->data; - mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, s11, s1, + mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, stride_col_y, stride_col_dst, ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst, - ne03, ne3, s03, s13, s3, prec, ctx.stream()); + ne03, ne3, s03, s13, s3, ids_stride, prec, ctx.stream()); } break; case GGML_TYPE_BF16: { const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data; - mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, s11, s1, + mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, stride_col_y, stride_col_dst, ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst, - ne03, ne3, s03, s13, s3, prec, ctx.stream()); + ne03, ne3, s03, s13, s3, ids_stride, prec, ctx.stream()); } break; default: GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type)); @@ -695,19 +755,19 @@ void ggml_cuda_op_mul_mat_vec_f( const float * src0_d = (const float *) src0_dd_i; mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, 0, prec, stream); } break; case GGML_TYPE_F16: { const half * src0_d = (const half *) src0_dd_i; mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, 0, prec, stream); } break; case GGML_TYPE_BF16: { const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i; mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, 0, prec, stream); } break; default: GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type)); diff --git a/ggml/src/ggml-cuda/mmvf.cuh b/ggml/src/ggml-cuda/mmvf.cuh index a09fbdc7202..a50f7c02180 100644 --- a/ggml/src/ggml-cuda/mmvf.cuh +++ b/ggml/src/ggml-cuda/mmvf.cuh @@ -1,5 +1,7 @@ #include "common.cuh" +#define MMVF_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVF kernels. + void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, const ggml_cuda_mm_fusion_args_host * fusion = nullptr); diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index d671551c171..ce25ccf427c 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -137,15 +137,15 @@ static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int return 1; } -// tell the compiler to use as many registers as it wants, see nwarps definition below -template +template __launch_bounds__(calc_nwarps(ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1) static __global__ void mul_mat_vec_q( const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst, const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y, const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x, const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio, - const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst) { + const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst, + const uint32_t ids_stride) { constexpr int qk = ggml_cuda_type_traits::qk; constexpr int qi = ggml_cuda_type_traits::qi; @@ -162,11 +162,25 @@ static __global__ void mul_mat_vec_q( const int blocks_per_row_x = ncols_x / qk; constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi; - // The MUL_MAT_ID code path with ids != nullptr is only implemented for ncols_dst == 1. const uint32_t channel_dst = blockIdx.y; - const uint32_t channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : fastdiv(channel_dst, channel_ratio); - const uint32_t channel_y = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst; - const uint32_t sample_dst = blockIdx.z; + + uint32_t token_idx = 0; + uint32_t channel_x; + uint32_t channel_y; + uint32_t sample_dst; + + if constexpr (is_multi_token_id) { + // Multi-token MUL_MAT_ID path, adding these in the normal path causes a perf regression for n_tokens=1 case + token_idx = blockIdx.z; + channel_x = ids[channel_dst + token_idx * ids_stride]; + channel_y = fastmodulo(channel_dst, nchannels_y); + sample_dst = 0; + } else { + channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : fastdiv(channel_dst, channel_ratio); + channel_y = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst; + sample_dst = blockIdx.z; + } + const uint32_t sample_x = fastdiv(sample_dst, sample_ratio); const uint32_t sample_y = sample_dst; @@ -188,11 +202,11 @@ static __global__ void mul_mat_vec_q( active_glu = fusion.glu_op; } - const uint32_t channel_bias = ids ? channel_x : channel_dst; float x_biases[ncols_dst] = { 0.0f }; float gate_biases[ncols_dst] = { 0.0f }; if constexpr (has_fusion) { + const uint32_t channel_bias = ids ? channel_x : channel_dst; if (use_bias) { x_bias = x_bias + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0; // 1. Hide latency by prefetching bias and gate here @@ -222,6 +236,9 @@ static __global__ void mul_mat_vec_q( float tmp_gate[ncols_dst][rows_per_cuda_block] = {{0.0f}}; const block_q8_1 * y = ((const block_q8_1 *) vy) + sample_y*stride_sample_y + channel_y*stride_channel_y; + if constexpr (is_multi_token_id) { + y += token_idx*stride_col_y; + } const int kbx_offset = sample_x*stride_sample_x + channel_x*stride_channel_x + row0*stride_row_x; for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) { @@ -275,6 +292,10 @@ static __global__ void mul_mat_vec_q( dst += sample_dst*stride_sample_dst + channel_dst*stride_channel_dst + row0; + if constexpr (is_multi_token_id) { + dst += token_idx*stride_col_dst; + } + // sum up partial sums and write back result #pragma unroll for (int j = 0; j < ncols_dst; ++j) { @@ -335,40 +356,41 @@ static __global__ void mul_mat_vec_q( } static std::pair calc_launch_params( - const int ncols_dst, const int nrows_x, const int nchannels_y, const int nsamples_y, + const int ncols_dst, const int nrows_x, const int nchannels_dst, const int nsamples_or_ntokens, const int warp_size, const mmvq_parameter_table_id table_id) { const int64_t nblocks = (nrows_x + calc_rows_per_block(ncols_dst, table_id) - 1) / calc_rows_per_block(ncols_dst, table_id); - const dim3 block_nums(nblocks, nchannels_y, nsamples_y); + const dim3 block_nums(nblocks, nchannels_dst, nsamples_or_ntokens); const dim3 block_dims(warp_size, calc_nwarps(ncols_dst, table_id), 1); return {block_nums, block_dims}; } -template +template static void mul_mat_vec_q_switch_fusion( const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst, const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y, const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x, const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio, const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst, - const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared, cudaStream_t stream) { + const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared, + const uint32_t ids_stride, cudaStream_t stream) { const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr; if constexpr (c_ncols_dst == 1) { if (has_fusion) { - mul_mat_vec_q<<>> + mul_mat_vec_q<<>> (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride); return; } } GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1"); - mul_mat_vec_q<<>> + mul_mat_vec_q<<>> (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride); } template @@ -379,7 +401,7 @@ static void mul_mat_vec_q_switch_ncols_dst( const int nchannels_x, const int nchannels_y, const int nchannels_dst, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, const int nsamples_x, const int nsamples_dst, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, - cudaStream_t stream) { + const int ids_stride, cudaStream_t stream) { GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0); GGML_ASSERT(ncols_dst <= MMVQ_MAX_BATCH_SIZE); @@ -393,8 +415,19 @@ static void mul_mat_vec_q_switch_ncols_dst( const mmvq_parameter_table_id table_id = get_device_table_id(ggml_cuda_info().devices[device].cc); const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr; + const bool has_ids = ids != nullptr; + + if (has_ids && ncols_dst > 1) { + // Multi-token MUL_MAT_ID path only - single-token goes through regular path below + constexpr int c_ncols_dst = 1; + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, ncols_dst, warp_size, table_id); + mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, + dims.first, dims.second, 0, ids_stride, stream); + return; + } - GGML_ASSERT(!ids || ncols_dst == 1); switch (ncols_dst) { case 1: { constexpr int c_ncols_dst = 1; @@ -402,7 +435,7 @@ static void mul_mat_vec_q_switch_ncols_dst( mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, - dims.first, dims.second, 0, stream); + dims.first, dims.second, 0, ids_stride, stream); } break; case 2: { constexpr int c_ncols_dst = 2; @@ -410,7 +443,7 @@ static void mul_mat_vec_q_switch_ncols_dst( mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, - dims.first, dims.second, 0, stream); + dims.first, dims.second, 0, ids_stride, stream); } break; case 3: { constexpr int c_ncols_dst = 3; @@ -418,7 +451,7 @@ static void mul_mat_vec_q_switch_ncols_dst( mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, - dims.first, dims.second, 0, stream); + dims.first, dims.second, 0, ids_stride, stream); } break; case 4: { constexpr int c_ncols_dst = 4; @@ -426,7 +459,7 @@ static void mul_mat_vec_q_switch_ncols_dst( mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, - dims.first, dims.second, 0, stream); + dims.first, dims.second, 0, ids_stride, stream); } break; case 5: { constexpr int c_ncols_dst = 5; @@ -434,7 +467,7 @@ static void mul_mat_vec_q_switch_ncols_dst( mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, - dims.first, dims.second, 0, stream); + dims.first, dims.second, 0, ids_stride, stream); } break; case 6: { constexpr int c_ncols_dst = 6; @@ -442,7 +475,7 @@ static void mul_mat_vec_q_switch_ncols_dst( mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, - dims.first, dims.second, 0, stream); + dims.first, dims.second, 0, ids_stride, stream); } break; case 7: { constexpr int c_ncols_dst = 7; @@ -450,7 +483,7 @@ static void mul_mat_vec_q_switch_ncols_dst( mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, - dims.first, dims.second, 0, stream); + dims.first, dims.second, 0, ids_stride, stream); } break; case 8: { constexpr int c_ncols_dst = 8; @@ -458,7 +491,7 @@ static void mul_mat_vec_q_switch_ncols_dst( mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, - dims.first, dims.second, 0, stream); + dims.first, dims.second, 0, ids_stride, stream); } break; default: GGML_ABORT("fatal error"); @@ -474,127 +507,127 @@ static void mul_mat_vec_q_switch_type( const int nchannels_x, const int nchannels_y, const int nchannels_dst, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, const int nsamples_x, const int nsamples_dst, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, - cudaStream_t stream) { + const int ids_stride, cudaStream_t stream) { switch (type_x) { case GGML_TYPE_Q4_0: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_Q4_1: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_Q5_0: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_Q5_1: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_Q8_0: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_MXFP4: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_Q2_K: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_Q3_K: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_Q4_K: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_Q5_K: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_Q6_K: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_IQ2_XXS: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_IQ2_XS: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_IQ2_S: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_IQ3_XXS: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_IQ1_S: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_IQ1_M: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_IQ4_NL: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_IQ4_XS: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_IQ3_S: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; default: GGML_ABORT("fatal error"); @@ -622,7 +655,7 @@ void ggml_cuda_mul_mat_vec_q( GGML_ASSERT( nb0 == ts_dst); GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type)); - GGML_ASSERT(!ids || ne12 == 1); // Implementation is only correct for batch size 1. + GGML_ASSERT(!ids || ne12 <= MMVQ_MAX_BATCH_SIZE); const float * src1_d = (const float *) src1->data; const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr; @@ -693,11 +726,13 @@ void ggml_cuda_mul_mat_vec_q( const int64_t stride_channel_dst = ids ? s1 : s2; const int64_t stride_channel_y = ids ? s11 : s12; + const int64_t ids_stride = ids ? ids->nb[1] / ggml_type_size(ids->type) : 0; + mul_mat_vec_q_switch_type( src0->data, src0->type, src1_q8_1.get(), ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, stride_col_y, stride_col_dst, ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst, - ne03, ne3, s03, s13, s3, stream); + ne03, ne3, s03, s13, s3, ids_stride, stream); } void ggml_cuda_op_mul_mat_vec_q( @@ -726,7 +761,7 @@ void ggml_cuda_op_mul_mat_vec_q( ggml_cuda_mm_fusion_args_device fusion_local{}; mul_mat_vec_q_switch_type( src0_dd_i, src0->type, src1_ddq_i, nullptr, fusion_local, dst_dd_i, ne00, row_diff, src1_ncols, stride_row_x, stride_col_y, nrows_dst, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, stream); + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, stream); GGML_UNUSED_VARS(src1, dst, src1_ddf_i, src1_ncols, src1_padded_row_size); } From aa34558b6ffa27408948ca11a8430ccbf5bad51d Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Tue, 3 Feb 2026 17:37:32 +0100 Subject: [PATCH 094/831] vulkan: disable coopmat1 fa on Nvidia Turing (llama/19290) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index a99375c0885..cb7fa2c9cbb 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -254,6 +254,7 @@ enum vk_device_architecture { AMD_RDNA3, INTEL_XE2, NVIDIA_PRE_TURING, + NVIDIA_TURING, }; static vk_device_architecture get_device_architecture(const vk::PhysicalDevice& device) { @@ -336,18 +337,34 @@ static vk_device_architecture get_device_architecture(const vk::PhysicalDevice& const std::vector ext_props = device.enumerateDeviceExtensionProperties(); bool cooperative_matrix = false; + bool sm_builtins = false; // Detect "pre-turing" based on lack of coopmat support. for (const auto& properties : ext_props) { if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0) { cooperative_matrix = true; - break; + } else if (strcmp("VK_NV_shader_sm_builtins", properties.extensionName) == 0) { + sm_builtins = true; } } if (!cooperative_matrix) { return vk_device_architecture::NVIDIA_PRE_TURING; } + + if (sm_builtins) { + vk::PhysicalDeviceProperties2 props2; + vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV sm_props; + + props2.pNext = &sm_props; + + device.getProperties2(&props2); + + // Turing has 32, following architectures have 48 + if (sm_props.shaderWarpsPerSM == 32) { + return vk_device_architecture::NVIDIA_TURING; + } + } } return vk_device_architecture::OTHER; } @@ -8460,6 +8477,11 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx FaCodePath path = ctx->device->coopmat2 ? FA_COOPMAT2 : ctx->device->coopmat1_fa_support ? FA_COOPMAT1 : FA_SCALAR; + if (path == FA_COOPMAT1 && ctx->device->architecture == vk_device_architecture::NVIDIA_TURING) { + // Nvidia compiler bug, see https://github.com/ggml-org/llama.cpp/pull/19075#issuecomment-3820716090 + path = FA_SCALAR; + } + if (path == FA_COOPMAT1) { const bool coopmat_shape_supported = (dst->op_params[3] == GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f32acc) || (dst->op_params[3] != GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f16acc); From 5dda94dd2e2583d53c1b935a42dfb7e4a5f66566 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 3 Feb 2026 23:43:14 +0200 Subject: [PATCH 095/831] metal : add solve_tri (llama/19302) --- ggml/src/ggml-metal/ggml-metal-device.cpp | 30 +++++++++ ggml/src/ggml-metal/ggml-metal-device.h | 1 + ggml/src/ggml-metal/ggml-metal-device.m | 1 + ggml/src/ggml-metal/ggml-metal-impl.h | 30 ++++++++- ggml/src/ggml-metal/ggml-metal-ops.cpp | 61 ++++++++++++++++++ ggml/src/ggml-metal/ggml-metal-ops.h | 1 + ggml/src/ggml-metal/ggml-metal.metal | 77 +++++++++++++++++++++++ 7 files changed, 200 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 377b0d3eb8f..4cd3d93d813 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -534,6 +534,36 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv(ggml_metal_ return res; } +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri(ggml_metal_library_t lib, const ggml_tensor * op) { + char base[256]; + char name[256]; + + const int nsg = 8; + const int n = op->src[1]->ne[1]; + const int k = op->src[1]->ne[0]; + + snprintf(base, 256, "kernel_solve_tri_%s", ggml_type_name(op->src[0]->type)); + snprintf(name, 256, "%s_nsg=%d_n=%d_k=%d", base, nsg, n, k); + + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_int16(cv, nsg, FC_SOLVE_TRI + 0); + ggml_metal_cv_set_int16(cv, n, FC_SOLVE_TRI + 1); + ggml_metal_cv_set_int16(cv, k, FC_SOLVE_TRI + 2); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); + } + + res.nsg = nsg; + res.smem = GGML_PAD(GGML_PAD(n, 32)*nsg*sizeof(float), 16); + + return res; +} + ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext(ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1, int nsg, int nxpsg, int r1ptg) { char base[256]; char name[256]; diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index afb091e7255..d8984327124 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -121,6 +121,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv_batched (ggml_metal_library_t lib, const struct ggml_tensor * op, int ssm_conv_bs); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv (ggml_metal_library_t lib, const struct ggml_tensor * op); diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 285dd1630e7..8a0b85c6e4d 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -1152,6 +1152,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te return has_simdgroup_reduction; case GGML_OP_RWKV_WKV6: case GGML_OP_RWKV_WKV7: + case GGML_OP_SOLVE_TRI: return true; case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index e074f2ef3db..640ade8f880 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -78,7 +78,8 @@ #define FC_MUL_MM 700 #define FC_ROPE 800 #define FC_SSM_CONV 900 -#define FC_COUNT_EQUAL 1000 +#define FC_SOLVE_TRI 1000 +#define FC_COUNT_EQUAL 1100 // op-specific constants #define OP_FLASH_ATTN_EXT_NQPSG 8 @@ -733,6 +734,33 @@ typedef struct { uint64_t nb0; } ggml_metal_kargs_ssm_scan; +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne10; + int32_t ne11; + int32_t ne12; + int32_t ne13; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; +} ggml_metal_kargs_solve_tri; + typedef struct { int32_t ne00t; int32_t ne00; diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index f97c4435dec..753fcec3175 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -341,6 +341,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { { n_fuse = ggml_metal_op_rwkv(ctx, idx); } break; + case GGML_OP_SOLVE_TRI: + { + n_fuse = ggml_metal_op_solve_tri(ctx, idx); + } break; case GGML_OP_MUL_MAT: { n_fuse = ggml_metal_op_mul_mat(ctx, idx); @@ -1557,6 +1561,63 @@ int ggml_metal_op_rwkv(ggml_metal_op_t ctx, int idx) { return 1; } +int ggml_metal_op_solve_tri(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + ggml_metal_kargs_solve_tri args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.ne13 =*/ ne13, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + }; + + auto pipeline = ggml_metal_library_get_pipeline_solve_tri(lib, op); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); + + const int nsg = pipeline.nsg; + + ggml_metal_encoder_set_threadgroup_memory_size(enc, pipeline.smem, 0); + + ggml_metal_encoder_dispatch_threadgroups(enc, (ne10 + nsg - 1)/nsg, ne02, ne03, 32, nsg, 1); + + return 1; +} + int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) { ggml_tensor * op = ctx->node(idx); diff --git a/ggml/src/ggml-metal/ggml-metal-ops.h b/ggml/src/ggml-metal/ggml-metal-ops.h index 10686a334e0..2e4c7d3fa11 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.h +++ b/ggml/src/ggml-metal/ggml-metal-ops.h @@ -60,6 +60,7 @@ int ggml_metal_op_soft_max (ggml_metal_op_t ctx, int idx); int ggml_metal_op_ssm_conv (ggml_metal_op_t ctx, int idx); int ggml_metal_op_ssm_scan (ggml_metal_op_t ctx, int idx); int ggml_metal_op_rwkv (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_solve_tri (ggml_metal_op_t ctx, int idx); int ggml_metal_op_cpy (ggml_metal_op_t ctx, int idx); int ggml_metal_op_pool_1d (ggml_metal_op_t ctx, int idx); int ggml_metal_op_pool_2d (ggml_metal_op_t ctx, int idx); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 3259213fd61..c09a54e6614 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -2737,6 +2737,83 @@ kernel void kernel_rwkv_wkv7_f32( } } +constant short FC_solve_tri_nsg [[function_constant(FC_SOLVE_TRI + 0)]]; +constant short FC_solve_tri_n [[function_constant(FC_SOLVE_TRI + 1)]]; +constant short FC_solve_tri_k [[function_constant(FC_SOLVE_TRI + 2)]]; + +kernel void kernel_solve_tri_f32( + constant ggml_metal_kargs_solve_tri & args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], + ushort3 tgpig[[threadgroup_position_in_grid]], + ushort sgitg[[simdgroup_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + constexpr short NW = N_SIMDWIDTH; + + const short NSG = FC_solve_tri_nsg; + const short N = FC_solve_tri_n; + const short K = FC_solve_tri_k; + const short NP = PAD2(N, NW); + + const int32_t ne02 = args.ne02; + const int32_t ne03 = args.ne03; + + const int32_t i03 = tgpig.z; + const int32_t i02 = tgpig.y; + const int32_t i01 = tgpig.x*NSG + sgitg; + + threadgroup float * sh0 = (threadgroup float *) shmem; + + device const float * src0_ptr = (device const float *)(src0 + i02 * args.nb02 + i03 * args.nb03) + sgitg*N; + device const float * src1_ptr = (device const float *)(src1 + i02 * args.nb12 + i03 * args.nb13) + i01; + device float * dst_ptr = (device float *)(dst + i02 * args.nb2 + i03 * args.nb3) + i01; + + for (short rr = 0; rr < N; rr += NSG) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + { + threadgroup float * sh0_cur = sh0 + sgitg*NP; + + for (short t = 0; t*NW < N; ++t) { + const short idx = t*NW + tiisg; + sh0_cur[idx] = src0_ptr[idx]; + } + + src0_ptr += NSG*N; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (i01 >= args.ne10) { + continue; + } + + for (short ir = 0; ir < NSG && rr + ir < N; ++ir) { + const short r = rr + ir; + + threadgroup float * sh0_cur = sh0 + ir*NP; + + float sum = 0.0f; + + for (short t = 0; t*NW < r; ++t) { + const short idx = t*NW + tiisg; + sum += sh0_cur[idx] * dst_ptr[idx*K] * (idx < r); + } + + sum = simd_sum(sum); + + if (tiisg == 0) { + const float diag = sh0_cur[r]; + + dst_ptr[r*K] = (src1_ptr[r*K] - sum) / diag; + } + } + } +} + kernel void kernel_argmax_f32( constant ggml_metal_kargs_argmax & args, device const char * src0, From 4685ec95557ecd53225c5a1abac2392454a4e20d Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Wed, 4 Feb 2026 09:43:29 +0800 Subject: [PATCH 096/831] ggml-cpu: use LUT for converting e8->f32 scales on x86 (llama/19288) * ggml-cpu: use LUT for converting e8->f32 scales on x86 * add dispatch based on macro --- ggml/src/ggml-cpu/arch/x86/quants.c | 18 +++++++++--------- ggml/src/ggml-cpu/ggml-cpu.c | 8 ++++++++ ggml/src/ggml-cpu/simd-mappings.h | 11 +++++++++++ 3 files changed, 28 insertions(+), 9 deletions(-) diff --git a/ggml/src/ggml-cpu/arch/x86/quants.c b/ggml/src/ggml-cpu/arch/x86/quants.c index cb49320a67f..74d699f633d 100644 --- a/ggml/src/ggml-cpu/arch/x86/quants.c +++ b/ggml/src/ggml-cpu/arch/x86/quants.c @@ -268,9 +268,9 @@ static inline __m256 quad_fp16_delta_float(const float x0, const float y0, const _mm_set1_ps(GGML_CPU_FP16_TO_FP32(x0) * GGML_CPU_FP16_TO_FP32(y0))); } -static inline __m256 quad_mx_delta_float(const int8_t x0, const float y0, const int8_t x1, const float y1) { - return _mm256_set_m128(_mm_set1_ps(GGML_E8M0_TO_FP32_HALF(x1) * GGML_CPU_FP16_TO_FP32(y1)), - _mm_set1_ps(GGML_E8M0_TO_FP32_HALF(x0) * GGML_CPU_FP16_TO_FP32(y0))); +static inline __m256 quad_mx_delta_float(const uint8_t x0, const float y0, const uint8_t x1, const float y1) { + return _mm256_set_m128(_mm_set1_ps(GGML_CPU_E8M0_TO_FP32_HALF(x1) * GGML_CPU_FP16_TO_FP32(y1)), + _mm_set1_ps(GGML_CPU_E8M0_TO_FP32_HALF(x0) * GGML_CPU_FP16_TO_FP32(y0))); } #endif #elif defined(__SSSE3__) @@ -782,6 +782,7 @@ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo __m256 accum1 = _mm256_setzero_ps(); __m256 accum2 = _mm256_setzero_ps(); + for (; ib + 1 < nb; ib += 2) { const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)x[ib + 0].qs); const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)x[ib + 1].qs); @@ -795,10 +796,10 @@ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2); const __m256i p_1 = _mm256_madd_epi16(p16_1, mone); const __m256i p_2 = _mm256_madd_epi16(p16_2, mone); - accum1 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib + 0].d)*GGML_E8M0_TO_FP32_HALF(x[ib + 0].e)), - _mm256_cvtepi32_ps(p_1), accum1); - accum2 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib + 1].d)*GGML_E8M0_TO_FP32_HALF(x[ib + 1].e)), - _mm256_cvtepi32_ps(p_2), accum2); + const __m256 scale0 = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib + 0].d)*GGML_CPU_E8M0_TO_FP32_HALF(x[ib + 0].e)); + const __m256 scale1 = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib + 1].d)*GGML_CPU_E8M0_TO_FP32_HALF(x[ib + 1].e)); + accum1 = _mm256_fmadd_ps(scale0, _mm256_cvtepi32_ps(p_1), accum1); + accum2 = _mm256_fmadd_ps(scale1, _mm256_cvtepi32_ps(p_2), accum2); } sumf = hsum_float_8(_mm256_add_ps(accum1, accum2)); @@ -830,7 +831,7 @@ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo #endif for (; ib < nb; ++ib) { - const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_E8M0_TO_FP32_HALF(x[ib].e); + const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_CPU_E8M0_TO_FP32_HALF(x[ib].e); int sumi1 = 0; int sumi2 = 0; for (int j = 0; j < QK_MXFP4/2; ++j) { @@ -3817,4 +3818,3 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); #endif } - diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 3e5f01e3fb6..b003fe13fd9 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -75,6 +75,9 @@ // precomputed f32 table for f16 (256 KB) (simd-mappings.h) float ggml_table_f32_f16[1 << 16]; +// precomputed f32 table for e8m0 half (1 KB) (simd-mappings.h) +float ggml_table_f32_e8m0_half[1 << 8]; + #if defined(__ARM_ARCH) struct ggml_arm_arch_features_type { int sve_cnt; @@ -3681,6 +3684,11 @@ void ggml_cpu_init(void) { ggml_table_gelu_quick_f16[i] = GGML_CPU_FP32_TO_FP16(ggml_gelu_quick_f32(f)); } + // initialize E8M0 half table (256 entries) + for (int i = 0; i < (1 << 8); ++i) { + ggml_table_f32_e8m0_half[i] = GGML_E8M0_TO_FP32_HALF(i); + } + const uint64_t t_end = ggml_time_us(); UNUSED(t_end); GGML_PRINT_DEBUG("%s: GELU, Quick GELU, SILU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0); diff --git a/ggml/src/ggml-cpu/simd-mappings.h b/ggml/src/ggml-cpu/simd-mappings.h index e367f110b46..630e506542b 100644 --- a/ggml/src/ggml-cpu/simd-mappings.h +++ b/ggml/src/ggml-cpu/simd-mappings.h @@ -116,6 +116,17 @@ extern "C" { // defined in ggml-cpu.c, initialized in ggml_cpu_init() extern float ggml_table_f32_f16[1 << 16]; +// precomputed f32 table for e8m0 half (1 KB) +// defined in ggml-cpu.c, initialized in ggml_cpu_init() +extern float ggml_table_f32_e8m0_half[1 << 8]; + +// Use lookup table for E8M0 on x86 (faster than bit manipulation) +#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) +#define GGML_CPU_E8M0_TO_FP32_HALF(x) ggml_table_f32_e8m0_half[(uint8_t)(x)] +#else +#define GGML_CPU_E8M0_TO_FP32_HALF(x) GGML_E8M0_TO_FP32_HALF(x) +#endif + // On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32, // so we define GGML_CPU_FP16_TO_FP32 and GGML_CPU_FP32_TO_FP16 elsewhere for NEON. // This is also true for POWER9. From 2763054f99b3609ee438bfec70a1ecbe9a563829 Mon Sep 17 00:00:00 2001 From: Kevin Pouget Date: Wed, 4 Feb 2026 03:46:18 +0100 Subject: [PATCH 097/831] ggml-virtgpu: make the code thread safe (llama/19204) * ggml-virtgpu: regenerate_remoting.py: add the ability to deprecate a function * ggml-virtgpu: deprecate buffer_type is_host remoting not necessary * ggml-virtgpu: stop using static vars as cache The static init isn't thread safe. * ggml-virtgpu: protect the use of the shared memory to transfer data * ggml-virtgpu: make the remote calls thread-safe * ggml-virtgpu: backend: don't continue if couldn't allocate the tensor memory * ggml-virtgpu: add a cleanup function for consistency * ggml-virtgpu: backend: don't crash if buft->iface.get_max_size is missing * fix style and ordering * Remove the static variable in apir_device_get_count * ggml-virtgpu: improve the logging * fix review minor formatting changes --- ggml/include/ggml-virtgpu.h | 2 - .../ggml-virtgpu/apir_cs_ggml-rpc-front.cpp | 2 +- .../backend/backend-dispatched-backend.cpp | 4 +- .../backend-dispatched-buffer-type.cpp | 12 +- .../backend/backend-dispatched-buffer.cpp | 6 +- .../backend/backend-dispatched-device.cpp | 2 +- .../backend/backend-dispatched.cpp | 8 +- .../backend/backend-dispatched.gen.h | 5 +- .../ggml-virtgpu/backend/backend-dispatched.h | 2 + ggml/src/ggml-virtgpu/backend/backend.cpp | 32 ++-- .../src/ggml-virtgpu/backend/shared/apir_cs.h | 11 +- .../backend/shared/apir_cs_ggml.h | 14 +- .../ggml-virtgpu/ggml-backend-buffer-type.cpp | 29 +--- ggml/src/ggml-virtgpu/ggml-backend-device.cpp | 65 +++++--- ggml/src/ggml-virtgpu/ggml-backend-reg.cpp | 94 ++++++++--- ggml/src/ggml-virtgpu/ggml-remoting.h | 5 +- .../ggml-virtgpu/ggmlremoting_functions.yaml | 20 ++- ggml/src/ggml-virtgpu/regenerate_remoting.py | 18 ++- .../ggml-virtgpu/virtgpu-forward-backend.cpp | 14 +- .../virtgpu-forward-buffer-type.cpp | 41 ++--- .../ggml-virtgpu/virtgpu-forward-buffer.cpp | 28 +++- .../ggml-virtgpu/virtgpu-forward-device.cpp | 22 +-- ggml/src/ggml-virtgpu/virtgpu-forward-impl.h | 6 +- ggml/src/ggml-virtgpu/virtgpu-forward.gen.h | 21 +-- ggml/src/ggml-virtgpu/virtgpu-shm.cpp | 3 +- ggml/src/ggml-virtgpu/virtgpu.cpp | 149 ++++++++++++------ ggml/src/ggml-virtgpu/virtgpu.h | 23 +++ 27 files changed, 399 insertions(+), 239 deletions(-) diff --git a/ggml/include/ggml-virtgpu.h b/ggml/include/ggml-virtgpu.h index 1cb4bd7a038..faaba8f246d 100644 --- a/ggml/include/ggml-virtgpu.h +++ b/ggml/include/ggml-virtgpu.h @@ -7,8 +7,6 @@ extern "C" { #endif -#define GGML_REMOTING_FRONTEND_NAME "RemotingFrontend" - GGML_BACKEND_API ggml_backend_reg_t ggml_backend_virtgpu_reg(); #ifdef __cplusplus diff --git a/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp b/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp index f60ae3556ca..d2e87330a63 100644 --- a/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +++ b/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp @@ -36,7 +36,7 @@ apir_rpc_tensor apir_serialize_tensor(const ggml_tensor * tensor) { result.data = reinterpret_cast(tensor->data); if (tensor->data) { if (!tensor->buffer) { - GGML_ABORT("tensor has data but not buffer"); + GGML_ABORT("%s: tensor has data but not buffer", __func__); } // tensor->data is serialized as an offset to the buffer base address result.data -= reinterpret_cast(BUFFER_TO_GGML_CONTEXT(tensor->buffer)->base); diff --git a/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp b/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp index 77b4ee71e12..cc879e51d04 100644 --- a/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +++ b/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp @@ -27,7 +27,7 @@ uint32_t backend_backend_graph_compute(apir_encoder * enc, apir_decoder * dec, v const void * shmem_data = ctx->iface->get_shmem_ptr(ctx->ctx_id, shmem_res_id); if (!shmem_data) { - GGML_LOG_ERROR("Couldn't get the shmem addr from virgl\n"); + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Couldn't get the shmem addr from virgl\n", __func__); apir_decoder_set_fatal(dec); return 1; } @@ -45,7 +45,7 @@ uint32_t backend_backend_graph_compute(apir_encoder * enc, apir_decoder * dec, v if (dev->iface.supports_op(dev, op)) { continue; } - GGML_LOG_ERROR("Graph node %d (%s) not supported by the backend\n", idx, ggml_op_desc(op)); + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Graph node %d (%s) not supported by the backend\n", idx, ggml_op_desc(op)); status = GGML_STATUS_ABORTED; apir_encode_ggml_status(enc, &status); diff --git a/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp b/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp index 8ea1bb4fb49..d55eec27610 100644 --- a/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +++ b/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp @@ -36,18 +36,22 @@ uint32_t backend_buffer_type_get_max_size(apir_encoder * enc, apir_decoder * dec ggml_backend_buffer_type_t buft; buft = apir_decode_ggml_buffer_type(dec); - size_t value = buft->iface.get_max_size(buft); + size_t value = SIZE_MAX; + if (buft->iface.get_max_size) { + value = buft->iface.get_max_size(buft); + } + apir_encode_size_t(enc, &value); return 0; } +/* APIR_COMMAND_TYPE_BUFFER_TYPE_IS_HOST is deprecated. Keeping the handler for backward compatibility. */ uint32_t backend_buffer_type_is_host(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) { GGML_UNUSED(ctx); - ggml_backend_buffer_type_t buft; - buft = apir_decode_ggml_buffer_type(dec); + GGML_UNUSED(dec); + const bool is_host = false; - bool is_host = buft->iface.is_host(buft); apir_encode_bool_t(enc, &is_host); return 0; diff --git a/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp b/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp index cf81888e989..8cc063ff0a6 100644 --- a/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +++ b/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp @@ -40,7 +40,7 @@ uint32_t backend_buffer_set_tensor(apir_encoder * enc, apir_decoder * dec, virgl void * shmem_data = ctx->iface->get_shmem_ptr(ctx->ctx_id, shmem_res_id); if (!shmem_data) { - GGML_LOG_ERROR("Couldn't get the shmem addr from virgl\n"); + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Couldn't get the shmem addr from virgl\n", __func__); return 1; } @@ -71,7 +71,7 @@ uint32_t backend_buffer_get_tensor(apir_encoder * enc, apir_decoder * dec, virgl void * shmem_data = ctx->iface->get_shmem_ptr(ctx->ctx_id, shmem_res_id); if (!shmem_data) { - GGML_LOG_ERROR("Couldn't get the shmem addr from virgl\n"); + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Couldn't get the shmem addr from virgl\n", __func__); return 1; } @@ -121,7 +121,7 @@ uint32_t backend_buffer_free_buffer(apir_encoder * enc, apir_decoder * dec, virg buffer = apir_decode_ggml_buffer(dec); if (!apir_untrack_backend_buffer(buffer)) { - GGML_LOG_WARN("%s: unknown buffer %p\n", __func__, (void *) buffer); + GGML_LOG_WARN(GGML_VIRTGPU_BCK "%s: unknown buffer %p\n", __func__, (void *) buffer); return 1; } diff --git a/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp b/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp index 497f737a881..c7acb8b51ce 100644 --- a/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +++ b/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp @@ -124,7 +124,7 @@ uint32_t backend_device_buffer_from_ptr(apir_encoder * enc, apir_decoder * dec, void * shmem_ptr = ctx->iface->get_shmem_ptr(ctx->ctx_id, shmem_res_id); if (!shmem_ptr) { - GGML_LOG_ERROR("Couldn't get the shmem addr from virgl\n"); + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Couldn't get the shmem addr from virgl\n", __func__); apir_decoder_set_fatal(dec); return 1; } diff --git a/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp b/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp index 51d445725f0..64152eef0d8 100644 --- a/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +++ b/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp @@ -17,26 +17,26 @@ uint64_t timer_count = 0; uint32_t backend_dispatch_initialize(void * ggml_backend_reg_fct_p) { if (reg != NULL) { - GGML_LOG_WARN("%s: already initialized\n", __func__); + GGML_LOG_WARN(GGML_VIRTGPU_BCK "%s: already initialized\n", __func__); return APIR_BACKEND_INITIALIZE_ALREADY_INITED; } ggml_backend_reg_t (*ggml_backend_reg_fct)(void) = (ggml_backend_reg_t (*)()) ggml_backend_reg_fct_p; reg = ggml_backend_reg_fct(); if (reg == NULL) { - GGML_LOG_ERROR("%s: backend registration failed\n", __func__); + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: backend registration failed\n", __func__); return APIR_BACKEND_INITIALIZE_BACKEND_REG_FAILED; } if (!reg->iface.get_device_count(reg)) { - GGML_LOG_ERROR("%s: backend initialization failed: no device found\n", __func__); + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: backend initialization failed: no device found\n", __func__); return APIR_BACKEND_INITIALIZE_NO_DEVICE; } dev = reg->iface.get_device(reg, 0); if (!dev) { - GGML_LOG_ERROR("%s: backend initialization failed: no device received\n", __func__); + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: backend initialization failed: no device received\n", __func__); return APIR_BACKEND_INITIALIZE_NO_DEVICE; } diff --git a/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h b/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h index b81fd5039bd..481d7f3150d 100644 --- a/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +++ b/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h @@ -16,6 +16,7 @@ uint32_t backend_device_buffer_from_ptr(apir_encoder * enc, apir_decoder * dec, uint32_t backend_buffer_type_get_name(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); uint32_t backend_buffer_type_get_alignment(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); uint32_t backend_buffer_type_get_max_size(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); +/* APIR_COMMAND_TYPE_BUFFER_TYPE_IS_HOST is deprecated. Keeping the handler for backward compatibility. */ uint32_t backend_buffer_type_is_host(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); uint32_t backend_buffer_type_alloc_buffer(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); uint32_t backend_buffer_type_get_alloc_size(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); @@ -62,7 +63,7 @@ static inline const char * backend_dispatch_command_name(ApirBackendCommandType case APIR_COMMAND_TYPE_BUFFER_TYPE_GET_MAX_SIZE: return "backend_buffer_type_get_max_size"; case APIR_COMMAND_TYPE_BUFFER_TYPE_IS_HOST: - return "backend_buffer_type_is_host"; + return "backend_buffer_type_is_host (DEPRECATED)"; case APIR_COMMAND_TYPE_BUFFER_TYPE_ALLOC_BUFFER: return "backend_buffer_type_alloc_buffer"; case APIR_COMMAND_TYPE_BUFFER_TYPE_GET_ALLOC_SIZE: @@ -110,7 +111,7 @@ static const backend_dispatch_t apir_backend_dispatch_table[APIR_BACKEND_DISPATC /* APIR_COMMAND_TYPE_BUFFER_TYPE_GET_NAME = */ backend_buffer_type_get_name, /* APIR_COMMAND_TYPE_BUFFER_TYPE_GET_ALIGNMENT = */ backend_buffer_type_get_alignment, /* APIR_COMMAND_TYPE_BUFFER_TYPE_GET_MAX_SIZE = */ backend_buffer_type_get_max_size, - /* APIR_COMMAND_TYPE_BUFFER_TYPE_IS_HOST = */ backend_buffer_type_is_host, + /* APIR_COMMAND_TYPE_BUFFER_TYPE_IS_HOST = */ backend_buffer_type_is_host /* DEPRECATED */, /* APIR_COMMAND_TYPE_BUFFER_TYPE_ALLOC_BUFFER = */ backend_buffer_type_alloc_buffer, /* APIR_COMMAND_TYPE_BUFFER_TYPE_GET_ALLOC_SIZE = */ backend_buffer_type_get_alloc_size, diff --git a/ggml/src/ggml-virtgpu/backend/backend-dispatched.h b/ggml/src/ggml-virtgpu/backend/backend-dispatched.h index 6ccbecf078d..10311631d4f 100644 --- a/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +++ b/ggml/src/ggml-virtgpu/backend/backend-dispatched.h @@ -11,6 +11,8 @@ #include "shared/apir_cs.h" #include "shared/apir_cs_ggml.h" +#define GGML_VIRTGPU_BCK "ggml-virtgpu-backend: " + struct virgl_apir_context { uint32_t ctx_id; virgl_apir_callbacks * iface; diff --git a/ggml/src/ggml-virtgpu/backend/backend.cpp b/ggml/src/ggml-virtgpu/backend/backend.cpp index 95d602ed603..d93414a078b 100644 --- a/ggml/src/ggml-virtgpu/backend/backend.cpp +++ b/ggml/src/ggml-virtgpu/backend/backend.cpp @@ -35,14 +35,8 @@ void apir_backend_deinit(uint32_t virgl_ctx_id) { buffer->iface.free_buffer(buffer); } - if (dev) { - size_t free, total; - dev->iface.get_memory(dev, &free, &total); - GGML_LOG_INFO("%s: free memory: %ld MB\n", __func__, (size_t) free / 1024 / 1024); - } - if (backend_library_handle) { - GGML_LOG_INFO("%s: The GGML backend library was loaded. Unloading it.\n", __func__); + GGML_LOG_INFO(GGML_VIRTGPU_BCK "The GGML backend library was loaded. Unloading it.\n"); dlclose(backend_library_handle); backend_library_handle = NULL; } @@ -65,7 +59,7 @@ ApirLoadLibraryReturnCode apir_backend_initialize(uint32_t virgl_ctx_id, struct if (apir_logfile) { ggml_log_set(log_to_file_callback, apir_logfile); } else { - GGML_LOG_INFO("Could not open the log file at '%s'\n", apir_log_to_file); + GGML_LOG_INFO(GGML_VIRTGPU_BCK "Could not open the log file at '%s'\n", apir_log_to_file); } } @@ -74,7 +68,10 @@ ApirLoadLibraryReturnCode apir_backend_initialize(uint32_t virgl_ctx_id, struct const char * library_reg = virgl_library_reg ? virgl_library_reg : GGML_DEFAULT_BACKEND_REG; if (!library_name) { - GGML_LOG_ERROR("cannot open the GGML library: env var '%s' not defined\n", APIR_LLAMA_CPP_GGML_LIBRARY_PATH_ENV); + GGML_LOG_ERROR(GGML_VIRTGPU_BCK + "%s: cannot open the GGML library: env var '%s' not defined\n", + __func__, APIR_LLAMA_CPP_GGML_LIBRARY_PATH_ENV); + return APIR_LOAD_LIBRARY_ENV_VAR_MISSING; } @@ -82,13 +79,16 @@ ApirLoadLibraryReturnCode apir_backend_initialize(uint32_t virgl_ctx_id, struct backend_library_handle = dlopen(library_name, RTLD_LAZY); if (!backend_library_handle) { - GGML_LOG_ERROR("cannot open the GGML library: %s\n", dlerror()); + GGML_LOG_ERROR(GGML_VIRTGPU_BCK + "%s: cannot open the GGML library: %s\n", __func__, dlerror()); return APIR_LOAD_LIBRARY_CANNOT_OPEN; } if (!library_reg) { - GGML_LOG_ERROR("cannot register the GGML library: env var '%s' not defined\n", APIR_LLAMA_CPP_GGML_LIBRARY_REG_ENV); + GGML_LOG_ERROR(GGML_VIRTGPU_BCK + "%s: cannot register the GGML library: env var '%s' not defined\n", + __func__, APIR_LLAMA_CPP_GGML_LIBRARY_REG_ENV); return APIR_LOAD_LIBRARY_ENV_VAR_MISSING; } @@ -96,8 +96,10 @@ ApirLoadLibraryReturnCode apir_backend_initialize(uint32_t virgl_ctx_id, struct void * ggml_backend_reg_fct = dlsym(backend_library_handle, library_reg); dlsym_error = dlerror(); if (dlsym_error) { - GGML_LOG_ERROR("cannot find the GGML backend registration symbol '%s' (from %s): %s\n", library_reg, - APIR_LLAMA_CPP_GGML_LIBRARY_REG_ENV, dlsym_error); + GGML_LOG_ERROR(GGML_VIRTGPU_BCK + "%s: cannot find the GGML backend registration symbol '%s' (from %s): %s\n", + __func__, library_reg, APIR_LLAMA_CPP_GGML_LIBRARY_REG_ENV, dlsym_error); + return APIR_LOAD_LIBRARY_SYMBOL_MISSING; } @@ -134,7 +136,9 @@ uint32_t apir_backend_dispatcher(uint32_t virgl_ctx_id, }; if (cmd_type >= APIR_BACKEND_DISPATCH_TABLE_COUNT) { - GGML_LOG_ERROR("Received an invalid dispatch index (%d >= %d)\n", cmd_type, APIR_BACKEND_DISPATCH_TABLE_COUNT); + GGML_LOG_ERROR(GGML_VIRTGPU_BCK + "%s: Received an invalid dispatch index (%d >= %d)\n", + __func__, cmd_type, APIR_BACKEND_DISPATCH_TABLE_COUNT); return APIR_BACKEND_FORWARD_INDEX_INVALID; } diff --git a/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h b/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h index 27a61091ffd..1bc3a5f685b 100644 --- a/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +++ b/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h @@ -86,7 +86,7 @@ static inline bool apir_decoder_peek_internal(apir_decoder * dec, assert(val_size <= size); if (unlikely(size > (size_t) (dec->end - dec->cur))) { - GGML_LOG_ERROR("reading too much from the decoder ...\n"); + GGML_LOG_ERROR("%s: reading too much from the decoder ...\n", __func__); apir_decoder_set_fatal(dec); memset(val, 0, val_size); return false; @@ -103,7 +103,7 @@ static inline void apir_decoder_peek(apir_decoder * dec, size_t size, void * val static inline const void * apir_decoder_use_inplace(apir_decoder * dec, size_t size) { if (unlikely(size > (size_t) (dec->end - dec->cur))) { - GGML_LOG_ERROR("reading too much from the decoder ...\n"); + GGML_LOG_ERROR("%s: reading too much from the decoder ...\n", __func__); apir_decoder_set_fatal(dec); return NULL; } @@ -221,7 +221,7 @@ static inline uint64_t apir_decode_array_size(apir_decoder * dec, uint64_t expec uint64_t size; apir_decode_uint64_t(dec, &size); if (size != expected_size) { - GGML_LOG_ERROR("Couldn't decode array from the decoder\n"); + GGML_LOG_ERROR("%s: Couldn't decode array from the decoder\n", __func__); apir_decoder_set_fatal(dec); size = 0; } @@ -322,7 +322,7 @@ static inline void apir_decode_char_array(apir_decoder * dec, char * val, size_t if (size) { val[size - 1] = '\0'; } else { - GGML_LOG_ERROR("Couldn't decode the blog array\n"); + GGML_LOG_ERROR("%s: Couldn't decode the blog array\n", __func__); apir_decoder_set_fatal(dec); } } @@ -332,7 +332,8 @@ static inline void apir_decode_char_array(apir_decoder * dec, char * val, size_t static inline void * apir_decoder_alloc_array(size_t size, size_t count) { size_t alloc_size; if (unlikely(__builtin_mul_overflow(size, count, &alloc_size))) { - GGML_LOG_ERROR("overflow in array allocation of %zu * %zu bytes\n", size, count); + GGML_LOG_ERROR("%s: overflow in array allocation of %zu * %zu bytes\n", + __func__, size, count); return NULL; } diff --git a/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h b/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h index 070c3b25fb1..289f4b77d74 100644 --- a/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +++ b/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h @@ -39,11 +39,17 @@ static inline void apir_encode_ggml_tensor(apir_encoder * enc, const ggml_tensor static inline const ggml_tensor * apir_decode_ggml_tensor(apir_decoder * dec) { const apir_rpc_tensor * apir_rpc_tensor = apir_decode_apir_rpc_tensor_inplace(dec); + + if (!apir_rpc_tensor) { + return NULL; + } + ggml_init_params params{ /*.mem_size =*/ ggml_tensor_overhead(), /*.mem_buffer =*/ NULL, /*.no_alloc =*/ true, }; + ggml_context * ctx = ggml_init(params); const ggml_tensor * tensor = apir_deserialize_tensor(ctx, apir_rpc_tensor); @@ -71,6 +77,10 @@ static inline ggml_backend_buffer_type_t apir_decode_ggml_buffer_type(apir_decod return (ggml_backend_buffer_type_t) handle; } +static inline void apir_encode_apir_buffer_type_host_handle(apir_encoder * enc, apir_buffer_type_host_handle_t handle) { + apir_encoder_write(enc, sizeof(handle), &handle, sizeof(handle)); +} + static inline apir_buffer_type_host_handle_t apir_decode_apir_buffer_type_host_handle(apir_decoder * dec) { apir_buffer_type_host_handle_t handle; @@ -154,13 +164,13 @@ static inline void apir_encode_ggml_tensor_inline(apir_encoder * enc, const ggml size_t tensor_size = sizeof(*tensor); if (tensor->extra) { - GGML_ABORT("Cannot pass tensors with extra"); + GGML_ABORT("%s: Cannot pass tensors with extra", __func__); } if (tensor->src[0] && tensor->buffer) { static int first = 1; if (first) { - GGML_LOG_WARN("Cannot pass tensors with src and buffer\n"); + GGML_LOG_WARN("%s: Cannot pass tensors with src and buffer\n", __func__); first = 0; } } diff --git a/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp b/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp index 7f650659b8a..c493a8e2ae3 100644 --- a/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +++ b/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp @@ -6,7 +6,7 @@ static ggml_backend_buffer_t ggml_backend_remoting_buffer_type_alloc_buffer(ggml ggml_backend_remoting_buffer_context * context = (ggml_backend_remoting_buffer_context *) malloc(sizeof(*context)); if (!context) { - GGML_ABORT("Couldn't allocate the buffer context ..."); + GGML_ABORT(GGML_VIRTGPU "%s: Couldn't allocate the buffer context ...", __func__); } context->gpu = gpu; @@ -20,7 +20,7 @@ static ggml_backend_buffer_t ggml_backend_remoting_buffer_type_alloc_buffer(ggml context->base = context->apir_context.shmem.mmap_ptr; context->is_from_ptr = true; } else { - context->apir_context = apir_buffer_type_alloc_buffer(gpu, buft, size); + context->apir_context = apir_buffer_type_alloc_buffer(gpu, gpu->cached_buffer_type.host_handle, size); context->is_from_ptr = false; context->base = NULL; } @@ -34,36 +34,19 @@ static ggml_backend_buffer_t ggml_backend_remoting_buffer_type_alloc_buffer(ggml static const char * ggml_backend_remoting_buffer_type_get_name(ggml_backend_buffer_type_t buft) { virtgpu * gpu = BUFT_TO_GPU(buft); - return apir_buffer_type_get_name(gpu, buft); + return gpu->cached_buffer_type.name; } static size_t ggml_backend_remoting_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { virtgpu * gpu = BUFT_TO_GPU(buft); - static size_t align = 0; - - if (align == 0) { - align = apir_buffer_type_get_alignment(gpu, buft); - } - - return align; + return gpu->cached_buffer_type.alignment; } static size_t ggml_backend_remoting_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) { virtgpu * gpu = BUFT_TO_GPU(buft); - static size_t max_size = 0; - if (max_size == 0) { - max_size = apir_buffer_type_get_max_size(gpu, buft); - } - - return max_size; -} - -static bool ggml_backend_remoting_buffer_type_is_host(ggml_backend_buffer_type_t buft) { - virtgpu * gpu = BUFT_TO_GPU(buft); - - return apir_buffer_type_is_host(gpu, buft); + return gpu->cached_buffer_type.max_size; } static size_t ggml_backend_remoting_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, @@ -76,7 +59,7 @@ static size_t ggml_backend_remoting_buffer_type_get_alloc_size(ggml_backend_buff return ggml_nbytes(tensor); } - return apir_buffer_type_get_alloc_size(gpu, buft, tensor); + return apir_buffer_type_get_alloc_size(gpu, gpu->cached_buffer_type.host_handle, tensor); } const ggml_backend_buffer_type_i ggml_backend_remoting_buffer_type_interface = { diff --git a/ggml/src/ggml-virtgpu/ggml-backend-device.cpp b/ggml/src/ggml-virtgpu/ggml-backend-device.cpp index 579eb990781..c7d2881058b 100644 --- a/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +++ b/ggml/src/ggml-virtgpu/ggml-backend-device.cpp @@ -3,32 +3,27 @@ static const char * ggml_backend_remoting_device_get_name(ggml_backend_dev_t dev) { virtgpu * gpu = DEV_TO_GPU(dev); - return apir_device_get_name(gpu); + return gpu->cached_device_info.name; } static const char * ggml_backend_remoting_device_get_description(ggml_backend_dev_t dev) { virtgpu * gpu = DEV_TO_GPU(dev); - return apir_device_get_description(gpu); + // Return the pre-cached description from the virtgpu structure + return gpu->cached_device_info.description; } static enum ggml_backend_dev_type ggml_backend_remoting_device_get_type(ggml_backend_dev_t dev) { virtgpu * gpu = DEV_TO_GPU(dev); - static enum ggml_backend_dev_type type; - static bool has_type = false; - if (!has_type) { - has_type = true; - type = (enum ggml_backend_dev_type) apir_device_get_type(gpu); - } - - return type; + return (enum ggml_backend_dev_type) gpu->cached_device_info.type; } static void ggml_backend_remoting_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { virtgpu * gpu = DEV_TO_GPU(dev); - return apir_device_get_memory(gpu, free, total); + *free = gpu->cached_device_info.memory_free; + *total = gpu->cached_device_info.memory_total; } static bool ggml_backend_remoting_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) { @@ -77,13 +72,22 @@ static void ggml_backend_remoting_device_get_props(ggml_backend_dev_t dev, ggml_ ggml_backend_buffer_type_t ggml_backend_remoting_device_get_buffer_type(ggml_backend_dev_t dev) { virtgpu * gpu = DEV_TO_GPU(dev); - apir_buffer_type_host_handle_t ctx = apir_device_get_buffer_type(gpu); - - static ggml_backend_buffer_type buft{ - /* .iface = */ ggml_backend_remoting_buffer_type_interface, - /* .device = */ dev, - /* .context = */ (void *) ctx, - }; + static std::atomic initialized = false; + static ggml_backend_buffer_type buft; + + if (!initialized) { + static std::mutex mutex; + std::lock_guard lock(mutex); + + if (!initialized) { + buft = { + /* .iface = */ ggml_backend_remoting_buffer_type_interface, + /* .device = */ dev, + /* .context = */ (void *) gpu->cached_buffer_type.host_handle, + }; + initialized = true; + } + } return &buft; } @@ -91,13 +95,22 @@ ggml_backend_buffer_type_t ggml_backend_remoting_device_get_buffer_type(ggml_bac static ggml_backend_buffer_type_t ggml_backend_remoting_device_get_buffer_from_ptr_type(ggml_backend_dev_t dev) { virtgpu * gpu = DEV_TO_GPU(dev); - apir_buffer_type_host_handle_t ctx = apir_device_get_buffer_type(gpu); - - static ggml_backend_buffer_type buft{ - /* .iface = */ ggml_backend_remoting_buffer_from_ptr_type_interface, - /* .device = */ dev, - /* .context = */ (void *) ctx, - }; + static std::atomic initialized = false; + static ggml_backend_buffer_type buft; + + if (!initialized) { + static std::mutex mutex; + std::lock_guard lock(mutex); + + if (!initialized) { + buft = { + /* .iface = */ ggml_backend_remoting_buffer_from_ptr_type_interface, + /* .device = */ dev, + /* .context = */ (void *) gpu->cached_buffer_type.host_handle, + }; + initialized = true; + } + } return &buft; } @@ -110,7 +123,7 @@ static ggml_backend_buffer_t ggml_backend_remoting_device_buffer_from_ptr(ggml_b ggml_backend_remoting_buffer_context * context = (ggml_backend_remoting_buffer_context *) malloc(sizeof(*context)); if (!context) { - GGML_ABORT("Couldn't allocate the buffer context ..."); + GGML_ABORT(GGML_VIRTGPU "%s: Couldn't allocate the buffer context ...", __func__); } context->gpu = gpu; diff --git a/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp b/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp index c46cf51c022..2d02cfec1d3 100644 --- a/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +++ b/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp @@ -4,37 +4,70 @@ #include #include +void ggml_virtgpu_cleanup(virtgpu * gpu); + static virtgpu * apir_initialize() { - static virtgpu * apir_gpu_instance = NULL; - static bool apir_initialized = false; + static virtgpu * gpu = NULL; + static std::atomic initialized = false; + + if (initialized) { + // fast track + return gpu; + } { static std::mutex mutex; std::lock_guard lock(mutex); - if (apir_initialized) { - return apir_gpu_instance; + if (initialized) { + // thread safe + return gpu; + } + + gpu = create_virtgpu(); + if (!gpu) { + initialized = true; + return NULL; } - apir_gpu_instance = create_virtgpu(); - if (!apir_gpu_instance) { - GGML_ABORT("failed to initialize the virtgpu"); + // Pre-fetch and cache all device information, it will not change + gpu->cached_device_info.description = apir_device_get_description(gpu); + if (!gpu->cached_device_info.description) { + GGML_ABORT(GGML_VIRTGPU "%s: failed to initialize the virtgpu device description", __func__); + } + gpu->cached_device_info.name = apir_device_get_name(gpu); + if (!gpu->cached_device_info.name) { + GGML_ABORT(GGML_VIRTGPU "%s: failed to initialize the virtgpu device name", __func__); } + gpu->cached_device_info.device_count = apir_device_get_count(gpu); + gpu->cached_device_info.type = apir_device_get_type(gpu); + + apir_device_get_memory(gpu, + &gpu->cached_device_info.memory_free, + &gpu->cached_device_info.memory_total); + + apir_buffer_type_host_handle_t buft_host_handle = apir_device_get_buffer_type(gpu); + gpu->cached_buffer_type.host_handle = buft_host_handle; + gpu->cached_buffer_type.name = apir_buffer_type_get_name(gpu, buft_host_handle); + if (!gpu->cached_buffer_type.name) { + GGML_ABORT(GGML_VIRTGPU "%s: failed to initialize the virtgpu buffer type name", __func__); + } + gpu->cached_buffer_type.alignment = apir_buffer_type_get_alignment(gpu, buft_host_handle); + gpu->cached_buffer_type.max_size = apir_buffer_type_get_max_size(gpu, buft_host_handle); - apir_initialized = true; + initialized = true; } - return apir_gpu_instance; + return gpu; } static int ggml_backend_remoting_get_device_count() { virtgpu * gpu = apir_initialize(); if (!gpu) { - GGML_LOG_WARN("apir_initialize failed\n"); return 0; } - return apir_device_get_count(gpu); + return gpu->cached_device_info.device_count; } static size_t ggml_backend_remoting_reg_get_device_count(ggml_backend_reg_t reg) { @@ -52,17 +85,21 @@ ggml_backend_dev_t ggml_backend_remoting_get_device(size_t device) { static void ggml_backend_remoting_reg_init_devices(ggml_backend_reg_t reg) { if (devices.size() > 0) { - GGML_LOG_INFO("%s: already initialized\n", __func__); + GGML_LOG_INFO(GGML_VIRTGPU "%s: already initialized\n", __func__); return; } virtgpu * gpu = apir_initialize(); if (!gpu) { - GGML_LOG_ERROR("apir_initialize failed\n"); + GGML_LOG_ERROR(GGML_VIRTGPU "%s: apir_initialize failed\n", __func__); return; } - static bool initialized = false; + static std::atomic initialized = false; + + if (initialized) { + return; // fast track + } { static std::mutex mutex; @@ -70,10 +107,10 @@ static void ggml_backend_remoting_reg_init_devices(ggml_backend_reg_t reg) { if (!initialized) { for (int i = 0; i < ggml_backend_remoting_get_device_count(); i++) { ggml_backend_remoting_device_context * ctx = new ggml_backend_remoting_device_context; - char desc[256] = "API Remoting device"; + char desc[256] = "ggml-virtgpu API Remoting device"; ctx->device = i; - ctx->name = GGML_REMOTING_FRONTEND_NAME + std::to_string(i); + ctx->name = GGML_VIRTGPU_NAME + std::to_string(i); ctx->description = desc; ctx->gpu = gpu; @@ -98,7 +135,7 @@ static ggml_backend_dev_t ggml_backend_remoting_reg_get_device(ggml_backend_reg_ static const char * ggml_backend_remoting_reg_get_name(ggml_backend_reg_t reg) { UNUSED(reg); - return GGML_REMOTING_FRONTEND_NAME; + return GGML_VIRTGPU_NAME; } static const ggml_backend_reg_i ggml_backend_remoting_reg_i = { @@ -111,8 +148,7 @@ static const ggml_backend_reg_i ggml_backend_remoting_reg_i = { ggml_backend_reg_t ggml_backend_virtgpu_reg() { virtgpu * gpu = apir_initialize(); if (!gpu) { - GGML_LOG_ERROR("virtgpu_apir_initialize failed\n"); - return NULL; + GGML_LOG_ERROR(GGML_VIRTGPU "%s: virtgpu_apir_initialize failed\n", __func__); } static ggml_backend_reg reg = { @@ -129,9 +165,25 @@ ggml_backend_reg_t ggml_backend_virtgpu_reg() { ggml_backend_remoting_reg_init_devices(®); - GGML_LOG_INFO("%s: initialized\n", __func__); - return ® } +// public function, not exposed in the GGML interface at the moment +void ggml_virtgpu_cleanup(virtgpu * gpu) { + if (gpu->cached_device_info.name) { + free(gpu->cached_device_info.name); + gpu->cached_device_info.name = NULL; + } + if (gpu->cached_device_info.description) { + free(gpu->cached_device_info.description); + gpu->cached_device_info.description = NULL; + } + if (gpu->cached_buffer_type.name) { + free(gpu->cached_buffer_type.name); + gpu->cached_buffer_type.name = NULL; + } + + mtx_destroy(&gpu->data_shmem_mutex); +} + GGML_BACKEND_DL_IMPL(ggml_backend_virtgpu_reg) diff --git a/ggml/src/ggml-virtgpu/ggml-remoting.h b/ggml/src/ggml-virtgpu/ggml-remoting.h index 36fc6b2a7bd..08766408676 100644 --- a/ggml/src/ggml-virtgpu/ggml-remoting.h +++ b/ggml/src/ggml-virtgpu/ggml-remoting.h @@ -8,6 +8,9 @@ #include #include +#define GGML_VIRTGPU_NAME "ggml-virtgpu" +#define GGML_VIRTGPU "ggml-virtgpu: " + // USE_ALWAYS_TRUE_SUPPORTS_OP: 1 is fast, 0 avoid micro-benchmark crashes #define USE_ALWAYS_TRUE_SUPPORTS_OP 1 @@ -62,7 +65,7 @@ static inline apir_buffer_type_host_handle_t ggml_buffer_type_to_apir_handle(ggm static inline apir_buffer_host_handle_t ggml_buffer_to_apir_handle(ggml_backend_buffer_t buffer) { if (!buffer->context) { - GGML_ABORT("%s: no context available :/", __func__); + GGML_ABORT(GGML_VIRTGPU "%s: no context available :/", __func__); } return BUFFER_TO_HOST_HANDLE(buffer); } diff --git a/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml b/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml index 0b7cccfe9cf..14ef2433e46 100644 --- a/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +++ b/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml @@ -24,10 +24,10 @@ functions: frontend_return: "int" get_name: - frontend_return: "const char *" + frontend_return: "char *" get_description: - frontend_return: "const char *" + frontend_return: "char *" get_type: frontend_return: "uint32_t" @@ -64,35 +64,33 @@ functions: group_description: "buffer-type" functions: get_name: - frontend_return: "const char *" + frontend_return: "char *" frontend_extra_params: - - "ggml_backend_buffer_type_t buft" + - "apir_buffer_type_host_handle_t host_handle" get_alignment: frontend_return: "size_t" frontend_extra_params: - - "ggml_backend_buffer_type_t buft" + - "apir_buffer_type_host_handle_t host_handle" get_max_size: frontend_return: "size_t" frontend_extra_params: - - "ggml_backend_buffer_type_t buft" + - "apir_buffer_type_host_handle_t host_handle" is_host: - frontend_return: "bool" - frontend_extra_params: - - "ggml_backend_buffer_type_t buft" + deprecated: true alloc_buffer: frontend_return: "apir_buffer_context_t" frontend_extra_params: - - "ggml_backend_buffer_type_t buffer_buft" + - "apir_buffer_type_host_handle_t host_handle" - "size_t size" get_alloc_size: frontend_return: "size_t" frontend_extra_params: - - "ggml_backend_buffer_type_t buft" + - "apir_buffer_type_host_handle_t host_handle" - "const ggml_tensor *op" buffer: diff --git a/ggml/src/ggml-virtgpu/regenerate_remoting.py b/ggml/src/ggml-virtgpu/regenerate_remoting.py index 4174a24327f..aeb48a4087e 100755 --- a/ggml/src/ggml-virtgpu/regenerate_remoting.py +++ b/ggml/src/ggml-virtgpu/regenerate_remoting.py @@ -116,7 +116,7 @@ def get_enabled_functions(self) -> List[Dict[str, Any]]: 'frontend_return': func_metadata.get('frontend_return', 'void'), 'frontend_extra_params': func_metadata.get('frontend_extra_params', []), 'group_description': group_description, - 'newly_added': func_metadata.get('newly_added', False) + 'deprecated': func_metadata.get('deprecated', False), }) enum_value += 1 @@ -165,6 +165,9 @@ def generate_backend_dispatched_header(self) -> str: signature = "uint32_t" params = "apir_encoder *enc, apir_decoder *dec, virgl_apir_context *ctx" + if func['deprecated']: + decl_lines.append(f"/* {func['enum_name']} is deprecated. Keeping the handler for backward compatibility. */") + decl_lines.append(f"{signature} {func['backend_function']}({params});") # Switch cases @@ -176,7 +179,9 @@ def generate_backend_dispatched_header(self) -> str: switch_lines.append(f" /* {func['group_description']} */") current_group = func['group_name'] - switch_lines.append(f" case {func['enum_name']}: return \"{func['backend_function']}\";") + deprecated = " (DEPRECATED)" if func['deprecated'] else "" + + switch_lines.append(f" case {func['enum_name']}: return \"{func['backend_function']}{deprecated}\";") # Dispatch table table_lines = [] @@ -188,7 +193,8 @@ def generate_backend_dispatched_header(self) -> str: table_lines.append("") current_group = func['group_name'] - table_lines.append(f" /* {func['enum_name']} = */ {func['backend_function']},") + deprecated = " /* DEPRECATED */" if func['deprecated'] else "" + table_lines.append(f" /* {func['enum_name']} = */ {func['backend_function']}{deprecated},") header_content = f'''\ #pragma once @@ -225,6 +231,10 @@ def generate_virtgpu_forward_header(self) -> str: decl_lines.append(f"/* {func['group_description']} */") current_group = func['group_name'] + if func['deprecated']: + decl_lines.append(f"/* {func['frontend_function']} is deprecated. */") + continue + # Build parameter list params = [self.naming_patterns['frontend_base_param']] params.extend(func['frontend_extra_params']) @@ -287,7 +297,7 @@ def regenerate_codebase(self) -> None: generated_files = [apir_backend_path, backend_dispatched_path, virtgpu_forward_path] if not self.clang_format_available: - logging.warning("\n⚠️clang-format not found in PATH. Generated files will not be formatted." + logging.warning("\n⚠️clang-format not found in PATH. Generated files will not be formatted.\n" " Install clang-format to enable automatic code formatting.") else: logging.info("\n🎨 Formatting files with clang-format...") diff --git a/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp b/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp index bf3c41011ac..07d9a668496 100644 --- a/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +++ b/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp @@ -18,12 +18,17 @@ ggml_status apir_backend_graph_compute(virtgpu * gpu, ggml_cgraph * cgraph) { virtgpu_shmem temp_shmem; // Local storage for large buffers virtgpu_shmem * shmem = &temp_shmem; + bool using_shared_shmem = false; if (cgraph_size <= gpu->data_shmem.mmap_size) { - // prefer the init-time allocated page, if large enough + // Lock mutex before using shared data_shmem buffer + if (mtx_lock(&gpu->data_shmem_mutex) != thrd_success) { + GGML_ABORT(GGML_VIRTGPU "%s: Failed to lock data_shmem mutex", __func__); + } + using_shared_shmem = true; shmem = &gpu->data_shmem; } else if (virtgpu_shmem_create(gpu, cgraph_size, shmem)) { - GGML_ABORT("Couldn't allocate the guest-host shared buffer"); + GGML_ABORT(GGML_VIRTGPU "%s: Couldn't allocate the guest-host shared buffer", __func__); } apir_encode_virtgpu_shmem_res_id(encoder, shmem->res_id); @@ -42,7 +47,10 @@ ggml_status apir_backend_graph_compute(virtgpu * gpu, ggml_cgraph * cgraph) { remote_call_finish(gpu, encoder, decoder); - if (shmem != &gpu->data_shmem) { + // Unlock mutex before cleanup + if (using_shared_shmem) { + mtx_unlock(&gpu->data_shmem_mutex); + } else { virtgpu_shmem_destroy(gpu, shmem); } diff --git a/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp b/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp index 03cb09e0643..cab74fd1707 100644 --- a/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +++ b/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp @@ -1,20 +1,20 @@ #include "virtgpu-forward-impl.h" -const char * apir_buffer_type_get_name(virtgpu * gpu, ggml_backend_buffer_type_t buft) { +char * apir_buffer_type_get_name(virtgpu * gpu, apir_buffer_type_host_handle_t host_handle) { apir_encoder * encoder; apir_decoder * decoder; ApirForwardReturnCode ret; REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_BUFFER_TYPE_GET_NAME); - apir_encode_ggml_buffer_type(encoder, buft); + apir_encode_apir_buffer_type_host_handle(encoder, host_handle); REMOTE_CALL(gpu, encoder, decoder, ret); const size_t string_size = apir_decode_array_size_unchecked(decoder); char * string = (char *) apir_decoder_alloc_array(sizeof(char), string_size); if (!string) { - GGML_LOG_ERROR("%s: Could not allocate the device name buffer\n", __func__); + GGML_LOG_ERROR(GGML_VIRTGPU "%s: Could not allocate the device name buffer\n", __func__); apir_decoder_set_fatal(decoder); } apir_decode_char_array(decoder, string, string_size); @@ -24,14 +24,14 @@ const char * apir_buffer_type_get_name(virtgpu * gpu, ggml_backend_buffer_type_t return string; } -size_t apir_buffer_type_get_alignment(virtgpu * gpu, ggml_backend_buffer_type_t buft) { +size_t apir_buffer_type_get_alignment(virtgpu * gpu, apir_buffer_type_host_handle_t host_handle) { apir_encoder * encoder; apir_decoder * decoder; ApirForwardReturnCode ret; REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_BUFFER_TYPE_GET_ALIGNMENT); - apir_encode_ggml_buffer_type(encoder, buft); + apir_encode_apir_buffer_type_host_handle(encoder, host_handle); REMOTE_CALL(gpu, encoder, decoder, ret); @@ -43,14 +43,14 @@ size_t apir_buffer_type_get_alignment(virtgpu * gpu, ggml_backend_buffer_type_t return alignment; } -size_t apir_buffer_type_get_max_size(virtgpu * gpu, ggml_backend_buffer_type_t buft) { +size_t apir_buffer_type_get_max_size(virtgpu * gpu, apir_buffer_type_host_handle_t host_handle) { apir_encoder * encoder; apir_decoder * decoder; ApirForwardReturnCode ret; REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_BUFFER_TYPE_GET_MAX_SIZE); - apir_encode_ggml_buffer_type(encoder, buft); + apir_encode_apir_buffer_type_host_handle(encoder, host_handle); REMOTE_CALL(gpu, encoder, decoder, ret); @@ -62,26 +62,7 @@ size_t apir_buffer_type_get_max_size(virtgpu * gpu, ggml_backend_buffer_type_t b return max_size; } -bool apir_buffer_type_is_host(virtgpu * gpu, ggml_backend_buffer_type_t buft) { - apir_encoder * encoder; - apir_decoder * decoder; - ApirForwardReturnCode ret; - - REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_BUFFER_TYPE_IS_HOST); - - apir_encode_ggml_buffer_type(encoder, buft); - - REMOTE_CALL(gpu, encoder, decoder, ret); - - bool is_host; - apir_decode_bool_t(decoder, &is_host); - - remote_call_finish(gpu, encoder, decoder); - - return is_host; -} - -apir_buffer_context_t apir_buffer_type_alloc_buffer(virtgpu * gpu, ggml_backend_buffer_type_t buft, size_t size) { +apir_buffer_context_t apir_buffer_type_alloc_buffer(virtgpu * gpu, apir_buffer_type_host_handle_t host_handle, size_t size) { apir_encoder * encoder; apir_decoder * decoder; ApirForwardReturnCode ret; @@ -90,7 +71,7 @@ apir_buffer_context_t apir_buffer_type_alloc_buffer(virtgpu * gpu, ggml_backend_ REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_BUFFER_TYPE_ALLOC_BUFFER); - apir_encode_ggml_buffer_type(encoder, buft); + apir_encode_apir_buffer_type_host_handle(encoder, host_handle); apir_encode_size_t(encoder, &size); @@ -103,14 +84,14 @@ apir_buffer_context_t apir_buffer_type_alloc_buffer(virtgpu * gpu, ggml_backend_ return buffer_context; } -size_t apir_buffer_type_get_alloc_size(virtgpu * gpu, ggml_backend_buffer_type_t buft, const ggml_tensor * op) { +size_t apir_buffer_type_get_alloc_size(virtgpu * gpu, apir_buffer_type_host_handle_t host_handle, const ggml_tensor * op) { apir_encoder * encoder; apir_decoder * decoder; ApirForwardReturnCode ret; REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_BUFFER_TYPE_GET_ALLOC_SIZE); - apir_encode_ggml_buffer_type(encoder, buft); + apir_encode_apir_buffer_type_host_handle(encoder, host_handle); apir_encode_ggml_tensor_inline(encoder, op); diff --git a/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp b/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp index 3181e394407..86eee358cf4 100644 --- a/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +++ b/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp @@ -36,13 +36,18 @@ void apir_buffer_set_tensor(virtgpu * gpu, virtgpu_shmem temp_shmem; // Local storage for large buffers virtgpu_shmem * shmem = &temp_shmem; + bool using_shared_shmem = false; if (size <= gpu->data_shmem.mmap_size) { - // prefer the init-time allocated page, if large enough + // Lock mutex before using shared data_shmem buffer + if (mtx_lock(&gpu->data_shmem_mutex) != thrd_success) { + GGML_ABORT(GGML_VIRTGPU "%s: Failed to lock data_shmem mutex", __func__); + } + using_shared_shmem = true; shmem = &gpu->data_shmem; } else if (virtgpu_shmem_create(gpu, size, shmem)) { - GGML_ABORT("Couldn't allocate the guest-host shared buffer"); + GGML_ABORT(GGML_VIRTGPU "%s: Couldn't allocate the guest-host shared buffer", __func__); } memcpy(shmem->mmap_ptr, data, size); @@ -55,7 +60,10 @@ void apir_buffer_set_tensor(virtgpu * gpu, remote_call_finish(gpu, encoder, decoder); - if (shmem != &gpu->data_shmem) { + // Unlock mutex before cleanup + if (using_shared_shmem) { + mtx_unlock(&gpu->data_shmem_mutex); + } else { virtgpu_shmem_destroy(gpu, shmem); } @@ -79,13 +87,18 @@ void apir_buffer_get_tensor(virtgpu * gpu, virtgpu_shmem temp_shmem; // Local storage for large buffers virtgpu_shmem * shmem = &temp_shmem; + bool using_shared_shmem = false; if (size <= gpu->data_shmem.mmap_size) { - // prefer the init-time allocated page, if large enough + // Lock mutex before using shared data_shmem buffer + if (mtx_lock(&gpu->data_shmem_mutex) != thrd_success) { + GGML_ABORT(GGML_VIRTGPU "%s: Failed to lock data_shmem mutex", __func__); + } + using_shared_shmem = true; shmem = &gpu->data_shmem; } else if (virtgpu_shmem_create(gpu, size, shmem)) { - GGML_ABORT("Couldn't allocate the guest-host shared buffer"); + GGML_ABORT(GGML_VIRTGPU "%s: Couldn't allocate the guest-host shared buffer", __func__); } apir_encode_virtgpu_shmem_res_id(encoder, shmem->res_id); @@ -98,7 +111,10 @@ void apir_buffer_get_tensor(virtgpu * gpu, remote_call_finish(gpu, encoder, decoder); - if (shmem != &gpu->data_shmem) { + // Unlock mutex before cleanup + if (using_shared_shmem) { + mtx_unlock(&gpu->data_shmem_mutex); + } else { virtgpu_shmem_destroy(gpu, shmem); } } diff --git a/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp b/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp index 3e45e55bdcb..4b6b8f527be 100644 --- a/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +++ b/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp @@ -2,11 +2,6 @@ #include "virtgpu-shm.h" int apir_device_get_count(virtgpu * gpu) { - static int32_t dev_count = -1; - if (dev_count != -1) { - return dev_count; - } - apir_encoder * encoder; apir_decoder * decoder; ApirForwardReturnCode ret; @@ -14,6 +9,7 @@ int apir_device_get_count(virtgpu * gpu) { REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_DEVICE_GET_COUNT); REMOTE_CALL(gpu, encoder, decoder, ret); + int32_t dev_count = -1; apir_decode_int32_t(decoder, &dev_count); remote_call_finish(gpu, encoder, decoder); @@ -21,11 +17,7 @@ int apir_device_get_count(virtgpu * gpu) { return dev_count; } -const char * apir_device_get_name(virtgpu * gpu) { - static char * string = nullptr; - if (string) { - return string; - } +char * apir_device_get_name(virtgpu * gpu) { apir_encoder * encoder; apir_decoder * decoder; ApirForwardReturnCode ret; @@ -34,9 +26,9 @@ const char * apir_device_get_name(virtgpu * gpu) { REMOTE_CALL(gpu, encoder, decoder, ret); const size_t string_size = apir_decode_array_size_unchecked(decoder); - string = (char *) apir_decoder_alloc_array(sizeof(char), string_size); + char * string = (char *) apir_decoder_alloc_array(sizeof(char), string_size); if (!string) { - GGML_LOG_ERROR("%s: Could not allocate the device name buffer\n", __func__); + GGML_LOG_ERROR(GGML_VIRTGPU "%s: Could not allocate the device name buffer\n", __func__); return NULL; } apir_decode_char_array(decoder, string, string_size); @@ -46,7 +38,7 @@ const char * apir_device_get_name(virtgpu * gpu) { return string; } -const char * apir_device_get_description(virtgpu * gpu) { +char * apir_device_get_description(virtgpu * gpu) { apir_encoder * encoder; apir_decoder * decoder; ApirForwardReturnCode ret; @@ -58,7 +50,7 @@ const char * apir_device_get_description(virtgpu * gpu) { const size_t string_size = apir_decode_array_size_unchecked(decoder); char * string = (char *) apir_decoder_alloc_array(sizeof(char), string_size); if (!string) { - GGML_LOG_ERROR("%s: Could not allocate the device description buffer\n", __func__); + GGML_LOG_ERROR(GGML_VIRTGPU "%s: Could not allocate the device description buffer\n", __func__); return NULL; } @@ -181,7 +173,7 @@ apir_buffer_context_t apir_device_buffer_from_ptr(virtgpu * gpu, size_t size, si REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_DEVICE_BUFFER_FROM_PTR); if (virtgpu_shmem_create(gpu, size, &buffer_context.shmem)) { - GGML_ABORT("Couldn't allocate the guest-host shared buffer"); + GGML_ABORT(GGML_VIRTGPU "Couldn't allocate the guest-host shared buffer"); } apir_encode_virtgpu_shmem_res_id(encoder, buffer_context.shmem.res_id); diff --git a/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h b/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h index eea3e7e5a9b..f23c75bb968 100644 --- a/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +++ b/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h @@ -11,7 +11,7 @@ int32_t forward_flag = (int32_t) apir_command_type__; \ encoder_name = remote_call_prepare(gpu_dev_name, APIR_COMMAND_TYPE_FORWARD, forward_flag); \ if (!encoder_name) { \ - GGML_ABORT("%s: failed to prepare the remote call encoder", __func__); \ + GGML_ABORT(GGML_VIRTGPU "%s: failed to prepare the remote call encoder", __func__); \ } \ } while (0) @@ -19,10 +19,10 @@ do { \ ret_name = (ApirForwardReturnCode) remote_call(gpu_dev_name, encoder_name, &decoder_name, 0, NULL); \ if (!decoder_name) { \ - GGML_ABORT("%s: failed to kick the remote call", __func__); \ + GGML_ABORT(GGML_VIRTGPU "%s: failed to kick the remote call", __func__); \ } \ if (ret_name < APIR_FORWARD_BASE_INDEX) { \ - GGML_ABORT("%s: failed to forward the API call: %s: code %d", __func__, \ + GGML_ABORT(GGML_VIRTGPU "%s: failed to forward the API call: %s: code %d", __func__, \ apir_forward_error(ret_name), ret_name); \ } \ ret_name = (ApirForwardReturnCode) (ret_name - APIR_FORWARD_BASE_INDEX); \ diff --git a/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h b/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h index c27c07f0865..fe4cae20253 100644 --- a/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +++ b/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h @@ -3,8 +3,8 @@ /* device */ void apir_device_get_device_count(struct virtgpu * gpu); int apir_device_get_count(struct virtgpu * gpu); -const char * apir_device_get_name(struct virtgpu * gpu); -const char * apir_device_get_description(struct virtgpu * gpu); +char * apir_device_get_name(struct virtgpu * gpu); +char * apir_device_get_description(struct virtgpu * gpu); uint32_t apir_device_get_type(struct virtgpu * gpu); void apir_device_get_memory(struct virtgpu * gpu, size_t * free, size_t * total); bool apir_device_supports_op(struct virtgpu * gpu, const ggml_tensor * op); @@ -17,14 +17,15 @@ void apir_device_get_props(struct virtgpu * gpu, apir_buffer_context_t apir_device_buffer_from_ptr(struct virtgpu * gpu, size_t size, size_t max_tensor_size); /* buffer-type */ -const char * apir_buffer_type_get_name(struct virtgpu * gpu, ggml_backend_buffer_type_t buft); -size_t apir_buffer_type_get_alignment(struct virtgpu * gpu, ggml_backend_buffer_type_t buft); -size_t apir_buffer_type_get_max_size(struct virtgpu * gpu, ggml_backend_buffer_type_t buft); -bool apir_buffer_type_is_host(struct virtgpu * gpu, ggml_backend_buffer_type_t buft); -apir_buffer_context_t apir_buffer_type_alloc_buffer(struct virtgpu * gpu, - ggml_backend_buffer_type_t buffer_buft, - size_t size); -size_t apir_buffer_type_get_alloc_size(struct virtgpu * gpu, ggml_backend_buffer_type_t buft, const ggml_tensor * op); +char * apir_buffer_type_get_name(struct virtgpu * gpu, apir_buffer_type_host_handle_t host_handle); +size_t apir_buffer_type_get_alignment(struct virtgpu * gpu, apir_buffer_type_host_handle_t host_handle); +size_t apir_buffer_type_get_max_size(struct virtgpu * gpu, apir_buffer_type_host_handle_t host_handle); +apir_buffer_context_t apir_buffer_type_alloc_buffer(struct virtgpu * gpu, + apir_buffer_type_host_handle_t host_handle, + size_t size); +size_t apir_buffer_type_get_alloc_size(struct virtgpu * gpu, + apir_buffer_type_host_handle_t host_handle, + const ggml_tensor * op); /* buffer */ void * apir_buffer_get_base(struct virtgpu * gpu, apir_buffer_context_t * buffer_context); diff --git a/ggml/src/ggml-virtgpu/virtgpu-shm.cpp b/ggml/src/ggml-virtgpu/virtgpu-shm.cpp index 4def405a62b..ce6b3b3e607 100644 --- a/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +++ b/ggml/src/ggml-virtgpu/virtgpu-shm.cpp @@ -85,8 +85,7 @@ int virtgpu_shmem_create(virtgpu * gpu, size_t size, virtgpu_shmem * shmem) { void * ptr = virtgpu_ioctl_map(gpu, gem_handle, size); if (!ptr) { virtgpu_ioctl_gem_close(gpu, gem_handle); - GGML_LOG_ERROR("virtgpu_ioctl_map FAILED\n"); - exit(1); + GGML_LOG_ERROR(GGML_VIRTGPU "%s: virtgpu_ioctl_map failed\n", __func__); return 1; } diff --git a/ggml/src/ggml-virtgpu/virtgpu.cpp b/ggml/src/ggml-virtgpu/virtgpu.cpp index 005c8e21db8..1e650dc65b2 100644 --- a/ggml/src/ggml-virtgpu/virtgpu.cpp +++ b/ggml/src/ggml-virtgpu/virtgpu.cpp @@ -33,7 +33,7 @@ static int virtgpu_handshake(virtgpu * gpu) { encoder = remote_call_prepare(gpu, APIR_COMMAND_TYPE_HANDSHAKE, 0); if (!encoder) { - GGML_ABORT("%s: failed to prepare the remote call encoder", __func__); + GGML_ABORT(GGML_VIRTGPU "%s: failed to prepare the remote call encoder", __func__); return 1; } @@ -52,7 +52,7 @@ static int virtgpu_handshake(virtgpu * gpu) { log_call_duration(call_duration_ns, "API Remoting handshake"); if (!decoder) { - GGML_ABORT( + GGML_ABORT(GGML_VIRTGPU "%s: failed to initiate the communication with the virglrenderer library. " "Most likely, the wrong virglrenderer library was loaded in the hypervisor.", __func__); @@ -65,7 +65,8 @@ static int virtgpu_handshake(virtgpu * gpu) { uint32_t host_minor; if (ret_magic != APIR_HANDSHAKE_MAGIC) { - GGML_ABORT("%s: handshake with the virglrenderer failed (code=%d | %s)", __func__, ret_magic, + GGML_ABORT(GGML_VIRTGPU + "%s: handshake with the virglrenderer failed (code=%d | %s)", __func__, ret_magic, apir_backend_initialize_error(ret_magic)); } else { apir_decode_uint32_t(decoder, &host_major); @@ -78,13 +79,13 @@ static int virtgpu_handshake(virtgpu * gpu) { return 1; } - GGML_LOG_INFO("%s: Guest is running with %u.%u\n", __func__, guest_major, guest_minor); - GGML_LOG_INFO("%s: Host is running with %u.%u\n", __func__, host_major, host_minor); + GGML_LOG_INFO(GGML_VIRTGPU "%s: Guest is running with %u.%u\n", __func__, guest_major, guest_minor); + GGML_LOG_INFO(GGML_VIRTGPU "%s: Host is running with %u.%u\n", __func__, host_major, host_minor); if (guest_major != host_major) { - GGML_LOG_ERROR("Host major (%d) and guest major (%d) version differ\n", host_major, guest_major); + GGML_LOG_ERROR(GGML_VIRTGPU "Host major (%d) and guest major (%d) version differ\n", host_major, guest_major); } else if (guest_minor != host_minor) { - GGML_LOG_WARN("Host minor (%d) and guest minor (%d) version differ\n", host_minor, guest_minor); + GGML_LOG_WARN(GGML_VIRTGPU "Host minor (%d) and guest minor (%d) version differ\n", host_minor, guest_minor); } return 0; @@ -97,7 +98,7 @@ static ApirLoadLibraryReturnCode virtgpu_load_library(virtgpu * gpu) { encoder = remote_call_prepare(gpu, APIR_COMMAND_TYPE_LOADLIBRARY, 0); if (!encoder) { - GGML_ABORT("%s: hypercall error: failed to prepare the remote call encoder", __func__); + GGML_ABORT(GGML_VIRTGPU "%s: hypercall error: failed to prepare the API Remoting command encoder", __func__); return APIR_LOAD_LIBRARY_HYPERCALL_INITIALIZATION_ERROR; } @@ -108,36 +109,67 @@ static ApirLoadLibraryReturnCode virtgpu_load_library(virtgpu * gpu) { log_call_duration(call_duration_ns, "API Remoting LoadLibrary"); if (!decoder) { - GGML_ABORT("%s: hypercall error: failed to kick the API remoting hypercall.\n", __func__); + GGML_ABORT(GGML_VIRTGPU "%s: hypercall error: failed to trigger the API Remoting hypercall.\n", __func__); return APIR_LOAD_LIBRARY_HYPERCALL_INITIALIZATION_ERROR; } remote_call_finish(gpu, encoder, decoder); if (ret == APIR_LOAD_LIBRARY_SUCCESS) { - GGML_LOG_INFO("%s: The API Remoting backend was successfully loaded and initialized\n", __func__); + GGML_LOG_INFO(GGML_VIRTGPU "The API Remoting backend was successfully loaded and initialized\n"); return ret; } // something wrong happened, find out what. - if (ret < APIR_LOAD_LIBRARY_INIT_BASE_INDEX) { - GGML_ABORT("%s: virglrenderer could not load the API Remoting backend library: %s (code %d)", __func__, - apir_load_library_error(ret), ret); + if (ret == APIR_LOAD_LIBRARY_ENV_VAR_MISSING) { + GGML_ABORT(GGML_VIRTGPU + "%s: virglrenderer could not open the API Remoting backend library, " + "some environment variables are missing. " + "Make sure virglrenderer is correctly configured by the hypervisor. (%s)", + __func__, apir_load_library_error(ret)); + } else if (ret == APIR_LOAD_LIBRARY_CANNOT_OPEN) { + GGML_ABORT(GGML_VIRTGPU + "%s: virglrenderer could not open the API Remoting backend library. " + "Make sure virglrenderer is correctly configured by the hypervisor. (%s)", + __func__, apir_load_library_error(ret)); + } else if (ret == APIR_LOAD_LIBRARY_ENV_VAR_MISSING) { + GGML_ABORT(GGML_VIRTGPU + "%s: could not load the backend library, some symbols are missing. " + "Make sure virglrenderer is correctly configured by the hypervisor. (%s) ", + __func__, apir_load_library_error(ret)); + } else { + GGML_ABORT(GGML_VIRTGPU + "%s: virglrenderer could not load the API Remoting backend library. (%s - code %d)", __func__, + apir_load_library_error(ret), ret); + } return ret; } - GGML_LOG_INFO("%s: virglrenderer successfully loaded the API Remoting backend library", __func__); + GGML_LOG_INFO(GGML_VIRTGPU + "%s: virglrenderer successfully loaded the API Remoting backend library.\n", __func__); ApirLoadLibraryReturnCode apir_ret = (ApirLoadLibraryReturnCode) (ret - APIR_LOAD_LIBRARY_INIT_BASE_INDEX); - if (apir_ret < APIR_LOAD_LIBRARY_INIT_BASE_INDEX) { - GGML_ABORT("%s: the API Remoting backend library couldn't load the backend library: apir code=%d | %s)", + if (apir_ret == APIR_LOAD_LIBRARY_CANNOT_OPEN) { + GGML_ABORT(GGML_VIRTGPU + "%s: the API Remoting backend library couldn't load the GGML backend library. " + "Make sure virglrenderer is correctly configured by the hypervisor. (%s)", + __func__, apir_load_library_error(apir_ret)); + } else if (apir_ret == APIR_LOAD_LIBRARY_SYMBOL_MISSING) { + GGML_ABORT(GGML_VIRTGPU + "%s: the API Remoting backend library couldn't load the GGML backend library, some symbols are missing. " + "Make sure virglrenderer is correctly configured by the hypervisor. (%s)", + __func__, apir_load_library_error(apir_ret)); + } else if (apir_ret < APIR_LOAD_LIBRARY_INIT_BASE_INDEX) { + GGML_ABORT(GGML_VIRTGPU + "%s: the API Remoting backend library couldn't load the GGML backend library: apir code=%d | %s)", __func__, apir_ret, apir_load_library_error(apir_ret)); } else { uint32_t lib_ret = apir_ret - APIR_LOAD_LIBRARY_INIT_BASE_INDEX; - GGML_ABORT("%s: the API Remoting backend library initialize its backend library: apir code=%d)", __func__, + GGML_ABORT(GGML_VIRTGPU + "%s: the API Remoting backend library initialize its backend library: apir code=%d)", __func__, lib_ret); } return ret; @@ -149,38 +181,58 @@ virtgpu * create_virtgpu() { gpu->use_apir_capset = getenv("GGML_REMOTING_USE_APIR_CAPSET") != nullptr; util_sparse_array_init(&gpu->shmem_array, sizeof(virtgpu_shmem), 1024); + // Initialize mutex to protect shared data_shmem buffer + if (mtx_init(&gpu->data_shmem_mutex, mtx_plain) != thrd_success) { + delete gpu; + GGML_ABORT(GGML_VIRTGPU + "%s: failed to initialize data_shmem mutex", __func__); + return NULL; + } + if (virtgpu_open(gpu) != APIR_SUCCESS) { - GGML_ABORT("%s: failed to open the virtgpu device", __func__); + GGML_LOG_ERROR(GGML_VIRTGPU + "%s: failed to open the virtgpu device\n", __func__); return NULL; } if (virtgpu_init_capset(gpu) != APIR_SUCCESS) { - GGML_ABORT("%s: failed to initialize the GPU capset", __func__); + if (gpu->use_apir_capset) { + GGML_ABORT(GGML_VIRTGPU + "%s: failed to initialize the virtgpu APIR capset. Make sure that the virglrenderer library supports it.", __func__); + } else { + GGML_ABORT(GGML_VIRTGPU + "%s: failed to initialize the virtgpu Venus capset", __func__); + } return NULL; } if (virtgpu_init_context(gpu) != APIR_SUCCESS) { - GGML_ABORT("%s: failed to initialize the GPU context", __func__); + GGML_ABORT(GGML_VIRTGPU + "%s: failed to initialize the GPU context", __func__); return NULL; } if (virtgpu_shmem_create(gpu, SHMEM_REPLY_SIZE, &gpu->reply_shmem)) { - GGML_ABORT("%s: failed to create the shared reply memory pages", __func__); + GGML_ABORT(GGML_VIRTGPU + "%s: failed to create the shared reply memory pages", __func__); return NULL; } if (virtgpu_shmem_create(gpu, SHMEM_DATA_SIZE, &gpu->data_shmem)) { - GGML_ABORT("%s: failed to create the shared data memory pages", __func__); + GGML_ABORT(GGML_VIRTGPU + "%s: failed to create the shared data memory pages", __func__); return NULL; } if (virtgpu_handshake(gpu)) { - GGML_ABORT("%s: failed to handshake with the virglrenderer library", __func__); + GGML_ABORT(GGML_VIRTGPU + "%s: failed to handshake with the virglrenderer library", __func__); return NULL; } if (virtgpu_load_library(gpu) != APIR_LOAD_LIBRARY_SUCCESS) { - GGML_ABORT("%s: failed to load the backend library", __func__); + GGML_ABORT(GGML_VIRTGPU + "%s: failed to load the backend library", __func__); return NULL; } @@ -191,7 +243,8 @@ static virt_gpu_result_t virtgpu_open(virtgpu * gpu) { drmDevicePtr devs[8]; int count = drmGetDevices2(0, devs, ARRAY_SIZE(devs)); if (count < 0) { - GGML_LOG_ERROR("%s: failed to enumerate DRM devices\n", __func__); + GGML_LOG_ERROR(GGML_VIRTGPU + "%s: failed to enumerate DRM devices\n", __func__); return APIR_ERROR_INITIALIZATION_FAILED; } @@ -213,16 +266,19 @@ static virt_gpu_result_t virtgpu_open_device(virtgpu * gpu, const drmDevicePtr d int fd = open(node_path, O_RDWR | O_CLOEXEC); if (fd < 0) { - GGML_ABORT("failed to open %s", node_path); + GGML_ABORT(GGML_VIRTGPU + "%s: failed to open %s", __func__, node_path); return APIR_ERROR_INITIALIZATION_FAILED; } drmVersionPtr version = drmGetVersion(fd); if (!version || strcmp(version->name, "virtio_gpu") || version->version_major != 0) { if (version) { - GGML_ABORT("unknown DRM driver %s version %d", version->name, version->version_major); + GGML_LOG_ERROR(GGML_VIRTGPU + "%s: unknown DRM driver %s version %d\n", __func__, version->name, version->version_major); } else { - GGML_ABORT("failed to get DRM driver version"); + GGML_LOG_ERROR(GGML_VIRTGPU + "%s: failed to get DRM driver version\n", __func__); } if (version) { @@ -236,7 +292,7 @@ static virt_gpu_result_t virtgpu_open_device(virtgpu * gpu, const drmDevicePtr d drmFreeVersion(version); - GGML_LOG_INFO("using DRM device %s\n", node_path); + GGML_LOG_INFO(GGML_VIRTGPU "using DRM device %s\n", node_path); return APIR_SUCCESS; } @@ -245,7 +301,7 @@ static virt_gpu_result_t virtgpu_init_context(virtgpu * gpu) { assert(!gpu->capset.version); const int ret = virtgpu_ioctl_context_init(gpu, gpu->capset.id); if (ret) { - GGML_LOG_INFO("failed to initialize context: %s\n", strerror(errno)); + GGML_LOG_ERROR(GGML_VIRTGPU "%s: failed to initialize context: %s\n", __func__, strerror(errno)); return APIR_ERROR_INITIALIZATION_FAILED; } @@ -254,10 +310,10 @@ static virt_gpu_result_t virtgpu_init_context(virtgpu * gpu) { static virt_gpu_result_t virtgpu_init_capset(virtgpu * gpu) { if (gpu->use_apir_capset) { - GGML_LOG_INFO("Using the APIR capset\n"); + GGML_LOG_INFO(GGML_VIRTGPU "Using the APIR capset\n"); gpu->capset.id = VIRTGPU_DRM_CAPSET_APIR; } else { - GGML_LOG_INFO("Using the Venus capset\n"); + GGML_LOG_INFO(GGML_VIRTGPU "Using the Venus capset\n"); gpu->capset.id = VIRTGPU_DRM_CAPSET_VENUS; } gpu->capset.version = 0; @@ -266,7 +322,9 @@ static virt_gpu_result_t virtgpu_init_capset(virtgpu * gpu) { virtgpu_ioctl_get_caps(gpu, gpu->capset.id, gpu->capset.version, &gpu->capset.data, sizeof(gpu->capset.data)); if (ret) { - GGML_LOG_INFO("failed to get APIR v%d capset: %s\n", gpu->capset.version, strerror(errno)); + GGML_LOG_ERROR(GGML_VIRTGPU + "%s: failed to get APIR v%d capset: %s\n", + __func__, gpu->capset.version, strerror(errno)); return APIR_ERROR_INITIALIZATION_FAILED; } @@ -333,9 +391,9 @@ apir_encoder * remote_call_prepare(virtgpu * gpu, ApirCommandType apir_cmd_type, * Prepare the command encoder and its buffer */ - static char encoder_buffer[4096]; + thread_local char encoder_buffer[4096]; - static apir_encoder enc; + thread_local apir_encoder enc; enc = { .cur = encoder_buffer, .start = encoder_buffer, @@ -369,19 +427,19 @@ void remote_call_finish(virtgpu * gpu, apir_encoder * enc, apir_decoder * dec) { UNUSED(gpu); if (!enc) { - GGML_LOG_ERROR("Invalid (null) encoder\n"); + GGML_ABORT(GGML_VIRTGPU "%s: Invalid (null) encoder", __func__); } if (!dec) { - GGML_LOG_ERROR("Invalid (null) decoder\n"); + GGML_ABORT(GGML_VIRTGPU "%s: Invalid (null) decoder", __func__); } if (apir_encoder_get_fatal(enc)) { - GGML_LOG_ERROR("Failed to encode the output parameters.\n"); + GGML_LOG_ERROR(GGML_VIRTGPU "%s: Failed to encode the output parameters.", __func__); } if (apir_decoder_get_fatal(dec)) { - GGML_LOG_ERROR("Failed to decode the input parameters.\n"); + GGML_LOG_ERROR(GGML_VIRTGPU "%s: Failed to decode the input parameters.", __func__); } } @@ -423,7 +481,7 @@ uint32_t remote_call(virtgpu * gpu, int ret = drmIoctl(gpu->fd, DRM_IOCTL_VIRTGPU_EXECBUFFER, &args); if (ret != 0) { - GGML_ABORT("%s: the virtgpu EXECBUFFER ioctl failed (%d)", __func__, ret); + GGML_ABORT(GGML_VIRTGPU "%s: the virtgpu EXECBUFFER ioctl failed (%d)", __func__, ret); } /* @@ -467,7 +525,7 @@ uint32_t remote_call(virtgpu * gpu, } if (max_wait_ms && timedout) { - GGML_LOG_ERROR("timed out waiting for the host answer...\n"); + GGML_LOG_ERROR(GGML_VIRTGPU "%s: timed out waiting for the host answer...\n", __func__); return APIR_FORWARD_TIMEOUT; } @@ -489,10 +547,13 @@ static void log_call_duration(long long call_duration_ns, const char * name) { double call_duration_s = (double) call_duration_ns / 1e9; // 1 second = 1e9 nanoseconds if (call_duration_s > 1) { - GGML_LOG_INFO("%s: waited %.2fs for the %s host reply...\n", __func__, call_duration_s, name); + GGML_LOG_INFO(GGML_VIRTGPU + "waited %.2fs for the %s host reply...\n", call_duration_s, name); } else if (call_duration_ms > 1) { - GGML_LOG_INFO("%s: waited %.2fms for the %s host reply...\n", __func__, call_duration_ms, name); + GGML_LOG_INFO(GGML_VIRTGPU + "waited %.2fms for the %s host reply...\n", call_duration_ms, name); } else { - GGML_LOG_INFO("%s: waited %lldns for the %s host reply...\n", __func__, call_duration_ns, name); + GGML_LOG_INFO(GGML_VIRTGPU + "waited %lldns for the %s host reply...\n", call_duration_ns, name); } } diff --git a/ggml/src/ggml-virtgpu/virtgpu.h b/ggml/src/ggml-virtgpu/virtgpu.h index d4bb42e20b2..68e0f3a376e 100644 --- a/ggml/src/ggml-virtgpu/virtgpu.h +++ b/ggml/src/ggml-virtgpu/virtgpu.h @@ -17,6 +17,8 @@ #include +#include "ggml-remoting.h" + #define VIRGL_RENDERER_UNSTABLE_APIS 1 #include "apir_hw.h" #include @@ -73,6 +75,27 @@ struct virtgpu { /* APIR communication pages */ virtgpu_shmem reply_shmem; virtgpu_shmem data_shmem; + + /* Mutex to protect shared data_shmem buffer from concurrent access */ + mtx_t data_shmem_mutex; + + /* Cached device information to prevent memory leaks and race conditions */ + struct { + char * description; + char * name; + int32_t device_count; + uint32_t type; + size_t memory_free; + size_t memory_total; + } cached_device_info; + + /* Cached buffer type information to prevent memory leaks and race conditions */ + struct { + apir_buffer_type_host_handle_t host_handle; + char * name; + size_t alignment; + size_t max_size; + } cached_buffer_type; }; static inline int virtgpu_ioctl(virtgpu * gpu, unsigned long request, void * args) { From eecc9bfa690b124345d138db7fc76735ff60d78f Mon Sep 17 00:00:00 2001 From: will-lms Date: Thu, 5 Feb 2026 01:05:09 -0500 Subject: [PATCH 098/831] metal : add missing includes (llama/19348) --- ggml/src/ggml-metal/ggml-metal.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/ggml/src/ggml-metal/ggml-metal.cpp b/ggml/src/ggml-metal/ggml-metal.cpp index a616dcdb461..1c705362fb7 100644 --- a/ggml/src/ggml-metal/ggml-metal.cpp +++ b/ggml/src/ggml-metal/ggml-metal.cpp @@ -7,6 +7,9 @@ #include "ggml-metal-context.h" #include "ggml-metal-ops.h" +#include +#include + #define GGML_METAL_NAME "MTL" #define GGML_METAL_MAX_DEVICES 16 From e0a3f393ad7aa34d1043ca9e9c93c43659284962 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Thu, 5 Feb 2026 01:38:59 -0600 Subject: [PATCH 099/831] vulkan: fix non-contig rope (llama/19299) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 32 ++++-- .../ggml-vulkan/vulkan-shaders/rms_norm.comp | 5 +- .../vulkan-shaders/rope_funcs.glsl | 99 +++++++------------ .../vulkan-shaders/rope_multi.comp | 11 ++- .../ggml-vulkan/vulkan-shaders/rope_neox.comp | 11 ++- .../ggml-vulkan/vulkan-shaders/rope_norm.comp | 11 ++- .../vulkan-shaders/rope_params.glsl | 15 ++- .../vulkan-shaders/rope_vision.comp | 11 ++- 8 files changed, 100 insertions(+), 95 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index cb7fa2c9cbb..af57685a37d 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -1263,25 +1263,30 @@ struct vk_op_diag_mask_push_constants { struct vk_op_rope_push_constants { uint32_t rope_mode; - uint32_t ncols; uint32_t nrows; uint32_t n_dims; float freq_scale; - uint32_t p_delta_rows; float freq_base; float ext_factor; float attn_factor; float corr_dims[2]; float theta_scale; uint32_t has_ff; - uint32_t ne02; - uint32_t s1; - uint32_t s2; int32_t sections[4]; uint32_t is_imrope; uint32_t is_back; uint32_t set_rows_stride; + uint32_t ne00; + uint32_t ne01; + uint32_t ne02; + uint32_t nb01; + uint32_t nb02; + uint32_t nb03; + uint32_t nb11; + uint32_t nb12; + uint32_t nb13; }; +static_assert(sizeof(vk_op_rope_push_constants) <= 128, "sizeof(vk_op_rope_push_constants) must be <= 128"); // For fused rms_norm+mul+rope(+view+set_rows) struct vk_op_rms_norm_mul_rope_push_constants { @@ -10405,12 +10410,22 @@ static vk_op_rope_push_constants ggml_vk_make_rope_constants(const ggml_tensor * uint32_t nb01 = src0->nb[1] / ggml_type_size(src0->type); uint32_t nb02 = src0->nb[2] / ggml_type_size(src0->type); + uint32_t nb03 = src0->nb[3] / ggml_type_size(src0->type); + + uint32_t nb11 = dst->nb[1] / ggml_type_size(dst->type); + uint32_t nb12 = dst->nb[2] / ggml_type_size(dst->type); + uint32_t nb13 = dst->nb[3] / ggml_type_size(dst->type); vk_op_rope_push_constants rope { - (uint32_t)mode, (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1], - freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale, - has_ff, (uint32_t)src0->ne[2], nb01, nb02, + (uint32_t)mode, (uint32_t)ggml_nrows(src0), (uint32_t)n_dims, freq_scale, + freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale, has_ff, { sections[0], sections[1], sections[2], sections[3] }, is_imrope, backprop, set_rows_stride, + + (uint32_t)src0->ne[0], + (uint32_t)src0->ne[1], + (uint32_t)src0->ne[2], + nb01, nb02, nb03, + nb11, nb12, nb13, }; return rope; @@ -14798,6 +14813,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_REPEAT_BACK: return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_ROPE: + return ggml_is_contiguous_rows(op) && ggml_is_contiguous_rows(op->src[0]); case GGML_OP_ROPE_BACK: case GGML_OP_NONE: case GGML_OP_RESHAPE: diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp index 9d6d3665427..55b89f19a7a 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp @@ -112,12 +112,11 @@ void rms_norm(uint num_iters) { #if RMS_NORM_ROPE_FUSION barrier(); rope_params rp = p.rope; - uint rope_row = (samp*nchannels + channel)*nrows + row; for (uint t = 2*tid; t < ncols; t += 2*BLOCK_SIZE) { if (rp.rope_mode == GGML_ROPE_TYPE_NEOX) { - rope_neox(t, rope_row, rp); + rope_neox(t, row, channel, samp, rp); } else if (rp.rope_mode == GGML_ROPE_TYPE_NORMAL) { - rope_norm(t, rope_row, rp); + rope_norm(t, row, channel, samp, rp); } } #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl index aacec984696..2e53459909d 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl @@ -4,12 +4,12 @@ float rope_yarn_ramp(const float low, const float high, const uint i0) { return 1.0f - min(1.0f, max(0.0f, y)); } -uint rope_a_coord(const uint i0, const uint i01, const uint i02, rope_params p) { +uint rope_a_coord(const uint i0, const uint i01, const uint i02, const uint i03, rope_params p) { #if RMS_NORM_ROPE_FUSION // Per-row offset in shared memory const uint ix = i0; #else - const uint ix = i02*p.nb02 + i01*p.nb01 + i0; + const uint ix = i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i0; #endif return ix; } @@ -34,26 +34,19 @@ void rope_yarn(const float theta_extrap, const uint i0, out float cos_theta, out sin_theta = sin(theta) * mscale; } -void rope_norm(const uint i0, const uint i1, rope_params p) { - uint ne0 = p.ncols; - uint ne1 = p.p_delta_rows; - - if (i0 >= ne0) { +void rope_norm(const uint i0, const uint i1, const uint i2, const uint i3, rope_params p) { + if (i0 >= p.ne00) { return; } - // i1 is actually i2*nb2+i1, but the rows are contiguous - const uint i01 = i1 % ne1; - const uint i02 = i1 / ne1; - - uint idst = i1*ne0 + i0; - const uint ix = rope_a_coord(i0, i01, i02, p); + uint idst = i0 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13; + const uint ix = rope_a_coord(i0, i1, i2, i3, p); // Fusion optimization: ROPE + VIEW + SET_ROWS. // The rope output is viewed as a 1D tensor and offset based on a row index in rope_data_i. if (p.set_rows_stride != 0) { - idst = i01*ne0 + i0; - idst += rope_data_i[i02].x * p.set_rows_stride; + idst = i1*p.nb11 + i0; + idst += rope_data_i[i2].x * p.set_rows_stride; } if (i0 >= p.n_dims) { @@ -63,7 +56,7 @@ void rope_norm(const uint i0, const uint i1, rope_params p) { return; } - const float theta_base = rope_data_pos[i02] * pow(p.theta_scale, i0/2.0f); + const float theta_base = rope_data_pos[i2] * pow(p.theta_scale, i0/2.0f); const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f; @@ -77,25 +70,19 @@ void rope_norm(const uint i0, const uint i1, rope_params p) { rope_data_d[idst + 1] = ROPE_D_TYPE(x0*sin_theta + x1*cos_theta); } -void rope_neox(const uint i0, const uint i1, rope_params p) { - uint ne0 = p.ncols; - uint ne1 = p.p_delta_rows; - - if (i0 >= ne0) { +void rope_neox(const uint i0, const uint i1, const uint i2, const uint i3, rope_params p) { + if (i0 >= p.ne00) { return; } - const uint i01 = i1 % ne1; - const uint i02 = i1 / ne1; - - uint idst = i1*ne0 + i0/2; - const uint ix = rope_a_coord(i0/2, i01, i02, p); + uint idst = i0/2 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13; + const uint ix = rope_a_coord(i0/2, i1, i2, i3, p); // Fusion optimization: ROPE + VIEW + SET_ROWS. // The rope output is viewed as a 1D tensor and offset based on a row index in rope_data_i. if (p.set_rows_stride != 0) { - idst = i01*ne0 + i0/2; - idst += rope_data_i[i02].x * p.set_rows_stride; + idst = i1*p.nb11 + i0/2; + idst += rope_data_i[i2].x * p.set_rows_stride; } if (i0 >= p.n_dims) { @@ -105,7 +92,7 @@ void rope_neox(const uint i0, const uint i1, rope_params p) { return; } - const float theta_base = rope_data_pos[i02] * pow(p.theta_scale, i0/2.0f); + const float theta_base = rope_data_pos[i2] * pow(p.theta_scale, i0/2.0f); const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f; @@ -120,26 +107,19 @@ void rope_neox(const uint i0, const uint i1, rope_params p) { } -void rope_multi(const uint i0, const uint i1, rope_params p) { - uint ne0 = p.ncols; - uint ne1 = p.p_delta_rows; - uint ne2 = p.ne02; - - if (i0 >= ne0) { +void rope_multi(const uint i0, const uint i1, const uint i2, const uint i3, rope_params p) { + if (i0 >= p.ne00) { return; } - const uint i01 = i1 % ne1; - const uint i02 = i1 / ne1; - - uint idst = i1*ne0 + i0/2; - const uint ix = rope_a_coord(i0/2, i01, i02, p); + uint idst = i0/2 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13; + const uint ix = rope_a_coord(i0/2, i1, i2, i3, p); // Fusion optimization: ROPE + VIEW + SET_ROWS. // The rope output is viewed as a 1D tensor and offset based on a row index in rope_data_i. if (p.set_rows_stride != 0) { - idst = i01*ne0 + i0/2; - idst += rope_data_i[i02].x * p.set_rows_stride; + idst = i1*p.nb11 + i0/2; + idst += rope_data_i[i2].x * p.set_rows_stride; } if (i0 >= p.n_dims) { @@ -156,26 +136,26 @@ void rope_multi(const uint i0, const uint i1, rope_params p) { float theta_base = 0.0; if (p.is_imrope != 0) { if (sector % 3 == 1 && sector < 3 * p.sections[1]) { - theta_base = rope_data_pos[i02 + ne2 * 1]*pow(p.theta_scale, i0/2.0f); + theta_base = rope_data_pos[i2 + p.ne02 * 1]*pow(p.theta_scale, i0/2.0f); } else if (sector % 3 == 2 && sector < 3 * p.sections[2]) { - theta_base = rope_data_pos[i02 + ne2 * 2]*pow(p.theta_scale, i0/2.0f); + theta_base = rope_data_pos[i2 + p.ne02 * 2]*pow(p.theta_scale, i0/2.0f); } else if (sector % 3 == 0 && sector < 3 * p.sections[0]) { - theta_base = rope_data_pos[i02]*pow(p.theta_scale, i0/2.0f); + theta_base = rope_data_pos[i2]*pow(p.theta_scale, i0/2.0f); } else { - theta_base = rope_data_pos[i02 + ne2 * 3]*pow(p.theta_scale, i0/2.0f); + theta_base = rope_data_pos[i2 + p.ne02 * 3]*pow(p.theta_scale, i0/2.0f); } } else { if (sector < p.sections[0]) { - theta_base = rope_data_pos[i02]*pow(p.theta_scale, i0/2.0f); + theta_base = rope_data_pos[i2]*pow(p.theta_scale, i0/2.0f); } else if (sector >= p.sections[0] && sector < sec_w) { - theta_base = rope_data_pos[i02 + ne2 * 1]*pow(p.theta_scale, i0/2.0f); + theta_base = rope_data_pos[i2 + p.ne02 * 1]*pow(p.theta_scale, i0/2.0f); } else if (sector >= sec_w && sector < sec_w + p.sections[2]) { - theta_base = rope_data_pos[i02 + ne2 * 2]*pow(p.theta_scale, i0/2.0f); + theta_base = rope_data_pos[i2 + p.ne02 * 2]*pow(p.theta_scale, i0/2.0f); } else if (sector >= sec_w + p.sections[2]) { - theta_base = rope_data_pos[i02 + ne2 * 3]*pow(p.theta_scale, i0/2.0f); + theta_base = rope_data_pos[i2 + p.ne02 * 3]*pow(p.theta_scale, i0/2.0f); } } @@ -191,20 +171,13 @@ void rope_multi(const uint i0, const uint i1, rope_params p) { rope_data_d[idst + p.n_dims/2] = ROPE_D_TYPE(x0*sin_theta + x1*cos_theta); } -void rope_vision(const uint i0, const uint i1, rope_params p) { - uint ne0 = p.ncols; - uint ne1 = p.p_delta_rows; - uint ne2 = p.ne02; - - if (i0 >= ne0) { +void rope_vision(const uint i0, const uint i1, const uint i2, const uint i3, rope_params p) { + if (i0 >= p.ne00) { return; } - const uint i01 = i1 % ne1; - const uint i02 = i1 / ne1; - - const uint idst = i1*ne0 + i0/2; - const uint ix = rope_a_coord(i0/2, i01, i02, p); + const uint idst = i0/2 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13; + const uint ix = rope_a_coord(i0/2, i1, i2, i3, p); const int sect_dims = p.sections[0] + p.sections[1]; const int sec_w = p.sections[1] + p.sections[0]; @@ -213,11 +186,11 @@ void rope_vision(const uint i0, const uint i1, rope_params p) { float theta_base = 0.0; if (sector < p.sections[0]) { const uint p0 = sector; - theta_base = rope_data_pos[i02]*pow(p.theta_scale, p0); + theta_base = rope_data_pos[i2]*pow(p.theta_scale, p0); } else if (sector >= p.sections[0] && sector < sec_w) { const uint p0 = sector - p.sections[0]; - theta_base = rope_data_pos[i02 + ne2]*pow(p.theta_scale, p0); + theta_base = rope_data_pos[i2 + p.ne02]*pow(p.theta_scale, p0); } const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp index f7587468a81..1528fbeeaec 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp @@ -5,10 +5,13 @@ void main() { const uint i0 = 2*gl_GlobalInvocationID.y; - // i1 is actually i2*nb2+i1, but the rows are contiguous - const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z; - if (i1 >= pc.nrows) { + const uint row = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z; + if (row >= pc.nrows) { return; } - rope_multi(i0, i1, pc); + const uint i3 = row / (pc.ne01*pc.ne02); + const uint i2 = (row - i3 * pc.ne01*pc.ne02) / pc.ne01; + const uint i1 = (row - i3 * pc.ne01*pc.ne02 - i2 * pc.ne01); + + rope_multi(i0, i1, i2, i3, pc); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp index acb8ed78155..ad0896095db 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp @@ -5,10 +5,13 @@ void main() { const uint i0 = 2*gl_GlobalInvocationID.y; - // i1 is actually i2*nb2+i1, but the rows are contiguous - const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z; - if (i1 >= pc.nrows) { + const uint row = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z; + if (row >= pc.nrows) { return; } - rope_neox(i0, i1, pc); + const uint i3 = row / (pc.ne01*pc.ne02); + const uint i2 = (row - i3 * pc.ne01*pc.ne02) / pc.ne01; + const uint i1 = (row - i3 * pc.ne01*pc.ne02 - i2 * pc.ne01); + + rope_neox(i0, i1, i2, i3, pc); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp index 0033cdb224f..11220817df0 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp @@ -5,10 +5,13 @@ void main() { const uint i0 = 2*gl_GlobalInvocationID.y; - // i1 is actually i2*nb2+i1, but the rows are contiguous - const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z; - if (i1 >= pc.nrows) { + const uint row = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z; + if (row >= pc.nrows) { return; } - rope_norm(i0, i1, pc); + const uint i3 = row / (pc.ne01*pc.ne02); + const uint i2 = (row - i3 * pc.ne01*pc.ne02) / pc.ne01; + const uint i1 = (row - i3 * pc.ne01*pc.ne02 - i2 * pc.ne01); + + rope_norm(i0, i1, i2, i3, pc); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl index 939cf3c51cd..ec6ceaca9bd 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl @@ -5,24 +5,29 @@ struct rope_params { uint rope_mode; - uint ncols; uint nrows; uint n_dims; float freq_scale; - uint p_delta_rows; float freq_base; float ext_factor; float attn_factor; float corr_dims[2]; float theta_scale; uint has_ff; - uint ne02; - uint nb01; - uint nb02; int sections[4]; uint is_imrope; uint is_back; uint set_rows_stride; + + uint ne00; + uint ne01; + uint ne02; + uint nb01; + uint nb02; + uint nb03; + uint nb11; + uint nb12; + uint nb13; }; #endif // !defined(GGML_ROPE_PARAMS) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp index d93800b5e76..ca71efb2f55 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp @@ -5,10 +5,13 @@ void main() { const uint i0 = 2*gl_GlobalInvocationID.y; - // i1 is actually i2*nb2+i1, but the rows are contiguous - const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z; - if (i1 >= pc.nrows) { + const uint row = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z; + if (row >= pc.nrows) { return; } - rope_vision(i0, i1, pc); + const uint i3 = row / (pc.ne01*pc.ne02); + const uint i2 = (row - i3 * pc.ne01*pc.ne02) / pc.ne01; + const uint i1 = (row - i3 * pc.ne01*pc.ne02 - i2 * pc.ne01); + + rope_vision(i0, i1, i2, i3, pc); } From 5a786f76480fb1fd11060a834b49103320d70328 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Thu, 5 Feb 2026 01:48:33 -0600 Subject: [PATCH 100/831] vulkan: Set k_load_shmem to false when K is too large (llama/19301) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index af57685a37d..2f6570181ad 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -3204,9 +3204,10 @@ static void ggml_vk_load_shaders(vk_device& device) { const uint32_t D_lsb = D ^ (D & (D-1)); uint32_t D_split = std::min(std::min(device->subgroup_size, 8u), D_lsb / 4); - // Nvidia prefers shared memory use to load large tiles of K + // Nvidia prefers shared memory use to load large tiles of K. + // Switch to loading from global memory when it would use too much shared memory. // AMD prefers loading K directly from global memory - const uint32_t k_load_shmem = device->vendor_id == VK_VENDOR_ID_NVIDIA ? 1 : 0; + const uint32_t k_load_shmem = device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 ? 1 : 0; return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split, device->subgroup_size, k_load_shmem}; }; @@ -8412,7 +8413,7 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co const uint32_t sfshstride = (hsk <= 128) ? (Br + 8) : Br; const uint32_t sfsh = Bc * sfshstride * acctype; - const bool k_load_shmem = device->vendor_id == VK_VENDOR_ID_NVIDIA; + const bool k_load_shmem = device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256; const uint32_t kshstride = (k_load_shmem ? hsk_pad : MatBr) / 4 + 2; const uint32_t vsh_stride = MatBc / 4 * row_split; const uint32_t ksh = ((kshstride >= vsh_stride) ? (Bc * kshstride) : (Bc * vsh_stride)) * f16vec4; From 932def31988609785b9e350801cb4b278084ca03 Mon Sep 17 00:00:00 2001 From: Oleksandr Kuvshynov <661042+okuvshynov@users.noreply.github.com> Date: Thu, 5 Feb 2026 03:06:59 -0500 Subject: [PATCH 101/831] vulkan: fix GPU deduplication logic. (llama/19222) * vulkan: fix GPU deduplication logic. As reported in https://github.com/ggml-org/llama.cpp/issues/19221, the (same uuid, same driver) logic is problematic for windows+intel igpu. Let's just avoid filtering for MoltenVK which is apple-specific, and keep the logic the same as before 88d23ad5 - just dedup based on UUID. Verified that MacOS + 4xVega still reports 4 GPUs with this version. * vulkan: only skip dedup when both drivers are moltenVk --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 2f6570181ad..ff9cb7355c2 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -5561,9 +5561,9 @@ static void ggml_vk_instance_init() { // Check if there are two physical devices corresponding to the same GPU // This handles the case where the same GPU appears with different drivers (e.g., RADV + AMDVLK on Linux), // see https://github.com/ggml-org/llama.cpp/pull/7582 for original deduplication. - // However, for MoltenVK on macOS, multiple GPUs on the same card may report the same UUID, - // see https://github.com/KhronosGroup/MoltenVK/issues/2683. Until this is fixed, we'll only deduplicate - // when drivers differ (same driver + same UUID = likely different GPUs) + // MoltenVK on macOS may report the same UUID for distinct GPUs on multi-GPU cards, + // see https://github.com/KhronosGroup/MoltenVK/issues/2683. Skip when both old/new + // driver is MoltenVK auto old_device = std::find_if( vk_instance.device_indices.begin(), vk_instance.device_indices.end(), @@ -5580,11 +5580,9 @@ static void ggml_vk_instance_init() { old_id.deviceLUIDValid && new_id.deviceLUIDValid && std::equal(std::begin(old_id.deviceLUID), std::end(old_id.deviceLUID), std::begin(new_id.deviceLUID)) ); + bool both_molten_vk = (new_driver.driverID == vk::DriverId::eMoltenvk && old_driver.driverID == vk::DriverId::eMoltenvk); - // Only deduplicate if same UUID AND different drivers - // (same driver + same UUID on MoltenVK = likely different GPUs on multi-GPU card) - bool different_driver = (old_driver.driverID != new_driver.driverID); - return same_uuid && different_driver; + return same_uuid && !both_molten_vk; } ); if (old_device == vk_instance.device_indices.end()) { From 0781df25183fc674d84db01a0da94289c83b3a03 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 5 Feb 2026 10:08:45 +0200 Subject: [PATCH 102/831] metal : add diag (llama/19330) --- ggml/src/ggml-metal/ggml-metal-device.cpp | 20 ++++++++++ ggml/src/ggml-metal/ggml-metal-device.h | 1 + ggml/src/ggml-metal/ggml-metal-device.m | 4 +- ggml/src/ggml-metal/ggml-metal-impl.h | 19 ++++++++++ ggml/src/ggml-metal/ggml-metal-ops.cpp | 46 +++++++++++++++++++++++ ggml/src/ggml-metal/ggml-metal-ops.h | 1 + ggml/src/ggml-metal/ggml-metal.metal | 20 ++++++++++ 7 files changed, 110 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 4cd3d93d813..6af0dd88d55 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -176,6 +176,26 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_set_rows(ggml_me return res; } +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_diag(ggml_metal_library_t lib, const ggml_tensor * op) { + char base[256]; + char name[256]; + + const int n = op->src[0]->ne[0]; + + snprintf(base, 256, "kernel_diag_%s", ggml_type_name(op->src[0]->type)); + snprintf(name, 256, "%s_n=%d", base, n); + + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + } + + res.nsg = 1; + res.smem = 0; + + return res; +} + ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_repeat(ggml_metal_library_t lib, ggml_type tsrc) { char base[256]; char name[256]; diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index d8984327124..84dcec30830 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -108,6 +108,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_1d struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_2d (ggml_metal_library_t lib, const struct ggml_tensor * op, enum ggml_op_pool op_pool); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_get_rows (ggml_metal_library_t lib, enum ggml_type tsrc); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_set_rows (ggml_metal_library_t lib, enum ggml_type tidx, enum ggml_type tdst); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_diag (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_repeat (ggml_metal_library_t lib, enum ggml_type tsrc); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_glu (ggml_metal_library_t lib, const struct ggml_tensor * op); diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 8a0b85c6e4d..c8e737d4187 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -1152,8 +1152,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te return has_simdgroup_reduction; case GGML_OP_RWKV_WKV6: case GGML_OP_RWKV_WKV7: - case GGML_OP_SOLVE_TRI: return true; + case GGML_OP_SOLVE_TRI: case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: return has_simdgroup_reduction; @@ -1235,6 +1235,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te return false; }; } + case GGML_OP_DIAG: + return true; case GGML_OP_OPT_STEP_ADAMW: case GGML_OP_OPT_STEP_SGD: return has_simdgroup_reduction; diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 640ade8f880..7f73cb97bbb 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -792,6 +792,25 @@ typedef struct { uint64_t nb3; } ggml_metal_kargs_set_rows; +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; +} ggml_metal_kargs_diag; + typedef struct { int64_t ne00; int64_t ne01; diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 753fcec3175..e0ed6c7805c 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -361,6 +361,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { { n_fuse = ggml_metal_op_set_rows(ctx, idx); } break; + case GGML_OP_DIAG: + { + n_fuse = ggml_metal_op_diag(ctx, idx); + } break; case GGML_OP_L2_NORM: { n_fuse = ggml_metal_op_l2_norm(ctx, idx); @@ -1259,6 +1263,48 @@ int ggml_metal_op_set_rows(ggml_metal_op_t ctx, int idx) { return 1; } +int ggml_metal_op_diag(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS(int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS(int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + ggml_metal_kargs_diag args = { + /*.ne00 =*/ne00, + /*.ne01 =*/ne01, + /*.ne02 =*/ne02, + /*.ne03 =*/ne03, + /*.nb00 =*/nb00, + /*.nb01 =*/nb01, + /*.nb02 =*/nb02, + /*.nb03 =*/nb03, + /*.ne0 =*/ne0, + /*.ne1 =*/ne1, + /*.ne2 =*/ne2, + /*.ne3 =*/ne3, + /*.nb0 =*/nb0, + /*.nb1 =*/nb1, + /*.nb2 =*/nb2, + /*.nb3 =*/nb3, + }; + + auto pipeline = ggml_metal_library_get_pipeline_diag(lib, op); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 2); + + ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, 32, 1, 1); + + return 1; +} + int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) { ggml_tensor * op = ctx->node(idx); diff --git a/ggml/src/ggml-metal/ggml-metal-ops.h b/ggml/src/ggml-metal/ggml-metal-ops.h index 2e4c7d3fa11..3c64e4f6007 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.h +++ b/ggml/src/ggml-metal/ggml-metal-ops.h @@ -56,6 +56,7 @@ int ggml_metal_op_sum_rows (ggml_metal_op_t ctx, int idx); int ggml_metal_op_cumsum (ggml_metal_op_t ctx, int idx); int ggml_metal_op_get_rows (ggml_metal_op_t ctx, int idx); int ggml_metal_op_set_rows (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_diag (ggml_metal_op_t ctx, int idx); int ggml_metal_op_soft_max (ggml_metal_op_t ctx, int idx); int ggml_metal_op_ssm_conv (ggml_metal_op_t ctx, int idx); int ggml_metal_op_ssm_scan (ggml_metal_op_t ctx, int idx); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index c09a54e6614..e54cdab39dd 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -8815,6 +8815,26 @@ kernel void kernel_set_rows_f( } } +kernel void kernel_diag_f32( + constant ggml_metal_kargs_diag & args, + device const char * src0, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiitg[[thread_index_in_threadgroup]]) { + constexpr short NW = N_SIMDWIDTH; + + const int32_t i3 = tgpig.z; + const int32_t i2 = tgpig.y; + const int32_t i1 = tgpig.x; + + device const float * src0_ptr = (device const float *)(src0 + i2*args.nb02 + i3*args.nb03); + device float * dst_ptr = (device float *)(dst + i1*args.nb01 + i2*args.nb2 + i3*args.nb3); + + for (int i0 = tiitg; i0 < args.ne0; i0 += NW) { + dst_ptr[i0] = i0 == i1 ? src0_ptr[i0] : 0.0f; + } +} + constant bool FC_mul_mm_bc_inp [[function_constant(FC_MUL_MM + 0)]]; constant bool FC_mul_mm_bc_out [[function_constant(FC_MUL_MM + 1)]]; From a567c140a3d8a043f9bf78d48a79066f30d8d7a7 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Thu, 5 Feb 2026 09:26:38 -0600 Subject: [PATCH 103/831] vulkan: Preprocess FA mask to detect all-neg-inf and all-zero. (llama/19281) Write out a 2-bit code per block and avoid loading the mask when it matches these two common cases. Apply this optimization when the mask is relatively large (i.e. prompt processing). --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 109 +++++++++++--- .../vulkan-shaders/flash_attn.comp | 39 ++--- .../vulkan-shaders/flash_attn_base.glsl | 6 + .../vulkan-shaders/flash_attn_cm1.comp | 110 ++++++++------ .../vulkan-shaders/flash_attn_cm2.comp | 69 +++++---- .../vulkan-shaders/flash_attn_mask_opt.comp | 142 ++++++++++++++++++ .../vulkan-shaders/vulkan-shaders-gen.cpp | 2 + 7 files changed, 356 insertions(+), 121 deletions(-) create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index ff9cb7355c2..4357da24d42 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -402,18 +402,19 @@ enum FaCodePath { }; struct vk_fa_pipeline_state { - vk_fa_pipeline_state(uint32_t HSK, uint32_t HSV, bool small_rows, bool small_cache, FaCodePath path, bool aligned, bool f32acc) - : HSK(HSK), HSV(HSV), small_rows(small_rows), small_cache(small_cache), path(path), aligned(aligned), f32acc(f32acc) {} + vk_fa_pipeline_state(uint32_t HSK, uint32_t HSV, bool small_rows, bool small_cache, FaCodePath path, bool aligned, bool f32acc, bool use_mask_opt) + : HSK(HSK), HSV(HSV), small_rows(small_rows), small_cache(small_cache), path(path), aligned(aligned), f32acc(f32acc), use_mask_opt(use_mask_opt) {} uint32_t HSK, HSV; bool small_rows, small_cache; FaCodePath path; bool aligned; bool f32acc; + bool use_mask_opt; bool operator<(const vk_fa_pipeline_state &b) const { - return std::tie(HSK, HSV, small_rows, small_cache, path, aligned, f32acc) < - std::tie(b.HSK, b.HSV, b.small_rows, b.small_cache, b.path, b.aligned, b.f32acc); + return std::tie(HSK, HSV, small_rows, small_cache, path, aligned, f32acc, use_mask_opt) < + std::tie(b.HSK, b.HSV, b.small_rows, b.small_cache, b.path, b.aligned, b.f32acc, b.use_mask_opt); } }; @@ -820,6 +821,8 @@ struct vk_device_struct { std::map pipeline_flash_attn_f32_f16[GGML_TYPE_COUNT]; + std::map, vk_pipeline> pipeline_fa_mask_opt; + vk_pipeline pipeline_flash_attn_split_k_reduce; vk_pipeline pipeline_count_experts; @@ -1549,6 +1552,18 @@ struct vk_op_flash_attn_split_k_reduce_push_constants { uint32_t sinks; }; +struct vk_op_flash_attn_mask_opt_push_constants { + uint32_t nem0; + uint32_t nem1; + uint32_t nem2; + uint32_t nbm1; + uint32_t nbm2; + uint32_t nbm3; + uint32_t nbd1; + uint32_t nbd2; + uint32_t nbd3; +}; + // Allow pre-recording command buffers struct vk_staging_memcpy { vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {} @@ -1757,6 +1772,7 @@ class vk_perf_logger { " k(" << k->ne[0] << "," << k->ne[1] << "," << k->ne[2] << "," << k->ne[3] << "), " << " v(" << v->ne[0] << "," << v->ne[1] << "," << v->ne[2] << "," << v->ne[3] << "), " << " m(" << (m?m->ne[0]:0) << "," << (m?m->ne[1]:0) << "," << (m?m->ne[2]:0) << "," << (m?m->ne[3]:0) << ")"; + *n_flops = 2ull * q->ne[1] * q->ne[2] * (k->ne[0] + v->ne[0]) * k->ne[1] * q->ne[3]; return name.str(); } if (node->op == GGML_OP_TOP_K) { @@ -3177,7 +3193,7 @@ static void ggml_vk_load_shaders(vk_device& device) { return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows, small_cache)[0], 1, 1}; }; - auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache) -> std::vector { + auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache, bool use_mask_opt) -> std::vector { // For large number of rows, 128 invocations seems to work best. // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we // can't use 256 for D==80. @@ -3209,7 +3225,7 @@ static void ggml_vk_load_shaders(vk_device& device) { // AMD prefers loading K directly from global memory const uint32_t k_load_shmem = device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 ? 1 : 0; - return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split, device->subgroup_size, k_load_shmem}; + return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split, device->subgroup_size, k_load_shmem, use_mask_opt}; }; #define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \ @@ -3221,18 +3237,19 @@ static void ggml_vk_load_shaders(vk_device& device) { FaCodePath path = fa.first.path; \ bool aligned = fa.first.aligned; \ bool f32acc = fa.first.f32acc; \ + bool use_mask_opt = fa.first.use_mask_opt; \ if (path == FAPATH) { \ if (aligned) { \ if (f32acc) { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache,use_mask_opt), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \ } else { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache,use_mask_opt), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \ } \ } else { \ if (f32acc) { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache,use_mask_opt), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \ } else { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache,use_mask_opt), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \ } \ } \ } \ @@ -4028,6 +4045,11 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, sizeof(vk_op_flash_attn_split_k_reduce_push_constants), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true); + for (auto &it : device->pipeline_fa_mask_opt) { + auto BrBc = it.first; + ggml_vk_create_pipeline(device, it.second, "fa_mask_opt", fa_mask_opt_len, fa_mask_opt_data, "main", 2, sizeof(vk_op_flash_attn_mask_opt_push_constants), {1, 1, 1}, {128, 128 / device->subgroup_size, BrBc.first, BrBc.second}, 1, true, true, device->subgroup_size); + } + if (device->subgroup_clustered && device->subgroup_require_full_support) { ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1_x4, "quantize_q8_1_x4", quantize_q8_1_x4_subgroup_len, quantize_q8_1_x4_subgroup_data, "main", 2, sizeof(vk_quantize_q8_1_push_constants), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1, true, true); } else { @@ -8400,8 +8422,6 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co const uint32_t acctype = f32acc ? 4 : 2; const uint32_t f16vec4 = 8; - const uint32_t tmpsh = (Bc / MatBc) * sizeof(float); - const uint32_t qstride = hsk_pad / 4 + 2; const uint32_t Qf = Br * qstride * f16vec4; @@ -8418,7 +8438,7 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co const uint32_t slope = Br * acctype; - const uint32_t total_size = tmpsh + Qf + Psh + sfsh + ksh + slope; + const uint32_t total_size = Qf + Psh + sfsh + ksh + slope; const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize; VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", f32acc=" << f32acc << ", kv_type=" << kv_type << ", total_size=" << total_size << ", supported=" << supported); @@ -8445,6 +8465,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + const uint32_t nem0 = mask ? mask->ne[0] : 0; const uint32_t nem1 = mask ? mask->ne[1] : 0; const uint32_t nem2 = mask ? mask->ne[2] : 0; const uint32_t nem3 = mask ? mask->ne[3] : 0; @@ -8574,7 +8595,10 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32; - vk_fa_pipeline_state fa_pipeline_state(HSK, HSV, small_rows, small_cache, path, aligned, f32acc); + // Only use mask opt when the mask is fairly large. This hasn't been tuned extensively. + bool use_mask_opt = mask && nem1 >= 32 && nem0 * nem1 > 32768; + + vk_fa_pipeline_state fa_pipeline_state(HSK, HSV, small_rows, small_cache, path, aligned, f32acc, use_mask_opt); vk_pipeline pipeline = nullptr; @@ -8625,10 +8649,32 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx ggml_vk_preallocate_buffers(ctx, subctx); } - { - // Request descriptor sets - if (split_k > 1) { - ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_flash_attn_split_k_reduce, 1); + auto rows_cols = fa_rows_cols(path, HSK, HSV, !aligned, k->type, small_rows, small_cache); + const uint32_t Br = rows_cols[0]; + const uint32_t Bc = rows_cols[1]; + + const uint32_t mask_opt_num_dwords = CEIL_DIV(nem0, 16 * Bc); + const uint64_t mask_opt_size = sizeof(uint32_t) * mask_opt_num_dwords * CEIL_DIV(nem1, Br) * nem2 * nem3; + + vk_pipeline pipeline_fa_mask_opt = nullptr; + if (use_mask_opt) { + std::lock_guard guard(ctx->device->mutex); + auto &pipelines = ctx->device->pipeline_fa_mask_opt; + auto it = pipelines.find({Br, Bc}); + if (it != pipelines.end()) { + pipeline_fa_mask_opt = it->second; + } else { + pipelines[{Br, Bc}] = pipeline_fa_mask_opt = std::make_shared(); + } + assert(pipeline_fa_mask_opt); + ggml_pipeline_request_descriptor_sets(ctx, pipeline_fa_mask_opt, 1); + + if (ctx->prealloc_size_y < mask_opt_size) { + ctx->prealloc_size_y = mask_opt_size; + ggml_vk_preallocate_buffers(ctx, subctx); + } + if (ctx->prealloc_y_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); } } @@ -8655,9 +8701,30 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst); vk_subbuffer mask_buf = mask ? ggml_vk_tensor_subbuffer(ctx, mask) : q_buf; vk_subbuffer sinks_buf = sinks ? ggml_vk_tensor_subbuffer(ctx, sinks) : q_buf; + vk_subbuffer mask_opt_buf = use_mask_opt ? ggml_vk_subbuffer(ctx, ctx->prealloc_y, 0) : q_buf; uint32_t mask_n_head_log2 = ((sinks != nullptr) << 24) | ((mask != nullptr) << 16) | n_head_log2; + if (use_mask_opt) + { + const vk_op_flash_attn_mask_opt_push_constants opt_pc = { + nem0, + nem1, + nem2, + (uint32_t)(mask->nb[1] / sizeof(ggml_fp16_t)), + (uint32_t)(mask->nb[2] / sizeof(ggml_fp16_t)), + (uint32_t)(mask->nb[3] / sizeof(ggml_fp16_t)), + mask_opt_num_dwords, + mask_opt_num_dwords * CEIL_DIV(nem1, Br), + mask_opt_num_dwords * CEIL_DIV(nem1, Br) * nem2, + }; + + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline_fa_mask_opt, + { mask_buf, mask_opt_buf }, opt_pc, + { mask_opt_num_dwords, CEIL_DIV(nem1, Br), nem2 * nem3 }); + ggml_vk_sync_buffers(ctx, subctx); + } + const vk_flash_attn_push_constants pc = { N, KV, (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3, (uint32_t)neq2, (uint32_t)neq3, @@ -8672,13 +8739,15 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx gqa_ratio, split_kv, split_k }; if (split_k > 1) { + ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_flash_attn_split_k_reduce, 1); + if (ctx->prealloc_split_k_need_sync) { ggml_vk_sync_buffers(ctx, subctx); } workgroups_x *= pipeline->wg_denoms[0]; vk_subbuffer split_k_buf = ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, - {q_buf, k_buf, v_buf, mask_buf, sinks_buf, split_k_buf}, + {q_buf, k_buf, v_buf, mask_buf, sinks_buf, split_k_buf, mask_opt_buf}, // We only use split_k when group query attention is enabled, which means // there's no more than one tile of rows (i.e. workgroups_x would have been // one). We reuse workgroups_x to mean the number of splits, so we need to @@ -8697,7 +8766,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx workgroups_x *= pipeline->wg_denoms[0]; } ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, - {q_buf, k_buf, v_buf, mask_buf, sinks_buf, dst_buf}, + {q_buf, k_buf, v_buf, mask_buf, sinks_buf, dst_buf, mask_opt_buf}, pc, { workgroups_x, workgroups_y, workgroups_z }); } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index 3ce8d07be80..49a3c530cb6 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -94,6 +94,10 @@ void main() { } } + const uint32_t mo_stride = CEIL_DIV(KV, 16 * Bc); + // mo_offset will point to the tile starting at row i*Br and col 0 + uint32_t mo_offset = mo_stride * i; + #if BLOCK_SIZE > 1 uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE; uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE; @@ -104,15 +108,28 @@ void main() { uint32_t m_offset = gqa_iq1*KV; if (p.nem2 != 1 || p.nem3 != 1) { m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV; + mo_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * CEIL_DIV(p.nem1, Br) * mo_stride; } + uint32_t mask_opt = 0; + uint32_t mask_opt_idx = ~0; + [[dont_unroll]] for (uint32_t j = start_j; j < end_j; ++j) { - if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) { + if (USE_MASK_OPT && mask_opt_idx != j / 16) { + mask_opt_idx = j / 16; + mask_opt = data_mask_opt[mo_offset + mask_opt_idx]; + } + uint32_t mask_opt_bits = (mask_opt >> ((j % 16) * 2)) & 0x3; + if (mask_opt_bits == MASK_OPT_ALL_NEG_INF) { + // skip this block + continue; + } + // Only load if the block is not all zeros + if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0 && mask_opt_bits != MASK_OPT_ALL_ZERO) { bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0; - float max_mask = NEG_FLT_MAX_OVER_2; [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { uint32_t c = (idx + tid) % Bc; uint32_t r = (idx + tid) / Bc; @@ -120,25 +137,12 @@ void main() { if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) { float m = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]); masksh[c][r] = m; - max_mask = max(max_mask, m); } else { masksh[c][r] = float(0); } } } - // skip the block if the mask is entirely -inf - bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2); - barrier(); - if (gl_SubgroupInvocationID == 0) { - tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f; - } barrier(); - [[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) { - max_mask = max(max_mask, tmpsh[s]); - } - if (max_mask <= NEG_FLT_MAX_OVER_2) { - continue; - } } float Sf[Br][cols_per_thread]; @@ -185,7 +189,7 @@ void main() { } } - if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) { + if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0 && mask_opt_bits != MASK_OPT_ALL_ZERO) { [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { [[unroll]] for (uint32_t r = 0; r < Br; ++r) { float mvf = masksh[c * cols_per_iter + col_tid][r]; @@ -256,9 +260,6 @@ void main() { barrier(); } - // prevent race on tmpsh - barrier(); - // reduce across threads [[unroll]] for (uint32_t r = 0; r < Br; ++r) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl index 23a4d2c0058..252451101ab 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl @@ -10,6 +10,7 @@ layout (constant_id = 5) const uint32_t Clamp = 0; layout (constant_id = 6) const uint32_t D_split = 16; layout (constant_id = 7) const uint32_t SubGroupSize = 32; layout (constant_id = 8) const uint32_t K_LOAD_SHMEM = 0; +layout (constant_id = 9) const bool USE_MASK_OPT = false; // Round up head sizes to a multiple of 16, for coopmat1/coopmat2 paths const uint32_t HSK_pad = (HSK + 15) & ~15; @@ -66,6 +67,11 @@ layout (binding = 4) readonly buffer S {float data_s[];}; layout (binding = 5) writeonly buffer O {D_TYPE data_o[];}; +layout (binding = 6) readonly buffer MO {uint32_t data_mask_opt[];}; + +#define MASK_OPT_ALL_NEG_INF 1 +#define MASK_OPT_ALL_ZERO 2 + #define BINDING_IDX_K 0 #define BINDING_IDX_V 1 #if defined(DATA_A_F32) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp index 83d52d19d67..89af3697e1d 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp @@ -42,8 +42,6 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TY return elem; } -shared float tmpsh[row_split]; - const uint32_t qstride = HSK_pad / 4 + 2; // in units of f16vec4 shared f16vec4 Qf[Br * qstride]; @@ -134,6 +132,10 @@ void main() { } } + const uint32_t mo_stride = CEIL_DIV(KV, 16 * Bc); + // mo_offset will point to the tile starting at row i*Br and col 0 + uint32_t mo_offset = mo_stride * i; + #if BLOCK_SIZE > 1 uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE; uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE; @@ -144,66 +146,74 @@ void main() { uint32_t m_offset = gqa_iq1*KV; if (p.nem2 != 1 || p.nem3 != 1) { m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV; + mo_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * CEIL_DIV(p.nem1, Br) * mo_stride; } + uint32_t mask_opt = 0; + uint32_t mask_opt_idx = ~0; + [[dont_unroll]] for (uint32_t j = start_j; j < end_j; ++j) { f16vec4 mask_cache[Bc * Br / 4 / WorkGroupSize]; + [[unroll]] for (uint32_t idx = 0; idx < mask_cache.length(); ++idx) { + mask_cache[idx] = f16vec4(0); + } + if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) { - bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0; - float max_mask = NEG_FLT_MAX_OVER_2; - [[unroll]] for (uint32_t idx = 0; idx < Bc * Br / 4; idx += gl_WorkGroupSize.x) { - uint32_t c = (idx + tid) / (Br / 4); - uint32_t r = (idx + tid) % (Br / 4); - if (idx + tid < Bc * Br / 4 || idx + gl_WorkGroupSize.x <= Bc * Br / 4) { - if ((!KV_bounds_check || j * Bc + c < KV)) { - f16vec4 m; - if (!nem1_bounds_check || i * Br + r * 4 + 3 < p.nem1) { - m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)], - data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)], - data_m[m_offset + (i * Br + r * 4 + 2) * m_stride + (j * Bc + c)], - data_m[m_offset + (i * Br + r * 4 + 3) * m_stride + (j * Bc + c)]); - max_mask = max(max(max(max(max_mask, float(m[0])), float(m[1])), float(m[2])), float(m[3])); - } else if (i * Br + r * 4 + 2 < p.nem1) { - m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)], - data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)], - data_m[m_offset + (i * Br + r * 4 + 2) * m_stride + (j * Bc + c)], - 0.0); - max_mask = max(max(max(max_mask, float(m[0])), float(m[1])), float(m[2])); - } else if (i * Br + r * 4 + 1 < p.nem1) { - m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)], - data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)], - 0.0, - 0.0); - max_mask = max(max(max_mask, float(m[0])), float(m[1])); - } else if (i * Br + r * 4 < p.nem1) { - m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)], - 0.0, - 0.0, - 0.0); - max_mask = max(max_mask, float(m[0])); - } else { - m = f16vec4(0.0); + if (USE_MASK_OPT && mask_opt_idx != j / 16) { + mask_opt_idx = j / 16; + mask_opt = data_mask_opt[mo_offset + mask_opt_idx]; + } + uint32_t mask_opt_bits = (mask_opt >> ((j % 16) * 2)) & 0x3; + if (mask_opt_bits == MASK_OPT_ALL_NEG_INF) { + // skip this block + continue; + } + // Only load if the block is not all zeros + if (mask_opt_bits != MASK_OPT_ALL_ZERO) { + bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0; + + float max_mask = NEG_FLT_MAX_OVER_2; + [[unroll]] for (uint32_t idx = 0; idx < Bc * Br / 4; idx += gl_WorkGroupSize.x) { + uint32_t c = (idx + tid) / (Br / 4); + uint32_t r = (idx + tid) % (Br / 4); + if (idx + tid < Bc * Br / 4 || idx + gl_WorkGroupSize.x <= Bc * Br / 4) { + if ((!KV_bounds_check || j * Bc + c < KV)) { + f16vec4 m; + if (!nem1_bounds_check || i * Br + r * 4 + 3 < p.nem1) { + m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)], + data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)], + data_m[m_offset + (i * Br + r * 4 + 2) * m_stride + (j * Bc + c)], + data_m[m_offset + (i * Br + r * 4 + 3) * m_stride + (j * Bc + c)]); + max_mask = max(max(max(max(max_mask, float(m[0])), float(m[1])), float(m[2])), float(m[3])); + } else if (i * Br + r * 4 + 2 < p.nem1) { + m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)], + data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)], + data_m[m_offset + (i * Br + r * 4 + 2) * m_stride + (j * Bc + c)], + 0.0); + max_mask = max(max(max(max_mask, float(m[0])), float(m[1])), float(m[2])); + } else if (i * Br + r * 4 + 1 < p.nem1) { + m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)], + data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)], + 0.0, + 0.0); + max_mask = max(max(max_mask, float(m[0])), float(m[1])); + } else if (i * Br + r * 4 < p.nem1) { + m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)], + 0.0, + 0.0, + 0.0); + max_mask = max(max_mask, float(m[0])); + } else { + m = f16vec4(0.0); + } + mask_cache[idx / WorkGroupSize] = m; } - mask_cache[idx / WorkGroupSize] = m; } } } - // skip the block if the mask is entirely -inf - bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2); - barrier(); - if (gl_SubgroupInvocationID == 0) { - tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f; - } - barrier(); - [[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) { - max_mask = max(max_mask, tmpsh[s]); - } - if (max_mask <= NEG_FLT_MAX_OVER_2) { - continue; - } } if (K_LOAD_SHMEM != 0) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index 54f1b0b6226..47b110621b7 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -138,48 +138,53 @@ void main() { coopMatPerElementNV(slopeMat, slopeMat, perElemOpComputeSlope, iq2); } + const uint32_t mo_stride = CEIL_DIV(KV, 16 * Bc); + // mo_offset will point to the tile starting at row i*Br and col 0 + uint32_t mo_offset = mo_stride * i; + uint32_t m_offset = gqa_iq1*KV * 2 /*sizeof(float16_t)*/; if (p.nem2 != 1 || p.nem3 != 1) { m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV * 2 /*sizeof(float16_t)*/; + mo_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * CEIL_DIV(p.nem1, Br) * mo_stride; } + uint32_t mask_opt = 0; + uint32_t mask_opt_idx = ~0; + [[dont_unroll]] for (uint32_t j = start_j; j < end_j; ++j) { - coopmat mv; + coopmat mv = coopmat(0); if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) { - bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0; - - if (nem1_bounds_check) { - tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); - tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV); - tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1); - tensorLayoutM = setTensorLayoutClampValueNV(tensorLayoutM, 0xfc00); // -inf in float16_t - - coopmat mvmax; - - coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc)); - - // skip the block if the mask is entirely -inf - coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16); - if (mvmax[0] <= NEG_FLT_MAX_OVER_2) { - continue; - } - } else { - tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp); - // Don't clamp against nem1 when GQA is enabled - uint32_t m_height = p.gqa_ratio > 1 ? ~0 : p.nem1; - tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, m_height, KV); - tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1); - - coopmat mvmax; - coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc)); - - // skip the block if the mask is entirely -inf - coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16); - if (mvmax[0] <= NEG_FLT_MAX_OVER_2) { - continue; + if (USE_MASK_OPT && mask_opt_idx != j / 16) { + mask_opt_idx = j / 16; + mask_opt = data_mask_opt[mo_offset + mask_opt_idx]; + } + uint32_t mask_opt_bits = (mask_opt >> ((j % 16) * 2)) & 0x3; + if (mask_opt_bits == MASK_OPT_ALL_NEG_INF) { + // skip this block + continue; + } + // Only load if the block is not all zeros + if (mask_opt_bits != MASK_OPT_ALL_ZERO) { + bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0; + + if (nem1_bounds_check) { + tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); + tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV); + tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1); + tensorLayoutM = setTensorLayoutClampValueNV(tensorLayoutM, 0xfc00); // -inf in float16_t + + coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc)); + } else { + tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp); + // Don't clamp against nem1 when GQA is enabled + uint32_t m_height = p.gqa_ratio > 1 ? ~0 : p.nem1; + tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, m_height, KV); + tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1); + + coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc)); } } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp new file mode 100644 index 00000000000..8c92c1adcda --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp @@ -0,0 +1,142 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable +#extension GL_EXT_shader_16bit_storage : enable +#extension GL_KHR_shader_subgroup_arithmetic : enable + +layout (constant_id = 0) const uint BLOCK_SIZE = 128; +layout (constant_id = 1) const uint NUM_SUBGROUPS = 4; +layout (constant_id = 2) const uint Br = 32; +layout (constant_id = 3) const uint Bc = 32; + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {float16_t data_a[];}; +layout (binding = 0) readonly buffer Av4 {f16vec4 data_av4[];}; +layout (binding = 1) writeonly buffer D {uint data_d[];}; + +layout (push_constant) uniform parameter { + uint nem0; + uint nem1; + uint nem2; + uint nbm1; + uint nbm2; + uint nbm3; + uint nbd1; + uint nbd2; + uint nbd3; +}; + +#define MASK_OPT_ALL_NEG_INF 1 +#define MASK_OPT_ALL_ZERO 2 + +shared float minsh[NUM_SUBGROUPS]; +shared float maxsh[NUM_SUBGROUPS]; + +// For each Br x Bc block of the mask (input) buffer, read all values and check +// if it's all -inf or all zero. Write out a two-bit code indicating which it is +// (or zero for neither). Each workgroup processes 16 tiles and writes out a +// 32-bit result mask. +// +// TODO: This is a lot of work per workgroup, might make sense to split this into +// more workgroups in the future. +void main() { + // Each workgroup handles a row + const uint tid = gl_LocalInvocationIndex; + const uint i0 = gl_WorkGroupID.x; + const uint i1 = gl_WorkGroupID.y; + const uint i2 = gl_WorkGroupID.z % nem2; + const uint i3 = gl_WorkGroupID.z / nem2; + + float FLT_MAX_OVER_2 = uintBitsToFloat(0x7EFFFFFF); + + uint result = 0; + + // Fast path for fully in-bounds blocks where we can do f16vec4 loads + if ((nem0 % Bc) == 0 && (nem1 % Br) == 0 && + ((Br * Bc) % (BLOCK_SIZE * 4)) == 0) { + [[unroll]] for (uint block_x = 0; block_x < 16; ++block_x) { + float min_v = FLT_MAX_OVER_2; + float max_v = -FLT_MAX_OVER_2; + [[unroll]] for (uint i = 0; i < Br * Bc / 4; i += BLOCK_SIZE) { + uint j0 = (i + tid) % (Bc / 4); + uint j1 = (i + tid) / (Bc / 4); + + j0 *= 4; + j0 += (i0 * 16 + block_x) * Bc; + j1 += i1 * Br; + + vec4 f = vec4(data_av4[(j0 + j1 * nbm1 + i2 * nbm2 + i3 * nbm3) / 4]); + [[unroll]] for (int c = 0; c < 4; ++c) { + min_v = min(min_v, f[c]); + max_v = max(max_v, f[c]); + } + } + min_v = subgroupMin(min_v); + max_v = subgroupMax(max_v); + if (gl_SubgroupInvocationID == 0) { + minsh[gl_SubgroupID] = min_v; + maxsh[gl_SubgroupID] = max_v; + } + barrier(); + if (tid == 0) { + [[unroll]] for (uint i = 0; i < NUM_SUBGROUPS; ++i) { + min_v = min(min_v, minsh[i]); + max_v = max(max_v, maxsh[i]); + } + if (max_v <= -FLT_MAX_OVER_2) { + result |= 1 << (2*block_x); + } + if (min_v == 0.0f && max_v == 0.0f) { + result |= 2 << (2*block_x); + } + } + barrier(); + } + } else { + [[unroll]] for (uint block_x = 0; block_x < 16; ++block_x) { + float min_v = FLT_MAX_OVER_2; + float max_v = -FLT_MAX_OVER_2; + [[unroll]] for (uint i = 0; i < Br * Bc; i += BLOCK_SIZE) { + if ((Br * Bc % BLOCK_SIZE) != 0 && i + tid >= Br * Bc) { + continue; + } + uint j0 = (i + tid) % Bc; + uint j1 = (i + tid) / Bc; + + j0 += (i0 * 16 + block_x) * Bc; + j1 += i1 * Br; + + if (j0 < nem0 && j1 < nem1) { + float f = float(data_a[j0 + j1 * nbm1 + i2 * nbm2 + i3 * nbm3]); + min_v = min(min_v, f); + max_v = max(max_v, f); + } + } + min_v = subgroupMin(min_v); + max_v = subgroupMax(max_v); + if (gl_SubgroupInvocationID == 0) { + minsh[gl_SubgroupID] = min_v; + maxsh[gl_SubgroupID] = max_v; + } + barrier(); + if (tid == 0) { + [[unroll]] for (uint i = 0; i < NUM_SUBGROUPS; ++i) { + min_v = min(min_v, minsh[i]); + max_v = max(max_v, maxsh[i]); + } + if (max_v <= -FLT_MAX_OVER_2) { + result |= 1 << (2*block_x); + } + if (min_v == 0.0f && max_v == 0.0f) { + result |= 2 << (2*block_x); + } + } + barrier(); + } + } + + if (tid == 0) { + data_d[i0 + i1 * nbd1 + i2 * nbd2 + i3 * nbd3] = result; + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index ca486a288a1..42ebc21e2a6 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -790,6 +790,8 @@ void process_shaders() { string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {}); string_to_spv("fa_split_k_reduce", "flash_attn_split_k_reduce.comp", {}); + string_to_spv("fa_mask_opt", "flash_attn_mask_opt.comp", {}); + string_to_spv("quantize_q8_1", "quantize_q8_1.comp", {}); string_to_spv("quantize_q8_1_subgroup", "quantize_q8_1.comp", {{"USE_SUBGROUPS", "1"}}); From 34d332aca55f44f47e6724b011f8e5903d86bfe9 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 5 Feb 2026 19:07:22 +0200 Subject: [PATCH 104/831] metal : adaptive CPU/GPU interleave based on number of nodes (llama/19369) --- ggml/src/ggml-metal/ggml-metal-context.m | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-metal/ggml-metal-context.m b/ggml/src/ggml-metal/ggml-metal-context.m index a412d70aed5..c7e8ebd3f32 100644 --- a/ggml/src/ggml-metal/ggml-metal-context.m +++ b/ggml/src/ggml-metal/ggml-metal-context.m @@ -415,7 +415,7 @@ bool ggml_metal_cpy_tensor_async(ggml_metal_t ctx_src, ggml_metal_t ctx_dst, con enum ggml_status ggml_metal_graph_compute(ggml_metal_t ctx, struct ggml_cgraph * gf) { // number of nodes encoded by the main thread (empirically determined) - const int n_main = 64; + const int n_main = MAX(64, 0.1*gf->n_nodes); // number of threads in addition to the main thread const int n_cb = ctx->n_cb; From 2a7d5490f17fc8fdeae04251220e9fa7733a8539 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 6 Feb 2026 07:55:06 +0200 Subject: [PATCH 105/831] cuda : cuda graphs now compare all node params (llama/19383) --- ggml/src/ggml-cuda/ggml-cuda.cu | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index eeb8625dbeb..9e77c231c85 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2979,8 +2979,7 @@ static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_ } } - if ((node->op == GGML_OP_SCALE || node->op == GGML_OP_GLU) && - memcmp(props->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) { + if (memcmp(props->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) { return false; } From 776cf61857698423e10b562f06d6380fc777684e Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 6 Feb 2026 09:25:11 +0200 Subject: [PATCH 106/831] metal : skip loading all-zero mask (llama/19337) * metal : skip loading all-zero mask * cont : minor --- ggml/src/ggml-metal/ggml-metal.metal | 63 +++++++++++++++++----------- 1 file changed, 39 insertions(+), 24 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index e54cdab39dd..612a42a1ea8 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -5285,6 +5285,7 @@ constant int32_t FC_flash_attn_ext_blk_ncpsg [[function_constant(FC_FLASH_ATTN_E // scan the blocks of the mask that are not masked // 0 - masked (i.e. full of -INF, skip) // 1 - not masked (i.e. at least one element of the mask is not -INF) +// 2 - all zero kernel void kernel_flash_attn_ext_blk( constant ggml_metal_kargs_flash_attn_ext_blk & args, device const char * mask, @@ -5306,27 +5307,29 @@ kernel void kernel_flash_attn_ext_blk( device const half * mask_src = (device const half *) (mask + (i1*Q)*args.nb31 + i2*args.nb32 + i3*args.nb33) + i0*C + tiisg; - // fast route - if (res == 0) { - if (simd_max(*mask_src) > -MAXHALF/2) { - res = 1; - } - } - // detailed check of the elements of the block if ((C > NW || Q > 1) && res == 0) { - half m = -MAXHALF; + half mmin = MAXHALF; + half mmax = -MAXHALF; FOR_UNROLL (short j = 0; j < Q; ++j) { FOR_UNROLL (short ii = 0; ii < C/NW; ++ii) { - m = max(m, mask_src[ii*NW]); + mmin = min(mmin, mask_src[ii*NW]); + mmax = max(mmax, mask_src[ii*NW]); } mask_src += args.nb31/2; } - if (simd_max(m) > -MAXHALF/2) { - res = 1; + mmin = simd_min(mmin); + mmax = simd_max(mmax); + + if (mmax > -MAXHALF) { + if (mmin == 0.0 && mmax == 0.0) { + res = 2; + } else { + res = 1; + } } } @@ -5568,9 +5571,13 @@ void kernel_flash_attn_ext_impl( ic = 0; } + char blk_cur = 1; + // read the mask into shared mem if (FC_flash_attn_ext_has_mask) { - if (blk[ic0] == 0) { + blk_cur = blk[ic0]; + + if (blk_cur == 0) { FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { pm2[jj] += NW; } @@ -5578,16 +5585,22 @@ void kernel_flash_attn_ext_impl( continue; } - FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { - const short j = jj*NSG + sgitg; + if (blk_cur == 1) { + FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { + const short j = jj*NSG + sgitg; - if (FC_flash_attn_ext_bc_mask) { - sm2[j*SH + tiisg] = (iq1 + j) < args.ne31 ? pm2[jj][tiisg] : half2(-MAXHALF, -MAXHALF); - } else { - sm2[j*SH + tiisg] = pm2[jj][tiisg]; - } + if (FC_flash_attn_ext_bc_mask) { + sm2[j*SH + tiisg] = (iq1 + j) < args.ne31 ? pm2[jj][tiisg] : half2(-MAXHALF, -MAXHALF); + } else { + sm2[j*SH + tiisg] = pm2[jj][tiisg]; + } - pm2[jj] += NW; + pm2[jj] += NW; + } + } else if (blk_cur == 2) { + FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { + pm2[jj] += NW; + } } #if 0 @@ -5752,10 +5765,12 @@ void kernel_flash_attn_ext_impl( } // mqk = mqk + slope*mask - if (FC_flash_attn_ext_has_bias) { - s2 += s2_t(sm2[j*SH + tiisg])*slope; - } else { - s2 += s2_t(sm2[j*SH + tiisg]); + if (blk_cur != 2) { + if (FC_flash_attn_ext_has_bias) { + s2 += s2_t(sm2[j*SH + tiisg])*slope; + } else { + s2 += s2_t(sm2[j*SH + tiisg]); + } } M[jj] = simd_max(max(M[jj], max(s2[0], s2[1]))); From c1b63354bb566143cf7987c20fed9256a0b79338 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Fri, 6 Feb 2026 01:49:58 -0600 Subject: [PATCH 107/831] vulkan: make FA mask/softcap enables spec constants (llama/19309) * vulkan: make FA mask/softcap enables spec constants * don't specialize for sinks * bump timeout a little bit --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 56 ++++++++++--------- .../vulkan-shaders/flash_attn.comp | 6 +- .../vulkan-shaders/flash_attn_base.glsl | 7 ++- .../vulkan-shaders/flash_attn_cm1.comp | 6 +- .../vulkan-shaders/flash_attn_cm2.comp | 6 +- 5 files changed, 44 insertions(+), 37 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 4357da24d42..72097ffd0ff 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -402,19 +402,19 @@ enum FaCodePath { }; struct vk_fa_pipeline_state { - vk_fa_pipeline_state(uint32_t HSK, uint32_t HSV, bool small_rows, bool small_cache, FaCodePath path, bool aligned, bool f32acc, bool use_mask_opt) - : HSK(HSK), HSV(HSV), small_rows(small_rows), small_cache(small_cache), path(path), aligned(aligned), f32acc(f32acc), use_mask_opt(use_mask_opt) {} + vk_fa_pipeline_state(uint32_t HSK, uint32_t HSV, bool small_rows, bool small_cache, FaCodePath path, bool aligned, bool f32acc, uint32_t flags) + : HSK(HSK), HSV(HSV), small_rows(small_rows), small_cache(small_cache), path(path), aligned(aligned), f32acc(f32acc), flags(flags) {} uint32_t HSK, HSV; bool small_rows, small_cache; FaCodePath path; bool aligned; bool f32acc; - bool use_mask_opt; + uint32_t flags; bool operator<(const vk_fa_pipeline_state &b) const { - return std::tie(HSK, HSV, small_rows, small_cache, path, aligned, f32acc, use_mask_opt) < - std::tie(b.HSK, b.HSV, b.small_rows, b.small_cache, b.path, b.aligned, b.f32acc, b.use_mask_opt); + return std::tie(HSK, HSV, small_rows, small_cache, path, aligned, f32acc, flags) < + std::tie(b.HSK, b.HSV, b.small_rows, b.small_cache, b.path, b.aligned, b.f32acc, b.flags); } }; @@ -3193,7 +3193,7 @@ static void ggml_vk_load_shaders(vk_device& device) { return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows, small_cache)[0], 1, 1}; }; - auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache, bool use_mask_opt) -> std::vector { + auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache, uint32_t flags) -> std::vector { // For large number of rows, 128 invocations seems to work best. // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we // can't use 256 for D==80. @@ -3225,7 +3225,7 @@ static void ggml_vk_load_shaders(vk_device& device) { // AMD prefers loading K directly from global memory const uint32_t k_load_shmem = device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 ? 1 : 0; - return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split, device->subgroup_size, k_load_shmem, use_mask_opt}; + return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split, device->subgroup_size, k_load_shmem, flags}; }; #define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \ @@ -3237,19 +3237,19 @@ static void ggml_vk_load_shaders(vk_device& device) { FaCodePath path = fa.first.path; \ bool aligned = fa.first.aligned; \ bool f32acc = fa.first.f32acc; \ - bool use_mask_opt = fa.first.use_mask_opt; \ + uint32_t flags = fa.first.flags; \ if (path == FAPATH) { \ if (aligned) { \ if (f32acc) { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache,use_mask_opt), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \ } else { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache,use_mask_opt), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \ } \ } else { \ if (f32acc) { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache,use_mask_opt), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache,flags), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \ } else { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache,use_mask_opt), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache,flags), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \ } \ } \ } \ @@ -8595,10 +8595,26 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32; + float scale = 1.0f; + float max_bias = 0.0f; + float logit_softcap = 0.0f; + + memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float)); + memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float)); + memcpy(&logit_softcap, (const float *) dst->op_params + 2, sizeof(float)); + + if (logit_softcap != 0) { + scale /= logit_softcap; + } + // Only use mask opt when the mask is fairly large. This hasn't been tuned extensively. bool use_mask_opt = mask && nem1 >= 32 && nem0 * nem1 > 32768; - vk_fa_pipeline_state fa_pipeline_state(HSK, HSV, small_rows, small_cache, path, aligned, f32acc, use_mask_opt); + uint32_t flags = (use_mask_opt ? 1 : 0) | + (mask != nullptr ? 2 : 0) | + (logit_softcap != 0 ? 4 : 0); + + vk_fa_pipeline_state fa_pipeline_state(HSK, HSV, small_rows, small_cache, path, aligned, f32acc, flags); vk_pipeline pipeline = nullptr; @@ -8678,18 +8694,6 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx } } - float scale = 1.0f; - float max_bias = 0.0f; - float logit_softcap = 0.0f; - - memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float)); - memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float)); - memcpy(&logit_softcap, (const float *) dst->op_params + 2, sizeof(float)); - - if (logit_softcap != 0) { - scale /= logit_softcap; - } - const uint32_t n_head_kv = neq2; const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv)); const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); @@ -8703,7 +8707,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx vk_subbuffer sinks_buf = sinks ? ggml_vk_tensor_subbuffer(ctx, sinks) : q_buf; vk_subbuffer mask_opt_buf = use_mask_opt ? ggml_vk_subbuffer(ctx, ctx->prealloc_y, 0) : q_buf; - uint32_t mask_n_head_log2 = ((sinks != nullptr) << 24) | ((mask != nullptr) << 16) | n_head_log2; + uint32_t mask_n_head_log2 = ((sinks != nullptr) << 24) | n_head_log2; if (use_mask_opt) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index 49a3c530cb6..914f131c965 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -127,7 +127,7 @@ void main() { continue; } // Only load if the block is not all zeros - if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0 && mask_opt_bits != MASK_OPT_ALL_ZERO) { + if (MASK_ENABLE && mask_opt_bits != MASK_OPT_ALL_ZERO) { bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0; [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { @@ -181,7 +181,7 @@ void main() { } } - if (p.logit_softcap != 0.0f) { + if (LOGIT_SOFTCAP) { [[unroll]] for (uint32_t r = 0; r < Br; ++r) { [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { Sf[r][c] = p.logit_softcap * tanh(Sf[r][c]); @@ -189,7 +189,7 @@ void main() { } } - if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0 && mask_opt_bits != MASK_OPT_ALL_ZERO) { + if (MASK_ENABLE && mask_opt_bits != MASK_OPT_ALL_ZERO) { [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { [[unroll]] for (uint32_t r = 0; r < Br; ++r) { float mvf = masksh[c * cols_per_iter + col_tid][r]; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl index 252451101ab..74005cffb3f 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl @@ -10,7 +10,11 @@ layout (constant_id = 5) const uint32_t Clamp = 0; layout (constant_id = 6) const uint32_t D_split = 16; layout (constant_id = 7) const uint32_t SubGroupSize = 32; layout (constant_id = 8) const uint32_t K_LOAD_SHMEM = 0; -layout (constant_id = 9) const bool USE_MASK_OPT = false; +layout (constant_id = 9) const uint32_t Flags = 0; + +const bool USE_MASK_OPT = (Flags & 1) != 0; +const bool MASK_ENABLE = (Flags & 2) != 0; +const bool LOGIT_SOFTCAP = (Flags & 4) != 0; // Round up head sizes to a multiple of 16, for coopmat1/coopmat2 paths const uint32_t HSK_pad = (HSK + 15) & ~15; @@ -60,7 +64,6 @@ layout (push_constant) uniform parameter { } p; #define SINK_ENABLE_BIT (1<<24) -#define MASK_ENABLE_BIT (1<<16) #define N_LOG2_MASK 0xFFFF layout (binding = 4) readonly buffer S {float data_s[];}; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp index 89af3697e1d..b3177738234 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp @@ -160,7 +160,7 @@ void main() { mask_cache[idx] = f16vec4(0); } - if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) { + if (MASK_ENABLE) { if (USE_MASK_OPT && mask_opt_idx != j / 16) { mask_opt_idx = j / 16; @@ -303,7 +303,7 @@ void main() { coopMatStore(SfMat, sfsh, coord, sfshstride, gl_CooperativeMatrixLayoutRowMajor); barrier(); - if (p.logit_softcap != 0.0f) { + if (LOGIT_SOFTCAP) { [[unroll]] for (uint32_t idx = 0; idx < Bc * Br / 4; idx += gl_WorkGroupSize.x) { uint32_t c = (idx + tid) / (Br / 4); uint32_t r = (idx + tid) % (Br / 4); @@ -314,7 +314,7 @@ void main() { barrier(); } - if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) { + if (MASK_ENABLE) { [[unroll]] for (uint32_t idx = 0; idx < Bc * Br / 4; idx += gl_WorkGroupSize.x) { uint32_t c = (idx + tid) / (Br / 4); uint32_t r = (idx + tid) % (Br / 4); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index 47b110621b7..b07c21f6e55 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -155,7 +155,7 @@ void main() { for (uint32_t j = start_j; j < end_j; ++j) { coopmat mv = coopmat(0); - if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) { + if (MASK_ENABLE) { if (USE_MASK_OPT && mask_opt_idx != j / 16) { mask_opt_idx = j / 16; @@ -197,14 +197,14 @@ void main() { coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose DECODEFUNC); S = coopMatMulAdd(Qf16, K_T, S); - if (p.logit_softcap != 0.0f) { + if (LOGIT_SOFTCAP) { [[unroll]] for (int k = 0; k < S.length(); ++k) { S[k] = ACC_TYPE(p.logit_softcap)*tanh(S[k]); } } - if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) { + if (MASK_ENABLE) { S += slopeMat*coopmat(mv); } From cea22b3075684fc4d949982eb412b47f8da205cc Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Fri, 6 Feb 2026 02:15:13 -0600 Subject: [PATCH 108/831] vulkan: For coopmat2 FA, use fp16 accumulators for the final result (llama/19376) The cpu and cuda backends use fp16 for the VKQ accumulator type, this change does the same for vulkan. This helps particularly with large head sizes which are very register-limited. I tried this for the coopmat1 path and it slowed down a bit. I didn't try for scalar. I applied the softmax bias that the cuda backend uses to avoid overflow, although I was not able to reproduce the original bug without it. --- .../vulkan-shaders/flash_attn_base.glsl | 4 ++++ .../vulkan-shaders/flash_attn_cm2.comp | 20 +++++++++---------- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl index 74005cffb3f..4142c1e6eaa 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl @@ -240,3 +240,7 @@ void init_indices() // and breaking the alignment detection. m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV; } + +// Bias applied to softmax to stay in fp16 range. +// Based on ggml-cuda issue https://github.com/ggml-org/llama.cpp/issues/18606 +const float FATTN_KQ_MAX_OFFSET = 3.0f*0.6931f; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index b07c21f6e55..39f0c4d23b9 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -117,7 +117,7 @@ void main() { Qf16 = coopmat(Q); Qf16 *= float16_t(p.scale); - coopmat O = coopmat(0); + coopmat O = coopmat(0); coopmat L, M; @@ -223,6 +223,8 @@ void main() { coopMatReduceNV(rowmax, S, gl_CooperativeMatrixReduceRowNV, maxReduce); + rowmax += coopmat(FATTN_KQ_MAX_OFFSET); + coopmat Mold = M; // M = max(rowmax, Mold) @@ -265,11 +267,8 @@ void main() { // resize eM by using smear/reduce coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce); - // multiply with fp16 accumulation, then add to O. - coopmat PV = coopmat(0); - PV = coopMatMulAdd(P_A, V, PV); - - O = eMdiag * O + coopmat(PV); + O *= coopmat(eMdiag); + O = coopMatMulAdd(P_A, V, O); } // If there is split_k, then the split_k resolve shader does the final @@ -311,7 +310,7 @@ void main() { if (sink > Mr[i]) { ms = exp(Mr[i] - sink); - O[i] *= ms; + O[i] *= float16_t(ms); } else { vs = exp(sink - Mr[i]); } @@ -325,15 +324,16 @@ void main() { Ldiag[k] = (Ldiag[k] == 0.0) ? ACC_TYPE(0.0) : (ACC_TYPE(1.0) / Ldiag[k]); } - O = Ldiag*O; + coopmat O_D = coopmat(O); + + O_D = coopmat(Ldiag)*O_D; #if defined(ACC_TYPE_MAX) - [[unroll]] for (uint i = 0; i < O.length(); ++i) { O[i] = clamp(O[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); } + [[unroll]] for (uint i = 0; i < O_D.length(); ++i) { O_D[i] = clamp(O_D[i], D_TYPE(-ACC_TYPE_MAX), D_TYPE(ACC_TYPE_MAX)); } #endif uint32_t o_offset = gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV; - coopmat O_D = coopmat(O); if (p.gqa_ratio > 1) { coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N); } else { From f2f73208171564bc9fc1fa361fa434c5b0ce4ba3 Mon Sep 17 00:00:00 2001 From: Nechama Krashinski Date: Fri, 6 Feb 2026 17:13:44 +0200 Subject: [PATCH 109/831] sycl: add F16 support for GGML_OP_CEIL (llama/19306) * Fix SYCL CEIL operator * sycl: implement GGML_OP_CEIL --- ggml/src/ggml-sycl/element_wise.cpp | 13 +++---------- ggml/src/ggml-sycl/ggml-sycl.cpp | 2 +- 2 files changed, 4 insertions(+), 11 deletions(-) diff --git a/ggml/src/ggml-sycl/element_wise.cpp b/ggml/src/ggml-sycl/element_wise.cpp index 651b875b636..00d54b83f82 100644 --- a/ggml/src/ggml-sycl/element_wise.cpp +++ b/ggml/src/ggml-sycl/element_wise.cpp @@ -836,16 +836,9 @@ static inline void ggml_sycl_op_floor(ggml_backend_sycl_context & ctx, ggml_tens } static inline void ggml_sycl_op_ceil(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, - [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { - const int num_blocks = ceil_div(k_elements, 256); - stream->parallel_for( - sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256), - sycl::range<1>(256)), - [=](sycl::nd_item<1> item_ct1) { - unary_op_ceil_kernel(src, dst_ptr, k_elements, item_ct1); - }); - }); + ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) { + return op_ceil(x); + }); } static inline void ggml_sycl_op_round(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index a03d26d7f20..0614d7e8f3a 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -4591,9 +4591,9 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g case GGML_UNARY_OP_EXP: case GGML_UNARY_OP_SOFTPLUS: case GGML_UNARY_OP_ELU: + case GGML_UNARY_OP_CEIL: return true; case GGML_UNARY_OP_FLOOR: - case GGML_UNARY_OP_CEIL: case GGML_UNARY_OP_ROUND: case GGML_UNARY_OP_TRUNC: #if defined (GGML_SYCL_F16) From 1739af663a850fb02f8e3437a67f8e0d4dfad8b5 Mon Sep 17 00:00:00 2001 From: Abhijit Ramesh Date: Fri, 6 Feb 2026 10:33:30 -0800 Subject: [PATCH 110/831] ggml-webgpu: JIT compile binary operators and handle binding overlaps (llama/19310) * ggml webgpu: port binary operators to use pre-wgsl * Add binary.wgsl: unified shader with conditionals for all 4 ops * Add gen_binary_shaders.cpp: build tool for using pre_wgsl preprocessor * Remove bin_op.tmpl.wgsl and binary.wgsl (Python template) * Update CMake to generate binary operator shaders at build time * ggml-webgpu: migrate binary ops to JIT compilation with overlap handling * port binary operators from AOT to pre-wgsl JIT compilation * add src1=dst overlap handling for binary ops * use compile-time workgroup size defines instead of runtime overrides * ggml-webgpu: complete overlap handling for binary ops * add support for inplace & overlap case in binding setup * restructure conditional logic to handle all overlap cases * ensure all buffer bindings are correctly assigned for edge cases * ggml-webgpu: remove unused binary overlap cases Remove src0==src1 binary overlap case that never occurs in practice. * keep INPLACE (src0==dst), OVERLAP (src1==dst), DEFAULT * remove unused src0==src1 and all-same variant * refactor wgsl to eliminate duplication --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 69 +++++++ ggml/src/ggml-webgpu/ggml-webgpu.cpp | 178 ++++++++--------- .../ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl | 188 ------------------ ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl | 107 ++++++++++ .../ggml-webgpu/wgsl-shaders/binary_head.tmpl | 45 ----- 5 files changed, 257 insertions(+), 330 deletions(-) delete mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl delete mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 84d88e81d45..6997f6bdd31 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -465,4 +465,73 @@ inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_unary_shader( return result; } +/** Binary **/ + +struct ggml_webgpu_binary_pipeline_key { + int type; + int op; + bool inplace; + bool overlap; + + bool operator==(const ggml_webgpu_binary_pipeline_key & other) const { + return type == other.type && op == other.op && inplace == other.inplace && overlap == other.overlap; + } +}; + +struct ggml_webgpu_binary_pipeline_key_hash { + size_t operator()(const ggml_webgpu_binary_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.type); + ggml_webgpu_hash_combine(seed, key.op); + ggml_webgpu_hash_combine(seed, key.inplace); + ggml_webgpu_hash_combine(seed, key.overlap); + return seed; + } +}; + +struct ggml_webgpu_binary_shader_lib_context { + ggml_webgpu_binary_pipeline_key key; + uint32_t max_wg_size; +}; + +inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_binary_shader( + pre_wgsl::Preprocessor & preprocessor, + const char * shader_src, + const ggml_webgpu_binary_shader_lib_context & context) { + std::vector defines; + std::string op_name = ggml_op_name((ggml_op) context.key.op); + std::string variant = op_name; + + defines.push_back(std::string("OP_") + op_name); + + switch (context.key.type) { + case GGML_TYPE_F32: + defines.push_back("TYPE_F32"); + variant += "_f32"; + break; + case GGML_TYPE_F16: + defines.push_back("TYPE_F16"); + variant += "_f16"; + break; + default: + GGML_ABORT("Unsupported type for binary shader"); + } + + if (context.key.inplace) { + defines.push_back("INPLACE"); + variant += "_inplace"; + } else if (context.key.overlap) { + defines.push_back("OVERLAP"); + variant += "_overlap"; + } + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + ggml_webgpu_processed_shader result; + result.wgsl = preprocessor.preprocess(shader_src, defines); + result.variant = variant; + ggml_webgpu_generic_shader_decisions * decisions = new ggml_webgpu_generic_shader_decisions(); + decisions->wg_size = context.max_wg_size; + result.decisions = decisions; + return result; +} #endif // GGML_WEBGPU_SHADER_LIB_HPP diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 4ef50e365ef..f7ceca11212 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -348,13 +348,12 @@ struct webgpu_context_struct { std::unordered_map set_rows_pipelines; - std::map> get_rows_pipelines; // src_type, vectorized + std::map> get_rows_pipelines; // src_type, vectorized - std::map> cpy_pipelines; // src_type, dst_type - std::map> add_pipelines; // type, inplace - std::map> sub_pipelines; // type, inplace - std::map> mul_pipelines; // type, inplace - std::map> div_pipelines; // type, inplace + std::map> cpy_pipelines; // src_type, dst_type + + std::unordered_map + binary_pipelines; std::map rms_norm_pipelines; // inplace std::map>> rope_pipelines; // type, ff, inplace @@ -823,6 +822,28 @@ static bool ggml_webgpu_tensor_equal(ggml_tensor * a, ggml_tensor * b) { (ggml_webgpu_tensor_offset(a) == ggml_webgpu_tensor_offset(b)); } +// Used to determine if two tensors share the same buffer and their byte ranges overlap, +static bool ggml_webgpu_tensor_overlap(ggml_tensor * a, ggml_tensor * b) { + return (ggml_webgpu_tensor_buf(a).Get() == ggml_webgpu_tensor_buf(b).Get()) && + ggml_webgpu_tensor_offset(a) < (ggml_webgpu_tensor_offset(b) + ggml_nbytes(b)) && + ggml_webgpu_tensor_offset(b) < (ggml_webgpu_tensor_offset(a) + ggml_nbytes(a)); +} + +struct binary_overlap_flags { + bool inplace; // src0 == dst + bool overlap; // src1 == dst +}; + +static binary_overlap_flags ggml_webgpu_detect_binary_overlap(ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { + binary_overlap_flags flags = {}; + flags.inplace = ggml_webgpu_tensor_equal(src0, dst); + flags.overlap = ggml_webgpu_tensor_overlap(src1, dst); + + return flags; +} + static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { uint32_t ne = (uint32_t) ggml_nelements(dst); @@ -1375,14 +1396,42 @@ static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * s return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } -static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * dst, - webgpu_pipeline & pipeline, - bool inplace) { +static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { + binary_overlap_flags flags = ggml_webgpu_detect_binary_overlap(src0, src1, dst); + + ggml_webgpu_binary_pipeline_key pipeline_key = { + .type = dst->type, + .op = dst->op, + .inplace = flags.inplace, + .overlap = flags.overlap, + }; + ggml_webgpu_binary_shader_lib_context shader_lib_ctx = { + .key = pipeline_key, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup + }; + + webgpu_pipeline pipeline; + auto it = ctx->binary_pipelines.find(pipeline_key); + if (it != ctx->binary_pipelines.end()) { + pipeline = it->second; + } else { + ggml_webgpu_processed_shader processed = + ggml_webgpu_preprocess_binary_shader(ctx->p, wgsl_binary, shader_lib_ctx); + pipeline = + ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); + pipeline.context = processed.decisions; + ctx->binary_pipelines.emplace(pipeline_key, pipeline); + } + + ggml_webgpu_generic_shader_decisions decisions = + *static_cast(pipeline.context); + + uint32_t ne = (uint32_t) ggml_nelements(dst); + std::vector params = { - (uint32_t) ggml_nelements(dst), + ne, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), @@ -1399,24 +1448,30 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx, (uint32_t) src1->ne[3], }; - std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = ggml_webgpu_tensor_align_offset(ctx, src0), - .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(src1), - .offset = ggml_webgpu_tensor_align_offset(ctx, src1), - .size = ggml_webgpu_tensor_binding_size(ctx, src1) } - }; - if (!inplace) { + std::vector entries; + + entries.push_back({ + .binding = 0, + .buffer = ggml_webgpu_tensor_buf(src0), + .offset = ggml_webgpu_tensor_align_offset(ctx, src0), + .size = ggml_webgpu_tensor_binding_size(ctx, src0), + }); + + entries.push_back({ + .binding = 1, + .buffer = ggml_webgpu_tensor_buf(src1), + .offset = ggml_webgpu_tensor_align_offset(ctx, src1), + .size = ggml_webgpu_tensor_binding_size(ctx, src1), + }); + + if (!flags.inplace && !flags.overlap) { entries.push_back({ .binding = 2, .buffer = ggml_webgpu_tensor_buf(dst), .offset = ggml_webgpu_tensor_align_offset(ctx, dst), .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); } - uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE); + uint32_t wg_x = CEIL_DIV(ne, decisions.wg_size); return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } @@ -2038,25 +2093,10 @@ static std::optional ggml_webgpu_encode_node(webgpu_context ctx, return std::nullopt; #endif case GGML_OP_ADD: - { - int inplace = ggml_webgpu_tensor_equal(src0, node); - return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_pipelines[node->type][inplace], inplace); - } case GGML_OP_SUB: - { - int inplace = ggml_webgpu_tensor_equal(src0, node); - return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->sub_pipelines[node->type][inplace], inplace); - } case GGML_OP_MUL: - { - int inplace = ggml_webgpu_tensor_equal(src0, node); - return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_pipelines[node->type][inplace], inplace); - } case GGML_OP_DIV: - { - int inplace = ggml_webgpu_tensor_equal(src0, node); - return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->div_pipelines[node->type][inplace], inplace); - } + return ggml_webgpu_binary_op(ctx, src0, src1, node); case GGML_OP_RMS_NORM: return ggml_webgpu_rms_norm(ctx, src0, node); case GGML_OP_ROPE: @@ -2665,58 +2705,6 @@ static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) { ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f16_f16, "cpy_f16_f16", constants); } -static void ggml_webgpu_init_add_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); - - webgpu_ctx->add_pipelines[GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_add_f32, "add_f32", constants); - webgpu_ctx->add_pipelines[GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_add_f16, "add_f16", constants); - webgpu_ctx->add_pipelines[GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_add_f32_inplace, "add_f32_inplace", constants); - webgpu_ctx->add_pipelines[GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_add_f16_inplace, "add_f16_inplace", constants); -} - -static void ggml_webgpu_init_sub_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); - - webgpu_ctx->sub_pipelines[GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_sub_f32, "sub_f32", constants); - webgpu_ctx->sub_pipelines[GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_sub_f16, "sub_f16", constants); - webgpu_ctx->sub_pipelines[GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_sub_f32_inplace, "sub_f32_inplace", constants); - webgpu_ctx->sub_pipelines[GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_sub_f16_inplace, "sub_f16_inplace", constants); -} - -static void ggml_webgpu_init_mul_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); - - webgpu_ctx->mul_pipelines[GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_f32, "mul_f32", constants); - webgpu_ctx->mul_pipelines[GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_f16, "mul_f16", constants); - webgpu_ctx->mul_pipelines[GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_f32_inplace, "mul_f32_inplace", constants); - webgpu_ctx->mul_pipelines[GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_f16_inplace, "mul_f16_inplace", constants); -} - -static void ggml_webgpu_init_div_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); - - webgpu_ctx->div_pipelines[GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_div_f32, "div_f32", constants); - webgpu_ctx->div_pipelines[GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_div_f16, "div_f16", constants); - webgpu_ctx->div_pipelines[GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_div_f32_inplace, "div_f32_inplace", constants); - webgpu_ctx->div_pipelines[GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_div_f16_inplace, "div_f16_inplace", constants); -} - static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) { std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE); @@ -3018,10 +3006,6 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) { ggml_webgpu_init_mul_mat_pipeline(webgpu_ctx); ggml_webgpu_init_get_rows_pipeline(webgpu_ctx); ggml_webgpu_init_cpy_pipeline(webgpu_ctx); - ggml_webgpu_init_add_pipeline(webgpu_ctx); - ggml_webgpu_init_sub_pipeline(webgpu_ctx); - ggml_webgpu_init_mul_pipeline(webgpu_ctx); - ggml_webgpu_init_div_pipeline(webgpu_ctx); ggml_webgpu_init_rms_norm_pipeline(webgpu_ctx); ggml_webgpu_init_rope_pipeline(webgpu_ctx); ggml_webgpu_init_glu_pipeline(webgpu_ctx); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl deleted file mode 100644 index 1ce4d83fa8e..00000000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +++ /dev/null @@ -1,188 +0,0 @@ -#define(VARIANTS) - -[ - { - "SHADER_NAME": "add_f32", - "REPLS": { - "TYPE" : "f32", - "OP": "+" - }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "add_f16", - "REPLS": { - "TYPE" : "f16", - "OP": "+" - }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "add_f32_inplace", - "REPLS": { - "TYPE" : "f32", - "OP": "+" - }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "add_f16_inplace", - "REPLS": { - "TYPE" : "f16", - "OP": "+" - }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "mul_f32", - "REPLS": { - "TYPE" : "f32", - "OP": "*" - }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "mul_f16", - "REPLS": { - "TYPE" : "f16", - "OP": "*" - }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "mul_f32_inplace", - "REPLS": { - "TYPE" : "f32", - "OP": "*" - }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "mul_f16_inplace", - "REPLS": { - "TYPE" : "f16", - "OP": "*" - }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "sub_f32", - "REPLS": { - "TYPE" : "f32", - "OP": "-" - }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "sub_f16", - "REPLS": { - "TYPE" : "f16", - "OP": "-" - }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "sub_f32_inplace", - "REPLS": { - "TYPE" : "f32", - "OP": "-" - }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "sub_f16_inplace", - "REPLS": { - "TYPE" : "f16", - "OP": "-" - }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "div_f32", - "REPLS": { - "TYPE" : "f32", - "OP": "/" - }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "div_f16", - "REPLS": { - "TYPE" : "f16", - "OP": "/" - }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "div_f32_inplace", - "REPLS": { - "TYPE" : "f32", - "OP": "/" - }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "div_f16_inplace", - "REPLS": { - "TYPE" : "f16", - "OP": "/" - }, - "DECLS": ["INPLACE"] - } -] - -#end(VARIANTS) - -#define(DECLS) - -#decl(NOT_INPLACE) - -fn update(dst_i: u32, src0_i: u32, src1_i: u32) { - dst[dst_i] = src0[src0_i] {{OP}} src1[src1_i]; -} - -@group(0) @binding(2) -var dst: array<{{TYPE}}>; - -@group(0) @binding(3) -var params: Params; - -#enddecl(NOT_INPLACE) - -#decl(INPLACE) - -fn update(dst_i: u32, src0_i: u32, src1_i: u32) { - src0[dst_i] = src0[src0_i] {{OP}} src1[src1_i]; -} - -@group(0) @binding(2) -var params: Params; - -#enddecl(INPLACE) - -#end(DECLS) - - -#define(SHADER) - -enable f16; - -#include "binary_head.tmpl" - -@group(0) @binding(0) -var src0: array<{{TYPE}}>; - -@group(0) @binding(1) -var src1: array<{{TYPE}}>; - -DECLS - -override wg_size: u32; -@compute @workgroup_size(wg_size) -fn main(@builtin(global_invocation_id) gid: vec3) { - if (gid.x < params.ne) { - update(params.offset_dst + gid.x, params.offset_src0 + gid.x, params.offset_src1 + src1_index(gid.x)); - } -} - -#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl new file mode 100644 index 00000000000..55dd66408a3 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl @@ -0,0 +1,107 @@ +enable f16; + +struct Params { + ne: u32, + + // offsets in elements + offset_src0: u32, + offset_src1: u32, + offset_dst: u32, + + stride_src1_0: u32, + stride_src1_1: u32, + stride_src1_2: u32, + stride_src1_3: u32, + + a_ne0: u32, + a_ne1: u32, + a_ne2: u32, + + b_ne0: u32, + b_ne1: u32, + b_ne2: u32, + b_ne3: u32, +}; + +fn src1_index(_i: u32) -> u32 { + var i = _i; + let a_i3 = i / (params.a_ne2 * params.a_ne1 * params.a_ne0); + i = i % (params.a_ne2 * params.a_ne1 * params.a_ne0); + let a_i2 = i / (params.a_ne1 * params.a_ne0); + i = i % (params.a_ne1 * params.a_ne0); + let a_i1 = i / params.a_ne0; + let a_i0 = i % params.a_ne0; + + // handle repetition of b + // index loops back to the beginning and repeats after elements are exhausted = modulo + let b_i0 = a_i0 % params.b_ne0; + let b_i1 = a_i1 % params.b_ne1; + let b_i2 = a_i2 % params.b_ne2; + let b_i3 = a_i3 % params.b_ne3; + + // compute index for position in b's flat array + return b_i0 * params.stride_src1_0 + + b_i1 * params.stride_src1_1 + + b_i2 * params.stride_src1_2 + + b_i3 * params.stride_src1_3; +} + +#ifdef TYPE_F32 +#define DataType f32 +#endif +#ifdef TYPE_F16 +#define DataType f16 +#endif + +@group(0) @binding(0) +var src0: array; + +@group(0) @binding(1) +var src1 : array; + +#ifdef INPLACE +@group(0) @binding(2) +var params: Params; + +#elif defined(OVERLAP) +@group(0) @binding(2) +var params: Params; + +#else +@group(0) @binding(2) +var dst: array; + +@group(0) @binding(3) +var params: Params; +#endif + +fn op(a: DataType, b: DataType) -> DataType { +#ifdef OP_ADD + return a + b; +#elif defined(OP_SUB) + return a - b; +#elif defined(OP_MUL) + return a * b; +#elif defined(OP_DIV) + return a / b; +#endif +} + +fn update(dst_i: u32, src0_i: u32, src1_i: u32){ + let result = op(src0[src0_i], src1[src1_i]); + +#ifdef INPLACE + src0[dst_i] = result; +#elif defined(OVERLAP) + src1[dst_i] = result; +#else + dst[dst_i] = result; +#endif +} + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(global_invocation_id) gid: vec3) { + if (gid.x < params.ne) { + update(params.offset_dst + gid.x, params.offset_src0 + gid.x, params.offset_src1 + src1_index(gid.x)); + } +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl deleted file mode 100644 index 4b254f468d6..00000000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +++ /dev/null @@ -1,45 +0,0 @@ -struct Params { - ne: u32, - - // offsets in elements - offset_src0: u32, - offset_src1: u32, - offset_dst: u32, - - stride_src1_0: u32, - stride_src1_1: u32, - stride_src1_2: u32, - stride_src1_3: u32, - - a_ne0: u32, - a_ne1: u32, - a_ne2: u32, - - b_ne0: u32, - b_ne1: u32, - b_ne2: u32, - b_ne3: u32, -}; - -fn src1_index(_i: u32) -> u32 { - var i = _i; - let a_i3 = i / (params.a_ne2 * params.a_ne1 * params.a_ne0); - i = i % (params.a_ne2 * params.a_ne1 * params.a_ne0); - let a_i2 = i / (params.a_ne1 * params.a_ne0); - i = i % (params.a_ne1 * params.a_ne0); - let a_i1 = i / params.a_ne0; - let a_i0 = i % params.a_ne0; - - // handle repetition of b - // index loops back to the beginning and repeats after elements are exhausted = modulo - let b_i0 = a_i0 % params.b_ne0; - let b_i1 = a_i1 % params.b_ne1; - let b_i2 = a_i2 % params.b_ne2; - let b_i3 = a_i3 % params.b_ne3; - - // compute index for position in b's flat array - return b_i0 * params.stride_src1_0 + - b_i1 * params.stride_src1_1 + - b_i2 * params.stride_src1_2 + - b_i3 * params.stride_src1_3; -} From a9a0a51fbadc01630237c8fd869743ef50ef501b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 7 Feb 2026 07:37:15 +0200 Subject: [PATCH 111/831] metal : fix event synchronization in cpy_tensor_async (llama/19402) --- ggml/src/ggml-metal/ggml-metal-context.m | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-metal/ggml-metal-context.m b/ggml/src/ggml-metal/ggml-metal-context.m index c7e8ebd3f32..5d3a8ce412a 100644 --- a/ggml/src/ggml-metal/ggml-metal-context.m +++ b/ggml/src/ggml-metal/ggml-metal-context.m @@ -394,7 +394,7 @@ bool ggml_metal_cpy_tensor_async(ggml_metal_t ctx_src, ggml_metal_t ctx_dst, con [encoder endEncoding]; ggml_metal_event_t ev_cpy = ggml_metal_get_ev_cpy(ctx_src); - ggml_metal_event_record(ctx_src, ev_cpy); + ggml_metal_event_encode_signal(ev_cpy, cmd_buf); [cmd_buf commit]; From 55d7cb2e938169a2d336e706b9fbf2fc26eb02dc Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 7 Feb 2026 10:35:56 +0200 Subject: [PATCH 112/831] metal : consolidate bin kernels (llama/19390) * metal : refactor bin kernels * cont * cont : fix cv --- ggml/src/ggml-metal/ggml-metal-device.cpp | 78 ++++- ggml/src/ggml-metal/ggml-metal-device.h | 6 +- ggml/src/ggml-metal/ggml-metal-device.m | 10 +- ggml/src/ggml-metal/ggml-metal-impl.h | 1 + ggml/src/ggml-metal/ggml-metal-ops.cpp | 33 +- ggml/src/ggml-metal/ggml-metal.metal | 377 ++++++++-------------- 6 files changed, 217 insertions(+), 288 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 6af0dd88d55..4c4c3ce36c4 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -1392,34 +1392,78 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_v GGML_UNUSED(op); } -ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin( - ggml_metal_library_t lib, - ggml_op op, - int32_t n_fuse, - bool row) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin(ggml_metal_library_t lib, const ggml_tensor * op, int32_t n_fuse) { char base[256]; char name[256]; - const char * op_str = "undefined"; - switch (op) { - case GGML_OP_ADD: op_str = "add"; break; - case GGML_OP_SUB: op_str = "sub"; break; - case GGML_OP_MUL: op_str = "mul"; break; - case GGML_OP_DIV: op_str = "div"; break; + int op_num = -1; + + switch (op->op) { + case GGML_OP_ADD: op_num = 0; break; + case GGML_OP_SUB: op_num = 1; break; + case GGML_OP_MUL: op_num = 2; break; + case GGML_OP_DIV: op_num = 3; break; default: GGML_ABORT("fatal error"); }; - if (row) { - snprintf(base, 256, "kernel_%s_row_c4_fuse_%d", op_str, n_fuse); - } else { - snprintf(base, 256, "kernel_%s_fuse_%d", op_str, n_fuse); + const char * t0_str = ggml_type_name(op->src[0]->type); + const char * t1_str = ggml_type_name(op->src[1]->type); + const char * t_str = ggml_type_name(op->type); + + const bool is_c4 = (op->src[0]->ne[0] % 4 == 0) && (op->src[1]->ne[0] % 4 == 0); + + const bool is_rb = ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]) && (ggml_nrows(op->src[1]) == 1) && ggml_nelements(op) < 65536; + + snprintf(base, 256, "kernel_bin_fuse_%s_%s_%s%s", t0_str, t1_str, t_str, is_c4 ? "_4" : ""); + snprintf(name, 256, "%s_op=%d_nf=%d_rb=%d", base, op_num, n_fuse, is_rb); + + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_int16(cv, op_num, FC_BIN + 0); + ggml_metal_cv_set_int16(cv, n_fuse, FC_BIN + 1); + ggml_metal_cv_set_bool (cv, is_rb, FC_BIN + 2); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); } - snprintf(name, 256, "%s", base); + res.c4 = is_c4; + res.cnt = is_rb; + + return res; +} + +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin_one(ggml_metal_library_t lib, ggml_op op) { + char base[256]; + char name[256]; + + int op_num = -1; + + switch (op) { + case GGML_OP_ADD: op_num = 0; break; + case GGML_OP_SUB: op_num = 1; break; + case GGML_OP_MUL: op_num = 2; break; + case GGML_OP_DIV: op_num = 3; break; + default: GGML_ABORT("fatal error"); + }; + + snprintf(base, 256, "kernel_bin_fuse_%s_%s_%s", "f32", "f32", "f32"); + snprintf(name, 256, "%s_op=%d_nf=%d", base, op_num, 1); ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); if (!res.pipeline) { - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_int16(cv, op_num, FC_BIN + 0); + ggml_metal_cv_set_int16(cv, 1, FC_BIN + 1); + ggml_metal_cv_set_bool (cv, false, FC_BIN + 2); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); } return res; diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index 84dcec30830..93d7f6a216f 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -53,6 +53,9 @@ struct ggml_metal_pipeline_with_params { int nr1; size_t smem; + + bool c4; + bool cnt; }; int ggml_metal_pipeline_max_theads_per_threadgroup(struct ggml_metal_pipeline_with_params pipeline); @@ -134,7 +137,8 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort_merge (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k_merge (ggml_metal_library_t lib, const struct ggml_tensor * op); -struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse ); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin_one (ggml_metal_library_t lib, enum ggml_op op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse); diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index c8e737d4187..891d70c85a4 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -346,10 +346,12 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline(ggml_meta struct ggml_metal_pipeline_with_params res = { /*.pipeline =*/ nil, + /*.nsg =*/ 0, /*.nr0 =*/ 0, /*.nr1 =*/ 0, - /*.nsg =*/ 0, /*.smem =*/ 0, + /*.c4 =*/ false, + /*.cnt =*/ false, }; res.pipeline = ggml_metal_pipelines_get(lib->pipelines, name); @@ -362,10 +364,12 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline(ggml_meta struct ggml_metal_pipeline_with_params ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv) { struct ggml_metal_pipeline_with_params res = { /*.pipeline =*/ nil, + /*.nsg =*/ 0, /*.nr0 =*/ 0, /*.nr1 =*/ 0, - /*.nsg =*/ 0, /*.smem =*/ 0, + /*.c4 =*/ false, + /*.cnt =*/ false, }; [lib->lock lock]; @@ -1054,7 +1058,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_MUL: case GGML_OP_DIV: case GGML_OP_ADD_ID: - return op->src[0]->type == GGML_TYPE_F32; + return ggml_is_contiguous_rows(op->src[0]) && ggml_is_contiguous_rows(op->src[1]) && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_ACC: case GGML_OP_REPEAT: case GGML_OP_SCALE: diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 7f73cb97bbb..77bb403c15d 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -80,6 +80,7 @@ #define FC_SSM_CONV 900 #define FC_SOLVE_TRI 1000 #define FC_COUNT_EQUAL 1100 +#define FC_BIN 1200 // op-specific constants #define OP_FLASH_ATTN_EXT_NQPSG 8 diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index e0ed6c7805c..dbf25433c25 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -707,7 +707,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) { /*.o1 =*/ { 0 }, }; - auto pipeline = ggml_metal_library_get_pipeline_bin(lib, GGML_OP_ADD, 1, false); + auto pipeline = ggml_metal_library_get_pipeline_bin_one(lib, GGML_OP_ADD); ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); @@ -2895,8 +2895,6 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) { GGML_ASSERT(ggml_is_contiguous_rows(op->src[0])); GGML_ASSERT(ggml_is_contiguous_rows(op->src[1])); - bool bcast_row = false; - ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]); ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]); ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op); @@ -2990,18 +2988,7 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) { struct ggml_metal_pipeline_with_params pipeline; - if (ggml_nelements(op->src[1]) == ne10 && ggml_is_contiguous(op->src[1]) && ne00 % 4 == 0 && ne10 % 4 == 0) { - GGML_ASSERT(ggml_is_contiguous(op->src[0])); - - // src1 is a row - GGML_ASSERT(ne11 == 1); - - pipeline = ggml_metal_library_get_pipeline_bin(lib, op->op, n_fuse, true); - - bcast_row = true; - } else { - pipeline = ggml_metal_library_get_pipeline_bin(lib, op->op, n_fuse, false); - } + pipeline = ggml_metal_library_get_pipeline_bin(lib, op, n_fuse); if (n_fuse > 1) { bid_dst = ggml_metal_get_buffer_id(ctx->node(idx + n_fuse - 1)); @@ -3015,20 +3002,28 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) { } } + if (pipeline.c4) { + args.ne00 = ne00/4; + args.ne10 = ne10/4; + args.ne0 = ne0/4; + } + ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); ggml_metal_encoder_set_buffer (enc, bid_src0, 1); ggml_metal_encoder_set_buffer (enc, bid_src1, 2); ggml_metal_encoder_set_buffer (enc, bid_dst, 3); - if (bcast_row) { - const int64_t n = ggml_nelements(op)/4; + if (pipeline.cnt) { + const int n = pipeline.c4 ? ggml_nelements(op)/4 : ggml_nelements(op); ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1); } else { - int nth = 32; + const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + + int nth = 1; - while (16*nth < ne0 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + while (2*nth < args.ne0 && nth < nth_max) { nth *= 2; } diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 612a42a1ea8..35cc3bbdfdf 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -895,11 +895,13 @@ enum ggml_sort_order { GGML_SORT_ORDER_DESC, }; -// general-purpose kernel for addition, subtraction, multiplication and division of two tensors -// pros: works for non-contiguous tensors, supports broadcast across all dims -// cons: not very efficient -template -kernel void kernel_add_fuse_impl( +// OP: 0 - add, 1 - sub, 2 - mul, 3 - div +constant short FC_bin_op [[function_constant(FC_BIN + 0)]]; +constant short FC_bin_f [[function_constant(FC_BIN + 1)]]; +constant bool FC_bin_rb [[function_constant(FC_BIN + 2)]]; + +template +kernel void kernel_bin_fuse_impl( constant ggml_metal_kargs_bin & args, device const char * src0, device const char * src1, @@ -907,139 +909,153 @@ kernel void kernel_add_fuse_impl( uint3 tgpig[[threadgroup_position_in_grid]], ushort3 tpitg[[thread_position_in_threadgroup]], ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig.z; - const int i02 = tgpig.y; - const int i01 = tgpig.x; +#define FC_OP FC_bin_op +#define FC_F FC_bin_f +#define FC_RB FC_bin_rb - const int i13 = i03%args.ne13; - const int i12 = i02%args.ne12; - const int i11 = i01%args.ne11; + if (FC_RB) { + // row broadcast + const uint i0 = tgpig.x; + const uint i1 = i0%args.ne10; - device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs); - device float * dst_ptr = (device float *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs); + device const T0 * src0_row = (device const T0 *) (src0); + device T * dst_row = (device T *) (dst); - device const float * src1_ptr[F]; - for (short j = 0; j < F; ++j) { - src1_ptr[j] = (device const float *) (src1 + args.o1[j] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11); - } + if (FC_F == 1) { + device const T1 * src1_row = (device const T1 *) (src1 + args.o1[0]); - for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { - const int i10 = i0%args.ne10; + if (FC_OP == 0) { + dst_row[i0] = src0_row[i0] + src1_row[i1]; + } + + if (FC_OP == 1) { + dst_row[i0] = src0_row[i0] - src1_row[i1]; + } + + if (FC_OP == 2) { + dst_row[i0] = src0_row[i0] * src1_row[i1]; + } + + if (FC_OP == 3) { + dst_row[i0] = src0_row[i0] / src1_row[i1]; + } + } else { + T0 res = src0_row[i0]; + + if (FC_OP == 0) { + FOR_UNROLL (short j = 0; j < FC_F; ++j) { + res += ((device const T1 *) (src1 + args.o1[j]))[i1]; + } + } + + if (FC_OP == 1) { + FOR_UNROLL (short j = 0; j < FC_F; ++j) { + res -= ((device const T1 *) (src1 + args.o1[j]))[i1]; + } + } - float res = src0_ptr[i0]; + if (FC_OP == 2) { + FOR_UNROLL (short j = 0; j < FC_F; ++j) { + res *= ((device const T1 *) (src1 + args.o1[j]))[i1]; + } + } + + if (FC_OP == 3) { + FOR_UNROLL (short j = 0; j < FC_F; ++j) { + res /= ((device const T1 *) (src1 + args.o1[j]))[i1]; + } + } -#pragma unroll - for (short j = 0; j < F; ++j) { - res += src1_ptr[j][i10]; + dst_row[i0] = res; } + } else { + const int i03 = tgpig.z; + const int i02 = tgpig.y; + const int i01 = tgpig.x; - dst_ptr[i0] = res; - } -} + if (i01 >= args.ne01) { + return; + } -typedef decltype(kernel_add_fuse_impl<2>) kernel_add_fuse_t; + const int i13 = i03%args.ne13; + const int i12 = i02%args.ne12; + const int i11 = i01%args.ne11; -template [[host_name("kernel_add_fuse_1")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<1>; -template [[host_name("kernel_add_fuse_2")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<2>; -template [[host_name("kernel_add_fuse_3")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<3>; -template [[host_name("kernel_add_fuse_4")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<4>; -template [[host_name("kernel_add_fuse_5")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<5>; -template [[host_name("kernel_add_fuse_6")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<6>; -template [[host_name("kernel_add_fuse_7")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<7>; -template [[host_name("kernel_add_fuse_8")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<8>; + device const T0 * src0_ptr = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs); + device T * dst_ptr = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs); -kernel void kernel_sub_fuse_1( - constant ggml_metal_kargs_bin & args, - device const char * src0, - device const char * src1, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig.z; - const int i02 = tgpig.y; - const int i01 = tgpig.x; + if (FC_F == 1) { + device const T1 * src1_ptr = (device const T1 *) (src1 + args.o1[0] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11); - const int i13 = i03%args.ne13; - const int i12 = i02%args.ne12; - const int i11 = i01%args.ne11; + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + const int i10 = i0%args.ne10; - device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs; - device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0]; - device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs; + if (FC_OP == 0) { + dst_ptr[i0] = src0_ptr[i0] + src1_ptr[i10]; + } - for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { - const int i10 = i0%args.ne10; - *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) - *((device float *)(src1_ptr + i10*args.nb10)); - } -} + if (FC_OP == 1) { + dst_ptr[i0] = src0_ptr[i0] - src1_ptr[i10]; + } -kernel void kernel_mul_fuse_1( - constant ggml_metal_kargs_bin & args, - device const char * src0, - device const char * src1, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig.z; - const int i02 = tgpig.y; - const int i01 = tgpig.x; + if (FC_OP == 2) { + dst_ptr[i0] = src0_ptr[i0] * src1_ptr[i10]; + } - const int i13 = i03%args.ne13; - const int i12 = i02%args.ne12; - const int i11 = i01%args.ne11; + if (FC_OP == 3) { + dst_ptr[i0] = src0_ptr[i0] / src1_ptr[i10]; + } + } + } else { + device const T1 * src1_ptr[8]; + FOR_UNROLL (short j = 0; j < FC_F; ++j) { + src1_ptr[j] = (device const T1 *) (src1 + args.o1[j] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11); + } - device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs; - device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0]; - device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs; + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + const int i10 = i0%args.ne10; - if (args.ne10 == 1) { - const float x = *((device float *)(src1_ptr)); - for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { - *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * x; - } - } else { - for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { - const int i10 = i0%args.ne10; - *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * *((device float *)(src1_ptr + i10*args.nb10)); - } - } -} + T res = src0_ptr[i0]; -kernel void kernel_div_fuse_1( - constant ggml_metal_kargs_bin & args, - device const char * src0, - device const char * src1, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig.z; - const int i02 = tgpig.y; - const int i01 = tgpig.x; + if (FC_OP == 0) { + FOR_UNROLL (short j = 0; j < FC_F; ++j) { + res += src1_ptr[j][i10]; + } + } + + if (FC_OP == 1) { + FOR_UNROLL (short j = 0; j < FC_F; ++j) { + res -= src1_ptr[j][i10]; + } + } - const int i13 = i03%args.ne13; - const int i12 = i02%args.ne12; - const int i11 = i01%args.ne11; + if (FC_OP == 2) { + FOR_UNROLL (short j = 0; j < FC_F; ++j) { + res *= src1_ptr[j][i10]; + } + } - device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs; - device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0]; - device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs; + if (FC_OP == 3) { + FOR_UNROLL (short j = 0; j < FC_F; ++j) { + res /= src1_ptr[j][i10]; + } + } - if (args.ne10 == 1) { - const float x = 1.0f / *((device float *)(src1_ptr)); - for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { - *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * x; - } - } else { - for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { - const int i10 = i0%args.ne10; - *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) / *((device float *)(src1_ptr + i10*args.nb10)); + dst_ptr[i0] = res; + } } } + +#undef FC_OP +#undef FC_F +#undef FC_RB } +typedef decltype(kernel_bin_fuse_impl) kernel_bin_fuse_t; + +template [[host_name("kernel_bin_fuse_f32_f32_f32")]] kernel kernel_bin_fuse_t kernel_bin_fuse_impl; +template [[host_name("kernel_bin_fuse_f32_f32_f32_4")]] kernel kernel_bin_fuse_t kernel_bin_fuse_impl; + kernel void kernel_add_id( constant ggml_metal_kargs_add_id & args, device const char * src0, @@ -1057,7 +1073,7 @@ kernel void kernel_add_id( const size_t nb1 = args.ne0 * sizeof(float); const size_t nb2 = args.ne1 * nb1; - device float * dst_row = (device float *)((device char *)dst + i1*nb1 + i2*nb2); + device float * dst_row = (device float *)((device char *)dst + i1*nb1 + i2*nb2); device const float * src0_row = (device const float *)((device char *)src0 + i1*args.nb01 + i2*args.nb02); device const float * src1_row = (device const float *)((device char *)src1 + i11*args.nb11); @@ -1098,141 +1114,6 @@ template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat; template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat; -// assumption: src1 is a row -// broadcast src1 into src0 -template -kernel void kernel_add_row_c4_fuse_impl( - constant ggml_metal_kargs_bin & args, - device const char * src0, - device const char * src1, - device char * dst, - uint tpig[[thread_position_in_grid]]) { - const uint nb = args.ne00/4; - const uint i = tpig % nb; - - device const float4 * src0_row = (device const float4 *) (src0); - device float4 * dst_row = (device float4 *) (dst); - - float4 res = src0_row[tpig]; - -#pragma unroll(F) - for (short j = 0; j < F; ++j) { - res += ((device const float4 *) (src1 + args.o1[j]))[i]; - } - - dst_row[tpig] = res; -} - -typedef decltype(kernel_add_row_c4_fuse_impl<1>) kernel_add_row_c4_fuse_t; - -template [[host_name("kernel_add_row_c4_fuse_1")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<1>; -template [[host_name("kernel_add_row_c4_fuse_2")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<2>; -template [[host_name("kernel_add_row_c4_fuse_3")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<3>; -template [[host_name("kernel_add_row_c4_fuse_4")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<4>; -template [[host_name("kernel_add_row_c4_fuse_5")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<5>; -template [[host_name("kernel_add_row_c4_fuse_6")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<6>; -template [[host_name("kernel_add_row_c4_fuse_7")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<7>; -template [[host_name("kernel_add_row_c4_fuse_8")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<8>; - -template -kernel void kernel_sub_row_c4_fuse_impl( - constant ggml_metal_kargs_bin & args, - device const char * src0, - device const char * src1, - device char * dst, - uint tpig[[thread_position_in_grid]]) { - - const uint nb = args.ne00/4; - const uint i = tpig % nb; - - device const float4 * src0_row = (device const float4 *) (src0); - device float4 * dst_row = (device float4 *) (dst); - - device const float4 * src1_row[F]; - for (short j = 0; j < F; ++j) { - src1_row[j] = (device const float4 *) (src1 + args.o1[j]); - } - - float4 res = src0_row[tpig]; - -#pragma unroll(F) - for (short j = 0; j < F; ++j) { - res -= src1_row[j][i]; - } - - dst_row[tpig] = res; -} - -typedef decltype(kernel_sub_row_c4_fuse_impl<1>) kernel_sub_row_c4_fuse_t; - -template [[host_name("kernel_sub_row_c4_fuse_1")]] kernel kernel_sub_row_c4_fuse_t kernel_sub_row_c4_fuse_impl<1>; - -template -kernel void kernel_mul_row_c4_fuse_impl( - constant ggml_metal_kargs_bin & args, - device const char * src0, - device const char * src1, - device char * dst, - uint tpig[[thread_position_in_grid]]) { - - const uint nb = args.ne00/4; - const uint i = tpig % nb; - - device const float4 * src0_row = (device const float4 *) (src0); - device float4 * dst_row = (device float4 *) (dst); - - device const float4 * src1_row[F]; - for (short j = 0; j < F; ++j) { - src1_row[j] = (device const float4 *) (src1 + args.o1[j]); - } - - float4 res = src0_row[tpig]; - -#pragma unroll(F) - for (short j = 0; j < F; ++j) { - res *= src1_row[j][i]; - } - - dst_row[tpig] = res; -} - -typedef decltype(kernel_mul_row_c4_fuse_impl<1>) kernel_mul_row_c4_fuse_t; - -template [[host_name("kernel_mul_row_c4_fuse_1")]] kernel kernel_mul_row_c4_fuse_t kernel_mul_row_c4_fuse_impl<1>; - -template -kernel void kernel_div_row_c4_fuse_impl( - constant ggml_metal_kargs_bin & args, - device const char * src0, - device const char * src1, - device char * dst, - uint tpig[[thread_position_in_grid]]) { - - const uint nb = args.ne00/4; - const uint i = tpig % nb; - - device const float4 * src0_row = (device const float4 *) (src0); - device float4 * dst_row = (device float4 *) (dst); - - device const float4 * src1_row[F]; - for (short j = 0; j < F; ++j) { - src1_row[j] = (device const float4 *) (src1 + args.o1[j]); - } - - float4 res = src0_row[tpig]; - -#pragma unroll(F) - for (short j = 0; j < F; ++j) { - res /= src1_row[j][i]; - } - - dst_row[tpig] = res; -} - -typedef decltype(kernel_div_row_c4_fuse_impl<1>) kernel_div_row_c4_fuse_t; - -template [[host_name("kernel_div_row_c4_fuse_1")]] kernel kernel_div_row_c4_fuse_t kernel_div_row_c4_fuse_impl<1>; - kernel void kernel_scale_f32( constant ggml_metal_kargs_scale & args, device const float * src0, From b0e81c1a2e7d849058940790e123aa67fa8e24eb Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 7 Feb 2026 10:38:22 +0200 Subject: [PATCH 113/831] sync : ggml --- scripts/sync-ggml.last | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index 95f370ece52..ff14b73caa8 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -f7cb4b731a38e1f6d24e61c966acc35b0cc31263 +5cecdad692d868e28dbd2f7c468504770108f30c From 4b23ff249e7f93137cb870b28fb27818e074c255 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 7 Feb 2026 10:39:43 +0200 Subject: [PATCH 114/831] talk-llama : sync llama.cpp --- examples/talk-llama/CMakeLists.txt | 2 +- examples/talk-llama/llama-arch.cpp | 134 ++- examples/talk-llama/llama-arch.h | 15 + examples/talk-llama/llama-chat.cpp | 2 +- examples/talk-llama/llama-context.cpp | 9 +- examples/talk-llama/llama-grammar.cpp | 2 +- examples/talk-llama/llama-graph.cpp | 115 ++- examples/talk-llama/llama-graph.h | 29 + examples/talk-llama/llama-hparams.cpp | 14 + examples/talk-llama/llama-hparams.h | 10 +- examples/talk-llama/llama-kv-cache-iswa.cpp | 4 +- examples/talk-llama/llama-kv-cache.cpp | 12 +- .../talk-llama/llama-memory-recurrent.cpp | 22 +- examples/talk-llama/llama-model.cpp | 275 +++++++ examples/talk-llama/llama-model.h | 14 + examples/talk-llama/llama-quant.cpp | 4 +- .../{llama-sampling.cpp => llama-sampler.cpp} | 83 +- .../{llama-sampling.h => llama-sampler.h} | 2 - examples/talk-llama/llama-vocab.cpp | 50 +- examples/talk-llama/models/deepseek2.cpp | 2 +- examples/talk-llama/models/kimi-linear.cpp | 772 ++++++++++++++++++ examples/talk-llama/models/models.h | 31 + examples/talk-llama/models/openelm.cpp | 2 +- examples/talk-llama/models/qwen3next.cpp | 12 +- examples/talk-llama/models/step35-iswa.cpp | 168 ++++ examples/talk-llama/unicode.cpp | 49 +- 26 files changed, 1658 insertions(+), 176 deletions(-) rename examples/talk-llama/{llama-sampling.cpp => llama-sampler.cpp} (98%) rename examples/talk-llama/{llama-sampling.h => llama-sampler.h} (92%) create mode 100644 examples/talk-llama/models/kimi-linear.cpp create mode 100644 examples/talk-llama/models/step35-iswa.cpp diff --git a/examples/talk-llama/CMakeLists.txt b/examples/talk-llama/CMakeLists.txt index 20caaa99de5..549842a2474 100644 --- a/examples/talk-llama/CMakeLists.txt +++ b/examples/talk-llama/CMakeLists.txt @@ -29,7 +29,7 @@ if (WHISPER_SDL2) llama-model-saver.cpp llama-model.cpp llama-quant.cpp - llama-sampling.cpp + llama-sampler.cpp llama-vocab.cpp unicode.cpp unicode-data.cpp diff --git a/examples/talk-llama/llama-arch.cpp b/examples/talk-llama/llama-arch.cpp index a54bc1956ae..bd78f1e5562 100644 --- a/examples/talk-llama/llama-arch.cpp +++ b/examples/talk-llama/llama-arch.cpp @@ -117,9 +117,11 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_RND1, "rnd1" }, { LLM_ARCH_PANGU_EMBED, "pangu-embedded" }, { LLM_ARCH_MISTRAL3, "mistral3" }, - { LLM_ARCH_MIMO2, "mimo2" }, + { LLM_ARCH_MIMO2, "mimo2" }, + { LLM_ARCH_STEP35, "step35" }, { LLM_ARCH_LLAMA_EMBED, "llama-embed" }, { LLM_ARCH_MAINCODER, "maincoder" }, + { LLM_ARCH_KIMI_LINEAR, "kimi-linear" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -161,6 +163,8 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_EXPERT_FEED_FORWARD_LENGTH, "%s.expert_feed_forward_length" }, { LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, "%s.expert_shared_feed_forward_length" }, { LLM_KV_EXPERT_CHUNK_FEED_FORWARD_LENGTH, "%s.expert_chunk_feed_forward_length" }, + { LLM_KV_SWIGLU_CLAMP_EXP, "%s.swiglu_clamp_exp" }, + { LLM_KV_SWIGLU_CLAMP_SHEXP, "%s.swiglu_clamp_shexp" }, { LLM_KV_USE_PARALLEL_RESIDUAL, "%s.use_parallel_residual" }, { LLM_KV_TENSOR_DATA_LAYOUT, "%s.tensor_data_layout" }, { LLM_KV_EXPERT_COUNT, "%s.expert_count" }, @@ -219,21 +223,21 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" }, { LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" }, - { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, - { LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" }, - { LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" }, - { LLM_KV_ROPE_FREQ_BASE_SWA, "%s.rope.freq_base_swa" }, - { LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" }, - { LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" }, - { LLM_KV_ROPE_SCALING_FACTOR, "%s.rope.scaling.factor" }, - { LLM_KV_ROPE_SCALING_ATTN_FACTOR, "%s.rope.scaling.attn_factor" }, - { LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" }, - { LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" }, - { LLM_KV_ROPE_SCALING_YARN_LOG_MUL, "%s.rope.scaling.yarn_log_multiplier" }, - { LLM_KV_ROPE_SCALING_YARN_EXT_FACTOR, "%s.rope.scaling.yarn_ext_factor" }, - { LLM_KV_ROPE_SCALING_YARN_ATTN_FACTOR, "%s.rope.scaling.yarn_attn_factor" }, - { LLM_KV_ROPE_SCALING_YARN_BETA_FAST, "%s.rope.scaling.yarn_beta_fast" }, - { LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, "%s.rope.scaling.yarn_beta_slow" }, + { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, + { LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" }, + { LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" }, + { LLM_KV_ROPE_FREQ_BASE_SWA, "%s.rope.freq_base_swa" }, + { LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" }, + { LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" }, + { LLM_KV_ROPE_SCALING_FACTOR, "%s.rope.scaling.factor" }, + { LLM_KV_ROPE_SCALING_ATTN_FACTOR, "%s.rope.scaling.attn_factor" }, + { LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" }, + { LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" }, + { LLM_KV_ROPE_SCALING_YARN_LOG_MUL, "%s.rope.scaling.yarn_log_multiplier" }, + { LLM_KV_ROPE_SCALING_YARN_EXT_FACTOR, "%s.rope.scaling.yarn_ext_factor" }, + { LLM_KV_ROPE_SCALING_YARN_ATTN_FACTOR, "%s.rope.scaling.yarn_attn_factor" }, + { LLM_KV_ROPE_SCALING_YARN_BETA_FAST, "%s.rope.scaling.yarn_beta_fast" }, + { LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, "%s.rope.scaling.yarn_beta_slow" }, { LLM_KV_SPLIT_NO, "split.no" }, { LLM_KV_SPLIT_COUNT, "split.count" }, @@ -246,6 +250,8 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_SSM_GROUP_COUNT, "%s.ssm.group_count" }, { LLM_KV_SSM_DT_B_C_RMS, "%s.ssm.dt_b_c_rms" }, + { LLM_KV_KDA_HEAD_DIM, "%s.kda.head_dim" }, + { LLM_KV_WKV_HEAD_SIZE, "%s.wkv.head_size" }, { LLM_KV_POSNET_EMBEDDING_LENGTH, "%s.posnet.embedding_length" }, @@ -371,6 +377,15 @@ static const std::map LLM_TENSOR_NAMES = { { LLM_TENSOR_SSM_DT_NORM, "blk.%d.ssm_dt_norm" }, { LLM_TENSOR_SSM_B_NORM, "blk.%d.ssm_b_norm" }, { LLM_TENSOR_SSM_C_NORM, "blk.%d.ssm_c_norm" }, + { LLM_TENSOR_SSM_CONV1D_Q, "blk.%d.ssm_conv1d_q" }, + { LLM_TENSOR_SSM_CONV1D_K, "blk.%d.ssm_conv1d_k" }, + { LLM_TENSOR_SSM_CONV1D_V, "blk.%d.ssm_conv1d_v" }, + { LLM_TENSOR_SSM_F_A, "blk.%d.ssm_f_a" }, + { LLM_TENSOR_SSM_F_B, "blk.%d.ssm_f_b" }, + { LLM_TENSOR_SSM_BETA, "blk.%d.ssm_beta" }, + { LLM_TENSOR_SSM_G_A, "blk.%d.ssm_g_a" }, + { LLM_TENSOR_SSM_G_B, "blk.%d.ssm_g_b" }, + { LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" }, { LLM_TENSOR_ATTN_Q_A_NORM, "blk.%d.attn_q_a_norm" }, { LLM_TENSOR_ATTN_KV_A_NORM, "blk.%d.attn_kv_a_norm" }, { LLM_TENSOR_ATTN_Q_A, "blk.%d.attn_q_a" }, @@ -2267,6 +2282,35 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_FFN_UP_EXPS, LLM_TENSOR_FFN_EXP_PROBS_B, }; + case LLM_ARCH_STEP35: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ROPE_FREQS, + LLM_TENSOR_ROPE_FACTORS_LONG, + LLM_TENSOR_ROPE_FACTORS_SHORT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_Q_NORM, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_K_NORM, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_GATE, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_FFN_GATE_INP, + LLM_TENSOR_FFN_GATE_EXPS, + LLM_TENSOR_FFN_DOWN_EXPS, + LLM_TENSOR_FFN_UP_EXPS, + LLM_TENSOR_FFN_GATE_SHEXP, + LLM_TENSOR_FFN_UP_SHEXP, + LLM_TENSOR_FFN_DOWN_SHEXP, + LLM_TENSOR_FFN_EXP_PROBS_B, + }; case LLM_ARCH_GPTJ: case LLM_ARCH_UNKNOWN: return { @@ -2289,6 +2333,54 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_FFN_DOWN, LLM_TENSOR_FFN_UP, }; + case LLM_ARCH_KIMI_LINEAR: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ROPE_FREQS, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_NORM, + // Dense FFN (layer 0 only) + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + // MoE FFN (layers 1+) + LLM_TENSOR_FFN_GATE_INP, + LLM_TENSOR_FFN_GATE_EXPS, + LLM_TENSOR_FFN_DOWN_EXPS, + LLM_TENSOR_FFN_UP_EXPS, + LLM_TENSOR_FFN_EXP_PROBS_B, + // Shared experts + LLM_TENSOR_FFN_GATE_SHEXP, + LLM_TENSOR_FFN_DOWN_SHEXP, + LLM_TENSOR_FFN_UP_SHEXP, + // KDA (using SSM_ enum prefix, keeping GGUF names for backward compat) + LLM_TENSOR_SSM_CONV1D_Q, + LLM_TENSOR_SSM_CONV1D_K, + LLM_TENSOR_SSM_CONV1D_V, + LLM_TENSOR_SSM_F_A, + LLM_TENSOR_SSM_F_B, + LLM_TENSOR_SSM_BETA, + LLM_TENSOR_SSM_A, + LLM_TENSOR_SSM_G_A, + LLM_TENSOR_SSM_G_B, + LLM_TENSOR_SSM_DT, + LLM_TENSOR_SSM_NORM, + // MLA + LLM_TENSOR_ATTN_Q_A, + LLM_TENSOR_ATTN_Q_B, + LLM_TENSOR_ATTN_Q_A_NORM, + LLM_TENSOR_ATTN_KV_A_MQA, + LLM_TENSOR_ATTN_KV_B, + LLM_TENSOR_ATTN_K_B, + LLM_TENSOR_ATTN_V_B, + LLM_TENSOR_ATTN_KV_A_NORM, + }; default: GGML_ABORT("unknown architecture for tensor mapping"); } @@ -2392,6 +2484,15 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_SSM_C_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_SSM_D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_SSM_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + // Kimi KDA - Conv tensors are 4D [d_conv, 1, d_inner, 1], reshaped to 2D at runtime + {LLM_TENSOR_SSM_CONV1D_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_SSM_CONV1D_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_SSM_CONV1D_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_SSM_F_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_SSM_F_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_SSM_BETA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_SSM_G_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_SSM_G_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_TIME_MIX_LERP_X, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_TIME_MIX_LN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_CHANNEL_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, @@ -2573,6 +2674,7 @@ bool llm_arch_is_hybrid(const llm_arch & arch) { case LLM_ARCH_NEMOTRON_H: case LLM_ARCH_NEMOTRON_H_MOE: case LLM_ARCH_QWEN3NEXT: + case LLM_ARCH_KIMI_LINEAR: return true; default: return false; diff --git a/examples/talk-llama/llama-arch.h b/examples/talk-llama/llama-arch.h index 270d28b16a4..e8263369b80 100644 --- a/examples/talk-llama/llama-arch.h +++ b/examples/talk-llama/llama-arch.h @@ -122,8 +122,10 @@ enum llm_arch { LLM_ARCH_PANGU_EMBED, LLM_ARCH_MISTRAL3, LLM_ARCH_MIMO2, + LLM_ARCH_STEP35, LLM_ARCH_LLAMA_EMBED, LLM_ARCH_MAINCODER, + LLM_ARCH_KIMI_LINEAR, LLM_ARCH_UNKNOWN, }; @@ -165,6 +167,8 @@ enum llm_kv { LLM_KV_EXPERT_FEED_FORWARD_LENGTH, LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, LLM_KV_EXPERT_CHUNK_FEED_FORWARD_LENGTH, + LLM_KV_SWIGLU_CLAMP_EXP, + LLM_KV_SWIGLU_CLAMP_SHEXP, LLM_KV_USE_PARALLEL_RESIDUAL, LLM_KV_TENSOR_DATA_LAYOUT, LLM_KV_EXPERT_COUNT, @@ -250,6 +254,8 @@ enum llm_kv { LLM_KV_SSM_GROUP_COUNT, LLM_KV_SSM_DT_B_C_RMS, + LLM_KV_KDA_HEAD_DIM, + LLM_KV_WKV_HEAD_SIZE, LLM_KV_TOKENIZER_MODEL, @@ -398,6 +404,15 @@ enum llm_tensor { LLM_TENSOR_SSM_NORM, LLM_TENSOR_SSM_OUT, LLM_TENSOR_SSM_BETA_ALPHA, // qwen3next + // Kimi Linear KDA (using SSM_ prefix for consistency) + LLM_TENSOR_SSM_CONV1D_Q, // kimi: Q conv1d weight + LLM_TENSOR_SSM_CONV1D_K, // kimi: K conv1d weight + LLM_TENSOR_SSM_CONV1D_V, // kimi: V conv1d weight + LLM_TENSOR_SSM_F_A, // kimi: forget gate projection A + LLM_TENSOR_SSM_F_B, // kimi: forget gate projection B + LLM_TENSOR_SSM_BETA, // kimi: beta mixing coefficient + LLM_TENSOR_SSM_G_A, // kimi: output gate projection A + LLM_TENSOR_SSM_G_B, // kimi: output gate projection B LLM_TENSOR_TIME_MIX_W0, LLM_TENSOR_TIME_MIX_W1, LLM_TENSOR_TIME_MIX_W2, diff --git a/examples/talk-llama/llama-chat.cpp b/examples/talk-llama/llama-chat.cpp index 3c7e0afdae8..c415a998f33 100644 --- a/examples/talk-llama/llama-chat.cpp +++ b/examples/talk-llama/llama-chat.cpp @@ -233,7 +233,7 @@ int32_t llm_chat_apply_template( llm_chat_template tmpl, const std::vector & chat, std::string & dest, bool add_ass) { - // Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527 + // Taken from the research: https://github.com/ggml-org/llama.cpp/issues/5527 std::stringstream ss; if (tmpl == LLM_CHAT_TEMPLATE_CHATML) { // chatml template diff --git a/examples/talk-llama/llama-context.cpp b/examples/talk-llama/llama-context.cpp index 10b306a8537..a6df893a311 100644 --- a/examples/talk-llama/llama-context.cpp +++ b/examples/talk-llama/llama-context.cpp @@ -317,6 +317,7 @@ llama_context::llama_context( auto dev_type = ggml_backend_dev_type(ggml_backend_get_device(backend.get())); if (dev_type == GGML_BACKEND_DEVICE_TYPE_CPU) { // ignore CPU backend + // TODO: should we ignore ACCEL types too? continue; } auto * dev = ggml_backend_get_device(backend.get()); @@ -1026,11 +1027,7 @@ bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) { llama_sampler_chain_n(sampler) > 0; if (sampler && can_offload) { - ggml_backend_buffer_type_t buft = ggml_backend_dev_buffer_type(model.dev_output()); - auto * host_buft = ggml_backend_dev_host_buffer_type(model.dev_output()); - if (host_buft) { - buft = host_buft; - } + auto * buft = ggml_backend_dev_buffer_type(model.dev_output()); sampler->iface->backend_init(sampler, buft); @@ -2016,7 +2013,7 @@ void llama_context::output_reorder() { // uint32_t llama_context::graph_max_nodes(uint32_t n_tokens) const { - if (model.arch == LLM_ARCH_QWEN3NEXT) { + if (model.arch == LLM_ARCH_QWEN3NEXT || model.arch == LLM_ARCH_KIMI_LINEAR) { return std::max(n_tokens * 40, 32u * model.n_tensors()); } uint32_t res = std::max(1024u, 8u*model.n_tensors()); diff --git a/examples/talk-llama/llama-grammar.cpp b/examples/talk-llama/llama-grammar.cpp index 64ea2fd00a9..2d55070cecc 100644 --- a/examples/talk-llama/llama-grammar.cpp +++ b/examples/talk-llama/llama-grammar.cpp @@ -2,7 +2,7 @@ #include "llama-impl.h" #include "llama-vocab.h" -#include "llama-sampling.h" +#include "llama-sampler.h" #include #include diff --git a/examples/talk-llama/llama-graph.cpp b/examples/talk-llama/llama-graph.cpp index 16d42c4ae3d..bba747d37b5 100644 --- a/examples/talk-llama/llama-graph.cpp +++ b/examples/talk-llama/llama-graph.cpp @@ -13,6 +13,8 @@ #include #include #include +#include +#include #include void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) { @@ -533,6 +535,50 @@ bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) { return res; } +// TODO: Hybrid input classes are a bit redundant. +// Instead of creating a hybrid input, the graph can simply create 2 separate inputs. +// Refactoring is required in the future. +void llm_graph_input_mem_hybrid_k::set_input(const llama_ubatch * ubatch) { + mctx->get_attn()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch); + + mctx->get_attn()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn); + + const int64_t n_rs = mctx->get_recr()->get_n_rs(); + + if (inp_rs->s_copy) { + GGML_ASSERT(ggml_backend_buffer_is_host(inp_rs->s_copy->buffer)); + int32_t * data = (int32_t *) inp_rs->s_copy->data; + + // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n + for (uint32_t i = 0; i < n_rs; ++i) { + data[i] = mctx->get_recr()->s_copy(i); + } + } +} + +bool llm_graph_input_mem_hybrid_k::can_reuse(const llm_graph_params & params) { + const auto * mctx = static_cast(params.mctx); + + this->mctx = mctx; + + bool res = true; + + res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens; + + res &= inp_attn->self_kq_mask->ne[0] == mctx->get_attn()->get_n_kv(); + res &= inp_attn->self_kq_mask->ne[1] == params.ubatch.n_tokens; + + res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs(); + + res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs; + res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs; + + res &= inp_rs->head == mctx->get_recr()->get_head(); + res &= inp_rs->rs_z == mctx->get_recr()->get_rs_z(); + + return res; +} + void llm_graph_input_mem_hybrid_iswa::set_input(const llama_ubatch * ubatch) { const auto * attn_ctx = mctx->get_attn(); @@ -970,6 +1016,26 @@ ggml_tensor * llm_graph_context::build_ffn( switch (type_op) { case LLM_FFN_SILU: if (gate && type_gate == LLM_FFN_PAR) { + // Step35: HF clamps gate (after SiLU) and up before multiplication + if (arch == LLM_ARCH_STEP35 && il >= 0) { + const float limit = hparams.swiglu_clamp_shexp[il]; + constexpr float eps = 1e-6f; + if (limit > eps) { + ggml_tensor * gate_act = ggml_silu(ctx0, cur); + cb(gate_act, "ffn_silu", il); + gate_act = ggml_clamp(ctx0, gate_act, -INFINITY, limit); + cb(gate_act, "ffn_silu_clamped", il); + + tmp = ggml_clamp(ctx0, tmp, -limit, limit); + cb(tmp, "ffn_up_clamped", il); + + cur = ggml_mul(ctx0, gate_act, tmp); + cb(cur, "ffn_swiglu_limited", il); + type_gate = LLM_FFN_SEQ; + break; + } + } + cur = ggml_swiglu_split(ctx0, cur, tmp); cb(cur, "ffn_swiglu", il); type_gate = LLM_FFN_SEQ; @@ -1272,6 +1338,25 @@ ggml_tensor * llm_graph_context::build_moe_ffn( switch (type_op) { case LLM_FFN_SILU: if (gate_exps) { + // Step35: per-layer clamp for routed experts + if (arch == LLM_ARCH_STEP35 && il >= 0) { + const float limit = hparams.swiglu_clamp_exp[il]; + constexpr float eps = 1e-6f; + if (limit > eps) { + ggml_tensor * gate_act = ggml_silu(ctx0, cur); + cb(gate_act, "ffn_moe_silu", il); + gate_act = ggml_clamp(ctx0, gate_act, -INFINITY, limit); + cb(gate_act, "ffn_moe_silu_clamped", il); + + up = ggml_clamp(ctx0, up, -limit, limit); + cb(up, "ffn_moe_up_clamped", il); + + cur = ggml_mul(ctx0, gate_act, up); + cb(cur, "ffn_moe_swiglu_limited", il); + break; + } + } + cur = ggml_swiglu_split(ctx0, cur, up); cb(cur, "ffn_moe_swiglu", il); } else { @@ -2268,6 +2353,17 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const { return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp)); } +llm_graph_input_mem_hybrid_k * llm_graph_context::build_inp_mem_hybrid_k() const { + const auto * mctx_cur = static_cast(mctx); + + auto inp_rs = build_rs_inp_impl (ctx0, ubatch, mctx_cur->get_recr()); + auto inp_attn = build_attn_inp_k_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn()); + + auto inp = std::make_unique(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur); + + return (llm_graph_input_mem_hybrid_k *) res->add_input(std::move(inp)); +} + llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa() const { const auto * mctx_cur = static_cast(mctx); @@ -2419,6 +2515,9 @@ void llm_graph_context::build_sampling() const { return; } + std::array outs; + outs[0] = res->t_logits; + auto inp_sampling = std::make_unique(samplers); res->add_input(std::move(inp_sampling)); @@ -2439,14 +2538,14 @@ void llm_graph_context::build_sampling() const { // add a dummy row of logits // this trick makes the graph static, regardless of which samplers are activated // this is important in order to minimize graph reallocations - // TODO: use `ggml_build_forward_select()` when available (https://github.com/ggml-org/llama.cpp/pull/18550) ggml_tensor * logits_t = ggml_pad(ctx0, res->t_logits, 0, 1, 0, 0); for (const auto & [seq_id, sampler] : samplers) { const auto it = seq_to_logit_row.find(seq_id); // inactive samplers always work on the first row - const auto row_idx = seq_to_logit_row.find(seq_id) != seq_to_logit_row.end() ? it->second : 0; + const auto row_idx = it != seq_to_logit_row.end() ? it->second : 0; + const int i_out = it != seq_to_logit_row.end() ? 1 : 0; ggml_tensor * logits_seq = ggml_view_1d(ctx0, logits_t, logits_t->ne[0], row_idx * logits_t->nb[1]); ggml_format_name(logits_seq, "logits_seq_%d", seq_id); @@ -2463,22 +2562,26 @@ void llm_graph_context::build_sampling() const { if (data.sampled != nullptr) { res->t_sampled[seq_id] = data.sampled; - ggml_build_forward_expand(gf, data.sampled); + outs[1] = data.sampled; + ggml_build_forward_select(gf, outs.data(), outs.size(), i_out); } if (data.probs != nullptr) { res->t_sampled_probs[seq_id] = data.probs; - ggml_build_forward_expand(gf, data.probs); + outs[1] = data.probs; + ggml_build_forward_select(gf, outs.data(), outs.size(), i_out); } if (data.logits != nullptr) { res->t_sampled_logits[seq_id] = data.logits; - ggml_build_forward_expand(gf, data.logits); + outs[1] = data.logits; + ggml_build_forward_select(gf, outs.data(), outs.size(), i_out); } if (data.candidates != nullptr) { res->t_candidates[seq_id] = data.candidates; - ggml_build_forward_expand(gf, data.candidates); + outs[1] = data.candidates; + ggml_build_forward_select(gf, outs.data(), outs.size(), i_out); } } diff --git a/examples/talk-llama/llama-graph.h b/examples/talk-llama/llama-graph.h index 4090d8116c9..1d69ff1a6fc 100644 --- a/examples/talk-llama/llama-graph.h +++ b/examples/talk-llama/llama-graph.h @@ -433,6 +433,34 @@ class llm_graph_input_mem_hybrid : public llm_graph_input_i { const llama_memory_hybrid_context * mctx; }; +class llm_graph_input_mem_hybrid_k : public llm_graph_input_i { +public: + llm_graph_input_mem_hybrid_k( + const llama_cparams & cparams, + std::unique_ptr inp_attn, + std::unique_ptr inp_rs, + const llama_memory_hybrid_context * mctx) : + inp_attn(std::move(inp_attn)), + inp_rs(std::move(inp_rs)), + cparams(cparams), + mctx(mctx) { } + virtual ~llm_graph_input_mem_hybrid_k() = default; + + void set_input(const llama_ubatch * ubatch) override; + + bool can_reuse(const llm_graph_params & params) override; + + std::unique_ptr inp_attn; + std::unique_ptr inp_rs; + + llm_graph_input_attn_k * get_attn() const { return inp_attn.get(); } + llm_graph_input_rs * get_recr() const { return inp_rs.get(); } + + const llama_cparams cparams; + + const llama_memory_hybrid_context * mctx; +}; + class llm_graph_input_mem_hybrid_iswa : public llm_graph_input_i { public: llm_graph_input_mem_hybrid_iswa( @@ -960,6 +988,7 @@ struct llm_graph_context { // llm_graph_input_mem_hybrid * build_inp_mem_hybrid() const; + llm_graph_input_mem_hybrid_k * build_inp_mem_hybrid_k() const; llm_graph_input_mem_hybrid_iswa * build_inp_mem_hybrid_iswa() const; diff --git a/examples/talk-llama/llama-hparams.cpp b/examples/talk-llama/llama-hparams.cpp index 392f9160cef..756dda1a7ab 100644 --- a/examples/talk-llama/llama-hparams.cpp +++ b/examples/talk-llama/llama-hparams.cpp @@ -139,6 +139,13 @@ uint32_t llama_hparams::n_embd_r() const { return n_embd * (n_shortconv_l_cache - 1); } + if (n_embd_head_kda != 0) { + // for Kimi KDA layers + // Conv state for Q, K, V: 3 * (d_conv - 1) * n_head * head_dim + const uint32_t d_inner = n_head() * n_embd_head_kda; // 32 * 128 = 4096 + return 3 * (ssm_d_conv > 0 ? ssm_d_conv - 1 : 3) * d_inner; + } + // TODO: maybe support other convolution strides than 1 // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed // Corresponds to Mamba's conv_states size @@ -151,6 +158,13 @@ uint32_t llama_hparams::n_embd_s() const { return n_embd * wkv_head_size; } + if (n_embd_head_kda != 0) { + // for Kimi KDA layers + // Full recurrent state: head_dim * head_dim * n_head + // h tensor shape for delta attention: [head_dim, head_dim, n_head] + return n_embd_head_kda * n_embd_head_kda * n_head(); // 128 * 128 * 32 = 524288 + } + // corresponds to Mamba's ssm_states size return ssm_d_state * ssm_d_inner; } diff --git a/examples/talk-llama/llama-hparams.h b/examples/talk-llama/llama-hparams.h index caed0ec1b76..6c695bdbf66 100644 --- a/examples/talk-llama/llama-hparams.h +++ b/examples/talk-llama/llama-hparams.h @@ -137,6 +137,9 @@ struct llama_hparams { uint32_t ssm_dt_rank = 0; uint32_t ssm_n_group = 0; + // for Kimi Linear KDA + uint32_t n_embd_head_kda = 0; + // for hybrid state space models std::array recurrent_layer_arr; @@ -195,7 +198,7 @@ struct llama_hparams { uint32_t n_deepstack_layers = 0; // needed by encoder-decoder models (e.g. T5, FLAN-T5) - // ref: https://github.com/ggerganov/llama.cpp/pull/8141 + // ref: https://github.com/ggml-org/llama.cpp/pull/8141 llama_token dec_start_token_id = LLAMA_TOKEN_NULL; uint32_t dec_n_layer = 0; @@ -203,6 +206,11 @@ struct llama_hparams { enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE; enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE; + + // Step35: optional per-layer clamps for (Swi)GLU + std::array swiglu_clamp_exp; // clamping for expert FFN + std::array swiglu_clamp_shexp; // shared expert + // this value n_pattern means that every nth layer is dense (i.e. non-SWA) // dense_first means whether the pattern is start with a dense layer // note that if n_pattern == 0, all layers are SWA diff --git a/examples/talk-llama/llama-kv-cache-iswa.cpp b/examples/talk-llama/llama-kv-cache-iswa.cpp index 3a34102a23d..26e2cb4270b 100644 --- a/examples/talk-llama/llama-kv-cache-iswa.cpp +++ b/examples/talk-llama/llama-kv-cache-iswa.cpp @@ -218,7 +218,9 @@ llama_memory_context_ptr llama_kv_cache_iswa::init_update(llama_context * lctx, } bool llama_kv_cache_iswa::get_can_shift() const { - return kv_base->get_size() == kv_swa->get_size(); + return kv_base->get_can_shift() && + kv_swa->get_can_shift() && + kv_base->get_size() == kv_swa->get_size(); } void llama_kv_cache_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const { diff --git a/examples/talk-llama/llama-kv-cache.cpp b/examples/talk-llama/llama-kv-cache.cpp index f3c9b49f30a..cb702b2a59f 100644 --- a/examples/talk-llama/llama-kv-cache.cpp +++ b/examples/talk-llama/llama-kv-cache.cpp @@ -974,6 +974,10 @@ void llama_kv_cache::apply_ubatch(const slot_info & sinfo, const llama_ubatch & } bool llama_kv_cache::get_can_shift() const { + // Step35 uses per-layer RoPE dims; K-shift assumes a single global n_rot. + if (model.arch == LLM_ARCH_STEP35) { + return false; + } return true; } @@ -1772,8 +1776,6 @@ void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t io.write(&v_trans, sizeof(v_trans)); io.write(&n_layer, sizeof(n_layer)); - std::vector tmp_buf; - // Iterate and write all the keys first, each row is a cell // Get whole range at a time for (const auto & layer : layers) { @@ -1791,7 +1793,7 @@ void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t const uint64_t k_size_row = ggml_row_size(k->type, n_embd_k_gqa); io.write(&k_size_row, sizeof(k_size_row)); - // Read each range of cells of k_size length each into tmp_buf and write out + // Read each range of cells of k_size length and write out for (const auto & range : cr.data) { const size_t range_size = range.second - range.first; const size_t buf_size = range_size * k_size_row; @@ -1818,7 +1820,7 @@ void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t const uint64_t v_size_row = ggml_row_size(v->type, n_embd_v_gqa); io.write(&v_size_row, sizeof(v_size_row)); - // Read each range of cells of v_size length each into tmp_buf and write out + // Read each range of cells of v_size length and write out for (const auto & range : cr.data) { const size_t range_size = range.second - range.first; const size_t buf_size = range_size * v_size_row; @@ -1852,7 +1854,7 @@ void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t // For each row, we get the element values of each cell for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { - // Read each range of cells of v_size_el length each into tmp_buf and write out + // Read each range of cells of v_size_el length and write out for (const auto & range : cr.data) { const size_t range_size = range.second - range.first; const size_t src_offset = (range.first + j * kv_size) * v_size_el; diff --git a/examples/talk-llama/llama-memory-recurrent.cpp b/examples/talk-llama/llama-memory-recurrent.cpp index 812bf253049..f0038036dcb 100644 --- a/examples/talk-llama/llama-memory-recurrent.cpp +++ b/examples/talk-llama/llama-memory-recurrent.cpp @@ -785,23 +785,21 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std:: io.write(&s_trans, sizeof(s_trans)); io.write(&n_layer, sizeof(n_layer)); - std::vector tmp_buf; - - // Iterate and write all the keys first, each row is a cell + // Iterate and write all the R tensors first, each row is a cell // Get whole range at a time for (uint32_t il = 0; il < n_layer; ++il) { // skip null layers (read_data will handle this by checking "r_l" and "s_l" for null) if (r_l[il] == nullptr) continue; - // Write key type + // Write R tensor type const int32_t r_type_i = (int32_t)r_l[il]->type; io.write(&r_type_i, sizeof(r_type_i)); - // Write row size of key + // Write row size of R tensor const uint64_t r_size_row = ggml_row_size(r_l[il]->type, hparams.n_embd_r()); io.write(&r_size_row, sizeof(r_size_row)); - // Read each range of cells of k_size length each into tmp_buf and write out + // Write each range of cells of r_size_row length for (const auto & range : cell_ranges) { const size_t range_size = range.second - range.first; const size_t buf_size = range_size * r_size_row; @@ -814,15 +812,15 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std:: // skip null layers (read_data will handle this by checking "r_l" and "s_l" for null) if (s_l[il] == nullptr) continue; - // Write value type + // Write S tensor type const int32_t s_type_i = (int32_t)s_l[il]->type; io.write(&s_type_i, sizeof(s_type_i)); - // Write row size of value + // Write row size of S tensor const uint64_t s_size_row = ggml_row_size(s_l[il]->type, hparams.n_embd_s()); io.write(&s_size_row, sizeof(s_size_row)); - // Read each range of cells of s_size length each into tmp_buf and write out + // Write each range of S tensor rows for (const auto & range : cell_ranges) { const size_t range_size = range.second - range.first; const size_t buf_size = range_size * s_size_row; @@ -830,7 +828,7 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std:: } } } else { - // When v is transposed, we also need the element size and get the element ranges from each row + // When S tensor is transposed, we also need the element size and get the element ranges from each row const uint32_t mem_size = size; for (uint32_t il = 0; il < n_layer; ++il) { // skip null layers (read_data will handle this by checking "r_l" and "s_l" for null) @@ -838,7 +836,7 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std:: const uint32_t n_embd_s = hparams.n_embd_s(); - // Write value type + // Write S tensor type const int32_t s_type_i = (int32_t)s_l[il]->type; io.write(&s_type_i, sizeof(s_type_i)); @@ -851,7 +849,7 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std:: // For each row, we get the element values of each cell for (uint32_t j = 0; j < n_embd_s; ++j) { - // Read each range of cells of v_size_el length each into tmp_buf and write out + // Write each range of cells of s_size_el length for (const auto & range : cell_ranges) { const size_t range_size = range.second - range.first; const size_t src_offset = (range.first + j * mem_size) * s_size_el; diff --git a/examples/talk-llama/llama-model.cpp b/examples/talk-llama/llama-model.cpp index 72490a89b56..674d06c8910 100644 --- a/examples/talk-llama/llama-model.cpp +++ b/examples/talk-llama/llama-model.cpp @@ -125,10 +125,12 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_21B_A3B: return "21B.A3B"; case LLM_TYPE_30B_A3B: return "30B.A3B"; case LLM_TYPE_31B_A3_5B: return "31B.A3.5B"; + case LLM_TYPE_48B_A3B: return "48B.A3B"; case LLM_TYPE_80B_A3B: return "80B.A3B"; case LLM_TYPE_100B_A6B: return "100B.A6B"; case LLM_TYPE_102B_A12B: return "102B.A12B"; case LLM_TYPE_106B_A12B: return "106B.A12B"; + case LLM_TYPE_196B_A11B: return "196B.A11B"; case LLM_TYPE_230B_A10B: return "230B.A10B"; case LLM_TYPE_235B_A22B: return "235B.A22B"; case LLM_TYPE_300B_A47B: return "300B.A47B"; @@ -559,6 +561,8 @@ void llama_model::load_hparams(llama_model_loader & ml) { std::fill(hparams.xielu_alpha_p.begin(), hparams.xielu_alpha_p.end(), 0.0f); std::fill(hparams.xielu_beta.begin(), hparams.xielu_beta.end(), 0.0f); std::fill(hparams.xielu_eps.begin(), hparams.xielu_eps.end(), 0.0f); + std::fill(hparams.swiglu_clamp_exp.begin(), hparams.swiglu_clamp_exp.end(), 0.0f); + std::fill(hparams.swiglu_clamp_shexp.begin(), hparams.swiglu_clamp_shexp.end(), 0.0f); ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, hparams.n_layer, false); ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer, false); @@ -2450,6 +2454,66 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_KIMI_LINEAR: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla_impl); + ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla_impl); + ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot); + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_KDA_HEAD_DIM, hparams.n_embd_head_kda); + + // MLA qk_rope_head_dim (for reference) + // qk_rope_head_dim = 64, qk_nope_head_dim = 128, qk_head_dim = 192 + + // Mark KDA layers as recurrent using n_head_kv pattern (like Jamba) + // Set n_head_kv = 0 for KDA layers (recurrent), n_head_kv = n_head for MLA layers (attention) + for (uint32_t i = 0; i < hparams.n_layer; ++i) { + hparams.recurrent_layer_arr[i] = hparams.n_head_kv(i) == 0; // KDA layers are recurrent + } + + // MoE parameters - Kimi uses moe_intermediate_size = 1024 + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func); + + switch (hparams.n_layer) { + case 27: type = LLM_TYPE_48B_A3B; break; // Kimi-Linear-48B-A3B + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_STEP35: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + + // MoE + SWA parameters + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + + // Step35 uses sigmoid gating by default (if not set in GGUF) + if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { + hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID; + } + + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa); + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.swa_layers, hparams.n_layer); + ml.get_key_or_arr(LLM_KV_SWIGLU_CLAMP_EXP, hparams.swiglu_clamp_exp, hparams.n_layer, false); + ml.get_key_or_arr(LLM_KV_SWIGLU_CLAMP_SHEXP, hparams.swiglu_clamp_shexp, hparams.n_layer, false); + + switch (hparams.n_layer) { + case 45: type = LLM_TYPE_196B_A11B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; default: throw std::runtime_error("unsupported model architecture"); } @@ -6752,6 +6816,141 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, 0); } } break; + case LLM_ARCH_KIMI_LINEAR: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + // Check for KDA specific tensors to determine layer type or if it's a mixed model + // Assuming KDA layer if KDA tensors are present + + // KDA uses head_dim = 128 (from linear_attn_config.head_dim) + const int64_t n_embd_head_k_kda = hparams.n_embd_head_kda; + const int64_t n_embd_head_v_kda = hparams.n_embd_head_kda; + const int64_t ssm_d_conv = hparams.ssm_d_conv; + + // Try loading KDA specific tensors (using SSM_ prefix) + // Conv1d weights: try 4D first, then 3D (quantization may remove trailing 1) + // 4D: [d_conv, 1, d_inner, 1], 3D: [d_conv, 1, d_inner] + layer.ssm_q_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_Q, "weight", i), {ssm_d_conv, 1, n_embd_head_k_kda * n_head, 1}, TENSOR_NOT_REQUIRED); + if (!layer.ssm_q_conv) { + layer.ssm_q_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_Q, "weight", i), {ssm_d_conv, 1, n_embd_head_k_kda * n_head}, TENSOR_NOT_REQUIRED); + } + + if (layer.ssm_q_conv) { + // KDA Layer - Conv1d weights may be 3D or 4D + layer.ssm_k_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_K, "weight", i), {ssm_d_conv, 1, n_embd_head_k_kda * n_head, 1}, TENSOR_NOT_REQUIRED); + if (!layer.ssm_k_conv) { + layer.ssm_k_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_K, "weight", i), {ssm_d_conv, 1, n_embd_head_k_kda * n_head}, 0); + } + layer.ssm_v_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_V, "weight", i), {ssm_d_conv, 1, n_embd_head_v_kda * n_head, 1}, TENSOR_NOT_REQUIRED); + if (!layer.ssm_v_conv) { + layer.ssm_v_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_V, "weight", i), {ssm_d_conv, 1, n_embd_head_v_kda * n_head}, 0); + } + + // q, k, v projections + // Python: q_proj, k_proj, v_proj + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k_kda * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_head_k_kda * n_head}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_head_v_kda * n_head}, 0); + + // KDA specific projections + // f_a_proj, f_b_proj + layer.ssm_f_a = create_tensor(tn(LLM_TENSOR_SSM_F_A, "weight", i), {n_embd, n_embd_head_k_kda}, 0); // head_dim + layer.ssm_f_b = create_tensor(tn(LLM_TENSOR_SSM_F_B, "weight", i), {n_embd_head_k_kda, n_embd_head_k_kda * n_head}, 0); // projection_size + + // b_proj (beta mixing coefficient) + layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", i), {n_embd, n_head}, 0); + + // A_log - Shape in GGUF: [1, num_heads, 1, 1] (4D) or [1, num_heads] (2D after quantization) Note: -exp(A_log) is applied in convert_hf_to_gguf.py + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_head, 1, 1}, TENSOR_NOT_REQUIRED); + if (!layer.ssm_a) { + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_head}, 0); + } + + // dt_bias - shape [n_embd_head_k_kda * n_head] = [4096] + layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {n_embd_head_k_kda * n_head}, 0); + + // g_a_proj, g_b_proj (output gate) + layer.ssm_g_a = create_tensor(tn(LLM_TENSOR_SSM_G_A, "weight", i), {n_embd, n_embd_head_k_kda}, 0); + layer.ssm_g_b = create_tensor(tn(LLM_TENSOR_SSM_G_B, "weight", i), {n_embd_head_k_kda, n_embd_head_k_kda * n_head}, 0); + + // o_norm (reusing SSM_NORM) + layer.ssm_o_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {n_embd_head_k_kda}, 0); // FusedRMSNormGated + + // o_proj + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_v_kda * n_head, n_embd}, 0); + + } else { + // MLA Layer - use MLA-specific head dimensions + const int64_t q_lora_rank = hparams.n_lora_q; + const int64_t kv_lora_rank = hparams.n_lora_kv; + const int64_t n_embd_head_k_mla = hparams.n_embd_head_k_mla(); + const int64_t n_embd_head_v_mla = hparams.n_embd_head_v_mla(); + + layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, TENSOR_NOT_REQUIRED); + layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, 0); + + if (layer.attn_q_a_norm) { + layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, 0); + layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k_mla}, 0); + } else { + // Kimi MLA without Q compression: wq = [n_embd, n_head * n_embd_head_k_mla] + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_head * n_embd_head_k_mla}, 0); + } + + // Kimi: qk_rope_head_dim = 64 (actual RoPE dimension for MLA) + // Note: hparams.n_rot may be 72 (from conversion) but actual is 64 + const int64_t qk_rope_head_dim = hparams.n_rot; // From config: qk_rope_head_dim + layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + qk_rope_head_dim}, 0); + // Support Legacy GGUFs that don't split wkv_b (MLA KV cache disabled) + layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_k_mla - qk_rope_head_dim + n_embd_head_v_mla)}, TENSOR_NOT_REQUIRED); + if (!layer.wkv_b) { // MLA KV cache enabled + layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K_B, "weight", i), {n_embd_head_k_mla - qk_rope_head_dim, kv_lora_rank, n_head}, 0); + layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_embd_head_v_mla, n_head}, 0); + } + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_embd_head_v_mla, n_embd}, 0); + } + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + // MoE intermediate size (different from dense FFN) + const int64_t n_ff_exp = hparams.n_ff_exp; + + // Kimi uses n_layer_dense_lead to determine which layers use dense FFN vs MoE + // first_k_dense_replace = 1 means layer 0 uses dense FFN, layers 1+ use MoE + if (i < (int) hparams.n_layer_dense_lead) { + // Dense FFN layer - use normal n_ff + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } else { + // MoE layer - use n_ff_exp (1024) instead of n_ff (9216) + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); + + // Shared experts use moe_intermediate_size * num_shared_experts + // Kimi: shared_expert_intermediate_size = 1024 * 1 = 1024 + // Tensors are 2D: [n_embd, n_ff_shexp] or [n_ff_shexp, n_embd] + const int64_t n_ff_shexp_actual = n_ff_exp * (hparams.n_expert_shared > 0 ? hparams.n_expert_shared : 1); + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp_actual}, TENSOR_NOT_REQUIRED); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp_actual, n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp_actual}, TENSOR_NOT_REQUIRED); + + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, 0); + } + } + } break; case LLM_ARCH_COGVLM: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -6940,6 +7139,72 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); } } break; + case LLM_ARCH_STEP35: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + // STEP35 supports per-layer partial RoPE dims; rope factors are stored as a single shared tensor + // ("rope_freqs.weight") and ggml uses only the first (n_rot_l/2) entries per layer. + uint32_t n_rot_max = 0; + for (int i = 0; i < n_layer; ++i) { + n_rot_max = std::max(n_rot_max, hparams.n_rot); + } + if (n_rot_max == 0) { + n_rot_max = n_rot; + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + const uint32_t n_head_l = hparams.n_head(i); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i); + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, TENSOR_NOT_REQUIRED); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, TENSOR_NOT_REQUIRED); + + // optional rope factors (llama3) / longrope tensors + if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot_max/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot_max/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } else { + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot_max/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head_l}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_v * n_head_l, n_embd}, 0); + + // head-wise attention gate (Step35 self_attn.g_proj) + layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), {n_embd, n_head_l}, TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + // dense MLP (leading dense blocks) + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); + + // MoE routed experts + selection bias (router_bias) + const int64_t n_ff_exp = hparams.n_ff_exp; + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); + + // shared expert MLP + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, TENSOR_NOT_REQUIRED); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, TENSOR_NOT_REQUIRED); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, TENSOR_NOT_REQUIRED); + } + } break; case LLM_ARCH_MAINCODER: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -8086,6 +8351,14 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_KIMI_LINEAR: + { + llm = std::make_unique(*this, params); + } break; + case LLM_ARCH_STEP35: + { + llm = std::make_unique(*this, params); + } break; default: GGML_ABORT("fatal error"); } @@ -8235,6 +8508,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_WAVTOKENIZER_DEC: case LLM_ARCH_NEMOTRON_H: case LLM_ARCH_NEMOTRON_H_MOE: + case LLM_ARCH_KIMI_LINEAR: return LLAMA_ROPE_TYPE_NONE; // use what we call a normal RoPE, operating on pairs of consecutive head values @@ -8330,6 +8604,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_AFMOE: case LLM_ARCH_QWEN3NEXT: case LLM_ARCH_MIMO2: + case LLM_ARCH_STEP35: return LLAMA_ROPE_TYPE_NEOX; case LLM_ARCH_QWEN2VL: diff --git a/examples/talk-llama/llama-model.h b/examples/talk-llama/llama-model.h index d1de16e3f28..7b580043b33 100644 --- a/examples/talk-llama/llama-model.h +++ b/examples/talk-llama/llama-model.h @@ -118,10 +118,12 @@ enum llm_type { LLM_TYPE_21B_A3B, // Ernie MoE small LLM_TYPE_30B_A3B, LLM_TYPE_31B_A3_5B, + LLM_TYPE_48B_A3B, // Kimi Linear LLM_TYPE_80B_A3B, // Qwen3 Next LLM_TYPE_100B_A6B, LLM_TYPE_102B_A12B, // Solar-Open LLM_TYPE_106B_A12B, // GLM-4.5-Air + LLM_TYPE_196B_A11B, // Step3.5-Flash LLM_TYPE_230B_A10B, // Minimax M2 LLM_TYPE_235B_A22B, LLM_TYPE_300B_A47B, // Ernie MoE big @@ -411,6 +413,18 @@ struct llama_layer { struct ggml_tensor * ffn_act_beta = nullptr; struct ggml_tensor * ffn_act_eps = nullptr; + // Kimi Linear KDA (using ssm_ prefix for consistency) + // Note: ssm_dt_b already exists above (mamba bias), reused for Kimi dt_bias + struct ggml_tensor * ssm_q_conv = nullptr; + struct ggml_tensor * ssm_k_conv = nullptr; + struct ggml_tensor * ssm_v_conv = nullptr; + struct ggml_tensor * ssm_f_a = nullptr; + struct ggml_tensor * ssm_f_b = nullptr; + struct ggml_tensor * ssm_beta = nullptr; + struct ggml_tensor * ssm_g_a = nullptr; + struct ggml_tensor * ssm_g_b = nullptr; + struct ggml_tensor * ssm_o_norm = nullptr; + struct llama_layer_posnet posnet; struct llama_layer_convnext convnext; diff --git a/examples/talk-llama/llama-quant.cpp b/examples/talk-llama/llama-quant.cpp index 776222cb6f2..a7891647c3d 100644 --- a/examples/talk-llama/llama-quant.cpp +++ b/examples/talk-llama/llama-quant.cpp @@ -787,9 +787,9 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_POS_EMBD, "weight"); quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_TOKEN_TYPES, "weight"); - // do not quantize Mamba's small yet 2D weights + // do not quantize Mamba /Kimi's small conv1d weights // NOTE: can't use LLM_TN here because the layer number is not known - quantize &= name.find("ssm_conv1d.weight") == std::string::npos; + quantize &= name.find("ssm_conv1d") == std::string::npos; quantize &= name.find("shortconv.conv.weight") == std::string::npos; // do not quantize RWKV's small yet 2D weights diff --git a/examples/talk-llama/llama-sampling.cpp b/examples/talk-llama/llama-sampler.cpp similarity index 98% rename from examples/talk-llama/llama-sampling.cpp rename to examples/talk-llama/llama-sampler.cpp index 5dde513065b..9bbc5dbde24 100644 --- a/examples/talk-llama/llama-sampling.cpp +++ b/examples/talk-llama/llama-sampler.cpp @@ -1,4 +1,4 @@ -#include "llama-sampling.h" +#include "llama-sampler.h" #include "llama-impl.h" #include "llama-vocab.h" @@ -1025,11 +1025,7 @@ struct llama_sampler_dist : public llama_sampler_backend { std::mt19937 rng; - // backend input - struct ggml_tensor * inp_uniform; - - ggml_context_ptr inp_ctx; - ggml_backend_buffer_ptr inp_buf; + ggml_tensor * inp_uniform; }; static const char * llama_sampler_dist_name(const struct llama_sampler * smpl) { @@ -1138,37 +1134,10 @@ static bool llama_sampler_dist_backend_init( ggml_backend_buffer_type_t buft) { auto * sctx = (llama_sampler_dist *) smpl->ctx; - // allocate inputs - { - ggml_init_params params = { - /*.mem_size =*/ ggml_tensor_overhead(), - /*.mem_buffer =*/ nullptr, - /*.no_alloc =*/ true, - }; - - sctx->inp_ctx.reset(ggml_init(params)); - - // Create the uniform random scalar input tensor. This will be set by - // llama_sampler_dist_backend_set_input after this graph is built. - sctx->inp_uniform = ggml_new_tensor_1d(sctx->inp_ctx.get(), GGML_TYPE_F32, 1); - ggml_set_name (sctx->inp_uniform, "uniform"); - ggml_set_input(sctx->inp_uniform); - - // Allocate all tensors from our context to the backend - sctx->inp_buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(sctx->inp_ctx.get(), buft)); - - ggml_backend_buffer_clear(sctx->inp_buf.get(), 0); - } - const bool res = llama_sampler_backend_support(smpl, buft); sctx->init(res); - if (!res) { - sctx->inp_ctx.reset(nullptr); - sctx->inp_buf.reset(nullptr); - } - return res; } @@ -1178,8 +1147,13 @@ static void llama_sampler_dist_backend_apply( struct ggml_cgraph * gf, struct llama_sampler_data * data) { GGML_UNUSED(gf); + auto * sctx = (llama_sampler_dist *) smpl->ctx; + sctx->inp_uniform = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1); + ggml_set_name (sctx->inp_uniform, "uniform"); + ggml_set_input(sctx->inp_uniform); + struct ggml_tensor * probs = ggml_soft_max(ctx, data->logits); ggml_set_name(probs, "dist_probs"); @@ -1226,6 +1200,7 @@ static void llama_sampler_dist_backend_apply( static void llama_sampler_dist_backend_set_input(struct llama_sampler * smpl) { auto * sctx = (llama_sampler_dist *) smpl->ctx; + GGML_ASSERT(sctx->inp_uniform != nullptr); // We sample in double precision and cast to float to match rnd numbers of @@ -1262,8 +1237,6 @@ struct llama_sampler * llama_sampler_init_dist(uint32_t seed) { /* .seed_cur = */ seed_cur, /* .rng = */ std::mt19937(seed_cur), /* .inp_uniform = */ nullptr, - /* .inp_ctx = */ nullptr, - /* .inp_buf = */ nullptr, } ); } @@ -3461,9 +3434,6 @@ struct llama_sampler_logit_bias : public llama_sampler_backend { struct ggml_tensor * inp_logit_bias; struct ggml_tensor * inp_logit_idxs; - - ggml_context_ptr inp_ctx; - ggml_backend_buffer_ptr inp_buf; }; static const char * llama_sampler_logit_bias_name(const struct llama_sampler * smpl) { @@ -3526,6 +3496,16 @@ static void llama_sampler_logit_bias_backend_apply( return; } + const size_t n = sctx->logit_bias.size(); + + sctx->inp_logit_bias = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n); + ggml_set_name(sctx->inp_logit_bias, "logit_bias"); + ggml_set_input(sctx->inp_logit_bias); + + sctx->inp_logit_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n); + ggml_set_name(sctx->inp_logit_idxs, "logit_idxs"); + ggml_set_input(sctx->inp_logit_idxs); + ggml_tensor * cur = ggml_fill(ctx, data->logits, 0.0f); cur = ggml_reshape_2d(ctx, cur, 1, ggml_nelements(cur)); @@ -3562,6 +3542,8 @@ static void llama_sampler_logit_bias_backend_set_input(struct llama_sampler * sm static bool llama_sampler_logit_bias_backend_init( struct llama_sampler * smpl, ggml_backend_buffer_type_t buft) { + GGML_UNUSED(buft); + auto * sctx = (llama_sampler_logit_bias *) smpl->ctx; sctx->init(true); @@ -3570,29 +3552,6 @@ static bool llama_sampler_logit_bias_backend_init( return true; } - ggml_init_params params = { - /*.mem_size =*/ 2*ggml_tensor_overhead(), - /*.mem_buffer =*/ nullptr, - /*.no_alloc =*/ true, - }; - - sctx->inp_ctx.reset(ggml_init(params)); - - const size_t n = sctx->logit_bias.size(); - - sctx->inp_logit_bias = ggml_new_tensor_2d(sctx->inp_ctx.get(), GGML_TYPE_F32, 1, n); - ggml_set_name(sctx->inp_logit_bias, "logit_bias"); - ggml_set_input(sctx->inp_logit_bias); - - sctx->inp_logit_idxs = ggml_new_tensor_1d(sctx->inp_ctx.get(), GGML_TYPE_I32, n); - ggml_set_name(sctx->inp_logit_idxs, "logit_idxs"); - ggml_set_input(sctx->inp_logit_idxs); - - // Allocate all tensors from our context to the backend - sctx->inp_buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(sctx->inp_ctx.get(), buft)); - - ggml_backend_buffer_clear(sctx->inp_buf.get(), 0); - return true; } @@ -3628,8 +3587,6 @@ struct llama_sampler * llama_sampler_init_logit_bias( /* .to_search = */ {}, /* .inp_logit_bias = */ nullptr, /* .inp_logit_idxs = */ nullptr, - /* .inp_ctx = */ nullptr, - /* .inp_buf = */ nullptr, } ); } diff --git a/examples/talk-llama/llama-sampling.h b/examples/talk-llama/llama-sampler.h similarity index 92% rename from examples/talk-llama/llama-sampling.h rename to examples/talk-llama/llama-sampler.h index 6a963c0bb73..b9bfc20d251 100644 --- a/examples/talk-llama/llama-sampling.h +++ b/examples/talk-llama/llama-sampler.h @@ -1,7 +1,5 @@ #pragma once -// TODO: rename llama-sampling.h/.cpp to llama-sampler.h/.cpp ? - #include "llama.h" #include diff --git a/examples/talk-llama/llama-vocab.cpp b/examples/talk-llama/llama-vocab.cpp index a23950d007c..6d6bdfa090c 100644 --- a/examples/talk-llama/llama-vocab.cpp +++ b/examples/talk-llama/llama-vocab.cpp @@ -90,7 +90,7 @@ static_assert(std::is_trivially_copyable::value, "llm_symbol is not // // SPM tokenizer // original implementation: -// https://github.com/ggerganov/llama.cpp/commit/074bea2eb1f1349a0118239c4152914aecaa1be4 +// https://github.com/ggml-org/llama.cpp/commit/074bea2eb1f1349a0118239c4152914aecaa1be4 // struct llm_bigram_spm { @@ -285,7 +285,7 @@ struct llm_tokenizer_bpe : llm_tokenizer { // original regex from tokenizer.json //"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", - // adapted: https://github.com/ggerganov/llama.cpp/pull/6920#issuecomment-2080233989 + // adapted: https://github.com/ggml-org/llama.cpp/pull/6920#issuecomment-2080233989 "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", }; break; @@ -1752,26 +1752,33 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { // read bpe merges and populate bpe ranks const int merges_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_MERGES).c_str()); + // Kimi-K2 uses custom tokenization without traditional BPE merges + const bool is_kimi_k2 = (tokenizer_pre == "kimi-k2"); + if (merges_keyidx == -1) { - throw std::runtime_error("cannot find tokenizer merges in model file\n"); - } + if (!is_kimi_k2) { + throw std::runtime_error("cannot find tokenizer merges in model file\n"); + } + // Kimi-K2 doesn't need merges, skip + LLAMA_LOG_INFO("%s: Kimi-K2 tokenizer detected, skipping BPE merges\n", __func__); + } else { + const int n_merges = gguf_get_arr_n(ctx, merges_keyidx); + for (int i = 0; i < n_merges; i++) { + const std::string word = gguf_get_arr_str(ctx, merges_keyidx, i); + //GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0); - const int n_merges = gguf_get_arr_n(ctx, merges_keyidx); - for (int i = 0; i < n_merges; i++) { - const std::string word = gguf_get_arr_str(ctx, merges_keyidx, i); - //GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0); + std::string first; + std::string second; - std::string first; - std::string second; + const size_t pos = word.find(' ', 1); - const size_t pos = word.find(' ', 1); + if (pos != std::string::npos) { + first = word.substr(0, pos); + second = word.substr(pos + 1); + } - if (pos != std::string::npos) { - first = word.substr(0, pos); - second = word.substr(pos + 1); + bpe_ranks.emplace(std::make_pair(first, second), i); } - - bpe_ranks.emplace(std::make_pair(first, second), i); } // default special tokens @@ -2226,6 +2233,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { || t.first == "<|end_of_text|>" // granite || t.first == "" || t.first == "_" + || t.first == "[EOT]" // Kimi-K2 || t.first == "<|end▁of▁sentence|>" // DeepSeek || t.first == "" // smoldocling ) { @@ -2262,6 +2270,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { || t.first == "
"
                         || t.first == "▁
"          // CodeLlama
                         || t.first == "<|code_prefix|>" // GLM-4.5
+                        || t.first == "<|prefix|>"      // Falcon-H1-Tiny-Coder
                         ) {
                     special_fim_pre_id = t.second;
                     if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
@@ -2282,6 +2291,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                         || t.first == ""
                         || t.first == "▁"         // CodeLlama
                         || t.first == "<|code_suffix|>" // GLM-4.5
+                        || t.first == "<|suffix|>"      // Falcon-H1-Tiny-Coder
                         ) {
                     special_fim_suf_id = t.second;
                     if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
@@ -2302,6 +2312,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                         || t.first == ""
                         || t.first == "▁"         // CodeLlama
                         || t.first == "<|code_middle|>" // GLM-4.5
+                        || t.first == "<|middle|>"      // Falcon-H1-Tiny-Coder
                         ) {
                     special_fim_mid_id = t.second;
                     if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
@@ -2319,6 +2330,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                         || t.first == ""
                         || t.first == ""   // Granite
                         || t.first == ""
+                        || t.first == "[PAD]" // Kimi-K2
                         ) {
                     special_fim_pad_id = t.second;
                     if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
@@ -2390,7 +2402,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
 
         // maintain a list of tokens that cause end-of-generation
         // this is currently determined based on the token text, which is obviously not ideal
-        // ref: https://github.com/ggerganov/llama.cpp/issues/9606
+        // ref: https://github.com/ggml-org/llama.cpp/issues/9606
         special_eog_ids.clear();
 
         if (special_fim_pad_id != LLAMA_TOKEN_NULL && special_eog_ids.count(special_fim_pad_id) == 0) {
@@ -2421,6 +2433,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                     || t.first == "<|eom_id|>"
                     || t.first == ""
                     || t.first == "_"
+                    || t.first == "[EOT]" // Kimi-K2
+                    || t.first == "[EOS]" // Kimi-K2
                     || t.first == "<|end_of_text|>"
                     || t.first == "" // smoldocling
                ) {
@@ -3079,7 +3093,7 @@ std::vector llama_vocab::impl::tokenize(
 }
 
 int32_t llama_vocab::impl::token_to_piece(llama_token token, char * buf, int32_t length, int32_t lstrip, bool special) const {
-    // ref: https://github.com/ggerganov/llama.cpp/pull/7587#discussion_r1620983843
+    // ref: https://github.com/ggml-org/llama.cpp/pull/7587#discussion_r1620983843
     static const int attr_special = LLAMA_TOKEN_ATTR_UNKNOWN | LLAMA_TOKEN_ATTR_CONTROL;
     const llama_token_attr attr = token_get_attr(token);
     if (!special && (attr & attr_special)) {
diff --git a/examples/talk-llama/models/deepseek2.cpp b/examples/talk-llama/models/deepseek2.cpp
index 297dca51369..987f449934c 100644
--- a/examples/talk-llama/models/deepseek2.cpp
+++ b/examples/talk-llama/models/deepseek2.cpp
@@ -14,7 +14,7 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr
     const uint32_t kv_lora_rank = hparams.n_lora_kv;
 
     // We have to pre-scale kq_scale and attn_factor to make the YaRN RoPE work correctly.
-    // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
+    // See https://github.com/ggml-org/llama.cpp/discussions/7416 for detailed explanation.
     // And also: https://github.com/ggml-org/llama.cpp/pull/17945 [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX]
 
     // first cancel the adjustment from llama_hparams::yarn_attn_factor_adjust to get the original attn_factor
diff --git a/examples/talk-llama/models/kimi-linear.cpp b/examples/talk-llama/models/kimi-linear.cpp
new file mode 100644
index 00000000000..0f037d1a393
--- /dev/null
+++ b/examples/talk-llama/models/kimi-linear.cpp
@@ -0,0 +1,772 @@
+#include "models.h"
+#include "ggml.h"
+
+#define CHUNK_SIZE 64
+
+// Causal Conv1d function for Q,K,V
+// When qkv is 0, it is Q, 1 is K, 2 is V
+static ggml_tensor * causal_conv1d(ggml_cgraph * gf, ggml_context * ctx0, ggml_tensor * conv_states_all, ggml_tensor * conv_state_all, int64_t qkv, ggml_tensor * x, ggml_tensor * proj_w, ggml_tensor * conv_w, int64_t d_conv, int64_t head_dim, int64_t n_head, int64_t n_seq_tokens, int64_t n_seqs, int64_t n_tokens, int64_t kv_head) {
+    const int64_t d_inner = head_dim * n_head;
+    const int64_t conv_state_size = (d_conv - 1) * d_inner;
+    const int64_t n_embd_r_total = 3 * conv_state_size;  // Q + K + V
+
+    // conv_state_all is [n_embd_r_total, n_seqs], split into Q, K, V
+    // Each conv state is [(d_conv-1) * d_inner] per sequence, need to reshape to [d_conv-1, d_inner, n_seqs]
+    // Memory layout: for each seq, Q state is first conv_state_size elements, then K, then V
+    // conv_state_all has stride: nb[0] = element_size, nb[1] = n_embd_r_total * element_size
+    // View Q conv state: offset 0, size conv_state_size per seq
+    // conv_state_all is [n_embd_r_total, n_seqs] with memory layout:
+    //   state[i + seq * n_embd_r_total] where i = conv_step + channel * (d_conv-1) + {0, conv_state_size, 2*conv_state_size} for Q/K/V
+    // We want [d_conv-1, d_inner, n_seqs] view:
+    //   nb1 = (d_conv-1) * element_size (stride between channels)
+    //   nb2 = n_embd_r_total * element_size (stride between seqs)
+    ggml_tensor * conv_state_x = ggml_view_3d(ctx0, conv_state_all, d_conv - 1, d_inner, n_seqs,
+        (d_conv - 1) * ggml_element_size(conv_state_all),  // nb1: stride between channels
+        n_embd_r_total * ggml_element_size(conv_state_all),  // nb2: stride between seqs
+        qkv * conv_state_size * ggml_element_size(conv_state_all));
+
+// Causal Conv1d function for Q,K,V
+// When qkv is 0, it is Q, 1 is K, 2 is V
+    // Step 1: Q, K, V projections -> [d_inner, n_tokens]
+    ggml_tensor * x_proj = ggml_mul_mat(ctx0, proj_w, x);
+
+    // Reshape input: {d_inner, n_tokens} -> {d_inner, n_seq_tokens, n_seqs}
+    ggml_tensor * x_3d = ggml_reshape_3d(ctx0, x_proj, d_inner, n_seq_tokens, n_seqs);
+
+    // Concat Q conv state and current input: {d_conv-1 + n_seq_tokens, d_inner, n_seqs}
+    ggml_tensor * conv_x = ggml_concat(ctx0, conv_state_x, ggml_transpose(ctx0, x_3d), 0);
+
+    // Save last (d_conv-1) columns back to Q conv state
+    ggml_tensor * last_conv_x = ggml_view_3d(ctx0, conv_x, d_conv - 1, d_inner, n_seqs,
+        conv_x->nb[1], conv_x->nb[2], n_seq_tokens * conv_x->nb[0]);
+    ggml_build_forward_expand(gf,
+        ggml_cpy(ctx0, last_conv_x,
+            ggml_view_1d(ctx0, conv_states_all, conv_state_size * n_seqs,
+                (kv_head * n_embd_r_total + qkv * conv_state_size) * ggml_element_size(conv_states_all))));
+    // Reshape conv weight: GGUF [d_conv, 1, d_inner, 1] -> ggml_ssm_conv expects [d_conv, d_inner]
+    // GGUF stores as [d_conv, 1, d_inner, 1] with memory layout w[conv_step + channel * d_conv]
+    // vLLM stores as [d_inner, d_conv] with memory layout w[channel * d_conv + conv_step]
+    // ggml_ssm_conv computes: c[conv_step + channel * d_conv]
+    // GGUF layout: [d_conv, 1, d_inner] or [d_conv, 1, d_inner, 1] -> reshape to [d_conv, d_inner]
+    // Reshape conv weight from [d_conv, 1, d_inner, 1] to [d_conv, d_inner] for ggml_ssm_conv
+    ggml_tensor * conv_weight = ggml_reshape_2d(ctx0, conv_w, d_conv, d_inner);
+
+    // Apply conv1d
+    // ggml_ssm_conv output: {d_inner, n_seq_tokens, n_seqs}
+    ggml_tensor * Xcur = ggml_ssm_conv(ctx0, conv_x, conv_weight);
+    // Reshape to 2D for bias add: {d_inner, n_tokens}
+    Xcur = ggml_reshape_2d(ctx0, Xcur, d_inner, n_tokens);
+    Xcur = ggml_silu(ctx0, Xcur);
+
+    return ggml_reshape_4d(ctx0, Xcur, head_dim, n_head, n_seq_tokens, n_seqs);
+}
+
+llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const llm_graph_params & params) :
+    llm_graph_context_mamba(params), model(model) {
+    ggml_tensor * cur;
+    ggml_tensor * inpL;
+
+    inpL = build_inp_embd(model.tok_embd);
+    cb(inpL, "model.embed_tokens", -1);
+
+    // Note: Kimi MLA does NOT use RoPE (rotary_emb=None in vLLM)
+    // So we don't need inp_pos
+
+    auto * inp_kv = !hparams.is_mla() ? build_inp_mem_hybrid() : nullptr;
+    auto * inp_k = hparams.is_mla() ? build_inp_mem_hybrid_k() : nullptr;
+    auto * inp_rs = hparams.is_mla() ? inp_k->get_recr() : inp_kv->get_recr();
+    auto * inp_attn_kv = !hparams.is_mla() ? inp_kv->get_attn() : nullptr;
+    auto * inp_attn_k = hparams.is_mla() ? inp_k->get_attn() : nullptr;
+
+    // Output ids for selecting which tokens to output
+    ggml_tensor * inp_out_ids = build_inp_out_ids();
+
+    ggml_tensor * chunked_causal_mask =
+        ggml_tri(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, CHUNK_SIZE, CHUNK_SIZE), 1.0f),
+                    GGML_TRI_TYPE_LOWER);
+
+    ggml_tensor * chunked_identity = ggml_diag(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, CHUNK_SIZE), 1.0f));
+    ggml_tensor * chunked_diag_mask = ggml_add(ctx0, chunked_causal_mask, chunked_identity);
+
+    ggml_build_forward_expand(gf, chunked_causal_mask);
+    ggml_build_forward_expand(gf, chunked_identity);
+    ggml_build_forward_expand(gf, chunked_diag_mask);
+
+    // Kimi dimension constants
+    const int64_t n_head = hparams.n_head();
+    const int64_t head_dim = hparams.n_embd_head_kda;
+    const int64_t d_conv = hparams.ssm_d_conv;
+    const int64_t d_inner = n_head * head_dim;  // 32 * 128 = 4096
+    const int64_t n_seqs = ubatch.n_seqs;
+    const int64_t n_seq_tokens = ubatch.n_seq_tokens;
+
+    // Verify batch consistency for recurrent layers
+    GGML_ASSERT(n_seqs != 0);
+    GGML_ASSERT(ubatch.equal_seqs());
+    GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
+
+    // MLA params
+    const int64_t n_embd_head_k_mla = hparams.n_embd_head_k_mla();
+    const int64_t n_embd_head_v_mla = hparams.n_embd_head_v_mla();
+    const int64_t kv_lora_rank = hparams.n_lora_kv;
+    // qk_rope_head_dim = 64 (from Kimi config) which is hparams.n_rot
+    // Confirmed from tensor shape: wkv_a_mqa [2304, 576] = [n_embd, kv_lora_rank + qk_rope_head_dim]
+    const int64_t n_embd_head_qk_rope = hparams.n_rot;  // config.qk_rope_head_dim
+    const int64_t n_embd_head_qk_nope = n_embd_head_k_mla - n_embd_head_qk_rope;  // 192 - 64 = 128
+    // Attention scale for MLA
+    const float kq_scale_mla = 1.0f / sqrtf((float)n_embd_head_k_mla);
+
+    for (int il = 0; il < n_layer; ++il) {
+        const auto & layer = model.layers[il];
+        ggml_tensor * inpSA = inpL;
+
+        // Attention Norm
+        cur = build_norm(inpL, layer.attn_norm, NULL, LLM_NORM_RMS, il);
+        cb(cur, "attn_norm", il);
+
+        // Check layer type by checking which tensors exist
+        // KDA layers have ssm_a_log tensor, MLA layers have wkv_a_mqa tensor
+        bool is_kda = (layer.ssm_a != nullptr);
+        bool is_mla = (layer.wkv_a_mqa != nullptr);
+
+        if (is_kda) {
+            // === KDA Layer (Kimi Delta Attention) with Recurrent State ===
+            // Reference: vLLM kda.py
+            const auto * mctx_cur = inp_rs->mctx;
+            const auto kv_head = mctx_cur->get_head();
+
+            // Get conv states from r_l tensor (Q, K, V each have separate state)
+            ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
+            cb(conv_states_all, "conv_states_all", il);
+            ggml_tensor * conv_state_all = build_rs(inp_rs, conv_states_all, hparams.n_embd_r(), n_seqs);
+            ggml_tensor * Qcur = causal_conv1d(gf, ctx0, conv_states_all, conv_state_all, 0, cur, layer.wq, layer.ssm_q_conv, d_conv, head_dim, n_head, n_seq_tokens, n_seqs, n_tokens, kv_head);
+            ggml_tensor * Kcur = causal_conv1d(gf, ctx0, conv_states_all, conv_state_all, 1, cur, layer.wk, layer.ssm_k_conv, d_conv, head_dim, n_head, n_seq_tokens, n_seqs, n_tokens, kv_head);
+            ggml_tensor * Vcur = causal_conv1d(gf, ctx0, conv_states_all, conv_state_all, 2, cur, layer.wv, layer.ssm_v_conv, d_conv, head_dim, n_head, n_seq_tokens, n_seqs, n_tokens, kv_head);
+
+            // g1 = -exp(A_log) * softplus(f_b(f_a(x)) + dt_bias)
+            ggml_tensor * f_a = ggml_mul_mat(ctx0, layer.ssm_f_a, cur);
+            ggml_tensor * g1 = ggml_mul_mat(ctx0, layer.ssm_f_b, f_a);
+            cb(g1, "g1 f_b(f_a(cur))", il);
+            g1 = ggml_add(ctx0, g1, layer.ssm_dt_b);
+            g1 = ggml_softplus(ctx0, g1);
+            g1 = ggml_reshape_3d(ctx0, g1, head_dim, n_head, n_tokens);
+
+            // A_log shape is [1, n_head] or [1, n_head, 1, 1], need to broadcast to [head_dim, n_head, n_tokens]. No need to -exp(a_log) because it was done in convert_hf_to_gguf.py
+            // Reshape to [1, n_head, 1] for broadcasting with g1 [head_dim, n_head, n_tokens]
+            ggml_tensor * A = ggml_reshape_3d(ctx0, layer.ssm_a, 1, n_head, 1);
+            g1 = ggml_mul(ctx0, g1, A);
+            cb(g1, "kda_g1", il);
+
+            // Compute beta (mixing coefficient)
+            ggml_tensor * beta = ggml_mul_mat(ctx0, layer.ssm_beta, cur);
+            beta = ggml_reshape_4d(ctx0, beta, n_head, 1, n_seq_tokens, n_seqs);
+            cb(beta, "kda_beta", il);
+
+            // Reshape for KDA recurrence
+            // {n_embd, n_tokens} -> {n_embd, n_seq_tokens, n_seqs}
+            cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs);
+
+            g1 = ggml_reshape_4d(ctx0, g1, head_dim, n_head, n_seq_tokens, n_seqs);
+
+            // Get SSM state and compute KDA recurrence using ggml_kda_scan
+            ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il);
+            ggml_tensor * state = build_rs(inp_rs, ssm_states_all, hparams.n_embd_s(), n_seqs);
+            state = ggml_reshape_4d(ctx0, state, head_dim, head_dim, n_head, n_seqs);
+            // Choose between build_kda_chunking and build_kda_recurrent based on n_tokens
+            std::pair attn_out = n_seq_tokens == 1 ?
+                build_kda_autoregressive(Qcur, Kcur, Vcur, g1, beta, state, il) :
+                build_kda_chunking(Qcur, Kcur, Vcur, g1, beta, state, chunked_causal_mask, chunked_identity, chunked_diag_mask, il);
+
+            ggml_tensor * output = attn_out.first;
+            ggml_tensor * new_state = attn_out.second;
+            cb(output, "attn_output", il);
+            cb(new_state, "new_state", il);
+
+            // Update the recurrent states
+            ggml_build_forward_expand(gf,
+                                     ggml_cpy(ctx0, new_state,
+                                              ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs,
+                                                           kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all))));
+
+            // Output gating g2 = g_b(g_a(x))
+            ggml_tensor * cur_2d = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs);
+            ggml_tensor * g_a = ggml_mul_mat(ctx0, layer.ssm_g_a, cur_2d);
+            ggml_tensor * g2 = ggml_mul_mat(ctx0, layer.ssm_g_b, g_a);
+            cb(g2, "g2 g_b(g_a(cur_2d))", il);
+            g2 = ggml_reshape_3d(ctx0, g2, head_dim, n_head, n_seq_tokens * n_seqs);
+
+            // Apply o_norm with sigmoid gating
+            // Note: Kimi model uses sigmoid gating, not SiLU (despite FusedRMSNormGated default being swish)
+            // Formula: output = RMSNorm(x) * sigmoid(g)
+            ggml_tensor * attn_out_final = ggml_reshape_3d(ctx0, output, head_dim, n_head,  n_seq_tokens * n_seqs);
+            ggml_tensor * normed = build_norm(attn_out_final, layer.ssm_o_norm, nullptr, LLM_NORM_RMS, il);
+            cb(normed, "kda_normed", il);
+            ggml_tensor * gate = ggml_sigmoid(ctx0, g2);
+            ggml_tensor * gated = ggml_mul(ctx0, normed, gate);
+
+            // Output projection
+            gated = ggml_cont_2d(ctx0, gated, d_inner, n_tokens);
+            cur = ggml_mul_mat(ctx0, layer.wo, gated);
+            cb(cur, "kda_out", il);
+
+        } else if (is_mla) {
+            // === MLA Layer (Multi-head Latent Attention) without KV Cache ===
+            // Reference: vLLM mla.py
+            // Step 1: Q projection and reshape
+            // vLLM Kimi: q = q_proj(hidden_states), then view as [n_tokens, n_head, qk_head_dim]
+            // Note: Kimi MLA does NOT use RoPE (rotary_emb=None in vLLM)
+            ggml_tensor * Qcur = ggml_mul_mat(ctx0, layer.wq, cur);
+
+            // Step 2: KV compression
+            // kv_cmpr_pe = kv_a_proj_with_mqa(hidden_states) -> [kv_lora_rank + qk_rope_head_dim, n_tokens]
+            ggml_tensor * kv_cmpr_pe = ggml_mul_mat(ctx0, layer.wkv_a_mqa, cur);
+
+            // Split: kv_cmpr = kv_lora[:kv_lora_rank], k_pe = kv_lora[kv_lora_rank:]
+            ggml_tensor * kv_cmpr = ggml_view_2d(ctx0, kv_cmpr_pe, kv_lora_rank, n_tokens,
+                ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope), 0);
+            ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_cmpr_pe, n_embd_head_qk_rope, 1, n_tokens,
+                ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope),
+                ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope),
+                ggml_row_size(kv_cmpr_pe->type, kv_lora_rank));
+            // Note: Kimi MLA does NOT apply RoPE (rotary_emb=None in vLLM)
+            // k_pe is used directly without RoPE
+            // Normalize kv_c
+            kv_cmpr = build_norm(kv_cmpr, layer.attn_kv_a_norm, nullptr, LLM_NORM_RMS, il);
+
+            if (layer.wk_b && layer.wv_b) { // MLA KV cache enabled
+                // extract q_nope
+                ggml_tensor * q_nope =
+                    ggml_view_3d(ctx0, Qcur, n_embd_head_qk_nope, n_head, n_tokens, ggml_row_size(Qcur->type, n_embd_head_k_mla),
+                                 ggml_row_size(Qcur->type, n_embd_head_k_mla) * n_head, 0);
+                cb(q_nope, "q_nope", il);
+
+                // and {n_embd_head_qk_rope, n_head, n_tokens}
+                ggml_tensor * q_pe = ggml_view_3d(
+                    ctx0, Qcur, n_embd_head_qk_rope, n_head, n_tokens, ggml_row_size(Qcur->type, n_embd_head_k_mla),
+                    ggml_row_size(Qcur->type, n_embd_head_k_mla) * n_head, ggml_row_size(Qcur->type, n_embd_head_qk_nope));
+                cb(q_pe, "q_pe", il);
+
+                // {n_embd_head_qk_nope, n_tokens, n_head}
+                q_nope = ggml_permute(ctx0, q_nope, 0, 2, 1, 3);
+                cb(q_nope, "q_nope_perm", il);
+
+                // {n_embd_head_qk_nope, kv_lora_rank, n_head} x {n_embd_head_qk_nope, n_tokens, n_head}
+                ggml_tensor * q_nope_absorbed = ggml_mul_mat(ctx0, layer.wk_b, q_nope);
+                cb(q_nope_absorbed, "q_nope_absorbed", il);
+
+                // {kv_lora_rank, n_head, n_tokens}
+                q_nope_absorbed = ggml_permute(ctx0, q_nope_absorbed, 0, 2, 1, 3);
+                cb(q_nope_absorbed, "q_nope_absorbed_perm", il);
+
+                // {n_embd_head_qk_rope + kv_lora_rank, n_head, n_tokens}
+                // note: rope must go first for in-place context shifting in build_rope_shift()
+                Qcur = ggml_concat(ctx0, q_nope_absorbed, q_pe, 0);
+                cb(Qcur, "Qcur", il);
+
+                kv_cmpr = ggml_reshape_3d(ctx0, kv_cmpr, kv_lora_rank, 1, n_tokens);
+                cb(kv_cmpr, "kv_cmpr_reshape", il);
+
+                // {n_embd_head_qk_rope + kv_lora_rank, 1, n_tokens}
+                ggml_tensor * Kcur = ggml_concat(ctx0, kv_cmpr, k_pe, 0);
+                cb(Kcur, "Kcur", il);
+
+                // {kv_lora_rank, 1, n_tokens}
+                ggml_tensor * Vcur = kv_cmpr;
+                cb(Vcur, "Vcur", il);
+
+                cur = build_attn(inp_attn_k, layer.wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, layer.wv_b, kq_scale_mla, il);
+                cb(cur, "mla_out", il);
+            } else { // MLA KV cache disabled. Fall back to MHA KV cache.
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head_k_mla, n_head, n_tokens);
+                cb(Qcur, "mla_Q", il);
+                // KV decompression: kv = kv_b_proj(kv_c_normed)
+                ggml_tensor * kv = ggml_mul_mat(ctx0, layer.wkv_b, kv_cmpr);
+                const int64_t kv_per_head = n_embd_head_qk_nope + n_embd_head_v_mla;
+
+                // Split kv into k_nope and v
+                ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens,
+                    ggml_row_size(kv->type, kv_per_head),
+                    ggml_row_size(kv->type, kv_per_head * n_head), 0);
+                ggml_tensor * Vcur = ggml_view_3d(ctx0, kv, n_embd_head_v_mla, n_head, n_tokens,
+                    ggml_row_size(kv->type, kv_per_head),
+                    ggml_row_size(kv->type, kv_per_head * n_head),
+                    ggml_row_size(kv->type, n_embd_head_qk_nope));
+                Vcur = ggml_cont(ctx0, Vcur);
+                cb(Vcur, "mla_V", il);
+
+                // Concatenate k_nope + k_pe (broadcast k_pe to all heads)
+                // K = [k_nope, k_pe] where k_nope is [qk_nope_head_dim, n_head, n_tokens]
+                // and k_pe is [qk_rope_head_dim, 1, n_tokens] broadcast to all heads
+                // Need to broadcast k_pe from [qk_rope, 1, n_tokens] to [qk_rope, n_head, n_tokens]
+                ggml_tensor * k_pe_target = ggml_new_tensor_3d(ctx0, k_pe->type, n_embd_head_qk_rope, n_head, n_tokens);
+                ggml_tensor * k_pe_repeated = ggml_repeat(ctx0, k_pe, k_pe_target);
+                ggml_tensor * Kcur = ggml_concat(ctx0, k_pe_repeated, k_nope, 0);
+                cb(Kcur, "mla_K", il);
+
+                // Direct softmax attention (with MHA KV cache)
+                // Use build_attn with inp_attn for proper mask handling
+                cur = build_attn(inp_attn_kv, layer.wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale_mla, il);
+                cb(cur, "mla_out", il);
+            }
+        } else {
+            // Unknown layer type - this should not happen
+            GGML_ABORT("Kimi layer is neither KDA nor MLA - missing required tensors");
+        }
+
+        // On last layer, select only the output tokens
+        if (il == n_layer - 1 && inp_out_ids) {
+            cur   = ggml_get_rows(ctx0, cur,   inp_out_ids);
+            inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+        }
+
+        // Residual
+        ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+        cb(ffn_inp, "ffn_inp", il);
+
+        // FFN Norm
+        cur = build_norm(ffn_inp, layer.ffn_norm, NULL, LLM_NORM_RMS, il);
+        cb(cur, "ffn_norm", il);
+
+        if ((uint32_t) il < hparams.n_layer_dense_lead) {
+            // Dense FFN layer
+            cur = build_ffn(cur,
+                layer.ffn_up, NULL, NULL,
+                layer.ffn_gate, NULL, NULL,
+                layer.ffn_down, NULL, NULL,
+                NULL, LLM_FFN_SILU, LLM_FFN_PAR, il);
+            cb(cur, "ffn_out", il);
+        } else {
+            // MoE layer
+            // Kimi uses moe_renormalize=True and routed_scaling_factor (stored as expert_weights_scale) = 2.446
+            ggml_tensor * moe_out = build_moe_ffn(cur,
+                layer.ffn_gate_inp,
+                layer.ffn_up_exps,
+                layer.ffn_gate_exps,
+                layer.ffn_down_exps,
+                layer.ffn_exp_probs_b,
+                hparams.n_expert,
+                hparams.n_expert_used,
+                LLM_FFN_SILU, true,
+                true, hparams.expert_weights_scale,
+                (llama_expert_gating_func_type) hparams.expert_gating_func,
+                il);
+            cb(moe_out, "ffn_moe_out", il);
+
+            // Shared expert
+            {
+                ggml_tensor * ffn_shexp = build_ffn(cur,
+                        layer.ffn_up_shexp, NULL, NULL,
+                        layer.ffn_gate_shexp, NULL, NULL,
+                        layer.ffn_down_shexp, NULL, NULL,
+                        NULL, LLM_FFN_SILU, LLM_FFN_PAR, il);
+                cb(ffn_shexp, "ffn_shexp", il);
+
+                cur = ggml_add(ctx0, moe_out, ffn_shexp);
+                cb(cur, "ffn_out", il);
+            }
+        }
+        // Residual
+        cur = ggml_add(ctx0, cur, ffn_inp);
+
+        cur = build_cvec(cur, il);
+        cb(cur, "l_out", il);
+
+        inpL = cur;
+    }
+    cur = inpL;
+
+    // Final Norm
+    cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1);
+
+    cb(cur, "result_norm", -1);
+    res->t_embd = cur;
+
+    // Output
+    cur = ggml_mul_mat(ctx0, model.output, cur);
+    cb(cur, "result_output", -1);
+    res->t_logits = cur;
+
+    ggml_build_forward_expand(gf, cur);
+}
+
+/*
+    This is a ggml implementation of the naive_chunk_kda function of
+    https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/kda/naive.py
+*/
+std::pair llm_build_kimi_linear::build_kda_chunking(
+        ggml_tensor * q,
+        ggml_tensor * k,
+        ggml_tensor * v,
+        ggml_tensor * gk,
+        ggml_tensor * beta,
+        ggml_tensor * state,
+        ggml_tensor * causal_mask,
+        ggml_tensor * identity,
+        ggml_tensor * diag_mask,
+        int           il) {
+    GGML_ASSERT(ggml_is_contiguous(state));
+
+    const int64_t S_k      = q->ne[0];
+    const int64_t H_k      = q->ne[1];
+    const int64_t n_tokens = q->ne[2];
+    const int64_t n_seqs   = q->ne[3];
+
+    const int64_t S_v = v->ne[0];
+    const int64_t H_v = v->ne[1];
+
+    GGML_ASSERT(v->ne[2] == n_tokens);
+    GGML_ASSERT(k->ne[2] == n_tokens);
+    GGML_ASSERT(gk->ne[0] == S_v && gk->ne[1] == H_v && gk->ne[2] == n_tokens && gk->ne[3] == n_seqs);
+    GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs);
+    GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v && state->ne[2] == H_v && state->ne[3] == n_seqs);
+
+    GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
+    GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
+
+    GGML_ASSERT(H_k == H_v);  // we did a repeat to make sure this is the case
+
+    // TODO: can this ever be false?
+    const bool use_qk_l2norm = true;
+
+    if (use_qk_l2norm) {
+        const float eps_norm = hparams.f_norm_rms_eps;
+
+        q = ggml_l2_norm(ctx0, q, eps_norm);
+        k = ggml_l2_norm(ctx0, k, eps_norm);
+    }
+
+    const float scale = 1.0f / sqrtf(S_v);
+
+    beta = ggml_sigmoid(ctx0, beta);
+
+    cb(q, "q_in", il);
+    cb(k, "k_in", il);
+    cb(v, "v_in", il);
+    cb(beta, "beta_in", il);
+    cb(gk, "gk_in", il);
+
+    q = ggml_cont_4d(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3), S_k, n_tokens, H_k, n_seqs);
+    k = ggml_cont_4d(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3), S_k, n_tokens, H_k, n_seqs);
+    v = ggml_cont_4d(ctx0, ggml_permute(ctx0, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
+    gk = ggml_cont_4d(ctx0, ggml_permute(ctx0, gk, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs);
+
+    beta  = ggml_cont(ctx0, ggml_permute(ctx0, beta, 2, 0, 1, 3));
+    state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs);
+
+    cb(q, "q_perm", il);
+    cb(k, "k_perm", il);
+    cb(v, "v_perm", il);
+    cb(beta, "beta_perm", il);
+    cb(gk, "gk_perm", il);
+    cb(state, "state_in", il);
+
+    GGML_ASSERT(q->ne[1] == n_tokens && q->ne[0] == S_k && q->ne[2] == H_k && q->ne[3] == n_seqs);
+    GGML_ASSERT(k->ne[1] == n_tokens && k->ne[0] == S_k && k->ne[2] == H_k && k->ne[3] == n_seqs);
+    GGML_ASSERT(v->ne[1] == n_tokens && v->ne[0] == S_v && v->ne[2] == H_k && v->ne[3] == n_seqs);
+    GGML_ASSERT(beta->ne[1] == n_tokens && beta->ne[2] == H_k && beta->ne[0] == 1 && beta->ne[3] == n_seqs);
+
+    // Do padding
+    const int64_t chunk_size = CHUNK_SIZE;
+
+    const int64_t pad = (chunk_size - n_tokens % chunk_size) % chunk_size;
+    const int64_t n_chunks = (n_tokens + pad) / chunk_size;
+
+    q = ggml_pad(ctx0, q, 0, pad, 0, 0);
+    k = ggml_pad(ctx0, k, 0, pad, 0, 0);
+    v = ggml_pad(ctx0, v, 0, pad, 0, 0);
+    gk = ggml_pad(ctx0, gk, 0, pad, 0, 0);
+    beta = ggml_pad(ctx0, beta, 0, pad, 0, 0);
+
+    cb(q, "q_pad", il);
+    cb(k, "k_pad", il);
+    cb(v, "v_pad", il);
+    cb(beta, "beta_pad", il);
+    cb(gk, "gk_pad", il);
+
+    ggml_tensor * v_beta = ggml_mul(ctx0, v, beta);
+    ggml_tensor * k_beta = ggml_mul(ctx0, k, beta);
+
+    cb(v_beta, "v_beta", il);
+    cb(k_beta, "k_beta", il);
+
+    const int64_t HB = H_k * n_seqs;
+
+    q      = ggml_cont_4d(ctx0, q,      S_k, chunk_size, n_chunks, HB);
+    k      = ggml_cont_4d(ctx0, k,      S_k, chunk_size, n_chunks, HB);
+    k_beta = ggml_cont_4d(ctx0, k_beta, S_k, chunk_size, n_chunks, HB);
+    v      = ggml_cont_4d(ctx0, v,      S_v, chunk_size, n_chunks, HB);
+    v_beta = ggml_cont_4d(ctx0, v_beta, S_v, chunk_size, n_chunks, HB);
+
+    gk    = ggml_cont_4d(ctx0, gk, S_k, chunk_size, n_chunks, HB);
+    beta = ggml_cont_4d(ctx0, beta, 1, chunk_size, n_chunks, HB);
+
+    // switch for cumsum
+    gk = ggml_cont_4d(ctx0, ggml_permute(ctx0, gk, 1, 0, 2, 3), chunk_size, S_k, n_chunks, HB);
+    cb(gk, "gk", il);
+    ggml_tensor * gk_cumsum = ggml_cumsum(ctx0, gk);
+    cb(gk_cumsum, "gk_cumsum", il);
+
+/*
+    Compute Akk and Aqk loop together
+    Akk loop:
+    for i in range(BT):
+        k_i = k[..., i, :] # k_i [B,H,NT,S]
+        g_i = g[..., i:i+1, :] # g_i [B,H,NT,1,S]
+        A[..., i] = torch.einsum('... c d, ... d -> ... c', k * (g - g_i).exp(), k_i)
+    Aqk loop:
+    for j in range(BT):
+        k_j = k[:, :, i, j]
+        g_j = g[:, :, i, j:j+1, :]
+        A[..., j] = torch.einsum('... c d, ... d -> ... c', q_i * (g_i - g_j).exp(), k_j)
+*/
+    const int64_t CHB = n_chunks * H_k * n_seqs;
+    ggml_tensor * gkcs_i = ggml_reshape_4d(ctx0, gk_cumsum, chunk_size, 1, S_k, CHB);  // [chunk_size, 1, S_k, CHB]
+    ggml_tensor * gkcs_j = ggml_reshape_4d(ctx0, gkcs_i, 1, chunk_size, S_k, CHB);  // [1, chunk_size, S_k, CHB]
+
+    ggml_tensor * gkcs_j_bc = ggml_repeat_4d(ctx0, gkcs_j, chunk_size, chunk_size, S_k, CHB);  // [1, chunk_size, S_k, CHB] -> [chunk_size, chunk_size, S_k, CHB]
+    // decay_mask [chunk_size,chunk_size,S_k,CHB]
+    ggml_tensor * decay_mask = ggml_sub(ctx0, gkcs_j_bc, gkcs_i);
+    cb(decay_mask, "decay_mask", il);
+
+    decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
+    cb(decay_mask, "decay_masked", il);
+    decay_mask = ggml_exp(ctx0, decay_mask);
+    decay_mask = ggml_mul(ctx0, decay_mask, diag_mask);
+
+    // decay_mask [S_k,BT_j,BT_i,CHB] *Note* second and third chunk_sizes are switched
+    decay_mask = ggml_cont_4d(ctx0, ggml_permute(ctx0, decay_mask, 2, 1, 0, 3), S_k, chunk_size, chunk_size, CHB);
+
+    ggml_tensor * k_i = ggml_reshape_4d(ctx0, k, S_k, chunk_size, 1, CHB);
+    ggml_tensor * k_j = ggml_reshape_4d(ctx0, k, S_k, 1, chunk_size, CHB);
+    ggml_tensor * q_i = ggml_reshape_4d(ctx0, q, S_k, chunk_size, 1, CHB);
+
+    ggml_tensor * decay_k_i = ggml_mul(ctx0, decay_mask, k_i);
+    ggml_tensor * decay_q_i = ggml_mul(ctx0, decay_mask, q_i);
+
+    // decay_k_i [S.BT,BT,CHB] @ k_j [S,1,BT,CHB] = Akk [BT,1,BT,CHB]
+    ggml_tensor * Akk = ggml_mul_mat(ctx0, decay_k_i, k_j);
+    ggml_tensor * Aqk = ggml_mul_mat(ctx0, decay_q_i, k_j);
+    Akk = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, Akk, chunk_size, chunk_size, n_chunks, HB)));
+    Aqk = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, Aqk, chunk_size, chunk_size, n_chunks, HB)));
+    cb(Akk, "Akk", il);
+    cb(Aqk, "Aqk", il);
+
+    Akk = ggml_mul(ctx0, Akk, beta);
+    Akk = ggml_neg(ctx0, ggml_mul(ctx0, Akk, causal_mask));
+    cb(Akk, "attn_pre_solve", il);
+
+    Aqk = ggml_mul(ctx0, Aqk, diag_mask);
+    Aqk = ggml_scale(ctx0, Aqk, scale); // scale q
+    cb(Aqk, "Aqk_masked", il);
+
+    // for i in range(1, chunk_size):
+    //          row = attn[..., i, :i].clone()
+    //          sub = attn[..., :i, :i].clone()
+    //          attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
+    // attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)
+    //
+    // We reduce this to a linear triangular solve: AX = B, where B = attn, A = I - tril(A)
+    ggml_tensor * attn_lower = ggml_mul(ctx0, Akk, causal_mask);
+    ggml_tensor * lhs        = ggml_sub(ctx0, ggml_repeat(ctx0, identity, attn_lower), attn_lower);
+
+    ggml_tensor * lin_solve  = ggml_solve_tri(ctx0, lhs, Akk, true, true, false);
+    Akk                      = ggml_mul(ctx0, lin_solve, causal_mask);
+    Akk                      = ggml_add(ctx0, Akk, identity);
+
+    cb(Akk, "attn_solved", il);
+
+    // switch back for downstream
+    gk_cumsum = ggml_cont_4d(ctx0, ggml_permute(ctx0, gk_cumsum, 1, 0, 2, 3), S_k, chunk_size, n_chunks, HB);
+    ggml_tensor * gkexp      = ggml_exp(ctx0, gk_cumsum);
+    cb(gk_cumsum, "gk_cumsum", il);
+
+    // u = (A*beta[..., None, :]) @ v  aka U_[t]
+    ggml_tensor * vb = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_beta)), Akk);
+
+    ggml_tensor * kbeta_gkexp = ggml_mul(ctx0, k_beta, gkexp);
+    cb(kbeta_gkexp, "kbeta_gkexp", il);
+
+    ggml_tensor * k_cumdecay = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, kbeta_gkexp)), Akk);
+    cb(k_cumdecay, "k_cumdecay", il);
+
+    ggml_tensor * core_attn_out = nullptr;
+    ggml_tensor * new_state = ggml_dup(ctx0, state);
+
+    cb(new_state, "new_state", il);
+
+    for (int64_t chunk = 0; chunk < n_chunks; chunk++) {
+// extract one chunk worth of data
+        auto chunkify = [=](ggml_tensor * t) {
+                    return ggml_cont(ctx0, ggml_view_4d(ctx0, t, t->ne[0], chunk_size, 1, t->ne[3],
+                t->nb[1], t->nb[2], t->nb[3], t->nb[2] * chunk));
+        };
+        auto chunkify_A = [=](ggml_tensor * t) {
+                    return ggml_cont(ctx0, ggml_view_4d(ctx0, t, chunk_size, chunk_size, 1, t->ne[3],
+                t->nb[1], t->nb[2], t->nb[3], t->nb[2] * chunk));
+        };
+
+
+// k [S,BT,NT,H*B] => k_chunk [S,BT,1,H*B]
+        ggml_tensor * k_chunk = chunkify(k);
+        ggml_tensor * q_chunk = chunkify(q);
+        ggml_tensor * vb_chunk = chunkify(vb);
+
+// gk_cumsum [S,BT,NT,H*B] => gk_cs_chunk [S,BT,1,H*B]
+        ggml_tensor * gk_cs_chunk = chunkify(gk_cumsum);
+        ggml_tensor * k_cumdecay_chunk = chunkify(k_cumdecay);
+        ggml_tensor * gkexp_chunk = ggml_exp(ctx0, gk_cs_chunk);
+        ggml_tensor * Aqk_chunk = chunkify_A(Aqk);
+
+        ggml_tensor * state_t = ggml_cont_4d(ctx0, ggml_permute(ctx0, new_state, 1, 0, 2, 3), S_v, S_v, 1, H_v * n_seqs);
+
+        // new_state [S,S,1,H*B] k_cumdecay_chunk [S,BT,1,H*B]
+        // v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state or W_[t] @ S_[t]
+        ggml_tensor * v_prime = ggml_mul_mat(ctx0, state_t, k_cumdecay_chunk);
+
+        // v_new = v_i - v_prime or U_[t] - W_[t]*S_[t]
+        ggml_tensor * v_new = ggml_sub(ctx0, ggml_repeat(ctx0, vb_chunk, v_prime), v_prime);
+        ggml_tensor * v_new_t = ggml_cont(ctx0, ggml_transpose(ctx0, v_new));
+
+        // q_chunk [S,BT,1,H*B] gkexp_chunk [S,BT,1,H*B]
+        // attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
+        // or Gamma_[t]*Q_]t] @ S
+        ggml_tensor * q_gk_exp   = ggml_mul(ctx0, q_chunk, gkexp_chunk);
+        ggml_tensor * attn_inter = ggml_mul_mat(ctx0, state_t, q_gk_exp);
+        attn_inter = ggml_scale(ctx0, attn_inter, scale); // scale q
+
+        // v_new_t [S,BT,1,H*B] Aqk [BT,BT,1,H*B]
+        // core_attn_out[:, :, i] = attn_inter + attn @ v_new or A' @ (U_[t] - W_[t]*S_[t])
+        ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_new_t, Aqk_chunk);
+
+        // o[:, :, i] = (q_i * g_i.exp()) @ S + A @ v_i
+        ggml_tensor * core_attn_out_chunk = ggml_add(ctx0, attn_inter, v_attn);
+
+        core_attn_out = core_attn_out == nullptr ? core_attn_out_chunk : ggml_concat(ctx0, core_attn_out, core_attn_out_chunk, 1);
+
+        ggml_tensor * gk_cum_last =
+            ggml_cont(ctx0, ggml_view_4d(ctx0, gk_cs_chunk, gk_cs_chunk->ne[0], 1, gk_cs_chunk->ne[2], gk_cs_chunk->ne[3],
+                                        gk_cs_chunk->nb[1], gk_cs_chunk->nb[2], gk_cs_chunk->nb[3],
+                                        gk_cs_chunk->nb[1] * (gk_cs_chunk->ne[1] - 1)));
+
+        ggml_tensor * gkexp_last = ggml_exp(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, gk_cum_last)));
+
+        ggml_tensor * gk_diff = ggml_neg(ctx0, ggml_sub(ctx0, gk_cs_chunk, gk_cum_last));
+
+        ggml_tensor * gk_diff_exp = ggml_exp(ctx0, gk_diff);
+
+        ggml_tensor * key_gkdiff = ggml_mul(ctx0, k_chunk, gk_diff_exp);
+
+        // rearrange((g_i[:,:,-1:] - g_i).exp()*k_i, 'b h c k -> b h k c') @ (U_[t] - W_[t] @ S)
+        ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, ggml_cont(ctx0, ggml_transpose(ctx0, key_gkdiff)));
+
+        new_state = ggml_add(ctx0,
+            ggml_mul(ctx0, new_state, ggml_reshape_4d(ctx0, gkexp_last, gkexp_last->ne[0], gkexp_last->ne[1], H_v, n_seqs)),
+            ggml_reshape_4d(ctx0, kgdmulvnew, kgdmulvnew->ne[0], kgdmulvnew->ne[1], H_v, n_seqs));
+    }
+
+    core_attn_out = ggml_cont_4d(ctx0, core_attn_out, S_v, chunk_size * n_chunks, H_v, n_seqs);
+
+    // truncate padded tokens
+    ggml_tensor * output_tokens = ggml_view_4d(ctx0, core_attn_out,
+            S_v, n_tokens, H_v, n_seqs,
+            ggml_row_size(core_attn_out->type, S_v),
+            ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks),
+            ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks * H_v), 0);
+    output_tokens = ggml_cont(ctx0, output_tokens);
+    // permute back to (S_v, H_v, n_tokens, n_seqs)
+    output_tokens = ggml_permute(ctx0, output_tokens, 0, 2, 1, 3);
+    output_tokens = ggml_cont(ctx0, output_tokens);
+
+    cb(new_state, "output_state", il);
+
+    return {output_tokens, new_state};
+}
+
+std::pair llm_build_kimi_linear::build_kda_autoregressive(
+    ggml_tensor * q,
+    ggml_tensor * k,
+    ggml_tensor * v,
+    ggml_tensor * gk,
+    ggml_tensor * beta,
+    ggml_tensor * state,
+    int il) {
+    GGML_ASSERT(ggml_is_contiguous(v));
+    GGML_ASSERT(ggml_is_contiguous(gk));
+
+    const int64_t S_k      = q->ne[0];
+    const int64_t H_k      = q->ne[1];
+    const int64_t n_tokens = q->ne[2];
+    const int64_t n_seqs   = q->ne[3];
+
+    const int64_t S_v = v->ne[0];
+    const int64_t H_v = v->ne[1];
+
+    GGML_ASSERT(n_tokens == 1);
+    GGML_ASSERT(v->ne[2] == n_tokens);
+    GGML_ASSERT(k->ne[2] == n_tokens);
+    GGML_ASSERT(gk->ne[0] == S_k && gk->ne[1] == H_k && gk->ne[2] == n_tokens && gk->ne[3] == n_seqs);
+    GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs);
+    GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_k && state->ne[2] == H_v && state->ne[3] == n_seqs);
+
+    GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs);
+    GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs);
+
+    GGML_ASSERT(H_k == H_v);  // we did a repeat to make sure this is the case
+
+    const float eps_norm = hparams.f_norm_rms_eps;
+
+    q = ggml_l2_norm(ctx0, q, eps_norm);
+    k = ggml_l2_norm(ctx0, k, eps_norm);
+
+    const float scale = 1.0f / sqrtf(S_v);
+
+    q    = ggml_scale(ctx0, q, scale);
+    beta = ggml_sigmoid(ctx0, beta);
+
+    cb(q, "q_in", il);
+    cb(k, "k_in", il);
+    cb(v, "v_in", il);
+    cb(beta, "beta_in", il);
+    cb(gk, "gk_in", il);
+
+// g [H,1,B,1] g_t [1,H,B,1] => [1,1,H,B]
+// gk [S,H,1,B] => [S,1,H,B] gk_t [1,S,H,B]
+// beta [H,1,1,B] beta_t [1,H,1,B] => [1,1,H,B]
+    gk = ggml_reshape_4d(ctx0, gk, S_k, 1, H_k, n_seqs);
+    ggml_tensor * gk_t = ggml_cont(ctx0, ggml_transpose(ctx0, gk));
+    ggml_tensor * beta_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, beta), 1, 1, H_k, n_seqs);
+
+    // Apply exponential to gk_t
+    gk_t = ggml_exp(ctx0, gk_t);
+    // Apply the gated delta rule for the single timestep
+    // last_recurrent_state = last_recurrent_state * gk_t
+    // S = S * g_i[..., None].exp()
+    state = ggml_mul(ctx0, state, gk_t);
+
+    ggml_tensor * state_t = ggml_cont(ctx0, ggml_transpose(ctx0, state));
+
+// state [S,S,H,B] k [S,1,H,B] k_state [S_v,1,H,B]
+    k = ggml_reshape_4d(ctx0, k, S_k, 1, H_k, n_seqs);
+    ggml_tensor * k_state = ggml_mul_mat(ctx0, state_t, k);
+
+    // v_i - (k_i[..., None] * S).sum(-2)
+    v = ggml_reshape_4d(ctx0, v, S_v, 1, H_v, n_seqs);
+    ggml_tensor * v_diff = ggml_sub(ctx0, v, k_state);
+
+    // b_i[..., None] * k_i
+    ggml_tensor * k_beta = ggml_mul(ctx0, k, beta_t);
+
+    // S = S + torch.einsum('b h k, b h v -> b h k v', b_i[..., None] * k_i, v_i - (k_i[..., None] * S).sum(-2))
+    // v_diff_t [1,S_v,H,B] k_beta_t [1,S_k,H,B] state [S_v,S_k,H,B]
+    state = ggml_add(ctx0, state, ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_diff)), ggml_cont(ctx0, ggml_transpose(ctx0, k_beta))));
+
+    q = ggml_reshape_4d(ctx0, q, S_k, 1, H_k, n_seqs);
+    state_t = ggml_cont(ctx0, ggml_transpose(ctx0, state));
+    ggml_tensor * core_attn_out = ggml_mul_mat(ctx0, state_t, q);
+    // core_attn_out should be [S_v, 1, H_v, n_seqs] after this
+    cb(core_attn_out, "output_tokens", il);
+    cb(state, "new_state", il);
+
+    return {core_attn_out, state};
+}
+
diff --git a/examples/talk-llama/models/models.h b/examples/talk-llama/models/models.h
index 3a44f7f140f..cfcbb9aaa5b 100644
--- a/examples/talk-llama/models/models.h
+++ b/examples/talk-llama/models/models.h
@@ -288,6 +288,33 @@ struct llm_build_jamba : public llm_graph_context_mamba {
     llm_build_jamba(const llama_model & model, const llm_graph_params & params);
 };
 
+struct llm_build_kimi_linear : public llm_graph_context_mamba {
+    llm_build_kimi_linear(const llama_model & model, const llm_graph_params & params);
+
+    std::pair build_kda_autoregressive(
+                ggml_tensor * q,
+                ggml_tensor * k,
+                ggml_tensor * v,
+                ggml_tensor * gk,
+                ggml_tensor * beta,
+                ggml_tensor * state,
+                        int   il);
+
+    std::pair build_kda_chunking(
+                ggml_tensor * q,
+                ggml_tensor * k,
+                ggml_tensor * v,
+                ggml_tensor * gk,
+                ggml_tensor * beta,
+                ggml_tensor * state,
+                ggml_tensor * causal_mask,
+                ggml_tensor * identity,
+                ggml_tensor * diag_mask,
+                        int   il);
+
+    const llama_model & model;
+};
+
 struct llm_build_lfm2 : public llm_graph_context {
     const llama_model & model;
 
@@ -556,6 +583,10 @@ struct llm_build_starcoder : public llm_graph_context {
     llm_build_starcoder(const llama_model & model, const llm_graph_params & params);
 };
 
+struct llm_build_step35_iswa : public llm_graph_context {
+    llm_build_step35_iswa(const llama_model & model, const llm_graph_params & params);
+};
+
 struct llm_build_t5_dec : public llm_graph_context {
     llm_build_t5_dec(const llama_model & model, const llm_graph_params & params);
 };
diff --git a/examples/talk-llama/models/openelm.cpp b/examples/talk-llama/models/openelm.cpp
index ee46a3375e8..fbf682ec835 100644
--- a/examples/talk-llama/models/openelm.cpp
+++ b/examples/talk-llama/models/openelm.cpp
@@ -43,7 +43,7 @@ llm_build_openelm::llm_build_openelm(const llama_model & model, const llm_graph_
             ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, cur->nb[1], cur->nb[2], cur->nb[1]*n_head);
             cb(Kcur, "Kcur", il);
 
-            ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, cur->nb[1], cur->nb[2], cur->nb[1]*(n_head+n_head_kv)));
+            ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, cur->nb[1], cur->nb[2], cur->nb[1]*(n_head+n_head_kv));
             cb(Vcur, "Vcur", il);
 
             Qcur = build_norm(Qcur,
diff --git a/examples/talk-llama/models/qwen3next.cpp b/examples/talk-llama/models/qwen3next.cpp
index 57b6659baf0..99b1a76a485 100644
--- a/examples/talk-llama/models/qwen3next.cpp
+++ b/examples/talk-llama/models/qwen3next.cpp
@@ -265,9 +265,15 @@ std::pair llm_build_qwen3next::build_delta_net_chu
     cb(g_diff, "g_diff", il); // shape: (chunk_size, 1, n_chunks, H_v * n_seqs)
 
     ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff);
-    ggml_tensor * key_gdiff = ggml_mul(ctx0, k, g_diff_exp);
+    ggml_tensor * g_diff_exp_t = ggml_reshape_4d(ctx0, g_diff_exp,
+                                                 1, chunk_size, n_chunks, g_diff_exp->ne[3]);
+
+    ggml_tensor * key_gdiff = ggml_mul(ctx0, k, g_diff_exp_t);
     cb(key_gdiff, "key_gdiff", il); // shape: (S_k, chunk_size, n_chunks, H_v * n_seqs)
 
+    ggml_tensor * key_gdiff_t = ggml_cont(ctx0, ggml_transpose(ctx0, key_gdiff));
+    cb(key_gdiff_t, "key_gdiff_t", il); // shape: (chunk_size, S_k, n_chunks, H_v * n_seqs)
+
 
     // state to be updated per chunk
     ggml_tensor * new_state = state; // ggml_dup(ctx0, state);
@@ -322,9 +328,9 @@ std::pair llm_build_qwen3next::build_delta_net_chu
             : ggml_concat(ctx0, core_attn_out, core_attn_out_chunk, 2);
 
         // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new
-        ggml_tensor * k_gdiff = ggml_cont(ctx0, get_slice_2d(ctx0, key_gdiff, chunk));
+        ggml_tensor * k_gdiff_t = get_slice_2d(ctx0, key_gdiff_t, chunk);
         //ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, k_gdiff, v_new); // this is slower on metal, why?
-        ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, ggml_cont(ctx0, ggml_transpose(ctx0, k_gdiff)));
+        ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, k_gdiff_t);
 
         // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew
         ggml_tensor * gexp_last_chunk = ggml_cont(ctx0, get_slice_2d(ctx0, g_last_exp, chunk));
diff --git a/examples/talk-llama/models/step35-iswa.cpp b/examples/talk-llama/models/step35-iswa.cpp
new file mode 100644
index 00000000000..f8737815a67
--- /dev/null
+++ b/examples/talk-llama/models/step35-iswa.cpp
@@ -0,0 +1,168 @@
+#include "models.h"
+
+llm_build_step35_iswa::llm_build_step35_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
+    ggml_tensor * cur;
+    ggml_tensor * inpL;
+
+    inpL = build_inp_embd(model.tok_embd);
+    ggml_tensor * inp_pos     = build_inp_pos();
+    auto        * inp_attn    = build_attn_inp_kv_iswa();
+    ggml_tensor * inp_out_ids = build_inp_out_ids();
+
+    for (int il = 0; il < n_layer; ++il) {
+        ggml_tensor * inpSA = inpL;
+
+        const uint32_t n_head_l    = hparams.n_head(il);
+        const uint32_t n_head_kv_l = hparams.n_head_kv(il);
+
+        const float freq_base_l  = model.get_rope_freq_base(cparams, il);
+        const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
+
+        cur = inpL;
+
+        // dump pre-attn RMSNorm input to pinpoint layer boundary issues
+        cb(cur, "attn_norm_in", il);
+
+        // self-attention
+        {
+            cur = build_norm(cur, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+            ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+            ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+            ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+
+            cb(Qcur, "Qcur", il);
+            cb(Kcur, "Kcur", il);
+            cb(Vcur, "Vcur", il);
+
+            Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head_l,    n_tokens);
+            Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv_l, n_tokens);
+            Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head_v, n_head_kv_l, n_tokens);
+
+            // Q/K per-head RMSNorm (Step35 q_norm / k_norm)
+            if (model.layers[il].attn_q_norm) {
+                Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il);
+                cb(Qcur, "Qcur_normed", il);
+            }
+            if (model.layers[il].attn_k_norm) {
+                Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, il);
+                cb(Kcur, "Kcur_normed", il);
+            }
+
+            // RoPE (partial rotary factors per layer)
+            const bool is_swa = hparams.is_swa(il);
+            ggml_tensor * rope_factors = is_swa ? nullptr : model.get_rope_factors(cparams, il);
+            const int64_t n_rot_l = is_swa ? hparams.n_rot : (hparams.n_rot / 2);
+            Qcur = ggml_rope_ext(
+                ctx0, Qcur, inp_pos, rope_factors,
+                n_rot_l, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
+                ext_factor, attn_factor, beta_fast, beta_slow
+            );
+            Kcur = ggml_rope_ext(
+                ctx0, Kcur, inp_pos, rope_factors,
+                n_rot_l, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
+                ext_factor, attn_factor, beta_fast, beta_slow
+            );
+            cb(Qcur, "Qcur_pos", il);
+            cb(Kcur, "Kcur_pos", il);
+
+            const float kq_scale = 1.0f / sqrtf(float(n_embd_head_k));
+            ggml_tensor * attn_out = build_attn(inp_attn,
+                    nullptr, nullptr,
+                    Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
+            cb(attn_out, "attn_out", il);
+            // head-wise attention gate: sigmoid(g_proj(x)) in torch
+            if (model.layers[il].wqkv_gate) {
+                ggml_tensor * gate = build_lora_mm(model.layers[il].wqkv_gate, cur); // [n_head_l, n_tokens]
+                cb(gate, "attn_gate", il);
+
+                gate = ggml_sigmoid(ctx0, gate);
+                cb(gate, "attn_gate_sigmoid", il);
+
+                // reshape + broadcast to [n_embd_head_v, n_head_l, n_tokens]
+                ggml_tensor * attn_3d = ggml_reshape_3d(ctx0, attn_out, n_embd_head_v, n_head_l, n_tokens);
+                ggml_tensor * gate_3d = ggml_reshape_3d(ctx0, gate,       1,          n_head_l, n_tokens);
+                cb(gate_3d, "attn_gate_3d", il);
+
+                attn_3d = ggml_mul(ctx0, attn_3d, gate_3d);
+                cb(attn_3d, "attn_gated_3d", il);
+
+                attn_out = ggml_reshape_2d(ctx0, attn_3d, n_embd_head_v * n_head_l, n_tokens);
+                cb(attn_out, "attn_gated", il);
+            }
+
+            // output projection
+            cur = build_lora_mm(model.layers[il].wo, attn_out);
+            cb(cur, "attn_proj", il);
+        }
+
+        if (il == n_layer - 1 && inp_out_ids) {
+            cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+            inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+        }
+
+        ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+        cb(ffn_inp, "ffn_inp", il);
+
+        cur = build_norm(ffn_inp, model.layers[il].ffn_norm, nullptr, LLM_NORM_RMS, il);
+        cb(cur, "ffn_norm", il);
+
+        // feed-forward
+        if (model.layers[il].ffn_gate_inp == nullptr) {
+            // dense MLP
+            cur = build_ffn(cur,
+                    model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   nullptr,
+                    model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, nullptr,
+                    model.layers[il].ffn_down, model.layers[il].ffn_down_b, nullptr,
+                    nullptr,
+                    LLM_FFN_SILU, LLM_FFN_PAR, il);
+            cb(cur, "ffn_out", il);
+        } else {
+            // MoE routed experts
+            const bool  norm_w  = hparams.expert_weights_norm;
+            const float w_scale = hparams.expert_weights_scale;
+            const bool  scale_w = w_scale != 0.0f;
+            ggml_tensor * moe_out = build_moe_ffn(cur,
+                    model.layers[il].ffn_gate_inp,
+                    model.layers[il].ffn_up_exps,
+                    model.layers[il].ffn_gate_exps,
+                    model.layers[il].ffn_down_exps,
+                    model.layers[il].ffn_exp_probs_b,
+                    n_expert, n_expert_used,
+                    LLM_FFN_SILU,
+                    norm_w, scale_w, w_scale,
+                    (llama_expert_gating_func_type) hparams.expert_gating_func,
+                    il);
+            cb(moe_out, "ffn_moe_out", il);
+
+            // shared expert MLP (always added on MoE layers in Step35)
+            ggml_tensor * sh_out = build_ffn(cur,
+                    model.layers[il].ffn_up_shexp,   nullptr, nullptr,
+                    model.layers[il].ffn_gate_shexp, nullptr, nullptr,
+                    model.layers[il].ffn_down_shexp, nullptr, nullptr,
+                    nullptr,
+                    LLM_FFN_SILU, LLM_FFN_PAR, il);
+            cb(sh_out, "ffn_shared_out", il);
+
+            cur = ggml_add(ctx0, moe_out, sh_out);
+            cb(cur, "ffn_out", il);
+        }
+        cur = ggml_add(ctx0, cur, ffn_inp);
+        cur = build_cvec(cur, il);
+        cb(cur, "l_out", il);
+
+        inpL = cur;
+    }
+
+    cur = inpL;
+
+    cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1);
+    cb(cur, "result_norm", -1);
+    res->t_embd = cur;
+
+    cur = build_lora_mm(model.output, cur);
+    cb(cur, "result_output", -1);
+    res->t_logits = cur;
+
+    ggml_build_forward_expand(gf, cur);
+}
diff --git a/examples/talk-llama/unicode.cpp b/examples/talk-llama/unicode.cpp
index b47dcbe6198..adfc489d1f0 100644
--- a/examples/talk-llama/unicode.cpp
+++ b/examples/talk-llama/unicode.cpp
@@ -497,49 +497,26 @@ static std::vector unicode_regex_split_custom_llama3(const std::string &
     return bpe_offsets;
 }
 
-// use std::wregex to split the text
-static std::vector unicode_regex_split_stl(const std::wstring & wtext, const std::wstring & regex_expr, const std::vector & offsets) {
-    std::wregex expr(regex_expr, std::regex_constants::optimize | std::regex_constants::nosubs);
-    std::vector bpe_offsets; // store the offset of each word
-    bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
-    size_t start = 0;
-    for (auto offset : offsets) {
-        std::wcregex_iterator it(wtext.data() + start, wtext.data() + start + offset, expr);
-        std::wcregex_iterator end;
-
-        int64_t start_idx = 0;
-        while (it != end) {
-            std::wcmatch match = *it;
-            if (match.position() > start_idx) {
-                bpe_offsets.emplace_back(match.position() - start_idx);
-            }
-            bpe_offsets.emplace_back(match.length());
-            start_idx = match.position() + match.length();
-            ++it;
-        }
-
-        if (start_idx < (int64_t) offset) {
-            bpe_offsets.emplace_back(offset - start_idx);
-        }
-        start += offset;
-    }
-
-    return bpe_offsets;
-}
-
-// use std::regex to split the text
-static std::vector unicode_regex_split_stl(const std::string & text, const std::string & regex_expr, const std::vector & offsets) {
-    std::regex expr(regex_expr, std::regex_constants::optimize | std::regex_constants::nosubs);
+template 
+static std::vector unicode_regex_split_stl(const std::basic_string & text, const std::basic_string & regex, const std::vector & offsets) {
+    using BidirIt = typename std::basic_string::const_iterator;
+#ifdef _MSC_VER
+    // Bypass bug in MSVC: https://github.com/ggml-org/llama.cpp/issues/17830
+    constexpr auto regex_flags = std::regex_constants::ECMAScript;
+#else
+    constexpr auto regex_flags = std::regex_constants::optimize | std::regex_constants::nosubs;
+#endif
+    std::basic_regex expr(regex, regex_flags);
     std::vector bpe_offsets; // store the offset of each word
     bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
     size_t start = 0;
     for (auto offset : offsets) {
-        std::cregex_iterator it(text.data() + start, text.data() + start + offset, expr);
-        std::cregex_iterator end;
+        std::regex_iterator it(text.begin() + start, text.begin() + start + offset, expr);
+        std::regex_iterator end;
 
         int64_t start_idx = 0;
         while (it != end) {
-            std::cmatch match = *it;
+            std::match_results match = *it;
             if (match.position() > start_idx) {
                 bpe_offsets.emplace_back(match.position() - start_idx);
             }

From 193f7cdaaf9144599dcb8ab5017fd0d3fd8ad828 Mon Sep 17 00:00:00 2001
From: Georgi Gerganov 
Date: Mon, 9 Feb 2026 09:59:22 +0200
Subject: [PATCH 115/831] ci : try fix mirrors (#3655)

---
 .github/workflows/build.yml | 20 --------------------
 1 file changed, 20 deletions(-)

diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml
index 5c1cf93ba2a..823dba7d573 100644
--- a/.github/workflows/build.yml
+++ b/.github/workflows/build.yml
@@ -174,10 +174,6 @@ jobs:
             sed -i "s|archive.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list
             sed -i "s|security.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list
 
-            apt-get update
-            apt-get install -y ca-certificates
-            sed -i "s|http://ports.ubuntu.com|https://mirror.kumi.systems|g" /etc/apt/sources.list
-
             apt update
             apt install -y build-essential libsdl2-dev cmake git
             cmake -B build -DGGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=armv8-a
@@ -210,10 +206,6 @@ jobs:
             sed -i "s|archive.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list
             sed -i "s|security.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list
 
-            apt-get update
-            apt-get install -y ca-certificates
-            sed -i "s|http://ports.ubuntu.com|https://mirror.kumi.systems|g" /etc/apt/sources.list
-
             apt update
             apt install -y build-essential libsdl2-dev cmake git
             cmake -B build -DGGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=armv7-a+fp
@@ -338,10 +330,6 @@ jobs:
             sed -i "s|archive.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list
             sed -i "s|security.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list
 
-            apt-get update
-            apt-get install -y ca-certificates
-            sed -i "s|http://ports.ubuntu.com|https://mirror.kumi.systems|g" /etc/apt/sources.list
-
             apt update
             apt install -y build-essential cmake libsdl2-dev git
             cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} -DGGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=armv8-a
@@ -376,10 +364,6 @@ jobs:
             sed -i "s|archive.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list
             sed -i "s|security.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list
 
-            apt-get update
-            apt-get install -y ca-certificates
-            sed -i "s|http://ports.ubuntu.com|https://mirror.kumi.systems|g" /etc/apt/sources.list
-
             apt update
             apt install -y build-essential cmake libsdl2-dev git
             cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} -DGGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=armv7-a+fp
@@ -417,10 +401,6 @@ jobs:
             sed -i "s|archive.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list
             sed -i "s|security.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list
 
-            apt-get update
-            apt-get install -y ca-certificates
-            sed -i "s|http://ports.ubuntu.com|https://mirror.kumi.systems|g" /etc/apt/sources.list
-
             apt update
             apt install -y clang build-essential cmake libsdl2-dev git
             cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_COMPILER=clang

From eb27fa2252ee757a1ddf4fe55d10dbe4b24c3a97 Mon Sep 17 00:00:00 2001
From: Sid Mohan <61345237+sidmohan0@users.noreply.github.com>
Date: Mon, 9 Feb 2026 00:10:13 -0800
Subject: [PATCH 116/831] server : fix hardcoded /inference path in default
 HTML page (#3639)

Closes #3596
---
 examples/server/server.cpp | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/examples/server/server.cpp b/examples/server/server.cpp
index b77d8a3ed46..c5354efc314 100644
--- a/examples/server/server.cpp
+++ b/examples/server/server.cpp
@@ -748,9 +748,9 @@ int main(int argc, char ** argv) {
     
         

Whisper.cpp Server

-

/inference

+

)" + sparams.request_path + sparams.inference_path + R"(

-    curl 127.0.0.1:)" + std::to_string(sparams.port) + R"(/inference \
+    curl 127.0.0.1:)" + std::to_string(sparams.port) + sparams.request_path + sparams.inference_path + R"( \
     -H "Content-Type: multipart/form-data" \
     -F file="@<file-path>" \
     -F temperature="0.0" \
@@ -767,7 +767,7 @@ int main(int argc, char ** argv) {
 
         

Try it out

-
+
From 525be69a66788499af010cff1548ee33f284547f Mon Sep 17 00:00:00 2001 From: Christian Kastner Date: Mon, 9 Feb 2026 11:32:18 +0100 Subject: [PATCH 117/831] cmake: Drop obsolete build-time configuration of backends (#3649) The backend configuration now happens in ggml. This updated configuration mirrors that of llama.cpp. --- cmake/whisper-config.cmake.in | 45 ++++------------------------------- 1 file changed, 5 insertions(+), 40 deletions(-) diff --git a/cmake/whisper-config.cmake.in b/cmake/whisper-config.cmake.in index 6a3fa22701f..b70c1e5af44 100644 --- a/cmake/whisper-config.cmake.in +++ b/cmake/whisper-config.cmake.in @@ -3,60 +3,25 @@ set(WHISPER_BUILD_COMMIT @WHISPER_BUILD_COMMIT@) set(WHISPER_BUILD_NUMBER @WHISPER_BUILD_NUMBER@) set(WHISPER_SHARED_LIB @BUILD_SHARED_LIBS@) -set(GGML_BLAS @GGML_BLAS@) -set(GGML_CUDA @GGML_CUDA@) -set(GGML_METAL @GGML_METAL@) -set(GGML_HIPBLAS @GGML_HIPBLAS@) -set(GGML_ACCELERATE @GGML_ACCELERATE@) - @PACKAGE_INIT@ set_and_check(WHISPER_INCLUDE_DIR "@PACKAGE_WHISPER_INCLUDE_INSTALL_DIR@") set_and_check(WHISPER_LIB_DIR "@PACKAGE_WHISPER_LIB_INSTALL_DIR@") set_and_check(WHISPER_BIN_DIR "@PACKAGE_WHISPER_BIN_INSTALL_DIR@") -# Ensure transient dependencies satisfied - -find_package(Threads REQUIRED) - -if (APPLE AND GGML_ACCELERATE) - find_library(ACCELERATE_FRAMEWORK Accelerate REQUIRED) -endif() - -if (GGML_BLAS) - find_package(BLAS REQUIRED) -endif() - -if (GGML_CUDA) - find_package(CUDAToolkit REQUIRED) -endif() - -if (GGML_METAL) - find_library(FOUNDATION_LIBRARY Foundation REQUIRED) - find_library(METAL_FRAMEWORK Metal REQUIRED) - find_library(METALKIT_FRAMEWORK MetalKit REQUIRED) -endif() - -if (GGML_HIPBLAS) - find_package(hip REQUIRED) - find_package(hipblas REQUIRED) - find_package(rocblas REQUIRED) -endif() +find_package(ggml REQUIRED HINTS ${LLAMA_LIB_DIR}/cmake) find_library(whisper_LIBRARY whisper REQUIRED - HINTS ${WHISPER_LIB_DIR}) - -set(_whisper_link_deps "Threads::Threads" "@WHISPER_EXTRA_LIBS@") -set(_whisper_transient_defines "@WHISPER_TRANSIENT_DEFINES@") + HINTS ${WHISPER_LIB_DIR} + NO_CMAKE_FIND_ROOT_PATH +) add_library(whisper UNKNOWN IMPORTED) - set_target_properties(whisper PROPERTIES INTERFACE_INCLUDE_DIRECTORIES "${WHISPER_INCLUDE_DIR}" - INTERFACE_LINK_LIBRARIES "${_whisper_link_deps}" - INTERFACE_COMPILE_DEFINITIONS "${_whisper_transient_defines}" + INTERFACE_LINK_LIBRARIES "ggml::ggml;ggml::ggml-base;" IMPORTED_LINK_INTERFACE_LANGUAGES "CXX" IMPORTED_LOCATION "${whisper_LIBRARY}" INTERFACE_COMPILE_FEATURES cxx_std_11 From 052066c4f760eb1041db8b3f4e93d17728d21af5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A1draic=20Slattery?= Date: Mon, 9 Feb 2026 11:32:46 +0100 Subject: [PATCH 118/831] chore: Update outdated GitHub Actions versions (#3646) --- .github/workflows/bindings-go.yml | 4 +- .github/workflows/bindings-ruby.yml | 2 +- .github/workflows/build.yml | 118 ++++++++++++++-------------- .github/workflows/docker.yml | 4 +- .github/workflows/examples-wasm.yml | 6 +- .github/workflows/examples.yml | 4 +- close-issue.yml | 2 +- 7 files changed, 70 insertions(+), 70 deletions(-) diff --git a/.github/workflows/bindings-go.yml b/.github/workflows/bindings-go.yml index ff420f2b636..83473e4636a 100644 --- a/.github/workflows/bindings-go.yml +++ b/.github/workflows/bindings-go.yml @@ -13,10 +13,10 @@ jobs: ubuntu-22: runs-on: ubuntu-22.04 steps: - - uses: actions/setup-go@v5 + - uses: actions/setup-go@v6 with: go-version: '^1.23' - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - run: | cd bindings/go make test diff --git a/.github/workflows/bindings-ruby.yml b/.github/workflows/bindings-ruby.yml index 680862fb764..c3f158e26e4 100644 --- a/.github/workflows/bindings-ruby.yml +++ b/.github/workflows/bindings-ruby.yml @@ -17,5 +17,5 @@ jobs: - uses: ruby/setup-ruby@v1 with: ruby-version: '3.2' - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - run: rake test diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 823dba7d573..8ce887fd111 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -67,7 +67,7 @@ jobs: steps: - name: Checkout with full history - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: fetch-depth: 0 @@ -127,7 +127,7 @@ jobs: steps: - name: Clone - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Set up QEMU uses: docker/setup-qemu-action@v3 @@ -159,7 +159,7 @@ jobs: steps: - name: Clone - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Set up QEMU uses: docker/setup-qemu-action@v3 @@ -191,7 +191,7 @@ jobs: steps: - name: Clone - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Set up QEMU uses: docker/setup-qemu-action@v3 @@ -223,7 +223,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: ccache uses: hendrikmuhs/ccache-action@v1.2.16 @@ -255,7 +255,7 @@ jobs: # # steps: # - name: Clone -# uses: actions/checkout@v4 +# uses: actions/checkout@v6 # # - name: Build # uses: cross-platform-actions/action@v0.27.0 @@ -281,7 +281,7 @@ jobs: steps: - name: Clone - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Set up QEMU uses: docker/setup-qemu-action@v3 @@ -315,7 +315,7 @@ jobs: steps: - name: Clone - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Set up QEMU uses: docker/setup-qemu-action@v3 @@ -349,7 +349,7 @@ jobs: steps: - name: Clone - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Set up QEMU uses: docker/setup-qemu-action@v3 @@ -386,7 +386,7 @@ jobs: steps: - name: Clone - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Set up QEMU uses: docker/setup-qemu-action@v3 @@ -420,7 +420,7 @@ jobs: steps: - name: Clone - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Set up QEMU uses: docker/setup-qemu-action@v3 @@ -460,7 +460,7 @@ jobs: steps: - name: Clone - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: add oneAPI to apt shell: bash @@ -484,7 +484,7 @@ jobs: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Build id: cmake_build @@ -512,7 +512,7 @@ jobs: steps: - name: Clone - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: add oneAPI to apt shell: bash @@ -536,7 +536,7 @@ jobs: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Build id: cmake_build @@ -561,7 +561,7 @@ jobs: steps: - name: Clone - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Setup ${{ matrix.sys }} uses: msys2/setup-msys2@v2 @@ -616,7 +616,7 @@ jobs: steps: - name: Clone - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Add msbuild to PATH uses: microsoft/setup-msbuild@v2 @@ -646,31 +646,31 @@ jobs: - name: Upload SDL2.dll if: matrix.sdl2 == 'ON' - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: name: ${{ matrix.s2arc }}_SDL2.dll path: build/bin/${{ matrix.build }}/SDL2.dll - name: Upload whisper dll - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: name: whisper_${{ matrix.arch }}.dll path: build/bin/${{ matrix.build }}/whisper.dll - name: Upload ggml dll - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: name: ggml_${{ matrix.arch }}.dll path: build/bin/${{ matrix.build }}/ggml.dll - name: Upload ggml base dll - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: name: ggml_base_${{ matrix.arch }}.dll path: build/bin/${{ matrix.build }}/ggml-base.dll - name: Upload ggml cpu dll - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: name: ggml_cpu_${{ matrix.arch }}.dll path: build/bin/${{ matrix.build }}/ggml-cpu.dll @@ -682,7 +682,7 @@ jobs: - name: Upload binaries if: matrix.sdl2 == 'ON' && ${{ needs.determine-tag.outputs.should_release }} - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: name: whisper-bin-${{ matrix.arch }}.zip path: whisper-bin-${{ matrix.arch }}.zip @@ -711,10 +711,10 @@ jobs: steps: - name: Clone - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Export GitHub Actions cache environment variables - uses: actions/github-script@v7 + uses: actions/github-script@v8 with: script: | core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); @@ -768,7 +768,7 @@ jobs: - name: Upload binaries if: matrix.blas == 'ON' && matrix.sdl2 == 'ON' && ${{ needs.determine-tag.outputs.should_release }} - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: name: whisper-blas-bin-${{ matrix.arch }}.zip path: whisper-blas-bin-${{ matrix.arch }}.zip @@ -792,7 +792,7 @@ jobs: sdl2_ver: 2.28.5 steps: - name: Clone repository - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Install Ninja id: install_ninja @@ -977,7 +977,7 @@ jobs: - name: Upload binaries if: ${{ needs.determine-tag.outputs.should_release }} - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: name: whisper-cublas-${{ matrix.cuda-toolkit }}-bin-${{ matrix.arch }}.zip path: whisper-cublas-${{ matrix.cuda-toolkit }}-bin-${{ matrix.arch }}.zip @@ -993,7 +993,7 @@ jobs: steps: - name: Clone - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Setup emsdk uses: mymindstorm/setup-emsdk@v14 @@ -1016,7 +1016,7 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Configure run: | @@ -1058,7 +1058,7 @@ jobs: - name: Upload artifacts if: ${{ needs.determine-tag.outputs.should_release }} - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: path: whisper-${{ needs.determine-tag.outputs.tag_name }}-xcframework.zip name: whisper-${{ needs.determine-tag.outputs.tag_name }}-xcframework.zip @@ -1070,12 +1070,12 @@ jobs: steps: - name: Clone - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: path: whisper - name: Install Java - uses: actions/setup-java@v4 + uses: actions/setup-java@v5 with: distribution: zulu java-version: 21 @@ -1099,10 +1099,10 @@ jobs: steps: - name: Clone - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: set up JDK 11 - uses: actions/setup-java@v4 + uses: actions/setup-java@v5 with: java-version: '11' distribution: 'temurin' @@ -1125,36 +1125,36 @@ jobs: needs: ['windows'] runs-on: windows-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Install Java - uses: actions/setup-java@v4 + uses: actions/setup-java@v5 with: distribution: zulu java-version: 20 - name: Download Whisper Windows lib - uses: actions/download-artifact@v4 + uses: actions/download-artifact@v7 with: name: whisper_x64.dll - name: Download GGML Windows lib - uses: actions/download-artifact@v4 + uses: actions/download-artifact@v7 with: name: ggml_x64.dll - name: Download GGML Base Windows lib - uses: actions/download-artifact@v4 + uses: actions/download-artifact@v7 with: name: ggml_base_x64.dll - name: Download GGML CPU Windows lib - uses: actions/download-artifact@v4 + uses: actions/download-artifact@v7 with: name: ggml_cpu_x64.dll - name: Download SDL2.dll - uses: actions/download-artifact@v4 + uses: actions/download-artifact@v7 with: name: x64_SDL2.dll @@ -1201,7 +1201,7 @@ jobs: Compress-Archive -Path "bindings/java/build/libs/whispercpp-*.jar" -DestinationPath "whispercpp.jar.zip" - name: Upload jar - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v6 with: name: whispercpp.jar.zip path: whispercpp.jar.zip @@ -1225,7 +1225,7 @@ jobs: steps: - name: Clone - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Test quantize run: | @@ -1249,7 +1249,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: fetch-depth: 0 @@ -1262,7 +1262,7 @@ jobs: # Downloads all the artifacts from the previous jobs - name: Download artifacts id: download-artifact - uses: actions/download-artifact@v4 + uses: actions/download-artifact@v7 with: path: ./artifact @@ -1312,7 +1312,7 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Set environment variables id: set_vars @@ -1338,7 +1338,7 @@ jobs: steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Build shell: bash @@ -1358,7 +1358,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: ccache uses: ggml-org/ccache-action@v1.2.16 @@ -1383,7 +1383,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: ccache uses: ggml-org/ccache-action@v1.2.16 @@ -1408,7 +1408,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: ccache uses: ggml-org/ccache-action@v1.2.16 @@ -1433,7 +1433,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: ccache uses: ggml-org/ccache-action@v1.2.16 @@ -1458,7 +1458,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: ccache uses: ggml-org/ccache-action@v1.2.16 @@ -1483,7 +1483,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Test id: ggml-ci @@ -1497,7 +1497,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Test id: ggml-ci @@ -1511,7 +1511,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Test id: ggml-ci @@ -1525,7 +1525,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Test id: ggml-ci @@ -1538,7 +1538,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Test id: ggml-ci @@ -1551,7 +1551,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Test id: ggml-ci diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index 0e2fb1f2b9e..57f062e9f7c 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -25,7 +25,7 @@ jobs: steps: - name: Check out the repo - uses: actions/checkout@v3 + uses: actions/checkout@v6 - name: Set up QEMU uses: docker/setup-qemu-action@v3 @@ -67,7 +67,7 @@ jobs: echo "tags=$TAGS" >> $GITHUB_OUTPUT - name: Build and push Docker image (tagged) - uses: docker/build-push-action@v5 + uses: docker/build-push-action@v6 with: context: . push: ${{ github.event_name == 'push' }} diff --git a/.github/workflows/examples-wasm.yml b/.github/workflows/examples-wasm.yml index ebbbdfe20ca..927438cdad8 100644 --- a/.github/workflows/examples-wasm.yml +++ b/.github/workflows/examples-wasm.yml @@ -22,10 +22,10 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Setup Pages - uses: actions/configure-pages@v4 + uses: actions/configure-pages@v5 - name: Setup emsdk uses: mymindstorm/setup-emsdk@v14 @@ -88,7 +88,7 @@ jobs: find staging -type f | sort - name: Upload artifact - uses: actions/upload-pages-artifact@v3 + uses: actions/upload-pages-artifact@v4 with: path: ./staging diff --git a/.github/workflows/examples.yml b/.github/workflows/examples.yml index 74ef8e0faae..1c9ade5a300 100644 --- a/.github/workflows/examples.yml +++ b/.github/workflows/examples.yml @@ -17,7 +17,7 @@ jobs: node-version: [ 16.x, 18.x ] steps: - name: Clone - uses: actions/checkout@v1 + uses: actions/checkout@v6 - name: Dependencies run: | @@ -27,7 +27,7 @@ jobs: sudo apt-get install libsdl2-dev - name: Use Node.js ${{ matrix.node-version }} - uses: actions/setup-node@v1 + uses: actions/setup-node@v6 with: node-version: ${{ matrix.node-version }} cache: 'npm' diff --git a/close-issue.yml b/close-issue.yml index 276a217d450..f661de1cd45 100644 --- a/close-issue.yml +++ b/close-issue.yml @@ -15,7 +15,7 @@ jobs: issues: write pull-requests: write steps: - - uses: actions/stale@v5 + - uses: actions/stale@v10 with: exempt-issue-labels: "refactor,help wanted,good first issue,research,bug,roadmap" days-before-issue-stale: 30 From 764482c3175d9c3bc6089c1ec84df7d1b9537d83 Mon Sep 17 00:00:00 2001 From: Nuno Date: Mon, 9 Feb 2026 11:33:06 +0100 Subject: [PATCH 119/831] ci: add vulkan docker image (#3644) Signed-off-by: rare-magma --- .devops/main-vulkan.Dockerfile | 20 ++++++++++++++++++++ .github/workflows/docker.yml | 1 + README.md | 15 ++++++++++++++- 3 files changed, 35 insertions(+), 1 deletion(-) create mode 100644 .devops/main-vulkan.Dockerfile diff --git a/.devops/main-vulkan.Dockerfile b/.devops/main-vulkan.Dockerfile new file mode 100644 index 00000000000..2be22e4d53b --- /dev/null +++ b/.devops/main-vulkan.Dockerfile @@ -0,0 +1,20 @@ +FROM ubuntu:24.04 AS build +WORKDIR /app + +RUN apt-get update && \ + apt-get install -y build-essential wget cmake git libvulkan-dev glslc \ + && rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/* + +COPY .. . +RUN make base.en CMAKE_ARGS="-DGGML_VULKAN=1" + +FROM ubuntu:24.04 AS runtime +WORKDIR /app + +RUN apt-get update && \ + apt-get install -y curl ffmpeg libsdl2-dev wget cmake git libvulkan1 mesa-vulkan-drivers \ + && rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/* + +COPY --from=build /app /app +ENV PATH=/app/build/bin:$PATH +ENTRYPOINT [ "bash", "-c" ] diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index 57f062e9f7c..6c0de0ece70 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -22,6 +22,7 @@ jobs: - { tag: "main-musa", dockerfile: ".devops/main-musa.Dockerfile", platform: "linux/amd64" } - { tag: "main-intel", dockerfile: ".devops/main-intel.Dockerfile", platform: "linux/amd64" } - { tag: "main-cuda", dockerfile: ".devops/main-cuda.Dockerfile", platform: "linux/amd64" } + - { tag: "main-vulkan", dockerfile: ".devops/main-vulkan.Dockerfile", platform: "linux/amd64" } steps: - name: Check out the repo diff --git a/README.md b/README.md index 6d4988e6fa5..c0d8edb99bc 100644 --- a/README.md +++ b/README.md @@ -443,11 +443,12 @@ ffmpeg -i samples/jfk.wav jfk.opus ### Images -We have two Docker images available for this project: +We have multiple Docker images available for this project: 1. `ghcr.io/ggml-org/whisper.cpp:main`: This image includes the main executable file as well as `curl` and `ffmpeg`. (platforms: `linux/amd64`, `linux/arm64`) 2. `ghcr.io/ggml-org/whisper.cpp:main-cuda`: Same as `main` but compiled with CUDA support. (platforms: `linux/amd64`) 3. `ghcr.io/ggml-org/whisper.cpp:main-musa`: Same as `main` but compiled with MUSA support. (platforms: `linux/amd64`) +4. `ghcr.io/ggml-org/whisper.cpp:main-vulkan`: Same as `main` but compiled with Vulkan support. (platforms: `linux/amd64`) ### Usage @@ -456,15 +457,27 @@ We have two Docker images available for this project: docker run -it --rm \ -v path/to/models:/models \ whisper.cpp:main "./models/download-ggml-model.sh base /models" + # transcribe an audio file docker run -it --rm \ -v path/to/models:/models \ -v path/to/audios:/audios \ whisper.cpp:main "whisper-cli -m /models/ggml-base.bin -f /audios/jfk.wav" + # transcribe an audio file in samples folder docker run -it --rm \ -v path/to/models:/models \ whisper.cpp:main "whisper-cli -m /models/ggml-base.bin -f ./samples/jfk.wav" + +# run the web server +docker run -it --rm -p "8080:8080" \ + -v path/to/models:/models \ + whisper.cpp:main "whisper-server --host 127.0.0.1 -m /models/ggml-base.bin" + +# run the bench too on the small.en model using 4 threads +docker run -it --rm \ + -v path/to/models:/models \ + whisper.cpp:main "whisper-bench -m /models/ggml-small.en.bin -t 4" ``` ## Installing with Conan From 808904277e765c607e4c8ee029508e00a5a1ab1a Mon Sep 17 00:00:00 2001 From: Oliver Simons Date: Sun, 8 Feb 2026 14:12:51 +0100 Subject: [PATCH 120/831] CUDA: Fix non-contig rope (llama/19338) * Rename variables + fix rope_neox Seems memory layout is shared with Vulkan so we can port fix from https://github.com/ggml-org/llama.cpp/pull/19299 * Fix rope_multi * Fix rope_vision * Fix rope_norm * Rename ne* to ne0* for consistent variable naming * cont : consistent stride names --------- Co-authored-by: Georgi Gerganov --- ggml/src/ggml-cuda/rope.cu | 366 +++++++++++++++++++++++-------------- 1 file changed, 233 insertions(+), 133 deletions(-) diff --git a/ggml/src/ggml-cuda/rope.cu b/ggml/src/ggml-cuda/rope.cu index 88ed79111a1..45a49a5dc2a 100644 --- a/ggml/src/ggml-cuda/rope.cu +++ b/ggml/src/ggml-cuda/rope.cu @@ -43,10 +43,15 @@ static __device__ void rope_yarn( template static __global__ void rope_norm(const T * x, D * dst, - const int ne0, - const int ne1, + const int ne00, + const int ne01, + const int ne02, + const int s01, + const int s02, + const int s03, const int s1, const int s2, + const int s3, const int n_dims, const int32_t * pos, const float freq_scale, @@ -59,23 +64,23 @@ static __global__ void rope_norm(const T * x, const int set_rows_stride) { const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); - if (i0 >= ne0) { + if (i0 >= ne00) { return; } const int row_dst = blockDim.x*blockIdx.x + threadIdx.x; - const int row_x = row_dst % ne1; - const int channel_x = row_dst / ne1; - - int idst = row_dst * ne0 + i0; - const int ix = channel_x*s2 + row_x*s1 + i0; + const uint32_t i3 = row_dst / (ne01 * ne02); + const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01; + const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01; + int idst = i0 + i1 * s1 + i2 * s2 + i3 * s3; + const int ix = i0 + i1 * s01 + i2 * s02 + i3 * s03; // Fusion optimization: ROPE + VIEW + SET_ROWS. // The rope output is viewed as a 1D tensor and offset based on a row index in row_indices. if (set_rows_stride != 0) { - idst = row_x * ne0 + i0; - idst += row_indices[channel_x] * set_rows_stride; + idst = i1 * s1 + i0; + idst += row_indices[i2] * set_rows_stride; } const auto & store_coaelsced = [&](float x0, float x1) { @@ -92,7 +97,7 @@ static __global__ void rope_norm(const T * x, return; } - const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f); + const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f); const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; @@ -110,10 +115,15 @@ static __global__ void rope_norm(const T * x, template static __global__ void rope_neox(const T * x, D * dst, - const int ne0, - const int ne1, + const int ne00, + const int ne01, + const int ne02, + const int s01, + const int s02, + const int s03, const int s1, const int s2, + const int s3, const int n_dims, const int32_t * pos, const float freq_scale, @@ -126,23 +136,24 @@ static __global__ void rope_neox(const T * x, const int set_rows_stride) { const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); - if (i0 >= ne0) { + if (i0 >= ne00) { return; } const int row_dst = blockDim.x*blockIdx.x + threadIdx.x; - const int row_x = row_dst % ne1; - const int channel_x = row_dst / ne1; + const uint32_t i3 = row_dst / (ne01 * ne02); + const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01; + const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01; - int idst = row_dst * ne0 + i0 / 2; - const int ix = channel_x*s2 + row_x*s1 + i0/2; + int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3; + const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03; // Fusion optimization: ROPE + VIEW + SET_ROWS. // The rope output is viewed as a 1D tensor and offset based on a row index in row_indices. if (set_rows_stride != 0) { - idst = row_x * ne0 + i0 / 2; - idst += row_indices[channel_x] * set_rows_stride; + idst = i1 * s1 + i0 / 2; + idst += row_indices[i2] * set_rows_stride; } if (i0 >= n_dims) { @@ -152,7 +163,7 @@ static __global__ void rope_neox(const T * x, return; } - const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f); + const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f); const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; @@ -168,24 +179,42 @@ static __global__ void rope_neox(const T * x, dst[idst + n_dims / 2] = ggml_cuda_cast(x0 * sin_theta + x1 * cos_theta); } -template -static __global__ void rope_multi( - const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, - const int n_dims, const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, - const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors, const mrope_sections sections, const bool is_imrope) { - const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); - - if (i0 >= ne0) { +template +static __global__ void rope_multi(const T * x, + T * dst, + const int ne00, + const int ne01, + const int ne02, + const int s01, + const int s02, + const int s03, + const int s1, + const int s2, + const int s3, + const int n_dims, + const int32_t * pos, + const float freq_scale, + const float ext_factor, + const float attn_factor, + const rope_corr_dims corr_dims, + const float theta_scale, + const float * freq_factors, + const mrope_sections sections, + const bool is_imrope) { + const int i0 = 2 * (blockDim.y * blockIdx.y + threadIdx.y); + + if (i0 >= ne00) { return; } const int row_dst = blockDim.x*blockIdx.x + threadIdx.x; - const int row_x = row_dst % ne1; - const int channel_x = row_dst / ne1; + const uint32_t i3 = row_dst / (ne01 * ne02); + const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01; + const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01; - const int idst = row_dst*ne0 + i0/2; - const int ix = channel_x*s2 + row_x*s1 + i0/2; + int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3; + const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03; if (i0 >= n_dims) { dst[idst + i0/2 + 0] = x[ix + i0/2 + 0]; @@ -200,27 +229,24 @@ static __global__ void rope_multi( float theta_base = 0.0; if (is_imrope) { - if (sector % 3 == 1 && sector < 3 * sections.v[1]) { // h - theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f); - } else if (sector % 3 == 2 && sector < 3 * sections.v[2]) { // w - theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f); - } else if (sector % 3 == 0 && sector < 3 * sections.v[0]) { // t - theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f); + if (sector % 3 == 1 && sector < 3 * sections.v[1]) { // h + theta_base = pos[i2 + ne02 * 1] * powf(theta_scale, i0 / 2.0f); + } else if (sector % 3 == 2 && sector < 3 * sections.v[2]) { // w + theta_base = pos[i2 + ne02 * 2] * powf(theta_scale, i0 / 2.0f); + } else if (sector % 3 == 0 && sector < 3 * sections.v[0]) { // t + theta_base = pos[i2] * powf(theta_scale, i0 / 2.0f); } else { - theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f); + theta_base = pos[i2 + ne02 * 3] * powf(theta_scale, i0 / 2.0f); } } else { if (sector < sections.v[0]) { - theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f); - } - else if (sector >= sections.v[0] && sector < sec_w) { - theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f); - } - else if (sector >= sec_w && sector < sec_w + sections.v[2]) { - theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f); - } - else if (sector >= sec_w + sections.v[2]) { - theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f); + theta_base = pos[i2] * powf(theta_scale, i0 / 2.0f); + } else if (sector >= sections.v[0] && sector < sec_w) { + theta_base = pos[i2 + ne02 * 1] * powf(theta_scale, i0 / 2.0f); + } else if (sector >= sec_w && sector < sec_w + sections.v[2]) { + theta_base = pos[i2 + ne02 * 2] * powf(theta_scale, i0 / 2.0f); + } else if (sector >= sec_w + sections.v[2]) { + theta_base = pos[i2 + ne02 * 3] * powf(theta_scale, i0 / 2.0f); } } @@ -238,37 +264,53 @@ static __global__ void rope_multi( dst[idst + n_dims/2] = x0*sin_theta + x1*cos_theta; } -template -static __global__ void rope_vision( - const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, - const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims, - const float theta_scale, const float * freq_factors, const mrope_sections sections) { +template +static __global__ void rope_vision(const T * x, + T * dst, + const int ne00, + const int ne01, + const int ne02, + const int s01, + const int s02, + const int s03, + const int s1, + const int s2, + const int s3, + const int n_dims, + const int32_t * pos, + const float freq_scale, + const float ext_factor, + const float attn_factor, + const rope_corr_dims corr_dims, + const float theta_scale, + const float * freq_factors, + const mrope_sections sections) { const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); - if (i0 >= ne0) { + if (i0 >= ne00) { return; } const int row_dst = blockDim.x*blockIdx.x + threadIdx.x; - const int row_x = row_dst % ne1; - const int channel_x = row_dst / ne1; + const uint32_t i3 = row_dst / (ne01 * ne02); + const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01; + const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01; - const int idst = row_dst*ne0 + i0/2; - const int ix = channel_x*s2 + row_x*s1 + i0/2; + int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3; + const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03; const int sect_dims = sections.v[0] + sections.v[1]; - const int sec_w = sections.v[1] + sections.v[0]; - const int sector = (i0 / 2) % sect_dims; + const int sec_w = sections.v[1] + sections.v[0]; + const int sector = (i0 / 2) % sect_dims; float theta_base = 0.0; if (sector < sections.v[0]) { const int p = sector; - theta_base = pos[channel_x]*powf(theta_scale, p); - } - else if (sector >= sections.v[0] && sector < sec_w) { + theta_base = pos[i2] * powf(theta_scale, p); + } else if (sector >= sections.v[0] && sector < sec_w) { const int p = sector - sections.v[0]; - theta_base = pos[channel_x + ne2]*powf(theta_scale, p); + theta_base = pos[i2 + ne02] * powf(theta_scale, p); } const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; @@ -288,10 +330,15 @@ static __global__ void rope_vision( template static void rope_norm_cuda(const T * x, D * dst, - const int ne0, - const int ne1, + const int ne00, + const int ne01, + const int ne02, + const int s01, + const int s02, + const int s03, const int s1, const int s2, + const int s3, const int n_dims, const int nr, const int32_t * pos, @@ -304,31 +351,36 @@ static void rope_norm_cuda(const T * x, const int64_t * row_indices, const int set_rows_stride, cudaStream_t stream) { - GGML_ASSERT(ne0 % 2 == 0); + GGML_ASSERT(ne00 % 2 == 0); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); - const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); + const int n_blocks_x = (ne00 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE); const dim3 block_nums(nr, n_blocks_x, 1); - const float theta_scale = powf(freq_base, -2.0f/n_dims); + const float theta_scale = powf(freq_base, -2.0f / n_dims); if (freq_factors == nullptr) { rope_norm<<>>( - x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, - freq_factors, row_indices, set_rows_stride); + x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor, + attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride); } else { rope_norm<<>>( - x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, - freq_factors, row_indices, set_rows_stride); + x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor, + attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride); } } template static void rope_neox_cuda(const T * x, D * dst, - const int ne0, - const int ne1, + const int ne00, + const int ne01, + const int ne02, + const int s01, + const int s02, + const int s03, const int s1, const int s2, + const int s3, const int n_dims, const int nr, const int32_t * pos, @@ -341,55 +393,92 @@ static void rope_neox_cuda(const T * x, const int64_t * row_indices, const int set_rows_stride, cudaStream_t stream) { - GGML_ASSERT(ne0 % 2 == 0); + GGML_ASSERT(ne00 % 2 == 0); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); - const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); + const int n_blocks_x = (ne00 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE); const dim3 block_nums(nr, n_blocks_x, 1); - const float theta_scale = powf(freq_base, -2.0f/n_dims); + const float theta_scale = powf(freq_base, -2.0f / n_dims); if (freq_factors == nullptr) { rope_neox<<>>( - x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, - freq_factors, row_indices, set_rows_stride); + x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor, + attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride); } else { rope_neox<<>>( - x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, - freq_factors, row_indices, set_rows_stride); + x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor, + attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride); } } -template -static void rope_multi_cuda( - const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr, - const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor, - const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, const bool is_imrope, cudaStream_t stream) { - GGML_ASSERT(ne0 % 2 == 0); +template +static void rope_multi_cuda(const T * x, + T * dst, + const int ne00, + const int ne01, + const int ne02, + const int s01, + const int s02, + const int s03, + const int s1, + const int s2, + const int s3, + const int n_dims, + const int nr, + const int32_t * pos, + const float freq_scale, + const float freq_base, + const float ext_factor, + const float attn_factor, + const rope_corr_dims corr_dims, + const float * freq_factors, + const mrope_sections sections, + const bool is_imrope, + cudaStream_t stream) { + GGML_ASSERT(ne00 % 2 == 0); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); - const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); + const int n_blocks_x = (ne00 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE); const dim3 block_nums(nr, n_blocks_x, 1); - const float theta_scale = powf(freq_base, -2.0f/n_dims); + const float theta_scale = powf(freq_base, -2.0f / n_dims); if (freq_factors == nullptr) { rope_multi<<>>( - x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, + x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope); } else { rope_multi<<>>( - x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, + x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope); } } -template -static void rope_vision_cuda( - const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr, - const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor, - const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, cudaStream_t stream) { - GGML_ASSERT(ne0 % 2 == 0); +template +static void rope_vision_cuda(const T * x, + T * dst, + const int ne00, + const int ne01, + const int ne02, + const int s01, + const int s02, + const int s03, + const int s1, + const int s2, + const int s3, + const int n_dims, + const int nr, + const int32_t * pos, + const float freq_scale, + const float freq_base, + const float ext_factor, + const float attn_factor, + const rope_corr_dims corr_dims, + const float * freq_factors, + const mrope_sections sections, + cudaStream_t stream) { + GGML_ASSERT(ne00 % 2 == 0); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); - const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); + const int n_blocks_x = (ne00 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE); const dim3 block_nums(nr, n_blocks_x, 1); // break down (head_dim, heads, seq) into (CUDA_ROPE_BLOCK_SIZE, x, heads * seq) // where x ~= ceil(head_dim / CUDA_ROPE_BLOCK_SIZE); @@ -398,11 +487,11 @@ static void rope_vision_cuda( if (freq_factors == nullptr) { rope_vision<<>>( - x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, + x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, freq_factors, sections); } else { rope_vision<<>>( - x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, + x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, freq_factors, sections); } } @@ -445,6 +534,11 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, const size_t s01 = src0->nb[1] / ggml_type_size(src0->type); const size_t s02 = src0->nb[2] / ggml_type_size(src0->type); + const size_t s03 = src0->nb[3] / ggml_type_size(src0->type); + + const size_t s1 = dst->nb[1] / ggml_type_size(dst->type); + const size_t s2 = dst->nb[2] / ggml_type_size(dst->type); + const size_t s3 = dst->nb[3] / ggml_type_size(dst->type); //const int n_past = ((int32_t *) dst->op_params)[0]; const int n_dims = ((int32_t *) dst->op_params)[1]; @@ -495,57 +589,63 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, // compute if (is_neox) { if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) { - rope_neox_cuda((const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims, - nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, - freq_factors, row_indices, set_rows_stride, stream); + rope_neox_cuda((const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, + s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base, + ext_factor, attn_factor, corr_dims, freq_factors, row_indices, + set_rows_stride, stream); } else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) { - rope_neox_cuda((const float *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, - nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, - freq_factors, row_indices, set_rows_stride, stream); + rope_neox_cuda((const float *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, + s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base, + ext_factor, attn_factor, corr_dims, freq_factors, row_indices, + set_rows_stride, stream); } else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) { - rope_neox_cuda((const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr, - pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, - freq_factors, row_indices, set_rows_stride, stream); + rope_neox_cuda((const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, + s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base, + ext_factor, attn_factor, corr_dims, freq_factors, row_indices, + set_rows_stride, stream); } else { GGML_ABORT("fatal error"); } } else if (is_mrope && !is_vision) { if (src0->type == GGML_TYPE_F32) { - rope_multi_cuda( - (const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale, - freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, is_imrope, stream); + rope_multi_cuda((const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, s03, s1, + s2, s3, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, + corr_dims, freq_factors, sections, is_imrope, stream); } else if (src0->type == GGML_TYPE_F16) { - rope_multi_cuda( - (const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale, - freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, is_imrope, stream); + rope_multi_cuda((const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, s03, s1, + s2, s3, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, + corr_dims, freq_factors, sections, is_imrope, stream); } else { GGML_ABORT("fatal error"); } } else if (is_vision) { if (src0->type == GGML_TYPE_F32) { - rope_vision_cuda( - (const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale, - freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream); + rope_vision_cuda((const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, s03, s1, + s2, s3, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, + corr_dims, freq_factors, sections, stream); } else if (src0->type == GGML_TYPE_F16) { - rope_vision_cuda( - (const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale, - freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream); + rope_vision_cuda((const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, s03, s1, + s2, s3, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, + corr_dims, freq_factors, sections, stream); } else { GGML_ABORT("fatal error"); } } else { if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) { - rope_norm_cuda((const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims, - nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, - freq_factors, row_indices, set_rows_stride, stream); + rope_norm_cuda((const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, + s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base, + ext_factor, attn_factor, corr_dims, freq_factors, row_indices, + set_rows_stride, stream); } else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) { - rope_norm_cuda((const float *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, - nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, - freq_factors, row_indices, set_rows_stride, stream); + rope_norm_cuda((const float *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, + s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base, + ext_factor, attn_factor, corr_dims, freq_factors, row_indices, + set_rows_stride, stream); } else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) { - rope_norm_cuda((const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr, - pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, - freq_factors, row_indices, set_rows_stride, stream); + rope_norm_cuda((const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, + s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base, + ext_factor, attn_factor, corr_dims, freq_factors, row_indices, + set_rows_stride, stream); } else { GGML_ABORT("fatal error"); } From a36210c8362b8e01bbcbd7e6fa89c68b90068974 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 10 Feb 2026 08:07:16 +0200 Subject: [PATCH 121/831] cuda : extend GGML_OP_PAD to work with non-cont src0 (llama/19429) * cuda : extend GGML_OP_PAD to work with non-cont src0 * tests : add permuted pad --- ggml/src/ggml-cpu/ops.cpp | 3 +-- ggml/src/ggml-cuda/ggml-cuda.cu | 3 ++- ggml/src/ggml-cuda/pad.cu | 23 +++++++++++++---------- 3 files changed, 16 insertions(+), 13 deletions(-) diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index ce15b18ce0e..ed45350207e 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -7629,8 +7629,7 @@ static void ggml_compute_forward_pad_f32( const ggml_tensor * src0 = dst->src[0]; - GGML_ASSERT(src0->nb[0] == sizeof(float)); - GGML_ASSERT( dst->nb[0] == sizeof(float)); + assert(dst->nb[0] == sizeof(float)); const int ith = params->ith; const int nth = params->nth; diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 9e77c231c85..b163468789f 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -4834,8 +4834,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_SUM_ROWS: case GGML_OP_MEAN: case GGML_OP_GROUP_NORM: - case GGML_OP_PAD: return ggml_is_contiguous(op->src[0]); + case GGML_OP_PAD: + return true; case GGML_OP_UPSCALE: case GGML_OP_PAD_REFLECT_1D: case GGML_OP_ARANGE: diff --git a/ggml/src/ggml-cuda/pad.cu b/ggml/src/ggml-cuda/pad.cu index 660c192e48a..31cd00f7781 100644 --- a/ggml/src/ggml-cuda/pad.cu +++ b/ggml/src/ggml-cuda/pad.cu @@ -7,7 +7,7 @@ __device__ __forceinline__ int64_t wrap_around(int64_t coord, int64_t size) { return (coord + size) % size; } -static __global__ void pad_f32(const float * src, float * dst, +static __global__ void pad_f32(const float * src, size_t s00, size_t s01, size_t s02, size_t s03, float * dst, const int lp0, const int rp0, const int lp1, const int rp1, const int lp2, const int rp2, const int lp3, const int rp3, const int ne0, const int ne1, const int ne2, const int ne3, @@ -34,11 +34,8 @@ static __global__ void pad_f32(const float * src, float * dst, const int64_t i01 = i1 - lp1; const int64_t i02 = i2 - lp2; const int64_t i03 = i3 - lp3; - const int64_t ne02 = ne2 - lp2 - rp2; - const int64_t ne01 = ne1 - lp1 - rp1; - const int64_t ne00 = ne0 - lp0 - rp0; - const int64_t src_idx = i03 * (ne00 * ne01 * ne02) + i02 * (ne00 * ne01) + i01 * ne00 + i00; + const int64_t src_idx = i03 * s03 + i02 * s02 + i01 * s01 + i00 * s00; dst[dst_idx] = src[src_idx]; } else { @@ -57,21 +54,21 @@ static __global__ void pad_f32(const float * src, float * dst, const int64_t i02 = wrap_around(i2 - lp2, ne02); const int64_t i03 = wrap_around(i3 - lp3, ne03); - const int64_t src_idx = i03 * (ne00 * ne01 * ne02) + i02 * (ne00 * ne01) + i01 * ne00 + i00; + const int64_t src_idx = i03 * s03 + i02 * s02 + i01 * s01 + i00 * s00; dst[dst_idx] = src[src_idx]; } } -static void pad_f32_cuda(const float * src, float * dst, +static void pad_f32_cuda(const float * src, size_t s00, size_t s01, size_t s02, size_t s03, float * dst, const int lp0, const int rp0, const int lp1, const int rp1, const int lp2, const int rp2, const int lp3, const int rp3, const int ne0, const int ne1, const int ne2, const int ne3, const bool circular, cudaStream_t stream) { int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE; dim3 gridDim(num_blocks, ne1, ne2 * ne3); - pad_f32<<>>(src, dst, + pad_f32<<>>(src, s00, s01, s02, s03, dst, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, ne0, ne1, ne2, ne3, circular); } @@ -82,9 +79,10 @@ void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { float * dst_d = (float *) dst->data; cudaStream_t stream = ctx.stream(); + GGML_TENSOR_UNARY_OP_LOCALS; + GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32); - GGML_ASSERT(ggml_is_contiguous(src0)); const int32_t lp0 = ((const int32_t *) (dst->op_params))[0]; const int32_t rp0 = ((const int32_t *) (dst->op_params))[1]; @@ -96,7 +94,12 @@ void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const int32_t rp3 = ((const int32_t *) (dst->op_params))[7]; const int32_t circular = ((const int32_t *) (dst->op_params))[8]; - pad_f32_cuda(src0_d, dst_d, + const size_t s00 = nb00 / ggml_type_size(src0->type); + const size_t s01 = nb01 / ggml_type_size(src0->type); + const size_t s02 = nb02 / ggml_type_size(src0->type); + const size_t s03 = nb03 / ggml_type_size(src0->type); + + pad_f32_cuda(src0_d, s00, s01, s02, s03, dst_d, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], (bool) circular, stream); From 6a74f56212cf576a7212f21f6706eb10ce5e0c18 Mon Sep 17 00:00:00 2001 From: hipudding Date: Tue, 10 Feb 2026 14:18:59 +0800 Subject: [PATCH 122/831] CANN: implement quantized MUL_MAT_ID for MoE models (llama/19228) Implement ggml_cann_mul_mat_id_quant function to support quantized matrix multiplication for Mixture of Experts (MoE) architectures on CANN backend. Key features: - Support Q4_0 and Q8_0 quantized weight formats - Use IndexSelect to dynamically route expert-specific weights based on indices - Leverage WeightQuantBatchMatmulV2 for efficient quantized computation - Handle automatic F16 type conversion for hardware compatibility - Support both per-expert and broadcast input modes Implementation details: - Extract expert weights and scales using CANN IndexSelect operation - Process each batch and expert combination independently - Create proper tensor views with correct stride for matmul operations - Automatic input/output type casting to/from F16 as needed Testing: All test cases passed for supported types (F32, F16, Q4_0, Q8_0). --- ggml/src/ggml-cann/aclnn_ops.cpp | 315 ++++++++++++++++++++----------- 1 file changed, 204 insertions(+), 111 deletions(-) diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index 87ac05748e8..fc7c3e3b724 100644 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -3286,130 +3286,223 @@ static void ggml_cann_mul_mat_id_fp(ggml_backend_cann_context & ctx, ggml_tensor } /** - * @brief Performs expert-specific matrix multiplication (MoE) with - * quantized precision using the CANN backend. - * - * This function executes a matrix multiplication operation tailored for - * Mixture of Experts (MoE) models, where the input tensor is multiplied - * with expert-specific quantized weight matrices. It leverages the CANN - * backend to perform efficient low-precision computations and stores the - * quantized result in the destination tensor `dst`. - * - * Quantization techniques reduce memory footprint and improve performance - * by using lower-bit representations (e.g., int8) instead of floating-point. - * This function is designed to work with such formats and may incorporate - * optimizations like identity-based fast paths or routing masks for sparse - * expert selection. - * - * @param ctx The context for executing CANN backend operations. - * @param dst The destination tensor where the quantized MoE multiplication result - * will be stored. - * - * @note This function assumes quantized data types and is designed for - * MoE architectures with potential sparse expert routing. + * @brief Performs quantized matrix multiplication for Mixture of Experts (MoE) + * models using the CANN backend. + * + * This function implements MUL_MAT_ID operation for quantized weight matrices + * (Q4_0 and Q8_0 formats). It selects expert-specific weight matrices based on + * the provided expert indices, and computes matrix multiplication using CANN's + * WeightQuantBatchMatmulV2 operator. + * + * The function performs the following steps: + * 1. Converts input/output tensors to F16 format if necessary + * 2. Uses IndexSelect to extract expert-specific weights and scales based on indices + * 3. Performs quantized matrix multiplication for each expert using WeightQuantBatchMatmulV2 + * 4. Converts output back to the target type if needed + * + * Tensor shapes: + * - dst: [M, K, N, 1] - output tensor + * - src0: [D, M, A, 1] - quantized weight matrices (Q4_0 or Q8_0) + * - src1: [D, B, N, 1] - input activations (B = K for per-expert input, or B = 1 for broadcast) + * - ids: [K, N] - expert indices for routing + * + * @param ctx The CANN backend context for operation execution. + * @param dst The destination tensor where the multiplication result will be stored. + * + * @note Only Q4_0 and Q8_0 quantization formats are supported. + * @note The function handles automatic type conversion to/from F16 as needed by the hardware. */ static void ggml_cann_mul_mat_id_quant(ggml_backend_cann_context & ctx, ggml_tensor * dst) { - // TODO: Use aclnnGroupedMatMul - //dst [M, K, N, 1] - ggml_tensor * src0 = dst->src[0]; //src0 [D, M, A, 1] - ggml_tensor * src1 = dst->src[1]; //src1 [D, B, N, 1], B = K or B = 1 - ggml_tensor * ids = dst->src[2]; //ids [K, N] + // dst: [M, K, N, 1] + // src0: [D, M, A, 1] - quantized weights + // src1: [D, B, N, 1] - input activations, B = K or B = 1 + // ids: [K, N] - expert indices + ggml_tensor * src0 = dst->src[0]; + ggml_tensor * src1 = dst->src[1]; + ggml_tensor * ids = dst->src[2]; - GGML_TENSOR_BINARY_OP_LOCALS + GGML_ASSERT(src0->ne[3] == 1); + GGML_ASSERT(src1->ne[3] == 1); + GGML_ASSERT(dst->ne[3] == 1); + GGML_ASSERT(src1->ne[2] == ids->ne[1]); + + const int64_t n_batches = ids->ne[1]; + const int64_t n_select_experts = ids->ne[0]; + const enum ggml_type type = src0->type; + + const int32_t group_size = QK8_0; // Both Q4_0 and Q8_0 use group size of 32 + GGML_ASSERT(group_size == QK4_0); + + // Calculate element size for quantized weights + const float weight_elem_size = + (type == GGML_TYPE_Q4_0) ? 0.5f : + (type == GGML_TYPE_Q8_0) ? 1.0f : + (GGML_ABORT("MUL_MAT_ID only supports Q4_0 and Q8_0"), 0.0f); + + // Calculate scale offset in memory + const size_t weight_size = src0->ne[0] * src0->ne[1] * src0->ne[2] * weight_elem_size; + const size_t scale_elem_size = sizeof(uint16_t); + char * scale_data = (char *) src0->data + weight_size; + + // Allocate buffers for selected expert weights and scales + const size_t selected_weight_size = src0->ne[0] * src0->ne[1] * n_select_experts * weight_elem_size; + ggml_cann_pool_alloc selected_weight_alloc(ctx.pool(), selected_weight_size); + void * selected_weight_buffer = selected_weight_alloc.get(); + + const size_t selected_scale_size = (src0->ne[0] / group_size) * src0->ne[1] * n_select_experts * scale_elem_size; + ggml_cann_pool_alloc selected_scale_alloc(ctx.pool(), selected_scale_size); + void * selected_scale_buffer = selected_scale_alloc.get(); + + // Helper lambda to allocate and cast tensor to F16 if needed + constexpr size_t f16_elem_size = sizeof(uint16_t); + auto prepare_f16_buffer = [&](ggml_tensor * tensor, ggml_cann_pool_alloc & allocator, + bool need_cast = false) -> void * { + if (tensor->type == GGML_TYPE_F16) { + return tensor->data; + } - // copy index from npu to cpu - int64_t n_as = ne02; // A - int64_t n_ids = ids->ne[0]; // K + size_t total_size = f16_elem_size; + for (int i = 0; i < GGML_MAX_DIMS; i++) { + total_size *= tensor->ne[i]; + } + void * buffer = allocator.alloc(total_size); - std::vector ids_host(ggml_nbytes(ids)); - ACL_CHECK(aclrtMemcpyAsync(ids_host.data(), ggml_nbytes(ids), ids->data, ggml_nbytes(ids), - ACL_MEMCPY_DEVICE_TO_HOST, ctx.stream())); - ACL_CHECK(aclrtSynchronizeStream(ctx.stream())); + if (need_cast == false) { + return buffer; + } - char * src0_original = (char *) src0->data; - char * src1_original = (char *) src1->data; - char * dst_original = (char *) dst->data; + int64_t ne[GGML_MAX_DIMS]; + size_t nb[GGML_MAX_DIMS] = { f16_elem_size }; + for (int i = 0; i < GGML_MAX_DIMS; i++) { + ne[i] = tensor->ne[i]; + if (i > 0) { + nb[i] = nb[i - 1] * ne[i - 1]; + } + } - ggml_tensor src0_row = *src0; - ggml_tensor src1_row = *src1; - ggml_tensor dst_row = *dst; + acl_tensor_ptr src_tensor = ggml_cann_create_tensor(tensor); + acl_tensor_ptr f16_tensor = ggml_cann_create_tensor(buffer, ACL_FLOAT16, f16_elem_size, ne, nb, GGML_MAX_DIMS); + aclnn_cast(ctx, src_tensor.get(), f16_tensor.get(), ACL_FLOAT16); - const enum ggml_type type = dst->src[0]->type; - float weight_elem_size; - if (type == GGML_TYPE_Q4_0) { - weight_elem_size = float(sizeof(uint8_t)) / 2; - } else if (type == GGML_TYPE_Q8_0) { - weight_elem_size = float(sizeof(uint8_t)); - } else { - GGML_ABORT("MUL_MAT_ID only support quant type Q4_0 and Q8_0 "); - } + return buffer; + }; - // src0_row [D, M, 1, 1] weight without permute - src0_row.ne[2] = 1; - src0_row.ne[3] = 1; - src0_row.nb[0] = weight_elem_size; - src0_row.nb[1] = weight_elem_size * ne00; - src0_row.nb[2] = weight_elem_size * ne00; - src0_row.nb[3] = weight_elem_size * ne00; - size_t weight_stride = ne00 * ne01 * weight_elem_size; - size_t weight_size = weight_stride * ne02 * ne03; + // Prepare input and output buffers + ggml_cann_pool_alloc input_alloc(ctx.pool()); + void * input_buffer = prepare_f16_buffer(src1, input_alloc, true); - // scale [D, M, 1, 1] -> scale && permute - size_t scale_elem_size = sizeof(uint16_t); - size_t scale_stride = src0->ne[1] * src0->ne[0] / QK8_0 * scale_elem_size; + ggml_cann_pool_alloc output_alloc(ctx.pool()); + void * output_buffer = prepare_f16_buffer(dst, output_alloc, false); + + // Process each batch + for (int64_t batch_idx = 0; batch_idx < n_batches; batch_idx++) { + // Create index tensor for current batch + const size_t index_offset = batch_idx * ids->nb[1]; + acl_tensor_ptr batch_indices = ggml_cann_create_tensor(ids, ids->ne, ids->nb, 1, ACL_FORMAT_ND, index_offset); + + // Select quantized weights using expert indices + // Q4_0 stores 2 values per byte, Q8_0 stores 1 value per byte + const int64_t weight_d = (type == GGML_TYPE_Q4_0) ? src0->ne[0] / 2 : src0->ne[0]; + const int64_t weight_m = src0->ne[1]; + const int64_t weight_n_experts = src0->ne[2]; + + int64_t weight_ne[3] = { weight_d, weight_m, weight_n_experts }; + size_t weight_nb[3] = { sizeof(int8_t), weight_d * sizeof(int8_t), weight_d * weight_m * sizeof(int8_t) }; + + acl_tensor_ptr all_weights = + ggml_cann_create_tensor(src0->data, ACL_INT8, sizeof(int8_t), weight_ne, weight_nb, 3); + + int64_t selected_weight_ne[3] = { weight_d, weight_m, n_select_experts }; + size_t selected_weight_nb[3] = { sizeof(int8_t), weight_d * sizeof(int8_t), + weight_d * weight_m * sizeof(int8_t) }; + + acl_tensor_ptr selected_weights = ggml_cann_create_tensor(selected_weight_buffer, ACL_INT8, sizeof(int8_t), + selected_weight_ne, selected_weight_nb, 3); + + GGML_CANN_CALL_ACLNN_OP(ctx, IndexSelect, all_weights.get(), 0, batch_indices.get(), selected_weights.get()); + + // Select scales using the same expert indices + const int64_t scale_d = src0->ne[0] / group_size; + int64_t scale_ne[3] = { scale_d, weight_m, weight_n_experts }; + size_t scale_nb[3] = { scale_elem_size, scale_d * scale_elem_size, scale_d * weight_m * scale_elem_size }; + + acl_tensor_ptr all_scales = + ggml_cann_create_tensor(scale_data, ACL_FLOAT16, scale_elem_size, scale_ne, scale_nb, 3); + + int64_t selected_scale_ne[3] = { scale_d, weight_m, n_select_experts }; + size_t selected_scale_nb[3] = { scale_elem_size, scale_d * scale_elem_size, + scale_d * weight_m * scale_elem_size }; + + acl_tensor_ptr selected_scales = ggml_cann_create_tensor(selected_scale_buffer, ACL_FLOAT16, scale_elem_size, + selected_scale_ne, selected_scale_nb, 3); + + GGML_CANN_CALL_ACLNN_OP(ctx, IndexSelect, all_scales.get(), 0, batch_indices.get(), selected_scales.get()); + + // Process each expert for current batch + // IndexSelect output layout: [D, M, K] in contiguous format + // WeightQuantBatchMatmulV2 expects: [M, D] with row-major stride + for (int64_t expert_idx = 0; expert_idx < n_select_experts; expert_idx++) { + // Determine input offset: broadcast if src1->ne[1]==1, otherwise use per-expert input + const size_t input_offset = + (batch_idx * src1->ne[1] + (src1->ne[1] == 1 ? 0 : expert_idx)) * src1->ne[0] * f16_elem_size; + const size_t output_offset = (batch_idx * dst->ne[1] + expert_idx) * dst->ne[0] * f16_elem_size; + + // Create weight view for current expert: [D, M, K] -> [M, D] + int64_t weight_view_ne[2] = { weight_m, src0->ne[0] }; + float weight_view_nb[2] = { src0->ne[0] * weight_elem_size, weight_elem_size }; + const size_t weight_view_offset = expert_idx * selected_weight_nb[2]; - // src1_row [D, 1, 1, 1] -> input - src1_row.ne[1] = 1; - src1_row.ne[2] = 1; - src1_row.ne[3] = 1; - src1_row.nb[2] = nb11; - src1_row.nb[3] = nb11; - - // dst_row [M, 1, 1, 1] -> out - dst_row.ne[1] = 1; - dst_row.ne[2] = 1; - dst_row.ne[3] = 1; - dst_row.nb[2] = nb1; - dst_row.nb[3] = nb1; - - //create weight for one row - ggml_cann_pool_alloc weight_allocator(ctx.pool()); - void * weight_buffer = weight_allocator.alloc(nb02); - for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) { - for (int64_t id = 0; id < n_ids; id++) { - // expert index - int32_t i02 = *(int32_t *) (ids_host.data() + iid1 * ids->nb[1] + id * ids->nb[0]); - GGML_ASSERT(i02 >= 0 && i02 < n_as); - - // If B = 1 (broadcast), always use 0; otherwise, use id. - int64_t i11 = (ne11 == 1 ? 0 : id); - int64_t i12 = iid1; - - int64_t i1 = id; - int64_t i2 = i12; - - void * src0_tmp_ptr = src0_original + i02 * weight_stride; - void * scale_tmp_ptr = src0_original + weight_size + i02 * scale_stride; - void * src1_tmp_ptr = src1_original + i11 * nb11 + i12 * nb12; - void * dst_tmp_ptr = dst_original + i1 * nb1 + i2 * nb2; - - // mem cpy - ACL_CHECK(aclrtMemcpyAsync(weight_buffer, weight_stride, src0_tmp_ptr, weight_stride, - ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream())); - void * scale_buffer = (char *) weight_buffer + weight_stride; - ACL_CHECK(aclrtMemcpyAsync(scale_buffer, scale_stride, scale_tmp_ptr, scale_stride, - ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream())); - - src0_row.data = weight_buffer; - src1_row.data = src1_tmp_ptr; - dst_row.data = dst_tmp_ptr; - dst_row.src[0] = &src0_row; - dst_row.src[1] = &src1_row; - - ggml_cann_mul_mat(ctx, &dst_row); + acl_tensor_ptr weight_view = + ggml_cann_create_tensor(selected_weight_buffer, ggml_cann_type_mapping(type), weight_elem_size, + weight_view_ne, weight_view_nb, 2, ACL_FORMAT_ND, weight_view_offset); + + // Create scale view for current expert: [D, M, K] -> [M, D] + int64_t scale_view_ne[2] = { weight_m, scale_d }; + size_t scale_view_nb[2] = { selected_scale_nb[1], selected_scale_nb[0] }; + const size_t scale_view_offset = expert_idx * selected_scale_nb[2]; + + acl_tensor_ptr scale_view = + ggml_cann_create_tensor(selected_scale_buffer, ACL_FLOAT16, scale_elem_size, scale_view_ne, + scale_view_nb, 2, ACL_FORMAT_ND, scale_view_offset); + + // Create input activation tensor [D, 1] + int64_t input_ne[2] = { src1->ne[0], 1 }; + size_t input_nb[2] = { f16_elem_size, src1->ne[0] * f16_elem_size }; + + acl_tensor_ptr input_tensor = ggml_cann_create_tensor(input_buffer, ACL_FLOAT16, f16_elem_size, input_ne, + input_nb, 2, ACL_FORMAT_ND, input_offset); + + // Create output tensor [M, 1] + int64_t output_ne[2] = { dst->ne[0], 1 }; + size_t output_nb[2] = { f16_elem_size, dst->ne[0] * f16_elem_size }; + + acl_tensor_ptr output_tensor = ggml_cann_create_tensor(output_buffer, ACL_FLOAT16, f16_elem_size, output_ne, + output_nb, 2, ACL_FORMAT_ND, output_offset); + + // Perform quantized matrix multiplication + GGML_CANN_CALL_ACLNN_OP(ctx, WeightQuantBatchMatmulV2, input_tensor.get(), weight_view.get(), + scale_view.get(), nullptr, nullptr, nullptr, nullptr, group_size, + output_tensor.get()); } } - return; + + // Cast output back to original type if we used a temporary F16 buffer + if (dst->type != GGML_TYPE_F16) { + int64_t ne[GGML_MAX_DIMS]; + size_t nb[GGML_MAX_DIMS] = { f16_elem_size }; + for (int i = 0; i < GGML_MAX_DIMS; i++) { + ne[i] = dst->ne[i]; + if (i > 0) { + nb[i] = nb[i - 1] * ne[i - 1]; + } + } + + acl_tensor_ptr f16_output = + ggml_cann_create_tensor(output_buffer, ACL_FLOAT16, f16_elem_size, ne, nb, GGML_MAX_DIMS); + acl_tensor_ptr dst_tensor = ggml_cann_create_tensor(dst); + + aclnn_cast(ctx, f16_output.get(), dst_tensor.get(), ggml_cann_type_mapping(dst->type)); + } } void ggml_cann_mul_mat_id(ggml_backend_cann_context & ctx, ggml_tensor * dst) { From 2de2fc9270fa7ff5f804b69c4eaaf72f7741ba51 Mon Sep 17 00:00:00 2001 From: Raul Torres <138264735+rauletorresc@users.noreply.github.com> Date: Tue, 10 Feb 2026 06:19:30 +0000 Subject: [PATCH 123/831] CANN: Remove unnecessary wrapper for `gml_backend_buft_is_cann` (llama/18968) --- ggml/src/ggml-cann/ggml-cann.cpp | 85 +++++++++++++------------------- 1 file changed, 35 insertions(+), 50 deletions(-) diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index 6b2dbdd3591..3f3de9f0bcb 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -794,19 +794,44 @@ struct ggml_backend_cann_buffer_context { ~ggml_backend_cann_buffer_context() { ACL_CHECK(aclrtFree(dev_ptr)); } }; +// cann buffer type +/** + * @brief Structure representing context information for a specific backend + * buffer type. + */ +struct ggml_backend_cann_buffer_type_context { + int32_t device; /**< Device identifier associated with the buffer context. */ + std::string name; /**< Name associated with the buffer context. */ +}; + /** - * @brief Check if a buffer is a CANN buffer. + * @brief Retrieves the name associated with a CANN buffer type. * - * This function checks if a given buffer is a CANN buffer by comparing its - * `get_name` function pointer to `ggml_backend_cann_buffer_get_name`. + * This function returns the descriptive name associated with the specified + * CANN buffer type context. * - * @param buffer The buffer to check. - * @return true if the buffer is a CANN buffer, false otherwise. + * @param buft Pointer to the buffer type context. + * @return Const pointer to the C-style string containing the name. */ -static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft); +static const char * ggml_backend_cann_buffer_type_name(ggml_backend_buffer_type_t buft) { + ggml_backend_cann_buffer_type_context * buft_ctx = (ggml_backend_cann_buffer_type_context *) buft->context; -static bool ggml_backend_buffer_is_cann(ggml_backend_buffer_t buffer) { - return ggml_backend_buft_is_cann(buffer->buft); + return buft_ctx->name.c_str(); +} + +/** + * @brief Checks if the backend buffer type is associated with the CANN backend. + * + * This function checks whether the provided backend buffer type is associated + * with the CANN backend based on the comparison of its name retrieval function + * pointer. + * + * @param buft Pointer to the backend buffer type to check. + * @return bool Returns true if the buffer type is associated with the CANN + * backend, otherwise false. + */ +static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft) { + return buft->iface.get_name == ggml_backend_cann_buffer_type_name; } /** @@ -1271,7 +1296,7 @@ static void ggml_backend_cann_buffer_get_tensor(ggml_backend_buffer_t buffer, static bool ggml_backend_cann_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) { - if (ggml_backend_buffer_is_cann(src->buffer)) { + if (ggml_backend_buft_is_cann(src->buffer->buft)) { ggml_backend_cann_buffer_context * src_ctx = (ggml_backend_cann_buffer_context *) src->buffer->context; ggml_backend_cann_buffer_context * dst_ctx = (ggml_backend_cann_buffer_context *) buffer->context; @@ -1335,31 +1360,6 @@ static const ggml_backend_buffer_i ggml_backend_cann_buffer_interface = { /* .reset = */ NULL, }; -// cann buffer type -/** - * @brief Structure representing context information for a specific backend - * buffer type. - */ -struct ggml_backend_cann_buffer_type_context { - int32_t device; /**< Device identifier associated with the buffer context. */ - std::string name; /**< Name associated with the buffer context. */ -}; - -/** - * @brief Retrieves the name associated with a CANN buffer type. - * - * This function returns the descriptive name associated with the specified - * CANN buffer type context. - * - * @param buft Pointer to the buffer type context. - * @return Const pointer to the C-style string containing the name. - */ -static const char * ggml_backend_cann_buffer_type_name(ggml_backend_buffer_type_t buft) { - ggml_backend_cann_buffer_type_context * buft_ctx = (ggml_backend_cann_buffer_type_context *) buft->context; - - return buft_ctx->name.c_str(); -} - /** * @brief Allocates a new CANN buffer of the specified type and size. * @@ -1997,7 +1997,7 @@ static bool ggml_backend_cann_cpy_tensor_async(ggml_backend_t backend_src, GGML_ASSERT(!is_matmul_weight((const ggml_tensor *) src)); - if (!ggml_backend_buffer_is_cann(src->buffer) || !ggml_backend_buffer_is_cann(dst->buffer)) { + if (!ggml_backend_buft_is_cann(src->buffer->buft) || !ggml_backend_buft_is_cann(dst->buffer->buft)) { return false; } @@ -2523,21 +2523,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten GGML_UNUSED(dev); } -/** - * @brief Checks if the backend buffer type is associated with the CANN backend. - * - * This function checks whether the provided backend buffer type is associated - * with the CANN backend based on the comparison of its name retrieval function - * pointer. - * - * @param buft Pointer to the backend buffer type to check. - * @return bool Returns true if the buffer type is associated with the CANN - * backend, otherwise false. - */ -static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft) { - return buft->iface.get_name == ggml_backend_cann_buffer_type_name; -} - /** * @brief Records an event on the CANN backend stream. * From b0fe2e84fa5bedfeb7df335153146cb7293cba01 Mon Sep 17 00:00:00 2001 From: k4ss4n <128936199+k4ss4n@users.noreply.github.com> Date: Tue, 10 Feb 2026 10:57:48 +0100 Subject: [PATCH 124/831] ggml : use noexcept overload for is_regular_file in backend registration (llama/19452) using noexcept std::filesystem::directory_entry::is_regular_file overload prevents abnormal termination upon throwing an error (as caused by symlinks to non-existent folders on linux) Resolves: #18560 --- ggml/src/ggml-backend-reg.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp index 8a693f84af5..311fa5fe368 100644 --- a/ggml/src/ggml-backend-reg.cpp +++ b/ggml/src/ggml-backend-reg.cpp @@ -471,9 +471,10 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent, int best_score = 0; fs::path best_path; + std::error_code ec; for (const auto & search_path : search_paths) { - if (std::error_code ec; !fs::exists(search_path, ec)) { + if (!fs::exists(search_path, ec)) { if (ec) { GGML_LOG_DEBUG("%s: posix_stat(%s) failure, error-message: %s\n", __func__, path_str(search_path).c_str(), ec.message().c_str()); } else { @@ -483,7 +484,7 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent, } fs::directory_iterator dir_it(search_path, fs::directory_options::skip_permission_denied); for (const auto & entry : dir_it) { - if (entry.is_regular_file()) { + if (entry.is_regular_file(ec)) { auto filename = entry.path().filename(); auto ext = entry.path().extension(); if (filename.native().find(file_prefix) == 0 && ext == file_extension) { From d77265c8181f916a84a8e452ee2bffc2e476b6d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alberto=20Cabrera=20P=C3=A9rez?= <1478977+Alcpz@users.noreply.github.com> Date: Tue, 10 Feb 2026 10:47:45 +0000 Subject: [PATCH 125/831] ggml-cpu: arm64: q6_K repack gemm and gemv (and generic) implementations (dotprod) (llama/19360) * First working version of GEMM and GEMV * interleave loads and compute * Clang-format * Added missing fallback. Removed tested TODO. * Swap M and N to be consistent with the repack template convention --- ggml/src/ggml-cpu/arch-fallback.h | 16 +- ggml/src/ggml-cpu/arch/arm/repack.cpp | 400 +++++++++++++++++++++++- ggml/src/ggml-cpu/repack.cpp | 425 ++++++++++++++------------ ggml/src/ggml-cpu/repack.h | 4 + 4 files changed, 643 insertions(+), 202 deletions(-) diff --git a/ggml/src/ggml-cpu/arch-fallback.h b/ggml/src/ggml-cpu/arch-fallback.h index 427c1146e46..c6eb75b2300 100644 --- a/ggml/src/ggml-cpu/arch-fallback.h +++ b/ggml/src/ggml-cpu/arch-fallback.h @@ -43,6 +43,7 @@ #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K #define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K +#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 @@ -55,7 +56,8 @@ #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K #define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K -# define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K +#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K +#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 @@ -76,6 +78,7 @@ #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K +#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 @@ -84,6 +87,7 @@ #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K +#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 @@ -107,6 +111,7 @@ #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K #define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K +#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 @@ -119,6 +124,7 @@ #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K #define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K +#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 @@ -143,6 +149,7 @@ #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K #define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K +#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 @@ -155,6 +162,7 @@ #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K #define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K +#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 @@ -186,6 +194,7 @@ #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K #define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K +#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 @@ -197,6 +206,7 @@ #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K #define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K +#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 @@ -227,6 +237,7 @@ #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K #define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K +#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 @@ -239,6 +250,7 @@ #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K #define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K +#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 @@ -271,6 +283,7 @@ #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K #define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K +#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 @@ -283,6 +296,7 @@ #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K #define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K +#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 diff --git a/ggml/src/ggml-cpu/arch/arm/repack.cpp b/ggml/src/ggml-cpu/arch/arm/repack.cpp index 99bb70274c5..fd05c609f7e 100644 --- a/ggml/src/ggml-cpu/arch/arm/repack.cpp +++ b/ggml/src/ggml-cpu/arch/arm/repack.cpp @@ -1072,6 +1072,195 @@ void ggml_gemv_q5_K_8x8_q8_K(int n, ggml_gemv_q5_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc); } +void ggml_gemv_q6_K_8x4_q8_K(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + constexpr int qk = QK_K; + const int nb = n / qk; + + constexpr int ncols_interleaved = 8; + constexpr int blocklen = 4; + + assert(n % qk == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + constexpr int col_groups = ncols_interleaved / 4; + const uint8x16_t m4b = vdupq_n_u8(0x0f); + const uint8x16_t mask_lo = vdupq_n_u8(0x03); + const uint8x16_t mask_hi = vdupq_n_u8(0x30); + + // 1x8 tile = 2 x 4 + float32x4_t acc_f32[2]; + + const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy; + + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q6_Kx8 * GGML_RESTRICT q6_ptr = (const block_q6_Kx8 *) vx + (x * nb); + + for (int i = 0; i < col_groups; i++) { + acc_f32[i] = vdupq_n_f32(0); + } + + for (int b = 0; b < nb; b++) { + float32x4_t q6_d_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d)); // d0 d1 d2 d3 + float32x4_t q6_d_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d + 4)); // d4 d5 d6 d7 + float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d); + float32x4_t sb_scale_0 = vmulq_f32(q6_d_0, q8_d); + float32x4_t sb_scale_1 = vmulq_f32(q6_d_1, q8_d); + + int32x4_t acc[col_groups]; + for (int i = 0; i < col_groups; i++) { + acc[i] = vdupq_n_s32(0); + } + + // Load all 16 scales once and widen to int16 (Q6_K has 16 scales per block) + // Reused for bias and dequantization later + int16_t q6_scales[16 * 8]; + for (int i = 0; i < 16; i++) { + int16x8_t scales = vmovl_s8(vld1_s8(q6_ptr[b].scales + i * 8)); + vst1q_s16(q6_scales + i * 8, scales); + } + + // Compute bias per column using q8 bsums and preloaded scales to skip the -32 shift + int32x4_t bias_lo = vdupq_n_s32(0); + int32x4_t bias_hi = vdupq_n_s32(0); + + // Load bsums in chunks of 4 to process with vectorized operations + for (int i = 0; i < 16; i += 4) { + int16x4_t bsums_vec = vld1_s16(q8_ptr[b].bsums + i); + int16x4_t scales_lo_0 = vld1_s16(q6_scales + (i + 0) * 8); + int16x4_t scales_hi_0 = vld1_s16(q6_scales + (i + 0) * 8 + 4); + int16x4_t scales_lo_1 = vld1_s16(q6_scales + (i + 1) * 8); + int16x4_t scales_hi_1 = vld1_s16(q6_scales + (i + 1) * 8 + 4); + int16x4_t scales_lo_2 = vld1_s16(q6_scales + (i + 2) * 8); + int16x4_t scales_hi_2 = vld1_s16(q6_scales + (i + 2) * 8 + 4); + int16x4_t scales_lo_3 = vld1_s16(q6_scales + (i + 3) * 8); + int16x4_t scales_hi_3 = vld1_s16(q6_scales + (i + 3) * 8 + 4); + + bias_lo = vmlal_lane_s16(bias_lo, scales_lo_0, bsums_vec, 0); + bias_hi = vmlal_lane_s16(bias_hi, scales_hi_0, bsums_vec, 0); + bias_lo = vmlal_lane_s16(bias_lo, scales_lo_1, bsums_vec, 1); + bias_hi = vmlal_lane_s16(bias_hi, scales_hi_1, bsums_vec, 1); + bias_lo = vmlal_lane_s16(bias_lo, scales_lo_2, bsums_vec, 2); + bias_hi = vmlal_lane_s16(bias_hi, scales_hi_2, bsums_vec, 2); + bias_lo = vmlal_lane_s16(bias_lo, scales_lo_3, bsums_vec, 3); + bias_hi = vmlal_lane_s16(bias_hi, scales_hi_3, bsums_vec, 3); + } + bias_lo = vshlq_n_s32(bias_lo, 5); + bias_hi = vshlq_n_s32(bias_hi, 5); + + // Process two 128-value halves per superblock + for (int half = 0; half < 2; half++) { + const uint8_t * ql_base = q6_ptr[b].ql + half * 512; + const uint8_t * qh_base = q6_ptr[b].qh + half * 256; + + // A subblock (sb) is a set of weights that share the scale + // Since q6_K scales are per 16 elements + // num sbs -> 256 elements / (16 elements/scale * 2 elements/byte * 2 halves) + for (int sb = 0; sb < QK_K / 64; sb++) { + const int8_t * q8_base_l = q8_ptr[b].qs + half * 128 + sb * 16; + const int8_t * q8_base_h = q8_base_l + 64; + + // Load and duplicate q8 values (each register covers four interleaved columns of q6) + int8x16_t q8_l[4]; + int8x16_t q8_h[4]; + for (int i = 0; i < 4; i++) { + q8_l[i] = (int8x16_t) vld1q_dup_s32((const int32_t *) (q8_base_l + i * 4)); + q8_h[i] = (int8x16_t) vld1q_dup_s32((const int32_t *) (q8_base_h + i * 4)); + } + + const int ql_off_base = sb * QK_K / 2; + const int qh_off_base = ql_off_base & 255; // wraps after 256 bytes + + // Load 4 vectors at once (64 bytes each for ql_0, ql_1, qh_0, qh_1) + uint8x16x4_t q6_ql_0 = vld1q_u8_x4(ql_base + ql_off_base); + uint8x16x4_t q6_ql_1 = vld1q_u8_x4(ql_base + ql_off_base + 64); + uint8x16x4_t q6_qh_0 = vld1q_u8_x4(qh_base + qh_off_base); + uint8x16x4_t q6_qh_1 = vld1q_u8_x4(qh_base + qh_off_base + 64); + + // Adjust qh for subblocks 2 and 3 (shift right by 2) + if (sb > 1) { + q6_qh_0.val[0] = vshrq_n_u8(q6_qh_0.val[0], 2); + q6_qh_0.val[1] = vshrq_n_u8(q6_qh_0.val[1], 2); + q6_qh_0.val[2] = vshrq_n_u8(q6_qh_0.val[2], 2); + q6_qh_0.val[3] = vshrq_n_u8(q6_qh_0.val[3], 2); + q6_qh_1.val[0] = vshrq_n_u8(q6_qh_1.val[0], 2); + q6_qh_1.val[1] = vshrq_n_u8(q6_qh_1.val[1], 2); + q6_qh_1.val[2] = vshrq_n_u8(q6_qh_1.val[2], 2); + q6_qh_1.val[3] = vshrq_n_u8(q6_qh_1.val[3], 2); + } + + const uint8x16_t q6_ql[8] = { q6_ql_0.val[0], q6_ql_0.val[1], q6_ql_0.val[2], q6_ql_0.val[3], + q6_ql_1.val[0], q6_ql_1.val[1], q6_ql_1.val[2], q6_ql_1.val[3] }; + const uint8x16_t q6_qh[8] = { q6_qh_0.val[0], q6_qh_0.val[1], q6_qh_0.val[2], q6_qh_0.val[3], + q6_qh_1.val[0], q6_qh_1.val[1], q6_qh_1.val[2], q6_qh_1.val[3] }; + + // Process column groups (0-3, 4-7) + for (int g = 0; g < col_groups; g++) { + int32x4_t sb_acc_l = vdupq_n_s32(0); + int32x4_t sb_acc_h = vdupq_n_s32(0); + + for (int chunk = 0; chunk < 4; chunk++) { + const int idx = chunk * 2 + g; + + const uint8x16_t q6_qs_l = q6_ql[idx]; + const uint8x16_t q6_qs_h = q6_qh[idx]; + + // Extract high 2 bits for upper nibble reconstruction + const uint8x16_t q6_qs_hh = vandq_u8(q6_qs_h, mask_hi); + + // q6 = (low4 | high2<<4), without -32 bias (handled via bsums) + const int8x16_t q6_l = + vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_qs_l, m4b), vandq_u8(q6_qs_h, mask_lo), 4)); + const int8x16_t q6_h = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_qs_l, 4), q6_qs_hh)); + + sb_acc_l = vdotq_s32(sb_acc_l, q6_l, q8_l[chunk]); + sb_acc_h = vdotq_s32(sb_acc_h, q6_h, q8_h[chunk]); + } + + const int scale_idx_l = half * 8 + sb; + const int scale_idx_h = half * 8 + sb + 4; + + const int32x4_t scale_vec_l = vmovl_s16(vld1_s16(q6_scales + scale_idx_l * 8 + g * 4)); + const int32x4_t scale_vec_h = vmovl_s16(vld1_s16(q6_scales + scale_idx_h * 8 + g * 4)); + + acc[g] = vmlaq_s32(acc[g], sb_acc_l, scale_vec_l); + acc[g] = vmlaq_s32(acc[g], sb_acc_h, scale_vec_h); + } + } + } // for half + + // Bias correction + acc[0] = vsubq_s32(acc[0], bias_lo); + acc[1] = vsubq_s32(acc[1], bias_hi); + + // Apply superblock scale (no mins for q6_K) + // acc[g] has [c0, c1, c2, c3] + float32x4_t w_0123 = vmulq_f32(vcvtq_f32_s32(acc[0]), sb_scale_0); + float32x4_t w_4567 = vmulq_f32(vcvtq_f32_s32(acc[1]), sb_scale_1); + + acc_f32[0] = vaddq_f32(acc_f32[0], w_0123); + acc_f32[1] = vaddq_f32(acc_f32[1], w_4567); + } // for b + + int base = x * ncols_interleaved; + vst1q_f32(s + base, acc_f32[0]); + vst1q_f32(s + base + 4, acc_f32[1]); + } // for x + return; +#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + ggml_gemv_q6_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc); +} + void ggml_gemv_q6_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, @@ -1177,15 +1366,14 @@ void ggml_gemv_q6_K_8x8_q8_K(int n, q8_h[i] = (int8x16_t) vld1q_dup_s64((const int64_t *) (q8_base_h + i * 8)); } - // TODO: Test other qh repack patterns to reduce loads const int ql_off_base = sb * QK_K / 2; const int qh_off_base = ql_off_base & 255; // wraps after 256 bytes // Load 4 vectors at once (64 bytes each for ql_0, ql_1, qh_0, qh_1) - ggml_uint8x16x4_t q6_ql_0 = ggml_vld1q_u8_x4(ql_base + ql_off_base); - ggml_uint8x16x4_t q6_ql_1 = ggml_vld1q_u8_x4(ql_base + ql_off_base + 64); - ggml_uint8x16x4_t q6_qh_0 = ggml_vld1q_u8_x4(qh_base + qh_off_base); - ggml_uint8x16x4_t q6_qh_1 = ggml_vld1q_u8_x4(qh_base + qh_off_base + 64); + uint8x16x4_t q6_ql_0 = vld1q_u8_x4(ql_base + ql_off_base); + uint8x16x4_t q6_ql_1 = vld1q_u8_x4(ql_base + ql_off_base + 64); + uint8x16x4_t q6_qh_0 = vld1q_u8_x4(qh_base + qh_off_base); + uint8x16x4_t q6_qh_1 = vld1q_u8_x4(qh_base + qh_off_base + 64); // Adjust qh for subblocks 2 and 3 (shift right by 2) if (sb > 1) { @@ -3474,6 +3662,208 @@ void ggml_gemm_q5_K_8x8_q8_K(int n, ggml_gemm_q5_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc); } +void ggml_gemm_q6_K_8x4_q8_K(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + constexpr int qk = QK_K; + const int nb = n / qk; + + constexpr int ncols_interleaved = 8; + constexpr int blocklen = 4; + + assert(n % qk == 0); + assert(nr % 4 == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + constexpr int q8_k_blocklen = 4; + constexpr int col_groups = ncols_interleaved / 4; + constexpr int acc_size = q8_k_blocklen * col_groups; // 4 rows, 2 column groups + const uint8x16_t m4b = vdupq_n_u8(0x0f); + const uint8x16_t mask_lo = vdupq_n_u8(0x03); + const uint8x16_t mask_hi = vdupq_n_u8(0x30); + const int8x16_t m32s = vdupq_n_s8(32); + + float32x4_t acc_f32[acc_size]; + + for (int y = 0; y < nr / q8_k_blocklen; y++) { + const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb); + + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q6_Kx8 * GGML_RESTRICT q6_ptr = (const block_q6_Kx8 *) vx + (x * nb); + + for (int i = 0; i < acc_size; i++) { + acc_f32[i] = vdupq_n_f32(0); + } + + for (int b = 0; b < nb; b++) { + float32x4_t q6_d_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d)); + float32x4_t q6_d_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d + 4)); + float32x4_t q8_d_0123 = vld1q_f32(q8_ptr[b].d); + + float32x4_t sbd_scale_0123[q8_k_blocklen]; + float32x4_t sbd_scale_4567[q8_k_blocklen]; + + sbd_scale_0123[0] = vmulq_laneq_f32(q6_d_0123, q8_d_0123, 0); + sbd_scale_4567[0] = vmulq_laneq_f32(q6_d_4567, q8_d_0123, 0); + sbd_scale_0123[1] = vmulq_laneq_f32(q6_d_0123, q8_d_0123, 1); + sbd_scale_4567[1] = vmulq_laneq_f32(q6_d_4567, q8_d_0123, 1); + sbd_scale_0123[2] = vmulq_laneq_f32(q6_d_0123, q8_d_0123, 2); + sbd_scale_4567[2] = vmulq_laneq_f32(q6_d_4567, q8_d_0123, 2); + sbd_scale_0123[3] = vmulq_laneq_f32(q6_d_0123, q8_d_0123, 3); + sbd_scale_4567[3] = vmulq_laneq_f32(q6_d_4567, q8_d_0123, 3); + + int32x4_t acc_s32[acc_size]; + for (int i = 0; i < acc_size; i++) { + acc_s32[i] = vdupq_n_s32(0); + } + + int16_t q6_scales[8 * 16]; + for (int i = 0; i < 16; i++) { + int16x8_t scales = vmovl_s8(vld1_s8(q6_ptr[b].scales + i * 8)); + vst1q_s16(q6_scales + i * 8, scales); + } + + for (int half = 0; half < 2; half++) { + const uint8_t * ql_base = q6_ptr[b].ql + half * 512; + const uint8_t * qh_base = q6_ptr[b].qh + half * 256; + + for (int sb = 0; sb < QK_K / 64; sb++) { + int32x4_t acc_lo[acc_size]; + int32x4_t acc_hi[acc_size]; + for (int i = 0; i < acc_size; i++) { + acc_lo[i] = vdupq_n_s32(0); + acc_hi[i] = vdupq_n_s32(0); + } + + const int8_t * q8_base_l = q8_ptr[b].qs + half * 512 + sb * 64; + const int8_t * q8_base_h = q8_ptr[b].qs + half * 512 + 256 + sb * 64; + + // 4 rows * 16 elements per scale + // 4 reads of 16 bytes each + constexpr int reads_per_sb = 4; + int8x16_t q8_l[reads_per_sb]; + int8x16_t q8_h[reads_per_sb]; + for (int k = 0; k < reads_per_sb; k++) { + q8_l[k] = vld1q_s8(q8_base_l + 16 * k); + q8_h[k] = vld1q_s8(q8_base_h + 16 * k); + } + + const int ql_off_base = sb * QK_K / 2; + const int qh_off_base = ql_off_base & 255; + + uint8x16_t q6_ql_0123[reads_per_sb]; + uint8x16_t q6_ql_4567[reads_per_sb]; + uint8x16_t q6_qh_0123[reads_per_sb]; + uint8x16_t q6_qh_4567[reads_per_sb]; + + for (int k = 0; k < reads_per_sb; k++) { + q6_ql_0123[k] = vld1q_u8(ql_base + ql_off_base + k * 32); + q6_ql_4567[k] = vld1q_u8(ql_base + ql_off_base + k * 32 + 16); + q6_qh_0123[k] = vld1q_u8(qh_base + qh_off_base + k * 32); + q6_qh_4567[k] = vld1q_u8(qh_base + qh_off_base + k * 32 + 16); + } + + if (sb > 1) { + for (int k = 0; k < reads_per_sb; k++) { + q6_qh_0123[k] = vshrq_n_u8(q6_qh_0123[k], 2); + q6_qh_4567[k] = vshrq_n_u8(q6_qh_4567[k], 2); + } + } + + for (int k = 0; k < reads_per_sb; k++) { + // q = (ql | qh) - 32 + const uint8x16_t hbit_lo_0123 = vandq_u8(q6_qh_0123[k], mask_lo); + const uint8x16_t hbit_hi_0123 = vandq_u8(q6_qh_0123[k], mask_hi); + const uint8x16_t hbit_lo_4567 = vandq_u8(q6_qh_4567[k], mask_lo); + const uint8x16_t hbit_hi_4567 = vandq_u8(q6_qh_4567[k], mask_hi); + + const int8x16_t q6_0123_lo = vsubq_s8( + vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_ql_0123[k], m4b), hbit_lo_0123, 4)), m32s); + const int8x16_t q6_0123_hi = vsubq_s8( + vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_ql_0123[k], 4), hbit_hi_0123)), m32s); + + acc_lo[0] = vdotq_laneq_s32(acc_lo[0], q6_0123_lo, q8_l[k], 0); // 0..3 r0 c0123 + acc_lo[1] = vdotq_laneq_s32(acc_lo[1], q6_0123_lo, q8_l[k], 1); // 0..3 r1 c0123 + acc_lo[2] = vdotq_laneq_s32(acc_lo[2], q6_0123_lo, q8_l[k], 2); // 0..3 r2 c0123 + acc_lo[3] = vdotq_laneq_s32(acc_lo[3], q6_0123_lo, q8_l[k], 3); // 0..3 r3 c0123 + + acc_hi[0] = vdotq_laneq_s32(acc_hi[0], q6_0123_hi, q8_h[k], 0); // 64..67 r0 c0123 + acc_hi[1] = vdotq_laneq_s32(acc_hi[1], q6_0123_hi, q8_h[k], 1); // 64..67 r1 c0123 + acc_hi[2] = vdotq_laneq_s32(acc_hi[2], q6_0123_hi, q8_h[k], 2); // 64..67 r2 c0123 + acc_hi[3] = vdotq_laneq_s32(acc_hi[3], q6_0123_hi, q8_h[k], 3); // 64..67 r3 c0123 + + const int8x16_t q6_4567_lo = vsubq_s8( + vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_ql_4567[k], m4b), hbit_lo_4567, 4)), m32s); + const int8x16_t q6_4567_hi = vsubq_s8( + vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_ql_4567[k], 4), hbit_hi_4567)), m32s); + + acc_lo[4] = vdotq_laneq_s32(acc_lo[4], q6_4567_lo, q8_l[k], 0); // 0..3 r0 c4567 + acc_lo[5] = vdotq_laneq_s32(acc_lo[5], q6_4567_lo, q8_l[k], 1); // 0..3 r1 c4567 + acc_lo[6] = vdotq_laneq_s32(acc_lo[6], q6_4567_lo, q8_l[k], 2); // 0..3 r2 c4567 + acc_lo[7] = vdotq_laneq_s32(acc_lo[7], q6_4567_lo, q8_l[k], 3); // 0..3 r3 c4567 + + acc_hi[4] = vdotq_laneq_s32(acc_hi[4], q6_4567_hi, q8_h[k], 0); // 64..67 r0 c4567 + acc_hi[5] = vdotq_laneq_s32(acc_hi[5], q6_4567_hi, q8_h[k], 1); // 64..67 r1 c4567 + acc_hi[6] = vdotq_laneq_s32(acc_hi[6], q6_4567_hi, q8_h[k], 2); // 64..67 r2 c4567 + acc_hi[7] = vdotq_laneq_s32(acc_hi[7], q6_4567_hi, q8_h[k], 3); // 64..67 r3 c4567 + } + + // Scale and bias + const int scale_idx_l = half * 8 + sb; + const int scale_idx_h = half * 8 + sb + 4; + + for (int g = 0; g < col_groups; g++) { + const int16x4_t scales_l16 = vld1_s16(q6_scales + scale_idx_l * 8 + g * 4); + const int16x4_t scales_h16 = vld1_s16(q6_scales + scale_idx_h * 8 + g * 4); + const int32x4_t scale_vec_l = vmovl_s16(scales_l16); + const int32x4_t scale_vec_h = vmovl_s16(scales_h16); + const int acc_offset = g * q8_k_blocklen; + + for (int row = 0; row < q8_k_blocklen; row++) { + const int idx = row * 2 + g; + acc_s32[idx] = vmlaq_s32(acc_s32[idx], acc_lo[acc_offset + row], scale_vec_l); + acc_s32[idx] = vmlaq_s32(acc_s32[idx], acc_hi[acc_offset + row], scale_vec_h); + } + } + } + } + + // Finally we apply the superblock scales + for (int row = 0; row < q8_k_blocklen; row++) { + const int idx0 = 2 * row; + const int idx1 = 2 * row + 1; + const int32x4_t acc_0123 = acc_s32[idx0]; + const int32x4_t acc_4567 = acc_s32[idx1]; + + acc_f32[idx0] = vmlaq_f32(acc_f32[idx0], vcvtq_f32_s32(acc_0123), sbd_scale_0123[row]); + acc_f32[idx1] = vmlaq_f32(acc_f32[idx1], vcvtq_f32_s32(acc_4567), sbd_scale_4567[row]); + } + } // for b + + for (int i = 0; i < q8_k_blocklen; i++) { + int row = y * q8_k_blocklen + i; + for (int j = 0; j < 2; j++) { + int col = x * ncols_interleaved + j * 4; + int offset = row * bs + col; + vst1q_f32(s + offset, acc_f32[2 * i + j]); + } + } + } // for x + } // for y + return; +#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + ggml_gemm_q6_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc); +} + void ggml_gemm_q6_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp index 24e8ab46182..4cb7cdeb07b 100644 --- a/ggml/src/ggml-cpu/repack.cpp +++ b/ggml/src/ggml-cpu/repack.cpp @@ -256,6 +256,200 @@ template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_K>(const float * GGML_RESTR ggml_quantize_mat_q8_K_4x8(x, vy, n_per_row); } +template +static void ggml_gemv_q6_K_NxM_q8_K_generic_impl(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + constexpr int blocklen = M; + constexpr int ncols_interleaved = N; + const int qk = QK_K; + const int nb = n / qk; + const int blocks_per_half = 64 / blocklen; + + assert(n % qk == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(bs); + UNUSED(nr); + + float sumf[8]; + + const block_q8_K * a_ptr = (const block_q8_K *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q6_Kx8 * b_ptr = (const block_q6_Kx8 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) { + sumf[j] = 0.0f; + } + + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + const int base_l = (k / blocks_per_half) * 128 + (k % blocks_per_half) * blocklen; + const int base_h = base_l + 64; + + const int scale_idx_l = base_l / 16; + const int scale_idx_h = base_h / 16; + + const int qh_shift_l = ((base_l % 128) / 32) * 2; + const int qh_shift_h = ((base_h % 128) / 32) * 2; + + const int qh_half_l = (base_l / 128) * 32; + const int qh_half_h = (base_h / 128) * 32; + + for (int j = 0; j < ncols_interleaved; j++) { + const int8_t scale_l = b_ptr[l].scales[scale_idx_l * ncols_interleaved + j]; + const int8_t scale_h = b_ptr[l].scales[scale_idx_h * ncols_interleaved + j]; + + int sumi_l = 0; + int sumi_h = 0; + + for (int i = 0; i < blocklen; i++) { + const int ql_pos = k * ncols_interleaved * blocklen + j * blocklen + i; + const int l_4 = b_ptr[l].ql[ql_pos] & 0xF; + const int hi_4 = (b_ptr[l].ql[ql_pos] >> 4) & 0xF; + + const int qh_idx_l = qh_half_l + ((base_l + i) % 32); + const int qh_chunk_l = qh_idx_l / blocklen; + const int qh_pos_l = qh_idx_l % blocklen; + const int qh_offset_l = qh_chunk_l * (blocklen * ncols_interleaved) + j * blocklen + qh_pos_l; + const int hi_2_l = (b_ptr[l].qh[qh_offset_l] >> qh_shift_l) & 0x3; + + const int qh_idx_h = qh_half_h + ((base_h + i) % 32); + const int qh_chunk_h = qh_idx_h / blocklen; + const int qh_pos_h = qh_idx_h % blocklen; + const int qh_offset_h = qh_chunk_h * (blocklen * ncols_interleaved) + j * blocklen + qh_pos_h; + const int hi_2_h = (b_ptr[l].qh[qh_offset_h] >> qh_shift_h) & 0x3; + + const int q_l = ((hi_2_l << 4) | l_4) - 32; + const int q_h = ((hi_2_h << 4) | hi_4) - 32; + + const int8_t a_l = a_ptr[l].qs[base_l + i]; + const int8_t a_h = a_ptr[l].qs[base_h + i]; + + sumi_l += q_l * a_l; + sumi_h += q_h * a_h; + } + + sumf[j] += + (sumi_l * scale_l + sumi_h * scale_h) * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d; + } + } + } + + for (int j = 0; j < ncols_interleaved; j++) { + s[x * ncols_interleaved + j] = sumf[j]; + } + } +} + +template +static void ggml_gemm_q6_K_NxM_q8_K_generic_impl(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + constexpr int blocklen = M; + constexpr int ncols_interleaved = N; + const int qk = QK_K; + const int nb = n / qk; + const int blocks_per_half = 64 / blocklen; + const int q8_half_stride = 512; + const int q8_low_high_step = 256; + + assert(n % qk == 0); + assert(nr % 4 == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(bs); + + float sumf[4][8]; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q6_Kx8 * b_ptr = (const block_q6_Kx8 *) vx + (x * nb); + + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumf[m][j] = 0.0f; + } + } + + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + const int base_l = (k / blocks_per_half) * 128 + (k % blocks_per_half) * blocklen; + const int base_h = base_l + 64; + + const int scale_idx_l = base_l / 16; + const int scale_idx_h = base_h / 16; + + const int qh_shift_l = ((base_l % 128) / 32) * 2; + const int qh_shift_h = ((base_h % 128) / 32) * 2; + + const int qh_half_l = (base_l / 128) * 32; + const int qh_half_h = (base_h / 128) * 32; + + const int q8_base = (k / blocks_per_half) * q8_half_stride + (k % blocks_per_half) * (blocklen * 4); + + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + const int8_t scale_l = b_ptr[l].scales[scale_idx_l * ncols_interleaved + j]; + const int8_t scale_h = b_ptr[l].scales[scale_idx_h * ncols_interleaved + j]; + + int sumi_l = 0; + int sumi_h = 0; + + for (int i = 0; i < blocklen; i++) { + const int ql_pos = k * ncols_interleaved * blocklen + j * blocklen + i; + const int l_4 = b_ptr[l].ql[ql_pos] & 0xF; + const int hi_4 = (b_ptr[l].ql[ql_pos] >> 4) & 0xF; + + const int qh_idx_l = qh_half_l + ((base_l + i) % 32); + const int qh_chunk_l = qh_idx_l / blocklen; + const int qh_pos_l = qh_idx_l % blocklen; + const int qh_offset_l = + qh_chunk_l * (blocklen * ncols_interleaved) + j * blocklen + qh_pos_l; + const int hi_2_l = (b_ptr[l].qh[qh_offset_l] >> qh_shift_l) & 0x3; + + const int qh_idx_h = qh_half_h + ((base_h + i) % 32); + const int qh_chunk_h = qh_idx_h / blocklen; + const int qh_pos_h = qh_idx_h % blocklen; + const int qh_offset_h = + qh_chunk_h * (blocklen * ncols_interleaved) + j * blocklen + qh_pos_h; + const int hi_2_h = (b_ptr[l].qh[qh_offset_h] >> qh_shift_h) & 0x3; + + const int q_l = ((hi_2_l << 4) | l_4) - 32; + const int q_h = ((hi_2_h << 4) | hi_4) - 32; + + const int8_t q8_l = a_ptr[l].qs[q8_base + m * blocklen + i]; + const int8_t q8_h = a_ptr[l].qs[q8_base + m * blocklen + i + q8_low_high_step]; + + sumi_l += q_l * q8_l; + sumi_h += q_h * q8_h; + } + + sumf[m][j] += (sumi_l * scale_l + sumi_h * scale_h) * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * + a_ptr[l].d[m]; + } + } + } + } + + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + } + } + } + } +} + extern "C" { void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { @@ -704,94 +898,12 @@ void ggml_gemv_q5_K_8x8_q8_K_generic(int n, } -void ggml_gemv_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - constexpr int qk = QK_K; - const int nb = n / qk; - const int ncols_interleaved = 8; - const int blocklen = 8; - - assert(n % qk == 0); - assert(nc % ncols_interleaved == 0); - - UNUSED(bs); - UNUSED(nr); - - float sumf[8]; - - const block_q8_K * a_ptr = (const block_q8_K *) vy; - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q6_Kx8 * b_ptr = (const block_q6_Kx8 *) vx + (x * nb); - - for (int j = 0; j < ncols_interleaved; j++) { - sumf[j] = 0.0f; - } - - for (int l = 0; l < nb; l++) { - - - for (int k = 0; k < 16; k++) { - // k = 0.. 7 weights 0-63 low, 64-127 high - // k = 8..15 weights 128-191 low, 192-255 high - const int base_l = (k / 8) * 128 + (k % 8) * 8; - const int base_h = base_l + 64; - - const int scale_idx_l = base_l / 16; - const int scale_idx_h = base_h / 16; - - // Bit shift cycles 0,2,4,6 for each 32-value group within a 128-value half - const int qh_shift_l = ((base_l % 128) / 32) * 2; - const int qh_shift_h = ((base_h % 128) / 32) * 2; - - // qh_half: offset to the correct 32-byte half (0 or 32) - const int qh_half_l = (base_l / 128) * 32; - const int qh_half_h = (base_h / 128) * 32; - - for (int j = 0; j < ncols_interleaved; j++) { - // Interleaved scales - const int8_t scale_l = b_ptr[l].scales[scale_idx_l * 8 + j]; - const int8_t scale_h = b_ptr[l].scales[scale_idx_h * 8 + j]; - - int sumi_l = 0; - int sumi_h = 0; - - for (int i = 0; i < blocklen; i++) { - const int ql_pos = k * 64 + j * 8 + i; - const int l_4 = b_ptr[l].ql[ql_pos] & 0xF; - const int hi_4 = (b_ptr[l].ql[ql_pos] >> 4) & 0xF; - - // qh indexing with 8-byte interleaving (like q5_K) - const int qh_byte_l = qh_half_l + ((base_l + i) % 32); - const int qh_chunk_l = qh_byte_l / 8; - const int qh_pos_l = qh_byte_l % 8; - const int qh_offset_l = qh_chunk_l * 64 + j * 8 + qh_pos_l; - const int hi_2_l = (b_ptr[l].qh[qh_offset_l] >> qh_shift_l) & 0x3; - - const int qh_byte_h = qh_half_h + ((base_h + i) % 32); - const int qh_chunk_h = qh_byte_h / 8; - const int qh_pos_h = qh_byte_h % 8; - const int qh_offset_h = qh_chunk_h * 64 + j * 8 + qh_pos_h; - const int hi_2_h = (b_ptr[l].qh[qh_offset_h] >> qh_shift_h) & 0x3; - - const int q_l = ((hi_2_l << 4) | l_4) - 32; - const int q_h = ((hi_2_h << 4) | hi_4) - 32; - - const int8_t a_l = a_ptr[l].qs[base_l + i]; - const int8_t a_h = a_ptr[l].qs[base_h + i]; - - sumi_l += q_l * a_l; - sumi_h += q_h * a_h; - } - - sumf[j] += - (sumi_l * scale_l + sumi_h * scale_h) * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d; - } - } - } +void ggml_gemv_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q6_K_NxM_q8_K_generic_impl<4, 8>(n, s, bs, vx, vy, nr, nc); +} - for (int j = 0; j < ncols_interleaved; j++) { - s[x * ncols_interleaved + j] = sumf[j]; - } - } +void ggml_gemv_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q6_K_NxM_q8_K_generic_impl<8, 8>(n, s, bs, vx, vy, nr, nc); } void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { @@ -1485,109 +1597,12 @@ void ggml_gemm_q5_K_8x8_q8_K_generic(int n, } } -void ggml_gemm_q6_K_8x8_q8_K_generic(int n, - float * GGML_RESTRICT s, - size_t bs, - const void * GGML_RESTRICT vx, - const void * GGML_RESTRICT vy, - int nr, - int nc) { - const int qk = QK_K; - const int nb = n / qk; - const int ncols_interleaved = 8; - const int blocklen = 8; - - assert(n % qk == 0); - assert(nr % 4 == 0); - assert(nc % ncols_interleaved == 0); - - UNUSED(bs); - - float sumf[4][8]; - - for (int y = 0; y < nr / 4; y++) { - const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb); - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q6_Kx8 * b_ptr = (const block_q6_Kx8 *) vx + (x * nb); - - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumf[m][j] = 0.0f; - } - } - - for (int l = 0; l < nb; l++) { - for (int k = 0; k < 16; k++) { - // k = 0.. 7 weights 0-63 low, 64-127 high - // k = 8..15 weights 128-191 low, 192-255 high - const int base_l = (k / 8) * 128 + (k % 8) * 8; - const int base_h = base_l + 64; - - const int scale_idx_l = base_l / 16; - const int scale_idx_h = base_h / 16; - - // Bit shift cycles 0,2,4,6 for each 32-value group within a 128-value half - const int qh_shift_l = ((base_l % 128) / 32) * 2; - const int qh_shift_h = ((base_h % 128) / 32) * 2; - - // qh_half: offset to the correct 32-byte half (0 or 32) - const int qh_half_l = (base_l / 128) * 32; - const int qh_half_h = (base_h / 128) * 32; - - // Activation base indices for q8_Kx4 interleaved format - // Layout: 128-value halves (k/8), then 8-value sub-blocks (k%8) with stride 32 - const int q8_base = (k / 8) * 512 + (k % 8) * 32; - - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) { - // Interleaved scales - const int8_t scale_l = b_ptr[l].scales[scale_idx_l * 8 + j]; - const int8_t scale_h = b_ptr[l].scales[scale_idx_h * 8 + j]; - - int sumi_l = 0; - int sumi_h = 0; - - for (int i = 0; i < blocklen; i++) { - const int ql_pos = k * 64 + j * 8 + i; - const int l_4 = b_ptr[l].ql[ql_pos] & 0xF; - const int hi_4 = (b_ptr[l].ql[ql_pos] >> 4) & 0xF; - - const int qh_idx_l = qh_half_l + ((base_l + i) % 32); - const int qh_chunk_l = qh_idx_l / 8; - const int qh_pos_l = qh_idx_l % 8; - const int qh_offset_l = qh_chunk_l * 64 + j * 8 + qh_pos_l; - const int hi_2_l = (b_ptr[l].qh[qh_offset_l] >> qh_shift_l) & 0x3; - - const int qh_idx_h = qh_half_h + ((base_h + i) % 32); - const int qh_chunk_h = qh_idx_h / 8; - const int qh_pos_h = qh_idx_h % 8; - const int qh_offset_h = qh_chunk_h * 64 + j * 8 + qh_pos_h; - const int hi_2_h = (b_ptr[l].qh[qh_offset_h] >> qh_shift_h) & 0x3; - - const int q_l = ((hi_2_l << 4) | l_4) - 32; - const int q_h = ((hi_2_h << 4) | hi_4) - 32; - - const int8_t q8_l = a_ptr[l].qs[q8_base + m * 8 + i]; - const int8_t q8_h = a_ptr[l].qs[q8_base + m * 8 + i + 256]; - - sumi_l += q_l * q8_l; - sumi_h += q_h * q8_h; - } - - sumf[m][j] += (sumi_l * scale_l + sumi_h * scale_h) * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * - a_ptr[l].d[m]; - } - } - } - } +void ggml_gemm_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q6_K_NxM_q8_K_generic_impl<4, 8>(n, s, bs, vx, vy, nr, nc); +} - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) { - s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; - } - } - } - } +void ggml_gemm_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q6_K_NxM_q8_K_generic_impl<8, 8>(n, s, bs, vx, vy, nr, nc); } void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { @@ -2097,18 +2112,18 @@ static block_q6_Kx8 make_block_q6_Kx8(block_q6_K * in, unsigned int blck_size_in } const int end_ls = QK_K * 4 / blck_size_interleave; - // Interleave Q6_K quants by taking 8 bytes at a time + // Interleave Q6_K quants by taking blck_size_interleave bytes at a time for (int i = 0; i < end_ls; ++i) { int src_id = i % n_blocks; int src_offset = (i / n_blocks) * blck_size_interleave; int dst_offset = i * blck_size_interleave; uint64_t elem_ls; - memcpy(&elem_ls, &in[src_id].ql[src_offset], sizeof(uint64_t)); - memcpy(&out.ql[dst_offset], &elem_ls, sizeof(uint64_t)); + memcpy(&elem_ls, &in[src_id].ql[src_offset], blck_size_interleave); + memcpy(&out.ql[dst_offset], &elem_ls, blck_size_interleave); } - // Interleave high bits using same 8-byte pattern as low bits + // Interleave high bits using same chunk size as low bits const int end_hs = end_ls / 2; for (int i = 0; i < end_hs; ++i) { int src_id = i % n_blocks; @@ -2116,8 +2131,8 @@ static block_q6_Kx8 make_block_q6_Kx8(block_q6_K * in, unsigned int blck_size_in int dst_offset = i * blck_size_interleave; uint64_t elem_hs; - memcpy(&elem_hs, &in[src_id].qh[src_offset], sizeof(uint64_t)); - memcpy(&out.qh[dst_offset], &elem_hs, sizeof(uint64_t)); + memcpy(&elem_hs, &in[src_id].qh[src_offset], blck_size_interleave); + memcpy(&out.qh[dst_offset], &elem_hs, blck_size_interleave); } // The below logic is designed so as to unpack and rearrange scales in Q6_K @@ -2262,7 +2277,7 @@ static int repack_q5_K_to_q5_K_8_bl(struct ggml_tensor * t, static int repack_q6_K_to_q6_K_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { GGML_ASSERT(t->type == GGML_TYPE_Q6_K); - GGML_ASSERT(interleave_block == 8); + GGML_ASSERT(interleave_block == 4 || interleave_block == 8); constexpr int nrows_interleaved = 8; block_q6_Kx8 * dst = (block_q6_Kx8 *)t->data; @@ -2511,6 +2526,10 @@ template <> int repack(struct ggml_tensor * t, const void * da return repack_q5_K_to_q5_K_8_bl(t, 8, data, data_size); } +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q6_K_to_q6_K_8_bl(t, 4, data, data_size); +} + template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { return repack_q6_K_to_q6_K_8_bl(t, 8, data, data_size); } @@ -2575,6 +2594,10 @@ template <> void gemv(int n, float * s, size_t ggml_gemv_q5_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); } +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q6_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc); +} + template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemv_q6_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); } @@ -2634,6 +2657,10 @@ template <> void gemm(int n, float * s, size_t ggml_gemm_q5_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); } +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q6_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc); +} + template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_q6_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); } @@ -3043,6 +3070,7 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons static const ggml::cpu::repack::tensor_traits q5_K_8x8_q8_K; // instance for Q6_K + static const ggml::cpu::repack::tensor_traits q6_K_8x4_q8_K; static const ggml::cpu::repack::tensor_traits q6_K_8x8_q8_K; // instance for Q2 @@ -3107,6 +3135,11 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons return &q6_K_8x8_q8_K; } } + if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) { + if (cur->ne[1] % 8 == 0) { + return &q6_K_8x4_q8_K; + } + } } else if (cur->type == GGML_TYPE_IQ4_NL) { if (ggml_cpu_has_avx2()) { if (cur->ne[1] % 8 == 0) { diff --git a/ggml/src/ggml-cpu/repack.h b/ggml/src/ggml-cpu/repack.h index 855320eeeb6..39b6b482388 100644 --- a/ggml/src/ggml-cpu/repack.h +++ b/ggml/src/ggml-cpu/repack.h @@ -112,6 +112,7 @@ void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q5_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q6_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q6_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -122,6 +123,7 @@ void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q5_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q6_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q6_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -142,6 +144,7 @@ void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, void ggml_gemv_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -152,6 +155,7 @@ void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, void ggml_gemm_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); From 562255fd77cb4a047c6ed54f7c4f573dbcd699d2 Mon Sep 17 00:00:00 2001 From: Nikhil Jain Date: Tue, 10 Feb 2026 08:04:00 -0800 Subject: [PATCH 126/831] Plug memory leaks and free resources on shutdown (llama/19315) * Fix memory leaks in shader lib, backend, backend_context, buffer_context, and webgpu_buf_pool * Free pools * Cleanup * More cleanup * Run clang-format * Fix arg-parser and tokenizer test errors that free an unallocated buffer * Fix device lost callback to not print on device teardown * Fix include and run clang-format * remove unused unused * Update binary ops --------- Co-authored-by: Reese Levine --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 81 ++++++++--------- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 89 +++++++++++-------- 2 files changed, 94 insertions(+), 76 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 6997f6bdd31..63f797f142d 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -4,6 +4,7 @@ #include "ggml.h" #include "pre_wgsl.hpp" +#include #include #include @@ -18,9 +19,9 @@ #define GGML_WEBGPU_ARGSORT_MERGE_MAX_WG_SIZE 512u struct ggml_webgpu_processed_shader { - std::string wgsl; - std::string variant; - void * decisions; + std::string wgsl; + std::string variant; + std::shared_ptr decisions; }; // Same hash combine function as in boost @@ -192,13 +193,13 @@ inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_shader( defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); ggml_webgpu_processed_shader result; - result.wgsl = preprocessor.preprocess(shader_src, defines); - result.variant = variant; - ggml_webgpu_flash_attn_shader_decisions * decisions = new ggml_webgpu_flash_attn_shader_decisions(); - decisions->q_tile = q_tile; - decisions->kv_tile = kv_tile; - decisions->wg_size = wg_size; - result.decisions = decisions; + result.wgsl = preprocessor.preprocess(shader_src, defines); + result.variant = variant; + auto decisions = std::make_shared(); + decisions->q_tile = q_tile; + decisions->kv_tile = kv_tile; + decisions->wg_size = wg_size; + result.decisions = decisions; return result; } @@ -270,11 +271,11 @@ inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_pad_shader( defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); ggml_webgpu_processed_shader result; - result.wgsl = preprocessor.preprocess(shader_src, defines); - result.variant = variant; - ggml_webgpu_generic_shader_decisions * decisions = new ggml_webgpu_generic_shader_decisions(); - decisions->wg_size = context.max_wg_size; - result.decisions = decisions; + result.wgsl = preprocessor.preprocess(shader_src, defines); + result.variant = variant; + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; + result.decisions = decisions; return result; } @@ -305,11 +306,11 @@ inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_argsort_shader( } defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); ggml_webgpu_processed_shader result; - result.wgsl = preprocessor.preprocess(shader_src, defines); - result.variant = variant; - ggml_webgpu_argsort_shader_decisions * decisions = new ggml_webgpu_argsort_shader_decisions(); - decisions->wg_size = wg_size; - result.decisions = decisions; + result.wgsl = preprocessor.preprocess(shader_src, defines); + result.variant = variant; + auto decisions = std::make_shared(); + decisions->wg_size = wg_size; + result.decisions = decisions; return result; } @@ -324,11 +325,11 @@ inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_argsort_merge_shader( uint32_t wg_size = std::min(GGML_WEBGPU_ARGSORT_MERGE_MAX_WG_SIZE, context.max_wg_size); defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); ggml_webgpu_processed_shader result; - result.wgsl = preprocessor.preprocess(shader_src, defines); - result.variant = variant; - ggml_webgpu_argsort_shader_decisions * decisions = new ggml_webgpu_argsort_shader_decisions(); - decisions->wg_size = wg_size; - result.decisions = decisions; + result.wgsl = preprocessor.preprocess(shader_src, defines); + result.variant = variant; + auto decisions = std::make_shared(); + decisions->wg_size = wg_size; + result.decisions = decisions; return result; } @@ -391,11 +392,11 @@ inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_set_rows_shader( defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); ggml_webgpu_processed_shader result; - result.wgsl = preprocessor.preprocess(shader_src, defines); - result.variant = variant; - ggml_webgpu_generic_shader_decisions * decisions = new ggml_webgpu_generic_shader_decisions(); - decisions->wg_size = context.max_wg_size; - result.decisions = decisions; + result.wgsl = preprocessor.preprocess(shader_src, defines); + result.variant = variant; + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; + result.decisions = decisions; return result; } @@ -457,11 +458,11 @@ inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_unary_shader( defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); ggml_webgpu_processed_shader result; - result.wgsl = preprocessor.preprocess(shader_src, defines); - result.variant = variant; - ggml_webgpu_generic_shader_decisions * decisions = new ggml_webgpu_generic_shader_decisions(); - decisions->wg_size = context.max_wg_size; - result.decisions = decisions; + result.wgsl = preprocessor.preprocess(shader_src, defines); + result.variant = variant; + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; + result.decisions = decisions; return result; } @@ -527,11 +528,11 @@ inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_binary_shader( defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); ggml_webgpu_processed_shader result; - result.wgsl = preprocessor.preprocess(shader_src, defines); - result.variant = variant; - ggml_webgpu_generic_shader_decisions * decisions = new ggml_webgpu_generic_shader_decisions(); - decisions->wg_size = context.max_wg_size; - result.decisions = decisions; + result.wgsl = preprocessor.preprocess(shader_src, defines); + result.variant = variant; + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; + result.decisions = decisions; return result; } #endif // GGML_WEBGPU_SHADER_LIB_HPP diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index f7ceca11212..32e120266a9 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -186,11 +186,17 @@ struct webgpu_buf_pool { void cleanup() { std::lock_guard lock(mutex); for (auto & bufs : free) { - bufs.host_buf.Destroy(); - bufs.dev_buf.Destroy(); + if (bufs.host_buf) { + bufs.host_buf.Destroy(); + } + if (bufs.dev_buf) { + bufs.dev_buf.Destroy(); + } } free.clear(); } + + ~webgpu_buf_pool() { this->cleanup(); } }; #ifdef GGML_WEBGPU_GPU_PROFILE @@ -252,13 +258,15 @@ struct webgpu_gpu_profile_buf_pool { } free.clear(); } + + ~webgpu_gpu_profile_buf_pool() { this->cleanup(); } }; #endif struct webgpu_pipeline { wgpu::ComputePipeline pipeline; std::string name; - void * context = nullptr; + std::shared_ptr context = nullptr; }; struct webgpu_command { @@ -319,6 +327,23 @@ struct webgpu_global_context_struct { wgpu::Buffer debug_host_buf; wgpu::Buffer debug_dev_buf; #endif + + ~webgpu_global_context_struct() { + if (this->get_tensor_staging_buf) { + this->get_tensor_staging_buf.Destroy(); + this->get_tensor_staging_buf = nullptr; + } +#ifdef GGML_WEBGPU_DEBUG + if (this->debug_host_buf) { + this->debug_host_buf.Destroy(); + this->debug_host_buf = nullptr; + } + if (this->debug_dev_buf) { + this->debug_dev_buf.Destroy(); + this->debug_dev_buf = nullptr; + } +#endif + } }; typedef std::shared_ptr webgpu_global_context; @@ -744,7 +769,6 @@ static const char * ggml_backend_webgpu_name(ggml_backend_t backend) { return ctx->name.c_str(); } -// TODO: implement proper cleanup static void ggml_backend_webgpu_free(ggml_backend_t backend) { ggml_backend_webgpu_context * ctx = (ggml_backend_webgpu_context *) backend->context; WEBGPU_LOG_DEBUG("ggml_backend_webgpu_free(" << ctx->name << ")"); @@ -788,9 +812,8 @@ static void ggml_backend_webgpu_free(ggml_backend_t backend) { std::cout << "ggml_webgpu: gpu/cpu ratio: " << (total_cpu > 0.0 ? total_gpu / total_cpu : 0.0) << "\n"; #endif -#if !defined(GGML_WEBGPU_CPU_PROFILE) && !defined(GGML_WEBGPU_GPU_PROFILE) - GGML_UNUSED(ctx); -#endif + delete ctx; + delete backend; } static size_t ggml_webgpu_tensor_offset(const ggml_tensor * tensor) { @@ -896,8 +919,7 @@ static webgpu_command ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, g ctx->pad_pipelines.emplace(pipeline_key, pipeline); } - ggml_webgpu_generic_shader_decisions decisions = - *static_cast(pipeline.context); + auto * decisions = static_cast(pipeline.context.get()); const uint32_t ne = (uint32_t) ggml_nelements(dst); @@ -941,7 +963,7 @@ static webgpu_command ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, g .size = ggml_webgpu_tensor_binding_size(ctx, dst) } }; - uint32_t wg_x = CEIL_DIV(ne, decisions.wg_size); + uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } @@ -975,8 +997,7 @@ static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, ctx->set_rows_pipelines.emplace(key, pipeline); } - ggml_webgpu_generic_shader_decisions decisions = - *static_cast(pipeline.context); + auto * decisions = static_cast(pipeline.context.get()); std::optional error_bufs = std::nullopt; if (key.i64_idx) { @@ -1028,7 +1049,7 @@ static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, } else { threads = src->ne[0] * src->ne[1] * src->ne[2] * src->ne[3]; } - uint32_t wg_x = CEIL_DIV(threads, decisions.wg_size); + uint32_t wg_x = CEIL_DIV(threads, decisions->wg_size); return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, 1, error_bufs); } @@ -1297,10 +1318,9 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, ctx->flash_attn_pipelines.emplace(key, pipeline); } - ggml_webgpu_flash_attn_shader_decisions decisions = - *static_cast(pipeline.context); + auto * decisions = static_cast(pipeline.context.get()); - uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions.q_tile); + uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions->q_tile); uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } @@ -1331,8 +1351,7 @@ static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * s ctx->unary_pipelines.emplace(pipeline_key, pipeline); } - ggml_webgpu_generic_shader_decisions decisions = - *static_cast(pipeline.context); + auto * decisions = static_cast(pipeline.context.get()); uint32_t ne = (uint32_t) ggml_nelements(dst); @@ -1392,7 +1411,7 @@ static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * s .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); } - uint32_t wg_x = CEIL_DIV(ne, decisions.wg_size); + uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } @@ -1425,8 +1444,7 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx, ctx->binary_pipelines.emplace(pipeline_key, pipeline); } - ggml_webgpu_generic_shader_decisions decisions = - *static_cast(pipeline.context); + auto * decisions = static_cast(pipeline.context.get()); uint32_t ne = (uint32_t) ggml_nelements(dst); @@ -1471,7 +1489,7 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx, .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); } - uint32_t wg_x = CEIL_DIV(ne, decisions.wg_size); + uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } @@ -1821,8 +1839,7 @@ static webgpu_command ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * sr argsort_pipeline.context = processed.decisions; ctx->argsort_pipelines.emplace(order, argsort_pipeline); } - ggml_webgpu_argsort_shader_decisions argsort_decisions = - *static_cast(argsort_pipeline.context); + auto * argsort_decisions = static_cast(argsort_pipeline.context.get()); webgpu_pipeline argsort_merge_pipeline; it = ctx->argsort_merge_pipelines.find(order); @@ -1839,13 +1856,13 @@ static webgpu_command ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * sr const uint32_t src_ne0 = (uint32_t) src->ne[0]; const uint32_t nrows = (uint32_t) ggml_nrows(src); - const uint32_t npr = CEIL_DIV(src_ne0, argsort_decisions.wg_size); + const uint32_t npr = CEIL_DIV(src_ne0, argsort_decisions->wg_size); const uint32_t block_size = - is_top_k ? std::min(argsort_decisions.wg_size, (uint32_t) dst->ne[0]) : argsort_decisions.wg_size; + is_top_k ? std::min(argsort_decisions->wg_size, (uint32_t) dst->ne[0]) : argsort_decisions->wg_size; uint32_t out_ne0 = src_ne0; if (is_top_k) { if (npr > 1) { - const uint32_t last_tile = src_ne0 - (npr - 1) * argsort_decisions.wg_size; + const uint32_t last_tile = src_ne0 - (npr - 1) * argsort_decisions->wg_size; out_ne0 = (npr - 1) * block_size + std::min(last_tile, block_size); } else { out_ne0 = block_size; @@ -2198,7 +2215,10 @@ static ggml_backend_i ggml_backend_webgpu_i = { static void ggml_backend_webgpu_buffer_free_buffer(ggml_backend_buffer_t buffer) { ggml_backend_webgpu_buffer_context * ctx = static_cast(buffer->context); - ctx->buffer.Destroy(); + if (ctx != nullptr && ctx->buffer != nullptr) { + ctx->buffer.Destroy(); + delete ctx; + } } // Returns the "fake" base pointer. @@ -2926,12 +2946,12 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { dev_desc.SetDeviceLostCallback( wgpu::CallbackMode::AllowSpontaneous, [](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) { + if (reason == wgpu::DeviceLostReason::Destroyed) { + return; + } GGML_UNUSED(device); - GGML_UNUSED(reason); - GGML_UNUSED(message); - //TODO: uncomment once proper free logic is in place - //GGML_LOG_ERROR("ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast(reason), - //std::string(message).c_str()); + GGML_LOG_ERROR("ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast(reason), + std::string(message).c_str()); }); dev_desc.SetUncapturedErrorCallback( [](const wgpu::Device & device, wgpu::ErrorType reason, wgpu::StringView message) { @@ -3365,10 +3385,7 @@ static size_t ggml_backend_webgpu_reg_get_device_count(ggml_backend_reg_t reg) { return ctx->device_count; } -// TODO: Does this need to be thread safe? Is it only called once? -// TODO: move most logic to device_init function so backend can be freed/initialized properly // Only one device is supported for now - static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t reg, size_t index) { GGML_ASSERT(index == 0); WEBGPU_LOG_DEBUG("ggml_backend_reg_get_device()"); From 57c620b4b19fe96d48b081349c7536954cb442a6 Mon Sep 17 00:00:00 2001 From: Oliver Simons Date: Tue, 10 Feb 2026 22:31:19 +0100 Subject: [PATCH 127/831] CUDA : Update CCCL-tag for 3.2 to final release from RC (llama/19486) CCCL 3.2 has been released since it was added to llama.cpp as part of the backend-sampling PR, and it makes sense to update from RC to final released version. https://github.com/NVIDIA/cccl/releases/tag/v3.2.0 --- ggml/src/ggml-cuda/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/CMakeLists.txt b/ggml/src/ggml-cuda/CMakeLists.txt index d313c1ac9af..262f88204e0 100644 --- a/ggml/src/ggml-cuda/CMakeLists.txt +++ b/ggml/src/ggml-cuda/CMakeLists.txt @@ -64,7 +64,7 @@ if (CUDAToolkit_FOUND) FetchContent_Declare( CCCL GIT_REPOSITORY https://github.com/nvidia/cccl.git - GIT_TAG v3.2.0-rc2 + GIT_TAG v3.2.0 GIT_SHALLOW TRUE ) From de949fb1db96ba4daf174b0eba514c9159c65980 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 11 Feb 2026 07:51:12 +0200 Subject: [PATCH 128/831] metal : consolidate unary ops (llama/19490) --- ggml/src/ggml-metal/ggml-metal-device.cpp | 82 +-- ggml/src/ggml-metal/ggml-metal-device.m | 21 +- ggml/src/ggml-metal/ggml-metal-impl.h | 75 ++- ggml/src/ggml-metal/ggml-metal-ops.cpp | 209 ++------ ggml/src/ggml-metal/ggml-metal-ops.h | 4 - ggml/src/ggml-metal/ggml-metal.metal | 614 +++++++--------------- 6 files changed, 352 insertions(+), 653 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 4c4c3ce36c4..949e344cc8c 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -212,61 +212,69 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_repeat(ggml_meta } ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary(ggml_metal_library_t lib, const ggml_tensor * op) { - GGML_ASSERT(ggml_is_contiguous(op->src[0])); - char base[256]; char name[256]; - const int64_t n = ggml_nelements(op); + int op_num = -1; - const char * op_str = "undefined"; switch (op->op) { - case GGML_OP_SCALE: op_str = "scale"; break; - case GGML_OP_FILL: op_str = "fill"; break; - case GGML_OP_CLAMP: op_str = "clamp"; break; - case GGML_OP_SQR: op_str = "sqr"; break; - case GGML_OP_SQRT: op_str = "sqrt"; break; - case GGML_OP_SIN: op_str = "sin"; break; - case GGML_OP_COS: op_str = "cos"; break; - case GGML_OP_LOG: op_str = "log"; break; - case GGML_OP_LEAKY_RELU: op_str = "leaky_relu"; break; + case GGML_OP_SCALE: op_num = OP_UNARY_NUM_SCALE; break; + case GGML_OP_FILL: op_num = OP_UNARY_NUM_FILL; break; + case GGML_OP_CLAMP: op_num = OP_UNARY_NUM_CLAMP; break; + case GGML_OP_SQR: op_num = OP_UNARY_NUM_SQR; break; + case GGML_OP_SQRT: op_num = OP_UNARY_NUM_SQRT; break; + case GGML_OP_SIN: op_num = OP_UNARY_NUM_SIN; break; + case GGML_OP_COS: op_num = OP_UNARY_NUM_COS; break; + case GGML_OP_LOG: op_num = OP_UNARY_NUM_LOG; break; + case GGML_OP_LEAKY_RELU: op_num = OP_UNARY_NUM_LEAKY_RELU; break; case GGML_OP_UNARY: switch (ggml_get_unary_op(op)) { - case GGML_UNARY_OP_TANH: op_str = "tanh"; break; - case GGML_UNARY_OP_RELU: op_str = "relu"; break; - case GGML_UNARY_OP_SIGMOID: op_str = "sigmoid"; break; - case GGML_UNARY_OP_GELU: op_str = "gelu"; break; - case GGML_UNARY_OP_GELU_ERF: op_str = "gelu_erf"; break; - case GGML_UNARY_OP_GELU_QUICK: op_str = "gelu_quick"; break; - case GGML_UNARY_OP_SILU: op_str = "silu"; break; - case GGML_UNARY_OP_ELU: op_str = "elu"; break; - case GGML_UNARY_OP_NEG: op_str = "neg"; break; - case GGML_UNARY_OP_ABS: op_str = "abs"; break; - case GGML_UNARY_OP_SGN: op_str = "sgn"; break; - case GGML_UNARY_OP_STEP: op_str = "step"; break; - case GGML_UNARY_OP_HARDSWISH: op_str = "hardswish"; break; - case GGML_UNARY_OP_HARDSIGMOID: op_str = "hardsigmoid"; break; - case GGML_UNARY_OP_EXP: op_str = "exp"; break; - case GGML_UNARY_OP_SOFTPLUS: op_str = "softplus"; break; - case GGML_UNARY_OP_EXPM1: op_str = "expm1"; break; + case GGML_UNARY_OP_TANH: op_num = OP_UNARY_NUM_TANH; break; + case GGML_UNARY_OP_RELU: op_num = OP_UNARY_NUM_RELU; break; + case GGML_UNARY_OP_SIGMOID: op_num = OP_UNARY_NUM_SIGMOID; break; + case GGML_UNARY_OP_GELU: op_num = OP_UNARY_NUM_GELU; break; + case GGML_UNARY_OP_GELU_ERF: op_num = OP_UNARY_NUM_GELU_ERF; break; + case GGML_UNARY_OP_GELU_QUICK: op_num = OP_UNARY_NUM_GELU_QUICK; break; + case GGML_UNARY_OP_SILU: op_num = OP_UNARY_NUM_SILU; break; + case GGML_UNARY_OP_ELU: op_num = OP_UNARY_NUM_ELU; break; + case GGML_UNARY_OP_NEG: op_num = OP_UNARY_NUM_NEG; break; + case GGML_UNARY_OP_ABS: op_num = OP_UNARY_NUM_ABS; break; + case GGML_UNARY_OP_SGN: op_num = OP_UNARY_NUM_SGN; break; + case GGML_UNARY_OP_STEP: op_num = OP_UNARY_NUM_STEP; break; + case GGML_UNARY_OP_HARDSWISH: op_num = OP_UNARY_NUM_HARDSWISH; break; + case GGML_UNARY_OP_HARDSIGMOID: op_num = OP_UNARY_NUM_HARDSIGMOID; break; + case GGML_UNARY_OP_EXP: op_num = OP_UNARY_NUM_EXP; break; + case GGML_UNARY_OP_SOFTPLUS: op_num = OP_UNARY_NUM_SOFTPLUS; break; + case GGML_UNARY_OP_EXPM1: op_num = OP_UNARY_NUM_EXPM1; break; default: GGML_ABORT("fatal error"); } break; default: GGML_ABORT("fatal error"); }; - const char * suffix = ""; - if (n % 4 == 0) { - suffix = "_4"; - } + const char * t0_str = ggml_type_name(op->src[0]->type); + const char * t_str = ggml_type_name(op->type); - snprintf(base, 256, "kernel_%s_%s%s", op_str, ggml_type_name(op->src[0]->type), suffix); - snprintf(name, 256, "%s", base); + const bool is_c4 = op->src[0]->ne[0] % 4 == 0; + const bool is_cnt = ggml_is_contiguous(op->src[0]) && ggml_nelements(op) < 32768; + + snprintf(base, 256, "kernel_unary_%s_%s%s", t0_str, t_str, is_c4 ? "_4" : ""); + snprintf(name, 256, "%s_op=%d_cnt=%d", base, op_num, is_cnt); ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); if (!res.pipeline) { - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_int16(cv, op_num, FC_UNARY + 0); + ggml_metal_cv_set_bool (cv, is_cnt, FC_UNARY + 1); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); } + res.c4 = is_c4; + res.cnt = is_cnt; + return res; } diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 891d70c85a4..50a2a3e7f72 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -1011,6 +1011,15 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te } switch (op->op) { + case GGML_OP_SCALE: + case GGML_OP_FILL: + case GGML_OP_CLAMP: + case GGML_OP_SQR: + case GGML_OP_SQRT: + case GGML_OP_SIN: + case GGML_OP_COS: + case GGML_OP_LOG: + return ggml_is_contiguous_rows(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_UNARY: switch (ggml_get_unary_op(op)) { case GGML_UNARY_OP_TANH: @@ -1030,7 +1039,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_UNARY_OP_EXP: case GGML_UNARY_OP_SOFTPLUS: case GGML_UNARY_OP_EXPM1: - return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; + return ggml_is_contiguous_rows(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; default: return false; } @@ -1061,8 +1070,6 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te return ggml_is_contiguous_rows(op->src[0]) && ggml_is_contiguous_rows(op->src[1]) && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_ACC: case GGML_OP_REPEAT: - case GGML_OP_SCALE: - case GGML_OP_FILL: case GGML_OP_CONV_TRANSPOSE_1D: return true; case GGML_OP_CONV_TRANSPOSE_2D: @@ -1070,14 +1077,6 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te (op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32) && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; - case GGML_OP_CLAMP: - return op->src[0]->type == GGML_TYPE_F32; - case GGML_OP_SQR: - case GGML_OP_SQRT: - case GGML_OP_SIN: - case GGML_OP_COS: - case GGML_OP_LOG: - return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_SUM: return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]); case GGML_OP_TRI: diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 77bb403c15d..44141f8e3d9 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -80,7 +80,8 @@ #define FC_SSM_CONV 900 #define FC_SOLVE_TRI 1000 #define FC_COUNT_EQUAL 1100 -#define FC_BIN 1200 +#define FC_UNARY 1200 +#define FC_BIN 1300 // op-specific constants #define OP_FLASH_ATTN_EXT_NQPSG 8 @@ -89,6 +90,35 @@ #define OP_FLASH_ATTN_EXT_VEC_NQPSG 1 #define OP_FLASH_ATTN_EXT_VEC_NCPSG 32 +#define OP_UNARY_NUM_SCALE 10 +#define OP_UNARY_NUM_FILL 11 +#define OP_UNARY_NUM_CLAMP 12 +#define OP_UNARY_NUM_SQR 13 +#define OP_UNARY_NUM_SQRT 14 +#define OP_UNARY_NUM_SIN 15 +#define OP_UNARY_NUM_COS 16 +#define OP_UNARY_NUM_LOG 17 +#define OP_UNARY_NUM_LEAKY_RELU 18 + +#define OP_UNARY_NUM_TANH 100 +#define OP_UNARY_NUM_RELU 101 +#define OP_UNARY_NUM_SIGMOID 102 +#define OP_UNARY_NUM_GELU 103 +#define OP_UNARY_NUM_GELU_ERF 104 +#define OP_UNARY_NUM_GELU_QUICK 105 +#define OP_UNARY_NUM_SILU 106 +#define OP_UNARY_NUM_ELU 107 +#define OP_UNARY_NUM_NEG 108 +#define OP_UNARY_NUM_ABS 109 +#define OP_UNARY_NUM_SGN 110 +#define OP_UNARY_NUM_STEP 111 +#define OP_UNARY_NUM_HARDSWISH 112 +#define OP_UNARY_NUM_HARDSIGMOID 113 +#define OP_UNARY_NUM_EXP 114 +#define OP_UNARY_NUM_SOFTPLUS 115 +#define OP_UNARY_NUM_EXPM1 116 + + // kernel argument structs // // - element counters (e.g. ne00) typically use int32_t to reduce register usage @@ -124,6 +154,31 @@ typedef struct { int32_t dim; } ggml_metal_kargs_concat; +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; + float slope; + float scale; + float bias; + float val; + float min; + float max; +} ggml_metal_kargs_unary; + typedef struct { int32_t ne00; int32_t ne01; @@ -181,20 +236,6 @@ typedef struct { uint64_t nb3; } ggml_metal_kargs_repeat; -typedef struct { - float scale; - float bias; -} ggml_metal_kargs_scale; - -typedef struct { - float val; -} ggml_metal_kargs_fill; - -typedef struct { - float min; - float max; -} ggml_metal_kargs_clamp; - typedef struct { int64_t nk0; int64_t ne00; @@ -881,10 +922,6 @@ typedef struct { int max_period; } ggml_metal_kargs_timestep_embedding; -typedef struct { - float slope; -} ggml_metal_kargs_leaky_relu; - typedef struct { int32_t ne00; int32_t ne01; diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index dbf25433c25..b159a8e7fd0 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -287,17 +287,9 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { n_fuse = ggml_metal_op_acc(ctx, idx); } break; case GGML_OP_SCALE: - { - n_fuse = ggml_metal_op_scale(ctx, idx); - } break; case GGML_OP_FILL: - { - n_fuse = ggml_metal_op_fill(ctx, idx); - } break; case GGML_OP_CLAMP: - { - n_fuse = ggml_metal_op_clamp(ctx, idx); - } break; + case GGML_OP_LEAKY_RELU: case GGML_OP_SQR: case GGML_OP_SQRT: case GGML_OP_SIN: @@ -426,10 +418,6 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { { n_fuse = ggml_metal_op_top_k(ctx, idx); } break; - case GGML_OP_LEAKY_RELU: - { - n_fuse = ggml_metal_op_leaky_relu(ctx, idx); - } break; case GGML_OP_TRI: { n_fuse = ggml_metal_op_tri(ctx, idx); @@ -722,7 +710,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) { return 1; } -int ggml_metal_op_scale(ggml_metal_op_t ctx, int idx) { +int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) { ggml_tensor * op = ctx->node(idx); ggml_metal_library_t lib = ctx->lib; @@ -733,133 +721,80 @@ int ggml_metal_op_scale(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne, op, ne); GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); - float scale; - float bias; - memcpy(&scale, ((const int32_t *) op->op_params) + 0, sizeof(float)); - memcpy(&bias, ((const int32_t *) op->op_params) + 1, sizeof(float)); + GGML_ASSERT(ggml_is_contiguous_rows(op->src[0])); - ggml_metal_kargs_scale args = { - /*.scale =*/ scale, - /*.bias =*/ bias, - }; + ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]); + ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op); - int64_t n = ggml_nelements(op); + ggml_metal_kargs_unary args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.slope =*/ 0.0, + /*.scale =*/ 0.0, + /*.bias =*/ 0.0, + /*.val =*/ 0.0, + /*.min =*/ 0.0, + /*.max =*/ 0.0, + }; - if (n % 4 == 0) { - n /= 4; + if (op->op == GGML_OP_LEAKY_RELU) { + args.slope = ggml_get_op_params_f32(op, 0); } - auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op); - - ggml_metal_encoder_set_pipeline(enc, pipeline); - ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); - - ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1); - - return 1; -} - -int ggml_metal_op_fill(ggml_metal_op_t ctx, int idx) { - ggml_tensor * op = ctx->node(idx); - - ggml_metal_library_t lib = ctx->lib; - ggml_metal_encoder_t enc = ctx->enc; - - GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); - GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); - GGML_TENSOR_LOCALS( int32_t, ne, op, ne); - GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); - - const float val = ggml_get_op_params_f32(op, 0); - - ggml_metal_kargs_fill args = { - /*.val =*/ val - }; + if (op->op == GGML_OP_SCALE) { + args.scale = ggml_get_op_params_f32(op, 0); + args.bias = ggml_get_op_params_f32(op, 1); + } - int64_t n = ggml_nelements(op); + if (op->op == GGML_OP_FILL) { + args.val = ggml_get_op_params_f32(op, 0); + } - if (n % 4 == 0) { - n /= 4; + if (op->op == GGML_OP_CLAMP) { + args.min = ggml_get_op_params_f32(op, 0); + args.max = ggml_get_op_params_f32(op, 1); } auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op); - ggml_metal_encoder_set_pipeline(enc, pipeline); - ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); - - ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1); - - return 1; -} - -int ggml_metal_op_clamp(ggml_metal_op_t ctx, int idx) { - ggml_tensor * op = ctx->node(idx); - - ggml_metal_library_t lib = ctx->lib; - ggml_metal_encoder_t enc = ctx->enc; - - GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); - GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); - GGML_TENSOR_LOCALS( int32_t, ne, op, ne); - GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); - - float min; - float max; - memcpy(&min, ((const int32_t *) op->op_params) + 0, sizeof(float)); - memcpy(&max, ((const int32_t *) op->op_params) + 1, sizeof(float)); - - ggml_metal_kargs_clamp args = { - /*.min =*/ min, - /*.max =*/ max, - }; - - int64_t n = ggml_nelements(op); - - if (n % 4 == 0) { - n /= 4; + if (pipeline.c4) { + args.ne00 = ne00/4; + args.ne0 = ne0/4; } - auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op); - ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + ggml_metal_encoder_set_buffer (enc, bid_src0, 1); + ggml_metal_encoder_set_buffer (enc, bid_dst, 2); - ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1); + if (pipeline.cnt) { + const int n = pipeline.c4 ? ggml_nelements(op)/4 : ggml_nelements(op); - return 1; -} + ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1); + } else { + const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); -int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) { - ggml_tensor * op = ctx->node(idx); + const int nth = MIN(args.ne00, nth_max); - ggml_metal_library_t lib = ctx->lib; - ggml_metal_encoder_t enc = ctx->enc; + const int nk0 = (args.ne00 + nth - 1)/nth; - GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); - GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); - GGML_TENSOR_LOCALS( int32_t, ne, op, ne); - GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); - - int64_t n = ggml_nelements(op); - - if (n % 4 == 0) { - n /= 4; + ggml_metal_encoder_dispatch_threadgroups(enc, nk0*ne01, ne02, ne03, nth, 1, 1); } - auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op); - - ggml_metal_encoder_set_pipeline(enc, pipeline); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 0); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 1); - - ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1); - return 1; } @@ -4084,42 +4019,6 @@ int ggml_metal_op_top_k(ggml_metal_op_t ctx, int idx) { return 1; } -int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) { - ggml_tensor * op = ctx->node(idx); - - ggml_metal_library_t lib = ctx->lib; - ggml_metal_encoder_t enc = ctx->enc; - - GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); - GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); - GGML_TENSOR_LOCALS( int32_t, ne, op, ne); - GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); - - float slope; - memcpy(&slope, op->op_params, sizeof(float)); - - ggml_metal_kargs_leaky_relu args = { - /*.slope =*/ slope - }; - - auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op); - - int64_t n = ggml_nelements(op); - - if (n % 4 == 0) { - n /= 4; - } - - ggml_metal_encoder_set_pipeline(enc, pipeline); - ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); - - ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1); - - return 1; -} - int ggml_metal_op_tri(ggml_metal_op_t ctx, int idx) { ggml_tensor * op = ctx->node(idx); diff --git a/ggml/src/ggml-metal/ggml-metal-ops.h b/ggml/src/ggml-metal/ggml-metal-ops.h index 3c64e4f6007..29456d70d5e 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.h +++ b/ggml/src/ggml-metal/ggml-metal-ops.h @@ -46,9 +46,6 @@ size_t ggml_metal_op_flash_attn_ext_extra_tmp(const struct ggml_tensor * op); int ggml_metal_op_concat (ggml_metal_op_t ctx, int idx); int ggml_metal_op_repeat (ggml_metal_op_t ctx, int idx); int ggml_metal_op_acc (ggml_metal_op_t ctx, int idx); -int ggml_metal_op_scale (ggml_metal_op_t ctx, int idx); -int ggml_metal_op_fill (ggml_metal_op_t ctx, int idx); -int ggml_metal_op_clamp (ggml_metal_op_t ctx, int idx); int ggml_metal_op_unary (ggml_metal_op_t ctx, int idx); int ggml_metal_op_glu (ggml_metal_op_t ctx, int idx); int ggml_metal_op_sum (ggml_metal_op_t ctx, int idx); @@ -86,7 +83,6 @@ int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx); int ggml_metal_op_argmax (ggml_metal_op_t ctx, int idx); int ggml_metal_op_argsort (ggml_metal_op_t ctx, int idx); int ggml_metal_op_top_k (ggml_metal_op_t ctx, int idx); -int ggml_metal_op_leaky_relu (ggml_metal_op_t ctx, int idx); int ggml_metal_op_tri (ggml_metal_op_t ctx, int idx); int ggml_metal_op_opt_step_adamw (ggml_metal_op_t ctx, int idx); int ggml_metal_op_opt_step_sgd (ggml_metal_op_t ctx, int idx); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 35cc3bbdfdf..7d841341a18 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -895,6 +895,192 @@ enum ggml_sort_order { GGML_SORT_ORDER_DESC, }; +constant float GELU_COEF_A = 0.044715f; +constant float GELU_QUICK_COEF = -1.702f; +constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; +constant float SQRT_2_INV = 0.70710678118654752440084436210484f; + +// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation +// ref: https://www.johndcook.com/blog/python_erf/ +constant float p_erf = 0.3275911f; +constant float a1_erf = 0.254829592f; +constant float a2_erf = -0.284496736f; +constant float a3_erf = 1.421413741f; +constant float a4_erf = -1.453152027f; +constant float a5_erf = 1.061405429f; + +template +T erf_approx(T x) { + T sign_x = sign(x); + x = fabs(x); + T t = 1.0f / (1.0f + p_erf * x); + T y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x); + return sign_x * y; +} + +constant short FC_unary_op [[function_constant(FC_UNARY + 0)]]; +constant bool FC_unary_cnt[[function_constant(FC_UNARY + 1)]]; + +template +kernel void kernel_unary_impl( + constant ggml_metal_kargs_unary & args, + device const char * src0, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { +#define FC_OP FC_unary_op +#define FC_CNT FC_unary_cnt + + device const T0 * src0_ptr; + device T * dst_ptr; + + int i0; + + if (FC_CNT) { + i0 = tgpig.x; + + src0_ptr = (device const T0 *) (src0); + dst_ptr = (device T *) (dst); + } else { + const int i03 = tgpig.z; + const int i02 = tgpig.y; + const int k0 = tgpig.x/args.ne01; + const int i01 = tgpig.x - k0*args.ne01; + + i0 = k0*ntg.x + tpitg.x; + + src0_ptr = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01); + dst_ptr = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 ); + } + + { + //threadgroup_barrier(mem_flags::mem_none); + + if (!FC_CNT) { + if (i0 >= args.ne0) { + return; + } + } + + device const T0 & x = src0_ptr[i0]; + + if (FC_OP == OP_UNARY_NUM_SCALE) { + dst_ptr[i0] = args.scale * x + args.bias; + } + + if (FC_OP == OP_UNARY_NUM_FILL) { + dst_ptr[i0] = args.val; + } + + if (FC_OP == OP_UNARY_NUM_CLAMP) { + dst_ptr[i0] = clamp(x, args.min, args.max); + } + + if (FC_OP == OP_UNARY_NUM_SQR) { + dst_ptr[i0] = x * x; + } + + if (FC_OP == OP_UNARY_NUM_SQRT) { + dst_ptr[i0] = sqrt(x); + } + + if (FC_OP == OP_UNARY_NUM_SIN) { + dst_ptr[i0] = sin(x); + } + + if (FC_OP == OP_UNARY_NUM_COS) { + dst_ptr[i0] = cos(x); + } + + if (FC_OP == OP_UNARY_NUM_LOG) { + dst_ptr[i0] = log(x); + } + + if (FC_OP == OP_UNARY_NUM_LEAKY_RELU) { + dst_ptr[i0] = T(x > 0.0f)*x + T(x <= 0.0f)*(x * args.slope); + } + + if (FC_OP == OP_UNARY_NUM_TANH) { + dst_ptr[i0] = precise::tanh(x); + } + + if (FC_OP == OP_UNARY_NUM_RELU) { + dst_ptr[i0] = fmax(0.0f, x); + } + + if (FC_OP == OP_UNARY_NUM_SIGMOID) { + dst_ptr[i0] = 1.0f / (1.0f + exp(-x)); + } + + if (FC_OP == OP_UNARY_NUM_GELU) { + dst_ptr[i0] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); + } + + if (FC_OP == OP_UNARY_NUM_GELU_ERF) { + dst_ptr[i0] = 0.5f*x*(1.0f + erf_approx(SQRT_2_INV*x)); + } + + if (FC_OP == OP_UNARY_NUM_GELU_QUICK) { + dst_ptr[i0] = x * (1.0f/(1.0f + exp(GELU_QUICK_COEF*x))); + } + + if (FC_OP == OP_UNARY_NUM_SILU) { + dst_ptr[i0] = x / (1.0f + exp(-x)); + } + + if (FC_OP == OP_UNARY_NUM_ELU) { + dst_ptr[i0] = T(x > 0.0f)*x + T(x <= 0.0f)*(exp(x) - 1.0f); + } + + if (FC_OP == OP_UNARY_NUM_NEG) { + dst_ptr[i0] = -x; + } + + if (FC_OP == OP_UNARY_NUM_ABS) { + dst_ptr[i0] = fabs(x); + } + + if (FC_OP == OP_UNARY_NUM_SGN) { + dst_ptr[i0] = T(x > 0.0f) - T(x < 0.0f); + } + + if (FC_OP == OP_UNARY_NUM_STEP) { + dst_ptr[i0] = T(x > 0.0f); + } + + if (FC_OP == OP_UNARY_NUM_HARDSWISH) { + dst_ptr[i0] = x * fmax(0.0f, fmin(1.0f, x/6.0f + 0.5f)); + } + + if (FC_OP == OP_UNARY_NUM_HARDSIGMOID) { + dst_ptr[i0] = fmax(0.0f, fmin(1.0f, x/6.0f + 0.5f)); + } + + if (FC_OP == OP_UNARY_NUM_EXP) { + dst_ptr[i0] = exp(x); + } + + if (FC_OP == OP_UNARY_NUM_SOFTPLUS) { + dst_ptr[i0] = select(log(1.0f + exp(x)), x, x > 20.0f); + } + + if (FC_OP == OP_UNARY_NUM_EXPM1) { + // TODO: precise implementation + dst_ptr[i0] = exp(x) - 1.0f; + } + } + +#undef FC_OP +#undef FC_CNT +} + +typedef decltype(kernel_unary_impl) kernel_unary_t; + +template [[host_name("kernel_unary_f32_f32")]] kernel kernel_unary_t kernel_unary_impl; +template [[host_name("kernel_unary_f32_f32_4")]] kernel kernel_unary_t kernel_unary_impl; + + // OP: 0 - add, 1 - sub, 2 - mul, 3 - div constant short FC_bin_op [[function_constant(FC_BIN + 0)]]; constant short FC_bin_f [[function_constant(FC_BIN + 1)]]; @@ -1114,414 +1300,6 @@ template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat; template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat; -kernel void kernel_scale_f32( - constant ggml_metal_kargs_scale & args, - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] * args.scale + args.bias; -} - -kernel void kernel_scale_f32_4( - constant ggml_metal_kargs_scale & args, - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] * args.scale + args.bias; -} - -kernel void kernel_fill_f32( - constant ggml_metal_kargs_fill & args, - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = args.val; -} - -kernel void kernel_fill_f32_4( - constant ggml_metal_kargs_fill & args, - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = args.val; -} - -kernel void kernel_clamp_f32( - constant ggml_metal_kargs_clamp & args, - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = clamp(src0[tpig], args.min, args.max); -} - -kernel void kernel_clamp_f32_4( - constant ggml_metal_kargs_clamp & args, - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = clamp(src0[tpig], args.min, args.max); -} - -kernel void kernel_relu_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = max(0.0f, src0[tpig]); -} - -kernel void kernel_relu_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = max(0.0f, src0[tpig]); -} - -kernel void kernel_sigmoid_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig])); -} - -kernel void kernel_sigmoid_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig])); -} - -kernel void kernel_tanh_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = precise::tanh(src0[tpig]); -} - -kernel void kernel_tanh_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = precise::tanh(src0[tpig]); -} - -constant float GELU_COEF_A = 0.044715f; -constant float GELU_QUICK_COEF = -1.702f; -constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; -constant float SQRT_2_INV = 0.70710678118654752440084436210484f; - -kernel void kernel_gelu_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - device const float & x = src0[tpig]; - - dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); -} - -kernel void kernel_gelu_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - device const float4 & x = src0[tpig]; - - // BEWARE !!! - // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs! - // This was observed with Falcon 7B and 40B models - // - dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); -} - -kernel void kernel_gelu_quick_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - device const float & x = src0[tpig]; - - dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x))); -} - -kernel void kernel_gelu_quick_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - device const float4 & x = src0[tpig]; - - dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x))); -} - -// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation -// ref: https://www.johndcook.com/blog/python_erf/ -constant float p_erf = 0.3275911f; -constant float a1_erf = 0.254829592f; -constant float a2_erf = -0.284496736f; -constant float a3_erf = 1.421413741f; -constant float a4_erf = -1.453152027f; -constant float a5_erf = 1.061405429f; - -template -T erf_approx(T x) { - T sign_x = sign(x); - x = fabs(x); - T t = 1.0f / (1.0f + p_erf * x); - T y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x); - return sign_x * y; -} - -kernel void kernel_gelu_erf_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - device const float & x = src0[tpig]; - - dst[tpig] = 0.5f*x*(1.0f+erf_approx(x*SQRT_2_INV)); -} - -kernel void kernel_gelu_erf_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - device const float4 & x = src0[tpig]; - - dst[tpig] = 0.5f*x*(1.0f+erf_approx(x*SQRT_2_INV)); -} - -kernel void kernel_silu_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - device const float & x = src0[tpig]; - dst[tpig] = x / (1.0f + exp(-x)); -} - -kernel void kernel_silu_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - device const float4 & x = src0[tpig]; - dst[tpig] = x / (1.0f + exp(-x)); -} - -kernel void kernel_elu_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - const float x = src0[tpig]; - dst[tpig] = (x > 0.0f) ? x : (exp(x) - 1.0f); -} - -kernel void kernel_elu_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - const float4 x = src0[tpig]; - dst[tpig][0] = (x[0] > 0.0f) ? x[0] : (exp(x[0]) - 1.0f); - dst[tpig][1] = (x[1] > 0.0f) ? x[1] : (exp(x[1]) - 1.0f); - dst[tpig][2] = (x[2] > 0.0f) ? x[2] : (exp(x[2]) - 1.0f); - dst[tpig][3] = (x[3] > 0.0f) ? x[3] : (exp(x[3]) - 1.0f); -} - -kernel void kernel_sqr_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] * src0[tpig]; -} - -kernel void kernel_sqr_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] * src0[tpig]; -} - -kernel void kernel_sqrt_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = sqrt(src0[tpig]); -} - -kernel void kernel_sqrt_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = sqrt(src0[tpig]); -} - -kernel void kernel_sin_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = sin(src0[tpig]); -} - -kernel void kernel_sin_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = sin(src0[tpig]); -} - -kernel void kernel_cos_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = cos(src0[tpig]); -} - -kernel void kernel_cos_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = cos(src0[tpig]); -} - -kernel void kernel_log_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = log(src0[tpig]); -} - -kernel void kernel_log_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = log(src0[tpig]); -} - -kernel void kernel_neg_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = -src0[tpig]; -} - -kernel void kernel_neg_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = -src0[tpig]; -} - -kernel void kernel_abs_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = fabs(src0[tpig]); -} - -kernel void kernel_abs_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = fabs(src0[tpig]); -} - -kernel void kernel_sgn_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = sign(src0[tpig]); -} - -kernel void kernel_sgn_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = sign(src0[tpig]); -} - -kernel void kernel_step_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = step(0.0f, src0[tpig]); -} - -kernel void kernel_step_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = step(0.0f, src0[tpig]); -} - -kernel void kernel_hardswish_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - const float x = src0[tpig]; - dst[tpig] = x * fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f)); -} - -kernel void kernel_hardswish_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - const float4 x = src0[tpig]; - dst[tpig] = x * fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f)); -} - -kernel void kernel_hardsigmoid_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - const float x = src0[tpig]; - dst[tpig] = fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f)); -} - -kernel void kernel_hardsigmoid_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - const float4 x = src0[tpig]; - dst[tpig] = fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f)); -} - -kernel void kernel_exp_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = exp(src0[tpig]); -} - -kernel void kernel_exp_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = exp(src0[tpig]); -} - -kernel void kernel_softplus_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - device const float & x = src0[tpig]; - dst[tpig] = select(log(1.0f + exp(x)), x, x > 20.0f); -} - -kernel void kernel_softplus_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - device const float4 & x = src0[tpig]; - dst[tpig] = select(log(1.0f + exp(x)), x, x > 20.0f); -} - -kernel void kernel_expm1_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = exp(src0[tpig]) - 1.0f; -} - -kernel void kernel_expm1_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = exp(src0[tpig]) - 1.0f; -} - kernel void kernel_reglu_f32( constant ggml_metal_kargs_glu & args, device const char * src0, @@ -5072,24 +4850,6 @@ kernel void kernel_argsort_merge_f32_i32( template [[host_name("kernel_argsort_merge_f32_i32_asc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32; template [[host_name("kernel_argsort_merge_f32_i32_desc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32; -kernel void kernel_leaky_relu_f32( - constant ggml_metal_kargs_leaky_relu & args, - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - const float x = src0[tpig]; - dst[tpig] = x > 0.0f ? x : x * args.slope; -} - -kernel void kernel_leaky_relu_f32_4( - constant ggml_metal_kargs_leaky_relu & args, - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - const float4 x = src0[tpig]; - dst[tpig] = float4(x > 0.0f)*x + float4(x <= 0.0f)*(x * args.slope); -} - constant bool FC_flash_attn_ext_pad_has_mask [[function_constant(FC_FLASH_ATTN_EXT_PAD + 0)]]; constant int32_t FC_flash_attn_ext_pad_ncpsg [[function_constant(FC_FLASH_ATTN_EXT_PAD + 25)]]; @@ -9939,7 +9699,7 @@ kernel void kernel_opt_step_sgd_f32( template kernel void kernel_memset( - constant ggml_metal_kargs_fill & args, + constant ggml_metal_kargs_memset & args, device T * dst, uint tpig[[thread_position_in_grid]]) { dst[tpig] = args.val; From 350435805607207934b4e79554012649e8d0cc53 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 11 Feb 2026 07:52:00 +0200 Subject: [PATCH 129/831] ggml : extend bin bcast for permuted src1 (llama/19484) * tests : extend bin bcast for permuted src1 * cont : extend bin support * cont : s0 is always 1 * tests : simplify --- ggml/src/ggml-cpu/binary-ops.cpp | 8 ++--- ggml/src/ggml-cuda/binbcast.cu | 62 ++++++++++++++++---------------- 2 files changed, 34 insertions(+), 36 deletions(-) diff --git a/ggml/src/ggml-cpu/binary-ops.cpp b/ggml/src/ggml-cpu/binary-ops.cpp index 14f5b43ae0e..75e38290015 100644 --- a/ggml/src/ggml-cpu/binary-ops.cpp +++ b/ggml/src/ggml-cpu/binary-ops.cpp @@ -59,11 +59,7 @@ static void apply_binary_op(const ggml_compute_params * params, ggml_tensor * ds GGML_ASSERT(nb00 == sizeof(src0_t)); const auto [ir0, ir1] = get_thread_range(params, src0); - const bool is_src1_contiguous = (nb10 == sizeof(src1_t)); - - if (!is_src1_contiguous) { // broadcast not implemented yet for non-contiguous - GGML_ASSERT(ggml_are_same_shape(src0, src1)); - } + const bool is_src1_contiguous_rows = ggml_is_contiguous_rows(src1); #ifdef GGML_USE_ACCELERATE vDSP_fn_t vDSP_op = nullptr; @@ -94,7 +90,7 @@ static void apply_binary_op(const ggml_compute_params * params, ggml_tensor * ds const src0_t * src0_ptr = (const src0_t *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); const src1_t * src1_ptr = (const src1_t *) ((const char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11); - if (is_src1_contiguous) { + if (is_src1_contiguous_rows) { // src1 is broadcastable across src0 and dst in i1, i2, i3 const int64_t nr0 = ne00 / ne10; diff --git a/ggml/src/ggml-cuda/binbcast.cu b/ggml/src/ggml-cuda/binbcast.cu index 0e6d777b1e6..7339fe0c070 100644 --- a/ggml/src/ggml-cuda/binbcast.cu +++ b/ggml/src/ggml-cuda/binbcast.cu @@ -39,13 +39,16 @@ static __global__ void k_bin_bcast(const src0_t * src0, const uint3 ne11, const uint3 ne12, const uint3 ne13, - /*int s0, */ const int s1, + /*const int s0,*/ + const int s1, const int s2, const int s3, - /*int s00,*/ const int s01, + const int s00, + const int s01, const int s02, const int s03, - /*int s10,*/ const int s11, + const int s10, + const int s11, const int s12, const int s13, src1_ptrs... src1s) { @@ -72,11 +75,11 @@ static __global__ void k_bin_bcast(const src0_t * src0, for (int i0 = i0s; i0 < ne0; i0 += blockDim.x * gridDim.x) { const uint32_t i10 = fastmodulo(i0, ne10); - float result = src0_row ? (float) src0_row[i0] : 0.0f; + float result = src0_row ? (float) src0_row[i0*s00] : 0.0f; if constexpr (sizeof...(src1_ptrs) > 0) { - result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10]))); + result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10*s10]))); } else { - result = bin_op(result, (float)src1[i_src1 + i10]); + result = bin_op(result, (float)src1[i_src1 + i10*s10]); } dst_row[i0] = (dst_t) result; @@ -101,13 +104,16 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const uint3 ne11, const uint3 ne12, const uint3 ne13, - /*int s0, */ const int s1, + /*const int s0,*/ + const int s1, const int s2, const int s3, - /*int s00,*/ const int s01, + const int s00, + const int s01, const int s02, const int s03, - /*int s10,*/ const int s11, + const int s10, + const int s11, const int s12, const int s13, src1_ptrs... src1s) { @@ -135,11 +141,11 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const int i10 = fastmodulo(i0, ne10); - float result = src0_row ? (float) src0_row[i0] : 0.0f; + float result = src0_row ? (float) src0_row[i0*s00] : 0.0f; if constexpr (sizeof...(src1_ptrs) > 0) { - result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10]))); + result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10*s10]))); } else { - result = bin_op(result, (float)src1[i_src1 + i10]); + result = bin_op(result, (float)src1[i_src1 + i10*s10]); } dst_row[i0] = (dst_t) result; @@ -179,7 +185,7 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor * cnb[3] *= cne[3]; }; - if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) { + if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && !ggml_is_permuted(src0) && !ggml_is_permuted(src1)) { for (int i = 0; i < 4; i++) { if (nr[i] != 1) { break; @@ -221,7 +227,7 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor * size_t nb12 = cnb1[2]; size_t nb13 = cnb1[3]; - size_t s0 = nb0 / sizeof(dst_t); + //size_t s0 = nb0 / sizeof(dst_t); size_t s1 = nb1 / sizeof(dst_t); size_t s2 = nb2 / sizeof(dst_t); size_t s3 = nb3 / sizeof(dst_t); @@ -251,10 +257,6 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor * GGML_ASSERT(nb12 % sizeof(src1_t) == 0); GGML_ASSERT(nb13 % sizeof(src1_t) == 0); - GGML_ASSERT(s0 == 1); - GGML_ASSERT(s00 == 1); - GGML_ASSERT(s10 == 1); - const int block_size = 128; int64_t hne0 = std::max(ne0 / 2LL, 1LL); @@ -284,31 +286,31 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor * k_bin_bcast_unravel<<>>( src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv, ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11, ne12, ne13, - /* s0, */ s1, s2, s3, - /* s00,*/ s01, s02, s03, - /* s10,*/ s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...); + /*s0,*/ s1, s2, s3, + s00, s01, s02, s03, + s10, s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...); } else { k_bin_bcast_unravel <<>>(src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv, ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11, ne12, ne13, - /* s0, */ s1, s2, s3, - /* s00,*/ s01, s02, s03, - /* s10,*/ s11, s12, s13); + /*s0,*/ s1, s2, s3, + s00, s01, s02, s03, + s10, s11, s12, s13); } } else { const uint3 ne3_fastdiv = init_fastdiv_values((uint32_t) ne3); if constexpr (sizeof...(I) > 0) { k_bin_bcast<<>>( src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13, - /* s0, */ s1, s2, s3, - /* s00,*/ s01, s02, s03, - /* s10,*/ s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...); + /*s0,*/ s1, s2, s3, + s00 ,s01, s02, s03, + s10, s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...); } else { k_bin_bcast<<>>( src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13, - /* s0, */ s1, s2, s3, - /* s00,*/ s01, s02, s03, - /* s10,*/ s11, s12, s13); + /*s0,*/ s1, s2, s3, + s00, s01, s02, s03, + s10, s11, s12, s13); } } } From 09587ceb12a4107c0075b3a8fd5699984689abf6 Mon Sep 17 00:00:00 2001 From: Max Krasnyansky Date: Tue, 10 Feb 2026 23:21:12 -0800 Subject: [PATCH 130/831] hexagon: Add ARGSORT, DIV, SQR, SQRT, SUM_ROWS, GEGLU (llama/19406) * hexagon: add ARGSORT op Co-authored-by: Yarden Tal * hexagon: argsort reject tensors with huge rows for now * Adding support for DIV,SQR,SQRT,SUM_ROWS ops in hexagon backend * hexagon : Add GEGLU op * hexagon: fix editor config check * hexagon: rewrite and optimize binary ops ADD/SUB/MUL/DIV/ADD_ID to use DMA --------- Co-authored-by: Yarden Tal Co-authored-by: Manohara Hosakoppa Krishnamurthy --- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 113 ++- ggml/src/ggml-hexagon/htp/CMakeLists.txt | 3 + ggml/src/ggml-hexagon/htp/act-ops.c | 152 +++- ggml/src/ggml-hexagon/htp/argsort-ops.c | 281 +++++++ ggml/src/ggml-hexagon/htp/binary-ops.c | 918 +++++++++++++++++------ ggml/src/ggml-hexagon/htp/htp-msg.h | 64 +- ggml/src/ggml-hexagon/htp/htp-ops.h | 2 + ggml/src/ggml-hexagon/htp/hvx-arith.h | 251 ++++--- ggml/src/ggml-hexagon/htp/hvx-base.h | 6 + ggml/src/ggml-hexagon/htp/hvx-copy.h | 2 - ggml/src/ggml-hexagon/htp/hvx-div.h | 116 +++ ggml/src/ggml-hexagon/htp/hvx-sigmoid.h | 27 + ggml/src/ggml-hexagon/htp/hvx-sqrt.h | 68 +- ggml/src/ggml-hexagon/htp/hvx-utils.h | 1 + ggml/src/ggml-hexagon/htp/main.c | 107 +++ ggml/src/ggml-hexagon/htp/sum-rows-ops.c | 115 +++ ggml/src/ggml-hexagon/htp/unary-ops.c | 64 ++ 17 files changed, 1904 insertions(+), 386 deletions(-) create mode 100644 ggml/src/ggml-hexagon/htp/argsort-ops.c create mode 100644 ggml/src/ggml-hexagon/htp/hvx-div.h create mode 100644 ggml/src/ggml-hexagon/htp/sum-rows-ops.c diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 4f0a1620fbf..54f9986498f 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -1935,11 +1935,6 @@ static bool ggml_hexagon_supported_binary(const struct ggml_hexagon_session * se return false; } - // TODO: add support for non-contigiuos tensors - if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) { - return false; - } - return true; } @@ -1991,6 +1986,25 @@ static bool ggml_hexagon_supported_unary(const struct ggml_hexagon_session * ses return true; } +static bool ggml_hexagon_supported_sum_rows(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + const struct ggml_tensor * src0 = op->src[0]; + const struct ggml_tensor * dst = op; + + if (!hex_supported_src0_type(src0->type)) { + return false; + } + if (!hex_supported_dst_type(dst->type)) { + return false; + } + + // TODO: add support for non-contigiuos tensors + if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(dst)) { + return false; + } + + return true; +} + static bool ggml_hexagon_supported_activations(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { const struct ggml_tensor * src0 = op->src[0]; @@ -2111,6 +2125,26 @@ static bool ggml_hexagon_supported_get_rows(const struct ggml_hexagon_session * return true; } +static bool ggml_hexagon_supported_argsort(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + const struct ggml_tensor * src0 = op->src[0]; // values + const struct ggml_tensor * dst = op; // indices + + if (src0->type != GGML_TYPE_F32) { + return false; + } + + if (dst->type != GGML_TYPE_I32) { + return false; + } + + if (src0->ne[0] > (16*1024)) { + // reject tensors with huge rows for now + return false; + } + + return true; +} + static bool ggml_hexagon_supported_rope(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { const int32_t * op_params = &op->op_params[0]; @@ -2278,6 +2312,9 @@ static inline size_t init_binary_req(htp_general_req * req, dspqueue_buffer * bu case GGML_OP_SUB: req->op = HTP_OP_SUB; break; + case GGML_OP_DIV: + req->op = HTP_OP_DIV; + break; default: GGML_ABORT("ggml-hex: binary : unsupported op: %d\n", t->op); break; @@ -2316,6 +2353,17 @@ static inline size_t init_get_rows_req(htp_general_req * req, dspqueue_buffer * return n_bufs; } +static inline size_t init_argsort_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { + req->op = HTP_OP_ARGSORT; + memcpy(&req->op_params, &t->op_params, sizeof(t->op_params)); + + size_t n_bufs = 0; + n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); + n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); + + return n_bufs; +} + template static inline size_t init_binary_id_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { switch (t->op) { @@ -2370,6 +2418,16 @@ static inline size_t init_unary_req(htp_general_req * req, dspqueue_buffer * buf supported = true; break; + case GGML_OP_SQR: + req->op = HTP_OP_SQR; + supported = true; + break; + + case GGML_OP_SQRT: + req->op = HTP_OP_SQRT; + supported = true; + break; + case GGML_OP_UNARY: if (ggml_get_unary_op(t) == GGML_UNARY_OP_SILU) { req->op = HTP_OP_UNARY_SILU; @@ -2387,6 +2445,9 @@ static inline size_t init_unary_req(htp_general_req * req, dspqueue_buffer * buf } else if (ggml_get_glu_op(t) == GGML_GLU_OP_SWIGLU_OAI) { req->op = HTP_OP_GLU_SWIGLU_OAI; supported = true; + } else if (ggml_get_glu_op(t) == GGML_GLU_OP_GEGLU) { + req->op = HTP_OP_GLU_GEGLU; + supported = true; } break; @@ -2411,6 +2472,17 @@ static inline size_t init_unary_req(htp_general_req * req, dspqueue_buffer * buf return n_bufs; } +static inline size_t init_sum_rows_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { + memcpy(&req->op_params, &t->op_params, sizeof(t->op_params)); + req->op = HTP_OP_SUM_ROWS; + + size_t n_bufs = 0; + n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); + n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); + + return n_bufs; +} + static inline size_t init_rope_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { memcpy(&req->op_params, &t->op_params, sizeof(t->op_params)); req->op = HTP_OP_ROPE; @@ -2519,6 +2591,7 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg case GGML_OP_MUL: case GGML_OP_ADD: case GGML_OP_SUB: + case GGML_OP_DIV: ggml_hexagon_dispatch_op>(sess, node, flags); break; case GGML_OP_ADD_ID: @@ -2528,6 +2601,13 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg case GGML_OP_SCALE: ggml_hexagon_dispatch_op(sess, node, flags); break; + case GGML_OP_SQR: + case GGML_OP_SQRT: + ggml_hexagon_dispatch_op(sess, node, flags); + break; + case GGML_OP_SUM_ROWS: + ggml_hexagon_dispatch_op(sess, node, flags); + break; case GGML_OP_UNARY: if ((ggml_get_unary_op(node) == GGML_UNARY_OP_SILU) || (ggml_get_unary_op(node) == GGML_UNARY_OP_GELU)) { @@ -2536,7 +2616,8 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg break; case GGML_OP_GLU: if ((ggml_get_glu_op(node) == GGML_GLU_OP_SWIGLU) || - (ggml_get_glu_op(node) == GGML_GLU_OP_SWIGLU_OAI)) { + (ggml_get_glu_op(node) == GGML_GLU_OP_SWIGLU_OAI) || + (ggml_get_glu_op(node) == GGML_GLU_OP_GEGLU)) { ggml_hexagon_dispatch_op(sess, node, flags); } break; @@ -2564,6 +2645,10 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg ggml_hexagon_dispatch_op(sess, node, flags); break; + case GGML_OP_ARGSORT: + ggml_hexagon_dispatch_op(sess, node, flags); + break; + default: GGML_ABORT("\nggml-hex: graph-compute %s is not supported\n", ggml_op_desc(node)); } @@ -2916,6 +3001,7 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons case GGML_OP_MUL: case GGML_OP_ADD: case GGML_OP_SUB: + case GGML_OP_DIV: supp = ggml_hexagon_supported_binary(sess, op); break; @@ -2928,6 +3014,15 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons supp = ggml_hexagon_supported_unary(sess, op); break; + case GGML_OP_SQR: + case GGML_OP_SQRT: + supp = ggml_hexagon_supported_unary(sess, op); + break; + + case GGML_OP_SUM_ROWS: + supp = ggml_hexagon_supported_sum_rows(sess, op); + break; + case GGML_OP_SOFT_MAX: supp = ggml_hexagon_supported_softmax(sess, op); break; @@ -2943,7 +3038,7 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons case GGML_OP_GLU: { const auto glu_op = ggml_get_glu_op(op); - if ((glu_op == GGML_GLU_OP_SWIGLU) || (glu_op == GGML_GLU_OP_SWIGLU_OAI)) { + if ((glu_op == GGML_GLU_OP_SWIGLU) || (glu_op == GGML_GLU_OP_SWIGLU_OAI) || (glu_op == GGML_GLU_OP_GEGLU)) { supp = ggml_hexagon_supported_activations(sess, op); } break; @@ -2968,6 +3063,10 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons supp = ggml_hexagon_supported_cpy(sess, op); break; + case GGML_OP_ARGSORT: + supp = ggml_hexagon_supported_argsort(sess, op); + break; + default: break; } diff --git a/ggml/src/ggml-hexagon/htp/CMakeLists.txt b/ggml/src/ggml-hexagon/htp/CMakeLists.txt index e8ef203045c..2c23b60da3d 100644 --- a/ggml/src/ggml-hexagon/htp/CMakeLists.txt +++ b/ggml/src/ggml-hexagon/htp/CMakeLists.txt @@ -6,6 +6,7 @@ include(${HEXAGON_SDK_ROOT}/build/cmake/hexagon_fun.cmake) include_directories( ${HEXAGON_SDK_ROOT}/incs ${HEXAGON_SDK_ROOT}/incs/stddef + ${CMAKE_CURRENT_SOURCE_DIR}/../../../include ${CMAKE_CURRENT_SOURCE_DIR}/../.. ${CMAKE_CURRENT_SOURCE_DIR}/.. ${CMAKE_CURRENT_SOURCE_DIR} @@ -21,6 +22,7 @@ add_library(${HTP_LIB} SHARED matmul-ops.c binary-ops.c unary-ops.c + sum-rows-ops.c softmax-ops.c act-ops.c rope-ops.c @@ -28,6 +30,7 @@ add_library(${HTP_LIB} SHARED set-rows-ops.c get-rows-ops.c cpy-ops.c + argsort-ops.c ) target_compile_definitions(${HTP_LIB} PRIVATE diff --git a/ggml/src/ggml-hexagon/htp/act-ops.c b/ggml/src/ggml-hexagon/htp/act-ops.c index c3daf5adb2e..950d836ad34 100644 --- a/ggml/src/ggml-hexagon/htp/act-ops.c +++ b/ggml/src/ggml-hexagon/htp/act-ops.c @@ -410,7 +410,7 @@ static void unary_gelu_f32_per_thread(const struct htp_tensor * src0, // gelu = x * sigmoid(1.702 * x) // current implementation hvx_mul_scalar_f32((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (float) 1.702, ne0); hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0); - hvx_mul_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0); + hvx_mul_f32_aaa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0); } dma_queue_push_vtcm_to_ddr(dma_queue, @@ -516,7 +516,7 @@ static void unary_silu_f32_per_thread(const struct htp_tensor * src0, // silu = x * sigmoid(x) hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, ne0); - hvx_mul_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0); + hvx_mul_f32_aaa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0); } dma_queue_push_vtcm_to_ddr(dma_queue, @@ -541,6 +541,143 @@ static void unary_silu_f32_per_thread(const struct htp_tensor * src0, ne03, src0_start_row, src0_end_row, ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } +static const float GELU_COEF_A = 0.044715f; +static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; + +static void glu_geglu_f32_per_thread(const struct htp_tensor * src0, + const struct htp_tensor * src1, + struct htp_tensor * dst, + const int32_t * op_params, + struct htp_spad * src0_spad, + struct htp_spad * src1_spad, + struct htp_spad * dst_spad, + uint32_t nth, + uint32_t ith, + uint32_t src0_nrows_per_thread, + dma_queue * dma_queue) { + htp_act_preamble3; + + size_t src0_row_size = nb01; + size_t src1_row_size = nb11; + size_t dst_row_size = nb1; + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows + + const uint32_t src0_start_row = src0_nrows_per_thread * ith; + const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); + + // no work for this thread + if (src0_start_row >= src0_end_row) { + return; + } + + const uint8_t * restrict data_src0 = (const uint8_t *) src0->data; + const uint8_t * restrict data_src1 = (const uint8_t *) src1->data; + uint8_t * restrict data_dst = (uint8_t *) dst->data; + + const bool src1_valid = src1->ne[0]; + const int nc = (src1_valid) ? ne00 : ne00 / 2; + if (!src1_valid) { + const int32_t swapped = op_params[1]; + data_src1 = data_src0; + src1_row_size = src0_row_size; + + const size_t nc_in_bytes = nc * SIZEOF_FP32; + data_src0 += swapped ? nc_in_bytes : 0; + data_src1 += swapped ? 0 : nc_in_bytes; + } + + const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN); + const size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN); + const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN); + + uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread); + uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_spad->size_per_thread); + uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread); + + // While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0 + size_t src0_spad_half_size = src0_spad->size_per_thread / 2; + size_t src1_spad_half_size = src1_spad->size_per_thread / 2; + size_t dst_spad_half_size = dst_spad->size_per_thread / 2; + + const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block + if (BLOCK == 0) { + FARF(ERROR, + "geglu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n", + src0_spad->size_per_thread, src0_row_size_aligned); + return; + } + + // See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379 + for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) { + const uint32_t block_size = MIN(BLOCK, src0_end_row - ir); + + // Dummy DMA transation for sequencing (interleaving dst,src,dst,...) + dma_queue_push_vtcm_to_ddr(dma_queue, + dma_make_ptr(data_dst, dst_spad_data + (spad_idx * dst_spad_half_size)), + dst_row_size, dst_row_size_aligned, 0); + + dma_queue_push_ddr_to_vtcm(dma_queue, + dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src0 + (ir * src0_row_size)), + src0_row_size_aligned, src0_row_size, block_size); + dma_queue_push_ddr_to_vtcm(dma_queue, + dma_make_ptr(src1_spad_data + (spad_idx * src1_spad_half_size), data_src1 + (ir * src1_row_size)), + src1_row_size_aligned, src1_row_size, block_size); + } + + for (uint32_t ir = src0_start_row; ir < src0_end_row; ir += BLOCK) { + const uint32_t block_size = MIN(BLOCK, src0_end_row - ir); + + float * dst_spad = (float *) dma_queue_pop(dma_queue).src; + float * src0_spad = (float *) dma_queue_pop(dma_queue).dst; + float * src1_spad = (float *) dma_queue_pop(dma_queue).dst; + + for (uint32_t ib = 0; ib < block_size; ib++) { + const uint8_t * src0_spad_ptr = (const uint8_t *)(src0_spad + ib * (src0_row_size_aligned / sizeof(float))); + const uint8_t * src1_spad_ptr = (const uint8_t *)(src1_spad + ib * (src1_row_size_aligned / sizeof(float))); + uint8_t * dst_spad_ptr = (uint8_t *)(dst_spad + ib * (dst_row_size_aligned / sizeof(float))); + + // geglu tanh implementation + // geglu(x, g) = gelu(x) * g + // gelu(x) = 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))) + hvx_mul_f32_aaa(dst_spad_ptr, src0_spad_ptr, src0_spad_ptr, nc); // res = x*x + hvx_mul_scalar_f32_aa(dst_spad_ptr, (const uint8_t *)dst_spad_ptr, GELU_COEF_A, nc); // res = res * GELU_COEF_A + hvx_add_scalar_f32_aa(dst_spad_ptr, (const uint8_t *)dst_spad_ptr, 1.0f, nc); // res = res + 1.0f + hvx_mul_f32_aaa(dst_spad_ptr, src0_spad_ptr, (const uint8_t *)dst_spad_ptr, nc); // res = res * x + hvx_mul_scalar_f32_aa(dst_spad_ptr, (const uint8_t*)dst_spad_ptr, SQRT_2_OVER_PI, nc); // res = result * SQRT_2_OVER_PI + hvx_tanh_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) dst_spad_ptr, nc); // res = tanh(res) + hvx_add_scalar_f32_aa(dst_spad_ptr, (const uint8_t*)dst_spad_ptr, 1.0f, nc); // res = res + 1.0f + hvx_mul_f32_aaa(dst_spad_ptr, src0_spad_ptr, (const uint8_t *)dst_spad_ptr, nc); // res = res * x + hvx_mul_scalar_f32_aa(dst_spad_ptr, (const uint8_t *)dst_spad_ptr, 0.5f, nc); // res = res + 0.5f + hvx_mul_f32_aaa(dst_spad_ptr, (const uint8_t *)dst_spad_ptr, src1_spad_ptr, nc); // res = res * g + } + + dma_queue_push_vtcm_to_ddr(dma_queue, dma_make_ptr(data_dst + (ir * dst_row_size), dst_spad), dst_row_size, + dst_row_size_aligned, block_size); + + // prefetch N+2 loop iteration if any + const uint32_t pref_block = (ir + BLOCK * 2); + if (pref_block < src0_end_row) { + const uint32_t pref_block_size = MIN(BLOCK, src0_end_row - pref_block); + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src0_spad, data_src0 + (pref_block * src0_row_size)), + src0_row_size_aligned, src0_row_size, pref_block_size); + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src1_spad, data_src1 + (pref_block * src1_row_size)), + src1_row_size_aligned, src1_row_size, pref_block_size); + } + } + + dma_queue_flush(dma_queue); + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "geglu-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, + ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, + (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + static void unary_silu_f32(unsigned int n, unsigned int i, void * data) { struct htp_ops_context * octx = (struct htp_ops_context *) data; unary_silu_f32_per_thread(&octx->src0, &octx->dst, octx->op_params, &octx->src0_spad, &octx->dst_spad, n, i, @@ -559,6 +696,12 @@ static void glu_swiglu_oai_f32(unsigned int n, unsigned int i, void * data) { &octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]); } +static void glu_geglu_f32(unsigned int n, unsigned int i, void * data) { + struct htp_ops_context * octx = (struct htp_ops_context *) data; + glu_geglu_f32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad, + &octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]); +} + static int execute_op_activations_f32(struct htp_ops_context * octx) { int err = HTP_STATUS_OK; @@ -593,6 +736,11 @@ static int execute_op_activations_f32(struct htp_ops_context * octx) { act_op_func = unary_gelu_f32; op_type = "gelu-f32"; break; + + case HTP_OP_GLU_GEGLU: + act_op_func = glu_geglu_f32; + op_type = "geglu-f32"; + break; default: FARF(ERROR, "Unsupported activations Op %u\n", octx->op); return HTP_STATUS_NO_SUPPORT; diff --git a/ggml/src/ggml-hexagon/htp/argsort-ops.c b/ggml/src/ggml-hexagon/htp/argsort-ops.c new file mode 100644 index 00000000000..a4cee980be8 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/argsort-ops.c @@ -0,0 +1,281 @@ +#include +#include +#include +#include +#include + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "ggml.h" + +#include "hvx-utils.h" +#include "hex-dma.h" + +#include "htp-ctx.h" +#include "htp-msg.h" +#include "htp-ops.h" + +#ifndef MIN +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#endif + +struct htp_argsort_context { + struct htp_ops_context * octx; + uint32_t nrows_per_thread; +}; + +static inline bool all_greater_f32(HVX_Vector x, HVX_Vector y) +{ + const HVX_Vector one = Q6_V_vsplat_R(1); + const HVX_Vector zero = Q6_V_vzero(); + + HVX_VectorPred pred = Q6_Q_vcmp_gt_VsfVsf(x, y); + HVX_Vector matches = Q6_V_vmux_QVV(pred, one, zero); + HVX_Vector sum = hvx_vec_reduce_sum_i32(matches); + return hvx_vec_get_i32(sum) == 32; +} + +// Sorts values and mirrors swaps to indices. +static void quicksort_values_indices_asc(float * values, int32_t * indices, int left, int right) { + if (left >= right) return; + + int pivot_idx = (left + right) / 2; + float pivot = values[pivot_idx]; + int i = left; + int j = right; + + HVX_Vector pivot_vec = hvx_vec_splat_f32(pivot); + while (i <= j) { + // Vectorized scan for i + while (i <= j) { + // Check if we have at least one full vector + if (i + 32 <= j) { + HVX_Vector vals_vec = *(HVX_UVector *)(values + i); + if (all_greater_f32(pivot_vec, vals_vec)) { + // If all elements are < pivot, we can skip this whole block + i += 32; + continue; + } + } + + // Scalar fallback / cleanup + if (values[i] < pivot) { + i++; + } else { + break; + } + } + + // Vectorized scan for j + while (i <= j) { + if (j - 32 >= i) { + // Load 32 elements ending at j. + // Since we want `values[j] > pivot`, let's load from j-31 to j. + HVX_Vector vals_vec = *(HVX_UVector *)(values + j - 31); + if (all_greater_f32(vals_vec, pivot_vec)) { + j -= 32; + continue; + } + } + + if (values[j] > pivot) { + j--; + } else { + break; + } + } + + if (i <= j) { + float tmp_val = values[i]; + values[i] = values[j]; + values[j] = tmp_val; + + int32_t tmp_idx = indices[i]; + indices[i] = indices[j]; + indices[j] = tmp_idx; + i++; + j--; + } + } + + if (left < j) quicksort_values_indices_asc(values, indices, left, j); + if (i < right) quicksort_values_indices_asc(values, indices, i, right); +} + +static void quicksort_values_indices_desc(float * values, int32_t * indices, int left, int right) { + if (left >= right) return; + + int pivot_idx = (left + right) / 2; + float pivot = values[pivot_idx]; + int i = left; + int j = right; + + HVX_Vector pivot_vec = hvx_vec_splat_f32(pivot); + + while (i <= j) { + // Vectorized scan for i (values[i] > pivot) + while (i <= j) { + if (i + 32 <= j) { + HVX_Vector vals_vec = *(HVX_UVector *)(values + i); + if (all_greater_f32(vals_vec, pivot_vec)) { + i += 32; + continue; + } + } + + if (values[i] > pivot) { + i++; + } else { + break; + } + } + + // Vectorized scan for j (values[j] < pivot) + while (i <= j) { + if (j - 32 >= i) { + HVX_Vector vals_vec = *(HVX_UVector *)(values + j - 31); + if (all_greater_f32(pivot_vec, vals_vec)) { + j -= 32; + continue; + } + } + + if (values[j] < pivot) { + j--; + } else { + break; + } + } + + if (i <= j) { + float tmp_val = values[i]; + values[i] = values[j]; + values[j] = tmp_val; + + int32_t tmp_idx = indices[i]; + indices[i] = indices[j]; + indices[j] = tmp_idx; + i++; + j--; + } + } + + if (left < j) quicksort_values_indices_desc(values, indices, left, j); + if (i < right) quicksort_values_indices_desc(values, indices, i, right); +} + +static void htp_argsort_f32(unsigned int n, unsigned int i, void * data) { + struct htp_argsort_context * actx = (struct htp_argsort_context *)data; + struct htp_ops_context * octx = actx->octx; + + // Unpack context + const struct htp_tensor * src0 = &octx->src0; + const struct htp_tensor * dst = &octx->dst; + + // Scratchpad memory + uint8_t * spad = octx->src0_spad.data + octx->src0_spad.size_per_thread * i; + + // Dimensions + uint32_t ne00 = src0->ne[0]; + uint32_t ne01 = src0->ne[1]; + uint32_t ne02 = src0->ne[2]; + uint32_t ne03 = src0->ne[3]; + + uint32_t nb01 = src0->nb[1]; + //uint32_t nb02 = src0->nb[2]; + //uint32_t nb03 = src0->nb[3]; + + uint32_t nb1 = dst->nb[1]; + //uint32_t nb2 = dst->nb[2]; + //uint32_t nb3 = dst->nb[3]; + + // Sort order + enum ggml_sort_order order = (enum ggml_sort_order) octx->op_params[0]; + + // Rows to process + uint32_t total_rows = ne01 * ne02 * ne03; + uint32_t rows_per_thread = actx->nrows_per_thread; + uint32_t start_row = rows_per_thread * i; + uint32_t end_row = MIN(start_row + rows_per_thread, total_rows); + + // Scratchpad layout: + // We need space for one row of float data (values) and one row of int32 indices. + // values: ne00 * sizeof(float) + // indices: ne00 * sizeof(int32_t) + // Padded to 128 bytes. + + size_t values_size = hex_round_up(ne00 * sizeof(float), 128); + float * values_buf = (float *) spad; + int32_t * indices_buf = (int32_t *) (spad + values_size); + + for (uint32_t r = start_row; r < end_row; r++) { + uint32_t src_offset = r * nb01; + uint32_t dst_offset = r * nb1; + + uint8_t * src_ptr = (uint8_t *) src0->data + src_offset; + uint8_t * dst_ptr = (uint8_t *) dst->data + dst_offset; + + hex_l2fetch(src_ptr, ne00 * sizeof(float), ne00 * sizeof(float), 1); + hvx_copy_f32_au((uint8_t*)values_buf, src_ptr, ne00); + + // Initialize indices + for (uint32_t j = 0; j < ne00; j++) { + indices_buf[j] = j; + } + + // Sort values and mirror swaps to indices + if (order == GGML_SORT_ORDER_ASC) { + quicksort_values_indices_asc(values_buf, indices_buf, 0, ne00 - 1); + } else { + quicksort_values_indices_desc(values_buf, indices_buf, 0, ne00 - 1); + } + + // Copy indices back to DDR + hvx_copy_f32_ua(dst_ptr, (const uint8_t *) indices_buf, ne00); + } +} + +int op_argsort(struct htp_ops_context * octx) { + // Check supported types + if (octx->src0.type != HTP_TYPE_F32) { + return HTP_STATUS_NO_SUPPORT; + } + + // Allocate scratchpad + // We need 1 row of float + 1 row of int32 per thread. + uint32_t ne00 = octx->src0.ne[0]; + size_t values_size = hex_round_up(ne00 * sizeof(float), 128); + size_t indices_size = hex_round_up(ne00 * sizeof(int32_t), 128); + size_t spad_per_thread = values_size + indices_size; + + // Make sure we round up to 256 for alignment requirements + spad_per_thread = hex_round_up(spad_per_thread, 256); + + size_t total_spad_size = spad_per_thread * octx->n_threads; + + if (octx->ctx->vtcm_size < total_spad_size) { + FARF(ERROR, "argsort: VTCM size too small. Needed %zu, have %zu", total_spad_size, octx->ctx->vtcm_size); + return HTP_STATUS_VTCM_TOO_SMALL; + } + + octx->src0_spad.data = octx->ctx->vtcm_base; + octx->src0_spad.size = total_spad_size; + octx->src0_spad.size_per_thread = spad_per_thread; + + FARF(HIGH, "argsort: %ux%ux%ux%u -> %ux%ux%ux%u (0x%x, 0x%x)", + octx->src0.ne[0], octx->src0.ne[1], octx->src0.ne[2], octx->src0.ne[3], + octx->dst.ne[0], octx->dst.ne[1], octx->dst.ne[2], octx->dst.ne[3], + octx->src0.data, octx->dst.data); + + uint32_t total_rows = octx->src0.ne[1] * octx->src0.ne[2] * octx->src0.ne[3]; + uint32_t n_jobs = MIN(total_rows, octx->n_threads); + + struct htp_argsort_context actx; + actx.octx = octx; + actx.nrows_per_thread = (total_rows + n_jobs - 1) / n_jobs; + + // Run jobs + worker_pool_run_func(octx->ctx->worker_pool, htp_argsort_f32, &actx, n_jobs); + + return HTP_STATUS_OK; +} diff --git a/ggml/src/ggml-hexagon/htp/binary-ops.c b/ggml/src/ggml-hexagon/htp/binary-ops.c index de22afe460e..00dbcf87986 100644 --- a/ggml/src/ggml-hexagon/htp/binary-ops.c +++ b/ggml/src/ggml-hexagon/htp/binary-ops.c @@ -17,15 +17,37 @@ #include "htp-msg.h" #include "htp-ops.h" -typedef void (*hvx_elemwise_f32_func)(uint8_t * data_dst, const uint8_t * src0, const uint8_t * src1, const uint32_t num_elems); - -static hvx_elemwise_f32_func func_table_HVX[] = { hvx_mul_f32, hvx_add_f32, hvx_sub_f32 }; -static hvx_elemwise_f32_func func_table_HVX_opt[] = { hvx_mul_f32_aa, hvx_add_f32_aa, hvx_sub_f32_aa }; +#ifndef MIN +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#endif + +// Context for binary operations +struct htp_binary_context { + struct htp_ops_context * octx; + struct fastdiv_values dim1_div; + struct fastdiv_values dim2_div; + struct fastdiv_values dim12_div; + + struct fastdiv_values src1_dim1_div; // ne11 + struct fastdiv_values src1_dim2_div; // ne12 + struct fastdiv_values src1_dim3_div; // ne13 + + uint32_t nrows_per_thread; + bool split_at_ne01; + bool split_at_ne02; + + // Precomputed values + uint32_t block_max; + size_t src0_row_size_aligned; + size_t src1_row_size_aligned; + size_t dst_row_size_aligned; + uint32_t src1_fetch_rows; // 1 or block_max + uint32_t src1_dma_stride; // 0 or stride +}; #define htp_binary_preamble \ const struct htp_tensor * src0 = &octx->src0; \ const struct htp_tensor * src1 = &octx->src1; \ - const struct htp_tensor * src2 = &octx->src2; \ struct htp_tensor * dst = &octx->dst; \ \ const uint32_t ne00 = src0->ne[0]; \ @@ -38,266 +60,696 @@ static hvx_elemwise_f32_func func_table_HVX_opt[] = { hvx_mul_f32_aa, hvx_add_f3 const uint32_t ne12 = src1->ne[2]; \ const uint32_t ne13 = src1->ne[3]; \ \ - const uint32_t ne0 = dst->ne[0]; \ - const uint32_t ne1 = dst->ne[1]; \ - const uint32_t ne2 = dst->ne[2]; \ - const uint32_t ne3 = dst->ne[3]; \ - \ - const uint32_t nb00 = src0->nb[0]; \ const uint32_t nb01 = src0->nb[1]; \ const uint32_t nb02 = src0->nb[2]; \ const uint32_t nb03 = src0->nb[3]; \ \ - const uint32_t nb10 = src1->nb[0]; \ const uint32_t nb11 = src1->nb[1]; \ const uint32_t nb12 = src1->nb[2]; \ const uint32_t nb13 = src1->nb[3]; \ \ - const uint32_t nb0 = dst->nb[0]; \ const uint32_t nb1 = dst->nb[1]; \ const uint32_t nb2 = dst->nb[2]; \ - const uint32_t nb3 = dst->nb[3]; \ - \ - const uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread; + const uint32_t nb3 = dst->nb[3]; -static void binary_job_f32_per_thread(struct htp_ops_context * octx, - uint8_t * spad_data, - uint32_t nth, - uint32_t ith, - enum htp_op op) { - htp_binary_preamble; +static inline uint32_t calc_block_size(struct htp_binary_context * bctx, uint32_t ir, uint32_t end_row, + uint32_t ne01, uint32_t ne02) { + uint32_t i03, i02, i01, rem; + i03 = fastdiv(ir, &bctx->dim12_div); + rem = ir - i03 * (ne02 * ne01); + i02 = fastdiv(rem, &bctx->dim1_div); + i01 = rem - i02 * ne01; - const size_t src0_row_size = nb01; - const size_t src1_row_size = nb11; - const size_t dst_row_size = nb1; + uint32_t rows_left = end_row - ir; + uint32_t block_limit = rows_left; - const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows - const uint32_t src1_nrows = ne11 * ne12 * ne13; // src1 rows + if (bctx->split_at_ne01) { + block_limit = MIN(block_limit, ne01 - i01); + } + if (bctx->split_at_ne02) { + uint32_t rows_in_plane = (ne02 * ne01) - rem; + block_limit = MIN(block_limit, rows_in_plane); + } - const uint32_t src0_start_row = src0_nrows_per_thread * ith; - const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); + return MIN(bctx->block_max, block_limit); +} - // no work for this thread - if (src0_start_row >= src0_end_row) { - return; +// Macro for scalar op switch +#define COMPUTE_SCALAR_OP(DST, SRC, VAL, N) \ + switch (octx->op) { \ + case HTP_OP_ADD: hvx_add_scalar_f32_aa(DST, SRC, VAL, N); break; \ + case HTP_OP_SUB: hvx_sub_scalar_f32_aa(DST, SRC, VAL, N); break; \ + case HTP_OP_MUL: hvx_mul_scalar_f32_aa(DST, SRC, VAL, N); break; \ + case HTP_OP_DIV: hvx_mul_scalar_f32_aa(DST, SRC, 1.0f / (VAL), N); break; \ + default: break; \ } - uint64_t t1, t2; - t1 = HAP_perf_get_qtimer_count(); +// Macro for vector op switch (All Aligned) +#define COMPUTE_VECTOR_OP_AAA(DST, SRC0, SRC1, N) \ + switch (octx->op) { \ + case HTP_OP_ADD: hvx_add_f32_aaa(DST, SRC0, SRC1, N); break; \ + case HTP_OP_SUB: hvx_sub_f32_aaa(DST, SRC0, SRC1, N); break; \ + case HTP_OP_MUL: hvx_mul_f32_aaa(DST, SRC0, SRC1, N); break; \ + case HTP_OP_DIV: hvx_div_f32_aaa(DST, SRC0, SRC1, N); break; \ + default: break; \ + } - int is_aligned = 1; - int opt_path = 0; - if ((0 == hex_is_aligned((void *) src0->data, VLEN)) || (0 == hex_is_aligned((void *) src1->data, VLEN)) || - (0 == hex_is_aligned((void *) dst->data, VLEN))) { - is_aligned = 0; +// Macro for vector op switch (Dst Aligned, Src0 Aligned, Src1 Unaligned) +#define COMPUTE_VECTOR_OP_AAU(DST, SRC0, SRC1, N) \ + switch (octx->op) { \ + case HTP_OP_ADD: hvx_add_f32_aau(DST, SRC0, SRC1, N); break; \ + case HTP_OP_SUB: hvx_sub_f32_aau(DST, SRC0, SRC1, N); break; \ + case HTP_OP_MUL: hvx_mul_f32_aau(DST, SRC0, SRC1, N); break; \ + case HTP_OP_DIV: hvx_div_f32_aau(DST, SRC0, SRC1, N); break; \ + default: break; \ } - if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) { - opt_path = 1; + +// Macro for vector op switch (All Unaligned - generic loop used in element repeat) +#define COMPUTE_VECTOR_OP_UUU(DST, SRC0, SRC1, N) \ + switch (octx->op) { \ + case HTP_OP_ADD: hvx_add_f32_uuu(DST, SRC0, SRC1, N); break; \ + case HTP_OP_SUB: hvx_sub_f32_uuu(DST, SRC0, SRC1, N); break; \ + case HTP_OP_MUL: hvx_mul_f32_uuu(DST, SRC0, SRC1, N); break; \ + case HTP_OP_DIV: hvx_div_f32_uuu(DST, SRC0, SRC1, N); break; \ + default: break; \ } - hvx_elemwise_f32_func func_HVX = (1 == opt_path) ? func_table_HVX_opt[op] : func_table_HVX[op]; +// 1. Scalar src1 (ne10 == 1) +static void binary_job_scalar(unsigned int nth, unsigned int ith, void * data) { + struct htp_binary_context * bctx = (struct htp_binary_context *) data; + struct htp_ops_context * octx = bctx->octx; + htp_binary_preamble; - uint8_t * restrict spad_data_th = spad_data + (ith * src0_row_size); + const uint32_t total_rows = ne01 * ne02 * ne03; + const uint32_t start_row = bctx->nrows_per_thread * ith; + const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows); + if (start_row >= end_row) return; + + uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread); + uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread); + size_t src0_spad_half = octx->src0_spad.size_per_thread / 2; + size_t dst_spad_half = octx->dst_spad.size_per_thread / 2; + + dma_queue * q = octx->ctx->dma[ith]; + uint32_t ir_prefetch = start_row; + int spad_idx = 0; + + // Preamble + for (int k = 0; k < 2 && ir_prefetch < end_row; k++) { + uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); + uint32_t i03, i02, i01, rem; + i03 = fastdiv(ir_prefetch, &bctx->dim12_div); + rem = ir_prefetch - i03 * (ne02 * ne01); + i02 = fastdiv(rem, &bctx->dim1_div); + i01 = rem - i02 * ne01; + + uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01; + uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; + + uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half; + uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half; + + dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0); + dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size); + ir_prefetch += current_block_size; + spad_idx ^= 1; + } - const uint8_t * restrict src0_ptr = (const uint8_t *) src0->data + (src0_start_row * src0_row_size); - uint8_t * restrict dst_ptr = (uint8_t *) dst->data + (src0_start_row * dst_row_size); + // Main loop + for (uint32_t ir = start_row; ir < end_row; ) { + uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02); + + uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src; + uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst; + + uint32_t i03, i02, i01, rem; + i03 = fastdiv(ir, &bctx->dim12_div); + rem = ir - i03 * (ne02 * ne01); + i02 = fastdiv(rem, &bctx->dim1_div); + i01 = rem - i02 * ne01; + + // src1 indices (broadcast/repeat) + uint32_t i13 = fastmodulo(i03, ne13, &bctx->src1_dim3_div); + uint32_t i12 = fastmodulo(i02, ne12, &bctx->src1_dim2_div); + uint32_t i11 = fastmodulo(i01, ne11, &bctx->src1_dim1_div); + + uint8_t * src1_ptr = (uint8_t *)src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11; + uint32_t s1_stride = (ne11 == 1) ? 0 : nb11; + + for (uint32_t r = 0; r < current_block_size; r++) { + uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned; + uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned; + float val = *(float *)src1_ptr; + src1_ptr += s1_stride; + COMPUTE_SCALAR_OP(r_dst, r_src0, val, ne00); + } - const uint8_t * restrict data_src1 = (const uint8_t *) src1->data; + uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; + dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size); - const uint32_t ne02_ne01 = ne02 * ne01; + if (ir_prefetch < end_row) { + uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); + uint32_t p03, p02, p01, prem; + p03 = fastdiv(ir_prefetch, &bctx->dim12_div); + prem = ir_prefetch - p03 * (ne02 * ne01); + p02 = fastdiv(prem, &bctx->dim1_div); + p01 = prem - p02 * ne01; + uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01; - for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) { - const uint32_t i03 = fastdiv(ir, &octx->src0_div21); - const uint32_t i02 = fastdiv(ir - i03 * ne02_ne01, &octx->src0_div1); - const uint32_t i01 = (ir - i03 * ne02_ne01 - i02 * ne01); + dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size); + ir_prefetch += next_block_size; + } + ir += current_block_size; + } + dma_queue_flush(q); +} - const uint32_t i13 = fastmodulo(i03, ne13, &octx->src1_div3); - const uint32_t i12 = fastmodulo(i02, ne12, &octx->src1_div2); - const uint32_t i11 = fastmodulo(i01, ne11, &octx->src1_div1); +// 2. Vector Same Shape (ne1x == ne0x) or Simple Broadcast +static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, void * data) { + struct htp_binary_context * bctx = (struct htp_binary_context *) data; + struct htp_ops_context * octx = bctx->octx; + htp_binary_preamble; - const uint8_t * restrict src1_ptr = data_src1 + i13 * nb13 + i12 * nb12 + i11 * src1_row_size; + const uint32_t total_rows = ne01 * ne02 * ne03; + const uint32_t start_row = bctx->nrows_per_thread * ith; + const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows); + if (start_row >= end_row) return; + + uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread); + uint8_t * src1_spad_base = octx->src1_spad.data + (ith * octx->src1_spad.size_per_thread); + uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread); + + size_t src0_spad_half = octx->src0_spad.size_per_thread / 2; + size_t src1_spad_half = octx->src1_spad.size_per_thread / 2; + size_t dst_spad_half = octx->dst_spad.size_per_thread / 2; + + dma_queue * q = octx->ctx->dma[ith]; + uint32_t ir_prefetch = start_row; + int spad_idx = 0; + + for (int k = 0; k < 2 && ir_prefetch < end_row; k++) { + uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); + uint32_t i03, i02, i01, rem; + i03 = fastdiv(ir_prefetch, &bctx->dim12_div); + rem = ir_prefetch - i03 * (ne02 * ne01); + i02 = fastdiv(rem, &bctx->dim1_div); + i01 = rem - i02 * ne01; + + uint32_t i13 = (ne13 == 1) ? 0 : i03; + uint32_t i12 = (ne12 == 1) ? 0 : i02; + uint32_t i11 = (ne11 == 1) ? 0 : i01; + + uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01; + uint8_t * src1_base = (uint8_t *)src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11; + uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; + + uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half; + uint8_t * s1_spad = src1_spad_base + spad_idx * src1_spad_half; + uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half; + + dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0); + dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size); + dma_queue_push(q, dma_make_ptr(s1_spad, src1_base), bctx->src1_row_size_aligned, bctx->src1_dma_stride, ne00 * sizeof(float), current_block_size); + ir_prefetch += current_block_size; + spad_idx ^= 1; + } - if (ir + 1 < src0_end_row) { - hex_l2fetch(src0_ptr + ne00, src0_row_size, src0_row_size, 1); - if (src1_row_size == src0_row_size) { - hex_l2fetch(src1_ptr, src1_row_size, src1_row_size, 1); - } + for (uint32_t ir = start_row; ir < end_row; ) { + uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02); + uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src; + uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst; + uint8_t * s1_spad = (uint8_t *) dma_queue_pop(q).dst; + + for (uint32_t r = 0; r < current_block_size; r++) { + uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned; + uint8_t * r_src1 = s1_spad + r * bctx->src1_row_size_aligned; + uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned; + COMPUTE_VECTOR_OP_AAA(r_dst, r_src0, r_src1, ne00); } - const uint32_t nr0 = ne00 / ne10; - if (nr0 > 1) { - if ((1 == is_aligned) && (nr0 == ne00)) { - hvx_splat_f32_a(spad_data_th, *(float *) src1_ptr, nr0); - } else { - for (uint32_t r = 0; r < nr0; r++) { - memcpy(spad_data_th + r * nb11, (const uint8_t *) src1_ptr, nb11); - } - } - func_HVX((uint8_t *) dst_ptr, (const uint8_t *) src0_ptr, (const uint8_t *) spad_data_th, ne00); - } else { - func_HVX((uint8_t *) dst_ptr, (const uint8_t *) src0_ptr, (const uint8_t *) src1_ptr, ne00); + uint32_t i03, i02, i01, rem; + i03 = fastdiv(ir, &bctx->dim12_div); + rem = ir - i03 * (ne02 * ne01); + i02 = fastdiv(rem, &bctx->dim1_div); + i01 = rem - i02 * ne01; + uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; + dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size); + + if (ir_prefetch < end_row) { + uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); + uint32_t p03, p02, p01, prem; + p03 = fastdiv(ir_prefetch, &bctx->dim12_div); + prem = ir_prefetch - p03 * (ne02 * ne01); + p02 = fastdiv(prem, &bctx->dim1_div); + p01 = prem - p02 * ne01; + + uint32_t p13 = (ne13 == 1) ? 0 : p03; + uint32_t p12 = (ne12 == 1) ? 0 : p02; + uint32_t p11 = (ne11 == 1) ? 0 : p01; + + uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01; + uint8_t * s1_next = (uint8_t *)src1->data + p13 * nb13 + p12 * nb12 + p11 * nb11; + + dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size); + dma_queue_push(q, dma_make_ptr(s1_spad, s1_next), bctx->src1_row_size_aligned, bctx->src1_dma_stride, ne00 * sizeof(float), next_block_size); + + ir_prefetch += next_block_size; } - - src0_ptr += src0_row_size; - dst_ptr += dst_row_size; + ir += current_block_size; } - - t2 = HAP_perf_get_qtimer_count(); - - FARF(HIGH, "binary-f32 %d/%d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, opt_path, - ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, - (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); + dma_queue_flush(q); } -static void binary_add_id_job_f32_per_thread(struct htp_ops_context * octx, - uint8_t * spad_data, - uint32_t nth, - uint32_t ith, - hvx_elemwise_f32_func func_HVX) { +// 3. Row Broadcast (ne11 == 1, ne12 == 1, single row src1) +static void binary_job_vector_row_broadcast(unsigned int nth, unsigned int ith, void * data) { + struct htp_binary_context * bctx = (struct htp_binary_context *) data; + struct htp_ops_context * octx = bctx->octx; htp_binary_preamble; - const size_t src0_row_size = nb01; - const size_t src1_row_size = nb11; - const size_t dst_row_size = nb1; + const uint32_t total_rows = ne01 * ne02 * ne03; + const uint32_t start_row = bctx->nrows_per_thread * ith; + const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows); + if (start_row >= end_row) return; - const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows + uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread); + uint8_t * src1_spad = octx->src1_spad.data + (ith * octx->src1_spad.size_per_thread); + uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread); - const uint32_t src0_start_row = src0_nrows_per_thread * ith; - const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); + size_t src0_spad_half = octx->src0_spad.size_per_thread / 2; + size_t dst_spad_half = octx->dst_spad.size_per_thread / 2; - // no work for this thread - if (src0_start_row >= src0_end_row) { - return; - } + dma_queue * q = octx->ctx->dma[ith]; + uint32_t ir_prefetch = start_row; + int spad_idx = 0; - uint64_t t1, t2; - t1 = HAP_perf_get_qtimer_count(); + void * s1_ptr = (void *) src1_spad; - const uint8_t * restrict data_src0 = (const uint8_t *) src0->data; - const uint8_t * restrict data_src1 = (const uint8_t *) src1->data; - uint8_t * restrict data_dst = (uint8_t *) dst->data; + for (int k = 0; k < 2 && ir_prefetch < end_row; k++) { + uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); + uint32_t i03, i02, i01, rem; + i03 = fastdiv(ir_prefetch, &bctx->dim12_div); + rem = ir_prefetch - i03 * (ne02 * ne01); + i02 = fastdiv(rem, &bctx->dim1_div); + i01 = rem - i02 * ne01; - const uint32_t ne02_ne01 = ne02 * ne01; - for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) { - // src0 indices - const uint32_t i03 = fastdiv(ir, &octx->src0_div21); - const uint32_t i02 = fastdiv(ir - i03 * ne02_ne01, &octx->src0_div1); - const uint32_t i01 = (ir - i03 * ne02_ne01 - i02 * ne01); + uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01; + uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; - // src1 indices - const int i11 = *(int32_t *) ((char *) src2->data + i01 * src2->nb[0] + i02 * src2->nb[1]); - assert(i11 >= 0 && i11 < ne11); + uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half; + uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half; - float * restrict dst_ptr = (float *) (data_dst + i03 * nb3 + i02 * nb2 + i01 * nb1); - const float * restrict src0_ptr = (const float *) (data_src0 + i03 * nb03 + i02 * nb02 + i01 * nb01); - const float * restrict src1_ptr = (const float *) (data_src1 + 0 + 0 + i11 * nb11); + dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0); + dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size); + ir_prefetch += current_block_size; + spad_idx ^= 1; + } - if (ir + 1 < src0_end_row) { - hex_l2fetch(src0_ptr + ne00, src0_row_size, src0_row_size, 1); - if (src1_row_size == src0_row_size) { - hex_l2fetch(src1_ptr + ne10, src1_row_size, src1_row_size, 1); - } + for (uint32_t ir = start_row; ir < end_row; ) { + uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02); + uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src; + uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst; + + for (uint32_t r = 0; r < current_block_size; r++) { + uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned; + uint8_t * r_src1 = (uint8_t *)s1_ptr; // Constant + uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned; + COMPUTE_VECTOR_OP_AAA(r_dst, r_src0, r_src1, ne00); } - const uint32_t nr0 = ne00 / ne10; - if (nr0 > 1) { - for (uint32_t r = 0; r < nr0; r++) { - memcpy(spad_data + r * nb10, (const uint8_t *) src1_ptr, nb10); - } - func_HVX((uint8_t *) dst_ptr, (const uint8_t *) src0_ptr, (const uint8_t *) spad_data, ne00); - } else { - func_HVX((uint8_t *) dst_ptr, (const uint8_t *) src0_ptr, (const uint8_t *) src1_ptr, ne00); + uint32_t i03, i02, i01, rem; + i03 = fastdiv(ir, &bctx->dim12_div); + rem = ir - i03 * (ne02 * ne01); + i02 = fastdiv(rem, &bctx->dim1_div); + i01 = rem - i02 * ne01; + uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; + dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size); + + if (ir_prefetch < end_row) { + uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); + uint32_t p03, p02, p01, prem; + p03 = fastdiv(ir_prefetch, &bctx->dim12_div); + prem = ir_prefetch - p03 * (ne02 * ne01); + p02 = fastdiv(prem, &bctx->dim1_div); + p01 = prem - p02 * ne01; + uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01; + dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size); + ir_prefetch += next_block_size; } + ir += current_block_size; + } + dma_queue_flush(q); +} + +// 4. Vector Complex (ne10 == ne00, complex broadcast) +static void binary_job_vector_complex(unsigned int nth, unsigned int ith, void * data) { + struct htp_binary_context * bctx = (struct htp_binary_context *) data; + struct htp_ops_context * octx = bctx->octx; + htp_binary_preamble; + + const uint32_t total_rows = ne01 * ne02 * ne03; + const uint32_t start_row = bctx->nrows_per_thread * ith; + const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows); + if (start_row >= end_row) return; + + uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread); + uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread); + size_t src0_spad_half = octx->src0_spad.size_per_thread / 2; + size_t dst_spad_half = octx->dst_spad.size_per_thread / 2; + + dma_queue * q = octx->ctx->dma[ith]; + uint32_t ir_prefetch = start_row; + int spad_idx = 0; + + for (int k = 0; k < 2 && ir_prefetch < end_row; k++) { + uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); + uint32_t i03, i02, i01, rem; + i03 = fastdiv(ir_prefetch, &bctx->dim12_div); + rem = ir_prefetch - i03 * (ne02 * ne01); + i02 = fastdiv(rem, &bctx->dim1_div); + i01 = rem - i02 * ne01; + + uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01; + uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; + + uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half; + uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half; + + dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0); + dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size); + ir_prefetch += current_block_size; + spad_idx ^= 1; } - t2 = HAP_perf_get_qtimer_count(); + for (uint32_t ir = start_row; ir < end_row; ) { + uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02); + uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src; + uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst; + + uint32_t i03, i02, i01, rem; + i03 = fastdiv(ir, &bctx->dim12_div); + rem = ir - i03 * (ne02 * ne01); + i02 = fastdiv(rem, &bctx->dim1_div); + i01 = rem - i02 * ne01; + + for (uint32_t r = 0; r < current_block_size; r++) { + uint32_t r_i01 = i01 + r; + uint32_t i13 = fastmodulo(i03, ne13, &bctx->src1_dim3_div); + uint32_t i12 = fastmodulo(i02, ne12, &bctx->src1_dim2_div); + uint32_t i11 = fastmodulo(r_i01, ne11, &bctx->src1_dim1_div); + + uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned; + uint8_t * r_src1 = (uint8_t *)src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11; + uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned; + + // Read src1 from DDR (unaligned) + COMPUTE_VECTOR_OP_AAU(r_dst, r_src0, r_src1, ne00); + } - FARF(HIGH, "add-id-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u usec %u\n", ith, nth, - src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1], - src1->ne[2], src1->ne[3], src2->ne[0], src2->ne[1], src2->ne[2], src2->ne[3], dst->ne[0], dst->ne[1], - dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); + uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; + dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size); + + if (ir_prefetch < end_row) { + uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); + uint32_t p03, p02, p01, prem; + p03 = fastdiv(ir_prefetch, &bctx->dim12_div); + prem = ir_prefetch - p03 * (ne02 * ne01); + p02 = fastdiv(prem, &bctx->dim1_div); + p01 = prem - p02 * ne01; + uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01; + dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size); + ir_prefetch += next_block_size; + } + ir += current_block_size; + } + dma_queue_flush(q); } -static void binary_job_dispatcher_f32(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = (struct htp_ops_context *) data; +// 5. Element Repeat (ne10 != ne00) +static void binary_job_element_repeat(unsigned int nth, unsigned int ith, void * data) { + struct htp_binary_context * bctx = (struct htp_binary_context *) data; + struct htp_ops_context * octx = bctx->octx; + htp_binary_preamble; - switch (octx->op) { - case HTP_OP_MUL: - case HTP_OP_ADD: - case HTP_OP_SUB: - binary_job_f32_per_thread(octx, octx->src1_spad.data, n, i, octx->op); - break; + const uint32_t total_rows = ne01 * ne02 * ne03; + const uint32_t start_row = bctx->nrows_per_thread * ith; + const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows); + if (start_row >= end_row) return; + + uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread); + uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread); + size_t src0_spad_half = octx->src0_spad.size_per_thread / 2; + size_t dst_spad_half = octx->dst_spad.size_per_thread / 2; + + dma_queue * q = octx->ctx->dma[ith]; + uint32_t ir_prefetch = start_row; + int spad_idx = 0; + + for (int k = 0; k < 2 && ir_prefetch < end_row; k++) { + uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); + uint32_t i03, i02, i01, rem; + i03 = fastdiv(ir_prefetch, &bctx->dim12_div); + rem = ir_prefetch - i03 * (ne02 * ne01); + i02 = fastdiv(rem, &bctx->dim1_div); + i01 = rem - i02 * ne01; + + uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01; + uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; + + uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half; + uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half; + + dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0); + dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size); + ir_prefetch += current_block_size; + spad_idx ^= 1; + } - case HTP_OP_ADD_ID: - binary_add_id_job_f32_per_thread(octx, octx->src0_spad.data, n, i, hvx_add_f32); - break; + for (uint32_t ir = start_row; ir < end_row; ) { + uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02); + uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src; + uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst; + + uint32_t i03, i02, i01, rem; + i03 = fastdiv(ir, &bctx->dim12_div); + rem = ir - i03 * (ne02 * ne01); + i02 = fastdiv(rem, &bctx->dim1_div); + i01 = rem - i02 * ne01; + + for (uint32_t r = 0; r < current_block_size; r++) { + uint32_t r_i01 = i01 + r; + uint32_t i13 = fastmodulo(i03, ne13, &bctx->src1_dim3_div); + uint32_t i12 = fastmodulo(i02, ne12, &bctx->src1_dim2_div); + uint32_t i11 = fastmodulo(r_i01, ne11, &bctx->src1_dim1_div); + + uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned; + uint8_t * r_src1_row = (uint8_t *)src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11; + uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned; + + // Repeat src1 row + for (uint32_t c = 0; c < ne00; c += ne10) { + uint32_t len = MIN(ne10, ne00 - c); + // Use UUU for speed and simplicity + COMPUTE_VECTOR_OP_UUU(r_dst + c * sizeof(float), r_src0 + c * sizeof(float), r_src1_row, len); + } + } - default: - FARF(ERROR, "Unknown Binary Op %u", octx->op); - break; + uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; + dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size); + + if (ir_prefetch < end_row) { + uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); + uint32_t p03, p02, p01, prem; + p03 = fastdiv(ir_prefetch, &bctx->dim12_div); + prem = ir_prefetch - p03 * (ne02 * ne01); + p02 = fastdiv(prem, &bctx->dim1_div); + p01 = prem - p02 * ne01; + uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01; + dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size); + ir_prefetch += next_block_size; + } + ir += current_block_size; } + dma_queue_flush(q); } -static int execute_op_binary_f32(struct htp_ops_context * octx) { - int err = HTP_STATUS_OK; +// 6. ADD_ID (src1 gathered via src2 indices) +static void binary_job_add_id(unsigned int nth, unsigned int ith, void * data) { + struct htp_binary_context * bctx = (struct htp_binary_context *) data; + struct htp_ops_context * octx = bctx->octx; const struct htp_tensor * src0 = &octx->src0; const struct htp_tensor * src1 = &octx->src1; + const struct htp_tensor * src2 = &octx->src2; struct htp_tensor * dst = &octx->dst; - worker_callback_t binary_op_func; - const char * op_type = NULL; - - switch (octx->op) { - case HTP_OP_MUL: - binary_op_func = binary_job_dispatcher_f32; - op_type = "mul-f32"; - break; - - case HTP_OP_ADD: - binary_op_func = binary_job_dispatcher_f32; - op_type = "add-f32"; - break; - - case HTP_OP_SUB: - binary_op_func = binary_job_dispatcher_f32; - op_type = "sub-f32"; - break; - - case HTP_OP_ADD_ID: - binary_op_func = binary_job_dispatcher_f32; - op_type = "add-id-f32"; - break; - - default: - FARF(ERROR, "Unsupported binary-Op %u\n", octx->op); - return HTP_STATUS_NO_SUPPORT; + const uint32_t ne00 = src0->ne[0]; + const uint32_t ne01 = src0->ne[1]; + const uint32_t ne02 = src0->ne[2]; + const uint32_t ne03 = src0->ne[3]; + const uint32_t ne11 = src1->ne[1]; // for bounds check + + const uint32_t nb01 = src0->nb[1]; + const uint32_t nb02 = src0->nb[2]; + const uint32_t nb03 = src0->nb[3]; + const uint32_t nb11 = src1->nb[1]; // src1 row stride + const uint32_t nb1 = dst->nb[1]; + const uint32_t nb2 = dst->nb[2]; + const uint32_t nb3 = dst->nb[3]; + + const uint32_t total_rows = ne01 * ne02 * ne03; + const uint32_t start_row = bctx->nrows_per_thread * ith; + const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows); + if (start_row >= end_row) return; + + uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread); + uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread); + size_t src0_spad_half = octx->src0_spad.size_per_thread / 2; + size_t dst_spad_half = octx->dst_spad.size_per_thread / 2; + + dma_queue * q = octx->ctx->dma[ith]; + uint32_t ir_prefetch = start_row; + int spad_idx = 0; + + for (int k = 0; k < 2 && ir_prefetch < end_row; k++) { + uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); + uint32_t i03, i02, i01, rem; + i03 = fastdiv(ir_prefetch, &bctx->dim12_div); + rem = ir_prefetch - i03 * (ne02 * ne01); + i02 = fastdiv(rem, &bctx->dim1_div); + i01 = rem - i02 * ne01; + + uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01; + uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; + + uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half; + uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half; + + dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0); + dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size); + ir_prefetch += current_block_size; + spad_idx ^= 1; + } + + for (uint32_t ir = start_row; ir < end_row; ) { + uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02); + uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src; + uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst; + + uint32_t i03, i02, i01, rem; + i03 = fastdiv(ir, &bctx->dim12_div); + rem = ir - i03 * (ne02 * ne01); + i02 = fastdiv(rem, &bctx->dim1_div); + i01 = rem - i02 * ne01; + + for (uint32_t r = 0; r < current_block_size; r++) { + uint32_t r_i01 = i01 + r; // linear within block since we split at ne01 + + const int32_t idx = *(int32_t *)((char *)src2->data + r_i01 * src2->nb[0] + i02 * src2->nb[1]); + + uint8_t * r_src1 = (uint8_t *)src1->data + idx * nb11; + uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned; + uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned; + + hvx_add_f32_aau(r_dst, r_src0, r_src1, ne00); + } + + uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; + dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size); + + if (ir_prefetch < end_row) { + uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); + uint32_t p03, p02, p01, prem; + p03 = fastdiv(ir_prefetch, &bctx->dim12_div); + prem = ir_prefetch - p03 * (ne02 * ne01); + p02 = fastdiv(prem, &bctx->dim1_div); + p01 = prem - p02 * ne01; + uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01; + dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size); + ir_prefetch += next_block_size; + } + ir += current_block_size; } + dma_queue_flush(q); +} + +static int execute_op_binary_f32(struct htp_ops_context * octx) { + const struct htp_tensor * src0 = &octx->src0; + const struct htp_tensor * src1 = &octx->src1; + struct htp_tensor * dst = &octx->dst; - const int n_threads = octx->n_threads; + const uint32_t n_threads = octx->n_threads; const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3]; - const size_t src0_row_size = src0->nb[1]; - const size_t src1_row_size = src1->nb[1]; - const size_t dst_row_size = dst->nb[1]; + // Use packed row sizes for VTCM allocation + const size_t src0_row_size = src0->ne[0] * sizeof(float); + const size_t src1_row_size = src1->ne[0] * sizeof(float); + const size_t dst_row_size = dst->ne[0] * sizeof(float); + + // Align to VLEN + const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN); + const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN); + size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN); + + bool is_add_id = (octx->op == HTP_OP_ADD_ID); + bool is_scalar = !is_add_id && (src1->ne[0] == 1); + + // Determine which kernel we will use to alloc memory and dispatch + bool use_vector_same = !is_add_id && !is_scalar && src1->ne[0] == src0->ne[0] && + (src1->ne[1] == src0->ne[1] || src1->ne[1] == 1) && + (src1->ne[2] == src0->ne[2] || src1->ne[2] == 1) && + (src1->ne[3] == src0->ne[3] || src1->ne[3] == 1); + + bool is_row_bcast = use_vector_same && (src1->ne[1] == 1 && src1->ne[2] == 1 && src1->ne[3] == 1); + bool use_complex = !is_add_id && !is_scalar && !use_vector_same && (src1->ne[0] == src0->ne[0]); + bool use_repeat = !is_add_id && !is_scalar && !use_vector_same && (src1->ne[0] != src0->ne[0]); + + size_t spad_row_total; + if (is_scalar) { + spad_row_total = 2 * (src0_row_size_aligned + dst_row_size_aligned); + } else if (is_row_bcast) { + spad_row_total = 2 * (src0_row_size_aligned + dst_row_size_aligned); + } else if (use_vector_same) { + spad_row_total = 2 * (src0_row_size_aligned + src1_row_size_aligned + dst_row_size_aligned); + } else if (is_add_id) { + spad_row_total = 2 * (src0_row_size_aligned + dst_row_size_aligned); // src1 read directly + } else { + spad_row_total = 2 * (src0_row_size_aligned + dst_row_size_aligned); + } - // VTCM scratchpads for all tensors - octx->dst_spad.size = hex_round_up(dst_row_size, 128) * n_threads; - octx->src0_spad.size = hex_round_up(src0_row_size, 128) * n_threads; - octx->src1_spad.size = hex_round_up(src1_row_size, 128) * n_threads; + size_t rows_per_buffer = octx->ctx->vtcm_size / (n_threads * spad_row_total); + // Adjust for static src1 in row_bcast case + if (is_row_bcast) { + size_t needed_static = src1_row_size_aligned; + if (octx->ctx->vtcm_size < needed_static) return HTP_STATUS_VTCM_TOO_SMALL; + size_t avail = octx->ctx->vtcm_size - needed_static; + rows_per_buffer = avail / (n_threads * spad_row_total); + } - size_t spad_size = octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size; + if (rows_per_buffer < 1) { + FARF(ERROR, "binary-f32: VTCM too small\n"); + return HTP_STATUS_VTCM_TOO_SMALL; + } - FARF(HIGH, - "%s: (%ux%ux%ux%u) * (%ux%ux%ux%u) -> (%ux%ux%ux%u) : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", - op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], - src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], octx->src0_spad.size, octx->src1_spad.size, - octx->dst_spad.size); + octx->src0_spad.size_per_thread = rows_per_buffer * 2 * src0_row_size_aligned; + octx->dst_spad.size_per_thread = rows_per_buffer * 2 * dst_row_size_aligned; - // Make sure the reserved vtcm size is sufficient - if (octx->ctx->vtcm_size < spad_size) { - FARF(ERROR, "binary-%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, - octx->ctx->vtcm_size, spad_size); + if (is_scalar || use_complex || use_repeat || is_add_id) { + octx->src1_spad.size_per_thread = 0; + } else if (is_row_bcast) { + octx->src1_spad.size_per_thread = 0; + } else { + octx->src1_spad.size_per_thread = rows_per_buffer * 2 * src1_row_size_aligned; + } + + octx->src0_spad.size = n_threads * octx->src0_spad.size_per_thread; + if (is_row_bcast) { + octx->src1_spad.size = src1_row_size_aligned; + } else { + octx->src1_spad.size = n_threads * octx->src1_spad.size_per_thread; + } + octx->dst_spad.size = n_threads * octx->dst_spad.size_per_thread; + + if (octx->ctx->vtcm_size < (octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size)) { return HTP_STATUS_VTCM_TOO_SMALL; } @@ -305,39 +757,71 @@ static int execute_op_binary_f32(struct htp_ops_context * octx) { octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; - if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { - uint32_t n_jobs = MIN(n_threads, src0_nrows); + if ((octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { + return HTP_STATUS_OK; + } + + uint32_t n_jobs = MIN(n_threads, src0_nrows); - octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs; + dma_queue * q = octx->ctx->dma[0]; + if (is_row_bcast) { + dma_queue_push(q, dma_make_ptr(octx->src1_spad.data, (const void *) src1->data), src1_row_size_aligned, 0, src1->ne[0] * sizeof(float), 1); + } - octx->src0_div21 = init_fastdiv_values(src0->ne[2] * src0->ne[1]); - octx->src0_div3 = init_fastdiv_values(src0->ne[3]); - octx->src0_div2 = init_fastdiv_values(src0->ne[2]); - octx->src0_div1 = init_fastdiv_values(src0->ne[1]); + struct htp_binary_context bctx; + bctx.octx = octx; + bctx.nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs; + bctx.block_max = rows_per_buffer; + bctx.src0_row_size_aligned = src0_row_size_aligned; + bctx.src1_row_size_aligned = src1_row_size_aligned; + bctx.dst_row_size_aligned = dst_row_size_aligned; - octx->src1_div21 = init_fastdiv_values(src1->ne[2] * src1->ne[1]); - octx->src1_div3 = init_fastdiv_values(src1->ne[3]); - octx->src1_div2 = init_fastdiv_values(src1->ne[2]); - octx->src1_div1 = init_fastdiv_values(src1->ne[1]); + bctx.dim1_div = init_fastdiv_values(src0->ne[1]); + bctx.dim2_div = init_fastdiv_values(src0->ne[2]); + bctx.dim12_div = init_fastdiv_values(src0->ne[1] * src0->ne[2]); - worker_pool_run_func(octx->ctx->worker_pool, binary_op_func, octx, n_jobs); - } + bctx.src1_dim1_div = init_fastdiv_values(src1->ne[1]); + bctx.src1_dim2_div = init_fastdiv_values(src1->ne[2]); + bctx.src1_dim3_div = init_fastdiv_values(src1->ne[3]); - return err; -} + bool src0_contig_dim1 = (src0->nb[2] == src0->ne[1] * src0->nb[1]); + bool dst_contig_dim1 = (dst->nb[2] == src0->ne[1] * dst->nb[1]); -int op_binary(struct htp_ops_context * octx) { - int err = HTP_STATUS_OK; + bool src0_contig_dim2 = (src0->nb[3] == src0->ne[2] * src0->nb[2]); + bool dst_contig_dim2 = (dst->nb[3] == src0->ne[2] * dst->nb[2]); - switch (octx->src0.type) { - case HTP_TYPE_F32: - err = execute_op_binary_f32(octx); - break; + bctx.split_at_ne01 = (src0->ne[2] > 1) && + ((src1->ne[1] > 1) || (src1->ne[2] > 1) || !src0_contig_dim1 || !dst_contig_dim1); - default: - err = HTP_STATUS_NO_SUPPORT; - break; + bctx.split_at_ne02 = (src0->ne[3] > 1) && + ((src1->ne[2] > 1) || (src1->ne[3] > 1) || !src0_contig_dim2 || !dst_contig_dim2); + + // Precompute specific kernel parameters + if (use_vector_same) { + bctx.src1_dma_stride = (src1->ne[1] == 1) ? 0 : src1->nb[1]; + bctx.src1_fetch_rows = (src1->ne[1] == 1) ? 1 : rows_per_buffer; } - return err; + worker_callback_t worker_func; + if (is_add_id) worker_func = binary_job_add_id; + else if (is_scalar) worker_func = binary_job_scalar; + else if (is_row_bcast) worker_func = binary_job_vector_row_broadcast; + else if (use_vector_same) worker_func = binary_job_vector_same_shape; + else if (use_complex) worker_func = binary_job_vector_complex; + else worker_func = binary_job_element_repeat; + + if (is_row_bcast) { + dma_queue_pop(q); + } + + worker_pool_run_func(octx->ctx->worker_pool, worker_func, &bctx, n_jobs); + + return HTP_STATUS_OK; +} + +int op_binary(struct htp_ops_context * octx) { + if (octx->src0.type == HTP_TYPE_F32) { + return execute_op_binary_f32(octx); + } + return HTP_STATUS_NO_SUPPORT; } diff --git a/ggml/src/ggml-hexagon/htp/htp-msg.h b/ggml/src/ggml-hexagon/htp/htp-msg.h index f49e8ee4478..25403bb1126 100644 --- a/ggml/src/ggml-hexagon/htp/htp-msg.h +++ b/ggml/src/ggml-hexagon/htp/htp-msg.h @@ -42,32 +42,36 @@ enum htp_data_type { HTP_TYPE_COUNT }; -// These values are manually translated over to HTP -// !!!! DO NOT ALTER THE ORDER OF THE FIRST FOUR ENUMS !!!! +// Do not reorder first 4 (used as an index) enum htp_op { - HTP_OP_MUL = 0, - HTP_OP_ADD = 1, - HTP_OP_SUB = 2, - HTP_OP_DIV = 3, - HTP_OP_MUL_MAT = 4, - HTP_OP_MUL_MAT_ID = 5, - HTP_OP_RMS_NORM = 6, - HTP_OP_UNARY_SILU = 7, - HTP_OP_UNARY_GELU = 8, - HTP_OP_GLU_SWIGLU = 9, - HTP_OP_GLU_SWIGLU_OAI = 10, - HTP_OP_SOFTMAX = 11, - HTP_OP_ADD_ID = 12, - HTP_OP_ROPE = 13, - HTP_OP_FLASH_ATTN_EXT = 14, - HTP_OP_SET_ROWS = 15, - HTP_OP_SCALE = 16, - HTP_OP_GET_ROWS = 17, - HTP_OP_CPY = 18, + HTP_OP_MUL = 0, + HTP_OP_ADD = 1, + HTP_OP_SUB = 2, + HTP_OP_DIV = 3, + HTP_OP_MUL_MAT, + HTP_OP_MUL_MAT_ID, + HTP_OP_RMS_NORM, + HTP_OP_UNARY_SILU, + HTP_OP_UNARY_GELU, + HTP_OP_GLU_SWIGLU, + HTP_OP_GLU_SWIGLU_OAI, + HTP_OP_GLU_GEGLU, + HTP_OP_SOFTMAX, + HTP_OP_ADD_ID, + HTP_OP_ROPE, + HTP_OP_FLASH_ATTN_EXT, + HTP_OP_SET_ROWS, + HTP_OP_GET_ROWS, + HTP_OP_SCALE, + HTP_OP_CPY, + HTP_OP_ARGSORT, + HTP_OP_SQR, + HTP_OP_SQRT, + HTP_OP_SUM_ROWS, INVALID }; -static inline size_t htp_type_block_size(uint32_t t) { +static inline size_t htp_t_block_size(uint32_t t) { switch (t) { case HTP_TYPE_F32: return 1; @@ -103,22 +107,6 @@ static inline size_t htp_type_nbytes(uint32_t t) { return 0; } -static const char * htp_type_name(uint32_t t) { - switch (t) { - case HTP_TYPE_F32: - return "fp32"; - case HTP_TYPE_F16: - return "fp16"; - case HTP_TYPE_Q4_0: - return "q4_0"; - case HTP_TYPE_Q8_0: - return "q8_0"; - case HTP_TYPE_MXFP4: - return "mxfp4"; - } - return 0; -} - // Internal types #define QK_Q4_0x4x2 256 // 4x Q4_0 blocks packed with next 4x Q4_0 blocks (size in bytes 128) #define QK_Q8_0x4x2 256 // 4x Q8_0 blocks concat with next 4x Q8_0 blocks diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index 602a2775a47..c0d72587ce5 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -90,6 +90,7 @@ int op_matmul(struct htp_ops_context * octx); int op_matmul_id(struct htp_ops_context * octx); int op_binary(struct htp_ops_context * octx); int op_unary(struct htp_ops_context * octx); +int op_sum_rows(struct htp_ops_context * octx); int op_activations(struct htp_ops_context * octx); int op_softmax(struct htp_ops_context * octx); int op_add_id(struct htp_ops_context * octx); @@ -98,5 +99,6 @@ int op_flash_attn_ext(struct htp_ops_context * octx); int op_set_rows(struct htp_ops_context * octx); int op_get_rows(struct htp_ops_context * octx); int op_cpy(struct htp_ops_context * octx); +int op_argsort(struct htp_ops_context * octx); #endif /* HTP_OPS_H */ diff --git a/ggml/src/ggml-hexagon/htp/hvx-arith.h b/ggml/src/ggml-hexagon/htp/hvx-arith.h index 3449739a4fa..2577cdd0418 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-arith.h +++ b/ggml/src/ggml-hexagon/htp/hvx-arith.h @@ -46,127 +46,76 @@ #define HVX_OP_MUL(a, b) Q6_Vsf_vmpy_VsfVsf(a, b) #endif -// ADD variants - -static inline void hvx_add_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { - assert((unsigned long) dst % 128 == 0); - assert((unsigned long) src0 % 128 == 0); - assert((unsigned long) src1 % 128 == 0); - hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_ADD); -} - -static inline void hvx_add_f32_au(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { - assert((unsigned long) dst % 128 == 0); - assert((unsigned long) src0 % 128 == 0); - hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_ADD); -} - -static inline void hvx_add_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { - assert((unsigned long) src0 % 128 == 0); - assert((unsigned long) src1 % 128 == 0); - hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u, HVX_OP_ADD); -} - -static inline void hvx_add_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { - hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_ADD); -} - -// SUB variants - -static inline void hvx_sub_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { - assert((unsigned long) dst % 128 == 0); - assert((unsigned long) src0 % 128 == 0); - assert((unsigned long) src1 % 128 == 0); - hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_SUB); -} - -static inline void hvx_sub_f32_au(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { - assert((unsigned long) dst % 128 == 0); - assert((unsigned long) src0 % 128 == 0); - hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_SUB); -} - -static inline void hvx_sub_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { - assert((unsigned long) src0 % 128 == 0); - assert((unsigned long) src1 % 128 == 0); - hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u, HVX_OP_SUB); -} - -static inline void hvx_sub_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { - hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_SUB); -} - -// MUL variants - -static inline void hvx_mul_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { - assert((unsigned long) dst % 128 == 0); - assert((unsigned long) src0 % 128 == 0); - assert((unsigned long) src1 % 128 == 0); - hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_MUL); -} - -static inline void hvx_mul_f32_au(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { - assert((unsigned long) dst % 128 == 0); - assert((unsigned long) src0 % 128 == 0); - hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_MUL); -} - -static inline void hvx_mul_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { - assert((unsigned long) src0 % 128 == 0); - assert((unsigned long) src1 % 128 == 0); - hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u, HVX_OP_MUL); -} - -static inline void hvx_mul_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { - hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_MUL); -} - -// Dispatchers - -static inline void hvx_add_f32(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) { - if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src0, 128)) { - if (hex_is_aligned((void *) src1, 128)) { - hvx_add_f32_aa(dst, src0, src1, num_elems); - } else { - hvx_add_f32_au(dst, src0, src1, num_elems); - } - } else if (hex_is_aligned((void *) src0, 128) && hex_is_aligned((void *) src1, 128)) { - hvx_add_f32_ua(dst, src0, src1, num_elems); - } else { - hvx_add_f32_uu(dst, src0, src1, num_elems); - } -} - -static inline void hvx_sub_f32(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) { - if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src0, 128)) { - if (hex_is_aligned((void *) src1, 128)) { - hvx_sub_f32_aa(dst, src0, src1, num_elems); - } else { - hvx_sub_f32_au(dst, src0, src1, num_elems); - } - } else if (hex_is_aligned((void *) src0, 128) && hex_is_aligned((void *) src1, 128)) { - hvx_sub_f32_ua(dst, src0, src1, num_elems); - } else { - hvx_sub_f32_uu(dst, src0, src1, num_elems); - } -} - -static inline void hvx_mul_f32(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) { - if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src0, 128)) { - if (hex_is_aligned((void *) src1, 128)) { - hvx_mul_f32_aa(dst, src0, src1, num_elems); - } else { - hvx_mul_f32_au(dst, src0, src1, num_elems); - } - } else if (hex_is_aligned((void *) src0, 128) && hex_is_aligned((void *) src1, 128)) { - hvx_mul_f32_ua(dst, src0, src1, num_elems); - } else { - hvx_mul_f32_uu(dst, src0, src1, num_elems); - } -} +// Generic macro to define alignment permutations for an op +#define DEFINE_HVX_BINARY_OP_VARIANTS(OP_NAME, OP_MACRO) \ +static inline void OP_NAME##_aaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ + assert((uintptr_t) dst % 128 == 0); \ + assert((uintptr_t) src0 % 128 == 0); \ + assert((uintptr_t) src1 % 128 == 0); \ + hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a, OP_MACRO); \ +} \ +static inline void OP_NAME##_aau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ + assert((uintptr_t) dst % 128 == 0); \ + assert((uintptr_t) src0 % 128 == 0); \ + hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a, OP_MACRO); \ +} \ +static inline void OP_NAME##_aua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ + assert((uintptr_t) dst % 128 == 0); \ + assert((uintptr_t) src1 % 128 == 0); \ + hvx_arith_loop_body(HVX_Vector, HVX_UVector, HVX_Vector, hvx_vec_store_a, OP_MACRO); \ +} \ +static inline void OP_NAME##_auu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ + assert((uintptr_t) dst % 128 == 0); \ + hvx_arith_loop_body(HVX_Vector, HVX_UVector, HVX_UVector, hvx_vec_store_a, OP_MACRO); \ +} \ +static inline void OP_NAME##_uaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ + assert((uintptr_t) src0 % 128 == 0); \ + assert((uintptr_t) src1 % 128 == 0); \ + hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u, OP_MACRO); \ +} \ +static inline void OP_NAME##_uau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ + assert((uintptr_t) src0 % 128 == 0); \ + hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_UVector, hvx_vec_store_u, OP_MACRO); \ +} \ +static inline void OP_NAME##_uua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ + assert((uintptr_t) src1 % 128 == 0); \ + hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_Vector, hvx_vec_store_u, OP_MACRO); \ +} \ +static inline void OP_NAME##_uuu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ + hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u, OP_MACRO); \ +} \ + +DEFINE_HVX_BINARY_OP_VARIANTS(hvx_add_f32, HVX_OP_ADD) +DEFINE_HVX_BINARY_OP_VARIANTS(hvx_sub_f32, HVX_OP_SUB) +DEFINE_HVX_BINARY_OP_VARIANTS(hvx_mul_f32, HVX_OP_MUL) + +// Dispatcher logic +#define HVX_BINARY_DISPATCHER(OP_NAME) \ +static inline void OP_NAME(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) { \ + if (hex_is_aligned((void *) dst, 128)) { \ + if (hex_is_aligned((void *) src0, 128)) { \ + if (hex_is_aligned((void *) src1, 128)) OP_NAME##_aaa(dst, src0, src1, num_elems); \ + else OP_NAME##_aau(dst, src0, src1, num_elems); \ + } else { \ + if (hex_is_aligned((void *) src1, 128)) OP_NAME##_aua(dst, src0, src1, num_elems); \ + else OP_NAME##_auu(dst, src0, src1, num_elems); \ + } \ + } else { \ + if (hex_is_aligned((void *) src0, 128)) { \ + if (hex_is_aligned((void *) src1, 128)) OP_NAME##_uaa(dst, src0, src1, num_elems); \ + else OP_NAME##_uau(dst, src0, src1, num_elems); \ + } else { \ + if (hex_is_aligned((void *) src1, 128)) OP_NAME##_uua(dst, src0, src1, num_elems); \ + else OP_NAME##_uuu(dst, src0, src1, num_elems); \ + } \ + } \ +} + +HVX_BINARY_DISPATCHER(hvx_add_f32) +HVX_BINARY_DISPATCHER(hvx_sub_f32) +HVX_BINARY_DISPATCHER(hvx_mul_f32) // Mul-Mul Optimized - static inline void hvx_mul_mul_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint8_t * restrict src2, const uint32_t num_elems) { assert((unsigned long) dst % 128 == 0); assert((unsigned long) src0 % 128 == 0); @@ -443,6 +392,68 @@ static inline void hvx_clamp_scalar_f32(uint8_t * restrict dst, const uint8_t * } } +// +// Square +// + +#define hvx_sqr_loop_body(dst_type, src_type, vec_store) \ + do { \ + dst_type * restrict vdst = (dst_type *) dst; \ + src_type * restrict vsrc = (src_type *) src; \ + \ + const uint32_t elem_size = sizeof(float); \ + const uint32_t epv = 128 / elem_size; \ + const uint32_t nvec = n / epv; \ + const uint32_t nloe = n % epv; \ + \ + uint32_t i = 0; \ + \ + _Pragma("unroll(4)") \ + for (; i < nvec; i++) { \ + vdst[i] = HVX_OP_MUL(vsrc[i], vsrc[i]); \ + } \ + if (nloe) { \ + HVX_Vector v = HVX_OP_MUL(vsrc[i], vsrc[i]); \ + vec_store((void *) &vdst[i], nloe * elem_size, v); \ + } \ + } while(0) + +static inline void hvx_sqr_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + assert((unsigned long) dst % 128 == 0); + assert((unsigned long) src % 128 == 0); + hvx_sqr_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a); +} + +static inline void hvx_sqr_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + assert((unsigned long) dst % 128 == 0); + hvx_sqr_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a); +} + +static inline void hvx_sqr_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + assert((unsigned long) src % 128 == 0); + hvx_sqr_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u); +} + +static inline void hvx_sqr_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + hvx_sqr_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u); +} + +static inline void hvx_sqr_f32(uint8_t * restrict dst, const uint8_t * restrict src, const uint32_t num_elems) { + if (hex_is_aligned((void *) dst, 128)) { + if (hex_is_aligned((void *) src, 128)) { + hvx_sqr_f32_aa(dst, src, num_elems); + } else { + hvx_sqr_f32_au(dst, src, num_elems); + } + } else { + if (hex_is_aligned((void *) src, 128)) { + hvx_sqr_f32_ua(dst, src, num_elems); + } else { + hvx_sqr_f32_uu(dst, src, num_elems); + } + } +} + #undef HVX_OP_ADD #undef HVX_OP_SUB #undef HVX_OP_MUL @@ -453,5 +464,7 @@ static inline void hvx_clamp_scalar_f32(uint8_t * restrict dst, const uint8_t * #undef hvx_scalar_loop_body #undef HVX_OP_MIN_SCALAR #undef HVX_OP_CLAMP_SCALAR +#undef DEFINE_HVX_BINARY_OP_VARIANTS +#undef HVX_BINARY_DISPATCHER #endif // HVX_ARITH_H diff --git a/ggml/src/ggml-hexagon/htp/hvx-base.h b/ggml/src/ggml-hexagon/htp/hvx-base.h index ffa6e18e645..12a1b7f1288 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-base.h +++ b/ggml/src/ggml-hexagon/htp/hvx-base.h @@ -66,6 +66,12 @@ static inline float hvx_vec_get_f32(HVX_Vector v) { return x; } +static inline int32_t hvx_vec_get_i32(HVX_Vector v) { + int32_t __attribute__((aligned(128))) x; + hvx_vec_store_a(&x, 4, v); + return x; +} + static inline HVX_Vector hvx_vec_abs_f16(HVX_Vector v) { // abs by clearing the fp16 sign bit HVX_Vector mask = Q6_Vh_vsplat_R(0x7fff); diff --git a/ggml/src/ggml-hexagon/htp/hvx-copy.h b/ggml/src/ggml-hexagon/htp/hvx-copy.h index 6b617b76177..ae0dbed0306 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-copy.h +++ b/ggml/src/ggml-hexagon/htp/hvx-copy.h @@ -136,8 +136,6 @@ static inline void hvx_copy_f32_uu(uint8_t * restrict dst, const uint8_t * restr dst_type * restrict vdst = (dst_type *) dst; \ src_type * restrict vsrc = (src_type *) src; \ \ - const HVX_Vector zero = Q6_V_vsplat_R(0); \ - \ const uint32_t elem_size = sizeof(__fp16); \ const uint32_t epv = 128 / elem_size; \ const uint32_t nvec = n / epv; \ diff --git a/ggml/src/ggml-hexagon/htp/hvx-div.h b/ggml/src/ggml-hexagon/htp/hvx-div.h new file mode 100644 index 00000000000..7dae012e0ed --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hvx-div.h @@ -0,0 +1,116 @@ +#ifndef HVX_DIV_H +#define HVX_DIV_H + +#include + +#include +#include +#include +#include +#include + +#include "hvx-base.h" +#include "hex-utils.h" +#include "hvx-inverse.h" +#include "hvx-arith.h" + +#if __HVX_ARCH__ < 79 +#define HVX_OP_MUL(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b)) +#else +#define HVX_OP_MUL(a, b) Q6_Vsf_vmpy_VsfVsf(a, b) +#endif + +#define hvx_div_f32_loop_body(dst_type, src0_type, src1_type, vec_store) \ + do { \ + dst_type * restrict vdst = (dst_type *) dst; \ + src0_type * restrict vsrc0 = (src0_type *) src0; \ + src1_type * restrict vsrc1 = (src1_type *) src1; \ + \ + const HVX_Vector nan_inf_mask = Q6_V_vsplat_R(0x7f800000); \ + \ + const uint32_t nvec = n / VLEN_FP32; \ + const uint32_t nloe = n % VLEN_FP32; \ + \ + uint32_t i = 0; \ + \ + _Pragma("unroll(4)") \ + for (; i < nvec; i++) { \ + HVX_Vector inv_src1 = hvx_vec_inverse_f32_guard(vsrc1[i], nan_inf_mask); \ + HVX_Vector res = HVX_OP_MUL(vsrc0[i], inv_src1); \ + vdst[i] = res; \ + } \ + if (nloe) { \ + HVX_Vector inv_src1 = hvx_vec_inverse_f32_guard(vsrc1[i], nan_inf_mask); \ + HVX_Vector res = HVX_OP_MUL(vsrc0[i], inv_src1); \ + vec_store((void *) &vdst[i], nloe * SIZEOF_FP32, res); \ + } \ + } while(0) + +// 3-letter suffix variants +static inline void hvx_div_f32_aaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { + assert((uintptr_t) dst % 128 == 0); + assert((uintptr_t) src0 % 128 == 0); + assert((uintptr_t) src1 % 128 == 0); + hvx_div_f32_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a); +} + +static inline void hvx_div_f32_aau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { + assert((uintptr_t) dst % 128 == 0); + assert((uintptr_t) src0 % 128 == 0); + hvx_div_f32_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a); +} + +static inline void hvx_div_f32_aua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { + assert((uintptr_t) dst % 128 == 0); + assert((uintptr_t) src1 % 128 == 0); + hvx_div_f32_loop_body(HVX_Vector, HVX_UVector, HVX_Vector, hvx_vec_store_a); +} + +static inline void hvx_div_f32_auu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { + assert((uintptr_t) dst % 128 == 0); + hvx_div_f32_loop_body(HVX_Vector, HVX_UVector, HVX_UVector, hvx_vec_store_a); +} + +static inline void hvx_div_f32_uaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { + assert((uintptr_t) src0 % 128 == 0); + assert((uintptr_t) src1 % 128 == 0); + hvx_div_f32_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u); +} + +static inline void hvx_div_f32_uau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { + assert((uintptr_t) src0 % 128 == 0); + hvx_div_f32_loop_body(HVX_UVector, HVX_Vector, HVX_UVector, hvx_vec_store_u); +} + +static inline void hvx_div_f32_uua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { + assert((uintptr_t) src1 % 128 == 0); + hvx_div_f32_loop_body(HVX_UVector, HVX_UVector, HVX_Vector, hvx_vec_store_u); +} + +static inline void hvx_div_f32_uuu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { + hvx_div_f32_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u); +} + +static inline void hvx_div_f32(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) { + if (hex_is_aligned((void *) dst, 128)) { + if (hex_is_aligned((void *) src0, 128)) { + if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_aaa(dst, src0, src1, num_elems); + else hvx_div_f32_aau(dst, src0, src1, num_elems); + } else { + if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_aua(dst, src0, src1, num_elems); + else hvx_div_f32_auu(dst, src0, src1, num_elems); + } + } else { + if (hex_is_aligned((void *) src0, 128)) { + if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_uaa(dst, src0, src1, num_elems); + else hvx_div_f32_uau(dst, src0, src1, num_elems); + } else { + if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_uua(dst, src0, src1, num_elems); + else hvx_div_f32_uuu(dst, src0, src1, num_elems); + } + } +} + +#undef HVX_OP_MUL + +#endif // HVX_DIV_H diff --git a/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h b/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h index 1b4aaff0c92..095193277ea 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +++ b/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h @@ -91,6 +91,27 @@ static inline HVX_Vector hvx_vec_tanh_f32(HVX_Vector x) { } \ } while(0) +#define hvx_tanh_loop_body(dst_type, src_type, vec_store) \ + do { \ + dst_type * restrict vdst = (dst_type *) dst; \ + src_type * restrict vsrc = (src_type *) src; \ + \ + const uint32_t epv = 128 / sizeof(float); \ + const uint32_t nvec = n / epv; \ + const uint32_t nloe = n % epv; \ + \ + uint32_t i = 0; \ + \ + _Pragma("unroll(4)") \ + for (; i < nvec; i++) { \ + vdst[i] = hvx_vec_tanh_f32(vsrc[i]); \ + } \ + if (nloe) { \ + HVX_Vector tmp = hvx_vec_tanh_f32(vsrc[i]); \ + vec_store((void *) &vdst[i], nloe * sizeof(float), tmp); \ + } \ + } while(0) + static inline void hvx_sigmoid_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { assert((unsigned long) dst % 128 == 0); assert((unsigned long) src % 128 == 0); @@ -111,4 +132,10 @@ static inline void hvx_sigmoid_f32_uu(uint8_t * restrict dst, const uint8_t * re hvx_sigmoid_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u); } +static inline void hvx_tanh_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + assert((unsigned long) dst % 128 == 0); + assert((unsigned long) src % 128 == 0); + hvx_tanh_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a); +} + #endif /* HVX_SIGMOID_H */ diff --git a/ggml/src/ggml-hexagon/htp/hvx-sqrt.h b/ggml/src/ggml-hexagon/htp/hvx-sqrt.h index 28ee9f68d3e..e31a1006d21 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +++ b/ggml/src/ggml-hexagon/htp/hvx-sqrt.h @@ -12,11 +12,17 @@ #define RSQRT_ONE_HALF 0x3f000000 // 0.5 #define RSQRT_THREE_HALVES 0x3fc00000 // 1.5 +#if __HVX_ARCH__ < 79 +#define HVX_OP_MUL(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b)) +#else +#define HVX_OP_MUL(a, b) Q6_Vsf_vmpy_VsfVsf(a, b) +#endif + static inline HVX_Vector hvx_vec_rsqrt_f32(HVX_Vector in_vec) { //Algorithm : // x2 = input*0.5 // y = * (long *) &input - // y = 0x5f3759df - (y>>2) + // y = 0x5f3759df - (y>>1) // y = y*(threehalfs - x2*y*y) HVX_Vector rsqrtconst = Q6_V_vsplat_R(RSQRT_CONST); @@ -57,4 +63,64 @@ static inline HVX_Vector hvx_vec_rsqrt_f32(HVX_Vector in_vec) { return Q6_Vsf_equals_Vqf32(temp); } +// Compute sqrt(x) as x*inv_sqrt(x) +#define hvx_sqrt_f32_loop_body(dst_type, src_type, vec_store) \ + do { \ + dst_type * restrict vdst = (dst_type *) dst; \ + src_type * restrict vsrc = (src_type *) src; \ + \ + const uint32_t nvec = n / VLEN_FP32; \ + const uint32_t nloe = n % VLEN_FP32; \ + \ + uint32_t i = 0; \ + \ + _Pragma("unroll(4)") \ + for (; i < nvec; i++) { \ + HVX_Vector inv_sqrt = hvx_vec_rsqrt_f32(vsrc[i]); \ + HVX_Vector sqrt_res = HVX_OP_MUL(inv_sqrt, vsrc[i]); \ + vdst[i] = sqrt_res; \ + } \ + if (nloe) { \ + HVX_Vector inv_sqrt = hvx_vec_rsqrt_f32(vsrc[i]); \ + HVX_Vector sqrt_res = HVX_OP_MUL(inv_sqrt, vsrc[i]); \ + vec_store((void *) &vdst[i], nloe * SIZEOF_FP32, sqrt_res); \ + } \ + } while(0) + +static inline void hvx_sqrt_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + assert((unsigned long) dst % 128 == 0); + assert((unsigned long) src % 128 == 0); + hvx_sqrt_f32_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a); +} + +static inline void hvx_sqrt_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + assert((unsigned long) dst % 128 == 0); + hvx_sqrt_f32_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a); +} + +static inline void hvx_sqrt_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + assert((unsigned long) src % 128 == 0); + hvx_sqrt_f32_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u); +} + +static inline void hvx_sqrt_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + hvx_sqrt_f32_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u); +} + +static inline void hvx_sqrt_f32(uint8_t * restrict dst, const uint8_t * restrict src, const int num_elems) { + if ((unsigned long) dst % 128 == 0) { + if ((unsigned long) src % 128 == 0) { + hvx_sqrt_f32_aa(dst, src, num_elems); + } else { + hvx_sqrt_f32_au(dst, src, num_elems); + } + } else { + if ((unsigned long) src % 128 == 0) { + hvx_sqrt_f32_ua(dst, src, num_elems); + } else { + hvx_sqrt_f32_uu(dst, src, num_elems); + } + } +} + #endif /* HVX_SQRT_H */ diff --git a/ggml/src/ggml-hexagon/htp/hvx-utils.h b/ggml/src/ggml-hexagon/htp/hvx-utils.h index 7b79a5ea322..a518ad37331 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-utils.h +++ b/ggml/src/ggml-hexagon/htp/hvx-utils.h @@ -12,6 +12,7 @@ #include "hvx-sigmoid.h" #include "hvx-sqrt.h" #include "hvx-arith.h" +#include "hvx-div.h" #include "hvx-base.h" #endif /* HVX_UTILS_H */ diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index e28a67a95dc..62708eee5cf 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -440,6 +440,45 @@ static void proc_matmul_req(struct htp_context * ctx, send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); } +static void proc_argsort_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { + struct dspqueue_buffer rsp_bufs[1]; + + // We had written to the output buffer, we'd also need to flush it + rsp_bufs[0].fd = bufs[1].fd; + rsp_bufs[0].ptr = bufs[1].ptr; + rsp_bufs[0].offset = bufs[1].offset; + rsp_bufs[0].size = bufs[1].size; + rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP + DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU + + // Setup Op context + struct htp_ops_context octx = { 0 }; + octx.ctx = ctx; + octx.src0 = req->src0; + octx.dst = req->dst; + octx.flags = req->flags; + octx.op = req->op; + + memcpy(octx.op_params, req->op_params, sizeof(octx.op_params)); + + // Update data pointers + octx.src0.data = (uint32_t) bufs[0].ptr; + octx.dst.data = (uint32_t) bufs[1].ptr; + octx.n_threads = ctx->n_threads; + + struct profile_data prof; + profile_start(&prof); + + uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; + if (vtcm_acquire(ctx) == AEE_SUCCESS) { + rsp_status = op_argsort(&octx); + vtcm_release(ctx); + } + + profile_stop(&prof); + send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); +} + static void proc_cpy_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { struct dspqueue_buffer rsp_bufs[1]; @@ -679,6 +718,45 @@ static void proc_unary_req(struct htp_context * ctx, struct htp_general_req * re send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); } +static void proc_sum_rows_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { + struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS]; + + // We had written to the output buffer, we'd also need to flush it + rsp_bufs[0].fd = bufs[1].fd; + rsp_bufs[0].ptr = bufs[1].ptr; + rsp_bufs[0].offset = bufs[1].offset; + rsp_bufs[0].size = bufs[1].size; + rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP + DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU + + // Setup Op context + struct htp_ops_context octx = { 0 }; + octx.ctx = ctx; + octx.src0 = req->src0; + octx.dst = req->dst; + octx.flags = req->flags; + octx.op = req->op; + + memcpy(octx.op_params, req->op_params, sizeof(octx.op_params)); + + // Update data pointers + octx.src0.data = (uint32_t) bufs[0].ptr; + octx.dst.data = (uint32_t) bufs[1].ptr; + octx.n_threads = ctx->n_threads; + + struct profile_data prof; + profile_start(&prof); + + uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; + if (vtcm_acquire(ctx) == AEE_SUCCESS) { + rsp_status = op_sum_rows(&octx); + vtcm_release(ctx); + } + + profile_stop(&prof); + send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); +} + static void proc_activations_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs, @@ -951,6 +1029,7 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) { case HTP_OP_MUL: case HTP_OP_ADD: case HTP_OP_SUB: + case HTP_OP_DIV: if (n_bufs != 3) { FARF(ERROR, "Bad binary-req buffer list"); continue; @@ -968,6 +1047,25 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) { proc_unary_req(ctx, &req, bufs); break; + case HTP_OP_SQR: + case HTP_OP_SQRT: + if (n_bufs != 2) { + FARF(ERROR, "Bad unary-req buffer list"); + continue; + } + + proc_unary_req(ctx, &req, bufs); + break; + + case HTP_OP_SUM_ROWS: + if (n_bufs != 2) { + FARF(ERROR, "Bad unary-req buffer list"); + continue; + } + + proc_sum_rows_req(ctx, &req, bufs); + break; + case HTP_OP_UNARY_SILU: case HTP_OP_UNARY_GELU: if (n_bufs != 2) { @@ -980,6 +1078,7 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) { case HTP_OP_GLU_SWIGLU: case HTP_OP_GLU_SWIGLU_OAI: case HTP_OP_SOFTMAX: + case HTP_OP_GLU_GEGLU: if ((n_bufs != 2) && (n_bufs != 3)) { FARF(ERROR, "Bad act-req buffer list"); continue; @@ -1035,6 +1134,14 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) { proc_cpy_req(ctx, &req, bufs); break; + case HTP_OP_ARGSORT: + if (n_bufs != 2) { + FARF(ERROR, "Bad argsort-req buffer list"); + continue; + } + proc_argsort_req(ctx, &req, bufs); + break; + default: FARF(ERROR, "Unknown Op %u", req.op); break; diff --git a/ggml/src/ggml-hexagon/htp/sum-rows-ops.c b/ggml/src/ggml-hexagon/htp/sum-rows-ops.c new file mode 100644 index 00000000000..62e45da2b35 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/sum-rows-ops.c @@ -0,0 +1,115 @@ +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#include +#include + +#include +#include + +#include "hex-dma.h" +#include "hvx-utils.h" + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" +#include "htp-msg.h" +#include "htp-ops.h" + + +#define sum_rows_preamble \ + struct htp_tensor *src0 = &octx->src0;\ + struct htp_tensor *dst = &octx->dst; \ + \ + const uint32_t ne00 = src0->ne[0]; \ + const uint32_t ne01 = src0->ne[1]; \ + const uint32_t ne02 = src0->ne[2]; \ + const uint32_t ne03 = src0->ne[3]; \ + \ + const uint32_t nb00 = src0->nb[0]; \ + const uint32_t nb01 = src0->nb[1]; \ + const uint32_t nb02 = src0->nb[2]; \ + const uint32_t nb03 = src0->nb[3]; \ + \ + const uint32_t ne0 = dst->ne[0]; \ + const uint32_t ne1 = dst->ne[1]; \ + const uint32_t ne2 = dst->ne[2]; \ + const uint32_t ne3 = dst->ne[3]; \ + \ + const uint32_t nb0 = dst->nb[0]; \ + const uint32_t nb1 = dst->nb[1]; \ + const uint32_t nb2 = dst->nb[2]; \ + const uint32_t nb3 = dst->nb[3]; \ + +static int sum_rows_thread_f32(struct htp_ops_context * octx, const int nth, const int ith) { + sum_rows_preamble; + + const uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread; + const size_t src0_row_size = nb01; + const size_t dst_row_size = nb1; + + const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows + + const uint32_t src0_start_row = src0_nrows_per_thread * ith; + const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); + + // no work for this thread + if (src0_start_row >= src0_end_row) { + return HTP_STATUS_OK; + } + + int opt_path = 0; + if ((0 == hex_is_aligned((void *) src0->data, VLEN)) && !(nb01 & (VLEN - 1))) { + opt_path = 1; + } + + const uint8_t * restrict data_src = (const uint8_t *) src0->data; + uint8_t * restrict data_dst = (uint8_t *) dst->data; + + const float * restrict src_th = (float *) (data_src + (src0_start_row * src0_row_size)); + float * restrict dst_th = (float *) (data_dst + (src0_start_row * dst_row_size)); + + for (uint32_t ir = 0; ir < src0_nrows_per_thread; ir++) { + const float * restrict src_local = src_th + (ir * ne00); + + if (ir + 1 < src0_nrows_per_thread) { + hex_l2fetch(src_local + ne00, src0_row_size, src0_row_size, 1); + } + + if (1 == opt_path) { + dst_th[ir] = hvx_reduce_sum_f32_a((const uint8_t *) src_local, ne00); + } else { + dst_th[ir] = hvx_reduce_sum_f32((const uint8_t *) src_local, ne00); + } + } + + return HTP_STATUS_OK; +} + +static void sum_rows_work_f32(unsigned int n, unsigned int i, void *data) { + sum_rows_thread_f32((struct htp_ops_context *) data, n, i); +} + +int op_sum_rows(struct htp_ops_context * octx) { + sum_rows_preamble; + + if (octx->src0.type != HTP_TYPE_F32) { + return HTP_STATUS_NO_SUPPORT; + } + + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { + return HTP_STATUS_OK; + } + + const int n_threads = octx->n_threads; + const uint32_t src0_nrows = ne01 * ne02 * ne03; + + uint32_t n_jobs = MIN(n_threads, src0_nrows); + octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs; + + worker_pool_run_func(octx->ctx->worker_pool, sum_rows_work_f32, octx, n_jobs); + + return HTP_STATUS_OK; +} + diff --git a/ggml/src/ggml-hexagon/htp/unary-ops.c b/ggml/src/ggml-hexagon/htp/unary-ops.c index 1a27cb6e63e..ce879bf0370 100644 --- a/ggml/src/ggml-hexagon/htp/unary-ops.c +++ b/ggml/src/ggml-hexagon/htp/unary-ops.c @@ -132,6 +132,56 @@ static void rms_norm_htp_f32(const float * restrict src, } } +static void sqr_htp_f32(const float * restrict src, + float * restrict dst, + uint8_t * restrict spad, + const uint32_t num_rows, + const uint32_t row_elems, + const size_t row_size, + int32_t * op_params, + int opt_path) { + + for (uint32_t ir = 0; ir < num_rows; ir++) { + const float * restrict src_local = src + (ir * row_elems); + float * restrict dst_local = dst + (ir * row_elems); + + if (ir + 1 < num_rows) { + hex_l2fetch(src_local + row_elems, row_size, row_size, 1); + } + + if (1 == opt_path) { + hvx_sqr_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems); + } else { + hvx_sqr_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems); + } + } +} + +static void sqrt_htp_f32(const float * restrict src, + float * restrict dst, + uint8_t * restrict spad, + const uint32_t num_rows, + const uint32_t row_elems, + const size_t row_size, + int32_t * op_params, + int opt_path) { + + for (uint32_t ir = 0; ir < num_rows; ir++) { + const float * restrict src_local = src + (ir * row_elems); + float * restrict dst_local = dst + (ir * row_elems); + + if (ir + 1 < num_rows) { + hex_l2fetch(src_local + row_elems, row_size, row_size, 1); + } + + if (1 == opt_path) { + hvx_sqrt_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems); + } else { + hvx_sqrt_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems); + } + } +} + static void unary_job_f32_per_thread(const struct htp_tensor * src, struct htp_tensor * dst, uint8_t * spad, @@ -181,6 +231,12 @@ static void unary_job_f32_per_thread(const struct htp_tensor * src, case HTP_OP_SCALE: scale_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path); break; + case HTP_OP_SQR: + sqr_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path); + break; + case HTP_OP_SQRT: + sqrt_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path); + break; default: break; @@ -218,6 +274,14 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) { unary_op_func = unary_job_dispatcher_f32; op_type = "scale-f32"; break; + case HTP_OP_SQR: + unary_op_func = unary_job_dispatcher_f32; + op_type = "sqr-f32"; + break; + case HTP_OP_SQRT: + unary_op_func = unary_job_dispatcher_f32; + op_type = "sqrt-f32"; + break; default: FARF(ERROR, "Unsupported unary Op %u\n", octx->op); From 3ffa1fd84e853fb4307dc4bbd249b7f12ca2f93b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 11 Feb 2026 14:53:19 +0200 Subject: [PATCH 131/831] metal : extend l2_norm support for non-cont src0 (llama/19502) --- ggml/src/ggml-metal/ggml-metal-device.cpp | 11 ++++-- ggml/src/ggml-metal/ggml-metal-device.m | 3 +- ggml/src/ggml-metal/ggml-metal-impl.h | 15 +++++++- ggml/src/ggml-metal/ggml-metal-ops.cpp | 46 ++++++++++++++++------- ggml/src/ggml-metal/ggml-metal.metal | 30 ++++++++++----- 5 files changed, 75 insertions(+), 30 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 949e344cc8c..517559d12a6 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -1480,13 +1480,15 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin_one(ggml_met ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_L2_NORM); - GGML_ASSERT(op->src[0]->ne[0] % 4 == 0); - GGML_ASSERT(ggml_is_contiguous_1(op->src[0])); - char base[256]; char name[256]; - snprintf(base, 256, "kernel_l2_norm_f32"); + const bool is_c4 = op->src[0]->ne[0] % 4 == 0; + + const char * t0_str = ggml_type_name(op->src[0]->type); + const char * t_str = ggml_type_name(op->type); + + snprintf(base, 256, "kernel_l2_norm_%s_%s%s", t0_str, t_str, is_c4 ? "_4" : ""); snprintf(name, 256, "%s", base); ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); @@ -1494,6 +1496,7 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm(ggml_met res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } + res.c4 = is_c4; res.smem = 32*sizeof(float); return res; diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 50a2a3e7f72..c714ef3add9 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -1086,9 +1086,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_MEAN: case GGML_OP_SOFT_MAX: case GGML_OP_GROUP_NORM: - return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]); case GGML_OP_L2_NORM: - return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0])); + return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]); case GGML_OP_COUNT_EQUAL: return has_simdgroup_reduction && op->src[0]->type == GGML_TYPE_I32 && diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 44141f8e3d9..952e1be076e 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -539,8 +539,21 @@ typedef struct { typedef struct { int32_t ne00; - int32_t ne00_4; + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb00; uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; float eps; } ggml_metal_kargs_l2_norm; diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index b159a8e7fd0..7db95d1c84d 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -2979,39 +2979,59 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne, op, ne); GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + GGML_ASSERT(ggml_is_contiguous_rows(op->src[0])); + + ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]); + ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op); + float eps; memcpy(&eps, op->op_params, sizeof(float)); - int nth = 32; // SIMD width - ggml_metal_kargs_l2_norm args = { - /*.ne00 =*/ ne00, - /*.ne00_4 =*/ ne00/4, - /*.nb01 =*/ nb01, - /*.eps =*/ eps, + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.eps =*/ eps, }; auto pipeline = ggml_metal_library_get_pipeline_l2_norm(lib, op); - while (nth < ne00/4 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + if (pipeline.c4) { + args.ne00 = ne00/4; + args.ne0 = ne0/4; + } + + int nth = 32; // SIMD width + + while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { nth *= 2; } nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); - nth = std::min(nth, ne00/4); const size_t smem = pipeline.smem; - const int64_t nrows = ggml_nrows(op->src[0]); - ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + ggml_metal_encoder_set_buffer (enc, bid_src0, 1); + ggml_metal_encoder_set_buffer (enc, bid_dst, 2); ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); - ggml_metal_encoder_dispatch_threadgroups(enc, nrows, 1, 1, nth, 1, 1); + ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1); return 1; } diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 7d841341a18..a385a50b942 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -2706,26 +2706,32 @@ template [[host_name("kernel_rms_norm_f32_4")]] kernel kernel_rms_norm_f template [[host_name("kernel_rms_norm_mul_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl; template [[host_name("kernel_rms_norm_mul_add_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl; -kernel void kernel_l2_norm_f32( +template +kernel void kernel_l2_norm_impl( constant ggml_metal_kargs_l2_norm & args, device const char * src0, device char * dst, threadgroup float * shmem_f32 [[threadgroup(0)]], - uint tgpig[[threadgroup_position_in_grid]], - ushort tpitg[[thread_position_in_threadgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]], - ushort tiisg[[thread_index_in_simdgroup]], - ushort ntg[[threads_per_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i03 = tgpig.z; + const int i02 = tgpig.y; + const int i01 = tgpig.x; + if (sgitg == 0) { shmem_f32[tiisg] = 0.0f; } - device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01); + device const T0 * x = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01); + device T * y = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1); float sumf = 0.0f; // parallel sum - for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) { + for (int i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) { sumf += dot(x[i00], x[i00]); } sumf = simd_sum(sumf); @@ -2743,12 +2749,16 @@ kernel void kernel_l2_norm_f32( const float scale = 1.0f/sqrt(max(sumf, args.eps)); - device float4 * y = (device float4 *) dst + tgpig*args.ne00_4; - for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) { + for (int i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) { y[i00] = x[i00] * scale; } } +typedef decltype(kernel_l2_norm_impl) kernel_l2_norm_t; + +template [[host_name("kernel_l2_norm_f32_f32")]] kernel kernel_l2_norm_t kernel_l2_norm_impl; +template [[host_name("kernel_l2_norm_f32_f32_4")]] kernel kernel_l2_norm_t kernel_l2_norm_impl; + kernel void kernel_group_norm_f32( constant ggml_metal_kargs_group_norm & args, device const float * src0, From f3e78985bec1b6f5f4f1c4ebd719d31b2b72109c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 11 Feb 2026 18:58:43 +0200 Subject: [PATCH 132/831] ggml : unary ops support non-cont src0 + metal F16 unary ops (llama/19511) * ggml : unary ops support non-cont src0 * metal : support F16 unary ops + fix ELU --- ggml/src/ggml-cpu/ops.cpp | 144 +++++++++++++++++------- ggml/src/ggml-cpu/unary-ops.cpp | 2 +- ggml/src/ggml-metal/ggml-metal-device.m | 4 +- ggml/src/ggml-metal/ggml-metal.metal | 84 ++++++++------ ggml/src/ggml.c | 2 +- 5 files changed, 159 insertions(+), 77 deletions(-) diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index ed45350207e..4352e132807 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -2096,10 +2096,14 @@ static void ggml_compute_forward_gelu_f32( const ggml_tensor * src0 = dst->src[0]; - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); + assert(ggml_is_contiguous_rows(src0)); assert(ggml_are_same_shape(src0, dst)); + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + const int ith = params->ith; const int nth = params->nth; @@ -2113,10 +2117,14 @@ static void ggml_compute_forward_gelu_f32( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); - for (int i1 = ir0; i1 < ir1; i1++) { + for (int ir = ir0; ir < ir1; ++ir) { + const int i3 = ir/(ne02*ne01); + const int i2 = (ir - i3*ne02*ne01)/ne01; + const int i1 = (ir - i3*ne02*ne01 - i2*ne01); + ggml_vec_gelu_f32(nc, - (float *) ((char *) dst->data + i1*( dst->nb[1])), - (float *) ((char *) src0->data + i1*(src0->nb[1]))); + (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1), + (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01)); #ifndef NDEBUG for (int k = 0; k < nc; k++) { @@ -2135,10 +2143,14 @@ static void ggml_compute_forward_gelu_f16( const ggml_tensor * src0 = dst->src[0]; - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); + assert(ggml_is_contiguous_rows(src0)); assert(ggml_are_same_shape(src0, dst)); + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + const int ith = params->ith; const int nth = params->nth; @@ -2152,10 +2164,14 @@ static void ggml_compute_forward_gelu_f16( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); - for (int i1 = ir0; i1 < ir1; i1++) { + for (int ir = ir0; ir < ir1; ++ir) { + const int i3 = ir/(ne02*ne01); + const int i2 = (ir - i3*ne02*ne01)/ne01; + const int i1 = (ir - i3*ne02*ne01 - i2*ne01); + ggml_vec_gelu_f16(nc, - (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])), - (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1]))); + (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1), + (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01)); #ifndef NDEBUG for (int k = 0; k < nc; k++) { @@ -2276,10 +2292,14 @@ static void ggml_compute_forward_gelu_erf_f32( const ggml_tensor * src0 = dst->src[0]; - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); + assert(ggml_is_contiguous_rows(src0)); assert(ggml_are_same_shape(src0, dst)); + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + const int ith = params->ith; const int nth = params->nth; @@ -2293,10 +2313,14 @@ static void ggml_compute_forward_gelu_erf_f32( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); - for (int i1 = ir0; i1 < ir1; i1++) { + for (int ir = ir0; ir < ir1; ++ir) { + const int i3 = ir/(ne02*ne01); + const int i2 = (ir - i3*ne02*ne01)/ne01; + const int i1 = (ir - i3*ne02*ne01 - i2*ne01); + ggml_vec_gelu_erf_f32(nc, - (float *) ((char *) dst->data + i1*( dst->nb[1])), - (float *) ((char *) src0->data + i1*(src0->nb[1]))); + (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1), + (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01)); #ifndef NDEBUG for (int k = 0; k < nc; k++) { @@ -2315,10 +2339,14 @@ static void ggml_compute_forward_gelu_erf_f16( const ggml_tensor * src0 = dst->src[0]; - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); + assert(ggml_is_contiguous_rows(src0)); assert(ggml_are_same_shape(src0, dst)); + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + const int ith = params->ith; const int nth = params->nth; @@ -2332,10 +2360,14 @@ static void ggml_compute_forward_gelu_erf_f16( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); - for (int i1 = ir0; i1 < ir1; i1++) { + for (int ir = ir0; ir < ir1; ++ir) { + const int i3 = ir/(ne02*ne01); + const int i2 = (ir - i3*ne02*ne01)/ne01; + const int i1 = (ir - i3*ne02*ne01 - i2*ne01); + ggml_vec_gelu_erf_f16(nc, - (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])), - (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1]))); + (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1), + (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01)); #ifndef NDEBUG for (int k = 0; k < nc; k++) { @@ -2379,10 +2411,14 @@ static void ggml_compute_forward_gelu_quick_f32( const ggml_tensor * src0 = dst->src[0]; - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); + assert(ggml_is_contiguous_rows(src0)); assert(ggml_are_same_shape(src0, dst)); + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + const int ith = params->ith; const int nth = params->nth; @@ -2396,10 +2432,14 @@ static void ggml_compute_forward_gelu_quick_f32( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); - for (int i1 = ir0; i1 < ir1; i1++) { + for (int ir = ir0; ir < ir1; ++ir) { + const int i3 = ir/(ne02*ne01); + const int i2 = (ir - i3*ne02*ne01)/ne01; + const int i1 = (ir - i3*ne02*ne01 - i2*ne01); + ggml_vec_gelu_quick_f32(nc, - (float *) ((char *) dst->data + i1*( dst->nb[1])), - (float *) ((char *) src0->data + i1*(src0->nb[1]))); + (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1), + (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01)); #ifndef NDEBUG for (int k = 0; k < nc; k++) { @@ -2418,10 +2458,14 @@ static void ggml_compute_forward_gelu_quick_f16( const ggml_tensor * src0 = dst->src[0]; - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); + assert(ggml_is_contiguous_rows(src0)); assert(ggml_are_same_shape(src0, dst)); + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + const int ith = params->ith; const int nth = params->nth; @@ -2435,10 +2479,14 @@ static void ggml_compute_forward_gelu_quick_f16( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); - for (int i1 = ir0; i1 < ir1; i1++) { + for (int ir = ir0; ir < ir1; ++ir) { + const int i3 = ir/(ne02*ne01); + const int i2 = (ir - i3*ne02*ne01)/ne01; + const int i1 = (ir - i3*ne02*ne01 - i2*ne01); + ggml_vec_gelu_quick_f16(nc, - (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])), - (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1]))); + (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1), + (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01)); #ifndef NDEBUG for (int k = 0; k < nc; k++) { @@ -2482,10 +2530,14 @@ static void ggml_compute_forward_silu_f32( const ggml_tensor * src0 = dst->src[0]; - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); + assert(ggml_is_contiguous_rows(src0)); assert(ggml_are_same_shape(src0, dst)); + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + const int ith = params->ith; const int nth = params->nth; @@ -2499,10 +2551,14 @@ static void ggml_compute_forward_silu_f32( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); - for (int i1 = ir0; i1 < ir1; i1++) { + for (int ir = ir0; ir < ir1; ++ir) { + const int i3 = ir/(ne02*ne01); + const int i2 = (ir - i3*ne02*ne01)/ne01; + const int i1 = (ir - i3*ne02*ne01 - i2*ne01); + ggml_vec_silu_f32(nc, - (float *) ((char *) dst->data + i1*( dst->nb[1])), - (float *) ((char *) src0->data + i1*(src0->nb[1]))); + (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1), + (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01)); #ifndef NDEBUG for (int k = 0; k < nc; k++) { @@ -2521,10 +2577,14 @@ static void ggml_compute_forward_silu_f16( const ggml_tensor * src0 = dst->src[0]; - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); + assert(ggml_is_contiguous_rows(src0)); assert(ggml_are_same_shape(src0, dst)); + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + const int ith = params->ith; const int nth = params->nth; @@ -2538,10 +2598,14 @@ static void ggml_compute_forward_silu_f16( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); - for (int i1 = ir0; i1 < ir1; i1++) { + for (int ir = ir0; ir < ir1; ++ir) { + const int i3 = ir/(ne02*ne01); + const int i2 = (ir - i3*ne02*ne01)/ne01; + const int i1 = (ir - i3*ne02*ne01 - i2*ne01); + ggml_vec_silu_f16(nc, - (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])), - (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1]))); + (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1), + (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01)); #ifndef NDEBUG for (int k = 0; k < nc; k++) { diff --git a/ggml/src/ggml-cpu/unary-ops.cpp b/ggml/src/ggml-cpu/unary-ops.cpp index 1d9873ad0f2..1d8344436f0 100644 --- a/ggml/src/ggml-cpu/unary-ops.cpp +++ b/ggml/src/ggml-cpu/unary-ops.cpp @@ -111,7 +111,7 @@ template static void apply_unary_op(const ggml_compute_params * params, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; - GGML_ASSERT(ggml_is_contiguous_1(src0) && ggml_is_contiguous_1(dst) && ggml_are_same_shape(src0, dst)); + GGML_ASSERT(ggml_is_contiguous_rows(src0) && ggml_is_contiguous_rows(dst) && ggml_are_same_shape(src0, dst)); GGML_TENSOR_UNARY_OP_LOCALS diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index c714ef3add9..4ea0bfb94da 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -1019,7 +1019,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_SIN: case GGML_OP_COS: case GGML_OP_LOG: - return ggml_is_contiguous_rows(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; + return ggml_is_contiguous_rows(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16); case GGML_OP_UNARY: switch (ggml_get_unary_op(op)) { case GGML_UNARY_OP_TANH: @@ -1039,7 +1039,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_UNARY_OP_EXP: case GGML_UNARY_OP_SOFTPLUS: case GGML_UNARY_OP_EXPM1: - return ggml_is_contiguous_rows(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; + return ggml_is_contiguous_rows(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16); default: return false; } diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index a385a50b942..0036ba90ec9 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -910,7 +910,7 @@ constant float a4_erf = -1.453152027f; constant float a5_erf = 1.061405429f; template -T erf_approx(T x) { +inline T erf_approx(T x) { T sign_x = sign(x); x = fabs(x); T t = 1.0f / (1.0f + p_erf * x); @@ -918,10 +918,27 @@ T erf_approx(T x) { return sign_x * y; } +template T elu_approx(T x); + +template<> inline float elu_approx(float x) { + return (x > 0.f) ? x : (exp(x) - 1); +} + +template<> inline float4 elu_approx(float4 x) { + float4 res; + + res[0] = (x[0] > 0.0f) ? x[0] : (exp(x[0]) - 1.0f); + res[1] = (x[1] > 0.0f) ? x[1] : (exp(x[1]) - 1.0f); + res[2] = (x[2] > 0.0f) ? x[2] : (exp(x[2]) - 1.0f); + res[3] = (x[3] > 0.0f) ? x[3] : (exp(x[3]) - 1.0f); + + return res; +} + constant short FC_unary_op [[function_constant(FC_UNARY + 0)]]; constant bool FC_unary_cnt[[function_constant(FC_UNARY + 1)]]; -template +template kernel void kernel_unary_impl( constant ggml_metal_kargs_unary & args, device const char * src0, @@ -963,111 +980,111 @@ kernel void kernel_unary_impl( } } - device const T0 & x = src0_ptr[i0]; + const TC x = (TC) src0_ptr[i0]; if (FC_OP == OP_UNARY_NUM_SCALE) { - dst_ptr[i0] = args.scale * x + args.bias; + dst_ptr[i0] = (T) (args.scale * x + args.bias); } if (FC_OP == OP_UNARY_NUM_FILL) { - dst_ptr[i0] = args.val; + dst_ptr[i0] = (T) args.val; } if (FC_OP == OP_UNARY_NUM_CLAMP) { - dst_ptr[i0] = clamp(x, args.min, args.max); + dst_ptr[i0] = (T) clamp(x, args.min, args.max); } if (FC_OP == OP_UNARY_NUM_SQR) { - dst_ptr[i0] = x * x; + dst_ptr[i0] = (T) (x * x); } if (FC_OP == OP_UNARY_NUM_SQRT) { - dst_ptr[i0] = sqrt(x); + dst_ptr[i0] = (T) sqrt(x); } if (FC_OP == OP_UNARY_NUM_SIN) { - dst_ptr[i0] = sin(x); + dst_ptr[i0] = (T) sin(x); } if (FC_OP == OP_UNARY_NUM_COS) { - dst_ptr[i0] = cos(x); + dst_ptr[i0] = (T) cos(x); } if (FC_OP == OP_UNARY_NUM_LOG) { - dst_ptr[i0] = log(x); + dst_ptr[i0] = (T) log(x); } if (FC_OP == OP_UNARY_NUM_LEAKY_RELU) { - dst_ptr[i0] = T(x > 0.0f)*x + T(x <= 0.0f)*(x * args.slope); + dst_ptr[i0] = (T) (TC(x > 0)*x + TC(x <= 0)*(x * args.slope)); } if (FC_OP == OP_UNARY_NUM_TANH) { - dst_ptr[i0] = precise::tanh(x); + dst_ptr[i0] = (T) precise::tanh(x); } if (FC_OP == OP_UNARY_NUM_RELU) { - dst_ptr[i0] = fmax(0.0f, x); + dst_ptr[i0] = (T) fmax(0, x); } if (FC_OP == OP_UNARY_NUM_SIGMOID) { - dst_ptr[i0] = 1.0f / (1.0f + exp(-x)); + dst_ptr[i0] = (T) (1 / (1 + exp(-x))); } if (FC_OP == OP_UNARY_NUM_GELU) { - dst_ptr[i0] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); + dst_ptr[i0] = (T) (0.5*x*(1 + precise::tanh(SQRT_2_OVER_PI*x*(1 + GELU_COEF_A*x*x)))); } if (FC_OP == OP_UNARY_NUM_GELU_ERF) { - dst_ptr[i0] = 0.5f*x*(1.0f + erf_approx(SQRT_2_INV*x)); + dst_ptr[i0] = (T) (0.5*x*(1 + erf_approx(SQRT_2_INV*x))); } if (FC_OP == OP_UNARY_NUM_GELU_QUICK) { - dst_ptr[i0] = x * (1.0f/(1.0f + exp(GELU_QUICK_COEF*x))); + dst_ptr[i0] = (T) (x * (1/(1 + exp(GELU_QUICK_COEF*x)))); } if (FC_OP == OP_UNARY_NUM_SILU) { - dst_ptr[i0] = x / (1.0f + exp(-x)); + dst_ptr[i0] = (T) (x / (1 + exp(-x))); } if (FC_OP == OP_UNARY_NUM_ELU) { - dst_ptr[i0] = T(x > 0.0f)*x + T(x <= 0.0f)*(exp(x) - 1.0f); + dst_ptr[i0] = (T) elu_approx(x); } if (FC_OP == OP_UNARY_NUM_NEG) { - dst_ptr[i0] = -x; + dst_ptr[i0] = (T) -x; } if (FC_OP == OP_UNARY_NUM_ABS) { - dst_ptr[i0] = fabs(x); + dst_ptr[i0] = (T) fabs(x); } if (FC_OP == OP_UNARY_NUM_SGN) { - dst_ptr[i0] = T(x > 0.0f) - T(x < 0.0f); + dst_ptr[i0] = T(x > 0) - T(x < 0); } if (FC_OP == OP_UNARY_NUM_STEP) { - dst_ptr[i0] = T(x > 0.0f); + dst_ptr[i0] = T(x > 0); } if (FC_OP == OP_UNARY_NUM_HARDSWISH) { - dst_ptr[i0] = x * fmax(0.0f, fmin(1.0f, x/6.0f + 0.5f)); + dst_ptr[i0] = (T) (x * fmax(0, fmin(1, x/6 + 0.5))); } if (FC_OP == OP_UNARY_NUM_HARDSIGMOID) { - dst_ptr[i0] = fmax(0.0f, fmin(1.0f, x/6.0f + 0.5f)); + dst_ptr[i0] = (T) fmax(0, fmin(1, x/6 + 0.5)); } if (FC_OP == OP_UNARY_NUM_EXP) { - dst_ptr[i0] = exp(x); + dst_ptr[i0] = (T) exp(x); } if (FC_OP == OP_UNARY_NUM_SOFTPLUS) { - dst_ptr[i0] = select(log(1.0f + exp(x)), x, x > 20.0f); + dst_ptr[i0] = (T) select(log(1 + exp(x)), x, x > 20); } if (FC_OP == OP_UNARY_NUM_EXPM1) { // TODO: precise implementation - dst_ptr[i0] = exp(x) - 1.0f; + dst_ptr[i0] = (T) (exp(x) - 1); } } @@ -1075,11 +1092,12 @@ kernel void kernel_unary_impl( #undef FC_CNT } -typedef decltype(kernel_unary_impl) kernel_unary_t; - -template [[host_name("kernel_unary_f32_f32")]] kernel kernel_unary_t kernel_unary_impl; -template [[host_name("kernel_unary_f32_f32_4")]] kernel kernel_unary_t kernel_unary_impl; +typedef decltype(kernel_unary_impl) kernel_unary_t; +template [[host_name("kernel_unary_f32_f32")]] kernel kernel_unary_t kernel_unary_impl; +template [[host_name("kernel_unary_f32_f32_4")]] kernel kernel_unary_t kernel_unary_impl; +template [[host_name("kernel_unary_f16_f16")]] kernel kernel_unary_t kernel_unary_impl; +template [[host_name("kernel_unary_f16_f16_4")]] kernel kernel_unary_t kernel_unary_impl; // OP: 0 - add, 1 - sub, 2 - mul, 3 - div constant short FC_bin_op [[function_constant(FC_BIN + 0)]]; diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 500cb6b72f9..e2a6ff67be7 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -5749,7 +5749,7 @@ static struct ggml_tensor * ggml_unary_impl( struct ggml_tensor * a, enum ggml_unary_op op, bool inplace) { - GGML_ASSERT(ggml_is_contiguous_1(a)); + GGML_ASSERT(ggml_is_contiguous_rows(a)); struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); From 0326fd37dd4170f0264a909c1ab58ec230dee02b Mon Sep 17 00:00:00 2001 From: lhez Date: Wed, 11 Feb 2026 10:33:13 -0800 Subject: [PATCH 133/831] opencl: add general Q6_K mm and Q4_K mv (llama/19347) * opencl: add general q6_k mm * opencl: refine condition for q6_K mm * opencl: add general q4_K mv * opencl: fix whitespace --- ggml/src/ggml-opencl/CMakeLists.txt | 2 + ggml/src/ggml-opencl/ggml-opencl.cpp | 123 +++++++++++- .../kernels/mul_mm_q6_k_f32_l4_lm.cl | 158 +++++++++++++++ .../ggml-opencl/kernels/mul_mv_q4_k_f32.cl | 180 ++++++++++++++++++ 4 files changed, 461 insertions(+), 2 deletions(-) create mode 100644 ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl create mode 100644 ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index fa5fadd112b..b6094fb68b0 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -85,6 +85,7 @@ set(GGML_OPENCL_KERNELS mul_mv_q4_0_f32_8x_flat mul_mv_q4_0_f32_1d_8x_flat mul_mv_q4_0_f32_1d_16x_flat + mul_mv_q4_k_f32 mul_mv_q6_k_f32 mul_mv_q6_k_f32_flat mul_mv_q8_0_f32 @@ -101,6 +102,7 @@ set(GGML_OPENCL_KERNELS mul_mm_f32_f32_l4_lm mul_mm_f16_f32_l4_lm mul_mm_q8_0_f32_l4_lm + mul_mm_q6_k_f32_l4_lm mul_mm_q8_0_f32_8x4 gemv_noshuffle_general_q8_0_f32 mul diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 508b2b8f037..40474c193bb 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -532,6 +532,7 @@ struct ggml_backend_opencl_context { cl_kernel kernel_restore_block_q4_0_noshuffle; cl_kernel kernel_convert_block_q6_K, kernel_restore_block_q6_K; cl_kernel kernel_mul_mat_q4_0_f32_1d_8x_flat, kernel_mul_mat_q4_0_f32_1d_16x_flat; + cl_kernel kernel_mul_mv_q4_K_f32; cl_kernel kernel_mul_mv_q6_K_f32; cl_kernel kernel_mul_mv_q6_K_f32_flat; cl_kernel kernel_mul_mv_mxfp4_f32, kernel_mul_mv_mxfp4_f32_flat; @@ -564,6 +565,7 @@ struct ggml_backend_opencl_context { cl_kernel kernel_mul_mm_f32_f32_l4_lm; cl_kernel kernel_mul_mm_f16_f32_l4_lm; cl_kernel kernel_mul_mm_q8_0_f32_l4_lm; + cl_kernel kernel_mul_mm_q6_k_f32_l4_lm; std::vector profiling_info; @@ -1117,6 +1119,23 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // mul_mv_q4_k_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_q4_k_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_q4_k_f32.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mv_q4_K_f32 = clCreateKernel(prog, "kernel_mul_mv_q4_K_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + // mul_mv_q6_k_f32 { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -1358,6 +1377,23 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // mul_mm_q6_k_f32_l4_lm + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mm_q6_k_f32_l4_lm.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mm_q6_k_f32_l4_lm.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mm_q6_k_f32_l4_lm = clCreateKernel(prog, "kernel_mul_mm_q6_k_f32_l4_lm", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + // mul_mm_f16_f32_kq_kqv { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -3364,6 +3400,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te } else if (op->src[0]->type == GGML_TYPE_F32) { return op->src[1]->type == GGML_TYPE_F32; } else if (op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_MXFP4 || + op->src[0]->type == GGML_TYPE_Q4_K || op->src[0]->type == GGML_TYPE_Q6_K) { return op->src[1]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]); } else if (op->src[0]->type == GGML_TYPE_Q8_0) { @@ -8927,6 +8964,50 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); return; } + case GGML_TYPE_Q6_K: { + if (ne11 < 32) { + break; + } + if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) { + break; + } + + kernel = backend_ctx->kernel_mul_mm_q6_k_f32_l4_lm; + nth0 = 128; // calculated as (BM*BN)/(TM*TN) + + int batch_stride_a = ne00*ne01; + int batch_stride_b = ne10*ne11; + int batch_stride_d = ne0*ne1; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q6_K->ql)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q6_K->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q6_K->s)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q6_K->d)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne10)); // stride_a + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne10)); // stride_b + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne01)); // stride_d + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &batch_stride_a)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &batch_stride_b)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &batch_stride_d)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int), &r3)); + + // 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed. + size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13}; + size_t local_work_size[] = {(size_t)nth0, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + return; + } default: break; } @@ -9262,7 +9343,42 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co } case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: - case GGML_TYPE_Q4_K: + case GGML_TYPE_Q4_K: { + kernel = backend_ctx->kernel_mul_mv_q4_K_f32; + + if (backend_ctx->gpu_family == INTEL) { + nth0 = 16; + nth1 = 1; + ndst = 4; + } else if (backend_ctx->gpu_family == ADRENO) { + nth0 = 64; + nth1 = 1; + ndst = 4; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(int), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb13)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &r3)); + break; + } case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: #ifdef GGML_OPENCL_SOA_Q @@ -9424,7 +9540,10 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); } else if (src0t == GGML_TYPE_Q4_K) { - GGML_ASSERT(false && "not implemented"); + size_t global_work_size[] = {(size_t)(ne01+ndst*nth1-1)/(ndst*nth1)*nth0, (size_t)ne11*nth1, (size_t)ne12*ne13}; + size_t local_work_size[] = {(size_t)nth0, (size_t)nth1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); } else if (src0t == GGML_TYPE_Q3_K) { GGML_ASSERT(false && "not implemented"); } else if (src0t == GGML_TYPE_Q5_K) { diff --git a/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl b/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl new file mode 100644 index 00000000000..3602c92fef4 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl @@ -0,0 +1,158 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#define LOAD_VEC_A 2 +#define LOAD_VEC_B 4 + +#define BM 64 +#define BN 64 +#define BK 32 +#define TM 4 +#define TN 8 + +kernel void kernel_mul_mm_q6_k_f32_l4_lm( + global uchar * src0_ql, + global uchar * src0_qh, + global char * src0_s, + global half * src0_d, + global float4 * src1, + ulong offset1, + global float * dst, + ulong offsetd, + + int ne00, + int ne01, + int ne02, + int ne11, + int ne12, + + int stride_a, + int stride_b, + int stride_d, + + int batch_stride_a, + int batch_stride_b, + int batch_stride_d, + + int r2, + int r3 +) { + src1 = (global float4*)((global char*)src1 + offset1); + dst = (global float *)((global char*)dst + offsetd); + + local float buf_a[BM * BK]; + local float buf_b[BN * BK]; + + const int batch_idx = get_global_id(2); + + const int i13 = batch_idx / ne12; + const int i12 = batch_idx % ne12; + + const int i03 = i13 / r3; + const int i02 = i12 / r2; + + const int batch_idx_a = i03 * ne02 + i02; + + const int ir = get_group_id(0); + const int ic = get_group_id(1); + + const int tid = get_local_id(0); + const int th_r = tid % (BM / TM); + const int th_c = tid / (BM / TM); + + const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A); + const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A); + const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B); + const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B); + + const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK; + const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK; + + int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A; + int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B; + + float sums[TM * TN]; + float cache_a[TM]; + float cache_b[TN]; + + for (int i = 0; i < TM * TN; i++) { + sums[i] = 0.0f; + } + + for (int block = 0; block < ne00; block += BK) { + for (int l = 0; l < BM; l += loadstride_a) { + if (ir*BM + loadc_a + l < ne01) { + int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a; + + int ib = idx / 128; // 2 values per idx + int iqs = idx % 128; // 0..127 + + int n = iqs / 64; // 0,1 + int b = (iqs % 64) / 32; // 0,1 + int is_b = (iqs % 16) / 8; // 0,1 + int qhshift = ((iqs % 64) / 16) * 2; // 0,2,4,6 + int is = 8 * n + qhshift + is_b; // 0..15 + int qsi = n * 64 + (iqs % 32) * 2; // 0,2,4..126 + int qhi = n * 32 + (iqs % 16) * 2; // 0,2,4..62 + + float dscale = (float)src0_d[ib] * (float)src0_s[ib*16 + is]; + + buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = dscale * convert_float(convert_char(((src0_ql[128*ib + qsi + 0] >> (b * 4)) & 0xF) | (((src0_qh[64*ib + qhi + 0] >> qhshift) & 3) << 4)) - 32); + buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = dscale * convert_float(convert_char(((src0_ql[128*ib + qsi + 1] >> (b * 4)) & 0xF) | (((src0_qh[64*ib + qhi + 1] >> qhshift) & 3) << 4)) - 32); + } else { + buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = 0.0f; + } + } + + for (int l = 0; l < BN; l += loadstride_b) { + if (ic*BN + loadc_b + l < ne11) { + int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b; + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3; + } else { + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f; + } + } + + barrier(CLK_LOCAL_MEM_FENCE); + + pos_a += BK / LOAD_VEC_A; + pos_b += BK / LOAD_VEC_B; + + for (int i = 0; i < BK; i++) { + for (int j = 0; j < TM; j++) { + cache_a[j] = buf_a[(i) * BM + th_r * TM + j]; + } + + for (int j = 0; j < TN; j++) { + cache_b[j] = buf_b[(i) * BN + th_c * TN + j]; + } + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + const int sums_idx = cc*TM + cr; + sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]); + } + } + } + barrier(CLK_LOCAL_MEM_FENCE); + } + + const int dr = ir * BM + th_r * TM; + const int dc = ic * BN + th_c * TN; + + const int offsets = batch_idx * batch_stride_d; + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + if (dr + cr < ne01 && dc + cc < ne11) { + dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr]; + } + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl new file mode 100644 index 00000000000..71ab9898213 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl @@ -0,0 +1,180 @@ +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +//------------------------------------------------------------------------------ +// block_q4_K +//------------------------------------------------------------------------------ +#define QK_K 256 +#define K_SCALE_SIZE 12 + +// 8 blocks of 32 elements each +// weight is represented as x = a * q + b +typedef struct { + half d; // super-block scale for quantized scales + half dmin; // super-block scale for quantized mins + + uchar scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits + uchar qs[QK_K/2]; // 4-bit quants +} block_q4_K; + +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 4 // number of rows each SIMD group works on +#define N_SIMDGROUP 1 // number of SIMD groups in a thread group +#define N_SIMDWIDTH 16 // SIMD group size +#elif defined (ADRENO_GPU) +#define N_DST 4 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif + +#undef BLOCK_STRIDE +// number of (super) blocks each subgroup processes +// each thread in a subgroup processes a block (32 weights) +#define BLOCK_STRIDE (N_SIMDWIDTH/8) + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mv_q4_K_f32( + global char * src0, + int offset0, + global char * src1, + int offset1, + global char * dst, + int offsetd, + int ne00, + int ne01, + ulong nb01, + ulong nb02, + ulong nb03, + int ne12, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + int ne1, + int r2, + int r3 +) { + src0 = src0 + offset0; + src1 = src1 + offset1; + dst = dst + offsetd; + + ushort kmask1 = 0x3f3f; + ushort kmask2 = 0x0f0f; + ushort kmask3 = 0xc0c0; + + int ix = get_sub_group_local_id()/8; // super block index + int it = get_sub_group_local_id()%8; // block index (inside super block) + int iq = it/4; // 0 or 1 - first or second half of the super block + int ir = it%4; // 0...3 - block index in the half super block + + int nb = ne00/QK_K; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + int offset_src0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + int offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + global block_q4_K * x = (global block_q4_K *) (src0 + offset_src0); + global float * y = (global float *) (src1 + offset_src1); + + float yl[16]; + float yh[16]; + float sumf[N_DST] = {0.f}; + float all_sum; + + global float * y4 = y + ix * QK_K + 64 * iq + 8 * ir; + + ushort sc16[4]; + uchar * sc8 = (uchar *)sc16; + + for (int ib = ix; ib < nb; ib += BLOCK_STRIDE) { + float4 sumy = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; ++i) { + yl[i+0] = y4[i+0]; + sumy.s0 += yl[i+0]; + + yl[i+8] = y4[i+32]; + sumy.s1 += yl[i+8]; + + yh[i+0] = y4[i+128]; + sumy.s2 += yh[i+0]; + + yh[i+8] = y4[i+160]; + sumy.s3 += yh[i+8]; + } + + global ushort * sc = (global ushort *)x[ib].scales + iq; + global ushort * q1 = (global ushort *)x[ib].qs + 16 * iq + 4 * ir; + global half * dh = &x[ib].d; + + for (int row = 0; row < N_DST; row++) { + sc16[0] = sc[0] & kmask1; + sc16[1] = sc[2] & kmask1; + sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2); + sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2); + + global ushort * q2 = q1 + 32; + + float4 acc1 = {0.f, 0.f, 0.f, 0.f}; + float4 acc2 = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; i += 2) { + acc1.s0 += yl[i+0] * (q1[i/2] & 0x000F); + acc1.s1 += yl[i+1] * (q1[i/2] & 0x0F00); + acc1.s2 += yl[i+8] * (q1[i/2] & 0x00F0); + acc1.s3 += yl[i+9] * (q1[i/2] & 0xF000); + acc2.s0 += yh[i+0] * (q2[i/2] & 0x000F); + acc2.s1 += yh[i+1] * (q2[i/2] & 0x0F00); + acc2.s2 += yh[i+8] * (q2[i/2] & 0x00F0); + acc2.s3 += yh[i+9] * (q2[i/2] & 0xF000); + } + + float dall = dh[0]; + float dmin = dh[1]; + sumf[row] += dall * ((acc1.s0 + 1.f/256.f * acc1.s1) * sc8[0] + + (acc1.s2 + 1.f/256.f * acc1.s3) * sc8[1] * 1.f/16.f + + (acc2.s0 + 1.f/256.f * acc2.s1) * sc8[4] + + (acc2.s2 + 1.f/256.f * acc2.s3) * sc8[5] * 1.f/16.f) - + dmin * (sumy.s0 * sc8[2] + sumy.s1 * sc8[3] + sumy.s2 * sc8[6] + sumy.s3 * sc8[7]); + + q1 += nb01/2; + sc += nb01/2; + dh += nb01/2; + } + + y4 += BLOCK_STRIDE * QK_K; + } + + global float * dst_f32 = (global float *) dst + im*ne0*ne1 + r1*ne0; + + for (int row = 0; row < N_DST; ++row) { + all_sum = sub_group_reduce_add(sumf[row]); + if (first_row + row < ne01) { + if (get_sub_group_local_id() == 0) { + dst_f32[first_row + row] = all_sum; + } + } + } +} From 304205679c650c3e5977e0b01b7b9bd022336767 Mon Sep 17 00:00:00 2001 From: Max Krasnyansky Date: Wed, 11 Feb 2026 23:04:27 -0800 Subject: [PATCH 134/831] hexagon: further optimization and tuning of matmul and dot kernels (llama/19407) * ggml-hexagon: implement 2x2 matmul kernel * hexmm: implement vec_dot_rx2x2 for Q8_0 and MXFP4 * hexagon: fix editor config failures * hexagon: refactor matmul ops to use context struct and remove wrappers Also implement vec_dot_f16 2x2 * hexagon: refactor dyn quantizers to use mmctx * hexagon: remove mm fastdiv from op_ctx * hexagon: refactor matmul entry point to reduce code duplication --------- Co-authored-by: Trivikram Reddy --- ggml/src/ggml-hexagon/htp/htp-ops.h | 13 - ggml/src/ggml-hexagon/htp/matmul-ops.c | 1505 +++++++++++++----------- 2 files changed, 847 insertions(+), 671 deletions(-) diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index c0d72587ce5..f1ad24dbfaa 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -64,25 +64,12 @@ struct htp_ops_context { struct fastdiv_values broadcast_rv2; struct fastdiv_values broadcast_rv3; - struct fastdiv_values mm_div_ne12_ne1; // fastdiv values for ne12 * ne1 - struct fastdiv_values mm_div_ne1; // fastdiv values for ne1 - struct fastdiv_values mm_div_r2; // fastdiv values for ne12 / ne02 - struct fastdiv_values mm_div_r3; // fastdiv values for ne13 / ne03 - struct fastdiv_values set_rows_div_ne12; // fastdiv values for ne12 struct fastdiv_values set_rows_div_ne11; // fastdiv values for ne11 struct fastdiv_values get_rows_div_ne10; // fastdiv values for ne10 struct fastdiv_values get_rows_div_ne10_ne11; // fastdiv values for ne10 * ne11 - struct fastdiv_values cpy_div_ne01; // fastdiv values for ne01 - struct fastdiv_values cpy_div_ne02; // fastdiv values for ne02 - struct fastdiv_values cpy_div_ne03; // fastdiv values for ne03 - - struct fastdiv_values cpy_rshp_div_n0; // fastdiv values for ne00 - struct fastdiv_values cpy_rshp_div_n1n0; // fastdiv values for ne00*ne01 - struct fastdiv_values cpy_rshp_div_n2n1n0; // fastdiv values for ne00*ne01*ne02 - uint32_t flags; }; diff --git a/ggml/src/ggml-hexagon/htp/matmul-ops.c b/ggml/src/ggml-hexagon/htp/matmul-ops.c index d251eeed33a..c360abe8dae 100644 --- a/ggml/src/ggml-hexagon/htp/matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/matmul-ops.c @@ -23,10 +23,30 @@ #define MM_SPAD_SRC1_NROWS 16 #define MM_SPAD_DST_NROWS 2 -struct htp_matmul_type { +struct htp_matmul_context { const char * type; - void (*vec_dot)(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); - void (*vec_dot_rx2)(const int n, float * restrict s, const void * restrict vx, uint32_t vx_row_size, const void * restrict vy); + struct htp_ops_context * octx; + + void (*vec_dot_1x1)(const int n, float * restrict s0, + const void * restrict vx0, + const void * restrict vy0); + + void (*vec_dot_2x1)(const int n, float * restrict s0, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0); + + void (*vec_dot_2x2)(const int n, float * restrict s0, float * restrict s1, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0, const void * restrict vy1); + + // Precomputed values + uint32_t src0_nrows_per_thread; + uint32_t src1_nrows_per_thread; + + struct fastdiv_values mm_div_ne12_ne1; + struct fastdiv_values mm_div_ne1; + struct fastdiv_values mm_div_r2; + struct fastdiv_values mm_div_r3; }; // vdelta control to replicate first 4x fp32 values across lanes @@ -122,6 +142,7 @@ static inline HVX_Vector_x8 hvx_vec_load_q4x4x8(const uint8_t * restrict ptr) { HVX_Vector v6_7 = vptr[3]; // ... const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + const HVX_Vector i8 = Q6_Vb_vsplat_R(8); HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4); // & 0x0F HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4); // >> 4 @@ -133,15 +154,14 @@ static inline HVX_Vector_x8 hvx_vec_load_q4x4x8(const uint8_t * restrict ptr) { HVX_Vector v7 = Q6_Vub_vlsr_VubR(v6_7, 4); // >> 4 // Convert uint4 to int4 (i.e. x - 8) - const HVX_Vector i8 = Q6_Vb_vsplat_R(8); - v0 = Q6_Vb_vsub_VbVb(v0, i8); - v1 = Q6_Vb_vsub_VbVb(v1, i8); - v2 = Q6_Vb_vsub_VbVb(v2, i8); - v3 = Q6_Vb_vsub_VbVb(v3, i8); - v4 = Q6_Vb_vsub_VbVb(v4, i8); - v5 = Q6_Vb_vsub_VbVb(v5, i8); - v6 = Q6_Vb_vsub_VbVb(v6, i8); - v7 = Q6_Vb_vsub_VbVb(v7, i8); + v0 = Q6_Vb_vsub_VbVb(v0, i8); + v1 = Q6_Vb_vsub_VbVb(v1, i8); + v2 = Q6_Vb_vsub_VbVb(v2, i8); + v3 = Q6_Vb_vsub_VbVb(v3, i8); + v4 = Q6_Vb_vsub_VbVb(v4, i8); + v5 = Q6_Vb_vsub_VbVb(v5, i8); + v6 = Q6_Vb_vsub_VbVb(v6, i8); + v7 = Q6_Vb_vsub_VbVb(v7, i8); HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 }; return r; @@ -156,6 +176,7 @@ static inline HVX_Vector_x8 hvx_vec_load_mxfp4x4x8(const uint8_t * restrict ptr) HVX_Vector v6_7 = vptr[3]; // ... const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + const HVX_Vector lut = *(const HVX_Vector *) kvalues_mxfp4_lut; HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4); // & 0x0F HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4); // >> 4 @@ -166,15 +187,14 @@ static inline HVX_Vector_x8 hvx_vec_load_mxfp4x4x8(const uint8_t * restrict ptr) HVX_Vector v6 = Q6_V_vand_VV(v6_7, mask_h4); // & 0x0F HVX_Vector v7 = Q6_Vub_vlsr_VubR(v6_7, 4); // >> 4 - HVX_Vector lut = *(const HVX_Vector *) kvalues_mxfp4_lut; - v0 = Q6_Vb_vlut32_VbVbI(v0, lut, 0); - v1 = Q6_Vb_vlut32_VbVbI(v1, lut, 0); - v2 = Q6_Vb_vlut32_VbVbI(v2, lut, 0); - v3 = Q6_Vb_vlut32_VbVbI(v3, lut, 0); - v4 = Q6_Vb_vlut32_VbVbI(v4, lut, 0); - v5 = Q6_Vb_vlut32_VbVbI(v5, lut, 0); - v6 = Q6_Vb_vlut32_VbVbI(v6, lut, 0); - v7 = Q6_Vb_vlut32_VbVbI(v7, lut, 0); + v0 = Q6_Vb_vlut32_VbVbI(v0, lut, 0); + v1 = Q6_Vb_vlut32_VbVbI(v1, lut, 0); + v2 = Q6_Vb_vlut32_VbVbI(v2, lut, 0); + v3 = Q6_Vb_vlut32_VbVbI(v3, lut, 0); + v4 = Q6_Vb_vlut32_VbVbI(v4, lut, 0); + v5 = Q6_Vb_vlut32_VbVbI(v5, lut, 0); + v6 = Q6_Vb_vlut32_VbVbI(v6, lut, 0); + v7 = Q6_Vb_vlut32_VbVbI(v7, lut, 0); HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 }; return r; @@ -196,46 +216,6 @@ static inline HVX_Vector_x8 hvx_vec_load_q8x4x8(const uint8_t * restrict ptr) { return r; } -static inline HVX_Vector_x4 hvx_vec_load_x4_f16(const uint8_t * restrict ptr) { - const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; - - HVX_Vector v0 = vptr[0]; // first 64 vals - HVX_Vector v1 = vptr[1]; // second 64 vals - HVX_Vector v2 = vptr[2]; // third 64 vals - HVX_Vector v3 = vptr[3]; // forth 64 vals - - HVX_Vector_x4 r = { v0, v1, v2, v3 }; - return r; -} - -static inline HVX_Vector_x4 hvx_vec_load_x4_f32_as_f16(const uint8_t * restrict ptr) { - const HVX_VectorPair * restrict vptr = (const HVX_VectorPair *) ptr; - - HVX_VectorPair v0 = vptr[0]; // first 64 vals - HVX_VectorPair v1 = vptr[1]; // second 64 vals - HVX_VectorPair v2 = vptr[2]; // third 64 vals - HVX_VectorPair v3 = vptr[3]; // forth 64 vals - - HVX_Vector vq0_lo = Q6_Vqf32_vsub_VsfVsf(Q6_V_lo_W(v0), Q6_V_vzero()); - HVX_Vector vq0_hi = Q6_Vqf32_vsub_VsfVsf(Q6_V_hi_W(v0), Q6_V_vzero()); - HVX_Vector vq1_lo = Q6_Vqf32_vsub_VsfVsf(Q6_V_lo_W(v1), Q6_V_vzero()); - HVX_Vector vq1_hi = Q6_Vqf32_vsub_VsfVsf(Q6_V_hi_W(v1), Q6_V_vzero()); - HVX_Vector vq2_lo = Q6_Vqf32_vsub_VsfVsf(Q6_V_lo_W(v2), Q6_V_vzero()); - HVX_Vector vq2_hi = Q6_Vqf32_vsub_VsfVsf(Q6_V_hi_W(v2), Q6_V_vzero()); - HVX_Vector vq3_lo = Q6_Vqf32_vsub_VsfVsf(Q6_V_lo_W(v3), Q6_V_vzero()); - HVX_Vector vq3_hi = Q6_Vqf32_vsub_VsfVsf(Q6_V_hi_W(v3), Q6_V_vzero()); - - HVX_Vector vh0 = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vq0_hi, vq0_lo)); - HVX_Vector vh1 = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vq1_hi, vq1_lo)); - HVX_Vector vh2 = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vq2_hi, vq2_lo)); - HVX_Vector vh3 = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vq3_hi, vq3_lo)); - - // vcombine does a shuffle, use vdeal to undo - - HVX_Vector_x4 r = { Q6_Vh_vdeal_Vh(vh0), Q6_Vh_vdeal_Vh(vh1), Q6_Vh_vdeal_Vh(vh2), Q6_Vh_vdeal_Vh(vh3) }; - return r; -} - // Reduce multiply 1024 x 1024 int8 elements (32x q4/8 blocks in 8x HVX vectors). // Accumulate each block into a single int32 value. // Return a single HVX vector with 32x int32 accumulators. @@ -300,26 +280,26 @@ static inline HVX_Vector hvx_vec_rmpy_x8_nloe(HVX_Vector_x8 x, HVX_Vector_x8 y, return hvx_vec_rmpy_x8_n(x, y, 1024); } -static void vec_dot_q4x4x2_q8x4x2(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { +static void vec_dot_q4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) { assert(n % 32 == 0); // min sub-block size - assert((unsigned long) vx % 128 == 0); - assert((unsigned long) vy % 128 == 0); + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); const uint32_t qk = QK_Q4_0x4x2 * 4; - const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t x_qblk_size = qk / 2; // int4 - const uint32_t x_qrow_size = n / 2; // int4 (not padded) + const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_qblk_size = qk / 2; // int4 + const uint32_t x_qrow_size = n / 2; // int4 (not padded) - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx + 0); // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx + x_qrow_size); // then scales + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales - const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales + const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales // Row sum (sf) HVX_Vector r0_sum = Q6_V_vsplat_R(0); @@ -372,36 +352,34 @@ static void vec_dot_q4x4x2_q8x4x2(const int n, float * restrict s, const void * r0_sum = hvx_vec_reduce_sum_f32(r0_sum); - hvx_vec_store_u(&s[0], 4, r0_sum); + hvx_vec_store_u(s0, 4, r0_sum); } -static void vec_dot_q4x4x2_q8x4x2_rx2(const int n, - float * restrict s, - const void * restrict vx, - uint32_t vx_row_size, - const void * restrict vy) { +static void vec_dot_q4x4x2_q8x4x2_2x1(const int n, float * restrict s0, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0) { assert(n % 32 == 0); // min sub-block size - assert((unsigned long) vx % 128 == 0); - assert((unsigned long) vy % 128 == 0); + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); const uint32_t qk = QK_Q4_0x4x2 * 4; - const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t x_qblk_size = qk / 2; // int4 - const uint32_t x_qrow_size = n / 2; // int4 (not padded) - - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) + const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_qblk_size = qk / 2; // int4 + const uint32_t x_qrow_size = n / 2; // int4 (not padded) - const uint8_t * restrict r0_x_q = ((const uint8_t *) (vx + (0 * vx_row_size)) + 0); // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) (vx + (0 * vx_row_size)) + x_qrow_size); // then scales + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) - const uint8_t * restrict r1_x_q = ((const uint8_t *) (vx + (1 * vx_row_size)) + 0); // quants first - const uint8_t * restrict r1_x_d = ((const uint8_t *) (vx + (1 * vx_row_size)) + x_qrow_size); // then scales + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales - const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales + const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales // Row sum (sf) HVX_Vector r0_sum = Q6_V_vsplat_R(0); @@ -468,13 +446,143 @@ static void vec_dot_q4x4x2_q8x4x2_rx2(const int n, } HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum); - hvx_vec_store_u(&s[0], 8, rsum); + hvx_vec_store_u(s0, 8, rsum); } -static void vec_dot_q8x4x2_q8x4x2(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { +static void vec_dot_q4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0, const void * restrict vy1) { + assert(n % 32 == 0); + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + assert((unsigned long) vy1 % 128 == 0); + + const uint32_t qk = QK_Q4_0x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_qblk_size = qk / 2; // int4 + const uint32_t x_qrow_size = n / 2; // int4 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales + + const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0; // quants first + const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales + const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0; // quants first + const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; // then scales + + // Row sums (sf) - 4 accumulators for 2×2 tile + HVX_Vector r0_c0_sum = Q6_V_vsplat_R(0); + HVX_Vector r0_c1_sum = Q6_V_vsplat_R(0); + HVX_Vector r1_c0_sum = Q6_V_vsplat_R(0); + HVX_Vector r1_c1_sum = Q6_V_vsplat_R(0); + + const uint32_t nb = n / qk; // num full blocks + const uint32_t nloe = n % qk; // num leftover elements + + uint32_t i = 0; + for (; i < nb; i++) { + // Load src1 columns (reused across both src0 rows) + HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size); + HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size); + + // Load src0 rows (reused across both src1 columns) + HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8(r1_x_q + i * x_qblk_size); + + // Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1 + HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q)); + HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q)); + HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q)); + HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q)); + + // Load scales + HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); + HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + + // Compute combined scales + HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d))); + HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d))); + HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d))); + HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d))); + + // Apply scales and accumulate + HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); + HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); + HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); + HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); + + r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); + r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); + r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); + r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); + } + + // Process leftovers + if (nloe) { + HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size); + HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8(r1_x_q + i * x_qblk_size); + + HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy0_q, nloe)); + HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy1_q, nloe)); + HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy0_q, nloe)); + HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy1_q, nloe)); + + HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); + HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + + HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d))); + HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d))); + HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d))); + HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d))); + + // Zero out unused scales + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd); + r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd); + r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd); + r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd); + r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia); + r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia); + r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia); + r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia); + + HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); + HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); + HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); + HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); + + r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); + r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); + r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); + r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); + } + + // Reduce and store results + HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum); + HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum); + + hvx_vec_store_u(s0, 8, r0_r1_c0_sum); // row0,col0 row1,col0 + hvx_vec_store_u(s1, 8, r0_r1_c1_sum); // row0,col1 row1,col1 +} + +static void vec_dot_q8x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) { assert(n % 32 == 0); // min sub-block size - assert((unsigned long) vx % 128 == 0); - assert((unsigned long) vy % 128 == 0); + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); const uint32_t qk = QK_Q4_0x4x2 * 4; @@ -486,11 +594,11 @@ static void vec_dot_q8x4x2_q8x4x2(const int n, float * restrict s, const void * const uint32_t y_qblk_size = qk; // int8 const uint32_t y_qrow_size = n; // int8 (not padded) - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx + 0); // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx + x_qrow_size); // then scales + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales - const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales + const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales // Row sum (sf) HVX_Vector r0_sum = Q6_V_vsplat_R(0); @@ -543,36 +651,34 @@ static void vec_dot_q8x4x2_q8x4x2(const int n, float * restrict s, const void * r0_sum = hvx_vec_reduce_sum_f32(r0_sum); - hvx_vec_store_u(&s[0], 4, r0_sum); + hvx_vec_store_u(s0, 4, r0_sum); } -static void vec_dot_q8x4x2_q8x4x2_rx2(const int n, - float * restrict s, - const void * restrict vx, - uint32_t vx_row_size, - const void * restrict vy) { +static void vec_dot_q8x4x2_q8x4x2_2x1(const int n, float * restrict s0, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0) { assert(n % 32 == 0); // min sub-block size - assert((unsigned long) vx % 128 == 0); - assert((unsigned long) vy % 128 == 0); + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); const uint32_t qk = QK_Q4_0x4x2 * 4; - const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t x_qblk_size = qk; // int8 - const uint32_t x_qrow_size = n; // int8 (not padded) + const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_qblk_size = qk; // int8 + const uint32_t x_qrow_size = n; // int8 (not padded) - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) - const uint8_t * restrict r0_x_q = ((const uint8_t *) (vx + (0 * vx_row_size)) + 0); // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) (vx + (0 * vx_row_size)) + x_qrow_size); // then scales + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales - const uint8_t * restrict r1_x_q = ((const uint8_t *) (vx + (1 * vx_row_size)) + 0); // quants first - const uint8_t * restrict r1_x_d = ((const uint8_t *) (vx + (1 * vx_row_size)) + x_qrow_size); // then scales - - const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales + const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales // Row sum (qf32) HVX_Vector r0_sum = Q6_V_vsplat_R(0); @@ -639,16 +745,143 @@ static void vec_dot_q8x4x2_q8x4x2_rx2(const int n, } HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum); - hvx_vec_store_u(&s[0], 8, rsum); + hvx_vec_store_u(s0, 8, rsum); +} + +static void vec_dot_q8x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0, const void * restrict vy1) { + assert(n % 32 == 0); + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + assert((unsigned long) vy1 % 128 == 0); + + const uint32_t qk = QK_Q8_0x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_qblk_size = qk; // int8 + const uint32_t x_qrow_size = n; // int8 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales + + const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0; // quants first + const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales + const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0; // quants first + const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; // then scales + + // Row sums (sf) - 4 accumulators for 2×2 tile + HVX_Vector r0_c0_sum = Q6_V_vsplat_R(0); + HVX_Vector r0_c1_sum = Q6_V_vsplat_R(0); + HVX_Vector r1_c0_sum = Q6_V_vsplat_R(0); + HVX_Vector r1_c1_sum = Q6_V_vsplat_R(0); + + const uint32_t nb = n / qk; // num full blocks + const uint32_t nloe = n % qk; // num leftover elements + + uint32_t i = 0; + for (; i < nb; i++) { + // Load src1 columns (reused across both src0 rows) + HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size); + HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size); + + // Load src0 rows (reused across both src1 columns) + HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8(r1_x_q + i * x_qblk_size); + + // Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1 + HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q)); + HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q)); + HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q)); + HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q)); + + // Load scales + HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); + HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + + // Compute combined scales + HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d))); + HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d))); + HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d))); + HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d))); + + // Apply scales and accumulate + HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); + HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); + HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); + HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); + + r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); + r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); + r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); + r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); + } + + // Process leftovers + if (nloe) { + HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size); + HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8(r1_x_q + i * x_qblk_size); + + HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy0_q, nloe)); + HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy1_q, nloe)); + HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy0_q, nloe)); + HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy1_q, nloe)); + + HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); + HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + + HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d))); + HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d))); + HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d))); + HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d))); + + // Zero out unused scales + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd); + r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd); + r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd); + r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd); + r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia); + r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia); + r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia); + r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia); + + HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); + HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); + HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); + HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); + + r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); + r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); + r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); + r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); + } + + // Reduce and store results + HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum); + HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum); + + hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum); // row0,col0 row1,col0 + hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); // row0,col1 row1,col1 } -static void vec_dot_mxfp4x4x2_q8x4x2(const int n, - float * restrict s, - const void * restrict vx, - const void * restrict vy) { +static void vec_dot_mxfp4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) { assert(n % 32 == 0); // min sub-block size - assert((unsigned long) vx % 128 == 0); - assert((unsigned long) vy % 128 == 0); + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); const uint32_t qk = QK_MXFP4x4x2 * 4; @@ -660,11 +893,11 @@ static void vec_dot_mxfp4x4x2_q8x4x2(const int n, const uint32_t y_qblk_size = qk; // int8 const uint32_t y_qrow_size = n; // int8 (not padded) - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx + 0); // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx + x_qrow_size); // then scales + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales - const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales + const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales // Row sum (sf) HVX_Vector r0_sum = Q6_V_vsplat_R(0); @@ -747,36 +980,34 @@ static void vec_dot_mxfp4x4x2_q8x4x2(const int n, r0_sum = hvx_vec_reduce_sum_f32(r0_sum); - hvx_vec_store_u(&s[0], 4, r0_sum); + hvx_vec_store_u(s0, 4, r0_sum); } -static void vec_dot_mxfp4x4x2_q8x4x2_rx2(const int n, - float * restrict s, - const void * restrict vx, - uint32_t vx_row_size, - const void * restrict vy) { +static void vec_dot_mxfp4x4x2_q8x4x2_2x1(const int n, float * restrict s0, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0) { assert(n % 32 == 0); // min sub-block size - assert((unsigned long) vx % 128 == 0); - assert((unsigned long) vy % 128 == 0); + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); const uint32_t qk = QK_MXFP4x4x2 * 4; - const uint32_t x_dblk_size = 8 * 4 * 1; // 32x e8m0 - const uint32_t x_qblk_size = qk / 2; // fp4 - const uint32_t x_qrow_size = n / 2; // fp4 (not padded) + const uint32_t x_dblk_size = 8 * 4 * 1; // 32x e8m0 + const uint32_t x_qblk_size = qk / 2; // fp4 + const uint32_t x_qrow_size = n / 2; // fp4 (not padded) - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) - const uint8_t * restrict r0_x_q = ((const uint8_t *) (vx + (0 * vx_row_size)) + 0); // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) (vx + (0 * vx_row_size)) + x_qrow_size); // then scales + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales - const uint8_t * restrict r1_x_q = ((const uint8_t *) (vx + (1 * vx_row_size)) + 0); // quants first - const uint8_t * restrict r1_x_d = ((const uint8_t *) (vx + (1 * vx_row_size)) + x_qrow_size); // then scales - - const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales + const uint8_t * restrict y_q = ((const uint8_t *) vy0) + 0; // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales // Row sum (sf) HVX_Vector r0_sum = Q6_V_vsplat_R(0); @@ -879,10 +1110,180 @@ static void vec_dot_mxfp4x4x2_q8x4x2_rx2(const int n, } HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum); - hvx_vec_store_u(&s[0], 8, rsum); + hvx_vec_store_u(s0, 8, rsum); } -static void vec_dot_f16_f16_aa(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { +static void vec_dot_mxfp4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0, const void * restrict vy1) { + assert(n % 32 == 0); + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + assert((unsigned long) vy1 % 128 == 0); + + const uint32_t qk = QK_MXFP4x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 1; // 32x e8m0 + const uint32_t x_qblk_size = qk / 2; // fp4 + const uint32_t x_qrow_size = n / 2; // fp4 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales + + const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0; // quants first + const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales + const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0; // quants first + const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; // then scales + + // Row sums (sf) - 4 accumulators for 2×2 tile + HVX_Vector r0_c0_sum = Q6_V_vsplat_R(0); + HVX_Vector r0_c1_sum = Q6_V_vsplat_R(0); + HVX_Vector r1_c0_sum = Q6_V_vsplat_R(0); + HVX_Vector r1_c1_sum = Q6_V_vsplat_R(0); + + const uint32_t nb = n / qk; // num full blocks + const uint32_t nloe = n % qk; // num leftover elements + + uint32_t i = 0; + for (; i < nb; i++) { + // Load src1 columns (reused across both src0 rows) + HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size); + HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size); + + // Load src0 rows (reused across both src1 columns) + HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8(r1_x_q + i * x_qblk_size); + + // Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1 + HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q)); + HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q)); + HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q)); + HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q)); + + // Load scales + HVX_Vector vy0_d = *(const HVX_UVector *) (y0_d + i * y_dblk_size); + HVX_Vector vy1_d = *(const HVX_UVector *) (y1_d + i * y_dblk_size); + HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); + HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); + + // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving + HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16 + vy0_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy0_d), half)); + vy0_d = Q6_Vsf_equals_Vqf32(vy0_d); + vy1_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy1_d), half)); + vy1_d = Q6_Vsf_equals_Vqf32(vy1_d); + + // Convert rX_d scales from e8m0 to fp32 + // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ... + // Left shift with zero fill to create FP32 + // FIXME: might need to handle zero as a special case (see ggml-cpu code) + HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0; + HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff); + r0_d = Q6_V_vdelta_VV(r0_d, expand); + r0_d = Q6_V_vand_VV(r0_d, e8m0_mask); + r0_d = Q6_Vw_vasl_VwR(r0_d, 23); + r1_d = Q6_V_vdelta_VV(r1_d, expand); + r1_d = Q6_V_vand_VV(r1_d, e8m0_mask); + r1_d = Q6_Vw_vasl_VwR(r1_d, 23); + + // Compute combined scales + HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy0_d)); + HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy1_d)); + HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy0_d)); + HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy1_d)); + + // Apply scales and accumulate + HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); + HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); + HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); + HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); + + r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); + r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); + r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); + r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); + } + + // Process leftovers + if (nloe) { + HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size); + HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8(r1_x_q + i * x_qblk_size); + + HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy0_q, nloe)); + HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy1_q, nloe)); + HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy0_q, nloe)); + HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy1_q, nloe)); + + HVX_Vector vy0_d = *(const HVX_UVector *) (y0_d + i * y_dblk_size); + HVX_Vector vy1_d = *(const HVX_UVector *) (y1_d + i * y_dblk_size); + HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); + HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); + + // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving + HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16 + vy0_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy0_d), half)); + vy0_d = Q6_Vsf_equals_Vqf32(vy0_d); + vy1_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy1_d), half)); + vy1_d = Q6_Vsf_equals_Vqf32(vy1_d); + + // Convert rX_d scales from e8m0 to fp32 + // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ... + // Left shift with zero fill to create FP32 + // FIXME: might need to handle zero as a special case (see ggml-cpu code) + HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0; + HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff); + r0_d = Q6_V_vdelta_VV(r0_d, expand); + r0_d = Q6_V_vand_VV(r0_d, e8m0_mask); + r0_d = Q6_Vw_vasl_VwR(r0_d, 23); + r1_d = Q6_V_vdelta_VV(r1_d, expand); + r1_d = Q6_V_vand_VV(r1_d, e8m0_mask); + r1_d = Q6_Vw_vasl_VwR(r1_d, 23); + + HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy0_d)); + HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy1_d)); + HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy0_d)); + HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy1_d)); + + // Zero out unused scales + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd); + r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd); + r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd); + r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd); + r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia); + r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia); + r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia); + r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia); + + HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); + HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); + HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); + HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); + + r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); + r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); + r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); + r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); + } + + // Reduce and store results + HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum); + HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum); + + hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum); // row0,col0 row1,col0 + hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); // row0,col1 row1,col1 +} + +static void vec_dot_f16_f16_aa_1x1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { const HVX_Vector * restrict x = (const HVX_Vector *) vx; const HVX_Vector * restrict y = (const HVX_Vector *) vy; @@ -912,14 +1313,12 @@ static void vec_dot_f16_f16_aa(const int n, float * restrict s, const void * res hvx_vec_store_u(&s[0], 4, rsum); } -static void vec_dot_f16_f16_aa_rx2(const int n, - float * restrict s, - const void * restrict vx, - uint32_t vx_row_size, - const void * restrict vy) { - const HVX_Vector * restrict x0 = (const HVX_Vector *) vx; - const HVX_Vector * restrict x1 = (const HVX_Vector *) ((const uint8_t *) vx + vx_row_size); - const HVX_Vector * restrict y = (const HVX_Vector *) vy; +static void vec_dot_f16_f16_aa_2x1(const int n, float * restrict s0, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0) { + const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0; + const HVX_Vector * restrict x1 = (const HVX_Vector *) vx1; + const HVX_Vector * restrict y = (const HVX_Vector *) vy0; uint32_t nvec = n / VLEN_FP16; uint32_t nloe = n % VLEN_FP16; @@ -953,10 +1352,86 @@ static void vec_dot_f16_f16_aa_rx2(const int n, } HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(Q6_Vsf_equals_Vqf32(rsum0), Q6_Vsf_equals_Vqf32(rsum1)); - hvx_vec_store_u(&s[0], 8, rsum); + hvx_vec_store_u(s0, 8, rsum); +} + +static void vec_dot_f16_f16_aa_2x2(const int n, float * restrict s0, float * restrict s1, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0, const void * restrict vy1) { + const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0; + const HVX_Vector * restrict x1 = (const HVX_Vector *) vx1; + const HVX_Vector * restrict y0 = (const HVX_Vector *) vy0; + const HVX_Vector * restrict y1 = (const HVX_Vector *) vy1; + + uint32_t nvec = n / VLEN_FP16; + uint32_t nloe = n % VLEN_FP16; + + // Row sums (sf) - 4 accumulators for 2×2 tile + HVX_Vector r0_c0_sum = Q6_V_vsplat_R(0); + HVX_Vector r0_c1_sum = Q6_V_vsplat_R(0); + HVX_Vector r1_c0_sum = Q6_V_vsplat_R(0); + HVX_Vector r1_c1_sum = Q6_V_vsplat_R(0); + + uint32_t i = 0; + + #pragma unroll(2) + for (i = 0; i < nvec; i++) { + HVX_Vector r0_hf = x0[i]; + HVX_Vector r1_hf = x1[i]; + HVX_Vector c0_hf = y0[i]; + HVX_Vector c1_hf = y1[i]; + + // Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1 + HVX_VectorPair r0_c0_qf_p = Q6_Wqf32_vmpy_VhfVhf(r0_hf, c0_hf); + HVX_VectorPair r0_c1_qf_p = Q6_Wqf32_vmpy_VhfVhf(r0_hf, c1_hf); + HVX_VectorPair r1_c0_qf_p = Q6_Wqf32_vmpy_VhfVhf(r1_hf, c0_hf); + HVX_VectorPair r1_c1_qf_p = Q6_Wqf32_vmpy_VhfVhf(r1_hf, c1_hf); + + HVX_Vector r0_c0_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r0_c0_qf_p), Q6_V_hi_W(r0_c0_qf_p)); + HVX_Vector r0_c1_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r0_c1_qf_p), Q6_V_hi_W(r0_c1_qf_p)); + HVX_Vector r1_c0_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r1_c0_qf_p), Q6_V_hi_W(r1_c0_qf_p)); + HVX_Vector r1_c1_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r1_c1_qf_p), Q6_V_hi_W(r1_c1_qf_p)); + + r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_qf, r0_c0_sum)); + r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_qf, r0_c1_sum)); + r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_qf, r1_c0_sum)); + r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_qf, r1_c1_sum)); + } + + if (nloe) { + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2); + + HVX_Vector r0_hf = Q6_V_vand_QV(bmask, x0[i]); + HVX_Vector r1_hf = Q6_V_vand_QV(bmask, x1[i]); + HVX_Vector c0_hf = Q6_V_vand_QV(bmask, y0[i]); + HVX_Vector c1_hf = Q6_V_vand_QV(bmask, y1[i]); + + HVX_VectorPair r0_c0_qf_p = Q6_Wqf32_vmpy_VhfVhf(r0_hf, c0_hf); + HVX_VectorPair r0_c1_qf_p = Q6_Wqf32_vmpy_VhfVhf(r0_hf, c1_hf); + HVX_VectorPair r1_c0_qf_p = Q6_Wqf32_vmpy_VhfVhf(r1_hf, c0_hf); + HVX_VectorPair r1_c1_qf_p = Q6_Wqf32_vmpy_VhfVhf(r1_hf, c1_hf); + + HVX_Vector r0_c0_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r0_c0_qf_p), Q6_V_hi_W(r0_c0_qf_p)); + HVX_Vector r0_c1_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r0_c1_qf_p), Q6_V_hi_W(r0_c1_qf_p)); + HVX_Vector r1_c0_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r1_c0_qf_p), Q6_V_hi_W(r1_c0_qf_p)); + HVX_Vector r1_c1_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r1_c1_qf_p), Q6_V_hi_W(r1_c1_qf_p)); + + r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_qf, r0_c0_sum)); + r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_qf, r0_c1_sum)); + r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_qf, r1_c0_sum)); + r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_qf, r1_c1_sum)); + + } + + // Reduce and store results + HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum); + HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum); + + hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum); // row0,col0 row1,col0 + hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); // row0,col1 row1,col1 } -static void vec_dot_f16_f16_uu(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { +static void vec_dot_f16_f16_uu_1x1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { const HVX_UVector * restrict x = (const HVX_UVector *) vx; const HVX_UVector * restrict y = (const HVX_UVector *) vy; @@ -986,7 +1461,7 @@ static void vec_dot_f16_f16_uu(const int n, float * restrict s, const void * res hvx_vec_store_u(&s[0], 4, rsum); } -static void vec_dot_f16_f32_uu(const int n, float * restrict s, const void * restrict x, const void * restrict y) { +static void vec_dot_f16_f32_uu_1x1(const int n, float * restrict s, const void * restrict x, const void * restrict y) { const HVX_UVector * restrict vx = (const HVX_UVector * restrict) x; const HVX_UVector * restrict vy = (const HVX_UVector * restrict) y; @@ -1083,14 +1558,16 @@ static void vec_dot_f16_f32_uu(const int n, float * restrict s, const void * res const uint32_t nb2 = dst->nb[2]; \ const uint32_t nb3 = dst->nb[3]; -#define htp_matmul_preamble \ - htp_matmul_tensors_preamble; \ - dma_queue *dma_queue = octx->ctx->dma[ith]; \ - uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread; +#define htp_matmul_preamble \ + struct htp_matmul_context * mmctx = data; \ + struct htp_ops_context * octx = mmctx->octx; \ + htp_matmul_tensors_preamble; \ + dma_queue *dma_queue = octx->ctx->dma[ith]; \ + uint32_t src0_nrows_per_thread = mmctx->src0_nrows_per_thread; // *** matmul with support for 4d tensors and full broadcasting -static void matmul_4d(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) { +static void matmul_4d(unsigned int nth, unsigned int ith, void * data) { htp_matmul_preamble; uint64_t t1, t2; @@ -1136,13 +1613,13 @@ static void matmul_4d(struct htp_matmul_type * mt, struct htp_ops_context * octx for (uint32_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) { for (uint32_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) { for (uint32_t ir1 = iir1; ir1 < MIN(iir1 + blck_1, ir1_end); ir1++) { - const uint32_t i13 = fastdiv(ir1, &octx->mm_div_ne12_ne1); - const uint32_t i12 = fastdiv(ir1 - i13 * ne12 * ne1, &octx->mm_div_ne1); + const uint32_t i13 = fastdiv(ir1, &mmctx->mm_div_ne12_ne1); + const uint32_t i12 = fastdiv(ir1 - i13 * ne12 * ne1, &mmctx->mm_div_ne1); const uint32_t i11 = (ir1 - i13 * ne12 * ne1 - i12 * ne1); // broadcast src0 into src1 - const uint32_t i03 = fastdiv(i13, &octx->mm_div_r3); - const uint32_t i02 = fastdiv(i12, &octx->mm_div_r2); + const uint32_t i03 = fastdiv(i13, &mmctx->mm_div_r3); + const uint32_t i02 = fastdiv(i12, &mmctx->mm_div_r2); const uint32_t i1 = i11; const uint32_t i2 = i12; @@ -1155,7 +1632,7 @@ static void matmul_4d(struct htp_matmul_type * mt, struct htp_ops_context * octx const uint32_t ir0_block_end = MIN(iir0 + blck_0, ir0_end); for (uint32_t ir0 = iir0; ir0 < ir0_block_end; ir0++) { const uint8_t * restrict src0_row = src0_base + ir0 * nb01; - mt->vec_dot(ne00, &dst_col[ir0], src0_row, src1_col); + mmctx->vec_dot_1x1(ne00, &dst_col[ir0], src0_row, src1_col); } } } @@ -1170,7 +1647,7 @@ static void matmul_4d(struct htp_matmul_type * mt, struct htp_ops_context * octx } // src1 tensor is already in VTCM spad -static void matmul_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) { +static void matmul_2d(unsigned int nth, unsigned int ith, void * data) { htp_matmul_preamble; const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows @@ -1195,7 +1672,7 @@ static void matmul_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx // Per-thread VTCM scratchpads for all tensors // Note that the entire src1 tensor is already in VTCM // For other tensors we allocate N rows per thread, padded to HVX vector size - uint8_t * restrict spad_dst = dst_spad->data + dst_spad->size_per_thread * ith; + uint8_t * restrict spad_dst = dst_spad->data + dst_spad->size_per_thread * ith; uint8_t * restrict spad_src0 = src0_spad->data + src0_spad->size_per_thread * ith; uint8_t * restrict src1_data = src1_spad->data; @@ -1219,11 +1696,21 @@ static void matmul_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; - #pragma unroll(2) - for (uint32_t ir1 = 0; ir1 < src1_nrows; ++ir1) { + // Process src1 columns in pairs (2×2 tiling) + uint32_t ir1 = 0; + for (; ir1 + 1 < src1_nrows; ir1 += 2) { + const uint8_t * restrict src1_col0 = (const uint8_t *) (src1_data + (ir1+0) * src1_stride); + const uint8_t * restrict src1_col1 = (const uint8_t *) (src1_data + (ir1+1) * src1_stride); + float * restrict dst_row0 = (float *) (dst->data + ((ir1+0) * dst_row_size)); + float * restrict dst_row1 = (float *) (dst->data + ((ir1+1) * dst_row_size)); + mmctx->vec_dot_2x2(ne00, &dst_row0[ir0], &dst_row1[ir0], ss0, ss0 + src0_stride, src1_col0, src1_col1); + } + + // Handle remaining src1 rows (fallback to 2×1) + for (; ir1 < src1_nrows; ++ir1) { const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_stride); float * restrict dst_row = (float *) (dst->data + (ir1 * dst_row_size)); - mt->vec_dot_rx2(ne00, &dst_row[ir0], ss0, src0_stride, src1_col); + mmctx->vec_dot_2x1(ne00, &dst_row[ir0], ss0, ss0 + src0_stride, src1_col); } // Prefetch next (n + spad_nrows) row @@ -1247,20 +1734,20 @@ static void matmul_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx for (uint32_t ir1 = 0; ir1 < src1_nrows; ++ir1) { const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_stride); float * restrict dst_row = (float *) (dst->data + (ir1 * dst_row_size)); - mt->vec_dot(ne00, &dst_row[ir0], ss0, src1_col); + mmctx->vec_dot_1x1(ne00, &dst_row[ir0], ss0, src1_col); } } t2 = HAP_perf_get_qtimer_count(); - FARF(HIGH, "matmul-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", mt->type, ith, nth, + FARF(HIGH, "matmul-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", mmctx->type, ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } // q8x4x2 src1 tensor is already in VTCM spad -static void matvec_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) { +static void matvec_2d(unsigned int nth, unsigned int ith, void * data) { htp_matmul_preamble; const uint32_t src0_nrows = ne01; @@ -1311,7 +1798,7 @@ static void matvec_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx // Process src0 rows for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; - mt->vec_dot_rx2(ne00, &tmp[ir0 - src0_start_row], ss0, src0_stride, src1_col); + mmctx->vec_dot_2x1(ne00, &tmp[ir0 - src0_start_row], ss0, ss0 + src0_stride, src1_col); // Prefetch next (n + spad_nrows) row const uint32_t pr0 = (ir0 + MM_SPAD_SRC0_NROWS); @@ -1329,14 +1816,14 @@ static void matvec_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), src0_stride, src0_row_size, 1); const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; - mt->vec_dot(ne00, &tmp[ir0 - src0_start_row], ss0, src1_col); + mmctx->vec_dot_1x1(ne00, &tmp[ir0 - src0_start_row], ss0, src1_col); } hvx_copy_f32_ua((uint8_t *) &dst_col[src0_start_row], (uint8_t *) tmp, src0_end_row - src0_start_row); t2 = HAP_perf_get_qtimer_count(); - FARF(HIGH, "matvec-%s %u/%u: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", mt->type, ith, nth, + FARF(HIGH, "matvec-%s %u/%u: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", mmctx->type, ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); @@ -1350,7 +1837,7 @@ struct mmid_row_mapping { }; // src1 tensor is already in VTCM spad -static void matmul_id(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) { +static void matmul_id(unsigned int nth, unsigned int ith, void * data) { htp_matmul_preamble; struct htp_tensor * restrict ids = &octx->src2; @@ -1423,11 +1910,10 @@ static void matmul_id(struct htp_matmul_type * mt, struct htp_ops_context * octx const int rm2 = row_mapping.i2; // token idx const uint32_t ir1 = src1_nrows == 1 ? 0 : rm1; // src1 row idx - const uint8_t * restrict src1_col = - (const uint8_t *) (src1_data + (ir1 + rm2 * ne11 + 0) * src1_row_size); + const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + (ir1 + rm2 * ne11 + 0) * src1_row_size); float * dst_row = (float *) (dst->data + (rm1 * nb1 + rm2 * nb2 + 0)); - mt->vec_dot_rx2(ne00, &dst_row[ir0], ss0, src0_row_size_padded, src1_col); + mmctx->vec_dot_2x1(ne00, &dst_row[ir0], ss0, ss0 + src0_row_size_padded, src1_col); } // Prefetch next (n + spad_nrows) row @@ -1453,25 +1939,24 @@ static void matmul_id(struct htp_matmul_type * mt, struct htp_ops_context * octx const int rm2 = row_mapping.i2; // token idx const uint32_t ir1 = src1_nrows == 1 ? 0 : rm1; // src1 row idx - const uint8_t * restrict src1_col = - (const uint8_t *) (src1_data + (ir1 + rm2 * ne11 + 0) * src1_row_size); + const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + (ir1 + rm2 * ne11 + 0) * src1_row_size); float * dst_row = (float *) (dst->data + (rm1 * nb1 + rm2 * nb2 + 0)); - mt->vec_dot(ne00, &dst_row[ir0], ss0, src1_col); + mmctx->vec_dot_1x1(ne00, &dst_row[ir0], ss0, src1_col); } } } t2 = HAP_perf_get_qtimer_count(); - FARF(HIGH, "matmul-id-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u usec %u\n", mt->type, + FARF(HIGH, "matmul-id-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u usec %u\n", mmctx->type, ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], ids->ne[0], ids->ne[1], ids->ne[2], ids->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } // src1 tensor is already in VTCM spad -static void matvec_id(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) { +static void matvec_id(unsigned int nth, unsigned int ith, void * data) { htp_matmul_preamble; struct htp_tensor * restrict ids = &octx->src2; @@ -1531,7 +2016,7 @@ static void matvec_id(struct htp_matmul_type * mt, struct htp_ops_context * octx // Process src0 rows for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; - mt->vec_dot_rx2(ne00, &dst_row[ir0], ss0, src0_row_size_padded, src1_col); + mmctx->vec_dot_2x1(ne00, &dst_row[ir0], ss0, ss0 + src0_row_size_padded, src1_col); // Prefetch next (n + spad_nrows) row const int pr0 = (ir0 + MM_SPAD_SRC0_NROWS); @@ -1549,13 +2034,13 @@ static void matvec_id(struct htp_matmul_type * mt, struct htp_ops_context * octx dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size), src0_row_size_padded, src0_row_size, 1); const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; - mt->vec_dot(ne00, &dst_row[ir0], ss0, src1_col); + mmctx->vec_dot_1x1(ne00, &dst_row[ir0], ss0, src1_col); } } t2 = HAP_perf_get_qtimer_count(); - FARF(HIGH, "matvec-id-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u usec %u\n", mt->type, + FARF(HIGH, "matvec-id-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u usec %u\n", mmctx->type, ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], src2->ne[0], src2->ne[1], src2->ne[2], src2->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); @@ -1754,12 +2239,14 @@ static void quantize_row_f32_q8x4x2(float * restrict x, uint8_t * restrict y, ui hvx_copy_f16_ua(y_d, t_d, nb * 8); } -static void quantize_f32_q8x4x2(const struct htp_tensor * src, - uint8_t * restrict dst, - struct htp_spad * spad, - uint32_t nth, - uint32_t ith, - uint32_t nrows_per_thread) { +static void quantize_f32_q8x4x2(unsigned int nth, unsigned int ith, void * data) { + struct htp_matmul_context * mmctx = data; + struct htp_ops_context * octx = mmctx->octx; + + const struct htp_tensor * src = &octx->src1; + uint8_t * restrict dst = octx->src1_spad.data; + struct htp_spad * spad = &octx->src0_spad; + uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread; uint64_t t1 = HAP_perf_get_qtimer_count(); @@ -1799,8 +2286,14 @@ static void quantize_f32_q8x4x2(const struct htp_tensor * src, ir_last, src_row_size, dst_row_size, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } -static void quantize_f32_f16(const struct htp_tensor * src, uint8_t * restrict dst, uint32_t nth, uint32_t ith, - uint32_t nrows_per_thread, uint32_t dst_stride) { +static void quantize_f32_f16(unsigned int nth, unsigned int ith, void * data) { + struct htp_matmul_context * mmctx = data; + struct htp_ops_context * octx = mmctx->octx; + + const struct htp_tensor * src = &octx->src1; + uint8_t * restrict dst = octx->src1_spad.data; + uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread; + uint32_t dst_stride = octx->src1_spad.stride; uint64_t t1 = HAP_perf_get_qtimer_count(); @@ -1835,8 +2328,14 @@ static void quantize_f32_f16(const struct htp_tensor * src, uint8_t * restrict d } // TODO just a plain copy that should be done via the DMA during the Op setup -static void quantize_f16_f16(const struct htp_tensor * src, uint8_t * restrict dst, uint32_t nth, uint32_t ith, - uint32_t nrows_per_thread, uint32_t dst_stride) { +static void quantize_f16_f16(unsigned int nth, unsigned int ith, void * data) { + struct htp_matmul_context * mmctx = data; + struct htp_ops_context * octx = mmctx->octx; + + const struct htp_tensor * src = &octx->src1; + uint8_t * restrict dst = octx->src1_spad.data; + uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread; + uint32_t dst_stride = octx->src1_spad.stride; uint64_t t1 = HAP_perf_get_qtimer_count(); @@ -1870,213 +2369,76 @@ static void quantize_f16_f16(const struct htp_tensor * src, uint8_t * restrict d ir_last, src_row_size, src_stride, dst_stride, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } -static void htp_quantize_f32_q8x4x2(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - quantize_f32_q8x4x2(&octx->src1, octx->src1_spad.data, &octx->src0_spad, n, i, octx->src1_nrows_per_thread); -} - -static void htp_quantize_f32_f16(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - quantize_f32_f16(&octx->src1, octx->src1_spad.data, n, i, octx->src1_nrows_per_thread, octx->src1_spad.stride); -} - -static void htp_quantize_f16_f16(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - quantize_f16_f16(&octx->src1, octx->src1_spad.data, n, i, octx->src1_nrows_per_thread, octx->src1_spad.stride); -} - -// ** matmul/matvec callbacks for worker_pool - -static void htp_matvec_2d_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - - struct htp_matmul_type mt; - mt.type = "q4x4x2-q8x4x2"; - mt.vec_dot = vec_dot_q4x4x2_q8x4x2; - mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2; - - matvec_2d(&mt, octx, n, i); -} - -static void htp_matmul_2d_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - - struct htp_matmul_type mt; - mt.type = "q4x4x2-q8x4x2"; - mt.vec_dot = vec_dot_q4x4x2_q8x4x2; - mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2; - - matmul_2d(&mt, octx, n, i); -} - -static void htp_matvec_2d_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - - struct htp_matmul_type mt; - mt.type = "q8x4x2-q8x4x2"; - mt.vec_dot = vec_dot_q8x4x2_q8x4x2; - mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2; - - matvec_2d(&mt, octx, n, i); -} - -static void htp_matmul_2d_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - - struct htp_matmul_type mt; - mt.type = "q8x4x2-q8x4x2"; - mt.vec_dot = vec_dot_q8x4x2_q8x4x2; - mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2; - - matmul_2d(&mt, octx, n, i); -} - -static void htp_matvec_2d_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - - struct htp_matmul_type mt; - mt.type = "mxfp4x4x2-q8x4x2"; - mt.vec_dot = vec_dot_mxfp4x4x2_q8x4x2; - mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2; - - matvec_2d(&mt, octx, n, i); -} - -static void htp_matmul_2d_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - - struct htp_matmul_type mt; - mt.type = "mxfp4x4x2-q8x4x2"; - mt.vec_dot = vec_dot_mxfp4x4x2_q8x4x2; - mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2; - - matmul_2d(&mt, octx, n, i); -} - -static void htp_matvec_2d_f16_f16(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - - struct htp_matmul_type mt; - mt.type = "f16-f16"; - mt.vec_dot = vec_dot_f16_f16_aa; - mt.vec_dot_rx2 = vec_dot_f16_f16_aa_rx2; - - matvec_2d(&mt, octx, n, i); -} - -static void htp_matmul_2d_f16_f16(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - - struct htp_matmul_type mt; - mt.type = "f16-f16"; - mt.vec_dot = vec_dot_f16_f16_aa; - mt.vec_dot_rx2 = vec_dot_f16_f16_aa_rx2; - - matmul_2d(&mt, octx, n, i); -} - -static void htp_matmul_4d_f16_f32(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - - struct htp_matmul_type mt; - mt.type = "f16-f32"; - mt.vec_dot = vec_dot_f16_f32_uu; - - matmul_4d(&mt, octx, n, i); -} - -static void htp_matmul_4d_f16_f16(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - struct htp_matmul_type mt; - mt.type = "f16-f16"; - mt.vec_dot = vec_dot_f16_f16_uu; - - matmul_4d(&mt, octx, n, i); -} - -// ** matmul-id callbacks for worker_pool - -static void htp_matvec_id_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - - struct htp_matmul_type mt; - mt.type = "q4x4x2-q8x4x2"; - mt.vec_dot = vec_dot_q4x4x2_q8x4x2; - mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2; - - matvec_id(&mt, octx, n, i); -} - -static void htp_matmul_id_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - - struct htp_matmul_type mt; - mt.type = "q4x4x2-q8x4x2"; - mt.vec_dot = vec_dot_q4x4x2_q8x4x2; - mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2; - - matmul_id(&mt, octx, n, i); -} - -static void htp_matvec_id_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - - struct htp_matmul_type mt; - mt.type = "q8x4x2-q8x4x2"; - mt.vec_dot = vec_dot_q8x4x2_q8x4x2; - mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2; - - matvec_id(&mt, octx, n, i); -} - -static void htp_matmul_id_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - - struct htp_matmul_type mt; - mt.type = "q8x4x2-q8x4x2"; - mt.vec_dot = vec_dot_q8x4x2_q8x4x2; - mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2; - - matmul_id(&mt, octx, n, i); +static inline bool htp_is_permuted(const struct htp_tensor * t) { + return t->nb[0] > t->nb[1] || t->nb[1] > t->nb[2] || t->nb[2] > t->nb[3]; } -static void htp_matvec_id_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - - struct htp_matmul_type mt; - mt.type = "mxfp4x4x2-q8x4x2"; - mt.vec_dot = vec_dot_mxfp4x4x2_q8x4x2; - mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2; - - matvec_id(&mt, octx, n, i); +static int htp_mminit_vec_dot(struct htp_matmul_context * mmctx, enum htp_data_type type) { + switch (type) { + case HTP_TYPE_Q4_0: + mmctx->type = "q4x4x2-f32"; + mmctx->vec_dot_1x1 = vec_dot_q4x4x2_q8x4x2_1x1; + mmctx->vec_dot_2x1 = vec_dot_q4x4x2_q8x4x2_2x1; + mmctx->vec_dot_2x2 = vec_dot_q4x4x2_q8x4x2_2x2; + return 0; + case HTP_TYPE_Q8_0: + mmctx->type = "q8x4x2-f32"; + mmctx->vec_dot_1x1 = vec_dot_q8x4x2_q8x4x2_1x1; + mmctx->vec_dot_2x1 = vec_dot_q8x4x2_q8x4x2_2x1; + mmctx->vec_dot_2x2 = vec_dot_q8x4x2_q8x4x2_2x2; + return 0; + case HTP_TYPE_MXFP4: + mmctx->type = "mxfp4x4x2-f32"; + mmctx->vec_dot_1x1 = vec_dot_mxfp4x4x2_q8x4x2_1x1; + mmctx->vec_dot_2x1 = vec_dot_mxfp4x4x2_q8x4x2_2x1; + mmctx->vec_dot_2x2 = vec_dot_mxfp4x4x2_q8x4x2_2x2; + return 0; + default: + return -1; + } } -static void htp_matmul_id_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - - struct htp_matmul_type mt; - mt.type = "mxfp4x4x2-q8x4x2"; - mt.vec_dot = vec_dot_mxfp4x4x2_q8x4x2; - mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2; - - matmul_id(&mt, octx, n, i); -} +static void htp_mminit_spad(struct htp_ops_context * octx, + size_t dst_row_size, + size_t src0_row_size_padded, + size_t src1_row_size, + uint32_t src1_nrows, + size_t src2_spad_size_per_thread) { + octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); + octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); + octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256); + + if (src2_spad_size_per_thread > 0) { + octx->src2_spad.size_per_thread = src2_spad_size_per_thread; + octx->src2_spad.size = octx->src2_spad.size_per_thread; + } -// ** main matmul entry point + // src0 spad is also used in dynamic quantizer to store padded src1 rows + size_t src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float)); + if (octx->src0_spad.size_per_thread < src1_row_size_padded) { + octx->src0_spad.size_per_thread = src1_row_size_padded; + } -static inline bool htp_is_permuted(const struct htp_tensor * t) { - return t->nb[0] > t->nb[1] || t->nb[1] > t->nb[2] || t->nb[2] > t->nb[3]; + octx->src1_spad.size = octx->src1_spad.size_per_thread; + octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; + octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; } int op_matmul(struct htp_ops_context * octx) { htp_matmul_tensors_preamble; - const char * op_type; + struct htp_matmul_context mmctx_struct = {0}; + struct htp_matmul_context * mmctx = &mmctx_struct; + mmctx->octx = octx; const uint32_t src0_nrows = ne01 * ne02 * ne03; const uint32_t src1_nrows = ne11 * ne12 * ne13; + // Compute src0_nrows_per_thread + mmctx->src0_nrows_per_thread = (src0_nrows + octx->n_threads - 1) / octx->n_threads; + mmctx->src0_nrows_per_thread += (mmctx->src0_nrows_per_thread & 1); // round up to even + const size_t src0_row_size = nb01; const size_t dst_row_size = nb1; size_t src1_row_size = nb11; @@ -2085,181 +2447,95 @@ int op_matmul(struct htp_ops_context * octx) { size_t src1_row_size_padded; worker_callback_t quant_job_func; - worker_callback_t matmul_job_func; + worker_callback_t matmul_job_func = src1_nrows > 1 ? matmul_2d : matvec_2d; bool need_quant = !(octx->flags & HTP_OPFLAGS_SKIP_QUANTIZE); - switch (src0->type) { - case HTP_TYPE_Q4_0: - op_type = "q4x4x2-f32"; - quant_job_func = htp_quantize_f32_q8x4x2; - if (src1_nrows > 1) { - matmul_job_func = htp_matmul_2d_q4x4x2_q8x4x2; - } else { - matmul_job_func = htp_matvec_2d_q4x4x2_q8x4x2; - } - - src1_row_size = q8x4x2_row_size(ne10); // row size post quantization - - // Entire src1 tensor is placed into the VTCM - // For other tensors we allocate N rows per thread, padded to HVX vector size - - octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); - octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256); - - // src0 spad is also used in dynamic quantizer to store padded src1 rows - src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float)); - if (octx->src0_spad.size_per_thread < src1_row_size_padded) { - octx->src0_spad.size_per_thread = src1_row_size_padded; - } + if (src0->type == HTP_TYPE_F16) { + // Try optimized f16-f16 path first (src1 in VTCM) + const size_t f16_src1_row_size = hex_round_up(ne10 * 2, 128); + const size_t f16_src1_spad_size = hex_round_up(f16_src1_row_size * src1_nrows, 256); + const size_t f16_src0_spad_size = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256) * octx->n_threads; + const size_t f16_dst_spad_size = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256) * octx->n_threads; - octx->src1_spad.size = octx->src1_spad.size_per_thread; - octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; - octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; - break; + const size_t f16_total_size = f16_src1_spad_size + f16_src0_spad_size + f16_dst_spad_size; - case HTP_TYPE_Q8_0: - op_type = "q8x4x2-f32"; - quant_job_func = htp_quantize_f32_q8x4x2; - if (src1_nrows > 1) { - matmul_job_func = htp_matmul_2d_q8x4x2_q8x4x2; - } else { - matmul_job_func = htp_matvec_2d_q8x4x2_q8x4x2; - } + // Default matmul implementation does not support multi-batch src0 (N-vs-N broadcasting). + // It only supports 1-vs-N broadcasting (src0 is 2D) or standard 2D matmul. + const bool is_batched = (ne02 > 1) || (ne03 > 1); + const bool is_permuted = htp_is_permuted(&octx->src0) || htp_is_permuted(&octx->src1); - src1_row_size = q8x4x2_row_size(ne10); // row size post quantization + if (!is_batched && !is_permuted && f16_total_size <= octx->ctx->vtcm_size) { + // Optimized path + quant_job_func = (src1->type == HTP_TYPE_F32) ? quantize_f32_f16 : quantize_f16_f16; + mmctx->type = "f16-f16"; + mmctx->vec_dot_1x1 = vec_dot_f16_f16_aa_1x1; + mmctx->vec_dot_2x1 = vec_dot_f16_f16_aa_2x1; + mmctx->vec_dot_2x2 = vec_dot_f16_f16_aa_2x2; - // Entire src1 tensor is placed into the VTCM - // For other tensors we allocate N rows per thread, padded to HVX vector size + src1_row_size = f16_src1_row_size; // row size post quantization octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256); - // src0 spad is also used in dynamic quantizer to store padded src1 rows - src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float)); - if (octx->src0_spad.size_per_thread < src1_row_size_padded) { - octx->src0_spad.size_per_thread = src1_row_size_padded; - } - octx->src1_spad.size = octx->src1_spad.size_per_thread; octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; - break; - - case HTP_TYPE_MXFP4: - op_type = "mxfp4x4x2-f32"; - quant_job_func = htp_quantize_f32_q8x4x2; - if (src1_nrows > 1) { - matmul_job_func = htp_matmul_2d_mxfp4x4x2_q8x4x2; + } else { + // Fallback to f16/f32 (DDR) if src1 doesn't fit in VTCM or broadcasting is required + quant_job_func = NULL; + if (src1->type == HTP_TYPE_F32) { + mmctx->type = "f16-f32"; + mmctx->vec_dot_1x1 = vec_dot_f16_f32_uu_1x1; + matmul_job_func = matmul_4d; } else { - matmul_job_func = htp_matvec_2d_mxfp4x4x2_q8x4x2; + mmctx->type = "f16-f16"; + mmctx->vec_dot_1x1 = vec_dot_f16_f16_uu_1x1; + matmul_job_func = matmul_4d; } - src1_row_size = q8x4x2_row_size(ne10); // row size post quantization - - // Entire src1 tensor is placed into the VTCM - // For other tensors we allocate N rows per thread, padded to HVX vector size + src1_row_size = nb11; // original row size in DDR octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); - octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256); + octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size, 256); + octx->src1_spad.size_per_thread = hex_round_up(MM_SPAD_SRC1_NROWS * src1_row_size, 256); - // src0 spad is also used in dynamic quantizer to store padded src1 rows - src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float)); - if (octx->src0_spad.size_per_thread < src1_row_size_padded) { - octx->src0_spad.size_per_thread = src1_row_size_padded; - } - - octx->src1_spad.size = octx->src1_spad.size_per_thread; octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; + octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads; octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; - break; - case HTP_TYPE_F16: - { - // Try optimized f16-f16 path first (src1 in VTCM) - const size_t f16_src1_row_size = hex_round_up(ne10 * 2, 128); - const size_t f16_src1_spad_size = hex_round_up(f16_src1_row_size * src1_nrows, 256); - const size_t f16_src0_spad_size = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256) * octx->n_threads; - const size_t f16_dst_spad_size = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256) * octx->n_threads; - - const size_t f16_total_size = f16_src1_spad_size + f16_src0_spad_size + f16_dst_spad_size; - - // Default matmul implementation does not support multi-batch src0 (N-vs-N broadcasting). - // It only supports 1-vs-N broadcasting (src0 is 2D) or standard 2D matmul. - const bool is_batched = (ne02 > 1) || (ne03 > 1); - const bool is_permuted = htp_is_permuted(&octx->src0) || htp_is_permuted(&octx->src1); - - if (!is_batched && !is_permuted && f16_total_size <= octx->ctx->vtcm_size) { - // Optimized path - op_type = "f16-f16"; - quant_job_func = (src1->type == HTP_TYPE_F32) ? htp_quantize_f32_f16 : htp_quantize_f16_f16; - if (src1_nrows > 1) { - matmul_job_func = htp_matmul_2d_f16_f16; - } else { - matmul_job_func = htp_matvec_2d_f16_f16; - } - - src1_row_size = f16_src1_row_size; // row size post quantization - - octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); - octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256); - - octx->src1_spad.size = octx->src1_spad.size_per_thread; - octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; - octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; - } else { - // Fallback to f16/f32 (DDR) if src1 doesn't fit in VTCM or broadcasting is required - quant_job_func = NULL; - if (src1->type == HTP_TYPE_F32) { - op_type = "f16-f32"; - matmul_job_func = htp_matmul_4d_f16_f32; - } else { - op_type = "f16-f16"; - matmul_job_func = htp_matmul_4d_f16_f16; - } - - src1_row_size = nb11; // original row size in DDR - - octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size, 256); - octx->src1_spad.size_per_thread = hex_round_up(MM_SPAD_SRC1_NROWS * src1_row_size, 256); - - octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; - octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads; - octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; - - // Init fastdiv for matmul_4d (supports broadcasting) - octx->mm_div_ne12_ne1 = init_fastdiv_values(src1->ne[2] * dst->ne[1]); - octx->mm_div_ne1 = init_fastdiv_values(dst->ne[1]); - octx->mm_div_r2 = init_fastdiv_values(src1->ne[2] / src0->ne[2]); - octx->mm_div_r3 = init_fastdiv_values(src1->ne[3] / src0->ne[3]); - - need_quant = false; - } - } - break; + // Init fastdiv for matmul_4d (supports broadcasting) + mmctx->mm_div_ne12_ne1 = init_fastdiv_values(src1->ne[2] * dst->ne[1]); + mmctx->mm_div_ne1 = init_fastdiv_values(dst->ne[1]); + mmctx->mm_div_r2 = init_fastdiv_values(src1->ne[2] / src0->ne[2]); + mmctx->mm_div_r3 = init_fastdiv_values(src1->ne[3] / src0->ne[3]); - default: + need_quant = false; + } + } else { + if (htp_mminit_vec_dot(mmctx, src0->type) != 0) { return HTP_STATUS_NO_SUPPORT; + } + + quant_job_func = quantize_f32_q8x4x2; + src1_row_size = q8x4x2_row_size(ne10); + htp_mminit_spad(octx, dst_row_size, src0_row_size_padded, src1_row_size, src1_nrows, 0); } // VTCM scratchpads for all tensors size_t spad_size = octx->src1_spad.size + octx->src0_spad.size + octx->dst_spad.size; - FARF(HIGH, "matmul-%s : src0-spad-size %u src1-spad-size %u dst-spad-size %u (%zu)\n", op_type, + FARF(HIGH, "matmul-%s : src0-spad-size %u src1-spad-size %u dst-spad-size %u (%zu)\n", mmctx->type, octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size, spad_size); - FARF(HIGH, "matmul-%s : %ux%ux%ux%u * %ux%ux%ux%u-> %ux%ux%ux%u (0x%p, 0x%p, 0x%p)\n", op_type, src0->ne[0], + FARF(HIGH, "matmul-%s : %ux%ux%ux%u * %ux%ux%ux%u-> %ux%ux%ux%u (0x%p, 0x%p, 0x%p)\n", mmctx->type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], src0->data, src1->data, dst->data); // Make sure the reserved vtcm size is sufficient if (octx->ctx->vtcm_size < spad_size) { - FARF(ERROR, "matmul-%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, + FARF(ERROR, "matmul-%s : current VTCM reservation %zu is too small, needed %zu\n", mmctx->type, octx->ctx->vtcm_size, spad_size); return HTP_STATUS_VTCM_TOO_SMALL; } @@ -2268,39 +2544,31 @@ int op_matmul(struct htp_ops_context * octx) { octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; - octx->src0_nrows_per_thread = (src0_nrows + octx->n_threads - 1) / octx->n_threads; - octx->src0_nrows_per_thread += (octx->src0_nrows_per_thread & 1); // round up to even - octx->src0_spad.stride = src0_row_size_padded; octx->src1_spad.stride = src1_row_size; if (need_quant) { - // Run quant jobs - const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads); - octx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs; - worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, octx, n_quant_jobs); + const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads); + mmctx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs; + worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, mmctx, n_quant_jobs); } if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { - // Run matmul jobs const uint32_t n_matmul_jobs = octx->n_threads; - worker_pool_run_func(octx->ctx->worker_pool, matmul_job_func, octx, n_matmul_jobs); + worker_pool_run_func(octx->ctx->worker_pool, matmul_job_func, mmctx, n_matmul_jobs); } return HTP_STATUS_OK; } -// ** main matmul-id entry point - int op_matmul_id(struct htp_ops_context * octx) { htp_matmul_tensors_preamble; - struct htp_tensor * restrict ids = &octx->src2; - - const char * op_type; + struct htp_matmul_context mmctx_struct = {0}; + struct htp_matmul_context * mmctx = &mmctx_struct; + mmctx->octx = octx; - worker_callback_t quant_job_func; - worker_callback_t matmul_id_job_func; + struct htp_tensor * restrict ids = &octx->src2; const size_t src0_row_size = nb01; const size_t dst_row_size = nb1; @@ -2310,6 +2578,13 @@ int op_matmul_id(struct htp_ops_context * octx) { const uint32_t src0_nrows = ne01; // per expert const uint32_t src1_nrows = ne11 * ne12 * ne13; + worker_callback_t quant_job_func; + worker_callback_t matmul_id_job_func = src1_nrows > 1 ? matmul_id : matvec_id; + + // Compute src0_nrows_per_thread + mmctx->src0_nrows_per_thread = (src0_nrows + octx->n_threads - 1) / octx->n_threads; + mmctx->src0_nrows_per_thread += (mmctx->src0_nrows_per_thread & 1); // round up to even + size_t src1_row_size; size_t src1_row_size_padded; @@ -2320,112 +2595,29 @@ int op_matmul_id(struct htp_ops_context * octx) { size_t matrix_row_counts_size = n_as * sizeof(uint32_t); size_t matrix_row_map_size = n_as * ids->ne[0] * ids->ne[1] * sizeof(struct mmid_row_mapping); - switch (src0->type) { - case HTP_TYPE_Q4_0: - op_type = "q4x2x2-f32"; - quant_job_func = htp_quantize_f32_q8x4x2; - src1_row_size = q8x4x2_row_size(ne10); // row size post quantization - if (src1_nrows > 1) { - matmul_id_job_func = htp_matmul_id_q4x4x2_q8x4x2; - } else { - matmul_id_job_func = htp_matvec_id_q4x4x2_q8x4x2; - } - - // Entire src1 tensor is placed into the VTCM - // For other tensors we allocate N rows per thread, padded to HVX vector size - octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); - octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256); - octx->src2_spad.size_per_thread = hex_round_up(matrix_row_counts_size + matrix_row_map_size, 256); - - // src0 spad is also used in dynamic quantizer to store padded src1 rows - src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float)); - if (octx->src0_spad.size_per_thread < src1_row_size_padded) { - octx->src0_spad.size_per_thread = src1_row_size_padded; - } - - octx->src2_spad.size = octx->src2_spad.size_per_thread; - octx->src1_spad.size = octx->src1_spad.size_per_thread; - octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; - octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; - break; - - case HTP_TYPE_Q8_0: - op_type = "q8x2x2-f32"; - quant_job_func = htp_quantize_f32_q8x4x2; - src1_row_size = q8x4x2_row_size(ne10); // row size post quantization - if (src1_nrows > 1) { - matmul_id_job_func = htp_matmul_id_q8x4x2_q8x4x2; - } else { - matmul_id_job_func = htp_matvec_id_q8x4x2_q8x4x2; - } - - // Entire src1 tensor is placed into the VTCM - // For other tensors we allocate N rows per thread, padded to HVX vector size - octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); - octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256); - octx->src2_spad.size_per_thread = hex_round_up(matrix_row_counts_size + matrix_row_map_size, 256); - - // src0 spad is also used in dynamic quantizer to store padded src1 rows - src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float)); - if (octx->src0_spad.size_per_thread < src1_row_size_padded) { - octx->src0_spad.size_per_thread = src1_row_size_padded; - } - - octx->src2_spad.size = octx->src2_spad.size_per_thread; - octx->src1_spad.size = octx->src1_spad.size_per_thread; - octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; - octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; - break; - - case HTP_TYPE_MXFP4: - op_type = "mxfp4x2x2-f32"; - quant_job_func = htp_quantize_f32_q8x4x2; - src1_row_size = q8x4x2_row_size(ne10); // row size post quantization - if (src1_nrows > 1) { - matmul_id_job_func = htp_matmul_id_mxfp4x4x2_q8x4x2; - } else { - matmul_id_job_func = htp_matvec_id_mxfp4x4x2_q8x4x2; - } - - // Entire src1 tensor is placed into the VTCM - // For other tensors we allocate N rows per thread, padded to HVX vector size - octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); - octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256); - octx->src2_spad.size_per_thread = hex_round_up(matrix_row_counts_size + matrix_row_map_size, 256); - - // src0 spad is also used in dynamic quantizer to store padded src1 rows - src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float)); - if (octx->src0_spad.size_per_thread < src1_row_size_padded) { - octx->src0_spad.size_per_thread = src1_row_size_padded; - } + if (htp_mminit_vec_dot(mmctx, src0->type) != 0) { + return HTP_STATUS_NO_SUPPORT; + } - octx->src2_spad.size = octx->src2_spad.size_per_thread; - octx->src1_spad.size = octx->src1_spad.size_per_thread; - octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; - octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; - break; + quant_job_func = quantize_f32_q8x4x2; + src1_row_size = q8x4x2_row_size(ne10); - default: - return HTP_STATUS_NO_SUPPORT; - } + const size_t src2_spad_size_per_thread = hex_round_up(matrix_row_counts_size + matrix_row_map_size, 256); + htp_mminit_spad(octx, dst_row_size, src0_row_size_padded, src1_row_size, src1_nrows, src2_spad_size_per_thread); size_t spad_size = octx->src2_spad.size + octx->src1_spad.size + octx->src0_spad.size + octx->dst_spad.size; - FARF(HIGH, "matmul-id-%s : src0-spad-size %u src1-spad-size %u src2-spad-size %u dst-spad-size %u (%zu)\n", op_type, + FARF(HIGH, "matmul-id-%s : src0-spad-size %u src1-spad-size %u src2-spad-size %u dst-spad-size %u (%zu)\n", mmctx->type, octx->src0_spad.size, octx->src1_spad.size, octx->src2_spad.size, octx->dst_spad.size, spad_size); - FARF(HIGH, "matmul-id-%s : %ux%ux%ux%u * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u (0x%p, 0x%p, 0x%p)\n", op_type, + FARF(HIGH, "matmul-id-%s : %ux%ux%ux%u * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u (0x%p, 0x%p, 0x%p)\n", mmctx->type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], ids->ne[0], ids->ne[1], ids->ne[2], ids->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], src0->data, src1->data, dst->data); // Make sure the reserved vtcm size is sufficient if (octx->ctx->vtcm_size < spad_size) { - FARF(ERROR, "matmul-id-%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, - octx->ctx->vtcm_size, spad_size); + FARF(ERROR, "matmul-id-%s : current VTCM reservation %zu is too small, needed %zu\n", mmctx->type, octx->ctx->vtcm_size, spad_size); return HTP_STATUS_VTCM_TOO_SMALL; } @@ -2434,8 +2626,8 @@ int op_matmul_id(struct htp_ops_context * octx) { octx->src2_spad.data = octx->src1_spad.data + octx->src1_spad.size; octx->dst_spad.data = octx->src2_spad.data + octx->src2_spad.size; - octx->src0_nrows_per_thread = (src0_nrows + octx->n_threads - 1) / octx->n_threads; - octx->src0_nrows_per_thread += (octx->src0_nrows_per_thread & 1); // round up to even + octx->src0_spad.stride = src0_row_size_padded; + octx->src1_spad.stride = src1_row_size; if (src1_nrows > 1) { // initialize matrix_row_counts and map @@ -2447,8 +2639,7 @@ int op_matmul_id(struct htp_ops_context * octx) { // group rows by src0 matrix for (uint32_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) { // token idx for (uint32_t id = 0; id < n_ids; ++id) { // expert idx - const uint32_t i02 = - *(const uint32_t *) ((const uint8_t *) ids->data + iid1 * ids->nb[1] + id * ids->nb[0]); + const uint32_t i02 = *(const uint32_t *) ((const uint8_t *) ids->data + iid1 * ids->nb[1] + id * ids->nb[0]); assert(i02 >= 0 && i02 < n_as); @@ -2460,16 +2651,14 @@ int op_matmul_id(struct htp_ops_context * octx) { // Setup worker pool callbacks if (!(octx->flags & HTP_OPFLAGS_SKIP_QUANTIZE)) { - // Run quant jobs const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads); - octx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs; - worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, octx, n_quant_jobs); + mmctx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs; + worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, mmctx, n_quant_jobs); } if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { - // Run matmul-id jobs const uint32_t n_matmul_jobs = octx->n_threads; - worker_pool_run_func(octx->ctx->worker_pool, matmul_id_job_func, octx, n_matmul_jobs); + worker_pool_run_func(octx->ctx->worker_pool, matmul_id_job_func, mmctx, n_matmul_jobs); } return HTP_STATUS_OK; From 39b5f414a3460b3048a92e60805ed185d06a951b Mon Sep 17 00:00:00 2001 From: Mario Limonciello Date: Thu, 12 Feb 2026 02:38:35 -0600 Subject: [PATCH 135/831] Add a workaround for compilation with ROCWMMA_FATTN and gfx9 (llama/19461) There is an upstream problem [1] with AMD's LLVM 22 fork and rocWMMA 2.2.0 causing compilation issues on devices without native fp16 support (CDNA devices). The specialized types aren't resolved properly: ``` /opt/rocm/include/rocwmma/internal/mfma_impl.hpp:2549:37: error: ambiguous partial specializations of 'amdgcn_mfma<__half, __half, __half, 16, 16, 16>' 2549 | using ARegsT = typename Impl::ARegsT; ``` Add a workaround to explicitly declare the types and cast when compiling with HIP and ROCWMMA_FATTN [2]. When this is actually fixed upstream some guards can be used to detect and wrap the version that has the fix to only apply when necessary. Link: https://github.com/ROCm/rocm-libraries/issues/4398 [1] Link: https://github.com/ggml-org/llama.cpp/issues/19269 [2] Signed-off-by: Mario Limonciello --- ggml/src/ggml-cuda/fattn-wmma-f16.cu | 31 +++++++++++++++++++++++----- 1 file changed, 26 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ggml/src/ggml-cuda/fattn-wmma-f16.cu index 8694fd06c7b..35735d48b2e 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cu +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cu @@ -63,11 +63,19 @@ static __global__ void flash_attn_ext_f16( constexpr int frag_m = ncols == 8 ? 32 : 16; constexpr int frag_n = ncols == 8 ? 8 : 16; static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0."); +#if defined(GGML_USE_HIP) + typedef wmma::fragment frag_a_K; + typedef wmma::fragment frag_a_V; + typedef wmma::fragment frag_b; + typedef wmma::fragment frag_c_KQ; + typedef wmma::fragment frag_c_VKQ; +#else typedef wmma::fragment frag_a_K; typedef wmma::fragment frag_a_V; typedef wmma::fragment frag_b; typedef wmma::fragment frag_c_KQ; typedef wmma::fragment frag_c_VKQ; +#endif constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel. constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy. @@ -126,6 +134,19 @@ static __global__ void flash_attn_ext_f16( __shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice. half2 * VKQ2 = (half2 *) VKQ; + +#if defined(GGML_USE_HIP) + const _Float16 * K_h_f16 = reinterpret_cast(K_h); + const _Float16 * V_h_f16 = reinterpret_cast(V_h); + _Float16 * KQ_f16 = reinterpret_cast<_Float16 *>(KQ); + _Float16 * VKQ_f16 = reinterpret_cast<_Float16 *>(VKQ); +#else + const half * K_h_f16 = K_h; + const half * V_h_f16 = V_h; + half * KQ_f16 = KQ; + half * VKQ_f16 = VKQ; +#endif + #pragma unroll for (int j0 = 0; j0 < ncols; j0 += nwarps) { const int j = j0 + threadIdx.y; @@ -160,7 +181,7 @@ static __global__ void flash_attn_ext_f16( for (int i0 = 0; i0 < D; i0 += 16) { #pragma unroll for (int j0 = 0; j0 < ncols; j0 += frag_n) { - wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded); + wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ_f16 + j0*D_padded + i0, D_padded); } } @@ -180,7 +201,7 @@ static __global__ void flash_attn_ext_f16( #pragma unroll for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) { frag_a_K K_a; - wmma::load_matrix_sync(K_a, K_h + int64_t(k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV); + wmma::load_matrix_sync(K_a, K_h_f16 + int64_t(k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV); #pragma unroll for (int j = 0; j < ncols/frag_n; ++j) { wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]); @@ -310,7 +331,7 @@ static __global__ void flash_attn_ext_f16( const int k = k0 + (threadIdx.y % VKQ_ratio)*16; wmma::load_matrix_sync( KQ_b[k0/(VKQ_ratio*16)][j0/frag_n], - KQ + j0*(kqar*kqs_padded) + k, + KQ_f16 + j0*(kqar*kqs_padded) + k, kqar*kqs_padded); } } @@ -328,7 +349,7 @@ static __global__ void flash_attn_ext_f16( const int k = k0 + (threadIdx.y % VKQ_ratio)*16; frag_a_V v_a; - wmma::load_matrix_sync(v_a, V_h + int64_t(k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV); + wmma::load_matrix_sync(v_a, V_h_f16 + int64_t(k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV); #pragma unroll for (int j = 0; j < ncols/frag_n; ++j) { wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]); @@ -344,7 +365,7 @@ static __global__ void flash_attn_ext_f16( #pragma unroll for (int j0 = 0; j0 < ncols; j0 += frag_n) { wmma::store_matrix_sync( - KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio), + KQ_f16 + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio), VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n], D_padded, wmma::mem_col_major); } From d8e3e2ef0891ff858aeacc62b5b44ccaae62ba6f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 12 Feb 2026 11:35:28 +0200 Subject: [PATCH 136/831] metal : update sum_rows kernel to support float4 (llama/19524) --- ggml/src/ggml-metal/ggml-metal-device.cpp | 33 +++++++--- ggml/src/ggml-metal/ggml-metal-impl.h | 3 + ggml/src/ggml-metal/ggml-metal-ops.cpp | 18 ++++-- ggml/src/ggml-metal/ggml-metal.metal | 79 ++++++++++++++--------- 4 files changed, 91 insertions(+), 42 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 517559d12a6..06f3d804590 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -328,31 +328,46 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum(ggml_metal_l } ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum_rows(ggml_metal_library_t lib, const ggml_tensor * op) { - GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type)); + GGML_ASSERT(ggml_is_contiguous_rows(op->src[0])); char base[256]; char name[256]; - const char * op_str = "undefined"; + int op_num = -1; + switch (op->op) { - case GGML_OP_SUM_ROWS: - op_str = "sum_rows"; break; - case GGML_OP_MEAN: - op_str = "mean"; break; + case GGML_OP_SUM_ROWS: op_num = OP_SUM_ROWS_NUM_SUM_ROWS; break; + case GGML_OP_MEAN: op_num = OP_SUM_ROWS_NUM_MEAN; break; default: GGML_ABORT("fatal error"); }; - snprintf(base, 256, "kernel_%s_%s", op_str, ggml_type_name(op->src[0]->type)); + const char * t0_str = ggml_type_name(op->src[0]->type); + const char * t_str = ggml_type_name(op->type); - snprintf(name, 256, "%s", base); + const bool is_c4 = op->src[0]->ne[0] % 4 == 0; + + snprintf(base, 256, "kernel_sum_rows_%s_%s%s", t0_str, t_str, is_c4 ? "_4" : ""); + snprintf(name, 256, "%s_op=%d", base, op_num); ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); if (!res.pipeline) { - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_int16(cv, op_num, FC_SUM_ROWS + 0); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); } res.smem = 32*sizeof(float); + if (is_c4) { + res.smem *= 4; + } + + res.c4 = is_c4; + return res; } diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 952e1be076e..383e0d6e93b 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -82,6 +82,7 @@ #define FC_COUNT_EQUAL 1100 #define FC_UNARY 1200 #define FC_BIN 1300 +#define FC_SUM_ROWS 1400 // op-specific constants #define OP_FLASH_ATTN_EXT_NQPSG 8 @@ -118,6 +119,8 @@ #define OP_UNARY_NUM_SOFTPLUS 115 #define OP_UNARY_NUM_EXPM1 116 +#define OP_SUM_ROWS_NUM_SUM_ROWS 10 +#define OP_SUM_ROWS_NUM_MEAN 11 // kernel argument structs // diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 7db95d1c84d..20880d9551e 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -904,6 +904,11 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne, op, ne); GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + GGML_ASSERT(ggml_is_contiguous_rows(op->src[0])); + + ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]); + ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op); + ggml_metal_kargs_sum_rows args = { /*.ne00 =*/ ne00, /*.ne01 =*/ ne01, @@ -925,21 +930,26 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) { auto pipeline = ggml_metal_library_get_pipeline_sum_rows(lib, op); + if (pipeline.c4) { + args.ne00 = ne00/4; + args.ne0 = ne0/4; + } + int nth = 32; // SIMD width - while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + while (nth < args.ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { nth *= 2; } nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); - nth = std::min(nth, ne00); + nth = std::min(nth, (int) args.ne00); const size_t smem = pipeline.smem; ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + ggml_metal_encoder_set_buffer (enc, bid_src0, 1); + ggml_metal_encoder_set_buffer (enc, bid_dst, 2); ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 0036ba90ec9..6c349aa0c92 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -77,6 +77,14 @@ static inline float dot(float x, float y) { return x*y; } +static inline float sum(float x) { + return x; +} + +static inline float sum(float4 x) { + return x[0] + x[1] + x[2] + x[3]; +} + // NOTE: this is not dequantizing - we are simply fitting the template template void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) { @@ -1501,33 +1509,35 @@ kernel void kernel_op_sum_f32( } } -template -kernel void kernel_sum_rows( +constant short FC_sum_rows_op [[function_constant(FC_SUM_ROWS + 0)]]; + +template +kernel void kernel_sum_rows_impl( constant ggml_metal_kargs_sum_rows & args, - device const float * src0, - device float * dst, - threadgroup float * shmem_f32 [[threadgroup(0)]], + device const char * src0, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort3 tpitg[[thread_position_in_threadgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]], ushort tiisg[[thread_index_in_simdgroup]], ushort3 ntg[[threads_per_threadgroup]]) { - int64_t i3 = tgpig.z; - int64_t i2 = tgpig.y; - int64_t i1 = tgpig.x; +#define FC_OP FC_sum_rows_op - if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) { - return; - } + const int i3 = tgpig.z; + const int i2 = tgpig.y; + const int i1 = tgpig.x; + + threadgroup T0 * shmem_t = (threadgroup T0 *) shmem; if (sgitg == 0) { - shmem_f32[tiisg] = 0.0f; + shmem_t[tiisg] = 0.0f; } - device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03); - device float * dst_row = (device float *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3); + device const T0 * src_row = (device const T0 *) (src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03); + device T * dst_row = (device T *) (dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3); - float sumf = 0; + T0 sumf = T0(0.0f); for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) { sumf += src_row[i0]; @@ -1538,23 +1548,33 @@ kernel void kernel_sum_rows( threadgroup_barrier(mem_flags::mem_threadgroup); if (tiisg == 0) { - shmem_f32[sgitg] = sumf; + shmem_t[sgitg] = sumf; } threadgroup_barrier(mem_flags::mem_threadgroup); - sumf = shmem_f32[tiisg]; + sumf = shmem_t[tiisg]; sumf = simd_sum(sumf); if (tpitg.x == 0) { - dst_row[0] = norm ? sumf / args.ne00 : sumf; + if (FC_OP == OP_SUM_ROWS_NUM_MEAN) { + if (is_same::value) { + dst_row[0] = sum(sumf) / (4*args.ne00); + } else { + dst_row[0] = sum(sumf) / args.ne00; + } + } else { + dst_row[0] = sum(sumf); + } } + +#undef FC_OP } -typedef decltype(kernel_sum_rows) kernel_sum_rows_t; +typedef decltype(kernel_sum_rows_impl) kernel_sum_rows_t; -template [[host_name("kernel_sum_rows_f32")]] kernel kernel_sum_rows_t kernel_sum_rows; -template [[host_name("kernel_mean_f32")]] kernel kernel_sum_rows_t kernel_sum_rows; +template [[host_name("kernel_sum_rows_f32_f32")]] kernel kernel_sum_rows_t kernel_sum_rows_impl; +template [[host_name("kernel_sum_rows_f32_f32_4")]] kernel kernel_sum_rows_t kernel_sum_rows_impl; template kernel void kernel_cumsum_blk( @@ -2435,9 +2455,6 @@ kernel void kernel_solve_tri_f32( const short K = FC_solve_tri_k; const short NP = PAD2(N, NW); - const int32_t ne02 = args.ne02; - const int32_t ne03 = args.ne03; - const int32_t i03 = tgpig.z; const int32_t i02 = tgpig.y; const int32_t i01 = tgpig.x*NSG + sgitg; @@ -5949,7 +5966,7 @@ kernel void kernel_flash_attn_ext_vec( static_assert(DK4 % NL == 0, "DK4 must be divisible by NL"); static_assert(DV4 % NL == 0, "DV4 must be divisible by NL"); - const short T = PK + NSG*SH; // shared memory size per query in (half) + //const short T = PK + NSG*SH; // shared memory size per query in (half) //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*PK); // holds the query data threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*PK); // same as above but in q4_t @@ -8537,7 +8554,9 @@ kernel void kernel_mul_mm( threadgroup S0 * sa = (threadgroup S0 *)(shmem); threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096); +#ifdef GGML_METAL_HAS_TENSOR threadgroup float * sc = (threadgroup float *)(shmem); +#endif constexpr int NR0 = 64; constexpr int NR1 = 32; @@ -8660,8 +8679,8 @@ kernel void kernel_mul_mm( const short sx = (tiitg%NL1); const short sy = (tiitg/NL1)/8; - const short dx = sx; - const short dy = sy; + //const short dx = sx; + //const short dy = sy; const short ly = (tiitg/NL1)%8; @@ -8910,7 +8929,9 @@ kernel void kernel_mul_mm_id( threadgroup S0 * sa = (threadgroup S0 *)(shmem); threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096); +#ifdef GGML_METAL_HAS_TENSOR threadgroup float * sc = (threadgroup float *)(shmem); +#endif constexpr int NR0 = 64; constexpr int NR1 = 32; @@ -9045,8 +9066,8 @@ kernel void kernel_mul_mm_id( const short sx = (tiitg%NL1); const short sy = (tiitg/NL1)/8; - const short dx = sx; - const short dy = sy; + //const short dx = sx; + //const short dy = sy; const short ly = (tiitg/NL1)%8; From 9f87eeccdf982faf6aeeb1cef53c6f2681c792ff Mon Sep 17 00:00:00 2001 From: lhez Date: Thu, 12 Feb 2026 14:52:37 -0800 Subject: [PATCH 137/831] opencl: add basic support for q4_1 (llama/19534) * opencl: add q4_1 mv * opencl: clean up * opencl: add flattened q4_1 mv * opencl: clean up * opencl: add basic q4_1 mm * opencl: fix whitespace * opencl: add general q4_0 mm --- ggml/src/ggml-opencl/CMakeLists.txt | 4 + ggml/src/ggml-opencl/ggml-opencl.cpp | 404 +++++++++++++++++- ggml/src/ggml-opencl/kernels/cvt.cl | 51 +++ .../kernels/mul_mm_q4_0_f32_l4_lm.cl | 163 +++++++ .../kernels/mul_mm_q4_1_f32_l4_lm.cl | 165 +++++++ .../ggml-opencl/kernels/mul_mv_q4_1_f32.cl | 219 ++++++++++ .../kernels/mul_mv_q4_1_f32_flat.cl | 229 ++++++++++ 7 files changed, 1231 insertions(+), 4 deletions(-) create mode 100644 ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl create mode 100644 ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl create mode 100644 ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl create mode 100644 ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index b6094fb68b0..f3891936911 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -85,6 +85,8 @@ set(GGML_OPENCL_KERNELS mul_mv_q4_0_f32_8x_flat mul_mv_q4_0_f32_1d_8x_flat mul_mv_q4_0_f32_1d_16x_flat + mul_mv_q4_1_f32 + mul_mv_q4_1_f32_flat mul_mv_q4_k_f32 mul_mv_q6_k_f32 mul_mv_q6_k_f32_flat @@ -101,6 +103,8 @@ set(GGML_OPENCL_KERNELS gemv_moe_mxfp4_f32 mul_mm_f32_f32_l4_lm mul_mm_f16_f32_l4_lm + mul_mm_q4_0_f32_l4_lm + mul_mm_q4_1_f32_l4_lm mul_mm_q8_0_f32_l4_lm mul_mm_q6_k_f32_l4_lm mul_mm_q8_0_f32_8x4 diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 40474c193bb..ae3f79fd0d6 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -525,6 +525,7 @@ struct ggml_backend_opencl_context { cl_kernel kernel_mul_mm_f16_f32_kq; cl_kernel kernel_mul_mat_q4_0_f32, kernel_mul_mat_q4_0_f32_v; cl_kernel kernel_convert_block_q4_0, kernel_restore_block_q4_0; + cl_kernel kernel_convert_block_q4_1, kernel_restore_block_q4_1; cl_kernel kernel_convert_block_mxfp4, kernel_convert_block_mxfp4_trans, kernel_restore_block_mxfp4, kernel_restore_block_mxfp4_trans; cl_kernel kernel_convert_block_q8_0, kernel_restore_block_q8_0, kernel_restore_block_q8_0_trans; cl_kernel kernel_mul_mat_q4_0_f32_8x_flat; @@ -532,6 +533,8 @@ struct ggml_backend_opencl_context { cl_kernel kernel_restore_block_q4_0_noshuffle; cl_kernel kernel_convert_block_q6_K, kernel_restore_block_q6_K; cl_kernel kernel_mul_mat_q4_0_f32_1d_8x_flat, kernel_mul_mat_q4_0_f32_1d_16x_flat; + cl_kernel kernel_mul_mv_q4_1_f32; + cl_kernel kernel_mul_mv_q4_1_f32_flat; cl_kernel kernel_mul_mv_q4_K_f32; cl_kernel kernel_mul_mv_q6_K_f32; cl_kernel kernel_mul_mv_q6_K_f32_flat; @@ -564,6 +567,8 @@ struct ggml_backend_opencl_context { cl_kernel kernel_mul_mv_id_mxfp4_f32_flat; cl_kernel kernel_mul_mm_f32_f32_l4_lm; cl_kernel kernel_mul_mm_f16_f32_l4_lm; + cl_kernel kernel_mul_mm_q4_0_f32_l4_lm; + cl_kernel kernel_mul_mm_q4_1_f32_l4_lm; cl_kernel kernel_mul_mm_q8_0_f32_l4_lm; cl_kernel kernel_mul_mm_q6_k_f32_l4_lm; @@ -888,6 +893,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve CL_CHECK((backend_ctx->kernel_restore_block_q4_0_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_0_noshuffle", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q4_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_0", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_q4_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_0", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q4_1 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_1", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q4_1 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_1", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_mxfp4 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_mxfp4_trans = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4_trans", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_mxfp4_trans = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_mxfp4_trans", &err), err)); @@ -1119,6 +1126,40 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // mul_mv_q4_1_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_q4_1_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_q4_1_f32.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mv_q4_1_f32 = clCreateKernel(prog, "kernel_mul_mv_q4_1_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // mul_mv_q4_1_f32_flat + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_q4_1_f32_flat.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_q4_1_f32_flat.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mv_q4_1_f32_flat = clCreateKernel(prog, "kernel_mul_mv_q4_1_f32_flat", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + // mul_mv_q4_k_f32 { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -1361,6 +1402,38 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // mul_mm_q4_0_f32_l4_lm + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mm_q4_0_f32_l4_lm.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mm_q4_0_f32_l4_lm.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mm_q4_0_f32_l4_lm = clCreateKernel(prog, "kernel_mul_mm_q4_0_f32_l4_lm", &err), err)); + GGML_LOG_CONT("."); + } + + // mul_mm_q4_1_f32_l4_lm + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mm_q4_1_f32_l4_lm.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mm_q4_1_f32_l4_lm.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mm_q4_1_f32_l4_lm = clCreateKernel(prog, "kernel_mul_mm_q4_1_f32_l4_lm", &err), err)); + GGML_LOG_CONT("."); + } + // mul_mm_q8_0_f32_l4_lm { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -2923,6 +2996,59 @@ struct ggml_tensor_extra_cl_q4_0 { } }; +struct ggml_tensor_extra_cl_q4_1 { + // Quantized values. + cl_mem q = nullptr; + // Quantized values in image1d_buffer_t. + cl_mem q_img = nullptr; + // Scales. + cl_mem d = nullptr; + // Scales in image1d_buffer_t. + cl_mem d_img = nullptr; + // Min + cl_mem m = nullptr; + // Min in image1d_buffer_t. + cl_mem m_img = nullptr; + // Size of quantized values. + size_t size_q = 0; + // Size of scales. + size_t size_d = 0; + // Size of min values. + size_t size_m = 0; + + ~ggml_tensor_extra_cl_q4_1() { + reset(); + } + + void reset() { + // q and d are subbuffers into the bigger buffer allocated in ggml_backend_buffer. + // They must be properly released so that the original buffer can be + // properly released to avoid memory leak. + if (q != nullptr) { + CL_CHECK(clReleaseMemObject(q)); + q = nullptr; + } + if (d != nullptr) { + CL_CHECK(clReleaseMemObject(d)); + d = nullptr; + } + if (m != nullptr) { + CL_CHECK(clReleaseMemObject(m)); + m = nullptr; + } + // Currently, q_img and d_img are only initialized when SMALL_ALLOC is + // enabled. They point to the images in ggml_backend_opencl_buffer_context. + // So, there is no need to release them here. + // TODO: initialize them for non SMALL_PATH path, or remove them. + q_img = nullptr; + d_img = nullptr; + m_img = nullptr; + size_q = 0; + size_d = 0; + size_m = 0; + } +}; + struct ggml_tensor_extra_cl_mxfp4 { // Quantized values. cl_mem q = nullptr; @@ -3399,8 +3525,9 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te return true; } else if (op->src[0]->type == GGML_TYPE_F32) { return op->src[1]->type == GGML_TYPE_F32; - } else if (op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_MXFP4 || - op->src[0]->type == GGML_TYPE_Q4_K || + } else if (op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q4_1 || + op->src[0]->type == GGML_TYPE_MXFP4 || + op->src[0]->type == GGML_TYPE_Q4_K || op->src[0]->type == GGML_TYPE_Q6_K) { return op->src[1]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]); } else if (op->src[0]->type == GGML_TYPE_Q8_0) { @@ -3629,6 +3756,21 @@ struct ggml_backend_opencl_buffer_context { return extra; } + ggml_tensor_extra_cl_q4_1 * ggml_opencl_alloc_temp_tensor_extra_q4_1() { + ggml_tensor_extra_cl_q4_1 * extra; + if (temp_tensor_extras_q4_1.empty()) { + extra = new ggml_tensor_extra_cl_q4_1(); + } else { + extra = temp_tensor_extras_q4_1.back(); + temp_tensor_extras_q4_1.pop_back(); + } + + temp_tensor_extras_q4_1_in_use.push_back(extra); + + extra->reset(); + return extra; + } + ggml_tensor_extra_cl_mxfp4 * ggml_opencl_alloc_temp_tensor_extra_mxfp4() { ggml_tensor_extra_cl_mxfp4 * extra; if (temp_tensor_extras_mxfp4.empty()) { @@ -3685,6 +3827,11 @@ struct ggml_backend_opencl_buffer_context { } temp_tensor_extras_q4_0_in_use.clear(); + for (ggml_tensor_extra_cl_q4_1 * e : temp_tensor_extras_q4_1_in_use) { + temp_tensor_extras_q4_1.push_back(e); + } + temp_tensor_extras_q4_1_in_use.clear(); + for (ggml_tensor_extra_cl_mxfp4 * e : temp_tensor_extras_mxfp4_in_use) { temp_tensor_extras_mxfp4.push_back(e); } @@ -3710,6 +3857,8 @@ struct ggml_backend_opencl_buffer_context { std::vector temp_tensor_extras_in_use; std::vector temp_tensor_extras_q4_0; std::vector temp_tensor_extras_q4_0_in_use; + std::vector temp_tensor_extras_q4_1; + std::vector temp_tensor_extras_q4_1_in_use; std::vector temp_tensor_extras_mxfp4; std::vector temp_tensor_extras_mxfp4_in_use; std::vector temp_tensor_extras_q8_0; @@ -4079,6 +4228,75 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, return; } + if (tensor->type == GGML_TYPE_Q4_1) { + ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra; + GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized"); + + // Allocate the new extra and create aliases from the original. + ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; + ggml_tensor_extra_cl_q4_1 * extra = ctx->ggml_opencl_alloc_temp_tensor_extra_q4_1(); + + size_t size_d = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t); + size_t size_m = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t); + size_t size_q = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/2; + GGML_ASSERT(size_d + size_m + size_q == ggml_nbytes(tensor) && "Incorrect tensor size"); + + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + CL_CHECK(clEnqueueWriteBuffer( + queue, data_device, CL_TRUE, 0, + ggml_nbytes(tensor), data, 0, NULL, NULL)); + + cl_buffer_region region; + + // The original tensor memory is divided into scales and quants, i.e., + // we first store scales, mins, then quants. + // Create subbuffer for scales. + region.origin = align_to(extra_orig->offset + tensor->view_offs + offset, backend_ctx->alignment); + region.size = size_d; + extra->d = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + auto previous_origin = region.origin; + + // Create subbuffer for mins. + region.origin = align_to(previous_origin + size_d, backend_ctx->alignment); + region.size = size_m; + extra->m = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + previous_origin = region.origin; + + // Create subbuffer for quants. + region.origin = align_to(previous_origin + size_m, backend_ctx->alignment); + region.size = size_q; + extra->q = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + + cl_kernel kernel = backend_ctx->kernel_convert_block_q4_1; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->m)); + + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + + tensor->extra = extra; + + return; + } if (tensor->type == GGML_TYPE_MXFP4) { ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra; GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized"); @@ -4581,7 +4799,35 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, size, data, 0, NULL, NULL)); CL_CHECK(clReleaseMemObject(data_device)); return; - } else if (tensor->type == GGML_TYPE_MXFP4) { + } + if (tensor->type == GGML_TYPE_Q4_1) { + ggml_tensor_extra_cl_q4_1 * extra = (ggml_tensor_extra_cl_q4_1 *)tensor->extra; + + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + + cl_kernel kernel = backend_ctx->kernel_restore_block_q4_1; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->m)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &data_device)); + + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {1, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clEnqueueReadBuffer( + queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); + return; + } + if (tensor->type == GGML_TYPE_MXFP4) { ggml_tensor_extra_cl_mxfp4 * extra = (ggml_tensor_extra_cl_mxfp4 *)tensor->extra; cl_int err; @@ -8409,6 +8655,7 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co #ifdef GGML_OPENCL_SOA_Q ggml_tensor_extra_cl_q4_0 * extra0_q4_0 = (ggml_tensor_extra_cl_q4_0 *)src0->extra; + ggml_tensor_extra_cl_q4_1 * extra0_q4_1 = (ggml_tensor_extra_cl_q4_1 *)src0->extra; ggml_tensor_extra_cl_mxfp4 * extra0_mxfp4 = (ggml_tensor_extra_cl_mxfp4 *)src0->extra; ggml_tensor_extra_cl_q8_0 * extra0_q8_0 = (ggml_tensor_extra_cl_q8_0 *)src0->extra; ggml_tensor_extra_cl_q6_K * extra0_q6_K = (ggml_tensor_extra_cl_q6_K *)src0->extra; @@ -8922,6 +9169,91 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); return; } + case GGML_TYPE_Q4_0: { + if (ne11 < 32) { + break; + } + if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) { + break; + } + + kernel = backend_ctx->kernel_mul_mm_q4_0_f32_l4_lm; + nth0 = 128; // calculated as (BM*BN)/(TM*TN) + + int batch_stride_a = ne00*ne01; + int batch_stride_b = ne10*ne11; + int batch_stride_d = ne0*ne1; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q4_0->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_0->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne10)); // stride_a + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne10)); // stride_b + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne01)); // stride_d + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &batch_stride_a)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &batch_stride_b)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &batch_stride_d)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &r3)); + + // 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed. + size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13}; + size_t local_work_size[] = {(size_t)nth0, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + return; + } + case GGML_TYPE_Q4_1: { + if (ne11 < 32) { + break; + } + if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) { + break; + } + + kernel = backend_ctx->kernel_mul_mm_q4_1_f32_l4_lm; + nth0 = 128; // calculated as (BM*BN)/(TM*TN) + + int batch_stride_a = ne00*ne01; + int batch_stride_b = ne10*ne11; + int batch_stride_d = ne0*ne1; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q4_1->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_1->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q4_1->m)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne10)); // stride_a + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne10)); // stride_b + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne01)); // stride_d + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &batch_stride_a)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &batch_stride_b)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &batch_stride_d)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(int), &r3)); + + // 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed. + size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13}; + size_t local_work_size[] = {(size_t)nth0, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + return; + } case GGML_TYPE_Q8_0: { if (ne11 < 32) { break; @@ -9262,7 +9594,71 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r3)); #endif // GGML_OPENCL_SOA_Q break; - case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_1: { +#ifdef GGML_OPENCL_SOA_Q + if (backend_ctx->gpu_family == INTEL) { + nth0 = 16; + nth1 = 1; + ndst = 4; + } else if (backend_ctx->gpu_family == ADRENO) { + nth0 = 64; + nth1 = 1; + ndst = 4; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } + + kernel = backend_ctx->kernel_mul_mv_q4_1_f32_flat; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q4_1->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_1->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q4_1->m)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &r3)); +#else + if (backend_ctx->gpu_family == INTEL) { + nth0 = 16; + nth1 = 1; + ndst = 4; + } else if (backend_ctx->gpu_family == ADRENO) { + nth0 = 64; + nth1 = 1; + ndst = 4; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } + + kernel = backend_ctx->kernel_mul_mv_q4_1_f32; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r3)); +#endif // GGML_OPENCL_SOA_Q + break; + } case GGML_TYPE_Q8_0: { #ifdef GGML_OPENCL_SOA_Q kernel = backend_ctx->kernel_mul_mv_q8_0_f32_flat; diff --git a/ggml/src/ggml-opencl/kernels/cvt.cl b/ggml/src/ggml-opencl/kernels/cvt.cl index 9fb434713df..2c244ce3215 100644 --- a/ggml/src/ggml-opencl/kernels/cvt.cl +++ b/ggml/src/ggml-opencl/kernels/cvt.cl @@ -46,6 +46,15 @@ struct block_q4_0 uint8_t qs[QK4_0 / 2]; }; +//------------------------------------------------------------------------------ +// block_q4_1 +//------------------------------------------------------------------------------ +struct block_q4_1 { + half d; // delta + half m; // min + uchar qs[QK4_1 / 2]; // nibbles / quants +}; + //------------------------------------------------------------------------------ // block_q6_K //------------------------------------------------------------------------------ @@ -148,6 +157,48 @@ kernel void kernel_restore_block_q4_0_noshuffle( } } +//------------------------------------------------------------------------------ +// kernel_convert_block_q4_1 +// Convert the block_q4_1 format to 2 separate arrays (AOS -> SOA). +// This kernel does not deshuffle the bits. +//------------------------------------------------------------------------------ +kernel void kernel_convert_block_q4_1( + global struct block_q4_1 * src0, + global uchar * dst_q, + global half * dst_d, + global half * dst_m +) { + global struct block_q4_1 * b = (global struct block_q4_1 *) src0 + get_global_id(0); + global uchar * q = (global uchar *) dst_q + QK4_1/2*get_global_id(0); + global half * d = (global half *) dst_d + get_global_id(0); + global half * m = (global half *) dst_m + get_global_id(0); + + *d = b->d; + *m = b->m; + + for (int i = 0; i < QK4_1/2; ++i) { + q[i] = b->qs[i]; + } +} + +kernel void kernel_restore_block_q4_1( + global uchar * src_q, + global half * src_d, + global half * src_m, + global struct block_q4_1 * dst +) { + global struct block_q4_1 * b = (global struct block_q4_1 *) dst + get_global_id(0); + global uchar * q = (global uchar *) src_q + QK4_1/2*get_global_id(0); + global half * d = (global half *) src_d + get_global_id(0); + global half * m = (global half *) src_m + get_global_id(0); + + b->d = *d; + b->m = *m; + for (int i = 0; i < QK4_1/2; ++i) { + b->qs[i] = q[i]; + } +} + //------------------------------------------------------------------------------ // block_mxfp4 //------------------------------------------------------------------------------ diff --git a/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl b/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl new file mode 100644 index 00000000000..4100e3080a2 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl @@ -0,0 +1,163 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#define LOAD_VEC_A 8 +#define LOAD_VEC_B 4 + +#define BM 64 +#define BN 64 +#define BK 32 +#define TM 4 +#define TN 8 + +kernel void kernel_mul_mm_q4_0_f32_l4_lm( + global uchar4 * src0_q, + global half * src0_d, + global float4 * src1, + ulong offset1, + global float * dst, + ulong offsetd, + + int ne00, + int ne01, + int ne02, + int ne11, + int ne12, + + int stride_a, + int stride_b, + int stride_d, + + int batch_stride_a, + int batch_stride_b, + int batch_stride_d, + + int r2, + int r3 +) { + src1 = (global float4*)((global char*)src1 + offset1); + dst = (global float *)((global char*)dst + offsetd); + + local float buf_a[BM * BK]; + local float buf_b[BN * BK]; + + const int batch_idx = get_global_id(2); + + const int i13 = batch_idx / ne12; + const int i12 = batch_idx % ne12; + + const int i03 = i13 / r3; + const int i02 = i12 / r2; + + const int batch_idx_a = i03 * ne02 + i02; + + const int ir = get_group_id(0); + const int ic = get_group_id(1); + + const int tid = get_local_id(0); + const int th_r = tid % (BM / TM); + const int th_c = tid / (BM / TM); + + const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A); + const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A); + const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B); + const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B); + + const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK; + const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK; + + int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A; + int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B; + + float sums[TM * TN]; + float cache_a[TM]; + float cache_b[TN]; + + for (int i = 0; i < TM * TN; i++) { + sums[i] = 0.0f; + } + + for (int block = 0; block < ne00; block += BK) { + for (int l = 0; l < BM; l += loadstride_a) { + if (ir*BM + loadc_a + l < ne01) { + int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a; + int ib = idx / 4; + int iqs = idx % 4; + + float d = (float)src0_d[ib]; + global uchar4 * qs = src0_q + ib*4 + iqs; + uchar4 q = *qs; + float4 v1 = (convert_float4((uchar4)((q.s0 )&0x0F, (q.s1 )&0x0F, (q.s2 )&0x0F, (q.s3 )&0x0F)) - 8.0f)*d; + float4 v2 = (convert_float4((uchar4)((q.s0>>4)&0x0F, (q.s1>>4)&0x0F, (q.s2>>4)&0x0F, (q.s3>>4)&0x0F)) - 8.0f)*d; + + buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = v1.s0; + buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = v1.s1; + buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = v1.s2; + buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = v1.s3; + buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = v2.s0; + buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = v2.s1; + buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = v2.s2; + buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = v2.s3; + } else { + buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = 0.0f; + } + } + + for (int l = 0; l < BN; l += loadstride_b) { + if (ic*BN + loadc_b + l < ne11) { + int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b; + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3; + } else { + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f; + } + } + + barrier(CLK_LOCAL_MEM_FENCE); + + pos_a += BK / LOAD_VEC_A; + pos_b += BK / LOAD_VEC_B; + + for (int i = 0; i < BK; i++) { + for (int j = 0; j < TM; j++) { + cache_a[j] = buf_a[(i) * BM + th_r * TM + j]; + } + + for (int j = 0; j < TN; j++) { + cache_b[j] = buf_b[(i) * BN + th_c * TN + j]; + } + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + const int sums_idx = cc*TM + cr; + sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]); + } + } + } + barrier(CLK_LOCAL_MEM_FENCE); + } + + const int dr = ir * BM + th_r * TM; + const int dc = ic * BN + th_c * TN; + + const int offsets = batch_idx * batch_stride_d; + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + if (dr + cr < ne01 && dc + cc < ne11) { + dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr]; + } + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl b/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl new file mode 100644 index 00000000000..d0d2f08361e --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl @@ -0,0 +1,165 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#define LOAD_VEC_A 8 +#define LOAD_VEC_B 4 + +#define BM 64 +#define BN 64 +#define BK 32 +#define TM 4 +#define TN 8 + +kernel void kernel_mul_mm_q4_1_f32_l4_lm( + global uchar4 * src0_q, + global half * src0_d, + global half * src0_m, + global float4 * src1, + ulong offset1, + global float * dst, + ulong offsetd, + + int ne00, + int ne01, + int ne02, + int ne11, + int ne12, + + int stride_a, + int stride_b, + int stride_d, + + int batch_stride_a, + int batch_stride_b, + int batch_stride_d, + + int r2, + int r3 +) { + src1 = (global float4*)((global char*)src1 + offset1); + dst = (global float *)((global char*)dst + offsetd); + + local float buf_a[BM * BK]; + local float buf_b[BN * BK]; + + const int batch_idx = get_global_id(2); + + const int i13 = batch_idx / ne12; + const int i12 = batch_idx % ne12; + + const int i03 = i13 / r3; + const int i02 = i12 / r2; + + const int batch_idx_a = i03 * ne02 + i02; + + const int ir = get_group_id(0); + const int ic = get_group_id(1); + + const int tid = get_local_id(0); + const int th_r = tid % (BM / TM); + const int th_c = tid / (BM / TM); + + const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A); + const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A); + const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B); + const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B); + + const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK; + const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK; + + int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A; + int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B; + + float sums[TM * TN]; + float cache_a[TM]; + float cache_b[TN]; + + for (int i = 0; i < TM * TN; i++) { + sums[i] = 0.0f; + } + + for (int block = 0; block < ne00; block += BK) { + for (int l = 0; l < BM; l += loadstride_a) { + if (ir*BM + loadc_a + l < ne01) { + int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a; + int ib = idx / 4; + int iqs = idx % 4; + + float d = (float)src0_d[ib]; + float m = (float)src0_m[ib]; + global uchar4 * qs = src0_q + ib*4 + iqs; + uchar4 q = *qs; + float4 v1 = (convert_float4((uchar4)((q.s0 )&0x0F, (q.s1 )&0x0F, (q.s2 )&0x0F, (q.s3 )&0x0F)))*d + m; + float4 v2 = (convert_float4((uchar4)((q.s0>>4)&0x0F, (q.s1>>4)&0x0F, (q.s2>>4)&0x0F, (q.s3>>4)&0x0F)))*d + m; + + buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = v1.s0; + buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = v1.s1; + buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = v1.s2; + buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = v1.s3; + buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = v2.s0; + buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = v2.s1; + buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = v2.s2; + buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = v2.s3; + } else { + buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = 0.0f; + } + } + + for (int l = 0; l < BN; l += loadstride_b) { + if (ic*BN + loadc_b + l < ne11) { + int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b; + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3; + } else { + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f; + } + } + + barrier(CLK_LOCAL_MEM_FENCE); + + pos_a += BK / LOAD_VEC_A; + pos_b += BK / LOAD_VEC_B; + + for (int i = 0; i < BK; i++) { + for (int j = 0; j < TM; j++) { + cache_a[j] = buf_a[(i) * BM + th_r * TM + j]; + } + + for (int j = 0; j < TN; j++) { + cache_b[j] = buf_b[(i) * BN + th_c * TN + j]; + } + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + const int sums_idx = cc*TM + cr; + sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]); + } + } + } + barrier(CLK_LOCAL_MEM_FENCE); + } + + const int dr = ir * BM + th_r * TM; + const int dc = ic * BN + th_c * TN; + + const int offsets = batch_idx * batch_stride_d; + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + if (dr + cr < ne01 && dc + cc < ne11) { + dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr]; + } + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl new file mode 100644 index 00000000000..6fe828f20e7 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl @@ -0,0 +1,219 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define QK4_1 32 + +struct block_q4_1 { + half d; // delta + half m; // min + uchar qs[QK4_1 / 2]; // nibbles / quants +}; + +inline float block_q4_1_dot_y( + global const struct block_q4_1 * qb_curr, + float sumy, + float16 yl, + int il +) { + float d = qb_curr->d; + float m = qb_curr->m; + + float4 acc = (float4)(0.0f, 0.0f, 0.0f, 0.0f); + + global const ushort * qs = ((global const ushort *) qb_curr + 2 + il/2); + + acc.s0 += yl.s0 * (qs[0] & 0x000F); + acc.s0 += yl.s1 * (qs[0] & 0x0F00); + acc.s0 += yl.s8 * (qs[0] & 0x00F0); + acc.s3 += yl.s9 * (qs[0] & 0xF000); + + acc.s0 += yl.s2 * (qs[1] & 0x000F); + acc.s1 += yl.s3 * (qs[1] & 0x0F00); + acc.s2 += yl.sa * (qs[1] & 0x00F0); + acc.s3 += yl.sb * (qs[1] & 0xF000); + + acc.s0 += yl.s4 * (qs[2] & 0x000F); + acc.s1 += yl.s5 * (qs[2] & 0x0F00); + acc.s2 += yl.sc * (qs[2] & 0x00F0); + acc.s3 += yl.sd * (qs[2] & 0xF000); + + acc.s0 += yl.s6 * (qs[3] & 0x000F); + acc.s1 += yl.s7 * (qs[3] & 0x0F00); + acc.s2 += yl.se * (qs[3] & 0x00F0); + acc.s3 += yl.sf * (qs[3] & 0xF000); + + return d * (acc.s0 + acc.s1 + acc.s2 + acc.s3) + sumy * m; +} + +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 4 // each subgroup works on 4 rows +#define N_SIMDGROUP 1 // number of subgroups in a thread group +#define N_SIMDWIDTH 16 // assuming subgroup size is 16 +#elif defined (ADRENO_GPU) +#define N_DST 4 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif + +inline void mul_vec_q_n_f32( + global void * src0, + global float * src1, + global float * dst, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + const ulong nb = ne00/QK4_1; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + global struct block_q4_1 * x = (global struct block_q4_1 *) src0 + offset0; + global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float16 yl; + float4 sumf = (float4)(0.f, 0.f, 0.f, 0.f); + + int ix = get_sub_group_local_id()/2; + int il = 8*(get_sub_group_local_id()%2); + + global float * yb = y + ix * QK4_1 + il; + + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { + float sumy = 0; + + sumy += yb[0]; + sumy += yb[1]; + sumy += yb[2]; + sumy += yb[3]; + sumy += yb[4]; + sumy += yb[5]; + sumy += yb[6]; + sumy += yb[7]; + + sumy += yb[16]; + sumy += yb[17]; + sumy += yb[18]; + sumy += yb[19]; + sumy += yb[20]; + sumy += yb[21]; + sumy += yb[22]; + sumy += yb[23]; + + + yl.s0 = yb[0]; + yl.s1 = yb[1]/256.f; + + yl.s2 = yb[2]; + yl.s3 = yb[3]/256.f; + + yl.s4 = yb[4]; + yl.s5 = yb[5]/256.f; + + yl.s6 = yb[6]; + yl.s7 = yb[7]/256.f; + + yl.s8 = yb[16]/16.f; + yl.s9 = yb[17]/4096.f; + + yl.sa = yb[18]/16.f; + yl.sb = yb[19]/4096.f; + + yl.sc = yb[20]/16.f; + yl.sd = yb[21]/4096.f; + + yl.se = yb[22]/16.f; + yl.sf = yb[23]/4096.f; + + sumf.s0 += block_q4_1_dot_y(x+ib+0*nb, sumy, yl, il); + sumf.s1 += block_q4_1_dot_y(x+ib+1*nb, sumy, yl, il); + sumf.s2 += block_q4_1_dot_y(x+ib+2*nb, sumy, yl, il); + sumf.s3 += block_q4_1_dot_y(x+ib+3*nb, sumy, yl, il); + + yb += QK4_1 * (N_SIMDWIDTH/2); + } + + float4 tot = (float4)( + sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), + sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3) + ); + + if (get_sub_group_local_id() == 0) { + if (first_row + 0 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; + } + if (first_row + 1 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; + } + if (first_row + 2 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; + } + if (first_row + 3 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; + } + } +} + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mv_q4_1_f32( + global void * src0, + ulong offset0, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + mul_vec_q_n_f32(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl new file mode 100644 index 00000000000..d7c4645d675 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl @@ -0,0 +1,229 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define QK4_1 32 + +struct block_q4_1 { + half d; // delta + half m; // min + uchar qs[QK4_1 / 2]; // nibbles / quants +}; + +inline float block_q4_1_dot_y_flat( + global const uchar * x, + global const half * dh, + global const half * mh, + float sumy, + float16 yl, + int il +) { + float d = *dh; + float m = *mh; + global const ushort * qs = ((global const ushort *) x + il/2); + + float4 acc = (float4)(0.0f, 0.0f, 0.0f, 0.0f); + + acc.s0 += yl.s0 * (qs[0] & 0x000F); + acc.s0 += yl.s1 * (qs[0] & 0x0F00); + acc.s0 += yl.s8 * (qs[0] & 0x00F0); + acc.s3 += yl.s9 * (qs[0] & 0xF000); + + acc.s0 += yl.s2 * (qs[1] & 0x000F); + acc.s1 += yl.s3 * (qs[1] & 0x0F00); + acc.s2 += yl.sa * (qs[1] & 0x00F0); + acc.s3 += yl.sb * (qs[1] & 0xF000); + + acc.s0 += yl.s4 * (qs[2] & 0x000F); + acc.s1 += yl.s5 * (qs[2] & 0x0F00); + acc.s2 += yl.sc * (qs[2] & 0x00F0); + acc.s3 += yl.sd * (qs[2] & 0xF000); + + acc.s0 += yl.s6 * (qs[3] & 0x000F); + acc.s1 += yl.s7 * (qs[3] & 0x0F00); + acc.s2 += yl.se * (qs[3] & 0x00F0); + acc.s3 += yl.sf * (qs[3] & 0xF000); + + return d * (acc.s0 + acc.s1 + acc.s2 + acc.s3) + sumy * m; +} + +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 4 // each subgroup works on 4 rows +#define N_SIMDGROUP 1 // number of subgroups in a thread group +#define N_SIMDWIDTH 16 // assuming subgroup size is 16 +#elif defined (ADRENO_GPU) +#define N_DST 4 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif + +inline void mul_vec_q_n_f32_flat( + global void * src0_q, + global void * src0_d, + global void * src0_m, + global float * src1, + global float * dst, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + const ulong nb = ne00/QK4_1; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + // The number of scales/mins is the same as the number of blocks. + ulong offset0_dm = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)); + // Each block contains QK4_1/2 uchars, hence offset for qs is as follows. + ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_1/2; + + global uchar * x = (global uchar *) src0_q + offset0_q; + global half * d = (global half *) src0_d + offset0_dm; + global half * m = (global half *) src0_m + offset0_dm; + global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float16 yl; + float4 sumf = (float4)(0.f, 0.f, 0.f, 0.f); + + int ix = get_sub_group_local_id()/2; + int il = 8*(get_sub_group_local_id()%2); + + global float * yb = y + ix * QK4_1 + il; + + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { + float sumy = 0; + + sumy += yb[0]; + sumy += yb[1]; + sumy += yb[2]; + sumy += yb[3]; + sumy += yb[4]; + sumy += yb[5]; + sumy += yb[6]; + sumy += yb[7]; + + sumy += yb[16]; + sumy += yb[17]; + sumy += yb[18]; + sumy += yb[19]; + sumy += yb[20]; + sumy += yb[21]; + sumy += yb[22]; + sumy += yb[23]; + + + yl.s0 = yb[0]; + yl.s1 = yb[1]/256.f; + + yl.s2 = yb[2]; + yl.s3 = yb[3]/256.f; + + yl.s4 = yb[4]; + yl.s5 = yb[5]/256.f; + + yl.s6 = yb[6]; + yl.s7 = yb[7]/256.f; + + yl.s8 = yb[16]/16.f; + yl.s9 = yb[17]/4096.f; + + yl.sa = yb[18]/16.f; + yl.sb = yb[19]/4096.f; + + yl.sc = yb[20]/16.f; + yl.sd = yb[21]/4096.f; + + yl.se = yb[22]/16.f; + yl.sf = yb[23]/4096.f; + + sumf.s0 += block_q4_1_dot_y_flat(x + ib*QK4_1/2 + 0*nb*QK4_1/2, d + ib + 0*nb, m + ib + 0*nb, sumy, yl, il); + sumf.s1 += block_q4_1_dot_y_flat(x + ib*QK4_1/2 + 1*nb*QK4_1/2, d + ib + 1*nb, m + ib + 1*nb, sumy, yl, il); + sumf.s2 += block_q4_1_dot_y_flat(x + ib*QK4_1/2 + 2*nb*QK4_1/2, d + ib + 2*nb, m + ib + 2*nb, sumy, yl, il); + sumf.s3 += block_q4_1_dot_y_flat(x + ib*QK4_1/2 + 3*nb*QK4_1/2, d + ib + 3*nb, m + ib + 3*nb, sumy, yl, il); + + yb += QK4_1 * (N_SIMDWIDTH/2); + } + + float4 tot = (float4)( + sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), + sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3) + ); + + if (get_sub_group_local_id() == 0) { + if (first_row + 0 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; + } + if (first_row + 1 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; + } + if (first_row + 2 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; + } + if (first_row + 3 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; + } + } +} + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mv_q4_1_f32_flat( + global void * src0_q, + global void * src0_d, + global void * src0_m, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + mul_vec_q_n_f32_flat(src0_q, src0_d, src0_m, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); +} From 195af60a8b9114de79ae4f3d269f8878199de8f0 Mon Sep 17 00:00:00 2001 From: Shupei Fan Date: Fri, 13 Feb 2026 07:07:49 +0800 Subject: [PATCH 138/831] hexagon: fix typo in vtcm_needs_release (llama/19545) --- ggml/src/ggml-hexagon/htp/main.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index 62708eee5cf..92a1422896c 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -189,7 +189,7 @@ static int vtcm_release_callback(unsigned int rctx, void * state) { // otherwise we'll release it once we're done with the current Op. if (ctx->vtcm_inuse) { - ctx->vtcm_needs_release = false; + ctx->vtcm_needs_release = true; return 0; } From c5325e50fce6ee1d8344f0dd0a215f996989a2df Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 13 Feb 2026 07:34:52 +0200 Subject: [PATCH 139/831] metal : support GGML_OP_SET (llama/19548) --- ggml/src/ggml-metal/ggml-metal-device.m | 1 + ggml/src/ggml-metal/ggml-metal-ops.cpp | 132 ++++++++++++++++++++++++ ggml/src/ggml-metal/ggml-metal-ops.h | 1 + 3 files changed, 134 insertions(+) diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 4ea0bfb94da..b4ca9c5dd6f 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -1159,6 +1159,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: return has_simdgroup_reduction; + case GGML_OP_SET: case GGML_OP_CPY: case GGML_OP_DUP: case GGML_OP_CONT: diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 20880d9551e..c04e9fc7ffa 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -426,6 +426,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { { n_fuse = ggml_metal_op_flash_attn_ext(ctx, idx); } break; + case GGML_OP_SET: + { + n_fuse = ggml_metal_op_set(ctx, idx); + } break; case GGML_OP_DUP: case GGML_OP_CPY: case GGML_OP_CONT: @@ -1609,6 +1613,134 @@ int ggml_metal_op_solve_tri(ggml_metal_op_t ctx, int idx) { return 1; } +int ggml_metal_op_set(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]); + ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]); + ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op); + + const size_t pnb1 = ((const int32_t *) op->op_params)[0]; + const size_t pnb2 = ((const int32_t *) op->op_params)[1]; + const size_t pnb3 = ((const int32_t *) op->op_params)[2]; + const size_t offs = ((const int32_t *) op->op_params)[3]; + + const bool inplace = (bool) ((const int32_t *) op->op_params)[4]; + + if (!inplace) { + // run a separete kernel to cpy src->dst + // not sure how to avoid this + // TODO: make a simpler cpy_bytes kernel + + //const id pipeline = ctx->pipelines[GGML_METAL_PIPELINE_TYPE_CPY_F32_F32].obj; + auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type); + + ggml_metal_kargs_cpy args = { + /*.nk0 =*/ ne00, + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + }; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, bid_src0, 1); + ggml_metal_encoder_set_buffer (enc, bid_dst, 2); + + const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00); + + ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1); + + ggml_metal_op_concurrency_reset(ctx); + } + + auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[1]->type, op->type); + + GGML_ASSERT(ne10 % ggml_blck_size(op->src[1]->type) == 0); + + int64_t nk0 = ne10; + if (ggml_is_quantized(op->src[1]->type)) { + nk0 = ne10/16; + } else if (ggml_is_quantized(op->type)) { + nk0 = ne10/ggml_blck_size(op->type); + } + + int nth = std::min(nk0, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + + // when rows are small, we can batch them together in a single threadgroup + int nrptg = 1; + + // TODO: relax this constraint in the future + if (ggml_blck_size(op->src[1]->type) == 1 && ggml_blck_size(op->type) == 1) { + if (nth > nk0) { + nrptg = (nth + nk0 - 1)/nk0; + nth = nk0; + + if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + nrptg--; + } + } + } + + nth = std::min(nth, nk0); + + ggml_metal_kargs_cpy args = { + /*.nk0 =*/ nk0, + /*.ne00 =*/ ne10, + /*.ne01 =*/ ne11, + /*.ne02 =*/ ne12, + /*.ne03 =*/ ne13, + /*.nb00 =*/ nb10, + /*.nb01 =*/ nb11, + /*.nb02 =*/ nb12, + /*.nb03 =*/ nb13, + /*.ne0 =*/ ne10, + /*.ne1 =*/ ne11, + /*.ne2 =*/ ne12, + /*.ne3 =*/ ne13, + /*.nb0 =*/ ggml_element_size(op), + /*.nb1 =*/ pnb1, + /*.nb2 =*/ pnb2, + /*.nb3 =*/ pnb3, + }; + + const int nw0 = nrptg == 1 ? (nk0 + nth - 1)/nth : 1; + + bid_dst.offs += offs; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, bid_src1, 1); + ggml_metal_encoder_set_buffer (enc, bid_dst, 2); + + ggml_metal_encoder_dispatch_threadgroups(enc, nw0*(ne11 + nrptg - 1)/nrptg, ne12, ne13, nth, nrptg, 1); + + return 1; +} + int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) { ggml_tensor * op = ctx->node(idx); diff --git a/ggml/src/ggml-metal/ggml-metal-ops.h b/ggml/src/ggml-metal/ggml-metal-ops.h index 29456d70d5e..f3e38c7aa9d 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.h +++ b/ggml/src/ggml-metal/ggml-metal-ops.h @@ -59,6 +59,7 @@ int ggml_metal_op_ssm_conv (ggml_metal_op_t ctx, int idx); int ggml_metal_op_ssm_scan (ggml_metal_op_t ctx, int idx); int ggml_metal_op_rwkv (ggml_metal_op_t ctx, int idx); int ggml_metal_op_solve_tri (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_set (ggml_metal_op_t ctx, int idx); int ggml_metal_op_cpy (ggml_metal_op_t ctx, int idx); int ggml_metal_op_pool_1d (ggml_metal_op_t ctx, int idx); int ggml_metal_op_pool_2d (ggml_metal_op_t ctx, int idx); From 0e94faa19cdcb1e1f779b4c86a3087b715785e19 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 13 Feb 2026 07:35:57 +0200 Subject: [PATCH 140/831] metal : improve concurrency (llama/19555) --- ggml/src/ggml-metal/ggml-metal-common.cpp | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-common.cpp b/ggml/src/ggml-metal/ggml-metal-common.cpp index 95627d38665..87e13786849 100644 --- a/ggml/src/ggml-metal/ggml-metal-common.cpp +++ b/ggml/src/ggml-metal/ggml-metal-common.cpp @@ -264,15 +264,25 @@ static std::vector ggml_metal_graph_optimize_reorder(const std::vector ggml_metal_graph_optimize_reorder(const std::vector Date: Fri, 13 Feb 2026 10:37:55 +0100 Subject: [PATCH 141/831] CUDA: Do not mutate cgraph for fused ADDs (llama/19566) * Do not mutate cgraph for fused ADDs 1. We should try to minimize in-place changes to the incoming ggml_cgraph where possible (those should happen in graph_optimize) 2. Modifying in-place leads to an additional, unnecessary graph capture step as we store the properties before modifying the graph in-place in the cuda-backend * Assert ggml_tensor is trivially copyable * Update ggml/src/ggml-cuda/ggml-cuda.cu Co-authored-by: Aman Gupta --------- Co-authored-by: Aman Gupta --- ggml/src/ggml-cuda/ggml-cuda.cu | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index b163468789f..7dc688483ad 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3640,11 +3640,13 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud n_fuse++; if (n_fuse > 1) { + ggml_tensor fused_add_node; + memcpy(&fused_add_node, node, sizeof(ggml_tensor)); for (int j = 0; j < n_fuse - 1; ++j) { - node->src[j + 2] = cgraph->nodes[i + j + 1]->src[1]; + fused_add_node.src[j + 2] = cgraph->nodes[i + j + 1]->src[1]; } - cgraph->nodes[i + n_fuse - 1]->data = node->data; - ggml_cuda_op_fused_add(*cuda_ctx, node, n_fuse); + fused_add_node.data = cgraph->nodes[i + n_fuse - 1]->data; + ggml_cuda_op_fused_add(*cuda_ctx, &fused_add_node, n_fuse); i += n_fuse - 1; continue; From 58e3d5a42dedfdfdd397bb80838cc9fa4e23bb89 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Fri, 13 Feb 2026 17:01:40 +0530 Subject: [PATCH 142/831] CUDA: loop over ne2*ne3 in case it overflows (llama/19538) * CUDA: loop over ne2*ne3 in case it overflows * use fastdiv --- ggml/src/ggml-cuda/convert.cu | 62 +++++++++++++++++++++-------------- 1 file changed, 38 insertions(+), 24 deletions(-) diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index ba3d4eeb880..09b6d5db6a0 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -7,7 +7,8 @@ template static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, - const int64_t ne00, const int64_t ne01, const int64_t ne02, + const int64_t ne00, const int64_t ne01, + const int64_t ne0203, const uint3 ne02, const int64_t s01, const int64_t s02, const int64_t s03) { const int64_t i00 = 2 * (int64_t(blockDim.x)*blockIdx.x + threadIdx.x); @@ -16,23 +17,27 @@ static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __ } const int64_t i01 = blockIdx.y; - const int64_t i02 = blockIdx.z % ne02; - const int64_t i03 = blockIdx.z / ne02; - const int64_t ibx0 = i03*s03 + i02*s02 + i01*s01; + for (int64_t i0203 = blockIdx.z; i0203 < ne0203; i0203 += gridDim.z) { + const uint2 dm = fast_div_modulo((uint32_t)i0203, ne02); + const int64_t i02 = dm.y; + const int64_t i03 = dm.x; - const int64_t ib = ibx0 + i00/qk; // block index - const int64_t iqs = (i00%qk)/qr; // quant index - const int64_t iybs = i00 - i00%qk; // y block start index - const int64_t y_offset = qr == 1 ? 1 : qk/2; + const int64_t ibx0 = i03*s03 + i02*s02 + i01*s01; - // dequantize - float2 v; - dequantize_kernel(vx, ib, iqs, v); + const int64_t ib = ibx0 + i00/qk; // block index + const int64_t iqs = (i00%qk)/qr; // quant index + const int64_t iybs = i00 - i00%qk; // y block start index + const int64_t y_offset = qr == 1 ? 1 : qk/2; - const int64_t iy0 = ((i03*ne02 + i02)*ne01 + i01)*ne00 + iybs + iqs; - y[iy0 + 0] = ggml_cuda_cast(v.x); - y[iy0 + y_offset] = ggml_cuda_cast(v.y); + // dequantize + float2 v; + dequantize_kernel(vx, ib, iqs, v); + + const int64_t iy0 = (i0203*ne01 + i01)*ne00 + iybs + iqs; + y[iy0 + 0] = ggml_cuda_cast(v.x); + y[iy0 + y_offset] = ggml_cuda_cast(v.y); + } } template @@ -485,9 +490,11 @@ template static void dequantize_block_cuda(const void * vx, dst_t * y, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, const int64_t s01, const int64_t s02, const int64_t s03, cudaStream_t stream) { - const dim3 num_blocks((ne00 + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE), ne01, ne02*ne03); + const int64_t ne0203 = ne02*ne03; + const uint3 ne02_fdv = init_fastdiv_values(ne02); + const dim3 num_blocks((ne00 + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE), ne01, (int)std::min(ne0203, (int64_t)65535)); dequantize_block<<>> - (vx, y, ne00, ne01, ne02, s01, s02, s03); + (vx, y, ne00, ne01, ne0203, ne02_fdv, s01, s02, s03); } template @@ -612,7 +619,8 @@ static void dequantize_row_mxfp4_cuda(const void * vx, dst_t * y, const int64_t template static __global__ void convert_unary( - const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01, const int64_t ne02, + const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01, + const int64_t ne0203, const uint3 ne02, const int64_t s01, const int64_t s02, const int64_t s03) { const int64_t i00 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; @@ -621,23 +629,29 @@ static __global__ void convert_unary( } const int64_t i01 = blockIdx.y; - const int64_t i02 = blockIdx.z % ne02; - const int64_t i03 = blockIdx.z / ne02; const src_t * x = (const src_t *) vx; - const int64_t ix = i03*s03 + i02*s02 + i01*s01 + i00; - const int64_t iy = ((i03*ne02 + i02)*ne01 + i01)*ne00 + i00; - y[iy] = ggml_cuda_cast(x[ix]); + for (int64_t i0203 = blockIdx.z; i0203 < ne0203; i0203 += gridDim.z) { + const uint2 dm = fast_div_modulo((uint32_t)i0203, ne02); + const int64_t i02 = dm.y; + const int64_t i03 = dm.x; + + const int64_t ix = i03*s03 + i02*s02 + i01*s01 + i00; + const int64_t iy = (i0203*ne01 + i01)*ne00 + i00; + y[iy] = ggml_cuda_cast(x[ix]); + } } template static void convert_unary_cuda(const void * vx, dst_t * y, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, const int64_t s01, const int64_t s02, const int64_t s03, cudaStream_t stream) { - const dim3 num_blocks((ne00 + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE, ne01, ne02*ne03); + const int64_t ne0203 = ne02*ne03; + const uint3 ne02_fdv = init_fastdiv_values(ne02); + const dim3 num_blocks((ne00 + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE, ne01, (int)std::min(ne0203, (int64_t)65535)); convert_unary<<>> - (vx, y, ne00, ne01, ne02, s01, s02, s03); + (vx, y, ne00, ne01, ne0203, ne02_fdv, s01, s02, s03); } template From 628b545b7e2b19de43ce81279be4986e75dac2f5 Mon Sep 17 00:00:00 2001 From: ymcki <84055651+ymcki@users.noreply.github.com> Date: Fri, 13 Feb 2026 20:31:37 +0800 Subject: [PATCH 143/831] fix vulkan ggml_acc only works in 3d but not 4d (llama/19426) * fix vulkan ggml_acc only works in 3d but not 4d * removed clamp in test_acc_block * use the correct stride and its test case * cuda : fix "supports op" condition * change src0 to src1 in ggml_vk_acc. Update acc.comp with jeffbolznv\'s suggestion except to keep the boundary check * version without boundary check * revert back to boundary check version --------- Co-authored-by: Georgi Gerganov --- ggml/src/ggml-cuda/ggml-cuda.cu | 5 ++++- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 12 ++++++------ ggml/src/ggml-vulkan/vulkan-shaders/acc.comp | 17 +++++++++-------- 3 files changed, 19 insertions(+), 15 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 7dc688483ad..85ce96958fa 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -4822,8 +4822,11 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_CONV_2D_DW: case GGML_OP_CONV_TRANSPOSE_2D: case GGML_OP_POOL_2D: - case GGML_OP_ACC: return true; + case GGML_OP_ACC: + // TODO: extend support like so: + //return ggml_is_contiguous_rows(op->src[0]) && ggml_is_contiguous_rows(op->src[1]); + return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]); case GGML_OP_SUM: return ggml_is_contiguous_rows(op->src[0]); case GGML_OP_TOP_K: diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 72097ffd0ff..e5dcd3cbda2 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -9801,16 +9801,16 @@ static void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const const uint32_t src1_type_size = ggml_type_size(src1->type); const uint32_t dst_type_size = ggml_type_size(dst->type); - int nb1 = dst->op_params[0] / 4; // 4 bytes of float32 - int nb2 = dst->op_params[1] / 4; // 4 bytes of float32 - // int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused - int offset = dst->op_params[3] / 4; // offset in bytes + int nb1 = dst->op_params[0] / src0_type_size; // 4 bytes of float32 + int nb2 = dst->op_params[1] / src0_type_size; // 4 bytes of float32 + int nb3 = dst->op_params[2] / src0_type_size; // 4 bytes of float32 + int offset = dst->op_params[3] / src0_type_size; // offset in bytes ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_ACC, { (uint32_t)ggml_nelements(src0), - (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)nb3, (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, - (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t) dst->nb[3] / dst_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)nb3, 0, 0.0f, 0.0f, offset, }); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp b/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp index 5084a70ed49..3d61168b56f 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp @@ -13,17 +13,18 @@ void main() { const uint offset = p.param3; const uint src1_i = idx - offset; - const uint oz = src1_i / p.nb02; - const uint oy = (src1_i - (oz * p.nb02)) / p.nb01; - const uint ox = src1_i % p.nb01; + const uint i3 = src1_i / p.nb03; + const uint rem2 = src1_i - i3 * p.nb03; + const uint i2 = rem2 / p.nb02; + const uint rem1 = rem2 - i2 * p.nb02; + const uint i1 = rem1 / p.nb01; + const uint i0 = rem1 % p.nb01; uint i00, i01, i02, i03; - get_indices(idx, i00, i01, i02, i03); - if (ox < p.ne10 && oy < p.ne11 && oz < p.ne12) { - data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + ox + oy * p.ne10 + oz * p.ne10 * p.ne11])); + if (i0 < p.ne10 && i1 < p.ne11 && i2 < p.ne12 && i3 < p.ne13) { + data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]) + FLOAT_TYPE(data_b[get_boffset() + src1_idx(i0, i1, i2, i3)])); } else { - data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)])); + data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx])); } } - From e8a25654b261d16f23706ae47e050672fc429402 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alberto=20Cabrera=20P=C3=A9rez?= <1478977+Alcpz@users.noreply.github.com> Date: Fri, 13 Feb 2026 12:32:14 +0000 Subject: [PATCH 144/831] Fix wrong memcpy length for block_interleave == 4 (llama/19575) --- ggml/src/ggml-cpu/repack.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp index 4cb7cdeb07b..f94426ddd7f 100644 --- a/ggml/src/ggml-cpu/repack.cpp +++ b/ggml/src/ggml-cpu/repack.cpp @@ -1916,9 +1916,10 @@ static block_q4_Kx8 make_block_q4_Kx8(block_q4_K * in, unsigned int blck_size_in int src_offset = (i / 8) * blck_size_interleave; int dst_offset = i * blck_size_interleave; + // buffer large enough for the max interleave block size (8 bytes) uint64_t elems; - memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t)); - memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t)); + memcpy(&elems, &in[src_id].qs[src_offset], blck_size_interleave); + memcpy(&out.qs[dst_offset], &elems, blck_size_interleave); } // The below logic is designed so as to unpack and rearrange scales and mins values in Q4_K From ec57bf407cb1b02998bde2b395f27eb96b0e9bc8 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Fri, 13 Feb 2026 11:35:29 -0800 Subject: [PATCH 145/831] vulkan: restore -inf check in FA shaders (llama/19582) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 4 +++- .../ggml-vulkan/vulkan-shaders/flash_attn.comp | 17 +++++++++++++++++ .../vulkan-shaders/flash_attn_cm1.comp | 15 +++++++++++++++ .../vulkan-shaders/flash_attn_cm2.comp | 14 ++++++++++++++ 4 files changed, 49 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index e5dcd3cbda2..82933ae0330 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -8422,6 +8422,8 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co const uint32_t acctype = f32acc ? 4 : 2; const uint32_t f16vec4 = 8; + const uint32_t tmpsh = (Bc / MatBc) * sizeof(float); + const uint32_t qstride = hsk_pad / 4 + 2; const uint32_t Qf = Br * qstride * f16vec4; @@ -8438,7 +8440,7 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co const uint32_t slope = Br * acctype; - const uint32_t total_size = Qf + Psh + sfsh + ksh + slope; + const uint32_t total_size = tmpsh + Qf + Psh + sfsh + ksh + slope; const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize; VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", f32acc=" << f32acc << ", kv_type=" << kv_type << ", total_size=" << total_size << ", supported=" << supported); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index 914f131c965..0735f678549 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -130,6 +130,7 @@ void main() { if (MASK_ENABLE && mask_opt_bits != MASK_OPT_ALL_ZERO) { bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0; + float max_mask = NEG_FLT_MAX_OVER_2; [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { uint32_t c = (idx + tid) % Bc; uint32_t r = (idx + tid) / Bc; @@ -137,12 +138,25 @@ void main() { if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) { float m = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]); masksh[c][r] = m; + max_mask = max(max_mask, m); } else { masksh[c][r] = float(0); } } } + // skip the block if the mask is entirely -inf + bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2); barrier(); + if (gl_SubgroupInvocationID == 0) { + tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f; + } + barrier(); + [[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) { + max_mask = max(max_mask, tmpsh[s]); + } + if (max_mask <= NEG_FLT_MAX_OVER_2) { + continue; + } } float Sf[Br][cols_per_thread]; @@ -260,6 +274,9 @@ void main() { barrier(); } + // prevent race on tmpsh + barrier(); + // reduce across threads [[unroll]] for (uint32_t r = 0; r < Br; ++r) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp index b3177738234..19630972daf 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp @@ -42,6 +42,8 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TY return elem; } +shared float tmpsh[row_split]; + const uint32_t qstride = HSK_pad / 4 + 2; // in units of f16vec4 shared f16vec4 Qf[Br * qstride]; @@ -213,6 +215,19 @@ void main() { } } } + // skip the block if the mask is entirely -inf + bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2); + barrier(); + if (gl_SubgroupInvocationID == 0) { + tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f; + } + barrier(); + [[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) { + max_mask = max(max_mask, tmpsh[s]); + } + if (max_mask <= NEG_FLT_MAX_OVER_2) { + continue; + } } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index 39f0c4d23b9..853f17fa16e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -176,7 +176,14 @@ void main() { tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1); tensorLayoutM = setTensorLayoutClampValueNV(tensorLayoutM, 0xfc00); // -inf in float16_t + coopmat mvmax; + coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc)); + // skip the block if the mask is entirely -inf + coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16); + if (mvmax[0] <= NEG_FLT_MAX_OVER_2) { + continue; + } } else { tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp); // Don't clamp against nem1 when GQA is enabled @@ -184,7 +191,14 @@ void main() { tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, m_height, KV); tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1); + coopmat mvmax; + coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc)); + // skip the block if the mask is entirely -inf + coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16); + if (mvmax[0] <= NEG_FLT_MAX_OVER_2) { + continue; + } } } } From e6476d4c12f8e921bea9be6e0f65f4e07cbe08e3 Mon Sep 17 00:00:00 2001 From: Max Krasnyansky Date: Fri, 13 Feb 2026 16:27:30 -0800 Subject: [PATCH 146/831] hexagon: further optimizations and refactoring for flash attention (llama/19583) * ggml-hexagon: fa improvements ggml-hexagon: optimize flash attention calculations with improved variable handling ggml-hexagon: streamline flash attention operations by removing redundant checks for FP32 ggml-hexagon: optimize hvx_dot_f16_f16_aa_rx2 by simplifying variable handling for unused elements ggml-hexagon: optimize flash attention by changing slope vector type to F16 * hexfa: fixed test-backend-ops failurs due to leftover element handling * hexagon: refactor and optimize fa to use local context struct * ggml-hexagon: optimize flash-attention using hvx_vec_expf Use HVX for online softmax. --------- Co-authored-by: chraac --- ggml/src/ggml-hexagon/htp/flash-attn-ops.c | 524 ++++++++++----------- 1 file changed, 253 insertions(+), 271 deletions(-) diff --git a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c index c1846374437..74c777d4c3e 100644 --- a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +++ b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c @@ -17,121 +17,6 @@ #include "htp-msg.h" #include "htp-ops.h" -static inline HVX_Vector hvx_load_f32_to_f16(const HVX_Vector * restrict src, const HVX_Vector zero) { - HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(src[0], zero); // 32 elements - HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(src[1], zero); // 32 elements - return Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf))); -} - -// Dot product of FP32 and FP16 vectors, accumulating to float -static inline void hvx_dot_f32_f16_aa(float * restrict r, const void * restrict y, const void * restrict x, unsigned int n, float s) { - const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp32 - const HVX_Vector * restrict vx = (const HVX_Vector * restrict) x; // fp16 - - uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors - uint32_t nloe = n % VLEN_FP16; // leftover elements - - const HVX_Vector zero = Q6_V_vsplat_R(0); - HVX_Vector rsum = Q6_V_vsplat_R(0); - - uint32_t i = 0; - - #pragma unroll(4) - for (i = 0; i < nvec; i++) { - // Load y (fp32) and convert into fp16 - HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero); - - // Load x (fp16) - HVX_Vector x_hf = vx[i]; - - HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); - - rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum)); - } - - if (nloe) { - // Load y (fp32) and convert into fp16 - HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero); - - // Load x (fp16) - HVX_Vector x_hf = vx[i]; - - // Zero-out unused elements - // Note that we need to clear both x and y because they may contain NANs - HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2); - x_hf = Q6_V_vand_QV(bmask, x_hf); - y_hf = Q6_V_vand_QV(bmask, y_hf); - - HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); - - rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum)); - } - - rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32(rsum)); - hvx_vec_store_u(r, 4, Q6_Vsf_equals_Vqf32(rsum)); -} - -// Dot product of FP32 and FP16 vectors, accumulating to float -static inline void hvx_dot_f32_f16_aa_rx2(float * restrict r, - const void * restrict y, - const void * restrict x0, - const void * restrict x1, - unsigned int n, - float s) { - const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp32 - const HVX_Vector * restrict vx0 = (const HVX_Vector * restrict) x0; // fp16 - const HVX_Vector * restrict vx1 = (const HVX_Vector * restrict) x1; // fp16 - - uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors - uint32_t nloe = n % VLEN_FP16; // leftover elements - - const HVX_Vector zero = Q6_V_vsplat_R(0); - HVX_Vector rsum0 = Q6_V_vsplat_R(0); - HVX_Vector rsum1 = Q6_V_vsplat_R(0); - - uint32_t i = 0; - - #pragma unroll(2) - for (i = 0; i < nvec; i++) { - // Load y (fp32) and convert into fp16 - HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero); - // Load x (fp16) - HVX_Vector x0_hf = vx0[i]; - HVX_Vector x1_hf = vx1[i]; - - HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf); - HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf); - - rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0)); - rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1)); - } - - if (nloe) { - // Load y (fp32) and convert into fp16 - HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero); - - // Load x (fp16) - HVX_Vector x0_hf = vx0[i]; - HVX_Vector x1_hf = vx1[i]; - - // Zero-out unused elements - // Note that we need to clear both x and y because they may contain NANs - HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2); - x0_hf = Q6_V_vand_QV(bmask, x0_hf); - x1_hf = Q6_V_vand_QV(bmask, x1_hf); - y_hf = Q6_V_vand_QV(bmask, y_hf); - - HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf); - HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf); - - rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0)); - rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1)); - } - - HVX_Vector rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32x2(rsum0, rsum1)); - hvx_vec_store_u(r, 8, Q6_Vsf_equals_Vqf32(rsum)); -} - // Dot product of two F16 vectors, accumulating to float static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict x, const void * restrict y, unsigned int n, float s) { const HVX_Vector * restrict vx = (const HVX_Vector * restrict) x; // fp16 @@ -140,8 +25,7 @@ static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors uint32_t nloe = n % VLEN_FP16; // leftover elements - const HVX_Vector zero = Q6_V_vsplat_R(0); - HVX_Vector rsum = Q6_V_vsplat_R(0); + HVX_Vector rsum = Q6_V_vsplat_R(0); uint32_t i = 0; @@ -156,11 +40,10 @@ static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict } if (nloe) { - HVX_Vector y_hf = vy[i]; - // Load x (fp16) and zero-out unused elements HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2); - HVX_Vector x_hf = Q6_V_vand_QV(bmask, vx[i]); + HVX_Vector y_hf = Q6_V_vand_QV(bmask, vy[i]); + HVX_Vector x_hf = Q6_V_vand_QV(bmask, vx[i]); HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); @@ -181,12 +64,11 @@ static inline void hvx_dot_f16_f16_aa_rx2(float * restrict r, const HVX_Vector * restrict vx1 = (const HVX_Vector * restrict) x1; // fp16 const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp16 - uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors - uint32_t nloe = n % VLEN_FP16; // leftover elements + uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors + uint32_t nloe = n % VLEN_FP16; // leftover elements - const HVX_Vector zero = Q6_V_vsplat_R(0); - HVX_Vector rsum0 = Q6_V_vsplat_R(0); - HVX_Vector rsum1 = Q6_V_vsplat_R(0); + HVX_Vector rsum0 = Q6_V_vsplat_R(0); + HVX_Vector rsum1 = Q6_V_vsplat_R(0); uint32_t i = 0; @@ -204,12 +86,11 @@ static inline void hvx_dot_f16_f16_aa_rx2(float * restrict r, } if (nloe) { - HVX_Vector y_hf = vy[i]; - // Load x (fp16) and zero-out unused elements HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2); - HVX_Vector x0_hf = Q6_V_vand_QV(bmask, vx0[i]); - HVX_Vector x1_hf = Q6_V_vand_QV(bmask, vx1[i]); + HVX_Vector x0_hf = Q6_V_vand_QV(bmask, vx0[i]); + HVX_Vector x1_hf = Q6_V_vand_QV(bmask, vx1[i]); + HVX_Vector y_hf = Q6_V_vand_QV(bmask, vy[i]); HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf); HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf); @@ -222,7 +103,7 @@ static inline void hvx_dot_f16_f16_aa_rx2(float * restrict r, hvx_vec_store_u(r, 8, Q6_Vsf_equals_Vqf32(rsum)); } -// MAD: y (F32) += x (F16) * s (float) +// MAD: y (F32) += x (F16) * s (F32) static inline void hvx_mad_f32_f16_aa(float * restrict y, const void * restrict x, int n, float s) { const HVX_Vector * restrict ptr_x = (const HVX_Vector *) x; HVX_Vector * restrict ptr_y = (HVX_Vector *) y; @@ -259,15 +140,125 @@ static inline void hvx_mad_f32_f16_aa(float * restrict y, const void * restrict } } +// MAD: y (F32) += x0 (F16) * s0 (F32) + x1 (F16) * s1 (F32) +static inline void hvx_mad_f32_f16_aa_rx2(float * restrict y, + const void * restrict x0, + const void * restrict x1, + float s0, + float s1, + int n) { + const HVX_Vector * restrict ptr_x0 = (const HVX_Vector *) x0; + const HVX_Vector * restrict ptr_x1 = (const HVX_Vector *) x1; + HVX_Vector * restrict ptr_y = (HVX_Vector *) y; + + uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors + uint32_t nloe = n % VLEN_FP16; // leftover elements + + HVX_Vector S0 = hvx_vec_splat_f16(s0); + HVX_Vector S1 = hvx_vec_splat_f16(s1); + + uint32_t i = 0; + #pragma unroll(2) + for (i = 0; i < nvec; ++i) { + // Multiply x * s -> pair of F32 vectors + HVX_VectorPair xs0_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x0[i]), S0); + HVX_VectorPair xs1_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x1[i]), S1); + + HVX_Vector xs_p_lo = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xs0_p), Q6_V_lo_W(xs1_p)); + HVX_Vector xs_p_hi = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_hi_W(xs0_p), Q6_V_hi_W(xs1_p)); + + ptr_y[i * 2] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs_p_lo, ptr_y[i * 2])); + ptr_y[i * 2 + 1] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs_p_hi, ptr_y[i * 2 + 1])); + } + + if (nloe) { + HVX_VectorPair xs0_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x0[i]), S0); + HVX_VectorPair xs1_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x1[i]), S1); + + HVX_Vector xs_p_lo = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xs0_p), Q6_V_lo_W(xs1_p)); + HVX_Vector xs = xs_p_lo; + i = 2 * i; // index for ptr_y + + if (nloe >= 32) { + ptr_y[i] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i])); + nloe -= 32; ++i; + xs = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_hi_W(xs0_p), Q6_V_hi_W(xs1_p)); + } + + if (nloe) { + HVX_Vector xy = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i])); + hvx_vec_store_a(&ptr_y[i], nloe * 4, xy); + } + } +} + #define FLASH_ATTN_BLOCK_SIZE 128 -static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, int nth) { +struct htp_fa_context { + const struct htp_ops_context * octx; + + struct fastdiv_values src0_div21; + struct fastdiv_values src0_div1; + + struct fastdiv_values broadcast_rk2; + struct fastdiv_values broadcast_rk3; + struct fastdiv_values broadcast_rv2; + struct fastdiv_values broadcast_rv3; + + struct fastdiv_values src3_div2; + struct fastdiv_values src3_div3; + + float scale; + float max_bias; + float logit_softcap; + + uint32_t n_head_log2; + float m0; + float m1; + + uint32_t n_blocks; + + size_t size_q_row_padded; + size_t size_k_row_padded; + size_t size_v_row_padded; + + size_t size_k_block; + size_t size_v_block; + size_t size_m_block; + + bool is_q_fp32; +}; + +static inline void hvx_scale_vec_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const int n, HVX_Vector vs) { + assert((size_t) dst % 128 == 0); + assert((size_t) src % 128 == 0); + + const HVX_Vector * restrict vsrc = (const HVX_Vector * restrict) src; + HVX_Vector * restrict vdst = (HVX_Vector * restrict) dst; + + const uint32_t nvec = n / VLEN_FP32; + const uint32_t nloe = n % VLEN_FP32; + + uint32_t i = 0; + #pragma unroll(4) + for (; i < nvec; ++i) { + vdst[i] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs)); + } + if (nloe) { + HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs); + hvx_vec_store_a(&vdst[i], nloe * sizeof(float), Q6_Vsf_equals_Vqf32(v)); + } +} + +static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void * data) { + struct htp_fa_context * factx = (struct htp_fa_context *) data; + const struct htp_ops_context * octx = factx->octx; const struct htp_tensor * q = &octx->src0; const struct htp_tensor * k = &octx->src1; const struct htp_tensor * v = &octx->src2; const struct htp_tensor * mask = (octx->src3.data) ? &octx->src3 : NULL; const struct htp_tensor * sinks = (octx->src4.data) ? &octx->src4 : NULL; - struct htp_tensor * dst = &octx->dst; + const struct htp_tensor * dst = &octx->dst; const uint32_t neq0 = q->ne[0]; const uint32_t neq1 = q->ne[1]; @@ -304,18 +295,6 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in const uint32_t nb2 = dst->nb[2]; const uint32_t nb3 = dst->nb[3]; - float scale = 1.0f; - float max_bias = 0.0f; - float logit_softcap = 0.0f; - - memcpy(&scale, (float *) octx->op_params + 0, sizeof(float)); - memcpy(&max_bias, (float *) octx->op_params + 1, sizeof(float)); - memcpy(&logit_softcap, (float *) octx->op_params + 2, sizeof(float)); - - if (logit_softcap != 0) { - scale /= logit_softcap; - } - // total rows in q const uint32_t nr = neq1*neq2*neq3; @@ -331,18 +310,8 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in const uint32_t DV = nev0; const size_t size_q_row = DK * ((q->type == HTP_TYPE_F32) ? 4 : 2); - const size_t size_q_row_padded = hex_round_up(size_q_row, 128); - const size_t size_k_row = DK * sizeof(__fp16); const size_t size_v_row = DV * sizeof(__fp16); - const size_t size_m_row = FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16); // Treat block as one row for mask - - const size_t size_k_row_padded = hex_round_up(size_k_row, 128); - const size_t size_v_row_padded = hex_round_up(size_v_row, 128); - - const size_t size_k_block = size_k_row_padded * FLASH_ATTN_BLOCK_SIZE; - const size_t size_v_block = size_v_row_padded * FLASH_ATTN_BLOCK_SIZE; - const size_t size_m_block = hex_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128); // Scratchpad buffers for Q, K, V, Mask, and VKQ32 accumulator uint8_t * spad_q = octx->src0_spad.data + octx->src0_spad.size_per_thread * ith; @@ -351,31 +320,28 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in uint8_t * spad_m = octx->src3_spad.data + octx->src3_spad.size_per_thread * ith; uint8_t * spad_a = octx->dst_spad.data + octx->dst_spad.size_per_thread * ith; - const uint32_t n_head = neq2; - const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head)); - const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); - const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + const HVX_Vector logit_cap = hvx_vec_splat_f32(factx->logit_softcap); for (uint32_t ir = ir0; ir < ir1; ++ir) { - const uint32_t iq3 = fastdiv(ir, &octx->src0_div21); - const uint32_t iq2 = fastdiv(ir - iq3*neq2*neq1, &octx->src0_div1); + const uint32_t iq3 = fastdiv(ir, &factx->src0_div21); + const uint32_t iq2 = fastdiv(ir - iq3*neq2*neq1, &factx->src0_div1); const uint32_t iq1 = (ir - iq3*neq2*neq1 - iq2 * neq1); - const uint32_t ik3 = fastdiv(iq3, &octx->broadcast_rk3); - const uint32_t ik2 = fastdiv(iq2, &octx->broadcast_rk2); + const uint32_t ik3 = fastdiv(iq3, &factx->broadcast_rk3); + const uint32_t ik2 = fastdiv(iq2, &factx->broadcast_rk2); - const uint32_t iv3 = fastdiv(iq3, &octx->broadcast_rv3); - const uint32_t iv2 = fastdiv(iq2, &octx->broadcast_rv2); + const uint32_t iv3 = fastdiv(iq3, &factx->broadcast_rv3); + const uint32_t iv2 = fastdiv(iq2, &factx->broadcast_rv2); // Fetch Q row const uint8_t * q_row_ptr = (const uint8_t *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3); - dma_queue_push(dma, dma_make_ptr(spad_q, q_row_ptr), size_q_row_padded, nbq1, size_q_row, 1); + dma_queue_push(dma, dma_make_ptr(spad_q, q_row_ptr), factx->size_q_row_padded, nbq1, size_q_row, 1); const uint32_t h = iq2; // head index - const float slope = (max_bias > 0.0f) ? (h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1)) : 1.0f; + const float slope = (factx->max_bias > 0.0f) ? (h < factx->n_head_log2 ? powf(factx->m0, h + 1) : powf(factx->m1, 2*(h - factx->n_head_log2) + 1)) : 1.0f; - float S = 0.0f; // sum - float M = -INFINITY; // maximum KQ value + HVX_Vector S_vec = hvx_vec_splat_f32(0.0f); + HVX_Vector M_vec = hvx_vec_splat_f32(-INFINITY); // Clear accumulator hvx_splat_f32_a(spad_a, 0, DV); @@ -383,40 +349,42 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in const __fp16 * mp_base = NULL; if (mask) { - const uint32_t im2 = fastmodulo(iq2, mask->ne[2], &octx->src3_div2); - const uint32_t im3 = fastmodulo(iq3, mask->ne[3], &octx->src3_div3); + const uint32_t im2 = fastmodulo(iq2, mask->ne[2], &factx->src3_div2); + const uint32_t im3 = fastmodulo(iq3, mask->ne[3], &factx->src3_div3); mp_base = (const __fp16 *) ((const uint8_t *) mask->data + iq1*mask->nb[1] + im2*mask->nb[2] + im3*mask->nb[3]); } - const uint32_t n_blocks = (nek1 + FLASH_ATTN_BLOCK_SIZE - 1) / FLASH_ATTN_BLOCK_SIZE; - // Prefetch first two blocks - for (uint32_t ib = 0; ib < MIN(n_blocks, 2); ++ib) { + for (uint32_t ib = 0; ib < MIN(factx->n_blocks, 2); ++ib) { const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE; const uint32_t current_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - ic_start); // K const uint8_t * k_src = (const uint8_t *) k->data + (ic_start*nbk1 + ik2*nbk2 + ik3*nbk3); - uint8_t * k_dst = spad_k + (ib % 2) * size_k_block; - dma_queue_push(dma, dma_make_ptr(k_dst, k_src), size_k_row_padded, nbk1, size_k_row, current_block_size); + uint8_t * k_dst = spad_k + (ib % 2) * factx->size_k_block; + dma_queue_push(dma, dma_make_ptr(k_dst, k_src), factx->size_k_row_padded, nbk1, size_k_row, current_block_size); // V const uint8_t * v_src = (const uint8_t *) v->data + (ic_start*nbv1 + iv2*nbv2 + iv3*nbv3); - uint8_t * v_dst = spad_v + (ib % 2) * size_v_block; - dma_queue_push(dma, dma_make_ptr(v_dst, v_src), size_v_row_padded, nbv1, size_v_row, current_block_size); + uint8_t * v_dst = spad_v + (ib % 2) * factx->size_v_block; + dma_queue_push(dma, dma_make_ptr(v_dst, v_src), factx->size_v_row_padded, nbv1, size_v_row, current_block_size); // Mask if (mask) { const uint8_t * m_src = (const uint8_t *) (mp_base + ic_start); - uint8_t * m_dst = spad_m + (ib % 2) * size_m_block; + uint8_t * m_dst = spad_m + (ib % 2) * factx->size_m_block; // Mask is 1D contiguous for this row dma_queue_push(dma, dma_make_ptr(m_dst, m_src), current_block_size * 2, current_block_size * 2, current_block_size * 2, 1); } } - const uint8_t * q_ptr_vtcm = dma_queue_pop(dma).dst; + uint8_t * q_ptr_vtcm = dma_queue_pop(dma).dst; + if (factx->is_q_fp32) { + hvx_copy_f16_f32_aa(q_ptr_vtcm, q_ptr_vtcm, DK); // inplace convert f32 to f16 + } - for (uint32_t ib = 0; ib < n_blocks; ++ib) { + const HVX_Vector slope_vec = hvx_vec_splat_f16(slope); + for (uint32_t ib = 0; ib < factx->n_blocks; ++ib) { const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE; const uint32_t current_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - ic_start); @@ -428,8 +396,6 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in // Inner loop processing the block from VTCM uint32_t ic = 0; - const bool is_q_fp32 = (q->type == HTP_TYPE_F32); - // Process in blocks of 32 (VLEN_FP32) static_assert(FLASH_ATTN_BLOCK_SIZE / VLEN_FP32 <= 4, "FLASH_ATTN_BLOCK_SIZE changed, fix HVX_Vector_x4 usage"); HVX_Vector_x4 scores_x4; @@ -437,22 +403,18 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in for (uint32_t iv = 0; ic + VLEN_FP32 <= current_block_size; ic += VLEN_FP32, ++iv) { // 1. Compute scores float __attribute__((aligned(VLEN))) scores_arr[VLEN_FP32]; - for (int j = 0; j < VLEN_FP32; j += 2) { + for (uint32_t j = 0; j < VLEN_FP32; j += 2) { const uint32_t cur_ic = ic + j; - const uint8_t * k_ptr = k_base + cur_ic * size_k_row_padded; - if (is_q_fp32) { - hvx_dot_f32_f16_aa_rx2(&scores_arr[j], q_ptr_vtcm, k_ptr, k_ptr + size_k_row_padded, DK, scale); - } else { - hvx_dot_f16_f16_aa_rx2(&scores_arr[j], q_ptr_vtcm, k_ptr, k_ptr + size_k_row_padded, DK, scale); - } + const uint8_t * k_ptr = k_base + cur_ic * factx->size_k_row_padded; + hvx_dot_f16_f16_aa_rx2(&scores_arr[j], q_ptr_vtcm, k_ptr, k_ptr + factx->size_k_row_padded, DK, factx->scale); } HVX_Vector scores = *(HVX_Vector *) scores_arr; // 2. Softcap - if (logit_softcap != 0.0f) { + if (factx->logit_softcap != 0.0f) { scores = hvx_vec_tanh_f32(scores); - scores = Q6_Vqf32_vmpy_VsfVsf(scores, hvx_vec_splat_f32(logit_softcap)); + scores = Q6_Vqf32_vmpy_VsfVsf(scores, logit_cap); scores = Q6_Vsf_equals_Vqf32(scores); } @@ -460,70 +422,59 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in if (mask) { const __fp16 * mp = m_base + ic; HVX_Vector m_vals_f16 = *(const HVX_UVector *) mp; - - HVX_Vector one_f16 = Q6_Vh_vsplat_R(0x3c00); - HVX_VectorPair m_vals_f32_pair = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_f16), one_f16); - - HVX_Vector m_vals_f32 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(m_vals_f32_pair)); - - HVX_Vector slope_vec = hvx_vec_splat_f32(slope); - HVX_Vector add_val = Q6_Vqf32_vmpy_VsfVsf(m_vals_f32, slope_vec); - scores = Q6_Vqf32_vadd_VsfVsf(scores, Q6_Vsf_equals_Vqf32(add_val)); + HVX_VectorPair m_vals_f32_pair = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_f16), slope_vec); + HVX_Vector add_val = Q6_V_lo_W(m_vals_f32_pair); + scores = Q6_Vqf32_vadd_Vqf32Vsf(add_val, scores); scores = Q6_Vsf_equals_Vqf32(scores); } scores_x4.v[iv] = scores; - v_max = Q6_Vsf_vmax_VsfVsf(scores, v_max); + v_max = hvx_vec_reduce_max2_f32(scores, v_max); // All lanes have block max } { // 4. Online Softmax Update - v_max = hvx_vec_reduce_max_f32(v_max); - float m_block = hvx_vec_get_f32(v_max); - float M_old = M; - float M_new = (m_block > M) ? m_block : M; - M = M_new; + HVX_Vector M_new_vec = Q6_Vsf_vmax_VsfVsf(v_max, M_vec); + HVX_Vector diff_vec = Q6_Vqf32_vsub_VsfVsf(M_vec, M_new_vec); + HVX_Vector ms_vec = hvx_vec_exp_f32(Q6_Vsf_equals_Vqf32(diff_vec)); + M_vec = M_new_vec; - const float ms = expf(M_old - M_new); - hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms); + hvx_scale_vec_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms_vec); - HVX_Vector M_new_vec = hvx_vec_splat_f32(M_new); HVX_Vector p_sum_vec = hvx_vec_splat_f32(0.0f); for (uint32_t ic2 = 0, iv = 0; ic2 + VLEN_FP32 <= current_block_size; ic2 += VLEN_FP32, ++iv) { HVX_Vector scores = scores_x4.v[iv]; - HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_new_vec); + HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_vec); HVX_Vector P = hvx_vec_exp_f32(Q6_Vsf_equals_Vqf32(scores_shifted)); p_sum_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(p_sum_vec, P)); // 5. Accumulate V float __attribute__((aligned(VLEN))) p_arr[VLEN_FP32]; - *(HVX_Vector*)p_arr = P; + *(HVX_Vector *) p_arr = P; - for (int j = 0; j < VLEN_FP32; ++j) { - const uint32_t cur_ic = ic2 + j; - const uint8_t * v_ptr = v_base + cur_ic * size_v_row_padded; - hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, p_arr[j]); + for (uint32_t j = 0; j < VLEN_FP32; j += 2) { + const uint32_t cur_ic = ic2 + j; + const uint8_t * v_ptr = v_base + cur_ic * factx->size_v_row_padded; + hvx_mad_f32_f16_aa_rx2(VKQ32, v_ptr, v_ptr + factx->size_v_row_padded, p_arr[j], p_arr[j + 1], DV); } } p_sum_vec = hvx_vec_reduce_sum_f32(p_sum_vec); - S = S * ms + hvx_vec_get_f32(p_sum_vec); + S_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(S_vec, ms_vec)), p_sum_vec)); } + // Sync scalars for leftover/next block if needed + float M = hvx_vec_get_f32(M_vec); + float S = hvx_vec_get_f32(S_vec); + // Leftover for (; ic < current_block_size; ++ic) { float s_val; - const uint8_t * k_ptr = k_base + ic * size_k_row_padded; - - if (is_q_fp32) { - hvx_dot_f32_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale); - } else { - hvx_dot_f16_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale); - } - - if (logit_softcap != 0.0f) { - s_val = logit_softcap * tanhf(s_val); + const uint8_t * k_ptr = k_base + ic * factx->size_k_row_padded; + hvx_dot_f16_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, factx->scale); + if (factx->logit_softcap != 0.0f) { + s_val = factx->logit_softcap * tanhf(s_val); } if (mask) { @@ -532,37 +483,42 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in } const float Mold = M; - float ms = 1.0f; float vs = 1.0f; if (s_val > M) { M = s_val; - ms = expf(Mold - M); - hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms); + HVX_Vector diff_vec = hvx_vec_splat_f32(Mold - M); + HVX_Vector ms_vec = hvx_vec_exp_f32(diff_vec); + hvx_scale_vec_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms_vec); + + float ms = hvx_vec_get_f32(ms_vec); + S = S * ms + vs; } else { - vs = expf(s_val - M); + HVX_Vector diff_vec = hvx_vec_splat_f32(s_val - M); + vs = hvx_vec_get_f32(hvx_vec_exp_f32(diff_vec)); + S += vs; } - const uint8_t * v_ptr = v_base + ic * size_v_row_padded; + const uint8_t * v_ptr = v_base + ic * factx->size_v_row_padded; hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, vs); - - S = S * ms + vs; } + M_vec = hvx_vec_splat_f32(M); + S_vec = hvx_vec_splat_f32(S); // Issue DMA for next+1 block (if exists) - if (ib + 2 < n_blocks) { + if (ib + 2 < factx->n_blocks) { const uint32_t next_ib = ib + 2; const uint32_t next_ic_start = next_ib * FLASH_ATTN_BLOCK_SIZE; const uint32_t next_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - next_ic_start); // K const uint8_t * k_src = (const uint8_t *) k->data + (next_ic_start*nbk1 + ik2*nbk2 + ik3*nbk3); - dma_queue_push(dma, dma_make_ptr(k_base, k_src), size_k_row_padded, nbk1, size_k_row, next_block_size); + dma_queue_push(dma, dma_make_ptr(k_base, k_src), factx->size_k_row_padded, nbk1, size_k_row, next_block_size); // V const uint8_t * v_src = (const uint8_t *) v->data + (next_ic_start*nbv1 + iv2*nbv2 + iv3*nbv3); - dma_queue_push(dma, dma_make_ptr(v_base, v_src), size_v_row_padded, nbv1, size_v_row, next_block_size); + dma_queue_push(dma, dma_make_ptr(v_base, v_src), factx->size_v_row_padded, nbv1, size_v_row, next_block_size); // Mask if (mask) { @@ -573,20 +529,26 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in } // sinks + float M = hvx_vec_get_f32(M_vec); + float S = hvx_vec_get_f32(S_vec); + if (sinks) { const float s = ((float *)((char *) sinks->data))[h]; - float ms = 1.0f; float vs = 1.0f; if (s > M) { - ms = expf(M - s); - hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms); + HVX_Vector diff_vec = hvx_vec_splat_f32(M - s); + HVX_Vector ms_vec = hvx_vec_exp_f32(diff_vec); + hvx_scale_vec_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms_vec); + + float ms = hvx_vec_get_f32(ms_vec); + S = S * ms + vs; } else { - vs = expf(s - M); + HVX_Vector diff_vec = hvx_vec_splat_f32(s - M); + vs = hvx_vec_get_f32(hvx_vec_exp_f32(diff_vec)); + S += vs; } - - S = S * ms + vs; } const float S_inv = S == 0.0f ? 0.0f : 1.0f/S; @@ -609,53 +571,73 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in } } -static void htp_flash_attn_ext_job(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - flash_attn_ext_f16_thread(octx, i, n); -} - int op_flash_attn_ext(struct htp_ops_context * octx) { const struct htp_tensor * q = &octx->src0; const struct htp_tensor * k = &octx->src1; const struct htp_tensor * v = &octx->src2; - const struct htp_tensor * mask = (octx->src3.type != HTP_TYPE_COUNT) ? &octx->src3 : NULL; - struct htp_tensor * dst = &octx->dst; + const struct htp_tensor * mask = (octx->src3.data) ? &octx->src3 : NULL; + const struct htp_tensor * dst = &octx->dst; // Check support - if ((q->type != HTP_TYPE_F16 && q->type != HTP_TYPE_F32) || - k->type != HTP_TYPE_F16 || - v->type != HTP_TYPE_F16) { + if ((q->type != HTP_TYPE_F16 && q->type != HTP_TYPE_F32) || k->type != HTP_TYPE_F16 || v->type != HTP_TYPE_F16) { return HTP_STATUS_NO_SUPPORT; } - octx->src0_div21 = init_fastdiv_values(q->ne[2] * q->ne[1]); - octx->src0_div1 = init_fastdiv_values(q->ne[1]); + struct htp_fa_context factx; + factx.octx = octx; + + factx.src0_div21 = init_fastdiv_values(q->ne[2] * q->ne[1]); + factx.src0_div1 = init_fastdiv_values(q->ne[1]); - octx->broadcast_rk2 = init_fastdiv_values(q->ne[2]/k->ne[2]); - octx->broadcast_rk3 = init_fastdiv_values(q->ne[3]/k->ne[3]); - octx->broadcast_rv2 = init_fastdiv_values(q->ne[2]/v->ne[2]); - octx->broadcast_rv3 = init_fastdiv_values(q->ne[3]/v->ne[3]); + factx.broadcast_rk2 = init_fastdiv_values(q->ne[2]/k->ne[2]); + factx.broadcast_rk3 = init_fastdiv_values(q->ne[3]/k->ne[3]); + factx.broadcast_rv2 = init_fastdiv_values(q->ne[2]/v->ne[2]); + factx.broadcast_rv3 = init_fastdiv_values(q->ne[3]/v->ne[3]); if (mask) { - octx->src3_div2 = init_fastdiv_values(mask->ne[2]); - octx->src3_div3 = init_fastdiv_values(mask->ne[3]); + factx.src3_div2 = init_fastdiv_values(mask->ne[2]); + factx.src3_div3 = init_fastdiv_values(mask->ne[3]); + } + + factx.is_q_fp32 = (q->type == HTP_TYPE_F32); + factx.size_q_row_padded = hex_round_up(q->ne[0] * (factx.is_q_fp32 ? 4 : 2), 128); + factx.size_k_row_padded = hex_round_up(k->ne[0] * sizeof(__fp16), 128); + factx.size_v_row_padded = hex_round_up(v->ne[0] * sizeof(__fp16), 128); + + size_t size_q_block = factx.size_q_row_padded * 1; // single row for now + factx.size_k_block = factx.size_k_row_padded * FLASH_ATTN_BLOCK_SIZE; + factx.size_v_block = factx.size_v_row_padded * FLASH_ATTN_BLOCK_SIZE; + factx.size_m_block = hex_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128); + + factx.n_blocks = (k->ne[1] + FLASH_ATTN_BLOCK_SIZE - 1) / FLASH_ATTN_BLOCK_SIZE; + + float scale = 1.0f; + float max_bias = 0.0f; + float logit_softcap = 0.0f; + + memcpy(&scale, (float *) octx->op_params + 0, sizeof(float)); + memcpy(&max_bias, (float *) octx->op_params + 1, sizeof(float)); + memcpy(&logit_softcap, (float *) octx->op_params + 2, sizeof(float)); + + if (logit_softcap != 0.0f) { + scale /= logit_softcap; } - size_t size_q_row_padded = hex_round_up(q->ne[0] * (q->type == HTP_TYPE_F32 ? 4 : 2), 128); - size_t size_k_row_padded = hex_round_up(k->ne[0] * sizeof(__fp16), 128); - size_t size_v_row_padded = hex_round_up(v->ne[0] * sizeof(__fp16), 128); + factx.scale = scale; + factx.max_bias = max_bias; + factx.logit_softcap = logit_softcap; - size_t size_q_block = size_q_row_padded * 1; // single row for now - size_t size_k_block = size_k_row_padded * FLASH_ATTN_BLOCK_SIZE; - size_t size_v_block = size_v_row_padded * FLASH_ATTN_BLOCK_SIZE; - size_t size_m_block = hex_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128); + uint32_t n_head = q->ne[2]; + factx.n_head_log2 = 1u << (uint32_t) floor(log2(n_head)); + factx.m0 = powf(2.0f, -(max_bias ) / factx.n_head_log2); + factx.m1 = powf(2.0f, -(max_bias / 2.0f) / factx.n_head_log2); size_t size_vkq_acc = hex_round_up(v->ne[0] * sizeof(float), 128); // VKQ32 octx->src0_spad.size_per_thread = size_q_block * 1; - octx->src1_spad.size_per_thread = size_k_block * 2; - octx->src2_spad.size_per_thread = size_v_block * 2; - octx->src3_spad.size_per_thread = mask ? size_m_block * 2 : 0; + octx->src1_spad.size_per_thread = factx.size_k_block * 2; + octx->src2_spad.size_per_thread = factx.size_v_block * 2; + octx->src3_spad.size_per_thread = mask ? factx.size_m_block * 2 : 0; octx->dst_spad.size_per_thread = size_vkq_acc; octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; @@ -677,7 +659,7 @@ int op_flash_attn_ext(struct htp_ops_context * octx) { octx->dst_spad.data = octx->src3_spad.data + octx->src3_spad.size; if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { - worker_pool_run_func(octx->ctx->worker_pool, htp_flash_attn_ext_job, octx, octx->n_threads); + worker_pool_run_func(octx->ctx->worker_pool, flash_attn_ext_f16_thread, &factx, octx->n_threads); } return HTP_STATUS_OK; From fc6bbab817d1bf108f17a723e9c9b165b50b3499 Mon Sep 17 00:00:00 2001 From: Sophon Date: Sat, 14 Feb 2026 13:29:17 +0800 Subject: [PATCH 147/831] vulkan: Add vendor id for Qualcomm drivers (llama/19569) This commit allows Qualcomm native vulkan driver to be used on Windows instead of Mesa Dozen. --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 82933ae0330..e919d2223e7 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -92,6 +92,7 @@ static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; } #define VK_VENDOR_ID_APPLE 0x106b #define VK_VENDOR_ID_INTEL 0x8086 #define VK_VENDOR_ID_NVIDIA 0x10de +#define VK_VENDOR_ID_QUALCOMM 0x5143 #define VK_DEVICE_DESCRIPTOR_POOL_SIZE 256 @@ -5641,6 +5642,10 @@ static void ggml_vk_instance_init() { driver_priorities[vk::DriverId::eMesaNvk] = 2; #endif break; + case VK_VENDOR_ID_QUALCOMM: + driver_priorities[vk::DriverId::eQualcommProprietary] = 1; + driver_priorities[vk::DriverId::eMesaTurnip] = 2; + break; } driver_priorities[vk::DriverId::eMesaDozen] = 100; From 197e9ab6eba776d6fe49aa2f73d5fc4194aa5320 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Fri, 13 Feb 2026 21:36:38 -0800 Subject: [PATCH 148/831] vulkan: support GGML_OP_SET (llama/19584) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 20 +++++++++++++++++--- ggml/src/ggml-vulkan/vulkan-shaders/acc.comp | 9 ++++++++- 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index e919d2223e7..a9f75f0d00d 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -688,6 +688,7 @@ struct vk_device_struct { vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT]; vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT]; vk_pipeline pipeline_acc_f32; + vk_pipeline pipeline_set_f32; // [src0 0=fp32,1=fp16][src1 0=fp32,1=fp16][dst 0=fp32,1=fp16] vk_pipeline pipeline_add[2][2][2]; @@ -4182,7 +4183,8 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_add_id_f32, "add_id_f32", add_id_f32_len, add_id_f32_data, "main", 4, sizeof(vk_op_add_id_push_constants), {1, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0, 1}, 1); + ggml_vk_create_pipeline(device, device->pipeline_set_f32, "set_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0, 0}, 1); ggml_vk_create_pipeline(device, device->pipeline_concat_f32, "concat_f32", concat_f32_len, concat_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_concat_f16, "concat_f16", concat_f16_len, concat_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); @@ -8822,6 +8824,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_acc_f32; } return nullptr; + case GGML_OP_SET: + if (src0->type == src1->type && src0->type == dst->type && + (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32)) { + return ctx->device->pipeline_set_f32; + } + return nullptr; case GGML_OP_ADD: case GGML_OP_SUB: case GGML_OP_MUL: @@ -9813,7 +9821,7 @@ static void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const int nb3 = dst->op_params[2] / src0_type_size; // 4 bytes of float32 int offset = dst->op_params[3] / src0_type_size; // offset in bytes - ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_ACC, { + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, dst->op, { (uint32_t)ggml_nelements(src0), (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)nb3, (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, @@ -12507,6 +12515,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr break; case GGML_OP_ACC: + case GGML_OP_SET: ggml_vk_acc(ctx, compute_ctx, src0, src1, node); break; @@ -14967,7 +14976,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm } return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_ACC: - return op->src[0]->type == GGML_TYPE_F32; + return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; + case GGML_OP_SET: + return op->src[0]->type == op->src[1]->type && op->src[0]->type == op->type && + (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_I32); case GGML_OP_CONCAT: return ggml_type_size(op->src[0]->type) == ggml_type_size(GGML_TYPE_F32); case GGML_OP_ADD1: @@ -15618,6 +15630,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * tensor_clone = ggml_add(ggml_ctx, src_clone[0], src_clone[1]); } else if (tensor->op == GGML_OP_ACC) { tensor_clone = ggml_acc(ggml_ctx, src_clone[0], src_clone[1], tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3]); + } else if (tensor->op == GGML_OP_SET) { + tensor_clone = ggml_set(ggml_ctx, src_clone[0], src_clone[1], tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3]); } else if (tensor->op == GGML_OP_NORM) { tensor_clone = ggml_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params); } else if (tensor->op == GGML_OP_GROUP_NORM) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp b/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp index 3d61168b56f..6ba3d1d89e0 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp @@ -3,6 +3,9 @@ #include "types.glsl" #include "generic_binary_head.glsl" +// false for SET, true for ACC +layout(constant_id = 1) const bool ACC = true; + layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; void main() { @@ -23,7 +26,11 @@ void main() { uint i00, i01, i02, i03; if (i0 < p.ne10 && i1 < p.ne11 && i2 < p.ne12 && i3 < p.ne13) { - data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]) + FLOAT_TYPE(data_b[get_boffset() + src1_idx(i0, i1, i2, i3)])); + if (ACC) { + data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]) + FLOAT_TYPE(data_b[get_boffset() + src1_idx(i0, i1, i2, i3)])); + } else { + data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_b[get_boffset() + src1_idx(i0, i1, i2, i3)])); + } } else { data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx])); } From cc448def01c660ab1da099a8f3916dc78e30f0e2 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Fri, 13 Feb 2026 21:42:04 -0800 Subject: [PATCH 149/831] vulkan: support L2_NORM with contiguous rows (llama/19604) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 12 ++++++++---- .../ggml-vulkan/vulkan-shaders/l2_norm.comp | 19 +++++++++++-------- 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index a9f75f0d00d..114992da08d 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -4082,7 +4082,7 @@ static void ggml_vk_load_shaders(vk_device& device) { } ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); @@ -10639,8 +10639,10 @@ static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& sub } static void ggml_vk_l2_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { - float * op_params = (float *)dst->op_params; - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_L2_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f, 0.0f, 0.0f }); + const float * op_params = (const float *)dst->op_params; + vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst); + p.param1 = op_params[0]; + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_L2_NORM, std::move(p)); } static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { @@ -14912,8 +14914,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm return true; case GGML_OP_NORM: case GGML_OP_GROUP_NORM: - case GGML_OP_L2_NORM: return ggml_is_contiguous(op->src[0]); + case GGML_OP_L2_NORM: + return ggml_is_contiguous_rows(op->src[0]) && + op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; case GGML_OP_ADD: case GGML_OP_SUB: case GGML_OP_MUL: diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp index 83ef2f87958..7d0a1de0df9 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp @@ -1,6 +1,6 @@ #version 450 -#include "generic_head.glsl" +#include "generic_unary_head.glsl" #include "types.glsl" #extension GL_EXT_control_flow_attributes : enable @@ -8,19 +8,22 @@ layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; -layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; -layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; - shared FLOAT_TYPE sum[BLOCK_SIZE]; void main() { const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; const uint tid = gl_LocalInvocationID.x; + const uint i3 = row / (p.ne11 * p.ne12); + const uint i3_offset = i3 * p.ne12 * p.ne11; + const uint i2 = (row - i3_offset) / p.ne11; + const uint i2_offset = i2 * p.ne11; + const uint i1 = row - i3_offset - i2_offset; + sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp - [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { - const FLOAT_TYPE xi = FLOAT_TYPE(data_a[row*p.KX + col]); + [[unroll]] for (uint i0 = tid; i0 < p.ne00; i0 += BLOCK_SIZE) { + const FLOAT_TYPE xi = FLOAT_TYPE(data_a[i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0]); sum[tid] += xi * xi; } @@ -35,7 +38,7 @@ void main() { const FLOAT_TYPE scale = inversesqrt(max(sum[0], FLOAT_TYPE(p.param1))); - [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { - data_d[row*p.KX + col] = D_TYPE(scale * FLOAT_TYPE(data_a[row*p.KX + col])); + [[unroll]] for (uint i0 = tid; i0 < p.ne00; i0 += BLOCK_SIZE) { + data_d[i3*p.nb13 + i2*p.nb12 + i1*p.nb11 + i0] = D_TYPE(scale * FLOAT_TYPE(data_a[i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0])); } } From fbdac5119c9a8a7a7b2d8f84f49bc3890a3f561d Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 14 Feb 2026 09:54:03 +0200 Subject: [PATCH 150/831] metal : fix ACC op (llama/19427) --- ggml/src/ggml-metal/ggml-metal-device.m | 2 +- ggml/src/ggml-metal/ggml-metal-ops.cpp | 28 +++++++++++++++---------- 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index b4ca9c5dd6f..3db7f126291 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -1067,8 +1067,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_MUL: case GGML_OP_DIV: case GGML_OP_ADD_ID: - return ggml_is_contiguous_rows(op->src[0]) && ggml_is_contiguous_rows(op->src[1]) && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_ACC: + return ggml_is_contiguous_rows(op->src[0]) && ggml_is_contiguous_rows(op->src[1]) && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_REPEAT: case GGML_OP_CONV_TRANSPOSE_1D: return true; diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index c04e9fc7ffa..3d5db0b79f5 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -620,8 +620,8 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) { GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32); GGML_ASSERT(op->type == GGML_TYPE_F32); - GGML_ASSERT(ggml_is_contiguous(op->src[0])); - GGML_ASSERT(ggml_is_contiguous(op->src[1])); + GGML_ASSERT(ggml_is_contiguous_rows(op->src[0])); + GGML_ASSERT(ggml_is_contiguous_rows(op->src[1])); const size_t pnb1 = ((const int32_t *) op->op_params)[0]; const size_t pnb2 = ((const int32_t *) op->op_params)[1]; @@ -671,10 +671,10 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) { } ggml_metal_kargs_bin args = { - /*.ne00 =*/ ne00, - /*.ne01 =*/ ne01, - /*.ne02 =*/ ne02, - /*.ne03 =*/ ne03, + /*.ne00 =*/ ne10, + /*.ne01 =*/ ne11, + /*.ne02 =*/ ne12, + /*.ne03 =*/ ne13, /*.nb00 =*/ nb00, /*.nb01 =*/ pnb1, /*.nb02 =*/ pnb2, @@ -687,10 +687,10 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) { /*.nb11 =*/ nb11, /*.nb12 =*/ nb12, /*.nb13 =*/ nb13, - /*.ne0 =*/ ne0, - /*.ne1 =*/ ne1, - /*.ne2 =*/ ne2, - /*.ne3 =*/ ne3, + /*.ne0 =*/ ne10, + /*.ne1 =*/ ne11, + /*.ne2 =*/ ne12, + /*.ne3 =*/ ne13, /*.nb0 =*/ nb0, /*.nb1 =*/ pnb1, /*.nb2 =*/ pnb2, @@ -707,7 +707,13 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); - const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00); + const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + + int nth = 1; + + while (2*nth < args.ne0 && nth < nth_max) { + nth *= 2; + } ggml_metal_encoder_dispatch_threadgroups(enc, ne11, ne12, ne13, nth, 1, 1); From 226e8c041c720e1293b519ce928487c702fff5c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrien=20Gallou=C3=ABt?= Date: Sat, 14 Feb 2026 11:22:57 +0100 Subject: [PATCH 151/831] ggml : fix GGML_DEBUG with OpenMP (llama/19599) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit last_graph is only available without OpenMP, but ggml_graph_compute_thread() is called in both cases. Signed-off-by: Adrien Gallouët --- ggml/src/ggml-cpu/ggml-cpu.c | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index b003fe13fd9..e048d5e5e77 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -2947,7 +2947,11 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { /*.use_ref =*/ cplan->use_ref, }; - GGML_PRINT_DEBUG("thread #%d compute-start cplan %p last-graph %d \n", state->ith, cplan, state->last_graph); +#ifdef GGML_USE_OPENMP + GGML_PRINT_DEBUG("thread #%d compute-start cplan %p\n", state->ith, (const void *)cplan); +#else + GGML_PRINT_DEBUG("thread #%d compute-start cplan %p last-graph %d\n", state->ith, (const void *)cplan, state->last_graph); +#endif for (int node_n = 0; node_n < cgraph->n_nodes && atomic_load_explicit(&tp->abort, memory_order_relaxed) != node_n; node_n++) { struct ggml_tensor * node = cgraph->nodes[node_n]; @@ -2974,7 +2978,11 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { } } - GGML_PRINT_DEBUG("thread #%d compute-done cplan %p last-graph %d \n", state->ith, cplan, state->last_graph); +#ifdef GGML_USE_OPENMP + GGML_PRINT_DEBUG("thread #%d compute-done cplan %p\n", state->ith, (const void *)cplan); +#else + GGML_PRINT_DEBUG("thread #%d compute-done cplan %p last-graph %d\n", state->ith, (const void *)cplan, state->last_graph); +#endif ggml_barrier(state->threadpool); From 4ac70ce791baedf27cd14f36313b8056a0fe45a8 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 14 Feb 2026 12:57:36 +0200 Subject: [PATCH 152/831] models : optimize qwen3next graph (llama/19375) * models : optimizing qwen3next graph * cont * wip * wip * wip * wip * wip * wip * wip * wip * wip * wip * cont : remove redundant q, g chunking * minor * minor * avoid passing masks around * avoid concats during chunking * naming + shapes * update names and use prefix to disable CUDA graphs --- ggml/src/ggml-cuda/ggml-cuda.cu | 6 +++++- ggml/src/ggml-metal/ggml-metal-common.cpp | 1 + 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 85ce96958fa..bed5c71a1bd 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2872,6 +2872,7 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) { const std::string ffn_moe_down_bias_prefix = "ffn_moe_down_biased"; const std::string nemotron_h_block_out_prefix = "nemotron_h_block_out"; const std::string mamba2_y_add_d_prefix = "mamba2_y_add_d"; + const std::string delta_net_prefix = "dnet_add"; for (int i = 0; i < cgraph->n_nodes; i++) { ggml_tensor * node = cgraph->nodes[i]; @@ -2902,7 +2903,8 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) { strncmp(node->name, ffn_moe_up_bias_prefix.c_str(), ffn_moe_up_bias_prefix.size()) != 0 && strncmp(node->name, ffn_moe_down_bias_prefix.c_str(), ffn_moe_down_bias_prefix.size()) != 0 && strncmp(node->name, nemotron_h_block_out_prefix.c_str(), nemotron_h_block_out_prefix.size()) != 0 && - strncmp(node->name, mamba2_y_add_d_prefix.c_str(), mamba2_y_add_d_prefix.size()) != 0) { + strncmp(node->name, mamba2_y_add_d_prefix.c_str(), mamba2_y_add_d_prefix.size()) != 0 && + strncmp(node->name, delta_net_prefix.c_str(), delta_net_prefix.size()) != 0) { // disable CUDA graphs for batch size > 1 for now while excluding the matrix-matrix addition as part of Gemma3n's `project_per_layer_input` operation // by means of matching node names. See // https://github.com/ggml-org/llama.cpp/blob/f9a31eea06a859e34cecb88b4d020c7f03d86cc4/src/llama-model.cpp#L10199-L10241 and @@ -4544,6 +4546,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_UNARY_OP_CEIL: case GGML_UNARY_OP_ROUND: case GGML_UNARY_OP_TRUNC: + // TODO: should become: + //return ggml_is_contiguous_rows(op->src[0]); return ggml_is_contiguous(op->src[0]); default: return false; diff --git a/ggml/src/ggml-metal/ggml-metal-common.cpp b/ggml/src/ggml-metal/ggml-metal-common.cpp index 87e13786849..2eb9820bff9 100644 --- a/ggml/src/ggml-metal/ggml-metal-common.cpp +++ b/ggml/src/ggml-metal/ggml-metal-common.cpp @@ -273,6 +273,7 @@ static std::vector ggml_metal_graph_optimize_reorder(const std::vector Date: Sun, 15 Feb 2026 19:42:09 +0200 Subject: [PATCH 153/831] sync : ggml --- scripts/sync-ggml.last | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index ff14b73caa8..8db0963de78 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -5cecdad692d868e28dbd2f7c468504770108f30c +68fee723b1f0c2432258b77710f3ca973b3bc5cc From 364c77f4ca2737e3287652e0e8a8c6dce3231bba Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 15 Feb 2026 19:43:28 +0200 Subject: [PATCH 154/831] talk-llama : sync llama.cpp --- examples/talk-llama/llama-arch.cpp | 117 +++- examples/talk-llama/llama-arch.h | 14 +- examples/talk-llama/llama-context.cpp | 311 ++++----- examples/talk-llama/llama-context.h | 36 +- examples/talk-llama/llama-hparams.h | 6 +- examples/talk-llama/llama-impl.h | 10 + examples/talk-llama/llama-mmap.cpp | 20 +- examples/talk-llama/llama-model.cpp | 369 +++++++++- examples/talk-llama/llama-model.h | 12 + examples/talk-llama/llama-vocab.cpp | 11 + examples/talk-llama/llama-vocab.h | 1 + examples/talk-llama/llama.h | 27 +- examples/talk-llama/models/deepseek2.cpp | 5 +- examples/talk-llama/models/kimi-linear.cpp | 7 +- examples/talk-llama/models/models.h | 113 +++ examples/talk-llama/models/qwen35.cpp | 740 ++++++++++++++++++++ examples/talk-llama/models/qwen35moe.cpp | 774 +++++++++++++++++++++ examples/talk-llama/models/qwen3next.cpp | 548 +++++++-------- examples/talk-llama/unicode.cpp | 31 +- 19 files changed, 2597 insertions(+), 555 deletions(-) create mode 100644 examples/talk-llama/models/qwen35.cpp create mode 100644 examples/talk-llama/models/qwen35moe.cpp diff --git a/examples/talk-llama/llama-arch.cpp b/examples/talk-llama/llama-arch.cpp index bd78f1e5562..416c17463ee 100644 --- a/examples/talk-llama/llama-arch.cpp +++ b/examples/talk-llama/llama-arch.cpp @@ -37,6 +37,8 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_QWEN3NEXT, "qwen3next" }, { LLM_ARCH_QWEN3VL, "qwen3vl" }, { LLM_ARCH_QWEN3VLMOE, "qwen3vlmoe" }, + { LLM_ARCH_QWEN35, "qwen35" }, + { LLM_ARCH_QWEN35MOE, "qwen35moe" }, { LLM_ARCH_PHI2, "phi2" }, { LLM_ARCH_PHI3, "phi3" }, { LLM_ARCH_PHIMOE, "phimoe" }, @@ -72,6 +74,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_CHATGLM, "chatglm" }, { LLM_ARCH_GLM4, "glm4" }, { LLM_ARCH_GLM4_MOE, "glm4moe" }, + { LLM_ARCH_GLM_DSA, "glm-dsa" }, { LLM_ARCH_BITNET, "bitnet" }, { LLM_ARCH_T5, "t5" }, { LLM_ARCH_T5ENCODER, "t5encoder" }, @@ -195,6 +198,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_EMBEDDING_SCALE, "%s.embedding_scale" }, { LLM_KV_TOKEN_SHIFT_COUNT, "%s.token_shift_count" }, { LLM_KV_INTERLEAVE_MOE_LAYER_STEP, "%s.interleave_moe_layer_step" }, + { LLM_KV_FULL_ATTENTION_INTERVAL, "%s.full_attention_interval" }, { LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" }, { LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" }, @@ -222,6 +226,9 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ATTENTION_TEMPERATURE_SCALE, "%s.attention.temperature_scale" }, { LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" }, { LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" }, + { LLM_KV_ATTENTION_INDEXER_HEAD_COUNT, "%s.attention.indexer.head_count" }, + { LLM_KV_ATTENTION_INDEXER_KEY_LENGTH, "%s.attention.indexer.key_length" }, + { LLM_KV_ATTENTION_INDEXER_TOP_K, "%s.attention.indexer.top_k" }, { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, { LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" }, @@ -366,6 +373,7 @@ static const std::map LLM_TENSOR_NAMES = { { LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" }, { LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" }, { LLM_TENSOR_SSM_BETA_ALPHA, "blk.%d.ssm_ba" }, + { LLM_TENSOR_SSM_ALPHA, "blk.%d.ssm_alpha" }, { LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" }, { LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" }, { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, @@ -512,6 +520,10 @@ static const std::map LLM_TENSOR_NAMES = { { LLM_TENSOR_VISEXP_FFN_GATE, "blk.%d.vis_gate" }, { LLM_TENSOR_VISEXP_FFN_DOWN, "blk.%d.vis_down" }, { LLM_TENSOR_VISEXP_FFN_UP, "blk.%d.vis_up" }, + { LLM_TENSOR_INDEXER_K_NORM, "blk.%d.indexer.k_norm" }, + { LLM_TENSOR_INDEXER_PROJ, "blk.%d.indexer.proj" }, + { LLM_TENSOR_INDEXER_ATTN_K, "blk.%d.indexer.attn_k" }, + { LLM_TENSOR_INDEXER_ATTN_Q_B, "blk.%d.indexer.attn_q_b" }, }; static std::set llm_get_tensor_names(llm_arch arch) { @@ -968,7 +980,6 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_ATTN_OUT, LLM_TENSOR_ATTN_QKV, LLM_TENSOR_ATTN_GATE, - LLM_TENSOR_FFN_NORM, LLM_TENSOR_FFN_GATE_INP, LLM_TENSOR_FFN_GATE_EXPS, LLM_TENSOR_FFN_DOWN_EXPS, @@ -985,6 +996,63 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_SSM_NORM, LLM_TENSOR_SSM_OUT, }; + case LLM_ARCH_QWEN35: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_POST_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_Q_NORM, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_K_NORM, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_ATTN_QKV, + LLM_TENSOR_ATTN_GATE, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_SSM_A_NOSCAN, + LLM_TENSOR_SSM_CONV1D, + LLM_TENSOR_SSM_DT, + LLM_TENSOR_SSM_BETA, + LLM_TENSOR_SSM_ALPHA, + LLM_TENSOR_SSM_NORM, + LLM_TENSOR_SSM_OUT, + }; + case LLM_ARCH_QWEN35MOE: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_POST_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_Q_NORM, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_K_NORM, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_ATTN_QKV, + LLM_TENSOR_ATTN_GATE, + LLM_TENSOR_FFN_GATE_INP, + LLM_TENSOR_FFN_GATE_EXPS, + LLM_TENSOR_FFN_DOWN_EXPS, + LLM_TENSOR_FFN_UP_EXPS, + LLM_TENSOR_FFN_GATE_INP_SHEXP, + LLM_TENSOR_FFN_GATE_SHEXP, + LLM_TENSOR_FFN_DOWN_SHEXP, + LLM_TENSOR_FFN_UP_SHEXP, + LLM_TENSOR_SSM_A_NOSCAN, + LLM_TENSOR_SSM_CONV1D, + LLM_TENSOR_SSM_DT, + LLM_TENSOR_SSM_BETA, + LLM_TENSOR_SSM_ALPHA, + LLM_TENSOR_SSM_NORM, + LLM_TENSOR_SSM_OUT, + }; case LLM_ARCH_QWEN3VL: case LLM_ARCH_CHAMELEON: case LLM_ARCH_HUNYUAN_DENSE: @@ -1597,6 +1665,46 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, }; + case LLM_ARCH_GLM_DSA: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q_A_NORM, + LLM_TENSOR_ATTN_KV_A_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_Q_A, + LLM_TENSOR_ATTN_Q_B, + LLM_TENSOR_ATTN_KV_A_MQA, + LLM_TENSOR_ATTN_KV_B, + LLM_TENSOR_ATTN_K_B, + LLM_TENSOR_ATTN_V_B, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_GATE_INP, + LLM_TENSOR_FFN_GATE_EXPS, + LLM_TENSOR_FFN_DOWN_EXPS, + LLM_TENSOR_FFN_UP_EXPS, + LLM_TENSOR_FFN_GATE_INP_SHEXP, + LLM_TENSOR_FFN_GATE_SHEXP, + LLM_TENSOR_FFN_DOWN_SHEXP, + LLM_TENSOR_FFN_UP_SHEXP, + LLM_TENSOR_FFN_EXP_PROBS_B, + LLM_TENSOR_INDEXER_K_NORM, + LLM_TENSOR_INDEXER_PROJ, + LLM_TENSOR_INDEXER_ATTN_K, + LLM_TENSOR_INDEXER_ATTN_Q_B, + LLM_TENSOR_NEXTN_EH_PROJ, + LLM_TENSOR_NEXTN_EMBED_TOKENS, + LLM_TENSOR_NEXTN_ENORM, + LLM_TENSOR_NEXTN_HNORM, + LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, + LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, + }; case LLM_ARCH_BITNET: return { LLM_TENSOR_TOKEN_EMBD, @@ -2456,6 +2564,7 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_SSM_X, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_SSM_DT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_SSM_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_SSM_ALPHA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_SSM_BETA_ALPHA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_TIME_MIX_W1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_TIME_MIX_W2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, @@ -2582,6 +2691,10 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_VISEXP_FFN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_VISEXP_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_VISEXP_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_INDEXER_K_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_INDEXER_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_INDEXER_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_INDEXER_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, // NextN/MTP tensors are currently ignored (reserved for future MTP support) // These tensors only exist in the last layer(s) and are treated as output tensors {LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, @@ -2675,6 +2788,8 @@ bool llm_arch_is_hybrid(const llm_arch & arch) { case LLM_ARCH_NEMOTRON_H_MOE: case LLM_ARCH_QWEN3NEXT: case LLM_ARCH_KIMI_LINEAR: + case LLM_ARCH_QWEN35: + case LLM_ARCH_QWEN35MOE: return true; default: return false; diff --git a/examples/talk-llama/llama-arch.h b/examples/talk-llama/llama-arch.h index e8263369b80..521944370b4 100644 --- a/examples/talk-llama/llama-arch.h +++ b/examples/talk-llama/llama-arch.h @@ -41,6 +41,8 @@ enum llm_arch { LLM_ARCH_QWEN3NEXT, LLM_ARCH_QWEN3VL, LLM_ARCH_QWEN3VLMOE, + LLM_ARCH_QWEN35, + LLM_ARCH_QWEN35MOE, LLM_ARCH_PHI2, LLM_ARCH_PHI3, LLM_ARCH_PHIMOE, @@ -76,6 +78,7 @@ enum llm_arch { LLM_ARCH_CHATGLM, LLM_ARCH_GLM4, LLM_ARCH_GLM4_MOE, + LLM_ARCH_GLM_DSA, LLM_ARCH_BITNET, LLM_ARCH_T5, LLM_ARCH_T5ENCODER, @@ -199,6 +202,7 @@ enum llm_kv { LLM_KV_EMBEDDING_SCALE, LLM_KV_TOKEN_SHIFT_COUNT, LLM_KV_INTERLEAVE_MOE_LAYER_STEP, + LLM_KV_FULL_ATTENTION_INTERVAL, LLM_KV_ATTENTION_HEAD_COUNT, LLM_KV_ATTENTION_HEAD_COUNT_KV, @@ -226,6 +230,9 @@ enum llm_kv { LLM_KV_ATTENTION_TEMPERATURE_SCALE, LLM_KV_ATTENTION_KEY_LENGTH_MLA, LLM_KV_ATTENTION_VALUE_LENGTH_MLA, + LLM_KV_ATTENTION_INDEXER_HEAD_COUNT, + LLM_KV_ATTENTION_INDEXER_KEY_LENGTH, + LLM_KV_ATTENTION_INDEXER_TOP_K, LLM_KV_ROPE_DIMENSION_COUNT, LLM_KV_ROPE_DIMENSION_SECTIONS, @@ -404,13 +411,14 @@ enum llm_tensor { LLM_TENSOR_SSM_NORM, LLM_TENSOR_SSM_OUT, LLM_TENSOR_SSM_BETA_ALPHA, // qwen3next + LLM_TENSOR_SSM_ALPHA, // qwen3.5 // Kimi Linear KDA (using SSM_ prefix for consistency) LLM_TENSOR_SSM_CONV1D_Q, // kimi: Q conv1d weight LLM_TENSOR_SSM_CONV1D_K, // kimi: K conv1d weight LLM_TENSOR_SSM_CONV1D_V, // kimi: V conv1d weight LLM_TENSOR_SSM_F_A, // kimi: forget gate projection A LLM_TENSOR_SSM_F_B, // kimi: forget gate projection B - LLM_TENSOR_SSM_BETA, // kimi: beta mixing coefficient + LLM_TENSOR_SSM_BETA, // kimi: beta mixing coefficient and qwen3.5 LLM_TENSOR_SSM_G_A, // kimi: output gate projection A LLM_TENSOR_SSM_G_B, // kimi: output gate projection B LLM_TENSOR_TIME_MIX_W0, @@ -513,6 +521,10 @@ enum llm_tensor { LLM_TENSOR_VISEXP_FFN_GATE, LLM_TENSOR_VISEXP_FFN_DOWN, LLM_TENSOR_VISEXP_FFN_UP, + LLM_TENSOR_INDEXER_K_NORM, + LLM_TENSOR_INDEXER_PROJ, + LLM_TENSOR_INDEXER_ATTN_K, + LLM_TENSOR_INDEXER_ATTN_Q_B, LLM_TENSOR_NEXTN_EH_PROJ, LLM_TENSOR_NEXTN_EMBED_TOKENS, LLM_TENSOR_NEXTN_ENORM, diff --git a/examples/talk-llama/llama-context.cpp b/examples/talk-llama/llama-context.cpp index a6df893a311..99035b6cace 100644 --- a/examples/talk-llama/llama-context.cpp +++ b/examples/talk-llama/llama-context.cpp @@ -677,7 +677,7 @@ enum llama_pooling_type llama_context::pooling_type() const { float * llama_context::get_logits() { output_reorder(); - return logits; + return logits.data; } int64_t llama_context::output_resolve_row(int32_t i) const { @@ -715,7 +715,7 @@ float * llama_context::get_logits_ith(int32_t i) { output_reorder(); try { - if (logits == nullptr) { + if (logits.data == nullptr) { throw std::runtime_error("no logits"); } @@ -739,7 +739,7 @@ float * llama_context::get_logits_ith(int32_t i) { throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs)); } - return logits + j*model.vocab.n_tokens(); + return logits.data + j*model.vocab.n_tokens(); } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what()); #ifndef NDEBUG @@ -753,11 +753,11 @@ float * llama_context::get_logits_ith(int32_t i) { float * llama_context::get_embeddings() { output_reorder(); - return embd; + return embd.data; } llama_token * llama_context::get_sampled_tokens() const{ - return sampling.sampled; + return sampling.sampled.data; } float * llama_context::get_embeddings_ith(int32_t i) { @@ -766,7 +766,7 @@ float * llama_context::get_embeddings_ith(int32_t i) { output_reorder(); try { - if (embd == nullptr) { + if (embd.data == nullptr) { throw std::runtime_error("no embeddings"); } @@ -791,7 +791,7 @@ float * llama_context::get_embeddings_ith(int32_t i) { } const uint32_t n_embd_out = model.hparams.n_embd_out(); - return embd + j*n_embd_out; + return embd.data + j*n_embd_out; } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what()); #ifndef NDEBUG @@ -814,14 +814,14 @@ float * llama_context::get_embeddings_seq(llama_seq_id seq_id) { llama_token llama_context::get_sampled_token_ith(int32_t idx) { output_reorder(); - if (sampling.sampled == nullptr) { + if (!sampling.sampled.has_data()) { return LLAMA_TOKEN_NULL; } try { const int64_t row = output_resolve_row(idx); - GGML_ASSERT(row < (int64_t) sampling.sampled_size); - return sampling.sampled[row]; + GGML_ASSERT(row < (int64_t) sampling.sampled.size); + return sampling.sampled.data[row]; } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: invalid backend sampled token id %d, reason: %s\n", __func__, idx, err.what()); return LLAMA_TOKEN_NULL; @@ -831,7 +831,7 @@ llama_token llama_context::get_sampled_token_ith(int32_t idx) { float * llama_context::get_sampled_probs_ith(int32_t idx) { output_reorder(); - if (sampling.probs == nullptr) { + if (!sampling.probs.has_data()) { return nullptr; } @@ -840,7 +840,7 @@ float * llama_context::get_sampled_probs_ith(int32_t idx) { if ((size_t) row >= sampling.probs_count.size() || sampling.probs_count[row] == 0) { return nullptr; } - return sampling.probs + row*model.vocab.n_tokens(); + return sampling.probs.data + row*model.vocab.n_tokens(); } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: invalid backend sampled probs id %d, reason: %s\n", __func__, idx, err.what()); return nullptr; @@ -850,7 +850,7 @@ float * llama_context::get_sampled_probs_ith(int32_t idx) { float * llama_context::get_sampled_logits_ith(int32_t idx) { output_reorder(); - if (sampling.logits == nullptr) { + if (!sampling.logits.has_data()) { return nullptr; } @@ -859,7 +859,7 @@ float * llama_context::get_sampled_logits_ith(int32_t idx) { if ((size_t) row >= sampling.logits_count.size() || sampling.logits_count[row] == 0) { return nullptr; } - return sampling.logits + row*model.vocab.n_tokens(); + return sampling.logits.data + row*model.vocab.n_tokens(); } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: invalid backend sampled logits id %d, reason: %s\n", __func__, idx, err.what()); return nullptr; @@ -871,13 +871,14 @@ const llama_token * llama_context::get_sampled_candidates_ith(int32_t idx) { try { const int64_t row = output_resolve_row(idx); - if (sampling.candidates != nullptr && + if (sampling.candidates.has_data() && (size_t) row < sampling.candidates_count.size() && sampling.candidates_count[row] > 0) { - return sampling.candidates + row*model.vocab.n_tokens(); + return sampling.candidates.data + row*model.vocab.n_tokens(); } } catch (const std::exception & err) { // fallback to full vocab list + GGML_UNUSED(err); } return sampling.token_ids_full_vocab.data(); @@ -886,7 +887,7 @@ const llama_token * llama_context::get_sampled_candidates_ith(int32_t idx) { size_t llama_context::get_sampled_candidates_count(int32_t idx) { output_reorder(); - if (sampling.candidates == nullptr) { + if (!sampling.candidates.has_data()) { return 0; } @@ -905,7 +906,7 @@ size_t llama_context::get_sampled_candidates_count(int32_t idx) { size_t llama_context::get_sampled_logits_count(int32_t idx) { output_reorder(); - if (sampling.logits == nullptr) { + if (!sampling.logits.has_data()) { return model.vocab.n_tokens(); } @@ -924,7 +925,7 @@ size_t llama_context::get_sampled_logits_count(int32_t idx) { size_t llama_context::get_sampled_probs_count(int32_t idx) { output_reorder(); - if (sampling.probs == nullptr) { + if (!sampling.probs.has_data()) { return 0; } @@ -1057,51 +1058,43 @@ bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) { return true; } -void llama_context::set_adapter_lora( - llama_adapter_lora * adapter, - float scale) { - LLAMA_LOG_DEBUG("%s: adapter = %p, scale = %f\n", __func__, (void *) adapter, scale); +void llama_context::set_adapters_lora(llama_adapter_lora ** adapters, size_t n_adapters, float * scales) { + LLAMA_LOG_DEBUG("%s: adapters = %p\n", __func__, (void *) adapters); - if (auto it = loras.find(adapter); it != loras.end()) { - if (it->second == scale) { - return; - } + if (adapters_lora_are_same(adapters, n_adapters, scales)) { + return; } - loras[adapter] = scale; + loras.clear(); + + for (size_t i = 0; i < n_adapters; i ++) { + if (scales[i] != 0.0f) { + loras[adapters[i]] = scales[i]; + } + } sched_need_reserve = true; } -bool llama_context::rm_adapter_lora( - llama_adapter_lora * adapter) { - LLAMA_LOG_DEBUG("%s: adapter = %p\n", __func__, (void *) adapter); - - auto it = loras.find(adapter); - if (it != loras.end()) { - loras.erase(it); - - sched_need_reserve = true; +bool llama_context::adapters_lora_are_same(llama_adapter_lora ** adapters, size_t n_adapters, float * scales) { + LLAMA_LOG_DEBUG("%s: adapters = %p\n", __func__, (void *) adapters); - return true; + if (n_adapters != loras.size()) { + return false; } - return false; -} - -void llama_context::clear_adapter_lora() { - LLAMA_LOG_DEBUG("%s: call\n", __func__); + for (size_t i = 0; i < n_adapters; i ++) { + auto it = loras.find(adapters[i]); - if (loras.empty()) { - return; + if (it == loras.end() || it->second != scales[i]) { + return false; + } } - loras.clear(); - - sched_need_reserve = true; + return true; } -bool llama_context::apply_adapter_cvec( +bool llama_context::set_adapter_cvec( const float * data, size_t len, int32_t n_embd, @@ -1254,16 +1247,16 @@ int llama_context::encode(const llama_batch & batch_inp) { auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd(); // extract logits - if (logits && t_logits) { + if (logits.data && t_logits) { ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits); GGML_ASSERT(backend_res != nullptr); - GGML_ASSERT(logits != nullptr); + GGML_ASSERT(logits.data != nullptr); - ggml_backend_tensor_get_async(backend_res, t_logits, logits, 0, n_tokens*n_vocab*sizeof(float)); + ggml_backend_tensor_get_async(backend_res, t_logits, logits.data, 0, n_tokens*n_vocab*sizeof(float)); } // extract embeddings - if (embd && t_embd) { + if (embd.data && t_embd) { ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd); GGML_ASSERT(backend_embd != nullptr); @@ -1271,11 +1264,11 @@ int llama_context::encode(const llama_batch & batch_inp) { case LLAMA_POOLING_TYPE_NONE: { // extract token embeddings - GGML_ASSERT(embd != nullptr); + GGML_ASSERT(embd.data != nullptr); const uint32_t n_embd_out = hparams.n_embd_out(); - GGML_ASSERT(n_tokens*n_embd_out <= (int64_t) embd_size); - ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*n_embd_out*sizeof(float)); + GGML_ASSERT(n_tokens*n_embd_out <= (int64_t) embd.size); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd.data, 0, n_tokens*n_embd_out*sizeof(float)); } break; case LLAMA_POOLING_TYPE_MEAN: case LLAMA_POOLING_TYPE_CLS: @@ -1323,7 +1316,7 @@ int llama_context::encode(const llama_batch & batch_inp) { cross.n_embd = t_embd->ne[0]; cross.n_enc = t_embd->ne[1]; cross.v_embd.resize(cross.n_embd*cross.n_enc); - memcpy(cross.v_embd.data(), embd, ggml_nbytes(t_embd)); + memcpy(cross.v_embd.data(), embd.data, ggml_nbytes(t_embd)); const auto & batch = balloc->get_batch(); @@ -1363,11 +1356,10 @@ static std::map build_seq_to_output_row(const llama_ubat static void copy_tensor_async_ints( const std::map & tensor_map, - llama_token * sampled, - size_t sampled_size, + const buffer_view & sampled, const std::map & seq_to_row, ggml_backend_sched_t sched) { - if (sampled == nullptr) { + if (!sampled.has_data()) { return; } @@ -1378,23 +1370,23 @@ static void copy_tensor_async_ints( } const uint32_t row = it->second; - GGML_ASSERT(row < sampled_size); + GGML_ASSERT(row < sampled.size); GGML_ASSERT(ggml_is_contiguous(tensor) && "sampled tokens tensor must be contiguous for async copy"); ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor); - ggml_backend_tensor_get_async(backend, tensor, sampled + row, 0, sizeof(sampled[row])); + ggml_backend_tensor_get_async(backend, tensor, sampled.data + row, 0, sizeof(sampled.data[row])); } } static void copy_tensor_async_floats( const std::map & tensor_map, - float * dst, + const buffer_view & dst, size_t stride, std::vector & counts, const std::map & seq_to_row, ggml_backend_sched_t sched) { - if (dst == nullptr) { + if (!dst.has_data()) { return; } @@ -1410,7 +1402,7 @@ static void copy_tensor_async_floats( GGML_ASSERT(ggml_is_contiguous(tensor) && "logits/probs tensor must be contiguous for async copy"); ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor); - float * row_ptr = dst + (size_t) row * stride; + float * row_ptr = dst.data + (size_t) row * stride; ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor)); // Update the actual number of logits/probabilities that were written for this row. @@ -1420,12 +1412,12 @@ static void copy_tensor_async_floats( static void copy_tensor_async_candidates( const std::map & tensor_map, - llama_token * dst, + const buffer_view & dst, size_t stride, std::vector & counts, const std::map & seq_to_row, ggml_backend_sched_t sched) { - if (dst == nullptr) { + if (!dst.has_data()) { return; } @@ -1441,7 +1433,7 @@ static void copy_tensor_async_candidates( GGML_ASSERT(ggml_is_contiguous(tensor) && "candidates tensor must be contiguous for async copy"); ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor); - llama_token * row_ptr = dst + (size_t) row * stride; + llama_token * row_ptr = dst.data + (size_t) row * stride; ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor)); // Update the actual number of candidates that were written. @@ -1671,22 +1663,22 @@ int llama_context::decode(const llama_batch & batch_inp) { } // extract logits - if (logits && t_logits && n_outputs > 0 && needs_raw_logits(ubatch, sampling.samplers)) { + if (logits.data && t_logits && n_outputs > 0 && needs_raw_logits(ubatch, sampling.samplers)) { ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits); GGML_ASSERT(backend_res != nullptr); - GGML_ASSERT(logits != nullptr); + GGML_ASSERT(logits.data != nullptr); - float * logits_out = logits + n_outputs_prev*n_vocab; + float * logits_out = logits.data + n_outputs_prev*n_vocab; if (n_outputs) { GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); - GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits_size); + GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits.size); ggml_backend_tensor_get_async(backend_res, t_logits, logits_out, 0, n_outputs*n_vocab*sizeof(float)); } } // extract embeddings - if (embd && t_embd && n_outputs > 0) { + if (embd.data && t_embd && n_outputs > 0) { ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd); GGML_ASSERT(backend_embd != nullptr); @@ -1694,13 +1686,13 @@ int llama_context::decode(const llama_batch & batch_inp) { case LLAMA_POOLING_TYPE_NONE: { // extract token embeddings - GGML_ASSERT(embd != nullptr); + GGML_ASSERT(embd.data != nullptr); const uint32_t n_embd_out = hparams.n_embd_out(); - float * embd_out = embd + n_outputs_prev*n_embd_out; + float * embd_out = embd.data + n_outputs_prev*n_embd_out; if (n_outputs) { GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); - GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd_out <= (int64_t) embd_size); + GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd_out <= (int64_t) embd.size); ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd_out*sizeof(float)); } } break; @@ -1747,7 +1739,7 @@ int llama_context::decode(const llama_batch & batch_inp) { const auto stride = n_vocab; // async copy the sampling data from the backend to the host - copy_tensor_async_ints(res->t_sampled, sampling.sampled, sampling.sampled_size, seq_to_output_row, sched.get()); + copy_tensor_async_ints(res->t_sampled, sampling.sampled, seq_to_output_row, sched.get()); copy_tensor_async_floats (res->t_sampled_logits, sampling.logits, stride, sampling.logits_count, seq_to_output_row, sched.get()); copy_tensor_async_floats (res->t_sampled_probs, sampling.probs, stride, sampling.probs_count, seq_to_output_row, sched.get()); @@ -1818,7 +1810,6 @@ int llama_context::decode(const llama_batch & batch_inp) { // uint32_t llama_context::output_reserve(int32_t n_outputs) { - const auto & hparams = model.hparams; const auto & vocab = model.vocab; @@ -1841,19 +1832,14 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { size_t backend_float_count = 0; size_t backend_token_count = 0; - logits_size = has_logits ? n_vocab*n_outputs_max : 0; - embd_size = has_embd ? n_embd_out*n_outputs_max : 0; + logits.size = has_logits ? n_vocab*n_outputs_max : 0; + embd.size = has_embd ? n_embd_out*n_outputs_max : 0; // Allocate backend sampling output buffers if there are backend samplers configured. const bool has_sampling = !sampling.samplers.empty(); if (has_sampling) { - sampling.logits_size = n_vocab*n_outputs_max; - sampling.probs_size = n_vocab*n_outputs_max; - sampling.sampled_size = n_outputs_max; - sampling.candidates_size = n_vocab*n_outputs_max; - - backend_float_count = sampling.logits_size + sampling.probs_size; - backend_token_count = sampling.sampled_size + sampling.candidates_size; + backend_float_count = 2 * n_vocab * n_outputs_max; // logits + probs + backend_token_count = (1 + n_vocab) * n_outputs_max; // sampled + candidates } if (output_ids.empty()) { @@ -1863,7 +1849,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { const size_t prev_size = buf_output ? ggml_backend_buffer_get_size(buf_output.get()) : 0; const size_t new_size = - (logits_size + embd_size + backend_float_count) * sizeof(float) + + (logits.size + embd.size + backend_float_count) * sizeof(float) + ( backend_token_count) * sizeof(llama_token); // alloc only when more than the current capacity is required @@ -1878,8 +1864,8 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { // TODO: not needed? buf_output = nullptr; - logits = nullptr; - embd = nullptr; + logits.data = nullptr; + embd.data = nullptr; } auto * buft = ggml_backend_cpu_buffer_type(); @@ -1898,35 +1884,27 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { float * output_base = (float *) ggml_backend_buffer_get_base(buf_output.get()); - logits = nullptr; - embd = nullptr; - size_t offset = 0; uint8_t * base = (uint8_t *) output_base; - logits = has_logits ? output_base : nullptr; - offset += logits_size * sizeof(float); - - embd = has_embd ? (float *) (base + offset) : nullptr; - offset += embd_size * sizeof(float); + logits = has_logits ? buffer_view{output_base, logits.size} : buffer_view{nullptr, 0}; + offset += logits.size * sizeof(float); - sampling.logits = nullptr; - sampling.probs = nullptr; - sampling.sampled = nullptr; - sampling.candidates = nullptr; + embd = has_embd ? buffer_view{(float *) (base + offset), embd.size} : buffer_view{nullptr, 0}; + offset += embd.size * sizeof(float); if (has_sampling) { - sampling.logits = (float *) (base + offset); - offset += sampling.logits_size * sizeof(float); + sampling.logits = {(float *) (base + offset), (size_t)(n_vocab*n_outputs_max)}; + offset += sampling.logits.size * sizeof(float); - sampling.probs = (float *) (base + offset); - offset += sampling.probs_size * sizeof(float); + sampling.probs = {(float *) (base + offset), (size_t)(n_vocab*n_outputs_max)}; + offset += sampling.probs.size * sizeof(float); - sampling.sampled = (llama_token *) (base + offset); - offset += sampling.sampled_size * sizeof(llama_token); + sampling.sampled = {(llama_token *) (base + offset), (size_t)n_outputs_max}; + offset += sampling.sampled.size * sizeof(llama_token); - sampling.candidates = (llama_token *) (base + offset); - offset += sampling.candidates_size * sizeof(llama_token); + sampling.candidates = {(llama_token *) (base + offset), (size_t)(n_vocab*n_outputs_max)}; + offset += sampling.candidates.size * sizeof(llama_token); // The count vectors keep track of the actual number of logits/probs/candidates // copied from the backend for each output row. @@ -1939,7 +1917,16 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { std::fill(sampling.probs_count.begin(), sampling.probs_count.end(), 0); std::fill(sampling.candidates_count.begin(), sampling.candidates_count.end(), 0); - std::fill_n(sampling.sampled, sampling.sampled_size, LLAMA_TOKEN_NULL); + std::fill_n(sampling.sampled.data, sampling.sampled.size, LLAMA_TOKEN_NULL); + } else { + sampling.logits = {nullptr, 0}; + sampling.probs = {nullptr, 0}; + sampling.sampled = {nullptr, 0}; + sampling.candidates = {nullptr, 0}; + + sampling.logits_count.clear(); + sampling.probs_count.clear(); + sampling.candidates_count.clear(); } // set all ids as invalid (negative) @@ -1958,49 +1945,42 @@ void llama_context::output_reorder() { const uint64_t i0 = output_swaps[s].i0; const uint64_t i1 = output_swaps[s].i1; - if (logits_size > 0) { + if (logits.size > 0) { for (uint64_t k = 0; k < n_vocab; k++) { - std::swap(logits[i0*n_vocab + k], logits[i1*n_vocab + k]); + std::swap(logits.data[i0*n_vocab + k], logits.data[i1*n_vocab + k]); } } - if (embd_size > 0) { + if (embd.size > 0) { for (uint64_t k = 0; k < n_embd; k++) { - std::swap(embd[i0*n_embd + k], embd[i1*n_embd + k]); + std::swap(embd.data[i0*n_embd + k], embd.data[i1*n_embd + k]); } } - if (sampling.logits && sampling.logits_size > 0) { + if (!sampling.samplers.empty()) { + assert(sampling.logits.size > 0); + assert(sampling.probs.size > 0); + assert(sampling.candidates.size > 0); + assert(sampling.sampled.size > 0); + assert(sampling.logits_count.size() > 0); + assert(sampling.probs_count.size() > 0); + assert(sampling.candidates_count.size() > 0); + for (uint64_t k = 0; k < n_vocab; ++k) { - std::swap(sampling.logits[i0*n_vocab + k], sampling.logits[i1*n_vocab + k]); + std::swap(sampling.logits.data[i0*n_vocab + k], sampling.logits.data[i1*n_vocab + k]); } - } - if (sampling.probs && sampling.probs_size > 0) { for (uint64_t k = 0; k < n_vocab; ++k) { - std::swap(sampling.probs[i0*n_vocab + k], sampling.probs[i1*n_vocab + k]); + std::swap(sampling.probs.data[i0*n_vocab + k], sampling.probs.data[i1*n_vocab + k]); } - } - if (sampling.candidates && sampling.candidates_size > 0) { for (uint64_t k = 0; k < n_vocab; ++k) { - std::swap(sampling.candidates[i0*n_vocab + k], sampling.candidates[i1*n_vocab + k]); + std::swap(sampling.candidates.data[i0*n_vocab + k], sampling.candidates.data[i1*n_vocab + k]); } - } - - if (sampling.sampled && sampling.sampled_size > 0) { - std::swap(sampling.sampled[i0], sampling.sampled[i1]); - } - - if (!sampling.logits_count.empty()) { - std::swap(sampling.logits_count[i0], sampling.logits_count[i1]); - } - if (!sampling.probs_count.empty()) { - std::swap(sampling.probs_count[i0], sampling.probs_count[i1]); - } - - if (!sampling.candidates_count.empty()) { + std::swap(sampling.sampled.data[i0], sampling.sampled.data[i1]); + std::swap(sampling.logits_count[i0], sampling.logits_count[i1]); + std::swap(sampling.probs_count[i0], sampling.probs_count[i1]); std::swap(sampling.candidates_count[i0], sampling.candidates_count[i1]); } } @@ -2013,7 +1993,7 @@ void llama_context::output_reorder() { // uint32_t llama_context::graph_max_nodes(uint32_t n_tokens) const { - if (model.arch == LLM_ARCH_QWEN3NEXT || model.arch == LLM_ARCH_KIMI_LINEAR) { + if (model.arch == LLM_ARCH_QWEN3NEXT || model.arch == LLM_ARCH_KIMI_LINEAR || model.arch == LLM_ARCH_QWEN35 || model.arch == LLM_ARCH_QWEN35MOE) { return std::max(n_tokens * 40, 32u * model.n_tensors()); } uint32_t res = std::max(1024u, 8u*model.n_tensors()); @@ -2533,12 +2513,12 @@ size_t llama_context::state_write_data(llama_io_write_i & io) { { LLAMA_LOG_DEBUG("%s: - writing logits\n", __func__); - const uint64_t logits_size = std::min((uint64_t) this->logits_size, (uint64_t) n_outputs * model.vocab.n_tokens()); + const uint64_t logits_size = std::min((uint64_t) this->logits.size, (uint64_t) n_outputs * model.vocab.n_tokens()); io.write(&logits_size, sizeof(logits_size)); if (logits_size) { - io.write(logits, logits_size * sizeof(float)); + io.write(logits.data, logits_size * sizeof(float)); } } @@ -2546,12 +2526,12 @@ size_t llama_context::state_write_data(llama_io_write_i & io) { { LLAMA_LOG_DEBUG("%s: - writing embeddings\n", __func__); - const uint64_t embd_size = std::min((uint64_t) this->embd_size, (uint64_t) n_outputs * model.hparams.n_embd); + const uint64_t embd_size = std::min((uint64_t) this->embd.size, (uint64_t) n_outputs * model.hparams.n_embd); io.write(&embd_size, sizeof(embd_size)); if (embd_size) { - io.write(embd, embd_size * sizeof(float)); + io.write(embd.data, embd_size * sizeof(float)); } } @@ -2619,12 +2599,12 @@ size_t llama_context::state_read_data(llama_io_read_i & io) { uint64_t logits_size; io.read_to(&logits_size, sizeof(logits_size)); - if (this->logits_size < logits_size) { + if (this->logits.size < logits_size) { throw std::runtime_error("logits buffer too small"); } if (logits_size) { - io.read_to(this->logits, logits_size * sizeof(float)); + io.read_to(this->logits.data, logits_size * sizeof(float)); } } @@ -2635,12 +2615,12 @@ size_t llama_context::state_read_data(llama_io_read_i & io) { uint64_t embd_size; io.read_to(&embd_size, sizeof(embd_size)); - if (this->embd_size < embd_size) { + if (this->embd.size < embd_size) { throw std::runtime_error("embeddings buffer too small"); } if (embd_size) { - io.read_to(this->embd, embd_size * sizeof(float)); + io.read_to(this->embd.data, embd_size * sizeof(float)); } } @@ -3218,35 +3198,28 @@ uint32_t llama_get_sampled_probs_count_ith(llama_context * ctx, int32_t i) { // llama adapter API -int32_t llama_set_adapter_lora( +int32_t llama_set_adapters_lora( llama_context * ctx, - llama_adapter_lora * adapter, - float scale) { - ctx->set_adapter_lora(adapter, scale); - - return 0; -} - -int32_t llama_rm_adapter_lora( - llama_context * ctx, - llama_adapter_lora * adapter) { - bool res = ctx->rm_adapter_lora(adapter); + llama_adapter_lora ** adapters, + size_t n_adapters, + float * scales) { + if (adapters == nullptr || scales == nullptr) { + GGML_ASSERT(n_adapters == 0 && "invalid llama_set_adapters_lora call"); + } - return res ? 0 : -1; -} + ctx->set_adapters_lora(adapters, n_adapters, scales); -void llama_clear_adapter_lora(llama_context * ctx) { - ctx->clear_adapter_lora(); + return 0; } -int32_t llama_apply_adapter_cvec( +int32_t llama_set_adapter_cvec( llama_context * ctx, - const float * data, - size_t len, - int32_t n_embd, - int32_t il_start, - int32_t il_end) { - bool res = ctx->apply_adapter_cvec(data, len, n_embd, il_start, il_end); + const float * data, + size_t len, + int32_t n_embd, + int32_t il_start, + int32_t il_end) { + bool res = ctx->set_adapter_cvec(data, len, n_embd, il_start, il_end); return res ? 0 : -1; } diff --git a/examples/talk-llama/llama-context.h b/examples/talk-llama/llama-context.h index 8e71cdd1dc5..a8e53f335cc 100644 --- a/examples/talk-llama/llama-context.h +++ b/examples/talk-llama/llama-context.h @@ -4,6 +4,7 @@ #include "llama-cparams.h" #include "llama-graph.h" #include "llama-adapter.h" +#include "llama-impl.h" #include "ggml-cpp.h" #include "ggml-opt.h" @@ -104,16 +105,11 @@ struct llama_context { void set_causal_attn(bool value); void set_warmup(bool value); - void set_adapter_lora( - llama_adapter_lora * adapter, - float scale); + void set_adapters_lora(llama_adapter_lora ** adapters, size_t n_adapters, float * scales); - bool rm_adapter_lora( - llama_adapter_lora * adapter); + bool adapters_lora_are_same(llama_adapter_lora ** adapters, size_t n_adapters, float * scales); - void clear_adapter_lora(); - - bool apply_adapter_cvec( + bool set_adapter_cvec( const float * data, size_t len, int32_t n_embd, @@ -269,34 +265,26 @@ struct llama_context { std::unique_ptr memory; // decode output (2-dimensional array: [n_outputs][n_vocab]) - size_t logits_size = 0; // capacity (of floats) for logits - float * logits = nullptr; + buffer_view logits = {nullptr, 0}; // embeddings output (2-dimensional array: [n_outputs][n_embd]) // populated only when pooling_type == LLAMA_POOLING_TYPE_NONE - size_t embd_size = 0; // capacity (of floats) for embeddings - float * embd = nullptr; + buffer_view embd = {nullptr, 0}; - // TODO: simplify struct sampling_info { + // !samplers.empty() to check if any samplers are active std::map samplers; - float * logits = nullptr; - size_t logits_size = 0; - - llama_token * sampled = nullptr; - size_t sampled_size = 0; - - float * probs = nullptr; - size_t probs_size = 0; - - llama_token * candidates = nullptr; - size_t candidates_size = 0; + buffer_view logits = {nullptr, 0}; + buffer_view sampled = {nullptr, 0}; + buffer_view probs = {nullptr, 0}; + buffer_view candidates = {nullptr, 0}; std::vector logits_count; std::vector probs_count; std::vector candidates_count; + // optimization std::vector token_ids_full_vocab; }; diff --git a/examples/talk-llama/llama-hparams.h b/examples/talk-llama/llama-hparams.h index 6c695bdbf66..c4b2a99da5a 100644 --- a/examples/talk-llama/llama-hparams.h +++ b/examples/talk-llama/llama-hparams.h @@ -42,7 +42,6 @@ struct llama_hparams { uint32_t n_ctx_train; // context size the model was trained on uint32_t n_embd; - uint32_t n_embd_features = 0; uint32_t n_layer; int32_t n_layer_kv_from_start = -1; // if non-negative, the first n_layer_kv_from_start layers have KV cache uint32_t n_rot; @@ -194,6 +193,11 @@ struct llama_hparams { std::array xielu_beta; std::array xielu_eps; + // DSA (deepseek sparse attention) + uint32_t indexer_n_head = 0; + uint32_t indexer_head_size = 0; + uint32_t indexer_top_k = 0; + // qwen3vl deepstack uint32_t n_deepstack_layers = 0; diff --git a/examples/talk-llama/llama-impl.h b/examples/talk-llama/llama-impl.h index c3391e79f51..dfd9fee9f44 100644 --- a/examples/talk-llama/llama-impl.h +++ b/examples/talk-llama/llama-impl.h @@ -49,6 +49,16 @@ struct time_meas { int64_t & t_acc; }; +template +struct buffer_view { + T * data; + size_t size = 0; + + bool has_data() const { + return data && size > 0; + } +}; + void replace_all(std::string & s, const std::string & search, const std::string & replace); // TODO: rename to llama_format ? diff --git a/examples/talk-llama/llama-mmap.cpp b/examples/talk-llama/llama-mmap.cpp index 0261e4c72c9..c03228e9ce2 100644 --- a/examples/talk-llama/llama-mmap.cpp +++ b/examples/talk-llama/llama-mmap.cpp @@ -504,6 +504,8 @@ struct llama_mmap::impl { } } #elif defined(_WIN32) + HANDLE hMapping = nullptr; + impl(struct llama_file * file, size_t prefetch, bool numa) { GGML_UNUSED(numa); @@ -511,7 +513,7 @@ struct llama_mmap::impl { HANDLE hFile = (HANDLE) _get_osfhandle(file->file_id()); - HANDLE hMapping = CreateFileMappingA(hFile, NULL, PAGE_READONLY, 0, 0, NULL); + hMapping = CreateFileMappingA(hFile, NULL, PAGE_READONLY, 0, 0, NULL); if (hMapping == NULL) { DWORD error = GetLastError(); @@ -520,9 +522,9 @@ struct llama_mmap::impl { addr = MapViewOfFile(hMapping, FILE_MAP_READ, 0, 0, 0); DWORD error = GetLastError(); - CloseHandle(hMapping); if (addr == NULL) { + CloseHandle(hMapping); throw std::runtime_error(format("MapViewOfFile failed: %s", llama_format_win_err(error).c_str())); } @@ -554,9 +556,17 @@ struct llama_mmap::impl { } ~impl() { - if (!UnmapViewOfFile(addr)) { - LLAMA_LOG_WARN("warning: UnmapViewOfFile failed: %s\n", - llama_format_win_err(GetLastError()).c_str()); + if (hMapping) { + if (addr) { + if (!UnmapViewOfFile(addr)) { + LLAMA_LOG_WARN("warning: UnmapViewOfFile failed: %s\n", + llama_format_win_err(GetLastError()).c_str()); + } + } + if (!CloseHandle(hMapping)) { + LLAMA_LOG_WARN("warning: CloseHandle failed: %s\n", + llama_format_win_err(GetLastError()).c_str()); + } } } #else diff --git a/examples/talk-llama/llama-model.cpp b/examples/talk-llama/llama-model.cpp index 674d06c8910..c26584aa67f 100644 --- a/examples/talk-llama/llama-model.cpp +++ b/examples/talk-llama/llama-model.cpp @@ -125,6 +125,7 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_21B_A3B: return "21B.A3B"; case LLM_TYPE_30B_A3B: return "30B.A3B"; case LLM_TYPE_31B_A3_5B: return "31B.A3.5B"; + case LLM_TYPE_35B_A3B: return "35B.A3B"; case LLM_TYPE_48B_A3B: return "48B.A3B"; case LLM_TYPE_80B_A3B: return "80B.A3B"; case LLM_TYPE_100B_A6B: return "100B.A6B"; @@ -136,6 +137,7 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_300B_A47B: return "300B.A47B"; case LLM_TYPE_310B_A15B: return "310B.A15B"; case LLM_TYPE_355B_A32B: return "355B.A32B"; + case LLM_TYPE_744B_A40B: return "744B.A40B"; case LLM_TYPE_E2B: return "E2B"; case LLM_TYPE_E4B: return "E4B"; default: return "?B"; @@ -522,7 +524,8 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERT_GROUP_USED_COUNT, hparams.n_group_used, false); if (arch == LLM_ARCH_WAVTOKENIZER_DEC) { - ml.get_key(LLM_KV_FEATURES_LENGTH, hparams.n_embd_features); + ml.get_key(LLM_KV_FEATURES_LENGTH, hparams.n_embd); + ml.get_key(LLM_KV_EMBEDDING_LENGTH, hparams.n_embd_out_impl); ml.get_key(LLM_KV_POSNET_EMBEDDING_LENGTH, hparams.posnet.n_embd); ml.get_key(LLM_KV_POSNET_BLOCK_COUNT, hparams.posnet.n_layer); @@ -1820,6 +1823,50 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_GLM_DSA: + { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, false); + + // MoE parameters + ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert); + ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + + // deepseek MLA parameters + ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); + ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla_impl, false); + ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla_impl, false); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + + // DSA parameters + ml.get_key(LLM_KV_ATTENTION_INDEXER_HEAD_COUNT, hparams.indexer_n_head); + ml.get_key(LLM_KV_ATTENTION_INDEXER_KEY_LENGTH, hparams.indexer_head_size); + ml.get_key(LLM_KV_ATTENTION_INDEXER_TOP_K, hparams.indexer_top_k); + + // Expert gating function (GLM-4.5 uses sigmoid) + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); + if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { + hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID; + } + + // NextN/MTP parameters + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + + // TODO: when MTP is implemented, this should probably be updated if needed + hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; + + switch (hparams.n_layer) { + case 79: type = LLM_TYPE_744B_A40B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_BITNET: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -2403,8 +2450,12 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); // Mark recurrent layers (linear attention layers) - for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = ((i + 1) % 4 != 0); // TODO: extract the magic 4 from "full_attention_interval" + { + uint32_t full_attn_interval = 4; + ml.get_key(LLM_KV_FULL_ATTENTION_INTERVAL, full_attn_interval, false); + for (uint32_t i = 0; i < hparams.n_layer; ++i) { + hparams.recurrent_layer_arr[i] = ((i + 1) % full_attn_interval != 0); + } } switch (hparams.n_layer) { @@ -2412,6 +2463,62 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_QWEN35: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, true); + + // Load linear attention (gated delta net) parameters + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); + + // Mark recurrent layers (linear attention layers) + { + uint32_t full_attn_interval = 4; + ml.get_key(LLM_KV_FULL_ATTENTION_INTERVAL, full_attn_interval, false); + for (uint32_t i = 0; i < hparams.n_layer; ++i) { + hparams.recurrent_layer_arr[i] = ((i + 1) % full_attn_interval != 0); + } + } + + switch (hparams.n_layer) { + case 24: type = LLM_TYPE_2B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_QWEN35MOE: + { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, true); + + // Load linear attention (gated delta net) parameters + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); + + // Mark recurrent layers (linear attention layers) + { + uint32_t full_attn_interval = 4; + ml.get_key(LLM_KV_FULL_ATTENTION_INTERVAL, full_attn_interval, false); + for (uint32_t i = 0; i < hparams.n_layer; ++i) { + hparams.recurrent_layer_arr[i] = ((i + 1) % full_attn_interval != 0); + } + } + + switch (hparams.n_layer) { + case 28: type = LLM_TYPE_35B_A3B; break; + case 48: type = LLM_TYPE_80B_A3B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_MISTRAL3: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -5430,6 +5537,108 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } } break; + case LLM_ARCH_GLM_DSA: + { + const bool is_mla = hparams.is_mla(); + if (!is_mla) { + throw std::runtime_error("GLM_DSA architecture requires MLA"); + } + + // note: these are the actual head sizes you get when treating as MHA or after "decompression" using wv_b for MLA + const int64_t n_embd_head_k_mla = hparams.n_embd_head_k_mla(); + const int64_t n_embd_head_v_mla = hparams.n_embd_head_v_mla(); + + const int64_t n_embd_head_qk_rope = hparams.n_rot; + const int64_t n_embd_head_qk_nope = n_embd_head_k_mla - n_embd_head_qk_rope; + + const int64_t q_lora_rank = hparams.n_lora_q; + const int64_t kv_lora_rank = hparams.n_lora_kv; + + const int64_t n_ff_exp = hparams.n_ff_exp; + const int64_t n_expert_shared = hparams.n_expert_shared; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + // try to load output.weight, if not found, use token_embd (tied embeddings) + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + if (!output) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + int flags = 0; + if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { + // skip all tensors in the NextN layers + // TODO @ngxson : TENSOR_NOT_REQUIRED was a hack, need to remove it later + flags |= TENSOR_SKIP | TENSOR_NOT_REQUIRED; + } + + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, flags); + layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, flags); + layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, flags); + + layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, flags); + layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k_mla}, flags); + + layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + n_embd_head_qk_rope}, flags); + + // note: only old legacy GGUF files will have the unsplit wkv_b tensor in + layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K_B, "weight", i), {n_embd_head_qk_nope, kv_lora_rank, n_head}, flags); + layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_embd_head_v_mla, n_head}, flags); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_embd_head_v_mla, n_embd}, flags); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, flags); + + // DSA indexer + layer.indexer_k_norm = create_tensor(tn(LLM_TENSOR_INDEXER_K_NORM, "weight", i), {hparams.indexer_head_size}, flags); + layer.indexer_k_norm_b = create_tensor(tn(LLM_TENSOR_INDEXER_K_NORM, "bias", i), {hparams.indexer_head_size}, flags); + layer.indexer_proj = create_tensor(tn(LLM_TENSOR_INDEXER_PROJ, "weight", i), {n_embd, hparams.indexer_n_head}, flags); + layer.indexer_attn_k = create_tensor(tn(LLM_TENSOR_INDEXER_ATTN_K, "weight", i), {n_embd, hparams.indexer_head_size}, flags); + layer.indexer_attn_q_b = create_tensor(tn(LLM_TENSOR_INDEXER_ATTN_Q_B, "weight", i), {q_lora_rank, hparams.indexer_n_head * hparams.indexer_head_size}, flags); + if (i < (int) hparams.n_layer_dense_lead) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, flags); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, flags); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, flags); + } else { + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, flags); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0"); + } + + // MoE branch + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, flags); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, flags); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, flags); + + // Shared expert branch + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, flags); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, flags); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, flags); + } + + // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers + if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags); + + // Optional tensors + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, flags | TENSOR_NOT_REQUIRED); + } + } + } break; case LLM_ARCH_NEMOTRON: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -5985,9 +6194,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } break; case LLM_ARCH_WAVTOKENIZER_DEC: { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {hparams.n_embd_features, n_vocab}, 0); + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {hparams.n_embd, n_vocab}, 0); - conv1d = create_tensor(tn(LLM_TENSOR_CONV1D, "weight"), {7, hparams.n_embd_features, hparams.posnet.n_embd}, 0); + conv1d = create_tensor(tn(LLM_TENSOR_CONV1D, "weight"), {7, hparams.n_embd, hparams.posnet.n_embd}, 0); conv1d_b = create_tensor(tn(LLM_TENSOR_CONV1D, "bias"), {1, hparams.posnet.n_embd}, 0); // posnet @@ -6083,8 +6292,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); } - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {hparams.convnext.n_embd, n_embd}, 0); - output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {hparams.convnext.n_embd, hparams.n_embd_out()}, 0); + output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), {hparams.n_embd_out()}, 0); } break; case LLM_ARCH_BAILINGMOE: { @@ -7101,6 +7310,131 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { hparams.n_ff_shexp, n_embd }, 0); } } break; + case LLM_ARCH_QWEN35MOE: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); + } + + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + + // Calculate dimensions from hyperparameters + const int64_t head_k_dim = hparams.ssm_d_state; + const int64_t head_v_dim = hparams.ssm_d_state; + const int64_t n_k_heads = hparams.ssm_n_group; + const int64_t n_v_heads = hparams.ssm_dt_rank; + const int64_t key_dim = head_k_dim * n_k_heads; + const int64_t value_dim = head_v_dim * n_v_heads; + const int64_t conv_dim = key_dim * 2 + value_dim; + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0); + + if (!hparams.is_recurrent(i)) { + // Attention layers + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head * 2 }, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); + + // Q/K normalization for attention layers + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0); + } else { + // Linear attention (gated delta net) specific tensors + // Create tensors with calculated dimensions + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED); + layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), { n_embd, value_dim }, TENSOR_NOT_REQUIRED); + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), { hparams.ssm_d_conv, conv_dim }, 0); + layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), { hparams.ssm_dt_rank }, 0); + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, i), { hparams.ssm_dt_rank }, 0); + layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", i), { n_embd, n_v_heads }, 0); + layer.ssm_alpha = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "weight", i), { n_embd, n_v_heads }, 0); + layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), { head_v_dim }, 0); + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), { value_dim, n_embd }, 0); + } + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0); + + // Shared experts + const int64_t n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff; + + layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), { n_embd }, 0); + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_shexp, n_embd }, 0); + } + } break; + case LLM_ARCH_QWEN35: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); + } + + // Calculate dimensions from hyperparameters + const int64_t head_k_dim = hparams.ssm_d_state; + const int64_t head_v_dim = hparams.ssm_d_state; + const int64_t n_k_heads = hparams.ssm_n_group; + const int64_t n_v_heads = hparams.ssm_dt_rank; + const int64_t key_dim = head_k_dim * n_k_heads; + const int64_t value_dim = head_v_dim * n_v_heads; + const int64_t conv_dim = key_dim * 2 + value_dim; + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0); + + if (!hparams.is_recurrent(i)) { + // Attention layers + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head * 2 }, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); + + // Q/K normalization for attention layers + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0); + } else { + // Linear attention (gated delta net) specific tensors + // Create tensors with calculated dimensions + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED); + layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), { n_embd, value_dim }, TENSOR_NOT_REQUIRED); + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), { hparams.ssm_d_conv, conv_dim }, 0); + layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), { hparams.ssm_dt_rank }, 0); + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, i), { hparams.ssm_dt_rank }, 0); + layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", i), { n_embd, n_v_heads }, 0); + layer.ssm_alpha = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "weight", i), { n_embd, n_v_heads }, 0); + layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), { head_v_dim }, 0); + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), { value_dim, n_embd }, 0); + } + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } break; case LLM_ARCH_MIMO2: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -7545,6 +7879,8 @@ void llama_model::print_info() const { arch == LLM_ARCH_PLAMO2 || arch == LLM_ARCH_GRANITE_HYBRID || arch == LLM_ARCH_QWEN3NEXT || + arch == LLM_ARCH_QWEN35 || + arch == LLM_ARCH_QWEN35MOE || arch == LLM_ARCH_NEMOTRON_H || arch == LLM_ARCH_NEMOTRON_H_MOE) { LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv); @@ -7576,7 +7912,7 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); } - if (arch == LLM_ARCH_DEEPSEEK2) { + if (arch == LLM_ARCH_DEEPSEEK2 || arch == LLM_ARCH_GLM_DSA) { LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); LLAMA_LOG_INFO("%s: n_lora_q = %d\n", __func__, hparams.n_lora_q); LLAMA_LOG_INFO("%s: n_lora_kv = %d\n", __func__, hparams.n_lora_kv); @@ -7776,7 +8112,6 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, cparams.n_seq_max, nullptr); } else if (llm_arch_is_hybrid(arch)) { - // The main difference between hybrid architectures is the // layer filters, so pick the right one here llama_memory_hybrid::layer_filter_cb filter_attn = nullptr; @@ -7801,7 +8136,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, /* attn_type_v */ params.type_v, /* attn_v_trans */ !cparams.flash_attn, /* attn_swa_full */ params.swa_full, - /* attn_kv_size */ cparams.n_ctx, + /* attn_kv_size */ cparams.n_ctx_seq, /* attn_n_ubatch */ cparams.n_ubatch, /* attn_n_pad */ 1, /* recurrent_type_r */ GGML_TYPE_F32, @@ -7818,7 +8153,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, /* attn_type_k */ params.type_k, /* attn_type_v */ params.type_v, /* attn_v_trans */ !cparams.flash_attn, - /* attn_kv_size */ cparams.n_ctx, + /* attn_kv_size */ cparams.n_ctx_seq, /* attn_n_pad */ 1, /* attn_n_swa */ hparams.n_swa, /* attn_swa_type */ hparams.swa_type, @@ -8149,6 +8484,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { llm = std::make_unique(*this, params); } break; case LLM_ARCH_DEEPSEEK2: + case LLM_ARCH_GLM_DSA: { llm = std::make_unique(*this, params); } break; @@ -8343,6 +8679,14 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_QWEN35: + { + llm = std::make_unique(*this, params); + } break; + case LLM_ARCH_QWEN35MOE: + { + llm = std::make_unique(*this, params); + } break; case LLM_ARCH_MISTRAL3: { llm = std::make_unique(*this, params); @@ -8542,6 +8886,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_MISTRAL3: case LLM_ARCH_LLAMA_EMBED: case LLM_ARCH_MAINCODER: + case LLM_ARCH_GLM_DSA: return LLAMA_ROPE_TYPE_NORM; // the pairs of head values are offset by n_rot/2 @@ -8611,6 +8956,8 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { return LLAMA_ROPE_TYPE_MROPE; case LLM_ARCH_QWEN3VL: case LLM_ARCH_QWEN3VLMOE: + case LLM_ARCH_QWEN35: + case LLM_ARCH_QWEN35MOE: return LLAMA_ROPE_TYPE_IMROPE; case LLM_ARCH_GLM4: diff --git a/examples/talk-llama/llama-model.h b/examples/talk-llama/llama-model.h index 7b580043b33..b3505914293 100644 --- a/examples/talk-llama/llama-model.h +++ b/examples/talk-llama/llama-model.h @@ -118,6 +118,7 @@ enum llm_type { LLM_TYPE_21B_A3B, // Ernie MoE small LLM_TYPE_30B_A3B, LLM_TYPE_31B_A3_5B, + LLM_TYPE_35B_A3B, // Qwen3.5 LLM_TYPE_48B_A3B, // Kimi Linear LLM_TYPE_80B_A3B, // Qwen3 Next LLM_TYPE_100B_A6B, @@ -129,6 +130,7 @@ enum llm_type { LLM_TYPE_300B_A47B, // Ernie MoE big LLM_TYPE_310B_A15B, // /MiMo-V2-Flash LLM_TYPE_355B_A32B, // GLM-4.5 + LLM_TYPE_744B_A40B, // GLM-5 LLM_TYPE_E2B, LLM_TYPE_E4B, }; @@ -322,6 +324,9 @@ struct llama_layer { // qwen3next struct ggml_tensor * ssm_beta_alpha = nullptr; + // qwen3.5 + struct ggml_tensor * ssm_alpha = nullptr; + // rwkv struct ggml_tensor * time_mix_w1 = nullptr; struct ggml_tensor * time_mix_w2 = nullptr; @@ -425,6 +430,13 @@ struct llama_layer { struct ggml_tensor * ssm_g_b = nullptr; struct ggml_tensor * ssm_o_norm = nullptr; + // DSA (deepseek sparse attention) + struct ggml_tensor * indexer_k_norm = nullptr; + struct ggml_tensor * indexer_k_norm_b = nullptr; + struct ggml_tensor * indexer_proj = nullptr; + struct ggml_tensor * indexer_attn_k = nullptr; + struct ggml_tensor * indexer_attn_q_b = nullptr; // note: for lora a/b, not bias + struct llama_layer_posnet posnet; struct llama_layer_convnext convnext; diff --git a/examples/talk-llama/llama-vocab.cpp b/examples/talk-llama/llama-vocab.cpp index 6d6bdfa090c..62e137fb842 100644 --- a/examples/talk-llama/llama-vocab.cpp +++ b/examples/talk-llama/llama-vocab.cpp @@ -368,6 +368,13 @@ struct llm_tokenizer_bpe : llm_tokenizer { "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", }; break; + case LLAMA_VOCAB_PRE_TYPE_QWEN35: + regex_exprs = { + // original regex from tokenizer.json + // "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?[\\p{L}\\p{M}]+|\\p{N}| ?[^\\s\\p{L}\\p{M}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" + "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?[\\p{L}\\p{M}]+|\\p{N}| ?[^\\s\\p{L}\\p{M}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + }; + break; case LLAMA_VOCAB_PRE_TYPE_PORO: case LLAMA_VOCAB_PRE_TYPE_BLOOM: case LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH: @@ -1926,6 +1933,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "kormo") { pre_type = LLAMA_VOCAB_PRE_TYPE_QWEN2; clean_spaces = false; + } else if ( + tokenizer_pre == "qwen35") { + pre_type = LLAMA_VOCAB_PRE_TYPE_QWEN35; + clean_spaces = false; } else if ( tokenizer_pre == "stablelm2") { pre_type = LLAMA_VOCAB_PRE_TYPE_STABLELM2; diff --git a/examples/talk-llama/llama-vocab.h b/examples/talk-llama/llama-vocab.h index 28c3a82b91e..718238fb866 100644 --- a/examples/talk-llama/llama-vocab.h +++ b/examples/talk-llama/llama-vocab.h @@ -54,6 +54,7 @@ enum llama_vocab_pre_type { LLAMA_VOCAB_PRE_TYPE_SOLAR_OPEN = 43, LLAMA_VOCAB_PRE_TYPE_YOUTU = 44, LLAMA_VOCAB_PRE_TYPE_EXAONE_MOE = 45, + LLAMA_VOCAB_PRE_TYPE_QWEN35 = 46, }; struct LLM_KV; diff --git a/examples/talk-llama/llama.h b/examples/talk-llama/llama.h index bf4e28a8be1..d2d7f59ebc6 100644 --- a/examples/talk-llama/llama.h +++ b/examples/talk-llama/llama.h @@ -482,7 +482,7 @@ extern "C" { enum llama_params_fit_status { LLAMA_PARAMS_FIT_STATUS_SUCCESS = 0, // found allocations that are projected to fit LLAMA_PARAMS_FIT_STATUS_FAILURE = 1, // could not find allocations that are projected to fit - LLAMA_PARAMS_FIT_STATUS_ERROR = 2, // a hard error occured, e.g. because no model could be found at the specified path + LLAMA_PARAMS_FIT_STATUS_ERROR = 2, // a hard error occurred, e.g. because no model could be found at the specified path }; // fits mparams and cparams to free device memory (assumes system memory is unlimited) @@ -656,21 +656,12 @@ extern "C" { // The following functions operate on a llama_context, hence the naming: llama_verb_... - // Add a loaded LoRA adapter to given context - // This will not modify model's weight - LLAMA_API int32_t llama_set_adapter_lora( + // Set LoRa adapters on the context. Will only modify if the adapters currently in context are different. + LLAMA_API int32_t llama_set_adapters_lora( struct llama_context * ctx, - struct llama_adapter_lora * adapter, - float scale); - - // Remove a specific LoRA adapter from given context - // Return -1 if the adapter is not present in the context - LLAMA_API int32_t llama_rm_adapter_lora( - struct llama_context * ctx, - struct llama_adapter_lora * adapter); - - // Remove all LoRA adapters from given context - LLAMA_API void llama_clear_adapter_lora(struct llama_context * ctx); + struct llama_adapter_lora ** adapters, + size_t n_adapters, + float * scales); // Apply a loaded control vector to a llama_context, or if data is NULL, clear // the currently loaded vector. @@ -678,7 +669,7 @@ extern "C" { // to an n_embd x n_layers buffer starting from layer 1. // il_start and il_end are the layer range the vector should apply to (both inclusive) // See llama_control_vector_load in common to load a control vector. - LLAMA_API int32_t llama_apply_adapter_cvec( + LLAMA_API int32_t llama_set_adapter_cvec( struct llama_context * ctx, const float * data, size_t len, @@ -1150,9 +1141,9 @@ extern "C" { // /// Apply chat template. Inspired by hf apply_chat_template() on python. - /// Both "model" and "custom_template" are optional, but at least one is required. "custom_template" has higher precedence than "model" + /// /// NOTE: This function does not use a jinja parser. It only support a pre-defined list of template. See more: https://github.com/ggml-org/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template - /// @param tmpl A Jinja template to use for this chat. If this is nullptr, the model’s default chat template will be used instead. + /// @param tmpl A Jinja template to use for this chat. /// @param chat Pointer to a list of multiple llama_chat_message /// @param n_msg Number of llama_chat_message in this chat /// @param add_ass Whether to end the prompt with the token(s) that indicate the start of an assistant message. diff --git a/examples/talk-llama/models/deepseek2.cpp b/examples/talk-llama/models/deepseek2.cpp index 987f449934c..b2c1f160601 100644 --- a/examples/talk-llama/models/deepseek2.cpp +++ b/examples/talk-llama/models/deepseek2.cpp @@ -45,7 +45,8 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr ggml_tensor * inp_out_ids = build_inp_out_ids(); - for (int il = 0; il < n_layer; ++il) { + int effective_n_layers = hparams.n_layer - hparams.nextn_predict_layers; + for (int il = 0; il < effective_n_layers; ++il) { ggml_tensor * inpSA = inpL; // norm @@ -188,7 +189,7 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); } } - if (il == n_layer - 1 && inp_out_ids) { + if (il == effective_n_layers - 1 && inp_out_ids) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } diff --git a/examples/talk-llama/models/kimi-linear.cpp b/examples/talk-llama/models/kimi-linear.cpp index 0f037d1a393..942844d071f 100644 --- a/examples/talk-llama/models/kimi-linear.cpp +++ b/examples/talk-llama/models/kimi-linear.cpp @@ -41,8 +41,11 @@ static ggml_tensor * causal_conv1d(ggml_cgraph * gf, ggml_context * ctx0, ggml_t conv_x->nb[1], conv_x->nb[2], n_seq_tokens * conv_x->nb[0]); ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_x, - ggml_view_1d(ctx0, conv_states_all, conv_state_size * n_seqs, - (kv_head * n_embd_r_total + qkv * conv_state_size) * ggml_element_size(conv_states_all)))); + ggml_view_3d(ctx0, conv_states_all, + d_conv - 1, d_inner, n_seqs, + (d_conv - 1) * ggml_element_size(conv_states_all), // nb1: contiguous within one channel's conv taps + n_embd_r_total * ggml_element_size(conv_states_all), // nb2: stride between sequences (skip over K,V states) + (kv_head * n_embd_r_total + qkv * conv_state_size) * ggml_element_size(conv_states_all)))); // offset to first seq's Q/K/V state // Reshape conv weight: GGUF [d_conv, 1, d_inner, 1] -> ggml_ssm_conv expects [d_conv, d_inner] // GGUF stores as [d_conv, 1, d_inner, 1] with memory layout w[conv_step + channel * d_conv] // vLLM stores as [d_inner, d_conv] with memory layout w[channel * d_conv + conv_step] diff --git a/examples/talk-llama/models/models.h b/examples/talk-llama/models/models.h index cfcbb9aaa5b..ec6f80e5265 100644 --- a/examples/talk-llama/models/models.h +++ b/examples/talk-llama/models/models.h @@ -476,6 +476,7 @@ struct llm_build_qwen3vl : public llm_graph_context { struct llm_build_qwen3vlmoe : public llm_graph_context { llm_build_qwen3vlmoe(const llama_model & model, const llm_graph_params & params); }; + struct llm_build_qwen3next : public llm_graph_context_mamba { llm_build_qwen3next(const llama_model & model, const llm_graph_params & params); private: @@ -485,6 +486,118 @@ struct llm_build_qwen3next : public llm_graph_context_mamba { ggml_tensor * inp_pos, int il); + ggml_tensor * build_layer_attn_linear( + llm_graph_input_rs * inp, + ggml_tensor * cur, + int il); + + ggml_tensor * build_layer_ffn( + ggml_tensor * cur, + int il); + + // returns pair of output and new state + std::pair build_delta_net_chunking( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * beta, + ggml_tensor * state, + int il); + + // returns pair of output and new state + std::pair build_delta_net_autoregressive( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * beta, + ggml_tensor * state, + int il); + + ggml_tensor * build_norm_gated( + ggml_tensor * input, + ggml_tensor * weights, + ggml_tensor * gate, + int layer); + + // returns pair of qkv, z + std::pair build_qkvz( + ggml_tensor * input, + int il); + + const llama_model & model; +}; + +struct llm_build_qwen35 : public llm_graph_context_mamba { + llm_build_qwen35(const llama_model & model, const llm_graph_params & params); +private: + ggml_tensor * build_layer_attn( + llm_graph_input_attn_kv * inp_attn, + ggml_tensor * cur, + ggml_tensor * inp_pos, + int * sections, + int il); + + ggml_tensor * build_layer_attn_linear( + llm_graph_input_rs * inp, + ggml_tensor * cur, + ggml_tensor * causal_mask, + ggml_tensor * identity, + ggml_tensor * diag_mask, + int il); + + ggml_tensor * build_layer_ffn( + ggml_tensor * cur, + int il); + + // returns pair of output and new state + std::pair build_delta_net_chunking( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * beta, + ggml_tensor * state, + ggml_tensor * causal_mask, + ggml_tensor * identity, + ggml_tensor * diag_mask, + int il); + + // returns pair of output and new state + std::pair build_delta_net_autoregressive( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * beta, + ggml_tensor * state, + int il); + + ggml_tensor * build_norm_gated( + ggml_tensor * input, + ggml_tensor * weights, + ggml_tensor * gate, + int layer); + + // returns pair of qkv, z + std::pair build_qkvz( + ggml_tensor * input, + int il); + + const llama_model & model; +}; + +struct llm_build_qwen35moe : public llm_graph_context_mamba { + llm_build_qwen35moe(const llama_model & model, const llm_graph_params & params); +private: + ggml_tensor * build_layer_attn( + llm_graph_input_attn_kv * inp_attn, + ggml_tensor * cur, + ggml_tensor * inp_pos, + int * sections, + int il); + ggml_tensor * build_layer_attn_linear( llm_graph_input_rs * inp, ggml_tensor * cur, diff --git a/examples/talk-llama/models/qwen35.cpp b/examples/talk-llama/models/qwen35.cpp new file mode 100644 index 00000000000..592c170457b --- /dev/null +++ b/examples/talk-llama/models/qwen35.cpp @@ -0,0 +1,740 @@ +#include "ggml.h" +#include "models.h" + +#define CHUNK_SIZE 64 + +llm_build_qwen35::llm_build_qwen35(const llama_model & model, const llm_graph_params & params) : + llm_graph_context_mamba(params), model(model) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + int sections[4]; + std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + cb(inpL, "model.input_embed", -1); + + auto * inp = build_inp_mem_hybrid(); + + ggml_tensor * inp_pos = build_inp_pos(); + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + ggml_tensor * causal_mask = + ggml_tri(ctx0, ggml_fill(ctx0, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, CHUNK_SIZE, CHUNK_SIZE), 1.0f), + GGML_TRI_TYPE_LOWER); + + ggml_tensor * identity = ggml_diag(ctx0, ggml_fill(ctx0, ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, CHUNK_SIZE), 1.0f)); + ggml_tensor * diag_mask = ggml_add(ctx0, causal_mask, identity); + + ggml_build_forward_expand(gf, causal_mask); + ggml_build_forward_expand(gf, identity); + ggml_build_forward_expand(gf, diag_mask); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // Determine layer type and build appropriate attention mechanism + if (hparams.is_recurrent(il)) { + // Linear attention layer (gated delta net) + cur = build_layer_attn_linear(inp->get_recr(), cur, causal_mask, identity, diag_mask, il); + } else { + // Full attention layer + cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + // Residual connection + cur = ggml_add(ctx0, cur, inpSA); + cb(cur, "attn_residual", il); + + // Save the tensor before post-attention norm for residual connection + ggml_tensor * ffn_residual = cur; + + // Post-attention norm + ggml_tensor * attn_post_norm = build_norm(cur, model.layers[il].attn_post_norm, nullptr, LLM_NORM_RMS, il); + cb(attn_post_norm, "attn_post_norm", il); + + // Dense FFN layer - without residual connection + cur = build_layer_ffn(attn_post_norm, il); + cb(cur, "ffn_out", il); + + // Residual connection for FFN - add to the tensor from before post_attention_layernorm + cur = ggml_add(ctx0, cur, ffn_residual); + cb(cur, "post_ffn", il); + + // Input for next layer + inpL = cur; + } + cur = inpL; + + // Final norm + cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // LM head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} + +// utility to get one slice from the third dimension +// input dim: [x, y, c, b] +// output dim: [x, y, 1, b] +static ggml_tensor * get_slice_2d(ggml_context * ctx0, ggml_tensor * t, int64_t c) { + return ggml_view_4d(ctx0, t, t->ne[0], t->ne[1], 1, t->ne[3], + t->nb[1], t->nb[2], t->nb[3], t->nb[2] * c); +} + +std::pair llm_build_qwen35::build_delta_net_chunking( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * beta, + ggml_tensor * state, + ggml_tensor * causal_mask, + ggml_tensor * identity, + ggml_tensor * diag_mask, + int il) { + const int64_t S_k = q->ne[0]; + const int64_t H_k = q->ne[1]; + const int64_t n_tokens = q->ne[2]; + const int64_t n_seqs = q->ne[3]; + + const int64_t S_v = v->ne[0]; + const int64_t H_v = v->ne[1]; + + GGML_ASSERT(v->ne[2] == n_tokens); + GGML_ASSERT(k->ne[2] == n_tokens); + GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs); + GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs); + GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs); + + GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs); + GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs); + + GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case + + const float eps_norm = hparams.f_norm_rms_eps; + + q = ggml_l2_norm(ctx0, q, eps_norm); + k = ggml_l2_norm(ctx0, k, eps_norm); + + const float scale = 1.0f / sqrtf(S_v); + + q = ggml_scale(ctx0, q, scale); + + beta = ggml_sigmoid(ctx0, beta); + + cb(q, "q_in", il); + cb(k, "k_in", il); + cb(v, "v_in", il); + cb(beta, "beta_in", il); + cb(g, "g_in", il); + + q = ggml_cont_4d(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); + k = ggml_cont_4d(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); + v = ggml_cont_4d(ctx0, ggml_permute(ctx0, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); + g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 2, 0, 3, 1), n_tokens, 1, H_k, n_seqs); + + beta = ggml_cont(ctx0, ggml_permute(ctx0, beta, 2, 0, 1, 3)); + state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs); + + cb(q, "q_perm", il); + cb(k, "k_perm", il); + cb(v, "v_perm", il); + cb(beta, "beta_perm", il); + cb(g, "g_perm", il); + cb(state, "state_in", il); + + GGML_ASSERT(q->ne[1] == n_tokens && q->ne[0] == S_k && q->ne[2] == H_k && q->ne[3] == n_seqs); + GGML_ASSERT(k->ne[1] == n_tokens && k->ne[0] == S_k && k->ne[2] == H_k && k->ne[3] == n_seqs); + GGML_ASSERT(v->ne[1] == n_tokens && v->ne[0] == S_v && v->ne[2] == H_k && v->ne[3] == n_seqs); + GGML_ASSERT(beta->ne[1] == n_tokens && beta->ne[2] == H_k && beta->ne[0] == 1 && beta->ne[3] == n_seqs); + + // Do padding + const int64_t chunk_size = CHUNK_SIZE; + + const int64_t pad = (chunk_size - n_tokens % chunk_size) % chunk_size; + const int64_t n_chunks = (n_tokens + pad) / chunk_size; + + q = ggml_pad(ctx0, q, 0, pad, 0, 0); + k = ggml_pad(ctx0, k, 0, pad, 0, 0); + v = ggml_pad(ctx0, v, 0, pad, 0, 0); + g = ggml_pad(ctx0, g, pad, 0, 0, 0); + beta = ggml_pad(ctx0, beta, 0, pad, 0, 0); + + cb(q, "q_pad", il); + cb(k, "k_pad", il); + cb(v, "v_pad", il); + cb(beta, "beta_pad", il); + cb(g, "g_pad", il); + + ggml_tensor * v_beta = ggml_mul(ctx0, v, beta); + ggml_tensor * k_beta = ggml_mul(ctx0, k, beta); + + cb(v_beta, "v_beta", il); + cb(k_beta, "k_beta", il); + + q = ggml_reshape_4d(ctx0, q, S_k, chunk_size, n_chunks, H_k * n_seqs); + k = ggml_reshape_4d(ctx0, k, S_k, chunk_size, n_chunks, H_k * n_seqs); + k_beta = ggml_reshape_4d(ctx0, k_beta, S_k, chunk_size, n_chunks, H_k * n_seqs); + v = ggml_reshape_4d(ctx0, v, S_v, chunk_size, n_chunks, H_v * n_seqs); + v_beta = ggml_reshape_4d(ctx0, v_beta, S_v, chunk_size, n_chunks, H_v * n_seqs); + + g = ggml_reshape_4d(ctx0, g, chunk_size, 1, n_chunks, H_k * n_seqs); + beta = ggml_reshape_4d(ctx0, beta, 1, chunk_size, n_chunks, H_k * n_seqs); + + ggml_tensor * g_cumsum = ggml_cumsum(ctx0, g); + cb(g_cumsum, "g_cumsum", il); // shape: (chunk_size, 1, n_chunks, H_v * n_seqs) + + ggml_tensor * gcs_i = g_cumsum; // ggml_reshape_4d(ctx0, g_cumsum, chunk_size, 1, n_chunks, H_v * n_seqs); + ggml_tensor * gcs_j = ggml_reshape_4d(ctx0, g_cumsum, 1, chunk_size, n_chunks, H_v * n_seqs); + + ggml_tensor * gcs_j_broadcast = + ggml_repeat_4d(ctx0, gcs_j, chunk_size, chunk_size, n_chunks, H_v * n_seqs); + + ggml_tensor * decay_mask = ggml_sub(ctx0, gcs_j_broadcast, gcs_i); + cb(decay_mask, "decay_mask", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) + + decay_mask = ggml_mul(ctx0, decay_mask, diag_mask); + decay_mask = ggml_exp(ctx0, decay_mask); + decay_mask = ggml_mul(ctx0, decay_mask, diag_mask); + + ggml_tensor * kmulkbeta = ggml_mul_mat(ctx0, k, k_beta); + + ggml_tensor * k_decay = ggml_mul(ctx0, kmulkbeta, decay_mask); + ggml_tensor * attn = ggml_neg(ctx0, ggml_mul(ctx0, k_decay, causal_mask)); + cb(attn, "attn_pre_solve", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) + + ggml_tensor * attn_lower = ggml_mul(ctx0, attn, causal_mask); + ggml_tensor * lhs = ggml_sub(ctx0, ggml_repeat(ctx0, identity, attn_lower), attn_lower); + + ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, attn, true, true, false); + attn = ggml_mul(ctx0, lin_solve, causal_mask); + attn = ggml_add(ctx0, attn, identity); + cb(attn, "attn_solved", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) + + v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_beta)), attn); + + ggml_tensor * g_cumsum_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_cumsum)); + ggml_tensor * gexp = ggml_exp(ctx0, g_cumsum_t); + + ggml_tensor * kbeta_gexp = ggml_mul(ctx0, k_beta, gexp); + cb(kbeta_gexp, "kbeta_gexp", il); // shape: (S_k, chunk_size, n_chunks, H_v * n_seqs) + + ggml_tensor * k_cumdecay = + ggml_cont(ctx0, ggml_transpose(ctx0, ggml_mul_mat(ctx0, attn, ggml_cont(ctx0, ggml_transpose(ctx0, kbeta_gexp))))); + cb(k_cumdecay, "k_cumdecay", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) + + ggml_tensor * attn_kq = ggml_mul_mat(ctx0, k, q); + attn_kq = ggml_mul(ctx0, attn_kq, decay_mask); + attn_kq = ggml_mul(ctx0, attn_kq, diag_mask); + cb(attn_kq, "attn_kq", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) + + + // vectorized calculation of key_gdiff + // improved from the chunked version: + // g_last = torch.clamp(g_cum[:, :, -1], max=50.0).exp().unsqueeze(-1).unsqueeze(-1) + // g_diff = torch.clamp(g_cum[:, :, -1:] - g_cum, max=50.0).exp() + // key_gdiff = key * g_diff.unsqueeze(-1) + // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new + // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew + + // get last element in g_cumsum along chunk_size dimension (ne0) + // example: [[x, y, z, ..., last], ...] -> [[last], ...] + ggml_tensor * g_last = ggml_view_4d(ctx0, g_cumsum, 1, 1, g_cumsum->ne[2], g_cumsum->ne[3], + g_cumsum->nb[1], g_cumsum->nb[2], g_cumsum->nb[3], + (g_cumsum->ne[0] - 1) * ggml_element_size(g_cumsum)); + g_last = ggml_cont(ctx0, g_last); + cb(g_last, "g_last", il); // shape: (1, 1, n_chunks, H_v * n_seqs) + + ggml_tensor * g_last_exp = ggml_exp(ctx0, g_last); + cb(g_last_exp, "g_last_exp", il); // shape: (1, 1, n_chunks, H_v * n_seqs) + + ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cumsum, g_last)); + cb(g_diff, "g_diff", il); // shape: (chunk_size, 1, n_chunks, H_v * n_seqs) + + ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff); + ggml_tensor * g_diff_exp_t = ggml_reshape_4d(ctx0, g_diff_exp, + 1, chunk_size, n_chunks, g_diff_exp->ne[3]); + + ggml_tensor * key_gdiff = ggml_mul(ctx0, k, g_diff_exp_t); + cb(key_gdiff, "key_gdiff", il); // shape: (S_k, chunk_size, n_chunks, H_v * n_seqs) + + ggml_tensor * key_gdiff_t = ggml_cont(ctx0, ggml_transpose(ctx0, key_gdiff)); + cb(key_gdiff_t, "key_gdiff_t", il); // shape: (chunk_size, S_k, n_chunks, H_v * n_seqs) + + // state to be updated per chunk + ggml_tensor * new_state = state; // ggml_dup(ctx0, state); + cb(new_state, "new_state", il); // shape: (S_v, S_v, H_v, n_seqs) + + // shape after loop of chunks: (S_v, chunk_size, n_chunks, H_v * n_seqs) + ggml_tensor * core_attn_out = nullptr; + + for (int64_t chunk = 0; chunk < n_chunks; chunk++) { + // shape: (S_k, chunk_size, 1, H_k * n_seqs) + ggml_tensor * q_chunk = get_slice_2d(ctx0, q, chunk); // (no cont), next op: ggml_mul + + // shape: (S_v, chunk_size, 1, H_v * n_seqs) + ggml_tensor * v_chunk = get_slice_2d(ctx0, v, chunk); // (no cont), next op: ggml_repeat + + // shape: (chunk_size, 1, n_chunks, H_v * n_seqs) + ggml_tensor * gexp_chunk = get_slice_2d(ctx0, gexp, chunk); // (no cont), next op: ggml_mul + + // shape: (chunk_size, 1, H_v * n_seqs) + ggml_tensor * k_cumdecay_chunk = get_slice_2d(ctx0, k_cumdecay, chunk); // (no cont), next op: ggml_mul_mat + + // attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0) + // replaced by precomputed attn_kq + ggml_tensor * attn_chunk = get_slice_2d(ctx0, attn_kq, chunk); + cb(attn_chunk, "attn_chunk", il); + + ggml_tensor * state_t = ggml_cont_4d(ctx0, ggml_permute(ctx0, new_state, 1, 0, 2, 3), S_v, S_v, 1, H_v * n_seqs); + + // v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state + ggml_tensor * v_prime = ggml_mul_mat(ctx0, state_t, k_cumdecay_chunk); + cb(v_prime, "v_prime_chunk", il); // shape: (S_v, 1, H_v * n_seqs) + + // v_new = v_i - v_prime + ggml_tensor * v_new = ggml_sub(ctx0, ggml_repeat(ctx0, v_chunk, v_prime), v_prime); + ggml_tensor * v_new_t = ggml_cont(ctx0, ggml_transpose(ctx0, v_new)); + cb(v_new, "v_new_chunk", il); + + // attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state + ggml_tensor * q_g_exp = ggml_mul(ctx0, q_chunk, gexp_chunk); + ggml_tensor * attn_inter = ggml_mul_mat(ctx0, state_t, q_g_exp); + cb(attn_inter, "attn_inter_chunk", il); + + // core_attn_out[:, :, i] = attn_inter + attn @ v_new + ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_new_t, attn_chunk); + cb(v_attn, "v_attn_chunk", il); + + ggml_tensor * core_attn_out_chunk = ggml_add(ctx0, attn_inter, v_attn); + cb(core_attn_out_chunk, "core_attn_out_chunk", il); // shape: (S_v, chunk_size, 1, H_v * n_seqs) + + core_attn_out = core_attn_out == nullptr + ? core_attn_out_chunk + : ggml_concat(ctx0, core_attn_out, core_attn_out_chunk, 2); + + // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new + ggml_tensor * k_gdiff_t = get_slice_2d(ctx0, key_gdiff_t, chunk); + //ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, k_gdiff, v_new); // this is slower on metal, why? + ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, k_gdiff_t); + + // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew + ggml_tensor * gexp_last_chunk = ggml_cont(ctx0, get_slice_2d(ctx0, g_last_exp, chunk)); + new_state = ggml_add(ctx0, + ggml_mul(ctx0, new_state, ggml_reshape_4d(ctx0, gexp_last_chunk, gexp_last_chunk->ne[0], gexp_last_chunk->ne[1], H_v, n_seqs)), + ggml_reshape_4d(ctx0, kgdmulvnew, kgdmulvnew->ne[0], kgdmulvnew->ne[1], H_v, n_seqs)); + } + + // truncate padded tokens + ggml_tensor * output_tokens = ggml_view_4d(ctx0, core_attn_out, + S_v, n_tokens, H_v, n_seqs, + ggml_row_size(core_attn_out->type, S_v), + ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks), + ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks * H_v), 0); + output_tokens = ggml_cont(ctx0, output_tokens); + cb(output_tokens, "output_tokens", il); + + // permute back to (S_v, H_v, n_tokens, n_seqs) + output_tokens = ggml_permute(ctx0, output_tokens, 0, 2, 1, 3); + output_tokens = ggml_cont(ctx0, output_tokens); + + return {output_tokens, new_state}; +} + +std::pair llm_build_qwen35::build_delta_net_autoregressive( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * beta, + ggml_tensor * state, + int il) { + const int64_t S_k = q->ne[0]; + const int64_t H_k = q->ne[1]; + const int64_t n_tokens = q->ne[2]; + const int64_t n_seqs = q->ne[3]; + + const int64_t S_v = v->ne[0]; + const int64_t H_v = v->ne[1]; + + GGML_ASSERT(n_tokens == 1); // This function is optimized for single token processing + GGML_ASSERT(v->ne[2] == n_tokens); + GGML_ASSERT(k->ne[2] == n_tokens); + GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs); + GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs); + GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs); + + GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs); + GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs); + + GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case + + const float eps_norm = hparams.f_norm_rms_eps; + + q = ggml_l2_norm(ctx0, q, eps_norm); + k = ggml_l2_norm(ctx0, k, eps_norm); + + const float scale = 1.0f / sqrtf(S_v); + + q = ggml_scale(ctx0, q, scale); + beta = ggml_sigmoid(ctx0, beta); + + cb(q, "q_in", il); + cb(k, "k_in", il); + cb(v, "v_in", il); + cb(beta, "beta_in", il); + cb(g, "g_in", il); + + state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs); + + ggml_tensor * g_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, g), 1, 1, H_k, n_seqs); + ggml_tensor * beta_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, beta), 1, 1, H_k, n_seqs); + + // Apply exponential to g_t + g_t = ggml_exp(ctx0, g_t); + + // Apply the gated delta rule for the single timestep + // last_recurrent_state = last_recurrent_state * g_t + state = ggml_mul(ctx0, state, g_t); + + // kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2) + ggml_tensor * k_t_unsqueezed = ggml_reshape_4d(ctx0, k, 1, S_v, H_v, n_seqs); + ggml_tensor * kv_mem = ggml_mul(ctx0, state, k_t_unsqueezed); + // we need to sum over dim=-2, so we transpose, sum, then transpose again + kv_mem = ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, kv_mem)))); + + // v_t = v.unsqueeze(2) (we insert the singleton dimension after n_seqs and H_v) + ggml_tensor * v_t = ggml_reshape_4d(ctx0, v, S_v, 1, H_v, n_seqs); + // delta = (v_t - kv_mem) * beta_t + ggml_tensor * v_diff = ggml_sub(ctx0, v_t, kv_mem); // both should be [S_v, 1, H_v, n_seqs] + ggml_tensor * delta = ggml_mul(ctx0, v_diff, beta_t); + + // last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta + ggml_tensor * k_t_delta = ggml_mul(ctx0, ggml_repeat_4d(ctx0, k_t_unsqueezed, S_v, S_v, H_v, n_seqs), delta); + state = ggml_add(ctx0, state, k_t_delta); + + // Compute the attention output + // core_attn_out = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2) + ggml_tensor * q_t_unsqueezed = ggml_reshape_4d(ctx0, q, 1, S_v, H_v, n_seqs); // unsqueeze q_t + ggml_tensor * state_q = ggml_mul(ctx0, state, q_t_unsqueezed); + // again, since it's over dim = -2, transpose, sum, transpose back + ggml_tensor * core_attn_out = + ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, state_q)))); + + // core_attn_out should be [S_v, 1, H_v, n_seqs] after this + cb(core_attn_out, "output_tokens", il); + cb(state, "new_state", il); + + return {core_attn_out, state}; +} + +std::pair llm_build_qwen35::build_qkvz( + ggml_tensor * input, + int il) { + const int64_t n_seqs = ubatch.n_seqs; + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + + ggml_tensor * qkv_mixed = build_lora_mm(model.layers[il].wqkv, input); + qkv_mixed = ggml_reshape_3d(ctx0, qkv_mixed, qkv_mixed->ne[0], n_seq_tokens, n_seqs); + cb(qkv_mixed, "linear_attn_qkv_mixed", il); + + ggml_tensor * z = build_lora_mm(model.layers[il].wqkv_gate, input); + cb(z, "z", il); + + return { qkv_mixed, z }; +} + +ggml_tensor * llm_build_qwen35::build_norm_gated( + ggml_tensor * input, + ggml_tensor * weights, + ggml_tensor * gate, + int layer) { + ggml_tensor * normalized = build_norm(input, weights, nullptr, LLM_NORM_RMS, layer); + ggml_tensor * gated_silu = ggml_silu(ctx0, gate); + + return ggml_mul(ctx0, normalized, gated_silu); +} + +ggml_tensor * llm_build_qwen35::build_layer_attn( + llm_graph_input_attn_kv * inp, + ggml_tensor * cur, + ggml_tensor * inp_pos, + int * sections, + int il) { + const int64_t n_embd_head = hparams.n_embd_head_v; + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + // Order: joint QG projection, QG split, Q norm, KV projection, K norm, RoPE, attention + + // Qwen3Next uses a single Q projection that outputs query + gate + ggml_tensor * Qcur_full = build_lora_mm(model.layers[il].wq, cur); // [ (n_embd_head * 2) * n_head, n_tokens ] + cb(Qcur_full, "Qcur_full", il); + + ggml_tensor * Qcur = ggml_view_3d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, + ggml_element_size(Qcur_full) * n_embd_head * 2, + ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, 0); + cb(Qcur, "Qcur_reshaped", il); + + // Apply Q normalization + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + // Apply K normalization + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + + ggml_tensor * gate = ggml_view_3d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, + ggml_element_size(Qcur_full) * n_embd_head * 2, + ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, + ggml_element_size(Qcur_full) * n_embd_head); + gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens); + cb(gate, "gate_reshaped", il); + + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + // Apply MRoPE + Qcur = ggml_rope_multi( + ctx0, Qcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_multi( + ctx0, Kcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + // Attention computation + const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + + cur = build_attn(inp, + nullptr, nullptr, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + cb(cur, "attn_pregate", il); + + ggml_tensor * gate_sigmoid = ggml_sigmoid(ctx0, gate); + cb(gate_sigmoid, "gate_sigmoid", il); + + cur = ggml_mul(ctx0, cur, gate_sigmoid); + cb(cur, "attn_gated", il); + + cur = build_lora_mm(model.layers[il].wo, cur); + cb(cur, "attn_output", il); + + return cur; +} + +ggml_tensor * llm_build_qwen35::build_layer_attn_linear( + llm_graph_input_rs * inp, + ggml_tensor * cur, + ggml_tensor * causal_mask, + ggml_tensor * identity, + ggml_tensor * diag_mask, + int il) { + const auto * mctx_cur = inp->mctx; + + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t n_seqs = ubatch.n_seqs; + const int64_t head_k_dim = hparams.ssm_d_state; + const int64_t num_k_heads = hparams.ssm_n_group; + const int64_t num_v_heads = hparams.ssm_dt_rank; + const int64_t head_v_dim = d_inner / num_v_heads; + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + + const auto kv_head = mctx_cur->get_head(); + + GGML_ASSERT(n_seqs != 0); + GGML_ASSERT(ubatch.equal_seqs()); + GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); + + // Input projections + auto qkvz = build_qkvz(cur, il); + ggml_tensor * qkv_mixed = qkvz.first; + ggml_tensor * z = qkvz.second; + + ggml_tensor * beta = build_lora_mm(model.layers[il].ssm_beta, cur); + beta = ggml_reshape_4d(ctx0, beta, num_v_heads, 1, n_seq_tokens, n_seqs); + cb(beta, "beta", il); + ggml_tensor * alpha = build_lora_mm(model.layers[il].ssm_alpha, cur); + alpha = ggml_cont_3d(ctx0, alpha, num_v_heads, n_seq_tokens, n_seqs); + cb(alpha, "alpha", il); + + ggml_tensor * alpha_biased = ggml_add(ctx0, alpha, model.layers[il].ssm_dt); + ggml_tensor * alpha_softplus = ggml_softplus(ctx0, alpha_biased); + cb(alpha_softplus, "a_softplus", il); + ggml_tensor * gate = ggml_mul(ctx0, alpha_softplus, model.layers[il].ssm_a); // -A_log.exp() * softplus + cb(gate, "gate", il); + + // Get convolution states from cache + ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); + ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); + + // bool use_precomputed_states = n_seq_tokens == 1 && mctx_cur->has_previous_state(); + + // Build the convolution states tensor + ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs); + cb(conv_states, "conv_states", il); + + // Calculate convolution kernel size + ggml_tensor * conv_kernel = model.layers[il].ssm_conv1d; + const int64_t conv_kernel_size = conv_kernel->ne[0]; + const int64_t conv_channels = d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state; + conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs); + cb(conv_states, "conv_states_reshaped", il); + + qkv_mixed = ggml_permute(ctx0, qkv_mixed, 1, 0, 2, 3); + cb(qkv_mixed, "qkv_mixed_permuted", il); + + ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0); + cb(conv_input, "conv_input", il); + + // Update convolution state cache + // Extract the last (conv_kernel_size - 1) states from conv_input + ggml_tensor * last_conv_states = + ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, conv_channels, n_seqs, conv_input->nb[1], + conv_input->nb[2], (conv_input->ne[0] - conv_states->ne[0]) * ggml_element_size(conv_input)); + cb(last_conv_states, "last_conv_states", il); + + ggml_tensor * state_update_target = + ggml_view_1d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels * n_seqs, + kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all)); + cb(state_update_target, "state_update_target", il); + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target)); + cb(conv_states_all, "conv_states_updated", il); + + // Apply SSM convolution + ggml_tensor * conv_output_proper = ggml_ssm_conv(ctx0, conv_input, conv_kernel); + cb(conv_output_proper, "conv_output_raw", il); + + ggml_tensor * conv_output_silu = ggml_silu(ctx0, conv_output_proper); + cb(conv_output_silu, "conv_output_silu", il); + + ggml_tensor * conv_qkv_mix = conv_output_silu; + + // Calculate the total conv dimension + int64_t qkv_dim = head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads; + int64_t nb1_qkv = ggml_row_size(conv_qkv_mix->type, qkv_dim); + + // Extract the convolved Q, K, V from conv_output + ggml_tensor * q_conv = + ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, nb1_qkv, 0); + cb(q_conv, "q_conv", il); + ggml_tensor * k_conv = + ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, nb1_qkv, + head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix)); + cb(k_conv, "k_conv", il); + ggml_tensor * v_conv = + ggml_view_2d(ctx0, conv_qkv_mix, head_v_dim * num_v_heads, n_seq_tokens * n_seqs, nb1_qkv, + 2 * head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix)); + cb(v_conv, "v_conv", il); + + // Unsqueeze them + q_conv = ggml_cont_4d(ctx0, q_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs); + k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs); + v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); + + ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs); + state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim * num_v_heads, 1, n_seqs); + cb(state, "state_predelta", il); + + // if head keys and value keys are different, repeat Q/K to match V's head count + // V heads are in tiled order (from conversion), so simple tiled repeat works + if (num_k_heads != num_v_heads) { + GGML_ASSERT(num_v_heads % num_k_heads == 0); + q_conv = ggml_repeat_4d(ctx0, q_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs); + k_conv = ggml_repeat_4d(ctx0, k_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs); + } + + cb(q_conv, "q_conv_predelta", il); + cb(k_conv, "k_conv_predelta", il); + cb(v_conv, "v_conv_predelta", il); + + // Choose between build_delta_net_chunking, build_delta_net_recurrent, and build_delta_net_autoregressive based on n_tokens + std::pair attn_out; // pair of (output, new_state) + if (n_seq_tokens == 1) { + attn_out = build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il); + } else { + attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, diag_mask, il); + } + ggml_tensor * output = attn_out.first; + ggml_tensor * new_state = attn_out.second; + cb(output, "attn_output", il); + cb(new_state, "new_state", il); + + // Update the recurrent states + ggml_build_forward_expand(gf, + ggml_cpy(ctx0, new_state, + ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs, + kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all)))); + + // Reshape both attn_out_final and z to 2D tensors for normalization + // attn_out_final: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] + ggml_tensor * attn_out_2d_final = ggml_reshape_2d(ctx0, output, head_v_dim, num_v_heads * n_seq_tokens * n_seqs); + + // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] + ggml_tensor * z_2d = ggml_reshape_2d(ctx0, z, head_v_dim, num_v_heads * n_seq_tokens * n_seqs); + + // Apply gated normalization: self.norm(core_attn_out, z) + ggml_tensor * attn_out_norm = build_norm_gated(attn_out_2d_final, model.layers[il].ssm_norm, z_2d, il); + + // Final reshape: [head_dim, n_heads, n_tokens, n_seqs] -> [n_tokens, n_seqs, n_heads * head_dim] + ggml_tensor * final_output = ggml_reshape_3d(ctx0, attn_out_norm, head_v_dim * num_v_heads, n_seq_tokens, n_seqs); + cb(final_output, "final_output", il); + + // Output projection + cur = build_lora_mm(model.layers[il].ssm_out, final_output); + cb(cur, "linear_attn_out", il); + + // Reshape back to original dimensions + cur = ggml_cont_2d(ctx0, cur, n_embd, n_seq_tokens * n_seqs); + return cur; +} + +ggml_tensor * llm_build_qwen35::build_layer_ffn(ggml_tensor * cur, const int il) { + // Qwen3.5 does not use MoE FFN + GGML_ASSERT(model.layers[il].ffn_gate_inp == nullptr); + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + return cur; +} diff --git a/examples/talk-llama/models/qwen35moe.cpp b/examples/talk-llama/models/qwen35moe.cpp new file mode 100644 index 00000000000..0db8f825c67 --- /dev/null +++ b/examples/talk-llama/models/qwen35moe.cpp @@ -0,0 +1,774 @@ +#include "ggml.h" +#include "models.h" + +#define CHUNK_SIZE 64 + +llm_build_qwen35moe::llm_build_qwen35moe(const llama_model & model, const llm_graph_params & params) : + llm_graph_context_mamba(params), model(model) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + int sections[4]; + std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + cb(inpL, "model.input_embed", -1); + + auto * inp = build_inp_mem_hybrid(); + + ggml_tensor * inp_pos = build_inp_pos(); + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + ggml_tensor * causal_mask = + ggml_tri(ctx0, ggml_fill(ctx0, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, CHUNK_SIZE, CHUNK_SIZE), 1.0f), + GGML_TRI_TYPE_LOWER); + + ggml_tensor * identity = ggml_diag(ctx0, ggml_fill(ctx0, ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, CHUNK_SIZE), 1.0f)); + ggml_tensor * diag_mask = ggml_add(ctx0, causal_mask, identity); + + ggml_build_forward_expand(gf, causal_mask); + ggml_build_forward_expand(gf, identity); + ggml_build_forward_expand(gf, diag_mask); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // Determine layer type and build appropriate attention mechanism + if (hparams.is_recurrent(il)) { + // Linear attention layer (gated delta net) + cur = build_layer_attn_linear(inp->get_recr(), cur, causal_mask, identity, diag_mask, il); + } else { + // Full attention layer + cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + // Residual connection + cur = ggml_add(ctx0, cur, inpSA); + cb(cur, "attn_residual", il); + + // Save the tensor before post-attention norm for residual connection + ggml_tensor * ffn_residual = cur; + + // Post-attention norm + ggml_tensor * attn_post_norm = build_norm(cur, model.layers[il].attn_post_norm, nullptr, LLM_NORM_RMS, il); + cb(attn_post_norm, "attn_post_norm", il); + + // MOE FFN layer + cur = build_layer_ffn(attn_post_norm, il); + cb(cur, "ffn_out", il); + + // Residual connection for FFN - add to the tensor from before post_attention_layernorm + cur = ggml_add(ctx0, cur, ffn_residual); + cb(cur, "post_moe", il); + + // Input for next layer + inpL = cur; + } + cur = inpL; + + // Final norm + cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // LM head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} + +// utility to get one slice from the third dimension +// input dim: [x, y, c, b] +// output dim: [x, y, 1, b] +static ggml_tensor * get_slice_2d(ggml_context * ctx0, ggml_tensor * t, int64_t c) { + return ggml_view_4d(ctx0, t, t->ne[0], t->ne[1], 1, t->ne[3], + t->nb[1], t->nb[2], t->nb[3], t->nb[2] * c); +} + +std::pair llm_build_qwen35moe::build_delta_net_chunking( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * beta, + ggml_tensor * state, + ggml_tensor * causal_mask, + ggml_tensor * identity, + ggml_tensor * diag_mask, + int il) { + const int64_t S_k = q->ne[0]; + const int64_t H_k = q->ne[1]; + const int64_t n_tokens = q->ne[2]; + const int64_t n_seqs = q->ne[3]; + + const int64_t S_v = v->ne[0]; + const int64_t H_v = v->ne[1]; + + GGML_ASSERT(v->ne[2] == n_tokens); + GGML_ASSERT(k->ne[2] == n_tokens); + GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs); + GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs); + GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs); + + GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs); + GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs); + + GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case + + const float eps_norm = hparams.f_norm_rms_eps; + + q = ggml_l2_norm(ctx0, q, eps_norm); + k = ggml_l2_norm(ctx0, k, eps_norm); + + const float scale = 1.0f / sqrtf(S_v); + + q = ggml_scale(ctx0, q, scale); + + beta = ggml_sigmoid(ctx0, beta); + + cb(q, "q_in", il); + cb(k, "k_in", il); + cb(v, "v_in", il); + cb(beta, "beta_in", il); + cb(g, "g_in", il); + + q = ggml_cont_4d(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); + k = ggml_cont_4d(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); + v = ggml_cont_4d(ctx0, ggml_permute(ctx0, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); + g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 2, 0, 3, 1), n_tokens, 1, H_k, n_seqs); + + beta = ggml_cont(ctx0, ggml_permute(ctx0, beta, 2, 0, 1, 3)); + state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs); + + cb(q, "q_perm", il); + cb(k, "k_perm", il); + cb(v, "v_perm", il); + cb(beta, "beta_perm", il); + cb(g, "g_perm", il); + cb(state, "state_in", il); + + GGML_ASSERT(q->ne[1] == n_tokens && q->ne[0] == S_k && q->ne[2] == H_k && q->ne[3] == n_seqs); + GGML_ASSERT(k->ne[1] == n_tokens && k->ne[0] == S_k && k->ne[2] == H_k && k->ne[3] == n_seqs); + GGML_ASSERT(v->ne[1] == n_tokens && v->ne[0] == S_v && v->ne[2] == H_k && v->ne[3] == n_seqs); + GGML_ASSERT(beta->ne[1] == n_tokens && beta->ne[2] == H_k && beta->ne[0] == 1 && beta->ne[3] == n_seqs); + + // Do padding + const int64_t chunk_size = CHUNK_SIZE; + + const int64_t pad = (chunk_size - n_tokens % chunk_size) % chunk_size; + const int64_t n_chunks = (n_tokens + pad) / chunk_size; + + q = ggml_pad(ctx0, q, 0, pad, 0, 0); + k = ggml_pad(ctx0, k, 0, pad, 0, 0); + v = ggml_pad(ctx0, v, 0, pad, 0, 0); + g = ggml_pad(ctx0, g, pad, 0, 0, 0); + beta = ggml_pad(ctx0, beta, 0, pad, 0, 0); + + cb(q, "q_pad", il); + cb(k, "k_pad", il); + cb(v, "v_pad", il); + cb(beta, "beta_pad", il); + cb(g, "g_pad", il); + + ggml_tensor * v_beta = ggml_mul(ctx0, v, beta); + ggml_tensor * k_beta = ggml_mul(ctx0, k, beta); + + cb(v_beta, "v_beta", il); + cb(k_beta, "k_beta", il); + + q = ggml_reshape_4d(ctx0, q, S_k, chunk_size, n_chunks, H_k * n_seqs); + k = ggml_reshape_4d(ctx0, k, S_k, chunk_size, n_chunks, H_k * n_seqs); + k_beta = ggml_reshape_4d(ctx0, k_beta, S_k, chunk_size, n_chunks, H_k * n_seqs); + v = ggml_reshape_4d(ctx0, v, S_v, chunk_size, n_chunks, H_v * n_seqs); + v_beta = ggml_reshape_4d(ctx0, v_beta, S_v, chunk_size, n_chunks, H_v * n_seqs); + + g = ggml_reshape_4d(ctx0, g, chunk_size, 1, n_chunks, H_k * n_seqs); + beta = ggml_reshape_4d(ctx0, beta, 1, chunk_size, n_chunks, H_k * n_seqs); + + ggml_tensor * g_cumsum = ggml_cumsum(ctx0, g); + cb(g_cumsum, "g_cumsum", il); // shape: (chunk_size, 1, n_chunks, H_v * n_seqs) + + ggml_tensor * gcs_i = g_cumsum; // ggml_reshape_4d(ctx0, g_cumsum, chunk_size, 1, n_chunks, H_v * n_seqs); + ggml_tensor * gcs_j = ggml_reshape_4d(ctx0, g_cumsum, 1, chunk_size, n_chunks, H_v * n_seqs); + + ggml_tensor * gcs_j_broadcast = + ggml_repeat_4d(ctx0, gcs_j, chunk_size, chunk_size, n_chunks, H_v * n_seqs); + + ggml_tensor * decay_mask = ggml_sub(ctx0, gcs_j_broadcast, gcs_i); + cb(decay_mask, "decay_mask", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) + + decay_mask = ggml_mul(ctx0, decay_mask, diag_mask); + decay_mask = ggml_exp(ctx0, decay_mask); + decay_mask = ggml_mul(ctx0, decay_mask, diag_mask); + + ggml_tensor * kmulkbeta = ggml_mul_mat(ctx0, k, k_beta); + + ggml_tensor * k_decay = ggml_mul(ctx0, kmulkbeta, decay_mask); + ggml_tensor * attn = ggml_neg(ctx0, ggml_mul(ctx0, k_decay, causal_mask)); + cb(attn, "attn_pre_solve", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) + + ggml_tensor * attn_lower = ggml_mul(ctx0, attn, causal_mask); + ggml_tensor * lhs = ggml_sub(ctx0, ggml_repeat(ctx0, identity, attn_lower), attn_lower); + + ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, attn, true, true, false); + attn = ggml_mul(ctx0, lin_solve, causal_mask); + attn = ggml_add(ctx0, attn, identity); + cb(attn, "attn_solved", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) + + v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_beta)), attn); + + ggml_tensor * g_cumsum_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_cumsum)); + ggml_tensor * gexp = ggml_exp(ctx0, g_cumsum_t); + + ggml_tensor * kbeta_gexp = ggml_mul(ctx0, k_beta, gexp); + cb(kbeta_gexp, "kbeta_gexp", il); // shape: (S_k, chunk_size, n_chunks, H_v * n_seqs) + + ggml_tensor * k_cumdecay = + ggml_cont(ctx0, ggml_transpose(ctx0, ggml_mul_mat(ctx0, attn, ggml_cont(ctx0, ggml_transpose(ctx0, kbeta_gexp))))); + cb(k_cumdecay, "k_cumdecay", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) + + ggml_tensor * attn_kq = ggml_mul_mat(ctx0, k, q); + attn_kq = ggml_mul(ctx0, attn_kq, decay_mask); + attn_kq = ggml_mul(ctx0, attn_kq, diag_mask); + cb(attn_kq, "attn_kq", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) + + + // vectorized calculation of key_gdiff + // improved from the chunked version: + // g_last = torch.clamp(g_cum[:, :, -1], max=50.0).exp().unsqueeze(-1).unsqueeze(-1) + // g_diff = torch.clamp(g_cum[:, :, -1:] - g_cum, max=50.0).exp() + // key_gdiff = key * g_diff.unsqueeze(-1) + // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new + // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew + + // get last element in g_cumsum along chunk_size dimension (ne0) + // example: [[x, y, z, ..., last], ...] -> [[last], ...] + ggml_tensor * g_last = ggml_view_4d(ctx0, g_cumsum, 1, 1, g_cumsum->ne[2], g_cumsum->ne[3], + g_cumsum->nb[1], g_cumsum->nb[2], g_cumsum->nb[3], + (g_cumsum->ne[0] - 1) * ggml_element_size(g_cumsum)); + g_last = ggml_cont(ctx0, g_last); + cb(g_last, "g_last", il); // shape: (1, 1, n_chunks, H_v * n_seqs) + + ggml_tensor * g_last_exp = ggml_exp(ctx0, g_last); + cb(g_last_exp, "g_last_exp", il); // shape: (1, 1, n_chunks, H_v * n_seqs) + + ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cumsum, g_last)); + cb(g_diff, "g_diff", il); // shape: (chunk_size, 1, n_chunks, H_v * n_seqs) + + ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff); + ggml_tensor * g_diff_exp_t = ggml_reshape_4d(ctx0, g_diff_exp, + 1, chunk_size, n_chunks, g_diff_exp->ne[3]); + + ggml_tensor * key_gdiff = ggml_mul(ctx0, k, g_diff_exp_t); + cb(key_gdiff, "key_gdiff", il); // shape: (S_k, chunk_size, n_chunks, H_v * n_seqs) + + ggml_tensor * key_gdiff_t = ggml_cont(ctx0, ggml_transpose(ctx0, key_gdiff)); + cb(key_gdiff_t, "key_gdiff_t", il); // shape: (chunk_size, S_k, n_chunks, H_v * n_seqs) + + + // state to be updated per chunk + ggml_tensor * new_state = state; // ggml_dup(ctx0, state); + cb(new_state, "new_state", il); // shape: (S_v, S_v, H_v, n_seqs) + + // shape after loop of chunks: (S_v, chunk_size, n_chunks, H_v * n_seqs) + ggml_tensor * core_attn_out = nullptr; + + for (int64_t chunk = 0; chunk < n_chunks; chunk++) { + // shape: (S_k, chunk_size, 1, H_k * n_seqs) + ggml_tensor * q_chunk = get_slice_2d(ctx0, q, chunk); // (no cont), next op: ggml_mul + + // shape: (S_v, chunk_size, 1, H_v * n_seqs) + ggml_tensor * v_chunk = get_slice_2d(ctx0, v, chunk); // (no cont), next op: ggml_repeat + + // shape: (chunk_size, 1, n_chunks, H_v * n_seqs) + ggml_tensor * gexp_chunk = get_slice_2d(ctx0, gexp, chunk); // (no cont), next op: ggml_mul + + // shape: (chunk_size, 1, H_v * n_seqs) + ggml_tensor * k_cumdecay_chunk = get_slice_2d(ctx0, k_cumdecay, chunk); // (no cont), next op: ggml_mul_mat + + // attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0) + // replaced by precomputed attn_kq + ggml_tensor * attn_chunk = get_slice_2d(ctx0, attn_kq, chunk); + cb(attn_chunk, "attn_chunk", il); + + ggml_tensor * state_t = ggml_cont_4d(ctx0, ggml_permute(ctx0, new_state, 1, 0, 2, 3), S_v, S_v, 1, H_v * n_seqs); + + // v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state + ggml_tensor * v_prime = ggml_mul_mat(ctx0, state_t, k_cumdecay_chunk); + cb(v_prime, "v_prime_chunk", il); // shape: (S_v, 1, H_v * n_seqs) + + // v_new = v_i - v_prime + ggml_tensor * v_new = ggml_sub(ctx0, ggml_repeat(ctx0, v_chunk, v_prime), v_prime); + ggml_tensor * v_new_t = ggml_cont(ctx0, ggml_transpose(ctx0, v_new)); + cb(v_new, "v_new_chunk", il); + + // attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state + ggml_tensor * q_g_exp = ggml_mul(ctx0, q_chunk, gexp_chunk); + ggml_tensor * attn_inter = ggml_mul_mat(ctx0, state_t, q_g_exp); + cb(attn_inter, "attn_inter_chunk", il); + + // core_attn_out[:, :, i] = attn_inter + attn @ v_new + ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_new_t, attn_chunk); + cb(v_attn, "v_attn_chunk", il); + + ggml_tensor * core_attn_out_chunk = ggml_add(ctx0, attn_inter, v_attn); + cb(core_attn_out_chunk, "core_attn_out_chunk", il); // shape: (S_v, chunk_size, 1, H_v * n_seqs) + + core_attn_out = core_attn_out == nullptr + ? core_attn_out_chunk + : ggml_concat(ctx0, core_attn_out, core_attn_out_chunk, 2); + + // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new + ggml_tensor * k_gdiff_t = get_slice_2d(ctx0, key_gdiff_t, chunk); + //ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, k_gdiff, v_new); // this is slower on metal, why? + ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, k_gdiff_t); + + // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew + ggml_tensor * gexp_last_chunk = ggml_cont(ctx0, get_slice_2d(ctx0, g_last_exp, chunk)); + new_state = ggml_add(ctx0, + ggml_mul(ctx0, new_state, ggml_reshape_4d(ctx0, gexp_last_chunk, gexp_last_chunk->ne[0], gexp_last_chunk->ne[1], H_v, n_seqs)), + ggml_reshape_4d(ctx0, kgdmulvnew, kgdmulvnew->ne[0], kgdmulvnew->ne[1], H_v, n_seqs)); + } + + // truncate padded tokens + ggml_tensor * output_tokens = ggml_view_4d(ctx0, core_attn_out, + S_v, n_tokens, H_v, n_seqs, + ggml_row_size(core_attn_out->type, S_v), + ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks), + ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks * H_v), 0); + output_tokens = ggml_cont(ctx0, output_tokens); + cb(output_tokens, "output_tokens", il); + + // permute back to (S_v, H_v, n_tokens, n_seqs) + output_tokens = ggml_permute(ctx0, output_tokens, 0, 2, 1, 3); + output_tokens = ggml_cont(ctx0, output_tokens); + + return {output_tokens, new_state}; +} + +std::pair llm_build_qwen35moe::build_delta_net_autoregressive( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * beta, + ggml_tensor * state, + int il) { + const int64_t S_k = q->ne[0]; + const int64_t H_k = q->ne[1]; + const int64_t n_tokens = q->ne[2]; + const int64_t n_seqs = q->ne[3]; + + const int64_t S_v = v->ne[0]; + const int64_t H_v = v->ne[1]; + + GGML_ASSERT(n_tokens == 1); // This function is optimized for single token processing + GGML_ASSERT(v->ne[2] == n_tokens); + GGML_ASSERT(k->ne[2] == n_tokens); + GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs); + GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs); + GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs); + + GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs); + GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs); + + GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case + + const float eps_norm = hparams.f_norm_rms_eps; + + q = ggml_l2_norm(ctx0, q, eps_norm); + k = ggml_l2_norm(ctx0, k, eps_norm); + + const float scale = 1.0f / sqrtf(S_v); + + q = ggml_scale(ctx0, q, scale); + beta = ggml_sigmoid(ctx0, beta); + + cb(q, "q_in", il); + cb(k, "k_in", il); + cb(v, "v_in", il); + cb(beta, "beta_in", il); + cb(g, "g_in", il); + + state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs); + + ggml_tensor * g_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, g), 1, 1, H_k, n_seqs); + ggml_tensor * beta_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, beta), 1, 1, H_k, n_seqs); + + // Apply exponential to g_t + g_t = ggml_exp(ctx0, g_t); + + // Apply the gated delta rule for the single timestep + // last_recurrent_state = last_recurrent_state * g_t + state = ggml_mul(ctx0, state, g_t); + + // kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2) + ggml_tensor * k_t_unsqueezed = ggml_reshape_4d(ctx0, k, 1, S_v, H_v, n_seqs); + ggml_tensor * kv_mem = ggml_mul(ctx0, state, k_t_unsqueezed); + // we need to sum over dim=-2, so we transpose, sum, then transpose again + kv_mem = ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, kv_mem)))); + + // v_t = v.unsqueeze(2) (we insert the singleton dimension after n_seqs and H_v) + ggml_tensor * v_t = ggml_reshape_4d(ctx0, v, S_v, 1, H_v, n_seqs); + // delta = (v_t - kv_mem) * beta_t + ggml_tensor * v_diff = ggml_sub(ctx0, v_t, kv_mem); // both should be [S_v, 1, H_v, n_seqs] + ggml_tensor * delta = ggml_mul(ctx0, v_diff, beta_t); + + // last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta + ggml_tensor * k_t_delta = ggml_mul(ctx0, ggml_repeat_4d(ctx0, k_t_unsqueezed, S_v, S_v, H_v, n_seqs), delta); + state = ggml_add(ctx0, state, k_t_delta); + + // Compute the attention output + // core_attn_out = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2) + ggml_tensor * q_t_unsqueezed = ggml_reshape_4d(ctx0, q, 1, S_v, H_v, n_seqs); // unsqueeze q_t + ggml_tensor * state_q = ggml_mul(ctx0, state, q_t_unsqueezed); + // again, since it's over dim = -2, transpose, sum, transpose back + ggml_tensor * core_attn_out = + ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, state_q)))); + + // core_attn_out should be [S_v, 1, H_v, n_seqs] after this + cb(core_attn_out, "output_tokens", il); + cb(state, "new_state", il); + + return {core_attn_out, state}; +} + +std::pair llm_build_qwen35moe::build_qkvz( + ggml_tensor * input, + int il) { + const int64_t n_seqs = ubatch.n_seqs; + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + + ggml_tensor * qkv_mixed = build_lora_mm(model.layers[il].wqkv, input); + qkv_mixed = ggml_reshape_3d(ctx0, qkv_mixed, qkv_mixed->ne[0], n_seq_tokens, n_seqs); + cb(qkv_mixed, "linear_attn_qkv_mixed", il); + + ggml_tensor * z = build_lora_mm(model.layers[il].wqkv_gate, input); + cb(z, "z", il); + + return { qkv_mixed, z }; +} + +ggml_tensor * llm_build_qwen35moe::build_norm_gated( + ggml_tensor * input, + ggml_tensor * weights, + ggml_tensor * gate, + int layer) { + ggml_tensor * normalized = build_norm(input, weights, nullptr, LLM_NORM_RMS, layer); + ggml_tensor * gated_silu = ggml_silu(ctx0, gate); + + return ggml_mul(ctx0, normalized, gated_silu); +} + +ggml_tensor * llm_build_qwen35moe ::build_layer_attn( + llm_graph_input_attn_kv * inp, + ggml_tensor * cur, + ggml_tensor * inp_pos, + int * sections, + int il) { + const int64_t n_embd_head = hparams.n_embd_head_v; + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + // Order: joint QG projection, QG split, Q norm, KV projection, K norm, RoPE, attention + + // Qwen3Next uses a single Q projection that outputs query + gate + ggml_tensor * Qcur_full = build_lora_mm(model.layers[il].wq, cur); // [ (n_embd_head * 2) * n_head, n_tokens ] + cb(Qcur_full, "Qcur_full", il); + + ggml_tensor * Qcur = ggml_view_3d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, + ggml_element_size(Qcur_full) * n_embd_head * 2, + ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, 0); + cb(Qcur, "Qcur_reshaped", il); + + // Apply Q normalization + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + // Apply K normalization + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + + ggml_tensor * gate = ggml_view_3d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, + ggml_element_size(Qcur_full) * n_embd_head * 2, + ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, + ggml_element_size(Qcur_full) * n_embd_head); + gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens); + cb(gate, "gate_reshaped", il); + + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + // Apply IMRoPE + Qcur = ggml_rope_multi( + ctx0, Qcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_multi( + ctx0, Kcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + // Attention computation + const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + + cur = build_attn(inp, + nullptr, nullptr, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + cb(cur, "attn_pregate", il); + + ggml_tensor * gate_sigmoid = ggml_sigmoid(ctx0, gate); + cb(gate_sigmoid, "gate_sigmoid", il); + + cur = ggml_mul(ctx0, cur, gate_sigmoid); + cb(cur, "attn_gated", il); + + cur = build_lora_mm(model.layers[il].wo, cur); + cb(cur, "attn_output", il); + + return cur; +} + +ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear( + llm_graph_input_rs * inp, + ggml_tensor * cur, + ggml_tensor * causal_mask, + ggml_tensor * identity, + ggml_tensor * diag_mask, + int il) { + const auto * mctx_cur = inp->mctx; + + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t n_seqs = ubatch.n_seqs; + const int64_t head_k_dim = hparams.ssm_d_state; + const int64_t num_k_heads = hparams.ssm_n_group; + const int64_t num_v_heads = hparams.ssm_dt_rank; + const int64_t head_v_dim = d_inner / num_v_heads; + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + + const auto kv_head = mctx_cur->get_head(); + + GGML_ASSERT(n_seqs != 0); + GGML_ASSERT(ubatch.equal_seqs()); + GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); + + // Input projections + auto qkvz = build_qkvz(cur, il); + ggml_tensor * qkv_mixed = qkvz.first; + ggml_tensor * z = qkvz.second; + + ggml_tensor * beta = build_lora_mm(model.layers[il].ssm_beta, cur); + beta = ggml_reshape_4d(ctx0, beta, num_v_heads, 1, n_seq_tokens, n_seqs); + cb(beta, "beta", il); + ggml_tensor * alpha = build_lora_mm(model.layers[il].ssm_alpha, cur); + alpha = ggml_cont_3d(ctx0, alpha, num_v_heads, n_seq_tokens, n_seqs); + cb(alpha, "alpha", il); + + ggml_tensor * alpha_biased = ggml_add(ctx0, alpha, model.layers[il].ssm_dt); + ggml_tensor * alpha_softplus = ggml_softplus(ctx0, alpha_biased); + cb(alpha_softplus, "a_softplus", il); + ggml_tensor * gate = ggml_mul(ctx0, alpha_softplus, model.layers[il].ssm_a); // -A_log.exp() * softplus + cb(gate, "gate", il); + + // Get convolution states from cache + ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); + ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); + + // bool use_precomputed_states = n_seq_tokens == 1 && mctx_cur->has_previous_state(); + + // Build the convolution states tensor + ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs); + cb(conv_states, "conv_states", il); + + // Calculate convolution kernel size + ggml_tensor * conv_kernel = model.layers[il].ssm_conv1d; + const int64_t conv_kernel_size = conv_kernel->ne[0]; + const int64_t conv_channels = d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state; + conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs); + cb(conv_states, "conv_states_reshaped", il); + + qkv_mixed = ggml_permute(ctx0, qkv_mixed, 1, 0, 2, 3); + cb(qkv_mixed, "qkv_mixed_permuted", il); + + ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0); + cb(conv_input, "conv_input", il); + + // Update convolution state cache + // Extract the last (conv_kernel_size - 1) states from conv_input + ggml_tensor * last_conv_states = + ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, conv_channels, n_seqs, conv_input->nb[1], + conv_input->nb[2], (conv_input->ne[0] - conv_states->ne[0]) * ggml_element_size(conv_input)); + cb(last_conv_states, "last_conv_states", il); + + ggml_tensor * state_update_target = + ggml_view_1d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels * n_seqs, + kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all)); + cb(state_update_target, "state_update_target", il); + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target)); + cb(conv_states_all, "conv_states_updated", il); + + // Apply SSM convolution + ggml_tensor * conv_output_proper = ggml_ssm_conv(ctx0, conv_input, conv_kernel); + cb(conv_output_proper, "conv_output_raw", il); + + ggml_tensor * conv_output_silu = ggml_silu(ctx0, conv_output_proper); + cb(conv_output_silu, "conv_output_silu", il); + + ggml_tensor * conv_qkv_mix = conv_output_silu; + + // Calculate the total conv dimension + int64_t qkv_dim = head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads; + int64_t nb1_qkv = ggml_row_size(conv_qkv_mix->type, qkv_dim); + + // Extract the convolved Q, K, V from conv_output + ggml_tensor * q_conv = + ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, nb1_qkv, 0); + cb(q_conv, "q_conv", il); + ggml_tensor * k_conv = + ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, nb1_qkv, + head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix)); + cb(k_conv, "k_conv", il); + ggml_tensor * v_conv = + ggml_view_2d(ctx0, conv_qkv_mix, head_v_dim * num_v_heads, n_seq_tokens * n_seqs, nb1_qkv, + 2 * head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix)); + cb(v_conv, "v_conv", il); + + // Unsqueeze them + q_conv = ggml_cont_4d(ctx0, q_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs); + k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs); + v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); + + ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs); + state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim * num_v_heads, 1, n_seqs); + cb(state, "state_predelta", il); + + // if head keys and value keys are different, repeat Q/K to match V's head count + // V heads are in tiled order (from conversion), so simple tiled repeat works + if (num_k_heads != num_v_heads) { + GGML_ASSERT(num_v_heads % num_k_heads == 0); + q_conv = ggml_repeat_4d(ctx0, q_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs); + k_conv = ggml_repeat_4d(ctx0, k_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs); + } + + cb(q_conv, "q_conv_predelta", il); + cb(k_conv, "k_conv_predelta", il); + cb(v_conv, "v_conv_predelta", il); + + // Choose between build_delta_net_chunking, build_delta_net_recurrent, and build_delta_net_autoregressive based on n_tokens + std::pair attn_out; // pair of (output, new_state) + if (n_seq_tokens == 1) { + attn_out = build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il); + } else { + attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, diag_mask, il); + } + ggml_tensor * output = attn_out.first; + ggml_tensor * new_state = attn_out.second; + cb(output, "attn_output", il); + cb(new_state, "new_state", il); + + // Update the recurrent states + ggml_build_forward_expand(gf, + ggml_cpy(ctx0, new_state, + ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs, + kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all)))); + + // Reshape both attn_out_final and z to 2D tensors for normalization + // attn_out_final: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] + ggml_tensor * attn_out_2d_final = ggml_reshape_2d(ctx0, output, head_v_dim, num_v_heads * n_seq_tokens * n_seqs); + + // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] + ggml_tensor * z_2d = ggml_reshape_2d(ctx0, z, head_v_dim, num_v_heads * n_seq_tokens * n_seqs); + + // Apply gated normalization: self.norm(core_attn_out, z) + ggml_tensor * attn_out_norm = build_norm_gated(attn_out_2d_final, model.layers[il].ssm_norm, z_2d, il); + + // Final reshape: [head_dim, n_heads, n_tokens, n_seqs] -> [n_tokens, n_seqs, n_heads * head_dim] + ggml_tensor * final_output = ggml_reshape_3d(ctx0, attn_out_norm, head_v_dim * num_v_heads, n_seq_tokens, n_seqs); + cb(final_output, "final_output", il); + + // Output projection + cur = build_lora_mm(model.layers[il].ssm_out, final_output); + cb(cur, "linear_attn_out", il); + + // Reshape back to original dimensions + cur = ggml_cont_2d(ctx0, cur, n_embd, n_seq_tokens * n_seqs); + return cur; +} + +ggml_tensor * llm_build_qwen35moe ::build_layer_ffn(ggml_tensor * cur, const int il) { + // Check if this is an MoE layer + GGML_ASSERT(model.layers[il].ffn_gate_inp != nullptr); + + ggml_tensor * moe_out = + build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, LLM_FFN_SILU, + true, false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il); + cb(moe_out, "ffn_moe_out", il); + + // Add shared experts if present - following Qwen3Next reference implementation + if (model.layers[il].ffn_up_shexp != nullptr) { + ggml_tensor * ffn_shexp = + build_ffn(cur, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(ffn_shexp, "ffn_shexp", il); + + // Apply shared expert gating as in the reference implementation + // The shared expert has its own gate that is sigmoided + // Note: ffn_gate_inp_shexp is the shared expert gate (outputs 1 value per token) + ggml_tensor * shared_gate = build_lora_mm(model.layers[il].ffn_gate_inp_shexp, cur); + cb(shared_gate, "shared_expert_gate", il); + + // Apply sigmoid to the gate + shared_gate = ggml_sigmoid(ctx0, shared_gate); + cb(shared_gate, "shared_expert_gate_sigmoid", il); + + + // Apply the gate to the shared expert output + ffn_shexp = ggml_mul(ctx0, ffn_shexp, shared_gate); + cb(ffn_shexp, "ffn_shexp_gated", il); + + cur = ggml_add(ctx0, moe_out, ffn_shexp); + cb(cur, "ffn_out", il); + } else { + cur = moe_out; + } + + return cur; +} diff --git a/examples/talk-llama/models/qwen3next.cpp b/examples/talk-llama/models/qwen3next.cpp index 99b1a76a485..aea8b29513e 100644 --- a/examples/talk-llama/models/qwen3next.cpp +++ b/examples/talk-llama/models/qwen3next.cpp @@ -16,17 +16,6 @@ llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_gr ggml_tensor * inp_pos = build_inp_pos(); ggml_tensor * inp_out_ids = build_inp_out_ids(); - ggml_tensor * causal_mask = - ggml_tri(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, CHUNK_SIZE, CHUNK_SIZE), 1.0f), - GGML_TRI_TYPE_LOWER); - - ggml_tensor * identity = ggml_diag(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, CHUNK_SIZE), 1.0f)); - ggml_tensor * diag_mask = ggml_add(ctx0, causal_mask, identity); - - ggml_build_forward_expand(gf, causal_mask); - ggml_build_forward_expand(gf, identity); - ggml_build_forward_expand(gf, diag_mask); - for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; @@ -36,7 +25,7 @@ llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_gr // Determine layer type and build appropriate attention mechanism if (hparams.is_recurrent(il)) { // Linear attention layer (gated delta net) - cur = build_layer_attn_linear(inp->get_recr(), cur, causal_mask, identity, diag_mask, il); + cur = build_layer_attn_linear(inp->get_recr(), cur, il); } else { // Full attention layer cur = build_layer_attn(inp->get_attn(), cur, inp_pos, il); @@ -99,11 +88,8 @@ std::pair llm_build_qwen3next::build_delta_net_chu ggml_tensor * k, ggml_tensor * v, ggml_tensor * g, - ggml_tensor * beta, - ggml_tensor * state, - ggml_tensor * causal_mask, - ggml_tensor * identity, - ggml_tensor * diag_mask, + ggml_tensor * b, + ggml_tensor * s, int il) { const int64_t S_k = q->ne[0]; const int64_t H_k = q->ne[1]; @@ -113,134 +99,123 @@ std::pair llm_build_qwen3next::build_delta_net_chu const int64_t S_v = v->ne[0]; const int64_t H_v = v->ne[1]; - GGML_ASSERT(v->ne[2] == n_tokens); - GGML_ASSERT(k->ne[2] == n_tokens); - GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs); - GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs); - GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs); + GGML_ASSERT(S_k == S_v); + GGML_ASSERT(H_v % H_k == 0); GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs); GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs); + GGML_ASSERT(v->ne[0] == S_v && v->ne[1] == H_v && v->ne[2] == n_tokens && v->ne[3] == n_seqs); - GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case - - const float eps_norm = hparams.f_norm_rms_eps; - - q = ggml_l2_norm(ctx0, q, eps_norm); - k = ggml_l2_norm(ctx0, k, eps_norm); + GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs); + GGML_ASSERT(b->ne[0] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs); + GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs); - const float scale = 1.0f / sqrtf(S_v); + const float scale = 1.0f / sqrtf(S_k); q = ggml_scale(ctx0, q, scale); - beta = ggml_sigmoid(ctx0, beta); - cb(q, "q_in", il); cb(k, "k_in", il); cb(v, "v_in", il); - cb(beta, "beta_in", il); + cb(b, "b_in", il); cb(g, "g_in", il); - q = ggml_cont_4d(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); - k = ggml_cont_4d(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); - v = ggml_cont_4d(ctx0, ggml_permute(ctx0, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); - g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 2, 0, 3, 1), n_tokens, 1, H_k, n_seqs); - - beta = ggml_cont(ctx0, ggml_permute(ctx0, beta, 2, 0, 1, 3)); - state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs); - - cb(q, "q_perm", il); - cb(k, "k_perm", il); - cb(v, "v_perm", il); - cb(beta, "beta_perm", il); - cb(g, "g_perm", il); - cb(state, "state_in", il); - - GGML_ASSERT(q->ne[1] == n_tokens && q->ne[0] == S_k && q->ne[2] == H_k && q->ne[3] == n_seqs); - GGML_ASSERT(k->ne[1] == n_tokens && k->ne[0] == S_k && k->ne[2] == H_k && k->ne[3] == n_seqs); - GGML_ASSERT(v->ne[1] == n_tokens && v->ne[0] == S_v && v->ne[2] == H_k && v->ne[3] == n_seqs); - GGML_ASSERT(beta->ne[1] == n_tokens && beta->ne[2] == H_k && beta->ne[0] == 1 && beta->ne[3] == n_seqs); + q = ggml_permute(ctx0, q, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs] + k = ggml_permute(ctx0, k, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs] + v = ggml_permute(ctx0, v, 0, 2, 1, 3); // [S_v, n_tokens, H_v, n_seqs] + g = ggml_permute(ctx0, g, 2, 1, 3, 0); // [ 1, n_tokens, H_v, n_seqs] + b = ggml_permute(ctx0, b, 2, 0, 1, 3); // [ 1, n_tokens, H_v, n_seqs] - // Do padding - const int64_t chunk_size = CHUNK_SIZE; + const int CS = CHUNK_SIZE; - const int64_t pad = (chunk_size - n_tokens % chunk_size) % chunk_size; - const int64_t n_chunks = (n_tokens + pad) / chunk_size; + const int pad = (CS - n_tokens % CS) % CS; + const int n_chunks = (n_tokens + pad) / CS; q = ggml_pad(ctx0, q, 0, pad, 0, 0); k = ggml_pad(ctx0, k, 0, pad, 0, 0); v = ggml_pad(ctx0, v, 0, pad, 0, 0); - g = ggml_pad(ctx0, g, pad, 0, 0, 0); - beta = ggml_pad(ctx0, beta, 0, pad, 0, 0); + g = ggml_pad(ctx0, g, 0, pad, 0, 0); + b = ggml_pad(ctx0, b, 0, pad, 0, 0); - cb(q, "q_pad", il); - cb(k, "k_pad", il); - cb(v, "v_pad", il); - cb(beta, "beta_pad", il); - cb(g, "g_pad", il); + ggml_tensor * v_b = ggml_mul(ctx0, v, b); + ggml_tensor * k_b = ggml_mul(ctx0, k, b); - ggml_tensor * v_beta = ggml_mul(ctx0, v, beta); - ggml_tensor * k_beta = ggml_mul(ctx0, k, beta); + cb(v_b, "v_b", il); + cb(k_b, "k_b", il); - cb(v_beta, "v_beta", il); - cb(k_beta, "k_beta", il); + q = ggml_reshape_4d(ctx0, q, S_k, CS, n_chunks, H_k * n_seqs); + k = ggml_reshape_4d(ctx0, k, S_k, CS, n_chunks, H_k * n_seqs); + k_b = ggml_reshape_4d(ctx0, k_b, S_k, CS, n_chunks, H_v * n_seqs); + v = ggml_reshape_4d(ctx0, v, S_v, CS, n_chunks, H_v * n_seqs); + v_b = ggml_reshape_4d(ctx0, v_b, S_v, CS, n_chunks, H_v * n_seqs); - q = ggml_reshape_4d(ctx0, q, S_k, chunk_size, n_chunks, H_k * n_seqs); - k = ggml_reshape_4d(ctx0, k, S_k, chunk_size, n_chunks, H_k * n_seqs); - k_beta = ggml_reshape_4d(ctx0, k_beta, S_k, chunk_size, n_chunks, H_k * n_seqs); - v = ggml_reshape_4d(ctx0, v, S_v, chunk_size, n_chunks, H_v * n_seqs); - v_beta = ggml_reshape_4d(ctx0, v_beta, S_v, chunk_size, n_chunks, H_v * n_seqs); + g = ggml_reshape_4d(ctx0, g, CS, 1, n_chunks, H_v * n_seqs); + b = ggml_reshape_4d(ctx0, b, 1, CS, n_chunks, H_v * n_seqs); - g = ggml_reshape_4d(ctx0, g, chunk_size, 1, n_chunks, H_k * n_seqs); - beta = ggml_reshape_4d(ctx0, beta, 1, chunk_size, n_chunks, H_k * n_seqs); + // [CS, 1, n_chunks, H_v * n_seqs] + ggml_tensor * g_cs = ggml_cumsum(ctx0, g); + cb(g_cs, "g_cs", il); - ggml_tensor * g_cumsum = ggml_cumsum(ctx0, g); - cb(g_cumsum, "g_cumsum", il); // shape: (chunk_size, 1, n_chunks, H_v * n_seqs) + ggml_tensor * g_cs_i = g_cs; + ggml_tensor * g_cs_j = ggml_reshape_4d(ctx0, g_cs, 1, CS, n_chunks, H_v * n_seqs); - ggml_tensor * gcs_i = g_cumsum; // ggml_reshape_4d(ctx0, g_cumsum, chunk_size, 1, n_chunks, H_v * n_seqs); - ggml_tensor * gcs_j = ggml_reshape_4d(ctx0, g_cumsum, 1, chunk_size, n_chunks, H_v * n_seqs); + g_cs_j = ggml_repeat_4d(ctx0, g_cs_j, CS, CS, n_chunks, H_v * n_seqs); - ggml_tensor * gcs_j_broadcast = - ggml_repeat_4d(ctx0, gcs_j, chunk_size, chunk_size, n_chunks, H_v * n_seqs); + // [CS, CS, n_chunks, H_v * n_seqs] + ggml_tensor * decay_mask; + decay_mask = ggml_sub(ctx0, g_cs_j, g_cs_i); + decay_mask = ggml_tri(ctx0, decay_mask, GGML_TRI_TYPE_LOWER_DIAG); + decay_mask = ggml_exp(ctx0, decay_mask); + cb(decay_mask, "decay_mask", il); - ggml_tensor * decay_mask = ggml_sub(ctx0, gcs_j_broadcast, gcs_i); - cb(decay_mask, "decay_mask", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) + // [CS, CS, n_chunks, H_k * n_seqs] + ggml_tensor * kb; + kb = ggml_mul_mat(ctx0, k, k_b); + kb = ggml_mul (ctx0, kb, decay_mask); - decay_mask = ggml_mul(ctx0, decay_mask, diag_mask); - decay_mask = ggml_exp(ctx0, decay_mask); - decay_mask = ggml_mul(ctx0, decay_mask, diag_mask); + // [CS, CS, n_chunks, H_k * n_seqs] + ggml_tensor * attn; + attn = ggml_tri(ctx0, kb, GGML_TRI_TYPE_LOWER); - ggml_tensor * kmulkbeta = ggml_mul_mat(ctx0, k, k_beta); + ggml_tensor * identity; + identity = ggml_view_1d(ctx0, attn, CS, 0); + identity = ggml_fill (ctx0, identity, 1.0f); + identity = ggml_diag (ctx0, identity); - ggml_tensor * k_decay = ggml_mul(ctx0, kmulkbeta, decay_mask); - ggml_tensor * attn = ggml_neg(ctx0, ggml_mul(ctx0, k_decay, causal_mask)); - cb(attn, "attn_pre_solve", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) + ggml_tensor * lhs = ggml_add(ctx0, attn, identity); + cb(lhs, "dnet_add_ch_lhs", il); - ggml_tensor * attn_lower = ggml_mul(ctx0, attn, causal_mask); - ggml_tensor * lhs = ggml_sub(ctx0, ggml_repeat(ctx0, identity, attn_lower), attn_lower); + attn = ggml_neg(ctx0, attn); - ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, attn, true, true, false); - attn = ggml_mul(ctx0, lin_solve, causal_mask); - attn = ggml_add(ctx0, attn, identity); - cb(attn, "attn_solved", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) + ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, attn, true, true, false); + attn = ggml_add(ctx0, lin_solve, identity); + cb(attn, "dnet_add_ch_attn_solved", il); // [CS, CS, n_chunks, H_k * n_seqs] - v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_beta)), attn); + // [S_v, CS, n_chunks, H_v * n_seqs] + v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_b)), attn); - ggml_tensor * g_cumsum_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_cumsum)); - ggml_tensor * gexp = ggml_exp(ctx0, g_cumsum_t); + // [CS, 1, n_chunks, H_v * n_seqs] + ggml_tensor * g_exp = ggml_exp(ctx0, g_cs); - ggml_tensor * kbeta_gexp = ggml_mul(ctx0, k_beta, gexp); - cb(kbeta_gexp, "kbeta_gexp", il); // shape: (S_k, chunk_size, n_chunks, H_v * n_seqs) + k_b = ggml_cont(ctx0, ggml_transpose(ctx0, k_b)); - ggml_tensor * k_cumdecay = - ggml_cont(ctx0, ggml_transpose(ctx0, ggml_mul_mat(ctx0, attn, ggml_cont(ctx0, ggml_transpose(ctx0, kbeta_gexp))))); - cb(k_cumdecay, "k_cumdecay", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) + // [CS, S_k, n_chunks, H_k * n_seqs] + ggml_tensor * kbg = ggml_mul(ctx0, k_b, g_exp); + cb(kbg, "k_beta_g_exp", il); - ggml_tensor * attn_kq = ggml_mul_mat(ctx0, k, q); - attn_kq = ggml_mul(ctx0, attn_kq, decay_mask); - attn_kq = ggml_mul(ctx0, attn_kq, diag_mask); - cb(attn_kq, "attn_kq", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) + // [S_k, CS, n_chunks, H_k * n_seqs] + ggml_tensor * k_cd = ggml_mul_mat(ctx0, kbg, attn); + cb(k_cd, "k_cumdecay", il); + // [S_k, CS, n_chunks, H_k * n_seqs] + ggml_tensor * g_exp_t = ggml_transpose(ctx0, g_exp); + ggml_tensor * q_g_exp = ggml_mul(ctx0, q, g_exp_t); + + // [CS, CS, n_chunks, H_k * n_seqs] + ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); + kq = ggml_mul(ctx0, kq, decay_mask); + kq = ggml_tri(ctx0, kq, GGML_TRI_TYPE_LOWER_DIAG); + cb(kq, "kq", il); // vectorized calculation of key_gdiff // improved from the chunked version: @@ -250,109 +225,98 @@ std::pair llm_build_qwen3next::build_delta_net_chu // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew - // get last element in g_cumsum along chunk_size dimension (ne0) + // get last element in g_cumsum along CS dimension (ne0) // example: [[x, y, z, ..., last], ...] -> [[last], ...] - ggml_tensor * g_last = ggml_view_4d(ctx0, g_cumsum, 1, 1, g_cumsum->ne[2], g_cumsum->ne[3], - g_cumsum->nb[1], g_cumsum->nb[2], g_cumsum->nb[3], - (g_cumsum->ne[0] - 1) * ggml_element_size(g_cumsum)); + // [1, 1, n_chunks, H_v * n_seqs] + ggml_tensor * g_last = ggml_view_4d(ctx0, g_cs, 1, 1, g_cs->ne[2], g_cs->ne[3], + g_cs->nb[1], + g_cs->nb[2], + g_cs->nb[3], + ggml_row_size(g_cs->type, g_cs->ne[0] - 1)); + cb(g_last, "g_last", il); + + // TODO: remove this cont when CUDA supports non-cont unary ops g_last = ggml_cont(ctx0, g_last); - cb(g_last, "g_last", il); // shape: (1, 1, n_chunks, H_v * n_seqs) + // [1, 1, n_chunks, H_v * n_seqs] ggml_tensor * g_last_exp = ggml_exp(ctx0, g_last); - cb(g_last_exp, "g_last_exp", il); // shape: (1, 1, n_chunks, H_v * n_seqs) - - ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cumsum, g_last)); - cb(g_diff, "g_diff", il); // shape: (chunk_size, 1, n_chunks, H_v * n_seqs) + cb(g_last_exp, "g_last_exp", il); - ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff); - ggml_tensor * g_diff_exp_t = ggml_reshape_4d(ctx0, g_diff_exp, - 1, chunk_size, n_chunks, g_diff_exp->ne[3]); + // [CS, 1, n_chunks, H_v * n_seqs] + ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cs, g_last)); + cb(g_diff, "g_diff", il); - ggml_tensor * key_gdiff = ggml_mul(ctx0, k, g_diff_exp_t); - cb(key_gdiff, "key_gdiff", il); // shape: (S_k, chunk_size, n_chunks, H_v * n_seqs) + ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff); + ggml_tensor * g_diff_exp_t = ggml_transpose(ctx0, g_diff_exp); - ggml_tensor * key_gdiff_t = ggml_cont(ctx0, ggml_transpose(ctx0, key_gdiff)); - cb(key_gdiff_t, "key_gdiff_t", il); // shape: (chunk_size, S_k, n_chunks, H_v * n_seqs) + // [S_k, CS, n_chunks, H_v * n_seqs] + ggml_tensor * kg = ggml_mul(ctx0, k, g_diff_exp_t); + cb(kg, "key_gdiff", il); + // [CS, S_k, n_chunks, H_v * n_seqs] + ggml_tensor * kg_t = ggml_cont(ctx0, ggml_transpose(ctx0, kg)); + cb(kg_t, "key_gdiff_t", il); - // state to be updated per chunk - ggml_tensor * new_state = state; // ggml_dup(ctx0, state); - cb(new_state, "new_state", il); // shape: (S_v, S_v, H_v, n_seqs) + ggml_tensor * s_t = ggml_transpose(ctx0, s); + s_t = ggml_cont_4d(ctx0, s_t, S_v, S_v, 1, H_v * n_seqs); + cb(s_t, "dnet_add_ch_state", il); - // shape after loop of chunks: (S_v, chunk_size, n_chunks, H_v * n_seqs) - ggml_tensor * core_attn_out = nullptr; + // [CS, S_v, n_chunks, H_v * n_seqs] + ggml_tensor * v_t = ggml_cont(ctx0, ggml_transpose(ctx0, v)); for (int64_t chunk = 0; chunk < n_chunks; chunk++) { - // shape: (S_k, chunk_size, 1, H_k * n_seqs) - ggml_tensor * q_chunk = get_slice_2d(ctx0, q, chunk); // (no cont), next op: ggml_mul + ggml_tensor * ch_k_cd = get_slice_2d(ctx0, k_cd, chunk); // [S_k, CS, 1, H_k * n_seqs] + ggml_tensor * ch_v_t = get_slice_2d(ctx0, v_t, chunk); // [ CS, S_v, 1, H_v * n_seqs] + ggml_tensor * ch_kq = get_slice_2d(ctx0, kq, chunk); // [ CS, CS, 1, H_k * n_seqs] + ggml_tensor * ch_q_g_exp = get_slice_2d(ctx0, q_g_exp, chunk); // [S_k, CS, 1, H_k * n_seqs] + ggml_tensor * ch_kg_t = get_slice_2d(ctx0, kg_t, chunk); // [ CS, S_k, 1, H_v * n_seqs] - // shape: (S_v, chunk_size, 1, H_v * n_seqs) - ggml_tensor * v_chunk = get_slice_2d(ctx0, v, chunk); // (no cont), next op: ggml_repeat + // [CS, S_v, 1, H_v * n_seqs] + ggml_tensor * v_t_p = ggml_mul_mat(ctx0, ch_k_cd, s_t); + cb(v_t_p, "v_prime", il); - // shape: (chunk_size, 1, n_chunks, H_v * n_seqs) - ggml_tensor * gexp_chunk = get_slice_2d(ctx0, gexp, chunk); // (no cont), next op: ggml_mul + // [CS, S_v, 1, H_v * n_seqs] + ggml_tensor * v_t_new = ggml_sub(ctx0, ch_v_t, v_t_p); + cb(v_t_new, "v_t_new", il); - // shape: (chunk_size, 1, H_v * n_seqs) - ggml_tensor * k_cumdecay_chunk = get_slice_2d(ctx0, k_cumdecay, chunk); // (no cont), next op: ggml_mul_mat + // [S_v, CS, 1, H_v * n_seqs] + ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_t_new, ch_kq); + cb(v_attn, "v_attn", il); - // attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0) - // replaced by precomputed attn_kq - ggml_tensor * attn_chunk = get_slice_2d(ctx0, attn_kq, chunk); - cb(attn_chunk, "attn_chunk", il); + // [S_v, CS, 1, H_v * n_seqs] + ggml_tensor * attn_inter = ggml_mul_mat(ctx0, s_t, ch_q_g_exp); + cb(attn_inter, "attn_inter", il); - ggml_tensor * state_t = ggml_cont_4d(ctx0, ggml_permute(ctx0, new_state, 1, 0, 2, 3), S_v, S_v, 1, H_v * n_seqs); + // [S_v, CS, 1, H_v * n_seqs] + ggml_tensor * o_ch = ggml_add(ctx0, attn_inter, v_attn); + cb(o_ch, "dnet_add_ch_attn_out", il); - // v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state - ggml_tensor * v_prime = ggml_mul_mat(ctx0, state_t, k_cumdecay_chunk); - cb(v_prime, "v_prime_chunk", il); // shape: (S_v, 1, H_v * n_seqs) - - // v_new = v_i - v_prime - ggml_tensor * v_new = ggml_sub(ctx0, ggml_repeat(ctx0, v_chunk, v_prime), v_prime); - ggml_tensor * v_new_t = ggml_cont(ctx0, ggml_transpose(ctx0, v_new)); - cb(v_new, "v_new_chunk", il); - - // attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state - ggml_tensor * q_g_exp = ggml_mul(ctx0, q_chunk, gexp_chunk); - ggml_tensor * attn_inter = ggml_mul_mat(ctx0, state_t, q_g_exp); - cb(attn_inter, "attn_inter_chunk", il); - - // core_attn_out[:, :, i] = attn_inter + attn @ v_new - ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_new_t, attn_chunk); - cb(v_attn, "v_attn_chunk", il); - - ggml_tensor * core_attn_out_chunk = ggml_add(ctx0, attn_inter, v_attn); - cb(core_attn_out_chunk, "core_attn_out_chunk", il); // shape: (S_v, chunk_size, 1, H_v * n_seqs) - - core_attn_out = core_attn_out == nullptr - ? core_attn_out_chunk - : ggml_concat(ctx0, core_attn_out, core_attn_out_chunk, 2); + v = ggml_set_inplace(ctx0, v, o_ch, v->nb[1], v->nb[2], v->nb[3], chunk * v->nb[2]); // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new - ggml_tensor * k_gdiff_t = get_slice_2d(ctx0, key_gdiff_t, chunk); - //ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, k_gdiff, v_new); // this is slower on metal, why? - ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, k_gdiff_t); + // TODO: head broadcast might not work here - probably will need a transpose + ggml_tensor * kgv = ggml_mul_mat(ctx0, ch_kg_t, v_t_new); // [S_k, S_v, 1, H_k * n_seqs] // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew - ggml_tensor * gexp_last_chunk = ggml_cont(ctx0, get_slice_2d(ctx0, g_last_exp, chunk)); - new_state = ggml_add(ctx0, - ggml_mul(ctx0, new_state, ggml_reshape_4d(ctx0, gexp_last_chunk, gexp_last_chunk->ne[0], gexp_last_chunk->ne[1], H_v, n_seqs)), - ggml_reshape_4d(ctx0, kgdmulvnew, kgdmulvnew->ne[0], kgdmulvnew->ne[1], H_v, n_seqs)); + ggml_tensor * ch_g_last_exp = get_slice_2d(ctx0, g_last_exp, chunk); + s_t = ggml_mul(ctx0, s_t, ch_g_last_exp); + s_t = ggml_add(ctx0, s_t, kgv); + cb(s_t, "dnet_add_ch_state", il); } + s_t = ggml_reshape_4d(ctx0, s_t, S_v, S_v, H_v, n_seqs); + // truncate padded tokens - ggml_tensor * output_tokens = ggml_view_4d(ctx0, core_attn_out, + ggml_tensor * o = ggml_view_4d(ctx0, v, S_v, n_tokens, H_v, n_seqs, - ggml_row_size(core_attn_out->type, S_v), - ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks), - ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks * H_v), 0); - output_tokens = ggml_cont(ctx0, output_tokens); - cb(output_tokens, "output_tokens", il); + ggml_row_size(v->type, S_v), + ggml_row_size(v->type, S_v * CS * n_chunks), + ggml_row_size(v->type, S_v * CS * n_chunks * H_v), 0); - // permute back to (S_v, H_v, n_tokens, n_seqs) - output_tokens = ggml_permute(ctx0, output_tokens, 0, 2, 1, 3); - output_tokens = ggml_cont(ctx0, output_tokens); + o = ggml_permute (ctx0, o, 0, 2, 1, 3); // [S_v, H_v, n_tokens, n_seqs] + s = ggml_transpose(ctx0, s_t); // [S_v, S_v, H_v, n_seqs] - return {output_tokens, new_state}; + return {o, s}; } std::pair llm_build_qwen3next::build_delta_net_autoregressive( @@ -360,8 +324,8 @@ std::pair llm_build_qwen3next::build_delta_net_aut ggml_tensor * k, ggml_tensor * v, ggml_tensor * g, - ggml_tensor * beta, - ggml_tensor * state, + ggml_tensor * b, // beta + ggml_tensor * s, // state int il) { const int64_t S_k = q->ne[0]; const int64_t H_k = q->ne[1]; @@ -371,75 +335,72 @@ std::pair llm_build_qwen3next::build_delta_net_aut const int64_t S_v = v->ne[0]; const int64_t H_v = v->ne[1]; - GGML_ASSERT(n_tokens == 1); // This function is optimized for single token processing - GGML_ASSERT(v->ne[2] == n_tokens); - GGML_ASSERT(k->ne[2] == n_tokens); - GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs); - GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs); - GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs); + GGML_ASSERT(n_tokens == 1); + + GGML_ASSERT(S_k == S_v); + GGML_ASSERT(H_v % H_k == 0); GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs); GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs); + GGML_ASSERT(v->ne[0] == S_v && v->ne[1] == H_v && v->ne[2] == n_tokens && v->ne[3] == n_seqs); - GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case - - const float eps_norm = hparams.f_norm_rms_eps; + GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs); + GGML_ASSERT(b->ne[0] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs); + GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs); - q = ggml_l2_norm(ctx0, q, eps_norm); - k = ggml_l2_norm(ctx0, k, eps_norm); + const float scale = 1.0f / sqrtf(S_k); - const float scale = 1.0f / sqrtf(S_v); + q = ggml_scale(ctx0, q, scale); - q = ggml_scale(ctx0, q, scale); - beta = ggml_sigmoid(ctx0, beta); + q = ggml_permute(ctx0, q, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs] + k = ggml_permute(ctx0, k, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs] + v = ggml_permute(ctx0, v, 0, 2, 1, 3); // [S_v, n_tokens, H_v, n_seqs] cb(q, "q_in", il); cb(k, "k_in", il); cb(v, "v_in", il); - cb(beta, "beta_in", il); + cb(b, "b_in", il); cb(g, "g_in", il); - state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs); + g = ggml_reshape_4d(ctx0, g, 1, 1, H_v, n_seqs); + b = ggml_reshape_4d(ctx0, b, 1, 1, H_v, n_seqs); - ggml_tensor * g_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, g), 1, 1, H_k, n_seqs); - ggml_tensor * beta_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, beta), 1, 1, H_k, n_seqs); + // [S_v, S_v, H_v, n_seqs] + g = ggml_exp(ctx0, g); + s = ggml_mul(ctx0, s, g); - // Apply exponential to g_t - g_t = ggml_exp(ctx0, g_t); + ggml_tensor * s_t = ggml_cont(ctx0, ggml_transpose(ctx0, s)); - // Apply the gated delta rule for the single timestep - // last_recurrent_state = last_recurrent_state * g_t - state = ggml_mul(ctx0, state, g_t); + // [1, S_v, H_v, n_seqs] + ggml_tensor * sk; + sk = ggml_mul (ctx0, s_t, k); + sk = ggml_sum_rows(ctx0, sk); - // kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2) - ggml_tensor * k_t_unsqueezed = ggml_reshape_4d(ctx0, k, 1, S_v, H_v, n_seqs); - ggml_tensor * kv_mem = ggml_mul(ctx0, state, k_t_unsqueezed); - // we need to sum over dim=-2, so we transpose, sum, then transpose again - kv_mem = ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, kv_mem)))); + // [S_v, 1, H_v, n_seqs] + ggml_tensor * d; + d = ggml_sub(ctx0, v, ggml_transpose(ctx0, sk)); + d = ggml_mul(ctx0, d, b); - // v_t = v.unsqueeze(2) (we insert the singleton dimension after n_seqs and H_v) - ggml_tensor * v_t = ggml_reshape_4d(ctx0, v, S_v, 1, H_v, n_seqs); - // delta = (v_t - kv_mem) * beta_t - ggml_tensor * v_diff = ggml_sub(ctx0, v_t, kv_mem); // both should be [S_v, 1, H_v, n_seqs] - ggml_tensor * delta = ggml_mul(ctx0, v_diff, beta_t); + // [1, S_v, H_v, n_seqs] + ggml_tensor * d_t; + d_t = ggml_transpose(ctx0, d); - // last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta - ggml_tensor * k_t_delta = ggml_mul(ctx0, ggml_repeat_4d(ctx0, k_t_unsqueezed, S_v, S_v, H_v, n_seqs), delta); - state = ggml_add(ctx0, state, k_t_delta); + // [S_v, S_v, H_v, n_seqs] + ggml_tensor * kd; + k = ggml_repeat(ctx0, k, s); + kd = ggml_mul (ctx0, k, d_t); - // Compute the attention output - // core_attn_out = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2) - ggml_tensor * q_t_unsqueezed = ggml_reshape_4d(ctx0, q, 1, S_v, H_v, n_seqs); // unsqueeze q_t - ggml_tensor * state_q = ggml_mul(ctx0, state, q_t_unsqueezed); - // again, since it's over dim = -2, transpose, sum, transpose back - ggml_tensor * core_attn_out = - ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, state_q)))); + s_t = ggml_add(ctx0, s_t, kd); - // core_attn_out should be [S_v, 1, H_v, n_seqs] after this - cb(core_attn_out, "output_tokens", il); - cb(state, "new_state", il); + cb(s_t, "dnet_add_ar_state", il); - return {core_attn_out, state}; + ggml_tensor * s_q = ggml_mul (ctx0, s_t, q); + ggml_tensor * o = ggml_sum_rows(ctx0, s_q); + + o = ggml_permute (ctx0, o, 2, 0, 1, 3); // [S_v, H_v, n_tokens, n_seqs] + s = ggml_transpose(ctx0, s_t); // [S_v, S_v, H_v, n_seqs] + + return {o, s}; } ggml_tensor * llm_build_qwen3next::build_norm_gated( @@ -472,39 +433,29 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn( // Split Q projection into query and gate // The split should be along dimension 0 (the feature dimension) ggml_tensor * Qcur = ggml_view_4d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, 1, - Qcur_full->nb[1], Qcur_full->nb[2], Qcur_full->nb[3], 0); + Qcur_full->nb[1], Qcur_full->nb[2], Qcur_full->nb[3], 0); + cb(Qcur, "Qcur_view", il); + ggml_tensor * gate = ggml_view_4d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, 1, Qcur_full->nb[1], Qcur_full->nb[2], Qcur_full->nb[3], n_embd_head * ggml_element_size(Qcur_full)); - cb(Qcur, "Qcur", il); cb(gate, "gate", il); - // Now reshape Qcur to [n_embd_head, n_head, n_tokens] for multi-head attention - Qcur = ggml_cont_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - cb(Qcur, "Qcur_reshaped", il); - - // Apply Q normalization - Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il); - cb(Qcur, "Qcur_normed", il); - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); cb(Kcur, "Kcur", il); ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); cb(Vcur, "Vcur", il); - // Apply K normalization Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, il); - cb(Kcur, "Kcur_normed", il); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - // Reshape gate to [n_embd, n_tokens] for the sigmoid gating (flatten the heads) - gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens); - cb(gate, "gate_reshaped", il); + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); - // Apply RoPE Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, @@ -519,7 +470,6 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn( cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - // Attention computation const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale; cur = build_attn(inp, @@ -527,10 +477,15 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn( Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_pregate", il); - ggml_tensor * gate_sigmoid = ggml_sigmoid(ctx0, gate); - cb(gate_sigmoid, "gate_sigmoid", il); + // TODO: CUDA is missing non-contiguous unary ops. when implemented: remove this cont + gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens); + + gate = ggml_sigmoid(ctx0, gate); + cb(gate, "gate_sigmoid", il); + + gate = ggml_reshape_2d(ctx0, gate, n_embd_head * n_head, n_tokens); - cur = ggml_mul(ctx0, cur, gate_sigmoid); + cur = ggml_mul(ctx0, cur, gate); cb(cur, "attn_gated", il); cur = build_lora_mm(model.layers[il].wo, cur); @@ -560,7 +515,6 @@ std::pair llm_build_qwen3next::build_qkvz( cb(z, "z", il); return { qkv_mixed, z }; - } else { // legacy (slower) path ggml_tensor * mixed_qkvz = build_lora_mm(model.layers[il].ssm_in, input); @@ -624,9 +578,6 @@ std::pair llm_build_qwen3next::build_qkvz( ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( llm_graph_input_rs * inp, ggml_tensor * cur, - ggml_tensor * causal_mask, - ggml_tensor * identity, - ggml_tensor * diag_mask, int il) { const auto * mctx_cur = inp->mctx; @@ -671,7 +622,12 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( split_sizes_ba[0] * ggml_element_size(mixed_ba_reshaped)); cb(a, "a", il); - ggml_tensor * beta = ggml_cont_4d(ctx0, b, num_v_heads, 1, n_seq_tokens, n_seqs); + // TODO: CUDA is missing non-contiguous unary ops. when implemented: remove this cont + b = ggml_cont(ctx0, b); + + ggml_tensor * beta = ggml_sigmoid(ctx0, b); + + beta = ggml_reshape_4d(ctx0, beta, num_v_heads, 1, n_seq_tokens, n_seqs); // Reshape a to merge head dimensions: [batch, seq_len, num_k_heads, num_v_heads/num_k_heads] -> [batch, seq_len, num_v_heads] ggml_tensor * alpha = ggml_cont_3d(ctx0, a, num_v_heads, n_seq_tokens, n_seqs); @@ -679,6 +635,7 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( ggml_tensor * alpha_biased = ggml_add(ctx0, alpha, model.layers[il].ssm_dt); ggml_tensor * alpha_softplus = ggml_softplus(ctx0, alpha_biased); cb(alpha_softplus, "a_softplus", il); + ggml_tensor * gate = ggml_mul(ctx0, alpha_softplus, model.layers[il].ssm_a); // -A_log.exp() * softplus cb(gate, "gate", il); @@ -686,8 +643,6 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); - // bool use_precomputed_states = n_seq_tokens == 1 && mctx_cur->has_previous_state(); - // Build the convolution states tensor ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs); cb(conv_states, "conv_states", il); @@ -696,11 +651,12 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( ggml_tensor * conv_kernel = model.layers[il].ssm_conv1d; const int64_t conv_kernel_size = conv_kernel->ne[0]; const int64_t conv_channels = d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state; - conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs); + + conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs); cb(conv_states, "conv_states_reshaped", il); - qkv_mixed = ggml_permute(ctx0, qkv_mixed, 1, 0, 2, 3); - cb(qkv_mixed, "qkv_mixed_permuted", il); + qkv_mixed = ggml_transpose(ctx0, qkv_mixed); + cb(qkv_mixed, "qkv_mixed_transposed", il); ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0); cb(conv_input, "conv_input", il); @@ -720,7 +676,10 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target)); cb(conv_states_all, "conv_states_updated", il); - // Apply SSM convolution + ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs); + state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim, num_v_heads, n_seqs); + cb(state, "state_predelta", il); + ggml_tensor * conv_output_proper = ggml_ssm_conv(ctx0, conv_input, conv_kernel); cb(conv_output_proper, "conv_output_raw", il); @@ -734,26 +693,36 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( int64_t nb1_qkv = ggml_row_size(conv_qkv_mix->type, qkv_dim); // Extract the convolved Q, K, V from conv_output - ggml_tensor * q_conv = - ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, nb1_qkv, 0); + ggml_tensor * q_conv = ggml_view_4d(ctx0, conv_qkv_mix, head_k_dim, num_k_heads, n_seq_tokens, n_seqs, + ggml_row_size(conv_qkv_mix->type, head_k_dim), + nb1_qkv, + nb1_qkv * n_seq_tokens, + 0); + + ggml_tensor * k_conv = ggml_view_4d(ctx0, conv_qkv_mix, head_k_dim, num_k_heads, n_seq_tokens, n_seqs, + ggml_row_size(conv_qkv_mix->type, head_k_dim), + nb1_qkv, + nb1_qkv * n_seq_tokens, + head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix)); + + ggml_tensor * v_conv = ggml_view_4d(ctx0, conv_qkv_mix, head_v_dim, num_v_heads, n_seq_tokens, n_seqs, + ggml_row_size(conv_qkv_mix->type, head_v_dim), + nb1_qkv, + nb1_qkv * n_seq_tokens, + ggml_row_size(conv_qkv_mix->type, 2 * head_k_dim * num_k_heads)); + cb(q_conv, "q_conv", il); - ggml_tensor * k_conv = - ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, nb1_qkv, - head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix)); cb(k_conv, "k_conv", il); - ggml_tensor * v_conv = - ggml_view_2d(ctx0, conv_qkv_mix, head_v_dim * num_v_heads, n_seq_tokens * n_seqs, nb1_qkv, - 2 * head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix)); cb(v_conv, "v_conv", il); - // Unsqueeze them - q_conv = ggml_cont_4d(ctx0, q_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs); - k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs); - v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); + const float eps_norm = hparams.f_norm_rms_eps; - ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs); - state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim * num_v_heads, 1, n_seqs); - cb(state, "state_predelta", il); + q_conv = ggml_l2_norm(ctx0, q_conv, eps_norm); + k_conv = ggml_l2_norm(ctx0, k_conv, eps_norm); + + //q_conv = ggml_cont_4d(ctx0, q_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs); + //k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs); + //v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); // if head keys and value keys are different, repeat to force tensors into matching shapes if (num_k_heads != num_v_heads) { @@ -786,7 +755,7 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( if (n_seq_tokens == 1) { attn_out = build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il); } else { - attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, diag_mask, il); + attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, il); } ggml_tensor * output = attn_out.first; ggml_tensor * new_state = attn_out.second; @@ -795,19 +764,15 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( // Update the recurrent states ggml_build_forward_expand(gf, - ggml_cpy(ctx0, new_state, - ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs, - kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all)))); - - // Reshape both attn_out_final and z to 2D tensors for normalization - // attn_out_final: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] - ggml_tensor * attn_out_2d_final = ggml_reshape_2d(ctx0, output, head_v_dim, num_v_heads * n_seq_tokens * n_seqs); + ggml_cpy(ctx0, new_state, + ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs, + kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all)))); // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] - ggml_tensor * z_2d = ggml_reshape_2d(ctx0, z, head_v_dim, num_v_heads * n_seq_tokens * n_seqs); + ggml_tensor * z_2d = ggml_reshape_4d(ctx0, z, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); // Apply gated normalization: self.norm(core_attn_out, z) - ggml_tensor * attn_out_norm = build_norm_gated(attn_out_2d_final, model.layers[il].ssm_norm, z_2d, il); + ggml_tensor * attn_out_norm = build_norm_gated(output, model.layers[il].ssm_norm, z_2d, il); // Final reshape: [head_dim, n_heads, n_tokens, n_seqs] -> [n_tokens, n_seqs, n_heads * head_dim] ggml_tensor * final_output = ggml_reshape_3d(ctx0, attn_out_norm, head_v_dim * num_v_heads, n_seq_tokens, n_seqs); @@ -818,7 +783,8 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( cb(cur, "linear_attn_out", il); // Reshape back to original dimensions - cur = ggml_cont_2d(ctx0, cur, n_embd, n_seq_tokens * n_seqs); + cur = ggml_reshape_2d(ctx0, cur, n_embd, n_seq_tokens * n_seqs); + return cur; } @@ -839,7 +805,7 @@ ggml_tensor * llm_build_qwen3next::build_layer_ffn(ggml_tensor * cur, const int if (model.layers[il].ffn_up_shexp != nullptr) { ggml_tensor * ffn_shexp = build_ffn(cur, - model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_up_shexp, NULL, NULL, model.layers[il].ffn_gate_shexp, NULL, NULL, model.layers[il].ffn_down_shexp, NULL, NULL, NULL, @@ -852,11 +818,9 @@ ggml_tensor * llm_build_qwen3next::build_layer_ffn(ggml_tensor * cur, const int ggml_tensor * shared_gate = build_lora_mm(model.layers[il].ffn_gate_inp_shexp, cur); cb(shared_gate, "shared_expert_gate", il); - // Apply sigmoid to the gate shared_gate = ggml_sigmoid(ctx0, shared_gate); cb(shared_gate, "shared_expert_gate_sigmoid", il); - // Apply the gate to the shared expert output ffn_shexp = ggml_mul(ctx0, ffn_shexp, shared_gate); cb(ffn_shexp, "ffn_shexp_gated", il); diff --git a/examples/talk-llama/unicode.cpp b/examples/talk-llama/unicode.cpp index adfc489d1f0..b88d953bd27 100644 --- a/examples/talk-llama/unicode.cpp +++ b/examples/talk-llama/unicode.cpp @@ -1,16 +1,10 @@ -#if defined(_MSC_VER) -#define _SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING -#endif - #include "unicode.h" #include "unicode-data.h" #include #include -#include #include #include -#include #include #include #include @@ -199,27 +193,6 @@ static std::unordered_map unicode_utf8_to_byte_map() { return map; } -static inline std::wstring unicode_wstring_from_utf8(const std::string & s) { -#if defined(__clang__) - // disable C++17 deprecation warning for std::codecvt_utf8 -# pragma clang diagnostic push -# pragma clang diagnostic ignored "-Wdeprecated-declarations" -#elif defined(__GNUC__) -# pragma GCC diagnostic push -# pragma GCC diagnostic ignored "-Wdeprecated-declarations" -#endif - - std::wstring_convert> conv; - -#if defined(__clang__) -# pragma clang diagnostic pop -#elif defined(__GNUC__) -# pragma GCC diagnostic pop -#endif - - return conv.from_bytes(s); -} - static std::vector unicode_byte_encoding_process(const std::vector & bpe_words) { std::vector bpe_encoded_words; for (const auto & word : bpe_words) { @@ -1028,10 +1001,10 @@ std::vector unicode_regex_split(const std::string & text, const std break; } } + const auto cpts_regex = unicode_cpts_from_utf8(regex_expr); if (use_collapsed) { // sanity-check that the original regex does not contain any non-ASCII characters - const auto cpts_regex = unicode_cpts_from_utf8(regex_expr); for (size_t i = 0; i < cpts_regex.size(); ++i) { if (cpts_regex[i] >= 128) { throw std::runtime_error("Regex includes both unicode categories and non-ASCII characters - not supported"); @@ -1087,7 +1060,7 @@ std::vector unicode_regex_split(const std::string & text, const std bpe_offsets = unicode_regex_split_stl(text_collapsed, regex_expr_collapsed, bpe_offsets); } else { // no unicode category used, we can use std::wregex directly - const std::wstring wregex_expr = unicode_wstring_from_utf8(regex_expr); + std::wstring wregex_expr(cpts_regex.begin(), cpts_regex.end()); // std::wregex \s does not mach non-ASCII whitespaces, using 0x0B as fallback std::wstring wtext(cpts.begin(), cpts.end()); From 21411d81ea736ed5d9cdea4df360d3c4b60a4adb Mon Sep 17 00:00:00 2001 From: Maxime Grenu <69890511+cluster2600@users.noreply.github.com> Date: Thu, 19 Feb 2026 16:18:42 +0100 Subject: [PATCH 155/831] docs : fix duplicate word typo in VAD section (#3670) The VAD section contained a spurious 'the' at the end of a sentence, creating the run-on 'Using this information the / only the speech segments...'. Replace the orphaned 'the' with a comma so the sentence reads correctly: 'Using this information, only the speech segments...'. --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index c0d8edb99bc..474a1301da7 100644 --- a/README.md +++ b/README.md @@ -755,7 +755,7 @@ argument to `whisper-cli`. In addition to this option a VAD model is also required. The way this works is that first the audio samples are passed through -the VAD model which will detect speech segments. Using this information the +the VAD model which will detect speech segments. Using this information, only the speech segments that are detected are extracted from the original audio input and passed to whisper for processing. This reduces the amount of audio data that needs to be processed by whisper and can significantly speed up the From cec1dd9d1276a1df679858222f3b1dc0551c5220 Mon Sep 17 00:00:00 2001 From: Dmitry Atamanov Date: Fri, 27 Feb 2026 15:15:15 +0500 Subject: [PATCH 156/831] examples : update miniaudio library to 0.11.24 (#3672) --- examples/miniaudio.h | 6638 ++++++++++++++++++++++++++++-------------- 1 file changed, 4507 insertions(+), 2131 deletions(-) diff --git a/examples/miniaudio.h b/examples/miniaudio.h index c74bebeb3c7..24e676bb264 100644 --- a/examples/miniaudio.h +++ b/examples/miniaudio.h @@ -1,6 +1,6 @@ /* Audio playback and capture library. Choice of public domain or MIT-0. See license statements at the end of this file. -miniaudio - v0.11.22 - 2025-02-24 +miniaudio - v0.11.24 - 2026-01-17 David Reid - mackron@gmail.com @@ -12,18 +12,10 @@ GitHub: https://github.com/mackron/miniaudio /* 1. Introduction =============== -To use miniaudio, include "miniaudio.h": - - ```c - #include "miniaudio.h" - ``` - -The implementation is contained in "miniaudio.c". Just compile this like any other source file. You -can include miniaudio.c if you want to compile your project as a single translation unit: - - ```c - #include "miniaudio.c" - ``` +To use miniaudio, just include "miniaudio.h" like any other header and add "miniaudio.c" to your +source tree. If you don't want to add it to your source tree you can compile and link to it like +any other library. Note that ABI compatibility is not guaranteed between versions, even with bug +fix releases, so take care if compiling as a shared object. miniaudio includes both low level and high level APIs. The low level API is good for those who want to do all of their mixing themselves and only require a light weight interface to the underlying @@ -303,7 +295,7 @@ The engine encapsulates both the resource manager and the node graph to create a use high level API. The resource manager and node graph APIs are covered in more later sections of this manual. -The code below shows how you can initialize an engine using it's default configuration. +The code below shows how you can initialize an engine using its default configuration. ```c ma_result result; @@ -391,7 +383,7 @@ Sounds are not started by default. Start a sound with `ma_sound_start()` and sto `ma_sound_stop()`. When a sound is stopped, it is not rewound to the start. Use `ma_sound_seek_to_pcm_frame(&sound, 0)` to seek back to the start of a sound. By default, starting and stopping sounds happens immediately, but sometimes it might be convenient to schedule the sound -the be started and/or stopped at a specific time. This can be done with the following functions: +to be started and/or stopped at a specific time. This can be done with the following functions: ```c ma_sound_set_start_time_in_pcm_frames() @@ -463,6 +455,11 @@ is at the end, use `ma_sound_at_end()`. Looping of a sound can be controlled wit miniaudio should work cleanly out of the box without the need to download or install any dependencies. See below for platform-specific details. +This library has been designed to be added directly to your source tree which is the preferred way +of using it, but you can compile it as a normal library if that's your preference. Be careful if +compiling as a shared object because miniaudio is not ABI compatible between any release, including +bug fix releases. It's recommended you link statically. + Note that GCC and Clang require `-msse2`, `-mavx2`, etc. for SIMD optimizations. If you get errors about undefined references to `__sync_val_compare_and_swap_8`, `__atomic_load_8`, @@ -532,7 +529,7 @@ you'll need to disable run-time linking with `MA_NO_RUNTIME_LINKING` and link wi The Emscripten build emits Web Audio JavaScript directly and should compile cleanly out of the box. You cannot use `-std=c*` compiler flags, nor `-ansi`. -You can enable the use of AudioWorkets by defining `MA_ENABLE_AUDIO_WORKLETS` and then compiling +You can enable the use of AudioWorklets by defining `MA_ENABLE_AUDIO_WORKLETS` and then compiling with the following options: -sAUDIO_WORKLET=1 -sWASM_WORKERS=1 -sASYNCIFY @@ -881,7 +878,7 @@ read data within a certain range of the underlying data. To do this you can use This is useful if you have a sound bank where many sounds are stored in the same file and you want the data source to only play one of those sub-sounds. Note that once the range is set, everything -that takes a position, such as cursors and loop points, should always be relatvie to the start of +that takes a position, such as cursors and loop points, should always be relative to the start of the range. When the range is set, any previously defined loop point will be reset. Custom loop points can also be used with data sources. By default, data sources will loop after @@ -889,7 +886,7 @@ they reach the end of the data source, but if you need to loop at a specific loc the following: ```c - result = ma_data_set_loop_point_in_pcm_frames(pDataSource, loopBegInFrames, loopEndInFrames); + result = ma_data_source_set_loop_point_in_pcm_frames(pDataSource, loopBegInFrames, loopEndInFrames); if (result != MA_SUCCESS) { return result; // Failed to set the loop point. } @@ -3750,7 +3747,7 @@ extern "C" { #define MA_VERSION_MAJOR 0 #define MA_VERSION_MINOR 11 -#define MA_VERSION_REVISION 22 +#define MA_VERSION_REVISION 24 #define MA_VERSION_STRING MA_XSTRINGIFY(MA_VERSION_MAJOR) "." MA_XSTRINGIFY(MA_VERSION_MINOR) "." MA_XSTRINGIFY(MA_VERSION_REVISION) #if defined(_MSC_VER) && !defined(__clang__) @@ -3857,37 +3854,65 @@ typedef ma_uint16 wchar_t; #define MA_SIZE_MAX 0xFFFFFFFF /* When SIZE_MAX is not defined by the standard library just default to the maximum 32-bit unsigned integer. */ #endif +#define MA_UINT64_MAX (((ma_uint64)0xFFFFFFFF << 32) | (ma_uint64)0xFFFFFFFF) /* Weird shifting syntax is for VC6 compatibility. */ + /* Platform/backend detection. */ -#if defined(_WIN32) || defined(__COSMOPOLITAN__) +#if defined(_WIN32) #define MA_WIN32 #if defined(MA_FORCE_UWP) || (defined(WINAPI_FAMILY) && ((defined(WINAPI_FAMILY_PC_APP) && WINAPI_FAMILY == WINAPI_FAMILY_PC_APP) || (defined(WINAPI_FAMILY_PHONE_APP) && WINAPI_FAMILY == WINAPI_FAMILY_PHONE_APP))) #define MA_WIN32_UWP #elif defined(WINAPI_FAMILY) && (defined(WINAPI_FAMILY_GAMES) && WINAPI_FAMILY == WINAPI_FAMILY_GAMES) #define MA_WIN32_GDK + #elif defined(NXDK) + #define MA_WIN32_NXDK #else #define MA_WIN32_DESKTOP #endif + + /* The original Xbox. */ + #if defined(NXDK) /* <-- Add other Xbox compiler toolchains here, and then add a toolchain-specific define in case we need to discriminate between them later. */ + #define MA_XBOX + + #if defined(NXDK) + #define MA_XBOX_NXDK + #endif + #endif +#endif +#if defined(__MSDOS__) || defined(MSDOS) || defined(_MSDOS) || defined(__DOS__) + #define MA_DOS + + /* No threading allowed on DOS. */ + #ifndef MA_NO_THREADING + #define MA_NO_THREADING + #endif + + /* No runtime linking allowed on DOS. */ + #ifndef MA_NO_RUNTIME_LINKING + #define MA_NO_RUNTIME_LINKING + #endif #endif -#if !defined(_WIN32) /* If it's not Win32, assume POSIX. */ +#if !defined(MA_WIN32) && !defined(MA_DOS) /* If it's not Win32, assume POSIX. */ #define MA_POSIX - /* - Use the MA_NO_PTHREAD_IN_HEADER option at your own risk. This is intentionally undocumented. - You can use this to avoid including pthread.h in the header section. The downside is that it - results in some fixed sized structures being declared for the various types that are used in - miniaudio. The risk here is that these types might be too small for a given platform. This - risk is yours to take and no support will be offered if you enable this option. - */ - #ifndef MA_NO_PTHREAD_IN_HEADER - #include /* Unfortunate #include, but needed for pthread_t, pthread_mutex_t and pthread_cond_t types. */ - typedef pthread_t ma_pthread_t; - typedef pthread_mutex_t ma_pthread_mutex_t; - typedef pthread_cond_t ma_pthread_cond_t; - #else - typedef ma_uintptr ma_pthread_t; - typedef union ma_pthread_mutex_t { char __data[40]; ma_uint64 __alignment; } ma_pthread_mutex_t; - typedef union ma_pthread_cond_t { char __data[48]; ma_uint64 __alignment; } ma_pthread_cond_t; + #if !defined(MA_NO_THREADING) + /* + Use the MA_NO_PTHREAD_IN_HEADER option at your own risk. This is intentionally undocumented. + You can use this to avoid including pthread.h in the header section. The downside is that it + results in some fixed sized structures being declared for the various types that are used in + miniaudio. The risk here is that these types might be too small for a given platform. This + risk is yours to take and no support will be offered if you enable this option. + */ + #ifndef MA_NO_PTHREAD_IN_HEADER + #include /* Unfortunate #include, but needed for pthread_t, pthread_mutex_t and pthread_cond_t types. */ + typedef pthread_t ma_pthread_t; + typedef pthread_mutex_t ma_pthread_mutex_t; + typedef pthread_cond_t ma_pthread_cond_t; + #else + typedef ma_uintptr ma_pthread_t; + typedef union ma_pthread_mutex_t { char __data[40]; ma_uint64 __alignment; } ma_pthread_mutex_t; + typedef union ma_pthread_cond_t { char __data[48]; ma_uint64 __alignment; } ma_pthread_cond_t; + #endif #endif #if defined(__unix__) @@ -3914,8 +3939,11 @@ typedef ma_uint16 wchar_t; #if defined(__PROSPERO__) #define MA_PROSPERO #endif - #if defined(__NX__) - #define MA_NX + #if defined(__3DS__) + #define MA_3DS + #endif + #if defined(__SWITCH__) || defined(__NX__) + #define MA_SWITCH #endif #if defined(__BEOS__) || defined(__HAIKU__) #define MA_BEOS @@ -3925,12 +3953,13 @@ typedef ma_uint16 wchar_t; #endif #endif -#if defined(__has_c_attribute) - #if __has_c_attribute(fallthrough) - #define MA_FALLTHROUGH [[fallthrough]] - #endif +#if !defined(MA_FALLTHROUGH) && defined(__cplusplus) && __cplusplus >= 201703L + #define MA_FALLTHROUGH [[fallthrough]] #endif -#if !defined(MA_FALLTHROUGH) && defined(__has_attribute) && (defined(__clang__) || defined(__GNUC__)) +#if !defined(MA_FALLTHROUGH) && defined(__STDC_VERSION__) && __STDC_VERSION__ >= 202000L + #define MA_FALLTHROUGH [[fallthrough]] +#endif +#if !defined(MA_FALLTHROUGH) && defined(__has_attribute) #if __has_attribute(fallthrough) #define MA_FALLTHROUGH __attribute__((fallthrough)) #endif @@ -3967,7 +3996,7 @@ typedef ma_uint16 wchar_t; #define MA_NO_INLINE __attribute__((noinline)) #else #define MA_INLINE MA_GNUC_INLINE_HINT - #define MA_NO_INLINE __attribute__((noinline)) + #define MA_NO_INLINE #endif #elif defined(__WATCOMC__) #define MA_INLINE __inline @@ -4153,9 +4182,13 @@ typedef enum MA_CHANNEL_AUX_29 = 49, MA_CHANNEL_AUX_30 = 50, MA_CHANNEL_AUX_31 = 51, + + /* Count. */ + MA_CHANNEL_POSITION_COUNT, + + /* Aliases. */ MA_CHANNEL_LEFT = MA_CHANNEL_FRONT_LEFT, MA_CHANNEL_RIGHT = MA_CHANNEL_FRONT_RIGHT, - MA_CHANNEL_POSITION_COUNT = (MA_CHANNEL_AUX_31 + 1) } _ma_channel_position; /* Do not use `_ma_channel_position` directly. Use `ma_channel` instead. */ typedef enum @@ -4350,7 +4383,7 @@ typedef struct typedef struct { - ma_int32 state; + ma_uint32 state; } ma_lcg; @@ -6569,22 +6602,18 @@ This section contains the APIs for device playback and capture. Here is where yo ************************************************************************************************************************************************************/ #ifndef MA_NO_DEVICE_IO /* Some backends are only supported on certain platforms. */ -#if defined(MA_WIN32) +#if defined(MA_WIN32) && !defined(MA_XBOX) #define MA_SUPPORT_WASAPI #if defined(MA_WIN32_DESKTOP) /* DirectSound and WinMM backends are only supported on desktops. */ #define MA_SUPPORT_DSOUND #define MA_SUPPORT_WINMM - - /* Don't enable JACK here if compiling with Cosmopolitan. It'll be enabled in the Linux section below. */ - #if !defined(__COSMOPOLITAN__) - #define MA_SUPPORT_JACK /* JACK is technically supported on Windows, but I don't know how many people use it in practice... */ - #endif + #define MA_SUPPORT_JACK /* JACK is technically supported on Windows, but I don't know how many people use it in practice... */ #endif #endif #if defined(MA_UNIX) && !defined(MA_ORBIS) && !defined(MA_PROSPERO) #if defined(MA_LINUX) - #if !defined(MA_ANDROID) && !defined(__COSMOPOLITAN__) /* ALSA is not supported on Android. */ + #if !defined(MA_ANDROID) && !defined(MA_EMSCRIPTEN) /* ALSA is not supported on Android. */ #define MA_SUPPORT_ALSA #endif #endif @@ -7426,6 +7455,7 @@ struct ma_context ma_proc snd_pcm_hw_params_set_rate_resample; ma_proc snd_pcm_hw_params_set_rate; ma_proc snd_pcm_hw_params_set_rate_near; + ma_proc snd_pcm_hw_params_set_rate_minmax; ma_proc snd_pcm_hw_params_set_buffer_size_near; ma_proc snd_pcm_hw_params_set_periods_near; ma_proc snd_pcm_hw_params_set_access; @@ -7986,6 +8016,7 @@ struct ma_device /*AAudioStream**/ ma_ptr pStreamPlayback; /*AAudioStream**/ ma_ptr pStreamCapture; ma_mutex rerouteLock; + ma_atomic_bool32 isTearingDown; ma_aaudio_usage usage; ma_aaudio_content_type contentType; ma_aaudio_input_preset inputPreset; @@ -9644,7 +9675,7 @@ Parameters ---------- pBackends (out, optional) A pointer to the buffer that will receive the enabled backends. Set to NULL to retrieve the backend count. Setting - the capacity of the buffer to `MA_BUFFER_COUNT` will guarantee it's large enough for all backends. + the capacity of the buffer to `MA_BACKEND_COUNT` will guarantee it's large enough for all backends. backendCap (in) The capacity of the `pBackends` buffer. @@ -10489,6 +10520,7 @@ typedef struct ma_decoding_backend_vtable** ppCustomDecodingBackendVTables; ma_uint32 customDecodingBackendCount; void* pCustomDecodingBackendUserData; + ma_resampler_config resampling; } ma_resource_manager_config; MA_API ma_resource_manager_config ma_resource_manager_config_init(void); @@ -10816,6 +10848,7 @@ MA_API ma_result ma_node_graph_read_pcm_frames(ma_node_graph* pNodeGraph, void* MA_API ma_uint32 ma_node_graph_get_channels(const ma_node_graph* pNodeGraph); MA_API ma_uint64 ma_node_graph_get_time(const ma_node_graph* pNodeGraph); MA_API ma_result ma_node_graph_set_time(ma_node_graph* pNodeGraph, ma_uint64 globalTime); +MA_API ma_uint32 ma_node_graph_get_processing_size_in_frames(const ma_node_graph* pNodeGraph); @@ -11123,6 +11156,7 @@ typedef struct ma_bool8 isPitchDisabled; /* Pitching can be explicitly disabled with MA_SOUND_FLAG_NO_PITCH to optimize processing. */ ma_bool8 isSpatializationDisabled; /* Spatialization can be explicitly disabled with MA_SOUND_FLAG_NO_SPATIALIZATION. */ ma_uint8 pinnedListenerIndex; /* The index of the listener this node should always use for spatialization. If set to MA_LISTENER_INDEX_CLOSEST the engine will use the closest listener. */ + ma_resampler_config resampling; } ma_engine_node_config; MA_API ma_engine_node_config ma_engine_node_config_init(ma_engine* pEngine, ma_engine_node_type type, ma_uint32 flags); @@ -11137,7 +11171,7 @@ typedef struct ma_uint32 volumeSmoothTimeInPCMFrames; ma_mono_expansion_mode monoExpansionMode; ma_fader fader; - ma_linear_resampler resampler; /* For pitch shift. */ + ma_resampler resampler; /* For pitch shift. */ ma_spatializer spatializer; ma_panner panner; ma_gainer volumeGainer; /* This will only be used if volumeSmoothTimeInPCMFrames is > 0. */ @@ -11193,6 +11227,7 @@ typedef struct ma_uint64 loopPointEndInPCMFrames; ma_sound_end_proc endCallback; /* Fired when the sound reaches the end. Will be fired from the audio thread. Do not restart, uninitialize or otherwise change the state of the sound from here. Instead fire an event or set a variable to indicate to a different thread to change the start of the sound. Will not be fired in response to a scheduled stop with ma_sound_set_stop_time_*(). */ void* pEndCallbackUserData; + ma_resampler_config pitchResampling; #ifndef MA_NO_RESOURCE_MANAGER ma_resource_manager_pipeline_notifications initNotifications; #endif @@ -11211,7 +11246,10 @@ struct ma_sound MA_ATOMIC(4, ma_bool32) atEnd; ma_sound_end_proc endCallback; void* pEndCallbackUserData; - ma_bool8 ownsDataSource; + float* pProcessingCache; /* Will be null if pDataSource is null. */ + ma_uint32 processingCacheFramesRemaining; + ma_uint32 processingCacheCap; + ma_bool8 ownsDataSource; /* We're declaring a resource manager data source object here to save us a malloc when loading a @@ -11255,7 +11293,7 @@ typedef struct ma_log* pLog; /* When set to NULL, will use the context's log. */ ma_uint32 listenerCount; /* Must be between 1 and MA_ENGINE_MAX_LISTENERS. */ ma_uint32 channels; /* The number of channels to use when mixing and spatializing. When set to 0, will use the native channel count of the device. */ - ma_uint32 sampleRate; /* The sample rate. When set to 0 will use the native channel count of the device. */ + ma_uint32 sampleRate; /* The sample rate. When set to 0 will use the native sample rate of the device. */ ma_uint32 periodSizeInFrames; /* If set to something other than 0, updates will always be exactly this size. The underlying device may be a different size, but from the perspective of the mixer that won't matter.*/ ma_uint32 periodSizeInMilliseconds; /* Used if periodSizeInFrames is unset. */ ma_uint32 gainSmoothTimeInFrames; /* The number of frames to interpolate the gain of spatialized sounds across. If set to 0, will use gainSmoothTimeInMilliseconds. */ @@ -11269,6 +11307,8 @@ typedef struct ma_vfs* pResourceManagerVFS; /* A pointer to a pre-allocated VFS object to use with the resource manager. This is ignored if pResourceManager is not NULL. */ ma_engine_process_proc onProcess; /* Fired at the end of each call to ma_engine_read_pcm_frames(). For engine's that manage their own internal device (the default configuration), this will be fired from the audio thread, and you do not need to call ma_engine_read_pcm_frames() manually in order to trigger this. */ void* pProcessUserData; /* User data that's passed into onProcess. */ + ma_resampler_config resourceManagerResampling; /* The resampling config to use with the resource manager. */ + ma_resampler_config pitchResampling; /* The resampling config for the pitch and Doppler effects. You will typically want this to be a fast resampler. For high quality stuff, it's recommended that you pre-resample. */ } ma_engine_config; MA_API ma_engine_config ma_engine_config_init(void); @@ -11298,6 +11338,7 @@ struct ma_engine ma_mono_expansion_mode monoExpansionMode; ma_engine_process_proc onProcess; void* pProcessUserData; + ma_resampler_config pitchResamplingConfig; }; MA_API ma_result ma_engine_init(const ma_engine_config* pConfig, ma_engine* pEngine); @@ -11358,8 +11399,12 @@ MA_API ma_engine* ma_sound_get_engine(const ma_sound* pSound); MA_API ma_data_source* ma_sound_get_data_source(const ma_sound* pSound); MA_API ma_result ma_sound_start(ma_sound* pSound); MA_API ma_result ma_sound_stop(ma_sound* pSound); -MA_API ma_result ma_sound_stop_with_fade_in_pcm_frames(ma_sound* pSound, ma_uint64 fadeLengthInFrames); /* Will overwrite any scheduled stop and fade. */ -MA_API ma_result ma_sound_stop_with_fade_in_milliseconds(ma_sound* pSound, ma_uint64 fadeLengthInFrames); /* Will overwrite any scheduled stop and fade. */ +MA_API ma_result ma_sound_stop_with_fade_in_pcm_frames(ma_sound* pSound, ma_uint64 fadeLengthInFrames); /* Will overwrite any scheduled stop and fade. If you want to restart the sound, first reset it with `ma_sound_reset_stop_time_and_fade()`. There are plans to make this less awkward in the future. */ +MA_API ma_result ma_sound_stop_with_fade_in_milliseconds(ma_sound* pSound, ma_uint64 fadeLengthInFrames); /* Will overwrite any scheduled stop and fade. If you want to restart the sound, first reset it with `ma_sound_reset_stop_time_and_fade()`. There are plans to make this less awkward in the future. */ +MA_API void ma_sound_reset_start_time(ma_sound* pSound); +MA_API void ma_sound_reset_stop_time(ma_sound* pSound); +MA_API void ma_sound_reset_fade(ma_sound* pSound); +MA_API void ma_sound_reset_stop_time_and_fade(ma_sound* pSound); /* Resets fades and scheduled stop time. Does not seek back to the start. */ MA_API void ma_sound_set_volume(ma_sound* pSound, float volume); MA_API float ma_sound_get_volume(const ma_sound* pSound); MA_API void ma_sound_set_pan(ma_sound* pSound, float pan); @@ -11419,11 +11464,11 @@ MA_API ma_bool32 ma_sound_is_looping(const ma_sound* pSound); MA_API ma_bool32 ma_sound_at_end(const ma_sound* pSound); MA_API ma_result ma_sound_seek_to_pcm_frame(ma_sound* pSound, ma_uint64 frameIndex); /* Just a wrapper around ma_data_source_seek_to_pcm_frame(). */ MA_API ma_result ma_sound_seek_to_second(ma_sound* pSound, float seekPointInSeconds); /* Abstraction to ma_sound_seek_to_pcm_frame() */ -MA_API ma_result ma_sound_get_data_format(ma_sound* pSound, ma_format* pFormat, ma_uint32* pChannels, ma_uint32* pSampleRate, ma_channel* pChannelMap, size_t channelMapCap); -MA_API ma_result ma_sound_get_cursor_in_pcm_frames(ma_sound* pSound, ma_uint64* pCursor); -MA_API ma_result ma_sound_get_length_in_pcm_frames(ma_sound* pSound, ma_uint64* pLength); -MA_API ma_result ma_sound_get_cursor_in_seconds(ma_sound* pSound, float* pCursor); -MA_API ma_result ma_sound_get_length_in_seconds(ma_sound* pSound, float* pLength); +MA_API ma_result ma_sound_get_data_format(const ma_sound* pSound, ma_format* pFormat, ma_uint32* pChannels, ma_uint32* pSampleRate, ma_channel* pChannelMap, size_t channelMapCap); +MA_API ma_result ma_sound_get_cursor_in_pcm_frames(const ma_sound* pSound, ma_uint64* pCursor); +MA_API ma_result ma_sound_get_length_in_pcm_frames(const ma_sound* pSound, ma_uint64* pLength); +MA_API ma_result ma_sound_get_cursor_in_seconds(const ma_sound* pSound, float* pCursor); +MA_API ma_result ma_sound_get_length_in_seconds(const ma_sound* pSound, float* pLength); MA_API ma_result ma_sound_set_end_callback(ma_sound* pSound, ma_sound_end_proc callback, void* pUserData); MA_API ma_result ma_sound_group_init(ma_engine* pEngine, ma_uint32 flags, ma_sound_group* pParentGroup, ma_sound_group* pGroup); @@ -11544,16 +11589,22 @@ IMPLEMENTATION #endif #if !defined(MA_WIN32) -#include -#include /* select() (used for ma_sleep()). */ -#include -#endif + #if !defined(MA_NO_THREADING) + #include + #include /* For pthreads. */ + #endif -#ifdef MA_NX -#include /* For nanosleep() */ + #include /* select() (used for ma_sleep()). */ + #include /* For nanosleep() */ + #include #endif -#include /* For fstat(), etc. */ +/* For fstat(), etc. */ +#if defined(MA_XBOX_NXDK) + #include /* Suggestion for NXDK: Add a sys/stat.h wrapper for compatibility. */ +#else + #include +#endif #ifdef MA_EMSCRIPTEN #include @@ -11606,7 +11657,7 @@ IMPLEMENTATION #endif /* Intrinsics Support */ -#if (defined(MA_X64) || defined(MA_X86)) && !defined(__COSMOPOLITAN__) +#if defined(MA_X64) || defined(MA_X86) #if defined(_MSC_VER) && !defined(__clang__) /* MSVC. */ #if _MSC_VER >= 1400 && !defined(MA_NO_SSE2) /* 2005 */ @@ -11861,7 +11912,7 @@ static MA_INLINE ma_bool32 ma_has_neon(void) #endif #ifndef MA_RESTRICT - #if defined(__clang__) || defined(__GNUC__) || defined(_MSC_VER) + #if defined(__clang__) || defined(_MSC_VER) || (defined(__GNUC__) && (__GNUC__ > 2 || (__GNUC__ == 2 && __GNUC_MINOR__ >= 95))) #define MA_RESTRICT __restrict #else #define MA_RESTRICT @@ -11955,7 +12006,7 @@ static void ma_sleep__posix(ma_uint32 milliseconds) (void)milliseconds; MA_ASSERT(MA_FALSE); /* The Emscripten build should never sleep. */ #else - #if (defined(_POSIX_C_SOURCE) && _POSIX_C_SOURCE >= 199309L) || defined(MA_NX) + #if (defined(_POSIX_C_SOURCE) && _POSIX_C_SOURCE >= 199309L) || defined(MA_SWITCH) struct timespec ts; ts.tv_sec = milliseconds / 1000; ts.tv_nsec = milliseconds % 1000 * 1000000; @@ -11997,7 +12048,7 @@ static MA_INLINE void ma_yield(void) #endif #endif #else - __asm__ __volatile__ ("pause"); + __asm__ __volatile__ ("rep; nop"); #endif #elif (defined(__arm__) && defined(__ARM_ARCH) && __ARM_ARCH >= 7) || defined(_M_ARM64) || (defined(_M_ARM) && _M_ARM >= 7) || defined(__ARM_ARCH_6K__) || defined(__ARM_ARCH_6T2__) /* ARM */ @@ -12020,7 +12071,7 @@ static MA_INLINE unsigned int ma_disable_denormals(void) { unsigned int prevState; - #if defined(_MSC_VER) + #if defined(_MSC_VER) && !defined(MA_XBOX_NXDK) { /* Older versions of Visual Studio don't support the "safe" versions of _controlfp_s(). I don't @@ -12043,7 +12094,7 @@ static MA_INLINE unsigned int ma_disable_denormals(void) } #elif defined(MA_X86) || defined(MA_X64) { - #if defined(__SSE2__) && !(defined(__TINYC__) || defined(__WATCOMC__) || defined(__COSMOPOLITAN__)) /* <-- Add compilers that lack support for _mm_getcsr() and _mm_setcsr() to this list. */ + #if defined(MA_SUPPORT_SSE2) && defined(__SSE2__) && !(defined(__TINYC__) || defined(__WATCOMC__)) /* <-- Add compilers that lack support for _mm_getcsr() and _mm_setcsr() to this list. */ { prevState = _mm_getcsr(); _mm_setcsr(prevState | MA_MM_DENORMALS_ZERO_MASK | MA_MM_FLUSH_ZERO_MASK); @@ -12067,7 +12118,7 @@ static MA_INLINE unsigned int ma_disable_denormals(void) static MA_INLINE void ma_restore_denormals(unsigned int prevState) { - #if defined(_MSC_VER) + #if defined(_MSC_VER) && !defined(MA_XBOX_NXDK) { /* Older versions of Visual Studio do not support _controlfp_s(). See ma_disable_denormals(). */ #if _MSC_VER <= 1200 @@ -12083,7 +12134,7 @@ static MA_INLINE void ma_restore_denormals(unsigned int prevState) } #elif defined(MA_X86) || defined(MA_X64) { - #if defined(__SSE2__) && !(defined(__TINYC__) || defined(__WATCOMC__) || defined(__COSMOPOLITAN__)) /* <-- Add compilers that lack support for _mm_getcsr() and _mm_setcsr() to this list. */ + #if defined(MA_SUPPORT_SSE2) && defined(__SSE2__) && !(defined(__TINYC__) || defined(__WATCOMC__)) /* <-- Add compilers that lack support for _mm_getcsr() and _mm_setcsr() to this list. */ { _mm_setcsr(prevState); } @@ -12719,6 +12770,29 @@ MA_API MA_NO_INLINE int ma_strcmp(const char* str1, const char* str2) return ((unsigned char*)str1)[0] - ((unsigned char*)str2)[0]; } +MA_API MA_NO_INLINE int ma_wcscmp(const wchar_t* str1, const wchar_t* str2) +{ + if (str1 == str2) return 0; + + /* These checks differ from the standard implementation. It's not important, but I prefer it just for sanity. */ + if (str1 == NULL) return -1; + if (str2 == NULL) return 1; + + for (;;) { + if (str1[0] == L'\0') { + break; + } + if (str1[0] != str2[0]) { + break; + } + + str1 += 1; + str2 += 1; + } + + return ((unsigned short*)str1)[0] - ((unsigned short*)str2)[0]; +} + MA_API MA_NO_INLINE int ma_strappend(char* dst, size_t dstSize, const char* srcA, const char* srcB) { int result; @@ -12736,6 +12810,22 @@ MA_API MA_NO_INLINE int ma_strappend(char* dst, size_t dstSize, const char* srcA return result; } +MA_API MA_NO_INLINE size_t ma_wcslen(const wchar_t* str) +{ + const wchar_t* end; + + if (str == NULL) { + return 0; + } + + end = str; + while (end[0] != '\0') { + end += 1; + } + + return end - str; +} + MA_API MA_NO_INLINE char* ma_copy_string(const char* src, const ma_allocation_callbacks* pAllocationCallbacks) { size_t sz; @@ -12758,7 +12848,7 @@ MA_API MA_NO_INLINE char* ma_copy_string(const char* src, const ma_allocation_ca MA_API MA_NO_INLINE wchar_t* ma_copy_string_w(const wchar_t* src, const ma_allocation_callbacks* pAllocationCallbacks) { - size_t sz = wcslen(src)+1; + size_t sz = ma_wcslen(src)+1; wchar_t* dst = (wchar_t*)ma_malloc(sz * sizeof(*dst), pAllocationCallbacks); if (dst == NULL) { return NULL; @@ -13189,7 +13279,7 @@ MA_API ma_result ma_fopen(FILE** ppFile, const char* pFilePath, const char* pOpe return MA_INVALID_ARGS; } -#if defined(_MSC_VER) && _MSC_VER >= 1400 +#if (defined(_MSC_VER) && _MSC_VER >= 1400) && !defined(MA_XBOX_NXDK) err = fopen_s(ppFile, pFilePath, pOpenMode); if (err != 0) { return ma_result_from_errno(err); @@ -13231,7 +13321,7 @@ _wfopen() isn't always available in all compilation environments. This can be reviewed as compatibility issues arise. The preference is to use _wfopen_s() and _wfopen() as opposed to the wcsrtombs() fallback, so if you notice your compiler not detecting this properly I'm happy to look at adding support. */ -#if defined(_WIN32) +#if defined(_WIN32) && !defined(MA_XBOX_NXDK) #if defined(_MSC_VER) || defined(__MINGW64__) || (!defined(__STRICT_ANSI__) && !defined(_NO_EXT_KEYS)) #define MA_HAS_WFOPEN #endif @@ -13247,29 +13337,34 @@ MA_API ma_result ma_wfopen(FILE** ppFile, const wchar_t* pFilePath, const wchar_ return MA_INVALID_ARGS; } -#if defined(MA_HAS_WFOPEN) + #if defined(MA_HAS_WFOPEN) { /* Use _wfopen() on Windows. */ - #if defined(_MSC_VER) && _MSC_VER >= 1400 - errno_t err = _wfopen_s(ppFile, pFilePath, pOpenMode); - if (err != 0) { - return ma_result_from_errno(err); + #if defined(_MSC_VER) && _MSC_VER >= 1400 + { + errno_t err = _wfopen_s(ppFile, pFilePath, pOpenMode); + if (err != 0) { + return ma_result_from_errno(err); + } } - #else - *ppFile = _wfopen(pFilePath, pOpenMode); - if (*ppFile == NULL) { - return ma_result_from_errno(errno); + #else + { + *ppFile = _wfopen(pFilePath, pOpenMode); + if (*ppFile == NULL) { + return ma_result_from_errno(errno); + } } - #endif + #endif + (void)pAllocationCallbacks; } -#else - /* - Use fopen() on anything other than Windows. Requires a conversion. This is annoying because fopen() is locale specific. The only real way I can - think of to do this is with wcsrtombs(). Note that wcstombs() is apparently not thread-safe because it uses a static global mbstate_t object for - maintaining state. I've checked this with -std=c89 and it works, but if somebody get's a compiler error I'll look into improving compatibility. - */ + #elif !defined(MA_XBOX_NXDK) && !defined(MA_DOS) /* If your compiler does not support wcsrtombs(), add it here. */ { + /* + Use fopen() on anything other than Windows. Requires a conversion. This is annoying because fopen() is locale specific. The only real way I can + think of to do this is with wcsrtombs(). Note that wcstombs() is apparently not thread-safe because it uses a static global mbstate_t object for + maintaining state. I've checked this with -std=c89 and it works, but if somebody get's a compiler error I'll look into improving compatibility. + */ mbstate_t mbs; size_t lenMB; const wchar_t* pFilePathTemp = pFilePath; @@ -13310,11 +13405,16 @@ MA_API ma_result ma_wfopen(FILE** ppFile, const wchar_t* pFilePath, const wchar_ ma_free(pFilePathMB, pAllocationCallbacks); } + #else + { + /* Getting here means there is no way to open the file with a wide character string. */ + *ppFile = NULL; + } + #endif if (*ppFile == NULL) { return MA_ERROR; } -#endif return MA_SUCCESS; } @@ -13323,7 +13423,7 @@ MA_API ma_result ma_wfopen(FILE** ppFile, const wchar_t* pFilePath, const wchar_ static MA_INLINE void ma_copy_memory_64(void* dst, const void* src, ma_uint64 sizeInBytes) { -#if 0xFFFFFFFFFFFFFFFF <= MA_SIZE_MAX +#if MA_SIZE_MAX > 0xFFFFFFFF MA_COPY_MEMORY(dst, src, (size_t)sizeInBytes); #else while (sizeInBytes > 0) { @@ -13343,7 +13443,7 @@ static MA_INLINE void ma_copy_memory_64(void* dst, const void* src, ma_uint64 si static MA_INLINE void ma_zero_memory_64(void* dst, ma_uint64 sizeInBytes) { -#if 0xFFFFFFFFFFFFFFFF <= MA_SIZE_MAX +#if MA_SIZE_MAX > 0xFFFFFFFF MA_ZERO_MEMORY(dst, (size_t)sizeInBytes); #else while (sizeInBytes > 0) { @@ -13472,6 +13572,18 @@ static ma_result ma_allocation_callbacks_init_copy(ma_allocation_callbacks* pDst Logging **************************************************************************************************************************************************************/ +#ifndef ma_va_copy + #if !defined(_MSC_VER) || _MSC_VER >= 1800 + #if (defined(__GNUC__) && __GNUC__ < 3) + #define ma_va_copy(dst, src) ((dst) = (src)) /* This is untested. Not sure if this is correct for old GCC. */ + #else + #define ma_va_copy(dst, src) va_copy((dst), (src)) + #endif + #else + #define ma_va_copy(dst, src) ((dst) = (src)) + #endif +#endif + MA_API const char* ma_log_level_to_string(ma_uint32 logLevel) { switch (logLevel) @@ -13712,9 +13824,15 @@ MA_API ma_result ma_log_postv(ma_log* pLog, ma_uint32 level, const char* pFormat int length; char pFormattedMessageStack[1024]; char* pFormattedMessageHeap = NULL; + va_list args2; /* First try formatting into our fixed sized stack allocated buffer. If this is too small we'll fallback to a heap allocation. */ - length = vsnprintf(pFormattedMessageStack, sizeof(pFormattedMessageStack), pFormat, args); + ma_va_copy(args2, args); + { + length = vsnprintf(pFormattedMessageStack, sizeof(pFormattedMessageStack), pFormat, args2); + } + va_end(args2); + if (length < 0) { return MA_INVALID_OPERATION; /* An error occurred when trying to convert the buffer. */ } @@ -13755,17 +13873,10 @@ MA_API ma_result ma_log_postv(ma_log* pLog, ma_uint32 level, const char* pFormat char* pFormattedMessage = NULL; va_list args2; - #if _MSC_VER >= 1800 - { - va_copy(args2, args); - } - #else + ma_va_copy(args2, args); { - args2 = args; + formattedLen = ma_vscprintf(&pLog->allocationCallbacks, pFormat, args2); } - #endif - - formattedLen = ma_vscprintf(&pLog->allocationCallbacks, pFormat, args2); va_end(args2); if (formattedLen <= 0) { @@ -13964,7 +14075,7 @@ miniaudio's purposes. #define MA_LCG_A 48271 #define MA_LCG_C 0 -static ma_lcg g_maLCG = {MA_DEFAULT_LCG_SEED}; /* Non-zero initial seed. Use ma_seed() to use an explicit seed. */ +static ma_lcg g_maLCG = {MA_DEFAULT_LCG_SEED}; /* Non-zero initial seed. Use ma_lcg_seed() to use an explicit seed. */ static MA_INLINE void ma_lcg_seed(ma_lcg* pLCG, ma_int32 seed) { @@ -14013,7 +14124,7 @@ static MA_INLINE ma_int32 ma_lcg_rand_range_s32(ma_lcg* pLCG, ma_int32 lo, ma_in } - +#if 0 /* Currently unused. */ static MA_INLINE void ma_seed(ma_int32 seed) { ma_lcg_seed(&g_maLCG, seed); @@ -14038,6 +14149,7 @@ static MA_INLINE float ma_rand_f32(void) { return ma_lcg_rand_f32(&g_maLCG); } +#endif static MA_INLINE float ma_rand_range_f32(float lo, float hi) { @@ -14097,6 +14209,7 @@ Atomics **************************************************************************************************************************************************************/ /* c89atomic.h begin */ #ifndef ma_atomic_h +#define ma_atomic_h #if defined(__cplusplus) extern "C" { #endif @@ -14108,11 +14221,63 @@ extern "C" { #endif #endif typedef int ma_atomic_memory_order; -#define MA_ATOMIC_HAS_8 -#define MA_ATOMIC_HAS_16 -#define MA_ATOMIC_HAS_32 -#define MA_ATOMIC_HAS_64 -#if (defined(_MSC_VER) ) || defined(__WATCOMC__) || defined(__DMC__) +#if !defined(MA_ATOMIC_MODERN_MSVC) && \ + !defined(MA_ATOMIC_LEGACY_MSVC) && \ + !defined(MA_ATOMIC_LEGACY_MSVC_ASM) && \ + !defined(MA_ATOMIC_MODERN_GCC) && \ + !defined(MA_ATOMIC_LEGACY_GCC) && \ + !defined(MA_ATOMIC_LEGACY_GCC_ASM) + #if defined(_MSC_VER) || defined(__WATCOMC__) || defined(__DMC__) || defined(__BORLANDC__) + #if (defined(_MSC_VER) && _MSC_VER > 1600) + #define MA_ATOMIC_MODERN_MSVC + #else + #if defined(MA_X64) + #define MA_ATOMIC_LEGACY_MSVC + #else + #define MA_ATOMIC_LEGACY_MSVC_ASM + #endif + #endif + #elif (defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 7))) || defined(__clang__) + #define MA_ATOMIC_MODERN_GCC + #else + #if defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 1)) + #define MA_ATOMIC_LEGACY_GCC + #else + #define MA_ATOMIC_LEGACY_GCC_ASM + #endif + #endif +#endif +#if defined(MA_ATOMIC_MODERN_MSVC) || defined(MA_ATOMIC_LEGACY_MSVC) + #include + #define ma_atomic_memory_order_relaxed 1 + #define ma_atomic_memory_order_consume 2 + #define ma_atomic_memory_order_acquire 3 + #define ma_atomic_memory_order_release 4 + #define ma_atomic_memory_order_acq_rel 5 + #define ma_atomic_memory_order_seq_cst 6 + #define MA_ATOMIC_MSVC_ARM_INTRINSIC_NORETURN(dst, src, order, intrin, ma_atomicType, msvcType) \ + switch (order) \ + { \ + case ma_atomic_memory_order_relaxed: \ + { \ + intrin##_nf((volatile msvcType*)dst, (msvcType)src); \ + } break; \ + case ma_atomic_memory_order_consume: \ + case ma_atomic_memory_order_acquire: \ + { \ + intrin##_acq((volatile msvcType*)dst, (msvcType)src); \ + } break; \ + case ma_atomic_memory_order_release: \ + { \ + intrin##_rel((volatile msvcType*)dst, (msvcType)src); \ + } break; \ + case ma_atomic_memory_order_acq_rel: \ + case ma_atomic_memory_order_seq_cst: \ + default: \ + { \ + intrin((volatile msvcType*)dst, (msvcType)src); \ + } break; \ + } #define MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, intrin, ma_atomicType, msvcType) \ ma_atomicType result; \ switch (order) \ @@ -14138,720 +14303,1501 @@ typedef int ma_atomic_memory_order; } break; \ } \ return result; - #define MA_ATOMIC_MSVC_ARM_INTRINSIC_COMPARE_EXCHANGE(ptr, expected, desired, order, intrin, ma_atomicType, msvcType) \ + typedef ma_uint32 ma_atomic_flag; + static MA_INLINE ma_atomic_flag ma_atomic_flag_test_and_set_explicit(volatile ma_atomic_flag* dst, ma_atomic_memory_order order) + { + #if defined(MA_ARM) + { + MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, 1, order, _InterlockedExchange, ma_atomic_flag, long); + } + #else + { + (void)order; + return (ma_atomic_flag)_InterlockedExchange((volatile long*)dst, (long)1); + } + #endif + } + static MA_INLINE void ma_atomic_flag_clear_explicit(volatile ma_atomic_flag* dst, ma_atomic_memory_order order) + { + #if defined(MA_ARM) + { + MA_ATOMIC_MSVC_ARM_INTRINSIC_NORETURN(dst, 0, order, _InterlockedExchange, ma_atomic_flag, long); + } + #else + { + (void)order; + _InterlockedExchange((volatile long*)dst, (long)0); + } + #endif + } + static MA_INLINE ma_atomic_flag ma_atomic_flag_load_explicit(volatile const ma_atomic_flag* dst, ma_atomic_memory_order order) + { + (void)order; + return (ma_uint32)_InterlockedCompareExchange((volatile long*)dst, 0, 0); + } +#endif +#if defined(MA_ATOMIC_LEGACY_MSVC_ASM) + #define ma_atomic_memory_order_relaxed 1 + #define ma_atomic_memory_order_consume 2 + #define ma_atomic_memory_order_acquire 3 + #define ma_atomic_memory_order_release 4 + #define ma_atomic_memory_order_acq_rel 5 + #define ma_atomic_memory_order_seq_cst 6 + typedef ma_uint32 ma_atomic_flag; + static MA_INLINE ma_atomic_flag ma_atomic_flag_test_and_set_explicit(volatile ma_atomic_flag* dst, ma_atomic_memory_order order) + { + ma_atomic_flag result = 0; + (void)order; + __asm { + mov ecx, dst + mov eax, 1 + xchg [ecx], eax + mov result, eax + } + return result; + } + static MA_INLINE void ma_atomic_flag_clear_explicit(volatile ma_atomic_flag* dst, ma_atomic_memory_order order) + { + if (order == ma_atomic_memory_order_relaxed) { + __asm { + mov esi, dst + mov dword ptr [esi], 0 + } + } else { + __asm { + mov esi, dst + mov eax, 0 + xchg [esi], eax + } + } + } + static MA_INLINE ma_atomic_flag ma_atomic_flag_load_explicit(volatile const ma_atomic_flag* dst, ma_atomic_memory_order order) + { + ma_atomic_flag result = 0; + if (order == ma_atomic_memory_order_relaxed) { + __asm { + mov esi, dst + mov eax, [esi] + mov result, eax + } + } else if (order <= ma_atomic_memory_order_release) { + __asm { + mov esi, dst + mov eax, [esi] + lock add dword ptr [esp], 0 + mov result, eax + } + } else { + __asm { + lock add dword ptr [esp], 0 + mov esi, dst + mov eax, [esi] + mov result, eax + lock add dword ptr [esp], 0 + } + } + return result; + } +#endif +#if defined(MA_ATOMIC_MODERN_GCC) + #define ma_atomic_memory_order_relaxed __ATOMIC_RELAXED + #define ma_atomic_memory_order_consume __ATOMIC_CONSUME + #define ma_atomic_memory_order_acquire __ATOMIC_ACQUIRE + #define ma_atomic_memory_order_release __ATOMIC_RELEASE + #define ma_atomic_memory_order_acq_rel __ATOMIC_ACQ_REL + #define ma_atomic_memory_order_seq_cst __ATOMIC_SEQ_CST + typedef ma_uint32 ma_atomic_flag; + #define ma_atomic_flag_test_and_set_explicit(dst, order) __atomic_exchange_n(dst, 1, order) + #define ma_atomic_flag_clear_explicit(dst, order) __atomic_store_n(dst, 0, order) + #define ma_atomic_flag_load_explicit(dst, order) __atomic_load_n(dst, order) +#endif +#if defined(MA_ATOMIC_LEGACY_GCC) + #define ma_atomic_memory_order_relaxed 1 + #define ma_atomic_memory_order_consume 2 + #define ma_atomic_memory_order_acquire 3 + #define ma_atomic_memory_order_release 4 + #define ma_atomic_memory_order_acq_rel 5 + #define ma_atomic_memory_order_seq_cst 6 + typedef ma_uint32 ma_atomic_flag; + static MA_INLINE ma_atomic_flag ma_atomic_flag_test_and_set_explicit(volatile ma_atomic_flag* dst, ma_atomic_memory_order order) + { + if (order > ma_atomic_memory_order_acquire) { + __sync_synchronize(); + } + return __sync_lock_test_and_set(dst, 1); + } + static MA_INLINE void ma_atomic_flag_clear_explicit(volatile ma_atomic_flag* dst, ma_atomic_memory_order order) + { + if (order > ma_atomic_memory_order_release) { + __sync_synchronize(); + } + __sync_lock_release(dst); + } + static MA_INLINE ma_atomic_flag ma_atomic_flag_load_explicit(volatile const ma_atomic_flag* dst, ma_atomic_memory_order order) + { + (void)order; + return __sync_val_compare_and_swap((ma_atomic_flag*)dst, 0, 0); + } +#endif +#if defined(MA_ATOMIC_LEGACY_GCC_ASM) + #define ma_atomic_memory_order_relaxed 1 + #define ma_atomic_memory_order_consume 2 + #define ma_atomic_memory_order_acquire 3 + #define ma_atomic_memory_order_release 4 + #define ma_atomic_memory_order_acq_rel 5 + #define ma_atomic_memory_order_seq_cst 6 + #if defined(MA_X86) + #define ma_atomic_thread_fence(order) __asm__ __volatile__("lock; addl $0, (%%esp)" ::: "memory") + #elif defined(MA_X64) + #define ma_atomic_thread_fence(order) __asm__ __volatile__("lock; addq $0, (%%rsp)" ::: "memory") + #else + #error Unsupported architecture. + #endif + #define MA_ATOMIC_XCHG_GCC_X86(instructionSizeSuffix, result, dst, src) \ + __asm__ __volatile__( \ + "xchg"instructionSizeSuffix" %0, %1" \ + : "=r"(result), \ + "=m"(*dst) \ + : "0"(src), \ + "m"(*dst) \ + : "memory" \ + ) + #define MA_ATOMIC_LOAD_RELAXED_GCC_X86(instructionSizeSuffix, result, dst) \ + __asm__ __volatile__( \ + "mov"instructionSizeSuffix" %1, %0" \ + : "=r"(result) \ + : "m"(*dst) \ + ) + #define MA_ATOMIC_LOAD_RELEASE_GCC_X86(instructionSizeSuffix, result, dst) \ + ma_atomic_thread_fence(ma_atomic_memory_order_release); \ + __asm__ __volatile__( \ + "mov"instructionSizeSuffix" %1, %0" \ + : "=r"(result) \ + : "m"(*dst) \ + : "memory" \ + ) + #define MA_ATOMIC_LOAD_SEQ_CST_GCC_X86(instructionSizeSuffix, result, dst) \ + ma_atomic_thread_fence(ma_atomic_memory_order_seq_cst); \ + __asm__ __volatile__( \ + "mov"instructionSizeSuffix" %1, %0" \ + : "=r"(result) \ + : "m"(*dst) \ + : "memory" \ + ); \ + ma_atomic_thread_fence(ma_atomic_memory_order_seq_cst) + typedef ma_uint32 ma_atomic_flag; + static MA_INLINE ma_atomic_flag ma_atomic_flag_test_and_set_explicit(volatile ma_atomic_flag* dst, ma_atomic_memory_order order) + { + ma_atomic_flag result; + #if defined(MA_X86) || defined(MA_X64) + { + (void)order; + MA_ATOMIC_XCHG_GCC_X86("l", result, dst, 1); + } + #else + { + #error Unsupported architecture. + } + #endif + return result; + } + static MA_INLINE void ma_atomic_flag_clear_explicit(volatile ma_atomic_flag* dst, ma_atomic_memory_order order) + { + #if defined(MA_X86) || defined(MA_X64) + { + if (order == ma_atomic_memory_order_relaxed) { + __asm__ __volatile__( + "movl $0, %0" + : "=m"(*dst) + ); + } else if (order == ma_atomic_memory_order_release) { + __asm__ __volatile__( + "movl $0, %0" + : "=m"(*dst) + : + : "memory" + ); + } else { + ma_atomic_flag tmp = 0; + __asm__ __volatile__( + "xchgl %0, %1" + : "=r"(tmp), + "=m"(*dst) + : "0"(tmp), + "m"(*dst) + : "memory" + ); + } + } + #else + { + #error Unsupported architecture. + } + #endif + } + static MA_INLINE ma_atomic_flag ma_atomic_flag_load_explicit(volatile const ma_atomic_flag* dst, ma_atomic_memory_order order) + { + #if defined(MA_X86) || defined(MA_X64) + { + ma_atomic_flag result; + if (order == ma_atomic_memory_order_relaxed) { + MA_ATOMIC_LOAD_RELAXED_GCC_X86("l", result, dst); + } else if (order <= ma_atomic_memory_order_release) { + MA_ATOMIC_LOAD_RELEASE_GCC_X86("l", result, dst); + } else { + MA_ATOMIC_LOAD_SEQ_CST_GCC_X86("l", result, dst); + } + return result; + } + #else + { + #error Unsupported architecture. + } + #endif + } +#endif +#define ma_atomic_flag_test_and_set(dst) ma_atomic_flag_test_and_set_explicit(dst, ma_atomic_memory_order_acquire) +#define ma_atomic_flag_clear(dst) ma_atomic_flag_clear_explicit(dst, ma_atomic_memory_order_release) +typedef ma_atomic_flag ma_atomic_spinlock; +static MA_INLINE void ma_atomic_spinlock_lock(volatile ma_atomic_spinlock* pSpinlock) +{ + for (;;) { + if (ma_atomic_flag_test_and_set_explicit(pSpinlock, ma_atomic_memory_order_acquire) == 0) { + break; + } + while (ma_atomic_flag_load_explicit(pSpinlock, ma_atomic_memory_order_relaxed) == 1) { + } + } +} +static MA_INLINE void ma_atomic_spinlock_unlock(volatile ma_atomic_spinlock* pSpinlock) +{ + ma_atomic_flag_clear_explicit(pSpinlock, ma_atomic_memory_order_release); +} +ma_atomic_spinlock ma_atomic_global_lock; +#if defined(MA_ATOMIC_MODERN_MSVC) || defined(MA_ATOMIC_LEGACY_MSVC) || defined(MA_ATOMIC_LEGACY_MSVC_ASM) || defined(MA_ATOMIC_LEGACY_GCC) || defined(MA_ATOMIC_LEGACY_GCC_ASM) + #if defined(MA_X64) || (defined(MA_X86) && ((defined(__GNUC__) && defined(__i486__)) || (defined(_M_IX86) && _M_IX86 >= 400))) + #if defined(MA_ATOMIC_LEGACY_MSVC) && defined(MA_X64) + #else + #define MA_ATOMIC_IS_LOCK_FREE_8 1 + #define MA_ATOMIC_IS_LOCK_FREE_16 1 + #endif + #define MA_ATOMIC_IS_LOCK_FREE_32 1 + #if defined(MA_X64) || (defined(MA_X86) && ((defined(__GNUC__) && defined(__i586__)) || (defined(_M_IX86) && _M_IX86 >= 500))) + #define MA_ATOMIC_IS_LOCK_FREE_64 1 + #else + #endif + #else + #endif + #if defined(MA_ARM32) || defined(MA_ARM64) + #define MA_ATOMIC_IS_LOCK_FREE_8 1 + #define MA_ATOMIC_IS_LOCK_FREE_16 1 + #define MA_ATOMIC_IS_LOCK_FREE_32 1 + #if defined(MA_ARM64) || defined(__ARM_ARCH_7A__) || defined(__ARM_ARCH_7R__) || defined(__ARM_ARCH_6K__) || defined(__ARM_ARCH_6Z__) || defined(__ARM_ARCH_6ZK__) + #define MA_ATOMIC_IS_LOCK_FREE_64 1 + #endif + #endif + #if defined(MA_ATOMIC_PPC32) || defined(MA_ATOMIC_PPC64) + #if (defined(__GNUC__) && (__GNUC__ < 4 || (__GNUC__ == 4 && __GNUC_MINOR__ < 7))) && !defined(__clang__) + #else + #define MA_ATOMIC_IS_LOCK_FREE_8 1 + #define MA_ATOMIC_IS_LOCK_FREE_16 1 + #endif + #define MA_ATOMIC_IS_LOCK_FREE_32 1 + #if defined(MA_ATOMIC_PPC64) + #define MA_ATOMIC_IS_LOCK_FREE_64 1 + #endif + #endif + static MA_INLINE ma_bool32 ma_atomic_is_lock_free_8(volatile void* ptr) + { + (void)ptr; + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) + return 1; + #else + return 0; + #endif + } + static MA_INLINE ma_bool32 ma_atomic_is_lock_free_16(volatile void* ptr) + { + (void)ptr; + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) + return 1; + #else + return 0; + #endif + } + static MA_INLINE ma_bool32 ma_atomic_is_lock_free_32(volatile void* ptr) + { + (void)ptr; + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) + return 1; + #else + return 0; + #endif + } + static MA_INLINE ma_bool32 ma_atomic_is_lock_free_64(volatile void* ptr) + { + (void)ptr; + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) + return 1; + #else + return 0; + #endif + } +#endif +#define MA_ATOMIC_COMPARE_AND_SWAP_LOCK(sizeInBits, dst, expected, replacement) \ + ma_uint##sizeInBits result; \ + ma_atomic_spinlock_lock(&ma_atomic_global_lock); \ + { \ + result = *dst; \ + if (result == expected) { \ + *dst = replacement; \ + } \ + } \ + ma_atomic_spinlock_unlock(&ma_atomic_global_lock); \ + return result +#define MA_ATOMIC_LOAD_EXPLICIT_LOCK(sizeInBits, ptr, order) \ + ma_uint##sizeInBits result; \ + ma_atomic_spinlock_lock(&ma_atomic_global_lock); \ + { \ + result = *ptr; \ + (void)order; \ + } \ + ma_atomic_spinlock_unlock(&ma_atomic_global_lock); \ + return result +#define MA_ATOMIC_STORE_EXPLICIT_LOCK(sizeInBits, dst, src, order) \ + ma_atomic_spinlock_lock(&ma_atomic_global_lock); \ + { \ + *dst = src; \ + (void)order; \ + } \ + ma_atomic_spinlock_unlock(&ma_atomic_global_lock) +#define MA_ATOMIC_STORE_EXPLICIT_CAS(sizeInBits, dst, src, order) \ + ma_uint##sizeInBits oldValue; \ + do { \ + oldValue = ma_atomic_load_explicit_##sizeInBits(dst, ma_atomic_memory_order_relaxed); \ + } while (ma_atomic_compare_and_swap_##sizeInBits(dst, oldValue, src) != oldValue); \ + (void)order +#define MA_ATOMIC_EXCHANGE_EXPLICIT_LOCK(sizeInBits, dst, src, order) \ + ma_uint##sizeInBits result; \ + ma_atomic_spinlock_lock(&ma_atomic_global_lock); \ + { \ + result = *dst; \ + *dst = src; \ + (void)order; \ + } \ + ma_atomic_spinlock_unlock(&ma_atomic_global_lock); \ + return result +#define MA_ATOMIC_EXCHANGE_EXPLICIT_CAS(sizeInBits, dst, src, order) \ + ma_uint##sizeInBits oldValue; \ + do { \ + oldValue = ma_atomic_load_explicit_##sizeInBits(dst, ma_atomic_memory_order_relaxed); \ + } while (ma_atomic_compare_and_swap_##sizeInBits(dst, oldValue, src) != oldValue); \ + (void)order; \ + return oldValue +#define MA_ATOMIC_FETCH_ADD_LOCK(sizeInBits, dst, src, order) \ + ma_uint##sizeInBits result; \ + ma_atomic_spinlock_lock(&ma_atomic_global_lock); \ + { \ + result = *dst; \ + *dst += src; \ + (void)order; \ + } \ + ma_atomic_spinlock_unlock(&ma_atomic_global_lock); \ + return result +#define MA_ATOMIC_FETCH_ADD_CAS(sizeInBits, dst, src, order) \ + ma_uint##sizeInBits oldValue; \ + ma_uint##sizeInBits newValue; \ + do { \ + oldValue = ma_atomic_load_explicit_##sizeInBits(dst, ma_atomic_memory_order_relaxed); \ + newValue = oldValue + src; \ + } while (ma_atomic_compare_and_swap_##sizeInBits(dst, oldValue, newValue) != oldValue); \ + (void)order; \ + return oldValue +#define MA_ATOMIC_FETCH_AND_CAS(sizeInBits, dst, src, order) \ + ma_uint##sizeInBits oldValue; \ + ma_uint##sizeInBits newValue; \ + do { \ + oldValue = ma_atomic_load_explicit_##sizeInBits(dst, ma_atomic_memory_order_relaxed); \ + newValue = (ma_uint##sizeInBits)(oldValue & src); \ + } while (ma_atomic_compare_and_swap_##sizeInBits(dst, oldValue, newValue) != oldValue); \ + (void)order; \ + return oldValue +#define MA_ATOMIC_FETCH_OR_CAS(sizeInBits, dst, src, order) \ + ma_uint##sizeInBits oldValue; \ + ma_uint##sizeInBits newValue; \ + do { \ + oldValue = ma_atomic_load_explicit_##sizeInBits(dst, ma_atomic_memory_order_relaxed); \ + newValue = (ma_uint##sizeInBits)(oldValue | src); \ + } while (ma_atomic_compare_and_swap_##sizeInBits(dst, oldValue, newValue) != oldValue); \ + (void)order; \ + return oldValue +#define MA_ATOMIC_FETCH_XOR_CAS(sizeInBits, dst, src, order) \ + ma_uint##sizeInBits oldValue; \ + ma_uint##sizeInBits newValue; \ + do { \ + oldValue = ma_atomic_load_explicit_##sizeInBits(dst, ma_atomic_memory_order_relaxed); \ + newValue = (ma_uint##sizeInBits)(oldValue ^ src); \ + } while (ma_atomic_compare_and_swap_##sizeInBits(dst, oldValue, newValue) != oldValue); \ + (void)order; \ + return oldValue +#if defined(MA_ATOMIC_MODERN_MSVC) || defined(MA_ATOMIC_LEGACY_MSVC) + #define MA_ATOMIC_MSVC_ARM_INTRINSIC_COMPARE_EXCHANGE(ptr, expected, replacement, order, intrin, ma_atomicType, msvcType) \ ma_atomicType result; \ switch (order) \ { \ case ma_atomic_memory_order_relaxed: \ { \ - result = (ma_atomicType)intrin##_nf((volatile msvcType*)ptr, (msvcType)expected, (msvcType)desired); \ + result = (ma_atomicType)intrin##_nf((volatile msvcType*)ptr, (msvcType)expected, (msvcType)replacement); \ } break; \ case ma_atomic_memory_order_consume: \ case ma_atomic_memory_order_acquire: \ { \ - result = (ma_atomicType)intrin##_acq((volatile msvcType*)ptr, (msvcType)expected, (msvcType)desired); \ + result = (ma_atomicType)intrin##_acq((volatile msvcType*)ptr, (msvcType)expected, (msvcType)replacement); \ } break; \ case ma_atomic_memory_order_release: \ { \ - result = (ma_atomicType)intrin##_rel((volatile msvcType*)ptr, (msvcType)expected, (msvcType)desired); \ + result = (ma_atomicType)intrin##_rel((volatile msvcType*)ptr, (msvcType)expected, (msvcType)replacement); \ } break; \ case ma_atomic_memory_order_acq_rel: \ case ma_atomic_memory_order_seq_cst: \ default: \ { \ - result = (ma_atomicType)intrin((volatile msvcType*)ptr, (msvcType)expected, (msvcType)desired); \ + result = (ma_atomicType)intrin((volatile msvcType*)ptr, (msvcType)expected, (msvcType)replacement); \ } break; \ } \ return result; - #define ma_atomic_memory_order_relaxed 0 - #define ma_atomic_memory_order_consume 1 - #define ma_atomic_memory_order_acquire 2 - #define ma_atomic_memory_order_release 3 - #define ma_atomic_memory_order_acq_rel 4 - #define ma_atomic_memory_order_seq_cst 5 - #if _MSC_VER < 1600 && defined(MA_X86) - #define MA_ATOMIC_MSVC_USE_INLINED_ASSEMBLY + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) + #define ma_atomic_compare_and_swap_8( dst, expected, replacement) (ma_uint8 )_InterlockedCompareExchange8((volatile char*)dst, (char)replacement, (char)expected) + #else + static MA_INLINE ma_uint8 __stdcall ma_atomic_compare_and_swap_8(volatile ma_uint8* dst, ma_uint8 expected, ma_uint8 replacement) + { + MA_ATOMIC_COMPARE_AND_SWAP_LOCK(8, dst, expected, replacement); + } + #endif + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) + #define ma_atomic_compare_and_swap_16(dst, expected, replacement) (ma_uint16)_InterlockedCompareExchange16((volatile short*)dst, (short)replacement, (short)expected) + #else + static MA_INLINE ma_uint16 __stdcall ma_atomic_compare_and_swap_16(volatile ma_uint16* dst, ma_uint16 expected, ma_uint16 replacement) + { + MA_ATOMIC_COMPARE_AND_SWAP_LOCK(16, dst, expected, replacement); + } #endif - #if _MSC_VER < 1600 - #undef MA_ATOMIC_HAS_8 - #undef MA_ATOMIC_HAS_16 + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) + #define ma_atomic_compare_and_swap_32(dst, expected, replacement) (ma_uint32)_InterlockedCompareExchange((volatile long*)dst, (long)replacement, (long)expected) + #else + static MA_INLINE ma_uint32 __stdcall ma_atomic_compare_and_swap_32(volatile ma_uint32* dst, ma_uint32 expected, ma_uint32 replacement) + { + MA_ATOMIC_COMPARE_AND_SWAP_LOCK(32, dst, expected, replacement); + } #endif - #if !defined(MA_ATOMIC_MSVC_USE_INLINED_ASSEMBLY) - #include + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) + #define ma_atomic_compare_and_swap_64(dst, expected, replacement) (ma_uint64)_InterlockedCompareExchange64((volatile ma_int64*)dst, (ma_int64)replacement, (ma_int64)expected) + #else + static MA_INLINE ma_uint64 __stdcall ma_atomic_compare_and_swap_64(volatile ma_uint64* dst, ma_uint64 expected, ma_uint64 replacement) + { + MA_ATOMIC_COMPARE_AND_SWAP_LOCK(64, dst, expected, replacement); + } #endif - #if defined(MA_ATOMIC_MSVC_USE_INLINED_ASSEMBLY) - #if defined(MA_ATOMIC_HAS_8) - static MA_INLINE ma_uint8 __stdcall ma_atomic_compare_and_swap_8(volatile ma_uint8* dst, ma_uint8 expected, ma_uint8 desired) + static MA_INLINE ma_uint8 ma_atomic_load_explicit_8(volatile const ma_uint8* ptr, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) + { + #if defined(MA_ARM) { - ma_uint8 result = 0; - __asm { - mov ecx, dst - mov al, expected - mov dl, desired - lock cmpxchg [ecx], dl - mov result, al - } - return result; + MA_ATOMIC_MSVC_ARM_INTRINSIC_COMPARE_EXCHANGE(ptr, 0, 0, order, _InterlockedCompareExchange8, ma_uint8, char); } - #endif - #if defined(MA_ATOMIC_HAS_16) - static MA_INLINE ma_uint16 __stdcall ma_atomic_compare_and_swap_16(volatile ma_uint16* dst, ma_uint16 expected, ma_uint16 desired) + #else { - ma_uint16 result = 0; - __asm { - mov ecx, dst - mov ax, expected - mov dx, desired - lock cmpxchg [ecx], dx - mov result, ax - } - return result; + (void)order; + return ma_atomic_compare_and_swap_8((volatile ma_uint8*)ptr, 0, 0); } + #endif + } + #else + { + MA_ATOMIC_LOAD_EXPLICIT_LOCK(8, ptr, order); + } #endif - #if defined(MA_ATOMIC_HAS_32) - static MA_INLINE ma_uint32 __stdcall ma_atomic_compare_and_swap_32(volatile ma_uint32* dst, ma_uint32 expected, ma_uint32 desired) + } + static MA_INLINE ma_uint16 ma_atomic_load_explicit_16(volatile const ma_uint16* ptr, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) + { + #if defined(MA_ARM) { - ma_uint32 result = 0; - __asm { - mov ecx, dst - mov eax, expected - mov edx, desired - lock cmpxchg [ecx], edx - mov result, eax - } - return result; + MA_ATOMIC_MSVC_ARM_INTRINSIC_COMPARE_EXCHANGE(ptr, 0, 0, order, _InterlockedCompareExchange16, ma_uint16, short); } - #endif - #if defined(MA_ATOMIC_HAS_64) - static MA_INLINE ma_uint64 __stdcall ma_atomic_compare_and_swap_64(volatile ma_uint64* dst, ma_uint64 expected, ma_uint64 desired) + #else { - ma_uint32 resultEAX = 0; - ma_uint32 resultEDX = 0; - __asm { - mov esi, dst - mov eax, dword ptr expected - mov edx, dword ptr expected + 4 - mov ebx, dword ptr desired - mov ecx, dword ptr desired + 4 - lock cmpxchg8b qword ptr [esi] - mov resultEAX, eax - mov resultEDX, edx - } - return ((ma_uint64)resultEDX << 32) | resultEAX; + (void)order; + return ma_atomic_compare_and_swap_16((volatile ma_uint16*)ptr, 0, 0); } + #endif + } + #else + { + MA_ATOMIC_LOAD_EXPLICIT_LOCK(16, ptr, order); + } #endif - #else - #if defined(MA_ATOMIC_HAS_8) - #define ma_atomic_compare_and_swap_8( dst, expected, desired) (ma_uint8 )_InterlockedCompareExchange8((volatile char*)dst, (char)desired, (char)expected) - #endif - #if defined(MA_ATOMIC_HAS_16) - #define ma_atomic_compare_and_swap_16(dst, expected, desired) (ma_uint16)_InterlockedCompareExchange16((volatile short*)dst, (short)desired, (short)expected) - #endif - #if defined(MA_ATOMIC_HAS_32) - #define ma_atomic_compare_and_swap_32(dst, expected, desired) (ma_uint32)_InterlockedCompareExchange((volatile long*)dst, (long)desired, (long)expected) - #endif - #if defined(MA_ATOMIC_HAS_64) - #define ma_atomic_compare_and_swap_64(dst, expected, desired) (ma_uint64)_InterlockedCompareExchange64((volatile ma_int64*)dst, (ma_int64)desired, (ma_int64)expected) - #endif - #endif - #if defined(MA_ATOMIC_MSVC_USE_INLINED_ASSEMBLY) - #if defined(MA_ATOMIC_HAS_8) - static MA_INLINE ma_uint8 __stdcall ma_atomic_exchange_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + } + static MA_INLINE ma_uint32 ma_atomic_load_explicit_32(volatile const ma_uint32* ptr, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) + { + #if defined(MA_ARM) { - ma_uint8 result = 0; - (void)order; - __asm { - mov ecx, dst - mov al, src - lock xchg [ecx], al - mov result, al - } - return result; + MA_ATOMIC_MSVC_ARM_INTRINSIC_COMPARE_EXCHANGE(ptr, 0, 0, order, _InterlockedCompareExchange, ma_uint32, long); } - #endif - #if defined(MA_ATOMIC_HAS_16) - static MA_INLINE ma_uint16 __stdcall ma_atomic_exchange_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + #else { - ma_uint16 result = 0; (void)order; - __asm { - mov ecx, dst - mov ax, src - lock xchg [ecx], ax - mov result, ax - } - return result; + return ma_atomic_compare_and_swap_32((volatile ma_uint32*)ptr, 0, 0); } + #endif + } + #else + { + MA_ATOMIC_LOAD_EXPLICIT_LOCK(32, ptr, order); + } #endif - #if defined(MA_ATOMIC_HAS_32) - static MA_INLINE ma_uint32 __stdcall ma_atomic_exchange_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + } + static MA_INLINE ma_uint64 ma_atomic_load_explicit_64(volatile const ma_uint64* ptr, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) + { + #if defined(MA_ARM) + { + MA_ATOMIC_MSVC_ARM_INTRINSIC_COMPARE_EXCHANGE(ptr, 0, 0, order, _InterlockedCompareExchange64, ma_uint64, long long); + } + #else { - ma_uint32 result = 0; (void)order; - __asm { - mov ecx, dst - mov eax, src - lock xchg [ecx], eax - mov result, eax - } - return result; + return ma_atomic_compare_and_swap_64((volatile ma_uint64*)ptr, 0, 0); } + #endif + } + #else + { + MA_ATOMIC_LOAD_EXPLICIT_LOCK(64, ptr, order); + } #endif - #else - #if defined(MA_ATOMIC_HAS_8) - static MA_INLINE ma_uint8 __stdcall ma_atomic_exchange_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) - { + } + static MA_INLINE ma_uint8 __stdcall ma_atomic_exchange_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) + { #if defined(MA_ARM) + { MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedExchange8, ma_uint8, char); + } #else + { (void)order; return (ma_uint8)_InterlockedExchange8((volatile char*)dst, (char)src); - #endif } + #endif + } + #else + { + MA_ATOMIC_EXCHANGE_EXPLICIT_LOCK(8, dst, src, order); + } #endif - #if defined(MA_ATOMIC_HAS_16) - static MA_INLINE ma_uint16 __stdcall ma_atomic_exchange_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) - { + } + static MA_INLINE ma_uint16 __stdcall ma_atomic_exchange_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) + { #if defined(MA_ARM) + { MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedExchange16, ma_uint16, short); + } #else + { (void)order; return (ma_uint16)_InterlockedExchange16((volatile short*)dst, (short)src); - #endif } + #endif + } + #else + { + MA_ATOMIC_EXCHANGE_EXPLICIT_LOCK(16, dst, src, order); + } #endif - #if defined(MA_ATOMIC_HAS_32) - static MA_INLINE ma_uint32 __stdcall ma_atomic_exchange_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) - { + } + static MA_INLINE ma_uint32 __stdcall ma_atomic_exchange_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) + { #if defined(MA_ARM) + { MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedExchange, ma_uint32, long); + } #else + { (void)order; return (ma_uint32)_InterlockedExchange((volatile long*)dst, (long)src); - #endif } - #endif - #if defined(MA_ATOMIC_HAS_64) && defined(MA_64BIT) - static MA_INLINE ma_uint64 __stdcall ma_atomic_exchange_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) - { - #if defined(MA_ARM) - MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedExchange64, ma_uint64, long long); - #else - (void)order; - return (ma_uint64)_InterlockedExchange64((volatile long long*)dst, (long long)src); #endif - } + } #else - #endif - #endif - #if defined(MA_ATOMIC_HAS_64) && !defined(MA_64BIT) - static MA_INLINE ma_uint64 __stdcall ma_atomic_exchange_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) { - ma_uint64 oldValue; - do { - oldValue = *dst; - } while (ma_atomic_compare_and_swap_64(dst, oldValue, src) != oldValue); - (void)order; - return oldValue; + MA_ATOMIC_EXCHANGE_EXPLICIT_LOCK(32, dst, src, order); } - #endif - #if defined(MA_ATOMIC_MSVC_USE_INLINED_ASSEMBLY) - #if defined(MA_ATOMIC_HAS_8) - static MA_INLINE ma_uint8 __stdcall ma_atomic_fetch_add_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + #endif + } + static MA_INLINE ma_uint64 __stdcall ma_atomic_exchange_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) + { + #if defined(MA_32BIT) { - ma_uint8 result = 0; - (void)order; - __asm { - mov ecx, dst - mov al, src - lock xadd [ecx], al - mov result, al - } - return result; + MA_ATOMIC_EXCHANGE_EXPLICIT_CAS(64, dst, src, order); } - #endif - #if defined(MA_ATOMIC_HAS_16) - static MA_INLINE ma_uint16 __stdcall ma_atomic_fetch_add_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + #else { - ma_uint16 result = 0; - (void)order; - __asm { - mov ecx, dst - mov ax, src - lock xadd [ecx], ax - mov result, ax + #if defined(MA_ARM) + { + MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedExchange64, ma_uint64, long long); } - return result; + #else + { + (void)order; + return (ma_uint64)_InterlockedExchange64((volatile long long*)dst, (long long)src); + } + #endif } + #endif + } + #else + { + MA_ATOMIC_EXCHANGE_EXPLICIT_LOCK(64, dst, src, order); + } #endif - #if defined(MA_ATOMIC_HAS_32) - static MA_INLINE ma_uint32 __stdcall ma_atomic_fetch_add_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + } + static MA_INLINE ma_uint8 __stdcall ma_atomic_fetch_add_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) + { + #if defined(MA_ARM) { - ma_uint32 result = 0; - (void)order; - __asm { - mov ecx, dst - mov eax, src - lock xadd [ecx], eax - mov result, eax - } - return result; - } - #endif - #else - #if defined(MA_ATOMIC_HAS_8) - static MA_INLINE ma_uint8 __stdcall ma_atomic_fetch_add_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) - { - #if defined(MA_ARM) MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedExchangeAdd8, ma_uint8, char); + } #else + { (void)order; return (ma_uint8)_InterlockedExchangeAdd8((volatile char*)dst, (char)src); - #endif } + #endif + } + #else + { + MA_ATOMIC_FETCH_ADD_LOCK(8, dst, src, order); + } #endif - #if defined(MA_ATOMIC_HAS_16) - static MA_INLINE ma_uint16 __stdcall ma_atomic_fetch_add_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) - { + } + static MA_INLINE ma_uint16 __stdcall ma_atomic_fetch_add_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) + { #if defined(MA_ARM) + { MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedExchangeAdd16, ma_uint16, short); + } #else + { (void)order; return (ma_uint16)_InterlockedExchangeAdd16((volatile short*)dst, (short)src); - #endif } + #endif + } + #else + { + MA_ATOMIC_FETCH_ADD_LOCK(16, dst, src, order); + } #endif - #if defined(MA_ATOMIC_HAS_32) - static MA_INLINE ma_uint32 __stdcall ma_atomic_fetch_add_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) - { + } + static MA_INLINE ma_uint32 __stdcall ma_atomic_fetch_add_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) + { #if defined(MA_ARM) + { MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedExchangeAdd, ma_uint32, long); + } #else + { (void)order; return (ma_uint32)_InterlockedExchangeAdd((volatile long*)dst, (long)src); - #endif } + #endif + } + #else + { + MA_ATOMIC_FETCH_ADD_LOCK(32, dst, src, order); + } #endif - #if defined(MA_ATOMIC_HAS_64) && defined(MA_64BIT) - static MA_INLINE ma_uint64 __stdcall ma_atomic_fetch_add_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + } + static MA_INLINE ma_uint64 __stdcall ma_atomic_fetch_add_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) + { + #if defined(MA_32BIT) { - #if defined(MA_ARM) - MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedExchangeAdd64, ma_uint64, long long); + MA_ATOMIC_FETCH_ADD_CAS(64, dst, src, order); + } #else - (void)order; - return (ma_uint64)_InterlockedExchangeAdd64((volatile long long*)dst, (long long)src); - #endif + { + #if defined(MA_ARM) + { + MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedExchangeAdd64, ma_uint64, long long); + } + #else + { + (void)order; + return (ma_uint64)_InterlockedExchangeAdd64((volatile long long*)dst, (long long)src); + } + #endif } + #endif + } #else - #endif - #endif - #if defined(MA_ATOMIC_HAS_64) && !defined(MA_64BIT) - static MA_INLINE ma_uint64 __stdcall ma_atomic_fetch_add_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) - { - ma_uint64 oldValue; - ma_uint64 newValue; - do { - oldValue = *dst; - newValue = oldValue + src; - } while (ma_atomic_compare_and_swap_64(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; + { + MA_ATOMIC_FETCH_ADD_LOCK(64, dst, src, order); } - #endif - #if defined(MA_ATOMIC_MSVC_USE_INLINED_ASSEMBLY) - static MA_INLINE void __stdcall ma_atomic_thread_fence(ma_atomic_memory_order order) + #endif + } + static MA_INLINE ma_uint8 __stdcall ma_atomic_fetch_sub_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + { + return ma_atomic_fetch_add_explicit_8(dst, (ma_uint8)(-(ma_int8)src), order); + } + static MA_INLINE ma_uint16 __stdcall ma_atomic_fetch_sub_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + { + return ma_atomic_fetch_add_explicit_16(dst, (ma_uint16)(-(ma_int16)src), order); + } + static MA_INLINE ma_uint32 __stdcall ma_atomic_fetch_sub_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + { + return ma_atomic_fetch_add_explicit_32(dst, (ma_uint32)(-(ma_int32)src), order); + } + static MA_INLINE ma_uint64 __stdcall ma_atomic_fetch_sub_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + { + return ma_atomic_fetch_add_explicit_64(dst, (ma_uint64)(-(ma_int64)src), order); + } + static MA_INLINE ma_uint8 __stdcall ma_atomic_fetch_and_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + { + #if defined(MA_ARM) { - (void)order; - __asm { - lock add [esp], 0 - } + MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedAnd8, ma_uint8, char); } - #else - #if defined(MA_X64) - #define ma_atomic_thread_fence(order) __faststorefence(), (void)order - #elif defined(MA_ARM64) - #define ma_atomic_thread_fence(order) __dmb(_ARM64_BARRIER_ISH), (void)order #else - static MA_INLINE void ma_atomic_thread_fence(ma_atomic_memory_order order) - { - volatile ma_uint32 barrier = 0; - ma_atomic_fetch_add_explicit_32(&barrier, 0, order); - } - #endif - #endif - #define ma_atomic_compiler_fence() ma_atomic_thread_fence(ma_atomic_memory_order_seq_cst) - #define ma_atomic_signal_fence(order) ma_atomic_thread_fence(order) - #if defined(MA_ATOMIC_HAS_8) - static MA_INLINE ma_uint8 ma_atomic_load_explicit_8(volatile const ma_uint8* ptr, ma_atomic_memory_order order) { + MA_ATOMIC_FETCH_AND_CAS(8, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint16 __stdcall ma_atomic_fetch_and_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + { #if defined(MA_ARM) - MA_ATOMIC_MSVC_ARM_INTRINSIC_COMPARE_EXCHANGE(ptr, 0, 0, order, _InterlockedCompareExchange8, ma_uint8, char); + { + MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedAnd16, ma_uint16, short); + } #else - (void)order; - return ma_atomic_compare_and_swap_8((volatile ma_uint8*)ptr, 0, 0); + { + MA_ATOMIC_FETCH_AND_CAS(16, dst, src, order); + } #endif + } + static MA_INLINE ma_uint32 __stdcall ma_atomic_fetch_and_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + { + #if defined(MA_ARM) + { + MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedAnd, ma_uint32, long); } - #endif - #if defined(MA_ATOMIC_HAS_16) - static MA_INLINE ma_uint16 ma_atomic_load_explicit_16(volatile const ma_uint16* ptr, ma_atomic_memory_order order) + #else { + MA_ATOMIC_FETCH_AND_CAS(32, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint64 __stdcall ma_atomic_fetch_and_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + { #if defined(MA_ARM) - MA_ATOMIC_MSVC_ARM_INTRINSIC_COMPARE_EXCHANGE(ptr, 0, 0, order, _InterlockedCompareExchange16, ma_uint16, short); + { + MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedAnd64, ma_uint64, long long); + } #else - (void)order; - return ma_atomic_compare_and_swap_16((volatile ma_uint16*)ptr, 0, 0); + { + MA_ATOMIC_FETCH_AND_CAS(64, dst, src, order); + } #endif + } + static MA_INLINE ma_uint8 __stdcall ma_atomic_fetch_or_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + { + #if defined(MA_ARM) + { + MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedOr8, ma_uint8, char); } - #endif - #if defined(MA_ATOMIC_HAS_32) - static MA_INLINE ma_uint32 ma_atomic_load_explicit_32(volatile const ma_uint32* ptr, ma_atomic_memory_order order) + #else { + MA_ATOMIC_FETCH_OR_CAS(8, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint16 __stdcall ma_atomic_fetch_or_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + { #if defined(MA_ARM) - MA_ATOMIC_MSVC_ARM_INTRINSIC_COMPARE_EXCHANGE(ptr, 0, 0, order, _InterlockedCompareExchange, ma_uint32, long); + { + MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedOr16, ma_uint16, short); + } #else - (void)order; - return ma_atomic_compare_and_swap_32((volatile ma_uint32*)ptr, 0, 0); + { + MA_ATOMIC_FETCH_OR_CAS(16, dst, src, order); + } #endif + } + static MA_INLINE ma_uint32 __stdcall ma_atomic_fetch_or_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + { + #if defined(MA_ARM) + { + MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedOr, ma_uint32, long); } - #endif - #if defined(MA_ATOMIC_HAS_64) - static MA_INLINE ma_uint64 ma_atomic_load_explicit_64(volatile const ma_uint64* ptr, ma_atomic_memory_order order) + #else { + MA_ATOMIC_FETCH_OR_CAS(32, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint64 __stdcall ma_atomic_fetch_or_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + { #if defined(MA_ARM) - MA_ATOMIC_MSVC_ARM_INTRINSIC_COMPARE_EXCHANGE(ptr, 0, 0, order, _InterlockedCompareExchange64, ma_uint64, long long); + { + MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedOr64, ma_uint64, long long); + } #else - (void)order; - return ma_atomic_compare_and_swap_64((volatile ma_uint64*)ptr, 0, 0); + { + MA_ATOMIC_FETCH_OR_CAS(64, dst, src, order); + } #endif + } + static MA_INLINE ma_uint8 __stdcall ma_atomic_fetch_xor_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + { + #if defined(MA_ARM) + { + MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedXor8, ma_uint8, char); } - #endif - #if defined(MA_ATOMIC_HAS_8) - #define ma_atomic_store_explicit_8( dst, src, order) (void)ma_atomic_exchange_explicit_8 (dst, src, order) - #endif - #if defined(MA_ATOMIC_HAS_16) - #define ma_atomic_store_explicit_16(dst, src, order) (void)ma_atomic_exchange_explicit_16(dst, src, order) - #endif - #if defined(MA_ATOMIC_HAS_32) - #define ma_atomic_store_explicit_32(dst, src, order) (void)ma_atomic_exchange_explicit_32(dst, src, order) - #endif - #if defined(MA_ATOMIC_HAS_64) - #define ma_atomic_store_explicit_64(dst, src, order) (void)ma_atomic_exchange_explicit_64(dst, src, order) - #endif - #if defined(MA_ATOMIC_HAS_8) - static MA_INLINE ma_uint8 __stdcall ma_atomic_fetch_sub_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) - { - ma_uint8 oldValue; - ma_uint8 newValue; - do { - oldValue = *dst; - newValue = (ma_uint8)(oldValue - src); - } while (ma_atomic_compare_and_swap_8(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; + #else + { + MA_ATOMIC_FETCH_XOR_CAS(8, dst, src, order); } - #endif - #if defined(MA_ATOMIC_HAS_16) - static MA_INLINE ma_uint16 __stdcall ma_atomic_fetch_sub_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) - { - ma_uint16 oldValue; - ma_uint16 newValue; - do { - oldValue = *dst; - newValue = (ma_uint16)(oldValue - src); - } while (ma_atomic_compare_and_swap_16(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; + #endif + } + static MA_INLINE ma_uint16 __stdcall ma_atomic_fetch_xor_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + { + #if defined(MA_ARM) + { + MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedXor16, ma_uint16, short); } - #endif - #if defined(MA_ATOMIC_HAS_32) - static MA_INLINE ma_uint32 __stdcall ma_atomic_fetch_sub_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) - { - ma_uint32 oldValue; - ma_uint32 newValue; - do { - oldValue = *dst; - newValue = oldValue - src; - } while (ma_atomic_compare_and_swap_32(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; + #else + { + MA_ATOMIC_FETCH_XOR_CAS(16, dst, src, order); } - #endif - #if defined(MA_ATOMIC_HAS_64) - static MA_INLINE ma_uint64 __stdcall ma_atomic_fetch_sub_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) - { - ma_uint64 oldValue; - ma_uint64 newValue; - do { - oldValue = *dst; - newValue = oldValue - src; - } while (ma_atomic_compare_and_swap_64(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; + #endif + } + static MA_INLINE ma_uint32 __stdcall ma_atomic_fetch_xor_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + { + #if defined(MA_ARM) + { + MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedXor, ma_uint32, long); } - #endif - #if defined(MA_ATOMIC_HAS_8) - static MA_INLINE ma_uint8 __stdcall ma_atomic_fetch_and_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + #else { + MA_ATOMIC_FETCH_XOR_CAS(32, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint64 __stdcall ma_atomic_fetch_xor_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + { #if defined(MA_ARM) - MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedAnd8, ma_uint8, char); + { + MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedXor64, ma_uint64, long long); + } #else - ma_uint8 oldValue; - ma_uint8 newValue; - do { - oldValue = *dst; - newValue = (ma_uint8)(oldValue & src); - } while (ma_atomic_compare_and_swap_8(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; + { + MA_ATOMIC_FETCH_XOR_CAS(64, dst, src, order); + } #endif + } + #define ma_atomic_store_explicit_8( dst, src, order) (void)ma_atomic_exchange_explicit_8 (dst, src, order) + #define ma_atomic_store_explicit_16(dst, src, order) (void)ma_atomic_exchange_explicit_16(dst, src, order) + #define ma_atomic_store_explicit_32(dst, src, order) (void)ma_atomic_exchange_explicit_32(dst, src, order) + #define ma_atomic_store_explicit_64(dst, src, order) (void)ma_atomic_exchange_explicit_64(dst, src, order) + #if defined(MA_X64) + #define ma_atomic_thread_fence(order) __faststorefence(), (void)order + #elif defined(MA_ARM64) + #define ma_atomic_thread_fence(order) __dmb(_ARM64_BARRIER_ISH), (void)order + #else + static MA_INLINE void ma_atomic_thread_fence(ma_atomic_memory_order order) + { + volatile ma_uint32 barrier = 0; + ma_atomic_fetch_add_explicit_32(&barrier, 0, order); } #endif - #if defined(MA_ATOMIC_HAS_16) - static MA_INLINE ma_uint16 __stdcall ma_atomic_fetch_and_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + #define ma_atomic_signal_fence(order) _ReadWriteBarrier(), (void)order +#endif +#if defined(MA_ATOMIC_LEGACY_MSVC_ASM) + static MA_INLINE ma_uint8 __stdcall ma_atomic_compare_and_swap_8(volatile ma_uint8* dst, ma_uint8 expected, ma_uint8 replacement) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) { - #if defined(MA_ARM) - MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedAnd16, ma_uint16, short); + ma_uint8 result = 0; + __asm { + mov ecx, dst + mov al, expected + mov dl, replacement + lock cmpxchg [ecx], dl + mov result, al + } + return result; + } #else - ma_uint16 oldValue; - ma_uint16 newValue; - do { - oldValue = *dst; - newValue = (ma_uint16)(oldValue & src); - } while (ma_atomic_compare_and_swap_16(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; - #endif + { + MA_ATOMIC_COMPARE_AND_SWAP_LOCK(8, dst, expected, replacement); } - #endif - #if defined(MA_ATOMIC_HAS_32) - static MA_INLINE ma_uint32 __stdcall ma_atomic_fetch_and_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + #endif + } + static MA_INLINE ma_uint16 __stdcall ma_atomic_compare_and_swap_16(volatile ma_uint16* dst, ma_uint16 expected, ma_uint16 replacement) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) { - #if defined(MA_ARM) - MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedAnd, ma_uint32, long); + ma_uint16 result = 0; + __asm { + mov ecx, dst + mov ax, expected + mov dx, replacement + lock cmpxchg [ecx], dx + mov result, ax + } + return result; + } #else - ma_uint32 oldValue; - ma_uint32 newValue; - do { - oldValue = *dst; - newValue = oldValue & src; - } while (ma_atomic_compare_and_swap_32(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; - #endif + { + MA_ATOMIC_COMPARE_AND_SWAP_LOCK(16, dst, expected, replacement); } - #endif - #if defined(MA_ATOMIC_HAS_64) - static MA_INLINE ma_uint64 __stdcall ma_atomic_fetch_and_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + #endif + } + static MA_INLINE ma_uint32 __stdcall ma_atomic_compare_and_swap_32(volatile ma_uint32* dst, ma_uint32 expected, ma_uint32 replacement) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) { - #if defined(MA_ARM) - MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedAnd64, ma_uint64, long long); + ma_uint32 result = 0; + __asm { + mov ecx, dst + mov eax, expected + mov edx, replacement + lock cmpxchg [ecx], edx + mov result, eax + } + return result; + } #else - ma_uint64 oldValue; - ma_uint64 newValue; - do { - oldValue = *dst; - newValue = oldValue & src; - } while (ma_atomic_compare_and_swap_64(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; - #endif + { + MA_ATOMIC_COMPARE_AND_SWAP_LOCK(32, dst, expected, replacement); } - #endif - #if defined(MA_ATOMIC_HAS_8) - static MA_INLINE ma_uint8 __stdcall ma_atomic_fetch_xor_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + #endif + } + static MA_INLINE ma_uint64 __stdcall ma_atomic_compare_and_swap_64(volatile ma_uint64* dst, ma_uint64 expected, ma_uint64 replacement) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) { - #if defined(MA_ARM) - MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedXor8, ma_uint8, char); + ma_uint32 resultEAX = 0; + ma_uint32 resultEDX = 0; + __asm { + mov esi, dst + mov eax, dword ptr expected + mov edx, dword ptr expected + 4 + mov ebx, dword ptr replacement + mov ecx, dword ptr replacement + 4 + lock cmpxchg8b qword ptr [esi] + mov resultEAX, eax + mov resultEDX, edx + } + return ((ma_uint64)resultEDX << 32) | resultEAX; + } #else - ma_uint8 oldValue; - ma_uint8 newValue; - do { - oldValue = *dst; - newValue = (ma_uint8)(oldValue ^ src); - } while (ma_atomic_compare_and_swap_8(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; - #endif + { + MA_ATOMIC_COMPARE_AND_SWAP_LOCK(64, dst, expected, replacement); } - #endif - #if defined(MA_ATOMIC_HAS_16) - static MA_INLINE ma_uint16 __stdcall ma_atomic_fetch_xor_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + #endif + } + static MA_INLINE ma_uint8 ma_atomic_load_explicit_8(volatile const ma_uint8* dst, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) { - #if defined(MA_ARM) - MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedXor16, ma_uint16, short); + ma_uint8 result = 0; + if (order == ma_atomic_memory_order_relaxed) { + __asm { + mov esi, dst + mov al, [esi] + mov result, al + } + } else if (order <= ma_atomic_memory_order_release) { + __asm { + mov esi, dst + mov al, [esi] + lock add dword ptr [esp], 0 + mov result, al + } + } else { + __asm { + lock add dword ptr [esp], 0 + mov esi, dst + mov al, [esi] + mov result, al + lock add dword ptr [esp], 0 + } + } + return result; + } #else - ma_uint16 oldValue; - ma_uint16 newValue; - do { - oldValue = *dst; - newValue = (ma_uint16)(oldValue ^ src); - } while (ma_atomic_compare_and_swap_16(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; + { + MA_ATOMIC_LOAD_EXPLICIT_LOCK(8, dst, order); + } #endif + } + static MA_INLINE ma_uint16 ma_atomic_load_explicit_16(volatile const ma_uint16* dst, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) + { + ma_uint16 result = 0; + if (order == ma_atomic_memory_order_relaxed) { + __asm { + mov esi, dst + mov ax, [esi] + mov result, ax + } + } else if (order <= ma_atomic_memory_order_release) { + __asm { + mov esi, dst + mov ax, [esi] + lock add dword ptr [esp], 0 + mov result, ax + } + } else { + __asm { + lock add dword ptr [esp], 0 + mov esi, dst + mov ax, [esi] + mov result, ax + lock add dword ptr [esp], 0 + } + } + return result; } - #endif - #if defined(MA_ATOMIC_HAS_32) - static MA_INLINE ma_uint32 __stdcall ma_atomic_fetch_xor_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + #else { - #if defined(MA_ARM) - MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedXor, ma_uint32, long); + MA_ATOMIC_LOAD_EXPLICIT_LOCK(16, dst, order); + } + #endif + } + static MA_INLINE ma_uint32 ma_atomic_load_explicit_32(volatile const ma_uint32* dst, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) + { + ma_uint32 result = 0; + if (order == ma_atomic_memory_order_relaxed) { + __asm { + mov esi, dst + mov eax, [esi] + mov result, eax + } + } else if (order <= ma_atomic_memory_order_release) { + __asm { + mov esi, dst + mov eax, [esi] + lock add dword ptr [esp], 0 + mov result, eax + } + } else { + __asm { + lock add dword ptr [esp], 0 + mov esi, dst + mov eax, [esi] + mov result, eax + lock add dword ptr [esp], 0 + } + } + return result; + } #else - ma_uint32 oldValue; - ma_uint32 newValue; - do { - oldValue = *dst; - newValue = oldValue ^ src; - } while (ma_atomic_compare_and_swap_32(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; + { + MA_ATOMIC_LOAD_EXPLICIT_LOCK(32, dst, order); + } #endif + } + static MA_INLINE ma_uint64 ma_atomic_load_explicit_64(volatile const ma_uint64* dst, ma_atomic_memory_order order) + { + (void)order; + return ma_atomic_compare_and_swap_64((volatile ma_uint64*)dst, 0, 0); + } + static MA_INLINE void __stdcall ma_atomic_store_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + { + if (order == ma_atomic_memory_order_relaxed) { + __asm { + mov esi, dst + mov al, src + mov [esi], al + } + } else { + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) + { + __asm { + mov esi, dst + mov al, src + xchg [esi], al + } + } + #else + { + MA_ATOMIC_STORE_EXPLICIT_LOCK(8, dst, src, order); + } + #endif } - #endif - #if defined(MA_ATOMIC_HAS_64) - static MA_INLINE ma_uint64 __stdcall ma_atomic_fetch_xor_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + } + static MA_INLINE void __stdcall ma_atomic_store_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + { + if (order == ma_atomic_memory_order_relaxed) { + __asm { + mov esi, dst + mov ax, src + mov [esi], ax + } + } else { + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) + { + __asm { + mov esi, dst + mov ax, src + xchg [esi], ax + } + } + #else + { + MA_ATOMIC_STORE_EXPLICIT_LOCK(16, dst, src, order); + } + #endif + } + } + static MA_INLINE void __stdcall ma_atomic_store_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + { + if (order == ma_atomic_memory_order_relaxed) { + __asm { + mov esi, dst + mov eax, src + mov [esi], eax + } + } else { + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) + { + __asm { + mov esi, dst + mov eax, src + xchg [esi], eax + } + } + #else + { + MA_ATOMIC_STORE_EXPLICIT_LOCK(32, dst, src, order); + } + #endif + } + } + static MA_INLINE void __stdcall ma_atomic_store_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) { - #if defined(MA_ARM) - MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedXor64, ma_uint64, long long); + MA_ATOMIC_STORE_EXPLICIT_CAS(64, dst, src, order); + } #else - ma_uint64 oldValue; - ma_uint64 newValue; - do { - oldValue = *dst; - newValue = oldValue ^ src; - } while (ma_atomic_compare_and_swap_64(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; + { + MA_ATOMIC_STORE_EXPLICIT_LOCK(64, dst, src, order); + } #endif + } + static MA_INLINE ma_uint8 __stdcall ma_atomic_exchange_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) + { + ma_uint8 result = 0; + (void)order; + __asm { + mov ecx, dst + mov al, src + lock xchg [ecx], al + mov result, al + } + return result; } - #endif - #if defined(MA_ATOMIC_HAS_8) - static MA_INLINE ma_uint8 __stdcall ma_atomic_fetch_or_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + #else { - #if defined(MA_ARM) - MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedOr8, ma_uint8, char); + MA_ATOMIC_EXCHANGE_EXPLICIT_LOCK(8, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint16 __stdcall ma_atomic_exchange_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) + { + ma_uint16 result = 0; + (void)order; + __asm { + mov ecx, dst + mov ax, src + lock xchg [ecx], ax + mov result, ax + } + return result; + } #else - ma_uint8 oldValue; - ma_uint8 newValue; - do { - oldValue = *dst; - newValue = (ma_uint8)(oldValue | src); - } while (ma_atomic_compare_and_swap_8(dst, oldValue, newValue) != oldValue); + { + MA_ATOMIC_EXCHANGE_EXPLICIT_LOCK(16, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint32 __stdcall ma_atomic_exchange_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) + { + ma_uint32 result = 0; (void)order; - return oldValue; + __asm { + mov ecx, dst + mov eax, src + xchg [ecx], eax + mov result, eax + } + return result; + } + #else + { + MA_ATOMIC_EXCHANGE_EXPLICIT_LOCK(32, dst, src, order); + } #endif + } + static MA_INLINE ma_uint64 __stdcall ma_atomic_exchange_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) + { + MA_ATOMIC_EXCHANGE_EXPLICIT_CAS(64, dst, src, order); } - #endif - #if defined(MA_ATOMIC_HAS_16) - static MA_INLINE ma_uint16 __stdcall ma_atomic_fetch_or_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + #else { - #if defined(MA_ARM) - MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedOr16, ma_uint16, short); + MA_ATOMIC_EXCHANGE_EXPLICIT_LOCK(64, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint8 __stdcall ma_atomic_fetch_add_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) + { + ma_uint8 result = 0; + (void)order; + __asm { + mov ecx, dst + mov al, src + lock xadd [ecx], al + mov result, al + } + return result; + } #else - ma_uint16 oldValue; - ma_uint16 newValue; - do { - oldValue = *dst; - newValue = (ma_uint16)(oldValue | src); - } while (ma_atomic_compare_and_swap_16(dst, oldValue, newValue) != oldValue); + { + MA_ATOMIC_FETCH_ADD_LOCK(8, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint16 __stdcall ma_atomic_fetch_add_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) + { + ma_uint16 result = 0; (void)order; - return oldValue; + __asm { + mov ecx, dst + mov ax, src + lock xadd [ecx], ax + mov result, ax + } + return result; + } + #else + { + MA_ATOMIC_FETCH_ADD_LOCK(16, dst, src, order); + } #endif + } + static MA_INLINE ma_uint32 __stdcall ma_atomic_fetch_add_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) + { + ma_uint32 result = 0; + (void)order; + __asm { + mov ecx, dst + mov eax, src + lock xadd [ecx], eax + mov result, eax + } + return result; } - #endif - #if defined(MA_ATOMIC_HAS_32) - static MA_INLINE ma_uint32 __stdcall ma_atomic_fetch_or_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + #else { - #if defined(MA_ARM) - MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedOr, ma_uint32, long); + MA_ATOMIC_FETCH_ADD_LOCK(32, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint64 __stdcall ma_atomic_fetch_add_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) + { + MA_ATOMIC_FETCH_ADD_CAS(64, dst, src, order); + } #else - ma_uint32 oldValue; - ma_uint32 newValue; - do { - oldValue = *dst; - newValue = oldValue | src; - } while (ma_atomic_compare_and_swap_32(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; + { + MA_ATOMIC_FETCH_ADD_LOCK(64, dst, src, order); + } #endif + } + static MA_INLINE ma_uint8 __stdcall ma_atomic_fetch_sub_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) + { + ma_uint8 result = 0; + (void)order; + __asm { + mov ecx, dst + mov al, src + neg al + lock xadd [ecx], al + mov result, al + } + return result; } - #endif - #if defined(MA_ATOMIC_HAS_64) - static MA_INLINE ma_uint64 __stdcall ma_atomic_fetch_or_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + #else { - #if defined(MA_ARM) - MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedOr64, ma_uint64, long long); + MA_ATOMIC_FETCH_ADD_LOCK(8, dst, (ma_uint8)(-(ma_int8)src), order); + } + #endif + } + static MA_INLINE ma_uint16 __stdcall ma_atomic_fetch_sub_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) + { + ma_uint16 result = 0; + (void)order; + __asm { + mov ecx, dst + mov ax, src + neg ax + lock xadd [ecx], ax + mov result, ax + } + return result; + } #else - ma_uint64 oldValue; - ma_uint64 newValue; - do { - oldValue = *dst; - newValue = oldValue | src; - } while (ma_atomic_compare_and_swap_64(dst, oldValue, newValue) != oldValue); + { + MA_ATOMIC_FETCH_ADD_LOCK(16, dst, (ma_uint16)(-(ma_int16)src), order); + } + #endif + } + static MA_INLINE ma_uint32 __stdcall ma_atomic_fetch_sub_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) + { + ma_uint32 result = 0; (void)order; - return oldValue; + __asm { + mov ecx, dst + mov eax, src + neg eax + lock xadd [ecx], eax + mov result, eax + } + return result; + } + #else + { + MA_ATOMIC_FETCH_ADD_LOCK(32, dst, (ma_uint32)(-(ma_int32)src), order); + } #endif + } + static MA_INLINE ma_uint64 __stdcall ma_atomic_fetch_sub_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + { + MA_ATOMIC_FETCH_ADD_CAS(64, dst, (ma_uint64)(-(ma_int64)src), order); + } + static MA_INLINE ma_uint8 __stdcall ma_atomic_fetch_and_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + { + MA_ATOMIC_FETCH_AND_CAS(8, dst, src, order); + } + static MA_INLINE ma_uint16 __stdcall ma_atomic_fetch_and_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + { + MA_ATOMIC_FETCH_AND_CAS(16, dst, src, order); + } + static MA_INLINE ma_uint32 __stdcall ma_atomic_fetch_and_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + { + MA_ATOMIC_FETCH_AND_CAS(32, dst, src, order); + } + static MA_INLINE ma_uint64 __stdcall ma_atomic_fetch_and_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + { + MA_ATOMIC_FETCH_AND_CAS(64, dst, src, order); + } + static MA_INLINE ma_uint8 __stdcall ma_atomic_fetch_or_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + { + MA_ATOMIC_FETCH_OR_CAS(8, dst, src, order); + } + static MA_INLINE ma_uint16 __stdcall ma_atomic_fetch_or_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + { + MA_ATOMIC_FETCH_OR_CAS(16, dst, src, order); + } + static MA_INLINE ma_uint32 __stdcall ma_atomic_fetch_or_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + { + MA_ATOMIC_FETCH_OR_CAS(32, dst, src, order); + } + static MA_INLINE ma_uint64 __stdcall ma_atomic_fetch_or_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + { + MA_ATOMIC_FETCH_OR_CAS(64, dst, src, order); + } + static MA_INLINE ma_uint8 __stdcall ma_atomic_fetch_xor_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + { + MA_ATOMIC_FETCH_XOR_CAS(8, dst, src, order); + } + static MA_INLINE ma_uint16 __stdcall ma_atomic_fetch_xor_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + { + MA_ATOMIC_FETCH_XOR_CAS(16, dst, src, order); + } + static MA_INLINE ma_uint32 __stdcall ma_atomic_fetch_xor_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + { + MA_ATOMIC_FETCH_XOR_CAS(32, dst, src, order); + } + static MA_INLINE ma_uint64 __stdcall ma_atomic_fetch_xor_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + { + MA_ATOMIC_FETCH_XOR_CAS(64, dst, src, order); + } + static MA_INLINE void __stdcall ma_atomic_thread_fence(ma_atomic_memory_order order) + { + (void)order; + __asm { + lock add dword ptr [esp], 0 } - #endif - #if defined(MA_ATOMIC_HAS_8) - #define ma_atomic_test_and_set_explicit_8( dst, order) ma_atomic_exchange_explicit_8 (dst, 1, order) - #endif - #if defined(MA_ATOMIC_HAS_16) - #define ma_atomic_test_and_set_explicit_16(dst, order) ma_atomic_exchange_explicit_16(dst, 1, order) - #endif - #if defined(MA_ATOMIC_HAS_32) - #define ma_atomic_test_and_set_explicit_32(dst, order) ma_atomic_exchange_explicit_32(dst, 1, order) - #endif - #if defined(MA_ATOMIC_HAS_64) - #define ma_atomic_test_and_set_explicit_64(dst, order) ma_atomic_exchange_explicit_64(dst, 1, order) - #endif - #if defined(MA_ATOMIC_HAS_8) - #define ma_atomic_clear_explicit_8( dst, order) ma_atomic_store_explicit_8 (dst, 0, order) - #endif - #if defined(MA_ATOMIC_HAS_16) - #define ma_atomic_clear_explicit_16(dst, order) ma_atomic_store_explicit_16(dst, 0, order) - #endif - #if defined(MA_ATOMIC_HAS_32) - #define ma_atomic_clear_explicit_32(dst, order) ma_atomic_store_explicit_32(dst, 0, order) - #endif - #if defined(MA_ATOMIC_HAS_64) - #define ma_atomic_clear_explicit_64(dst, order) ma_atomic_store_explicit_64(dst, 0, order) - #endif - #if defined(MA_ATOMIC_HAS_8) - typedef ma_uint8 ma_atomic_flag; - #define ma_atomic_flag_test_and_set_explicit(ptr, order) (ma_bool32)ma_atomic_test_and_set_explicit_8(ptr, order) - #define ma_atomic_flag_clear_explicit(ptr, order) ma_atomic_clear_explicit_8(ptr, order) - #define ma_atomic_flag_load_explicit(ptr, order) ma_atomic_load_explicit_8(ptr, order) - #else - typedef ma_uint32 ma_atomic_flag; - #define ma_atomic_flag_test_and_set_explicit(ptr, order) (ma_bool32)ma_atomic_test_and_set_explicit_32(ptr, order) - #define ma_atomic_flag_clear_explicit(ptr, order) ma_atomic_clear_explicit_32(ptr, order) - #define ma_atomic_flag_load_explicit(ptr, order) ma_atomic_load_explicit_32(ptr, order) - #endif -#elif defined(__clang__) || (defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 7))) + } + #define ma_atomic_signal_fence(order) __asm {}; (void)order +#endif +#if defined(MA_ATOMIC_MODERN_GCC) #define MA_ATOMIC_HAS_NATIVE_COMPARE_EXCHANGE - #define MA_ATOMIC_HAS_NATIVE_IS_LOCK_FREE - #define ma_atomic_memory_order_relaxed __ATOMIC_RELAXED - #define ma_atomic_memory_order_consume __ATOMIC_CONSUME - #define ma_atomic_memory_order_acquire __ATOMIC_ACQUIRE - #define ma_atomic_memory_order_release __ATOMIC_RELEASE - #define ma_atomic_memory_order_acq_rel __ATOMIC_ACQ_REL - #define ma_atomic_memory_order_seq_cst __ATOMIC_SEQ_CST - #define ma_atomic_compiler_fence() __asm__ __volatile__("":::"memory") #define ma_atomic_thread_fence(order) __atomic_thread_fence(order) #define ma_atomic_signal_fence(order) __atomic_signal_fence(order) #define ma_atomic_is_lock_free_8(ptr) __atomic_is_lock_free(1, ptr) #define ma_atomic_is_lock_free_16(ptr) __atomic_is_lock_free(2, ptr) #define ma_atomic_is_lock_free_32(ptr) __atomic_is_lock_free(4, ptr) #define ma_atomic_is_lock_free_64(ptr) __atomic_is_lock_free(8, ptr) - #define ma_atomic_test_and_set_explicit_8( dst, order) __atomic_exchange_n(dst, 1, order) - #define ma_atomic_test_and_set_explicit_16(dst, order) __atomic_exchange_n(dst, 1, order) - #define ma_atomic_test_and_set_explicit_32(dst, order) __atomic_exchange_n(dst, 1, order) - #define ma_atomic_test_and_set_explicit_64(dst, order) __atomic_exchange_n(dst, 1, order) - #define ma_atomic_clear_explicit_8( dst, order) __atomic_store_n(dst, 0, order) - #define ma_atomic_clear_explicit_16(dst, order) __atomic_store_n(dst, 0, order) - #define ma_atomic_clear_explicit_32(dst, order) __atomic_store_n(dst, 0, order) - #define ma_atomic_clear_explicit_64(dst, order) __atomic_store_n(dst, 0, order) #define ma_atomic_store_explicit_8( dst, src, order) __atomic_store_n(dst, src, order) #define ma_atomic_store_explicit_16(dst, src, order) __atomic_store_n(dst, src, order) #define ma_atomic_store_explicit_32(dst, src, order) __atomic_store_n(dst, src, order) @@ -14864,14 +15810,14 @@ typedef int ma_atomic_memory_order; #define ma_atomic_exchange_explicit_16(dst, src, order) __atomic_exchange_n(dst, src, order) #define ma_atomic_exchange_explicit_32(dst, src, order) __atomic_exchange_n(dst, src, order) #define ma_atomic_exchange_explicit_64(dst, src, order) __atomic_exchange_n(dst, src, order) - #define ma_atomic_compare_exchange_strong_explicit_8( dst, expected, desired, successOrder, failureOrder) __atomic_compare_exchange_n(dst, expected, desired, 0, successOrder, failureOrder) - #define ma_atomic_compare_exchange_strong_explicit_16(dst, expected, desired, successOrder, failureOrder) __atomic_compare_exchange_n(dst, expected, desired, 0, successOrder, failureOrder) - #define ma_atomic_compare_exchange_strong_explicit_32(dst, expected, desired, successOrder, failureOrder) __atomic_compare_exchange_n(dst, expected, desired, 0, successOrder, failureOrder) - #define ma_atomic_compare_exchange_strong_explicit_64(dst, expected, desired, successOrder, failureOrder) __atomic_compare_exchange_n(dst, expected, desired, 0, successOrder, failureOrder) - #define ma_atomic_compare_exchange_weak_explicit_8( dst, expected, desired, successOrder, failureOrder) __atomic_compare_exchange_n(dst, expected, desired, 1, successOrder, failureOrder) - #define ma_atomic_compare_exchange_weak_explicit_16(dst, expected, desired, successOrder, failureOrder) __atomic_compare_exchange_n(dst, expected, desired, 1, successOrder, failureOrder) - #define ma_atomic_compare_exchange_weak_explicit_32(dst, expected, desired, successOrder, failureOrder) __atomic_compare_exchange_n(dst, expected, desired, 1, successOrder, failureOrder) - #define ma_atomic_compare_exchange_weak_explicit_64(dst, expected, desired, successOrder, failureOrder) __atomic_compare_exchange_n(dst, expected, desired, 1, successOrder, failureOrder) + #define ma_atomic_compare_exchange_strong_explicit_8( dst, expected, replacement, successOrder, failureOrder) __atomic_compare_exchange_n(dst, expected, replacement, 0, successOrder, failureOrder) + #define ma_atomic_compare_exchange_strong_explicit_16(dst, expected, replacement, successOrder, failureOrder) __atomic_compare_exchange_n(dst, expected, replacement, 0, successOrder, failureOrder) + #define ma_atomic_compare_exchange_strong_explicit_32(dst, expected, replacement, successOrder, failureOrder) __atomic_compare_exchange_n(dst, expected, replacement, 0, successOrder, failureOrder) + #define ma_atomic_compare_exchange_strong_explicit_64(dst, expected, replacement, successOrder, failureOrder) __atomic_compare_exchange_n(dst, expected, replacement, 0, successOrder, failureOrder) + #define ma_atomic_compare_exchange_weak_explicit_8( dst, expected, replacement, successOrder, failureOrder) __atomic_compare_exchange_n(dst, expected, replacement, 1, successOrder, failureOrder) + #define ma_atomic_compare_exchange_weak_explicit_16(dst, expected, replacement, successOrder, failureOrder) __atomic_compare_exchange_n(dst, expected, replacement, 1, successOrder, failureOrder) + #define ma_atomic_compare_exchange_weak_explicit_32(dst, expected, replacement, successOrder, failureOrder) __atomic_compare_exchange_n(dst, expected, replacement, 1, successOrder, failureOrder) + #define ma_atomic_compare_exchange_weak_explicit_64(dst, expected, replacement, successOrder, failureOrder) __atomic_compare_exchange_n(dst, expected, replacement, 1, successOrder, failureOrder) #define ma_atomic_fetch_add_explicit_8( dst, src, order) __atomic_fetch_add(dst, src, order) #define ma_atomic_fetch_add_explicit_16(dst, src, order) __atomic_fetch_add(dst, src, order) #define ma_atomic_fetch_add_explicit_32(dst, src, order) __atomic_fetch_add(dst, src, order) @@ -14892,19 +15838,19 @@ typedef int ma_atomic_memory_order; #define ma_atomic_fetch_and_explicit_16(dst, src, order) __atomic_fetch_and(dst, src, order) #define ma_atomic_fetch_and_explicit_32(dst, src, order) __atomic_fetch_and(dst, src, order) #define ma_atomic_fetch_and_explicit_64(dst, src, order) __atomic_fetch_and(dst, src, order) - static MA_INLINE ma_uint8 ma_atomic_compare_and_swap_8(volatile ma_uint8* dst, ma_uint8 expected, ma_uint8 desired) + static MA_INLINE ma_uint8 ma_atomic_compare_and_swap_8(volatile ma_uint8* dst, ma_uint8 expected, ma_uint8 replacement) { - __atomic_compare_exchange_n(dst, &expected, desired, 0, __ATOMIC_SEQ_CST, __ATOMIC_SEQ_CST); + __atomic_compare_exchange_n(dst, &expected, replacement, 0, __ATOMIC_SEQ_CST, __ATOMIC_SEQ_CST); return expected; } - static MA_INLINE ma_uint16 ma_atomic_compare_and_swap_16(volatile ma_uint16* dst, ma_uint16 expected, ma_uint16 desired) + static MA_INLINE ma_uint16 ma_atomic_compare_and_swap_16(volatile ma_uint16* dst, ma_uint16 expected, ma_uint16 replacement) { - __atomic_compare_exchange_n(dst, &expected, desired, 0, __ATOMIC_SEQ_CST, __ATOMIC_SEQ_CST); + __atomic_compare_exchange_n(dst, &expected, replacement, 0, __ATOMIC_SEQ_CST, __ATOMIC_SEQ_CST); return expected; } - static MA_INLINE ma_uint32 ma_atomic_compare_and_swap_32(volatile ma_uint32* dst, ma_uint32 expected, ma_uint32 desired) + static MA_INLINE ma_uint32 ma_atomic_compare_and_swap_32(volatile ma_uint32* dst, ma_uint32 expected, ma_uint32 replacement) { - __atomic_compare_exchange_n(dst, &expected, desired, 0, __ATOMIC_SEQ_CST, __ATOMIC_SEQ_CST); + __atomic_compare_exchange_n(dst, &expected, replacement, 0, __ATOMIC_SEQ_CST, __ATOMIC_SEQ_CST); return expected; } #if defined(__clang__) @@ -14913,636 +15859,1134 @@ typedef int ma_atomic_memory_order; #pragma clang diagnostic ignored "-Watomic-alignment" #endif #endif - static MA_INLINE ma_uint64 ma_atomic_compare_and_swap_64(volatile ma_uint64* dst, ma_uint64 expected, ma_uint64 desired) + static MA_INLINE ma_uint64 ma_atomic_compare_and_swap_64(volatile ma_uint64* dst, ma_uint64 expected, ma_uint64 replacement) { - __atomic_compare_exchange_n(dst, &expected, desired, 0, __ATOMIC_SEQ_CST, __ATOMIC_SEQ_CST); + __atomic_compare_exchange_n(dst, &expected, replacement, 0, __ATOMIC_SEQ_CST, __ATOMIC_SEQ_CST); return expected; } #if defined(__clang__) #pragma clang diagnostic pop #endif - typedef ma_uint8 ma_atomic_flag; - #define ma_atomic_flag_test_and_set_explicit(dst, order) (ma_bool32)__atomic_test_and_set(dst, order) - #define ma_atomic_flag_clear_explicit(dst, order) __atomic_clear(dst, order) - #define ma_atomic_flag_load_explicit(ptr, order) ma_atomic_load_explicit_8(ptr, order) -#else - #define ma_atomic_memory_order_relaxed 1 - #define ma_atomic_memory_order_consume 2 - #define ma_atomic_memory_order_acquire 3 - #define ma_atomic_memory_order_release 4 - #define ma_atomic_memory_order_acq_rel 5 - #define ma_atomic_memory_order_seq_cst 6 - #define ma_atomic_compiler_fence() __asm__ __volatile__("":::"memory") - #if defined(__GNUC__) +#endif +#if defined(MA_ATOMIC_LEGACY_GCC) || defined(MA_ATOMIC_LEGACY_GCC_ASM) + #define ma_atomic_signal_fence(order) __asm__ __volatile__("":::"memory") + #if defined(MA_ATOMIC_LEGACY_GCC) #define ma_atomic_thread_fence(order) __sync_synchronize(), (void)order + static MA_INLINE ma_uint8 ma_atomic_compare_and_swap_8(volatile ma_uint8* dst, ma_uint8 expected, ma_uint8 replacement) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) + { + return __sync_val_compare_and_swap(dst, expected, replacement); + } + #else + { + MA_ATOMIC_COMPARE_AND_SWAP_LOCK(8, dst, expected, replacement); + } + #endif + } + static MA_INLINE ma_uint16 ma_atomic_compare_and_swap_16(volatile ma_uint16* dst, ma_uint16 expected, ma_uint16 replacement) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) + { + return __sync_val_compare_and_swap(dst, expected, replacement); + } + #else + { + MA_ATOMIC_COMPARE_AND_SWAP_LOCK(16, dst, expected, replacement); + } + #endif + } + static MA_INLINE ma_uint32 ma_atomic_compare_and_swap_32(volatile ma_uint32* dst, ma_uint32 expected, ma_uint32 replacement) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) + { + return __sync_val_compare_and_swap(dst, expected, replacement); + } + #else + { + MA_ATOMIC_COMPARE_AND_SWAP_LOCK(32, dst, expected, replacement); + } + #endif + } + static MA_INLINE ma_uint64 ma_atomic_compare_and_swap_64(volatile ma_uint64* dst, ma_uint64 expected, ma_uint64 replacement) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) + { + return __sync_val_compare_and_swap(dst, expected, replacement); + } + #else + { + MA_ATOMIC_COMPARE_AND_SWAP_LOCK(64, dst, expected, replacement); + } + #endif + } + static MA_INLINE ma_uint8 ma_atomic_load_explicit_8(volatile const ma_uint8* ptr, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) + { + (void)order; + return ma_atomic_compare_and_swap_8((ma_uint8*)ptr, 0, 0); + } + #else + { + MA_ATOMIC_LOAD_EXPLICIT_LOCK(8, ptr, order); + } + #endif + } + static MA_INLINE ma_uint16 ma_atomic_load_explicit_16(volatile const ma_uint16* ptr, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) + { + (void)order; + return ma_atomic_compare_and_swap_16((ma_uint16*)ptr, 0, 0); + } + #else + { + MA_ATOMIC_LOAD_EXPLICIT_LOCK(16, ptr, order); + } + #endif + } + static MA_INLINE ma_uint32 ma_atomic_load_explicit_32(volatile const ma_uint32* ptr, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) + { + (void)order; + return ma_atomic_compare_and_swap_32((ma_uint32*)ptr, 0, 0); + } + #else + { + MA_ATOMIC_LOAD_EXPLICIT_LOCK(32, ptr, order); + } + #endif + } + static MA_INLINE ma_uint64 ma_atomic_load_explicit_64(volatile const ma_uint64* ptr, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) + { + (void)order; + return ma_atomic_compare_and_swap_64((ma_uint64*)ptr, 0, 0); + } + #else + { + MA_ATOMIC_LOAD_EXPLICIT_LOCK(64, ptr, order); + } + #endif + } static MA_INLINE ma_uint8 ma_atomic_exchange_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) { - if (order > ma_atomic_memory_order_acquire) { - __sync_synchronize(); + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) + { + if (order > ma_atomic_memory_order_acquire) { + __sync_synchronize(); + } + return __sync_lock_test_and_set(dst, src); } - return __sync_lock_test_and_set(dst, src); + #else + { + MA_ATOMIC_EXCHANGE_EXPLICIT_LOCK(8, dst, src, order); + } + #endif } static MA_INLINE ma_uint16 ma_atomic_exchange_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) { - ma_uint16 oldValue; - do { - oldValue = *dst; - } while (__sync_val_compare_and_swap(dst, oldValue, src) != oldValue); - (void)order; - return oldValue; + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) + { + if (order > ma_atomic_memory_order_acquire) { + __sync_synchronize(); + } + return __sync_lock_test_and_set(dst, src); + } + #else + { + MA_ATOMIC_EXCHANGE_EXPLICIT_LOCK(16, dst, src, order); + } + #endif } static MA_INLINE ma_uint32 ma_atomic_exchange_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) { - ma_uint32 oldValue; - do { - oldValue = *dst; - } while (__sync_val_compare_and_swap(dst, oldValue, src) != oldValue); - (void)order; - return oldValue; + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) + { + if (order > ma_atomic_memory_order_acquire) { + __sync_synchronize(); + } + return __sync_lock_test_and_set(dst, src); + } + #else + { + MA_ATOMIC_EXCHANGE_EXPLICIT_LOCK(32, dst, src, order); + } + #endif } static MA_INLINE ma_uint64 ma_atomic_exchange_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) { - ma_uint64 oldValue; - do { - oldValue = *dst; - } while (__sync_val_compare_and_swap(dst, oldValue, src) != oldValue); - (void)order; - return oldValue; + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) + { + if (order > ma_atomic_memory_order_acquire) { + __sync_synchronize(); + } + return __sync_lock_test_and_set(dst, src); + } + #else + { + MA_ATOMIC_EXCHANGE_EXPLICIT_LOCK(64, dst, src, order); + } + #endif } + #define ma_atomic_store_explicit_8( dst, src, order) (void)ma_atomic_exchange_explicit_8 (dst, src, order) + #define ma_atomic_store_explicit_16(dst, src, order) (void)ma_atomic_exchange_explicit_16(dst, src, order) + #define ma_atomic_store_explicit_32(dst, src, order) (void)ma_atomic_exchange_explicit_32(dst, src, order) + #define ma_atomic_store_explicit_64(dst, src, order) (void)ma_atomic_exchange_explicit_64(dst, src, order) static MA_INLINE ma_uint8 ma_atomic_fetch_add_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) { - (void)order; - return __sync_fetch_and_add(dst, src); + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) + { + (void)order; + return __sync_fetch_and_add(dst, src); + } + #else + { + MA_ATOMIC_FETCH_ADD_LOCK(8, dst, src, order); + } + #endif } static MA_INLINE ma_uint16 ma_atomic_fetch_add_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) { - (void)order; - return __sync_fetch_and_add(dst, src); + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) + { + (void)order; + return __sync_fetch_and_add(dst, src); + } + #else + { + MA_ATOMIC_FETCH_ADD_LOCK(16, dst, src, order); + } + #endif } static MA_INLINE ma_uint32 ma_atomic_fetch_add_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) { - (void)order; - return __sync_fetch_and_add(dst, src); + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) + { + (void)order; + return __sync_fetch_and_add(dst, src); + } + #else + { + MA_ATOMIC_FETCH_ADD_LOCK(32, dst, src, order); + } + #endif } static MA_INLINE ma_uint64 ma_atomic_fetch_add_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) { - (void)order; - return __sync_fetch_and_add(dst, src); + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) + { + (void)order; + return __sync_fetch_and_add(dst, src); + } + #else + { + MA_ATOMIC_FETCH_ADD_LOCK(64, dst, src, order); + } + #endif } static MA_INLINE ma_uint8 ma_atomic_fetch_sub_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) { - (void)order; - return __sync_fetch_and_sub(dst, src); + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) + { + (void)order; + return __sync_fetch_and_sub(dst, src); + } + #else + { + MA_ATOMIC_FETCH_ADD_LOCK(8, dst, (ma_uint8)(-(ma_int8)src), order); + } + #endif } static MA_INLINE ma_uint16 ma_atomic_fetch_sub_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) { - (void)order; - return __sync_fetch_and_sub(dst, src); + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) + { + (void)order; + return __sync_fetch_and_sub(dst, src); + } + #else + { + MA_ATOMIC_FETCH_ADD_LOCK(16, dst, (ma_uint16)(-(ma_int16)src), order); + } + #endif } static MA_INLINE ma_uint32 ma_atomic_fetch_sub_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) { - (void)order; - return __sync_fetch_and_sub(dst, src); + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) + { + (void)order; + return __sync_fetch_and_sub(dst, src); + } + #else + { + MA_ATOMIC_FETCH_ADD_LOCK(32, dst, (ma_uint32)(-(ma_int32)src), order); + } + #endif } static MA_INLINE ma_uint64 ma_atomic_fetch_sub_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) { - (void)order; - return __sync_fetch_and_sub(dst, src); + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) + { + (void)order; + return __sync_fetch_and_sub(dst, src); + } + #else + { + MA_ATOMIC_FETCH_ADD_LOCK(64, dst, (ma_uint64)(-(ma_int64)src), order); + } + #endif + } + static MA_INLINE ma_uint8 ma_atomic_fetch_and_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) + { + (void)order; + return __sync_fetch_and_and(dst, src); + } + #else + { + MA_ATOMIC_FETCH_AND_CAS(8, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint16 ma_atomic_fetch_and_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) + { + (void)order; + return __sync_fetch_and_and(dst, src); + } + #else + { + MA_ATOMIC_FETCH_AND_CAS(16, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint32 ma_atomic_fetch_and_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) + { + (void)order; + return __sync_fetch_and_and(dst, src); + } + #else + { + MA_ATOMIC_FETCH_AND_CAS(32, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint64 ma_atomic_fetch_and_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) + { + (void)order; + return __sync_fetch_and_and(dst, src); + } + #else + { + MA_ATOMIC_FETCH_AND_CAS(64, dst, src, order); + } + #endif } static MA_INLINE ma_uint8 ma_atomic_fetch_or_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) { - (void)order; - return __sync_fetch_and_or(dst, src); + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) + { + (void)order; + return __sync_fetch_and_or(dst, src); + } + #else + { + MA_ATOMIC_FETCH_OR_CAS(8, dst, src, order); + } + #endif } static MA_INLINE ma_uint16 ma_atomic_fetch_or_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) { - (void)order; - return __sync_fetch_and_or(dst, src); + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) + { + (void)order; + return __sync_fetch_and_or(dst, src); + } + #else + { + MA_ATOMIC_FETCH_OR_CAS(16, dst, src, order); + } + #endif } static MA_INLINE ma_uint32 ma_atomic_fetch_or_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) { - (void)order; - return __sync_fetch_and_or(dst, src); + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) + { + (void)order; + return __sync_fetch_and_or(dst, src); + } + #else + { + MA_ATOMIC_FETCH_OR_CAS(32, dst, src, order); + } + #endif } static MA_INLINE ma_uint64 ma_atomic_fetch_or_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) { - (void)order; - return __sync_fetch_and_or(dst, src); + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) + { + (void)order; + return __sync_fetch_and_or(dst, src); + } + #else + { + MA_ATOMIC_FETCH_OR_CAS(64, dst, src, order); + } + #endif } static MA_INLINE ma_uint8 ma_atomic_fetch_xor_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) { - (void)order; - return __sync_fetch_and_xor(dst, src); + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) + { + (void)order; + return __sync_fetch_and_xor(dst, src); + } + #else + { + MA_ATOMIC_FETCH_XOR_CAS(8, dst, src, order); + } + #endif } static MA_INLINE ma_uint16 ma_atomic_fetch_xor_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) { - (void)order; - return __sync_fetch_and_xor(dst, src); + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) + { + (void)order; + return __sync_fetch_and_xor(dst, src); + } + #else + { + MA_ATOMIC_FETCH_XOR_CAS(16, dst, src, order); + } + #endif } static MA_INLINE ma_uint32 ma_atomic_fetch_xor_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) { - (void)order; - return __sync_fetch_and_xor(dst, src); + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) + { + (void)order; + return __sync_fetch_and_xor(dst, src); + } + #else + { + MA_ATOMIC_FETCH_XOR_CAS(32, dst, src, order); + } + #endif } static MA_INLINE ma_uint64 ma_atomic_fetch_xor_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) { - (void)order; - return __sync_fetch_and_xor(dst, src); + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) + { + (void)order; + return __sync_fetch_and_xor(dst, src); + } + #else + { + MA_ATOMIC_FETCH_XOR_CAS(64, dst, src, order); + } + #endif } - static MA_INLINE ma_uint8 ma_atomic_fetch_and_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + #elif defined(MA_ATOMIC_LEGACY_GCC_ASM) + #define MA_ATOMIC_CMPXCHG_GCC_X86(instructionSizeSuffix, result, dst, expected, replacement) \ + __asm__ __volatile__( \ + "lock; cmpxchg"instructionSizeSuffix" %2, %1" \ + : "=a"(result), \ + "=m"(*dst) \ + : "r"(replacement), \ + "0"(expected), \ + "m"(*dst) \ + : "cc", "memory") + #define MA_ATOMIC_XADD_GCC_X86(instructionSizeSuffix, result, dst, src) \ + __asm__ __volatile__( \ + "lock; xadd"instructionSizeSuffix" %0, %1" \ + : "=a"(result), \ + "=m"(*dst) \ + : "0"(src), \ + "m"(*dst) \ + : "cc", "memory") + static MA_INLINE ma_uint8 ma_atomic_compare_and_swap_8(volatile ma_uint8* dst, ma_uint8 expected, ma_uint8 replacement) { - (void)order; - return __sync_fetch_and_and(dst, src); + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) && (defined(MA_X86) || defined(MA_X64)) + { + ma_uint8 result; + #if defined(MA_X86) || defined(MA_X64) + { + MA_ATOMIC_CMPXCHG_GCC_X86("b", result, dst, expected, replacement); + } + #else + { + #error Unsupported architecture. + } + #endif + return result; + } + #else + { + MA_ATOMIC_COMPARE_AND_SWAP_LOCK(8, dst, expected, replacement); + } + #endif } - static MA_INLINE ma_uint16 ma_atomic_fetch_and_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + static MA_INLINE ma_uint16 ma_atomic_compare_and_swap_16(volatile ma_uint16* dst, ma_uint16 expected, ma_uint16 replacement) { - (void)order; - return __sync_fetch_and_and(dst, src); + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) && (defined(MA_X86) || defined(MA_X64)) + { + ma_uint16 result; + #if defined(MA_X86) || defined(MA_X64) + { + MA_ATOMIC_CMPXCHG_GCC_X86("w", result, dst, expected, replacement); + } + #else + { + #error Unsupported architecture. + } + #endif + return result; + } + #else + { + MA_ATOMIC_COMPARE_AND_SWAP_LOCK(16, dst, expected, replacement); + } + #endif } - static MA_INLINE ma_uint32 ma_atomic_fetch_and_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + static MA_INLINE ma_uint32 ma_atomic_compare_and_swap_32(volatile ma_uint32* dst, ma_uint32 expected, ma_uint32 replacement) { - (void)order; - return __sync_fetch_and_and(dst, src); + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) && (defined(MA_X86) || defined(MA_X64)) + { + ma_uint32 result; + #if defined(MA_X86) || defined(MA_X64) + { + MA_ATOMIC_CMPXCHG_GCC_X86("l", result, dst, expected, replacement); + } + #else + { + #error Unsupported architecture. + } + #endif + return result; + } + #else + { + MA_ATOMIC_COMPARE_AND_SWAP_LOCK(32, dst, expected, replacement); + } + #endif } - static MA_INLINE ma_uint64 ma_atomic_fetch_and_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + static MA_INLINE ma_uint64 ma_atomic_compare_and_swap_64(volatile ma_uint64* dst, ma_uint64 expected, ma_uint64 replacement) { - (void)order; - return __sync_fetch_and_and(dst, src); + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) && (defined(MA_X86) || defined(MA_X64)) + { + ma_uint64 result; + #if defined(MA_X86) + { + ma_uint32 resultEAX; + ma_uint32 resultEDX; + __asm__ __volatile__( + "pushl %%ebx\n" + "movl %4, %%ebx\n" + "lock cmpxchg8b (%%edi)\n" + "popl %%ebx\n" + : "=a"(resultEAX), + "=d"(resultEDX) + : "a"((ma_uint32)(expected & 0xFFFFFFFF)), + "d"((ma_uint32)(expected >> 32)), + "r"((ma_uint32)(replacement & 0xFFFFFFFF)), + "c"((ma_uint32)(replacement >> 32)), + "D"(dst) + : "memory", "cc"); + result = ((ma_uint64)resultEDX << 32) | resultEAX; + } + #elif defined(MA_X64) + { + MA_ATOMIC_CMPXCHG_GCC_X86("q", result, dst, expected, replacement); + } + #else + { + #error Unsupported architecture. + } + #endif + return result; + } + #else + { + MA_ATOMIC_COMPARE_AND_SWAP_LOCK(64, dst, expected, replacement); + } + #endif } - #define ma_atomic_compare_and_swap_8( dst, expected, desired) __sync_val_compare_and_swap(dst, expected, desired) - #define ma_atomic_compare_and_swap_16(dst, expected, desired) __sync_val_compare_and_swap(dst, expected, desired) - #define ma_atomic_compare_and_swap_32(dst, expected, desired) __sync_val_compare_and_swap(dst, expected, desired) - #define ma_atomic_compare_and_swap_64(dst, expected, desired) __sync_val_compare_and_swap(dst, expected, desired) - #else - #if defined(MA_X86) - #define ma_atomic_thread_fence(order) __asm__ __volatile__("lock; addl $0, (%%esp)" ::: "memory", "cc") - #elif defined(MA_X64) - #define ma_atomic_thread_fence(order) __asm__ __volatile__("lock; addq $0, (%%rsp)" ::: "memory", "cc") - #else - #error Unsupported architecture. Please submit a feature request. - #endif - static MA_INLINE ma_uint8 ma_atomic_compare_and_swap_8(volatile ma_uint8* dst, ma_uint8 expected, ma_uint8 desired) + static MA_INLINE ma_uint8 ma_atomic_load_explicit_8(volatile const ma_uint8* dst, ma_atomic_memory_order order) { - ma_uint8 result; - #if defined(MA_X86) || defined(MA_X64) - __asm__ __volatile__("lock; cmpxchg %3, %0" : "+m"(*dst), "=a"(result) : "a"(expected), "d"(desired) : "cc"); - #else - #error Unsupported architecture. Please submit a feature request. - #endif - return result; + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) && (defined(MA_X86) || defined(MA_X64)) + { + ma_uint8 result; + #if defined(MA_X86) || defined(MA_X64) + { + if (order == ma_atomic_memory_order_relaxed) { + MA_ATOMIC_LOAD_RELAXED_GCC_X86("b", result, dst); + } else if (order <= ma_atomic_memory_order_release) { + MA_ATOMIC_LOAD_RELEASE_GCC_X86("b", result, dst); + } else { + MA_ATOMIC_LOAD_SEQ_CST_GCC_X86("b", result, dst); + } + } + #else + { + #error Unsupported architecture. + } + #endif + return result; + } + #else + { + MA_ATOMIC_LOAD_EXPLICIT_LOCK(8, dst, order); + } + #endif } - static MA_INLINE ma_uint16 ma_atomic_compare_and_swap_16(volatile ma_uint16* dst, ma_uint16 expected, ma_uint16 desired) + static MA_INLINE ma_uint16 ma_atomic_load_explicit_16(volatile const ma_uint16* dst, ma_atomic_memory_order order) { - ma_uint16 result; - #if defined(MA_X86) || defined(MA_X64) - __asm__ __volatile__("lock; cmpxchg %3, %0" : "+m"(*dst), "=a"(result) : "a"(expected), "d"(desired) : "cc"); - #else - #error Unsupported architecture. Please submit a feature request. - #endif - return result; + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) && (defined(MA_X86) || defined(MA_X64)) + { + ma_uint16 result; + #if defined(MA_X86) || defined(MA_X64) + { + if (order == ma_atomic_memory_order_relaxed) { + MA_ATOMIC_LOAD_RELAXED_GCC_X86("w", result, dst); + } else if (order <= ma_atomic_memory_order_release) { + MA_ATOMIC_LOAD_RELEASE_GCC_X86("w", result, dst); + } else { + MA_ATOMIC_LOAD_SEQ_CST_GCC_X86("w", result, dst); + } + } + #else + { + #error Unsupported architecture. + } + #endif + return result; + } + #else + { + MA_ATOMIC_LOAD_EXPLICIT_LOCK(16, dst, order); + } + #endif + } + static MA_INLINE ma_uint32 ma_atomic_load_explicit_32(volatile const ma_uint32* dst, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) && (defined(MA_X86) || defined(MA_X64)) + { + ma_uint32 result; + #if defined(MA_X86) || defined(MA_X64) + { + if (order == ma_atomic_memory_order_relaxed) { + MA_ATOMIC_LOAD_RELAXED_GCC_X86("l", result, dst); + } else if (order <= ma_atomic_memory_order_release) { + MA_ATOMIC_LOAD_RELEASE_GCC_X86("l", result, dst); + } else { + MA_ATOMIC_LOAD_SEQ_CST_GCC_X86("l", result, dst); + } + } + #else + { + #error Unsupported architecture. + } + #endif + return result; + } + #else + { + MA_ATOMIC_LOAD_EXPLICIT_LOCK(32, dst, order); + } + #endif + } + static MA_INLINE ma_uint64 ma_atomic_load_explicit_64(volatile const ma_uint64* dst, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) && (defined(MA_X86) || defined(MA_X64)) + { + ma_uint64 result; + #if defined(MA_X64) + { + if (order == ma_atomic_memory_order_relaxed) { + MA_ATOMIC_LOAD_RELAXED_GCC_X86("q", result, dst); + } else if (order <= ma_atomic_memory_order_release) { + MA_ATOMIC_LOAD_RELEASE_GCC_X86("q", result, dst); + } else { + MA_ATOMIC_LOAD_SEQ_CST_GCC_X86("q", result, dst); + } + } + #elif defined(MA_X86) + { + (void)order; + return ma_atomic_compare_and_swap_64((volatile ma_uint64*)dst, 0, 0); + } + #else + { + #error Unsupported architecture. + } + #endif + return result; + } + #else + { + MA_ATOMIC_LOAD_EXPLICIT_LOCK(64, dst, order); + } + #endif + } + static MA_INLINE ma_uint8 ma_atomic_exchange_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) && (defined(MA_X86) || defined(MA_X64)) + { + ma_uint8 result; + (void)order; + #if defined(MA_X86) || defined(MA_X64) + { + MA_ATOMIC_XCHG_GCC_X86("b", result, dst, src); + } + #else + { + #error Unsupported architecture. + } + #endif + return result; + } + #else + { + MA_ATOMIC_EXCHANGE_EXPLICIT_LOCK(8, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint16 ma_atomic_exchange_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) && (defined(MA_X86) || defined(MA_X64)) + { + ma_uint16 result; + (void)order; + #if defined(MA_X86) || defined(MA_X64) + { + MA_ATOMIC_XCHG_GCC_X86("w", result, dst, src); + } + #else + { + #error Unsupported architecture. + } + #endif + return result; + } + #else + { + MA_ATOMIC_EXCHANGE_EXPLICIT_LOCK(16, dst, src, order); + } + #endif } - static MA_INLINE ma_uint32 ma_atomic_compare_and_swap_32(volatile ma_uint32* dst, ma_uint32 expected, ma_uint32 desired) + static MA_INLINE ma_uint32 ma_atomic_exchange_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) { - ma_uint32 result; - #if defined(MA_X86) || defined(MA_X64) - __asm__ __volatile__("lock; cmpxchg %3, %0" : "+m"(*dst), "=a"(result) : "a"(expected), "d"(desired) : "cc"); - #else - #error Unsupported architecture. Please submit a feature request. - #endif - return result; + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) && (defined(MA_X86) || defined(MA_X64)) + { + ma_uint32 result; + (void)order; + #if defined(MA_X86) || defined(MA_X64) + { + MA_ATOMIC_XCHG_GCC_X86("l", result, dst, src); + } + #else + { + #error Unsupported architecture. + } + #endif + return result; + } + #else + { + MA_ATOMIC_EXCHANGE_EXPLICIT_LOCK(32, dst, src, order); + } + #endif } - static MA_INLINE ma_uint64 ma_atomic_compare_and_swap_64(volatile ma_uint64* dst, ma_uint64 expected, ma_uint64 desired) + static MA_INLINE ma_uint64 ma_atomic_exchange_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) { - volatile ma_uint64 result; - #if defined(MA_X86) - ma_uint32 resultEAX; - ma_uint32 resultEDX; - __asm__ __volatile__("push %%ebx; xchg %5, %%ebx; lock; cmpxchg8b %0; pop %%ebx" : "+m"(*dst), "=a"(resultEAX), "=d"(resultEDX) : "a"(expected & 0xFFFFFFFF), "d"(expected >> 32), "r"(desired & 0xFFFFFFFF), "c"(desired >> 32) : "cc"); - result = ((ma_uint64)resultEDX << 32) | resultEAX; - #elif defined(MA_X64) - __asm__ __volatile__("lock; cmpxchg %3, %0" : "+m"(*dst), "=a"(result) : "a"(expected), "d"(desired) : "cc"); - #else - #error Unsupported architecture. Please submit a feature request. - #endif - return result; + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) && (defined(MA_X86) || defined(MA_X64)) + { + ma_uint64 result; + (void)order; + #if defined(MA_X86) + { + MA_ATOMIC_EXCHANGE_EXPLICIT_CAS(64, dst, src, order); + } + #elif defined(MA_X64) + { + MA_ATOMIC_XCHG_GCC_X86("q", result, dst, src); + } + #else + { + #error Unsupported architecture. + } + #endif + return result; + } + #else + { + MA_ATOMIC_EXCHANGE_EXPLICIT_LOCK(64, dst, src, order); + } + #endif } - static MA_INLINE ma_uint8 ma_atomic_exchange_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + static MA_INLINE void ma_atomic_store_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) { - ma_uint8 result = 0; - (void)order; - #if defined(MA_X86) || defined(MA_X64) - __asm__ __volatile__("lock; xchg %1, %0" : "+m"(*dst), "=a"(result) : "a"(src)); - #else - #error Unsupported architecture. Please submit a feature request. - #endif - return result; + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) && (defined(MA_X86) || defined(MA_X64)) + { + #if defined(MA_X86) || defined(MA_X64) + { + if (order == ma_atomic_memory_order_relaxed) { + __asm__ __volatile__ ( + "movb %1, %0" + : "=m"(*dst) + : "r"(src) + ); + } else { + __asm__ __volatile__ ( + "xchgb %1, %0" + : "=m"(*dst) + : "r"(src) + : "memory" + ); + } + } + #else + { + #error Unsupported architecture. + } + #endif + } + #else + { + MA_ATOMIC_STORE_EXPLICIT_LOCK(8, dst, src, order); + } + #endif } - static MA_INLINE ma_uint16 ma_atomic_exchange_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + static MA_INLINE void ma_atomic_store_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) { - ma_uint16 result = 0; - (void)order; - #if defined(MA_X86) || defined(MA_X64) - __asm__ __volatile__("lock; xchg %1, %0" : "+m"(*dst), "=a"(result) : "a"(src)); - #else - #error Unsupported architecture. Please submit a feature request. - #endif - return result; + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) && (defined(MA_X86) || defined(MA_X64)) + { + #if defined(MA_X86) || defined(MA_X64) + { + if (order == ma_atomic_memory_order_relaxed) { + __asm__ __volatile__ ( + "movw %1, %0" + : "=m"(*dst) + : "r"(src) + ); + } else { + __asm__ __volatile__ ( + "xchgw %1, %0" + : "=m"(*dst) + : "r"(src) + : "memory" + ); + } + } + #else + { + #error Unsupported architecture. + } + #endif + } + #else + { + MA_ATOMIC_STORE_EXPLICIT_LOCK(16, dst, src, order); + } + #endif } - static MA_INLINE ma_uint32 ma_atomic_exchange_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + static MA_INLINE void ma_atomic_store_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) { - ma_uint32 result; - (void)order; - #if defined(MA_X86) || defined(MA_X64) - __asm__ __volatile__("lock; xchg %1, %0" : "+m"(*dst), "=a"(result) : "a"(src)); - #else - #error Unsupported architecture. Please submit a feature request. - #endif - return result; + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) && (defined(MA_X86) || defined(MA_X64)) + { + #if defined(MA_X86) || defined(MA_X64) + { + if (order == ma_atomic_memory_order_relaxed) { + __asm__ __volatile__ ( + "movl %1, %0" + : "=m"(*dst) + : "r"(src) + ); + } else { + __asm__ __volatile__ ( + "xchgl %1, %0" + : "=m"(*dst) + : "r"(src) + : "memory" + ); + } + } + #else + { + #error Unsupported architecture. + } + #endif + } + #else + { + MA_ATOMIC_STORE_EXPLICIT_LOCK(32, dst, src, order); + } + #endif } - static MA_INLINE ma_uint64 ma_atomic_exchange_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + static MA_INLINE void ma_atomic_store_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) { - ma_uint64 result; - (void)order; - #if defined(MA_X86) - do { - result = *dst; - } while (ma_atomic_compare_and_swap_64(dst, result, src) != result); - #elif defined(MA_X64) - __asm__ __volatile__("lock; xchg %1, %0" : "+m"(*dst), "=a"(result) : "a"(src)); - #else - #error Unsupported architecture. Please submit a feature request. - #endif - return result; + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) && (defined(MA_X86) || defined(MA_X64)) + { + #if defined(MA_X64) + { + if (order == ma_atomic_memory_order_relaxed) { + __asm__ __volatile__ ( + "movq %1, %0" + : "=m"(*dst) + : "r"(src) + ); + } else { + __asm__ __volatile__ ( + "xchgq %1, %0" + : "=m"(*dst) + : "r"(src) + : "memory" + ); + } + } + #else + { + MA_ATOMIC_STORE_EXPLICIT_CAS(64, dst, src, order); + } + #endif + } + #else + { + MA_ATOMIC_STORE_EXPLICIT_LOCK(64, dst, src, order); + } + #endif } static MA_INLINE ma_uint8 ma_atomic_fetch_add_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) { - ma_uint8 result; - (void)order; - #if defined(MA_X86) || defined(MA_X64) - __asm__ __volatile__("lock; xadd %1, %0" : "+m"(*dst), "=a"(result) : "a"(src) : "cc"); - #else - #error Unsupported architecture. Please submit a feature request. - #endif - return result; + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) && (defined(MA_X86) || defined(MA_X64)) + { + #if defined(MA_X86) || defined(MA_X64) + { + ma_uint8 result; + (void)order; + MA_ATOMIC_XADD_GCC_X86("b", result, dst, src); + return result; + } + #else + { + #error Unsupported architecture. + } + #endif + } + #else + { + MA_ATOMIC_FETCH_ADD_LOCK(8, dst, src, order); + } + #endif } static MA_INLINE ma_uint16 ma_atomic_fetch_add_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) { - ma_uint16 result; - (void)order; - #if defined(MA_X86) || defined(MA_X64) - __asm__ __volatile__("lock; xadd %1, %0" : "+m"(*dst), "=a"(result) : "a"(src) : "cc"); - #else - #error Unsupported architecture. Please submit a feature request. - #endif - return result; + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) && (defined(MA_X86) || defined(MA_X64)) + { + #if defined(MA_X86) || defined(MA_X64) + { + ma_uint16 result; + (void)order; + MA_ATOMIC_XADD_GCC_X86("w", result, dst, src); + return result; + } + #else + { + #error Unsupported architecture. + } + #endif + } + #else + { + MA_ATOMIC_FETCH_ADD_LOCK(16, dst, src, order); + } + #endif } static MA_INLINE ma_uint32 ma_atomic_fetch_add_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) { - ma_uint32 result; - (void)order; - #if defined(MA_X86) || defined(MA_X64) - __asm__ __volatile__("lock; xadd %1, %0" : "+m"(*dst), "=a"(result) : "a"(src) : "cc"); - #else - #error Unsupported architecture. Please submit a feature request. - #endif - return result; + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) && (defined(MA_X86) || defined(MA_X64)) + { + #if defined(MA_X86) || defined(MA_X64) + { + ma_uint32 result; + (void)order; + MA_ATOMIC_XADD_GCC_X86("l", result, dst, src); + return result; + } + #else + { + #error Unsupported architecture. + } + #endif + } + #else + { + MA_ATOMIC_FETCH_ADD_LOCK(32, dst, src, order); + } + #endif } static MA_INLINE ma_uint64 ma_atomic_fetch_add_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) { - #if defined(MA_X86) - ma_uint64 oldValue; - ma_uint64 newValue; - (void)order; - do { - oldValue = *dst; - newValue = oldValue + src; - } while (ma_atomic_compare_and_swap_64(dst, oldValue, newValue) != oldValue); - return oldValue; - #elif defined(MA_X64) - ma_uint64 result; - (void)order; - __asm__ __volatile__("lock; xadd %1, %0" : "+m"(*dst), "=a"(result) : "a"(src) : "cc"); - return result; - #endif + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) && (defined(MA_X86) || defined(MA_X64)) + { + #if defined(MA_X86) + { + MA_ATOMIC_FETCH_ADD_CAS(64, dst, src, order); + } + #elif defined(MA_X64) + { + ma_uint64 result; + MA_ATOMIC_XADD_GCC_X86("q", result, dst, src); + (void)order; + return result; + } + #else + { + #error Unsupported architecture. + } + #endif + } + #else + { + MA_ATOMIC_FETCH_ADD_LOCK(64, dst, src, order); + } + #endif } static MA_INLINE ma_uint8 ma_atomic_fetch_sub_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) { - ma_uint8 oldValue; - ma_uint8 newValue; - do { - oldValue = *dst; - newValue = (ma_uint8)(oldValue - src); - } while (ma_atomic_compare_and_swap_8(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; + return ma_atomic_fetch_add_explicit_8(dst, (ma_uint8)(-(ma_int8)src), order); } static MA_INLINE ma_uint16 ma_atomic_fetch_sub_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) { - ma_uint16 oldValue; - ma_uint16 newValue; - do { - oldValue = *dst; - newValue = (ma_uint16)(oldValue - src); - } while (ma_atomic_compare_and_swap_16(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; + return ma_atomic_fetch_add_explicit_16(dst, (ma_uint16)(-(ma_int16)src), order); } static MA_INLINE ma_uint32 ma_atomic_fetch_sub_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) { - ma_uint32 oldValue; - ma_uint32 newValue; - do { - oldValue = *dst; - newValue = oldValue - src; - } while (ma_atomic_compare_and_swap_32(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; + return ma_atomic_fetch_add_explicit_32(dst, (ma_uint32)(-(ma_int32)src), order); } static MA_INLINE ma_uint64 ma_atomic_fetch_sub_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) { - ma_uint64 oldValue; - ma_uint64 newValue; - do { - oldValue = *dst; - newValue = oldValue - src; - } while (ma_atomic_compare_and_swap_64(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; + return ma_atomic_fetch_add_explicit_64(dst, (ma_uint64)(-(ma_int64)src), order); } static MA_INLINE ma_uint8 ma_atomic_fetch_and_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) { - ma_uint8 oldValue; - ma_uint8 newValue; - do { - oldValue = *dst; - newValue = (ma_uint8)(oldValue & src); - } while (ma_atomic_compare_and_swap_8(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; + MA_ATOMIC_FETCH_AND_CAS(8, dst, src, order); } static MA_INLINE ma_uint16 ma_atomic_fetch_and_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) { - ma_uint16 oldValue; - ma_uint16 newValue; - do { - oldValue = *dst; - newValue = (ma_uint16)(oldValue & src); - } while (ma_atomic_compare_and_swap_16(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; + MA_ATOMIC_FETCH_AND_CAS(16, dst, src, order); } static MA_INLINE ma_uint32 ma_atomic_fetch_and_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) { - ma_uint32 oldValue; - ma_uint32 newValue; - do { - oldValue = *dst; - newValue = oldValue & src; - } while (ma_atomic_compare_and_swap_32(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; + MA_ATOMIC_FETCH_AND_CAS(32, dst, src, order); } static MA_INLINE ma_uint64 ma_atomic_fetch_and_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) { - ma_uint64 oldValue; - ma_uint64 newValue; - do { - oldValue = *dst; - newValue = oldValue & src; - } while (ma_atomic_compare_and_swap_64(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; + MA_ATOMIC_FETCH_AND_CAS(64, dst, src, order); } - static MA_INLINE ma_uint8 ma_atomic_fetch_xor_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + static MA_INLINE ma_uint8 ma_atomic_fetch_or_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) { - ma_uint8 oldValue; - ma_uint8 newValue; - do { - oldValue = *dst; - newValue = (ma_uint8)(oldValue ^ src); - } while (ma_atomic_compare_and_swap_8(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; + MA_ATOMIC_FETCH_OR_CAS(8, dst, src, order); } - static MA_INLINE ma_uint16 ma_atomic_fetch_xor_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + static MA_INLINE ma_uint16 ma_atomic_fetch_or_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) { - ma_uint16 oldValue; - ma_uint16 newValue; - do { - oldValue = *dst; - newValue = (ma_uint16)(oldValue ^ src); - } while (ma_atomic_compare_and_swap_16(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; + MA_ATOMIC_FETCH_OR_CAS(16, dst, src, order); } - static MA_INLINE ma_uint32 ma_atomic_fetch_xor_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + static MA_INLINE ma_uint32 ma_atomic_fetch_or_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) { - ma_uint32 oldValue; - ma_uint32 newValue; - do { - oldValue = *dst; - newValue = oldValue ^ src; - } while (ma_atomic_compare_and_swap_32(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; + MA_ATOMIC_FETCH_OR_CAS(32, dst, src, order); } - static MA_INLINE ma_uint64 ma_atomic_fetch_xor_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + static MA_INLINE ma_uint64 ma_atomic_fetch_or_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) { - ma_uint64 oldValue; - ma_uint64 newValue; - do { - oldValue = *dst; - newValue = oldValue ^ src; - } while (ma_atomic_compare_and_swap_64(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; + MA_ATOMIC_FETCH_OR_CAS(64, dst, src, order); } - static MA_INLINE ma_uint8 ma_atomic_fetch_or_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + static MA_INLINE ma_uint8 ma_atomic_fetch_xor_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) { - ma_uint8 oldValue; - ma_uint8 newValue; - do { - oldValue = *dst; - newValue = (ma_uint8)(oldValue | src); - } while (ma_atomic_compare_and_swap_8(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; + MA_ATOMIC_FETCH_XOR_CAS(8, dst, src, order); } - static MA_INLINE ma_uint16 ma_atomic_fetch_or_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + static MA_INLINE ma_uint16 ma_atomic_fetch_xor_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) { - ma_uint16 oldValue; - ma_uint16 newValue; - do { - oldValue = *dst; - newValue = (ma_uint16)(oldValue | src); - } while (ma_atomic_compare_and_swap_16(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; + MA_ATOMIC_FETCH_XOR_CAS(16, dst, src, order); } - static MA_INLINE ma_uint32 ma_atomic_fetch_or_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + static MA_INLINE ma_uint32 ma_atomic_fetch_xor_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) { - ma_uint32 oldValue; - ma_uint32 newValue; - do { - oldValue = *dst; - newValue = oldValue | src; - } while (ma_atomic_compare_and_swap_32(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; + MA_ATOMIC_FETCH_XOR_CAS(32, dst, src, order); } - static MA_INLINE ma_uint64 ma_atomic_fetch_or_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + static MA_INLINE ma_uint64 ma_atomic_fetch_xor_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) { - ma_uint64 oldValue; - ma_uint64 newValue; - do { - oldValue = *dst; - newValue = oldValue | src; - } while (ma_atomic_compare_and_swap_64(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; + MA_ATOMIC_FETCH_XOR_CAS(64, dst, src, order); } + #else + #error Unsupported compiler. #endif - #define ma_atomic_signal_fence(order) ma_atomic_thread_fence(order) - static MA_INLINE ma_uint8 ma_atomic_load_explicit_8(volatile const ma_uint8* ptr, ma_atomic_memory_order order) - { - (void)order; - return ma_atomic_compare_and_swap_8((ma_uint8*)ptr, 0, 0); - } - static MA_INLINE ma_uint16 ma_atomic_load_explicit_16(volatile const ma_uint16* ptr, ma_atomic_memory_order order) - { - (void)order; - return ma_atomic_compare_and_swap_16((ma_uint16*)ptr, 0, 0); - } - static MA_INLINE ma_uint32 ma_atomic_load_explicit_32(volatile const ma_uint32* ptr, ma_atomic_memory_order order) - { - (void)order; - return ma_atomic_compare_and_swap_32((ma_uint32*)ptr, 0, 0); - } - static MA_INLINE ma_uint64 ma_atomic_load_explicit_64(volatile const ma_uint64* ptr, ma_atomic_memory_order order) - { - (void)order; - return ma_atomic_compare_and_swap_64((ma_uint64*)ptr, 0, 0); - } - #define ma_atomic_store_explicit_8( dst, src, order) (void)ma_atomic_exchange_explicit_8 (dst, src, order) - #define ma_atomic_store_explicit_16(dst, src, order) (void)ma_atomic_exchange_explicit_16(dst, src, order) - #define ma_atomic_store_explicit_32(dst, src, order) (void)ma_atomic_exchange_explicit_32(dst, src, order) - #define ma_atomic_store_explicit_64(dst, src, order) (void)ma_atomic_exchange_explicit_64(dst, src, order) - #define ma_atomic_test_and_set_explicit_8( dst, order) ma_atomic_exchange_explicit_8 (dst, 1, order) - #define ma_atomic_test_and_set_explicit_16(dst, order) ma_atomic_exchange_explicit_16(dst, 1, order) - #define ma_atomic_test_and_set_explicit_32(dst, order) ma_atomic_exchange_explicit_32(dst, 1, order) - #define ma_atomic_test_and_set_explicit_64(dst, order) ma_atomic_exchange_explicit_64(dst, 1, order) - #define ma_atomic_clear_explicit_8( dst, order) ma_atomic_store_explicit_8 (dst, 0, order) - #define ma_atomic_clear_explicit_16(dst, order) ma_atomic_store_explicit_16(dst, 0, order) - #define ma_atomic_clear_explicit_32(dst, order) ma_atomic_store_explicit_32(dst, 0, order) - #define ma_atomic_clear_explicit_64(dst, order) ma_atomic_store_explicit_64(dst, 0, order) - typedef ma_uint8 ma_atomic_flag; - #define ma_atomic_flag_test_and_set_explicit(ptr, order) (ma_bool32)ma_atomic_test_and_set_explicit_8(ptr, order) - #define ma_atomic_flag_clear_explicit(ptr, order) ma_atomic_clear_explicit_8(ptr, order) - #define ma_atomic_flag_load_explicit(ptr, order) ma_atomic_load_explicit_8(ptr, order) #endif #if !defined(MA_ATOMIC_HAS_NATIVE_COMPARE_EXCHANGE) - #if defined(MA_ATOMIC_HAS_8) - static MA_INLINE ma_bool32 ma_atomic_compare_exchange_strong_explicit_8(volatile ma_uint8* dst, ma_uint8* expected, ma_uint8 desired, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) - { - ma_uint8 expectedValue; - ma_uint8 result; - (void)successOrder; - (void)failureOrder; - expectedValue = ma_atomic_load_explicit_8(expected, ma_atomic_memory_order_seq_cst); - result = ma_atomic_compare_and_swap_8(dst, expectedValue, desired); - if (result == expectedValue) { - return 1; - } else { - ma_atomic_store_explicit_8(expected, result, failureOrder); - return 0; - } - } - #endif - #if defined(MA_ATOMIC_HAS_16) - static MA_INLINE ma_bool32 ma_atomic_compare_exchange_strong_explicit_16(volatile ma_uint16* dst, ma_uint16* expected, ma_uint16 desired, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) - { - ma_uint16 expectedValue; - ma_uint16 result; - (void)successOrder; - (void)failureOrder; - expectedValue = ma_atomic_load_explicit_16(expected, ma_atomic_memory_order_seq_cst); - result = ma_atomic_compare_and_swap_16(dst, expectedValue, desired); - if (result == expectedValue) { - return 1; - } else { - ma_atomic_store_explicit_16(expected, result, failureOrder); - return 0; - } - } - #endif - #if defined(MA_ATOMIC_HAS_32) - static MA_INLINE ma_bool32 ma_atomic_compare_exchange_strong_explicit_32(volatile ma_uint32* dst, ma_uint32* expected, ma_uint32 desired, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) - { - ma_uint32 expectedValue; - ma_uint32 result; - (void)successOrder; - (void)failureOrder; - expectedValue = ma_atomic_load_explicit_32(expected, ma_atomic_memory_order_seq_cst); - result = ma_atomic_compare_and_swap_32(dst, expectedValue, desired); - if (result == expectedValue) { - return 1; - } else { - ma_atomic_store_explicit_32(expected, result, failureOrder); - return 0; - } - } - #endif - #if defined(MA_ATOMIC_HAS_64) - static MA_INLINE ma_bool32 ma_atomic_compare_exchange_strong_explicit_64(volatile ma_uint64* dst, volatile ma_uint64* expected, ma_uint64 desired, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) - { - ma_uint64 expectedValue; - ma_uint64 result; - (void)successOrder; - (void)failureOrder; - expectedValue = ma_atomic_load_explicit_64(expected, ma_atomic_memory_order_seq_cst); - result = ma_atomic_compare_and_swap_64(dst, expectedValue, desired); - if (result == expectedValue) { - return 1; - } else { - ma_atomic_store_explicit_64(expected, result, failureOrder); - return 0; - } - } - #endif - #define ma_atomic_compare_exchange_weak_explicit_8( dst, expected, desired, successOrder, failureOrder) ma_atomic_compare_exchange_strong_explicit_8 (dst, expected, desired, successOrder, failureOrder) - #define ma_atomic_compare_exchange_weak_explicit_16(dst, expected, desired, successOrder, failureOrder) ma_atomic_compare_exchange_strong_explicit_16(dst, expected, desired, successOrder, failureOrder) - #define ma_atomic_compare_exchange_weak_explicit_32(dst, expected, desired, successOrder, failureOrder) ma_atomic_compare_exchange_strong_explicit_32(dst, expected, desired, successOrder, failureOrder) - #define ma_atomic_compare_exchange_weak_explicit_64(dst, expected, desired, successOrder, failureOrder) ma_atomic_compare_exchange_strong_explicit_64(dst, expected, desired, successOrder, failureOrder) -#endif -#if !defined(MA_ATOMIC_HAS_NATIVE_IS_LOCK_FREE) - static MA_INLINE ma_bool32 ma_atomic_is_lock_free_8(volatile void* ptr) + static MA_INLINE ma_bool32 ma_atomic_compare_exchange_strong_explicit_8(volatile ma_uint8* dst, ma_uint8* expected, ma_uint8 replacement, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) { - (void)ptr; - return 1; + ma_uint8 result; + (void)successOrder; + (void)failureOrder; + result = ma_atomic_compare_and_swap_8(dst, *expected, replacement); + if (result == *expected) { + return 1; + } else { + *expected = result; + return 0; + } } - static MA_INLINE ma_bool32 ma_atomic_is_lock_free_16(volatile void* ptr) + static MA_INLINE ma_bool32 ma_atomic_compare_exchange_strong_explicit_16(volatile ma_uint16* dst, ma_uint16* expected, ma_uint16 replacement, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) { - (void)ptr; - return 1; + ma_uint16 result; + (void)successOrder; + (void)failureOrder; + result = ma_atomic_compare_and_swap_16(dst, *expected, replacement); + if (result == *expected) { + return 1; + } else { + *expected = result; + return 0; + } } - static MA_INLINE ma_bool32 ma_atomic_is_lock_free_32(volatile void* ptr) + static MA_INLINE ma_bool32 ma_atomic_compare_exchange_strong_explicit_32(volatile ma_uint32* dst, ma_uint32* expected, ma_uint32 replacement, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) { - (void)ptr; - return 1; + ma_uint32 result; + (void)successOrder; + (void)failureOrder; + result = ma_atomic_compare_and_swap_32(dst, *expected, replacement); + if (result == *expected) { + return 1; + } else { + *expected = result; + return 0; + } } - static MA_INLINE ma_bool32 ma_atomic_is_lock_free_64(volatile void* ptr) + static MA_INLINE ma_bool32 ma_atomic_compare_exchange_strong_explicit_64(volatile ma_uint64* dst, volatile ma_uint64* expected, ma_uint64 replacement, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) { - (void)ptr; - #if defined(MA_64BIT) - return 1; - #else - #if defined(MA_X86) || defined(MA_X64) + ma_uint64 result; + (void)successOrder; + (void)failureOrder; + result = ma_atomic_compare_and_swap_64(dst, *expected, replacement); + if (result == *expected) { return 1; - #else + } else { + *expected = result; return 0; - #endif - #endif + } } + #define ma_atomic_compare_exchange_weak_explicit_8( dst, expected, replacement, successOrder, failureOrder) ma_atomic_compare_exchange_strong_explicit_8 (dst, expected, replacement, successOrder, failureOrder) + #define ma_atomic_compare_exchange_weak_explicit_16(dst, expected, replacement, successOrder, failureOrder) ma_atomic_compare_exchange_strong_explicit_16(dst, expected, replacement, successOrder, failureOrder) + #define ma_atomic_compare_exchange_weak_explicit_32(dst, expected, replacement, successOrder, failureOrder) ma_atomic_compare_exchange_strong_explicit_32(dst, expected, replacement, successOrder, failureOrder) + #define ma_atomic_compare_exchange_weak_explicit_64(dst, expected, replacement, successOrder, failureOrder) ma_atomic_compare_exchange_strong_explicit_64(dst, expected, replacement, successOrder, failureOrder) #endif #if defined(MA_64BIT) static MA_INLINE ma_bool32 ma_atomic_is_lock_free_ptr(volatile void** ptr) @@ -15561,17 +17005,17 @@ typedef int ma_atomic_memory_order; { return (void*)ma_atomic_exchange_explicit_64((volatile ma_uint64*)dst, (ma_uint64)src, order); } - static MA_INLINE ma_bool32 ma_atomic_compare_exchange_strong_explicit_ptr(volatile void** dst, void** expected, void* desired, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) + static MA_INLINE ma_bool32 ma_atomic_compare_exchange_strong_explicit_ptr(volatile void** dst, void** expected, void* replacement, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) { - return ma_atomic_compare_exchange_strong_explicit_64((volatile ma_uint64*)dst, (ma_uint64*)expected, (ma_uint64)desired, successOrder, failureOrder); + return ma_atomic_compare_exchange_strong_explicit_64((volatile ma_uint64*)dst, (ma_uint64*)expected, (ma_uint64)replacement, successOrder, failureOrder); } - static MA_INLINE ma_bool32 ma_atomic_compare_exchange_weak_explicit_ptr(volatile void** dst, void** expected, void* desired, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) + static MA_INLINE ma_bool32 ma_atomic_compare_exchange_weak_explicit_ptr(volatile void** dst, void** expected, void* replacement, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) { - return ma_atomic_compare_exchange_weak_explicit_64((volatile ma_uint64*)dst, (ma_uint64*)expected, (ma_uint64)desired, successOrder, failureOrder); + return ma_atomic_compare_exchange_weak_explicit_64((volatile ma_uint64*)dst, (ma_uint64*)expected, (ma_uint64)replacement, successOrder, failureOrder); } - static MA_INLINE void* ma_atomic_compare_and_swap_ptr(volatile void** dst, void* expected, void* desired) + static MA_INLINE void* ma_atomic_compare_and_swap_ptr(volatile void** dst, void* expected, void* replacement) { - return (void*)ma_atomic_compare_and_swap_64((volatile ma_uint64*)dst, (ma_uint64)expected, (ma_uint64)desired); + return (void*)ma_atomic_compare_and_swap_64((volatile ma_uint64*)dst, (ma_uint64)expected, (ma_uint64)replacement); } #elif defined(MA_32BIT) static MA_INLINE ma_bool32 ma_atomic_is_lock_free_ptr(volatile void** ptr) @@ -15590,36 +17034,26 @@ typedef int ma_atomic_memory_order; { return (void*)ma_atomic_exchange_explicit_32((volatile ma_uint32*)dst, (ma_uint32)src, order); } - static MA_INLINE ma_bool32 ma_atomic_compare_exchange_strong_explicit_ptr(volatile void** dst, void** expected, void* desired, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) + static MA_INLINE ma_bool32 ma_atomic_compare_exchange_strong_explicit_ptr(volatile void** dst, void** expected, void* replacement, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) { - return ma_atomic_compare_exchange_strong_explicit_32((volatile ma_uint32*)dst, (ma_uint32*)expected, (ma_uint32)desired, successOrder, failureOrder); + return ma_atomic_compare_exchange_strong_explicit_32((volatile ma_uint32*)dst, (ma_uint32*)expected, (ma_uint32)replacement, successOrder, failureOrder); } - static MA_INLINE ma_bool32 ma_atomic_compare_exchange_weak_explicit_ptr(volatile void** dst, void** expected, void* desired, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) + static MA_INLINE ma_bool32 ma_atomic_compare_exchange_weak_explicit_ptr(volatile void** dst, void** expected, void* replacement, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) { - return ma_atomic_compare_exchange_weak_explicit_32((volatile ma_uint32*)dst, (ma_uint32*)expected, (ma_uint32)desired, successOrder, failureOrder); + return ma_atomic_compare_exchange_weak_explicit_32((volatile ma_uint32*)dst, (ma_uint32*)expected, (ma_uint32)replacement, successOrder, failureOrder); } - static MA_INLINE void* ma_atomic_compare_and_swap_ptr(volatile void** dst, void* expected, void* desired) + static MA_INLINE void* ma_atomic_compare_and_swap_ptr(volatile void** dst, void* expected, void* replacement) { - return (void*)ma_atomic_compare_and_swap_32((volatile ma_uint32*)dst, (ma_uint32)expected, (ma_uint32)desired); + return (void*)ma_atomic_compare_and_swap_32((volatile ma_uint32*)dst, (ma_uint32)expected, (ma_uint32)replacement); } #else #error Unsupported architecture. #endif -#define ma_atomic_flag_test_and_set(ptr) ma_atomic_flag_test_and_set_explicit(ptr, ma_atomic_memory_order_seq_cst) -#define ma_atomic_flag_clear(ptr) ma_atomic_flag_clear_explicit(ptr, ma_atomic_memory_order_seq_cst) -#define ma_atomic_store_ptr(dst, src) ma_atomic_store_explicit_ptr((volatile void**)dst, (void*)src, ma_atomic_memory_order_seq_cst) -#define ma_atomic_load_ptr(ptr) ma_atomic_load_explicit_ptr((volatile void**)ptr, ma_atomic_memory_order_seq_cst) -#define ma_atomic_exchange_ptr(dst, src) ma_atomic_exchange_explicit_ptr((volatile void**)dst, (void*)src, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_strong_ptr(dst, expected, desired) ma_atomic_compare_exchange_strong_explicit_ptr((volatile void**)dst, (void**)expected, (void*)desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_weak_ptr(dst, expected, desired) ma_atomic_compare_exchange_weak_explicit_ptr((volatile void**)dst, (void**)expected, (void*)desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) -#define ma_atomic_test_and_set_8( ptr) ma_atomic_test_and_set_explicit_8( ptr, ma_atomic_memory_order_seq_cst) -#define ma_atomic_test_and_set_16(ptr) ma_atomic_test_and_set_explicit_16(ptr, ma_atomic_memory_order_seq_cst) -#define ma_atomic_test_and_set_32(ptr) ma_atomic_test_and_set_explicit_32(ptr, ma_atomic_memory_order_seq_cst) -#define ma_atomic_test_and_set_64(ptr) ma_atomic_test_and_set_explicit_64(ptr, ma_atomic_memory_order_seq_cst) -#define ma_atomic_clear_8( ptr) ma_atomic_clear_explicit_8( ptr, ma_atomic_memory_order_seq_cst) -#define ma_atomic_clear_16(ptr) ma_atomic_clear_explicit_16(ptr, ma_atomic_memory_order_seq_cst) -#define ma_atomic_clear_32(ptr) ma_atomic_clear_explicit_32(ptr, ma_atomic_memory_order_seq_cst) -#define ma_atomic_clear_64(ptr) ma_atomic_clear_explicit_64(ptr, ma_atomic_memory_order_seq_cst) +#define ma_atomic_store_ptr(dst, src) ma_atomic_store_explicit_ptr((volatile void**)dst, (void*)src, ma_atomic_memory_order_seq_cst) +#define ma_atomic_load_ptr(ptr) ma_atomic_load_explicit_ptr((volatile void**)ptr, ma_atomic_memory_order_seq_cst) +#define ma_atomic_exchange_ptr(dst, src) ma_atomic_exchange_explicit_ptr((volatile void**)dst, (void*)src, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_strong_ptr(dst, expected, replacement) ma_atomic_compare_exchange_strong_explicit_ptr((volatile void**)dst, (void**)expected, (void*)replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_weak_ptr(dst, expected, replacement) ma_atomic_compare_exchange_weak_explicit_ptr((volatile void**)dst, (void**)expected, (void*)replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) #define ma_atomic_store_8( dst, src) ma_atomic_store_explicit_8( dst, src, ma_atomic_memory_order_seq_cst) #define ma_atomic_store_16(dst, src) ma_atomic_store_explicit_16(dst, src, ma_atomic_memory_order_seq_cst) #define ma_atomic_store_32(dst, src) ma_atomic_store_explicit_32(dst, src, ma_atomic_memory_order_seq_cst) @@ -15632,14 +17066,14 @@ typedef int ma_atomic_memory_order; #define ma_atomic_exchange_16(dst, src) ma_atomic_exchange_explicit_16(dst, src, ma_atomic_memory_order_seq_cst) #define ma_atomic_exchange_32(dst, src) ma_atomic_exchange_explicit_32(dst, src, ma_atomic_memory_order_seq_cst) #define ma_atomic_exchange_64(dst, src) ma_atomic_exchange_explicit_64(dst, src, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_strong_8( dst, expected, desired) ma_atomic_compare_exchange_strong_explicit_8( dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_strong_16(dst, expected, desired) ma_atomic_compare_exchange_strong_explicit_16(dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_strong_32(dst, expected, desired) ma_atomic_compare_exchange_strong_explicit_32(dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_strong_64(dst, expected, desired) ma_atomic_compare_exchange_strong_explicit_64(dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_weak_8( dst, expected, desired) ma_atomic_compare_exchange_weak_explicit_8( dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_weak_16( dst, expected, desired) ma_atomic_compare_exchange_weak_explicit_16(dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_weak_32( dst, expected, desired) ma_atomic_compare_exchange_weak_explicit_32(dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_weak_64( dst, expected, desired) ma_atomic_compare_exchange_weak_explicit_64(dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_strong_8( dst, expected, replacement) ma_atomic_compare_exchange_strong_explicit_8( dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_strong_16(dst, expected, replacement) ma_atomic_compare_exchange_strong_explicit_16(dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_strong_32(dst, expected, replacement) ma_atomic_compare_exchange_strong_explicit_32(dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_strong_64(dst, expected, replacement) ma_atomic_compare_exchange_strong_explicit_64(dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_weak_8( dst, expected, replacement) ma_atomic_compare_exchange_weak_explicit_8( dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_weak_16( dst, expected, replacement) ma_atomic_compare_exchange_weak_explicit_16(dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_weak_32( dst, expected, replacement) ma_atomic_compare_exchange_weak_explicit_32(dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_weak_64( dst, expected, replacement) ma_atomic_compare_exchange_weak_explicit_64(dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) #define ma_atomic_fetch_add_8( dst, src) ma_atomic_fetch_add_explicit_8( dst, src, ma_atomic_memory_order_seq_cst) #define ma_atomic_fetch_add_16(dst, src) ma_atomic_fetch_add_explicit_16(dst, src, ma_atomic_memory_order_seq_cst) #define ma_atomic_fetch_add_32(dst, src) ma_atomic_fetch_add_explicit_32(dst, src, ma_atomic_memory_order_seq_cst) @@ -15660,14 +17094,6 @@ typedef int ma_atomic_memory_order; #define ma_atomic_fetch_and_16(dst, src) ma_atomic_fetch_and_explicit_16(dst, src, ma_atomic_memory_order_seq_cst) #define ma_atomic_fetch_and_32(dst, src) ma_atomic_fetch_and_explicit_32(dst, src, ma_atomic_memory_order_seq_cst) #define ma_atomic_fetch_and_64(dst, src) ma_atomic_fetch_and_explicit_64(dst, src, ma_atomic_memory_order_seq_cst) -#define ma_atomic_test_and_set_explicit_i8( ptr, order) (ma_int8 )ma_atomic_test_and_set_explicit_8( (ma_uint8* )ptr, order) -#define ma_atomic_test_and_set_explicit_i16(ptr, order) (ma_int16)ma_atomic_test_and_set_explicit_16((ma_uint16*)ptr, order) -#define ma_atomic_test_and_set_explicit_i32(ptr, order) (ma_int32)ma_atomic_test_and_set_explicit_32((ma_uint32*)ptr, order) -#define ma_atomic_test_and_set_explicit_i64(ptr, order) (ma_int64)ma_atomic_test_and_set_explicit_64((ma_uint64*)ptr, order) -#define ma_atomic_clear_explicit_i8( ptr, order) ma_atomic_clear_explicit_8( (ma_uint8* )ptr, order) -#define ma_atomic_clear_explicit_i16(ptr, order) ma_atomic_clear_explicit_16((ma_uint16*)ptr, order) -#define ma_atomic_clear_explicit_i32(ptr, order) ma_atomic_clear_explicit_32((ma_uint32*)ptr, order) -#define ma_atomic_clear_explicit_i64(ptr, order) ma_atomic_clear_explicit_64((ma_uint64*)ptr, order) #define ma_atomic_store_explicit_i8( dst, src, order) ma_atomic_store_explicit_8( (ma_uint8* )dst, (ma_uint8 )src, order) #define ma_atomic_store_explicit_i16(dst, src, order) ma_atomic_store_explicit_16((ma_uint16*)dst, (ma_uint16)src, order) #define ma_atomic_store_explicit_i32(dst, src, order) ma_atomic_store_explicit_32((ma_uint32*)dst, (ma_uint32)src, order) @@ -15680,14 +17106,14 @@ typedef int ma_atomic_memory_order; #define ma_atomic_exchange_explicit_i16(dst, src, order) (ma_int16)ma_atomic_exchange_explicit_16((ma_uint16*)dst, (ma_uint16)src, order) #define ma_atomic_exchange_explicit_i32(dst, src, order) (ma_int32)ma_atomic_exchange_explicit_32((ma_uint32*)dst, (ma_uint32)src, order) #define ma_atomic_exchange_explicit_i64(dst, src, order) (ma_int64)ma_atomic_exchange_explicit_64((ma_uint64*)dst, (ma_uint64)src, order) -#define ma_atomic_compare_exchange_strong_explicit_i8( dst, expected, desired, successOrder, failureOrder) ma_atomic_compare_exchange_strong_explicit_8( (ma_uint8* )dst, (ma_uint8* )expected, (ma_uint8 )desired, successOrder, failureOrder) -#define ma_atomic_compare_exchange_strong_explicit_i16(dst, expected, desired, successOrder, failureOrder) ma_atomic_compare_exchange_strong_explicit_16((ma_uint16*)dst, (ma_uint16*)expected, (ma_uint16)desired, successOrder, failureOrder) -#define ma_atomic_compare_exchange_strong_explicit_i32(dst, expected, desired, successOrder, failureOrder) ma_atomic_compare_exchange_strong_explicit_32((ma_uint32*)dst, (ma_uint32*)expected, (ma_uint32)desired, successOrder, failureOrder) -#define ma_atomic_compare_exchange_strong_explicit_i64(dst, expected, desired, successOrder, failureOrder) ma_atomic_compare_exchange_strong_explicit_64((ma_uint64*)dst, (ma_uint64*)expected, (ma_uint64)desired, successOrder, failureOrder) -#define ma_atomic_compare_exchange_weak_explicit_i8( dst, expected, desired, successOrder, failureOrder) ma_atomic_compare_exchange_weak_explicit_8( (ma_uint8* )dst, (ma_uint8* )expected, (ma_uint8 )desired, successOrder, failureOrder) -#define ma_atomic_compare_exchange_weak_explicit_i16(dst, expected, desired, successOrder, failureOrder) ma_atomic_compare_exchange_weak_explicit_16((ma_uint16*)dst, (ma_uint16*)expected, (ma_uint16)desired, successOrder, failureOrder) -#define ma_atomic_compare_exchange_weak_explicit_i32(dst, expected, desired, successOrder, failureOrder) ma_atomic_compare_exchange_weak_explicit_32((ma_uint32*)dst, (ma_uint32*)expected, (ma_uint32)desired, successOrder, failureOrder) -#define ma_atomic_compare_exchange_weak_explicit_i64(dst, expected, desired, successOrder, failureOrder) ma_atomic_compare_exchange_weak_explicit_64((ma_uint64*)dst, (ma_uint64*)expected, (ma_uint64)desired, successOrder, failureOrder) +#define ma_atomic_compare_exchange_strong_explicit_i8( dst, expected, replacement, successOrder, failureOrder) ma_atomic_compare_exchange_strong_explicit_8( (ma_uint8* )dst, (ma_uint8* )expected, (ma_uint8 )replacement, successOrder, failureOrder) +#define ma_atomic_compare_exchange_strong_explicit_i16(dst, expected, replacement, successOrder, failureOrder) ma_atomic_compare_exchange_strong_explicit_16((ma_uint16*)dst, (ma_uint16*)expected, (ma_uint16)replacement, successOrder, failureOrder) +#define ma_atomic_compare_exchange_strong_explicit_i32(dst, expected, replacement, successOrder, failureOrder) ma_atomic_compare_exchange_strong_explicit_32((ma_uint32*)dst, (ma_uint32*)expected, (ma_uint32)replacement, successOrder, failureOrder) +#define ma_atomic_compare_exchange_strong_explicit_i64(dst, expected, replacement, successOrder, failureOrder) ma_atomic_compare_exchange_strong_explicit_64((ma_uint64*)dst, (ma_uint64*)expected, (ma_uint64)replacement, successOrder, failureOrder) +#define ma_atomic_compare_exchange_weak_explicit_i8( dst, expected, replacement, successOrder, failureOrder) ma_atomic_compare_exchange_weak_explicit_8( (ma_uint8* )dst, (ma_uint8* )expected, (ma_uint8 )replacement, successOrder, failureOrder) +#define ma_atomic_compare_exchange_weak_explicit_i16(dst, expected, replacement, successOrder, failureOrder) ma_atomic_compare_exchange_weak_explicit_16((ma_uint16*)dst, (ma_uint16*)expected, (ma_uint16)replacement, successOrder, failureOrder) +#define ma_atomic_compare_exchange_weak_explicit_i32(dst, expected, replacement, successOrder, failureOrder) ma_atomic_compare_exchange_weak_explicit_32((ma_uint32*)dst, (ma_uint32*)expected, (ma_uint32)replacement, successOrder, failureOrder) +#define ma_atomic_compare_exchange_weak_explicit_i64(dst, expected, replacement, successOrder, failureOrder) ma_atomic_compare_exchange_weak_explicit_64((ma_uint64*)dst, (ma_uint64*)expected, (ma_uint64)replacement, successOrder, failureOrder) #define ma_atomic_fetch_add_explicit_i8( dst, src, order) (ma_int8 )ma_atomic_fetch_add_explicit_8( (ma_uint8* )dst, (ma_uint8 )src, order) #define ma_atomic_fetch_add_explicit_i16(dst, src, order) (ma_int16)ma_atomic_fetch_add_explicit_16((ma_uint16*)dst, (ma_uint16)src, order) #define ma_atomic_fetch_add_explicit_i32(dst, src, order) (ma_int32)ma_atomic_fetch_add_explicit_32((ma_uint32*)dst, (ma_uint32)src, order) @@ -15708,14 +17134,6 @@ typedef int ma_atomic_memory_order; #define ma_atomic_fetch_and_explicit_i16(dst, src, order) (ma_int16)ma_atomic_fetch_and_explicit_16((ma_uint16*)dst, (ma_uint16)src, order) #define ma_atomic_fetch_and_explicit_i32(dst, src, order) (ma_int32)ma_atomic_fetch_and_explicit_32((ma_uint32*)dst, (ma_uint32)src, order) #define ma_atomic_fetch_and_explicit_i64(dst, src, order) (ma_int64)ma_atomic_fetch_and_explicit_64((ma_uint64*)dst, (ma_uint64)src, order) -#define ma_atomic_test_and_set_i8( ptr) ma_atomic_test_and_set_explicit_i8( ptr, ma_atomic_memory_order_seq_cst) -#define ma_atomic_test_and_set_i16(ptr) ma_atomic_test_and_set_explicit_i16(ptr, ma_atomic_memory_order_seq_cst) -#define ma_atomic_test_and_set_i32(ptr) ma_atomic_test_and_set_explicit_i32(ptr, ma_atomic_memory_order_seq_cst) -#define ma_atomic_test_and_set_i64(ptr) ma_atomic_test_and_set_explicit_i64(ptr, ma_atomic_memory_order_seq_cst) -#define ma_atomic_clear_i8( ptr) ma_atomic_clear_explicit_i8( ptr, ma_atomic_memory_order_seq_cst) -#define ma_atomic_clear_i16(ptr) ma_atomic_clear_explicit_i16(ptr, ma_atomic_memory_order_seq_cst) -#define ma_atomic_clear_i32(ptr) ma_atomic_clear_explicit_i32(ptr, ma_atomic_memory_order_seq_cst) -#define ma_atomic_clear_i64(ptr) ma_atomic_clear_explicit_i64(ptr, ma_atomic_memory_order_seq_cst) #define ma_atomic_store_i8( dst, src) ma_atomic_store_explicit_i8( dst, src, ma_atomic_memory_order_seq_cst) #define ma_atomic_store_i16(dst, src) ma_atomic_store_explicit_i16(dst, src, ma_atomic_memory_order_seq_cst) #define ma_atomic_store_i32(dst, src) ma_atomic_store_explicit_i32(dst, src, ma_atomic_memory_order_seq_cst) @@ -15728,14 +17146,14 @@ typedef int ma_atomic_memory_order; #define ma_atomic_exchange_i16(dst, src) ma_atomic_exchange_explicit_i16(dst, src, ma_atomic_memory_order_seq_cst) #define ma_atomic_exchange_i32(dst, src) ma_atomic_exchange_explicit_i32(dst, src, ma_atomic_memory_order_seq_cst) #define ma_atomic_exchange_i64(dst, src) ma_atomic_exchange_explicit_i64(dst, src, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_strong_i8( dst, expected, desired) ma_atomic_compare_exchange_strong_explicit_i8( dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_strong_i16(dst, expected, desired) ma_atomic_compare_exchange_strong_explicit_i16(dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_strong_i32(dst, expected, desired) ma_atomic_compare_exchange_strong_explicit_i32(dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_strong_i64(dst, expected, desired) ma_atomic_compare_exchange_strong_explicit_i64(dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_weak_i8( dst, expected, desired) ma_atomic_compare_exchange_weak_explicit_i8( dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_weak_i16(dst, expected, desired) ma_atomic_compare_exchange_weak_explicit_i16(dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_weak_i32(dst, expected, desired) ma_atomic_compare_exchange_weak_explicit_i32(dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_weak_i64(dst, expected, desired) ma_atomic_compare_exchange_weak_explicit_i64(dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_strong_i8( dst, expected, replacement) ma_atomic_compare_exchange_strong_explicit_i8( dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_strong_i16(dst, expected, replacement) ma_atomic_compare_exchange_strong_explicit_i16(dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_strong_i32(dst, expected, replacement) ma_atomic_compare_exchange_strong_explicit_i32(dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_strong_i64(dst, expected, replacement) ma_atomic_compare_exchange_strong_explicit_i64(dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_weak_i8( dst, expected, replacement) ma_atomic_compare_exchange_weak_explicit_i8( dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_weak_i16(dst, expected, replacement) ma_atomic_compare_exchange_weak_explicit_i16(dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_weak_i32(dst, expected, replacement) ma_atomic_compare_exchange_weak_explicit_i32(dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_weak_i64(dst, expected, replacement) ma_atomic_compare_exchange_weak_explicit_i64(dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) #define ma_atomic_fetch_add_i8( dst, src) ma_atomic_fetch_add_explicit_i8( dst, src, ma_atomic_memory_order_seq_cst) #define ma_atomic_fetch_add_i16(dst, src) ma_atomic_fetch_add_explicit_i16(dst, src, ma_atomic_memory_order_seq_cst) #define ma_atomic_fetch_add_i32(dst, src) ma_atomic_fetch_add_explicit_i32(dst, src, ma_atomic_memory_order_seq_cst) @@ -15812,28 +17230,28 @@ static MA_INLINE double ma_atomic_exchange_explicit_f64(volatile double* dst, do r.i = ma_atomic_exchange_explicit_64((volatile ma_uint64*)dst, x.i, order); return r.f; } -static MA_INLINE ma_bool32 ma_atomic_compare_exchange_strong_explicit_f32(volatile float* dst, float* expected, float desired, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) +static MA_INLINE ma_bool32 ma_atomic_compare_exchange_strong_explicit_f32(volatile float* dst, float* expected, float replacement, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) { ma_atomic_if32 d; - d.f = desired; + d.f = replacement; return ma_atomic_compare_exchange_strong_explicit_32((volatile ma_uint32*)dst, (ma_uint32*)expected, d.i, successOrder, failureOrder); } -static MA_INLINE ma_bool32 ma_atomic_compare_exchange_strong_explicit_f64(volatile double* dst, double* expected, double desired, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) +static MA_INLINE ma_bool32 ma_atomic_compare_exchange_strong_explicit_f64(volatile double* dst, double* expected, double replacement, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) { ma_atomic_if64 d; - d.f = desired; + d.f = replacement; return ma_atomic_compare_exchange_strong_explicit_64((volatile ma_uint64*)dst, (ma_uint64*)expected, d.i, successOrder, failureOrder); } -static MA_INLINE ma_bool32 ma_atomic_compare_exchange_weak_explicit_f32(volatile float* dst, float* expected, float desired, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) +static MA_INLINE ma_bool32 ma_atomic_compare_exchange_weak_explicit_f32(volatile float* dst, float* expected, float replacement, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) { ma_atomic_if32 d; - d.f = desired; + d.f = replacement; return ma_atomic_compare_exchange_weak_explicit_32((volatile ma_uint32*)dst, (ma_uint32*)expected, d.i, successOrder, failureOrder); } -static MA_INLINE ma_bool32 ma_atomic_compare_exchange_weak_explicit_f64(volatile double* dst, double* expected, double desired, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) +static MA_INLINE ma_bool32 ma_atomic_compare_exchange_weak_explicit_f64(volatile double* dst, double* expected, double replacement, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) { ma_atomic_if64 d; - d.f = desired; + d.f = replacement; return ma_atomic_compare_exchange_weak_explicit_64((volatile ma_uint64*)dst, (ma_uint64*)expected, d.i, successOrder, failureOrder); } static MA_INLINE float ma_atomic_fetch_add_explicit_f32(volatile float* dst, float src, ma_atomic_memory_order order) @@ -15924,10 +17342,10 @@ static MA_INLINE double ma_atomic_fetch_and_explicit_f64(volatile double* dst, d #define ma_atomic_load_f64(ptr) (double)ma_atomic_load_explicit_f64(ptr, ma_atomic_memory_order_seq_cst) #define ma_atomic_exchange_f32(dst, src) (float )ma_atomic_exchange_explicit_f32(dst, src, ma_atomic_memory_order_seq_cst) #define ma_atomic_exchange_f64(dst, src) (double)ma_atomic_exchange_explicit_f64(dst, src, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_strong_f32(dst, expected, desired) ma_atomic_compare_exchange_strong_explicit_f32(dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_strong_f64(dst, expected, desired) ma_atomic_compare_exchange_strong_explicit_f64(dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_weak_f32(dst, expected, desired) ma_atomic_compare_exchange_weak_explicit_f32(dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_weak_f64(dst, expected, desired) ma_atomic_compare_exchange_weak_explicit_f64(dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_strong_f32(dst, expected, replacement) ma_atomic_compare_exchange_strong_explicit_f32(dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_strong_f64(dst, expected, replacement) ma_atomic_compare_exchange_strong_explicit_f64(dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_weak_f32(dst, expected, replacement) ma_atomic_compare_exchange_weak_explicit_f32(dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_weak_f64(dst, expected, replacement) ma_atomic_compare_exchange_weak_explicit_f64(dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) #define ma_atomic_fetch_add_f32(dst, src) ma_atomic_fetch_add_explicit_f32(dst, src, ma_atomic_memory_order_seq_cst) #define ma_atomic_fetch_add_f64(dst, src) ma_atomic_fetch_add_explicit_f64(dst, src, ma_atomic_memory_order_seq_cst) #define ma_atomic_fetch_sub_f32(dst, src) ma_atomic_fetch_sub_explicit_f32(dst, src, ma_atomic_memory_order_seq_cst) @@ -15938,39 +17356,24 @@ static MA_INLINE double ma_atomic_fetch_and_explicit_f64(volatile double* dst, d #define ma_atomic_fetch_xor_f64(dst, src) ma_atomic_fetch_xor_explicit_f64(dst, src, ma_atomic_memory_order_seq_cst) #define ma_atomic_fetch_and_f32(dst, src) ma_atomic_fetch_and_explicit_f32(dst, src, ma_atomic_memory_order_seq_cst) #define ma_atomic_fetch_and_f64(dst, src) ma_atomic_fetch_and_explicit_f64(dst, src, ma_atomic_memory_order_seq_cst) -static MA_INLINE float ma_atomic_compare_and_swap_f32(volatile float* dst, float expected, float desired) +static MA_INLINE float ma_atomic_compare_and_swap_f32(volatile float* dst, float expected, float replacement) { ma_atomic_if32 r; ma_atomic_if32 e, d; e.f = expected; - d.f = desired; + d.f = replacement; r.i = ma_atomic_compare_and_swap_32((volatile ma_uint32*)dst, e.i, d.i); return r.f; } -static MA_INLINE double ma_atomic_compare_and_swap_f64(volatile double* dst, double expected, double desired) +static MA_INLINE double ma_atomic_compare_and_swap_f64(volatile double* dst, double expected, double replacement) { ma_atomic_if64 r; ma_atomic_if64 e, d; e.f = expected; - d.f = desired; + d.f = replacement; r.i = ma_atomic_compare_and_swap_64((volatile ma_uint64*)dst, e.i, d.i); return r.f; } -typedef ma_atomic_flag ma_atomic_spinlock; -static MA_INLINE void ma_atomic_spinlock_lock(volatile ma_atomic_spinlock* pSpinlock) -{ - for (;;) { - if (ma_atomic_flag_test_and_set_explicit(pSpinlock, ma_atomic_memory_order_acquire) == 0) { - break; - } - while (ma_atomic_flag_load_explicit(pSpinlock, ma_atomic_memory_order_relaxed) == 1) { - } - } -} -static MA_INLINE void ma_atomic_spinlock_unlock(volatile ma_atomic_spinlock* pSpinlock) -{ - ma_atomic_flag_clear_explicit(pSpinlock, ma_atomic_memory_order_release); -} #if defined(__clang__) || (defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 6))) #pragma GCC diagnostic pop #endif @@ -16176,7 +17579,7 @@ static ma_result ma_thread_create__posix(ma_thread* pThread, ma_thread_priority int result; pthread_attr_t* pAttr = NULL; -#if !defined(__EMSCRIPTEN__) && !defined(__3DS__) +#if !defined(MA_EMSCRIPTEN) && !defined(MA_3DS) && !defined(MA_SWITCH) /* Try setting the thread priority. It's not critical if anything fails here. */ pthread_attr_t attr; if (pthread_attr_init(&attr) == 0) { @@ -16208,9 +17611,18 @@ static ma_result ma_thread_create__posix(ma_thread* pThread, ma_thread_priority } #endif - if (stackSize > 0) { - pthread_attr_setstacksize(&attr, stackSize); + #if defined(_POSIX_THREAD_ATTR_STACKSIZE) && _POSIX_THREAD_ATTR_STACKSIZE >= 0 + { + if (stackSize > 0) { + pthread_attr_setstacksize(&attr, stackSize); + } + } + #else + { + (void)stackSize; /* Suppress unused parameter warning. */ } + #endif + if (scheduler != -1) { int priorityMin = sched_get_priority_min(scheduler); @@ -16218,7 +17630,7 @@ static ma_result ma_thread_create__posix(ma_thread* pThread, ma_thread_priority int priorityStep = (priorityMax - priorityMin) / 7; /* 7 = number of priorities supported by miniaudio. */ struct sched_param sched; - if (pthread_attr_getschedparam(&attr, &sched) == 0) { + if (priorityMin != -1 && priorityMax != -1 && pthread_attr_getschedparam(&attr, &sched) == 0) { if (priority == ma_thread_priority_idle) { sched.sched_priority = priorityMin; } else if (priority == ma_thread_priority_realtime) { @@ -16267,6 +17679,21 @@ static ma_result ma_thread_create__posix(ma_thread* pThread, ma_thread_priority } if (result != 0) { + /* + There have been reports that attempting to create a realtime thread can sometimes fail. In this case, + fall back to a normal priority thread. + + I'm including a compile-time option here to disable this functionality for those who have a hard + requirement on realtime threads and would rather an explicit failure. + */ + #ifndef MA_NO_PTHREAD_REALTIME_PRIORITY_FALLBACK + { + if(result == EPERM && priority == ma_thread_priority_realtime) { + return ma_thread_create__posix(pThread, ma_thread_priority_normal, stackSize, entryProc, pData); + } + } + #endif + return ma_result_from_errno(result); } @@ -16538,7 +17965,7 @@ static ma_result ma_event_signal__win32(ma_event* pEvent) static ma_result ma_semaphore_init__win32(int initialValue, ma_semaphore* pSemaphore) { - *pSemaphore = CreateSemaphoreW(NULL, (LONG)initialValue, LONG_MAX, NULL); + *pSemaphore = CreateSemaphore(NULL, (LONG)initialValue, LONG_MAX, NULL); if (*pSemaphore == NULL) { return ma_result_from_GetLastError(GetLastError()); } @@ -17432,10 +18859,12 @@ static MA_INLINE ma_uint16 ma_job_extract_slot(ma_uint64 toc) return (ma_uint16)(toc & 0x0000FFFF); } +#if 0 /* Currently unused, but might make use of this later. */ static MA_INLINE ma_uint16 ma_job_extract_code(ma_uint64 toc) { return (ma_uint16)((toc & 0xFFFF0000) >> 16); } +#endif static MA_INLINE ma_uint64 ma_job_toc_to_allocation(ma_uint64 toc) { @@ -17900,6 +19329,13 @@ MA_API ma_result ma_job_queue_next(ma_job_queue* pQueue, ma_job* pJob) Dynamic Linking *******************************************************************************/ +/* Disable run-time linking on certain backends and platforms. */ +#ifndef MA_NO_RUNTIME_LINKING + #if defined(MA_EMSCRIPTEN) || defined(MA_ORBIS) || defined(MA_PROSPERO) || defined(MA_SWITCH) || defined(MA_DOS) + #define MA_NO_RUNTIME_LINKING + #endif +#endif + #ifdef MA_POSIX /* No need for dlfcn.h if we're not using runtime linking. */ #ifndef MA_NO_RUNTIME_LINKING @@ -17909,104 +19345,124 @@ Dynamic Linking MA_API ma_handle ma_dlopen(ma_log* pLog, const char* filename) { -#ifndef MA_NO_RUNTIME_LINKING - ma_handle handle; + #ifndef MA_NO_RUNTIME_LINKING + { + ma_handle handle; - ma_log_postf(pLog, MA_LOG_LEVEL_DEBUG, "Loading library: %s\n", filename); + ma_log_postf(pLog, MA_LOG_LEVEL_DEBUG, "Loading library: %s\n", filename); - #ifdef MA_WIN32 - /* From MSDN: Desktop applications cannot use LoadPackagedLibrary; if a desktop application calls this function it fails with APPMODEL_ERROR_NO_PACKAGE.*/ - #if !defined(MA_WIN32_UWP) || !(defined(WINAPI_FAMILY) && ((defined(WINAPI_FAMILY_PHONE_APP) && WINAPI_FAMILY == WINAPI_FAMILY_PHONE_APP))) - handle = (ma_handle)LoadLibraryA(filename); + #ifdef MA_WIN32 + /* From MSDN: Desktop applications cannot use LoadPackagedLibrary; if a desktop application calls this function it fails with APPMODEL_ERROR_NO_PACKAGE.*/ + #if !defined(MA_WIN32_UWP) || !(defined(WINAPI_FAMILY) && ((defined(WINAPI_FAMILY_PHONE_APP) && WINAPI_FAMILY == WINAPI_FAMILY_PHONE_APP))) + handle = (ma_handle)LoadLibraryA(filename); + #else + /* *sigh* It appears there is no ANSI version of LoadPackagedLibrary()... */ + WCHAR filenameW[4096]; + if (MultiByteToWideChar(CP_UTF8, 0, filename, -1, filenameW, sizeof(filenameW)) == 0) { + handle = NULL; + } else { + handle = (ma_handle)LoadPackagedLibrary(filenameW, 0); + } + #endif #else - /* *sigh* It appears there is no ANSI version of LoadPackagedLibrary()... */ - WCHAR filenameW[4096]; - if (MultiByteToWideChar(CP_UTF8, 0, filename, -1, filenameW, sizeof(filenameW)) == 0) { - handle = NULL; - } else { - handle = (ma_handle)LoadPackagedLibrary(filenameW, 0); - } + handle = (ma_handle)dlopen(filename, RTLD_NOW); #endif - #else - handle = (ma_handle)dlopen(filename, RTLD_NOW); - #endif - /* - I'm not considering failure to load a library an error nor a warning because seamlessly falling through to a lower-priority - backend is a deliberate design choice. Instead I'm logging it as an informational message. - */ - if (handle == NULL) { - ma_log_postf(pLog, MA_LOG_LEVEL_INFO, "Failed to load library: %s\n", filename); - } + /* + I'm not considering failure to load a library an error nor a warning because seamlessly falling through to a lower-priority + backend is a deliberate design choice. Instead I'm logging it as an informational message. + */ + if (handle == NULL) { + ma_log_postf(pLog, MA_LOG_LEVEL_INFO, "Failed to load library: %s\n", filename); + } - return handle; -#else - /* Runtime linking is disabled. */ - (void)pLog; - (void)filename; - return NULL; -#endif + return handle; + } + #else + { + /* Runtime linking is disabled. */ + (void)pLog; + (void)filename; + return NULL; + } + #endif } MA_API void ma_dlclose(ma_log* pLog, ma_handle handle) { -#ifndef MA_NO_RUNTIME_LINKING - #ifdef MA_WIN32 - FreeLibrary((HMODULE)handle); - #else - /* Hack for Android bug (see https://github.com/android/ndk/issues/360). Calling dlclose() pre-API 28 may segfault. */ - #if !defined(MA_ANDROID) || (defined(__ANDROID_API__) && __ANDROID_API__ >= 28) + #ifndef MA_NO_RUNTIME_LINKING + { + #ifdef MA_WIN32 { - dlclose((void*)handle); + FreeLibrary((HMODULE)handle); } #else { - (void)handle; + /* Hack for Android bug (see https://github.com/android/ndk/issues/360). Calling dlclose() pre-API 28 may segfault. */ + #if !defined(MA_ANDROID) || (defined(__ANDROID_API__) && __ANDROID_API__ >= 28) + { + dlclose((void*)handle); + } + #else + { + (void)handle; + } + #endif } #endif - #endif - (void)pLog; -#else - /* Runtime linking is disabled. */ - (void)pLog; - (void)handle; -#endif + (void)pLog; + } + #else + { + /* Runtime linking is disabled. */ + (void)pLog; + (void)handle; + } + #endif } MA_API ma_proc ma_dlsym(ma_log* pLog, ma_handle handle, const char* symbol) { -#ifndef MA_NO_RUNTIME_LINKING - ma_proc proc; + #ifndef MA_NO_RUNTIME_LINKING + { + ma_proc proc; - ma_log_postf(pLog, MA_LOG_LEVEL_DEBUG, "Loading symbol: %s\n", symbol); + ma_log_postf(pLog, MA_LOG_LEVEL_DEBUG, "Loading symbol: %s\n", symbol); -#ifdef _WIN32 - proc = (ma_proc)GetProcAddress((HMODULE)handle, symbol); -#else -#if (defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 8))) || defined(__clang__) - #pragma GCC diagnostic push - #pragma GCC diagnostic ignored "-Wpedantic" -#endif - proc = (ma_proc)dlsym((void*)handle, symbol); -#if (defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 8))) || defined(__clang__) - #pragma GCC diagnostic pop -#endif -#endif + #ifdef _WIN32 + { + proc = (ma_proc)GetProcAddress((HMODULE)handle, symbol); + } + #else + { + #if (defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 8))) || defined(__clang__) + #pragma GCC diagnostic push + #pragma GCC diagnostic ignored "-Wpedantic" + #endif + proc = (ma_proc)dlsym((void*)handle, symbol); + #if (defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 8))) || defined(__clang__) + #pragma GCC diagnostic pop + #endif + } + #endif - if (proc == NULL) { - ma_log_postf(pLog, MA_LOG_LEVEL_WARNING, "Failed to load symbol: %s\n", symbol); - } + if (proc == NULL) { + ma_log_postf(pLog, MA_LOG_LEVEL_WARNING, "Failed to load symbol: %s\n", symbol); + } - (void)pLog; /* It's possible for pContext to be unused. */ - return proc; -#else - /* Runtime linking is disabled. */ - (void)pLog; - (void)handle; - (void)symbol; - return NULL; -#endif + (void)pLog; /* It's possible for pContext to be unused. */ + return proc; + } + #else + { + /* Runtime linking is disabled. */ + (void)pLog; + (void)handle; + (void)symbol; + return NULL; + } + #endif } @@ -18020,13 +19476,6 @@ DEVICE I/O ************************************************************************************************************************************************************* ************************************************************************************************************************************************************/ -/* Disable run-time linking on certain backends and platforms. */ -#ifndef MA_NO_RUNTIME_LINKING - #if defined(MA_EMSCRIPTEN) || defined(MA_ORBIS) || defined(MA_PROSPERO) - #define MA_NO_RUNTIME_LINKING - #endif -#endif - #ifdef MA_APPLE #include #endif @@ -18039,12 +19488,6 @@ DEVICE I/O #ifdef MA_POSIX #include - #include - - /* No need for dlfcn.h if we're not using runtime linking. */ - #ifndef MA_NO_RUNTIME_LINKING - #include - #endif #endif /* This must be set to at least 26. */ @@ -18299,7 +19742,7 @@ MA_API ma_bool32 ma_is_loopback_supported(ma_backend backend) -#if defined(MA_WIN32) +#if defined(MA_WIN32) && !defined(MA_XBOX) /* WASAPI error codes. */ #define MA_AUDCLNT_E_NOT_INITIALIZED ((HRESULT)0x88890001) #define MA_AUDCLNT_E_ALREADY_INITIALIZED ((HRESULT)0x88890002) @@ -18514,6 +19957,11 @@ typedef LONG (WINAPI * MA_PFN_RegCloseKey)(HKEY hKey); typedef LONG (WINAPI * MA_PFN_RegQueryValueExA)(HKEY hKey, const char* lpValueName, DWORD* lpReserved, DWORD* lpType, BYTE* lpData, DWORD* lpcbData); #endif /* MA_WIN32_DESKTOP */ +static GUID MA_GUID_KSDATAFORMAT_SUBTYPE_PCM = {0x00000001, 0x0000, 0x0010, {0x80, 0x00, 0x00, 0xaa, 0x00, 0x38, 0x9b, 0x71}}; +static GUID MA_GUID_KSDATAFORMAT_SUBTYPE_IEEE_FLOAT = {0x00000003, 0x0000, 0x0010, {0x80, 0x00, 0x00, 0xaa, 0x00, 0x38, 0x9b, 0x71}}; +/*static GUID MA_GUID_KSDATAFORMAT_SUBTYPE_ALAW = {0x00000006, 0x0000, 0x0010, {0x80, 0x00, 0x00, 0xaa, 0x00, 0x38, 0x9b, 0x71}};*/ +/*static GUID MA_GUID_KSDATAFORMAT_SUBTYPE_MULAW = {0x00000007, 0x0000, 0x0010, {0x80, 0x00, 0x00, 0xaa, 0x00, 0x38, 0x9b, 0x71}};*/ + MA_API size_t ma_strlen_WCHAR(const WCHAR* str) { size_t len = 0; @@ -18577,7 +20025,7 @@ Timing *******************************************************************************/ #if defined(MA_WIN32) && !defined(MA_POSIX) static LARGE_INTEGER g_ma_TimerFrequency; /* <-- Initialized to zero since it's static. */ - static void ma_timer_init(ma_timer* pTimer) + static MA_INLINE void ma_timer_init(ma_timer* pTimer) { LARGE_INTEGER counter; @@ -18589,7 +20037,7 @@ Timing pTimer->counter = counter.QuadPart; } - static double ma_timer_get_time_in_seconds(ma_timer* pTimer) + static MA_INLINE double ma_timer_get_time_in_seconds(ma_timer* pTimer) { LARGE_INTEGER counter; if (!QueryPerformanceCounter(&counter)) { @@ -18600,7 +20048,7 @@ Timing } #elif defined(MA_APPLE) && (MAC_OS_X_VERSION_MIN_REQUIRED < 101200) static ma_uint64 g_ma_TimerFrequency = 0; - static void ma_timer_init(ma_timer* pTimer) + static MA_INLINE void ma_timer_init(ma_timer* pTimer) { mach_timebase_info_data_t baseTime; mach_timebase_info(&baseTime); @@ -18609,7 +20057,7 @@ Timing pTimer->counter = mach_absolute_time(); } - static double ma_timer_get_time_in_seconds(ma_timer* pTimer) + static MA_INLINE double ma_timer_get_time_in_seconds(ma_timer* pTimer) { ma_uint64 newTimeCounter = mach_absolute_time(); ma_uint64 oldTimeCounter = pTimer->counter; @@ -18634,15 +20082,15 @@ Timing #define MA_CLOCK_ID CLOCK_REALTIME #endif - static void ma_timer_init(ma_timer* pTimer) + static MA_INLINE void ma_timer_init(ma_timer* pTimer) { struct timespec newTime; clock_gettime(MA_CLOCK_ID, &newTime); - pTimer->counter = (newTime.tv_sec * 1000000000) + newTime.tv_nsec; + pTimer->counter = ((ma_int64)newTime.tv_sec * 1000000000) + newTime.tv_nsec; } - static double ma_timer_get_time_in_seconds(ma_timer* pTimer) + static MA_INLINE double ma_timer_get_time_in_seconds(ma_timer* pTimer) { ma_uint64 newTimeCounter; ma_uint64 oldTimeCounter; @@ -18650,21 +20098,21 @@ Timing struct timespec newTime; clock_gettime(MA_CLOCK_ID, &newTime); - newTimeCounter = (newTime.tv_sec * 1000000000) + newTime.tv_nsec; + newTimeCounter = ((ma_uint64)newTime.tv_sec * 1000000000) + newTime.tv_nsec; oldTimeCounter = pTimer->counter; return (newTimeCounter - oldTimeCounter) / 1000000000.0; } #else - static void ma_timer_init(ma_timer* pTimer) + static MA_INLINE void ma_timer_init(ma_timer* pTimer) { struct timeval newTime; gettimeofday(&newTime, NULL); - pTimer->counter = (newTime.tv_sec * 1000000) + newTime.tv_usec; + pTimer->counter = ((ma_int64)newTime.tv_sec * 1000000) + newTime.tv_usec; } - static double ma_timer_get_time_in_seconds(ma_timer* pTimer) + static MA_INLINE double ma_timer_get_time_in_seconds(ma_timer* pTimer) { ma_uint64 newTimeCounter; ma_uint64 oldTimeCounter; @@ -18672,7 +20120,7 @@ Timing struct timeval newTime; gettimeofday(&newTime, NULL); - newTimeCounter = (newTime.tv_sec * 1000000) + newTime.tv_usec; + newTimeCounter = ((ma_uint64)newTime.tv_sec * 1000000) + newTime.tv_usec; oldTimeCounter = pTimer->counter; return (newTimeCounter - oldTimeCounter) / 1000000.0; @@ -19248,14 +20696,6 @@ static MA_INLINE void ma_device__set_state(ma_device* pDevice, ma_device_state n } -#if defined(MA_WIN32) - static GUID MA_GUID_KSDATAFORMAT_SUBTYPE_PCM = {0x00000001, 0x0000, 0x0010, {0x80, 0x00, 0x00, 0xaa, 0x00, 0x38, 0x9b, 0x71}}; - static GUID MA_GUID_KSDATAFORMAT_SUBTYPE_IEEE_FLOAT = {0x00000003, 0x0000, 0x0010, {0x80, 0x00, 0x00, 0xaa, 0x00, 0x38, 0x9b, 0x71}}; - /*static GUID MA_GUID_KSDATAFORMAT_SUBTYPE_ALAW = {0x00000006, 0x0000, 0x0010, {0x80, 0x00, 0x00, 0xaa, 0x00, 0x38, 0x9b, 0x71}};*/ - /*static GUID MA_GUID_KSDATAFORMAT_SUBTYPE_MULAW = {0x00000007, 0x0000, 0x0010, {0x80, 0x00, 0x00, 0xaa, 0x00, 0x38, 0x9b, 0x71}};*/ -#endif - - MA_API ma_uint32 ma_get_format_priority_index(ma_format format) /* Lower = better. */ { @@ -19967,7 +21407,7 @@ static ma_result ma_context_init__null(ma_context* pContext, const ma_context_co WIN32 COMMON *******************************************************************************/ -#if defined(MA_WIN32) +#if defined(MA_WIN32) && !defined(MA_XBOX) #if defined(MA_WIN32_DESKTOP) || defined(MA_WIN32_GDK) #define ma_CoInitializeEx(pContext, pvReserved, dwCoInit) ((pContext->win32.CoInitializeEx) ? ((MA_PFN_CoInitializeEx)pContext->win32.CoInitializeEx)(pvReserved, dwCoInit) : ((MA_PFN_CoInitialize)pContext->win32.CoInitialize)(pvReserved)) #define ma_CoUninitialize(pContext) ((MA_PFN_CoUninitialize)pContext->win32.CoUninitialize)() @@ -19982,7 +21422,7 @@ WIN32 COMMON #define ma_PropVariantClear(pContext, pvar) PropVariantClear(pvar) #endif -#if !defined(MAXULONG_PTR) && !defined(__WATCOMC__) +#if !defined(MAXULONG_PTR) && !defined(__WATCOMC__) && !defined(MA_XBOX_NXDK) typedef size_t DWORD_PTR; #endif @@ -20409,11 +21849,21 @@ typedef enum MA_AudioCategory_Other = 0 /* <-- miniaudio is only caring about Other. */ } MA_AUDIO_STREAM_CATEGORY; +typedef enum +{ + MA_AUDCLNT_STREAMOPTIONS_NONE, + MA_AUDCLNT_STREAMOPTIONS_RAW, + MA_AUDCLNT_STREAMOPTIONS_MATCH_FORMAT, + MA_AUDCLNT_STREAMOPTIONS_AMBISONICS, + MA_AUDCLNT_STREAMOPTIONS_POST_VOLUME_LOOPBACK +} MA_AUDCLNT_STREAMOPTIONS; + typedef struct { ma_uint32 cbSize; BOOL bIsOffload; MA_AUDIO_STREAM_CATEGORY eCategory; + MA_AUDCLNT_STREAMOPTIONS Options; } ma_AudioClientProperties; /* IUnknown */ @@ -21588,6 +23038,7 @@ static ma_result ma_context_get_MMDevice__wasapi(ma_context* pContext, ma_device { ma_IMMDeviceEnumerator* pDeviceEnumerator; HRESULT hr; + HRESULT CoInitializeResult; MA_ASSERT(pContext != NULL); MA_ASSERT(ppMMDevice != NULL); @@ -21601,12 +23052,17 @@ static ma_result ma_context_get_MMDevice__wasapi(ma_context* pContext, ma_device The community has reported that this seems to fix the crash. There are future plans to move all WASAPI operation over to a single thread to make everything safer, but in the meantime while we wait for that to come online I'm happy enough to use this hack instead. + + CoUninitialize should only be called if we successfully initialized. S_OK and S_FALSE both mean that we need to + call CoUninitialize since the internal ref count was increased. RPC_E_CHANGED_MODE means that CoInitializeEx was + called with a different COINIT value, and we don't call CoUninitialize in that case. Other errors are possible, + so we check for S_OK and S_FALSE specifically. */ - ma_CoInitializeEx(pContext, NULL, MA_COINIT_VALUE); + CoInitializeResult = ma_CoInitializeEx(pContext, NULL, MA_COINIT_VALUE); { hr = ma_CoCreateInstance(pContext, &MA_CLSID_MMDeviceEnumerator, NULL, CLSCTX_ALL, &MA_IID_IMMDeviceEnumerator, (void**)&pDeviceEnumerator); - } - ma_CoUninitialize(pContext); + } + if (CoInitializeResult == S_OK || CoInitializeResult == S_FALSE) { ma_CoUninitialize(pContext); } if (FAILED(hr)) { /* <-- This is checking the call above to ma_CoCreateInstance(). */ ma_log_postf(ma_context_get_log(pContext), MA_LOG_LEVEL_ERROR, "[WASAPI] Failed to create IMMDeviceEnumerator.\n"); @@ -21950,7 +23406,7 @@ static ma_result ma_context_get_IAudioClient__wasapi(ma_context* pContext, ma_de pActivationParams = &activationParams; /* When requesting a specific device ID we need to use a special device ID. */ - MA_COPY_MEMORY(virtualDeviceID.wasapi, MA_VIRTUAL_AUDIO_DEVICE_PROCESS_LOOPBACK, (wcslen(MA_VIRTUAL_AUDIO_DEVICE_PROCESS_LOOPBACK) + 1) * sizeof(wchar_t)); /* +1 for the null terminator. */ + MA_COPY_MEMORY(virtualDeviceID.wasapi, MA_VIRTUAL_AUDIO_DEVICE_PROCESS_LOOPBACK, (ma_wcslen(MA_VIRTUAL_AUDIO_DEVICE_PROCESS_LOOPBACK) + 1) * sizeof(wchar_t)); /* +1 for the null terminator. */ pDeviceID = &virtualDeviceID; } else { pActivationParams = NULL; /* No activation parameters required. */ @@ -26679,6 +28135,9 @@ typedef snd_pcm_channel_area_t ma_snd_pcm_channel_area_t; typedef snd_pcm_chmap_t ma_snd_pcm_chmap_t; typedef snd_pcm_state_t ma_snd_pcm_state_t; +/* snd_pcm_state_t */ +#define MA_SND_PCM_STATE_XRUN SND_PCM_STATE_XRUN + /* snd_pcm_stream_t */ #define MA_SND_PCM_STREAM_PLAYBACK SND_PCM_STREAM_PLAYBACK #define MA_SND_PCM_STREAM_CAPTURE SND_PCM_STREAM_CAPTURE @@ -26874,6 +28333,7 @@ typedef int (* ma_snd_pcm_hw_params_set_channels_minmax_proc) ( typedef int (* ma_snd_pcm_hw_params_set_rate_resample_proc) (ma_snd_pcm_t *pcm, ma_snd_pcm_hw_params_t *params, unsigned int val); typedef int (* ma_snd_pcm_hw_params_set_rate_proc) (ma_snd_pcm_t *pcm, ma_snd_pcm_hw_params_t *params, unsigned int val, int dir); typedef int (* ma_snd_pcm_hw_params_set_rate_near_proc) (ma_snd_pcm_t *pcm, ma_snd_pcm_hw_params_t *params, unsigned int *val, int *dir); +typedef int (* ma_snd_pcm_hw_params_set_rate_minmax_proc) (ma_snd_pcm_t *pcm, ma_snd_pcm_hw_params_t *params, unsigned int *min, int *mindir, unsigned int *max, int *maxdir); typedef int (* ma_snd_pcm_hw_params_set_buffer_size_near_proc)(ma_snd_pcm_t *pcm, ma_snd_pcm_hw_params_t *params, ma_snd_pcm_uframes_t *val); typedef int (* ma_snd_pcm_hw_params_set_periods_near_proc) (ma_snd_pcm_t *pcm, ma_snd_pcm_hw_params_t *params, unsigned int *val, int *dir); typedef int (* ma_snd_pcm_hw_params_set_access_proc) (ma_snd_pcm_t *pcm, ma_snd_pcm_hw_params_t *params, ma_snd_pcm_access_t _access); @@ -28640,8 +30100,9 @@ static ma_result ma_context_init__alsa(ma_context* pContext, const ma_context_co ma_snd_pcm_hw_params_get_format_mask_proc _snd_pcm_hw_params_get_format_mask = snd_pcm_hw_params_get_format_mask; ma_snd_pcm_hw_params_set_channels_proc _snd_pcm_hw_params_set_channels = snd_pcm_hw_params_set_channels; ma_snd_pcm_hw_params_set_channels_near_proc _snd_pcm_hw_params_set_channels_near = snd_pcm_hw_params_set_channels_near; + ma_snd_pcm_hw_params_set_channels_minmax_proc _snd_pcm_hw_params_set_channels_minmax = snd_pcm_hw_params_set_channels_minmax; ma_snd_pcm_hw_params_set_rate_resample_proc _snd_pcm_hw_params_set_rate_resample = snd_pcm_hw_params_set_rate_resample; - ma_snd_pcm_hw_params_set_rate_near _snd_pcm_hw_params_set_rate = snd_pcm_hw_params_set_rate; + ma_snd_pcm_hw_params_set_rate_proc _snd_pcm_hw_params_set_rate = snd_pcm_hw_params_set_rate; ma_snd_pcm_hw_params_set_rate_near_proc _snd_pcm_hw_params_set_rate_near = snd_pcm_hw_params_set_rate_near; ma_snd_pcm_hw_params_set_rate_minmax_proc _snd_pcm_hw_params_set_rate_minmax = snd_pcm_hw_params_set_rate_minmax; ma_snd_pcm_hw_params_set_buffer_size_near_proc _snd_pcm_hw_params_set_buffer_size_near = snd_pcm_hw_params_set_buffer_size_near; @@ -28693,9 +30154,9 @@ static ma_result ma_context_init__alsa(ma_context* pContext, const ma_context_co ma_snd_pcm_info_proc _snd_pcm_info = snd_pcm_info; ma_snd_pcm_info_sizeof_proc _snd_pcm_info_sizeof = snd_pcm_info_sizeof; ma_snd_pcm_info_get_name_proc _snd_pcm_info_get_name = snd_pcm_info_get_name; - ma_snd_pcm_poll_descriptors _snd_pcm_poll_descriptors = snd_pcm_poll_descriptors; - ma_snd_pcm_poll_descriptors_count _snd_pcm_poll_descriptors_count = snd_pcm_poll_descriptors_count; - ma_snd_pcm_poll_descriptors_revents _snd_pcm_poll_descriptors_revents = snd_pcm_poll_descriptors_revents; + ma_snd_pcm_poll_descriptors_proc _snd_pcm_poll_descriptors = snd_pcm_poll_descriptors; + ma_snd_pcm_poll_descriptors_count_proc _snd_pcm_poll_descriptors_count = snd_pcm_poll_descriptors_count; + ma_snd_pcm_poll_descriptors_revents_proc _snd_pcm_poll_descriptors_revents = snd_pcm_poll_descriptors_revents; ma_snd_config_update_free_global_proc _snd_config_update_free_global = snd_config_update_free_global; pContext->alsa.snd_pcm_open = (ma_proc)_snd_pcm_open; @@ -28711,6 +30172,7 @@ static ma_result ma_context_init__alsa(ma_context* pContext, const ma_context_co pContext->alsa.snd_pcm_hw_params_set_rate_resample = (ma_proc)_snd_pcm_hw_params_set_rate_resample; pContext->alsa.snd_pcm_hw_params_set_rate = (ma_proc)_snd_pcm_hw_params_set_rate; pContext->alsa.snd_pcm_hw_params_set_rate_near = (ma_proc)_snd_pcm_hw_params_set_rate_near; + pContext->alsa.snd_pcm_hw_params_set_rate_minmax = (ma_proc)_snd_pcm_hw_params_set_rate_minmax; pContext->alsa.snd_pcm_hw_params_set_buffer_size_near = (ma_proc)_snd_pcm_hw_params_set_buffer_size_near; pContext->alsa.snd_pcm_hw_params_set_periods_near = (ma_proc)_snd_pcm_hw_params_set_periods_near; pContext->alsa.snd_pcm_hw_params_set_access = (ma_proc)_snd_pcm_hw_params_set_access; @@ -29436,7 +30898,7 @@ typedef void (* ma_pa_threaded_mainloop_unlock_proc) ( typedef void (* ma_pa_threaded_mainloop_wait_proc) (ma_pa_threaded_mainloop* m); typedef void (* ma_pa_threaded_mainloop_signal_proc) (ma_pa_threaded_mainloop* m, int wait_for_accept); typedef void (* ma_pa_threaded_mainloop_accept_proc) (ma_pa_threaded_mainloop* m); -typedef int (* ma_pa_threaded_mainloop_get_retval_proc) (ma_pa_threaded_mainloop* m); +typedef int (* ma_pa_threaded_mainloop_get_retval_proc) (const ma_pa_threaded_mainloop* m); typedef ma_pa_mainloop_api* (* ma_pa_threaded_mainloop_get_api_proc) (ma_pa_threaded_mainloop* m); typedef int (* ma_pa_threaded_mainloop_in_thread_proc) (ma_pa_threaded_mainloop* m); typedef void (* ma_pa_threaded_mainloop_set_name_proc) (ma_pa_threaded_mainloop* m, const char* name); @@ -29445,13 +30907,13 @@ typedef void (* ma_pa_context_unref_proc) ( typedef int (* ma_pa_context_connect_proc) (ma_pa_context* c, const char* server, ma_pa_context_flags_t flags, const ma_pa_spawn_api* api); typedef void (* ma_pa_context_disconnect_proc) (ma_pa_context* c); typedef void (* ma_pa_context_set_state_callback_proc) (ma_pa_context* c, ma_pa_context_notify_cb_t cb, void* userdata); -typedef ma_pa_context_state_t (* ma_pa_context_get_state_proc) (ma_pa_context* c); +typedef ma_pa_context_state_t (* ma_pa_context_get_state_proc) (const ma_pa_context* c); typedef ma_pa_operation* (* ma_pa_context_get_sink_info_list_proc) (ma_pa_context* c, ma_pa_sink_info_cb_t cb, void* userdata); typedef ma_pa_operation* (* ma_pa_context_get_source_info_list_proc) (ma_pa_context* c, ma_pa_source_info_cb_t cb, void* userdata); typedef ma_pa_operation* (* ma_pa_context_get_sink_info_by_name_proc) (ma_pa_context* c, const char* name, ma_pa_sink_info_cb_t cb, void* userdata); typedef ma_pa_operation* (* ma_pa_context_get_source_info_by_name_proc)(ma_pa_context* c, const char* name, ma_pa_source_info_cb_t cb, void* userdata); typedef void (* ma_pa_operation_unref_proc) (ma_pa_operation* o); -typedef ma_pa_operation_state_t (* ma_pa_operation_get_state_proc) (ma_pa_operation* o); +typedef ma_pa_operation_state_t (* ma_pa_operation_get_state_proc) (const ma_pa_operation* o); typedef ma_pa_channel_map* (* ma_pa_channel_map_init_extend_proc) (ma_pa_channel_map* m, unsigned channels, ma_pa_channel_map_def_t def); typedef int (* ma_pa_channel_map_valid_proc) (const ma_pa_channel_map* m); typedef int (* ma_pa_channel_map_compatible_proc) (const ma_pa_channel_map* m, const ma_pa_sample_spec* ss); @@ -29460,12 +30922,12 @@ typedef void (* ma_pa_stream_unref_proc) ( typedef int (* ma_pa_stream_connect_playback_proc) (ma_pa_stream* s, const char* dev, const ma_pa_buffer_attr* attr, ma_pa_stream_flags_t flags, const ma_pa_cvolume* volume, ma_pa_stream* sync_stream); typedef int (* ma_pa_stream_connect_record_proc) (ma_pa_stream* s, const char* dev, const ma_pa_buffer_attr* attr, ma_pa_stream_flags_t flags); typedef int (* ma_pa_stream_disconnect_proc) (ma_pa_stream* s); -typedef ma_pa_stream_state_t (* ma_pa_stream_get_state_proc) (ma_pa_stream* s); +typedef ma_pa_stream_state_t (* ma_pa_stream_get_state_proc) (const ma_pa_stream* s); typedef const ma_pa_sample_spec* (* ma_pa_stream_get_sample_spec_proc) (ma_pa_stream* s); typedef const ma_pa_channel_map* (* ma_pa_stream_get_channel_map_proc) (ma_pa_stream* s); typedef const ma_pa_buffer_attr* (* ma_pa_stream_get_buffer_attr_proc) (ma_pa_stream* s); typedef ma_pa_operation* (* ma_pa_stream_set_buffer_attr_proc) (ma_pa_stream* s, const ma_pa_buffer_attr* attr, ma_pa_stream_success_cb_t cb, void* userdata); -typedef const char* (* ma_pa_stream_get_device_name_proc) (ma_pa_stream* s); +typedef const char* (* ma_pa_stream_get_device_name_proc) (const ma_pa_stream* s); typedef void (* ma_pa_stream_set_write_callback_proc) (ma_pa_stream* s, ma_pa_stream_request_cb_t cb, void* userdata); typedef void (* ma_pa_stream_set_read_callback_proc) (ma_pa_stream* s, ma_pa_stream_request_cb_t cb, void* userdata); typedef void (* ma_pa_stream_set_suspended_callback_proc) (ma_pa_stream* s, ma_pa_stream_notify_cb_t cb, void* userdata); @@ -29473,15 +30935,15 @@ typedef void (* ma_pa_stream_set_moved_callback_proc) ( typedef int (* ma_pa_stream_is_suspended_proc) (const ma_pa_stream* s); typedef ma_pa_operation* (* ma_pa_stream_flush_proc) (ma_pa_stream* s, ma_pa_stream_success_cb_t cb, void* userdata); typedef ma_pa_operation* (* ma_pa_stream_drain_proc) (ma_pa_stream* s, ma_pa_stream_success_cb_t cb, void* userdata); -typedef int (* ma_pa_stream_is_corked_proc) (ma_pa_stream* s); +typedef int (* ma_pa_stream_is_corked_proc) (const ma_pa_stream* s); typedef ma_pa_operation* (* ma_pa_stream_cork_proc) (ma_pa_stream* s, int b, ma_pa_stream_success_cb_t cb, void* userdata); typedef ma_pa_operation* (* ma_pa_stream_trigger_proc) (ma_pa_stream* s, ma_pa_stream_success_cb_t cb, void* userdata); typedef int (* ma_pa_stream_begin_write_proc) (ma_pa_stream* s, void** data, size_t* nbytes); typedef int (* ma_pa_stream_write_proc) (ma_pa_stream* s, const void* data, size_t nbytes, ma_pa_free_cb_t free_cb, int64_t offset, ma_pa_seek_mode_t seek); typedef int (* ma_pa_stream_peek_proc) (ma_pa_stream* s, const void** data, size_t* nbytes); typedef int (* ma_pa_stream_drop_proc) (ma_pa_stream* s); -typedef size_t (* ma_pa_stream_writable_size_proc) (ma_pa_stream* s); -typedef size_t (* ma_pa_stream_readable_size_proc) (ma_pa_stream* s); +typedef size_t (* ma_pa_stream_writable_size_proc) (const ma_pa_stream* s); +typedef size_t (* ma_pa_stream_readable_size_proc) (const ma_pa_stream* s); typedef struct { @@ -29777,9 +31239,10 @@ static ma_result ma_init_pa_mainloop_and_pa_context__pulse(ma_context* pContext, } /* Now we need to connect to the context. Everything is asynchronous so we need to wait for it to connect before returning. */ - result = ma_result_from_pulse(((ma_pa_context_connect_proc)pContext->pulse.pa_context_connect)((ma_pa_context*)pPulseContext, pServerName, (tryAutoSpawn) ? 0 : MA_PA_CONTEXT_NOAUTOSPAWN, NULL)); + result = ma_result_from_pulse(((ma_pa_context_connect_proc)pContext->pulse.pa_context_connect)((ma_pa_context*)pPulseContext, pServerName, (tryAutoSpawn) ? MA_PA_CONTEXT_NOFLAGS : MA_PA_CONTEXT_NOAUTOSPAWN, NULL)); if (result != MA_SUCCESS) { ma_log_postf(ma_context_get_log(pContext), MA_LOG_LEVEL_ERROR, "[PulseAudio] Failed to connect PulseAudio context."); + ((ma_pa_context_unref_proc)pContext->pulse.pa_context_unref)((ma_pa_context*)(pPulseContext)); ((ma_pa_mainloop_free_proc)pContext->pulse.pa_mainloop_free)((ma_pa_mainloop*)(pMainLoop)); return result; } @@ -29788,6 +31251,7 @@ static ma_result ma_init_pa_mainloop_and_pa_context__pulse(ma_context* pContext, result = ma_wait_for_pa_context_to_connect__pulse(pContext, pMainLoop, pPulseContext); if (result != MA_SUCCESS) { ma_log_postf(ma_context_get_log(pContext), MA_LOG_LEVEL_ERROR, "[PulseAudio] Waiting for connection failed."); + ((ma_pa_context_unref_proc)pContext->pulse.pa_context_unref)((ma_pa_context*)(pPulseContext)); ((ma_pa_mainloop_free_proc)pContext->pulse.pa_mainloop_free)((ma_pa_mainloop*)(pMainLoop)); return result; } @@ -30510,7 +31974,7 @@ static ma_result ma_device_init__pulse(ma_device* pDevice, const ma_device_confi const ma_pa_buffer_attr* pActualAttr = NULL; const ma_pa_channel_map* pActualChannelMap = NULL; ma_uint32 iChannel; - ma_pa_stream_flags_t streamFlags; + int streamFlags; MA_ASSERT(pDevice != NULL); MA_ZERO_OBJECT(&pDevice->pulse); @@ -30568,8 +32032,13 @@ static ma_result ma_device_init__pulse(ma_device* pDevice, const ma_device_confi ss.channels = pDescriptorCapture->channels; } + /* PulseAudio has a maximum channel count of 32. We'll get a crash if this is exceeded. */ + if (ss.channels > 32) { + ss.channels = 32; + } + /* Use a default channel map. */ - ((ma_pa_channel_map_init_extend_proc)pDevice->pContext->pulse.pa_channel_map_init_extend)(&cmap, ss.channels, pConfig->pulse.channelMap); + ((ma_pa_channel_map_init_extend_proc)pDevice->pContext->pulse.pa_channel_map_init_extend)(&cmap, ss.channels, (ma_pa_channel_map_def_t)pConfig->pulse.channelMap); /* Use the requested sample rate if one was specified. */ if (pDescriptorCapture->sampleRate != 0) { @@ -30626,7 +32095,7 @@ static ma_result ma_device_init__pulse(ma_device* pDevice, const ma_device_confi streamFlags |= MA_PA_STREAM_DONT_MOVE; } - error = ((ma_pa_stream_connect_record_proc)pDevice->pContext->pulse.pa_stream_connect_record)((ma_pa_stream*)pDevice->pulse.pStreamCapture, devCapture, &attr, streamFlags); + error = ((ma_pa_stream_connect_record_proc)pDevice->pContext->pulse.pa_stream_connect_record)((ma_pa_stream*)pDevice->pulse.pStreamCapture, devCapture, &attr, (ma_pa_stream_flags_t)streamFlags); if (error != MA_PA_OK) { ma_log_post(ma_device_get_log(pDevice), MA_LOG_LEVEL_ERROR, "[PulseAudio] Failed to connect PulseAudio capture stream."); result = ma_result_from_pulse(error); @@ -30720,8 +32189,13 @@ static ma_result ma_device_init__pulse(ma_device* pDevice, const ma_device_confi ss.channels = pDescriptorPlayback->channels; } + /* PulseAudio has a maximum channel count of 32. We'll get a crash if this is exceeded. */ + if (ss.channels > 32) { + ss.channels = 32; + } + /* Use a default channel map. */ - ((ma_pa_channel_map_init_extend_proc)pDevice->pContext->pulse.pa_channel_map_init_extend)(&cmap, ss.channels, pConfig->pulse.channelMap); + ((ma_pa_channel_map_init_extend_proc)pDevice->pContext->pulse.pa_channel_map_init_extend)(&cmap, ss.channels, (ma_pa_channel_map_def_t)pConfig->pulse.channelMap); /* Use the requested sample rate if one was specified. */ @@ -30783,7 +32257,7 @@ static ma_result ma_device_init__pulse(ma_device* pDevice, const ma_device_confi streamFlags |= MA_PA_STREAM_DONT_MOVE; } - error = ((ma_pa_stream_connect_playback_proc)pDevice->pContext->pulse.pa_stream_connect_playback)((ma_pa_stream*)pDevice->pulse.pStreamPlayback, devPlayback, &attr, streamFlags, NULL, NULL); + error = ((ma_pa_stream_connect_playback_proc)pDevice->pContext->pulse.pa_stream_connect_playback)((ma_pa_stream*)pDevice->pulse.pStreamPlayback, devPlayback, &attr, (ma_pa_stream_flags_t)streamFlags, NULL, NULL); if (error != MA_PA_OK) { ma_log_post(ma_device_get_log(pDevice), MA_LOG_LEVEL_ERROR, "[PulseAudio] Failed to connect PulseAudio playback stream."); result = ma_result_from_pulse(error); @@ -31338,6 +32812,7 @@ typedef JackProcessCallback ma_JackProcessCallback; typedef JackBufferSizeCallback ma_JackBufferSizeCallback; typedef JackShutdownCallback ma_JackShutdownCallback; #define MA_JACK_DEFAULT_AUDIO_TYPE JACK_DEFAULT_AUDIO_TYPE +#define ma_JackNullOption JackNullOption #define ma_JackNoStartServer JackNoStartServer #define ma_JackPortIsInput JackPortIsInput #define ma_JackPortIsOutput JackPortIsOutput @@ -31352,6 +32827,7 @@ typedef int (* ma_JackProcessCallback) (ma_jack_nframes_t nframes, void* arg) typedef int (* ma_JackBufferSizeCallback)(ma_jack_nframes_t nframes, void* arg); typedef void (* ma_JackShutdownCallback) (void* arg); #define MA_JACK_DEFAULT_AUDIO_TYPE "32 bit float mono audio" +#define ma_JackNullOption 0 #define ma_JackNoStartServer 1 #define ma_JackPortIsInput 1 #define ma_JackPortIsOutput 2 @@ -31392,7 +32868,7 @@ static ma_result ma_context_open_client__jack(ma_context* pContext, ma_jack_clie maxClientNameSize = ((ma_jack_client_name_size_proc)pContext->jack.jack_client_name_size)(); /* Includes null terminator. */ ma_strncpy_s(clientName, ma_min(sizeof(clientName), maxClientNameSize), (pContext->jack.pClientName != NULL) ? pContext->jack.pClientName : "miniaudio", (size_t)-1); - pClient = ((ma_jack_client_open_proc)pContext->jack.jack_client_open)(clientName, (pContext->jack.tryStartServer) ? 0 : ma_JackNoStartServer, &status, NULL); + pClient = ((ma_jack_client_open_proc)pContext->jack.jack_client_open)(clientName, (pContext->jack.tryStartServer) ? ma_JackNullOption : ma_JackNoStartServer, &status, NULL); if (pClient == NULL) { return MA_FAILED_TO_OPEN_BACKEND_DEVICE; } @@ -36994,7 +38470,7 @@ OSS Backend #define MA_OSS_DEFAULT_DEVICE_NAME "/dev/dsp" -static int ma_open_temp_device__oss() +static int ma_open_temp_device__oss(void) { /* The OSS sample code uses "/dev/mixer" as the device for getting system properties so I'm going to do the same. */ int fd = open("/dev/mixer", O_RDONLY, 0); @@ -37834,25 +39310,30 @@ static void ma_stream_error_callback__aaudio(ma_AAudioStream* pStream, void* pUs (void)error; ma_log_postf(ma_device_get_log(pDevice), MA_LOG_LEVEL_INFO, "[AAudio] ERROR CALLBACK: error=%d, AAudioStream_getState()=%d\n", error, ((MA_PFN_AAudioStream_getState)pDevice->pContext->aaudio.AAudioStream_getState)(pStream)); + /* When we get an error, we'll assume that the stream is in an erroneous state and needs to be restarted. From the documentation, we cannot do this from the error callback. Therefore we are going to use an event thread for the AAudio backend to do this cleanly and safely. */ - job = ma_job_init(MA_JOB_TYPE_DEVICE_AAUDIO_REROUTE); - job.data.device.aaudio.reroute.pDevice = pDevice; - - if (pStream == pDevice->aaudio.pStreamCapture) { - job.data.device.aaudio.reroute.deviceType = ma_device_type_capture; + if (ma_atomic_bool32_get(&pDevice->aaudio.isTearingDown)) { + ma_log_postf(ma_device_get_log(pDevice), MA_LOG_LEVEL_INFO, "[AAudio] Device Disconnected. Tearing down device.\n"); } else { - job.data.device.aaudio.reroute.deviceType = ma_device_type_playback; - } - - result = ma_device_job_thread_post(&pDevice->pContext->aaudio.jobThread, &job); - if (result != MA_SUCCESS) { - ma_log_postf(ma_device_get_log(pDevice), MA_LOG_LEVEL_INFO, "[AAudio] Device Disconnected. Failed to post job for rerouting.\n"); - return; + job = ma_job_init(MA_JOB_TYPE_DEVICE_AAUDIO_REROUTE); + job.data.device.aaudio.reroute.pDevice = pDevice; + + if (pStream == pDevice->aaudio.pStreamCapture) { + job.data.device.aaudio.reroute.deviceType = ma_device_type_capture; + } else { + job.data.device.aaudio.reroute.deviceType = ma_device_type_playback; + } + + result = ma_device_job_thread_post(&pDevice->pContext->aaudio.jobThread, &job); + if (result != MA_SUCCESS) { + ma_log_postf(ma_device_get_log(pDevice), MA_LOG_LEVEL_INFO, "[AAudio] Device Disconnected. Failed to post job for rerouting.\n"); + return; + } } } @@ -38169,7 +39650,7 @@ static ma_result ma_close_streams__aaudio(ma_device* pDevice) { MA_ASSERT(pDevice != NULL); - /* When re-routing, streams may have been closed and never re-opened. Hence the extra checks below. */ + /* When rerouting, streams may have been closed and never re-opened. Hence the extra checks below. */ if (pDevice->type == ma_device_type_capture || pDevice->type == ma_device_type_duplex) { ma_close_stream__aaudio(pDevice->pContext, (ma_AAudioStream*)pDevice->aaudio.pStreamCapture); pDevice->aaudio.pStreamCapture = NULL; @@ -38186,6 +39667,12 @@ static ma_result ma_device_uninit__aaudio(ma_device* pDevice) { MA_ASSERT(pDevice != NULL); + /* + Note: Closing the streams may cause a timeout error, which would then trigger rerouting in our error callback. + We must not schedule a reroute when device is getting destroyed. + */ + ma_atomic_bool32_set(&pDevice->aaudio.isTearingDown, MA_TRUE); + /* Wait for any rerouting to finish before attempting to close the streams. */ ma_mutex_lock(&pDevice->aaudio.rerouteLock); { @@ -38193,7 +39680,7 @@ static ma_result ma_device_uninit__aaudio(ma_device* pDevice) } ma_mutex_unlock(&pDevice->aaudio.rerouteLock); - /* Destroy re-routing lock. */ + /* Destroy rerouting lock. */ ma_mutex_uninit(&pDevice->aaudio.rerouteLock); return MA_SUCCESS; @@ -38429,17 +39916,22 @@ static ma_result ma_device_stop__aaudio(ma_device* pDevice) static ma_result ma_device_reinit__aaudio(ma_device* pDevice, ma_device_type deviceType) { + const ma_int32 maxAttempts = 4; /* Reasonable retry limit. */ + ma_result result; - int32_t retries = 0; + ma_int32 iAttempt; MA_ASSERT(pDevice != NULL); - /* - TODO: Stop retrying if main thread is about to uninit device. - */ - ma_mutex_lock(&pDevice->aaudio.rerouteLock); - { -error_disconnected: + /* We got disconnected! Retry a few times, until we find a connected device! */ + iAttempt = 0; + while (iAttempt++ < maxAttempts) { + /* Device tearing down? No need to reroute! */ + if (ma_atomic_bool32_get(&pDevice->aaudio.isTearingDown)) { + result = MA_SUCCESS; /* Caller should continue as normal. */ + break; + } + /* The first thing to do is close the streams. */ ma_close_streams__aaudio(pDevice); @@ -38495,14 +39987,16 @@ static ma_result ma_device_reinit__aaudio(ma_device* pDevice, ma_device_type dev result = ma_device_init_streams__aaudio(pDevice, &deviceConfig, &descriptorPlayback, &descriptorCapture); if (result != MA_SUCCESS) { ma_log_post(ma_device_get_log(pDevice), MA_LOG_LEVEL_WARNING, "[AAudio] Failed to create stream after route change."); - goto done; + /* Reroute failed! */ + break; } result = ma_device_post_init(pDevice, deviceType, &descriptorPlayback, &descriptorCapture); if (result != MA_SUCCESS) { ma_log_post(ma_device_get_log(pDevice), MA_LOG_LEVEL_WARNING, "[AAudio] Failed to initialize device after route change."); ma_close_streams__aaudio(pDevice); - goto done; + /* Reroute failed! */ + break; } /* We'll only ever do this in response to a reroute. */ @@ -38513,26 +40007,23 @@ static ma_result ma_device_reinit__aaudio(ma_device* pDevice, ma_device_type dev if (pDevice->aaudio.noAutoStartAfterReroute == MA_FALSE) { result = ma_device_start__aaudio(pDevice); if (result != MA_SUCCESS) { - /* We got disconnected! Retry a few times, until we find a connected device! */ - retries += 1; - if (retries <= 3) { - ma_log_postf(ma_device_get_log(pDevice), MA_LOG_LEVEL_INFO, "[AAudio] Failed to start stream after route change, retrying(%d)", retries); - goto error_disconnected; + if (iAttempt < maxAttempts) { + ma_log_postf(ma_device_get_log(pDevice), MA_LOG_LEVEL_INFO, "[AAudio] Failed to start stream after route change, retrying(%d)", iAttempt); + } else { + ma_log_post(ma_device_get_log(pDevice), MA_LOG_LEVEL_INFO, "[AAudio] Failed to start stream after route change, giving up."); } - ma_log_post(ma_device_get_log(pDevice), MA_LOG_LEVEL_INFO, "[AAudio] Failed to start stream after route change."); - goto done; } } else { - ma_device_stop(pDevice); /* Do a full device stop so we set internal state correctly. */ + ma_device_stop(pDevice); /* Do a full device stop so we set internal state correctly. */ } } - - result = MA_SUCCESS; - } -done: - /* Re-routing done */ - ma_mutex_unlock(&pDevice->aaudio.rerouteLock); + if (result == MA_SUCCESS) { + /* Reroute successful! */ + break; + } + } + return result; } @@ -38698,7 +40189,7 @@ static ma_result ma_context_init__aaudio(ma_context* pContext, const ma_context_ static ma_result ma_job_process__device__aaudio_reroute(ma_job* pJob) { - ma_result result; + ma_result result = MA_SUCCESS; ma_device* pDevice; MA_ASSERT(pJob != NULL); @@ -38706,19 +40197,22 @@ static ma_result ma_job_process__device__aaudio_reroute(ma_job* pJob) pDevice = (ma_device*)pJob->data.device.aaudio.reroute.pDevice; MA_ASSERT(pDevice != NULL); - /* Here is where we need to reroute the device. To do this we need to uninitialize the stream and reinitialize it. */ - result = ma_device_reinit__aaudio(pDevice, (ma_device_type)pJob->data.device.aaudio.reroute.deviceType); - if (result != MA_SUCCESS) { - /* - Getting here means we failed to reroute the device. The best thing I can think of here is to - just stop the device. - */ - ma_log_post(ma_device_get_log(pDevice), MA_LOG_LEVEL_ERROR, "[AAudio] Stopping device due to reroute failure."); - ma_device_stop(pDevice); - return result; + ma_mutex_lock(&pDevice->aaudio.rerouteLock); + { + /* Here is where we need to reroute the device. To do this we need to uninitialize the stream and reinitialize it. */ + result = ma_device_reinit__aaudio(pDevice, (ma_device_type)pJob->data.device.aaudio.reroute.deviceType); + if (result != MA_SUCCESS) { + /* + Getting here means we failed to reroute the device. The best thing I can think of here is to + just stop the device. + */ + ma_log_post(ma_device_get_log(pDevice), MA_LOG_LEVEL_ERROR, "[AAudio] Stopping device due to reroute failure."); + ma_device_stop(pDevice); + } } + ma_mutex_unlock(&pDevice->aaudio.rerouteLock); - return MA_SUCCESS; + return result; } #else /* Getting here means there is no AAudio backend so we need a no-op job implementation. */ @@ -40269,8 +41763,11 @@ static EM_BOOL ma_audio_worklet_process_callback__webaudio(int inputCount, const frameCount = pDevice->capture.internalPeriodSizeInFrames; } + /* + If this is called by the device has not yet been started we need to return early, making sure we output silence to + the output buffer. + */ if (ma_device_get_state(pDevice) != ma_device_state_started) { - /* Fill the output buffer with zero to avoid a noise sound */ for (int i = 0; i < outputCount; i += 1) { MA_ZERO_MEMORY(pOutputs[i].data, pOutputs[i].numberOfChannels * frameCount * sizeof(float)); } @@ -40292,7 +41789,9 @@ static EM_BOOL ma_audio_worklet_process_callback__webaudio(int inputCount, const if (outputCount > 0) { /* If it's a capture-only device, we'll need to output silence. */ if (pDevice->type == ma_device_type_capture) { - MA_ZERO_MEMORY(pOutputs[0].data, frameCount * pDevice->playback.internalChannels * sizeof(float)); + for (int i = 0; i < outputCount; i += 1) { + MA_ZERO_MEMORY(pOutputs[i].data, pOutputs[i].numberOfChannels * frameCount * sizeof(float)); + } } else { ma_device_process_pcm_frames_playback__webaudio(pDevice, frameCount, pDevice->webaudio.pIntermediaryBuffer); @@ -40302,6 +41801,14 @@ static EM_BOOL ma_audio_worklet_process_callback__webaudio(int inputCount, const pOutputs[0].data[frameCount*iChannel + iFrame] = pDevice->webaudio.pIntermediaryBuffer[iFrame*pDevice->playback.internalChannels + iChannel]; } } + + /* + Just above we output data to the first output buffer. Here we just make sure we're putting silence into any + remaining output buffers. + */ + for (int i = 1; i < outputCount; i += 1) { /* <-- Note that the counter starts at 1 instead of 0. */ + MA_ZERO_MEMORY(pOutputs[i].data, pOutputs[i].numberOfChannels * frameCount * sizeof(float)); + } } } @@ -40782,8 +42289,8 @@ static ma_result ma_context_uninit__webaudio(ma_context* pContext) /* Remove the global miniaudio object from window if there are no more references to it. */ EM_ASM({ if (typeof(window.miniaudio) !== 'undefined') { - miniaudio.unlock_event_types.map(function(event_type) { - document.removeEventListener(event_type, miniaudio.unlock, true); + window.miniaudio.unlock_event_types.map(function(event_type) { + document.removeEventListener(event_type, window.miniaudio.unlock, true); }); window.miniaudio.referenceCount -= 1; @@ -41236,13 +42743,13 @@ MA_API ma_result ma_device_post_init(ma_device* pDevice, ma_device_type deviceTy static ma_thread_result MA_THREADCALL ma_worker_thread(void* pData) { ma_device* pDevice = (ma_device*)pData; -#ifdef MA_WIN32 +#if defined(MA_WIN32) && !defined(MA_XBOX) HRESULT CoInitializeResult; #endif MA_ASSERT(pDevice != NULL); -#ifdef MA_WIN32 +#if defined(MA_WIN32) && !defined(MA_XBOX) CoInitializeResult = ma_CoInitializeEx(pDevice->pContext, NULL, MA_COINIT_VALUE); #endif @@ -41333,8 +42840,8 @@ static ma_thread_result MA_THREADCALL ma_worker_thread(void* pData) ma_event_signal(&pDevice->stopEvent); } -#ifdef MA_WIN32 - if (CoInitializeResult == S_OK) { +#if defined(MA_WIN32) && !defined(MA_XBOX) + if (CoInitializeResult == S_OK || CoInitializeResult == S_FALSE) { ma_CoUninitialize(pDevice->pContext); } #endif @@ -41358,67 +42865,92 @@ static ma_bool32 ma_device__is_initialized(ma_device* pDevice) static ma_result ma_context_uninit_backend_apis__win32(ma_context* pContext) { /* For some reason UWP complains when CoUninitialize() is called. I'm just not going to call it on UWP. */ -#if defined(MA_WIN32_DESKTOP) || defined(MA_WIN32_GDK) - if (pContext->win32.CoInitializeResult == S_OK) { - ma_CoUninitialize(pContext); - } + #if defined(MA_WIN32_DESKTOP) || defined(MA_WIN32_GDK) + { + /* TODO: Remove this once the new single threaded backend system is in place in 0.12. */ + #if !defined(MA_XBOX) + { + if (pContext->win32.CoInitializeResult == S_OK || pContext->win32.CoInitializeResult == S_FALSE) { + ma_CoUninitialize(pContext); /* TODO: Remove this once the new single threaded backend system is in place in 0.12. */ + } + } + #endif - #if defined(MA_WIN32_DESKTOP) - ma_dlclose(ma_context_get_log(pContext), pContext->win32.hUser32DLL); - ma_dlclose(ma_context_get_log(pContext), pContext->win32.hAdvapi32DLL); - #endif + #if defined(MA_WIN32_DESKTOP) + ma_dlclose(ma_context_get_log(pContext), pContext->win32.hUser32DLL); + ma_dlclose(ma_context_get_log(pContext), pContext->win32.hAdvapi32DLL); + #endif - ma_dlclose(ma_context_get_log(pContext), pContext->win32.hOle32DLL); -#else - (void)pContext; -#endif + ma_dlclose(ma_context_get_log(pContext), pContext->win32.hOle32DLL); + } + #else + { + (void)pContext; + } + #endif return MA_SUCCESS; } static ma_result ma_context_init_backend_apis__win32(ma_context* pContext) { -#if defined(MA_WIN32_DESKTOP) || defined(MA_WIN32_GDK) - #if defined(MA_WIN32_DESKTOP) - /* User32.dll */ - pContext->win32.hUser32DLL = ma_dlopen(ma_context_get_log(pContext), "user32.dll"); - if (pContext->win32.hUser32DLL == NULL) { - return MA_FAILED_TO_INIT_BACKEND; - } + /* + TODO: Reassess all of this stuff and move everything to the relevant backends. For example, I think + GetForegroundWindow() and GetDesktopWindow() are only used by the DirectSound backend. + */ + #if (defined(MA_WIN32_DESKTOP) || defined(MA_WIN32_GDK)) && !defined(MA_XBOX) + { + #if defined(MA_WIN32_DESKTOP) + { + /* User32.dll */ + pContext->win32.hUser32DLL = ma_dlopen(ma_context_get_log(pContext), "user32.dll"); + if (pContext->win32.hUser32DLL == NULL) { + return MA_FAILED_TO_INIT_BACKEND; + } + + pContext->win32.GetForegroundWindow = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hUser32DLL, "GetForegroundWindow"); + pContext->win32.GetDesktopWindow = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hUser32DLL, "GetDesktopWindow"); - pContext->win32.GetForegroundWindow = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hUser32DLL, "GetForegroundWindow"); - pContext->win32.GetDesktopWindow = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hUser32DLL, "GetDesktopWindow"); + /* Advapi32.dll */ + pContext->win32.hAdvapi32DLL = ma_dlopen(ma_context_get_log(pContext), "advapi32.dll"); + if (pContext->win32.hAdvapi32DLL == NULL) { + return MA_FAILED_TO_INIT_BACKEND; + } + + pContext->win32.RegOpenKeyExA = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hAdvapi32DLL, "RegOpenKeyExA"); + pContext->win32.RegCloseKey = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hAdvapi32DLL, "RegCloseKey"); + pContext->win32.RegQueryValueExA = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hAdvapi32DLL, "RegQueryValueExA"); + } + #endif - /* Advapi32.dll */ - pContext->win32.hAdvapi32DLL = ma_dlopen(ma_context_get_log(pContext), "advapi32.dll"); - if (pContext->win32.hAdvapi32DLL == NULL) { + /* Ole32.dll */ + pContext->win32.hOle32DLL = ma_dlopen(ma_context_get_log(pContext), "ole32.dll"); + if (pContext->win32.hOle32DLL == NULL) { return MA_FAILED_TO_INIT_BACKEND; } - pContext->win32.RegOpenKeyExA = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hAdvapi32DLL, "RegOpenKeyExA"); - pContext->win32.RegCloseKey = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hAdvapi32DLL, "RegCloseKey"); - pContext->win32.RegQueryValueExA = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hAdvapi32DLL, "RegQueryValueExA"); + pContext->win32.CoInitialize = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hOle32DLL, "CoInitialize"); + pContext->win32.CoInitializeEx = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hOle32DLL, "CoInitializeEx"); + pContext->win32.CoUninitialize = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hOle32DLL, "CoUninitialize"); + pContext->win32.CoCreateInstance = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hOle32DLL, "CoCreateInstance"); + pContext->win32.CoTaskMemFree = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hOle32DLL, "CoTaskMemFree"); + pContext->win32.PropVariantClear = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hOle32DLL, "PropVariantClear"); + pContext->win32.StringFromGUID2 = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hOle32DLL, "StringFromGUID2"); + } + #else + { + (void)pContext; /* Unused. */ + } #endif - /* Ole32.dll */ - pContext->win32.hOle32DLL = ma_dlopen(ma_context_get_log(pContext), "ole32.dll"); - if (pContext->win32.hOle32DLL == NULL) { - return MA_FAILED_TO_INIT_BACKEND; + /* TODO: Remove this once the new single threaded backend system is in place in 0.12. */ + #if !defined(MA_XBOX) + { + pContext->win32.CoInitializeResult = ma_CoInitializeEx(pContext, NULL, MA_COINIT_VALUE); } + #endif - pContext->win32.CoInitialize = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hOle32DLL, "CoInitialize"); - pContext->win32.CoInitializeEx = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hOle32DLL, "CoInitializeEx"); - pContext->win32.CoUninitialize = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hOle32DLL, "CoUninitialize"); - pContext->win32.CoCreateInstance = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hOle32DLL, "CoCreateInstance"); - pContext->win32.CoTaskMemFree = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hOle32DLL, "CoTaskMemFree"); - pContext->win32.PropVariantClear = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hOle32DLL, "PropVariantClear"); - pContext->win32.StringFromGUID2 = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hOle32DLL, "StringFromGUID2"); -#else - (void)pContext; /* Unused. */ -#endif - - pContext->win32.CoInitializeResult = ma_CoInitializeEx(pContext, NULL, MA_COINIT_VALUE); return MA_SUCCESS; } #else @@ -44016,7 +45548,7 @@ static MA_INLINE void ma_pcm_s16_to_s32__reference(void* dst, const void* src, m ma_uint64 i; for (i = 0; i < count; i += 1) { - dst_s32[i] = src_s16[i] << 16; + dst_s32[i] = (ma_int32)src_s16[i] << 16; } (void)ditherMode; @@ -49347,15 +50879,15 @@ static /*__attribute__((noinline))*/ ma_result ma_gainer_process_pcm_frames_inte a += d; } } + + pFramesOut = ma_offset_ptr(pFramesOut, interpolatedFrameCount * sizeof(float)); + pFramesIn = ma_offset_ptr(pFramesIn, interpolatedFrameCount * sizeof(float)); } + frameCount -= interpolatedFrameCount; + /* Make sure the timer is updated. */ pGainer->t = (ma_uint32)ma_min(pGainer->t + interpolatedFrameCount, pGainer->config.smoothTimeInFrames); - - /* Adjust our arguments so the next part can work normally. */ - frameCount -= interpolatedFrameCount; - pFramesOut = ma_offset_ptr(pFramesOut, interpolatedFrameCount * sizeof(float)); - pFramesIn = ma_offset_ptr(pFramesIn, interpolatedFrameCount * sizeof(float)); } /* All we need to do here is apply the new gains using an optimized path. */ @@ -50783,13 +52315,16 @@ static float ma_calculate_angular_gain(ma_vec3f dirA, ma_vec3f dirB, float coneI MA_API ma_result ma_spatializer_process_pcm_frames(ma_spatializer* pSpatializer, ma_spatializer_listener* pListener, void* pFramesOut, const void* pFramesIn, ma_uint64 frameCount) { - ma_channel* pChannelMapIn = pSpatializer->pChannelMapIn; - ma_channel* pChannelMapOut = pListener->config.pChannelMapOut; + ma_channel* pChannelMapIn; + ma_channel* pChannelMapOut; - if (pSpatializer == NULL) { + if (pSpatializer == NULL || pListener == NULL) { return MA_INVALID_ARGS; } + pChannelMapIn = pSpatializer->pChannelMapIn; + pChannelMapOut = pListener->config.pChannelMapOut; + /* If we're not spatializing we need to run an optimized path. */ if (ma_atomic_load_i32(&pSpatializer->attenuationModel) == ma_attenuation_model_none) { if (ma_spatializer_listener_is_enabled(pListener)) { @@ -50834,23 +52369,17 @@ MA_API ma_result ma_spatializer_process_pcm_frames(ma_spatializer* pSpatializer, We'll need the listener velocity for doppler pitch calculations. The speed of sound is defined by the listener, so we'll grab that here too. */ - if (pListener != NULL) { - listenerVel = ma_spatializer_listener_get_velocity(pListener); - speedOfSound = pListener->config.speedOfSound; - } else { - listenerVel = ma_vec3f_init_3f(0, 0, 0); - speedOfSound = MA_DEFAULT_SPEED_OF_SOUND; - } + listenerVel = ma_spatializer_listener_get_velocity(pListener); + speedOfSound = pListener->config.speedOfSound; - if (pListener == NULL || ma_spatializer_get_positioning(pSpatializer) == ma_positioning_relative) { - /* There's no listener or we're using relative positioning. */ + if (ma_spatializer_get_positioning(pSpatializer) == ma_positioning_relative) { relativePos = ma_spatializer_get_position(pSpatializer); relativeDir = ma_spatializer_get_direction(pSpatializer); } else { /* - We've found a listener and we're using absolute positioning. We need to transform the - sound's position and direction so that it's relative to listener. Later on we'll use - this for determining the factors to apply to each channel to apply the panning effect. + We're using absolute positioning. We need to transform the sound's position and + direction so that it's relative to listener. Later on we'll use this for determining + the factors to apply to each channel to apply the panning effect. */ ma_spatializer_get_relative_position_and_direction(pSpatializer, pListener, &relativePos, &relativeDir); } @@ -52885,7 +54414,7 @@ static ma_bool32 ma_is_spatial_channel_position(ma_channel channelPosition) return MA_FALSE; } - if (channelPosition >= MA_CHANNEL_AUX_0 && channelPosition <= MA_CHANNEL_AUX_31) { + if (channelPosition >= MA_CHANNEL_AUX_0) { return MA_FALSE; } @@ -56408,8 +57937,12 @@ MA_API size_t ma_channel_map_to_string(const ma_channel* pChannelMap, ma_uint32 } /* Null terminate. Don't increment the length here. */ - if (pBufferOut != NULL && bufferCap > len + 1) { - pBufferOut[len] = '\0'; + if (pBufferOut != NULL) { + if (bufferCap > len) { + pBufferOut[len] = '\0'; + } else if (bufferCap > 0) { + pBufferOut[bufferCap - 1] = '\0'; + } } return len; @@ -56620,7 +58153,7 @@ MA_API ma_result ma_rb_init_ex(size_t subbufferSizeInBytes, size_t subbufferCoun Here is where we allocate our own buffer. We always want to align this to MA_SIMD_ALIGNMENT for future SIMD optimization opportunity. To do this we need to make sure the stride is a multiple of MA_SIMD_ALIGNMENT. */ - pRB->subbufferStrideInBytes = (pRB->subbufferSizeInBytes + (MA_SIMD_ALIGNMENT-1)) & ~MA_SIMD_ALIGNMENT; + pRB->subbufferStrideInBytes = ma_align(pRB->subbufferSizeInBytes, MA_SIMD_ALIGNMENT); bufferSizeInBytes = (size_t)pRB->subbufferCount*pRB->subbufferStrideInBytes; pRB->pBuffer = ma_aligned_malloc(bufferSizeInBytes, MA_SIMD_ALIGNMENT, &pRB->allocationCallbacks); @@ -59515,7 +61048,7 @@ MA_API ma_result ma_vfs_info(ma_vfs* pVFS, ma_vfs_file file, ma_file_info* pInfo } -#if !defined(MA_USE_WIN32_FILEIO) && (defined(MA_WIN32) && defined(MA_WIN32_DESKTOP) && !defined(MA_NO_WIN32_FILEIO) && !defined(MA_POSIX)) +#if !defined(MA_USE_WIN32_FILEIO) && (defined(MA_WIN32) && (defined(MA_WIN32_DESKTOP) || defined(MA_WIN32_NXDK)) && !defined(MA_NO_WIN32_FILEIO) && !defined(MA_POSIX)) #define MA_USE_WIN32_FILEIO #endif @@ -59592,25 +61125,34 @@ static ma_result ma_default_vfs_open__win32(ma_vfs* pVFS, const char* pFilePath, static ma_result ma_default_vfs_open_w__win32(ma_vfs* pVFS, const wchar_t* pFilePath, ma_uint32 openMode, ma_vfs_file* pFile) { - HANDLE hFile; - DWORD dwDesiredAccess; - DWORD dwShareMode; - DWORD dwCreationDisposition; + #if !defined(MA_XBOX_NXDK) + { + HANDLE hFile; + DWORD dwDesiredAccess; + DWORD dwShareMode; + DWORD dwCreationDisposition; - (void)pVFS; + (void)pVFS; - /* Load some Win32 symbols dynamically so we can dynamically check for the existence of SetFilePointerEx. */ - ma_win32_fileio_init(); + /* Load some Win32 symbols dynamically so we can dynamically check for the existence of SetFilePointerEx. */ + ma_win32_fileio_init(); - ma_default_vfs__get_open_settings_win32(openMode, &dwDesiredAccess, &dwShareMode, &dwCreationDisposition); + ma_default_vfs__get_open_settings_win32(openMode, &dwDesiredAccess, &dwShareMode, &dwCreationDisposition); - hFile = CreateFileW(pFilePath, dwDesiredAccess, dwShareMode, NULL, dwCreationDisposition, FILE_ATTRIBUTE_NORMAL, NULL); - if (hFile == INVALID_HANDLE_VALUE) { - return ma_result_from_GetLastError(GetLastError()); - } + hFile = CreateFileW(pFilePath, dwDesiredAccess, dwShareMode, NULL, dwCreationDisposition, FILE_ATTRIBUTE_NORMAL, NULL); + if (hFile == INVALID_HANDLE_VALUE) { + return ma_result_from_GetLastError(GetLastError()); + } - *pFile = hFile; - return MA_SUCCESS; + *pFile = hFile; + return MA_SUCCESS; + } + #else + { + /* No CreateFileW() available. */ + return MA_NOT_IMPLEMENTED; + } + #endif } static ma_result ma_default_vfs_close__win32(ma_vfs* pVFS, ma_vfs_file file) @@ -59781,19 +61323,28 @@ static ma_result ma_default_vfs_tell__win32(ma_vfs* pVFS, ma_vfs_file file, ma_i static ma_result ma_default_vfs_info__win32(ma_vfs* pVFS, ma_vfs_file file, ma_file_info* pInfo) { - BY_HANDLE_FILE_INFORMATION fi; - BOOL result; - (void)pVFS; - result = GetFileInformationByHandle((HANDLE)file, &fi); - if (result == 0) { - return ma_result_from_GetLastError(GetLastError()); - } + #if !defined(MA_XBOX_NXDK) + { + BY_HANDLE_FILE_INFORMATION fi; + BOOL result; - pInfo->sizeInBytes = ((ma_uint64)fi.nFileSizeHigh << 32) | ((ma_uint64)fi.nFileSizeLow); + result = GetFileInformationByHandle((HANDLE)file, &fi); + if (result == 0) { + return ma_result_from_GetLastError(GetLastError()); + } - return MA_SUCCESS; + pInfo->sizeInBytes = ((ma_uint64)fi.nFileSizeHigh << 32) | ((ma_uint64)fi.nFileSizeLow); + + return MA_SUCCESS; + } + #else + { + /* GetFileInformationByHandle() is unavailable. */ + return MA_NOT_IMPLEMENTED; + } + #endif } #else static ma_result ma_default_vfs_open__stdio(ma_vfs* pVFS, const char* pFilePath, ma_uint32 openMode, ma_vfs_file* pFile) @@ -60131,6 +61682,8 @@ static ma_result ma_default_vfs_tell(ma_vfs* pVFS, ma_vfs_file file, ma_int64* p static ma_result ma_default_vfs_info(ma_vfs* pVFS, ma_vfs_file file, ma_file_info* pInfo) { + ma_result result; + if (pInfo == NULL) { return MA_INVALID_ARGS; } @@ -60142,10 +61695,42 @@ static ma_result ma_default_vfs_info(ma_vfs* pVFS, ma_vfs_file file, ma_file_inf } #if defined(MA_USE_WIN32_FILEIO) - return ma_default_vfs_info__win32(pVFS, file, pInfo); + result = ma_default_vfs_info__win32(pVFS, file, pInfo); #else - return ma_default_vfs_info__stdio(pVFS, file, pInfo); + result = ma_default_vfs_info__stdio(pVFS, file, pInfo); #endif + + if (result == MA_NOT_IMPLEMENTED) { + /* Not implemented. Fall back to seek/tell/seek. */ + ma_int64 cursor; + ma_int64 sizeInBytes; + + result = ma_default_vfs_tell(pVFS, file, &cursor); + if (result != MA_SUCCESS) { + return result; + } + + result = ma_default_vfs_seek(pVFS, file, 0, ma_seek_origin_end); + if (result != MA_SUCCESS) { + return result; + } + + result = ma_default_vfs_tell(pVFS, file, &sizeInBytes); + if (result != MA_SUCCESS) { + return result; + } + + pInfo->sizeInBytes = sizeInBytes; + + result = ma_default_vfs_seek(pVFS, file, cursor, ma_seek_origin_start); + if (result != MA_SUCCESS) { + return result; + } + + MA_ASSERT(result == MA_SUCCESS); + } + + return result; } @@ -60324,6 +61909,8 @@ Decoding and Encoding Headers. These are auto-generated from a tool. **************************************************************************************************************************************************************/ #if !defined(MA_NO_WAV) && (!defined(MA_NO_DECODING) || !defined(MA_NO_ENCODING)) +#define MA_HAS_WAV + /* dr_wav_h begin */ #ifndef ma_dr_wav_h #define ma_dr_wav_h @@ -60333,8 +61920,8 @@ extern "C" { #define MA_DR_WAV_STRINGIFY(x) #x #define MA_DR_WAV_XSTRINGIFY(x) MA_DR_WAV_STRINGIFY(x) #define MA_DR_WAV_VERSION_MAJOR 0 -#define MA_DR_WAV_VERSION_MINOR 13 -#define MA_DR_WAV_VERSION_REVISION 18 +#define MA_DR_WAV_VERSION_MINOR 14 +#define MA_DR_WAV_VERSION_REVISION 4 #define MA_DR_WAV_VERSION_STRING MA_DR_WAV_XSTRINGIFY(MA_DR_WAV_VERSION_MAJOR) "." MA_DR_WAV_XSTRINGIFY(MA_DR_WAV_VERSION_MINOR) "." MA_DR_WAV_XSTRINGIFY(MA_DR_WAV_VERSION_REVISION) #include #define MA_DR_WAVE_FORMAT_PCM 0x1 @@ -60350,8 +61937,9 @@ MA_API void ma_dr_wav_version(ma_uint32* pMajor, ma_uint32* pMinor, ma_uint32* p MA_API const char* ma_dr_wav_version_string(void); typedef enum { - ma_dr_wav_seek_origin_start, - ma_dr_wav_seek_origin_current + MA_DR_WAV_SEEK_SET, + MA_DR_WAV_SEEK_CUR, + MA_DR_WAV_SEEK_END } ma_dr_wav_seek_origin; typedef enum { @@ -60388,6 +61976,7 @@ MA_API ma_uint16 ma_dr_wav_fmt_get_format(const ma_dr_wav_fmt* pFMT); typedef size_t (* ma_dr_wav_read_proc)(void* pUserData, void* pBufferOut, size_t bytesToRead); typedef size_t (* ma_dr_wav_write_proc)(void* pUserData, const void* pData, size_t bytesToWrite); typedef ma_bool32 (* ma_dr_wav_seek_proc)(void* pUserData, int offset, ma_dr_wav_seek_origin origin); +typedef ma_bool32 (* ma_dr_wav_tell_proc)(void* pUserData, ma_int64* pCursor); typedef ma_uint64 (* ma_dr_wav_chunk_proc)(void* pChunkUserData, ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, void* pReadSeekUserData, const ma_dr_wav_chunk_header* pChunkHeader, ma_dr_wav_container container, const ma_dr_wav_fmt* pFMT); typedef struct { @@ -60432,6 +62021,11 @@ typedef enum ma_dr_wav_metadata_type_list_info_genre = 1 << 15, ma_dr_wav_metadata_type_list_info_album = 1 << 16, ma_dr_wav_metadata_type_list_info_tracknumber = 1 << 17, + ma_dr_wav_metadata_type_list_info_location = 1 << 18, + ma_dr_wav_metadata_type_list_info_organization = 1 << 19, + ma_dr_wav_metadata_type_list_info_keywords = 1 << 20, + ma_dr_wav_metadata_type_list_info_medium = 1 << 21, + ma_dr_wav_metadata_type_list_info_description = 1 << 22, ma_dr_wav_metadata_type_list_all_info_strings = ma_dr_wav_metadata_type_list_info_software | ma_dr_wav_metadata_type_list_info_copyright | ma_dr_wav_metadata_type_list_info_title @@ -60440,7 +62034,12 @@ typedef enum | ma_dr_wav_metadata_type_list_info_date | ma_dr_wav_metadata_type_list_info_genre | ma_dr_wav_metadata_type_list_info_album - | ma_dr_wav_metadata_type_list_info_tracknumber, + | ma_dr_wav_metadata_type_list_info_tracknumber + | ma_dr_wav_metadata_type_list_info_location + | ma_dr_wav_metadata_type_list_info_organization + | ma_dr_wav_metadata_type_list_info_keywords + | ma_dr_wav_metadata_type_list_info_medium + | ma_dr_wav_metadata_type_list_info_description, ma_dr_wav_metadata_type_list_all_adtl = ma_dr_wav_metadata_type_list_label | ma_dr_wav_metadata_type_list_note | ma_dr_wav_metadata_type_list_labelled_cue_region, @@ -60457,8 +62056,8 @@ typedef struct { ma_uint32 cuePointId; ma_uint32 type; - ma_uint32 firstSampleByteOffset; - ma_uint32 lastSampleByteOffset; + ma_uint32 firstSampleOffset; + ma_uint32 lastSampleOffset; ma_uint32 sampleFraction; ma_uint32 playCount; } ma_dr_wav_smpl_loop; @@ -60493,7 +62092,7 @@ typedef struct ma_uint8 dataChunkId[4]; ma_uint32 chunkStart; ma_uint32 blockStart; - ma_uint32 sampleByteOffset; + ma_uint32 sampleOffset; } ma_dr_wav_cue_point; typedef struct { @@ -60595,6 +62194,7 @@ typedef struct ma_dr_wav_read_proc onRead; ma_dr_wav_write_proc onWrite; ma_dr_wav_seek_proc onSeek; + ma_dr_wav_tell_proc onTell; void* pUserData; ma_allocation_callbacks allocationCallbacks; ma_dr_wav_container container; @@ -60637,9 +62237,9 @@ typedef struct ma_bool8 isUnsigned; } aiff; } ma_dr_wav; -MA_API ma_bool32 ma_dr_wav_init(ma_dr_wav* pWav, ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks); -MA_API ma_bool32 ma_dr_wav_init_ex(ma_dr_wav* pWav, ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, ma_dr_wav_chunk_proc onChunk, void* pReadSeekUserData, void* pChunkUserData, ma_uint32 flags, const ma_allocation_callbacks* pAllocationCallbacks); -MA_API ma_bool32 ma_dr_wav_init_with_metadata(ma_dr_wav* pWav, ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, void* pUserData, ma_uint32 flags, const ma_allocation_callbacks* pAllocationCallbacks); +MA_API ma_bool32 ma_dr_wav_init(ma_dr_wav* pWav, ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, ma_dr_wav_tell_proc onTell, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks); +MA_API ma_bool32 ma_dr_wav_init_ex(ma_dr_wav* pWav, ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, ma_dr_wav_tell_proc onTell, ma_dr_wav_chunk_proc onChunk, void* pReadSeekTellUserData, void* pChunkUserData, ma_uint32 flags, const ma_allocation_callbacks* pAllocationCallbacks); +MA_API ma_bool32 ma_dr_wav_init_with_metadata(ma_dr_wav* pWav, ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, ma_dr_wav_tell_proc onTell, void* pUserData, ma_uint32 flags, const ma_allocation_callbacks* pAllocationCallbacks); MA_API ma_bool32 ma_dr_wav_init_write(ma_dr_wav* pWav, const ma_dr_wav_data_format* pFormat, ma_dr_wav_write_proc onWrite, ma_dr_wav_seek_proc onSeek, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks); MA_API ma_bool32 ma_dr_wav_init_write_sequential(ma_dr_wav* pWav, const ma_dr_wav_data_format* pFormat, ma_uint64 totalSampleCount, ma_dr_wav_write_proc onWrite, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks); MA_API ma_bool32 ma_dr_wav_init_write_sequential_pcm_frames(ma_dr_wav* pWav, const ma_dr_wav_data_format* pFormat, ma_uint64 totalPCMFrameCount, ma_dr_wav_write_proc onWrite, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks); @@ -60711,9 +62311,9 @@ MA_API ma_bool32 ma_dr_wav_init_memory_write(ma_dr_wav* pWav, void** ppData, siz MA_API ma_bool32 ma_dr_wav_init_memory_write_sequential(ma_dr_wav* pWav, void** ppData, size_t* pDataSize, const ma_dr_wav_data_format* pFormat, ma_uint64 totalSampleCount, const ma_allocation_callbacks* pAllocationCallbacks); MA_API ma_bool32 ma_dr_wav_init_memory_write_sequential_pcm_frames(ma_dr_wav* pWav, void** ppData, size_t* pDataSize, const ma_dr_wav_data_format* pFormat, ma_uint64 totalPCMFrameCount, const ma_allocation_callbacks* pAllocationCallbacks); #ifndef MA_DR_WAV_NO_CONVERSION_API -MA_API ma_int16* ma_dr_wav_open_and_read_pcm_frames_s16(ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks); -MA_API float* ma_dr_wav_open_and_read_pcm_frames_f32(ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks); -MA_API ma_int32* ma_dr_wav_open_and_read_pcm_frames_s32(ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks); +MA_API ma_int16* ma_dr_wav_open_and_read_pcm_frames_s16(ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, ma_dr_wav_tell_proc onTell, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks); +MA_API float* ma_dr_wav_open_and_read_pcm_frames_f32(ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, ma_dr_wav_tell_proc onTell, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks); +MA_API ma_int32* ma_dr_wav_open_and_read_pcm_frames_s32(ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, ma_dr_wav_tell_proc onTell, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks); #ifndef MA_DR_WAV_NO_STDIO MA_API ma_int16* ma_dr_wav_open_file_and_read_pcm_frames_s16(const char* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks); MA_API float* ma_dr_wav_open_file_and_read_pcm_frames_f32(const char* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks); @@ -60744,6 +62344,8 @@ MA_API ma_bool32 ma_dr_wav_fourcc_equal(const ma_uint8* a, const char* b); #endif /* MA_NO_WAV */ #if !defined(MA_NO_FLAC) && !defined(MA_NO_DECODING) +#define MA_HAS_FLAC + /* dr_flac_h begin */ #ifndef ma_dr_flac_h #define ma_dr_flac_h @@ -60753,8 +62355,8 @@ extern "C" { #define MA_DR_FLAC_STRINGIFY(x) #x #define MA_DR_FLAC_XSTRINGIFY(x) MA_DR_FLAC_STRINGIFY(x) #define MA_DR_FLAC_VERSION_MAJOR 0 -#define MA_DR_FLAC_VERSION_MINOR 12 -#define MA_DR_FLAC_VERSION_REVISION 43 +#define MA_DR_FLAC_VERSION_MINOR 13 +#define MA_DR_FLAC_VERSION_REVISION 3 #define MA_DR_FLAC_VERSION_STRING MA_DR_FLAC_XSTRINGIFY(MA_DR_FLAC_VERSION_MAJOR) "." MA_DR_FLAC_XSTRINGIFY(MA_DR_FLAC_VERSION_MINOR) "." MA_DR_FLAC_XSTRINGIFY(MA_DR_FLAC_VERSION_REVISION) #include #if defined(_MSC_VER) && _MSC_VER >= 1700 @@ -60817,8 +62419,9 @@ typedef enum } ma_dr_flac_container; typedef enum { - ma_dr_flac_seek_origin_start, - ma_dr_flac_seek_origin_current + MA_DR_FLAC_SEEK_SET, + MA_DR_FLAC_SEEK_CUR, + MA_DR_FLAC_SEEK_END } ma_dr_flac_seek_origin; typedef struct { @@ -60841,8 +62444,9 @@ typedef struct typedef struct { ma_uint32 type; - const void* pRawData; ma_uint32 rawDataSize; + ma_uint64 rawDataOffset; + const void* pRawData; union { ma_dr_flac_streaminfo streaminfo; @@ -60888,12 +62492,14 @@ typedef struct ma_uint32 colorDepth; ma_uint32 indexColorCount; ma_uint32 pictureDataSize; + ma_uint64 pictureDataOffset; const ma_uint8* pPictureData; } picture; } data; } ma_dr_flac_metadata; typedef size_t (* ma_dr_flac_read_proc)(void* pUserData, void* pBufferOut, size_t bytesToRead); typedef ma_bool32 (* ma_dr_flac_seek_proc)(void* pUserData, int offset, ma_dr_flac_seek_origin origin); +typedef ma_bool32 (* ma_dr_flac_tell_proc)(void* pUserData, ma_int64* pCursor); typedef void (* ma_dr_flac_meta_proc)(void* pUserData, ma_dr_flac_metadata* pMetadata); typedef struct { @@ -60905,6 +62511,7 @@ typedef struct { ma_dr_flac_read_proc onRead; ma_dr_flac_seek_proc onSeek; + ma_dr_flac_tell_proc onTell; void* pUserData; size_t unalignedByteCount; ma_dr_flac_cache_t unalignedCache; @@ -60964,10 +62571,10 @@ typedef struct ma_dr_flac_bs bs; ma_uint8 pExtraData[1]; } ma_dr_flac; -MA_API ma_dr_flac* ma_dr_flac_open(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks); -MA_API ma_dr_flac* ma_dr_flac_open_relaxed(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_container container, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks); -MA_API ma_dr_flac* ma_dr_flac_open_with_metadata(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_meta_proc onMeta, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks); -MA_API ma_dr_flac* ma_dr_flac_open_with_metadata_relaxed(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_meta_proc onMeta, ma_dr_flac_container container, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks); +MA_API ma_dr_flac* ma_dr_flac_open(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_tell_proc onTell, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks); +MA_API ma_dr_flac* ma_dr_flac_open_relaxed(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_tell_proc onTell, ma_dr_flac_container container, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks); +MA_API ma_dr_flac* ma_dr_flac_open_with_metadata(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_tell_proc onTell, ma_dr_flac_meta_proc onMeta, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks); +MA_API ma_dr_flac* ma_dr_flac_open_with_metadata_relaxed(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_tell_proc onTell, ma_dr_flac_meta_proc onMeta, ma_dr_flac_container container, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks); MA_API void ma_dr_flac_close(ma_dr_flac* pFlac); MA_API ma_uint64 ma_dr_flac_read_pcm_frames_s32(ma_dr_flac* pFlac, ma_uint64 framesToRead, ma_int32* pBufferOut); MA_API ma_uint64 ma_dr_flac_read_pcm_frames_s16(ma_dr_flac* pFlac, ma_uint64 framesToRead, ma_int16* pBufferOut); @@ -60981,9 +62588,9 @@ MA_API ma_dr_flac* ma_dr_flac_open_file_with_metadata_w(const wchar_t* pFileName #endif MA_API ma_dr_flac* ma_dr_flac_open_memory(const void* pData, size_t dataSize, const ma_allocation_callbacks* pAllocationCallbacks); MA_API ma_dr_flac* ma_dr_flac_open_memory_with_metadata(const void* pData, size_t dataSize, ma_dr_flac_meta_proc onMeta, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks); -MA_API ma_int32* ma_dr_flac_open_and_read_pcm_frames_s32(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, void* pUserData, unsigned int* channels, unsigned int* sampleRate, ma_uint64* totalPCMFrameCount, const ma_allocation_callbacks* pAllocationCallbacks); -MA_API ma_int16* ma_dr_flac_open_and_read_pcm_frames_s16(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, void* pUserData, unsigned int* channels, unsigned int* sampleRate, ma_uint64* totalPCMFrameCount, const ma_allocation_callbacks* pAllocationCallbacks); -MA_API float* ma_dr_flac_open_and_read_pcm_frames_f32(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, void* pUserData, unsigned int* channels, unsigned int* sampleRate, ma_uint64* totalPCMFrameCount, const ma_allocation_callbacks* pAllocationCallbacks); +MA_API ma_int32* ma_dr_flac_open_and_read_pcm_frames_s32(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_tell_proc onTell, void* pUserData, unsigned int* channels, unsigned int* sampleRate, ma_uint64* totalPCMFrameCount, const ma_allocation_callbacks* pAllocationCallbacks); +MA_API ma_int16* ma_dr_flac_open_and_read_pcm_frames_s16(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_tell_proc onTell, void* pUserData, unsigned int* channels, unsigned int* sampleRate, ma_uint64* totalPCMFrameCount, const ma_allocation_callbacks* pAllocationCallbacks); +MA_API float* ma_dr_flac_open_and_read_pcm_frames_f32(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_tell_proc onTell, void* pUserData, unsigned int* channels, unsigned int* sampleRate, ma_uint64* totalPCMFrameCount, const ma_allocation_callbacks* pAllocationCallbacks); #ifndef MA_DR_FLAC_NO_STDIO MA_API ma_int32* ma_dr_flac_open_file_and_read_pcm_frames_s32(const char* filename, unsigned int* channels, unsigned int* sampleRate, ma_uint64* totalPCMFrameCount, const ma_allocation_callbacks* pAllocationCallbacks); MA_API ma_int16* ma_dr_flac_open_file_and_read_pcm_frames_s16(const char* filename, unsigned int* channels, unsigned int* sampleRate, ma_uint64* totalPCMFrameCount, const ma_allocation_callbacks* pAllocationCallbacks); @@ -61031,6 +62638,14 @@ MA_API ma_bool32 ma_dr_flac_next_cuesheet_track(ma_dr_flac_cuesheet_track_iterat #endif /* MA_NO_FLAC */ #if !defined(MA_NO_MP3) && !defined(MA_NO_DECODING) +#define MA_HAS_MP3 + +#ifndef MA_DR_MP3_NO_SIMD + #if (defined(MA_NO_NEON) && defined(MA_ARM)) || (defined(MA_NO_SSE2) && (defined(MA_X86) || defined(MA_X64))) + #define MA_DR_MP3_NO_SIMD + #endif +#endif + /* dr_mp3_h begin */ #ifndef ma_dr_mp3_h #define ma_dr_mp3_h @@ -61040,31 +62655,57 @@ extern "C" { #define MA_DR_MP3_STRINGIFY(x) #x #define MA_DR_MP3_XSTRINGIFY(x) MA_DR_MP3_STRINGIFY(x) #define MA_DR_MP3_VERSION_MAJOR 0 -#define MA_DR_MP3_VERSION_MINOR 6 -#define MA_DR_MP3_VERSION_REVISION 40 +#define MA_DR_MP3_VERSION_MINOR 7 +#define MA_DR_MP3_VERSION_REVISION 3 #define MA_DR_MP3_VERSION_STRING MA_DR_MP3_XSTRINGIFY(MA_DR_MP3_VERSION_MAJOR) "." MA_DR_MP3_XSTRINGIFY(MA_DR_MP3_VERSION_MINOR) "." MA_DR_MP3_XSTRINGIFY(MA_DR_MP3_VERSION_REVISION) #include #define MA_DR_MP3_MAX_PCM_FRAMES_PER_MP3_FRAME 1152 #define MA_DR_MP3_MAX_SAMPLES_PER_FRAME (MA_DR_MP3_MAX_PCM_FRAMES_PER_MP3_FRAME*2) MA_API void ma_dr_mp3_version(ma_uint32* pMajor, ma_uint32* pMinor, ma_uint32* pRevision); MA_API const char* ma_dr_mp3_version_string(void); +#define MA_DR_MP3_MAX_BITRESERVOIR_BYTES 511 +#define MA_DR_MP3_MAX_FREE_FORMAT_FRAME_SIZE 2304 +#define MA_DR_MP3_MAX_L3_FRAME_PAYLOAD_BYTES MA_DR_MP3_MAX_FREE_FORMAT_FRAME_SIZE typedef struct { - int frame_bytes, channels, hz, layer, bitrate_kbps; + int frame_bytes, channels, sample_rate, layer, bitrate_kbps; } ma_dr_mp3dec_frame_info; typedef struct +{ + const ma_uint8 *buf; + int pos, limit; +} ma_dr_mp3_bs; +typedef struct +{ + const ma_uint8 *sfbtab; + ma_uint16 part_23_length, big_values, scalefac_compress; + ma_uint8 global_gain, block_type, mixed_block_flag, n_long_sfb, n_short_sfb; + ma_uint8 table_select[3], region_count[3], subblock_gain[3]; + ma_uint8 preflag, scalefac_scale, count1_table, scfsi; +} ma_dr_mp3_L3_gr_info; +typedef struct +{ + ma_dr_mp3_bs bs; + ma_uint8 maindata[MA_DR_MP3_MAX_BITRESERVOIR_BYTES + MA_DR_MP3_MAX_L3_FRAME_PAYLOAD_BYTES]; + ma_dr_mp3_L3_gr_info gr_info[4]; + float grbuf[2][576], scf[40], syn[18 + 15][2*32]; + ma_uint8 ist_pos[2][39]; +} ma_dr_mp3dec_scratch; +typedef struct { float mdct_overlap[2][9*32], qmf_state[15*2*32]; int reserv, free_format_bytes; ma_uint8 header[4], reserv_buf[511]; + ma_dr_mp3dec_scratch scratch; } ma_dr_mp3dec; MA_API void ma_dr_mp3dec_init(ma_dr_mp3dec *dec); MA_API int ma_dr_mp3dec_decode_frame(ma_dr_mp3dec *dec, const ma_uint8 *mp3, int mp3_bytes, void *pcm, ma_dr_mp3dec_frame_info *info); MA_API void ma_dr_mp3dec_f32_to_s16(const float *in, ma_int16 *out, size_t num_samples); typedef enum { - ma_dr_mp3_seek_origin_start, - ma_dr_mp3_seek_origin_current + MA_DR_MP3_SEEK_SET, + MA_DR_MP3_SEEK_CUR, + MA_DR_MP3_SEEK_END } ma_dr_mp3_seek_origin; typedef struct { @@ -61073,8 +62714,24 @@ typedef struct ma_uint16 mp3FramesToDiscard; ma_uint16 pcmFramesToDiscard; } ma_dr_mp3_seek_point; +typedef enum +{ + MA_DR_MP3_METADATA_TYPE_ID3V1, + MA_DR_MP3_METADATA_TYPE_ID3V2, + MA_DR_MP3_METADATA_TYPE_APE, + MA_DR_MP3_METADATA_TYPE_XING, + MA_DR_MP3_METADATA_TYPE_VBRI +} ma_dr_mp3_metadata_type; +typedef struct +{ + ma_dr_mp3_metadata_type type; + const void* pRawData; + size_t rawDataSize; +} ma_dr_mp3_metadata; typedef size_t (* ma_dr_mp3_read_proc)(void* pUserData, void* pBufferOut, size_t bytesToRead); typedef ma_bool32 (* ma_dr_mp3_seek_proc)(void* pUserData, int offset, ma_dr_mp3_seek_origin origin); +typedef ma_bool32 (* ma_dr_mp3_tell_proc)(void* pUserData, ma_int64* pCursor); +typedef void (* ma_dr_mp3_meta_proc)(void* pUserData, const ma_dr_mp3_metadata* pMetadata); typedef struct { ma_uint32 channels; @@ -61087,7 +62744,9 @@ typedef struct ma_uint32 sampleRate; ma_dr_mp3_read_proc onRead; ma_dr_mp3_seek_proc onSeek; + ma_dr_mp3_meta_proc onMeta; void* pUserData; + void* pUserDataMeta; ma_allocation_callbacks allocationCallbacks; ma_uint32 mp3FrameChannels; ma_uint32 mp3FrameSampleRate; @@ -61096,13 +62755,20 @@ typedef struct ma_uint8 pcmFrames[sizeof(float)*MA_DR_MP3_MAX_SAMPLES_PER_FRAME]; ma_uint64 currentPCMFrame; ma_uint64 streamCursor; + ma_uint64 streamLength; + ma_uint64 streamStartOffset; ma_dr_mp3_seek_point* pSeekPoints; ma_uint32 seekPointCount; + ma_uint32 delayInPCMFrames; + ma_uint32 paddingInPCMFrames; + ma_uint64 totalPCMFrameCount; + ma_bool32 isVBR; + ma_bool32 isCBR; size_t dataSize; size_t dataCapacity; size_t dataConsumed; ma_uint8* pData; - ma_bool32 atEnd : 1; + ma_bool32 atEnd; struct { const ma_uint8* pData; @@ -61110,9 +62776,12 @@ typedef struct size_t currentReadPos; } memory; } ma_dr_mp3; -MA_API ma_bool32 ma_dr_mp3_init(ma_dr_mp3* pMP3, ma_dr_mp3_read_proc onRead, ma_dr_mp3_seek_proc onSeek, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks); +MA_API ma_bool32 ma_dr_mp3_init(ma_dr_mp3* pMP3, ma_dr_mp3_read_proc onRead, ma_dr_mp3_seek_proc onSeek, ma_dr_mp3_tell_proc onTell, ma_dr_mp3_meta_proc onMeta, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks); +MA_API ma_bool32 ma_dr_mp3_init_memory_with_metadata(ma_dr_mp3* pMP3, const void* pData, size_t dataSize, ma_dr_mp3_meta_proc onMeta, void* pUserDataMeta, const ma_allocation_callbacks* pAllocationCallbacks); MA_API ma_bool32 ma_dr_mp3_init_memory(ma_dr_mp3* pMP3, const void* pData, size_t dataSize, const ma_allocation_callbacks* pAllocationCallbacks); #ifndef MA_DR_MP3_NO_STDIO +MA_API ma_bool32 ma_dr_mp3_init_file_with_metadata(ma_dr_mp3* pMP3, const char* pFilePath, ma_dr_mp3_meta_proc onMeta, void* pUserDataMeta, const ma_allocation_callbacks* pAllocationCallbacks); +MA_API ma_bool32 ma_dr_mp3_init_file_with_metadata_w(ma_dr_mp3* pMP3, const wchar_t* pFilePath, ma_dr_mp3_meta_proc onMeta, void* pUserDataMeta, const ma_allocation_callbacks* pAllocationCallbacks); MA_API ma_bool32 ma_dr_mp3_init_file(ma_dr_mp3* pMP3, const char* pFilePath, const ma_allocation_callbacks* pAllocationCallbacks); MA_API ma_bool32 ma_dr_mp3_init_file_w(ma_dr_mp3* pMP3, const wchar_t* pFilePath, const ma_allocation_callbacks* pAllocationCallbacks); #endif @@ -61125,8 +62794,8 @@ MA_API ma_uint64 ma_dr_mp3_get_mp3_frame_count(ma_dr_mp3* pMP3); MA_API ma_bool32 ma_dr_mp3_get_mp3_and_pcm_frame_count(ma_dr_mp3* pMP3, ma_uint64* pMP3FrameCount, ma_uint64* pPCMFrameCount); MA_API ma_bool32 ma_dr_mp3_calculate_seek_points(ma_dr_mp3* pMP3, ma_uint32* pSeekPointCount, ma_dr_mp3_seek_point* pSeekPoints); MA_API ma_bool32 ma_dr_mp3_bind_seek_table(ma_dr_mp3* pMP3, ma_uint32 seekPointCount, ma_dr_mp3_seek_point* pSeekPoints); -MA_API float* ma_dr_mp3_open_and_read_pcm_frames_f32(ma_dr_mp3_read_proc onRead, ma_dr_mp3_seek_proc onSeek, void* pUserData, ma_dr_mp3_config* pConfig, ma_uint64* pTotalFrameCount, const ma_allocation_callbacks* pAllocationCallbacks); -MA_API ma_int16* ma_dr_mp3_open_and_read_pcm_frames_s16(ma_dr_mp3_read_proc onRead, ma_dr_mp3_seek_proc onSeek, void* pUserData, ma_dr_mp3_config* pConfig, ma_uint64* pTotalFrameCount, const ma_allocation_callbacks* pAllocationCallbacks); +MA_API float* ma_dr_mp3_open_and_read_pcm_frames_f32(ma_dr_mp3_read_proc onRead, ma_dr_mp3_seek_proc onSeek, ma_dr_mp3_tell_proc onTell, void* pUserData, ma_dr_mp3_config* pConfig, ma_uint64* pTotalFrameCount, const ma_allocation_callbacks* pAllocationCallbacks); +MA_API ma_int16* ma_dr_mp3_open_and_read_pcm_frames_s16(ma_dr_mp3_read_proc onRead, ma_dr_mp3_seek_proc onSeek, ma_dr_mp3_tell_proc onTell, void* pUserData, ma_dr_mp3_config* pConfig, ma_uint64* pTotalFrameCount, const ma_allocation_callbacks* pAllocationCallbacks); MA_API float* ma_dr_mp3_open_memory_and_read_pcm_frames_f32(const void* pData, size_t dataSize, ma_dr_mp3_config* pConfig, ma_uint64* pTotalFrameCount, const ma_allocation_callbacks* pAllocationCallbacks); MA_API ma_int16* ma_dr_mp3_open_memory_and_read_pcm_frames_s16(const void* pData, size_t dataSize, ma_dr_mp3_config* pConfig, ma_uint64* pTotalFrameCount, const ma_allocation_callbacks* pAllocationCallbacks); #ifndef MA_DR_MP3_NO_STDIO @@ -61591,7 +63260,6 @@ static ma_result ma_decoder_init_custom_from_memory__internal(const void* pData, /* WAV */ #ifdef ma_dr_wav_h -#define MA_HAS_WAV typedef struct { @@ -61679,8 +63347,10 @@ static ma_bool32 ma_wav_dr_callback__seek(void* pUserData, int offset, ma_dr_wav MA_ASSERT(pWav != NULL); maSeekOrigin = ma_seek_origin_start; - if (origin == ma_dr_wav_seek_origin_current) { - maSeekOrigin = ma_seek_origin_current; + if (origin == MA_DR_WAV_SEEK_CUR) { + maSeekOrigin = ma_seek_origin_current; + } else if (origin == MA_DR_WAV_SEEK_END) { + maSeekOrigin = ma_seek_origin_end; } result = pWav->onSeek(pWav->pReadSeekTellUserData, offset, maSeekOrigin); @@ -61690,6 +63360,26 @@ static ma_bool32 ma_wav_dr_callback__seek(void* pUserData, int offset, ma_dr_wav return MA_TRUE; } + +static ma_bool32 ma_wav_dr_callback__tell(void* pUserData, ma_int64* pCursor) +{ + ma_wav* pWav = (ma_wav*)pUserData; + ma_result result; + + MA_ASSERT(pWav != NULL); + MA_ASSERT(pCursor != NULL); + + if (pWav->onTell == NULL) { + return MA_FALSE; /* Not implemented. */ + } + + result = pWav->onTell(pWav->pReadSeekTellUserData, pCursor); + if (result != MA_SUCCESS) { + return MA_FALSE; /* Failed to tell. */ + } + + return MA_TRUE; +} #endif static ma_result ma_wav_init_internal(const ma_decoding_backend_config* pConfig, ma_wav* pWav) @@ -61784,7 +63474,7 @@ MA_API ma_result ma_wav_init(ma_read_proc onRead, ma_seek_proc onSeek, ma_tell_p { ma_bool32 wavResult; - wavResult = ma_dr_wav_init(&pWav->dr, ma_wav_dr_callback__read, ma_wav_dr_callback__seek, pWav, pAllocationCallbacks); + wavResult = ma_dr_wav_init(&pWav->dr, ma_wav_dr_callback__read, ma_wav_dr_callback__seek, ma_wav_dr_callback__tell, pWav, pAllocationCallbacks); if (wavResult != MA_TRUE) { return MA_INVALID_FILE; } @@ -62275,7 +63965,6 @@ static ma_result ma_decoder_init_wav_from_memory__internal(const void* pData, si /* FLAC */ #ifdef ma_dr_flac_h -#define MA_HAS_FLAC typedef struct { @@ -62363,8 +64052,10 @@ static ma_bool32 ma_flac_dr_callback__seek(void* pUserData, int offset, ma_dr_fl MA_ASSERT(pFlac != NULL); maSeekOrigin = ma_seek_origin_start; - if (origin == ma_dr_flac_seek_origin_current) { - maSeekOrigin = ma_seek_origin_current; + if (origin == MA_DR_FLAC_SEEK_CUR) { + maSeekOrigin = ma_seek_origin_current; + } else if (origin == MA_DR_FLAC_SEEK_END) { + maSeekOrigin = ma_seek_origin_end; } result = pFlac->onSeek(pFlac->pReadSeekTellUserData, offset, maSeekOrigin); @@ -62374,6 +64065,26 @@ static ma_bool32 ma_flac_dr_callback__seek(void* pUserData, int offset, ma_dr_fl return MA_TRUE; } + +static ma_bool32 ma_flac_dr_callback__tell(void* pUserData, ma_int64* pCursor) +{ + ma_flac* pFlac = (ma_flac*)pUserData; + ma_result result; + + MA_ASSERT(pFlac != NULL); + MA_ASSERT(pCursor != NULL); + + if (pFlac->onTell == NULL) { + return MA_FALSE; /* Not implemented. */ + } + + result = pFlac->onTell(pFlac->pReadSeekTellUserData, pCursor); + if (result != MA_SUCCESS) { + return MA_FALSE; /* Failed to tell. */ + } + + return MA_TRUE; +} #endif static ma_result ma_flac_init_internal(const ma_decoding_backend_config* pConfig, ma_flac* pFlac) @@ -62425,7 +64136,7 @@ MA_API ma_result ma_flac_init(ma_read_proc onRead, ma_seek_proc onSeek, ma_tell_ #if !defined(MA_NO_FLAC) { - pFlac->dr = ma_dr_flac_open(ma_flac_dr_callback__read, ma_flac_dr_callback__seek, pFlac, pAllocationCallbacks); + pFlac->dr = ma_dr_flac_open(ma_flac_dr_callback__read, ma_flac_dr_callback__seek, ma_flac_dr_callback__tell, pFlac, pAllocationCallbacks); if (pFlac->dr == NULL) { return MA_INVALID_FILE; } @@ -62897,7 +64608,6 @@ static ma_result ma_decoder_init_flac_from_memory__internal(const void* pData, s /* MP3 */ #ifdef ma_dr_mp3_h -#define MA_HAS_MP3 typedef struct { @@ -62986,9 +64696,12 @@ static ma_bool32 ma_mp3_dr_callback__seek(void* pUserData, int offset, ma_dr_mp3 MA_ASSERT(pMP3 != NULL); - maSeekOrigin = ma_seek_origin_start; - if (origin == ma_dr_mp3_seek_origin_current) { - maSeekOrigin = ma_seek_origin_current; + if (origin == MA_DR_MP3_SEEK_SET) { + maSeekOrigin = ma_seek_origin_start; + } else if (origin == MA_DR_MP3_SEEK_END) { + maSeekOrigin = ma_seek_origin_end; + } else { + maSeekOrigin = ma_seek_origin_current; } result = pMP3->onSeek(pMP3->pReadSeekTellUserData, offset, maSeekOrigin); @@ -62998,6 +64711,21 @@ static ma_bool32 ma_mp3_dr_callback__seek(void* pUserData, int offset, ma_dr_mp3 return MA_TRUE; } + +static ma_bool32 ma_mp3_dr_callback__tell(void* pUserData, ma_int64* pCursor) +{ + ma_mp3* pMP3 = (ma_mp3*)pUserData; + ma_result result; + + MA_ASSERT(pMP3 != NULL); + + result = pMP3->onTell(pMP3->pReadSeekTellUserData, pCursor); + if (result != MA_SUCCESS) { + return MA_FALSE; + } + + return MA_TRUE; +} #endif static ma_result ma_mp3_init_internal(const ma_decoding_backend_config* pConfig, ma_mp3* pMP3) @@ -63098,7 +64826,7 @@ MA_API ma_result ma_mp3_init(ma_read_proc onRead, ma_seek_proc onSeek, ma_tell_p { ma_bool32 mp3Result; - mp3Result = ma_dr_mp3_init(&pMP3->dr, ma_mp3_dr_callback__read, ma_mp3_dr_callback__seek, pMP3, pAllocationCallbacks); + mp3Result = ma_dr_mp3_init(&pMP3->dr, ma_mp3_dr_callback__read, ma_mp3_dr_callback__seek, ma_mp3_dr_callback__tell, NULL, pMP3, pAllocationCallbacks); if (mp3Result != MA_TRUE) { return MA_INVALID_FILE; } @@ -64557,11 +66285,9 @@ static ma_result ma_decoder_init__internal(ma_decoder_read_proc onRead, ma_decod We use trial and error to open a decoder. We prioritize custom decoders so that if they implement the same encoding format they take priority over the built-in decoders. */ + result = ma_decoder_init_custom__internal(pConfig, pDecoder); if (result != MA_SUCCESS) { - result = ma_decoder_init_custom__internal(pConfig, pDecoder); - if (result != MA_SUCCESS) { - onSeek(pDecoder, 0, ma_seek_origin_start); - } + onSeek(pDecoder, 0, ma_seek_origin_start); } /* @@ -64825,14 +66551,6 @@ MA_API ma_result ma_decoder_init_memory(const void* pData, size_t dataSize, cons /* Initialization was successful. Finish up. */ result = ma_decoder__postinit(&config, pDecoder); if (result != MA_SUCCESS) { - /* - The backend was initialized successfully, but for some reason post-initialization failed. This is most likely - due to an out of memory error. We're going to abort with an error here and not try to recover. - */ - if (pDecoder->pBackendVTable != NULL && pDecoder->pBackendVTable->onUninit != NULL) { - pDecoder->pBackendVTable->onUninit(pDecoder->pBackendUserData, &pDecoder->pBackend, &pDecoder->allocationCallbacks); - } - return result; } } else { @@ -64997,14 +66715,16 @@ static ma_bool32 ma_path_extension_equal_w(const wchar_t* path, const wchar_t* e ext1 = extension; ext2 = ma_path_extension_w(path); -#if defined(_MSC_VER) || defined(__WATCOMC__) || defined(__DMC__) - return _wcsicmp(ext1, ext2) == 0; -#else - /* - I'm not aware of a wide character version of strcasecmp(). I'm therefore converting the extensions to multibyte strings and comparing those. This - isn't the most efficient way to do it, but it should work OK. - */ + #if (defined(_MSC_VER) || defined(__WATCOMC__) || defined(__DMC__)) && !defined(MA_XBOX_NXDK) + { + return _wcsicmp(ext1, ext2) == 0; + } + #elif !defined(MA_XBOX_NXDK) && !defined(MA_DOS) { + /* + I'm not aware of a wide character version of strcasecmp(). I'm therefore converting the extensions to multibyte strings and comparing those. This + isn't the most efficient way to do it, but it should work OK. + */ char ext1MB[4096]; char ext2MB[4096]; const wchar_t* pext1 = ext1; @@ -65024,7 +66744,13 @@ static ma_bool32 ma_path_extension_equal_w(const wchar_t* path, const wchar_t* e return strcasecmp(ext1MB, ext2MB) == 0; } -#endif + #else + { + /* Getting here means we don't have a way to do a case-sensitive comparison for wide strings. Fall back to a simple case-sensitive comparison. */ + /* TODO: Implement our own wchar_t-to-char conversion routine and then use the char* version for comparing. */ + return ma_wcscmp(ext1, ext2) == 0; + } + #endif } #endif /* MA_HAS_PATH_API */ @@ -65125,11 +66851,9 @@ MA_API ma_result ma_decoder_init_vfs(ma_vfs* pVFS, const char* pFilePath, const We use trial and error to open a decoder. We prioritize custom decoders so that if they implement the same encoding format they take priority over the built-in decoders. */ + result = ma_decoder_init_custom__internal(&config, pDecoder); if (result != MA_SUCCESS) { - result = ma_decoder_init_custom__internal(&config, pDecoder); - if (result != MA_SUCCESS) { - ma_decoder__on_seek_vfs(pDecoder, 0, ma_seek_origin_start); - } + ma_decoder__on_seek_vfs(pDecoder, 0, ma_seek_origin_start); } /* @@ -65258,11 +66982,9 @@ MA_API ma_result ma_decoder_init_vfs_w(ma_vfs* pVFS, const wchar_t* pFilePath, c We use trial and error to open a decoder. We prioritize custom decoders so that if they implement the same encoding format they take priority over the built-in decoders. */ + result = ma_decoder_init_custom__internal(&config, pDecoder); if (result != MA_SUCCESS) { - result = ma_decoder_init_custom__internal(&config, pDecoder); - if (result != MA_SUCCESS) { - ma_decoder__on_seek_vfs(pDecoder, 0, ma_seek_origin_start); - } + ma_decoder__on_seek_vfs(pDecoder, 0, ma_seek_origin_start); } /* @@ -65444,14 +67166,6 @@ MA_API ma_result ma_decoder_init_file(const char* pFilePath, const ma_decoder_co /* Initialization was successful. Finish up. */ result = ma_decoder__postinit(&config, pDecoder); if (result != MA_SUCCESS) { - /* - The backend was initialized successfully, but for some reason post-initialization failed. This is most likely - due to an out of memory error. We're going to abort with an error here and not try to recover. - */ - if (pDecoder->pBackendVTable != NULL && pDecoder->pBackendVTable->onUninit != NULL) { - pDecoder->pBackendVTable->onUninit(pDecoder->pBackendUserData, &pDecoder->pBackend, &pDecoder->allocationCallbacks); - } - return result; } } else { @@ -65594,14 +67308,6 @@ MA_API ma_result ma_decoder_init_file_w(const wchar_t* pFilePath, const ma_decod /* Initialization was successful. Finish up. */ result = ma_decoder__postinit(&config, pDecoder); if (result != MA_SUCCESS) { - /* - The backend was initialized successfully, but for some reason post-initialization failed. This is most likely - due to an out of memory error. We're going to abort with an error here and not try to recover. - */ - if (pDecoder->pBackendVTable != NULL && pDecoder->pBackendVTable->onUninit != NULL) { - pDecoder->pBackendVTable->onUninit(pDecoder->pBackendUserData, &pDecoder->pBackend, &pDecoder->allocationCallbacks); - } - return result; } } else { @@ -66119,10 +67825,18 @@ static ma_bool32 ma_encoder__internal_on_seek_wav(void* pUserData, int offset, m { ma_encoder* pEncoder = (ma_encoder*)pUserData; ma_result result; + ma_seek_origin maSeekOrigin; MA_ASSERT(pEncoder != NULL); - result = pEncoder->onSeek(pEncoder, offset, (origin == ma_dr_wav_seek_origin_start) ? ma_seek_origin_start : ma_seek_origin_current); + maSeekOrigin = ma_seek_origin_start; + if (origin == MA_DR_WAV_SEEK_CUR) { + maSeekOrigin = ma_seek_origin_current; + } else if (origin == MA_DR_WAV_SEEK_END) { + maSeekOrigin = ma_seek_origin_end; + } + + result = pEncoder->onSeek(pEncoder, offset, maSeekOrigin); if (result != MA_SUCCESS) { return MA_FALSE; } else { @@ -67644,7 +69358,7 @@ static MA_INLINE ma_uint32 ma_hash_getblock(const ma_uint32* blocks, int i) ma_uint32 block; /* Try silencing a sanitization warning about unaligned access by doing a memcpy() instead of assignment. */ - MA_COPY_MEMORY(&block, ma_offset_ptr(blocks, i * sizeof(block)), sizeof(block)); + MA_COPY_MEMORY(&block, ma_offset_ptr(blocks, i * (int) sizeof(block)), sizeof(block)); if (ma_is_little_endian()) { return block; @@ -67720,7 +69434,7 @@ static ma_uint32 ma_hash_string_32(const char* str) static ma_uint32 ma_hash_string_w_32(const wchar_t* str) { - return ma_hash_32(str, (int)wcslen(str) * sizeof(*str), MA_DEFAULT_HASH_SEED); + return ma_hash_32(str, (int)ma_wcslen(str) * sizeof(*str), MA_DEFAULT_HASH_SEED); } @@ -67880,6 +69594,7 @@ static MA_INLINE ma_resource_manager_data_buffer_node* ma_resource_manager_data_ return ma_resource_manager_data_buffer_node_find_min(pDataBufferNode->pChildHi); } +#if 0 /* Currently unused, but might make use of this later. */ static MA_INLINE ma_resource_manager_data_buffer_node* ma_resource_manager_data_buffer_node_find_inorder_predecessor(ma_resource_manager_data_buffer_node* pDataBufferNode) { MA_ASSERT(pDataBufferNode != NULL); @@ -67887,6 +69602,7 @@ static MA_INLINE ma_resource_manager_data_buffer_node* ma_resource_manager_data_ return ma_resource_manager_data_buffer_node_find_max(pDataBufferNode->pChildLo); } +#endif static ma_result ma_resource_manager_data_buffer_node_remove(ma_resource_manager* pResourceManager, ma_resource_manager_data_buffer_node* pDataBufferNode) { @@ -68237,6 +69953,7 @@ MA_API ma_resource_manager_config ma_resource_manager_config_init(void) config.decodedSampleRate = 0; config.jobThreadCount = 1; /* A single miniaudio-managed job thread by default. */ config.jobQueueCapacity = MA_JOB_TYPE_RESOURCE_MANAGER_QUEUE_CAPACITY; + config.resampling = ma_resampler_config_init(ma_format_unknown, 0, 0, 0, ma_resample_algorithm_linear); /* Format/channels/rate doesn't matter here. */ /* Flags. */ config.flags = 0; @@ -68490,6 +70207,7 @@ static ma_decoder_config ma_resource_manager__init_decoder_config(ma_resource_ma config.ppCustomBackendVTables = pResourceManager->config.ppCustomDecodingBackendVTables; config.customBackendCount = pResourceManager->config.customDecodingBackendCount; config.pCustomBackendUserData = pResourceManager->config.pCustomDecodingBackendUserData; + config.resampling = pResourceManager->config.resampling; return config; } @@ -69009,16 +70727,19 @@ static ma_result ma_resource_manager_data_buffer_node_acquire_critical_section(m /* Failed to post job. Probably ran out of memory. */ ma_log_postf(ma_resource_manager_get_log(pResourceManager), MA_LOG_LEVEL_ERROR, "Failed to post MA_JOB_TYPE_RESOURCE_MANAGER_LOAD_DATA_BUFFER_NODE job. %s.\n", ma_result_description(result)); - /* - Fences were acquired before posting the job, but since the job was not able to - be posted, we need to make sure we release them so nothing gets stuck waiting. - */ - if (pInitFence != NULL) { ma_fence_release(pInitFence); } - if (pDoneFence != NULL) { ma_fence_release(pDoneFence); } - if ((flags & MA_RESOURCE_MANAGER_DATA_SOURCE_FLAG_WAIT_INIT) != 0) { ma_resource_manager_inline_notification_uninit(pInitNotification); } else { + /* + Fences were acquired before posting the job, but since the job was not able to + be posted, we need to make sure we release them so nothing gets stuck waiting. + + In the WAIT_INIT case, these will have already been released in ma_job_process() + so we should only release fences in this branch. + */ + if (pInitFence != NULL) { ma_fence_release(pInitFence); } + if (pDoneFence != NULL) { ma_fence_release(pDoneFence); } + /* These will have been freed by the job thread, but with WAIT_INIT they will already have happened since the job has already been handled. */ ma_free(pFilePathCopy, &pResourceManager->config.allocationCallbacks); ma_free(pFilePathWCopy, &pResourceManager->config.allocationCallbacks); @@ -69812,13 +71533,13 @@ MA_API ma_result ma_resource_manager_data_buffer_get_data_format(ma_resource_man MA_API ma_result ma_resource_manager_data_buffer_get_cursor_in_pcm_frames(ma_resource_manager_data_buffer* pDataBuffer, ma_uint64* pCursor) { - /* We cannot be using the data source after it's been uninitialized. */ - MA_ASSERT(ma_resource_manager_data_buffer_node_result(pDataBuffer->pNode) != MA_UNAVAILABLE); - if (pDataBuffer == NULL || pCursor == NULL) { return MA_INVALID_ARGS; } + /* We cannot be using the data source after it's been uninitialized. */ + MA_ASSERT(ma_resource_manager_data_buffer_node_result(pDataBuffer->pNode) != MA_UNAVAILABLE); + *pCursor = 0; switch (ma_resource_manager_data_buffer_node_get_data_supply_type(pDataBuffer->pNode)) @@ -69852,13 +71573,13 @@ MA_API ma_result ma_resource_manager_data_buffer_get_cursor_in_pcm_frames(ma_res MA_API ma_result ma_resource_manager_data_buffer_get_length_in_pcm_frames(ma_resource_manager_data_buffer* pDataBuffer, ma_uint64* pLength) { - /* We cannot be using the data source after it's been uninitialized. */ - MA_ASSERT(ma_resource_manager_data_buffer_node_result(pDataBuffer->pNode) != MA_UNAVAILABLE); - if (pDataBuffer == NULL || pLength == NULL) { return MA_INVALID_ARGS; } + /* We cannot be using the data source after it's been uninitialized. */ + MA_ASSERT(ma_resource_manager_data_buffer_node_result(pDataBuffer->pNode) != MA_UNAVAILABLE); + if (ma_resource_manager_data_buffer_node_get_data_supply_type(pDataBuffer->pNode) == ma_resource_manager_data_supply_type_unknown) { return MA_BUSY; /* Still loading. */ } @@ -71213,8 +72934,6 @@ static ma_result ma_job_process__resource_manager__free_data_buffer_node(ma_job* return ma_resource_manager_post_job(pResourceManager, pJob); /* Out of order. */ } - ma_resource_manager_data_buffer_node_free(pResourceManager, pDataBufferNode); - /* The event needs to be signalled last. */ if (pJob->data.resourceManager.freeDataBufferNode.pDoneNotification != NULL) { ma_async_notification_signal(pJob->data.resourceManager.freeDataBufferNode.pDoneNotification); @@ -71225,6 +72944,9 @@ static ma_result ma_job_process__resource_manager__free_data_buffer_node(ma_job* } ma_atomic_fetch_add_32(&pDataBufferNode->executionPointer, 1); + + ma_resource_manager_data_buffer_node_free(pResourceManager, pDataBufferNode); + return MA_SUCCESS; } @@ -72097,6 +73819,15 @@ MA_API ma_result ma_node_graph_set_time(ma_node_graph* pNodeGraph, ma_uint64 glo return ma_node_set_time(&pNodeGraph->endpoint, globalTime); /* Global time is just the local time of the endpoint. */ } +MA_API ma_uint32 ma_node_graph_get_processing_size_in_frames(const ma_node_graph* pNodeGraph) +{ + if (pNodeGraph == NULL) { + return 0; + } + + return pNodeGraph->processingSizeInFrames; +} + #define MA_NODE_OUTPUT_BUS_FLAG_HAS_READ 0x01 /* Whether or not this bus ready to read more data. Only used on nodes with multiple output buses. */ @@ -73256,12 +74987,12 @@ MA_API ma_node_state ma_node_get_state_by_time_range(const ma_node* pNode, ma_ui its start time not having been reached yet. Also, the stop time may have also been reached in which case it'll be considered stopped. */ - if (ma_node_get_state_time(pNode, ma_node_state_started) > globalTimeBeg) { - return ma_node_state_stopped; /* Start time has not yet been reached. */ + if (ma_node_get_state_time(pNode, ma_node_state_stopped) < globalTimeBeg) { + return ma_node_state_stopped; /* End time is before the start of the range. */ } - if (ma_node_get_state_time(pNode, ma_node_state_stopped) <= globalTimeEnd) { - return ma_node_state_stopped; /* Stop time has been reached. */ + if (ma_node_get_state_time(pNode, ma_node_state_started) > globalTimeEnd) { + return ma_node_state_stopped; /* Start time is after the end of the range. */ } /* Getting here means the node is marked as started and is within its start/stop times. */ @@ -73341,14 +75072,14 @@ static ma_result ma_node_read_pcm_frames(ma_node* pNode, ma_uint32 outputBusInde return MA_INVALID_ARGS; /* Invalid output bus index. */ } + globalTimeBeg = globalTime; + globalTimeEnd = globalTime + frameCount; + /* Don't do anything if we're in a stopped state. */ - if (ma_node_get_state_by_time_range(pNode, globalTime, globalTime + frameCount) != ma_node_state_started) { + if (ma_node_get_state_by_time_range(pNode, globalTimeBeg, globalTimeEnd) != ma_node_state_started) { return MA_SUCCESS; /* We're in a stopped state. This is not an error - we just need to not read anything. */ } - - globalTimeBeg = globalTime; - globalTimeEnd = globalTime + frameCount; startTime = ma_node_get_state_time(pNode, ma_node_state_started); stopTime = ma_node_get_state_time(pNode, ma_node_state_stopped); @@ -73361,11 +75092,16 @@ static ma_result ma_node_read_pcm_frames(ma_node* pNode, ma_uint32 outputBusInde therefore need to offset it by a number of frames to accommodate. The same thing applies for the stop time. */ - timeOffsetBeg = (globalTimeBeg < startTime) ? (ma_uint32)(globalTimeEnd - startTime) : 0; + timeOffsetBeg = (globalTimeBeg < startTime) ? (ma_uint32)(startTime - globalTimeBeg) : 0; timeOffsetEnd = (globalTimeEnd > stopTime) ? (ma_uint32)(globalTimeEnd - stopTime) : 0; /* Trim based on the start offset. We need to silence the start of the buffer. */ if (timeOffsetBeg > 0) { + MA_ASSERT(timeOffsetBeg <= frameCount); + if (timeOffsetBeg > frameCount) { + timeOffsetBeg = frameCount; + } + ma_silence_pcm_frames(pFramesOut, timeOffsetBeg, ma_format_f32, ma_node_get_output_channels(pNode, outputBusIndex)); pFramesOut += timeOffsetBeg * ma_node_get_output_channels(pNode, outputBusIndex); frameCount -= timeOffsetBeg; @@ -73373,6 +75109,11 @@ static ma_result ma_node_read_pcm_frames(ma_node* pNode, ma_uint32 outputBusInde /* Trim based on the end offset. We don't need to silence the tail section because we'll just have a reduced value written to pFramesRead. */ if (timeOffsetEnd > 0) { + MA_ASSERT(timeOffsetEnd <= frameCount); + if (timeOffsetEnd > frameCount) { + timeOffsetEnd = frameCount; + } + frameCount -= timeOffsetEnd; } @@ -74787,12 +76528,20 @@ static void ma_sound_set_at_end(ma_sound* pSound, ma_bool32 atEnd) MA_ASSERT(pSound != NULL); ma_atomic_exchange_32(&pSound->atEnd, atEnd); + /* + When this function is called the state of the sound will not yet be in a stopped state. This makes it confusing + because an end callback will intuitively expect ma_sound_is_playing() to return false from inside the callback. + I'm therefore no longer firing the callback here and will instead fire it manually in the *next* processing step + when the state should be set to stopped as expected. + */ + #if 0 /* Fire any callbacks or events. */ if (atEnd) { if (pSound->endCallback != NULL) { pSound->endCallback(pSound->pEndCallbackUserData, pSound); } } + #endif } static ma_bool32 ma_sound_get_at_end(const ma_sound* pSound) @@ -74812,6 +76561,7 @@ MA_API ma_engine_node_config ma_engine_node_config_init(ma_engine* pEngine, ma_e config.isPitchDisabled = (flags & MA_SOUND_FLAG_NO_PITCH) != 0; config.isSpatializationDisabled = (flags & MA_SOUND_FLAG_NO_SPATIALIZATION) != 0; config.monoExpansionMode = pEngine->monoExpansionMode; + config.resampling = pEngine->pitchResamplingConfig; return config; } @@ -74838,7 +76588,7 @@ static void ma_engine_node_update_pitch_if_required(ma_engine_node* pEngineNode) if (isUpdateRequired) { float basePitch = (float)pEngineNode->sampleRate / ma_engine_get_sample_rate(pEngineNode->pEngine); - ma_linear_resampler_set_rate_ratio(&pEngineNode->resampler, basePitch * pEngineNode->oldPitch * pEngineNode->oldDopplerPitch); + ma_resampler_set_rate_ratio(&pEngineNode->resampler, basePitch * pEngineNode->oldPitch * pEngineNode->oldDopplerPitch); } } @@ -74857,22 +76607,6 @@ static ma_bool32 ma_engine_node_is_spatialization_enabled(const ma_engine_node* return !ma_atomic_load_explicit_32(&pEngineNode->isSpatializationDisabled, ma_atomic_memory_order_acquire); } -static ma_uint64 ma_engine_node_get_required_input_frame_count(const ma_engine_node* pEngineNode, ma_uint64 outputFrameCount) -{ - ma_uint64 inputFrameCount = 0; - - if (ma_engine_node_is_pitching_enabled(pEngineNode)) { - ma_result result = ma_linear_resampler_get_required_input_frame_count(&pEngineNode->resampler, outputFrameCount, &inputFrameCount); - if (result != MA_SUCCESS) { - inputFrameCount = 0; - } - } else { - inputFrameCount = outputFrameCount; /* No resampling, so 1:1. */ - } - - return inputFrameCount; -} - static ma_result ma_engine_node_set_volume(ma_engine_node* pEngineNode, float volume) { if (pEngineNode == NULL) { @@ -75014,7 +76748,7 @@ static void ma_engine_node_process_pcm_frames__general(ma_engine_node* pEngineNo ma_uint64 resampleFrameCountIn = framesAvailableIn; ma_uint64 resampleFrameCountOut = framesAvailableOut; - ma_linear_resampler_process_pcm_frames(&pEngineNode->resampler, pRunningFramesIn, &resampleFrameCountIn, pWorkingBuffer, &resampleFrameCountOut); + ma_resampler_process_pcm_frames(&pEngineNode->resampler, pRunningFramesIn, &resampleFrameCountIn, pWorkingBuffer, &resampleFrameCountOut); isWorkingBufferValid = MA_TRUE; framesJustProcessedIn = (ma_uint32)resampleFrameCountIn; @@ -75138,6 +76872,11 @@ static void ma_engine_node_process_pcm_frames__sound(ma_node* pNode, const float /* If we're marked at the end we need to stop the sound and do nothing. */ if (ma_sound_at_end(pSound)) { ma_sound_stop(pSound); + + if (pSound->endCallback != NULL) { + pSound->endCallback(pSound->pEndCallbackUserData, pSound); + } + *pFrameCountOut = 0; return; } @@ -75175,55 +76914,74 @@ static void ma_engine_node_process_pcm_frames__sound(ma_node* pNode, const float /* Keep reading until we've read as much as was requested or we reach the end of the data source. */ while (totalFramesRead < frameCount) { ma_uint32 framesRemaining = frameCount - totalFramesRead; - ma_uint32 framesToRead; ma_uint64 framesJustRead; ma_uint32 frameCountIn; ma_uint32 frameCountOut; const float* pRunningFramesIn; float* pRunningFramesOut; - /* - The first thing we need to do is read into the temporary buffer. We can calculate exactly - how many input frames we'll need after resampling. - */ - framesToRead = (ma_uint32)ma_engine_node_get_required_input_frame_count(&pSound->engineNode, framesRemaining); - if (framesToRead > tempCapInFrames) { - framesToRead = tempCapInFrames; - } + /* If there's any input frames sitting in the cache get those processed first. */ + if (pSound->processingCacheFramesRemaining > 0) { + pRunningFramesIn = pSound->pProcessingCache; + frameCountIn = pSound->processingCacheFramesRemaining; - result = ma_data_source_read_pcm_frames(pSound->pDataSource, temp, framesToRead, &framesJustRead); + pRunningFramesOut = ma_offset_pcm_frames_ptr_f32(ppFramesOut[0], totalFramesRead, ma_node_get_output_channels(pNode, 0)); + frameCountOut = framesRemaining; - /* If we reached the end of the sound we'll want to mark it as at the end and stop it. This should never be returned for looping sounds. */ - if (result == MA_AT_END) { - ma_sound_set_at_end(pSound, MA_TRUE); /* This will be set to false in ma_sound_start(). */ - } + ma_engine_node_process_pcm_frames__general(&pSound->engineNode, &pRunningFramesIn, &frameCountIn, &pRunningFramesOut, &frameCountOut); - pRunningFramesOut = ma_offset_pcm_frames_ptr_f32(ppFramesOut[0], totalFramesRead, ma_node_get_output_channels(pNode, 0)); + MA_ASSERT(frameCountIn <= pSound->processingCacheFramesRemaining); + pSound->processingCacheFramesRemaining -= frameCountIn; - frameCountIn = (ma_uint32)framesJustRead; - frameCountOut = framesRemaining; + /* Move any remaining data in the cache down. */ + if (pSound->processingCacheFramesRemaining > 0) { + MA_MOVE_MEMORY(pSound->pProcessingCache, ma_offset_pcm_frames_ptr_f32(pSound->pProcessingCache, frameCountIn, dataSourceChannels), pSound->processingCacheFramesRemaining * ma_get_bytes_per_frame(ma_format_f32, dataSourceChannels)); + } + + totalFramesRead += (ma_uint32)frameCountOut; /* Safe cast. */ - /* Convert if necessary. */ - if (dataSourceFormat == ma_format_f32) { - /* Fast path. No data conversion necessary. */ - pRunningFramesIn = (float*)temp; - ma_engine_node_process_pcm_frames__general(&pSound->engineNode, &pRunningFramesIn, &frameCountIn, &pRunningFramesOut, &frameCountOut); + if (result != MA_SUCCESS || ma_sound_at_end(pSound)) { + break; /* Might have reached the end. */ + } } else { - /* Slow path. Need to do sample format conversion to f32. If we give the f32 buffer the same count as the first temp buffer, we're guaranteed it'll be large enough. */ - float tempf32[MA_DATA_CONVERTER_STACK_BUFFER_SIZE]; /* Do not do `MA_DATA_CONVERTER_STACK_BUFFER_SIZE/sizeof(float)` here like we've done in other places. */ - ma_convert_pcm_frames_format(tempf32, ma_format_f32, temp, dataSourceFormat, framesJustRead, dataSourceChannels, ma_dither_mode_none); + /* Getting here means there's nothing in the cache. Read more data from the data source. */ + if (dataSourceFormat == ma_format_f32) { + /* Fast path. No conversion to f32 necessary. */ + result = ma_data_source_read_pcm_frames(pSound->pDataSource, pSound->pProcessingCache, pSound->processingCacheCap, &framesJustRead); + } else { + /* Slow path. Need to convert to f32. */ + ma_uint64 totalFramesConverted = 0; + + while (totalFramesConverted < pSound->processingCacheCap) { + ma_uint64 framesConverted; + ma_uint32 framesToConvertThisIteration = pSound->processingCacheCap - (ma_uint32)totalFramesConverted; + if (framesToConvertThisIteration > tempCapInFrames) { + framesToConvertThisIteration = tempCapInFrames; + } - /* Now that we have our samples in f32 format we can process like normal. */ - pRunningFramesIn = tempf32; - ma_engine_node_process_pcm_frames__general(&pSound->engineNode, &pRunningFramesIn, &frameCountIn, &pRunningFramesOut, &frameCountOut); - } + result = ma_data_source_read_pcm_frames(pSound->pDataSource, temp, framesToConvertThisIteration, &framesConverted); + if (result != MA_SUCCESS) { + break; + } - /* We should have processed all of our input frames since we calculated the required number of input frames at the top. */ - MA_ASSERT(frameCountIn == framesJustRead); - totalFramesRead += (ma_uint32)frameCountOut; /* Safe cast. */ + ma_convert_pcm_frames_format(ma_offset_pcm_frames_ptr_f32(pSound->pProcessingCache, totalFramesConverted, dataSourceChannels), ma_format_f32, temp, dataSourceFormat, framesConverted, dataSourceChannels, ma_dither_mode_none); + totalFramesConverted += framesConverted; + } - if (result != MA_SUCCESS || ma_sound_at_end(pSound)) { - break; /* Might have reached the end. */ + framesJustRead = totalFramesConverted; + } + + MA_ASSERT(framesJustRead <= pSound->processingCacheCap); + pSound->processingCacheFramesRemaining = (ma_uint32)framesJustRead; + + /* If we reached the end of the sound we'll want to mark it as at the end and stop it. This should never be returned for looping sounds. */ + if (result == MA_AT_END) { + ma_sound_set_at_end(pSound, MA_TRUE); /* This will be set to false in ma_sound_start(). */ + } + + if (result != MA_SUCCESS || ma_sound_at_end(pSound)) { + break; + } } } } @@ -75246,25 +77004,6 @@ static void ma_engine_node_process_pcm_frames__group(ma_node* pNode, const float ma_engine_node_process_pcm_frames__general((ma_engine_node*)pNode, ppFramesIn, pFrameCountIn, ppFramesOut, pFrameCountOut); } -static ma_result ma_engine_node_get_required_input_frame_count__group(ma_node* pNode, ma_uint32 outputFrameCount, ma_uint32* pInputFrameCount) -{ - ma_uint64 inputFrameCount; - - MA_ASSERT(pInputFrameCount != NULL); - - /* Our pitch will affect this calculation. We need to update it. */ - ma_engine_node_update_pitch_if_required((ma_engine_node*)pNode); - - inputFrameCount = ma_engine_node_get_required_input_frame_count((ma_engine_node*)pNode, outputFrameCount); - if (inputFrameCount > 0xFFFFFFFF) { - inputFrameCount = 0xFFFFFFFF; /* Will never happen because miniaudio will only ever process in relatively small chunks. */ - } - - *pInputFrameCount = (ma_uint32)inputFrameCount; - - return MA_SUCCESS; -} - static ma_node_vtable g_ma_engine_node_vtable__sound = { @@ -75278,7 +77017,7 @@ static ma_node_vtable g_ma_engine_node_vtable__sound = static ma_node_vtable g_ma_engine_node_vtable__group = { ma_engine_node_process_pcm_frames__group, - ma_engine_node_get_required_input_frame_count__group, + NULL, /* onGetRequiredInputFrameCount */ 1, /* Groups have one input bus. */ 1, /* Groups have one output bus. */ MA_NODE_FLAG_DIFFERENT_PROCESSING_RATES /* The engine node does resampling so should let miniaudio know about it. */ @@ -75324,9 +77063,10 @@ static ma_result ma_engine_node_get_heap_layout(const ma_engine_node_config* pCo ma_result result; size_t tempHeapSize; ma_node_config baseNodeConfig; - ma_linear_resampler_config resamplerConfig; + ma_resampler_config resamplerConfig; ma_spatializer_config spatializerConfig; ma_gainer_config gainerConfig; + ma_uint32 sampleRate; ma_uint32 channelsIn; ma_uint32 channelsOut; ma_channel defaultStereoChannelMap[2] = {MA_CHANNEL_SIDE_LEFT, MA_CHANNEL_SIDE_RIGHT}; /* <-- Consistent with the default channel map of a stereo listener. Means channel conversion can run on a fast path. */ @@ -75345,6 +77085,7 @@ static ma_result ma_engine_node_get_heap_layout(const ma_engine_node_config* pCo pHeapLayout->sizeInBytes = 0; + sampleRate = (pConfig->sampleRate > 0) ? pConfig->sampleRate : ma_engine_get_sample_rate(pConfig->pEngine); channelsIn = (pConfig->channelsIn != 0) ? pConfig->channelsIn : ma_engine_get_channels(pConfig->pEngine); channelsOut = (pConfig->channelsOut != 0) ? pConfig->channelsOut : ma_engine_get_channels(pConfig->pEngine); @@ -75364,10 +77105,13 @@ static ma_result ma_engine_node_get_heap_layout(const ma_engine_node_config* pCo /* Resmapler. */ - resamplerConfig = ma_linear_resampler_config_init(ma_format_f32, channelsIn, 1, 1); /* Input and output sample rates don't affect the calculation of the heap size. */ - resamplerConfig.lpfOrder = 0; + resamplerConfig = pConfig->resampling; + resamplerConfig.format = ma_format_f32; + resamplerConfig.channels = channelsIn; + resamplerConfig.sampleRateIn = sampleRate; + resamplerConfig.sampleRateOut = ma_engine_get_sample_rate(pConfig->pEngine); - result = ma_linear_resampler_get_heap_size(&resamplerConfig, &tempHeapSize); + result = ma_resampler_get_heap_size(&resamplerConfig, &tempHeapSize); if (result != MA_SUCCESS) { return result; /* Failed to retrieve the size of the heap for the resampler. */ } @@ -75435,7 +77179,7 @@ MA_API ma_result ma_engine_node_init_preallocated(const ma_engine_node_config* p ma_result result; ma_engine_node_heap_layout heapLayout; ma_node_config baseNodeConfig; - ma_linear_resampler_config resamplerConfig; + ma_resampler_config resamplerConfig; ma_fader_config faderConfig; ma_spatializer_config spatializerConfig; ma_panner_config pannerConfig; @@ -75510,10 +77254,13 @@ MA_API ma_result ma_engine_node_init_preallocated(const ma_engine_node_config* p */ /* We'll always do resampling first. */ - resamplerConfig = ma_linear_resampler_config_init(ma_format_f32, baseNodeConfig.pInputChannels[0], pEngineNode->sampleRate, ma_engine_get_sample_rate(pEngineNode->pEngine)); - resamplerConfig.lpfOrder = 0; /* <-- Need to disable low-pass filtering for pitch shifting for now because there's cases where the biquads are becoming unstable. Need to figure out a better fix for this. */ + resamplerConfig = pConfig->resampling; + resamplerConfig.format = ma_format_f32; + resamplerConfig.channels = baseNodeConfig.pInputChannels[0]; + resamplerConfig.sampleRateIn = pEngineNode->sampleRate; + resamplerConfig.sampleRateOut = ma_engine_get_sample_rate(pEngineNode->pEngine); - result = ma_linear_resampler_init_preallocated(&resamplerConfig, ma_offset_ptr(pHeap, heapLayout.resamplerOffset), &pEngineNode->resampler); + result = ma_resampler_init_preallocated(&resamplerConfig, ma_offset_ptr(pHeap, heapLayout.resamplerOffset), &pEngineNode->resampler); if (result != MA_SUCCESS) { goto error1; } @@ -75572,7 +77319,7 @@ MA_API ma_result ma_engine_node_init_preallocated(const ma_engine_node_config* p /* No need for allocation callbacks here because we use a preallocated heap. */ error3: ma_spatializer_uninit(&pEngineNode->spatializer, NULL); -error2: ma_linear_resampler_uninit(&pEngineNode->resampler, NULL); +error2: ma_resampler_uninit(&pEngineNode->resampler, NULL); error1: ma_node_uninit(&pEngineNode->baseNode, NULL); error0: return result; } @@ -75621,7 +77368,7 @@ MA_API void ma_engine_node_uninit(ma_engine_node* pEngineNode, const ma_allocati } ma_spatializer_uninit(&pEngineNode->spatializer, pAllocationCallbacks); - ma_linear_resampler_uninit(&pEngineNode->resampler, pAllocationCallbacks); + ma_resampler_uninit(&pEngineNode->resampler, pAllocationCallbacks); /* Free the heap last. */ if (pEngineNode->_ownsHeap) { @@ -75643,8 +77390,12 @@ MA_API ma_sound_config ma_sound_config_init_2(ma_engine* pEngine) if (pEngine != NULL) { config.monoExpansionMode = pEngine->monoExpansionMode; + config.pitchResampling = pEngine->pitchResamplingConfig; } else { config.monoExpansionMode = ma_mono_expansion_mode_default; + + config.pitchResampling = ma_resampler_config_init(ma_format_f32, 0, 0, 0, ma_resample_algorithm_linear); + config.pitchResampling.linear.lpfOrder = 0; /* <-- Need to disable low-pass filtering for pitch shifting for now because there's cases where the biquads are becoming unstable. Need to figure out a better fix for this. */ } config.rangeEndInPCMFrames = ~((ma_uint64)0); @@ -75666,8 +77417,12 @@ MA_API ma_sound_group_config ma_sound_group_config_init_2(ma_engine* pEngine) if (pEngine != NULL) { config.monoExpansionMode = pEngine->monoExpansionMode; + config.pitchResampling = pEngine->pitchResamplingConfig; } else { config.monoExpansionMode = ma_mono_expansion_mode_default; + + config.pitchResampling = ma_resampler_config_init(ma_format_f32, 0, 0, 0, ma_resample_algorithm_linear); + config.pitchResampling.linear.lpfOrder = 0; /* <-- Need to disable low-pass filtering for pitch shifting for now because there's cases where the biquads are becoming unstable. Need to figure out a better fix for this. */ } return config; @@ -75679,8 +77434,12 @@ MA_API ma_engine_config ma_engine_config_init(void) ma_engine_config config; MA_ZERO_OBJECT(&config); - config.listenerCount = 1; /* Always want at least one listener. */ - config.monoExpansionMode = ma_mono_expansion_mode_default; + config.listenerCount = 1; /* Always want at least one listener. */ + config.monoExpansionMode = ma_mono_expansion_mode_default; + config.resourceManagerResampling = ma_resampler_config_init(ma_format_unknown, 0, 0, 0, ma_resample_algorithm_linear); + + config.pitchResampling = ma_resampler_config_init(ma_format_f32, 0, 0, 0, ma_resample_algorithm_linear); + config.pitchResampling.linear.lpfOrder = 0; /* <-- Need to disable low-pass filtering for pitch shifting for now because there's cases where the biquads are becoming unstable. Need to figure out a better fix for this. */ return config; } @@ -75761,6 +77520,7 @@ MA_API ma_result ma_engine_init(const ma_engine_config* pConfig, ma_engine* pEng pEngine->defaultVolumeSmoothTimeInPCMFrames = engineConfig.defaultVolumeSmoothTimeInPCMFrames; pEngine->onProcess = engineConfig.onProcess; pEngine->pProcessUserData = engineConfig.pProcessUserData; + pEngine->pitchResamplingConfig = engineConfig.pitchResampling; ma_allocation_callbacks_init_copy(&pEngine->allocationCallbacks, &engineConfig.allocationCallbacks); #if !defined(MA_NO_RESOURCE_MANAGER) @@ -75943,6 +77703,7 @@ MA_API ma_result ma_engine_init(const ma_engine_config* pConfig, ma_engine* pEng resourceManagerConfig.decodedSampleRate = ma_engine_get_sample_rate(pEngine); ma_allocation_callbacks_init_copy(&resourceManagerConfig.allocationCallbacks, &pEngine->allocationCallbacks); resourceManagerConfig.pVFS = engineConfig.pResourceManagerVFS; + resourceManagerConfig.resampling = engineConfig.resourceManagerResampling; /* The Emscripten build cannot use threads unless it's targeting pthreads. */ #if defined(MA_EMSCRIPTEN) && !defined(__EMSCRIPTEN_PTHREADS__) @@ -76668,13 +78429,32 @@ static ma_result ma_sound_init_from_data_source_internal(ma_engine* pEngine, con } + /* + When pulling data from a data source we need a processing cache to hold onto unprocessed input data from the data source + after doing resampling. + */ + if (pSound->pDataSource != NULL) { + pSound->processingCacheFramesRemaining = 0; + pSound->processingCacheCap = ma_node_graph_get_processing_size_in_frames(&pEngine->nodeGraph); + if (pSound->processingCacheCap == 0) { + pSound->processingCacheCap = 512; + } + + pSound->pProcessingCache = (float*)ma_calloc(pSound->processingCacheCap * ma_get_bytes_per_frame(ma_format_f32, engineNodeConfig.channelsIn), &pEngine->allocationCallbacks); + if (pSound->pProcessingCache == NULL) { + ma_engine_node_uninit(&pSound->engineNode, &pEngine->allocationCallbacks); + return MA_OUT_OF_MEMORY; + } + } + + /* Apply initial range and looping state to the data source if applicable. */ if (pConfig->rangeBegInPCMFrames != 0 || pConfig->rangeEndInPCMFrames != ~((ma_uint64)0)) { ma_data_source_set_range_in_pcm_frames(ma_sound_get_data_source(pSound), pConfig->rangeBegInPCMFrames, pConfig->rangeEndInPCMFrames); } if (pConfig->loopPointBegInPCMFrames != 0 || pConfig->loopPointEndInPCMFrames != ~((ma_uint64)0)) { - ma_data_source_set_range_in_pcm_frames(ma_sound_get_data_source(pSound), pConfig->loopPointBegInPCMFrames, pConfig->loopPointEndInPCMFrames); + ma_data_source_set_loop_point_in_pcm_frames(ma_sound_get_data_source(pSound), pConfig->loopPointBegInPCMFrames, pConfig->loopPointEndInPCMFrames); } ma_sound_set_looping(pSound, pConfig->isLooping || ((pConfig->flags & MA_SOUND_FLAG_LOOPING) != 0)); @@ -76736,6 +78516,7 @@ MA_API ma_result ma_sound_init_from_file_internal(ma_engine* pEngine, const ma_s result = ma_resource_manager_data_source_init_ex(pEngine->pResourceManager, &resourceManagerDataSourceConfig, pSound->pResourceManagerDataSource); if (result != MA_SUCCESS) { + ma_free(pSound->pResourceManagerDataSource, &pEngine->allocationCallbacks); goto done; } @@ -76904,6 +78685,11 @@ MA_API void ma_sound_uninit(ma_sound* pSound) */ ma_engine_node_uninit(&pSound->engineNode, &pSound->engineNode.pEngine->allocationCallbacks); + if (pSound->pProcessingCache != NULL) { + ma_free(pSound->pProcessingCache, &pSound->engineNode.pEngine->allocationCallbacks); + pSound->pProcessingCache = NULL; + } + /* Once the sound is detached from the group we can guarantee that it won't be referenced by the mixer thread which means it's safe for us to destroy the data source. */ #ifndef MA_NO_RESOURCE_MANAGER if (pSound->ownsDataSource) { @@ -76999,6 +78785,27 @@ MA_API ma_result ma_sound_stop_with_fade_in_milliseconds(ma_sound* pSound, ma_ui return ma_sound_stop_with_fade_in_pcm_frames(pSound, (fadeLengthInMilliseconds * sampleRate) / 1000); } +MA_API void ma_sound_reset_start_time(ma_sound* pSound) +{ + ma_sound_set_start_time_in_pcm_frames(pSound, 0); +} + +MA_API void ma_sound_reset_stop_time(ma_sound* pSound) +{ + ma_sound_set_stop_time_in_pcm_frames(pSound, ~(ma_uint64)0); +} + +MA_API void ma_sound_reset_fade(ma_sound* pSound) +{ + ma_sound_set_fade_in_pcm_frames(pSound, 0, 1, 0); +} + +MA_API void ma_sound_reset_stop_time_and_fade(ma_sound* pSound) +{ + ma_sound_reset_stop_time(pSound); + ma_sound_reset_fade(pSound); +} + MA_API void ma_sound_set_volume(ma_sound* pSound, float volume) { if (pSound == NULL) { @@ -77541,7 +79348,12 @@ MA_API ma_uint64 ma_sound_get_time_in_pcm_frames(const ma_sound* pSound) MA_API ma_uint64 ma_sound_get_time_in_milliseconds(const ma_sound* pSound) { - return ma_sound_get_time_in_pcm_frames(pSound) * 1000 / ma_engine_get_sample_rate(ma_sound_get_engine(pSound)); + ma_uint32 sampleRate = ma_engine_get_sample_rate(ma_sound_get_engine(pSound)); + if (sampleRate == 0) { + return 0; /* Prevent a division by zero. */ + } + + return ma_sound_get_time_in_pcm_frames(pSound) * 1000 / sampleRate; } MA_API void ma_sound_set_looping(ma_sound* pSound, ma_bool32 isLooping) @@ -77625,7 +79437,7 @@ MA_API ma_result ma_sound_seek_to_second(ma_sound* pSound, float seekPointInSeco return ma_sound_seek_to_pcm_frame(pSound, frameIndex); } -MA_API ma_result ma_sound_get_data_format(ma_sound* pSound, ma_format* pFormat, ma_uint32* pChannels, ma_uint32* pSampleRate, ma_channel* pChannelMap, size_t channelMapCap) +MA_API ma_result ma_sound_get_data_format(const ma_sound* pSound, ma_format* pFormat, ma_uint32* pChannels, ma_uint32* pSampleRate, ma_channel* pChannelMap, size_t channelMapCap) { if (pSound == NULL) { return MA_INVALID_ARGS; @@ -77645,7 +79457,7 @@ MA_API ma_result ma_sound_get_data_format(ma_sound* pSound, ma_format* pFormat, } if (pSampleRate != NULL) { - *pSampleRate = pSound->engineNode.resampler.config.sampleRateIn; + *pSampleRate = pSound->engineNode.resampler.sampleRateIn; } if (pChannelMap != NULL) { @@ -77658,7 +79470,7 @@ MA_API ma_result ma_sound_get_data_format(ma_sound* pSound, ma_format* pFormat, } } -MA_API ma_result ma_sound_get_cursor_in_pcm_frames(ma_sound* pSound, ma_uint64* pCursor) +MA_API ma_result ma_sound_get_cursor_in_pcm_frames(const ma_sound* pSound, ma_uint64* pCursor) { ma_uint64 seekTarget; @@ -77680,7 +79492,7 @@ MA_API ma_result ma_sound_get_cursor_in_pcm_frames(ma_sound* pSound, ma_uint64* } } -MA_API ma_result ma_sound_get_length_in_pcm_frames(ma_sound* pSound, ma_uint64* pLength) +MA_API ma_result ma_sound_get_length_in_pcm_frames(const ma_sound* pSound, ma_uint64* pLength) { if (pSound == NULL) { return MA_INVALID_ARGS; @@ -77694,7 +79506,7 @@ MA_API ma_result ma_sound_get_length_in_pcm_frames(ma_sound* pSound, ma_uint64* return ma_data_source_get_length_in_pcm_frames(pSound->pDataSource, pLength); } -MA_API ma_result ma_sound_get_cursor_in_seconds(ma_sound* pSound, float* pCursor) +MA_API ma_result ma_sound_get_cursor_in_seconds(const ma_sound* pSound, float* pCursor) { ma_result result; ma_uint64 cursorInPCMFrames; @@ -77720,7 +79532,7 @@ MA_API ma_result ma_sound_get_cursor_in_seconds(ma_sound* pSound, float* pCursor return MA_SUCCESS; } -MA_API ma_result ma_sound_get_length_in_seconds(ma_sound* pSound, float* pLength) +MA_API ma_result ma_sound_get_length_in_seconds(const ma_sound* pSound, float* pLength) { if (pSound == NULL) { return MA_INVALID_ARGS; @@ -78539,12 +80351,12 @@ MA_PRIVATE ma_bool32 ma_dr_wav__seek_forward(ma_dr_wav_seek_proc onSeek, ma_uint ma_uint64 bytesRemainingToSeek = offset; while (bytesRemainingToSeek > 0) { if (bytesRemainingToSeek > 0x7FFFFFFF) { - if (!onSeek(pUserData, 0x7FFFFFFF, ma_dr_wav_seek_origin_current)) { + if (!onSeek(pUserData, 0x7FFFFFFF, MA_DR_WAV_SEEK_CUR)) { return MA_FALSE; } bytesRemainingToSeek -= 0x7FFFFFFF; } else { - if (!onSeek(pUserData, (int)bytesRemainingToSeek, ma_dr_wav_seek_origin_current)) { + if (!onSeek(pUserData, (int)bytesRemainingToSeek, MA_DR_WAV_SEEK_CUR)) { return MA_FALSE; } bytesRemainingToSeek = 0; @@ -78555,17 +80367,17 @@ MA_PRIVATE ma_bool32 ma_dr_wav__seek_forward(ma_dr_wav_seek_proc onSeek, ma_uint MA_PRIVATE ma_bool32 ma_dr_wav__seek_from_start(ma_dr_wav_seek_proc onSeek, ma_uint64 offset, void* pUserData) { if (offset <= 0x7FFFFFFF) { - return onSeek(pUserData, (int)offset, ma_dr_wav_seek_origin_start); + return onSeek(pUserData, (int)offset, MA_DR_WAV_SEEK_SET); } - if (!onSeek(pUserData, 0x7FFFFFFF, ma_dr_wav_seek_origin_start)) { + if (!onSeek(pUserData, 0x7FFFFFFF, MA_DR_WAV_SEEK_SET)) { return MA_FALSE; } offset -= 0x7FFFFFFF; for (;;) { if (offset <= 0x7FFFFFFF) { - return onSeek(pUserData, (int)offset, ma_dr_wav_seek_origin_current); + return onSeek(pUserData, (int)offset, MA_DR_WAV_SEEK_CUR); } - if (!onSeek(pUserData, 0x7FFFFFFF, ma_dr_wav_seek_origin_current)) { + if (!onSeek(pUserData, 0x7FFFFFFF, MA_DR_WAV_SEEK_CUR)) { return MA_FALSE; } offset -= 0x7FFFFFFF; @@ -78588,7 +80400,7 @@ MA_PRIVATE ma_bool32 ma_dr_wav__on_seek(ma_dr_wav_seek_proc onSeek, void* pUserD if (!onSeek(pUserData, offset, origin)) { return MA_FALSE; } - if (origin == ma_dr_wav_seek_origin_start) { + if (origin == MA_DR_WAV_SEEK_SET) { *pCursor = offset; } else { *pCursor += offset; @@ -78707,12 +80519,12 @@ MA_PRIVATE ma_uint64 ma_dr_wav__read_smpl_to_metadata_obj(ma_dr_wav__metadata_pa ma_uint8 smplLoopData[MA_DR_WAV_SMPL_LOOP_BYTES]; bytesJustRead = ma_dr_wav__metadata_parser_read(pParser, smplLoopData, sizeof(smplLoopData), &totalBytesRead); if (bytesJustRead == sizeof(smplLoopData)) { - pMetadata->data.smpl.pLoops[iSampleLoop].cuePointId = ma_dr_wav_bytes_to_u32(smplLoopData + 0); - pMetadata->data.smpl.pLoops[iSampleLoop].type = ma_dr_wav_bytes_to_u32(smplLoopData + 4); - pMetadata->data.smpl.pLoops[iSampleLoop].firstSampleByteOffset = ma_dr_wav_bytes_to_u32(smplLoopData + 8); - pMetadata->data.smpl.pLoops[iSampleLoop].lastSampleByteOffset = ma_dr_wav_bytes_to_u32(smplLoopData + 12); - pMetadata->data.smpl.pLoops[iSampleLoop].sampleFraction = ma_dr_wav_bytes_to_u32(smplLoopData + 16); - pMetadata->data.smpl.pLoops[iSampleLoop].playCount = ma_dr_wav_bytes_to_u32(smplLoopData + 20); + pMetadata->data.smpl.pLoops[iSampleLoop].cuePointId = ma_dr_wav_bytes_to_u32(smplLoopData + 0); + pMetadata->data.smpl.pLoops[iSampleLoop].type = ma_dr_wav_bytes_to_u32(smplLoopData + 4); + pMetadata->data.smpl.pLoops[iSampleLoop].firstSampleOffset = ma_dr_wav_bytes_to_u32(smplLoopData + 8); + pMetadata->data.smpl.pLoops[iSampleLoop].lastSampleOffset = ma_dr_wav_bytes_to_u32(smplLoopData + 12); + pMetadata->data.smpl.pLoops[iSampleLoop].sampleFraction = ma_dr_wav_bytes_to_u32(smplLoopData + 16); + pMetadata->data.smpl.pLoops[iSampleLoop].playCount = ma_dr_wav_bytes_to_u32(smplLoopData + 20); } else { break; } @@ -78756,7 +80568,7 @@ MA_PRIVATE ma_uint64 ma_dr_wav__read_cue_to_metadata_obj(ma_dr_wav__metadata_par pMetadata->data.cue.pCuePoints[iCuePoint].dataChunkId[3] = cuePointData[11]; pMetadata->data.cue.pCuePoints[iCuePoint].chunkStart = ma_dr_wav_bytes_to_u32(cuePointData + 12); pMetadata->data.cue.pCuePoints[iCuePoint].blockStart = ma_dr_wav_bytes_to_u32(cuePointData + 16); - pMetadata->data.cue.pCuePoints[iCuePoint].sampleByteOffset = ma_dr_wav_bytes_to_u32(cuePointData + 20); + pMetadata->data.cue.pCuePoints[iCuePoint].sampleOffset = ma_dr_wav_bytes_to_u32(cuePointData + 20); } else { break; } @@ -79096,7 +80908,7 @@ MA_PRIVATE ma_uint64 ma_dr_wav__metadata_process_chunk(ma_dr_wav__metadata_parse if (pParser->stage == ma_dr_wav__metadata_parser_stage_count) { ma_uint8 buffer[4]; size_t bytesJustRead; - if (!pParser->onSeek(pParser->pReadSeekUserData, 28, ma_dr_wav_seek_origin_current)) { + if (!pParser->onSeek(pParser->pReadSeekUserData, 28, MA_DR_WAV_SEEK_CUR)) { return bytesRead; } bytesRead += 28; @@ -79191,7 +81003,7 @@ MA_PRIVATE ma_uint64 ma_dr_wav__metadata_process_chunk(ma_dr_wav__metadata_parse return bytesRead; } allocSizeNeeded += ma_dr_wav__strlen(buffer) + 1; - allocSizeNeeded += (size_t)pChunkHeader->sizeInBytes - MA_DR_WAV_BEXT_BYTES; + allocSizeNeeded += (size_t)pChunkHeader->sizeInBytes - MA_DR_WAV_BEXT_BYTES + 1; ma_dr_wav__metadata_request_extra_memory_for_stage_2(pParser, allocSizeNeeded, 1); pParser->metadataCount += 1; } else { @@ -79274,6 +81086,16 @@ MA_PRIVATE ma_uint64 ma_dr_wav__metadata_process_chunk(ma_dr_wav__metadata_parse subchunkBytesRead = ma_dr_wav__metadata_process_info_text_chunk(pParser, subchunkDataSize, ma_dr_wav_metadata_type_list_info_album); } else if (ma_dr_wav__chunk_matches(allowedMetadataTypes, subchunkId, ma_dr_wav_metadata_type_list_info_tracknumber, "ITRK")) { subchunkBytesRead = ma_dr_wav__metadata_process_info_text_chunk(pParser, subchunkDataSize, ma_dr_wav_metadata_type_list_info_tracknumber); + } else if (ma_dr_wav__chunk_matches(allowedMetadataTypes, subchunkId, ma_dr_wav_metadata_type_list_info_location, "IARL")) { + subchunkBytesRead = ma_dr_wav__metadata_process_info_text_chunk(pParser, subchunkDataSize, ma_dr_wav_metadata_type_list_info_location); + } else if (ma_dr_wav__chunk_matches(allowedMetadataTypes, subchunkId, ma_dr_wav_metadata_type_list_info_organization, "ICMS")) { + subchunkBytesRead = ma_dr_wav__metadata_process_info_text_chunk(pParser, subchunkDataSize, ma_dr_wav_metadata_type_list_info_organization); + } else if (ma_dr_wav__chunk_matches(allowedMetadataTypes, subchunkId, ma_dr_wav_metadata_type_list_info_keywords, "IKEY")) { + subchunkBytesRead = ma_dr_wav__metadata_process_info_text_chunk(pParser, subchunkDataSize, ma_dr_wav_metadata_type_list_info_keywords); + } else if (ma_dr_wav__chunk_matches(allowedMetadataTypes, subchunkId, ma_dr_wav_metadata_type_list_info_medium, "IMED")) { + subchunkBytesRead = ma_dr_wav__metadata_process_info_text_chunk(pParser, subchunkDataSize, ma_dr_wav_metadata_type_list_info_medium); + } else if (ma_dr_wav__chunk_matches(allowedMetadataTypes, subchunkId, ma_dr_wav_metadata_type_list_info_description, "ISBJ")) { + subchunkBytesRead = ma_dr_wav__metadata_process_info_text_chunk(pParser, subchunkDataSize, ma_dr_wav_metadata_type_list_info_description); } else if ((allowedMetadataTypes & ma_dr_wav_metadata_type_unknown) != 0) { subchunkBytesRead = ma_dr_wav__metadata_process_unknown_chunk(pParser, subchunkId, subchunkDataSize, listType); } @@ -79281,13 +81103,13 @@ MA_PRIVATE ma_uint64 ma_dr_wav__metadata_process_chunk(ma_dr_wav__metadata_parse MA_DR_WAV_ASSERT(subchunkBytesRead <= subchunkDataSize); if (subchunkBytesRead < subchunkDataSize) { ma_uint64 bytesToSeek = subchunkDataSize - subchunkBytesRead; - if (!pParser->onSeek(pParser->pReadSeekUserData, (int)bytesToSeek, ma_dr_wav_seek_origin_current)) { + if (!pParser->onSeek(pParser->pReadSeekUserData, (int)bytesToSeek, MA_DR_WAV_SEEK_CUR)) { break; } bytesRead += bytesToSeek; } if ((subchunkDataSize % 2) == 1) { - if (!pParser->onSeek(pParser->pReadSeekUserData, 1, ma_dr_wav_seek_origin_current)) { + if (!pParser->onSeek(pParser->pReadSeekUserData, 1, MA_DR_WAV_SEEK_CUR)) { break; } bytesRead += 1; @@ -79324,7 +81146,7 @@ MA_API ma_uint16 ma_dr_wav_fmt_get_format(const ma_dr_wav_fmt* pFMT) return ma_dr_wav_bytes_to_u16(pFMT->subFormat); } } -MA_PRIVATE ma_bool32 ma_dr_wav_preinit(ma_dr_wav* pWav, ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, void* pReadSeekUserData, const ma_allocation_callbacks* pAllocationCallbacks) +MA_PRIVATE ma_bool32 ma_dr_wav_preinit(ma_dr_wav* pWav, ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, ma_dr_wav_tell_proc onTell, void* pReadSeekTellUserData, const ma_allocation_callbacks* pAllocationCallbacks) { if (pWav == NULL || onRead == NULL || onSeek == NULL) { return MA_FALSE; @@ -79332,7 +81154,8 @@ MA_PRIVATE ma_bool32 ma_dr_wav_preinit(ma_dr_wav* pWav, ma_dr_wav_read_proc onRe MA_DR_WAV_ZERO_MEMORY(pWav, sizeof(*pWav)); pWav->onRead = onRead; pWav->onSeek = onSeek; - pWav->pUserData = pReadSeekUserData; + pWav->onTell = onTell; + pWav->pUserData = pReadSeekTellUserData; pWav->allocationCallbacks = ma_dr_wav_copy_allocation_callbacks_or_defaults(pAllocationCallbacks); if (pWav->allocationCallbacks.onFree == NULL || (pWav->allocationCallbacks.onMalloc == NULL && pWav->allocationCallbacks.onRealloc == NULL)) { return MA_FALSE; @@ -79546,14 +81369,14 @@ MA_PRIVATE ma_bool32 ma_dr_wav_init__internal(ma_dr_wav* pWav, ma_dr_wav_chunk_p fmt.channelMask = ma_dr_wav_bytes_to_u32_ex(fmtext + 2, pWav->container); ma_dr_wav_bytes_to_guid(fmtext + 6, fmt.subFormat); } else { - if (pWav->onSeek(pWav->pUserData, fmt.extendedSize, ma_dr_wav_seek_origin_current) == MA_FALSE) { + if (pWav->onSeek(pWav->pUserData, fmt.extendedSize, MA_DR_WAV_SEEK_CUR) == MA_FALSE) { return MA_FALSE; } } cursor += fmt.extendedSize; bytesReadSoFar += fmt.extendedSize; } - if (pWav->onSeek(pWav->pUserData, (int)(header.sizeInBytes - bytesReadSoFar), ma_dr_wav_seek_origin_current) == MA_FALSE) { + if (pWav->onSeek(pWav->pUserData, (int)(header.sizeInBytes - bytesReadSoFar), MA_DR_WAV_SEEK_CUR) == MA_FALSE) { return MA_FALSE; } cursor += (header.sizeInBytes - bytesReadSoFar); @@ -79704,15 +81527,26 @@ MA_PRIVATE ma_bool32 ma_dr_wav_init__internal(ma_dr_wav* pWav, ma_dr_wav_chunk_p return MA_FALSE; } offset = ma_dr_wav_bytes_to_u32_ex(offsetAndBlockSizeData + 0, pWav->container); - if (ma_dr_wav__seek_forward(pWav->onSeek, offset, pWav->pUserData) == MA_FALSE) { - return MA_FALSE; - } - cursor += offset; - pWav->dataChunkDataPos = cursor; + pWav->dataChunkDataPos = cursor + offset; dataChunkSize = chunkSize; - if (sequential || !isProcessingMetadata) { - break; + if (dataChunkSize > offset) { + dataChunkSize -= offset; + } else { + dataChunkSize = 0; + } + if (sequential) { + if (foundChunk_fmt) { + if (ma_dr_wav__seek_forward(pWav->onSeek, offset, pWav->pUserData) == MA_FALSE) { + return MA_FALSE; + } + cursor += offset; + break; + } else { + return MA_FALSE; + } } else { + chunkSize += header.paddingSize; + chunkSize -= sizeof(offsetAndBlockSizeData); if (ma_dr_wav__seek_forward(pWav->onSeek, chunkSize, pWav->pUserData) == MA_FALSE) { break; } @@ -79776,6 +81610,17 @@ MA_PRIVATE ma_bool32 ma_dr_wav_init__internal(ma_dr_wav* pWav, ma_dr_wav_chunk_p pWav->pMetadata = metadataParser.pMetadata; pWav->metadataCount = metadataParser.metadataCount; } + if (pWav->onTell != NULL && pWav->onSeek != NULL) { + if (pWav->onSeek(pWav->pUserData, 0, MA_DR_WAV_SEEK_END) == MA_TRUE) { + ma_int64 fileSize; + if (pWav->onTell(pWav->pUserData, &fileSize)) { + if (dataChunkSize + pWav->dataChunkDataPos > (ma_uint64)fileSize) { + dataChunkSize = (ma_uint64)fileSize - pWav->dataChunkDataPos; + } + } + } else { + } + } if (dataChunkSize == 0xFFFFFFFF && (pWav->container == ma_dr_wav_container_riff || pWav->container == ma_dr_wav_container_rifx) && pWav->isSequentialWrite == MA_FALSE) { dataChunkSize = 0; for (;;) { @@ -79795,8 +81640,14 @@ MA_PRIVATE ma_bool32 ma_dr_wav_init__internal(ma_dr_wav* pWav, ma_dr_wav_chunk_p pWav->sampleRate = fmt.sampleRate; pWav->channels = fmt.channels; pWav->bitsPerSample = fmt.bitsPerSample; - pWav->bytesRemaining = dataChunkSize; pWav->translatedFormatTag = translatedFormatTag; + if (!ma_dr_wav__is_compressed_format_tag(translatedFormatTag)) { + ma_uint32 bytesPerFrame = ma_dr_wav_get_bytes_per_pcm_frame(pWav); + if (bytesPerFrame > 0) { + dataChunkSize -= (dataChunkSize % bytesPerFrame); + } + } + pWav->bytesRemaining = dataChunkSize; pWav->dataChunkDataSize = dataChunkSize; if (sampleCountFromFactChunk != 0) { pWav->totalPCMFrameCount = sampleCountFromFactChunk; @@ -79851,20 +81702,20 @@ MA_PRIVATE ma_bool32 ma_dr_wav_init__internal(ma_dr_wav* pWav, ma_dr_wav_chunk_p #endif return MA_TRUE; } -MA_API ma_bool32 ma_dr_wav_init(ma_dr_wav* pWav, ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks) +MA_API ma_bool32 ma_dr_wav_init(ma_dr_wav* pWav, ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, ma_dr_wav_tell_proc onTell, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks) { - return ma_dr_wav_init_ex(pWav, onRead, onSeek, NULL, pUserData, NULL, 0, pAllocationCallbacks); + return ma_dr_wav_init_ex(pWav, onRead, onSeek, onTell, NULL, pUserData, NULL, 0, pAllocationCallbacks); } -MA_API ma_bool32 ma_dr_wav_init_ex(ma_dr_wav* pWav, ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, ma_dr_wav_chunk_proc onChunk, void* pReadSeekUserData, void* pChunkUserData, ma_uint32 flags, const ma_allocation_callbacks* pAllocationCallbacks) +MA_API ma_bool32 ma_dr_wav_init_ex(ma_dr_wav* pWav, ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, ma_dr_wav_tell_proc onTell, ma_dr_wav_chunk_proc onChunk, void* pReadSeekTellUserData, void* pChunkUserData, ma_uint32 flags, const ma_allocation_callbacks* pAllocationCallbacks) { - if (!ma_dr_wav_preinit(pWav, onRead, onSeek, pReadSeekUserData, pAllocationCallbacks)) { + if (!ma_dr_wav_preinit(pWav, onRead, onSeek, onTell, pReadSeekTellUserData, pAllocationCallbacks)) { return MA_FALSE; } return ma_dr_wav_init__internal(pWav, onChunk, pChunkUserData, flags); } -MA_API ma_bool32 ma_dr_wav_init_with_metadata(ma_dr_wav* pWav, ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, void* pUserData, ma_uint32 flags, const ma_allocation_callbacks* pAllocationCallbacks) +MA_API ma_bool32 ma_dr_wav_init_with_metadata(ma_dr_wav* pWav, ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, ma_dr_wav_tell_proc onTell, void* pUserData, ma_uint32 flags, const ma_allocation_callbacks* pAllocationCallbacks) { - if (!ma_dr_wav_preinit(pWav, onRead, onSeek, pUserData, pAllocationCallbacks)) { + if (!ma_dr_wav_preinit(pWav, onRead, onSeek, onTell, pUserData, pAllocationCallbacks)) { return MA_FALSE; } return ma_dr_wav_init__internal(pWav, NULL, NULL, flags | MA_DR_WAV_WITH_METADATA); @@ -80026,8 +81877,8 @@ MA_PRIVATE size_t ma_dr_wav__write_or_count_metadata(ma_dr_wav* pWav, ma_dr_wav_ for (iLoop = 0; iLoop < pMetadata->data.smpl.sampleLoopCount; ++iLoop) { bytesWritten += ma_dr_wav__write_or_count_u32ne_to_le(pWav, pMetadata->data.smpl.pLoops[iLoop].cuePointId); bytesWritten += ma_dr_wav__write_or_count_u32ne_to_le(pWav, pMetadata->data.smpl.pLoops[iLoop].type); - bytesWritten += ma_dr_wav__write_or_count_u32ne_to_le(pWav, pMetadata->data.smpl.pLoops[iLoop].firstSampleByteOffset); - bytesWritten += ma_dr_wav__write_or_count_u32ne_to_le(pWav, pMetadata->data.smpl.pLoops[iLoop].lastSampleByteOffset); + bytesWritten += ma_dr_wav__write_or_count_u32ne_to_le(pWav, pMetadata->data.smpl.pLoops[iLoop].firstSampleOffset); + bytesWritten += ma_dr_wav__write_or_count_u32ne_to_le(pWav, pMetadata->data.smpl.pLoops[iLoop].lastSampleOffset); bytesWritten += ma_dr_wav__write_or_count_u32ne_to_le(pWav, pMetadata->data.smpl.pLoops[iLoop].sampleFraction); bytesWritten += ma_dr_wav__write_or_count_u32ne_to_le(pWav, pMetadata->data.smpl.pLoops[iLoop].playCount); } @@ -80061,7 +81912,7 @@ MA_PRIVATE size_t ma_dr_wav__write_or_count_metadata(ma_dr_wav* pWav, ma_dr_wav_ bytesWritten += ma_dr_wav__write_or_count(pWav, pMetadata->data.cue.pCuePoints[iCuePoint].dataChunkId, 4); bytesWritten += ma_dr_wav__write_or_count_u32ne_to_le(pWav, pMetadata->data.cue.pCuePoints[iCuePoint].chunkStart); bytesWritten += ma_dr_wav__write_or_count_u32ne_to_le(pWav, pMetadata->data.cue.pCuePoints[iCuePoint].blockStart); - bytesWritten += ma_dr_wav__write_or_count_u32ne_to_le(pWav, pMetadata->data.cue.pCuePoints[iCuePoint].sampleByteOffset); + bytesWritten += ma_dr_wav__write_or_count_u32ne_to_le(pWav, pMetadata->data.cue.pCuePoints[iCuePoint].sampleOffset); } } break; case ma_dr_wav_metadata_type_acid: @@ -80147,15 +81998,20 @@ MA_PRIVATE size_t ma_dr_wav__write_or_count_metadata(ma_dr_wav* pWav, ma_dr_wav_ if (pMetadata->type & ma_dr_wav_metadata_type_list_all_info_strings) { const char* pID = NULL; switch (pMetadata->type) { - case ma_dr_wav_metadata_type_list_info_software: pID = "ISFT"; break; - case ma_dr_wav_metadata_type_list_info_copyright: pID = "ICOP"; break; - case ma_dr_wav_metadata_type_list_info_title: pID = "INAM"; break; - case ma_dr_wav_metadata_type_list_info_artist: pID = "IART"; break; - case ma_dr_wav_metadata_type_list_info_comment: pID = "ICMT"; break; - case ma_dr_wav_metadata_type_list_info_date: pID = "ICRD"; break; - case ma_dr_wav_metadata_type_list_info_genre: pID = "IGNR"; break; - case ma_dr_wav_metadata_type_list_info_album: pID = "IPRD"; break; - case ma_dr_wav_metadata_type_list_info_tracknumber: pID = "ITRK"; break; + case ma_dr_wav_metadata_type_list_info_software: pID = "ISFT"; break; + case ma_dr_wav_metadata_type_list_info_copyright: pID = "ICOP"; break; + case ma_dr_wav_metadata_type_list_info_title: pID = "INAM"; break; + case ma_dr_wav_metadata_type_list_info_artist: pID = "IART"; break; + case ma_dr_wav_metadata_type_list_info_comment: pID = "ICMT"; break; + case ma_dr_wav_metadata_type_list_info_date: pID = "ICRD"; break; + case ma_dr_wav_metadata_type_list_info_genre: pID = "IGNR"; break; + case ma_dr_wav_metadata_type_list_info_album: pID = "IPRD"; break; + case ma_dr_wav_metadata_type_list_info_tracknumber: pID = "ITRK"; break; + case ma_dr_wav_metadata_type_list_info_location: pID = "IARL"; break; + case ma_dr_wav_metadata_type_list_info_organization: pID = "ICMS"; break; + case ma_dr_wav_metadata_type_list_info_keywords: pID = "IKEY"; break; + case ma_dr_wav_metadata_type_list_info_medium: pID = "IMED"; break; + case ma_dr_wav_metadata_type_list_info_description: pID = "ISBJ"; break; default: break; } MA_DR_WAV_ASSERT(pID != NULL); @@ -80370,7 +82226,7 @@ MA_PRIVATE ma_bool32 ma_dr_wav_init_write__internal(ma_dr_wav* pWav, const ma_dr } pWav->dataChunkDataSizeTargetWrite = initialDataChunkSize; if (pFormat->container == ma_dr_wav_container_riff) { - ma_uint32 chunkSizeRIFF = 28 + (ma_uint32)initialDataChunkSize; + ma_uint32 chunkSizeRIFF = 36 + (ma_uint32)initialDataChunkSize; runningPos += ma_dr_wav__write(pWav, "RIFF", 4); runningPos += ma_dr_wav__write_u32ne_to_le(pWav, chunkSizeRIFF); runningPos += ma_dr_wav__write(pWav, "WAVE", 4); @@ -80493,7 +82349,31 @@ MA_PRIVATE size_t ma_dr_wav__on_write_stdio(void* pUserData, const void* pData, } MA_PRIVATE ma_bool32 ma_dr_wav__on_seek_stdio(void* pUserData, int offset, ma_dr_wav_seek_origin origin) { - return fseek((FILE*)pUserData, offset, (origin == ma_dr_wav_seek_origin_current) ? SEEK_CUR : SEEK_SET) == 0; + int whence = SEEK_SET; + if (origin == MA_DR_WAV_SEEK_CUR) { + whence = SEEK_CUR; + } else if (origin == MA_DR_WAV_SEEK_END) { + whence = SEEK_END; + } + return fseek((FILE*)pUserData, offset, whence) == 0; +} +MA_PRIVATE ma_bool32 ma_dr_wav__on_tell_stdio(void* pUserData, ma_int64* pCursor) +{ + FILE* pFileStdio = (FILE*)pUserData; + ma_int64 result; + MA_DR_WAV_ASSERT(pFileStdio != NULL); + MA_DR_WAV_ASSERT(pCursor != NULL); +#if defined(_WIN32) && !defined(NXDK) + #if defined(_MSC_VER) && _MSC_VER > 1200 + result = _ftelli64(pFileStdio); + #else + result = ftell(pFileStdio); + #endif +#else + result = ftell(pFileStdio); +#endif + *pCursor = result; + return MA_TRUE; } MA_API ma_bool32 ma_dr_wav_init_file(ma_dr_wav* pWav, const char* filename, const ma_allocation_callbacks* pAllocationCallbacks) { @@ -80502,7 +82382,7 @@ MA_API ma_bool32 ma_dr_wav_init_file(ma_dr_wav* pWav, const char* filename, cons MA_PRIVATE ma_bool32 ma_dr_wav_init_file__internal_FILE(ma_dr_wav* pWav, FILE* pFile, ma_dr_wav_chunk_proc onChunk, void* pChunkUserData, ma_uint32 flags, const ma_allocation_callbacks* pAllocationCallbacks) { ma_bool32 result; - result = ma_dr_wav_preinit(pWav, ma_dr_wav__on_read_stdio, ma_dr_wav__on_seek_stdio, (void*)pFile, pAllocationCallbacks); + result = ma_dr_wav_preinit(pWav, ma_dr_wav__on_read_stdio, ma_dr_wav__on_seek_stdio, ma_dr_wav__on_tell_stdio, (void*)pFile, pAllocationCallbacks); if (result != MA_TRUE) { fclose(pFile); return result; @@ -80639,25 +82519,26 @@ MA_PRIVATE size_t ma_dr_wav__on_read_memory(void* pUserData, void* pBufferOut, s MA_PRIVATE ma_bool32 ma_dr_wav__on_seek_memory(void* pUserData, int offset, ma_dr_wav_seek_origin origin) { ma_dr_wav* pWav = (ma_dr_wav*)pUserData; + ma_int64 newCursor; MA_DR_WAV_ASSERT(pWav != NULL); - if (origin == ma_dr_wav_seek_origin_current) { - if (offset > 0) { - if (pWav->memoryStream.currentReadPos + offset > pWav->memoryStream.dataSize) { - return MA_FALSE; - } - } else { - if (pWav->memoryStream.currentReadPos < (size_t)-offset) { - return MA_FALSE; - } - } - pWav->memoryStream.currentReadPos += offset; + if (origin == MA_DR_WAV_SEEK_SET) { + newCursor = 0; + } else if (origin == MA_DR_WAV_SEEK_CUR) { + newCursor = (ma_int64)pWav->memoryStream.currentReadPos; + } else if (origin == MA_DR_WAV_SEEK_END) { + newCursor = (ma_int64)pWav->memoryStream.dataSize; } else { - if ((ma_uint32)offset <= pWav->memoryStream.dataSize) { - pWav->memoryStream.currentReadPos = offset; - } else { - return MA_FALSE; - } + MA_DR_WAV_ASSERT(!"Invalid seek origin"); + return MA_FALSE; + } + newCursor += offset; + if (newCursor < 0) { + return MA_FALSE; + } + if ((size_t)newCursor > pWav->memoryStream.dataSize) { + return MA_FALSE; } + pWav->memoryStream.currentReadPos = (size_t)newCursor; return MA_TRUE; } MA_PRIVATE size_t ma_dr_wav__on_write_memory(void* pUserData, const void* pDataIn, size_t bytesToWrite) @@ -80691,25 +82572,34 @@ MA_PRIVATE size_t ma_dr_wav__on_write_memory(void* pUserData, const void* pDataI MA_PRIVATE ma_bool32 ma_dr_wav__on_seek_memory_write(void* pUserData, int offset, ma_dr_wav_seek_origin origin) { ma_dr_wav* pWav = (ma_dr_wav*)pUserData; + ma_int64 newCursor; MA_DR_WAV_ASSERT(pWav != NULL); - if (origin == ma_dr_wav_seek_origin_current) { - if (offset > 0) { - if (pWav->memoryStreamWrite.currentWritePos + offset > pWav->memoryStreamWrite.dataSize) { - offset = (int)(pWav->memoryStreamWrite.dataSize - pWav->memoryStreamWrite.currentWritePos); - } - } else { - if (pWav->memoryStreamWrite.currentWritePos < (size_t)-offset) { - offset = -(int)pWav->memoryStreamWrite.currentWritePos; - } - } - pWav->memoryStreamWrite.currentWritePos += offset; + if (origin == MA_DR_WAV_SEEK_SET) { + newCursor = 0; + } else if (origin == MA_DR_WAV_SEEK_CUR) { + newCursor = (ma_int64)pWav->memoryStreamWrite.currentWritePos; + } else if (origin == MA_DR_WAV_SEEK_END) { + newCursor = (ma_int64)pWav->memoryStreamWrite.dataSize; } else { - if ((ma_uint32)offset <= pWav->memoryStreamWrite.dataSize) { - pWav->memoryStreamWrite.currentWritePos = offset; - } else { - pWav->memoryStreamWrite.currentWritePos = pWav->memoryStreamWrite.dataSize; - } + MA_DR_WAV_ASSERT(!"Invalid seek origin"); + return MA_FALSE; + } + newCursor += offset; + if (newCursor < 0) { + return MA_FALSE; + } + if ((size_t)newCursor > pWav->memoryStreamWrite.dataSize) { + return MA_FALSE; } + pWav->memoryStreamWrite.currentWritePos = (size_t)newCursor; + return MA_TRUE; +} +MA_PRIVATE ma_bool32 ma_dr_wav__on_tell_memory(void* pUserData, ma_int64* pCursor) +{ + ma_dr_wav* pWav = (ma_dr_wav*)pUserData; + MA_DR_WAV_ASSERT(pWav != NULL); + MA_DR_WAV_ASSERT(pCursor != NULL); + *pCursor = (ma_int64)pWav->memoryStream.currentReadPos; return MA_TRUE; } MA_API ma_bool32 ma_dr_wav_init_memory(ma_dr_wav* pWav, const void* data, size_t dataSize, const ma_allocation_callbacks* pAllocationCallbacks) @@ -80721,7 +82611,7 @@ MA_API ma_bool32 ma_dr_wav_init_memory_ex(ma_dr_wav* pWav, const void* data, siz if (data == NULL || dataSize == 0) { return MA_FALSE; } - if (!ma_dr_wav_preinit(pWav, ma_dr_wav__on_read_memory, ma_dr_wav__on_seek_memory, pWav, pAllocationCallbacks)) { + if (!ma_dr_wav_preinit(pWav, ma_dr_wav__on_read_memory, ma_dr_wav__on_seek_memory, ma_dr_wav__on_tell_memory, pWav, pAllocationCallbacks)) { return MA_FALSE; } pWav->memoryStream.data = (const ma_uint8*)data; @@ -80734,7 +82624,7 @@ MA_API ma_bool32 ma_dr_wav_init_memory_with_metadata(ma_dr_wav* pWav, const void if (data == NULL || dataSize == 0) { return MA_FALSE; } - if (!ma_dr_wav_preinit(pWav, ma_dr_wav__on_read_memory, ma_dr_wav__on_seek_memory, pWav, pAllocationCallbacks)) { + if (!ma_dr_wav_preinit(pWav, ma_dr_wav__on_read_memory, ma_dr_wav__on_seek_memory, ma_dr_wav__on_tell_memory, pWav, pAllocationCallbacks)) { return MA_FALSE; } pWav->memoryStream.data = (const ma_uint8*)data; @@ -80793,30 +82683,30 @@ MA_API ma_result ma_dr_wav_uninit(ma_dr_wav* pWav) } if (pWav->onSeek && !pWav->isSequentialWrite) { if (pWav->container == ma_dr_wav_container_riff) { - if (pWav->onSeek(pWav->pUserData, 4, ma_dr_wav_seek_origin_start)) { + if (pWav->onSeek(pWav->pUserData, 4, MA_DR_WAV_SEEK_SET)) { ma_uint32 riffChunkSize = ma_dr_wav__riff_chunk_size_riff(pWav->dataChunkDataSize, pWav->pMetadata, pWav->metadataCount); ma_dr_wav__write_u32ne_to_le(pWav, riffChunkSize); } - if (pWav->onSeek(pWav->pUserData, (int)pWav->dataChunkDataPos - 4, ma_dr_wav_seek_origin_start)) { + if (pWav->onSeek(pWav->pUserData, (int)pWav->dataChunkDataPos - 4, MA_DR_WAV_SEEK_SET)) { ma_uint32 dataChunkSize = ma_dr_wav__data_chunk_size_riff(pWav->dataChunkDataSize); ma_dr_wav__write_u32ne_to_le(pWav, dataChunkSize); } } else if (pWav->container == ma_dr_wav_container_w64) { - if (pWav->onSeek(pWav->pUserData, 16, ma_dr_wav_seek_origin_start)) { + if (pWav->onSeek(pWav->pUserData, 16, MA_DR_WAV_SEEK_SET)) { ma_uint64 riffChunkSize = ma_dr_wav__riff_chunk_size_w64(pWav->dataChunkDataSize); ma_dr_wav__write_u64ne_to_le(pWav, riffChunkSize); } - if (pWav->onSeek(pWav->pUserData, (int)pWav->dataChunkDataPos - 8, ma_dr_wav_seek_origin_start)) { + if (pWav->onSeek(pWav->pUserData, (int)pWav->dataChunkDataPos - 8, MA_DR_WAV_SEEK_SET)) { ma_uint64 dataChunkSize = ma_dr_wav__data_chunk_size_w64(pWav->dataChunkDataSize); ma_dr_wav__write_u64ne_to_le(pWav, dataChunkSize); } } else if (pWav->container == ma_dr_wav_container_rf64) { int ds64BodyPos = 12 + 8; - if (pWav->onSeek(pWav->pUserData, ds64BodyPos + 0, ma_dr_wav_seek_origin_start)) { + if (pWav->onSeek(pWav->pUserData, ds64BodyPos + 0, MA_DR_WAV_SEEK_SET)) { ma_uint64 riffChunkSize = ma_dr_wav__riff_chunk_size_rf64(pWav->dataChunkDataSize, pWav->pMetadata, pWav->metadataCount); ma_dr_wav__write_u64ne_to_le(pWav, riffChunkSize); } - if (pWav->onSeek(pWav->pUserData, ds64BodyPos + 8, ma_dr_wav_seek_origin_start)) { + if (pWav->onSeek(pWav->pUserData, ds64BodyPos + 8, MA_DR_WAV_SEEK_SET)) { ma_uint64 dataChunkSize = ma_dr_wav__data_chunk_size_rf64(pWav->dataChunkDataSize); ma_dr_wav__write_u64ne_to_le(pWav, dataChunkSize); } @@ -80863,7 +82753,7 @@ MA_API size_t ma_dr_wav_read_raw(ma_dr_wav* pWav, size_t bytesToRead, void* pBuf if (bytesToSeek > 0x7FFFFFFF) { bytesToSeek = 0x7FFFFFFF; } - if (pWav->onSeek(pWav->pUserData, (int)bytesToSeek, ma_dr_wav_seek_origin_current) == MA_FALSE) { + if (pWav->onSeek(pWav->pUserData, (int)bytesToSeek, MA_DR_WAV_SEEK_CUR) == MA_FALSE) { break; } bytesRead += bytesToSeek; @@ -80962,7 +82852,7 @@ MA_PRIVATE ma_bool32 ma_dr_wav_seek_to_first_pcm_frame(ma_dr_wav* pWav) if (pWav->onWrite != NULL) { return MA_FALSE; } - if (!pWav->onSeek(pWav->pUserData, (int)pWav->dataChunkDataPos, ma_dr_wav_seek_origin_start)) { + if (!pWav->onSeek(pWav->pUserData, (int)pWav->dataChunkDataPos, MA_DR_WAV_SEEK_SET)) { return MA_FALSE; } if (ma_dr_wav__is_compressed_format_tag(pWav->translatedFormatTag)) { @@ -81043,7 +82933,7 @@ MA_API ma_bool32 ma_dr_wav_seek_to_pcm_frame(ma_dr_wav* pWav, ma_uint64 targetFr } while (offset > 0) { int offset32 = ((offset > INT_MAX) ? INT_MAX : (int)offset); - if (!pWav->onSeek(pWav->pUserData, offset32, ma_dr_wav_seek_origin_current)) { + if (!pWav->onSeek(pWav->pUserData, offset32, MA_DR_WAV_SEEK_CUR)) { return MA_FALSE; } pWav->readCursorInPCMFrames += offset32 / bytesPerFrame; @@ -81169,12 +83059,12 @@ MA_API ma_uint64 ma_dr_wav_write_pcm_frames(ma_dr_wav* pWav, ma_uint64 framesToW MA_PRIVATE ma_uint64 ma_dr_wav_read_pcm_frames_s16__msadpcm(ma_dr_wav* pWav, ma_uint64 framesToRead, ma_int16* pBufferOut) { ma_uint64 totalFramesRead = 0; - static ma_int32 adaptationTable[] = { + static const ma_int32 adaptationTable[] = { 230, 230, 230, 230, 307, 409, 512, 614, 768, 614, 512, 409, 307, 230, 230, 230 }; - static ma_int32 coeff1Table[] = { 256, 512, 0, 192, 240, 460, 392 }; - static ma_int32 coeff2Table[] = { 0, -256, 0, 64, 0, -208, -232 }; + static const ma_int32 coeff1Table[] = { 256, 512, 0, 192, 240, 460, 392 }; + static const ma_int32 coeff2Table[] = { 0, -256, 0, 64, 0, -208, -232 }; MA_DR_WAV_ASSERT(pWav != NULL); MA_DR_WAV_ASSERT(framesToRead > 0); while (pWav->readCursorInPCMFrames < pWav->totalPCMFrameCount) { @@ -81193,7 +83083,7 @@ MA_PRIVATE ma_uint64 ma_dr_wav_read_pcm_frames_s16__msadpcm(ma_dr_wav* pWav, ma_ pWav->msadpcm.cachedFrames[2] = pWav->msadpcm.prevFrames[0][0]; pWav->msadpcm.cachedFrames[3] = pWav->msadpcm.prevFrames[0][1]; pWav->msadpcm.cachedFrameCount = 2; - if (pWav->msadpcm.predictor[0] >= ma_dr_wav_countof(coeff1Table)) { + if (pWav->msadpcm.predictor[0] >= ma_dr_wav_countof(coeff1Table) || pWav->msadpcm.predictor[0] >= ma_dr_wav_countof(coeff2Table)) { return totalFramesRead; } } else { @@ -81215,7 +83105,8 @@ MA_PRIVATE ma_uint64 ma_dr_wav_read_pcm_frames_s16__msadpcm(ma_dr_wav* pWav, ma_ pWav->msadpcm.cachedFrames[2] = pWav->msadpcm.prevFrames[0][1]; pWav->msadpcm.cachedFrames[3] = pWav->msadpcm.prevFrames[1][1]; pWav->msadpcm.cachedFrameCount = 2; - if (pWav->msadpcm.predictor[0] >= ma_dr_wav_countof(coeff1Table) || pWav->msadpcm.predictor[1] >= ma_dr_wav_countof(coeff2Table)) { + if (pWav->msadpcm.predictor[0] >= ma_dr_wav_countof(coeff1Table) || pWav->msadpcm.predictor[0] >= ma_dr_wav_countof(coeff2Table) || + pWav->msadpcm.predictor[1] >= ma_dr_wav_countof(coeff1Table) || pWav->msadpcm.predictor[1] >= ma_dr_wav_countof(coeff2Table)) { return totalFramesRead; } } @@ -81252,6 +83143,9 @@ MA_PRIVATE ma_uint64 ma_dr_wav_read_pcm_frames_s16__msadpcm(ma_dr_wav* pWav, ma_ if (pWav->channels == 1) { ma_int32 newSample0; ma_int32 newSample1; + if (pWav->msadpcm.predictor[0] >= ma_dr_wav_countof(coeff1Table) || pWav->msadpcm.predictor[0] >= ma_dr_wav_countof(coeff2Table)) { + return totalFramesRead; + } newSample0 = ((pWav->msadpcm.prevFrames[0][1] * coeff1Table[pWav->msadpcm.predictor[0]]) + (pWav->msadpcm.prevFrames[0][0] * coeff2Table[pWav->msadpcm.predictor[0]])) >> 8; newSample0 += nibble0 * pWav->msadpcm.delta[0]; newSample0 = ma_dr_wav_clamp(newSample0, -32768, 32767); @@ -81276,6 +83170,9 @@ MA_PRIVATE ma_uint64 ma_dr_wav_read_pcm_frames_s16__msadpcm(ma_dr_wav* pWav, ma_ } else { ma_int32 newSample0; ma_int32 newSample1; + if (pWav->msadpcm.predictor[0] >= ma_dr_wav_countof(coeff1Table) || pWav->msadpcm.predictor[0] >= ma_dr_wav_countof(coeff2Table)) { + return totalFramesRead; + } newSample0 = ((pWav->msadpcm.prevFrames[0][1] * coeff1Table[pWav->msadpcm.predictor[0]]) + (pWav->msadpcm.prevFrames[0][0] * coeff2Table[pWav->msadpcm.predictor[0]])) >> 8; newSample0 += nibble0 * pWav->msadpcm.delta[0]; newSample0 = ma_dr_wav_clamp(newSample0, -32768, 32767); @@ -81285,6 +83182,9 @@ MA_PRIVATE ma_uint64 ma_dr_wav_read_pcm_frames_s16__msadpcm(ma_dr_wav* pWav, ma_ } pWav->msadpcm.prevFrames[0][0] = pWav->msadpcm.prevFrames[0][1]; pWav->msadpcm.prevFrames[0][1] = newSample0; + if (pWav->msadpcm.predictor[1] >= ma_dr_wav_countof(coeff1Table) || pWav->msadpcm.predictor[1] >= ma_dr_wav_countof(coeff2Table)) { + return totalFramesRead; + } newSample1 = ((pWav->msadpcm.prevFrames[1][1] * coeff1Table[pWav->msadpcm.predictor[1]]) + (pWav->msadpcm.prevFrames[1][0] * coeff2Table[pWav->msadpcm.predictor[1]])) >> 8; newSample1 += nibble1 * pWav->msadpcm.delta[1]; newSample1 = ma_dr_wav_clamp(newSample1, -32768, 32767); @@ -81307,11 +83207,11 @@ MA_PRIVATE ma_uint64 ma_dr_wav_read_pcm_frames_s16__ima(ma_dr_wav* pWav, ma_uint { ma_uint64 totalFramesRead = 0; ma_uint32 iChannel; - static ma_int32 indexTable[16] = { + static const ma_int32 indexTable[16] = { -1, -1, -1, -1, 2, 4, 6, 8, -1, -1, -1, -1, 2, 4, 6, 8 }; - static ma_int32 stepTable[89] = { + static const ma_int32 stepTable[89] = { 7, 8, 9, 10, 11, 12, 13, 14, 16, 17, 19, 21, 23, 25, 28, 31, 34, 37, 41, 45, 50, 55, 60, 66, 73, 80, 88, 97, 107, 118, @@ -81334,7 +83234,7 @@ MA_PRIVATE ma_uint64 ma_dr_wav_read_pcm_frames_s16__ima(ma_dr_wav* pWav, ma_uint } pWav->ima.bytesRemainingInBlock = pWav->fmt.blockAlign - sizeof(header); if (header[2] >= ma_dr_wav_countof(stepTable)) { - pWav->onSeek(pWav->pUserData, pWav->ima.bytesRemainingInBlock, ma_dr_wav_seek_origin_current); + pWav->onSeek(pWav->pUserData, pWav->ima.bytesRemainingInBlock, MA_DR_WAV_SEEK_CUR); pWav->ima.bytesRemainingInBlock = 0; return totalFramesRead; } @@ -81349,7 +83249,7 @@ MA_PRIVATE ma_uint64 ma_dr_wav_read_pcm_frames_s16__ima(ma_dr_wav* pWav, ma_uint } pWav->ima.bytesRemainingInBlock = pWav->fmt.blockAlign - sizeof(header); if (header[2] >= ma_dr_wav_countof(stepTable) || header[6] >= ma_dr_wav_countof(stepTable)) { - pWav->onSeek(pWav->pUserData, pWav->ima.bytesRemainingInBlock, ma_dr_wav_seek_origin_current); + pWav->onSeek(pWav->pUserData, pWav->ima.bytesRemainingInBlock, MA_DR_WAV_SEEK_CUR); pWav->ima.bytesRemainingInBlock = 0; return totalFramesRead; } @@ -81424,7 +83324,7 @@ MA_PRIVATE ma_uint64 ma_dr_wav_read_pcm_frames_s16__ima(ma_dr_wav* pWav, ma_uint return totalFramesRead; } #ifndef MA_DR_WAV_NO_CONVERSION_API -static unsigned short g_ma_dr_wavAlawTable[256] = { +static const unsigned short ma_dr_wav_gAlawTable[256] = { 0xEA80, 0xEB80, 0xE880, 0xE980, 0xEE80, 0xEF80, 0xEC80, 0xED80, 0xE280, 0xE380, 0xE080, 0xE180, 0xE680, 0xE780, 0xE480, 0xE580, 0xF540, 0xF5C0, 0xF440, 0xF4C0, 0xF740, 0xF7C0, 0xF640, 0xF6C0, 0xF140, 0xF1C0, 0xF040, 0xF0C0, 0xF340, 0xF3C0, 0xF240, 0xF2C0, 0xAA00, 0xAE00, 0xA200, 0xA600, 0xBA00, 0xBE00, 0xB200, 0xB600, 0x8A00, 0x8E00, 0x8200, 0x8600, 0x9A00, 0x9E00, 0x9200, 0x9600, @@ -81442,7 +83342,7 @@ static unsigned short g_ma_dr_wavAlawTable[256] = { 0x0560, 0x0520, 0x05E0, 0x05A0, 0x0460, 0x0420, 0x04E0, 0x04A0, 0x0760, 0x0720, 0x07E0, 0x07A0, 0x0660, 0x0620, 0x06E0, 0x06A0, 0x02B0, 0x0290, 0x02F0, 0x02D0, 0x0230, 0x0210, 0x0270, 0x0250, 0x03B0, 0x0390, 0x03F0, 0x03D0, 0x0330, 0x0310, 0x0370, 0x0350 }; -static unsigned short g_ma_dr_wavMulawTable[256] = { +static const unsigned short ma_dr_wav_gMulawTable[256] = { 0x8284, 0x8684, 0x8A84, 0x8E84, 0x9284, 0x9684, 0x9A84, 0x9E84, 0xA284, 0xA684, 0xAA84, 0xAE84, 0xB284, 0xB684, 0xBA84, 0xBE84, 0xC184, 0xC384, 0xC584, 0xC784, 0xC984, 0xCB84, 0xCD84, 0xCF84, 0xD184, 0xD384, 0xD584, 0xD784, 0xD984, 0xDB84, 0xDD84, 0xDF84, 0xE104, 0xE204, 0xE304, 0xE404, 0xE504, 0xE604, 0xE704, 0xE804, 0xE904, 0xEA04, 0xEB04, 0xEC04, 0xED04, 0xEE04, 0xEF04, 0xF004, @@ -81462,11 +83362,11 @@ static unsigned short g_ma_dr_wavMulawTable[256] = { }; static MA_INLINE ma_int16 ma_dr_wav__alaw_to_s16(ma_uint8 sampleIn) { - return (short)g_ma_dr_wavAlawTable[sampleIn]; + return (short)ma_dr_wav_gAlawTable[sampleIn]; } static MA_INLINE ma_int16 ma_dr_wav__mulaw_to_s16(ma_uint8 sampleIn) { - return (short)g_ma_dr_wavMulawTable[sampleIn]; + return (short)ma_dr_wav_gMulawTable[sampleIn]; } MA_PRIVATE void ma_dr_wav__pcm_to_s16(ma_int16* pOut, const ma_uint8* pIn, size_t totalSampleCount, unsigned int bytesPerSample) { @@ -82529,6 +84429,10 @@ MA_PRIVATE ma_int16* ma_dr_wav__read_pcm_frames_and_close_s16(ma_dr_wav* pWav, u ma_int16* pSampleData; ma_uint64 framesRead; MA_DR_WAV_ASSERT(pWav != NULL); + if (pWav->channels == 0 || pWav->totalPCMFrameCount > MA_SIZE_MAX / pWav->channels / sizeof(ma_int16)) { + ma_dr_wav_uninit(pWav); + return NULL; + } sampleDataSize = pWav->totalPCMFrameCount * pWav->channels * sizeof(ma_int16); if (sampleDataSize > MA_SIZE_MAX) { ma_dr_wav_uninit(pWav); @@ -82563,6 +84467,10 @@ MA_PRIVATE float* ma_dr_wav__read_pcm_frames_and_close_f32(ma_dr_wav* pWav, unsi float* pSampleData; ma_uint64 framesRead; MA_DR_WAV_ASSERT(pWav != NULL); + if (pWav->channels == 0 || pWav->totalPCMFrameCount > MA_SIZE_MAX / pWav->channels / sizeof(float)) { + ma_dr_wav_uninit(pWav); + return NULL; + } sampleDataSize = pWav->totalPCMFrameCount * pWav->channels * sizeof(float); if (sampleDataSize > MA_SIZE_MAX) { ma_dr_wav_uninit(pWav); @@ -82597,6 +84505,10 @@ MA_PRIVATE ma_int32* ma_dr_wav__read_pcm_frames_and_close_s32(ma_dr_wav* pWav, u ma_int32* pSampleData; ma_uint64 framesRead; MA_DR_WAV_ASSERT(pWav != NULL); + if (pWav->channels == 0 || pWav->totalPCMFrameCount > MA_SIZE_MAX / pWav->channels / sizeof(ma_int32)) { + ma_dr_wav_uninit(pWav); + return NULL; + } sampleDataSize = pWav->totalPCMFrameCount * pWav->channels * sizeof(ma_int32); if (sampleDataSize > MA_SIZE_MAX) { ma_dr_wav_uninit(pWav); @@ -82625,7 +84537,7 @@ MA_PRIVATE ma_int32* ma_dr_wav__read_pcm_frames_and_close_s32(ma_dr_wav* pWav, u } return pSampleData; } -MA_API ma_int16* ma_dr_wav_open_and_read_pcm_frames_s16(ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks) +MA_API ma_int16* ma_dr_wav_open_and_read_pcm_frames_s16(ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, ma_dr_wav_tell_proc onTell, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks) { ma_dr_wav wav; if (channelsOut) { @@ -82637,12 +84549,12 @@ MA_API ma_int16* ma_dr_wav_open_and_read_pcm_frames_s16(ma_dr_wav_read_proc onRe if (totalFrameCountOut) { *totalFrameCountOut = 0; } - if (!ma_dr_wav_init(&wav, onRead, onSeek, pUserData, pAllocationCallbacks)) { + if (!ma_dr_wav_init(&wav, onRead, onSeek, onTell, pUserData, pAllocationCallbacks)) { return NULL; } return ma_dr_wav__read_pcm_frames_and_close_s16(&wav, channelsOut, sampleRateOut, totalFrameCountOut); } -MA_API float* ma_dr_wav_open_and_read_pcm_frames_f32(ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks) +MA_API float* ma_dr_wav_open_and_read_pcm_frames_f32(ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, ma_dr_wav_tell_proc onTell, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks) { ma_dr_wav wav; if (channelsOut) { @@ -82654,12 +84566,12 @@ MA_API float* ma_dr_wav_open_and_read_pcm_frames_f32(ma_dr_wav_read_proc onRead, if (totalFrameCountOut) { *totalFrameCountOut = 0; } - if (!ma_dr_wav_init(&wav, onRead, onSeek, pUserData, pAllocationCallbacks)) { + if (!ma_dr_wav_init(&wav, onRead, onSeek, onTell, pUserData, pAllocationCallbacks)) { return NULL; } return ma_dr_wav__read_pcm_frames_and_close_f32(&wav, channelsOut, sampleRateOut, totalFrameCountOut); } -MA_API ma_int32* ma_dr_wav_open_and_read_pcm_frames_s32(ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks) +MA_API ma_int32* ma_dr_wav_open_and_read_pcm_frames_s32(ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, ma_dr_wav_tell_proc onTell, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks) { ma_dr_wav wav; if (channelsOut) { @@ -82671,7 +84583,7 @@ MA_API ma_int32* ma_dr_wav_open_and_read_pcm_frames_s32(ma_dr_wav_read_proc onRe if (totalFrameCountOut) { *totalFrameCountOut = 0; } - if (!ma_dr_wav_init(&wav, onRead, onSeek, pUserData, pAllocationCallbacks)) { + if (!ma_dr_wav_init(&wav, onRead, onSeek, onTell, pUserData, pAllocationCallbacks)) { return NULL; } return ma_dr_wav__read_pcm_frames_and_close_s32(&wav, channelsOut, sampleRateOut, totalFrameCountOut); @@ -83979,7 +85891,7 @@ static MA_INLINE ma_uint32 ma_dr_flac__clz_lzcnt(ma_dr_flac_cache_t x) { ma_uint64 r; __asm__ __volatile__ ( - "lzcnt{ %1, %0| %0, %1}" : "=r"(r) : "r"(x) : "cc" + "rep; bsr{q %1, %0| %0, %1}" : "=r"(r) : "r"(x) : "cc" ); return (ma_uint32)r; } @@ -83987,11 +85899,11 @@ static MA_INLINE ma_uint32 ma_dr_flac__clz_lzcnt(ma_dr_flac_cache_t x) { ma_uint32 r; __asm__ __volatile__ ( - "lzcnt{l %1, %0| %0, %1}" : "=r"(r) : "r"(x) : "cc" + "rep; bsr{l %1, %0| %0, %1}" : "=r"(r) : "r"(x) : "cc" ); return r; } - #elif defined(MA_ARM) && (defined(__ARM_ARCH) && __ARM_ARCH >= 5) && !defined(__ARM_ARCH_6M__) && !defined(MA_64BIT) + #elif defined(MA_ARM) && (defined(__ARM_ARCH) && __ARM_ARCH >= 5) && !defined(__ARM_ARCH_6M__) && !(defined(__thumb__) && !defined(__thumb2__)) && !defined(MA_64BIT) { unsigned int r; __asm__ __volatile__ ( @@ -84106,23 +86018,23 @@ static ma_bool32 ma_dr_flac__seek_to_byte(ma_dr_flac_bs* bs, ma_uint64 offsetFro MA_DR_FLAC_ASSERT(offsetFromStart > 0); if (offsetFromStart > 0x7FFFFFFF) { ma_uint64 bytesRemaining = offsetFromStart; - if (!bs->onSeek(bs->pUserData, 0x7FFFFFFF, ma_dr_flac_seek_origin_start)) { + if (!bs->onSeek(bs->pUserData, 0x7FFFFFFF, MA_DR_FLAC_SEEK_SET)) { return MA_FALSE; } bytesRemaining -= 0x7FFFFFFF; while (bytesRemaining > 0x7FFFFFFF) { - if (!bs->onSeek(bs->pUserData, 0x7FFFFFFF, ma_dr_flac_seek_origin_current)) { + if (!bs->onSeek(bs->pUserData, 0x7FFFFFFF, MA_DR_FLAC_SEEK_CUR)) { return MA_FALSE; } bytesRemaining -= 0x7FFFFFFF; } if (bytesRemaining > 0) { - if (!bs->onSeek(bs->pUserData, (int)bytesRemaining, ma_dr_flac_seek_origin_current)) { + if (!bs->onSeek(bs->pUserData, (int)bytesRemaining, MA_DR_FLAC_SEEK_CUR)) { return MA_FALSE; } } } else { - if (!bs->onSeek(bs->pUserData, (int)offsetFromStart, ma_dr_flac_seek_origin_start)) { + if (!bs->onSeek(bs->pUserData, (int)offsetFromStart, MA_DR_FLAC_SEEK_SET)) { return MA_FALSE; } } @@ -86600,6 +88512,7 @@ typedef struct { ma_dr_flac_read_proc onRead; ma_dr_flac_seek_proc onSeek; + ma_dr_flac_tell_proc onTell; ma_dr_flac_meta_proc onMeta; ma_dr_flac_container container; void* pUserData; @@ -86728,11 +88641,12 @@ static void ma_dr_flac__free_from_callbacks(void* p, const ma_allocation_callbac pAllocationCallbacks->onFree(p, pAllocationCallbacks->pUserData); } } -static ma_bool32 ma_dr_flac__read_and_decode_metadata(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_meta_proc onMeta, void* pUserData, void* pUserDataMD, ma_uint64* pFirstFramePos, ma_uint64* pSeektablePos, ma_uint32* pSeekpointCount, ma_allocation_callbacks* pAllocationCallbacks) +static ma_bool32 ma_dr_flac__read_and_decode_metadata(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_tell_proc onTell, ma_dr_flac_meta_proc onMeta, void* pUserData, void* pUserDataMD, ma_uint64* pFirstFramePos, ma_uint64* pSeektablePos, ma_uint32* pSeekpointCount, ma_allocation_callbacks* pAllocationCallbacks) { ma_uint64 runningFilePos = 42; ma_uint64 seektablePos = 0; ma_uint32 seektableSize = 0; + (void)onTell; for (;;) { ma_dr_flac_metadata metadata; ma_uint8 isLastBlock = 0; @@ -86743,8 +88657,9 @@ static ma_bool32 ma_dr_flac__read_and_decode_metadata(ma_dr_flac_read_proc onRea } runningFilePos += 4; metadata.type = blockType; - metadata.pRawData = NULL; metadata.rawDataSize = 0; + metadata.rawDataOffset = runningFilePos; + metadata.pRawData = NULL; switch (blockType) { case MA_DR_FLAC_METADATA_BLOCK_TYPE_APPLICATION: @@ -86944,53 +88859,124 @@ static ma_bool32 ma_dr_flac__read_and_decode_metadata(ma_dr_flac_read_proc onRea return MA_FALSE; } if (onMeta) { - void* pRawData; - const char* pRunningData; - const char* pRunningDataEnd; - pRawData = ma_dr_flac__malloc_from_callbacks(blockSize, pAllocationCallbacks); - if (pRawData == NULL) { - return MA_FALSE; + ma_bool32 result = MA_TRUE; + ma_uint32 blockSizeRemaining = blockSize; + char* pMime = NULL; + char* pDescription = NULL; + void* pPictureData = NULL; + if (blockSizeRemaining < 4 || onRead(pUserData, &metadata.data.picture.type, 4) != 4) { + result = MA_FALSE; + goto done_flac; } - if (onRead(pUserData, pRawData, blockSize) != blockSize) { - ma_dr_flac__free_from_callbacks(pRawData, pAllocationCallbacks); - return MA_FALSE; + blockSizeRemaining -= 4; + metadata.data.picture.type = ma_dr_flac__be2host_32(metadata.data.picture.type); + if (blockSizeRemaining < 4 || onRead(pUserData, &metadata.data.picture.mimeLength, 4) != 4) { + result = MA_FALSE; + goto done_flac; } - metadata.pRawData = pRawData; - metadata.rawDataSize = blockSize; - pRunningData = (const char*)pRawData; - pRunningDataEnd = (const char*)pRawData + blockSize; - metadata.data.picture.type = ma_dr_flac__be2host_32_ptr_unaligned(pRunningData); pRunningData += 4; - metadata.data.picture.mimeLength = ma_dr_flac__be2host_32_ptr_unaligned(pRunningData); pRunningData += 4; - if ((pRunningDataEnd - pRunningData) - 24 < (ma_int64)metadata.data.picture.mimeLength) { - ma_dr_flac__free_from_callbacks(pRawData, pAllocationCallbacks); - return MA_FALSE; + blockSizeRemaining -= 4; + metadata.data.picture.mimeLength = ma_dr_flac__be2host_32(metadata.data.picture.mimeLength); + pMime = (char*)ma_dr_flac__malloc_from_callbacks(metadata.data.picture.mimeLength + 1, pAllocationCallbacks); + if (pMime == NULL) { + result = MA_FALSE; + goto done_flac; } - metadata.data.picture.mime = pRunningData; pRunningData += metadata.data.picture.mimeLength; - metadata.data.picture.descriptionLength = ma_dr_flac__be2host_32_ptr_unaligned(pRunningData); pRunningData += 4; - if ((pRunningDataEnd - pRunningData) - 20 < (ma_int64)metadata.data.picture.descriptionLength) { - ma_dr_flac__free_from_callbacks(pRawData, pAllocationCallbacks); - return MA_FALSE; + if (blockSizeRemaining < metadata.data.picture.mimeLength || onRead(pUserData, pMime, metadata.data.picture.mimeLength) != metadata.data.picture.mimeLength) { + result = MA_FALSE; + goto done_flac; } - metadata.data.picture.description = pRunningData; pRunningData += metadata.data.picture.descriptionLength; - metadata.data.picture.width = ma_dr_flac__be2host_32_ptr_unaligned(pRunningData); pRunningData += 4; - metadata.data.picture.height = ma_dr_flac__be2host_32_ptr_unaligned(pRunningData); pRunningData += 4; - metadata.data.picture.colorDepth = ma_dr_flac__be2host_32_ptr_unaligned(pRunningData); pRunningData += 4; - metadata.data.picture.indexColorCount = ma_dr_flac__be2host_32_ptr_unaligned(pRunningData); pRunningData += 4; - metadata.data.picture.pictureDataSize = ma_dr_flac__be2host_32_ptr_unaligned(pRunningData); pRunningData += 4; - metadata.data.picture.pPictureData = (const ma_uint8*)pRunningData; - if (pRunningDataEnd - pRunningData < (ma_int64)metadata.data.picture.pictureDataSize) { - ma_dr_flac__free_from_callbacks(pRawData, pAllocationCallbacks); + blockSizeRemaining -= metadata.data.picture.mimeLength; + pMime[metadata.data.picture.mimeLength] = '\0'; + metadata.data.picture.mime = (const char*)pMime; + if (blockSizeRemaining < 4 || onRead(pUserData, &metadata.data.picture.descriptionLength, 4) != 4) { + result = MA_FALSE; + goto done_flac; + } + blockSizeRemaining -= 4; + metadata.data.picture.descriptionLength = ma_dr_flac__be2host_32(metadata.data.picture.descriptionLength); + pDescription = (char*)ma_dr_flac__malloc_from_callbacks(metadata.data.picture.descriptionLength + 1, pAllocationCallbacks); + if (pDescription == NULL) { + result = MA_FALSE; + goto done_flac; + } + if (blockSizeRemaining < metadata.data.picture.descriptionLength || onRead(pUserData, pDescription, metadata.data.picture.descriptionLength) != metadata.data.picture.descriptionLength) { + result = MA_FALSE; + goto done_flac; + } + blockSizeRemaining -= metadata.data.picture.descriptionLength; + pDescription[metadata.data.picture.descriptionLength] = '\0'; + metadata.data.picture.description = (const char*)pDescription; + if (blockSizeRemaining < 4 || onRead(pUserData, &metadata.data.picture.width, 4) != 4) { + result = MA_FALSE; + goto done_flac; + } + blockSizeRemaining -= 4; + metadata.data.picture.width = ma_dr_flac__be2host_32(metadata.data.picture.width); + if (blockSizeRemaining < 4 || onRead(pUserData, &metadata.data.picture.height, 4) != 4) { + result = MA_FALSE; + goto done_flac; + } + blockSizeRemaining -= 4; + metadata.data.picture.height = ma_dr_flac__be2host_32(metadata.data.picture.height); + if (blockSizeRemaining < 4 || onRead(pUserData, &metadata.data.picture.colorDepth, 4) != 4) { + result = MA_FALSE; + goto done_flac; + } + blockSizeRemaining -= 4; + metadata.data.picture.colorDepth = ma_dr_flac__be2host_32(metadata.data.picture.colorDepth); + if (blockSizeRemaining < 4 || onRead(pUserData, &metadata.data.picture.indexColorCount, 4) != 4) { + result = MA_FALSE; + goto done_flac; + } + blockSizeRemaining -= 4; + metadata.data.picture.indexColorCount = ma_dr_flac__be2host_32(metadata.data.picture.indexColorCount); + if (blockSizeRemaining < 4 || onRead(pUserData, &metadata.data.picture.pictureDataSize, 4) != 4) { + result = MA_FALSE; + goto done_flac; + } + blockSizeRemaining -= 4; + metadata.data.picture.pictureDataSize = ma_dr_flac__be2host_32(metadata.data.picture.pictureDataSize); + if (blockSizeRemaining < metadata.data.picture.pictureDataSize) { + result = MA_FALSE; + goto done_flac; + } + metadata.data.picture.pictureDataOffset = runningFilePos + (blockSize - blockSizeRemaining); + #ifndef MA_DR_FLAC_NO_PICTURE_METADATA_MALLOC + pPictureData = ma_dr_flac__malloc_from_callbacks(metadata.data.picture.pictureDataSize, pAllocationCallbacks); + if (pPictureData != NULL) { + if (onRead(pUserData, pPictureData, metadata.data.picture.pictureDataSize) != metadata.data.picture.pictureDataSize) { + result = MA_FALSE; + goto done_flac; + } + } else + #endif + { + if (!onSeek(pUserData, metadata.data.picture.pictureDataSize, MA_DR_FLAC_SEEK_CUR)) { + result = MA_FALSE; + goto done_flac; + } + } + blockSizeRemaining -= metadata.data.picture.pictureDataSize; + (void)blockSizeRemaining; + metadata.data.picture.pPictureData = (const ma_uint8*)pPictureData; + if (metadata.data.picture.pictureDataOffset != 0 || metadata.data.picture.pPictureData != NULL) { + onMeta(pUserDataMD, &metadata); + } else { + } + done_flac: + ma_dr_flac__free_from_callbacks(pMime, pAllocationCallbacks); + ma_dr_flac__free_from_callbacks(pDescription, pAllocationCallbacks); + ma_dr_flac__free_from_callbacks(pPictureData, pAllocationCallbacks); + if (result != MA_TRUE) { return MA_FALSE; } - onMeta(pUserDataMD, &metadata); - ma_dr_flac__free_from_callbacks(pRawData, pAllocationCallbacks); } } break; case MA_DR_FLAC_METADATA_BLOCK_TYPE_PADDING: { if (onMeta) { metadata.data.padding.unused = 0; - if (!onSeek(pUserData, blockSize, ma_dr_flac_seek_origin_current)) { + if (!onSeek(pUserData, blockSize, MA_DR_FLAC_SEEK_CUR)) { isLastBlock = MA_TRUE; } else { onMeta(pUserDataMD, &metadata); @@ -87000,7 +88986,7 @@ static ma_bool32 ma_dr_flac__read_and_decode_metadata(ma_dr_flac_read_proc onRea case MA_DR_FLAC_METADATA_BLOCK_TYPE_INVALID: { if (onMeta) { - if (!onSeek(pUserData, blockSize, ma_dr_flac_seek_origin_current)) { + if (!onSeek(pUserData, blockSize, MA_DR_FLAC_SEEK_CUR)) { isLastBlock = MA_TRUE; } } @@ -87009,12 +88995,15 @@ static ma_bool32 ma_dr_flac__read_and_decode_metadata(ma_dr_flac_read_proc onRea { if (onMeta) { void* pRawData = ma_dr_flac__malloc_from_callbacks(blockSize, pAllocationCallbacks); - if (pRawData == NULL) { - return MA_FALSE; - } - if (onRead(pUserData, pRawData, blockSize) != blockSize) { - ma_dr_flac__free_from_callbacks(pRawData, pAllocationCallbacks); - return MA_FALSE; + if (pRawData != NULL) { + if (onRead(pUserData, pRawData, blockSize) != blockSize) { + ma_dr_flac__free_from_callbacks(pRawData, pAllocationCallbacks); + return MA_FALSE; + } + } else { + if (!onSeek(pUserData, blockSize, MA_DR_FLAC_SEEK_CUR)) { + return MA_FALSE; + } } metadata.pRawData = pRawData; metadata.rawDataSize = blockSize; @@ -87024,7 +89013,7 @@ static ma_bool32 ma_dr_flac__read_and_decode_metadata(ma_dr_flac_read_proc onRea } break; } if (onMeta == NULL && blockSize > 0) { - if (!onSeek(pUserData, blockSize, ma_dr_flac_seek_origin_current)) { + if (!onSeek(pUserData, blockSize, MA_DR_FLAC_SEEK_CUR)) { isLastBlock = MA_TRUE; } } @@ -87288,6 +89277,7 @@ typedef struct { ma_dr_flac_read_proc onRead; ma_dr_flac_seek_proc onSeek; + ma_dr_flac_tell_proc onTell; void* pUserData; ma_uint64 currentBytePos; ma_uint64 firstBytePos; @@ -87306,29 +89296,29 @@ static size_t ma_dr_flac_oggbs__read_physical(ma_dr_flac_oggbs* oggbs, void* buf } static ma_bool32 ma_dr_flac_oggbs__seek_physical(ma_dr_flac_oggbs* oggbs, ma_uint64 offset, ma_dr_flac_seek_origin origin) { - if (origin == ma_dr_flac_seek_origin_start) { + if (origin == MA_DR_FLAC_SEEK_SET) { if (offset <= 0x7FFFFFFF) { - if (!oggbs->onSeek(oggbs->pUserData, (int)offset, ma_dr_flac_seek_origin_start)) { + if (!oggbs->onSeek(oggbs->pUserData, (int)offset, MA_DR_FLAC_SEEK_SET)) { return MA_FALSE; } oggbs->currentBytePos = offset; return MA_TRUE; } else { - if (!oggbs->onSeek(oggbs->pUserData, 0x7FFFFFFF, ma_dr_flac_seek_origin_start)) { + if (!oggbs->onSeek(oggbs->pUserData, 0x7FFFFFFF, MA_DR_FLAC_SEEK_SET)) { return MA_FALSE; } oggbs->currentBytePos = offset; - return ma_dr_flac_oggbs__seek_physical(oggbs, offset - 0x7FFFFFFF, ma_dr_flac_seek_origin_current); + return ma_dr_flac_oggbs__seek_physical(oggbs, offset - 0x7FFFFFFF, MA_DR_FLAC_SEEK_CUR); } } else { while (offset > 0x7FFFFFFF) { - if (!oggbs->onSeek(oggbs->pUserData, 0x7FFFFFFF, ma_dr_flac_seek_origin_current)) { + if (!oggbs->onSeek(oggbs->pUserData, 0x7FFFFFFF, MA_DR_FLAC_SEEK_CUR)) { return MA_FALSE; } oggbs->currentBytePos += 0x7FFFFFFF; offset -= 0x7FFFFFFF; } - if (!oggbs->onSeek(oggbs->pUserData, (int)offset, ma_dr_flac_seek_origin_current)) { + if (!oggbs->onSeek(oggbs->pUserData, (int)offset, MA_DR_FLAC_SEEK_CUR)) { return MA_FALSE; } oggbs->currentBytePos += offset; @@ -87354,7 +89344,7 @@ static ma_bool32 ma_dr_flac_oggbs__goto_next_page(ma_dr_flac_oggbs* oggbs, ma_dr continue; } if (header.serialNumber != oggbs->serialNumber) { - if (pageBodySize > 0 && !ma_dr_flac_oggbs__seek_physical(oggbs, pageBodySize, ma_dr_flac_seek_origin_current)) { + if (pageBodySize > 0 && !ma_dr_flac_oggbs__seek_physical(oggbs, pageBodySize, MA_DR_FLAC_SEEK_CUR)) { return MA_FALSE; } continue; @@ -87416,7 +89406,7 @@ static ma_bool32 ma_dr_flac_oggbs__seek_to_next_packet(ma_dr_flac_oggbs* oggbs) } bytesToEndOfPacketOrPage += segmentSize; } - ma_dr_flac_oggbs__seek_physical(oggbs, bytesToEndOfPacketOrPage, ma_dr_flac_seek_origin_current); + ma_dr_flac_oggbs__seek_physical(oggbs, bytesToEndOfPacketOrPage, MA_DR_FLAC_SEEK_CUR); oggbs->bytesRemainingInPage -= bytesToEndOfPacketOrPage; if (atEndOfPage) { if (!ma_dr_flac_oggbs__goto_next_page(oggbs)) { @@ -87469,36 +89459,44 @@ static ma_bool32 ma_dr_flac__on_seek_ogg(void* pUserData, int offset, ma_dr_flac int bytesSeeked = 0; MA_DR_FLAC_ASSERT(oggbs != NULL); MA_DR_FLAC_ASSERT(offset >= 0); - if (origin == ma_dr_flac_seek_origin_start) { - if (!ma_dr_flac_oggbs__seek_physical(oggbs, (int)oggbs->firstBytePos, ma_dr_flac_seek_origin_start)) { + if (origin == MA_DR_FLAC_SEEK_SET) { + if (!ma_dr_flac_oggbs__seek_physical(oggbs, (int)oggbs->firstBytePos, MA_DR_FLAC_SEEK_SET)) { return MA_FALSE; } if (!ma_dr_flac_oggbs__goto_next_page(oggbs, ma_dr_flac_ogg_fail_on_crc_mismatch)) { return MA_FALSE; } - return ma_dr_flac__on_seek_ogg(pUserData, offset, ma_dr_flac_seek_origin_current); - } - MA_DR_FLAC_ASSERT(origin == ma_dr_flac_seek_origin_current); - while (bytesSeeked < offset) { - int bytesRemainingToSeek = offset - bytesSeeked; - MA_DR_FLAC_ASSERT(bytesRemainingToSeek >= 0); - if (oggbs->bytesRemainingInPage >= (size_t)bytesRemainingToSeek) { - bytesSeeked += bytesRemainingToSeek; - (void)bytesSeeked; - oggbs->bytesRemainingInPage -= bytesRemainingToSeek; - break; - } - if (oggbs->bytesRemainingInPage > 0) { - bytesSeeked += (int)oggbs->bytesRemainingInPage; - oggbs->bytesRemainingInPage = 0; - } - MA_DR_FLAC_ASSERT(bytesRemainingToSeek > 0); - if (!ma_dr_flac_oggbs__goto_next_page(oggbs, ma_dr_flac_ogg_fail_on_crc_mismatch)) { - return MA_FALSE; + return ma_dr_flac__on_seek_ogg(pUserData, offset, MA_DR_FLAC_SEEK_CUR); + } else if (origin == MA_DR_FLAC_SEEK_CUR) { + while (bytesSeeked < offset) { + int bytesRemainingToSeek = offset - bytesSeeked; + MA_DR_FLAC_ASSERT(bytesRemainingToSeek >= 0); + if (oggbs->bytesRemainingInPage >= (size_t)bytesRemainingToSeek) { + bytesSeeked += bytesRemainingToSeek; + (void)bytesSeeked; + oggbs->bytesRemainingInPage -= bytesRemainingToSeek; + break; + } + if (oggbs->bytesRemainingInPage > 0) { + bytesSeeked += (int)oggbs->bytesRemainingInPage; + oggbs->bytesRemainingInPage = 0; + } + MA_DR_FLAC_ASSERT(bytesRemainingToSeek > 0); + if (!ma_dr_flac_oggbs__goto_next_page(oggbs, ma_dr_flac_ogg_fail_on_crc_mismatch)) { + return MA_FALSE; + } } + } else if (origin == MA_DR_FLAC_SEEK_END) { + return MA_FALSE; } return MA_TRUE; } +static ma_bool32 ma_dr_flac__on_tell_ogg(void* pUserData, ma_int64* pCursor) +{ + (void)pUserData; + (void)pCursor; + return MA_FALSE; +} static ma_bool32 ma_dr_flac_ogg__seek_to_pcm_frame(ma_dr_flac* pFlac, ma_uint64 pcmFrameIndex) { ma_dr_flac_oggbs* oggbs = (ma_dr_flac_oggbs*)pFlac->_oggbs; @@ -87515,7 +89513,7 @@ static ma_bool32 ma_dr_flac_ogg__seek_to_pcm_frame(ma_dr_flac* pFlac, ma_uint64 runningGranulePosition = 0; for (;;) { if (!ma_dr_flac_oggbs__goto_next_page(oggbs, ma_dr_flac_ogg_recover_on_crc_mismatch)) { - ma_dr_flac_oggbs__seek_physical(oggbs, originalBytePos, ma_dr_flac_seek_origin_start); + ma_dr_flac_oggbs__seek_physical(oggbs, originalBytePos, MA_DR_FLAC_SEEK_SET); return MA_FALSE; } runningFrameBytePos = oggbs->currentBytePos - ma_dr_flac_ogg__get_page_header_size(&oggbs->currentPageHeader) - oggbs->pageDataSize; @@ -87534,7 +89532,7 @@ static ma_bool32 ma_dr_flac_ogg__seek_to_pcm_frame(ma_dr_flac* pFlac, ma_uint64 } } } - if (!ma_dr_flac_oggbs__seek_physical(oggbs, runningFrameBytePos, ma_dr_flac_seek_origin_start)) { + if (!ma_dr_flac_oggbs__seek_physical(oggbs, runningFrameBytePos, MA_DR_FLAC_SEEK_SET)) { return MA_FALSE; } if (!ma_dr_flac_oggbs__goto_next_page(oggbs, ma_dr_flac_ogg_recover_on_crc_mismatch)) { @@ -87629,7 +89627,7 @@ static ma_bool32 ma_dr_flac__init_private__ogg(ma_dr_flac_init_info* pInit, ma_d if (mappingVersion[0] != 1) { return MA_FALSE; } - if (!onSeek(pUserData, 2, ma_dr_flac_seek_origin_current)) { + if (!onSeek(pUserData, 2, MA_DR_FLAC_SEEK_CUR)) { return MA_FALSE; } if (onRead(pUserData, sig, 4) != 4) { @@ -87674,17 +89672,17 @@ static ma_bool32 ma_dr_flac__init_private__ogg(ma_dr_flac_init_info* pInit, ma_d return MA_FALSE; } } else { - if (!onSeek(pUserData, bytesRemainingInPage, ma_dr_flac_seek_origin_current)) { + if (!onSeek(pUserData, bytesRemainingInPage, MA_DR_FLAC_SEEK_CUR)) { return MA_FALSE; } } } else { - if (!onSeek(pUserData, bytesRemainingInPage, ma_dr_flac_seek_origin_current)) { + if (!onSeek(pUserData, bytesRemainingInPage, MA_DR_FLAC_SEEK_CUR)) { return MA_FALSE; } } } else { - if (!onSeek(pUserData, pageBodySize, ma_dr_flac_seek_origin_current)) { + if (!onSeek(pUserData, pageBodySize, MA_DR_FLAC_SEEK_CUR)) { return MA_FALSE; } } @@ -87698,7 +89696,7 @@ static ma_bool32 ma_dr_flac__init_private__ogg(ma_dr_flac_init_info* pInit, ma_d return MA_TRUE; } #endif -static ma_bool32 ma_dr_flac__init_private(ma_dr_flac_init_info* pInit, ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_meta_proc onMeta, ma_dr_flac_container container, void* pUserData, void* pUserDataMD) +static ma_bool32 ma_dr_flac__init_private(ma_dr_flac_init_info* pInit, ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_tell_proc onTell, ma_dr_flac_meta_proc onMeta, ma_dr_flac_container container, void* pUserData, void* pUserDataMD) { ma_bool32 relaxed; ma_uint8 id[4]; @@ -87708,12 +89706,14 @@ static ma_bool32 ma_dr_flac__init_private(ma_dr_flac_init_info* pInit, ma_dr_fla MA_DR_FLAC_ZERO_MEMORY(pInit, sizeof(*pInit)); pInit->onRead = onRead; pInit->onSeek = onSeek; + pInit->onTell = onTell; pInit->onMeta = onMeta; pInit->container = container; pInit->pUserData = pUserData; pInit->pUserDataMD = pUserDataMD; pInit->bs.onRead = onRead; pInit->bs.onSeek = onSeek; + pInit->bs.onTell = onTell; pInit->bs.pUserData = pUserData; ma_dr_flac__reset_cache(&pInit->bs); relaxed = container != ma_dr_flac_container_unknown; @@ -87736,7 +89736,7 @@ static ma_bool32 ma_dr_flac__init_private(ma_dr_flac_init_info* pInit, ma_dr_fla if (flags & 0x10) { headerSize += 10; } - if (!onSeek(pUserData, headerSize, ma_dr_flac_seek_origin_current)) { + if (!onSeek(pUserData, headerSize, MA_DR_FLAC_SEEK_CUR)) { return MA_FALSE; } pInit->runningFilePos += headerSize; @@ -87779,7 +89779,7 @@ static void ma_dr_flac__init_from_info(ma_dr_flac* pFlac, const ma_dr_flac_init_ pFlac->totalPCMFrameCount = pInit->totalPCMFrameCount; pFlac->container = pInit->container; } -static ma_dr_flac* ma_dr_flac_open_with_metadata_private(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_meta_proc onMeta, ma_dr_flac_container container, void* pUserData, void* pUserDataMD, const ma_allocation_callbacks* pAllocationCallbacks) +static ma_dr_flac* ma_dr_flac_open_with_metadata_private(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_tell_proc onTell, ma_dr_flac_meta_proc onMeta, ma_dr_flac_container container, void* pUserData, void* pUserDataMD, const ma_allocation_callbacks* pAllocationCallbacks) { ma_dr_flac_init_info init; ma_uint32 allocationSize; @@ -87794,7 +89794,7 @@ static ma_dr_flac* ma_dr_flac_open_with_metadata_private(ma_dr_flac_read_proc on ma_allocation_callbacks allocationCallbacks; ma_dr_flac* pFlac; ma_dr_flac__init_cpu_caps(); - if (!ma_dr_flac__init_private(&init, onRead, onSeek, onMeta, container, pUserData, pUserDataMD)) { + if (!ma_dr_flac__init_private(&init, onRead, onSeek, onTell, onMeta, container, pUserData, pUserDataMD)) { return NULL; } if (pAllocationCallbacks != NULL) { @@ -87827,6 +89827,7 @@ static ma_dr_flac* ma_dr_flac_open_with_metadata_private(ma_dr_flac_read_proc on MA_DR_FLAC_ZERO_MEMORY(pOggbs, sizeof(*pOggbs)); pOggbs->onRead = onRead; pOggbs->onSeek = onSeek; + pOggbs->onTell = onTell; pOggbs->pUserData = pUserData; pOggbs->currentBytePos = init.oggFirstBytePos; pOggbs->firstBytePos = init.oggFirstBytePos; @@ -87841,15 +89842,17 @@ static ma_dr_flac* ma_dr_flac_open_with_metadata_private(ma_dr_flac_read_proc on if (init.hasMetadataBlocks) { ma_dr_flac_read_proc onReadOverride = onRead; ma_dr_flac_seek_proc onSeekOverride = onSeek; + ma_dr_flac_tell_proc onTellOverride = onTell; void* pUserDataOverride = pUserData; #ifndef MA_DR_FLAC_NO_OGG if (init.container == ma_dr_flac_container_ogg) { onReadOverride = ma_dr_flac__on_read_ogg; onSeekOverride = ma_dr_flac__on_seek_ogg; + onTellOverride = ma_dr_flac__on_tell_ogg; pUserDataOverride = (void*)pOggbs; } #endif - if (!ma_dr_flac__read_and_decode_metadata(onReadOverride, onSeekOverride, onMeta, pUserDataOverride, pUserDataMD, &firstFramePos, &seektablePos, &seekpointCount, &allocationCallbacks)) { + if (!ma_dr_flac__read_and_decode_metadata(onReadOverride, onSeekOverride, onTellOverride, onMeta, pUserDataOverride, pUserDataMD, &firstFramePos, &seektablePos, &seekpointCount, &allocationCallbacks)) { #ifndef MA_DR_FLAC_NO_OGG ma_dr_flac__free_from_callbacks(pOggbs, &allocationCallbacks); #endif @@ -87875,6 +89878,7 @@ static ma_dr_flac* ma_dr_flac_open_with_metadata_private(ma_dr_flac_read_proc on pOggbs = NULL; pFlac->bs.onRead = ma_dr_flac__on_read_ogg; pFlac->bs.onSeek = ma_dr_flac__on_seek_ogg; + pFlac->bs.onTell = ma_dr_flac__on_tell_ogg; pFlac->bs.pUserData = (void*)pInternalOggbs; pFlac->_oggbs = (void*)pInternalOggbs; } @@ -87894,7 +89898,7 @@ static ma_dr_flac* ma_dr_flac_open_with_metadata_private(ma_dr_flac_read_proc on pFlac->pSeekpoints = (ma_dr_flac_seekpoint*)((ma_uint8*)pFlac->pDecodedSamples + decodedSamplesAllocationSize); MA_DR_FLAC_ASSERT(pFlac->bs.onSeek != NULL); MA_DR_FLAC_ASSERT(pFlac->bs.onRead != NULL); - if (pFlac->bs.onSeek(pFlac->bs.pUserData, (int)seektablePos, ma_dr_flac_seek_origin_start)) { + if (pFlac->bs.onSeek(pFlac->bs.pUserData, (int)seektablePos, MA_DR_FLAC_SEEK_SET)) { ma_uint32 iSeekpoint; for (iSeekpoint = 0; iSeekpoint < seekpointCount; iSeekpoint += 1) { if (pFlac->bs.onRead(pFlac->bs.pUserData, pFlac->pSeekpoints + iSeekpoint, MA_DR_FLAC_SEEKPOINT_SIZE_IN_BYTES) == MA_DR_FLAC_SEEKPOINT_SIZE_IN_BYTES) { @@ -87907,7 +89911,7 @@ static ma_dr_flac* ma_dr_flac_open_with_metadata_private(ma_dr_flac_read_proc on break; } } - if (!pFlac->bs.onSeek(pFlac->bs.pUserData, (int)pFlac->firstFLACFramePosInBytes, ma_dr_flac_seek_origin_start)) { + if (!pFlac->bs.onSeek(pFlac->bs.pUserData, (int)pFlac->firstFLACFramePosInBytes, MA_DR_FLAC_SEEK_SET)) { ma_dr_flac__free_from_callbacks(pFlac, &allocationCallbacks); return NULL; } @@ -87950,8 +89954,31 @@ static size_t ma_dr_flac__on_read_stdio(void* pUserData, void* bufferOut, size_t } static ma_bool32 ma_dr_flac__on_seek_stdio(void* pUserData, int offset, ma_dr_flac_seek_origin origin) { - MA_DR_FLAC_ASSERT(offset >= 0); - return fseek((FILE*)pUserData, offset, (origin == ma_dr_flac_seek_origin_current) ? SEEK_CUR : SEEK_SET) == 0; + int whence = SEEK_SET; + if (origin == MA_DR_FLAC_SEEK_CUR) { + whence = SEEK_CUR; + } else if (origin == MA_DR_FLAC_SEEK_END) { + whence = SEEK_END; + } + return fseek((FILE*)pUserData, offset, whence) == 0; +} +static ma_bool32 ma_dr_flac__on_tell_stdio(void* pUserData, ma_int64* pCursor) +{ + FILE* pFileStdio = (FILE*)pUserData; + ma_int64 result; + MA_DR_FLAC_ASSERT(pFileStdio != NULL); + MA_DR_FLAC_ASSERT(pCursor != NULL); +#if defined(_WIN32) && !defined(NXDK) + #if defined(_MSC_VER) && _MSC_VER > 1200 + result = _ftelli64(pFileStdio); + #else + result = ftell(pFileStdio); + #endif +#else + result = ftell(pFileStdio); +#endif + *pCursor = result; + return MA_TRUE; } MA_API ma_dr_flac* ma_dr_flac_open_file(const char* pFileName, const ma_allocation_callbacks* pAllocationCallbacks) { @@ -87960,7 +89987,7 @@ MA_API ma_dr_flac* ma_dr_flac_open_file(const char* pFileName, const ma_allocati if (ma_fopen(&pFile, pFileName, "rb") != MA_SUCCESS) { return NULL; } - pFlac = ma_dr_flac_open(ma_dr_flac__on_read_stdio, ma_dr_flac__on_seek_stdio, (void*)pFile, pAllocationCallbacks); + pFlac = ma_dr_flac_open(ma_dr_flac__on_read_stdio, ma_dr_flac__on_seek_stdio, ma_dr_flac__on_tell_stdio, (void*)pFile, pAllocationCallbacks); if (pFlac == NULL) { fclose(pFile); return NULL; @@ -87975,7 +90002,7 @@ MA_API ma_dr_flac* ma_dr_flac_open_file_w(const wchar_t* pFileName, const ma_all if (ma_wfopen(&pFile, pFileName, L"rb", pAllocationCallbacks) != MA_SUCCESS) { return NULL; } - pFlac = ma_dr_flac_open(ma_dr_flac__on_read_stdio, ma_dr_flac__on_seek_stdio, (void*)pFile, pAllocationCallbacks); + pFlac = ma_dr_flac_open(ma_dr_flac__on_read_stdio, ma_dr_flac__on_seek_stdio, ma_dr_flac__on_tell_stdio, (void*)pFile, pAllocationCallbacks); if (pFlac == NULL) { fclose(pFile); return NULL; @@ -87990,7 +90017,7 @@ MA_API ma_dr_flac* ma_dr_flac_open_file_with_metadata(const char* pFileName, ma_ if (ma_fopen(&pFile, pFileName, "rb") != MA_SUCCESS) { return NULL; } - pFlac = ma_dr_flac_open_with_metadata_private(ma_dr_flac__on_read_stdio, ma_dr_flac__on_seek_stdio, onMeta, ma_dr_flac_container_unknown, (void*)pFile, pUserData, pAllocationCallbacks); + pFlac = ma_dr_flac_open_with_metadata_private(ma_dr_flac__on_read_stdio, ma_dr_flac__on_seek_stdio, ma_dr_flac__on_tell_stdio, onMeta, ma_dr_flac_container_unknown, (void*)pFile, pUserData, pAllocationCallbacks); if (pFlac == NULL) { fclose(pFile); return pFlac; @@ -88005,7 +90032,7 @@ MA_API ma_dr_flac* ma_dr_flac_open_file_with_metadata_w(const wchar_t* pFileName if (ma_wfopen(&pFile, pFileName, L"rb", pAllocationCallbacks) != MA_SUCCESS) { return NULL; } - pFlac = ma_dr_flac_open_with_metadata_private(ma_dr_flac__on_read_stdio, ma_dr_flac__on_seek_stdio, onMeta, ma_dr_flac_container_unknown, (void*)pFile, pUserData, pAllocationCallbacks); + pFlac = ma_dr_flac_open_with_metadata_private(ma_dr_flac__on_read_stdio, ma_dr_flac__on_seek_stdio, ma_dr_flac__on_tell_stdio, onMeta, ma_dr_flac_container_unknown, (void*)pFile, pUserData, pAllocationCallbacks); if (pFlac == NULL) { fclose(pFile); return pFlac; @@ -88033,24 +90060,34 @@ static size_t ma_dr_flac__on_read_memory(void* pUserData, void* bufferOut, size_ static ma_bool32 ma_dr_flac__on_seek_memory(void* pUserData, int offset, ma_dr_flac_seek_origin origin) { ma_dr_flac__memory_stream* memoryStream = (ma_dr_flac__memory_stream*)pUserData; + ma_int64 newCursor; MA_DR_FLAC_ASSERT(memoryStream != NULL); - MA_DR_FLAC_ASSERT(offset >= 0); - if (offset > (ma_int64)memoryStream->dataSize) { + if (origin == MA_DR_FLAC_SEEK_SET) { + newCursor = 0; + } else if (origin == MA_DR_FLAC_SEEK_CUR) { + newCursor = (ma_int64)memoryStream->currentReadPos; + } else if (origin == MA_DR_FLAC_SEEK_END) { + newCursor = (ma_int64)memoryStream->dataSize; + } else { + MA_DR_FLAC_ASSERT(!"Invalid seek origin"); return MA_FALSE; } - if (origin == ma_dr_flac_seek_origin_current) { - if (memoryStream->currentReadPos + offset <= memoryStream->dataSize) { - memoryStream->currentReadPos += offset; - } else { - return MA_FALSE; - } - } else { - if ((ma_uint32)offset <= memoryStream->dataSize) { - memoryStream->currentReadPos = offset; - } else { - return MA_FALSE; - } + newCursor += offset; + if (newCursor < 0) { + return MA_FALSE; } + if ((size_t)newCursor > memoryStream->dataSize) { + return MA_FALSE; + } + memoryStream->currentReadPos = (size_t)newCursor; + return MA_TRUE; +} +static ma_bool32 ma_dr_flac__on_tell_memory(void* pUserData, ma_int64* pCursor) +{ + ma_dr_flac__memory_stream* memoryStream = (ma_dr_flac__memory_stream*)pUserData; + MA_DR_FLAC_ASSERT(memoryStream != NULL); + MA_DR_FLAC_ASSERT(pCursor != NULL); + *pCursor = (ma_int64)memoryStream->currentReadPos; return MA_TRUE; } MA_API ma_dr_flac* ma_dr_flac_open_memory(const void* pData, size_t dataSize, const ma_allocation_callbacks* pAllocationCallbacks) @@ -88060,7 +90097,7 @@ MA_API ma_dr_flac* ma_dr_flac_open_memory(const void* pData, size_t dataSize, co memoryStream.data = (const ma_uint8*)pData; memoryStream.dataSize = dataSize; memoryStream.currentReadPos = 0; - pFlac = ma_dr_flac_open(ma_dr_flac__on_read_memory, ma_dr_flac__on_seek_memory, &memoryStream, pAllocationCallbacks); + pFlac = ma_dr_flac_open(ma_dr_flac__on_read_memory, ma_dr_flac__on_seek_memory, ma_dr_flac__on_tell_memory, &memoryStream, pAllocationCallbacks); if (pFlac == NULL) { return NULL; } @@ -88085,7 +90122,7 @@ MA_API ma_dr_flac* ma_dr_flac_open_memory_with_metadata(const void* pData, size_ memoryStream.data = (const ma_uint8*)pData; memoryStream.dataSize = dataSize; memoryStream.currentReadPos = 0; - pFlac = ma_dr_flac_open_with_metadata_private(ma_dr_flac__on_read_memory, ma_dr_flac__on_seek_memory, onMeta, ma_dr_flac_container_unknown, &memoryStream, pUserData, pAllocationCallbacks); + pFlac = ma_dr_flac_open_with_metadata_private(ma_dr_flac__on_read_memory, ma_dr_flac__on_seek_memory, ma_dr_flac__on_tell_memory, onMeta, ma_dr_flac_container_unknown, &memoryStream, pUserData, pAllocationCallbacks); if (pFlac == NULL) { return NULL; } @@ -88103,21 +90140,21 @@ MA_API ma_dr_flac* ma_dr_flac_open_memory_with_metadata(const void* pData, size_ } return pFlac; } -MA_API ma_dr_flac* ma_dr_flac_open(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks) +MA_API ma_dr_flac* ma_dr_flac_open(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_tell_proc onTell, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks) { - return ma_dr_flac_open_with_metadata_private(onRead, onSeek, NULL, ma_dr_flac_container_unknown, pUserData, pUserData, pAllocationCallbacks); + return ma_dr_flac_open_with_metadata_private(onRead, onSeek, onTell, NULL, ma_dr_flac_container_unknown, pUserData, pUserData, pAllocationCallbacks); } -MA_API ma_dr_flac* ma_dr_flac_open_relaxed(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_container container, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks) +MA_API ma_dr_flac* ma_dr_flac_open_relaxed(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_tell_proc onTell, ma_dr_flac_container container, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks) { - return ma_dr_flac_open_with_metadata_private(onRead, onSeek, NULL, container, pUserData, pUserData, pAllocationCallbacks); + return ma_dr_flac_open_with_metadata_private(onRead, onSeek, onTell, NULL, container, pUserData, pUserData, pAllocationCallbacks); } -MA_API ma_dr_flac* ma_dr_flac_open_with_metadata(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_meta_proc onMeta, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks) +MA_API ma_dr_flac* ma_dr_flac_open_with_metadata(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_tell_proc onTell, ma_dr_flac_meta_proc onMeta, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks) { - return ma_dr_flac_open_with_metadata_private(onRead, onSeek, onMeta, ma_dr_flac_container_unknown, pUserData, pUserData, pAllocationCallbacks); + return ma_dr_flac_open_with_metadata_private(onRead, onSeek, onTell, onMeta, ma_dr_flac_container_unknown, pUserData, pUserData, pAllocationCallbacks); } -MA_API ma_dr_flac* ma_dr_flac_open_with_metadata_relaxed(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_meta_proc onMeta, ma_dr_flac_container container, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks) +MA_API ma_dr_flac* ma_dr_flac_open_with_metadata_relaxed(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_tell_proc onTell, ma_dr_flac_meta_proc onMeta, ma_dr_flac_container container, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks) { - return ma_dr_flac_open_with_metadata_private(onRead, onSeek, onMeta, container, pUserData, pUserData, pAllocationCallbacks); + return ma_dr_flac_open_with_metadata_private(onRead, onSeek, onTell, onMeta, container, pUserData, pUserData, pAllocationCallbacks); } MA_API void ma_dr_flac_close(ma_dr_flac* pFlac) { @@ -90345,57 +92382,42 @@ static type* ma_dr_flac__full_read_and_close_ ## extension (ma_dr_flac* pFlac, u { \ type* pSampleData = NULL; \ ma_uint64 totalPCMFrameCount; \ + type buffer[4096]; \ + ma_uint64 pcmFramesRead; \ + size_t sampleDataBufferSize = sizeof(buffer); \ \ MA_DR_FLAC_ASSERT(pFlac != NULL); \ \ - totalPCMFrameCount = pFlac->totalPCMFrameCount; \ - \ - if (totalPCMFrameCount == 0) { \ - type buffer[4096]; \ - ma_uint64 pcmFramesRead; \ - size_t sampleDataBufferSize = sizeof(buffer); \ + totalPCMFrameCount = 0; \ \ - pSampleData = (type*)ma_dr_flac__malloc_from_callbacks(sampleDataBufferSize, &pFlac->allocationCallbacks); \ - if (pSampleData == NULL) { \ - goto on_error; \ - } \ - \ - while ((pcmFramesRead = (ma_uint64)ma_dr_flac_read_pcm_frames_##extension(pFlac, sizeof(buffer)/sizeof(buffer[0])/pFlac->channels, buffer)) > 0) { \ - if (((totalPCMFrameCount + pcmFramesRead) * pFlac->channels * sizeof(type)) > sampleDataBufferSize) { \ - type* pNewSampleData; \ - size_t newSampleDataBufferSize; \ + pSampleData = (type*)ma_dr_flac__malloc_from_callbacks(sampleDataBufferSize, &pFlac->allocationCallbacks); \ + if (pSampleData == NULL) { \ + goto on_error; \ + } \ \ - newSampleDataBufferSize = sampleDataBufferSize * 2; \ - pNewSampleData = (type*)ma_dr_flac__realloc_from_callbacks(pSampleData, newSampleDataBufferSize, sampleDataBufferSize, &pFlac->allocationCallbacks); \ - if (pNewSampleData == NULL) { \ - ma_dr_flac__free_from_callbacks(pSampleData, &pFlac->allocationCallbacks); \ - goto on_error; \ - } \ + while ((pcmFramesRead = (ma_uint64)ma_dr_flac_read_pcm_frames_##extension(pFlac, sizeof(buffer)/sizeof(buffer[0])/pFlac->channels, buffer)) > 0) { \ + if (((totalPCMFrameCount + pcmFramesRead) * pFlac->channels * sizeof(type)) > sampleDataBufferSize) { \ + type* pNewSampleData; \ + size_t newSampleDataBufferSize; \ \ - sampleDataBufferSize = newSampleDataBufferSize; \ - pSampleData = pNewSampleData; \ + newSampleDataBufferSize = sampleDataBufferSize * 2; \ + pNewSampleData = (type*)ma_dr_flac__realloc_from_callbacks(pSampleData, newSampleDataBufferSize, sampleDataBufferSize, &pFlac->allocationCallbacks); \ + if (pNewSampleData == NULL) { \ + ma_dr_flac__free_from_callbacks(pSampleData, &pFlac->allocationCallbacks); \ + goto on_error; \ } \ \ - MA_DR_FLAC_COPY_MEMORY(pSampleData + (totalPCMFrameCount*pFlac->channels), buffer, (size_t)(pcmFramesRead*pFlac->channels*sizeof(type))); \ - totalPCMFrameCount += pcmFramesRead; \ - } \ - \ - \ - MA_DR_FLAC_ZERO_MEMORY(pSampleData + (totalPCMFrameCount*pFlac->channels), (size_t)(sampleDataBufferSize - totalPCMFrameCount*pFlac->channels*sizeof(type))); \ - } else { \ - ma_uint64 dataSize = totalPCMFrameCount*pFlac->channels*sizeof(type); \ - if (dataSize > (ma_uint64)MA_SIZE_MAX) { \ - goto on_error; \ - } \ - \ - pSampleData = (type*)ma_dr_flac__malloc_from_callbacks((size_t)dataSize, &pFlac->allocationCallbacks); \ - if (pSampleData == NULL) { \ - goto on_error; \ + sampleDataBufferSize = newSampleDataBufferSize; \ + pSampleData = pNewSampleData; \ } \ \ - totalPCMFrameCount = ma_dr_flac_read_pcm_frames_##extension(pFlac, pFlac->totalPCMFrameCount, pSampleData); \ + MA_DR_FLAC_COPY_MEMORY(pSampleData + (totalPCMFrameCount*pFlac->channels), buffer, (size_t)(pcmFramesRead*pFlac->channels*sizeof(type))); \ + totalPCMFrameCount += pcmFramesRead; \ } \ \ + \ + MA_DR_FLAC_ZERO_MEMORY(pSampleData + (totalPCMFrameCount*pFlac->channels), (size_t)(sampleDataBufferSize - totalPCMFrameCount*pFlac->channels*sizeof(type))); \ + \ if (sampleRateOut) *sampleRateOut = pFlac->sampleRate; \ if (channelsOut) *channelsOut = pFlac->channels; \ if (totalPCMFrameCountOut) *totalPCMFrameCountOut = totalPCMFrameCount; \ @@ -90410,7 +92432,7 @@ on_error: MA_DR_FLAC_DEFINE_FULL_READ_AND_CLOSE(s32, ma_int32) MA_DR_FLAC_DEFINE_FULL_READ_AND_CLOSE(s16, ma_int16) MA_DR_FLAC_DEFINE_FULL_READ_AND_CLOSE(f32, float) -MA_API ma_int32* ma_dr_flac_open_and_read_pcm_frames_s32(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalPCMFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks) +MA_API ma_int32* ma_dr_flac_open_and_read_pcm_frames_s32(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_tell_proc onTell, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalPCMFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks) { ma_dr_flac* pFlac; if (channelsOut) { @@ -90422,13 +92444,13 @@ MA_API ma_int32* ma_dr_flac_open_and_read_pcm_frames_s32(ma_dr_flac_read_proc on if (totalPCMFrameCountOut) { *totalPCMFrameCountOut = 0; } - pFlac = ma_dr_flac_open(onRead, onSeek, pUserData, pAllocationCallbacks); + pFlac = ma_dr_flac_open(onRead, onSeek, onTell, pUserData, pAllocationCallbacks); if (pFlac == NULL) { return NULL; } return ma_dr_flac__full_read_and_close_s32(pFlac, channelsOut, sampleRateOut, totalPCMFrameCountOut); } -MA_API ma_int16* ma_dr_flac_open_and_read_pcm_frames_s16(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalPCMFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks) +MA_API ma_int16* ma_dr_flac_open_and_read_pcm_frames_s16(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_tell_proc onTell, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalPCMFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks) { ma_dr_flac* pFlac; if (channelsOut) { @@ -90440,13 +92462,13 @@ MA_API ma_int16* ma_dr_flac_open_and_read_pcm_frames_s16(ma_dr_flac_read_proc on if (totalPCMFrameCountOut) { *totalPCMFrameCountOut = 0; } - pFlac = ma_dr_flac_open(onRead, onSeek, pUserData, pAllocationCallbacks); + pFlac = ma_dr_flac_open(onRead, onSeek, onTell, pUserData, pAllocationCallbacks); if (pFlac == NULL) { return NULL; } return ma_dr_flac__full_read_and_close_s16(pFlac, channelsOut, sampleRateOut, totalPCMFrameCountOut); } -MA_API float* ma_dr_flac_open_and_read_pcm_frames_f32(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalPCMFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks) +MA_API float* ma_dr_flac_open_and_read_pcm_frames_f32(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_tell_proc onTell, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalPCMFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks) { ma_dr_flac* pFlac; if (channelsOut) { @@ -90458,7 +92480,7 @@ MA_API float* ma_dr_flac_open_and_read_pcm_frames_f32(ma_dr_flac_read_proc onRea if (totalPCMFrameCountOut) { *totalPCMFrameCountOut = 0; } - pFlac = ma_dr_flac_open(onRead, onSeek, pUserData, pAllocationCallbacks); + pFlac = ma_dr_flac_open(onRead, onSeek, onTell, pUserData, pAllocationCallbacks); if (pFlac == NULL) { return NULL; } @@ -90680,12 +92702,9 @@ MA_API const char* ma_dr_mp3_version_string(void) #define MA_DR_MP3_NO_SIMD #endif #define MA_DR_MP3_OFFSET_PTR(p, offset) ((void*)((ma_uint8*)(p) + (offset))) -#define MA_DR_MP3_MAX_FREE_FORMAT_FRAME_SIZE 2304 #ifndef MA_DR_MP3_MAX_FRAME_SYNC_MATCHES #define MA_DR_MP3_MAX_FRAME_SYNC_MATCHES 10 #endif -#define MA_DR_MP3_MAX_L3_FRAME_PAYLOAD_BYTES MA_DR_MP3_MAX_FREE_FORMAT_FRAME_SIZE -#define MA_DR_MP3_MAX_BITRESERVOIR_BYTES 511 #define MA_DR_MP3_SHORT_BLOCK_TYPE 2 #define MA_DR_MP3_STOP_BLOCK_TYPE 3 #define MA_DR_MP3_MODE_MONO 3 @@ -90735,7 +92754,7 @@ MA_API const char* ma_dr_mp3_version_string(void) #define MA_DR_MP3_VMUL_S(x, s) _mm_mul_ps(x, _mm_set1_ps(s)) #define MA_DR_MP3_VREV(x) _mm_shuffle_ps(x, x, _MM_SHUFFLE(0, 1, 2, 3)) typedef __m128 ma_dr_mp3_f4; -#if defined(_MSC_VER) || defined(MA_DR_MP3_ONLY_SIMD) +#if (defined(_MSC_VER) || defined(MA_DR_MP3_ONLY_SIMD)) && !defined(__clang__) #define ma_dr_mp3_cpuid __cpuid #else static __inline__ __attribute__((always_inline)) void ma_dr_mp3_cpuid(int CPUInfo[], const int InfoType) @@ -90851,11 +92870,6 @@ static __inline__ __attribute__((always_inline)) ma_int32 ma_dr_mp3_clip_int16_a #define MA_DR_MP3_FREE(p) free((p)) #endif typedef struct -{ - const ma_uint8 *buf; - int pos, limit; -} ma_dr_mp3_bs; -typedef struct { float scf[3*64]; ma_uint8 total_bands, stereo_bands, bitalloc[64], scfcod[64]; @@ -90864,22 +92878,6 @@ typedef struct { ma_uint8 tab_offset, code_tab_width, band_count; } ma_dr_mp3_L12_subband_alloc; -typedef struct -{ - const ma_uint8 *sfbtab; - ma_uint16 part_23_length, big_values, scalefac_compress; - ma_uint8 global_gain, block_type, mixed_block_flag, n_long_sfb, n_short_sfb; - ma_uint8 table_select[3], region_count[3], subblock_gain[3]; - ma_uint8 preflag, scalefac_scale, count1_table, scfsi; -} ma_dr_mp3_L3_gr_info; -typedef struct -{ - ma_dr_mp3_bs bs; - ma_uint8 maindata[MA_DR_MP3_MAX_BITRESERVOIR_BYTES + MA_DR_MP3_MAX_L3_FRAME_PAYLOAD_BYTES]; - ma_dr_mp3_L3_gr_info gr_info[4]; - float grbuf[2][576], scf[40], syn[18 + 15][2*32]; - ma_uint8 ist_pos[2][39]; -} ma_dr_mp3dec_scratch; static void ma_dr_mp3_bs_init(ma_dr_mp3_bs *bs, const ma_uint8 *data, int bytes) { bs->buf = data; @@ -91262,6 +93260,10 @@ static float ma_dr_mp3_L3_ldexp_q2(float y, int exp_q2) } while ((exp_q2 -= e) > 0); return y; } +#if (defined(__GNUC__) && (__GNUC__ >= 13)) && !defined(__clang__) + #pragma GCC diagnostic push + #pragma GCC diagnostic ignored "-Wstringop-overflow" +#endif static void ma_dr_mp3_L3_decode_scalefactors(const ma_uint8 *hdr, ma_uint8 *ist_pos, ma_dr_mp3_bs *bs, const ma_dr_mp3_L3_gr_info *gr, float *scf, int ch) { static const ma_uint8 g_scf_partitions[3][28] = { @@ -91320,7 +93322,10 @@ static void ma_dr_mp3_L3_decode_scalefactors(const ma_uint8 *hdr, ma_uint8 *ist_ scf[i] = ma_dr_mp3_L3_ldexp_q2(gain, iscf[i] << scf_shift); } } -static const float g_ma_dr_mp3_pow43[129 + 16] = { +#if (defined(__GNUC__) && (__GNUC__ >= 13)) && !defined(__clang__) + #pragma GCC diagnostic pop +#endif +static const float ma_dr_mp3_g_pow43[129 + 16] = { 0,-1,-2.519842f,-4.326749f,-6.349604f,-8.549880f,-10.902724f,-13.390518f,-16.000000f,-18.720754f,-21.544347f,-24.463781f,-27.473142f,-30.567351f,-33.741992f,-36.993181f, 0,1,2.519842f,4.326749f,6.349604f,8.549880f,10.902724f,13.390518f,16.000000f,18.720754f,21.544347f,24.463781f,27.473142f,30.567351f,33.741992f,36.993181f,40.317474f,43.711787f,47.173345f,50.699631f,54.288352f,57.937408f,61.644865f,65.408941f,69.227979f,73.100443f,77.024898f,81.000000f,85.024491f,89.097188f,93.216975f,97.382800f,101.593667f,105.848633f,110.146801f,114.487321f,118.869381f,123.292209f,127.755065f,132.257246f,136.798076f,141.376907f,145.993119f,150.646117f,155.335327f,160.060199f,164.820202f,169.614826f,174.443577f,179.305980f,184.201575f,189.129918f,194.090580f,199.083145f,204.107210f,209.162385f,214.248292f,219.364564f,224.510845f,229.686789f,234.892058f,240.126328f,245.389280f,250.680604f,256.000000f,261.347174f,266.721841f,272.123723f,277.552547f,283.008049f,288.489971f,293.998060f,299.532071f,305.091761f,310.676898f,316.287249f,321.922592f,327.582707f,333.267377f,338.976394f,344.709550f,350.466646f,356.247482f,362.051866f,367.879608f,373.730522f,379.604427f,385.501143f,391.420496f,397.362314f,403.326427f,409.312672f,415.320884f,421.350905f,427.402579f,433.475750f,439.570269f,445.685987f,451.822757f,457.980436f,464.158883f,470.357960f,476.577530f,482.817459f,489.077615f,495.357868f,501.658090f,507.978156f,514.317941f,520.677324f,527.056184f,533.454404f,539.871867f,546.308458f,552.764065f,559.238575f,565.731879f,572.243870f,578.774440f,585.323483f,591.890898f,598.476581f,605.080431f,611.702349f,618.342238f,625.000000f,631.675540f,638.368763f,645.079578f }; @@ -91330,7 +93335,7 @@ static float ma_dr_mp3_L3_pow_43(int x) int sign, mult = 256; if (x < 129) { - return g_ma_dr_mp3_pow43[16 + x]; + return ma_dr_mp3_g_pow43[16 + x]; } if (x < 1024) { @@ -91339,7 +93344,7 @@ static float ma_dr_mp3_L3_pow_43(int x) } sign = 2*x & 64; frac = (float)((x & 63) - sign) / ((x & ~63) + sign); - return g_ma_dr_mp3_pow43[16 + ((x + sign) >> 6)]*(1.f + frac*((4.f/3) + frac*(2.f/9)))*mult; + return ma_dr_mp3_g_pow43[16 + ((x + sign) >> 6)]*(1.f + frac*((4.f/3) + frac*(2.f/9)))*mult; } static void ma_dr_mp3_L3_huffman(float *dst, ma_dr_mp3_bs *bs, const ma_dr_mp3_L3_gr_info *gr_info, const float *scf, int layer3gr_limit) { @@ -91409,7 +93414,7 @@ static void ma_dr_mp3_L3_huffman(float *dst, ma_dr_mp3_bs *bs, const ma_dr_mp3_L *dst = one*ma_dr_mp3_L3_pow_43(lsb)*((ma_int32)bs_cache < 0 ? -1: 1); } else { - *dst = g_ma_dr_mp3_pow43[16 + lsb - 16*(bs_cache >> 31)]*one; + *dst = ma_dr_mp3_g_pow43[16 + lsb - 16*(bs_cache >> 31)]*one; } MA_DR_MP3_FLUSH_BITS(lsb ? 1 : 0); } @@ -91437,7 +93442,7 @@ static void ma_dr_mp3_L3_huffman(float *dst, ma_dr_mp3_bs *bs, const ma_dr_mp3_L for (j = 0; j < 2; j++, dst++, leaf >>= 4) { int lsb = leaf & 0x0F; - *dst = g_ma_dr_mp3_pow43[16 + lsb - 16*(bs_cache >> 31)]*one; + *dst = ma_dr_mp3_g_pow43[16 + lsb - 16*(bs_cache >> 31)]*one; MA_DR_MP3_FLUSH_BITS(lsb ? 1 : 0); } MA_DR_MP3_CHECK_BITS; @@ -92245,7 +94250,6 @@ MA_API int ma_dr_mp3dec_decode_frame(ma_dr_mp3dec *dec, const ma_uint8 *mp3, int int i = 0, igr, frame_size = 0, success = 1; const ma_uint8 *hdr; ma_dr_mp3_bs bs_frame[1]; - ma_dr_mp3dec_scratch scratch; if (mp3_bytes > 4 && dec->header[0] == 0xff && ma_dr_mp3_hdr_compare(dec->header, mp3)) { frame_size = ma_dr_mp3_hdr_frame_bytes(mp3, dec->free_format_bytes) + ma_dr_mp3_hdr_padding(mp3); @@ -92268,7 +94272,7 @@ MA_API int ma_dr_mp3dec_decode_frame(ma_dr_mp3dec *dec, const ma_uint8 *mp3, int MA_DR_MP3_COPY_MEMORY(dec->header, hdr, MA_DR_MP3_HDR_SIZE); info->frame_bytes = i + frame_size; info->channels = MA_DR_MP3_HDR_IS_MONO(hdr) ? 1 : 2; - info->hz = ma_dr_mp3_hdr_sample_rate_hz(hdr); + info->sample_rate = ma_dr_mp3_hdr_sample_rate_hz(hdr); info->layer = 4 - MA_DR_MP3_HDR_GET_LAYER(hdr); info->bitrate_kbps = ma_dr_mp3_hdr_bitrate_kbps(hdr); ma_dr_mp3_bs_init(bs_frame, hdr + MA_DR_MP3_HDR_SIZE, frame_size - MA_DR_MP3_HDR_SIZE); @@ -92278,23 +94282,23 @@ MA_API int ma_dr_mp3dec_decode_frame(ma_dr_mp3dec *dec, const ma_uint8 *mp3, int } if (info->layer == 3) { - int main_data_begin = ma_dr_mp3_L3_read_side_info(bs_frame, scratch.gr_info, hdr); + int main_data_begin = ma_dr_mp3_L3_read_side_info(bs_frame, dec->scratch.gr_info, hdr); if (main_data_begin < 0 || bs_frame->pos > bs_frame->limit) { ma_dr_mp3dec_init(dec); return 0; } - success = ma_dr_mp3_L3_restore_reservoir(dec, bs_frame, &scratch, main_data_begin); + success = ma_dr_mp3_L3_restore_reservoir(dec, bs_frame, &dec->scratch, main_data_begin); if (success && pcm != NULL) { for (igr = 0; igr < (MA_DR_MP3_HDR_TEST_MPEG1(hdr) ? 2 : 1); igr++, pcm = MA_DR_MP3_OFFSET_PTR(pcm, sizeof(ma_dr_mp3d_sample_t)*576*info->channels)) { - MA_DR_MP3_ZERO_MEMORY(scratch.grbuf[0], 576*2*sizeof(float)); - ma_dr_mp3_L3_decode(dec, &scratch, scratch.gr_info + igr*info->channels, info->channels); - ma_dr_mp3d_synth_granule(dec->qmf_state, scratch.grbuf[0], 18, info->channels, (ma_dr_mp3d_sample_t*)pcm, scratch.syn[0]); + MA_DR_MP3_ZERO_MEMORY(dec->scratch.grbuf[0], 576*2*sizeof(float)); + ma_dr_mp3_L3_decode(dec, &dec->scratch, dec->scratch.gr_info + igr*info->channels, info->channels); + ma_dr_mp3d_synth_granule(dec->qmf_state, dec->scratch.grbuf[0], 18, info->channels, (ma_dr_mp3d_sample_t*)pcm, dec->scratch.syn[0]); } } - ma_dr_mp3_L3_save_reservoir(dec, &scratch); + ma_dr_mp3_L3_save_reservoir(dec, &dec->scratch); } else { #ifdef MA_DR_MP3_ONLY_MP3 @@ -92305,15 +94309,15 @@ MA_API int ma_dr_mp3dec_decode_frame(ma_dr_mp3dec *dec, const ma_uint8 *mp3, int return ma_dr_mp3_hdr_frame_samples(hdr); } ma_dr_mp3_L12_read_scale_info(hdr, bs_frame, sci); - MA_DR_MP3_ZERO_MEMORY(scratch.grbuf[0], 576*2*sizeof(float)); + MA_DR_MP3_ZERO_MEMORY(dec->scratch.grbuf[0], 576*2*sizeof(float)); for (i = 0, igr = 0; igr < 3; igr++) { - if (12 == (i += ma_dr_mp3_L12_dequantize_granule(scratch.grbuf[0] + i, bs_frame, sci, info->layer | 1))) + if (12 == (i += ma_dr_mp3_L12_dequantize_granule(dec->scratch.grbuf[0] + i, bs_frame, sci, info->layer | 1))) { i = 0; - ma_dr_mp3_L12_apply_scf_384(sci, sci->scf + igr, scratch.grbuf[0]); - ma_dr_mp3d_synth_granule(dec->qmf_state, scratch.grbuf[0], 12, info->channels, (ma_dr_mp3d_sample_t*)pcm, scratch.syn[0]); - MA_DR_MP3_ZERO_MEMORY(scratch.grbuf[0], 576*2*sizeof(float)); + ma_dr_mp3_L12_apply_scf_384(sci, sci->scf + igr, dec->scratch.grbuf[0]); + ma_dr_mp3d_synth_granule(dec->qmf_state, dec->scratch.grbuf[0], 12, info->channels, (ma_dr_mp3d_sample_t*)pcm, dec->scratch.syn[0]); + MA_DR_MP3_ZERO_MEMORY(dec->scratch.grbuf[0], 576*2*sizeof(float)); pcm = MA_DR_MP3_OFFSET_PTR(pcm, sizeof(ma_dr_mp3d_sample_t)*384*info->channels); } if (bs_frame->pos > bs_frame->limit) @@ -92491,19 +94495,41 @@ static ma_allocation_callbacks ma_dr_mp3_copy_allocation_callbacks_or_defaults(c } static size_t ma_dr_mp3__on_read(ma_dr_mp3* pMP3, void* pBufferOut, size_t bytesToRead) { - size_t bytesRead = pMP3->onRead(pMP3->pUserData, pBufferOut, bytesToRead); + size_t bytesRead; + MA_DR_MP3_ASSERT(pMP3 != NULL); + MA_DR_MP3_ASSERT(pMP3->onRead != NULL); + if (bytesToRead == 0) { + return 0; + } + bytesRead = pMP3->onRead(pMP3->pUserData, pBufferOut, bytesToRead); pMP3->streamCursor += bytesRead; return bytesRead; } +static size_t ma_dr_mp3__on_read_clamped(ma_dr_mp3* pMP3, void* pBufferOut, size_t bytesToRead) +{ + MA_DR_MP3_ASSERT(pMP3 != NULL); + MA_DR_MP3_ASSERT(pMP3->onRead != NULL); + if (pMP3->streamLength == MA_UINT64_MAX) { + return ma_dr_mp3__on_read(pMP3, pBufferOut, bytesToRead); + } else { + ma_uint64 bytesRemaining; + bytesRemaining = (pMP3->streamLength - pMP3->streamCursor); + if (bytesToRead > bytesRemaining) { + bytesToRead = (size_t)bytesRemaining; + } + return ma_dr_mp3__on_read(pMP3, pBufferOut, bytesToRead); + } +} static ma_bool32 ma_dr_mp3__on_seek(ma_dr_mp3* pMP3, int offset, ma_dr_mp3_seek_origin origin) { MA_DR_MP3_ASSERT(offset >= 0); + MA_DR_MP3_ASSERT(origin == MA_DR_MP3_SEEK_SET || origin == MA_DR_MP3_SEEK_CUR); if (!pMP3->onSeek(pMP3->pUserData, offset, origin)) { return MA_FALSE; } - if (origin == ma_dr_mp3_seek_origin_start) { + if (origin == MA_DR_MP3_SEEK_SET) { pMP3->streamCursor = (ma_uint64)offset; - } else { + } else{ pMP3->streamCursor += offset; } return MA_TRUE; @@ -92513,18 +94539,18 @@ static ma_bool32 ma_dr_mp3__on_seek_64(ma_dr_mp3* pMP3, ma_uint64 offset, ma_dr_ if (offset <= 0x7FFFFFFF) { return ma_dr_mp3__on_seek(pMP3, (int)offset, origin); } - if (!ma_dr_mp3__on_seek(pMP3, 0x7FFFFFFF, ma_dr_mp3_seek_origin_start)) { + if (!ma_dr_mp3__on_seek(pMP3, 0x7FFFFFFF, MA_DR_MP3_SEEK_SET)) { return MA_FALSE; } offset -= 0x7FFFFFFF; while (offset > 0) { if (offset <= 0x7FFFFFFF) { - if (!ma_dr_mp3__on_seek(pMP3, (int)offset, ma_dr_mp3_seek_origin_current)) { + if (!ma_dr_mp3__on_seek(pMP3, (int)offset, MA_DR_MP3_SEEK_CUR)) { return MA_FALSE; } offset = 0; } else { - if (!ma_dr_mp3__on_seek(pMP3, 0x7FFFFFFF, ma_dr_mp3_seek_origin_current)) { + if (!ma_dr_mp3__on_seek(pMP3, 0x7FFFFFFF, MA_DR_MP3_SEEK_CUR)) { return MA_FALSE; } offset -= 0x7FFFFFFF; @@ -92532,7 +94558,18 @@ static ma_bool32 ma_dr_mp3__on_seek_64(ma_dr_mp3* pMP3, ma_uint64 offset, ma_dr_ } return MA_TRUE; } -static ma_uint32 ma_dr_mp3_decode_next_frame_ex__callbacks(ma_dr_mp3* pMP3, ma_dr_mp3d_sample_t* pPCMFrames) +static void ma_dr_mp3__on_meta(ma_dr_mp3* pMP3, ma_dr_mp3_metadata_type type, const void* pRawData, size_t rawDataSize) +{ + if (pMP3->onMeta) { + ma_dr_mp3_metadata metadata; + MA_DR_MP3_ZERO_OBJECT(&metadata); + metadata.type = type; + metadata.pRawData = pRawData; + metadata.rawDataSize = rawDataSize; + pMP3->onMeta(pMP3->pUserDataMeta, &metadata); + } +} +static ma_uint32 ma_dr_mp3_decode_next_frame_ex__callbacks(ma_dr_mp3* pMP3, ma_dr_mp3d_sample_t* pPCMFrames, ma_dr_mp3dec_frame_info* pMP3FrameInfo, const ma_uint8** ppMP3FrameData) { ma_uint32 pcmFramesRead = 0; MA_DR_MP3_ASSERT(pMP3 != NULL); @@ -92559,7 +94596,7 @@ static ma_uint32 ma_dr_mp3_decode_next_frame_ex__callbacks(ma_dr_mp3* pMP3, ma_d pMP3->pData = pNewData; pMP3->dataCapacity = newDataCap; } - bytesRead = ma_dr_mp3__on_read(pMP3, pMP3->pData + pMP3->dataSize, (pMP3->dataCapacity - pMP3->dataSize)); + bytesRead = ma_dr_mp3__on_read_clamped(pMP3, pMP3->pData + pMP3->dataSize, (pMP3->dataCapacity - pMP3->dataSize)); if (bytesRead == 0) { if (pMP3->dataSize == 0) { pMP3->atEnd = MA_TRUE; @@ -92578,16 +94615,20 @@ static ma_uint32 ma_dr_mp3_decode_next_frame_ex__callbacks(ma_dr_mp3* pMP3, ma_d return 0; } pcmFramesRead = ma_dr_mp3dec_decode_frame(&pMP3->decoder, pMP3->pData + pMP3->dataConsumed, (int)pMP3->dataSize, pPCMFrames, &info); - if (info.frame_bytes > 0) { - pMP3->dataConsumed += (size_t)info.frame_bytes; - pMP3->dataSize -= (size_t)info.frame_bytes; - } + pMP3->dataConsumed += (size_t)info.frame_bytes; + pMP3->dataSize -= (size_t)info.frame_bytes; if (pcmFramesRead > 0) { pcmFramesRead = ma_dr_mp3_hdr_frame_samples(pMP3->decoder.header); pMP3->pcmFramesConsumedInMP3Frame = 0; pMP3->pcmFramesRemainingInMP3Frame = pcmFramesRead; pMP3->mp3FrameChannels = info.channels; - pMP3->mp3FrameSampleRate = info.hz; + pMP3->mp3FrameSampleRate = info.sample_rate; + if (pMP3FrameInfo != NULL) { + *pMP3FrameInfo = info; + } + if (ppMP3FrameData != NULL) { + *ppMP3FrameData = pMP3->pData + pMP3->dataConsumed - (size_t)info.frame_bytes; + } break; } else if (info.frame_bytes == 0) { size_t bytesRead; @@ -92604,7 +94645,7 @@ static ma_uint32 ma_dr_mp3_decode_next_frame_ex__callbacks(ma_dr_mp3* pMP3, ma_d pMP3->pData = pNewData; pMP3->dataCapacity = newDataCap; } - bytesRead = ma_dr_mp3__on_read(pMP3, pMP3->pData + pMP3->dataSize, (pMP3->dataCapacity - pMP3->dataSize)); + bytesRead = ma_dr_mp3__on_read_clamped(pMP3, pMP3->pData + pMP3->dataSize, (pMP3->dataCapacity - pMP3->dataSize)); if (bytesRead == 0) { pMP3->atEnd = MA_TRUE; return 0; @@ -92614,7 +94655,7 @@ static ma_uint32 ma_dr_mp3_decode_next_frame_ex__callbacks(ma_dr_mp3* pMP3, ma_d }; return pcmFramesRead; } -static ma_uint32 ma_dr_mp3_decode_next_frame_ex__memory(ma_dr_mp3* pMP3, ma_dr_mp3d_sample_t* pPCMFrames) +static ma_uint32 ma_dr_mp3_decode_next_frame_ex__memory(ma_dr_mp3* pMP3, ma_dr_mp3d_sample_t* pPCMFrames, ma_dr_mp3dec_frame_info* pMP3FrameInfo, const ma_uint8** ppMP3FrameData) { ma_uint32 pcmFramesRead = 0; ma_dr_mp3dec_frame_info info; @@ -92630,36 +94671,44 @@ static ma_uint32 ma_dr_mp3_decode_next_frame_ex__memory(ma_dr_mp3* pMP3, ma_dr_m pMP3->pcmFramesConsumedInMP3Frame = 0; pMP3->pcmFramesRemainingInMP3Frame = pcmFramesRead; pMP3->mp3FrameChannels = info.channels; - pMP3->mp3FrameSampleRate = info.hz; + pMP3->mp3FrameSampleRate = info.sample_rate; + if (pMP3FrameInfo != NULL) { + *pMP3FrameInfo = info; + } + if (ppMP3FrameData != NULL) { + *ppMP3FrameData = pMP3->memory.pData + pMP3->memory.currentReadPos; + } break; } else if (info.frame_bytes > 0) { pMP3->memory.currentReadPos += (size_t)info.frame_bytes; + pMP3->streamCursor += (size_t)info.frame_bytes; } else { break; } } pMP3->memory.currentReadPos += (size_t)info.frame_bytes; + pMP3->streamCursor += (size_t)info.frame_bytes; return pcmFramesRead; } -static ma_uint32 ma_dr_mp3_decode_next_frame_ex(ma_dr_mp3* pMP3, ma_dr_mp3d_sample_t* pPCMFrames) +static ma_uint32 ma_dr_mp3_decode_next_frame_ex(ma_dr_mp3* pMP3, ma_dr_mp3d_sample_t* pPCMFrames, ma_dr_mp3dec_frame_info* pMP3FrameInfo, const ma_uint8** ppMP3FrameData) { if (pMP3->memory.pData != NULL && pMP3->memory.dataSize > 0) { - return ma_dr_mp3_decode_next_frame_ex__memory(pMP3, pPCMFrames); + return ma_dr_mp3_decode_next_frame_ex__memory(pMP3, pPCMFrames, pMP3FrameInfo, ppMP3FrameData); } else { - return ma_dr_mp3_decode_next_frame_ex__callbacks(pMP3, pPCMFrames); + return ma_dr_mp3_decode_next_frame_ex__callbacks(pMP3, pPCMFrames, pMP3FrameInfo, ppMP3FrameData); } } static ma_uint32 ma_dr_mp3_decode_next_frame(ma_dr_mp3* pMP3) { MA_DR_MP3_ASSERT(pMP3 != NULL); - return ma_dr_mp3_decode_next_frame_ex(pMP3, (ma_dr_mp3d_sample_t*)pMP3->pcmFrames); + return ma_dr_mp3_decode_next_frame_ex(pMP3, (ma_dr_mp3d_sample_t*)pMP3->pcmFrames, NULL, NULL); } #if 0 static ma_uint32 ma_dr_mp3_seek_next_frame(ma_dr_mp3* pMP3) { ma_uint32 pcmFrameCount; MA_DR_MP3_ASSERT(pMP3 != NULL); - pcmFrameCount = ma_dr_mp3_decode_next_frame_ex(pMP3, NULL); + pcmFrameCount = ma_dr_mp3_decode_next_frame_ex(pMP3, NULL, NULL, NULL); if (pcmFrameCount == 0) { return 0; } @@ -92669,33 +94718,252 @@ static ma_uint32 ma_dr_mp3_seek_next_frame(ma_dr_mp3* pMP3) return pcmFrameCount; } #endif -static ma_bool32 ma_dr_mp3_init_internal(ma_dr_mp3* pMP3, ma_dr_mp3_read_proc onRead, ma_dr_mp3_seek_proc onSeek, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks) +static ma_bool32 ma_dr_mp3_init_internal(ma_dr_mp3* pMP3, ma_dr_mp3_read_proc onRead, ma_dr_mp3_seek_proc onSeek, ma_dr_mp3_tell_proc onTell, ma_dr_mp3_meta_proc onMeta, void* pUserData, void* pUserDataMeta, const ma_allocation_callbacks* pAllocationCallbacks) { + ma_dr_mp3dec_frame_info firstFrameInfo; + const ma_uint8* pFirstFrameData; + ma_uint32 firstFramePCMFrameCount; + ma_uint32 detectedMP3FrameCount = 0xFFFFFFFF; MA_DR_MP3_ASSERT(pMP3 != NULL); MA_DR_MP3_ASSERT(onRead != NULL); ma_dr_mp3dec_init(&pMP3->decoder); pMP3->onRead = onRead; pMP3->onSeek = onSeek; + pMP3->onMeta = onMeta; pMP3->pUserData = pUserData; + pMP3->pUserDataMeta = pUserDataMeta; pMP3->allocationCallbacks = ma_dr_mp3_copy_allocation_callbacks_or_defaults(pAllocationCallbacks); if (pMP3->allocationCallbacks.onFree == NULL || (pMP3->allocationCallbacks.onMalloc == NULL && pMP3->allocationCallbacks.onRealloc == NULL)) { return MA_FALSE; } - if (ma_dr_mp3_decode_next_frame(pMP3) == 0) { + pMP3->streamCursor = 0; + pMP3->streamLength = MA_UINT64_MAX; + pMP3->streamStartOffset = 0; + pMP3->delayInPCMFrames = 0; + pMP3->paddingInPCMFrames = 0; + pMP3->totalPCMFrameCount = MA_UINT64_MAX; + #if 1 + if (onSeek != NULL && onTell != NULL) { + if (onSeek(pUserData, 0, MA_DR_MP3_SEEK_END)) { + ma_int64 streamLen; + int streamEndOffset = 0; + if (onTell(pUserData, &streamLen)) { + if (streamLen > 128) { + char id3[3]; + if (onSeek(pUserData, streamEndOffset - 128, MA_DR_MP3_SEEK_END)) { + if (onRead(pUserData, id3, 3) == 3 && id3[0] == 'T' && id3[1] == 'A' && id3[2] == 'G') { + streamEndOffset -= 128; + streamLen -= 128; + if (onMeta != NULL) { + ma_uint8 tag[128]; + tag[0] = 'T'; tag[1] = 'A'; tag[2] = 'G'; + if (onRead(pUserData, tag + 3, 125) == 125) { + ma_dr_mp3__on_meta(pMP3, MA_DR_MP3_METADATA_TYPE_ID3V1, tag, 128); + } + } + } else { + } + } else { + } + } else { + } + if (streamLen > 32) { + char ape[32]; + if (onSeek(pUserData, streamEndOffset - 32, MA_DR_MP3_SEEK_END)) { + if (onRead(pUserData, ape, 32) == 32 && ape[0] == 'A' && ape[1] == 'P' && ape[2] == 'E' && ape[3] == 'T' && ape[4] == 'A' && ape[5] == 'G' && ape[6] == 'E' && ape[7] == 'X') { + ma_uint32 tagSize = + ((ma_uint32)ape[24] << 0) | + ((ma_uint32)ape[25] << 8) | + ((ma_uint32)ape[26] << 16) | + ((ma_uint32)ape[27] << 24); + if (32 + tagSize < streamLen) { + streamEndOffset -= 32 + tagSize; + streamLen -= 32 + tagSize; + if (onMeta != NULL) { + if (onSeek(pUserData, streamEndOffset, MA_DR_MP3_SEEK_END)) { + size_t apeTagSize = (size_t)tagSize + 32; + ma_uint8* pTagData = (ma_uint8*)ma_dr_mp3_malloc(apeTagSize, pAllocationCallbacks); + if (pTagData != NULL) { + if (onRead(pUserData, pTagData, apeTagSize) == apeTagSize) { + ma_dr_mp3__on_meta(pMP3, MA_DR_MP3_METADATA_TYPE_APE, pTagData, apeTagSize); + } + ma_dr_mp3_free(pTagData, pAllocationCallbacks); + } + } + } + } else { + } + } + } + } else { + } + if (!onSeek(pUserData, 0, MA_DR_MP3_SEEK_SET)) { + return MA_FALSE; + } + pMP3->streamLength = (ma_uint64)streamLen; + if (pMP3->memory.pData != NULL) { + pMP3->memory.dataSize = (size_t)pMP3->streamLength; + } + } else { + if (!onSeek(pUserData, 0, MA_DR_MP3_SEEK_SET)) { + return MA_FALSE; + } + } + } else { + } + } else { + } + #endif + #if 1 + { + char header[10]; + if (onRead(pUserData, header, 10) == 10) { + if (header[0] == 'I' && header[1] == 'D' && header[2] == '3') { + ma_uint32 tagSize = + (((ma_uint32)header[6] & 0x7F) << 21) | + (((ma_uint32)header[7] & 0x7F) << 14) | + (((ma_uint32)header[8] & 0x7F) << 7) | + (((ma_uint32)header[9] & 0x7F) << 0); + if (header[5] & 0x10) { + tagSize += 10; + } + if (onMeta != NULL) { + size_t tagSizeWithHeader = 10 + tagSize; + ma_uint8* pTagData = (ma_uint8*)ma_dr_mp3_malloc(tagSizeWithHeader, pAllocationCallbacks); + if (pTagData != NULL) { + MA_DR_MP3_COPY_MEMORY(pTagData, header, 10); + if (onRead(pUserData, pTagData + 10, tagSize) == tagSize) { + ma_dr_mp3__on_meta(pMP3, MA_DR_MP3_METADATA_TYPE_ID3V2, pTagData, tagSizeWithHeader); + } + ma_dr_mp3_free(pTagData, pAllocationCallbacks); + } + } else { + if (onSeek != NULL) { + if (!onSeek(pUserData, tagSize, MA_DR_MP3_SEEK_CUR)) { + return MA_FALSE; + } + } else { + char discard[1024]; + while (tagSize > 0) { + size_t bytesToRead = tagSize; + if (bytesToRead > sizeof(discard)) { + bytesToRead = sizeof(discard); + } + if (onRead(pUserData, discard, bytesToRead) != bytesToRead) { + return MA_FALSE; + } + tagSize -= (ma_uint32)bytesToRead; + } + } + } + pMP3->streamStartOffset += 10 + tagSize; + pMP3->streamCursor = pMP3->streamStartOffset; + } else { + if (onSeek != NULL) { + if (!onSeek(pUserData, 0, MA_DR_MP3_SEEK_SET)) { + return MA_FALSE; + } + } else { + } + } + } else { + return MA_FALSE; + } + } + #endif + firstFramePCMFrameCount = ma_dr_mp3_decode_next_frame_ex(pMP3, (ma_dr_mp3d_sample_t*)pMP3->pcmFrames, &firstFrameInfo, &pFirstFrameData); + if (firstFramePCMFrameCount > 0) { + MA_DR_MP3_ASSERT(pFirstFrameData != NULL); + #if 1 + MA_DR_MP3_ASSERT(firstFrameInfo.frame_bytes > 0); + { + ma_dr_mp3_bs bs; + ma_dr_mp3_L3_gr_info grInfo[4]; + ma_dr_mp3_bs_init(&bs, pFirstFrameData + MA_DR_MP3_HDR_SIZE, firstFrameInfo.frame_bytes - MA_DR_MP3_HDR_SIZE); + if (MA_DR_MP3_HDR_IS_CRC(pFirstFrameData)) { + ma_dr_mp3_bs_get_bits(&bs, 16); + } + if (ma_dr_mp3_L3_read_side_info(&bs, grInfo, pFirstFrameData) >= 0) { + ma_bool32 isXing = MA_FALSE; + ma_bool32 isInfo = MA_FALSE; + const ma_uint8* pTagData; + const ma_uint8* pTagDataBeg; + pTagDataBeg = pFirstFrameData + MA_DR_MP3_HDR_SIZE + (bs.pos/8); + pTagData = pTagDataBeg; + isXing = (pTagData[0] == 'X' && pTagData[1] == 'i' && pTagData[2] == 'n' && pTagData[3] == 'g'); + isInfo = (pTagData[0] == 'I' && pTagData[1] == 'n' && pTagData[2] == 'f' && pTagData[3] == 'o'); + if (isXing || isInfo) { + ma_uint32 bytes = 0; + ma_uint32 flags = pTagData[7]; + pTagData += 8; + if (flags & 0x01) { + detectedMP3FrameCount = (ma_uint32)pTagData[0] << 24 | (ma_uint32)pTagData[1] << 16 | (ma_uint32)pTagData[2] << 8 | (ma_uint32)pTagData[3]; + pTagData += 4; + } + if (flags & 0x02) { + bytes = (ma_uint32)pTagData[0] << 24 | (ma_uint32)pTagData[1] << 16 | (ma_uint32)pTagData[2] << 8 | (ma_uint32)pTagData[3]; + (void)bytes; + pTagData += 4; + } + if (flags & 0x04) { + pTagData += 100; + } + if (flags & 0x08) { + pTagData += 4; + } + if (pTagData[0]) { + pTagData += 21; + if (pTagData - pFirstFrameData + 14 < firstFrameInfo.frame_bytes) { + int delayInPCMFrames; + int paddingInPCMFrames; + delayInPCMFrames = (( (ma_uint32)pTagData[0] << 4) | ((ma_uint32)pTagData[1] >> 4)) + (528 + 1); + paddingInPCMFrames = ((((ma_uint32)pTagData[1] & 0xF) << 8) | ((ma_uint32)pTagData[2] )) - (528 + 1); + if (paddingInPCMFrames < 0) { + paddingInPCMFrames = 0; + } + pMP3->delayInPCMFrames = (ma_uint32)delayInPCMFrames; + pMP3->paddingInPCMFrames = (ma_uint32)paddingInPCMFrames; + } + } + if (isXing) { + pMP3->isVBR = MA_TRUE; + } else if (isInfo) { + pMP3->isCBR = MA_TRUE; + } + if (onMeta != NULL) { + ma_dr_mp3_metadata_type metadataType = isXing ? MA_DR_MP3_METADATA_TYPE_XING : MA_DR_MP3_METADATA_TYPE_VBRI; + size_t tagDataSize; + tagDataSize = (size_t)firstFrameInfo.frame_bytes; + tagDataSize -= (size_t)(pTagDataBeg - pFirstFrameData); + ma_dr_mp3__on_meta(pMP3, metadataType, pTagDataBeg, tagDataSize); + } + pMP3->pcmFramesRemainingInMP3Frame = 0; + pMP3->streamStartOffset += (ma_uint32)(firstFrameInfo.frame_bytes); + pMP3->streamCursor = pMP3->streamStartOffset; + ma_dr_mp3dec_init(&pMP3->decoder); + } + } else { + } + } + #endif + } else { ma_dr_mp3__free_from_callbacks(pMP3->pData, &pMP3->allocationCallbacks); return MA_FALSE; } + if (detectedMP3FrameCount != 0xFFFFFFFF) { + pMP3->totalPCMFrameCount = detectedMP3FrameCount * firstFramePCMFrameCount; + } pMP3->channels = pMP3->mp3FrameChannels; pMP3->sampleRate = pMP3->mp3FrameSampleRate; return MA_TRUE; } -MA_API ma_bool32 ma_dr_mp3_init(ma_dr_mp3* pMP3, ma_dr_mp3_read_proc onRead, ma_dr_mp3_seek_proc onSeek, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks) +MA_API ma_bool32 ma_dr_mp3_init(ma_dr_mp3* pMP3, ma_dr_mp3_read_proc onRead, ma_dr_mp3_seek_proc onSeek, ma_dr_mp3_tell_proc onTell, ma_dr_mp3_meta_proc onMeta, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks) { if (pMP3 == NULL || onRead == NULL) { return MA_FALSE; } MA_DR_MP3_ZERO_OBJECT(pMP3); - return ma_dr_mp3_init_internal(pMP3, onRead, onSeek, pUserData, pAllocationCallbacks); + return ma_dr_mp3_init_internal(pMP3, onRead, onSeek, onTell, onMeta, pUserData, pUserData, pAllocationCallbacks); } static size_t ma_dr_mp3__on_read_memory(void* pUserData, void* pBufferOut, size_t bytesToRead) { @@ -92716,29 +94984,39 @@ static size_t ma_dr_mp3__on_read_memory(void* pUserData, void* pBufferOut, size_ static ma_bool32 ma_dr_mp3__on_seek_memory(void* pUserData, int byteOffset, ma_dr_mp3_seek_origin origin) { ma_dr_mp3* pMP3 = (ma_dr_mp3*)pUserData; + ma_int64 newCursor; MA_DR_MP3_ASSERT(pMP3 != NULL); - if (origin == ma_dr_mp3_seek_origin_current) { - if (byteOffset > 0) { - if (pMP3->memory.currentReadPos + byteOffset > pMP3->memory.dataSize) { - byteOffset = (int)(pMP3->memory.dataSize - pMP3->memory.currentReadPos); - } - } else { - if (pMP3->memory.currentReadPos < (size_t)-byteOffset) { - byteOffset = -(int)pMP3->memory.currentReadPos; - } - } - pMP3->memory.currentReadPos += byteOffset; + if (origin == MA_DR_MP3_SEEK_SET) { + newCursor = 0; + } else if (origin == MA_DR_MP3_SEEK_CUR) { + newCursor = (ma_int64)pMP3->memory.currentReadPos; + } else if (origin == MA_DR_MP3_SEEK_END) { + newCursor = (ma_int64)pMP3->memory.dataSize; } else { - if ((ma_uint32)byteOffset <= pMP3->memory.dataSize) { - pMP3->memory.currentReadPos = byteOffset; - } else { - pMP3->memory.currentReadPos = pMP3->memory.dataSize; - } + MA_DR_MP3_ASSERT(!"Invalid seek origin"); + return MA_FALSE; } + newCursor += byteOffset; + if (newCursor < 0) { + return MA_FALSE; + } + if ((size_t)newCursor > pMP3->memory.dataSize) { + return MA_FALSE; + } + pMP3->memory.currentReadPos = (size_t)newCursor; return MA_TRUE; } -MA_API ma_bool32 ma_dr_mp3_init_memory(ma_dr_mp3* pMP3, const void* pData, size_t dataSize, const ma_allocation_callbacks* pAllocationCallbacks) +static ma_bool32 ma_dr_mp3__on_tell_memory(void* pUserData, ma_int64* pCursor) { + ma_dr_mp3* pMP3 = (ma_dr_mp3*)pUserData; + MA_DR_MP3_ASSERT(pMP3 != NULL); + MA_DR_MP3_ASSERT(pCursor != NULL); + *pCursor = (ma_int64)pMP3->memory.currentReadPos; + return MA_TRUE; +} +MA_API ma_bool32 ma_dr_mp3_init_memory_with_metadata(ma_dr_mp3* pMP3, const void* pData, size_t dataSize, ma_dr_mp3_meta_proc onMeta, void* pUserDataMeta, const ma_allocation_callbacks* pAllocationCallbacks) +{ + ma_bool32 result; if (pMP3 == NULL) { return MA_FALSE; } @@ -92749,7 +95027,21 @@ MA_API ma_bool32 ma_dr_mp3_init_memory(ma_dr_mp3* pMP3, const void* pData, size_ pMP3->memory.pData = (const ma_uint8*)pData; pMP3->memory.dataSize = dataSize; pMP3->memory.currentReadPos = 0; - return ma_dr_mp3_init_internal(pMP3, ma_dr_mp3__on_read_memory, ma_dr_mp3__on_seek_memory, pMP3, pAllocationCallbacks); + result = ma_dr_mp3_init_internal(pMP3, ma_dr_mp3__on_read_memory, ma_dr_mp3__on_seek_memory, ma_dr_mp3__on_tell_memory, onMeta, pMP3, pUserDataMeta, pAllocationCallbacks); + if (result == MA_FALSE) { + return MA_FALSE; + } + if (pMP3->streamLength <= (ma_uint64)MA_SIZE_MAX) { + pMP3->memory.dataSize = (size_t)pMP3->streamLength; + } + if (pMP3->streamStartOffset > (ma_uint64)MA_SIZE_MAX) { + return MA_FALSE; + } + return MA_TRUE; +} +MA_API ma_bool32 ma_dr_mp3_init_memory(ma_dr_mp3* pMP3, const void* pData, size_t dataSize, const ma_allocation_callbacks* pAllocationCallbacks) +{ + return ma_dr_mp3_init_memory_with_metadata(pMP3, pData, dataSize, NULL, NULL, pAllocationCallbacks); } #ifndef MA_DR_MP3_NO_STDIO #include @@ -92760,36 +95052,76 @@ static size_t ma_dr_mp3__on_read_stdio(void* pUserData, void* pBufferOut, size_t } static ma_bool32 ma_dr_mp3__on_seek_stdio(void* pUserData, int offset, ma_dr_mp3_seek_origin origin) { - return fseek((FILE*)pUserData, offset, (origin == ma_dr_mp3_seek_origin_current) ? SEEK_CUR : SEEK_SET) == 0; + int whence = SEEK_SET; + if (origin == MA_DR_MP3_SEEK_CUR) { + whence = SEEK_CUR; + } else if (origin == MA_DR_MP3_SEEK_END) { + whence = SEEK_END; + } + return fseek((FILE*)pUserData, offset, whence) == 0; } -MA_API ma_bool32 ma_dr_mp3_init_file(ma_dr_mp3* pMP3, const char* pFilePath, const ma_allocation_callbacks* pAllocationCallbacks) +static ma_bool32 ma_dr_mp3__on_tell_stdio(void* pUserData, ma_int64* pCursor) +{ + FILE* pFileStdio = (FILE*)pUserData; + ma_int64 result; + MA_DR_MP3_ASSERT(pFileStdio != NULL); + MA_DR_MP3_ASSERT(pCursor != NULL); +#if defined(_WIN32) && !defined(NXDK) + #if defined(_MSC_VER) && _MSC_VER > 1200 + result = _ftelli64(pFileStdio); + #else + result = ftell(pFileStdio); + #endif +#else + result = ftell(pFileStdio); +#endif + *pCursor = result; + return MA_TRUE; +} +MA_API ma_bool32 ma_dr_mp3_init_file_with_metadata(ma_dr_mp3* pMP3, const char* pFilePath, ma_dr_mp3_meta_proc onMeta, void* pUserDataMeta, const ma_allocation_callbacks* pAllocationCallbacks) { ma_bool32 result; FILE* pFile; + if (pMP3 == NULL) { + return MA_FALSE; + } + MA_DR_MP3_ZERO_OBJECT(pMP3); if (ma_fopen(&pFile, pFilePath, "rb") != MA_SUCCESS) { return MA_FALSE; } - result = ma_dr_mp3_init(pMP3, ma_dr_mp3__on_read_stdio, ma_dr_mp3__on_seek_stdio, (void*)pFile, pAllocationCallbacks); + result = ma_dr_mp3_init_internal(pMP3, ma_dr_mp3__on_read_stdio, ma_dr_mp3__on_seek_stdio, ma_dr_mp3__on_tell_stdio, onMeta, (void*)pFile, pUserDataMeta, pAllocationCallbacks); if (result != MA_TRUE) { fclose(pFile); return result; } return MA_TRUE; } -MA_API ma_bool32 ma_dr_mp3_init_file_w(ma_dr_mp3* pMP3, const wchar_t* pFilePath, const ma_allocation_callbacks* pAllocationCallbacks) +MA_API ma_bool32 ma_dr_mp3_init_file_with_metadata_w(ma_dr_mp3* pMP3, const wchar_t* pFilePath, ma_dr_mp3_meta_proc onMeta, void* pUserDataMeta, const ma_allocation_callbacks* pAllocationCallbacks) { ma_bool32 result; FILE* pFile; + if (pMP3 == NULL) { + return MA_FALSE; + } + MA_DR_MP3_ZERO_OBJECT(pMP3); if (ma_wfopen(&pFile, pFilePath, L"rb", pAllocationCallbacks) != MA_SUCCESS) { return MA_FALSE; } - result = ma_dr_mp3_init(pMP3, ma_dr_mp3__on_read_stdio, ma_dr_mp3__on_seek_stdio, (void*)pFile, pAllocationCallbacks); + result = ma_dr_mp3_init_internal(pMP3, ma_dr_mp3__on_read_stdio, ma_dr_mp3__on_seek_stdio, ma_dr_mp3__on_tell_stdio, onMeta, (void*)pFile, pUserDataMeta, pAllocationCallbacks); if (result != MA_TRUE) { fclose(pFile); return result; } return MA_TRUE; } +MA_API ma_bool32 ma_dr_mp3_init_file(ma_dr_mp3* pMP3, const char* pFilePath, const ma_allocation_callbacks* pAllocationCallbacks) +{ + return ma_dr_mp3_init_file_with_metadata(pMP3, pFilePath, NULL, NULL, pAllocationCallbacks); +} +MA_API ma_bool32 ma_dr_mp3_init_file_w(ma_dr_mp3* pMP3, const wchar_t* pFilePath, const ma_allocation_callbacks* pAllocationCallbacks) +{ + return ma_dr_mp3_init_file_with_metadata_w(pMP3, pFilePath, NULL, NULL, pAllocationCallbacks); +} #endif MA_API void ma_dr_mp3_uninit(ma_dr_mp3* pMP3) { @@ -92859,17 +95191,38 @@ static ma_uint64 ma_dr_mp3_read_pcm_frames_raw(ma_dr_mp3* pMP3, ma_uint64 frames MA_DR_MP3_ASSERT(pMP3 != NULL); MA_DR_MP3_ASSERT(pMP3->onRead != NULL); while (framesToRead > 0) { - ma_uint32 framesToConsume = (ma_uint32)MA_DR_MP3_MIN(pMP3->pcmFramesRemainingInMP3Frame, framesToRead); + ma_uint32 framesToConsume; + if (pMP3->currentPCMFrame < pMP3->delayInPCMFrames) { + ma_uint32 framesToSkip = (ma_uint32)MA_DR_MP3_MIN(pMP3->pcmFramesRemainingInMP3Frame, pMP3->delayInPCMFrames - pMP3->currentPCMFrame); + pMP3->currentPCMFrame += framesToSkip; + pMP3->pcmFramesConsumedInMP3Frame += framesToSkip; + pMP3->pcmFramesRemainingInMP3Frame -= framesToSkip; + } + framesToConsume = (ma_uint32)MA_DR_MP3_MIN(pMP3->pcmFramesRemainingInMP3Frame, framesToRead); + if (pMP3->totalPCMFrameCount != MA_UINT64_MAX && pMP3->totalPCMFrameCount > pMP3->paddingInPCMFrames) { + if (pMP3->currentPCMFrame < (pMP3->totalPCMFrameCount - pMP3->paddingInPCMFrames)) { + ma_uint64 framesRemainigToPadding = (pMP3->totalPCMFrameCount - pMP3->paddingInPCMFrames) - pMP3->currentPCMFrame; + if (framesToConsume > framesRemainigToPadding) { + framesToConsume = (ma_uint32)framesRemainigToPadding; + } + } else { + break; + } + } if (pBufferOut != NULL) { - #if defined(MA_DR_MP3_FLOAT_OUTPUT) - float* pFramesOutF32 = (float*)MA_DR_MP3_OFFSET_PTR(pBufferOut, sizeof(float) * totalFramesRead * pMP3->channels); - float* pFramesInF32 = (float*)MA_DR_MP3_OFFSET_PTR(&pMP3->pcmFrames[0], sizeof(float) * pMP3->pcmFramesConsumedInMP3Frame * pMP3->mp3FrameChannels); - MA_DR_MP3_COPY_MEMORY(pFramesOutF32, pFramesInF32, sizeof(float) * framesToConsume * pMP3->channels); - #else - ma_int16* pFramesOutS16 = (ma_int16*)MA_DR_MP3_OFFSET_PTR(pBufferOut, sizeof(ma_int16) * totalFramesRead * pMP3->channels); - ma_int16* pFramesInS16 = (ma_int16*)MA_DR_MP3_OFFSET_PTR(&pMP3->pcmFrames[0], sizeof(ma_int16) * pMP3->pcmFramesConsumedInMP3Frame * pMP3->mp3FrameChannels); - MA_DR_MP3_COPY_MEMORY(pFramesOutS16, pFramesInS16, sizeof(ma_int16) * framesToConsume * pMP3->channels); - #endif + #if defined(MA_DR_MP3_FLOAT_OUTPUT) + { + float* pFramesOutF32 = (float*)MA_DR_MP3_OFFSET_PTR(pBufferOut, sizeof(float) * totalFramesRead * pMP3->channels); + float* pFramesInF32 = (float*)MA_DR_MP3_OFFSET_PTR(&pMP3->pcmFrames[0], sizeof(float) * pMP3->pcmFramesConsumedInMP3Frame * pMP3->mp3FrameChannels); + MA_DR_MP3_COPY_MEMORY(pFramesOutF32, pFramesInF32, sizeof(float) * framesToConsume * pMP3->channels); + } + #else + { + ma_int16* pFramesOutS16 = (ma_int16*)MA_DR_MP3_OFFSET_PTR(pBufferOut, sizeof(ma_int16) * totalFramesRead * pMP3->channels); + ma_int16* pFramesInS16 = (ma_int16*)MA_DR_MP3_OFFSET_PTR(&pMP3->pcmFrames[0], sizeof(ma_int16) * pMP3->pcmFramesConsumedInMP3Frame * pMP3->mp3FrameChannels); + MA_DR_MP3_COPY_MEMORY(pFramesOutS16, pFramesInS16, sizeof(ma_int16) * framesToConsume * pMP3->channels); + } + #endif } pMP3->currentPCMFrame += framesToConsume; pMP3->pcmFramesConsumedInMP3Frame += framesToConsume; @@ -92879,6 +95232,9 @@ static ma_uint64 ma_dr_mp3_read_pcm_frames_raw(ma_dr_mp3* pMP3, ma_uint64 frames if (framesToRead == 0) { break; } + if (pMP3->totalPCMFrameCount != MA_UINT64_MAX && pMP3->totalPCMFrameCount > pMP3->paddingInPCMFrames && pMP3->currentPCMFrame >= (pMP3->totalPCMFrameCount - pMP3->paddingInPCMFrames)) { + break; + } MA_DR_MP3_ASSERT(pMP3->pcmFramesRemainingInMP3Frame == 0); if (ma_dr_mp3_decode_next_frame(pMP3) == 0) { break; @@ -92958,7 +95314,7 @@ static ma_bool32 ma_dr_mp3_seek_to_start_of_stream(ma_dr_mp3* pMP3) { MA_DR_MP3_ASSERT(pMP3 != NULL); MA_DR_MP3_ASSERT(pMP3->onSeek != NULL); - if (!ma_dr_mp3__on_seek(pMP3, 0, ma_dr_mp3_seek_origin_start)) { + if (!ma_dr_mp3__on_seek_64(pMP3, pMP3->streamStartOffset, MA_DR_MP3_SEEK_SET)) { return MA_FALSE; } ma_dr_mp3_reset(pMP3); @@ -93024,7 +95380,7 @@ static ma_bool32 ma_dr_mp3_seek_to_pcm_frame__seek_table(ma_dr_mp3* pMP3, ma_uin seekPoint.mp3FramesToDiscard = 0; seekPoint.pcmFramesToDiscard = 0; } - if (!ma_dr_mp3__on_seek_64(pMP3, seekPoint.seekPosInBytes, ma_dr_mp3_seek_origin_start)) { + if (!ma_dr_mp3__on_seek_64(pMP3, seekPoint.seekPosInBytes, MA_DR_MP3_SEEK_SET)) { return MA_FALSE; } ma_dr_mp3_reset(pMP3); @@ -93035,7 +95391,7 @@ static ma_bool32 ma_dr_mp3_seek_to_pcm_frame__seek_table(ma_dr_mp3* pMP3, ma_uin if (iMP3Frame == seekPoint.mp3FramesToDiscard-1) { pPCMFrames = (ma_dr_mp3d_sample_t*)pMP3->pcmFrames; } - pcmFramesRead = ma_dr_mp3_decode_next_frame_ex(pMP3, pPCMFrames); + pcmFramesRead = ma_dr_mp3_decode_next_frame_ex(pMP3, pPCMFrames, NULL, NULL); if (pcmFramesRead == 0) { return MA_FALSE; } @@ -93077,7 +95433,7 @@ MA_API ma_bool32 ma_dr_mp3_get_mp3_and_pcm_frame_count(ma_dr_mp3* pMP3, ma_uint6 totalMP3FrameCount = 0; for (;;) { ma_uint32 pcmFramesInCurrentMP3Frame; - pcmFramesInCurrentMP3Frame = ma_dr_mp3_decode_next_frame_ex(pMP3, NULL); + pcmFramesInCurrentMP3Frame = ma_dr_mp3_decode_next_frame_ex(pMP3, NULL, NULL, NULL); if (pcmFramesInCurrentMP3Frame == 0) { break; } @@ -93101,10 +95457,26 @@ MA_API ma_bool32 ma_dr_mp3_get_mp3_and_pcm_frame_count(ma_dr_mp3* pMP3, ma_uint6 MA_API ma_uint64 ma_dr_mp3_get_pcm_frame_count(ma_dr_mp3* pMP3) { ma_uint64 totalPCMFrameCount; - if (!ma_dr_mp3_get_mp3_and_pcm_frame_count(pMP3, NULL, &totalPCMFrameCount)) { + if (pMP3 == NULL) { return 0; } - return totalPCMFrameCount; + if (pMP3->totalPCMFrameCount != MA_UINT64_MAX) { + totalPCMFrameCount = pMP3->totalPCMFrameCount; + if (totalPCMFrameCount >= pMP3->delayInPCMFrames) { + totalPCMFrameCount -= pMP3->delayInPCMFrames; + } else { + } + if (totalPCMFrameCount >= pMP3->paddingInPCMFrames) { + totalPCMFrameCount -= pMP3->paddingInPCMFrames; + } else { + } + return totalPCMFrameCount; + } else { + if (!ma_dr_mp3_get_mp3_and_pcm_frame_count(pMP3, NULL, &totalPCMFrameCount)) { + return 0; + } + return totalPCMFrameCount; + } } MA_API ma_uint64 ma_dr_mp3_get_mp3_frame_count(ma_dr_mp3* pMP3) { @@ -93174,7 +95546,7 @@ MA_API ma_bool32 ma_dr_mp3_calculate_seek_points(ma_dr_mp3* pMP3, ma_uint32* pSe MA_DR_MP3_ASSERT(pMP3->streamCursor >= pMP3->dataSize); mp3FrameInfo[iMP3Frame].bytePos = pMP3->streamCursor - pMP3->dataSize; mp3FrameInfo[iMP3Frame].pcmFrameIndex = runningPCMFrameCount; - pcmFramesInCurrentMP3FrameIn = ma_dr_mp3_decode_next_frame_ex(pMP3, NULL); + pcmFramesInCurrentMP3FrameIn = ma_dr_mp3_decode_next_frame_ex(pMP3, NULL, NULL, NULL); if (pcmFramesInCurrentMP3FrameIn == 0) { return MA_FALSE; } @@ -93198,7 +95570,7 @@ MA_API ma_bool32 ma_dr_mp3_calculate_seek_points(ma_dr_mp3* pMP3, ma_uint32* pSe } mp3FrameInfo[MA_DR_MP3_COUNTOF(mp3FrameInfo)-1].bytePos = pMP3->streamCursor - pMP3->dataSize; mp3FrameInfo[MA_DR_MP3_COUNTOF(mp3FrameInfo)-1].pcmFrameIndex = runningPCMFrameCount; - pcmFramesInCurrentMP3FrameIn = ma_dr_mp3_decode_next_frame_ex(pMP3, NULL); + pcmFramesInCurrentMP3FrameIn = ma_dr_mp3_decode_next_frame_ex(pMP3, NULL, NULL, NULL); if (pcmFramesInCurrentMP3FrameIn == 0) { pSeekPoints[iSeekPoint].seekPosInBytes = mp3FrameInfo[0].bytePos; pSeekPoints[iSeekPoint].pcmFrameIndex = nextTargetPCMFrame; @@ -93264,6 +95636,8 @@ static float* ma_dr_mp3__full_read_and_close_f32(ma_dr_mp3* pMP3, ma_dr_mp3_conf pNewFrames = (float*)ma_dr_mp3__realloc_from_callbacks(pFrames, (size_t)newFramesBufferSize, (size_t)oldFramesBufferSize, &pMP3->allocationCallbacks); if (pNewFrames == NULL) { ma_dr_mp3__free_from_callbacks(pFrames, &pMP3->allocationCallbacks); + pFrames = NULL; + totalFramesRead = 0; break; } pFrames = pNewFrames; @@ -93315,6 +95689,8 @@ static ma_int16* ma_dr_mp3__full_read_and_close_s16(ma_dr_mp3* pMP3, ma_dr_mp3_c pNewFrames = (ma_int16*)ma_dr_mp3__realloc_from_callbacks(pFrames, (size_t)newFramesBufferSize, (size_t)oldFramesBufferSize, &pMP3->allocationCallbacks); if (pNewFrames == NULL) { ma_dr_mp3__free_from_callbacks(pFrames, &pMP3->allocationCallbacks); + pFrames = NULL; + totalFramesRead = 0; break; } pFrames = pNewFrames; @@ -93336,18 +95712,18 @@ static ma_int16* ma_dr_mp3__full_read_and_close_s16(ma_dr_mp3* pMP3, ma_dr_mp3_c } return pFrames; } -MA_API float* ma_dr_mp3_open_and_read_pcm_frames_f32(ma_dr_mp3_read_proc onRead, ma_dr_mp3_seek_proc onSeek, void* pUserData, ma_dr_mp3_config* pConfig, ma_uint64* pTotalFrameCount, const ma_allocation_callbacks* pAllocationCallbacks) +MA_API float* ma_dr_mp3_open_and_read_pcm_frames_f32(ma_dr_mp3_read_proc onRead, ma_dr_mp3_seek_proc onSeek, ma_dr_mp3_tell_proc onTell, void* pUserData, ma_dr_mp3_config* pConfig, ma_uint64* pTotalFrameCount, const ma_allocation_callbacks* pAllocationCallbacks) { ma_dr_mp3 mp3; - if (!ma_dr_mp3_init(&mp3, onRead, onSeek, pUserData, pAllocationCallbacks)) { + if (!ma_dr_mp3_init(&mp3, onRead, onSeek, onTell, NULL, pUserData, pAllocationCallbacks)) { return NULL; } return ma_dr_mp3__full_read_and_close_f32(&mp3, pConfig, pTotalFrameCount); } -MA_API ma_int16* ma_dr_mp3_open_and_read_pcm_frames_s16(ma_dr_mp3_read_proc onRead, ma_dr_mp3_seek_proc onSeek, void* pUserData, ma_dr_mp3_config* pConfig, ma_uint64* pTotalFrameCount, const ma_allocation_callbacks* pAllocationCallbacks) +MA_API ma_int16* ma_dr_mp3_open_and_read_pcm_frames_s16(ma_dr_mp3_read_proc onRead, ma_dr_mp3_seek_proc onSeek, ma_dr_mp3_tell_proc onTell, void* pUserData, ma_dr_mp3_config* pConfig, ma_uint64* pTotalFrameCount, const ma_allocation_callbacks* pAllocationCallbacks) { ma_dr_mp3 mp3; - if (!ma_dr_mp3_init(&mp3, onRead, onSeek, pUserData, pAllocationCallbacks)) { + if (!ma_dr_mp3_init(&mp3, onRead, onSeek, onTell, NULL, pUserData, pAllocationCallbacks)) { return NULL; } return ma_dr_mp3__full_read_and_close_s16(&mp3, pConfig, pTotalFrameCount); From 4bea3cd329fa141bdce3819d91f4c20075850ec0 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 15 Feb 2026 22:21:04 +0200 Subject: [PATCH 157/831] ggml : bump version to 0.9.7 (ggml/1425) --- ggml/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index aa0ecde02a7..4323afe57b5 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -4,7 +4,7 @@ project("ggml" C CXX ASM) ### GGML Version set(GGML_VERSION_MAJOR 0) set(GGML_VERSION_MINOR 9) -set(GGML_VERSION_PATCH 6) +set(GGML_VERSION_PATCH 7) set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}") find_program(GIT_EXE NAMES git git.exe NO_CMAKE_FIND_ROOT_PATH) From 7ee772ab2becef7adbba13b53e74dd8d4a481583 Mon Sep 17 00:00:00 2001 From: SamareshSingh <97642706+ssam18@users.noreply.github.com> Date: Sat, 14 Feb 2026 23:22:53 -0600 Subject: [PATCH 158/831] cmake : fix KleidiAI install target failure with EXCLUDE_FROM_ALL (llama/19581) * cmake: fix KleidiAI install target failure with EXCLUDE_FROM_ALL Fix for the bug #19501 by adding EXCLUDE_FROM_ALL to FetchContent_Declare. This properly excludes KleidiAI from both build and install targets, preventing install failures when GGML_CPU_KLEIDIAI=ON is used. The KleidiAI source files are still compiled into libggml-cpu.so, preserving all functionality. * addressed code review comments --- ggml/src/ggml-cpu/CMakeLists.txt | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index 7622d0bf49b..6aea7f7bfb9 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -569,12 +569,14 @@ function(ggml_add_cpu_backend_variant_impl tag_name) cmake_policy(SET CMP0135 NEW) endif() + # TODO: Use FetchContent_MakeAvailable with EXCLUDE_FROM_ALL after bumping minimum CMake version to 3.28+ + # Using FetchContent_Populate instead to avoid EXCLUDE_FROM_ALL which requires CMake 3.28 FetchContent_Declare(KleidiAI_Download URL ${KLEIDIAI_DOWNLOAD_URL} DOWNLOAD_EXTRACT_TIMESTAMP NEW URL_HASH MD5=${KLEIDIAI_ARCHIVE_MD5}) - FetchContent_MakeAvailable(KleidiAI_Download) + FetchContent_Populate(KleidiAI_Download) FetchContent_GetProperties(KleidiAI_Download SOURCE_DIR KLEIDIAI_SRC POPULATED KLEIDIAI_POPULATED) @@ -585,11 +587,6 @@ function(ggml_add_cpu_backend_variant_impl tag_name) add_compile_definitions(GGML_USE_CPU_KLEIDIAI) - # Remove kleidiai target after fetching it - if (TARGET kleidiai) - set_target_properties(kleidiai PROPERTIES EXCLUDE_FROM_ALL TRUE) - endif() - list(APPEND GGML_CPU_SOURCES ggml-cpu/kleidiai/kleidiai.cpp ggml-cpu/kleidiai/kernels.cpp From 76f769d06fe831a6c2c13bcbec341bab92455bd6 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Sun, 15 Feb 2026 11:09:24 +0530 Subject: [PATCH 159/831] ggml-cpu: FA add GEMM microkernel (llama/19422) * ggml-cpu: FA add GEMM microkernel * add guard for sizeless vector types * fix case where DV % GGML_F32_EPR !=0 * move memset out of the loop * move another memset out of the loop * use RM=4 for arm * simd_gemm: convert everything to int * convert everything to size_t to avoid warnings * fixup * add pragma for ignoring aggressive loop optimizations --- ggml/src/ggml-cpu/common.h | 4 +- ggml/src/ggml-cpu/ggml-cpu.c | 4 +- ggml/src/ggml-cpu/ops.cpp | 121 ++++++++++++++++------------- ggml/src/ggml-cpu/simd-gemm.h | 139 ++++++++++++++++++++++++++++++++++ 4 files changed, 211 insertions(+), 57 deletions(-) create mode 100644 ggml/src/ggml-cpu/simd-gemm.h diff --git a/ggml/src/ggml-cpu/common.h b/ggml/src/ggml-cpu/common.h index 1057b5bb152..abbadc359c5 100644 --- a/ggml/src/ggml-cpu/common.h +++ b/ggml/src/ggml-cpu/common.h @@ -6,8 +6,8 @@ #include "ggml-impl.h" #include "simd-mappings.h" -#define GGML_FA_TILE_Q 32 -#define GGML_FA_TILE_KV 16 +#define GGML_FA_TILE_Q 64 +#define GGML_FA_TILE_KV 64 #ifdef __cplusplus diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index e048d5e5e77..64eb01a4e18 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -2874,8 +2874,8 @@ struct ggml_cplan ggml_graph_plan( const int64_t DV = node->src[2]->ne[0]; // Tiled flash attention scratch (tile sizes defined in common.h) - // Per-thread: Q_q + KQ + mask + VKQ32 + V32 + padding - size_t prefill = sizeof(float)*(GGML_FA_TILE_Q*DK + 2*GGML_FA_TILE_Q*GGML_FA_TILE_KV + GGML_FA_TILE_Q*DV + GGML_FA_TILE_KV*DV)*n_tasks; + // Per-thread: Q_q + KQ + mask + VKQ32 + V32 + K_f32 + padding + size_t prefill = sizeof(float)*(GGML_FA_TILE_Q*DK + 2*GGML_FA_TILE_Q*GGML_FA_TILE_KV + GGML_FA_TILE_Q*DV + GGML_FA_TILE_KV*DV + GGML_FA_TILE_KV*DK)*n_tasks; // Decode path: n_kv_chunks = n_tasks (one chunk per thread) // Per-thread: VKQ accmulator (DV), partial M, partial S + intra-thread scratch for V, Q and VKQ diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 4352e132807..b7a70e06f1d 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -3,6 +3,7 @@ #include "ggml-cpu.h" #include "ggml-impl.h" #include "binary-ops.h" +#include "simd-gemm.h" #include "ggml.h" #include "unary-ops.h" #include "vec.h" @@ -8389,10 +8390,6 @@ static void ggml_compute_forward_flash_attn_ext_tiled( GGML_ASSERT(k->type == v->type); const ggml_type kv_type = k->type; - const auto * kv_type_traits_cpu = ggml_get_type_traits_cpu(kv_type); - const ggml_from_float_t kv_from_float = kv_type_traits_cpu->from_float; - const ggml_vec_dot_t kv_vec_dot = kv_type_traits_cpu->vec_dot; - const size_t kv_type_size = ggml_type_size(kv_type); // broadcast factors const int64_t rk2 = neq2/nek2; @@ -8424,8 +8421,6 @@ static void ggml_compute_forward_flash_attn_ext_tiled( static constexpr int Q_TILE_SZ = ggml_fa_tile_config::Q; static constexpr int KV_TILE_SZ = ggml_fa_tile_config::KV; - GGML_ASSERT(nek1 % KV_TILE_SZ == 0 && "KV sequence length must be divisible by KV_TILE_SZ"); - int ir = ir0; while (ir < ir1) { // q indices for the start of this tile @@ -8452,18 +8447,20 @@ static void ggml_compute_forward_flash_attn_ext_tiled( } // Per-thread scratch layout: - // Q_q: Q_TILE_SZ * DK (converted Q tile in KV type) + // Q_q: Q_TILE_SZ * DK (converted Q tile — F32 for GEMM, KV type for scalar) // KQ: Q_TILE_SZ * KV_TILE_SZ (attention scores in float) // mask: Q_TILE_SZ * KV_TILE_SZ (mask in float) // VKQ32: Q_TILE_SZ * DV (FP32 output accumulator) - // V32: KV_TILE_SZ * DV (F32 buffer for V tile - used for f166 conversion) - float * base = (float *) params->wdata + ith*(Q_TILE_SZ*DK + 2*Q_TILE_SZ*KV_TILE_SZ + Q_TILE_SZ*DV + KV_TILE_SZ*DV + CACHE_LINE_SIZE_F32); + // V32: KV_TILE_SZ * DV (F32 buffer for V tile) + // K_f32: KV_TILE_SZ * DK (F32 buffer for K tile — GEMM path) + float * base = (float *) params->wdata + ith*(Q_TILE_SZ*DK + 2*Q_TILE_SZ*KV_TILE_SZ + Q_TILE_SZ*DV + KV_TILE_SZ*DV + KV_TILE_SZ*DK + CACHE_LINE_SIZE_F32); void * Q_q = base; float * KQ = (float *)((char *)base + Q_TILE_SZ * DK * sizeof(float)); float * mask32 = KQ + Q_TILE_SZ * KV_TILE_SZ; float * VKQ32 = mask32 + Q_TILE_SZ * KV_TILE_SZ; - float * V32 = VKQ32 + Q_TILE_SZ * DV; // F32 buffer for V tile + float * V32 = VKQ32 + Q_TILE_SZ * DV; + float * K_f32 = V32 + KV_TILE_SZ * DV; memset(VKQ32, 0, Q_TILE_SZ * DV * sizeof(float)); memset(mask32, 0, Q_TILE_SZ * KV_TILE_SZ * sizeof(float)); @@ -8476,28 +8473,38 @@ static void ggml_compute_forward_flash_attn_ext_tiled( const int iv3 = iq3 / rv3; const int iv2 = iq2 / rv2; - for (int tq = 0; tq < tile_rows; tq++) { - const float * pq = (const float *) ((char *) q->data + ((iq1 + tq)*nbq1 + iq2*nbq2 + iq3*nbq3)); - kv_from_float(pq, (char *)Q_q + tq * DK * kv_type_size, DK); - } - // Zero-pad remaining rows - for (int tq = tile_rows; tq < Q_TILE_SZ; tq++) { - memset((char *)Q_q + tq * DK * kv_type_size, 0, DK * kv_type_size); + { + float * Q_f32 = (float *)Q_q; + for (int tq = 0; tq < tile_rows; tq++) { + const float * pq = (const float *) ((char *) q->data + ((iq1 + tq)*nbq1 + iq2*nbq2 + iq3*nbq3)); + memcpy(Q_f32 + tq * DK, pq, DK * sizeof(float)); + } + for (int tq = tile_rows; tq < Q_TILE_SZ; tq++) { + memset(Q_f32 + tq * DK, 0, DK * sizeof(float)); + } } + memset(K_f32, 0, DK * KV_TILE_SZ * sizeof(float)); + memset(V32, 0, KV_TILE_SZ * DV * sizeof(float)); + for (int64_t ic = 0; ic < nek1; ic += KV_TILE_SZ) { + const int kv_tile = (int)std::min((int64_t)KV_TILE_SZ, nek1 - ic); // skip the tile entirely if all the masks are -inf if (mask) { bool can_skip = true; for (int tq = 0; tq < tile_rows; tq++) { const ggml_fp16_t * mp_row = (const ggml_fp16_t *)((const char *) mask->data + (iq1 + tq)*mask->nb[1] + (iq2%mask->ne[2])*mask->nb[2] + (iq3%mask->ne[3])*mask->nb[3]); - for (int tk = 0; tk < KV_TILE_SZ; tk++) { + for (int tk = 0; tk < kv_tile; tk++) { mask32[tq * KV_TILE_SZ + tk] = slope * GGML_CPU_FP16_TO_FP32(mp_row[ic + tk]); if (mask32[tq * KV_TILE_SZ + tk] != -INFINITY) { can_skip = false; } } + // Pad remaining mask entries with -inf + for (int tk = kv_tile; tk < KV_TILE_SZ; tk++) { + mask32[tq * KV_TILE_SZ + tk] = -INFINITY; + } } if (can_skip) { @@ -8505,13 +8512,32 @@ static void ggml_compute_forward_flash_attn_ext_tiled( } } - for (int tq = 0; tq < Q_TILE_SZ; tq++) { - const void * q_row = (const char *)Q_q + tq * DK * kv_type_size; - for (int tk = 0; tk < KV_TILE_SZ; tk++) { - const void * k_row = (const char *) k->data + ((ic + tk)*nbk1 + ik2*nbk2 + ik3*nbk3); - float s; - kv_vec_dot(DK, &s, 0, k_row, 0, q_row, 0, 1); - KQ[tq * KV_TILE_SZ + tk] = s * scale; + // Pack K tile transposed: K_f32[dk][kv] so KV_TILE is contiguous (SIMD dim) + // Zero-pad the last tile so the GEMM always operates on KV_TILE_SZ columns + for (int tk = 0; tk < kv_tile; tk++) { + const char * k_data = (const char *)k->data + (ic + tk)*nbk1 + ik2*nbk2 + ik3*nbk3; + if (kv_type == GGML_TYPE_F16) { + const ggml_fp16_t * k_f16 = (const ggml_fp16_t *)k_data; + for (int64_t dk = 0; dk < DK; dk++) { + K_f32[dk * KV_TILE_SZ + tk] = GGML_CPU_FP16_TO_FP32(k_f16[dk]); + } + } else { + const float * k_f32_src = (const float *)k_data; + for (int64_t dk = 0; dk < DK; dk++) { + K_f32[dk * KV_TILE_SZ + tk] = k_f32_src[dk]; + } + } + } + memset(KQ, 0, Q_TILE_SZ * KV_TILE_SZ * sizeof(float)); + simd_gemm(KQ, (const float *)Q_q, K_f32, Q_TILE_SZ, DK, KV_TILE_SZ); + ggml_vec_scale_f32(Q_TILE_SZ * KV_TILE_SZ, KQ, scale); + + // Set padded KQ entries to -inf so softmax gives them zero weight + if (kv_tile < KV_TILE_SZ) { + for (int tq = 0; tq < Q_TILE_SZ; tq++) { + for (int tk = kv_tile; tk < KV_TILE_SZ; tk++) { + KQ[tq * KV_TILE_SZ + tk] = -INFINITY; + } } } @@ -8551,33 +8577,22 @@ static void ggml_compute_forward_flash_attn_ext_tiled( S[tq] += ggml_vec_soft_max_f32(KV_TILE_SZ, kq_row, kq_row, Mnew); } - // Convert V tile to F32 first (if F16), then do MAD - // On x86, ggml_vec_mad_f16 internall converts F16<->F32 on every load/store, so pre-converting is faster. - // TODO: on ARM, native f16 should be faster - if (kv_type == GGML_TYPE_F16) { - for (int tk = 0; tk < KV_TILE_SZ; tk++) { - const ggml_fp16_t * v_row = (const ggml_fp16_t *)((const char *) v->data + ((ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3)); - ggml_fp16_to_fp32_row(v_row, V32 + tk * DV, DV); - } - for (int tq = 0; tq < Q_TILE_SZ; tq++) { - if (skip[tq]) continue; - float * vkq_row = VKQ32 + tq * DV; - for (int tk = 0; tk < KV_TILE_SZ; tk++) { - const float p = KQ[tq * KV_TILE_SZ + tk]; - ggml_vec_mad_f32(DV, vkq_row, V32 + tk * DV, p); - } + // V accumulation: VKQ32 += softmax(KQ) * V + // Pack V tile to contiguous F32, zero-padded + for (int tk = 0; tk < kv_tile; tk++) { + const char * v_data = (const char *)v->data + (ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3; + if (kv_type == GGML_TYPE_F16) { + ggml_fp16_to_fp32_row((const ggml_fp16_t *)v_data, V32 + tk * DV, DV); + } else { + memcpy(V32 + tk * DV, v_data, DV * sizeof(float)); } - } else { - for (int tq = 0; tq < Q_TILE_SZ; tq++) { - if (skip[tq]) continue; - float * vkq_row = VKQ32 + tq * DV; - for (int tk = 0; tk < KV_TILE_SZ; tk++) { - const float p = KQ[tq * KV_TILE_SZ + tk]; - const float * v_row = (const float *)((const char *) v->data + ((ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3)); - ggml_vec_mad_f32(DV, vkq_row, v_row, p); - } + } + for (int tq = 0; tq < Q_TILE_SZ; tq++) { + if (skip[tq]) { + memset(KQ + tq * KV_TILE_SZ, 0, KV_TILE_SZ * sizeof(float)); } } + simd_gemm(VKQ32, KQ, V32, Q_TILE_SZ, KV_TILE_SZ, DV); } // sinks (apply only to valid rows in the tile) @@ -8794,15 +8809,15 @@ static void ggml_compute_forward_flash_attn_ext_f16( const int64_t dr = (nr + nchunk - 1) / nchunk; - static constexpr int64_t KV_TILE_SZ = ggml_fa_tile_config::KV; static constexpr int64_t Q_TILE_SZ = ggml_fa_tile_config::Q; - const bool use_tiled = !use_ref && + bool use_tiled = !use_ref && (q->type == GGML_TYPE_F32 && kv_is_f32_or_f16 && k->type == v->type && - nek1 % KV_TILE_SZ == 0 && neq1 >= Q_TILE_SZ); - +#ifdef GGML_SIMD + use_tiled &= (DV % GGML_F32_EPR == 0); +#endif int current_chunk = ith; while (current_chunk < nchunk) { diff --git a/ggml/src/ggml-cpu/simd-gemm.h b/ggml/src/ggml-cpu/simd-gemm.h new file mode 100644 index 00000000000..cd98a1b0332 --- /dev/null +++ b/ggml/src/ggml-cpu/simd-gemm.h @@ -0,0 +1,139 @@ +#pragma once + +// Computes C[M x N] += A[M x K] * B[K x N] + +#include "ggml-cpu-impl.h" +#include "vec.h" +#include "common.h" +#include "simd-mappings.h" + +// TODO: add support for sizeless vector types +#if defined(GGML_SIMD) && !defined(__ARM_FEATURE_SVE) && !defined(__riscv_v_intrinsic) + +// TODO: untested on avx512 +// These are in units of GGML_F32_EPR +#if defined(__AVX512F__) || defined (__ARM_NEON__) + static constexpr int GEMM_RM = 4; + static constexpr int GEMM_RN = 4; // 16+4+1 = 25/32 +#elif defined(__AVX2__) || defined(__AVX__) + static constexpr int GEMM_RM = 6; + static constexpr int GEMM_RN = 2; // 12+2+1 = 15/16 +#else + static constexpr int GEMM_RM = 2; + static constexpr int GEMM_RN = 2; +#endif + +#if defined(__GNUC__) && !defined(__clang__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Waggressive-loop-optimizations" +#endif + +template +static inline void simd_gemm_ukernel( + float * GGML_RESTRICT C, + const float * GGML_RESTRICT A, + const float * GGML_RESTRICT B, + int64_t K, int64_t N, + int64_t ii, int64_t jj) +{ + static constexpr int KN = GGML_F32_EPR; + + GGML_F32_VEC acc[RM][RN]; + for (int i = 0; i < RM; i++) { + for (int r = 0; r < RN; r++) { + acc[i][r] = GGML_F32_VEC_LOAD(C + (ii + i) * N + jj + r * KN); + } + } + + for (int64_t kk = 0; kk < K; kk++) { + GGML_F32_VEC Bv[RN]; + for (int r = 0; r < RN; r++) { + Bv[r] = GGML_F32_VEC_LOAD(B + kk * N + jj + r * KN); + } + for (int i = 0; i < RM; i++) { + GGML_F32_VEC p = GGML_F32_VEC_SET1(A[(ii + i) * K + kk]); + for (int r = 0; r < RN; r++) { + acc[i][r] = GGML_F32_VEC_FMA(acc[i][r], Bv[r], p); + } + } + } + + for (int i = 0; i < RM; i++) { + for (int r = 0; r < RN; r++) { + GGML_F32_VEC_STORE(C + (ii + i) * N + jj + r * KN, acc[i][r]); + } + } +} + +// C[M x N] += A[M x K] * B[K x N] +static void simd_gemm( + float * GGML_RESTRICT C, + const float * GGML_RESTRICT A, + const float * GGML_RESTRICT B, + int64_t M, int64_t K, int64_t N) +{ + static constexpr int KN = GGML_F32_EPR; + + int64_t ii = 0; + for (; ii + GEMM_RM <= M; ii += GEMM_RM) { + int64_t jj = 0; + for (; jj + GEMM_RN * KN <= N; jj += GEMM_RN * KN) { + simd_gemm_ukernel(C, A, B, K, N, ii, jj); + } + for (; jj + KN <= N; jj += KN) { + simd_gemm_ukernel(C, A, B, K, N, ii, jj); + } + for (; jj < N; jj++) { + for (int i = 0; i < GEMM_RM; i++) { + float a = C[(ii + i) * N + jj]; + for (int64_t kk = 0; kk < K; kk++) { + a += A[(ii + i) * K + kk] * B[kk * N + jj]; + } + C[(ii + i) * N + jj] = a; + } + } + } + + // Tail rows: one at a time + for (; ii < M; ii++) { + int64_t jj = 0; + for (; jj + GEMM_RN * KN <= N; jj += GEMM_RN * KN) { + simd_gemm_ukernel<1, GEMM_RN>(C, A, B, K, N, ii, jj); + } + for (; jj + KN <= N; jj += KN) { + simd_gemm_ukernel<1, 1>(C, A, B, K, N, ii, jj); + } + for (; jj < N; jj++) { + float a = C[ii * N + jj]; + for (int64_t kk = 0; kk < K; kk++) { + a += A[ii * K + kk] * B[kk * N + jj]; + } + C[ii * N + jj] = a; + } + } +} + +#if defined(__GNUC__) && !defined(__clang__) +#pragma GCC diagnostic pop +#endif + +#else // scalar path + +static void simd_gemm( + float * GGML_RESTRICT C, + const float * GGML_RESTRICT A, + const float * GGML_RESTRICT B, + int64_t M, int64_t K, int64_t N) +{ + for (int64_t i = 0; i < M; i++) { + for (int64_t j = 0; j < N; j++) { + float sum = C[i * N + j]; + for (int64_t kk = 0; kk < K; kk++) { + sum += A[i * K + kk] * B[kk * N + j]; + } + C[i * N + j] = sum; + } + } +} + +#endif // GGML_SIMD From 7b5a1ebaa63211d69cc57511050181197bb0cd8d Mon Sep 17 00:00:00 2001 From: Aaron Teo Date: Sun, 15 Feb 2026 18:20:35 +0800 Subject: [PATCH 160/831] ggml-cpu: optimize ggml_vec_dot_bf16 for s390x (llama/19399) --- ggml/src/ggml-cpu/simd-mappings.h | 26 ++++++++++++++++++++++++++ ggml/src/ggml-cpu/vec.cpp | 3 +-- 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cpu/simd-mappings.h b/ggml/src/ggml-cpu/simd-mappings.h index 630e506542b..22de55700d4 100644 --- a/ggml/src/ggml-cpu/simd-mappings.h +++ b/ggml/src/ggml-cpu/simd-mappings.h @@ -1160,6 +1160,14 @@ static inline void __lsx_f16x4_store(ggml_fp16_t * x, __m128 y) { float32x4_t tmp = x[0] + vec_reve(x[0]); \ res = tmp[0] + tmp[1]; \ } +#define GGML_F32x4_REDUCE_4(res, s0, s1, s2, s3) \ +{ \ + float32x4_t v = vec_add(vec_add(s0, s1), \ + vec_add(s2, s3)); \ + v = vec_add(v, vec_sld(v, v, 8)); \ + v = vec_add(v, vec_sld(v, v, 4)); \ + res += (ggml_float)vec_extract(v, 0); \ +} #define GGML_F32_VEC GGML_F32x4 #define GGML_F32_VEC_ZERO GGML_F32x4_ZERO @@ -1209,6 +1217,24 @@ static inline void __lzs_f16cx4_store(ggml_fp16_t * x, float32x4_t v_y) { #define GGML_F16_VEC_MUL GGML_F32x4_MUL #define GGML_F16_VEC_REDUCE GGML_F32x4_REDUCE +// BF16 s390x +#define GGML_BF16_STEP 16 +#define GGML_BF16_EPR 8 + +#define GGML_BF16x8 __vector unsigned short +#define GGML_BF16x8_ZERO vec_splats((unsigned short)0) +#define GGML_BF16x8_LOAD(p) vec_xl(0, (const unsigned short *)(p)) + +#define GGML_BF16_VEC GGML_BF16x8 +#define GGML_BF16_VEC_ZERO GGML_BF16x8_ZERO +#define GGML_BF16_VEC_LOAD GGML_BF16x8_LOAD +#define GGML_BF16_TO_F32_LO(v) ((float32x4_t) vec_mergel((v), GGML_BF16_VEC_ZERO)) +#define GGML_BF16_TO_F32_HI(v) ((float32x4_t) vec_mergeh((v), GGML_BF16_VEC_ZERO)) +#define GGML_BF16_FMA_LO(acc, x, y) \ + (acc) = GGML_F32x4_FMA((acc), GGML_BF16_TO_F32_LO(x), GGML_BF16_TO_F32_LO(y)) +#define GGML_BF16_FMA_HI(acc, x, y) \ + (acc) = GGML_F32x4_FMA((acc), GGML_BF16_TO_F32_HI(x), GGML_BF16_TO_F32_HI(y)) + #elif defined(__riscv_v_intrinsic) // compatible with vlen >= 128 diff --git a/ggml/src/ggml-cpu/vec.cpp b/ggml/src/ggml-cpu/vec.cpp index 8708cd4e92f..d0e4001338a 100644 --- a/ggml/src/ggml-cpu/vec.cpp +++ b/ggml/src/ggml-cpu/vec.cpp @@ -236,8 +236,7 @@ void ggml_vec_dot_bf16(int n, float * GGML_RESTRICT s, size_t bs, ggml_bf16_t * vfloat32m1_t redsum = __riscv_vfredusum_vs_f32m4_f32m1(vsum0, __riscv_vfmv_v_f_f32m1(0.0f, 1), vl); sumf += __riscv_vfmv_f_s_f32m1_f32(redsum); -#endif -#if defined(__POWER9_VECTOR__) +#elif defined(__POWER9_VECTOR__) || defined(__VXE__) || defined(__VXE2__) const int np = (n & ~(GGML_BF16_STEP - 1)); if (np > 0) { GGML_F32_VEC sum[4] = {GGML_F32_VEC_ZERO}; From 22f0861efccba73a9614cf2fc620822a6850986a Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 15 Feb 2026 14:56:35 +0200 Subject: [PATCH 161/831] ggml : avoid UB in gemm ukernel (llama/19642) --- ggml/src/ggml-cpu/simd-gemm.h | 57 +++++++++++++++++------------------ 1 file changed, 27 insertions(+), 30 deletions(-) diff --git a/ggml/src/ggml-cpu/simd-gemm.h b/ggml/src/ggml-cpu/simd-gemm.h index cd98a1b0332..78d663e593e 100644 --- a/ggml/src/ggml-cpu/simd-gemm.h +++ b/ggml/src/ggml-cpu/simd-gemm.h @@ -2,9 +2,6 @@ // Computes C[M x N] += A[M x K] * B[K x N] -#include "ggml-cpu-impl.h" -#include "vec.h" -#include "common.h" #include "simd-mappings.h" // TODO: add support for sizeless vector types @@ -23,44 +20,38 @@ static constexpr int GEMM_RN = 2; #endif -#if defined(__GNUC__) && !defined(__clang__) -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Waggressive-loop-optimizations" -#endif - template static inline void simd_gemm_ukernel( float * GGML_RESTRICT C, const float * GGML_RESTRICT A, const float * GGML_RESTRICT B, - int64_t K, int64_t N, - int64_t ii, int64_t jj) + int K, int N) { static constexpr int KN = GGML_F32_EPR; GGML_F32_VEC acc[RM][RN]; - for (int i = 0; i < RM; i++) { + for (int64_t i = 0; i < RM; i++) { for (int r = 0; r < RN; r++) { - acc[i][r] = GGML_F32_VEC_LOAD(C + (ii + i) * N + jj + r * KN); + acc[i][r] = GGML_F32_VEC_LOAD(C + i * N + r * KN); } } for (int64_t kk = 0; kk < K; kk++) { GGML_F32_VEC Bv[RN]; for (int r = 0; r < RN; r++) { - Bv[r] = GGML_F32_VEC_LOAD(B + kk * N + jj + r * KN); + Bv[r] = GGML_F32_VEC_LOAD(B + kk * N + r * KN); } - for (int i = 0; i < RM; i++) { - GGML_F32_VEC p = GGML_F32_VEC_SET1(A[(ii + i) * K + kk]); + for (int64_t i = 0; i < RM; i++) { + GGML_F32_VEC p = GGML_F32_VEC_SET1(A[i * K + kk]); for (int r = 0; r < RN; r++) { acc[i][r] = GGML_F32_VEC_FMA(acc[i][r], Bv[r], p); } } } - for (int i = 0; i < RM; i++) { + for (int64_t i = 0; i < RM; i++) { for (int r = 0; r < RN; r++) { - GGML_F32_VEC_STORE(C + (ii + i) * N + jj + r * KN, acc[i][r]); + GGML_F32_VEC_STORE(C + i * N + r * KN, acc[i][r]); } } } @@ -70,7 +61,7 @@ static void simd_gemm( float * GGML_RESTRICT C, const float * GGML_RESTRICT A, const float * GGML_RESTRICT B, - int64_t M, int64_t K, int64_t N) + int M, int K, int N) { static constexpr int KN = GGML_F32_EPR; @@ -78,38 +69,44 @@ static void simd_gemm( for (; ii + GEMM_RM <= M; ii += GEMM_RM) { int64_t jj = 0; for (; jj + GEMM_RN * KN <= N; jj += GEMM_RN * KN) { - simd_gemm_ukernel(C, A, B, K, N, ii, jj); + simd_gemm_ukernel(C + jj, A, B + jj, K, N); } for (; jj + KN <= N; jj += KN) { - simd_gemm_ukernel(C, A, B, K, N, ii, jj); + simd_gemm_ukernel(C + jj, A, B + jj, K, N); } for (; jj < N; jj++) { - for (int i = 0; i < GEMM_RM; i++) { - float a = C[(ii + i) * N + jj]; + for (int64_t i = 0; i < GEMM_RM; i++) { + float a = C[i * N + jj]; for (int64_t kk = 0; kk < K; kk++) { - a += A[(ii + i) * K + kk] * B[kk * N + jj]; + a += A[i + kk] * B[kk * N + jj]; } - C[(ii + i) * N + jj] = a; + C[i * N + jj] = a; } } + + A += GEMM_RM * K; + C += GEMM_RM * N; } // Tail rows: one at a time for (; ii < M; ii++) { int64_t jj = 0; for (; jj + GEMM_RN * KN <= N; jj += GEMM_RN * KN) { - simd_gemm_ukernel<1, GEMM_RN>(C, A, B, K, N, ii, jj); + simd_gemm_ukernel<1, GEMM_RN>(C + jj, A, B + jj, K, N); } for (; jj + KN <= N; jj += KN) { - simd_gemm_ukernel<1, 1>(C, A, B, K, N, ii, jj); + simd_gemm_ukernel<1, 1>(C + jj, A, B + jj, K, N); } for (; jj < N; jj++) { - float a = C[ii * N + jj]; + float a = C[jj]; for (int64_t kk = 0; kk < K; kk++) { - a += A[ii * K + kk] * B[kk * N + jj]; + a += A[kk] * B[kk * N + jj]; } - C[ii * N + jj] = a; + C[jj] = a; } + + A += K; + C += N; } } @@ -123,7 +120,7 @@ static void simd_gemm( float * GGML_RESTRICT C, const float * GGML_RESTRICT A, const float * GGML_RESTRICT B, - int64_t M, int64_t K, int64_t N) + int M, int K, int N) { for (int64_t i = 0; i < M; i++) { for (int64_t j = 0; j < N; j++) { From df2f8d3bc44015e1cf174b61accfea4a564ddd6d Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Sun, 15 Feb 2026 13:59:38 +0100 Subject: [PATCH 162/831] cmake : check if KleidiAI API has been fetched (llama/19640) This commit addresses a build issue with the KleidiAI backend when building multiple cpu backends. Commmit 3a00c98584e42a20675b6569d81beadb282b0952 ("cmake : fix KleidiAI install target failure with EXCLUDE_FROM_ALL") introduced a change where FetchContent_Populate is called instead of FetchContent_MakeAvailable, where the latter does handle this case (it is idempotent but FetchContent_Populate is not). I missed this during my review and I should not have commited without verifying the CI failure, sorry about that. --- ggml/src/ggml-cpu/CMakeLists.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index 6aea7f7bfb9..43d6f7f54f7 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -576,13 +576,13 @@ function(ggml_add_cpu_backend_variant_impl tag_name) DOWNLOAD_EXTRACT_TIMESTAMP NEW URL_HASH MD5=${KLEIDIAI_ARCHIVE_MD5}) - FetchContent_Populate(KleidiAI_Download) FetchContent_GetProperties(KleidiAI_Download SOURCE_DIR KLEIDIAI_SRC POPULATED KLEIDIAI_POPULATED) if (NOT KLEIDIAI_POPULATED) - message(FATAL_ERROR "KleidiAI source downloaded failed.") + FetchContent_Populate(KleidiAI_Download) + FetchContent_GetProperties(KleidiAI_Download SOURCE_DIR KLEIDIAI_SRC) endif() add_compile_definitions(GGML_USE_CPU_KLEIDIAI) From 02a9f660b8e05d0fd5afae17c0316f16df667a81 Mon Sep 17 00:00:00 2001 From: David Friehs Date: Sun, 15 Feb 2026 18:08:42 +0100 Subject: [PATCH 163/831] cuda: optimize iq2xxs/iq2xs/iq3xxs dequantization (llama/19624) * cuda: optimize iq2xxs/iq2xs/iq3xxs dequantization - load all 8 int8 for a grid position in one load - calculate signs via popcnt instead of fetching from ksigns table - broadcast signs to drop individual shift/mask * cuda: iq2xxs: simplify sum scaling express `(sum * scale + sum / 2) / 4` as `(sum * (scale * 2 + 1)) / 8` express `((aux32 >> 28) * 2 + 1)` as `(aux32 >> 27 | 1)` saves 3 registers for mul_mat_vec_q (152 -> 149) according to nsight AFAICT no overflow can occur here as iq2xxs values are far too small * uint -> uint32_t error: identifier "uint" is undefined --- ggml/src/ggml-cuda/mmq.cuh | 37 ++++++++++++++------------ ggml/src/ggml-cuda/vecdotq.cuh | 48 ++++++++++++++++++++++------------ 2 files changed, 52 insertions(+), 33 deletions(-) diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index f80f98cda2c..255e59f6fc6 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -2715,14 +2715,14 @@ template static __device__ __forceinline__ void loa #pragma unroll for (int l = 0; l < QR2_XXS; ++l) { - const int * grid_pos = (const int *) (iq2xxs_grid + aux8[l]); - const int signs_packed = ksigns_iq2xs[(aux32 >> (7*l)) & 0x7F]; + const uint2 grid_pos = ((const uint2*)iq2xxs_grid)[aux8[l]]; + const uint32_t signs = unpack_ksigns(aux32 >> (7 * l)); - const int signs0 = __vcmpne4(((signs_packed & 0x03) << 7) | ((signs_packed & 0x0C) << 21), 0x00000000); - const int grid0 = __vsub4(grid_pos[0] ^ signs0, signs0); + const int signs0 = __vcmpne4(signs & 0x08040201, 0); + const int grid0 = __vsub4(grid_pos.x ^ signs0, signs0); - const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000); - const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1); + const int signs1 = __vcmpne4(signs & 0x80402010, 0); + const int grid1 = __vsub4(grid_pos.y ^ signs1, signs1); #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid0; @@ -2733,12 +2733,12 @@ template static __device__ __forceinline__ void loa #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } - const int ls = aux32 >> 28; + const int ls = aux32 >> 27 | 1; // (scale * 2 + 1) const float d = bxi->d; #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) - x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/4; + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = d * ls / 8; // (d * scale + d / 2) / 4 #else - x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/4; + x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = d * ls / 8; // (d * scale + d / 2) / 4 #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } } @@ -2776,11 +2776,14 @@ template static __device__ __forceinline__ void loa #pragma unroll for (int l = 0; l < QR2_XS; ++l) { - const uint32_t * grid_pos = (const uint32_t *)(iq2xs_grid + (q2[l] & 0x000001FF)); - const uint32_t * signs = (const uint32_t *)(ksigns64 + (q2[l] >> 9)); + const uint2 grid_pos = ((const uint2*)iq2xs_grid)[q2[l] & 0x1FF]; + const uint32_t signs = unpack_ksigns(q2[l] >> 9); + + const int signs0 = __vcmpne4(signs & 0x08040201, 0); + const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0); - const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]); - const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]); + const int signs1 = __vcmpne4(signs & 0x80402010, 0); + const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1); #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l; @@ -2904,11 +2907,13 @@ template static __device__ __forceinline__ void loa #pragma unroll for (int l = 0; l < QR3_XXS; ++l) { const int2 grid_pos = make_int2(iq3xxs_grid[q3[2*l+0]], iq3xxs_grid[q3[2*l+1]]); + const uint32_t signs = unpack_ksigns(aux32 >> (7*l)); - const int * signs = (const int *)(ksigns64 + ((aux32 >> (7*l)) & 0x7F)); + const int signs0 = __vcmpne4(signs & 0x08040201, 0); + const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0); - const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]); - const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]); + const int signs1 = __vcmpne4(signs & 0x80402010, 0); + const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1); #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid_l; diff --git a/ggml/src/ggml-cuda/vecdotq.cuh b/ggml/src/ggml-cuda/vecdotq.cuh index 6baab1176ff..ab803aca21b 100644 --- a/ggml/src/ggml-cuda/vecdotq.cuh +++ b/ggml/src/ggml-cuda/vecdotq.cuh @@ -94,6 +94,15 @@ static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, con #endif } +static __device__ __forceinline__ uint32_t unpack_ksigns(const uint8_t v) { + // v is a 7 bit int, with the 8th sign being encodable as popcnt + // with xor we can "correct" the bit instead of having to mask + const uint32_t p = __popc(v) & 1; + const uint32_t s = v ^ p << 7; + // broadcast over uint to allow for 0x08040201 / 0x80402010 as selectors + return s * 0x01010101; +} + // VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called // MMVQ = mul_mat_vec_q, MMQ = mul_mat_q @@ -905,22 +914,22 @@ static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1( int sumi = 0; #pragma unroll for (int k0 = 0; k0 < 8; k0 += 2) { - const int * grid_pos = (const int *) (iq2xxs_grid + aux8[k0/2]); - const int signs_packed = ksigns_iq2xs[(aux32 >> (7*k0/2)) & 0x7F]; + const uint2 grid_pos = ((const uint2*)iq2xxs_grid)[aux8[k0/2]]; + const uint32_t signs = unpack_ksigns(aux32 >> (7 * k0 / 2)); - const int signs0 = __vcmpne4(((signs_packed & 0x03) << 7) | ((signs_packed & 0x0C) << 21), 0x00000000); - const int grid0 = __vsub4(grid_pos[0] ^ signs0, signs0); + const int signs0 = __vcmpne4(signs & 0x08040201, 0); + const int grid0 = __vsub4(grid_pos.x ^ signs0, signs0); const int u0 = get_int_b4(bq8_1[iqs/2].qs, k0 + 0); sumi = ggml_cuda_dp4a(grid0, u0, sumi); - const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000); - const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1); + const int signs1 = __vcmpne4(signs & 0x80402010, 0); + const int grid1 = __vsub4(grid_pos.y ^ signs1, signs1); const int u1 = get_int_b4(bq8_1[iqs/2].qs, k0 + 1); sumi = ggml_cuda_dp4a(grid1, u1, sumi); } - const int ls = aux32 >> 28; - sumi = (ls*sumi + sumi/2)/4; + const int ls = aux32 >> 27 | 1; // (scale * 2 + 1) + sumi = sumi * ls / 8; // (sumi * scale + sumi / 2) / 4 const float d = __half2float(bq2->d) * __low2float(bq8_1[iqs/2].ds); return d * sumi; } @@ -942,13 +951,15 @@ static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1( int sumi1 = 0; #pragma unroll for (int l0 = 0; l0 < 8; l0 += 2) { - const uint32_t * grid_pos = (const uint32_t *)(iq2xs_grid + (q2[l0/2] & 0x000001FF)); - const uint32_t * signs = (const uint32_t *)(ksigns64 + (q2[l0/2] >> 9)); - - const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]); - const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]); + const uint2 grid_pos = ((const uint2*)iq2xs_grid)[q2[l0/2] & 0x1FF]; + const uint32_t signs = unpack_ksigns(q2[l0/2] >> 9); + const int signs0 = __vcmpne4(signs & 0x08040201, 0); + const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0); const int u0 = get_int_b4(bq8_1[iqs/2].qs, l0 + 0); + + const int signs1 = __vcmpne4(signs & 0x80402010, 0); + const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1); const int u1 = get_int_b4(bq8_1[iqs/2].qs, l0 + 1); if (l0 < 4) { @@ -1028,13 +1039,16 @@ static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1( #pragma unroll for (int l0 = 0; l0 < 8; l0 += 2) { const int2 grid_pos = make_int2(iq3xxs_grid[q3[l0 + 0]], iq3xxs_grid[q3[l0 + 1]]); + const uint32_t signs = unpack_ksigns(aux32 >> (7*l0/2)); - const int * signs = (const int *)(ksigns64 + ((aux32 >> (7*l0/2)) & 0x7F)); - - const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]); - const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]); + const int signs0 = __vcmpne4(signs & 0x08040201, 0); + const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0); const int u0 = get_int_b4(bq8_1[iqs/2].qs, l0 + 0); + + const int signs1 = __vcmpne4(signs & 0x80402010, 0); + const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1); + const int u1 = get_int_b4(bq8_1[iqs/2].qs, l0 + 1); sumi = ggml_cuda_dp4a(grid_l, u0, sumi); From f8f7c1d8918114f2c2ee6a5c249ccf61a0298c45 Mon Sep 17 00:00:00 2001 From: abhijain1204fujitsu <139222713+abhijain1204fujitsu@users.noreply.github.com> Date: Mon, 16 Feb 2026 12:08:43 +0530 Subject: [PATCH 164/831] ggml: aarch64: Implement SVE in Gemm q4_k 8x8 q8_k Kernel (llama/19132) * Updated repack.cpp * Updated repack.cpp * Updated repack.cpp * Added if condition to support only vector length 256. * Changed the format removed comments and duplicate variable * If SVE 256 not present then was using generic function to compute, hence slowing the performance. So added code if SVE 256 is not present then use NEON code. * Code format change suggestion --------- Co-authored-by: Vithule, Prashant --- ggml/src/ggml-cpu/arch/arm/repack.cpp | 310 ++++++++++++++++++++++++++ 1 file changed, 310 insertions(+) diff --git a/ggml/src/ggml-cpu/arch/arm/repack.cpp b/ggml/src/ggml-cpu/arch/arm/repack.cpp index fd05c609f7e..3a3b32efb2b 100644 --- a/ggml/src/ggml-cpu/arch/arm/repack.cpp +++ b/ggml/src/ggml-cpu/arch/arm/repack.cpp @@ -3226,6 +3226,316 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, UNUSED(ncols_interleaved); UNUSED(blocklen); +#if defined(__aarch64__) && defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8) + if (svcntb() * 8 == 256) { + constexpr int q8_k_blocklen = 4; + const svuint8_t m4b_1 = svdup_n_u8(0x0f); + // 8 accumulators: 2 row pairs × 4 col pairs + svfloat32_t acc_f32_01, acc_f32_23, acc_f32_45, acc_f32_67; + uint32_t idx_arr[8] = { 0, 2, 4, 6, 1, 3, 5, 7 }; + svbool_t pg = svptrue_pat_b32(SV_VL8); + svuint32_t idx = svld1(pg, idx_arr); + + static const uint32_t idx_data[8] = {0, 4, 2, 6, 1, 5, 3, 7}; + svuint32_t idx1 = svld1_u32(svptrue_b32(), idx_data); + + for (int y = 0; y < nr / q8_k_blocklen; y++) { + const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb); + + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb); + + acc_f32_01 = svdup_n_f32(0); + acc_f32_23 = svdup_n_f32(0); + acc_f32_45 = svdup_n_f32(0); + acc_f32_67 = svdup_n_f32(0); + + for (int b = 0; b < nb; b++) { + // bsums pairs belongs to the same q8_k subblock + // 64 elemnts loaded and made sum of 0-7 and 8-15 sum || 16-23 and 24 - 31 sum + const int16x8_t bsums[4]{ + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)), + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)), + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)), + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)), + }; + + int32_t bsums_arr32[4][8]; + + for (int q8_row = 0; q8_row < 4; q8_row++) { + int16x8_t v16 = bsums[q8_row]; + + // low 4 + int32x4_t v32_lo = vmovl_s16(vget_low_s16(v16)); + vst1q_s32(&bsums_arr32[q8_row][0], v32_lo); + + // high 4 + int32x4_t v32_hi = vmovl_s16(vget_high_s16(v16)); + vst1q_s32(&bsums_arr32[q8_row][4], v32_hi); + } + + svint32_t sb_acc_0 = svdup_n_s32(0); + svint32_t sb_acc_2 = svdup_n_s32(0); + + svint32_t acc_00 = svdup_n_s32(0); + svint32_t acc_11 = svdup_n_s32(0); + svint32_t acc_22 = svdup_n_s32(0); + svint32_t acc_33 = svdup_n_s32(0); + svint32_t acc_44 = svdup_n_s32(0); + svint32_t acc_55 = svdup_n_s32(0); + svint32_t acc_66 = svdup_n_s32(0); + svint32_t acc_77 = svdup_n_s32(0); + + svint32_t bias_acc_00 = svdup_n_s32(0); + svint32_t bias_acc_22 = svdup_n_s32(0); + svint32_t bias_acc_44 = svdup_n_s32(0); + svint32_t bias_acc_66 = svdup_n_s32(0); + + for (int sb = 0; sb < QK_K / 64; sb++) { + // Need scales for the low and high nibbles + // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total + svint32_t block_scale_0, block_scale_1, block_scale_2, block_scale_3; + svint32_t q4sb_mins_0, q4sb_mins_1; + { + // 2-superblock I am working on + const int offset = sb * 24 + 0 * 12; + const uint8_t * scales_in = &q4_ptr[b].scales[offset]; + + const int offset1 = sb * 24 + 12; + const uint8_t * scales_in1 = &q4_ptr[b].scales[offset1]; + + constexpr uint32_t kmask1 = 0x3f3f3f3f; + constexpr uint32_t kmask2 = 0x0f0f0f0f; + constexpr uint32_t kmask3 = 0x03030303; + constexpr uint8_t scales_size = 12; + + uint32_t sm[3]; + memcpy(sm, scales_in, scales_size); + + uint32_t sm1[3]; + memcpy(sm1, scales_in1, scales_size); + + const uint32_t mins_0_3 = sm[1] & kmask1; + const uint32_t mins_4_7 = ((sm[2] >> 4) & kmask2) | (((sm[1] >> 6) & kmask3) << 4); + + const uint32_t mins_0_3_1 = sm1[1] & kmask1; + const uint32_t mins_4_7_1 = ((sm1[2] >> 4) & kmask2) | (((sm1[1] >> 6) & kmask3) << 4); + + svuint32_t mins_u32_temp = svzip1_u32(svdup_n_u32(mins_0_3), svdup_n_u32(mins_4_7)); + svuint32_t mins_u32_temp_1 = svzip1_u32(svdup_n_u32(mins_0_3_1), svdup_n_u32(mins_4_7_1)); + + /* reinterpret u32 → u8 */ + svuint8_t mins_u8 = svreinterpret_u8_u32(mins_u32_temp); + svuint8_t mins_u8_1 = svreinterpret_u8_u32(mins_u32_temp_1); + + /* widen u8 → u16->u32 (lower half only) */ + svuint32_t mins_u16 = svunpklo_u32(svunpklo_u16(mins_u8)); + svuint32_t mins_u16_1 = svunpklo_u32(svunpklo_u16(mins_u8_1)); + + q4sb_mins_0 = svreinterpret_s32_u32(mins_u16); + q4sb_mins_1 = svreinterpret_s32_u32(mins_u16_1); + + uint32_t scales_u32_0 = sm[0] & kmask1; + uint32_t scales_u32_1 = (sm[2] & kmask2) | (((sm[0] >> 6) & kmask3) << 4); + uint32_t scales_u32_2 = sm1[0] & kmask1; + uint32_t scales_u32_3 = (sm1[2] & kmask2) | (((sm1[0] >> 6) & kmask3) << 4); + + svuint32_t S01 = svdup_n_u32(scales_u32_0); + svuint32_t S23 = svdup_n_u32(scales_u32_1); + svuint32_t R01 = svdup_n_u32(scales_u32_2); + svuint32_t R23 = svdup_n_u32(scales_u32_3); + + svint8_t S01_b = svreinterpret_s8_u32(S01); + svint8_t S23_b = svreinterpret_s8_u32(S23); + svint8_t R01_b = svreinterpret_s8_u32(R01); + svint8_t R23_b = svreinterpret_s8_u32(R23); + + svint32_t S01_d = svunpklo_s32(svunpklo_s16(svzip1_s8(S01_b, S01_b))); + svint32_t R01_d = svunpklo_s32(svunpklo_s16(svzip1_s8(R01_b, R01_b))); + svint32_t S23_d = svunpklo_s32(svunpklo_s16(svzip1_s8(S23_b, S23_b))); + svint32_t R23_d = svunpklo_s32(svunpklo_s16(svzip1_s8(R23_b, R23_b))); + + block_scale_0 = svtbl_s32(svzip1_s32(S01_d, R01_d), idx); + block_scale_1 = svtbl_s32(svzip2_s32(S01_d, R01_d), idx); + block_scale_2 = svtbl_s32(svzip1_s32(S23_d, R23_d), idx); + block_scale_3 = svtbl_s32(svzip2_s32(S23_d, R23_d), idx); + } + + const int8_t * q8_base_1 = q8_ptr[b].qs + sb * 256; + + // Load 32-byte per row pair, 1 subblock each time + // predicate for activating higher lanes for 16 int8 elements + const svbool_t ph16 = svptrue_pat_b8(SV_VL16); + // predicate for activating lower lanes for 16 int8 elements + const svbool_t pl16 = svnot_b_z(svptrue_b8(), ph16); + + svint8_t q8_qs_0 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 0), svld1_s8(pl16, q8_base_1 + 112)); + svint8_t q8_qs_2 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 32), svld1_s8(pl16, q8_base_1 + 144)); + svint8_t q8_qs_4 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 64), svld1_s8(pl16, q8_base_1 + 176)); + svint8_t q8_qs_6 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 96), svld1_s8(pl16, q8_base_1 + 208)); + + svint8_t q8_qs_1 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 16), svld1_s8(pl16, q8_base_1 + 128)); + svint8_t q8_qs_3 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 48), svld1_s8(pl16, q8_base_1 + 160)); + svint8_t q8_qs_5 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 80), svld1_s8(pl16, q8_base_1 + 192)); + svint8_t q8_qs_7 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 112), svld1_s8(pl16, q8_base_1 + 224)); + + // Q4s columns iterated in pairs (01, 23, 45, 67) + for (int cp = 0; cp < ncols_interleaved / 2; cp++) { + + sb_acc_0 = svdup_n_s32(0); + sb_acc_2 = svdup_n_s32(0); + + svuint8_t q4_qs_cp_00 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 0); + svuint8_t q4_qs_cp_01 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 64); + svuint8_t q4_qs_cp_02 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 128); + svuint8_t q4_qs_cp_03 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 192); + + svint8_t q4_nibbles_00 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_00, m4b_1), 4)); + svint8_t q4_nibbles_01 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_01, m4b_1), 4)); + svint8_t q4_nibbles_02 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_02, m4b_1), 4)); + svint8_t q4_nibbles_03 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_03, m4b_1), 4)); + + sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_00, q8_qs_0); + sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_01, q8_qs_2); + + sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_02, q8_qs_4); + sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_03, q8_qs_6); + + sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_00, q8_qs_1); + sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_01, q8_qs_3); + + sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_02, q8_qs_5); + sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_03, q8_qs_7); + + if(cp == 0) { + acc_00 = svmla_s32_m(svptrue_b32(), acc_00, sb_acc_0, block_scale_0); + acc_44 = svmla_s32_m(svptrue_b32(), acc_44, sb_acc_2, block_scale_0); + } + if(cp == 1) { + acc_11 = svmla_s32_m(svptrue_b32(), acc_11, sb_acc_0, block_scale_1); + acc_55 = svmla_s32_m(svptrue_b32(), acc_55, sb_acc_2, block_scale_1); + } + if(cp == 2) { + acc_22 = svmla_s32_m(svptrue_b32(), acc_22, sb_acc_0, block_scale_2); + acc_66 = svmla_s32_m(svptrue_b32(), acc_66, sb_acc_2, block_scale_2); + } + if(cp == 3) { + acc_33 = svmla_s32_m(svptrue_b32(), acc_33, sb_acc_0, block_scale_3); + acc_77 = svmla_s32_m(svptrue_b32(), acc_77, sb_acc_2, block_scale_3); + } + } + + bias_acc_00 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_00, svdup_n_s32(bsums_arr32[sb][0]), q4sb_mins_0); + bias_acc_00 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_00, svdup_n_s32(bsums_arr32[sb][1]), q4sb_mins_1); + + bias_acc_22 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_22, svdup_n_s32(bsums_arr32[sb][2]), q4sb_mins_0); + bias_acc_22 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_22, svdup_n_s32(bsums_arr32[sb][3]), q4sb_mins_1); + + bias_acc_44 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_44, svdup_n_s32(bsums_arr32[sb][4]), q4sb_mins_0); + bias_acc_44 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_44, svdup_n_s32(bsums_arr32[sb][5]), q4sb_mins_1); + + bias_acc_66 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_66, svdup_n_s32(bsums_arr32[sb][6]), q4sb_mins_0); + bias_acc_66 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_66, svdup_n_s32(bsums_arr32[sb][7]), q4sb_mins_1); + } // for sb + + + acc_00 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_00, svext_s32(acc_00, acc_00, 4)); + acc_11 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_11, svext_s32(acc_11, acc_11, 4)); + acc_22 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_22, svext_s32(acc_22, acc_22, 4)); + acc_33 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_33, svext_s32(acc_33, acc_33, 4)); + acc_44 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_44, svext_s32(acc_44, acc_44, 4)); + acc_55 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_55, svext_s32(acc_55, acc_55, 4)); + acc_66 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_66, svext_s32(acc_66, acc_66, 4)); + acc_77 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_77, svext_s32(acc_77, acc_77, 4)); + + svint32_t reorder_acc_01 = svtbl_s32( svzip1_s32( svtrn1_s32(acc_00, acc_11), svtrn1_s32(acc_22, acc_33)), idx1); + svint32_t reorder_acc_23 = svtbl_s32( svzip1_s32( svtrn2_s32(acc_00, acc_11), svtrn2_s32(acc_22, acc_33)), idx1); + + svint32_t reorder_acc_45 = svtbl_s32( svzip1_s32( svtrn1_s32(acc_44, acc_55), svtrn1_s32(acc_66, acc_77)), idx1); + svint32_t reorder_acc_67 = svtbl_s32( svzip1_s32( svtrn2_s32(acc_44, acc_55), svtrn2_s32(acc_66, acc_77)), idx1); + + // Broadcast q8 scalar + svfloat32_t q8_d = svdup_f32(q8_ptr[b].d[0]); + + svfloat32_t q4_dmin_temp = svcvt_f32_f16_x(svptrue_b32(), svzip1_f16( svld1_f16(svptrue_pat_b16(SV_VL8), (const __fp16 *)q4_ptr[b].dmin), svdup_f16(0))); + + svfloat32_t q4_d_temp = svcvt_f32_f16_x(svptrue_b32(), svzip1_f16( svld1_f16(svptrue_pat_b16(SV_VL8), (const __fp16 *)q4_ptr[b].d), svdup_f16(0))); + + svfloat32_t scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d); + svfloat32_t dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d); + + acc_f32_01 = svmls_f32_m(svptrue_b32(), acc_f32_01, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_00), dmins1); + acc_f32_01 = svmla_f32_m(svptrue_b32(), acc_f32_01, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_01), scale1); + + q8_d = svdup_f32(q8_ptr[b].d[1]); + + scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d); + dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d); + + acc_f32_23 = svmls_f32_m(svptrue_b32(), acc_f32_23, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_22), dmins1); + acc_f32_23 = svmla_f32_m(svptrue_b32(), acc_f32_23, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_23), scale1); + + q8_d = svdup_f32(q8_ptr[b].d[2]); + + + scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d); + dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d); + + acc_f32_45 = svmls_f32_m(svptrue_b32(), acc_f32_45, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_44), dmins1); + acc_f32_45 = svmla_f32_m(svptrue_b32(), acc_f32_45, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_45), scale1); + + q8_d = svdup_f32(q8_ptr[b].d[3]); + + scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d); + dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d); + + acc_f32_67 = svmls_f32_m(svptrue_b32(), acc_f32_67, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_66), dmins1); + acc_f32_67 = svmla_f32_m(svptrue_b32(), acc_f32_67, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_67), scale1); + + } // for b + + // With the previous reorder, the tile is already in the correct memory layout. + // Predicate for exactly 4 lanes + svbool_t pg4 = svptrue_pat_b32(SV_VL4); + for (int i = 0; i < q8_k_blocklen; i++) { + int row = y * q8_k_blocklen + i; + for (int j = 0; j < 2; j++) { + int col = x * ncols_interleaved + j * 4; + int offset = row * bs + col; + + if (i == 0 && j == 0) { + // acc_f32_0 → lower half of acc_f32_01 + svst1_f32(pg4, s + offset, acc_f32_01); + } else if (i == 0 && j == 1) { + // acc_f32_1 → upper half of acc_f32_01 + svst1_f32(pg4, s + offset, svext_f32(acc_f32_01, acc_f32_01, 4)); + } else if (i == 1 && j == 0) { + // acc_f32_2 + svst1_f32(pg4, s + offset, acc_f32_23); + } else if (i == 1 && j == 1) { + // acc_f32_3 + svst1_f32(pg4, s + offset, svext_f32(acc_f32_23, acc_f32_23, 4)); + } else if (i == 2 && j == 0) { + // acc_f32_4 + svst1_f32(pg4, s + offset, acc_f32_45); + } else if (i == 2 && j == 1) { + // acc_f32_5 + svst1_f32(pg4, s + offset, svext_f32(acc_f32_45, acc_f32_45, 4)); + } else if (i == 3 && j == 0) { + // acc_f32_6 + svst1_f32(pg4, s + offset, acc_f32_67); + } else if (i == 3 && j == 1) { + // acc_f32_7 + svst1_f32(pg4, s + offset, svext_f32(acc_f32_67, acc_f32_67, 4)); + } + } + } + } // for x + } // for y + return; + } +#endif // SVE compile-time end + #if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) constexpr int q8_k_blocklen = 4; const uint8x16_t m4b = vdupq_n_u8(0x0f); From 5d9d72ec124dca3340785c51951bb9cc149bc1c5 Mon Sep 17 00:00:00 2001 From: Mario Limonciello Date: Mon, 16 Feb 2026 07:46:08 -0600 Subject: [PATCH 165/831] Adjust workaround for ROCWMMA_FATTN/GFX9 to only newer ROCm veresions (llama/19591) Avoids issues with ROCm 6.4.4. Closes: https://github.com/ggml-org/llama.cpp/issues/19580 Fixes: 6845f7f87 ("Add a workaround for compilation with ROCWMMA_FATTN and gfx9 (#19461)") Signed-off-by: Mario Limonciello (AMD) --- ggml/src/ggml-cuda/fattn-wmma-f16.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ggml/src/ggml-cuda/fattn-wmma-f16.cu index 35735d48b2e..f19defbff93 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cu +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cu @@ -63,7 +63,7 @@ static __global__ void flash_attn_ext_f16( constexpr int frag_m = ncols == 8 ? 32 : 16; constexpr int frag_n = ncols == 8 ? 8 : 16; static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0."); -#if defined(GGML_USE_HIP) +#if defined(GGML_USE_HIP) && HIP_VERSION >= 60500000 typedef wmma::fragment frag_a_K; typedef wmma::fragment frag_a_V; typedef wmma::fragment frag_b; @@ -135,7 +135,7 @@ static __global__ void flash_attn_ext_f16( __shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice. half2 * VKQ2 = (half2 *) VKQ; -#if defined(GGML_USE_HIP) +#if defined(GGML_USE_HIP) && HIP_VERSION >= 60500000 const _Float16 * K_h_f16 = reinterpret_cast(K_h); const _Float16 * V_h_f16 = reinterpret_cast(V_h); _Float16 * KQ_f16 = reinterpret_cast<_Float16 *>(KQ); From 5ee5748722ab9674c6b1c9147bea9ffe06d6ebf8 Mon Sep 17 00:00:00 2001 From: Judd <4046440+foldl@users.noreply.github.com> Date: Mon, 16 Feb 2026 23:43:34 +0800 Subject: [PATCH 166/831] ggml : make `ggml_is_view` as API (llama/19539) * make `ggml_is_view` as API * introduce `ggml_aux_is_view` as inline version for internal use. * change `ggml_aux_is_view` to `ggml_impl_is_view` --- ggml/include/ggml.h | 1 + ggml/src/ggml-alloc.c | 13 ++++--------- ggml/src/ggml-impl.h | 4 ++++ ggml/src/ggml.c | 4 ++++ 4 files changed, 13 insertions(+), 9 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index f759e2d5883..77af0e7fb6a 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -752,6 +752,7 @@ extern "C" { GGML_API bool ggml_is_transposed(const struct ggml_tensor * tensor); GGML_API bool ggml_is_permuted (const struct ggml_tensor * tensor); GGML_API bool ggml_is_empty (const struct ggml_tensor * tensor); + GGML_API bool ggml_is_view (const struct ggml_tensor * tensor); GGML_API bool ggml_is_scalar (const struct ggml_tensor * tensor); GGML_API bool ggml_is_vector (const struct ggml_tensor * tensor); GGML_API bool ggml_is_matrix (const struct ggml_tensor * tensor); diff --git a/ggml/src/ggml-alloc.c b/ggml/src/ggml-alloc.c index 41419b617bd..7f414b2311c 100644 --- a/ggml/src/ggml-alloc.c +++ b/ggml/src/ggml-alloc.c @@ -17,11 +17,6 @@ //#define AT_PRINTF(...) GGML_LOG_DEBUG(__VA_ARGS__) #define AT_PRINTF(...) - -static bool ggml_is_view(const struct ggml_tensor * t) { - return t->view_src != NULL; -} - // ops that return true for this function must not use restrict pointers for their backend implementations bool ggml_op_can_inplace(enum ggml_op op) { switch (op) { @@ -627,7 +622,7 @@ static void ggml_gallocr_allocate_node(ggml_gallocr_t galloc, struct ggml_tensor GGML_ASSERT(buffer_id >= 0); struct hash_node * hn = ggml_gallocr_hash_get(galloc, node); - if (!ggml_gallocr_is_allocated(galloc, node) && !ggml_is_view(node)) { + if (!ggml_gallocr_is_allocated(galloc, node) && !ggml_impl_is_view(node)) { hn->allocated = true; assert(hn->addr.offset == 0); @@ -658,7 +653,7 @@ static void ggml_gallocr_allocate_node(ggml_gallocr_t galloc, struct ggml_tensor struct hash_node * p_hn = ggml_gallocr_hash_get(galloc, parent); if (p_hn->n_children == 1 && p_hn->n_views == 0) { - if (ggml_is_view(parent)) { + if (ggml_impl_is_view(parent)) { struct ggml_tensor * view_src = parent->view_src; struct hash_node * view_src_hn = ggml_gallocr_hash_get(galloc, view_src); if (view_src_hn->n_views == 1 && view_src_hn->n_children == 0 && view_src->data == parent->data) { @@ -739,7 +734,7 @@ static void ggml_gallocr_alloc_graph_impl(ggml_gallocr_t galloc, struct ggml_cgr // GGML_OP_NONE does not appear normally in the graph nodes, but is used by ggml-backend to add dependencies to // control when some tensors are allocated and freed. in this case, the dependencies are in `src`, but the node // itself is never used and should not be considered a dependency - if (ggml_is_view(node) && node->op != GGML_OP_NONE) { + if (ggml_impl_is_view(node) && node->op != GGML_OP_NONE) { struct ggml_tensor * view_src = node->view_src; ggml_gallocr_hash_get(galloc, view_src)->n_views += 1; } @@ -806,7 +801,7 @@ static void ggml_gallocr_alloc_graph_impl(ggml_gallocr_t galloc, struct ggml_cgr parent->name, p_hn->n_children, p_hn->n_views, p_hn->allocated); if (p_hn->n_children == 0 && p_hn->n_views == 0) { - if (ggml_is_view(parent)) { + if (ggml_impl_is_view(parent)) { struct ggml_tensor * view_src = parent->view_src; struct hash_node * view_src_hn = ggml_gallocr_hash_get(galloc, view_src); view_src_hn->n_views -= 1; diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h index baadfe9a7b3..e3714b38a6a 100644 --- a/ggml/src/ggml-impl.h +++ b/ggml/src/ggml-impl.h @@ -98,6 +98,10 @@ static bool ggml_op_is_empty(enum ggml_op op) { } } +static inline bool ggml_impl_is_view(const struct ggml_tensor * t) { + return t->view_src != NULL; +} + static inline float ggml_compute_softplus_f32(float input) { return (input > 20.0f) ? input : logf(1 + expf(input)); } diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index e2a6ff67be7..ed819eaa4c5 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1496,6 +1496,10 @@ bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tenso (t0->nb[3] == t1->nb[3]); } +bool ggml_is_view(const struct ggml_tensor * t) { + return ggml_impl_is_view(t); +} + // check if t1 can be represented as a repetition of t0 bool ggml_can_repeat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) { static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); From cf4bd07028c007d42aeaa1f987a8159f7cb0cc92 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 17 Feb 2026 12:31:49 +0200 Subject: [PATCH 167/831] cuda : enable CUDA graphs for MMID 1 <= BS <= 4 (llama/19645) * cuda : enable CUDA graphs for MMID BS <= 4 * cont : add stream capture check Co-authored-by: Oliver Simons * cont : add MMVQ_MMID_MAX_BATCH_SIZE --------- Co-authored-by: Oliver Simons --- ggml/src/ggml-cuda/ggml-cuda.cu | 43 ++++++++------------------------- ggml/src/ggml-cuda/mmvq.cuh | 1 + 2 files changed, 11 insertions(+), 33 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index bed5c71a1bd..ffa35eeb654 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2278,11 +2278,12 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + // [TAG_MUL_MAT_ID_CUDA_GRAPHS] if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { static_assert(MMVQ_MAX_BATCH_SIZE == MMVF_MAX_BATCH_SIZE); if (ne2 <= MMVQ_MAX_BATCH_SIZE) { if (ggml_is_quantized(src0->type)) { - if (ne2 <= 4) { + if (ne2 <= MMVQ_MMID_MAX_BATCH_SIZE) { ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst); return; } @@ -2305,6 +2306,8 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * } } + // note: this path should not be reached when recording CUDA graphs, because it requires stream synchronization + // TODO: add asserts to verify this. should work with CUDA, HIP, etc. cudaStream_t stream = ctx.stream(); GGML_ASSERT(nb12 % nb11 == 0); @@ -2865,15 +2868,6 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) { bool use_cuda_graph = true; // Loop over nodes in GGML graph to obtain info needed for CUDA graph - const std::string gemma3n_per_layer_proj_src0_name = "inp_per_layer_selected"; - const std::string gemma3n_per_layer_proj_src1_name = "per_layer_proj"; - const std::string ffn_moe_gate_bias_prefix = "ffn_moe_gate_biased"; - const std::string ffn_moe_up_bias_prefix = "ffn_moe_up_biased"; - const std::string ffn_moe_down_bias_prefix = "ffn_moe_down_biased"; - const std::string nemotron_h_block_out_prefix = "nemotron_h_block_out"; - const std::string mamba2_y_add_d_prefix = "mamba2_y_add_d"; - const std::string delta_net_prefix = "dnet_add"; - for (int i = 0; i < cgraph->n_nodes; i++) { ggml_tensor * node = cgraph->nodes[i]; @@ -2888,31 +2882,14 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) { #endif } - if (node->op == GGML_OP_MUL_MAT_ID && node->ne[2] != 1) { - use_cuda_graph = false; // This node type is not supported by CUDA graph capture -#ifndef NDEBUG - GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported node type\n", __func__); -#endif - } - - if (node->op == GGML_OP_ADD && - node->src[1] && node->src[1]->ne[1] > 1 && - (node->src[0] ? node->src[0]->name != gemma3n_per_layer_proj_src0_name : true) && - (node->src[1] ? node->src[1]->name != gemma3n_per_layer_proj_src1_name : true) && - strncmp(node->name, ffn_moe_gate_bias_prefix.c_str(), ffn_moe_gate_bias_prefix.size()) != 0 && - strncmp(node->name, ffn_moe_up_bias_prefix.c_str(), ffn_moe_up_bias_prefix.size()) != 0 && - strncmp(node->name, ffn_moe_down_bias_prefix.c_str(), ffn_moe_down_bias_prefix.size()) != 0 && - strncmp(node->name, nemotron_h_block_out_prefix.c_str(), nemotron_h_block_out_prefix.size()) != 0 && - strncmp(node->name, mamba2_y_add_d_prefix.c_str(), mamba2_y_add_d_prefix.size()) != 0 && - strncmp(node->name, delta_net_prefix.c_str(), delta_net_prefix.size()) != 0) { - // disable CUDA graphs for batch size > 1 for now while excluding the matrix-matrix addition as part of Gemma3n's `project_per_layer_input` operation - // by means of matching node names. See - // https://github.com/ggml-org/llama.cpp/blob/f9a31eea06a859e34cecb88b4d020c7f03d86cc4/src/llama-model.cpp#L10199-L10241 and - // https://github.com/huggingface/transformers/blob/bda75b4011239d065de84aa3e744b67ebfa7b245/src/transformers/models/gemma3n/modeling_gemma3n.py#L1773, - // Generally, changes in batch size or context size can cause changes to the grid size of some kernels. + // [TAG_MUL_MAT_ID_CUDA_GRAPHS] + if (node->op == GGML_OP_MUL_MAT_ID && (!ggml_is_quantized(node->src[0]->type) || node->ne[2] > MMVQ_MMID_MAX_BATCH_SIZE)) { + // under these conditions, the mul_mat_id operation will need to synchronize the stream, so we cannot use CUDA graphs + // TODO: figure out a way to enable for larger batch sizes, without hurting performance + // ref: https://github.com/ggml-org/llama.cpp/pull/18958 use_cuda_graph = false; #ifndef NDEBUG - GGML_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]); + GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported node type\n", __func__); #endif } diff --git a/ggml/src/ggml-cuda/mmvq.cuh b/ggml/src/ggml-cuda/mmvq.cuh index 4bb10cfaec2..8a154631f69 100644 --- a/ggml/src/ggml-cuda/mmvq.cuh +++ b/ggml/src/ggml-cuda/mmvq.cuh @@ -1,6 +1,7 @@ #include "common.cuh" #define MMVQ_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVQ kernels. +#define MMVQ_MMID_MAX_BATCH_SIZE 4 // Max. batch size for which to use MMVQ kernels for MUL_MAT_ID void ggml_cuda_mul_mat_vec_q(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, const ggml_cuda_mm_fusion_args_host * fusion = nullptr); From 58855d08c21baca953b9958c4b2a2b71bd5a2e47 Mon Sep 17 00:00:00 2001 From: Talha Can Havadar Date: Tue, 17 Feb 2026 12:22:46 +0100 Subject: [PATCH 168/831] ggml: ggml-cpu: force-no-lto-for-cpu-feats (llama/19609) When LTO enabled in build environments it forces all builds to have LTO in place. But feature detection logic is fragile, and causing Illegal instruction errors with lto. This disables LTO for the feature detection code to prevent cross-module optimization from inlining architecture-specific instructions into the score function. Without this, LTO can cause SIGILL when loading backends on older CPUs (e.g., loading power10 backend on power9 crashes before feature check runs). --- ggml/src/ggml-cpu/CMakeLists.txt | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index 43d6f7f54f7..3dc948e4d8e 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -9,6 +9,11 @@ function(ggml_add_cpu_backend_features cpu_name arch) target_compile_definitions(${GGML_CPU_FEATS_NAME} PRIVATE ${ARGN}) target_compile_definitions(${GGML_CPU_FEATS_NAME} PRIVATE GGML_BACKEND_DL GGML_BACKEND_BUILD GGML_BACKEND_SHARED) set_target_properties(${GGML_CPU_FEATS_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON) + # Disable LTO for the feature detection code to prevent cross-module optimization + # from inlining architecture-specific instructions into the score function. + # Without this, LTO can cause SIGILL when loading backends on older CPUs + # (e.g., loading power10 backend on power9 crashes before feature check runs). + target_compile_options(${GGML_CPU_FEATS_NAME} PRIVATE -fno-lto) target_link_libraries(${cpu_name} PRIVATE ${GGML_CPU_FEATS_NAME}) endfunction() From 6fadc749a98d07cfb4250a91a064d93c0ec38818 Mon Sep 17 00:00:00 2001 From: shaofeiqi Date: Tue, 17 Feb 2026 13:56:09 -0800 Subject: [PATCH 169/831] opencl: optimize mean and sum_row kernels (llama/19614) * opencl: optimize mean and sum_row kernels * opencl: add comment for max subgroups * opencl: format --------- Co-authored-by: Li He --- ggml/src/ggml-opencl/ggml-opencl.cpp | 32 ++++-- ggml/src/ggml-opencl/kernels/mean.cl | 127 ++++++++++++++++++++--- ggml/src/ggml-opencl/kernels/sum_rows.cl | 127 ++++++++++++++++++++--- 3 files changed, 251 insertions(+), 35 deletions(-) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index ae3f79fd0d6..3dd12e177f3 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -484,7 +484,7 @@ struct ggml_backend_opencl_context { cl_kernel kernel_scale_f32, kernel_scale_f32_4; cl_kernel kernel_sqr_cont_f32, kernel_sqr_cont_f32_4, kernel_sqr_cont_f16, kernel_sqr_cont_f16_4; cl_kernel kernel_sqrt_cont_f32, kernel_sqrt_cont_f32_4, kernel_sqrt_cont_f16, kernel_sqrt_cont_f16_4; - cl_kernel kernel_mean_f32; + cl_kernel kernel_mean_f32, kernel_mean_f32_4; cl_kernel kernel_silu, kernel_silu_4; cl_kernel kernel_gelu, kernel_gelu_4; cl_kernel kernel_gelu_erf, kernel_gelu_erf_4; @@ -543,7 +543,7 @@ struct ggml_backend_opencl_context { cl_kernel kernel_solve_tri_f32; cl_kernel kernel_im2col_f32, kernel_im2col_f16; cl_kernel kernel_argsort_f32_i32; - cl_kernel kernel_sum_rows_f32; + cl_kernel kernel_sum_rows_f32, kernel_sum_rows_f32_4; cl_kernel kernel_repeat_f32; cl_kernel kernel_pad; cl_kernel kernel_tanh_f32, kernel_tanh_f32_4, kernel_tanh_f32_nc; @@ -1837,6 +1837,7 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); CL_CHECK((backend_ctx->kernel_mean_f32 = clCreateKernel(prog, "kernel_mean_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_mean_f32_4 = clCreateKernel(prog, "kernel_mean_f32_4", &err), err)); CL_CHECK(clReleaseProgram(prog)); GGML_LOG_CONT("."); @@ -1874,6 +1875,7 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); CL_CHECK((backend_ctx->kernel_sum_rows_f32 = clCreateKernel(backend_ctx->program_sum_rows_f32, "kernel_sum_rows_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_sum_rows_f32_4 = clCreateKernel(backend_ctx->program_sum_rows_f32, "kernel_sum_rows_f32_4", &err), err)); GGML_LOG_CONT("."); } @@ -3587,7 +3589,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te } case GGML_OP_SUM_ROWS: case GGML_OP_MEAN: - return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]); + return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_FLASH_ATTN_EXT: { const ggml_tensor * q = op->src[0]; @@ -6400,7 +6402,6 @@ static void ggml_cl_mean(ggml_backend_t backend, const ggml_tensor * src0, const GGML_UNUSED(src1); GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type)); - GGML_ASSERT(ggml_is_contiguous(src0)); ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; @@ -6423,7 +6424,14 @@ static void ggml_cl_mean(ggml_backend_t backend, const ggml_tensor * src0, const const cl_ulong nb2 = dst->nb[2]; const cl_ulong nb3 = dst->nb[3]; - cl_kernel kernel = backend_ctx->kernel_mean_f32; + cl_kernel kernel; + + const bool is_c4 = ne00 % 4 == 0; + if (is_c4) { + kernel = backend_ctx->kernel_mean_f32_4; + } else { + kernel = backend_ctx->kernel_mean_f32; + } CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); @@ -6440,7 +6448,7 @@ static void ggml_cl_mean(ggml_backend_t backend, const ggml_tensor * src0, const CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb2)); CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb3)); - size_t global_work_size[] = {(size_t)ne01, (size_t)ne02, (size_t)ne03}; + size_t global_work_size[] = {64 * (size_t)ne01, (size_t)ne02, (size_t)ne03}; size_t local_work_size[] = {(size_t)64, 1, 1}; backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); @@ -11088,7 +11096,6 @@ static void ggml_cl_sum_rows(ggml_backend_t backend, const ggml_tensor * src0, c GGML_UNUSED(src1); GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type)); - GGML_ASSERT(ggml_is_contiguous(src0)); ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; @@ -11111,7 +11118,14 @@ static void ggml_cl_sum_rows(ggml_backend_t backend, const ggml_tensor * src0, c const cl_ulong nb2 = dst->nb[2]; const cl_ulong nb3 = dst->nb[3]; - cl_kernel kernel = backend_ctx->kernel_sum_rows_f32; + cl_kernel kernel; + + const bool is_c4 = ne00 % 4 == 0; + if (is_c4) { + kernel = backend_ctx->kernel_sum_rows_f32_4; + } else { + kernel = backend_ctx->kernel_sum_rows_f32; + } CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); @@ -11128,7 +11142,7 @@ static void ggml_cl_sum_rows(ggml_backend_t backend, const ggml_tensor * src0, c CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb2)); CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb3)); - size_t global_work_size[] = {(size_t)ne01, (size_t)ne02, (size_t)ne03}; + size_t global_work_size[] = {64 * (size_t)ne01, (size_t)ne02, (size_t)ne03}; size_t local_work_size[] = {(size_t)64, 1, 1}; backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); diff --git a/ggml/src/ggml-opencl/kernels/mean.cl b/ggml/src/ggml-opencl/kernels/mean.cl index 5c3e8bcd863..7c7e0a587ee 100644 --- a/ggml/src/ggml-opencl/kernels/mean.cl +++ b/ggml/src/ggml-opencl/kernels/mean.cl @@ -1,8 +1,13 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +// Most devices have max workgroup size of 1024, so this is enough for subgroup +// sizes of 16, 32, 64 and 128. Increase this value for smaller subgroups sizes +#define MAX_SUBGROUPS 64 kernel void kernel_mean_f32( - global float * src0, + global char * src0, ulong offset0, - global float * dst, + global char * dst, ulong offsetd, int ne00, int ne01, @@ -15,25 +20,121 @@ kernel void kernel_mean_f32( ulong nb2, ulong nb3 ) { - src0 = (global float *)((global char *)src0 + offset0); - dst = (global float *)((global char *)dst + offsetd); + src0 = src0 + offset0; + dst = dst + offsetd; - int i3 = get_global_id(2); - int i2 = get_global_id(1); - int i1 = get_global_id(0); + const int i3 = get_group_id(2); + const int i2 = get_group_id(1); + const int i1 = get_group_id(0); + + const int lid = get_local_id(0); + const int lsize = get_local_size(0); + + const uint sg_size = get_sub_group_size(); + const uint sg_id = get_sub_group_id(); + const uint sg_lid = get_sub_group_local_id(); + + __local float lmem[MAX_SUBGROUPS]; if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) { return; } - global float * src_row = (global float *) ((global char *) src0 + i1*nb01 + i2*nb02 + i3*nb03); - global float * dst_row = (global float *) ((global char *) dst + i1*nb1 + i2*nb2 + i3*nb3); + if(sg_id == 0){ + lmem[sg_lid] = 0.0f; + } - float row_sum = 0; + global float * src_row = (global float *) (src0 + i1*nb01 + i2*nb02 + i3*nb03); + global float * dst_row = (global float *) (dst + i1*nb1 + i2*nb2 + i3*nb3); - for (int i0 = 0; i0 < ne00; i0++) { - row_sum += src_row[i0]; + float sumf = 0.0f; + + for (int i0 = lid; i0 < ne00; i0 += lsize) { + sumf += src_row[i0]; } - dst_row[0] = row_sum / ne00; + sumf = sub_group_reduce_add(sumf); + + barrier(CLK_LOCAL_MEM_FENCE); + + if(sg_lid == 0){ + lmem[sg_id] = sumf; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + sumf = lmem[sg_lid]; + sumf = sub_group_reduce_add(sumf); + + if (lid == 0) { + dst_row[0] = sumf / ne00; + } +} + +kernel void kernel_mean_f32_4( + global char * src0, + ulong offset0, + global char * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb01, + ulong nb02, + ulong nb03, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = src0 + offset0; + dst = dst + offsetd; + + const int i3 = get_group_id(2); + const int i2 = get_group_id(1); + const int i1 = get_group_id(0); + + const int lid = get_local_id(0); + const int lsize = get_local_size(0); + + const uint sg_size = get_sub_group_size(); + const uint sg_id = get_sub_group_id(); + const uint sg_lid = get_sub_group_local_id(); + + __local float lmem[MAX_SUBGROUPS]; + + if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) { + return; + } + + if(sg_id == 0){ + lmem[sg_lid] = 0.0f; + } + + global float4 * src_row = (global float4 *) (src0 + i1*nb01 + i2*nb02 + i3*nb03); + global float * dst_row = (global float *) (dst + i1*nb1 + i2*nb2 + i3*nb3); + + float4 sum_vec = (float4)0.0f; + + for (int i0 = lid; i0 < ne00 / 4; i0 += lsize) { + sum_vec += src_row[i0]; + } + + float sumf = dot(sum_vec, (float4)(1.0f)); + sumf = sub_group_reduce_add(sumf); + + barrier(CLK_LOCAL_MEM_FENCE); + + if(sg_lid == 0){ + lmem[sg_id] = sumf; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + sumf = lmem[sg_lid]; + sumf = sub_group_reduce_add(sumf); + + if (lid == 0) { + dst_row[0] = sumf / ne00; + } } diff --git a/ggml/src/ggml-opencl/kernels/sum_rows.cl b/ggml/src/ggml-opencl/kernels/sum_rows.cl index c5f7c570f95..84630aa8a30 100644 --- a/ggml/src/ggml-opencl/kernels/sum_rows.cl +++ b/ggml/src/ggml-opencl/kernels/sum_rows.cl @@ -1,8 +1,13 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +// Most devices have max workgroup size of 1024, so this is enough for subgroup +// sizes of 16, 32, 64 and 128. Increase this value for smaller subgroups sizes +#define MAX_SUBGROUPS 64 kernel void kernel_sum_rows_f32( - global float * src0, + global char * src0, ulong offset0, - global float * dst, + global char * dst, ulong offsetd, int ne00, int ne01, @@ -15,25 +20,121 @@ kernel void kernel_sum_rows_f32( ulong nb2, ulong nb3 ) { - src0 = (global float *)((global char *)src0 + offset0); - dst = (global float *)((global char *)dst + offsetd); + src0 = src0 + offset0; + dst = dst + offsetd; - int i3 = get_global_id(2); - int i2 = get_global_id(1); - int i1 = get_global_id(0); + const int i3 = get_group_id(2); + const int i2 = get_group_id(1); + const int i1 = get_group_id(0); + + const int lid = get_local_id(0); + const int lsize = get_local_size(0); + + const uint sg_size = get_sub_group_size(); + const uint sg_id = get_sub_group_id(); + const uint sg_lid = get_sub_group_local_id(); + + __local float lmem[MAX_SUBGROUPS]; if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) { return; } - global float * src_row = (global float *) ((global char *) src0 + i1*nb01 + i2*nb02 + i3*nb03); - global float * dst_row = (global float *) ((global char *) dst + i1*nb1 + i2*nb2 + i3*nb3); + if(sg_id == 0){ + lmem[sg_lid] = 0.0f; + } - float row_sum = 0; + global float * src_row = (global float *) (src0 + i1*nb01 + i2*nb02 + i3*nb03); + global float * dst_row = (global float *) (dst + i1*nb1 + i2*nb2 + i3*nb3); - for (int i0 = 0; i0 < ne00; i0++) { - row_sum += src_row[i0]; + float sumf = 0.0f; + + for (int i0 = lid; i0 < ne00; i0 += lsize) { + sumf += src_row[i0]; } - dst_row[0] = row_sum; + sumf = sub_group_reduce_add(sumf); + + barrier(CLK_LOCAL_MEM_FENCE); + + if(sg_lid == 0){ + lmem[sg_id] = sumf; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + sumf = lmem[sg_lid]; + sumf = sub_group_reduce_add(sumf); + + if (lid == 0) { + dst_row[0] = sumf; + } +} + +kernel void kernel_sum_rows_f32_4( + global char * src0, + ulong offset0, + global char * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb01, + ulong nb02, + ulong nb03, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = src0 + offset0; + dst = dst + offsetd; + + const int i3 = get_group_id(2); + const int i2 = get_group_id(1); + const int i1 = get_group_id(0); + + const int lid = get_local_id(0); + const int lsize = get_local_size(0); + + const uint sg_size = get_sub_group_size(); + const uint sg_id = get_sub_group_id(); + const uint sg_lid = get_sub_group_local_id(); + + __local float lmem[MAX_SUBGROUPS]; + + if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) { + return; + } + + if(sg_id == 0){ + lmem[sg_lid] = 0.0f; + } + + global float4 * src_row = (global float4 *) (src0 + i1*nb01 + i2*nb02 + i3*nb03); + global float * dst_row = (global float *) (dst + i1*nb1 + i2*nb2 + i3*nb3); + + float4 sum_vec = (float4)0.0f; + + for (int i0 = lid; i0 < ne00 / 4; i0 += lsize) { + sum_vec += src_row[i0]; + } + + float sumf = dot(sum_vec, (float4)(1.0f)); + sumf = sub_group_reduce_add(sumf); + + barrier(CLK_LOCAL_MEM_FENCE); + + if(sg_lid == 0){ + lmem[sg_id] = sumf; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + sumf = lmem[sg_lid]; + sumf = sub_group_reduce_add(sumf); + + if (lid == 0) { + dst_row[0] = sumf; + } } From 51ce7de94ca9508f2baf05c82f027ed393e019c5 Mon Sep 17 00:00:00 2001 From: shaofeiqi Date: Tue, 17 Feb 2026 14:47:18 -0800 Subject: [PATCH 170/831] opencl: refactor expm1 and softplus (llama/19404) * opencl: refactor expm1 * opencl: refactor softplus * opencl: use h for half literals --------- Co-authored-by: Li He --- ggml/src/ggml-opencl/ggml-opencl.cpp | 310 +++++++++++------------ ggml/src/ggml-opencl/kernels/expm1.cl | 143 +++++++---- ggml/src/ggml-opencl/kernels/softplus.cl | 148 ++++++----- 3 files changed, 319 insertions(+), 282 deletions(-) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 3dd12e177f3..3da022ed86c 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -548,10 +548,10 @@ struct ggml_backend_opencl_context { cl_kernel kernel_pad; cl_kernel kernel_tanh_f32, kernel_tanh_f32_4, kernel_tanh_f32_nc; cl_kernel kernel_tanh_f16, kernel_tanh_f16_4, kernel_tanh_f16_nc; - cl_kernel kernel_expm1_f32_nd; - cl_kernel kernel_expm1_f16_nd; - cl_kernel kernel_softplus_f32_nd; - cl_kernel kernel_softplus_f16_nd; + cl_kernel kernel_expm1_f32, kernel_expm1_f32_4, kernel_expm1_f32_nc; + cl_kernel kernel_expm1_f16, kernel_expm1_f16_4, kernel_expm1_f16_nc; + cl_kernel kernel_softplus_f32, kernel_softplus_f32_4, kernel_softplus_f32_nc; + cl_kernel kernel_softplus_f16, kernel_softplus_f16_4, kernel_softplus_f16_nc; cl_kernel kernel_upscale; cl_kernel kernel_upscale_bilinear; cl_kernel kernel_concat_f32; @@ -1980,20 +1980,16 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve #else const std::string kernel_src = read_file("expm1.cl"); #endif - cl_program prog; - if (!kernel_src.empty()) { - prog = - build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_expm1_f32_nd = clCreateKernel(prog, "kernel_expm1_f32_nd", &err), err)); - CL_CHECK((backend_ctx->kernel_expm1_f16_nd = clCreateKernel(prog, "kernel_expm1_f16_nd", &err), err)); - GGML_LOG_CONT("."); - } else { - GGML_LOG_WARN("ggml_opencl: expm1 kernel source not found or empty. Expm1 operation will not be available.\n"); - prog = nullptr; - backend_ctx->kernel_expm1_f32_nd = nullptr; - backend_ctx->kernel_expm1_f16_nd = nullptr; - } + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_expm1_f32 = clCreateKernel(prog, "kernel_expm1_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_expm1_f32_4 = clCreateKernel(prog, "kernel_expm1_f32_4", &err), err)); + CL_CHECK((backend_ctx->kernel_expm1_f32_nc = clCreateKernel(prog, "kernel_expm1_f32_nc", &err), err)); + CL_CHECK((backend_ctx->kernel_expm1_f16 = clCreateKernel(prog, "kernel_expm1_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_expm1_f16_4 = clCreateKernel(prog, "kernel_expm1_f16_4", &err), err)); + CL_CHECK((backend_ctx->kernel_expm1_f16_nc = clCreateKernel(prog, "kernel_expm1_f16_nc", &err), err)); CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); } // softplus @@ -2005,20 +2001,16 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve #else const std::string kernel_src = read_file("softplus.cl"); #endif - cl_program prog; - if (!kernel_src.empty()) { - prog = - build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_softplus_f32_nd = clCreateKernel(prog, "kernel_softplus_f32_nd", &err), err)); - CL_CHECK((backend_ctx->kernel_softplus_f16_nd = clCreateKernel(prog, "kernel_softplus_f16_nd", &err), err)); - GGML_LOG_CONT("."); - } else { - GGML_LOG_WARN("ggml_opencl: softplus kernel source not found or empty. Softplus operation will not be available.\n"); - prog = nullptr; - backend_ctx->kernel_softplus_f32_nd = nullptr; - backend_ctx->kernel_softplus_f16_nd = nullptr; - } + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_softplus_f32 = clCreateKernel(prog, "kernel_softplus_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_softplus_f32_4 = clCreateKernel(prog, "kernel_softplus_f32_4", &err), err)); + CL_CHECK((backend_ctx->kernel_softplus_f32_nc = clCreateKernel(prog, "kernel_softplus_f32_nc", &err), err)); + CL_CHECK((backend_ctx->kernel_softplus_f16 = clCreateKernel(prog, "kernel_softplus_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_softplus_f16_4 = clCreateKernel(prog, "kernel_softplus_f16_4", &err), err)); + CL_CHECK((backend_ctx->kernel_softplus_f16_nc = clCreateKernel(prog, "kernel_softplus_f16_nc", &err), err)); CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); } // upscale @@ -3465,11 +3457,9 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te case GGML_UNARY_OP_TANH: return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16; case GGML_UNARY_OP_EXPM1: - return (op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) || - (op->src[0]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16); + return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16; case GGML_UNARY_OP_SOFTPLUS: - return (op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) || - (op->src[0]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16); + return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16; default: return false; } @@ -7396,18 +7386,8 @@ static void ggml_cl_expm1(ggml_backend_t backend, const ggml_tensor * src0, cons ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; - cl_ulong offset0_abs = extra0->offset + src0->view_offs; - cl_ulong offsetd_abs = extrad->offset + dst->view_offs; - - cl_kernel kernel; - if (dst->type == GGML_TYPE_F32) { - kernel = backend_ctx->kernel_expm1_f32_nd; - } else if (dst->type == GGML_TYPE_F16) { - kernel = backend_ctx->kernel_expm1_f16_nd; - } else { - GGML_ASSERT(false && "Unsupported type for ggml_cl_expm1"); - } - GGML_ASSERT(kernel != nullptr); + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; const int ne00 = src0->ne[0]; const int ne01 = src0->ne[1]; @@ -7419,70 +7399,74 @@ static void ggml_cl_expm1(ggml_backend_t backend, const ggml_tensor * src0, cons const cl_ulong nb02 = src0->nb[2]; const cl_ulong nb03 = src0->nb[3]; - const int ne10 = dst->ne[0]; - const int ne11 = dst->ne[1]; - const int ne12 = dst->ne[2]; - const int ne13 = dst->ne[3]; + const cl_ulong nb0 = dst->nb[0]; + const cl_ulong nb1 = dst->nb[1]; + const cl_ulong nb2 = dst->nb[2]; + const cl_ulong nb3 = dst->nb[3]; - const cl_ulong nb10 = dst->nb[0]; - const cl_ulong nb11 = dst->nb[1]; - const cl_ulong nb12 = dst->nb[2]; - const cl_ulong nb13 = dst->nb[3]; + cl_kernel kernel; - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0_abs)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd_abs)); + if (ggml_is_contiguous(src0)) { + // Handle contiguous input + int n = ggml_nelements(dst); + if (n % 4 == 0) { + if (src0->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_expm1_f32_4; + } else { + kernel = backend_ctx->kernel_expm1_f16_4; + } + n /= 4; + } else { + if (src0->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_expm1_f32; + } else { + kernel = backend_ctx->kernel_expm1_f16; + } + } - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); - CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01)); - CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02)); - CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03)); - CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb00)); - CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb01)); - CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong),&nb02)); - CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong),&nb03)); + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); - CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne10)); - CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne11)); - CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne12)); - CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne13)); - CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong),&nb10)); - CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong),&nb11)); - CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong),&nb12)); - CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong),&nb13)); - - size_t global_work_size[3]; - if (ne10 == 0 || ne11 == 0 || ne12 == 0 || ne13 == 0) { // Handle case of 0 elements - return; - } - global_work_size[0] = (size_t)ne10; - global_work_size[1] = (size_t)ne11; - global_work_size[2] = (size_t)ne12; + size_t global_work_size[] = {(size_t)n, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + + size_t * local_work_size_ptr = local_work_size; + if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) { + local_work_size_ptr = nullptr; + } - size_t lws0 = 16, lws1 = 4, lws2 = 1; - if (ne10 < 16) lws0 = ne10; - if (ne11 < 4) lws1 = ne11; - if (ne12 < 1) lws2 = ne12 > 0 ? ne12 : 1; + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst); + } else { + // Handle non-contiguous input + if (src0->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_expm1_f32_nc; + } else { + kernel = backend_ctx->kernel_expm1_f16_nc; + } - while (lws0 * lws1 * lws2 > 256 && lws0 > 1) lws0 /= 2; - while (lws0 * lws1 * lws2 > 256 && lws1 > 1) lws1 /= 2; - while (lws0 * lws1 * lws2 > 256 && lws2 > 1) lws2 /= 2; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &nb00)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb0)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb1)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb2)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb3)); + int nth = 64; - size_t local_work_size[] = {lws0, lws1, lws2}; + size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03}; + size_t local_work_size[] = {(size_t)nth, 1, 1}; - size_t* local_work_size_ptr = local_work_size; - if (!backend_ctx->non_uniform_workgroups) { - if (global_work_size[0] % local_work_size[0] != 0 || - global_work_size[1] % local_work_size[1] != 0 || - global_work_size[2] % local_work_size[2] != 0) { - local_work_size_ptr = NULL; - } + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); } - if (global_work_size[0] == 0 || global_work_size[1] == 0 || global_work_size[2] == 0) return; - - backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst); } static void ggml_cl_softplus(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { @@ -7498,18 +7482,8 @@ static void ggml_cl_softplus(ggml_backend_t backend, const ggml_tensor * src0, c ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; - cl_ulong offset0_abs = extra0->offset + src0->view_offs; - cl_ulong offsetd_abs = extrad->offset + dst->view_offs; - - cl_kernel kernel; - if (dst->type == GGML_TYPE_F32) { - kernel = backend_ctx->kernel_softplus_f32_nd; - } else if (dst->type == GGML_TYPE_F16) { - kernel = backend_ctx->kernel_softplus_f16_nd; - } else { - GGML_ASSERT(false && "Unsupported type for ggml_cl_softplus"); - } - GGML_ASSERT(kernel != nullptr); + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; const int ne00 = src0->ne[0]; const int ne01 = src0->ne[1]; @@ -7521,70 +7495,74 @@ static void ggml_cl_softplus(ggml_backend_t backend, const ggml_tensor * src0, c const cl_ulong nb02 = src0->nb[2]; const cl_ulong nb03 = src0->nb[3]; - const int ne10 = dst->ne[0]; - const int ne11 = dst->ne[1]; - const int ne12 = dst->ne[2]; - const int ne13 = dst->ne[3]; + const cl_ulong nb0 = dst->nb[0]; + const cl_ulong nb1 = dst->nb[1]; + const cl_ulong nb2 = dst->nb[2]; + const cl_ulong nb3 = dst->nb[3]; - const cl_ulong nb10 = dst->nb[0]; - const cl_ulong nb11 = dst->nb[1]; - const cl_ulong nb12 = dst->nb[2]; - const cl_ulong nb13 = dst->nb[3]; + cl_kernel kernel; - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0_abs)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd_abs)); + if (ggml_is_contiguous(src0)) { + // Handle contiguous input + int n = ggml_nelements(dst); + if (n % 4 == 0) { + if (src0->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_softplus_f32_4; + } else { + kernel = backend_ctx->kernel_softplus_f16_4; + } + n /= 4; + } else { + if (src0->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_softplus_f32; + } else { + kernel = backend_ctx->kernel_softplus_f16; + } + } - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); - CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01)); - CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02)); - CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03)); - CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb00)); - CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb01)); - CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong),&nb02)); - CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong),&nb03)); + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); - CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne10)); - CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne11)); - CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne12)); - CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne13)); - CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong),&nb10)); - CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong),&nb11)); - CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong),&nb12)); - CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong),&nb13)); - - size_t global_work_size[3]; - if (ne10 == 0 || ne11 == 0 || ne12 == 0 || ne13 == 0) { // Handle case of 0 elements - return; - } - global_work_size[0] = (size_t)ne10; - global_work_size[1] = (size_t)ne11; - global_work_size[2] = (size_t)ne12; + size_t global_work_size[] = {(size_t)n, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + + size_t * local_work_size_ptr = local_work_size; + if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) { + local_work_size_ptr = nullptr; + } - size_t lws0 = 16, lws1 = 4, lws2 = 1; - if (ne10 < 16) lws0 = ne10; - if (ne11 < 4) lws1 = ne11; - if (ne12 < 1) lws2 = ne12 > 0 ? ne12 : 1; + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst); + } else { + // Handle non-contiguous input + if (src0->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_softplus_f32_nc; + } else { + kernel = backend_ctx->kernel_softplus_f16_nc; + } - while (lws0 * lws1 * lws2 > 256 && lws0 > 1) lws0 /= 2; - while (lws0 * lws1 * lws2 > 256 && lws1 > 1) lws1 /= 2; - while (lws0 * lws1 * lws2 > 256 && lws2 > 1) lws2 /= 2; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &nb00)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb0)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb1)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb2)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb3)); + int nth = 64; - size_t local_work_size[] = {lws0, lws1, lws2}; + size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03}; + size_t local_work_size[] = {(size_t)nth, 1, 1}; - size_t* local_work_size_ptr = local_work_size; - if (!backend_ctx->non_uniform_workgroups) { - if (global_work_size[0] % local_work_size[0] != 0 || - global_work_size[1] % local_work_size[1] != 0 || - global_work_size[2] % local_work_size[2] != 0) { - local_work_size_ptr = NULL; - } + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); } - if (global_work_size[0] == 0 || global_work_size[1] == 0 || global_work_size[2] == 0) return; - - backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst); } static void ggml_cl_repeat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1_shape_def, ggml_tensor * dst) { diff --git a/ggml/src/ggml-opencl/kernels/expm1.cl b/ggml/src/ggml-opencl/kernels/expm1.cl index 126298a2cdb..05442ac2043 100644 --- a/ggml/src/ggml-opencl/kernels/expm1.cl +++ b/ggml/src/ggml-opencl/kernels/expm1.cl @@ -3,80 +3,111 @@ //------------------------------------------------------------------------------ // expm1 //------------------------------------------------------------------------------ -kernel void kernel_expm1_f32_nd( - global void * p_src0_base, - ulong off_src0_abs, - global void * p_dst_base, - ulong off_dst_abs, - int ne00, - int ne01, - int ne02, - int ne03, + +kernel void kernel_expm1_f32( + global const float * src0, + ulong offset0, + global float * dst, + ulong offsetd +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + dst[get_global_id(0)] = exp(src0[get_global_id(0)]) - 1.0f; +} + +kernel void kernel_expm1_f32_4( + global const float4 * src0, + ulong offset0, + global float4 * dst, + ulong offsetd +) { + src0 = (global float4*)((global char*)src0 + offset0); + dst = (global float4*)((global char*)dst + offsetd); + + dst[get_global_id(0)] = exp(src0[get_global_id(0)]) - 1.0f; +} + +kernel void kernel_expm1_f16( + global const half * src0, + ulong offset0, + global half * dst, + ulong offsetd +) { + src0 = (global half*)((global char*)src0 + offset0); + dst = (global half*)((global char*)dst + offsetd); + + dst[get_global_id(0)] = exp(src0[get_global_id(0)]) - 1.0h; +} + +kernel void kernel_expm1_f16_4( + global const half4 * src0, + ulong offset0, + global half4 * dst, + ulong offsetd +) { + src0 = (global half4*)((global char*)src0 + offset0); + dst = (global half4*)((global char*)dst + offsetd); + + dst[get_global_id(0)] = exp(src0[get_global_id(0)]) - 1.0h; +} + +kernel void kernel_expm1_f32_nc( + global const char * src0, + ulong offset0, + global char * dst, + ulong offsetd, + int ne00, ulong nb00, ulong nb01, ulong nb02, ulong nb03, - int ne10, - int ne11, - int ne12, - int ne13, - ulong nb10, - ulong nb11, - ulong nb12, - ulong nb13 + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 ) { - int i0 = get_global_id(0); - int i1 = get_global_id(1); - int i2 = get_global_id(2); + src0 = src0 + offset0; + dst = dst + offsetd; - if (i0 < ne10 && i1 < ne11 && i2 < ne12) { - for (int i3 = 0; i3 < ne13; ++i3) { - ulong src_offset_in_tensor = (ulong)i0*nb00 + (ulong)i1*nb01 + (ulong)i2*nb02 + (ulong)i3*nb03; - global const float *src_val_ptr = (global const float *)((global char *)p_src0_base + off_src0_abs + src_offset_in_tensor); + const int i3 = get_group_id(2); + const int i2 = get_group_id(1); + const int i1 = get_group_id(0); - ulong dst_offset_in_tensor = (ulong)i0*nb10 + (ulong)i1*nb11 + (ulong)i2*nb12 + (ulong)i3*nb13; - global float *dst_val_ptr = (global float *)((global char *)p_dst_base + off_dst_abs + dst_offset_in_tensor); + for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) { + global const float * x = (global const float *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global float * y = (global float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - *dst_val_ptr = exp(*src_val_ptr) - 1; - } + *y = exp(*x) - 1.0f; } } -kernel void kernel_expm1_f16_nd( - global void * p_src0_base, - ulong off_src0_abs, - global void * p_dst_base, - ulong off_dst_abs, - int ne00, - int ne01, - int ne02, - int ne03, +kernel void kernel_expm1_f16_nc( + global const char * src0, + ulong offset0, + global char * dst, + ulong offsetd, + int ne00, ulong nb00, ulong nb01, ulong nb02, ulong nb03, - int ne10, - int ne11, - int ne12, - int ne13, - ulong nb10, - ulong nb11, - ulong nb12, - ulong nb13 + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 ) { - int i0 = get_global_id(0); - int i1 = get_global_id(1); - int i2 = get_global_id(2); + src0 = src0 + offset0; + dst = dst + offsetd; - if (i0 < ne10 && i1 < ne11 && i2 < ne12) { - for (int i3 = 0; i3 < ne13; ++i3) { - ulong src_offset_in_tensor = (ulong)i0*nb00 + (ulong)i1*nb01 + (ulong)i2*nb02 + (ulong)i3*nb03; - global const half *src_val_ptr = (global const half *)((global char *)p_src0_base + off_src0_abs + src_offset_in_tensor); + const int i3 = get_group_id(2); + const int i2 = get_group_id(1); + const int i1 = get_group_id(0); - ulong dst_offset_in_tensor = (ulong)i0*nb10 + (ulong)i1*nb11 + (ulong)i2*nb12 + (ulong)i3*nb13; - global half *dst_val_ptr = (global half *)((global char *)p_dst_base + off_dst_abs + dst_offset_in_tensor); + for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) { + global const half * x = (global const half *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global half * y = (global half *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - *dst_val_ptr = exp(*src_val_ptr) - 1; - } + *y = exp(*x) - 1.0f; } } diff --git a/ggml/src/ggml-opencl/kernels/softplus.cl b/ggml/src/ggml-opencl/kernels/softplus.cl index 033766e2e07..6f8b7474165 100644 --- a/ggml/src/ggml-opencl/kernels/softplus.cl +++ b/ggml/src/ggml-opencl/kernels/softplus.cl @@ -3,86 +3,114 @@ //------------------------------------------------------------------------------ // softplus //------------------------------------------------------------------------------ -inline float softplus_f32(float x){ - float ax = fabs(x); - float m = fmax(x, 0.0f); - return log1p(exp(-ax)) + m; + +kernel void kernel_softplus_f32( + global const float * src0, + ulong offset0, + global float * dst, + ulong offsetd +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + dst[get_global_id(0)] = (src0[get_global_id(0)] > 20.0f) ? src0[get_global_id(0)] : log(1.0f + exp(src0[get_global_id(0)])); +} + +kernel void kernel_softplus_f32_4( + global const float4 * src0, + ulong offset0, + global float4 * dst, + ulong offsetd +) { + src0 = (global float4*)((global char*)src0 + offset0); + dst = (global float4*)((global char*)dst + offsetd); + + dst[get_global_id(0)] = (src0[get_global_id(0)] > 20.0f) ? src0[get_global_id(0)] : log(1.0f + exp(src0[get_global_id(0)])); +} + +kernel void kernel_softplus_f16( + global const half * src0, + ulong offset0, + global half * dst, + ulong offsetd +) { + src0 = (global half*)((global char*)src0 + offset0); + dst = (global half*)((global char*)dst + offsetd); + + const float x = convert_float(src0[get_global_id(0)]); + dst[get_global_id(0)] = convert_half_rte((x > 20.0f) ? x : log(1.0f + exp(x))); +} + +kernel void kernel_softplus_f16_4( + global const half4 * src0, + ulong offset0, + global half4 * dst, + ulong offsetd +) { + src0 = (global half4*)((global char*)src0 + offset0); + dst = (global half4*)((global char*)dst + offsetd); + + const float4 x = convert_float4(src0[get_global_id(0)]); + dst[get_global_id(0)] = convert_half4_rte((x > 20.0f) ? x : log(1.0f + exp(x))); } -kernel void kernel_softplus_f32_nd( - global void * p_src0_base, - ulong off_src0_abs, - global void * p_dst_base, - ulong off_dst_abs, - int ne00, - int ne01, - int ne02, - int ne03, +kernel void kernel_softplus_f32_nc( + global const char * src0, + ulong offset0, + global char * dst, + ulong offsetd, + int ne00, ulong nb00, ulong nb01, ulong nb02, ulong nb03, - int ne10, - int ne11, - int ne12, - int ne13, - ulong nb10, - ulong nb11, - ulong nb12, - ulong nb13 + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 ) { - int i0 = get_global_id(0); - int i1 = get_global_id(1); - int i2 = get_global_id(2); + src0 = src0 + offset0; + dst = dst + offsetd; - if (i0 < ne10 && i1 < ne11 && i2 < ne12) { - for (int i3 = 0; i3 < ne13; ++i3) { - ulong src_offset_in_tensor = (ulong)i0*nb00 + (ulong)i1*nb01 + (ulong)i2*nb02 + (ulong)i3*nb03; - global const float *src_val_ptr = (global const float *)((global char *)p_src0_base + off_src0_abs + src_offset_in_tensor); + const int i3 = get_group_id(2); + const int i2 = get_group_id(1); + const int i1 = get_group_id(0); - ulong dst_offset_in_tensor = (ulong)i0*nb10 + (ulong)i1*nb11 + (ulong)i2*nb12 + (ulong)i3*nb13; - global float *dst_val_ptr = (global float *)((global char *)p_dst_base + off_dst_abs + dst_offset_in_tensor); + for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) { + global const float * x = (global const float *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global float * y = (global float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - *dst_val_ptr = softplus_f32(*src_val_ptr); - } + *y = (*x > 20.0f) ? *x : log(1.0f + exp(*x)); } } -kernel void kernel_softplus_f16_nd( - global void * p_src0_base, - ulong off_src0_abs, - global void * p_dst_base, - ulong off_dst_abs, - int ne00, - int ne01, - int ne02, - int ne03, +kernel void kernel_softplus_f16_nc( + global const char * src0, + ulong offset0, + global char * dst, + ulong offsetd, + int ne00, ulong nb00, ulong nb01, ulong nb02, ulong nb03, - int ne10, - int ne11, - int ne12, - int ne13, - ulong nb10, - ulong nb11, - ulong nb12, - ulong nb13 + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 ) { - int i0 = get_global_id(0); - int i1 = get_global_id(1); - int i2 = get_global_id(2); + src0 = src0 + offset0; + dst = dst + offsetd; - if (i0 < ne10 && i1 < ne11 && i2 < ne12) { - for (int i3 = 0; i3 < ne13; ++i3) { - ulong src_offset_in_tensor = (ulong)i0*nb00 + (ulong)i1*nb01 + (ulong)i2*nb02 + (ulong)i3*nb03; - global const half *src_val_ptr = (global const half *)((global char *)p_src0_base + off_src0_abs + src_offset_in_tensor); + const int i3 = get_group_id(2); + const int i2 = get_group_id(1); + const int i1 = get_group_id(0); - ulong dst_offset_in_tensor = (ulong)i0*nb10 + (ulong)i1*nb11 + (ulong)i2*nb12 + (ulong)i3*nb13; - global half *dst_val_ptr = (global half *)((global char *)p_dst_base + off_dst_abs + dst_offset_in_tensor); + for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) { + global const half * hx = (global const half *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global half * hy = (global half *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - *dst_val_ptr = (half)(softplus_f32((float)(*src_val_ptr))); - } + const float x = convert_float(*hx); + *hy = convert_half_rte((x > 20.0f) ? x : log(1.0f + exp(x))); } } From f1da0a26f5adefd33a5d5b88ebfc9350ec2afa67 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Wed, 18 Feb 2026 01:47:10 -0800 Subject: [PATCH 171/831] vulkan: split mul_mat into multiple dispatches to avoid overflow (llama/19509) * vulkan: split mul_mat into multiple dispatches to avoid overflow The batch dimensions can be greater than the max workgroup count limit, in which case we need to split into multiple dispatches and pass the base index through a push constant. Fall back for the less common p021 and nc variants. * address feedback --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 80 ++++++++++++------- .../vulkan-shaders/mul_mat_vec_base.glsl | 5 +- .../ggml-vulkan/vulkan-shaders/mul_mm.comp | 8 +- .../vulkan-shaders/mul_mm_cm2.comp | 8 +- 4 files changed, 66 insertions(+), 35 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 114992da08d..a8840a0773b 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -944,6 +944,7 @@ struct vk_mat_mat_push_constants { uint32_t M; uint32_t N; uint32_t K; uint32_t stride_a; uint32_t stride_b; uint32_t stride_d; uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d; + uint32_t base_work_group_z; uint32_t num_batches; uint32_t k_split; uint32_t ne02; uint32_t ne12; uint32_t broadcast2; uint32_t broadcast3; uint32_t padded_N; @@ -963,6 +964,7 @@ struct vk_mat_vec_push_constants { uint32_t batch_stride_b; uint32_t batch_stride_d; uint32_t fusion_flags; + uint32_t base_work_group_y; uint32_t ne02; uint32_t ne12; uint32_t broadcast2; @@ -6773,8 +6775,16 @@ static void ggml_vk_matmul( uint32_t padded_n) { VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ", padded_n: " << padded_n << ")"); if (split_k == 1) { - const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3, padded_n }; - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, pc, { m, n, batch }); + ggml_pipeline_request_descriptor_sets(ctx, pipeline, CEIL_DIV(batch, ctx->device->properties.limits.maxComputeWorkGroupCount[2])); + + uint32_t base_work_group_z = 0; + while (base_work_group_z < batch) { + uint32_t groups_z = std::min(batch - base_work_group_z, ctx->device->properties.limits.maxComputeWorkGroupCount[2]); + + const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, base_work_group_z, batch, k, ne02, ne12, broadcast2, broadcast3, padded_n }; + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, pc, { m, n, groups_z }); + base_work_group_z += groups_z; + } return; } @@ -6788,9 +6798,17 @@ static void ggml_vk_matmul( uint32_t k_split = CEIL_DIV(k, split_k); k_split = ROUNDUP_POW2(k_split, 256); - const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k_split, ne02, ne12, broadcast2, broadcast3, padded_n }; - // Make sure enough workgroups get assigned for split k to work - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, batch }); + ggml_pipeline_request_descriptor_sets(ctx, pipeline, CEIL_DIV(batch, ctx->device->properties.limits.maxComputeWorkGroupCount[2])); + + uint32_t base_work_group_z = 0; + while (base_work_group_z < batch) { + uint32_t groups_z = std::min(batch - base_work_group_z, ctx->device->properties.limits.maxComputeWorkGroupCount[2]); + + const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, base_work_group_z, batch, k_split, ne02, ne12, broadcast2, broadcast3, padded_n }; + // Make sure enough workgroups get assigned for split k to work + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, groups_z }); + base_work_group_z += groups_z; + } ggml_vk_sync_buffers(ctx, subctx); const std::array pc2 = { (uint32_t)(m * n * batch), split_k }; ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2, { m * n * batch, 1, 1 }); @@ -7186,7 +7204,6 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub } // Request descriptor sets - ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); if (qx_needs_dequant) { ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_0, 1); } @@ -7484,7 +7501,6 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& if (quantize_y) { ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1); } - ggml_pipeline_request_descriptor_sets(ctx, dmmv, 1); } vk_subbuffer d_D = ggml_vk_tensor_subbuffer(ctx, cgraph->nodes[node_idx + ctx->num_additional_fused_ops]); @@ -7579,22 +7595,29 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& fusion_flags |= MAT_VEC_FUSION_FLAGS_BIAS1; } - // compute - const vk_mat_vec_push_constants pc = { - (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01, - stride_batch_x, stride_batch_y, stride_batch_d, - fusion_flags, - (uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3, - }; - ggml_vk_dispatch_pipeline(ctx, subctx, dmmv, - { - d_X, - d_Y, - d_D, - d_F0, - d_F1, - }, - pc, { groups_x, (uint32_t)(ne12 * ne13), groups_z }); + ggml_pipeline_request_descriptor_sets(ctx, dmmv, CEIL_DIV(ne12 * ne13, ctx->device->properties.limits.maxComputeWorkGroupCount[1])); + + uint32_t base_work_group_y = 0; + while (base_work_group_y < ne12 * ne13) { + + uint32_t groups_y = std::min((uint32_t)(ne12 * ne13) - base_work_group_y, ctx->device->properties.limits.maxComputeWorkGroupCount[1]); + const vk_mat_vec_push_constants pc = { + (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01, + stride_batch_x, stride_batch_y, stride_batch_d, + fusion_flags, base_work_group_y, + (uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3, + }; + ggml_vk_dispatch_pipeline(ctx, subctx, dmmv, + { + d_X, + d_Y, + d_D, + d_F0, + d_F1, + }, + pc, { groups_x, groups_y, groups_z }); + base_work_group_y += groups_y; + } if (x_non_contig) { ctx->prealloc_x_need_sync = true; @@ -7832,10 +7855,15 @@ static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, c src1->nb[2] <= src1->nb[1] && src1->nb[1] <= src1->nb[3] && src0->ne[3] == 1 && - src1->ne[3] == 1) { + src1->ne[3] == 1 && + src0->ne[1] <= ctx->device->properties.limits.maxComputeWorkGroupCount[1] && + src1->ne[2] <= ctx->device->properties.limits.maxComputeWorkGroupCount[2]) { ggml_vk_mul_mat_vec_p021_f16_f32(ctx, subctx, cgraph, node_idx); } else if (src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && dst->ne[1] == 1 && - !ggml_is_permuted(src0) && !ggml_is_permuted(src1)) { + !ggml_is_permuted(src0) && !ggml_is_permuted(src1) && + src0->ne[3] <= ctx->device->properties.limits.maxComputeWorkGroupCount[0] && + src0->ne[1] <= ctx->device->properties.limits.maxComputeWorkGroupCount[1] && + src1->ne[2] <= ctx->device->properties.limits.maxComputeWorkGroupCount[2]) { ggml_vk_mul_mat_vec_nc_f16_f32(ctx, subctx, cgraph, node_idx); // mul_mat_vec supports batching ne12*ne13 when ne11==1, or treating ne11 as the batch size (up to four) // when ne12 and ne13 are one. @@ -11560,7 +11588,6 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t } } - ggml_pipeline_request_descriptor_sets(ctx, p, num_it); if (split_k > 1) { ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, num_it); @@ -12069,7 +12096,6 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, // y[i] = i % k; } - ggml_pipeline_request_descriptor_sets(ctx, p, num_it); if (split_k > 1) { ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, num_it); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl index 4f2c7003065..4aeda68c7f2 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl @@ -32,6 +32,7 @@ layout (push_constant) uniform parameter uint expert_i1; uint nbi1; #else + uint base_work_group_y; uint ne02; uint ne12; uint broadcast2; @@ -45,9 +46,9 @@ uint expert_id; void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) { #ifdef MUL_MAT_ID - const uint expert_i0 = gl_GlobalInvocationID.y; + const uint expert_i0 = gl_WorkGroupID.y; #else - const uint batch_idx = gl_GlobalInvocationID.y; + const uint batch_idx = gl_WorkGroupID.y + p.base_work_group_y; #endif #ifndef MUL_MAT_ID diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp index 775e9a70f6d..79344d33005 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp @@ -90,6 +90,8 @@ layout (push_constant) uniform parameter uint nbi1; uint ne11; #else + uint base_work_group_z; + uint num_batches; uint k_split; uint ne02; uint ne12; @@ -139,7 +141,7 @@ void main() { const uint ic = gl_WorkGroupID.y; #ifdef MUL_MAT_ID - const uint expert_idx = gl_GlobalInvocationID.z; + const uint expert_idx = gl_WorkGroupID.z; if (ic * BN >= data_expert_count[expert_idx]) { return; } @@ -149,7 +151,7 @@ void main() { #endif #ifndef MUL_MAT_ID - const uint batch_idx = gl_GlobalInvocationID.z; + const uint batch_idx = gl_WorkGroupID.z + p.base_work_group_z; const uint i13 = batch_idx / p.ne12; const uint i12 = batch_idx % p.ne12; @@ -366,7 +368,7 @@ void main() { const uint dc = ic * BN + warp_c * WN; #ifndef MUL_MAT_ID - const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z; + const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * p.num_batches; #endif #ifdef COOPMAT diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp index b6614d2fc59..717d124e019 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp @@ -53,6 +53,8 @@ layout (push_constant) uniform parameter uint nbi1; uint ne11; #else + uint base_work_group_z; + uint num_batches; uint k_split; uint ne02; uint ne12; @@ -197,7 +199,7 @@ void main() { const uint ic = gl_WorkGroupID.y; #ifdef MUL_MAT_ID - const uint expert_idx = gl_GlobalInvocationID.z; + const uint expert_idx = gl_WorkGroupID.z; if (ic * BN >= data_expert_count[expert_idx]) { return; } @@ -215,7 +217,7 @@ void main() { #endif #ifndef MUL_MAT_ID - const uint batch_idx = gl_GlobalInvocationID.z; + const uint batch_idx = gl_WorkGroupID.z + p.base_work_group_z; const uint i13 = batch_idx / p.ne12; const uint i12 = batch_idx % p.ne12; @@ -255,7 +257,7 @@ void main() { #else uint pos_a = batch_idx_a * (p.batch_stride_a / QUANT_K); uint pos_b = batch_idx * p.batch_stride_b; - uint pos_d = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z; + uint pos_d = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * p.num_batches; #endif uint stride_a = p.stride_a / QUANT_K; From fc7a78f4d8662f8d970629622f8a4f15bd90719b Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Wed, 25 Feb 2026 09:33:32 +0200 Subject: [PATCH 172/831] ggml webgpu: shader library organization (llama/19530) * Basic JIT compilation for mul_mat, get_rows, and scale (ggml/17) * scale jit working * preliminary working jit for getrows and mulmat, needs refining * simplified mul_mat preprocessing switch statement * get_rows fixes, mul_mat refinement * formatted + last edits * removed some extraneous prints * fixed get_rows, fixed workgroup dispatch in mul_mat. no gibberish * small fix * some changes, working * get_rows and mul_mat jit fixed and working * Update formatting * formatting * Add header --------- Co-authored-by: Neha Abbas Co-authored-by: Reese Levine * Start work on all-encompassing shader library * refactor argmax, set_rows * Refactor all but flashattention, mat mul * flashattention and matrix multiplication moved to new format * clean up preprocessing * Formatting * remove duplicate constants * Split large shaders into multiple static strings --------- Co-authored-by: neha-ha <137219201+neha-ha@users.noreply.github.com> --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 1465 ++++++++++++----- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 840 +++------- .../wgsl-shaders/common_decls.tmpl | 137 +- .../ggml-webgpu/wgsl-shaders/embed_wgsl.py | 45 +- .../{get_rows.tmpl.wgsl => get_rows.wgsl} | 312 +--- .../{mul_mat.tmpl.wgsl => mul_mat.wgsl} | 312 +--- .../wgsl-shaders/mul_mat_decls.tmpl | 54 +- ...g_tile.tmpl.wgsl => mul_mat_reg_tile.wgsl} | 143 +- ...tmpl.wgsl => mul_mat_subgroup_matrix.wgsl} | 154 +- ...mul_mat_vec.tmpl.wgsl => mul_mat_vec.wgsl} | 163 +- .../{scale.tmpl.wgsl => scale.wgsl} | 45 +- 11 files changed, 1667 insertions(+), 2003 deletions(-) rename ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl => get_rows.wgsl} (83%) rename ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat.tmpl.wgsl => mul_mat.wgsl} (84%) rename ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_reg_tile.tmpl.wgsl => mul_mat_reg_tile.wgsl} (55%) rename ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_subgroup_matrix.tmpl.wgsl => mul_mat_subgroup_matrix.wgsl} (66%) rename ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_vec.tmpl.wgsl => mul_mat_vec.wgsl} (61%) rename ggml/src/ggml-webgpu/wgsl-shaders/{scale.tmpl.wgsl => scale.wgsl} (78%) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 63f797f142d..0d5a818dacb 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -1,11 +1,16 @@ #ifndef GGML_WEBGPU_SHADER_LIB_HPP #define GGML_WEBGPU_SHADER_LIB_HPP +#include "ggml-wgsl-shaders.hpp" #include "ggml.h" #include "pre_wgsl.hpp" +#include + +#include #include #include +#include #include #define GGML_WEBGPU_F16_SIZE_BYTES 2 @@ -18,17 +23,203 @@ #define GGML_WEBGPU_ARGSORT_MERGE_MAX_WG_SIZE 512u -struct ggml_webgpu_processed_shader { - std::string wgsl; - std::string variant; - std::shared_ptr decisions; -}; +// Matrix multiplication parameters + +// Register tiling parameters +#define WEBGPU_MUL_MAT_TILE_M 8 +#define WEBGPU_MUL_MAT_TILE_N 8 +#define WEBGPU_MUL_MAT_WG_SIZE_M 8 +#define WEBGPU_MUL_MAT_WG_SIZE_N 8 +#define WEBGPU_MUL_MAT_TILE_K 32 + +// Subgroup matrix parameters +// The number of subgroups in the M dimension +#define WEBGPU_MUL_MAT_SUBGROUP_M 2 +// The number of subgroups in the N dimension +#define WEBGPU_MUL_MAT_SUBGROUP_N 2 +// The number of subgroup matrices each subgroup accumulates over +#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M 4 +#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2 + +// Matrix-vector multiplication parameters +#define WEBGPU_MUL_MAT_VEC_WG_SIZE 256 +// Must be multiple of 4 to work with vectorized paths, and must divide +// mul_mat_vec wg size +#define WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG 64 +#define WEBGPU_MUL_MAT_VEC_TILE_K 256 + +// default size for legacy matrix multiplication +#define WEBGPU_MUL_MAT_WG_SIZE 256 // Same hash combine function as in boost template inline void ggml_webgpu_hash_combine(size_t & seed, const T & value) { seed ^= std::hash{}(value) + 0x9e3779b9 + (seed << 6) + (seed >> 2); } +struct ggml_webgpu_shader_lib_context { + ggml_tensor * src0; + ggml_tensor * src1; + ggml_tensor * src2; + ggml_tensor * src3; + ggml_tensor * src4; + ggml_tensor * dst; + + uint32_t max_wg_size; + size_t wg_mem_limit_bytes = 0; + bool inplace = false; + bool overlap = false; + bool supports_subgroup_matrix = false; + uint32_t sg_mat_m = 0; + uint32_t sg_mat_n = 0; + uint32_t sg_mat_k = 0; + uint32_t max_subgroup_size = 0; +}; + +struct webgpu_pipeline { + wgpu::ComputePipeline pipeline; + std::string name; + std::shared_ptr context = nullptr; +}; + +struct ggml_webgpu_generic_shader_decisions { + uint32_t wg_size = 0; +}; + +/** Argsort **/ + +struct ggml_webgpu_argsort_shader_lib_context { + uint32_t max_wg_size; + size_t wg_mem_limit_bytes; + int32_t order; +}; + +/** Set Rows **/ + +struct ggml_webgpu_set_rows_pipeline_key { + int dst_type; + int vec4; + int i64_idx; + + bool operator==(const ggml_webgpu_set_rows_pipeline_key & other) const { + return dst_type == other.dst_type && vec4 == other.vec4 && i64_idx == other.i64_idx; + } +}; + +struct ggml_webgpu_set_rows_pipeline_key_hash { + size_t operator()(const ggml_webgpu_set_rows_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.dst_type); + ggml_webgpu_hash_combine(seed, key.vec4); + ggml_webgpu_hash_combine(seed, key.i64_idx); + return seed; + } +}; + +struct ggml_webgpu_set_rows_shader_decisions { + bool vec4; + bool i64_idx; + uint32_t wg_size; +}; + +/** Get Rows **/ + +struct ggml_webgpu_get_rows_pipeline_key { + ggml_type src_type; + int vectorized; + + bool operator==(const ggml_webgpu_get_rows_pipeline_key & other) const { + return src_type == other.src_type && vectorized == other.vectorized; + } +}; + +struct ggml_webgpu_get_rows_pipeline_key_hash { + size_t operator()(const ggml_webgpu_get_rows_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.src_type); + ggml_webgpu_hash_combine(seed, key.vectorized); + return seed; + } +}; + +/** Pad **/ +struct ggml_webgpu_pad_pipeline_key { + bool circular; + + bool operator==(const ggml_webgpu_pad_pipeline_key & other) const { return circular == other.circular; } +}; + +struct ggml_webgpu_pad_pipeline_key_hash { + size_t operator()(const ggml_webgpu_pad_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.circular); + return seed; + } +}; + +/** Scale **/ + +struct ggml_webgpu_scale_pipeline_key { + int inplace; + + bool operator==(const ggml_webgpu_scale_pipeline_key & other) const { return inplace == other.inplace; } +}; + +struct ggml_webgpu_scale_pipeline_key_hash { + size_t operator()(const ggml_webgpu_scale_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.inplace); + return seed; + } +}; + +/** Binary **/ + +struct ggml_webgpu_binary_pipeline_key { + int type; + int op; + bool inplace; + bool overlap; + + bool operator==(const ggml_webgpu_binary_pipeline_key & other) const { + return type == other.type && op == other.op && inplace == other.inplace && overlap == other.overlap; + } +}; + +struct ggml_webgpu_binary_pipeline_key_hash { + size_t operator()(const ggml_webgpu_binary_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.type); + ggml_webgpu_hash_combine(seed, key.op); + ggml_webgpu_hash_combine(seed, key.inplace); + ggml_webgpu_hash_combine(seed, key.overlap); + return seed; + } +}; + +/** Unary **/ + +struct ggml_webgpu_unary_pipeline_key { + int type; + int op; + bool is_unary; // many unary operators fall under the GGML_OP_UNARY umbrella + bool inplace; + + bool operator==(const ggml_webgpu_unary_pipeline_key & other) const { + return type == other.type && op == other.op && is_unary == other.is_unary && inplace == other.inplace; + } +}; + +struct ggml_webgpu_unary_pipeline_key_hash { + size_t operator()(const ggml_webgpu_unary_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.type); + ggml_webgpu_hash_combine(seed, key.op); + ggml_webgpu_hash_combine(seed, key.is_unary); + ggml_webgpu_hash_combine(seed, key.inplace); + return seed; + } +}; + /** FlashAttention */ struct ggml_webgpu_flash_attn_pipeline_key { @@ -100,439 +291,941 @@ inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile, return f16_elems * GGML_WEBGPU_F16_SIZE_BYTES + f32_elems * GGML_WEBGPU_F32_SIZE_BYTES; } -static uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_flash_attn_shader_lib_context & context) { - const size_t limit_bytes = context.wg_mem_limit_bytes; - const size_t q_tile = context.sg_mat_m; - const size_t base_q_bytes = - (context.key.head_dim_qk + context.key.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES + - 2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES; - size_t bytes_per_kv = 0; - if (!context.key.kv_direct) { - bytes_per_kv += std::max(context.key.head_dim_qk, context.key.head_dim_v); - } - if (context.key.has_mask) { - bytes_per_kv += q_tile; - } - bytes_per_kv += q_tile; - bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES; - const uint32_t max_kv_tile = (limit_bytes - base_q_bytes) / bytes_per_kv; - return (max_kv_tile / context.sg_mat_n) * context.sg_mat_n; -} - -inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_shader( - pre_wgsl::Preprocessor & preprocessor, - const char * shader_src, - const ggml_webgpu_flash_attn_shader_lib_context & context) { - std::vector defines; - std::string variant = "flash_attn"; - - switch (context.key.kv_type) { - case GGML_TYPE_F32: - defines.push_back("KV_F32"); - break; - case GGML_TYPE_F16: - defines.push_back("KV_F16"); - break; - case GGML_TYPE_Q4_0: - defines.push_back("KV_Q4_0"); - break; - case GGML_TYPE_Q8_0: - defines.push_back("KV_Q8_0"); - break; - default: - GGML_ABORT("Unsupported KV type for flash attention shader"); - } - variant += std::string("_") + ggml_type_name(context.key.kv_type); - - if (context.key.has_mask) { - defines.push_back("MASK"); - variant += "_mask"; - } - if (context.key.has_sinks) { - defines.push_back("SINKS"); - variant += "_sinks"; - } - if (context.key.uses_logit_softcap) { - defines.push_back("LOGIT_SOFTCAP"); - variant += "_lgsc"; - } - - if (context.key.kv_direct) { - defines.push_back("KV_DIRECT"); - variant += "_kvdirect"; - } - - defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(context.key.head_dim_qk)); - variant += std::string("_hsqk") + std::to_string(context.key.head_dim_qk); - - defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(context.key.head_dim_v)); - variant += std::string("_hsv") + std::to_string(context.key.head_dim_v); - // For now these are not part of the variant name - defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m)); - defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n)); - defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k)); - - // Add chosen Q/KV tile sizes - uint32_t q_tile = context.sg_mat_m; - uint32_t kv_tile = std::min(ggml_webgpu_flash_attn_max_kv_tile(context), - context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES); - if (context.key.kv_direct) { - GGML_ASSERT(kv_tile <= GGML_WEBGPU_KV_SEQ_PAD); - // Avoids having to use bounds-checks and decreasing performance for direct KV loads - while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) { - kv_tile -= context.sg_mat_n; - } - } - - defines.push_back(std::string("Q_TILE=") + std::to_string(q_tile)); - defines.push_back(std::string("KV_TILE=") + std::to_string(kv_tile)); - - // workgroup size - uint32_t wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE); - - defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); - - ggml_webgpu_processed_shader result; - result.wgsl = preprocessor.preprocess(shader_src, defines); - result.variant = variant; - auto decisions = std::make_shared(); - decisions->q_tile = q_tile; - decisions->kv_tile = kv_tile; - decisions->wg_size = wg_size; - result.decisions = decisions; - return result; -} +/** Matrix Multiplication **/ -/** Generic **/ +struct ggml_webgpu_legacy_mul_mat_pipeline_key { + ggml_type src0_type; + ggml_type src1_type; -struct ggml_webgpu_generic_shader_lib_context { - int vec4; - uint32_t max_wg_size; + bool operator==(const ggml_webgpu_legacy_mul_mat_pipeline_key & other) const { + return src0_type == other.src0_type && src1_type == other.src1_type; + } }; -struct ggml_webgpu_generic_shader_decisions { - uint32_t wg_size; +struct ggml_webgpu_legacy_mul_mat_pipeline_key_hash { + size_t operator()(const ggml_webgpu_legacy_mul_mat_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.src0_type); + ggml_webgpu_hash_combine(seed, key.src1_type); + return seed; + } }; -inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_generic_shader( - pre_wgsl::Preprocessor & preprocessor, - const char * shader_src, - const ggml_webgpu_generic_shader_lib_context & context, - const std::string & base_variant) { - std::vector defines; - std::string variant = base_variant; +struct ggml_webgpu_mul_mat_vec_pipeline_key { + ggml_type src0_type; + ggml_type src1_type; + int vectorized; - if (context.vec4) { - defines.push_back("VEC4"); - variant += "_vec"; + bool operator==(const ggml_webgpu_mul_mat_vec_pipeline_key & other) const { + return src0_type == other.src0_type && src1_type == other.src1_type && vectorized == other.vectorized; } +}; - defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); - - ggml_webgpu_processed_shader result; - result.wgsl = preprocessor.preprocess(shader_src, defines); - result.variant = variant; - return result; -} +struct ggml_webgpu_mul_mat_vec_pipeline_key_hash { + size_t operator()(const ggml_webgpu_mul_mat_vec_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.src0_type); + ggml_webgpu_hash_combine(seed, key.src1_type); + ggml_webgpu_hash_combine(seed, key.vectorized); + return seed; + } +}; -/** Pad **/ +struct ggml_webgpu_mul_mat_vec_shader_decisions { + uint32_t wg_size; + uint32_t tile_k; + uint32_t outputs_per_wg; + uint32_t vec_size; +}; -struct ggml_webgpu_pad_pipeline_key { - bool circular; +struct ggml_webgpu_mul_mat_pipeline_key { + ggml_type src0_type; + ggml_type src1_type; + int vectorized; + int use_subgroup_matrix; - bool operator==(const ggml_webgpu_pad_pipeline_key & other) const { return circular == other.circular; } + bool operator==(const ggml_webgpu_mul_mat_pipeline_key & other) const { + return src0_type == other.src0_type && src1_type == other.src1_type && vectorized == other.vectorized && + use_subgroup_matrix == other.use_subgroup_matrix; + } }; -struct ggml_webgpu_pad_pipeline_key_hash { - size_t operator()(const ggml_webgpu_pad_pipeline_key & key) const { +struct ggml_webgpu_mul_mat_pipeline_key_hash { + size_t operator()(const ggml_webgpu_mul_mat_pipeline_key & key) const { size_t seed = 0; - ggml_webgpu_hash_combine(seed, key.circular); + ggml_webgpu_hash_combine(seed, key.src0_type); + ggml_webgpu_hash_combine(seed, key.src1_type); + ggml_webgpu_hash_combine(seed, key.vectorized); + ggml_webgpu_hash_combine(seed, key.use_subgroup_matrix); return seed; } }; -struct ggml_webgpu_pad_shader_lib_context { - ggml_webgpu_pad_pipeline_key key; - uint32_t max_wg_size; +struct ggml_webgpu_mul_mat_shader_decisions { + uint32_t tile_k; + uint32_t wg_size_m; + uint32_t wg_size_n; + uint32_t wg_size; + uint32_t outputs_per_wg; + int use_subgroup_matrix; + + uint32_t tile_m; + uint32_t tile_n; + + // Subgroup matrix parameters + uint32_t subgroup_m; + uint32_t subgroup_n; + uint32_t subgroup_matrix_m; + uint32_t subgroup_matrix_n; + + uint32_t mul_mat_wg_size; }; -inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_pad_shader( - pre_wgsl::Preprocessor & preprocessor, - const char * shader_src, - const ggml_webgpu_pad_shader_lib_context & context) { - std::vector defines; - std::string variant = "pad"; +class ggml_webgpu_shader_lib { + wgpu::Device device; + pre_wgsl::Preprocessor preprocessor; + + std::unordered_map sum_rows_pipelines; // key is fixed, no variants yet + std::unordered_map argmax_pipelines; // key is vec4 + std::unordered_map argsort_pipelines; // key is order + std::unordered_map argsort_merge_pipelines; // key is order + std::unordered_map cumsum_pipelines; // key is fixed, no variants yet + std::unordered_map + get_rows_pipelines; // src_type, vectorized + std::unordered_map + unary_pipelines; // type/op/inplace + std::unordered_map + scale_pipelines; // inplace + std::unordered_map + pad_pipelines; // circular/non-circular + std::unordered_map + binary_pipelines; // type/op/inplace/overlap + std::unordered_map + flash_attn_pipelines; + std::unordered_map + mul_mat_legacy_pipelines; // legacy mul_mat (non-subgroup/non-regtile/non-vec) + std::unordered_map + mul_mat_vec_pipelines; // fast mat-vec (n==1) + std::unordered_map + mul_mat_fast_pipelines; // fast mat-mat (reg-tile or subgroup) + + std::unordered_map + set_rows_pipelines; + + public: + ggml_webgpu_shader_lib(wgpu::Device device) { this->device = device; } + + webgpu_pipeline get_sum_rows_pipeline(const ggml_webgpu_shader_lib_context & context) { + auto it = sum_rows_pipelines.find(1); + if (it != sum_rows_pipelines.end()) { + return it->second; + } + std::vector defines; + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); - if (context.key.circular) { - defines.push_back("CIRCULAR"); - variant += "_circular"; + auto processed = preprocessor.preprocess(wgsl_sum_rows, defines); + sum_rows_pipelines[1] = ggml_webgpu_create_pipeline(device, processed, "sum_rows"); + return sum_rows_pipelines[1]; } - defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + webgpu_pipeline get_argmax_pipeline(const ggml_webgpu_shader_lib_context & context) { + bool vec4 = context.src0->ne[0] % 4 == 0; - ggml_webgpu_processed_shader result; - result.wgsl = preprocessor.preprocess(shader_src, defines); - result.variant = variant; - auto decisions = std::make_shared(); - decisions->wg_size = context.max_wg_size; - result.decisions = decisions; - return result; -} + auto it = argmax_pipelines.find(vec4); + if (it != argmax_pipelines.end()) { + return it->second; + } + std::string variant = "argmax"; + std::vector defines; + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + if (vec4) { + defines.push_back("VEC4"); + variant += "_vec4"; + } -/** Argsort **/ + auto processed = preprocessor.preprocess(wgsl_argmax, defines); + argmax_pipelines[vec4] = ggml_webgpu_create_pipeline(device, processed, variant); + return argmax_pipelines.at(vec4); + } -struct ggml_webgpu_argsort_shader_lib_context { - uint32_t max_wg_size; - size_t wg_mem_limit_bytes; - int32_t order; -}; + webgpu_pipeline get_set_rows_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_set_rows_pipeline_key key = { .dst_type = context.dst->type, + .vec4 = context.src0->ne[0] % 4 == 0, + .i64_idx = context.src1->type == GGML_TYPE_I64 }; -struct ggml_webgpu_argsort_shader_decisions { - uint32_t wg_size = 0; -}; + auto it = set_rows_pipelines.find(key); + if (it != set_rows_pipelines.end()) { + return it->second; + } -inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_argsort_shader( - pre_wgsl::Preprocessor & preprocessor, - const char * shader_src, - const ggml_webgpu_argsort_shader_lib_context & context) { - std::vector defines; - std::string variant = "argsort"; - defines.push_back(std::string("ORDER=") + std::to_string(context.order)); - variant += std::string("_order") + std::to_string(context.order); - uint32_t wg_size = 1; - while (wg_size * 2 <= context.max_wg_size && - wg_size * GGML_WEBGPU_I32_SIZE_BYTES <= context.wg_mem_limit_bytes / 2) { - wg_size *= 2; - } - defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); - ggml_webgpu_processed_shader result; - result.wgsl = preprocessor.preprocess(shader_src, defines); - result.variant = variant; - auto decisions = std::make_shared(); - decisions->wg_size = wg_size; - result.decisions = decisions; - return result; -} + std::vector defines; + std::string variant = "set_rows"; + + switch (context.dst->type) { + case GGML_TYPE_F32: + defines.push_back("DST_F32"); + variant += "_dstf32"; + break; + case GGML_TYPE_F16: + defines.push_back("DST_F16"); + variant += "_dstf16"; + break; + default: + GGML_ABORT("Unsupported dst type for set_rows shader"); + } -inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_argsort_merge_shader( - pre_wgsl::Preprocessor & preprocessor, - const char * shader_src, - const ggml_webgpu_argsort_shader_lib_context & context) { - std::vector defines; - std::string variant = "argsort_merge"; - defines.push_back(std::string("ORDER=") + std::to_string(context.order)); - variant += std::string("_order") + std::to_string(context.order); - uint32_t wg_size = std::min(GGML_WEBGPU_ARGSORT_MERGE_MAX_WG_SIZE, context.max_wg_size); - defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); - ggml_webgpu_processed_shader result; - result.wgsl = preprocessor.preprocess(shader_src, defines); - result.variant = variant; - auto decisions = std::make_shared(); - decisions->wg_size = wg_size; - result.decisions = decisions; - return result; -} + if (key.vec4) { + defines.push_back("VEC4"); + variant += "_vec4"; + } + if (key.i64_idx) { + defines.push_back("I64_IDX"); + variant += "_i64idx"; + } -/** Set Rows **/ + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); -struct ggml_webgpu_set_rows_pipeline_key { - int dst_type; - int vec4; - int i64_idx; + auto processed = preprocessor.preprocess(wgsl_set_rows, defines); + auto decisions = std::make_shared(); + decisions->vec4 = key.vec4; + decisions->i64_idx = key.i64_idx; + decisions->wg_size = context.max_wg_size; + set_rows_pipelines[key] = ggml_webgpu_create_pipeline(device, processed, variant); + set_rows_pipelines[key].context = decisions; + return set_rows_pipelines[key]; + } - bool operator==(const ggml_webgpu_set_rows_pipeline_key & other) const { - return dst_type == other.dst_type && vec4 == other.vec4 && i64_idx == other.i64_idx; + webgpu_pipeline get_cumsum_pipeline(const ggml_webgpu_shader_lib_context & context) { + auto it = cumsum_pipelines.find(1); + if (it != cumsum_pipelines.end()) { + return it->second; + } + + std::vector defines; + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_cumsum, defines); + cumsum_pipelines[1] = ggml_webgpu_create_pipeline(device, processed, "cumsum"); + return cumsum_pipelines[1]; } -}; -struct ggml_webgpu_set_rows_pipeline_key_hash { - size_t operator()(const ggml_webgpu_set_rows_pipeline_key & key) const { - size_t seed = 0; - ggml_webgpu_hash_combine(seed, key.dst_type); - ggml_webgpu_hash_combine(seed, key.vec4); - ggml_webgpu_hash_combine(seed, key.i64_idx); - return seed; + webgpu_pipeline get_argsort_pipeline(const ggml_webgpu_shader_lib_context & context) { + bool is_top_k = context.dst->op == GGML_OP_TOP_K; + // ascending order is 0, descending order is 1 + const int32_t order = + is_top_k ? (int32_t) GGML_SORT_ORDER_DESC : (int32_t) ggml_get_op_params_i32(context.dst, 0); + + auto it = argsort_pipelines.find(order); + if (it != argsort_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "argsort"; + defines.push_back(std::string("ORDER=") + std::to_string(order)); + variant += std::string("_order") + std::to_string(order); + uint32_t wg_size = 1; + while (wg_size * 2 <= context.max_wg_size && + wg_size * GGML_WEBGPU_I32_SIZE_BYTES <= context.wg_mem_limit_bytes / 2) { + wg_size *= 2; + } + defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); + auto processed = preprocessor.preprocess(wgsl_argsort, defines); + auto decisions = std::make_shared(); + decisions->wg_size = wg_size; + argsort_pipelines[order] = ggml_webgpu_create_pipeline(device, processed, variant); + argsort_pipelines[order].context = decisions; + return argsort_pipelines[order]; } -}; -struct ggml_webgpu_set_rows_shader_lib_context { - ggml_webgpu_set_rows_pipeline_key key; - uint32_t max_wg_size; -}; + webgpu_pipeline get_argsort_merge_pipeline(const ggml_webgpu_shader_lib_context & context) { + bool is_top_k = context.dst->op == GGML_OP_TOP_K; + // ascending order is 0, descending order is 1 + const int32_t order = + is_top_k ? (int32_t) GGML_SORT_ORDER_DESC : (int32_t) ggml_get_op_params_i32(context.dst, 0); -inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_set_rows_shader( - pre_wgsl::Preprocessor & preprocessor, - const char * shader_src, - const ggml_webgpu_set_rows_shader_lib_context & context) { - std::vector defines; - std::string variant = "set_rows"; - - switch (context.key.dst_type) { - case GGML_TYPE_F32: - defines.push_back("DST_F32"); - variant += "_dstf32"; - break; - case GGML_TYPE_F16: - defines.push_back("DST_F16"); - variant += "_dstf16"; - break; - default: - GGML_ABORT("Unsupported dst type for set_rows shader"); - } - - if (context.key.vec4) { - defines.push_back("VEC4"); - variant += "_vec"; - } - if (context.key.i64_idx) { - defines.push_back("I64_IDX"); - variant += "_i64idx"; - } - - defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); - - ggml_webgpu_processed_shader result; - result.wgsl = preprocessor.preprocess(shader_src, defines); - result.variant = variant; - auto decisions = std::make_shared(); - decisions->wg_size = context.max_wg_size; - result.decisions = decisions; - return result; -} + auto it = argsort_merge_pipelines.find(order); + if (it != argsort_merge_pipelines.end()) { + return it->second; + } -struct ggml_webgpu_unary_pipeline_key { - int type; - int op; - bool is_unary; // many unary operators fall under the GGML_OP_UNARY umbrella - bool inplace; + std::vector defines; + std::string variant = "argsort_merge"; + defines.push_back(std::string("ORDER=") + std::to_string(order)); + variant += std::string("_order") + std::to_string(order); + uint32_t wg_size = std::min(GGML_WEBGPU_ARGSORT_MERGE_MAX_WG_SIZE, context.max_wg_size); + defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); - bool operator==(const ggml_webgpu_unary_pipeline_key & other) const { - return type == other.type && op == other.op && is_unary == other.is_unary && inplace == other.inplace; + auto processed = preprocessor.preprocess(wgsl_argsort_merge, defines); + argsort_merge_pipelines[order] = ggml_webgpu_create_pipeline(device, processed, variant); + return argsort_merge_pipelines[order]; } -}; -struct ggml_webgpu_unary_pipeline_key_hash { - size_t operator()(const ggml_webgpu_unary_pipeline_key & key) const { - size_t seed = 0; - ggml_webgpu_hash_combine(seed, key.type); - ggml_webgpu_hash_combine(seed, key.op); - ggml_webgpu_hash_combine(seed, key.is_unary); - ggml_webgpu_hash_combine(seed, key.inplace); - return seed; + webgpu_pipeline get_get_rows_pipeline(const ggml_webgpu_shader_lib_context & context) { + const bool vectorized = context.src0->type == GGML_TYPE_F32 && context.dst->ne[0] % 4 == 0; + ggml_webgpu_get_rows_pipeline_key key = { + .src_type = context.src0->type, + .vectorized = (int) vectorized, + }; + + auto it = get_rows_pipelines.find(key); + if (it != get_rows_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "get_rows"; + + const struct ggml_type_traits * type_traits = ggml_get_type_traits(key.src_type); + const char * type_str = type_traits->type_name; + + switch (key.src_type) { + case GGML_TYPE_F32: + if (key.vectorized) { + defines.push_back("F32_VEC"); + defines.push_back("SRC_TYPE=vec4"); + defines.push_back("DST_TYPE=vec4"); + defines.push_back("BLOCK_SIZE=4u"); + } else { + defines.push_back("F32"); + defines.push_back("SRC_TYPE=f32"); + defines.push_back("DST_TYPE=f32"); + defines.push_back("BLOCK_SIZE=1u"); + } + variant += "_f32"; + break; + case GGML_TYPE_F16: + defines.push_back("F16"); + defines.push_back("SRC_TYPE=f16"); + defines.push_back("DST_TYPE=f32"); + defines.push_back("BLOCK_SIZE=1u"); + variant += "_f16"; + break; + case GGML_TYPE_I32: + defines.push_back("I32"); + defines.push_back("SRC_TYPE=i32"); + defines.push_back("DST_TYPE=i32"); + defines.push_back("BLOCK_SIZE=1u"); + variant += "_i32"; + break; + default: + { + std::string type_upper = type_str; + std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper); + + defines.push_back("BYTE_HELPERS"); + defines.push_back(type_upper + "_T"); + defines.push_back(type_upper); + defines.push_back(type_upper + "_SCALE_MIN"); + defines.push_back(type_upper + "_TABLES"); + defines.push_back(type_upper + "_GRID"); + + variant += "_"; + variant += type_str; + + defines.push_back(std::string("SRC_TYPE=") + type_str); + defines.push_back("DST_TYPE=f32"); + + if ((key.src_type >= GGML_TYPE_Q4_0 && key.src_type <= GGML_TYPE_Q8_1) || + key.src_type == GGML_TYPE_IQ4_NL) { + defines.push_back("BLOCK_SIZE=32u"); + } else if (key.src_type >= GGML_TYPE_Q2_K) { + defines.push_back("BLOCK_SIZE=256u"); + } else { + defines.push_back("BLOCK_SIZE=1u"); + } + break; + } + } + + if (key.vectorized) { + variant += "_vec"; + } + + defines.push_back("WG_SIZE=" + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_get_rows, defines); + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + get_rows_pipelines[key] = pipeline; + return get_rows_pipelines[key]; } -}; -struct ggml_webgpu_unary_shader_lib_context { - ggml_webgpu_unary_pipeline_key key; - uint32_t max_wg_size; -}; + webgpu_pipeline get_scale_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_scale_pipeline_key key = { .inplace = context.inplace }; -inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_unary_shader( - pre_wgsl::Preprocessor & preprocessor, - const char * shader_src, - const ggml_webgpu_unary_shader_lib_context & context) { - std::vector defines; - std::string variant = context.key.is_unary ? ggml_unary_op_name((ggml_unary_op) context.key.op) : - ggml_op_name((ggml_op) context.key.op); - // Operation-specific behavior - defines.push_back(variant); - - switch (context.key.type) { - case GGML_TYPE_F32: - defines.push_back("TYPE_F32"); - variant += "_f32"; - break; - case GGML_TYPE_F16: - defines.push_back("TYPE_F16"); - variant += "_f16"; - break; - default: - GGML_ABORT("Unsupported type for unary shader"); - } - - if (context.key.inplace) { - defines.push_back("INPLACE"); - variant += "_inplace"; - } - - defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); - - ggml_webgpu_processed_shader result; - result.wgsl = preprocessor.preprocess(shader_src, defines); - result.variant = variant; - auto decisions = std::make_shared(); - decisions->wg_size = context.max_wg_size; - result.decisions = decisions; - return result; -} + auto it = scale_pipelines.find(key); + if (it != scale_pipelines.end()) { + return it->second; + } -/** Binary **/ + std::vector defines; + std::string variant = "scale"; -struct ggml_webgpu_binary_pipeline_key { - int type; - int op; - bool inplace; - bool overlap; + if (key.inplace) { + defines.push_back("INPLACE"); + variant += "_inplace"; + } - bool operator==(const ggml_webgpu_binary_pipeline_key & other) const { - return type == other.type && op == other.op && inplace == other.inplace && overlap == other.overlap; + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_scale, defines); + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + scale_pipelines[key] = pipeline; + return scale_pipelines[key]; } -}; -struct ggml_webgpu_binary_pipeline_key_hash { - size_t operator()(const ggml_webgpu_binary_pipeline_key & key) const { - size_t seed = 0; - ggml_webgpu_hash_combine(seed, key.type); - ggml_webgpu_hash_combine(seed, key.op); - ggml_webgpu_hash_combine(seed, key.inplace); - ggml_webgpu_hash_combine(seed, key.overlap); - return seed; + webgpu_pipeline get_pad_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_pad_pipeline_key key = { .circular = ggml_get_op_params_i32(context.dst, 8) != 0 }; + + auto it = pad_pipelines.find(key); + if (it != pad_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "pad"; + + if (key.circular) { + defines.push_back("CIRCULAR"); + variant += "_circular"; + } + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_pad, defines); + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + pad_pipelines[key] = pipeline; + return pad_pipelines[key]; } -}; -struct ggml_webgpu_binary_shader_lib_context { - ggml_webgpu_binary_pipeline_key key; - uint32_t max_wg_size; + webgpu_pipeline get_mul_mat_vec_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_mul_mat_vec_pipeline_key key = { + .src0_type = context.src0->type, + .src1_type = context.src1->type, + // Quantized mat-vec path currently runs scalar; only allow vectorization when both inputs are float + .vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 && + (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? + 1 : + 0, + }; + + auto it = mul_mat_vec_pipelines.find(key); + if (it != mul_mat_vec_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "mul_mat_vec"; + + // src1 type (vector) + switch (context.src1->type) { + case GGML_TYPE_F32: + defines.push_back("SRC1_INNER_TYPE=f32"); + variant += "_f32"; + break; + case GGML_TYPE_F16: + defines.push_back("SRC1_INNER_TYPE=f16"); + variant += "_f16"; + break; + default: + GGML_ABORT("Unsupported src1 type for mul_mat_vec shader"); + } + + // src0 type (matrix row) + switch (context.src0->type) { + case GGML_TYPE_F32: + defines.push_back("SRC0_INNER_TYPE=f32"); + defines.push_back("MUL_ACC_FLOAT"); + break; + case GGML_TYPE_F16: + defines.push_back("SRC0_INNER_TYPE=f16"); + defines.push_back("MUL_ACC_FLOAT"); + break; + default: + { + // Quantized types: use helpers but accumulate in f16 + const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type); + std::string src0_name = src0_traits->type_name; + std::string type_upper = src0_name; + std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper); + + defines.push_back("BYTE_HELPERS"); + defines.push_back("MUL_ACC_" + type_upper); + + // For fast path we always dequantize from f16 inside the shader + defines.push_back("SRC0_INNER_TYPE=f16"); + break; + } + } + + // VEC/SCALAR controls + defines.push_back(key.vectorized ? "VEC" : "SCALAR"); + + uint32_t wg_size = WEBGPU_MUL_MAT_VEC_WG_SIZE; + uint32_t tile_k = WEBGPU_MUL_MAT_VEC_TILE_K; + uint32_t outputs_per_wg = WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG; + defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); + defines.push_back(std::string("TILE_K=") + std::to_string(tile_k)); + defines.push_back(std::string("OUTPUTS_PER_WG=") + std::to_string(outputs_per_wg)); + + auto processed = preprocessor.preprocess(wgsl_mul_mat_vec, defines); + auto decisions = std::make_shared(); + decisions->wg_size = wg_size; + decisions->tile_k = tile_k; + decisions->outputs_per_wg = outputs_per_wg; + decisions->vec_size = key.vectorized ? 4 : 1; + + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + mul_mat_vec_pipelines[key] = pipeline; + return mul_mat_vec_pipelines[key]; + } + + webgpu_pipeline get_mul_mat_fast_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_mul_mat_pipeline_key key = { + .src0_type = context.src0->type, + .src1_type = context.src1->type, + .vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 && context.dst->ne[1] % 4 == 0 && + (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? + 1 : + 0, + .use_subgroup_matrix = context.supports_subgroup_matrix + }; + + auto it = mul_mat_fast_pipelines.find(key); + if (it != mul_mat_fast_pipelines.end()) { + return it->second; + } + + const char * shader_src = key.use_subgroup_matrix ? wgsl_mul_mat_subgroup_matrix : wgsl_mul_mat_reg_tile; + std::vector defines; + std::string variant = key.use_subgroup_matrix ? "mul_mat_subgroup_matrix" : "mul_mat_reg_tile"; + + // src1 type + switch (context.src1->type) { + case GGML_TYPE_F32: + defines.push_back("SRC1_INNER_TYPE=f32"); + break; + case GGML_TYPE_F16: + defines.push_back("SRC1_INNER_TYPE=f16"); + break; + default: + GGML_ABORT("Unsupported src1 type for mul_mat fast shader"); + } + + // src0 type + const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type); + const char * src0_name = src0_traits->type_name; + + switch (context.src0->type) { + case GGML_TYPE_F32: + defines.push_back("SRC0_INNER_TYPE=f32"); + defines.push_back("FLOAT"); + defines.push_back("MUL_ACC_FLOAT"); + defines.push_back("INIT_SRC0_SHMEM_FLOAT"); + defines.push_back("INIT_SRC1_SHMEM_FLOAT"); + variant += "_f32"; + break; + case GGML_TYPE_F16: + defines.push_back("SRC0_INNER_TYPE=f16"); + defines.push_back("FLOAT"); + defines.push_back("MUL_ACC_FLOAT"); + defines.push_back("INIT_SRC0_SHMEM_FLOAT"); + defines.push_back("INIT_SRC1_SHMEM_FLOAT"); + variant += "_f16"; + break; + default: + { + std::string type_upper = src0_name; + std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper); + + defines.push_back("BYTE_HELPERS"); + defines.push_back("MUL_ACC_" + type_upper); + defines.push_back("INIT_SRC0_SHMEM_" + type_upper); + defines.push_back("INIT_SRC1_SHMEM_FLOAT"); + + // Use f16 inside the shader for quantized types + defines.push_back("SRC0_INNER_TYPE=f16"); + + variant += std::string("_") + src0_name; + break; + } + } + + // VEC/SCALAR controls + defines.push_back(key.vectorized ? "VEC" : "SCALAR"); + + // Tiles + defines.push_back("TILE_M=" + std::to_string(WEBGPU_MUL_MAT_TILE_M) + "u"); + defines.push_back("TILE_N=" + std::to_string(WEBGPU_MUL_MAT_TILE_N) + "u"); + defines.push_back("TILE_K=" + std::to_string(WEBGPU_MUL_MAT_TILE_K) + "u"); + + // Subgroup matrix specifics + if (key.use_subgroup_matrix) { + defines.push_back("MAX_SUBGROUP_SIZE=" + std::to_string(context.max_subgroup_size) + "u"); + defines.push_back("SUBGROUP_M=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_M) + "u"); + defines.push_back("SUBGROUP_N=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_N) + "u"); + defines.push_back("SUBGROUP_MATRIX_M=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M) + "u"); + defines.push_back("SUBGROUP_MATRIX_N=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N) + "u"); + defines.push_back("SUBGROUP_MATRIX_M_SIZE=" + std::to_string(context.sg_mat_m) + "u"); + defines.push_back("SUBGROUP_MATRIX_N_SIZE=" + std::to_string(context.sg_mat_n) + "u"); + defines.push_back("SUBGROUP_MATRIX_K_SIZE=" + std::to_string(context.sg_mat_k) + "u"); + } + + // variant suffix for src1 type + variant += std::string("_") + (context.src1->type == GGML_TYPE_F32 ? "f32" : "f16"); + if (key.vectorized) { + variant += "_vectorized"; + } + + if (!key.use_subgroup_matrix) { + defines.push_back("WORKGROUP_SIZE_M=" + std::to_string(WEBGPU_MUL_MAT_WG_SIZE_M) + "u"); + defines.push_back("WORKGROUP_SIZE_N=" + std::to_string(WEBGPU_MUL_MAT_WG_SIZE_N) + "u"); + } + + auto processed = preprocessor.preprocess(shader_src, defines); + + auto decisions = std::make_shared(); + decisions->tile_k = WEBGPU_MUL_MAT_TILE_K; + decisions->tile_m = WEBGPU_MUL_MAT_TILE_M; + decisions->tile_n = WEBGPU_MUL_MAT_TILE_N; + decisions->use_subgroup_matrix = key.use_subgroup_matrix; + if (key.use_subgroup_matrix) { + decisions->subgroup_m = WEBGPU_MUL_MAT_SUBGROUP_M; + decisions->subgroup_n = WEBGPU_MUL_MAT_SUBGROUP_N; + decisions->subgroup_matrix_m = WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M; + decisions->subgroup_matrix_n = WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N; + decisions->wg_size = context.max_subgroup_size; + } else { + decisions->wg_size_m = WEBGPU_MUL_MAT_WG_SIZE_M; + decisions->wg_size_n = WEBGPU_MUL_MAT_WG_SIZE_N; + decisions->wg_size = WEBGPU_MUL_MAT_WG_SIZE_M * WEBGPU_MUL_MAT_WG_SIZE_N; + decisions->mul_mat_wg_size = WEBGPU_MUL_MAT_WG_SIZE; + } + + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + mul_mat_fast_pipelines[key] = pipeline; + return mul_mat_fast_pipelines[key]; + } + + webgpu_pipeline get_mul_mat_legacy_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_legacy_mul_mat_pipeline_key key = { .src0_type = context.src0->type, + .src1_type = context.src1->type }; + + auto it = mul_mat_legacy_pipelines.find(key); + if (it != mul_mat_legacy_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "mul_mat"; + + switch (context.src1->type) { + case GGML_TYPE_F32: + defines.push_back("SRC1_TYPE=f32"); + variant += "_f32"; + break; + case GGML_TYPE_F16: + defines.push_back("SRC1_TYPE=f16"); + variant += "_f16"; + break; + default: + GGML_ABORT("Unsupported src1 type for mul_mat legacy shader"); + } + + const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type); + const char * src0_name = src0_traits->type_name; + + switch (context.src0->type) { + case GGML_TYPE_F32: + defines.push_back("SRC0_TYPE=f32"); + defines.push_back("FLOAT"); + variant += "_f32"; + break; + case GGML_TYPE_F16: + defines.push_back("SRC0_TYPE=f16"); + defines.push_back("FLOAT"); + variant += "_f16"; + break; + default: + { + // quantized types + std::string type_upper = src0_name; + std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper); + + defines.push_back(std::string("SRC0_TYPE=") + src0_name); + defines.push_back("BYTE_HELPERS"); + defines.push_back(type_upper + "_T"); + defines.push_back(type_upper); + defines.push_back(type_upper + "_SCALE_MIN"); + defines.push_back(type_upper + "_TABLES"); + defines.push_back(type_upper + "_GRID"); + + variant += std::string("_") + src0_name; + break; + } + } + + auto processed = preprocessor.preprocess(wgsl_mul_mat, defines); + + auto decisions = std::make_shared(); + decisions->wg_size = WEBGPU_MUL_MAT_WG_SIZE; + + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + mul_mat_legacy_pipelines[key] = pipeline; + return mul_mat_legacy_pipelines[key]; + } + + webgpu_pipeline get_unary_pipeline(const ggml_webgpu_shader_lib_context & context) { + const bool is_unary = context.dst->op == GGML_OP_UNARY; + const int op = is_unary ? (int) ggml_get_unary_op(context.dst) : context.dst->op; + ggml_webgpu_unary_pipeline_key key = { + .type = context.dst->type, + .op = op, + .is_unary = is_unary, + .inplace = context.inplace, + }; + + auto it = unary_pipelines.find(key); + if (it != unary_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = + key.is_unary ? ggml_unary_op_name((ggml_unary_op) key.op) : ggml_op_name((ggml_op) key.op); + defines.push_back(variant); + + switch (key.type) { + case GGML_TYPE_F32: + defines.push_back("TYPE_F32"); + variant += "_f32"; + break; + case GGML_TYPE_F16: + defines.push_back("TYPE_F16"); + variant += "_f16"; + break; + default: + GGML_ABORT("Unsupported type for unary shader"); + } + + if (key.inplace) { + defines.push_back("INPLACE"); + variant += "_inplace"; + } + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_unary, defines); + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + unary_pipelines[key] = pipeline; + return unary_pipelines[key]; + } + + webgpu_pipeline get_binary_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_binary_pipeline_key key = { + .type = context.dst->type, + .op = context.dst->op, + .inplace = context.inplace, + .overlap = context.overlap, + }; + + auto it = binary_pipelines.find(key); + if (it != binary_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string op_name = ggml_op_name((ggml_op) key.op); + std::string variant = op_name; + + defines.push_back(std::string("OP_") + op_name); + + switch (key.type) { + case GGML_TYPE_F32: + defines.push_back("TYPE_F32"); + variant += "_f32"; + break; + case GGML_TYPE_F16: + defines.push_back("TYPE_F16"); + variant += "_f16"; + break; + default: + GGML_ABORT("Unsupported type for binary shader"); + } + + if (key.inplace) { + defines.push_back("INPLACE"); + variant += "_inplace"; + } else if (key.overlap) { + defines.push_back("OVERLAP"); + variant += "_overlap"; + } + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_binary, defines); + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + binary_pipelines[key] = pipeline; + return binary_pipelines[key]; + } + + webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_shader_lib_context & context) { + const bool has_mask = context.src3 != nullptr; + const bool has_sinks = context.src4 != nullptr; + + bool kv_direct = (context.src1->type == GGML_TYPE_F16) && (context.src0->ne[0] % context.sg_mat_k == 0) && + (context.src1->ne[1] % context.sg_mat_n == 0); + + ggml_webgpu_flash_attn_pipeline_key key = { + .kv_type = context.src1->type, + .head_dim_qk = (uint32_t) context.src0->ne[0], + .head_dim_v = (uint32_t) context.src2->ne[0], + .kv_direct = kv_direct, + .has_mask = has_mask, + .has_sinks = has_sinks, + .uses_logit_softcap = (*(float *) &context.dst->op_params[2]) != 0.0f, + }; + + auto it = flash_attn_pipelines.find(key); + if (it != flash_attn_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "flash_attn"; + + switch (key.kv_type) { + case GGML_TYPE_F32: + defines.push_back("KV_F32"); + break; + case GGML_TYPE_F16: + defines.push_back("KV_F16"); + break; + case GGML_TYPE_Q4_0: + defines.push_back("KV_Q4_0"); + break; + case GGML_TYPE_Q8_0: + defines.push_back("KV_Q8_0"); + break; + default: + GGML_ABORT("Unsupported KV type for flash attention shader"); + } + variant += std::string("_") + ggml_type_name(key.kv_type); + + if (key.has_mask) { + defines.push_back("MASK"); + variant += "_mask"; + } + if (key.has_sinks) { + defines.push_back("SINKS"); + variant += "_sinks"; + } + if (key.uses_logit_softcap) { + defines.push_back("LOGIT_SOFTCAP"); + variant += "_lgsc"; + } + if (key.kv_direct) { + defines.push_back("KV_DIRECT"); + variant += "_kvdirect"; + } + + defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(key.head_dim_qk)); + variant += std::string("_hsqk") + std::to_string(key.head_dim_qk); + + defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v)); + variant += std::string("_hsv") + std::to_string(key.head_dim_v); + + defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m)); + defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n)); + defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k)); + + uint32_t q_tile = context.sg_mat_m; + uint32_t kv_tile = + std::min(ggml_webgpu_flash_attn_max_kv_tile({ key, context.sg_mat_m, context.sg_mat_n, context.sg_mat_k, + context.wg_mem_limit_bytes, context.max_subgroup_size }), + context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES); + if (key.kv_direct) { + while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) { + kv_tile -= context.sg_mat_n; + } + } + + defines.push_back(std::string("Q_TILE=") + std::to_string(q_tile)); + defines.push_back(std::string("KV_TILE=") + std::to_string(kv_tile)); + + uint32_t wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE); + defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); + + auto processed = preprocessor.preprocess(wgsl_flash_attn, defines); + auto decisions = std::make_shared(); + decisions->q_tile = q_tile; + decisions->kv_tile = kv_tile; + decisions->wg_size = wg_size; + + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + flash_attn_pipelines[key] = pipeline; + return flash_attn_pipelines[key]; + } + + private: + static webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & device, + std::string shader_code, + std::string label) { + wgpu::ShaderSourceWGSL shader_source; + shader_source.code = shader_code.c_str(); + + wgpu::ShaderModuleDescriptor shader_desc; + shader_desc.nextInChain = &shader_source; + + wgpu::ShaderModule shader_module = device.CreateShaderModule(&shader_desc); + + wgpu::ComputePipelineDescriptor pipeline_desc; + pipeline_desc.label = label.c_str(); + pipeline_desc.compute.module = shader_module; + pipeline_desc.compute.entryPoint = "main"; // Entry point in the WGSL code + pipeline_desc.layout = nullptr; // nullptr means auto layout + return { device.CreateComputePipeline(&pipeline_desc), label }; + } + + static uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_flash_attn_shader_lib_context & context) { + const size_t limit_bytes = context.wg_mem_limit_bytes; + const size_t q_tile = context.sg_mat_m; + const size_t base_q_bytes = + (context.key.head_dim_qk + context.key.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES + + 2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES; + size_t bytes_per_kv = 0; + if (!context.key.kv_direct) { + bytes_per_kv += std::max(context.key.head_dim_qk, context.key.head_dim_v); + } + if (context.key.has_mask) { + bytes_per_kv += q_tile; + } + bytes_per_kv += q_tile; + bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES; + const uint32_t max_kv_tile = (limit_bytes - base_q_bytes) / bytes_per_kv; + return (max_kv_tile / context.sg_mat_n) * context.sg_mat_n; + } }; -inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_binary_shader( - pre_wgsl::Preprocessor & preprocessor, - const char * shader_src, - const ggml_webgpu_binary_shader_lib_context & context) { - std::vector defines; - std::string op_name = ggml_op_name((ggml_op) context.key.op); - std::string variant = op_name; - - defines.push_back(std::string("OP_") + op_name); - - switch (context.key.type) { - case GGML_TYPE_F32: - defines.push_back("TYPE_F32"); - variant += "_f32"; - break; - case GGML_TYPE_F16: - defines.push_back("TYPE_F16"); - variant += "_f16"; - break; - default: - GGML_ABORT("Unsupported type for binary shader"); - } - - if (context.key.inplace) { - defines.push_back("INPLACE"); - variant += "_inplace"; - } else if (context.key.overlap) { - defines.push_back("OVERLAP"); - variant += "_overlap"; - } - - defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); - ggml_webgpu_processed_shader result; - result.wgsl = preprocessor.preprocess(shader_src, defines); - result.variant = variant; - auto decisions = std::make_shared(); - decisions->wg_size = context.max_wg_size; - result.decisions = decisions; - return result; -} #endif // GGML_WEBGPU_SHADER_LIB_HPP diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 32e120266a9..17bb2f47126 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -8,7 +8,6 @@ #include "ggml-backend-impl.h" #include "ggml-impl.h" #include "ggml-webgpu-shader-lib.hpp" -#include "ggml-wgsl-shaders.hpp" #include "pre_wgsl.hpp" #ifdef __EMSCRIPTEN__ @@ -23,6 +22,7 @@ #include #include #include +#include #include #include #include @@ -69,50 +69,29 @@ /* Constants */ -// Track https://github.com/gpuweb/gpuweb/issues/5315 for fixes to implementations so this can be removed. -#define WEBGPU_MAX_WG_SIZE 288 - -#define WEBGPU_MUL_MAT_WG_SIZE 256 #define WEBGPU_NUM_PARAM_BUFS 16u #define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 8u #define WEBGPU_WAIT_ANY_TIMEOUT_MS 0 -// Maximum number of in-flight submissions per-thread, to avoid exhausting the parameter buffer pool +// Maximum number of in-flight submissions per-thread, to avoid exhausting the +// parameter buffer pool #define WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE #define WEBGPU_PARAMS_BUF_SIZE_BYTES 128 // enough for 32 parameters #define WEBGPU_NUM_SET_ROWS_ERROR_BUFS 16 #define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4 #define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4 -// For operations which process a row in parallel, this seems like a reasonable default +// For operations which process a row in parallel, this seems like a reasonable +// default #define WEBGPU_ROW_SPLIT_WG_SIZE 64 -// Matrix multiplication parameters - -// Register tiling parameters -#define WEBGPU_MUL_MAT_TILE_M 8 -#define WEBGPU_MUL_MAT_TILE_N 8 -#define WEBGPU_MUL_MAT_WG_SIZE_M 8 -#define WEBGPU_MUL_MAT_WG_SIZE_N 8 -#define WEBGPU_MUL_MAT_TILE_K 32 - -// Subgroup matrix parameters -// The number of subgroups in the M dimension -#define WEBGPU_MUL_MAT_SUBGROUP_M 2 -// The number of subgroups in the N dimension -#define WEBGPU_MUL_MAT_SUBGROUP_N 2 -// The number of subgroup matrices each subgroup accumulates over -#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M 4 -#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2 - -// Matrix-vector multiplication parameters -#define WEBGPU_MUL_MAT_VEC_WG_SIZE 256 -// Must be multiple of 4 to work with vectorized paths, and must divide mul_mat_vec wg size -#define WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG 64 -#define WEBGPU_MUL_MAT_VEC_TILE_K 256 +// Track https://github.com/gpuweb/gpuweb/issues/5315 for fixes to +// implementations so this can be removed, necessary only for get_rows right now +#define WEBGPU_MAX_WG_SIZE 288 /* End Constants */ -// This is a "fake" base pointer, since WebGPU buffers do not have pointers to their locations. +// This is a "fake" base pointer, since WebGPU buffers do not have pointers to +// their locations. static void * const webgpu_ptr_base = (void *) (uintptr_t) 0x1000; // NOLINT // Always returns the base offset of a tensor, regardless of views. @@ -263,12 +242,6 @@ struct webgpu_gpu_profile_buf_pool { }; #endif -struct webgpu_pipeline { - wgpu::ComputePipeline pipeline; - std::string name; - std::shared_ptr context = nullptr; -}; - struct webgpu_command { wgpu::CommandBuffer commands; std::vector params_bufs; @@ -353,41 +326,18 @@ struct webgpu_context_struct { // Points to global instances owned by ggml_backend_webgpu_reg_context webgpu_global_context global_ctx; - pre_wgsl::Preprocessor p; + std::unique_ptr shader_lib; webgpu_buf_pool param_buf_pool; webgpu_buf_pool set_rows_error_buf_pool; - std::map>> mul_mat_pipelines; // src0_type, src1_type, vectorized - std::map>> - mul_mat_vec_pipelines; // src0_type, src1_type, vectorized - - std::unordered_map - flash_attn_pipelines; - - std::unordered_map argmax_pipelines; // key is vec4 - std::unordered_map argsort_pipelines; // key is order (asc/desc) - std::unordered_map argsort_merge_pipelines; // key is order (asc/desc) - std::unordered_map cumsum_pipelines; // key is fixed, no variants yet - std::unordered_map sum_rows_pipelines; // key is fixed, no variants yet - - std::unordered_map - set_rows_pipelines; - std::map> get_rows_pipelines; // src_type, vectorized - - std::map> cpy_pipelines; // src_type, dst_type - - std::unordered_map - binary_pipelines; + std::map> cpy_pipelines; // src_type, dst_type std::map rms_norm_pipelines; // inplace std::map>> rope_pipelines; // type, ff, inplace std::map>> glu_pipelines; // glu_op, type, split - std::map scale_pipelines; // inplace + std::map>> soft_max_pipelines; // mask_type, has_sink, inplace - std::unordered_map - unary_pipelines; - std::unordered_map pad_pipelines; size_t memset_bytes_per_thread; }; @@ -429,25 +379,6 @@ struct ggml_backend_webgpu_buffer_context { /* WebGPU object initializations */ -// Process a WGSL shader string, replacing tokens of the form {{KEY}} with -// the corresponding values provided in `repls`. -static std::string ggml_webgpu_process_shader_repls(const char * src, - const std::map & repls) { - if (!src) { - return std::string(); - } - std::string s = src; - for (const auto & kv : repls) { - std::string token = "{{" + kv.first + "}}"; - size_t pos = 0; - while ((pos = s.find(token, pos)) != std::string::npos) { - s.replace(pos, token.length(), kv.second); - pos += kv.second.length(); - } - } - return s; -} - static webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & device, const char * shader_code, const char * label, @@ -495,8 +426,9 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device, static void ggml_backend_webgpu_wait(webgpu_global_context & ctx, std::vector & futures, bool block = true) { - // If we have too many in-flight submissions, wait on the oldest one first. If there are many threads, - // inflight_max may be 0, meaning that we must wait on all futures. + // If we have too many in-flight submissions, wait on the oldest one first. If + // there are many threads, inflight_max may be 0, meaning that we must wait on + // all futures. uint64_t timeout_ms = block ? UINT64_MAX : 0; uint32_t inflight_threads = ctx->inflight_threads; uint32_t inflight_max = WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD / std::max(inflight_threads, 1u); @@ -681,7 +613,8 @@ static webgpu_command ggml_backend_webgpu_build_multi( encoder.CopyBufferToBuffer(params_bufs.host_buf, 0, params_bufs.dev_buf, 0, params_bufs.dev_buf.GetSize()); } - // If there are SET_ROWS operations in this submission, copy their error buffers to the host. + // If there are SET_ROWS operations in this submission, copy their error + // buffers to the host. if (set_rows_error_bufs) { encoder.CopyBufferToBuffer(set_rows_error_bufs->dev_buf, 0, set_rows_error_bufs->host_buf, 0, set_rows_error_bufs->host_buf.GetSize()); @@ -900,24 +833,11 @@ static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, g } static webgpu_command ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { - const bool circular = ggml_get_op_params_i32(dst, 8) != 0; - - ggml_webgpu_pad_pipeline_key pipeline_key = { .circular = circular }; - ggml_webgpu_pad_shader_lib_context shader_lib_ctx = { - .key = pipeline_key, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup + ggml_webgpu_shader_lib_context shader_lib_ctx = { + .src0 = src, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup }; - webgpu_pipeline pipeline; - auto it = ctx->pad_pipelines.find(pipeline_key); - if (it != ctx->pad_pipelines.end()) { - pipeline = it->second; - } else { - ggml_webgpu_processed_shader processed = ggml_webgpu_preprocess_pad_shader(ctx->p, wgsl_pad, shader_lib_ctx); - pipeline = - ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); - pipeline.context = processed.decisions; - ctx->pad_pipelines.emplace(pipeline_key, pipeline); - } + webgpu_pipeline pipeline = ctx->shader_lib->get_pad_pipeline(shader_lib_ctx); auto * decisions = static_cast(pipeline.context.get()); @@ -971,36 +891,25 @@ static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * idx, ggml_tensor * dst) { - // For set rows specifically, we need to check if src and idx are empty tensors. + // For set rows specifically, we need to check if src and idx are empty + // tensors. if (ggml_is_empty(src) || ggml_is_empty(idx)) { return std::nullopt; } - ggml_webgpu_set_rows_pipeline_key key = { .dst_type = dst->type, - .vec4 = src->ne[0] % 4 == 0, - .i64_idx = idx->type == GGML_TYPE_I64 }; - - ggml_webgpu_set_rows_shader_lib_context shader_lib_ctx = { - .key = key, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup + ggml_webgpu_shader_lib_context shader_lib_ctx = { + .src0 = src, + .src1 = idx, + .dst = dst, + .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup }; - webgpu_pipeline pipeline; - auto it = ctx->set_rows_pipelines.find(key); - if (it != ctx->set_rows_pipelines.end()) { - pipeline = it->second; - } else { - ggml_webgpu_processed_shader processed = - ggml_webgpu_preprocess_set_rows_shader(ctx->p, wgsl_set_rows, shader_lib_ctx); - pipeline = - ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); - pipeline.context = processed.decisions; - ctx->set_rows_pipelines.emplace(key, pipeline); - } + webgpu_pipeline pipeline = ctx->shader_lib->get_set_rows_pipeline(shader_lib_ctx); - auto * decisions = static_cast(pipeline.context.get()); + auto * decisions = static_cast(pipeline.context.get()); std::optional error_bufs = std::nullopt; - if (key.i64_idx) { + if (decisions->i64_idx) { error_bufs = ctx->set_rows_error_buf_pool.alloc_bufs(); if (error_bufs->host_buf.GetMapState() == wgpu::BufferMapState::Mapped) { error_bufs->host_buf.Unmap(); @@ -1038,13 +947,13 @@ static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, .size = ggml_webgpu_tensor_binding_size(ctx, dst) } }; - if (key.i64_idx) { + if (decisions->i64_idx) { entries.push_back( { .binding = 3, .buffer = error_bufs->dev_buf, .offset = 0, .size = error_bufs->dev_buf.GetSize() }); } uint32_t threads; - if (key.vec4) { + if (decisions->vec4) { threads = (src->ne[1] * src->ne[2] * src->ne[3]) * (src->ne[0] / 4); } else { threads = src->ne[0] * src->ne[1] * src->ne[2] * src->ne[3]; @@ -1054,26 +963,47 @@ static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, error_bufs); } +// Workgroup size is a common constant +static std::vector ggml_webgpu_wg_size_entry(uint32_t wg_size) { + std::vector constants(1); + constants[0].key = "wg_size"; + constants[0].value = wg_size; + return constants; +} + static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * idx, ggml_tensor * dst) { - std::vector params = { - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)), - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), - // Convert byte-strides to element-strides - (uint32_t) (src->nb[1] / ggml_type_size(src->type)), (uint32_t) (src->nb[2] / ggml_type_size(src->type)), - (uint32_t) (src->nb[3] / ggml_type_size(src->type)), (uint32_t) (idx->nb[0] / ggml_type_size(idx->type)), - (uint32_t) (idx->nb[1] / ggml_type_size(idx->type)), (uint32_t) (idx->nb[2] / ggml_type_size(idx->type)), - (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), - (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), - // Shape of dst - (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], - // Shape of idx - (uint32_t) (idx->ne[1]), (uint32_t) (idx->ne[2]) + ggml_webgpu_shader_lib_context shader_lib_ctx = { + .src0 = src, + .src1 = nullptr, + .dst = dst, + .max_wg_size = WEBGPU_MAX_WG_SIZE, }; + webgpu_pipeline pipeline = ctx->shader_lib->get_get_rows_pipeline(shader_lib_ctx); + auto * decisions = static_cast(pipeline.context.get()); + + std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + (uint32_t) (src->nb[1] / ggml_type_size(src->type)), + (uint32_t) (src->nb[2] / ggml_type_size(src->type)), + (uint32_t) (src->nb[3] / ggml_type_size(src->type)), + (uint32_t) (idx->nb[0] / ggml_type_size(idx->type)), + (uint32_t) (idx->nb[1] / ggml_type_size(idx->type)), + (uint32_t) (idx->nb[2] / ggml_type_size(idx->type)), + (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), + (uint32_t) dst->ne[0], + (uint32_t) dst->ne[1], + (uint32_t) dst->ne[2], + (uint32_t) dst->ne[3], + (uint32_t) (idx->ne[1]), + (uint32_t) (idx->ne[2]) }; + std::vector entries = { { .binding = 0, .buffer = ggml_webgpu_tensor_buf(src), @@ -1089,10 +1019,8 @@ static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx, .size = ggml_webgpu_tensor_binding_size(ctx, dst) } }; - uint32_t wg_x = CEIL_DIV(dst->ne[1] * dst->ne[2] * dst->ne[3], WEBGPU_MAX_WG_SIZE); + uint32_t wg_x = CEIL_DIV(dst->ne[1] * dst->ne[2] * dst->ne[3], decisions->wg_size); - uint32_t vectorized = src->type == GGML_TYPE_F32 && dst->ne[0] % 4 == 0; - webgpu_pipeline pipeline = ctx->get_rows_pipelines[src->type][vectorized]; return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } @@ -1100,25 +1028,74 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) { + // Determine if this is a mat-vec operation + bool is_vec = (dst->ne[1] == 1); + + // Determine if we should use fast path + bool use_fast = false; + switch (src1->type) { + case GGML_TYPE_F16: + use_fast = (src0->type == GGML_TYPE_F16); + break; + case GGML_TYPE_F32: + switch (src0->type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + case GGML_TYPE_Q4_0: + use_fast = true; + break; + default: + break; + } + break; + default: + break; + } + + ggml_webgpu_shader_lib_context shader_lib_ctx = { + .src0 = src0, + .src1 = src1, + .dst = dst, + .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, + .supports_subgroup_matrix = ctx->global_ctx->capabilities.supports_subgroup_matrix, + .sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m, + .sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n, + .sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k, + .max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size, + }; + + // Get or create pipeline + webgpu_pipeline pipeline; + + if (use_fast && is_vec) { + pipeline = ctx->shader_lib->get_mul_mat_vec_pipeline(shader_lib_ctx); + } else if (use_fast) { + pipeline = ctx->shader_lib->get_mul_mat_fast_pipeline(shader_lib_ctx); + } else { + pipeline = ctx->shader_lib->get_mul_mat_legacy_pipeline(shader_lib_ctx); + } + + // Build params std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), - (uint32_t) dst->ne[0], // number of rows in result (M, transposed) - (uint32_t) dst->ne[1], // number of columns in result (N) - (uint32_t) src0->ne[0], // number of columns in src0/src1 (K) - (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), // stride (elements/blocks) of src0 in dimension 1 - (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), // stride (elements/blocks) of src1 in dimension 1 - (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), // stride (elements/blocks) of src0 in dimension 2 - (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), // stride (elements/blocks) of src1 in dimension 2 - (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), // stride (elements/blocks) of src0 in dimension 3 - (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)), // stride (elements/blocks) of src1 in dimension 3 - (uint32_t) src0->ne[2], // batch size in dimension 2 - (uint32_t) src0->ne[3], // batch size in dimension 3 - (uint32_t) (src1->ne[2] / src0->ne[2]), // broadcast in dimension 2 - (uint32_t) (src1->ne[3] / src0->ne[3]) // broadcast in dimension 3 + (uint32_t) dst->ne[0], + (uint32_t) dst->ne[1], + (uint32_t) src0->ne[0], + (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), + (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), + (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), + (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), + (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), + (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)), + (uint32_t) src0->ne[2], + (uint32_t) src0->ne[3], + (uint32_t) (src1->ne[2] / src0->ne[2]), + (uint32_t) (src1->ne[3] / src0->ne[3]) }; + // Build bind group entries std::vector entries = { { .binding = 0, .buffer = ggml_webgpu_tensor_buf(src0), @@ -1134,68 +1111,44 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, .size = ggml_webgpu_tensor_binding_size(ctx, dst) }, }; - webgpu_pipeline pipeline = ctx->mul_mat_pipelines[src0->type][src1->type][0]; - - uint32_t wg_x = CEIL_DIV(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3], WEBGPU_MUL_MAT_WG_SIZE); + // Calculate workgroup dimensions + uint32_t wg_x = 1; uint32_t wg_y = 1; - bool use_fast = false; - switch (src1->type) { - case GGML_TYPE_F16: - use_fast = (src0->type == GGML_TYPE_F16); - break; - case GGML_TYPE_F32: - switch (src0->type) { - case GGML_TYPE_F32: - case GGML_TYPE_F16: - case GGML_TYPE_Q4_0: - use_fast = true; - break; - default: - break; - } - break; - default: - break; - } - - if (use_fast) { - int vectorized = src0->ne[0] % 4 == 0 && dst->ne[0] % 4 == 0 && dst->ne[1] % 4 == 0; - if (dst->ne[1] == 1) { - // We don't support vectorized mul_mat_vec for quantized types - vectorized = vectorized && (src0->type < 2); - pipeline = ctx->mul_mat_vec_pipelines[src0->type][src1->type][vectorized]; - uint32_t batches = dst->ne[2] * dst->ne[3]; - uint32_t output_groups = CEIL_DIV(dst->ne[0], WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG); - uint32_t total_wg = output_groups * batches; - wg_x = total_wg % ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; - wg_y = CEIL_DIV(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension); + if (use_fast && is_vec) { + auto decisions = static_cast(pipeline.context.get()); + + uint32_t batches = dst->ne[2] * dst->ne[3]; + uint32_t output_groups = CEIL_DIV(dst->ne[0], decisions->outputs_per_wg); + uint32_t total_wg = output_groups * batches; + wg_x = total_wg % ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; + wg_y = CEIL_DIV(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension); + } else if (use_fast) { + auto decisions = static_cast(pipeline.context.get()); + + // Fast-path tiled/subgroup calculations + uint32_t wg_m, wg_n; + if (decisions->use_subgroup_matrix) { + uint32_t wg_m_sg_tile = + decisions->subgroup_m * decisions->subgroup_matrix_m * ctx->global_ctx->capabilities.sg_mat_m; + wg_m = CEIL_DIV(dst->ne[0], wg_m_sg_tile); + uint32_t wg_n_sg_tile = + decisions->subgroup_n * decisions->subgroup_matrix_n * ctx->global_ctx->capabilities.sg_mat_n; + wg_n = CEIL_DIV(dst->ne[1], wg_n_sg_tile); } else { - pipeline = ctx->mul_mat_pipelines[src0->type][src1->type][vectorized]; - uint32_t wg_m; - uint32_t wg_n; -#ifndef __EMSCRIPTEN__ - if (ctx->global_ctx->capabilities.supports_subgroup_matrix) { - // The total number of subgroups/workgroups needed per matrix. - uint32_t wg_m_sg_tile = WEBGPU_MUL_MAT_SUBGROUP_M * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M * - ctx->global_ctx->capabilities.sg_mat_m; - wg_m = CEIL_DIV(dst->ne[0], wg_m_sg_tile); - uint32_t wg_n_sg_tile = WEBGPU_MUL_MAT_SUBGROUP_N * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N * - ctx->global_ctx->capabilities.sg_mat_n; - wg_n = CEIL_DIV(dst->ne[1], wg_n_sg_tile); - } else { -#endif - uint32_t tile_m_s = WEBGPU_MUL_MAT_TILE_M * WEBGPU_MUL_MAT_WG_SIZE_M; - uint32_t tile_n_s = WEBGPU_MUL_MAT_TILE_N * WEBGPU_MUL_MAT_WG_SIZE_N; - wg_m = CEIL_DIV(dst->ne[0], tile_m_s); - wg_n = CEIL_DIV(dst->ne[1], tile_n_s); -#ifndef __EMSCRIPTEN__ - } -#endif - - wg_x = wg_m * wg_n * dst->ne[2] * dst->ne[3]; + uint32_t tile_m_s = decisions->tile_m * decisions->wg_size_m; + uint32_t tile_n_s = decisions->tile_n * decisions->wg_size_n; + wg_m = CEIL_DIV(dst->ne[0], tile_m_s); + wg_n = CEIL_DIV(dst->ne[1], tile_n_s); } + wg_x = wg_m * wg_n * dst->ne[2] * dst->ne[3]; + } else { // legacy + auto decisions = static_cast(pipeline.context.get()); + uint32_t wg_size = decisions->wg_size; + wg_x = CEIL_DIV(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3], wg_size); + wg_y = 1; } + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, wg_y); } @@ -1283,40 +1236,22 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, .offset = ggml_webgpu_tensor_align_offset(ctx, dst), .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); - bool kv_direct = (K->type == GGML_TYPE_F16) && (Q->ne[0] % ctx->global_ctx->capabilities.sg_mat_k == 0) && - (K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0); - - ggml_webgpu_flash_attn_pipeline_key key = { - .kv_type = K->type, - .head_dim_qk = (uint32_t) Q->ne[0], - .head_dim_v = (uint32_t) V->ne[0], - .kv_direct = kv_direct, - .has_mask = static_cast(has_mask), - .has_sinks = static_cast(has_sinks), - .uses_logit_softcap = logit_softcap != 0.0f, + ggml_webgpu_shader_lib_context shader_lib_ctx = { + .src0 = Q, + .src1 = K, + .src2 = V, + .src3 = mask, + .src4 = sinks, + .dst = dst, + .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, + .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize, + .sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m, + .sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n, + .sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k, + .max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size, }; - webgpu_pipeline pipeline; - auto it = ctx->flash_attn_pipelines.find(key); - if (it != ctx->flash_attn_pipelines.end()) { - pipeline = it->second; - } else { - ggml_webgpu_flash_attn_shader_lib_context shader_lib_ctx = { - .key = key, - .sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m, - .sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n, - .sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k, - .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize, - .max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size - }; - - ggml_webgpu_processed_shader processed = - ggml_webgpu_preprocess_flash_attn_shader(ctx->p, wgsl_flash_attn, shader_lib_ctx); - pipeline = - ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); - pipeline.context = processed.decisions; - ctx->flash_attn_pipelines.emplace(key, pipeline); - } + webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_pipeline(shader_lib_ctx); auto * decisions = static_cast(pipeline.context.get()); @@ -1329,27 +1264,16 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { bool is_unary = dst->op == GGML_OP_UNARY; bool inplace = ggml_webgpu_tensor_equal(src, dst) || (dst->op == GGML_OP_FILL); - int op = is_unary ? (int) ggml_get_unary_op(dst) : dst->op; - ggml_webgpu_unary_pipeline_key pipeline_key = { - .type = dst->type, .op = op, .is_unary = is_unary, .inplace = inplace - }; - ggml_webgpu_unary_shader_lib_context shader_lib_ctx = { - .key = pipeline_key, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup + ggml_webgpu_shader_lib_context shader_lib_ctx = { + .src0 = src, + .src1 = nullptr, + .dst = dst, + .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, + .inplace = inplace, }; - webgpu_pipeline pipeline; - auto it = ctx->unary_pipelines.find(pipeline_key); - if (it != ctx->unary_pipelines.end()) { - pipeline = it->second; - } else { - ggml_webgpu_processed_shader processed = - ggml_webgpu_preprocess_unary_shader(ctx->p, wgsl_unary, shader_lib_ctx); - pipeline = - ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); - pipeline.context = processed.decisions; - ctx->unary_pipelines.emplace(pipeline_key, pipeline); - } + webgpu_pipeline pipeline = ctx->shader_lib->get_unary_pipeline(shader_lib_ctx); auto * decisions = static_cast(pipeline.context.get()); @@ -1421,30 +1345,18 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx, ggml_tensor * dst) { binary_overlap_flags flags = ggml_webgpu_detect_binary_overlap(src0, src1, dst); - ggml_webgpu_binary_pipeline_key pipeline_key = { - .type = dst->type, - .op = dst->op, - .inplace = flags.inplace, - .overlap = flags.overlap, - }; - ggml_webgpu_binary_shader_lib_context shader_lib_ctx = { - .key = pipeline_key, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup + ggml_webgpu_shader_lib_context shader_lib_ctx = { + .src0 = src0, + .src1 = src1, + .dst = dst, + .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, + .inplace = flags.inplace, + .overlap = flags.overlap, }; - webgpu_pipeline pipeline; - auto it = ctx->binary_pipelines.find(pipeline_key); - if (it != ctx->binary_pipelines.end()) { - pipeline = it->second; - } else { - ggml_webgpu_processed_shader processed = - ggml_webgpu_preprocess_binary_shader(ctx->p, wgsl_binary, shader_lib_ctx); - pipeline = - ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); - pipeline.context = processed.decisions; - ctx->binary_pipelines.emplace(pipeline_key, pipeline); - } + webgpu_pipeline pipeline = ctx->shader_lib->get_binary_pipeline(shader_lib_ctx); - auto * decisions = static_cast(pipeline.context.get()); + auto * decisions = static_cast(pipeline.context.get()); uint32_t ne = (uint32_t) ggml_nelements(dst); @@ -1669,8 +1581,20 @@ static webgpu_command ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0, } static webgpu_command ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { - int inplace = ggml_webgpu_tensor_equal(src, dst); + bool inplace = ggml_webgpu_tensor_equal(src, dst); + ggml_webgpu_shader_lib_context shader_lib_ctx = { + .src0 = src, + .src1 = nullptr, + .dst = dst, + .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, + .inplace = inplace, + }; + + webgpu_pipeline pipeline = ctx->shader_lib->get_scale_pipeline(shader_lib_ctx); + auto * decisions = static_cast(pipeline.context.get()); + + // params unchanged std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), @@ -1688,12 +1612,14 @@ static webgpu_command ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, *(uint32_t *) &dst->op_params[1] // bias }; + // bindgroups unchanged std::vector entries = { { .binding = 0, .buffer = ggml_webgpu_tensor_buf(src), .offset = ggml_webgpu_tensor_align_offset(ctx, src), .size = ggml_webgpu_tensor_binding_size(ctx, src) } }; + if (!inplace) { entries.push_back({ .binding = 1, .buffer = ggml_webgpu_tensor_buf(dst), @@ -1701,9 +1627,8 @@ static webgpu_command ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); } - uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, ctx->scale_pipelines[inplace], params, - entries, wg_x); + uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx, @@ -1796,63 +1721,30 @@ static webgpu_command ggml_webgpu_argmax(webgpu_context & ctx, ggml_tensor * src .size = ggml_webgpu_tensor_binding_size(ctx, dst) } }; - ggml_webgpu_generic_shader_lib_context shader_lib_ctx = { - .vec4 = src->ne[0] % 4 == 0, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, + ggml_webgpu_shader_lib_context shader_lib_ctx = { + .src0 = src, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup }; - webgpu_pipeline pipeline; - auto it = ctx->argmax_pipelines.find(shader_lib_ctx.vec4); - if (it != ctx->argmax_pipelines.end()) { - pipeline = it->second; - } else { - ggml_webgpu_processed_shader processed = - ggml_webgpu_preprocess_generic_shader(ctx->p, wgsl_argmax, shader_lib_ctx, "argmax"); - pipeline = - ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); - ctx->argmax_pipelines.emplace(shader_lib_ctx.vec4, pipeline); - } - uint32_t wg_x = ggml_nelements(dst); + webgpu_pipeline pipeline = ctx->shader_lib->get_argmax_pipeline(shader_lib_ctx); + uint32_t wg_x = ggml_nelements(dst); return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } static webgpu_command ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { - bool is_top_k = dst->op == GGML_OP_TOP_K; - // ascending order is 0, descending order is 1 - const int32_t order = is_top_k ? (int32_t) GGML_SORT_ORDER_DESC : (int32_t) ggml_get_op_params_i32(dst, 0); + bool is_top_k = dst->op == GGML_OP_TOP_K; - ggml_webgpu_argsort_shader_lib_context shader_lib_ctx = { + ggml_webgpu_shader_lib_context shader_lib_ctx = { + .src0 = src, + .src1 = nullptr, + .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize, - .order = order }; - webgpu_pipeline argsort_pipeline; - auto it = ctx->argsort_pipelines.find(order); - if (it != ctx->argsort_pipelines.end()) { - argsort_pipeline = it->second; - } else { - ggml_webgpu_processed_shader processed = - ggml_webgpu_preprocess_argsort_shader(ctx->p, wgsl_argsort, shader_lib_ctx); - argsort_pipeline = - ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); - argsort_pipeline.context = processed.decisions; - ctx->argsort_pipelines.emplace(order, argsort_pipeline); - } - auto * argsort_decisions = static_cast(argsort_pipeline.context.get()); - - webgpu_pipeline argsort_merge_pipeline; - it = ctx->argsort_merge_pipelines.find(order); - if (it != ctx->argsort_merge_pipelines.end()) { - argsort_merge_pipeline = it->second; - } else { - ggml_webgpu_processed_shader processed = - ggml_webgpu_preprocess_argsort_merge_shader(ctx->p, wgsl_argsort_merge, shader_lib_ctx); - argsort_merge_pipeline = - ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); - argsort_merge_pipeline.context = processed.decisions; - ctx->argsort_merge_pipelines.emplace(order, argsort_merge_pipeline); - } + webgpu_pipeline argsort_pipeline = ctx->shader_lib->get_argsort_pipeline(shader_lib_ctx); + auto * argsort_decisions = static_cast(argsort_pipeline.context.get()); + + webgpu_pipeline argsort_merge_pipeline = ctx->shader_lib->get_argsort_merge_pipeline(shader_lib_ctx); const uint32_t src_ne0 = (uint32_t) src->ne[0]; const uint32_t nrows = (uint32_t) ggml_nrows(src); @@ -2011,22 +1903,15 @@ static webgpu_command ggml_webgpu_cumsum(webgpu_context & ctx, ggml_tensor * src .size = ggml_webgpu_tensor_binding_size(ctx, dst) } }; - ggml_webgpu_generic_shader_lib_context shader_lib_ctx = { - .vec4 = false, + ggml_webgpu_shader_lib_context shader_lib_ctx = { + .src0 = src, + .src1 = nullptr, + .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, }; - webgpu_pipeline pipeline; - auto it = ctx->cumsum_pipelines.find(1); - if (it != ctx->cumsum_pipelines.end()) { - pipeline = it->second; - } else { - ggml_webgpu_processed_shader processed = - ggml_webgpu_preprocess_generic_shader(ctx->p, wgsl_cumsum, shader_lib_ctx, "cumsum"); - pipeline = - ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); - ctx->cumsum_pipelines.emplace(1, pipeline); - } - uint32_t wg_x = ggml_nrows(dst); + + webgpu_pipeline pipeline = ctx->shader_lib->get_cumsum_pipeline(shader_lib_ctx); + uint32_t wg_x = ggml_nrows(dst); return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } @@ -2052,22 +1937,12 @@ static webgpu_command ggml_webgpu_sum_rows(webgpu_context & ctx, ggml_tensor * s .size = ggml_webgpu_tensor_binding_size(ctx, dst) } }; - ggml_webgpu_generic_shader_lib_context shader_lib_ctx = { - .vec4 = false, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, + ggml_webgpu_shader_lib_context shader_lib_ctx = { + .src0 = src, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup }; - webgpu_pipeline pipeline; - auto it = ctx->sum_rows_pipelines.find(1); - if (it != ctx->sum_rows_pipelines.end()) { - pipeline = it->second; - } else { - ggml_webgpu_processed_shader processed = - ggml_webgpu_preprocess_generic_shader(ctx->p, wgsl_sum_rows, shader_lib_ctx, "sum_rows"); - pipeline = - ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); - ctx->sum_rows_pipelines.emplace(1, pipeline); - } + webgpu_pipeline pipeline = ctx->shader_lib->get_sum_rows_pipeline(shader_lib_ctx); + uint32_t wg_x = total_sum ? 1 : ggml_nrows(dst); return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } @@ -2233,7 +2108,9 @@ static void ggml_backend_webgpu_buffer_memset_tensor(ggml_backend_buffer_t buffe size_t offset, size_t size) { if (size == 0) { - WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor: size is zero, nothing to do."); + WEBGPU_LOG_DEBUG( + "ggml_backend_webgpu_buffer_memset_tensor: size is zero, " + "nothing to do."); return; } @@ -2310,7 +2187,8 @@ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer, size_t final_size = size; if (size % 4 != 0) { - // If size is not a multiple of 4, we need to round it up to the next multiple of 4 + // If size is not a multiple of 4, we need to round it up to the next + // multiple of 4 final_size = size + (4 - (size % 4)); } @@ -2364,7 +2242,8 @@ static ggml_backend_buffer_i ggml_backend_webgpu_buffer_interface = { /* .get_tensor = */ ggml_backend_webgpu_buffer_get_tensor, /* .cpy_tensor = */ NULL, // TODO: optional, implement this /* .clear = */ ggml_backend_webgpu_buffer_clear, - /* .reset = */ NULL, // TODO: optional, think it coordinates with .init_tensor + /* .reset = */ NULL, // TODO: optional, think it coordinates with + // .init_tensor }; /* End GGML Backend Buffer Interface */ @@ -2401,7 +2280,8 @@ static size_t ggml_backend_webgpu_buffer_type_get_alignment(ggml_backend_buffer_ return dev_ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment; } -// maxBufferSize might be larger, but you can't bind more than maxStorageBufferBindingSize to a single binding. +// maxBufferSize might be larger, but you can't bind more than +// maxStorageBufferBindingSize to a single binding. static size_t ggml_backend_webgpu_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) { ggml_backend_webgpu_device_context * dev_ctx = static_cast(buft->device->context); @@ -2487,14 +2367,6 @@ static ggml_guid_t ggml_backend_webgpu_guid(void) { return reinterpret_cast((void *) guid_str); } -// Workgroup size is a common constant -static std::vector ggml_webgpu_wg_size_entry(uint32_t wg_size) { - std::vector constants(1); - constants[0].key = "wg_size"; - constants[0].value = wg_size; - return constants; -} - static void ggml_webgpu_init_memset_pipeline(webgpu_global_context & ctx) { // we use the maximum workgroup size for the memset pipeline size_t max_threads = WEBGPU_MAX_WG_SIZE * ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; @@ -2509,207 +2381,6 @@ static void ggml_webgpu_init_memset_pipeline(webgpu_global_context & ctx) { ctx->memset_pipelines[0] = ggml_webgpu_create_pipeline(ctx->device, wgsl_memset, "memset", constants); } -static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { - // Q4/Q5/Q8 classic quantizations - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q4_0_f32, "mul_mat_q4_0_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_1][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q4_1_f32, "mul_mat_q4_1_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q5_0][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q5_0_f32, "mul_mat_q5_0_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q5_1][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q5_1_f32, "mul_mat_q5_1_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q8_0][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q8_0_f32, "mul_mat_q8_0_f32"); - - // K-quantizations - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q2_K][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q2_k_f32, "mul_mat_q2_k_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q3_K][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q3_k_f32, "mul_mat_q3_k_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_K][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q4_k_f32, "mul_mat_q4_k_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q5_K][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q5_k_f32, "mul_mat_q5_k_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q6_K][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q6_k_f32, "mul_mat_q6_k_f32"); - - // IQ quantizations (2-, 3-, 4-bit variants) - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ2_XXS][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq2_xxs_f32, "mul_mat_iq2_xxs_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ2_XS][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq2_xs_f32, "mul_mat_iq2_xs_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ2_S][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq2_s_f32, "mul_mat_iq2_s_f32"); - - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ3_XXS][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq3_xxs_f32, "mul_mat_iq3_xxs_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ3_S][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq3_s_f32, "mul_mat_iq3_s_f32"); - - // 1-bit and 4-bit IQ variants - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ1_S][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq1_s_f32, "mul_mat_iq1_s_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ1_M][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq1_m_f32, "mul_mat_iq1_m_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ4_NL][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq4_nl_f32, "mul_mat_iq4_nl_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ4_XS][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq4_xs_f32, "mul_mat_iq4_xs_f32"); - - std::string proc_mul_mat_f32_f32; - std::string proc_mul_mat_f32_f32_vec; - std::string proc_mul_mat_f16_f32; - std::string proc_mul_mat_f16_f32_vec; - std::string proc_mul_mat_f16_f16; - std::string proc_mul_mat_f16_f16_vec; - std::string proc_mul_mat_q4_0_f32; - std::string proc_mul_mat_q4_0_f32_vec; - - std::vector mul_mat_constants; -#ifndef __EMSCRIPTEN__ - if (webgpu_ctx->global_ctx->capabilities.supports_subgroup_matrix) { - std::map sg_matrix_repls; - sg_matrix_repls["WEBGPU_MAX_SUBGROUP_SIZE"] = - std::to_string(webgpu_ctx->global_ctx->capabilities.max_subgroup_size); - sg_matrix_repls["WEBGPU_TILE_K"] = std::to_string(WEBGPU_MUL_MAT_TILE_K); - sg_matrix_repls["WEBGPU_SUBGROUP_M"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_M); - sg_matrix_repls["WEBGPU_SUBGROUP_N"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_N); - sg_matrix_repls["WEBGPU_SUBGROUP_MATRIX_M"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M); - sg_matrix_repls["WEBGPU_SUBGROUP_MATRIX_N"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N); - sg_matrix_repls["WEBGPU_SG_MAT_M_SIZE"] = std::to_string(webgpu_ctx->global_ctx->capabilities.sg_mat_m); - sg_matrix_repls["WEBGPU_SG_MAT_N_SIZE"] = std::to_string(webgpu_ctx->global_ctx->capabilities.sg_mat_n); - sg_matrix_repls["WEBGPU_SG_MAT_K_SIZE"] = std::to_string(webgpu_ctx->global_ctx->capabilities.sg_mat_k); - proc_mul_mat_f32_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32, sg_matrix_repls); - proc_mul_mat_f32_f32_vec = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32_vec, sg_matrix_repls); - proc_mul_mat_f16_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f32, sg_matrix_repls); - proc_mul_mat_f16_f32_vec = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f32_vec, sg_matrix_repls); - proc_mul_mat_f16_f16 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f16, sg_matrix_repls); - proc_mul_mat_f16_f16_vec = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f16_vec, sg_matrix_repls); - proc_mul_mat_q4_0_f32 = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_q4_0_f32, sg_matrix_repls); - proc_mul_mat_q4_0_f32_vec = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_q4_0_f32_vec, sg_matrix_repls); - } else { -#endif - mul_mat_constants.push_back({ .key = "TILE_K", .value = WEBGPU_MUL_MAT_TILE_K }); - mul_mat_constants.push_back({ .key = "WORKGROUP_SIZE_M", .value = WEBGPU_MUL_MAT_WG_SIZE_M }); - mul_mat_constants.push_back({ .key = "WORKGROUP_SIZE_N", .value = WEBGPU_MUL_MAT_WG_SIZE_N }); - - std::map reg_repls; - reg_repls["WEBGPU_TILE_M"] = std::to_string(WEBGPU_MUL_MAT_TILE_M); - reg_repls["WEBGPU_TILE_N"] = std::to_string(WEBGPU_MUL_MAT_TILE_N); - - proc_mul_mat_f32_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f32_f32, reg_repls); - proc_mul_mat_f32_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f32_f32_vec, reg_repls); - proc_mul_mat_f16_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f32, reg_repls); - proc_mul_mat_f16_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f32_vec, reg_repls); - proc_mul_mat_f16_f16 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16, reg_repls); - proc_mul_mat_f16_f16_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16_vec, reg_repls); - proc_mul_mat_q4_0_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_q4_0_f32, reg_repls); - proc_mul_mat_q4_0_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_q4_0_f32_vec, reg_repls); -#ifndef __EMSCRIPTEN__ - } -#endif - - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, proc_mul_mat_f32_f32.c_str(), "mul_mat_f32_f32", mul_mat_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, proc_mul_mat_f32_f32_vec.c_str(), "mul_mat_f32_f32_vec", mul_mat_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, proc_mul_mat_f16_f32.c_str(), "mul_mat_f16_f32", mul_mat_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, proc_mul_mat_f16_f32_vec.c_str(), "mul_mat_f16_f32_vec", mul_mat_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, proc_mul_mat_f16_f16.c_str(), "mul_mat_f16_f16", mul_mat_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, proc_mul_mat_f16_f16_vec.c_str(), "mul_mat_f16_f16_vec", mul_mat_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, proc_mul_mat_q4_0_f32.c_str(), "mul_mat_q4_0_f32", mul_mat_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, proc_mul_mat_q4_0_f32_vec.c_str(), "mul_mat_q4_0_f32_vec", mul_mat_constants); - - std::vector mul_mat_vec_constants(3); - mul_mat_vec_constants[0].key = "WORKGROUP_SIZE"; - mul_mat_vec_constants[0].value = WEBGPU_MUL_MAT_VEC_WG_SIZE; - mul_mat_vec_constants[1].key = "TILE_K"; - mul_mat_vec_constants[1].value = WEBGPU_MUL_MAT_VEC_TILE_K; - mul_mat_vec_constants[2].key = "OUTPUTS_PER_WG"; - mul_mat_vec_constants[2].value = WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG; - - webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f32_f32, "mul_mat_vec_f32_f32", mul_mat_vec_constants); - webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f32_f32_vec, "mul_mat_vec_f32_f32_vec", mul_mat_vec_constants); - webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f16_f32, "mul_mat_vec_f16_f32", mul_mat_vec_constants); - webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f16_f32_vec, "mul_mat_vec_f16_f32_vec", mul_mat_vec_constants); - webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f16_f16, "mul_mat_vec_f16_f16", mul_mat_vec_constants); - webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f16_f16_vec, "mul_mat_vec_f16_f16_vec", mul_mat_vec_constants); - webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_q4_0_f32, "mul_mat_vec_q4_0_f32", mul_mat_vec_constants); -} - -static void ggml_webgpu_init_get_rows_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); - - webgpu_ctx->get_rows_pipelines[GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_f32, "get_rows_f32", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_get_rows_f32_vec, "get_rows_f32_vec", constants); - - webgpu_ctx->get_rows_pipelines[GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_f16, "get_rows_f16", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_I32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_i32, "get_rows_i32", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q4_0][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q4_0, "get_rows_q4_0", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q4_1][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q4_1, "get_rows_q4_1", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q5_0][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q5_0, "get_rows_q5_0", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q5_1][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q5_1, "get_rows_q5_1", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q8_0][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q8_0, "get_rows_q8_0", constants); - - webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q2_K][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q2_k, "get_rows_q2_k", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q3_K][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q3_k, "get_rows_q3_k", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q4_K][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q4_k, "get_rows_q4_k", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q5_K][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q5_k, "get_rows_q5_k", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q6_K][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q6_k, "get_rows_q6_k", constants); - - webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ2_XXS][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_get_rows_iq2_xxs, "get_rows_iq2_xxs", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ2_XS][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq2_xs, "get_rows_iq2_xs", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ2_S][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq2_s, "get_rows_iq2_s", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ3_XXS][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_get_rows_iq3_xxs, "get_rows_iq3_xxs", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ3_S][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq3_s, "get_rows_iq3_s", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ1_S][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq1_s, "get_rows_iq1_s", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ1_M][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq1_m, "get_rows_iq1_m", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ4_NL][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq4_nl, "get_rows_iq4_nl", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ4_XS][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq4_xs, "get_rows_iq4_xs", constants); -} - static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) { std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); @@ -2816,15 +2487,6 @@ static void ggml_webgpu_init_glu_pipeline(webgpu_context & webgpu_ctx) { webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f16_split, "geglu_quick_f16_split", constants); } -static void ggml_webgpu_init_scale_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); - - webgpu_ctx->scale_pipelines[0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_scale_f32, "scale_f32", constants); - webgpu_ctx->scale_pipelines[1] = ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_scale_f32_inplace, - "scale_f32_inplace", constants); -} - static void ggml_webgpu_init_soft_max_pipeline(webgpu_context & webgpu_ctx) { std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE); @@ -3015,6 +2677,7 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) { ggml_backend_webgpu_device_context * dev_ctx = (ggml_backend_webgpu_device_context *) dev->context; webgpu_context webgpu_ctx = std::make_shared(); webgpu_ctx->global_ctx = dev_ctx->webgpu_global_ctx; + webgpu_ctx->shader_lib = std::make_unique(dev_ctx->webgpu_global_ctx->device); webgpu_ctx->param_buf_pool.init(webgpu_ctx->global_ctx->device, WEBGPU_NUM_PARAM_BUFS, WEBGPU_PARAMS_BUF_SIZE_BYTES, wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform, wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite); @@ -3023,13 +2686,10 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) { wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage, wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead); - ggml_webgpu_init_mul_mat_pipeline(webgpu_ctx); - ggml_webgpu_init_get_rows_pipeline(webgpu_ctx); ggml_webgpu_init_cpy_pipeline(webgpu_ctx); ggml_webgpu_init_rms_norm_pipeline(webgpu_ctx); ggml_webgpu_init_rope_pipeline(webgpu_ctx); ggml_webgpu_init_glu_pipeline(webgpu_ctx); - ggml_webgpu_init_scale_pipeline(webgpu_ctx); ggml_webgpu_init_soft_max_pipeline(webgpu_ctx); #ifdef GGML_WEBGPU_DEBUG // Initialize debug buffers @@ -3071,11 +2731,11 @@ static ggml_backend_buffer_type_t ggml_backend_webgpu_device_get_buffer_type(ggm static struct ggml_backend_buffer_type ggml_backend_webgpu_buffer_type = { /* .iface = */ { /* .get_name = */ ggml_backend_webgpu_buffer_type_get_name, - /* .alloc_buffer = */ ggml_backend_webgpu_buffer_type_alloc_buffer, - /* .get_alignment = */ ggml_backend_webgpu_buffer_type_get_alignment, - /* .get_max_size = */ ggml_backend_webgpu_buffer_type_get_max_size, - /* .get_alloc_size = */ ggml_backend_webgpu_buffer_type_get_alloc_size, - /* .is_host = */ NULL, // defaults to false + /* .alloc_buffer = */ + ggml_backend_webgpu_buffer_type_alloc_buffer, /* .get_alignment = */ + ggml_backend_webgpu_buffer_type_get_alignment, /* .get_max_size = */ + ggml_backend_webgpu_buffer_type_get_max_size, /* .get_alloc_size = */ + ggml_backend_webgpu_buffer_type_get_alloc_size, /* .is_host = */ NULL, // defaults to false }, /* .device = */ dev, diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl index 389c97bb51b..9a5b18ebc07 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl @@ -1,5 +1,4 @@ -#decl(BYTE_HELPERS) - +#ifdef BYTE_HELPERS fn get_byte(value: u32, index: u32) -> u32 { return (value >> (index * 8)) & 0xFF; } @@ -7,76 +6,74 @@ fn get_byte(value: u32, index: u32) -> u32 { fn get_byte_i32(value: u32, index: u32) -> i32 { return bitcast(((value >> (index * 8)) & 0xFF) << 24) >> 24; } +#endif -#enddecl(BYTE_HELPERS) - -#decl(Q4_0_T) +#ifdef Q4_0_T struct q4_0 { d: f16, qs: array }; -#enddecl(Q4_0_T) +#endif -#decl(Q4_1_T) +#ifdef Q4_1_T struct q4_1 { d: f16, m: f16, qs: array }; -#enddecl(Q4_1_T) +#endif -#decl(Q5_0_T) +#ifdef Q5_0_T struct q5_0 { d: f16, qh: array, qs: array }; -#enddecl(Q5_0_T) +#endif -#decl(Q5_1_T) +#ifdef Q5_1_T struct q5_1 { d: f16, m: f16, qh: u32, qs: array }; -#enddecl(Q5_1_T) +#endif -#decl(Q8_0_T) +#ifdef Q8_0_T struct q8_0 { d: f16, qs: array }; -#enddecl(Q8_0_T) +#endif -#decl(Q8_1_T) +#ifdef Q8_1_T struct q8_1 { d: f16, m: f16, qs: array }; -#enddecl(Q8_1_T) +#endif -#decl(Q2_K_T) -struct q2_k { +#ifdef Q2_K_T +struct q2_K { scales: array, qs: array, d: f16, dmin: f16 }; -#enddecl(Q2_K_T) +#endif -#decl(Q3_K_T) -struct q3_k { +#ifdef Q3_K_T +struct q3_K { hmask: array, qs: array, scales: array, d: f16 }; -#enddecl(Q3_K_T) - -#decl(Q45_K_SCALE_MIN) +#endif +#if defined(Q4_K_SCALE_MIN) || defined(Q5_K_SCALE_MIN) fn get_scale_min(is: u32, scales: array) -> vec2 { if (is < 4) { let sc_byte = get_byte(scales[is / 4], is % 4); @@ -91,69 +88,67 @@ fn get_scale_min(is: u32, scales: array) -> vec2 { return vec2(f32(sc), f32(m)); } } - -#enddecl(Q45_K_SCALE_MIN) - -#decl(Q4_K_T) -struct q4_k { +#endif +#ifdef Q4_K_T +struct q4_K { d: f16, dmin: f16, scales: array, qs: array }; -#enddecl(Q4_K_T) +#endif -#decl(Q5_K_T) -struct q5_k { +#ifdef Q5_K_T +struct q5_K { d: f16, dmin: f16, scales: array, qh: array, qs: array }; -#enddecl(Q5_K_T) +#endif -#decl(Q6_K_T) -struct q6_k { +#ifdef Q6_K_T +struct q6_K { ql: array, qh: array, scales: array, d: f16 }; -#enddecl(Q6_K_T) +#endif -#decl(IQ2_XXS_T) +#ifdef IQ2_XXS_T struct iq2_xxs { d: f16, qs: array }; -#enddecl(IQ2_XXS_T) +#endif -#decl(IQ2_XS_T) +#ifdef IQ2_XS_T struct iq2_xs { d: f16, qs: array, scales: array }; -#enddecl(IQ2_XS_T) +#endif -#decl(IQ2_S_T) +#ifdef IQ2_S_T struct iq2_s { d: f16, qs: array, qh: array, scales: array }; -#enddecl(IQ2_S_T) +#endif -#decl(IQ3_XSS_T) +#ifdef IQ3_XXS_T struct iq3_xxs { d: f16, qs: array }; -#enddecl(IQ3_XSS_T) +#endif -#decl(IQ3_S_T) +#ifdef IQ3_S_T struct iq3_s { d: f16, qs: array, @@ -161,41 +156,41 @@ struct iq3_s { signs: array, scales: array }; -#enddecl(IQ3_S_T) +#endif -#decl(IQ1_S_T) +#ifdef IQ1_S_T struct iq1_s { d: f16, qs: array, qh: array }; -#enddecl(IQ1_S_T) +#endif -#decl(IQ1_M_T) +#ifdef IQ1_M_T struct iq1_m { qs: array, qh: array, scales: array }; -#enddecl(IQ1_M_T) +#endif -#decl(IQ4_NL_T) +#ifdef IQ4_NL_T struct iq4_nl { d: f16, qs: array, }; -#enddecl(IQ4_NL_T) +#endif -#decl(IQ4_XS_T) +#ifdef IQ4_XS_T struct iq4_xs { d: f16, scales_h: f16, scales_l: u32, qs: array }; -#enddecl(IQ4_XS_T) +#endif -#decl(IQ23_TABLES) +#if defined(IQ2_XXS_TABLES) || defined(IQ2_XS_TABLES) || defined(IQ2_S_TABLES) || defined(IQ3_XXS_TABLES) || defined(IQ3_S_TABLES) const kmask_iq2xs : array = array( 0x08040201u, // 1, 2, 4, 8 0x80402010u // 16, 32, 64, 128 @@ -211,9 +206,9 @@ const ksigns_iq2xs: array = array( 0x63e2e160,0xe76665e4,0xeb6a69e8,0x6feeed6c, 0xf37271f0,0x77f6f574,0x7bfaf978,0xff7e7dfc ); -#enddecl(IQ23_TABLES) +#endif -#decl(IQ2_XXS_GRID) +#ifdef IQ2_XXS_GRID const iq2xxs_grid = array( 0x08080808, 0x08080808, 0x0808082b, 0x08080808, 0x08081919, 0x08080808, 0x08082b08, 0x08080808, 0x08082b2b, 0x08080808, 0x08190819, 0x08080808, 0x08191908, 0x08080808, 0x082b0808, 0x08080808, @@ -280,9 +275,9 @@ const iq2xxs_grid = array( 0x0808082b, 0x2b2b0808, 0x19190808, 0x2b2b0808, 0x2b081919, 0x2b2b0808, 0x08082b19, 0x2b2b0819, 0x08080808, 0x2b2b082b, 0x08192b08, 0x2b2b1908, 0x19190808, 0x2b2b2b08, 0x08081908, 0x2b2b2b19 ); -#enddecl(IQ2_XXS_GRID) +#endif -#decl(IQ2_XS_GRID) +#ifdef IQ2_XS_GRID const iq2xs_grid = array( 0x08080808, 0x08080808, 0x0808082b, 0x08080808, 0x08081919, 0x08080808, 0x08082b08, 0x08080808, 0x08082b2b, 0x08080808, 0x08190819, 0x08080808, 0x08191908, 0x08080808, 0x0819192b, 0x08080808, @@ -413,9 +408,9 @@ const iq2xs_grid = array( 0x2b2b2b08, 0x2b2b2b08, 0x08081908, 0x2b2b2b19, 0x2b081908, 0x2b2b2b19, 0x2b08192b, 0x2b2b2b19, 0x082b2b08, 0x2b2b2b2b, 0x082b2b2b, 0x2b2b2b2b, 0x2b190819, 0x2b2b2b2b, 0x2b2b2b2b, 0x2b2b2b2b ); -#enddecl(IQ2_XS_GRID) +#endif -#decl(IQ2_S_GRID) +#ifdef IQ2_S_GRID const iq2s_grid = array( 0x08080808, 0x08080808, 0x0808082b, 0x08080808, 0x08081919, 0x08080808, 0x08082b08, 0x08080808, 0x08082b2b, 0x08080808, 0x08190819, 0x08080808, 0x08191908, 0x08080808, 0x0819192b, 0x08080808, @@ -674,10 +669,9 @@ const iq2s_grid = array( 0x2b08192b, 0x2b2b2b19, 0x08082b08, 0x2b2b2b2b, 0x08082b2b, 0x2b2b2b2b, 0x082b0808, 0x2b2b2b2b, 0x082b082b, 0x2b2b2b2b, 0x082b2b08, 0x2b2b2b2b, 0x2b082b08, 0x2b2b2b2b, 0x2b2b2b2b, 0x2b2b2b2b ); -#enddecl(IQ2_S_GRID) - -#decl(IQ3_XSS_GRID) +#endif +#ifdef IQ3_XXS_GRID const iq3xxs_grid = array( 0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3e, 0x04041404, 0x04041414, 0x04041c0c, 0x04042414, 0x04043e1c, 0x04043e2c, 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14, @@ -712,10 +706,9 @@ const iq3xxs_grid = array( 0x3e042c14, 0x3e0c1434, 0x3e0c2404, 0x3e140c14, 0x3e14242c, 0x3e142c14, 0x3e1c0404, 0x3e1c0c2c, 0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04 ); -#enddecl(IQ3_XSS_GRID) - -#decl(IQ3_S_GRID) +#endif +#ifdef IQ3_S_GRID const iq3s_grid = array( 0x01010101, 0x01010103, 0x01010105, 0x0101010b, 0x0101010f, 0x01010301, 0x01010303, 0x01010305, 0x01010309, 0x0101030d, 0x01010501, 0x01010503, 0x0101050b, 0x01010707, 0x01010901, 0x01010905, @@ -782,9 +775,9 @@ const iq3s_grid = array( 0x0f050701, 0x0f050b03, 0x0f070105, 0x0f070705, 0x0f07070b, 0x0f070b07, 0x0f090103, 0x0f09010b, 0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101 ); -#enddecl(IQ3_S_GRID) +#endif -#decl(IQ1_GRID) +#if defined(IQ1_S_GRID) || defined(IQ1_M_GRID) const IQ1_DELTA: f32 = 0.125; @@ -919,12 +912,12 @@ const iq1_grid = array( 0x55dd55df, 0x55d555d7, 0x5503550c, 0x557f5501, 0x5577557d, 0x55405575, 0x555d555f, 0x55555557 ); -#enddecl(IQ1_GRID) +#endif -#decl(IQ4_GRID) +#if defined(IQ4_NL_GRID) || defined(IQ4_XS_GRID) const kvalues_iq4nl = array( -127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113 ); -#enddecl(IQ4_GRID) +#endif diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py b/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py index d61df5bb9e5..8b5cfe715e7 100755 --- a/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +++ b/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py @@ -56,12 +56,46 @@ def replacer(match): return include_pattern.sub(replacer, shader) -def write_shader(shader_name, shader_code, output_dir, outfile): +def chunk_shader(shader_code, max_chunk_len=60000): + """Split shader_code into safe raw-string sized chunks.""" + return [shader_code[i : i + max_chunk_len] for i in range(0, len(shader_code), max_chunk_len)] + + +def raw_delim(shader_code): + """Pick a raw-string delimiter that does not appear in the shader.""" + delim = "wgsl" + while f"){delim}\"" in shader_code: + delim += "_x" + return delim + + +def write_shader(shader_name, shader_code, output_dir, outfile, input_dir): + shader_code = expand_includes(shader_code, input_dir) + if output_dir: wgsl_filename = os.path.join(output_dir, f"{shader_name}.wgsl") with open(wgsl_filename, "w", encoding="utf-8") as f_out: f_out.write(shader_code) - outfile.write(f'const char* wgsl_{shader_name} = R"({shader_code})";\n\n') + + delim = raw_delim(shader_code) + chunks = chunk_shader(shader_code) + + if len(chunks) == 1: + outfile.write(f'const char* wgsl_{shader_name} = R"{delim}({shader_code}){delim}";\n\n') + else: + for idx, chunk in enumerate(chunks): + outfile.write(f'static const char wgsl_{shader_name}_part{idx}[] = R"{delim}({chunk}){delim}";\n\n') + outfile.write(f'static const std::string& wgsl_{shader_name}_str() {{\n') + outfile.write(' static const std::string s = []{\n') + outfile.write(' std::string tmp;\n') + outfile.write(f' tmp.reserve({len(shader_code)});\n') + for idx in range(len(chunks)): + outfile.write(f' tmp.append(wgsl_{shader_name}_part{idx});\n') + outfile.write(' return tmp;\n') + outfile.write(' }();\n') + outfile.write(' return s;\n') + outfile.write('}\n') + outfile.write(f'const char* wgsl_{shader_name} = wgsl_{shader_name}_str().c_str();\n\n') def generate_variants(fname, input_dir, output_dir, outfile): @@ -74,7 +108,7 @@ def generate_variants(fname, input_dir, output_dir, outfile): try: variants = ast.literal_eval(extract_block(text, "VARIANTS")) except ValueError: - write_shader(shader_base_name, text, output_dir, outfile) + write_shader(shader_base_name, text, output_dir, outfile, input_dir) else: try: decls_map = parse_decls(extract_block(text, "DECLS")) @@ -123,7 +157,7 @@ def generate_variants(fname, input_dir, output_dir, outfile): output_name = f"{shader_base_name}_" + variant["REPLS"]["TYPE"] else: output_name = shader_base_name - write_shader(output_name, final_shader, output_dir, outfile) + write_shader(output_name, final_shader, output_dir, outfile, input_dir) def main(): @@ -137,7 +171,8 @@ def main(): os.makedirs(args.output_dir, exist_ok=True) with open(args.output_file, "w", encoding="utf-8") as out: - out.write("// Auto-generated shader embedding\n\n") + out.write("// Auto-generated shader embedding\n") + out.write("#include \n\n") for fname in sorted(os.listdir(args.input_dir)): if fname.endswith(".wgsl"): generate_variants(fname, args.input_dir, args.output_dir, out) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl similarity index 83% rename from ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl rename to ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl index f80ce1fc550..b10800e36d2 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl @@ -1,222 +1,31 @@ -#define(VARIANTS) - -[ - { - "SHADER_SUFFIX": "f32_vec", - "REPLS": { - "TYPE" : "vec4", - "DST_TYPE": "vec4", - "BLOCK_SIZE": 4 - }, - "DECLS": ["F32_VEC"] - }, - { - "REPLS": { - "TYPE" : "f32", - "DST_TYPE": "f32", - "BLOCK_SIZE": 1 - }, - "DECLS": ["F32"] - }, - { - "REPLS": { - "TYPE" : "f16", - "DST_TYPE": "f32", - "BLOCK_SIZE": 1 - }, - "DECLS": ["F16"] - }, - { - "REPLS": { - "TYPE" : "i32", - "DST_TYPE": "i32", - "BLOCK_SIZE": 1 - }, - "DECLS": ["I32"] - }, - { - "REPLS": { - "TYPE" : "q4_0", - "DST_TYPE": "f32", - "BLOCK_SIZE": 32 - }, - "DECLS": ["BYTE_HELPERS", "Q4_0_T", "Q4_0"] - }, - { - "REPLS": { - "TYPE" : "q4_1", - "DST_TYPE": "f32", - "BLOCK_SIZE": 32 - }, - "DECLS": ["BYTE_HELPERS", "Q4_1_T", "Q4_1"] - }, - { - "REPLS": { - "TYPE" : "q5_0", - "DST_TYPE": "f32", - "BLOCK_SIZE": 32 - }, - "DECLS": ["BYTE_HELPERS", "Q5_0_T", "Q5_0"] - }, - { - "REPLS": { - "TYPE" : "q5_1", - "DST_TYPE": "f32", - "BLOCK_SIZE": 32 - }, - "DECLS": ["BYTE_HELPERS", "Q5_1_T", "Q5_1"] - }, - { - "REPLS": { - "TYPE" : "q8_0", - "DST_TYPE": "f32", - "BLOCK_SIZE": 32 - }, - "DECLS": ["BYTE_HELPERS", "Q8_0_T", "Q8_0"] - }, - { - "REPLS": { - "TYPE" : "q2_k", - "DST_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "Q2_K_T", "Q2_K"] - }, - { - "REPLS": { - "TYPE" : "q3_k", - "DST_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "Q3_K_T", "Q3_K"] - }, - { - "REPLS": { - "TYPE" : "q4_k", - "DST_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["Q45_K_SCALE_MIN", "BYTE_HELPERS", "Q4_K_T", "Q4_K"] - }, - { - "REPLS": { - "TYPE" : "q5_k", - "DST_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["Q45_K_SCALE_MIN", "BYTE_HELPERS", "Q5_K_T", "Q5_K"] - }, - { - "REPLS": { - "TYPE" : "q6_k", - "DST_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "Q6_K_T", "Q6_K"] - }, - { - "REPLS": { - "TYPE" : "iq2_xxs", - "DST_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_XXS_GRID", "IQ2_XXS_T", "IQ2_XXS"] - }, - { - "REPLS": { - "TYPE" : "iq2_xs", - "DST_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_XS_GRID", "IQ2_XS_T", "IQ2_XS"] - }, - { - "REPLS": { - "TYPE": "iq2_s", - "DST_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_S_GRID", "IQ2_S_T", "IQ2_S"] - }, - { - "REPLS": { - "TYPE": "iq3_xxs", - "DST_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ3_XSS_GRID", "IQ3_XSS_T", "IQ3_XSS"] - }, - { - "REPLS": { - "TYPE": "iq3_s", - "DST_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ3_S_GRID", "IQ3_S_T", "IQ3_S"] - }, - { - "REPLS": { - "TYPE": "iq1_s", - "DST_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "IQ1_GRID", "IQ1_S_T", "IQ1_S"] - }, - { - "REPLS": { - "TYPE": "iq1_m", - "DST_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "IQ1_GRID", "IQ1_M_T", "IQ1_M"] - }, - { - "REPLS": { - "TYPE": "iq4_nl", - "DST_TYPE": "f32", - "BLOCK_SIZE": 32, - }, - "DECLS": ["BYTE_HELPERS", "IQ4_GRID", "IQ4_NL_T", "IQ4_NL"] - }, - { - "REPLS": { - "TYPE": "iq4_xs", - "DST_TYPE": "f32", - "BLOCK_SIZE": 256, - }, - "DECLS": ["BYTE_HELPERS", "IQ4_GRID", "IQ4_XS_T", "IQ4_XS"] - } -] - -#end(VARIANTS) - -#define(DECLS) - -#decl(F32_VEC) +enable f16; +#include "common_decls.tmpl" + +#ifdef F32_VEC fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { dst[(dst_base / 4) + offset] = src[(src_base / 4) + offset]; } -#enddecl(F32_VEC) +#endif -#decl(F32) +#ifdef F32 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { dst[dst_base + offset] = src[src_base + offset]; } -#enddecl(F32) +#endif -#decl(F16) +#ifdef F16 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { dst[dst_base + offset] = f32(src[src_base + offset]); } -#enddecl(F16) +#endif -#decl(I32) +#ifdef I32 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { dst[dst_base + offset] = src[src_base + offset]; } -#enddecl(I32) +#endif -#decl(Q4_0) +#ifdef Q4_0 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block_q4_0 = src[src_base + offset]; let d = f32(block_q4_0.d); @@ -232,9 +41,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { } } } -#enddecl(Q4_0) +#endif -#decl(Q4_1) +#ifdef Q4_1 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block_q4_1 = src[src_base + offset]; let d = f32(block_q4_1.d); @@ -251,9 +60,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { } } } -#enddecl(Q4_1) +#endif -#decl(Q5_0) +#ifdef Q5_0 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block_q5_0 = src[src_base + offset]; let d = f32(block_q5_0.d); @@ -272,10 +81,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { } } } +#endif -#enddecl(Q5_0) - -#decl(Q5_1) +#ifdef Q5_1 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block_q5_1 = src[src_base + offset]; let d = f32(block_q5_1.d); @@ -294,9 +102,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { } } } -#enddecl(Q5_1) +#endif -#decl(Q8_0) +#ifdef Q8_0 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block_q8_0 = src[src_base + offset]; let d = f32(block_q8_0.d); @@ -310,9 +118,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { } } } -#enddecl(Q8_0) +#endif -#decl(Q2_K) +#ifdef Q2_K fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block = src[src_base + offset]; let d = f32(block.d); @@ -340,9 +148,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { } } } -#enddecl(Q2_K) +#endif -#decl(Q3_K) +#ifdef Q3_K fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block = src[src_base + offset]; let d = f32(block.d); @@ -398,9 +206,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { } } } -#enddecl(Q3_K) +#endif -#decl(Q4_K) +#ifdef Q4_K // 8 blocks of 32 elements each fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block = src[src_base + offset]; @@ -425,9 +233,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { } } } -#enddecl(Q4_K) +#endif -#decl(Q5_K) +#ifdef Q5_K fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block = src[src_base + offset]; let d = f32(block.d); @@ -455,9 +263,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { } } } -#enddecl(Q5_K) +#endif -#decl(Q6_K) +#ifdef Q6_K // 16 blocks of 16 elements each fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block = src[src_base + offset]; @@ -511,10 +319,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { sc_b_idx += 8; } } +#endif -#enddecl(Q6_K) - -#decl(IQ2_XXS) +#ifdef IQ2_XXS fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block = src[src_base + offset]; let d = f32(block.d); @@ -536,9 +343,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { } } } -#enddecl(IQ2_XXS) +#endif -#decl(IQ2_XS) +#ifdef IQ2_XS fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block = src[src_base + offset]; let d = f32(block.d); @@ -568,9 +375,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { } } } -#enddecl(IQ2_XS) +#endif -#decl(IQ2_S) +#ifdef IQ2_S fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block = src[src_base + offset]; let d = f32(block.d); @@ -608,10 +415,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { } } } +#endif -#enddecl(IQ2_S) - -#decl(IQ3_XSS) +#ifdef IQ3_XXS fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block = src[src_base + offset]; let d = f32(block.d); @@ -638,9 +444,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { } } } -#enddecl(IQ3_XSS) +#endif -#decl(IQ3_S) +#ifdef IQ3_S fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block = src[src_base + offset]; let d = f32(block.d); @@ -683,9 +489,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { } } } -#enddecl(IQ3_S) +#endif -#decl(IQ1_S) +#ifdef IQ1_S fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block = src[src_base + offset]; let d = f32(block.d); @@ -707,10 +513,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { } } } +#endif -#enddecl(IQ1_S) - -#decl(IQ1_M) +#ifdef IQ1_M fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block = src[src_base + offset]; @@ -751,10 +556,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { } } } +#endif -#enddecl(IQ1_M) - -#decl(IQ4_NL) +#ifdef IQ4_NL fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block = src[src_base + offset]; let d = f32(block.d); @@ -770,9 +574,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { dst_i++; } } -#enddecl(IQ4_NL) +#endif -#decl(IQ4_XS) +#ifdef IQ4_XS fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block = src[src_base + offset]; let d = f32(block.d); @@ -791,24 +595,16 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { dst_i += 16; } } -#enddecl(IQ4_XS) - -#end(DECLS) - -#define(SHADER) - -enable f16; - -DECLS +#endif @group(0) @binding(0) -var src: array<{{TYPE}}>; +var src: array; @group(0) @binding(1) var idx: array; @group(0) @binding(2) -var dst: array<{{DST_TYPE}}>; +var dst: array; struct Params { offset_src: u32, // in elements @@ -842,8 +638,7 @@ struct Params { @group(0) @binding(3) var params: Params; -override wg_size: u32; -@compute @workgroup_size(wg_size) +@compute @workgroup_size(WG_SIZE) fn main(@builtin(global_invocation_id) gid: vec3) { if (gid.x >= params.n_rows * params.ne2 * params.ne3) { return; @@ -866,9 +661,8 @@ fn main(@builtin(global_invocation_id) gid: vec3) { let i_src_row = params.offset_src + idx_val * params.stride_src1 + i_dst2 * params.stride_src2 + i_dst3 * params.stride_src3; let i_dst_row = params.offset_dst + i_dst1 * params.stride_dst1 + i_dst2 * params.stride_dst2 + i_dst3 * params.stride_dst3; - for (var i: u32 = 0; i < params.ne0/{{BLOCK_SIZE}}; i++) { + for (var i: u32 = 0; i < params.ne0/BLOCK_SIZE; i++) { copy_elements(i_src_row, i_dst_row, i); } } -#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl similarity index 84% rename from ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl rename to ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl index 0f8e6e5ac3d..6aba47317c6 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl @@ -1,195 +1,24 @@ -#define(VARIANTS) - -[ - { - "REPLS": { - "SRC0_TYPE" : "f32", - "SRC1_TYPE" : "f32", - "BLOCK_SIZE" : 1 - }, - "DECLS" : ["FLOAT"] - }, - { - "REPLS": { - "SRC0_TYPE" : "f16", - "SRC1_TYPE" : "f16", - "BLOCK_SIZE" : 1 - }, - "DECLS" : ["FLOAT"] - }, - { - "REPLS": { - "SRC0_TYPE" : "f16", - "SRC1_TYPE" : "f32", - "BLOCK_SIZE" : 1 - }, - "DECLS" : ["FLOAT"] - }, - { - "REPLS": { - "SRC0_TYPE": "q4_0", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 32 - }, - "DECLS": ["BYTE_HELPERS", "Q4_0_T", "Q4_0"] - }, - { - "REPLS": { - "SRC0_TYPE": "q4_1", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 32 - }, - "DECLS": ["BYTE_HELPERS", "Q4_1_T", "Q4_1"] - }, - { - "REPLS": { - "SRC0_TYPE": "q5_0", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 32 - }, - "DECLS": ["BYTE_HELPERS", "Q5_0_T", "Q5_0"] - }, - { - "REPLS": { - "SRC0_TYPE": "q5_1", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 32 - }, - "DECLS": ["BYTE_HELPERS", "Q5_1_T", "Q5_1"] - }, - { - "REPLS": { - "SRC0_TYPE": "q8_0", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 32 - }, - "DECLS": ["BYTE_HELPERS", "Q8_0_T", "Q8_0"] - }, - { - "REPLS": { - "SRC0_TYPE": "q2_k", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "Q2_K_T", "Q2_K"] - }, - { - "REPLS": { - "SRC0_TYPE": "q3_k", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "Q3_K_T", "Q3_K"] - }, - { - "REPLS": { - "SRC0_TYPE": "q4_k", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["Q45_K_SCALE_MIN", "BYTE_HELPERS", "Q4_K_T", "Q4_K"] - }, - { - "REPLS": { - "SRC0_TYPE": "q5_k", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["Q45_K_SCALE_MIN", "BYTE_HELPERS", "Q5_K_T", "Q5_K"] - }, - { - "REPLS": { - "SRC0_TYPE": "q6_k", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "Q6_K_T", "Q6_K"] - }, - { - "REPLS": { - "SRC0_TYPE": "iq2_xxs", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_XXS_GRID", "IQ2_XXS_T", "IQ2_XXS"] - }, - { - "REPLS": { - "SRC0_TYPE": "iq2_xs", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_XS_GRID", "IQ2_XS_T", "IQ2_XS"] - }, - { - "REPLS": { - "SRC0_TYPE": "iq2_s", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_S_GRID", "IQ2_S_T", "IQ2_S"] - }, - { - "REPLS": { - "SRC0_TYPE": "iq3_xxs", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ3_XSS_GRID", "IQ3_XSS_T", "IQ3_XSS"] - }, - { - "REPLS": { - "SRC0_TYPE": "iq3_s", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ3_S_GRID", "IQ3_S_T", "IQ3_S"] - }, - { - "REPLS": { - "SRC0_TYPE": "iq1_s", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "IQ1_GRID", "IQ1_S_T", "IQ1_S"] - }, - { - "REPLS": { - "SRC0_TYPE": "iq1_m", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "IQ1_GRID", "IQ1_M_T", "IQ1_M"] - }, - { - "REPLS": { - "SRC0_TYPE": "iq4_nl", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 32, - }, - "DECLS": ["BYTE_HELPERS", "IQ4_GRID", "IQ4_NL_T", "IQ4_NL"] - }, - { - "REPLS": { - "SRC0_TYPE": "iq4_xs", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 256, - }, - "DECLS": ["BYTE_HELPERS", "IQ4_GRID", "IQ4_XS_T", "IQ4_XS"] - } -] - -#end(VARIANTS) - -#define(DECLS) - -#decl(FLOAT) +enable f16; + +#include "common_decls.tmpl" + +#ifdef FLOAT +const BLOCK_SIZE = 1u; + +#elif defined(Q4_0) || defined(Q4_1) || defined(Q5_0) || defined(Q5_1) || defined(Q8_0) || defined(Q8_1) || defined(IQ4_NL) +const BLOCK_SIZE = 32u; + +#elif defined(Q2_K) || defined(Q3_K) || defined(Q4_K) || defined(Q5_K) || defined(Q6_K) || defined(IQ2_XXS) || defined(IQ2_XS) || defined(IQ2_S) || defined(IQ3_XXS) || defined(IQ3_S) || defined(IQ1_S) || defined(IQ1_M) || defined(IQ4_XS) +const BLOCK_SIZE = 256u; +#endif + +#ifdef FLOAT fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { return f32(src0[src0_idx_base + offset]) * f32(src1[src1_idx_base + offset]); } -#enddecl(FLOAT) +#endif -#decl(Q4_0) +#ifdef Q4_0 fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block_q4_0 = src0[src0_idx_base + offset]; let d = f32(block_q4_0.d); @@ -207,9 +36,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { } return sum; } -#enddecl(Q4_0) +#endif -#decl(Q4_1) +#ifdef Q4_1 fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block_q4_1 = src0[src0_idx_base + offset]; let d = f32(block_q4_1.d); @@ -228,9 +57,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { } return sum; } -#enddecl(Q4_1) +#endif -#decl(Q5_0) +#ifdef Q5_0 fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block_q5_0 = src0[src0_idx_base + offset]; let d = f32(block_q5_0.d); @@ -251,9 +80,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { } return sum; } -#enddecl(Q5_0) +#endif -#decl(Q5_1) +#ifdef Q5_1 fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block_q5_1 = src0[src0_idx_base + offset]; let d = f32(block_q5_1.d); @@ -274,9 +103,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { } return sum; } -#enddecl(Q5_1) +#endif -#decl(Q8_0) +#ifdef Q8_0 fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block_q8_0 = src0[src0_idx_base + offset]; let d = f32(block_q8_0.d); @@ -292,9 +121,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { } return sum; } -#enddecl(Q8_0) +#endif -#decl(Q8_1) +#ifdef Q8_1 fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block_q8_1 = src0[src0_idx_base + offset]; let d = f32(block_q8_1.d); @@ -311,9 +140,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { } return sum; } -#enddecl(Q8_1) +#endif -#decl(Q2_K) +#ifdef Q2_K // 16 blocks of 16 elements each fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block = src0[src0_idx_base + offset]; @@ -344,10 +173,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { } return sum; } +#endif -#enddecl(Q2_K) - -#decl(Q3_K) +#ifdef Q3_K // 16 blocks of 16 elements each fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block = src0[src0_idx_base + offset]; @@ -406,10 +234,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { } return sum; } +#endif -#enddecl(Q3_K) - -#decl(Q4_K) +#ifdef Q4_K // 8 blocks of 32 elements each fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block = src0[src0_idx_base + offset]; @@ -436,10 +263,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { } return sum; } +#endif -#enddecl(Q4_K) - -#decl(Q5_K) +#ifdef Q5_K // 8 blocks of 32 elements each fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block = src0[src0_idx_base + offset]; @@ -470,10 +296,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { } return sum; } +#endif -#enddecl(Q5_K) - -#decl(Q6_K) +#ifdef Q6_K // 16 blocks of 16 elements each fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block = src0[src0_idx_base + offset]; @@ -529,10 +354,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { } return sum; } +#endif -#enddecl(Q6_K) - -#decl(IQ2_XXS) +#ifdef IQ2_XXS fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block = src0[src0_idx_base + offset]; let d = f32(block.d); @@ -556,10 +380,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { } return sum; } +#endif -#enddecl(IQ2_XXS) - -#decl(IQ2_XS) +#ifdef IQ2_XS fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block = src0[src0_idx_base + offset]; let d = f32(block.d); @@ -591,10 +414,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { } return sum; } +#endif -#enddecl(IQ2_XS) - -#decl(IQ2_S) +#ifdef IQ2_S fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block = src0[src0_idx_base + offset]; let d = f32(block.d); @@ -634,11 +456,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { } return sum; } +#endif - -#enddecl(IQ2_S) - -#decl(IQ3_XSS) +#ifdef IQ3_XXS fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block = src0[src0_idx_base + offset]; let d = f32(block.d); @@ -667,10 +487,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { } return sum; } +#endif -#enddecl(IQ3_XSS) - -#decl(IQ3_S) +#ifdef IQ3_S fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block = src0[src0_idx_base + offset]; let d = f32(block.d); @@ -715,9 +534,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { } return sum; } -#enddecl(IQ3_S) +#endif -#decl(IQ1_S) +#ifdef IQ1_S fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block = src0[src0_idx_base + offset]; let d = f32(block.d); @@ -741,10 +560,10 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { } return sum; } +#endif -#enddecl(IQ1_S) -#decl(IQ1_M) +#ifdef IQ1_M fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block = src0[src0_idx_base + offset]; @@ -787,10 +606,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { } return sum; } +#endif -#enddecl(IQ1_M) - -#decl(IQ4_NL) +#ifdef IQ4_NL fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block = src0[src0_idx_base + offset]; let d = f32(block.d); @@ -808,10 +626,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { } return sum; } +#endif -#enddecl(IQ4_NL) - -#decl(IQ4_XS) +#ifdef IQ4_XS fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block = src0[src0_idx_base + offset]; let d = f32(block.d); @@ -832,16 +649,7 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { } return sum; } - -#enddecl(IQ4_XS) - -#end(DECLS) - -#define(SHADER) - -enable f16; - -DECLS +#endif struct MulMatParams { offset_src0: u32, // in elements/blocks @@ -864,8 +672,8 @@ struct MulMatParams { broadcast3: u32 }; -@group(0) @binding(0) var src0: array<{{SRC0_TYPE}}>; // M rows, K columns -@group(0) @binding(1) var src1: array<{{SRC1_TYPE}}>; // K rows, N columns (transposed) +@group(0) @binding(0) var src0: array; // M rows, K columns +@group(0) @binding(1) var src1: array; // K rows, N columns (transposed) @group(0) @binding(2) var dst: array; // M rows, N columns @group(0) @binding(3) var params: MulMatParams; @@ -898,10 +706,8 @@ fn main(@builtin(global_invocation_id) global_id: vec3) { let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12 + row * params.stride_11; var sum = 0.0; - for (var i: u32 = 0u; i < params.k/{{BLOCK_SIZE}}; i = i + 1u) { + for (var i: u32 = 0u; i < params.k/BLOCK_SIZE; i = i + 1u) { sum += multiply_add(src0_idx_base, src1_idx_base, i); } dst[params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row * params.m + col] = sum; } - -#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl index 109ff8d6159..5c1074ebc10 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl @@ -1,58 +1,65 @@ -#decl(SHMEM_VEC) +#ifdef VEC +#define VEC_SIZE 4 +#define SHMEM_TYPE vec4 +#define DST_TYPE vec4 +#define SRC0_TYPE vec4 +#define SRC1_TYPE vec4 + fn store_shmem(val: vec4, idx: u32) { shmem[idx] = val.x; shmem[idx + 1] = val.y; shmem[idx + 2] = val.z; shmem[idx + 3] = val.w; } -#enddecl(SHMEM_VEC) +#endif + +#ifdef SCALAR +#define VEC_SIZE 1 +#define SHMEM_TYPE f16 +#define DST_TYPE f32 +#define SRC0_TYPE SRC0_INNER_TYPE +#define SRC1_TYPE SRC1_INNER_TYPE -#decl(SHMEM_SCALAR) fn store_shmem(val: f16, idx: u32) { shmem[idx] = val; } -#enddecl(SHMEM_SCALAR) - -#decl(INIT_SRC0_SHMEM_FLOAT) +#endif +#ifdef INIT_SRC0_SHMEM_FLOAT fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { - for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) { + for (var elem_idx = thread_id * VEC_SIZE; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * VEC_SIZE) { let tile_m = elem_idx / TILE_K; let tile_k = elem_idx % TILE_K; let global_m = offset_m + tile_m; let global_k = k_outer + tile_k; let src0_idx = batch_offset + global_m * params.stride_01 + global_k; let src0_val = select( // taking a slight performance hit to avoid oob - {{SRC0_TYPE}}(0.0), - src0[src0_idx/{{VEC_SIZE}}], + SRC0_TYPE(0.0), + src0[src0_idx/VEC_SIZE], global_m < params.m && global_k < params.k); - store_shmem({{SHMEM_TYPE}}(src0_val), elem_idx); + store_shmem(SHMEM_TYPE(src0_val), elem_idx); } } +#endif -#enddecl(INIT_SRC0_SHMEM_FLOAT) - -#decl(INIT_SRC1_SHMEM) - +#ifdef INIT_SRC1_SHMEM_FLOAT fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u32) { - for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC1_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) { + for (var elem_idx = thread_id * VEC_SIZE; elem_idx < TILE_SRC1_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * VEC_SIZE) { let tile_n = elem_idx / TILE_K; let tile_k = elem_idx % TILE_K; let global_n = offset_n + tile_n; let global_k = k_outer + tile_k; let src1_idx = batch_offset + global_n * params.stride_11 + global_k; let src1_val = select( - {{SRC1_TYPE}}(0.0), - src1[src1_idx/{{VEC_SIZE}}], + SRC1_TYPE(0.0), + src1[src1_idx/VEC_SIZE], global_n < params.n && global_k < params.k); - store_shmem({{SHMEM_TYPE}}(src1_val), TILE_SRC0_SHMEM + elem_idx); + store_shmem(SHMEM_TYPE(src1_val), TILE_SRC0_SHMEM + elem_idx); } } +#endif -#enddecl(INIT_SRC1_SHMEM) - -#decl(INIT_SRC0_SHMEM_Q4_0) - +#ifdef INIT_SRC0_SHMEM_Q4_0 const BLOCK_SIZE = 32u; // the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. override BLOCKS_K = TILE_K/BLOCK_SIZE; @@ -93,5 +100,4 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 } } } - -#enddecl(INIT_SRC0_SHMEM_Q4_0) +#endif diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl similarity index 55% rename from ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl rename to ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl index 6b1dd26cd9e..771e5cd1ee3 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl @@ -1,115 +1,19 @@ -#define(VARIANTS) -[ - { - "SHADER_SUFFIX": "f32_f32_vec", - "REPLS": { - "SRC0_TYPE" : "vec4", - "SRC1_TYPE" : "vec4", - "DST_TYPE" : "vec4", - "SHMEM_TYPE" : "vec4", - "VEC_SIZE" : 4, - }, - "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] - }, - { - "SHADER_SUFFIX": "f32_f32", - "REPLS": { - "SRC0_TYPE" : "f32", - "SRC1_TYPE" : "f32", - "DST_TYPE" : "f32", - "SHMEM_TYPE" : "f16", - "VEC_SIZE" : 1, - }, - "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] - }, - { - "SHADER_SUFFIX": "f16_f32_vec", - "REPLS": { - "SRC0_TYPE" : "vec4", - "SRC1_TYPE" : "vec4", - "DST_TYPE" : "vec4", - "SHMEM_TYPE" : "vec4", - "VEC_SIZE" : 4, - }, - "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] - }, - { - "SHADER_SUFFIX": "f16_f32", - "REPLS": { - "SRC0_TYPE" : "f16", - "SRC1_TYPE" : "f32", - "DST_TYPE" : "f32", - "SHMEM_TYPE" : "f16", - "VEC_SIZE" : 1, - }, - "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] - }, - { - "SHADER_SUFFIX": "f16_f16_vec", - "REPLS": { - "SRC0_TYPE" : "vec4", - "SRC1_TYPE" : "vec4", - "DST_TYPE" : "vec4", - "SHMEM_TYPE" : "vec4", - "VEC_SIZE" : 4, - }, - "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] - }, - { - "SHADER_SUFFIX": "f16_f16", - "REPLS": { - "SRC0_TYPE" : "f16", - "SRC1_TYPE" : "f16", - "DST_TYPE" : "f32", - "SHMEM_TYPE" : "f16", - "VEC_SIZE" : 1, - }, - "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] - }, - { - "SHADER_SUFFIX": "q4_0_f32_vec", - "REPLS": { - "SRC0_TYPE" : "f16", - "SRC1_TYPE" : "vec4", - "DST_TYPE" : "vec4", - "SHMEM_TYPE" : "vec4", - "VEC_SIZE" : 4, - }, - "DECLS": ["BYTE_HELPERS", "VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"] - }, - { - "SHADER_SUFFIX": "q4_0_f32", - "REPLS": { - "SRC0_TYPE" : "f16", - "SRC1_TYPE" : "f32", - "DST_TYPE" : "f32", - "SHMEM_TYPE" : "f16", - "VEC_SIZE" : 1, - }, - "DECLS": ["BYTE_HELPERS", "SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"] - } -] - -#end(VARIANTS) - -#define(DECLS) - -#decl(VEC) +enable f16; + +#include "common_decls.tmpl" +#include "mul_mat_decls.tmpl" + +#ifdef VEC fn store_val(acc: array, TILE_M>, tn: u32, tm: u32) -> vec4 { return vec4(f32(acc[tm][tn]), f32(acc[tm + 1][tn]), f32(acc[tm + 2][tn]), f32(acc[tm + 3][tn])); } -#enddecl(VEC) +#endif -#decl(SCALAR) +#ifdef SCALAR fn store_val(acc: array, TILE_M>, tn: u32, tm: u32) -> f32 { return f32(acc[tm][tn]); } -#enddecl(SCALAR) - -#end(DECLS) - -#define(SHADER) -enable f16; +#endif struct MulMatParams { offset_src0: u32, @@ -130,14 +34,12 @@ struct MulMatParams { broadcast3: u32 }; -@group(0) @binding(0) var src0: array<{{SRC0_TYPE}}>; // M rows, K columns -@group(0) @binding(1) var src1: array<{{SRC1_TYPE}}>; // K rows, N columns (transposed) -@group(0) @binding(2) var dst: array<{{DST_TYPE}}>; // M rows, N columns (transposed) +@group(0) @binding(0) var src0: array; // M rows, K columns +@group(0) @binding(1) var src1: array; // K rows, N columns (transposed) +@group(0) @binding(2) var dst: array; // M rows, N columns (transposed) @group(0) @binding(3) var params: MulMatParams; -DECLS - fn get_local_n(thread_id: u32) -> u32 { return thread_id / WORKGROUP_SIZE_M; } @@ -145,18 +47,9 @@ fn get_local_m(thread_id: u32) -> u32 { return thread_id % WORKGROUP_SIZE_M; } -// TILE_M must be multiple of 4 for vec4 loads -const TILE_M = {{WEBGPU_TILE_M}}u; -const TILE_N = {{WEBGPU_TILE_N}}u; - -override WORKGROUP_SIZE_M: u32; -override WORKGROUP_SIZE_N: u32; -override TILE_K: u32; - -override TOTAL_WORKGROUP_SIZE = WORKGROUP_SIZE_M * WORKGROUP_SIZE_N; -override TILE_SRC0_SHMEM = TILE_K * WORKGROUP_SIZE_M * TILE_M; -override TILE_SRC1_SHMEM = TILE_K * WORKGROUP_SIZE_N * TILE_N; - +const TOTAL_WORKGROUP_SIZE = WORKGROUP_SIZE_M * WORKGROUP_SIZE_N; +const TILE_SRC0_SHMEM = TILE_K * WORKGROUP_SIZE_M * TILE_M; +const TILE_SRC1_SHMEM = TILE_K * WORKGROUP_SIZE_N * TILE_N; var shmem: array; @compute @workgroup_size(TOTAL_WORKGROUP_SIZE) @@ -233,15 +126,13 @@ fn main(@builtin(workgroup_id) wg_id: vec3, for (var tn = 0u; tn < TILE_N; tn++) { let global_col = output_col_base + tn; if (global_col < params.n) { - for (var tm = 0u; tm < TILE_M; tm += {{VEC_SIZE}}) { + for (var tm = 0u; tm < TILE_M; tm += VEC_SIZE) { let global_row = output_row_base + tm; if (global_row < params.m) { let dst_idx = dst_batch_offset + global_col * params.m + global_row; - dst[dst_idx/{{VEC_SIZE}}] = store_val(acc, tn, tm); + dst[dst_idx/VEC_SIZE] = store_val(acc, tn, tm); } } } } } - -#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl similarity index 66% rename from ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl rename to ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl index 47c8ce36ab3..64529e03cdc 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl @@ -1,100 +1,12 @@ -#define(VARIANTS) -[ - { - "SHADER_SUFFIX": "f32_f32_vec", - "REPLS": { - "SRC0_TYPE" : "vec4", - "SRC1_TYPE" : "vec4", - "DST_TYPE" : "vec4", - "SHMEM_TYPE" : "vec4", - "VEC_SIZE" : 4, - }, - "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] - }, - { - "SHADER_SUFFIX": "f32_f32", - "REPLS": { - "SRC0_TYPE" : "f32", - "SRC1_TYPE" : "f32", - "DST_TYPE" : "f32", - "SHMEM_TYPE" : "f16", - "VEC_SIZE" : 1, - }, - "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] - }, - { - "SHADER_SUFFIX": "f16_f32_vec", - "REPLS": { - "SRC0_TYPE" : "vec4", - "SRC1_TYPE" : "vec4", - "DST_TYPE" : "vec4", - "SHMEM_TYPE" : "vec4", - "VEC_SIZE" : 4, - }, - "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] - }, - { - "SHADER_SUFFIX": "f16_f32", - "REPLS": { - "SRC0_TYPE" : "f16", - "SRC1_TYPE" : "f32", - "DST_TYPE" : "f32", - "SHMEM_TYPE" : "f16", - "VEC_SIZE" : 1, - }, - "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] - }, - { - "SHADER_SUFFIX": "f16_f16_vec", - "REPLS": { - "SRC0_TYPE" : "vec4", - "SRC1_TYPE" : "vec4", - "DST_TYPE" : "vec4", - "SHMEM_TYPE" : "vec4", - "VEC_SIZE" : 4, - }, - "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] - }, - { - "SHADER_SUFFIX": "f16_f16", - "REPLS": { - "SRC0_TYPE" : "f16", - "SRC1_TYPE" : "f16", - "DST_TYPE" : "f32", - "SHMEM_TYPE" : "f16", - "VEC_SIZE" : 1, - }, - "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] - }, - { - "SHADER_SUFFIX": "q4_0_f32_vec", - "REPLS": { - "SRC0_TYPE" : "f16", - "SRC1_TYPE" : "vec4", - "DST_TYPE" : "vec4", - "SHMEM_TYPE" : "vec4", - "VEC_SIZE" : 4, - }, - "DECLS": ["BYTE_HELPERS", "VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"] - }, - { - "SHADER_SUFFIX": "q4_0_f32", - "REPLS": { - "SRC0_TYPE" : "f16", - "SRC1_TYPE" : "f32", - "DST_TYPE" : "f32", - "SHMEM_TYPE" : "f16", - "VEC_SIZE" : 1, - }, - "DECLS": ["BYTE_HELPERS", "SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"] - } -] - -#end(VARIANTS) - -#define(DECLS) - -#decl(VEC) +diagnostic(off, chromium.subgroup_matrix_uniformity); +enable f16; +enable subgroups; +enable chromium_experimental_subgroup_matrix; + +#include "common_decls.tmpl" +#include "mul_mat_decls.tmpl" + +#ifdef VEC fn store_dst(shmem_idx: u32, dst_idx: u32) { dst[dst_idx] = vec4( f32(shmem[shmem_idx]), @@ -103,21 +15,13 @@ fn store_dst(shmem_idx: u32, dst_idx: u32) { f32(shmem[shmem_idx + 3]) ); } -#enddecl(VEC) +#endif -#decl(SCALAR) +#ifdef SCALAR fn store_dst(shmem_idx: u32, dst_idx: u32) { dst[dst_idx] = f32(shmem[shmem_idx]); } -#enddecl(SCALAR) - -#end(DECLS) - -#define(SHADER) -diagnostic(off, chromium.subgroup_matrix_uniformity); -enable f16; -enable subgroups; -enable chromium_experimental_subgroup_matrix; +#endif struct MulMatParams { offset_src0: u32, @@ -138,36 +42,19 @@ struct MulMatParams { broadcast3: u32 }; -@group(0) @binding(0) var src0: array<{{SRC0_TYPE}}>; // M rows, K columns -@group(0) @binding(1) var src1: array<{{SRC1_TYPE}}>; // K rows, N columns (transposed) -@group(0) @binding(2) var dst: array<{{DST_TYPE}}>; // M rows, N columns (transposed) +// SRC0_TYPE and SRC1_TYPE are defined in mul_mat_decls, which is included +@group(0) @binding(0) var src0: array; // M rows, K columns +@group(0) @binding(1) var src1: array; // K rows, N columns (transposed) +@group(0) @binding(2) var dst: array; // M rows, N columns (transposed) @group(0) @binding(3) var params: MulMatParams; -DECLS +const WG_M_SG_TILE_SIZE = SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE; +const WG_N_SG_TILE_SIZE = SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE; -// Note: These are string interpolated at build time, cannot use override constants due to limitations in -// current Dawn version type definitions/matrix load requirements for constant memory sizes. -const SUBGROUP_M = {{WEBGPU_SUBGROUP_M}}u; -const SUBGROUP_N = {{WEBGPU_SUBGROUP_N}}u; // For portability we assume the max subgroup size, meaning some subgroups will be masked out if the // runtime subgroup size is smaller. -const MAX_SUBGROUP_SIZE = {{WEBGPU_MAX_SUBGROUP_SIZE}}u; - const EXPECTED_SUBGROUPS = SUBGROUP_M * SUBGROUP_N; - -const SUBGROUP_MATRIX_M_SIZE = {{WEBGPU_SG_MAT_M_SIZE}}u; -const SUBGROUP_MATRIX_N_SIZE = {{WEBGPU_SG_MAT_N_SIZE}}u; -const SUBGROUP_MATRIX_K_SIZE = {{WEBGPU_SG_MAT_K_SIZE}}u; - -const SUBGROUP_MATRIX_M = {{WEBGPU_SUBGROUP_MATRIX_M}}u; -const SUBGROUP_MATRIX_N = {{WEBGPU_SUBGROUP_MATRIX_N}}u; - -const TILE_K = {{WEBGPU_TILE_K}}u; - -const WG_M_SG_TILE_SIZE = SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE; -const WG_N_SG_TILE_SIZE = SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE; - const TOTAL_WORKGROUP_SIZE = SUBGROUP_M * SUBGROUP_N * MAX_SUBGROUP_SIZE; const TILE_SRC0_SHMEM = TILE_K * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE; const TILE_SRC1_SHMEM = TILE_K * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE; @@ -285,7 +172,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let tile_dst_row_base = wg_m * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE; let tile_dst_col_base = wg_n * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE; - for (var idx = thread_id * {{VEC_SIZE}}; idx < total_tile_elems; idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) { + for (var idx = thread_id * VEC_SIZE; idx < total_tile_elems; idx += TOTAL_WORKGROUP_SIZE * VEC_SIZE) { let local_row = idx % WG_TILE_STRIDE; let local_col = idx / WG_TILE_STRIDE; @@ -294,9 +181,8 @@ fn main(@builtin(workgroup_id) wg_id: vec3, if (global_col < params.n && global_row < params.m) { let dst_idx = dst_batch_offset + global_col * params.m + global_row; - store_dst(idx, dst_idx/{{VEC_SIZE}}); + store_dst(idx, dst_idx/VEC_SIZE); } } } -#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl similarity index 61% rename from ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl rename to ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl index ffbb6403285..f9ea95e07b9 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl @@ -1,84 +1,17 @@ -#define(VARIANTS) -[ - { - "SHADER_SUFFIX": "f32_f32_vec", - "REPLS": { - "SRC0_TYPE" : "vec4", - "SRC1_TYPE" : "vec4", - "DST_TYPE": "vec4", - "VEC_SIZE" : 4, - }, - "DECLS": ["VEC", "MUL_ACC_FLOAT"] - }, - { - "SHADER_SUFFIX": "f32_f32", - "REPLS": { - "SRC0_TYPE" : "f32", - "SRC1_TYPE" : "f32", - "DST_TYPE": "f32", - "VEC_SIZE" : 1, - }, - "DECLS": ["SCALAR", "MUL_ACC_FLOAT"] - }, - { - "SHADER_SUFFIX": "f16_f32_vec", - "REPLS": { - "SRC0_TYPE" : "vec4", - "SRC1_TYPE" : "vec4", - "DST_TYPE": "vec4", - "VEC_SIZE" : 4, - }, - "DECLS": ["VEC", "MUL_ACC_FLOAT"] - }, - { - "SHADER_SUFFIX": "f16_f32", - "REPLS": { - "SRC0_TYPE" : "f16", - "SRC1_TYPE" : "f32", - "DST_TYPE": "f32", - "VEC_SIZE" : 1, - }, - "DECLS": ["SCALAR", "MUL_ACC_FLOAT"] - }, - { - "SHADER_SUFFIX": "f16_f16_vec", - "REPLS": { - "SRC0_TYPE" : "vec4", - "SRC1_TYPE" : "vec4", - "DST_TYPE": "vec4", - "VEC_SIZE" : 4, - }, - "DECLS": ["VEC", "MUL_ACC_FLOAT"] - }, - { - "SHADER_SUFFIX": "f16_f16", - "REPLS": { - "SRC0_TYPE" : "f16", - "SRC1_TYPE" : "f16", - "DST_TYPE": "f32", - "VEC_SIZE" : 1, - }, - "DECLS": ["SCALAR", "MUL_ACC_FLOAT"] - }, - { - "SHADER_SUFFIX": "q4_0_f32", - "REPLS": { - "SRC0_TYPE" : "f16", - "SRC1_TYPE" : "f32", - "DST_TYPE": "f32", - "VEC_SIZE" : 1, - }, - "DECLS": ["BYTE_HELPERS", "SCALAR", "MUL_ACC_Q4_0"] - } -] - -#end(VARIANTS) - -#define(DECLS) - -#decl(VEC) -fn inner_dot(src0_val: {{SRC0_TYPE}}, src1_val: {{SRC1_TYPE}}) -> f32 { - return f32(dot({{SRC1_TYPE}}(src0_val), src1_val)); + +enable f16; + +#include "common_decls.tmpl" + +#ifdef VEC + +#define VEC_SIZE 4 +#define DST_TYPE vec4 +#define SRC0_TYPE vec4 +#define SRC1_TYPE vec4 + +fn inner_dot(src0_val: SRC0_TYPE, src1_val: SRC1_TYPE) -> f32 { + return f32(dot(SRC1_TYPE(src0_val), src1_val)); } fn store_val(group_base: u32) -> vec4 { @@ -87,33 +20,37 @@ fn store_val(group_base: u32) -> vec4 { partial_sums[group_base + THREADS_PER_OUTPUT * 2], partial_sums[group_base + THREADS_PER_OUTPUT * 3]); } -#enddecl(VEC) +#endif + +#ifdef SCALAR + +#define VEC_SIZE 1 +#define DST_TYPE f32 +#define SRC0_TYPE SRC0_INNER_TYPE +#define SRC1_TYPE SRC1_INNER_TYPE -#decl(SCALAR) -fn inner_dot(src0_val: {{SRC0_TYPE}}, src1_val: {{SRC1_TYPE}}) -> f32 { +fn inner_dot(src0_val: SRC0_TYPE, src1_val: SRC1_TYPE) -> f32 { return f32(src0_val) * f32(src1_val); } fn store_val(group_base: u32) -> f32 { return partial_sums[group_base]; } -#enddecl(SCALAR) - -#decl(MUL_ACC_FLOAT) +#endif +#ifdef MUL_ACC_FLOAT fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { var local_sum = 0.0; - for (var i = tig * {{VEC_SIZE}}; i < tile_size; i += THREADS_PER_OUTPUT * {{VEC_SIZE}}) { - let a = src0[(idx_base + k_outer + i) / {{VEC_SIZE}}]; - let b = shared_vector[i / {{VEC_SIZE}}]; + for (var i = tig * VEC_SIZE; i < tile_size; i += THREADS_PER_OUTPUT * VEC_SIZE) { + let a = src0[(idx_base + k_outer + i) / VEC_SIZE]; + let b = shared_vector[i / VEC_SIZE]; local_sum += inner_dot(a, b); } return local_sum; } +#endif -#enddecl(MUL_ACC_FLOAT) - -#decl(MUL_ACC_Q4_0) +#ifdef MUL_ACC_Q4_0 const BLOCK_SIZE = 32; const NQ = 16u; // number of weights per thread @@ -145,15 +82,7 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { } return local_sum; } - -#enddecl(MUL_ACC_Q4_0) - -#end(DECLS) - -#define(SHADER) -enable f16; - -DECLS +#endif struct MulMatParams { offset_src0: u32, @@ -174,22 +103,20 @@ struct MulMatParams { broadcast3: u32 }; -@group(0) @binding(0) var src0: array<{{SRC0_TYPE}}>; // Matrix (M x K) -@group(0) @binding(1) var src1: array<{{SRC1_TYPE}}>; // Vector (K x 1, transposed) -@group(0) @binding(2) var dst: array<{{DST_TYPE}}>; // Result vector (transposed) +// SRC0_TYPE and SRC1_TYPE are defined in mul_mat_decls, which is included +@group(0) @binding(0) var src0: array; // M rows, K columns +@group(0) @binding(1) var src1: array; // K rows, N columns (transposed) +@group(0) @binding(2) var dst: array; // M rows, N columns (transposed) @group(0) @binding(3) var params: MulMatParams; -override WORKGROUP_SIZE: u32; -override TILE_K: u32; -override OUTPUTS_PER_WG: u32; -override THREADS_PER_OUTPUT = WORKGROUP_SIZE / OUTPUTS_PER_WG; +const THREADS_PER_OUTPUT = WG_SIZE / OUTPUTS_PER_WG; // Shared memory for collaborative loading and reduction -var shared_vector: array<{{SRC1_TYPE}}, TILE_K/{{VEC_SIZE}}>; // Cache vector tile -var partial_sums: array; // For reduction +var shared_vector: array; // Cache vector tile +var partial_sums: array; // For reduction -@compute @workgroup_size(WORKGROUP_SIZE) +@compute @workgroup_size(WG_SIZE) fn main( @builtin(local_invocation_id) local_id: vec3, @builtin(workgroup_id) wg_id: vec3, @@ -232,8 +159,8 @@ fn main( let tile_size = min(TILE_K, params.k - k_tile); // Cooperatively load vector tile into shared memory (all threads) - for (var i = thread_id * {{VEC_SIZE}}; i < tile_size; i += WORKGROUP_SIZE * {{VEC_SIZE}}) { - shared_vector[i / {{VEC_SIZE}}] = src1[(src1_idx_base + k_tile + i) / {{VEC_SIZE}}]; + for (var i = thread_id * VEC_SIZE; i < tile_size; i += WG_SIZE * VEC_SIZE) { + shared_vector[i / VEC_SIZE] = src1[(src1_idx_base + k_tile + i) / VEC_SIZE]; } workgroupBarrier(); @@ -250,7 +177,7 @@ fn main( workgroupBarrier(); let group_base = thread_group * THREADS_PER_OUTPUT; let thread_base = group_base + thread_in_group; - var offset = THREADS_PER_OUTPUT / 2; + var offset: u32 = THREADS_PER_OUTPUT / 2; while (offset > 0) { if (thread_in_group < offset) { partial_sums[thread_base] += partial_sums[thread_base + offset]; @@ -260,8 +187,8 @@ fn main( } // Store back to global memory - if (output_row < params.m && thread_group % {{VEC_SIZE}} == 0 && thread_in_group == 0) { - dst[dst_idx / {{VEC_SIZE}}] = store_val(group_base); + if (output_row < params.m && thread_group % VEC_SIZE == 0 && thread_in_group == 0) { + dst[dst_idx / VEC_SIZE] = store_val(group_base); } } -#end(SHADER) + diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl similarity index 78% rename from ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl rename to ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl index 040e80dfea2..3b70a876d70 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl @@ -1,44 +1,21 @@ -#define(VARIANTS) - -[ - { - "SHADER_NAME": "scale_f32", - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "scale_f32_inplace", - "DECLS": ["INPLACE"] - } -] - -#end(VARIANTS) - -#define(DECLS) - -#decl(NOT_INPLACE) +#ifdef INPLACE @group(0) @binding(1) -var dst: array; - -@group(0) @binding(2) var params: Params; fn store_scale(val: f32, offset: u32) { - dst[offset] = val; + src[offset] = val; } -#enddecl(NOT_INPLACE) - -#decl(INPLACE) +#else @group(0) @binding(1) +var dst: array; + +@group(0) @binding(2) var params: Params; fn store_scale(val: f32, offset: u32) { - src[offset] = val; + dst[offset] = val; } -#enddecl(INPLACE) - -#end(DECLS) - -#define(SHADER) +#endif struct Params { offset_src: u32, @@ -65,10 +42,7 @@ struct Params { @group(0) @binding(0) var src: array; -DECLS - -override wg_size: u32; -@compute @workgroup_size(wg_size) +@compute @workgroup_size(WG_SIZE) fn main(@builtin(global_invocation_id) gid: vec3) { if (gid.x >= params.ne) { return; @@ -87,4 +61,3 @@ fn main(@builtin(global_invocation_id) gid: vec3) { store_scale(src[i_src] * params.scale + params.bias, i_dst); } -#end(SHADER) From 8b3a52ba871d092e619835b0bf844d9f11e3a6c8 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Wed, 18 Feb 2026 16:06:29 -0700 Subject: [PATCH 173/831] ggml webgpu: Fix bug in dispatching large matrix-vector multiplication (llama/19535) * Fix bug in dispatching large matrix-vector multiplication --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 17bb2f47126..b5fee480562 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -1121,7 +1121,8 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, uint32_t batches = dst->ne[2] * dst->ne[3]; uint32_t output_groups = CEIL_DIV(dst->ne[0], decisions->outputs_per_wg); uint32_t total_wg = output_groups * batches; - wg_x = total_wg % ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; + // TODO: split large sizes into multiple batches to avoid way over-provisioning workgroups + wg_x = std::min(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension); wg_y = CEIL_DIV(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension); } else if (use_fast) { auto decisions = static_cast(pipeline.context.get()); From cc9e5cf89d5e324e4974c62ef8ccd770ffc42252 Mon Sep 17 00:00:00 2001 From: shalinib-ibm Date: Thu, 19 Feb 2026 11:58:53 +0530 Subject: [PATCH 174/831] llamafile: powerpc: add FP16 MMA path for Q4/Q8 matmul (llama/19709) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Avoid xvi8ger4pp signed→unsigned bias correction by dequantizing Q4/Q8 inputs to FP16 and using FP16×FP16→FP32 MMA. This removes post-processing overhead and improves performance. Performance Impact: 1.5 ~ 2x improvement in PP_Speed for Q4 and Q8 Models, measured with llama-bench and llama-batched-bench. Q8 Model: granite-4.0-h-micro-Q8_0.gguf (from huggingface) Q4 Model: Meta-Llama3-8b Q4 model (generated with llama-quantize from f32 model) llama-bench Q8 Model Results: model                                size     params backend    threads             test Base t/s Patch t/s granitehybrid 3B Q8_0            3.16 GiB     3.19 B CPU              10             pp8         64.48 ± 4.72         73.99 ± 0.27 granitehybrid 3B Q8_0            3.16 GiB     3.19 B CPU              10             pp16         80.11 ± 0.32         112.53 ± 0.40 granitehybrid 3B Q8_0            3.16 GiB     3.19 B CPU              10             pp32         89.10 ± 0.27         152.95 ± 0.68 granitehybrid 3B Q8_0            3.16 GiB     3.19 B CPU              10             pp64         93.65 ± 0.25         187.83 ± 0.83 granitehybrid 3B Q8_0            3.16 GiB     3.19 B CPU              10           pp128         99.93 ± 0.02         201.32 ± 0.11 granitehybrid 3B Q8_0            3.16 GiB     3.19 B CPU              10           pp256         102.32 ± 0.40         208.32 ± 0.41 granitehybrid 3B Q8_0            3.16 GiB     3.19 B CPU              10           pp512         103.42 ± 0.40         209.98 ± 0.14 granitehybrid 3B Q8_0            3.16 GiB     3.19 B CPU              10           tg128         20.35 ± 0.01         19.57 ± 0.01 llama-bench Q4 Model Results: model                                size     params backend    threads             test               Base    t/s                Patch   t/s llama 8B Q4_0                    4.33 GiB     8.03 B CPU              10             pp8         34.77 ± 0.10         41.23 ± 0.08 llama 8B Q4_0                    4.33 GiB     8.03 B CPU              10             pp16         40.81 ± 0.04         64.55 ± 0.15 llama 8B Q4_0                    4.33 GiB     8.03 B CPU              10             pp32         44.65 ± 0.05         90.84 ± 0.22 llama 8B Q4_0                    4.33 GiB     8.03 B CPU              10             pp64         47.49 ± 0.03         114.39 ± 0.11 llama 8B Q4_0                    4.33 GiB     8.03 B CPU              10           pp128         49.29 ± 0.24         120.13 ± 0.19 llama 8B Q4_0                    4.33 GiB     8.03 B CPU              10           pp256         49.77 ± 0.23         121.51 ± 0.11 llama 8B Q4_0                    4.33 GiB     8.03 B CPU              10           pp512         49.89 ± 0.23         117.52 ± 0.10 llama 8B Q4_0                    4.33 GiB     8.03 B CPU              10           tg128         13.40 ± 0.01         13.37 ± 0.00 Llama perplexity Results: Model Base Final PPL Estimate Patch Final PPL Estimate granite-4.0-h-micro-Q8_0 1.3862 +/- 0.04424 1.3868 +/- 0.04432 Meta-Llama3-8b Q4 1.3801 +/- 0.04116 1.3803 +/- 0.04116 Signed-off-by: Shalini.Salomi.Bodapati --- ggml/src/ggml-cpu/llamafile/sgemm-ppc.h | 333 ------------ ggml/src/ggml-cpu/llamafile/sgemm.cpp | 664 ++++++++++++++++++------ 2 files changed, 508 insertions(+), 489 deletions(-) delete mode 100644 ggml/src/ggml-cpu/llamafile/sgemm-ppc.h diff --git a/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h b/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h deleted file mode 100644 index a7078687288..00000000000 --- a/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +++ /dev/null @@ -1,333 +0,0 @@ -#pragma once - -typedef vector unsigned char vec_t; -typedef __vector_quad acc_t; - -template -class tinyBLAS_Q0_PPC { - public: - tinyBLAS_Q0_PPC(int64_t k, - const TA *A, int64_t lda, - const block_q8_0 *B, int64_t ldb, - float *C, int64_t ldc, - int ith, int nth); - - void matmul(int64_t m, int64_t n); - void matmul_tiled_q0(int64_t m, int64_t n, int64_t mc, int64_t nc, int64_t kc) { - vec_t A_pack[mc*kc*2]; - vec_t B_pack[nc*kc*2]; - int comparray[mc*kc]; - constexpr bool is_Ablock_q4 = std::is_same_v; - int64_t ytiles = m / mc; - int64_t xtiles = n / nc; - int64_t tiles = xtiles * ytiles; - int64_t duty = (tiles + nth - 1) / nth; - int64_t start = duty * ith; - int64_t end = start + duty; - if (end > tiles) { - end = tiles; - } - for (int64_t job = start; job < end; ++job) { - int64_t ii = (job / xtiles) * mc; - int64_t jj = (job % xtiles) * nc; - for (int64_t kk = 0; kk < k; kk += kc) { - if constexpr(is_Ablock_q4) { - packNormalInt4_large(A + ii*lda + kk, lda, mc, 4, (int8_t*)A_pack, comparray); - } else { - packNormal_large(A + ii*lda + kk, lda, mc, 8, (int8_t*)A_pack, false, comparray); - } - packNormal_large(B + jj*ldb + kk, ldb, nc, 8, (uint8_t*)B_pack, true); - KERNEL_Q0(ii, jj, mc, nc, kc, kk, A_pack, B_pack, comparray); - } - } - } - - private: - inline void save_res(int ii, int jj, int idx, vector float* fin_res, int RM=4, int RN=4) { - for (int I = 0; I < RM; I++) { - for (int J = 0; J < RN; J++) { - *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&fin_res[idx+I]+J); - } - } - } - - inline void add_save_res(int ii, int jj, int idx, vector float* fin_res, int RM=4, int RN=4) { - for (int I = 0; I < RM; I++) { - for (int J = 0; J < RN; J++) { - float * c_ptr = (float *)(C+ii+((jj+J)*ldc)+I); - *c_ptr += *((float*)&fin_res[idx+I]+J); - } - } - } - - template - inline void compute(acc_t* ACC, int c_idx, int s_idx, ArrayType& comparray, vector float* vs, vector float* fin_res) { - vector signed int vec_C[4]; - vector float CA[4] = {0}; - vector float res[4] = {0}; - __builtin_mma_disassemble_acc(vec_C, ACC); - for (int i = 0; i < 4; i++) { - CA[i] = vec_splats((float)(((double)comparray[c_idx+i]) * -128.0)); - res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]); - fin_res[s_idx+i] = vec_madd(res[i], vs[s_idx+i], fin_res[s_idx+i]); - } - } - - inline void process_q4_elements(vector signed char (&c)[2], int* ca) { - const vector signed char lowMask = vec_splats((signed char)0xF); - const vector unsigned char v4 = vec_splats((unsigned char)0x4); - const vector signed char v8 = vec_splats((signed char)0x8); - vector signed int vsum = {0}; - vector signed int vsum2 = {0}; - c[0] = vec_and(c[1], lowMask); - c[1] = vec_sr(c[1], v4); - c[0] = vec_sub(c[0], v8); - c[1] = vec_sub(c[1], v8); - vsum = vec_sum4s(c[0], vsum); - vsum2 = vec_sum4s(c[1], vsum2); - vsum = vec_add(vsum, vsum2); - *(ca) = vsum[0] + vsum[1] + vsum[2] + vsum[3]; - } - - template - inline void vector_permute_store(V2 &s1, V2 &s2, V2 &s3, V2 &s4, V1 *vecOffset, bool flip) { - vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23}; - vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31}; - vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27}; - vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31}; - V2 t1, t2, t3, t4, t5, t6, t7, t8; - vector unsigned char xor_vector; - uint8_t flip_vec = 0x80; - xor_vector = vec_splats(flip_vec); - t1 = vec_perm(s1, s2, swiz1); - t2 = vec_perm(s1, s2, swiz2); - t3 = vec_perm(s3, s4, swiz1); - t4 = vec_perm(s3, s4, swiz2); - t5 = vec_perm(t1, t3, swiz3); - t6 = vec_perm(t1, t3, swiz4); - t7 = vec_perm(t2, t4, swiz3); - t8 = vec_perm(t2, t4, swiz4); - if (flip == true) { - t5 = vec_xor(t5, xor_vector); - t6 = vec_xor(t6, xor_vector); - t7 = vec_xor(t7, xor_vector); - t8 = vec_xor(t8, xor_vector); - } - vec_xst(t5, 0, vecOffset); - vec_xst(t6, 0, vecOffset+16); - vec_xst(t7, 0, vecOffset+32); - vec_xst(t8, 0, vecOffset+48); - } - - template - inline void kernel(int64_t ii, int64_t jj) { - if constexpr(RM == 4 && RN == 8) { - KERNEL_4x8(ii,jj); - } else if constexpr(RM == 8 && RN == 4) { - KERNEL_8x4(ii,jj); - } else if constexpr(RM == 8 && RN == 8) { - KERNEL_8x8(ii,jj); - } else { - assert(false && "RN/RM values not supported"); - } - } - template - void packNormalInt4(const TA* a, int64_t lda, int rows, int cols, int8_t* vec, std::array& comparray); - template - void packNormal(const block_q8_0* a, int64_t lda, int rows, int cols, VA* vec, bool flip); - void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n); - void KERNEL_4x8(int64_t ii, int64_t jj); - void KERNEL_8x4(int64_t ii, int64_t jj); - void KERNEL_8x8(int64_t ii, int64_t jj); - void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN); - template - void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n); - - void compute_scale(int64_t ii, int64_t jj, int blk, vector float* vs){ - for (int I = 0; I<8; I++) { - float a_scale = unhalf((A+((ii+I)*lda)+blk)->d); - for (int J = 0; J<4; J++) { - *((float*)&vs[I]+J) = (a_scale * unhalf((B+((jj+J)*ldb)+blk)->d)); - *((float*)&vs[I+8]+J) = (a_scale * unhalf((B+((jj+J+4)*ldb)+blk)->d)); - } - } - } - - inline void process_q8_elements(const int8_t *qs, int *ca) { - vector signed char c1 = vec_xl(0, qs); - vector signed char c2 = vec_xl(16, qs); - vector signed int vsum1 = {0}; - vector signed int vsum2 = {0}; - vsum1 = vec_sum4s(c1, vsum1); - vsum2 = vec_sum4s(c2, vsum2); - vector signed int vsum = vec_add(vsum1, vsum2); - *ca = vsum[0] + vsum[1] + vsum[2] + vsum[3]; - } - - template - void packNormal_large(const block_q8_0* a, int64_t lda, int rows, int cols, VA* vec, bool flip, int* comparray=nullptr) { - int64_t i, j; - block_q8_0 *aoffset = NULL; - VA *vecOffset = NULL; - block_q8_0* aoffsets[8]; - __vector_pair arr[8]; - VB c[8][2] = {0}; - VB c1[8] = {0}; VB c2[8] = {0}; - aoffset = const_cast(a); - vecOffset = vec; - j = (rows >> 3); - int index = 0; - if (j > 0) { - do { - for (int it = 0; it < 8; it++) - aoffsets[it] = aoffset + it*lda; - aoffset += 8 * lda; - for (int blk = 0; blk < kc; blk++) { - for (int it = 0; it < 8; it++) { - arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)(aoffsets[it]+blk)->qs); - __builtin_vsx_disassemble_pair(c[it], &arr[it]); - c1[it] = c[it][0]; - c2[it] = c[it][1]; - if (comparray){ - process_q8_elements((aoffsets[it]+ blk)->qs, &comparray[index + 8*blk + it]); - } - } - vector_permute_store(c1[0], c1[1], c1[2], c1[3], vecOffset, flip); - vector_permute_store(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip); - vector_permute_store(c1[4], c1[5], c1[6], c1[7], vecOffset+128, flip); - vector_permute_store(c2[4], c2[5], c2[6], c2[7], vecOffset+192, flip); - vecOffset += 256; - } - j--; - index += 8*kc; - } while(j > 0); - } - - } - - void packNormalInt4_large(const TA* a, int64_t lda, int rows, int cols, int8_t* vec, int*comparray) { - int64_t i, j; - TA *aoffset = NULL; - int8_t *vecOffset = NULL; - TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL; - TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL; - vector signed char c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0}; - vector signed char c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0}; - aoffset = const_cast(a); - vecOffset = vec; - int index = 0; - j = (rows >> 3); - if (j > 0) { - do { - aoffset1 = aoffset; - aoffset2 = aoffset1 + lda; - aoffset3 = aoffset2 + lda; - aoffset4 = aoffset3 + lda; - aoffset5 = aoffset4 + lda; - aoffset6 = aoffset5 + lda; - aoffset7 = aoffset6 + lda; - aoffset8 = aoffset7 + lda; - aoffset += 8 * lda; - for (int blk = 0; blk < kc; blk++) { - c1[1] = reinterpret_cast(vec_xl(0, (aoffset1+blk)->qs)); - c2[1] = reinterpret_cast(vec_xl(0, (aoffset2+blk)->qs)); - c3[1] = reinterpret_cast(vec_xl(0, (aoffset3+blk)->qs)); - c4[1] = reinterpret_cast(vec_xl(0, (aoffset4+blk)->qs)); - c5[1] = reinterpret_cast(vec_xl(0, (aoffset5+blk)->qs)); - c6[1] = reinterpret_cast(vec_xl(0, (aoffset6+blk)->qs)); - c7[1] = reinterpret_cast(vec_xl(0, (aoffset7+blk)->qs)); - c8[1] = reinterpret_cast(vec_xl(0, (aoffset8+blk)->qs)); - - process_q4_elements(c1, &comparray[index + 8*blk+0]); - process_q4_elements(c2, &comparray[index + 8*blk+1]); - process_q4_elements(c3, &comparray[index + 8*blk+2]); - process_q4_elements(c4, &comparray[index + 8*blk+3]); - process_q4_elements(c5, &comparray[index + 8*blk+4]); - process_q4_elements(c6, &comparray[index + 8*blk+5]); - process_q4_elements(c7, &comparray[index + 8*blk+6]); - process_q4_elements(c8, &comparray[index + 8*blk+7]); - vector_permute_store(c1[0], c2[0], c3[0], c4[0], vecOffset, false); - vector_permute_store(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false); - vector_permute_store(c5[0], c6[0], c7[0], c8[0], vecOffset+128, false); - vector_permute_store(c5[1], c6[1], c7[1], c8[1], vecOffset+192, false); - vecOffset += 256; - } - j--; - index += 8*kc; - } while (j > 0); - } - } - - void KERNEL_Q0(int64_t ii, int64_t jj, int64_t mc, int64_t nc, int64_t kc, int64_t l, vec_t *vec_A, vec_t *vec_B, int *comparray) { - acc_t acc[8]; - for (int i = 0; i < mc ; i += 8) { - for (int j = 0; j < nc; j += 8) { - vector float fin_res[16] = {0}; - vector float vs[16] = {0}; - for (int64_t kk = 0; kk < kc; kk+=2) { - for (int x = 0; x < 8; x++) { - __builtin_mma_xxsetaccz(&acc[x]); - } - int A_block_idx = (i/8)*(16*kc) + kk*16; - int B_block_idx = (j/8)*(16*kc)+ kk*16; - vec_t *A_block = &vec_A[A_block_idx]; - vec_t *B_block = &vec_B[B_block_idx]; - for (int x = 0; x < 8; x++) { - __builtin_mma_xvi8ger4pp(&acc[0], A_block[x], B_block[x]); - __builtin_mma_xvi8ger4pp(&acc[1], A_block[x + 8], B_block[x]); - __builtin_mma_xvi8ger4pp(&acc[2], A_block[x], B_block[x+8]); - __builtin_mma_xvi8ger4pp(&acc[3], A_block[x+8], B_block[x+8]); - } - compute_scale(ii+i, jj+j, l+kk, vs); - int c_index = (i/8)*(8*kc)+ kk*8; - int* c_block = &comparray[c_index]; - compute(&acc[0], 0, 0, c_block, vs, fin_res); - compute(&acc[1], 4, 4, c_block, vs, fin_res); - compute(&acc[2], 0, 8, c_block, vs, fin_res); - compute(&acc[3], 4, 12, c_block, vs, fin_res); - - A_block_idx = (i/8)*(16*kc) + (kk+1)*16; - B_block_idx = (j/8)*(16*kc)+ (kk+1)*16; - A_block = &vec_A[A_block_idx]; - B_block = &vec_B[B_block_idx]; - for (int x = 0; x < 8; x++) { - __builtin_mma_xvi8ger4pp(&acc[4], A_block[x], B_block[x]); - __builtin_mma_xvi8ger4pp(&acc[5], A_block[x + 8], B_block[x]); - __builtin_mma_xvi8ger4pp(&acc[6], A_block[x], B_block[x+8]); - __builtin_mma_xvi8ger4pp(&acc[7], A_block[x+8], B_block[x+8]); - } - compute_scale(ii+i, jj+j, l+kk+1, vs); - c_index = (i/8)*(8*kc)+ (kk+1)*8; - c_block = &comparray[c_index]; - compute(&acc[4], 0, 0, c_block, vs, fin_res); - compute(&acc[5], 4, 4, c_block, vs, fin_res); - compute(&acc[6], 0, 8, c_block, vs, fin_res); - compute(&acc[7], 4, 12, c_block, vs, fin_res); - - } - if (l == 0) { - save_res(ii+i, jj+j, 0, fin_res); - save_res(ii+i+4, jj+j, 4, fin_res); - save_res(ii+i, jj+j+4, 8, fin_res); - save_res(ii+i+4, jj+j+4, 12, fin_res); - } else { - add_save_res(ii+i, jj+j, 0, fin_res); - add_save_res(ii+i+4, jj+j, 4, fin_res); - add_save_res(ii+i, jj+j+4, 8, fin_res); - add_save_res(ii+i+4, jj+j+4, 12, fin_res); - } - } - } - } - - const TA *const A; - const block_q8_0 *const B; - float *C; - const int64_t k; - int64_t kc; - const int64_t lda; - const int64_t ldb; - const int64_t ldc; - const int ith; - const int nth; -}; diff --git a/ggml/src/ggml-cpu/llamafile/sgemm.cpp b/ggml/src/ggml-cpu/llamafile/sgemm.cpp index 8f980c16b96..da412fd009b 100644 --- a/ggml/src/ggml-cpu/llamafile/sgemm.cpp +++ b/ggml/src/ggml-cpu/llamafile/sgemm.cpp @@ -121,7 +121,8 @@ inline float32x4_t mul(float32x4_t x, float32x4_t y) { return vec_mul(x, y); } #endif #if defined(__MMA__) -#include "sgemm-ppc.h" +typedef vector unsigned char vec_t; +typedef __vector_quad acc_t; #endif //////////////////////////////////////////////////////////////////////////////////////////////////// // VECTORIZED FUSED MULTIPLY ADD @@ -2153,7 +2154,7 @@ class tinyBLAS_HP16_PPC { packNormal((B+(jj*ldb)+l), ldb, 8, 4, (uint8_t*)vec_B); for (int x = 0; x < 4; x++) { mma_instr::outer_product(&acc_0, vec_A[x], vec_B[x]); - mma_instr::outer_product(&acc_1, vec_A[x], vec_B[x+4]); + mma_instr::outer_product(&acc_1, vec_A[x+4], vec_B[x]); } } SAVE_ACC(&acc_0, ii, jj); @@ -2301,43 +2302,299 @@ class tinyBLAS_HP16_PPC { const int nth; }; - template - tinyBLAS_Q0_PPC::tinyBLAS_Q0_PPC(int64_t k, - const TA *A, int64_t lda, - const block_q8_0 *B, int64_t ldb, - float *C, int64_t ldc, - int ith, int nth) +template +class tinyBLAS_Q0_PPC { + public: + tinyBLAS_Q0_PPC(int64_t k, + const TA * A, int64_t lda, + const block_q8_0 * B, int64_t ldb, + float * C, int64_t ldc, + int ith, int nth) : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { - kc = 64; } - template - void tinyBLAS_Q0_PPC::matmul(int64_t m, int64_t n) { - int mc = 64; int nc = 64; - if (n % 8 == 0 && n < nc) { - nc = n; - mc = 32 ; - kc = 32; - } - const bool is_aligned = ((m & (mc - 1)) == 0) & ((n & (nc - 1)) == 0) & ((k & (kc - 1)) == 0); - if (is_aligned) { - this->matmul_tiled_q0(m, n, mc, nc, kc); + void matmul(int64_t m, int64_t n) { + const int64_t mc = 64; + const int64_t kc = 64; + int64_t nc = 64; + int64_t n_aligned = 0; + if (n % 64 == 0) { + n_aligned = n; + } else if (n == 4) { + n_aligned = 4; + } else if (n < 64) { + n_aligned = (n / 8) * 8; + } else { + n_aligned = (n / 64) * 64; + } + + if (n_aligned > 0) { + if (n_aligned % 64 == 0) nc = 64; + else if (n_aligned == n) nc = n; + else if (n_aligned % 32 == 0) nc = 32; + else if (n_aligned % 24 == 0) nc = 24; + else if (n_aligned % 16 == 0) nc = 16; + else nc = 8; + } + bool can_use_tiled = n_aligned > 0 && (m % mc == 0) && (k % kc == 0); + if (can_use_tiled) { + matmul_tiled(m, n_aligned, mc, nc, kc); + if (n > n_aligned) { + mnpack(0, m, n_aligned, n); + } } else { mnpack(0, m, 0, n); } } - template - template - void tinyBLAS_Q0_PPC::packNormalInt4(const TA* a, int64_t lda, int rows, int cols, int8_t* vec, std::array& comparray) { + private: + inline void save_res(int ii, int jj, int idx, vector float * fin_res, int RM = 4, int RN = 4) { + for (int I = 0; I < RM; I++) { + for (int J = 0; J < RN; J++) { + *((float *)(C + ii + ((jj + J) * ldc) + I)) = *((float *)&fin_res[idx + I] + J); + } + } + } + + inline void save_acc(acc_t * ACC, int64_t ii, int64_t jj) { + vec_t vec_C[4]; + __builtin_mma_disassemble_acc(vec_C, ACC); + for (int I = 0; I < 4; I++) { + for (int J = 0; J < 4; J++) { + *((float *)(C + ii + ((jj + J) * ldc) + I)) = *((float *)&vec_C[I] + J); + } + } + } + + inline void add_save_acc(acc_t * ACC, int64_t ii, int64_t jj) { + vec_t vec_C[4]; + __builtin_mma_disassemble_acc(vec_C, ACC); + for (int I = 0; I < 4; I++) { + for (int J = 0; J < 4; J++) { + float * c_ptr = (float *)(C + ii+ ((jj + J) * ldc) + I); + *c_ptr += *((float *)&vec_C[I] + J); + } + } + } + + template + inline void compute(acc_t * ACC, int c_idx, int s_idx, ArrayType & comparray, vector float * vs, vector float * fin_res) { + vector signed int vec_C[4]; + vector float CA[4] = {0}; + vector float res[4] = {0}; + __builtin_mma_disassemble_acc(vec_C, ACC); + for (int i = 0; i < 4; i++) { + CA[i] = vec_splats((float)(((double)comparray[c_idx + i]) * -128.0)); + res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]); + fin_res[s_idx + i] = vec_madd(res[i], vs[s_idx + i], fin_res[s_idx + i]); + } + } + + inline void process_q4_elements(vector signed char (&c)[2], int * ca) { + const vector signed char lowMask = vec_splats((signed char)0xF); + const vector unsigned char v4 = vec_splats((unsigned char)0x4); + const vector signed char v8 = vec_splats((signed char)0x8); + vector signed int vsum = {0}; + vector signed int vsum2 = {0}; + c[0] = vec_and(c[1], lowMask); + c[1] = vec_sr(c[1], v4); + c[0] = vec_sub(c[0], v8); + c[1] = vec_sub(c[1], v8); + vsum = vec_sum4s(c[0], vsum); + vsum2 = vec_sum4s(c[1], vsum2); + vsum = vec_add(vsum, vsum2); + *(ca) = vsum[0] + vsum[1] + vsum[2] + vsum[3]; + } + + template + inline void vector_permute_store(V2 & s1, V2 & s2, V2 & s3, V2 & s4, V1 * vecOffset, bool flip) { + vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23}; + vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31}; + vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27}; + vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31}; + V2 t1, t2, t3, t4, t5, t6, t7, t8; + vector unsigned char xor_vector; + uint8_t flip_vec = 0x80; + xor_vector = vec_splats(flip_vec); + t1 = vec_perm(s1, s2, swiz1); + t2 = vec_perm(s1, s2, swiz2); + t3 = vec_perm(s3, s4, swiz1); + t4 = vec_perm(s3, s4, swiz2); + t5 = vec_perm(t1, t3, swiz3); + t6 = vec_perm(t1, t3, swiz4); + t7 = vec_perm(t2, t4, swiz3); + t8 = vec_perm(t2, t4, swiz4); + if (flip == true) { + t5 = vec_xor(t5, xor_vector); + t6 = vec_xor(t6, xor_vector); + t7 = vec_xor(t7, xor_vector); + t8 = vec_xor(t8, xor_vector); + } + vec_xst(t5, 0, vecOffset); + vec_xst(t6, 0, vecOffset + 16); + vec_xst(t7, 0, vecOffset + 32); + vec_xst(t8, 0, vecOffset + 48); + } + + inline void unpack_q4_to_q8(vector signed char packed, vector signed char & lo, vector signed char & hi) { + const vector signed char lowMask = vec_splats((signed char)0x0F); + const vector signed char v8 = vec_splats((signed char)0x08); + const vector unsigned char v4 = vec_splats((unsigned char)4); + lo = vec_and(packed, lowMask); + hi = vec_sr(packed, v4); + lo = vec_sub(lo, v8); + hi = vec_sub(hi, v8); + } + + inline void vector_permute_store_fp16(vec_t * c, unsigned char * vecOffset) { + vec_t t[8], s[8]; + vec_t swiz1 = {0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23}; + vec_t swiz2 = {8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31}; + vec_t swiz3 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23}; + vec_t swiz4 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31}; + for (int i = 0; i < 4; i += 2) { + t[i + 0] = vec_perm(c[i + 0], c[i + 1], swiz1); + t[i + 1] = vec_perm(c[i + 0], c[i + 1], swiz2); + } + for (int i = 4; i < 8; i += 2) { + t[i + 0] = vec_perm(c[i + 0], c[i + 1], swiz1); + t[i + 1] = vec_perm(c[i + 0], c[i + 1], swiz2); + } + s[0] = vec_perm(t[0], t[2], swiz3); + s[1] = vec_perm(t[0], t[2], swiz4); + s[2] = vec_perm(t[1], t[3], swiz3); + s[3] = vec_perm(t[1], t[3], swiz4); + s[4] = vec_perm(t[4], t[6], swiz3); + s[5] = vec_perm(t[4], t[6], swiz4); + s[6] = vec_perm(t[5], t[7], swiz3); + s[7] = vec_perm(t[5], t[7], swiz4); + for (int i = 0; i < 8; ++i) { + vec_xst(s[i], 0, (vec_t *)(vecOffset + i * 16)); + } + } + + static inline void convert_and_scale_q8(vector signed char raw, vector float v_scale, vector unsigned short & out_hi, vector unsigned short & out_lo) { + vector signed short i16_hi = vec_unpackh(raw); + vector signed short i16_lo = vec_unpackl(raw); + + vector float f_hi_h = vec_ctf(vec_unpackh(i16_hi), 0); + vector float f_hi_l = vec_ctf(vec_unpackl(i16_hi), 0); + vector float f_lo_h = vec_ctf(vec_unpackh(i16_lo), 0); + vector float f_lo_l = vec_ctf(vec_unpackl(i16_lo), 0); + out_hi = vec_pack_to_short_fp32(vec_mul(f_hi_h, v_scale), vec_mul(f_hi_l, v_scale)); + out_lo = vec_pack_to_short_fp32(vec_mul(f_lo_h, v_scale), vec_mul(f_lo_l, v_scale)); + } + + void packNormal_q4_fp16(const block_q4_0 * a, int64_t lda, int rows, int blocks, unsigned char * vec) { + unsigned char * vecOffset = vec; + for (int i = 0; i < rows; i += 8) { + const block_q4_0 * rows_base[8]; + for (int r = 0; r < 8; r++) { + rows_base[r] = a + (i + r) * lda; + } + for (int blk = 0; blk < blocks; blk++) { + vector unsigned short hp_res[8][4]; + for (int r = 0; r < 8; r++) { + const block_q4_0 * current_blk = rows_base[r] + blk; + vector float v_scale = vec_extract_fp32_from_shorth(vec_splats(current_blk->d)); + vector signed char v_qs = reinterpret_cast(vec_xl(0, current_blk->qs)); + vector signed char c1, c2; + unpack_q4_to_q8(v_qs, c1, c2); + convert_and_scale_q8(c1, v_scale, hp_res[r][0], hp_res[r][1]); + convert_and_scale_q8(c2, v_scale, hp_res[r][2], hp_res[r][3]); + } + for (int c = 0; c < 4; c++) { + vector unsigned char c_arr[8]; + for (int r = 0; r < 8; r++) { + c_arr[r] = (vector unsigned char)hp_res[r][c]; + } + vector_permute_store_fp16((vec_t *)c_arr, vecOffset); + vecOffset += 128; + } + } + } + } + + template + static inline void pack_q8_block(const block_q8_0 * a, int64_t lda, int rows, int blocks, unsigned char * vec) { + unsigned char * vecOffset = vec; + const vec_t swiz1 = {0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23}; + const vec_t swiz2 = {8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31}; + const vec_t swiz3 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23}; + const vec_t swiz4 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31}; + + for (int i = 0; i < rows; i += chunk_size) { + const block_q8_0 * rows_base[chunk_size]; + for (int r = 0; r < chunk_size; r++) { + rows_base[r] = a + (i + r) * lda; + } + for (int blk = 0; blk < blocks; blk++) { + vector unsigned short hp_res[chunk_size][4]; + for (int r = 0; r < chunk_size; r++) { + const block_q8_0 * b = rows_base[r] + blk; + vector float v_scale = vec_extract_fp32_from_shorth(vec_splats(b->d)); + vector signed char c[2]; + __vector_pair pair = __builtin_vsx_lxvp(0, (__vector_pair *)b->qs); + __builtin_vsx_disassemble_pair(c, & pair); + convert_and_scale_q8(c[0], v_scale, hp_res[r][0], hp_res[r][1]); + convert_and_scale_q8(c[1], v_scale, hp_res[r][2], hp_res[r][3]); + } + for (int col = 0; col < 4; col++) { + if constexpr (chunk_size == 8) { + vec_t t[8]; + t[0] = vec_perm((vec_t)hp_res[0][col], (vec_t)hp_res[1][col], swiz1); + t[1] = vec_perm((vec_t)hp_res[0][col], (vec_t)hp_res[1][col], swiz2); + t[2] = vec_perm((vec_t)hp_res[2][col], (vec_t)hp_res[3][col], swiz1); + t[3] = vec_perm((vec_t)hp_res[2][col], (vec_t)hp_res[3][col], swiz2); + t[4] = vec_perm((vec_t)hp_res[4][col], (vec_t)hp_res[5][col], swiz1); + t[5] = vec_perm((vec_t)hp_res[4][col], (vec_t)hp_res[5][col], swiz2); + t[6] = vec_perm((vec_t)hp_res[6][col], (vec_t)hp_res[7][col], swiz1); + t[7] = vec_perm((vec_t)hp_res[6][col], (vec_t)hp_res[7][col], swiz2); + + vec_xst(vec_perm(t[0], t[2], swiz3), 0, (vec_t *)(vecOffset + 0)); + vec_xst(vec_perm(t[0], t[2], swiz4), 0, (vec_t *)(vecOffset + 16)); + vec_xst(vec_perm(t[1], t[3], swiz3), 0, (vec_t *)(vecOffset + 32)); + vec_xst(vec_perm(t[1], t[3], swiz4), 0, (vec_t *)(vecOffset + 48)); + vec_xst(vec_perm(t[4], t[6], swiz3), 0, (vec_t *)(vecOffset + 64)); + vec_xst(vec_perm(t[4], t[6], swiz4), 0, (vec_t *)(vecOffset + 80)); + vec_xst(vec_perm(t[5], t[7], swiz3), 0, (vec_t *)(vecOffset + 96)); + vec_xst(vec_perm(t[5], t[7], swiz4), 0, (vec_t *)(vecOffset + 112)); + vecOffset += 128; + } else { + vec_t t0 = vec_perm((vec_t)hp_res[0][col], (vec_t)hp_res[1][col], swiz1); + vec_t t1 = vec_perm((vec_t)hp_res[0][col], (vec_t)hp_res[1][col], swiz2); + vec_t t2 = vec_perm((vec_t)hp_res[2][col], (vec_t)hp_res[3][col], swiz1); + vec_t t3 = vec_perm((vec_t)hp_res[2][col], (vec_t)hp_res[3][col], swiz2); + + vec_xst(vec_perm(t0, t2, swiz3), 0, (vec_t *)(vecOffset + 0)); + vec_xst(vec_perm(t0, t2, swiz4), 0, (vec_t *)(vecOffset + 16)); + vec_xst(vec_perm(t1, t3, swiz3), 0, (vec_t *)(vecOffset + 32)); + vec_xst(vec_perm(t1, t3, swiz4), 0, (vec_t *)(vecOffset + 48)); + vecOffset += 64; + } + } + } + } + } + + void packNormal_q8_fp16(const block_q8_0 * a, int64_t lda, int rows, int blocks, unsigned char * vec) { + if (rows == 4) { + pack_q8_block<4>(a, lda, rows, blocks, vec); + } else { + pack_q8_block<8>(a, lda, rows, blocks, vec); + } + } + + template + void packNormalInt4(const TA * a, int64_t lda, int rows, int cols, int8_t * vec, std::array & comparray) { int64_t i, j; - TA *aoffset = NULL; - int8_t *vecOffset = NULL; - TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL; - TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL; + TA * aoffset = NULL; + int8_t * vecOffset = NULL; + TA * aoffset1 = NULL, * aoffset2 = NULL, * aoffset3 = NULL, * aoffset4 = NULL; + TA * aoffset5 = NULL, * aoffset6 = NULL, * aoffset7 = NULL, * aoffset8 = NULL; vector signed char c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0}; vector signed char c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0}; - aoffset = const_cast(a); + aoffset = const_cast(a); vecOffset = vec; j = (rows >> 3); if (j > 0) { @@ -2363,18 +2620,18 @@ class tinyBLAS_HP16_PPC { c7[1] = reinterpret_cast(vec_xl(0, aoffset7->qs)); c8[1] = reinterpret_cast(vec_xl(0, aoffset8->qs)); - process_q4_elements(c1, &comparray[0]); - process_q4_elements(c2, &comparray[1]); - process_q4_elements(c3, &comparray[2]); - process_q4_elements(c4, &comparray[3]); - process_q4_elements(c5, &comparray[4]); - process_q4_elements(c6, &comparray[5]); - process_q4_elements(c7, &comparray[6]); - process_q4_elements(c8, &comparray[7]); + process_q4_elements(c1, & comparray[0]); + process_q4_elements(c2, & comparray[1]); + process_q4_elements(c3, & comparray[2]); + process_q4_elements(c4, & comparray[3]); + process_q4_elements(c5, & comparray[4]); + process_q4_elements(c6, & comparray[5]); + process_q4_elements(c7, & comparray[6]); + process_q4_elements(c8, & comparray[7]); vector_permute_store(c1[0], c2[0], c3[0], c4[0], vecOffset, false); - vector_permute_store(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false); - vector_permute_store(c5[0], c6[0], c7[0], c8[0], vecOffset+128, false); - vector_permute_store(c5[1], c6[1], c7[1], c8[1], vecOffset+192, false); + vector_permute_store(c1[1], c2[1], c3[1], c4[1], vecOffset + 64, false); + vector_permute_store(c5[0], c6[0], c7[0], c8[0], vecOffset + 128, false); + vector_permute_store(c5[1], c6[1], c7[1], c8[1], vecOffset + 192, false); aoffset1 += lda; aoffset2 += lda; aoffset3 += lda; @@ -2405,12 +2662,12 @@ class tinyBLAS_HP16_PPC { c3[1] = reinterpret_cast(vec_xl(0, aoffset3->qs)); c4[1] = reinterpret_cast(vec_xl(0, aoffset4->qs)); - process_q4_elements(c1, &comparray[0]); - process_q4_elements(c2, &comparray[1]); - process_q4_elements(c3, &comparray[2]); - process_q4_elements(c4, &comparray[3]); + process_q4_elements(c1, & comparray[0]); + process_q4_elements(c2, & comparray[1]); + process_q4_elements(c3, & comparray[2]); + process_q4_elements(c4, & comparray[3]); vector_permute_store(c1[0], c2[0], c3[0], c4[0], vecOffset, false); - vector_permute_store(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false); + vector_permute_store(c1[1], c2[1], c3[1], c4[1], vecOffset + 64, false); aoffset1 += lda; aoffset2 += lda; aoffset3 += lda; @@ -2434,12 +2691,12 @@ class tinyBLAS_HP16_PPC { case 1: c1[1] = reinterpret_cast(vec_xl(0, aoffset1->qs)); break; } - process_q4_elements(c1, &comparray[0]); - process_q4_elements(c2, &comparray[1]); - process_q4_elements(c3, &comparray[2]); - process_q4_elements(c4, &comparray[3]); + process_q4_elements(c1, & comparray[0]); + process_q4_elements(c2, & comparray[1]); + process_q4_elements(c3, & comparray[2]); + process_q4_elements(c4, & comparray[3]); vector_permute_store(c1[0], c2[0], c3[0], c4[0], vecOffset, false); - vector_permute_store(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false); + vector_permute_store(c1[1], c2[1], c3[1], c4[1], vecOffset + 64, false); aoffset1 += lda; aoffset2 += lda; aoffset3 += lda; @@ -2450,39 +2707,38 @@ class tinyBLAS_HP16_PPC { } } - template template - void tinyBLAS_Q0_PPC::packNormal(const block_q8_0* a, int64_t lda, int rows, int cols, VA* vec, bool flip) { + void packNormal(const block_q8_0 * a, int64_t lda, int rows, int cols, VA * vec, bool flip) { int64_t i, j; - block_q8_0 *aoffset = NULL; - VA *vecOffset = NULL; - block_q8_0* aoffsets[8]; + block_q8_0 * aoffset = NULL; + VA * vecOffset = NULL; + block_q8_0 * aoffsets[8]; __vector_pair arr[8]; VB c[8][2] = {0}; VB c1[8] = {0}; VB c2[8] = {0}; - aoffset = const_cast(a); + aoffset = const_cast(a); vecOffset = vec; j = (rows >> 3); if (j > 0) { do { aoffsets[0] = aoffset; for (int it = 1; it < 8; it++) - aoffsets[it] = aoffsets[it-1] + lda; + aoffsets[it] = aoffsets[it - 1] + lda; aoffset += 8 * lda; i = (cols >> 3); if (i > 0) { do { for (int it = 0; it < 8; it++) { - arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]->qs); - __builtin_vsx_disassemble_pair(c[it], &arr[it]); + arr[it] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[it]->qs); + __builtin_vsx_disassemble_pair(c[it], & arr[it]); c1[it] = c[it][0]; c2[it] = c[it][1]; } vector_permute_store(c1[0], c1[1], c1[2], c1[3], vecOffset, flip); - vector_permute_store(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip); - vector_permute_store(c1[4], c1[5], c1[6], c1[7], vecOffset+128, flip); - vector_permute_store(c2[4], c2[5], c2[6], c2[7], vecOffset+192, flip); + vector_permute_store(c2[0], c2[1], c2[2], c2[3], vecOffset + 64, flip); + vector_permute_store(c1[4], c1[5], c1[6], c1[7], vecOffset + 128, flip); + vector_permute_store(c2[4], c2[5], c2[6], c2[7], vecOffset + 192, flip); for (int it = 0; it < 8; it++) aoffsets[it] += lda; vecOffset += 256; @@ -2501,13 +2757,13 @@ class tinyBLAS_HP16_PPC { if (i > 0) { do { for (int it = 0; it < 4; it++) { - arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]->qs); - __builtin_vsx_disassemble_pair(c[it], &arr[it]); + arr[it] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[it]->qs); + __builtin_vsx_disassemble_pair(c[it], & arr[it]); c1[it] = c[it][0]; c2[it] = c[it][1]; } vector_permute_store(c1[0], c1[1], c1[2], c1[3], vecOffset, flip); - vector_permute_store(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip); + vector_permute_store(c2[0], c2[1], c2[2], c2[3], vecOffset + 64, flip); for (int it = 0; it < 4; it++) { aoffsets[it] += lda; } @@ -2520,24 +2776,24 @@ class tinyBLAS_HP16_PPC { if (rows & 3) { aoffsets[0] = aoffset; for (int it = 1; it < 3; it++ ) - aoffsets[it] = aoffsets[it-1] + lda; + aoffsets[it] = aoffsets[it - 1] + lda; i = (cols >> 3); if (i > 0) { do { switch(rows) { - case 3: arr[2] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[2]->qs); - __builtin_vsx_disassemble_pair(c[2], &arr[2]); + case 3: arr[2] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[2]->qs); + __builtin_vsx_disassemble_pair(c[2], & arr[2]); c1[2] = c[2][0]; c2[2] = c[2][1]; - case 2: arr[1] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[1]->qs); - __builtin_vsx_disassemble_pair(c[1], &arr[1]); + case 2: arr[1] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[1]->qs); + __builtin_vsx_disassemble_pair(c[1], & arr[1]); c1[1] = c[1][0]; c2[1] = c[1][1]; - case 1: arr[0] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[0]->qs); - __builtin_vsx_disassemble_pair(c[0], &arr[0]); + case 1: arr[0] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[0]->qs); + __builtin_vsx_disassemble_pair(c[0], & arr[0]); c1[0] = c[0][0]; c2[0] = c[0][1]; break; } vector_permute_store(c1[0], c1[1], c1[2], c1[3], vecOffset, flip); - vector_permute_store(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip); + vector_permute_store(c2[0], c2[1], c2[2], c2[3], vecOffset + 64, flip); for (int it = 0; it < 3; it++) aoffsets[it] += lda; vecOffset += 128; @@ -2547,8 +2803,7 @@ class tinyBLAS_HP16_PPC { } } - template - void tinyBLAS_Q0_PPC::mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) { + void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) { int m_rem = MIN(m - m0, 16); int n_rem = MIN(n - n0, 16); @@ -2585,8 +2840,7 @@ class tinyBLAS_HP16_PPC { } - template - void tinyBLAS_Q0_PPC::KERNEL_4x8(int64_t ii, int64_t jj) { + void KERNEL_4x8(int64_t ii, int64_t jj) { vec_t vec_A[8], vec_B[16] = {0}; acc_t acc_0, acc_1; std::array comparray {}; @@ -2594,26 +2848,26 @@ class tinyBLAS_HP16_PPC { vector float vs[8] = {0}; bool isAblock_q4 = std::is_same_v; for (int l = 0; l < k; l++) { - __builtin_mma_xxsetaccz(&acc_0); - __builtin_mma_xxsetaccz(&acc_1); + __builtin_mma_xxsetaccz(& acc_0); + __builtin_mma_xxsetaccz(& acc_1); if (std::is_same_v) { - packNormalInt4<4>((A+(ii*lda)+l), lda, 4, 4, (int8_t*)vec_A, comparray); + packNormalInt4<4>((A + (ii * lda) + l), lda, 4, 4, (int8_t *)vec_A, comparray); } else { - packNormal((const block_q8_0*)(A+(ii*lda)+l), lda, 4, 8, (int8_t*)vec_A, false); + packNormal((const block_q8_0 *)(A + (ii * lda) + l), lda, 4, 8, (int8_t *)vec_A, false); } - packNormal((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true); + packNormal((B + (jj * ldb) + l), ldb, 8, 8, (uint8_t *)vec_B, true); for(int x = 0; x < 8; x++) { - __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]); - __builtin_mma_xvi8ger4pp(&acc_1, vec_A[x], vec_B[x+8]); + __builtin_mma_xvi8ger4pp(& acc_0, vec_A[x], vec_B[x]); + __builtin_mma_xvi8ger4pp(& acc_1, vec_A[x], vec_B[x+8]); } for (int I = 0; I<4; I++) { for (int J = 0; J<4; J++) { - *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d)); - *((float*)&vs[I+4]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d)); + *((float *)& vs[I] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J) * ldb) + l)->d)); + *((float *)& vs[I + 4] + J) = (unhalf((A +((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J + 4) * ldb) + l)->d)); } } if (!isAblock_q4) { - auto aoffset = A+(ii*lda)+l; + auto aoffset = A + (ii * lda) + l; for (int i = 0; i < 4; i++) { comparray[i] = 0; int ca = 0; @@ -2624,15 +2878,14 @@ class tinyBLAS_HP16_PPC { aoffset += lda; } } - compute(&acc_0, 0, 0, comparray, vs, fin_res); - compute(&acc_1, 0, 4, comparray, vs, fin_res); + compute(& acc_0, 0, 0, comparray, vs, fin_res); + compute(& acc_1, 0, 4, comparray, vs, fin_res); } save_res(ii, jj, 0, fin_res); - save_res(ii, jj+4, 4, fin_res); + save_res(ii, jj + 4, 4, fin_res); } - template - void tinyBLAS_Q0_PPC::KERNEL_8x4(int64_t ii, int64_t jj) { + void KERNEL_8x4(int64_t ii, int64_t jj) { vec_t vec_A[16], vec_B[8] = {0}; acc_t acc_0, acc_1; std::array comparray {}; @@ -2640,25 +2893,25 @@ class tinyBLAS_HP16_PPC { vector float vs[8] = {0}; bool isAblock_q4 = std::is_same_v; for (int l = 0; l < k; l++) { - __builtin_mma_xxsetaccz(&acc_0); - __builtin_mma_xxsetaccz(&acc_1); + __builtin_mma_xxsetaccz(& acc_0); + __builtin_mma_xxsetaccz(& acc_1); if (std::is_same_v) { - packNormalInt4<8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray); + packNormalInt4<8>((A + (ii * lda) + l), lda, 8, 4, (int8_t *)vec_A, comparray); } else { - packNormal((const block_q8_0*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false); + packNormal((const block_q8_0 *)(A + (ii * lda) + l), lda, 8, 8, (int8_t *)vec_A, false); } - packNormal((B+(jj*ldb)+l), ldb, 4, 8, (uint8_t*)vec_B, true); + packNormal((B + (jj * ldb) + l), ldb, 4, 8, (uint8_t *)vec_B, true); for(int x = 0; x < 8; x++) { - __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]); - __builtin_mma_xvi8ger4pp(&acc_1, vec_A[x+8], vec_B[x]); + __builtin_mma_xvi8ger4pp(& acc_0, vec_A[x], vec_B[x]); + __builtin_mma_xvi8ger4pp(& acc_1, vec_A[x + 8], vec_B[x]); } - for (int I = 0; I<8; I++) { - for (int J = 0; J<4; J++) { - *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d)); + for (int I = 0; I < 8; I++) { + for (int J = 0; J < 4; J++) { + *((float *)&vs[I] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J) * ldb) + l)->d)); } } if (!isAblock_q4) { - auto aoffset = A+(ii*lda)+l; + auto aoffset = A + (ii * lda) + l; for (int i = 0; i < 8; i++) { comparray[i] = 0; int ca = 0; @@ -2669,15 +2922,14 @@ class tinyBLAS_HP16_PPC { aoffset += lda; } } - compute(&acc_0, 0, 0, comparray, vs, fin_res); - compute(&acc_1, 4, 4, comparray, vs, fin_res); + compute(& acc_0, 0, 0, comparray, vs, fin_res); + compute(& acc_1, 4, 4, comparray, vs, fin_res); } save_res(ii, jj, 0, fin_res); - save_res(ii+4, jj, 4, fin_res); + save_res(ii + 4, jj, 4, fin_res); } - template - void tinyBLAS_Q0_PPC::KERNEL_8x8(int64_t ii, int64_t jj) { + void KERNEL_8x8(int64_t ii, int64_t jj) { vec_t vec_A[16], vec_B[16] = {0}; acc_t acc_0, acc_1, acc_2, acc_3; acc_t acc_4, acc_5, acc_6, acc_7; @@ -2686,30 +2938,30 @@ class tinyBLAS_HP16_PPC { vector float vs[16] = {0}; bool isAblock_q4 = std::is_same_v; for (int l = 0; l < k; l++) { - __builtin_mma_xxsetaccz(&acc_0); - __builtin_mma_xxsetaccz(&acc_1); - __builtin_mma_xxsetaccz(&acc_2); - __builtin_mma_xxsetaccz(&acc_3); + __builtin_mma_xxsetaccz(& acc_0); + __builtin_mma_xxsetaccz(& acc_1); + __builtin_mma_xxsetaccz(& acc_2); + __builtin_mma_xxsetaccz(& acc_3); if (std::is_same_v) { - packNormalInt4<8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray); + packNormalInt4<8>((A + (ii * lda) + l), lda, 8, 4, (int8_t *)vec_A, comparray); } else { - packNormal((const block_q8_0*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false); + packNormal((const block_q8_0 *)(A + (ii * lda) + l), lda, 8, 8, (int8_t *)vec_A, false); } - packNormal((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true); + packNormal((B + (jj * ldb) + l), ldb, 8, 8, (uint8_t *)vec_B, true); for(int x = 0; x < 8; x++) { - __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]); - __builtin_mma_xvi8ger4pp(&acc_1, vec_A[x+8], vec_B[x]); - __builtin_mma_xvi8ger4pp(&acc_2, vec_A[x], vec_B[x+8]); - __builtin_mma_xvi8ger4pp(&acc_3, vec_A[x+8], vec_B[x+8]); + __builtin_mma_xvi8ger4pp(& acc_0, vec_A[x], vec_B[x]); + __builtin_mma_xvi8ger4pp(& acc_1, vec_A[x + 8], vec_B[x]); + __builtin_mma_xvi8ger4pp(& acc_2, vec_A[x], vec_B[x + 8]); + __builtin_mma_xvi8ger4pp(& acc_3, vec_A[x + 8], vec_B[x + 8]); } - for (int I = 0; I<8; I++) { - for (int J = 0; J<4; J++) { - *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d)); - *((float*)&vs[I+8]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d)); + for (int I = 0; I < 8 ; I++) { + for (int J = 0; J < 4; J++) { + *((float *)& vs[I] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J) * ldb) + l)->d)); + *((float *)& vs[I + 8] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J + 4) * ldb) + l)->d)); } } if (!isAblock_q4) { - auto aoffset = A+(ii*lda)+l; + auto aoffset = A + (ii * lda) + l; for (int i = 0; i < 8; i++) { comparray[i] = 0; int ca = 0; @@ -2720,19 +2972,99 @@ class tinyBLAS_HP16_PPC { aoffset += lda; } } - compute(&acc_0, 0, 0, comparray, vs, fin_res); - compute(&acc_1, 4, 4, comparray, vs, fin_res); - compute(&acc_2, 0, 8, comparray, vs, fin_res); - compute(&acc_3, 4, 12, comparray, vs, fin_res); + compute(& acc_0, 0, 0, comparray, vs, fin_res); + compute(& acc_1, 4, 4, comparray, vs, fin_res); + compute(& acc_2, 0, 8, comparray, vs, fin_res); + compute(& acc_3, 4, 12, comparray, vs, fin_res); } save_res(ii, jj, 0, fin_res); - save_res(ii+4, jj, 4, fin_res); - save_res(ii, jj+4, 8, fin_res); - save_res(ii+4, jj+4, 12, fin_res); + save_res(ii + 4, jj, 4, fin_res); + save_res(ii, jj + 4, 8, fin_res); + save_res(ii + 4, jj + 4, 12, fin_res); + } + + void KERNEL_Q0(int64_t ii, int64_t jj, int64_t mc, int64_t nc, int64_t kc, int64_t l, vec_t * vec_A, vec_t * vec_B) { + acc_t acc[8]; + for (int i = 0; i < mc ; i += 16) { + for (int j = 0; j < nc; j += 8) { + int A0_base = (i / 16) * (2 * 32 * kc); + int B0_base = (j / 8) * (32 * kc); + for (int x = 0; x < 8; x++) { + __builtin_mma_xxsetaccz(&acc[x]); + } + for (int64_t kk = 0; kk < kc; kk++) { + int A0_block_idx = A0_base + kk * 32; + int B0_block_idx = B0_base + kk * 32; + int A1_block_idx = A0_block_idx + 32 * kc; + int B1_block_idx = B0_block_idx + 32 * kc; + vec_t * A0_block = & vec_A[A0_block_idx]; + vec_t * B0_block = & vec_B[B0_block_idx]; + vec_t * A1_block = & vec_A[A1_block_idx]; + for (int it = 0; it < 4; it++) { + for (int x = 0; x < 4; x++) { + __builtin_mma_xvf16ger2pp(& acc[0], A0_block[8 * it + x], B0_block[8 * it + x]); + __builtin_mma_xvf16ger2pp(& acc[1], A0_block[8 * it + x], B0_block[8 * it + x + 4]); + __builtin_mma_xvf16ger2pp(& acc[2], A0_block[8 * it + x + 4], B0_block[8 * it + x]); + __builtin_mma_xvf16ger2pp(& acc[3], A0_block[8 * it + x + 4], B0_block[8 * it + x + 4]); + __builtin_mma_xvf16ger2pp(& acc[4], A1_block[8 * it + x], B0_block[8 * it + x]); + __builtin_mma_xvf16ger2pp(& acc[5], A1_block[8 * it + x], B0_block[8 * it+ x + 4]); + __builtin_mma_xvf16ger2pp(& acc[6], A1_block[8 * it + x + 4], B0_block[8 * it + x]); + __builtin_mma_xvf16ger2pp(& acc[7], A1_block[8 * it + x + 4], B0_block[8 * it + x + 4]); + } + } + } + if (l == 0) { + save_acc(& acc[0], ii + i, jj + j); + save_acc(& acc[1], ii + i, jj + j + 4); + save_acc(& acc[2], ii + i + 4, jj + j); + save_acc(& acc[3], ii + i + 4, jj + j + 4); + save_acc(& acc[4], ii + i + 8, jj + j); + save_acc(& acc[5], ii + i + 8, jj + j + 4); + save_acc(& acc[6], ii + i + 12, jj + j); + save_acc(& acc[7], ii + i + 12, jj + j + 4); + } else { + add_save_acc(& acc[0], ii + i, jj + j); + add_save_acc(& acc[1], ii + i, jj + j + 4); + add_save_acc(& acc[2], ii + i + 4, jj + j); + add_save_acc(& acc[3], ii + i + 4, jj + j + 4); + add_save_acc(& acc[4], ii + i + 8, jj + j); + add_save_acc(& acc[5], ii + i + 8, jj + j + 4); + add_save_acc(& acc[6], ii + i + 12, jj + j); + add_save_acc(& acc[7], ii + i + 12, jj + j + 4); + } + } + } } - template - void tinyBLAS_Q0_PPC::gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) { + void matmul_tiled(int64_t m, int64_t n, int64_t mc, int64_t nc, int64_t kc) { + vec_t A_pack[mc * kc * 4]; + vec_t B_pack[nc * kc * 4]; + constexpr bool is_Ablock_q4 = std::is_same_v; + int64_t ytiles = m / mc; + int64_t xtiles = n / nc; + int64_t tiles = xtiles * ytiles; + int64_t duty = (tiles + nth - 1) / nth; + int64_t start = duty * ith; + int64_t end = start + duty; + if (end > tiles) { + end = tiles; + } + for (int64_t job = start; job < end; ++job) { + int64_t ii = (job / xtiles) * mc; + int64_t jj = (job % xtiles) * nc; + for (int64_t kk = 0; kk < k; kk += kc) { + if constexpr(is_Ablock_q4) { + packNormal_q4_fp16(A + ii * lda + kk, lda, mc, kc, (uint8_t *)A_pack); + } else { + packNormal_q8_fp16(A + ii * lda + kk, lda, mc, kc, (uint8_t *)A_pack); + } + packNormal_q8_fp16(B + jj * ldb + kk, ldb, nc, kc, (uint8_t *)B_pack); + KERNEL_Q0(ii, jj, mc, nc, kc, kk, A_pack, B_pack); + } + } + } + + void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) { int64_t ytiles = (m - m0) / RM; int64_t xtiles = (n - n0) / RN; int64_t tiles = xtiles * ytiles; @@ -2754,32 +3086,32 @@ class tinyBLAS_HP16_PPC { vector float fin_res[4] = {0}; vector float vs[4] = {0}; vector float CA[4] = {0}; - __builtin_prefetch((A+(ii*lda)+0)->qs, 0, 1); // prefetch first value - __builtin_prefetch((B+(jj*ldb)+0)->qs, 0, 1); // prefetch first value + __builtin_prefetch((A + (ii * lda) + 0)->qs, 0, 1); // prefetch first value + __builtin_prefetch((B + (jj * ldb) + 0)->qs, 0, 1); // prefetch first value for (int l = 0; l < k; l++) { - __builtin_prefetch((A+(ii*lda)+(l+1))->qs, 0, 1); // prefetch one loop ahead - __builtin_prefetch((B+(jj*ldb)+(l+1))->qs, 0, 1); // prefetch one loop ahead - __builtin_mma_xxsetaccz(&acc_0); + __builtin_prefetch((A + (ii * lda) + (l + 1))->qs, 0, 1); // prefetch one loop ahead + __builtin_prefetch((B + (jj * ldb) + (l + 1))->qs, 0, 1); // prefetch one loop ahead + __builtin_mma_xxsetaccz(& acc_0); if (isAblock_q4) { - packNormalInt4<4>((A+(ii*lda)+l), lda, RM, 4, (int8_t*)vec_A, comparray); + packNormalInt4<4>((A + (ii * lda) + l), lda, RM, 4, (int8_t *)vec_A, comparray); } else { - packNormal((const block_q8_0*)(A+(ii*lda)+l), lda, RM, 8, (int8_t*)vec_A, false); + packNormal((const block_q8_0 *)(A + (ii * lda) + l), lda, RM, 8, (int8_t *)vec_A, false); } - packNormal((B+(jj*ldb)+l), ldb, RN, 8, (uint8_t*)vec_B, true); - for(int x = 0; x < 8; x+=4) { - __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]); - __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+1], vec_B[x+1]); - __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+2], vec_B[x+2]); - __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+3], vec_B[x+3]); + packNormal((B + (jj * ldb) + l), ldb, RN, 8, (uint8_t *)vec_B, true); + for (int x = 0; x < 8; x += 4) { + __builtin_mma_xvi8ger4pp(& acc_0, vec_A[x], vec_B[x]); + __builtin_mma_xvi8ger4pp(& acc_0, vec_A[x + 1], vec_B[x + 1]); + __builtin_mma_xvi8ger4pp(& acc_0, vec_A[x + 2], vec_B[x + 2]); + __builtin_mma_xvi8ger4pp(& acc_0, vec_A[x + 3], vec_B[x + 3]); } - for (int I = 0; Id) * unhalf((B+((jj+J)*ldb)+l)->d)); + for (int I = 0; I < RM; I++) { + for (int J = 0; J < RN; J++) { + *((float*)&vs[I] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J) * ldb) + l)->d)); } } - __builtin_mma_disassemble_acc(vec_C, &acc_0); + __builtin_mma_disassemble_acc(vec_C, & acc_0); if (!isAblock_q4) { - auto aoffset = A+(ii*lda)+l; + auto aoffset = A + (ii * lda) + l; for (int i = 0; i < RM; i++) { comparray[i] = 0; int ca = 0; @@ -2800,9 +3132,21 @@ class tinyBLAS_HP16_PPC { } } - template + template + inline void kernel(int64_t ii, int64_t jj) { + if constexpr(RM == 4 && RN == 8) { + KERNEL_4x8(ii,jj); + } else if constexpr(RM == 8 && RN == 4) { + KERNEL_8x4(ii,jj); + } else if constexpr(RM == 8 && RN == 8) { + KERNEL_8x8(ii,jj); + } else { + assert(false && "RN/RM values not supported"); + } + } + template - NOINLINE void tinyBLAS_Q0_PPC::gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) { + NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) { int64_t ytiles = (m - m0) / RM; int64_t xtiles = (n - n0) / RN; int64_t tiles = xtiles * ytiles; @@ -2814,12 +3158,20 @@ class tinyBLAS_HP16_PPC { for (int64_t job = start; job < end; ++job) { int64_t ii = m0 + job / xtiles * RM; int64_t jj = n0 + job % xtiles * RN; - this->kernel(ii, jj); + kernel(ii, jj); } } - -template class tinyBLAS_Q0_PPC; -template class tinyBLAS_Q0_PPC; + const TA * const A; + const block_q8_0 * const B; + float * C; + const int64_t k; + int64_t kc; + const int64_t lda; + const int64_t ldb; + const int64_t ldc; + const int ith; + const int nth; +}; class tinyBLAS_PPC { public: From ade724fced721a56ebd2d6c5eb4dd14ddecf59ca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Thu, 19 Feb 2026 12:42:58 +0100 Subject: [PATCH 175/831] CUDA: fix kernel selection logic for tile FA (llama/19686) * CUDA: fix kernel selection logic for tile FA * add comment --- ggml/src/ggml-cuda/fattn-tile.cuh | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/fattn-tile.cuh b/ggml/src/ggml-cuda/fattn-tile.cuh index b6db5822818..f3fa80ab23d 100644 --- a/ggml/src/ggml-cuda/fattn-tile.cuh +++ b/ggml/src/ggml-cuda/fattn-tile.cuh @@ -1186,8 +1186,10 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); const int gqa_ratio = Q->ne[2] / K->ne[2]; + // On NVIDIA (Pascal and older) the GQA optimizations seem to be detrimental in some cases. + // However, for DKQ == 576, DV == 512 only the kernel variant with GQA optimizations is implemented. const bool nvidia = GGML_CUDA_CC_IS_NVIDIA(ggml_cuda_info().devices[ggml_cuda_get_device()].cc); - const int gqa_limit = nvidia && gqa_ratio <= 4 ? 16 : INT_MAX; + const int gqa_limit = nvidia && gqa_ratio <= 4 && DV <= 256 ? 16 : INT_MAX; const bool use_gqa_opt = mask && max_bias == 0.0f && Q->ne[1] <= gqa_limit && K->ne[1] % FATTN_KQ_STRIDE == 0; if constexpr (DV == 512) { From 3f68f30907e8285368737363eb6412928559a3c0 Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Thu, 19 Feb 2026 14:59:16 +0100 Subject: [PATCH 176/831] vulkan: fix MMQ shader push constants and multi-dispatch (llama/19732) --- ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp index 335d7f6a682..aae1c2e8ae9 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp @@ -57,6 +57,8 @@ layout (push_constant) uniform parameter uint nbi1; uint ne11; #else + uint base_work_group_z; + uint num_batches; uint k_split; uint ne02; uint ne12; @@ -108,7 +110,7 @@ void main() { const uint ic = gl_WorkGroupID.y; #ifdef MUL_MAT_ID - const uint expert_idx = gl_GlobalInvocationID.z; + const uint expert_idx = gl_WorkGroupID.z; if (ic * BN >= data_expert_count[expert_idx]) { return; } @@ -118,7 +120,7 @@ void main() { #endif #ifndef MUL_MAT_ID - const uint batch_idx = gl_GlobalInvocationID.z; + const uint batch_idx = gl_WorkGroupID.z + p.base_work_group_z; const uint i13 = batch_idx / p.ne12; const uint i12 = batch_idx % p.ne12; @@ -276,7 +278,7 @@ void main() { const uint dc = ic * BN + warp_c * WN; #ifndef MUL_MAT_ID - const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z; + const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * p.num_batches; #endif [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { From 0158795ebc9a57fafc3ab3be4814b367590767c4 Mon Sep 17 00:00:00 2001 From: Masashi Yoshimura Date: Fri, 20 Feb 2026 01:18:30 +0900 Subject: [PATCH 177/831] ggml-webgpu: Add unary op (SQR, SQRT, SIN, COS) support. (llama/19700) * ggml-webgpu: Add unary op (SQR, SQRT, SIN, COS) support. * Fix to cast the src value to f32 before sin/cos computing. --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 20 ++++++++++++++++++++ ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl | 14 ++++++++++++++ 2 files changed, 34 insertions(+) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index b5fee480562..1c00d3cb2b1 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -2008,6 +2008,14 @@ static std::optional ggml_webgpu_encode_node(webgpu_context ctx, return ggml_webgpu_unary_op(ctx, src0, node); case GGML_OP_LOG: return ggml_webgpu_unary_op(ctx, src0, node); + case GGML_OP_SQR: + return ggml_webgpu_unary_op(ctx, src0, node); + case GGML_OP_SQRT: + return ggml_webgpu_unary_op(ctx, src0, node); + case GGML_OP_SIN: + return ggml_webgpu_unary_op(ctx, src0, node); + case GGML_OP_COS: + return ggml_webgpu_unary_op(ctx, src0, node); case GGML_OP_PAD: return ggml_webgpu_pad(ctx, src0, node); case GGML_OP_ARGMAX: @@ -2967,6 +2975,18 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const case GGML_OP_LOG: supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type); break; + case GGML_OP_SQR: + supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type); + break; + case GGML_OP_SQRT: + supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type); + break; + case GGML_OP_SIN: + supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type); + break; + case GGML_OP_COS: + supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type); + break; case GGML_OP_PAD: supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32; break; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl index d639d984970..feaf6d0ac29 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl @@ -170,6 +170,20 @@ fn main(@builtin(global_invocation_id) gid: vec3) { #ifdef TRUNC let res = trunc(src[params.offset_src + src_idx]); #endif +#ifdef SQR + let res = src[params.offset_src + src_idx] * src[params.offset_src + src_idx]; +#endif +#ifdef SQRT + let res = sqrt(src[params.offset_src + src_idx]); +#endif +#ifdef SIN + let res_f32 = sin(f32(src[params.offset_src + src_idx])); + let res = TYPE(res_f32); +#endif +#ifdef COS + let res_f32 = cos(f32(src[params.offset_src + src_idx])); + let res = TYPE(res_f32); +#endif #ifdef INPLACE src[params.offset_src + src_idx] = res; From 0c10a15447a644d09600e24d948e861b422d94b2 Mon Sep 17 00:00:00 2001 From: Taimur Ahmad Date: Fri, 20 Feb 2026 16:30:07 +0500 Subject: [PATCH 178/831] ggml-cpu: add RVV vec dot kernels for quantization types (llama/18784) * ggml-cpu: add rvv vec_dot for iq2_s Co-authored-by: Rehan Qasim * ggml-cpu: add rvv vec_dot for iq3_s Co-authored-by: Rehan Qasim * ggml-cpu: add rvv vec_dot for tq1_0, tq2_0 Co-authored-by: Rehan Qasim ggml-cpu: add rvv vec_dot for tq1_0, tq2_0 * ggml-cpu: add rvv vec_dot for iq1_s, iq1_m Co-authored-by: Rehan Qasim * ggml-cpu: add vlen switch for rvv vec_dot --------- Co-authored-by: Rehan Qasim --- ggml/src/ggml-cpu/arch-fallback.h | 6 - ggml/src/ggml-cpu/arch/riscv/quants.c | 770 ++++++++++++++++++++++++++ 2 files changed, 770 insertions(+), 6 deletions(-) diff --git a/ggml/src/ggml-cpu/arch-fallback.h b/ggml/src/ggml-cpu/arch-fallback.h index c6eb75b2300..55526e6fb38 100644 --- a/ggml/src/ggml-cpu/arch-fallback.h +++ b/ggml/src/ggml-cpu/arch-fallback.h @@ -171,15 +171,9 @@ #elif defined(__riscv) // quants.c #define quantize_row_q8_K_generic quantize_row_q8_K -#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K -#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K #define ggml_vec_dot_iq2_xxs_q8_K_generic ggml_vec_dot_iq2_xxs_q8_K #define ggml_vec_dot_iq2_xs_q8_K_generic ggml_vec_dot_iq2_xs_q8_K -#define ggml_vec_dot_iq2_s_q8_K_generic ggml_vec_dot_iq2_s_q8_K #define ggml_vec_dot_iq3_xxs_q8_K_generic ggml_vec_dot_iq3_xxs_q8_K -#define ggml_vec_dot_iq3_s_q8_K_generic ggml_vec_dot_iq3_s_q8_K -#define ggml_vec_dot_iq1_s_q8_K_generic ggml_vec_dot_iq1_s_q8_K -#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K #define ggml_vec_dot_iq4_nl_q8_0_generic ggml_vec_dot_iq4_nl_q8_0 #define ggml_vec_dot_iq4_xs_q8_K_generic ggml_vec_dot_iq4_xs_q8_K #define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0 diff --git a/ggml/src/ggml-cpu/arch/riscv/quants.c b/ggml/src/ggml-cpu/arch/riscv/quants.c index ae0ebb3cad1..bf9f4df1182 100644 --- a/ggml/src/ggml-cpu/arch/riscv/quants.c +++ b/ggml/src/ggml-cpu/arch/riscv/quants.c @@ -1954,3 +1954,773 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi #endif } +static const uint8_t sign_gather_indices_arr[64] = { + 0,0,0,0,0,0,0,0, 1,1,1,1,1,1,1,1, 2,2,2,2,2,2,2,2, 3,3,3,3,3,3,3,3, + 4,4,4,4,4,4,4,4, 5,5,5,5,5,5,5,5, 6,6,6,6,6,6,6,6, 7,7,7,7,7,7,7,7 +}; + +static const uint8_t sign_bit_masks_arr[64] = { + 1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128, + 1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128 +}; + +static void ggml_vec_dot_iq2_s_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); + + const block_iq2_s * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + const uint64_t * grid64 = (const uint64_t *)iq2s_grid; + + // --- Pre-load Constants --- + uint16_t gather_qh_arr[8] = {0, 0, 0, 0, 1, 1, 1, 1}; + vuint16mf2_t v_gather_qh = __riscv_vle16_v_u16mf2(gather_qh_arr, 8); + uint16_t shift_qh_arr[8] = {11, 9, 7, 5, 11, 9, 7, 5}; + vuint16mf2_t v_shift_qh = __riscv_vle16_v_u16mf2(shift_qh_arr, 8); + + // Constants for sign extraction + vuint8m2_t v_sign_gather_indices = __riscv_vle8_v_u8m2(sign_gather_indices_arr, 64); + vuint8m2_t v_sign_masks = __riscv_vle8_v_u8m2(sign_bit_masks_arr, 64); + + float sumf = 0.0f; + + for (int i = 0; i < nb; ++i) { + const float combined_scale = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * GGML_RESTRICT qs = x[i].qs; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const uint8_t * GGML_RESTRICT scales = x[i].scales; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + const uint8_t * signs_ptr = qs + 32; + + float sum_block = 0.0f; + + for (int ib = 0; ib < 4; ++ib) { + // Combine low + high bits + vuint8mf4_t v_qs_u8 = __riscv_vle8_v_u8mf4(qs, 8); + qs += 8; + uint16_t qh_val; + memcpy(&qh_val, qh, 2); + qh += 2; + vuint8mf8_t v_qh_raw = __riscv_vle8_v_u8mf8((const uint8_t*)&qh_val, 2); + vuint16mf4_t v_qh_u16 = __riscv_vwcvtu_x_x_v_u16mf4(v_qh_raw, 2); + vuint16mf2_t v_qh_u16_ext = __riscv_vlmul_ext_v_u16mf4_u16mf2(v_qh_u16); + vuint16mf2_t v_qh_expanded = __riscv_vrgather_vv_u16mf2(v_qh_u16_ext, v_gather_qh, 8); + v_qh_expanded = __riscv_vsll_vv_u16mf2(v_qh_expanded, v_shift_qh, 8); + + // Mask: We want bits 11-12. 0x1800 = 0001 1000 0000 0000 + v_qh_expanded = __riscv_vand_vx_u16mf2(v_qh_expanded, 0x1800, 8); + vuint16mf2_t v_qs_u16 = __riscv_vwcvtu_x_x_v_u16mf2(v_qs_u8, 8); + + // Multiply by 8 to get byte offset, instead of element offset + v_qs_u16 = __riscv_vsll_vx_u16mf2(v_qs_u16, 3, 8); + vuint16mf2_t v_grid_offsets = __riscv_vor_vv_u16mf2(v_qs_u16, v_qh_expanded, 8); + + // Lookup Grid using Byte Offsets + vuint64m2_t v_grid_vals = __riscv_vluxei16_v_u64m2(grid64, v_grid_offsets, 8); + + vuint8m2_t v_grid_u8 = __riscv_vreinterpret_v_u64m2_u8m2(v_grid_vals); + vint8m2_t v_grid_i8 = __riscv_vreinterpret_v_u8m2_i8m2(v_grid_u8); + + // Load signs and generate sign mask + vuint8mf4_t v_signs_raw = __riscv_vle8_v_u8mf4(signs_ptr, 8); + signs_ptr += 8; + + vuint8m2_t v_signs_source = __riscv_vlmul_ext_v_u8mf4_u8m2(v_signs_raw); + vuint8m2_t v_signs_bcast = __riscv_vrgather_vv_u8m2(v_signs_source, v_sign_gather_indices, 64); + + vuint8m2_t v_sign_bits = __riscv_vand_vv_u8m2(v_signs_bcast, v_sign_masks, 64); + vbool4_t m_negative = __riscv_vmsne_vx_u8m2_b4(v_sign_bits, 0, 64); + + vint8m2_t v_q8 = __riscv_vle8_v_i8m2(q8, 64); + q8 += 64; + + vint8m2_t v_q8_signed = __riscv_vrsub_vx_i8m2_mu(m_negative, v_q8, v_q8, 0, 64); + vint16m4_t v_dot = __riscv_vwmul_vv_i16m4(v_grid_i8, v_q8_signed, 64); + + vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, 1); + + int32_t s0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( + __riscv_vget_v_i16m4_i16m1(v_dot, 0), v_zero, 16)); + int32_t s1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( + __riscv_vget_v_i16m4_i16m1(v_dot, 1), v_zero, 16)); + int32_t s2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( + __riscv_vget_v_i16m4_i16m1(v_dot, 2), v_zero, 16)); + int32_t s3 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( + __riscv_vget_v_i16m4_i16m1(v_dot, 3), v_zero, 16)); + + uint8_t sc0 = scales[0]; + uint8_t sc1 = scales[1]; + scales += 2; + + sum_block += s0 * (2 * (sc0 & 0xF) + 1); + sum_block += s1 * (2 * (sc0 >> 4) + 1); + sum_block += s2 * (2 * (sc1 & 0xF) + 1); + sum_block += s3 * (2 * (sc1 >> 4) + 1); + } + sumf += sum_block * combined_scale; + } + *s = 0.125f * sumf; +} + +static void ggml_vec_dot_iq2_s_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); + + const block_iq2_s * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + const uint64_t * grid64 = (const uint64_t *)iq2s_grid; + + // Pre-load Constants + vuint8m2_t v_ids = __riscv_vid_v_u8m2(32); + vuint8m2_t v_sign_gather_indices = __riscv_vsrl_vx_u8m2(v_ids, 3, 32); + vuint8m2_t v_ones = __riscv_vmv_v_x_u8m2(1, 32); + vuint8m2_t v_shift_amts = __riscv_vand_vx_u8m2(v_ids, 7, 32); + vuint8m2_t v_sign_masks = __riscv_vsll_vv_u8m2(v_ones, v_shift_amts, 32); + uint16_t shift_qh_arr[4] = {11, 9, 7, 5}; + vuint16mf2_t v_shift_qh = __riscv_vle16_v_u16mf2(shift_qh_arr, 4); + + float sumf = 0.0f; + + for (int i = 0; i < nb; ++i) { + const float combined_scale = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * GGML_RESTRICT qs = x[i].qs; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const uint8_t * GGML_RESTRICT scales = x[i].scales; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + const uint8_t * signs_ptr = qs + 32; + float sum_block = 0.0f; + + for (int ib = 0; ib < 8; ++ib) { + + // Load Low Bits [4 bytes] + vuint8mf4_t v_qs_u8 = __riscv_vle8_v_u8mf4(qs, 4); + qs += 4; + + // Load 1 byte. It contains bits for 4 mini-blocks. + uint8_t qh_val = *qh++; + + // Combine Low + High bits of 10bit indices + vuint8mf4_t v_qh_raw = __riscv_vmv_v_x_u8mf4(qh_val, 4); + vuint16mf2_t v_qh_u16 = __riscv_vwcvtu_x_x_v_u16mf2(v_qh_raw, 4); + vuint16mf2_t v_qh_mf2 = __riscv_vsll_vv_u16mf2(v_qh_u16, v_shift_qh, 4); + v_qh_mf2 = __riscv_vand_vx_u16mf2(v_qh_mf2, 0x1800, 4); + vuint16mf2_t v_qs_u16_mf2 = __riscv_vwcvtu_x_x_v_u16mf2(v_qs_u8, 4); + vuint16mf2_t v_qs_u16 = __riscv_vsll_vx_u16mf2(v_qs_u16_mf2, 3, 4); + vuint16mf2_t v_grid_offsets = __riscv_vor_vv_u16mf2(v_qs_u16, v_qh_mf2, 4); + + // Lookup Grid + vint8m2_t v_grid_i8 = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vreinterpret_v_u64m2_u8m2(__riscv_vluxei16_v_u64m2(grid64, v_grid_offsets, 4))); + + vuint8mf4_t v_signs_raw = __riscv_vle8_v_u8mf4(signs_ptr, 4); + signs_ptr += 4; + vuint8m2_t v_signs_source = __riscv_vlmul_ext_v_u8mf4_u8m2(v_signs_raw); + vuint8m2_t v_signs_bcast = __riscv_vrgather_vv_u8m2(v_signs_source, v_sign_gather_indices, 32); + + // generating sign mask + vuint8m2_t v_sign_bits = __riscv_vand_vv_u8m2(v_signs_bcast, v_sign_masks, 32); + vbool4_t m_negative = __riscv_vmsne_vx_u8m2_b4(v_sign_bits, 0, 32); + + vint8m2_t v_q8 = __riscv_vle8_v_i8m2(q8, 32); + q8 += 32; + + // apply signs + vint8m2_t v_q8_signed = __riscv_vrsub_vx_i8m2_mu(m_negative,v_q8, v_q8, 0, 32); + vint16m4_t v_dot = __riscv_vwmul_vv_i16m4(v_grid_i8, v_q8_signed, 32); + + // Reduction + vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, 1); + + // Reduce 0-15 (First Half) + int32_t s0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1( + __riscv_vget_v_i16m4_i16m2(v_dot, 0), v_zero, 16)); + + // Reduce 16-31 (Second Half) + int32_t s1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1( + __riscv_vget_v_i16m4_i16m2(v_dot, 1), v_zero, 16)); + + // Apply sub Scales + uint8_t sc = *scales++; + + sum_block += s0 * (2 * (sc & 0xF) + 1); + sum_block += s1 * (2 * (sc >> 4) + 1); + } + sumf += sum_block * combined_scale; + } + *s = 0.125f * sumf; +} + +void ggml_vec_dot_iq2_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_v_intrinsic + switch (__riscv_vlenb() * 8) { + case 128: + ggml_vec_dot_iq2_s_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); + break; + case 256: + ggml_vec_dot_iq2_s_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); + break; + default: + ggml_vec_dot_iq2_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); + break; + } +#else + ggml_vec_dot_iq2_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + +static void ggml_vec_dot_iq3_s_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq3_s * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + const uint64_t * grid64 = (const uint64_t *)iq3s_grid; + + // --- Pre-load Constants --- + const uint16_t qh_bit_shifts_arr[16] = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 + }; + vuint8m2_t v_sign_gather_indices = __riscv_vle8_v_u8m2(sign_gather_indices_arr, 64); + vuint8m2_t v_sign_masks = __riscv_vle8_v_u8m2(sign_bit_masks_arr, 64); + vuint16m1_t v_qh_shifts = __riscv_vle16_v_u16m1(qh_bit_shifts_arr, 16); + + float sumf = 0.0f; + + for (int i = 0; i < nb; ++i) { + const float d = GGML_CPU_FP16_TO_FP32(x[i].d); + const float combined_scale = d * y[i].d; + + const uint8_t * GGML_RESTRICT qs = x[i].qs; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const uint8_t * GGML_RESTRICT scales = x[i].scales; + const uint8_t * GGML_RESTRICT signs = x[i].signs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + float sum_block = 0.0f; + + // Loop: Process 64 weights (16 mini-blocks of 4) per iteration + for (int ib = 0; ib < 4; ++ib) { + + vuint8mf2_t v_qs_u8 = __riscv_vle8_v_u8mf2(qs, 16); + qs += 16; + + uint16_t qh_val; + memcpy(&qh_val, qh, 2); + qh += 2; + + vuint16m1_t v_qh_val = __riscv_vmv_v_x_u16m1(qh_val, 16); + // Extract bits: (qh >> i) & 1 + v_qh_val = __riscv_vsrl_vv_u16m1(v_qh_val, v_qh_shifts, 16); + v_qh_val = __riscv_vand_vx_u16m1(v_qh_val, 1, 16); + + vuint16m1_t v_qs_u16 = __riscv_vwcvtu_x_x_v_u16m1(v_qs_u8, 16); + v_qs_u16 = __riscv_vsll_vx_u16m1(v_qs_u16, 2, 16); + v_qh_val = __riscv_vsll_vx_u16m1(v_qh_val, 10, 16); + vuint16m1_t v_grid_offsets = __riscv_vor_vv_u16m1(v_qs_u16, v_qh_val, 16); + + // Grid value is 4xuint8 + vuint32m2_t v_grid_packed = __riscv_vluxei16_v_u32m2((const uint32_t *)grid64, v_grid_offsets, 16); + vuint8m2_t v_grid_u8 = __riscv_vreinterpret_v_u32m2_u8m2(v_grid_packed); + vuint8mf4_t v_signs_raw = __riscv_vle8_v_u8mf4(signs, 8); + signs += 8; + + // Generate sign mask + vuint8m2_t v_signs_source = __riscv_vlmul_ext_v_u8mf4_u8m2(v_signs_raw); + vuint8m2_t v_signs_bcast = __riscv_vrgather_vv_u8m2(v_signs_source, v_sign_gather_indices, 64); + vuint8m2_t v_sign_bits = __riscv_vand_vv_u8m2(v_signs_bcast, v_sign_masks, 64); + vbool4_t m_negative = __riscv_vmsne_vx_u8m2_b4(v_sign_bits, 0, 64); + + vint8m2_t v_q8 = __riscv_vle8_v_i8m2(q8, 64); + q8 += 64; + + // Apply Signs + vint8m2_t v_q8_signed = __riscv_vrsub_vx_i8m2_mu(m_negative, v_q8, v_q8, 0, 64); + vint16m4_t v_dot = __riscv_vwmulsu_vv_i16m4(v_q8_signed, v_grid_u8, 64); + + // Reduction + vint16m2_t v_dot_lo = __riscv_vget_v_i16m4_i16m2(v_dot, 0); + vint16m2_t v_dot_hi = __riscv_vget_v_i16m4_i16m2(v_dot, 1); + vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, 1); + + int32_t s_lo = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(v_dot_lo, v_zero, 32)); + int32_t s_hi = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(v_dot_hi, v_zero, 32)); + + // Apply sub-scales + uint8_t sc_byte = *scales++; + int sc_lo = (sc_byte & 0xF) * 2 + 1; + int sc_hi = (sc_byte >> 4) * 2 + 1; + + sum_block += s_lo * sc_lo + s_hi * sc_hi; + } + sumf += sum_block * combined_scale; + } + *s = 0.125f * sumf; +} + +void ggml_vec_dot_iq3_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_v_intrinsic + switch (__riscv_vlenb() * 8) { + case 256: + ggml_vec_dot_iq3_s_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); + break; + default: + ggml_vec_dot_iq3_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); + break; + } +#else + ggml_vec_dot_iq3_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + +static void ggml_vec_dot_tq1_0_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_tq1_0 * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + float sumf = 0.0f; + uint8_t pow[16] = {1, 1, 1, 1, 3, 3, 3, 3, 9, 9, 9, 9, 27, 27, 27, 27}; + + for (int i = 0; i < nb; i++) { + // First loop. + vint32m4_t suml1; + { + const int vl = 32; + vuint8m1_t tq = __riscv_vle8_v_u8m1(x[i].qs, vl); + + vuint16m2_t tq0 = __riscv_vsrl_vx_u16m2(__riscv_vwmulu_vx_u16m2(tq, 3, vl), 8, vl); + vuint16m2_t tq1 = __riscv_vsrl_vx_u16m2(__riscv_vwmulu_vx_u16m2(__riscv_vmul_vx_u8m1(tq, 3, vl), 3, vl), 8, vl); + vuint16m2_t tq2 = __riscv_vsrl_vx_u16m2(__riscv_vwmulu_vx_u16m2(__riscv_vmul_vx_u8m1(tq, 9, vl), 3, vl), 8, vl); + vuint16m2_t tq3 = __riscv_vsrl_vx_u16m2(__riscv_vwmulu_vx_u16m2(__riscv_vmul_vx_u8m1(tq, 27, vl), 3, vl), 8, vl); + vuint16m2_t tq4 = __riscv_vsrl_vx_u16m2(__riscv_vwmulu_vx_u16m2(__riscv_vmul_vx_u8m1(tq, 81, vl), 3, vl), 8, vl); + + vint16m2_t q80 = __riscv_vwcvt_x_x_v_i16m2(__riscv_vle8_v_i8m1(y[i].qs + 0, vl), vl); + vint16m2_t q81 = __riscv_vwcvt_x_x_v_i16m2(__riscv_vle8_v_i8m1(y[i].qs + 32, vl), vl); + vint16m2_t q82 = __riscv_vwcvt_x_x_v_i16m2(__riscv_vle8_v_i8m1(y[i].qs + 64, vl), vl); + vint16m2_t q83 = __riscv_vwcvt_x_x_v_i16m2(__riscv_vle8_v_i8m1(y[i].qs + 96, vl), vl); + vint16m2_t q84 = __riscv_vwcvt_x_x_v_i16m2(__riscv_vle8_v_i8m1(y[i].qs + 128, vl), vl); + + vint16m2_t sum0 = __riscv_vmul_vv_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vsub_vx_u16m2(tq0, 1, vl)), q80, vl); + vint16m2_t sum1 = __riscv_vmul_vv_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vsub_vx_u16m2(tq1, 1, vl)), q81, vl); + vint16m2_t sum2 = __riscv_vmul_vv_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vsub_vx_u16m2(tq2, 1, vl)), q82, vl); + vint16m2_t sum3 = __riscv_vmul_vv_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vsub_vx_u16m2(tq3, 1, vl)), q83, vl); + vint16m2_t sum4 = __riscv_vmul_vv_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vsub_vx_u16m2(tq4, 1, vl)), q84, vl); + + vint32m4_t sumi0 = __riscv_vwadd_vv_i32m4(sum0, sum1, vl); + vint32m4_t sumi1 = __riscv_vwadd_vv_i32m4(sum2, sum3, vl); + suml1 = __riscv_vadd_vv_i32m4(__riscv_vwcvt_x_x_v_i32m4(sum4, vl), __riscv_vadd_vv_i32m4(sumi0, sumi1, vl), vl); + } + + // Second loop. + vint32m2_t suml2; + { + const int vl = 16; + vuint8mf2_t tq = __riscv_vle8_v_u8mf2(x[i].qs + 32, vl); + + vuint16m1_t tq0 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(tq, 3 * 1, vl), 8, vl); + vuint16m1_t tq1 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(__riscv_vmul_vx_u8mf2(tq, 3, vl), 3, vl), 8, vl); + vuint16m1_t tq2 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(__riscv_vmul_vx_u8mf2(tq, 9, vl), 3, vl), 8, vl); + vuint16m1_t tq3 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(__riscv_vmul_vx_u8mf2(tq, 27, vl), 3, vl), 8, vl); + vuint16m1_t tq4 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(__riscv_vmul_vx_u8mf2(tq, 81, vl), 3, vl), 8, vl); + + vint16m1_t q80 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 160, vl), vl); + vint16m1_t q81 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 176, vl), vl); + vint16m1_t q82 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 192, vl), vl); + vint16m1_t q83 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 208, vl), vl); + vint16m1_t q84 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 224, vl), vl); + + vint16m1_t sum0 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq0, 1, vl)), q80, vl); + vint16m1_t sum1 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq1, 1, vl)), q81, vl); + vint16m1_t sum2 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq2, 1, vl)), q82, vl); + vint16m1_t sum3 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq3, 1, vl)), q83, vl); + vint16m1_t sum4 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq4, 1, vl)), q84, vl); + + vint32m2_t sumi0 = __riscv_vwadd_vv_i32m2(sum0, sum1, vl); + vint32m2_t sumi1 = __riscv_vwadd_vv_i32m2(sum2, sum3, vl); + suml2 = __riscv_vadd_vv_i32m2(__riscv_vwcvt_x_x_v_i32m2(sum4, vl), __riscv_vadd_vv_i32m2(sumi0, sumi1, vl), vl); + } + + // Third loop. + vint32m2_t suml3; + { + const int vl = 16; + + uint32_t qh; + memcpy(&qh, &x[i].qh[0], 4); + // Prevent fusion with vmv. + __asm__ __volatile__("" : "+r"(qh)); + vuint8mf2_t tq = __riscv_vreinterpret_v_u32mf2_u8mf2(__riscv_vmv_v_x_u32mf2(qh, vl / 4)); + + vuint8mf2_t p = __riscv_vle8_v_u8mf2(pow, vl); + + vuint16m1_t tq0 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(__riscv_vmul_vv_u8mf2(tq, p, vl), 3, vl), 8, vl); + + vint16m1_t q80 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 240, vl), vl); + + vint16m1_t sum0 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq0, 1, vl)), q80, vl); + suml3 = __riscv_vwcvt_x_x_v_i32m2(sum0, vl); + } + + vint32m2_t sumb = __riscv_vadd_vv_i32m2(__riscv_vget_v_i32m4_i32m2(suml1, 0), __riscv_vget_v_i32m4_i32m2(suml1, 1), 16); + sumb = __riscv_vadd_vv_i32m2(sumb, suml2, 16); + sumb = __riscv_vadd_vv_i32m2(sumb, suml3, 16); + + vint32m1_t sum = __riscv_vredsum_vs_i32m2_i32m1(sumb, __riscv_vmv_v_x_i32m1(0, 1), 16); + sumf += __riscv_vmv_x_s_i32m1_i32(sum) * y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); + } + + *s = sumf; +} + +void ggml_vec_dot_tq1_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_v_intrinsic + switch (__riscv_vlenb() * 8) { + case 256: + ggml_vec_dot_tq1_0_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); + break; + default: + ggml_vec_dot_tq1_0_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); + break; + } +#else + ggml_vec_dot_tq1_0_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + +static void ggml_vec_dot_tq2_0_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_tq2_0 * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + float sumf = 0.0f; + for (int i = 0; i < nb; ++i) { + int32_t sumi = 0; + + for (size_t j = 0; j < sizeof(x[0].qs); j += 32) { + const int8_t * py0 = &y[i].qs[j * 4 + 0 * 32]; + const int8_t * py1 = &y[i].qs[j * 4 + 1 * 32]; + const int8_t * py2 = &y[i].qs[j * 4 + 2 * 32]; + const int8_t * py3 = &y[i].qs[j * 4 + 3 * 32]; + const uint8_t* px = &x[i].qs[j]; + + size_t vlmax_16m2 = __riscv_vsetvl_e16m2(32); + vint16m2_t vacc16 = __riscv_vmv_v_x_i16m2(0, vlmax_16m2); + + size_t vl = __riscv_vsetvl_e8m1(32); + + vuint8m1_t vx_u8 = __riscv_vle8_v_u8m1(px, vl); + + vint8m1_t vy0 = __riscv_vle8_v_i8m1(py0 , vl); + vint8m1_t vy1 = __riscv_vle8_v_i8m1(py1, vl); + vint8m1_t vy2 = __riscv_vle8_v_i8m1(py2, vl); + vint8m1_t vy3 = __riscv_vle8_v_i8m1(py3, vl); + + // l=0 (bits 1:0) + vuint8m1_t t0 = __riscv_vand_vx_u8m1(vx_u8, 0x03, vl); + vint8m1_t vq0 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(t0), 1, vl); + + // l=1 (bits 3:2) + vuint8m1_t t1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(vx_u8, 2, vl), 0x03, vl); + vint8m1_t vq1 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(t1), 1, vl); + + // l=2 (bits 5:4) + vuint8m1_t t2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(vx_u8, 4, vl), 0x03, vl); + vint8m1_t vq2 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(t2), 1, vl); + + // l=3 (bits 7:6) + vuint8m1_t t3 = __riscv_vsrl_vx_u8m1(vx_u8, 6, vl); // No final AND needed as vsrl shifts in zeros + vint8m1_t vq3 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(t3), 1, vl); + + // 4. Multiply and accumulate + vacc16 = __riscv_vwmacc_vv_i16m2(vacc16, vq0, vy0, vl); + vacc16 = __riscv_vwmacc_vv_i16m2(vacc16, vq1, vy1, vl); + vacc16 = __riscv_vwmacc_vv_i16m2(vacc16, vq2, vy2, vl); + vacc16 = __riscv_vwmacc_vv_i16m2(vacc16, vq3, vy3, vl); + + vlmax_16m2 = __riscv_vsetvl_e16m2(32); + vint32m1_t vzero32 = __riscv_vmv_v_x_i32m1(0, 1); + vint32m1_t vred32 = __riscv_vwredsum_vs_i16m2_i32m1(vacc16, vzero32, vlmax_16m2); + + sumi += __riscv_vmv_x_s_i32m1_i32(vred32); + } + const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); + sumf += (float)sumi * d; + } + + *s = sumf; +} + +void ggml_vec_dot_tq2_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_v_intrinsic + switch (__riscv_vlenb() * 8) { + case 256: + ggml_vec_dot_tq2_0_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); + break; + default: + ggml_vec_dot_tq2_0_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); + break; + } +#else + ggml_vec_dot_tq2_0_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + +static void ggml_vec_dot_iq1_s_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq1_s * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + // Load qh once for the entire superblock. + vuint16mf2_t qh = __riscv_vle16_v_u16mf2(x[i].qh, 8); + + // Calculate ls. + vuint16mf2_t temp = __riscv_vsrl_vx_u16mf2(qh, 12, 8); + temp = __riscv_vand_vx_u16mf2(temp, 7, 8); + vint32m1_t ls = __riscv_vreinterpret_v_u32m1_i32m1(__riscv_vwmulu_vx_u32m1(temp, 2, 8)); + ls = __riscv_vadd_vx_i32m1(ls, 1, 8); + + // Calculate delta. + vbool32_t mask = __riscv_vmseq_vx_u16mf2_b32(__riscv_vand_vx_u16mf2(qh, 0x8000, 8), 0, 8); + vint32m1_t delta_neg = __riscv_vmv_v_x_i32m1(-1, 8); + vint32m1_t delta_pos = __riscv_vmv_v_x_i32m1(1, 8); + vint32m1_t delta = __riscv_vmerge_vvm_i32m1(delta_neg, delta_pos, mask, 8); + + // Load qs. + vuint8m1_t qs = __riscv_vle8_v_u8m1(x[i].qs, 32); + + // Prepare the indices. + const uint64_t shift = 0x0009000600030000; + vuint16m2_t qh_shift = __riscv_vreinterpret_v_u64m2_u16m2(__riscv_vmv_v_x_u64m2(shift, 8)); + vuint16m2_t qh_gather_index = __riscv_vreinterpret_v_i16m2_u16m2( + __riscv_vdiv_vx_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vid_v_u16m2(32)), 4, 32)); + vuint16m2_t qh_ext = __riscv_vlmul_ext_v_u16m1_u16m2(__riscv_vlmul_ext_v_u16mf2_u16m1(qh)); + vuint16m2_t qh_index = __riscv_vrgather_vv_u16m2(qh_ext, qh_gather_index, 32); + qh_index = __riscv_vsrl_vv_u16m2(qh_index, qh_shift, 32); + qh_index = __riscv_vand_vx_u16m2(qh_index, 7, 32); + qh_index = __riscv_vsll_vx_u16m2(qh_index, 8, 32); + qh_index = __riscv_vor_vv_u16m2(qh_index, __riscv_vzext_vf2_u16m2(qs, 32), 32); + vuint16m2_t index = __riscv_vsll_vx_u16m2(qh_index, 3, 32); + + // Final lsums. + int32_t lsums_s[8]; + vint32m1_t one_scalar = __riscv_vmv_v_x_i32m1(0, 1); + + // Sub-blocks 1-4 + { + vuint16m1_t grid_index0 = __riscv_vget_v_u16m2_u16m1(index, 0); + vint8m4_t grid0 = __riscv_vreinterpret_v_i64m4_i8m4(__riscv_vluxei16_v_i64m4((const int64_t*)iq1s_grid, grid_index0, 16)); + vint8m4_t q80 = __riscv_vle8_v_i8m4(y[i].qs, 128); + vint16m8_t lsum0 = __riscv_vwmul_vv_i16m8(grid0, q80, 128); + lsums_s[0] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum0, 0), one_scalar, 32)); + lsums_s[1] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum0, 1), one_scalar, 32)); + lsums_s[2] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum0, 2), one_scalar, 32)); + lsums_s[3] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum0, 3), one_scalar, 32)); + } + __asm__ __volatile__("" ::: "memory"); + // Sub-blocks 5-8 + { + vuint16m1_t grid_index1 = __riscv_vget_v_u16m2_u16m1(index, 1); + vint8m4_t grid1 = __riscv_vreinterpret_v_i64m4_i8m4(__riscv_vluxei16_v_i64m4((const int64_t*)iq1s_grid, grid_index1, 16)); + vint8m4_t q81 = __riscv_vle8_v_i8m4(&y[i].qs[128], 128); + vint16m8_t lsum1 = __riscv_vwmul_vv_i16m8(grid1, q81, 128); + lsums_s[4] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum1, 0), one_scalar, 32)); + lsums_s[5] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum1, 1), one_scalar, 32)); + lsums_s[6] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum1, 2), one_scalar, 32)); + lsums_s[7] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum1, 3), one_scalar, 32)); + } + __asm__ __volatile__("" ::: "memory"); + vint32m1_t lsums = __riscv_vle32_v_i32m1(&lsums_s[0], 8); + + // Calculate the bsums. + vint16m1_t bsums_0 = __riscv_vle16_v_i16m1(y[i].bsums, 16); + const vuint32m1_t bsums_i32 = __riscv_vreinterpret_v_u16m1_u32m1(__riscv_vreinterpret_v_i16m1_u16m1(bsums_0)); + const vint16mf2_t bsums_i32_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(bsums_i32, 0, 8)); + const vint16mf2_t bsums_i32_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(bsums_i32, 16, 8)); + const vint32m1_t bsums = __riscv_vwadd_vv_i32m1(bsums_i32_0, bsums_i32_1, 8); + + // Accumulation. + vint32m1_t sumi_v = __riscv_vmul_vv_i32m1(ls, lsums, 8); + vint32m1_t sumi1_v = __riscv_vmul_vv_i32m1(__riscv_vmul_vv_i32m1(ls, delta, 8), bsums, 8); + + // Update sumf. + int sumi = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m1_i32m1(sumi_v, __riscv_vmv_v_x_i32m1(0.0f, 1), 8)); + int sumi1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m1_i32m1(sumi1_v, __riscv_vmv_v_x_i32m1(0.0f, 1), 8)); + sumf += GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d * (sumi + IQ1S_DELTA * sumi1); + } + + *s = sumf; +} + +void ggml_vec_dot_iq1_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_v_intrinsic + switch (__riscv_vlenb() * 8) { + case 256: + ggml_vec_dot_iq1_s_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); + break; + default: + ggml_vec_dot_iq1_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); + break; + } +#else + ggml_vec_dot_iq1_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + +static void ggml_vec_dot_iq1_m_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq1_m * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + iq1m_scale_t scale; + float sumf = 0.0f; + for (int i = 0; i < nb; ++i) { + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint8_t * qh = x[i].qh; + const uint16_t * sc = (const uint16_t *)x[i].scales; + + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); + + // Accumulators. + vint32m2_t acc1 = __riscv_vmv_v_x_i32m2(0, 16); + vint32m2_t acc2 = __riscv_vmv_v_x_i32m2(0, 16); + + // We process 4 sub-blocks together. + for (int ib = 0; ib < QK_K/128; ib++) { + // Load qh for 4 sub-blocks. + const vuint8mf4_t qh_8 = __riscv_vle8_v_u8mf4(qh, 8); + const vuint16mf2_t qh_16_lo = __riscv_vzext_vf2_u16mf2(qh_8, 8); + const vuint16mf2_t qh_16_hi = __riscv_vsll_vx_u16mf2(qh_16_lo, 8, 8); + const vuint16m1_t qhb = __riscv_vzext_vf2_u16m1( + __riscv_vreinterpret_v_u16mf2_u8mf2(__riscv_vor_vv_u16mf2(qh_16_lo, qh_16_hi, 8)), 16); + qh += 8; + + // Prepare grid indices. + const vuint16m1_t qsb = __riscv_vzext_vf2_u16m1(__riscv_vle8_v_u8mf2(&qs[0], 16), 16); + const vuint16m1_t shift = __riscv_vreinterpret_v_u32m1_u16m1(__riscv_vmv_v_x_u32m1(0x00040008, 8)); + vuint16m1_t index = __riscv_vor_vv_u16m1(qsb, __riscv_vand_vx_u16m1(__riscv_vsll_vv_u16m1(qhb, shift, 16), 0x700, 16), 16); + index = __riscv_vsll_vx_u16m1(index, 3, 16); + qs += 16; + + // Load the grid. + const vint8m4_t iq1b = __riscv_vreinterpret_v_i64m4_i8m4(__riscv_vreinterpret_v_u64m4_i64m4( + __riscv_vluxei16_v_u64m4(iq1s_grid, index, 16))); + + // Prepare the deltas. + const vbool16_t mask = __riscv_vmsgtu_vx_u16m1_b16( + __riscv_vand_vv_u16m1(qhb, __riscv_vreinterpret_v_u32m1_u16m1(__riscv_vmv_v_x_u32m1(0x00800008, 8)), 16), 0, 16); + const vint64m4_t delta_pos = __riscv_vmv_v_x_i64m4(0x0101010101010101, 16); + const vint64m4_t delta_neg = __riscv_vmv_v_x_i64m4(0xffffffffffffffff, 16); + const vint8m4_t delta = __riscv_vreinterpret_v_i64m4_i8m4( + __riscv_vmerge_vvm_i64m4(delta_pos, delta_neg, mask, 16)); + + // Load q8 for sub-blocks. + const vint8m4_t q8b = __riscv_vle8_v_i8m4(q8, 128); + q8 += 128; + + // Calculate the lsums. + const vint16m8_t lsum1 = __riscv_vwmul_vv_i16m8(iq1b, q8b, 128); + const vint16m8_t lsum2 = __riscv_vwmul_vv_i16m8(delta, q8b, 128); + + // Prepare the scales. + const int16_t ls_0_0 = 2*((sc[0] >> 0) & 0x7) + 1; + const int16_t ls_0_1 = 2*((sc[0] >> 3) & 0x7) + 1; + const int16_t ls_1_0 = 2*((sc[0] >> 6) & 0x7) + 1; + const int16_t ls_1_1 = 2*((sc[0] >> 9) & 0x7) + 1; + const int16_t ls_2_0 = 2*((sc[1] >> 0) & 0x7) + 1; + const int16_t ls_2_1 = 2*((sc[1] >> 3) & 0x7) + 1; + const int16_t ls_3_0 = 2*((sc[1] >> 6) & 0x7) + 1; + const int16_t ls_3_1 = 2*((sc[1] >> 9) & 0x7) + 1; + sc += 2; + + // Accumulate in acc0 and acc1 for each sub-block. + acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_0_0, __riscv_vget_v_i16m8_i16m1(lsum1, 0), 16); + acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_0_1, __riscv_vget_v_i16m8_i16m1(lsum1, 1), 16); + acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_0_0, __riscv_vget_v_i16m8_i16m1(lsum2, 0), 16); + acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_0_1, __riscv_vget_v_i16m8_i16m1(lsum2, 1), 16); + // + acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_1_0, __riscv_vget_v_i16m8_i16m1(lsum1, 2), 16); + acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_1_1, __riscv_vget_v_i16m8_i16m1(lsum1, 3), 16); + acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_1_0, __riscv_vget_v_i16m8_i16m1(lsum2, 2), 16); + acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_1_1, __riscv_vget_v_i16m8_i16m1(lsum2, 3), 16); + // + acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_2_0, __riscv_vget_v_i16m8_i16m1(lsum1, 4), 16); + acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_2_1, __riscv_vget_v_i16m8_i16m1(lsum1, 5), 16); + acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_2_0, __riscv_vget_v_i16m8_i16m1(lsum2, 4), 16); + acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_2_1, __riscv_vget_v_i16m8_i16m1(lsum2, 5), 16); + // + acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_3_0, __riscv_vget_v_i16m8_i16m1(lsum1, 6), 16); + acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_3_1, __riscv_vget_v_i16m8_i16m1(lsum1, 7), 16); + acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_3_0, __riscv_vget_v_i16m8_i16m1(lsum2, 6), 16); + acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_3_1, __riscv_vget_v_i16m8_i16m1(lsum2, 7), 16); + } + + // Reduce and accumulate in `sumf`. + vint32m1_t one = __riscv_vmv_v_x_i32m1(0, 1); + int sumi1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m2_i32m1(acc1, one, 16)); + int sumi2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m2_i32m1(acc2, one, 16)); + sumf += y[i].d * GGML_CPU_FP16_TO_FP32(scale.f16) * (sumi1 + IQ1M_DELTA * sumi2); + } + + *s = sumf; +} + +void ggml_vec_dot_iq1_m_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_v_intrinsic + switch (__riscv_vlenb() * 8) { + case 256: + ggml_vec_dot_iq1_m_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); + break; + default: + ggml_vec_dot_iq1_m_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); + break; + } +#else + ggml_vec_dot_iq1_m_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} From 98915f889a53d9cfe02625fb01479225d02630d8 Mon Sep 17 00:00:00 2001 From: Gaurav Garg Date: Sat, 21 Feb 2026 15:09:36 +0530 Subject: [PATCH 179/831] Improve CUDA graph capture (llama/19754) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Improve CUDA graph capture Currently, CUDA graphs are eagerly enabled on the first call to ggml_backend_cuda_graph_compute. If the graph properties keep changing (4+ consecutive updates), the graph is permanently disabled. This is suboptimal because: - The first call always incurs CUDA graph capture overhead even if the graph is unstable - Once permanently disabled, CUDA graphs never re-enable even after the graph stabilizes (e.g., switching from prompt processing to decode) The new approach delays CUDA graph activation until warmup completes: the same cgraph must be called at least twice with matching properties before CUDA graph capture begins. This avoids wasted capture overhead on volatile graphs and allows graphs to become eligible once they stabilize. This also fixes issues such as https://github.com/ggml-org/llama.cpp/discussions/19708 * Update ggml/src/ggml-cuda/ggml-cuda.cu Co-authored-by: Johannes Gäßler * Remove EM dashes * Update ggml/src/ggml-cuda/ggml-cuda.cu Co-authored-by: Aman Gupta --------- Co-authored-by: Johannes Gäßler Co-authored-by: Aman Gupta --- ggml/src/ggml-cuda/common.cuh | 17 ++-------------- ggml/src/ggml-cuda/ggml-cuda.cu | 35 ++++++++++++++++++++++++--------- 2 files changed, 28 insertions(+), 24 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index a3256d59dd0..36d8a3aaab2 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -1149,8 +1149,7 @@ struct ggml_cuda_graph { size_t num_nodes = 0; std::vector nodes; bool disable_due_to_gpu_arch = false; - bool disable_due_to_too_many_updates = false; - int number_consecutive_updates = 0; + bool warmup_complete = false; std::vector props; // these are extra tensors (inputs) that participate in the ggml graph but are not nodes @@ -1159,21 +1158,9 @@ struct ggml_cuda_graph { // ref: https://github.com/ggml-org/llama.cpp/pull/19165 std::vector extra; - void record_update(bool use_graph, bool update_required) { - if (use_graph && update_required) { - number_consecutive_updates++; - } else { - number_consecutive_updates = 0; - } - if (number_consecutive_updates >= 4) { - GGML_LOG_DEBUG("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__); - disable_due_to_too_many_updates = true; - } - } - bool is_enabled() const { static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr); - return !(disable_due_to_gpu_arch || disable_cuda_graphs_due_to_env || disable_due_to_too_many_updates); + return !(disable_due_to_gpu_arch || disable_cuda_graphs_due_to_env); } #endif }; diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index ffa35eeb654..7e6d3303549 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2979,10 +2979,6 @@ static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx const void * graph_key = ggml_cuda_graph_get_key(cgraph); ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key); - if (graph->instance == nullptr) { - res = true; - } - // Check if the graph size has changed if (graph->props.size() != (size_t)cgraph->n_nodes) { res = true; @@ -3931,14 +3927,35 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, #ifdef USE_CUDA_GRAPH graph_key = ggml_cuda_graph_get_key(cgraph); - use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx, graph_key); + ggml_cuda_graph_set_enabled(cuda_ctx, graph_key); ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key); if (graph->is_enabled()) { - cuda_graph_update_required = ggml_cuda_graph_update_required(cuda_ctx, cgraph); - use_cuda_graph = ggml_cuda_graph_check_compability(cgraph); - - graph->record_update(use_cuda_graph, cuda_graph_update_required); + const bool graph_compatible = ggml_cuda_graph_check_compability(cgraph); + if (graph_compatible) { + const bool properties_changed = ggml_cuda_graph_update_required(cuda_ctx, cgraph); + + if (!graph->warmup_complete) { + // Warmup: need at least 2 calls with no property change on the 2nd call + if (!properties_changed) { + graph->warmup_complete = true; + GGML_LOG_DEBUG("%s: CUDA graph warmup complete\n", __func__); + use_cuda_graph = true; + cuda_graph_update_required = true; + } + // else: properties changed or first call - execute directly (use_cuda_graph stays false) + } else { + // Post-warmup: normal CUDA graph operation + if (properties_changed) { + // Properties changed - reset warmup, execute directly until stable again + graph->warmup_complete = false; + GGML_LOG_DEBUG("%s: CUDA graph warmup reset\n", __func__); + } else { + use_cuda_graph = true; + cuda_graph_update_required = graph->instance == nullptr; + } + } + } } #endif // USE_CUDA_GRAPH From 06fbd9c5f23b2ca0c5c2141e3500048cddb9688a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alberto=20Cabrera=20P=C3=A9rez?= <1478977+Alcpz@users.noreply.github.com> Date: Mon, 23 Feb 2026 12:42:52 +0000 Subject: [PATCH 180/831] ggml-cpu: arm64: q5_K repack gemm and gemv (and generic) implementations (dotprod) (llama/19356) * Generic GEMV and boilerplate for q5_K dotprod * Generic GEMM and boilerplate for q5_K dotprod * ARM64 q5_K dotprod GEMM * ARM64 q5_K dotprod GEMV --- ggml/src/ggml-cpu/arch-fallback.h | 16 +- ggml/src/ggml-cpu/arch/arm/repack.cpp | 388 +++++++++++++++++++++++ ggml/src/ggml-cpu/repack.cpp | 435 ++++++++++++++------------ ggml/src/ggml-cpu/repack.h | 4 + 4 files changed, 642 insertions(+), 201 deletions(-) diff --git a/ggml/src/ggml-cpu/arch-fallback.h b/ggml/src/ggml-cpu/arch-fallback.h index 55526e6fb38..4dfe28e1d64 100644 --- a/ggml/src/ggml-cpu/arch-fallback.h +++ b/ggml/src/ggml-cpu/arch-fallback.h @@ -42,6 +42,7 @@ #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K +#define ggml_gemv_q5_K_8x4_q8_K_generic ggml_gemv_q5_K_8x4_q8_K #define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K #define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K @@ -55,9 +56,10 @@ #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K +#define ggml_gemm_q5_K_8x4_q8_K_generic ggml_gemm_q5_K_8x4_q8_K #define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K #define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K -#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K +#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 @@ -77,6 +79,7 @@ #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K +#define ggml_gemv_q5_K_8x4_q8_K_generic ggml_gemv_q5_K_8x4_q8_K #define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K #define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K @@ -86,6 +89,7 @@ #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K +#define ggml_gemm_q5_K_8x4_q8_K_generic ggml_gemm_q5_K_8x4_q8_K #define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K #define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K @@ -110,6 +114,7 @@ #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K +#define ggml_gemv_q5_K_8x4_q8_K_generic ggml_gemv_q5_K_8x4_q8_K #define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K #define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K @@ -123,6 +128,7 @@ #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K +#define ggml_gemm_q5_K_8x4_q8_K_generic ggml_gemm_q5_K_8x4_q8_K #define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K #define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K @@ -148,6 +154,7 @@ #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K +#define ggml_gemv_q5_K_8x4_q8_K_generic ggml_gemv_q5_K_8x4_q8_K #define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K #define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K @@ -161,6 +168,7 @@ #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K +#define ggml_gemm_q5_K_8x4_q8_K_generic ggml_gemm_q5_K_8x4_q8_K #define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K #define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K @@ -187,6 +195,7 @@ #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K +#define ggml_gemv_q5_K_8x4_q8_K_generic ggml_gemv_q5_K_8x4_q8_K #define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K #define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K @@ -199,6 +208,7 @@ #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K +#define ggml_gemm_q5_K_8x4_q8_K_generic ggml_gemm_q5_K_8x4_q8_K #define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K #define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K @@ -230,6 +240,7 @@ #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K +#define ggml_gemv_q5_K_8x4_q8_K_generic ggml_gemv_q5_K_8x4_q8_K #define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K #define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K @@ -243,6 +254,7 @@ #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K +#define ggml_gemm_q5_K_8x4_q8_K_generic ggml_gemm_q5_K_8x4_q8_K #define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K #define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K @@ -276,6 +288,7 @@ #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K +#define ggml_gemv_q5_K_8x4_q8_K_generic ggml_gemv_q5_K_8x4_q8_K #define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K #define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K @@ -289,6 +302,7 @@ #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K +#define ggml_gemm_q5_K_8x4_q8_K_generic ggml_gemm_q5_K_8x4_q8_K #define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K #define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K diff --git a/ggml/src/ggml-cpu/arch/arm/repack.cpp b/ggml/src/ggml-cpu/arch/arm/repack.cpp index 3a3b32efb2b..c2e4623f371 100644 --- a/ggml/src/ggml-cpu/arch/arm/repack.cpp +++ b/ggml/src/ggml-cpu/arch/arm/repack.cpp @@ -785,6 +785,165 @@ void ggml_gemv_q4_K_8x8_q8_K(int n, ggml_gemv_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc); } +void ggml_gemv_q5_K_8x4_q8_K(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + constexpr int qk = QK_K; + const int nb = n / qk; + + constexpr int ncols_interleaved = 8; + constexpr int blocklen = 4; + + assert(n % qk == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + constexpr int col_groups = ncols_interleaved / 4; // 0123 and 4567 + const uint8x16_t m4b = vdupq_n_u8(0x0f); + const uint8x16_t mone = vdupq_n_u8(1); + const uint8x16_t mtwo = vdupq_n_u8(2); + + // 1x8 tile = 2 x 4 + float32x4_t acc_f32[col_groups]; + + const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy; + + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q5_Kx8 * GGML_RESTRICT q5_ptr = (const block_q5_Kx8 *) vx + (x * nb); + + for (int i = 0; i < col_groups; i++) { + acc_f32[i] = vdupq_n_f32(0); + } + + for (int b = 0; b < nb; b++) { + float32x4_t q5_d_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d)); // d0 d1 d2 d3 + float32x4_t q5_d_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d + 4)); // d4 d5 d6 d7 + float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d); + float32x4_t sb_scale_0123 = vmulq_f32(q5_d_0, q8_d); + float32x4_t sb_scale_4567 = vmulq_f32(q5_d_1, q8_d); + float32x4_t q5_dmin_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin)); // dmin 0..3 + float32x4_t q5_dmin_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin + 4)); // dmin 4..7 + float32x4_t sb_min_0123 = vmulq_f32(q5_dmin_0, q8_d); + float32x4_t sb_min_4567 = vmulq_f32(q5_dmin_1, q8_d); + + // interleaved bias_acc: [0]->r0 0123, [1]->r0 4567 + int32x4_t bias_acc[2] = { vdupq_n_s32(0), vdupq_n_s32(0) }; + int32x4_t acc_lo[col_groups]; + int32x4_t acc_hi[col_groups]; + + // Each bsum is 16 elements, pairwise add leaves us with the 8 bsums of the entire block + const int16x8_t bsums = vpaddq_s16(vld1q_s16(q8_ptr[b].bsums), vld1q_s16(q8_ptr[b].bsums + 8)); + int16_t bsums_arr[8]; + vst1q_s16(bsums_arr, bsums); + + uint8x16_t qh[col_groups][8]; + for (int c = 0; c < col_groups; c++) { + for (int i = 0; i < 8; i++) { + qh[c][i] = vld1q_u8(q5_ptr[b].qh + i * 32 + 16 * c); + } + } + + for (int sb = 0; sb < QK_K / 64; sb++) { + for (int i = 0; i < col_groups; i++) { + acc_lo[i] = vdupq_n_s32(0); + acc_hi[i] = vdupq_n_s32(0); + } + // Need scales for the low and high nibbles + // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total + int16x8_t q5sb_mins[2]; + int16x8_t q5sb_scales[2]; + for (int i = 0; i < 2; i++) { + int8_t aux_q5sb[8]; + const int offset = sb * 24 + i * 12; + decode_q_Kx8_6bit_scales(&q5_ptr[b].scales[offset], &q5sb_mins[i], aux_q5sb); + q5sb_scales[i] = vmovl_s8(vld1_s8(aux_q5sb)); + } + + int8x16_t q8_qs[4]; + for (int i = 0; i < 4; i++) { + q8_qs[i] = vld1q_s8(q8_ptr[b].qs + sb * 64 + i * 16); + } + + for (int c = 0; c < col_groups; c++) { + uint8x16_t q5_cols[8]; + uint8x16_t hbit_lo[8]; + uint8x16_t hbit_hi[8]; + int8x16_t q5_lo[8]; + int8x16_t q5_hi[8]; + + for (int i = 0; i < 8; i++) { + q5_cols[i] = vld1q_u8(q5_ptr[b].qs + sb * QK_K + i * 32 + 16 * c); + hbit_lo[i] = vandq_u8(qh[c][i], mone); + hbit_hi[i] = vshlq_n_u8(vandq_u8(qh[c][i], mtwo), 3); + qh[c][i] = vshrq_n_u8(qh[c][i], 2); + q5_lo[i] = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q5_cols[i], m4b), hbit_lo[i], 4)); + q5_hi[i] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5_cols[i], 4), hbit_hi[i])); + } + + acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[0], q8_qs[0], 0); + acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[1], q8_qs[0], 1); + acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[2], q8_qs[0], 2); + acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[3], q8_qs[0], 3); + acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[4], q8_qs[1], 0); + acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[5], q8_qs[1], 1); + acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[6], q8_qs[1], 2); + acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[7], q8_qs[1], 3); + + acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[0], q8_qs[2], 0); + acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[1], q8_qs[2], 1); + acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[2], q8_qs[2], 2); + acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[3], q8_qs[2], 3); + acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[4], q8_qs[3], 0); + acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[5], q8_qs[3], 1); + acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[6], q8_qs[3], 2); + acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[7], q8_qs[3], 3); + } + + // Scales + // row c0123 blk0 and blk1 + const int16x4_t sc_0123_lo = vget_low_s16(q5sb_scales[0]); + const int16x4_t sc_0123_hi = vget_low_s16(q5sb_scales[1]); + const float32x4_t sumf_0123 = vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_0123_lo), acc_lo[0]), + vmulq_s32(vmovl_s16(sc_0123_hi), acc_hi[0]))); + acc_f32[0] = vfmaq_f32(acc_f32[0], sb_scale_0123, sumf_0123); + // row c4567 blk0 and blk1 + const int16x4_t sc_4567_lo = vget_high_s16(q5sb_scales[0]); + const int16x4_t sc_4567_hi = vget_high_s16(q5sb_scales[1]); + const float32x4_t sumf_4567 = vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_4567_lo), acc_lo[1]), + vmulq_s32(vmovl_s16(sc_4567_hi), acc_hi[1]))); + acc_f32[1] = vfmaq_f32(acc_f32[1], sb_scale_4567, sumf_4567); + + // Bias Correction + const int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[2 * sb + 0]); + const int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[2 * sb + 1]); + + bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_lo, vget_low_s16(q5sb_mins[0])); + bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_hi, vget_low_s16(q5sb_mins[1])); + bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_lo, vget_high_s16(q5sb_mins[0])); + bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_hi, vget_high_s16(q5sb_mins[1])); + } // for sb + + acc_f32[0] = vmlsq_f32(acc_f32[0], vcvtq_f32_s32(bias_acc[0]), sb_min_0123); + acc_f32[1] = vmlsq_f32(acc_f32[1], vcvtq_f32_s32(bias_acc[1]), sb_min_4567); + } // for b + + int base = x * ncols_interleaved; + vst1q_f32(s + base, acc_f32[0]); + vst1q_f32(s + base + 4, acc_f32[1]); + } // for x + return; +#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + ggml_gemv_q5_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc); +} + void ggml_gemv_q5_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, @@ -3205,6 +3364,235 @@ void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo ggml_gemm_q4_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc); } +void ggml_gemm_q5_K_8x4_q8_K(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + constexpr int qk = QK_K; + const int nb = n / qk; + + constexpr int ncols_interleaved = 8; + constexpr int blocklen = 4; + + assert(n % qk == 0); + assert(nr % 4 == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + constexpr int q8_k_blocklen = 4; + constexpr int acc_size = 2 * 4; // 2 row pairs, 4 col pairs + constexpr int col_groups = ncols_interleaved / 4; + const uint8x16_t m4b = vdupq_n_u8(0x0f); + const uint8x16_t mone = vdupq_n_u8(1); + const uint8x16_t mtwo = vdupq_n_u8(2); + + // 8 accumulators: 2 row pairs, 4 col pairs + float32x4_t acc_f32[acc_size]; + + for (int y = 0; y < nr / q8_k_blocklen; y++) { + const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb); + + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q5_Kx8 * GGML_RESTRICT q5_ptr = (const block_q5_Kx8 *) vx + (x * nb); + + for (int i = 0; i < acc_size; i++) { + acc_f32[i] = vdupq_n_f32(0); + } + + for (int b = 0; b < nb; b++) { + // d5 0 1 2 3, 4 5 6 7 + float32x4_t q5_d_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d)); + float32x4_t q5_d_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d + 4)); + // d8 0 1 2 3 + float32x4_t q8_d_0123 = vld1q_f32(q8_ptr[b].d); + // mins + float32x4_t q5_dmin_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin)); + float32x4_t q5_dmin_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin + 4)); + + // Precomputation of scales and mins + float32x4_t sbd_scale_0123[q8_k_blocklen]; + float32x4_t sbd_scale_4567[q8_k_blocklen]; + float32x4_t sbd_min_0123[q8_k_blocklen]; + float32x4_t sbd_min_4567[q8_k_blocklen]; + + sbd_scale_0123[0] = vmulq_laneq_f32(q5_d_0123, q8_d_0123, 0); + sbd_scale_4567[0] = vmulq_laneq_f32(q5_d_4567, q8_d_0123, 0); + sbd_min_0123[0] = vmulq_laneq_f32(q5_dmin_0123, q8_d_0123, 0); + sbd_min_4567[0] = vmulq_laneq_f32(q5_dmin_4567, q8_d_0123, 0); + + sbd_scale_0123[1] = vmulq_laneq_f32(q5_d_0123, q8_d_0123, 1); + sbd_scale_4567[1] = vmulq_laneq_f32(q5_d_4567, q8_d_0123, 1); + sbd_min_0123[1] = vmulq_laneq_f32(q5_dmin_0123, q8_d_0123, 1); + sbd_min_4567[1] = vmulq_laneq_f32(q5_dmin_4567, q8_d_0123, 1); + + sbd_scale_0123[2] = vmulq_laneq_f32(q5_d_0123, q8_d_0123, 2); + sbd_scale_4567[2] = vmulq_laneq_f32(q5_d_4567, q8_d_0123, 2); + sbd_min_0123[2] = vmulq_laneq_f32(q5_dmin_0123, q8_d_0123, 2); + sbd_min_4567[2] = vmulq_laneq_f32(q5_dmin_4567, q8_d_0123, 2); + + sbd_scale_0123[3] = vmulq_laneq_f32(q5_d_0123, q8_d_0123, 3); + sbd_scale_4567[3] = vmulq_laneq_f32(q5_d_4567, q8_d_0123, 3); + sbd_min_0123[3] = vmulq_laneq_f32(q5_dmin_0123, q8_d_0123, 3); + sbd_min_4567[3] = vmulq_laneq_f32(q5_dmin_4567, q8_d_0123, 3); + + // Precomputation of bsums, each vpaddq calcs all the bsums for each row + const int16x8_t bsums[q8_k_blocklen] = { + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)), + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)), + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)), + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)), + }; + int16_t bsums_arr[QK_K / 64][8]; + for (int q8_row = 0; q8_row < 4; q8_row++) { + vst1q_s16(bsums_arr[q8_row], bsums[q8_row]); + } + + // interleaved bias_acc: [0]->r0 0123, [1]->r1 0123, .., [4]->r0 4567, [5]->r1 4567 .. + int32x4_t bias_acc[acc_size]; + for (int i = 0; i < acc_size; i++) { + bias_acc[i] = vdupq_n_s32(0); + } + + uint8x16_t qh[col_groups][8]; + for (int c = 0; c < col_groups; c++) { + for (int i = 0; i < 8; i++) { + qh[c][i] = vld1q_u8(q5_ptr[b].qh + i * 32 + 16 * c); + } + } + + for (int sb = 0; sb < QK_K / 64; sb++) { + // Int accumulators for qs vecdot (4 row * 2 col quartets) + int32x4_t acc_lo[acc_size]; + int32x4_t acc_hi[acc_size]; + for (int i = 0; i < acc_size; i++) { + acc_lo[i] = vdupq_n_s32(0); + acc_hi[i] = vdupq_n_s32(0); + } + // Need scales for the low and high nibbles + // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total + int16x8_t q5sb_scales[2]; + int16x8_t q5sb_mins[2]; + for (int i = 0; i < 2; i++) { + int8_t aux_q5sb[8]; + const int offset = sb * 24 + i * 12; + decode_q_Kx8_6bit_scales(&q5_ptr[b].scales[offset], &q5sb_mins[i], aux_q5sb); + q5sb_scales[i] = vmovl_s8(vld1_s8(aux_q5sb)); + } + + constexpr int reads_per_sb = 8; // 8 * 16 bytes each => 32 qs * 4 rows + for (int k = 0; k < reads_per_sb; k++) { + const int8x16_t q8_blk0 = vld1q_s8(q8_ptr[b].qs + sb * 256 + 16 * k); + const int8x16_t q8_blk1 = vld1q_s8(q8_ptr[b].qs + sb * 256 + 16 * k + 128); + + // 0..3 & 32..35 + const uint8x16_t q5_0123 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 32 * k); + const uint8x16_t q5_4567 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 32 * k + 16); + + // NOTE: This is the only difference with q4_K + const uint8x16_t hbit_lo_0123 = vandq_u8(qh[0][k], mone); + const uint8x16_t hbit_hi_0123 = vshlq_n_u8(vandq_u8(qh[0][k], mtwo), 3); + qh[0][k] = vshrq_n_u8(qh[0][k], 2); + const uint8x16_t hbit_lo_4567 = vandq_u8(qh[1][k], mone); + const uint8x16_t hbit_hi_4567 = vshlq_n_u8(vandq_u8(qh[1][k], mtwo), 3); + qh[1][k] = vshrq_n_u8(qh[1][k], 2); + // From here, same as q4_K + + const int8x16_t q5_0123_lo = + vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q5_0123, m4b), hbit_lo_0123, 4)); + const int8x16_t q5_0123_hi = + vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5_0123, 4), hbit_hi_0123)); + + acc_lo[0] = vdotq_laneq_s32(acc_lo[0], q5_0123_lo, q8_blk0, 0); // 0..3 r0 c0123 + acc_lo[1] = vdotq_laneq_s32(acc_lo[1], q5_0123_lo, q8_blk0, 1); // 0..3 r1 c0123 + acc_lo[2] = vdotq_laneq_s32(acc_lo[2], q5_0123_lo, q8_blk0, 2); // 0..3 r2 c0123 + acc_lo[3] = vdotq_laneq_s32(acc_lo[3], q5_0123_lo, q8_blk0, 3); // 0..3 r3 c0123 + + acc_hi[0] = vdotq_laneq_s32(acc_hi[0], q5_0123_hi, q8_blk1, 0); // 32..35 r0 c0123 + acc_hi[1] = vdotq_laneq_s32(acc_hi[1], q5_0123_hi, q8_blk1, 1); // 32..35 r1 c0123 + acc_hi[2] = vdotq_laneq_s32(acc_hi[2], q5_0123_hi, q8_blk1, 2); // 32..35 r2 c0123 + acc_hi[3] = vdotq_laneq_s32(acc_hi[3], q5_0123_hi, q8_blk1, 3); // 32..35 r3 c0123 + + const int8x16_t q5_4567_lo = + vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q5_4567, m4b), hbit_lo_4567, 4)); + const int8x16_t q5_4567_hi = + vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5_4567, 4), hbit_hi_4567)); + + acc_lo[4] = vdotq_laneq_s32(acc_lo[4], q5_4567_lo, q8_blk0, 0); // 0..3 r0 c4567 + acc_lo[5] = vdotq_laneq_s32(acc_lo[5], q5_4567_lo, q8_blk0, 1); // 0..3 r1 c4567 + acc_lo[6] = vdotq_laneq_s32(acc_lo[6], q5_4567_lo, q8_blk0, 2); // 0..3 r2 c4567 + acc_lo[7] = vdotq_laneq_s32(acc_lo[7], q5_4567_lo, q8_blk0, 3); // 0..3 r3 c4567 + + acc_hi[4] = vdotq_laneq_s32(acc_hi[4], q5_4567_hi, q8_blk1, 0); // 32..35 r0 c4567 + acc_hi[5] = vdotq_laneq_s32(acc_hi[5], q5_4567_hi, q8_blk1, 1); // 32..35 r1 c4567 + acc_hi[6] = vdotq_laneq_s32(acc_hi[6], q5_4567_hi, q8_blk1, 2); // 32..35 r2 c4567 + acc_hi[7] = vdotq_laneq_s32(acc_hi[7], q5_4567_hi, q8_blk1, 3); // 32..35 r3 c4567 + } + + // Scale and bias application + // acc is stored interleaved to match output layout + const int16x4_t sc_0123_lo = vget_low_s16(q5sb_scales[0]); + const int16x4_t sc_4567_lo = vget_high_s16(q5sb_scales[0]); + const int16x4_t sc_0123_hi = vget_low_s16(q5sb_scales[1]); + const int16x4_t sc_4567_hi = vget_high_s16(q5sb_scales[1]); + for (int row = 0; row < q8_k_blocklen; row++) { + // Bias correction + // row c0123 blk0 and blk1 + const float32x4_t sumf_0123 = + vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_0123_lo), acc_lo[row]), + vmulq_s32(vmovl_s16(sc_0123_hi), acc_hi[row]))); + acc_f32[2 * row] = vfmaq_f32(acc_f32[2 * row], sbd_scale_0123[row], sumf_0123); + + // row c4567 blk0 and blk1 + const float32x4_t sumf_4567 = + vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_4567_lo), acc_lo[row + 4]), + vmulq_s32(vmovl_s16(sc_4567_hi), acc_hi[row + 4]))); + acc_f32[2 * row + 1] = vfmaq_f32(acc_f32[2 * row + 1], sbd_scale_4567[row], sumf_4567); + + // Bias + const int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[sb][row * 2]); + const int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][row * 2 + 1]); + + // row c0123 blk0 and blk1 + bias_acc[2 * row] = vmlal_s16(bias_acc[2 * row], bsums_vec_lo, vget_low_s16(q5sb_mins[0])); + bias_acc[2 * row] = vmlal_s16(bias_acc[2 * row], bsums_vec_hi, vget_low_s16(q5sb_mins[1])); + + // row c4567 blk0 and blk1 + bias_acc[2 * row + 1] = + vmlal_s16(bias_acc[2 * row + 1], bsums_vec_lo, vget_high_s16(q5sb_mins[0])); + bias_acc[2 * row + 1] = + vmlal_s16(bias_acc[2 * row + 1], bsums_vec_hi, vget_high_s16(q5sb_mins[1])); + } + } // for sb + + for (int row = 0; row < q8_k_blocklen; row++) { + acc_f32[2 * row] = vmlsq_f32(acc_f32[2 * row], vcvtq_f32_s32(bias_acc[2 * row]), sbd_min_0123[row]); + acc_f32[2 * row + 1] = + vmlsq_f32(acc_f32[2 * row + 1], vcvtq_f32_s32(bias_acc[2 * row + 1]), sbd_min_4567[row]); + } + } // for b + + for (int i = 0; i < q8_k_blocklen; i++) { + int row = y * q8_k_blocklen + i; + for (int j = 0; j < 2; j++) { + int col = x * ncols_interleaved + j * 4; + int offset = row * bs + col; + vst1q_f32(s + offset, acc_f32[2 * i + j]); + } + } + } // for x + } // for y + return; +#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + ggml_gemm_q5_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc); +} + void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp index f94426ddd7f..1b3d23cbedc 100644 --- a/ggml/src/ggml-cpu/repack.cpp +++ b/ggml/src/ggml-cpu/repack.cpp @@ -450,6 +450,208 @@ static void ggml_gemm_q6_K_NxM_q8_K_generic_impl(int n, } } +template +static void ggml_gemv_q5_K_NxM_q8_K_generic_impl(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + constexpr int blocklen = M; + constexpr int ncols_interleaved = N; + const int qk = QK_K; + const int nb = n / qk; + static const uint32_t kmask1 = 0x3f3f3f3f; + static const uint32_t kmask2 = 0x0f0f0f0f; + static const uint32_t kmask3 = 0x03030303; + + assert(n % qk == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(bs); + UNUSED(nr); + + float sumf[ncols_interleaved]; + float sum_minf[ncols_interleaved]; + uint32_t utmp[32]; + int sumi1; + int sumi2; + int sumi; + + const block_q8_K * a_ptr = (const block_q8_K *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q5_Kx8 * b_ptr = (const block_q5_Kx8 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) { + sumf[j] = 0.0; + sum_minf[j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int sb = 0; sb < 8; sb++) { + memcpy(utmp + sb * 4, b_ptr[l].scales + sb * K_SCALE_SIZE, K_SCALE_SIZE); + utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4); + const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1; + utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4); + utmp[sb * 4 + 2] = uaux_0; + utmp[sb * 4 + 0] &= kmask1; + } + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + constexpr int scale_stride = 32; + uint8_t * scales_0 = (uint8_t *) utmp + (k / (32 / blocklen)) * scale_stride; + uint8_t * scales_1 = (uint8_t *) utmp + (k / (32 / blocklen)) * scale_stride + 16; + + const int qh_shift = (k / (32 / blocklen)) * 2; + for (int j = 0; j < ncols_interleaved; j++) { + sumi1 = 0; + sumi2 = 0; + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int b_qs_offset = k * ncols_interleaved * blocklen + j * blocklen + i; + + const int qh_idx = (k * blocklen + i) % 32; + const int qh_chunk = qh_idx / blocklen; + const int qh_pos = qh_idx % blocklen; + const int b_qh_offset = qh_chunk * (blocklen * ncols_interleaved) + j * blocklen + qh_pos; + + const uint8_t qh_val = b_ptr[l].qh[b_qh_offset]; + const uint8_t h0 = (qh_val >> qh_shift) & 1; + const uint8_t h1 = (qh_val >> (qh_shift + 1)) & 1; + + const int v0 = (int8_t) ((b_ptr[l].qs[b_qs_offset] & 0xF) | (h0 << 4)); + const int v1 = (int8_t) ((b_ptr[l].qs[b_qs_offset] >> 4) | (h1 << 4)); + + const int q8_offset = (k / (32 / blocklen)) * 64 + (k % (32 / blocklen)) * blocklen + i; + + sumi1 = (v0 * a_ptr[l].qs[q8_offset]); + sumi2 = (v1 * a_ptr[l].qs[q8_offset + 32]); + sumi1 = sumi1 * scales_0[j]; + sumi2 = sumi2 * scales_1[j]; + sumi += sumi1 + sumi2; + } + sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d; + } + } + for (int sb = 0; sb < 8; sb++) { + uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16; + for (int j = 0; j < ncols_interleaved; j++) { + sum_minf[j] += mins[j] * (a_ptr[l].bsums[sb * 2] + a_ptr[l].bsums[sb * 2 + 1]) * + GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d; + } + } + } + for (int j = 0; j < ncols_interleaved; j++) { + s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j]; + } + } +} + +template +static void ggml_gemm_q5_K_NxM_q8_K_generic_impl(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + constexpr int blocklen = M; + constexpr int ncols_interleaved = N; + const int qk = QK_K; + const int nb = n / qk; + static const uint32_t kmask1 = 0x3f3f3f3f; + static const uint32_t kmask2 = 0x0f0f0f0f; + static const uint32_t kmask3 = 0x03030303; + + assert(n % qk == 0); + assert(nr % 4 == 0); + assert(nc % ncols_interleaved == 0); + + float sumf[4][ncols_interleaved]; + float sum_minf[4][ncols_interleaved]; + uint32_t utmp[32]; + int sumi1; + int sumi2; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q5_Kx8 * b_ptr = (const block_q5_Kx8 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumf[m][j] = 0.0; + sum_minf[m][j] = 0.0; + } + } + for (int l = 0; l < nb; l++) { + for (int sb = 0; sb < 8; sb++) { + memcpy(utmp + sb * 4, b_ptr[l].scales + sb * K_SCALE_SIZE, K_SCALE_SIZE); + utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4); + const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1; + utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4); + utmp[sb * 4 + 2] = uaux_0; + utmp[sb * 4 + 0] &= kmask1; + } + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + constexpr int scale_stride = 32; + uint8_t * scales_0 = (uint8_t *) utmp + (k / (32 / blocklen)) * scale_stride; + uint8_t * scales_1 = (uint8_t *) utmp + (k / (32 / blocklen)) * scale_stride + 16; + + const int qh_shift = (k / (32 / blocklen)) * 2; + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi1 = 0; + sumi2 = 0; + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int b_qs_offset = k * ncols_interleaved * blocklen + j * blocklen + i; + + const int qh_idx = (k * blocklen + i) % 32; + const int qh_chunk = qh_idx / blocklen; + const int qh_pos = qh_idx % blocklen; + const int b_qh_offset = + qh_chunk * (blocklen * ncols_interleaved) + j * blocklen + qh_pos; + + const uint8_t qh_val = b_ptr[l].qh[b_qh_offset]; + const uint8_t h0 = (qh_val >> qh_shift) & 1; + const uint8_t h1 = (qh_val >> (qh_shift + 1)) & 1; + + const int v0 = (int8_t) ((b_ptr[l].qs[b_qs_offset] & 0xF) | (h0 << 4)); + const int v1 = (int8_t) ((b_ptr[l].qs[b_qs_offset] >> 4) | (h1 << 4)); + + const int q8_offset = (k / (32 / blocklen)) * 256 + + (k % (32 / blocklen)) * 4 * blocklen + m * blocklen + i; + + sumi1 = (v0 * a_ptr[l].qs[q8_offset]); + sumi2 = (v1 * a_ptr[l].qs[q8_offset + 128]); + sumi1 = sumi1 * scales_0[j]; + sumi2 = sumi2 * scales_1[j]; + sumi += sumi1 + sumi2; + } + sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m]; + } + } + } + for (int sb = 0; sb < 8; sb++) { + uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16; + for (int m = 0; m < 4; m++) { + const int16_t * bsums = a_ptr[l].bsums + (sb * 8) + (m * 4) - ((sb % 2) * 6); + for (int j = 0; j < ncols_interleaved; j++) { + sum_minf[m][j] += mins[j] * (bsums[0] + bsums[1]) * + GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m]; + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j]; + } + } + } + } +} + extern "C" { void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { @@ -803,98 +1005,12 @@ void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, } } -void ggml_gemv_q5_K_8x8_q8_K_generic(int n, - float * GGML_RESTRICT s, - size_t bs, - const void * GGML_RESTRICT vx, - const void * GGML_RESTRICT vy, - int nr, - int nc) { - const int qk = QK_K; - const int nb = n / qk; - const int ncols_interleaved = 8; - const int blocklen = 8; - static const uint32_t kmask1 = 0x3f3f3f3f; - static const uint32_t kmask2 = 0x0f0f0f0f; - static const uint32_t kmask3 = 0x03030303; - - assert(n % qk == 0); - assert(nc % ncols_interleaved == 0); - - UNUSED(bs); - UNUSED(nr); - - float sumf[8]; - float sum_minf[8]; - uint32_t utmp[32]; - int sumi1; - int sumi2; - int sumi; - - const block_q8_K * a_ptr = (const block_q8_K *) vy; - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q5_Kx8 * b_ptr = (const block_q5_Kx8 *) vx + (x * nb); - - for (int j = 0; j < ncols_interleaved; j++) { - sumf[j] = 0.0; - sum_minf[j] = 0.0; - } - for (int l = 0; l < nb; l++) { - for (int sb = 0; sb < 8; sb++) { - memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12); - utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4); - const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1; - utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4); - utmp[sb * 4 + 2] = uaux_0; - utmp[sb * 4 + 0] &= kmask1; - } - for (int k = 0; k < (qk / (2 * blocklen)); k++) { - uint8_t * scales_0 = (uint8_t *) utmp + (k / 4) * 32; - uint8_t * scales_1 = (uint8_t *) utmp + (k / 4) * 32 + 16; - - const int qh_shift = (k / 4) * 2; - for (int j = 0; j < ncols_interleaved; j++) { - sumi1 = 0; - sumi2 = 0; - sumi = 0; - for (int i = 0; i < blocklen; ++i) { - const int b_qs_offset = k * ncols_interleaved * blocklen + j * blocklen + i; - - const int qh_idx = (k * 8 + i) % 32; - const int qh_chunk = qh_idx / 8; - const int qh_pos = qh_idx % 8; - const int b_qh_offset = qh_chunk * 64 + j * 8 + qh_pos; - - const uint8_t qh_val = b_ptr[l].qh[b_qh_offset]; - const uint8_t h0 = (qh_val >> qh_shift) & 1; - const uint8_t h1 = (qh_val >> (qh_shift + 1)) & 1; - - const int v0 = (int8_t) ((b_ptr[l].qs[b_qs_offset] & 0xF) | (h0 << 4)); - const int v1 = (int8_t) ((b_ptr[l].qs[b_qs_offset] >> 4) | (h1 << 4)); - - const int q8_offset = (k >> 2) * 64 + (k % 4) * blocklen + i; +void ggml_gemv_q5_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q5_K_NxM_q8_K_generic_impl<4, 8>(n, s, bs, vx, vy, nr, nc); +} - sumi1 = (v0 * a_ptr[l].qs[q8_offset]); - sumi2 = (v1 * a_ptr[l].qs[q8_offset + 32]); - sumi1 = sumi1 * scales_0[j]; - sumi2 = sumi2 * scales_1[j]; - sumi += sumi1 + sumi2; - } - sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d; - } - } - for (int sb = 0; sb < 8; sb++) { - uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16; - for (int j = 0; j < ncols_interleaved; j++) { - sum_minf[j] += mins[j] * (a_ptr[l].bsums[sb * 2] + a_ptr[l].bsums[sb * 2 + 1]) * - GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d; - } - } - } - for (int j = 0; j < ncols_interleaved; j++) { - s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j]; - } - } +void ggml_gemv_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q5_K_NxM_q8_K_generic_impl<8, 8>(n, s, bs, vx, vy, nr, nc); } @@ -1494,107 +1610,12 @@ void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, } } -void ggml_gemm_q5_K_8x8_q8_K_generic(int n, - float * GGML_RESTRICT s, - size_t bs, - const void * GGML_RESTRICT vx, - const void * GGML_RESTRICT vy, - int nr, - int nc) { - const int qk = QK_K; - const int nb = n / qk; - const int ncols_interleaved = 8; - const int blocklen = 8; - - constexpr uint32_t kmask1 = 0x3f3f3f3f; - constexpr uint32_t kmask2 = 0x0f0f0f0f; - constexpr uint32_t kmask3 = 0x03030303; - - assert(n % qk == 0); - assert(nr % 4 == 0); - assert(nc % ncols_interleaved == 0); - - float sumf[4][8]; - float sum_minf[4][8]; - uint32_t utmp[32]; - int sumi1; - int sumi2; - int sumi; - - for (int y = 0; y < nr / 4; y++) { - const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb); - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q5_Kx8 * b_ptr = (const block_q5_Kx8 *) vx + (x * nb); - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumf[m][j] = 0.0; - sum_minf[m][j] = 0.0; - } - } - for (int l = 0; l < nb; l++) { - for (int sb = 0; sb < 8; sb++) { - memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12); - utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4); - const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1; - utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4); - utmp[sb * 4 + 2] = uaux_0; - utmp[sb * 4 + 0] &= kmask1; - } - for (int k = 0; k < (qk / (2 * blocklen)); k++) { - uint8_t * scales_0 = (uint8_t *) utmp + (k / 4) * 32; - uint8_t * scales_1 = (uint8_t *) utmp + (k / 4) * 32 + 16; - - const int qh_shift = (k / 4) * 2; - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumi1 = 0; - sumi2 = 0; - sumi = 0; - for (int i = 0; i < blocklen; ++i) { - const int b_qs_offset = k * ncols_interleaved * blocklen + j * blocklen + i; - - const int qh_idx = (k * 8 + i) % 32; - const int qh_chunk = qh_idx / 8; - const int qh_pos = qh_idx % 8; - const int b_qh_offset = qh_chunk * 64 + j * 8 + qh_pos; - - const uint8_t qh_val = b_ptr[l].qh[b_qh_offset]; - const uint8_t h0 = (qh_val >> qh_shift) & 1; - const uint8_t h1 = (qh_val >> (qh_shift + 1)) & 1; - - const int v0 = (int8_t) ((b_ptr[l].qs[b_qs_offset] & 0xF) | (h0 << 4)); - const int v1 = (int8_t) ((b_ptr[l].qs[b_qs_offset] >> 4) | (h1 << 4)); - - const int q8_offset = (k >> 2) * 256 + (k % 4) * 4 * blocklen + m * blocklen + i; +void ggml_gemm_q5_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q5_K_NxM_q8_K_generic_impl<4, 8>(n, s, bs, vx, vy, nr, nc); +} - sumi1 = (v0 * a_ptr[l].qs[q8_offset]); - sumi2 = (v1 * a_ptr[l].qs[q8_offset + 128]); - sumi1 = sumi1 * scales_0[j]; - sumi2 = sumi2 * scales_1[j]; - sumi += sumi1 + sumi2; - } - sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m]; - } - } - } - for (int sb = 0; sb < 8; sb++) { - uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16; - for (int m = 0; m < 4; m++) { - const int16_t * bsums = a_ptr[l].bsums + (sb * 8) + (m * 4) - ((sb % 2) * 6); - for (int j = 0; j < ncols_interleaved; j++) { - sum_minf[m][j] += mins[j] * (bsums[0] + bsums[1]) * - GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m]; - } - } - } - } - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) { - s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j]; - } - } - } - } +void ggml_gemm_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q5_K_NxM_q8_K_generic_impl<8, 8>(n, s, bs, vx, vy, nr, nc); } void ggml_gemm_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { @@ -2029,18 +2050,16 @@ static block_q5_Kx8 make_block_q5_Kx8(block_q5_K * in, unsigned int blck_size_in const int end = QK_K * 4 / blck_size_interleave; - // Interleave Q5_K quants by taking 8 bytes at a time + // Interleave Q5_K quants by taking blck_size_interleave bytes at a time for (int i = 0; i < end; ++i) { int src_id = i % 8; int src_offset = (i / 8) * blck_size_interleave; int dst_offset = i * blck_size_interleave; - uint64_t elems; - memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t)); - memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t)); + memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], blck_size_interleave); } - // Repeat for low bits 8 bytes at a time as well, since + // Repeat for high bits with the same chunk size, since // the high bits are interleaved in Q5_K and the index is // qh_idx = (qs_idx % 32); // qh_val = qh[qh_idx] >> (qs_idx / 32); @@ -2049,9 +2068,7 @@ static block_q5_Kx8 make_block_q5_Kx8(block_q5_K * in, unsigned int blck_size_in int src_offset = (i / 8) * blck_size_interleave; int dst_offset = i * blck_size_interleave; - uint64_t elems; - memcpy(&elems, &in[src_id].qh[src_offset], sizeof(uint64_t)); - memcpy(&out.qh[dst_offset], &elems, sizeof(uint64_t)); + memcpy(&out.qh[dst_offset], &in[src_id].qh[src_offset], blck_size_interleave); } // The below logic is copied over from Q4_K @@ -2249,7 +2266,7 @@ static int repack_q5_K_to_q5_K_8_bl(struct ggml_tensor * t, const void * GGML_RESTRICT data, size_t data_size) { GGML_ASSERT(t->type == GGML_TYPE_Q5_K); - GGML_ASSERT(interleave_block == 8); + GGML_ASSERT(interleave_block == 4 || interleave_block == 8); constexpr int nrows_interleaved = 8; block_q5_Kx8 * dst = (block_q5_Kx8 *) t->data; @@ -2523,6 +2540,10 @@ template <> int repack(struct ggml_tensor * t, const void * da return repack_q2_K_to_q2_K_8_bl(t, 8, data, data_size); } +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q5_K_to_q5_K_8_bl(t, 4, data, data_size); +} + template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { return repack_q5_K_to_q5_K_8_bl(t, 8, data, data_size); } @@ -2591,6 +2612,10 @@ template <> void gemv(int n, float * s, size_t ggml_gemv_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); } +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q5_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc); +} + template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemv_q5_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); } @@ -2654,6 +2679,10 @@ template <> void gemm(int n, float * s, size_t ggml_gemm_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); } +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q5_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc); +} + template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_q5_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); } @@ -3068,6 +3097,7 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons static const ggml::cpu::repack::tensor_traits q4_K_8x8_q8_K; // instance for Q5_K + static const ggml::cpu::repack::tensor_traits q5_K_8x4_q8_K; static const ggml::cpu::repack::tensor_traits q5_K_8x8_q8_K; // instance for Q6_K @@ -3130,6 +3160,11 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons return &q5_K_8x8_q8_K; } } + if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) { + if (cur->ne[1] % 8 == 0) { + return &q5_K_8x4_q8_K; + } + } } else if (cur->type == GGML_TYPE_Q6_K) { if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) { if (cur->ne[1] % 8 == 0) { diff --git a/ggml/src/ggml-cpu/repack.h b/ggml/src/ggml-cpu/repack.h index 39b6b482388..ddf03d7642d 100644 --- a/ggml/src/ggml-cpu/repack.h +++ b/ggml/src/ggml-cpu/repack.h @@ -111,6 +111,7 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q5_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q5_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q6_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q6_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -122,6 +123,7 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q5_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q5_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q6_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q6_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -143,6 +145,7 @@ void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q5_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -154,6 +157,7 @@ void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q5_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); From 53b571a47e204c202af3fe541ed2d8cc9e7ee1d0 Mon Sep 17 00:00:00 2001 From: Max Krasnyansky Date: Mon, 23 Feb 2026 16:32:14 -0800 Subject: [PATCH 181/831] hexagon refactor all Ops to use local context struct (llama/19819) * hexagon: refactor set/get/sum-rows ops to use local context * hexagon: refactor ROPE and Softmax Ops to use local context Improves performance a bit by precomputing things and saving in the context. * hexagon: refactor activation ops to use local context struct * hexagon: refactor unary ops to use local context struct and DMA/VTCM * hexagon: use aligned hvx_scale function * hexagon: remove unused fields from op_context * hexagon: rewrite ROPE to use DMA and VTCM scratchpad * hex-rope: keep N rows in scratchpad (instead of just two) * hex-rope: introduce rowidx cache * hex-rope: remove unused fields * hex-rope: rewrite dma prefetch logic to allow for multi-row fetch/compute also removes the need for fastdiv. * hex-rope: minor formatting * hex-rope: use indices and unroll the loops * hex-rope: more updates to cleanup rope-block handling * hexagon: cleanup supported type/dims checks * hexagon: all reduce funcs replicated across lanes There is no need to explicitly replicate the first value. * snapdragon: update adb and windows scripts to use ubatch-size 256 Updated Ops support handles larger ubatches. --- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 104 ++--- ggml/src/ggml-hexagon/htp/act-ops.c | 436 ++++++++++----------- ggml/src/ggml-hexagon/htp/get-rows-ops.c | 33 +- ggml/src/ggml-hexagon/htp/hex-dma.h | 30 +- ggml/src/ggml-hexagon/htp/htp-ops.h | 26 -- ggml/src/ggml-hexagon/htp/matmul-ops.c | 84 +--- ggml/src/ggml-hexagon/htp/rope-ops.c | 478 ++++++++++++----------- ggml/src/ggml-hexagon/htp/set-rows-ops.c | 53 +-- ggml/src/ggml-hexagon/htp/softmax-ops.c | 248 ++++++------ ggml/src/ggml-hexagon/htp/sum-rows-ops.c | 83 ++-- ggml/src/ggml-hexagon/htp/unary-ops.c | 334 +++++++++------- 11 files changed, 943 insertions(+), 966 deletions(-) diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 54f9986498f..7a44443a8a3 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -1749,23 +1749,6 @@ static inline bool ggml_backend_buffer_is_hexagon_repack(const struct ggml_backe return b->buft->iface.alloc_buffer == ggml_backend_hexagon_repack_buffer_type_alloc_buffer; } -static bool hex_supported_dims2(const struct ggml_tensor * x, const struct ggml_tensor * y) { - if (x->ne[0] != y->ne[0]) { - return false; - } - if (x->ne[1] != y->ne[1]) { - return false; - } - if (x->ne[2] != y->ne[2]) { - return false; - } - if (x->ne[3] != y->ne[3]) { - return false; - } - - return true; -} - static bool ggml_hexagon_supported_flash_attn_ext(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { const struct ggml_tensor * src0 = op->src[0]; const struct ggml_tensor * src1 = op->src[1]; @@ -1797,43 +1780,6 @@ static bool ggml_hexagon_supported_flash_attn_ext(const struct ggml_hexagon_sess return opt_experimental; } -static bool hex_supported_src0_type(ggml_type t) { - return t == GGML_TYPE_F32; -} - -static bool hex_supported_src1_type(ggml_type t) { - return t == GGML_TYPE_F32; -} - -static bool hex_supported_src2_type(ggml_type t) { - return t == GGML_TYPE_F32; -} - -static bool hex_supported_src1_type2(ggml_type t) { - return t == GGML_TYPE_F16; -} - -static bool hex_supported_src1_type3(ggml_type t) { - return t == GGML_TYPE_I32; -} - -static bool hex_supported_dst_type(ggml_type t) { - return t == GGML_TYPE_F32; -} - -static bool hex_supported_dims(const struct ggml_tensor * x, const struct ggml_tensor * y) { - // TODO: support broadcast for ne[2 and 3] - if (x->ne[0] != y->ne[0]) { - return false; - } - if (x->ne[2] != y->ne[2]) { - return false; - } - if (x->ne[3] != y->ne[3]) { - return false; - } - return true; -} static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * sess, const struct ggml_tensor * dst) { const struct ggml_tensor * src0 = dst->src[0]; @@ -1919,19 +1865,19 @@ static bool ggml_hexagon_supported_binary(const struct ggml_hexagon_session * se const struct ggml_tensor * src1 = op->src[1]; const struct ggml_tensor * dst = op; - if (!hex_supported_src0_type(src0->type)) { + if (src0->type != GGML_TYPE_F32) { return false; } - if (!hex_supported_src1_type(src1->type)) { + if (src1->type != GGML_TYPE_F32) { return false; } - if (!hex_supported_dst_type(dst->type)) { + if (dst->type != GGML_TYPE_F32) { return false; } - if (!hex_supported_dims2(src0, dst)) { + if (!ggml_are_same_shape(src0, dst)) { return false; } - if (!ggml_can_repeat(src1, src0)) { + if (!ggml_can_repeat(src1, src0) || ggml_is_permuted(src1)) { return false; } @@ -1943,16 +1889,16 @@ static bool ggml_hexagon_supported_add_id(const struct ggml_hexagon_session * se const struct ggml_tensor * src1 = op->src[1]; const struct ggml_tensor * dst = op; - if (!hex_supported_src0_type(src0->type)) { + if (src0->type != GGML_TYPE_F32) { return false; } - if (!hex_supported_src1_type(src1->type)) { + if (src1->type != GGML_TYPE_F32) { return false; } - if (!hex_supported_dst_type(dst->type)) { + if (dst->type != GGML_TYPE_F32) { return false; } - if (!hex_supported_dims2(src0, dst)) { + if (!ggml_are_same_shape(src0, dst)) { return false; } @@ -1968,13 +1914,13 @@ static bool ggml_hexagon_supported_unary(const struct ggml_hexagon_session * ses const struct ggml_tensor * src0 = op->src[0]; const struct ggml_tensor * dst = op; - if (!hex_supported_src0_type(src0->type)) { + if (src0->type != GGML_TYPE_F32) { return false; } - if (!hex_supported_dst_type(dst->type)) { + if (dst->type != GGML_TYPE_F32) { return false; } - if (!hex_supported_dims2(src0, dst)) { + if (!ggml_are_same_shape(src0, dst)) { return false; } @@ -1990,10 +1936,10 @@ static bool ggml_hexagon_supported_sum_rows(const struct ggml_hexagon_session * const struct ggml_tensor * src0 = op->src[0]; const struct ggml_tensor * dst = op; - if (!hex_supported_src0_type(src0->type)) { + if (src0->type != GGML_TYPE_F32) { return false; } - if (!hex_supported_dst_type(dst->type)) { + if (dst->type != GGML_TYPE_F32) { return false; } @@ -2011,10 +1957,10 @@ static bool ggml_hexagon_supported_activations(const struct ggml_hexagon_session const struct ggml_tensor * src1 = op->src[1]; const struct ggml_tensor * dst = op; - if (!hex_supported_src0_type(src0->type)) { + if (src0->type != GGML_TYPE_F32) { return false; } - if (!hex_supported_dst_type(dst->type)) { + if (dst->type != GGML_TYPE_F32) { return false; } @@ -2023,10 +1969,10 @@ static bool ggml_hexagon_supported_activations(const struct ggml_hexagon_session } if (src1) { - if (!hex_supported_src1_type(src1->type)) { + if (src1->type != GGML_TYPE_F32) { return false; } - if (!hex_supported_dims2(src0, src1)) { + if (!ggml_are_same_shape(src0, src1)) { return false; } if (!ggml_is_contiguous(src1)) { @@ -2047,15 +1993,15 @@ static bool ggml_hexagon_supported_softmax(const struct ggml_hexagon_session * s return false; // FIXME: add support for sinks } - if (!hex_supported_src0_type(src0->type)) { + if (src0->type != GGML_TYPE_F32) { return false; } - if (!hex_supported_dst_type(dst->type)) { + if (dst->type != GGML_TYPE_F32) { return false; } if (src1) { - if (!hex_supported_src1_type(src1->type) && !hex_supported_src1_type2(src1->type)) { + if (src1->type != GGML_TYPE_F32 && src1->type != GGML_TYPE_F16) { return false; } if (src0->ne[0] != src1->ne[0]) { @@ -2162,17 +2108,17 @@ static bool ggml_hexagon_supported_rope(const struct ggml_hexagon_session * sess const struct ggml_tensor * src2 = op->src[2]; const struct ggml_tensor * dst = op; - if (!hex_supported_src0_type(src0->type)) { + if (src0->type != GGML_TYPE_F32) { return false; // FIXME: add support for GGML_TYPE_F16 for src0 } - if (!hex_supported_dst_type(dst->type)) { + if (dst->type != GGML_TYPE_F32) { return false; } - if (!hex_supported_src1_type3(src1->type)) { + if (src1->type != GGML_TYPE_I32) { return false; } if (src2) { - if (!hex_supported_src2_type(src2->type)) { + if (src2->type != GGML_TYPE_F32) { return false; } int n_dims = op_params[1]; diff --git a/ggml/src/ggml-hexagon/htp/act-ops.c b/ggml/src/ggml-hexagon/htp/act-ops.c index 950d836ad34..21bd4050a1d 100644 --- a/ggml/src/ggml-hexagon/htp/act-ops.c +++ b/ggml/src/ggml-hexagon/htp/act-ops.c @@ -69,27 +69,45 @@ const uint32_t nb2 = dst->nb[2]; \ const uint32_t nb3 = dst->nb[3]; -static void glu_swiglu_f32_per_thread(const struct htp_tensor * src0, - const struct htp_tensor * src1, - struct htp_tensor * dst, - const int32_t * op_params, - struct htp_spad * src0_spad, - struct htp_spad * src1_spad, - struct htp_spad * dst_spad, - uint32_t nth, - uint32_t ith, - uint32_t src0_nrows_per_thread, - dma_queue * dma_queue) { +struct htp_act_context { + struct htp_ops_context * octx; + + // Precomputed values + const uint8_t * data_src0; + const uint8_t * data_src1; + uint8_t * data_dst; + + size_t src0_row_size; + size_t src1_row_size; + size_t dst_row_size; + + size_t src0_row_size_aligned; + size_t src1_row_size_aligned; + size_t dst_row_size_aligned; + + size_t src0_spad_half_size; + size_t src1_spad_half_size; + size_t dst_spad_half_size; + + uint32_t block; + uint32_t src0_nrows; + uint32_t src0_nrows_per_thread; + int nc; +}; + +static void glu_swiglu_f32_per_thread(unsigned int nth, unsigned int ith, void * data) { + struct htp_act_context * actx = (struct htp_act_context *) data; + const struct htp_tensor * src0 = &actx->octx->src0; + const struct htp_tensor * src1 = &actx->octx->src1; + const struct htp_tensor * dst = &actx->octx->dst; htp_act_preamble3; - size_t src0_row_size = nb01; - size_t src1_row_size = nb11; - size_t dst_row_size = nb1; - - - - const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows + size_t src0_row_size = actx->src0_row_size; + size_t src1_row_size = actx->src1_row_size; + size_t dst_row_size = actx->dst_row_size; + const uint32_t src0_nrows = actx->src0_nrows; + const uint32_t src0_nrows_per_thread = actx->src0_nrows_per_thread; const uint32_t src0_start_row = src0_nrows_per_thread * ith; const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); @@ -101,43 +119,34 @@ static void glu_swiglu_f32_per_thread(const struct htp_tensor * src0, uint64_t t1, t2; t1 = HAP_perf_get_qtimer_count(); - const uint8_t * restrict data_src0 = (const uint8_t *) src0->data; - const uint8_t * restrict data_src1 = (const uint8_t *) src1->data; - uint8_t * restrict data_dst = (uint8_t *) dst->data; - - const bool src1_valid = src1->ne[0]; - const int nc = (src1_valid) ? ne00 : ne00 / 2; - if (!src1_valid) { - const int32_t swapped = op_params[1]; - data_src1 = data_src0; - src1_row_size = src0_row_size; + const uint8_t * restrict data_src0 = actx->data_src0; + const uint8_t * restrict data_src1 = actx->data_src1; + uint8_t * restrict data_dst = actx->data_dst; - const size_t nc_in_bytes = nc * SIZEOF_FP32; - data_src0 += swapped ? nc_in_bytes : 0; - data_src1 += swapped ? 0 : nc_in_bytes; - } + const int nc = actx->nc; - const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN); - const size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN); - const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN); + const size_t src0_row_size_aligned = actx->src0_row_size_aligned; + const size_t src1_row_size_aligned = actx->src1_row_size_aligned; + const size_t dst_row_size_aligned = actx->dst_row_size_aligned; - uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread); - uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_spad->size_per_thread); - uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread); + uint8_t * restrict src0_spad_data = actx->octx->src0_spad.data + (ith * actx->octx->src0_spad.size_per_thread); + uint8_t * restrict src1_spad_data = actx->octx->src1_spad.data + (ith * actx->octx->src1_spad.size_per_thread); + uint8_t * restrict dst_spad_data = actx->octx->dst_spad.data + (ith * actx->octx->dst_spad.size_per_thread); - // While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0 - size_t src0_spad_half_size = src0_spad->size_per_thread / 2; - size_t src1_spad_half_size = src1_spad->size_per_thread / 2; - size_t dst_spad_half_size = dst_spad->size_per_thread / 2; + size_t src0_spad_half_size = actx->src0_spad_half_size; + size_t src1_spad_half_size = actx->src1_spad_half_size; + size_t dst_spad_half_size = actx->dst_spad_half_size; - const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block + const int BLOCK = actx->block; if (BLOCK == 0) { FARF(ERROR, "swiglu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n", - src0_spad->size_per_thread, src0_row_size_aligned); + actx->octx->src0_spad.size_per_thread, src0_row_size_aligned); return; } + dma_queue * dma_queue = actx->octx->ctx->dma[ith]; + // See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379 for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) { const uint32_t block_size = MIN(BLOCK, src0_end_row - ir); @@ -196,27 +205,22 @@ static void glu_swiglu_f32_per_thread(const struct htp_tensor * src0, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } -static void glu_swiglu_oai_f32_per_thread(const struct htp_tensor * src0, - const struct htp_tensor * src1, - struct htp_tensor * dst, - const int32_t * op_params, - struct htp_spad * src0_spad, - struct htp_spad * src1_spad, - struct htp_spad * dst_spad, - uint32_t nth, - uint32_t ith, - uint32_t src0_nrows_per_thread, - dma_queue * dma_queue) { +static void glu_swiglu_oai_f32_per_thread(unsigned int nth, unsigned int ith, void * data) { + struct htp_act_context * actx = (struct htp_act_context *) data; + const struct htp_tensor * src0 = &actx->octx->src0; + const struct htp_tensor * src1 = &actx->octx->src1; + const struct htp_tensor * dst = &actx->octx->dst; htp_act_preamble3; uint64_t t1, t2; t1 = HAP_perf_get_qtimer_count(); - size_t src0_row_size = nb01; - size_t src1_row_size = nb11; - size_t dst_row_size = nb1; + size_t src0_row_size = actx->src0_row_size; + size_t src1_row_size = actx->src1_row_size; + size_t dst_row_size = actx->dst_row_size; - const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows + const uint32_t src0_nrows = actx->src0_nrows; + const uint32_t src0_nrows_per_thread = actx->src0_nrows_per_thread; const uint32_t src0_start_row = src0_nrows_per_thread * ith; const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); @@ -226,45 +230,36 @@ static void glu_swiglu_oai_f32_per_thread(const struct htp_tensor * src0, return; } - const uint8_t * restrict data_src0 = (const uint8_t *) src0->data; - const uint8_t * restrict data_src1 = (const uint8_t *) src1->data; - uint8_t * restrict data_dst = (uint8_t *) dst->data; - - const bool src1_valid = src1->ne[0]; - const int nc = (src1_valid) ? ne00 : ne00 / 2; - if (!src1_valid) { - const int32_t swapped = op_params[1]; - data_src1 = data_src0; - src1_row_size = src0_row_size; + const uint8_t * restrict data_src0 = actx->data_src0; + const uint8_t * restrict data_src1 = actx->data_src1; + uint8_t * restrict data_dst = actx->data_dst; - const size_t nc_in_bytes = nc * SIZEOF_FP32; - data_src0 += swapped ? nc_in_bytes : 0; - data_src1 += swapped ? 0 : nc_in_bytes; - } + const int nc = actx->nc; - const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN); - const size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN); - const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN); + const size_t src0_row_size_aligned = actx->src0_row_size_aligned; + const size_t src1_row_size_aligned = actx->src1_row_size_aligned; + const size_t dst_row_size_aligned = actx->dst_row_size_aligned; - uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread); - uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_spad->size_per_thread); - uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread); + uint8_t * restrict src0_spad_data = actx->octx->src0_spad.data + (ith * actx->octx->src0_spad.size_per_thread); + uint8_t * restrict src1_spad_data = actx->octx->src1_spad.data + (ith * actx->octx->src1_spad.size_per_thread); + uint8_t * restrict dst_spad_data = actx->octx->dst_spad.data + (ith * actx->octx->dst_spad.size_per_thread); - // While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0 - size_t src0_spad_half_size = src0_spad->size_per_thread / 2; - size_t src1_spad_half_size = src1_spad->size_per_thread / 2; - size_t dst_spad_half_size = dst_spad->size_per_thread / 2; + size_t src0_spad_half_size = actx->src0_spad_half_size; + size_t src1_spad_half_size = actx->src1_spad_half_size; + size_t dst_spad_half_size = actx->dst_spad_half_size; - const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block + const int BLOCK = actx->block; if (BLOCK == 0) { FARF(ERROR, "swiglu-oai-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least " "%zu\n", - src0_spad->size_per_thread, src0_row_size_aligned); + actx->octx->src0_spad.size_per_thread, src0_row_size_aligned); return; } - const float alpha = ((const float *) (op_params))[2]; - const float limit = ((const float *) (op_params))[3]; + const float alpha = ((const float *) (actx->octx->op_params))[2]; + const float limit = ((const float *) (actx->octx->op_params))[3]; + + dma_queue * dma_queue = actx->octx->ctx->dma[ith]; // See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379 for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) { @@ -335,26 +330,22 @@ static void glu_swiglu_oai_f32_per_thread(const struct htp_tensor * src0, } -static void unary_gelu_f32_per_thread(const struct htp_tensor * src0, - struct htp_tensor * dst, - const int32_t * op_params, - struct htp_spad * src0_spad, - struct htp_spad * dst_spad, - uint32_t nth, - uint32_t ith, - uint32_t src0_nrows_per_thread, - dma_queue * dma_queue) { +static void unary_gelu_f32_per_thread(unsigned int nth, unsigned int ith, void * data) { + struct htp_act_context * actx = (struct htp_act_context *) data; + const struct htp_tensor * src0 = &actx->octx->src0; + const struct htp_tensor * dst = &actx->octx->dst; htp_act_preamble2; uint64_t t1, t2; t1 = HAP_perf_get_qtimer_count(); - const size_t src0_row_size = nb01; - const size_t dst_row_size = nb1; - const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN); - const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN); + const size_t src0_row_size = actx->src0_row_size; + const size_t dst_row_size = actx->dst_row_size; + const size_t src0_row_size_aligned = actx->src0_row_size_aligned; + const size_t dst_row_size_aligned = actx->dst_row_size_aligned; - const uint32_t src0_nrows = ne01 * ne02 * ne03; + const uint32_t src0_nrows = actx->src0_nrows; + const uint32_t src0_nrows_per_thread = actx->src0_nrows_per_thread; const uint32_t src0_start_row = src0_nrows_per_thread * ith; const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); @@ -364,25 +355,29 @@ static void unary_gelu_f32_per_thread(const struct htp_tensor * src0, return; } - const uint8_t * data_src0 = (const uint8_t *) src0->data; - uint8_t * data_dst = (uint8_t *) dst->data; + const uint8_t * data_src0 = actx->data_src0; + uint8_t * data_dst = actx->data_dst; - uint8_t * src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread); - uint8_t * dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread); + // nc/ne0 matches. + const int ne0_val = actx->nc; // == dst->ne[0] - // While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0 - size_t src0_spad_half_size = src0_spad->size_per_thread / 2; - size_t dst_spad_half_size = dst_spad->size_per_thread / 2; + uint8_t * src0_spad_data = actx->octx->src0_spad.data + (ith * actx->octx->src0_spad.size_per_thread); + uint8_t * dst_spad_data = actx->octx->dst_spad.data + (ith * actx->octx->dst_spad.size_per_thread); + + size_t src0_spad_half_size = actx->src0_spad_half_size; + size_t dst_spad_half_size = actx->dst_spad_half_size; // In gelu = x*sigmoid(x*1.702) - const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block + const int BLOCK = actx->block; if (BLOCK == 0) { FARF(ERROR, "gelu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n", - src0_spad->size_per_thread, src0_row_size_aligned); + actx->octx->src0_spad.size_per_thread, src0_row_size_aligned); return; } + dma_queue * dma_queue = actx->octx->ctx->dma[ith]; + // See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379 for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) { const uint32_t block_size = MIN(BLOCK, src0_end_row - ir); @@ -408,9 +403,9 @@ static void unary_gelu_f32_per_thread(const struct htp_tensor * src0, float* dst_spad_ptr = dst_spad + ib * (dst_row_size_aligned / sizeof(float)); // gelu = x * sigmoid(1.702 * x) // current implementation - hvx_mul_scalar_f32((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (float) 1.702, ne0); - hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0); - hvx_mul_f32_aaa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0); + hvx_mul_scalar_f32((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (float) 1.702, ne0_val); + hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0_val); + hvx_mul_f32_aaa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0_val); } dma_queue_push_vtcm_to_ddr(dma_queue, @@ -435,34 +430,23 @@ static void unary_gelu_f32_per_thread(const struct htp_tensor * src0, ne03, src0_start_row, src0_end_row, ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } -static void unary_gelu_f32(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = (struct htp_ops_context *) data; - unary_gelu_f32_per_thread(&octx->src0, &octx->dst, octx->op_params, &octx->src0_spad, &octx->dst_spad, n, i, - octx->src0_nrows_per_thread, octx->ctx->dma[i]); -} - - -static void unary_silu_f32_per_thread(const struct htp_tensor * src0, - struct htp_tensor * dst, - const int32_t * op_params, - struct htp_spad * src0_spad, - struct htp_spad * dst_spad, - uint32_t nth, - uint32_t ith, - uint32_t src0_nrows_per_thread, - dma_queue * dma_queue) { +static void unary_silu_f32_per_thread(unsigned int nth, unsigned int ith, void * data) { + struct htp_act_context * actx = (struct htp_act_context *) data; + const struct htp_tensor * src0 = &actx->octx->src0; + const struct htp_tensor * dst = &actx->octx->dst; htp_act_preamble2; uint64_t t1, t2; t1 = HAP_perf_get_qtimer_count(); - const size_t src0_row_size = nb01; - const size_t dst_row_size = nb1; - const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN); - const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN); + const size_t src0_row_size = actx->src0_row_size; + const size_t dst_row_size = actx->dst_row_size; + const size_t src0_row_size_aligned = actx->src0_row_size_aligned; + const size_t dst_row_size_aligned = actx->dst_row_size_aligned; - const uint32_t src0_nrows = ne01 * ne02 * ne03; + const uint32_t src0_nrows = actx->src0_nrows; + const uint32_t src0_nrows_per_thread = actx->src0_nrows_per_thread; const uint32_t src0_start_row = src0_nrows_per_thread * ith; const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); @@ -472,24 +456,27 @@ static void unary_silu_f32_per_thread(const struct htp_tensor * src0, return; } - const uint8_t * data_src0 = (const uint8_t *) src0->data; - uint8_t * data_dst = (uint8_t *) dst->data; + const uint8_t * data_src0 = actx->data_src0; + uint8_t * data_dst = actx->data_dst; - uint8_t * src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread); - uint8_t * dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread); + const int ne0_val = actx->nc; // == dst->ne[0] - // While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0 - size_t src0_spad_half_size = src0_spad->size_per_thread / 2; - size_t dst_spad_half_size = dst_spad->size_per_thread / 2; + uint8_t * src0_spad_data = actx->octx->src0_spad.data + (ith * actx->octx->src0_spad.size_per_thread); + uint8_t * dst_spad_data = actx->octx->dst_spad.data + (ith * actx->octx->dst_spad.size_per_thread); - const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block + size_t src0_spad_half_size = actx->src0_spad_half_size; + size_t dst_spad_half_size = actx->dst_spad_half_size; + + const int BLOCK = actx->block; if (BLOCK == 0) { FARF(ERROR, "silu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n", - src0_spad->size_per_thread, src0_row_size_aligned); + actx->octx->src0_spad.size_per_thread, src0_row_size_aligned); return; } + dma_queue * dma_queue = actx->octx->ctx->dma[ith]; + // See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379 for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) { const uint32_t block_size = MIN(BLOCK, src0_end_row - ir); @@ -515,8 +502,8 @@ static void unary_silu_f32_per_thread(const struct htp_tensor * src0, float* dst_spad_ptr = dst_spad + ib * (dst_row_size_aligned / sizeof(float)); // silu = x * sigmoid(x) - hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, ne0); - hvx_mul_f32_aaa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0); + hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, ne0_val); + hvx_mul_f32_aaa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0_val); } dma_queue_push_vtcm_to_ddr(dma_queue, @@ -544,27 +531,22 @@ static void unary_silu_f32_per_thread(const struct htp_tensor * src0, static const float GELU_COEF_A = 0.044715f; static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; -static void glu_geglu_f32_per_thread(const struct htp_tensor * src0, - const struct htp_tensor * src1, - struct htp_tensor * dst, - const int32_t * op_params, - struct htp_spad * src0_spad, - struct htp_spad * src1_spad, - struct htp_spad * dst_spad, - uint32_t nth, - uint32_t ith, - uint32_t src0_nrows_per_thread, - dma_queue * dma_queue) { +static void glu_geglu_f32_per_thread(unsigned int nth, unsigned int ith, void * data) { + struct htp_act_context * actx = (struct htp_act_context *) data; + const struct htp_tensor * src0 = &actx->octx->src0; + const struct htp_tensor * src1 = &actx->octx->src1; + const struct htp_tensor * dst = &actx->octx->dst; htp_act_preamble3; - size_t src0_row_size = nb01; - size_t src1_row_size = nb11; - size_t dst_row_size = nb1; + size_t src0_row_size = actx->src0_row_size; + size_t src1_row_size = actx->src1_row_size; + size_t dst_row_size = actx->dst_row_size; uint64_t t1, t2; t1 = HAP_perf_get_qtimer_count(); - const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows + const uint32_t src0_nrows = actx->src0_nrows; + const uint32_t src0_nrows_per_thread = actx->src0_nrows_per_thread; const uint32_t src0_start_row = src0_nrows_per_thread * ith; const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); @@ -574,43 +556,34 @@ static void glu_geglu_f32_per_thread(const struct htp_tensor * src0, return; } - const uint8_t * restrict data_src0 = (const uint8_t *) src0->data; - const uint8_t * restrict data_src1 = (const uint8_t *) src1->data; - uint8_t * restrict data_dst = (uint8_t *) dst->data; + const uint8_t * restrict data_src0 = actx->data_src0; + const uint8_t * restrict data_src1 = actx->data_src1; + uint8_t * restrict data_dst = actx->data_dst; - const bool src1_valid = src1->ne[0]; - const int nc = (src1_valid) ? ne00 : ne00 / 2; - if (!src1_valid) { - const int32_t swapped = op_params[1]; - data_src1 = data_src0; - src1_row_size = src0_row_size; - - const size_t nc_in_bytes = nc * SIZEOF_FP32; - data_src0 += swapped ? nc_in_bytes : 0; - data_src1 += swapped ? 0 : nc_in_bytes; - } + const int nc = actx->nc; - const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN); - const size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN); - const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN); + const size_t src0_row_size_aligned = actx->src0_row_size_aligned; + const size_t src1_row_size_aligned = actx->src1_row_size_aligned; + const size_t dst_row_size_aligned = actx->dst_row_size_aligned; - uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread); - uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_spad->size_per_thread); - uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread); + uint8_t * restrict src0_spad_data = actx->octx->src0_spad.data + (ith * actx->octx->src0_spad.size_per_thread); + uint8_t * restrict src1_spad_data = actx->octx->src1_spad.data + (ith * actx->octx->src1_spad.size_per_thread); + uint8_t * restrict dst_spad_data = actx->octx->dst_spad.data + (ith * actx->octx->dst_spad.size_per_thread); - // While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0 - size_t src0_spad_half_size = src0_spad->size_per_thread / 2; - size_t src1_spad_half_size = src1_spad->size_per_thread / 2; - size_t dst_spad_half_size = dst_spad->size_per_thread / 2; + size_t src0_spad_half_size = actx->src0_spad_half_size; + size_t src1_spad_half_size = actx->src1_spad_half_size; + size_t dst_spad_half_size = actx->dst_spad_half_size; - const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block + const int BLOCK = actx->block; if (BLOCK == 0) { FARF(ERROR, "geglu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n", - src0_spad->size_per_thread, src0_row_size_aligned); + actx->octx->src0_spad.size_per_thread, src0_row_size_aligned); return; } + dma_queue * dma_queue = actx->octx->ctx->dma[ith]; + // See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379 for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) { const uint32_t block_size = MIN(BLOCK, src0_end_row - ir); @@ -678,33 +651,7 @@ static void glu_geglu_f32_per_thread(const struct htp_tensor * src0, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } -static void unary_silu_f32(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = (struct htp_ops_context *) data; - unary_silu_f32_per_thread(&octx->src0, &octx->dst, octx->op_params, &octx->src0_spad, &octx->dst_spad, n, i, - octx->src0_nrows_per_thread, octx->ctx->dma[i]); -} - -static void glu_swiglu_f32(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = (struct htp_ops_context *) data; - glu_swiglu_f32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad, - &octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]); -} - -static void glu_swiglu_oai_f32(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = (struct htp_ops_context *) data; - glu_swiglu_oai_f32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad, - &octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]); -} - -static void glu_geglu_f32(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = (struct htp_ops_context *) data; - glu_geglu_f32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad, - &octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]); -} - static int execute_op_activations_f32(struct htp_ops_context * octx) { - int err = HTP_STATUS_OK; - const struct htp_tensor * src0 = &octx->src0; const struct htp_tensor * src1 = &octx->src1; struct htp_tensor * dst = &octx->dst; @@ -719,26 +666,26 @@ static int execute_op_activations_f32(struct htp_ops_context * octx) { switch (octx->op) { case HTP_OP_UNARY_SILU: - act_op_func = unary_silu_f32; + act_op_func = (worker_callback_t)unary_silu_f32_per_thread; op_type = "silu-f32"; break; case HTP_OP_GLU_SWIGLU: - act_op_func = glu_swiglu_f32; + act_op_func = (worker_callback_t)glu_swiglu_f32_per_thread; op_type = "swiglu-f32"; break; case HTP_OP_GLU_SWIGLU_OAI: - act_op_func = glu_swiglu_oai_f32; + act_op_func = (worker_callback_t)glu_swiglu_oai_f32_per_thread; op_type = "swiglu-oai-f32"; break; case HTP_OP_UNARY_GELU: - act_op_func = unary_gelu_f32; + act_op_func = (worker_callback_t)unary_gelu_f32_per_thread; op_type = "gelu-f32"; break; case HTP_OP_GLU_GEGLU: - act_op_func = glu_geglu_f32; + act_op_func = (worker_callback_t)glu_geglu_f32_per_thread; op_type = "geglu-f32"; break; default: @@ -797,13 +744,58 @@ static int execute_op_activations_f32(struct htp_ops_context * octx) { octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size); } - if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { - uint32_t n_jobs = MIN(n_threads, src0_nrows); - octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs; - worker_pool_run_func(octx->ctx->worker_pool, act_op_func, octx, n_jobs); + if ((octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { + return HTP_STATUS_OK; } - return err; + uint32_t n_jobs = MIN(n_threads, src0_nrows); + + // Prepare context + struct htp_act_context actx; + actx.octx = octx; + + actx.src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs; + + actx.src0_row_size = src0_row_size; + actx.src1_row_size = src1_row_size; + actx.dst_row_size = dst_row_size; + + actx.src0_row_size_aligned = src0_row_size_aligned; + actx.src1_row_size_aligned = src1_row_size_aligned; + actx.dst_row_size_aligned = dst_row_size_aligned; + + actx.src0_spad_half_size = octx->src0_spad.size_per_thread / 2; + actx.src1_spad_half_size = octx->src1_spad.size_per_thread / 2; + actx.dst_spad_half_size = octx->dst_spad.size_per_thread / 2; + + actx.block = actx.src0_spad_half_size / actx.src0_row_size_aligned; + actx.src0_nrows = src0_nrows; + + actx.nc = dst->ne[0]; + + // Pointers and GLU logic + const uint8_t * data_src0 = (const uint8_t *) src0->data; + const uint8_t * data_src1 = (const uint8_t *) src1->data; + + if (!src1_valid && (octx->op == HTP_OP_GLU_SWIGLU || octx->op == HTP_OP_GLU_SWIGLU_OAI || octx->op == HTP_OP_GLU_GEGLU)) { + const int32_t swapped = octx->op_params[1]; + data_src1 = data_src0; + actx.src1_row_size = actx.src0_row_size; + + size_t nc_in_bytes = actx.nc * SIZEOF_FP32; + if (swapped) { + data_src0 += nc_in_bytes; + } else { + data_src1 += nc_in_bytes; + } + } + + actx.data_src0 = data_src0; + actx.data_src1 = data_src1; + actx.data_dst = (uint8_t *) dst->data; + + worker_pool_run_func(octx->ctx->worker_pool, act_op_func, &actx, n_jobs); + return HTP_STATUS_OK; } int op_activations(struct htp_ops_context * octx) { diff --git a/ggml/src/ggml-hexagon/htp/get-rows-ops.c b/ggml/src/ggml-hexagon/htp/get-rows-ops.c index a657cd2dcf2..bf24bbda70a 100644 --- a/ggml/src/ggml-hexagon/htp/get-rows-ops.c +++ b/ggml/src/ggml-hexagon/htp/get-rows-ops.c @@ -15,6 +15,13 @@ #include "htp-ops.h" #include "hvx-utils.h" +struct get_rows_context { + struct htp_ops_context * octx; + uint32_t src1_nrows_per_thread; + struct fastdiv_values get_rows_div_ne10; + struct fastdiv_values get_rows_div_ne10_ne11; +}; + #define get_rows_preamble \ const uint32_t ne00 = octx->src0.ne[0]; \ const uint32_t ne01 = octx->src0.ne[1]; \ @@ -39,20 +46,22 @@ \ const uint32_t nr = ne10 * ne11 * ne12; -static int get_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth, const int ith) { +static void get_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *data) { + struct get_rows_context * grctx = (struct get_rows_context *)data; + struct htp_ops_context * octx = grctx->octx; get_rows_preamble; // parallelize by src1 elements (which correspond to dst rows) - const uint32_t dr = octx->src1_nrows_per_thread; + const uint32_t dr = grctx->src1_nrows_per_thread; const uint32_t ir0 = dr * ith; const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr; const bool is_i32 = (octx->src1.type == HTP_TYPE_I32); for (uint32_t i = ir0; i < ir1; ++i) { - const uint32_t i12 = fastdiv(i, &octx->get_rows_div_ne10_ne11); + const uint32_t i12 = fastdiv(i, &grctx->get_rows_div_ne10_ne11); const uint32_t rem = i - i12 * ne11 * ne10; - const uint32_t i11 = fastdiv(rem, &octx->get_rows_div_ne10); + const uint32_t i11 = fastdiv(rem, &grctx->get_rows_div_ne10); const uint32_t i10 = rem - i11 * ne10; const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12; @@ -68,12 +77,6 @@ static int get_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth, const uintptr_t dst_ptr = octx->dst.data + i10*nb1 + i11*nb2 + i12*nb3; hvx_copy_f32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, ne00); } - - return HTP_STATUS_OK; -} - -static void get_rows_work_f32_f32(unsigned int n, unsigned int i, void *data) { - get_rows_thread_f32_f32((struct htp_ops_context *) data, n, i); } int op_get_rows(struct htp_ops_context * octx) { @@ -95,12 +98,14 @@ int op_get_rows(struct htp_ops_context * octx) { return HTP_STATUS_OK; } - octx->get_rows_div_ne10 = init_fastdiv_values(octx->src1.ne[0]); - octx->get_rows_div_ne10_ne11 = init_fastdiv_values(octx->src1.ne[0] * octx->src1.ne[1]); + struct get_rows_context grctx; + grctx.octx = octx; + grctx.get_rows_div_ne10 = init_fastdiv_values(octx->src1.ne[0]); + grctx.get_rows_div_ne10_ne11 = init_fastdiv_values(octx->src1.ne[0] * octx->src1.ne[1]); const uint32_t n_jobs = MIN(nr, octx->n_threads); - octx->src1_nrows_per_thread = (nr + n_jobs - 1) / n_jobs; + grctx.src1_nrows_per_thread = (nr + n_jobs - 1) / n_jobs; - worker_pool_run_func(octx->ctx->worker_pool, get_rows_work_f32_f32, octx, n_jobs); + worker_pool_run_func(octx->ctx->worker_pool, get_rows_thread_f32_f32, &grctx, n_jobs); return HTP_STATUS_OK; } diff --git a/ggml/src/ggml-hexagon/htp/hex-dma.h b/ggml/src/ggml-hexagon/htp/hex-dma.h index d1ddb0ecbf0..350ab9d966f 100644 --- a/ggml/src/ggml-hexagon/htp/hex-dma.h +++ b/ggml/src/ggml-hexagon/htp/hex-dma.h @@ -102,7 +102,7 @@ static inline bool dma_queue_push(dma_queue * q, dmlink(q->tail, desc); q->tail = desc; - // FARF(ERROR, "dma-push: i %u len %u dst %p src %p\n", q->push_idx, len, dst, src); + // FARF(ERROR, "dma-push: i %u width %u nrows %d dst %p src %p\n", q->push_idx, width, nrows, dptr.dst, dptr.src); q->push_idx = (q->push_idx + 1) & q->idx_mask; return true; } @@ -144,11 +144,37 @@ static inline dma_ptr dma_queue_pop(dma_queue * q) { dptr = q->dptr[q->pop_idx]; - // FARF(ERROR, "dma-pop: i %u dst %p\n", q->pop_idx, dst); + // FARF(ERROR, "dma-pop: i %u dst %p src %p\n", q->pop_idx, dptr.dst, dptr.src); q->pop_idx = (q->pop_idx + 1) & q->idx_mask; return dptr; } +static inline dma_ptr dma_queue_pop_nowait(dma_queue * q) { + dma_ptr dptr = { NULL }; + + if (q->push_idx == q->pop_idx) { + return dptr; + } + + dptr = q->dptr[q->pop_idx]; + + // FARF(ERROR, "dma-pop-nowait: i %u dst %p src %p\n", q->pop_idx, dptr.dst, dptr.src); + q->pop_idx = (q->pop_idx + 1) & q->idx_mask; + return dptr; +} + +static inline bool dma_queue_empty(dma_queue * q) { + return q->push_idx == q->pop_idx; +} + +static inline uint32_t dma_queue_depth(dma_queue * q) { + return (q->push_idx - q->pop_idx) & q->idx_mask; +} + +static inline uint32_t dma_queue_capacity(dma_queue * q) { + return q->capacity; +} + #ifdef __cplusplus } // extern "C" #endif diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index f1ad24dbfaa..127ab1d6659 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -44,32 +44,6 @@ struct htp_ops_context { uint32_t src0_nrows_per_thread; uint32_t src1_nrows_per_thread; - struct fastdiv_values src0_div1; // fastdiv values for ne1 - struct fastdiv_values src0_div2; // fastdiv values for ne2 - struct fastdiv_values src0_div3; // fastdiv values for ne3 - struct fastdiv_values src0_div21; // fastdiv values for ne2 * ne1 - - struct fastdiv_values src1_div1; // fastdiv values for ne1 - struct fastdiv_values src1_div2; // fastdiv values for ne2 - struct fastdiv_values src1_div3; // fastdiv values for ne3 - struct fastdiv_values src1_div21; // fastdiv values for ne2 * ne1 - - struct fastdiv_values src3_div1; // fastdiv values for ne1 - struct fastdiv_values src3_div2; // fastdiv values for ne2 - struct fastdiv_values src3_div3; // fastdiv values for ne3 - struct fastdiv_values src3_div21; // fastdiv values for ne2 * ne1 - - struct fastdiv_values broadcast_rk2; - struct fastdiv_values broadcast_rk3; - struct fastdiv_values broadcast_rv2; - struct fastdiv_values broadcast_rv3; - - struct fastdiv_values set_rows_div_ne12; // fastdiv values for ne12 - struct fastdiv_values set_rows_div_ne11; // fastdiv values for ne11 - - struct fastdiv_values get_rows_div_ne10; // fastdiv values for ne10 - struct fastdiv_values get_rows_div_ne10_ne11; // fastdiv values for ne10 * ne11 - uint32_t flags; }; diff --git a/ggml/src/ggml-hexagon/htp/matmul-ops.c b/ggml/src/ggml-hexagon/htp/matmul-ops.c index c360abe8dae..6f6f51f01f5 100644 --- a/ggml/src/ggml-hexagon/htp/matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/matmul-ops.c @@ -49,62 +49,6 @@ struct htp_matmul_context { struct fastdiv_values mm_div_r3; }; -// vdelta control to replicate first 4x fp32 values across lanes -static const uint8_t __attribute__((aligned(128))) repl_4x_f32[128] = { - 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, - 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, - 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04, - 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40, - 0x44, 0x44, 0x44, 0x44, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, - 0x04, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, - 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, -}; - -// vdelta control to replicate and interleave first 8x fp32 values across lanes -static const uint8_t __attribute__((aligned(128))) repl_interleave_8x_f32[128] = { - 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x00, 0x00, 0x00, - 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, - 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, - 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40, - 0x44, 0x44, 0x44, 0x44, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40, 0x44, 0x44, 0x44, - 0x44, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, - 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, -}; - -// vdelta control to replicate first fp32 value across all elements -static const uint8_t __attribute__((aligned(128))) repl_1x_f32[128] = { - 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, - 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, - 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, - 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, - 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, - 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, - 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, -}; - -// vdelta control to replicate first fp16 value across all elements -static const uint8_t __attribute__((aligned(128))) repl_1x_f16[128] = { - 0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x10, 0x10, 0x02, - 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x20, 0x20, 0x02, 0x02, 0x04, 0x04, - 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, - 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x40, 0x40, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, - 0x04, 0x04, 0x02, 0x02, 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, - 0x02, 0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x10, 0x10, - 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, -}; - -// vdelta control to replicate first fp16 value across all elements -static const uint8_t __attribute__((aligned(128))) repl_2x_f16[128] = { - 0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, - 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, - 0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, - 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, - 0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, - 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, - 0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, - 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, -}; - // vdelta control to expand first 32 e8m0 values into 32 uint32 elements static const uint8_t __attribute__((aligned(128))) expand_x32_e8m0[128] = { 0x00, 0x00, 0x00, 0x00, 0x01, 0x04, 0x00, 0x00, 0x02, 0x00, 0x08, 0x08, 0x01, 0x02, 0x00, 0x04, 0x04, 0x00, 0x00, @@ -2067,10 +2011,10 @@ static inline void quantize_block_f32_q8x1(float * restrict x, uint8_t * restric HVX_Vector vx3_qf = Q6_Vqf32_vsub_VsfVsf(vx[3], zero); // 32 elements // Convert to QF32 - HVX_Vector vmax0_qf = Q6_Vqf32_vsub_VsfVsf(vmax0_sf, zero); - HVX_Vector vmax1_qf = Q6_Vqf32_vsub_VsfVsf(vmax1_sf, zero); - HVX_Vector vmax2_qf = Q6_Vqf32_vsub_VsfVsf(vmax2_sf, zero); - HVX_Vector vmax3_qf = Q6_Vqf32_vsub_VsfVsf(vmax3_sf, zero); + HVX_Vector vmax0_qf = Q6_Vqf32_vsub_VsfVsf(vmax0_sf, zero); // replicated over all lanes + HVX_Vector vmax1_qf = Q6_Vqf32_vsub_VsfVsf(vmax1_sf, zero); // replicated over all lanes + HVX_Vector vmax2_qf = Q6_Vqf32_vsub_VsfVsf(vmax2_sf, zero); // replicated over all lanes + HVX_Vector vmax3_qf = Q6_Vqf32_vsub_VsfVsf(vmax3_sf, zero); // replicated over all lanes // Combine and convert to fp16 HVX_Vector vmax01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax1_qf, vmax0_qf))); @@ -2080,11 +2024,6 @@ static inline void quantize_block_f32_q8x1(float * restrict x, uint8_t * restric HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf))); HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf))); - // Replicate first fp16 scale across all lanes - HVX_Vector ctrl = *(const HVX_Vector *) repl_2x_f16; - vmax01_hf = Q6_V_vdelta_VV(vmax01_hf, ctrl); - vmax23_hf = Q6_V_vdelta_VV(vmax23_hf, ctrl); - HVX_Vector vd01_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax01_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0 HVX_Vector vd23_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax23_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0 HVX_Vector vd01_hf = Q6_Vhf_equals_Vqf16(vd01_qf16); @@ -2130,13 +2069,8 @@ static inline void quantize_block_f32_q8x2(float * restrict x, uint8_t * restric HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf))); // Compute max and scale - HVX_Vector vmax01_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx01_hf)); - HVX_Vector vmax23_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx23_hf)); - - // Replicate first fp16 scale across all lanes - HVX_Vector ctrl = *(const HVX_Vector *) repl_1x_f16; - vmax01_hf = Q6_V_vdelta_VV(vmax01_hf, ctrl); - vmax23_hf = Q6_V_vdelta_VV(vmax23_hf, ctrl); + HVX_Vector vmax01_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx01_hf)); // replicated over all lanes + HVX_Vector vmax23_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx23_hf)); // replicated over all lanes HVX_Vector vd01_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax01_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0 HVX_Vector vd23_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax23_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0 @@ -2179,11 +2113,7 @@ static inline void quantize_block_f32_q8x4(float * restrict x, uint8_t * restric // Compute max and scale HVX_Vector vmax_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx01_hf)); - vmax_hf = hvx_vec_reduce_max2_f16(hvx_vec_abs_f16(vx23_hf), vmax_hf); - - // Replicate first fp16 scale across all lanes - HVX_Vector ctrl = *(const HVX_Vector *) repl_1x_f16; - vmax_hf = Q6_V_vdelta_VV(vmax_hf, ctrl); + vmax_hf = hvx_vec_reduce_max2_f16(hvx_vec_abs_f16(vx23_hf), vmax_hf); // replicated over all lanes HVX_Vector vd_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0 HVX_Vector vd_hf = Q6_Vhf_equals_Vqf16(vd_qf16); diff --git a/ggml/src/ggml-hexagon/htp/rope-ops.c b/ggml/src/ggml-hexagon/htp/rope-ops.c index 943ca5c952e..aa6a6c9008d 100644 --- a/ggml/src/ggml-hexagon/htp/rope-ops.c +++ b/ggml/src/ggml-hexagon/htp/rope-ops.c @@ -10,6 +10,7 @@ #include "hex-dma.h" #include "hvx-utils.h" +#include "hex-fastdiv.h" #define GGML_COMMON_DECL_C #include "ggml-common.h" @@ -21,6 +22,9 @@ #define HTP_ROPE_TYPE_NORMAL 0 #define HTP_ROPE_TYPE_NEOX 2 +#define HTP_ROPE_SPAD_NROWS 16 +#define HTP_ROPE_SPAD_BLOCK (HTP_ROPE_SPAD_NROWS/2) + #define htp_rope_preamble \ const uint32_t ne00 = src0->ne[0]; \ const uint32_t ne01 = src0->ne[1]; \ @@ -42,7 +46,7 @@ const uint32_t nb2 = dst->nb[2]; \ const uint32_t nb3 = dst->nb[3]; -struct rope_th_ctx { +struct htp_rope_context { int32_t n_dims; int32_t mode; int32_t n_ctx_orig; @@ -57,7 +61,19 @@ struct rope_th_ctx { float theta_scale; float corr_dims[2]; + uint32_t src0_nrows_per_thread; + size_t spad_stride; + struct htp_ops_context * octx; + + size_t src0_row_size; + size_t dst_row_size; + size_t src0_row_size_aligned; + size_t dst_row_size_aligned; + size_t theta_cache_offset; + uint32_t src0_nrows; + + uint64_t t_start; }; static float rope_yarn_ramp(const float low, const float high, const int i0) { @@ -117,64 +133,23 @@ static void rope_corr_dims(int n_dims, dims[1] = MIN(n_dims - 1, end); } -static void init_rope_ctx(struct rope_th_ctx * rope_ctx, struct htp_ops_context * octx) { - memset(rope_ctx, 0, sizeof(struct rope_th_ctx)); - - const int32_t * op_params = &octx->op_params[0]; - - rope_ctx->n_dims = ((const int32_t *) op_params)[1]; - rope_ctx->mode = ((const int32_t *) op_params)[2]; - rope_ctx->n_ctx_orig = ((const int32_t *) op_params)[4]; - - memcpy(&rope_ctx->freq_base, (int32_t *) op_params + 5, sizeof(float)); - memcpy(&rope_ctx->freq_scale, (int32_t *) op_params + 6, sizeof(float)); - memcpy(&rope_ctx->ext_factor, (int32_t *) op_params + 7, sizeof(float)); - memcpy(&rope_ctx->attn_factor, (int32_t *) op_params + 8, sizeof(float)); - memcpy(&rope_ctx->beta_fast, (int32_t *) op_params + 9, sizeof(float)); - memcpy(&rope_ctx->beta_slow, (int32_t *) op_params + 10, sizeof(float)); - memcpy(&rope_ctx->sections, (int32_t *) op_params + 11, sizeof(int) * 4); - - rope_ctx->theta_scale = powf(rope_ctx->freq_base, -2.0f / rope_ctx->n_dims); - - rope_corr_dims(rope_ctx->n_dims, rope_ctx->n_ctx_orig, rope_ctx->freq_base, rope_ctx->beta_fast, - rope_ctx->beta_slow, rope_ctx->corr_dims); - - rope_ctx->octx = octx; - FARF(HIGH, "rope-f32 n_dims:%d, ext_factor:%.6f, theta_scale:%.6f, attn_factor:%.6f\n", rope_ctx->n_dims, - rope_ctx->ext_factor, rope_ctx->theta_scale, rope_ctx->attn_factor); -} +static inline void hvx_rope_neox_f32_aa(float * restrict dst, const float * restrict src0, uint32_t ne, const float * restrict theta_cache) { + const HVX_Vector * restrict vsrc = (const HVX_Vector *) src0; + const HVX_Vector * restrict vtheta = (const HVX_Vector *) theta_cache; + HVX_Vector * restrict vdst = (HVX_Vector *) dst; -static void hvx_calc_rope_neox_f32(const float * restrict src0, - float * restrict dst, - const int num_elems, - const float * restrict theta_cache) { - // for (int i = 0; i < num_elems; i += 2) { - //const float cos_theta = theta_cache[i + 0]; - //const float sin_theta = theta_cache[i + 1]; + uint32_t nvec = (ne / (VLEN_FP32 * 2) * 2); // 2 vecs per loop, step of 2 - //const float x0 = src[0]; - //const float x1 = src[num_elems/2]; + uint32_t he = ne / 2; // half_dims offset in elements + uint32_t hv = he / VLEN_FP32; // half_dims offset in vectors - //dst[0] = x0*cos_theta - x1*sin_theta; - //dst[num_elems/2] = x0*sin_theta + x1*cos_theta; + #pragma unroll(2) + for (uint32_t i = 0; i < nvec; i += 2) { + HVX_Vector v0 = vsrc[i/2+0]; + HVX_Vector v1 = vsrc[i/2+hv]; - //src += 1; - //dst += 1; - // } - - const uint8_t * restrict src0_curr = (const uint8_t *) src0; - const uint8_t * restrict theta_curr = (const uint8_t *) theta_cache; - uint8_t * restrict dst_curr = (uint8_t *) dst; - - int step_of_1 = num_elems >> 6; // 6 because we process two vectors at once - int half_size = (sizeof(float) * (num_elems / 2)); - - for (int i = 0; i < step_of_1; i++) { - HVX_Vector v0 = *(HVX_Vector *) src0_curr; - HVX_Vector v1 = *(HVX_Vector *) (src0_curr + half_size); - - HVX_Vector v2 = *(HVX_Vector *) theta_curr; - HVX_Vector v3 = *(HVX_Vector *) (theta_curr + VLEN); + HVX_Vector v2 = vtheta[i+0]; + HVX_Vector v3 = vtheta[i+1]; HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(v3, v2, -4); // vcos_sin[0] = cos_theta, vcos_sin[1] = sin_theta @@ -186,45 +161,34 @@ static void hvx_calc_rope_neox_f32(const float * restrict src0, HVX_Vector v4 = Q6_Vqf32_vsub_Vqf32Vqf32(vx0_c, vx1_s); HVX_Vector v5 = Q6_Vqf32_vadd_Vqf32Vqf32(vx0_s, vx1_c); - *(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v4); - *(HVX_Vector *) (dst_curr + half_size) = Q6_Vsf_equals_Vqf32(v5); + vdst[i/2+0] = Q6_Vsf_equals_Vqf32(v4); + vdst[i/2+hv] = Q6_Vsf_equals_Vqf32(v5); + } - src0_curr += VLEN; - theta_curr += 2 * VLEN; - dst_curr += VLEN; + for (uint32_t i = nvec * VLEN_FP32; i < ne; i += 2) { + const float cos_theta = theta_cache[i+0]; + const float sin_theta = theta_cache[i+1]; + float x0 = src0[i/2]; + float x1 = src0[i/2 + he]; + dst[i/2] = x0 * cos_theta - x1 * sin_theta; + dst[i/2 + he] = x0 * sin_theta + x1 * cos_theta; } } -static void hvx_calc_rope_f32(const float * restrict src0, - float * restrict dst, - const int num_elems, - const float * restrict theta_cache) { - // for (int i = 0; i < num_elems; i += 2) { - //const float cos_theta = theta_cache[i + 0]; - //const float sin_theta = theta_cache[i + 1]; - - //const float x0 = src[0]; - //const float x1 = src[1]; +static inline void hvx_rope_f32_aa(float * restrict dst, const float * restrict src0, uint32_t ne, const float * restrict theta_cache) { + const HVX_Vector * restrict vsrc = (const HVX_Vector *) src0; + const HVX_Vector * restrict vtheta = (const HVX_Vector *) theta_cache; + HVX_Vector * restrict vdst = (HVX_Vector *) dst; - //dst[0] = x0*cos_theta - x1*sin_theta; - //dst[1] = x0*sin_theta + x1*cos_theta; + uint32_t nvec = (ne / (VLEN_FP32 * 2)) * 2; // 2 vecs per loop, step of two - //src += 2; - //dst += 2; - // } + #pragma unroll(2) + for (uint32_t i = 0; i < nvec; i+=2) { + HVX_Vector v0 = vsrc[i+0]; + HVX_Vector v1 = vsrc[i+1]; - const uint8_t * restrict src0_curr = (const uint8_t *) src0; - const uint8_t * restrict theta_curr = (const uint8_t *) theta_cache; - uint8_t * restrict dst_curr = (uint8_t *) dst; - - int step_of_1 = num_elems >> 6; // 6 because we process two vectors at once - - for (int i = 0; i < step_of_1; i++) { - HVX_Vector v0 = *(HVX_Vector *) src0_curr; - HVX_Vector v1 = *(HVX_Vector *) (src0_curr + VLEN); - - HVX_Vector v2 = *(HVX_Vector *) theta_curr; - HVX_Vector v3 = *(HVX_Vector *) (theta_curr + VLEN); + HVX_Vector v2 = vtheta[i+0]; + HVX_Vector v3 = vtheta[i+1]; HVX_VectorPair vx0_x1 = Q6_W_vdeal_VVR(v1, v0, -4); // vx0_x1[0] = x0, vx0_x1[1] = x1 HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(v3, v2, -4); // vcos_sin[0] = cos_theta, vcos_sin[1] = sin_theta @@ -239,151 +203,182 @@ static void hvx_calc_rope_f32(const float * restrict src0, HVX_VectorPair vstore = Q6_W_vshuff_VVR(Q6_Vsf_equals_Vqf32(v5), Q6_Vsf_equals_Vqf32(v4), -4); - *(HVX_Vector *) dst_curr = Q6_V_lo_W(vstore); - *(HVX_Vector *) (dst_curr + VLEN) = Q6_V_hi_W(vstore); + vdst[i+0] = Q6_V_lo_W(vstore); + vdst[i+1] = Q6_V_hi_W(vstore); + } + + for (uint32_t i = nvec * VLEN_FP32; i < ne; i += 2) { + const float cos_theta = theta_cache[i+0]; + const float sin_theta = theta_cache[i+1]; + float x0 = src0[i+0]; + float x1 = src0[i+1]; + dst[i+0] = x0 * cos_theta - x1 * sin_theta; + dst[i+1] = x0 * sin_theta + x1 * cos_theta; + } +} + +static void inline rope_basic_f32(struct htp_rope_context * rctx, uint8_t * restrict dst, uint8_t * restrict src, + uint32_t nr, uint32_t ne0, const float * restrict theta_cache) { + #pragma unroll(4) + for (uint32_t i = 0; i < nr; i++) { + float * d = (float *) (dst + i * rctx->dst_row_size_aligned); + float * s = (float *) (src + i * rctx->src0_row_size_aligned); + + hvx_rope_f32_aa(d, s, rctx->n_dims, theta_cache); + + // fill the remain channels with data from src tensor + if (rctx->n_dims < ne0) { + hvx_copy_f32_uu((uint8_t *)(d + rctx->n_dims), (uint8_t *)(s + rctx->n_dims), ne0 - rctx->n_dims); + } + } +} + +static void inline rope_neox_f32(struct htp_rope_context * rctx, uint8_t * restrict dst, uint8_t * restrict src, + uint32_t nr, uint32_t ne0, const float * restrict theta_cache) { + #pragma unroll(4) + for (uint32_t i = 0; i < nr; i++) { + float * d = (float *) (dst + i * rctx->dst_row_size_aligned); + float * s = (float *) (src + i * rctx->src0_row_size_aligned); - src0_curr += 2 * VLEN; - theta_curr += 2 * VLEN; - dst_curr += 2 * VLEN; + hvx_rope_neox_f32_aa(d, s, rctx->n_dims, theta_cache); + + // fill the remain channels with data from src tensor + if (rctx->n_dims < ne0) { + hvx_copy_f32_uu((uint8_t *)(d + rctx->n_dims), (uint8_t *)(s + rctx->n_dims), ne0 - rctx->n_dims); + } } } -static void rope_hex_f32(struct rope_th_ctx * rope_ctx, - const uint32_t ir0, - const uint32_t ir1, - int nth, - int ith, - const int opt_path) { - struct htp_ops_context * octx = rope_ctx->octx; +static void rope_job_f32(unsigned int nth, unsigned int ith, void * data) { + struct htp_rope_context * rctx = (struct htp_rope_context *) data; + struct htp_ops_context * octx = rctx->octx; const struct htp_tensor * src0 = &octx->src0; const struct htp_tensor * src1 = &octx->src1; const struct htp_tensor * src2 = &octx->src2; struct htp_tensor * dst = &octx->dst; - const int32_t mode = rope_ctx->mode; - const bool is_neox = mode & HTP_ROPE_TYPE_NEOX; - htp_rope_preamble; - const int32_t * pos = (const int32_t *) src1->data; + const uint32_t src0_nrows = rctx->src0_nrows; + const uint32_t src0_nrows_per_thread = rctx->src0_nrows_per_thread; - float * wp0 = (float *) (octx->src0_spad.data + (ith * nb01)); + const uint32_t src0_start_row = src0_nrows_per_thread * ith; + const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); - const float * freq_factors = NULL; - if (src2 != NULL) { - freq_factors = (const float *) src2->data; + // no work for this thread + if (src0_start_row >= src0_end_row) { + return; } - const uint32_t i1_end = MIN(ir1, ne1); - const int32_t half_dims = rope_ctx->n_dims / 2; - const size_t remain_bytes = (ne0 - rope_ctx->n_dims) * sizeof(float); - for (uint32_t i3 = 0; i3 < ne3; i3++) { // batch - for (uint32_t i2 = 0; i2 < ne2; i2++) { // seq-len - const int32_t p = pos[i2]; + uint64_t tt = HAP_perf_get_qtimer_count(); - rope_cache_init(p, rope_ctx->freq_scale, freq_factors, rope_ctx->corr_dims, ne0, rope_ctx->ext_factor, - rope_ctx->attn_factor, wp0, rope_ctx->theta_scale); + const int32_t mode = rctx->mode; + const bool is_neox = mode & HTP_ROPE_TYPE_NEOX; - for (uint32_t i1 = ir0; i1 < i1_end; i1++) { // attn-heads - const float * src = (float *) ((char *) src0->data + i3 * nb03 + i2 * nb02 + i1 * nb01); - float * dst_data = (float *) ((char *) dst->data + i3 * nb3 + i2 * nb2 + i1 * nb1); + // VTCM setup + uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread); + float * theta_cache = (float *) (src0_spad_base); + src0_spad_base = src0_spad_base + rctx->theta_cache_offset; + uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread); - const float * src_loc = src; - float * dst_data_loc = dst_data; + dma_queue * dma_queue = octx->ctx->dma[ith]; + const int32_t * pos = (const int32_t *) src1->data; + const float * freq_factors = src2->data ? (const float *) src2->data : NULL; - if (1 == opt_path) { - if (is_neox) { - hvx_calc_rope_neox_f32(src_loc, dst_data_loc, rope_ctx->n_dims, wp0); - } else { - hvx_calc_rope_f32(src_loc, dst_data_loc, rope_ctx->n_dims, wp0); - } + uint32_t ir = 0; + uint32_t prev_i2 = (uint32_t) -1; - src_loc += rope_ctx->n_dims; - dst_data_loc += rope_ctx->n_dims; - } else { - for (uint32_t i0 = 0; i0 < rope_ctx->n_dims; i0 += 2) { - const float cos_theta = wp0[i0 + 0]; - const float sin_theta = wp0[i0 + 1]; + for (uint32_t i3 = 0; i3 < ne3; i3++) { // batch + for (uint32_t i2 = 0; i2 < ne2; i2++) { // seq-len + for (uint32_t i1 = 0; i1 < ne1; ) { // attn-heads + if (ir < src0_start_row) { ir++; i1++; continue; } + if (ir >= src0_end_row) goto done; - if (is_neox) { - const float x0 = src_loc[0]; - const float x1 = src_loc[half_dims]; + // Rows in this block + const uint32_t nrows = MIN(src0_end_row - ir, ne1 - i1); - dst_data_loc[0] = x0 * cos_theta - x1 * sin_theta; - dst_data_loc[half_dims] = x0 * sin_theta + x1 * cos_theta; + // Depth before prefetch + uint32_t dma_depth = dma_queue_depth(dma_queue); - src_loc += 1; - dst_data_loc += 1; - } else { - const float x0 = src_loc[0]; - const float x1 = src_loc[1]; + // FARF(HIGH, "rope-block %u: ir %u n-rows %u dma-depth %u : usec %u", ith, ir, nrows, dma_depth, + // (unsigned) HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - rctx->t_start)); - dst_data_loc[0] = x0 * cos_theta - x1 * sin_theta; - dst_data_loc[1] = x0 * sin_theta + x1 * cos_theta; + // Prefetch loop + for (uint32_t pnr = 0, pr = 0; pr < nrows && pr < HTP_ROPE_SPAD_NROWS; pr += pnr) { + pnr = MIN(nrows - pr, HTP_ROPE_SPAD_BLOCK); - src_loc += 2; - dst_data_loc += 2; - } - } + uint32_t pi1 = i1 + pr; + uint32_t pir = ir + pr; + + // Dummy DMA transaction for sequencing (interleaving dst,src,dst,...) + dma_queue_push_vtcm_to_ddr(dma_queue, dma_make_ptr((void *) dst->data, dst_spad_base + pr * rctx->dst_row_size_aligned), 0, 0, 0); - src_loc += (is_neox ? half_dims : 0); - dst_data_loc += (is_neox ? half_dims : 0); + const uint8_t * src_addr = (const uint8_t *) src0->data + i3 * nb03 + i2 * nb02 + pi1 * nb01; + uint8_t * src_spad = src0_spad_base + pr * rctx->src0_row_size_aligned; + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src_spad, src_addr), + rctx->src0_row_size_aligned, rctx->src0_row_size, pnr); + + // FARF(HIGH, "rope-prefetch %u: pr %u i1 %u i2 %u i3 %u src-spad %p src-addr %p pnr %u", ith, pir, pi1, i2, i3, src_spad, src_addr, pnr); } - // TODO: use simd to speed up the remaining elements copy - memcpy(dst_data_loc, src_loc, remain_bytes); - } - } - } -} + // Update theta cache + if (i2 != prev_i2) { + prev_i2 = i2; -static void rope_job_f32_per_thread(struct rope_th_ctx * rope_ctx, int nth, int ith) { - struct htp_ops_context * octx = rope_ctx->octx; + const int32_t p = pos[i2]; + rope_cache_init(p, rctx->freq_scale, freq_factors, rctx->corr_dims, ne0, rctx->ext_factor, rctx->attn_factor, theta_cache, rctx->theta_scale); - const struct htp_tensor * src0 = &octx->src0; - const struct htp_tensor * src1 = &octx->src1; - struct htp_tensor * dst = &octx->dst; + // FARF(HIGH, "rope-theta %u: ir %u i1 %u i2 %u i3 %u cache %p : usec %u", ith, ir, i1, i2, i3, theta_cache, + // (unsigned) HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - rctx->t_start)); + } - htp_rope_preamble; + // Skip DMA transactions from prev block (if any) + // No need to wait for these since the DMA is setup for in-order processing + for (uint32_t d=0; d < dma_depth; d++) { dma_queue_pop_nowait(dma_queue); } - const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows - const uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread; + // Compute loop + for (uint32_t cnr = 0, cr = 0; cr < nrows; cr += cnr, ir += cnr, i1 += cnr) { + // Number of rows to compute + cnr = MIN(nrows - cr, HTP_ROPE_SPAD_BLOCK); - const uint32_t src0_start_row = src0_nrows_per_thread * ith; - const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); + uint8_t * dst_spad = (uint8_t *) dma_queue_pop(dma_queue).src; + uint8_t * src_spad = (uint8_t *) dma_queue_pop(dma_queue).dst; - // no work for this thread - if (src0_start_row >= src0_end_row) { - return; - } + // FARF(HIGH, "rope-compute %u: ir %u i1 %u i2 %u i3 %u src-spad %p cnr %u : usec %u", ith, ir, i1, i2, i3, src_spad, cnr, + // (unsigned) HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - rctx->t_start)); - uint64_t t1, t2; - t1 = HAP_perf_get_qtimer_count(); + if (is_neox) { + rope_neox_f32(rctx, dst_spad, src_spad, cnr, ne0, theta_cache); + } else { + rope_basic_f32(rctx, dst_spad, src_spad, cnr, ne0, theta_cache); + } - int is_aligned = 1; - int opt_path = 0; - if ((0 == hex_is_aligned((void *) src0->data, VLEN)) || (0 == hex_is_aligned((void *) src1->data, VLEN)) || - (0 == hex_is_aligned((void *) dst->data, VLEN))) { - FARF(HIGH, "rope-f32: unaligned addresses in rope op, possibly slower execution\n"); - is_aligned = 0; - } - if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) { - opt_path = 1; - } + uint8_t * dst_addr = (uint8_t *) dst->data + i3 * nb3 + i2 * nb2 + i1 * nb1; + dma_queue_push_vtcm_to_ddr(dma_queue, dma_make_ptr(dst_addr, dst_spad), rctx->dst_row_size, rctx->dst_row_size_aligned, cnr); - rope_hex_f32(rope_ctx, src0_start_row, src0_end_row, nth, ith, opt_path); + // Prefetch more rows (if any) + if ((cr + HTP_ROPE_SPAD_NROWS) < nrows) { + uint32_t pnr = MIN(nrows - (cr + HTP_ROPE_SPAD_NROWS), HTP_ROPE_SPAD_BLOCK); + uint32_t pi1 = i1 + HTP_ROPE_SPAD_NROWS; + uint32_t pir = ir + HTP_ROPE_SPAD_NROWS; - t2 = HAP_perf_get_qtimer_count(); + const uint8_t * src_addr = (const uint8_t *) src0->data + i3 * nb03 + i2 * nb02 + pi1 * nb01; + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src_spad, src_addr), + rctx->src0_row_size_aligned, rctx->src0_row_size, pnr); - FARF(HIGH, "rope-f32: %d/%d/%d: (%u:%u) usec %u\n", ith, nth, opt_path, src0_start_row, src0_end_row, - (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); -} + // FARF(HIGH, "rope-prefetch %u: pr %u i1 %u i2 %u i3 %u src-spad %p src-addr %p pnr %u", ith, pir, pi1, i2, i3, src_spad, src_addr, pnr); + } + } + } + } + } -static void rope_job_dispatcher_f32(unsigned int n, unsigned int i, void * data) { - struct rope_th_ctx * rope_ctx = (struct rope_th_ctx *) data; +done: + dma_queue_flush(dma_queue); + tt = HAP_perf_get_qtimer_count() - tt; - rope_job_f32_per_thread(rope_ctx, n, i); + FARF(HIGH, "rope-f32: %d/%d: (%u:%u) usec %u\n", ith, nth, src0_start_row, src0_end_row, (unsigned) HAP_perf_qtimer_count_to_us(tt)); } static int execute_op_rope_f32(struct htp_ops_context * octx) { @@ -394,17 +389,10 @@ static int execute_op_rope_f32(struct htp_ops_context * octx) { const struct htp_tensor * src2 = &octx->src2; struct htp_tensor * dst = &octx->dst; - worker_callback_t op_func; - const char * op_type = NULL; - - struct rope_th_ctx rope_ctx; + const char * op_type = "rope-f32"; switch (octx->op) { case HTP_OP_ROPE: - op_func = rope_job_dispatcher_f32; - op_type = "rope-f32"; - - init_rope_ctx(&rope_ctx, octx); break; default: @@ -415,49 +403,79 @@ static int execute_op_rope_f32(struct htp_ops_context * octx) { const uint32_t n_threads = octx->n_threads; const size_t src0_row_size = src0->nb[1]; - const size_t src1_row_size = src0_row_size; const size_t dst_row_size = dst->nb[1]; - // VTCM scratchpads for all tensors - // N rows per thread, padded to HVX vector size - octx->dst_spad.size = hex_round_up(dst_row_size, 128) * n_threads; - octx->src0_spad.size = hex_round_up(src0_row_size, 128) * n_threads; - octx->src1_spad.size = hex_round_up(src1_row_size, 128) * n_threads; - - size_t spad_size = octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size; - - if (src2->ne[0]) { - FARF(HIGH, - "%s: %ux%ux%ux%u (x %ux%ux%ux%u x %ux%ux%ux%u) -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u " - "dst-spad-size %u\n", - op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], - src1->ne[3], src2->ne[0], src2->ne[1], src2->ne[2], src2->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], - dst->ne[3], octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size); - } else { - FARF(HIGH, - "%s: %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", - op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], - src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], octx->src0_spad.size, octx->src1_spad.size, - octx->dst_spad.size); - } + // Aligned row sizes for VTCM + const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN); + const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN); + const size_t theta_cache_size_aligned = hex_round_up(src0->ne[0] * sizeof(float), 128); + + // Calculate spad sizes per thread + size_t src0_spad_per_thread = theta_cache_size_aligned + HTP_ROPE_SPAD_NROWS * src0_row_size_aligned; + size_t dst_spad_per_thread = HTP_ROPE_SPAD_NROWS * dst_row_size_aligned; + size_t spad_per_thread = src0_spad_per_thread + dst_spad_per_thread; - // Make sure the reserved vtcm size is sufficient - if (octx->ctx->vtcm_size < spad_size) { - FARF(ERROR, "%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size, - spad_size); + // Check if we fit in VTCM + size_t total_vtcm_needed = spad_per_thread * n_threads; + if (octx->ctx->vtcm_size < total_vtcm_needed) { + FARF(ERROR, "%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size, total_vtcm_needed); return HTP_STATUS_VTCM_TOO_SMALL; } + // Assign sizes + octx->src0_spad.size_per_thread = src0_spad_per_thread; + octx->dst_spad.size_per_thread = dst_spad_per_thread; + octx->src0_spad.size = n_threads * src0_spad_per_thread; + octx->dst_spad.size = n_threads * dst_spad_per_thread; + octx->src1_spad.size = 0; + + // Assign pointers octx->src0_spad.data = octx->ctx->vtcm_base; - octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; - octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; + octx->src1_spad.data = NULL; + octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size; + + // Fill context + struct htp_rope_context rctx; + memset(&rctx, 0, sizeof(struct htp_rope_context)); + + rctx.t_start = HAP_perf_get_qtimer_count(); + + rctx.octx = octx; + + const int32_t * op_params = &octx->op_params[0]; + rctx.n_dims = ((const int32_t *) op_params)[1]; + rctx.mode = ((const int32_t *) op_params)[2]; + rctx.n_ctx_orig = ((const int32_t *) op_params)[4]; + memcpy(&rctx.freq_base, (int32_t *) op_params + 5, sizeof(float)); + memcpy(&rctx.freq_scale, (int32_t *) op_params + 6, sizeof(float)); + memcpy(&rctx.ext_factor, (int32_t *) op_params + 7, sizeof(float)); + memcpy(&rctx.attn_factor, (int32_t *) op_params + 8, sizeof(float)); + memcpy(&rctx.beta_fast, (int32_t *) op_params + 9, sizeof(float)); + memcpy(&rctx.beta_slow, (int32_t *) op_params + 10, sizeof(float)); + memcpy(&rctx.sections, (int32_t *) op_params + 11, sizeof(int) * 4); + + rctx.theta_scale = powf(rctx.freq_base, -2.0f / rctx.n_dims); + + rope_corr_dims(rctx.n_dims, rctx.n_ctx_orig, rctx.freq_base, rctx.beta_fast, rctx.beta_slow, rctx.corr_dims); + + rctx.src0_row_size = src0_row_size; + rctx.dst_row_size = dst_row_size; + rctx.src0_row_size_aligned = src0_row_size_aligned; + rctx.dst_row_size_aligned = dst_row_size_aligned; + rctx.theta_cache_offset = theta_cache_size_aligned; + + uint32_t ne0 = dst->ne[0]; uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3]; + rctx.src0_nrows = src0_nrows; + + FARF(HIGH, "rope-f32 n-rows %u n-dims %d ne0 %u ext-factor %.6f theta-scale %.6f attn-factor %.6f\n", rctx.src0_nrows, rctx.n_dims, ne0, + rctx.ext_factor, rctx.theta_scale, rctx.attn_factor); if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { - uint32_t n_jobs = MIN(n_threads, src0_nrows); - octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs; - worker_pool_run_func(octx->ctx->worker_pool, op_func, &rope_ctx, n_jobs); + uint32_t n_jobs = MIN(n_threads, src0_nrows); + rctx.src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs; + worker_pool_run_func(octx->ctx->worker_pool, rope_job_f32, &rctx, n_jobs); } return err; diff --git a/ggml/src/ggml-hexagon/htp/set-rows-ops.c b/ggml/src/ggml-hexagon/htp/set-rows-ops.c index 904484da9de..2fd6c907724 100644 --- a/ggml/src/ggml-hexagon/htp/set-rows-ops.c +++ b/ggml/src/ggml-hexagon/htp/set-rows-ops.c @@ -43,11 +43,21 @@ \ const uint32_t nr = ne01; -static int set_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth, const int ith) { +struct htp_set_rows_context { + struct htp_ops_context * octx; + struct fastdiv_values div_ne12; + struct fastdiv_values div_ne11; + uint32_t src0_nrows_per_thread; +}; + +static void set_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *data) { + struct htp_set_rows_context * srctx = (struct htp_set_rows_context *)data; + struct htp_ops_context * octx = srctx->octx; + set_rows_preamble; // parallelize by rows of src0 - const uint32_t dr = octx->src0_nrows_per_thread; + const uint32_t dr = srctx->src0_nrows_per_thread; const uint32_t ir0 = dr * ith; const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr; @@ -56,8 +66,8 @@ static int set_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth, for (uint32_t i03 = 0; i03 < ne03; ++i03) { for (uint32_t i02 = 0; i02 < ne02; ++i02) { for (uint32_t i = ir0; i < ir1; ++i) { - const uint32_t i12 = fastmodulo(i03, ne12, &octx->set_rows_div_ne12); - const uint32_t i11 = fastmodulo(i02, ne11, &octx->set_rows_div_ne11); + const uint32_t i12 = fastmodulo(i03, ne12, &srctx->div_ne12); + const uint32_t i11 = fastmodulo(i02, ne11, &srctx->div_ne11); const uint32_t i10 = i; const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12; @@ -76,15 +86,16 @@ static int set_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth, } } } - - return HTP_STATUS_OK; } -static int set_rows_thread_f16_f32(struct htp_ops_context * octx, const int nth, const int ith) { +static void set_rows_thread_f16_f32(unsigned int nth, unsigned int ith, void *data) { + struct htp_set_rows_context * srctx = (struct htp_set_rows_context *)data; + struct htp_ops_context * octx = srctx->octx; + set_rows_preamble; // parallelize by rows of src0 - const uint32_t dr = octx->src0_nrows_per_thread; + const uint32_t dr = srctx->src0_nrows_per_thread; const uint32_t ir0 = dr * ith; const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr; @@ -93,8 +104,8 @@ static int set_rows_thread_f16_f32(struct htp_ops_context * octx, const int nth, for (uint32_t i03 = 0; i03 < ne03; ++i03) { for (uint32_t i02 = 0; i02 < ne02; ++i02) { for (uint32_t i = ir0; i < ir1; ++i) { - const uint32_t i12 = fastmodulo(i03, ne12, &octx->set_rows_div_ne12); - const uint32_t i11 = fastmodulo(i02, ne11, &octx->set_rows_div_ne11); + const uint32_t i12 = fastmodulo(i03, ne12, &srctx->div_ne12); + const uint32_t i11 = fastmodulo(i02, ne11, &srctx->div_ne11); const uint32_t i10 = i; const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12; @@ -112,16 +123,6 @@ static int set_rows_thread_f16_f32(struct htp_ops_context * octx, const int nth, } } } - - return HTP_STATUS_OK; -} - -static void set_rows_work_f16_f32(unsigned int n, unsigned int i, void *data) { - set_rows_thread_f16_f32((struct htp_ops_context *) data, n, i); -} - -static void set_rows_work_f32_f32(unsigned int n, unsigned int i, void *data) { - set_rows_thread_f32_f32((struct htp_ops_context *) data, n, i); } int op_set_rows(struct htp_ops_context * octx) { @@ -143,18 +144,20 @@ int op_set_rows(struct htp_ops_context * octx) { return HTP_STATUS_OK; } - octx->set_rows_div_ne12 = init_fastdiv_values(ne12); - octx->set_rows_div_ne11 = init_fastdiv_values(ne11); + struct htp_set_rows_context srctx; + srctx.octx = octx; + srctx.div_ne12 = init_fastdiv_values(ne12); + srctx.div_ne11 = init_fastdiv_values(ne11); const uint32_t n_jobs = MIN(nr, octx->n_threads); - octx->src0_nrows_per_thread = (nr + n_jobs - 1) / n_jobs; + srctx.src0_nrows_per_thread = (nr + n_jobs - 1) / n_jobs; switch(octx->dst.type) { case HTP_TYPE_F32: - worker_pool_run_func(octx->ctx->worker_pool, set_rows_work_f32_f32, octx, n_jobs); + worker_pool_run_func(octx->ctx->worker_pool, set_rows_thread_f32_f32, &srctx, n_jobs); break; case HTP_TYPE_F16: - worker_pool_run_func(octx->ctx->worker_pool, set_rows_work_f16_f32, octx, n_jobs); + worker_pool_run_func(octx->ctx->worker_pool, set_rows_thread_f16_f32, &srctx, n_jobs); break; default: return HTP_STATUS_NO_SUPPORT; diff --git a/ggml/src/ggml-hexagon/htp/softmax-ops.c b/ggml/src/ggml-hexagon/htp/softmax-ops.c index e91a16d947f..6e22eb6a639 100644 --- a/ggml/src/ggml-hexagon/htp/softmax-ops.c +++ b/ggml/src/ggml-hexagon/htp/softmax-ops.c @@ -10,6 +10,7 @@ #include "hex-dma.h" #include "hvx-utils.h" +#include "hex-fastdiv.h" #define GGML_COMMON_DECL_C #include "ggml-common.h" @@ -48,7 +49,7 @@ const uint32_t nb2 = dst->nb[2]; \ const uint32_t nb3 = dst->nb[3]; -struct softmax_th_ctx { +struct htp_softmax_context { bool use_f16; bool use_src1; uint32_t n_head; @@ -59,28 +60,48 @@ struct softmax_th_ctx { float m0; float m1; + uint32_t src0_nrows_per_thread; + struct fastdiv_values fastdiv_ne01; + struct fastdiv_values fastdiv_ne02; + struct fastdiv_values fastdiv_ne12; // For mask broadcasting + struct fastdiv_values fastdiv_ne13; // For mask broadcasting + size_t spad_stride; + struct htp_ops_context * octx; }; -static void init_softmax_ctx(struct softmax_th_ctx * softmax_ctx, struct htp_ops_context * octx) { +static void init_softmax_ctx(struct htp_softmax_context * smctx, struct htp_ops_context * octx) { const struct htp_tensor * src0 = &octx->src0; const struct htp_tensor * src1 = &octx->src1; - memset(softmax_ctx, 0, sizeof(struct softmax_th_ctx)); + memset(smctx, 0, sizeof(struct htp_softmax_context)); + + memcpy(&smctx->scale, (float *) octx->op_params, sizeof(float)); + memcpy(&smctx->max_bias, (float *) octx->op_params + 1, sizeof(float)); + + smctx->n_head = src0->ne[2]; + smctx->n_head_log2 = 1u << (uint32_t) floor(log2(smctx->n_head)); + + smctx->m0 = powf(2.0f, -(smctx->max_bias) / smctx->n_head_log2); + smctx->m1 = powf(2.0f, -(smctx->max_bias / 2.0f) / smctx->n_head_log2); - memcpy(&softmax_ctx->scale, (float *) octx->op_params, sizeof(float)); - memcpy(&softmax_ctx->max_bias, (float *) octx->op_params + 1, sizeof(float)); + smctx->use_src1 = (src1->ne[0] != 0); + smctx->use_f16 = (src1->ne[0] != 0) && (src1->type == HTP_TYPE_F16); - softmax_ctx->n_head = src0->ne[2]; - softmax_ctx->n_head_log2 = 1u << (uint32_t) floor(log2(softmax_ctx->n_head)); + smctx->octx = octx; - softmax_ctx->m0 = powf(2.0f, -(softmax_ctx->max_bias) / softmax_ctx->n_head_log2); - softmax_ctx->m1 = powf(2.0f, -(softmax_ctx->max_bias / 2.0f) / softmax_ctx->n_head_log2); + // Initialize fastdiv values + const uint32_t ne01 = src0->ne[1]; + const uint32_t ne02 = src0->ne[2]; - softmax_ctx->use_src1 = (src1->ne[0] != 0); - softmax_ctx->use_f16 = (src1->ne[0] != 0) && (src1->type == HTP_TYPE_F16); + if (ne01 > 0) smctx->fastdiv_ne01 = init_fastdiv_values(ne01); + if (ne02 > 0) smctx->fastdiv_ne02 = init_fastdiv_values(ne02); - softmax_ctx->octx = octx; + const uint32_t ne12 = (src1->ne[0]) ? src1->ne[2] : 1; + const uint32_t ne13 = (src1->ne[0]) ? src1->ne[3] : 1; + + if (ne12 > 0) smctx->fastdiv_ne12 = init_fastdiv_values(ne12); + if (ne13 > 0) smctx->fastdiv_ne13 = init_fastdiv_values(ne13); } static void hvx_fast_softmax_prep_f32(const uint8_t * restrict src, @@ -139,8 +160,7 @@ static void hvx_fast_softmax_f32(const uint8_t * restrict src, max_vec = Q6_Vsf_vmax_VsfVsf(max_vec, v1); } - HVX_Vector v = hvx_vec_reduce_max_f32(max_vec); - max_vec = hvx_vec_repl4(v); + max_vec = hvx_vec_reduce_max_f32(max_vec); // replicated over all lanes #pragma unroll(4) for (int i = 0; i < step_of_1; i++) { @@ -154,8 +174,7 @@ static void hvx_fast_softmax_f32(const uint8_t * restrict src, v_pad[i] = v3; } - v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_vec)); - sum_vec = hvx_vec_repl4(v); + sum_vec = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_vec)); // replicated over all lanes HVX_VectorPred pos_sum = Q6_Q_vcmp_gt_VwVw(sum_vec, zero_v); HVX_Vector v4 = hvx_vec_inverse_f32(sum_vec); @@ -183,83 +202,9 @@ static float hvx_softmax_f32(const uint8_t * restrict src, return sum; } -static void softmax_htp_f32(int nth, int ith, struct softmax_th_ctx * softmax_ctx, int opt_path) { - struct htp_ops_context * octx = softmax_ctx->octx; - - const struct htp_tensor * src0 = &octx->src0; - const struct htp_tensor * src1 = &octx->src1; - const struct htp_tensor * dst = &octx->dst; - - htp_softmax_preamble3; - - uint8_t * src0_spad_data = octx->src0_spad.data + (ith * nb01); - uint8_t * src1_spad_data = octx->src1_spad.data + (ith * nb01); - uint8_t * dst_spad_data = octx->dst_spad.data + (ith * nb1); - - float * wp0 = (float *) src0_spad_data; - float * wp1 = (float *) src1_spad_data; - float * wp2 = (float *) dst_spad_data; - - for (uint32_t i03 = 0; i03 < ne03; i03++) { - for (uint32_t i02 = 0; i02 < ne02; i02++) { - for (uint32_t i01 = ith; i01 < ne01; i01 += nth) { - const uint32_t i11 = i01; - const uint32_t i12 = i02 % ne12; - const uint32_t i13 = i03 % ne13; - - // ALiBi - const uint32_t h = i02; // head - - const float slope = (softmax_ctx->max_bias > 0.0f) ? - h < softmax_ctx->n_head_log2 ? - powf(softmax_ctx->m0, h + 1) : - powf(softmax_ctx->m1, 2 * (h - softmax_ctx->n_head_log2) + 1) : - 1.0f; - - float * sp = (float *) ((char *) octx->src0.data + i01 * nb01 + i02 * nb02 + i03 * nb03); - float * dp = (float *) ((char *) octx->dst.data + i01 * nb1 + i02 * nb2 + i03 * nb3); - - // broadcast the mask across rows - __fp16 * mp_f16 = (softmax_ctx->use_src1) ? - (__fp16 *) ((char *) octx->src1.data + i11 * nb11 + i12 * nb12 + i13 * nb13) : - NULL; - float * mp_f32 = (softmax_ctx->use_src1) ? - (float *) ((char *) octx->src1.data + i11 * nb11 + i12 * nb12 + i13 * nb13) : - NULL; - - if ((1 == opt_path) && (mp_f32) && !(softmax_ctx->use_f16)) { - hvx_fast_softmax_prep_f32((const uint8_t *) sp, (uint8_t *) wp0, ne00, softmax_ctx->scale, - (const uint8_t *) mp_f32, slope); - } else { - hvx_scale_f32((uint8_t *) wp0, (const uint8_t *) sp, ne00, softmax_ctx->scale); - if (mp_f32) { - if (softmax_ctx->use_f16) { - for (int i = 0; i < ne00; ++i) { - wp0[i] += slope * (float) mp_f16[i]; - } - } else { - for (int i = 0; i < ne00; ++i) { - wp0[i] += slope * mp_f32[i]; - } - } - } - } - - if (1 == opt_path) { - hvx_fast_softmax_f32((const uint8_t *) wp0, (uint8_t *) dp, (uint8_t *) wp1, ne00); - } else { - float max = hvx_reduce_max_f32((const uint8_t *) wp0, ne00); - float sum = hvx_softmax_f32((const uint8_t *) wp0, (uint8_t *) wp2, (uint8_t *) wp1, ne00, max); - sum = sum > 0.0 ? (1.0 / sum) : 1; - hvx_scale_f32((uint8_t *) dp, (const uint8_t *) wp2, ne00, sum); - } - } - } - } -} - -static void softmax_job_f32_per_thread(struct softmax_th_ctx * softmax_ctx, int nth, int ith) { - struct htp_ops_context * octx = softmax_ctx->octx; +static void softmax_job_f32(unsigned int nth, unsigned int ith, void * data) { + struct htp_softmax_context * smctx = (struct htp_softmax_context *) data; + struct htp_ops_context * octx = smctx->octx; const struct htp_tensor * src0 = &octx->src0; const struct htp_tensor * src1 = &octx->src1; @@ -268,7 +213,7 @@ static void softmax_job_f32_per_thread(struct softmax_th_ctx * softmax_ctx, int htp_softmax_preamble3; const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows - const uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread; + const uint32_t src0_nrows_per_thread = smctx->src0_nrows_per_thread; const uint32_t src0_start_row = src0_nrows_per_thread * ith; const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); @@ -291,20 +236,103 @@ static void softmax_job_f32_per_thread(struct softmax_th_ctx * softmax_ctx, int opt_path = 1; } - softmax_htp_f32(nth, ith, softmax_ctx, opt_path); + uint8_t * src0_spad_data = octx->src0_spad.data + (ith * smctx->spad_stride); + uint8_t * src1_spad_data = octx->src1_spad.data + (ith * smctx->spad_stride); + uint8_t * dst_spad_data = octx->dst_spad.data + (ith * smctx->spad_stride); + + float * wp0 = (float *) src0_spad_data; + float * wp1 = (float *) src1_spad_data; + float * wp2 = (float *) dst_spad_data; + + uint32_t prev_i2 = (uint32_t)-1; + float slope = 1.0f; + + for (uint32_t r = src0_start_row; r < src0_end_row; ++r) { + uint32_t i1 = fastmodulo(r, ne01, &smctx->fastdiv_ne01); + uint32_t r_div_ne01 = fastdiv(r, &smctx->fastdiv_ne01); + uint32_t i2 = fastmodulo(r_div_ne01, ne02, &smctx->fastdiv_ne02); + uint32_t i3 = fastdiv(r_div_ne01, &smctx->fastdiv_ne02); + + // Map to original logic indices + // i01 = i1 + // i02 = i2 + // i03 = i3 + + const uint32_t i11 = i1; + // const uint32_t i12 = i2 % ne12; + // const uint32_t i13 = i3 % ne13; + + uint32_t i12, i13; + if (ne12 == ne02) { + i12 = i2; + } else { + i12 = fastmodulo(i2, ne12, &smctx->fastdiv_ne12); + } + + if (ne13 == ne03) { + i13 = i3; + } else { + i13 = fastmodulo(i3, ne13, &smctx->fastdiv_ne13); + } + + // ALiBi + if (i2 != prev_i2) { + const uint32_t h = i2; // head + + slope = (smctx->max_bias > 0.0f) ? + h < smctx->n_head_log2 ? + powf(smctx->m0, h + 1) : + powf(smctx->m1, 2 * (h - smctx->n_head_log2) + 1) : + 1.0f; + prev_i2 = i2; + } + + float * sp = (float *) ((char *) octx->src0.data + i1 * nb01 + i2 * nb02 + i3 * nb03); + float * dp = (float *) ((char *) octx->dst.data + i1 * nb1 + i2 * nb2 + i3 * nb3); + + // broadcast the mask across rows + __fp16 * mp_f16 = (smctx->use_src1) ? + (__fp16 *) ((char *) octx->src1.data + i11 * nb11 + i12 * nb12 + i13 * nb13) : + NULL; + float * mp_f32 = (smctx->use_src1) ? + (float *) ((char *) octx->src1.data + i11 * nb11 + i12 * nb12 + i13 * nb13) : + NULL; + + if ((1 == opt_path) && (mp_f32) && !(smctx->use_f16)) { + hvx_fast_softmax_prep_f32((const uint8_t *) sp, (uint8_t *) wp0, ne00, smctx->scale, + (const uint8_t *) mp_f32, slope); + } else { + hvx_scale_f32((uint8_t *) wp0, (const uint8_t *) sp, ne00, smctx->scale); + if (mp_f32) { + if (smctx->use_f16) { + for (int i = 0; i < ne00; ++i) { + wp0[i] += slope * (float) mp_f16[i]; + } + } else { + for (int i = 0; i < ne00; ++i) { + wp0[i] += slope * mp_f32[i]; + } + } + } + } + + if (1 == opt_path) { + hvx_fast_softmax_f32((const uint8_t *) wp0, (uint8_t *) dp, (uint8_t *) wp1, ne00); + } else { + float max = hvx_reduce_max_f32((const uint8_t *) wp0, ne00); + float sum = hvx_softmax_f32((const uint8_t *) wp0, (uint8_t *) wp2, (uint8_t *) wp1, ne00, max); + sum = sum > 0.0 ? (1.0 / sum) : 1; + hvx_scale_f32((uint8_t *) dp, (const uint8_t *) wp2, ne00, sum); + } + } t2 = HAP_perf_get_qtimer_count(); FARF(HIGH, "softmax-f32 %d/%d/%d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, - softmax_ctx->use_f16, opt_path, ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13, + smctx->use_f16, opt_path, ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } -static void softmax_job_dispatcher_f32(unsigned int n, unsigned int i, void * p_data) { - struct softmax_th_ctx * p_softmax_ctx = (struct softmax_th_ctx *) p_data; - softmax_job_f32_per_thread(p_softmax_ctx, n, i); -} - static int execute_op_softmax_f32(struct htp_ops_context * octx) { int err = HTP_STATUS_OK; @@ -312,17 +340,12 @@ static int execute_op_softmax_f32(struct htp_ops_context * octx) { const struct htp_tensor * src1 = &octx->src1; struct htp_tensor * dst = &octx->dst; - worker_callback_t op_func; - const char * op_type = NULL; - - struct softmax_th_ctx softmax_ctx; + struct htp_softmax_context smctx; + const char * op_type = "softmax-f32"; switch (octx->op) { case HTP_OP_SOFTMAX: - op_func = softmax_job_dispatcher_f32; - op_type = "softmax-f32"; - - init_softmax_ctx(&softmax_ctx, octx); + init_softmax_ctx(&smctx, octx); break; default: @@ -342,6 +365,9 @@ static int execute_op_softmax_f32(struct htp_ops_context * octx) { octx->src0_spad.size = hex_round_up(src0_row_size, 128) * n_threads; octx->src1_spad.size = hex_round_up(src1_row_size, 128) * n_threads; + // Use stride for calculating offset + smctx.spad_stride = hex_round_up(src0_row_size, 128); + size_t spad_size = octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size; if (src1->ne[0]) { @@ -371,8 +397,8 @@ static int execute_op_softmax_f32(struct htp_ops_context * octx) { if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { uint32_t n_jobs = MIN(n_threads, src0_nrows); - octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs; - worker_pool_run_func(octx->ctx->worker_pool, op_func, &softmax_ctx, n_jobs); + smctx.src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs; + worker_pool_run_func(octx->ctx->worker_pool, softmax_job_f32, &smctx, n_jobs); } return err; diff --git a/ggml/src/ggml-hexagon/htp/sum-rows-ops.c b/ggml/src/ggml-hexagon/htp/sum-rows-ops.c index 62e45da2b35..04fa72182a3 100644 --- a/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +++ b/ggml/src/ggml-hexagon/htp/sum-rows-ops.c @@ -17,7 +17,6 @@ #include "htp-msg.h" #include "htp-ops.h" - #define sum_rows_preamble \ struct htp_tensor *src0 = &octx->src0;\ struct htp_tensor *dst = &octx->dst; \ @@ -42,53 +41,54 @@ const uint32_t nb2 = dst->nb[2]; \ const uint32_t nb3 = dst->nb[3]; \ -static int sum_rows_thread_f32(struct htp_ops_context * octx, const int nth, const int ith) { - sum_rows_preamble; +struct sum_rows_context { + const uint8_t * src_data; + uint8_t * dst_data; + uint32_t ne00; + size_t src_stride; + size_t dst_stride; + uint32_t rows_per_thread; + uint32_t total_rows; + bool opt_path; +}; - const uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread; - const size_t src0_row_size = nb01; - const size_t dst_row_size = nb1; +static void sum_rows_thread_f32(unsigned int nth, unsigned int ith, void *data) { + const struct sum_rows_context * smctx = (const struct sum_rows_context *) data; - const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows + const uint32_t rows_per_thread = smctx->rows_per_thread; + const uint32_t total_rows = smctx->total_rows; - const uint32_t src0_start_row = src0_nrows_per_thread * ith; - const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); + const uint32_t start_row = rows_per_thread * ith; + const uint32_t end_row = MIN(start_row + rows_per_thread, total_rows); - // no work for this thread - if (src0_start_row >= src0_end_row) { - return HTP_STATUS_OK; + if (start_row >= end_row) { + return; } - int opt_path = 0; - if ((0 == hex_is_aligned((void *) src0->data, VLEN)) && !(nb01 & (VLEN - 1))) { - opt_path = 1; - } + const size_t src_stride = smctx->src_stride; + const size_t dst_stride = smctx->dst_stride; + const uint32_t ne00 = smctx->ne00; + const bool opt_path = smctx->opt_path; - const uint8_t * restrict data_src = (const uint8_t *) src0->data; - uint8_t * restrict data_dst = (uint8_t *) dst->data; + const float * restrict src_th = (const float *) (smctx->src_data + (start_row * src_stride)); + float * restrict dst_th = (float *) (smctx->dst_data + (start_row * dst_stride)); - const float * restrict src_th = (float *) (data_src + (src0_start_row * src0_row_size)); - float * restrict dst_th = (float *) (data_dst + (src0_start_row * dst_row_size)); + // Calculate actual number of rows for this thread + const uint32_t n_rows = end_row - start_row; - for (uint32_t ir = 0; ir < src0_nrows_per_thread; ir++) { - const float * restrict src_local = src_th + (ir * ne00); + for (uint32_t ir = 0; ir < n_rows; ir++) { + const float * restrict src_local = src_th + (ir * (src_stride / sizeof(float))); - if (ir + 1 < src0_nrows_per_thread) { - hex_l2fetch(src_local + ne00, src0_row_size, src0_row_size, 1); + if (ir + 1 < n_rows) { + hex_l2fetch(src_local + (src_stride / sizeof(float)), src_stride, src_stride, 1); } - if (1 == opt_path) { + if (opt_path) { dst_th[ir] = hvx_reduce_sum_f32_a((const uint8_t *) src_local, ne00); } else { dst_th[ir] = hvx_reduce_sum_f32((const uint8_t *) src_local, ne00); } } - - return HTP_STATUS_OK; -} - -static void sum_rows_work_f32(unsigned int n, unsigned int i, void *data) { - sum_rows_thread_f32((struct htp_ops_context *) data, n, i); } int op_sum_rows(struct htp_ops_context * octx) { @@ -106,10 +106,25 @@ int op_sum_rows(struct htp_ops_context * octx) { const uint32_t src0_nrows = ne01 * ne02 * ne03; uint32_t n_jobs = MIN(n_threads, src0_nrows); - octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs; + uint32_t rows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs; - worker_pool_run_func(octx->ctx->worker_pool, sum_rows_work_f32, octx, n_jobs); + bool opt_path = false; + if ((0 == hex_is_aligned((void *) src0->data, VLEN)) && !(nb01 & (VLEN - 1))) { + opt_path = true; + } + + struct sum_rows_context smctx = { + .src_data = (const uint8_t *) src0->data, + .dst_data = (uint8_t *) dst->data, + .ne00 = ne00, + .src_stride = nb01, + .dst_stride = nb1, + .rows_per_thread = rows_per_thread, + .total_rows = src0_nrows, + .opt_path = opt_path, + }; + + worker_pool_run_func(octx->ctx->worker_pool, sum_rows_thread_f32, &smctx, n_jobs); return HTP_STATUS_OK; } - diff --git a/ggml/src/ggml-hexagon/htp/unary-ops.c b/ggml/src/ggml-hexagon/htp/unary-ops.c index ce879bf0370..98135c50ab8 100644 --- a/ggml/src/ggml-hexagon/htp/unary-ops.c +++ b/ggml/src/ggml-hexagon/htp/unary-ops.c @@ -17,6 +17,28 @@ #include "htp-msg.h" #include "htp-ops.h" +struct htp_unary_context { + struct htp_ops_context * octx; + + // Precomputed values + const uint8_t * data_src0; + uint8_t * data_dst; + + size_t src0_row_size; + size_t dst_row_size; + + size_t src0_row_size_aligned; + size_t dst_row_size_aligned; + + size_t src0_spad_half_size; + size_t dst_spad_half_size; + + uint32_t block; + uint32_t src0_nrows; + uint32_t src0_nrows_per_thread; + uint32_t nc; +}; + #define htp_unary_preamble \ const uint32_t ne00 = src->ne[0]; \ const uint32_t ne01 = src->ne[1]; \ @@ -57,8 +79,7 @@ static void hvx_fast_rms_norm_f32(const uint8_t * restrict src, sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2); } - HVX_Vector reduced_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_v)); - sum_v = hvx_vec_repl4(reduced_sum); + sum_v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_v)); // replicated over all lanes HVX_Vector t_v = hvx_vec_splat_f32((float) num_elems); HVX_Vector denom_v = hvx_vec_inverse_f32(t_v); @@ -75,128 +96,95 @@ static void hvx_fast_rms_norm_f32(const uint8_t * restrict src, } } -static void scale_htp_f32(const float * restrict src, - float * restrict dst, - uint8_t * restrict spad, - const uint32_t num_rows, - const uint32_t row_elems, - const size_t row_size, - int32_t * op_params, - int opt_path) { +static void scale_f32(const float * restrict src, + float * restrict dst, + uint8_t * restrict spad, + const uint32_t num_rows, + const uint32_t row_elems, + const size_t row_size, + int32_t * op_params) { float scale = 0.f; float bias = 0.f; memcpy(&scale, &op_params[0], sizeof(float)); memcpy(&bias, &op_params[1], sizeof(float)); for (uint32_t ir = 0; ir < num_rows; ir++) { - const float * restrict src_local = src + (ir * row_elems); - float * restrict dst_local = dst + (ir * row_elems); - - if (ir + 1 < num_rows) { - hex_l2fetch(src_local + row_elems, row_size, row_size, 1); - } + const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size); + uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size); - hvx_scale_offset_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems, scale, bias); + hvx_scale_offset_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems, scale, bias); } } -static void rms_norm_htp_f32(const float * restrict src, - float * restrict dst, - uint8_t * restrict spad, - const uint32_t num_rows, - const uint32_t row_elems, - const size_t row_size, - int32_t * op_params, - int opt_path) { +static void rms_norm_f32(const float * restrict src, + float * restrict dst, + uint8_t * restrict spad, + const uint32_t num_rows, + const uint32_t row_elems, + const size_t row_size, + int32_t * op_params) { float epsilon = 0.f; memcpy(&epsilon, op_params, sizeof(float)); for (uint32_t ir = 0; ir < num_rows; ir++) { - const float * restrict src_local = src + (ir * row_elems); - float * restrict dst_local = dst + (ir * row_elems); + const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size); + uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size); - if (ir + 1 < num_rows) { - hex_l2fetch(src_local + row_elems, row_size, row_size, 1); - } - - if (1 == opt_path) { - hvx_fast_rms_norm_f32((const uint8_t *) src_local, (uint8_t *) dst_local, spad, row_elems, epsilon); - } else { - float sum = hvx_sum_of_squares_f32((const uint8_t *) src_local, row_elems); - - const float mean = sum / row_elems; - const float scale = 1.0f / sqrtf(mean + epsilon); - - hvx_scale_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems, scale); - } + hvx_fast_rms_norm_f32((const uint8_t *) src_local, (uint8_t *) dst_local, spad, row_elems, epsilon); } } -static void sqr_htp_f32(const float * restrict src, - float * restrict dst, - uint8_t * restrict spad, - const uint32_t num_rows, - const uint32_t row_elems, - const size_t row_size, - int32_t * op_params, - int opt_path) { +static void sqr_f32(const float * restrict src, + float * restrict dst, + uint8_t * restrict spad, + const uint32_t num_rows, + const uint32_t row_elems, + const size_t row_size, + int32_t * op_params) { for (uint32_t ir = 0; ir < num_rows; ir++) { - const float * restrict src_local = src + (ir * row_elems); - float * restrict dst_local = dst + (ir * row_elems); - - if (ir + 1 < num_rows) { - hex_l2fetch(src_local + row_elems, row_size, row_size, 1); - } + const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size); + uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size); - if (1 == opt_path) { - hvx_sqr_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems); - } else { - hvx_sqr_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems); - } + hvx_sqr_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems); } } -static void sqrt_htp_f32(const float * restrict src, - float * restrict dst, - uint8_t * restrict spad, - const uint32_t num_rows, - const uint32_t row_elems, - const size_t row_size, - int32_t * op_params, - int opt_path) { +static void sqrt_f32(const float * restrict src, + float * restrict dst, + uint8_t * restrict spad, + const uint32_t num_rows, + const uint32_t row_elems, + const size_t row_size, + int32_t * op_params) { for (uint32_t ir = 0; ir < num_rows; ir++) { - const float * restrict src_local = src + (ir * row_elems); - float * restrict dst_local = dst + (ir * row_elems); + const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size); + uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size); - if (ir + 1 < num_rows) { - hex_l2fetch(src_local + row_elems, row_size, row_size, 1); - } - - if (1 == opt_path) { - hvx_sqrt_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems); - } else { - hvx_sqrt_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems); - } + hvx_sqrt_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems); } } -static void unary_job_f32_per_thread(const struct htp_tensor * src, - struct htp_tensor * dst, - uint8_t * spad, - int htp_op, - int32_t * op_params, - uint32_t nth, - uint32_t ith, - uint32_t src0_nrows_per_thread) { +static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * data) { + const struct htp_unary_context * uctx = (const struct htp_unary_context *) data; + struct htp_ops_context * octx = uctx->octx; + const struct htp_tensor * src = &octx->src0; + const struct htp_tensor * dst = &octx->dst; + htp_unary_preamble; - const size_t src0_row_size = nb01; - const size_t dst_row_size = nb1; + int htp_op = octx->op; + int32_t * op_params = octx->op_params; + uint32_t src0_nrows_per_thread = uctx->src0_nrows_per_thread; - const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows + const size_t src0_row_size = uctx->src0_row_size; + const size_t dst_row_size = uctx->dst_row_size; + const size_t src0_row_size_aligned = uctx->src0_row_size_aligned; + const size_t dst_row_size_aligned = uctx->dst_row_size_aligned; + + const uint32_t src0_nrows = uctx->src0_nrows; const uint32_t src0_start_row = src0_nrows_per_thread * ith; const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); @@ -208,79 +196,104 @@ static void unary_job_f32_per_thread(const struct htp_tensor * src, uint64_t t1, t2; t1 = HAP_perf_get_qtimer_count(); - int is_aligned = 1; - int opt_path = 0; - if ((0 == hex_is_aligned((void *) src->data, VLEN)) || (0 == hex_is_aligned((void *) dst->data, VLEN))) { - is_aligned = 0; - } - if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) { - opt_path = 1; + const uint8_t * restrict data_src = uctx->data_src0; + uint8_t * restrict data_dst = uctx->data_dst; + + uint8_t * src0_spad_data = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread); + uint8_t * dst_spad_data = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread); + + size_t src0_spad_half_size = uctx->src0_spad_half_size; + size_t dst_spad_half_size = uctx->dst_spad_half_size; + + const int BLOCK = uctx->block; + if (BLOCK == 0) { + FARF(ERROR, "unary-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n", + octx->src0_spad.size_per_thread, src0_row_size_aligned); + return; } - const uint8_t * restrict data_src = (const uint8_t *) src->data; - uint8_t * restrict data_dst = (uint8_t *) dst->data; + dma_queue * dma_queue = octx->ctx->dma[ith]; - const float * restrict src_th = (float *) (data_src + (src0_start_row * src0_row_size)); - float * restrict dst_th = (float *) (data_dst + (src0_start_row * dst_row_size)); - uint8_t * restrict spad_th = (uint8_t *) spad + (ith * nb01); + for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) { + const uint32_t block_size = MIN(BLOCK, src0_end_row - ir); - switch (htp_op) { - case HTP_OP_RMS_NORM: - rms_norm_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path); - break; - case HTP_OP_SCALE: - scale_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path); - break; - case HTP_OP_SQR: - sqr_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path); - break; - case HTP_OP_SQRT: - sqrt_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path); - break; + // Dummy DMA transation for sequencing (interleaving dst,src,dst,...) + dma_queue_push_vtcm_to_ddr(dma_queue, + dma_make_ptr(data_dst, dst_spad_data + (spad_idx * dst_spad_half_size)), + dst_row_size, dst_row_size_aligned, 0); - default: - break; + dma_queue_push_ddr_to_vtcm(dma_queue, + dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src + (ir * src0_row_size)), + src0_row_size_aligned, src0_row_size, block_size); } + for (uint32_t ir = src0_start_row; ir < src0_end_row; ir += BLOCK) { + const uint32_t block_size = MIN(BLOCK, src0_end_row - ir); + + float * dst_spad = (float *) dma_queue_pop(dma_queue).src; + float * src0_spad = (float *) dma_queue_pop(dma_queue).dst; + + // Process block in VTCM + switch (htp_op) { + case HTP_OP_RMS_NORM: + rms_norm_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params); + break; + case HTP_OP_SCALE: + scale_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params); + break; + case HTP_OP_SQR: + sqr_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params); + break; + case HTP_OP_SQRT: + sqrt_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params); + break; + default: + break; + } + + dma_queue_push_vtcm_to_ddr(dma_queue, + dma_make_ptr(data_dst + (ir * dst_row_size), dst_spad), + dst_row_size, dst_row_size_aligned, block_size); + + // prefetch N+2 loop iteration if any + const uint32_t pref_block = (ir + BLOCK * 2); + if (pref_block < src0_end_row) { + const uint32_t pref_block_size = MIN(BLOCK, src0_end_row - pref_block); + dma_queue_push_ddr_to_vtcm(dma_queue, + dma_make_ptr(src0_spad, data_src + (pref_block * src0_row_size)), + src0_row_size_aligned, src0_row_size, pref_block_size); + } + } + + dma_queue_flush(dma_queue); + t2 = HAP_perf_get_qtimer_count(); - FARF(HIGH, "unary-f32 %d/%d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", ith, nth, opt_path, src->ne[0], + FARF(HIGH, "unary-f32 %d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", ith, nth, src->ne[0], src->ne[1], src->ne[2], src->ne[3], src0_start_row, src0_end_row, dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } -static void unary_job_dispatcher_f32(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = (struct htp_ops_context *) data; - - unary_job_f32_per_thread(&octx->src0, &octx->dst, octx->src0_spad.data, octx->op, octx->op_params, n, i, - octx->src0_nrows_per_thread); -} - static int execute_op_unary_f32(struct htp_ops_context * octx) { int err = HTP_STATUS_OK; const struct htp_tensor * src0 = &octx->src0; struct htp_tensor * dst = &octx->dst; - worker_callback_t unary_op_func; - const char * op_type = NULL; + const char * op_type = NULL; switch (octx->op) { case HTP_OP_RMS_NORM: - unary_op_func = unary_job_dispatcher_f32; - op_type = "rmsnorm-f32"; + op_type = "rmsnorm-f32"; break; case HTP_OP_SCALE: - unary_op_func = unary_job_dispatcher_f32; - op_type = "scale-f32"; + op_type = "scale-f32"; break; case HTP_OP_SQR: - unary_op_func = unary_job_dispatcher_f32; - op_type = "sqr-f32"; + op_type = "sqr-f32"; break; case HTP_OP_SQRT: - unary_op_func = unary_job_dispatcher_f32; - op_type = "sqrt-f32"; + op_type = "sqrt-f32"; break; default: @@ -294,32 +307,61 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) { const size_t src0_row_size = src0->nb[1]; const size_t dst_row_size = dst->nb[1]; - // VTCM scratchpads for all tensors - octx->dst_spad.size = hex_round_up(dst_row_size, 128) * n_threads; - octx->src0_spad.size = hex_round_up(src0_row_size, 128) * n_threads; + const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN); + const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN); - size_t spad_size = octx->src0_spad.size + octx->dst_spad.size; + // VTCM scratchpads for all tensors + // N rows per thread, padded to HVX vector size + // Double buffering requires 2x size per buffer - FARF(HIGH, "%s: (%ux%ux%ux%u) -> (%ux%ux%ux%u) : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", op_type, - src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], - octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size); + size_t spad_size_per_row = 2 * (src0_row_size_aligned + dst_row_size_aligned); + size_t vtcm_row_per_thread = (octx->ctx->vtcm_size)/ (n_threads * spad_size_per_row); // Make sure the reserved vtcm size is sufficient - if (octx->ctx->vtcm_size < spad_size) { + if (vtcm_row_per_thread == 0) { FARF(ERROR, "unary-%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size, - spad_size); + spad_size_per_row * n_threads); return HTP_STATUS_VTCM_TOO_SMALL; } + octx->src0_spad.size_per_thread = src0_row_size_aligned * vtcm_row_per_thread * 2; + octx->dst_spad.size_per_thread = dst_row_size_aligned * vtcm_row_per_thread * 2; + + octx->src0_spad.size = n_threads * octx->src0_spad.size_per_thread; + octx->dst_spad.size = n_threads * octx->dst_spad.size_per_thread; + octx->src0_spad.data = octx->ctx->vtcm_base; octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size; + FARF(HIGH, "%s: (%ux%ux%ux%u) -> (%ux%ux%ux%u) : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", op_type, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size); + if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { uint32_t n_jobs = MIN(n_threads, src0_nrows); - octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs; + struct htp_unary_context uctx = { + .octx = octx, + .src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs, + .src0_nrows = src0_nrows, + + .data_src0 = (const uint8_t *)src0->data, + .data_dst = (uint8_t *)dst->data, + + .src0_row_size = src0_row_size, + .dst_row_size = dst_row_size, + + .src0_row_size_aligned = src0_row_size_aligned, + .dst_row_size_aligned = dst_row_size_aligned, + + .src0_spad_half_size = octx->src0_spad.size_per_thread / 2, + .dst_spad_half_size = octx->dst_spad.size_per_thread / 2, + + .block = (octx->src0_spad.size_per_thread / 2) / src0_row_size_aligned, + .nc = src0->ne[0], + }; - worker_pool_run_func(octx->ctx->worker_pool, unary_op_func, octx, n_jobs); + worker_pool_run_func(octx->ctx->worker_pool, unary_job_f32_per_thread, &uctx, n_jobs); } return err; From 344eae3d226eb9143fdf83f3097f2ca37f1021cd Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Tue, 24 Feb 2026 00:43:12 -0600 Subject: [PATCH 182/831] vulkan: fix data race in mul_mat_id shader (llama/19790) --- ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp | 4 +++- ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp index 717d124e019..497a18ff8a7 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp @@ -167,7 +167,9 @@ void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) { uint id = ids[iter++]; uvec4 ballot = subgroupBallot(in_range && id == expert_idx); - ballots_sh[gl_SubgroupID] = ballot; + if (gl_SubgroupInvocationID == 0) { + ballots_sh[gl_SubgroupID] = ballot; + } barrier(); uint subgroup_base = 0; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl index 743004ff8ad..26c5c12a49a 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl @@ -43,7 +43,9 @@ void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) { uint id = ids[iter++]; uvec4 ballot = subgroupBallot(in_range && id == expert_idx); - ballots_sh[gl_SubgroupID] = ballot; + if (gl_SubgroupInvocationID == 0) { + ballots_sh[gl_SubgroupID] = ballot; + } barrier(); uint subgroup_base = 0; From dcc877688df2c1b57c34f09c5b4b194e8f86a09a Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Tue, 24 Feb 2026 00:48:32 -0600 Subject: [PATCH 183/831] vulkan: fix coopmat1 without bf16 support (llama/19793) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index a8840a0773b..88b3e4e58eb 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -3780,10 +3780,12 @@ static void ggml_vk_load_shaders(vk_device& device) { && !device->coopmat_bf16_support #endif ) { + const uint32_t s_warptile_wm = device->subgroup_size == 8 ? 8 : 32; + // use scalar tile sizes l_warptile = { 128, 128, 128, 16, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 }; m_warptile = { 128, 64, 64, 16, subgroup_size_8, 32, 2, 4, 2, 1, subgroup_size_8 }; - s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, 2, 2, 1, subgroup_size_8 }; + s_warptile = { subgroup_size_32, 32, 32, 16, s_warptile_wm, 32, 2, 2, 2, 1, subgroup_size_8 }; l_wg_denoms = {128, 128, 1 }; m_wg_denoms = { 64, 64, 1 }; From 90800b5aa51e58f36cc46794e6ddc4c765eb446b Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Tue, 24 Feb 2026 08:35:48 +0100 Subject: [PATCH 184/831] Vulkan Scalar Flash Attention Refactor (llama/19625) * vulkan: allow using fp16 in scalar flash attention shader * split rows inside of subgroups for faster synchronization * use row_split when Br >= 4, change reductions to use shared memory if row_split == 1 * use f32 scalar FA if f16 is not supported by device * fix amd workgroup size issue * optimize masksh use * add medium rows FA shader Br size * fixes * add padding to mask shmem buffer * cache q values into registers for KQ * fuse lf accumulation, pf and v accumulation into a loop * stage K loads through shmem * stage V loads through shmem * only stage through shmem on Nvidia * default to Bc 32 * also stage V through shmem when this is done for K * dynamic subgroups for intel * use vectorized stores * use float_type for dequantize4 functions * use smaller scalar rows size for smaller rows count * relax flash attention split_k condition to allow non-gqa use * use minimal subgroup size on Intel * fix shmem support function * fix rebase issues * fixes * Bc 4 for scalar FA is not a valid configuration * Use wave32 on AMD RDNA for scalar FA * add Intel shader core count lookup-table * fix regressions * device tuning * tmpsh size fix * fix editorconfig * refactor fa tuning logic into a single place * fix gqa opt logic * fix block_rows with small n_rows * amd tuning * fix hsk=72/80 issue * tuning * allow condition skipping for column check * use float16 for Of if available * address feedback * fix bad RDNA performance on head size <= 128 by limiting occupancy * allow printing pipeline stats * cleanup and fixes * limit occupancy for GCN for small batch FA with large HSK * disable f16 FA for GCN AMD GPUs on the proprietary driver --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 572 +++++++++++------- .../vulkan-shaders/flash_attn.comp | 530 ++++++++++------ .../vulkan-shaders/flash_attn_base.glsl | 63 +- .../vulkan-shaders/flash_attn_cm1.comp | 224 ++++--- .../vulkan-shaders/flash_attn_cm2.comp | 42 +- .../vulkan-shaders/vulkan-shaders-gen.cpp | 80 +-- 6 files changed, 970 insertions(+), 541 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 88b3e4e58eb..8a9cfaf1654 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -403,19 +403,20 @@ enum FaCodePath { }; struct vk_fa_pipeline_state { - vk_fa_pipeline_state(uint32_t HSK, uint32_t HSV, bool small_rows, bool small_cache, FaCodePath path, bool aligned, bool f32acc, uint32_t flags) - : HSK(HSK), HSV(HSV), small_rows(small_rows), small_cache(small_cache), path(path), aligned(aligned), f32acc(f32acc), flags(flags) {} - uint32_t HSK, HSV; - bool small_rows, small_cache; + uint32_t Br, Bc; + uint32_t D_split, row_split; + bool shmem_staging; FaCodePath path; + uint32_t workgroup_size, subgroup_size; bool aligned; bool f32acc; uint32_t flags; + uint32_t limit_occupancy_shmem; bool operator<(const vk_fa_pipeline_state &b) const { - return std::tie(HSK, HSV, small_rows, small_cache, path, aligned, f32acc, flags) < - std::tie(b.HSK, b.HSV, b.small_rows, b.small_cache, b.path, b.aligned, b.f32acc, b.flags); + return std::tie(HSK, HSV, Br, Bc, D_split, row_split, shmem_staging, path, workgroup_size, subgroup_size, aligned, f32acc, flags, limit_occupancy_shmem) < + std::tie(b.HSK, b.HSV, b.Br, b.Bc, b.D_split, b.row_split, b.shmem_staging, b.path, b.workgroup_size, b.subgroup_size, b.aligned, b.f32acc, b.flags, b.limit_occupancy_shmem); } }; @@ -623,6 +624,8 @@ struct vk_device_struct { // floor(log2(maxComputeWorkGroupInvocations)) uint32_t max_workgroup_size_log2 {}; + bool flash_attention_fp16; + bool coopmat_support; bool coopmat_acc_f32_support {}; bool coopmat_acc_f16_support {}; @@ -1656,6 +1659,7 @@ static bool vk_perf_logger_concurrent = false; static bool vk_enable_sync_logger = false; // number of calls between perf logger prints static uint32_t vk_perf_logger_frequency = 1; +static std::string vk_pipeline_stats_filter; class vk_perf_logger { public: @@ -2172,7 +2176,32 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin executableInfo.pipeline = pipeline->pipeline; auto statistics = device->device.getPipelineExecutableStatisticsKHR(executableInfo); + + bool print_stats = !vk_pipeline_stats_filter.empty() && + pipeline->name.find(vk_pipeline_stats_filter) != std::string::npos; + if (print_stats) { + std::cerr << "ggml_vulkan: pipeline stats for " << pipeline->name << ":" << std::endl; + } + for (auto & s : statistics) { + if (print_stats) { + std::cerr << "ggml_vulkan: " << s.name.data() << ": "; + switch (s.format) { + case vk::PipelineExecutableStatisticFormatKHR::eBool32: + std::cerr << (s.value.b32 ? "true" : "false"); + break; + case vk::PipelineExecutableStatisticFormatKHR::eInt64: + std::cerr << s.value.i64; + break; + case vk::PipelineExecutableStatisticFormatKHR::eUint64: + std::cerr << s.value.u64; + break; + case vk::PipelineExecutableStatisticFormatKHR::eFloat64: + std::cerr << s.value.f64; + break; + } + std::cerr << std::endl; + } // "Register Count" is reported by NVIDIA drivers. if (strcmp(s.name, "Register Count") == 0) { VK_LOG_DEBUG(pipeline->name << " " << s.name << ": " << s.value.u64 << " registers"); @@ -2755,78 +2784,214 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector&& events ); } -// number of rows/cols for flash attention shader -static constexpr uint32_t flash_attention_num_small_rows = 32; -static constexpr uint32_t scalar_flash_attention_num_small_rows = 1; +struct vk_fa_tuning_params { + FaCodePath path; + uint32_t workgroup_size; + uint32_t subgroup_size; + uint32_t block_rows; + uint32_t block_cols; + uint32_t d_split; + uint32_t row_split; + bool shmem_staging; + bool disable_subgroups; + uint32_t limit_occupancy_shmem; + + void print() const { + std::cerr << "path=" << path << " workgroup_size=" << workgroup_size << " subgroup_size=" << subgroup_size << + " block_rows=" << block_rows << " block_cols=" << block_cols << " d_split=" << d_split << + " row_split=" << row_split << " shmem_staging=" << shmem_staging << " disable_subgroups=" << disable_subgroups << + " limit_occupancy_shmem=" << limit_occupancy_shmem << std::endl; + } +}; + +static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc); +static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc); + +static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) { + GGML_UNUSED(kv_type); + + vk_fa_tuning_params result{}; + result.path = FA_SCALAR; -static uint32_t get_fa_scalar_num_large_rows(uint32_t hsk, uint32_t hsv, bool small_cache) { - if (hsv >= 192) { - return 2; - } else if ((hsv | hsk) & 8 || small_cache) { - return 4; + if (device->vendor_id == VK_VENDOR_ID_INTEL) { + // Disable subgroup use due to performance issues when enforcing subgroup sizes + result.subgroup_size = 32; + result.disable_subgroups = true; + } else if (device->vendor_id == VK_VENDOR_ID_AMD && device->architecture != AMD_GCN) { + result.subgroup_size = n_rows < 4 ? 32 : device->subgroup_size; } else { - return 8; + result.subgroup_size = device->subgroup_size; } -} -// The FA coopmat1 shader assumes 16x16x16 matrix multiply support. -// 128 threads split into four subgroups, each subgroup does 1/4 -// of the Bc dimension. -static constexpr uint32_t coopmat1_flash_attention_num_large_rows = 16; -static constexpr uint32_t scalar_flash_attention_Bc = 64; -static constexpr uint32_t scalar_flash_attention_workgroup_size = 128; + // Row split splits the workgroup so that synchronization only has to happen within subgroups, which avoids barriers + uint32_t row_split_max_hsk = 64; + if (device->vendor_id == VK_VENDOR_ID_AMD && device->architecture != AMD_GCN && !device->uma) { + row_split_max_hsk = n_rows <= 8 ? 64 : 128; + } + result.row_split = (n_rows < 4 || hsk <= row_split_max_hsk) ? 1 : 4; -static uint32_t get_fa_num_small_rows(FaCodePath path) { - if (path == FA_COOPMAT2) { - return flash_attention_num_small_rows; + if (result.subgroup_size > 32 && (n_rows < 4 || hsk < (result.row_split == 1 ? 128 : 64))) { + result.workgroup_size = result.subgroup_size * 2; } else { - return scalar_flash_attention_num_small_rows; + result.workgroup_size = result.subgroup_size * 4; } -} -static std::array fa_rows_cols(FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache) { - GGML_UNUSED(clamp); + const uint32_t D = hsk | hsv; + + const bool reduce_block_rows = D & 8 || n_kv < 1024 || device->vendor_id == VK_VENDOR_ID_INTEL; - if (path == FA_SCALAR) { - if (small_rows) { - return {scalar_flash_attention_num_small_rows, 64}; + if (n_rows == 1) { + result.block_rows = 1; + result.block_cols = 64; + } else { + // row_split 1 means higher register use per row, so block size has to be adjusted + if (result.row_split == 1) { + result.block_rows = n_rows == 2 ? 2 : ((n_rows <= 4 || reduce_block_rows) ? 4 : 8); } else { - if ((hsv | hsk) & 8) { - // HSV/HSK not being a multiple of 16 makes D_split smaller, which makes cols_per_iter - // larger, and Bc needs to be >= cols_per_thread. 64 is large enough, 32 is not. - return {get_fa_scalar_num_large_rows(hsk, hsv, small_cache), 64}; - } else { - return {get_fa_scalar_num_large_rows(hsk, hsv, small_cache), 32}; - } + result.block_rows = n_rows <= 4 ? 4 : ((n_rows <= 8 || reduce_block_rows) ? 8 : 16); } + + result.block_cols = (D & 8) ? 64 : 32; } - if (path == FA_COOPMAT1) { - if (small_rows) { - return {scalar_flash_attention_num_small_rows, scalar_flash_attention_Bc}; - } else { - return {coopmat1_flash_attention_num_large_rows, scalar_flash_attention_Bc}; + const uint32_t D_lsb = D ^ (D & (D-1)); // extract lowest set bit + + result.d_split = std::min(std::min(result.subgroup_size, 8u), D_lsb / 4); + + result.shmem_staging = (device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 && hsv < 256) ? 1 : 0; + + if (!reduce_block_rows && !ggml_vk_flash_attn_scalar_shmem_support(device, result, hsk, hsv, f32acc)) { + result.block_rows /= 2; + } + + // On AMD RDNA, for small head sizes and big batch size the shader uses few registers, so too many subgroups get scheduled + // at once and end up thrashing the cache. Fix this by setting a large (unused) shmem buffer that reduces occupancy. + // This targets an occupancy of 4 subgroups per SIMD. + if (device->vendor_id == VK_VENDOR_ID_AMD && device->properties.limits.maxComputeSharedMemorySize == 65536) { + if (device->architecture != AMD_GCN && n_rows >= 64 && hsk <= 128) { + // 30kb target for hsk > 64, 26kb for <= 64 due to smaller workgroup size + // Values are guessed, tested on RDNA2 + result.limit_occupancy_shmem = (hsk <= 64 ? 26 : 30) * 1024 / 4 / 4; + } else if (device->architecture == AMD_GCN && n_rows <= 8 && hsk >= 256) { + // Same thing for GCN, with an occupancy target of 2 subgroups per SIMD. + // Here low-batch FA with large head size is affected. + // n_rows < 4 switch because workgroup size switches from 128 to 256 there. + result.limit_occupancy_shmem = (n_rows < 4 ? 14 : 26) * 1024 / 4 / 4; } } - // small rows, large cols + return result; +} + +static vk_fa_tuning_params get_fa_tuning_params_coopmat1(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) { + GGML_UNUSED(n_rows); + GGML_UNUSED(n_kv); + GGML_UNUSED(kv_type); + GGML_UNUSED(f32acc); + + vk_fa_tuning_params result{}; + result.path = FA_COOPMAT1; + + const uint32_t D = hsk | hsv; + + const uint32_t coopmat_block_rows = 16; + const uint32_t coopmat_block_cols = 16; + + const uint32_t num_subgroups = 4; + + result.block_rows = coopmat_block_rows; + result.block_cols = coopmat_block_cols * num_subgroups; + result.row_split = num_subgroups; + result.subgroup_size = device->subgroup_size; + result.workgroup_size = num_subgroups * result.subgroup_size; + + const uint32_t D_lsb = D ^ (D & (D-1)); // extract lowest set bit + result.d_split = std::min(std::min(result.subgroup_size, 8u), D_lsb / 4); + + result.shmem_staging = (device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 && hsv < 256) ? 1 : 0; + + return result; +} + +static vk_fa_tuning_params get_fa_tuning_params_coopmat2(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) { + GGML_UNUSED(n_kv); + GGML_UNUSED(f32acc); + + vk_fa_tuning_params result{}; + result.path = FA_COOPMAT2; + + const uint32_t D = hsk | hsv; + + const bool small_rows = n_rows < 32; + if (small_rows) { - return {get_fa_num_small_rows(FA_COOPMAT2), 32}; + result.block_rows = 32; + result.block_cols = 32; + } else if (ggml_is_quantized(kv_type) || hsk >= 256 || hsv >= 256) { + result.block_rows = (hsk >= 512 || hsv >= 512) ? 32 : 64; + result.block_cols = 32; + } else { + result.block_rows = 64; + result.block_cols = 64; } - // small cols to reduce register count - if (ggml_is_quantized(type) || hsk >= 256 || hsv >= 256) { - if (hsk >= 512 || hsv >= 512) { - return {32, 32}; - } else { - return {64, 32}; + result.subgroup_size = device->subgroup_size; + result.workgroup_size = (small_rows && (D % 32) == 0) ? 256 : 128; + + return result; +} + +static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) { + FaCodePath path = device->coopmat2 ? FA_COOPMAT2 : + device->coopmat1_fa_support ? FA_COOPMAT1 : FA_SCALAR; + + if (path == FA_COOPMAT1 && device->architecture == vk_device_architecture::NVIDIA_TURING) { + // Nvidia compiler bug, see https://github.com/ggml-org/llama.cpp/pull/19075#issuecomment-3820716090 + path = FA_SCALAR; + } + + if (path == FA_COOPMAT1) { + bool shape_ok = (f32acc && device->coopmat_support_16x16x16_f32acc) || + (!f32acc && device->coopmat_support_16x16x16_f16acc); + const vk_fa_tuning_params params = get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc); + bool shmem_ok = ggml_vk_flash_attn_coopmat_shmem_support(device, params, hsk, hsv, f32acc); + + if (!shape_ok || !shmem_ok) { + path = FA_SCALAR; } } - return {64, 64}; + + // scalar is faster than coopmat when N==1 + if (n_rows == 1 && (path == FA_COOPMAT1 || path == FA_COOPMAT2)) { + path = FA_SCALAR; + } + + switch (path) { + case FA_SCALAR: + return get_fa_tuning_params_scalar(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc); + case FA_COOPMAT1: + return get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc); + case FA_COOPMAT2: + return get_fa_tuning_params_coopmat2(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc); + default: + throw std::runtime_error("unsupported FaCodePath"); + } +} + +static vk_fa_pipeline_state get_fa_pipeline_state(const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool aligned, bool f32acc, + bool use_mask, bool use_mask_opt, bool use_logit_softcap) { + uint32_t flags = (use_mask_opt ? 1 : 0) | + (use_mask ? 2 : 0) | + (use_logit_softcap ? 4 : 0); + + const uint32_t subgroup_size = params.disable_subgroups ? 0 : params.subgroup_size; + + return vk_fa_pipeline_state{hsk, hsv, params.block_rows, params.block_cols, params.d_split, params.row_split, params.shmem_staging, params.path, params.workgroup_size, subgroup_size, aligned, f32acc, flags, params.limit_occupancy_shmem}; } -static uint32_t fa_align(FaCodePath path, uint32_t hsk, uint32_t hsv, ggml_type type, bool small_rows, bool small_cache) { - return fa_rows_cols(path, hsk, hsv, 0, type, small_rows, small_cache)[1]; +static std::vector get_fa_spec_constants(const vk_fa_pipeline_state& state) { + return {state.workgroup_size, state.Br, state.Bc, state.HSK, state.HSV, !state.aligned, state.D_split, + state.row_split, state.subgroup_size, state.shmem_staging ? 1u : 0u, state.flags, state.limit_occupancy_shmem}; } static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector& warptile, bool mul_mat_id, ggml_type src0_type) { @@ -3193,76 +3358,43 @@ static void ggml_vk_load_shaders(vk_device& device) { align, disable_robustness, require_full_subgroups, required_subgroup_size); }; - auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache) -> std::array { - return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows, small_cache)[0], 1, 1}; - }; - - auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache, uint32_t flags) -> std::vector { - // For large number of rows, 128 invocations seems to work best. - // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we - // can't use 256 for D==80. - // For scalar, use 128 (arbitrary) - // The same D_split value is used for both HSK and HSV, so just base it on the union of the LSBs. - const uint32_t D = (hsk|hsv); - auto rows_cols = fa_rows_cols(path, hsk, hsv, clamp, type, small_rows, small_cache); - - uint32_t wg_size; - switch (path) { - case FA_COOPMAT2: - wg_size = ((small_rows && (D % 32) == 0) ? 256 : 128); - break; - case FA_COOPMAT1: - wg_size = (rows_cols[1] / 16) * device->subgroup_size; // enough subgroups for Bc/MatBc - break; - default: - wg_size = scalar_flash_attention_workgroup_size; - break; - } - - // D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it. - // D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader. - const uint32_t D_lsb = D ^ (D & (D-1)); - uint32_t D_split = std::min(std::min(device->subgroup_size, 8u), D_lsb / 4); - - // Nvidia prefers shared memory use to load large tiles of K. - // Switch to loading from global memory when it would use too much shared memory. - // AMD prefers loading K directly from global memory - const uint32_t k_load_shmem = device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 ? 1 : 0; - - return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split, device->subgroup_size, k_load_shmem, flags}; - }; - #define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \ for (auto &fa : device->pipeline_flash_attn_f32_f16[TYPE]) { \ - uint32_t HSK = fa.first.HSK; \ - uint32_t HSV = fa.first.HSV; \ - bool small_rows = fa.first.small_rows; \ - bool small_cache = fa.first.small_cache; \ FaCodePath path = fa.first.path; \ + uint32_t Br = fa.first.Br; \ + uint32_t Bc = fa.first.Bc; \ bool aligned = fa.first.aligned; \ bool f32acc = fa.first.f32acc; \ - uint32_t flags = fa.first.flags; \ + uint32_t fa_sgs = fa.first.subgroup_size; \ + bool fa_ds = fa.first.subgroup_size == 0; \ if (path == FAPATH) { \ if (aligned) { \ if (f32acc) { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), Bc, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \ } else { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), Bc, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \ } \ } else { \ if (f32acc) { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache,flags), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), 1, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \ } else { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache,flags), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), 1, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \ } \ } \ } \ } - CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, ) - CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, ) - CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, ) - CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, ) + if (device->flash_attention_fp16) { + CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, ) + CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, ) + CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, ) + CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, ) + } else { + CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, _fp32) + CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, _fp32) + CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32) + CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32) + } #if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) if (device->coopmat1_fa_support) { CREATE_FA(GGML_TYPE_F32, f32, FA_COOPMAT1, _cm1) @@ -4535,6 +4667,7 @@ static void ggml_vk_load_shaders(vk_device& device) { } static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch); +static uint32_t ggml_vk_intel_shader_core_count(const vk::PhysicalDevice& vkdev); static vk_device ggml_vk_get_device(size_t idx) { VK_LOG_DEBUG("ggml_vk_get_device(" << idx << ")"); @@ -4751,6 +4884,8 @@ static vk_device ggml_vk_get_device(size_t idx) { device->shader_core_count = sm_props.shaderSMCount; } else if (amd_shader_core_properties2) { device->shader_core_count = amd_shader_core_properties2_props.activeComputeUnitCount; + } else if (device->vendor_id == VK_VENDOR_ID_INTEL) { + device->shader_core_count = ggml_vk_intel_shader_core_count(device->physical_device); } else { device->shader_core_count = 0; } @@ -4970,11 +5105,7 @@ static vk_device ggml_vk_get_device(size_t idx) { #if defined(VK_KHR_cooperative_matrix) device->coopmat_support = device->coopmat_support && coopmat_features.cooperativeMatrix; - - // coopmat1 fa shader currently assumes 32 invocations per subgroup - device->coopmat1_fa_support = device->coopmat_support && device->subgroup_require_full_support && - device->subgroup_size_control && device->subgroup_min_size <= 32 && - device->subgroup_max_size >= 32; + device->coopmat1_fa_support = device->coopmat_support && device->subgroup_require_full_support; #endif if (coopmat2_support) { @@ -5292,6 +5423,10 @@ static vk_device ggml_vk_get_device(size_t idx) { device->mmvq_mode = 1; } + // Driver issues with older AMD GPUs on Windows, see https://github.com/ggml-org/llama.cpp/pull/19625#issuecomment-3940840613 + const bool is_amd_proprietary_gcn = device->vendor_id == VK_VENDOR_ID_AMD && device->architecture == AMD_GCN && device->driver_id == vk::DriverId::eAmdProprietary; + device->flash_attention_fp16 = device->fp16 && !is_amd_proprietary_gcn; + return device; } @@ -5542,6 +5677,10 @@ static void ggml_vk_instance_init() { vk_perf_logger_concurrent = getenv("GGML_VK_PERF_LOGGER_CONCURRENT") != nullptr; vk_enable_sync_logger = getenv("GGML_VK_SYNC_LOGGER") != nullptr; vk_memory_logger_enabled = getenv("GGML_VK_MEMORY_LOGGER") != nullptr; + const char* GGML_VK_PIPELINE_STATS = getenv("GGML_VK_PIPELINE_STATS"); + if (GGML_VK_PIPELINE_STATS != nullptr) { + vk_pipeline_stats_filter = GGML_VK_PIPELINE_STATS; + } const char* GGML_VK_PERF_LOGGER_FREQUENCY = getenv("GGML_VK_PERF_LOGGER_FREQUENCY"); if (GGML_VK_PERF_LOGGER_FREQUENCY != nullptr) { @@ -8421,21 +8560,27 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx } } -static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool small_cache) { +static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc) { + GGML_UNUSED(f32acc); // Needs to be kept up to date on shader changes - GGML_UNUSED(hsv); - const uint32_t wg_size = scalar_flash_attention_workgroup_size; - const uint32_t Br = get_fa_scalar_num_large_rows(hsk, hsv, small_cache); - const uint32_t Bc = scalar_flash_attention_Bc; + const uint32_t wg_size = params.workgroup_size; + const uint32_t Br = params.block_rows; + const uint32_t Bc = params.block_cols; + const uint32_t float_type_size = device->flash_attention_fp16 ? sizeof(ggml_fp16_t) : sizeof(float); + + // tmpsh is overestimated slightly const uint32_t tmpsh = wg_size * sizeof(float); - const uint32_t tmpshv4 = wg_size * 4 * sizeof(float); + const uint32_t tmpshv4 = wg_size * 4 * float_type_size; + + const uint32_t masksh = Bc * (Br + 1) * float_type_size; - const uint32_t masksh = Bc * Br * sizeof(float); + const uint32_t Qf = Br * (hsk / 4 + 1) * 4 * float_type_size; - const uint32_t Qf = Br * (hsk / 4 + 2) * 4 * sizeof(float); + const uint32_t D = std::max(hsk, hsv); + const uint32_t kvsh = params.shmem_staging ? Bc * (D / 4 + 1) * 4 * float_type_size : 4 * float_type_size; - const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf; + const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf + kvsh; const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize; VK_LOG_DEBUG("ggml_vk_flash_attn_scalar_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", total_size=" << total_size << ", supported=" << supported); @@ -8443,18 +8588,17 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con return supported; } -static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type kv_type) { +static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc) { // Needs to be kept up to date on shader changes - GGML_UNUSED(hsv); - const auto rows_cols = fa_rows_cols(FA_COOPMAT1, hsk, hsv, 0, kv_type, false, false); - const uint32_t Br = rows_cols[0]; - const uint32_t Bc = rows_cols[1]; + const uint32_t Br = params.block_rows; + const uint32_t Bc = params.block_cols; const uint32_t MatBr = 16, MatBc = 16; const uint32_t row_split = Bc / MatBc; const uint32_t hsk_pad = ROUNDUP_POW2(hsk, 16); + const uint32_t hsv_pad = ROUNDUP_POW2(hsv, 16); const uint32_t acctype = f32acc ? 4 : 2; const uint32_t f16vec4 = 8; @@ -8470,17 +8614,19 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co const uint32_t sfshstride = (hsk <= 128) ? (Br + 8) : Br; const uint32_t sfsh = Bc * sfshstride * acctype; - const bool k_load_shmem = device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256; - const uint32_t kshstride = (k_load_shmem ? hsk_pad : MatBr) / 4 + 2; + const uint32_t kvshstride = (params.shmem_staging ? std::max(hsk_pad, hsv_pad) : MatBr) / 4 + 2; const uint32_t vsh_stride = MatBc / 4 * row_split; - const uint32_t ksh = ((kshstride >= vsh_stride) ? (Bc * kshstride) : (Bc * vsh_stride)) * f16vec4; + const uint32_t ksh = ((kvshstride >= vsh_stride) ? (Bc * kvshstride) : (Bc * vsh_stride)) * f16vec4; + + const uint32_t osh_stride = params.row_split * MatBr / 4; + const uint32_t pvsh = MatBc * osh_stride * f16vec4; const uint32_t slope = Br * acctype; - const uint32_t total_size = tmpsh + Qf + Psh + sfsh + ksh + slope; + const uint32_t total_size = tmpsh + Qf + Psh + sfsh + ksh + pvsh + slope; const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize; - VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", f32acc=" << f32acc << ", kv_type=" << kv_type << ", total_size=" << total_size << ", supported=" << supported); + VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", f32acc=" << f32acc << ", total_size=" << total_size << ", supported=" << supported); return supported; } @@ -8538,48 +8684,18 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx assert(q->type == GGML_TYPE_F32); assert(k->type == v->type); - FaCodePath path = ctx->device->coopmat2 ? FA_COOPMAT2 : - ctx->device->coopmat1_fa_support ? FA_COOPMAT1 : FA_SCALAR; - - if (path == FA_COOPMAT1 && ctx->device->architecture == vk_device_architecture::NVIDIA_TURING) { - // Nvidia compiler bug, see https://github.com/ggml-org/llama.cpp/pull/19075#issuecomment-3820716090 - path = FA_SCALAR; - } - - if (path == FA_COOPMAT1) { - const bool coopmat_shape_supported = (dst->op_params[3] == GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f32acc) || - (dst->op_params[3] != GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f16acc); - - const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device, HSK, HSV, dst->op_params[3] == GGML_PREC_F32, k->type); - - if (!coopmat_shape_supported || !coopmat_shmem_supported) { - path = FA_SCALAR; - } - } - uint32_t gqa_ratio = 1; uint32_t qk_ratio = neq2 / nek2; uint32_t workgroups_x = (uint32_t)neq1; uint32_t workgroups_y = (uint32_t)neq2; uint32_t workgroups_z = (uint32_t)neq3; - const bool small_cache = nek1 < 1024; + const bool f32acc = !ctx->device->flash_attention_fp16 || dst->op_params[3] == GGML_PREC_F32; // For scalar/coopmat1 FA, we can use the "large" size to accommodate qga. // For coopmat2 FA, we always use the small size (which is still pretty large for gqa). - uint32_t max_gqa; - switch (path) { - case FA_SCALAR: - case FA_COOPMAT1: - // We may switch from coopmat1 to scalar, so use the scalar limit for both - max_gqa = get_fa_scalar_num_large_rows(HSK, HSV, small_cache); - break; - case FA_COOPMAT2: - max_gqa = get_fa_num_small_rows(FA_COOPMAT2); - break; - default: - GGML_ASSERT(0); - } + vk_fa_tuning_params tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, 512, KV, k->type, f32acc); + const uint32_t max_gqa = std::min(tuning_params.block_rows, 32u); if (N <= 8 && qk_ratio > 1 && qk_ratio <= max_gqa && qk_ratio * nek2 == neq2 && nek2 == nev2 && nem2 <= 1) { @@ -8591,24 +8707,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx workgroups_y /= gqa_ratio; } - bool small_rows = N <= get_fa_num_small_rows(path); - - // coopmat1 does not actually support "small rows" (it needs 16 rows). - // So use scalar instead. - if (small_rows && path == FA_COOPMAT1) { - path = FA_SCALAR; - } - - // scalar is faster than coopmat2 when N==1 - if (N == 1 && path == FA_COOPMAT2) { - path = FA_SCALAR; - } - - // with large hsk/hsv, scalar path may need to use small_rows to fit in shared memory - if (path == FA_SCALAR && - !ggml_vk_flash_attn_scalar_shmem_support(ctx->device, HSK, HSV, small_cache)) { - small_rows = true; - } + tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, N, KV, k->type, f32acc); const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type)); uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type)); @@ -8622,18 +8721,16 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx v_stride /= 4; } - uint32_t alignment = fa_align(path, HSK, HSV, k->type, small_rows, small_cache); + const uint32_t alignment = tuning_params.block_cols; bool aligned = (KV % alignment) == 0 && // the "aligned" shader variant will forcibly align strides, for performance (q_stride & 7) == 0 && (k_stride & 7) == 0 && (v_stride & 7) == 0; // Need to use the coopmat2 variant that clamps loads when HSK/HSV aren't sufficiently aligned. - if (((HSK | HSV) % 16) != 0 && path == FA_COOPMAT2) { + if (((HSK | HSV) % 16) != 0 && tuning_params.path == FA_COOPMAT2) { aligned = false; } - bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32; - float scale = 1.0f; float max_bias = 0.0f; float logit_softcap = 0.0f; @@ -8648,12 +8745,8 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx // Only use mask opt when the mask is fairly large. This hasn't been tuned extensively. bool use_mask_opt = mask && nem1 >= 32 && nem0 * nem1 > 32768; - - uint32_t flags = (use_mask_opt ? 1 : 0) | - (mask != nullptr ? 2 : 0) | - (logit_softcap != 0 ? 4 : 0); - - vk_fa_pipeline_state fa_pipeline_state(HSK, HSV, small_rows, small_cache, path, aligned, f32acc, flags); + vk_fa_pipeline_state fa_pipeline_state = get_fa_pipeline_state(tuning_params, HSK, HSV, aligned, f32acc, + mask != nullptr, use_mask_opt, logit_softcap != 0); vk_pipeline pipeline = nullptr; @@ -8675,22 +8768,35 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx uint32_t split_kv = KV; uint32_t split_k = 1; + // Intel Alchemist prefers more workgroups + const uint32_t shader_core_count_multiplier = (ctx->device->vendor_id == VK_VENDOR_ID_INTEL && ctx->device->architecture != INTEL_XE2) ? 2 : 1; + // Use a placeholder core count if one isn't available. split_k is a big help for perf. - const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count : 16; + const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count * shader_core_count_multiplier : 16; + + const uint32_t Br = fa_pipeline_state.Br; + const uint32_t Bc = fa_pipeline_state.Bc; + + GGML_ASSERT(Br == pipeline->wg_denoms[0]); + const uint32_t Tr = CEIL_DIV(N, Br); // Try to use split_k when KV is large enough to be worth the overhead. - // Must either be a single batch or be using gqa, we can't mix the two. - if (workgroups_x <= pipeline->wg_denoms[0] && (workgroups_x == 1 || gqa_ratio > 1)) { - // Try to run two workgroups per SM. + if (gqa_ratio > 1 && workgroups_x <= Br) { split_k = shader_core_count * 2 / (workgroups_x * workgroups_y * workgroups_z); - if (split_k > 1) { - // Try to evenly split KV into split_k chunks, but it needs to be a multiple - // of "align", so recompute split_k based on that. - split_kv = ROUNDUP_POW2(std::max(1u, KV / split_k), alignment); - split_k = CEIL_DIV(KV, split_kv); + } else if (gqa_ratio <= 1) { + uint32_t total_wgs_no_split = Tr * workgroups_y * workgroups_z; + if (total_wgs_no_split < shader_core_count * 2) { + split_k = shader_core_count * 2 / total_wgs_no_split; } } + if (split_k > 1) { + // Try to evenly split KV into split_k chunks, but it needs to be a multiple + // of "align", so recompute split_k based on that. + split_kv = ROUNDUP_POW2(std::max(1u, KV / split_k), alignment); + split_k = CEIL_DIV(KV, split_kv); + } + // Reserve space for split_k temporaries. For each split x batch, we need to store the O matrix (D x ne1) // and the per-row m and L values (ne1 rows). We store all the matrices first, followed by the rows. // For matrices, the order is (inner to outer) [HSV, ne1, k, ne2, ne3]. @@ -8704,10 +8810,6 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx ggml_vk_preallocate_buffers(ctx, subctx); } - auto rows_cols = fa_rows_cols(path, HSK, HSV, !aligned, k->type, small_rows, small_cache); - const uint32_t Br = rows_cols[0]; - const uint32_t Bc = rows_cols[1]; - const uint32_t mask_opt_num_dwords = CEIL_DIV(nem0, 16 * Bc); const uint64_t mask_opt_size = sizeof(uint32_t) * mask_opt_num_dwords * CEIL_DIV(nem1, Br) * nem2 * nem3; @@ -8787,15 +8889,21 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx if (ctx->prealloc_split_k_need_sync) { ggml_vk_sync_buffers(ctx, subctx); } - workgroups_x *= pipeline->wg_denoms[0]; + + // We reuse workgroups_x to mean the number of splits, so we need to + // cancel out the divide by wg_denoms[0]. + uint32_t dispatch_x; + if (gqa_ratio > 1) { + workgroups_x *= pipeline->wg_denoms[0]; + dispatch_x = split_k * workgroups_x; + } else { + dispatch_x = Tr * split_k * pipeline->wg_denoms[0]; + } + vk_subbuffer split_k_buf = ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {q_buf, k_buf, v_buf, mask_buf, sinks_buf, split_k_buf, mask_opt_buf}, - // We only use split_k when group query attention is enabled, which means - // there's no more than one tile of rows (i.e. workgroups_x would have been - // one). We reuse workgroups_x to mean the number of splits, so we need to - // cancel out the divide by wg_denoms[0]. - pc, { split_k * workgroups_x, workgroups_y, workgroups_z }); + pc, { dispatch_x, workgroups_y, workgroups_z }); ggml_vk_sync_buffers(ctx, subctx); const vk_op_flash_attn_split_k_reduce_push_constants pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3, split_k, (sinks != nullptr) }; @@ -15420,6 +15528,46 @@ static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDevicePrope } } +static uint32_t ggml_vk_intel_shader_core_count(const vk::PhysicalDevice& vkdev) { + VkPhysicalDeviceProperties2 props = vkdev.getProperties2(); + + if (props.properties.vendorID != VK_VENDOR_ID_INTEL) { + return 0; + } + + const uint32_t device_id = props.properties.deviceID; + + switch (device_id) { + case 0x56A6: // A310 + return 6; + case 0x5693: // A370M + case 0x56A5: // A380 + case 0x56B1: // Pro A40/A50 + return 8; + case 0x5697: // A530M + return 12; + case 0x5692: // A550M + case 0x56B3: // Pro A60 + return 16; + case 0x56A2: // A580 + return 24; + case 0x5691: // A730M + case 0x56A1: // A750 + return 28; + case 0x56A0: // A770 + case 0x5690: // A770M + return 32; + case 0xE212: // Pro B50 + return 16; + case 0xE20C: // B570 + return 18; + case 0xE20B: // B580 + return 20; + default: + return 0; + } +} + // checks #ifdef GGML_VULKAN_CHECK_RESULTS @@ -16096,7 +16244,7 @@ static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph * ggml_vk_print_graph_origin(tensor, done); } - if (avg_err > 0.5 || std::isnan(avg_err)) { + if (avg_err > 0.01 || std::isnan(avg_err)) { std::cerr << "ERROR: avg_err=" << avg_err << " in " << ggml_op_name(tensor->op) << " (check " << check_counter << ")" << std::endl; std::cerr << "tensor=" << tensor << " tensor->name=" << tensor->name << " tensor->type: " << ggml_type_name(tensor->type) << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << " offset=" << tensor->view_offs << std::endl; if (src0 != nullptr) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index 0735f678549..135ab1ad625 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -3,9 +3,13 @@ #extension GL_EXT_control_flow_attributes : enable #extension GL_EXT_shader_16bit_storage : require -#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require +#ifdef FLOAT16 +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#extension GL_EXT_shader_subgroup_extended_types_float16 : require +#endif + #extension GL_KHR_shader_subgroup_shuffle : enable #extension GL_KHR_shader_subgroup_vote : enable @@ -15,8 +19,10 @@ const uint32_t HSK_per_thread = HSK / D_split; const uint32_t HSV_per_thread = HSV / D_split; -const uint32_t cols_per_iter = WorkGroupSize / D_split; +const uint32_t rows_per_thread = Br / row_split; +const uint32_t cols_per_iter = WorkGroupSize / D_split / row_split; const uint32_t cols_per_thread = Bc / cols_per_iter; +const uint32_t num_subgroups = SubGroupSize == 0 ? 0 : WorkGroupSize / SubGroupSize; layout (binding = 0) readonly buffer Q {float data_q[];}; @@ -27,20 +33,22 @@ layout (binding = 2) readonly buffer V {float16_t data_v[];}; layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];}; layout (binding = 3) readonly buffer M {float16_t data_m[];}; -// Store the output when doing grouped query attention. -// Rows index by Q's dimension 2, and the first N rows are valid. -D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) -{ - uint32_t offset = (iq2 + r) * HSV + c; - data_o[o_offset + offset] = D_TYPE(elem); - return elem; -} +// If SubGroupSize is set to 0 then only use shmem reductions +const uint32_t tmpsh_size = (SubGroupSize > 0) ? (row_split == 1 ? num_subgroups * D_split : num_subgroups) : WorkGroupSize; +shared float tmpsh[tmpsh_size]; +shared FLOAT_TYPEV4 tmpshv4[tmpsh_size]; -shared FLOAT_TYPE tmpsh[WorkGroupSize]; -shared vec4 tmpshv4[WorkGroupSize]; +const uint32_t masksh_stride = Br + 1; +shared FLOAT_TYPE masksh[Bc * masksh_stride]; -shared float masksh[Bc][Br]; -shared vec4 Qf[Br][HSK / 4]; +const uint32_t qf_stride = HSK / 4 + 1; +shared FLOAT_TYPEV4 Qf[Br * qf_stride]; + +const uint32_t D = HSK > HSV ? HSK : HSV; +const uint32_t kvsh_stride = D / 4 + 1; +shared FLOAT_TYPEV4 kvsh[SHMEM_STAGING != 0 ? Bc * kvsh_stride : 1]; + +shared vec4 occupancy_limiter[LIMIT_OCCUPANCY_SHMEM > 0 ? LIMIT_OCCUPANCY_SHMEM : 1]; void main() { #ifdef NEEDS_INIT_IQ_SHMEM @@ -50,8 +58,24 @@ void main() { init_indices(); const uint32_t tid = gl_LocalInvocationIndex; + const uint32_t threads_per_rowgroup = gl_WorkGroupSize.x / row_split; + const uint32_t row_tid = gl_LocalInvocationIndex / threads_per_rowgroup; + const uint32_t rowgroup_tid = gl_LocalInvocationIndex % threads_per_rowgroup; const uint32_t d_tid = gl_LocalInvocationIndex % D_split; - const uint32_t col_tid = gl_LocalInvocationIndex / D_split; + const uint32_t col_tid = (gl_LocalInvocationIndex % threads_per_rowgroup) / D_split; + + if (LIMIT_OCCUPANCY_SHMEM > 0) { + // This just exists to avoid the occupancy_limiter array getting optimized out + occupancy_limiter[tid] = vec4(tid); + + barrier(); + + if (occupancy_limiter[tid] == vec4(99999.0)) { + data_ov4[0] = D_TYPEV4(occupancy_limiter[tid]); + } + } + +#define tile_row(r) (row_tid * rows_per_thread + (r)) uint32_t q_offset = gqa_iq1*p.nb01 + (iq2*p.nb02 + iq3*p.nb03) / 4; @@ -60,37 +84,37 @@ void main() { uint32_t r = (idx + tid) / (HSK / 4); if (r < Br && d < HSK / 4 && i * Br + r < N) { - Qf[r][d] = vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d]) * p.scale; + Qf[r * qf_stride + d] = FLOAT_TYPEV4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale); } } barrier(); - vec4 Of[Br][HSV_per_thread / 4]; + FLOAT_TYPEV4 Of[rows_per_thread][HSV_per_thread / 4]; [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - Of[r][d] = vec4(0.0); + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Of[r][d] = FLOAT_TYPEV4(0.0); } } - float Lf[Br], Mf[Br]; + float Lf[rows_per_thread], Mf[rows_per_thread]; // Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M. const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF); - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { Lf[r] = 0; Mf[r] = NEG_FLT_MAX_OVER_2; } - float slope[Br]; - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - slope[r] = 1.0; + ACC_TYPE slope[rows_per_thread]; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + slope[r] = ACC_TYPE(1.0); } // ALiBi if (p.max_bias > 0.0f) { - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - slope[r] = perElemOpComputeSlope(r, col_tid, ACC_TYPE(0), iq2); + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + slope[r] = perElemOpComputeSlope(tile_row(r), col_tid, ACC_TYPE(0), iq2); } } @@ -113,75 +137,141 @@ void main() { uint32_t mask_opt = 0; uint32_t mask_opt_idx = ~0; + uint32_t mask_opt_bits = 0; [[dont_unroll]] for (uint32_t j = start_j; j < end_j; ++j) { + if (MASK_ENABLE) { + if (USE_MASK_OPT && mask_opt_idx != j / 16) { + mask_opt_idx = j / 16; + mask_opt = data_mask_opt[mo_offset + mask_opt_idx]; + } + mask_opt_bits = (mask_opt >> ((j % 16) * 2)) & 0x3; + if (mask_opt_bits == MASK_OPT_ALL_NEG_INF) { + // skip this block + continue; + } + // Only load if the block is not all zeros + if (mask_opt_bits != MASK_OPT_ALL_ZERO) { + bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0; - if (USE_MASK_OPT && mask_opt_idx != j / 16) { - mask_opt_idx = j / 16; - mask_opt = data_mask_opt[mo_offset + mask_opt_idx]; - } - uint32_t mask_opt_bits = (mask_opt >> ((j % 16) * 2)) & 0x3; - if (mask_opt_bits == MASK_OPT_ALL_NEG_INF) { - // skip this block - continue; - } - // Only load if the block is not all zeros - if (MASK_ENABLE && mask_opt_bits != MASK_OPT_ALL_ZERO) { - bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0; - - float max_mask = NEG_FLT_MAX_OVER_2; - [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { - uint32_t c = (idx + tid) % Bc; - uint32_t r = (idx + tid) / Bc; - if (idx + tid < Bc * Br) { - if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) { - float m = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]); - masksh[c][r] = m; - max_mask = max(max_mask, m); - } else { - masksh[c][r] = float(0); + float max_mask = NEG_FLT_MAX_OVER_2; + barrier(); + [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { + uint32_t c = (idx + tid) % Bc; + uint32_t r = (idx + tid) / Bc; + if (idx + tid < Bc * Br) { + if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) { + FLOAT_TYPE m = FLOAT_TYPE(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]); + masksh[c * masksh_stride + r] = m; + max_mask = max(max_mask, float(m)); + } else { + masksh[c * masksh_stride + r] = FLOAT_TYPE(0); + } } } - } - // skip the block if the mask is entirely -inf - bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2); - barrier(); - if (gl_SubgroupInvocationID == 0) { - tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f; - } - barrier(); - [[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) { - max_mask = max(max_mask, tmpsh[s]); - } - if (max_mask <= NEG_FLT_MAX_OVER_2) { - continue; + // skip the block if the mask is entirely -inf + bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2); + barrier(); + if (gl_SubgroupInvocationID == 0) { + tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f; + } + barrier(); + [[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) { + max_mask = max(max_mask, tmpsh[s]); + } + if (max_mask <= NEG_FLT_MAX_OVER_2) { + continue; + } } } - float Sf[Br][cols_per_thread]; - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + ACC_TYPE Sf[rows_per_thread][cols_per_thread]; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { - Sf[r][c] = 0.0; + Sf[r][c] = ACC_TYPE(0.0); } } + if (SHMEM_STAGING != 0) { + barrier(); + [[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) { + uint32_t d = (idx + tid) % (HSK / 4); + uint32_t c = (idx + tid) / (HSK / 4); + if (idx + gl_WorkGroupSize.x <= Bc * HSK / 4 || c < Bc) { + FLOAT_TYPEV4 K_Tf = FLOAT_TYPEV4(0); + if (!KV_bounds_check || j * Bc + c < KV) { +#if BLOCK_SIZE > 1 + uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d; + uint ib = coord / BLOCK_SIZE; + uint iqs = (coord % BLOCK_SIZE); + K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); +#else + K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]); +#endif + } - [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { - if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) { - continue; + kvsh[c * kvsh_stride + d] = K_Tf; + } } + barrier(); + } + + // More d iterations means Q register caching becomes relevant + // Few iterations means the additional registers needed are worse than the speed-up from caching + if (HSK_per_thread / 4 > 4) { [[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) { + FLOAT_TYPEV4 Q_cache[rows_per_thread]; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Q_cache[r] = Qf[tile_row(r) * qf_stride + d * D_split + d_tid]; + } + + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) { + continue; + } + + FLOAT_TYPEV4 K_Tf; + if (SHMEM_STAGING != 0) { + K_Tf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)]; + } else { #if BLOCK_SIZE > 1 - uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); - uint ib = coord / BLOCK_SIZE; - uint iqs = (coord % BLOCK_SIZE); - vec4 K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); + uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); + uint ib = coord / BLOCK_SIZE; + uint iqs = (coord % BLOCK_SIZE); + K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); #else - vec4 K_Tf = vec4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]); + K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]); #endif - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - Sf[r][c] += dot(Qf[r][d * D_split + d_tid], K_Tf); + } + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Sf[r][c] += ACC_TYPE(dot(Q_cache[r], K_Tf)); + } + } + } + } else { + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) { + continue; + } + + [[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) { + FLOAT_TYPEV4 K_Tf; + if (SHMEM_STAGING != 0) { + K_Tf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)]; + } else { +#if BLOCK_SIZE > 1 + uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); + uint ib = coord / BLOCK_SIZE; + uint iqs = (coord % BLOCK_SIZE); + K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); +#else + K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]); +#endif + } + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Sf[r][c] += ACC_TYPE(dot(Qf[tile_row(r) * qf_stride + d * D_split + d_tid], K_Tf)); + } } } } @@ -189,89 +279,109 @@ void main() { [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { // Compute sum across the D_split [[unroll]] for (uint s = D_split / 2; s > 0; s >>= 1) { - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { Sf[r][c] += subgroupShuffleXor(Sf[r][c], s); } } } if (LOGIT_SOFTCAP) { - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { - Sf[r][c] = p.logit_softcap * tanh(Sf[r][c]); + Sf[r][c] = ACC_TYPE(p.logit_softcap * tanh(Sf[r][c])); } } } if (MASK_ENABLE && mask_opt_bits != MASK_OPT_ALL_ZERO) { [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - float mvf = masksh[c * cols_per_iter + col_tid][r]; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + FLOAT_TYPE mvf = masksh[(c * cols_per_iter + col_tid) * masksh_stride + tile_row(r)]; Sf[r][c] += slope[r]*mvf; } } - barrier(); } - float rowmaxf[Br], Pf[Br][cols_per_thread], rowsumf[Br], eMf[Br], Moldf[Br]; - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - rowmaxf[r] = NEG_FLT_MAX_OVER_2; + float eMf[rows_per_thread]; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + float rowmaxf = NEG_FLT_MAX_OVER_2; [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) { continue; } - rowmaxf[r] = max(rowmaxf[r], Sf[r][c]); + rowmaxf = max(rowmaxf, float(Sf[r][c])); } - Moldf[r] = Mf[r]; + float Moldf = Mf[r]; // M = max(rowmax, Mold) // P = e^(S - M) // eM = e^(Mold - M) - Mf[r] = max(rowmaxf[r], Moldf[r]); - [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { - Pf[r][c] = exp(Sf[r][c] - Mf[r]); - } - eMf[r] = exp(Moldf[r] - Mf[r]); + Mf[r] = max(rowmaxf, Moldf); + eMf[r] = exp(Moldf - Mf[r]); + Lf[r] = eMf[r]*Lf[r]; + } - // Compute sum across row of P - rowsumf[r] = 0.0; - [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { - if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) { - continue; - } - rowsumf[r] += Pf[r][c]; + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Of[r][d] = FLOAT_TYPE(eMf[r]) * Of[r][d]; } - - Lf[r] = eMf[r]*Lf[r] + rowsumf[r]; } - [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - Of[r][d] = eMf[r] * Of[r][d]; + if (SHMEM_STAGING != 0) { + barrier(); + [[unroll]] for (uint32_t idx = 0; idx < Bc * HSV / 4; idx += gl_WorkGroupSize.x) { + uint32_t d = (idx + tid) % (HSV / 4); + uint32_t c = (idx + tid) / (HSV / 4); + if (idx + gl_WorkGroupSize.x <= Bc * HSV / 4 || c < Bc) { + FLOAT_TYPEV4 V_Tf = FLOAT_TYPEV4(0); + if (!KV_bounds_check || j * Bc + c < KV) { +#if BLOCK_SIZE > 1 + uint coord = (j * Bc + c) * v_stride * BLOCK_SIZE + 4 * d; + uint ib = coord / BLOCK_SIZE; + uint iqs = (coord % BLOCK_SIZE); + V_Tf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); +#else + V_Tf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]); +#endif + } + + kvsh[c * kvsh_stride + d] = V_Tf; + } } + barrier(); } [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) { continue; } + + FLOAT_TYPE Pf[rows_per_thread]; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Pf[r] = FLOAT_TYPE(exp(float(Sf[r][c]) - Mf[r])); + Lf[r] += Pf[r]; + } + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + FLOAT_TYPEV4 Vf; + if (SHMEM_STAGING != 0) { + Vf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)]; + } else { #if BLOCK_SIZE > 1 - uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); - uint ib = coord / BLOCK_SIZE; - uint iqs = (coord % BLOCK_SIZE); - vec4 Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); + uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); + uint ib = coord / BLOCK_SIZE; + uint iqs = (coord % BLOCK_SIZE); + Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); #else - vec4 Vf = vec4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]); + Vf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]); #endif - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - Of[r][d] += Pf[r][c] * Vf; + } + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Of[r][d] += FLOAT_TYPEV4(Pf[r] * Vf); } } } - - barrier(); } // prevent race on tmpsh @@ -279,58 +389,108 @@ void main() { // reduce across threads - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - float rowmaxf, eMf; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + float rowmaxf = Mf[r]; - tmpsh[tid] = Mf[r]; // Compute max across the row - barrier(); - [[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) { - if (tid < s) { - tmpsh[tid] = max(tmpsh[tid], tmpsh[tid + s]); + if (SubGroupSize > 0) { + [[unroll]] for (uint s = D_split; s < SubGroupSize; s *= 2) { + rowmaxf = max(rowmaxf, subgroupShuffleXor(rowmaxf, s)); } + if (row_split == 1) { + // Reduce inside workgroup with shmem + barrier(); + if (gl_SubgroupInvocationID == d_tid) { + tmpsh[gl_SubgroupID * D_split + d_tid] = rowmaxf; + } + barrier(); + rowmaxf = tmpsh[d_tid]; + [[unroll]] for (uint32_t s = 1; s < num_subgroups; ++s) { + rowmaxf = max(rowmaxf, tmpsh[s * D_split + d_tid]); + } + } + } else { + barrier(); + tmpsh[tid] = rowmaxf; barrier(); + [[unroll]] for (int s = int(threads_per_rowgroup) / 2; s >= D_split; s >>= 1) { + if (rowgroup_tid < s) { + tmpsh[tid] = max(tmpsh[tid], tmpsh[tid ^ s]); + } + barrier(); + } + rowmaxf = tmpsh[row_tid * threads_per_rowgroup + d_tid]; } - rowmaxf = tmpsh[d_tid]; - barrier(); float Moldf = Mf[r]; // M = max(rowmax, Mold) // eM = e^(Mold - M) Mf[r] = max(rowmaxf, Moldf); - eMf = exp(Moldf - Mf[r]); + float eMf = exp(Moldf - Mf[r]); Lf[r] = eMf*Lf[r]; - tmpsh[tid] = Lf[r]; - // Compute sum across the row - barrier(); - [[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) { - if (tid < s) { - tmpsh[tid] = tmpsh[tid] + tmpsh[tid + s]; + if (SubGroupSize > 0) { + [[unroll]] for (uint s = D_split; s < SubGroupSize; s *= 2) { + Lf[r] += subgroupShuffleXor(Lf[r], s); } + if (row_split == 1) { + barrier(); + if (gl_SubgroupInvocationID == d_tid) { + tmpsh[gl_SubgroupID * D_split + d_tid] = Lf[r]; + } + barrier(); + Lf[r] = tmpsh[d_tid]; + [[unroll]] for (uint32_t s = 1; s < num_subgroups; ++s) { + Lf[r] += tmpsh[s * D_split + d_tid]; + } + } + } else { barrier(); + tmpsh[tid] = Lf[r]; + barrier(); + [[unroll]] for (int s = int(threads_per_rowgroup) / 2; s >= D_split; s >>= 1) { + if (rowgroup_tid < s) { + tmpsh[tid] = tmpsh[tid] + tmpsh[tid ^ s]; + } + barrier(); + } + Lf[r] = tmpsh[row_tid * threads_per_rowgroup + d_tid]; } - Lf[r] = tmpsh[d_tid]; - barrier(); [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + Of[r][d] = FLOAT_TYPE(eMf) * Of[r][d]; - Of[r][d] = eMf * Of[r][d]; - tmpshv4[tid] = Of[r][d]; - - barrier(); - [[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) { - if (tid < s) { - Of[r][d] += tmpshv4[tid + s]; - tmpshv4[tid] = Of[r][d]; + if (SubGroupSize > 0) { + [[unroll]] for (uint s = D_split; s < SubGroupSize; s *= 2) { + Of[r][d] += subgroupShuffleXor(Of[r][d], s); + } + if (row_split == 1) { + barrier(); + if (gl_SubgroupInvocationID == d_tid) { + tmpshv4[gl_SubgroupID * D_split + d_tid] = Of[r][d]; + } + barrier(); + Of[r][d] = tmpshv4[d_tid]; + [[unroll]] for (uint32_t s = 1; s < num_subgroups; ++s) { + Of[r][d] += tmpshv4[s * D_split + d_tid]; + } } + } else { + barrier(); + tmpshv4[tid] = Of[r][d]; barrier(); + [[unroll]] for (int s = int(threads_per_rowgroup) / 2; s >= D_split; s >>= 1) { + if (rowgroup_tid < s) { + Of[r][d] += tmpshv4[tid ^ s]; + tmpshv4[tid] = Of[r][d]; + } + barrier(); + } + Of[r][d] = tmpshv4[row_tid * threads_per_rowgroup + d_tid]; } - Of[r][d] = tmpshv4[d_tid]; - barrier(); } } @@ -338,33 +498,53 @@ void main() { // If there is split_k, then the split_k resolve shader does the final // division by L. Store the intermediate O value and per-row m and L values. if (p.k_num > 1) { - // note: O and Q have swapped coord 1,2. - uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)); - - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - if (r < N) { - [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { - [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { - perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N); + if (p.gqa_ratio > 1) { + // note: O and Q have swapped coord 1,2. + uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)) / 4; + + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + const uint row = tile_row(r); + if (row < N) { + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + gqaStore(row, d * D_split + d_tid, Of[r][d], o_offset, iq2, N); } } } - } - o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)); - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - if (r < N) { - perElemOpStoreCol0(r, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N); - perElemOpStoreCol0(r, 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N); + o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)); + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + const uint row = tile_row(r); + if (row < N) { + perElemOpStoreCol0(row, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N); + perElemOpStoreCol0(row, 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N); + } } - } + } else { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + const uint row = tile_row(r); + const uint global_row = i * Br + row; + + if (global_row < N) { + uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (global_row + p.ne2 * iq3)) / 4; + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + data_ov4[o_offset + iq2 * HSV/4 + d * D_split + d_tid] = D_TYPEV4(Of[r][d]); + } + } + + if (global_row < N && d_tid == 0 && col_tid == 0) { + uint32_t lm_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (global_row + p.ne2 * iq3)); + data_o[lm_offset + iq2] = D_TYPE(Lf[r]); + data_o[lm_offset + p.ne1 + iq2] = D_TYPE(Mf[r]); + } + } + } return; } if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) { - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - float sink = perElemOpGetSink(r, 0u, ACC_TYPE(0), iq2); + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + float sink = perElemOpGetSink(tile_row(r), 0u, ACC_TYPE(0), iq2); float ms = 1.0f; float vs = 1.0f; @@ -373,7 +553,7 @@ void main() { ms = exp(Mf[r] - sink); [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { - Of[r][d] *= ms; + Of[r][d] *= FLOAT_TYPE(ms); } } else { vs = exp(sink - Mf[r]); @@ -383,39 +563,37 @@ void main() { } } - float Lfrcp[Br]; - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + float Lfrcp[rows_per_thread]; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { Lfrcp[r] = (Lf[r] == 0.0) ? 0.0 : (1.0 / Lf[r]); } [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - Of[r][d] *= Lfrcp[r]; -#if defined(ACC_TYPE_MAX) - Of[r][d] = clamp(Of[r][d], -vec4(ACC_TYPE_MAX), vec4(ACC_TYPE_MAX)); + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Of[r][d] *= FLOAT_TYPE(Lfrcp[r]); +#if defined(FLOAT_TYPE_MAX) + Of[r][d] = clamp(Of[r][d], -FLOAT_TYPE_MAX, FLOAT_TYPE_MAX); #endif } } - uint32_t o_offset = gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV; + uint32_t o_offset = (gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV) / 4; if (p.gqa_ratio > 1) { - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - if (r < N) { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + const uint row = tile_row(r); + if (row < N) { [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { - [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { - perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N); - } + gqaStore(row, d * D_split + d_tid, Of[r][d], o_offset, iq2, N); } } } } else { - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - if (i * Br + r < N) { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + const uint row = tile_row(r); + if (i * Br + row < N) { [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { - [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { - data_o[o_offset + iq2 * HSV + (i * Br + r) * p.ne1 * HSV + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]); - } + data_ov4[o_offset + (iq2 * HSV + (i * Br + row) * p.ne1 * HSV) / 4 + d * D_split + d_tid] = D_TYPEV4(Of[r][d]); } } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl index 4142c1e6eaa..d444542b533 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl @@ -1,16 +1,18 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; -layout (constant_id = 0) const uint32_t WorkGroupSize = 128; -layout (constant_id = 1) const uint32_t Br = 1; -layout (constant_id = 2) const uint32_t Bc = 32; -layout (constant_id = 3) const uint32_t HSK = 32; -layout (constant_id = 4) const uint32_t HSV = 32; -layout (constant_id = 5) const uint32_t Clamp = 0; -layout (constant_id = 6) const uint32_t D_split = 16; -layout (constant_id = 7) const uint32_t SubGroupSize = 32; -layout (constant_id = 8) const uint32_t K_LOAD_SHMEM = 0; -layout (constant_id = 9) const uint32_t Flags = 0; +layout (constant_id = 0) const uint32_t WorkGroupSize = 128; +layout (constant_id = 1) const uint32_t Br = 1; +layout (constant_id = 2) const uint32_t Bc = 32; +layout (constant_id = 3) const uint32_t HSK = 32; +layout (constant_id = 4) const uint32_t HSV = 32; +layout (constant_id = 5) const uint32_t Clamp = 0; +layout (constant_id = 6) const uint32_t D_split = 16; +layout (constant_id = 7) const uint32_t row_split = 1; +layout (constant_id = 8) const uint32_t SubGroupSize = 32; +layout (constant_id = 9) const uint32_t SHMEM_STAGING = 0; +layout (constant_id = 10) const uint32_t Flags = 0; +layout (constant_id = 11) const uint32_t LIMIT_OCCUPANCY_SHMEM = 0; const bool USE_MASK_OPT = (Flags & 1) != 0; const bool MASK_ENABLE = (Flags & 2) != 0; @@ -69,6 +71,7 @@ layout (push_constant) uniform parameter { layout (binding = 4) readonly buffer S {float data_s[];}; layout (binding = 5) writeonly buffer O {D_TYPE data_o[];}; +layout (binding = 5) writeonly buffer OV4 {D_TYPEV4 data_ov4[];}; layout (binding = 6) readonly buffer MO {uint32_t data_mask_opt[];}; @@ -94,12 +97,12 @@ layout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16 #define BLOCK_SIZE 4 #define BLOCK_BYTE_SIZE 16 -vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { +FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { // iqs is currently always zero in the flash attention shaders if (binding_idx == BINDING_IDX_K) { - return k_packed.k_data_packed[a_offset + ib]; + return FLOAT_TYPEV4(k_packed.k_data_packed[a_offset + ib]); } else { - return v_packed.v_data_packed[a_offset + ib]; + return FLOAT_TYPEV4(v_packed.v_data_packed[a_offset + ib]); } } #endif @@ -107,7 +110,7 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { #if defined(DATA_A_Q4_0) #define BLOCK_BYTE_SIZE 18 -vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { +FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { if (binding_idx == BINDING_IDX_K) { uint vui_lo = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); uint vui_hi = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); @@ -115,7 +118,7 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { vui_lo >>= shift; vui_hi >>= shift; - return float(k_packed.k_data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f); + return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * (FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - FLOAT_TYPE(8.0f)); } else { uint vui_lo = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); uint vui_hi = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); @@ -123,24 +126,24 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { vui_lo >>= shift; vui_hi >>= shift; - return float(v_packed.v_data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f); + return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * (FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - FLOAT_TYPE(8.0f)); } } #endif #if defined(DATA_A_Q8_0) #define BLOCK_BYTE_SIZE 34 -vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { +FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { if (binding_idx == BINDING_IDX_K) { const i8vec2 v0 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147 const i8vec2 v1 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy; - return float(k_packed.k_data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y); + return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * FLOAT_TYPEV4(v0.x, v0.y, v1.x, v1.y); } else { const i8vec2 v0 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147 const i8vec2 v1 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy; - return float(v_packed.v_data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y); + return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * FLOAT_TYPEV4(v0.x, v0.y, v1.x, v1.y); } } #endif @@ -189,10 +192,16 @@ void init_indices() KV = p.KV; if (p.k_num > 1) { - i = 0; - // batch and split_k share gl_WorkGroupID.x - gqa_iq1 = gl_WorkGroupID.x / p.k_num; - split_k_index = gl_WorkGroupID.x % p.k_num; + if (p.gqa_ratio > 1) { + i = 0; + // batch and split_k share gl_WorkGroupID.x + gqa_iq1 = gl_WorkGroupID.x / p.k_num; + split_k_index = gl_WorkGroupID.x % p.k_num; + } else { + gqa_iq1 = 0; + split_k_index = gl_WorkGroupID.x % p.k_num; + i = gl_WorkGroupID.x / p.k_num; + } } else if (p.gqa_ratio > 1) { i = 0; gqa_iq1 = gl_WorkGroupID.x; @@ -244,3 +253,11 @@ void init_indices() // Bias applied to softmax to stay in fp16 range. // Based on ggml-cuda issue https://github.com/ggml-org/llama.cpp/issues/18606 const float FATTN_KQ_MAX_OFFSET = 3.0f*0.6931f; + +// Store the output when doing grouped query attention. +// Rows index by Q's dimension 2, and the first N rows are valid. +void gqaStore(const in uint32_t r, const in uint32_t c, const in FLOAT_TYPEV4 elems, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) +{ + uint32_t offset = (iq2 + r) * HSV / 4 + c; + data_ov4[o_offset + offset] = D_TYPEV4(elems); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp index 19630972daf..526e8da384e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp @@ -19,7 +19,6 @@ const uint32_t MatBr = 16; const uint32_t MatBc = 16; -const uint32_t row_split = Bc / MatBc; const uint32_t rows_per_thread = Br / row_split; const uint32_t cols_per_iter = gl_WorkGroupSize.x / row_split; const uint32_t cols_per_thread = Bc / cols_per_iter; @@ -33,15 +32,6 @@ layout (binding = 2) readonly buffer V {float16_t data_v[];}; layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];}; layout (binding = 3) readonly buffer M {float16_t data_m[];}; -// Store the output when doing grouped query attention. -// Rows index by Q's dimension 2, and the first N rows are valid. -D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) -{ - uint32_t offset = (iq2 + r) * HSV + c; - data_o[o_offset + offset] = D_TYPE(elem); - return elem; -} - shared float tmpsh[row_split]; const uint32_t qstride = HSK_pad / 4 + 2; // in units of f16vec4 @@ -54,10 +44,14 @@ shared f16vec4 Psh[Bc * psh_stride]; const uint32_t sfshstride = (HSK <= 128) ? (Br / 4 + 2) : Br / 4; shared ACC_TYPEV4 sfsh[Bc * sfshstride]; -const uint32_t kshstride = (K_LOAD_SHMEM != 0 ? HSK_pad : MatBr) / 4 + 2; // in units of f16vec4 +const uint32_t D_pad = HSK_pad > HSV_pad ? HSK_pad : HSV_pad; +const uint32_t kvsh_stride = (SHMEM_STAGING != 0 ? D_pad : MatBr) / 4 + 2; // in units of f16vec4 const uint v_cols = MatBc / 4 * row_split; // total cols, 4 vec4s per MatBc * number of subgroups const uint vsh_stride = v_cols; -shared f16vec4 ksh[(kshstride >= vsh_stride) ? (Bc * kshstride) : (Bc * vsh_stride)]; +shared f16vec4 kvsh[(kvsh_stride >= vsh_stride) ? (Bc * kvsh_stride) : (Bc * vsh_stride)]; + +const uint32_t osh_stride = row_split * MatBr / 4; +shared f16vec4 pvsh[MatBc * osh_stride]; shared ACC_TYPE slope[Br]; @@ -84,11 +78,6 @@ void main() { Qf[i + tid] = f16vec4(0); } } - [[unroll]] for (uint i = 0; i < Bc * kshstride; i += gl_WorkGroupSize.x) { - if (i + tid < Bc * kshstride) { - ksh[i + tid] = f16vec4(0); - } - } barrier(); } @@ -104,10 +93,10 @@ void main() { } barrier(); - ACC_TYPEV4 Of[rows_per_thread][d_per_thread]; + f16vec4 Of[rows_per_thread][d_per_thread]; [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { [[unroll]] for (uint32_t d = 0; d < d_per_thread; ++d) { - Of[r][d] = ACC_TYPEV4(0.0); + Of[r][d] = f16vec4(0.0); } } @@ -153,22 +142,22 @@ void main() { uint32_t mask_opt = 0; uint32_t mask_opt_idx = ~0; + uint32_t mask_opt_bits = 0; + f16vec4 mask_cache[Bc * Br / 4 / WorkGroupSize]; [[dont_unroll]] for (uint32_t j = start_j; j < end_j; ++j) { - f16vec4 mask_cache[Bc * Br / 4 / WorkGroupSize]; [[unroll]] for (uint32_t idx = 0; idx < mask_cache.length(); ++idx) { mask_cache[idx] = f16vec4(0); } if (MASK_ENABLE) { - if (USE_MASK_OPT && mask_opt_idx != j / 16) { mask_opt_idx = j / 16; mask_opt = data_mask_opt[mo_offset + mask_opt_idx]; } - uint32_t mask_opt_bits = (mask_opt >> ((j % 16) * 2)) & 0x3; + mask_opt_bits = (mask_opt >> ((j % 16) * 2)) & 0x3; if (mask_opt_bits == MASK_OPT_ALL_NEG_INF) { // skip this block continue; @@ -231,24 +220,24 @@ void main() { } } - if (K_LOAD_SHMEM != 0) { - [[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) { - uint32_t d = (idx + tid) % (HSK / 4); - uint32_t c = (idx + tid) / (HSK / 4); - if (c < Bc && d < HSK / 4) { + if (SHMEM_STAGING != 0) { + [[unroll]] for (uint32_t idx = 0; idx < Bc * HSK_pad / 4; idx += gl_WorkGroupSize.x) { + uint32_t d = (idx + tid) % (HSK_pad / 4); + uint32_t c = (idx + tid) / (HSK_pad / 4); + if (idx + gl_WorkGroupSize.x <= Bc * HSK_pad / 4 || c < Bc) { f16vec4 K_Tf = f16vec4(0); - if (!KV_bounds_check || j * Bc + c < KV) { + if ((!KV_bounds_check || j * Bc + c < KV) && (HSK == HSK_pad || d < HSK / 4)) { #if BLOCK_SIZE > 1 uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d; uint ib = coord / BLOCK_SIZE; uint iqs = (coord % BLOCK_SIZE); - K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K)); + K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); #else K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]); #endif } - ksh[c * kshstride + d] = K_Tf; + kvsh[c * kvsh_stride + d] = K_Tf; } } barrier(); @@ -262,7 +251,11 @@ void main() { coopmat QMat; [[unroll]] for (uint32_t d = 0; d < HSK_pad / 16; ++d) { - if (K_LOAD_SHMEM == 0) { + // If SHMEM_STAGING is set, a Bc * HSK_pad size tile of K is loaded to shmem + // If not, f16 K is loaded directly from global memory if aligned, otherwise + // staged through a Bc * MatBr size staging buffer. + // If K is not type f16, then it is always staged for dequantization. + if (SHMEM_STAGING == 0) { #if BLOCK_SIZE == 1 if (KV_bounds_check || d * 16 + 16 > HSK) { #endif @@ -277,13 +270,13 @@ void main() { uint coord = (j * Bc + row) * k_stride * BLOCK_SIZE + d * 16 + col_vec * 4; uint ib = coord / BLOCK_SIZE; uint iqs = (coord % BLOCK_SIZE); - K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K)); + K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); #else K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + row) * k_stride / 4 + d * 16 / 4 + col_vec]); #endif } - ksh[row * kshstride + col_vec] = K_Tf; + kvsh[row * kvsh_stride + col_vec] = K_Tf; } } barrier(); @@ -295,8 +288,8 @@ void main() { if (KV_bounds_check || d * 16 + 16 > HSK) #endif { - uint coord = (gl_SubgroupID * MatBc) * kshstride; - coopMatLoad(KMat, ksh, coord, kshstride, gl_CooperativeMatrixLayoutRowMajor); + uint coord = (gl_SubgroupID * MatBc) * kvsh_stride; + coopMatLoad(KMat, kvsh, coord, kvsh_stride, gl_CooperativeMatrixLayoutRowMajor); } #if BLOCK_SIZE == 1 else { @@ -305,8 +298,8 @@ void main() { } #endif } else { - uint coord = (gl_SubgroupID * MatBc) * kshstride + d * 16 / 4; - coopMatLoad(KMat, ksh, coord, kshstride, gl_CooperativeMatrixLayoutRowMajor); + uint coord = (gl_SubgroupID * MatBc) * kvsh_stride + d * 16 / 4; + coopMatLoad(KMat, kvsh, coord, kvsh_stride, gl_CooperativeMatrixLayoutRowMajor); } coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor); @@ -329,7 +322,7 @@ void main() { barrier(); } - if (MASK_ENABLE) { + if (MASK_ENABLE && mask_opt_bits != MASK_OPT_ALL_ZERO) { [[unroll]] for (uint32_t idx = 0; idx < Bc * Br / 4; idx += gl_WorkGroupSize.x) { uint32_t c = (idx + tid) / (Br / 4); uint32_t r = (idx + tid) % (Br / 4); @@ -374,7 +367,7 @@ void main() { [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) { const uint d_local = d0 / threads_per_rowgroup; [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - Of[r][d_local] = ACC_TYPE(eMf[r]) * Of[r][d_local]; + Of[r][d_local] = float16_t(eMf[r]) * Of[r][d_local]; } } @@ -397,19 +390,47 @@ void main() { } } + if (SHMEM_STAGING != 0) { + [[unroll]] for (uint32_t idx = 0; idx < Bc * HSV_pad / 4; idx += gl_WorkGroupSize.x) { + uint32_t d = (idx + tid) % (HSV_pad / 4); + uint32_t c = (idx + tid) / (HSV_pad / 4); + if (idx + gl_WorkGroupSize.x <= Bc * HSV_pad / 4 || c < Bc) { + f16vec4 V_Tf = f16vec4(0); + if ((!KV_bounds_check || j * Bc + c < KV) && (HSV == HSV_pad || d < HSV / 4)) { +#if BLOCK_SIZE > 1 + uint coord = (j * Bc + c) * v_stride * BLOCK_SIZE + 4 * d; + uint ib = coord / BLOCK_SIZE; + uint iqs = (coord % BLOCK_SIZE); + V_Tf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); +#else + V_Tf = f16vec4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]); +#endif + } + + kvsh[c * kvsh_stride + d] = V_Tf; + } + } + } + barrier(); + const uint num_hsv_tiles = (HSV + MatBc * row_split - 1) / (MatBc * row_split); // round up // Each subgroup handles HSV/4 columns [[unroll]] for (uint32_t hsv_tile = 0; hsv_tile < num_hsv_tiles; ++hsv_tile) { const uint hsv_offset = (hsv_tile * row_split + gl_SubgroupID) * 16; - SfMat = coopmat(0); + coopmat PVMat = coopmat(0); // Preload V tiles for [Bc, 16 * num subgroups] const uint v_rows = Bc; const uint v_total = v_rows * v_cols; const uint v_loads_per_thread = v_total / gl_WorkGroupSize.x; + // If SHMEM_STAGING is set, a Bc * HSV_pad size tile of V is loaded to shmem. + // If not, f16 V is loaded directly from global memory if aligned, otherwise + // staged through a Bc * MatBr size staging buffer. + // If V is not type f16, then it is always staged for dequantization. + if (SHMEM_STAGING == 0) { #if BLOCK_SIZE == 1 // For f16, only preload if not aligned if (KV_bounds_check) { @@ -428,44 +449,52 @@ void main() { if (!KV_bounds_check || (v_row < KV && v_col < HSV)) { #if BLOCK_SIZE > 1 - ksh[row * vsh_stride + col] = f16vec4(dequantize4(ib, iqs, v_offset, BINDING_IDX_V)); + kvsh[row * vsh_stride + col] = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); #else - ksh[row * vsh_stride + col] = data_vv4[(v_offset + v_row * v_stride + v_col) / 4]; + kvsh[row * vsh_stride + col] = data_vv4[(v_offset + v_row * v_stride + v_col) / 4]; #endif } else { - ksh[row * vsh_stride + col] = f16vec4(0.0f); + kvsh[row * vsh_stride + col] = f16vec4(0.0f); } } + #if BLOCK_SIZE == 1 } #endif - + } barrier(); - [[unroll]] for (uint32_t bc_chunk = 0; bc_chunk < Bc / MatBc; ++bc_chunk) { - coopMatLoad(KMat, Psh, bc_chunk * MatBc * psh_stride, psh_stride, gl_CooperativeMatrixLayoutColumnMajor); + const uint o_offset = gl_SubgroupID * MatBr / 4; + + if (hsv_offset < HSV_pad) { + [[unroll]] for (uint32_t bc_chunk = 0; bc_chunk < Bc / MatBc; ++bc_chunk) { + coopMatLoad(KMat, Psh, bc_chunk * MatBc * psh_stride, psh_stride, gl_CooperativeMatrixLayoutColumnMajor); + if (SHMEM_STAGING == 0) { #if BLOCK_SIZE == 1 - if (!KV_bounds_check) { - // F16 values can be loaded directly from global memory - const uint v_tile_row = j * Bc + bc_chunk * MatBc; - const uint v_tile_offset = v_offset / 4 + v_tile_row * v_stride / 4 + hsv_offset / 4; - coopMatLoad(QMat, data_vv4, v_tile_offset, v_stride / 4, gl_CooperativeMatrixLayoutRowMajor); - } else + if (!KV_bounds_check) { + // F16 values can be loaded directly from global memory + const uint v_tile_row = j * Bc + bc_chunk * MatBc; + const uint v_tile_offset = v_offset / 4 + v_tile_row * v_stride / 4 + hsv_offset / 4; + coopMatLoad(QMat, data_vv4, v_tile_offset, v_stride / 4, gl_CooperativeMatrixLayoutRowMajor); + } else #endif - { - const uint v_tile_offset = bc_chunk * MatBr * v_cols + gl_SubgroupID * (MatBc / 4); - coopMatLoad(QMat, ksh, v_tile_offset, vsh_stride, gl_CooperativeMatrixLayoutRowMajor); + { + const uint v_tile_offset = bc_chunk * MatBr * v_cols + gl_SubgroupID * (MatBc / 4); + coopMatLoad(QMat, kvsh, v_tile_offset, vsh_stride, gl_CooperativeMatrixLayoutRowMajor); + } + } else { + const uint v_tile_offset = bc_chunk * MatBc * kvsh_stride + (hsv_tile * row_split + gl_SubgroupID) * (MatBc / 4); + coopMatLoad(QMat, kvsh, v_tile_offset, kvsh_stride, gl_CooperativeMatrixLayoutRowMajor); + } + + PVMat = coopMatMulAdd(KMat, QMat, PVMat); } - SfMat = coopMatMulAdd(KMat, QMat, SfMat); + // Store PVMat to pvsh and load into Of + coopMatStore(PVMat, pvsh, o_offset, osh_stride, gl_CooperativeMatrixLayoutRowMajor); } - // Store SfMat to sfsh and load into Of - const uint osh_stride = row_split * MatBc / 4; - const uint o_offset = gl_SubgroupID * MatBc / 4; - coopMatStore(SfMat, sfsh, o_offset, osh_stride, gl_CooperativeMatrixLayoutRowMajor); - barrier(); const uint hsv_per_tile = row_split * MatBc; @@ -484,7 +513,7 @@ void main() { if (hsv_col >= hsv_base && hsv_col < hsv_base + hsv_per_tile && hsv_col < HSV) { const uint local_hsv = (hsv_col - hsv_base) / 4; - Of[r][d_local] += ACC_TYPEV4(sfsh[row * osh_stride + local_hsv]); + Of[r][d_local] += pvsh[row * osh_stride + local_hsv]; } } } @@ -500,27 +529,48 @@ void main() { // If there is split_k, then the split_k resolve shader does the final // division by L. Store the intermediate O value and per-row m and L values. if (p.k_num > 1) { - // note: O and Q have swapped coord 1,2. - uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)); + if (p.gqa_ratio > 1) { + // note: O and Q have swapped coord 1,2. + uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)) / 4; - [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - if (tile_row(r) < N) { - [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) { - const uint d = d0 + col_tid; - if (d >= HSV/4) break; - const uint d_local = d0 / threads_per_rowgroup; - [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { - perElemOpGqaStore(tile_row(r), 4 * d + comp, float(Of[r][d_local][comp]), o_offset, iq2, N); + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + if (tile_row(r) < N) { + [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) { + const uint d = d0 + col_tid; + if (d >= HSV/4) break; + const uint d_local = d0 / threads_per_rowgroup; + gqaStore(tile_row(r), d, Of[r][d_local], o_offset, iq2, N); } } } - } - o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)); - [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - if (tile_row(r) < N) { - perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N); - perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N); + o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)); + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + if (tile_row(r) < N) { + perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N); + perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N); + } + } + } else { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + const uint row = tile_row(r); + const uint global_row = i * Br + row; + + if (global_row < N) { + uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (global_row + p.ne2 * iq3)) / 4; + + [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) { + const uint d = d0 + col_tid; + if (d >= HSV/4) break; + data_ov4[o_offset + iq2 * HSV/4 + d] = D_TYPEV4(Of[r][d/threads_per_rowgroup]); + } + } + + if (global_row < N && col_tid == 0) { + uint32_t lm_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (global_row + p.ne2 * iq3)); + data_o[lm_offset + iq2] = D_TYPE(Lf[r]); + data_o[lm_offset + p.ne1 + iq2] = D_TYPE(Mf[r]); + } } } @@ -539,7 +589,7 @@ void main() { [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) { const uint d_local = d0 / threads_per_rowgroup; - Of[r][d_local] *= ACC_TYPE(ms); + Of[r][d_local] *= float16_t(ms); } } else { vs = exp(sink - Mf[r]); @@ -557,14 +607,14 @@ void main() { [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) { const uint d_local = d0 / threads_per_rowgroup; [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - Of[r][d_local] *= ACC_TYPE(Lfrcp[r]); -#if defined(ACC_TYPE_MAX) - Of[r][d_local] = clamp(Of[r][d_local], -ACC_TYPE_MAX, ACC_TYPE_MAX); + Of[r][d_local] *= float16_t(Lfrcp[r]); +#if defined(FLOAT_TYPE_MAX) + Of[r][d_local] = clamp(Of[r][d_local], -FLOAT_TYPE_MAX, FLOAT_TYPE_MAX); #endif } } - uint32_t o_offset = gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV; + uint32_t o_offset = (gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV) / 4; if (p.gqa_ratio > 1) { [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { @@ -573,9 +623,7 @@ void main() { const uint d = d0 + col_tid; if (d >= HSV / 4) break; const uint d_local = d0 / threads_per_rowgroup; - [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { - perElemOpGqaStore(tile_row(r), 4 * d + comp, float(Of[r][d_local][comp]), o_offset, iq2, N); - } + gqaStore(tile_row(r), d, Of[r][d_local], o_offset, iq2, N); } } } @@ -586,9 +634,7 @@ void main() { const uint d = d0 + col_tid; if (d >= HSV / 4) break; const uint d_local = d0 / threads_per_rowgroup; - [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { - data_o[o_offset + iq2 * HSV + (i * Br + tile_row(r)) * p.ne1 * HSV + 4 * d + comp] = D_TYPE(Of[r][d_local][comp]); - } + data_ov4[o_offset + (iq2 * HSV + (i * Br + tile_row(r)) * p.ne1 * HSV) / 4 + d] = D_TYPEV4(Of[r][d_local]); } } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index 853f17fa16e..0ea181342ce 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -72,6 +72,28 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TY return elem; } +// Store O values for non-GQA split_k. Rows are tokens, not heads. +D_TYPE perElemOpNonGqaSplitKStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t unused, const in uint32_t iq2, const in uint32_t N) { + uint32_t global_row = i * Br + r; + if (global_row < N && c < HSV) { + uint32_t o_off = HSV * p.ne1 + * (split_k_index + p.k_num * (global_row + p.ne2 * iq3)); + data_o[o_off + iq2 * HSV + c] = D_TYPE(elem); + } + return elem; +} + +// Store L/M values for non-GQA split_k. +ACC_TYPE perElemOpNonGqaSplitKStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t lm_base, const in uint32_t iq2, const in uint32_t N) { + uint32_t global_row = i * Br + r; + if (global_row < N && c == 0) { + uint32_t lm_off = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + + p.ne1 * 2 * (split_k_index + p.k_num * (global_row + p.ne2 * iq3)); + data_o[lm_off + lm_base + iq2] = D_TYPE(elem); + } + return elem; +} + void main() { #ifdef NEEDS_INIT_IQ_SHMEM init_iq_shmem(gl_WorkGroupSize); @@ -290,13 +312,19 @@ void main() { if (p.k_num > 1) { coopmat O_D = coopmat(O); - // note: O and Q have swapped coord 1,2. - uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)); - coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N); - - o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)); - coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N); - coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N); + if (p.gqa_ratio > 1) { + // note: O and Q have swapped coord 1,2. + uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)); + coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N); + + o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)); + coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N); + coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N); + } else { + coopMatPerElementNV(O_D, O_D, perElemOpNonGqaSplitKStore, 0u, iq2, N); + coopMatPerElementNV(L, L, perElemOpNonGqaSplitKStoreCol0, 0u, iq2, N); + coopMatPerElementNV(M, M, perElemOpNonGqaSplitKStoreCol0, p.ne1, iq2, N); + } return; } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 42ebc21e2a6..85455988c57 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -595,8 +595,6 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c } void process_shaders() { - std::map base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}}; - // matmul for (const MatMulIdType& matmul_id_type : {MatMulIdType::NONE, MatMulIdType::DEFAULT, MatMulIdType::SUBGROUP}) { // No coopmats @@ -622,49 +620,63 @@ void process_shaders() { } } - // flash attention - for (const auto& f16acc : {false, true}) { - std::map fa_base_dict = base_dict; - fa_base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float"; - fa_base_dict["ACC_TYPEV4"] = f16acc ? "f16vec4" : "vec4"; - if (f16acc) { - fa_base_dict["ACC_TYPE_MAX"] = "float16_t(65504.0)"; + for (const bool& fp16 : {false, true}) { + std::map base_dict; + if (fp16) { + base_dict = {{"FLOAT_TYPE", "float16_t"}, {"FLOAT_TYPEV4", "f16vec4"}, {"FLOAT16", "1"}, {"FLOAT_TYPE_MAX", "float16_t(65504.0)"}}; + } else { + base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV4", "vec4"}}; } - for (const auto& tname : type_names) { - if (tname == "bf16") continue; + // flash attention + for (const bool& f16acc : {false, true}) { + std::map fa_base_dict = base_dict; + fa_base_dict["ACC_TYPE"] = fp16 && f16acc ? "float16_t" : "float"; + fa_base_dict["ACC_TYPEV4"] = fp16 && f16acc ? "f16vec4" : "vec4"; + if (fp16 && f16acc) { + fa_base_dict["ACC_TYPE_MAX"] = "float16_t(65504.0)"; + } + + for (const auto& tname : type_names) { + if (tname == "bf16") continue; + if (fp16) { #if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) - if (tname == "f16") { - string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", - merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), true, false, true, f16acc); - } else { - std::string data_a_key = "DATA_A_" + to_uppercase(tname); - string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", - merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc); - } + if (tname == "f16") { + string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", + merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), fp16, false, true, f16acc); + } else { + std::string data_a_key = "DATA_A_" + to_uppercase(tname); + string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", + merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), fp16, false, true, f16acc); + } #endif #if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) - if (tname == "f16") { - string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp", - merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"COOPMAT", "1"}}), true, true, false, f16acc); - } else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") { - std::string data_a_key = "DATA_A_" + to_uppercase(tname); - string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp", - merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), true, true, false, f16acc); - } + if (tname == "f16") { + string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp", + merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"COOPMAT", "1"}}), fp16, true, false, f16acc); + } else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") { + std::string data_a_key = "DATA_A_" + to_uppercase(tname); + string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp", + merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), fp16, true, false, f16acc); + } #endif - if (tname == "f16") { - string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", - merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), true, false, false, f16acc); - } else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") { - std::string data_a_key = "DATA_A_" + to_uppercase(tname); - string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", - merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, false, f16acc); + } + + if (tname == "f16") { + string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", + merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), fp16, false, false, f16acc); + } else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") { + std::string data_a_key = "DATA_A_" + to_uppercase(tname); + string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", + merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), fp16, false, false, f16acc); + } } } } + std::map base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}}; + for (const auto& tname : type_names) { // mul mat vec std::string data_a_key = "DATA_A_" + to_uppercase(tname); From 279be33a83890d79e68607e4354b43212b57d38e Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 24 Feb 2026 20:17:11 +0200 Subject: [PATCH 185/831] ggml/gguf : prevent integer overflows (llama/19856) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * gguf : prevent integer overflow for ggml_context mem size * ggml : fix int overflows in ggml_new_object() * gguf : prevent string exhaustion * gguf : prevent array elements exhaustion * ggml : fix negative tensor type oob * py : assert that alignment is non-zero power of 2 * ggml : check int overflow in ggml_new_tensor_impl and ggml_new_object * gguf-py : error on duplicate keys when reading * py : restore tensor_fields * enforce proper alignment in add_custom_alignment * gguf : better name * gguf : fix ctx size for no_alloc == true * gguf : minor print fix * ggml : print values when overflow * ggml : remove deprecated ggml_type_sizef() * ggml : relax ggml_type asserts to debug-only * gguf : add mem_size overflow test * gguf : add file size check for arrays * ggml : relax asseerts for ggml_get_type_traits() * flake8 fix --------- Co-authored-by: Sigbjørn Skjæret --- ggml/include/ggml.h | 4 ---- ggml/src/ggml.c | 33 +++++++++++++++++++++++++++------ 2 files changed, 27 insertions(+), 10 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 77af0e7fb6a..fcc51f1f71a 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -730,10 +730,6 @@ extern "C" { GGML_API size_t ggml_type_size(enum ggml_type type); // size in bytes for all elements in a block GGML_API size_t ggml_row_size (enum ggml_type type, int64_t ne); // size in bytes for all elements in a row - GGML_DEPRECATED( - GGML_API double ggml_type_sizef(enum ggml_type type), // ggml_type_size()/ggml_blck_size() as float - "use ggml_row_size() instead"); - GGML_API const char * ggml_type_name(enum ggml_type type); GGML_API const char * ggml_op_name (enum ggml_op op); GGML_API const char * ggml_op_symbol(enum ggml_op op); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index ed819eaa4c5..e9529fbb662 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -899,7 +899,8 @@ static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = { }; const struct ggml_type_traits * ggml_get_type_traits(enum ggml_type type) { - GGML_ASSERT(type < GGML_TYPE_COUNT); + assert(type >= 0); + assert(type < GGML_TYPE_COUNT); return &type_traits[type]; } @@ -1265,27 +1266,33 @@ size_t ggml_nbytes_pad(const struct ggml_tensor * tensor) { } int64_t ggml_blck_size(enum ggml_type type) { + assert(type >= 0); + assert(type < GGML_TYPE_COUNT); return type_traits[type].blck_size; } size_t ggml_type_size(enum ggml_type type) { + assert(type >= 0); + assert(type < GGML_TYPE_COUNT); return type_traits[type].type_size; } size_t ggml_row_size(enum ggml_type type, int64_t ne) { + assert(type >= 0); + assert(type < GGML_TYPE_COUNT); assert(ne % ggml_blck_size(type) == 0); return ggml_type_size(type)*ne/ggml_blck_size(type); } -double ggml_type_sizef(enum ggml_type type) { - return ((double)(type_traits[type].type_size))/type_traits[type].blck_size; -} - const char * ggml_type_name(enum ggml_type type) { - return type < GGML_TYPE_COUNT ? type_traits[type].type_name : "NONE"; + assert(type >= 0); + assert(type < GGML_TYPE_COUNT); + return type_traits[type].type_name; } bool ggml_is_quantized(enum ggml_type type) { + assert(type >= 0); + assert(type < GGML_TYPE_COUNT); return type_traits[type].is_quantized; } @@ -1629,11 +1636,23 @@ static struct ggml_object * ggml_new_object(struct ggml_context * ctx, enum ggml const size_t cur_end = cur_offs + cur_size; // align to GGML_MEM_ALIGN + GGML_ASSERT(size <= SIZE_MAX - (GGML_MEM_ALIGN - 1)); size_t size_needed = GGML_PAD(size, GGML_MEM_ALIGN); char * const mem_buffer = ctx->mem_buffer; struct ggml_object * const obj_new = (struct ggml_object *)(mem_buffer + cur_end); + // integer overflow checks + if (cur_end > SIZE_MAX - size_needed) { + GGML_LOG_WARN("%s: overflow detected in cur_end (%zu) + size_needed (%zu)\n", __func__, cur_end, size_needed); + return NULL; + } + if (cur_end + size_needed > SIZE_MAX - GGML_OBJECT_SIZE) { + GGML_LOG_WARN("%s: overflow detected in cur_end (%zu) + size_needed (%zu) + GGML_OBJECT_SIZE (%zu)\n", __func__, + cur_end, size_needed, (size_t) GGML_OBJECT_SIZE); + return NULL; + } + if (cur_end + size_needed + GGML_OBJECT_SIZE > ctx->mem_size) { GGML_LOG_WARN("%s: not enough space in the context's memory pool (needed %zu, available %zu)\n", __func__, cur_end + size_needed + GGML_OBJECT_SIZE, ctx->mem_size); @@ -1702,6 +1721,8 @@ static struct ggml_tensor * ggml_new_tensor_impl( obj_alloc_size = data_size; } + GGML_ASSERT(GGML_TENSOR_SIZE <= SIZE_MAX - obj_alloc_size); + struct ggml_object * const obj_new = ggml_new_object(ctx, GGML_OBJECT_TYPE_TENSOR, GGML_TENSOR_SIZE + obj_alloc_size); GGML_ASSERT(obj_new); From fb55b2654b8510f044676fa3298133d01e40d5d5 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Wed, 25 Feb 2026 11:25:38 -0600 Subject: [PATCH 186/831] vulkan: check for memory overlap before doing fusion (llama/19768) * vulkan: check for memory overlap before doing fusion * Update ggml/src/ggml-vulkan/ggml-vulkan.cpp * address feedback --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 133 ++++++++++++++++++++++++--- 1 file changed, 119 insertions(+), 14 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 8a9cfaf1654..a1149e606e4 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -13820,12 +13820,11 @@ static bool ggml_vk_can_fuse_rope_set_rows(ggml_backend_vk_context * ctx, const return true; } -// Check whether the tensors overlap in memory but are not equal. -// Fusions can potenitally overwrite src tensors in ways that are not prevented -// by ggml-alloc. If the fusion is entirely elementwise, then it's OK for them -// to overlap if they are exactly equal. -// XXX TODO this check is probably missing from several fusion optimizations. -static bool ggml_vk_tensors_overlap_but_not_equal(const ggml_tensor * a, const ggml_tensor * b) { +// Check whether the tensors overlap in memory. +// Fusions can potentially overwrite src tensors in ways that are not prevented +// by ggml-alloc. If the fusion src is being applied in a way that's elementwise +// with the destination, then it's OK for them to overlap if they are exactly equal. +static bool ggml_vk_tensors_overlap(const ggml_tensor * a, const ggml_tensor * b, bool elementwise) { ggml_backend_vk_buffer_context * a_buf_ctx = (ggml_backend_vk_buffer_context *)a->buffer->context; vk_buffer a_buf = a_buf_ctx->dev_buffer; ggml_backend_vk_buffer_context * b_buf_ctx = (ggml_backend_vk_buffer_context *)b->buffer->context; @@ -13836,7 +13835,7 @@ static bool ggml_vk_tensors_overlap_but_not_equal(const ggml_tensor * a, const g auto b_base = vk_tensor_offset(b) + b->view_offs; auto b_size = ggml_nbytes(b); - if (a_base == b_base && a_size == b_size) { + if (elementwise && a_base == b_base && a_size == b_size) { return false; } @@ -13874,13 +13873,6 @@ static bool ggml_vk_can_fuse_rms_norm_mul_rope(ggml_backend_vk_context * ctx, co return false; } - // must not overwrite srcs in a way that's not elementwise - ggml_tensor *other_src = mul->src[0] == rms ? mul->src[1] : mul->src[0]; - if (ggml_vk_tensors_overlap_but_not_equal(rms->src[0], rope) || - ggml_vk_tensors_overlap_but_not_equal(other_src, rope)) { - return false; - } - // conditions for pipeline creation if (!(ctx->device->float_controls_rte_fp16 && sizeof(vk_op_rms_norm_mul_rope_push_constants) <= ctx->device->properties.limits.maxPushConstantsSize)) { @@ -13942,6 +13934,18 @@ static uint32_t ggml_vk_fuse_multi_add(ggml_backend_vk_context * ctx, const stru return num_adds; } +static int32_t find_first_set(uint32_t x) { + int32_t ret = 0; + if (!x) { + return -1; + } + while (!(x & 1)) { + x >>= 1; + ret++; + } + return ret; +} + static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)"); ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; @@ -14040,6 +14044,12 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg total_mul_mat_bytes += bytes; } + // op_srcs_fused_elementwise indicates whether an op's srcs all contribute to + // the fused result in an elementwise-way. This affects whether the memory for + // the src is allowed to overlap the memory for the destination. + // The array is sized to handle the largest fusion (asserted later). + bool op_srcs_fused_elementwise[12]; + ctx->fused_topk_moe_mode = TOPK_MOE_COUNT; ctx->fused_topk_moe_scale = false; const char *fusion_string {}; @@ -14048,39 +14058,68 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg if (num_adds) { ctx->num_additional_fused_ops = num_adds - 1; fusion_string = "MULTI_ADD"; + std::fill_n(op_srcs_fused_elementwise, ctx->num_additional_fused_ops + 1, true); } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_ADD })) { ctx->num_additional_fused_ops = 2; fusion_string = "MUL_MAT_ADD_ADD"; + op_srcs_fused_elementwise[0] = false; + op_srcs_fused_elementwise[1] = true; + op_srcs_fused_elementwise[2] = true; } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT, GGML_OP_ADD })) { ctx->num_additional_fused_ops = 1; fusion_string = "MUL_MAT_ADD"; + op_srcs_fused_elementwise[0] = false; + op_srcs_fused_elementwise[1] = true; } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_MUL })) { ctx->num_additional_fused_ops = 2; fusion_string = "MUL_MAT_ID_ADD_ID_MUL"; + op_srcs_fused_elementwise[0] = false; + op_srcs_fused_elementwise[1] = true; + op_srcs_fused_elementwise[2] = true; } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID })) { ctx->num_additional_fused_ops = 1; fusion_string = "MUL_MAT_ID_ADD_ID"; + op_srcs_fused_elementwise[0] = false; + op_srcs_fused_elementwise[1] = true; } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_MUL })) { ctx->num_additional_fused_ops = 1; fusion_string = "MUL_MAT_ID_MUL"; + op_srcs_fused_elementwise[0] = false; + op_srcs_fused_elementwise[1] = true; } else if (ggml_can_fuse_subgraph(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, { i + 4 }) && ggml_check_edges(cgraph, i, rms_norm_mul_rope_view_set_rows_edges) && ggml_vk_can_fuse_rms_norm_mul_rope(ctx, cgraph, i) && ggml_vk_can_fuse_rope_set_rows(ctx, cgraph, i + 2)) { ctx->num_additional_fused_ops = 4; fusion_string = "RMS_NORM_MUL_ROPE_VIEW_SET_ROWS"; + op_srcs_fused_elementwise[0] = false; + op_srcs_fused_elementwise[1] = false; + op_srcs_fused_elementwise[2] = false; + op_srcs_fused_elementwise[3] = false; + op_srcs_fused_elementwise[4] = false; } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ROPE })&& ggml_vk_can_fuse_rms_norm_mul_rope(ctx, cgraph, i)) { ctx->num_additional_fused_ops = 2; fusion_string = "RMS_NORM_MUL_ROPE"; + // rope is approximately elementwise - whole rows are done by a single workgroup and it's row-wise + op_srcs_fused_elementwise[0] = false; + op_srcs_fused_elementwise[1] = true; + op_srcs_fused_elementwise[2] = true; } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { ctx->num_additional_fused_ops = 1; fusion_string = "RMS_NORM_MUL"; + // rms_norm is not elementwise, but whole rows must be consumed and the scale factor computed before + // they are overwritten, and one workgroup per row. So close enough. + op_srcs_fused_elementwise[0] = true; + op_srcs_fused_elementwise[1] = true; } else if (ggml_can_fuse_subgraph(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, { i + 2 }) && ggml_check_edges(cgraph, i, rope_view_set_rows_edges) && ggml_vk_can_fuse_rope_set_rows(ctx, cgraph, i)) { ctx->num_additional_fused_ops = 2; fusion_string = "ROPE_VIEW_SET_ROWS"; + op_srcs_fused_elementwise[0] = false; + op_srcs_fused_elementwise[1] = false; + op_srcs_fused_elementwise[2] = false; } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax_norm, { i + 3, i + 9 }) && ggml_check_edges(cgraph, i, topk_moe_early_softmax_norm_edges) && ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) { @@ -14089,6 +14128,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg ctx->fused_ops_write_mask |= 1 << 3; ctx->fused_topk_moe_mode = TOPK_MOE_EARLY_SOFTMAX_NORM; fusion_string = "TOPK_MOE_EARLY_SOFTMAX_NORM"; + std::fill_n(op_srcs_fused_elementwise, ctx->num_additional_fused_ops + 1, false); } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_sigmoid_norm_bias, { i + 4, i + 10 }) && ggml_check_edges(cgraph, i, topk_moe_sigmoid_norm_bias_edges) && ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_SIGMOID_NORM_BIAS)) { @@ -14097,6 +14137,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg ctx->fused_ops_write_mask |= 1 << 4; ctx->fused_topk_moe_mode = TOPK_MOE_SIGMOID_NORM_BIAS; fusion_string = "TOPK_MOE_SIGMOID_NORM_BIAS"; + std::fill_n(op_srcs_fused_elementwise, ctx->num_additional_fused_ops + 1, false); } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax, { i + 3, i + 4 }) && ggml_check_edges(cgraph, i, topk_moe_early_softmax_edges) && ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) { @@ -14105,6 +14146,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg ctx->fused_ops_write_mask |= 1 << 3; ctx->fused_topk_moe_mode = TOPK_MOE_EARLY_SOFTMAX; fusion_string = "TOPK_MOE_EARLY_SOFTMAX"; + std::fill_n(op_srcs_fused_elementwise, ctx->num_additional_fused_ops + 1, false); } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_late_softmax, { i + 1, i + 5 }) && ggml_check_edges(cgraph, i, topk_moe_late_softmax_edges) && ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_LATE_SOFTMAX)) { @@ -14113,6 +14155,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg ctx->fused_ops_write_mask |= 1 << 1; ctx->fused_topk_moe_mode = TOPK_MOE_LATE_SOFTMAX; fusion_string = "TOPK_MOE_LATE_SOFTMAX"; + std::fill_n(op_srcs_fused_elementwise, ctx->num_additional_fused_ops + 1, false); } if (ctx->fused_topk_moe_mode != TOPK_MOE_COUNT) { // Look for an additional scale op to fuse - occurs in deepseek2 and nemotron3 nano. @@ -14120,11 +14163,73 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg ggml_can_fuse_subgraph(cgraph, i + ctx->num_additional_fused_ops, { GGML_OP_GET_ROWS, GGML_OP_SCALE }, { i + ctx->num_additional_fused_ops + 1 })) { ctx->fused_topk_moe_scale = true; ctx->num_additional_fused_ops++; + op_srcs_fused_elementwise[ctx->num_additional_fused_ops] = false; } } } + GGML_ASSERT(ctx->num_additional_fused_ops < (int)(sizeof(op_srcs_fused_elementwise) / sizeof(op_srcs_fused_elementwise[0]))); ctx->fused_ops_write_mask |= 1 << ctx->num_additional_fused_ops; + // Check whether fusion would overwrite src operands while they're still in use. + // If so, disable fusion. + if (ctx->num_additional_fused_ops) { + // There are up to two output nodes - topk_moe has two. + uint32_t bits = ctx->fused_ops_write_mask & ~(1 << ctx->num_additional_fused_ops); + ggml_tensor *output_nodes[2] {}; + output_nodes[0] = cgraph->nodes[i + ctx->num_additional_fused_ops]; + if (bits) { + int output_idx = find_first_set(bits); + GGML_ASSERT(bits == (1u << output_idx)); + output_nodes[1] = cgraph->nodes[i + output_idx]; + } + + bool need_disable = false; + + // topk_moe often overwrites the source, but for a given row all the src values are + // loaded before anything is stored. If there's only one row, this is safe, so treat + // this as a special case. + bool is_topk_moe_single_row = ctx->fused_topk_moe_mode != TOPK_MOE_COUNT && + ggml_nrows(cgraph->nodes[i]->src[0]) == 1; + + if (!is_topk_moe_single_row) { + for (int j = 0; j < 2; ++j) { + ggml_tensor *dst = output_nodes[j]; + if (!dst) { + continue; + } + // Loop over all srcs of all nodes in the fusion. If the src overlaps + // the destination and the src is not an intermediate node that's being + // elided, then disable fusion. + for (int k = 0; k <= ctx->num_additional_fused_ops; ++k) { + for (uint32_t s = 0; s < GGML_MAX_SRC; ++s) { + ggml_tensor *src = cgraph->nodes[i + k]->src[s]; + if (!src || src->op == GGML_OP_NONE) { + continue; + } + if (ggml_vk_tensors_overlap(src, dst, op_srcs_fused_elementwise[k])) { + bool found = false; + for (int n = 0; n < k; ++n) { + if (cgraph->nodes[i + n] == src) { + found = true; + break; + } + } + if (!found) { + need_disable = true; + } + } + } + } + } + } + if (need_disable) { + ctx->num_additional_fused_ops = 0; + ctx->fused_ops_write_mask = 1; + ctx->fused_topk_moe_mode = TOPK_MOE_COUNT; + ctx->fused_topk_moe_scale = false; + } + } + // Signal the almost_ready fence when the graph is mostly complete (< 20% remaining) bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5; bool submit = (submitted_nodes >= nodes_per_submit) || From 4cac408c6030953fe192b869d1d4a85ab0911bd8 Mon Sep 17 00:00:00 2001 From: Neo Zhang Date: Thu, 26 Feb 2026 10:27:20 +0800 Subject: [PATCH 187/831] support permuted, remove check s0/s10 (llama/19889) Co-authored-by: Neo Zhang Jianyu --- ggml/src/ggml-sycl/binbcast.cpp | 41 +++++++++++++++++---------------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/ggml/src/ggml-sycl/binbcast.cpp b/ggml/src/ggml-sycl/binbcast.cpp index 0a3883ae1ed..92dd18889f4 100644 --- a/ggml/src/ggml-sycl/binbcast.cpp +++ b/ggml/src/ggml-sycl/binbcast.cpp @@ -11,8 +11,8 @@ static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst, int ne0, int ne1, int ne2, int ne3, int ne10, int ne11, int ne12, int ne13, /*int s0, */ int s1, int s2, int s3, - /*int s00,*/ int s01, int s02, int s03, - /*int s10,*/ int s11, int s12, int s13, + int s00, int s01, int s02, int s03, + int s10, int s11, int s12, int s13, const sycl::nd_item<3> &item_ct1) { const int i0s = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2); @@ -44,7 +44,7 @@ static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst, for (int i0 = i0s; i0 < ne0; i0 += item_ct1.get_local_range(2) * item_ct1.get_group_range(2)) { const int i10 = i0 % ne10; - dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]); + dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0*s00] : 0.0f, (float)src1_row[i10*s10]); } } @@ -53,8 +53,8 @@ static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t int ne0, int ne1, int ne2, int ne3, int ne10, int ne11, int ne12, int ne13, /*int s0, */ int s1, int s2, int s3, - /*int s00,*/ int s01, int s02, int s03, - /*int s10,*/ int s11, int s12, int s13, + int s00, int s01, int s02, int s03, + int s10, int s11, int s12, int s13, const sycl::nd_item<3> &item_ct1) { const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + @@ -82,7 +82,7 @@ static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t dst_t * dst_row = dst + i_dst; const int i10 = i0 % ne10; - dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]); + dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0*s00] : 0.0f, (float)src1_row[i10*s10]); } @@ -95,7 +95,8 @@ struct bin_bcast_sycl { const int64_t ne3, const size_t nb00, const size_t nb01, const size_t nb02, const size_t nb03, const size_t nb10, const size_t nb11, const size_t nb12, const size_t nb13, const size_t nb0, const size_t nb1, const size_t nb2, const size_t nb3, const bool src0_is_contiguous, - const bool src1_is_contiguous, const bool dst_is_contiguous, queue_ptr stream) { + const bool src1_is_contiguous, const bool src0_is_permuted, const bool src1_is_permuted, + queue_ptr stream) { int nr0 = ne10 / ne0; int nr1 = ne11/ne1; int nr2 = ne12/ne2; @@ -123,7 +124,7 @@ struct bin_bcast_sycl { cnb[3] *= cne[3]; }; - if (src0_is_contiguous && src1_is_contiguous && dst_is_contiguous) { + if (src0_is_contiguous && src1_is_contiguous && !src0_is_permuted && !src1_is_permuted) { for (int i = 0; i < 4; i++) { if (nr[i] != 1) { break; @@ -164,7 +165,7 @@ struct bin_bcast_sycl { size_t nb12 = cnb1[2]; size_t nb13 = cnb1[3]; - size_t s0 = nb0 / sizeof(dst_t); + // size_t s0 = nb0 / sizeof(dst_t); size_t s1 = nb1 / sizeof(dst_t); size_t s2 = nb2 / sizeof(dst_t); size_t s3 = nb3 / sizeof(dst_t); @@ -196,9 +197,6 @@ struct bin_bcast_sycl { GGML_ASSERT(nb12 % sizeof(src1_t) == 0); GGML_ASSERT(nb13 % sizeof(src1_t) == 0); - GGML_ASSERT(s0 == 1); - GGML_ASSERT(s10 == 1); - const int block_size = 128; int64_t hne0 = std::max(ne0/2LL, 1LL); @@ -232,8 +230,8 @@ struct bin_bcast_sycl { [=](sycl::nd_item<3> item_ct1) { k_bin_bcast_unravel( src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3, - ne10, ne11, ne12, ne13, s1, s2, s3, s01, s02, - s03, s11, s12, s13, item_ct1); + ne10, ne11, ne12, ne13, s1, s2, s3, s00, s01, s02, + s03, s10, s11, s12, s13, item_ct1); }); } } else { @@ -251,7 +249,7 @@ struct bin_bcast_sycl { [=](sycl::nd_item<3> item_ct1) { k_bin_bcast(src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3, ne10, ne11, ne12, ne13, - s1, s2, s3, s01, s02, s03, s11, s12, s13, + s1, s2, s3, s00, s01, s02, s03, s10, s11, s12, s13, item_ct1); }); } @@ -268,24 +266,27 @@ inline void ggml_sycl_op_bin_bcast(ggml_backend_sycl_context & ctx, const ggml_t if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { op()((const float *) src0->data, (const float *) src1->data, (float *) dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2, nb3, - ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_contiguous(dst), main_stream); + ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_permuted(src0), ggml_is_permuted(src1), main_stream); } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { op()((const sycl::half *) src0->data, (const sycl::half *) src1->data, (sycl::half *) dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, - nb0, nb1, nb2, nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_contiguous(dst), + nb0, nb1, nb2, nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_permuted(src0), ggml_is_permuted(src1), main_stream); } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) { op()((const sycl::half *) src0->data, (const float *) src1->data, (sycl::half *) dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, - nb2, nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_contiguous(dst), main_stream); + nb2, nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_permuted(src0), ggml_is_permuted(src1), + main_stream); } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) { op()((const int32_t *) src0->data, (const int32_t *) src1->data, (int32_t *) dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2, - nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_contiguous(dst), main_stream); + nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_permuted(src0), ggml_is_permuted(src1), + main_stream); } else if (src0->type == GGML_TYPE_I16 && src1->type == GGML_TYPE_I16 && dst->type == GGML_TYPE_I16) { op()((const int16_t *) src0->data, (const int16_t *) src1->data, (int16_t *) dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2, - nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_contiguous(dst), main_stream); + nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_permuted(src0), ggml_is_permuted(src1), + main_stream); } else { fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__, ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type)); From f877e1b20211e58a7e9dca1f0ed81c3deec0efa7 Mon Sep 17 00:00:00 2001 From: Kevin Pouget Date: Thu, 26 Feb 2026 13:00:57 +0100 Subject: [PATCH 188/831] ggml-virtgpu: improve the reliability of the code (llama/19846) * ggml-virtgpu-backend: validate the consistency of the received objects This patch adds consistency checks in the ggml-virtgpu-backend (running on the host side) to ensure that the data received from the guest is consistent (valid pointers, valid sizes and offsets). * ggml-virtgpu-backend: add fallback/skips for optional ggml backend methods ``` 1. bck->iface.synchronize(bck) 2. buft->iface.get_alloc_size(buft, op) 3. buft->iface.get_max_size(buft) ``` these three methods are optional in the GGML interface. `get_max_size` was already properly defaulted, but `backend sychronize` and `butf get_max_size` would have segfaulted the backend if not implemented. * ggml-virtgpu-backend: fix log format missing argument * ggml-virtgpu-backend: improve the abort message * ggml-virtgpu-backend: more safety checks * ggml-virtgpu-backend: new error code * ggml-virtgpu-backend: initialize all the error codes * ggml-virtgpu: add a missing comment generated by the code generator * ggml-virtgpu: add the '[virtgpu]' prefix to the device/buffer names * ggml-virtgpu: apir_device_buffer_from_ptr: improve the error message * ggml-virtgpu: shared: make it match the latest api_remoting.h of Virglrenderer APIR (still unmerged) * ggml-virtgpu: update the code generator to have dispatch_command_name in a host/guest shared file * ggml-virtgpu: REMOTE_CALL: fail if the backend returns an error * docs/backend/VirtGPU.md: indicate that the RAM+VRAM size is limed to 64 GB with libkrun * ggml-virtgpu: turn off clang-format header ordering for some of the files Compilation breaks when ordered alphabetically. * ggml-virtgpu: clang-format * ggml-virtgpu/backend/shared/api_remoting: better comments for the APIR return codes --- .../backend/backend-dispatched-backend.cpp | 43 +++++++++- .../backend-dispatched-buffer-type.cpp | 14 ++- .../backend/backend-dispatched-buffer.cpp | 48 +++++++++++ .../backend/backend-dispatched.cpp | 13 ++- .../backend/backend-dispatched.gen.h | 58 ------------- .../ggml-virtgpu/backend/backend-dispatched.h | 2 + .../ggml-virtgpu/backend/backend-virgl-apir.h | 2 +- ggml/src/ggml-virtgpu/backend/backend.cpp | 38 ++++----- .../backend/shared/api_remoting.h | 21 +++-- .../backend/shared/apir_backend.gen.h | 58 +++++++++++++ .../backend/shared/apir_backend.h | 6 +- .../src/ggml-virtgpu/backend/shared/apir_cs.h | 20 ++--- .../backend/shared/apir_cs_ggml.h | 27 ++++-- .../ggml-virtgpu/backend/shared/apir_cs_rpc.h | 4 + .../ggml-virtgpu/ggml-backend-buffer-type.cpp | 6 +- ggml/src/ggml-virtgpu/ggml-backend-device.cpp | 7 +- ggml/src/ggml-virtgpu/ggml-backend-reg.cpp | 56 ++++++++---- ggml/src/ggml-virtgpu/ggml-backend.cpp | 2 +- ggml/src/ggml-virtgpu/ggml-remoting.h | 2 +- ggml/src/ggml-virtgpu/include/apir_hw.h | 6 +- ggml/src/ggml-virtgpu/regenerate_remoting.py | 47 +++++----- .../ggml-virtgpu/virtgpu-forward-backend.cpp | 6 +- .../virtgpu-forward-buffer-type.cpp | 8 +- .../ggml-virtgpu/virtgpu-forward-buffer.cpp | 12 +-- .../ggml-virtgpu/virtgpu-forward-device.cpp | 4 +- ggml/src/ggml-virtgpu/virtgpu-forward-impl.h | 47 +++++----- ggml/src/ggml-virtgpu/virtgpu-forward.gen.h | 1 + ggml/src/ggml-virtgpu/virtgpu.cpp | 85 ++++++++----------- ggml/src/ggml-virtgpu/virtgpu.h | 8 +- 29 files changed, 395 insertions(+), 256 deletions(-) diff --git a/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp b/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp index cc879e51d04..03a037f1cbd 100644 --- a/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +++ b/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp @@ -7,9 +7,21 @@ #include +static uint32_t validate_graph_operation(size_t cgraph_size, uint32_t shmem_res_id, const char * operation) { + if (cgraph_size == 0) { + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Zero-size computation graph\n", operation); + return 1; + } + + // place-holder: validate that the size of shmem_res_id is <= cgraph_size + // need to add another method in the Virgl->APIR callback interface + GGML_UNUSED(shmem_res_id); + + return 0; // Valid +} + uint32_t backend_backend_graph_compute(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) { GGML_UNUSED(ctx); - GGML_UNUSED(enc); static bool async_backend_initialized = false; static bool async_backend; @@ -34,10 +46,26 @@ uint32_t backend_backend_graph_compute(apir_encoder * enc, apir_decoder * dec, v size_t cgraph_size; apir_decode_size_t(dec, &cgraph_size); + if (validate_graph_operation(cgraph_size, shmem_res_id, __func__) != 0) { + apir_decoder_set_fatal(dec); + return 1; + } + apir_decoder secondary_dec = apir_new_decoder((const char *) shmem_data, cgraph_size); ggml_cgraph * cgraph = apir_decode_ggml_cgraph(&secondary_dec, cgraph_size); + if (!cgraph || apir_decoder_get_fatal(&secondary_dec)) { + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Failed to deserialize computation graph\n", __func__); + return 1; + } + + if (cgraph->n_nodes < 0 || cgraph->n_leafs < 0) { + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Invalid negative node/leaf count: nodes=%d leafs=%d\n", __func__, + cgraph->n_nodes, cgraph->n_leafs); + return 1; + } + ggml_status status; #if APIR_BACKEND_CHECK_SUPPORTS_OP == 1 for (int idx = 0; idx < cgraph->n_nodes; idx++) { @@ -45,7 +73,8 @@ uint32_t backend_backend_graph_compute(apir_encoder * enc, apir_decoder * dec, v if (dev->iface.supports_op(dev, op)) { continue; } - GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Graph node %d (%s) not supported by the backend\n", idx, ggml_op_desc(op)); + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Graph node %d (%s) not supported by the backend\n", __func__, idx, + ggml_op_desc(op)); status = GGML_STATUS_ABORTED; apir_encode_ggml_status(enc, &status); @@ -53,9 +82,17 @@ uint32_t backend_backend_graph_compute(apir_encoder * enc, apir_decoder * dec, v return 0; } #endif + + // Check if backend is properly initialized + if (!bck) { + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Backend not initialized (bck is null)\n", __func__); + + return 1; + } + status = bck->iface.graph_compute(bck, cgraph); - if (async_backend) { + if (async_backend && bck->iface.synchronize) { bck->iface.synchronize(bck); } diff --git a/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp b/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp index d55eec27610..c66dbaa9e8f 100644 --- a/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +++ b/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp @@ -85,7 +85,19 @@ uint32_t backend_buffer_type_get_alloc_size(apir_encoder * enc, apir_decoder * d const ggml_tensor * op = apir_decode_ggml_tensor_inplace(dec); - size_t value = buft->iface.get_alloc_size(buft, op); + // Check for decode error + if (op == nullptr) { + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Failed to decode tensor\n", __func__); + apir_decoder_set_fatal(dec); + return 1; + } + + size_t value; + if (buft->iface.get_alloc_size) { + value = buft->iface.get_alloc_size(buft, op); + } else { + value = ggml_nbytes(op); // Default fallback + } apir_encode_size_t(enc, &value); diff --git a/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp b/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp index 8cc063ff0a6..3ade8d99b4e 100644 --- a/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +++ b/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp @@ -6,11 +6,26 @@ #include +static uint32_t validate_buffer_operation(size_t offset, size_t size, const char * operation) { + // Only check for critical integer overflow - no arbitrary size limits + if (offset > SIZE_MAX - size) { + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Integer overflow in offset+size: %zu + %zu\n", operation, offset, size); + return 1; + } + + return 0; // Valid +} + uint32_t backend_buffer_get_base(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) { GGML_UNUSED(ctx); ggml_backend_buffer_t buffer; buffer = apir_decode_ggml_buffer(dec); + if (!buffer || apir_decoder_get_fatal(dec)) { + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Invalid buffer handle from guest\n", __func__); + return 1; + } + uintptr_t base = (uintptr_t) buffer->iface.get_base(buffer); apir_encode_uintptr_t(enc, &base); @@ -24,6 +39,11 @@ uint32_t backend_buffer_set_tensor(apir_encoder * enc, apir_decoder * dec, virgl ggml_backend_buffer_t buffer; buffer = apir_decode_ggml_buffer(dec); + if (!buffer || apir_decoder_get_fatal(dec)) { + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Invalid buffer handle from guest\n", __func__); + return 1; + } + ggml_tensor * tensor; // safe to remove the const qualifier here tensor = (ggml_tensor *) (uintptr_t) apir_decode_ggml_tensor(dec); @@ -37,6 +57,10 @@ uint32_t backend_buffer_set_tensor(apir_encoder * enc, apir_decoder * dec, virgl size_t size; apir_decode_size_t(dec, &size); + if (validate_buffer_operation(offset, size, __func__) != 0) { + return 1; + } + void * shmem_data = ctx->iface->get_shmem_ptr(ctx->ctx_id, shmem_res_id); if (!shmem_data) { @@ -56,6 +80,11 @@ uint32_t backend_buffer_get_tensor(apir_encoder * enc, apir_decoder * dec, virgl ggml_backend_buffer_t buffer; buffer = apir_decode_ggml_buffer(dec); + if (!buffer || apir_decoder_get_fatal(dec)) { + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Invalid buffer handle from guest\n", __func__); + return 1; + } + const ggml_tensor * tensor; // safe to remove the const qualifier here tensor = apir_decode_ggml_tensor(dec); @@ -69,6 +98,10 @@ uint32_t backend_buffer_get_tensor(apir_encoder * enc, apir_decoder * dec, virgl size_t size; apir_decode_size_t(dec, &size); + if (validate_buffer_operation(offset, size, __func__) != 0) { + return 1; + } + void * shmem_data = ctx->iface->get_shmem_ptr(ctx->ctx_id, shmem_res_id); if (!shmem_data) { GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Couldn't get the shmem addr from virgl\n", __func__); @@ -86,6 +119,11 @@ uint32_t backend_buffer_cpy_tensor(apir_encoder * enc, apir_decoder * dec, virgl ggml_backend_buffer_t buffer; buffer = apir_decode_ggml_buffer(dec); + if (!buffer || apir_decoder_get_fatal(dec)) { + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Invalid buffer handle from guest\n", __func__); + return 1; + } + const ggml_tensor * src; // safe to remove the const qualifier here src = apir_decode_ggml_tensor(dec); @@ -105,6 +143,11 @@ uint32_t backend_buffer_clear(apir_encoder * enc, apir_decoder * dec, virgl_apir ggml_backend_buffer_t buffer; buffer = apir_decode_ggml_buffer(dec); + if (!buffer || apir_decoder_get_fatal(dec)) { + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Invalid buffer handle from guest\n", __func__); + return 1; + } + uint8_t value; apir_decode_uint8_t(dec, &value); @@ -120,6 +163,11 @@ uint32_t backend_buffer_free_buffer(apir_encoder * enc, apir_decoder * dec, virg ggml_backend_buffer_t buffer; buffer = apir_decode_ggml_buffer(dec); + if (!buffer || apir_decoder_get_fatal(dec)) { + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Invalid buffer handle from guest\n", __func__); + return 1; + } + if (!apir_untrack_backend_buffer(buffer)) { GGML_LOG_WARN(GGML_VIRTGPU_BCK "%s: unknown buffer %p\n", __func__, (void *) buffer); return 1; diff --git a/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp b/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp index 64152eef0d8..c80e4aabe1f 100644 --- a/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +++ b/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp @@ -1,6 +1,6 @@ #include "backend-dispatched.h" -#include "backend-virgl-apir.h" +#include "backend-virgl-apir.h" #include "ggml-backend-impl.h" #include "ggml-backend.h" #include "ggml-impl.h" @@ -28,19 +28,24 @@ uint32_t backend_dispatch_initialize(void * ggml_backend_reg_fct_p) { return APIR_BACKEND_INITIALIZE_BACKEND_REG_FAILED; } - if (!reg->iface.get_device_count(reg)) { - GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: backend initialization failed: no device found\n", __func__); + size_t device_count = reg->iface.get_device_count(reg); + if (!device_count) { + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: no device found\n", __func__); return APIR_BACKEND_INITIALIZE_NO_DEVICE; } dev = reg->iface.get_device(reg, 0); if (!dev) { - GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: backend initialization failed: no device received\n", __func__); + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: failed to get device\n", __func__); return APIR_BACKEND_INITIALIZE_NO_DEVICE; } bck = dev->iface.init_backend(dev, NULL); + if (!bck) { + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: backend initialization failed\n", __func__); + return APIR_BACKEND_INITIALIZE_BACKEND_INIT_FAILED; + } return APIR_BACKEND_INITIALIZE_SUCCESS; } diff --git a/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h b/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h index 481d7f3150d..3dc334e4ce4 100644 --- a/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +++ b/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h @@ -32,64 +32,6 @@ uint32_t backend_buffer_free_buffer(apir_encoder * enc, apir_decoder * dec, virg /* backend */ uint32_t backend_backend_graph_compute(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); -static inline const char * backend_dispatch_command_name(ApirBackendCommandType type) { - switch (type) { - /* device */ - case APIR_COMMAND_TYPE_DEVICE_GET_DEVICE_COUNT: - return "backend_device_get_device_count"; - case APIR_COMMAND_TYPE_DEVICE_GET_COUNT: - return "backend_device_get_count"; - case APIR_COMMAND_TYPE_DEVICE_GET_NAME: - return "backend_device_get_name"; - case APIR_COMMAND_TYPE_DEVICE_GET_DESCRIPTION: - return "backend_device_get_description"; - case APIR_COMMAND_TYPE_DEVICE_GET_TYPE: - return "backend_device_get_type"; - case APIR_COMMAND_TYPE_DEVICE_GET_MEMORY: - return "backend_device_get_memory"; - case APIR_COMMAND_TYPE_DEVICE_SUPPORTS_OP: - return "backend_device_supports_op"; - case APIR_COMMAND_TYPE_DEVICE_GET_BUFFER_TYPE: - return "backend_device_get_buffer_type"; - case APIR_COMMAND_TYPE_DEVICE_GET_PROPS: - return "backend_device_get_props"; - case APIR_COMMAND_TYPE_DEVICE_BUFFER_FROM_PTR: - return "backend_device_buffer_from_ptr"; - /* buffer-type */ - case APIR_COMMAND_TYPE_BUFFER_TYPE_GET_NAME: - return "backend_buffer_type_get_name"; - case APIR_COMMAND_TYPE_BUFFER_TYPE_GET_ALIGNMENT: - return "backend_buffer_type_get_alignment"; - case APIR_COMMAND_TYPE_BUFFER_TYPE_GET_MAX_SIZE: - return "backend_buffer_type_get_max_size"; - case APIR_COMMAND_TYPE_BUFFER_TYPE_IS_HOST: - return "backend_buffer_type_is_host (DEPRECATED)"; - case APIR_COMMAND_TYPE_BUFFER_TYPE_ALLOC_BUFFER: - return "backend_buffer_type_alloc_buffer"; - case APIR_COMMAND_TYPE_BUFFER_TYPE_GET_ALLOC_SIZE: - return "backend_buffer_type_get_alloc_size"; - /* buffer */ - case APIR_COMMAND_TYPE_BUFFER_GET_BASE: - return "backend_buffer_get_base"; - case APIR_COMMAND_TYPE_BUFFER_SET_TENSOR: - return "backend_buffer_set_tensor"; - case APIR_COMMAND_TYPE_BUFFER_GET_TENSOR: - return "backend_buffer_get_tensor"; - case APIR_COMMAND_TYPE_BUFFER_CPY_TENSOR: - return "backend_buffer_cpy_tensor"; - case APIR_COMMAND_TYPE_BUFFER_CLEAR: - return "backend_buffer_clear"; - case APIR_COMMAND_TYPE_BUFFER_FREE_BUFFER: - return "backend_buffer_free_buffer"; - /* backend */ - case APIR_COMMAND_TYPE_BACKEND_GRAPH_COMPUTE: - return "backend_backend_graph_compute"; - - default: - return "unknown"; - } -} - extern "C" { static const backend_dispatch_t apir_backend_dispatch_table[APIR_BACKEND_DISPATCH_TABLE_COUNT] = { diff --git a/ggml/src/ggml-virtgpu/backend/backend-dispatched.h b/ggml/src/ggml-virtgpu/backend/backend-dispatched.h index 10311631d4f..740ee9e3ffc 100644 --- a/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +++ b/ggml/src/ggml-virtgpu/backend/backend-dispatched.h @@ -1,5 +1,6 @@ #pragma once +// clang-format off #include #include @@ -10,6 +11,7 @@ #include "shared/apir_backend.h" #include "shared/apir_cs.h" #include "shared/apir_cs_ggml.h" +// clang-format on #define GGML_VIRTGPU_BCK "ggml-virtgpu-backend: " diff --git a/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h b/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h index 44b347f853f..c65a01cdf9b 100644 --- a/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +++ b/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h @@ -19,7 +19,7 @@ struct virgl_apir_callbacks { }; extern "C" { -ApirLoadLibraryReturnCode apir_backend_initialize(uint32_t virgl_ctx_id, struct virgl_apir_callbacks *virgl_cbs); +ApirLoadLibraryReturnCode apir_backend_initialize(uint32_t virgl_ctx_id, struct virgl_apir_callbacks * virgl_cbs); void apir_backend_deinit(uint32_t virgl_ctx_id); uint32_t apir_backend_dispatcher(uint32_t virgl_ctx_id, virgl_apir_callbacks * virgl_cbs, diff --git a/ggml/src/ggml-virtgpu/backend/backend.cpp b/ggml/src/ggml-virtgpu/backend/backend.cpp index d93414a078b..535a05f3e69 100644 --- a/ggml/src/ggml-virtgpu/backend/backend.cpp +++ b/ggml/src/ggml-virtgpu/backend/backend.cpp @@ -1,6 +1,5 @@ #include "backend-dispatched.h" #include "backend-virgl-apir.h" - #include "shared/api_remoting.h" #include "shared/apir_backend.h" #include "shared/apir_cs.h" @@ -17,10 +16,10 @@ #define GGML_DEFAULT_BACKEND_REG "ggml_backend_init" static void * backend_library_handle = NULL; -static FILE * apir_logfile = NULL; +static FILE * apir_logfile = NULL; static void log_to_file_callback(enum ggml_log_level level, const char * text, void * user_data) { - FILE * logfile = (FILE *)user_data; + FILE * logfile = (FILE *) user_data; fprintf(logfile, "[%d] %s", level, text); fflush(logfile); } @@ -48,9 +47,9 @@ void apir_backend_deinit(uint32_t virgl_ctx_id) { } #define APIR_GGML_LIBRARY_PATH_KEY "ggml.library.path" -#define APIR_GGML_LIBRARY_REG_KEY "ggml.library.reg" +#define APIR_GGML_LIBRARY_REG_KEY "ggml.library.reg" -ApirLoadLibraryReturnCode apir_backend_initialize(uint32_t virgl_ctx_id, struct virgl_apir_callbacks *virgl_cbs) { +ApirLoadLibraryReturnCode apir_backend_initialize(uint32_t virgl_ctx_id, struct virgl_apir_callbacks * virgl_cbs) { const char * dlsym_error; const char * apir_log_to_file = getenv(APIR_LLAMA_CPP_LOG_TO_FILE_ENV); @@ -63,15 +62,13 @@ ApirLoadLibraryReturnCode apir_backend_initialize(uint32_t virgl_ctx_id, struct } } - const char * library_name = virgl_cbs->get_config(virgl_ctx_id, APIR_GGML_LIBRARY_PATH_KEY); + const char * library_name = virgl_cbs->get_config(virgl_ctx_id, APIR_GGML_LIBRARY_PATH_KEY); const char * virgl_library_reg = virgl_cbs->get_config(virgl_ctx_id, APIR_GGML_LIBRARY_REG_KEY); - const char * library_reg = virgl_library_reg ? virgl_library_reg : GGML_DEFAULT_BACKEND_REG; + const char * library_reg = virgl_library_reg ? virgl_library_reg : GGML_DEFAULT_BACKEND_REG; if (!library_name) { - GGML_LOG_ERROR(GGML_VIRTGPU_BCK - "%s: cannot open the GGML library: env var '%s' not defined\n", - __func__, APIR_LLAMA_CPP_GGML_LIBRARY_PATH_ENV); - + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: cannot open the GGML library: env var '%s' not defined\n", __func__, + APIR_LLAMA_CPP_GGML_LIBRARY_PATH_ENV); return APIR_LOAD_LIBRARY_ENV_VAR_MISSING; } @@ -79,16 +76,14 @@ ApirLoadLibraryReturnCode apir_backend_initialize(uint32_t virgl_ctx_id, struct backend_library_handle = dlopen(library_name, RTLD_LAZY); if (!backend_library_handle) { - GGML_LOG_ERROR(GGML_VIRTGPU_BCK - "%s: cannot open the GGML library: %s\n", __func__, dlerror()); + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: cannot open the GGML library: %s\n", __func__, dlerror()); return APIR_LOAD_LIBRARY_CANNOT_OPEN; } if (!library_reg) { - GGML_LOG_ERROR(GGML_VIRTGPU_BCK - "%s: cannot register the GGML library: env var '%s' not defined\n", - __func__, APIR_LLAMA_CPP_GGML_LIBRARY_REG_ENV); + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: cannot register the GGML library: env var '%s' not defined\n", __func__, + APIR_LLAMA_CPP_GGML_LIBRARY_REG_ENV); return APIR_LOAD_LIBRARY_ENV_VAR_MISSING; } @@ -96,11 +91,9 @@ ApirLoadLibraryReturnCode apir_backend_initialize(uint32_t virgl_ctx_id, struct void * ggml_backend_reg_fct = dlsym(backend_library_handle, library_reg); dlsym_error = dlerror(); if (dlsym_error) { - GGML_LOG_ERROR(GGML_VIRTGPU_BCK - "%s: cannot find the GGML backend registration symbol '%s' (from %s): %s\n", + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: cannot find the GGML backend registration symbol '%s' (from %s): %s\n", __func__, library_reg, APIR_LLAMA_CPP_GGML_LIBRARY_REG_ENV, dlsym_error); - return APIR_LOAD_LIBRARY_SYMBOL_MISSING; } @@ -132,13 +125,12 @@ uint32_t apir_backend_dispatcher(uint32_t virgl_ctx_id, virgl_apir_context ctx = { .ctx_id = virgl_ctx_id, - .iface = virgl_cbs, + .iface = virgl_cbs, }; if (cmd_type >= APIR_BACKEND_DISPATCH_TABLE_COUNT) { - GGML_LOG_ERROR(GGML_VIRTGPU_BCK - "%s: Received an invalid dispatch index (%d >= %d)\n", - __func__, cmd_type, APIR_BACKEND_DISPATCH_TABLE_COUNT); + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Received an invalid dispatch index (%d >= %d)\n", __func__, cmd_type, + APIR_BACKEND_DISPATCH_TABLE_COUNT); return APIR_BACKEND_FORWARD_INDEX_INVALID; } diff --git a/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h b/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h index f19a5d12d17..6bf97e8a3a2 100644 --- a/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +++ b/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h @@ -16,28 +16,32 @@ enum ApirCommandType { APIR_COMMAND_TYPE_LOADLIBRARY = 1, APIR_COMMAND_TYPE_FORWARD = 2, - APIR_COMMAND_TYPE_LENGTH = 3, + APIR_COMMAND_TYPE_LENGTH = 3, }; typedef uint64_t ApirCommandFlags; enum ApirLoadLibraryReturnCode { APIR_LOAD_LIBRARY_SUCCESS = 0, + // these error codes are returned by the Virglrenderer APIR component APIR_LOAD_LIBRARY_HYPERCALL_INITIALIZATION_ERROR = 1, APIR_LOAD_LIBRARY_ALREADY_LOADED = 2, APIR_LOAD_LIBRARY_ENV_VAR_MISSING = 3, APIR_LOAD_LIBRARY_CANNOT_OPEN = 4, APIR_LOAD_LIBRARY_SYMBOL_MISSING = 5, - APIR_LOAD_LIBRARY_INIT_BASE_INDEX = 6, // anything above this is a APIR backend library initialization return code + // any value greater than this is an APIR *backend library* initialization return code + APIR_LOAD_LIBRARY_INIT_BASE_INDEX = 6, }; enum ApirForwardReturnCode { - APIR_FORWARD_SUCCESS = 0, - APIR_FORWARD_NO_DISPATCH_FCT = 1, - APIR_FORWARD_TIMEOUT = 2, - - APIR_FORWARD_BASE_INDEX = 3, // anything above this is a APIR backend library forward return code -} ; + APIR_FORWARD_SUCCESS = 0, + // these error codes are returned by the Virglrenderer APIR component + APIR_FORWARD_NO_DISPATCH_FCT = 1, + APIR_FORWARD_TIMEOUT = 2, + APIR_FORWARD_FAILED_TO_SYNC_STREAMS = 3, + // any value greater than this index an APIR *backend library* forward return code + APIR_FORWARD_BASE_INDEX = 4, +}; __attribute__((unused)) static inline const char * apir_command_name(ApirCommandType type) { switch (type) { @@ -82,6 +86,7 @@ __attribute__((unused)) static const char * apir_forward_error(ApirForwardReturn APIR_FORWARD_ERROR(APIR_FORWARD_SUCCESS); APIR_FORWARD_ERROR(APIR_FORWARD_NO_DISPATCH_FCT); APIR_FORWARD_ERROR(APIR_FORWARD_TIMEOUT); + APIR_FORWARD_ERROR(APIR_FORWARD_FAILED_TO_SYNC_STREAMS); APIR_FORWARD_ERROR(APIR_FORWARD_BASE_INDEX); return "Unknown APIR_COMMAND_TYPE_FORWARD error"; diff --git a/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h b/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h index d214b6f2a90..520ac9c7299 100644 --- a/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +++ b/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h @@ -34,3 +34,61 @@ typedef enum ApirBackendCommandType { // last command_type index + 1 APIR_BACKEND_DISPATCH_TABLE_COUNT = 23, } ApirBackendCommandType; + +static inline const char * apir_dispatch_command_name(ApirBackendCommandType type) { + switch (type) { + /* device */ + case APIR_COMMAND_TYPE_DEVICE_GET_DEVICE_COUNT: + return "device_get_device_count"; + case APIR_COMMAND_TYPE_DEVICE_GET_COUNT: + return "device_get_count"; + case APIR_COMMAND_TYPE_DEVICE_GET_NAME: + return "device_get_name"; + case APIR_COMMAND_TYPE_DEVICE_GET_DESCRIPTION: + return "device_get_description"; + case APIR_COMMAND_TYPE_DEVICE_GET_TYPE: + return "device_get_type"; + case APIR_COMMAND_TYPE_DEVICE_GET_MEMORY: + return "device_get_memory"; + case APIR_COMMAND_TYPE_DEVICE_SUPPORTS_OP: + return "device_supports_op"; + case APIR_COMMAND_TYPE_DEVICE_GET_BUFFER_TYPE: + return "device_get_buffer_type"; + case APIR_COMMAND_TYPE_DEVICE_GET_PROPS: + return "device_get_props"; + case APIR_COMMAND_TYPE_DEVICE_BUFFER_FROM_PTR: + return "device_buffer_from_ptr"; + /* buffer-type */ + case APIR_COMMAND_TYPE_BUFFER_TYPE_GET_NAME: + return "buffer_type_get_name"; + case APIR_COMMAND_TYPE_BUFFER_TYPE_GET_ALIGNMENT: + return "buffer_type_get_alignment"; + case APIR_COMMAND_TYPE_BUFFER_TYPE_GET_MAX_SIZE: + return "buffer_type_get_max_size"; + case APIR_COMMAND_TYPE_BUFFER_TYPE_IS_HOST: + return "buffer_type_is_host"; + case APIR_COMMAND_TYPE_BUFFER_TYPE_ALLOC_BUFFER: + return "buffer_type_alloc_buffer"; + case APIR_COMMAND_TYPE_BUFFER_TYPE_GET_ALLOC_SIZE: + return "buffer_type_get_alloc_size"; + /* buffer */ + case APIR_COMMAND_TYPE_BUFFER_GET_BASE: + return "buffer_get_base"; + case APIR_COMMAND_TYPE_BUFFER_SET_TENSOR: + return "buffer_set_tensor"; + case APIR_COMMAND_TYPE_BUFFER_GET_TENSOR: + return "buffer_get_tensor"; + case APIR_COMMAND_TYPE_BUFFER_CPY_TENSOR: + return "buffer_cpy_tensor"; + case APIR_COMMAND_TYPE_BUFFER_CLEAR: + return "buffer_clear"; + case APIR_COMMAND_TYPE_BUFFER_FREE_BUFFER: + return "buffer_free_buffer"; + /* backend */ + case APIR_COMMAND_TYPE_BACKEND_GRAPH_COMPUTE: + return "backend_graph_compute"; + + default: + return "unknown"; + } +} diff --git a/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h b/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h index f3efa52c721..da1e21b5b2f 100644 --- a/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +++ b/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h @@ -14,7 +14,7 @@ #define APIR_BACKEND_INITIALIZE_BACKEND_REG_FAILED 6 #define APIR_BACKEND_INITIALIZE_ALREADY_INITED 7 #define APIR_BACKEND_INITIALIZE_NO_DEVICE 8 - +#define APIR_BACKEND_INITIALIZE_BACKEND_INIT_FAILED 9 // new entries here need to be added to the apir_backend_initialize_error function below @@ -39,6 +39,10 @@ static const char * apir_backend_initialize_error(int code) { APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_MISSING_BACKEND_SYMBOLS); APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_MISSING_GGML_SYMBOLS); APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_BACKEND_FAILED); + APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_BACKEND_REG_FAILED); + APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_ALREADY_INITED); + APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_NO_DEVICE); + APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_BACKEND_INIT_FAILED); return "Unknown APIR_BACKEND_INITIALIZE error:/"; diff --git a/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h b/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h index 1bc3a5f685b..64bf2ec9609 100644 --- a/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +++ b/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h @@ -13,7 +13,6 @@ struct apir_encoder { const char * start; const char * end; bool fatal; - }; struct apir_decoder { @@ -28,8 +27,8 @@ struct apir_decoder { static apir_decoder apir_new_decoder(const char * ptr, size_t size) { apir_decoder dec = { - .cur = ptr, - .end = ptr + size, + .cur = ptr, + .end = ptr + size, .fatal = false, }; @@ -79,10 +78,7 @@ static inline bool apir_decoder_get_fatal(const apir_decoder * dec) { * encode peek */ -static inline bool apir_decoder_peek_internal(apir_decoder * dec, - size_t size, - void * val, - size_t val_size) { +static inline bool apir_decoder_peek_internal(apir_decoder * dec, size_t size, void * val, size_t val_size) { assert(val_size <= size); if (unlikely(size > (size_t) (dec->end - dec->cur))) { @@ -332,8 +328,7 @@ static inline void apir_decode_char_array(apir_decoder * dec, char * val, size_t static inline void * apir_decoder_alloc_array(size_t size, size_t count) { size_t alloc_size; if (unlikely(__builtin_mul_overflow(size, count, &alloc_size))) { - GGML_LOG_ERROR("%s: overflow in array allocation of %zu * %zu bytes\n", - __func__, size, count); + GGML_LOG_ERROR("%s: overflow in array allocation of %zu * %zu bytes\n", __func__, size, count); return NULL; } @@ -352,20 +347,19 @@ static inline void apir_decode_bool_t(apir_decoder * dec, bool * val) { /* apir_buffer_type_host_handle_t */ -static inline void apir_encode_apir_buffer_type_host_handle_t(apir_encoder * enc, +static inline void apir_encode_apir_buffer_type_host_handle_t(apir_encoder * enc, const apir_buffer_type_host_handle_t * val) { apir_encode(enc, sizeof(apir_buffer_type_host_handle_t), val, sizeof(apir_buffer_type_host_handle_t)); } -static inline void apir_decode_apir_buffer_type_host_handle_t(apir_decoder * dec, +static inline void apir_decode_apir_buffer_type_host_handle_t(apir_decoder * dec, apir_buffer_type_host_handle_t * val) { apir_decode(dec, sizeof(apir_buffer_type_host_handle_t), val, sizeof(apir_buffer_type_host_handle_t)); } /* apir_buffer_host_handle_t */ -static inline void apir_encode_apir_buffer_host_handle_t(apir_encoder * enc, - const apir_buffer_host_handle_t * val) { +static inline void apir_encode_apir_buffer_host_handle_t(apir_encoder * enc, const apir_buffer_host_handle_t * val) { apir_encode(enc, sizeof(apir_buffer_host_handle_t), val, sizeof(apir_buffer_host_handle_t)); } diff --git a/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h b/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h index 289f4b77d74..fabe3e401ca 100644 --- a/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +++ b/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h @@ -1,11 +1,10 @@ -#include "ggml-impl.h" #include "apir_cs.h" #include "apir_cs_rpc.h" +#include "ggml-impl.h" // ggml_buffer_to_apir_host_handle(ggml_backend_buffer_t buffer); -static inline void apir_encode_ggml_buffer_host_handle(apir_encoder * enc, - const apir_buffer_host_handle_t * handle); +static inline void apir_encode_ggml_buffer_host_handle(apir_encoder * enc, const apir_buffer_host_handle_t * handle); static inline ggml_backend_buffer_t apir_decode_ggml_buffer(apir_decoder * dec); @@ -22,8 +21,7 @@ static inline apir_rpc_tensor * apir_decode_apir_rpc_tensor_inplace(apir_decoder return (apir_rpc_tensor *) (uintptr_t) apir_decoder_use_inplace(dec, apir_rpc_tensor_size); } -static inline apir_rpc_tensor * apir_decode_apir_rpc_tensor_array_inplace(apir_decoder * dec, - uint32_t n_tensors) { +static inline apir_rpc_tensor * apir_decode_apir_rpc_tensor_array_inplace(apir_decoder * dec, uint32_t n_tensors) { size_t apir_rpc_tensor_size = sizeof(apir_rpc_tensor) * n_tensors; return (apir_rpc_tensor *) (uintptr_t) apir_decoder_use_inplace(dec, apir_rpc_tensor_size); @@ -45,9 +43,9 @@ static inline const ggml_tensor * apir_decode_ggml_tensor(apir_decoder * dec) { } ggml_init_params params{ - /*.mem_size =*/ ggml_tensor_overhead(), - /*.mem_buffer =*/ NULL, - /*.no_alloc =*/ true, + /*.mem_size =*/ggml_tensor_overhead(), + /*.mem_buffer =*/NULL, + /*.no_alloc =*/true, }; ggml_context * ctx = ggml_init(params); @@ -105,6 +103,19 @@ static inline ggml_backend_buffer_t apir_decode_ggml_buffer(apir_decoder * dec) apir_decoder_read(dec, buffer_ptr_size, &buffer, buffer_ptr_size); + // SECURITY: Validate buffer handle against tracked buffers to prevent + // guest VM from providing arbitrary host memory addresses + if (buffer) { + extern std::unordered_set backend_buffers; + if (backend_buffers.find(buffer) == backend_buffers.end()) { + GGML_LOG_WARN("ggml-virtgpu-backend: %s: Invalid buffer handle from guest: %p\n", __func__, + (void *) buffer); + // Set fatal flag to prevent further processing with invalid handle + apir_decoder_set_fatal(dec); + return NULL; + } + } + return buffer; } diff --git a/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h b/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h index f6817989528..4cb2f047d1e 100644 --- a/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +++ b/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h @@ -1,3 +1,6 @@ +#pragma once + +// clang-format off #include "ggml.h" #include "ggml-backend-impl.h" @@ -5,6 +8,7 @@ #include #include #include +// clang-format on // ggml_tensor is serialized into apir_rpc_tensor struct apir_rpc_tensor { diff --git a/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp b/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp index c493a8e2ae3..8fa20ff43bd 100644 --- a/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +++ b/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp @@ -34,6 +34,7 @@ static ggml_backend_buffer_t ggml_backend_remoting_buffer_type_alloc_buffer(ggml static const char * ggml_backend_remoting_buffer_type_get_name(ggml_backend_buffer_type_t buft) { virtgpu * gpu = BUFT_TO_GPU(buft); + // Return the prefixed name that was built once during initialization return gpu->cached_buffer_type.name; } @@ -53,9 +54,8 @@ static size_t ggml_backend_remoting_buffer_type_get_alloc_size(ggml_backend_buff const ggml_tensor * tensor) { virtgpu * gpu = BUFT_TO_GPU(buft); - if (tensor->buffer == NULL - || !tensor->buffer->context - || !buft->device->iface.supports_buft(buft->device, tensor->buffer->buft)) { + if (tensor->buffer == NULL || !tensor->buffer->context || + !buft->device->iface.supports_buft(buft->device, tensor->buffer->buft)) { return ggml_nbytes(tensor); } diff --git a/ggml/src/ggml-virtgpu/ggml-backend-device.cpp b/ggml/src/ggml-virtgpu/ggml-backend-device.cpp index c7d2881058b..ec8156bb868 100644 --- a/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +++ b/ggml/src/ggml-virtgpu/ggml-backend-device.cpp @@ -3,6 +3,7 @@ static const char * ggml_backend_remoting_device_get_name(ggml_backend_dev_t dev) { virtgpu * gpu = DEV_TO_GPU(dev); + // Return the prefixed name that was built once during initialization return gpu->cached_device_info.name; } @@ -22,7 +23,7 @@ static enum ggml_backend_dev_type ggml_backend_remoting_device_get_type(ggml_bac static void ggml_backend_remoting_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { virtgpu * gpu = DEV_TO_GPU(dev); - *free = gpu->cached_device_info.memory_free; + *free = gpu->cached_device_info.memory_free; *total = gpu->cached_device_info.memory_total; } @@ -72,7 +73,7 @@ static void ggml_backend_remoting_device_get_props(ggml_backend_dev_t dev, ggml_ ggml_backend_buffer_type_t ggml_backend_remoting_device_get_buffer_type(ggml_backend_dev_t dev) { virtgpu * gpu = DEV_TO_GPU(dev); - static std::atomic initialized = false; + static std::atomic initialized = false; static ggml_backend_buffer_type buft; if (!initialized) { @@ -95,7 +96,7 @@ ggml_backend_buffer_type_t ggml_backend_remoting_device_get_buffer_type(ggml_bac static ggml_backend_buffer_type_t ggml_backend_remoting_device_get_buffer_from_ptr_type(ggml_backend_dev_t dev) { virtgpu * gpu = DEV_TO_GPU(dev); - static std::atomic initialized = false; + static std::atomic initialized = false; static ggml_backend_buffer_type buft; if (!initialized) { diff --git a/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp b/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp index 2d02cfec1d3..a4df5956aa3 100644 --- a/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +++ b/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp @@ -7,8 +7,8 @@ void ggml_virtgpu_cleanup(virtgpu * gpu); static virtgpu * apir_initialize() { - static virtgpu * gpu = NULL; - static std::atomic initialized = false; + static virtgpu * gpu = NULL; + static std::atomic initialized = false; if (initialized) { // fast track @@ -31,29 +31,53 @@ static virtgpu * apir_initialize() { } // Pre-fetch and cache all device information, it will not change - gpu->cached_device_info.description = apir_device_get_description(gpu); + gpu->cached_device_info.description = apir_device_get_description(gpu); if (!gpu->cached_device_info.description) { GGML_ABORT(GGML_VIRTGPU "%s: failed to initialize the virtgpu device description", __func__); } - gpu->cached_device_info.name = apir_device_get_name(gpu); - if (!gpu->cached_device_info.name) { - GGML_ABORT(GGML_VIRTGPU "%s: failed to initialize the virtgpu device name", __func__); - } gpu->cached_device_info.device_count = apir_device_get_count(gpu); gpu->cached_device_info.type = apir_device_get_type(gpu); - apir_device_get_memory(gpu, - &gpu->cached_device_info.memory_free, - &gpu->cached_device_info.memory_total); + { + // Get the remote name and create prefixed version + char * rmt_device_name = apir_device_get_name(gpu); + if (!rmt_device_name) { + GGML_ABORT(GGML_VIRTGPU "%s: failed to get the virtgpu device name", __func__); + } + + size_t device_name_len = strlen(rmt_device_name) + 11; // "[virtgpu] " + null terminator + gpu->cached_device_info.name = (char *) malloc(device_name_len); + if (!gpu->cached_device_info.name) { + free(rmt_device_name); + GGML_ABORT(GGML_VIRTGPU "%s: failed to allocate memory for prefixed device name", __func__); + } + snprintf(gpu->cached_device_info.name, device_name_len, "[virtgpu] %s", rmt_device_name); + free(rmt_device_name); + } + + apir_device_get_memory(gpu, &gpu->cached_device_info.memory_free, &gpu->cached_device_info.memory_total); apir_buffer_type_host_handle_t buft_host_handle = apir_device_get_buffer_type(gpu); gpu->cached_buffer_type.host_handle = buft_host_handle; - gpu->cached_buffer_type.name = apir_buffer_type_get_name(gpu, buft_host_handle); - if (!gpu->cached_buffer_type.name) { - GGML_ABORT(GGML_VIRTGPU "%s: failed to initialize the virtgpu buffer type name", __func__); + { + // Get the remote name and create prefixed version + char * rmt_name = apir_buffer_type_get_name(gpu, buft_host_handle); + if (!rmt_name) { + GGML_ABORT(GGML_VIRTGPU "%s: failed to get the virtgpu buffer type name", __func__); + } + + size_t prefixed_len = strlen(rmt_name) + 11; // "[virtgpu] " + null terminator + gpu->cached_buffer_type.name = (char *) malloc(prefixed_len); + if (!gpu->cached_buffer_type.name) { + free(rmt_name); + GGML_ABORT(GGML_VIRTGPU "%s: failed to allocate memory for prefixed buffer type name", __func__); + } + snprintf(gpu->cached_buffer_type.name, prefixed_len, "[virtgpu] %s", rmt_name); + free(rmt_name); } - gpu->cached_buffer_type.alignment = apir_buffer_type_get_alignment(gpu, buft_host_handle); - gpu->cached_buffer_type.max_size = apir_buffer_type_get_max_size(gpu, buft_host_handle); + + gpu->cached_buffer_type.alignment = apir_buffer_type_get_alignment(gpu, buft_host_handle); + gpu->cached_buffer_type.max_size = apir_buffer_type_get_max_size(gpu, buft_host_handle); initialized = true; } @@ -98,7 +122,7 @@ static void ggml_backend_remoting_reg_init_devices(ggml_backend_reg_t reg) { static std::atomic initialized = false; if (initialized) { - return; // fast track + return; // fast track } { diff --git a/ggml/src/ggml-virtgpu/ggml-backend.cpp b/ggml/src/ggml-virtgpu/ggml-backend.cpp index 5cd6c0c0608..a63ee2b9d2f 100644 --- a/ggml/src/ggml-virtgpu/ggml-backend.cpp +++ b/ggml/src/ggml-virtgpu/ggml-backend.cpp @@ -1,5 +1,5 @@ -#include "ggml-remoting.h" #include "../../include/ggml-virtgpu.h" +#include "ggml-remoting.h" static const char * ggml_backend_remoting_get_name(ggml_backend_t backend) { UNUSED(backend); diff --git a/ggml/src/ggml-virtgpu/ggml-remoting.h b/ggml/src/ggml-virtgpu/ggml-remoting.h index 08766408676..4f70326bee2 100644 --- a/ggml/src/ggml-virtgpu/ggml-remoting.h +++ b/ggml/src/ggml-virtgpu/ggml-remoting.h @@ -9,7 +9,7 @@ #include #define GGML_VIRTGPU_NAME "ggml-virtgpu" -#define GGML_VIRTGPU "ggml-virtgpu: " +#define GGML_VIRTGPU "ggml-virtgpu: " // USE_ALWAYS_TRUE_SUPPORTS_OP: 1 is fast, 0 avoid micro-benchmark crashes diff --git a/ggml/src/ggml-virtgpu/include/apir_hw.h b/ggml/src/ggml-virtgpu/include/apir_hw.h index 33af045ca2b..7d6ea2265db 100644 --- a/ggml/src/ggml-virtgpu/include/apir_hw.h +++ b/ggml/src/ggml-virtgpu/include/apir_hw.h @@ -3,7 +3,7 @@ #include struct virgl_renderer_capset_apir { - uint32_t apir_version; - uint32_t supports_blob_resources; - uint32_t reserved[4]; // For future expansion + uint32_t apir_version; + uint32_t supports_blob_resources; + uint32_t reserved[4]; // For future expansion }; diff --git a/ggml/src/ggml-virtgpu/regenerate_remoting.py b/ggml/src/ggml-virtgpu/regenerate_remoting.py index aeb48a4087e..dae75fd1c80 100755 --- a/ggml/src/ggml-virtgpu/regenerate_remoting.py +++ b/ggml/src/ggml-virtgpu/regenerate_remoting.py @@ -145,8 +145,31 @@ def generate_apir_backend_header(self) -> str: enum_lines.append(f" APIR_BACKEND_DISPATCH_TABLE_COUNT = {total_count},") enum_lines.append("} ApirBackendCommandType;") + # Generate function name mapping + func_lines = [] + func_lines.append("static inline const char * apir_dispatch_command_name(ApirBackendCommandType type) {") + func_lines.append(" switch (type) {") + + current_group = None + for func in functions: + # Add comment for new group + if func['group_name'] != current_group: + func_lines.append(f" /* {func['group_description']} */") + current_group = func['group_name'] + + # Generate clean function name without backend_ prefix + clean_name = f"{func['group_name']}_{func['function_name']}" + func_lines.append(f" case {func['enum_name']}:") + func_lines.append(f" return \"{clean_name}\";") + + func_lines.append("") + func_lines.append(" default:") + func_lines.append(" return \"unknown\";") + func_lines.append(" }") + func_lines.append("}") + # Full header template - header_content = NL.join(enum_lines) + "\n" + header_content = NL.join(enum_lines) + "\n\n" + NL.join(func_lines) + "\n" return header_content @@ -170,19 +193,6 @@ def generate_backend_dispatched_header(self) -> str: decl_lines.append(f"{signature} {func['backend_function']}({params});") - # Switch cases - switch_lines = [] - current_group = None - - for func in functions: - if func['group_name'] != current_group: - switch_lines.append(f" /* {func['group_description']} */") - current_group = func['group_name'] - - deprecated = " (DEPRECATED)" if func['deprecated'] else "" - - switch_lines.append(f" case {func['enum_name']}: return \"{func['backend_function']}{deprecated}\";") - # Dispatch table table_lines = [] current_group = None @@ -201,15 +211,6 @@ def generate_backend_dispatched_header(self) -> str: {NL.join(decl_lines)} -static inline const char *backend_dispatch_command_name(ApirBackendCommandType type) -{{ - switch (type) {{ -{NL.join(switch_lines)} - - default: return "unknown"; - }} -}} - extern "C" {{ static const backend_dispatch_t apir_backend_dispatch_table[APIR_BACKEND_DISPATCH_TABLE_COUNT] = {{ {NL.join(table_lines)} diff --git a/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp b/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp index 07d9a668496..4593690c638 100644 --- a/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +++ b/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp @@ -17,8 +17,8 @@ ggml_status apir_backend_graph_compute(virtgpu * gpu, ggml_cgraph * cgraph) { size_t cgraph_size = apir_serialize_ggml_cgraph(cgraph, cgraph_data); virtgpu_shmem temp_shmem; // Local storage for large buffers - virtgpu_shmem * shmem = &temp_shmem; - bool using_shared_shmem = false; + virtgpu_shmem * shmem = &temp_shmem; + bool using_shared_shmem = false; if (cgraph_size <= gpu->data_shmem.mmap_size) { // Lock mutex before using shared data_shmem buffer @@ -26,7 +26,7 @@ ggml_status apir_backend_graph_compute(virtgpu * gpu, ggml_cgraph * cgraph) { GGML_ABORT(GGML_VIRTGPU "%s: Failed to lock data_shmem mutex", __func__); } using_shared_shmem = true; - shmem = &gpu->data_shmem; + shmem = &gpu->data_shmem; } else if (virtgpu_shmem_create(gpu, cgraph_size, shmem)) { GGML_ABORT(GGML_VIRTGPU "%s: Couldn't allocate the guest-host shared buffer", __func__); } diff --git a/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp b/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp index cab74fd1707..38f8ec945e0 100644 --- a/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +++ b/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp @@ -62,7 +62,9 @@ size_t apir_buffer_type_get_max_size(virtgpu * gpu, apir_buffer_type_host_handle return max_size; } -apir_buffer_context_t apir_buffer_type_alloc_buffer(virtgpu * gpu, apir_buffer_type_host_handle_t host_handle, size_t size) { +apir_buffer_context_t apir_buffer_type_alloc_buffer(virtgpu * gpu, + apir_buffer_type_host_handle_t host_handle, + size_t size) { apir_encoder * encoder; apir_decoder * decoder; ApirForwardReturnCode ret; @@ -84,7 +86,9 @@ apir_buffer_context_t apir_buffer_type_alloc_buffer(virtgpu * gpu, apir_buffer_t return buffer_context; } -size_t apir_buffer_type_get_alloc_size(virtgpu * gpu, apir_buffer_type_host_handle_t host_handle, const ggml_tensor * op) { +size_t apir_buffer_type_get_alloc_size(virtgpu * gpu, + apir_buffer_type_host_handle_t host_handle, + const ggml_tensor * op) { apir_encoder * encoder; apir_decoder * decoder; ApirForwardReturnCode ret; diff --git a/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp b/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp index 86eee358cf4..228284f4a42 100644 --- a/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +++ b/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp @@ -35,8 +35,8 @@ void apir_buffer_set_tensor(virtgpu * gpu, apir_encode_ggml_tensor(encoder, tensor); virtgpu_shmem temp_shmem; // Local storage for large buffers - virtgpu_shmem * shmem = &temp_shmem; - bool using_shared_shmem = false; + virtgpu_shmem * shmem = &temp_shmem; + bool using_shared_shmem = false; if (size <= gpu->data_shmem.mmap_size) { // Lock mutex before using shared data_shmem buffer @@ -44,7 +44,7 @@ void apir_buffer_set_tensor(virtgpu * gpu, GGML_ABORT(GGML_VIRTGPU "%s: Failed to lock data_shmem mutex", __func__); } using_shared_shmem = true; - shmem = &gpu->data_shmem; + shmem = &gpu->data_shmem; } else if (virtgpu_shmem_create(gpu, size, shmem)) { GGML_ABORT(GGML_VIRTGPU "%s: Couldn't allocate the guest-host shared buffer", __func__); @@ -86,8 +86,8 @@ void apir_buffer_get_tensor(virtgpu * gpu, apir_encode_ggml_tensor(encoder, tensor); virtgpu_shmem temp_shmem; // Local storage for large buffers - virtgpu_shmem * shmem = &temp_shmem; - bool using_shared_shmem = false; + virtgpu_shmem * shmem = &temp_shmem; + bool using_shared_shmem = false; if (size <= gpu->data_shmem.mmap_size) { // Lock mutex before using shared data_shmem buffer @@ -95,7 +95,7 @@ void apir_buffer_get_tensor(virtgpu * gpu, GGML_ABORT(GGML_VIRTGPU "%s: Failed to lock data_shmem mutex", __func__); } using_shared_shmem = true; - shmem = &gpu->data_shmem; + shmem = &gpu->data_shmem; } else if (virtgpu_shmem_create(gpu, size, shmem)) { GGML_ABORT(GGML_VIRTGPU "%s: Couldn't allocate the guest-host shared buffer", __func__); diff --git a/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp b/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp index 4b6b8f527be..9f513c138dd 100644 --- a/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +++ b/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp @@ -26,7 +26,7 @@ char * apir_device_get_name(virtgpu * gpu) { REMOTE_CALL(gpu, encoder, decoder, ret); const size_t string_size = apir_decode_array_size_unchecked(decoder); - char * string = (char *) apir_decoder_alloc_array(sizeof(char), string_size); + char * string = (char *) apir_decoder_alloc_array(sizeof(char), string_size); if (!string) { GGML_LOG_ERROR(GGML_VIRTGPU "%s: Could not allocate the device name buffer\n", __func__); return NULL; @@ -173,7 +173,7 @@ apir_buffer_context_t apir_device_buffer_from_ptr(virtgpu * gpu, size_t size, si REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_DEVICE_BUFFER_FROM_PTR); if (virtgpu_shmem_create(gpu, size, &buffer_context.shmem)) { - GGML_ABORT(GGML_VIRTGPU "Couldn't allocate the guest-host shared buffer"); + GGML_ABORT(GGML_VIRTGPU "%s: Couldn't allocate %ldb of guest-host shared buffer", __func__, size); } apir_encode_virtgpu_shmem_res_id(encoder, buffer_context.shmem.res_id); diff --git a/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h b/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h index f23c75bb968..4d0b6e05c74 100644 --- a/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +++ b/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h @@ -1,29 +1,36 @@ -#include "virtgpu.h" +#pragma once +// clang-format off +#include "virtgpu.h" #include "ggml-remoting.h" #include "backend/shared/apir_backend.h" #include "backend/shared/apir_cs_ggml.h" - #include "ggml-backend-impl.h" +// clang-format on -#define REMOTE_CALL_PREPARE(gpu_dev_name, encoder_name, apir_command_type__) \ - do { \ - int32_t forward_flag = (int32_t) apir_command_type__; \ - encoder_name = remote_call_prepare(gpu_dev_name, APIR_COMMAND_TYPE_FORWARD, forward_flag); \ - if (!encoder_name) { \ - GGML_ABORT(GGML_VIRTGPU "%s: failed to prepare the remote call encoder", __func__); \ - } \ +#define REMOTE_CALL_PREPARE(gpu_dev_name, encoder_name, apir_command_type__) \ + int32_t REMOTE_CALL_PREPARE_forward_flag = (int32_t) apir_command_type__; \ + const char * REMOTE_CALL_PREPARE_command_name = apir_dispatch_command_name(apir_command_type__); \ + do { \ + encoder_name = remote_call_prepare(gpu_dev_name, APIR_COMMAND_TYPE_FORWARD, REMOTE_CALL_PREPARE_forward_flag); \ + if (!encoder_name) { \ + GGML_ABORT(GGML_VIRTGPU "%s: failed to prepare the remote call encoder", __func__); \ + } \ } while (0) -#define REMOTE_CALL(gpu_dev_name, encoder_name, decoder_name, ret_name) \ - do { \ - ret_name = (ApirForwardReturnCode) remote_call(gpu_dev_name, encoder_name, &decoder_name, 0, NULL); \ - if (!decoder_name) { \ - GGML_ABORT(GGML_VIRTGPU "%s: failed to kick the remote call", __func__); \ - } \ - if (ret_name < APIR_FORWARD_BASE_INDEX) { \ - GGML_ABORT(GGML_VIRTGPU "%s: failed to forward the API call: %s: code %d", __func__, \ - apir_forward_error(ret_name), ret_name); \ - } \ - ret_name = (ApirForwardReturnCode) (ret_name - APIR_FORWARD_BASE_INDEX); \ +#define REMOTE_CALL(gpu_dev_name, encoder_name, decoder_name, ret_name) \ + do { \ + ret_name = (ApirForwardReturnCode) remote_call(gpu_dev_name, encoder_name, &decoder_name, 0, NULL); \ + if (!decoder_name) { \ + GGML_ABORT(GGML_VIRTGPU "%s: failed to kick the remote call", __func__); \ + } \ + if (ret_name < APIR_FORWARD_BASE_INDEX) { \ + GGML_ABORT(GGML_VIRTGPU "%s: failed to forward the API call: %s: code %d", __func__, \ + apir_forward_error(ret_name), ret_name); \ + } \ + ret_name = (ApirForwardReturnCode) (ret_name - APIR_FORWARD_BASE_INDEX); \ + if (ret_name != 0) { \ + GGML_ABORT(GGML_VIRTGPU "backend function '%s' failed (return code: %d)", \ + REMOTE_CALL_PREPARE_command_name, ret_name); \ + } \ } while (0) diff --git a/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h b/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h index fe4cae20253..44b0ad1ffa1 100644 --- a/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +++ b/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h @@ -20,6 +20,7 @@ apir_buffer_context_t apir_device_buffer_from_ptr(struct virtgpu * gpu, char * apir_buffer_type_get_name(struct virtgpu * gpu, apir_buffer_type_host_handle_t host_handle); size_t apir_buffer_type_get_alignment(struct virtgpu * gpu, apir_buffer_type_host_handle_t host_handle); size_t apir_buffer_type_get_max_size(struct virtgpu * gpu, apir_buffer_type_host_handle_t host_handle); +/* apir_buffer_type_is_host is deprecated. */ apir_buffer_context_t apir_buffer_type_alloc_buffer(struct virtgpu * gpu, apir_buffer_type_host_handle_t host_handle, size_t size); diff --git a/ggml/src/ggml-virtgpu/virtgpu.cpp b/ggml/src/ggml-virtgpu/virtgpu.cpp index 1e650dc65b2..a84a77399d9 100644 --- a/ggml/src/ggml-virtgpu/virtgpu.cpp +++ b/ggml/src/ggml-virtgpu/virtgpu.cpp @@ -53,9 +53,9 @@ static int virtgpu_handshake(virtgpu * gpu) { if (!decoder) { GGML_ABORT(GGML_VIRTGPU - "%s: failed to initiate the communication with the virglrenderer library. " - "Most likely, the wrong virglrenderer library was loaded in the hypervisor.", - __func__); + "%s: failed to initiate the communication with the virglrenderer library. " + "Most likely, the wrong virglrenderer library was loaded in the hypervisor.", + __func__); return 1; } @@ -65,8 +65,7 @@ static int virtgpu_handshake(virtgpu * gpu) { uint32_t host_minor; if (ret_magic != APIR_HANDSHAKE_MAGIC) { - GGML_ABORT(GGML_VIRTGPU - "%s: handshake with the virglrenderer failed (code=%d | %s)", __func__, ret_magic, + GGML_ABORT(GGML_VIRTGPU "%s: handshake with the virglrenderer failed (code=%d | %s)", __func__, ret_magic, apir_backend_initialize_error(ret_magic)); } else { apir_decode_uint32_t(decoder, &host_major); @@ -140,15 +139,13 @@ static ApirLoadLibraryReturnCode virtgpu_load_library(virtgpu * gpu) { "Make sure virglrenderer is correctly configured by the hypervisor. (%s) ", __func__, apir_load_library_error(ret)); } else { - GGML_ABORT(GGML_VIRTGPU - "%s: virglrenderer could not load the API Remoting backend library. (%s - code %d)", __func__, - apir_load_library_error(ret), ret); + GGML_ABORT(GGML_VIRTGPU "%s: virglrenderer could not load the API Remoting backend library. (%s - code %d)", + __func__, apir_load_library_error(ret), ret); } return ret; } - GGML_LOG_INFO(GGML_VIRTGPU - "%s: virglrenderer successfully loaded the API Remoting backend library.\n", __func__); + GGML_LOG_INFO(GGML_VIRTGPU "%s: virglrenderer successfully loaded the API Remoting backend library.\n", __func__); ApirLoadLibraryReturnCode apir_ret = (ApirLoadLibraryReturnCode) (ret - APIR_LOAD_LIBRARY_INIT_BASE_INDEX); @@ -158,10 +155,11 @@ static ApirLoadLibraryReturnCode virtgpu_load_library(virtgpu * gpu) { "Make sure virglrenderer is correctly configured by the hypervisor. (%s)", __func__, apir_load_library_error(apir_ret)); } else if (apir_ret == APIR_LOAD_LIBRARY_SYMBOL_MISSING) { - GGML_ABORT(GGML_VIRTGPU - "%s: the API Remoting backend library couldn't load the GGML backend library, some symbols are missing. " - "Make sure virglrenderer is correctly configured by the hypervisor. (%s)", - __func__, apir_load_library_error(apir_ret)); + GGML_ABORT( + GGML_VIRTGPU + "%s: the API Remoting backend library couldn't load the GGML backend library, some symbols are missing. " + "Make sure virglrenderer is correctly configured by the hypervisor. (%s)", + __func__, apir_load_library_error(apir_ret)); } else if (apir_ret < APIR_LOAD_LIBRARY_INIT_BASE_INDEX) { GGML_ABORT(GGML_VIRTGPU "%s: the API Remoting backend library couldn't load the GGML backend library: apir code=%d | %s)", @@ -169,8 +167,8 @@ static ApirLoadLibraryReturnCode virtgpu_load_library(virtgpu * gpu) { } else { uint32_t lib_ret = apir_ret - APIR_LOAD_LIBRARY_INIT_BASE_INDEX; GGML_ABORT(GGML_VIRTGPU - "%s: the API Remoting backend library initialize its backend library: apir code=%d)", __func__, - lib_ret); + "%s: the API Remoting backend library failed to initialize its backend library: apir code=%d)", + __func__, lib_ret); } return ret; } @@ -184,55 +182,49 @@ virtgpu * create_virtgpu() { // Initialize mutex to protect shared data_shmem buffer if (mtx_init(&gpu->data_shmem_mutex, mtx_plain) != thrd_success) { delete gpu; - GGML_ABORT(GGML_VIRTGPU - "%s: failed to initialize data_shmem mutex", __func__); + GGML_ABORT(GGML_VIRTGPU "%s: failed to initialize data_shmem mutex", __func__); return NULL; } if (virtgpu_open(gpu) != APIR_SUCCESS) { - GGML_LOG_ERROR(GGML_VIRTGPU - "%s: failed to open the virtgpu device\n", __func__); + GGML_LOG_ERROR(GGML_VIRTGPU "%s: failed to open the virtgpu device\n", __func__); return NULL; } if (virtgpu_init_capset(gpu) != APIR_SUCCESS) { if (gpu->use_apir_capset) { GGML_ABORT(GGML_VIRTGPU - "%s: failed to initialize the virtgpu APIR capset. Make sure that the virglrenderer library supports it.", __func__); + "%s: failed to initialize the virtgpu APIR capset. Make sure that the virglrenderer library " + "supports it.", + __func__); } else { - GGML_ABORT(GGML_VIRTGPU - "%s: failed to initialize the virtgpu Venus capset", __func__); + GGML_ABORT(GGML_VIRTGPU "%s: failed to initialize the virtgpu Venus capset", __func__); } return NULL; } if (virtgpu_init_context(gpu) != APIR_SUCCESS) { - GGML_ABORT(GGML_VIRTGPU - "%s: failed to initialize the GPU context", __func__); + GGML_ABORT(GGML_VIRTGPU "%s: failed to initialize the GPU context", __func__); return NULL; } if (virtgpu_shmem_create(gpu, SHMEM_REPLY_SIZE, &gpu->reply_shmem)) { - GGML_ABORT(GGML_VIRTGPU - "%s: failed to create the shared reply memory pages", __func__); + GGML_ABORT(GGML_VIRTGPU "%s: failed to create the shared reply memory pages", __func__); return NULL; } if (virtgpu_shmem_create(gpu, SHMEM_DATA_SIZE, &gpu->data_shmem)) { - GGML_ABORT(GGML_VIRTGPU - "%s: failed to create the shared data memory pages", __func__); + GGML_ABORT(GGML_VIRTGPU "%s: failed to create the shared data memory pages", __func__); return NULL; } if (virtgpu_handshake(gpu)) { - GGML_ABORT(GGML_VIRTGPU - "%s: failed to handshake with the virglrenderer library", __func__); + GGML_ABORT(GGML_VIRTGPU "%s: failed to handshake with the virglrenderer library", __func__); return NULL; } if (virtgpu_load_library(gpu) != APIR_LOAD_LIBRARY_SUCCESS) { - GGML_ABORT(GGML_VIRTGPU - "%s: failed to load the backend library", __func__); + GGML_ABORT(GGML_VIRTGPU "%s: failed to load the backend library", __func__); return NULL; } @@ -243,8 +235,7 @@ static virt_gpu_result_t virtgpu_open(virtgpu * gpu) { drmDevicePtr devs[8]; int count = drmGetDevices2(0, devs, ARRAY_SIZE(devs)); if (count < 0) { - GGML_LOG_ERROR(GGML_VIRTGPU - "%s: failed to enumerate DRM devices\n", __func__); + GGML_LOG_ERROR(GGML_VIRTGPU "%s: failed to enumerate DRM devices\n", __func__); return APIR_ERROR_INITIALIZATION_FAILED; } @@ -266,19 +257,17 @@ static virt_gpu_result_t virtgpu_open_device(virtgpu * gpu, const drmDevicePtr d int fd = open(node_path, O_RDWR | O_CLOEXEC); if (fd < 0) { - GGML_ABORT(GGML_VIRTGPU - "%s: failed to open %s", __func__, node_path); + GGML_ABORT(GGML_VIRTGPU "%s: failed to open %s", __func__, node_path); return APIR_ERROR_INITIALIZATION_FAILED; } drmVersionPtr version = drmGetVersion(fd); if (!version || strcmp(version->name, "virtio_gpu") || version->version_major != 0) { if (version) { - GGML_LOG_ERROR(GGML_VIRTGPU - "%s: unknown DRM driver %s version %d\n", __func__, version->name, version->version_major); + GGML_LOG_ERROR(GGML_VIRTGPU "%s: unknown DRM driver %s version %d\n", __func__, version->name, + version->version_major); } else { - GGML_LOG_ERROR(GGML_VIRTGPU - "%s: failed to get DRM driver version\n", __func__); + GGML_LOG_ERROR(GGML_VIRTGPU "%s: failed to get DRM driver version\n", __func__); } if (version) { @@ -322,9 +311,8 @@ static virt_gpu_result_t virtgpu_init_capset(virtgpu * gpu) { virtgpu_ioctl_get_caps(gpu, gpu->capset.id, gpu->capset.version, &gpu->capset.data, sizeof(gpu->capset.data)); if (ret) { - GGML_LOG_ERROR(GGML_VIRTGPU - "%s: failed to get APIR v%d capset: %s\n", - __func__, gpu->capset.version, strerror(errno)); + GGML_LOG_ERROR(GGML_VIRTGPU "%s: failed to get APIR v%d capset: %s\n", __func__, gpu->capset.version, + strerror(errno)); return APIR_ERROR_INITIALIZATION_FAILED; } @@ -547,13 +535,10 @@ static void log_call_duration(long long call_duration_ns, const char * name) { double call_duration_s = (double) call_duration_ns / 1e9; // 1 second = 1e9 nanoseconds if (call_duration_s > 1) { - GGML_LOG_INFO(GGML_VIRTGPU - "waited %.2fs for the %s host reply...\n", call_duration_s, name); + GGML_LOG_INFO(GGML_VIRTGPU "waited %.2fs for the %s host reply...\n", call_duration_s, name); } else if (call_duration_ms > 1) { - GGML_LOG_INFO(GGML_VIRTGPU - "waited %.2fms for the %s host reply...\n", call_duration_ms, name); + GGML_LOG_INFO(GGML_VIRTGPU "waited %.2fms for the %s host reply...\n", call_duration_ms, name); } else { - GGML_LOG_INFO(GGML_VIRTGPU - "waited %lldns for the %s host reply...\n", call_duration_ns, name); + GGML_LOG_INFO(GGML_VIRTGPU "waited %lldns for the %s host reply...\n", call_duration_ns, name); } } diff --git a/ggml/src/ggml-virtgpu/virtgpu.h b/ggml/src/ggml-virtgpu/virtgpu.h index 68e0f3a376e..f82d8fb50ba 100644 --- a/ggml/src/ggml-virtgpu/virtgpu.h +++ b/ggml/src/ggml-virtgpu/virtgpu.h @@ -1,5 +1,6 @@ #pragma once +// clang-format off #include "virtgpu-utils.h" #include "virtgpu-shm.h" #include "virtgpu-apir.h" @@ -23,20 +24,21 @@ #include "apir_hw.h" #include #include "venus_hw.h" +// clang-format on #ifndef VIRTGPU_DRM_CAPSET_APIR // Will be defined include/drm/virtgpu_drm.h when // https://gitlab.freedesktop.org/virgl/virglrenderer/-/merge_requests/1590/diffs // is merged -#define VIRTGPU_DRM_CAPSET_APIR 10 +# define VIRTGPU_DRM_CAPSET_APIR 10 #endif // Mesa/Virlgrenderer Venus internal. Only necessary during the // Venus->APIR transition in Virglrenderer #define VENUS_COMMAND_TYPE_LENGTH 331 -#ifndef VIRTGPU_DRM_CAPSET_VENUS // only available with Linux >= v6.16 -#define VIRTGPU_DRM_CAPSET_VENUS 4 +#ifndef VIRTGPU_DRM_CAPSET_VENUS // only available with Linux >= v6.16 +# define VIRTGPU_DRM_CAPSET_VENUS 4 #endif typedef uint32_t virgl_renderer_capset; From e722ee1bf52a064c8de600e7c627e1f9763b7754 Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Thu, 26 Feb 2026 19:11:04 +0100 Subject: [PATCH 189/831] vulkan: fix fp16 Flash Attention on Windows AMD RDNA2 and below (llama/19921) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 22 +++++++++---------- .../vulkan-shaders/flash_attn.comp | 9 +++++++- .../vulkan-shaders/flash_attn_base.glsl | 7 +++--- 3 files changed, 22 insertions(+), 16 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index a1149e606e4..0fae68628b6 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -624,8 +624,6 @@ struct vk_device_struct { // floor(log2(maxComputeWorkGroupInvocations)) uint32_t max_workgroup_size_log2 {}; - bool flash_attention_fp16; - bool coopmat_support; bool coopmat_acc_f32_support {}; bool coopmat_acc_f16_support {}; @@ -2978,11 +2976,15 @@ static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_ } } -static vk_fa_pipeline_state get_fa_pipeline_state(const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool aligned, bool f32acc, +static vk_fa_pipeline_state get_fa_pipeline_state(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool aligned, bool f32acc, bool use_mask, bool use_mask_opt, bool use_logit_softcap) { + const bool old_amd_windows = device->vendor_id == VK_VENDOR_ID_AMD && device->driver_id == vk::DriverId::eAmdProprietary && + (device->architecture == AMD_GCN || device->architecture == AMD_RDNA1 || device->architecture == AMD_RDNA2); + uint32_t flags = (use_mask_opt ? 1 : 0) | (use_mask ? 2 : 0) | - (use_logit_softcap ? 4 : 0); + (use_logit_softcap ? 4 : 0) | + (old_amd_windows ? 8 : 0); const uint32_t subgroup_size = params.disable_subgroups ? 0 : params.subgroup_size; @@ -3384,7 +3386,7 @@ static void ggml_vk_load_shaders(vk_device& device) { } \ } - if (device->flash_attention_fp16) { + if (device->fp16) { CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, ) CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, ) CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, ) @@ -5423,10 +5425,6 @@ static vk_device ggml_vk_get_device(size_t idx) { device->mmvq_mode = 1; } - // Driver issues with older AMD GPUs on Windows, see https://github.com/ggml-org/llama.cpp/pull/19625#issuecomment-3940840613 - const bool is_amd_proprietary_gcn = device->vendor_id == VK_VENDOR_ID_AMD && device->architecture == AMD_GCN && device->driver_id == vk::DriverId::eAmdProprietary; - device->flash_attention_fp16 = device->fp16 && !is_amd_proprietary_gcn; - return device; } @@ -8567,7 +8565,7 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con const uint32_t Br = params.block_rows; const uint32_t Bc = params.block_cols; - const uint32_t float_type_size = device->flash_attention_fp16 ? sizeof(ggml_fp16_t) : sizeof(float); + const uint32_t float_type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float); // tmpsh is overestimated slightly const uint32_t tmpsh = wg_size * sizeof(float); @@ -8690,7 +8688,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx uint32_t workgroups_y = (uint32_t)neq2; uint32_t workgroups_z = (uint32_t)neq3; - const bool f32acc = !ctx->device->flash_attention_fp16 || dst->op_params[3] == GGML_PREC_F32; + const bool f32acc = !ctx->device->fp16 || dst->op_params[3] == GGML_PREC_F32; // For scalar/coopmat1 FA, we can use the "large" size to accommodate qga. // For coopmat2 FA, we always use the small size (which is still pretty large for gqa). @@ -8745,7 +8743,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx // Only use mask opt when the mask is fairly large. This hasn't been tuned extensively. bool use_mask_opt = mask && nem1 >= 32 && nem0 * nem1 > 32768; - vk_fa_pipeline_state fa_pipeline_state = get_fa_pipeline_state(tuning_params, HSK, HSV, aligned, f32acc, + vk_fa_pipeline_state fa_pipeline_state = get_fa_pipeline_state(ctx->device, tuning_params, HSK, HSV, aligned, f32acc, mask != nullptr, use_mask_opt, logit_softcap != 0); vk_pipeline pipeline = nullptr; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index 135ab1ad625..ec48f5b1152 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -465,7 +465,14 @@ void main() { if (SubGroupSize > 0) { [[unroll]] for (uint s = D_split; s < SubGroupSize; s *= 2) { - Of[r][d] += subgroupShuffleXor(Of[r][d], s); + if (!OLD_AMD_WINDOWS) { + Of[r][d] += subgroupShuffleXor(Of[r][d], s); + } else { + // Something about f16vec4 subgroupShuffleXor is broken on AMD Windows RDNA2 and below. + // Shuffle full vec4 as workaround. + // See https://github.com/ggml-org/llama.cpp/issues/19881#issuecomment-3958643697 + Of[r][d] += FLOAT_TYPEV4(subgroupShuffleXor(vec4(Of[r][d]), s)); + } } if (row_split == 1) { barrier(); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl index d444542b533..172d38f034e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl @@ -14,9 +14,10 @@ layout (constant_id = 9) const uint32_t SHMEM_STAGING = 0; layout (constant_id = 10) const uint32_t Flags = 0; layout (constant_id = 11) const uint32_t LIMIT_OCCUPANCY_SHMEM = 0; -const bool USE_MASK_OPT = (Flags & 1) != 0; -const bool MASK_ENABLE = (Flags & 2) != 0; -const bool LOGIT_SOFTCAP = (Flags & 4) != 0; +const bool USE_MASK_OPT = (Flags & 1) != 0; +const bool MASK_ENABLE = (Flags & 2) != 0; +const bool LOGIT_SOFTCAP = (Flags & 4) != 0; +const bool OLD_AMD_WINDOWS = (Flags & 8) != 0; // Round up head sizes to a multiple of 16, for coopmat1/coopmat2 paths const uint32_t HSK_pad = (HSK + 15) & ~15; From 316d921c1a1fe1148b6c212d2b12a2acf60df2b3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrien=20Gallou=C3=ABt?= Date: Thu, 26 Feb 2026 21:39:11 +0100 Subject: [PATCH 190/831] ggml : fix AMX and add batched support (llama/19925) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit llama-perplexity -hf ggml-org/Qwen3-0.6B-GGUF:Q4_0 -f wikitext-2-raw/wiki.test.raw -c 2048 -b 2048 --chunks 2 before this commit: ``` perplexity: calculating perplexity over 2 chunks, n_ctx=2048, batch_size=2048, n_seq=1 perplexity: 2.31 seconds per pass - ETA 0.07 minutes [1]17.3868,[2]22.2199, Final estimate: PPL = 22.2199 +/- 1.59692 llama_perf_context_print: load time = 878.56 ms llama_perf_context_print: prompt eval time = 2037.82 ms / 4096 tokens ( 0.50 ms per token, 2009.99 tokens per second) llama_perf_context_print: eval time = 0.00 ms / 1 runs ( 0.00 ms per token, inf tokens per second) llama_perf_context_print: total time = 6403.17 ms / 4097 tokens llama_perf_context_print: graphs reused = 0 llama_memory_breakdown_print: | memory breakdown [MiB] | total free self model context compute unaccounted | llama_memory_breakdown_print: | - Host | 845 = 318 + 224 + 302 | llama_memory_breakdown_print: | - CPU_REPACK | 288 = 288 + 0 + 0 | llama_memory_breakdown_print: | - AMX | 31 = 31 + 0 + 0 | ``` after this commit: ``` perplexity: calculating perplexity over 2 chunks, n_ctx=2048, batch_size=2048, n_seq=1 perplexity: 1.98 seconds per pass - ETA 0.05 minutes [1]17.2005,[2]21.8220, Final estimate: PPL = 21.8220 +/- 1.56485 llama_perf_context_print: load time = 719.23 ms llama_perf_context_print: prompt eval time = 1676.23 ms / 4096 tokens ( 0.41 ms per token, 2443.58 tokens per second) llama_perf_context_print: eval time = 0.00 ms / 1 runs ( 0.00 ms per token, inf tokens per second) llama_perf_context_print: total time = 4258.74 ms / 4097 tokens llama_perf_context_print: graphs reused = 0 llama_memory_breakdown_print: | memory breakdown [MiB] | total free self model context compute unaccounted | llama_memory_breakdown_print: | - Host | 845 = 318 + 224 + 302 | llama_memory_breakdown_print: | - AMX | 319 = 319 + 0 + 0 | ``` (no more CPU_REPACK) after this commit, disabling amx: ``` perplexity: calculating perplexity over 2 chunks, n_ctx=2048, batch_size=2048, n_seq=1 perplexity: 2.34 seconds per pass - ETA 0.07 minutes [1]17.2005,[2]21.8220, Final estimate: PPL = 21.8220 +/- 1.56485 llama_perf_context_print: load time = 841.91 ms llama_perf_context_print: prompt eval time = 2057.28 ms / 4096 tokens ( 0.50 ms per token, 1990.98 tokens per second) llama_perf_context_print: eval time = 0.00 ms / 1 runs ( 0.00 ms per token, inf tokens per second) llama_perf_context_print: total time = 6454.51 ms / 4097 tokens llama_perf_context_print: graphs reused = 0 llama_memory_breakdown_print: | memory breakdown [MiB] | total free self model context compute unaccounted | llama_memory_breakdown_print: | - Host | 845 = 318 + 224 + 302 | llama_memory_breakdown_print: | - CPU_REPACK | 319 = 319 + 0 + 0 | ``` => same perplexity. Signed-off-by: Adrien Gallouët --- ggml/src/ggml-cpu/amx/amx.cpp | 61 +++++++++---- ggml/src/ggml-cpu/amx/mmq.cpp | 164 +++++++++++++++++----------------- 2 files changed, 124 insertions(+), 101 deletions(-) diff --git a/ggml/src/ggml-cpu/amx/amx.cpp b/ggml/src/ggml-cpu/amx/amx.cpp index 895a5713753..9baf3e025e6 100644 --- a/ggml/src/ggml-cpu/amx/amx.cpp +++ b/ggml/src/ggml-cpu/amx/amx.cpp @@ -141,27 +141,50 @@ static size_t ggml_backend_amx_buffer_type_get_alignment(ggml_backend_buffer_typ namespace ggml::cpu::amx { class extra_buffer_type : ggml::cpu::extra_buffer_type { bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override { - // handle only 2d gemm for now - auto is_contiguous_2d = [](const struct ggml_tensor * t) { - return ggml_is_contiguous(t) && t->ne[3] == 1 && t->ne[2] == 1; - }; - - if (op->op == GGML_OP_MUL_MAT && is_contiguous_2d(op->src[0]) && // src0 must be contiguous - is_contiguous_2d(op->src[1]) && // src1 must be contiguous - op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_amx_buffer_type() && - op->src[0]->ne[0] % (TILE_K * 2 * 32) == 0 && // TODO: not sure if correct (https://github.com/ggml-org/llama.cpp/pull/16315) - op->ne[0] % (TILE_N * 2) == 0 && // out_features is 32x - (qtype_has_amx_kernels(op->src[0]->type) || (op->src[0]->type == GGML_TYPE_F16))) { - // src1 must be host buffer - if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) { + if (op->op != GGML_OP_MUL_MAT) { + return false; + } + auto * src0 = op->src[0]; + auto * src1 = op->src[1]; + + if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) { + return false; + } + if (!src0->buffer || src0->buffer->buft != ggml_backend_amx_buffer_type()) { + return false; + } + if (src1->buffer && !ggml_backend_buft_is_host(src1->buffer->buft)) { + return false; + } + if (op->ne[0] % (TILE_N * 2)) { + return false; + } + int alignment; + switch (src0->type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q8_0: + alignment = TILE_K; + break; + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ4_XS: + alignment = 256; // QK_K + break; + case GGML_TYPE_F16: + alignment = 16; + break; + default: return false; - } - // src1 must be float32 - if (op->src[1]->type == GGML_TYPE_F32) { - return true; - } } - return false; + if (src0->ne[0] % alignment) { + return false; + } + if (src1->type != GGML_TYPE_F32) { + return false; + } + return true; } ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override { diff --git a/ggml/src/ggml-cpu/amx/mmq.cpp b/ggml/src/ggml-cpu/amx/mmq.cpp index 47c61b88164..b5aca76633c 100644 --- a/ggml/src/ggml-cpu/amx/mmq.cpp +++ b/ggml/src/ggml-cpu/amx/mmq.cpp @@ -1,4 +1,3 @@ - #if defined(__GNUC__) #pragma GCC diagnostic ignored "-Wpedantic" #pragma GCC diagnostic ignored "-Wunused-local-typedefs" @@ -202,35 +201,27 @@ struct tile_config_t{ // advanced-matrix-extensions-intrinsics-functions.html // -#define TC_CONFIG_TILE(i, r, cb) tc.rows[i] = r; tc.colsb[i] = cb -void ggml_tile_config_init(void) { - static thread_local bool is_first_time = true; +inline void ggml_tile_config_init(void) { + static thread_local bool done = false; - if (!is_first_time) { + if (done) { return; } - static thread_local tile_config_t tc; - tile_config_t current_tc; - _tile_storeconfig(¤t_tc); - - // load only when config changes - if (tc.palette_id == 0 || (memcmp(¤t_tc.colsb, &tc.colsb, sizeof(uint16_t) * 8) != 0 && - memcmp(¤t_tc.rows, &tc.rows, sizeof(uint8_t) * 8) != 0)) { - tc.palette_id = 1; - tc.start_row = 0; - TC_CONFIG_TILE(TMM0, 8, 64); - TC_CONFIG_TILE(TMM1, 8, 64); - TC_CONFIG_TILE(TMM2, 16, 32); - TC_CONFIG_TILE(TMM3, 16, 32); - TC_CONFIG_TILE(TMM4, 16, 64); - TC_CONFIG_TILE(TMM5, 16, 64); - TC_CONFIG_TILE(TMM6, 16, 64); - TC_CONFIG_TILE(TMM7, 16, 64); - _tile_loadconfig(&tc); - } - - is_first_time = false; + alignas(64) tile_config_t tc = {}; + tc.palette_id = 1; + tc.start_row = 0; + tc.rows[0] = 8; tc.colsb[0] = 64; + tc.rows[1] = 8; tc.colsb[1] = 64; + tc.rows[2] = 16; tc.colsb[2] = 32; + tc.rows[3] = 16; tc.colsb[3] = 32; + tc.rows[4] = 16; tc.colsb[4] = 64; + tc.rows[5] = 16; tc.colsb[5] = 64; + tc.rows[6] = 16; tc.colsb[6] = 64; + tc.rows[7] = 16; tc.colsb[7] = 64; + + _tile_loadconfig(&tc); + done = true; } // we need an extra 16 * 4B (TILE_N * int32_t) for each NB/KB block for compensation. @@ -268,33 +259,6 @@ int get_row_size(int K) { return row_size; } -// vectorized dtype conversion -inline float FP16_TO_FP32(ggml_half val) { - __m256i v = _mm256_setr_epi16( - val, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0); - __m512 o = _mm512_cvtph_ps(v); - return _mm512_cvtss_f32(o); -} - -inline __m512 FP16_TO_FP32_VEC(ggml_half val) { - __m256i v = _mm256_set1_epi16(val); - return _mm512_cvtph_ps(v); -} - -// horizontal reduce -inline float _mm512_reduce_max_ps(const __m512 x) { - __m512 v = x; - __m512 v1 = _mm512_shuffle_f32x4(v, v, 0x4E); - v = _mm512_max_ps(v, v1); - v1 = _mm512_shuffle_f32x4(v, v, 0xB1); - v = _mm512_max_ps(v, v1); - v1 = _mm512_shuffle_ps(v, v, 0x4E); - v = _mm512_max_ps(v, v1); - v1 = _mm512_shuffle_ps(v, v, 0xB1); - v = _mm512_max_ps(v, v1); - return _mm512_cvtss_f32(v); -} - // transpose utils #define SHUFFLE_EPI32(a, b, mask) \ _mm256_castps_si256(_mm256_shuffle_ps(_mm256_castsi256_ps(a), _mm256_castsi256_ps(b), mask)) @@ -1370,9 +1334,9 @@ struct tinygemm_kernel_avx #define LAUNCH_TINYGEMM_KERNEL_AVX(MB_SIZE, NB_SIZE) \ tinygemm_kernel_avx::apply( \ - K, (const float *)src1->data + mb_start * K, \ - (const type *)src0->data + nb_start * K, \ - (float *)dst->data + mb_start * ldc + nb_start, ldc); + K, (const float *)src1->data + src1_offset + mb_start * K, \ + (const type *)src0->data + src0_offset + nb_start * K, \ + (float *)dst->data + dst_offset + mb_start * ldc + nb_start, ldc) // re-organize in the format {NB, KB, TILE_SIZE}: @@ -2019,11 +1983,11 @@ struct tinygemm_kernel_vnni::apply( \ - KB, (const char *)wdata + 0 * row_size_A, \ - (const char *)src0->data + PACKED_INDEX(nb * kTilesN, 0, KB, TILE_SIZE), \ - (float *) dst->data + 0 * N + nb_start, ldc) +#define LAUNCH_TINYGEMM_KERNEL_VNNI(NB_SIZE) \ + tinygemm_kernel_vnni::apply( \ + KB, wdata_batch, \ + (const char *)src0->data + src0_offset + PACKED_INDEX(nb * kTilesN, 0, KB, TILE_SIZE), \ + (float *) dst->data + dst_offset + nb_start, ldc) template ::value, int>::type = 0> @@ -2079,7 +2043,7 @@ void tinygemm_kernel_amx(int M, int N, int KB, const void * RESTRICT _A, const v _tile_stored(TMM5, Tile5(C_pre), TILE_N * sizeof(int32_t)); if (need_unpack) { - unpack_B(Tile1, B_blk0); + unpack_B(Tile1, B_blk1); _tile_loadd(TMM1, Tile1, TILE_N * VNNI_BLK); } else { _tile_loadd(TMM1, B_blk1, TILE_N * VNNI_BLK); @@ -2336,6 +2300,13 @@ void ggml_backend_amx_convert_weight(struct ggml_tensor * tensor, const void * d }); } +// ne2 is passed explicitly to help compiler optimize repeated calls +inline int64_t ggml_batch_offset(const ggml_tensor * t, int64_t batch_idx, int64_t ne2) { + const int64_t i2 = batch_idx % ne2; + const int64_t i3 = batch_idx / ne2; + return i3 * t->nb[3] + i2 * t->nb[2]; +} + size_t ggml_backend_amx_desired_wsize(const struct ggml_tensor * dst) { struct ggml_tensor * src0 = dst->src[0]; @@ -2348,12 +2319,13 @@ size_t ggml_backend_amx_desired_wsize(const struct ggml_tensor * dst) { const int M = dst->ne[1]; const int K = src0->ne[0]; + const int64_t n_batch = dst->ne[2] * dst->ne[3]; size_t desired_wsize = 0; GGML_DISPATCH_QTYPES(TYPE, [&] { const size_t row_size_A = K / blck_size * sizeof(vec_dot_type); - desired_wsize = M * row_size_A; + desired_wsize = n_batch * M * row_size_A; }); return desired_wsize; @@ -2365,7 +2337,7 @@ size_t ggml_backend_amx_desired_wsize(const struct ggml_tensor * dst) { // src1: input in shape of {M, K}, float32 // dst: output in shape of {M, N}, float32 // -// the function performs: dst = src1 @ src0.T +// the function performs: dst = src1 @ src0.T for each batch // void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_tensor * dst) { struct ggml_tensor * src0 = dst->src[0]; @@ -2382,17 +2354,26 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te const int K = src0->ne[0]; const int ldc = dst->nb[1] / dst->nb[0]; + const int64_t ne2 = dst->ne[2]; + const int64_t n_batch = ne2 * dst->ne[3]; + if (is_floating_type) { constexpr int BLOCK_M = 4; constexpr int BLOCK_N = 6; const int MB = div_up(M, BLOCK_M); const int NB = div_up(N, BLOCK_N); - parallel_for_ggml(params, MB * NB, [&](int begin, int end) { + parallel_for_ggml(params, n_batch * MB * NB, [&](int begin, int end) { GGML_DISPATCH_FLOATING_TYPES(TYPE, [&] { for (int i = begin; i < end; ++i) { - int mb = i / NB; - int nb = i % NB; + int batch_idx = i / (MB * NB); + int remaining = i % (MB * NB); + int mb = remaining / NB; + int nb = remaining % NB; + + int64_t src0_offset = ggml_batch_offset(src0, batch_idx, ne2); + int64_t src1_offset = ggml_batch_offset(src1, batch_idx, ne2); + int64_t dst_offset = ggml_batch_offset(dst, batch_idx, ne2); int mb_start = mb * BLOCK_M; int mb_size = std::min(BLOCK_M, M - mb_start); @@ -2424,10 +2405,10 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te void * wdata = params->wdata; //TODO: performance improvement: merge quant A - if (params->ith == 0) { + // if (params->ith == 0) { GGML_DISPATCH_QTYPES(TYPE, [&] { const size_t row_size_A = K / blck_size * sizeof(vec_dot_type); - const size_t desired_wsize = M * row_size_A; + const size_t desired_wsize = n_batch * M * row_size_A; if (params->wsize < desired_wsize) { GGML_ABORT("insufficient work space size"); } @@ -2436,12 +2417,19 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te // Q4_K, Q5_K, Q6_K, IQ4_XS handles 8 TILE_K per blck_size GGML_ASSERT(TILE_K == blck_size || TILE_K * 8 == blck_size); - const float * A_data = static_cast(src1->data); - for (int m = 0; m < M; ++m) { - from_float(A_data + m * K, (char *)wdata + m * row_size_A, K); - } + parallel_for_ggml(params, n_batch, [&](int begin, int end) { + for (int batch_idx = begin; batch_idx < end; ++batch_idx) { + int64_t src1_offset = ggml_batch_offset(src1, batch_idx, ne2); + const float * A_data = (const float *)((const char *)src1->data + src1_offset); + char * wdata_batch = (char *)wdata + batch_idx * M * row_size_A; + + for (int m = 0; m < M; ++m) { + from_float(A_data + m * K, wdata_batch + m * row_size_A, K); + } + } + }); }); - } + // } ggml_barrier(params->threadpool); @@ -2451,13 +2439,19 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te constexpr int BLOCK_N = TILE_N * kTilesN; const int NB = div_up(N, BLOCK_N); - parallel_for_ggml(params, NB, [&](int begin, int end) { + parallel_for_ggml(params, n_batch * NB, [&](int begin, int end) { GGML_DISPATCH_QTYPES(TYPE, [&] { const int KB = K / blck_size; const int TILE_SIZE = get_tile_size(); const int row_size_A = KB * sizeof(vec_dot_type); for (int i = begin; i < end; ++i) { - int nb = i; + int batch_idx = i / NB; + int nb = i % NB; + + int64_t src0_offset = ggml_batch_offset(src0, batch_idx, ne2); + int64_t dst_offset = ggml_batch_offset(dst, batch_idx, ne2); + const char * wdata_batch = (const char *)wdata + batch_idx * row_size_A; + int nb_start = nb * BLOCK_N; int nb_size = std::min(BLOCK_N, N - nb_start); // 32, 64, 96 @@ -2481,7 +2475,7 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te const int MB = div_up(M, BLOCK_M); const int NB = div_up(N, BLOCK_N); - parallel_for_ggml(params, MB * NB, [&](int begin, int end) { + parallel_for_ggml(params, n_batch * MB * NB, [&](int begin, int end) { // init tile config for each thread ggml_tile_config_init(); @@ -2491,8 +2485,14 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te const int row_size_A = KB * sizeof(vec_dot_type); for (int i = begin; i < end; ++i) { - int mb = i / NB; - int nb = i % NB; + int batch_idx = i / (MB * NB); + int remaining = i % (MB * NB); + int mb = remaining / NB; + int nb = remaining % NB; + + int64_t src0_offset = ggml_batch_offset(src0, batch_idx, ne2); + int64_t dst_offset = ggml_batch_offset(dst, batch_idx, ne2); + const char * wdata_batch = (const char *)wdata + batch_idx * M * row_size_A; int mb_start = mb * BLOCK_M; int mb_size = std::min(BLOCK_M, M - mb_start); @@ -2501,9 +2501,9 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te tinygemm_kernel_amx( mb_size, nb_size, KB, - (const char *)wdata + mb_start * row_size_A, - (const char *)src0->data + PACKED_INDEX(nb * 2, 0, KB, TILE_SIZE), - (float *) dst->data + mb_start * N + nb_start, ldc); + wdata_batch + mb_start * row_size_A, + (const char *)src0->data + src0_offset + PACKED_INDEX(nb * 2, 0, KB, TILE_SIZE), + (float *) dst->data + dst_offset + mb_start * N + nb_start, ldc); } }); }); From 9c1fd5cc6e75973b3d22ae661373972a528d05d4 Mon Sep 17 00:00:00 2001 From: Vishal Singh Date: Fri, 27 Feb 2026 06:13:41 +0530 Subject: [PATCH 191/831] ggml-zendnn: update code for latest ZenDNN API (llama/19923) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - adapt ggml-zendnn.cpp to the new lowoha::matmul interface - update the ZenDNN git tag in CMake to the latest release (ZenDNN‑2026‑WW08) - add static lib support in CMake --- ggml/src/ggml-zendnn/CMakeLists.txt | 63 ++++++++++++++-------------- ggml/src/ggml-zendnn/ggml-zendnn.cpp | 6 +-- 2 files changed, 34 insertions(+), 35 deletions(-) diff --git a/ggml/src/ggml-zendnn/CMakeLists.txt b/ggml/src/ggml-zendnn/CMakeLists.txt index f5cf6eedd3a..9bdb4e836d3 100644 --- a/ggml/src/ggml-zendnn/CMakeLists.txt +++ b/ggml/src/ggml-zendnn/CMakeLists.txt @@ -1,12 +1,19 @@ ggml_add_backend_library(ggml-zendnn ggml-zendnn.cpp) -# Get ZenDNN path if (NOT DEFINED ZENDNN_ROOT OR ZENDNN_ROOT STREQUAL "") set(ZENDNN_ROOT "$ENV{ZENDNN_ROOT}") endif() -# Check if path is still empty or OFF +if (BUILD_SHARED_LIBS) + set(ZENDNN_SHARED_LIB ON) + set(ZENDNN_ARCHIVE_LIB OFF) +else() + set(ZENDNN_SHARED_LIB OFF) + set(ZENDNN_ARCHIVE_LIB ON) +endif() + +# Download and build ZenDNN if not provided if (NOT ZENDNN_ROOT OR ZENDNN_ROOT STREQUAL "" OR ZENDNN_ROOT STREQUAL "OFF") message(STATUS "ZENDNN_ROOT not set. Automatically downloading and building ZenDNN...") message(STATUS "This will take several minutes on first build...") @@ -21,7 +28,7 @@ if (NOT ZENDNN_ROOT OR ZENDNN_ROOT STREQUAL "" OR ZENDNN_ROOT STREQUAL "OFF") ExternalProject_Add( zendnn GIT_REPOSITORY https://github.com/amd/ZenDNN.git - GIT_TAG 21ce8f7879c86bf3637f707fae6f29e0951db5fe + GIT_TAG a18adf8c605fb5f5e52cefd7eda08a7b18febbaf # ZenDNN-2026-WW08 PREFIX ${ZENDNN_PREFIX} SOURCE_DIR ${ZENDNN_SOURCE_DIR} BINARY_DIR ${ZENDNN_BUILD_DIR} @@ -32,7 +39,9 @@ if (NOT ZENDNN_ROOT OR ZENDNN_ROOT STREQUAL "" OR ZENDNN_ROOT STREQUAL "OFF") -DZENDNNL_BUILD_DOXYGEN=OFF -DZENDNNL_BUILD_GTEST=OFF -DZENDNNL_BUILD_BENCHDNN=OFF - # Enable ALL matmul algorithm backends + -DZENDNNL_DEPENDS_FBGEMM=OFF + -DZENDNNL_LIB_BUILD_ARCHIVE=${ZENDNN_ARCHIVE_LIB} + -DZENDNNL_LIB_BUILD_SHARED=${ZENDNN_SHARED_LIB} -DZENDNNL_DEPENDS_AOCLDLP=ON -DZENDNNL_DEPENDS_ONEDNN=ON -DZENDNNL_DEPENDS_LIBXSMM=ON @@ -45,47 +54,37 @@ if (NOT ZENDNN_ROOT OR ZENDNN_ROOT STREQUAL "" OR ZENDNN_ROOT STREQUAL "OFF") LOG_INSTALL ON ) - # Add dependency so ZenDNN builds before our library add_dependencies(ggml-zendnn zendnn) - - # Set ZENDNN_ROOT to the installation directory set(ZENDNN_ROOT ${ZENDNN_INSTALL_DIR}) - message(STATUS "ZenDNN will be built to: ${ZENDNN_ROOT}") else() message(STATUS "Using custom ZenDNN installation at: ${ZENDNN_ROOT}") endif() -# ZenDNN headers + libs target_include_directories(ggml-zendnn PRIVATE ${ZENDNN_ROOT}/zendnnl/include - ${ZENDNN_ROOT}/deps/aocldlp/include - ${ZENDNN_ROOT}/deps/aoclutils/include ${ZENDNN_ROOT}/deps/json/include - ${ZENDNN_ROOT}/deps/libxsmm/include + ${ZENDNN_ROOT}/deps/aoclutils/include + ${ZENDNN_ROOT}/deps/aocldlp/include ${ZENDNN_ROOT}/deps/onednn/include -) + ${ZENDNN_ROOT}/deps/libxsmm/include) -target_link_directories(ggml-zendnn PRIVATE - ${ZENDNN_ROOT}/zendnnl/lib - ${ZENDNN_ROOT}/deps/aocldlp/lib - ${ZENDNN_ROOT}/deps/aoclutils/lib - ${ZENDNN_ROOT}/deps/libxsmm/lib - ${ZENDNN_ROOT}/deps/onednn/lib -) +if (ZENDNN_SHARED_LIB) + target_link_directories(ggml-zendnn PRIVATE ${ZENDNN_ROOT}/zendnnl/lib) + target_link_libraries(ggml-zendnn PRIVATE zendnnl) +elseif (ZENDNN_ARCHIVE_LIB) + target_link_libraries(ggml-zendnn PRIVATE + ${ZENDNN_ROOT}/zendnnl/lib/libzendnnl_archive.a + ${ZENDNN_ROOT}/deps/aoclutils/${CMAKE_INSTALL_LIBDIR}/libaoclutils.a + ${ZENDNN_ROOT}/deps/aoclutils/${CMAKE_INSTALL_LIBDIR}/libau_cpuid.a + ${ZENDNN_ROOT}/deps/aocldlp/lib/libaocl-dlp.a + ${ZENDNN_ROOT}/deps/onednn/${CMAKE_INSTALL_LIBDIR}/libdnnl.a + ${ZENDNN_ROOT}/deps/libxsmm/lib/libxsmm.a + ${ZENDNN_ROOT}/deps/libxsmm/lib/libxsmmext.a + ${ZENDNN_ROOT}/deps/libxsmm/lib/libxsmmnoblas.a) +endif() -target_link_libraries(ggml-zendnn PRIVATE - zendnnl_archive # ZenDNN main - aocl-dlp # AOCL libraries - aoclutils - au_cpuid - dnnl # OneDNN - xsmm # libxsmm small matrix math - xsmmext - xsmmnoblas - m - pthread -) +target_link_libraries(ggml-zendnn PRIVATE m pthread) if (GGML_OPENMP) target_link_libraries(ggml-zendnn PRIVATE OpenMP::OpenMP_CXX) diff --git a/ggml/src/ggml-zendnn/ggml-zendnn.cpp b/ggml/src/ggml-zendnn/ggml-zendnn.cpp index 551c15bb4ae..c8760304008 100644 --- a/ggml/src/ggml-zendnn/ggml-zendnn.cpp +++ b/ggml/src/ggml-zendnn/ggml-zendnn.cpp @@ -41,13 +41,13 @@ static bool ggml_zendnn_matmul(ggml_backend_zendnn_context * ctx, int64_t m, int const TA * A, int64_t lda, const TB * B, int64_t ldb, TC * C, int64_t ldc) { - zendnnl::lowoha::lowoha_params params; + zendnnl::lowoha::matmul::matmul_params params; params.dtypes.src = ggml_to_zendnn_type(); params.dtypes.wei = ggml_to_zendnn_type(); params.dtypes.dst = ggml_to_zendnn_type(); params.num_threads = ctx->n_threads; - zendnnl::lowoha::status_t status = zendnnl::lowoha::matmul_direct( + zendnnl::error_handling::status_t status = zendnnl::lowoha::matmul::matmul_direct( 'r', false, true, // row-major, don't transpose B, transpose A (because it's column-major) n, // M: rows of B and C m, // N: cols of A^T and C @@ -63,7 +63,7 @@ static bool ggml_zendnn_matmul(ggml_backend_zendnn_context * ctx, int64_t m, int params // params ); - if (status != zendnnl::lowoha::status_t::success) { + if (status != zendnnl::error_handling::status_t::success) { GGML_LOG_ERROR("%s, ZenDNN matmul failed: status=%d\n", __func__, static_cast(status)); return false; } From 64f48603e6aad90e0a41ef4b240dfe5febcd9b3d Mon Sep 17 00:00:00 2001 From: Neo Zhang Date: Fri, 27 Feb 2026 09:26:07 +0800 Subject: [PATCH 192/831] replace the magic nunber 768 by max work group size to support iGPU (llama/19920) Co-authored-by: Neo Zhang Jianyu --- ggml/src/ggml-sycl/add-id.cpp | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-sycl/add-id.cpp b/ggml/src/ggml-sycl/add-id.cpp index 00c073cf937..8929017a999 100644 --- a/ggml/src/ggml-sycl/add-id.cpp +++ b/ggml/src/ggml-sycl/add-id.cpp @@ -55,7 +55,11 @@ void ggml_sycl_add_id(ggml_backend_sycl_context& ctx, ggml_tensor* dst) { const int32_t* src2_d = (const int32_t*)src2->data; float* dst_d = (float*)dst->data; - int threads = std::min((int)ne00, 768); // cols + const unsigned int max_work_group_size = ggml_sycl_info().max_work_group_sizes[ctx.device]; + assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0); + + int threads = std::min((unsigned int)ne00, max_work_group_size); // cols + ctx.stream()->parallel_for( sycl::nd_range<3>( sycl::range<3>(1, ne02, ne01) * sycl::range<3>(1, 1, threads), From 473405606721ce248077ca99b5ee0db95b8bdd6f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 27 Feb 2026 12:19:27 +0200 Subject: [PATCH 193/831] sync : ggml --- scripts/sync-ggml.last | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index 8db0963de78..769d0fb6684 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -68fee723b1f0c2432258b77710f3ca973b3bc5cc +4773cde162a55f0d10a6a6d7c2ea4378e30e0b01 From 84f8db71d86e810d30df8c83e38f116f4e89ab16 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 27 Feb 2026 12:23:40 +0200 Subject: [PATCH 194/831] talk-llama : sync llama.cpp --- examples/talk-llama/CMakeLists.txt | 2 +- examples/talk-llama/llama-adapter.h | 3 + examples/talk-llama/llama-arch.cpp | 46 ++ examples/talk-llama/llama-arch.h | 5 + examples/talk-llama/llama-context.cpp | 187 +------ examples/talk-llama/llama-context.h | 7 +- examples/talk-llama/llama-graph.cpp | 207 +++++--- examples/talk-llama/llama-graph.h | 11 +- examples/talk-llama/llama-impl.cpp | 4 +- examples/talk-llama/llama-kv-cache.cpp | 3 + .../talk-llama/llama-memory-recurrent.cpp | 2 +- examples/talk-llama/llama-model-saver.cpp | 1 + examples/talk-llama/llama-model.cpp | 216 ++++++-- examples/talk-llama/llama-model.h | 25 +- examples/talk-llama/llama-quant.cpp | 262 ++++++---- examples/talk-llama/llama-vocab.cpp | 40 +- examples/talk-llama/llama-vocab.h | 3 + examples/talk-llama/llama.h | 1 + examples/talk-llama/models/deepseek2.cpp | 4 +- examples/talk-llama/models/delta-net-base.cpp | 376 ++++++++++++++ examples/talk-llama/models/eurobert.cpp | 97 ++++ examples/talk-llama/models/falcon-h1.cpp | 4 +- examples/talk-llama/models/glm4.cpp | 17 +- examples/talk-llama/models/granite-hybrid.cpp | 2 +- examples/talk-llama/models/jais2.cpp | 123 +++++ examples/talk-llama/models/jamba.cpp | 2 +- examples/talk-llama/models/kimi-linear.cpp | 421 +--------------- examples/talk-llama/models/lfm2.cpp | 263 +++++----- ...graph-context-mamba.cpp => mamba-base.cpp} | 8 +- examples/talk-llama/models/mamba.cpp | 3 +- examples/talk-llama/models/models.h | 165 +++--- examples/talk-llama/models/modern-bert.cpp | 7 - examples/talk-llama/models/nemotron-h.cpp | 10 +- examples/talk-llama/models/paddleocr.cpp | 122 +++++ examples/talk-llama/models/plamo2.cpp | 4 +- examples/talk-llama/models/qwen35.cpp | 464 ++--------------- examples/talk-llama/models/qwen35moe.cpp | 468 +++--------------- examples/talk-llama/models/qwen3next.cpp | 336 +------------ examples/talk-llama/models/rwkv6-base.cpp | 2 + examples/talk-llama/models/rwkv7-base.cpp | 2 + examples/talk-llama/unicode.cpp | 6 + 41 files changed, 1719 insertions(+), 2212 deletions(-) create mode 100644 examples/talk-llama/models/delta-net-base.cpp create mode 100644 examples/talk-llama/models/eurobert.cpp create mode 100644 examples/talk-llama/models/jais2.cpp rename examples/talk-llama/models/{graph-context-mamba.cpp => mamba-base.cpp} (97%) create mode 100644 examples/talk-llama/models/paddleocr.cpp diff --git a/examples/talk-llama/CMakeLists.txt b/examples/talk-llama/CMakeLists.txt index 549842a2474..1adeef8f511 100644 --- a/examples/talk-llama/CMakeLists.txt +++ b/examples/talk-llama/CMakeLists.txt @@ -34,7 +34,7 @@ if (WHISPER_SDL2) unicode.cpp unicode-data.cpp ${SRC_MODELS}) - target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS}) + target_include_directories(${TARGET} PRIVATE . ${SDL2_INCLUDE_DIRS}) target_link_libraries(${TARGET} PRIVATE common common-sdl whisper ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT}) install(TARGETS ${TARGET} RUNTIME) diff --git a/examples/talk-llama/llama-adapter.h b/examples/talk-llama/llama-adapter.h index d275d25425e..aa3ab63ad75 100644 --- a/examples/talk-llama/llama-adapter.h +++ b/examples/talk-llama/llama-adapter.h @@ -39,6 +39,8 @@ struct llama_adapter_cvec { std::vector tensors; // per layer }; +using llama_adapter_cvec_ptr = std::shared_ptr; + // // llama_adapter_lora // @@ -84,3 +86,4 @@ struct llama_adapter_lora { }; using llama_adapter_loras = std::unordered_map; +using llama_adapter_loras_ptr = std::unique_ptr; diff --git a/examples/talk-llama/llama-arch.cpp b/examples/talk-llama/llama-arch.cpp index 416c17463ee..47e8d5278ac 100644 --- a/examples/talk-llama/llama-arch.cpp +++ b/examples/talk-llama/llama-arch.cpp @@ -26,6 +26,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_NEO_BERT, "neo-bert" }, { LLM_ARCH_JINA_BERT_V2, "jina-bert-v2" }, { LLM_ARCH_JINA_BERT_V3, "jina-bert-v3" }, + { LLM_ARCH_EUROBERT, "eurobert" }, { LLM_ARCH_BLOOM, "bloom" }, { LLM_ARCH_STABLELM, "stablelm" }, { LLM_ARCH_QWEN, "qwen" }, @@ -79,6 +80,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_T5, "t5" }, { LLM_ARCH_T5ENCODER, "t5encoder" }, { LLM_ARCH_JAIS, "jais" }, + { LLM_ARCH_JAIS2, "jais2" }, { LLM_ARCH_NEMOTRON, "nemotron" }, { LLM_ARCH_NEMOTRON_H, "nemotron_h" }, { LLM_ARCH_NEMOTRON_H_MOE, "nemotron_h_moe" }, @@ -120,6 +122,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_RND1, "rnd1" }, { LLM_ARCH_PANGU_EMBED, "pangu-embedded" }, { LLM_ARCH_MISTRAL3, "mistral3" }, + { LLM_ARCH_PADDLEOCR, "paddleocr" }, { LLM_ARCH_MIMO2, "mimo2" }, { LLM_ARCH_STEP35, "step35" }, { LLM_ARCH_LLAMA_EMBED, "llama-embed" }, @@ -346,6 +349,7 @@ static const std::map LLM_TENSOR_NAMES = { { LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" }, { LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" }, { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_GATE_UP_EXPS, "blk.%d.ffn_gate_up_exps" }, { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, @@ -367,6 +371,7 @@ static const std::map LLM_TENSOR_NAMES = { { LLM_TENSOR_TOKEN_TYPES, "token_types" }, { LLM_TENSOR_CLS, "cls" }, { LLM_TENSOR_CLS_OUT, "cls.output" }, + { LLM_TENSOR_CLS_NORM, "cls.norm" }, { LLM_TENSOR_ENC_OUTPUT_NORM, "enc.output_norm" }, { LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" }, { LLM_TENSOR_SSM_A_NOSCAN, "blk.%d.ssm_a" }, @@ -737,6 +742,7 @@ static std::set llm_get_tensor_names(llm_arch arch) { case LLM_ARCH_INTERNLM2: case LLM_ARCH_GRANITE: case LLM_ARCH_ERNIE4_5: + case LLM_ARCH_PADDLEOCR: case LLM_ARCH_SMOLLM3: case LLM_ARCH_DREAM: case LLM_ARCH_LLADA: @@ -815,6 +821,20 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_CLS, LLM_TENSOR_CLS_OUT, }; + case LLM_ARCH_EUROBERT: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_FFN_DOWN, + }; case LLM_ARCH_MODERN_BERT: return { LLM_TENSOR_TOKEN_EMBD, @@ -828,6 +848,7 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_FFN_NORM, LLM_TENSOR_CLS, LLM_TENSOR_CLS_OUT, + LLM_TENSOR_CLS_NORM, }; case LLM_ARCH_JINA_BERT_V2: return { @@ -984,6 +1005,7 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_FFN_GATE_EXPS, LLM_TENSOR_FFN_DOWN_EXPS, LLM_TENSOR_FFN_UP_EXPS, + LLM_TENSOR_FFN_GATE_UP_EXPS, LLM_TENSOR_FFN_GATE_INP_SHEXP, LLM_TENSOR_FFN_GATE_SHEXP, LLM_TENSOR_FFN_DOWN_SHEXP, @@ -1041,6 +1063,7 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_FFN_GATE_EXPS, LLM_TENSOR_FFN_DOWN_EXPS, LLM_TENSOR_FFN_UP_EXPS, + LLM_TENSOR_FFN_GATE_UP_EXPS, LLM_TENSOR_FFN_GATE_INP_SHEXP, LLM_TENSOR_FFN_GATE_SHEXP, LLM_TENSOR_FFN_DOWN_SHEXP, @@ -1581,6 +1604,7 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_FFN_GATE_EXPS, LLM_TENSOR_FFN_DOWN_EXPS, LLM_TENSOR_FFN_UP_EXPS, + LLM_TENSOR_FFN_GATE_UP_EXPS, LLM_TENSOR_FFN_GATE_INP_SHEXP, LLM_TENSOR_FFN_GATE_SHEXP, LLM_TENSOR_FFN_DOWN_SHEXP, @@ -1633,6 +1657,12 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_FFN_DOWN, LLM_TENSOR_ATTN_POST_NORM, LLM_TENSOR_FFN_POST_NORM, + LLM_TENSOR_NEXTN_EH_PROJ, + LLM_TENSOR_NEXTN_EMBED_TOKENS, + LLM_TENSOR_NEXTN_ENORM, + LLM_TENSOR_NEXTN_HNORM, + LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, + LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, }; case LLM_ARCH_GLM4_MOE: return { @@ -1783,6 +1813,20 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_FFN_GATE, LLM_TENSOR_FFN_DOWN, }; + case LLM_ARCH_JAIS2: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_FFN_DOWN, + }; case LLM_ARCH_NEMOTRON_H: return { LLM_TENSOR_TOKEN_EMBD, @@ -2512,6 +2556,7 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_OUTPUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, {LLM_TENSOR_CLS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, {LLM_TENSOR_CLS_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_CLS_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, {LLM_TENSOR_DENSE_2_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output {LLM_TENSOR_DENSE_3_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output {LLM_TENSOR_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, @@ -2644,6 +2689,7 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_FFN_DOWN_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, {LLM_TENSOR_FFN_GATE_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, {LLM_TENSOR_FFN_UP_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, + {LLM_TENSOR_FFN_GATE_UP_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, {LLM_TENSOR_FFN_DOWN_CHEXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, {LLM_TENSOR_FFN_GATE_CHEXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, {LLM_TENSOR_FFN_UP_CHEXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, diff --git a/examples/talk-llama/llama-arch.h b/examples/talk-llama/llama-arch.h index 521944370b4..6d1b1df31c0 100644 --- a/examples/talk-llama/llama-arch.h +++ b/examples/talk-llama/llama-arch.h @@ -30,6 +30,7 @@ enum llm_arch { LLM_ARCH_NEO_BERT, LLM_ARCH_JINA_BERT_V2, LLM_ARCH_JINA_BERT_V3, + LLM_ARCH_EUROBERT, LLM_ARCH_BLOOM, LLM_ARCH_STABLELM, LLM_ARCH_QWEN, @@ -83,6 +84,7 @@ enum llm_arch { LLM_ARCH_T5, LLM_ARCH_T5ENCODER, LLM_ARCH_JAIS, + LLM_ARCH_JAIS2, LLM_ARCH_NEMOTRON, LLM_ARCH_NEMOTRON_H, LLM_ARCH_NEMOTRON_H_MOE, @@ -124,6 +126,7 @@ enum llm_arch { LLM_ARCH_RND1, LLM_ARCH_PANGU_EMBED, LLM_ARCH_MISTRAL3, + LLM_ARCH_PADDLEOCR, LLM_ARCH_MIMO2, LLM_ARCH_STEP35, LLM_ARCH_LLAMA_EMBED, @@ -370,6 +373,7 @@ enum llm_tensor { LLM_TENSOR_FFN_DOWN_EXPS, // merged experts LLM_TENSOR_FFN_GATE_EXPS, LLM_TENSOR_FFN_UP_EXPS, + LLM_TENSOR_FFN_GATE_UP_EXPS, LLM_TENSOR_FFN_DOWN_SHEXP, LLM_TENSOR_FFN_GATE_SHEXP, LLM_TENSOR_FFN_UP_SHEXP, @@ -497,6 +501,7 @@ enum llm_tensor { LLM_TENSOR_ENC_OUTPUT_NORM, LLM_TENSOR_CLS, LLM_TENSOR_CLS_OUT, + LLM_TENSOR_CLS_NORM, LLM_TENSOR_CONV1D, LLM_TENSOR_CONVNEXT_DW, LLM_TENSOR_CONVNEXT_NORM, diff --git a/examples/talk-llama/llama-context.cpp b/examples/talk-llama/llama-context.cpp index 99035b6cace..98d055d34ef 100644 --- a/examples/talk-llama/llama-context.cpp +++ b/examples/talk-llama/llama-context.cpp @@ -22,6 +22,8 @@ llama_context::llama_context( const llama_model & model, llama_context_params params) : model(model), + cvec(std::make_unique()), + loras(std::make_unique()), balloc(std::make_unique(model.hparams.n_pos_per_embd())) { // TODO warning when creating llama_context with awkward ctx size that is not a power of 2, // may need to be backend-dependent @@ -710,8 +712,6 @@ int64_t llama_context::output_resolve_row(int32_t i) const { } float * llama_context::get_logits_ith(int32_t i) { - int64_t j = -1; - output_reorder(); try { @@ -719,26 +719,7 @@ float * llama_context::get_logits_ith(int32_t i) { throw std::runtime_error("no logits"); } - // TODO: use output_resolve_row() - if (i < 0) { - j = n_outputs + i; - if (j < 0) { - throw std::runtime_error(format("negative index out of range [0, %d)", n_outputs)); - } - } else if ((size_t) i >= output_ids.size()) { - throw std::runtime_error(format("out of range [0, %zu)", output_ids.size())); - } else { - j = output_ids[i]; - } - - if (j < 0) { - throw std::runtime_error(format("batch.logits[%d] != true", i)); - } - if (j >= n_outputs) { - // This should not happen - throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs)); - } - + const int64_t j = output_resolve_row(i); return logits.data + j*model.vocab.n_tokens(); } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what()); @@ -761,8 +742,6 @@ llama_token * llama_context::get_sampled_tokens() const{ } float * llama_context::get_embeddings_ith(int32_t i) { - int64_t j = -1; - output_reorder(); try { @@ -770,26 +749,7 @@ float * llama_context::get_embeddings_ith(int32_t i) { throw std::runtime_error("no embeddings"); } - // TODO: use output_resolve_row() - if (i < 0) { - j = n_outputs + i; - if (j < 0) { - throw std::runtime_error(format("negative index out of range [0, %d)", n_outputs)); - } - } else if ((size_t) i >= output_ids.size()) { - throw std::runtime_error(format("out of range [0, %zu)", output_ids.size())); - } else { - j = output_ids[i]; - } - - if (j < 0) { - throw std::runtime_error(format("batch.logits[%d] != true", i)); - } - if (j >= n_outputs) { - // This should not happen - throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs)); - } - + const int64_t j = output_resolve_row(i); const uint32_t n_embd_out = model.hparams.n_embd_out(); return embd.data + j*n_embd_out; } catch (const std::exception & err) { @@ -1065,11 +1025,11 @@ void llama_context::set_adapters_lora(llama_adapter_lora ** adapters, size_t n_a return; } - loras.clear(); + loras.reset(new llama_adapter_loras()); for (size_t i = 0; i < n_adapters; i ++) { if (scales[i] != 0.0f) { - loras[adapters[i]] = scales[i]; + loras->insert({adapters[i], scales[i]}); } } @@ -1079,14 +1039,14 @@ void llama_context::set_adapters_lora(llama_adapter_lora ** adapters, size_t n_a bool llama_context::adapters_lora_are_same(llama_adapter_lora ** adapters, size_t n_adapters, float * scales) { LLAMA_LOG_DEBUG("%s: adapters = %p\n", __func__, (void *) adapters); - if (n_adapters != loras.size()) { + if (n_adapters != loras->size()) { return false; } for (size_t i = 0; i < n_adapters; i ++) { - auto it = loras.find(adapters[i]); + auto it = loras->find(adapters[i]); - if (it == loras.end() || it->second != scales[i]) { + if (it == loras->end() || it->second != scales[i]) { return false; } } @@ -1104,7 +1064,7 @@ bool llama_context::set_adapter_cvec( // TODO: should we reserve? - return cvec.apply(model, data, len, n_embd, il_start, il_end); + return cvec->apply(model, data, len, n_embd, il_start, il_end); } llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) { @@ -2081,8 +2041,8 @@ llm_graph_params llama_context::graph_params( /*.gtype =*/ gtype, /*.sched =*/ sched.get(), /*.backend_cpu =*/ backend_cpu, - /*.cvec =*/ &cvec, - /*.loras =*/ &loras, + /*.cvec =*/ cvec.get(), + /*.loras =*/ loras.get(), /*.mctx =*/ mctx, /*.cross =*/ &cross, /*.samplers =*/ sampling.samplers, @@ -2480,64 +2440,6 @@ size_t llama_context::state_write_data(llama_io_write_i & io) { // TODO: add more model-specific info which should prevent loading the session file if not identical } - // write output ids - { - LLAMA_LOG_DEBUG("%s: - writing output ids\n", __func__); - - const auto n_outputs = this->n_outputs; - const auto & output_ids = this->output_ids; - - std::vector w_output_pos; - - w_output_pos.resize(n_outputs); - - // build a more compact representation of the output ids - for (size_t i = 0; i < n_batch(); ++i) { - // map an output id to a position in the batch - int64_t pos = output_ids[i]; - if (pos >= 0) { - GGML_ASSERT(pos < n_outputs); - w_output_pos[pos] = i; - } - } - - io.write(&n_outputs, sizeof(n_outputs)); - - if (n_outputs) { - io.write(w_output_pos.data(), n_outputs * sizeof(int32_t)); - } - } - - // [TAG_CONTEXT_STATE_LOGITS] - // write logits - { - LLAMA_LOG_DEBUG("%s: - writing logits\n", __func__); - - const uint64_t logits_size = std::min((uint64_t) this->logits.size, (uint64_t) n_outputs * model.vocab.n_tokens()); - - io.write(&logits_size, sizeof(logits_size)); - - if (logits_size) { - io.write(logits.data, logits_size * sizeof(float)); - } - } - - // write embeddings - { - LLAMA_LOG_DEBUG("%s: - writing embeddings\n", __func__); - - const uint64_t embd_size = std::min((uint64_t) this->embd.size, (uint64_t) n_outputs * model.hparams.n_embd); - - io.write(&embd_size, sizeof(embd_size)); - - if (embd_size) { - io.write(embd.data, embd_size * sizeof(float)); - } - } - - // TODO: handle sampling buffers and samplers state ? - // https://github.com/ggml-org/llama.cpp/pull/17004 - if (memory != nullptr) { LLAMA_LOG_DEBUG("%s: - writing memory module\n", __func__); memory->state_write(io); @@ -2563,70 +2465,6 @@ size_t llama_context::state_read_data(llama_io_read_i & io) { // TODO: add more info which needs to be identical but which is not verified otherwise } - // read output ids - { - LLAMA_LOG_DEBUG("%s: - reading output ids\n", __func__); - - auto n_outputs = this->n_outputs; - io.read_to(&n_outputs, sizeof(n_outputs)); - - if (n_outputs > output_reserve(n_outputs)) { - throw std::runtime_error("could not reserve outputs"); - } - - std::vector output_pos; - - if (n_outputs) { - output_pos.resize(n_outputs); - io.read_to(output_pos.data(), n_outputs * sizeof(int32_t)); - - for (int32_t i = 0; i < (int32_t) output_pos.size(); ++i) { - int32_t id = output_pos[i]; - if ((uint32_t) id >= n_batch()) { - throw std::runtime_error(format("invalid output id, %d does not fit in batch size of %u", id, n_batch())); - } - this->output_ids[id] = i; - } - - this->n_outputs = n_outputs; - } - } - - // read logits - { - LLAMA_LOG_DEBUG("%s: - reading logits\n", __func__); - - uint64_t logits_size; - io.read_to(&logits_size, sizeof(logits_size)); - - if (this->logits.size < logits_size) { - throw std::runtime_error("logits buffer too small"); - } - - if (logits_size) { - io.read_to(this->logits.data, logits_size * sizeof(float)); - } - } - - // read embeddings - { - LLAMA_LOG_DEBUG("%s: - reading embeddings\n", __func__); - - uint64_t embd_size; - io.read_to(&embd_size, sizeof(embd_size)); - - if (this->embd.size < embd_size) { - throw std::runtime_error("embeddings buffer too small"); - } - - if (embd_size) { - io.read_to(this->embd.data, embd_size * sizeof(float)); - } - } - - // TODO: handle sampling buffers and samplers state ? - // https://github.com/ggml-org/llama.cpp/pull/17004 - if (memory) { LLAMA_LOG_DEBUG("%s: - reading memory module\n", __func__); @@ -2759,6 +2597,7 @@ void llama_context::opt_init(struct llama_model * model, struct llama_opt_params llama_set_param(model->cls_b, param_filter, param_filter_ud); llama_set_param(model->cls_out, param_filter, param_filter_ud); llama_set_param(model->cls_out_b, param_filter, param_filter_ud); + llama_set_param(model->cls_norm, param_filter, param_filter_ud); for (struct llama_layer & layer : model->layers) { for (size_t i = 0; i < sizeof(layer)/sizeof(struct ggml_tensor *); ++i) { diff --git a/examples/talk-llama/llama-context.h b/examples/talk-llama/llama-context.h index a8e53f335cc..e0d0085c1c3 100644 --- a/examples/talk-llama/llama-context.h +++ b/examples/talk-llama/llama-context.h @@ -256,9 +256,10 @@ struct llama_context { const llama_model & model; - llama_cparams cparams; - llama_adapter_cvec cvec; - llama_adapter_loras loras; + llama_cparams cparams; + + llama_adapter_cvec_ptr cvec; + llama_adapter_loras_ptr loras; llama_cross cross; // TODO: tmp for handling cross-attention - need something better probably diff --git a/examples/talk-llama/llama-graph.cpp b/examples/talk-llama/llama-graph.cpp index bba747d37b5..23a86ea2905 100644 --- a/examples/talk-llama/llama-graph.cpp +++ b/examples/talk-llama/llama-graph.cpp @@ -17,6 +17,41 @@ #include #include +// dedup helpers + +static ggml_tensor * build_kq_mask( + ggml_context * ctx, + const llama_kv_cache_context * mctx, + const llama_ubatch & ubatch, + const llama_cparams & cparams) { + const auto n_kv = mctx->get_n_kv(); + const auto n_tokens = ubatch.n_tokens; + const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq; + + return ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream); +} + +static bool can_reuse_kq_mask( + ggml_tensor * kq_mask, + const llama_kv_cache_context * mctx, + const llama_ubatch & ubatch, + const llama_cparams & cparams) { + const auto n_kv = mctx->get_n_kv(); + const auto n_tokens = ubatch.n_tokens; + const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq; + + bool res = true; + + res &= (kq_mask->ne[0] == n_kv); + res &= (kq_mask->ne[1] == n_tokens/n_stream); + res &= (kq_mask->ne[2] == 1); + res &= (kq_mask->ne[3] == n_stream); + + return res; +} + +// impl + void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) { if (ubatch->token) { const int64_t n_tokens = ubatch->n_tokens; @@ -150,7 +185,10 @@ bool llm_graph_input_out_ids::can_reuse(const llm_graph_params & params) { } void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) { - if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) { + if (cparams.embeddings && + (cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN || + cparams.pooling_type == LLAMA_POOLING_TYPE_RANK )) { + const int64_t n_tokens = ubatch->n_tokens; const int64_t n_seq_tokens = ubatch->n_seq_tokens; const int64_t n_seqs_unq = ubatch->n_seqs_unq; @@ -403,8 +441,7 @@ bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) { res &= self_k_idxs->ne[0] == params.ubatch.n_tokens; //res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there - res &= self_kq_mask->ne[0] == mctx->get_n_kv(); - res &= self_kq_mask->ne[1] == params.ubatch.n_tokens; + res &= can_reuse_kq_mask(self_kq_mask, mctx, params.ubatch, params.cparams); return res; } @@ -424,8 +461,7 @@ bool llm_graph_input_attn_k::can_reuse(const llm_graph_params & params) { res &= self_k_idxs->ne[0] == params.ubatch.n_tokens; - res &= self_kq_mask->ne[0] == mctx->get_n_kv(); - res &= self_kq_mask->ne[1] == params.ubatch.n_tokens; + res &= can_reuse_kq_mask(self_kq_mask, mctx, params.ubatch, params.cparams); return res; } @@ -455,11 +491,8 @@ bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) { res &= self_k_idxs_swa->ne[0] == params.ubatch.n_tokens; //res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there - res &= self_kq_mask->ne[0] == mctx->get_base()->get_n_kv(); - res &= self_kq_mask->ne[1] == params.ubatch.n_tokens; - - res &= self_kq_mask_swa->ne[0] == mctx->get_swa()->get_n_kv(); - res &= self_kq_mask_swa->ne[1] == params.ubatch.n_tokens; + res &= can_reuse_kq_mask(self_kq_mask, mctx->get_base(), params.ubatch, params.cparams); + res &= can_reuse_kq_mask(self_kq_mask_swa, mctx->get_swa(), params.ubatch, params.cparams); return res; } @@ -521,8 +554,7 @@ bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) { res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens; //res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there - res &= inp_attn->self_kq_mask->ne[0] == mctx->get_attn()->get_n_kv(); - res &= inp_attn->self_kq_mask->ne[1] == params.ubatch.n_tokens; + res &= can_reuse_kq_mask(inp_attn->self_kq_mask, mctx->get_attn(), params.ubatch, params.cparams); res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs(); @@ -565,8 +597,7 @@ bool llm_graph_input_mem_hybrid_k::can_reuse(const llm_graph_params & params) { res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens; - res &= inp_attn->self_kq_mask->ne[0] == mctx->get_attn()->get_n_kv(); - res &= inp_attn->self_kq_mask->ne[1] == params.ubatch.n_tokens; + res &= can_reuse_kq_mask(inp_attn->self_kq_mask, mctx->get_attn(), params.ubatch, params.cparams); res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs(); @@ -625,8 +656,7 @@ bool llm_graph_input_mem_hybrid_iswa::can_reuse(const llm_graph_params & params) res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens; //res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there - res &= inp_attn->self_kq_mask->ne[0] == attn_ctx->get_base()->get_n_kv(); - res &= inp_attn->self_kq_mask->ne[1] == params.ubatch.n_tokens; + res &= can_reuse_kq_mask(inp_attn->self_kq_mask, attn_ctx->get_base(), params.ubatch, params.cparams); } // swa tensors may not be allocated if there are no SWA attention layers @@ -634,8 +664,7 @@ bool llm_graph_input_mem_hybrid_iswa::can_reuse(const llm_graph_params & params) res &= inp_attn->self_k_idxs_swa->ne[0] == params.ubatch.n_tokens; //res &= inp_attn->self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there - res &= inp_attn->self_kq_mask_swa->ne[0] == attn_ctx->get_swa()->get_n_kv(); - res &= inp_attn->self_kq_mask_swa->ne[1] == params.ubatch.n_tokens; + res &= can_reuse_kq_mask(inp_attn->self_kq_mask_swa, attn_ctx->get_swa(), params.ubatch, params.cparams); } res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs(); @@ -1099,8 +1128,8 @@ ggml_tensor * llm_graph_context::build_ffn( if (down) { cur = build_lora_mm(down, cur); - if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) { - // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators + if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE || arch == LLM_ARCH_JAIS2) { + // GLM4, GLM4_MOE, and JAIS2 seem to have numerical issues with half-precision accumulators ggml_mul_mat_set_prec(cur, GGML_PREC_F32); } } @@ -1136,7 +1165,8 @@ ggml_tensor * llm_graph_context::build_moe_ffn( float w_scale, llama_expert_gating_func_type gating_op, int il, - ggml_tensor * probs_in) const { + ggml_tensor * probs_in, + ggml_tensor * gate_up_exps) const { return build_moe_ffn( cur, gate_inp, /* gate_inp_b */ nullptr, @@ -1152,7 +1182,8 @@ ggml_tensor * llm_graph_context::build_moe_ffn( w_scale, gating_op, il, - probs_in + probs_in, + gate_up_exps ); } @@ -1175,7 +1206,9 @@ ggml_tensor * llm_graph_context::build_moe_ffn( float w_scale, llama_expert_gating_func_type gating_op, int il, - ggml_tensor * probs_in) const { + ggml_tensor * probs_in, + ggml_tensor * gate_up_exps, + ggml_tensor * gate_up_exps_b) const { const int64_t n_embd = cur->ne[0]; const int64_t n_tokens = cur->ne[1]; const bool weight_before_ffn = arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN @@ -1314,27 +1347,49 @@ ggml_tensor * llm_graph_context::build_moe_ffn( cb(cur, "ffn_moe_weighted", il); } - ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] - cb(up, "ffn_moe_up", il); + ggml_tensor * up = nullptr; + ggml_tensor * experts = nullptr; - if (up_exps_b) { - up = ggml_add_id(ctx0, up, up_exps_b, selected_experts); - cb(up, "ffn_moe_up_biased", il); - } + if (gate_up_exps) { + // merged gate_up path: one mul_mat_id, then split into gate and up views + ggml_tensor * gate_up = build_lora_mm_id(gate_up_exps, cur, selected_experts); // [n_ff*2, n_expert_used, n_tokens] + cb(gate_up, "ffn_moe_gate_up", il); - ggml_tensor * experts = nullptr; - if (gate_exps) { - cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] + if (gate_up_exps_b) { + gate_up = ggml_add_id(ctx0, gate_up, gate_up_exps_b, selected_experts); + cb(gate_up, "ffn_moe_gate_up_biased", il); + } + + const int64_t n_ff = gate_up->ne[0] / 2; + cur = ggml_view_3d(ctx0, gate_up, n_ff, gate_up->ne[1], gate_up->ne[2], gate_up->nb[1], gate_up->nb[2], 0); cb(cur, "ffn_moe_gate", il); + up = ggml_view_3d(ctx0, gate_up, n_ff, gate_up->ne[1], gate_up->ne[2], gate_up->nb[1], gate_up->nb[2], n_ff * gate_up->nb[0]); + cb(up, "ffn_moe_up", il); } else { - cur = up; - } + // separate gate and up path + up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] + cb(up, "ffn_moe_up", il); + + if (up_exps_b) { + up = ggml_add_id(ctx0, up, up_exps_b, selected_experts); + cb(up, "ffn_moe_up_biased", il); + } - if (gate_exps_b) { - cur = ggml_add_id(ctx0, cur, gate_exps_b, selected_experts); - cb(cur, "ffn_moe_gate_biased", il); + if (gate_exps) { + cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] + cb(cur, "ffn_moe_gate", il); + } else { + cur = up; + } + + if (gate_exps_b) { + cur = ggml_add_id(ctx0, cur, gate_exps_b, selected_experts); + cb(cur, "ffn_moe_gate_biased", il); + } } + const bool has_gate = gate_exps || gate_up_exps; + switch (type_op) { case LLM_FFN_SILU: if (gate_exps) { @@ -1356,7 +1411,9 @@ ggml_tensor * llm_graph_context::build_moe_ffn( break; } } + } + if (has_gate) { cur = ggml_swiglu_split(ctx0, cur, up); cb(cur, "ffn_moe_swiglu", il); } else { @@ -1364,7 +1421,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn( cb(cur, "ffn_moe_silu", il); } break; case LLM_FFN_GELU: - if (gate_exps) { + if (has_gate) { cur = ggml_geglu_split(ctx0, cur, up); cb(cur, "ffn_moe_geglu", il); } else { @@ -1380,7 +1437,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn( cb(cur, "ffn_moe_swiglu_oai", il); } break; case LLM_FFN_RELU: - if (gate_exps) { + if (has_gate) { cur = ggml_reglu_split(ctx0, cur, up); cb(cur, "ffn_moe_reglu", il); } else { @@ -1388,7 +1445,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn( cb(cur, "ffn_moe_relu", il); } break; case LLM_FFN_RELU_SQR: - if (gate_exps) { + if (has_gate) { // TODO: add support for gated squared relu GGML_ABORT("fatal error: gated squared relu not implemented"); } else { @@ -1695,7 +1752,8 @@ ggml_tensor * llm_graph_context::build_attn_mha( ggml_tensor * cur; - if (cparams.flash_attn && kq_b == nullptr) { + const bool use_flash_attn = cparams.flash_attn && kq_b == nullptr; + if (use_flash_attn) { GGML_ASSERT(kq_b == nullptr && "Flash attention does not support KQ bias yet"); if (v_trans) { @@ -1891,14 +1949,11 @@ static std::unique_ptr build_attn_inp_kv_impl( { GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_iswa for SWA"); - const auto n_kv = mctx_cur->get_n_kv(); - const auto n_tokens = ubatch.n_tokens; - const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq; - inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch); inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch); - inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream); + inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur, ubatch, cparams); + ggml_set_input(inp->self_kq_mask); inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; @@ -1958,8 +2013,8 @@ ggml_tensor * llm_graph_context::build_attn( if (wo) { cur = build_lora_mm(wo, cur); - if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) { - // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators + if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE || arch == LLM_ARCH_JAIS2) { + // GLM4, GLM4_MOE, and JAIS2 seem to have numerical issues with half-precision accumulators ggml_mul_mat_set_prec(cur, GGML_PREC_F32); } } @@ -1983,13 +2038,9 @@ static std::unique_ptr build_attn_inp_k_impl( { GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_iswa for SWA"); - const auto n_kv = mctx_cur->get_n_kv(); - const auto n_tokens = ubatch.n_tokens; - const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq; - inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch); - inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream); + inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur, ubatch, cparams); ggml_set_input(inp->self_kq_mask); inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; @@ -2188,15 +2239,11 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const auto inp = std::make_unique(hparams, cparams, mctx_cur); - const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq; - { - const auto n_kv = mctx_cur->get_base()->get_n_kv(); - inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch); inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch); - inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream); + inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur->get_base(), ubatch, cparams); ggml_set_input(inp->self_kq_mask); ggml_set_name(inp->self_kq_mask, "self_kq_mask"); @@ -2207,12 +2254,10 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const { GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache for non-SWA"); - const auto n_kv = mctx_cur->get_swa()->get_n_kv(); - inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch); inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch); - inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream); + inp->self_kq_mask_swa = build_kq_mask(ctx0, mctx_cur->get_swa(), ubatch, cparams); ggml_set_input(inp->self_kq_mask_swa); ggml_set_name(inp->self_kq_mask_swa, "self_kq_mask_swa"); @@ -2374,27 +2419,21 @@ llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa() auto inp_attn = std::make_unique(hparams, cparams, attn_ctx); - const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq; - { - const auto n_kv = attn_ctx->get_base()->get_n_kv(); - inp_attn->self_k_idxs = attn_ctx->get_base()->build_input_k_idxs(ctx0, ubatch); inp_attn->self_v_idxs = attn_ctx->get_base()->build_input_v_idxs(ctx0, ubatch); - inp_attn->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream); + inp_attn->self_kq_mask = build_kq_mask(ctx0, attn_ctx->get_base(), ubatch, cparams); ggml_set_input(inp_attn->self_kq_mask); inp_attn->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask, GGML_TYPE_F16) : inp_attn->self_kq_mask; } { - const auto n_kv = attn_ctx->get_swa()->get_n_kv(); - inp_attn->self_k_idxs_swa = attn_ctx->get_swa()->build_input_k_idxs(ctx0, ubatch); inp_attn->self_v_idxs_swa = attn_ctx->get_swa()->build_input_v_idxs(ctx0, ubatch); - inp_attn->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream); + inp_attn->self_kq_mask_swa = build_kq_mask(ctx0, attn_ctx->get_swa(), ubatch, cparams); ggml_set_input(inp_attn->self_kq_mask_swa); inp_attn->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask_swa, GGML_TYPE_F16) : inp_attn->self_kq_mask_swa; @@ -2407,8 +2446,9 @@ llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa() void llm_graph_context::build_dense_out( ggml_tensor * dense_2, + ggml_tensor * dense_2_b, ggml_tensor * dense_3) const { - if (!cparams.embeddings || !(dense_2 || dense_3)) { + if (!cparams.embeddings || !(dense_2 || dense_2_b || dense_3)) { return; } ggml_tensor * cur = res->t_embd_pooled != nullptr ? res->t_embd_pooled : res->t_embd; @@ -2417,6 +2457,9 @@ void llm_graph_context::build_dense_out( if (dense_2) { cur = ggml_mul_mat(ctx0, dense_2, cur); } + if (dense_2_b) { + cur = ggml_add(ctx0, cur, dense_2_b); + } if (dense_3) { cur = ggml_mul_mat(ctx0, dense_3, cur); } @@ -2430,7 +2473,8 @@ void llm_graph_context::build_pooling( ggml_tensor * cls, ggml_tensor * cls_b, ggml_tensor * cls_out, - ggml_tensor * cls_out_b) const { + ggml_tensor * cls_out_b, + ggml_tensor * cls_norm) const { if (!cparams.embeddings) { return; } @@ -2469,8 +2513,15 @@ void llm_graph_context::build_pooling( } break; case LLAMA_POOLING_TYPE_RANK: { - ggml_tensor * inp_cls = build_inp_cls(); - cur = ggml_get_rows(ctx0, inp, inp_cls); + if (arch == LLM_ARCH_MODERN_BERT) { + // modern bert gte reranker builds mean first then applies prediction head and classifier + // https://github.com/huggingface/transformers/blob/main/src/transformers/models/modernbert/modular_modernbert.py#L1404-1411 + ggml_tensor * inp_mean = build_inp_mean(); + cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, inp)), inp_mean); + } else { + ggml_tensor * inp_cls = build_inp_cls(); + cur = ggml_get_rows(ctx0, inp, inp_cls); + } // classification head // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566 @@ -2479,7 +2530,15 @@ void llm_graph_context::build_pooling( if (cls_b) { cur = ggml_add(ctx0, cur, cls_b); } - cur = ggml_tanh(ctx0, cur); + if (arch == LLM_ARCH_MODERN_BERT) { + cur = ggml_gelu(ctx0, cur); + } else { + cur = ggml_tanh(ctx0, cur); + } + if (cls_norm) { + // head norm + cur = build_norm(cur, cls_norm, NULL, LLM_NORM, -1); + } } // some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en diff --git a/examples/talk-llama/llama-graph.h b/examples/talk-llama/llama-graph.h index 1d69ff1a6fc..e8f006977d2 100644 --- a/examples/talk-llama/llama-graph.h +++ b/examples/talk-llama/llama-graph.h @@ -814,7 +814,8 @@ struct llm_graph_context { float w_scale, llama_expert_gating_func_type gating_op, int il, - ggml_tensor * probs_in = nullptr) const; + ggml_tensor * probs_in = nullptr, + ggml_tensor * gate_up_exps = nullptr) const; ggml_tensor * build_moe_ffn( ggml_tensor * cur, @@ -835,7 +836,9 @@ struct llm_graph_context { float w_scale, llama_expert_gating_func_type gating_op, int il, - ggml_tensor * probs_in = nullptr) const; + ggml_tensor * probs_in = nullptr, + ggml_tensor * gate_up_exps = nullptr, + ggml_tensor * gate_up_exps_b = nullptr) const; // // inputs @@ -1000,7 +1003,8 @@ struct llm_graph_context { ggml_tensor * cls, ggml_tensor * cls_b, ggml_tensor * cls_out, - ggml_tensor * cls_out_b) const; + ggml_tensor * cls_out_b, + ggml_tensor * cls_norm) const; // // sampling (backend sampling) @@ -1014,6 +1018,7 @@ struct llm_graph_context { void build_dense_out( ggml_tensor * dense_2, + ggml_tensor * dense_2_b, ggml_tensor * dense_3) const; }; diff --git a/examples/talk-llama/llama-impl.cpp b/examples/talk-llama/llama-impl.cpp index 8e3e7b223a6..710a5a1e08d 100644 --- a/examples/talk-llama/llama-impl.cpp +++ b/examples/talk-llama/llama-impl.cpp @@ -109,9 +109,9 @@ std::string llama_format_tensor_shape(const std::vector & ne) { std::string llama_format_tensor_shape(const struct ggml_tensor * t) { char buf[256]; - snprintf(buf, sizeof(buf), "%5" PRId64, t->ne[0]); + snprintf(buf, sizeof(buf), "%6" PRId64, t->ne[0]); for (int i = 1; i < GGML_MAX_DIMS; i++) { - snprintf(buf + strlen(buf), sizeof(buf) - strlen(buf), ", %5" PRId64, t->ne[i]); + snprintf(buf + strlen(buf), sizeof(buf) - strlen(buf), ", %6" PRId64, t->ne[i]); } return buf; } diff --git a/examples/talk-llama/llama-kv-cache.cpp b/examples/talk-llama/llama-kv-cache.cpp index cb702b2a59f..6b668ee9abd 100644 --- a/examples/talk-llama/llama-kv-cache.cpp +++ b/examples/talk-llama/llama-kv-cache.cpp @@ -978,6 +978,9 @@ bool llama_kv_cache::get_can_shift() const { if (model.arch == LLM_ARCH_STEP35) { return false; } + if (hparams.n_pos_per_embd() > 1) { + return false; + } return true; } diff --git a/examples/talk-llama/llama-memory-recurrent.cpp b/examples/talk-llama/llama-memory-recurrent.cpp index f0038036dcb..6e8413f493d 100644 --- a/examples/talk-llama/llama-memory-recurrent.cpp +++ b/examples/talk-llama/llama-memory-recurrent.cpp @@ -163,7 +163,7 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos const auto & cell = cells[tail_id]; // partial intersection is invalid if it includes the final pos if (0 < p0 && p0 <= cell.pos && p1 > cell.pos) { - //printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: partial intersection is invalid, so returning false\n"); + //printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: partial intersection is invalid, so returning false, p0 = %d, cell.pos = %d, p1 = %d\n", p0, cell.pos, p1); return false; } // invalidate tails which will be cleared diff --git a/examples/talk-llama/llama-model-saver.cpp b/examples/talk-llama/llama-model-saver.cpp index 36e353074e0..676efeda709 100644 --- a/examples/talk-llama/llama-model-saver.cpp +++ b/examples/talk-llama/llama-model-saver.cpp @@ -271,6 +271,7 @@ void llama_model_saver::add_tensors_from_model() { add_tensor(model.cls_b); add_tensor(model.cls_out); add_tensor(model.cls_out_b); + add_tensor(model.cls_norm); for (const struct llama_layer & layer : model.layers) { for (size_t i = 0; i < sizeof(layer)/sizeof(struct ggml_tensor *); ++i) { diff --git a/examples/talk-llama/llama-model.cpp b/examples/talk-llama/llama-model.cpp index c26584aa67f..dabf3b3086e 100644 --- a/examples/talk-llama/llama-model.cpp +++ b/examples/talk-llama/llama-model.cpp @@ -123,6 +123,7 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_8B_A1B: return "8B.A1B"; case LLM_TYPE_16B_A1B: return "16B.A1B"; case LLM_TYPE_21B_A3B: return "21B.A3B"; + case LLM_TYPE_24B_A2B: return "24B.A2B"; case LLM_TYPE_30B_A3B: return "30B.A3B"; case LLM_TYPE_31B_A3_5B: return "31B.A3.5B"; case LLM_TYPE_35B_A3B: return "35B.A3B"; @@ -908,7 +909,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa); ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); - hparams.set_swa_pattern(swa_period); + hparams.set_swa_pattern(swa_period, true); } else { hparams.swa_type = LLAMA_SWA_TYPE_NONE; } @@ -978,6 +979,16 @@ void llama_model::load_hparams(llama_model_loader & ml) { type = LLM_TYPE_250M; } } break; + case LLM_ARCH_EUROBERT: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); + ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type); + + if (hparams.n_layer == 12) { + type = LLM_TYPE_SMALL; // 0.2B + } + } break; case LLM_ARCH_BLOOM: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); @@ -1703,8 +1714,8 @@ void llama_model::load_hparams(llama_model_loader & ml) { } break; case LLM_ARCH_DEEPSEEK2: { - // lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B - const bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26); + // lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B, Kanana-2-30B-A3B + const bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26 || (hparams.n_layer == 48 && n_vocab == 128256)); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); @@ -1784,7 +1795,15 @@ void llama_model::load_hparams(llama_model_loader & ml) { { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, false); + + // NextN/MTP parameters (GLM-OCR) + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + + // TODO: when MTP is implemented, this should probably be updated if needed + hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; + switch (hparams.n_layer) { + case 17: type = LLM_TYPE_1B; break; // GLM-OCR case 40: type = LLM_TYPE_9B; break; case 61: type = LLM_TYPE_32B; break; default: type = LLM_TYPE_UNKNOWN; @@ -1929,6 +1948,16 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_JAIS2: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_8B; break; + case 68: type = LLM_TYPE_70B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_NEMOTRON: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); @@ -2226,7 +2255,11 @@ void llama_model::load_hparams(llama_model_loader & ml) { } break; case LLM_ARCH_ERNIE4_5: case LLM_ARCH_ERNIE4_5_MOE: + case LLM_ARCH_PADDLEOCR: { + // paddleocr need mrope_section + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); if (arch == LLM_ARCH_ERNIE4_5_MOE) { ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); @@ -2340,6 +2373,12 @@ void llama_model::load_hparams(llama_model_loader & ml) { case 10752: type = LLM_TYPE_2_6B; break; default: type = LLM_TYPE_UNKNOWN; } + if (const auto is_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); is_swa && hparams.n_swa > 0) { + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + for (uint32_t il = 0; il < hparams.n_layer; ++il) { + hparams.swa_layers[il] = !hparams.recurrent_layer_arr[il]; + } + } } break; case LLM_ARCH_LFM2MOE: { @@ -2353,7 +2392,11 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.recurrent_layer_arr[il] = hparams.n_head_kv(il) == 0; } - type = LLM_TYPE_8B_A1B; + switch (hparams.n_layer) { + case 24: type = LLM_TYPE_8B_A1B; break; + case 40: type = LLM_TYPE_24B_A2B; break; + default: type = LLM_TYPE_UNKNOWN; + } } break; case LLM_ARCH_SMALLTHINKER: { @@ -2937,6 +2980,15 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // TODO: move to a separate function const auto tn = LLM_TN(arch); + + // helper: try merged gate_up_exps first, fall back to separate gate and up + auto create_tensor_gate_up_exps = [&](llama_layer & layer, int bid, int64_t n_embd_, int64_t n_ff_, int64_t n_expert_, int flags) { + layer.ffn_gate_up_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_UP_EXPS, "weight", bid), {n_embd_, n_ff_ * 2, n_expert_}, TENSOR_NOT_REQUIRED); + if (layer.ffn_gate_up_exps == nullptr) { + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", bid), {n_embd_, n_ff_, n_expert_}, flags); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", bid), {n_embd_, n_ff_, n_expert_}, flags); + } + }; switch (arch) { case LLM_ARCH_LLAMA: case LLM_ARCH_REFACT: @@ -3505,9 +3557,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); } - cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, TENSOR_NOT_REQUIRED); - cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED); - cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {hparams.n_cls_out}, TENSOR_NOT_REQUIRED); + cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED); + cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {hparams.n_cls_out}, TENSOR_NOT_REQUIRED); + cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, TENSOR_NOT_REQUIRED); + cls_norm = create_tensor(tn(LLM_TENSOR_CLS_NORM, "weight"), {n_embd}, TENSOR_NOT_REQUIRED); } break; case LLM_ARCH_NEO_BERT: @@ -3536,6 +3589,29 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); } } break; + case LLM_ARCH_EUROBERT: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + } + } break; case LLM_ARCH_JINA_BERT_V2: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); // word_embeddings @@ -5154,9 +5230,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } // MoE branch - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + create_tensor_gate_up_exps(layer, i, n_embd, n_ff_exp, n_expert, 0); // Shared expert branch layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); @@ -5360,6 +5435,45 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); } } break; + case LLM_ARCH_JAIS2: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + if (!output) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + // attention biases - all have shape n_embd (output dimension of projections) + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd}, 0); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd}, 0); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); + + // Jais-2 uses simple MLP (no gate) with biases + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); + } + } break; case LLM_ARCH_CHATGLM: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -5410,30 +5524,48 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } for (int i = 0; i < n_layer; ++i) { + int flags = 0; + if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { + // skip all tensors in the NextN layers + flags |= TENSOR_SKIP; + } + auto & layer = layers[i]; - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); - layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, flags); + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, flags | TENSOR_NOT_REQUIRED); + layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, flags | TENSOR_NOT_REQUIRED); if (layer.wqkv == nullptr) { - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, flags); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, flags); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, flags); + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, flags | TENSOR_NOT_REQUIRED); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, flags | TENSOR_NOT_REQUIRED); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, flags | TENSOR_NOT_REQUIRED); } - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, flags); - layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, flags); - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff * 2}, 0); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, flags); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, flags); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff * 2}, flags); - layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, flags); + + // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers + if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags); + + // Optional tensors + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, flags | TENSOR_NOT_REQUIRED); + } } } break; case LLM_ARCH_GLM4_MOE: @@ -6549,6 +6681,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } break; case LLM_ARCH_ERNIE4_5: case LLM_ARCH_ERNIE4_5_MOE: + case LLM_ARCH_PADDLEOCR: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -6869,7 +7002,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } // for LFM2-ColBert-350M - dense_2_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_2_OUT, "weight"), {n_embd, hparams.n_embd_out()}, TENSOR_NOT_REQUIRED); + dense_2_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_2_OUT, "weight"), {n_embd, hparams.n_embd_out()}, TENSOR_NOT_REQUIRED); + dense_2_out_layers_b = create_tensor(tn(LLM_TENSOR_DENSE_2_OUT, "bias"), {hparams.n_embd_out() }, TENSOR_NOT_REQUIRED); } break; case LLM_ARCH_SMALLTHINKER: { @@ -7299,9 +7433,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0); - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0); layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0); + create_tensor_gate_up_exps(layer, i, n_embd, n_ff_exp, n_expert, 0); // Shared experts layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), { n_embd }, 0); @@ -7365,9 +7498,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0); - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0); layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0); + create_tensor_gate_up_exps(layer, i, n_embd, n_ff_exp, n_expert, 0); // Shared experts const int64_t n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff; @@ -8088,6 +8220,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, case LLM_ARCH_NOMIC_BERT: case LLM_ARCH_NOMIC_BERT_MOE: case LLM_ARCH_NEO_BERT: + case LLM_ARCH_EUROBERT: case LLM_ARCH_WAVTOKENIZER_DEC: case LLM_ARCH_MODERN_BERT: case LLM_ARCH_GEMMA_EMBEDDING: @@ -8285,6 +8418,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_EUROBERT: + { + llm = std::make_unique(*this, params); + } break; case LLM_ARCH_BLOOM: { llm = std::make_unique(*this, params); @@ -8527,6 +8664,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_JAIS2: + { + llm = std::make_unique(*this, params); + } break; case LLM_ARCH_NEMOTRON: { llm = std::make_unique(*this, params); @@ -8622,6 +8763,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_PADDLEOCR: + { + llm = std::make_unique(*this, params); + } break; case LLM_ARCH_HUNYUAN_MOE: { llm = std::make_unique(*this, params); @@ -8645,7 +8790,11 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { case LLM_ARCH_LFM2: case LLM_ARCH_LFM2MOE: { - llm = std::make_unique(*this, params); + if (hparams.swa_type == LLAMA_SWA_TYPE_STANDARD) { + llm = std::make_unique>(*this, params); + } else { + llm = std::make_unique>(*this, params); + } } break; case LLM_ARCH_SMALLTHINKER: { @@ -8708,7 +8857,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { } // add on pooling layer - llm->build_pooling(cls, cls_b, cls_out, cls_out_b); + llm->build_pooling(cls, cls_b, cls_out, cls_out_b, cls_norm); // add backend sampling layers (if any) llm->build_sampling(); @@ -8717,7 +8866,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { // there will be two additional dense projection layers // dense linear projections are applied after pooling // TODO: move reranking logic here and generalize - llm->build_dense_out(dense_2_out_layers, dense_3_out_layers); + llm->build_dense_out(dense_2_out_layers, dense_2_out_layers_b, dense_3_out_layers); llm->res->set_outputs(); @@ -8899,6 +9048,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_MODERN_BERT: case LLM_ARCH_NOMIC_BERT: case LLM_ARCH_NOMIC_BERT_MOE: + case LLM_ARCH_EUROBERT: case LLM_ARCH_STABLELM: case LLM_ARCH_BITNET: case LLM_ARCH_QWEN: @@ -8935,6 +9085,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_BAILINGMOE2: case LLM_ARCH_DOTS1: case LLM_ARCH_HUNYUAN_MOE: + case LLM_ARCH_JAIS2: case LLM_ARCH_OPENAI_MOE: case LLM_ARCH_HUNYUAN_DENSE: case LLM_ARCH_LFM2: @@ -8953,6 +9104,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { return LLAMA_ROPE_TYPE_NEOX; case LLM_ARCH_QWEN2VL: + case LLM_ARCH_PADDLEOCR: return LLAMA_ROPE_TYPE_MROPE; case LLM_ARCH_QWEN3VL: case LLM_ARCH_QWEN3VLMOE: diff --git a/examples/talk-llama/llama-model.h b/examples/talk-llama/llama-model.h index b3505914293..d7c3e7d1c1a 100644 --- a/examples/talk-llama/llama-model.h +++ b/examples/talk-llama/llama-model.h @@ -116,6 +116,7 @@ enum llm_type { LLM_TYPE_8B_A1B, // lfm2moe LLM_TYPE_16B_A1B, LLM_TYPE_21B_A3B, // Ernie MoE small + LLM_TYPE_24B_A2B, // lfm2moe LLM_TYPE_30B_A3B, LLM_TYPE_31B_A3_5B, LLM_TYPE_35B_A3B, // Qwen3.5 @@ -279,14 +280,16 @@ struct llama_layer { struct ggml_tensor * ffn_up_enc = nullptr; // ff MoE - struct ggml_tensor * ffn_gate_inp = nullptr; - struct ggml_tensor * ffn_gate_exps = nullptr; - struct ggml_tensor * ffn_down_exps = nullptr; - struct ggml_tensor * ffn_up_exps = nullptr; - struct ggml_tensor * ffn_gate_inp_b = nullptr; - struct ggml_tensor * ffn_gate_exps_b = nullptr; - struct ggml_tensor * ffn_down_exps_b = nullptr; - struct ggml_tensor * ffn_up_exps_b = nullptr; + struct ggml_tensor * ffn_gate_inp = nullptr; + struct ggml_tensor * ffn_gate_exps = nullptr; + struct ggml_tensor * ffn_down_exps = nullptr; + struct ggml_tensor * ffn_up_exps = nullptr; + struct ggml_tensor * ffn_gate_up_exps = nullptr; + struct ggml_tensor * ffn_gate_inp_b = nullptr; + struct ggml_tensor * ffn_gate_exps_b = nullptr; + struct ggml_tensor * ffn_down_exps_b = nullptr; + struct ggml_tensor * ffn_up_exps_b = nullptr; + struct ggml_tensor * ffn_gate_up_exps_b = nullptr; // ff shared expert (shexp) struct ggml_tensor * ffn_gate_inp_shexp = nullptr; @@ -475,6 +478,7 @@ struct llama_model { struct ggml_tensor * cls_b = nullptr; struct ggml_tensor * cls_out = nullptr; struct ggml_tensor * cls_out_b = nullptr; + struct ggml_tensor * cls_norm = nullptr; struct ggml_tensor * conv1d = nullptr; struct ggml_tensor * conv1d_b = nullptr; @@ -491,8 +495,9 @@ struct llama_model { //Dense linear projections for SentenceTransformers models like embeddinggemma // For Sentence Transformers models structure see // https://sbert.net/docs/sentence_transformer/usage/custom_models.html#structure-of-sentence-transformer-models - struct ggml_tensor * dense_2_out_layers = nullptr; - struct ggml_tensor * dense_3_out_layers = nullptr; + struct ggml_tensor * dense_2_out_layers = nullptr; + struct ggml_tensor * dense_2_out_layers_b = nullptr; + struct ggml_tensor * dense_3_out_layers = nullptr; // gguf metadata std::unordered_map gguf_kv; diff --git a/examples/talk-llama/llama-quant.cpp b/examples/talk-llama/llama-quant.cpp index a7891647c3d..24770430e1c 100644 --- a/examples/talk-llama/llama-quant.cpp +++ b/examples/talk-llama/llama-quant.cpp @@ -479,6 +479,17 @@ static size_t llama_tensor_quantize_impl(enum ggml_type new_type, const float * return new_size; } +static bool tensor_type_requires_imatrix(const ggml_tensor * t, const ggml_type dst_type, const llama_ftype ftype) { + return ( + dst_type == GGML_TYPE_IQ2_XXS || dst_type == GGML_TYPE_IQ2_XS || + dst_type == GGML_TYPE_IQ3_XXS || dst_type == GGML_TYPE_IQ1_S || + dst_type == GGML_TYPE_IQ2_S || dst_type == GGML_TYPE_IQ1_M || + ( // Q2_K_S is the worst k-quant type - only allow it without imatrix for token embeddings + dst_type == GGML_TYPE_Q2_K && ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S && strcmp(t->name, "token_embd.weight") != 0 + ) + ); +} + static void llama_model_quantize_impl(const std::string & fname_inp, const std::string & fname_out, const llama_model_quantize_params * params) { ggml_type default_type; llama_ftype ftype = params->ftype; @@ -735,24 +746,36 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: }; const auto tn = LLM_TN(model.arch); - new_ofstream(0); + + // no output file for --dry-run + if (!params->dry_run) { + new_ofstream(0); + } + + // flag for `--dry-run`, to let the user know if imatrix will be required for a real + // quantization, as a courtesy + bool will_require_imatrix = false; + for (const auto * it : tensors) { const auto & weight = *it; ggml_tensor * tensor = weight.tensor; - if (weight.idx != cur_split && params->keep_split) { + if (!params->dry_run && (weight.idx != cur_split && params->keep_split)) { close_ofstream(); new_ofstream(weight.idx); } const std::string name = ggml_get_name(tensor); + const size_t tensor_size = ggml_nbytes(tensor); - if (!ml.use_mmap) { - if (read_data.size() < ggml_nbytes(tensor)) { - read_data.resize(ggml_nbytes(tensor)); + if (!params->dry_run) { + if (!ml.use_mmap) { + if (read_data.size() < tensor_size) { + read_data.resize(tensor_size); + } + tensor->data = read_data.data(); } - tensor->data = read_data.data(); + ml.load_data_for(tensor); } - ml.load_data_for(tensor); LLAMA_LOG_INFO("[%4d/%4d] %36s - [%s], type = %6s, ", ++idx, ml.n_tensors, @@ -900,129 +923,155 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: quantize = tensor->type != new_type; } - if (!quantize) { - new_type = tensor->type; - new_data = tensor->data; - new_size = ggml_nbytes(tensor); - LLAMA_LOG_INFO("size = %8.3f MiB\n", ggml_nbytes(tensor)/1024.0/1024.0); + // we have now decided on the target type for this tensor + if (params->dry_run) { + // the --dry-run option calculates the final quantization size without quantizting + if (quantize) { + new_size = ggml_nrows(tensor) * ggml_row_size(new_type, tensor->ne[0]); + LLAMA_LOG_INFO("size = %8.2f MiB -> %8.2f MiB (%s)\n", + tensor_size/1024.0/1024.0, + new_size/1024.0/1024.0, + ggml_type_name(new_type)); + if (!will_require_imatrix && tensor_type_requires_imatrix(tensor, new_type, params->ftype)) { + will_require_imatrix = true; + } + } else { + new_size = tensor_size; + LLAMA_LOG_INFO("size = %8.3f MiB\n", new_size/1024.0/1024.0); + } + total_size_org += tensor_size; + total_size_new += new_size; + continue; } else { - const int64_t nelements = ggml_nelements(tensor); + // no --dry-run, perform quantization + if (!quantize) { + new_type = tensor->type; + new_data = tensor->data; + new_size = tensor_size; + LLAMA_LOG_INFO("size = %8.3f MiB\n", tensor_size/1024.0/1024.0); + } else { + const int64_t nelements = ggml_nelements(tensor); - const float * imatrix = nullptr; - if (imatrix_data) { - auto it = imatrix_data->find(remap_imatrix(tensor->name, mapped)); - if (it == imatrix_data->end()) { - LLAMA_LOG_INFO("\n====== %s: did not find weights for %s\n", __func__, tensor->name); - } else { - if (it->second.size() == (size_t)tensor->ne[0]*tensor->ne[2]) { - imatrix = it->second.data(); + const float * imatrix = nullptr; + if (imatrix_data) { + auto it = imatrix_data->find(remap_imatrix(tensor->name, mapped)); + if (it == imatrix_data->end()) { + LLAMA_LOG_INFO("\n====== %s: did not find weights for %s\n", __func__, tensor->name); } else { - LLAMA_LOG_INFO("\n====== %s: imatrix size %d is different from tensor size %d for %s\n", __func__, - int(it->second.size()), int(tensor->ne[0]*tensor->ne[2]), tensor->name); - - // this can happen when quantizing an old mixtral model with split tensors with a new incompatible imatrix - // this is a significant error and it may be good idea to abort the process if this happens, - // since many people will miss the error and not realize that most of the model is being quantized without an imatrix - // tok_embd should be ignored in this case, since it always causes this warning - if (name != tn(LLM_TENSOR_TOKEN_EMBD, "weight")) { - throw std::runtime_error(format("imatrix size %d is different from tensor size %d for %s", - int(it->second.size()), int(tensor->ne[0]*tensor->ne[2]), tensor->name)); + if (it->second.size() == (size_t)tensor->ne[0]*tensor->ne[2]) { + imatrix = it->second.data(); + } else { + LLAMA_LOG_INFO("\n====== %s: imatrix size %d is different from tensor size %d for %s\n", __func__, + int(it->second.size()), int(tensor->ne[0]*tensor->ne[2]), tensor->name); + + // this can happen when quantizing an old mixtral model with split tensors with a new incompatible imatrix + // this is a significant error and it may be good idea to abort the process if this happens, + // since many people will miss the error and not realize that most of the model is being quantized without an imatrix + // tok_embd should be ignored in this case, since it always causes this warning + if (name != tn(LLM_TENSOR_TOKEN_EMBD, "weight")) { + throw std::runtime_error(format("imatrix size %d is different from tensor size %d for %s", + int(it->second.size()), int(tensor->ne[0]*tensor->ne[2]), tensor->name)); + } } } } - } - if ((new_type == GGML_TYPE_IQ2_XXS || - new_type == GGML_TYPE_IQ2_XS || - new_type == GGML_TYPE_IQ2_S || - new_type == GGML_TYPE_IQ1_S || - (new_type == GGML_TYPE_IQ1_M && strcmp(tensor->name, "token_embd.weight") && strcmp(tensor->name, "output.weight")) || - (new_type == GGML_TYPE_Q2_K && params->ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S && strcmp(tensor->name, "token_embd.weight") != 0)) && !imatrix) { - LLAMA_LOG_ERROR("\n\n============================================================\n"); - LLAMA_LOG_ERROR("Missing importance matrix for tensor %s in a very low-bit quantization\n", tensor->name); - LLAMA_LOG_ERROR("The result will be garbage, so bailing out\n"); - LLAMA_LOG_ERROR("============================================================\n\n"); - throw std::runtime_error(format("Missing importance matrix for tensor %s in a very low-bit quantization", tensor->name)); - } + if (!imatrix && tensor_type_requires_imatrix(tensor, new_type, params->ftype)) { + LLAMA_LOG_ERROR("\n\n============================================================\n"); + LLAMA_LOG_ERROR("Missing importance matrix for tensor %s in a very low-bit quantization\n", tensor->name); + LLAMA_LOG_ERROR("The result will be garbage, so bailing out\n"); + LLAMA_LOG_ERROR("============================================================\n\n"); + throw std::runtime_error(format("Missing importance matrix for tensor %s in a very low-bit quantization", tensor->name)); + } - float * f32_data; + float * f32_data; - if (tensor->type == GGML_TYPE_F32) { - f32_data = (float *) tensor->data; - } else if (ggml_is_quantized(tensor->type) && !params->allow_requantize) { - throw std::runtime_error(format("requantizing from type %s is disabled", ggml_type_name(tensor->type))); - } else { - llama_tensor_dequantize_impl(tensor, f32_conv_buf, workers, nelements, nthread); - f32_data = (float *) f32_conv_buf.data(); - } + if (tensor->type == GGML_TYPE_F32) { + f32_data = (float *) tensor->data; + } else if (ggml_is_quantized(tensor->type) && !params->allow_requantize) { + throw std::runtime_error(format("requantizing from type %s is disabled", ggml_type_name(tensor->type))); + } else { + llama_tensor_dequantize_impl(tensor, f32_conv_buf, workers, nelements, nthread); + f32_data = (float *) f32_conv_buf.data(); + } - LLAMA_LOG_INFO("converting to %s .. ", ggml_type_name(new_type)); - fflush(stdout); + LLAMA_LOG_INFO("converting to %s .. ", ggml_type_name(new_type)); + fflush(stdout); - if (work.size() < (size_t)nelements * 4) { - work.resize(nelements * 4); // upper bound on size - } - new_data = work.data(); + if (work.size() < (size_t)nelements * 4) { + work.resize(nelements * 4); // upper bound on size + } + new_data = work.data(); - const int64_t n_per_row = tensor->ne[0]; - const int64_t nrows = tensor->ne[1]; + const int64_t n_per_row = tensor->ne[0]; + const int64_t nrows = tensor->ne[1]; - static const int64_t min_chunk_size = 32 * 512; - const int64_t chunk_size = (n_per_row >= min_chunk_size ? n_per_row : n_per_row * ((min_chunk_size + n_per_row - 1)/n_per_row)); + static const int64_t min_chunk_size = 32 * 512; + const int64_t chunk_size = (n_per_row >= min_chunk_size ? n_per_row : n_per_row * ((min_chunk_size + n_per_row - 1)/n_per_row)); - const int64_t nelements_matrix = tensor->ne[0] * tensor->ne[1]; - const int64_t nchunk = (nelements_matrix + chunk_size - 1)/chunk_size; - const int64_t nthread_use = nthread > 1 ? std::max((int64_t)1, std::min((int64_t)nthread, nchunk)) : 1; + const int64_t nelements_matrix = tensor->ne[0] * tensor->ne[1]; + const int64_t nchunk = (nelements_matrix + chunk_size - 1)/chunk_size; + const int64_t nthread_use = nthread > 1 ? std::max((int64_t)1, std::min((int64_t)nthread, nchunk)) : 1; - // quantize each expert separately since they have different importance matrices - new_size = 0; - for (int64_t i03 = 0; i03 < tensor->ne[2]; ++i03) { - const float * f32_data_03 = f32_data + i03 * nelements_matrix; - void * new_data_03 = (char *)new_data + ggml_row_size(new_type, n_per_row) * i03 * nrows; - const float * imatrix_03 = imatrix ? imatrix + i03 * n_per_row : nullptr; + // quantize each expert separately since they have different importance matrices + new_size = 0; + for (int64_t i03 = 0; i03 < tensor->ne[2]; ++i03) { + const float * f32_data_03 = f32_data + i03 * nelements_matrix; + void * new_data_03 = (char *)new_data + ggml_row_size(new_type, n_per_row) * i03 * nrows; + const float * imatrix_03 = imatrix ? imatrix + i03 * n_per_row : nullptr; - new_size += llama_tensor_quantize_impl(new_type, f32_data_03, new_data_03, chunk_size, nrows, n_per_row, imatrix_03, workers, nthread_use); + new_size += llama_tensor_quantize_impl(new_type, f32_data_03, new_data_03, chunk_size, nrows, n_per_row, imatrix_03, workers, nthread_use); - // TODO: temporary sanity check that the F16 -> MXFP4 is lossless + // TODO: temporary sanity check that the F16 -> MXFP4 is lossless #if 0 - if (new_type == GGML_TYPE_MXFP4) { - auto * x = f32_data_03; - - //LLAMA_LOG_INFO("nrows = %d, n_per_row = %d\n", nrows, n_per_row); - std::vector deq(nrows*n_per_row); - const ggml_type_traits * qtype = ggml_get_type_traits(new_type); - qtype->to_float(new_data_03, deq.data(), deq.size()); - - double err = 0.0f; - for (int i = 0; i < (int) deq.size(); ++i) { - err += fabsf(deq[i] - x[i]); - //if (fabsf(deq[i] - x[i]) > 0.00001 && i < 256) { - if (deq[i] != x[i]) { - LLAMA_LOG_INFO("deq[%d] = %f, x[%d] = %f\n", i, deq[i], i, x[i]); + if (new_type == GGML_TYPE_MXFP4) { + auto * x = f32_data_03; + + //LLAMA_LOG_INFO("nrows = %d, n_per_row = %d\n", nrows, n_per_row); + std::vector deq(nrows*n_per_row); + const ggml_type_traits * qtype = ggml_get_type_traits(new_type); + qtype->to_float(new_data_03, deq.data(), deq.size()); + + double err = 0.0f; + for (int i = 0; i < (int) deq.size(); ++i) { + err += fabsf(deq[i] - x[i]); + //if (fabsf(deq[i] - x[i]) > 0.00001 && i < 256) { + if (deq[i] != x[i]) { + LLAMA_LOG_INFO("deq[%d] = %f, x[%d] = %f\n", i, deq[i], i, x[i]); + } } + //LLAMA_LOG_INFO("err = %f\n", err); + GGML_ASSERT(err == 0.00000); } - //LLAMA_LOG_INFO("err = %f\n", err); - GGML_ASSERT(err == 0.00000); - } #endif + } + LLAMA_LOG_INFO("size = %8.2f MiB -> %8.2f MiB\n", tensor_size/1024.0/1024.0, new_size/1024.0/1024.0); } - LLAMA_LOG_INFO("size = %8.2f MiB -> %8.2f MiB\n", ggml_nbytes(tensor)/1024.0/1024.0, new_size/1024.0/1024.0); - } - total_size_org += ggml_nbytes(tensor); - total_size_new += new_size; + total_size_org += tensor_size; + total_size_new += new_size; + + // update the gguf meta data as we go + gguf_set_tensor_type(ctx_outs[cur_split].get(), name.c_str(), new_type); + GGML_ASSERT(gguf_get_tensor_size(ctx_outs[cur_split].get(), gguf_find_tensor(ctx_outs[cur_split].get(), name.c_str())) == new_size); + gguf_set_tensor_data(ctx_outs[cur_split].get(), name.c_str(), new_data); + + // write tensor data + padding + fout.write((const char *) new_data, new_size); + zeros(fout, GGML_PAD(new_size, align) - new_size); + } // no --dry-run + } // iterate over tensors + + if (!params->dry_run) { + close_ofstream(); + } - // update the gguf meta data as we go - gguf_set_tensor_type(ctx_outs[cur_split].get(), name.c_str(), new_type); - GGML_ASSERT(gguf_get_tensor_size(ctx_outs[cur_split].get(), gguf_find_tensor(ctx_outs[cur_split].get(), name.c_str())) == new_size); - gguf_set_tensor_data(ctx_outs[cur_split].get(), name.c_str(), new_data); + LLAMA_LOG_INFO("%s: model size = %8.2f MiB (%.2f BPW)\n", __func__, total_size_org/1024.0/1024.0, total_size_org*8.0/ml.n_elements); + LLAMA_LOG_INFO("%s: quant size = %8.2f MiB (%.2f BPW)\n", __func__, total_size_new/1024.0/1024.0, total_size_new*8.0/ml.n_elements); - // write tensor data + padding - fout.write((const char *) new_data, new_size); - zeros(fout, GGML_PAD(new_size, align) - new_size); + if (!params->imatrix && params->dry_run && will_require_imatrix) { + LLAMA_LOG_WARN("%s: WARNING: dry run completed successfully, but actually completing this quantization will require an imatrix!\n", + __func__ + ); } - close_ofstream(); - - LLAMA_LOG_INFO("%s: model size = %8.2f MiB\n", __func__, total_size_org/1024.0/1024.0); - LLAMA_LOG_INFO("%s: quant size = %8.2f MiB\n", __func__, total_size_new/1024.0/1024.0); if (qs.n_fallback > 0) { LLAMA_LOG_WARN("%s: WARNING: %d of %d tensor(s) required fallback quantization\n", @@ -1045,6 +1094,7 @@ llama_model_quantize_params llama_model_quantize_default_params() { /*.only_copy =*/ false, /*.pure =*/ false, /*.keep_split =*/ false, + /*.dry_run =*/ false, /*.imatrix =*/ nullptr, /*.kv_overrides =*/ nullptr, /*.tensor_type =*/ nullptr, diff --git a/examples/talk-llama/llama-vocab.cpp b/examples/talk-llama/llama-vocab.cpp index 62e137fb842..194eed238ec 100644 --- a/examples/talk-llama/llama-vocab.cpp +++ b/examples/talk-llama/llama-vocab.cpp @@ -289,6 +289,15 @@ struct llm_tokenizer_bpe : llm_tokenizer { "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", }; break; + case LLAMA_VOCAB_PRE_TYPE_JAIS2: + regex_exprs = { + // original regex from tokenizer.json + //"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s{512}(?!\\S)|\\s{256}(?!\\S)|\\s{128}(?!\\S)|\\s{64}(?!\\S)|\\s{32}(?!\\S)|\\s{16}(?!\\S)|\\s{8}(?!\\S)|\\s{4}(?!\\S)|\\s{1,2}(?!\\S)|\\s{1}", + + // adapted: same as llama3 but with cascading whitespace pattern + "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s{512}(?!\\S)|\\s{256}(?!\\S)|\\s{128}(?!\\S)|\\s{64}(?!\\S)|\\s{32}(?!\\S)|\\s{16}(?!\\S)|\\s{8}(?!\\S)|\\s{4}(?!\\S)|\\s{1,2}(?!\\S)|\\s{1}", + }; + break; case LLAMA_VOCAB_PRE_TYPE_DBRX: case LLAMA_VOCAB_PRE_TYPE_SMAUG: regex_exprs = { @@ -308,6 +317,7 @@ struct llm_tokenizer_bpe : llm_tokenizer { break; case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM: case LLAMA_VOCAB_PRE_TYPE_HUNYUAN_DENSE: + case LLAMA_VOCAB_PRE_TYPE_JOYAI_LLM: regex_exprs = { "\\p{N}{1,3}", "[一-龥぀-ゟ゠-ヿ]+", @@ -422,6 +432,14 @@ struct llm_tokenizer_bpe : llm_tokenizer { "[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", }; break; + case LLAMA_VOCAB_PRE_TYPE_TINY_AYA: + regex_exprs = { + // original regex from tokenizer.json: "\\d{1,3}(?=(?:\\d{3})*\\b)" + "\\d{1,3}(?=(?:\\d{3})*\\b)", + // original regex from tokenizer.json: "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" + "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + }; + break; case LLAMA_VOCAB_PRE_TYPE_KIMI_K2: regex_exprs = { // K2 trigger pattern - this will activate the custom K2 handler in unicode.cpp @@ -1872,7 +1890,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "falcon-h1" || tokenizer_pre == "pixtral" || tokenizer_pre == "midm-2.0" || - tokenizer_pre == "lfm2") { + tokenizer_pre == "lfm2" || + tokenizer_pre == "jina-v5-nano") { pre_type = LLAMA_VOCAB_PRE_TYPE_LLAMA3; ignore_merges = true; add_bos = true; @@ -1912,8 +1931,11 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "jina-v2-de" || tokenizer_pre == "a.x-4.0" || tokenizer_pre == "mellum" || - tokenizer_pre == "modern-bert" ) { + tokenizer_pre == "modern-bert") { pre_type = LLAMA_VOCAB_PRE_TYPE_GPT2; + } else if ( + tokenizer_pre == "jais-2") { + pre_type = LLAMA_VOCAB_PRE_TYPE_JAIS2; } else if ( tokenizer_pre == "jina-v1-en" || tokenizer_pre == "jina-v2-code" || @@ -2005,10 +2027,15 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "megrez") { pre_type = LLAMA_VOCAB_PRE_TYPE_QWEN2; } else if ( - tokenizer_pre == "gpt-4o" || - tokenizer_pre == "llama4") { + tokenizer_pre == "gpt-4o" || + tokenizer_pre == "llama4" || + tokenizer_pre == "kanana2") { pre_type = LLAMA_VOCAB_PRE_TYPE_GPT4O; clean_spaces = false; + } else if ( + tokenizer_pre == "tiny_aya") { + pre_type = LLAMA_VOCAB_PRE_TYPE_TINY_AYA; + clean_spaces = false; } else if ( tokenizer_pre == "superbpe") { pre_type = LLAMA_VOCAB_PRE_TYPE_SUPERBPE; @@ -2039,6 +2066,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "hunyuan-dense") { pre_type = LLAMA_VOCAB_PRE_TYPE_HUNYUAN_DENSE; clean_spaces = false; + } else if ( + tokenizer_pre == "joyai-llm") { + pre_type = LLAMA_VOCAB_PRE_TYPE_JOYAI_LLM; + clean_spaces = false; } else if ( tokenizer_pre == "kimi-k2") { pre_type = LLAMA_VOCAB_PRE_TYPE_KIMI_K2; @@ -2441,6 +2472,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { || t.first == "<|calls|>" // solar-open || t.first == "" || t.first == "<|endoftext|>" + || t.first == "" // paddleocr || t.first == "<|eom_id|>" || t.first == "" || t.first == "_" diff --git a/examples/talk-llama/llama-vocab.h b/examples/talk-llama/llama-vocab.h index 718238fb866..be5b08012df 100644 --- a/examples/talk-llama/llama-vocab.h +++ b/examples/talk-llama/llama-vocab.h @@ -55,6 +55,9 @@ enum llama_vocab_pre_type { LLAMA_VOCAB_PRE_TYPE_YOUTU = 44, LLAMA_VOCAB_PRE_TYPE_EXAONE_MOE = 45, LLAMA_VOCAB_PRE_TYPE_QWEN35 = 46, + LLAMA_VOCAB_PRE_TYPE_TINY_AYA = 47, + LLAMA_VOCAB_PRE_TYPE_JOYAI_LLM = 48, + LLAMA_VOCAB_PRE_TYPE_JAIS2 = 49, }; struct LLM_KV; diff --git a/examples/talk-llama/llama.h b/examples/talk-llama/llama.h index d2d7f59ebc6..077f66dc651 100644 --- a/examples/talk-llama/llama.h +++ b/examples/talk-llama/llama.h @@ -389,6 +389,7 @@ extern "C" { bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored bool pure; // quantize all tensors to the default type bool keep_split; // quantize to the same number of shards + bool dry_run; // calculate and show the final quantization size without performing quantization void * imatrix; // pointer to importance matrix data void * kv_overrides; // pointer to vector containing overrides void * tensor_types; // pointer to vector containing tensor types diff --git a/examples/talk-llama/models/deepseek2.cpp b/examples/talk-llama/models/deepseek2.cpp index b2c1f160601..b608396e50e 100644 --- a/examples/talk-llama/models/deepseek2.cpp +++ b/examples/talk-llama/models/deepseek2.cpp @@ -218,7 +218,9 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr LLM_FFN_SILU, hparams.expert_weights_norm, hparams.expert_weights_scale, hparams.expert_weights_scale, (llama_expert_gating_func_type) hparams.expert_gating_func, - il); + il, + nullptr, + model.layers[il].ffn_gate_up_exps); cb(moe_out, "ffn_moe_out", il); // FFN shared expert diff --git a/examples/talk-llama/models/delta-net-base.cpp b/examples/talk-llama/models/delta-net-base.cpp new file mode 100644 index 00000000000..99f1fdd9538 --- /dev/null +++ b/examples/talk-llama/models/delta-net-base.cpp @@ -0,0 +1,376 @@ +#include "models.h" + +#define CHUNK_SIZE 64 + +// utility to get one slice from the third dimension +// input dim: [x, y, c, b] +// output dim: [x, y, 1, b] +static ggml_tensor * get_slice_2d(ggml_context * ctx0, ggml_tensor * t, int64_t c) { + return ggml_view_4d(ctx0, t, t->ne[0], t->ne[1], 1, t->ne[3], + t->nb[1], t->nb[2], t->nb[3], t->nb[2] * c); +} + +llm_build_delta_net_base::llm_build_delta_net_base(const llm_graph_params & params) : llm_graph_context(params) {} + +std::pair llm_build_delta_net_base::build_delta_net_chunking( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * b, + ggml_tensor * s, + int il) { + const int64_t S_k = q->ne[0]; + const int64_t H_k = q->ne[1]; + const int64_t n_tokens = q->ne[2]; + const int64_t n_seqs = q->ne[3]; + + const int64_t S_v = v->ne[0]; + const int64_t H_v = v->ne[1]; + const bool kda = (g->ne[0] == S_k && g->ne[1] == H_k); + + GGML_ASSERT(S_k == S_v); + GGML_ASSERT(H_v % H_k == 0); + + GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs); + GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs); + GGML_ASSERT(v->ne[0] == S_v && v->ne[1] == H_v && v->ne[2] == n_tokens && v->ne[3] == n_seqs); + + GGML_ASSERT(g->ne[0] == 1 || g->ne[0] == S_v); + GGML_ASSERT( g->ne[1] == H_v && g->ne[2] == n_tokens && g->ne[3] == n_seqs); + GGML_ASSERT(b->ne[0] == 1 && b->ne[1] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs); + GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs); + + const float scale = 1.0f / sqrtf(S_k); + + q = ggml_scale(ctx0, q, scale); + + cb(q, "q_in", il); + cb(k, "k_in", il); + cb(v, "v_in", il); + cb(b, "b_in", il); + cb(g, "g_in", il); + + q = ggml_permute(ctx0, q, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs] + k = ggml_permute(ctx0, k, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs] + v = ggml_permute(ctx0, v, 0, 2, 1, 3); // [S_v, n_tokens, H_v, n_seqs] + g = ggml_permute(ctx0, g, 0, 2, 1, 3); // [g_0, n_tokens, H_v, n_seqs] + b = ggml_permute(ctx0, b, 0, 2, 1, 3); // [ 1, n_tokens, H_v, n_seqs] + + const int CS = CHUNK_SIZE; + + const int pad = (CS - n_tokens % CS) % CS; + const int n_chunks = (n_tokens + pad) / CS; + + q = ggml_pad(ctx0, q, 0, pad, 0, 0); + k = ggml_pad(ctx0, k, 0, pad, 0, 0); + v = ggml_pad(ctx0, v, 0, pad, 0, 0); + g = ggml_pad(ctx0, g, 0, pad, 0, 0); + b = ggml_pad(ctx0, b, 0, pad, 0, 0); + + ggml_tensor * v_b = ggml_mul(ctx0, v, b); + ggml_tensor * k_b = ggml_mul(ctx0, k, b); + + cb(v_b, "v_b", il); + cb(k_b, "k_b", il); + + q = ggml_reshape_4d(ctx0, q, S_k, CS, n_chunks, H_k * n_seqs); + k = ggml_reshape_4d(ctx0, k, S_k, CS, n_chunks, H_k * n_seqs); + k_b = ggml_reshape_4d(ctx0, k_b, S_k, CS, n_chunks, H_v * n_seqs); + v = ggml_reshape_4d(ctx0, v, S_v, CS, n_chunks, H_v * n_seqs); + v_b = ggml_reshape_4d(ctx0, v_b, S_v, CS, n_chunks, H_v * n_seqs); + + g = ggml_reshape_4d(ctx0, g, g->ne[0], CS, n_chunks, H_v * n_seqs); + b = ggml_reshape_4d(ctx0, b, 1, CS, n_chunks, H_v * n_seqs); + + // [CS, g_0, n_chunks, H_v * n_seqs] + // TODO: extend ggml_cumsum with axis parameter to avoid transpose + ggml_tensor * g_cs = ggml_cumsum(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, g))); + cb(g_cs, "g_cs", il); + + ggml_tensor * kb = nullptr; + ggml_tensor * kq = nullptr; + if (kda) { + const int64_t CHB = n_chunks * H_k * n_seqs; + + ggml_tensor * g_cs_i = ggml_reshape_4d(ctx0, g_cs, CS, 1, S_k, CHB); // [chunk_size, 1, S_k, CHB] + ggml_tensor * g_cs_j = ggml_reshape_4d(ctx0, g_cs, 1, CS, S_k, CHB); // [1, chunk_size, S_k, CHB] + + g_cs_j = ggml_repeat_4d(ctx0, g_cs_j, CS, CS, S_k, CHB); // [1, chunk_size, S_k, CHB] -> [chunk_size, chunk_size, S_k, CHB] + + // decay_mask [chunk_size,chunk_size,S_k,CHB] + ggml_tensor * decay_mask; + decay_mask = ggml_sub(ctx0, g_cs_j, g_cs_i); + decay_mask = ggml_tri(ctx0, decay_mask, GGML_TRI_TYPE_LOWER_DIAG); + decay_mask = ggml_exp(ctx0, decay_mask); + cb(decay_mask, "decay_mask", il); + + // decay_mask [S_k,BT_j,BT_i,CHB] *Note* second and third chunk_sizes are switched + decay_mask = ggml_cont_4d(ctx0, ggml_permute(ctx0, decay_mask, 2, 1, 0, 3), S_k, CS, CS, CHB); + + ggml_tensor * k_b_i = ggml_reshape_4d(ctx0, k_b, S_k, CS, 1, CHB); + ggml_tensor * k_j = ggml_reshape_4d(ctx0, k, S_k, 1, CS, CHB); + ggml_tensor * q_i = ggml_reshape_4d(ctx0, q, S_k, CS, 1, CHB); + + ggml_tensor * decay_k_b_i = ggml_mul(ctx0, decay_mask, k_b_i); + ggml_tensor * decay_q_i = ggml_mul(ctx0, decay_mask, q_i); + + // decay_k_b_i [S,BT,BT,CHB] @ k_j [S,1,BT,CHB] = Akk [BT,1,BT,CHB] + kb = ggml_mul_mat(ctx0, decay_k_b_i, k_j); + kq = ggml_mul_mat(ctx0, decay_q_i, k_j); + + kb = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, kb, CS, CS, n_chunks, H_v * n_seqs))); + kq = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, kq, CS, CS, n_chunks, H_v * n_seqs))); + } else { + ggml_tensor * g_cs_i = g_cs; + ggml_tensor * g_cs_j = ggml_reshape_4d(ctx0, g_cs, 1, CS, n_chunks, H_v * n_seqs); + + g_cs_j = ggml_repeat_4d(ctx0, g_cs_j, CS, CS, n_chunks, H_v * n_seqs); + + // [CS, CS, n_chunks, H_v * n_seqs] + ggml_tensor * decay_mask; + decay_mask = ggml_sub(ctx0, g_cs_j, g_cs_i); + decay_mask = ggml_tri(ctx0, decay_mask, GGML_TRI_TYPE_LOWER_DIAG); + decay_mask = ggml_exp(ctx0, decay_mask); + cb(decay_mask, "decay_mask", il); + + // [CS, CS, n_chunks, H_k * n_seqs] + kb = ggml_mul_mat(ctx0, k, k_b); + kb = ggml_mul (ctx0, kb, decay_mask); + + // [CS, CS, n_chunks, H_k * n_seqs] + kq = ggml_mul_mat(ctx0, k, q); + kq = ggml_mul(ctx0, kq, decay_mask); + } + + kq = ggml_tri(ctx0, kq, GGML_TRI_TYPE_LOWER_DIAG); + cb(kq, "kq", il); + + // [CS, CS, n_chunks, H_k * n_seqs] + ggml_tensor * attn; + attn = ggml_tri(ctx0, kb, GGML_TRI_TYPE_LOWER); + cb(attn, "attn", il); + + ggml_tensor * identity; + identity = ggml_view_1d(ctx0, attn, CS, 0); + identity = ggml_fill (ctx0, identity, 1.0f); + identity = ggml_diag (ctx0, identity); + + ggml_tensor * lhs = ggml_add(ctx0, attn, identity); + cb(lhs, "dnet_add_ch_lhs", il); + + attn = ggml_neg(ctx0, attn); + cb(attn, "attn_pre_solve", il); + + ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, attn, true, true, false); + attn = ggml_add(ctx0, lin_solve, identity); + cb(attn, "dnet_add_ch_attn_solved", il); // [CS, CS, n_chunks, H_k * n_seqs] + + // [S_v, CS, n_chunks, H_v * n_seqs] + v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_b)), attn); + + // [CS, 1, n_chunks, H_v * n_seqs] KDA: [CS, S_k, n_chunks, H_v * n_seqs] + ggml_tensor * g_exp = ggml_exp(ctx0, g_cs); + + k_b = ggml_cont(ctx0, ggml_transpose(ctx0, k_b)); + + // [CS, S_k, n_chunks, H_k * n_seqs] + ggml_tensor * kbg = ggml_mul(ctx0, k_b, g_exp); + cb(kbg, "k_beta_g_exp", il); + + // [S_k, CS, n_chunks, H_k * n_seqs] + ggml_tensor * k_cd = ggml_mul_mat(ctx0, kbg, attn); + cb(k_cd, "k_cumdecay", il); + + // [1, CS, n_chunks, H_k * n_seqs] KDA: [S_k, CS, n_chunks, H_k * n_seqs] + ggml_tensor * g_exp_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_exp)); + ggml_tensor * q_g_exp = ggml_mul(ctx0, q, g_exp_t); + + // vectorized calculation of key_gdiff + // improved from the chunked version: + // g_last = torch.clamp(g_cum[:, :, -1], max=50.0).exp().unsqueeze(-1).unsqueeze(-1) + // g_diff = torch.clamp(g_cum[:, :, -1:] - g_cum, max=50.0).exp() + // key_gdiff = key * g_diff.unsqueeze(-1) + // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new + // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew + + // get last element in g_cumsum along CS dimension (ne0) + // example: [[x, y, z, ..., last], ...] -> [[last], ...] + // [1, 1, n_chunks, H_v * n_seqs] KDA: [1, S_k, n_chunks, H_v * n_seqs] + ggml_tensor * g_last = ggml_view_4d(ctx0, g_cs, 1, g_cs->ne[1], g_cs->ne[2], g_cs->ne[3], + g_cs->nb[1], + g_cs->nb[2], + g_cs->nb[3], + ggml_row_size(g_cs->type, g_cs->ne[0] - 1)); + cb(g_last, "g_last", il); + + // TODO: remove this cont when CUDA supports non-cont unary ops + g_last = ggml_cont(ctx0, g_last); + + // [1, 1, n_chunks, H_v * n_seqs] KDA: [S_k, 1, n_chunks, H_v * n_seqs] + ggml_tensor * g_last_exp_t = ggml_transpose(ctx0, ggml_exp(ctx0, g_last)); + cb(g_last_exp_t, "g_last_exp_t", il); + + // [CS, 1, n_chunks, H_v * n_seqs] KDA: [CS, S_k, n_chunks, H_v * n_seqs] + ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cs, g_last)); + cb(g_diff, "g_diff", il); + + ggml_tensor * g_diff_exp_t = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_exp(ctx0, g_diff))); + + // [S_k, CS, n_chunks, H_v * n_seqs] + ggml_tensor * kg = ggml_mul(ctx0, k, g_diff_exp_t); + cb(kg, "key_gdiff", il); + + // [CS, S_k, n_chunks, H_v * n_seqs] + ggml_tensor * kg_t = ggml_cont(ctx0, ggml_transpose(ctx0, kg)); + cb(kg_t, "key_gdiff_t", il); + + ggml_tensor * s_t = ggml_transpose(ctx0, s); + s_t = ggml_cont_4d(ctx0, s_t, S_v, S_v, 1, H_v * n_seqs); + cb(s_t, "dnet_add_ch_state", il); + + // [CS, S_v, n_chunks, H_v * n_seqs] + ggml_tensor * v_t = ggml_cont(ctx0, ggml_transpose(ctx0, v)); + + for (int64_t chunk = 0; chunk < n_chunks; chunk++) { + ggml_tensor * ch_k_cd = get_slice_2d(ctx0, k_cd, chunk); // [S_k, CS, 1, H_k * n_seqs] + ggml_tensor * ch_v_t = get_slice_2d(ctx0, v_t, chunk); // [ CS, S_v, 1, H_v * n_seqs] + ggml_tensor * ch_kq = get_slice_2d(ctx0, kq, chunk); // [ CS, CS, 1, H_k * n_seqs] + ggml_tensor * ch_q_g_exp = get_slice_2d(ctx0, q_g_exp, chunk); // [S_k, CS, 1, H_k * n_seqs] + ggml_tensor * ch_kg_t = get_slice_2d(ctx0, kg_t, chunk); // [ CS, S_k, 1, H_v * n_seqs] + + // [CS, S_v, 1, H_v * n_seqs] + ggml_tensor * v_t_p = ggml_mul_mat(ctx0, ch_k_cd, s_t); + cb(v_t_p, "v_prime", il); + + // [CS, S_v, 1, H_v * n_seqs] + ggml_tensor * v_t_new = ggml_sub(ctx0, ch_v_t, v_t_p); + cb(v_t_new, "v_t_new", il); + + // [S_v, CS, 1, H_v * n_seqs] + ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_t_new, ch_kq); + cb(v_attn, "v_attn", il); + + // [S_v, CS, 1, H_v * n_seqs] + ggml_tensor * attn_inter = ggml_mul_mat(ctx0, s_t, ch_q_g_exp); + cb(attn_inter, "attn_inter", il); + + // [S_v, CS, 1, H_v * n_seqs] + ggml_tensor * o_ch = ggml_add(ctx0, attn_inter, v_attn); + cb(o_ch, "dnet_add_ch_attn_out", il); + + v = ggml_set_inplace(ctx0, v, o_ch, v->nb[1], v->nb[2], v->nb[3], chunk * v->nb[2]); + + // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new + // TODO: head broadcast might not work here - probably will need a transpose + ggml_tensor * kgv = ggml_mul_mat(ctx0, ch_kg_t, v_t_new); // [S_k, S_v, 1, H_k * n_seqs] + + // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew + ggml_tensor * ch_g_last_exp_t = get_slice_2d(ctx0, g_last_exp_t, chunk); + + s_t = ggml_mul(ctx0, s_t, ch_g_last_exp_t); + s_t = ggml_add(ctx0, s_t, kgv); + cb(s_t, "dnet_add_ch_state", il); + } + + s_t = ggml_reshape_4d(ctx0, s_t, S_v, S_v, H_v, n_seqs); + + // truncate padded tokens + ggml_tensor * o = ggml_view_4d(ctx0, v, + S_v, n_tokens, H_v, n_seqs, + ggml_row_size(v->type, S_v), + ggml_row_size(v->type, S_v * CS * n_chunks), + ggml_row_size(v->type, S_v * CS * n_chunks * H_v), 0); + o = ggml_permute (ctx0, o, 0, 2, 1, 3); // [S_v, H_v, n_tokens, n_seqs] + s = ggml_transpose(ctx0, s_t); + cb(s, "output_state", il); + + return {o, s}; +} + +std::pair llm_build_delta_net_base::build_delta_net_autoregressive( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * b, // beta + ggml_tensor * s, // state + int il) { + const int64_t S_k = q->ne[0]; + const int64_t H_k = q->ne[1]; + const int64_t n_tokens = q->ne[2]; + const int64_t n_seqs = q->ne[3]; + + const int64_t S_v = v->ne[0]; + const int64_t H_v = v->ne[1]; + + GGML_ASSERT(n_tokens == 1); + + GGML_ASSERT(S_k == S_v); + GGML_ASSERT(H_v % H_k == 0); + + GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs); + GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs); + GGML_ASSERT(v->ne[0] == S_v && v->ne[1] == H_v && v->ne[2] == n_tokens && v->ne[3] == n_seqs); + + GGML_ASSERT(g->ne[0] == 1 || g->ne[0] == S_v); + GGML_ASSERT( g->ne[1] == H_v && g->ne[2] == n_tokens && g->ne[3] == n_seqs); + GGML_ASSERT(b->ne[0] == 1 && b->ne[1] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs); + GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs); + + const float scale = 1.0f / sqrtf(S_k); + + q = ggml_scale(ctx0, q, scale); + + q = ggml_permute(ctx0, q, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs] + k = ggml_permute(ctx0, k, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs] + v = ggml_permute(ctx0, v, 0, 2, 1, 3); // [S_v, n_tokens, H_v, n_seqs] + + cb(q, "q_in", il); + cb(k, "k_in", il); + cb(v, "v_in", il); + cb(b, "b_in", il); + cb(g, "g_in", il); + + // GDA: [1, 1, H_v, n_seqs] + // KDA: [1, S_k, H_v, n_seqs] + g = ggml_reshape_4d(ctx0, g, 1, g->ne[0], H_v, n_seqs); + b = ggml_reshape_4d(ctx0, b, 1, 1, H_v, n_seqs); + + // [S_v, S_v, H_v, n_seqs] + g = ggml_exp(ctx0, g); + s = ggml_mul(ctx0, s, g); + + ggml_tensor * s_t = ggml_cont(ctx0, ggml_transpose(ctx0, s)); + + // [1, S_v, H_v, n_seqs] + ggml_tensor * sk; + sk = ggml_mul (ctx0, s_t, k); + sk = ggml_sum_rows(ctx0, sk); + + // [S_v, 1, H_v, n_seqs] + ggml_tensor * d; + d = ggml_sub(ctx0, v, ggml_transpose(ctx0, sk)); + d = ggml_mul(ctx0, d, b); + + // [1, S_v, H_v, n_seqs] + ggml_tensor * d_t; + d_t = ggml_transpose(ctx0, d); + + // [S_v, S_v, H_v, n_seqs] + ggml_tensor * kd; + k = ggml_repeat(ctx0, k, s); + kd = ggml_mul (ctx0, k, d_t); + + s_t = ggml_add(ctx0, s_t, kd); + + cb(s_t, "dnet_add_ar_state", il); + + ggml_tensor * s_q = ggml_mul (ctx0, s_t, q); + ggml_tensor * o = ggml_sum_rows(ctx0, s_q); + + o = ggml_permute (ctx0, o, 2, 0, 1, 3); // [S_v, H_v, n_tokens, n_seqs] + s = ggml_transpose(ctx0, s_t); // [S_v, S_v, H_v, n_seqs] + + return {o, s}; +} diff --git a/examples/talk-llama/models/eurobert.cpp b/examples/talk-llama/models/eurobert.cpp new file mode 100644 index 00000000000..86e3176edc0 --- /dev/null +++ b/examples/talk-llama/models/eurobert.cpp @@ -0,0 +1,97 @@ +#include "models.h" + +llm_build_eurobert::llm_build_eurobert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + ggml_tensor * cur; + ggml_tensor * inpL; + ggml_tensor * inp_pos = build_inp_pos(); + + inpL = build_inp_embd(model.tok_embd); + cb(inpL, "inp_embd", -1); + + auto * inp_attn = build_attn_inp_no_cache(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * cur = inpL; + + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + + { + ggml_tensor * Qcur; + ggml_tensor * Kcur; + ggml_tensor * Vcur; + + Qcur = build_lora_mm(model.layers[il].wq, cur); + Kcur = build_lora_mm(model.layers[il].wk, cur); + Vcur = build_lora_mm(model.layers[il].wv, cur); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, nullptr, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + cb(cur, "kqv_out", il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + + cur = ggml_add(ctx0, cur, inpL); + + ggml_tensor * ffn_inp = cur; + cb(ffn_inp, "ffn_inp", il); + + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + + inpL = cur; + } + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_embd", -1); + res->t_embd = cur; + + ggml_build_forward_expand(gf, cur); +} diff --git a/examples/talk-llama/models/falcon-h1.cpp b/examples/talk-llama/models/falcon-h1.cpp index b641a094079..785a7e5e662 100644 --- a/examples/talk-llama/models/falcon-h1.cpp +++ b/examples/talk-llama/models/falcon-h1.cpp @@ -1,9 +1,7 @@ #include "models.h" - - llm_build_falcon_h1::llm_build_falcon_h1(const llama_model & model, const llm_graph_params & params) : - llm_graph_context_mamba(params) { + llm_build_mamba_base(params) { const int64_t n_embd_head = hparams.n_embd_head_v; ggml_tensor * cur; diff --git a/examples/talk-llama/models/glm4.cpp b/examples/talk-llama/models/glm4.cpp index 204aa3932af..bcd837b30d6 100644 --- a/examples/talk-llama/models/glm4.cpp +++ b/examples/talk-llama/models/glm4.cpp @@ -29,7 +29,10 @@ llm_build_glm4::llm_build_glm4(const llama_model & model, const llm_graph_params ggml_tensor * inp_out_ids = build_inp_out_ids(); - for (int il = 0; il < n_layer; ++il) { + // Only process up to last layer (skip final NextN layer) + // Final layer tensors are loaded but not processed in forward pass + const int n_transformer_layers = n_layer - hparams.nextn_predict_layers; + for (int il = 0; il < n_transformer_layers; ++il) { ggml_tensor * inpSA = inpL; // Pre-attention norm @@ -100,7 +103,7 @@ llm_build_glm4::llm_build_glm4(const llama_model & model, const llm_graph_params model.layers[il].wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); } - if (il == n_layer - 1 && inp_out_ids) { + if (il == n_transformer_layers - 1 && inp_out_ids) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } @@ -130,9 +133,13 @@ llm_build_glm4::llm_build_glm4(const llama_model & model, const llm_graph_params cur = build_norm(cur, model.layers[il].ffn_post_norm, NULL, LLM_NORM_RMS, il); cb(cur, "post_mlp_norm", il); } - // Add residual connection after post-MLP norm - inpL = ggml_add(ctx0, cur, ffn_inp); - cb(inpL, "l_out", il); + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; } // Final norm cur = build_norm(inpL, model.output_norm, NULL, LLM_NORM_RMS, -1); diff --git a/examples/talk-llama/models/granite-hybrid.cpp b/examples/talk-llama/models/granite-hybrid.cpp index f6ca4c17a21..726ecdcca77 100644 --- a/examples/talk-llama/models/granite-hybrid.cpp +++ b/examples/talk-llama/models/granite-hybrid.cpp @@ -2,7 +2,7 @@ llm_build_granite_hybrid::llm_build_granite_hybrid(const llama_model & model, const llm_graph_params & params) : - llm_graph_context_mamba(params) { + llm_build_mamba_base(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); diff --git a/examples/talk-llama/models/jais2.cpp b/examples/talk-llama/models/jais2.cpp new file mode 100644 index 00000000000..a69fcaa3bb3 --- /dev/null +++ b/examples/talk-llama/models/jais2.cpp @@ -0,0 +1,123 @@ +#include "models.h" + +// JAIS-2 model graph builder +// Uses: LayerNorm (not RMSNorm), relu2 activation, separate Q/K/V, RoPE embeddings +llm_build_jais2::llm_build_jais2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + // KV input for attention + auto * inp_attn = build_attn_inp_kv(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + // Pre-attention LayerNorm + cur = build_norm(inpL, + model.layers[il].attn_norm, + model.layers[il].attn_norm_b, + LLM_NORM, il); + cb(cur, "attn_norm", il); + + // Self-attention with separate Q, K, V projections + { + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur_bias", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur_bias", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur_bias", il); + + // Reshape for attention + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + // Apply RoPE + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur_rope", il); + cb(Kcur, "Kcur_rope", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + + // Residual connection + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL); + cb(ffn_inp, "ffn_inp", il); + + // Pre-FFN LayerNorm + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, + model.layers[il].ffn_norm_b, + LLM_NORM, il); + cb(cur, "ffn_norm", il); + + // FFN with relu2 activation (ReLU squared) - no gate projection + // up -> relu2 -> down + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + NULL, NULL, NULL, // no gate + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_RELU_SQR, LLM_FFN_SEQ, il); + cb(cur, "ffn_out", il); + + // Residual connection + inpL = ggml_add(ctx0, cur, ffn_inp); + inpL = build_cvec(inpL, il); + cb(inpL, "l_out", il); + } + + // Final LayerNorm + cur = build_norm(inpL, + model.output_norm, + model.output_norm_b, + LLM_NORM, -1); + cb(cur, "result_norm", -1); + + res->t_embd = cur; + + // Output projection + cur = build_lora_mm(model.output, cur); + cb(cur, "result_output", -1); + + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} diff --git a/examples/talk-llama/models/jamba.cpp b/examples/talk-llama/models/jamba.cpp index a0187772ccb..ceab5817407 100644 --- a/examples/talk-llama/models/jamba.cpp +++ b/examples/talk-llama/models/jamba.cpp @@ -1,6 +1,6 @@ #include "models.h" -llm_build_jamba::llm_build_jamba(const llama_model & model, const llm_graph_params & params) : llm_graph_context_mamba(params) { +llm_build_jamba::llm_build_jamba(const llama_model & model, const llm_graph_params & params) : llm_build_mamba_base(params) { const int64_t n_embd_head = hparams.n_embd_head_v; ggml_tensor * cur; diff --git a/examples/talk-llama/models/kimi-linear.cpp b/examples/talk-llama/models/kimi-linear.cpp index 942844d071f..83d11241f8d 100644 --- a/examples/talk-llama/models/kimi-linear.cpp +++ b/examples/talk-llama/models/kimi-linear.cpp @@ -1,7 +1,7 @@ #include "models.h" #include "ggml.h" -#define CHUNK_SIZE 64 +#include "llama-memory-recurrent.h" // Causal Conv1d function for Q,K,V // When qkv is 0, it is Q, 1 is K, 2 is V @@ -65,7 +65,7 @@ static ggml_tensor * causal_conv1d(ggml_cgraph * gf, ggml_context * ctx0, ggml_t } llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const llm_graph_params & params) : - llm_graph_context_mamba(params), model(model) { + llm_build_delta_net_base(params), model(model) { ggml_tensor * cur; ggml_tensor * inpL; @@ -84,17 +84,6 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll // Output ids for selecting which tokens to output ggml_tensor * inp_out_ids = build_inp_out_ids(); - ggml_tensor * chunked_causal_mask = - ggml_tri(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, CHUNK_SIZE, CHUNK_SIZE), 1.0f), - GGML_TRI_TYPE_LOWER); - - ggml_tensor * chunked_identity = ggml_diag(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, CHUNK_SIZE), 1.0f)); - ggml_tensor * chunked_diag_mask = ggml_add(ctx0, chunked_causal_mask, chunked_identity); - - ggml_build_forward_expand(gf, chunked_causal_mask); - ggml_build_forward_expand(gf, chunked_identity); - ggml_build_forward_expand(gf, chunked_diag_mask); - // Kimi dimension constants const int64_t n_head = hparams.n_head(); const int64_t head_dim = hparams.n_embd_head_kda; @@ -127,6 +116,8 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll cur = build_norm(inpL, layer.attn_norm, NULL, LLM_NORM_RMS, il); cb(cur, "attn_norm", il); + ggml_build_forward_expand(gf, cur); + // Check layer type by checking which tensors exist // KDA layers have ssm_a_log tensor, MLA layers have wkv_a_mqa tensor bool is_kda = (layer.ssm_a != nullptr); @@ -160,27 +151,35 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll g1 = ggml_mul(ctx0, g1, A); cb(g1, "kda_g1", il); + g1 = ggml_reshape_4d(ctx0, g1, head_dim, n_head, n_seq_tokens, n_seqs); + // Compute beta (mixing coefficient) ggml_tensor * beta = ggml_mul_mat(ctx0, layer.ssm_beta, cur); - beta = ggml_reshape_4d(ctx0, beta, n_head, 1, n_seq_tokens, n_seqs); + beta = ggml_reshape_4d(ctx0, beta, 1, n_head, n_seq_tokens, n_seqs); cb(beta, "kda_beta", il); + beta = ggml_sigmoid(ctx0, beta); + // Reshape for KDA recurrence // {n_embd, n_tokens} -> {n_embd, n_seq_tokens, n_seqs} cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs); - g1 = ggml_reshape_4d(ctx0, g1, head_dim, n_head, n_seq_tokens, n_seqs); - // Get SSM state and compute KDA recurrence using ggml_kda_scan ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); ggml_tensor * state = build_rs(inp_rs, ssm_states_all, hparams.n_embd_s(), n_seqs); state = ggml_reshape_4d(ctx0, state, head_dim, head_dim, n_head, n_seqs); - // Choose between build_kda_chunking and build_kda_recurrent based on n_tokens + + const float eps_norm = hparams.f_norm_rms_eps; + + Qcur = ggml_l2_norm(ctx0, Qcur, eps_norm); + Kcur = ggml_l2_norm(ctx0, Kcur, eps_norm); + + // Choose between build_delta_net_chunking and build_delta_net_recurrent based on n_tokens std::pair attn_out = n_seq_tokens == 1 ? - build_kda_autoregressive(Qcur, Kcur, Vcur, g1, beta, state, il) : - build_kda_chunking(Qcur, Kcur, Vcur, g1, beta, state, chunked_causal_mask, chunked_identity, chunked_diag_mask, il); + build_delta_net_autoregressive(Qcur, Kcur, Vcur, g1, beta, state, il) : + build_delta_net_chunking(Qcur, Kcur, Vcur, g1, beta, state, il); - ggml_tensor * output = attn_out.first; + ggml_tensor * output = ggml_cont(ctx0, attn_out.first); ggml_tensor * new_state = attn_out.second; cb(output, "attn_output", il); cb(new_state, "new_state", il); @@ -391,385 +390,3 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll ggml_build_forward_expand(gf, cur); } - -/* - This is a ggml implementation of the naive_chunk_kda function of - https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/kda/naive.py -*/ -std::pair llm_build_kimi_linear::build_kda_chunking( - ggml_tensor * q, - ggml_tensor * k, - ggml_tensor * v, - ggml_tensor * gk, - ggml_tensor * beta, - ggml_tensor * state, - ggml_tensor * causal_mask, - ggml_tensor * identity, - ggml_tensor * diag_mask, - int il) { - GGML_ASSERT(ggml_is_contiguous(state)); - - const int64_t S_k = q->ne[0]; - const int64_t H_k = q->ne[1]; - const int64_t n_tokens = q->ne[2]; - const int64_t n_seqs = q->ne[3]; - - const int64_t S_v = v->ne[0]; - const int64_t H_v = v->ne[1]; - - GGML_ASSERT(v->ne[2] == n_tokens); - GGML_ASSERT(k->ne[2] == n_tokens); - GGML_ASSERT(gk->ne[0] == S_v && gk->ne[1] == H_v && gk->ne[2] == n_tokens && gk->ne[3] == n_seqs); - GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs); - GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v && state->ne[2] == H_v && state->ne[3] == n_seqs); - - GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs); - GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs); - - GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case - - // TODO: can this ever be false? - const bool use_qk_l2norm = true; - - if (use_qk_l2norm) { - const float eps_norm = hparams.f_norm_rms_eps; - - q = ggml_l2_norm(ctx0, q, eps_norm); - k = ggml_l2_norm(ctx0, k, eps_norm); - } - - const float scale = 1.0f / sqrtf(S_v); - - beta = ggml_sigmoid(ctx0, beta); - - cb(q, "q_in", il); - cb(k, "k_in", il); - cb(v, "v_in", il); - cb(beta, "beta_in", il); - cb(gk, "gk_in", il); - - q = ggml_cont_4d(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3), S_k, n_tokens, H_k, n_seqs); - k = ggml_cont_4d(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3), S_k, n_tokens, H_k, n_seqs); - v = ggml_cont_4d(ctx0, ggml_permute(ctx0, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); - gk = ggml_cont_4d(ctx0, ggml_permute(ctx0, gk, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); - - beta = ggml_cont(ctx0, ggml_permute(ctx0, beta, 2, 0, 1, 3)); - state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs); - - cb(q, "q_perm", il); - cb(k, "k_perm", il); - cb(v, "v_perm", il); - cb(beta, "beta_perm", il); - cb(gk, "gk_perm", il); - cb(state, "state_in", il); - - GGML_ASSERT(q->ne[1] == n_tokens && q->ne[0] == S_k && q->ne[2] == H_k && q->ne[3] == n_seqs); - GGML_ASSERT(k->ne[1] == n_tokens && k->ne[0] == S_k && k->ne[2] == H_k && k->ne[3] == n_seqs); - GGML_ASSERT(v->ne[1] == n_tokens && v->ne[0] == S_v && v->ne[2] == H_k && v->ne[3] == n_seqs); - GGML_ASSERT(beta->ne[1] == n_tokens && beta->ne[2] == H_k && beta->ne[0] == 1 && beta->ne[3] == n_seqs); - - // Do padding - const int64_t chunk_size = CHUNK_SIZE; - - const int64_t pad = (chunk_size - n_tokens % chunk_size) % chunk_size; - const int64_t n_chunks = (n_tokens + pad) / chunk_size; - - q = ggml_pad(ctx0, q, 0, pad, 0, 0); - k = ggml_pad(ctx0, k, 0, pad, 0, 0); - v = ggml_pad(ctx0, v, 0, pad, 0, 0); - gk = ggml_pad(ctx0, gk, 0, pad, 0, 0); - beta = ggml_pad(ctx0, beta, 0, pad, 0, 0); - - cb(q, "q_pad", il); - cb(k, "k_pad", il); - cb(v, "v_pad", il); - cb(beta, "beta_pad", il); - cb(gk, "gk_pad", il); - - ggml_tensor * v_beta = ggml_mul(ctx0, v, beta); - ggml_tensor * k_beta = ggml_mul(ctx0, k, beta); - - cb(v_beta, "v_beta", il); - cb(k_beta, "k_beta", il); - - const int64_t HB = H_k * n_seqs; - - q = ggml_cont_4d(ctx0, q, S_k, chunk_size, n_chunks, HB); - k = ggml_cont_4d(ctx0, k, S_k, chunk_size, n_chunks, HB); - k_beta = ggml_cont_4d(ctx0, k_beta, S_k, chunk_size, n_chunks, HB); - v = ggml_cont_4d(ctx0, v, S_v, chunk_size, n_chunks, HB); - v_beta = ggml_cont_4d(ctx0, v_beta, S_v, chunk_size, n_chunks, HB); - - gk = ggml_cont_4d(ctx0, gk, S_k, chunk_size, n_chunks, HB); - beta = ggml_cont_4d(ctx0, beta, 1, chunk_size, n_chunks, HB); - - // switch for cumsum - gk = ggml_cont_4d(ctx0, ggml_permute(ctx0, gk, 1, 0, 2, 3), chunk_size, S_k, n_chunks, HB); - cb(gk, "gk", il); - ggml_tensor * gk_cumsum = ggml_cumsum(ctx0, gk); - cb(gk_cumsum, "gk_cumsum", il); - -/* - Compute Akk and Aqk loop together - Akk loop: - for i in range(BT): - k_i = k[..., i, :] # k_i [B,H,NT,S] - g_i = g[..., i:i+1, :] # g_i [B,H,NT,1,S] - A[..., i] = torch.einsum('... c d, ... d -> ... c', k * (g - g_i).exp(), k_i) - Aqk loop: - for j in range(BT): - k_j = k[:, :, i, j] - g_j = g[:, :, i, j:j+1, :] - A[..., j] = torch.einsum('... c d, ... d -> ... c', q_i * (g_i - g_j).exp(), k_j) -*/ - const int64_t CHB = n_chunks * H_k * n_seqs; - ggml_tensor * gkcs_i = ggml_reshape_4d(ctx0, gk_cumsum, chunk_size, 1, S_k, CHB); // [chunk_size, 1, S_k, CHB] - ggml_tensor * gkcs_j = ggml_reshape_4d(ctx0, gkcs_i, 1, chunk_size, S_k, CHB); // [1, chunk_size, S_k, CHB] - - ggml_tensor * gkcs_j_bc = ggml_repeat_4d(ctx0, gkcs_j, chunk_size, chunk_size, S_k, CHB); // [1, chunk_size, S_k, CHB] -> [chunk_size, chunk_size, S_k, CHB] - // decay_mask [chunk_size,chunk_size,S_k,CHB] - ggml_tensor * decay_mask = ggml_sub(ctx0, gkcs_j_bc, gkcs_i); - cb(decay_mask, "decay_mask", il); - - decay_mask = ggml_mul(ctx0, decay_mask, diag_mask); - cb(decay_mask, "decay_masked", il); - decay_mask = ggml_exp(ctx0, decay_mask); - decay_mask = ggml_mul(ctx0, decay_mask, diag_mask); - - // decay_mask [S_k,BT_j,BT_i,CHB] *Note* second and third chunk_sizes are switched - decay_mask = ggml_cont_4d(ctx0, ggml_permute(ctx0, decay_mask, 2, 1, 0, 3), S_k, chunk_size, chunk_size, CHB); - - ggml_tensor * k_i = ggml_reshape_4d(ctx0, k, S_k, chunk_size, 1, CHB); - ggml_tensor * k_j = ggml_reshape_4d(ctx0, k, S_k, 1, chunk_size, CHB); - ggml_tensor * q_i = ggml_reshape_4d(ctx0, q, S_k, chunk_size, 1, CHB); - - ggml_tensor * decay_k_i = ggml_mul(ctx0, decay_mask, k_i); - ggml_tensor * decay_q_i = ggml_mul(ctx0, decay_mask, q_i); - - // decay_k_i [S.BT,BT,CHB] @ k_j [S,1,BT,CHB] = Akk [BT,1,BT,CHB] - ggml_tensor * Akk = ggml_mul_mat(ctx0, decay_k_i, k_j); - ggml_tensor * Aqk = ggml_mul_mat(ctx0, decay_q_i, k_j); - Akk = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, Akk, chunk_size, chunk_size, n_chunks, HB))); - Aqk = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, Aqk, chunk_size, chunk_size, n_chunks, HB))); - cb(Akk, "Akk", il); - cb(Aqk, "Aqk", il); - - Akk = ggml_mul(ctx0, Akk, beta); - Akk = ggml_neg(ctx0, ggml_mul(ctx0, Akk, causal_mask)); - cb(Akk, "attn_pre_solve", il); - - Aqk = ggml_mul(ctx0, Aqk, diag_mask); - Aqk = ggml_scale(ctx0, Aqk, scale); // scale q - cb(Aqk, "Aqk_masked", il); - - // for i in range(1, chunk_size): - // row = attn[..., i, :i].clone() - // sub = attn[..., :i, :i].clone() - // attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2) - // attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) - // - // We reduce this to a linear triangular solve: AX = B, where B = attn, A = I - tril(A) - ggml_tensor * attn_lower = ggml_mul(ctx0, Akk, causal_mask); - ggml_tensor * lhs = ggml_sub(ctx0, ggml_repeat(ctx0, identity, attn_lower), attn_lower); - - ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, Akk, true, true, false); - Akk = ggml_mul(ctx0, lin_solve, causal_mask); - Akk = ggml_add(ctx0, Akk, identity); - - cb(Akk, "attn_solved", il); - - // switch back for downstream - gk_cumsum = ggml_cont_4d(ctx0, ggml_permute(ctx0, gk_cumsum, 1, 0, 2, 3), S_k, chunk_size, n_chunks, HB); - ggml_tensor * gkexp = ggml_exp(ctx0, gk_cumsum); - cb(gk_cumsum, "gk_cumsum", il); - - // u = (A*beta[..., None, :]) @ v aka U_[t] - ggml_tensor * vb = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_beta)), Akk); - - ggml_tensor * kbeta_gkexp = ggml_mul(ctx0, k_beta, gkexp); - cb(kbeta_gkexp, "kbeta_gkexp", il); - - ggml_tensor * k_cumdecay = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, kbeta_gkexp)), Akk); - cb(k_cumdecay, "k_cumdecay", il); - - ggml_tensor * core_attn_out = nullptr; - ggml_tensor * new_state = ggml_dup(ctx0, state); - - cb(new_state, "new_state", il); - - for (int64_t chunk = 0; chunk < n_chunks; chunk++) { -// extract one chunk worth of data - auto chunkify = [=](ggml_tensor * t) { - return ggml_cont(ctx0, ggml_view_4d(ctx0, t, t->ne[0], chunk_size, 1, t->ne[3], - t->nb[1], t->nb[2], t->nb[3], t->nb[2] * chunk)); - }; - auto chunkify_A = [=](ggml_tensor * t) { - return ggml_cont(ctx0, ggml_view_4d(ctx0, t, chunk_size, chunk_size, 1, t->ne[3], - t->nb[1], t->nb[2], t->nb[3], t->nb[2] * chunk)); - }; - - -// k [S,BT,NT,H*B] => k_chunk [S,BT,1,H*B] - ggml_tensor * k_chunk = chunkify(k); - ggml_tensor * q_chunk = chunkify(q); - ggml_tensor * vb_chunk = chunkify(vb); - -// gk_cumsum [S,BT,NT,H*B] => gk_cs_chunk [S,BT,1,H*B] - ggml_tensor * gk_cs_chunk = chunkify(gk_cumsum); - ggml_tensor * k_cumdecay_chunk = chunkify(k_cumdecay); - ggml_tensor * gkexp_chunk = ggml_exp(ctx0, gk_cs_chunk); - ggml_tensor * Aqk_chunk = chunkify_A(Aqk); - - ggml_tensor * state_t = ggml_cont_4d(ctx0, ggml_permute(ctx0, new_state, 1, 0, 2, 3), S_v, S_v, 1, H_v * n_seqs); - - // new_state [S,S,1,H*B] k_cumdecay_chunk [S,BT,1,H*B] - // v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state or W_[t] @ S_[t] - ggml_tensor * v_prime = ggml_mul_mat(ctx0, state_t, k_cumdecay_chunk); - - // v_new = v_i - v_prime or U_[t] - W_[t]*S_[t] - ggml_tensor * v_new = ggml_sub(ctx0, ggml_repeat(ctx0, vb_chunk, v_prime), v_prime); - ggml_tensor * v_new_t = ggml_cont(ctx0, ggml_transpose(ctx0, v_new)); - - // q_chunk [S,BT,1,H*B] gkexp_chunk [S,BT,1,H*B] - // attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state - // or Gamma_[t]*Q_]t] @ S - ggml_tensor * q_gk_exp = ggml_mul(ctx0, q_chunk, gkexp_chunk); - ggml_tensor * attn_inter = ggml_mul_mat(ctx0, state_t, q_gk_exp); - attn_inter = ggml_scale(ctx0, attn_inter, scale); // scale q - - // v_new_t [S,BT,1,H*B] Aqk [BT,BT,1,H*B] - // core_attn_out[:, :, i] = attn_inter + attn @ v_new or A' @ (U_[t] - W_[t]*S_[t]) - ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_new_t, Aqk_chunk); - - // o[:, :, i] = (q_i * g_i.exp()) @ S + A @ v_i - ggml_tensor * core_attn_out_chunk = ggml_add(ctx0, attn_inter, v_attn); - - core_attn_out = core_attn_out == nullptr ? core_attn_out_chunk : ggml_concat(ctx0, core_attn_out, core_attn_out_chunk, 1); - - ggml_tensor * gk_cum_last = - ggml_cont(ctx0, ggml_view_4d(ctx0, gk_cs_chunk, gk_cs_chunk->ne[0], 1, gk_cs_chunk->ne[2], gk_cs_chunk->ne[3], - gk_cs_chunk->nb[1], gk_cs_chunk->nb[2], gk_cs_chunk->nb[3], - gk_cs_chunk->nb[1] * (gk_cs_chunk->ne[1] - 1))); - - ggml_tensor * gkexp_last = ggml_exp(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, gk_cum_last))); - - ggml_tensor * gk_diff = ggml_neg(ctx0, ggml_sub(ctx0, gk_cs_chunk, gk_cum_last)); - - ggml_tensor * gk_diff_exp = ggml_exp(ctx0, gk_diff); - - ggml_tensor * key_gkdiff = ggml_mul(ctx0, k_chunk, gk_diff_exp); - - // rearrange((g_i[:,:,-1:] - g_i).exp()*k_i, 'b h c k -> b h k c') @ (U_[t] - W_[t] @ S) - ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, ggml_cont(ctx0, ggml_transpose(ctx0, key_gkdiff))); - - new_state = ggml_add(ctx0, - ggml_mul(ctx0, new_state, ggml_reshape_4d(ctx0, gkexp_last, gkexp_last->ne[0], gkexp_last->ne[1], H_v, n_seqs)), - ggml_reshape_4d(ctx0, kgdmulvnew, kgdmulvnew->ne[0], kgdmulvnew->ne[1], H_v, n_seqs)); - } - - core_attn_out = ggml_cont_4d(ctx0, core_attn_out, S_v, chunk_size * n_chunks, H_v, n_seqs); - - // truncate padded tokens - ggml_tensor * output_tokens = ggml_view_4d(ctx0, core_attn_out, - S_v, n_tokens, H_v, n_seqs, - ggml_row_size(core_attn_out->type, S_v), - ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks), - ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks * H_v), 0); - output_tokens = ggml_cont(ctx0, output_tokens); - // permute back to (S_v, H_v, n_tokens, n_seqs) - output_tokens = ggml_permute(ctx0, output_tokens, 0, 2, 1, 3); - output_tokens = ggml_cont(ctx0, output_tokens); - - cb(new_state, "output_state", il); - - return {output_tokens, new_state}; -} - -std::pair llm_build_kimi_linear::build_kda_autoregressive( - ggml_tensor * q, - ggml_tensor * k, - ggml_tensor * v, - ggml_tensor * gk, - ggml_tensor * beta, - ggml_tensor * state, - int il) { - GGML_ASSERT(ggml_is_contiguous(v)); - GGML_ASSERT(ggml_is_contiguous(gk)); - - const int64_t S_k = q->ne[0]; - const int64_t H_k = q->ne[1]; - const int64_t n_tokens = q->ne[2]; - const int64_t n_seqs = q->ne[3]; - - const int64_t S_v = v->ne[0]; - const int64_t H_v = v->ne[1]; - - GGML_ASSERT(n_tokens == 1); - GGML_ASSERT(v->ne[2] == n_tokens); - GGML_ASSERT(k->ne[2] == n_tokens); - GGML_ASSERT(gk->ne[0] == S_k && gk->ne[1] == H_k && gk->ne[2] == n_tokens && gk->ne[3] == n_seqs); - GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs); - GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_k && state->ne[2] == H_v && state->ne[3] == n_seqs); - - GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs); - GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs); - - GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case - - const float eps_norm = hparams.f_norm_rms_eps; - - q = ggml_l2_norm(ctx0, q, eps_norm); - k = ggml_l2_norm(ctx0, k, eps_norm); - - const float scale = 1.0f / sqrtf(S_v); - - q = ggml_scale(ctx0, q, scale); - beta = ggml_sigmoid(ctx0, beta); - - cb(q, "q_in", il); - cb(k, "k_in", il); - cb(v, "v_in", il); - cb(beta, "beta_in", il); - cb(gk, "gk_in", il); - -// g [H,1,B,1] g_t [1,H,B,1] => [1,1,H,B] -// gk [S,H,1,B] => [S,1,H,B] gk_t [1,S,H,B] -// beta [H,1,1,B] beta_t [1,H,1,B] => [1,1,H,B] - gk = ggml_reshape_4d(ctx0, gk, S_k, 1, H_k, n_seqs); - ggml_tensor * gk_t = ggml_cont(ctx0, ggml_transpose(ctx0, gk)); - ggml_tensor * beta_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, beta), 1, 1, H_k, n_seqs); - - // Apply exponential to gk_t - gk_t = ggml_exp(ctx0, gk_t); - // Apply the gated delta rule for the single timestep - // last_recurrent_state = last_recurrent_state * gk_t - // S = S * g_i[..., None].exp() - state = ggml_mul(ctx0, state, gk_t); - - ggml_tensor * state_t = ggml_cont(ctx0, ggml_transpose(ctx0, state)); - -// state [S,S,H,B] k [S,1,H,B] k_state [S_v,1,H,B] - k = ggml_reshape_4d(ctx0, k, S_k, 1, H_k, n_seqs); - ggml_tensor * k_state = ggml_mul_mat(ctx0, state_t, k); - - // v_i - (k_i[..., None] * S).sum(-2) - v = ggml_reshape_4d(ctx0, v, S_v, 1, H_v, n_seqs); - ggml_tensor * v_diff = ggml_sub(ctx0, v, k_state); - - // b_i[..., None] * k_i - ggml_tensor * k_beta = ggml_mul(ctx0, k, beta_t); - - // S = S + torch.einsum('b h k, b h v -> b h k v', b_i[..., None] * k_i, v_i - (k_i[..., None] * S).sum(-2)) - // v_diff_t [1,S_v,H,B] k_beta_t [1,S_k,H,B] state [S_v,S_k,H,B] - state = ggml_add(ctx0, state, ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_diff)), ggml_cont(ctx0, ggml_transpose(ctx0, k_beta)))); - - q = ggml_reshape_4d(ctx0, q, S_k, 1, H_k, n_seqs); - state_t = ggml_cont(ctx0, ggml_transpose(ctx0, state)); - ggml_tensor * core_attn_out = ggml_mul_mat(ctx0, state_t, q); - // core_attn_out should be [S_v, 1, H_v, n_seqs] after this - cb(core_attn_out, "output_tokens", il); - cb(state, "new_state", il); - - return {core_attn_out, state}; -} - diff --git a/examples/talk-llama/models/lfm2.cpp b/examples/talk-llama/models/lfm2.cpp index 7f805d78795..cf01ad62557 100644 --- a/examples/talk-llama/models/lfm2.cpp +++ b/examples/talk-llama/models/lfm2.cpp @@ -1,18 +1,149 @@ #include "models.h" +#include "../llama-memory-hybrid-iswa.h" #include "../llama-memory-hybrid.h" - -llm_build_lfm2::llm_build_lfm2(const llama_model & model, const llm_graph_params & params) : - llm_graph_context(params), - model(model) { +template +llm_build_lfm2::llm_build_lfm2(const llama_model & model, const llm_graph_params & params) : + llm_graph_context(params) { + using inp_hybrid_type = std::conditional_t; + using inp_attn_type = std::conditional_t; + using mem_hybrid_ctx = std::conditional_t; + + // lambda helpers for readability + auto build_dense_feed_forward = [&model, this](ggml_tensor * cur, int il) -> ggml_tensor * { + GGML_ASSERT(!model.layers[il].ffn_up_b); + GGML_ASSERT(!model.layers[il].ffn_gate_b); + GGML_ASSERT(!model.layers[il].ffn_down_b); + return build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); + }; + auto build_moe_feed_forward = [&model, this](ggml_tensor * cur, int il) -> ggml_tensor * { + return build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps, + model.layers[il].ffn_exp_probs_b, n_expert, n_expert_used, LLM_FFN_SILU, true, false, 0.0, + static_cast(hparams.expert_gating_func), il); + }; + auto build_attn_block = [&model, this](ggml_tensor * cur, + ggml_tensor * inp_pos, + inp_attn_type * inp_attn, + int il) -> ggml_tensor * { + GGML_ASSERT(hparams.n_embd_v_gqa(il) == hparams.n_embd_k_gqa(il)); + const auto n_embd_head = hparams.n_embd_head_v; + const auto n_head_kv = hparams.n_head_kv(il); + + auto * q = build_lora_mm(model.layers[il].wq, cur); + cb(q, "model.layers.{}.self_attn.q_proj", il); + auto * k = build_lora_mm(model.layers[il].wk, cur); + cb(k, "model.layers.{}.self_attn.k_proj", il); + auto * v = build_lora_mm(model.layers[il].wv, cur); + cb(v, "model.layers.{}.self_attn.v_proj", il); + + q = ggml_reshape_3d(ctx0, q, n_embd_head, n_head, n_tokens); + k = ggml_reshape_3d(ctx0, k, n_embd_head, n_head_kv, n_tokens); + v = ggml_reshape_3d(ctx0, v, n_embd_head, n_head_kv, n_tokens); + + // qk norm + q = build_norm(q, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); + cb(q, "model.layers.{}.self_attn.q_layernorm", il); + k = build_norm(k, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); + cb(k, "model.layers.{}.self_attn.k_layernorm", il); + + // RoPE + q = ggml_rope_ext(ctx0, q, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, + attn_factor, beta_fast, beta_slow); + k = ggml_rope_ext(ctx0, k, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, + attn_factor, beta_fast, beta_slow); + + cur = build_attn(inp_attn, + model.layers[il].wo, NULL, + q, k, v, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); + + cb(cur, "model.layers.{}.self_attn.out_proj", il); + + return cur; + }; + auto build_shortconv_block = [&model, this](ggml_tensor * cur, + llm_graph_input_rs * inp_recr, + int il) -> ggml_tensor * { + const auto * mctx_cur = static_cast(mctx)->get_recr(); + const uint32_t kv_head = mctx_cur->get_head(); + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + const int64_t n_seqs = ubatch.n_seqs; + GGML_ASSERT(n_seqs != 0); + GGML_ASSERT(ubatch.equal_seqs()); + GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); + + GGML_ASSERT(hparams.n_shortconv_l_cache > 1); + const uint32_t d_conv = hparams.n_shortconv_l_cache - 1; + + // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} + cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs); + + auto * bcx = build_lora_mm(model.layers[il].shortconv.in_proj, cur); + cb(bcx, "model.layers.{}.conv.in_proj", il); + + constexpr auto n_chunks = 3; + GGML_ASSERT(bcx->ne[0] % n_chunks == 0); + const auto chunk_size = bcx->ne[0] / n_chunks; + auto * b = ggml_view_3d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->ne[2], bcx->nb[1], bcx->nb[2], + 0 * chunk_size * ggml_element_size(bcx)); + auto * c = ggml_view_3d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->ne[2], bcx->nb[1], bcx->nb[2], + 1 * chunk_size * ggml_element_size(bcx)); + auto * x = ggml_view_3d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->ne[2], bcx->nb[1], bcx->nb[2], + 2 * chunk_size * ggml_element_size(bcx)); + + auto * bx = ggml_transpose(ctx0, ggml_mul(ctx0, b, x)); + + // read conv state + auto * conv_state = mctx_cur->get_r_l(il); + auto * conv_rs = build_rs(inp_recr, conv_state, hparams.n_embd_r(), n_seqs); + auto * conv = ggml_reshape_3d(ctx0, conv_rs, d_conv, hparams.n_embd, n_seqs); + + bx = ggml_concat(ctx0, conv, bx, 0); + GGML_ASSERT(bx->ne[0] > conv->ne[0]); + + // last d_conv columns is a new conv state + auto * new_conv = ggml_view_3d(ctx0, bx, conv->ne[0], bx->ne[1], bx->ne[2], bx->nb[1], bx->nb[2], + (bx->ne[0] - conv->ne[0]) * ggml_element_size(bx)); + GGML_ASSERT(ggml_are_same_shape(conv, new_conv)); + + // write new conv conv state + ggml_build_forward_expand(gf, ggml_cpy(ctx0, new_conv, + ggml_view_1d(ctx0, conv_state, ggml_nelements(new_conv), + kv_head * d_conv * n_embd * ggml_element_size(new_conv)))); + + auto * conv_kernel = model.layers[il].shortconv.conv; + auto * conv_out = ggml_ssm_conv(ctx0, bx, conv_kernel); + cb(conv_out, "model.layers.{}.conv.conv", il); + + auto * y = ggml_mul(ctx0, c, conv_out); + y = build_lora_mm(model.layers[il].shortconv.out_proj, y); + cb(y, "model.layers.{}.conv.out_proj", il); + // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens} + y = ggml_reshape_2d(ctx0, y, y->ne[0], n_seq_tokens * n_seqs); + + return y; + }; + + // actual graph construction starts here ggml_tensor * cur = build_inp_embd(model.tok_embd); cb(cur, "model.embed_tokens", -1); ggml_build_forward_expand(gf, cur); + inp_hybrid_type * inp_hybrid = nullptr; + if constexpr (iswa) { + inp_hybrid = build_inp_mem_hybrid_iswa(); + } else { + inp_hybrid = build_inp_mem_hybrid(); + } + ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_hybrid = build_inp_mem_hybrid(); ggml_tensor * inp_out_ids = build_inp_out_ids(); for (int il = 0; il < n_layer; ++il) { @@ -54,122 +185,6 @@ llm_build_lfm2::llm_build_lfm2(const llama_model & model, const llm_graph_params ggml_build_forward_expand(gf, cur); } -ggml_tensor * llm_build_lfm2::build_moe_feed_forward(ggml_tensor * cur, int il) const { - return build_moe_ffn(cur, - model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps, - model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps, - model.layers[il].ffn_exp_probs_b, n_expert, n_expert_used, LLM_FFN_SILU, true, false, 0.0, - static_cast(hparams.expert_gating_func), il); -} - -ggml_tensor * llm_build_lfm2::build_dense_feed_forward(ggml_tensor * cur, int il) const { - GGML_ASSERT(!model.layers[il].ffn_up_b); - GGML_ASSERT(!model.layers[il].ffn_gate_b); - GGML_ASSERT(!model.layers[il].ffn_down_b); - return build_ffn(cur, - model.layers[il].ffn_up, NULL, NULL, - model.layers[il].ffn_gate, NULL, NULL, - model.layers[il].ffn_down, NULL, NULL, - NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); -} - -ggml_tensor * llm_build_lfm2::build_attn_block(ggml_tensor * cur, - ggml_tensor * inp_pos, - llm_graph_input_attn_kv * inp_attn, - int il) const { - GGML_ASSERT(hparams.n_embd_v_gqa(il) == hparams.n_embd_k_gqa(il)); - const auto n_embd_head = hparams.n_embd_head_v; - const auto n_head_kv = hparams.n_head_kv(il); - - auto * q = build_lora_mm(model.layers[il].wq, cur); - cb(q, "model.layers.{}.self_attn.q_proj", il); - auto * k = build_lora_mm(model.layers[il].wk, cur); - cb(k, "model.layers.{}.self_attn.k_proj", il); - auto * v = build_lora_mm(model.layers[il].wv, cur); - cb(v, "model.layers.{}.self_attn.v_proj", il); - - q = ggml_reshape_3d(ctx0, q, n_embd_head, n_head, n_tokens); - k = ggml_reshape_3d(ctx0, k, n_embd_head, n_head_kv, n_tokens); - v = ggml_reshape_3d(ctx0, v, n_embd_head, n_head_kv, n_tokens); - - // qk norm - q = build_norm(q, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); - cb(q, "model.layers.{}.self_attn.q_layernorm", il); - k = build_norm(k, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); - cb(k, "model.layers.{}.self_attn.k_layernorm", il); - - // RoPE - q = ggml_rope_ext(ctx0, q, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, - attn_factor, beta_fast, beta_slow); - k = ggml_rope_ext(ctx0, k, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, - attn_factor, beta_fast, beta_slow); - - cur = build_attn(inp_attn, - model.layers[il].wo, NULL, - q, k, v, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); - - cb(cur, "model.layers.{}.self_attn.out_proj", il); - - return cur; -} - -ggml_tensor * llm_build_lfm2::build_shortconv_block(ggml_tensor * cur, llm_graph_input_rs * inp_recr, int il) { - const auto * mctx_cur = static_cast(mctx)->get_recr(); - const uint32_t kv_head = mctx_cur->get_head(); - const int64_t n_seq_tokens = ubatch.n_seq_tokens; - const int64_t n_seqs = ubatch.n_seqs; - GGML_ASSERT(n_seqs != 0); - GGML_ASSERT(ubatch.equal_seqs()); - GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); - - GGML_ASSERT(hparams.n_shortconv_l_cache > 1); - const uint32_t d_conv = hparams.n_shortconv_l_cache - 1; - - // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} - cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs); - - auto * bcx = build_lora_mm(model.layers[il].shortconv.in_proj, cur); - cb(bcx, "model.layers.{}.conv.in_proj", il); - - constexpr auto n_chunks = 3; - GGML_ASSERT(bcx->ne[0] % n_chunks == 0); - const auto chunk_size = bcx->ne[0] / n_chunks; - auto * b = ggml_view_3d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->ne[2], bcx->nb[1], bcx->nb[2], - 0 * chunk_size * ggml_element_size(bcx)); - auto * c = ggml_view_3d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->ne[2], bcx->nb[1], bcx->nb[2], - 1 * chunk_size * ggml_element_size(bcx)); - auto * x = ggml_view_3d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->ne[2], bcx->nb[1], bcx->nb[2], - 2 * chunk_size * ggml_element_size(bcx)); - - auto * bx = ggml_transpose(ctx0, ggml_mul(ctx0, b, x)); - - // read conv state - auto * conv_state = mctx_cur->get_r_l(il); - auto * conv_rs = build_rs(inp_recr, conv_state, hparams.n_embd_r(), n_seqs); - auto * conv = ggml_reshape_3d(ctx0, conv_rs, d_conv, hparams.n_embd, n_seqs); - - bx = ggml_concat(ctx0, conv, bx, 0); - GGML_ASSERT(bx->ne[0] > conv->ne[0]); - - // last d_conv columns is a new conv state - auto * new_conv = ggml_view_3d(ctx0, bx, conv->ne[0], bx->ne[1], bx->ne[2], bx->nb[1], bx->nb[2], - (bx->ne[0] - conv->ne[0]) * ggml_element_size(bx)); - GGML_ASSERT(ggml_are_same_shape(conv, new_conv)); - - // write new conv conv state - ggml_build_forward_expand(gf, ggml_cpy(ctx0, new_conv, - ggml_view_1d(ctx0, conv_state, ggml_nelements(new_conv), - kv_head * d_conv * n_embd * ggml_element_size(new_conv)))); - - auto * conv_kernel = model.layers[il].shortconv.conv; - auto * conv_out = ggml_ssm_conv(ctx0, bx, conv_kernel); - cb(conv_out, "model.layers.{}.conv.conv", il); - - auto * y = ggml_mul(ctx0, c, conv_out); - y = build_lora_mm(model.layers[il].shortconv.out_proj, y); - cb(y, "model.layers.{}.conv.out_proj", il); - // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens} - y = ggml_reshape_2d(ctx0, y, y->ne[0], n_seq_tokens * n_seqs); - - return y; -} +// Explicit template instantiations +template struct llm_build_lfm2; +template struct llm_build_lfm2; diff --git a/examples/talk-llama/models/graph-context-mamba.cpp b/examples/talk-llama/models/mamba-base.cpp similarity index 97% rename from examples/talk-llama/models/graph-context-mamba.cpp rename to examples/talk-llama/models/mamba-base.cpp index b9a363b32b6..aaac9487dfa 100644 --- a/examples/talk-llama/models/graph-context-mamba.cpp +++ b/examples/talk-llama/models/mamba-base.cpp @@ -1,8 +1,10 @@ #include "models.h" -llm_graph_context_mamba::llm_graph_context_mamba(const llm_graph_params & params) : llm_graph_context(params) {} +#include "llama-memory-recurrent.h" -ggml_tensor * llm_graph_context_mamba::build_mamba_layer(llm_graph_input_rs * inp, +llm_build_mamba_base::llm_build_mamba_base(const llm_graph_params & params) : llm_graph_context(params) {} + +ggml_tensor * llm_build_mamba_base::build_mamba_layer(llm_graph_input_rs * inp, ggml_tensor * cur, const llama_model & model, const llama_ubatch & ubatch, @@ -143,7 +145,7 @@ ggml_tensor * llm_graph_context_mamba::build_mamba_layer(llm_graph_input_rs * in return cur; } -ggml_tensor * llm_graph_context_mamba::build_mamba2_layer(llm_graph_input_rs * inp, +ggml_tensor * llm_build_mamba_base::build_mamba2_layer(llm_graph_input_rs * inp, ggml_tensor * cur, const llama_model & model, const llama_ubatch & ubatch, diff --git a/examples/talk-llama/models/mamba.cpp b/examples/talk-llama/models/mamba.cpp index 46819613c2d..55fd2e055c4 100644 --- a/examples/talk-llama/models/mamba.cpp +++ b/examples/talk-llama/models/mamba.cpp @@ -1,7 +1,6 @@ #include "models.h" - -llm_build_mamba::llm_build_mamba(const llama_model & model, const llm_graph_params & params) : llm_graph_context_mamba(params) { +llm_build_mamba::llm_build_mamba(const llama_model & model, const llm_graph_params & params) : llm_build_mamba_base(params) { ggml_tensor * cur; ggml_tensor * inpL; diff --git a/examples/talk-llama/models/models.h b/examples/talk-llama/models/models.h index ec6f80e5265..0712d03d8d9 100644 --- a/examples/talk-llama/models/models.h +++ b/examples/talk-llama/models/models.h @@ -1,23 +1,51 @@ #pragma once -#include "../llama-model.h" -#include "../llama-graph.h" +#include "llama-model.h" +#include "llama-graph.h" -// TODO: remove in follow-up PR - move to .cpp files -#include "../llama-memory-recurrent.h" +// note: almost all graphs require atleast sqrtf, so include cmath globally #include -struct llm_graph_context_mamba : public llm_graph_context { - llm_graph_context_mamba(const llm_graph_params & params); +// +// base classes +// - virtual ~llm_graph_context_mamba() = default; +struct llm_build_mamba_base : public llm_graph_context { + llm_build_mamba_base(const llm_graph_params & params); + + virtual ~llm_build_mamba_base() = default; ggml_tensor * build_mamba_layer(llm_graph_input_rs * inp, ggml_tensor * cur, const llama_model & model, const llama_ubatch & ubatch, int il); ggml_tensor * build_mamba2_layer(llm_graph_input_rs * inp, ggml_tensor * cur, const llama_model & model, const llama_ubatch & ubatch, int il) const; }; -// Base class for RWKV-related models +struct llm_build_delta_net_base : public llm_graph_context { + llm_build_delta_net_base(const llm_graph_params & params); + + virtual ~llm_build_delta_net_base() = default; + + // returns pair of output and new state + std::pair build_delta_net_chunking( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * b, + ggml_tensor * s, + int il); + + // returns pair of output and new state + std::pair build_delta_net_autoregressive( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * b, + ggml_tensor * s, + int il); +}; + struct llm_build_rwkv6_base : public llm_graph_context { const llama_model & model; @@ -58,6 +86,10 @@ struct llm_build_rwkv7_base : public llm_graph_context { int il) const; }; +// +// models +// + struct llm_build_afmoe : public llm_graph_context { llm_build_afmoe(const llama_model & model, const llm_graph_params & params); }; @@ -158,6 +190,10 @@ struct llm_build_ernie4_5_moe : public llm_graph_context { llm_build_ernie4_5_moe(const llama_model & model, const llm_graph_params & params); }; +struct llm_build_paddleocr : public llm_graph_context { + llm_build_paddleocr(const llama_model & model, const llm_graph_params & params); +}; + template struct llm_build_exaone4 : public llm_graph_context { llm_build_exaone4(const llama_model & model, const llm_graph_params & params); @@ -175,7 +211,7 @@ struct llm_build_falcon : public llm_graph_context { llm_build_falcon(const llama_model & model, const llm_graph_params & params); }; -struct llm_build_falcon_h1 : public llm_graph_context_mamba { +struct llm_build_falcon_h1 : public llm_build_mamba_base { llm_build_falcon_h1(const llama_model & model, const llm_graph_params & params); }; @@ -253,7 +289,7 @@ struct llm_build_granite : public llm_graph_context { const int il); }; -struct llm_build_granite_hybrid : public llm_graph_context_mamba { +struct llm_build_granite_hybrid : public llm_build_mamba_base { llm_build_granite_hybrid(const llama_model & model, const llm_graph_params & params); ggml_tensor * build_layer_ffn(ggml_tensor * cur, ggml_tensor * inpSA, const llama_model & model, const int il); ggml_tensor * build_attention_layer(ggml_tensor * cur, ggml_tensor * inp_pos, llm_graph_input_attn_kv * inp_attn, @@ -284,11 +320,15 @@ struct llm_build_jais : public llm_graph_context { llm_build_jais(const llama_model & model, const llm_graph_params & params); }; -struct llm_build_jamba : public llm_graph_context_mamba { +struct llm_build_jais2 : public llm_graph_context { + llm_build_jais2(const llama_model & model, const llm_graph_params & params); +}; + +struct llm_build_jamba : public llm_build_mamba_base { llm_build_jamba(const llama_model & model, const llm_graph_params & params); }; -struct llm_build_kimi_linear : public llm_graph_context_mamba { +struct llm_build_kimi_linear : public llm_build_delta_net_base { llm_build_kimi_linear(const llama_model & model, const llm_graph_params & params); std::pair build_kda_autoregressive( @@ -315,15 +355,9 @@ struct llm_build_kimi_linear : public llm_graph_context_mamba { const llama_model & model; }; +template struct llm_build_lfm2 : public llm_graph_context { - const llama_model & model; - llm_build_lfm2(const llama_model & model, const llm_graph_params & params); - ggml_tensor * build_moe_feed_forward(ggml_tensor * cur, int il) const; - ggml_tensor * build_dense_feed_forward(ggml_tensor * cur, int il) const; - ggml_tensor * build_attn_block(ggml_tensor * cur, ggml_tensor * inp_pos, llm_graph_input_attn_kv * inp_attn, int il) const; - ggml_tensor * build_shortconv_block(ggml_tensor * cur, llm_graph_input_rs * inp_recr, int il); - }; struct llm_build_llada : public llm_graph_context { @@ -347,7 +381,7 @@ struct llm_build_maincoder : public llm_graph_context { llm_build_maincoder(const llama_model & model, const llm_graph_params & params); }; -struct llm_build_mamba : public llm_graph_context_mamba { +struct llm_build_mamba : public llm_build_mamba_base { llm_build_mamba(const llama_model & model, const llm_graph_params & params); }; @@ -379,17 +413,21 @@ struct llm_build_nemotron : public llm_graph_context { llm_build_nemotron(const llama_model & model, const llm_graph_params & params); }; -struct llm_build_nemotron_h : public llm_graph_context_mamba { +struct llm_build_nemotron_h : public llm_build_mamba_base { llm_build_nemotron_h(const llama_model & model, const llm_graph_params & params); - ggml_tensor * build_ffn_layer(ggml_tensor * cur, const llama_model & model, const int il); + ggml_tensor * build_ffn_layer(ggml_tensor * cur, const llama_model & model, int il); ggml_tensor * build_attention_layer(ggml_tensor * cur, llm_graph_input_attn_kv * inp_attn, - const llama_model & model, const int64_t n_embd_head, const int il); + const llama_model & model, int64_t n_embd_head, int il); }; struct llm_build_neo_bert : public llm_graph_context { llm_build_neo_bert(const llama_model & model, const llm_graph_params & params); }; +struct llm_build_eurobert : public llm_graph_context { + llm_build_eurobert(const llama_model & model, const llm_graph_params & params); +}; + template struct llm_build_olmo2 : public llm_graph_context { llm_build_olmo2(const llama_model & model, const llm_graph_params & params); @@ -428,7 +466,7 @@ struct llm_build_phi3 : public llm_graph_context { llm_build_phi3(const llama_model & model, const llm_graph_params & params); }; -struct llm_build_plamo2 : public llm_graph_context_mamba { +struct llm_build_plamo2 : public llm_build_mamba_base { llm_build_plamo2(const llama_model & model, const llm_graph_params & params); private: ggml_tensor * build_plamo2_mamba_layer(llm_graph_input_rs * inp, ggml_tensor * cur, const llama_model & model, const llama_ubatch & ubatch, int il); @@ -477,7 +515,7 @@ struct llm_build_qwen3vlmoe : public llm_graph_context { llm_build_qwen3vlmoe(const llama_model & model, const llm_graph_params & params); }; -struct llm_build_qwen3next : public llm_graph_context_mamba { +struct llm_build_qwen3next : public llm_build_delta_net_base { llm_build_qwen3next(const llama_model & model, const llm_graph_params & params); private: ggml_tensor * build_layer_attn( @@ -495,26 +533,6 @@ struct llm_build_qwen3next : public llm_graph_context_mamba { ggml_tensor * cur, int il); - // returns pair of output and new state - std::pair build_delta_net_chunking( - ggml_tensor * q, - ggml_tensor * k, - ggml_tensor * v, - ggml_tensor * g, - ggml_tensor * beta, - ggml_tensor * state, - int il); - - // returns pair of output and new state - std::pair build_delta_net_autoregressive( - ggml_tensor * q, - ggml_tensor * k, - ggml_tensor * v, - ggml_tensor * g, - ggml_tensor * beta, - ggml_tensor * state, - int il); - ggml_tensor * build_norm_gated( ggml_tensor * input, ggml_tensor * weights, @@ -529,7 +547,7 @@ struct llm_build_qwen3next : public llm_graph_context_mamba { const llama_model & model; }; -struct llm_build_qwen35 : public llm_graph_context_mamba { +struct llm_build_qwen35 : public llm_build_delta_net_base { llm_build_qwen35(const llama_model & model, const llm_graph_params & params); private: ggml_tensor * build_layer_attn( @@ -542,38 +560,12 @@ struct llm_build_qwen35 : public llm_graph_context_mamba { ggml_tensor * build_layer_attn_linear( llm_graph_input_rs * inp, ggml_tensor * cur, - ggml_tensor * causal_mask, - ggml_tensor * identity, - ggml_tensor * diag_mask, int il); ggml_tensor * build_layer_ffn( ggml_tensor * cur, int il); - // returns pair of output and new state - std::pair build_delta_net_chunking( - ggml_tensor * q, - ggml_tensor * k, - ggml_tensor * v, - ggml_tensor * g, - ggml_tensor * beta, - ggml_tensor * state, - ggml_tensor * causal_mask, - ggml_tensor * identity, - ggml_tensor * diag_mask, - int il); - - // returns pair of output and new state - std::pair build_delta_net_autoregressive( - ggml_tensor * q, - ggml_tensor * k, - ggml_tensor * v, - ggml_tensor * g, - ggml_tensor * beta, - ggml_tensor * state, - int il); - ggml_tensor * build_norm_gated( ggml_tensor * input, ggml_tensor * weights, @@ -588,7 +580,8 @@ struct llm_build_qwen35 : public llm_graph_context_mamba { const llama_model & model; }; -struct llm_build_qwen35moe : public llm_graph_context_mamba { +// TODO: derive llm_build_delta_net_base instead +struct llm_build_qwen35moe : public llm_build_delta_net_base { llm_build_qwen35moe(const llama_model & model, const llm_graph_params & params); private: ggml_tensor * build_layer_attn( @@ -601,38 +594,12 @@ struct llm_build_qwen35moe : public llm_graph_context_mamba { ggml_tensor * build_layer_attn_linear( llm_graph_input_rs * inp, ggml_tensor * cur, - ggml_tensor * causal_mask, - ggml_tensor * identity, - ggml_tensor * diag_mask, int il); ggml_tensor * build_layer_ffn( ggml_tensor * cur, int il); - // returns pair of output and new state - std::pair build_delta_net_chunking( - ggml_tensor * q, - ggml_tensor * k, - ggml_tensor * v, - ggml_tensor * g, - ggml_tensor * beta, - ggml_tensor * state, - ggml_tensor * causal_mask, - ggml_tensor * identity, - ggml_tensor * diag_mask, - int il); - - // returns pair of output and new state - std::pair build_delta_net_autoregressive( - ggml_tensor * q, - ggml_tensor * k, - ggml_tensor * v, - ggml_tensor * g, - ggml_tensor * beta, - ggml_tensor * state, - int il); - ggml_tensor * build_norm_gated( ggml_tensor * input, ggml_tensor * weights, diff --git a/examples/talk-llama/models/modern-bert.cpp b/examples/talk-llama/models/modern-bert.cpp index bb12ed819f7..32066c712b4 100644 --- a/examples/talk-llama/models/modern-bert.cpp +++ b/examples/talk-llama/models/modern-bert.cpp @@ -104,13 +104,6 @@ llm_build_modern_bert::llm_build_modern_bert(const llama_model & model, const ll LLM_NORM, -1); cb(cur, "final_norm_out", -1); - if (hparams.pooling_type == LLAMA_POOLING_TYPE_CLS) { - // extracting cls token - cur = ggml_view_1d(ctx0, cur, hparams.n_embd, 0); - cb(cur, "cls_pooled_embd", -1); - } - - cb(cur, "res_embd", -1); res->t_embd = cur; ggml_build_forward_expand(gf, cur); } diff --git a/examples/talk-llama/models/nemotron-h.cpp b/examples/talk-llama/models/nemotron-h.cpp index 079c730ac29..d61d62a8c96 100644 --- a/examples/talk-llama/models/nemotron-h.cpp +++ b/examples/talk-llama/models/nemotron-h.cpp @@ -1,9 +1,7 @@ #include "models.h" - - llm_build_nemotron_h::llm_build_nemotron_h(const llama_model & model, const llm_graph_params & params) : - llm_graph_context_mamba(params) { + llm_build_mamba_base(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -65,8 +63,8 @@ llm_build_nemotron_h::llm_build_nemotron_h(const llama_model & model, const llm_ ggml_tensor * llm_build_nemotron_h::build_attention_layer(ggml_tensor * cur, llm_graph_input_attn_kv * inp_attn, const llama_model & model, - const int64_t n_embd_head, - const int il) { + int64_t n_embd_head, + int il) { // compute Q and K ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); cb(Qcur, "Qcur", il); @@ -106,7 +104,7 @@ ggml_tensor * llm_build_nemotron_h::build_attention_layer(ggml_tensor * return cur; } -ggml_tensor * llm_build_nemotron_h::build_ffn_layer(ggml_tensor * cur, const llama_model & model, const int il) { +ggml_tensor * llm_build_nemotron_h::build_ffn_layer(ggml_tensor * cur, const llama_model & model, int il) { if (model.layers[il].ffn_gate_inp == nullptr) { cur = build_ffn(cur, model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, diff --git a/examples/talk-llama/models/paddleocr.cpp b/examples/talk-llama/models/paddleocr.cpp new file mode 100644 index 00000000000..39a368df53b --- /dev/null +++ b/examples/talk-llama/models/paddleocr.cpp @@ -0,0 +1,122 @@ +#include "models.h" + +llm_build_paddleocr::llm_build_paddleocr(const llama_model & model, const llm_graph_params & params) : + llm_graph_context(params) { + + // NOTE: same with qwen2vl.cpp, but bias tensors are optional + + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + int sections[4]; + std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + { + cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + } + // self-attention + { + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_multi( + ctx0, Qcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_multi( + ctx0, Kcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + if (il == n_layer - 1) { + // skip computing output for unused tokens + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + { + cur = build_norm(ffn_inp, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + cur = inpL; + + cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} diff --git a/examples/talk-llama/models/plamo2.cpp b/examples/talk-llama/models/plamo2.cpp index 31115a08f95..3af236843bb 100644 --- a/examples/talk-llama/models/plamo2.cpp +++ b/examples/talk-llama/models/plamo2.cpp @@ -1,7 +1,9 @@ #include "models.h" +#include "llama-memory-recurrent.h" + llm_build_plamo2::llm_build_plamo2(const llama_model & model, const llm_graph_params & params) : - llm_graph_context_mamba(params) { + llm_build_mamba_base(params) { ggml_tensor * cur; ggml_tensor * inpL; diff --git a/examples/talk-llama/models/qwen35.cpp b/examples/talk-llama/models/qwen35.cpp index 592c170457b..bacf7a4c2ee 100644 --- a/examples/talk-llama/models/qwen35.cpp +++ b/examples/talk-llama/models/qwen35.cpp @@ -1,10 +1,9 @@ -#include "ggml.h" #include "models.h" -#define CHUNK_SIZE 64 +#include "llama-memory-recurrent.h" llm_build_qwen35::llm_build_qwen35(const llama_model & model, const llm_graph_params & params) : - llm_graph_context_mamba(params), model(model) { + llm_build_delta_net_base(params), model(model) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -24,27 +23,18 @@ llm_build_qwen35::llm_build_qwen35(const llama_model & model, const llm_graph_pa ggml_tensor * inp_pos = build_inp_pos(); ggml_tensor * inp_out_ids = build_inp_out_ids(); - ggml_tensor * causal_mask = - ggml_tri(ctx0, ggml_fill(ctx0, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, CHUNK_SIZE, CHUNK_SIZE), 1.0f), - GGML_TRI_TYPE_LOWER); - - ggml_tensor * identity = ggml_diag(ctx0, ggml_fill(ctx0, ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, CHUNK_SIZE), 1.0f)); - ggml_tensor * diag_mask = ggml_add(ctx0, causal_mask, identity); - - ggml_build_forward_expand(gf, causal_mask); - ggml_build_forward_expand(gf, identity); - ggml_build_forward_expand(gf, diag_mask); - for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); cb(cur, "attn_norm", il); + ggml_build_forward_expand(gf, cur); + // Determine layer type and build appropriate attention mechanism if (hparams.is_recurrent(il)) { // Linear attention layer (gated delta net) - cur = build_layer_attn_linear(inp->get_recr(), cur, causal_mask, identity, diag_mask, il); + cur = build_layer_attn_linear(inp->get_recr(), cur, il); } else { // Full attention layer cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il); @@ -94,361 +84,6 @@ llm_build_qwen35::llm_build_qwen35(const llama_model & model, const llm_graph_pa ggml_build_forward_expand(gf, cur); } -// utility to get one slice from the third dimension -// input dim: [x, y, c, b] -// output dim: [x, y, 1, b] -static ggml_tensor * get_slice_2d(ggml_context * ctx0, ggml_tensor * t, int64_t c) { - return ggml_view_4d(ctx0, t, t->ne[0], t->ne[1], 1, t->ne[3], - t->nb[1], t->nb[2], t->nb[3], t->nb[2] * c); -} - -std::pair llm_build_qwen35::build_delta_net_chunking( - ggml_tensor * q, - ggml_tensor * k, - ggml_tensor * v, - ggml_tensor * g, - ggml_tensor * beta, - ggml_tensor * state, - ggml_tensor * causal_mask, - ggml_tensor * identity, - ggml_tensor * diag_mask, - int il) { - const int64_t S_k = q->ne[0]; - const int64_t H_k = q->ne[1]; - const int64_t n_tokens = q->ne[2]; - const int64_t n_seqs = q->ne[3]; - - const int64_t S_v = v->ne[0]; - const int64_t H_v = v->ne[1]; - - GGML_ASSERT(v->ne[2] == n_tokens); - GGML_ASSERT(k->ne[2] == n_tokens); - GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs); - GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs); - GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs); - - GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs); - GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs); - - GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case - - const float eps_norm = hparams.f_norm_rms_eps; - - q = ggml_l2_norm(ctx0, q, eps_norm); - k = ggml_l2_norm(ctx0, k, eps_norm); - - const float scale = 1.0f / sqrtf(S_v); - - q = ggml_scale(ctx0, q, scale); - - beta = ggml_sigmoid(ctx0, beta); - - cb(q, "q_in", il); - cb(k, "k_in", il); - cb(v, "v_in", il); - cb(beta, "beta_in", il); - cb(g, "g_in", il); - - q = ggml_cont_4d(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); - k = ggml_cont_4d(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); - v = ggml_cont_4d(ctx0, ggml_permute(ctx0, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); - g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 2, 0, 3, 1), n_tokens, 1, H_k, n_seqs); - - beta = ggml_cont(ctx0, ggml_permute(ctx0, beta, 2, 0, 1, 3)); - state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs); - - cb(q, "q_perm", il); - cb(k, "k_perm", il); - cb(v, "v_perm", il); - cb(beta, "beta_perm", il); - cb(g, "g_perm", il); - cb(state, "state_in", il); - - GGML_ASSERT(q->ne[1] == n_tokens && q->ne[0] == S_k && q->ne[2] == H_k && q->ne[3] == n_seqs); - GGML_ASSERT(k->ne[1] == n_tokens && k->ne[0] == S_k && k->ne[2] == H_k && k->ne[3] == n_seqs); - GGML_ASSERT(v->ne[1] == n_tokens && v->ne[0] == S_v && v->ne[2] == H_k && v->ne[3] == n_seqs); - GGML_ASSERT(beta->ne[1] == n_tokens && beta->ne[2] == H_k && beta->ne[0] == 1 && beta->ne[3] == n_seqs); - - // Do padding - const int64_t chunk_size = CHUNK_SIZE; - - const int64_t pad = (chunk_size - n_tokens % chunk_size) % chunk_size; - const int64_t n_chunks = (n_tokens + pad) / chunk_size; - - q = ggml_pad(ctx0, q, 0, pad, 0, 0); - k = ggml_pad(ctx0, k, 0, pad, 0, 0); - v = ggml_pad(ctx0, v, 0, pad, 0, 0); - g = ggml_pad(ctx0, g, pad, 0, 0, 0); - beta = ggml_pad(ctx0, beta, 0, pad, 0, 0); - - cb(q, "q_pad", il); - cb(k, "k_pad", il); - cb(v, "v_pad", il); - cb(beta, "beta_pad", il); - cb(g, "g_pad", il); - - ggml_tensor * v_beta = ggml_mul(ctx0, v, beta); - ggml_tensor * k_beta = ggml_mul(ctx0, k, beta); - - cb(v_beta, "v_beta", il); - cb(k_beta, "k_beta", il); - - q = ggml_reshape_4d(ctx0, q, S_k, chunk_size, n_chunks, H_k * n_seqs); - k = ggml_reshape_4d(ctx0, k, S_k, chunk_size, n_chunks, H_k * n_seqs); - k_beta = ggml_reshape_4d(ctx0, k_beta, S_k, chunk_size, n_chunks, H_k * n_seqs); - v = ggml_reshape_4d(ctx0, v, S_v, chunk_size, n_chunks, H_v * n_seqs); - v_beta = ggml_reshape_4d(ctx0, v_beta, S_v, chunk_size, n_chunks, H_v * n_seqs); - - g = ggml_reshape_4d(ctx0, g, chunk_size, 1, n_chunks, H_k * n_seqs); - beta = ggml_reshape_4d(ctx0, beta, 1, chunk_size, n_chunks, H_k * n_seqs); - - ggml_tensor * g_cumsum = ggml_cumsum(ctx0, g); - cb(g_cumsum, "g_cumsum", il); // shape: (chunk_size, 1, n_chunks, H_v * n_seqs) - - ggml_tensor * gcs_i = g_cumsum; // ggml_reshape_4d(ctx0, g_cumsum, chunk_size, 1, n_chunks, H_v * n_seqs); - ggml_tensor * gcs_j = ggml_reshape_4d(ctx0, g_cumsum, 1, chunk_size, n_chunks, H_v * n_seqs); - - ggml_tensor * gcs_j_broadcast = - ggml_repeat_4d(ctx0, gcs_j, chunk_size, chunk_size, n_chunks, H_v * n_seqs); - - ggml_tensor * decay_mask = ggml_sub(ctx0, gcs_j_broadcast, gcs_i); - cb(decay_mask, "decay_mask", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) - - decay_mask = ggml_mul(ctx0, decay_mask, diag_mask); - decay_mask = ggml_exp(ctx0, decay_mask); - decay_mask = ggml_mul(ctx0, decay_mask, diag_mask); - - ggml_tensor * kmulkbeta = ggml_mul_mat(ctx0, k, k_beta); - - ggml_tensor * k_decay = ggml_mul(ctx0, kmulkbeta, decay_mask); - ggml_tensor * attn = ggml_neg(ctx0, ggml_mul(ctx0, k_decay, causal_mask)); - cb(attn, "attn_pre_solve", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) - - ggml_tensor * attn_lower = ggml_mul(ctx0, attn, causal_mask); - ggml_tensor * lhs = ggml_sub(ctx0, ggml_repeat(ctx0, identity, attn_lower), attn_lower); - - ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, attn, true, true, false); - attn = ggml_mul(ctx0, lin_solve, causal_mask); - attn = ggml_add(ctx0, attn, identity); - cb(attn, "attn_solved", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) - - v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_beta)), attn); - - ggml_tensor * g_cumsum_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_cumsum)); - ggml_tensor * gexp = ggml_exp(ctx0, g_cumsum_t); - - ggml_tensor * kbeta_gexp = ggml_mul(ctx0, k_beta, gexp); - cb(kbeta_gexp, "kbeta_gexp", il); // shape: (S_k, chunk_size, n_chunks, H_v * n_seqs) - - ggml_tensor * k_cumdecay = - ggml_cont(ctx0, ggml_transpose(ctx0, ggml_mul_mat(ctx0, attn, ggml_cont(ctx0, ggml_transpose(ctx0, kbeta_gexp))))); - cb(k_cumdecay, "k_cumdecay", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) - - ggml_tensor * attn_kq = ggml_mul_mat(ctx0, k, q); - attn_kq = ggml_mul(ctx0, attn_kq, decay_mask); - attn_kq = ggml_mul(ctx0, attn_kq, diag_mask); - cb(attn_kq, "attn_kq", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) - - - // vectorized calculation of key_gdiff - // improved from the chunked version: - // g_last = torch.clamp(g_cum[:, :, -1], max=50.0).exp().unsqueeze(-1).unsqueeze(-1) - // g_diff = torch.clamp(g_cum[:, :, -1:] - g_cum, max=50.0).exp() - // key_gdiff = key * g_diff.unsqueeze(-1) - // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new - // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew - - // get last element in g_cumsum along chunk_size dimension (ne0) - // example: [[x, y, z, ..., last], ...] -> [[last], ...] - ggml_tensor * g_last = ggml_view_4d(ctx0, g_cumsum, 1, 1, g_cumsum->ne[2], g_cumsum->ne[3], - g_cumsum->nb[1], g_cumsum->nb[2], g_cumsum->nb[3], - (g_cumsum->ne[0] - 1) * ggml_element_size(g_cumsum)); - g_last = ggml_cont(ctx0, g_last); - cb(g_last, "g_last", il); // shape: (1, 1, n_chunks, H_v * n_seqs) - - ggml_tensor * g_last_exp = ggml_exp(ctx0, g_last); - cb(g_last_exp, "g_last_exp", il); // shape: (1, 1, n_chunks, H_v * n_seqs) - - ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cumsum, g_last)); - cb(g_diff, "g_diff", il); // shape: (chunk_size, 1, n_chunks, H_v * n_seqs) - - ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff); - ggml_tensor * g_diff_exp_t = ggml_reshape_4d(ctx0, g_diff_exp, - 1, chunk_size, n_chunks, g_diff_exp->ne[3]); - - ggml_tensor * key_gdiff = ggml_mul(ctx0, k, g_diff_exp_t); - cb(key_gdiff, "key_gdiff", il); // shape: (S_k, chunk_size, n_chunks, H_v * n_seqs) - - ggml_tensor * key_gdiff_t = ggml_cont(ctx0, ggml_transpose(ctx0, key_gdiff)); - cb(key_gdiff_t, "key_gdiff_t", il); // shape: (chunk_size, S_k, n_chunks, H_v * n_seqs) - - // state to be updated per chunk - ggml_tensor * new_state = state; // ggml_dup(ctx0, state); - cb(new_state, "new_state", il); // shape: (S_v, S_v, H_v, n_seqs) - - // shape after loop of chunks: (S_v, chunk_size, n_chunks, H_v * n_seqs) - ggml_tensor * core_attn_out = nullptr; - - for (int64_t chunk = 0; chunk < n_chunks; chunk++) { - // shape: (S_k, chunk_size, 1, H_k * n_seqs) - ggml_tensor * q_chunk = get_slice_2d(ctx0, q, chunk); // (no cont), next op: ggml_mul - - // shape: (S_v, chunk_size, 1, H_v * n_seqs) - ggml_tensor * v_chunk = get_slice_2d(ctx0, v, chunk); // (no cont), next op: ggml_repeat - - // shape: (chunk_size, 1, n_chunks, H_v * n_seqs) - ggml_tensor * gexp_chunk = get_slice_2d(ctx0, gexp, chunk); // (no cont), next op: ggml_mul - - // shape: (chunk_size, 1, H_v * n_seqs) - ggml_tensor * k_cumdecay_chunk = get_slice_2d(ctx0, k_cumdecay, chunk); // (no cont), next op: ggml_mul_mat - - // attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0) - // replaced by precomputed attn_kq - ggml_tensor * attn_chunk = get_slice_2d(ctx0, attn_kq, chunk); - cb(attn_chunk, "attn_chunk", il); - - ggml_tensor * state_t = ggml_cont_4d(ctx0, ggml_permute(ctx0, new_state, 1, 0, 2, 3), S_v, S_v, 1, H_v * n_seqs); - - // v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state - ggml_tensor * v_prime = ggml_mul_mat(ctx0, state_t, k_cumdecay_chunk); - cb(v_prime, "v_prime_chunk", il); // shape: (S_v, 1, H_v * n_seqs) - - // v_new = v_i - v_prime - ggml_tensor * v_new = ggml_sub(ctx0, ggml_repeat(ctx0, v_chunk, v_prime), v_prime); - ggml_tensor * v_new_t = ggml_cont(ctx0, ggml_transpose(ctx0, v_new)); - cb(v_new, "v_new_chunk", il); - - // attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state - ggml_tensor * q_g_exp = ggml_mul(ctx0, q_chunk, gexp_chunk); - ggml_tensor * attn_inter = ggml_mul_mat(ctx0, state_t, q_g_exp); - cb(attn_inter, "attn_inter_chunk", il); - - // core_attn_out[:, :, i] = attn_inter + attn @ v_new - ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_new_t, attn_chunk); - cb(v_attn, "v_attn_chunk", il); - - ggml_tensor * core_attn_out_chunk = ggml_add(ctx0, attn_inter, v_attn); - cb(core_attn_out_chunk, "core_attn_out_chunk", il); // shape: (S_v, chunk_size, 1, H_v * n_seqs) - - core_attn_out = core_attn_out == nullptr - ? core_attn_out_chunk - : ggml_concat(ctx0, core_attn_out, core_attn_out_chunk, 2); - - // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new - ggml_tensor * k_gdiff_t = get_slice_2d(ctx0, key_gdiff_t, chunk); - //ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, k_gdiff, v_new); // this is slower on metal, why? - ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, k_gdiff_t); - - // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew - ggml_tensor * gexp_last_chunk = ggml_cont(ctx0, get_slice_2d(ctx0, g_last_exp, chunk)); - new_state = ggml_add(ctx0, - ggml_mul(ctx0, new_state, ggml_reshape_4d(ctx0, gexp_last_chunk, gexp_last_chunk->ne[0], gexp_last_chunk->ne[1], H_v, n_seqs)), - ggml_reshape_4d(ctx0, kgdmulvnew, kgdmulvnew->ne[0], kgdmulvnew->ne[1], H_v, n_seqs)); - } - - // truncate padded tokens - ggml_tensor * output_tokens = ggml_view_4d(ctx0, core_attn_out, - S_v, n_tokens, H_v, n_seqs, - ggml_row_size(core_attn_out->type, S_v), - ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks), - ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks * H_v), 0); - output_tokens = ggml_cont(ctx0, output_tokens); - cb(output_tokens, "output_tokens", il); - - // permute back to (S_v, H_v, n_tokens, n_seqs) - output_tokens = ggml_permute(ctx0, output_tokens, 0, 2, 1, 3); - output_tokens = ggml_cont(ctx0, output_tokens); - - return {output_tokens, new_state}; -} - -std::pair llm_build_qwen35::build_delta_net_autoregressive( - ggml_tensor * q, - ggml_tensor * k, - ggml_tensor * v, - ggml_tensor * g, - ggml_tensor * beta, - ggml_tensor * state, - int il) { - const int64_t S_k = q->ne[0]; - const int64_t H_k = q->ne[1]; - const int64_t n_tokens = q->ne[2]; - const int64_t n_seqs = q->ne[3]; - - const int64_t S_v = v->ne[0]; - const int64_t H_v = v->ne[1]; - - GGML_ASSERT(n_tokens == 1); // This function is optimized for single token processing - GGML_ASSERT(v->ne[2] == n_tokens); - GGML_ASSERT(k->ne[2] == n_tokens); - GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs); - GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs); - GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs); - - GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs); - GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs); - - GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case - - const float eps_norm = hparams.f_norm_rms_eps; - - q = ggml_l2_norm(ctx0, q, eps_norm); - k = ggml_l2_norm(ctx0, k, eps_norm); - - const float scale = 1.0f / sqrtf(S_v); - - q = ggml_scale(ctx0, q, scale); - beta = ggml_sigmoid(ctx0, beta); - - cb(q, "q_in", il); - cb(k, "k_in", il); - cb(v, "v_in", il); - cb(beta, "beta_in", il); - cb(g, "g_in", il); - - state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs); - - ggml_tensor * g_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, g), 1, 1, H_k, n_seqs); - ggml_tensor * beta_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, beta), 1, 1, H_k, n_seqs); - - // Apply exponential to g_t - g_t = ggml_exp(ctx0, g_t); - - // Apply the gated delta rule for the single timestep - // last_recurrent_state = last_recurrent_state * g_t - state = ggml_mul(ctx0, state, g_t); - - // kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2) - ggml_tensor * k_t_unsqueezed = ggml_reshape_4d(ctx0, k, 1, S_v, H_v, n_seqs); - ggml_tensor * kv_mem = ggml_mul(ctx0, state, k_t_unsqueezed); - // we need to sum over dim=-2, so we transpose, sum, then transpose again - kv_mem = ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, kv_mem)))); - - // v_t = v.unsqueeze(2) (we insert the singleton dimension after n_seqs and H_v) - ggml_tensor * v_t = ggml_reshape_4d(ctx0, v, S_v, 1, H_v, n_seqs); - // delta = (v_t - kv_mem) * beta_t - ggml_tensor * v_diff = ggml_sub(ctx0, v_t, kv_mem); // both should be [S_v, 1, H_v, n_seqs] - ggml_tensor * delta = ggml_mul(ctx0, v_diff, beta_t); - - // last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta - ggml_tensor * k_t_delta = ggml_mul(ctx0, ggml_repeat_4d(ctx0, k_t_unsqueezed, S_v, S_v, H_v, n_seqs), delta); - state = ggml_add(ctx0, state, k_t_delta); - - // Compute the attention output - // core_attn_out = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2) - ggml_tensor * q_t_unsqueezed = ggml_reshape_4d(ctx0, q, 1, S_v, H_v, n_seqs); // unsqueeze q_t - ggml_tensor * state_q = ggml_mul(ctx0, state, q_t_unsqueezed); - // again, since it's over dim = -2, transpose, sum, transpose back - ggml_tensor * core_attn_out = - ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, state_q)))); - - // core_attn_out should be [S_v, 1, H_v, n_seqs] after this - cb(core_attn_out, "output_tokens", il); - cb(state, "new_state", il); - - return {core_attn_out, state}; -} - std::pair llm_build_qwen35::build_qkvz( ggml_tensor * input, int il) { @@ -560,9 +195,6 @@ ggml_tensor * llm_build_qwen35::build_layer_attn( ggml_tensor * llm_build_qwen35::build_layer_attn_linear( llm_graph_input_rs * inp, ggml_tensor * cur, - ggml_tensor * causal_mask, - ggml_tensor * identity, - ggml_tensor * diag_mask, int il) { const auto * mctx_cur = inp->mctx; @@ -586,8 +218,11 @@ ggml_tensor * llm_build_qwen35::build_layer_attn_linear( ggml_tensor * z = qkvz.second; ggml_tensor * beta = build_lora_mm(model.layers[il].ssm_beta, cur); - beta = ggml_reshape_4d(ctx0, beta, num_v_heads, 1, n_seq_tokens, n_seqs); + beta = ggml_reshape_4d(ctx0, beta, 1, num_v_heads, n_seq_tokens, n_seqs); cb(beta, "beta", il); + + beta = ggml_sigmoid(ctx0, beta); + ggml_tensor * alpha = build_lora_mm(model.layers[il].ssm_alpha, cur); alpha = ggml_cont_3d(ctx0, alpha, num_v_heads, n_seq_tokens, n_seqs); cb(alpha, "alpha", il); @@ -595,15 +230,16 @@ ggml_tensor * llm_build_qwen35::build_layer_attn_linear( ggml_tensor * alpha_biased = ggml_add(ctx0, alpha, model.layers[il].ssm_dt); ggml_tensor * alpha_softplus = ggml_softplus(ctx0, alpha_biased); cb(alpha_softplus, "a_softplus", il); + ggml_tensor * gate = ggml_mul(ctx0, alpha_softplus, model.layers[il].ssm_a); // -A_log.exp() * softplus cb(gate, "gate", il); + gate = ggml_reshape_4d(ctx0, gate, 1, num_v_heads, n_seq_tokens, n_seqs); + // Get convolution states from cache ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); - // bool use_precomputed_states = n_seq_tokens == 1 && mctx_cur->has_previous_state(); - // Build the convolution states tensor ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs); cb(conv_states, "conv_states", il); @@ -612,11 +248,12 @@ ggml_tensor * llm_build_qwen35::build_layer_attn_linear( ggml_tensor * conv_kernel = model.layers[il].ssm_conv1d; const int64_t conv_kernel_size = conv_kernel->ne[0]; const int64_t conv_channels = d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state; - conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs); + + conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs); cb(conv_states, "conv_states_reshaped", il); - qkv_mixed = ggml_permute(ctx0, qkv_mixed, 1, 0, 2, 3); - cb(qkv_mixed, "qkv_mixed_permuted", il); + qkv_mixed = ggml_transpose(ctx0, qkv_mixed); + cb(qkv_mixed, "qkv_mixed_transposed", il); ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0); cb(conv_input, "conv_input", il); @@ -634,9 +271,11 @@ ggml_tensor * llm_build_qwen35::build_layer_attn_linear( cb(state_update_target, "state_update_target", il); ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target)); - cb(conv_states_all, "conv_states_updated", il); - // Apply SSM convolution + ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs); + state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim, num_v_heads, n_seqs); + cb(state, "state_predelta", il); + ggml_tensor * conv_output_proper = ggml_ssm_conv(ctx0, conv_input, conv_kernel); cb(conv_output_proper, "conv_output_raw", il); @@ -650,31 +289,41 @@ ggml_tensor * llm_build_qwen35::build_layer_attn_linear( int64_t nb1_qkv = ggml_row_size(conv_qkv_mix->type, qkv_dim); // Extract the convolved Q, K, V from conv_output - ggml_tensor * q_conv = - ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, nb1_qkv, 0); + ggml_tensor * q_conv = ggml_view_4d(ctx0, conv_qkv_mix, head_k_dim, num_k_heads, n_seq_tokens, n_seqs, + ggml_row_size(conv_qkv_mix->type, head_k_dim), + nb1_qkv, + nb1_qkv * n_seq_tokens, + 0); + + ggml_tensor * k_conv = ggml_view_4d(ctx0, conv_qkv_mix, head_k_dim, num_k_heads, n_seq_tokens, n_seqs, + ggml_row_size(conv_qkv_mix->type, head_k_dim), + nb1_qkv, + nb1_qkv * n_seq_tokens, + head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix)); + + ggml_tensor * v_conv = ggml_view_4d(ctx0, conv_qkv_mix, head_v_dim, num_v_heads, n_seq_tokens, n_seqs, + ggml_row_size(conv_qkv_mix->type, head_v_dim), + nb1_qkv, + nb1_qkv * n_seq_tokens, + ggml_row_size(conv_qkv_mix->type, 2 * head_k_dim * num_k_heads)); + cb(q_conv, "q_conv", il); - ggml_tensor * k_conv = - ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, nb1_qkv, - head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix)); cb(k_conv, "k_conv", il); - ggml_tensor * v_conv = - ggml_view_2d(ctx0, conv_qkv_mix, head_v_dim * num_v_heads, n_seq_tokens * n_seqs, nb1_qkv, - 2 * head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix)); cb(v_conv, "v_conv", il); - // Unsqueeze them - q_conv = ggml_cont_4d(ctx0, q_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs); - k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs); - v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); + const float eps_norm = hparams.f_norm_rms_eps; - ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs); - state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim * num_v_heads, 1, n_seqs); - cb(state, "state_predelta", il); + q_conv = ggml_l2_norm(ctx0, q_conv, eps_norm); + k_conv = ggml_l2_norm(ctx0, k_conv, eps_norm); - // if head keys and value keys are different, repeat Q/K to match V's head count - // V heads are in tiled order (from conversion), so simple tiled repeat works + //q_conv = ggml_cont_4d(ctx0, q_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs); + //k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs); + //v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); + + // if head keys and value keys are different, repeat to force tensors into matching shapes if (num_k_heads != num_v_heads) { GGML_ASSERT(num_v_heads % num_k_heads == 0); + // TODO: try to avoid these explicit repeats by utilizing op broadcast q_conv = ggml_repeat_4d(ctx0, q_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs); k_conv = ggml_repeat_4d(ctx0, k_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs); } @@ -688,7 +337,7 @@ ggml_tensor * llm_build_qwen35::build_layer_attn_linear( if (n_seq_tokens == 1) { attn_out = build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il); } else { - attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, diag_mask, il); + attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, il); } ggml_tensor * output = attn_out.first; ggml_tensor * new_state = attn_out.second; @@ -697,19 +346,15 @@ ggml_tensor * llm_build_qwen35::build_layer_attn_linear( // Update the recurrent states ggml_build_forward_expand(gf, - ggml_cpy(ctx0, new_state, - ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs, - kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all)))); - - // Reshape both attn_out_final and z to 2D tensors for normalization - // attn_out_final: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] - ggml_tensor * attn_out_2d_final = ggml_reshape_2d(ctx0, output, head_v_dim, num_v_heads * n_seq_tokens * n_seqs); + ggml_cpy(ctx0, new_state, + ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs, + kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all)))); // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] - ggml_tensor * z_2d = ggml_reshape_2d(ctx0, z, head_v_dim, num_v_heads * n_seq_tokens * n_seqs); + ggml_tensor * z_2d = ggml_reshape_4d(ctx0, z, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); // Apply gated normalization: self.norm(core_attn_out, z) - ggml_tensor * attn_out_norm = build_norm_gated(attn_out_2d_final, model.layers[il].ssm_norm, z_2d, il); + ggml_tensor * attn_out_norm = build_norm_gated(output, model.layers[il].ssm_norm, z_2d, il); // Final reshape: [head_dim, n_heads, n_tokens, n_seqs] -> [n_tokens, n_seqs, n_heads * head_dim] ggml_tensor * final_output = ggml_reshape_3d(ctx0, attn_out_norm, head_v_dim * num_v_heads, n_seq_tokens, n_seqs); @@ -720,7 +365,8 @@ ggml_tensor * llm_build_qwen35::build_layer_attn_linear( cb(cur, "linear_attn_out", il); // Reshape back to original dimensions - cur = ggml_cont_2d(ctx0, cur, n_embd, n_seq_tokens * n_seqs); + cur = ggml_reshape_2d(ctx0, cur, n_embd, n_seq_tokens * n_seqs); + return cur; } diff --git a/examples/talk-llama/models/qwen35moe.cpp b/examples/talk-llama/models/qwen35moe.cpp index 0db8f825c67..22d708f2062 100644 --- a/examples/talk-llama/models/qwen35moe.cpp +++ b/examples/talk-llama/models/qwen35moe.cpp @@ -1,10 +1,9 @@ -#include "ggml.h" #include "models.h" -#define CHUNK_SIZE 64 +#include "llama-memory-recurrent.h" llm_build_qwen35moe::llm_build_qwen35moe(const llama_model & model, const llm_graph_params & params) : - llm_graph_context_mamba(params), model(model) { + llm_build_delta_net_base(params), model(model) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -24,27 +23,18 @@ llm_build_qwen35moe::llm_build_qwen35moe(const llama_model & model, const llm_gr ggml_tensor * inp_pos = build_inp_pos(); ggml_tensor * inp_out_ids = build_inp_out_ids(); - ggml_tensor * causal_mask = - ggml_tri(ctx0, ggml_fill(ctx0, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, CHUNK_SIZE, CHUNK_SIZE), 1.0f), - GGML_TRI_TYPE_LOWER); - - ggml_tensor * identity = ggml_diag(ctx0, ggml_fill(ctx0, ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, CHUNK_SIZE), 1.0f)); - ggml_tensor * diag_mask = ggml_add(ctx0, causal_mask, identity); - - ggml_build_forward_expand(gf, causal_mask); - ggml_build_forward_expand(gf, identity); - ggml_build_forward_expand(gf, diag_mask); - for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); cb(cur, "attn_norm", il); + ggml_build_forward_expand(gf, cur); + // Determine layer type and build appropriate attention mechanism if (hparams.is_recurrent(il)) { // Linear attention layer (gated delta net) - cur = build_layer_attn_linear(inp->get_recr(), cur, causal_mask, identity, diag_mask, il); + cur = build_layer_attn_linear(inp->get_recr(), cur, il); } else { // Full attention layer cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il); @@ -94,362 +84,6 @@ llm_build_qwen35moe::llm_build_qwen35moe(const llama_model & model, const llm_gr ggml_build_forward_expand(gf, cur); } -// utility to get one slice from the third dimension -// input dim: [x, y, c, b] -// output dim: [x, y, 1, b] -static ggml_tensor * get_slice_2d(ggml_context * ctx0, ggml_tensor * t, int64_t c) { - return ggml_view_4d(ctx0, t, t->ne[0], t->ne[1], 1, t->ne[3], - t->nb[1], t->nb[2], t->nb[3], t->nb[2] * c); -} - -std::pair llm_build_qwen35moe::build_delta_net_chunking( - ggml_tensor * q, - ggml_tensor * k, - ggml_tensor * v, - ggml_tensor * g, - ggml_tensor * beta, - ggml_tensor * state, - ggml_tensor * causal_mask, - ggml_tensor * identity, - ggml_tensor * diag_mask, - int il) { - const int64_t S_k = q->ne[0]; - const int64_t H_k = q->ne[1]; - const int64_t n_tokens = q->ne[2]; - const int64_t n_seqs = q->ne[3]; - - const int64_t S_v = v->ne[0]; - const int64_t H_v = v->ne[1]; - - GGML_ASSERT(v->ne[2] == n_tokens); - GGML_ASSERT(k->ne[2] == n_tokens); - GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs); - GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs); - GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs); - - GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs); - GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs); - - GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case - - const float eps_norm = hparams.f_norm_rms_eps; - - q = ggml_l2_norm(ctx0, q, eps_norm); - k = ggml_l2_norm(ctx0, k, eps_norm); - - const float scale = 1.0f / sqrtf(S_v); - - q = ggml_scale(ctx0, q, scale); - - beta = ggml_sigmoid(ctx0, beta); - - cb(q, "q_in", il); - cb(k, "k_in", il); - cb(v, "v_in", il); - cb(beta, "beta_in", il); - cb(g, "g_in", il); - - q = ggml_cont_4d(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); - k = ggml_cont_4d(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); - v = ggml_cont_4d(ctx0, ggml_permute(ctx0, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); - g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 2, 0, 3, 1), n_tokens, 1, H_k, n_seqs); - - beta = ggml_cont(ctx0, ggml_permute(ctx0, beta, 2, 0, 1, 3)); - state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs); - - cb(q, "q_perm", il); - cb(k, "k_perm", il); - cb(v, "v_perm", il); - cb(beta, "beta_perm", il); - cb(g, "g_perm", il); - cb(state, "state_in", il); - - GGML_ASSERT(q->ne[1] == n_tokens && q->ne[0] == S_k && q->ne[2] == H_k && q->ne[3] == n_seqs); - GGML_ASSERT(k->ne[1] == n_tokens && k->ne[0] == S_k && k->ne[2] == H_k && k->ne[3] == n_seqs); - GGML_ASSERT(v->ne[1] == n_tokens && v->ne[0] == S_v && v->ne[2] == H_k && v->ne[3] == n_seqs); - GGML_ASSERT(beta->ne[1] == n_tokens && beta->ne[2] == H_k && beta->ne[0] == 1 && beta->ne[3] == n_seqs); - - // Do padding - const int64_t chunk_size = CHUNK_SIZE; - - const int64_t pad = (chunk_size - n_tokens % chunk_size) % chunk_size; - const int64_t n_chunks = (n_tokens + pad) / chunk_size; - - q = ggml_pad(ctx0, q, 0, pad, 0, 0); - k = ggml_pad(ctx0, k, 0, pad, 0, 0); - v = ggml_pad(ctx0, v, 0, pad, 0, 0); - g = ggml_pad(ctx0, g, pad, 0, 0, 0); - beta = ggml_pad(ctx0, beta, 0, pad, 0, 0); - - cb(q, "q_pad", il); - cb(k, "k_pad", il); - cb(v, "v_pad", il); - cb(beta, "beta_pad", il); - cb(g, "g_pad", il); - - ggml_tensor * v_beta = ggml_mul(ctx0, v, beta); - ggml_tensor * k_beta = ggml_mul(ctx0, k, beta); - - cb(v_beta, "v_beta", il); - cb(k_beta, "k_beta", il); - - q = ggml_reshape_4d(ctx0, q, S_k, chunk_size, n_chunks, H_k * n_seqs); - k = ggml_reshape_4d(ctx0, k, S_k, chunk_size, n_chunks, H_k * n_seqs); - k_beta = ggml_reshape_4d(ctx0, k_beta, S_k, chunk_size, n_chunks, H_k * n_seqs); - v = ggml_reshape_4d(ctx0, v, S_v, chunk_size, n_chunks, H_v * n_seqs); - v_beta = ggml_reshape_4d(ctx0, v_beta, S_v, chunk_size, n_chunks, H_v * n_seqs); - - g = ggml_reshape_4d(ctx0, g, chunk_size, 1, n_chunks, H_k * n_seqs); - beta = ggml_reshape_4d(ctx0, beta, 1, chunk_size, n_chunks, H_k * n_seqs); - - ggml_tensor * g_cumsum = ggml_cumsum(ctx0, g); - cb(g_cumsum, "g_cumsum", il); // shape: (chunk_size, 1, n_chunks, H_v * n_seqs) - - ggml_tensor * gcs_i = g_cumsum; // ggml_reshape_4d(ctx0, g_cumsum, chunk_size, 1, n_chunks, H_v * n_seqs); - ggml_tensor * gcs_j = ggml_reshape_4d(ctx0, g_cumsum, 1, chunk_size, n_chunks, H_v * n_seqs); - - ggml_tensor * gcs_j_broadcast = - ggml_repeat_4d(ctx0, gcs_j, chunk_size, chunk_size, n_chunks, H_v * n_seqs); - - ggml_tensor * decay_mask = ggml_sub(ctx0, gcs_j_broadcast, gcs_i); - cb(decay_mask, "decay_mask", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) - - decay_mask = ggml_mul(ctx0, decay_mask, diag_mask); - decay_mask = ggml_exp(ctx0, decay_mask); - decay_mask = ggml_mul(ctx0, decay_mask, diag_mask); - - ggml_tensor * kmulkbeta = ggml_mul_mat(ctx0, k, k_beta); - - ggml_tensor * k_decay = ggml_mul(ctx0, kmulkbeta, decay_mask); - ggml_tensor * attn = ggml_neg(ctx0, ggml_mul(ctx0, k_decay, causal_mask)); - cb(attn, "attn_pre_solve", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) - - ggml_tensor * attn_lower = ggml_mul(ctx0, attn, causal_mask); - ggml_tensor * lhs = ggml_sub(ctx0, ggml_repeat(ctx0, identity, attn_lower), attn_lower); - - ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, attn, true, true, false); - attn = ggml_mul(ctx0, lin_solve, causal_mask); - attn = ggml_add(ctx0, attn, identity); - cb(attn, "attn_solved", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) - - v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_beta)), attn); - - ggml_tensor * g_cumsum_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_cumsum)); - ggml_tensor * gexp = ggml_exp(ctx0, g_cumsum_t); - - ggml_tensor * kbeta_gexp = ggml_mul(ctx0, k_beta, gexp); - cb(kbeta_gexp, "kbeta_gexp", il); // shape: (S_k, chunk_size, n_chunks, H_v * n_seqs) - - ggml_tensor * k_cumdecay = - ggml_cont(ctx0, ggml_transpose(ctx0, ggml_mul_mat(ctx0, attn, ggml_cont(ctx0, ggml_transpose(ctx0, kbeta_gexp))))); - cb(k_cumdecay, "k_cumdecay", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) - - ggml_tensor * attn_kq = ggml_mul_mat(ctx0, k, q); - attn_kq = ggml_mul(ctx0, attn_kq, decay_mask); - attn_kq = ggml_mul(ctx0, attn_kq, diag_mask); - cb(attn_kq, "attn_kq", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) - - - // vectorized calculation of key_gdiff - // improved from the chunked version: - // g_last = torch.clamp(g_cum[:, :, -1], max=50.0).exp().unsqueeze(-1).unsqueeze(-1) - // g_diff = torch.clamp(g_cum[:, :, -1:] - g_cum, max=50.0).exp() - // key_gdiff = key * g_diff.unsqueeze(-1) - // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new - // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew - - // get last element in g_cumsum along chunk_size dimension (ne0) - // example: [[x, y, z, ..., last], ...] -> [[last], ...] - ggml_tensor * g_last = ggml_view_4d(ctx0, g_cumsum, 1, 1, g_cumsum->ne[2], g_cumsum->ne[3], - g_cumsum->nb[1], g_cumsum->nb[2], g_cumsum->nb[3], - (g_cumsum->ne[0] - 1) * ggml_element_size(g_cumsum)); - g_last = ggml_cont(ctx0, g_last); - cb(g_last, "g_last", il); // shape: (1, 1, n_chunks, H_v * n_seqs) - - ggml_tensor * g_last_exp = ggml_exp(ctx0, g_last); - cb(g_last_exp, "g_last_exp", il); // shape: (1, 1, n_chunks, H_v * n_seqs) - - ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cumsum, g_last)); - cb(g_diff, "g_diff", il); // shape: (chunk_size, 1, n_chunks, H_v * n_seqs) - - ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff); - ggml_tensor * g_diff_exp_t = ggml_reshape_4d(ctx0, g_diff_exp, - 1, chunk_size, n_chunks, g_diff_exp->ne[3]); - - ggml_tensor * key_gdiff = ggml_mul(ctx0, k, g_diff_exp_t); - cb(key_gdiff, "key_gdiff", il); // shape: (S_k, chunk_size, n_chunks, H_v * n_seqs) - - ggml_tensor * key_gdiff_t = ggml_cont(ctx0, ggml_transpose(ctx0, key_gdiff)); - cb(key_gdiff_t, "key_gdiff_t", il); // shape: (chunk_size, S_k, n_chunks, H_v * n_seqs) - - - // state to be updated per chunk - ggml_tensor * new_state = state; // ggml_dup(ctx0, state); - cb(new_state, "new_state", il); // shape: (S_v, S_v, H_v, n_seqs) - - // shape after loop of chunks: (S_v, chunk_size, n_chunks, H_v * n_seqs) - ggml_tensor * core_attn_out = nullptr; - - for (int64_t chunk = 0; chunk < n_chunks; chunk++) { - // shape: (S_k, chunk_size, 1, H_k * n_seqs) - ggml_tensor * q_chunk = get_slice_2d(ctx0, q, chunk); // (no cont), next op: ggml_mul - - // shape: (S_v, chunk_size, 1, H_v * n_seqs) - ggml_tensor * v_chunk = get_slice_2d(ctx0, v, chunk); // (no cont), next op: ggml_repeat - - // shape: (chunk_size, 1, n_chunks, H_v * n_seqs) - ggml_tensor * gexp_chunk = get_slice_2d(ctx0, gexp, chunk); // (no cont), next op: ggml_mul - - // shape: (chunk_size, 1, H_v * n_seqs) - ggml_tensor * k_cumdecay_chunk = get_slice_2d(ctx0, k_cumdecay, chunk); // (no cont), next op: ggml_mul_mat - - // attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0) - // replaced by precomputed attn_kq - ggml_tensor * attn_chunk = get_slice_2d(ctx0, attn_kq, chunk); - cb(attn_chunk, "attn_chunk", il); - - ggml_tensor * state_t = ggml_cont_4d(ctx0, ggml_permute(ctx0, new_state, 1, 0, 2, 3), S_v, S_v, 1, H_v * n_seqs); - - // v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state - ggml_tensor * v_prime = ggml_mul_mat(ctx0, state_t, k_cumdecay_chunk); - cb(v_prime, "v_prime_chunk", il); // shape: (S_v, 1, H_v * n_seqs) - - // v_new = v_i - v_prime - ggml_tensor * v_new = ggml_sub(ctx0, ggml_repeat(ctx0, v_chunk, v_prime), v_prime); - ggml_tensor * v_new_t = ggml_cont(ctx0, ggml_transpose(ctx0, v_new)); - cb(v_new, "v_new_chunk", il); - - // attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state - ggml_tensor * q_g_exp = ggml_mul(ctx0, q_chunk, gexp_chunk); - ggml_tensor * attn_inter = ggml_mul_mat(ctx0, state_t, q_g_exp); - cb(attn_inter, "attn_inter_chunk", il); - - // core_attn_out[:, :, i] = attn_inter + attn @ v_new - ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_new_t, attn_chunk); - cb(v_attn, "v_attn_chunk", il); - - ggml_tensor * core_attn_out_chunk = ggml_add(ctx0, attn_inter, v_attn); - cb(core_attn_out_chunk, "core_attn_out_chunk", il); // shape: (S_v, chunk_size, 1, H_v * n_seqs) - - core_attn_out = core_attn_out == nullptr - ? core_attn_out_chunk - : ggml_concat(ctx0, core_attn_out, core_attn_out_chunk, 2); - - // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new - ggml_tensor * k_gdiff_t = get_slice_2d(ctx0, key_gdiff_t, chunk); - //ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, k_gdiff, v_new); // this is slower on metal, why? - ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, k_gdiff_t); - - // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew - ggml_tensor * gexp_last_chunk = ggml_cont(ctx0, get_slice_2d(ctx0, g_last_exp, chunk)); - new_state = ggml_add(ctx0, - ggml_mul(ctx0, new_state, ggml_reshape_4d(ctx0, gexp_last_chunk, gexp_last_chunk->ne[0], gexp_last_chunk->ne[1], H_v, n_seqs)), - ggml_reshape_4d(ctx0, kgdmulvnew, kgdmulvnew->ne[0], kgdmulvnew->ne[1], H_v, n_seqs)); - } - - // truncate padded tokens - ggml_tensor * output_tokens = ggml_view_4d(ctx0, core_attn_out, - S_v, n_tokens, H_v, n_seqs, - ggml_row_size(core_attn_out->type, S_v), - ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks), - ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks * H_v), 0); - output_tokens = ggml_cont(ctx0, output_tokens); - cb(output_tokens, "output_tokens", il); - - // permute back to (S_v, H_v, n_tokens, n_seqs) - output_tokens = ggml_permute(ctx0, output_tokens, 0, 2, 1, 3); - output_tokens = ggml_cont(ctx0, output_tokens); - - return {output_tokens, new_state}; -} - -std::pair llm_build_qwen35moe::build_delta_net_autoregressive( - ggml_tensor * q, - ggml_tensor * k, - ggml_tensor * v, - ggml_tensor * g, - ggml_tensor * beta, - ggml_tensor * state, - int il) { - const int64_t S_k = q->ne[0]; - const int64_t H_k = q->ne[1]; - const int64_t n_tokens = q->ne[2]; - const int64_t n_seqs = q->ne[3]; - - const int64_t S_v = v->ne[0]; - const int64_t H_v = v->ne[1]; - - GGML_ASSERT(n_tokens == 1); // This function is optimized for single token processing - GGML_ASSERT(v->ne[2] == n_tokens); - GGML_ASSERT(k->ne[2] == n_tokens); - GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs); - GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs); - GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs); - - GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs); - GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs); - - GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case - - const float eps_norm = hparams.f_norm_rms_eps; - - q = ggml_l2_norm(ctx0, q, eps_norm); - k = ggml_l2_norm(ctx0, k, eps_norm); - - const float scale = 1.0f / sqrtf(S_v); - - q = ggml_scale(ctx0, q, scale); - beta = ggml_sigmoid(ctx0, beta); - - cb(q, "q_in", il); - cb(k, "k_in", il); - cb(v, "v_in", il); - cb(beta, "beta_in", il); - cb(g, "g_in", il); - - state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs); - - ggml_tensor * g_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, g), 1, 1, H_k, n_seqs); - ggml_tensor * beta_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, beta), 1, 1, H_k, n_seqs); - - // Apply exponential to g_t - g_t = ggml_exp(ctx0, g_t); - - // Apply the gated delta rule for the single timestep - // last_recurrent_state = last_recurrent_state * g_t - state = ggml_mul(ctx0, state, g_t); - - // kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2) - ggml_tensor * k_t_unsqueezed = ggml_reshape_4d(ctx0, k, 1, S_v, H_v, n_seqs); - ggml_tensor * kv_mem = ggml_mul(ctx0, state, k_t_unsqueezed); - // we need to sum over dim=-2, so we transpose, sum, then transpose again - kv_mem = ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, kv_mem)))); - - // v_t = v.unsqueeze(2) (we insert the singleton dimension after n_seqs and H_v) - ggml_tensor * v_t = ggml_reshape_4d(ctx0, v, S_v, 1, H_v, n_seqs); - // delta = (v_t - kv_mem) * beta_t - ggml_tensor * v_diff = ggml_sub(ctx0, v_t, kv_mem); // both should be [S_v, 1, H_v, n_seqs] - ggml_tensor * delta = ggml_mul(ctx0, v_diff, beta_t); - - // last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta - ggml_tensor * k_t_delta = ggml_mul(ctx0, ggml_repeat_4d(ctx0, k_t_unsqueezed, S_v, S_v, H_v, n_seqs), delta); - state = ggml_add(ctx0, state, k_t_delta); - - // Compute the attention output - // core_attn_out = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2) - ggml_tensor * q_t_unsqueezed = ggml_reshape_4d(ctx0, q, 1, S_v, H_v, n_seqs); // unsqueeze q_t - ggml_tensor * state_q = ggml_mul(ctx0, state, q_t_unsqueezed); - // again, since it's over dim = -2, transpose, sum, transpose back - ggml_tensor * core_attn_out = - ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, state_q)))); - - // core_attn_out should be [S_v, 1, H_v, n_seqs] after this - cb(core_attn_out, "output_tokens", il); - cb(state, "new_state", il); - - return {core_attn_out, state}; -} - std::pair llm_build_qwen35moe::build_qkvz( ggml_tensor * input, int il) { @@ -561,9 +195,6 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn( ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear( llm_graph_input_rs * inp, ggml_tensor * cur, - ggml_tensor * causal_mask, - ggml_tensor * identity, - ggml_tensor * diag_mask, int il) { const auto * mctx_cur = inp->mctx; @@ -587,8 +218,11 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear( ggml_tensor * z = qkvz.second; ggml_tensor * beta = build_lora_mm(model.layers[il].ssm_beta, cur); - beta = ggml_reshape_4d(ctx0, beta, num_v_heads, 1, n_seq_tokens, n_seqs); + beta = ggml_reshape_4d(ctx0, beta, 1, num_v_heads, n_seq_tokens, n_seqs); cb(beta, "beta", il); + + beta = ggml_sigmoid(ctx0, beta); + ggml_tensor * alpha = build_lora_mm(model.layers[il].ssm_alpha, cur); alpha = ggml_cont_3d(ctx0, alpha, num_v_heads, n_seq_tokens, n_seqs); cb(alpha, "alpha", il); @@ -596,15 +230,16 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear( ggml_tensor * alpha_biased = ggml_add(ctx0, alpha, model.layers[il].ssm_dt); ggml_tensor * alpha_softplus = ggml_softplus(ctx0, alpha_biased); cb(alpha_softplus, "a_softplus", il); + ggml_tensor * gate = ggml_mul(ctx0, alpha_softplus, model.layers[il].ssm_a); // -A_log.exp() * softplus cb(gate, "gate", il); + gate = ggml_reshape_4d(ctx0, gate, 1, num_v_heads, n_seq_tokens, n_seqs); + // Get convolution states from cache ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); - // bool use_precomputed_states = n_seq_tokens == 1 && mctx_cur->has_previous_state(); - // Build the convolution states tensor ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs); cb(conv_states, "conv_states", il); @@ -613,11 +248,12 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear( ggml_tensor * conv_kernel = model.layers[il].ssm_conv1d; const int64_t conv_kernel_size = conv_kernel->ne[0]; const int64_t conv_channels = d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state; - conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs); + + conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs); cb(conv_states, "conv_states_reshaped", il); - qkv_mixed = ggml_permute(ctx0, qkv_mixed, 1, 0, 2, 3); - cb(qkv_mixed, "qkv_mixed_permuted", il); + qkv_mixed = ggml_transpose(ctx0, qkv_mixed); + cb(qkv_mixed, "qkv_mixed_transposed", il); ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0); cb(conv_input, "conv_input", il); @@ -635,9 +271,11 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear( cb(state_update_target, "state_update_target", il); ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target)); - cb(conv_states_all, "conv_states_updated", il); - // Apply SSM convolution + ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs); + state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim, num_v_heads, n_seqs); + cb(state, "state_predelta", il); + ggml_tensor * conv_output_proper = ggml_ssm_conv(ctx0, conv_input, conv_kernel); cb(conv_output_proper, "conv_output_raw", il); @@ -651,31 +289,41 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear( int64_t nb1_qkv = ggml_row_size(conv_qkv_mix->type, qkv_dim); // Extract the convolved Q, K, V from conv_output - ggml_tensor * q_conv = - ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, nb1_qkv, 0); + ggml_tensor * q_conv = ggml_view_4d(ctx0, conv_qkv_mix, head_k_dim, num_k_heads, n_seq_tokens, n_seqs, + ggml_row_size(conv_qkv_mix->type, head_k_dim), + nb1_qkv, + nb1_qkv * n_seq_tokens, + 0); + + ggml_tensor * k_conv = ggml_view_4d(ctx0, conv_qkv_mix, head_k_dim, num_k_heads, n_seq_tokens, n_seqs, + ggml_row_size(conv_qkv_mix->type, head_k_dim), + nb1_qkv, + nb1_qkv * n_seq_tokens, + head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix)); + + ggml_tensor * v_conv = ggml_view_4d(ctx0, conv_qkv_mix, head_v_dim, num_v_heads, n_seq_tokens, n_seqs, + ggml_row_size(conv_qkv_mix->type, head_v_dim), + nb1_qkv, + nb1_qkv * n_seq_tokens, + ggml_row_size(conv_qkv_mix->type, 2 * head_k_dim * num_k_heads)); + cb(q_conv, "q_conv", il); - ggml_tensor * k_conv = - ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, nb1_qkv, - head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix)); cb(k_conv, "k_conv", il); - ggml_tensor * v_conv = - ggml_view_2d(ctx0, conv_qkv_mix, head_v_dim * num_v_heads, n_seq_tokens * n_seqs, nb1_qkv, - 2 * head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix)); cb(v_conv, "v_conv", il); - // Unsqueeze them - q_conv = ggml_cont_4d(ctx0, q_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs); - k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs); - v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); + const float eps_norm = hparams.f_norm_rms_eps; - ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs); - state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim * num_v_heads, 1, n_seqs); - cb(state, "state_predelta", il); + q_conv = ggml_l2_norm(ctx0, q_conv, eps_norm); + k_conv = ggml_l2_norm(ctx0, k_conv, eps_norm); - // if head keys and value keys are different, repeat Q/K to match V's head count - // V heads are in tiled order (from conversion), so simple tiled repeat works + //q_conv = ggml_cont_4d(ctx0, q_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs); + //k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs); + //v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); + + // if head keys and value keys are different, repeat to force tensors into matching shapes if (num_k_heads != num_v_heads) { GGML_ASSERT(num_v_heads % num_k_heads == 0); + // TODO: try to avoid these explicit repeats by utilizing op broadcast q_conv = ggml_repeat_4d(ctx0, q_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs); k_conv = ggml_repeat_4d(ctx0, k_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs); } @@ -689,7 +337,7 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear( if (n_seq_tokens == 1) { attn_out = build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il); } else { - attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, diag_mask, il); + attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, il); } ggml_tensor * output = attn_out.first; ggml_tensor * new_state = attn_out.second; @@ -698,19 +346,15 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear( // Update the recurrent states ggml_build_forward_expand(gf, - ggml_cpy(ctx0, new_state, - ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs, - kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all)))); - - // Reshape both attn_out_final and z to 2D tensors for normalization - // attn_out_final: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] - ggml_tensor * attn_out_2d_final = ggml_reshape_2d(ctx0, output, head_v_dim, num_v_heads * n_seq_tokens * n_seqs); + ggml_cpy(ctx0, new_state, + ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs, + kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all)))); // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] - ggml_tensor * z_2d = ggml_reshape_2d(ctx0, z, head_v_dim, num_v_heads * n_seq_tokens * n_seqs); + ggml_tensor * z_2d = ggml_reshape_4d(ctx0, z, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); // Apply gated normalization: self.norm(core_attn_out, z) - ggml_tensor * attn_out_norm = build_norm_gated(attn_out_2d_final, model.layers[il].ssm_norm, z_2d, il); + ggml_tensor * attn_out_norm = build_norm_gated(output, model.layers[il].ssm_norm, z_2d, il); // Final reshape: [head_dim, n_heads, n_tokens, n_seqs] -> [n_tokens, n_seqs, n_heads * head_dim] ggml_tensor * final_output = ggml_reshape_3d(ctx0, attn_out_norm, head_v_dim * num_v_heads, n_seq_tokens, n_seqs); @@ -721,7 +365,8 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear( cb(cur, "linear_attn_out", il); // Reshape back to original dimensions - cur = ggml_cont_2d(ctx0, cur, n_embd, n_seq_tokens * n_seqs); + cur = ggml_reshape_2d(ctx0, cur, n_embd, n_seq_tokens * n_seqs); + return cur; } @@ -735,7 +380,8 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_ffn(ggml_tensor * cur, const int model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps, nullptr, n_expert, n_expert_used, LLM_FFN_SILU, - true, false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il); + true, false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il, + nullptr, model.layers[il].ffn_gate_up_exps); cb(moe_out, "ffn_moe_out", il); // Add shared experts if present - following Qwen3Next reference implementation diff --git a/examples/talk-llama/models/qwen3next.cpp b/examples/talk-llama/models/qwen3next.cpp index aea8b29513e..f2621200f23 100644 --- a/examples/talk-llama/models/qwen3next.cpp +++ b/examples/talk-llama/models/qwen3next.cpp @@ -1,10 +1,9 @@ -#include "ggml.h" #include "models.h" -#define CHUNK_SIZE 64 +#include "llama-memory-recurrent.h" llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_graph_params & params) : - llm_graph_context_mamba(params), model(model) { + llm_build_delta_net_base(params), model(model) { ggml_tensor * cur; ggml_tensor * inpL; @@ -22,6 +21,8 @@ llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_gr cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); cb(cur, "attn_norm", il); + ggml_build_forward_expand(gf, cur); + // Determine layer type and build appropriate attention mechanism if (hparams.is_recurrent(il)) { // Linear attention layer (gated delta net) @@ -83,326 +84,6 @@ static ggml_tensor * get_slice_2d(ggml_context * ctx0, ggml_tensor * t, int64_t t->nb[1], t->nb[2], t->nb[3], t->nb[2] * c); } -std::pair llm_build_qwen3next::build_delta_net_chunking( - ggml_tensor * q, - ggml_tensor * k, - ggml_tensor * v, - ggml_tensor * g, - ggml_tensor * b, - ggml_tensor * s, - int il) { - const int64_t S_k = q->ne[0]; - const int64_t H_k = q->ne[1]; - const int64_t n_tokens = q->ne[2]; - const int64_t n_seqs = q->ne[3]; - - const int64_t S_v = v->ne[0]; - const int64_t H_v = v->ne[1]; - - GGML_ASSERT(S_k == S_v); - GGML_ASSERT(H_v % H_k == 0); - - GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs); - GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs); - GGML_ASSERT(v->ne[0] == S_v && v->ne[1] == H_v && v->ne[2] == n_tokens && v->ne[3] == n_seqs); - - GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs); - GGML_ASSERT(b->ne[0] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs); - GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs); - - const float scale = 1.0f / sqrtf(S_k); - - q = ggml_scale(ctx0, q, scale); - - cb(q, "q_in", il); - cb(k, "k_in", il); - cb(v, "v_in", il); - cb(b, "b_in", il); - cb(g, "g_in", il); - - q = ggml_permute(ctx0, q, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs] - k = ggml_permute(ctx0, k, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs] - v = ggml_permute(ctx0, v, 0, 2, 1, 3); // [S_v, n_tokens, H_v, n_seqs] - g = ggml_permute(ctx0, g, 2, 1, 3, 0); // [ 1, n_tokens, H_v, n_seqs] - b = ggml_permute(ctx0, b, 2, 0, 1, 3); // [ 1, n_tokens, H_v, n_seqs] - - const int CS = CHUNK_SIZE; - - const int pad = (CS - n_tokens % CS) % CS; - const int n_chunks = (n_tokens + pad) / CS; - - q = ggml_pad(ctx0, q, 0, pad, 0, 0); - k = ggml_pad(ctx0, k, 0, pad, 0, 0); - v = ggml_pad(ctx0, v, 0, pad, 0, 0); - g = ggml_pad(ctx0, g, 0, pad, 0, 0); - b = ggml_pad(ctx0, b, 0, pad, 0, 0); - - ggml_tensor * v_b = ggml_mul(ctx0, v, b); - ggml_tensor * k_b = ggml_mul(ctx0, k, b); - - cb(v_b, "v_b", il); - cb(k_b, "k_b", il); - - q = ggml_reshape_4d(ctx0, q, S_k, CS, n_chunks, H_k * n_seqs); - k = ggml_reshape_4d(ctx0, k, S_k, CS, n_chunks, H_k * n_seqs); - k_b = ggml_reshape_4d(ctx0, k_b, S_k, CS, n_chunks, H_v * n_seqs); - v = ggml_reshape_4d(ctx0, v, S_v, CS, n_chunks, H_v * n_seqs); - v_b = ggml_reshape_4d(ctx0, v_b, S_v, CS, n_chunks, H_v * n_seqs); - - g = ggml_reshape_4d(ctx0, g, CS, 1, n_chunks, H_v * n_seqs); - b = ggml_reshape_4d(ctx0, b, 1, CS, n_chunks, H_v * n_seqs); - - // [CS, 1, n_chunks, H_v * n_seqs] - ggml_tensor * g_cs = ggml_cumsum(ctx0, g); - cb(g_cs, "g_cs", il); - - ggml_tensor * g_cs_i = g_cs; - ggml_tensor * g_cs_j = ggml_reshape_4d(ctx0, g_cs, 1, CS, n_chunks, H_v * n_seqs); - - g_cs_j = ggml_repeat_4d(ctx0, g_cs_j, CS, CS, n_chunks, H_v * n_seqs); - - // [CS, CS, n_chunks, H_v * n_seqs] - ggml_tensor * decay_mask; - decay_mask = ggml_sub(ctx0, g_cs_j, g_cs_i); - decay_mask = ggml_tri(ctx0, decay_mask, GGML_TRI_TYPE_LOWER_DIAG); - decay_mask = ggml_exp(ctx0, decay_mask); - cb(decay_mask, "decay_mask", il); - - // [CS, CS, n_chunks, H_k * n_seqs] - ggml_tensor * kb; - kb = ggml_mul_mat(ctx0, k, k_b); - kb = ggml_mul (ctx0, kb, decay_mask); - - // [CS, CS, n_chunks, H_k * n_seqs] - ggml_tensor * attn; - attn = ggml_tri(ctx0, kb, GGML_TRI_TYPE_LOWER); - - ggml_tensor * identity; - identity = ggml_view_1d(ctx0, attn, CS, 0); - identity = ggml_fill (ctx0, identity, 1.0f); - identity = ggml_diag (ctx0, identity); - - ggml_tensor * lhs = ggml_add(ctx0, attn, identity); - cb(lhs, "dnet_add_ch_lhs", il); - - attn = ggml_neg(ctx0, attn); - - ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, attn, true, true, false); - attn = ggml_add(ctx0, lin_solve, identity); - cb(attn, "dnet_add_ch_attn_solved", il); // [CS, CS, n_chunks, H_k * n_seqs] - - // [S_v, CS, n_chunks, H_v * n_seqs] - v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_b)), attn); - - // [CS, 1, n_chunks, H_v * n_seqs] - ggml_tensor * g_exp = ggml_exp(ctx0, g_cs); - - k_b = ggml_cont(ctx0, ggml_transpose(ctx0, k_b)); - - // [CS, S_k, n_chunks, H_k * n_seqs] - ggml_tensor * kbg = ggml_mul(ctx0, k_b, g_exp); - cb(kbg, "k_beta_g_exp", il); - - // [S_k, CS, n_chunks, H_k * n_seqs] - ggml_tensor * k_cd = ggml_mul_mat(ctx0, kbg, attn); - cb(k_cd, "k_cumdecay", il); - - // [S_k, CS, n_chunks, H_k * n_seqs] - ggml_tensor * g_exp_t = ggml_transpose(ctx0, g_exp); - ggml_tensor * q_g_exp = ggml_mul(ctx0, q, g_exp_t); - - // [CS, CS, n_chunks, H_k * n_seqs] - ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); - kq = ggml_mul(ctx0, kq, decay_mask); - kq = ggml_tri(ctx0, kq, GGML_TRI_TYPE_LOWER_DIAG); - cb(kq, "kq", il); - - // vectorized calculation of key_gdiff - // improved from the chunked version: - // g_last = torch.clamp(g_cum[:, :, -1], max=50.0).exp().unsqueeze(-1).unsqueeze(-1) - // g_diff = torch.clamp(g_cum[:, :, -1:] - g_cum, max=50.0).exp() - // key_gdiff = key * g_diff.unsqueeze(-1) - // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new - // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew - - // get last element in g_cumsum along CS dimension (ne0) - // example: [[x, y, z, ..., last], ...] -> [[last], ...] - // [1, 1, n_chunks, H_v * n_seqs] - ggml_tensor * g_last = ggml_view_4d(ctx0, g_cs, 1, 1, g_cs->ne[2], g_cs->ne[3], - g_cs->nb[1], - g_cs->nb[2], - g_cs->nb[3], - ggml_row_size(g_cs->type, g_cs->ne[0] - 1)); - cb(g_last, "g_last", il); - - // TODO: remove this cont when CUDA supports non-cont unary ops - g_last = ggml_cont(ctx0, g_last); - - // [1, 1, n_chunks, H_v * n_seqs] - ggml_tensor * g_last_exp = ggml_exp(ctx0, g_last); - cb(g_last_exp, "g_last_exp", il); - - // [CS, 1, n_chunks, H_v * n_seqs] - ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cs, g_last)); - cb(g_diff, "g_diff", il); - - ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff); - ggml_tensor * g_diff_exp_t = ggml_transpose(ctx0, g_diff_exp); - - // [S_k, CS, n_chunks, H_v * n_seqs] - ggml_tensor * kg = ggml_mul(ctx0, k, g_diff_exp_t); - cb(kg, "key_gdiff", il); - - // [CS, S_k, n_chunks, H_v * n_seqs] - ggml_tensor * kg_t = ggml_cont(ctx0, ggml_transpose(ctx0, kg)); - cb(kg_t, "key_gdiff_t", il); - - ggml_tensor * s_t = ggml_transpose(ctx0, s); - s_t = ggml_cont_4d(ctx0, s_t, S_v, S_v, 1, H_v * n_seqs); - cb(s_t, "dnet_add_ch_state", il); - - // [CS, S_v, n_chunks, H_v * n_seqs] - ggml_tensor * v_t = ggml_cont(ctx0, ggml_transpose(ctx0, v)); - - for (int64_t chunk = 0; chunk < n_chunks; chunk++) { - ggml_tensor * ch_k_cd = get_slice_2d(ctx0, k_cd, chunk); // [S_k, CS, 1, H_k * n_seqs] - ggml_tensor * ch_v_t = get_slice_2d(ctx0, v_t, chunk); // [ CS, S_v, 1, H_v * n_seqs] - ggml_tensor * ch_kq = get_slice_2d(ctx0, kq, chunk); // [ CS, CS, 1, H_k * n_seqs] - ggml_tensor * ch_q_g_exp = get_slice_2d(ctx0, q_g_exp, chunk); // [S_k, CS, 1, H_k * n_seqs] - ggml_tensor * ch_kg_t = get_slice_2d(ctx0, kg_t, chunk); // [ CS, S_k, 1, H_v * n_seqs] - - // [CS, S_v, 1, H_v * n_seqs] - ggml_tensor * v_t_p = ggml_mul_mat(ctx0, ch_k_cd, s_t); - cb(v_t_p, "v_prime", il); - - // [CS, S_v, 1, H_v * n_seqs] - ggml_tensor * v_t_new = ggml_sub(ctx0, ch_v_t, v_t_p); - cb(v_t_new, "v_t_new", il); - - // [S_v, CS, 1, H_v * n_seqs] - ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_t_new, ch_kq); - cb(v_attn, "v_attn", il); - - // [S_v, CS, 1, H_v * n_seqs] - ggml_tensor * attn_inter = ggml_mul_mat(ctx0, s_t, ch_q_g_exp); - cb(attn_inter, "attn_inter", il); - - // [S_v, CS, 1, H_v * n_seqs] - ggml_tensor * o_ch = ggml_add(ctx0, attn_inter, v_attn); - cb(o_ch, "dnet_add_ch_attn_out", il); - - v = ggml_set_inplace(ctx0, v, o_ch, v->nb[1], v->nb[2], v->nb[3], chunk * v->nb[2]); - - // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new - // TODO: head broadcast might not work here - probably will need a transpose - ggml_tensor * kgv = ggml_mul_mat(ctx0, ch_kg_t, v_t_new); // [S_k, S_v, 1, H_k * n_seqs] - - // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew - ggml_tensor * ch_g_last_exp = get_slice_2d(ctx0, g_last_exp, chunk); - s_t = ggml_mul(ctx0, s_t, ch_g_last_exp); - s_t = ggml_add(ctx0, s_t, kgv); - cb(s_t, "dnet_add_ch_state", il); - } - - s_t = ggml_reshape_4d(ctx0, s_t, S_v, S_v, H_v, n_seqs); - - // truncate padded tokens - ggml_tensor * o = ggml_view_4d(ctx0, v, - S_v, n_tokens, H_v, n_seqs, - ggml_row_size(v->type, S_v), - ggml_row_size(v->type, S_v * CS * n_chunks), - ggml_row_size(v->type, S_v * CS * n_chunks * H_v), 0); - - o = ggml_permute (ctx0, o, 0, 2, 1, 3); // [S_v, H_v, n_tokens, n_seqs] - s = ggml_transpose(ctx0, s_t); // [S_v, S_v, H_v, n_seqs] - - return {o, s}; -} - -std::pair llm_build_qwen3next::build_delta_net_autoregressive( - ggml_tensor * q, - ggml_tensor * k, - ggml_tensor * v, - ggml_tensor * g, - ggml_tensor * b, // beta - ggml_tensor * s, // state - int il) { - const int64_t S_k = q->ne[0]; - const int64_t H_k = q->ne[1]; - const int64_t n_tokens = q->ne[2]; - const int64_t n_seqs = q->ne[3]; - - const int64_t S_v = v->ne[0]; - const int64_t H_v = v->ne[1]; - - GGML_ASSERT(n_tokens == 1); - - GGML_ASSERT(S_k == S_v); - GGML_ASSERT(H_v % H_k == 0); - - GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs); - GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs); - GGML_ASSERT(v->ne[0] == S_v && v->ne[1] == H_v && v->ne[2] == n_tokens && v->ne[3] == n_seqs); - - GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs); - GGML_ASSERT(b->ne[0] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs); - GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs); - - const float scale = 1.0f / sqrtf(S_k); - - q = ggml_scale(ctx0, q, scale); - - q = ggml_permute(ctx0, q, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs] - k = ggml_permute(ctx0, k, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs] - v = ggml_permute(ctx0, v, 0, 2, 1, 3); // [S_v, n_tokens, H_v, n_seqs] - - cb(q, "q_in", il); - cb(k, "k_in", il); - cb(v, "v_in", il); - cb(b, "b_in", il); - cb(g, "g_in", il); - - g = ggml_reshape_4d(ctx0, g, 1, 1, H_v, n_seqs); - b = ggml_reshape_4d(ctx0, b, 1, 1, H_v, n_seqs); - - // [S_v, S_v, H_v, n_seqs] - g = ggml_exp(ctx0, g); - s = ggml_mul(ctx0, s, g); - - ggml_tensor * s_t = ggml_cont(ctx0, ggml_transpose(ctx0, s)); - - // [1, S_v, H_v, n_seqs] - ggml_tensor * sk; - sk = ggml_mul (ctx0, s_t, k); - sk = ggml_sum_rows(ctx0, sk); - - // [S_v, 1, H_v, n_seqs] - ggml_tensor * d; - d = ggml_sub(ctx0, v, ggml_transpose(ctx0, sk)); - d = ggml_mul(ctx0, d, b); - - // [1, S_v, H_v, n_seqs] - ggml_tensor * d_t; - d_t = ggml_transpose(ctx0, d); - - // [S_v, S_v, H_v, n_seqs] - ggml_tensor * kd; - k = ggml_repeat(ctx0, k, s); - kd = ggml_mul (ctx0, k, d_t); - - s_t = ggml_add(ctx0, s_t, kd); - - cb(s_t, "dnet_add_ar_state", il); - - ggml_tensor * s_q = ggml_mul (ctx0, s_t, q); - ggml_tensor * o = ggml_sum_rows(ctx0, s_q); - - o = ggml_permute (ctx0, o, 2, 0, 1, 3); // [S_v, H_v, n_tokens, n_seqs] - s = ggml_transpose(ctx0, s_t); // [S_v, S_v, H_v, n_seqs] - - return {o, s}; -} - ggml_tensor * llm_build_qwen3next::build_norm_gated( ggml_tensor * input, ggml_tensor * weights, @@ -627,8 +308,6 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( ggml_tensor * beta = ggml_sigmoid(ctx0, b); - beta = ggml_reshape_4d(ctx0, beta, num_v_heads, 1, n_seq_tokens, n_seqs); - // Reshape a to merge head dimensions: [batch, seq_len, num_k_heads, num_v_heads/num_k_heads] -> [batch, seq_len, num_v_heads] ggml_tensor * alpha = ggml_cont_3d(ctx0, a, num_v_heads, n_seq_tokens, n_seqs); @@ -639,6 +318,9 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( ggml_tensor * gate = ggml_mul(ctx0, alpha_softplus, model.layers[il].ssm_a); // -A_log.exp() * softplus cb(gate, "gate", il); + beta = ggml_reshape_4d(ctx0, beta, 1, num_v_heads, n_seq_tokens, n_seqs); + gate = ggml_reshape_4d(ctx0, gate, 1, num_v_heads, n_seq_tokens, n_seqs); + // Get convolution states from cache ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); @@ -674,7 +356,6 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( cb(state_update_target, "state_update_target", il); ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target)); - cb(conv_states_all, "conv_states_updated", il); ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs); state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim, num_v_heads, n_seqs); @@ -798,7 +479,8 @@ ggml_tensor * llm_build_qwen3next::build_layer_ffn(ggml_tensor * cur, const int model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps, nullptr, n_expert, n_expert_used, LLM_FFN_SILU, - true, false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il); + true, false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il, + nullptr, model.layers[il].ffn_gate_up_exps); cb(moe_out, "ffn_moe_out", il); // Add shared experts if present - following Qwen3Next reference implementation diff --git a/examples/talk-llama/models/rwkv6-base.cpp b/examples/talk-llama/models/rwkv6-base.cpp index 7beed2daffb..83aeab7280b 100644 --- a/examples/talk-llama/models/rwkv6-base.cpp +++ b/examples/talk-llama/models/rwkv6-base.cpp @@ -1,5 +1,7 @@ #include "models.h" +#include "llama-memory-recurrent.h" + llm_build_rwkv6_base::llm_build_rwkv6_base(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params), model(model) {} diff --git a/examples/talk-llama/models/rwkv7-base.cpp b/examples/talk-llama/models/rwkv7-base.cpp index cda44653849..7fcab77745c 100644 --- a/examples/talk-llama/models/rwkv7-base.cpp +++ b/examples/talk-llama/models/rwkv7-base.cpp @@ -1,5 +1,7 @@ #include "models.h" +#include "llama-memory-recurrent.h" + llm_build_rwkv7_base::llm_build_rwkv7_base(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params), model(model) {} diff --git a/examples/talk-llama/unicode.cpp b/examples/talk-llama/unicode.cpp index b88d953bd27..1475b53b659 100644 --- a/examples/talk-llama/unicode.cpp +++ b/examples/talk-llama/unicode.cpp @@ -769,6 +769,12 @@ static std::vector unicode_regex_split_custom(const std::string & text, } else if (regex_expr == "\\p{AFMoE_digits}") { // AFMOE digit pattern - use custom implementation for proper splitting bpe_offsets = unicode_regex_split_custom_afmoe(text, offsets); + } else if (regex_expr == "\\d{1,3}(?=(?:\\d{3})*\\b)") { + // tiny_aya digit grouping pattern from tokenizer.json: + // {"type": "Split", "pattern": {"Regex": "\\d{1,3}(?=(?:\\d{3})*\\b)"}, "behavior": "Isolated"} + // Splits digits into groups of 3 from the right (e.g., 1234567 -> 1, 234, 567) + // TODO: Revisit this regex, incase there are any subtle tokenization differences with the original regex. + bpe_offsets = unicode_regex_split_custom_afmoe(text, offsets); } return bpe_offsets; From aaf8bdf3b8e48b4b2c28b35865691b0bd3b7df07 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 27 Feb 2026 12:24:33 +0200 Subject: [PATCH 195/831] scripts : sync gguf --- scripts/sync-ggml-am.sh | 4 +++- scripts/sync-ggml.sh | 1 + 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/scripts/sync-ggml-am.sh b/scripts/sync-ggml-am.sh index 1f87e23122b..bc7c1b2fe15 100755 --- a/scripts/sync-ggml-am.sh +++ b/scripts/sync-ggml-am.sh @@ -60,8 +60,8 @@ while read c; do cmake/common.cmake \ cmake/ggml-config.cmake.in \ src/ggml-cpu/cmake/FindSIMD.cmake \ - src/ggml*.h \ src/ggml* \ + src/gguf* \ include/ggml*.h \ include/gguf*.h \ examples/common.h \ @@ -105,6 +105,7 @@ if [ -f $SRC_WHISPER/ggml-src.patch ]; then # src/ggml-cpu/cmake/FindSIMD.cmake -> ggml/src/ggml-cpu/cmake/FindSIMD.cmake # # src/ggml* -> ggml/src/ggml*.c + # src/gguf* -> ggml/src/gguf*.c # # include/ggml*.h -> ggml/include/ggml*.h # include/gguf*.h -> ggml/include/gguf*.h @@ -126,6 +127,7 @@ if [ -f $SRC_WHISPER/ggml-src.patch ]; then -e 's/(^[[:space:]]| [ab]\/)cmake\/ggml-config.cmake.in/\1ggml\/cmake\/ggml-config.cmake.in/g' \ -e 's/(^[[:space:]]| [ab]\/)src\/ggml-cpu\/cmake\/FindSIMD.cmake/\1ggml\/src\/ggml-cpu\/cmake\/FindSIMD.cmake/g' \ -e 's/([[:space:]]| [ab]\/)src\/ggml(.*)/\1ggml\/src\/ggml\2/g' \ + -e 's/([[:space:]]| [ab]\/)src\/gguf(.*)/\1ggml\/src\/gguf\2/g' \ -e 's/(^[[:space:]]| [ab]\/)include\/ggml(.*)\.h/\1ggml\/include\/ggml\2.h/g' \ -e 's/(^[[:space:]]| [ab]\/)include\/gguf(.*)\.h/\1ggml\/include\/gguf\2.h/g' \ -e 's/(^[[:space:]]| [ab]\/)examples\/common\.h/\1examples\/common.h/g' \ diff --git a/scripts/sync-ggml.sh b/scripts/sync-ggml.sh index 4296ddf5f50..099d5445c8c 100755 --- a/scripts/sync-ggml.sh +++ b/scripts/sync-ggml.sh @@ -7,6 +7,7 @@ cp -rpv ../ggml/cmake/* ./ggml/cmake/ cp -rpv ../ggml/src/ggml-cpu/cmake/* ./ggml/src/ggml-cpu/cmake/ cp -rpv ../ggml/src/ggml* ./ggml/src/ +cp -rpv ../ggml/src/gguf* ./ggml/src/ cp -rpv ../ggml/include/ggml*.h ./ggml/include/ cp -rpv ../ggml/include/gguf*.h ./ggml/include/ From 9453b4b9be9b73adfc35051083f37cefa039acee Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 27 Feb 2026 12:24:59 +0200 Subject: [PATCH 196/831] gguf : sync (ggml/0) --- ggml/src/gguf.cpp | 273 ++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 229 insertions(+), 44 deletions(-) diff --git a/ggml/src/gguf.cpp b/ggml/src/gguf.cpp index 53504399c57..cbeedf6c4b6 100644 --- a/ggml/src/gguf.cpp +++ b/ggml/src/gguf.cpp @@ -15,6 +15,17 @@ #include #include +#define GGUF_MAX_STRING_LENGTH (1024*1024*1024) +#define GGUF_MAX_ARRAY_ELEMENTS (1024*1024*1024) + +#ifdef _WIN32 +# define gguf_ftell _ftelli64 +# define gguf_fseek _fseeki64 +#else +# define gguf_ftell ftello +# define gguf_fseek fseeko +#endif + template struct type_to_gguf_type; @@ -217,17 +228,64 @@ struct gguf_context { }; struct gguf_reader { - FILE * file; + gguf_reader(FILE * file) : file(file) { + // read the remaining bytes once and update on each read + nbytes_remain = file_remain(file); + } - gguf_reader(FILE * file) : file(file) {} + // helper for remaining bytes in a file + static uint64_t file_remain(FILE * file) { + const int64_t cur = gguf_ftell(file); + if (cur < 0) { + return 0; + } + if (gguf_fseek(file, 0, SEEK_END) != 0) { + gguf_fseek(file, cur, SEEK_SET); + + return 0; + } + const int64_t end = gguf_ftell(file); + if (end < 0) { + gguf_fseek(file, cur, SEEK_SET); + + return 0; + } + gguf_fseek(file, cur, SEEK_SET); + return static_cast(end - cur); + } template bool read(T & dst) const { - return fread(&dst, 1, sizeof(dst), file) == sizeof(dst); + const size_t size = sizeof(dst); + if (nbytes_remain < size) { + return false; + } + const size_t nread = fread(&dst, 1, size, file); + nbytes_remain -= nread; + return nread == size; } template bool read(std::vector & dst, const size_t n) const { + if (n > GGUF_MAX_ARRAY_ELEMENTS) { + return false; + } + if constexpr (std::is_same::value) { + // strings are prefixed with their length, so we need to account for that + if (n > SIZE_MAX / sizeof(uint64_t)) { + return false; + } + if (nbytes_remain < n * sizeof(uint64_t)) { + return false; + } + } else { + if (n > SIZE_MAX / sizeof(T)) { + return false; + } + if (nbytes_remain < n * sizeof(T)) { + return false; + } + } dst.resize(n); for (size_t i = 0; i < dst.size(); ++i) { if constexpr (std::is_same::value) { @@ -273,17 +331,37 @@ struct gguf_reader { } bool read(std::string & dst) const { - uint64_t size = -1; + uint64_t size = 0; if (!read(size)) { return false; } - dst.resize(size); - return fread(dst.data(), 1, dst.length(), file) == dst.length(); + if (size > GGUF_MAX_STRING_LENGTH) { + GGML_LOG_ERROR("%s: string length %" PRIu64 " exceeds maximum %" PRIu64 "\n", __func__, size, (uint64_t) GGUF_MAX_STRING_LENGTH); + return false; + } + if (size > nbytes_remain) { + GGML_LOG_ERROR("%s: string length %" PRIu64 " exceeds remaining file size %" PRIu64 " bytes\n", __func__, size, nbytes_remain); + return false; + } + dst.resize(static_cast(size)); + const size_t nread = fread(dst.data(), 1, size, file); + nbytes_remain -= nread; + return nread == size; } bool read(void * dst, const size_t size) const { - return fread(dst, 1, size, file) == size; + if (size > nbytes_remain) { + return false; + } + const size_t nread = fread(dst, 1, size, file); + nbytes_remain -= nread; + return nread == size; } + +private: + FILE * file; + + mutable uint64_t nbytes_remain; }; struct gguf_context * gguf_init_empty(void) { @@ -523,7 +601,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par // tensor shape { - uint32_t n_dims = -1; + uint32_t n_dims = 0; ok = ok && gr.read(n_dims); if (n_dims > GGML_MAX_DIMS) { GGML_LOG_ERROR("%s: tensor '%s' has invalid number of dimensions: %" PRIu32 " > %" PRIu32 "\n", @@ -568,8 +646,8 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par // check that tensor type is within defined range if (info.t.type < 0 || info.t.type >= GGML_TYPE_COUNT) { - GGML_LOG_ERROR("%s: tensor '%s' has invalid ggml type %d (%s)\n", - __func__, info.t.name, info.t.type, ggml_type_name(info.t.type)); + GGML_LOG_ERROR("%s: tensor '%s' has invalid ggml type %d. should be in [0, %d)\n", + __func__, info.t.name, info.t.type, GGML_TYPE_COUNT); ok = false; break; } @@ -585,6 +663,14 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par break; } + // check that the size of the tensor in bytes is representable + if (ok && uint64_t(ggml_nelements(&info.t)/ggml_blck_size(info.t.type)) > SIZE_MAX/ggml_type_size(info.t.type)) { + GGML_LOG_ERROR("%s: tensor '%s' with shape (%" PRIi64 ", %" PRIi64 ", %" PRIi64 ", %" PRIi64 ") has a size in bytes > %zu\n", + __func__, info.t.name, info.t.ne[0], info.t.ne[1], info.t.ne[2], info.t.ne[3], SIZE_MAX); + ok = false; + break; + } + // calculate byte offsets given the tensor shape and type info.t.nb[0] = type_size; info.t.nb[1] = info.t.nb[0]*(info.t.ne[0]/blck_size); @@ -610,14 +696,14 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par GGML_ASSERT(int64_t(ctx->info.size()) == n_tensors); // we require the data section to be aligned, so take into account any padding - if (fseek(file, GGML_PAD(ftell(file), ctx->alignment), SEEK_SET) != 0) { + if (gguf_fseek(file, GGML_PAD(gguf_ftell(file), ctx->alignment), SEEK_SET) != 0) { GGML_LOG_ERROR("%s: failed to seek to beginning of data section\n", __func__); gguf_free(ctx); return nullptr; } // store the current file offset - this is where the data section starts - ctx->offset = ftell(file); + ctx->offset = gguf_ftell(file); // compute the total size of the data section, taking into account the alignment { @@ -649,10 +735,34 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par // the ggml_tensor structs to the appropriate locations in the binary blob // compute the exact size needed for the new ggml_context - const size_t mem_size = - params.no_alloc ? - (n_tensors )*ggml_tensor_overhead() : - (n_tensors + 1)*ggml_tensor_overhead() + ctx->size; + size_t mem_size = 0; + if (params.no_alloc) { + if (n_tensors != 0 && SIZE_MAX / n_tensors < ggml_tensor_overhead()) { + GGML_LOG_ERROR("%s: memory size overflow while allocating ggml context\n", __func__); + gguf_free(ctx); + return nullptr; + } + + const size_t overhead = n_tensors * ggml_tensor_overhead(); + + mem_size = overhead; + } else { + if ((n_tensors + 1) != 0 && SIZE_MAX / (n_tensors + 1) < ggml_tensor_overhead()) { + GGML_LOG_ERROR("%s: memory size overflow while allocating ggml context\n", __func__); + gguf_free(ctx); + return nullptr; + } + + const size_t overhead = (n_tensors + 1) * ggml_tensor_overhead(); + + if (SIZE_MAX - overhead < ctx->size) { + GGML_LOG_ERROR("%s: memory size overflow while allocating ggml context\n", __func__); + gguf_free(ctx); + return nullptr; + } + + mem_size = overhead + ctx->size; + } struct ggml_init_params pdata = { /*mem_size =*/ mem_size, @@ -734,7 +844,7 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p FILE * file = ggml_fopen(fname, "rb"); if (!file) { - GGML_LOG_ERROR("%s: failed to open GGUF file '%s'\n", __func__, fname); + GGML_LOG_ERROR("%s: failed to open GGUF file '%s' (%s)\n", __func__, fname, strerror(errno)); return nullptr; } @@ -1166,50 +1276,51 @@ void gguf_set_tensor_data(struct gguf_context * ctx, const char * name, const vo ctx->info[tensor_id].t.data = (void *)(uintptr_t)data; // double cast suppresses warning about casting away const } -struct gguf_writer { - std::vector & buf; +struct gguf_writer_base { + size_t written_bytes {0u}; + + ~gguf_writer_base(void) = default; - gguf_writer(std::vector & buf) : buf(buf) {} + // we bet on devirtualization + virtual void write(int8_t val) = 0; + virtual void write(const std::vector & val) = 0; + virtual void write_tensor_data(const struct gguf_tensor_info & info, size_t offset_data, size_t alignment) = 0; template - void write(const T & val) const { + void write(const T & val) { for (size_t i = 0; i < sizeof(val); ++i) { - buf.push_back(reinterpret_cast(&val)[i]); + write(reinterpret_cast(&val)[i]); } } - void write(const std::vector & val) const { - buf.insert(buf.end(), val.begin(), val.end()); - } - - void write(const bool & val) const { + void write(const bool & val) { const int8_t val8 = val ? 1 : 0; write(val8); } - void write(const std::string & val) const { + void write(const std::string & val) { { const uint64_t n = val.length(); write(n); } for (size_t i = 0; i < val.length(); ++i) { - buf.push_back(reinterpret_cast(val.data())[i]); + write((val.data())[i]); } } - void write(const char * val) const { + void write(const char * val) { write(std::string(val)); } - void write(const enum ggml_type & val) const { + void write(const enum ggml_type & val) { write(int32_t(val)); } - void write(const enum gguf_type & val) const { + void write(const enum gguf_type & val) { write(int32_t(val)); } - void write(const struct gguf_kv & kv) const { + void write(const struct gguf_kv & kv) { const uint64_t ne = kv.get_ne(); write(kv.get_key()); @@ -1250,7 +1361,7 @@ struct gguf_writer { } } - void write_tensor_meta(const struct gguf_tensor_info & info) const { + void write_tensor_meta(const struct gguf_tensor_info & info) { write(info.t.name); const uint32_t n_dims = ggml_n_dims(&info.t); @@ -1263,14 +1374,33 @@ struct gguf_writer { write(info.offset); } - void pad(const size_t alignment) const { - while (buf.size() % alignment != 0) { + void pad(const size_t alignment) { + while (written_bytes % alignment != 0) { const int8_t zero = 0; write(zero); } } +}; + +// vector buffer based writer +struct gguf_writer_buf final : public gguf_writer_base { + std::vector & buf; + + gguf_writer_buf(std::vector & buf) : buf(buf) {} - void write_tensor_data(const struct gguf_tensor_info & info, const size_t offset_data, const size_t alignment) const { + using gguf_writer_base::write; + + void write(const int8_t val) override { + buf.push_back(val); + written_bytes++; + } + + void write(const std::vector & val) override { + buf.insert(buf.end(), val.begin(), val.end()); + written_bytes += val.size(); + } + + void write_tensor_data(const struct gguf_tensor_info & info, const size_t offset_data, const size_t alignment) override { GGML_ASSERT(buf.size() - offset_data == info.offset); GGML_ASSERT(ggml_is_contiguous(&info.t)); @@ -1284,14 +1414,58 @@ struct gguf_writer { GGML_ASSERT(info.t.data); memcpy(buf.data() + offset, info.t.data, nbytes); } + written_bytes += nbytes; pad(alignment); } }; -void gguf_write_to_buf(const struct gguf_context * ctx, std::vector & buf, bool only_meta) { - const struct gguf_writer gw(buf); +// file based writer +struct gguf_writer_file final : public gguf_writer_base { + FILE * file; + + gguf_writer_file(FILE* file) : file(file) {} + + using gguf_writer_base::write; + + void write(const int8_t val) override { + const auto real_val = static_cast(val); + const auto ret = fputc(real_val, file); + written_bytes++; + if (ret != real_val) { + throw std::runtime_error("unexpected fputc result '" + std::to_string(ret) + "' instead of '" + std::to_string((int)real_val) + "'"); + } + } + + void write(const std::vector & val) override { + const auto ret = fwrite(val.data(), 1, val.size(), file); + written_bytes += val.size(); + if (ret != val.size()) { + throw std::runtime_error("unexpected fwrite number of bytes written, '" + std::to_string(ret) + "' instead of '" + std::to_string(val.size()) + "'"); + } + } + + void write_tensor_data(const struct gguf_tensor_info & info, const size_t offset_data, const size_t alignment) override { + GGML_ASSERT(written_bytes - offset_data == info.offset); + + GGML_ASSERT(ggml_is_contiguous(&info.t)); + const size_t nbytes = ggml_nbytes(&info.t); + + std::vector buf(nbytes); + if (info.t.buffer) { + ggml_backend_tensor_get(&info.t, buf.data(), 0, nbytes); + } else { + GGML_ASSERT(info.t.data); + memcpy(buf.data(), info.t.data, nbytes); + } + write(buf); + pad(alignment); + } +}; + +template +static void gguf_write_out(const struct gguf_context * ctx, writer_t & gw, bool only_meta) { const int64_t n_kv = gguf_get_n_kv(ctx); const int64_t n_tensors = gguf_get_n_tensors(ctx); @@ -1321,7 +1495,7 @@ void gguf_write_to_buf(const struct gguf_context * ctx, std::vector & bu return; } - const size_t offset_data = gw.buf.size(); + const size_t offset_data = gw.written_bytes; // write tensor data for (int64_t i = 0; i < n_tensors; ++i) { @@ -1329,6 +1503,11 @@ void gguf_write_to_buf(const struct gguf_context * ctx, std::vector & bu } } +void gguf_write_to_buf(const struct gguf_context * ctx, std::vector & buf, bool only_meta) { + gguf_writer_buf gw(buf); + gguf_write_out(ctx, gw, only_meta); +} + bool gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta) { FILE * file = ggml_fopen(fname, "wb"); @@ -1337,11 +1516,17 @@ bool gguf_write_to_file(const struct gguf_context * ctx, const char * fname, boo return false; } - std::vector buf; - gguf_write_to_buf(ctx, buf, only_meta); - const bool ok = fwrite(buf.data(), 1, buf.size(), file) == buf.size(); + try { + gguf_writer_file gw(file); + gguf_write_out(ctx, gw, only_meta); + } catch (const std::runtime_error& ex) { + GGML_LOG_ERROR("%s: failed to write GGUF data into '%s': %s\n", __func__, fname, ex.what()); + fclose(file); + return false; + } + fclose(file); - return ok; + return true; } size_t gguf_get_meta_size(const struct gguf_context * ctx) { From 30c5194c9691e4e9a98b3dea9f19727397d3f46e Mon Sep 17 00:00:00 2001 From: KITAITI Makoto Date: Thu, 5 Mar 2026 14:36:42 +0900 Subject: [PATCH 197/831] ruby : null-check (#3689) * Introduce null-check to prevent SEGV * Fix error message --- bindings/ruby/ext/ruby_whisper_context.c | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/bindings/ruby/ext/ruby_whisper_context.c b/bindings/ruby/ext/ruby_whisper_context.c index a8118d12773..c39d43bd76c 100644 --- a/bindings/ruby/ext/ruby_whisper_context.c +++ b/bindings/ruby/ext/ruby_whisper_context.c @@ -304,11 +304,11 @@ VALUE ruby_whisper_model_type(VALUE self) static bool check_memory_view(rb_memory_view_t *memview) { - if (strcmp(memview->format, "f") != 0) { + if (memview->format != NULL && strcmp(memview->format, "f") != 0) { rb_warn("currently only format \"f\" is supported for MemoryView, but given: %s", memview->format); return false; } - if (memview->ndim != 1) { + if (memview->format != NULL && memview->ndim != 1) { rb_warn("currently only 1 dimensional MemoryView is supported, but given: %zd", memview->ndim); return false; } @@ -377,7 +377,7 @@ parse_samples(VALUE *samples, VALUE *n_samples) } parsed.n_samples = (int)n_samples_size; } else { - rb_warn("unable to get a memory view. fallbacks to Ruby object"); + rb_warn("unable to get a memory view. falls back to Ruby object"); if (rb_respond_to(*samples, id_length)) { parsed.n_samples = NUM2INT(rb_funcall(*samples, id_length, 0)); } else { From b524b5a1f0af8cf78da05d15835afc96e25a7449 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Fri, 27 Feb 2026 18:15:09 +0800 Subject: [PATCH 198/831] ggml-cpu: add repack for mxfp4 (llama/19738) --- ggml/src/ggml-cpu/arch-fallback.h | 28 +++ ggml/src/ggml-cpu/arch/arm/repack.cpp | 156 +++++++++++++ ggml/src/ggml-cpu/arch/x86/repack.cpp | 104 ++++++++- ggml/src/ggml-cpu/repack.cpp | 318 ++++++++++++++++++++++++++ ggml/src/ggml-cpu/repack.h | 21 ++ 5 files changed, 625 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cpu/arch-fallback.h b/ggml/src/ggml-cpu/arch-fallback.h index 4dfe28e1d64..ebbd4b47e05 100644 --- a/ggml/src/ggml-cpu/arch-fallback.h +++ b/ggml/src/ggml-cpu/arch-fallback.h @@ -48,6 +48,8 @@ #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 +#define ggml_gemv_mxfp4_4x4_q8_0_generic ggml_gemv_mxfp4_4x4_q8_0 +#define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 @@ -62,6 +64,8 @@ #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 +#define ggml_gemm_mxfp4_4x4_q8_0_generic ggml_gemm_mxfp4_4x4_q8_0 +#define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 #elif defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) || defined(_M_ARM64) @@ -69,8 +73,10 @@ #define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 +#define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0 #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 +#define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0 #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #elif defined(__x86_64__) || defined(__i386__) || defined(_M_IX86) || defined(_M_X64) // repack.cpp @@ -84,6 +90,7 @@ #define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 +#define ggml_gemv_mxfp4_4x4_q8_0_generic ggml_gemv_mxfp4_4x4_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 @@ -94,6 +101,7 @@ #define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 +#define ggml_gemm_mxfp4_4x4_q8_0_generic ggml_gemm_mxfp4_4x4_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 #elif defined(__POWERPC__) || defined(__powerpc__) @@ -120,6 +128,8 @@ #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 +#define ggml_gemv_mxfp4_4x4_q8_0_generic ggml_gemv_mxfp4_4x4_q8_0 +#define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 @@ -134,6 +144,8 @@ #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 +#define ggml_gemm_mxfp4_4x4_q8_0_generic ggml_gemm_mxfp4_4x4_q8_0 +#define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 #elif defined(__loongarch64) @@ -160,6 +172,8 @@ #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 +#define ggml_gemv_mxfp4_4x4_q8_0_generic ggml_gemv_mxfp4_4x4_q8_0 +#define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 @@ -174,6 +188,8 @@ #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 +#define ggml_gemm_mxfp4_4x4_q8_0_generic ggml_gemm_mxfp4_4x4_q8_0 +#define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 #elif defined(__riscv) @@ -201,6 +217,8 @@ #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 +#define ggml_gemv_mxfp4_4x4_q8_0_generic ggml_gemv_mxfp4_4x4_q8_0 +#define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 @@ -214,6 +232,8 @@ #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 +#define ggml_gemm_mxfp4_4x4_q8_0_generic ggml_gemm_mxfp4_4x4_q8_0 +#define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 #elif defined(__s390x__) @@ -246,6 +266,8 @@ #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 +#define ggml_gemv_mxfp4_4x4_q8_0_generic ggml_gemv_mxfp4_4x4_q8_0 +#define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 @@ -260,6 +282,8 @@ #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 +#define ggml_gemm_mxfp4_4x4_q8_0_generic ggml_gemm_mxfp4_4x4_q8_0 +#define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 #elif defined(__wasm__) @@ -294,6 +318,8 @@ #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 +#define ggml_gemv_mxfp4_4x4_q8_0_generic ggml_gemv_mxfp4_4x4_q8_0 +#define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 @@ -308,6 +334,8 @@ #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 +#define ggml_gemm_mxfp4_4x4_q8_0_generic ggml_gemm_mxfp4_4x4_q8_0 +#define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 #endif diff --git a/ggml/src/ggml-cpu/arch/arm/repack.cpp b/ggml/src/ggml-cpu/arch/arm/repack.cpp index c2e4623f371..3eed0105bf1 100644 --- a/ggml/src/ggml-cpu/arch/arm/repack.cpp +++ b/ggml/src/ggml-cpu/arch/arm/repack.cpp @@ -498,6 +498,81 @@ void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const ggml_gemv_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc); } +void ggml_gemv_mxfp4_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 4; + + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + const int8x16_t kvalues = vld1q_s8(kvalues_mxfp4); + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + float * res_ptr = s; + + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_mxfp4x4 * b_ptr = (const block_mxfp4x4 *) vx + (x * nb); + + float32x4_t sumf = vdupq_n_f32(0); + for (int l = 0; l < nb; l++) { + uint8x16_t b_0 = vld1q_u8(b_ptr[l].qs + 0); + uint8x16_t b_1 = vld1q_u8(b_ptr[l].qs + 16); + uint8x16_t b_2 = vld1q_u8(b_ptr[l].qs + 32); + uint8x16_t b_3 = vld1q_u8(b_ptr[l].qs + 48); + + int8x16_t b_0_hi = vqtbl1q_s8(kvalues, b_0 >> 4); + int8x16_t b_0_lo = vqtbl1q_s8(kvalues, b_0 & 0x0F); + int8x16_t b_1_hi = vqtbl1q_s8(kvalues, b_1 >> 4); + int8x16_t b_1_lo = vqtbl1q_s8(kvalues, b_1 & 0x0F); + int8x16_t b_2_hi = vqtbl1q_s8(kvalues, b_2 >> 4); + int8x16_t b_2_lo = vqtbl1q_s8(kvalues, b_2 & 0x0F); + int8x16_t b_3_hi = vqtbl1q_s8(kvalues, b_3 >> 4); + int8x16_t b_3_lo = vqtbl1q_s8(kvalues, b_3 & 0x0F); + + int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 0); + int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16); + + int32x4_t sumi = vdupq_n_s32(0); + sumi = vdotq_laneq_s32(sumi, b_0_lo, a_0, 0); + sumi = vdotq_laneq_s32(sumi, b_0_hi, a_1, 0); + sumi = vdotq_laneq_s32(sumi, b_1_lo, a_0, 1); + sumi = vdotq_laneq_s32(sumi, b_1_hi, a_1, 1); + sumi = vdotq_laneq_s32(sumi, b_2_lo, a_0, 2); + sumi = vdotq_laneq_s32(sumi, b_2_hi, a_1, 2); + sumi = vdotq_laneq_s32(sumi, b_3_lo, a_0, 3); + sumi = vdotq_laneq_s32(sumi, b_3_hi, a_1, 3); + + float32x4_t a_d = vcvt_f32_f16(vld1_dup_f16((const float16_t *)&a_ptr[l].d)); + float32x4_t b_d = { + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[0]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[1]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[2]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[3]), + }; + float32x4_t d = a_d * b_d; + + sumf = vmlaq_f32(sumf, d, vcvtq_f32_s32(sumi)); + } + + vst1q_f32(res_ptr + x * 4, sumf); + } + return; +#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) + ggml_gemv_mxfp4_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc); +} + void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { constexpr int qk = QK_K; const int nb = n / qk; @@ -3164,6 +3239,87 @@ void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const ggml_gemm_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc); } +void ggml_gemm_mxfp4_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 4; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + const int8x16_t kvalues = vld1q_s8(kvalues_mxfp4); + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_mxfp4x4 * b_ptr = (const block_mxfp4x4 *) vx + (x * nb); + + float32x4_t sumf[4]; + for (int m = 0; m < 4; m++) { + sumf[m] = vdupq_n_f32(0); + } + + for (int l = 0; l < nb; l++) { + float32x4_t a_d = vcvt_f32_f16(vld1_f16((const float16_t *)a_ptr[l].d)); + float32x4_t b_d = { + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[0]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[1]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[2]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[3]), + }; + + int32x4_t sumi_0 = vdupq_n_s32(0); + int32x4_t sumi_1 = vdupq_n_s32(0); + int32x4_t sumi_2 = vdupq_n_s32(0); + int32x4_t sumi_3 = vdupq_n_s32(0); + + for (int k = 0; k < 4; k++) { + int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 16 * k + 0); + int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16 * k + 64); + + uint8x16_t b = vld1q_u8(b_ptr[l].qs + 16 * k); + int8x16_t b_hi = vqtbl1q_s8(kvalues, b >> 4); + int8x16_t b_lo = vqtbl1q_s8(kvalues, b & 0xF); + + sumi_0 = vdotq_laneq_s32(sumi_0, b_lo, a_0, 0); + sumi_1 = vdotq_laneq_s32(sumi_1, b_lo, a_0, 1); + sumi_2 = vdotq_laneq_s32(sumi_2, b_lo, a_0, 2); + sumi_3 = vdotq_laneq_s32(sumi_3, b_lo, a_0, 3); + sumi_0 = vdotq_laneq_s32(sumi_0, b_hi, a_1, 0); + sumi_1 = vdotq_laneq_s32(sumi_1, b_hi, a_1, 1); + sumi_2 = vdotq_laneq_s32(sumi_2, b_hi, a_1, 2); + sumi_3 = vdotq_laneq_s32(sumi_3, b_hi, a_1, 3); + } + + sumf[0] = vmlaq_f32(sumf[0], vmulq_laneq_f32(b_d, a_d, 0), vcvtq_f32_s32(sumi_0)); + sumf[1] = vmlaq_f32(sumf[1], vmulq_laneq_f32(b_d, a_d, 1), vcvtq_f32_s32(sumi_1)); + sumf[2] = vmlaq_f32(sumf[2], vmulq_laneq_f32(b_d, a_d, 2), vcvtq_f32_s32(sumi_2)); + sumf[3] = vmlaq_f32(sumf[3], vmulq_laneq_f32(b_d, a_d, 3), vcvtq_f32_s32(sumi_3)); + } + + for (int m = 0; m < 4; m++) { + vst1q_f32(s + (y * 4 + m) * bs + x * 4, sumf[m]); + } + } + } + return; +#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) + ggml_gemm_mxfp4_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc); +} + void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { constexpr int qk = QK_K; const int nb = n / qk; diff --git a/ggml/src/ggml-cpu/arch/x86/repack.cpp b/ggml/src/ggml-cpu/arch/x86/repack.cpp index 7dda9eea0c5..bd6906c4159 100644 --- a/ggml/src/ggml-cpu/arch/x86/repack.cpp +++ b/ggml/src/ggml-cpu/arch/x86/repack.cpp @@ -522,7 +522,8 @@ template static void gemv_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, __m256i signextendlut) { static_assert( std::is_same_v || - std::is_same_v, + std::is_same_v || + std::is_same_v, "Unsupported block type"); const int qk = QK8_0; @@ -580,6 +581,18 @@ static void gemv_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t std::is_same_v || std::is_same_v) { col_scale_f32 = GGML_F32Cx8_REARRANGE_LOAD(b_ptr[b].d, changemask); + } else if constexpr (std::is_same_v) { + // Load 8 E8M0 exponents and convert to float via LUT + // Rearranged to match changemask order: 0,4,1,5,2,6,3,7 + col_scale_f32 = _mm256_set_ps( + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[7]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[3]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[6]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[2]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[5]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[1]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[4]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[0])); } // Load and convert to FP32 scale from block_q8_0 @@ -628,7 +641,8 @@ template static void gemm_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, __m256i signextendlut) { static_assert( std::is_same_v || - std::is_same_v, + std::is_same_v || + std::is_same_v, "Unsupported block type"); const int qk = QK8_0; @@ -749,6 +763,25 @@ static void gemm_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t std::is_same_v || std::is_same_v) { col_scale_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d); + } else if constexpr (std::is_same_v) { + //TODO: simd-ify + col_scale_f32 = _mm512_set_ps( + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[7]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[6]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[5]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[4]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[3]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[2]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[1]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[0]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[7]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[6]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[5]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[4]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[3]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[2]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[1]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[0])); } // Process LHS in pairs of rows @@ -941,6 +974,25 @@ static void gemm_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t std::is_same_v || std::is_same_v) { col_scale_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d); + } else if constexpr (std::is_same_v) { + //TODO: simd-ify + col_scale_f32 = _mm512_set_ps( + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[7]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[6]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[5]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[4]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[3]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[2]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[1]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[0]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[7]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[6]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[5]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[4]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[3]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[2]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[1]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[0])); } // Load the four blocks of quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3 @@ -1123,6 +1175,16 @@ static void gemm_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t std::is_same_v || std::is_same_v) { col_scale_f32 = GGML_F32Cx8_LOAD(b_ptr[b].d); + } else if constexpr (std::is_same_v) { + col_scale_f32 = _mm256_set_ps( + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[7]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[6]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[5]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[4]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[3]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[2]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[1]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[0])); } // Process LHS in groups of four @@ -1283,6 +1345,16 @@ static void gemm_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t std::is_same_v || std::is_same_v) { col_scale_f32 = GGML_F32Cx8_LOAD(b_ptr[b].d); + } else if constexpr (std::is_same_v) { + col_scale_f32 = _mm256_set_ps( + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[7]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[6]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[5]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[4]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[3]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[2]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[1]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[0])); } // Load the four blocks of quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3 @@ -1625,6 +1697,19 @@ void ggml_gemv_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const ggml_gemv_iq4_nl_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc); } +void ggml_gemv_mxfp4_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined(__AVX2__) + __m256i signextendlut = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i*)kvalues_mxfp4)); + signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0); + + gemv_q4_b32_8x8_q8_0_lut_avx(n, s, bs, vx, vy, nr, nc, signextendlut); + + return; +#endif + + ggml_gemv_mxfp4_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc); +} + void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK_K; const int nb = n / qk; @@ -3423,6 +3508,21 @@ void ggml_gemm_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const ggml_gemm_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc); } +void ggml_gemm_mxfp4_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined(__AVX2__) || defined(__AVX512F__) + { + __m256i signextendlut = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i*)kvalues_mxfp4)); + signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0); + + gemm_q4_b32_8x8_q8_0_lut_avx(n, s, bs, vx, vy, nr, nc, signextendlut); + + return; + } +#endif // defined(__AVX2__) || defined(__AVX512F__) + + ggml_gemm_mxfp4_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc); +} + void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK_K; const int nb = n / qk; diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp index 1b3d23cbedc..5edba4212f6 100644 --- a/ggml/src/ggml-cpu/repack.cpp +++ b/ggml/src/ggml-cpu/repack.cpp @@ -1098,6 +1098,82 @@ void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs } } +void ggml_gemv_mxfp4_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 4; + + assert(nr == 1); + assert(n % qk == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(bs); + UNUSED(nr); + + float sumf[4]; + int sumi; + + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_mxfp4x4 * b_ptr = (const block_mxfp4x4 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F]; + const int v1 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4]; + sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])); + } + sumf[j] += sumi * GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d); + } + } + } + for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; + } +} + +void ggml_gemv_mxfp4_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 8; + const int blocklen = 8; + + assert(nr == 1); + assert(n % qk == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(bs); + UNUSED(nr); + + float sumf[8]; + int sumi; + + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_mxfp4x8 * b_ptr = (const block_mxfp4x8 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F]; + const int v1 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4]; + sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])); + } + sumf[j] += sumi * GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d); + } + } + } + for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; + } +} + void ggml_gemv_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, @@ -1726,6 +1802,94 @@ void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs } } +void ggml_gemm_mxfp4_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 4; + + assert(n % qk == 0); + assert(nr % 4 == 0); + assert(nc % ncols_interleaved == 0); + + float sumf[4][4]; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_mxfp4x4 * b_ptr = (const block_mxfp4x4 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F]; + const int v1 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4]; + sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + + (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])); + } + sumf[m][j] += sumi * GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]); + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + } + } + } +} + +void ggml_gemm_mxfp4_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 8; + const int blocklen = 8; + + assert(n % qk == 0); + assert(nr % 4 == 0); + assert(nc % ncols_interleaved == 0); + + float sumf[4][8]; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_mxfp4x8 * b_ptr = (const block_mxfp4x8 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F]; + const int v1 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4]; + sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + + (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])); + } + sumf[m][j] += sumi * GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]); + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + } + } + } +} + void ggml_gemm_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, @@ -2510,6 +2674,121 @@ static int repack_iq4_nl_to_iq4_nl_8_bl(struct ggml_tensor * t, int interleave_b GGML_UNUSED(data_size); } + +static block_mxfp4x4 make_block_mxfp4x4(block_mxfp4 * in, unsigned int blck_size_interleave) { + block_mxfp4x4 out; + + for (int i = 0; i < 4; i++) { + out.e[i] = in[i].e; + } + + const int end = QK_MXFP4 * 2 / blck_size_interleave; + + if (blck_size_interleave == 4) { + for (int i = 0; i < end; ++i) { + int src_id = i % 4; + int src_offset = (i / 4) * blck_size_interleave; + int dst_offset = i * blck_size_interleave; + + memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint32_t)); + } + } else { + GGML_ASSERT(false); + } + + return out; +} + +static int repack_mxfp4_to_mxfp4_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_MXFP4); + GGML_ASSERT(interleave_block == 4); + + const block_mxfp4 * src = (const block_mxfp4 *)data; + block_mxfp4x4 * dst = ( block_mxfp4x4 *)t->data; + + block_mxfp4 dst_tmp[4]; + + int nrow = ggml_nrows(t); + int nrows_interleaved = 4; + int nblocks = t->ne[0] / QK_MXFP4; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_mxfp4)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_mxfp4x4(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +static block_mxfp4x8 make_block_mxfp4x8(block_mxfp4 * in, unsigned int blck_size_interleave) { + block_mxfp4x8 out; + + for (int i = 0; i < 8; i++) { + out.e[i] = in[i].e; + } + + const int end = QK_MXFP4 * 4 / blck_size_interleave; + + if (blck_size_interleave == 8) { + for (int i = 0; i < end; ++i) { + int src_id = i % 8; + int src_offset = (i / 8) * blck_size_interleave; + int dst_offset = i * blck_size_interleave; + + memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint64_t)); + } + } else { + GGML_ASSERT(false); + } + + return out; +} + +static int repack_mxfp4_to_mxfp4_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_MXFP4); + GGML_ASSERT(interleave_block == 8); + + const block_mxfp4 * src = (const block_mxfp4 *)data; + block_mxfp4x8 * dst = ( block_mxfp4x8 *)t->data; + + block_mxfp4 dst_tmp[8]; + + int nrow = ggml_nrows(t); + int nrows_interleaved = 8; + int nblocks = t->ne[0] / QK_MXFP4; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_mxfp4)); + + if (t->ne[1] % nrows_interleaved != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_mxfp4x8(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + namespace ggml::cpu::repack { // repack template @@ -2569,6 +2848,14 @@ template <> int repack(struct ggml_tensor * t, const void * return repack_iq4_nl_to_iq4_nl_8_bl(t, 8, data, data_size); } +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_mxfp4_to_mxfp4_4_bl(t, 4, data, data_size); +} + +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_mxfp4_to_mxfp4_8_bl(t, 8, data, data_size); +} + template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { return repack_q8_0_to_q8_0_4_bl(t, 4, data, data_size); } @@ -2636,6 +2923,14 @@ template <> void gemv(int n, float * s, size ggml_gemv_iq4_nl_8x8_q8_0(n, s, bs, vx, vy, nr, nc); } +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_mxfp4_4x4_q8_0(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_mxfp4_8x8_q8_0(n, s, bs, vx, vy, nr, nc); +} + template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemv_q8_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc); } @@ -2703,6 +2998,14 @@ template <> void gemm(int n, float * s, size ggml_gemm_iq4_nl_8x8_q8_0(n, s, bs, vx, vy, nr, nc); } +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_mxfp4_4x4_q8_0(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_mxfp4_8x8_q8_0(n, s, bs, vx, vy, nr, nc); +} + template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_q8_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc); } @@ -3111,6 +3414,10 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons static const ggml::cpu::repack::tensor_traits iq4_nl_4x4_q8_0; static const ggml::cpu::repack::tensor_traits iq4_nl_8x8_q8_0; + // instance for MXFP4 + static const ggml::cpu::repack::tensor_traits mxfp4_4x4_q8_0; + static const ggml::cpu::repack::tensor_traits mxfp4_8x8_q8_0; + // instance for Q8_0 static const ggml::cpu::repack::tensor_traits q8_0_4x4_q8_0; static const ggml::cpu::repack::tensor_traits q8_0_4x8_q8_0; @@ -3187,6 +3494,17 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons return &iq4_nl_4x4_q8_0; } } + } else if (cur->type == GGML_TYPE_MXFP4) { + if (ggml_cpu_has_avx2()) { + if (cur->ne[1] % 8 == 0) { + return &mxfp4_8x8_q8_0; + } + } + if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) { + if (cur->ne[1] % 4 == 0) { + return &mxfp4_4x4_q8_0; + } + } } else if (cur->type == GGML_TYPE_Q8_0) { if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) { if (cur->ne[1] % 4 == 0) { diff --git a/ggml/src/ggml-cpu/repack.h b/ggml/src/ggml-cpu/repack.h index ddf03d7642d..b9f821630c4 100644 --- a/ggml/src/ggml-cpu/repack.h +++ b/ggml/src/ggml-cpu/repack.h @@ -97,6 +97,19 @@ struct block_iq4_nlx8 { static_assert(sizeof(block_iq4_nlx8) == 8 * sizeof(ggml_half) + QK4_NL * 4, "wrong iq4_nlx8 block size/padding"); +struct block_mxfp4x4 { + uint8_t e[4]; + uint8_t qs[QK_MXFP4 * 2]; +}; +static_assert(sizeof(block_mxfp4x4) == 4 + QK_MXFP4 * 2, "wrong mxfp4x4 block size/padding"); + +struct block_mxfp4x8 { + uint8_t e[8]; + uint8_t qs[QK_MXFP4 * 4]; +}; +static_assert(sizeof(block_mxfp4x8) == 8 + QK_MXFP4 * 4, "wrong mxfp4x8 block size/padding"); + + #if defined(__cplusplus) extern "C" { #endif @@ -117,6 +130,8 @@ void ggml_gemv_q6_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo void ggml_gemv_q6_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_mxfp4_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_mxfp4_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -129,6 +144,8 @@ void ggml_gemm_q6_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo void ggml_gemm_q6_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_mxfp4_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_mxfp4_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q8_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -151,6 +168,8 @@ void ggml_gemv_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, void ggml_gemv_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_mxfp4_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_mxfp4_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -163,6 +182,8 @@ void ggml_gemm_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, void ggml_gemm_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_mxfp4_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_mxfp4_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q8_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); From 699eaf3a102551e033a5e8b342108936ef604d33 Mon Sep 17 00:00:00 2001 From: Jayant Lohia Date: Sat, 28 Feb 2026 00:07:26 +0530 Subject: [PATCH 199/831] CUDA: add CDNA3 MFMA support for flash attention MMA kernel (llama/19806) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * CUDA: add CDNA3 MFMA support for flash attention MMA kernel Add MI300X (gfx942) MFMA tensor core flash attention using v_mfma_f32_16x16x16_f16 (FP16 in, FP32 accumulate). - Add FATTN_WARP_SIZE=64 for CDNA wavefront64 - Add CDNA config for head sizes 64, 80, 96, 112, 128 - Add FP16 MFMA intrinsic path in mma.cuh - Add manual V transpose load for MFMA register layout - Route CDNA to MMA for prompt processing, VEC for token generation - Fix Q loading and combine stride granularity for non-power-of-2 heads Benchmarks (Qwen2.5-1.5B Q4_K_M, MI300X): pp512 +7%, pp1024 +13%, pp2048 +23%, pp4096 +39% tg128 -10% (FA overhead, VEC used for both) All 2480 flash attention tests pass. Ref: https://github.com/ggml-org/llama.cpp/issues/17917 * address review: replace FATTN_WARP_SIZE with constexpr, improve dispatch - Replace #define FATTN_WARP_SIZE with constexpr int warp_size = ggml_cuda_get_physical_warp_size() in each device function - Use ne[1]*gqa_ratio threshold for MMA vs tile dispatch. Benchmarked crossover on MI300X @ d32768 with power-of-2 GQA models: hsk=64 (Llama 1B, gqa=4): MMA wins at eff >= 128 (+11%) hsk=128 (Llama 3B, gqa=4): MMA wins at eff >= 128 (+4%) Unified threshold: eff_nq >= 128 for all head sizes. - Remove VEC fallback; small batches fall through to tile kernel * Update ggml/src/ggml-cuda/fattn.cu * use ggml_cuda_info().devices warp_size instead of hardcoded check --------- Co-authored-by: Johannes Gäßler --- ggml/src/ggml-cuda/fattn-mma-f16.cuh | 246 ++++++++++++++++++--------- ggml/src/ggml-cuda/fattn.cu | 12 ++ ggml/src/ggml-cuda/mma.cuh | 30 +++- 3 files changed, 203 insertions(+), 85 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 0b8ef90794c..beb7e32e4fc 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -111,6 +111,44 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols); } +static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_cdna(const int DKQ, const int DV, const int ncols) { + // Conservative configs for CDNA (MI100+): 64KB LDS, wavefront64, nstages=1 (no cp.async). + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 8, 128, 2, 128, 32, 32, 32, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 16, 128, 2, 64, 32, 32, 32, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 32, 128, 2, 64, 32, 32, 32, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 64, 256, 2, 64, 32, 32, 32, 1, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 8, 128, 2, 128, 40, 40, 40, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 16, 128, 2, 64, 40, 40, 40, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 32, 128, 2, 64, 40, 40, 40, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 64, 256, 2, 64, 40, 40, 40, 1, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 8, 128, 2, 128, 48, 48, 48, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 16, 128, 2, 64, 48, 48, 48, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 32, 128, 2, 64, 48, 48, 48, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 64, 256, 2, 64, 48, 48, 48, 1, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 8, 128, 2, 128, 56, 56, 56, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 16, 128, 2, 64, 56, 56, 56, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 32, 128, 2, 64, 56, 56, 56, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 64, 256, 2, 64, 56, 56, 56, 1, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 8, 128, 2, 128, 64, 64, 64, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 16, 128, 2, 64, 64, 64, 64, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 32, 128, 2, 64, 64, 64, 64, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 64, 256, 2, 64, 64, 64, 64, 1, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 8, 64, 4, 64, 128, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 64, 4, 32, 128, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 256, 2, 32, 128, 128, 128, 1, true); + + // Fallback for unsupported DKQ values (e.g. 576). Must return non-zero values to satisfy + // compile-time static_asserts even though the kernel guard prevents runtime execution. + // nthreads=256 gives nwarps=4 (warp_size=64) or 8 (warp_size=32), nbatch_fa=128 satisfies np*16 divisibility. + return fattn_mma_config(256, 1, 128, 4, 4, 4, 1, false); +} + static __host__ fattn_mma_config ggml_cuda_fattn_mma_get_config(const int DKQ, const int DV, const int ncols, const int cc) { if (ampere_mma_available(cc)) { return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols); @@ -118,6 +156,9 @@ static __host__ fattn_mma_config ggml_cuda_fattn_mma_get_config(const int DKQ, c if (turing_mma_available(cc)) { return ggml_cuda_fattn_mma_get_config_turing(DKQ, DV, ncols); } + if (amd_mfma_available(cc)) { + return ggml_cuda_fattn_mma_get_config_cdna(DKQ, DV, ncols); + } if (amd_wmma_available(cc)) { return ggml_cuda_fattn_mma_get_config_rdna(DKQ, DV, ncols); } @@ -130,6 +171,8 @@ static constexpr __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config(cons return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols); #elif defined(TURING_MMA_AVAILABLE) return ggml_cuda_fattn_mma_get_config_turing(DKQ, DV, ncols); +#elif defined(AMD_MFMA_AVAILABLE) + return ggml_cuda_fattn_mma_get_config_cdna(DKQ, DV, ncols); #elif defined(VOLTA_MMA_AVAILABLE) return ggml_cuda_fattn_mma_get_config_volta(DKQ, DV, ncols); #elif defined(AMD_WMMA_AVAILABLE) @@ -205,15 +248,15 @@ static constexpr __device__ bool ggml_cuda_fattn_mma_get_Q_in_reg(const int DKQ, } static constexpr __device__ int get_cols_per_thread() { -#if defined(AMD_WMMA_AVAILABLE) - return 1; // RDNA has a single column. +#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) + return 1; // AMD has a single column per thread. #else return 2; // This is specifically KQ columns, Volta only has a single VKQ column. -#endif // defined(AMD_WMMA_AVAILABLE) +#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) } static __host__ int get_cols_per_warp(const int cc) { - if (turing_mma_available(cc) || amd_wmma_available(cc)) { + if (turing_mma_available(cc) || amd_wmma_available(cc) || amd_mfma_available(cc)) { return 16; } else { // Volta @@ -241,6 +284,7 @@ static constexpr __device__ int ggml_cuda_fattn_mma_get_nstages(const int DKQ, c template static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV, const int i_sup) { + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); // K/V data is loaded with decreasing granularity for D for better memory bandwidth. // The minimum granularity with cp.async is 16 bytes, with synchronous data loading it's 4 bytes. if constexpr (use_cp_async) { @@ -252,10 +296,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( const unsigned int tile_KV_32 = ggml_cuda_cvta_generic_to_shared(tile_KV); auto load = [&] __device__ (auto n) { - const int stride_k = WARP_SIZE >> n; - const int k0_start = stride_k == WARP_SIZE ? 0 : chunks_per_row - chunks_per_row % (2*stride_k); + const int stride_k = warp_size >> n; + const int k0_start = stride_k == warp_size ? 0 : chunks_per_row - chunks_per_row % (2*stride_k); const int k0_stop = chunks_per_row - chunks_per_row % (1*stride_k); - const int stride_i = WARP_SIZE / stride_k; + const int stride_i = warp_size / stride_k; if (k0_start == k0_stop) { return; @@ -263,7 +307,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( #pragma unroll for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) { - const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); + const int i = i0 + threadIdx.y*stride_i + (stride_k == warp_size ? 0 : threadIdx.x / stride_k); if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) { break; @@ -271,7 +315,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( #pragma unroll for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { - const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k); cp_async_cg_16(tile_KV_32 + i*(stride_tile*sizeof(half2)) + k*16, KV + i*stride_KV + k*h2_per_chunk); } @@ -287,10 +331,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( } else { // TODO use ggml_cuda_memcpy_1 auto load = [&] __device__ (const int n) { - const int stride_k = WARP_SIZE >> n; - const int k0_start = stride_k == WARP_SIZE ? 0 : D2 - D2 % (2*stride_k); + const int stride_k = warp_size >> n; + const int k0_start = stride_k == warp_size ? 0 : D2 - D2 % (2*stride_k); const int k0_stop = D2 - D2 % (1*stride_k); - const int stride_i = WARP_SIZE / stride_k; + const int stride_i = warp_size / stride_k; if (k0_start == k0_stop) { return; @@ -298,7 +342,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( #pragma unroll for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) { - const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); + const int i = i0 + threadIdx.y*stride_i + (stride_k == warp_size ? 0 : threadIdx.x / stride_k); if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) { break; @@ -306,7 +350,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( #pragma unroll for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { - const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k); tile_KV[i*stride_tile + k] = !oob_check || i < i_sup ? KV[i*stride_KV + k] : make_half2(0.0f, 0.0f); } @@ -324,18 +368,19 @@ template= 32 ? nbatch_fa * sizeof(half) : 64; - constexpr int cols_per_warp = 8*WARP_SIZE/nbatch_fa; + constexpr int cols_per_warp = 8*warp_size/nbatch_fa; constexpr int stride_j = nwarps * cols_per_warp; const unsigned int tile_mask_32 = ggml_cuda_cvta_generic_to_shared(tile_mask); #pragma unroll for (int j1 = 0; j1 < ncols1; j1 += stride_j) { - const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (WARP_SIZE/cols_per_warp); + const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (warp_size/cols_per_warp); const int j_vram = fastmodulo(j0 + j_sram, ne01); if (j1 + stride_j > ncols1 && j_sram >= ncols1) { @@ -357,25 +402,25 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask( } #pragma unroll - for (int i0 = 0; i0 < nbatch_fa; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < nbatch_fa; i0 += warp_size) { const int i = i0 + threadIdx.x; tile_mask[j_sram*(nbatch_fa + 8) + i] = i < i_sup ? mask_h[j_vram*stride_mask + i] : half(0.0f); } } - } else if constexpr (nbatch_fa < 2*WARP_SIZE) { - constexpr int cols_per_warp = 2*WARP_SIZE/nbatch_fa; + } else if constexpr (nbatch_fa < 2*warp_size) { + constexpr int cols_per_warp = 2*warp_size/nbatch_fa; constexpr int stride_j = nwarps * cols_per_warp; #pragma unroll for (int j1 = 0; j1 < ncols1; j1 += stride_j) { - const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (WARP_SIZE/cols_per_warp); + const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (warp_size/cols_per_warp); const int j_vram = fastmodulo(j0 + j_sram, ne01); if (j1 + stride_j > ncols1 && j_sram >= ncols1) { break; } - const int i = threadIdx.x % (WARP_SIZE/cols_per_warp); + const int i = threadIdx.x % (warp_size/cols_per_warp); ggml_cuda_memcpy_1(tile_mask + j_sram*(nbatch_fa + 8) + 2*i, mask_h + j_vram*stride_mask + 2*i); } @@ -390,7 +435,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask( } #pragma unroll - for (int i0 = 0; i0 < nbatch_fa; i0 += 2*WARP_SIZE) { + for (int i0 = 0; i0 < nbatch_fa; i0 += 2*warp_size) { const int i = i0 + 2*threadIdx.x; ggml_cuda_memcpy_1(tile_mask + j_sram*(nbatch_fa + 8) + i, mask_h + j_vram*stride_mask + i); @@ -428,7 +473,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( const int jt, const int kb0, const int k_VKQ_sup) { -#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) +#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE) + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr int ncols = ncols1 * ncols2; constexpr int cols_per_warp = T_B_KQ::I; constexpr int cols_per_thread = get_cols_per_thread(); @@ -447,7 +493,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( const int k_VKQ_0 = kb0 * nbatch_fa; #if defined(TURING_MMA_AVAILABLE) T_C_KQ KQ_C[nbatch_fa/(np*(cols_per_warp == 8 ? T_C_KQ::I : T_C_KQ::J))]; -#elif defined(AMD_WMMA_AVAILABLE) +#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) T_C_KQ KQ_C[nbatch_fa/(np*T_C_KQ::J)]; #else // Volta T_C_KQ KQ_C[nbatch_fa/(np*T_C_KQ::J)]; @@ -500,13 +546,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[k_KQ_0/T_A_KQ::J]); } else { // Wide version of KQ_C is column-major -#if defined(AMD_WMMA_AVAILABLE) - // RDNA matrix C is column-major. +#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) + // AMD matrix C is column-major. mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[k_KQ_0/T_A_KQ::J]); #else // swap A and B for CUDA. mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[k_KQ_0/T_A_KQ::J], K_A); -#endif // defined(AMD_WMMA_AVAILABLE) +#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) } } } @@ -526,13 +572,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]); } else { // Wide version of KQ_C is column-major -#if defined(AMD_WMMA_AVAILABLE) - // RDNA matrix C is column-major. +#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) + // AMD matrix C is column-major. mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]); #else // swap A and B for CUDA. mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A); -#endif // defined(AMD_WMMA_AVAILABLE) +#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) } } } @@ -585,12 +631,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( #pragma unroll for (int l = 0; l < T_C_KQ::ne; ++l) { if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::I + T_C_KQ::get_i(l) < k_VKQ_sup) { -#if defined(AMD_WMMA_AVAILABLE) +#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) constexpr int KQ_idx = 0; #else // Turing + Volta: const int KQ_idx = l % 2; -#endif // defined(AMD_WMMA_AVAILABLE) +#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) KQ_max_new[KQ_idx] = fmaxf(KQ_max_new[KQ_idx], KQ_C[k0/(np*T_C_KQ::I)].x[l] + FATTN_KQ_MAX_OFFSET); } } @@ -601,7 +647,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( for (int col = 0; col < cols_per_thread; ++col) { #pragma unroll for (int offset = 16; offset >= 4; offset >>= 1) { - KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE)); + KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, warp_size)); } } @@ -611,12 +657,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( #pragma unroll for (int l = 0; l < T_C_KQ::ne; ++l) { if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::I + T_C_KQ::get_i(l) < k_VKQ_sup) { -#if defined(AMD_WMMA_AVAILABLE) +#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) constexpr int KQ_idx = 0; #else // Turing + Volta: const int KQ_idx = l % 2; -#endif // defined(AMD_WMMA_AVAILABLE) +#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) KQ_C[k0/(np*T_C_KQ::I)].x[l] = expf(KQ_C[k0/(np*T_C_KQ::I)].x[l] - KQ_max_new[KQ_idx]); KQ_rowsum_add[KQ_idx] += KQ_C[k0/(np*T_C_KQ::I)].x[l]; } else { @@ -649,12 +695,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( #pragma unroll for (int l = 0; l < T_C_KQ::ne; ++l) { if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::J + T_C_KQ::get_j(l) < k_VKQ_sup) { -#if defined(AMD_WMMA_AVAILABLE) +#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) constexpr int KQ_idx = 0; #else // Turing + Volta: const int KQ_idx = (l/2) % 2; -#endif // defined(AMD_WMMA_AVAILABLE) +#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) KQ_max_new[KQ_idx] = fmaxf(KQ_max_new[KQ_idx], KQ_C[(k0/(np*T_C_KQ::J))].x[l] + FATTN_KQ_MAX_OFFSET); } } @@ -666,6 +712,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( // Values per KQ column are spread across 4 threads: constexpr int offset_first = 2; constexpr int offset_last = 1; +#elif defined(AMD_MFMA_AVAILABLE) + // MFMA: 4 threads per Q column (threadIdx.x % 16 == col, spaced by 16). + constexpr int offset_first = 32; + constexpr int offset_last = 16; #elif defined(AMD_WMMA_AVAILABLE) // Values per KQ column are spread across 2 threads: constexpr int offset_first = 16; @@ -677,7 +727,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( #endif // defined(TURING_MMA_AVAILABLE) #pragma unroll for (int offset = offset_first; offset >= offset_last; offset >>= 1) { - KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE)); + KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, warp_size)); } } @@ -687,12 +737,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( #pragma unroll for (int l = 0; l < T_C_KQ::ne; ++l) { if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::J + T_C_KQ::get_j(l) < k_VKQ_sup) { -#if defined(AMD_WMMA_AVAILABLE) +#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) constexpr int KQ_idx = 0; #else // Turing + Volta: const int KQ_idx = (l/2) % 2; -#endif // defined(AMD_WMMA_AVAILABLE) +#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) KQ_C[(k0/(np*T_C_KQ::J))].x[l] = expf(KQ_C[(k0/(np*T_C_KQ::J))].x[l] - KQ_max_new[KQ_idx]); KQ_rowsum_add[KQ_idx] += KQ_C[(k0/(np*T_C_KQ::J))].x[l]; } else { @@ -739,7 +789,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } } } -#elif defined(AMD_WMMA_AVAILABLE) +#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) const half2 KQ_max_scale_h2 = make_half2( KQ_max_scale[0], KQ_max_scale[0]); #pragma unroll @@ -818,7 +868,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } const half2 * tile_V_i = !V_is_K_view || i0_stop > 2*nbatch_K2 ? tile_V : tile_V + i0_start/2; -#if defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) +#if defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) constexpr int i0_stride = cols_per_warp == 8 ? T_C_VKQ::I : 2*T_C_VKQ::J; #pragma unroll for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += i0_stride) { @@ -830,24 +880,38 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( T_A_VKQ A; // Transposed in SRAM but not in registers, gets transposed on load. #if defined(LDMATRIX_TRANS_AVAILABLE) load_ldmatrix_trans(A, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V); +#elif defined(AMD_MFMA_AVAILABLE) + // MFMA A register layout: A_mat[i=lane%16][k=4*(lane/16)+reg]. + // Normal load gives A_mat[seq][dv] but we need A_mat[dv][seq] = V^T. + // Load with transposed addressing: 4 strided half loads. + { + const half2 * xs0 = tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2; + const half * xs0_h = (const half *) xs0; + const int stride_h = stride_tile_V * 2; // stride in half units + half * A_h = (half *) A.x; +#pragma unroll + for (int l = 0; l < 4; ++l) { + A_h[l] = xs0_h[(4*(threadIdx.x / 16) + l) * stride_h + threadIdx.x % 16]; + } + } #else // TODO: Try to transpose tile_V when loading gmem to smem. // Use mma to transpose T_A_VKQ for RDNA. T_A_VKQ A_trans; load_ldmatrix(A_trans, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V); mma(A, A_trans, A_identity); -#endif // defined(TURING_MMA_AVAILABLE) +#endif // defined(LDMATRIX_TRANS_AVAILABLE) if constexpr (T_B_KQ::I == 8) { mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]); } else { // Wide version of VKQ_C is column-major. -#if defined(AMD_WMMA_AVAILABLE) - // RDNA matrix C is column-major. +#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) + // AMD matrix C is column-major. mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]); #else // swap A and B for CUDA. mma(VKQ_C[i_VKQ_0/i0_stride], B[k00/(np*T_A_VKQ::J)], A); -#endif // defined(AMD_WMMA_AVAILABLE) +#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) } } } @@ -866,7 +930,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( mma(VKQ_C[i_VKQ_0/i0_stride], B[k00/(np*T_A_VKQ::I)], A); } } -#endif // defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) +#endif // defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) if constexpr (nstages <= 1) { __syncthreads(); // Only needed if tile_K == tile_V. @@ -879,7 +943,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0); NO_DEVICE_CODE; -#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) +#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE) } #if defined(TURING_MMA_AVAILABLE) @@ -899,7 +963,7 @@ template<> struct mma_tile_sizes<8> { using T_B_VKQ = tile< 8, 8, half2>; // column-major using T_C_VKQ = tile<16, 4, half2>; // row-major }; -#elif defined(AMD_WMMA_AVAILABLE) +#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) template struct mma_tile_sizes { using T_A_KQ = tile<16, 8, half2>; // row-major using T_B_KQ = tile<16, 8, half2>; // column-major @@ -944,9 +1008,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const int zt_gqa, const int kb0_start, const int kb0_stop) { -#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) +#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE) //In this kernel Q, K, V are matrices while i, j, k are matrix indices. + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr int ncols = ncols1 * ncols2; using T_A_KQ = typename mma_tile_sizes::T_A_KQ; using T_B_KQ = typename mma_tile_sizes::T_B_KQ; @@ -986,7 +1051,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( T_B_KQ Q_B[(Q_in_reg ? DKQ/(2*T_B_KQ::J) : 1)]; #if defined(TURING_MMA_AVAILABLE) T_C_VKQ VKQ_C[cols_per_warp == 8 ? DV/T_C_VKQ::I : DV/(2*T_C_VKQ::J)]; -#elif defined(AMD_WMMA_AVAILABLE) +#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) T_C_VKQ VKQ_C[ DV/(2*T_C_VKQ::J)]; #else // Volta T_C_VKQ VKQ_C[ DV/(2*T_C_VKQ::J)]; @@ -1004,10 +1069,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( // The loading is done with decreasing granularity for D for better memory bandwidth. const half2 scale_h2 = make_half2(scale, scale); #pragma unroll - for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { - const int k0_start = stride_k == WARP_SIZE ? 0 : DKQ/2 - (DKQ/2) % (2*stride_k); + for (int stride_k : {warp_size, warp_size/2, warp_size/4, warp_size/8}) { + const int k0_start = stride_k == warp_size ? 0 : DKQ/2 - (DKQ/2) % (2*stride_k); const int k0_stop = DKQ/2 - (DKQ/2) % (1*stride_k); - const int stride_jc = WARP_SIZE / stride_k; + const int stride_jc = warp_size / stride_k; if (k0_start == k0_stop) { continue; @@ -1015,7 +1080,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( #pragma unroll for (int jc0 = 0; jc0 < ncols; jc0 += nwarps*stride_jc) { - const int jc = jc0 + threadIdx.y*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); + const int jc = jc0 + threadIdx.y*stride_jc + (stride_k == warp_size ? 0 : threadIdx.x / stride_k); if (jc0 + nwarps*stride_jc > ncols && jc >= ncols) { break; @@ -1027,7 +1092,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( if ((ncols1 == 1 || jt*ncols1 + j < int(ne01.z)) && (ncols2 == 1 || zt_gqa*ncols2 + c < gqa_ratio)) { #pragma unroll for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { - const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k); const float2 tmp = Q_f2[(jt*ncols1 + j)*stride_Q1 + c*stride_Q2 + k]; tile_Q[jc*stride_tile_Q + k] = scale_h2 * make_half2(tmp.x, tmp.y); @@ -1035,7 +1100,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( } else { #pragma unroll for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { - const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k); tile_Q[jc*stride_tile_Q + k] = make_half2(0.0f, 0.0f); } @@ -1127,6 +1192,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( // The partial sums are spread across 8/4 threads. constexpr int offset_first = cols_per_warp == 8 ? 16 : 2; constexpr int offset_last = cols_per_warp == 8 ? 4 : 1; +#elif defined(AMD_MFMA_AVAILABLE) + // The partial sums are spread across 4 threads (wavefront64, 16 cols). + constexpr int offset_first = 32; + constexpr int offset_last = 16; #elif defined(AMD_WMMA_AVAILABLE) // The partial sums are spread across 2 threads. constexpr int offset_first = 16; @@ -1140,7 +1209,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( for (int col = 0; col < cols_per_thread; ++col) { #pragma unroll for (int offset = offset_first; offset >= offset_last; offset >>= 1) { - KQ_rowsum[col] += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum[col], offset, WARP_SIZE); + KQ_rowsum[col] += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum[col], offset, warp_size); } } } @@ -1189,7 +1258,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( } } } -#elif defined(AMD_WMMA_AVAILABLE) +#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[0]); #pragma unroll for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) { @@ -1249,7 +1318,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const int jc_cwm = threadIdx.y*cols_per_warp + T_C_VKQ::get_i(threadIdx.x % 4); const float2 KQ_cmr = make_float2(KQ_max[threadIdx.x % cols_per_thread], KQ_rowsum[threadIdx.x % cols_per_thread]); const bool thread_should_write = threadIdx.x % 4 < cols_per_thread; -#elif defined(AMD_WMMA_AVAILABLE) +#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) const int jc_cwm = threadIdx.y*cols_per_warp + T_C_VKQ::get_i(0); const float2 KQ_cmr = make_float2(KQ_max[0], KQ_rowsum[0]); const bool thread_should_write = threadIdx.x / 16 < cols_per_thread; @@ -1283,14 +1352,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( // Warps with threadIdx.y % np != 0 must NOT return early. // All threads must return simultaneously to avoid race conditions with work on the next tile. - constexpr int nmeta = np*cols_per_warp >= WARP_SIZE ? np*cols_per_warp/WARP_SIZE : 1; + constexpr int nmeta = np*cols_per_warp >= warp_size ? np*cols_per_warp/warp_size : 1; - const int jc_meta = threadIdx.y*cols_per_warp + (np*cols_per_warp < WARP_SIZE ? threadIdx.x % (np*cols_per_warp) : threadIdx.x); + const int jc_meta = threadIdx.y*cols_per_warp + (np*cols_per_warp < warp_size ? threadIdx.x % (np*cols_per_warp) : threadIdx.x); float2 * const meta_ptr = ((float2 *) tile_Q) + jc_meta*(tile_stride/2) + nbatch_combine/2; float2 meta[nmeta]; #pragma unroll for (int imeta = 0; imeta < nmeta; ++imeta) { - meta[imeta] = meta_ptr[imeta * WARP_SIZE * tile_stride/2]; + meta[imeta] = meta_ptr[imeta * warp_size * tile_stride/2]; } float KQ_cmn = meta[0].x; // KQ combine max new, max between all parallel warps. @@ -1300,8 +1369,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( } #pragma unroll for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) { - if (offset < WARP_SIZE) { - KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE)); + if (offset < warp_size) { + KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, warp_size)); } } @@ -1318,8 +1387,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( } #pragma unroll for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) { - if (offset < WARP_SIZE) { - KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE); + if (offset < warp_size) { + KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, warp_size); } } @@ -1328,19 +1397,19 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( // Write back combined meta data: #pragma unroll for (int imeta = 0; imeta < nmeta; ++imeta) { - if (np*cols_per_warp >= WARP_SIZE || threadIdx.x < np*cols_per_warp) { + if (np*cols_per_warp >= warp_size || threadIdx.x < np*cols_per_warp) { // Combined KQ max scale + rowsum. - meta_ptr[imeta * WARP_SIZE * tile_stride/2] = make_float2(KQ_cms[imeta], KQ_crs); + meta_ptr[imeta * warp_size * tile_stride/2] = make_float2(KQ_cms[imeta], KQ_crs); } } // Combined KQ max + rowsum. - static_assert(cols_per_warp <= WARP_SIZE); - if (needs_fixup && (cols_per_warp == WARP_SIZE || threadIdx.x < cols_per_warp)) { + static_assert(cols_per_warp <= warp_size); + if (needs_fixup && (cols_per_warp == warp_size || threadIdx.x < cols_per_warp)) { float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols; dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs); } - if (is_fixup && (cols_per_warp == WARP_SIZE || threadIdx.x < cols_per_warp)) { + if (is_fixup && (cols_per_warp == warp_size || threadIdx.x < cols_per_warp)) { float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols; dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs); } @@ -1388,10 +1457,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( float2 * dstk_fixup_data = dstk_fixup + gridDim.x*(2*ncols) + blockIdx.x*(ncols*(DV/2)); #pragma unroll - for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { - const int k0_start = stride_k == WARP_SIZE ? 0 : nbatch_combine - nbatch_combine % (2*stride_k); + for (int stride_k : {warp_size, warp_size/2, warp_size/4, warp_size/8}) { + const int k0_start = stride_k == warp_size ? 0 : nbatch_combine - nbatch_combine % (2*stride_k); const int k0_stop = nbatch_combine - nbatch_combine % (1*stride_k); - const int stride_jc = WARP_SIZE / stride_k; + const int stride_jc = warp_size / stride_k; if (k0_start == k0_stop) { continue; @@ -1399,7 +1468,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( #pragma unroll for (int jc0_dst = 0; jc0_dst < ncols; jc0_dst += (nwarps/np)*stride_jc) { - const int jc_dst = jc0_dst + (threadIdx.y/np)*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); + const int jc_dst = jc0_dst + (threadIdx.y/np)*stride_jc + (stride_k == warp_size ? 0 : threadIdx.x / stride_k); if (jc0_dst + (nwarps/np)*stride_jc > ncols && jc_dst >= ncols) { break; @@ -1417,7 +1486,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const float * meta_j = (const float *) tile_Q + jc_tile_K*tile_stride + nbatch_combine; #pragma unroll for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { - const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k); float2 dstk_val = make_float2(0.0f, 0.0f); #pragma unroll @@ -1453,7 +1522,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop); NO_DEVICE_CODE; -#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) +#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE) } template @@ -1480,7 +1549,7 @@ static __global__ void flash_attn_ext_f16( const int32_t nb21, const int32_t nb22, const int64_t nb23, const int32_t ne31, const int32_t ne32, const int32_t ne33, const int32_t nb31, const int32_t nb32, const int64_t nb33) { -#if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))) +#if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE)) // Skip unused kernel variants for faster compilation: if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) { @@ -1508,10 +1577,18 @@ static __global__ void flash_attn_ext_f16( } #endif // defined(AMD_WMMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) + if (DKQ != 64 && DKQ != 80 && DKQ != 96 && DKQ != 112 && DKQ != 128) { + NO_DEVICE_CODE; + return; + } +#endif // defined(AMD_MFMA_AVAILABLE) + + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr int ncols = ncols1 * ncols2; constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols); constexpr int nthreads = ggml_cuda_fattn_mma_get_nthreads(DKQ, DV, ncols); - constexpr int nwarps = nthreads / WARP_SIZE; + constexpr int nwarps = nthreads / warp_size; const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. @@ -1624,7 +1701,7 @@ static __global__ void flash_attn_ext_f16( ne31, ne32, ne33, nb31, nb32, nb33); NO_DEVICE_CODE; -#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))) +#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE)) } template @@ -1644,7 +1721,8 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml const int nstages = ggml_cuda_fattn_mma_get_nstages (DKQ, DV, ncols1, ncols2, cc); const int cols_per_warp = std::min(ncols, get_cols_per_warp(cc)); - const int nwarps = nthreads / WARP_SIZE; + const int warp_size_host = ggml_cuda_info().devices[ctx.device].warp_size; + const int nwarps = nthreads / warp_size_host; constexpr bool V_is_K_view = DKQ == 576; // Guaranteed by the kernel selection logic in fattn.cu @@ -1694,7 +1772,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml } launch_fattn - (ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, nbatch_fa, true, true, true); + (ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, nbatch_fa, true, true, true, warp_size_host); } diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 721edd99944..85c177f496f 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -440,6 +440,18 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const return BEST_FATTN_KERNEL_MMA_F16; } + // Use MFMA flash attention for CDNA (MI100+): + if (amd_mfma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 256 && Q->ne[0] != 576) { + const int64_t eff_nq = Q->ne[1] * (gqa_opt_applies ? gqa_ratio : 1); + // MMA vs tile crossover benchmarked on MI300X @ d32768: + // hsk=64 (gqa=4): MMA wins at eff >= 128 (+11%) + // hsk=128 (gqa=4): MMA wins at eff >= 128 (+4%) + if (eff_nq >= (GGML_CUDA_CC_IS_CDNA1(cc) && Q->ne[0] == 64 ? 64 : 128)) { + return BEST_FATTN_KERNEL_MMA_F16; + } + // Fall through to tile kernel for small effective batch sizes. + } + // If there are no tensor cores available, use the generic tile kernel: if (can_use_vector_kernel) { if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) { diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index dd45d6c78fd..5d1dadd3e4f 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -668,7 +668,7 @@ namespace ggml_cuda_mma { return ret; } -#elif defined(AMD_WMMA_AVAILABLE) +#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) template static __device__ __forceinline__ tile get_half2(const tile & tile_float) { tile ret; @@ -964,6 +964,34 @@ namespace ggml_cuda_mma { GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; #endif // defined(RDNA4) +#elif defined(AMD_MFMA_AVAILABLE) + // MFMA: FP16 input, FP32 accumulate, convert back to half2. + using halfx4_t = __attribute__((ext_vector_type(4))) _Float16; + using floatx4_t = __attribute__((ext_vector_type(4))) float; + + // Convert existing half2 accumulator to float for MFMA: + floatx4_t acc_f32; + { + const halfx4_t acc_h = reinterpret_cast(D.x[0]); +#pragma unroll + for (int i = 0; i < 4; ++i) { + acc_f32[i] = (float)acc_h[i]; + } + } + + const halfx4_t& a_frag = reinterpret_cast(A.x[0]); + const halfx4_t& b_frag = reinterpret_cast(B.x[0]); + acc_f32 = __builtin_amdgcn_mfma_f32_16x16x16f16(a_frag, b_frag, acc_f32, 0, 0, 0); + + // Convert back to half2: + { + halfx4_t result_h; +#pragma unroll + for (int i = 0; i < 4; ++i) { + result_h[i] = (_Float16)acc_f32[i]; + } + reinterpret_cast(D.x[0]) = result_h; + } #else GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; From ca3f6bbd3cbf796e1966b0893ea248d456a83592 Mon Sep 17 00:00:00 2001 From: oobabooga Date: Sun, 1 Mar 2026 02:40:22 -0300 Subject: [PATCH 200/831] cuda: cap grid.y at 65535 in non-contiguous dequantize/convert kernels (llama/19999) --- ggml/src/ggml-cuda/convert.cu | 64 +++++++++++++++++------------------ 1 file changed, 32 insertions(+), 32 deletions(-) diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index 09b6d5db6a0..b70492c7d6c 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -16,27 +16,27 @@ static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __ return; } - const int64_t i01 = blockIdx.y; - - for (int64_t i0203 = blockIdx.z; i0203 < ne0203; i0203 += gridDim.z) { - const uint2 dm = fast_div_modulo((uint32_t)i0203, ne02); - const int64_t i02 = dm.y; - const int64_t i03 = dm.x; - - const int64_t ibx0 = i03*s03 + i02*s02 + i01*s01; - - const int64_t ib = ibx0 + i00/qk; // block index - const int64_t iqs = (i00%qk)/qr; // quant index - const int64_t iybs = i00 - i00%qk; // y block start index - const int64_t y_offset = qr == 1 ? 1 : qk/2; - - // dequantize - float2 v; - dequantize_kernel(vx, ib, iqs, v); - - const int64_t iy0 = (i0203*ne01 + i01)*ne00 + iybs + iqs; - y[iy0 + 0] = ggml_cuda_cast(v.x); - y[iy0 + y_offset] = ggml_cuda_cast(v.y); + for (int64_t i01 = blockIdx.y; i01 < ne01; i01 += gridDim.y) { + for (int64_t i0203 = blockIdx.z; i0203 < ne0203; i0203 += gridDim.z) { + const uint2 dm = fast_div_modulo((uint32_t)i0203, ne02); + const int64_t i02 = dm.y; + const int64_t i03 = dm.x; + + const int64_t ibx0 = i03*s03 + i02*s02 + i01*s01; + + const int64_t ib = ibx0 + i00/qk; // block index + const int64_t iqs = (i00%qk)/qr; // quant index + const int64_t iybs = i00 - i00%qk; // y block start index + const int64_t y_offset = qr == 1 ? 1 : qk/2; + + // dequantize + float2 v; + dequantize_kernel(vx, ib, iqs, v); + + const int64_t iy0 = (i0203*ne01 + i01)*ne00 + iybs + iqs; + y[iy0 + 0] = ggml_cuda_cast(v.x); + y[iy0 + y_offset] = ggml_cuda_cast(v.y); + } } } @@ -492,7 +492,7 @@ static void dequantize_block_cuda(const void * vx, dst_t * y, const int64_t s01, const int64_t s02, const int64_t s03, cudaStream_t stream) { const int64_t ne0203 = ne02*ne03; const uint3 ne02_fdv = init_fastdiv_values(ne02); - const dim3 num_blocks((ne00 + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE), ne01, (int)std::min(ne0203, (int64_t)65535)); + const dim3 num_blocks((ne00 + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE), (int)std::min(ne01, (int64_t)65535), (int)std::min(ne0203, (int64_t)65535)); dequantize_block<<>> (vx, y, ne00, ne01, ne0203, ne02_fdv, s01, s02, s03); } @@ -628,18 +628,18 @@ static __global__ void convert_unary( return; } - const int64_t i01 = blockIdx.y; - const src_t * x = (const src_t *) vx; - for (int64_t i0203 = blockIdx.z; i0203 < ne0203; i0203 += gridDim.z) { - const uint2 dm = fast_div_modulo((uint32_t)i0203, ne02); - const int64_t i02 = dm.y; - const int64_t i03 = dm.x; + for (int64_t i01 = blockIdx.y; i01 < ne01; i01 += gridDim.y) { + for (int64_t i0203 = blockIdx.z; i0203 < ne0203; i0203 += gridDim.z) { + const uint2 dm = fast_div_modulo((uint32_t)i0203, ne02); + const int64_t i02 = dm.y; + const int64_t i03 = dm.x; - const int64_t ix = i03*s03 + i02*s02 + i01*s01 + i00; - const int64_t iy = (i0203*ne01 + i01)*ne00 + i00; - y[iy] = ggml_cuda_cast(x[ix]); + const int64_t ix = i03*s03 + i02*s02 + i01*s01 + i00; + const int64_t iy = (i0203*ne01 + i01)*ne00 + i00; + y[iy] = ggml_cuda_cast(x[ix]); + } } } @@ -649,7 +649,7 @@ static void convert_unary_cuda(const void * vx, dst_t * y, const int64_t s01, const int64_t s02, const int64_t s03, cudaStream_t stream) { const int64_t ne0203 = ne02*ne03; const uint3 ne02_fdv = init_fastdiv_values(ne02); - const dim3 num_blocks((ne00 + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE, ne01, (int)std::min(ne0203, (int64_t)65535)); + const dim3 num_blocks((ne00 + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE, (int)std::min(ne01, (int64_t)65535), (int)std::min(ne0203, (int64_t)65535)); convert_unary<<>> (vx, y, ne00, ne01, ne0203, ne02_fdv, s01, s02, s03); } From 2a9649c4205215cfc8ca6e296f7299fd25153bf2 Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Sun, 1 Mar 2026 17:32:14 +0100 Subject: [PATCH 201/831] vulkan: improve partial offloading performance on AMD (llama/19976) * vulkan: fix and enable cpy_tensor_async function * use transfer_queue for async transfers on AMD, synchronize with timeline semaphore * update offload_op logic * fix missing transfer submission * disable async transfer queue on AMD GCN * revert op batch size change * fix cpy_tensor_async checks --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 263 ++++++++++++++++++--------- 1 file changed, 177 insertions(+), 86 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 0fae68628b6..72b11d378a7 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -590,6 +590,7 @@ struct vk_device_struct { vk_queue transfer_queue; bool single_queue; bool support_async; + bool async_use_transfer_queue; uint32_t subgroup_size; uint32_t subgroup_size_log2; uint32_t shader_core_count; @@ -1858,6 +1859,10 @@ struct ggml_backend_vk_context { vk_context_ref compute_ctx; + vk_context_ref transfer_ctx; + vk_semaphore transfer_semaphore; + uint64_t transfer_semaphore_last_submitted {}; + std::vector tensor_ctxs; std::vector descriptor_pools; @@ -1866,6 +1871,7 @@ struct ggml_backend_vk_context { uint32_t pipeline_descriptor_set_requirements {}; vk_command_pool compute_cmd_pool; + vk_command_pool transfer_cmd_pool; // number of additional consecutive nodes that are being fused with the // node currently being processed @@ -5391,13 +5397,19 @@ static vk_device ggml_vk_get_device(size_t idx) { ggml_vk_load_shaders(device); + const bool prefers_transfer_queue = device->vendor_id == VK_VENDOR_ID_AMD && device->architecture != AMD_GCN; + if (!device->single_queue) { const uint32_t transfer_queue_index = compute_queue_family_index == transfer_queue_family_index ? 1 : 0; ggml_vk_create_queue(device, device->transfer_queue, transfer_queue_family_index, transfer_queue_index, { vk::PipelineStageFlagBits::eTransfer }, true); + + device->async_use_transfer_queue = prefers_transfer_queue || (getenv("GGML_VK_ASYNC_USE_TRANSFER_QUEUE") != nullptr); } else { // TODO: Use pointer or reference to avoid copy device->transfer_queue.copyFrom(device->compute_queue); device->transfer_queue.cmd_pool.init(device, &device->transfer_queue); + + device->async_use_transfer_queue = false; } device->buffer_type = { @@ -5871,6 +5883,15 @@ static void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) { ctx->almost_ready_fence = ctx->device->device.createFence({}); ctx->compute_cmd_pool.init(ctx->device, &ctx->device->compute_queue); + if (ctx->device->async_use_transfer_queue) { + vk::SemaphoreTypeCreateInfo tci{ vk::SemaphoreType::eTimeline, 0 }; + vk::SemaphoreCreateInfo ci{}; + ci.setPNext(&tci); + ctx->transfer_semaphore.s = ctx->device->device.createSemaphore(ci); + ctx->transfer_semaphore.value = 0; + + ctx->transfer_cmd_pool.init(ctx->device, &ctx->device->transfer_queue); + } if (vk_perf_logger_enabled) { ctx->perf_logger = std::unique_ptr(new vk_perf_logger()); @@ -6419,6 +6440,47 @@ static void ggml_vk_ctx_begin(vk_device& device, vk_context& subctx) { subctx->s = subctx->seqs[subctx->seqs.size() - 1].data(); } +static vk_context ggml_vk_get_compute_ctx(ggml_backend_vk_context * ctx) { + if (!ctx->compute_ctx.expired()) { + return ctx->compute_ctx.lock(); + } + + vk_context result = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); + + ctx->compute_ctx = result; + ggml_vk_ctx_begin(ctx->device, result); + + if (ctx->device->async_use_transfer_queue && ctx->transfer_semaphore_last_submitted < ctx->transfer_semaphore.value) { + result->s->wait_semaphores.push_back(ctx->transfer_semaphore); + ctx->transfer_semaphore_last_submitted = ctx->transfer_semaphore.value; + } + + return result; +} + +// Submit any pending transfer queue work and signal the transfer semaphore. +// The next compute context created via ggml_vk_get_compute_ctx will wait on this semaphore. +// Returns true if work was submitted. +static bool ggml_vk_submit_transfer_ctx(ggml_backend_vk_context * ctx) { + if (!ctx->device->async_use_transfer_queue || ctx->transfer_ctx.expired()) { + return false; + } + + vk_context cpy_ctx = ctx->transfer_ctx.lock(); + ggml_vk_ctx_end(cpy_ctx); + + for (auto& cpy : cpy_ctx->in_memcpys) { + memcpy(cpy.dst, cpy.src, cpy.n); + } + + ctx->transfer_semaphore.value++; + cpy_ctx->seqs.back().back().signal_semaphores.push_back(ctx->transfer_semaphore); + + ggml_vk_submit(cpy_ctx, {}); + ctx->transfer_ctx.reset(); + return true; +} + static size_t ggml_vk_align_size(size_t width, size_t align) { VK_LOG_DEBUG("ggml_vk_align_size(" << width << ", " << align << ")"); return CEIL_DIV(width, align) * align; @@ -12529,15 +12591,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr } } - vk_context compute_ctx; - - if (ctx->compute_ctx.expired()) { - compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); - ctx->compute_ctx = compute_ctx; - ggml_vk_ctx_begin(ctx->device, compute_ctx); - } else { - compute_ctx = ctx->compute_ctx.lock(); - } + vk_context compute_ctx = ggml_vk_get_compute_ctx(ctx); { // This logic detects dependencies between modes in the graph and calls ggml_vk_sync_buffers @@ -13055,6 +13109,9 @@ static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) { ctx->prealloc_x_need_sync = ctx->prealloc_y_need_sync = ctx->prealloc_split_k_need_sync = false; ggml_vk_command_pool_cleanup(ctx->device, ctx->compute_cmd_pool); + if (ctx->device->async_use_transfer_queue) { + ggml_vk_command_pool_cleanup(ctx->device, ctx->transfer_cmd_pool); + } for (size_t i = 0; i < ctx->gc.semaphores.size(); i++) { ctx->device->device.destroySemaphore({ ctx->gc.semaphores[i].s }); @@ -13116,6 +13173,11 @@ static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) { ctx->descriptor_sets.clear(); ctx->compute_cmd_pool.destroy(ctx->device->device); + if (ctx->device->async_use_transfer_queue) { + ctx->device->device.destroySemaphore(ctx->transfer_semaphore.s); + + ctx->transfer_cmd_pool.destroy(ctx->device->device); + } if (vk_perf_logger_enabled) { ctx->perf_logger->print_timings(true); } @@ -13387,34 +13449,38 @@ static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, ggml_tensor ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context; - vk_context compute_ctx; + vk_context cpy_ctx; - if (ctx->compute_ctx.expired()) { - // Initialize new transfer context - compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); - ctx->compute_ctx = compute_ctx; - ggml_vk_ctx_begin(ctx->device, compute_ctx); + if (ctx->device->async_use_transfer_queue) { + if (ctx->transfer_ctx.expired()) { + // Initialize new transfer context + cpy_ctx = ggml_vk_create_context(ctx, ctx->transfer_cmd_pool); + ctx->transfer_ctx = cpy_ctx; + ggml_vk_ctx_begin(ctx->device, cpy_ctx); + } else { + cpy_ctx = ctx->transfer_ctx.lock(); + } } else { - compute_ctx = ctx->compute_ctx.lock(); + cpy_ctx = ggml_vk_get_compute_ctx(ctx); } vk_buffer buf = buf_ctx->dev_buffer; auto dst_offset = vk_tensor_offset(tensor) + tensor->view_offs + offset; - bool ret = ggml_vk_buffer_write_async(compute_ctx, buf, dst_offset, data, size); + bool ret = ggml_vk_buffer_write_async(cpy_ctx, buf, dst_offset, data, size); if (!ret) { ggml_vk_ensure_sync_staging_buffer(ctx, size); - ggml_vk_sync_buffers(nullptr, compute_ctx); + ggml_vk_sync_buffers(nullptr, cpy_ctx); vk::BufferCopy buffer_cpy; buffer_cpy.srcOffset = 0; buffer_cpy.dstOffset = dst_offset; buffer_cpy.size = size; - compute_ctx->s->buffer.copyBuffer(ctx->sync_staging->buffer, buf->buffer, { buffer_cpy }); - deferred_memcpy(ctx->sync_staging->ptr, data, size, &compute_ctx->in_memcpys); + cpy_ctx->s->buffer.copyBuffer(ctx->sync_staging->buffer, buf->buffer, { buffer_cpy }); + deferred_memcpy(ctx->sync_staging->ptr, data, size, &cpy_ctx->in_memcpys); ggml_vk_synchronize(ctx); } } @@ -13426,16 +13492,7 @@ static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_ ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context; - vk_context compute_ctx; - - if (ctx->compute_ctx.expired()) { - // Initialize new transfer context - compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); - ctx->compute_ctx = compute_ctx; - ggml_vk_ctx_begin(ctx->device, compute_ctx); - } else { - compute_ctx = ctx->compute_ctx.lock(); - } + vk_context compute_ctx = ggml_vk_get_compute_ctx(ctx); vk_buffer buf = buf_ctx->dev_buffer; @@ -13458,31 +13515,60 @@ static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_ } } -static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend, const ggml_tensor * src, ggml_tensor * dst) { +static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const ggml_tensor * src, ggml_tensor * dst) { VK_LOG_DEBUG("ggml_backend_vk_cpy_tensor_async()"); - ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; - if ((dst->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || dst->buffer->buft == ggml_backend_vk_host_buffer_type()) && ggml_backend_buffer_is_vk(src->buffer)) { - ggml_backend_vk_buffer_context * src_buf_ctx = (ggml_backend_vk_buffer_context *)src->buffer->context; - ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend_dst->context; - vk_context compute_ctx; + if (dst->buffer->buft != ggml_backend_vk_get_default_buffer_type(backend_dst)) { + return false; + } - if (ctx->compute_ctx.expired()) { - // Initialize new transfer context - compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); - ctx->compute_ctx = compute_ctx; - ggml_vk_ctx_begin(ctx->device, compute_ctx); - } else { - compute_ctx = ctx->compute_ctx.lock(); + ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + vk_buffer dst_buf = dst_buf_ctx->dev_buffer; + + if (ggml_backend_buffer_is_vk(src->buffer)) { + ggml_backend_vk_buffer_context * src_buf_ctx = (ggml_backend_vk_buffer_context *)src->buffer->context; + + // Async copy only works within the same device + if (src_buf_ctx->dev_buffer->device != dst_buf->device) { + return false; } - vk_buffer src_buf = src_buf_ctx->dev_buffer; - vk_buffer dst_buf = dst_buf_ctx->dev_buffer; + vk_context compute_ctx = ggml_vk_get_compute_ctx(ctx); - ggml_vk_buffer_copy_async(compute_ctx, dst_buf, vk_tensor_offset(dst) + dst->view_offs, src_buf, vk_tensor_offset(src) + src->view_offs, ggml_nbytes(src)); + ggml_vk_buffer_copy_async(compute_ctx, dst_buf, vk_tensor_offset(dst) + dst->view_offs, + src_buf_ctx->dev_buffer, vk_tensor_offset(src) + src->view_offs, + ggml_nbytes(src)); return true; } + if (ggml_backend_buffer_is_host(src->buffer)) { + vk_buffer pinned_buf = nullptr; + size_t pinned_offset = 0; + ggml_vk_host_get(ctx->device, src->data, pinned_buf, pinned_offset); + if (pinned_buf == nullptr) { + return false; + } + + vk_context cpy_ctx; + if (ctx->device->async_use_transfer_queue) { + if (ctx->transfer_ctx.expired()) { + cpy_ctx = ggml_vk_create_context(ctx, ctx->transfer_cmd_pool); + ctx->transfer_ctx = cpy_ctx; + ggml_vk_ctx_begin(ctx->device, cpy_ctx); + } else { + cpy_ctx = ctx->transfer_ctx.lock(); + } + } else { + cpy_ctx = ggml_vk_get_compute_ctx(ctx); + } + + return ggml_vk_buffer_write_async(cpy_ctx, dst_buf, + vk_tensor_offset(dst) + dst->view_offs, + src->data, ggml_nbytes(src)); + } + + GGML_UNUSED(backend_src); return false; } @@ -13491,6 +13577,10 @@ static void ggml_vk_synchronize(ggml_backend_vk_context * ctx) { bool do_transfer = !ctx->compute_ctx.expired(); + if (ggml_vk_submit_transfer_ctx(ctx)) { + ctx->submit_pending = true; + } + vk_context compute_ctx; if (do_transfer) { compute_ctx = ctx->compute_ctx.lock(); @@ -13506,7 +13596,22 @@ static void ggml_vk_synchronize(ggml_backend_vk_context * ctx) { } if (ctx->submit_pending) { - { + if (ctx->device->async_use_transfer_queue && ctx->transfer_semaphore_last_submitted < ctx->transfer_semaphore.value) { + vk::TimelineSemaphoreSubmitInfo tl_info{ + 1, &ctx->transfer_semaphore.value, + 0, nullptr, + }; + vk::PipelineStageFlags stage = ctx->device->transfer_queue.stage_flags; + vk::SubmitInfo si{ + 1, &ctx->transfer_semaphore.s, &stage, + 0, nullptr, + 0, nullptr, + }; + si.setPNext(&tl_info); + std::lock_guard guard(queue_mutex); + ctx->device->compute_queue.queue.submit({ si }, ctx->fence); + ctx->transfer_semaphore_last_submitted = ctx->transfer_semaphore.value; + } else { std::lock_guard guard(queue_mutex); ctx->device->compute_queue.queue.submit({}, ctx->fence); } @@ -13972,6 +14077,8 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg bool first_node_in_batch = true; // true if next node will be first node in a batch int submit_node_idx = 0; // index to first node in a batch + ggml_vk_submit_transfer_ctx(ctx); + vk_context compute_ctx; if (vk_perf_logger_enabled) { // allocate/resize the query pool @@ -13997,9 +14104,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg std::fill(ctx->query_node_idx.begin(), ctx->query_node_idx.end(), 0); GGML_ASSERT(ctx->compute_ctx.expired()); - compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); - ctx->compute_ctx = compute_ctx; - ggml_vk_ctx_begin(ctx->device, compute_ctx); + compute_ctx = ggml_vk_get_compute_ctx(ctx); ctx->query_idx = 0; compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++); } @@ -14009,13 +14114,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg if (ctx->prealloc_size_add_rms_partials) { ggml_vk_preallocate_buffers(ctx, nullptr); - if (ctx->compute_ctx.expired()) { - compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); - ctx->compute_ctx = compute_ctx; - ggml_vk_ctx_begin(ctx->device, compute_ctx); - } else { - compute_ctx = ctx->compute_ctx.lock(); - } + compute_ctx = ggml_vk_get_compute_ctx(ctx); // initialize partial sums to zero. ggml_vk_buffer_memset_async(compute_ctx, ctx->prealloc_add_rms_partials, 0, 0, ctx->prealloc_size_add_rms_partials); ggml_vk_sync_buffers(ctx, compute_ctx); @@ -14238,13 +14337,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg bool enqueued = ggml_vk_build_graph(ctx, cgraph, i, cgraph->nodes[submit_node_idx], submit_node_idx, i + ctx->num_additional_fused_ops >= last_node, almost_ready, submit); if (vk_perf_logger_enabled && enqueued) { - if (ctx->compute_ctx.expired()) { - compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); - ctx->compute_ctx = compute_ctx; - ggml_vk_ctx_begin(ctx->device, compute_ctx); - } else { - compute_ctx = ctx->compute_ctx.lock(); - } + compute_ctx = ggml_vk_get_compute_ctx(ctx); if (!vk_perf_logger_concurrent) { // track a single node/fusion for the current query ctx->query_nodes[ctx->query_idx] = cgraph->nodes[i]; @@ -14579,16 +14672,9 @@ static void ggml_backend_vk_event_record(ggml_backend_t backend, ggml_backend_ev ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; vk_event *vkev = (vk_event *)event->context; - vk_context compute_ctx; + ggml_vk_submit_transfer_ctx(ctx); - if (ctx->compute_ctx.expired()) { - // Initialize new transfer context - compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); - ctx->compute_ctx = compute_ctx; - ggml_vk_ctx_begin(ctx->device, compute_ctx); - } else { - compute_ctx = ctx->compute_ctx.lock(); - } + vk_context compute_ctx = ggml_vk_get_compute_ctx(ctx); // the backend interface doesn't have an explicit reset, so reset it here // before we record the command to set it @@ -14609,16 +14695,7 @@ static void ggml_backend_vk_event_wait(ggml_backend_t backend, ggml_backend_even ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; vk_event *vkev = (vk_event *)event->context; - vk_context compute_ctx; - - if (ctx->compute_ctx.expired()) { - // Initialize new transfer context - compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); - ctx->compute_ctx = compute_ctx; - ggml_vk_ctx_begin(ctx->device, compute_ctx); - } else { - compute_ctx = ctx->compute_ctx.lock(); - } + vk_context compute_ctx = ggml_vk_get_compute_ctx(ctx); ggml_vk_wait_events(compute_ctx, {vkev->event}); ggml_vk_ctx_end(compute_ctx); @@ -14631,7 +14708,7 @@ static ggml_backend_i ggml_backend_vk_interface = { /* .free = */ ggml_backend_vk_free, /* .set_tensor_async = */ ggml_backend_vk_set_tensor_async, /* .get_tensor_async = */ ggml_backend_vk_get_tensor_async, - /* .cpy_tensor_async = */ NULL, // ggml_backend_vk_cpy_tensor_async, + /* .cpy_tensor_async = */ ggml_backend_vk_cpy_tensor_async, /* .synchronize = */ ggml_backend_vk_synchronize, /* .graph_plan_create = */ NULL, /* .graph_plan_free = */ NULL, @@ -15367,11 +15444,25 @@ static bool ggml_backend_vk_device_supports_buft(ggml_backend_dev_t dev, ggml_ba return buft_ctx->device->idx == ctx->device; } +static int64_t ggml_vk_get_op_batch_size(const ggml_tensor * op) { + switch (op->op) { + case GGML_OP_GET_ROWS: + return 0; + case GGML_OP_MUL_MAT: + return op->ne[1]; + case GGML_OP_MUL_MAT_ID: + case GGML_OP_ROPE: + case GGML_OP_ROPE_BACK: + return op->ne[2]; + default: + return ggml_nrows(op); + } +} + static bool ggml_backend_vk_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) { ggml_backend_vk_device_context * dev_ctx = (ggml_backend_vk_device_context *)dev->context; - return (op->ne[1] >= dev_ctx->op_offload_min_batch_size && op->op != GGML_OP_GET_ROWS) || - (op->ne[2] >= dev_ctx->op_offload_min_batch_size && op->op == GGML_OP_MUL_MAT_ID); + return ggml_vk_get_op_batch_size(op) >= dev_ctx->op_offload_min_batch_size; } static ggml_backend_event_t ggml_backend_vk_device_event_new(ggml_backend_dev_t dev) { From e2be9edd5ac5ac174d2aa90b5c2b29a509b874b9 Mon Sep 17 00:00:00 2001 From: Aaron Teo Date: Mon, 2 Mar 2026 16:23:56 +0800 Subject: [PATCH 202/831] ggml-cpu: optimise s390x multiply extend instructions (llama/20032) --- ggml/src/ggml-cpu/arch/s390/quants.c | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/ggml/src/ggml-cpu/arch/s390/quants.c b/ggml/src/ggml-cpu/arch/s390/quants.c index 19d225a4837..34184ed8510 100644 --- a/ggml/src/ggml-cpu/arch/s390/quants.c +++ b/ggml/src/ggml-cpu/arch/s390/quants.c @@ -181,11 +181,11 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi const int8x16_t v_yh = vec_xl(QK8_0/2, y[ib].qs); const int16x8_t v_xylso = vec_mulo(v_xls, v_yl); - const int16x8_t v_xylse = vec_mule(v_xls, v_yl); + const int16x8_t v_xyl = vec_meadd(v_xls, v_yl, v_xylso); const int16x8_t v_xyhso = vec_mulo(v_xhs, v_yh); - const int16x8_t v_xyhse = vec_mule(v_xhs, v_yh); + const int16x8_t v_xyh = vec_meadd(v_xhs, v_yh, v_xyhso); - int16x8_t v_xy_ = v_xylso + v_xylse + v_xyhso + v_xyhse; v_xy_ += vec_reve(v_xy_); + int16x8_t v_xy_ = v_xyl + v_xyh; v_xy_ += vec_reve(v_xy_); const float32x4_t v_xy = vec_float(vec_unpackh(v_xy_)); const float32x4_t v_d = vec_splats(GGML_CPU_FP16_TO_FP32(x[ib].d) * GGML_CPU_FP16_TO_FP32(y[ib].d)); @@ -890,8 +890,7 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi const int16x8_t v_minsh = (int16x8_t)vec_unpackh((uint8x16_t)v_mins8); const int32x4_t v_minso = vec_mulo(v_ysums, v_minsh); - const int32x4_t v_minse = vec_mule(v_ysums, v_minsh); - const int32x4_t v_mins = v_minso + v_minse; + const int32x4_t v_mins = vec_meadd(v_ysums, v_minsh, v_minso); sumf -= dmin * (v_mins[0] + v_mins[1] + v_mins[2] + v_mins[3]); const uint8_t * scales = (const uint8_t *)utmp; @@ -1004,8 +1003,7 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi const int16x8_t v_minsh = (int16x8_t)vec_unpackh(v_mins8); const int32x4_t v_minsho = vec_mulo(v_ysums, v_minsh); - const int32x4_t v_minshe = vec_mule(v_ysums, v_minsh); - const int32x4_t v_mins = vec_add(v_minsho, v_minshe); + const int32x4_t v_mins = vec_meadd(v_ysums, v_minsh, v_minsho); const int32_t mins = vec_hsum_i32x4(v_mins); const uint8_t * scales = (const uint8_t *)utmp; @@ -1110,10 +1108,10 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi const int16x8_t v_scaleh = vec_unpackl(v_scale); const int32x4_t v_minslo = vec_mulo(v_ysumsl, v_scalel); - const int32x4_t v_minsle = vec_mule(v_ysumsl, v_scalel); + const int32x4_t v_minsl = vec_meadd(v_ysumsl, v_scalel, v_minslo); const int32x4_t v_minsho = vec_mulo(v_ysumsh, v_scaleh); - const int32x4_t v_minshe = vec_mule(v_ysumsh, v_scaleh); - const int32x4_t v_mins = v_minslo + v_minsle + v_minsho + v_minshe; + const int32x4_t v_minsh = vec_meadd(v_ysumsh, v_scaleh, v_minsho); + const int32x4_t v_mins = vec_add(v_minsl, v_minsh); const int32_t mins = vec_hsum_i32x4(v_mins); From 923a29242953c347d9719b331148014d075a6d9c Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Mon, 2 Mar 2026 15:58:25 +0100 Subject: [PATCH 203/831] vulkan: tune MMVQ for Intel Windows (llama/19988) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 72b11d378a7..23d6d39e0e8 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -7574,6 +7574,18 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_ return false; } + if (device->driver_id == vk::DriverId::eIntelProprietaryWindows) { + // Intel Windows proprietary driver tuning + switch (src0_type) { + case GGML_TYPE_MXFP4: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + return false; + default: + return true; + } + } + switch (src0_type) { // From tests on A770 Linux, may need more tuning case GGML_TYPE_Q4_0: From de686fafad1fd6449492091314fd6a9e85eac027 Mon Sep 17 00:00:00 2001 From: Masashi Yoshimura Date: Tue, 3 Mar 2026 00:59:53 +0900 Subject: [PATCH 204/831] ggml-webgpu: Support non-contiguous `src0` and overlapping `src0/src1` in binary ops (llama/19850) * ggml-webgpu: Add binary op support for overlapping and non-contiguous. * Add newline to binary.wgsl * Append the test of binary op for src overlapping to test_bin_bcast. * Remove unnecessary newline. --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 9 +- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 82 +++++++++++++------ ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl | 50 +++++++++-- 3 files changed, 109 insertions(+), 32 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 0d5a818dacb..369475eaf50 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -68,6 +68,7 @@ struct ggml_webgpu_shader_lib_context { size_t wg_mem_limit_bytes = 0; bool inplace = false; bool overlap = false; + bool src_overlap = false; bool supports_subgroup_matrix = false; uint32_t sg_mat_m = 0; uint32_t sg_mat_n = 0; @@ -179,9 +180,10 @@ struct ggml_webgpu_binary_pipeline_key { int op; bool inplace; bool overlap; + bool src_overlap; bool operator==(const ggml_webgpu_binary_pipeline_key & other) const { - return type == other.type && op == other.op && inplace == other.inplace && overlap == other.overlap; + return type == other.type && op == other.op && inplace == other.inplace && overlap == other.overlap && src_overlap == other.src_overlap; } }; @@ -192,6 +194,7 @@ struct ggml_webgpu_binary_pipeline_key_hash { ggml_webgpu_hash_combine(seed, key.op); ggml_webgpu_hash_combine(seed, key.inplace); ggml_webgpu_hash_combine(seed, key.overlap); + ggml_webgpu_hash_combine(seed, key.src_overlap); return seed; } }; @@ -1044,6 +1047,7 @@ class ggml_webgpu_shader_lib { .op = context.dst->op, .inplace = context.inplace, .overlap = context.overlap, + .src_overlap = context.src_overlap, }; auto it = binary_pipelines.find(key); @@ -1076,6 +1080,9 @@ class ggml_webgpu_shader_lib { } else if (key.overlap) { defines.push_back("OVERLAP"); variant += "_overlap"; + } else if (key.src_overlap) { + defines.push_back("SRC_OVERLAP"); + variant += "_src_overlap"; } defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 1c00d3cb2b1..4dc56e1dc58 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -788,6 +788,7 @@ static bool ggml_webgpu_tensor_overlap(ggml_tensor * a, ggml_tensor * b) { struct binary_overlap_flags { bool inplace; // src0 == dst bool overlap; // src1 == dst + bool src_overlap; }; static binary_overlap_flags ggml_webgpu_detect_binary_overlap(ggml_tensor * src0, @@ -796,6 +797,7 @@ static binary_overlap_flags ggml_webgpu_detect_binary_overlap(ggml_tensor * src0 binary_overlap_flags flags = {}; flags.inplace = ggml_webgpu_tensor_equal(src0, dst); flags.overlap = ggml_webgpu_tensor_overlap(src1, dst); + flags.src_overlap = ggml_webgpu_tensor_overlap(src0, src1); return flags; } @@ -1353,6 +1355,7 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, .inplace = flags.inplace, .overlap = flags.overlap, + .src_overlap = flags.src_overlap, }; webgpu_pipeline pipeline = ctx->shader_lib->get_binary_pipeline(shader_lib_ctx); @@ -1361,11 +1364,28 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx, uint32_t ne = (uint32_t) ggml_nelements(dst); + size_t src0_webgpu_tensor_align_offset = ggml_webgpu_tensor_align_offset(ctx, src0); + size_t src1_webgpu_tensor_align_offset = ggml_webgpu_tensor_align_offset(ctx, src1); + + uint32_t offset_merged_src0 = 0; + uint32_t offset_merged_src1 = 0; + if (flags.src_overlap) { + size_t min_off = std::min(src0_webgpu_tensor_align_offset, src1_webgpu_tensor_align_offset); + offset_merged_src0 = (uint32_t) ((src0_webgpu_tensor_align_offset - min_off) / ggml_type_size(src0->type)); + offset_merged_src1 = (uint32_t) ((src1_webgpu_tensor_align_offset - min_off) / ggml_type_size(src0->type)); + } + std::vector params = { ne, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + offset_merged_src0, + offset_merged_src1, + (uint32_t) (src0->nb[0] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), (uint32_t) (src1->nb[0] / ggml_type_size(src1->type)), (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), @@ -1381,25 +1401,43 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx, std::vector entries; - entries.push_back({ - .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = ggml_webgpu_tensor_align_offset(ctx, src0), - .size = ggml_webgpu_tensor_binding_size(ctx, src0), - }); - - entries.push_back({ - .binding = 1, - .buffer = ggml_webgpu_tensor_buf(src1), - .offset = ggml_webgpu_tensor_align_offset(ctx, src1), - .size = ggml_webgpu_tensor_binding_size(ctx, src1), - }); - - if (!flags.inplace && !flags.overlap) { - entries.push_back({ .binding = 2, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + if (flags.src_overlap) { + size_t merged_offset = std::min(src0_webgpu_tensor_align_offset, src1_webgpu_tensor_align_offset); + size_t merged_end = std::max(src0_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, src0), + src1_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, src1)); + entries.push_back({ + .binding = 0, + .buffer = ggml_webgpu_tensor_buf(src0), + .offset = merged_offset, + .size = merged_end - merged_offset, + }); + entries.push_back({ + .binding = 1, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = ggml_webgpu_tensor_align_offset(ctx, dst), + .size = ggml_webgpu_tensor_binding_size(ctx, dst), + }); + } else { + entries.push_back({ + .binding = 0, + .buffer = ggml_webgpu_tensor_buf(src0), + .offset = src0_webgpu_tensor_align_offset, + .size = ggml_webgpu_tensor_binding_size(ctx, src0), + }); + entries.push_back({ + .binding = 1, + .buffer = ggml_webgpu_tensor_buf(src1), + .offset = src1_webgpu_tensor_align_offset, + .size = ggml_webgpu_tensor_binding_size(ctx, src1), + }); + if (!flags.inplace && !flags.overlap) { + entries.push_back({ + .binding = 2, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = ggml_webgpu_tensor_align_offset(ctx, dst), + .size = ggml_webgpu_tensor_binding_size(ctx, dst), + }); + } } uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); @@ -2816,10 +2854,8 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const case GGML_OP_SUB: case GGML_OP_MUL: case GGML_OP_DIV: - // TODO: support non-contiguous tensors, e.g. for MOE_EXPERT_REDUCE - // see https://github.com/ggml-org/llama.cpp/pull/16857 supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type) && - (src1->type == op->type) && ggml_is_contiguous(src0) && ggml_is_contiguous(src1); + (src1->type == op->type); break; case GGML_OP_CPY: case GGML_OP_CONT: diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl index 55dd66408a3..a748dc1b86c 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl @@ -7,6 +7,13 @@ struct Params { offset_src0: u32, offset_src1: u32, offset_dst: u32, + offset_merged_src0: u32, + offset_merged_src1: u32, + + stride_src0_0: u32, + stride_src0_1: u32, + stride_src0_2: u32, + stride_src0_3: u32, stride_src1_0: u32, stride_src1_1: u32, @@ -23,6 +30,21 @@ struct Params { b_ne3: u32, }; +fn src0_index(_i: u32) -> u32 { + var i = _i; + let a_i3 = i / (params.a_ne2 * params.a_ne1 * params.a_ne0); + i = i % (params.a_ne2 * params.a_ne1 * params.a_ne0); + let a_i2 = i / (params.a_ne1 * params.a_ne0); + i = i % (params.a_ne1 * params.a_ne0); + let a_i1 = i / params.a_ne0; + let a_i0 = i % params.a_ne0; + + return a_i0 * params.stride_src0_0 + + a_i1 * params.stride_src0_1 + + a_i2 * params.stride_src0_2 + + a_i3 * params.stride_src0_3; +} + fn src1_index(_i: u32) -> u32 { var i = _i; let a_i3 = i / (params.a_ne2 * params.a_ne1 * params.a_ne0); @@ -53,17 +75,22 @@ fn src1_index(_i: u32) -> u32 { #define DataType f16 #endif +#ifdef SRC_OVERLAP @group(0) @binding(0) -var src0: array; +var merged_src: array; @group(0) @binding(1) -var src1 : array; +var dst: array; -#ifdef INPLACE @group(0) @binding(2) var params: Params; +#else +@group(0) @binding(0) +var src0: array; -#elif defined(OVERLAP) +@group(0) @binding(1) +var src1 : array; +#if defined(INPLACE) || defined(OVERLAP) @group(0) @binding(2) var params: Params; @@ -74,6 +101,7 @@ var dst: array; @group(0) @binding(3) var params: Params; #endif +#endif fn op(a: DataType, b: DataType) -> DataType { #ifdef OP_ADD @@ -87,13 +115,17 @@ fn op(a: DataType, b: DataType) -> DataType { #endif } -fn update(dst_i: u32, src0_i: u32, src1_i: u32){ +fn update(dst_i: u32, src0_i: u32, src1_i: u32) { +#ifdef SRC_OVERLAP + let result = op(merged_src[src0_i], merged_src[src1_i]); +#else let result = op(src0[src0_i], src1[src1_i]); +#endif #ifdef INPLACE - src0[dst_i] = result; + src0[src0_i] = result; #elif defined(OVERLAP) - src1[dst_i] = result; + src1[src1_i] = result; #else dst[dst_i] = result; #endif @@ -102,6 +134,8 @@ fn update(dst_i: u32, src0_i: u32, src1_i: u32){ @compute @workgroup_size(WG_SIZE) fn main(@builtin(global_invocation_id) gid: vec3) { if (gid.x < params.ne) { - update(params.offset_dst + gid.x, params.offset_src0 + gid.x, params.offset_src1 + src1_index(gid.x)); + let src0_i = params.offset_src0 + params.offset_merged_src0 + src0_index(gid.x); + let src1_i = params.offset_src1 + params.offset_merged_src1 + src1_index(gid.x); + update(params.offset_dst + gid.x, src0_i, src1_i); } } From 22034a5f6f6c5687ac015ef48f5e43372f5ed77b Mon Sep 17 00:00:00 2001 From: Nikhil Jain Date: Mon, 2 Mar 2026 10:23:34 -0800 Subject: [PATCH 205/831] ggml webgpu: Clean up per-thread parameter buffer pool and job submission logic (llama/19772) * Allow webgpu_buf_pool to resize if needed, remove inflight_threads, and replace inflight_threads with num_kernels for submission * Run clang-format * Keep track of num batched kernels that have not been submitted yet * Run clang-format * Increase buf pool max size * Increase param buf pool init size * Remove webgpu buf pool resizing * Merge with master * Add buffer pool growth * Move buffer pool growth outside of lock * Reduce max pool size to 32 * Run clang-format * Only resize param buf pool --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 65 ++++++++++++++++++++-------- 1 file changed, 47 insertions(+), 18 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 4dc56e1dc58..913cf7f8825 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -133,12 +133,28 @@ struct webgpu_buf_pool { // which can run on a different thread than the calling thread. std::mutex mutex; std::condition_variable cv; + size_t cur_pool_size; + size_t max_pool_size; + wgpu::Device device; + wgpu::BufferUsage host_buf_usage; + wgpu::BufferUsage dev_buf_usage; + size_t buf_size; + bool should_grow; void init(wgpu::Device device, int num_bufs, size_t buf_size, wgpu::BufferUsage dev_buf_usage, - wgpu::BufferUsage host_buf_usage) { + wgpu::BufferUsage host_buf_usage, + bool should_grow = false, + size_t max_pool_size = WEBGPU_NUM_PARAM_BUFS * 2) { + this->max_pool_size = max_pool_size; + this->cur_pool_size = num_bufs; + this->device = device; + this->host_buf_usage = host_buf_usage; + this->dev_buf_usage = dev_buf_usage; + this->buf_size = buf_size; + this->should_grow = should_grow; for (int i = 0; i < num_bufs; i++) { wgpu::Buffer host_buf; wgpu::Buffer dev_buf; @@ -150,6 +166,25 @@ struct webgpu_buf_pool { webgpu_pool_bufs alloc_bufs() { std::unique_lock lock(mutex); + if (!free.empty()) { + webgpu_pool_bufs bufs = free.back(); + free.pop_back(); + return bufs; + } + + // Try growing the pool if no free buffers + if (free.empty() && cur_pool_size < max_pool_size && should_grow) { + cur_pool_size++; + wgpu::Buffer host_buf; + wgpu::Buffer dev_buf; + ggml_webgpu_create_buffer(device, host_buf, buf_size, host_buf_usage, "ggml_webgpu_host_pool_buf"); + ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_pool_buf"); + + if (!(host_buf && dev_buf)) { + GGML_ABORT("webgpu_buf_pool: failed to allocate buffers"); + } + return webgpu_pool_bufs{ host_buf, dev_buf }; + } cv.wait(lock, [this] { return !free.empty(); }); webgpu_pool_bufs bufs = free.back(); free.pop_back(); @@ -243,6 +278,7 @@ struct webgpu_gpu_profile_buf_pool { #endif struct webgpu_command { + uint32_t num_kernels; wgpu::CommandBuffer commands; std::vector params_bufs; std::optional set_rows_error_bufs; @@ -280,7 +316,6 @@ struct webgpu_global_context_struct { webgpu_buf_pool memset_buf_pool; std::map memset_pipelines; // variant or type index - std::atomic_uint inflight_threads = 0; #ifdef GGML_WEBGPU_CPU_PROFILE // Profiling: labeled CPU time in ms (total) @@ -426,13 +461,9 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device, static void ggml_backend_webgpu_wait(webgpu_global_context & ctx, std::vector & futures, bool block = true) { - // If we have too many in-flight submissions, wait on the oldest one first. If - // there are many threads, inflight_max may be 0, meaning that we must wait on - // all futures. - uint64_t timeout_ms = block ? UINT64_MAX : 0; - uint32_t inflight_threads = ctx->inflight_threads; - uint32_t inflight_max = WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD / std::max(inflight_threads, 1u); - while (futures.size() >= inflight_max && futures.size() > 0) { + // If we have too many in-flight submissions, wait on the oldest one first. + uint64_t timeout_ms = block ? UINT64_MAX : 0; + while (futures.size() >= WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD) { ctx->instance.WaitAny(futures[0].futures.size(), futures[0].futures.data(), UINT64_MAX); futures.erase(futures.begin()); } @@ -651,6 +682,7 @@ static webgpu_command ggml_backend_webgpu_build_multi( result.commands = commands; result.params_bufs = params_bufs_list; result.set_rows_error_bufs = set_rows_error_bufs; + result.num_kernels = pipelines.size(); #ifdef GGML_WEBGPU_GPU_PROFILE result.timestamp_query_bufs = ts_bufs; // TODO: handle multiple pipeline names @@ -2081,19 +2113,17 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str WEBGPU_CPU_PROFILE_TOTAL_START(graph_compute); - ctx->global_ctx->inflight_threads++; - std::vector commands; std::vector futures; + uint32_t num_batched_kernels = 0; for (int i = 0; i < cgraph->n_nodes; i++) { if (auto cmd = ggml_webgpu_encode_node(ctx, cgraph->nodes[i])) { commands.push_back(*cmd); + num_batched_kernels += cmd.value().num_kernels; } - // compute the batch size based on the number of inflight threads - uint32_t inflight_threads = ctx->global_ctx->inflight_threads; - uint32_t batch_size = std::min(std::max(1u, WEBGPU_NUM_PARAM_BUFS / std::max(inflight_threads, 1u)), - WEBGPU_COMMAND_SUBMIT_BATCH_SIZE); - if (commands.size() >= batch_size) { + + if (num_batched_kernels >= WEBGPU_COMMAND_SUBMIT_BATCH_SIZE) { + num_batched_kernels = 0; futures.push_back(ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool, &ctx->set_rows_error_buf_pool)); // Process events and check for completed submissions @@ -2109,7 +2139,6 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str } ggml_backend_webgpu_wait(ctx->global_ctx, futures); - ctx->global_ctx->inflight_threads--; WEBGPU_CPU_PROFILE_TOTAL_END(graph_compute, ctx->global_ctx); return GGML_STATUS_SUCCESS; } @@ -2727,7 +2756,7 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) { webgpu_ctx->shader_lib = std::make_unique(dev_ctx->webgpu_global_ctx->device); webgpu_ctx->param_buf_pool.init(webgpu_ctx->global_ctx->device, WEBGPU_NUM_PARAM_BUFS, WEBGPU_PARAMS_BUF_SIZE_BYTES, wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform, - wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite); + wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite, true); webgpu_ctx->set_rows_error_buf_pool.init(webgpu_ctx->global_ctx->device, WEBGPU_NUM_SET_ROWS_ERROR_BUFS, WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES, wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage, From 3145384715c1af0a4895fcac5a77d75b6f02028c Mon Sep 17 00:00:00 2001 From: Abhijit Ramesh Date: Mon, 2 Mar 2026 19:35:11 -0800 Subject: [PATCH 206/831] ggml webgpu: fix workgroup dispatch limit for large batch sizes (llama/19965) * ggml-webgpu: fix workgroup dispatch limit for large batch sizes WebGPU limits workgroup sizes to 65535 per dimension. Large MUL_MAT operations with batch sizes exceedeing this limi would fail. * add compute_2d_workgroups() helper to split total workgroup ID across X/Y dimensions * update mul_mat_reg_tile.wgsl to reconstruct linear workgroup ID from 2D dispatch * update mul_mat_subgroup_matrix.wgsl to reconstruct linear workgroup ID from 2D dispatch * update mul_mat.wgsl to compute global index from 2D workgroup coordinates * refactor all three mul_mat dispatch paths to use the shared helper * ggml-webgpu: add bounds checking for over-dispatched workgroups 2D workgroup dispatch can over-dispatch when total workgroups don't divide evenly into the 65535 per-dimension limit. Extra workgroups would compute invalid batch indices, causing memory corruption. * add batch_idx bound check to mul_mat_reg_tile.wgsl and mul_mat_subgroup_matrix.wgsl to prevent over-dispatched workgroups from accessing invalid memory * fixes test failures with large batch sizes (eg., bs=[128, 1024]) * ggml-webgpu: add back TODO for spliting large sizes into batches * Optimize 2d workgroup provisioning * Set some parameters that increase speed --------- Co-authored-by: Reese Levine --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 28 ++++++++++++------- .../src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl | 13 ++++++--- .../wgsl-shaders/mul_mat_reg_tile.wgsl | 14 ++++++++-- .../wgsl-shaders/mul_mat_subgroup_matrix.wgsl | 14 ++++++++-- 4 files changed, 49 insertions(+), 20 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 913cf7f8825..19451618ec5 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -31,6 +31,13 @@ #define ROUNDUP_POW2(x, pow2) (((x) + ((pow2) - 1)) & ~((pow2) - 1)) #define CEIL_DIV(M, N) (((M) + (N) - 1) / (N)) +// Return a rectangular grid of workgroups with minimal over-provisioned workgroups. +// Assumes that the total number of workgroups does not exceed max_per_dim^2. +static inline void compute_2d_workgroups(uint32_t total_wg, uint32_t max_per_dim, uint32_t & wg_x, uint32_t & wg_y) { + wg_y = std::max(1u, CEIL_DIV(total_wg, max_per_dim)); + wg_x = CEIL_DIV(total_wg, wg_y); +} + #ifdef GGML_WEBGPU_DEBUG # define WEBGPU_LOG_DEBUG(msg) std::cout << msg << std::endl # define WEBGPU_DEBUG_BUF_ELEMS 512 @@ -69,8 +76,8 @@ /* Constants */ -#define WEBGPU_NUM_PARAM_BUFS 16u -#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 8u +#define WEBGPU_NUM_PARAM_BUFS 48u +#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 16u #define WEBGPU_WAIT_ANY_TIMEOUT_MS 0 // Maximum number of in-flight submissions per-thread, to avoid exhausting the // parameter buffer pool @@ -1146,8 +1153,9 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, }; // Calculate workgroup dimensions - uint32_t wg_x = 1; - uint32_t wg_y = 1; + uint32_t wg_x = 1; + uint32_t wg_y = 1; + const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; if (use_fast && is_vec) { auto decisions = static_cast(pipeline.context.get()); @@ -1155,9 +1163,7 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, uint32_t batches = dst->ne[2] * dst->ne[3]; uint32_t output_groups = CEIL_DIV(dst->ne[0], decisions->outputs_per_wg); uint32_t total_wg = output_groups * batches; - // TODO: split large sizes into multiple batches to avoid way over-provisioning workgroups - wg_x = std::min(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension); - wg_y = CEIL_DIV(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension); + compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y); } else if (use_fast) { auto decisions = static_cast(pipeline.context.get()); @@ -1176,12 +1182,14 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, wg_m = CEIL_DIV(dst->ne[0], tile_m_s); wg_n = CEIL_DIV(dst->ne[1], tile_n_s); } - wg_x = wg_m * wg_n * dst->ne[2] * dst->ne[3]; + uint32_t total_wg = wg_m * wg_n * dst->ne[2] * dst->ne[3]; + compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y); + } else { // legacy auto decisions = static_cast(pipeline.context.get()); uint32_t wg_size = decisions->wg_size; - wg_x = CEIL_DIV(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3], wg_size); - wg_y = 1; + uint32_t total_wg = CEIL_DIV(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3], wg_size); + compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y); } return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, wg_y); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl index 6aba47317c6..5b9f5b36224 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl @@ -679,19 +679,24 @@ struct MulMatParams { @group(0) @binding(3) var params: MulMatParams; @compute @workgroup_size(256) -fn main(@builtin(global_invocation_id) global_id: vec3) { +fn main(@builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) wg_id: vec3, + @builtin(num_workgroups) num_wg: vec3) { + let wg_linear = wg_id.y * num_wg.x + wg_id.x; + let global_idx = wg_linear * 256u + local_id.x; + let total = params.m * params.n * params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3; - if (global_id.x >= total) { + if (global_idx >= total) { return; } let dst2_stride = params.m * params.n; let dst3_stride = dst2_stride * params.bs02 * params.broadcast2; - let dst3_idx = global_id.x / dst3_stride; + let dst3_idx = global_idx / dst3_stride; let src03_idx = dst3_idx / params.broadcast3; // src0 may be broadcast along the third dimension let src13_idx = dst3_idx; // src1 is not broadcast - let dst3_rem = global_id.x % dst3_stride; + let dst3_rem = global_idx % dst3_stride; let dst2_idx = dst3_rem / dst2_stride; let src02_idx = dst2_idx / params.broadcast2; // src0 may also be broadcast along the second dimension diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl index 771e5cd1ee3..761e3017c14 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl @@ -54,7 +54,8 @@ var shmem: array; @compute @workgroup_size(TOTAL_WORKGROUP_SIZE) fn main(@builtin(workgroup_id) wg_id: vec3, - @builtin(local_invocation_id) local_id: vec3) { + @builtin(local_invocation_id) local_id: vec3, + @builtin(num_workgroups) num_wg: vec3) { let thread_id = local_id.x; let local_m = get_local_m(thread_id); @@ -64,9 +65,16 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let wg_m_count = (params.m + WORKGROUP_SIZE_M * TILE_M - 1u) / (WORKGROUP_SIZE_M * TILE_M); let wg_per_matrix = wg_m_count * wg_n_count; - let batch_idx = wg_id.x / wg_per_matrix; + let wg_linear = wg_id.y * num_wg.x + wg_id.x; - let wg_in_batch = wg_id.x % wg_per_matrix; + let batch_idx = wg_linear / wg_per_matrix; + + let total_batches = params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3; + if (batch_idx >= total_batches) { + return; + } + + let wg_in_batch = wg_linear % wg_per_matrix; let wg_m = wg_in_batch % wg_m_count; let wg_n = wg_in_batch / wg_m_count; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl index 64529e03cdc..9f9ef279f29 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl @@ -69,7 +69,8 @@ var shmem: array; @compute @workgroup_size(TOTAL_WORKGROUP_SIZE) fn main(@builtin(workgroup_id) wg_id: vec3, @builtin(local_invocation_id) local_id: vec3, - @builtin(subgroup_id) subgroup_id: u32) { + @builtin(subgroup_id) subgroup_id: u32, + @builtin(num_workgroups) num_wg: vec3) { let thread_id = local_id.x; let subgroup_m = subgroup_id % SUBGROUP_M; @@ -79,9 +80,16 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let wg_n_count = (params.n + WG_N_SG_TILE_SIZE - 1) / WG_N_SG_TILE_SIZE; let wg_per_matrix = wg_m_count * wg_n_count; - let batch_idx = wg_id.x / wg_per_matrix; + let wg_linear = wg_id.y * num_wg.x + wg_id.x; - let wg_in_batch = wg_id.x % wg_per_matrix; + let batch_idx = wg_linear / wg_per_matrix; + + let total_batches = params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3; + if (batch_idx >= total_batches) { + return; + } + + let wg_in_batch = wg_linear % wg_per_matrix; let wg_m = wg_in_batch % wg_m_count; let wg_n = wg_in_batch / wg_m_count; From 3a96680718399a9b61d0ad5c41438ca2096893a9 Mon Sep 17 00:00:00 2001 From: shaofeiqi Date: Mon, 2 Mar 2026 19:49:41 -0800 Subject: [PATCH 207/831] opencl: add optimized q4_1 mm kernel for adreno (llama/19840) * Add Q4_1 OpenCL Kernels * opencl: refactor transpose * opencl: format * opencl: refactor q4_1 unpack * opencl: move `ggml_cl_mul_mat_q4_1_f32_adreno` * opencl: refactor `ggml_cl_mul_mat_q4_1_f32_adreno` and kernels * opencl: rename kernel files and kernes * opencl: fix build for non adreno * opencl: move code around and format --------- Co-authored-by: Li He --- ggml/src/ggml-opencl/CMakeLists.txt | 2 + ggml/src/ggml-opencl/ggml-opencl.cpp | 386 +++++++++++++++++- ggml/src/ggml-opencl/kernels/cvt.cl | 52 +++ .../kernels/gemm_noshuffle_q4_1_f32.cl | 132 ++++++ .../gemv_noshuffle_general_q8_0_f32.cl | 2 +- .../kernels/gemv_noshuffle_q4_1_f32.cl | 283 +++++++++++++ ggml/src/ggml-opencl/kernels/transpose.cl | 26 ++ 7 files changed, 879 insertions(+), 4 deletions(-) create mode 100644 ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl create mode 100644 ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index f3891936911..0fe1dd38476 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -108,6 +108,8 @@ set(GGML_OPENCL_KERNELS mul_mm_q8_0_f32_l4_lm mul_mm_q6_k_f32_l4_lm mul_mm_q8_0_f32_8x4 + gemv_noshuffle_q4_1_f32 + gemm_noshuffle_q4_1_f32 gemv_noshuffle_general_q8_0_f32 mul norm diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 3da022ed86c..0b9a021d204 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -531,6 +531,8 @@ struct ggml_backend_opencl_context { cl_kernel kernel_mul_mat_q4_0_f32_8x_flat; cl_kernel kernel_convert_block_q4_0_noshuffle; cl_kernel kernel_restore_block_q4_0_noshuffle; + cl_kernel kernel_convert_block_q4_1_noshuffle; + cl_kernel kernel_restore_block_q4_1_noshuffle; cl_kernel kernel_convert_block_q6_K, kernel_restore_block_q6_K; cl_kernel kernel_mul_mat_q4_0_f32_1d_8x_flat, kernel_mul_mat_q4_0_f32_1d_16x_flat; cl_kernel kernel_mul_mv_q4_1_f32; @@ -683,7 +685,9 @@ struct ggml_backend_opencl_context { cl_kernel kernel_transpose_32; cl_kernel kernel_transpose_32_16; cl_kernel kernel_transpose_16; + cl_kernel kernel_transpose_8_buf; cl_kernel kernel_transpose_16_buf; + cl_kernel kernel_transpose_32_buf; cl_kernel kernel_transpose_16_4x1; // Gemm and Gemv related programs, kernels, etc @@ -699,6 +703,8 @@ struct ggml_backend_opencl_context { cl_kernel CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_4096; cl_kernel CL_mul_mat_vec_q4_0_f32_1d_4x_flat_11008_1_4096; cl_kernel CL_mul_mat_vec_q4_0_f32_1d_4x_flat_32000_1_4096; + cl_kernel kernel_gemv_noshuffle_q4_1_f32; + cl_kernel kernel_gemm_noshuffle_q4_1_f32; cl_kernel kernel_mul_mm_q8_0_f32_8x4; cl_kernel CL_mul_mat_vec_q8_0_f32; #endif // GGML_OPENCL_USE_ADRENO_KERNELS @@ -893,6 +899,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve CL_CHECK((backend_ctx->kernel_restore_block_q4_0_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_0_noshuffle", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q4_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_0", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_q4_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_0", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q4_1_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_1_noshuffle", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q4_1_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_1_noshuffle", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q4_1 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_1", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_q4_1 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_1", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_mxfp4 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4", &err), err)); @@ -2258,7 +2266,9 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve CL_CHECK((backend_ctx->kernel_transpose_32_16 = clCreateKernel(backend_ctx->program_transpose, "kernel_transpose_32_16", &err), err)); CL_CHECK((backend_ctx->kernel_transpose_32 = clCreateKernel(backend_ctx->program_transpose, "kernel_transpose_32", &err), err)); CL_CHECK((backend_ctx->kernel_transpose_16 = clCreateKernel(backend_ctx->program_transpose, "kernel_transpose_16", &err), err)); + CL_CHECK((backend_ctx->kernel_transpose_8_buf = clCreateKernel(backend_ctx->program_transpose, "kernel_transpose_8_buf", &err), err)); CL_CHECK((backend_ctx->kernel_transpose_16_buf = clCreateKernel(backend_ctx->program_transpose, "kernel_transpose_16_buf", &err), err)); + CL_CHECK((backend_ctx->kernel_transpose_32_buf = clCreateKernel(backend_ctx->program_transpose, "kernel_transpose_32_buf", &err), err)); CL_CHECK((backend_ctx->kernel_transpose_16_4x1 = clCreateKernel(backend_ctx->program_transpose, "kernel_transpose_16_4x1", &err), err)); GGML_LOG_CONT("."); } @@ -2378,6 +2388,45 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // gemm_noshuffle_q4_1_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemm_noshuffle_q4_1_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("gemm_noshuffle_q4_1_f32.cl"); +#endif + cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_gemm_noshuffle_q4_1_f32 = clCreateKernel(prog, "kernel_gemm_noshuffle_q4_1_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // gemv_noshuffle_q4_1_f32 + { + std::string CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable "; + if (backend_ctx->has_vector_subgroup_broadcast) { + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; + } + +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemv_noshuffle_q4_1_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("gemv_noshuffle_q4_1_f32.cl"); +#endif + + cl_program prog = build_program_from_source( + backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_gemv_compile_opts); + + CL_CHECK((backend_ctx->kernel_gemv_noshuffle_q4_1_f32 = clCreateKernel(prog, "kernel_gemv_noshuffle_q4_1_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + // mul_mm_q8_0_f32_8x4 { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -2413,7 +2462,7 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve cl_program prog = build_program_from_source( backend_ctx->context, backend_ctx->device, kernel_src_CL_gemv_general.c_str(), CL_gemv_compile_opts); - CL_CHECK((backend_ctx->CL_mul_mat_vec_q8_0_f32 = clCreateKernel(prog, "kernel_gemv_noshuffle", &err), err)); + CL_CHECK((backend_ctx->CL_mul_mat_vec_q8_0_f32 = clCreateKernel(prog, "kernel_gemv_noshuffle_q8_0_f32", &err), err)); CL_CHECK(clReleaseProgram(prog)); GGML_LOG_CONT("."); } @@ -2923,6 +2972,82 @@ static void ggml_cl2_free(ggml_backend_t backend) { } } +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS +static void transpose_2d( + ggml_backend_opencl_context * backend_ctx, + cl_kernel kernel, + cl_mem src, cl_mem dst, size_t size, + cl_int stride, cl_int rows, + bool blocking = true +) { + static ggml_cl_buffer buf; + + cl_event evt; + cl_int err; + + buf.allocate(backend_ctx->context, size); + + cl_mem trans; + cl_buffer_region region; + + region.origin = 0; + region.size = size; + CL_CHECK((trans = clCreateSubBuffer( + buf.buffer, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &src)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &trans)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_int), &stride)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_int), &rows)); + + size_t local_size[3] = {64, 1, 1}; + size_t global_size[3] = {(size_t)stride, (size_t)rows, 1};; + CL_CHECK(clEnqueueNDRangeKernel(backend_ctx->queue, kernel, 3, NULL, + global_size, local_size, 0, NULL, NULL)); + + if (blocking) { + CL_CHECK(clEnqueueCopyBuffer(backend_ctx->queue, trans, dst, 0, 0, size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseEvent(evt)); + } else { + CL_CHECK(clEnqueueCopyBuffer(backend_ctx->queue, trans, dst, 0, 0, size, 0, NULL, NULL)); + } + + CL_CHECK(clReleaseMemObject(trans)); +} + +static void transpose_2d_as_8b( + ggml_backend_opencl_context * backend_ctx, + cl_mem src, cl_mem dst, size_t size, + cl_int stride, cl_int rows, + bool blocking = true +) { + transpose_2d(backend_ctx, backend_ctx->kernel_transpose_8_buf, + src, dst, size, stride, rows, blocking); +} + +static void transpose_2d_as_16b( + ggml_backend_opencl_context * backend_ctx, + cl_mem src, cl_mem dst, size_t size, + cl_int stride, cl_int rows, + bool blocking = true +) { + transpose_2d(backend_ctx, backend_ctx->kernel_transpose_16_buf, + src, dst, size, stride, rows, blocking); +} + +static void transpose_2d_as_32b( + ggml_backend_opencl_context * backend_ctx, + cl_mem src, cl_mem dst, size_t size, + cl_int stride, cl_int rows, + bool blocking = true +) { + transpose_2d(backend_ctx, backend_ctx->kernel_transpose_32_buf, + src, dst, size, stride, rows, blocking); +} +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + //------------------------------------------------------------------------------ // Tensor extra management //------------------------------------------------------------------------------ @@ -4271,7 +4396,15 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); CL_CHECK(err); + #ifdef GGML_OPENCL_USE_ADRENO_KERNELS + cl_kernel kernel = backend_ctx->kernel_convert_block_q4_1; + + if (use_adreno_kernels(backend_ctx, tensor)) { + kernel = backend_ctx->kernel_convert_block_q4_1_noshuffle; + } + #else cl_kernel kernel = backend_ctx->kernel_convert_block_q4_1; + #endif // GGML_OPENCL_USE_ADRENO_KERNELS CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q)); CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d)); @@ -4287,6 +4420,22 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, tensor->extra = extra; +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_kernels(backend_ctx, tensor)) { + + int M = tensor->ne[1]; + int K = tensor->ne[0]; + + GGML_ASSERT(K % 32 == 0); + + // Transpose q as ushort + transpose_2d_as_16b(backend_ctx, extra->q, extra->q, size_q, K/4, M); + // Transpose d as ushort + transpose_2d_as_16b(backend_ctx, extra->d, extra->d, size_d, K/32, M); + // Transpose m as ushort + transpose_2d_as_16b(backend_ctx, extra->m, extra->m, size_m, K/32, M); + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS return; } if (tensor->type == GGML_TYPE_MXFP4) { @@ -4795,6 +4944,53 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, if (tensor->type == GGML_TYPE_Q4_1) { ggml_tensor_extra_cl_q4_1 * extra = (ggml_tensor_extra_cl_q4_1 *)tensor->extra; +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_kernels(backend_ctx, tensor)) { + static ggml_cl_buffer buf_trans_q; + static ggml_cl_buffer buf_trans_m; + static ggml_cl_buffer buf_trans_d; + static ggml_cl_buffer buf_unpacked; + + cl_int M = tensor->ne[1]; + cl_int K = tensor->ne[0]; + + GGML_ASSERT(K % ggml_blck_size(tensor->type) == 0); + + size_t size_q = (ggml_nelements(tensor)/ggml_blck_size(tensor->type))*ggml_blck_size(tensor->type)/2; + size_t size_d = (ggml_nelements(tensor)/ggml_blck_size(tensor->type))*sizeof(ggml_fp16_t); + size_t size_m = (ggml_nelements(tensor)/ggml_blck_size(tensor->type))*sizeof(ggml_fp16_t); + GGML_ASSERT(size_d + size_q + size_m == ggml_nbytes(tensor) && "Incorrect tensor size"); + + buf_trans_q.allocate(backend_ctx->context, size_q); + buf_trans_m.allocate(backend_ctx->context, size_m); + buf_trans_d.allocate(backend_ctx->context, size_d); + buf_unpacked.allocate(backend_ctx->context, ggml_nbytes(tensor)); + + // transpose q, d, m back + transpose_2d_as_16b(backend_ctx, extra->q, buf_trans_q.buffer, size_q, M, K/4); + transpose_2d_as_16b(backend_ctx, extra->d, buf_trans_d.buffer, size_d, M, K/32); + transpose_2d_as_16b(backend_ctx, extra->m, buf_trans_m.buffer, size_m, M, K/32); + + cl_uchar mask_0F = 0x0F; + cl_uchar mask_F0 = 0xF0; + + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {1, 1, 1}; + + cl_kernel kernel = backend_ctx->kernel_restore_block_q4_1_noshuffle; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &buf_trans_q.buffer)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &buf_trans_d.buffer)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &buf_trans_m.buffer)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &buf_unpacked.buffer)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_uchar), &mask_0F)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_uchar), &mask_F0)); + + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); + CL_CHECK(clEnqueueReadBuffer(queue, buf_unpacked.buffer, CL_TRUE, offset, size, data, 0, NULL, NULL)); + return; + } +#endif + cl_int err; cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, ggml_nbytes(tensor), NULL, &err); @@ -4886,8 +5082,8 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, int ne00 = tensor->ne[0]; int ne01 = tensor->ne[1]; - GGML_ASSERT(tensor->ne[2] == 1); // ??? - GGML_ASSERT(tensor->ne[3] == 1); // ??? + GGML_ASSERT(tensor->ne[2] == 1); + GGML_ASSERT(tensor->ne[3] == 1); CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q)); CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->d)); @@ -8371,6 +8567,180 @@ static void ggml_cl_mul_mat_kq_kqv_adreno(ggml_backend_t backend, const ggml_ten CL_CHECK(clReleaseMemObject(D_sub_buffer)); } +static void ggml_cl_mul_mat_q4_1_f32_adreno(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + ggml_tensor_extra_cl_q4_1 * extra0_q4_1 = (ggml_tensor_extra_cl_q4_1 *)src0->extra; + + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + + const int ne1 = dst->ne[1]; + + GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0); + + cl_context context = backend_ctx->context; + cl_kernel kernel; + + cl_int err; + cl_image_format img_fmt; + cl_image_desc img_desc; + cl_buffer_region region; + + int M = ne01; + int N = ne1; + int K = ne00; + + if (ne1 == 1) { + cl_mem q_img = nullptr; + cl_mem b_sub_buf = nullptr; + cl_mem b_img = nullptr; + + // image for q + img_fmt = { CL_R, CL_UNSIGNED_INT32}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = M * K / 2 / 4; + img_desc.buffer = extra0_q4_1->q; + CL_CHECK((q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + // subbuffer for activations + region.origin = offset1; + region.size = K * N * sizeof(float); + CL_CHECK((b_sub_buf = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for activations + img_fmt = {CL_RGBA, CL_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * N / 4; + img_desc.buffer = b_sub_buf; + CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + kernel = backend_ctx->kernel_gemv_noshuffle_q4_1_f32; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &q_img)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_1->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q4_1->m)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &b_img)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_int), &ne01)); + + size_t local_work_size[3] = {64, 4, 1}; + size_t global_work_size[3] = {(size_t)CEIL_DIV(ne01/2, 64)*64, 4, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + + CL_CHECK(clReleaseMemObject(q_img)); + CL_CHECK(clReleaseMemObject(b_sub_buf)); + CL_CHECK(clReleaseMemObject(b_img)); + } else { + cl_mem b_sub_buf = nullptr; + cl_mem b_sub_buf_trans = nullptr; + cl_mem b_img = nullptr; + cl_mem b_img_trans = nullptr; + + // subbuffer for activations + region.origin = offset1; + region.size = K * N * sizeof(float); + CL_CHECK((b_sub_buf = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for activations + img_fmt = {CL_RGBA, CL_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * N / 4; + img_desc.buffer = b_sub_buf; + CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + // pad N to multiple of 8 + int extra_elements = N % 8; + int padding = 0; + if (extra_elements > 0){ + padding = 8 - extra_elements; + } + + // subbuffer for transposed activations + region.origin = 0; + region.size = K * (N + padding) * sizeof(float)/2; + backend_ctx->prealloc_act_trans.allocate(context, region.size); + CL_CHECK((b_sub_buf_trans = clCreateSubBuffer(backend_ctx->prealloc_act_trans.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for transposed activations + img_fmt = {CL_RGBA, CL_HALF_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * (N + padding) / 4; + img_desc.buffer = b_sub_buf_trans; + CL_CHECK((b_img_trans = clCreateImage(context, 0, &img_fmt, &img_desc, NULL, &err), err)); + + // transpose activations + int height_B = N/4; + if (height_B == 0) { + height_B = 1; + } + int width_B = K/4; + int padded_height_B = (N + padding)/4; + + kernel = backend_ctx->kernel_transpose_32_16; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &b_img)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &b_img_trans)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_B)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_B)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &padded_height_B)); + + size_t local_work_size_t[2] = { 1, 16 }; + size_t global_work_size_t[2] = { (size_t)width_B, (size_t)padded_height_B }; + backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size_t, local_work_size_t, dst); + + // gemm + kernel = backend_ctx->kernel_gemm_noshuffle_q4_1_f32; + int padded_N = N + padding; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q4_1->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_1->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q4_1->m)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &b_img_trans)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_int), &padded_N)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_int), &ne1)); + + size_t global_work_size[3] = {(size_t)CEIL_DIV(ne1, 8), (size_t)CEIL_DIV(ne01, 4), 1}; + size_t local_work_size[3] = {1, 128, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + + CL_CHECK(clReleaseMemObject(b_sub_buf)); + CL_CHECK(clReleaseMemObject(b_sub_buf_trans)); + CL_CHECK(clReleaseMemObject(b_img)); + CL_CHECK(clReleaseMemObject(b_img_trans)); + } +#else + GGML_UNUSED(backend); + GGML_UNUSED(src0); + GGML_UNUSED(src1); + GGML_UNUSED(dst); +#endif +} + static void ggml_cl_mul_mat_q8_0_f32_adreno(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { #ifdef GGML_OPENCL_USE_ADRENO_KERNELS GGML_ASSERT(src0); @@ -8736,6 +9106,16 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co int padding; // <--------------------------------------------> // + // NOTE: Kernels using image1d_buffer_t (e.g., src0_q) would normally require + // a limit check, but q4_0 / q4_1 tensors are very unlikely to exceed that + // limit, so the check is omitted. + + // q4_1 x fp32 + if (src0t == GGML_TYPE_Q4_1 && src1t == GGML_TYPE_F32) { + ggml_cl_mul_mat_q4_1_f32_adreno(backend, src0, src1, dst); + return; + } + // q8_0 x fp32 if (src0t == GGML_TYPE_Q8_0 && src1t == GGML_TYPE_F32 && enable_adreno_trans_weight(backend_ctx, src0)) { diff --git a/ggml/src/ggml-opencl/kernels/cvt.cl b/ggml/src/ggml-opencl/kernels/cvt.cl index 2c244ce3215..78ef9c177f6 100644 --- a/ggml/src/ggml-opencl/kernels/cvt.cl +++ b/ggml/src/ggml-opencl/kernels/cvt.cl @@ -199,6 +199,58 @@ kernel void kernel_restore_block_q4_1( } } +kernel void kernel_convert_block_q4_1_noshuffle( + global struct block_q4_1 * src0, + global uchar * dst_q, + global half * dst_d, + global half * dst_m +) { + global struct block_q4_1 * b = (global struct block_q4_1 *) src0 + get_global_id(0); + global uchar * q = (global uchar *) dst_q + QK4_1/2*get_global_id(0); + global half * d = (global half *) dst_d + get_global_id(0); + global half * m = (global half *) dst_m + get_global_id(0); + + *d = b->d; + *m = b->m; + for (int i = 0; i < QK4_1/4; ++i) { + uchar x0 = b->qs[2*i + 0]; + uchar x1 = b->qs[2*i + 1]; + + q[i + 0 ] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4); + q[i + QK4_1/4] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0); + +#ifdef ADRENO_GPU + if (get_global_id(0) == 65536*4096) { + printf("%04x - %02x\n", *(global ushort*)d, ((x0 & 0xF0) >> 4) | (x1 & 0xF0)); + } +#endif + } +} + +kernel void kernel_restore_block_q4_1_noshuffle( + global uchar * src_q, + global half * src_d, + global half * src_m, + global struct block_q4_1 * dst, + uchar mask_0F, + uchar mask_F0 +) { + global struct block_q4_1 * b = (global struct block_q4_1 *) dst + get_global_id(0); + global uchar * q = (global uchar *) src_q + QK4_1/2*get_global_id(0); + global half * d = (global half *) src_d + get_global_id(0); + global half * m = (global half *) src_m + get_global_id(0); + + b->d = *d; + b->m = *m; + for (int i = 0; i < QK4_1/4; ++i) { + uchar x0 = q[i + 0 ] ; + uchar x1 = q[i + QK4_1/4]; + + b->qs[2*i + 0] = convert_uchar((x0 & mask_0F) | ((x1 & mask_0F) << 4)); + b->qs[2*i + 1] = convert_uchar(((x0 & mask_F0) >> 4) | (x1 & mask_F0)); + } +} + //------------------------------------------------------------------------------ // block_mxfp4 //------------------------------------------------------------------------------ diff --git a/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl b/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl new file mode 100644 index 00000000000..5c4d5cc8e2c --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl @@ -0,0 +1,132 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable + +#ifdef cl_qcom_reqd_sub_group_size +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_128 +#endif + +kernel void kernel_gemm_noshuffle_q4_1_f32( + global const ushort * src0_q, + global const half * src0_d, + global const half * src0_m, + read_only image1d_buffer_t src1, + global float * dst, + ulong offsetd, + int m, + int n, + int k, + int n_no_padding +) { + dst = (global float *)((global char *)dst + offsetd); + + int m_4 = m >> 2; + int n_4 = n >> 2; + + int gy = get_global_id(0); + int gx = get_global_id(1); + int gx_2 = gx << 2; + + half8 c0 = 0, c1 = 0, c2 = 0, c3 = 0; + half8 B; + half4 dequantized_weights; + + global const ushort* weight_ptr = src0_q + gx_2; + global const half* scale_ptr = src0_d + gx_2; + global const half* min_ptr = src0_m + gx_2; + + for(int i = 0; i < k; i += 4) { + B.s0123 = read_imageh(src1, gy*2 + (i)*(n_4)); + B.s4567 = read_imageh(src1, gy*2 + (i)*(n_4)+1); + + ushort4 bits4 = vload4(0, weight_ptr + (i/4)*(m)); + + half4 scale = vload4(0, scale_ptr + (i/32)*(m)); + half4 minv = vload4(0, min_ptr + (i/32)*(m)); + + // j=0 + dequantized_weights.s0 = (bits4.s0 & (0x000F)) * scale.s0 + minv.s0; + dequantized_weights.s1 = (bits4.s1 & (0x000F)) * scale.s1 + minv.s1; + dequantized_weights.s2 = (bits4.s2 & (0x000F)) * scale.s2 + minv.s2; + dequantized_weights.s3 = (bits4.s3 & (0x000F)) * scale.s3 + minv.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=1 + B.s0123 = read_imageh(src1, gy*2 + (i+1)*(n_4)); + B.s4567 = read_imageh(src1, gy*2 + (i+1)*(n_4)+1); + dequantized_weights.s0 = ((bits4.s0 & (0x00F0)) >> 4) * scale.s0 + minv.s0; + dequantized_weights.s1 = ((bits4.s1 & (0x00F0)) >> 4) * scale.s1 + minv.s1; + dequantized_weights.s2 = ((bits4.s2 & (0x00F0)) >> 4) * scale.s2 + minv.s2; + dequantized_weights.s3 = ((bits4.s3 & (0x00F0)) >> 4) * scale.s3 + minv.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=2 + B.s0123 = read_imageh(src1, gy*2 + (i+2)*(n_4)); + B.s4567 = read_imageh(src1, gy*2 + (i+2)*(n_4)+1); + dequantized_weights.s0 = ((bits4.s0 & (0x0F00)) >> 8) * scale.s0 + minv.s0; + dequantized_weights.s1 = ((bits4.s1 & (0x0F00)) >> 8) * scale.s1 + minv.s1; + dequantized_weights.s2 = ((bits4.s2 & (0x0F00)) >> 8) * scale.s2 + minv.s2; + dequantized_weights.s3 = ((bits4.s3 & (0x0F00)) >> 8) * scale.s3 + minv.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=3 + B.s0123 = read_imageh(src1, gy*2 + (i+3)*(n_4)); + B.s4567 = read_imageh(src1, gy*2 + (i+3)*(n_4)+1); + dequantized_weights.s0 = ((bits4.s0 & (0xF000)) >> 12) * scale.s0 + minv.s0; + dequantized_weights.s1 = ((bits4.s1 & (0xF000)) >> 12) * scale.s1 + minv.s1; + dequantized_weights.s2 = ((bits4.s2 & (0xF000)) >> 12) * scale.s2 + minv.s2; + dequantized_weights.s3 = ((bits4.s3 & (0xF000)) >> 12) * scale.s3 + minv.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + } + + int idx = (gy<<3)*m + (gx<<2); + + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s0, c1.s0, c2.s0, c3.s0), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s1, c1.s1, c2.s1, c3.s1), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s2, c1.s2, c2.s2, c3.s2), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s3, c1.s3, c2.s3, c3.s3), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s4, c1.s4, c2.s4, c3.s4), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s5, c1.s5, c2.s5, c3.s5), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s6, c1.s6, c2.s6, c3.s6), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s7, c1.s7, c2.s7, c3.s7), 0, dst + idx); + } +} diff --git a/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl index f944ef3a992..9703b693e56 100644 --- a/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +++ b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl @@ -121,7 +121,7 @@ #ifdef ADRENO_GPU REQD_SUBGROUP_SIZE_64 #endif -__kernel void kernel_gemv_noshuffle( +__kernel void kernel_gemv_noshuffle_q8_0_f32( __read_only image1d_buffer_t src0_q, // quantized A global half * src0_d, // A scales __read_only image1d_buffer_t src1, // B diff --git a/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl new file mode 100644 index 00000000000..fdc1472454f --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl @@ -0,0 +1,283 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable + +#ifdef cl_qcom_reqd_sub_group_size +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#endif + +#define QK4_0 32 +#define NSUBGROUPS 4 +#define SUBGROUP_SIZE 64 + +#define dequantizeBlockAccum_ns_sgbroadcast_1_hi(total_sums, bits4, scale, minv, y) \ + float shared_y; \ + shared_y = sub_group_broadcast(y.s0, 0); \ + total_sums.s0 += ((bits4.s0 & 0x000F) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((bits4.s1 & 0x000F) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 0); \ + total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 0); \ + total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 0); \ + total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 0); \ + total_sums.s0 += ((bits4.s2 & 0x000F) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((bits4.s3 & 0x000F) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 0); \ + total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 0); \ + total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 0); \ + total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s0, 1); \ + total_sums.s0 += ((bits4.s4 & 0x000F) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((bits4.s5 & 0x000F) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 1); \ + total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 1); \ + total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 1); \ + total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 1); \ + total_sums.s0 += ((bits4.s6 & 0x000F) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((bits4.s7 & 0x000F) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 1); \ + total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 1); \ + total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 1); \ + total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y; \ + + +#define dequantizeBlockAccum_ns_sgbroadcast_1_lo(total_sums, bits4, scale, minv, y) \ + shared_y = sub_group_broadcast(y.s0, 2); \ + total_sums.s0 += ((bits4.s0 & 0x000F) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((bits4.s1 & 0x000F) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 2); \ + total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 2); \ + total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 2); \ + total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 2); \ + total_sums.s0 += ((bits4.s2 & 0x000F) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((bits4.s3 & 0x000F) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 2); \ + total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 2); \ + total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 2); \ + total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s0, 3); \ + total_sums.s0 += ((bits4.s4 & 0x000F) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((bits4.s5 & 0x000F) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 3); \ + total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 3); \ + total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 3); \ + total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 3); \ + total_sums.s0 += ((bits4.s6 & 0x000F) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((bits4.s7 & 0x000F) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 3); \ + total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 3); \ + total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 3); \ + total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y; \ + + +#define dequantizeBlockAccum_ns_sgbroadcast_8_hi(total_sums, bits4, scale, minv, y) \ + float8 shared_y; \ + shared_y = sub_group_broadcast(y, 0); \ + total_sums.s0 += ((bits4.s0 & 0x000F) * scale.s0 + minv.s0) * shared_y.s0; \ + total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y.s1; \ + total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y.s2; \ + total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y.s3; \ + total_sums.s0 += ((bits4.s2 & 0x000F) * scale.s0 + minv.s0) * shared_y.s4; \ + total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y.s5; \ + total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y.s6; \ + total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y.s7; \ + total_sums.s1 += ((bits4.s1 & 0x000F) * scale.s1 + minv.s1) * shared_y.s0; \ + total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y.s1; \ + total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y.s2; \ + total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y.s3; \ + total_sums.s1 += ((bits4.s3 & 0x000F) * scale.s1 + minv.s1) * shared_y.s4; \ + total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y.s5; \ + total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y.s6; \ + total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y.s7; \ + shared_y = sub_group_broadcast(y, 1); \ + total_sums.s0 += ((bits4.s4 & 0x000F) * scale.s0 + minv.s0) * shared_y.s0; \ + total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y.s1; \ + total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y.s2; \ + total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y.s3; \ + total_sums.s0 += ((bits4.s6 & 0x000F) * scale.s0 + minv.s0) * shared_y.s4; \ + total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y.s5; \ + total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y.s6; \ + total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y.s7; \ + total_sums.s1 += ((bits4.s5 & 0x000F) * scale.s1 + minv.s1) * shared_y.s0; \ + total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y.s1; \ + total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y.s2; \ + total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y.s3; \ + total_sums.s1 += ((bits4.s7 & 0x000F) * scale.s1 + minv.s1) * shared_y.s4; \ + total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y.s5; \ + total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y.s6; \ + total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y.s7; \ + + +#define dequantizeBlockAccum_ns_sgbroadcast_8_lo(total_sums, bits4, scale, minv, y) \ + shared_y = sub_group_broadcast(y, 2); \ + total_sums.s0 += ((bits4.s0 & 0x000F) * scale.s0 + minv.s0) * shared_y.s0; \ + total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y.s1; \ + total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y.s2; \ + total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y.s3; \ + total_sums.s0 += ((bits4.s2 & 0x000F) * scale.s0 + minv.s0) * shared_y.s4; \ + total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y.s5; \ + total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y.s6; \ + total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y.s7; \ + total_sums.s1 += ((bits4.s1 & 0x000F) * scale.s1 + minv.s1) * shared_y.s0; \ + total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y.s1; \ + total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y.s2; \ + total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y.s3; \ + total_sums.s1 += ((bits4.s3 & 0x000F) * scale.s1 + minv.s1) * shared_y.s4; \ + total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y.s5; \ + total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y.s6; \ + total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y.s7; \ + shared_y = sub_group_broadcast(y, 3); \ + total_sums.s0 += ((bits4.s4 & 0x000F) * scale.s0 + minv.s0) * shared_y.s0; \ + total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y.s1; \ + total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y.s2; \ + total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y.s3; \ + total_sums.s0 += ((bits4.s6 & 0x000F) * scale.s0 + minv.s0) * shared_y.s4; \ + total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y.s5; \ + total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y.s6; \ + total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y.s7; \ + total_sums.s1 += ((bits4.s5 & 0x000F) * scale.s1 + minv.s1) * shared_y.s0; \ + total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y.s1; \ + total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y.s2; \ + total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y.s3; \ + total_sums.s1 += ((bits4.s7 & 0x000F) * scale.s1 + minv.s1) * shared_y.s4; \ + total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y.s5; \ + total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y.s6; \ + total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y.s7; \ + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_gemv_noshuffle_q4_1_f32( + read_only image1d_buffer_t src0_q, + global half2 * src0_d, + global half2 * src0_m, + read_only image1d_buffer_t src1, + global float * dst, + ulong offsetd, + int ne00, + int ne01) +{ + uint groupId = get_local_id(1); + uint gid = get_global_id(0); + ushort slid = get_sub_group_local_id(); + + uint K = ne00; + uint M = ne01; + + uint LINE_STRIDE_A = M / 2; + uint BLOCK_STRIDE_A = NSUBGROUPS * M; + + private uint4 regA; + private half2 regS; + private half2 regM; + private float8 regB; + + private float2 totalSum = (float2)(0.0f); + + // loop along K in block granularity, skip 4 blocks every iter + for (uint k = groupId; k < (K / QK4_0); k += NSUBGROUPS) { + regS = src0_d[gid + k * LINE_STRIDE_A]; // each fiber loads scale of two rows + regM = src0_m[gid + k * LINE_STRIDE_A]; // each fiber loads min of two rows + // first 4 fibers in each wave load 8 B values to its private scope + if (slid < 4) { + regB.s0123 = read_imagef(src1, (slid * 2 + k * 8)); + regB.s4567 = read_imagef(src1, (1 + slid * 2 + k * 8)); + } + + // load half weights for two blocks in consecutive rows + regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 0)).x; + regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 1)).x; + regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 2)).x; + regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 3)).x; +#ifdef VECTOR_SUB_GROUP_BROADCAT + dequantizeBlockAccum_ns_sgbroadcast_8_hi(totalSum, as_ushort8(regA), regS, regM, regB); +#else + dequantizeBlockAccum_ns_sgbroadcast_1_hi(totalSum, as_ushort8(regA), regS, regM, regB); +#endif // VECTOR_SUB_GROUP_BROADCAT + + regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 4)).x; + regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 5)).x; + regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x; + regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x; +#ifdef VECTOR_SUB_GROUP_BROADCAT + dequantizeBlockAccum_ns_sgbroadcast_8_lo(totalSum, as_ushort8(regA), regS, regM, regB); +#else + dequantizeBlockAccum_ns_sgbroadcast_1_lo(totalSum, as_ushort8(regA), regS, regM, regB); +#endif // VECTOR_SUB_GROUP_BROADCAT + } + + // reduction in local memory, assumes #wave=4 + local float2 reduceLM[SUBGROUP_SIZE * 3]; + if (groupId == 1) { + reduceLM[SUBGROUP_SIZE * 0 + slid] = totalSum; + } + if (groupId == 2) { + reduceLM[SUBGROUP_SIZE * 1 + slid] = totalSum; + } + if (groupId == 3) { + reduceLM[SUBGROUP_SIZE * 2 + slid] = totalSum; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + if (groupId == 0) { + totalSum += reduceLM[SUBGROUP_SIZE * 0 + slid]; + } + if (groupId == 0) { + totalSum += reduceLM[SUBGROUP_SIZE * 1 + slid]; + } + if (groupId == 0) { + totalSum += reduceLM[SUBGROUP_SIZE * 2 + slid]; + } + + // 2 outputs per fiber in wave 0 + if (groupId == 0) { + dst = (global float*)((global char*)dst + offsetd); + vstore2(totalSum, 0, &(dst[gid * 2])); + } + +} diff --git a/ggml/src/ggml-opencl/kernels/transpose.cl b/ggml/src/ggml-opencl/kernels/transpose.cl index 1279b6531b9..ad89bdcbdec 100644 --- a/ggml/src/ggml-opencl/kernels/transpose.cl +++ b/ggml/src/ggml-opencl/kernels/transpose.cl @@ -44,6 +44,19 @@ kernel void kernel_transpose_16_4x1( write_imageh(output, i * rows + j, (half4)(temp0, temp1, temp2, temp3)); } +// Transpose treating each element as 8-bit using buffer +kernel void kernel_transpose_8_buf( + global const uchar * input, + global uchar * output, + const int ldi, + const int ldo +) { + const int x = get_global_id(0); + const int y = get_global_id(1); + + output[x*ldo + y] = input[y*ldi + x]; +} + // Transpose treating each element as 16-bit using buffer kernel void kernel_transpose_16_buf( global const ushort * input, @@ -57,6 +70,19 @@ kernel void kernel_transpose_16_buf( output[x*ldo + y] = input[y*ldi + x]; } +// Transpose treating each element as 32-bit using buffer +kernel void kernel_transpose_32_buf( + global const uint * input, + global uint * output, + const int ldi, + const int ldo +) { + const int x = get_global_id(0); + const int y = get_global_id(1); + + output[x*ldo + y] = input[y*ldi + x]; +} + // 32-bit transpose, loading/storing a 4x4 tile of elements kernel void kernel_transpose_32( __read_only image1d_buffer_t input, From 169d723fa000f0325e0464585418a6f6a260152e Mon Sep 17 00:00:00 2001 From: Charles Xu Date: Tue, 3 Mar 2026 10:40:26 +0100 Subject: [PATCH 208/831] kleidiai : add sme fp16 compute path for q4_0 gemm on aarch64 (llama/20043) --- ggml/src/ggml-cpu/CMakeLists.txt | 11 +++++--- ggml/src/ggml-cpu/kleidiai/kernels.cpp | 35 +++++++++++++------------- 2 files changed, 25 insertions(+), 21 deletions(-) diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index 3dc948e4d8e..6ca3176a2f2 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -566,9 +566,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name) # Fetch KleidiAI sources: include(FetchContent) - set(KLEIDIAI_COMMIT_TAG "v1.16.0") + set(KLEIDIAI_COMMIT_TAG "v1.22.0") set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz") - set(KLEIDIAI_ARCHIVE_MD5 "0a9e9008adb6031f9e8cf70dff4a3321") + set(KLEIDIAI_ARCHIVE_MD5 "54049037570ab0ee0a0d126b2ba5ece1") if (POLICY CMP0135) cmake_policy(SET CMP0135 NEW) @@ -608,6 +608,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name) ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/ ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/ ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/ + ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_f16p_qsi4c32p/ ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/) set(ARCH_FLAGS_TEMP "${ARCH_FLAGS}") @@ -648,7 +649,6 @@ function(ggml_add_cpu_backend_variant_impl tag_name) if (NOT SME_ENABLED MATCHES -1) list(APPEND GGML_KLEIDIAI_SOURCES - ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.c ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa_asm.S @@ -656,10 +656,13 @@ function(ggml_add_cpu_backend_variant_impl tag_name) ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot_asm.S ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.c ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa_asm.S + ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_f16p_qsi4c32p/kai_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa.c + ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_f16p_qsi4c32p/kai_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa_asm.S ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.c ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.c + ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_pack_f16pmrx2_f32_neon.c ${KLEIDIAI_SRC}/kai/kai_common_sme_asm.S) - set(PRIVATE_ARCH_FLAGS "-fno-tree-vectorize;${PRIVATE_ARCH_FLAGS}+sve+sve2") + set(PRIVATE_ARCH_FLAGS "-fno-tree-vectorize;${PRIVATE_ARCH_FLAGS}+sve+sve2+sme2+fp16") endif() if (NOT SVE_ENABLED MATCHES -1) diff --git a/ggml/src/ggml-cpu/kleidiai/kernels.cpp b/ggml/src/ggml-cpu/kleidiai/kernels.cpp index d114f2d49bf..40f7c0df650 100644 --- a/ggml/src/ggml-cpu/kleidiai/kernels.cpp +++ b/ggml/src/ggml-cpu/kleidiai/kernels.cpp @@ -1,4 +1,4 @@ -// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2025-2026 Arm Limited and/or its affiliates // SPDX-License-Identifier: MIT // @@ -9,7 +9,6 @@ #include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.h" #include "kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.h" #include "kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.h" -#include "kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.h" #include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.h" #include "kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.h" #include "kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.h" @@ -20,6 +19,7 @@ #include "kai_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.h" #include "kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm.h" #include "kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod.h" +#include "kai_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa.h" #include "kai_lhs_pack_bf16p2vlx2_f32_sme.h" #include "kai_lhs_quant_pack_qsi8d32p_f32.h" @@ -31,6 +31,7 @@ #include "kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.h" #include "kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.h" #include "kai_rhs_pack_nxk_qsi8cxp_qsi8cx_neon.h" +#include "kai_lhs_pack_f16pmrx2_f32_neon.h" #include "kai_common.h" @@ -309,24 +310,24 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { { /* SME GEMM */ /* .kern_info = */ { - /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, - /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, - /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, - /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, - /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, - /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, - /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, - /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, - /* .get_lhs_offset_ex = */ &kernel_offs_fn3, - /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3, - /* .run_kernel_ex = */ &kernel_run_fn11, + /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa, + /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa, + /* .get_mr = */ kai_get_mr_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa, + /* .get_nr = */ kai_get_nr_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa, + /* .get_kr = */ kai_get_kr_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa, + /* .get_sr = */ kai_get_sr_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa, + /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa, + /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa, + /* .get_lhs_offset_ex = */ &kernel_offs_fn3, + /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3, + /* .run_kernel_ex = */ &kernel_run_fn11, }, /* .gemm_lhs_info = */ { - /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32_neon, - /* .get_packed_offset_ex = */ &lhs_offs_fn6, - /* .packed_size_ex = */ &lhs_ps_fn6, - /* .pack_func_ex = */ &lhs_pack_float_fn10, + /* .get_offset = */ kai_get_lhs_offset_lhs_pack_f16pmrx2_f32_neon, + /* .get_packed_offset_ex = */ &lhs_offs_fn6, + /* .packed_size_ex = */ &lhs_ps_fn6, + /* .pack_func_ex = */ &lhs_pack_void_fn10, }, /* SME GEMV */ /* .kern_info = */ { From b1b018dfd11060e7f5f633ff57d2714966b38e5b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrien=20Gallou=C3=ABt?= Date: Wed, 4 Mar 2026 11:57:09 +0100 Subject: [PATCH 209/831] ggml : use a simple std::thread in AMX without OpenMP (llama/20074) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Disabling OpenMP generally provides better inference performance (at least in my testing) but the loading becomes slightly slower. Benchmark results for `convert_B_packed_format()`: Before this commit: N K | No OpenMP OpenMP | Diff | Speedup ------------------------------------------------------------ 512 2880 | 640.9us 263.5us | -58.9% | 0.41x 2880 4096 | 2.55ms 261.7us | -89.8% | 0.10x 201088 2880 | 256.44ms 21.61ms | -91.6% | 0.08x ------------------------------------------------------------ Total: 325.43ms vs 31.05ms After: N K | No OpenMP OpenMP | Diff | Speedup ------------------------------------------------------------ 512 2880 | 1.49ms 263.5us | -82.3% | 0.18x 2880 4096 | 1.55ms 261.7us | -83.1% | 0.17x 201088 2880 | 24.03ms 21.61ms | -10.1% | 0.90x ------------------------------------------------------------ Total: 78.97ms vs 31.05ms Tested with unsloth/gpt-oss-20b-GGUF:Q4_K_M. Signed-off-by: Adrien Gallouët --- ggml/src/ggml-cpu/amx/common.h | 44 ++++++++++++++++++++++++++-------- 1 file changed, 34 insertions(+), 10 deletions(-) diff --git a/ggml/src/ggml-cpu/amx/common.h b/ggml/src/ggml-cpu/amx/common.h index f392e898518..26a6ec1a2d0 100644 --- a/ggml/src/ggml-cpu/amx/common.h +++ b/ggml/src/ggml-cpu/amx/common.h @@ -9,6 +9,8 @@ #if defined(GGML_USE_OPENMP) #include +#else +#include #endif #define TILE_M 16 @@ -56,18 +58,40 @@ inline void balance211(T n, T nth, T ith, T& n_start, T& n_end) { } template -inline void parallel_for(int n, const func_t& f) { +inline void parallel_for(int n, const func_t & f) { + if (n <= 0) { + return; + } #if defined(GGML_USE_OPENMP) -#pragma omp parallel -{ - int nth = omp_get_num_threads(); - int ith = omp_get_thread_num(); - int tbegin, tend; - balance211(n, nth, ith, tbegin, tend); - f(tbegin, tend); -} + #pragma omp parallel + { + int nth = omp_get_num_threads(); + int ith = omp_get_thread_num(); + int tbegin, tend; + balance211(n, nth, ith, tbegin, tend); + f(tbegin, tend); + } #else - f(0, n); + int nth = std::thread::hardware_concurrency(); + if (nth <= 1) { + f(0, n); + return; + } + if (nth > n) { + nth = n; + } + std::vector threads; + threads.reserve(nth); + for (int ith = 0; ith < nth; ++ith) { + threads.emplace_back([&f, n, ith, nth] { + int tbegin, tend; + balance211(n, nth, ith, tbegin, tend); + f(tbegin, tend); + }); + } + for (auto & t : threads) { + t.join(); + } #endif } From 5d25427e58cd571a0a4741b201c61b2121f8c654 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Wed, 4 Mar 2026 12:04:31 +0100 Subject: [PATCH 210/831] ggml: fix ggml_is_contiguous_n for ne == 1 (llama/20092) --- ggml/src/ggml.c | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index e9529fbb662..d644cca8a6e 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1410,16 +1410,14 @@ static bool ggml_is_contiguous_n(const struct ggml_tensor * tensor, int n) { } next_nb *= tensor->ne[0]/ggml_blck_size(tensor->type); for (int i = 1; i < GGML_MAX_DIMS; i++) { - if (tensor->ne[i] != 1) { - if (i > n) { - if (tensor->nb[i] != next_nb) { - return false; - } - next_nb *= tensor->ne[i]; - } else { - // this dimension does not need to be contiguous - next_nb = tensor->ne[i]*tensor->nb[i]; + if (i > n) { + if (tensor->ne[i] != 1 && tensor->nb[i] != next_nb) { + return false; } + next_nb *= tensor->ne[i]; + } else { + // this dimension does not need to be contiguous + next_nb = tensor->ne[i]*tensor->nb[i]; } } return true; From 8d78d409460b56cc128c7a36c87c92c024408eef Mon Sep 17 00:00:00 2001 From: Masashi Yoshimura Date: Thu, 5 Mar 2026 04:19:00 +0900 Subject: [PATCH 211/831] Add concat op to webgpu. (llama/20068) --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 55 ++++++++++++++ ggml/src/ggml-webgpu/ggml-webgpu.cpp | 67 +++++++++++++++++ ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl | 75 +++++++++++++++++++ 3 files changed, 197 insertions(+) create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 369475eaf50..17c5e0fb51f 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -173,6 +173,22 @@ struct ggml_webgpu_scale_pipeline_key_hash { } }; +/** Concat **/ + +struct ggml_webgpu_concat_pipeline_key { + int type; + + bool operator==(const ggml_webgpu_concat_pipeline_key & other) const { return type == other.type; } +}; + +struct ggml_webgpu_concat_pipeline_key_hash { + size_t operator()(const ggml_webgpu_concat_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.type); + return seed; + } +}; + /** Binary **/ struct ggml_webgpu_binary_pipeline_key { @@ -403,6 +419,8 @@ class ggml_webgpu_shader_lib { pad_pipelines; // circular/non-circular std::unordered_map binary_pipelines; // type/op/inplace/overlap + std::unordered_map + concat_pipelines; // type std::unordered_map flash_attn_pipelines; std::unordered_maptype, + }; + + auto it = concat_pipelines.find(key); + if (it != concat_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "concat"; + + switch (key.type) { + case GGML_TYPE_F32: + defines.push_back("TYPE_F32"); + variant += "_f32"; + break; + case GGML_TYPE_I32: + defines.push_back("TYPE_I32"); + variant += "_i32"; + break; + default: + GGML_ABORT("Unsupported type for concat shader"); + } + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_concat, defines); + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + concat_pipelines[key] = pipeline; + return concat_pipelines[key]; + } + webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_shader_lib_context & context) { const bool has_mask = context.src3 != nullptr; const bool has_sinks = context.src4 != nullptr; diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 19451618ec5..334919e589f 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -1484,6 +1484,68 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx, return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } +static webgpu_command ggml_webgpu_concat(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { + uint32_t ne = (uint32_t) ggml_nelements(dst); + uint32_t dim = (uint32_t) dst->op_params[0]; + + std::vector params = { + ne, + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + (uint32_t) (src0->nb[0] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), + (uint32_t) (src1->nb[0] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)), + (uint32_t) dst->ne[0], + (uint32_t) dst->ne[1], + (uint32_t) dst->ne[2], + (uint32_t) dst->ne[3], + dim, + (uint32_t)src0->ne[dim] + }; + + std::vector entries = { + { + .binding = 0, + .buffer = ggml_webgpu_tensor_buf(src0), + .offset = ggml_webgpu_tensor_align_offset(ctx, src0), + .size = ggml_webgpu_tensor_binding_size(ctx, src0) + }, + { + .binding = 1, + .buffer = ggml_webgpu_tensor_buf(src1), + .offset = ggml_webgpu_tensor_align_offset(ctx, src1), + .size = ggml_webgpu_tensor_binding_size(ctx, src1) + }, + { + .binding = 2, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = ggml_webgpu_tensor_align_offset(ctx, dst), + .size = ggml_webgpu_tensor_binding_size(ctx, dst) + } + }; + + ggml_webgpu_shader_lib_context shader_lib_ctx = { + .src0 = src0, + .src1 = src1, + .dst = dst, + .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, + }; + + webgpu_pipeline pipeline = ctx->shader_lib->get_concat_pipeline(shader_lib_ctx); + auto * decisions = static_cast(pipeline.context.get()); + uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); +} + static webgpu_command ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { int inplace = ggml_webgpu_tensor_equal(src, dst); @@ -2068,6 +2130,8 @@ static std::optional ggml_webgpu_encode_node(webgpu_context ctx, case GGML_OP_MUL: case GGML_OP_DIV: return ggml_webgpu_binary_op(ctx, src0, src1, node); + case GGML_OP_CONCAT: + return ggml_webgpu_concat(ctx, src0, src1, node); case GGML_OP_RMS_NORM: return ggml_webgpu_rms_norm(ctx, src0, node); case GGML_OP_ROPE: @@ -2894,6 +2958,9 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type) && (src1->type == op->type); break; + case GGML_OP_CONCAT: + supports_op = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32); + break; case GGML_OP_CPY: case GGML_OP_CONT: supports_op = ((op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl new file mode 100644 index 00000000000..a22d245d2cc --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl @@ -0,0 +1,75 @@ +struct Params { + ne: u32, + + offset_src0: u32, + offset_src1: u32, + offset_dst: u32, + + stride_src0_0: u32, + stride_src0_1: u32, + stride_src0_2: u32, + stride_src0_3: u32, + + stride_src1_0: u32, + stride_src1_1: u32, + stride_src1_2: u32, + stride_src1_3: u32, + + ne0: u32, + ne1: u32, + ne2: u32, + ne3: u32, + + dim: u32, + src0_nedim: u32 +}; + +#ifdef TYPE_F32 +#define DataType f32 +#endif +#ifdef TYPE_I32 +#define DataType i32 +#endif + +@group(0) @binding(0) +var src0: array; + +@group(0) @binding(1) +var src1 : array; + +@group(0) @binding(2) +var dst: array; + +@group(0) @binding(3) +var params: Params; + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(global_invocation_id) gid: vec3) { + + if (gid.x < params.ne) { + var i = gid.x; + let i3 = i / (params.ne2 * params.ne1 * params.ne0); + i = i % (params.ne2 * params.ne1 * params.ne0); + let i2 = i / (params.ne1 * params.ne0); + i = i % (params.ne1 * params.ne0); + let i1 = i / params.ne0; + let i0 = i % params.ne0; + + var ni = array(i0, i1, i2, i3); + + if (ni[params.dim] < params.src0_nedim) { + let src_i = ni[0] * params.stride_src0_0 + + ni[1] * params.stride_src0_1 + + ni[2] * params.stride_src0_2 + + ni[3] * params.stride_src0_3; + dst[params.offset_dst + gid.x] = src0[params.offset_src0 + src_i]; + } else { + ni[params.dim] -= params.src0_nedim; + let src_i = ni[0] * params.stride_src1_0 + + ni[1] * params.stride_src1_1 + + ni[2] * params.stride_src1_2 + + ni[3] * params.stride_src1_3; + dst[params.offset_dst + gid.x] = src1[params.offset_src1 + src_i]; + } + } +} From 4834971a4f62ba9e77d73a3261f666c5123ee0e0 Mon Sep 17 00:00:00 2001 From: Nikhil Jain Date: Wed, 4 Mar 2026 11:54:55 -0800 Subject: [PATCH 212/831] Fix wait logic for inflight jobs (llama/20096) * Enable tmate debugging for investigating thread safety issue * Refactor wait and submit to operate on vector, and fix wait to delete only the future that is completed. * Cleanup * Remove clear change and run clang-format * Cleanup --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 100 +++++++++++++++++---------- 1 file changed, 65 insertions(+), 35 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 334919e589f..b2ef2d59010 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -123,11 +123,6 @@ struct webgpu_pool_bufs { wgpu::Buffer dev_buf; }; -// The futures to wait on for a single queue submission -struct webgpu_submission_futures { - std::vector futures; -}; - // Holds a pool of parameter buffers for WebGPU operations struct webgpu_buf_pool { std::vector free; @@ -463,26 +458,60 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device, /** End WebGPU object initializations */ /** WebGPU Actions */ +static void erase_completed(std::vector & futures) { + futures.erase(std::remove_if(futures.begin(), futures.end(), + [](const wgpu::FutureWaitInfo & info) { return info.completed; }), + futures.end()); +} // Wait for the queue to finish processing all submitted work -static void ggml_backend_webgpu_wait(webgpu_global_context & ctx, - std::vector & futures, - bool block = true) { +static void ggml_backend_webgpu_wait(webgpu_global_context & ctx, + std::vector & futures, + bool block = true) { // If we have too many in-flight submissions, wait on the oldest one first. + if (futures.empty()) { + return; + } uint64_t timeout_ms = block ? UINT64_MAX : 0; while (futures.size() >= WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD) { - ctx->instance.WaitAny(futures[0].futures.size(), futures[0].futures.data(), UINT64_MAX); - futures.erase(futures.begin()); + auto waitStatus = ctx->instance.WaitAny(1, &futures[0], UINT64_MAX); + if (waitStatus == wgpu::WaitStatus::Error) { + GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n"); + } + if (futures[0].completed) { + futures.erase(futures.begin()); + } + } + + if (futures.empty()) { + return; } - size_t i = 0; - while (i < futures.size()) { - auto waitStatus = ctx->instance.WaitAny(futures[i].futures.size(), futures[i].futures.data(), timeout_ms); + + if (block) { + while (!futures.empty()) { + auto waitStatus = ctx->instance.WaitAny(futures.size(), futures.data(), timeout_ms); + switch (waitStatus) { + case wgpu::WaitStatus::Success: + // WaitAny doesn't tell us which future completed, so we must check all futures to see which finished. + erase_completed(futures); + break; + case wgpu::WaitStatus::Error: + GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n"); + break; + default: + GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an unknown status\n"); + break; + } + } + } else { + // Poll once and return + auto waitStatus = ctx->instance.WaitAny(futures.size(), futures.data(), timeout_ms); switch (waitStatus) { case wgpu::WaitStatus::Success: - futures.erase(futures.begin() + i); + // WaitAny doesn't tell us which future completed, so we must check all futures to see which finished. + erase_completed(futures); break; case wgpu::WaitStatus::TimedOut: - i++; break; case wgpu::WaitStatus::Error: GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n"); @@ -525,10 +554,11 @@ static void ggml_backend_webgpu_debug(webgpu_global_context & ctx) { } #endif -static webgpu_submission_futures ggml_backend_webgpu_submit(webgpu_global_context ctx, - std::vector commands, - webgpu_buf_pool & param_buf_pool, - webgpu_buf_pool * set_rows_error_buf_pool = nullptr) { +static std::vector ggml_backend_webgpu_submit( + webgpu_global_context ctx, + std::vector commands, + webgpu_buf_pool & param_buf_pool, + webgpu_buf_pool * set_rows_error_buf_pool = nullptr) { std::vector command_buffers; std::vector params_bufs; std::vector set_rows_error_bufs; @@ -600,7 +630,7 @@ static webgpu_submission_futures ggml_backend_webgpu_submit(webgpu_global_contex futures.push_back({ f }); } #endif - return { futures }; + return futures; } static webgpu_command ggml_backend_webgpu_build_multi( @@ -727,8 +757,7 @@ static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx, webgpu_command command = ggml_backend_webgpu_build(ctx, ctx->memset_buf_pool, ctx->memset_pipelines[0], params, entries, wg_x); - std::vector futures = { ggml_backend_webgpu_submit(ctx, { command }, - ctx->memset_buf_pool) }; + auto futures = ggml_backend_webgpu_submit(ctx, { command }, ctx->memset_buf_pool); ggml_backend_webgpu_wait(ctx, futures); } @@ -836,7 +865,7 @@ static binary_overlap_flags ggml_webgpu_detect_binary_overlap(ggml_tensor * src0 binary_overlap_flags flags = {}; flags.inplace = ggml_webgpu_tensor_equal(src0, dst); flags.overlap = ggml_webgpu_tensor_overlap(src1, dst); - flags.src_overlap = ggml_webgpu_tensor_overlap(src0, src1); + flags.src_overlap = ggml_webgpu_tensor_overlap(src0, src1); return flags; } @@ -1153,8 +1182,8 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, }; // Calculate workgroup dimensions - uint32_t wg_x = 1; - uint32_t wg_y = 1; + uint32_t wg_x = 1; + uint32_t wg_y = 1; const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; if (use_fast && is_vec) { @@ -1410,7 +1439,7 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx, uint32_t offset_merged_src0 = 0; uint32_t offset_merged_src1 = 0; if (flags.src_overlap) { - size_t min_off = std::min(src0_webgpu_tensor_align_offset, src1_webgpu_tensor_align_offset); + size_t min_off = std::min(src0_webgpu_tensor_align_offset, src1_webgpu_tensor_align_offset); offset_merged_src0 = (uint32_t) ((src0_webgpu_tensor_align_offset - min_off) / ggml_type_size(src0->type)); offset_merged_src1 = (uint32_t) ((src1_webgpu_tensor_align_offset - min_off) / ggml_type_size(src0->type)); } @@ -1419,7 +1448,7 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx, ne, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), offset_merged_src0, offset_merged_src1, (uint32_t) (src0->nb[0] / ggml_type_size(src0->type)), @@ -2185,9 +2214,9 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str WEBGPU_CPU_PROFILE_TOTAL_START(graph_compute); - std::vector commands; - std::vector futures; - uint32_t num_batched_kernels = 0; + std::vector commands; + std::vector futures; + uint32_t num_batched_kernels = 0; for (int i = 0; i < cgraph->n_nodes; i++) { if (auto cmd = ggml_webgpu_encode_node(ctx, cgraph->nodes[i])) { commands.push_back(*cmd); @@ -2195,9 +2224,10 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str } if (num_batched_kernels >= WEBGPU_COMMAND_SUBMIT_BATCH_SIZE) { - num_batched_kernels = 0; - futures.push_back(ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool, - &ctx->set_rows_error_buf_pool)); + num_batched_kernels = 0; + std::vector compute_futures = ggml_backend_webgpu_submit( + ctx->global_ctx, commands, ctx->param_buf_pool, &ctx->set_rows_error_buf_pool); + futures.insert(futures.end(), compute_futures.begin(), compute_futures.end()); // Process events and check for completed submissions ctx->global_ctx->instance.ProcessEvents(); ggml_backend_webgpu_wait(ctx->global_ctx, futures, false); @@ -2205,9 +2235,9 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str } } if (!commands.empty()) { - webgpu_submission_futures new_futures = + auto new_futures = ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool, &ctx->set_rows_error_buf_pool); - futures.push_back(new_futures); + futures.insert(futures.end(), new_futures.begin(), new_futures.end()); } ggml_backend_webgpu_wait(ctx->global_ctx, futures); From 2c50962528a7424931a74941cd215147c6357ebf Mon Sep 17 00:00:00 2001 From: lhez Date: Wed, 4 Mar 2026 21:32:26 -0800 Subject: [PATCH 213/831] opencl: add `SET`, support i32 for `CPY`, minor refactor for cpy (llama/20101) --- ggml/src/ggml-opencl/ggml-opencl.cpp | 151 ++++++++++++++++++++++----- ggml/src/ggml-opencl/kernels/cpy.cl | 45 ++++++++ 2 files changed, 168 insertions(+), 28 deletions(-) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 0b9a021d204..a4403a5c273 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -416,7 +416,6 @@ struct ggml_backend_opencl_context { cl_program program_add; cl_program program_add_id; cl_program program_clamp; - cl_program program_cpy; cl_program program_cvt; cl_program program_diag_mask_inf; cl_program program_gelu; @@ -514,7 +513,7 @@ struct ggml_backend_opencl_context { cl_kernel kernel_set_rows_f32_i64, kernel_set_rows_f32_i32, kernel_set_rows_f16_i64, kernel_set_rows_f16_i32; cl_kernel kernel_rope_norm_f32, kernel_rope_norm_f16, kernel_rope_neox_f32, kernel_rope_neox_f16; cl_kernel kernel_rope_multi_f32, kernel_rope_multi_f16, kernel_rope_vision_f32, kernel_rope_vision_f16; - cl_kernel kernel_cpy_f16_f16, kernel_cpy_f16_f32, kernel_cpy_f32_f16, kernel_cpy_f32_f32; + cl_kernel kernel_cpy_f16_f16, kernel_cpy_f16_f32, kernel_cpy_f32_f16, kernel_cpy_f32_f32, kernel_cpy_i32_i32; cl_kernel kernel_mul_mat_f32_f32; cl_kernel kernel_mul_mat_f16_f16; cl_kernel kernel_mul_mat_f16_f32_1row; @@ -873,13 +872,14 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve #else const std::string kernel_src = read_file("cpy.cl"); #endif - backend_ctx->program_cpy = + cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_cpy_f16_f16 = clCreateKernel(backend_ctx->program_cpy, "kernel_cpy_f16_f16", &err), err)); - CL_CHECK((backend_ctx->kernel_cpy_f16_f32 = clCreateKernel(backend_ctx->program_cpy, "kernel_cpy_f16_f32", &err), err)); - CL_CHECK((backend_ctx->kernel_cpy_f32_f16 = clCreateKernel(backend_ctx->program_cpy, "kernel_cpy_f32_f16", &err), err)); - CL_CHECK((backend_ctx->kernel_cpy_f32_f32 = clCreateKernel(backend_ctx->program_cpy, "kernel_cpy_f32_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_cpy_f16_f16 = clCreateKernel(prog, "kernel_cpy_f16_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_cpy_f16_f32 = clCreateKernel(prog, "kernel_cpy_f16_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_cpy_f32_f16 = clCreateKernel(prog, "kernel_cpy_f32_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_cpy_f32_f32 = clCreateKernel(prog, "kernel_cpy_f32_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_cpy_i32_i32 = clCreateKernel(prog, "kernel_cpy_i32_i32", &err), err)); GGML_LOG_CONT("."); } @@ -3544,9 +3544,21 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te default: return false; } + case GGML_TYPE_I32: + switch (op->type) { + case GGML_TYPE_I32: + return true; + default: + return false; + } default: return false; } + case GGML_OP_SET: { + return (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_I32) && + op->type == op->src[0]->type && + op->type == op->src[1]->type; + } case GGML_OP_SCALE: return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]); case GGML_OP_ADD: @@ -10782,28 +10794,13 @@ static void ggml_cl_cpy(ggml_backend_t backend, const ggml_tensor * src0, const // GGML_OP_DUP and GGML_OP_CONT happen between src0 and dst. UNUSED(dst); - const int ne00 = src0 ? src0->ne[0] : 0; - const int ne01 = src0 ? src0->ne[1] : 0; - const int ne02 = src0 ? src0->ne[2] : 0; - const int ne03 = src0 ? src0->ne[3] : 0; - - const cl_ulong nb00 = src0 ? src0->nb[0] : 0; - const cl_ulong nb01 = src0 ? src0->nb[1] : 0; - const cl_ulong nb02 = src0 ? src0->nb[2] : 0; - const cl_ulong nb03 = src0 ? src0->nb[3] : 0; - - const int ne10 = src1 ? src1->ne[0] : 0; - const int ne11 = src1 ? src1->ne[1] : 0; - const int ne12 = src1 ? src1->ne[2] : 0; - const int ne13 = src1 ? src1->ne[3] : 0; - - const cl_ulong nb10 = src1 ? src1->nb[0] : 0; - const cl_ulong nb11 = src1 ? src1->nb[1] : 0; - const cl_ulong nb12 = src1 ? src1->nb[2] : 0; - const cl_ulong nb13 = src1 ? src1->nb[3] : 0; + GGML_TENSOR_LOCALS(int, ne0, src0, ne); + GGML_TENSOR_LOCALS(cl_ulong, nb0, src0, nb); + GGML_TENSOR_LOCALS(int, ne1, src1, ne); + GGML_TENSOR_LOCALS(cl_ulong, nb1, src1, nb); - const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT; - const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT; + const enum ggml_type src0t = src0->type; + const enum ggml_type src1t = src1->type; ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; @@ -10840,6 +10837,15 @@ static void ggml_cl_cpy(ggml_backend_t backend, const ggml_tensor * src0, const GGML_ASSERT(false && "not implemented"); } break; + case GGML_TYPE_I32: + switch (src1t) { + case GGML_TYPE_I32: + kernel = backend_ctx->kernel_cpy_i32_i32; + break; + default: + GGML_ASSERT(false && "not implemented"); + } + break; default: GGML_ASSERT(false && "not implemented"); } @@ -10878,6 +10884,89 @@ static void ggml_cl_dup(ggml_backend_t backend, const ggml_tensor * src0, const UNUSED(src1); } +static void ggml_cl_set(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + GGML_ASSERT((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32) && + src1->type == src0->type && dst->type == src0->type); + + GGML_TENSOR_LOCALS(int, ne0, src0, ne); + GGML_TENSOR_LOCALS(cl_ulong, nb0, src0, nb); + GGML_TENSOR_LOCALS(int, ne1, src1, ne); + GGML_TENSOR_LOCALS(cl_ulong, nb1, src1, nb); + GGML_TENSOR_LOCALS(int, ne, dst, ne); + GGML_TENSOR_LOCALS(cl_ulong, nb, dst, nb); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + const cl_ulong pnb1 = ((const int32_t *)dst->op_params)[0]; + const cl_ulong pnb2 = ((const int32_t *)dst->op_params)[1]; + const cl_ulong pnb3 = ((const int32_t *)dst->op_params)[2]; + const cl_ulong offs = ((const int32_t *)dst->op_params)[3]; + const bool inplace = (bool)((const int32_t *)dst->op_params)[4]; + + cl_kernel kernel = nullptr; + + // for inplace case, dst is a view of src0 and is updated on top of it + // so for non-inplace case, copy src0 to dst first + if (!inplace) { + ggml_cl_cpy(backend, src0, dst, nullptr); + } + + // then copy src1 to dst with specified offset + if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_cpy_f32_f32; + } else if (src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) { + kernel = backend_ctx->kernel_cpy_i32_i32; + } else { + GGML_ASSERT(false && "not implemented"); + } + + offsetd += offs; + cl_ulong nb = ggml_element_size(dst); + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne13)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb10)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb13)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne13)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &pnb1)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &pnb2)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &pnb3)); + + int max_local_size = backend_ctx->get_kernel_workgroup_size(kernel); + + const int nth = MIN(max_local_size, ne00); + + size_t global_work_size[] = {(size_t)ne11*nth, (size_t)ne12, (size_t)ne13}; + size_t local_work_size[] = {(size_t)nth, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); +} + static void ggml_cl_diag_mask_inf(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0); GGML_ASSERT(src0->extra); @@ -11651,6 +11740,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor } func = ggml_cl_cpy; break; + case GGML_OP_SET: + if (!any_on_device) { + return false; + } + func = ggml_cl_set; + break; case GGML_OP_DUP: case GGML_OP_CONT: if (!any_on_device) { diff --git a/ggml/src/ggml-opencl/kernels/cpy.cl b/ggml/src/ggml-opencl/kernels/cpy.cl index 9369351a60c..820aa538a34 100644 --- a/ggml/src/ggml-opencl/kernels/cpy.cl +++ b/ggml/src/ggml-opencl/kernels/cpy.cl @@ -182,3 +182,48 @@ kernel void kernel_cpy_f32_f32( dst_data[i00] = src[0]; } } + +kernel void kernel_cpy_i32_i32( + global int * src0, + ulong offset0, + global int * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = (global int*)((global char*)src0 + offset0); + dst = (global int*)((global char*)dst + offsetd); + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + int i3 = n / (ne2*ne1*ne0); + int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); + + global int * dst_data = (global int *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + global const int * src = (global int *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + + dst_data[i00] = src[0]; + } +} From 2e79b85f66b942432fb5d0a2648be5a38b711ab1 Mon Sep 17 00:00:00 2001 From: Max Krasnyansky Date: Wed, 4 Mar 2026 21:55:29 -0800 Subject: [PATCH 214/831] hexagon: Flash Attention optimizations (dma, mpyacc, multi-row) and MatMul updates (llama/20118) * ggml-hexagon: enhance hvx_dot_f16_f16_aa_rx4 for improved performance by expanding vector handling and optimizing accumulation # Conflicts: # ggml/src/ggml-hexagon/htp/flash-attn-ops.c * ggml-hexagon: optimize hvx_dot_f16_f16_aa_rx4 and enhance hvx_vec_reduce_sum_f32x4 for improved performance and reduced complexity * ggml-hexagon: add hvx_dot_f16_f16_aa_rx32 for enhanced vector processing in flash attention # Conflicts: # ggml/src/ggml-hexagon/htp/flash-attn-ops.c * optimize hvx_dot_f16_f16_aa_rx4 and hvx_dot_f16_f16_aa_rx32 by removing unused scale parameter and improving vector accumulation # Conflicts: # ggml/src/ggml-hexagon/htp/flash-attn-ops.c * ggml-hexagon: refactor hvx_dot_f16_f16_aa_rx4 for improved readability and return HVX_Vector for better integration # Conflicts: # ggml/src/ggml-hexagon/htp/flash-attn-ops.c * ggml-hexagon: initialize sums variable in hvx_dot_f16_f16_aa_rx32 for clarity * ggml-hexagon: fix compiling error * fix hvx_dot_f16_f16_aa_rx4 to handle leftover elements correctly using masking * refactor hvx_dot_f16_f16_aa_rx4 to accept vector and leftover element counts as parameters for improved clarity and flexibility * wip * fa: instrumentation and dma reordering * hex-fa: use block-size 64 to improve DMA pipelining * hex-fa: optimize vec-dot for v79 and above * hex-fa: use block size 64 * hex-fa: avoid scalar fp32->fp16 conversions * hex-fa: simplify dot_f16 functions using optimized vec_mpyacc * hex-fa: rewrite mad_f32_f16 using hvx_vec_mpyacc * hex-mm: use mpyacc in matmul dot functions --------- Co-authored-by: chraac --- ggml/src/ggml-hexagon/htp/flash-attn-ops.c | 367 ++++++++++++--------- ggml/src/ggml-hexagon/htp/hvx-base.h | 21 +- ggml/src/ggml-hexagon/htp/hvx-copy.h | 4 +- ggml/src/ggml-hexagon/htp/hvx-reduce.h | 30 ++ ggml/src/ggml-hexagon/htp/matmul-ops.c | 88 ++--- 5 files changed, 291 insertions(+), 219 deletions(-) diff --git a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c index 74c777d4c3e..6dc978dd68a 100644 --- a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +++ b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c @@ -10,6 +10,7 @@ #include "hex-dma.h" #include "hvx-utils.h" +#include "hvx-dump.h" #define GGML_COMMON_DECL_C #include "ggml-common.h" @@ -17,6 +18,16 @@ #include "htp-msg.h" #include "htp-ops.h" +// Must be multiple of 32 +#define FLASH_ATTN_BLOCK_SIZE (32 * 2) + +// This is a bit of a hack because the compiler is strugling to properly inline +// the default hvx_vec_f32_to_f16 with output into the local array. +static void __attribute__((noinline)) hvx_vec_f32_to_f16_a(void *ptr, HVX_Vector v0, HVX_Vector v1) +{ + *(HVX_Vector *) ptr = hvx_vec_f32_to_f16(v0, v1); +} + // Dot product of two F16 vectors, accumulating to float static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict x, const void * restrict y, unsigned int n, float s) { const HVX_Vector * restrict vx = (const HVX_Vector * restrict) x; // fp16 @@ -25,175 +36,184 @@ static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors uint32_t nloe = n % VLEN_FP16; // leftover elements - HVX_Vector rsum = Q6_V_vsplat_R(0); + HVX_VectorPair rsum_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0)); uint32_t i = 0; #pragma unroll(4) for (i = 0; i < nvec; i++) { - HVX_Vector y_hf = vy[i]; - HVX_Vector x_hf = vx[i]; - - HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); - - rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum)); + rsum_p = hvx_vec_mpyacc_f32_f16(rsum_p, vx[i], vy[i]); } if (nloe) { - // Load x (fp16) and zero-out unused elements HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2); HVX_Vector y_hf = Q6_V_vand_QV(bmask, vy[i]); HVX_Vector x_hf = Q6_V_vand_QV(bmask, vx[i]); - HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); - - rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum)); + rsum_p = hvx_vec_mpyacc_f32_f16(rsum_p, x_hf, y_hf); } - rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32(rsum)); - hvx_vec_store_u(r, 4, Q6_Vsf_equals_Vqf32(rsum)); + HVX_Vector rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum_p), Q6_V_hi_W(rsum_p))); + rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32(rsum))); + hvx_vec_store_u(r, 4, rsum); } -static inline void hvx_dot_f16_f16_aa_rx2(float * restrict r, - const void * restrict y, - const void * restrict x0, - const void * restrict x1, - unsigned int n, - float s) { - const HVX_Vector * restrict vx0 = (const HVX_Vector * restrict) x0; // fp16 - const HVX_Vector * restrict vx1 = (const HVX_Vector * restrict) x1; // fp16 - const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp16 - - uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors - uint32_t nloe = n % VLEN_FP16; // leftover elements - - HVX_Vector rsum0 = Q6_V_vsplat_R(0); - HVX_Vector rsum1 = Q6_V_vsplat_R(0); +static inline HVX_Vector hvx_dot_f16_f16_aa_rx4(const void * restrict y, + const uint8_t * restrict x, + const size_t stride_x, + const size_t nvec, + const size_t nloe) { + const HVX_Vector * restrict vx0 = (const HVX_Vector * restrict) x; // fp16 + const HVX_Vector * restrict vx1 = (const HVX_Vector * restrict) (x + stride_x); // fp16 + const HVX_Vector * restrict vx2 = (const HVX_Vector * restrict) (x + stride_x * 2); // fp16 + const HVX_Vector * restrict vx3 = (const HVX_Vector * restrict) (x + stride_x * 3); // fp16 + const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp16 + + HVX_VectorPair rsum0_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0)); + HVX_VectorPair rsum1_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0)); + HVX_VectorPair rsum2_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0)); + HVX_VectorPair rsum3_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0)); uint32_t i = 0; - #pragma unroll(4) for (i = 0; i < nvec; i++) { HVX_Vector y_hf = vy[i]; HVX_Vector x0_hf = vx0[i]; HVX_Vector x1_hf = vx1[i]; + HVX_Vector x2_hf = vx2[i]; + HVX_Vector x3_hf = vx3[i]; - HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf); - HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf); - - rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0)); - rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1)); + rsum0_p = hvx_vec_mpyacc_f32_f16(rsum0_p, x0_hf, y_hf); + rsum1_p = hvx_vec_mpyacc_f32_f16(rsum1_p, x1_hf, y_hf); + rsum2_p = hvx_vec_mpyacc_f32_f16(rsum2_p, x2_hf, y_hf); + rsum3_p = hvx_vec_mpyacc_f32_f16(rsum3_p, x3_hf, y_hf); } if (nloe) { // Load x (fp16) and zero-out unused elements HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2); - HVX_Vector x0_hf = Q6_V_vand_QV(bmask, vx0[i]); - HVX_Vector x1_hf = Q6_V_vand_QV(bmask, vx1[i]); - HVX_Vector y_hf = Q6_V_vand_QV(bmask, vy[i]); + HVX_Vector y_hf = Q6_V_vand_QV(bmask, vy[i]); + HVX_Vector x0_hf = Q6_V_vand_QV(bmask, vx0[i]); + HVX_Vector x1_hf = Q6_V_vand_QV(bmask, vx1[i]); + HVX_Vector x2_hf = Q6_V_vand_QV(bmask, vx2[i]); + HVX_Vector x3_hf = Q6_V_vand_QV(bmask, vx3[i]); + + rsum0_p = hvx_vec_mpyacc_f32_f16(rsum0_p, x0_hf, y_hf); + rsum1_p = hvx_vec_mpyacc_f32_f16(rsum1_p, x1_hf, y_hf); + rsum2_p = hvx_vec_mpyacc_f32_f16(rsum2_p, x2_hf, y_hf); + rsum3_p = hvx_vec_mpyacc_f32_f16(rsum3_p, x3_hf, y_hf); + } + + HVX_Vector rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum0_p), Q6_V_hi_W(rsum0_p))); + HVX_Vector rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum1_p), Q6_V_hi_W(rsum1_p))); + HVX_Vector rsum2 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum2_p), Q6_V_hi_W(rsum2_p))); + HVX_Vector rsum3 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum3_p), Q6_V_hi_W(rsum3_p))); - HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf); - HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf); + HVX_Vector_x4 rsum0123 = { .v = { rsum0, rsum1, rsum2, rsum3 } }; + return hvx_vec_reduce_sum_f32x4(rsum0123); +} - rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0)); - rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1)); +static inline HVX_Vector hvx_dot_f16_f16_aa_rx32(const void * restrict y, + const uint8_t * restrict x, + const size_t stride_x, + const size_t n, + float s) { + + const size_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors + const size_t nloe = n % VLEN_FP16; // leftover elements + + HVX_Vector sums; // initialize at j = 0 + const size_t stride_x_4 = stride_x * 4; + for (uint32_t j = 0; j < VLEN_FP32; j += 4) { + HVX_Vector sums_x4 = hvx_dot_f16_f16_aa_rx4(y, x, stride_x, nvec, nloe); + HVX_VectorPred pred = Q6_Q_vsetq_R(j * SIZEOF_FP32); + sums = Q6_V_vmux_QVV(pred, sums, sums_x4); + x += stride_x_4; } - HVX_Vector rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32x2(rsum0, rsum1)); - hvx_vec_store_u(r, 8, Q6_Vsf_equals_Vqf32(rsum)); + sums = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), sums); + return Q6_Vsf_equals_Vqf32(sums); } -// MAD: y (F32) += x (F16) * s (F32) -static inline void hvx_mad_f32_f16_aa(float * restrict y, const void * restrict x, int n, float s) { - const HVX_Vector * restrict ptr_x = (const HVX_Vector *) x; - HVX_Vector * restrict ptr_y = (HVX_Vector *) y; +// MAD: y (F32) += x (F16) * s (F16) +static inline void hvx_mad_f32_f16_aa(float * restrict y, const void * restrict x, const __fp16 * restrict s, int n) { + const HVX_Vector * restrict vx0 = (const HVX_Vector *) x; + + HVX_VectorPair * restrict vy_p = (HVX_VectorPair *) y; + HVX_Vector * restrict vy = (HVX_Vector *) y; uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors uint32_t nloe = n % VLEN_FP16; // leftover elements - HVX_Vector S = hvx_vec_splat_f16(s); + HVX_Vector S0 = hvx_vec_splat_f16(*s); uint32_t i = 0; - #pragma unroll(4) + + #pragma unroll(2) for (i = 0; i < nvec; ++i) { - // Multiply x * s -> pair of F32 vectors - HVX_VectorPair xs_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x[i]), S); - ptr_y[i*2] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_lo_W(xs_p), ptr_y[i*2])); - ptr_y[i*2+1] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_hi_W(xs_p), ptr_y[i*2+1])); + vy_p[i] = hvx_vec_mpyacc_f32_f16(vy_p[i], Q6_Vh_vshuff_Vh(vx0[i]), S0); } if (nloe) { - HVX_VectorPair xs_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x[i]), S); + HVX_VectorPair xy_p = vy_p[i]; + xy_p = hvx_vec_mpyacc_f32_f16(xy_p, Q6_Vh_vshuff_Vh(vx0[i]), S0); - HVX_Vector xs = Q6_V_lo_W(xs_p); - i = 2 * i; // index for ptr_y + HVX_Vector xy = Q6_V_lo_W(xy_p); + i = 2 * i; // index for vy - if (nloe >= 32) { - ptr_y[i] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i])); - nloe -= 32; ++i; xs = Q6_V_hi_W(xs_p); + if (nloe >= VLEN_FP32) { + vy[i] = xy; + nloe -= VLEN_FP32; ++i; xy = Q6_V_hi_W(xy_p); } if (nloe) { - HVX_Vector xy = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i])); - hvx_vec_store_a(&ptr_y[i], nloe * 4, xy); + hvx_vec_store_a(&vy[i], nloe * 4, xy); } } } -// MAD: y (F32) += x0 (F16) * s0 (F32) + x1 (F16) * s1 (F32) -static inline void hvx_mad_f32_f16_aa_rx2(float * restrict y, - const void * restrict x0, - const void * restrict x1, - float s0, - float s1, - int n) { - const HVX_Vector * restrict ptr_x0 = (const HVX_Vector *) x0; - const HVX_Vector * restrict ptr_x1 = (const HVX_Vector *) x1; - HVX_Vector * restrict ptr_y = (HVX_Vector *) y; +// MAD: y (F32) += x0 (F16) * s0 (F16) + x1 (F16) * s1 (F16) +static inline void hvx_mad_f32_f16_aa_rx2(float * restrict y, const void * restrict x0, const void * restrict x1, + const __fp16 * restrict s0, const __fp16 * restrict s1, int n) { + const HVX_Vector * restrict vx0 = (const HVX_Vector *) x0; + const HVX_Vector * restrict vx1 = (const HVX_Vector *) x1; + + HVX_VectorPair * restrict vy_p = (HVX_VectorPair *) y; + HVX_Vector * restrict vy = (HVX_Vector *) y; uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors uint32_t nloe = n % VLEN_FP16; // leftover elements - HVX_Vector S0 = hvx_vec_splat_f16(s0); - HVX_Vector S1 = hvx_vec_splat_f16(s1); + HVX_Vector S0 = hvx_vec_splat_f16(*s0); + HVX_Vector S1 = hvx_vec_splat_f16(*s1); uint32_t i = 0; + #pragma unroll(2) for (i = 0; i < nvec; ++i) { - // Multiply x * s -> pair of F32 vectors - HVX_VectorPair xs0_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x0[i]), S0); - HVX_VectorPair xs1_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x1[i]), S1); - - HVX_Vector xs_p_lo = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xs0_p), Q6_V_lo_W(xs1_p)); - HVX_Vector xs_p_hi = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_hi_W(xs0_p), Q6_V_hi_W(xs1_p)); - - ptr_y[i * 2] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs_p_lo, ptr_y[i * 2])); - ptr_y[i * 2 + 1] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs_p_hi, ptr_y[i * 2 + 1])); + vy_p[i] = hvx_vec_mpyacc_f32_f16(vy_p[i], Q6_Vh_vshuff_Vh(vx0[i]), S0); + vy_p[i] = hvx_vec_mpyacc_f32_f16(vy_p[i], Q6_Vh_vshuff_Vh(vx1[i]), S1); } if (nloe) { - HVX_VectorPair xs0_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x0[i]), S0); - HVX_VectorPair xs1_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x1[i]), S1); + HVX_VectorPair xy_p = vy_p[i]; + xy_p = hvx_vec_mpyacc_f32_f16(xy_p, Q6_Vh_vshuff_Vh(vx0[i]), S0); + xy_p = hvx_vec_mpyacc_f32_f16(xy_p, Q6_Vh_vshuff_Vh(vx1[i]), S1); - HVX_Vector xs_p_lo = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xs0_p), Q6_V_lo_W(xs1_p)); - HVX_Vector xs = xs_p_lo; - i = 2 * i; // index for ptr_y + HVX_Vector xy = Q6_V_lo_W(xy_p); + i = 2 * i; // index for vy - if (nloe >= 32) { - ptr_y[i] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i])); - nloe -= 32; ++i; - xs = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_hi_W(xs0_p), Q6_V_hi_W(xs1_p)); + if (nloe >= VLEN_FP32) { + vy[i] = xy; + nloe -= VLEN_FP32; ++i; xy = Q6_V_hi_W(xy_p); } if (nloe) { - HVX_Vector xy = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i])); - hvx_vec_store_a(&ptr_y[i], nloe * 4, xy); + hvx_vec_store_a(&vy[i], nloe * 4, xy); } } } -#define FLASH_ATTN_BLOCK_SIZE 128 - struct htp_fa_context { const struct htp_ops_context * octx; @@ -226,7 +246,12 @@ struct htp_fa_context { size_t size_v_block; size_t size_m_block; + uint32_t qrows; + uint32_t qrows_per_thread; + bool is_q_fp32; + + uint64_t t_start; }; static inline void hvx_scale_vec_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const int n, HVX_Vector vs) { @@ -296,9 +321,8 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void * const uint32_t nb3 = dst->nb[3]; // total rows in q - const uint32_t nr = neq1*neq2*neq3; - - const uint32_t dr = (nr + nth - 1) / nth; + const uint32_t nr = factx->qrows; + const uint32_t dr = factx->qrows_per_thread; const uint32_t ir0 = dr * ith; const uint32_t ir1 = MIN(ir0 + dr, nr); @@ -337,15 +361,8 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void * const uint8_t * q_row_ptr = (const uint8_t *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3); dma_queue_push(dma, dma_make_ptr(spad_q, q_row_ptr), factx->size_q_row_padded, nbq1, size_q_row, 1); - const uint32_t h = iq2; // head index - const float slope = (factx->max_bias > 0.0f) ? (h < factx->n_head_log2 ? powf(factx->m0, h + 1) : powf(factx->m1, 2*(h - factx->n_head_log2) + 1)) : 1.0f; - - HVX_Vector S_vec = hvx_vec_splat_f32(0.0f); - HVX_Vector M_vec = hvx_vec_splat_f32(-INFINITY); - - // Clear accumulator - hvx_splat_f32_a(spad_a, 0, DV); - float * VKQ32 = (float *) spad_a; + // FARF(HIGH, "fa %u: prefetch Q: ir %u iq1 %u iq2 %u iq3 %u q_row_ptr %p size %u : usec %u", ith, ir, iq1, iq2, iq3, q_row_ptr, size_q_row, + // (unsigned)HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - factx->t_start)); const __fp16 * mp_base = NULL; if (mask) { @@ -376,8 +393,23 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void * // Mask is 1D contiguous for this row dma_queue_push(dma, dma_make_ptr(m_dst, m_src), current_block_size * 2, current_block_size * 2, current_block_size * 2, 1); } + + // FARF(HIGH, "fa %u: prefetch KVM: ir %u ib %u iq1 %u iq2 %u iq3 %u : size_k_row %u size_v_row %u bs %u: usec %u", + // ith, ir, ib, iq1, iq2, iq3, + // size_k_row, size_v_row, current_block_size, + // (unsigned)HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - factx->t_start)); } + const uint32_t h = iq2; // head index + const float slope = (factx->max_bias > 0.0f) ? (h < factx->n_head_log2 ? powf(factx->m0, h + 1) : powf(factx->m1, 2*(h - factx->n_head_log2) + 1)) : 1.0f; + + HVX_Vector S_vec = hvx_vec_splat_f32(0.0f); + HVX_Vector M_vec = hvx_vec_splat_f32(-INFINITY); + + // Clear accumulator + hvx_splat_f32_a(spad_a, 0, DV); + float * VKQ32 = (float *) (spad_a + 0); + uint8_t * q_ptr_vtcm = dma_queue_pop(dma).dst; if (factx->is_q_fp32) { hvx_copy_f16_f32_aa(q_ptr_vtcm, q_ptr_vtcm, DK); // inplace convert f32 to f16 @@ -393,23 +425,19 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void * uint8_t * v_base = dma_queue_pop(dma).dst; // V __fp16 * m_base = mask ? dma_queue_pop(dma).dst : NULL; // M + // FARF(HIGH, "fa %u: process: ir %u ib %u : iq1 %u iq2 %u iq3 %u q_ptr_vtcm %p : usec %u", + // ith, ir, ib, iq1, iq2, iq3, q_ptr_vtcm, + // (unsigned)HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - factx->t_start)); + // Inner loop processing the block from VTCM uint32_t ic = 0; - // Process in blocks of 32 (VLEN_FP32) - static_assert(FLASH_ATTN_BLOCK_SIZE / VLEN_FP32 <= 4, "FLASH_ATTN_BLOCK_SIZE changed, fix HVX_Vector_x4 usage"); - HVX_Vector_x4 scores_x4; + // Process in sub-blocks of 32 (VLEN_FP32) + HVX_Vector sb_scores[FLASH_ATTN_BLOCK_SIZE / VLEN_FP32]; HVX_Vector v_max = hvx_vec_splat_f32(-INFINITY); for (uint32_t iv = 0; ic + VLEN_FP32 <= current_block_size; ic += VLEN_FP32, ++iv) { // 1. Compute scores - float __attribute__((aligned(VLEN))) scores_arr[VLEN_FP32]; - for (uint32_t j = 0; j < VLEN_FP32; j += 2) { - const uint32_t cur_ic = ic + j; - const uint8_t * k_ptr = k_base + cur_ic * factx->size_k_row_padded; - hvx_dot_f16_f16_aa_rx2(&scores_arr[j], q_ptr_vtcm, k_ptr, k_ptr + factx->size_k_row_padded, DK, factx->scale); - } - - HVX_Vector scores = *(HVX_Vector *) scores_arr; + HVX_Vector scores = hvx_dot_f16_f16_aa_rx32(q_ptr_vtcm, k_base + ic * factx->size_k_row_padded, factx->size_k_row_padded, DK, factx->scale); // 2. Softcap if (factx->logit_softcap != 0.0f) { @@ -428,35 +456,35 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void * scores = Q6_Vsf_equals_Vqf32(scores); } - scores_x4.v[iv] = scores; + sb_scores[iv] = scores; v_max = hvx_vec_reduce_max2_f32(scores, v_max); // All lanes have block max } { // 4. Online Softmax Update HVX_Vector M_new_vec = Q6_Vsf_vmax_VsfVsf(v_max, M_vec); - HVX_Vector diff_vec = Q6_Vqf32_vsub_VsfVsf(M_vec, M_new_vec); - HVX_Vector ms_vec = hvx_vec_exp_f32(Q6_Vsf_equals_Vqf32(diff_vec)); + HVX_Vector diff_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(M_vec, M_new_vec)); + HVX_Vector ms_vec = hvx_vec_exp_f32(diff_vec); M_vec = M_new_vec; hvx_scale_vec_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms_vec); HVX_Vector p_sum_vec = hvx_vec_splat_f32(0.0f); for (uint32_t ic2 = 0, iv = 0; ic2 + VLEN_FP32 <= current_block_size; ic2 += VLEN_FP32, ++iv) { - HVX_Vector scores = scores_x4.v[iv]; + HVX_Vector scores = sb_scores[iv]; HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_vec); HVX_Vector P = hvx_vec_exp_f32(Q6_Vsf_equals_Vqf32(scores_shifted)); p_sum_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(p_sum_vec, P)); // 5. Accumulate V - float __attribute__((aligned(VLEN))) p_arr[VLEN_FP32]; - *(HVX_Vector *) p_arr = P; + __fp16 __attribute__((aligned(VLEN))) p_arr[VLEN_FP16]; + hvx_vec_f32_to_f16_a(p_arr, P, hvx_vec_splat_f32(0)); for (uint32_t j = 0; j < VLEN_FP32; j += 2) { const uint32_t cur_ic = ic2 + j; const uint8_t * v_ptr = v_base + cur_ic * factx->size_v_row_padded; - hvx_mad_f32_f16_aa_rx2(VKQ32, v_ptr, v_ptr + factx->size_v_row_padded, p_arr[j], p_arr[j + 1], DV); + hvx_mad_f32_f16_aa_rx2(VKQ32, v_ptr, v_ptr + factx->size_v_row_padded, (p_arr + j), (p_arr + j + 1), DV); } } @@ -464,47 +492,50 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void * S_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(S_vec, ms_vec)), p_sum_vec)); } - // Sync scalars for leftover/next block if needed - float M = hvx_vec_get_f32(M_vec); - float S = hvx_vec_get_f32(S_vec); + if (ic < current_block_size) { + // Sync scalars for leftover/next block if needed + float M = hvx_vec_get_f32(M_vec); + float S = hvx_vec_get_f32(S_vec); + + // Leftover + for (; ic < current_block_size; ++ic) { + float s_val; + const uint8_t * k_ptr = k_base + ic * factx->size_k_row_padded; + hvx_dot_f16_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, factx->scale); + if (factx->logit_softcap != 0.0f) { + s_val = factx->logit_softcap * tanhf(s_val); + } - // Leftover - for (; ic < current_block_size; ++ic) { - float s_val; - const uint8_t * k_ptr = k_base + ic * factx->size_k_row_padded; - hvx_dot_f16_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, factx->scale); - if (factx->logit_softcap != 0.0f) { - s_val = factx->logit_softcap * tanhf(s_val); - } + if (mask) { + const float m_val = m_base[ic]; + s_val += slope * m_val; + } - if (mask) { - const float m_val = m_base[ic]; - s_val += slope * m_val; - } + const float Mold = M; + __fp16 vs = 1.0f; + + if (s_val > M) { + M = s_val; + HVX_Vector diff_vec = hvx_vec_splat_f32(Mold - M); + HVX_Vector ms_vec = hvx_vec_exp_f32(diff_vec); + hvx_scale_vec_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms_vec); + + float ms = hvx_vec_get_f32(ms_vec); + S = S * ms + vs; + } else { + HVX_Vector diff_vec = hvx_vec_splat_f32(s_val - M); + vs = hvx_vec_get_f32(hvx_vec_exp_f32(diff_vec)); + S += vs; + } - const float Mold = M; - float vs = 1.0f; - - if (s_val > M) { - M = s_val; - HVX_Vector diff_vec = hvx_vec_splat_f32(Mold - M); - HVX_Vector ms_vec = hvx_vec_exp_f32(diff_vec); - hvx_scale_vec_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms_vec); - - float ms = hvx_vec_get_f32(ms_vec); - S = S * ms + vs; - } else { - HVX_Vector diff_vec = hvx_vec_splat_f32(s_val - M); - vs = hvx_vec_get_f32(hvx_vec_exp_f32(diff_vec)); - S += vs; - } + const uint8_t * v_ptr = v_base + ic * factx->size_v_row_padded; - const uint8_t * v_ptr = v_base + ic * factx->size_v_row_padded; + hvx_mad_f32_f16_aa(VKQ32, v_ptr, &vs, DV); + } - hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, vs); + M_vec = hvx_vec_splat_f32(M); + S_vec = hvx_vec_splat_f32(S); } - M_vec = hvx_vec_splat_f32(M); - S_vec = hvx_vec_splat_f32(S); // Issue DMA for next+1 block (if exists) if (ib + 2 < factx->n_blocks) { @@ -525,6 +556,11 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void * const uint8_t * m_src = (const uint8_t *) (mp_base + next_ic_start); dma_queue_push(dma, dma_make_ptr(m_base, m_src), next_block_size * 2, next_block_size * 2, next_block_size * 2, 1); } + + // FARF(HIGH, "fa %u: prefetch KVM: ir %u ib %u : iq1 %u iq2 %u iq3 %u : size_k_row %u size_v_row %u bs %u: usec %u", + // ith, ir, next_ib, iq1, iq2, iq3, + // size_k_row, size_v_row, next_block_size, + // (unsigned)HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - factx->t_start)); } } @@ -586,6 +622,8 @@ int op_flash_attn_ext(struct htp_ops_context * octx) { struct htp_fa_context factx; factx.octx = octx; + factx.t_start = HAP_perf_get_qtimer_count(); + factx.src0_div21 = init_fastdiv_values(q->ne[2] * q->ne[1]); factx.src0_div1 = init_fastdiv_values(q->ne[1]); @@ -632,6 +670,15 @@ int op_flash_attn_ext(struct htp_ops_context * octx) { factx.m0 = powf(2.0f, -(max_bias ) / factx.n_head_log2); factx.m1 = powf(2.0f, -(max_bias / 2.0f) / factx.n_head_log2); + // total rows in q + const uint32_t neq0 = q->ne[0]; + const uint32_t neq1 = q->ne[1]; + const uint32_t neq2 = q->ne[2]; + const uint32_t neq3 = q->ne[3]; + + factx.qrows = neq1*neq2*neq3; + factx.qrows_per_thread = (factx.qrows + octx->n_threads - 1) / octx->n_threads; + size_t size_vkq_acc = hex_round_up(v->ne[0] * sizeof(float), 128); // VKQ32 octx->src0_spad.size_per_thread = size_q_block * 1; diff --git a/ggml/src/ggml-hexagon/htp/hvx-base.h b/ggml/src/ggml-hexagon/htp/hvx-base.h index 12a1b7f1288..701637f22b2 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-base.h +++ b/ggml/src/ggml-hexagon/htp/hvx-base.h @@ -38,7 +38,7 @@ static inline HVX_Vector hvx_vec_splat_f32(float v) { return Q6_V_vsplat_R(u.i); } -static inline HVX_Vector hvx_vec_splat_f16(float v) { +static inline HVX_Vector hvx_vec_splat_f16(_Float16 v) { union { __fp16 f; uint16_t i; } u = { .f = v }; return Q6_Vh_vsplat_R(u.i); } @@ -170,4 +170,23 @@ static inline HVX_Vector hvx_vec_i16_from_hf_rnd_sat(HVX_Vector vin) { return Q6_Vh_vround_VwVw_sat(vsf_1, vsf_0); } +#if __HVX_ARCH__ < 79 + +static inline HVX_VectorPair hvx_vec_mpyacc_f32_f16(HVX_VectorPair acc, HVX_Vector x, HVX_Vector y) +{ + HVX_VectorPair m = Q6_Wqf32_vmpy_VhfVhf(x, y); + HVX_Vector a0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_lo_W(m), Q6_V_lo_W(acc))); + HVX_Vector a1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_hi_W(m), Q6_V_hi_W(acc))); + return Q6_W_vcombine_VV(a1, a0); +} + +#else + +static inline HVX_VectorPair hvx_vec_mpyacc_f32_f16(HVX_VectorPair acc, HVX_Vector x, HVX_Vector y) +{ + return Q6_Wsf_vmpyacc_WsfVhfVhf(acc, x, y); +} + +#endif + #endif /* HVX_BASE_H */ diff --git a/ggml/src/ggml-hexagon/htp/hvx-copy.h b/ggml/src/ggml-hexagon/htp/hvx-copy.h index ae0dbed0306..851482e01b2 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-copy.h +++ b/ggml/src/ggml-hexagon/htp/hvx-copy.h @@ -42,11 +42,11 @@ static inline void hvx_splat_f32_u(uint8_t * restrict dst, float v, uint32_t n) hvx_splat_u(dst, hvx_vec_splat_f32(v), n, sizeof(float)); } -static inline void hvx_splat_f16_a(uint8_t * restrict dst, float v, uint32_t n) { +static inline void hvx_splat_f16_a(uint8_t * restrict dst, _Float16 v, uint32_t n) { hvx_splat_u(dst, hvx_vec_splat_f16(v), n, sizeof(__fp16)); } -static inline void hvx_splat_f16_u(uint8_t * restrict dst, float v, uint32_t n) { +static inline void hvx_splat_f16_u(uint8_t * restrict dst, _Float16 v, uint32_t n) { hvx_splat_u(dst, hvx_vec_splat_f16(v), n, sizeof(__fp16)); } diff --git a/ggml/src/ggml-hexagon/htp/hvx-reduce.h b/ggml/src/ggml-hexagon/htp/hvx-reduce.h index 1ca7c05d983..3c0073ef6d8 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-reduce.h +++ b/ggml/src/ggml-hexagon/htp/hvx-reduce.h @@ -46,6 +46,21 @@ static inline HVX_Vector hvx_vec_reduce_sum_qf32(HVX_Vector in) { #if __HVX_ARCH__ > 75 +static inline HVX_Vector hvx_vec_reduce_sum_f32x4(HVX_Vector_x4 in) { + HVX_VectorPair sum_p01 = Q6_W_vshuff_VVR(in.v[1], in.v[0], 4); + HVX_VectorPair sum_p23 = Q6_W_vshuff_VVR(in.v[3], in.v[2], 4); + HVX_Vector sum_sf01 = Q6_Vsf_vadd_VsfVsf(Q6_V_lo_W(sum_p01), Q6_V_hi_W(sum_p01)); + HVX_Vector sum_sf23 = Q6_Vsf_vadd_VsfVsf(Q6_V_lo_W(sum_p23), Q6_V_hi_W(sum_p23)); + + HVX_VectorPair sum_p0123 = Q6_W_vshuff_VVR(sum_sf23, sum_sf01, 8); + HVX_Vector sum_sf = Q6_Vsf_vadd_VsfVsf(Q6_V_lo_W(sum_p0123), Q6_V_hi_W(sum_p0123)); + + sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 2)); + sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 4)); + sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 8)); + return sum_sf; +} + static inline HVX_Vector hvx_vec_reduce_sum_f32x2(HVX_Vector in0, HVX_Vector in1) { HVX_VectorPair sump = Q6_W_vshuff_VVR(in1, in0, 4); HVX_Vector sum_sf = Q6_Vsf_vadd_VsfVsf(Q6_V_lo_W(sump), Q6_V_hi_W(sump)); @@ -72,6 +87,21 @@ static inline HVX_Vector hvx_vec_reduce_sum_n_f32(HVX_Vector in, unsigned int n) #else +static inline HVX_Vector hvx_vec_reduce_sum_f32x4(HVX_Vector_x4 in) { + HVX_VectorPair sum_p01 = Q6_W_vshuff_VVR(in.v[1], in.v[0], 4); + HVX_VectorPair sum_p23 = Q6_W_vshuff_VVR(in.v[3], in.v[2], 4); + HVX_Vector sum_qf01 = Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(sum_p01), Q6_V_hi_W(sum_p01)); + HVX_Vector sum_qf23 = Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(sum_p23), Q6_V_hi_W(sum_p23)); + + HVX_VectorPair sum_p0123 = Q6_W_vshuff_VVR(Q6_Vsf_equals_Vqf32(sum_qf23), Q6_Vsf_equals_Vqf32(sum_qf01), 8); + HVX_Vector sum_qf = Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(sum_p0123), Q6_V_hi_W(sum_p0123)); + + sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 2)); + sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 4)); + sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 8)); + return Q6_Vsf_equals_Vqf32(sum_qf); +} + static inline HVX_Vector hvx_vec_reduce_sum_f32x2(HVX_Vector in0, HVX_Vector in1) { HVX_VectorPair sump = Q6_W_vshuff_VVR(in1, in0, 4); HVX_Vector sum_qf = Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(sump), Q6_V_hi_W(sump)); diff --git a/ggml/src/ggml-hexagon/htp/matmul-ops.c b/ggml/src/ggml-hexagon/htp/matmul-ops.c index 6f6f51f01f5..9ca74aedfef 100644 --- a/ggml/src/ggml-hexagon/htp/matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/matmul-ops.c @@ -1234,27 +1234,24 @@ static void vec_dot_f16_f16_aa_1x1(const int n, float * restrict s, const void * uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors uint32_t nloe = n % VLEN_FP16; // leftover elements - HVX_Vector rsum = Q6_V_vsplat_R(0); + HVX_VectorPair rsum_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0)); uint32_t i = 0; #pragma unroll(4) for (i = 0; i < nvec; i++) { - HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x[i], y[i]); - rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); + rsum_p = hvx_vec_mpyacc_f32_f16(rsum_p, x[i], y[i]); } if (nloe) { HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2); HVX_Vector x_hf = Q6_V_vand_QV(bmask, x[i]); HVX_Vector y_hf = Q6_V_vand_QV(bmask, y[i]); - - HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); - rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); + rsum_p = hvx_vec_mpyacc_f32_f16(rsum_p, x_hf, y_hf); } - rsum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(rsum)); - hvx_vec_store_u(&s[0], 4, rsum); + HVX_Vector rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum_p), Q6_V_hi_W(rsum_p))); + hvx_vec_store_u(s, 4, hvx_vec_reduce_sum_f32(rsum)); } static void vec_dot_f16_f16_aa_2x1(const int n, float * restrict s0, @@ -1267,35 +1264,30 @@ static void vec_dot_f16_f16_aa_2x1(const int n, float * restrict s0, uint32_t nvec = n / VLEN_FP16; uint32_t nloe = n % VLEN_FP16; - HVX_Vector rsum0 = Q6_V_vsplat_R(0); - HVX_Vector rsum1 = Q6_V_vsplat_R(0); + HVX_VectorPair rsum0_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0)); + HVX_VectorPair rsum1_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0)); uint32_t i = 0; #pragma unroll(2) for (i = 0; i < nvec; i++) { HVX_Vector y_hf = y[i]; - HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0[i], y_hf); - HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1[i], y_hf); - - rsum0 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum0, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf))); - rsum1 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum1, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf))); + rsum0_p = hvx_vec_mpyacc_f32_f16(rsum0_p, x0[i], y_hf); + rsum1_p = hvx_vec_mpyacc_f32_f16(rsum1_p, x1[i], y_hf); } if (nloe) { HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2); + HVX_Vector y_hf = Q6_V_vand_QV(bmask, y[i]); HVX_Vector x0_hf = Q6_V_vand_QV(bmask, x0[i]); HVX_Vector x1_hf = Q6_V_vand_QV(bmask, x1[i]); - HVX_Vector y_hf = Q6_V_vand_QV(bmask, y[i]); - - HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf); - HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf); - - rsum0 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum0, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf))); - rsum1 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum1, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf))); + rsum0_p = hvx_vec_mpyacc_f32_f16(rsum0_p, x0_hf, y_hf); + rsum1_p = hvx_vec_mpyacc_f32_f16(rsum1_p, x1_hf, y_hf); } - HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(Q6_Vsf_equals_Vqf32(rsum0), Q6_Vsf_equals_Vqf32(rsum1)); + HVX_Vector rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum0_p), Q6_V_hi_W(rsum0_p))); + HVX_Vector rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum1_p), Q6_V_hi_W(rsum1_p))); + HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(rsum0, rsum1); hvx_vec_store_u(s0, 8, rsum); } @@ -1311,10 +1303,10 @@ static void vec_dot_f16_f16_aa_2x2(const int n, float * restrict s0, float * res uint32_t nloe = n % VLEN_FP16; // Row sums (sf) - 4 accumulators for 2×2 tile - HVX_Vector r0_c0_sum = Q6_V_vsplat_R(0); - HVX_Vector r0_c1_sum = Q6_V_vsplat_R(0); - HVX_Vector r1_c0_sum = Q6_V_vsplat_R(0); - HVX_Vector r1_c1_sum = Q6_V_vsplat_R(0); + HVX_VectorPair r0_c0_sum_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0)); + HVX_VectorPair r0_c1_sum_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0)); + HVX_VectorPair r1_c0_sum_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0)); + HVX_VectorPair r1_c1_sum_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0)); uint32_t i = 0; @@ -1326,20 +1318,10 @@ static void vec_dot_f16_f16_aa_2x2(const int n, float * restrict s0, float * res HVX_Vector c1_hf = y1[i]; // Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1 - HVX_VectorPair r0_c0_qf_p = Q6_Wqf32_vmpy_VhfVhf(r0_hf, c0_hf); - HVX_VectorPair r0_c1_qf_p = Q6_Wqf32_vmpy_VhfVhf(r0_hf, c1_hf); - HVX_VectorPair r1_c0_qf_p = Q6_Wqf32_vmpy_VhfVhf(r1_hf, c0_hf); - HVX_VectorPair r1_c1_qf_p = Q6_Wqf32_vmpy_VhfVhf(r1_hf, c1_hf); - - HVX_Vector r0_c0_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r0_c0_qf_p), Q6_V_hi_W(r0_c0_qf_p)); - HVX_Vector r0_c1_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r0_c1_qf_p), Q6_V_hi_W(r0_c1_qf_p)); - HVX_Vector r1_c0_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r1_c0_qf_p), Q6_V_hi_W(r1_c0_qf_p)); - HVX_Vector r1_c1_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r1_c1_qf_p), Q6_V_hi_W(r1_c1_qf_p)); - - r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_qf, r0_c0_sum)); - r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_qf, r0_c1_sum)); - r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_qf, r1_c0_sum)); - r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_qf, r1_c1_sum)); + r0_c0_sum_p = hvx_vec_mpyacc_f32_f16(r0_c0_sum_p, r0_hf, c0_hf); + r0_c1_sum_p = hvx_vec_mpyacc_f32_f16(r0_c1_sum_p, r0_hf, c1_hf); + r1_c0_sum_p = hvx_vec_mpyacc_f32_f16(r1_c0_sum_p, r1_hf, c0_hf); + r1_c1_sum_p = hvx_vec_mpyacc_f32_f16(r1_c1_sum_p, r1_hf, c1_hf); } if (nloe) { @@ -1350,23 +1332,17 @@ static void vec_dot_f16_f16_aa_2x2(const int n, float * restrict s0, float * res HVX_Vector c0_hf = Q6_V_vand_QV(bmask, y0[i]); HVX_Vector c1_hf = Q6_V_vand_QV(bmask, y1[i]); - HVX_VectorPair r0_c0_qf_p = Q6_Wqf32_vmpy_VhfVhf(r0_hf, c0_hf); - HVX_VectorPair r0_c1_qf_p = Q6_Wqf32_vmpy_VhfVhf(r0_hf, c1_hf); - HVX_VectorPair r1_c0_qf_p = Q6_Wqf32_vmpy_VhfVhf(r1_hf, c0_hf); - HVX_VectorPair r1_c1_qf_p = Q6_Wqf32_vmpy_VhfVhf(r1_hf, c1_hf); - - HVX_Vector r0_c0_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r0_c0_qf_p), Q6_V_hi_W(r0_c0_qf_p)); - HVX_Vector r0_c1_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r0_c1_qf_p), Q6_V_hi_W(r0_c1_qf_p)); - HVX_Vector r1_c0_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r1_c0_qf_p), Q6_V_hi_W(r1_c0_qf_p)); - HVX_Vector r1_c1_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r1_c1_qf_p), Q6_V_hi_W(r1_c1_qf_p)); - - r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_qf, r0_c0_sum)); - r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_qf, r0_c1_sum)); - r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_qf, r1_c0_sum)); - r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_qf, r1_c1_sum)); - + r0_c0_sum_p = hvx_vec_mpyacc_f32_f16(r0_c0_sum_p, r0_hf, c0_hf); + r0_c1_sum_p = hvx_vec_mpyacc_f32_f16(r0_c1_sum_p, r0_hf, c1_hf); + r1_c0_sum_p = hvx_vec_mpyacc_f32_f16(r1_c0_sum_p, r1_hf, c0_hf); + r1_c1_sum_p = hvx_vec_mpyacc_f32_f16(r1_c1_sum_p, r1_hf, c1_hf); } + HVX_Vector r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(r0_c0_sum_p), Q6_V_hi_W(r0_c0_sum_p))); + HVX_Vector r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(r0_c1_sum_p), Q6_V_hi_W(r0_c1_sum_p))); + HVX_Vector r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(r1_c0_sum_p), Q6_V_hi_W(r1_c0_sum_p))); + HVX_Vector r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(r1_c1_sum_p), Q6_V_hi_W(r1_c1_sum_p))); + // Reduce and store results HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum); HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum); From 67abc63e9d98ff08411149aab7af306481be8979 Mon Sep 17 00:00:00 2001 From: Marcel Petrick Date: Thu, 5 Mar 2026 08:50:21 +0100 Subject: [PATCH 215/831] chore : correct typos [no ci] (llama/20041) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix(docs): correct typos found during code review Non-functional changes only: - Fixed minor spelling mistakes in comments - Corrected typos in user-facing strings - No variables, logic, or functional code was modified. Signed-off-by: Marcel Petrick * Update docs/backend/CANN.md Co-authored-by: Aaron Teo * Revert "Auxiliary commit to revert individual files from 846d1c301281178efbc6ce6060ad34c1ebe45af8" This reverts commit 02fcf0c7db661d5ff3eff96b2b2db9fdb7213256. * Update tests/test-backend-ops.cpp Co-authored-by: Sigbjørn Skjæret * Update tests/test-backend-ops.cpp Co-authored-by: Sigbjørn Skjæret --------- Signed-off-by: Marcel Petrick Co-authored-by: Aaron Teo Co-authored-by: Sigbjørn Skjæret --- ggml/include/ggml-backend.h | 2 +- ggml/include/ggml-opt.h | 2 +- ggml/include/ggml.h | 2 +- ggml/src/ggml-cpu/amx/mmq.cpp | 6 ++--- ggml/src/ggml-cpu/arch/arm/quants.c | 2 +- ggml/src/ggml-cpu/arch/arm/repack.cpp | 4 ++-- ggml/src/ggml-cpu/arch/x86/repack.cpp | 32 ++++++++++++------------- ggml/src/ggml-cpu/ggml-cpu.c | 2 +- ggml/src/ggml-cpu/llamafile/sgemm.cpp | 4 ++-- ggml/src/ggml-cpu/ops.cpp | 4 ++-- ggml/src/ggml-cpu/repack.cpp | 4 ++-- ggml/src/ggml-cuda/fattn-mma-f16.cuh | 2 +- ggml/src/ggml-cuda/fattn-vec.cuh | 2 +- ggml/src/ggml-cuda/fattn-wmma-f16.cuh | 2 +- ggml/src/ggml-cuda/ggml-cuda.cu | 2 +- ggml/src/ggml-cuda/quantize.cu | 2 +- ggml/src/ggml-cuda/softmax.cu | 2 +- ggml/src/ggml-cuda/solve_tri.cu | 2 +- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 22 ++++++++--------- ggml/src/ggml-hexagon/htp-drv.cpp | 2 +- ggml/src/ggml-hexagon/htp/hvx-inverse.h | 2 +- ggml/src/ggml-hexagon/htp/rope-ops.c | 2 +- ggml/src/ggml-hexagon/htp/worker-pool.c | 2 +- ggml/src/ggml-metal/ggml-metal-device.m | 2 +- ggml/src/ggml-metal/ggml-metal-ops.cpp | 6 ++--- ggml/src/ggml-metal/ggml-metal.cpp | 2 +- ggml/src/ggml-metal/ggml-metal.metal | 2 +- ggml/src/ggml-opencl/ggml-opencl.cpp | 10 ++++---- ggml/src/ggml-sycl/common.hpp | 8 +++---- ggml/src/ggml-sycl/quants.hpp | 2 +- ggml/src/ggml-sycl/softmax.cpp | 2 +- ggml/src/ggml-vulkan/CMakeLists.txt | 2 +- 32 files changed, 72 insertions(+), 72 deletions(-) diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h index a9d1778641e..9fd3f7f32a0 100644 --- a/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h @@ -259,7 +259,7 @@ extern "C" { Example usage: // operations that use tensors allocated in a buffer with USAGE_WEIGHTS will be assigned - // preferrably to run on the same backend as the buffer + // preferably to run on the same backend as the buffer ggml_backend_buffer_set_usage(buf_weights, GGML_BACKEND_BUFFER_USAGE_WEIGHTS); sched = ggml_backend_sched_new({backend_gpu, backend_gpu2, backend_cpu}, NULL, num_backends, GGML_DEFAULT_GRAPH_SIZE, false, true); diff --git a/ggml/include/ggml-opt.h b/ggml/include/ggml-opt.h index 4703a05afe1..1c2ed79b774 100644 --- a/ggml/include/ggml-opt.h +++ b/ggml/include/ggml-opt.h @@ -138,7 +138,7 @@ extern "C" { GGML_API ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params); GGML_API void ggml_opt_free(ggml_opt_context_t opt_ctx); - // set gradients to zero, initilize loss, and optionally reset the optimizer + // set gradients to zero, initialize loss, and optionally reset the optimizer GGML_API void ggml_opt_reset(ggml_opt_context_t opt_ctx, bool optimizer); GGML_API bool ggml_opt_static_graphs(ggml_opt_context_t opt_ctx); // whether the graphs are allocated_statically diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index fcc51f1f71a..784d69206b4 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -2575,7 +2575,7 @@ extern "C" { struct ggml_tensor * grad, struct ggml_tensor * sgd_params); // alpha, weight decay - // build forward mutiple tensors and select one of them for computing + // build forward multiple tensors and select one of them for computing // this is useful for creating graphs that have constant topology but compute different things based on the input // ref: https://github.com/ggml-org/llama.cpp/pull/18550 // diff --git a/ggml/src/ggml-cpu/amx/mmq.cpp b/ggml/src/ggml-cpu/amx/mmq.cpp index b5aca76633c..93a6d397f79 100644 --- a/ggml/src/ggml-cpu/amx/mmq.cpp +++ b/ggml/src/ggml-cpu/amx/mmq.cpp @@ -195,7 +195,7 @@ struct tile_config_t{ // will be needed. // // Here another commonly used pattern 1-3-3 is skipped, as it is mostly used when m <=16; -// and the sinlge batch gemm (m=1) has a special fast path with `avx512-vnni`. +// and the single batch gemm (m=1) has a special fast path with `avx512-vnni`. // // ref: https://www.intel.com/content/www/us/en/developer/articles/code-sample/ // advanced-matrix-extensions-intrinsics-functions.html @@ -1379,8 +1379,8 @@ struct tinygemm_kernel_vnni 4 #if _WIN32_WINNT >= 0x0602 diff --git a/ggml/src/ggml-cpu/llamafile/sgemm.cpp b/ggml/src/ggml-cpu/llamafile/sgemm.cpp index da412fd009b..5fd452a03d2 100644 --- a/ggml/src/ggml-cpu/llamafile/sgemm.cpp +++ b/ggml/src/ggml-cpu/llamafile/sgemm.cpp @@ -533,7 +533,7 @@ class tinyBLAS { if constexpr (RN > 1) { return mnpack(m, n, SIZE_N, BN); } else { - GGML_LOG_ERROR("mnpack<%d, %d> bloc size not supported\n", RM, (int)SIZE_N); + GGML_LOG_ERROR("mnpack<%d, %d> block size not supported\n", RM, (int)SIZE_N); GGML_ASSERT(false); // we have miss something. } } @@ -711,7 +711,7 @@ class tinyBLAS_RVV { if constexpr (RN > 1) { return mnpack(m, n, SIZE_N, BN); } else { - GGML_LOG_ERROR("mnpack<%d, %d> bloc size not supported\n", RM, (int)SIZE_N); + GGML_LOG_ERROR("mnpack<%d, %d> block size not supported\n", RM, (int)SIZE_N); GGML_ASSERT(false); // we have miss something. } } diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index b7a70e06f1d..ca1b3059b8c 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -375,7 +375,7 @@ static void ggml_compute_forward_dup_bytes( const size_t rs = ne00 * type_size; if (nb00 == type_size) { - // src0 is contigous on first dimension, copy by rows + // src0 is contiguous on first dimension, copy by rows for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { id += rs * ir0; @@ -1795,7 +1795,7 @@ void ggml_compute_forward_repeat( { ggml_compute_forward_repeat_f32(params, dst); } break; - // TODO: templateify the implemenation and support for I64 + // TODO: templateify the implementation and support for I64 // ref https://github.com/ggml-org/llama.cpp/pull/14274#discussion_r2169492225 //case GGML_TYPE_I64: // { diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp index 5edba4212f6..02c3cc3119b 100644 --- a/ggml/src/ggml-cpu/repack.cpp +++ b/ggml/src/ggml-cpu/repack.cpp @@ -3032,7 +3032,7 @@ template src[1])); - size = GGML_PAD(size, sizeof(int64_t)); // + padding for next bloc. + size = GGML_PAD(size, sizeof(int64_t)); // + padding for next block. const int64_t ne02 = op->src[0]->ne[2]; // n_as, n_expert const int64_t ne12 = op->src[1]->ne[2]; // n_tokens @@ -3297,7 +3297,7 @@ template wdata; auto * wdata_src1_end = (char *)wdata + GGML_PAD(nbw3, sizeof(int64_t)); - // total of [n_as][ne12 + 1] elemets of type mmid_row_mapping (2*int32_t = int64_t) + // total of [n_as][ne12 + 1] elements of type mmid_row_mapping (2*int32_t = int64_t) auto * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as] struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *) (matrix_row_counts + n_as); // [n_as][ne12] diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index beb7e32e4fc..fff70c8eb89 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -1215,7 +1215,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( } // If attention sinks are used, potentially re-scale if KQ_max is small. - // Also add the sink as a value to KQ_rowsum, this is done after synchonization of KQ_rowsum + // Also add the sink as a value to KQ_rowsum, this is done after synchronization of KQ_rowsum // so it's being done unconditionally for every thread. if (!is_fixup && (np == 1 || threadIdx.y % np == 0) && sinks_f) { float KQ_max_scale[cols_per_thread]; diff --git a/ggml/src/ggml-cuda/fattn-vec.cuh b/ggml/src/ggml-cuda/fattn-vec.cuh index 3f4a78cc6e5..7cbe32633e5 100644 --- a/ggml/src/ggml-cuda/fattn-vec.cuh +++ b/ggml/src/ggml-cuda/fattn-vec.cuh @@ -10,7 +10,7 @@ static constexpr __device__ int ggml_cuda_fattn_vec_get_nthreads_device() { return 128; } -// Currenlty llvm with the amdgcn target does not support unrolling loops +// Currently llvm with the amdgcn target does not support unrolling loops // that contain a break that can not be resolved at compile time. #ifdef __clang__ #pragma clang diagnostic push diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cuh b/ggml/src/ggml-cuda/fattn-wmma-f16.cuh index cd3bfd4051a..aaf711a618c 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cuh @@ -18,7 +18,7 @@ #if defined(RDNA4) && ROCWMMA_VERSION_MAJOR > 1 #define GGML_USE_WMMA_FATTN #elif defined(RDNA4) -#warning "rocwmma fattn is not suported on RDNA4 on rocwmma < v2.0.0, expect degraded performance" +#warning "rocwmma fattn is not supported on RDNA4 on rocwmma < v2.0.0, expect degraded performance" #endif // defined(RDNA4) && ROCWMMA_VERSION_MAJOR > 1 #endif // defined(GGML_HIP_ROCWMMA_FATTN) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 7e6d3303549..b56e3d50f58 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3330,7 +3330,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, return false; } - //rms_norm kernel assumes contigous rows + //rms_norm kernel assumes contiguous rows if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) { return false; } diff --git a/ggml/src/ggml-cuda/quantize.cu b/ggml/src/ggml-cuda/quantize.cu index a8c68e44b16..4300ffc148c 100644 --- a/ggml/src/ggml-cuda/quantize.cu +++ b/ggml/src/ggml-cuda/quantize.cu @@ -235,7 +235,7 @@ static __global__ void quantize_mmq_q8_1( q.z = roundf(xi.z*d_inv); q.w = roundf(xi.w*d_inv); - // Write back 4 int8 values as a single 32 bit value for better memroy bandwidth: + // Write back 4 int8 values as a single 32 bit value for better memory bandwidth: char4 * yqs4 = (char4 *) y[ib].qs; yqs4[iqs/4] = q; diff --git a/ggml/src/ggml-cuda/softmax.cu b/ggml/src/ggml-cuda/softmax.cu index dc06d06930e..285c0e9543a 100644 --- a/ggml/src/ggml-cuda/softmax.cu +++ b/ggml/src/ggml-cuda/softmax.cu @@ -46,7 +46,7 @@ struct soft_max_params { }; // When ncols_template == 0 the bounds for the loops in this function are not known and can't be unrolled. -// As we want to keep pragma unroll for all other cases we supress the clang transformation warning here. +// As we want to keep pragma unroll for all other cases we suppress the clang transformation warning here. #ifdef __clang__ #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wpass-failed" diff --git a/ggml/src/ggml-cuda/solve_tri.cu b/ggml/src/ggml-cuda/solve_tri.cu index 177ffc268f1..07ca33f513b 100644 --- a/ggml/src/ggml-cuda/solve_tri.cu +++ b/ggml/src/ggml-cuda/solve_tri.cu @@ -83,7 +83,7 @@ static void solve_tri_f32_cublas(ggml_backend_cuda_context & ctx, // ====================== // When ncols_template == 0 the bounds for the loops in this function are not // known and can't be unrolled. As we want to keep pragma unroll for all other -// cases we supress the clang transformation warning here. +// cases we suppress the clang transformation warning here. #ifdef __clang__ # pragma clang diagnostic push # pragma clang diagnostic ignored "-Wpass-failed" diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 7a44443a8a3..3006e217796 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -139,7 +139,7 @@ struct ggml_hexagon_session { }; void ggml_hexagon_session::enqueue(struct htp_general_req &req, struct dspqueue_buffer *bufs, uint32_t n_bufs, bool sync) { - // Bump pending flag (cleared in the session::flush once we get the responce) + // Bump pending flag (cleared in the session::flush once we get the response) this->op_pending++; // atomic inc int err = dspqueue_write(this->queue, @@ -443,7 +443,7 @@ static void repack_row_q4x4x2(uint8_t * y, const block_q4_0 * x, int64_t k) { // Repack the scales // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q4_0x4x2) - // the last block is truncated and overriden by the scales. + // the last block is truncated and overridden by the scales. for (int i = 0; i < nb; i++) { // Repack the scales ggml_half * d = (ggml_half *) (y_d + i * dblk_size); @@ -503,7 +503,7 @@ static void unpack_row_q4x4x2(block_q4_0 * x, const uint8_t * y, int64_t k) { // Repack the scales // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q4_0x4x2) - // the last block is truncated and overriden by the scales. + // the last block is truncated and overridden by the scales. for (int i = 0; i < nb; i++) { // Unpack the scales const ggml_half * d = (const ggml_half *) (y_d + i * dblk_size); @@ -552,7 +552,7 @@ static void init_row_q4x4x2(block_q4_0 * x, int64_t k) { // Init the scales // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q4_0x4x2) - // the last block is truncated and overriden by the scales. + // the last block is truncated and overridden by the scales. for (int i = 0; i < nb; i++) { // Unpack the scales x[i * 8 + 0].d = 0; @@ -770,7 +770,7 @@ static void repack_row_q8x4x2(uint8_t * y, const block_q8_0 * x, int64_t k) { // Repack the scales // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q4_0x4x2) - // the last block is truncated and overriden by the scales. + // the last block is truncated and overridden by the scales. for (int i = 0; i < nb; i++) { // Repack the scales ggml_half * d = (ggml_half *) (y_d + i * dblk_size); @@ -829,7 +829,7 @@ static void unpack_row_q8x4x2(block_q8_0 * x, const uint8_t * y, int64_t k) { // Repack the scales // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q4_0x4x2) - // the last block is truncated and overriden by the scales. + // the last block is truncated and overridden by the scales. for (int i = 0; i < nb; i++) { // Unpack the scales const ggml_half * d = (const ggml_half *) (y_d + i * dblk_size); @@ -878,7 +878,7 @@ static void init_row_q8x4x2(block_q8_0 * x, int64_t k) { // Init the scales // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q8_0x4x2) - // the last block is truncated and overriden by the scales. + // the last block is truncated and overridden by the scales. for (int i = 0; i < nb; i++) { // Unpack the scales x[i * 8 + 0].d = 0; @@ -1120,7 +1120,7 @@ static void repack_row_mxfp4x4x2(uint8_t * y, const block_mxfp4 * x, int64_t k) // Repack the scales // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_MXFP4x4x2) - // the last block is truncated and overriden by the scales. + // the last block is truncated and overridden by the scales. for (int i = 0; i < nb; i++) { // Repack the scales uint8_t * e = (uint8_t *) (y_e + i * eblk_size); @@ -1180,7 +1180,7 @@ static void unpack_row_mxfp4x4x2(block_mxfp4 * x, const uint8_t * y, int64_t k) // Repack the scales // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_MXFP4_0x4x2) - // the last block is truncated and overriden by the scales. + // the last block is truncated and overridden by the scales. for (int i = 0; i < nb; i++) { // Unpack the scales const uint8_t * e = (const uint8_t *) (y_e + i * eblk_size); @@ -1229,7 +1229,7 @@ static void init_row_mxfp4x4x2(block_mxfp4 * x, int64_t k) { // Init the scales // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_MXFP4x4x2) - // the last block is truncated and overriden by the scales. + // the last block is truncated and overridden by the scales. for (int i = 0; i < nb; i++) { // Unpack the scales x[i * 8 + 0].e = 0; @@ -2670,7 +2670,7 @@ static std::vector ggml_hexagon_graph_optimize_reorder(const std::vectorn_jobs); unsigned int i = atomic_fetch_add(&pool->next_job, 1); if (i >= n) { - // Spurios wakeup + // Spurious wakeup continue; } diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 3db7f126291..4cce414abfe 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -1281,7 +1281,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te bool use_residency_sets; // optional MTLResidencySet - // note: cannot use explicity "id" here because it is not available on certain OSes + // note: cannot use explicitly "id" here because it is not available on certain OSes id rset; // pointers to global device diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 3d5db0b79f5..b3390352ffc 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -631,7 +631,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) { const bool inplace = (bool) ((const int32_t *) op->op_params)[4]; if (!inplace) { - // run a separete kernel to cpy src->dst + // run a separate kernel to cpy src->dst // not sure how to avoid this // TODO: make a simpler cpy_bytes kernel @@ -1644,7 +1644,7 @@ int ggml_metal_op_set(ggml_metal_op_t ctx, int idx) { const bool inplace = (bool) ((const int32_t *) op->op_params)[4]; if (!inplace) { - // run a separete kernel to cpy src->dst + // run a separate kernel to cpy src->dst // not sure how to avoid this // TODO: make a simpler cpy_bytes kernel @@ -2005,7 +2005,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) { const int16_t r0ptg = nypsg*nsg; // num src0 rows per threadgroup int16_t r1ptg = 4; // num src1 rows per threadgroup - // note: not sure how optimal are those across all different hardware. there might be someting cleverer + // note: not sure how optimal are those across all different hardware. there might be something cleverer switch (ne11) { case 2: r1ptg = 2; break; diff --git a/ggml/src/ggml-metal/ggml-metal.cpp b/ggml/src/ggml-metal/ggml-metal.cpp index 1c705362fb7..9382ce53b36 100644 --- a/ggml/src/ggml-metal/ggml-metal.cpp +++ b/ggml/src/ggml-metal/ggml-metal.cpp @@ -14,7 +14,7 @@ #define GGML_METAL_MAX_DEVICES 16 // number of Metal devices -// note: can be overriden with GGML_METAL_DEVICES env to simulate virtual devices +// note: can be overridden with GGML_METAL_DEVICES env to simulate virtual devices static int g_devices = 1; //////////////////////////////////////////////////////////////////////////////// diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 6c349aa0c92..a58e641ad86 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -4218,7 +4218,7 @@ kernel void kernel_im2col( template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col; template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col; -// TODO: obolete -- remove +// TODO: obsolete -- remove //typedef void (im2col_ext_t)( // constant ggml_metal_kargs_im2col & args, // device const float * x, diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index a4403a5c273..7af032ce0e1 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -313,7 +313,7 @@ struct ProfilingInfo { cl_ulong cmd_duration_ns; // The time for the kernel to complete - COMPLETE - END cl_ulong cmd_complete_duration_ns; - // Total time to finish the kernel - COMPELTE - QUEUED + // Total time to finish the kernel - COMPLETE - QUEUED cl_ulong cmd_total_duration_ns; // Global and local work sizes. size_t global_size[3]; @@ -2555,7 +2555,7 @@ static std::vector ggml_opencl_probe_devices(ggml_backend_r cl_platform_id platform_ids[NPLAT]; if (clGetPlatformIDs(NPLAT, platform_ids, &n_platforms) != CL_SUCCESS) { - GGML_LOG_ERROR("ggml_opencl: plaform IDs not available.\n"); + GGML_LOG_ERROR("ggml_opencl: platform IDs not available.\n"); return found_devices; } @@ -3339,7 +3339,7 @@ static void ggml_backend_opencl_synchronize(ggml_backend_t backend) { CL_CHECK(clReleaseEvent(evt)); } -// Syncronizes the 'backend_ctx's device with others so that commands +// Synchronizes the 'backend_ctx's device with others so that commands // enqueued to it won't start until commands in the other devices have // completed. static void sync_with_other_backends(ggml_backend_opencl_context * backend_ctx) { @@ -3997,7 +3997,7 @@ struct ggml_backend_opencl_buffer_context { // The buffer_context is initially created by ggml_backend_buft_alloc_buffer // before any tensor is initialized (at the beginning of alloc_tensor_range). - // Hence, there is alway a buffer object in this vector. When each tensor is + // Hence, there is always a buffer object in this vector. When each tensor is // being initialized, this original buffer object will be released if both // flattening and small allocation are enabled, and additional buffer // objects will be created in init_tensor to represent flattened quantized @@ -4132,7 +4132,7 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, //GGML_ASSERT(offset == 0); // We create subbuffers from the original tensor buffer for scales and - // quants - i.e., scales and quants are aliases into the buffer obejct + // quants - i.e., scales and quants are aliases into the buffer object // that backs the original tensor. This is a cleaner way to adapt to the // new memory management. // In the old code, we allocate new buffers for scales and quants diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp index 519638fd416..04c9e1d7864 100644 --- a/ggml/src/ggml-sycl/common.hpp +++ b/ggml/src/ggml-sycl/common.hpp @@ -76,10 +76,10 @@ extern int g_ggml_sycl_prioritize_dmmv; #define __SYCL_ARCH__ DPCT_COMPATIBILITY_TEMP -#define VER_4VEC 610 // todo for hardward optimize. -#define VER_GEN9 700 // todo for hardward optimize. -#define VER_GEN12 1000000 // todo for hardward optimize. -#define VER_GEN13 (VER_GEN12 + 1030) // todo for hardward optimize. +#define VER_4VEC 610 // todo for hardware optimize. +#define VER_GEN9 700 // todo for hardware optimize. +#define VER_GEN12 1000000 // todo for hardware optimize. +#define VER_GEN13 (VER_GEN12 + 1030) // todo for hardware optimize. #define GGML_SYCL_MAX_NODES 8192 // TODO: adapt to hardwares diff --git a/ggml/src/ggml-sycl/quants.hpp b/ggml/src/ggml-sycl/quants.hpp index d0d5ac9a4e8..14490fea5be 100644 --- a/ggml/src/ggml-sycl/quants.hpp +++ b/ggml/src/ggml-sycl/quants.hpp @@ -29,7 +29,7 @@ namespace ggml_sycl_reordered { // [qs0, qs1, qs2, ..., qsN] [d0, d1, d2, ..., dN] // // Notes: out-of-bounds qs will run into d values -// Aligment relies on the allocated size of qs +// Alignment relies on the allocated size of qs template struct block_q_t; diff --git a/ggml/src/ggml-sycl/softmax.cpp b/ggml/src/ggml-sycl/softmax.cpp index b41124acc13..15d92e5e04c 100644 --- a/ggml/src/ggml-sycl/softmax.cpp +++ b/ggml/src/ggml-sycl/softmax.cpp @@ -37,7 +37,7 @@ struct soft_max_params { }; // When ncols_template == 0 the bounds for the loops in this function are not known and can't be unrolled. -// As we want to keep pragma unroll for all other cases we supress the clang transformation warning here. +// As we want to keep pragma unroll for all other cases we suppress the clang transformation warning here. #ifdef __clang__ #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wpass-failed" diff --git a/ggml/src/ggml-vulkan/CMakeLists.txt b/ggml/src/ggml-vulkan/CMakeLists.txt index de01336cd3f..715a263a6d0 100644 --- a/ggml/src/ggml-vulkan/CMakeLists.txt +++ b/ggml/src/ggml-vulkan/CMakeLists.txt @@ -90,7 +90,7 @@ if (Vulkan_FOUND) target_include_directories(ggml-vulkan PRIVATE ${CMAKE_CURRENT_BINARY_DIR}) # Workaround to the "can't dereference invalidated vector iterator" bug in clang-cl debug build - # Posssibly relevant: https://stackoverflow.com/questions/74748276/visual-studio-no-displays-the-correct-length-of-stdvector + # Possibly relevant: https://stackoverflow.com/questions/74748276/visual-studio-no-displays-the-correct-length-of-stdvector if (MSVC AND CMAKE_CXX_COMPILER_ID STREQUAL "Clang") add_compile_definitions(_ITERATOR_DEBUG_LEVEL=0) endif() From 51f397c1af380ec1101edc4364d5de15a6ef963e Mon Sep 17 00:00:00 2001 From: Andreas Kieslinger <47689530+aendk@users.noreply.github.com> Date: Thu, 5 Mar 2026 12:53:21 +0100 Subject: [PATCH 216/831] CUDA: Improve performance via less synchronizations between token (llama/17795) * Adds CPU-to-CUDA copy capability to ggml_backend_cuda_cpy_tensor_async() * Adds function to relax sync requirements between input copies on supported backends (CUDA for now) * Exchanges synchronous copy with async copy function. * Adds macro guards to allow compilation in non-CUDA builds * Reworked backend detection in ggml-backend.cpp to avoid linking conflicts * Relax requirement of checks in async CUDA copies from backend and buffer type to just buffer type, to avoid linking issues * Minor cleanup * Makes opt-in to relax use of explicit syncs more general. Backends like vulkan which require a synchronization between HtoD copies and graph execution could also adopt this change now. * Reintroduces stricter check for CPU->CUDA backend async copy via GGML_DEVICE_TYPE_CPU. * Corrects initialization of ggml_backend_sync_mode in ggml_backend_sched_split initialization * Simplifies synchronizations to adhere to `saaasg` pattern. * Apply suggestion from @ggerganov (src->buffer to buf_src) Co-authored-by: Georgi Gerganov * Apply suggestion from @ggerganov (src->buffer to buf_src) v2 Co-authored-by: Georgi Gerganov --------- Co-authored-by: Georgi Gerganov --- ggml/src/ggml-backend.cpp | 14 +++++++++----- ggml/src/ggml-cuda/ggml-cuda.cu | 14 ++++++++++---- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index 22c656996cc..bc57df20ba2 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -1455,6 +1455,10 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s int split_backend_id = split->backend_id; ggml_backend_t split_backend = sched->backends[split_backend_id]; + if (sched->events[split_backend_id][sched->cur_copy] == NULL) { + ggml_backend_synchronize(split_backend); + } + // copy the input tensors to the split backend for (int input_id = 0; input_id < split->n_inputs; input_id++) { ggml_backend_t input_backend = ggml_backend_sched_get_tensor_backend(sched, split->inputs[input_id]); @@ -1465,16 +1469,12 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s // inputs from the user must be copied immediately to prevent the user overwriting the data before the copy is done if (sched->events[split_backend_id][sched->cur_copy] != NULL) { ggml_backend_event_synchronize(sched->events[split_backend_id][sched->cur_copy]); - } else { - ggml_backend_synchronize(split_backend); } - ggml_backend_tensor_copy(input, input_cpy); + ggml_backend_tensor_copy_async(input_backend, split_backend, input, input_cpy); } else { // wait for the split backend to finish using the input before overwriting it if (sched->events[split_backend_id][sched->cur_copy] != NULL) { ggml_backend_event_wait(split_backend, sched->events[split_backend_id][sched->cur_copy]); - } else { - ggml_backend_synchronize(split_backend); } // when offloading MoE weights, we can reduce the amount of data copied by copying only the experts that are used @@ -1578,6 +1578,10 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s } } + if (sched->events[split_backend_id][sched->cur_copy] == NULL) { + ggml_backend_synchronize(split_backend); + } + if (!sched->callback_eval) { enum ggml_status ec = ggml_backend_graph_compute_async(split_backend, &split->graph); if (ec != GGML_STATUS_SUCCESS) { diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index b56e3d50f58..b2dcaf42fc3 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2803,11 +2803,14 @@ static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_ ggml_backend_buffer_t buf_src = src->view_src ? src->view_src->buffer : src->buffer; ggml_backend_buffer_t buf_dst = dst->view_src ? dst->view_src->buffer : dst->buffer; - if (!ggml_backend_is_cuda(backend_src) || !ggml_backend_is_cuda(backend_dst)) { + //enables async copies from CPU to CUDA, instead of only CUDA-to-CUDA + bool copy_from_host = ggml_backend_buffer_is_host(buf_src) && ggml_backend_dev_type(backend_src->device) == GGML_BACKEND_DEVICE_TYPE_CPU; + + if (!(copy_from_host || ggml_backend_is_cuda(backend_src)) || !ggml_backend_is_cuda(backend_dst)) { return false; } - if (!ggml_backend_buffer_is_cuda(src->buffer) || !ggml_backend_buffer_is_cuda(dst->buffer)) { + if (!(copy_from_host || ggml_backend_buffer_is_cuda(buf_src)) || !ggml_backend_buffer_is_cuda(dst->buffer)) { return false; } @@ -2818,14 +2821,17 @@ static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_ ggml_backend_cuda_buffer_context * buf_ctx_src = (ggml_backend_cuda_buffer_context *)buf_src->context; ggml_backend_cuda_buffer_context * buf_ctx_dst = (ggml_backend_cuda_buffer_context *)buf_dst->context; - if (cuda_ctx_src->device != buf_ctx_src->device || cuda_ctx_dst->device != buf_ctx_dst->device) { + if ((copy_from_host && cuda_ctx_dst->device != buf_ctx_dst->device) || + !copy_from_host && (cuda_ctx_src->device != buf_ctx_src->device || cuda_ctx_dst->device != buf_ctx_dst->device)) { #ifndef NDEBUG GGML_LOG_DEBUG("%s: backend and buffer devices do not match\n", __func__); #endif return false; } - if (backend_src != backend_dst) { + if (copy_from_host) { + CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(dst), cudaMemcpyHostToDevice, cuda_ctx_dst->stream())); + } else if (backend_src != backend_dst) { // copy on src stream if (cuda_ctx_src->device == cuda_ctx_dst->device) { CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(dst), cudaMemcpyDeviceToDevice, cuda_ctx_src->stream())); From f56fb1be3bf8cd59e47e2ed9fa6768aea8cc89da Mon Sep 17 00:00:00 2001 From: YardenTal44 Date: Fri, 6 Mar 2026 04:29:13 +0200 Subject: [PATCH 217/831] hexagon: add fp16 support for binary ops: add,sub,mul,div (llama/20139) * hexagon: add fp16 support for binary ops: add,sub,mul,div * hexagon: fix test-backend-ops failures for fp16 binary ops on older arches ( --- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 21 +- ggml/src/ggml-hexagon/htp/act-ops.c | 8 +- ggml/src/ggml-hexagon/htp/argsort-ops.c | 12 +- ggml/src/ggml-hexagon/htp/binary-ops.c | 194 ++++++++++----- ggml/src/ggml-hexagon/htp/cpy-ops.c | 7 +- ggml/src/ggml-hexagon/htp/get-rows-ops.c | 7 +- ggml/src/ggml-hexagon/htp/hvx-arith.h | 299 +++++++++++------------ ggml/src/ggml-hexagon/htp/hvx-base.h | 48 ++++ ggml/src/ggml-hexagon/htp/hvx-div.h | 267 +++++++++++++++----- ggml/src/ggml-hexagon/htp/hvx-inverse.h | 92 ++++--- ggml/src/ggml-hexagon/htp/rope-ops.c | 11 +- ggml/src/ggml-hexagon/htp/set-rows-ops.c | 9 +- ggml/src/ggml-hexagon/htp/softmax-ops.c | 10 +- ggml/src/ggml-hexagon/htp/sum-rows-ops.c | 8 +- ggml/src/ggml-hexagon/htp/unary-ops.c | 8 +- 15 files changed, 630 insertions(+), 371 deletions(-) diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 3006e217796..b70da8f3b28 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -1865,15 +1865,26 @@ static bool ggml_hexagon_supported_binary(const struct ggml_hexagon_session * se const struct ggml_tensor * src1 = op->src[1]; const struct ggml_tensor * dst = op; - if (src0->type != GGML_TYPE_F32) { - return false; + if (src0->type == GGML_TYPE_F32) { + if (src1->type != GGML_TYPE_F32) { + return false; + } + if (dst->type != GGML_TYPE_F32) { + return false; + } } - if (src1->type != GGML_TYPE_F32) { - return false; + else if (src0->type == GGML_TYPE_F16) { + if (src1->type != GGML_TYPE_F16) { + return false; + } + if (dst->type != GGML_TYPE_F16) { + return false; + } } - if (dst->type != GGML_TYPE_F32) { + else { return false; } + if (!ggml_are_same_shape(src0, dst)) { return false; } diff --git a/ggml/src/ggml-hexagon/htp/act-ops.c b/ggml/src/ggml-hexagon/htp/act-ops.c index 21bd4050a1d..d8b924981e0 100644 --- a/ggml/src/ggml-hexagon/htp/act-ops.c +++ b/ggml/src/ggml-hexagon/htp/act-ops.c @@ -693,8 +693,8 @@ static int execute_op_activations_f32(struct htp_ops_context * octx) { return HTP_STATUS_NO_SUPPORT; } - const uint32_t n_threads = octx->n_threads; const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3]; + const uint32_t n_threads = MIN(octx->n_threads, src0_nrows); size_t src0_row_size = src0->nb[1]; size_t src1_row_size = src1->nb[1]; // zero bytes if src1 is not used @@ -748,13 +748,11 @@ static int execute_op_activations_f32(struct htp_ops_context * octx) { return HTP_STATUS_OK; } - uint32_t n_jobs = MIN(n_threads, src0_nrows); - // Prepare context struct htp_act_context actx; actx.octx = octx; - actx.src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs; + actx.src0_nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads; actx.src0_row_size = src0_row_size; actx.src1_row_size = src1_row_size; @@ -794,7 +792,7 @@ static int execute_op_activations_f32(struct htp_ops_context * octx) { actx.data_src1 = data_src1; actx.data_dst = (uint8_t *) dst->data; - worker_pool_run_func(octx->ctx->worker_pool, act_op_func, &actx, n_jobs); + worker_pool_run_func(octx->ctx->worker_pool, act_op_func, &actx, n_threads); return HTP_STATUS_OK; } diff --git a/ggml/src/ggml-hexagon/htp/argsort-ops.c b/ggml/src/ggml-hexagon/htp/argsort-ops.c index a4cee980be8..170220e8f80 100644 --- a/ggml/src/ggml-hexagon/htp/argsort-ops.c +++ b/ggml/src/ggml-hexagon/htp/argsort-ops.c @@ -241,6 +241,9 @@ int op_argsort(struct htp_ops_context * octx) { return HTP_STATUS_NO_SUPPORT; } + const uint32_t total_rows = octx->src0.ne[1] * octx->src0.ne[2] * octx->src0.ne[3]; + const uint32_t n_threads = MIN(total_rows, octx->n_threads); + // Allocate scratchpad // We need 1 row of float + 1 row of int32 per thread. uint32_t ne00 = octx->src0.ne[0]; @@ -251,7 +254,7 @@ int op_argsort(struct htp_ops_context * octx) { // Make sure we round up to 256 for alignment requirements spad_per_thread = hex_round_up(spad_per_thread, 256); - size_t total_spad_size = spad_per_thread * octx->n_threads; + size_t total_spad_size = spad_per_thread * n_threads; if (octx->ctx->vtcm_size < total_spad_size) { FARF(ERROR, "argsort: VTCM size too small. Needed %zu, have %zu", total_spad_size, octx->ctx->vtcm_size); @@ -267,15 +270,12 @@ int op_argsort(struct htp_ops_context * octx) { octx->dst.ne[0], octx->dst.ne[1], octx->dst.ne[2], octx->dst.ne[3], octx->src0.data, octx->dst.data); - uint32_t total_rows = octx->src0.ne[1] * octx->src0.ne[2] * octx->src0.ne[3]; - uint32_t n_jobs = MIN(total_rows, octx->n_threads); - struct htp_argsort_context actx; actx.octx = octx; - actx.nrows_per_thread = (total_rows + n_jobs - 1) / n_jobs; + actx.nrows_per_thread = (total_rows + n_threads - 1) / n_threads; // Run jobs - worker_pool_run_func(octx->ctx->worker_pool, htp_argsort_f32, &actx, n_jobs); + worker_pool_run_func(octx->ctx->worker_pool, htp_argsort_f32, &actx, n_threads); return HTP_STATUS_OK; } diff --git a/ggml/src/ggml-hexagon/htp/binary-ops.c b/ggml/src/ggml-hexagon/htp/binary-ops.c index 00dbcf87986..ec90f22de52 100644 --- a/ggml/src/ggml-hexagon/htp/binary-ops.c +++ b/ggml/src/ggml-hexagon/htp/binary-ops.c @@ -95,43 +95,87 @@ static inline uint32_t calc_block_size(struct htp_binary_context * bctx, uint32_ } // Macro for scalar op switch -#define COMPUTE_SCALAR_OP(DST, SRC, VAL, N) \ - switch (octx->op) { \ - case HTP_OP_ADD: hvx_add_scalar_f32_aa(DST, SRC, VAL, N); break; \ - case HTP_OP_SUB: hvx_sub_scalar_f32_aa(DST, SRC, VAL, N); break; \ - case HTP_OP_MUL: hvx_mul_scalar_f32_aa(DST, SRC, VAL, N); break; \ - case HTP_OP_DIV: hvx_mul_scalar_f32_aa(DST, SRC, 1.0f / (VAL), N); break; \ - default: break; \ +#define COMPUTE_SCALAR_OP(DST, SRC, VAL, TYPE, N) \ + if(TYPE == HTP_TYPE_F32) { \ + switch (octx->op) { \ + case HTP_OP_ADD: hvx_add_scalar_f32_aa(DST, SRC, *(float *)VAL, N); break; \ + case HTP_OP_SUB: hvx_sub_scalar_f32_aa(DST, SRC, *(float *)VAL, N); break; \ + case HTP_OP_MUL: hvx_mul_scalar_f32_aa(DST, SRC, *(float *)VAL, N); break; \ + case HTP_OP_DIV: hvx_mul_scalar_f32_aa(DST, SRC, 1.0f / (*(float *)VAL), N); break; \ + default: break; \ + } \ + } \ + else { \ + switch (octx->op) { \ + case HTP_OP_ADD: hvx_add_scalar_f16_aa(DST, SRC, *(_Float16 *)VAL, N); break; \ + case HTP_OP_SUB: hvx_sub_scalar_f16_aa(DST, SRC, *(_Float16 *)VAL, N); break; \ + case HTP_OP_MUL: hvx_mul_scalar_f16_aa(DST, SRC, *(_Float16 *)VAL, N); break; \ + case HTP_OP_DIV: hvx_div_scalar_f16_aa(DST, SRC, *(_Float16 *)VAL, N); break; \ + default: break; \ + } \ } // Macro for vector op switch (All Aligned) -#define COMPUTE_VECTOR_OP_AAA(DST, SRC0, SRC1, N) \ - switch (octx->op) { \ - case HTP_OP_ADD: hvx_add_f32_aaa(DST, SRC0, SRC1, N); break; \ - case HTP_OP_SUB: hvx_sub_f32_aaa(DST, SRC0, SRC1, N); break; \ - case HTP_OP_MUL: hvx_mul_f32_aaa(DST, SRC0, SRC1, N); break; \ - case HTP_OP_DIV: hvx_div_f32_aaa(DST, SRC0, SRC1, N); break; \ - default: break; \ +#define COMPUTE_VECTOR_OP_AAA(DST, SRC0, SRC1, TYPE, N) \ + if(TYPE == HTP_TYPE_F32) { \ + switch (octx->op) { \ + case HTP_OP_ADD: hvx_add_f32_aaa(DST, SRC0, SRC1, N); break; \ + case HTP_OP_SUB: hvx_sub_f32_aaa(DST, SRC0, SRC1, N); break; \ + case HTP_OP_MUL: hvx_mul_f32_aaa(DST, SRC0, SRC1, N); break; \ + case HTP_OP_DIV: hvx_div_f32_aaa(DST, SRC0, SRC1, N); break; \ + default: break; \ + } \ + } \ + else { \ + switch (octx->op) { \ + case HTP_OP_ADD: hvx_add_f16_aaa(DST, SRC0, SRC1, N); break; \ + case HTP_OP_SUB: hvx_sub_f16_aaa(DST, SRC0, SRC1, N); break; \ + case HTP_OP_MUL: hvx_mul_f16_aaa(DST, SRC0, SRC1, N); break; \ + case HTP_OP_DIV: hvx_div_f16_aaa(DST, SRC0, SRC1, N); break; \ + default: break; \ + } \ } // Macro for vector op switch (Dst Aligned, Src0 Aligned, Src1 Unaligned) -#define COMPUTE_VECTOR_OP_AAU(DST, SRC0, SRC1, N) \ - switch (octx->op) { \ - case HTP_OP_ADD: hvx_add_f32_aau(DST, SRC0, SRC1, N); break; \ - case HTP_OP_SUB: hvx_sub_f32_aau(DST, SRC0, SRC1, N); break; \ - case HTP_OP_MUL: hvx_mul_f32_aau(DST, SRC0, SRC1, N); break; \ - case HTP_OP_DIV: hvx_div_f32_aau(DST, SRC0, SRC1, N); break; \ - default: break; \ +#define COMPUTE_VECTOR_OP_AAU(DST, SRC0, SRC1, TYPE, N) \ + if(TYPE == HTP_TYPE_F32) { \ + switch (octx->op) { \ + case HTP_OP_ADD: hvx_add_f32_aau(DST, SRC0, SRC1, N); break; \ + case HTP_OP_SUB: hvx_sub_f32_aau(DST, SRC0, SRC1, N); break; \ + case HTP_OP_MUL: hvx_mul_f32_aau(DST, SRC0, SRC1, N); break; \ + case HTP_OP_DIV: hvx_div_f32_aau(DST, SRC0, SRC1, N); break; \ + default: break; \ + } \ + } \ + else { \ + switch (octx->op) { \ + case HTP_OP_ADD: hvx_add_f16_aau(DST, SRC0, SRC1, N); break; \ + case HTP_OP_SUB: hvx_sub_f16_aau(DST, SRC0, SRC1, N); break; \ + case HTP_OP_MUL: hvx_mul_f16_aau(DST, SRC0, SRC1, N); break; \ + case HTP_OP_DIV: hvx_div_f16_aau(DST, SRC0, SRC1, N); break; \ + default: break; \ + } \ } // Macro for vector op switch (All Unaligned - generic loop used in element repeat) -#define COMPUTE_VECTOR_OP_UUU(DST, SRC0, SRC1, N) \ - switch (octx->op) { \ - case HTP_OP_ADD: hvx_add_f32_uuu(DST, SRC0, SRC1, N); break; \ - case HTP_OP_SUB: hvx_sub_f32_uuu(DST, SRC0, SRC1, N); break; \ - case HTP_OP_MUL: hvx_mul_f32_uuu(DST, SRC0, SRC1, N); break; \ - case HTP_OP_DIV: hvx_div_f32_uuu(DST, SRC0, SRC1, N); break; \ - default: break; \ +#define COMPUTE_VECTOR_OP_UUU(DST, SRC0, SRC1, TYPE, N) \ + if(TYPE == HTP_TYPE_F32) { \ + switch (octx->op) { \ + case HTP_OP_ADD: hvx_add_f32_uuu(DST, SRC0, SRC1, N); break; \ + case HTP_OP_SUB: hvx_sub_f32_uuu(DST, SRC0, SRC1, N); break; \ + case HTP_OP_MUL: hvx_mul_f32_uuu(DST, SRC0, SRC1, N); break; \ + case HTP_OP_DIV: hvx_div_f32_uuu(DST, SRC0, SRC1, N); break; \ + default: break; \ + } \ + } \ + else { \ + switch (octx->op) { \ + case HTP_OP_ADD: hvx_add_f16_uuu(DST, SRC0, SRC1, N); break; \ + case HTP_OP_SUB: hvx_sub_f16_uuu(DST, SRC0, SRC1, N); break; \ + case HTP_OP_MUL: hvx_mul_f16_uuu(DST, SRC0, SRC1, N); break; \ + case HTP_OP_DIV: hvx_div_f16_uuu(DST, SRC0, SRC1, N); break; \ + default: break; \ + } \ } // 1. Scalar src1 (ne10 == 1) @@ -140,6 +184,8 @@ static void binary_job_scalar(unsigned int nth, unsigned int ith, void * data) { struct htp_ops_context * octx = bctx->octx; htp_binary_preamble; + const uint32_t src0_type = octx->src0.type; + const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16); const uint32_t total_rows = ne01 * ne02 * ne03; const uint32_t start_row = bctx->nrows_per_thread * ith; const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows); @@ -170,7 +216,7 @@ static void binary_job_scalar(unsigned int nth, unsigned int ith, void * data) { uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half; dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0); - dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size); + dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size); ir_prefetch += current_block_size; spad_idx ^= 1; } @@ -199,13 +245,12 @@ static void binary_job_scalar(unsigned int nth, unsigned int ith, void * data) { for (uint32_t r = 0; r < current_block_size; r++) { uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned; uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned; - float val = *(float *)src1_ptr; + COMPUTE_SCALAR_OP(r_dst, r_src0, src1_ptr, src0_type, ne00); src1_ptr += s1_stride; - COMPUTE_SCALAR_OP(r_dst, r_src0, val, ne00); } uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; - dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size); + dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, current_block_size); if (ir_prefetch < end_row) { uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); @@ -216,7 +261,7 @@ static void binary_job_scalar(unsigned int nth, unsigned int ith, void * data) { p01 = prem - p02 * ne01; uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01; - dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size); + dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size); ir_prefetch += next_block_size; } ir += current_block_size; @@ -230,6 +275,8 @@ static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, voi struct htp_ops_context * octx = bctx->octx; htp_binary_preamble; + const uint32_t src0_type = octx->src0.type; + const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16); const uint32_t total_rows = ne01 * ne02 * ne03; const uint32_t start_row = bctx->nrows_per_thread * ith; const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows); @@ -268,8 +315,8 @@ static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, voi uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half; dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0); - dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size); - dma_queue_push(q, dma_make_ptr(s1_spad, src1_base), bctx->src1_row_size_aligned, bctx->src1_dma_stride, ne00 * sizeof(float), current_block_size); + dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size); + dma_queue_push(q, dma_make_ptr(s1_spad, src1_base), bctx->src1_row_size_aligned, bctx->src1_dma_stride, row_size_bytes, current_block_size); ir_prefetch += current_block_size; spad_idx ^= 1; } @@ -284,7 +331,7 @@ static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, voi uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned; uint8_t * r_src1 = s1_spad + r * bctx->src1_row_size_aligned; uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned; - COMPUTE_VECTOR_OP_AAA(r_dst, r_src0, r_src1, ne00); + COMPUTE_VECTOR_OP_AAA(r_dst, r_src0, r_src1, src0_type, ne00); } uint32_t i03, i02, i01, rem; @@ -293,7 +340,7 @@ static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, voi i02 = fastdiv(rem, &bctx->dim1_div); i01 = rem - i02 * ne01; uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; - dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size); + dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, current_block_size); if (ir_prefetch < end_row) { uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); @@ -310,8 +357,8 @@ static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, voi uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01; uint8_t * s1_next = (uint8_t *)src1->data + p13 * nb13 + p12 * nb12 + p11 * nb11; - dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size); - dma_queue_push(q, dma_make_ptr(s1_spad, s1_next), bctx->src1_row_size_aligned, bctx->src1_dma_stride, ne00 * sizeof(float), next_block_size); + dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size); + dma_queue_push(q, dma_make_ptr(s1_spad, s1_next), bctx->src1_row_size_aligned, bctx->src1_dma_stride, row_size_bytes, next_block_size); ir_prefetch += next_block_size; } @@ -326,6 +373,8 @@ static void binary_job_vector_row_broadcast(unsigned int nth, unsigned int ith, struct htp_ops_context * octx = bctx->octx; htp_binary_preamble; + const uint32_t src0_type = octx->src0.type; + const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16); const uint32_t total_rows = ne01 * ne02 * ne03; const uint32_t start_row = bctx->nrows_per_thread * ith; const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows); @@ -359,7 +408,7 @@ static void binary_job_vector_row_broadcast(unsigned int nth, unsigned int ith, uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half; dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0); - dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size); + dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size); ir_prefetch += current_block_size; spad_idx ^= 1; } @@ -373,7 +422,7 @@ static void binary_job_vector_row_broadcast(unsigned int nth, unsigned int ith, uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned; uint8_t * r_src1 = (uint8_t *)s1_ptr; // Constant uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned; - COMPUTE_VECTOR_OP_AAA(r_dst, r_src0, r_src1, ne00); + COMPUTE_VECTOR_OP_AAA(r_dst, r_src0, r_src1, src0_type, ne00); } uint32_t i03, i02, i01, rem; @@ -382,7 +431,7 @@ static void binary_job_vector_row_broadcast(unsigned int nth, unsigned int ith, i02 = fastdiv(rem, &bctx->dim1_div); i01 = rem - i02 * ne01; uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; - dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size); + dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, current_block_size); if (ir_prefetch < end_row) { uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); @@ -392,7 +441,7 @@ static void binary_job_vector_row_broadcast(unsigned int nth, unsigned int ith, p02 = fastdiv(prem, &bctx->dim1_div); p01 = prem - p02 * ne01; uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01; - dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size); + dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size); ir_prefetch += next_block_size; } ir += current_block_size; @@ -406,6 +455,8 @@ static void binary_job_vector_complex(unsigned int nth, unsigned int ith, void * struct htp_ops_context * octx = bctx->octx; htp_binary_preamble; + const uint32_t src0_type = octx->src0.type; + const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16); const uint32_t total_rows = ne01 * ne02 * ne03; const uint32_t start_row = bctx->nrows_per_thread * ith; const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows); @@ -435,7 +486,7 @@ static void binary_job_vector_complex(unsigned int nth, unsigned int ith, void * uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half; dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0); - dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size); + dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size); ir_prefetch += current_block_size; spad_idx ^= 1; } @@ -462,11 +513,11 @@ static void binary_job_vector_complex(unsigned int nth, unsigned int ith, void * uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned; // Read src1 from DDR (unaligned) - COMPUTE_VECTOR_OP_AAU(r_dst, r_src0, r_src1, ne00); + COMPUTE_VECTOR_OP_AAU(r_dst, r_src0, r_src1, src0_type, ne00); } uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; - dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size); + dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, current_block_size); if (ir_prefetch < end_row) { uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); @@ -476,7 +527,7 @@ static void binary_job_vector_complex(unsigned int nth, unsigned int ith, void * p02 = fastdiv(prem, &bctx->dim1_div); p01 = prem - p02 * ne01; uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01; - dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size); + dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size); ir_prefetch += next_block_size; } ir += current_block_size; @@ -490,6 +541,9 @@ static void binary_job_element_repeat(unsigned int nth, unsigned int ith, void * struct htp_ops_context * octx = bctx->octx; htp_binary_preamble; + const uint32_t src0_type = octx->src0.type; + const uint32_t elem_size_bytes = (src0_type == HTP_TYPE_F32) ? sizeof(float) : sizeof(_Float16); + const uint32_t row_size_bytes = ne00 * elem_size_bytes;; const uint32_t total_rows = ne01 * ne02 * ne03; const uint32_t start_row = bctx->nrows_per_thread * ith; const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows); @@ -519,7 +573,7 @@ static void binary_job_element_repeat(unsigned int nth, unsigned int ith, void * uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half; dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0); - dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size); + dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size); ir_prefetch += current_block_size; spad_idx ^= 1; } @@ -549,12 +603,12 @@ static void binary_job_element_repeat(unsigned int nth, unsigned int ith, void * for (uint32_t c = 0; c < ne00; c += ne10) { uint32_t len = MIN(ne10, ne00 - c); // Use UUU for speed and simplicity - COMPUTE_VECTOR_OP_UUU(r_dst + c * sizeof(float), r_src0 + c * sizeof(float), r_src1_row, len); + COMPUTE_VECTOR_OP_UUU(r_dst + c * elem_size_bytes, r_src0 + c * elem_size_bytes, r_src1_row, src0_type, len); } } uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; - dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size); + dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, current_block_size); if (ir_prefetch < end_row) { uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); @@ -564,7 +618,7 @@ static void binary_job_element_repeat(unsigned int nth, unsigned int ith, void * p02 = fastdiv(prem, &bctx->dim1_div); p01 = prem - p02 * ne01; uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01; - dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size); + dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size); ir_prefetch += next_block_size; } ir += current_block_size; @@ -672,18 +726,20 @@ static void binary_job_add_id(unsigned int nth, unsigned int ith, void * data) { dma_queue_flush(q); } -static int execute_op_binary_f32(struct htp_ops_context * octx) { +static int execute_op_binary(struct htp_ops_context * octx) { const struct htp_tensor * src0 = &octx->src0; const struct htp_tensor * src1 = &octx->src1; struct htp_tensor * dst = &octx->dst; - const uint32_t n_threads = octx->n_threads; const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3]; + const uint32_t n_threads = MIN(octx->n_threads, src0_nrows); // Use packed row sizes for VTCM allocation - const size_t src0_row_size = src0->ne[0] * sizeof(float); - const size_t src1_row_size = src1->ne[0] * sizeof(float); - const size_t dst_row_size = dst->ne[0] * sizeof(float); + const uint32_t src0_type = octx->src0.type; + const size_t elem_size = (src0_type == HTP_TYPE_F32) ? sizeof(float) : sizeof(_Float16); + const size_t src0_row_size = src0->ne[0] * elem_size; + const size_t src1_row_size = src1->ne[0] * elem_size; + const size_t dst_row_size = dst->ne[0] * elem_size; // Align to VLEN const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN); @@ -694,7 +750,7 @@ static int execute_op_binary_f32(struct htp_ops_context * octx) { bool is_scalar = !is_add_id && (src1->ne[0] == 1); // Determine which kernel we will use to alloc memory and dispatch - bool use_vector_same = !is_add_id && !is_scalar && src1->ne[0] == src0->ne[0] && + bool use_vector_same = !is_add_id && !is_scalar && ((src0->nb[1] % VLEN) == 0) && (src1->ne[0] == src0->ne[0]) && (src1->ne[1] == src0->ne[1] || src1->ne[1] == 1) && (src1->ne[2] == src0->ne[2] || src1->ne[2] == 1) && (src1->ne[3] == src0->ne[3] || src1->ne[3] == 1); @@ -726,7 +782,7 @@ static int execute_op_binary_f32(struct htp_ops_context * octx) { } if (rows_per_buffer < 1) { - FARF(ERROR, "binary-f32: VTCM too small\n"); + FARF(ERROR, "binary: VTCM too small\n"); return HTP_STATUS_VTCM_TOO_SMALL; } @@ -761,16 +817,14 @@ static int execute_op_binary_f32(struct htp_ops_context * octx) { return HTP_STATUS_OK; } - uint32_t n_jobs = MIN(n_threads, src0_nrows); - dma_queue * q = octx->ctx->dma[0]; if (is_row_bcast) { - dma_queue_push(q, dma_make_ptr(octx->src1_spad.data, (const void *) src1->data), src1_row_size_aligned, 0, src1->ne[0] * sizeof(float), 1); + dma_queue_push(q, dma_make_ptr(octx->src1_spad.data, (const void *) src1->data), src1_row_size_aligned, 0, src1->ne[0] * elem_size, 1); } struct htp_binary_context bctx; bctx.octx = octx; - bctx.nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs; + bctx.nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads; bctx.block_max = rows_per_buffer; bctx.src0_row_size_aligned = src0_row_size_aligned; bctx.src1_row_size_aligned = src1_row_size_aligned; @@ -814,14 +868,24 @@ static int execute_op_binary_f32(struct htp_ops_context * octx) { dma_queue_pop(q); } - worker_pool_run_func(octx->ctx->worker_pool, worker_func, &bctx, n_jobs); + worker_pool_run_func(octx->ctx->worker_pool, worker_func, &bctx, n_threads); return HTP_STATUS_OK; } int op_binary(struct htp_ops_context * octx) { - if (octx->src0.type == HTP_TYPE_F32) { - return execute_op_binary_f32(octx); + + // Does not support permutations of src1 + const struct htp_tensor * src1 = &octx->src1; + if (src1->nb[1] < src1->nb[0]) { + return HTP_STATUS_NO_SUPPORT; + } + + const uint32_t src0_type = octx->src0.type; + if ((src0_type == HTP_TYPE_F32) || (src0_type == HTP_TYPE_F16)) { + return execute_op_binary(octx); } + return HTP_STATUS_NO_SUPPORT; } + diff --git a/ggml/src/ggml-hexagon/htp/cpy-ops.c b/ggml/src/ggml-hexagon/htp/cpy-ops.c index 559ca183789..a40d866b9c3 100644 --- a/ggml/src/ggml-hexagon/htp/cpy-ops.c +++ b/ggml/src/ggml-hexagon/htp/cpy-ops.c @@ -202,6 +202,8 @@ static void cpy_work_func(unsigned int n, unsigned int i, void *data) { int op_cpy(struct htp_ops_context * octx) { cpy_preamble; + const uint32_t n_threads = MIN(nr, octx->n_threads); + struct htp_copy_context ct; ct.octx = octx; @@ -227,8 +229,7 @@ int op_cpy(struct htp_ops_context * octx) { const bool transposed = (nb00 > nb01) || (nb0 > nb1); const bool sameshape = !transposed && (ne00 == ne0 && ne01 == ne1 && ne02 == ne2 && ne03 == ne3); - const uint32_t n_jobs = MIN(nr, octx->n_threads); - ct.src0_nrows_per_thread = (nr + n_jobs - 1) / n_jobs; + ct.src0_nrows_per_thread = (nr + n_threads - 1) / n_threads; if (sametype && sameshape) { ct.copy = cpy_thread_sametype_sameshape; @@ -245,7 +246,7 @@ int op_cpy(struct htp_ops_context * octx) { return HTP_STATUS_NO_SUPPORT; } - worker_pool_run_func(octx->ctx->worker_pool, cpy_work_func, &ct, n_jobs); + worker_pool_run_func(octx->ctx->worker_pool, cpy_work_func, &ct, n_threads); return HTP_STATUS_OK; } diff --git a/ggml/src/ggml-hexagon/htp/get-rows-ops.c b/ggml/src/ggml-hexagon/htp/get-rows-ops.c index bf24bbda70a..047d2850aaa 100644 --- a/ggml/src/ggml-hexagon/htp/get-rows-ops.c +++ b/ggml/src/ggml-hexagon/htp/get-rows-ops.c @@ -82,6 +82,8 @@ static void get_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *da int op_get_rows(struct htp_ops_context * octx) { get_rows_preamble; + const uint32_t n_threads = MIN(nr, octx->n_threads); + if (octx->src0.type != HTP_TYPE_F32) { return HTP_STATUS_NO_SUPPORT; } @@ -103,9 +105,8 @@ int op_get_rows(struct htp_ops_context * octx) { grctx.get_rows_div_ne10 = init_fastdiv_values(octx->src1.ne[0]); grctx.get_rows_div_ne10_ne11 = init_fastdiv_values(octx->src1.ne[0] * octx->src1.ne[1]); - const uint32_t n_jobs = MIN(nr, octx->n_threads); - grctx.src1_nrows_per_thread = (nr + n_jobs - 1) / n_jobs; + grctx.src1_nrows_per_thread = (nr + n_threads - 1) / n_threads; - worker_pool_run_func(octx->ctx->worker_pool, get_rows_thread_f32_f32, &grctx, n_jobs); + worker_pool_run_func(octx->ctx->worker_pool, get_rows_thread_f32_f32, &grctx, n_threads); return HTP_STATUS_OK; } diff --git a/ggml/src/ggml-hexagon/htp/hvx-arith.h b/ggml/src/ggml-hexagon/htp/hvx-arith.h index 2577cdd0418..82e3416970b 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-arith.h +++ b/ggml/src/ggml-hexagon/htp/hvx-arith.h @@ -13,14 +13,15 @@ // Binary operations (add, mul, sub) // -#define hvx_arith_loop_body(dst_type, src0_type, src1_type, vec_store, vec_op) \ +#define UNUSED(x) (void)(x) + +#define hvx_arith_loop_body(dst_type, src0_type, src1_type, elem_size, vec_store, vec_op) \ do { \ dst_type * restrict vdst = (dst_type *) dst; \ src0_type * restrict vsrc0 = (src0_type *) src0; \ src1_type * restrict vsrc1 = (src1_type *) src1; \ \ - const uint32_t elem_size = sizeof(float); \ - const uint32_t epv = 128 / elem_size; \ + const uint32_t epv = 128 / (elem_size); \ const uint32_t nvec = n / epv; \ const uint32_t nloe = n % epv; \ \ @@ -32,62 +33,74 @@ } \ if (nloe) { \ HVX_Vector v = vec_op(vsrc0[i], vsrc1[i]); \ - vec_store((void *) &vdst[i], nloe * elem_size, v); \ + vec_store((void *) &vdst[i], nloe * (elem_size), v); \ } \ } while(0) #if __HVX_ARCH__ < 79 -#define HVX_OP_ADD(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(a, b)) -#define HVX_OP_SUB(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(a, b)) -#define HVX_OP_MUL(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b)) + +#define HVX_OP_ADD_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(a, b)) +#define HVX_OP_SUB_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(a, b)) +#define HVX_OP_MUL_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b)) + #else -#define HVX_OP_ADD(a, b) Q6_Vsf_vadd_VsfVsf(a, b) -#define HVX_OP_SUB(a, b) Q6_Vsf_vsub_VsfVsf(a, b) -#define HVX_OP_MUL(a, b) Q6_Vsf_vmpy_VsfVsf(a, b) + +#define HVX_OP_ADD_F32(a, b) Q6_Vsf_vadd_VsfVsf(a, b) +#define HVX_OP_SUB_F32(a, b) Q6_Vsf_vsub_VsfVsf(a, b) +#define HVX_OP_MUL_F32(a, b) Q6_Vsf_vmpy_VsfVsf(a, b) + #endif +#define HVX_OP_ADD_F16(a, b) hvx_vec_add_f16_f16(a, b) +#define HVX_OP_SUB_F16(a, b) hvx_vec_sub_f16_f16(a, b) +#define HVX_OP_MUL_F16(a, b) hvx_vec_mul_f16_f16(a, b) + // Generic macro to define alignment permutations for an op -#define DEFINE_HVX_BINARY_OP_VARIANTS(OP_NAME, OP_MACRO) \ +#define DEFINE_HVX_BINARY_OP_VARIANTS(OP_NAME, OP_MACRO, ELEM_TYPE) \ static inline void OP_NAME##_aaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ assert((uintptr_t) dst % 128 == 0); \ assert((uintptr_t) src0 % 128 == 0); \ assert((uintptr_t) src1 % 128 == 0); \ - hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a, OP_MACRO); \ + hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, sizeof(ELEM_TYPE), hvx_vec_store_a, OP_MACRO); \ } \ static inline void OP_NAME##_aau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ assert((uintptr_t) dst % 128 == 0); \ assert((uintptr_t) src0 % 128 == 0); \ - hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a, OP_MACRO); \ + hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, sizeof(ELEM_TYPE), hvx_vec_store_a, OP_MACRO); \ } \ static inline void OP_NAME##_aua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ assert((uintptr_t) dst % 128 == 0); \ assert((uintptr_t) src1 % 128 == 0); \ - hvx_arith_loop_body(HVX_Vector, HVX_UVector, HVX_Vector, hvx_vec_store_a, OP_MACRO); \ + hvx_arith_loop_body(HVX_Vector, HVX_UVector, HVX_Vector, sizeof(ELEM_TYPE), hvx_vec_store_a, OP_MACRO); \ } \ static inline void OP_NAME##_auu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ assert((uintptr_t) dst % 128 == 0); \ - hvx_arith_loop_body(HVX_Vector, HVX_UVector, HVX_UVector, hvx_vec_store_a, OP_MACRO); \ + hvx_arith_loop_body(HVX_Vector, HVX_UVector, HVX_UVector, sizeof(ELEM_TYPE), hvx_vec_store_a, OP_MACRO); \ } \ static inline void OP_NAME##_uaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ assert((uintptr_t) src0 % 128 == 0); \ assert((uintptr_t) src1 % 128 == 0); \ - hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u, OP_MACRO); \ + hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, sizeof(ELEM_TYPE), hvx_vec_store_u, OP_MACRO); \ } \ static inline void OP_NAME##_uau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ assert((uintptr_t) src0 % 128 == 0); \ - hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_UVector, hvx_vec_store_u, OP_MACRO); \ + hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_UVector, sizeof(ELEM_TYPE), hvx_vec_store_u, OP_MACRO); \ } \ static inline void OP_NAME##_uua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ assert((uintptr_t) src1 % 128 == 0); \ - hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_Vector, hvx_vec_store_u, OP_MACRO); \ + hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_Vector, sizeof(ELEM_TYPE), hvx_vec_store_u, OP_MACRO); \ } \ static inline void OP_NAME##_uuu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ - hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u, OP_MACRO); \ + hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, sizeof(ELEM_TYPE), hvx_vec_store_u, OP_MACRO); \ } \ -DEFINE_HVX_BINARY_OP_VARIANTS(hvx_add_f32, HVX_OP_ADD) -DEFINE_HVX_BINARY_OP_VARIANTS(hvx_sub_f32, HVX_OP_SUB) -DEFINE_HVX_BINARY_OP_VARIANTS(hvx_mul_f32, HVX_OP_MUL) +DEFINE_HVX_BINARY_OP_VARIANTS(hvx_add_f32, HVX_OP_ADD_F32, float) +DEFINE_HVX_BINARY_OP_VARIANTS(hvx_sub_f32, HVX_OP_SUB_F32, float) +DEFINE_HVX_BINARY_OP_VARIANTS(hvx_mul_f32, HVX_OP_MUL_F32, float) + +DEFINE_HVX_BINARY_OP_VARIANTS(hvx_add_f16, HVX_OP_ADD_F16, _Float16) +DEFINE_HVX_BINARY_OP_VARIANTS(hvx_sub_f16, HVX_OP_SUB_F16, _Float16) +DEFINE_HVX_BINARY_OP_VARIANTS(hvx_mul_f16, HVX_OP_MUL_F16, _Float16) // Dispatcher logic #define HVX_BINARY_DISPATCHER(OP_NAME) \ @@ -115,6 +128,10 @@ HVX_BINARY_DISPATCHER(hvx_add_f32) HVX_BINARY_DISPATCHER(hvx_sub_f32) HVX_BINARY_DISPATCHER(hvx_mul_f32) +HVX_BINARY_DISPATCHER(hvx_add_f16) +HVX_BINARY_DISPATCHER(hvx_sub_f16) +HVX_BINARY_DISPATCHER(hvx_mul_f16) + // Mul-Mul Optimized static inline void hvx_mul_mul_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint8_t * restrict src2, const uint32_t num_elems) { assert((unsigned long) dst % 128 == 0); @@ -136,26 +153,25 @@ static inline void hvx_mul_mul_f32_aa(uint8_t * restrict dst, const uint8_t * re _Pragma("unroll(4)") for (; i < nvec; i++) { - HVX_Vector v1 = HVX_OP_MUL(vsrc0[i], vsrc1[i]); + HVX_Vector v1 = HVX_OP_MUL_F32(vsrc0[i], vsrc1[i]); vdst[i] = HVX_OP_MUL(v1, vsrc2[i]); } if (nloe) { - HVX_Vector v1 = HVX_OP_MUL(vsrc0[i], vsrc1[i]); - HVX_Vector v2 = HVX_OP_MUL(v1, vsrc2[i]); + HVX_Vector v1 = HVX_OP_MUL_F32(vsrc0[i], vsrc1[i]); + HVX_Vector v2 = HVX_OP_MUL_F32(v1, vsrc2[i]); hvx_vec_store_a((void *) &vdst[i], nloe * elem_size, v2); } } // Scalar Operations -#define hvx_scalar_loop_body(dst_type, src_type, vec_store, scalar_op_macro) \ +#define hvx_scalar_loop_body(dst_type, src_type, elem_size, vec_store, scalar_op_macro) \ do { \ dst_type * restrict vdst = (dst_type *) dst; \ src_type * restrict vsrc = (src_type *) src; \ \ - const uint32_t elem_size = sizeof(float); \ - const uint32_t epv = 128 / elem_size; \ + const uint32_t epv = 128 / (elem_size); \ const uint32_t nvec = n / epv; \ const uint32_t nloe = n % epv; \ \ @@ -169,138 +185,88 @@ static inline void hvx_mul_mul_f32_aa(uint8_t * restrict dst, const uint8_t * re if (nloe) { \ HVX_Vector v = vsrc[i]; \ v = scalar_op_macro(v); \ - vec_store((void *) &vdst[i], nloe * elem_size, v); \ + vec_store((void *) &vdst[i], nloe * (elem_size), v); \ } \ } while(0) -#define HVX_OP_ADD_SCALAR(v) \ +#define HVX_OP_ADD_SCALAR_F32(v) \ ({ \ const HVX_VectorPred pred_inf = Q6_Q_vcmp_eq_VwVw(inf, v); \ - HVX_Vector out = HVX_OP_ADD(v, val_vec); \ + HVX_Vector out = HVX_OP_ADD_F32(v, val_vec); \ Q6_V_vmux_QVV(pred_inf, inf, out); \ }) -#define HVX_OP_MUL_SCALAR(v) HVX_OP_MUL(v, val_vec) -#define HVX_OP_SUB_SCALAR(v) HVX_OP_SUB(v, val_vec) - -// Add Scalar Variants - -static inline void hvx_add_scalar_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { - const HVX_Vector val_vec = hvx_vec_splat_f32(val); - const HVX_Vector inf = hvx_vec_splat_f32(INFINITY); - assert((unsigned long) dst % 128 == 0); - assert((unsigned long) src % 128 == 0); - hvx_scalar_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_ADD_SCALAR); -} - -static inline void hvx_add_scalar_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { - const HVX_Vector val_vec = hvx_vec_splat_f32(val); - const HVX_Vector inf = hvx_vec_splat_f32(INFINITY); - assert((unsigned long) dst % 128 == 0); - hvx_scalar_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_ADD_SCALAR); -} - -static inline void hvx_add_scalar_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { - const HVX_Vector val_vec = hvx_vec_splat_f32(val); - const HVX_Vector inf = hvx_vec_splat_f32(INFINITY); - assert((unsigned long) src % 128 == 0); - hvx_scalar_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u, HVX_OP_ADD_SCALAR); -} - -static inline void hvx_add_scalar_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { - const HVX_Vector val_vec = hvx_vec_splat_f32(val); - static const float kInf = INFINITY; - const HVX_Vector inf = hvx_vec_splat_f32(kInf); - hvx_scalar_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_ADD_SCALAR); -} - -// Sub Scalar Variants +#define HVX_OP_MUL_SCALAR_F32(v) HVX_OP_MUL_F32(v, val_vec) +#define HVX_OP_SUB_SCALAR_F32(v) HVX_OP_SUB_F32(v, val_vec) -static inline void hvx_sub_scalar_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { - const HVX_Vector val_vec = hvx_vec_splat_f32(val); - assert((unsigned long) dst % 128 == 0); - assert((unsigned long) src % 128 == 0); - hvx_scalar_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_SUB_SCALAR); -} - -static inline void hvx_sub_scalar_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { - const HVX_Vector val_vec = hvx_vec_splat_f32(val); - assert((unsigned long) dst % 128 == 0); - hvx_scalar_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_SUB_SCALAR); -} - -static inline void hvx_sub_scalar_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { - const HVX_Vector val_vec = hvx_vec_splat_f32(val); - assert((unsigned long) src % 128 == 0); - hvx_scalar_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u, HVX_OP_SUB_SCALAR); -} - -static inline void hvx_sub_scalar_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { - const HVX_Vector val_vec = hvx_vec_splat_f32(val); - hvx_scalar_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_SUB_SCALAR); -} +#define HVX_OP_ADD_SCALAR_F16(v) \ + ({ \ + const HVX_VectorPred pred_inf = Q6_Q_vcmp_eq_VhVh(inf, v); \ + HVX_Vector out = HVX_OP_ADD_F16(v, val_vec); \ + Q6_V_vmux_QVV(pred_inf, inf, out); \ + }) -// Mul Scalar Variants +#define HVX_OP_MUL_SCALAR_F16(v) HVX_OP_MUL_F16(v, val_vec) +#define HVX_OP_SUB_SCALAR_F16(v) HVX_OP_SUB_F16(v, val_vec) -static inline void hvx_mul_scalar_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { - const HVX_Vector val_vec = hvx_vec_splat_f32(val); - assert((unsigned long) dst % 128 == 0); - assert((unsigned long) src % 128 == 0); - hvx_scalar_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_MUL_SCALAR); -} +// Scalar Variants -static inline void hvx_mul_scalar_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { - const HVX_Vector val_vec = hvx_vec_splat_f32(val); - assert((unsigned long) dst % 128 == 0); - hvx_scalar_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_MUL_SCALAR); -} +// Generic macro to define alignment permutations for an op +#define DEFINE_HVX_BINARY_SCALAR_OP_VARIANTS(OP_NAME, OP_MACRO, SPLAT_MACRO, ELEM_TYPE) \ +static inline void OP_NAME##_aa(uint8_t * restrict dst, const uint8_t * restrict src, const ELEM_TYPE val, uint32_t n) { \ + const HVX_Vector val_vec = SPLAT_MACRO(val); \ + const HVX_Vector inf = SPLAT_MACRO((ELEM_TYPE)INFINITY); UNUSED(inf); \ + assert((uintptr_t) dst % 128 == 0); \ + assert((uintptr_t) src % 128 == 0); \ + hvx_scalar_loop_body(HVX_Vector, HVX_Vector, sizeof(ELEM_TYPE), hvx_vec_store_a, OP_MACRO); \ +} \ +static inline void OP_NAME##_au(uint8_t * restrict dst, const uint8_t * restrict src, const ELEM_TYPE val, uint32_t n) { \ + const HVX_Vector val_vec = SPLAT_MACRO(val); \ + const HVX_Vector inf = SPLAT_MACRO((ELEM_TYPE)INFINITY); UNUSED(inf); \ + assert((uintptr_t) dst % 128 == 0); \ + hvx_scalar_loop_body(HVX_Vector, HVX_UVector, sizeof(ELEM_TYPE), hvx_vec_store_a, OP_MACRO); \ +} \ +static inline void OP_NAME##_ua(uint8_t * restrict dst, const uint8_t * restrict src, const ELEM_TYPE val, uint32_t n) { \ + const HVX_Vector val_vec = SPLAT_MACRO(val); \ + const HVX_Vector inf = SPLAT_MACRO((ELEM_TYPE)INFINITY); UNUSED(inf); \ + assert((uintptr_t) src % 128 == 0); \ + hvx_scalar_loop_body(HVX_UVector, HVX_Vector, sizeof(ELEM_TYPE), hvx_vec_store_u, OP_MACRO); \ +} \ +static inline void OP_NAME##_uu(uint8_t * restrict dst, const uint8_t * restrict src, const ELEM_TYPE val, uint32_t n) { \ + const HVX_Vector val_vec = SPLAT_MACRO(val); \ + const HVX_Vector inf = SPLAT_MACRO((ELEM_TYPE)INFINITY); UNUSED(inf); \ + hvx_scalar_loop_body(HVX_UVector, HVX_UVector, sizeof(ELEM_TYPE), hvx_vec_store_u, OP_MACRO); \ +} \ -static inline void hvx_mul_scalar_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { - const HVX_Vector val_vec = hvx_vec_splat_f32(val); - assert((unsigned long) src % 128 == 0); - hvx_scalar_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u, HVX_OP_MUL_SCALAR); -} +DEFINE_HVX_BINARY_SCALAR_OP_VARIANTS(hvx_add_scalar_f32, HVX_OP_ADD_SCALAR_F32, hvx_vec_splat_f32, float) +DEFINE_HVX_BINARY_SCALAR_OP_VARIANTS(hvx_sub_scalar_f32, HVX_OP_SUB_SCALAR_F32, hvx_vec_splat_f32, float) +DEFINE_HVX_BINARY_SCALAR_OP_VARIANTS(hvx_mul_scalar_f32, HVX_OP_MUL_SCALAR_F32, hvx_vec_splat_f32, float) -static inline void hvx_mul_scalar_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { - const HVX_Vector val_vec = hvx_vec_splat_f32(val); - hvx_scalar_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_MUL_SCALAR); -} +DEFINE_HVX_BINARY_SCALAR_OP_VARIANTS(hvx_add_scalar_f16, HVX_OP_ADD_SCALAR_F16, hvx_vec_splat_f16, _Float16) +DEFINE_HVX_BINARY_SCALAR_OP_VARIANTS(hvx_sub_scalar_f16, HVX_OP_SUB_SCALAR_F16, hvx_vec_splat_f16, _Float16) +DEFINE_HVX_BINARY_SCALAR_OP_VARIANTS(hvx_mul_scalar_f16, HVX_OP_MUL_SCALAR_F16, hvx_vec_splat_f16, _Float16) -static inline void hvx_add_scalar_f32(uint8_t * restrict dst, const uint8_t * restrict src, const float val, const int num_elems) { - if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) { - hvx_add_scalar_f32_aa(dst, src, val, num_elems); - } else if (hex_is_aligned((void *) dst, 128)) { - hvx_add_scalar_f32_au(dst, src, val, num_elems); - } else if (hex_is_aligned((void *) src, 128)) { - hvx_add_scalar_f32_ua(dst, src, val, num_elems); - } else { - hvx_add_scalar_f32_uu(dst, src, val, num_elems); - } +// Dispatcher logic +#define HVX_BINARY_SCALAR_DISPATCHER(OP_NAME, ELEM_TYPE) \ +static inline void OP_NAME(uint8_t * restrict dst, const uint8_t * restrict src, const ELEM_TYPE val, const uint32_t num_elems) { \ + if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) { \ + OP_NAME##_aa(dst, src, val, num_elems); \ + } else if (hex_is_aligned((void *) dst, 128)) { \ + OP_NAME##_au(dst, src, val, num_elems); \ + } else if (hex_is_aligned((void *) src, 128)) { \ + OP_NAME##_ua(dst, src, val, num_elems); \ + } else { \ + OP_NAME##_uu(dst, src, val, num_elems); \ + } \ } -static inline void hvx_mul_scalar_f32(uint8_t * restrict dst, const uint8_t * restrict src, const float val, const int num_elems) { - if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) { - hvx_mul_scalar_f32_aa(dst, src, val, num_elems); - } else if (hex_is_aligned((void *) dst, 128)) { - hvx_mul_scalar_f32_au(dst, src, val, num_elems); - } else if (hex_is_aligned((void *) src, 128)) { - hvx_mul_scalar_f32_ua(dst, src, val, num_elems); - } else { - hvx_mul_scalar_f32_uu(dst, src, val, num_elems); - } -} +HVX_BINARY_SCALAR_DISPATCHER(hvx_add_scalar_f32, float) +HVX_BINARY_SCALAR_DISPATCHER(hvx_sub_scalar_f32, float) +HVX_BINARY_SCALAR_DISPATCHER(hvx_mul_scalar_f32, float) -static inline void hvx_sub_scalar_f32(uint8_t * restrict dst, const uint8_t * restrict src, const float val, const int num_elems) { - if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) { - hvx_sub_scalar_f32_aa(dst, src, val, num_elems); - } else if (hex_is_aligned((void *) dst, 128)) { - hvx_sub_scalar_f32_au(dst, src, val, num_elems); - } else if (hex_is_aligned((void *) src, 128)) { - hvx_sub_scalar_f32_ua(dst, src, val, num_elems); - } else { - hvx_sub_scalar_f32_uu(dst, src, val, num_elems); - } -} +HVX_BINARY_SCALAR_DISPATCHER(hvx_add_scalar_f16, _Float16) +HVX_BINARY_SCALAR_DISPATCHER(hvx_sub_scalar_f16, _Float16) +HVX_BINARY_SCALAR_DISPATCHER(hvx_mul_scalar_f16, _Float16) // MIN Scalar variants @@ -310,24 +276,24 @@ static inline void hvx_min_scalar_f32_aa(uint8_t * restrict dst, const uint8_t * const HVX_Vector val_vec = hvx_vec_splat_f32(val); assert((unsigned long) dst % 128 == 0); assert((unsigned long) src % 128 == 0); - hvx_scalar_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_MIN_SCALAR); + hvx_scalar_loop_body(HVX_Vector, HVX_Vector, sizeof(float), hvx_vec_store_a, HVX_OP_MIN_SCALAR); } static inline void hvx_min_scalar_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { const HVX_Vector val_vec = hvx_vec_splat_f32(val); assert((unsigned long) dst % 128 == 0); - hvx_scalar_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_MIN_SCALAR); + hvx_scalar_loop_body(HVX_Vector, HVX_UVector, sizeof(float), hvx_vec_store_a, HVX_OP_MIN_SCALAR); } static inline void hvx_min_scalar_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { const HVX_Vector val_vec = hvx_vec_splat_f32(val); assert((unsigned long) src % 128 == 0); - hvx_scalar_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u, HVX_OP_MIN_SCALAR); + hvx_scalar_loop_body(HVX_UVector, HVX_Vector, sizeof(float), hvx_vec_store_u, HVX_OP_MIN_SCALAR); } static inline void hvx_min_scalar_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { const HVX_Vector val_vec = hvx_vec_splat_f32(val); - hvx_scalar_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_MIN_SCALAR); + hvx_scalar_loop_body(HVX_UVector, HVX_UVector, sizeof(float), hvx_vec_store_u, HVX_OP_MIN_SCALAR); } static inline void hvx_min_scalar_f32(uint8_t * restrict dst, const uint8_t * restrict src, const float val, const int num_elems) { @@ -357,27 +323,27 @@ static inline void hvx_clamp_scalar_f32_aa(uint8_t * restrict dst, const uint8_t const HVX_Vector max_vec = hvx_vec_splat_f32(max); assert((unsigned long) dst % 128 == 0); assert((unsigned long) src % 128 == 0); - hvx_scalar_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_CLAMP_SCALAR); + hvx_scalar_loop_body(HVX_Vector, HVX_Vector, sizeof(float), hvx_vec_store_a, HVX_OP_CLAMP_SCALAR); } static inline void hvx_clamp_scalar_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const float min, const float max, uint32_t n) { const HVX_Vector min_vec = hvx_vec_splat_f32(min); const HVX_Vector max_vec = hvx_vec_splat_f32(max); assert((unsigned long) dst % 128 == 0); - hvx_scalar_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_CLAMP_SCALAR); + hvx_scalar_loop_body(HVX_Vector, HVX_UVector, sizeof(float), hvx_vec_store_a, HVX_OP_CLAMP_SCALAR); } static inline void hvx_clamp_scalar_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const float min, const float max, uint32_t n) { const HVX_Vector min_vec = hvx_vec_splat_f32(min); const HVX_Vector max_vec = hvx_vec_splat_f32(max); assert((unsigned long) src % 128 == 0); - hvx_scalar_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u, HVX_OP_CLAMP_SCALAR); + hvx_scalar_loop_body(HVX_UVector, HVX_Vector, sizeof(float), hvx_vec_store_u, HVX_OP_CLAMP_SCALAR); } static inline void hvx_clamp_scalar_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const float min, const float max, uint32_t n) { const HVX_Vector min_vec = hvx_vec_splat_f32(min); const HVX_Vector max_vec = hvx_vec_splat_f32(max); - hvx_scalar_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_CLAMP_SCALAR); + hvx_scalar_loop_body(HVX_UVector, HVX_UVector, sizeof(float), hvx_vec_store_u, HVX_OP_CLAMP_SCALAR); } static inline void hvx_clamp_scalar_f32(uint8_t * restrict dst, const uint8_t * restrict src, const float min, const float max, const int num_elems) { @@ -396,7 +362,7 @@ static inline void hvx_clamp_scalar_f32(uint8_t * restrict dst, const uint8_t * // Square // -#define hvx_sqr_loop_body(dst_type, src_type, vec_store) \ +#define hvx_sqr_f32_loop_body(dst_type, src_type, vec_store) \ do { \ dst_type * restrict vdst = (dst_type *) dst; \ src_type * restrict vsrc = (src_type *) src; \ @@ -410,10 +376,10 @@ static inline void hvx_clamp_scalar_f32(uint8_t * restrict dst, const uint8_t * \ _Pragma("unroll(4)") \ for (; i < nvec; i++) { \ - vdst[i] = HVX_OP_MUL(vsrc[i], vsrc[i]); \ + vdst[i] = HVX_OP_MUL_F32(vsrc[i], vsrc[i]); \ } \ if (nloe) { \ - HVX_Vector v = HVX_OP_MUL(vsrc[i], vsrc[i]); \ + HVX_Vector v = HVX_OP_MUL_F32(vsrc[i], vsrc[i]); \ vec_store((void *) &vdst[i], nloe * elem_size, v); \ } \ } while(0) @@ -421,21 +387,21 @@ static inline void hvx_clamp_scalar_f32(uint8_t * restrict dst, const uint8_t * static inline void hvx_sqr_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { assert((unsigned long) dst % 128 == 0); assert((unsigned long) src % 128 == 0); - hvx_sqr_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a); + hvx_sqr_f32_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a); } static inline void hvx_sqr_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { assert((unsigned long) dst % 128 == 0); - hvx_sqr_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a); + hvx_sqr_f32_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a); } static inline void hvx_sqr_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { assert((unsigned long) src % 128 == 0); - hvx_sqr_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u); + hvx_sqr_f32_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u); } static inline void hvx_sqr_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { - hvx_sqr_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u); + hvx_sqr_f32_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u); } static inline void hvx_sqr_f32(uint8_t * restrict dst, const uint8_t * restrict src, const uint32_t num_elems) { @@ -454,17 +420,24 @@ static inline void hvx_sqr_f32(uint8_t * restrict dst, const uint8_t * restrict } } -#undef HVX_OP_ADD -#undef HVX_OP_SUB -#undef HVX_OP_MUL +#undef HVX_OP_ADD_F32 +#undef HVX_OP_SUB_F32 +#undef HVX_OP_MUL_F32 +#undef HVX_OP_ADD_F16 +#undef HVX_OP_SUB_F16 +#undef HVX_OP_MUL_F16 #undef hvx_arith_loop_body -#undef HVX_OP_ADD_SCALAR -#undef HVX_OP_SUB_SCALAR -#undef HVX_OP_MUL_SCALAR +#undef HVX_OP_ADD_SCALAR_F32 +#undef HVX_OP_SUB_SCALAR_F32 +#undef HVX_OP_MUL_SCALAR_F32 +#undef HVX_OP_ADD_SCALAR_F16 +#undef HVX_OP_SUB_SCALAR_F16 +#undef HVX_OP_MUL_SCALAR_F16 #undef hvx_scalar_loop_body #undef HVX_OP_MIN_SCALAR #undef HVX_OP_CLAMP_SCALAR #undef DEFINE_HVX_BINARY_OP_VARIANTS #undef HVX_BINARY_DISPATCHER +#undef UNUSED #endif // HVX_ARITH_H diff --git a/ggml/src/ggml-hexagon/htp/hvx-base.h b/ggml/src/ggml-hexagon/htp/hvx-base.h index 701637f22b2..578ca288fb6 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-base.h +++ b/ggml/src/ggml-hexagon/htp/hvx-base.h @@ -189,4 +189,52 @@ static inline HVX_VectorPair hvx_vec_mpyacc_f32_f16(HVX_VectorPair acc, HVX_Vect #endif +#if __HVX_ARCH__ < 79 + +static inline HVX_Vector hvx_vec_add_f16_f16(HVX_Vector a, HVX_Vector b) +{ + const HVX_Vector negone = Q6_Vh_vsplat_R(0xBC00); // -1.0 in IEEE FP16 + const HVX_Vector one = Q6_Vh_vsplat_R(0x3C00); // 1.0 in IEEE FP16 + HVX_VectorPair a_p = Q6_Wqf32_vmpy_VhfVhf(a, one); + HVX_VectorPair b_p = Q6_Wqf32_vmpy_VhfVhf(b, negone); + HVX_Vector a0 = Q6_Vqf32_vsub_Vqf32Vqf32(Q6_V_lo_W(a_p), Q6_V_lo_W(b_p)); + HVX_Vector a1 = Q6_Vqf32_vsub_Vqf32Vqf32(Q6_V_hi_W(a_p), Q6_V_hi_W(b_p)); + return Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(a1, a0)); +} + +static inline HVX_Vector hvx_vec_sub_f16_f16(HVX_Vector a, HVX_Vector b) +{ + const HVX_Vector negone = Q6_Vh_vsplat_R(0xBC00); // -1.0 in IEEE FP16 + const HVX_Vector one = Q6_Vh_vsplat_R(0x3C00); // 1.0 in IEEE FP16 + HVX_VectorPair a_p = Q6_Wqf32_vmpy_VhfVhf(a, one); + HVX_VectorPair b_p = Q6_Wqf32_vmpy_VhfVhf(b, negone); + HVX_Vector a0 = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(a_p), Q6_V_lo_W(b_p)); + HVX_Vector a1 = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_hi_W(a_p), Q6_V_hi_W(b_p)); + return Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(a1, a0)); +} + +static inline HVX_Vector hvx_vec_mul_f16_f16(HVX_Vector a, HVX_Vector b) +{ + return Q6_Vhf_equals_Wqf32(Q6_Wqf32_vmpy_VhfVhf(a, b)); +} + +#else + +static inline HVX_Vector hvx_vec_add_f16_f16(HVX_Vector a, HVX_Vector b) +{ + return Q6_Vhf_vadd_VhfVhf(a, b); +} + +static inline HVX_Vector hvx_vec_sub_f16_f16(HVX_Vector a, HVX_Vector b) +{ + return Q6_Vhf_vsub_VhfVhf(a, b); +} + +static inline HVX_Vector hvx_vec_mul_f16_f16(HVX_Vector a, HVX_Vector b) +{ + return Q6_Vhf_vmpy_VhfVhf(a, b); +} + +#endif // __HVX_ARCH__ < 79 + #endif /* HVX_BASE_H */ diff --git a/ggml/src/ggml-hexagon/htp/hvx-div.h b/ggml/src/ggml-hexagon/htp/hvx-div.h index 7dae012e0ed..05cefea039f 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-div.h +++ b/ggml/src/ggml-hexagon/htp/hvx-div.h @@ -15,11 +15,144 @@ #include "hvx-arith.h" #if __HVX_ARCH__ < 79 -#define HVX_OP_MUL(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b)) +#define HVX_OP_MUL_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b)) #else -#define HVX_OP_MUL(a, b) Q6_Vsf_vmpy_VsfVsf(a, b) +#define HVX_OP_MUL_F32(a, b) Q6_Vsf_vmpy_VsfVsf(a, b) #endif +// Compute div by scaler in f32. Requires first by expanding fp32 to fp16 and converting the result back to fp32. +static inline HVX_Vector hvx_div_mul_f16_const_using_f32(HVX_Vector vec1_hf, HVX_Vector vec2_sf_const, HVX_Vector vec_hf_one_1_0) { +#if __HVX_ARCH__ < 79 + HVX_VectorPair src_to_f32 = Q6_Wqf32_vmpy_VhfVhf(vec1_hf, vec_hf_one_1_0); + HVX_Vector src_to_f32_0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(src_to_f32)); + HVX_Vector src_to_f32_1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(src_to_f32)); +#else + HVX_VectorPair src_to_f32 = Q6_Wsf_vmpy_VhfVhf(vec1_hf, vec_hf_one_1_0); + HVX_Vector src_to_f32_0 = Q6_V_lo_W(src_to_f32); + HVX_Vector src_to_f32_1 = Q6_V_hi_W(src_to_f32); +#endif + + HVX_Vector div_f32_0 = HVX_OP_MUL_F32(src_to_f32_0, vec2_sf_const); + HVX_Vector div_f32_1 = HVX_OP_MUL_F32(src_to_f32_1, vec2_sf_const); + +#if __HVX_ARCH__ < 79 + HVX_Vector res = hvx_vec_f32_to_f16(div_f32_0, div_f32_1); +#else + HVX_Vector res = Q6_Vhf_vcvt_VsfVsf(div_f32_0, div_f32_1); +#endif + return res; +} + +#define hvx_div_scaler_f16_loop_body(dst_type, src_type, vec_store) \ + do { \ + dst_type * restrict vdst = (dst_type *) dst; \ + src_type * restrict vsrc = (src_type *) src; \ + HVX_Vector hf_one = Q6_Vh_vsplat_R(0x3C00); \ + \ + const uint32_t nvec = n / VLEN_FP16; \ + const uint32_t nloe = n % VLEN_FP16; \ + \ + uint32_t i = 0; \ + \ + _Pragma("unroll(4)") \ + for (; i < nvec; i++) { \ + HVX_Vector res = hvx_div_mul_f16_const_using_f32(vsrc[i], val_vec_f32, hf_one); \ + vdst[i] = res; \ + } \ + if (nloe) { \ + HVX_Vector res = hvx_div_mul_f16_const_using_f32(vsrc[i], val_vec_f32, hf_one); \ + vec_store((void *) &vdst[i], nloe * SIZEOF_FP16, res); \ + } \ + } while(0) + +static inline void hvx_div_scalar_f16_aa(uint8_t * restrict dst, const uint8_t * restrict src, const _Float16 val, uint32_t n) { + const HVX_Vector val_vec_f32 = hvx_vec_splat_f32(1.0f/((float)val)); + assert((uintptr_t) dst % 128 == 0); + assert((uintptr_t) src % 128 == 0); + hvx_div_scaler_f16_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a); +} +static inline void hvx_div_scalar_f16_au(uint8_t * restrict dst, const uint8_t * restrict src, const _Float16 val, uint32_t n) { + const HVX_Vector val_vec_f32 = hvx_vec_splat_f32(1.0f/((float)val)); + assert((uintptr_t) dst % 128 == 0); + hvx_div_scaler_f16_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a); +} +static inline void hvx_div_scalar_f16_ua(uint8_t * restrict dst, const uint8_t * restrict src, const _Float16 val, uint32_t n) { + const HVX_Vector val_vec_f32 = hvx_vec_splat_f32(1.0f/((float)val)); + assert((uintptr_t) src % 128 == 0); + hvx_div_scaler_f16_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u); +} +static inline void hvx_div_scalar_f16_uu(uint8_t * restrict dst, const uint8_t * restrict src, const _Float16 val, uint32_t n) { + const HVX_Vector val_vec_f32 = hvx_vec_splat_f32(1.0f/((float)val)); + hvx_div_scaler_f16_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u); +} + +// Compute div by using hvx_vec_inverse_f32_guard. Requires first by exapnding fp32 to fp16 and convert the result back to fp32. +static inline HVX_Vector hvx_vec_div_f16_using_f32(HVX_Vector vec1, HVX_Vector vec2, HVX_Vector f32_nan_inf_mask, HVX_Vector vec_hf_one_1_0) { +#if __HVX_ARCH__ < 79 + // Convert first input to fp32 + HVX_VectorPair vec1_to_f32 = Q6_Wqf32_vmpy_VhfVhf(vec1, vec_hf_one_1_0); // *1.0 + HVX_Vector vec1_to_f32_0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(vec1_to_f32)); + HVX_Vector vec1_to_f32_1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(vec1_to_f32)); + + // Convert second input to fp32 + HVX_VectorPair vec2_to_f32 = Q6_Wqf32_vmpy_VhfVhf(vec2, vec_hf_one_1_0); // *1.0 + HVX_Vector vec2_to_f32_0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(vec2_to_f32)); + HVX_Vector vec2_to_f32_1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(vec2_to_f32)); +#else + // Convert first input to fp32 + HVX_VectorPair vec1_to_f32 = Q6_Wsf_vmpy_VhfVhf(vec1, vec_hf_one_1_0); // *1.0 + HVX_Vector vec1_to_f32_0 = Q6_V_lo_W(vec1_to_f32); + HVX_Vector vec1_to_f32_1 = Q6_V_hi_W(vec1_to_f32); + + // Convert second input to fp32 + HVX_VectorPair vec2_to_f32 = Q6_Wsf_vmpy_VhfVhf(vec2, vec_hf_one_1_0); // *1.0 + HVX_Vector vec2_to_f32_0 = Q6_V_lo_W(vec2_to_f32); + HVX_Vector vec2_to_f32_1 = Q6_V_hi_W(vec2_to_f32); +#endif + + // Inverse second input in fp32 + HVX_Vector vec2_inv_f32_0 = hvx_vec_inverse_f32_guard(vec2_to_f32_0, f32_nan_inf_mask); + HVX_Vector vec2_inv_f32_1 = hvx_vec_inverse_f32_guard(vec2_to_f32_1, f32_nan_inf_mask); + + // Multiply first input by inverse of second, in fp32 + HVX_Vector div_f32_0 = HVX_OP_MUL_F32(vec1_to_f32_0, vec2_inv_f32_0); + HVX_Vector div_f32_1 = HVX_OP_MUL_F32(vec1_to_f32_1, vec2_inv_f32_1); + + // Convert back to fp16 +#if __HVX_ARCH__ < 79 + HVX_Vector recip = hvx_vec_f32_to_f16(div_f32_0, div_f32_1); +#else + HVX_Vector recip = Q6_Vhf_vcvt_VsfVsf(div_f32_0, div_f32_1); +#endif + + return recip; +} + +#define hvx_div_f16_loop_body(dst_type, src0_type, src1_type, vec_store) \ + do { \ + dst_type * restrict vdst = (dst_type *) dst; \ + src0_type * restrict vsrc0 = (src0_type *) src0; \ + src1_type * restrict vsrc1 = (src1_type *) src1; \ + \ + const HVX_Vector nan_inf_mask = Q6_V_vsplat_R(0x7f800000); \ + const HVX_Vector hf_one = Q6_Vh_vsplat_R(0x3C00); \ + \ + const uint32_t nvec = n / VLEN_FP16; \ + const uint32_t nloe = n % VLEN_FP16; \ + \ + uint32_t i = 0; \ + \ + _Pragma("unroll(4)") \ + for (; i < nvec; i++) { \ + HVX_Vector res = hvx_vec_div_f16_using_f32(vsrc0[i], vsrc1[i], nan_inf_mask, hf_one); \ + vdst[i] = res; \ + } \ + if (nloe) { \ + HVX_Vector res = hvx_vec_div_f16_using_f32(vsrc0[i], vsrc1[i], nan_inf_mask, hf_one); \ + vec_store((void *) &vdst[i], nloe * SIZEOF_FP16, res); \ + } \ + } while(0) + #define hvx_div_f32_loop_body(dst_type, src0_type, src1_type, vec_store) \ do { \ dst_type * restrict vdst = (dst_type *) dst; \ @@ -36,81 +169,83 @@ _Pragma("unroll(4)") \ for (; i < nvec; i++) { \ HVX_Vector inv_src1 = hvx_vec_inverse_f32_guard(vsrc1[i], nan_inf_mask); \ - HVX_Vector res = HVX_OP_MUL(vsrc0[i], inv_src1); \ + HVX_Vector res = HVX_OP_MUL_F32(vsrc0[i], inv_src1); \ vdst[i] = res; \ } \ if (nloe) { \ HVX_Vector inv_src1 = hvx_vec_inverse_f32_guard(vsrc1[i], nan_inf_mask); \ - HVX_Vector res = HVX_OP_MUL(vsrc0[i], inv_src1); \ + HVX_Vector res = HVX_OP_MUL_F32(vsrc0[i], inv_src1); \ vec_store((void *) &vdst[i], nloe * SIZEOF_FP32, res); \ } \ } while(0) -// 3-letter suffix variants -static inline void hvx_div_f32_aaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { - assert((uintptr_t) dst % 128 == 0); - assert((uintptr_t) src0 % 128 == 0); - assert((uintptr_t) src1 % 128 == 0); - hvx_div_f32_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a); -} - -static inline void hvx_div_f32_aau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { - assert((uintptr_t) dst % 128 == 0); - assert((uintptr_t) src0 % 128 == 0); - hvx_div_f32_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a); +// Generic macro to define alignment permutations for an op +#define DEFINE_HVX_DIV_OP_VARIANTS(OP_NAME, OP_LOOP_BODY) \ +static inline void OP_NAME##_aaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ + assert((uintptr_t) dst % 128 == 0); \ + assert((uintptr_t) src0 % 128 == 0); \ + assert((uintptr_t) src1 % 128 == 0); \ + OP_LOOP_BODY(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a); \ +} \ +static inline void OP_NAME##_aau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ + assert((uintptr_t) dst % 128 == 0); \ + assert((uintptr_t) src0 % 128 == 0); \ + OP_LOOP_BODY(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a); \ +} \ +static inline void OP_NAME##_aua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ + assert((uintptr_t) dst % 128 == 0); \ + assert((uintptr_t) src1 % 128 == 0); \ + OP_LOOP_BODY(HVX_Vector, HVX_UVector, HVX_Vector, hvx_vec_store_a); \ +} \ +static inline void OP_NAME##_auu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ + assert((uintptr_t) dst % 128 == 0); \ + OP_LOOP_BODY(HVX_Vector, HVX_UVector, HVX_UVector, hvx_vec_store_a); \ +} \ +static inline void OP_NAME##_uaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ + assert((uintptr_t) src0 % 128 == 0); \ + assert((uintptr_t) src1 % 128 == 0); \ + OP_LOOP_BODY(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u); \ +} \ +static inline void OP_NAME##_uau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ + assert((uintptr_t) src0 % 128 == 0); \ + OP_LOOP_BODY(HVX_UVector, HVX_Vector, HVX_UVector, hvx_vec_store_u); \ +} \ +static inline void OP_NAME##_uua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ + assert((uintptr_t) src1 % 128 == 0); \ + OP_LOOP_BODY(HVX_UVector, HVX_UVector, HVX_Vector, hvx_vec_store_u); \ +} \ +static inline void OP_NAME##_uuu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ + OP_LOOP_BODY(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u); \ +} \ + +// Dispatcher logic +#define HVX_DIV_DISPATCHER(OP_NAME) \ +static inline void OP_NAME(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) { \ + if (hex_is_aligned((void *) dst, 128)) { \ + if (hex_is_aligned((void *) src0, 128)) { \ + if (hex_is_aligned((void *) src1, 128)) OP_NAME##_aaa(dst, src0, src1, num_elems); \ + else OP_NAME##_aau(dst, src0, src1, num_elems); \ + } else { \ + if (hex_is_aligned((void *) src1, 128)) OP_NAME##_aua(dst, src0, src1, num_elems); \ + else OP_NAME##_auu(dst, src0, src1, num_elems); \ + } \ + } else { \ + if (hex_is_aligned((void *) src0, 128)) { \ + if (hex_is_aligned((void *) src1, 128)) OP_NAME##_uaa(dst, src0, src1, num_elems); \ + else OP_NAME##_uau(dst, src0, src1, num_elems); \ + } else { \ + if (hex_is_aligned((void *) src1, 128)) OP_NAME##_uua(dst, src0, src1, num_elems); \ + else OP_NAME##_uuu(dst, src0, src1, num_elems); \ + } \ + } \ } -static inline void hvx_div_f32_aua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { - assert((uintptr_t) dst % 128 == 0); - assert((uintptr_t) src1 % 128 == 0); - hvx_div_f32_loop_body(HVX_Vector, HVX_UVector, HVX_Vector, hvx_vec_store_a); -} +DEFINE_HVX_DIV_OP_VARIANTS(hvx_div_f32, hvx_div_f32_loop_body) +DEFINE_HVX_DIV_OP_VARIANTS(hvx_div_f16, hvx_div_f16_loop_body) -static inline void hvx_div_f32_auu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { - assert((uintptr_t) dst % 128 == 0); - hvx_div_f32_loop_body(HVX_Vector, HVX_UVector, HVX_UVector, hvx_vec_store_a); -} - -static inline void hvx_div_f32_uaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { - assert((uintptr_t) src0 % 128 == 0); - assert((uintptr_t) src1 % 128 == 0); - hvx_div_f32_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u); -} - -static inline void hvx_div_f32_uau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { - assert((uintptr_t) src0 % 128 == 0); - hvx_div_f32_loop_body(HVX_UVector, HVX_Vector, HVX_UVector, hvx_vec_store_u); -} - -static inline void hvx_div_f32_uua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { - assert((uintptr_t) src1 % 128 == 0); - hvx_div_f32_loop_body(HVX_UVector, HVX_UVector, HVX_Vector, hvx_vec_store_u); -} - -static inline void hvx_div_f32_uuu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { - hvx_div_f32_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u); -} - -static inline void hvx_div_f32(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) { - if (hex_is_aligned((void *) dst, 128)) { - if (hex_is_aligned((void *) src0, 128)) { - if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_aaa(dst, src0, src1, num_elems); - else hvx_div_f32_aau(dst, src0, src1, num_elems); - } else { - if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_aua(dst, src0, src1, num_elems); - else hvx_div_f32_auu(dst, src0, src1, num_elems); - } - } else { - if (hex_is_aligned((void *) src0, 128)) { - if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_uaa(dst, src0, src1, num_elems); - else hvx_div_f32_uau(dst, src0, src1, num_elems); - } else { - if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_uua(dst, src0, src1, num_elems); - else hvx_div_f32_uuu(dst, src0, src1, num_elems); - } - } -} +HVX_DIV_DISPATCHER(hvx_div_f32) +HVX_DIV_DISPATCHER(hvx_div_f16) -#undef HVX_OP_MUL +#undef HVX_OP_MUL_F32 #endif // HVX_DIV_H diff --git a/ggml/src/ggml-hexagon/htp/hvx-inverse.h b/ggml/src/ggml-hexagon/htp/hvx-inverse.h index 53db94aae2b..f2054f45bac 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-inverse.h +++ b/ggml/src/ggml-hexagon/htp/hvx-inverse.h @@ -137,40 +137,74 @@ static inline HVX_Vector hvx_vec_inverse_f32_guard(HVX_Vector v_sf, HVX_Vector n } \ } while(0) -static inline void hvx_inverse_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { - assert((unsigned long) dst % 128 == 0); - assert((unsigned long) src % 128 == 0); - hvx_inverse_f32_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a); -} +static inline HVX_Vector hvx_vec_inverse_f16_guard(HVX_Vector v_sf, HVX_Vector nan_inf_mask) { + HVX_Vector out = hvx_vec_inverse_f16(v_sf); -static inline void hvx_inverse_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { - assert((unsigned long) dst % 128 == 0); - hvx_inverse_f32_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a); -} + HVX_Vector masked_out = Q6_V_vand_VV(out, nan_inf_mask); + const HVX_VectorPred pred = Q6_Q_vcmp_eq_VhVh(nan_inf_mask, masked_out); -static inline void hvx_inverse_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { - assert((unsigned long) src % 128 == 0); - hvx_inverse_f32_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u); + return Q6_V_vmux_QVV(pred, Q6_V_vzero(), out); } -static inline void hvx_inverse_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { - hvx_inverse_f32_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u); -} +#define hvx_inverse_f16_loop_body(dst_type, src_type, vec_store) \ + do { \ + dst_type * restrict vdst = (dst_type *) dst; \ + src_type * restrict vsrc = (src_type *) src; \ + \ + const HVX_Vector nan_inf_mask = Q6_Vh_vsplat_R(0x7c00); \ + \ + const uint32_t nvec = n / VLEN_FP16; \ + const uint32_t nloe = n % VLEN_FP16; \ + \ + uint32_t i = 0; \ + \ + _Pragma("unroll(4)") \ + for (; i < nvec; i++) { \ + vdst[i] = hvx_vec_inverse_f16_guard(vsrc[i], nan_inf_mask); \ + } \ + if (nloe) { \ + HVX_Vector v = hvx_vec_inverse_f16_guard(vsrc[i], nan_inf_mask); \ + vec_store((void *) &vdst[i], nloe * SIZEOF_FP16, v); \ + } \ + } while(0) -static inline void hvx_inverse_f32(uint8_t * restrict dst, uint8_t * restrict src, const int num_elems) { - if ((unsigned long) dst % 128 == 0) { - if ((unsigned long) src % 128 == 0) { - hvx_inverse_f32_aa(dst, src, num_elems); - } else { - hvx_inverse_f32_au(dst, src, num_elems); - } - } else { - if ((unsigned long) src % 128 == 0) { - hvx_inverse_f32_ua(dst, src, num_elems); - } else { - hvx_inverse_f32_uu(dst, src, num_elems); - } - } +// Generic macro to define alignment permutations for an op +#define DEFINE_HVX_INV_OP_VARIANTS(OP_NAME, OP_LOOP_BODY) \ +static inline void OP_NAME##_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { \ + assert((uintptr_t) dst % 128 == 0); \ + assert((uintptr_t) src % 128 == 0); \ + OP_LOOP_BODY(HVX_Vector, HVX_Vector, hvx_vec_store_a); \ +} \ +static inline void OP_NAME##_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { \ + assert((uintptr_t) dst % 128 == 0); \ + OP_LOOP_BODY(HVX_Vector, HVX_UVector, hvx_vec_store_a); \ +} \ +static inline void OP_NAME##_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { \ + assert((uintptr_t) src % 128 == 0); \ + OP_LOOP_BODY(HVX_UVector, HVX_Vector, hvx_vec_store_u); \ +} \ +static inline void OP_NAME##_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { \ + OP_LOOP_BODY(HVX_UVector, HVX_UVector, hvx_vec_store_u); \ +} \ + +// Dispatcher logic +#define HVX_INV_DISPATCHER(OP_NAME) \ +static inline void OP_NAME(uint8_t * restrict dst, const uint8_t * restrict src, const uint32_t num_elems) { \ + if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) { \ + OP_NAME##_aa(dst, src, num_elems); \ + } else if (hex_is_aligned((void *) dst, 128)) { \ + OP_NAME##_au(dst, src, num_elems); \ + } else if (hex_is_aligned((void *) src, 128)) { \ + OP_NAME##_ua(dst, src, num_elems); \ + } else { \ + OP_NAME##_uu(dst, src, num_elems); \ + } \ } +DEFINE_HVX_INV_OP_VARIANTS(hvx_inverse_f32, hvx_inverse_f32_loop_body) +DEFINE_HVX_INV_OP_VARIANTS(hvx_inverse_f16, hvx_inverse_f16_loop_body) + +HVX_INV_DISPATCHER(hvx_inverse_f32) +HVX_INV_DISPATCHER(hvx_inverse_f16) + #endif // HVX_INVERSE_H diff --git a/ggml/src/ggml-hexagon/htp/rope-ops.c b/ggml/src/ggml-hexagon/htp/rope-ops.c index 9aeb80d0b8b..be9469538f6 100644 --- a/ggml/src/ggml-hexagon/htp/rope-ops.c +++ b/ggml/src/ggml-hexagon/htp/rope-ops.c @@ -400,7 +400,9 @@ static int execute_op_rope_f32(struct htp_ops_context * octx) { return HTP_STATUS_NO_SUPPORT; } - const uint32_t n_threads = octx->n_threads; + const uint32_t ne0 = dst->ne[0]; + const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3]; + const uint32_t n_threads = MIN(octx->n_threads, src0_nrows); const size_t src0_row_size = src0->nb[1]; const size_t dst_row_size = dst->nb[1]; @@ -465,17 +467,14 @@ static int execute_op_rope_f32(struct htp_ops_context * octx) { rctx.dst_row_size_aligned = dst_row_size_aligned; rctx.theta_cache_offset = theta_cache_size_aligned; - uint32_t ne0 = dst->ne[0]; - uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3]; rctx.src0_nrows = src0_nrows; + rctx.src0_nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads; FARF(HIGH, "rope-f32 n-rows %u n-dims %d ne0 %u ext-factor %.6f theta-scale %.6f attn-factor %.6f\n", rctx.src0_nrows, rctx.n_dims, ne0, rctx.ext_factor, rctx.theta_scale, rctx.attn_factor); if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { - uint32_t n_jobs = MIN(n_threads, src0_nrows); - rctx.src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs; - worker_pool_run_func(octx->ctx->worker_pool, rope_job_f32, &rctx, n_jobs); + worker_pool_run_func(octx->ctx->worker_pool, rope_job_f32, &rctx, n_threads); } return err; diff --git a/ggml/src/ggml-hexagon/htp/set-rows-ops.c b/ggml/src/ggml-hexagon/htp/set-rows-ops.c index 2fd6c907724..4b6967749f8 100644 --- a/ggml/src/ggml-hexagon/htp/set-rows-ops.c +++ b/ggml/src/ggml-hexagon/htp/set-rows-ops.c @@ -128,6 +128,8 @@ static void set_rows_thread_f16_f32(unsigned int nth, unsigned int ith, void *da int op_set_rows(struct htp_ops_context * octx) { set_rows_preamble; + const uint32_t n_threads = MIN(nr, octx->n_threads); + if (octx->src0.type != HTP_TYPE_F32) { return HTP_STATUS_NO_SUPPORT; } @@ -149,15 +151,14 @@ int op_set_rows(struct htp_ops_context * octx) { srctx.div_ne12 = init_fastdiv_values(ne12); srctx.div_ne11 = init_fastdiv_values(ne11); - const uint32_t n_jobs = MIN(nr, octx->n_threads); - srctx.src0_nrows_per_thread = (nr + n_jobs - 1) / n_jobs; + srctx.src0_nrows_per_thread = (nr + n_threads - 1) / n_threads; switch(octx->dst.type) { case HTP_TYPE_F32: - worker_pool_run_func(octx->ctx->worker_pool, set_rows_thread_f32_f32, &srctx, n_jobs); + worker_pool_run_func(octx->ctx->worker_pool, set_rows_thread_f32_f32, &srctx, n_threads); break; case HTP_TYPE_F16: - worker_pool_run_func(octx->ctx->worker_pool, set_rows_thread_f16_f32, &srctx, n_jobs); + worker_pool_run_func(octx->ctx->worker_pool, set_rows_thread_f16_f32, &srctx, n_threads); break; default: return HTP_STATUS_NO_SUPPORT; diff --git a/ggml/src/ggml-hexagon/htp/softmax-ops.c b/ggml/src/ggml-hexagon/htp/softmax-ops.c index 6e22eb6a639..8dae7f1ed55 100644 --- a/ggml/src/ggml-hexagon/htp/softmax-ops.c +++ b/ggml/src/ggml-hexagon/htp/softmax-ops.c @@ -353,7 +353,8 @@ static int execute_op_softmax_f32(struct htp_ops_context * octx) { return HTP_STATUS_NO_SUPPORT; } - const uint32_t n_threads = octx->n_threads; + const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3]; + const uint32_t n_threads = MIN(octx->n_threads, src0_nrows); const size_t src0_row_size = src0->nb[1]; const size_t src1_row_size = src0_row_size; @@ -393,12 +394,9 @@ static int execute_op_softmax_f32(struct htp_ops_context * octx) { octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; - uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3]; - if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { - uint32_t n_jobs = MIN(n_threads, src0_nrows); - smctx.src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs; - worker_pool_run_func(octx->ctx->worker_pool, softmax_job_f32, &smctx, n_jobs); + smctx.src0_nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads; + worker_pool_run_func(octx->ctx->worker_pool, softmax_job_f32, &smctx, n_threads); } return err; diff --git a/ggml/src/ggml-hexagon/htp/sum-rows-ops.c b/ggml/src/ggml-hexagon/htp/sum-rows-ops.c index 04fa72182a3..352650b689b 100644 --- a/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +++ b/ggml/src/ggml-hexagon/htp/sum-rows-ops.c @@ -102,11 +102,9 @@ int op_sum_rows(struct htp_ops_context * octx) { return HTP_STATUS_OK; } - const int n_threads = octx->n_threads; const uint32_t src0_nrows = ne01 * ne02 * ne03; - - uint32_t n_jobs = MIN(n_threads, src0_nrows); - uint32_t rows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs; + const uint32_t n_threads = MIN(octx->n_threads, src0_nrows); + const uint32_t rows_per_thread = (src0_nrows + n_threads - 1) / n_threads; bool opt_path = false; if ((0 == hex_is_aligned((void *) src0->data, VLEN)) && !(nb01 & (VLEN - 1))) { @@ -124,7 +122,7 @@ int op_sum_rows(struct htp_ops_context * octx) { .opt_path = opt_path, }; - worker_pool_run_func(octx->ctx->worker_pool, sum_rows_thread_f32, &smctx, n_jobs); + worker_pool_run_func(octx->ctx->worker_pool, sum_rows_thread_f32, &smctx, n_threads); return HTP_STATUS_OK; } diff --git a/ggml/src/ggml-hexagon/htp/unary-ops.c b/ggml/src/ggml-hexagon/htp/unary-ops.c index 98135c50ab8..5bbd5040d3d 100644 --- a/ggml/src/ggml-hexagon/htp/unary-ops.c +++ b/ggml/src/ggml-hexagon/htp/unary-ops.c @@ -301,8 +301,8 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) { return HTP_STATUS_NO_SUPPORT; } - const int n_threads = octx->n_threads; const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3]; + const uint32_t n_threads = MIN(octx->n_threads, src0_nrows); const size_t src0_row_size = src0->nb[1]; const size_t dst_row_size = dst->nb[1]; @@ -338,11 +338,9 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) { octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size); if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { - uint32_t n_jobs = MIN(n_threads, src0_nrows); - struct htp_unary_context uctx = { .octx = octx, - .src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs, + .src0_nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads, .src0_nrows = src0_nrows, .data_src0 = (const uint8_t *)src0->data, @@ -361,7 +359,7 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) { .nc = src0->ne[0], }; - worker_pool_run_func(octx->ctx->worker_pool, unary_job_f32_per_thread, &uctx, n_jobs); + worker_pool_run_func(octx->ctx->worker_pool, unary_job_f32_per_thread, &uctx, n_threads); } return err; From 1d94b0be4fbd916b1210d84de684658935a2c5f8 Mon Sep 17 00:00:00 2001 From: lhez Date: Thu, 5 Mar 2026 21:16:39 -0800 Subject: [PATCH 218/831] opencl: add neg, exp and diag (llama/20127) * opencl: add `neg` * opencl: add `exp` * opencl: add `diag` --- ggml/src/ggml-opencl/CMakeLists.txt | 3 + ggml/src/ggml-opencl/ggml-opencl.cpp | 293 +++++++++++++++++++++++++++ ggml/src/ggml-opencl/kernels/diag.cl | 27 +++ ggml/src/ggml-opencl/kernels/exp.cl | 125 ++++++++++++ ggml/src/ggml-opencl/kernels/neg.cl | 125 ++++++++++++ 5 files changed, 573 insertions(+) create mode 100644 ggml/src/ggml-opencl/kernels/diag.cl create mode 100644 ggml/src/ggml-opencl/kernels/exp.cl create mode 100644 ggml/src/ggml-opencl/kernels/neg.cl diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index 0fe1dd38476..fb3ae17eaf4 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -63,6 +63,7 @@ set(GGML_OPENCL_KERNELS cpy cvt diag_mask_inf + diag div gelu gemv_noshuffle_general @@ -112,6 +113,7 @@ set(GGML_OPENCL_KERNELS gemm_noshuffle_q4_1_f32 gemv_noshuffle_general_q8_0_f32 mul + neg norm relu rms_norm @@ -134,6 +136,7 @@ set(GGML_OPENCL_KERNELS tsembd upscale tanh + exp expm1 softplus pad diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 7af032ce0e1..4ef33a7765d 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -499,6 +499,7 @@ struct ggml_backend_opencl_context { cl_kernel kernel_rms_norm, kernel_rms_norm_mul; cl_kernel kernel_group_norm, kernel_group_norm_mul_add; cl_kernel kernel_diag_mask_inf, kernel_diag_mask_inf_8; + cl_kernel kernel_diag_f32; cl_kernel kernel_soft_max, kernel_soft_max_4; cl_kernel kernel_soft_max_f16, kernel_soft_max_4_f16; std::map, cl_kernel> kernels_flash_attn_f16; @@ -549,6 +550,10 @@ struct ggml_backend_opencl_context { cl_kernel kernel_pad; cl_kernel kernel_tanh_f32, kernel_tanh_f32_4, kernel_tanh_f32_nc; cl_kernel kernel_tanh_f16, kernel_tanh_f16_4, kernel_tanh_f16_nc; + cl_kernel kernel_neg_f32, kernel_neg_f32_4, kernel_neg_f32_nc; + cl_kernel kernel_neg_f16, kernel_neg_f16_4, kernel_neg_f16_nc; + cl_kernel kernel_exp_f32, kernel_exp_f32_4, kernel_exp_f32_nc; + cl_kernel kernel_exp_f16, kernel_exp_f16_4, kernel_exp_f16_nc; cl_kernel kernel_expm1_f32, kernel_expm1_f32_4, kernel_expm1_f32_nc; cl_kernel kernel_expm1_f16, kernel_expm1_f16_4, kernel_expm1_f16_nc; cl_kernel kernel_softplus_f32, kernel_softplus_f32_4, kernel_softplus_f32_nc; @@ -932,6 +937,23 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // diag + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "diag.cl.h" + }; +#else + const std::string kernel_src = read_file("diag.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_diag_f32 = clCreateKernel(prog, "kernel_diag_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + // gelu { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -1979,6 +2001,48 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // neg + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "neg.cl.h" + }; +#else + const std::string kernel_src = read_file("neg.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_neg_f32 = clCreateKernel(prog, "kernel_neg_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_neg_f32_4 = clCreateKernel(prog, "kernel_neg_f32_4", &err), err)); + CL_CHECK((backend_ctx->kernel_neg_f32_nc = clCreateKernel(prog, "kernel_neg_f32_nc", &err), err)); + CL_CHECK((backend_ctx->kernel_neg_f16 = clCreateKernel(prog, "kernel_neg_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_neg_f16_4 = clCreateKernel(prog, "kernel_neg_f16_4", &err), err)); + CL_CHECK((backend_ctx->kernel_neg_f16_nc = clCreateKernel(prog, "kernel_neg_f16_nc", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // exp + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "exp.cl.h" + }; +#else + const std::string kernel_src = read_file("exp.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_exp_f32 = clCreateKernel(prog, "kernel_exp_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_exp_f32_4 = clCreateKernel(prog, "kernel_exp_f32_4", &err), err)); + CL_CHECK((backend_ctx->kernel_exp_f32_nc = clCreateKernel(prog, "kernel_exp_f32_nc", &err), err)); + CL_CHECK((backend_ctx->kernel_exp_f16 = clCreateKernel(prog, "kernel_exp_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_exp_f16_4 = clCreateKernel(prog, "kernel_exp_f16_4", &err), err)); + CL_CHECK((backend_ctx->kernel_exp_f16_nc = clCreateKernel(prog, "kernel_exp_f16_nc", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + // expm1 { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -3592,6 +3656,8 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te case GGML_UNARY_OP_SIGMOID: return ggml_is_contiguous(op->src[0]); case GGML_UNARY_OP_TANH: + case GGML_UNARY_OP_NEG: + case GGML_UNARY_OP_EXP: return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16; case GGML_UNARY_OP_EXPM1: return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16; @@ -3677,6 +3743,8 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te case GGML_OP_PERMUTE: case GGML_OP_TRANSPOSE: return true; + case GGML_OP_DIAG: + return true; case GGML_OP_DIAG_MASK_INF: return op->ne[3] == 1; case GGML_OP_ROPE: { @@ -7581,6 +7649,170 @@ static void ggml_cl_tanh(ggml_backend_t backend, const ggml_tensor * src0, const } } +static void ggml_cl_neg(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + UNUSED(src1); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + GGML_TENSOR_LOCALS(int, ne0, src0, ne); + GGML_TENSOR_LOCALS(cl_ulong, nb0, src0, nb); + GGML_TENSOR_LOCALS(int, ne, dst, ne); + GGML_TENSOR_LOCALS(cl_ulong, nb, dst, nb); + + cl_kernel kernel; + + if (ggml_is_contiguous(src0)) { + // Handle contiguous input + int n = ggml_nelements(dst); + if (n % 4 == 0) { + if (src0->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_neg_f32_4; + } else { + kernel = backend_ctx->kernel_neg_f16_4; + } + n /= 4; + } else { + if (src0->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_neg_f32; + } else { + kernel = backend_ctx->kernel_neg_f16; + } + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_int), &n)); + + size_t global_work_size[] = {(size_t)CEIL_DIV(n, 64)*64, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + } else { + // Handle non-contiguous input + if (src0->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_neg_f32_nc; + } else { + kernel = backend_ctx->kernel_neg_f16_nc; + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &nb00)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb0)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb1)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb2)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb3)); + + int nth = 64; + + size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03}; + size_t local_work_size[] = {(size_t)nth, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + } +} + +static void ggml_cl_exp(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + UNUSED(src1); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + GGML_TENSOR_LOCALS(int, ne0, src0, ne); + GGML_TENSOR_LOCALS(cl_ulong, nb0, src0, nb); + GGML_TENSOR_LOCALS(int, ne, dst, ne); + GGML_TENSOR_LOCALS(cl_ulong, nb, dst, nb); + + cl_kernel kernel; + + if (ggml_is_contiguous(src0)) { + // Handle contiguous input + int n = ggml_nelements(dst); + if (n % 4 == 0) { + if (src0->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_exp_f32_4; + } else { + kernel = backend_ctx->kernel_exp_f16_4; + } + n /= 4; + } else { + if (src0->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_exp_f32; + } else { + kernel = backend_ctx->kernel_exp_f16; + } + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_int), &n)); + + size_t global_work_size[] = {(size_t)CEIL_DIV(n, 64)*64, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + } else { + // Handle non-contiguous input + if (src0->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_exp_f32_nc; + } else { + kernel = backend_ctx->kernel_exp_f16_nc; + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &nb00)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb0)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb1)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb2)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb3)); + + int nth = 64; + + size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03}; + size_t local_work_size[] = {(size_t)nth, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + } +} + static void ggml_cl_expm1(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0); GGML_ASSERT(src0->extra); @@ -11029,6 +11261,49 @@ static void ggml_cl_diag_mask_inf(ggml_backend_t backend, const ggml_tensor * sr } } +static void ggml_cl_diag(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + UNUSED(src1); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + GGML_TENSOR_LOCALS(int, ne0, src0, ne); + GGML_TENSOR_LOCALS(cl_ulong, nb0, src0, nb); + GGML_TENSOR_LOCALS(int, ne, dst, ne); + GGML_TENSOR_LOCALS(cl_ulong, nb, dst, nb); + + cl_kernel kernel = backend_ctx->kernel_diag_f32; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb0)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb2)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb3)); + + int nth = 64; + + size_t global_work_size[] = {(size_t)ne1*nth, (size_t)ne2, (size_t)ne3}; + size_t local_work_size[] = {(size_t)nth, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); +} + static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0); GGML_ASSERT(src0->extra); @@ -11845,6 +12120,18 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor } func = ggml_cl_tanh; break; + case GGML_UNARY_OP_NEG: + if (!any_on_device) { + return false; + } + func = ggml_cl_neg; + break; + case GGML_UNARY_OP_EXP: + if (!any_on_device) { + return false; + } + func = ggml_cl_exp; + break; case GGML_UNARY_OP_EXPM1: if (!any_on_device) { return false; @@ -11971,6 +12258,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor } func = ggml_cl_nop; break; + case GGML_OP_DIAG: + if (!any_on_device) { + return false; + } + func = ggml_cl_diag; + break; case GGML_OP_DIAG_MASK_INF: if (!any_on_device) { return false; diff --git a/ggml/src/ggml-opencl/kernels/diag.cl b/ggml/src/ggml-opencl/kernels/diag.cl new file mode 100644 index 00000000000..884efa08fdd --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/diag.cl @@ -0,0 +1,27 @@ +kernel void kernel_diag_f32( + global const char * src0, + ulong offset0, + global char * dst, + ulong offsetd, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + ulong nb0, + ulong nb2, + ulong nb3 +) { + src0 = src0 + offset0; + dst = dst + offsetd; + + int i3 = get_group_id(2); + int i2 = get_group_id(1); + int i1 = get_group_id(0); + + global const float * src0_ptr = (global const float *)(src0 + i2*nb02 + i3*nb03); + global float * dst_ptr = (global float *)(dst + i1*nb01 + i2*nb2 + i3*nb3); + + for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) { + dst_ptr[i0] = i0 == i1 ? src0_ptr[i0] : 0.0f; + } +} diff --git a/ggml/src/ggml-opencl/kernels/exp.cl b/ggml/src/ggml-opencl/kernels/exp.cl new file mode 100644 index 00000000000..a2458b6579c --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/exp.cl @@ -0,0 +1,125 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +kernel void kernel_exp_f32( + global const float * src0, + ulong offset0, + global float * dst, + ulong offsetd, + int n +) { + if (get_global_id(0) >= n) { + return; + } + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + dst[get_global_id(0)] = exp(src0[get_global_id(0)]); +} + +kernel void kernel_exp_f32_4( + global const float4 * src0, + ulong offset0, + global float4 * dst, + ulong offsetd, + int n +) { + if (get_global_id(0) >= n) { + return; + } + src0 = (global float4*)((global char*)src0 + offset0); + dst = (global float4*)((global char*)dst + offsetd); + + dst[get_global_id(0)] = exp(src0[get_global_id(0)]); +} + +kernel void kernel_exp_f16( + global const half * src0, + ulong offset0, + global half * dst, + ulong offsetd, + int n +) { + if (get_global_id(0) >= n) { + return; + } + src0 = (global half*)((global char*)src0 + offset0); + dst = (global half*)((global char*)dst + offsetd); + + dst[get_global_id(0)] = exp(src0[get_global_id(0)]); +} + +kernel void kernel_exp_f16_4( + global const half4 * src0, + ulong offset0, + global half4 * dst, + ulong offsetd, + int n +) { + if (get_global_id(0) >= n) { + return; + } + src0 = (global half4*)((global char*)src0 + offset0); + dst = (global half4*)((global char*)dst + offsetd); + + dst[get_global_id(0)] = exp(src0[get_global_id(0)]); +} + +kernel void kernel_exp_f32_nc( + global const char * src0, + ulong offset0, + global char * dst, + ulong offsetd, + int ne00, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = src0 + offset0; + dst = dst + offsetd; + + const int i3 = get_group_id(2); + const int i2 = get_group_id(1); + const int i1 = get_group_id(0); + + for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) { + global const float * x = (global const float *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global float * y = (global float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + *y = exp(*x); + } +} + +kernel void kernel_exp_f16_nc( + global const char * src0, + ulong offset0, + global char * dst, + ulong offsetd, + int ne00, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = src0 + offset0; + dst = dst + offsetd; + + const int i3 = get_group_id(2); + const int i2 = get_group_id(1); + const int i1 = get_group_id(0); + + for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) { + global const half * x = (global const half *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global half * y = (global half *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + *y = exp(*x); + } +} diff --git a/ggml/src/ggml-opencl/kernels/neg.cl b/ggml/src/ggml-opencl/kernels/neg.cl new file mode 100644 index 00000000000..a862d8bc585 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/neg.cl @@ -0,0 +1,125 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +kernel void kernel_neg_f32( + global const float * src0, + ulong offset0, + global float * dst, + ulong offsetd, + int n +) { + if (get_global_id(0) >= n) { + return; + } + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + dst[get_global_id(0)] = -src0[get_global_id(0)]; +} + +kernel void kernel_neg_f32_4( + global const float4 * src0, + ulong offset0, + global float4 * dst, + ulong offsetd, + int n +) { + if (get_global_id(0) >= n) { + return; + } + src0 = (global float4*)((global char*)src0 + offset0); + dst = (global float4*)((global char*)dst + offsetd); + + dst[get_global_id(0)] = -src0[get_global_id(0)]; +} + +kernel void kernel_neg_f16( + global const half * src0, + ulong offset0, + global half * dst, + ulong offsetd, + int n +) { + if (get_global_id(0) >= n) { + return; + } + src0 = (global half*)((global char*)src0 + offset0); + dst = (global half*)((global char*)dst + offsetd); + + dst[get_global_id(0)] = -src0[get_global_id(0)]; +} + +kernel void kernel_neg_f16_4( + global const half4 * src0, + ulong offset0, + global half4 * dst, + ulong offsetd, + int n +) { + if (get_global_id(0) >= n) { + return; + } + src0 = (global half4*)((global char*)src0 + offset0); + dst = (global half4*)((global char*)dst + offsetd); + + dst[get_global_id(0)] = -src0[get_global_id(0)]; +} + +kernel void kernel_neg_f32_nc( + global const char * src0, + ulong offset0, + global char * dst, + ulong offsetd, + int ne00, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = src0 + offset0; + dst = dst + offsetd; + + const int i3 = get_group_id(2); + const int i2 = get_group_id(1); + const int i1 = get_group_id(0); + + for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) { + global const float * x = (global const float *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global float * y = (global float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + *y = -*x; + } +} + +kernel void kernel_neg_f16_nc( + global const char * src0, + ulong offset0, + global char * dst, + ulong offsetd, + int ne00, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = src0 + offset0; + dst = dst + offsetd; + + const int i3 = get_group_id(2); + const int i2 = get_group_id(1); + const int i1 = get_group_id(0); + + for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) { + global const half * x = (global const half *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global half * y = (global half *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + *y = -*x; + } +} From 596b655dbd8ebbd88de1a8857ebd5318c867ccad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Fri, 6 Mar 2026 09:12:49 +0100 Subject: [PATCH 219/831] ggml-cpu: fix data race for debug asserts (llama/20148) --- ggml/src/ggml-cpu/ops.cpp | 76 +++++++++++++++++++-------------------- 1 file changed, 38 insertions(+), 38 deletions(-) diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index ca1b3059b8c..243f01caf8e 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -2129,12 +2129,12 @@ static void ggml_compute_forward_gelu_f32( #ifndef NDEBUG for (int k = 0; k < nc; k++) { - const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + const float x = ((float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*(dst->nb[1])))[k]; GGML_UNUSED(x); assert(!isnan(x)); assert(!isinf(x)); } -#endif +#endif // NDEBUG } } @@ -2176,13 +2176,13 @@ static void ggml_compute_forward_gelu_f16( #ifndef NDEBUG for (int k = 0; k < nc; k++) { - const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*( dst->nb[1])))[k]; const float v = GGML_CPU_FP16_TO_FP32(x); GGML_UNUSED(v); assert(!isnan(v)); assert(!isinf(v)); } -#endif +#endif // NDEBUG } } @@ -2325,12 +2325,12 @@ static void ggml_compute_forward_gelu_erf_f32( #ifndef NDEBUG for (int k = 0; k < nc; k++) { - const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + const float x = ((float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*(dst->nb[1])))[k]; GGML_UNUSED(x); assert(!isnan(x)); assert(!isinf(x)); } -#endif +#endif // NDEBUG } } @@ -2372,13 +2372,13 @@ static void ggml_compute_forward_gelu_erf_f16( #ifndef NDEBUG for (int k = 0; k < nc; k++) { - const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*( dst->nb[1])))[k]; const float v = GGML_CPU_FP16_TO_FP32(x); GGML_UNUSED(v); assert(!isnan(v)); assert(!isinf(v)); } -#endif +#endif // NDEBUG } } @@ -2444,12 +2444,12 @@ static void ggml_compute_forward_gelu_quick_f32( #ifndef NDEBUG for (int k = 0; k < nc; k++) { - const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + const float x = ((float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*(dst->nb[1])))[k]; GGML_UNUSED(x); assert(!isnan(x)); assert(!isinf(x)); } -#endif +#endif // NDEBUG } } @@ -2491,13 +2491,13 @@ static void ggml_compute_forward_gelu_quick_f16( #ifndef NDEBUG for (int k = 0; k < nc; k++) { - const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*( dst->nb[1])))[k]; const float v = GGML_CPU_FP16_TO_FP32(x); GGML_UNUSED(v); assert(!isnan(v)); assert(!isinf(v)); } -#endif +#endif // NDEBUG } } @@ -2563,12 +2563,12 @@ static void ggml_compute_forward_silu_f32( #ifndef NDEBUG for (int k = 0; k < nc; k++) { - const float x = ((float *) ((char *) dst->data + i1*(dst->nb[1])))[k]; + const float x = ((float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*(dst->nb[1])))[k]; GGML_UNUSED(x); assert(!isnan(x)); assert(!isinf(x)); } -#endif +#endif // NDEBUG } } @@ -2610,13 +2610,13 @@ static void ggml_compute_forward_silu_f16( #ifndef NDEBUG for (int k = 0; k < nc; k++) { - const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])))[k]; + const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*( dst->nb[1])))[k]; const float v = GGML_CPU_FP16_TO_FP32(x); GGML_UNUSED(v); assert(!isnan(v)); assert(!isinf(v)); } -#endif +#endif // NDEBUG } } @@ -2766,7 +2766,7 @@ static void ggml_compute_forward_silu_back_f32( assert(!isnan(x)); assert(!isinf(x)); } -#endif +#endif // NDEBUG } } @@ -2802,7 +2802,7 @@ static void ggml_compute_forward_silu_back_f16( (ggml_fp16_t *) ((char *) src1->data + i1*(src1->nb[1])), (ggml_fp16_t *) ((char *) grad->data + i1*(grad->nb[1]))); - #ifndef NDEBUG +#ifndef NDEBUG for (int k = 0; k < nc; k++) { const float x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k]; const float v = GGML_CPU_FP16_TO_FP32(x); @@ -2810,7 +2810,7 @@ static void ggml_compute_forward_silu_back_f16( assert(!isnan(v)); assert(!isinf(v)); } - #endif +#endif // NDEBUG } } @@ -2893,7 +2893,7 @@ static void ggml_compute_forward_reglu_f32( assert(!isnan(x)); assert(!isinf(x)); } -#endif +#endif // NDEBUG } } @@ -2953,7 +2953,7 @@ static void ggml_compute_forward_reglu_f16( assert(!isnan(v)); assert(!isinf(v)); } -#endif +#endif // NDEBUG } } @@ -3036,7 +3036,7 @@ static void ggml_compute_forward_geglu_f32( assert(!isnan(x)); assert(!isinf(x)); } -#endif +#endif // NDEBUG } } @@ -3096,7 +3096,7 @@ static void ggml_compute_forward_geglu_f16( assert(!isnan(v)); assert(!isinf(v)); } -#endif +#endif // NDEBUG } } @@ -3179,7 +3179,7 @@ static void ggml_compute_forward_swiglu_f32( assert(!isnan(x)); assert(!isinf(x)); } -#endif +#endif // NDEBUG } } @@ -3239,7 +3239,7 @@ static void ggml_compute_forward_swiglu_f16( assert(!isnan(v)); assert(!isinf(v)); } -#endif +#endif // NDEBUG } } @@ -3330,7 +3330,7 @@ static void ggml_compute_forward_swiglu_oai_f32( assert(!isnan(x)); assert(!isinf(x)); } -#endif +#endif // NDEBUG } } @@ -3409,7 +3409,7 @@ static void ggml_compute_forward_geglu_erf_f32( assert(!isnan(x)); assert(!isinf(x)); } -#endif +#endif // NDEBUG } } @@ -3469,7 +3469,7 @@ static void ggml_compute_forward_geglu_erf_f16( assert(!isnan(v)); assert(!isinf(v)); } -#endif +#endif // NDEBUG } } @@ -3552,7 +3552,7 @@ static void ggml_compute_forward_geglu_quick_f32( assert(!isnan(x)); assert(!isinf(x)); } -#endif +#endif // NDEBUG } } @@ -3612,7 +3612,7 @@ static void ggml_compute_forward_geglu_quick_f16( assert(!isnan(v)); assert(!isinf(v)); } -#endif +#endif // NDEBUG } } @@ -5303,7 +5303,7 @@ static void ggml_compute_forward_soft_max_f32( //printf("p[%d] = %f\n", i, p[i]); assert(!isnan(wp[i])); } -#endif +#endif // NDEBUG float max = -INFINITY; ggml_vec_max_f32(ne00, &max, wp); @@ -5328,7 +5328,7 @@ static void ggml_compute_forward_soft_max_f32( assert(!isnan(dp[i])); assert(!isinf(dp[i])); } -#endif +#endif // NDEBUG } } } @@ -5402,7 +5402,7 @@ static void ggml_compute_forward_soft_max_ext_back_f32( assert(!isnan(dy[i])); assert(!isnan(y[i])); } -#endif +#endif // NDEBUG // Jii = yi - yi*yi // Jij = -yi*yj // J = diag(y)-y.T*y @@ -5435,7 +5435,7 @@ static void ggml_compute_forward_soft_max_ext_back_f32( assert(!isnan(dx[i])); assert(!isinf(dx[i])); } -#endif +#endif // NDEBUG } } @@ -10700,7 +10700,7 @@ static void ggml_compute_forward_cross_entropy_loss_f32( assert(!isnan(s0[i])); assert(!isnan(s1[i])); } -#endif +#endif // NDEBUG float max = -INFINITY; ggml_vec_max_f32(nc, &max, s0); @@ -10719,7 +10719,7 @@ static void ggml_compute_forward_cross_entropy_loss_f32( assert(!isnan(st[i])); assert(!isinf(st[i])); } -#endif +#endif // NDEBUG } sums[ith] = sum_thread; ggml_barrier(params->threadpool); @@ -10792,7 +10792,7 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32( assert(!isnan(s0[i])); assert(!isnan(s1[i])); } -#endif +#endif // NDEBUG // soft_max float max = -INFINITY; @@ -10810,7 +10810,7 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32( assert(!isnan(ds0[i])); assert(!isinf(ds0[i])); } -#endif +#endif // NDEBUG } } From d2d235f4679b3177e4c662d589d973550e06ab2c Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Fri, 6 Mar 2026 23:09:59 +0800 Subject: [PATCH 220/831] CUDA: use shared mem for ssm_conv (llama/20128) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * CUDA: use shared mem for ssm_conv * fuse silu + ssm_conv * fuse unary + mul * enable for fp16 * formatting Co-authored-by: Johannes Gäßler --------- Co-authored-by: Johannes Gäßler --- ggml/src/ggml-cuda/ggml-cuda.cu | 54 +++++++++++++++++++++ ggml/src/ggml-cuda/ssm-conv.cu | 85 ++++++++++++++++++++------------- ggml/src/ggml-cuda/ssm-conv.cuh | 2 +- ggml/src/ggml-cuda/unary.cu | 55 +++++++++++++++++++++ ggml/src/ggml-cuda/unary.cuh | 2 + 5 files changed, 165 insertions(+), 33 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index b2dcaf42fc3..35015bc7f3b 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3348,6 +3348,46 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, return true; } + if (ops.size() == 2 && ops.begin()[0] == GGML_OP_SSM_CONV && ops.begin()[1] == GGML_OP_UNARY + && unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_SILU) { + const ggml_tensor * ssm_conv = cgraph->nodes[node_idx]; + const ggml_tensor * silu = cgraph->nodes[node_idx+1]; + + if (ssm_conv->type != GGML_TYPE_F32 || silu->type != GGML_TYPE_F32) { + return false; + } + + return true; + } + + if (ops.size() == 2 && ops.begin()[0] == GGML_OP_UNARY && ops.begin()[1] == GGML_OP_MUL + && unary_ops.size() == 1 && (unary_ops.begin()[0] == GGML_UNARY_OP_SILU || unary_ops.begin()[0] == GGML_UNARY_OP_SIGMOID || unary_ops.begin()[0] == GGML_UNARY_OP_SOFTPLUS)) { + const ggml_tensor * unary = cgraph->nodes[node_idx]; + const ggml_tensor * mul = cgraph->nodes[node_idx+1]; + + if (ggml_get_unary_op(unary) != unary_ops.begin()[0]) { + return false; + } + + if (unary->type != GGML_TYPE_F32 && unary->type != GGML_TYPE_F16) { + return false; + } + + if (unary->type != mul->type) { + return false; + } + + const ggml_tensor * other = (mul->src[0] == unary) ? mul->src[1] : mul->src[0]; + if (other->type != unary->type) { + return false; + } + if (!ggml_is_contiguous_1(other) || !ggml_is_contiguous_1(unary->src[0]) || !ggml_are_same_shape(other, unary)) { + return false; + } + + return true; + } + if (ops.size() == 3 && ops.begin()[0] == GGML_OP_SCALE && ops.begin()[1] == GGML_OP_UNARY && ops.begin()[2] == GGML_OP_SCALE && unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_TANH) { const ggml_tensor *scale = cgraph->nodes[node_idx]; @@ -3836,6 +3876,20 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud continue; } + if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SSM_CONV, GGML_OP_UNARY }, { GGML_UNARY_OP_SILU })) { + ggml_cuda_op_ssm_conv(*cuda_ctx, node, cgraph->nodes[i+1]); + i++; + continue; + } + + if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_MUL }, { GGML_UNARY_OP_SILU }) || + ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_MUL }, { GGML_UNARY_OP_SIGMOID }) || + ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_MUL }, { GGML_UNARY_OP_SOFTPLUS })) { + ggml_cuda_op_unary_mul(*cuda_ctx, node, cgraph->nodes[i+1]); + i++; + continue; + } + if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SCALE, GGML_OP_UNARY, GGML_OP_SCALE }, { GGML_UNARY_OP_TANH })) { i += 2; ggml_cuda_op_softcap(*cuda_ctx, cgraph->nodes[i], node); diff --git a/ggml/src/ggml-cuda/ssm-conv.cu b/ggml/src/ggml-cuda/ssm-conv.cu index 6d5ea704c65..85e82b5a422 100644 --- a/ggml/src/ggml-cuda/ssm-conv.cu +++ b/ggml/src/ggml-cuda/ssm-conv.cu @@ -1,6 +1,7 @@ #include "ssm-conv.cuh" +#include "unary.cuh" -template +template static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float * __restrict__ src1, const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1, float * __restrict__ dst, const int dst_nb0, const int dst_nb1, const int dst_nb2, @@ -41,11 +42,11 @@ static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float for (size_t j = 0; j < d_conv; j++) { sumf += x[(i + j) % d_conv] * w[j]; } - y_block[i * stride_y + tid] = sumf; + y_block[i * stride_y + tid] = apply_silu ? ggml_cuda_op_silu_single(sumf) : sumf; } } -template +template static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0, const float * __restrict__ src1, const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1, float * __restrict__ dst, const int dst_nb0, @@ -65,36 +66,46 @@ static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0, const int stride_w = src1_nb1 / sizeof(float); const int stride_y = dst_nb1 / sizeof(float); - float x[d_conv] = { 0.0f }; - float w[d_conv] = { 0.0f }; + const int64_t local_n_t = min(split_n_t, n_t - bidz * split_n_t); + const int n_cols = d_conv - 1 + split_n_t; + + extern __shared__ float smem[]; + constexpr int load_cols = d_conv - 1 + split_n_t; + constexpr int total_elems = split_d_inner * load_cols; + int row = tid / load_cols; + int col = tid % load_cols; #pragma unroll - for (size_t j = 0; j < d_conv; j++) { - w[j] = w_block[tid * stride_w + j]; + for (int idx = tid; idx < total_elems; idx += split_d_inner) { + if (row < (int)split_d_inner) { + smem[row * n_cols + col] = x_block[row * stride_x + col]; + } + + col += split_d_inner; + row += col / load_cols; + col = col % load_cols; } + __syncthreads(); + // Load weights into registers (done once, small) + float w[d_conv] = { 0.0f }; #pragma unroll - for (int64_t i = 0; i < split_n_t; i++) { - if (bidz * split_n_t + i < n_t) { - float sumf = 0.0f; - - if (i == 0) { - for (size_t j = 0; j < d_conv; j++) { - x[j] = x_block[tid * stride_x + j]; - } - } else { - x[(i - 1) % d_conv] = x_block[tid * stride_x + i + d_conv - 1]; - } + for (size_t j = 0; j < d_conv; j++) { + w[j] = w_block[tid * stride_w + j]; + } + // Compute from shared memory + for (int64_t i = 0; i < local_n_t; i++) { + float sumf = 0.0f; #pragma unroll - for (size_t j = 0; j < d_conv; j++) { - sumf += x[(i + j) % d_conv] * w[j]; - } - y_block[i * stride_y + tid] = sumf; + for (size_t j = 0; j < d_conv; j++) { + sumf += smem[tid * n_cols + i + j] * w[j]; } + y_block[i * stride_y + tid] = apply_silu ? ggml_cuda_op_silu_single(sumf) : sumf; } } +template static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1, float * dst, const int dst_nb0, const int dst_nb1, const int dst_nb2, const int64_t nc, const int64_t nr, const int64_t n_t, @@ -106,12 +117,13 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int constexpr int kNC = decltype(NC)::value; if (n_t <= 32) { const dim3 blocks(n_s, (nr + threads - 1) / threads, 1); - ssm_conv_f32<<>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, + ssm_conv_f32<<>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t); } else { const int64_t split_n_t = 32; dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t); - ssm_conv_long_token_f32<<>>( + const size_t smem_size = threads * (kNC - 1 + split_n_t) * sizeof(float); + ssm_conv_long_token_f32<<>>( src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t); } }; @@ -124,27 +136,36 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int } } -void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { +void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * silu_dst) { const struct ggml_tensor * src0 = dst->src[0]; // conv_x const struct ggml_tensor * src1 = dst->src[1]; // conv1d.weight + const bool fuse_silu = silu_dst != nullptr; + + // When fusing, write to silu_dst (the node downstream references). + const struct ggml_tensor * out = fuse_silu ? silu_dst : dst; const int64_t nc = src1->ne[0]; // d_conv const int64_t nr = src0->ne[1]; // d_inner - const int64_t n_t = dst->ne[1]; // tokens per sequence - const int64_t n_s = dst->ne[2]; // number of sequences in the batch + const int64_t n_t = out->ne[1]; // tokens per sequence + const int64_t n_s = out->ne[2]; // number of sequences in the batch - GGML_ASSERT(dst->ne[0] == nr); + GGML_ASSERT(out->ne[0] == nr); GGML_ASSERT(src0->nb[0] == sizeof(float)); GGML_ASSERT(src1->nb[0] == sizeof(float)); GGML_ASSERT(src0->nb[1] == src0->ne[0] * sizeof(float)); const float * src0_d = (const float *) src0->data; const float * src1_d = (const float *) src1->data; - float * dst_d = (float *) dst->data; + float * dst_d = (float *) out->data; cudaStream_t stream = ctx.stream(); GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); - ssm_conv_f32_cuda(src0_d, src1_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, dst->nb[0], dst->nb[1], - dst->nb[2], nc, nr, n_t, n_s, stream); + GGML_ASSERT(out->type == GGML_TYPE_F32); + if (fuse_silu) { + ssm_conv_f32_cuda(src0_d, src1_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, out->nb[0], out->nb[1], + out->nb[2], nc, nr, n_t, n_s, stream); + } else { + ssm_conv_f32_cuda(src0_d, src1_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, out->nb[0], out->nb[1], + out->nb[2], nc, nr, n_t, n_s, stream); + } } diff --git a/ggml/src/ggml-cuda/ssm-conv.cuh b/ggml/src/ggml-cuda/ssm-conv.cuh index 8e6c1f00bfa..f96a1cd2484 100644 --- a/ggml/src/ggml-cuda/ssm-conv.cuh +++ b/ggml/src/ggml-cuda/ssm-conv.cuh @@ -1,3 +1,3 @@ #include "common.cuh" -void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst); +void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * silu_dst = nullptr); diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index d4866067a4f..4ad30fa1f35 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -560,3 +560,58 @@ void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) leaky_relu_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), negative_slope, stream); } } + +/* fused unary + mul */ + +template +static void ggml_cuda_op_unary_mul_impl(ggml_backend_cuda_context & ctx, ggml_tensor * unary_node, ggml_tensor * mul_node) { + // unary_node: UNARY op applied to unary_node->src[0] + // mul_node: MUL(a, b) where one of a/b is unary_node + // Output goes to mul_node->data + + const ggml_tensor * unary_src = unary_node->src[0]; // input to the unary op + const ggml_tensor * other_src = (mul_node->src[0] == unary_node) ? mul_node->src[1] : mul_node->src[0]; + + GGML_ASSERT(ggml_is_contiguous_1(unary_src)); + GGML_ASSERT(unary_src->nb[0] == ggml_element_size(unary_src)); + GGML_ASSERT(ggml_is_contiguous_1(other_src)); + GGML_ASSERT(other_src->nb[0] == ggml_element_size(other_src)); + GGML_ASSERT(ggml_are_same_shape(unary_src, other_src)); + + GGML_ASSERT(unary_src->type == GGML_TYPE_F32 || unary_src->type == GGML_TYPE_F16); + GGML_ASSERT(unary_src->type == other_src->type); + GGML_ASSERT(unary_src->type == mul_node->type); + + cudaStream_t stream = ctx.stream(); + + const int64_t k = ggml_nelements(mul_node); + const int64_t nc = unary_src->ne[0]; + const int64_t unary_stride = unary_src->nb[1]; + const int64_t other_stride = other_src->nb[1]; + + if (unary_src->type == GGML_TYPE_F16) { + unary_gated_cuda((const half *) unary_src->data, (const half *) other_src->data, + (half *) mul_node->data, k, nc, + unary_stride / sizeof(half), other_stride / sizeof(half), stream); + } else { + unary_gated_cuda((const float *) unary_src->data, (const float *) other_src->data, + (float *) mul_node->data, k, nc, + unary_stride / sizeof(float), other_stride / sizeof(float), stream); + } +} + +void ggml_cuda_op_unary_mul(ggml_backend_cuda_context & ctx, ggml_tensor * unary_node, ggml_tensor * mul_node) { + switch (ggml_get_unary_op(unary_node)) { + case GGML_UNARY_OP_SILU: + ggml_cuda_op_unary_mul_impl(ctx, unary_node, mul_node); + break; + case GGML_UNARY_OP_SIGMOID: + ggml_cuda_op_unary_mul_impl(ctx, unary_node, mul_node); + break; + case GGML_UNARY_OP_SOFTPLUS: + ggml_cuda_op_unary_mul_impl(ctx, unary_node, mul_node); + break; + default: + GGML_ABORT("Unsupported unary op for fused unary+mul"); + } +} diff --git a/ggml/src/ggml-cuda/unary.cuh b/ggml/src/ggml-cuda/unary.cuh index 609046e5694..f1dd2183a6c 100644 --- a/ggml/src/ggml-cuda/unary.cuh +++ b/ggml/src/ggml-cuda/unary.cuh @@ -89,6 +89,8 @@ void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst void ggml_cuda_op_xielu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); +void ggml_cuda_op_unary_mul(ggml_backend_cuda_context & ctx, ggml_tensor * unary_node, ggml_tensor * mul_node); + __device__ __forceinline__ float ggml_cuda_op_silu_single(float x) { return x / (1.0f + expf(-x)); } From 548f2e51907ee66a07bbe96b791f79f88e61d5c9 Mon Sep 17 00:00:00 2001 From: shalinib-ibm Date: Fri, 6 Mar 2026 20:52:39 +0530 Subject: [PATCH 221/831] ggml-cpu: Fix gcc 15 ICE on ppc64le (ggml/20083) (llama/20130) This patch addresses an Internal Compiler Error (Segmentation fault) observed with gcc 15 by replacing the intrinsic + cast by doing a cat on the data first and then calling the intrinsic. This bypasses the buggy compiler path while maintaining identical instruction selection. Performance Verification: Assembly analysis on RHEL 9 (GCC 15.1.1) confirms that both the original code and this fix generate the identical Power10 prefixed load instruction: `plxv 40, 2(14)` This ensures zero performance regression while unblocking builds on newer toolchains. Reproduced on: - Alpine Linux + GCC 15.2.0-r2 - RHEL 9 + GCC 15.1.1 (gcc-toolset-15) Signed-off-by: Shalini Salomi Bodapati --- ggml/src/ggml-cpu/llamafile/sgemm.cpp | 32 +++++++++++++-------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/ggml/src/ggml-cpu/llamafile/sgemm.cpp b/ggml/src/ggml-cpu/llamafile/sgemm.cpp index 5fd452a03d2..c89e5076f26 100644 --- a/ggml/src/ggml-cpu/llamafile/sgemm.cpp +++ b/ggml/src/ggml-cpu/llamafile/sgemm.cpp @@ -2497,7 +2497,7 @@ class tinyBLAS_Q0_PPC { for (int r = 0; r < 8; r++) { const block_q4_0 * current_blk = rows_base[r] + blk; vector float v_scale = vec_extract_fp32_from_shorth(vec_splats(current_blk->d)); - vector signed char v_qs = reinterpret_cast(vec_xl(0, current_blk->qs)); + vector signed char v_qs = vec_xl(0, (const vector signed char *)current_blk->qs); vector signed char c1, c2; unpack_q4_to_q8(v_qs, c1, c2); convert_and_scale_q8(c1, v_scale, hp_res[r][0], hp_res[r][1]); @@ -2611,14 +2611,14 @@ class tinyBLAS_Q0_PPC { i = (cols >> 2); if (i > 0) { do { - c1[1] = reinterpret_cast(vec_xl(0, aoffset1->qs)); - c2[1] = reinterpret_cast(vec_xl(0, aoffset2->qs)); - c3[1] = reinterpret_cast(vec_xl(0, aoffset3->qs)); - c4[1] = reinterpret_cast(vec_xl(0, aoffset4->qs)); - c5[1] = reinterpret_cast(vec_xl(0, aoffset5->qs)); - c6[1] = reinterpret_cast(vec_xl(0, aoffset6->qs)); - c7[1] = reinterpret_cast(vec_xl(0, aoffset7->qs)); - c8[1] = reinterpret_cast(vec_xl(0, aoffset8->qs)); + c1[1] = vec_xl(0, (const vector signed char *)aoffset1->qs); + c2[1] = vec_xl(0, (const vector signed char *)aoffset2->qs); + c3[1] = vec_xl(0, (const vector signed char *)aoffset3->qs); + c4[1] = vec_xl(0, (const vector signed char *)aoffset4->qs); + c5[1] = vec_xl(0, (const vector signed char *)aoffset5->qs); + c6[1] = vec_xl(0, (const vector signed char *)aoffset6->qs); + c7[1] = vec_xl(0, (const vector signed char *)aoffset7->qs); + c8[1] = vec_xl(0, (const vector signed char *)aoffset8->qs); process_q4_elements(c1, & comparray[0]); process_q4_elements(c2, & comparray[1]); @@ -2657,10 +2657,10 @@ class tinyBLAS_Q0_PPC { i = (cols >> 2); if (i > 0) { do { - c1[1] = reinterpret_cast(vec_xl(0, aoffset1->qs)); - c2[1] = reinterpret_cast(vec_xl(0, aoffset2->qs)); - c3[1] = reinterpret_cast(vec_xl(0, aoffset3->qs)); - c4[1] = reinterpret_cast(vec_xl(0, aoffset4->qs)); + c1[1] = vec_xl(0, (const vector signed char *)aoffset1->qs); + c2[1] = vec_xl(0, (const vector signed char *)aoffset2->qs); + c3[1] = vec_xl(0, (const vector signed char *)aoffset3->qs); + c4[1] = vec_xl(0, (const vector signed char *)aoffset4->qs); process_q4_elements(c1, & comparray[0]); process_q4_elements(c2, & comparray[1]); @@ -2686,9 +2686,9 @@ class tinyBLAS_Q0_PPC { if (i > 0) { do { switch(rows) { - case 3: c3[1] = reinterpret_cast(vec_xl(0, aoffset3->qs)); - case 2: c2[1] = reinterpret_cast(vec_xl(0, aoffset2->qs)); - case 1: c1[1] = reinterpret_cast(vec_xl(0, aoffset1->qs)); + case 3: c3[1] = vec_xl(0, (const vector signed char *)aoffset3->qs); + case 2: c2[1] = vec_xl(0, (const vector signed char *)aoffset2->qs); + case 1: c1[1] = vec_xl(0, (const vector signed char *)aoffset1->qs); break; } process_q4_elements(c1, & comparray[0]); From 5d9b73dc066bfaa2987bc62e0f1faf7b24cfe6e9 Mon Sep 17 00:00:00 2001 From: Aaron Teo Date: Fri, 6 Mar 2026 23:24:38 +0800 Subject: [PATCH 222/831] ggml: update comments for backends which have no memory to report (llama/20157) Signed-off-by: Aaron Teo --- ggml/src/ggml-blas/ggml-blas.cpp | 4 ++-- ggml/src/ggml-opencl/ggml-opencl.cpp | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-blas/ggml-blas.cpp b/ggml/src/ggml-blas/ggml-blas.cpp index 2e9ddf2240d..5de64b816fc 100644 --- a/ggml/src/ggml-blas/ggml-blas.cpp +++ b/ggml/src/ggml-blas/ggml-blas.cpp @@ -339,8 +339,8 @@ static const char * ggml_backend_blas_device_get_description(ggml_backend_dev_t } static void ggml_backend_blas_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { - // TODO - *free = 0; + // no memory to report + *free = 0; *total = 0; GGML_UNUSED(dev); diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 4ef33a7765d..0a2c86c6e22 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -5345,7 +5345,8 @@ static const char * ggml_backend_opencl_device_get_description(ggml_backend_dev_ } static void ggml_backend_opencl_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { - *free = 0; + // no memory to report + *free = 0; *total = 0; GGML_UNUSED(dev); From d658720fa5221182009a79f8a29eb73fb8b814e4 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Sat, 7 Mar 2026 00:05:43 +0800 Subject: [PATCH 223/831] ggml-cuda: add mem check for fusion (llama/19916) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * ggml-cuda: add mem check for fusion * Replace NaNs with -FLT_MAX * fix typo Co-authored-by: Johannes Gäßler --------- Co-authored-by: Johannes Gäßler --- ggml/src/ggml-cuda/ggml-cuda.cu | 69 ++++++++++++++++++++++++++++++++- ggml/src/ggml-cuda/topk-moe.cu | 12 ++++++ 2 files changed, 79 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 35015bc7f3b..54dc43bc088 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3412,6 +3412,69 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, return false; } +// returns whether the write (out) nodes overwrite the read nodes in operation +static bool ggml_cuda_check_fusion_memory_ranges(ggml_cgraph * cgraph, + int node_idx, + int node_count, + int * out_nodes, + int out_count) { + auto nodes_overlap = [&](const ggml_tensor * a, const ggml_tensor * b) { + const int64_t a_start = (int64_t) a->data; + const int64_t a_end = a_start + ggml_nbytes(a); + + const int64_t b_start = (int64_t) b->data; + const int64_t b_end = b_start + ggml_nbytes(b); + + if ((b_start <= a_start && a_start < b_end) || (a_start <= b_start && b_start < a_end)) { + return true; + } + + return false; + }; + + bool is_ok = true; + // for nrows=1, all fusion operations correctly read the src before writing dst or do it elementwise, so we should be ok + if (ggml_nrows(cgraph->nodes[node_idx]) == 1) { + return true; + } + + for (int i = 0; i < out_count; ++i) { + const ggml_tensor * dst = cgraph->nodes[out_nodes[i]]; + + for (int j = node_idx; j < node_idx + node_count; ++j) { + // Loop over all srcs of all nodes in the fusion. If the src overlaps + // the destination and the src is not an intermediate node that's being + // elided, then disable fusion. + + for (int src_idx = 0; src_idx < GGML_MAX_SRC; ++src_idx) { + const ggml_tensor * src = cgraph->nodes[j]->src[src_idx]; + + if (!src || src->op == GGML_OP_NONE) { + continue; + } + + if (nodes_overlap(dst, src)) { + bool found = false; + + for (int k = node_idx; k < j; ++k) { + if (cgraph->nodes[k] == src) { + found = true; + break; + } + } + + if (!found) { + is_ok = false; + break; + } + } + } + } + } + + return is_ok; +} + static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, const bool use_cuda_graph, const bool cuda_graph_update_required, const void * graph_key) { bool graph_evaluated_or_captured = false; @@ -3608,7 +3671,8 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud out_nodes[1] = i + ops.size() - 1; if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) && - ggml_cuda_should_use_topk_moe(node, logits, weights, ids)) { + ggml_cuda_should_use_topk_moe(node, logits, weights, ids) && + ggml_cuda_check_fusion_memory_ranges(cgraph, i, ops.size(), out_nodes, 2)) { ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args); i += ops.size() - 1; continue; @@ -3623,7 +3687,8 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud int out_nodes[2] = { i + 1, i + 5 }; if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) && - ggml_cuda_should_use_topk_moe(softmax, logits, weights, ids)) { + ggml_cuda_should_use_topk_moe(softmax, logits, weights, ids) && + ggml_cuda_check_fusion_memory_ranges(cgraph, i, ops.size(), out_nodes, 2)) { ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args); i += ops.size() - 1; continue; diff --git a/ggml/src/ggml-cuda/topk-moe.cu b/ggml/src/ggml-cuda/topk-moe.cu index 08a88990dde..3020e5c7433 100644 --- a/ggml/src/ggml-cuda/topk-moe.cu +++ b/ggml/src/ggml-cuda/topk-moe.cu @@ -119,6 +119,18 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * } } + // Sanitize NaN to -FLT_MAX so the iterative argmax produces unique expert IDs. + // NaN comparisons always return false, which would cause the same expert to be + // selected repeatedly. -FLT_MAX compares normally and is still excluded by the + // -INFINITY sentinel used after each selection round. + // More relevant for the cuBLAS path. See https://github.com/ggml-org/llama.cpp/issues/19659 +#pragma unroll + for (int i = 0; i < experts_per_thread; i++) { + if (__isnanf(wt[i])) { + wt[i] = -FLT_MAX; + } + } + // selection_wt is only needed when bias is present (selection uses wt + bias) // when no bias, we use wt directly for both selection and weight values float selection_wt[has_bias ? experts_per_thread : 1]; From 247ec204d867ea89e6dcab95ad727ca9cf1f29a5 Mon Sep 17 00:00:00 2001 From: Max Krasnyansky Date: Fri, 6 Mar 2026 08:32:40 -0800 Subject: [PATCH 224/831] cpu: skip redudant ROPE cache updates (llama/20149) --- ggml/src/ggml-cpu/ops.cpp | 39 ++++++++++++++++++++++----------------- 1 file changed, 22 insertions(+), 17 deletions(-) diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 243f01caf8e..2c372f9635b 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -5803,28 +5803,33 @@ static void ggml_compute_forward_rope_flt( const int32_t * pos = (const int32_t *) src1->data; + int64_t last_i2 = -1; + for (int64_t i3 = 0; i3 < ne3; i3++) { // batch for (int64_t i2 = 0; i2 < ne2; i2++) { // seq-len - - float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith; - if (!mrope_used) { - const int64_t p = pos[i2]; - ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale); - } - else { - const int64_t p_t = pos[i2]; - const int64_t p_h = pos[i2 + ne2]; - const int64_t p_w = pos[i2 + ne2 * 2]; - const int64_t p_e = pos[i2 + ne2 * 3]; - ggml_mrope_cache_init( - p_t, p_h, p_w, p_e, sections, is_imrope, is_vision, - freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale); - } - for (int64_t i1 = 0; i1 < ne1; i1++) { // attn-heads - if (ir++ < ir0) continue; + if (ir++ < ir0) continue; // skip rows mapped to other threads if (ir > ir1) break; + float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith; + if (last_i2 != i2) { + if (!mrope_used) { + const int64_t p = pos[i2]; + ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale); + } + else { + const int64_t p_t = pos[i2]; + const int64_t p_h = pos[i2 + ne2]; + const int64_t p_w = pos[i2 + ne2 * 2]; + const int64_t p_e = pos[i2 + ne2 * 3]; + ggml_mrope_cache_init( + p_t, p_h, p_w, p_e, sections, is_imrope, is_vision, + freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale); + } + + last_i2 = i2; + } + T * src = (T *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); T * dst_data = (T *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1); From 78b3801d54e681d8a7394d096678f978a2df956c Mon Sep 17 00:00:00 2001 From: Todor Boinovski Date: Fri, 6 Mar 2026 09:59:26 -0800 Subject: [PATCH 225/831] hexagon: add f32 ssm_conv op (llama/20122) * hexagon: add ssm_conv op * hexagon: hvx kernel is functional * hexagon: improvements to ssm-conv hvx kernel * hexagon: added dma to ssm-conv hvx kernel * hexagon: ssm-conv dynamically compute gather scratchpad * hex-ssm-conv: add local context and fix various issues (spad indexing, etc) --------- Co-authored-by: Max Krasnyansky --- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 57 ++++ ggml/src/ggml-hexagon/htp/CMakeLists.txt | 1 + ggml/src/ggml-hexagon/htp/htp-msg.h | 1 + ggml/src/ggml-hexagon/htp/htp-ops.h | 4 +- ggml/src/ggml-hexagon/htp/hvx-utils.h | 8 + ggml/src/ggml-hexagon/htp/main.c | 49 ++++ ggml/src/ggml-hexagon/htp/ssm-conv.c | 339 +++++++++++++++++++++++ 7 files changed, 456 insertions(+), 3 deletions(-) create mode 100644 ggml/src/ggml-hexagon/htp/ssm-conv.c diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index b70da8f3b28..d6e9776b878 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -2152,6 +2152,44 @@ static bool ggml_hexagon_supported_rope(const struct ggml_hexagon_session * sess return true; } +static bool ggml_hexagon_supported_ssm_conv(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + const struct ggml_tensor * src0 = op->src[0]; + const struct ggml_tensor * src1 = op->src[1]; + const struct ggml_tensor * dst = op; + + // Only support FP32 for now + if (src0->type != GGML_TYPE_F32 || src1->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) { + return false; + } + + // Check IO tensor shapes and dims + if (src0->ne[3] != 1 || src1->ne[2] != 1 || src1->ne[3] != 1 || dst->ne[3] != 1) { + return false; // src0 should be effectively 3D + } + + const int d_conv = src1->ne[0]; + const int d_inner = src0->ne[1]; + const int n_t = dst->ne[1]; + const int n_s = dst->ne[2]; + + if (src0->ne[0] != d_conv - 1 + n_t || src0->ne[1] != d_inner || src0->ne[2] != n_s) { + return false; + } + if (src1->ne[0] != d_conv || src1->ne[1] != d_inner) { + return false; + } + if (dst->ne[0] != d_inner || dst->ne[1] != n_t || dst->ne[2] != n_s) { + return false; + } + + // TODO: add support for non-contiguous tensors + if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) { + return false; + } + + return true; +} + enum dspqbuf_type { DSPQBUF_TYPE_DSP_WRITE_CPU_READ = 0, DSPQBUF_TYPE_CPU_WRITE_DSP_READ, @@ -2468,6 +2506,17 @@ static inline size_t init_flash_attn_ext_req(htp_general_req * req, dspqueue_buf return n_bufs; } +static inline size_t init_ssm_conv_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { + req->op = HTP_OP_SSM_CONV; + + size_t n_bufs = 0; + n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); + n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CONSTANT); + n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); + + return n_bufs; +} + static const char * ggml_backend_hexagon_name(ggml_backend_t backend) { auto sess = static_cast(backend->context); return sess->name.c_str(); @@ -2606,6 +2655,10 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg ggml_hexagon_dispatch_op(sess, node, flags); break; + case GGML_OP_SSM_CONV: + ggml_hexagon_dispatch_op(sess, node, flags); + break; + default: GGML_ABORT("\nggml-hex: graph-compute %s is not supported\n", ggml_op_desc(node)); } @@ -3024,6 +3077,10 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons supp = ggml_hexagon_supported_argsort(sess, op); break; + case GGML_OP_SSM_CONV: + supp = ggml_hexagon_supported_ssm_conv(sess, op); + break; + default: break; } diff --git a/ggml/src/ggml-hexagon/htp/CMakeLists.txt b/ggml/src/ggml-hexagon/htp/CMakeLists.txt index 2c23b60da3d..02d07a503d5 100644 --- a/ggml/src/ggml-hexagon/htp/CMakeLists.txt +++ b/ggml/src/ggml-hexagon/htp/CMakeLists.txt @@ -31,6 +31,7 @@ add_library(${HTP_LIB} SHARED get-rows-ops.c cpy-ops.c argsort-ops.c + ssm-conv.c ) target_compile_definitions(${HTP_LIB} PRIVATE diff --git a/ggml/src/ggml-hexagon/htp/htp-msg.h b/ggml/src/ggml-hexagon/htp/htp-msg.h index 25403bb1126..52dcc36d8f7 100644 --- a/ggml/src/ggml-hexagon/htp/htp-msg.h +++ b/ggml/src/ggml-hexagon/htp/htp-msg.h @@ -68,6 +68,7 @@ enum htp_op { HTP_OP_SQR, HTP_OP_SQRT, HTP_OP_SUM_ROWS, + HTP_OP_SSM_CONV, INVALID }; diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index 127ab1d6659..2ef20936f1b 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -41,9 +41,6 @@ struct htp_ops_context { worker_pool_context_t * wpool; // worker pool uint32_t n_threads; // num threads - uint32_t src0_nrows_per_thread; - uint32_t src1_nrows_per_thread; - uint32_t flags; }; @@ -61,5 +58,6 @@ int op_set_rows(struct htp_ops_context * octx); int op_get_rows(struct htp_ops_context * octx); int op_cpy(struct htp_ops_context * octx); int op_argsort(struct htp_ops_context * octx); +int op_ssm_conv(struct htp_ops_context * octx); #endif /* HTP_OPS_H */ diff --git a/ggml/src/ggml-hexagon/htp/hvx-utils.h b/ggml/src/ggml-hexagon/htp/hvx-utils.h index a518ad37331..08343798794 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-utils.h +++ b/ggml/src/ggml-hexagon/htp/hvx-utils.h @@ -15,4 +15,12 @@ #include "hvx-div.h" #include "hvx-base.h" +#ifndef GATHER_TYPE +# if defined(__hexagon__) +# define GATHER_TYPE(_a) (intptr_t) _a +# else +# define GATHER_TYPE(_a) (HVX_Vector *) _a +# endif +#endif + #endif /* HVX_UTILS_H */ diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index 92a1422896c..3f99dbb32c4 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -757,6 +757,47 @@ static void proc_sum_rows_req(struct htp_context * ctx, struct htp_general_req * send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); } +static void proc_ssm_conv_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { + struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS]; + + // We've written to the output buffer, we'd also need to flush it + rsp_bufs[0].fd = bufs[2].fd; + rsp_bufs[0].ptr = bufs[2].ptr; + rsp_bufs[0].offset = bufs[2].offset; + rsp_bufs[0].size = bufs[2].size; + rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP + DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU + + // Setup OP context + struct htp_ops_context octx = { 0 }; + octx.ctx = ctx; + octx.src0 = req->src0; + octx.src1 = req->src1; + octx.dst = req->dst; + octx.flags = req->flags; + octx.op = req->op; + + memcpy(octx.op_params, req->op_params, sizeof(octx.op_params)); + + // Update data pointers + octx.src0.data = (uint32_t) bufs[0].ptr; + octx.src1.data = (uint32_t) bufs[1].ptr; + octx.dst.data = (uint32_t) bufs[2].ptr; + octx.n_threads = ctx->n_threads; + + struct profile_data prof; + profile_start(&prof); + + uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; + if (vtcm_acquire(ctx) == AEE_SUCCESS) { + rsp_status = op_ssm_conv(&octx); + vtcm_release(ctx); + } + + profile_stop(&prof); + send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); +} + static void proc_activations_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs, @@ -1142,6 +1183,14 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) { proc_argsort_req(ctx, &req, bufs); break; + case HTP_OP_SSM_CONV: + if (n_bufs != 3) { + FARF(ERROR, "Bad ssm-conv-req buffer list"); + continue; + } + proc_ssm_conv_req(ctx, &req, bufs); + break; + default: FARF(ERROR, "Unknown Op %u", req.op); break; diff --git a/ggml/src/ggml-hexagon/htp/ssm-conv.c b/ggml/src/ggml-hexagon/htp/ssm-conv.c new file mode 100644 index 00000000000..b3c1ef9572e --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/ssm-conv.c @@ -0,0 +1,339 @@ +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" +#include "hex-dma.h" +#include "htp-msg.h" +#include "htp-ops.h" +#include "hvx-utils.h" + +#define htp_ssm_conv_tensors_preamble \ + struct htp_tensor * restrict src0 = &octx->src0; \ + struct htp_tensor * restrict src1 = &octx->src1; \ + struct htp_tensor * restrict dst = &octx->dst; \ + struct htp_spad * restrict src0_spad = &octx->src0_spad; \ + struct htp_spad * restrict src1_spad = &octx->src1_spad; \ + struct htp_spad * restrict dst_spad = &octx->dst_spad; \ + \ + const uint32_t ne00 = src0->ne[0]; \ + const uint32_t ne01 = src0->ne[1]; \ + const uint32_t ne02 = src0->ne[2]; \ + const uint32_t ne03 = src0->ne[3]; \ + \ + const uint32_t ne10 = src1->ne[0]; \ + const uint32_t ne11 = src1->ne[1]; \ + const uint32_t ne12 = src1->ne[2]; \ + const uint32_t ne13 = src1->ne[3]; \ + \ + const uint32_t ne0 = dst->ne[0]; \ + const uint32_t ne1 = dst->ne[1]; \ + const uint32_t ne2 = dst->ne[2]; \ + const uint32_t ne3 = dst->ne[3]; \ + \ + const uint32_t nb00 = src0->nb[0]; \ + const uint32_t nb01 = src0->nb[1]; \ + const uint32_t nb02 = src0->nb[2]; \ + const uint32_t nb03 = src0->nb[3]; \ + \ + const uint32_t nb10 = src1->nb[0]; \ + const uint32_t nb11 = src1->nb[1]; \ + const uint32_t nb12 = src1->nb[2]; \ + const uint32_t nb13 = src1->nb[3]; \ + \ + const uint32_t nb0 = dst->nb[0]; \ + const uint32_t nb1 = dst->nb[1]; \ + const uint32_t nb2 = dst->nb[2]; \ + const uint32_t nb3 = dst->nb[3]; + +struct htp_ssm_conv_context { + struct htp_ops_context * octx; + uint32_t nrows_per_thread; + uint64_t t_start; +}; + +#define htp_ssm_conv_preamble \ + struct htp_ssm_conv_context * scctx = (struct htp_ssm_conv_context *) data; \ + struct htp_ops_context * octx = scctx->octx; \ + htp_ssm_conv_tensors_preamble; \ + dma_queue * dma_queue = octx->ctx->dma[ith]; + +// Scalar FP32 SSM_CONV implementation +static void ssm_conv_thread_f32_f32(unsigned int nth, unsigned int ith, void *data) { + htp_ssm_conv_preamble; + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + const uint32_t d_conv = src1->ne[0]; + const uint32_t d_inner = src0->ne[1]; + const uint32_t n_t = dst->ne[1]; + const uint32_t n_s = dst->ne[2]; + + const uint32_t src0_stride_inner = src0->nb[1] / sizeof(float); // stride for inner dimension + const uint32_t src0_stride_seq = src0->nb[2] / sizeof(float); // stride for sequence dimension + const uint32_t src1_stride_inner = src1->nb[1] / sizeof(float); // stride for inner dimension + const uint32_t dst_stride_token = dst->nb[1] / sizeof(float); // stride for token dimension + const uint32_t dst_stride_seq = dst->nb[2] / sizeof(float); // stride for sequence dimension + + const float * src0_data = (const float *) src0->data; + const float * src1_data = (const float *) src1->data; + float * dst_data = (float *) dst->data; + + // Calculate row range for this thread + const uint32_t d_inner_per_thread = scctx->nrows_per_thread; + const uint32_t d_inner_start = d_inner_per_thread * ith; + const uint32_t d_inner_end = MIN(d_inner_start + d_inner_per_thread, d_inner); + + // No work for this thread + if (d_inner_start >= d_inner_end) { + return; + } + + for (uint32_t i3 = 0; i3 < n_s; ++i3) { + for (uint32_t i2 = 0; i2 < n_t; ++i2) { + for (uint32_t i1 = d_inner_start; i1 < d_inner_end; ++i1) { + float sumf = 0.0f; + + for (uint32_t i0 = 0; i0 < d_conv; ++i0) { + const uint32_t src0_idx = (i2 + i0) + i1 * src0_stride_inner + i3 * src0_stride_seq; + const uint32_t src1_idx = i0 + i1 * src1_stride_inner; + + sumf += src0_data[src0_idx] * src1_data[src1_idx]; + } + + const uint32_t dst_idx = i1 + i2 * dst_stride_token + i3 * dst_stride_seq; + dst_data[dst_idx] = sumf; + } + } + } + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "ssm-conv-f32 %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", + ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], d_inner_start, d_inner_end, + src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], + dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +// HVX FP32 SSM_CONV implementation - vectorizes across d_inner dimension +static void ssm_conv_thread_f32_f32_hvx(unsigned int nth, unsigned int ith, void *data) { + htp_ssm_conv_preamble; + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + const int nc = src1->ne[0]; // d_conv + const int ncs = src0->ne[0]; // d_conv - 1 + n_t + + const uint32_t d_conv = src1->ne[0]; + const uint32_t d_inner = src0->ne[1]; + const uint32_t n_t = dst->ne[1]; + const uint32_t n_s = dst->ne[2]; + + const float * src0_data = (const float *) src0->data; + const float * src1_data = (const float *) src1->data; + float * dst_data = (float *) dst->data; + + // Calculate row range for this thread + const int dr = scctx->nrows_per_thread; + const uint32_t ir0 = dr * ith; + const uint32_t ir1 = MIN(ir0 + dr, d_inner); + const int ir = ir1 - ir0; + + if (ir0 >= ir1) { + return; // No work for this thread + } + + // src0 and src1 gather offsets + uint32_t __attribute__((aligned(VLEN))) src0_offsets[VLEN_FP32] = { 0 }; + uint32_t __attribute__((aligned(VLEN))) src1_offsets[VLEN_FP32] = { 0 }; + + for (uint32_t i = 0; i < VLEN_FP32; ++i) { + src0_offsets[i] = i * (ncs) * sizeof(float); + src1_offsets[i] = i * (d_conv) * sizeof(float); + } + + const uint32_t src0_gather_len = VLEN * ncs; + const uint32_t src1_gather_len = VLEN * d_conv; + + // gather scratchpads + HVX_Vector * src0_vec = (HVX_Vector *) (octx->ctx->vtcm_base + ith * VLEN*2 + 0); + HVX_Vector * src1_vec = (HVX_Vector *) (octx->ctx->vtcm_base + ith * VLEN*2 + VLEN); + + float * data_src0 = (float *) ((char *) src0->data + ir0 * src0->nb[1]); + float * data_src1 = (float *) ((char *) src1->data + ir0 * src1->nb[1]); + + uint8_t * spad_src0 = octx->src0_spad.data + ith * octx->src0_spad.size_per_thread; + uint8_t * spad_src1 = octx->src1_spad.data + ith * octx->src1_spad.size_per_thread; + + // copy src1 workload to VTCM + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src1, data_src1), nb11, nb11, ir); + + // FARF(HIGH, "ssm-conv-src1-fetch %d: ir0 %u size %u\n", ith, ir0, nb11 * ir); + + for (uint32_t i3 = 0; i3 < n_s; ++i3) { + float * src0_data_ptr = (float *) ((char *) data_src0 + i3 * (src0->nb[2])); + + // copy src0 workload to VTCM + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0, src0_data_ptr), nb01, nb01, ir); + + // FARF(HIGH, "ssm-conv-src0-fetch %d: ir0 %u i3 %u size %u\n", ith, ir0, i3, nb01 * ir); + + dma_queue_flush(dma_queue); + + for (uint32_t i2 = 0; i2 < n_t; ++i2) { + float * dst_ptr = (float *) ((char *) dst->data + ir0 * (dst->nb[0]) + i2 * (dst->nb[1]) + i3 * (dst->nb[2])); + + const uint32_t nvec = ir / VLEN_FP32; + const uint32_t nloe = ir % VLEN_FP32; + uint32_t i1 = 0; + + for (uint32_t vi1 = 0; vi1 < nvec; vi1++) { + HVX_Vector acc_vec = Q6_V_vsplat_R(0); + + for (uint32_t i0 = 0; i0 < d_conv; ++i0) { + Q6_vgather_ARMVw(src0_vec, GATHER_TYPE(spad_src0 + (i0 + i1 * ncs) * sizeof(float) + i2 * (src0->nb[0])), + src0_gather_len, (*(const HVX_Vector *) src0_offsets)); + Q6_vgather_ARMVw(src1_vec, GATHER_TYPE(spad_src1 + (i0 + i1 * nc) * sizeof(float)), + src1_gather_len, (*(const HVX_Vector *) src1_offsets)); + + HVX_Vector prod = Q6_Vqf32_vmpy_VsfVsf(*(const HVX_Vector *) src0_vec, *(const HVX_Vector *) src1_vec); + acc_vec = Q6_Vqf32_vadd_Vqf32Vqf32(acc_vec, prod); + } + + *(HVX_UVector *) (dst_ptr + i1) = Q6_Vsf_equals_Vqf32(acc_vec); + i1 += VLEN_FP32; + } + + if (nloe) { + HVX_Vector acc_vec = Q6_V_vsplat_R(0); + + for (uint32_t i0 = 0; i0 < d_conv; ++i0) { + Q6_vgather_ARMVw(src0_vec, GATHER_TYPE(spad_src0 + (i0 + i1 * ncs) * sizeof(float) + i2 * (src0->nb[0])), + src0_gather_len, (*(const HVX_Vector *) src0_offsets)); + Q6_vgather_ARMVw(src1_vec, GATHER_TYPE(spad_src1 + (i0 + i1 * nc) * sizeof(float)), + src1_gather_len, (*(const HVX_Vector *) src1_offsets)); + + HVX_Vector prod = Q6_Vqf32_vmpy_VsfVsf(*(const HVX_Vector *) src0_vec, *(const HVX_Vector *) src1_vec); + acc_vec = Q6_Vqf32_vadd_Vqf32Vqf32(acc_vec, prod); + } + + hvx_vec_store_u(dst_ptr + i1, (ir - i1) * 4, Q6_Vsf_equals_Vqf32(acc_vec)); + } + } + } + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "ssm-conv-f32-hvx %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", + ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ir0, ir1, + src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], + dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +int op_ssm_conv_f32(struct htp_ops_context * octx) { + htp_ssm_conv_tensors_preamble; + + if (src0->type != HTP_TYPE_F32 || src1->type != HTP_TYPE_F32 || dst->type != HTP_TYPE_F32) { + FARF(ERROR, "ssm_conv: only (F32 x F32 -> F32) OPs supported"); + return HTP_STATUS_NO_SUPPORT; + } + + struct htp_ssm_conv_context scctx = { 0 }; + scctx.octx = octx; + + const uint32_t d_conv = src1->ne[0]; + const uint32_t d_inner = src0->ne[1]; + const uint32_t n_t = dst->ne[1]; // tokens per sequence + const uint32_t n_s = dst->ne[2]; // number of sequences in the batch + + const uint32_t n_threads = MIN(octx->n_threads, d_inner); + + if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { + uint32_t use_hvx = 0; + if (d_inner >= VLEN_FP32 && d_inner % VLEN_FP32 == 0) { + int is_aligned = hex_is_aligned((void *) src0->data, VLEN) && + hex_is_aligned((void *) src1->data, VLEN) && + hex_is_aligned((void *) dst->data, VLEN); + + if (is_aligned) { + use_hvx = 1; + } + } + + if (use_hvx) { + scctx.nrows_per_thread = (d_inner + n_threads - 1) / n_threads; // d_inner chunks per thread + scctx.nrows_per_thread += (scctx.nrows_per_thread & 1); // round up to even + + octx->src0_spad.size_per_thread = hex_round_up(scctx.nrows_per_thread * nb01, 256); + octx->src1_spad.size_per_thread = hex_round_up(scctx.nrows_per_thread * nb11, 256); + octx->dst_spad.size_per_thread = hex_round_up(scctx.nrows_per_thread * sizeof(float), 256); + + octx->src0_spad.size = octx->src0_spad.size_per_thread * n_threads; + octx->src1_spad.size = octx->src1_spad.size_per_thread * n_threads; + octx->dst_spad.size = octx->dst_spad.size_per_thread * n_threads; + + // Compute gather scratchpad size for src0 and src1 + const size_t gather_spad_size = n_threads * VLEN * 2; + + octx->src0_spad.data = octx->ctx->vtcm_base + gather_spad_size; + octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; + octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; + + FARF(HIGH, "ssm_conv-f32: gather-spad:%zu spad-per-thread:(%u:%u:%u) spad-sizes:(%u:%u:%u) spad-data:(%p:%p:%p)\n", + gather_spad_size, octx->src0_spad.size_per_thread, octx->src1_spad.size_per_thread, + octx->dst_spad.size_per_thread, octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size, + octx->src0_spad.data, octx->src1_spad.data, octx->dst_spad.data); + + const size_t total_spad_size = + gather_spad_size + octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size; + + if (total_spad_size > octx->ctx->vtcm_size) { + FARF(HIGH, "ssm_conv-f32: HVX scratchpad size %zu exceeds VTCM size %zu", total_spad_size, + octx->ctx->vtcm_size); + use_hvx = 0; + } + } + + FARF(HIGH, "ssm-conv-f32: (%ux%ux%ux%u) x (%ux%ux%ux%u) -> (%ux%ux%ux%u) : use_hvx %d\n", src0->ne[0], + src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], + dst->ne[1], dst->ne[2], dst->ne[3], use_hvx); + + if (use_hvx) { + worker_pool_run_func(octx->ctx->worker_pool, ssm_conv_thread_f32_f32_hvx, &scctx, n_threads); + } else { + worker_pool_run_func(octx->ctx->worker_pool, ssm_conv_thread_f32_f32, &scctx, n_threads); + } + } + + return HTP_STATUS_OK; +} + +int op_ssm_conv(struct htp_ops_context * octx) { + int err = HTP_STATUS_OK; + struct htp_tensor * dst = &octx->dst; + + switch (dst->type) { + case HTP_TYPE_F32: + err = op_ssm_conv_f32(octx); + break; + default: + err = HTP_STATUS_NO_SUPPORT; + break; + } + + return err; +} From 6e063fae5a283c6def4c4d38d91b119be67d0b6b Mon Sep 17 00:00:00 2001 From: Bartowski <3266127+bartowski1182@users.noreply.github.com> Date: Fri, 6 Mar 2026 16:06:56 -0500 Subject: [PATCH 226/831] quants : Add memsets and other fixes for IQ quants (llama/19861) * Add memsets and other fixes for IQ quants * Make memset unconditional, change Laux back to L * Move another memset --- ggml/src/ggml-quants.c | 27 +++++++++++++++++++++++---- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index de5cbd75e86..e8e25633fb8 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -3104,6 +3104,11 @@ static void quantize_row_iq2_xxs_impl(const float * GGML_RESTRICT x, void * GGML } float scale = make_qp_quants(32, kMaxQ+1, xval, (uint8_t*)L, weight); float eff_max = scale*kMaxQ; + if (eff_max <= 0) { + scales[ib] = 0; + memset(L, 0, 32); + continue; + } float best = 0; for (int is = -6; is <= 6; ++is) { float id = (2*kMaxQ-1+is*0.1f)/eff_max; @@ -3273,9 +3278,9 @@ static void quantize_row_iq2_xs_impl(const float * GGML_RESTRICT x, void * GGML_ } float max = xval[0]; for (int i = 1; i < 16; ++i) max = MAX(max, xval[i]); + memset(L, 0, 16); if (max < GROUP_MAX_EPS) { scales[ib] = 0; - memset(L, 0, 16); continue; } float best = 0; @@ -3714,9 +3719,9 @@ static void quantize_row_iq3_xxs_impl(int grid_size, const float * GGML_RESTRICT } float max = xval[0]; for (int i = 1; i < 32; ++i) max = MAX(max, xval[i]); + memset(L, 0, 32); if (max < GROUP_MAX_EPS_IQ3_XXS) { scales[ib] = 0; - memset(L, 0, 32); continue; } float best = 0; @@ -3922,6 +3927,7 @@ static void quantize_row_iq3_s_impl(int block_size, const float * GGML_RESTRICT } float max = xval[0]; for (int i = 1; i < block_size; ++i) max = MAX(max, xval[i]); + memset(L, 0, block_size); if (!max) { scales[ib] = 0; continue; @@ -4245,6 +4251,7 @@ static void quantize_row_iq1_s_impl(const float * GGML_RESTRICT x, void * GGML_R for (int i = 1; i < block_size; ++i) max = MAX(max, fabsf(xb[i])); if (max < GROUP_MAX_EPS_IQ1_S) { scales[ib] = 0; + shifts[ib] = 1; memset(L, 1, block_size); continue; } @@ -4285,7 +4292,12 @@ static void quantize_row_iq1_s_impl(const float * GGML_RESTRICT x, void * GGML_R } } } - GGML_ASSERT(besti1 >= 0 && besti2 >= 0 && best_shift != 0); + if (besti1 < 0 || besti2 < 0 || best_shift == 0) { + scales[ib] = 0; + shifts[ib] = 1; + memset(L, 1, block_size); + continue; + } for (int j = 0; j < besti1; ++j) L[idx[2*j]] = 0; for (int j = besti1; j < besti2; ++j) L[idx[2*j]] = 1; for (int j = besti2; j < block_size; ++j) L[idx[2*j]] = 2; @@ -4429,6 +4441,7 @@ static void quantize_row_iq1_m_impl(const float * GGML_RESTRICT x, void * GGML_R for (int i = 1; i < block_size; ++i) max = MAX(max, fabsf(xb[i])); if (max < GROUP_MAX_EPS_IQ1_M) { scales[ib] = 0; + shifts[ib] = 0; memset(L, 1, block_size); continue; } @@ -4527,7 +4540,12 @@ static void quantize_row_iq1_m_impl(const float * GGML_RESTRICT x, void * GGML_R } } } - GGML_ASSERT(besti1 >= 0 && besti2 >= 0 && best_k >= 0); + if (besti1 < 0 || besti2 < 0 || best_k < 0) { + scales[ib] = 0; + shifts[ib] = 0; + memset(L, 1, block_size); + continue; + } for (int j = 0; j < besti1; ++j) L[idx[2*j]] = 0; for (int j = besti1; j < besti2; ++j) L[idx[2*j]] = 1; for (int j = besti2; j < block_size; ++j) L[idx[2*j]] = 2; @@ -4874,6 +4892,7 @@ static void quantize_row_iq2_s_impl(const float * GGML_RESTRICT x, void * GGML_R } float max = xval[0]; for (int i = 1; i < 16; ++i) max = MAX(max, xval[i]); + memset(L, 0, 16); if (max < GROUP_MAX_EPS_IQ2_S) { scales[ib] = 0; continue; From 910034df28633ccde365f58c9d3b83b5a65ba58d Mon Sep 17 00:00:00 2001 From: lhez Date: Fri, 6 Mar 2026 18:03:05 -0800 Subject: [PATCH 227/831] opencl: add l2_norm (llama/20160) --- ggml/src/ggml-opencl/CMakeLists.txt | 1 + ggml/src/ggml-opencl/ggml-opencl.cpp | 84 +++++++++++++++++++++++++ ggml/src/ggml-opencl/kernels/l2_norm.cl | 71 +++++++++++++++++++++ 3 files changed, 156 insertions(+) create mode 100644 ggml/src/ggml-opencl/kernels/l2_norm.cl diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index fb3ae17eaf4..70802c9c001 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -116,6 +116,7 @@ set(GGML_OPENCL_KERNELS neg norm relu + l2_norm rms_norm rope scale diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 0a2c86c6e22..67e4b9277f5 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -497,6 +497,7 @@ struct ggml_backend_opencl_context { kernel_geglu_f16, kernel_reglu_f16, kernel_swiglu_f16, kernel_geglu_erf_f16, kernel_geglu_quick_f16; cl_kernel kernel_norm, kernel_norm_mul_add; cl_kernel kernel_rms_norm, kernel_rms_norm_mul; + cl_kernel kernel_l2_norm_f32; cl_kernel kernel_group_norm, kernel_group_norm_mul_add; cl_kernel kernel_diag_mask_inf, kernel_diag_mask_inf_8; cl_kernel kernel_diag_f32; @@ -1585,6 +1586,23 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // l2_norm + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "l2_norm.cl.h" + }; +#else + const std::string kernel_src = read_file("l2_norm.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_l2_norm_f32 = clCreateKernel(prog, "kernel_l2_norm_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + // rope { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -3689,6 +3707,8 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te return true; case GGML_OP_RMS_NORM: return op->ne[0] % 4 == 0 && ggml_is_contiguous_rows(op->src[0]); + case GGML_OP_L2_NORM: + return ggml_is_contiguous_rows(op->src[0]); case GGML_OP_REPEAT: return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; // Assuming F32 for now, can be expanded case GGML_OP_PAD: @@ -7554,6 +7574,64 @@ static void ggml_cl_group_norm(ggml_backend_t backend, const ggml_tensor * src0, backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); } +static void ggml_cl_l2_norm(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + UNUSED(src1); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + + GGML_TENSOR_LOCALS(int, ne0, src0, ne); + GGML_TENSOR_LOCALS(cl_ulong, nb0, src0, nb); + + size_t sgs; + if (backend_ctx->gpu_family == ADRENO) { + sgs = 64; + } else if (backend_ctx->gpu_family == INTEL) { + sgs = 32; + } else { + GGML_ASSERT(false && "Unsupported GPU"); + } + + cl_kernel kernel = backend_ctx->kernel_l2_norm_f32; + + int nth = sgs; + while (nth < ne00 && nth < (int)backend_ctx->get_kernel_workgroup_size(kernel)) { + nth *= 2; + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(float), &eps)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float)*nth/sgs, NULL)); + + size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03}; + size_t local_work_size[] = {(size_t)nth, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); +} + static void ggml_cl_tanh(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0); GGML_ASSERT(src0->extra); @@ -12184,6 +12262,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor } func = ggml_cl_rms_norm; break; + case GGML_OP_L2_NORM: + if (!any_on_device) { + return false; + } + func = ggml_cl_l2_norm; + break; case GGML_OP_GROUP_NORM: if (!any_on_device) { return false; diff --git a/ggml/src/ggml-opencl/kernels/l2_norm.cl b/ggml/src/ggml-opencl/kernels/l2_norm.cl new file mode 100644 index 00000000000..39f400199fa --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/l2_norm.cl @@ -0,0 +1,71 @@ +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_32 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_l2_norm_f32( + global void * src0, + ulong offset0, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb01, + ulong nb02, + ulong nb03, + float eps, + local float * sum +) { + src0 = (global void*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + global float * x = (global float *) ((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01); + global float * y = (global float *) (dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + + float sumf = 0; + + // parallel sum + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + sumf += x[i00] * x[i00]; + } + sumf = sub_group_reduce_add(sumf); + + if (get_sub_group_local_id() == 0) { + sum[get_sub_group_id()] = sumf; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + // broadcast + for (uint i = get_local_size(0) / get_max_sub_group_size() / 2; i > 0; i /= 2) { + if (get_local_id(0) < i) { + sum[get_local_id(0)] += sum[get_local_id(0) + i]; + } + } + + barrier(CLK_LOCAL_MEM_FENCE); + + const float scale = 1.0f/sqrt(max(sum[0], eps)); + + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + y[i00] = x[i00] * scale; + } +} From 49489bfbd1704bf8e222ea7bef49153da23d45d2 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Sat, 7 Mar 2026 15:41:10 +0800 Subject: [PATCH 228/831] ggml: add GATED_DELTA_NET op (llama/19504) * ggml: add GATED_DELTA_NET op * remove the transpose * add KDA * add qwen35 dense * llama : check for fused gated delta net backend support --------- Co-authored-by: Georgi Gerganov --- ggml/include/ggml.h | 10 ++ ggml/src/ggml-cpu/ggml-cpu.c | 10 ++ ggml/src/ggml-cpu/ops.cpp | 184 ++++++++++++++++++++ ggml/src/ggml-cpu/ops.h | 1 + ggml/src/ggml-cuda/gated_delta_net.cu | 223 +++++++++++++++++++++++++ ggml/src/ggml-cuda/gated_delta_net.cuh | 4 + ggml/src/ggml-cuda/ggml-cuda.cu | 5 + ggml/src/ggml.c | 57 ++++++- 8 files changed, 492 insertions(+), 2 deletions(-) create mode 100644 ggml/src/ggml-cuda/gated_delta_net.cu create mode 100644 ggml/src/ggml-cuda/gated_delta_net.cuh diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 784d69206b4..566e2714790 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -556,6 +556,7 @@ extern "C" { GGML_OP_GATED_LINEAR_ATTN, GGML_OP_RWKV_WKV7, GGML_OP_SOLVE_TRI, + GGML_OP_GATED_DELTA_NET, GGML_OP_UNARY, @@ -2463,6 +2464,15 @@ extern "C" { bool lower, bool uni); + GGML_API struct ggml_tensor * ggml_gated_delta_net( + struct ggml_context * ctx, + struct ggml_tensor * q, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * g, + struct ggml_tensor * beta, + struct ggml_tensor * state); + // custom operators typedef void (*ggml_custom1_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata); diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 7c4026fac4e..dc2b5ffaa77 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -2021,6 +2021,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_solve_tri(params, tensor); } break; + case GGML_OP_GATED_DELTA_NET: + { + ggml_compute_forward_gated_delta_net(params, tensor); + } break; case GGML_OP_MAP_CUSTOM1: { ggml_compute_forward_map_custom1(params, tensor); @@ -2200,6 +2204,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { } break; case GGML_OP_COUNT_EQUAL: case GGML_OP_SOLVE_TRI: + case GGML_OP_GATED_DELTA_NET: { n_tasks = n_threads; } break; @@ -2905,6 +2910,11 @@ struct ggml_cplan ggml_graph_plan( { cur = ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks); } break; + case GGML_OP_GATED_DELTA_NET: + { + const int64_t S_v = node->src[2]->ne[0]; + cur = S_v * sizeof(float) * n_tasks; + } break; case GGML_OP_COUNT: { GGML_ABORT("fatal error"); diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 2c372f9635b..331e071a267 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -10380,6 +10380,190 @@ void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, s } } +// ggml_compute_forward_gated_delta_net +static void ggml_compute_forward_gated_delta_net_one_chunk( + const ggml_compute_params * params, + ggml_tensor * dst, + int64_t ir0, + int64_t ir1) { + + ggml_tensor * src_q = dst->src[0]; + ggml_tensor * src_k = dst->src[1]; + ggml_tensor * src_v = dst->src[2]; + ggml_tensor * src_g = dst->src[3]; + ggml_tensor * src_beta = dst->src[4]; + ggml_tensor * src_state = dst->src[5]; + + const int64_t S_v = src_v->ne[0]; + const int64_t H = src_v->ne[1]; + const int64_t n_tokens = src_v->ne[2]; + const int64_t n_seqs = src_v->ne[3]; + + GGML_ASSERT(ggml_is_contiguous_rows(src_q)); + GGML_ASSERT(ggml_is_contiguous_rows(src_k)); + GGML_ASSERT(ggml_is_contiguous_rows(src_v)); + GGML_ASSERT(ggml_is_contiguous(src_g)); + GGML_ASSERT(ggml_is_contiguous(src_beta)); + GGML_ASSERT(ggml_is_contiguous(src_state)); + + GGML_ASSERT(src_g->ne[0] == 1 || src_g->ne[0] == S_v); + GGML_ASSERT(src_beta->ne[0] == 1); + + GGML_TENSOR_LOCALS(int64_t, neq, src_q, ne); + GGML_TENSOR_LOCALS(size_t, nbq, src_q, nb); + GGML_TENSOR_LOCALS(int64_t, nek, src_k, ne); + GGML_TENSOR_LOCALS(size_t, nbk, src_k, nb); + GGML_TENSOR_LOCALS(int64_t, nev, src_v, ne); + GGML_TENSOR_LOCALS(size_t, nbv, src_v, nb); + GGML_TENSOR_LOCALS(int64_t, neg, src_g, ne); + GGML_TENSOR_LOCALS(size_t, nbg, src_g, nb); + GGML_TENSOR_LOCALS(size_t, nbb, src_beta, nb); + + const bool kda = (neg0 == S_v); + + // scratch layout per thread: [delta(S_v)] + const int64_t scratch_per_thread = S_v; + const int ith = params->ith; + + float * delta = (float *)params->wdata + ith * scratch_per_thread + CACHE_LINE_SIZE_F32; + + // output layout: [attn_scores | new_states] + // attn_scores: S_v * H * n_tokens * n_seqs floats + // new_states: S_v * S_v * H * n_seqs floats + const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs; + float * attn_out_base = (float *)dst->data; + float * state_out_base = (float *)dst->data + attn_score_elems; + + const float * state_in_base = (const float *)src_state->data; + + const int64_t rq1 = nev1 / neq1; + const int64_t rk1 = nev1 / nek1; + const int64_t rq3 = nev3 / neq3; + const int64_t rk3 = nev3 / nek3; + + const float scale = 1.0f / sqrtf((float) S_v); + + for (int64_t ir = ir0; ir < ir1; ++ir) { + const int64_t iv1 = ir % H; // head_index + const int64_t iv3 = ir / H; // sequence + + const int64_t iq1 = iv1 / rq1; + const int64_t ik1 = iv1 / rk1; + + const int64_t iq3 = iv3 / rq3; + const int64_t ik3 = iv3 / rk3; + + float * s_out = state_out_base + (iv3 * H + iv1) * S_v * S_v; + + // copy input state into output buffer and operate in-place + const float * s_in = state_in_base + (iv3 * H + iv1) * S_v * S_v; + memcpy(s_out, s_in, S_v * S_v * sizeof(float)); + + // attn output pointer for first token of this (head, seq) + float * attn_data = attn_out_base + (iv3 * n_tokens * H + iv1) * S_v; + + for (int64_t t = 0; t < n_tokens; t++) { + const float * q_d = (const float *)((const char *)src_q->data + iq3 * nbq3 + t * nbq2 + iq1 * nbq1); + const float * k_d = (const float *)((const char *)src_k->data + ik3 * nbk3 + t * nbk2 + ik1 * nbk1); + const float * v_d = (const float *)((const char *)src_v->data + iv3 * nbv3 + t * nbv2 + iv1 * nbv1); + + const float beta_val = *(const float *)((const char *)src_beta->data + iv3 * nbb3 + t * nbb2 + iv1 * nbb1); + const float * g_d = (const float *)((const char *)src_g->data + iv3 * nbg3 + t * nbg2 + iv1 * nbg1); + + if (kda) { + for (int64_t i = 0; i < S_v; ++i) { + ggml_vec_scale_f32(S_v, &s_out[i * S_v], expf(g_d[i])); + } + } else { + ggml_vec_scale_f32(S_v * S_v, s_out, expf(g_d[0])); + } + + // delta[j] = sum_i S[j][i] * k[i] + memset(delta, 0, S_v * sizeof(float)); + for (int64_t i = 0; i < S_v; ++i) { + ggml_vec_mad_f32(S_v, delta, &s_out[i * S_v], k_d[i]); + } + for (int64_t j = 0; j < S_v; ++j) { + delta[j] = (v_d[j] - delta[j]) * beta_val; + } + + // outer product: S[j][i] += k[i] * delta[j] + for (int64_t i = 0; i < S_v; ++i) { + ggml_vec_mad_f32(S_v, &s_out[i * S_v], delta, k_d[i]); + } + + // attn_out[j] = sum_i S[j][i] * q[i] + memset(attn_data, 0, S_v * sizeof(float)); + for (int64_t i = 0; i < S_v; ++i) { + ggml_vec_mad_f32(S_v, attn_data, &s_out[i * S_v], q_d[i]); + } + ggml_vec_scale_f32(S_v, attn_data, scale); + + attn_data += S_v * H; // advance to next token + } + + } +} + + +static void ggml_compute_forward_gated_delta_net_f32( + const ggml_compute_params * params, + ggml_tensor * dst) { + + ggml_tensor * V = dst->src[2]; + int64_t nr = V->ne[1] * V->ne[3]; + + // disable for NUMA + const bool disable_chunking = ggml_is_numa(); + + int nth = params->nth; + int ith = params->ith; + + // 4x chunks per thread + int nth_scaled = nth * 4; + int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled; + int64_t nchunk = (nr + chunk_size - 1) / chunk_size; + + if (nth == 1 || nchunk < nth || disable_chunking) { + nchunk = nth; + } + + if (ith == 0) { + ggml_threadpool_chunk_set(params->threadpool, nth); + } + + ggml_barrier(params->threadpool); + + const int64_t dr = (nr + nchunk - 1) / nchunk; + + int current_chunk = ith; + + while (current_chunk < nchunk) { + const int64_t ir0 = dr * current_chunk; + const int64_t ir1 = MIN(ir0 + dr, nr); + + ggml_compute_forward_gated_delta_net_one_chunk(params, dst, ir0, ir1); + current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1); + } +} + +void ggml_compute_forward_gated_delta_net( + const ggml_compute_params * params, + ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_gated_delta_net_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + // ggml_compute_forward_rwkv_wkv7 static void ggml_compute_forward_rwkv_wkv7_f32( diff --git a/ggml/src/ggml-cpu/ops.h b/ggml/src/ggml-cpu/ops.h index 0fdfee79766..3fa1443abc4 100644 --- a/ggml/src/ggml-cpu/ops.h +++ b/ggml/src/ggml-cpu/ops.h @@ -102,6 +102,7 @@ void ggml_compute_forward_rwkv_wkv6(const struct ggml_compute_params * params, s void ggml_compute_forward_rwkv_wkv7(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_gla(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_gated_delta_net(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_map_custom1(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_map_custom2(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_map_custom3(const struct ggml_compute_params * params, struct ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/gated_delta_net.cu b/ggml/src/ggml-cuda/gated_delta_net.cu new file mode 100644 index 00000000000..d8e81114559 --- /dev/null +++ b/ggml/src/ggml-cuda/gated_delta_net.cu @@ -0,0 +1,223 @@ +#include "gated_delta_net.cuh" +#include "ggml-cuda/common.cuh" + +template +__global__ void gated_delta_net_cuda(const float * q, + const float * k, + const float * v, + const float * g, + const float * beta, + const float * curr_state, + float * dst, + int64_t H, + int64_t n_tokens, + int64_t n_seqs, + int64_t sq1, + int64_t sq2, + int64_t sq3, + int64_t sv1, + int64_t sv2, + int64_t sv3, + int64_t sb1, + int64_t sb2, + int64_t sb3, + int64_t rq1, + int64_t rq3, + float scale) { + const int64_t h_idx = blockIdx.x; + const int64_t sequence = blockIdx.y; + const int col = threadIdx.x; // each thread owns one column + + const int64_t iq1 = h_idx / rq1; + const int64_t iq3 = sequence / rq3; + + const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs; + float * attn_data = dst; + float * state = dst + attn_score_elems; + + const int64_t state_offset = (sequence * H + h_idx) * S_v * S_v; + state += state_offset; + curr_state += state_offset; + attn_data += (sequence * n_tokens * H + h_idx) * S_v; + + // Load state column into registers + float s[S_v]; +#pragma unroll + for (int i = 0; i < S_v; i++) { + s[i] = curr_state[i * S_v + col]; + } + + for (int t = 0; t < n_tokens; t++) { + const float * q_t = q + iq3 * sq3 + t * sq2 + iq1 * sq1; + const float * k_t = k + iq3 * sq3 + t * sq2 + iq1 * sq1; + const float * v_t = v + sequence * sv3 + t * sv2 + h_idx * sv1; + + const int64_t gb_offset = sequence * sb3 + t * sb2 + h_idx * sb1; + const float * beta_t = beta + gb_offset; + const float * g_t = g + gb_offset * (KDA ? S_v : 1); + + const float beta_val = *beta_t; + + if constexpr (!KDA) { + const float g_val = expf(*g_t); + + // kv[col] = (S^T @ k)[col] = sum_i S[i][col] * k[i] + float kv_col = 0.0f; +#pragma unroll + for (int i = 0; i < S_v; i++) { + kv_col += s[i] * k_t[i]; + } + + // delta[col] = (v[col] - g * kv[col]) * beta + float delta_col = (v_t[col] - g_val * kv_col) * beta_val; + + // fused: S[i][col] = g * S[i][col] + k[i] * delta[col] + // attn[col] = (S^T @ q)[col] = sum_i S[i][col] * q[i] + float attn_col = 0.0f; +#pragma unroll + for (int i = 0; i < S_v; i++) { + s[i] = g_val * s[i] + k_t[i] * delta_col; + attn_col += s[i] * q_t[i]; + } + + attn_data[col] = attn_col * scale; + } else { + // kv[col] = sum_i g[i] * S[i][col] * k[i] + float kv_col = 0.0f; +#pragma unroll + for (int i = 0; i < S_v; i++) { + kv_col += expf(g_t[i]) * s[i] * k_t[i]; + } + + // delta[col] = (v[col] - kv[col]) * beta + float delta_col = (v_t[col] - kv_col) * beta_val; + + // fused: S[i][col] = g[i] * S[i][col] + k[i] * delta[col] + // attn[col] = (S^T @ q)[col] = sum_i S[i][col] * q[i] + float attn_col = 0.0f; +#pragma unroll + for (int i = 0; i < S_v; i++) { + s[i] = expf(g_t[i]) * s[i] + k_t[i] * delta_col; + attn_col += s[i] * q_t[i]; + } + + attn_data[col] = attn_col * scale; + } + + attn_data += S_v * H; + } + + // Write state back to global memory +#pragma unroll + for (int i = 0; i < S_v; i++) { + state[i * S_v + col] = s[i]; + } +} + +template +static void launch_gated_delta_net( + const float * q_d, const float * k_d, const float * v_d, + const float * g_d, const float * b_d, const float * s_d, + float * dst_d, + int64_t S_v, int64_t H, int64_t n_tokens, int64_t n_seqs, + int64_t sq1, int64_t sq2, int64_t sq3, + int64_t sv1, int64_t sv2, int64_t sv3, + int64_t sb1, int64_t sb2, int64_t sb3, + int64_t rq1, int64_t rq3, + float scale, cudaStream_t stream) { + + dim3 grid_dims(H, n_seqs, 1); + dim3 block_dims(S_v, 1, 1); + + switch (S_v) { + case 32: + gated_delta_net_cuda<32, KDA><<>>( + q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, + n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, rq1, rq3, scale); + break; + case 64: + gated_delta_net_cuda<64, KDA><<>>( + q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, + n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, rq1, rq3, scale); + break; + case 128: + gated_delta_net_cuda<128, KDA><<>>( + q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, + n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, rq1, rq3, scale); + break; + default: + GGML_ABORT("fatal error"); + break; + } +} + +void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_tensor * src_q = dst->src[0]; + ggml_tensor * src_k = dst->src[1]; + ggml_tensor * src_v = dst->src[2]; + ggml_tensor * src_g = dst->src[3]; + ggml_tensor * src_beta = dst->src[4]; + ggml_tensor * src_state = dst->src[5]; + + GGML_TENSOR_LOCALS(int64_t, neq, src_q, ne); + GGML_TENSOR_LOCALS(size_t, nbq, src_q, nb); + GGML_TENSOR_LOCALS(int64_t, nev, src_v, ne); + GGML_TENSOR_LOCALS(size_t, nbv, src_v, nb); + GGML_TENSOR_LOCALS(size_t, nbb, src_beta, nb); + + const int64_t S_v = nev0; + const int64_t H = nev1; + const int64_t n_tokens = nev2; + const int64_t n_seqs = nev3; + + const bool kda = (src_g->ne[0] == S_v); + + const int64_t rq1 = nev1 / neq1; + const int64_t rq3 = nev3 / neq3; + + const float * q_d = (const float *) src_q->data; + const float * k_d = (const float *) src_k->data; + const float * v_d = (const float *) src_v->data; + const float * g_d = (const float *) src_g->data; + const float * b_d = (const float *) src_beta->data; + + const float * s_d = (const float *) src_state->data; + float * dst_d = (float *) dst->data; + + GGML_ASSERT(ggml_is_contiguous_rows(src_q)); + GGML_ASSERT(ggml_is_contiguous_rows(src_k)); + GGML_ASSERT(ggml_is_contiguous_rows(src_v)); + GGML_ASSERT(ggml_are_same_stride(src_q, src_k)); + GGML_ASSERT(src_g->ne[0] == 1 || kda); + GGML_ASSERT(ggml_is_contiguous(src_g)); + GGML_ASSERT(ggml_is_contiguous(src_beta)); + GGML_ASSERT(ggml_is_contiguous(src_state)); + + // strides in floats (beta strides used for both g and beta offset computation) + const int64_t sq1 = nbq1 / sizeof(float); + const int64_t sq2 = nbq2 / sizeof(float); + const int64_t sq3 = nbq3 / sizeof(float); + const int64_t sv1 = nbv1 / sizeof(float); + const int64_t sv2 = nbv2 / sizeof(float); + const int64_t sv3 = nbv3 / sizeof(float); + const int64_t sb1 = nbb1 / sizeof(float); + const int64_t sb2 = nbb2 / sizeof(float); + const int64_t sb3 = nbb3 / sizeof(float); + + const float scale = 1.0f / sqrtf((float) S_v); + + cudaStream_t stream = ctx.stream(); + + if (kda) { + launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, + S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, rq1, rq3, scale, stream); + } else { + launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, + S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, rq1, rq3, scale, stream); + } +} diff --git a/ggml/src/ggml-cuda/gated_delta_net.cuh b/ggml/src/ggml-cuda/gated_delta_net.cuh new file mode 100644 index 00000000000..7375e81c0c3 --- /dev/null +++ b/ggml/src/ggml-cuda/gated_delta_net.cuh @@ -0,0 +1,4 @@ +#include "common.cuh" +#include "ggml.h" + +void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 54dc43bc088..a8007a06360 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -53,6 +53,7 @@ #include "ggml-cuda/upscale.cuh" #include "ggml-cuda/wkv.cuh" #include "ggml-cuda/gla.cuh" +#include "ggml-cuda/gated_delta_net.cuh" #include "ggml-cuda/set.cuh" #include "ggml-cuda/set-rows.cuh" #include "ggml-cuda/pad_reflect_1d.cuh" @@ -2733,6 +2734,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_GATED_LINEAR_ATTN: ggml_cuda_op_gated_linear_attn(ctx, dst); break; + case GGML_OP_GATED_DELTA_NET: + ggml_cuda_op_gated_delta_net(ctx, dst); + break; case GGML_OP_RWKV_WKV7: ggml_cuda_op_rwkv_wkv7(ctx, dst); break; @@ -4972,6 +4976,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_LEAKY_RELU: case GGML_OP_RWKV_WKV6: case GGML_OP_GATED_LINEAR_ATTN: + case GGML_OP_GATED_DELTA_NET: case GGML_OP_RWKV_WKV7: return true; case GGML_OP_FLASH_ATTN_EXT: diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index d644cca8a6e..aeafc395d71 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1031,6 +1031,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "GATED_LINEAR_ATTN", "RWKV_WKV7", "SOLVE_TRI", + "GATED_DELTA_NET", "UNARY", @@ -1048,7 +1049,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "GLU", }; -static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95"); +static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT != 96"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1140,6 +1141,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "gated_linear_attn(k, v, q, gate, s)", "rwkv_wkv7(r, w, k, v, a, b, s)", "A X = B, A triangular, solve X", + "gated_delta_net(q, k, v, g, beta, s)", "unary(x)", @@ -1157,7 +1159,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "glu(x)", }; -static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95"); +static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT != 96"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -6124,6 +6126,57 @@ struct ggml_tensor * ggml_solve_tri( return result; } +// ggml_gated_delta_net + +struct ggml_tensor * ggml_gated_delta_net( + struct ggml_context * ctx, + struct ggml_tensor * q, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * g, + struct ggml_tensor * beta, + struct ggml_tensor * state) { + GGML_ASSERT(ggml_is_contiguous_rows(q)); + GGML_ASSERT(ggml_is_contiguous_rows(k)); + GGML_ASSERT(ggml_is_contiguous_rows(v)); + GGML_ASSERT(ggml_is_contiguous(g)); + GGML_ASSERT(ggml_is_contiguous(beta)); + GGML_ASSERT(ggml_is_contiguous(state)); + + GGML_ASSERT(q->type == GGML_TYPE_F32); + GGML_ASSERT(k->type == GGML_TYPE_F32); + GGML_ASSERT(v->type == GGML_TYPE_F32); + GGML_ASSERT(g->type == GGML_TYPE_F32); + GGML_ASSERT(beta->type == GGML_TYPE_F32); + GGML_ASSERT(state->type == GGML_TYPE_F32); + + const int64_t S_v = v->ne[0]; + const int64_t H = v->ne[1]; + const int64_t n_tokens = v->ne[2]; + const int64_t n_seqs = v->ne[3]; + + // gate: scalar [1, H, T, B] or vector [S_v, H, T, B] (KDA) + GGML_ASSERT(g->ne[0] == 1 || g->ne[0] == S_v); + GGML_ASSERT(beta->ne[0] == 1); + + GGML_ASSERT(ggml_nelements(state) == S_v * S_v * H * n_seqs); + + // concat output and new_state into a single tensor + // output: S_v * H * n_tokens * n_seqs, state: S_v * S_v * H * n_seqs + const int64_t ne[4] = { S_v * H, n_tokens * n_seqs + S_v * n_seqs, 1, 1 }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + + result->op = GGML_OP_GATED_DELTA_NET; + result->src[0] = q; + result->src[1] = k; + result->src[2] = v; + result->src[3] = g; + result->src[4] = beta; + result->src[5] = state; + + return result; +} + //////////////////////////////////////////////////////////////////////////////// struct ggml_hash_set ggml_hash_set_new(size_t size) { From 8a9b0ba1dff0d73c132a89a047320a5718e71f85 Mon Sep 17 00:00:00 2001 From: Neo Zhang Date: Sun, 8 Mar 2026 12:00:07 +0800 Subject: [PATCH 229/831] supprt Flash Attention for fp32/fp16/Q4/Q5/Q8 (llama/20190) * support flash-attention for fp32/fp16/Q4/Q5/Q8 * rm warining * update for JIT --- ggml/src/ggml-sycl/CMakeLists.txt | 6 + ggml/src/ggml-sycl/backend.hpp | 1 + ggml/src/ggml-sycl/common.hpp | 226 ++- ggml/src/ggml-sycl/convert.cpp | 70 +- ggml/src/ggml-sycl/convert.hpp | 17 +- ggml/src/ggml-sycl/count-equal.cpp | 2 +- ggml/src/ggml-sycl/dpct/helper.hpp | 772 ++++++++++ ggml/src/ggml-sycl/fattn-common.hpp | 1179 +++++++++++++++ ggml/src/ggml-sycl/fattn-tile.cpp | 55 + ggml/src/ggml-sycl/fattn-tile.hpp | 1338 +++++++++++++++++ ggml/src/ggml-sycl/fattn-vec.hpp | 667 ++++++++ ggml/src/ggml-sycl/fattn.cpp | 225 +++ ggml/src/ggml-sycl/fattn.hpp | 22 + ggml/src/ggml-sycl/ggml-sycl.cpp | 60 +- ggml/src/ggml-sycl/presets.hpp | 3 + ggml/src/ggml-sycl/softmax.cpp | 10 +- .../fattn-tile-instance-dkq112-dv112.cpp | 5 + .../fattn-tile-instance-dkq128-dv128.cpp | 5 + .../fattn-tile-instance-dkq256-dv256.cpp | 5 + .../fattn-tile-instance-dkq40-dv40.cpp | 5 + .../fattn-tile-instance-dkq576-dv512.cpp | 5 + .../fattn-tile-instance-dkq64-dv64.cpp | 5 + .../fattn-tile-instance-dkq72-dv72.cpp | 5 + .../fattn-tile-instance-dkq80-dv80.cpp | 5 + .../fattn-tile-instance-dkq96-dv96.cpp | 5 + .../fattn-vec-instance-f16-f16.cpp | 7 + .../fattn-vec-instance-f16-q4_0.cpp | 7 + .../fattn-vec-instance-f16-q4_1.cpp | 7 + .../fattn-vec-instance-f16-q5_0.cpp | 7 + .../fattn-vec-instance-f16-q5_1.cpp | 7 + .../fattn-vec-instance-f16-q8_0.cpp | 7 + .../fattn-vec-instance-q4_0-f16.cpp | 7 + .../fattn-vec-instance-q4_0-q4_0.cpp | 7 + .../fattn-vec-instance-q4_0-q4_1.cpp | 7 + .../fattn-vec-instance-q4_0-q5_0.cpp | 7 + .../fattn-vec-instance-q4_0-q5_1.cpp | 7 + .../fattn-vec-instance-q4_0-q8_0.cpp | 7 + .../fattn-vec-instance-q4_1-f16.cpp | 7 + .../fattn-vec-instance-q4_1-q4_0.cpp | 7 + .../fattn-vec-instance-q4_1-q4_1.cpp | 7 + .../fattn-vec-instance-q4_1-q5_0.cpp | 7 + .../fattn-vec-instance-q4_1-q5_1.cpp | 7 + .../fattn-vec-instance-q4_1-q8_0.cpp | 7 + .../fattn-vec-instance-q5_0-f16.cpp | 7 + .../fattn-vec-instance-q5_0-q4_0.cpp | 7 + .../fattn-vec-instance-q5_0-q4_1.cpp | 7 + .../fattn-vec-instance-q5_0-q5_0.cpp | 7 + .../fattn-vec-instance-q5_0-q5_1.cpp | 7 + .../fattn-vec-instance-q5_0-q8_0.cpp | 7 + .../fattn-vec-instance-q5_1-f16.cpp | 7 + .../fattn-vec-instance-q5_1-q4_0.cpp | 7 + .../fattn-vec-instance-q5_1-q4_1.cpp | 7 + .../fattn-vec-instance-q5_1-q5_0.cpp | 7 + .../fattn-vec-instance-q5_1-q5_1.cpp | 7 + .../fattn-vec-instance-q5_1-q8_0.cpp | 7 + .../fattn-vec-instance-q8_0-f16.cpp | 7 + .../fattn-vec-instance-q8_0-q4_0.cpp | 7 + .../fattn-vec-instance-q8_0-q4_1.cpp | 7 + .../fattn-vec-instance-q8_0-q5_0.cpp | 7 + .../fattn-vec-instance-q8_0-q5_1.cpp | 7 + .../fattn-vec-instance-q8_0-q8_0.cpp | 7 + ggml/src/ggml-sycl/vecdotq.hpp | 13 + 62 files changed, 4936 insertions(+), 27 deletions(-) create mode 100644 ggml/src/ggml-sycl/fattn-common.hpp create mode 100644 ggml/src/ggml-sycl/fattn-tile.cpp create mode 100644 ggml/src/ggml-sycl/fattn-tile.hpp create mode 100644 ggml/src/ggml-sycl/fattn-vec.hpp create mode 100644 ggml/src/ggml-sycl/fattn.cpp create mode 100644 ggml/src/ggml-sycl/fattn.hpp create mode 100644 ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp create mode 100644 ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp create mode 100644 ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp create mode 100644 ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp create mode 100644 ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp create mode 100644 ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp create mode 100644 ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp create mode 100644 ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp create mode 100644 ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp create mode 100644 ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp create mode 100644 ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp create mode 100644 ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp create mode 100644 ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp create mode 100644 ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp create mode 100644 ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp create mode 100644 ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp create mode 100644 ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp create mode 100644 ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp create mode 100644 ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp create mode 100644 ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp create mode 100644 ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp create mode 100644 ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp create mode 100644 ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp create mode 100644 ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp create mode 100644 ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp create mode 100644 ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp create mode 100644 ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp create mode 100644 ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp create mode 100644 ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp create mode 100644 ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp create mode 100644 ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp create mode 100644 ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp create mode 100644 ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp create mode 100644 ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp create mode 100644 ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp create mode 100644 ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp create mode 100644 ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp create mode 100644 ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp create mode 100644 ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp create mode 100644 ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp create mode 100644 ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp create mode 100644 ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp create mode 100644 ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp create mode 100644 ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp create mode 100644 ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp diff --git a/ggml/src/ggml-sycl/CMakeLists.txt b/ggml/src/ggml-sycl/CMakeLists.txt index eefdd9725ca..7b07b227874 100644 --- a/ggml/src/ggml-sycl/CMakeLists.txt +++ b/ggml/src/ggml-sycl/CMakeLists.txt @@ -25,6 +25,11 @@ ggml_add_backend_library(ggml-sycl file(GLOB GGML_HEADERS_SYCL "*.hpp") file(GLOB GGML_SOURCES_SYCL "*.cpp") +file(GLOB SRCS "template-instances/fattn-tile*.cpp") +list(APPEND GGML_SOURCES_SYCL ${SRCS}) +file(GLOB SRCS "template-instances/fattn-vec*.cpp") +list(APPEND GGML_SOURCES_SYCL ${SRCS}) + target_sources(ggml-sycl PRIVATE ${GGML_HEADERS_SYCL} ${GGML_SOURCES_SYCL}) if (WIN32) @@ -145,6 +150,7 @@ else() endif() if (GGML_SYCL_GRAPH) + message(STATUS "find GGML_SYCL_GRAPH") target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_GRAPH) endif() diff --git a/ggml/src/ggml-sycl/backend.hpp b/ggml/src/ggml-sycl/backend.hpp index 75657f3fca2..b30b7f2beb7 100644 --- a/ggml/src/ggml-sycl/backend.hpp +++ b/ggml/src/ggml-sycl/backend.hpp @@ -23,6 +23,7 @@ #include "dequantize.hpp" #include "dmmv.hpp" #include "element_wise.hpp" +#include "fattn.hpp" #include "gla.hpp" #include "im2col.hpp" #include "mmq.hpp" diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp index 04c9e1d7864..298fddc1038 100644 --- a/ggml/src/ggml-sycl/common.hpp +++ b/ggml/src/ggml-sycl/common.hpp @@ -19,10 +19,13 @@ #include #include "dpct/helper.hpp" +#include "ggml.h" +#include "ggml-impl.h" #include "ggml-sycl.h" #include "presets.hpp" #include "sycl_hw.hpp" +namespace syclexp = sycl::ext::oneapi::experimental; #if GGML_SYCL_DNNL #include "dnnl.hpp" @@ -31,6 +34,9 @@ #define GGML_COMMON_DECL_SYCL #define GGML_COMMON_IMPL_SYCL +#define SYCL_FLASH_ATTN //remove it to disable FLASH_ATTENTION in building. +#define SYCL_FAST_FP16 //don't change. remove it will break fattn-tile.hpp building + /* suppress warning spam */ #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wnested-anon-types" @@ -45,6 +51,8 @@ void ggml_sycl_host_free(void* ptr); extern int g_ggml_sycl_debug; extern int g_ggml_sycl_disable_optimize; extern int g_ggml_sycl_prioritize_dmmv; +extern int g_ggml_sycl_enable_flash_attention; + #if defined(__clang__) && __has_builtin(__builtin_expect) // Hint the optimizer to pipeline the more likely following instruction in branches @@ -170,6 +178,10 @@ static size_t g_scratch_offset = 0; int get_current_device_id(); +inline int ggml_sycl_get_device() { + return get_current_device_id(); +} + inline dpct::err0 ggml_sycl_set_device(const int device) try { int current_device_id; SYCL_CHECK(CHECK_TRY_ERROR(current_device_id = get_current_device_id())); @@ -194,11 +206,14 @@ struct optimize_feature { }; struct sycl_device_info { - int cc; // compute capability + int cc; // compute capability int nsm; // number of streaming multiprocessors (CUDA) maps to the maximum // number of compute units on a SYCL device. // size_t smpb; // max. shared memory per block size_t smpbo; // max. shared memory per block (with opt-in) + int warp_size; // max sub_group_size of SYCL + int max_wg_per_cu; // max work groups per compute unit - refer to + // cudaOccupancyMaxActiveBlocksPerMultiprocessor bool vmm; // virtual memory support size_t total_vram; //sycl_hw_info hw_info; \\ device id and aarch, currently not used @@ -435,13 +450,15 @@ warp_reduce_sum(sycl::float2 a, const sycl::nd_item<3>& item_ct1) { return a; } -template +/* use WARP_SIZE or WARP_32_SIZE*/ +template static __dpct_inline__ int warp_reduce_sum(int x) { return sycl::reduce_over_group( sycl::ext::oneapi::this_work_item::get_sub_group(), x, sycl::plus<>()); } -template +/* use WARP_SIZE or WARP_32_SIZE*/ +template static __dpct_inline__ float warp_reduce_sum(float x) { #pragma unroll for (int offset = width / 2; offset > 0; offset >>= 1) { @@ -451,7 +468,19 @@ static __dpct_inline__ float warp_reduce_sum(float x) { return x; } -template +/* use WARP_SIZE or WARP_32_SIZE*/ +template +static __dpct_inline__ float warp_reduce_sum(float x, const sycl::nd_item<3>& item_ct1) { +#pragma unroll + for (int offset = width / 2; offset > 0; offset >>= 1) { + x += dpct::permute_sub_group_by_xor( + item_ct1.get_sub_group(), x, offset); + } + return x; +} + +/* use WARP_SIZE or WARP_32_SIZE*/ +template static __dpct_inline__ sycl::float2 warp_reduce_sum(sycl::float2 a) { #pragma unroll for (int offset = width / 2; offset > 0; offset >>= 1) { @@ -465,7 +494,8 @@ static __dpct_inline__ sycl::float2 warp_reduce_sum(sycl::float2 a) { return a; } -template +/* use WARP_SIZE or WARP_32_SIZE*/ +template static __dpct_inline__ sycl::half2 warp_reduce_sum(sycl::half2 a) { #pragma unroll for (int offset = width / 2; offset > 0; offset >>= 1) { @@ -481,7 +511,52 @@ static constexpr int ggml_sycl_get_physical_warp_size() { return WARP_SIZE; } -template +/* use WARP_SIZE or WARP_32_SIZE*/ +template +static __dpct_inline__ int warp_reduce_all(int x) { + if (width == ggml_sycl_get_physical_warp_size()) { + return sycl::all_of_group( + sycl::ext::oneapi::this_work_item::get_sub_group(), + (~0xffffffff & + (0x1 << sycl::ext::oneapi::this_work_item::get_sub_group() + .get_local_linear_id())) || + x); + } else { +#pragma unroll + for (int offset = width / 2; offset > 0; offset >>= 1) { + x = dpct::permute_sub_group_by_xor( + sycl::ext::oneapi::this_work_item::get_sub_group(), x, + offset, width) && + x; + } + return x; + } +} + +/* use WARP_SIZE or WARP_32_SIZE*/ +template +static __dpct_inline__ int warp_reduce_any(int x) { + if (width == ggml_sycl_get_physical_warp_size()) { + return sycl::any_of_group( + sycl::ext::oneapi::this_work_item::get_sub_group(), + (0xffffffff & + (0x1 << sycl::ext::oneapi::this_work_item::get_sub_group() + .get_local_linear_id())) && + x); + } else { +#pragma unroll + for (int offset = width / 2; offset > 0; offset >>= 1) { + x = dpct::permute_sub_group_by_xor( + sycl::ext::oneapi::this_work_item::get_sub_group(), x, + offset, width) || + x; + } + return x; + } +} + +/* use WARP_SIZE or WARP_32_SIZE*/ +template static __dpct_inline__ float warp_reduce_max(float x) { #pragma unroll for (int offset = width / 2; offset > 0; offset >>= 1) { @@ -629,6 +704,42 @@ static const sycl::uint3 init_fastdiv_values(uint32_t d) { return sycl::uint3(mp, L, d); } +// Maximum number of bytes that can be copied in a single instruction. +// Set by test result. +static constexpr int ggml_sycl_get_max_cpy_bytes() { + return 16; +} + +// Aligned memory transfers of 8/16 bytes can be faster than 2 transfers with 4 bytes. +template +static __dpct_inline__ void ggml_sycl_memcpy_1(void * dst, const void * src) { + if constexpr (alignment != 0) { + static_assert(nbytes % alignment == 0, "bad alignment"); + } + constexpr int nb_per_cpy = alignment == 0 ? nbytes : alignment; + +#pragma unroll + for (int i = 0; i < nbytes/nb_per_cpy; ++i) { + if constexpr (nb_per_cpy == 1) { + ((char *) dst)[i] = ((const char *) src)[i]; + } else if constexpr (nb_per_cpy == 2) { + ((short *) dst)[i] = ((const short *) src)[i]; + } else if constexpr (nb_per_cpy == 4) { + ((int *) dst)[i] = ((const int *) src)[i]; + } else if constexpr (nb_per_cpy == 8) { + ((sycl::int2 *) dst)[i] = ((const sycl::int2 *) src)[i]; + } else if constexpr (nb_per_cpy == 16) { + ((sycl::int4 *) dst)[i] = ((const sycl::int4 *) src)[i]; + } else { + static_assert(nbytes == 0 && nbytes == -1, "bad nbytes"); + } + } +} +template +sycl::half2 __dpct_inline__ make_half2( T x, T y) { + sycl::half2 res(static_cast(x),static_cast(y)); + return res; +} static __dpct_inline__ uint32_t fastdiv(uint32_t n, const sycl::uint3 fastdiv_values) { const uint32_t hi = sycl::mul_hi(n, fastdiv_values.x()); @@ -636,6 +747,17 @@ static __dpct_inline__ uint32_t fastdiv(uint32_t n, const sycl::uint3 fastdiv_va } +template +sycl::float2 __dpct_inline__ make_float2( T x, T y) { + sycl::float2 res(static_cast(x),static_cast(y)); + return res; +} + +sycl::float2 __dpct_inline__ __half22float2(sycl::half2 &H) { + sycl::float2 float2_value(static_cast(H.x()), static_cast(H.y())); + return float2_value; +} + static __dpct_inline__ sycl::uint2 fast_div_modulo(uint32_t n, const sycl::uint3 fastdiv_values) { const uint32_t div_val = fastdiv(n, fastdiv_values); const uint32_t mod_val = n - div_val * fastdiv_values.z(); @@ -659,5 +781,97 @@ static __dpct_inline__ float ggml_sycl_e8m0_to_fp32(uint8_t x) { return result; } +sycl::float2 __dpct_inline__ __half22float2(const sycl::half2 &H) { + sycl::float2 float2_value(static_cast(H.x()), static_cast(H.y())); + return float2_value; +} + +float __dpct_inline__ __half2float(sycl::half H) { + return static_cast(H); +} + +static __dpct_inline__ void ggml_sycl_mad(float & acc, const float v, const float u) { + acc += v*u; +} + +static __dpct_inline__ void ggml_sycl_mad(float & acc, const sycl::float2 v, const sycl::float2 u) { + acc += v.x() * u.x(); + acc += v.y() * u.y(); +} + +static __dpct_inline__ void ggml_sycl_mad(float & acc, const sycl::half2 v, const sycl::half2 u) { +#ifdef GGML_SYCL_F16 + const sycl::float2 tmp = (v * u).template convert(); + acc += tmp.x() + tmp.y(); +#else + const sycl::float2 tmpv = __half22float2(v); + const sycl::float2 tmpu = __half22float2(u); + acc += tmpv.x() * tmpu.x(); + acc += tmpv.y() * tmpu.y(); +#endif // GGML_SYCL_F16 +} + +static __dpct_inline__ void ggml_sycl_mad(sycl::half2 & acc, const sycl::half2 v, const sycl::half2 u) { +#ifdef GGML_SYCL_F16 + acc += v*u; +#else + const sycl::float2 tmpv = __half22float2(v); + const sycl::float2 tmpu = __half22float2(u); + sycl::float2 tmpacc = __half22float2(acc); + // tmpacc.x += tmpv.x() * tmpu.x(); + // tmpacc.y += tmpv.y() * tmpu.y(); + sycl::float2 tmp1(tmpacc.x() + tmpv.x() * tmpu.x(), tmpacc.y() + tmpv.y() * tmpu.y()); + acc = make_half2(tmp1.x(), tmp1.y()); +#endif // GGML_SYCL_F16 +} + +template +struct ggml_sycl_unroll { + template + void operator()(const Func & f, Args... args) const { + f(n - 1, args...); + ggml_sycl_unroll{}(f, args...); + } +}; + +template <> +struct ggml_sycl_unroll<1> { + template + void operator()(const Func & f, Args... args) const { + f(0, args...); + } +}; + +static __dpct_inline__ sycl::half2 ggml_sycl_hmax2(const sycl::half2 a, const sycl::half2 b) { + sycl::half2 ret; + reinterpret_cast(ret.x()) = + sycl::vec(sycl::fmax(a[0], b[0])).convert()[0]; + reinterpret_cast(ret.y()) = + sycl::vec(sycl::fmax(a[1], b[1])).convert()[0]; + return ret; +} + +static __dpct_inline__ sycl::half ggml_sycl_hmax(const sycl::half a, const sycl::half b) { + return sycl::vec( + sycl::fmax(sycl::vec(a).convert()[0], + sycl::vec(b).convert()[0])) + .convert()[0]; +} + +static __dpct_inline__ uint32_t __hgt2_mask(const sycl::half2 a, const sycl::half2 b) { + const uint32_t mask_low = 0x0000FFFF * (float(a[0]) > float(b[0])); + const uint32_t mask_high = 0xFFFF0000 * (float(a[1]) > float(b[1])); + return mask_low | mask_high; +} + +static __dpct_inline__ uint32_t fastmodulo(uint32_t n, const sycl::uint3 fastdiv_values) { + // expects fastdiv_values to contain in (see init_fastdiv_values) + return n - fastdiv(n, fastdiv_values) * fastdiv_values.z(); +} + +static bool fast_fp16_available(const int cc) { + GGML_UNUSED(cc); + return true; //Intel GPUs always support FP16. +} #endif // GGML_SYCL_COMMON_HPP diff --git a/ggml/src/ggml-sycl/convert.cpp b/ggml/src/ggml-sycl/convert.cpp index 8bdae36458c..d17aca2cac4 100644 --- a/ggml/src/ggml-sycl/convert.cpp +++ b/ggml/src/ggml-sycl/convert.cpp @@ -482,6 +482,63 @@ static void dequantize_row_mxfp4_sycl(const void * vx, dst_t * y, const int64_t }); } +template +static void dequantize_block_nc(const void * __restrict__ vx, dst_t * __restrict__ y, + const int64_t ne00, const int64_t ne01, const int64_t ne02, + const int64_t s01, const int64_t s02, const int64_t s03) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + const int64_t i00 = 2 * (int64_t(item_ct1.get_local_range(2)) * item_ct1.get_group(2) + item_ct1.get_local_id(2)); + + if (i00 >= ne00) { + return; + } + + const int64_t i01 = item_ct1.get_group(1); + const int64_t i02 = item_ct1.get_group(0) % ne02; + const int64_t i03 = item_ct1.get_group(0) / ne02; + + const int64_t ibx0 = i03*s03 + i02*s02 + i01*s01; + + const int64_t ib = ibx0 + i00/qk; // block index + const int64_t iqs = (i00%qk)/qr; // quant index + const int64_t iybs = i00 - i00%qk; // y block start index + const int64_t y_offset = qr == 1 ? 1 : qk/2; + + // dequantize + #ifdef GGML_SYCL_F16 + sycl::half2 v; + #else + sycl::float2 v; + #endif + + dequantize_kernel(vx, ib, iqs, v); + + const int64_t iy0 = ((i03*ne02 + i02)*ne01 + i01)*ne00 + iybs + iqs; + y[iy0 + 0] = ggml_sycl_cast(v.x()); + y[iy0 + y_offset] = ggml_sycl_cast(v.y()); +} + + +template +static void dequantize_block_nc_sycl(const void * vx, + dst_t * y, + const int64_t ne00, + const int64_t ne01, + const int64_t ne02, + const int64_t ne03, + const int64_t s01, + const int64_t s02, + const int64_t s03, + dpct::queue_ptr stream) { + const dpct::dim3 num_blocks((ne00 + 2 * SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / (2 * SYCL_DEQUANTIZE_BLOCK_SIZE), ne01, + ne02 * ne03); + stream->parallel_for(sycl::nd_range<3>(num_blocks * sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + GGML_UNUSED(item_ct1); + dequantize_block_nc(vx, y, ne00, ne01, ne02, s01, s02, s03); + }); +} template static void convert_unary_nc(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t s01, const int64_t s02, const int64_t s03, @@ -662,7 +719,8 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) { } } -to_fp16_nc_sycl_t get_to_fp16_nc_sycl(ggml_type type) { + +to_fp16_nc_sycl_t ggml_get_to_fp16_nc_sycl(ggml_type type) { switch (type) { case GGML_TYPE_F32: return convert_unary_nc_sycl; @@ -670,6 +728,16 @@ to_fp16_nc_sycl_t get_to_fp16_nc_sycl(ggml_type type) { case GGML_TYPE_BF16: return convert_unary_nc_sycl; #endif + case GGML_TYPE_Q4_0: + return dequantize_block_nc_sycl; + case GGML_TYPE_Q4_1: + return dequantize_block_nc_sycl; + case GGML_TYPE_Q5_0: + return dequantize_block_nc_sycl; + case GGML_TYPE_Q5_1: + return dequantize_block_nc_sycl; + case GGML_TYPE_Q8_0: + return dequantize_block_nc_sycl; default: return nullptr; } diff --git a/ggml/src/ggml-sycl/convert.hpp b/ggml/src/ggml-sycl/convert.hpp index f8cb573e368..f93bd0df7d7 100644 --- a/ggml/src/ggml-sycl/convert.hpp +++ b/ggml/src/ggml-sycl/convert.hpp @@ -29,6 +29,21 @@ using to_t_nc_sycl_t = void (*)(const void * x, T * y, int64_t ne00, int64_t ne0 int64_t s01, int64_t s02, int64_t s03, dpct::queue_ptr queue); typedef to_t_nc_sycl_t to_fp16_nc_sycl_t; -to_fp16_nc_sycl_t get_to_fp16_nc_sycl(ggml_type type); +to_fp16_nc_sycl_t ggml_get_to_fp16_nc_sycl(ggml_type type); + +template + inline dst_t ggml_sycl_cast(src_t x) { + if constexpr (std::is_same_v) { + return x; + } else if constexpr (std::is_same_v) { + return sycl::ext::oneapi::bfloat16(float(x)); + } else if constexpr (std::is_same_v) { + return static_cast(x); + } else if constexpr(std::is_same_v) { + return int32_t(x); + } else { + return float(x); + } +} #endif // GGML_SYCL_CONVERT_HPP diff --git a/ggml/src/ggml-sycl/count-equal.cpp b/ggml/src/ggml-sycl/count-equal.cpp index b0a8b4820de..4580354cd9d 100644 --- a/ggml/src/ggml-sycl/count-equal.cpp +++ b/ggml/src/ggml-sycl/count-equal.cpp @@ -18,7 +18,7 @@ static void count_equal(const T *__restrict__ x, const T *__restrict__ y, nequal += xi == yi; } - nequal = warp_reduce_sum(nequal); + nequal = warp_reduce_sum(nequal); if (item_ct1.get_local_id(2) != 0) { return; diff --git a/ggml/src/ggml-sycl/dpct/helper.hpp b/ggml/src/ggml-sycl/dpct/helper.hpp index ece66a7ac1f..791d3cac52e 100644 --- a/ggml/src/ggml-sycl/dpct/helper.hpp +++ b/ggml/src/ggml-sycl/dpct/helper.hpp @@ -2997,6 +2997,778 @@ namespace dpct return 0; } + template + class args_selector; + + /// args_selector is a helper class for extracting arguments from an + /// array of pointers to arguments or buffer of arguments to pass to a + /// kernel function. + /// + /// \param R(Ts...) The type of the kernel + /// \param n_nondefault_params The number of nondefault parameters of the + /// kernel (excluding parameters that like sycl::nd_item, etc.) \param + /// n_default_params The number of default parameters of the kernel + /// + /// Example usage: + /// With the following kernel: + /// void foo(sycl::float2 *x, int n, sycl::nd_item<3> item_ct1, float + /// f=.1) {} + /// and with the declaration: + /// args_selector<2, 1, decltype(foo)> selector(kernelParams, extra); + /// we have: + /// selector.get<0>() returns a reference to sycl::float*, + /// selector.get<1>() returns a reference to int, + /// selector.get<2>() returns a reference to float + template + class args_selector { + private: + void **kernel_params; + char *args_buffer; + + template static constexpr int account_for_default_params() { + constexpr int n_total_params = sizeof...(Ts); + if constexpr (i >= n_nondefault_params) { + return n_total_params - n_default_params + + (i - n_nondefault_params); + } else { + return i; + } + } + + public: + /// Get the type of the ith argument of R(Ts...) + /// \param [in] i Index of parameter to get + /// \returns Type of ith parameter + template + using arg_type = std::tuple_element_t(), + std::tuple>; + static constexpr int params_num = sizeof...(Ts); + + private: + template static constexpr int get_offset() { + if constexpr (i == 0) { + // we can assume args_buffer is properly aligned to the + // first argument + return 0; + } else { + constexpr int prev_off = get_offset(); + constexpr int prev_past_end = + prev_off + sizeof(arg_type); + using T = arg_type; + // is the past-the-end of the i-1st element properly aligned + // with the ith element's alignment? + if constexpr (prev_past_end % alignof(T) == 0) { + return prev_past_end; + } + // otherwise bump prev_past_end to match alignment + else { + return prev_past_end + + (alignof(T) - (prev_past_end % alignof(T))); + } + } + } + + static char *get_args_buffer(void **extra) { + if (!extra) + return nullptr; + for (; (std::size_t)*extra != 0; ++extra) { + if ((std::size_t)*extra == 1) { + return static_cast(*(extra + 1)); + } + } + return nullptr; + } + + public: + /// If kernel_params is nonnull, then args_selector will + /// extract arguments from kernel_params. Otherwise, it + /// will extract them from extra. + /// \param [in] kernel_params Array of pointers to arguments + /// a or null pointer. + /// \param [in] extra Array containing pointer to argument buffer. + args_selector(void **kernel_params, void **extra) + : kernel_params(kernel_params), + args_buffer(get_args_buffer(extra)) {} + + /// Get a reference to the ith argument extracted from kernel_params + /// or extra. + /// \param [in] i Index of argument to get + /// \returns Reference to the ith argument + template arg_type &get() { + if (kernel_params) { + return *static_cast *>(kernel_params[i]); + } else { + return *reinterpret_cast *>(args_buffer + + get_offset()); + } + } + }; // COPY from DPCT head file + // /opt/intel/oneapi/dpcpp-ct/latest/include/dpct/util.hpp + + /// Utility class for launching SYCL kernels through kernel + /// function wrapper. + /// For example: + /// A SYCL kernel function: + /// void kernel_func(int *ptr, sycl::nd_item<3> item); + /// Kernel function wrapper: + /// void kernel_func_wrapper(int *ptr) { + /// sycl::queue queue = *dpct::kernel_launcher::_que; + /// unsigned int localMemSize = dpct::kernel_launcher::_local_mem_size; + /// sycl::nd_range<3> nr = dpct::kernel_launcher::_nr; + /// queue.parallel_for( + /// nr, + /// [=](sycl::nd_item<3> item_ct1) { + /// kernel_func(ptr, item_ct1); + /// }); + /// } + /// Then launch the kernel through wrapper like: + /// typedef void(*fpt)(int *); + /// fpt fp = kernel_func_wrapper; + /// dpct::kernel_launcher::launch(fp, dpct::dim3(1), dpct::dim3(1), 0, 0, + /// device_ptr); + /// If the origin function type is erased, then need to register it first: + /// void *fp = (void *)wrapper_register(&kernel_func_wrapper).get(); + /// dpct::kernel_launcher::launch(fp, dpct::dim3(1), dpct::dim3(1), args, + /// 0, 0); + class kernel_launcher { + template + static void launch_helper(FuncT &&func, ArgSelector &selector, + std::index_sequence) { + func(selector.template get()...); + } + static void set_execution_config(dim3 group_range, dim3 local_range, + unsigned int local_mem_size, + queue_ptr que) { + if (que) { + _que = que; + } else { + _que = &get_default_queue(); + } + _nr = sycl::nd_range<3>( + static_cast>(group_range * local_range), + static_cast>(local_range)); + _local_mem_size = local_mem_size; + + + }; + static inline std::mutex kernel_function_ptr_map_mutex; + + public: + /// Variables for storing execution configuration. + static inline thread_local sycl::queue *_que = nullptr; + static inline thread_local sycl::nd_range<3> _nr = sycl::nd_range<3>(); + static inline thread_local unsigned int _local_mem_size = 0; + /// Map for retrieving launchable functor from a raw pointer. + static inline std::map< + const void *, + std::function> + kernel_function_ptr_map = {}; + + /// Registers a kernel function pointer with a corresponding launchable + /// functor. + /// \param [in] func Pointer to the kernel function. + /// \param [in] launcher Functor to handle kernel invocation. + static void register_kernel_ptr( + const void *func, + std::function + launcher) { + std::lock_guard lock(kernel_function_ptr_map_mutex); + kernel_function_ptr_map[func] = std::move(launcher); + } + /// Launches a kernel function with arguments provided directly through + /// kernel function wrapper. + /// \tparam FuncT Type of the kernel function wrapper. + /// \tparam ArgsT Types of kernel arguments. + /// \param [in] func Pointer to the kernel function wrapper. + /// \param [in] group_range SYCL group range. + /// \param [in] local_range SYCL local range. + /// \param [in] local_mem_size The size of local memory required by the + /// kernel function. \param [in] que SYCL queue used to execute kernel. + /// \param [in] args Kernel arguments. + template + static std::enable_if_t, void> + launch(FuncT *func, dim3 group_range, dim3 local_range, + unsigned int local_mem_size, queue_ptr que, ArgsT... args) { + set_execution_config(group_range, local_range, local_mem_size, que); + func(args...); + } + /// Launches a kernel function through registered kernel function + /// wrapper. \param [in] func Pointer to the registered kernel function + /// wrapper. \param [in] group_range SYCL group range. \param [in] + /// local_range SYCL local range. \param [in] args Array of pointers to + /// kernel arguments. \param [in] local_mem_size The size of local + /// memory required by the kernel function. \param [in] que SYCL queue + /// used to execute kernel. + static void launch(const void *func, dim3 group_range, dim3 local_range, + void **args, unsigned int local_mem_size, + queue_ptr que) { + std::lock_guard lock(kernel_function_ptr_map_mutex); + auto Iter = kernel_function_ptr_map.find(func); + if (Iter == kernel_function_ptr_map.end()) { + throw std::runtime_error("dpct::launch() : no registered " + "kernel function wrapper found."); + } + (Iter->second)(group_range, local_range, args, local_mem_size, que); + } + /// Launches a kernel function with packed arguments through kernel + /// function wrapper. + /// \tparam FuncT Type of the kernel function wrapper. + /// \param [in] func Pointer to the kernel function wrapper. + /// \param [in] group_range SYCL group range. + /// \param [in] local_range SYCL local range. + /// \param [in] args Array of pointers to kernel arguments. + /// \param [in] local_mem_size The size of local memory required by the + /// kernel function. \param [in] que SYCL queue used to execute kernel. + template + static std::enable_if_t, void> + launch(FuncT *func, dim3 group_range, dim3 local_range, void **args, + unsigned int local_mem_size, queue_ptr que) { + constexpr size_t p_num = args_selector<0, 0, FuncT>::params_num; + set_execution_config(group_range, local_range, local_mem_size, que); + args_selector selector(args, nullptr); + launch_helper(func, selector, std::make_index_sequence{}); + } + }; // COPY from DPCT head file + // /opt/intel/oneapi/dpcpp-ct/latest/include/dpct/kernel.hpp + + // /opt/intel/oneapi/dpcpp-ct/latest/include/dpct/util.hpp + template + T select_from_sub_group( + sycl::sub_group g, + T x, + int remote_local_id, + int logical_sub_group_size = 32) { + unsigned int start_index = g.get_local_linear_id() / + logical_sub_group_size * + logical_sub_group_size; + return sycl::select_from_group( + g, x, start_index + remote_local_id % logical_sub_group_size); + } + + // /opt/intel/oneapi/dpcpp-ct/latest/include/dpct/math.hpp + template + void ldmatrix(uintptr_t addr, T* m, bool trans = false, unsigned mat = 0) { + auto sg = sycl::ext::oneapi::this_work_item::get_sub_group(); + int lane = sg.get_local_linear_id(); + + int lane_group8_row = lane / 8; + int lane_group8_col = lane % 8; + + if (!trans) { + // calculate the source lane + int src_lane = 2 * lane_group8_row; + if (lane_group8_col >= 4) + src_lane += 1; + + // Broadcast the address from the source lane + auto recv_addr_uintp = + dpct::select_from_sub_group(sg, addr, mat * 8 + src_lane); + + // Cast the received address from uintptr_t to the type of 'm' + auto recv_addr = reinterpret_cast(recv_addr_uintp); + + // Non-transposed load + *m = recv_addr[lane_group8_col % 4]; + } else { + // calculate the source lane + int src_lane = (lane % 4) * 2; + + // Broadcast the address from the source lane + auto recv_addr_uintp_1 = + dpct::select_from_sub_group(sg, addr, mat * 8 + src_lane); + auto recv_addr_uintp_2 = + dpct::select_from_sub_group(sg, addr, mat * 8 + src_lane + 1); + + // Cast the received address from uintptr_t to 'half *' + auto recv_addr_1 = reinterpret_cast(recv_addr_uintp_1); + auto recv_addr_2 = reinterpret_cast(recv_addr_uintp_2); + + // Transposed load + int index = lane / 4; + sycl::half val0 = recv_addr_1[index]; + sycl::half val1 = recv_addr_2[index]; + + // Combine the two 16-bits into one 32-bit value + sycl::half2 val = sycl::half2(val0, val1); + *m = *reinterpret_cast(&val); + } + } + + template + void ldmatrix(uintptr_t addr, T* m1, T* m2, bool trans = false) { + // Load 1st matrix + ldmatrix(addr, m1, trans, 0); + // Load 2nd matrix + ldmatrix(addr, m2, trans, 1); + } + + template + void ldmatrix( + uintptr_t addr, T* m1, T* m2, T* m3, T* m4, bool trans = false) { + // Load 1st matrix + ldmatrix(addr, m1, trans, 0); + // Load 2nd matrix + ldmatrix(addr, m2, trans, 1); + // Load 3rd matrix + ldmatrix(addr, m3, trans, 2); + // Load 4th matrix + ldmatrix(addr, m4, trans, 3); + } + + // /opt/intel/oneapi/dpcpp-ct/latest/include/dpct/math.hpp + + /// A helper struct that defines the pack type for the input matrix + /// fragments + /// of mma() function based on the type of input matrix fragments. + /// The MMAType struct is specialized for different types of input matrices. + /// Currently, the specialization for f16, bf16 and s8 types is defined + /// below. \tparam [in] T The type of the input matrix fragments + template + struct MMAType { + using PackType = uint32_t; + }; + + /// Each work item of a sub-group (limited to size 32) calling this function + /// calculates a subset fragment for the output matrix D using MAD operation + /// on A, B & C matrix fragments (D = A * B + C). Current supported shapes & + /// types: + /// - m8n8k4 (f32.f16.f16.f32) + /// - m8n8k16 (s32.s8.s8.s32) + /// - m16n8k8 (f32.f16.f16.f32 & f32.bf16.bf16.f32) + /// - m16n8k16 (f32.f16.f16.f32 & s32.s8.s8.s32) + /// - m16n8k32 (s32.s8.s8.s32) + /// Here, m, n & k define the shapes of A, B & C matrices respectively + /// (A = [m x k], B = [k x n], C = [m x n]). + /// \tparam [in] M The rows of A, C & D matrices + /// \tparam [in] N The columns of B, C, D matrices + /// \tparam [in] K The columns & rows of A & B matrices respectively + /// \tparam [in] ABType The type of the input matrix (A & B) fragment + /// \tparam [in] CDType The type of the output matrix (C & D) fragment + /// \param [out] d_mat_frag The fragment of the output matrix D to store the + /// result of A * B + C + /// \param [in] a_mat_frag The fragment of the input matrix A to be + /// multiplied with B matrix fragment \param [in] b_mat_frag The fragment of + /// the input matrix B to be multiplied with A matrix fragment \param [in] + /// c_mat_frag The fragment of the input matrix C to be added with the + /// result of A * B fragments + template + void mma( + volatile void** d_mat_frag, + void* a_mat_frag, + void* b_mat_frag, + void* c_mat_frag) { + auto d = reinterpret_cast(d_mat_frag); + auto a = + reinterpret_cast::PackType*>(a_mat_frag); + auto b = + reinterpret_cast::PackType*>(b_mat_frag); + auto c = reinterpret_cast(c_mat_frag); + + auto sg = sycl::ext::oneapi::this_work_item::get_sub_group(); + int lane = sg.get_local_linear_id(); + + static_assert( + (M == 8 && N == 8 && K == 4) || (M == 8 && N == 8 && K == 16) || + (M == 16 && N == 8 && K == 8) || (M == 16 && N == 8 && K == 16) || + (M == 16 && N == 8 && K == 32), + "Unsupported MMA shape!"); + + short row_load_offset = 4 * (lane >> 2); + short col_load_offset = 8 * (lane % 4); + + if constexpr (M == 8 && N == 8 && K == 4) { + if constexpr (std::is_floating_point_v) { + col_load_offset = row_load_offset % 16; + + // Init D matrix with fragments of C matrix + *d[0] = c[0]; + *d[1] = c[1]; + *d[2] = c[2]; + *d[3] = c[3]; + *d[4] = c[4]; + *d[5] = c[5]; + *d[6] = c[6]; + *d[7] = c[7]; + + // Calculate the row and col offset indices to iterate through the row + // & col fragments of A & B matrices + int r_ind = (lane % 2) ? 1 : 0; + int c_ind = ((lane % 4) / 2) ? 2 : 0; + + // Each sub-group is responsible for computing a fragment size of 8*8 + // elements of matrix D for each of 4 MMA computations. + // Each work item computes 8 elements of matrix D by gathering + // their corresponding col & row matrix fragments of length k (4) + // from A & B matrices respectively using below mapping logic: + // row0 = (i % 4) if (lane < 16) else (i % 4) + 4 + // col0 = (lane % 4) + // As each row & col fragment of A & B matrices is distributed across + // 4 work items, each iteration of below loop loads a partial fragment + // of matrix A (row) and matrix B (col) using the row & col offsets. + typename MMAType::PackType recv_a[2], recv_b[2]; + + for (int i = 0; i < 4; i++) { + // Load partial fragment from col0 of matrix A ({a0, a1}) + recv_a[0] = + dpct::select_from_sub_group(sg, a[0], row_load_offset + i); + // Load partial fragment from col0 of matrix A ({a2, a3}) + recv_a[1] = + dpct::select_from_sub_group(sg, a[1], row_load_offset + i); + + // Load partial fragment from row0 of matrix B ({b0, b1}) + recv_b[0] = + dpct::select_from_sub_group(sg, b[0], col_load_offset + i); + // Load partial fragment from row0 of matrix B ({b2, b3}) + recv_b[1] = + dpct::select_from_sub_group(sg, b[1], col_load_offset + i); + + auto ra = reinterpret_cast(recv_a); + auto rb = reinterpret_cast(recv_b); + + // Each work item calculates a partial product of A & B matrix + // fragments and adds it to the corresponding D matrix fragment (for + // even work item indices) d0 += col0{ a0 } * row0{ b0 } d1 += col0{ + // a0 } * row0{ b1 } d2 += col1{ a2 } * row0{ b0 } d3 += col1{ a2 } + // * row0{ b1 } (for odd work item indices) d0 += col0{ a1 } * row0{ + // b2 } d1 += col0{ a1 } * row0{ b3 } d2 += col1{ a3 } * row0{ b2 } + // d3 += col1{ a3 } * row0{ b3 } + *d[0] += + static_cast(ra[r_ind]) * static_cast(rb[c_ind]); + *d[1] += static_cast(ra[r_ind]) * + static_cast(rb[c_ind + 1]); + *d[2] += static_cast(ra[r_ind + 2]) * + static_cast(rb[c_ind]); + *d[3] += static_cast(ra[r_ind + 2]) * + static_cast(rb[c_ind + 1]); + + // Load partial fragment from row1 of matrix B ({b0, b1}) + recv_b[0] = + dpct::select_from_sub_group(sg, b[0], col_load_offset + i + 16); + // Load partial fragment from row1 of matrix B ({b2, b3}) + recv_b[1] = + dpct::select_from_sub_group(sg, b[1], col_load_offset + i + 16); + + // (for even work item indices) + // d0 += col0{ a0 } * row1{ b0 } + // d1 += col0{ a0 } * row1{ b1 } + // d2 += col1{ a2 } * row1{ b0 } + // d3 += col1{ a2 } * row1{ b1 } + // (for odd work item indices) + // d0 += col0{ a1 } * row1{ b2 } + // d1 += col0{ a1 } * row1{ b3 } + // d2 += col1{ a3 } * row1{ b2 } + // d3 += col1{ a3 } * row1{ b3 } + *d[4] += + static_cast(ra[r_ind]) * static_cast(rb[c_ind]); + *d[5] += static_cast(ra[r_ind]) * + static_cast(rb[c_ind + 1]); + *d[6] += static_cast(ra[r_ind + 2]) * + static_cast(rb[c_ind]); + *d[7] += static_cast(ra[r_ind + 2]) * + static_cast(rb[c_ind + 1]); + } + } + } else if constexpr (M == 8 && N == 8 && K == 16) { + if constexpr (std::is_integral_v) { + // Init D matrix with fragments of C matrix + *d[0] = c[0]; + *d[1] = c[1]; + + // Each sub-group is responsible for computing a fragment size of 16*8 + // elements of matrix D. + // Each work item computes 2 elements of matrix D by gathering + // their corresponding row & col matrix fragments of length k (16) + // from A & B matrices respectively using below mapping logic: + // row0 = ((lane % 4) * 4) + i + // col0 = (lane >> 2) + // As each row & col fragment of A & B matrices is distributed across + // 4 work items, each iteration of below loop loads a partial fragment + // of matrix A (row) and matrix B (col) using the row & col offsets. + for (int i = 0; i < 4; i++) { + typename MMAType::PackType recv_a, recv_b[2]; + + // Load partial fragment from row0 of matrix A ({a0, a1, a2, a3}) + recv_a = dpct::select_from_sub_group(sg, a[0], row_load_offset + i); + // Load partial fragment from col0 of matrix B ({b0, b1, b2, b3}) + recv_b[0] = + dpct::select_from_sub_group(sg, b[0], col_load_offset + i); + // Load partial fragment from col1 of matrix B ({b0, b1, b2, b3}) + recv_b[1] = + dpct::select_from_sub_group(sg, b[0], col_load_offset + i + 4); + + auto a = reinterpret_cast(&recv_a); + auto b = reinterpret_cast(recv_b); + + // Each work item calculates a partial product of A & B matrix + // fragments and adds it to the corresponding D matrix fragment d0 + // += row0{ a0, a1, a2, a3 } * col0{ b0, b1, b2, b3 } d1 += row0{ + // a0, a1, a2, a3 } * col1{ b0, b1, b2, b3 } d2 += row0{ a0, a1, a2, + // a3 } * col0{ b0, b1, b2, b3 } d3 += row0{ a0, a1, a2, a3 } * + // col1{ b0, b1, b2, b3 } + for (int j = 0; j < 4; j++) { + *d[0] += a[j] * b[j]; + *d[1] += a[j] * b[j + 4]; + } + } + } + } else if constexpr (M == 16 && N == 8 && K == 8) { + if constexpr (std::is_floating_point_v) { + // Init D matrix fragment with C matrix fragment + *d[0] = c[0]; + *d[1] = c[1]; + *d[2] = c[2]; + *d[3] = c[3]; + + // Each sub-group is responsible for computing a fragment size of 16*8 + // elements of matrix D. + // Each work item computes 4 elements of matrix D by gathering + // their corresponding row & col matrix fragments of length k (8) + // from A & B matrices respectively using below mapping logic: + // row0 = (lane >> 2) & row1 = (lane >> 2) + 8 + // col0 = (lane % 4) * 2 + (i & 0x1) + // As each row & col fragment of A & B matrices is distributed across + // 4 work items, each iteration of below loop loads a partial fragment + // of matrix A (row) and matrix B (col) using the row & col offsets. + for (int i = 0; i < 4; i++) { + typename MMAType::PackType recv_a[2], recv_b[2]; + + // Load partial fragment from row0 of matrix A ({a0, a1}) + recv_a[0] = + dpct::select_from_sub_group(sg, a[0], row_load_offset + i); + // Load partial fragment from row1 of matrix A ({a2, a3}) + recv_a[1] = + dpct::select_from_sub_group(sg, a[1], row_load_offset + i); + // Load partial fragment from col0 of matrix B ({b0, b1}) + recv_b[0] = + dpct::select_from_sub_group(sg, b[0], col_load_offset + i); + // Load partial fragment from col1 of matrix B ({b0, b1}) + recv_b[1] = + dpct::select_from_sub_group(sg, b[0], col_load_offset + i + 4); + + auto ra = reinterpret_cast(recv_a); + auto rb = reinterpret_cast(recv_b); + + // Each work item calculates a partial product of A & B matrix + // fragments and adds it to the corresponding D matrix fragment d0 + // += row0{ a0, a1 } * col0{ b0, b1 } d1 += row0{ a0, a1 } * col1{ + // b0, b1 } d2 += row1{ a2, a3 } * col0{ b0, b1 } d3 += row1{ a2, a3 + // } * col1{ b0, b1 } + for (int j = 0; j < 2; j++) { + *d[0] += static_cast(ra[j]) * static_cast(rb[j]); + *d[1] += + static_cast(ra[j]) * static_cast(rb[j + 2]); + *d[2] += + static_cast(ra[j + 2]) * static_cast(rb[j]); + *d[3] += + static_cast(ra[j + 2]) * static_cast(rb[j + 2]); + } + } + } + } else if constexpr (M == 16 && N == 8 && K == 16) { + if constexpr (std::is_floating_point_v) { + // Init D matrix fragment with C matrix fragment + *d[0] = c[0]; + *d[1] = c[1]; + *d[2] = c[2]; + *d[3] = c[3]; + + // Each sub-group is responsible for computing a fragment size of 16*8 + // elements of matrix D. + // Each work item computes 4 elements of matrix D by gathering + // their corresponding row & col matrix fragments of length k (8) + // from A & B matrices respectively using below mapping logic: + // row0 = (lane >> 2) & row1 = (lane >> 2) + 8 + // col0 = (lane % 4) * 2 & col1 = (lane % 4) * 2 + 1 + // As each row & col fragment of A & B matrices is distributed across + // 4 work items, each iteration of below loop loads a partial fragment + // of matrix A (row) and matrix B (col) using the row & col offsets. + for (int i = 0; i < 4; i++) { + typename MMAType::PackType recv_a[4], recv_b[4]; + + // Load partial fragment from row0 of matrix A ({a0, a1}) + recv_a[0] = + dpct::select_from_sub_group(sg, a[0], row_load_offset + i); + // Load partial fragment from row0 of matrix A ({a2, a3}) + recv_a[1] = + dpct::select_from_sub_group(sg, a[2], row_load_offset + i); + // Load partial fragment from row1 of matrix A ({a0, a1}) + recv_a[2] = + dpct::select_from_sub_group(sg, a[1], row_load_offset + i); + // Load partial fragment from row1 of matrix A ({a2, a3}) + recv_a[3] = + dpct::select_from_sub_group(sg, a[3], row_load_offset + i); + + // Load partial fragment from col0 of matrix B ({b0, b1}) + recv_b[0] = + dpct::select_from_sub_group(sg, b[0], col_load_offset + i); + // Load partial fragment from col0 of matrix B ({b2, b3}) + recv_b[1] = + dpct::select_from_sub_group(sg, b[1], col_load_offset + i); + // Load partial fragment from col1 of matrix B ({b0, b1}) + recv_b[2] = + dpct::select_from_sub_group(sg, b[0], col_load_offset + 4 + i); + // Load partial fragment from col1 of matrix B ({b2, b3}) + recv_b[3] = + dpct::select_from_sub_group(sg, b[1], col_load_offset + 4 + i); + + auto ra = reinterpret_cast(recv_a); + auto rb = reinterpret_cast(recv_b); + + // Each work item calculates a partial product of A & B matrix + // fragments and adds it to the corresponding D matrix fragment d0 + // += row0{ a0, a1, a2, a3 } * col0{ b0, b1, b2, b3 } d1 += row0{ + // a0, a1, a2, a3 } * col1{ b0, b1, b2, b3 } d2 += row1{ a0, a1, a2, + // a3 } * col0{ b0, b1, b2, b3 } d3 += row1{ a0, a1, a2, a3 } * + // col1{ b0, b1, b2, b3 } + for (int j = 0; j < 4; j++) { + *d[0] += static_cast(ra[j]) * static_cast(rb[j]); + *d[1] += + static_cast(ra[j]) * static_cast(rb[j + 4]); + *d[2] += + static_cast(ra[j + 4]) * static_cast(rb[j]); + *d[3] += static_cast(ra[j + 4]) * + static_cast(rb[j + 4]); + } + } + } else if constexpr (std::is_integral_v) { + // Init D matrix with fragments of C matrix + *d[0] = c[0]; + *d[1] = c[1]; + *d[2] = c[2]; + *d[3] = c[3]; + + // Each sub-group is responsible for computing a fragment size of 16*8 + // elements of matrix D. + // Each work item computes 4 elements of matrix D by gathering + // their corresponding row & col matrix fragments of length k (8) + // from A & B matrices respectively using below mapping logic: + // row0 = (lane >> 2) & row1 = (lane >> 2) + 8 + // col0 = (lane % 4) * 2 & col1 = (lane % 4) * 2 + 1 + // As each row & col fragment of A & B matrices is distributed across + // 4 work items, each iteration of below loop loads a partial fragment + // of matrix A (row) and matrix B (col) using the row & col offsets. + for (int i = 0; i < 4; i++) { + typename MMAType::PackType recv_a[2], recv_b[2]; + + // Load partial fragment from row0 of matrix A ({a0, a1, a2, a3}) + recv_a[0] = + dpct::select_from_sub_group(sg, a[0], row_load_offset + i); + // Load partial fragment from row1 of matrix A ({a4, a5, a6, a7}) + recv_a[1] = + dpct::select_from_sub_group(sg, a[1], row_load_offset + i); + // Load partial fragment from col0 of matrix B ({b0, b1, b2, b3}) + recv_b[0] = + dpct::select_from_sub_group(sg, b[0], col_load_offset + i); + // Load partial fragment from col1 of matrix B ({b4, b5, b6, b7}) + recv_b[1] = + dpct::select_from_sub_group(sg, b[0], col_load_offset + i + 4); + + auto ra = reinterpret_cast(recv_a); + auto rb = reinterpret_cast(recv_b); + + // Each work item calculates a partial product of A & B matrix + // fragments and adds it to the corresponding D matrix fragment d0 + // += row0{ a0, a1, a2, a3 } * col0{ b0, b1, b2, b3 } d1 += row0{ + // a0, a1, a2, a3 } * col1{ b4, b5, b6, b7 } d2 += row1{ a4, a5, a6, + // a7 } * col0{ b0, b1, b2, b3 } d3 += row1{ a4, a5, a6, a7 } * + // col1{ b4, b5, b6, b7 } + for (int i = 0; i < 4; i++) { + *d[0] += ra[i] * rb[i]; + *d[1] += ra[i] * rb[i + 4]; + *d[2] += ra[i + 4] * rb[i]; + *d[3] += ra[i + 4] * rb[i + 4]; + } + } + } + } else if constexpr (M == 16 && N == 8 && K == 32) { + if constexpr (std::is_integral_v) { + // Init D matrix with fragments of C matrix + *d[0] = c[0]; + *d[1] = c[1]; + *d[2] = c[2]; + *d[3] = c[3]; + + // Each sub-group is responsible for computing a fragment size of 16*8 + // elements of matrix D. + // Each work item computes 4 elements of matrix D by gathering + // their corresponding row & col matrix fragments of length k (32) + // from A & B matrices respectively using below mapping logic: + // row0 = (lane >> 2) & row1 = (lane >> 2) + 8 + // col0 = ((lane % 4) * 4) + (i & 0x3) & col1 = ((lane % 4) * 4) + (i + // & 0x3) As each row & col fragment of A & B matrices is distributed + // across 4 work items, each iteration of below loop loads a partial + // fragment of matrix A (row) and matrix B (col) using the row & col + // offsets. + for (int i = 0; i < 4; i++) { + typename MMAType::PackType recv_a[2], recv_b[2]; + + // Load partial fragment from row0 of matrix A ({a0, a1, a2, a3}) + recv_a[0] = + dpct::select_from_sub_group(sg, a[0], row_load_offset + i); + // Load partial fragment from row1 of matrix A ({a4, a5, a6, a7}) + recv_a[1] = + dpct::select_from_sub_group(sg, a[1], row_load_offset + i); + // Load partial fragment from col0 of matrix B ({b0, b1, b2, b3}) + recv_b[0] = + dpct::select_from_sub_group(sg, b[0], col_load_offset + i); + // Load partial fragment from col1 of matrix B ({b0, b1, b2, b3}) + recv_b[1] = + dpct::select_from_sub_group(sg, b[0], col_load_offset + i + 4); + + auto a = reinterpret_cast(recv_a); + auto b = reinterpret_cast(recv_b); + + // Each work item calculates a partial product of A & B matrix + // fragments and adds it to the corresponding D matrix fragment d0 + // += row0{ a0, a1, a2, a3 } * col0{ b0, b1, b2, b3 } d1 += row0{ + // a0, a1, a2, a3 } * col1{ b0, b1, b2, b3 } d2 += row1{ a4, a5, a6, + // a7 } * col0{ b0, b1, b2, b3 } d3 += row1{ a4, a5, a6, a7 } * + // col1{ b0, b1, b2, b3 } + for (int j = 0; j < 4; j++) { + *d[0] += a[j] * b[j]; + *d[1] += a[j] * b[j + 4]; + *d[2] += a[j + 4] * b[j]; + *d[3] += a[j + 4] * b[j + 4]; + } + } + + for (int i = 0; i < 4; i++) { + typename MMAType::PackType recv_a[2], recv_b[2]; + + // Load partial fragment from row0 of matrix A ({a8, a9, a10, a11}) + recv_a[0] = + dpct::select_from_sub_group(sg, a[2], row_load_offset + i); + // Load partial fragment from row1 of matrix A ({a12, a13, a14, + // a15}) + recv_a[1] = + dpct::select_from_sub_group(sg, a[3], row_load_offset + i); + // Load partial fragment from col0 of matrix B ({b4, b5, b6, b7}) + recv_b[0] = + dpct::select_from_sub_group(sg, b[1], col_load_offset + i); + // Load partial fragment from col1 of matrix B ({b4, b5, b6, b7}) + recv_b[1] = + dpct::select_from_sub_group(sg, b[1], col_load_offset + i + 4); + + auto a = reinterpret_cast(recv_a); + auto b = reinterpret_cast(recv_b); + + // Each work item calculates a partial product of A & B matrix + // fragments and adds it to the corresponding D matrix fragment d0 + // += row0{ a8, a9, a10, a11 } * col0{ b4, b5, b6, b7 } d1 += row0{ + // a8, a9, a10, a11 } * col1{ b4, b5, b6, b7 } d2 += row1{ a12, a13, + // a14, a15 } * col0{ b4, b5, b6, b7 } d3 += row1{ a12, a13, a14, + // a15 } * col1{ b4, b5, b6, b7 } + for (int j = 0; j < 4; j++) { + *d[0] += a[j] * b[j]; + *d[1] += a[j] * b[j + 4]; + *d[2] += a[j + 4] * b[j]; + *d[3] += a[j + 4] * b[j + 4]; + } + } + } + } + } } // COPY from DPCT head files #endif // GGML_SYCL_DPCT_HELPER_HPP diff --git a/ggml/src/ggml-sycl/fattn-common.hpp b/ggml/src/ggml-sycl/fattn-common.hpp new file mode 100644 index 00000000000..ed00d03c3b6 --- /dev/null +++ b/ggml/src/ggml-sycl/fattn-common.hpp @@ -0,0 +1,1179 @@ +#pragma once + +#include +#include "dpct/helper.hpp" +#include "common.hpp" +#include "convert.hpp" +#include "vecdotq.hpp" + +#include "ggml.h" + +#include +#include +#include + + +#define FATTN_KQ_STRIDE 256 +#define HALF_MAX_HALF sycl::half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction. +#define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs. +#define FATTN_KQ_MAX_OFFSET (3.0f*0.6931f) + +typedef void (*fattn_kernel_t)( + const char* Q, + const char* K, + const char* V, + const char* mask, + const char* sinks, + const int* KV_max, + float* dst, + sycl::float2* dst_meta, + const float scale, + const float max_bias, + const float m0, + const float m1, + const uint32_t n_head_log2, + const float logit_softcap, + const int32_t ne00, + const sycl::uint3 ne01, + const int32_t ne02, + const int32_t ne03, + const int32_t nb01, + const int32_t nb02, + const int32_t nb03, + const int32_t ne10, + const int32_t ne11, + const int32_t ne12, + const int32_t ne13, + const int32_t nb11, + const int32_t nb12, + const int64_t nb13, + const int32_t nb21, + const int32_t nb22, + const int64_t nb23, + const int32_t ne31, + const int32_t ne32, + const int32_t ne33, + const int32_t nb31, + const int32_t nb32, + const int64_t nb33); + +typedef float (*vec_dot_KQ_t)( + const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds); + +template +static __dpct_inline__ float vec_dot_fattn_vec_KQ_f16(const char * __restrict__ K_c, + const void * __restrict__ Q_v, + const int * __restrict__ Q_q8, + const void * __restrict__ Q_ds_v) { + const sycl::half2 * K_h2 = (const sycl::half2 *) K_c; + GGML_UNUSED(Q_q8); + GGML_UNUSED(Q_ds_v); + + constexpr int cpy_nb = ggml_sycl_get_max_cpy_bytes(); + constexpr int cpy_ne = cpy_nb / 4; + + float sum = 0.0f; + +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += nthreads*cpy_ne) { + sycl::half2 tmp[cpy_ne]; + ggml_sycl_memcpy_1( + tmp, + K_h2 + k_KQ_0 + (sycl::ext::oneapi::this_work_item::get_nd_item<3>().get_local_id(2) % nthreads) * cpy_ne); +#pragma unroll + for (int k_KQ_1 = 0; k_KQ_1 < cpy_ne; ++k_KQ_1) { +#ifdef GGML_SYCL_F16 + ggml_sycl_mad(sum, tmp[k_KQ_1] , ((const sycl::half2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]); +#else + ggml_sycl_mad(sum, __half22float2(tmp[k_KQ_1]), ((const sycl::float2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]); +#endif // GGML_SYCL_F16 + } + } + + return sum; +} + +template +static __dpct_inline__ float vec_dot_fattn_vec_KQ_q4_0(const char * __restrict__ K_c, + const void * __restrict__ Q_v, + const int * __restrict__ Q_q8, + const void * __restrict__ Q_ds_v) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + + const block_q4_0 * K_q4_0 = (const block_q4_0 *) K_c; + GGML_UNUSED(Q_v); + + float sum = 0.0f; + +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) { + const int k_KQ = + k_KQ_0 + (nthreads == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads); + + const int ib = k_KQ / QI8_1; + const int iqs4 = k_KQ % QI4_0; + const int shift = k_KQ & (QI8_1/2); + + int v; + ggml_sycl_memcpy_1(&v, K_q4_0[ib].qs + sizeof(int)*iqs4); + v = (v >> shift) & 0x0F0F0F0F; + const int u = Q_q8[k_KQ_0/nthreads]; + + const int sumi = ggml_sycl_dp4a(v, u, 0); + + const sycl::float2 Q_ds = ((const sycl::float2 *) Q_ds_v)[k_KQ_0 / nthreads]; + sum += __half2float(K_q4_0[ib].d) * (sumi*Q_ds.x() - (8/QI8_1)*Q_ds.y()); + } + + return sum; +} + +template +static __dpct_inline__ float vec_dot_fattn_vec_KQ_q4_1(const char * __restrict__ K_c, + const void * __restrict__ Q_v, + const int * __restrict__ Q_q8, + const void * __restrict__ Q_ds_v) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + const block_q4_1 * K_q4_1 = (const block_q4_1 *) K_c; + GGML_UNUSED(Q_v); + + float sum = 0.0f; + +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) { + const int k_KQ = + k_KQ_0 + (nthreads == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads); + + const int ib = k_KQ / QI8_1; + const int iqs4 = k_KQ % QI4_1; + const int shift = k_KQ & (QI8_1/2); + + int v; + ggml_sycl_memcpy_1(&v, K_q4_1[ib].qs + sizeof(int)*iqs4); + v = (v >> shift) & 0x0F0F0F0F; + const int u = Q_q8[k_KQ_0/nthreads]; + + const int sumi = ggml_sycl_dp4a(v, u, 0); + + const sycl::float2 K_dm = (K_q4_1[ib].dm).template convert(); + const sycl::float2 Q_ds = ((const sycl::float2 *) Q_ds_v)[k_KQ_0 / nthreads]; + + sum += K_dm.x()*Q_ds.x()*sumi + K_dm.y()*Q_ds.y()/QI8_1; + } + + return sum; +} + +template +static __dpct_inline__ float vec_dot_fattn_vec_KQ_q5_0(const char * __restrict__ K_c, + const void * __restrict__ Q_v, + const int * __restrict__ Q_q8, + const void * __restrict__ Q_ds_v) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + const block_q5_0 * K_q5_0 = (const block_q5_0 *) K_c; + GGML_UNUSED(Q_v); + + float sum = 0.0f; + +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) { + const int k_KQ = + k_KQ_0 + (nthreads == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads); + + const int ib = k_KQ / QI8_1; + const int iqs4 = k_KQ % QI5_0; + const int iqs8 = k_KQ % QI8_1; + const int shift = k_KQ & (QI8_1/2); + + int v; + ggml_sycl_memcpy_1(&v, K_q5_0[ib].qs + sizeof(int)*iqs4); + v = (v >> shift) & 0x0F0F0F0F; + + { + int vh; + ggml_sycl_memcpy_1(&vh, K_q5_0[ib].qh); + vh >>= iqs8 * QI5_0; + + v |= (vh << 4) & 0x00000010; // 0 -> 4 + v |= (vh << 11) & 0x00001000; // 1 -> 12 + v |= (vh << 18) & 0x00100000; // 2 -> 20 + v |= (vh << 25) & 0x10000000; // 3 -> 28 + } + + const int u = Q_q8[k_KQ_0/nthreads]; + + const int sumi = ggml_sycl_dp4a(v, u, 0); + + const sycl::float2 Q_ds = ((const sycl::float2 *) Q_ds_v)[k_KQ_0 / nthreads]; + + sum += __half2float(K_q5_0[ib].d) * (sumi*Q_ds.x() - (16/QI8_1)*Q_ds.y()); + } + + return sum; +} + +template +static __dpct_inline__ float vec_dot_fattn_vec_KQ_q5_1(const char * __restrict__ K_c, + const void * __restrict__ Q_v, + const int * __restrict__ Q_q8, + const void * __restrict__ Q_ds_v) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + const block_q5_1 * K_q5_1 = (const block_q5_1 *) K_c; + GGML_UNUSED(Q_v); + + float sum = 0.0f; + +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) { + const int k_KQ = + k_KQ_0 + (nthreads == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads); + + const int ib = k_KQ / QI8_1; + const int iqs4 = k_KQ % QI5_1; + const int iqs8 = k_KQ % QI8_1; + const int shift = k_KQ & (QI8_1/2); + + int v; + ggml_sycl_memcpy_1(&v, K_q5_1[ib].qs + sizeof(int)*iqs4); + v = (v >> shift) & 0x0F0F0F0F; + + { + int vh; + ggml_sycl_memcpy_1(&vh, K_q5_1[ib].qh); + vh >>= iqs8 * QI5_0; + + v |= (vh << 4) & 0x00000010; // 0 -> 4 + v |= (vh << 11) & 0x00001000; // 1 -> 12 + v |= (vh << 18) & 0x00100000; // 2 -> 20 + v |= (vh << 25) & 0x10000000; // 3 -> 28 + } + + const int u = Q_q8[k_KQ_0/nthreads]; + + const int sumi = ggml_sycl_dp4a(v, u, 0); + + const sycl::float2 K_dm = (K_q5_1[ib].dm).template convert(); + const sycl::float2 Q_ds = ((const sycl::float2 *) Q_ds_v)[k_KQ_0 / nthreads]; + + sum += K_dm.x()*Q_ds.x()*sumi + K_dm.y()*Q_ds.y()/QI8_1; + } + + return sum; +} + +template +static __dpct_inline__ float vec_dot_fattn_vec_KQ_q8_0(const char * __restrict__ K_c, + const void * __restrict__ Q_v, + const int * __restrict__ Q_q8, + const void * __restrict__ Q_ds_v) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + const block_q8_0 * K_q8_0 = (const block_q8_0 *) K_c; + GGML_UNUSED(Q_v); + + float sum = 0.0f; + +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) { + const int k_KQ = + k_KQ_0 + (nthreads == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads); + + const int ib = k_KQ / QI8_0; + const int iqs = k_KQ % QI8_0; + + int v; + ggml_sycl_memcpy_1(&v, K_q8_0[ib].qs + 4*iqs); + + const sycl::float2 * Q_ds = (const sycl::float2 *) Q_ds_v; + const float Q_d = Q_ds[k_KQ_0 / nthreads].x(); + + sum += vec_dot_q8_0_q8_1_impl(&v, &Q_q8[k_KQ_0/nthreads], K_q8_0[ib].d, Q_d); + } + + return sum; +} + +template +static __dpct_inline__ void quantize_q8_1_to_shared(const float * __restrict__ x, + const float scale, + int * __restrict__ yq32, + void * __restrict__ yds) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + + float vals[sizeof(int)] = { 0.0f }; +#pragma unroll + for (int l = 0; l < int(sizeof(int)); ++l) { + vals[l] = + (ni == warp_size || item_ct1.get_local_id(2) < ni) ? scale * x[4 * item_ct1.get_local_id(2) + l] : 0.0f; + } + + float amax = sycl::fabs(vals[0]); + float sum = vals[0]; +#pragma unroll + for (int l = 1; l < int(sizeof(int)); ++l) { + amax = sycl::fmax(amax, sycl::fabs(vals[l])); + sum += vals[l]; + } +#pragma unroll + for (int mask = QI8_1/2; mask > 0; mask >>= 1) { + amax = sycl::fmax( + amax, dpct::permute_sub_group_by_xor(sycl::ext::oneapi::this_work_item::get_sub_group(), amax, mask)); + sum += dpct::permute_sub_group_by_xor(sycl::ext::oneapi::this_work_item::get_sub_group(), sum, mask); + } + + const float d = amax / 127; + int q32 = 0; + int8_t * q8 = (int8_t *) &q32; + + if (d != 0.0f) { +#pragma unroll + for (int l = 0; l < int(sizeof(int)); ++l) { + q8[l] = sycl::round(vals[l] / d); + } + } + + yq32[item_ct1.get_local_id(2)] = q32; + if (item_ct1.get_local_id(2) % QI8_1 == 0 && (ni == warp_size || item_ct1.get_local_id(2) < ni)) { + if (std::is_same::value) { + ((sycl::half2 *) yds)[item_ct1.get_local_id(2)/QI8_1] = make_half2(d, sum); + } else { + ((sycl::float2 *) yds)[item_ct1.get_local_id(2)/QI8_1] = make_float2(d, sum); + } + } +} + +typedef void (*dequantize_V_t)(const void *, void *, const int64_t); + +template +static __dpct_inline__ void dequantize_V_f16(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) { + if constexpr (std::is_same_v) { + ggml_sycl_memcpy_1(dst, (const sycl::half *) vx + i0); + } else if constexpr (std::is_same_v) { + static_assert(ne % 2 == 0, "bad ne"); + sycl::half2 tmp[ne / 2]; + ggml_sycl_memcpy_1(tmp, (const sycl::half *) vx + i0); + sycl::float2 * dst_f2 = (sycl::float2 *) dst; +#pragma unroll + for (int l = 0; l < ne/2; ++l) { + dst_f2[l] = tmp[l].template convert(); + } + } else { + static_assert(std::is_same_v, "unsupported type"); + } +} + +template +static __dpct_inline__ void dequantize_V_q4_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) { + const block_q4_0 * x = (const block_q4_0 *) vx; + + const int64_t ib = i0 / QK4_0; + const int iqs = i0 % (QK4_0/2); + const int shift = (i0 % QK4_0) / (QK4_0/2); + + int q; + static_assert(ne == 2 || ne == 4, "bad ne"); + ggml_sycl_memcpy_1(&q, x[ib].qs + iqs); + q >>= 4*shift; + q &= 0x0F0F0F0F; + q = dpct::vectorized_binary(q, 0x08080808, dpct::sub_sat()); + + const int8_t * q8 = (const int8_t *) &q; + +#ifdef GGML_SYCL_F16 + if constexpr (std::is_same_v) { + const sycl::half2 d = sycl::half2(x[ib].d); + +#pragma unroll + for (int l0 = 0; l0 < ne; l0 += 2) { + ((sycl::half2 *) dst)[l0 / 2] = d * sycl::half2(q8[l0 + 0], q8[l0 + 1]); + } + } else +#endif // GGML_SYCL_F16 + if constexpr (std::is_same_v) { + const float d = x[ib].d; + +#pragma unroll + for (int l = 0; l < ne; ++l) { + ((float *) dst)[l] = d * q8[l]; + } + } else { + static_assert(std::is_same_v, "bad type"); + } +} + +template +static __dpct_inline__ void dequantize_V_q4_1(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) { + const block_q4_1 * x = (const block_q4_1 *) vx; + + const int64_t ib = i0 / QK4_1; + const int iqs = i0 % (QK4_1/2); + const int shift = (i0 % QK4_1) / (QK4_1/2); + + int q; + static_assert(ne == 2 || ne == 4, "bad ne"); + ggml_sycl_memcpy_1(&q, x[ib].qs + iqs); + q >>= 4*shift; + q &= 0x0F0F0F0F; + + const int8_t * q8 = (const int8_t *) &q; + +#ifdef GGML_SYCL_F16 + if constexpr (std::is_same_v) { + const sycl::half2 dm = x[ib].dm; + const sycl::half2 d = sycl::half2(dm[0]); + const sycl::half2 m = sycl::half2(dm[1]); + +#pragma unroll + for (int l0 = 0; l0 < ne; l0 += 2) { + ((sycl::half2 *) dst)[l0 / 2] = d * sycl::half2(q8[l0 + 0], q8[l0 + 1]) + m; + } + } else +#endif // GGML_SYCL_F16 + if constexpr (std::is_same_v) { + const sycl::float2 dm = (x[ib].dm).template convert(); + +#pragma unroll + for (int l = 0; l < ne; ++l) { + ((float *) dst)[l] = dm.x() * q8[l] + dm.y(); + } + } else { + static_assert(std::is_same_v, "bad type"); + } +} + +template +static __dpct_inline__ void dequantize_V_q5_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) { + const block_q5_0 * x = (const block_q5_0 *) vx; + + const int64_t ib = i0 / QK5_0; + const int idq = i0 % QK5_0; + const int iqs = i0 % (QK5_0/2); + const int shift = (i0 % QK5_0) / (QK5_0/2); + + int q; + static_assert(ne == 2 || ne == 4, "bad ne"); + ggml_sycl_memcpy_1(&q, x[ib].qs + iqs); + q >>= 4*shift; + q &= 0x0F0F0F0F; + + { + int qh; + ggml_sycl_memcpy_1(&qh, x[ib].qh); +#pragma unroll + for (int l = 0; l < ne; ++l) { + q |= ((qh >> (idq + l)) & 0x00000001) << (8*l + 4); + } + } + + q = dpct::vectorized_binary(q, 0x10101010, dpct::sub_sat()); + + const int8_t * q8 = (const int8_t *) &q; + +#ifdef GGML_SYCL_F16 + if constexpr (std::is_same_v) { + const sycl::half2 d = sycl::half2(x[ib].d); + +#pragma unroll + for (int l0 = 0; l0 < ne; l0 += 2) { + ((sycl::half2 *) dst)[l0 / 2] = d * sycl::half2(q8[l0 + 0], q8[l0 + 1]); + } + } else +#endif // GGML_SYCL_F16 + if constexpr (std::is_same_v) { + const float d = x[ib].d; + +#pragma unroll + for (int l = 0; l < ne; ++l) { + ((float *) dst)[l] = d * q8[l]; + } + } else { + static_assert(std::is_same_v, "bad type"); + } +} + +template +static __dpct_inline__ void dequantize_V_q5_1(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) { + const block_q5_1 * x = (const block_q5_1 *) vx; + + const int64_t ib = i0 / QK5_1; + const int idq = i0 % QK5_1; + const int iqs = i0 % (QK5_1/2); + const int shift = (i0 % QK5_1) / (QK5_1/2); + + int q; + static_assert(ne == 2 || ne == 4, "bad ne"); + ggml_sycl_memcpy_1(&q, x[ib].qs + iqs); + q >>= 4*shift; + q &= 0x0F0F0F0F; + + { + int qh; + ggml_sycl_memcpy_1(&qh, x[ib].qh); +#pragma unroll + for (int l = 0; l < ne; ++l) { + q |= ((qh >> (idq + l)) & 0x00000001) << (8*l + 4); + } + } + + const int8_t * q8 = (const int8_t *) &q; + +#ifdef GGML_SYCL_F16 + if constexpr (std::is_same_v) { + const sycl::half2 dm = x[ib].dm; + const sycl::half2 d = sycl::half2(dm[0]); + const sycl::half2 m = sycl::half2(dm[1]); + +#pragma unroll + for (int l0 = 0; l0 < ne; l0 += 2) { + ((sycl::half2 *) dst)[l0 / 2] = d * sycl::half2(q8[l0 + 0], q8[l0 + 1]) + m; + } + } else +#endif // GGML_SYCL_F16 + if constexpr (std::is_same_v) { + const sycl::float2 dm = (x[ib].dm).template convert(); + +#pragma unroll + for (int l = 0; l < ne; ++l) { + ((float *) dst)[l] = dm.x() * q8[l] + dm.y(); + } + } else { + static_assert(std::is_same_v, "bad type"); + } +} + +template +static __dpct_inline__ void dequantize_V_q8_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) { + const block_q8_0 * x = (const block_q8_0 *) vx; + + const int64_t ib = i0 / QK8_0; + const int iqs = i0 % QK8_0; + + static_assert(ne % 2 == 0, "bad ne"); + int8_t qs[ne]; + ggml_sycl_memcpy_1(qs, x[ib].qs + iqs); + +#ifdef GGML_SYCL_F16 + if constexpr (std::is_same::value) { + const sycl::half2 d = sycl::half2(x[ib].d); + +#pragma unroll + for (int l0 = 0; l0 < ne; l0 += 2) { + ((sycl::half2 *) dst)[l0 / 2] = d * make_half2(qs[l0 + 0], qs[l0 + 1]); + } + } else +#endif // GGML_SYCL_F16 + if constexpr (std::is_same::value) { + const float d = x[ib].d; + +#pragma unroll + for (int l = 0; l < ne; ++l) { + ((float *) dst)[l] = d * qs[l]; + } + } else { + static_assert(std::is_same_v, "unsupported type"); + } +} + +template +constexpr vec_dot_KQ_t get_vec_dot_KQ() { + if constexpr (type_K == GGML_TYPE_F16) { + return vec_dot_fattn_vec_KQ_f16; + } else if constexpr (type_K == GGML_TYPE_Q4_0) { + return vec_dot_fattn_vec_KQ_q4_0; + } else if constexpr (type_K == GGML_TYPE_Q4_1) { + return vec_dot_fattn_vec_KQ_q4_1; + } else if constexpr (type_K == GGML_TYPE_Q5_0) { + return vec_dot_fattn_vec_KQ_q5_0; + } else if constexpr (type_K == GGML_TYPE_Q5_1) { + return vec_dot_fattn_vec_KQ_q5_1; + } else if constexpr (type_K == GGML_TYPE_Q8_0) { + return vec_dot_fattn_vec_KQ_q8_0; + } else { + static_assert(type_K == -1, "bad type"); + return nullptr; + } +} + +template +constexpr dequantize_V_t get_dequantize_V() { + if constexpr (type_V == GGML_TYPE_F16) { + return dequantize_V_f16; + } else if constexpr (type_V == GGML_TYPE_Q4_0) { + return dequantize_V_q4_0; + } else if constexpr (type_V == GGML_TYPE_Q4_1) { + return dequantize_V_q4_1; + } else if constexpr (type_V == GGML_TYPE_Q5_0) { + return dequantize_V_q5_0; + } else if constexpr (type_V == GGML_TYPE_Q5_1) { + return dequantize_V_q5_1; + } else if constexpr (type_V == GGML_TYPE_Q8_0) { + return dequantize_V_q8_0; + } else { + static_assert(type_V == -1, "bad type"); + return nullptr; + } +} + +template +static void flash_attn_mask_to_KV_max(const sycl::half2 * __restrict__ mask, + int * __restrict__ KV_max, + const int ne30, + const int s31, + const int s33, + int * buf_iw) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + const int ne31 = item_ct1.get_group_range(2); + const int tid = item_ct1.get_local_id(2); + const int sequence = item_ct1.get_group(1); + const int jt = item_ct1.get_group(2); + + mask += sequence*s33 + jt*ncols1*s31; + + if (tid < warp_size) { + buf_iw[tid] = 1; + } + item_ct1.barrier(sycl::access::fence_space::local_space); + + int KV_max_sj = (ne30 - 1) * FATTN_KQ_STRIDE; + for (; KV_max_sj >= 0; KV_max_sj -= FATTN_KQ_STRIDE) { + int all_inf = 1; + +#pragma unroll + for (int j = 0; j < ncols1; ++j) { + const sycl::float2 tmp = + mask[j * s31 + KV_max_sj / 2 + tid].template convert(); + all_inf = all_inf && int(sycl::isinf((float) (tmp.x()))) && int(sycl::isinf((float) (tmp.y()))); + } + + all_inf = warp_reduce_all(all_inf); + if (tid % warp_size == 0) { + buf_iw[tid / warp_size] = all_inf; + } + item_ct1.barrier(sycl::access::fence_space::local_space); + all_inf = buf_iw[tid % warp_size]; + item_ct1.barrier(sycl::access::fence_space::local_space); + all_inf = warp_reduce_all(all_inf); + + if (!all_inf) { + break; + } + } + + // If the break in the loop was not triggered, KV_max_sj is now -FATTN_KQ_STRIDE. + // If the break was triggered it's the lower edge of the tile with the first non-masked values. + // In either case, walk back the decrementation by FATTN_KQ_STRIDE. + KV_max_sj += FATTN_KQ_STRIDE; + + if (item_ct1.get_local_id(2) != 0) { + return; + } + + KV_max[sequence*ne31 + jt] = KV_max_sj; +} + +template // D == head size + +static void flash_attn_stream_k_fixup(float * __restrict__ dst, + const sycl::float2 * __restrict__ dst_fixup, + const int ne01, + const int ne02, + const int ne03, + const int ne11, + const int ne12, + const int nbatch_fa) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + constexpr int ncols = ncols1 * ncols2; + + const int bidx0 = item_ct1.get_group(2); + const int j = item_ct1.get_group(1); + const int c = item_ct1.get_group(0); + const int jc = j*ncols2 + c; + const int tid = item_ct1.get_local_id(2); + + const float * dst_fixup_data = ((const float *) dst_fixup) + item_ct1.get_group_range(2) * (2 * 2 * ncols); + + const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. + + const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa; + const int iter_j = (ne01 + (ncols1 - 1)) / ncols1; + const int iter_z_gqa = (gqa_ratio + (ncols2 - 1)) / ncols2; + + const int kbc0 = int64_t(bidx0 + 0) * (iter_k * iter_j * iter_z_gqa * ne12 * ne03) / item_ct1.get_group_range(2); + const int kbc0_stop = + int64_t(bidx0 + 1) * (iter_k * iter_j * iter_z_gqa * ne12 * ne03) / item_ct1.get_group_range(2); + + const bool did_not_have_any_data = kbc0 == kbc0_stop; + const bool wrote_beginning_of_tile = kbc0 % iter_k == 0; + const bool did_not_write_last = kbc0/iter_k == kbc0_stop/iter_k && kbc0_stop % iter_k != 0; + if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) { + return; + } + + // z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index + const int sequence = kbc0 /(iter_k*iter_j*iter_z_gqa*ne12); + const int z_KV = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa); + const int zt_gqa = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j); + const int jt = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k; + + const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index. + + if (jt*ncols1 + j >= ne01 || zt_gqa*ncols2 + c >= gqa_ratio) { + return; + } + + dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + zt_Q*D + (j*ne02 + c)*D + tid; + + // Load the partial result that needs a fixup: + float dst_val = 0.0f; + float max_val = 0.0f; + float rowsum = 0.0f; + { + dst_val = *dst; + + const sycl::float2 tmp = dst_fixup[bidx0 * ncols + jc]; + max_val = tmp.x(); + rowsum = tmp.y(); + } + + // Iterate over previous blocks and compute the combined results. + // All SYCL blocks that get here must have a previous block that needs a fixup. + int bidx = bidx0 - 1; + int kbc_stop = kbc0; + while(true) { + const int kbc = int64_t(bidx) * (iter_k * iter_j * iter_z_gqa * ne12 * ne03) / item_ct1.get_group_range(2); + if (kbc == kbc_stop) { // Did not have any data. + bidx--; + kbc_stop = kbc; + continue; + } + + const float dst_add = dst_fixup_data[bidx*ncols*D + jc*D + tid]; + + const sycl::float2 tmp = dst_fixup[(item_ct1.get_group_range(2) + bidx) * ncols + jc]; + + // Scale the current and new value accumulators depending on the max. values. + const float max_val_new = sycl::fmax(max_val, tmp.x()); + + const float diff_val = max_val - max_val_new; + const float diff_add = tmp.x() - max_val_new; + + const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? sycl::native::exp(diff_val) : 0.0f; + const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? sycl::native::exp(diff_add) : 0.0f; + + dst_val = scale_val*dst_val + scale_add*dst_add; + rowsum = scale_val * rowsum + scale_add * tmp.y(); + + max_val = max_val_new; + + // If this block started in a previous tile we are done and don't need to combine additional partial results. + if (kbc % iter_k == 0 || kbc/iter_k < kbc0/iter_k) { + break; + } + bidx--; + kbc_stop = kbc; + } + + // Write back final result: + *dst = dst_val / rowsum; +} + +template // D == head size + +static void flash_attn_combine_results(const float * __restrict__ VKQ_parts, + const sycl::float2 * __restrict__ VKQ_meta, + float * __restrict__ dst, + const int parallel_blocks, + uint8_t * dpct_local) { + // Dimension 0: threadIdx.x + // Dimension 1: blockIdx.x + // Dimension 2: blockIdx.y + // Dimension 3: blockIdx.z + // Memory layout is permuted with [0, 2, 1, 3] + + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + const int ne01 = item_ct1.get_group_range(2); + const int ne02 = item_ct1.get_group_range(1); + + const int col = item_ct1.get_group(2); + const int head = item_ct1.get_group(1); + const int sequence = item_ct1.get_group(0); + + const int j_dst_unrolled = (sequence*ne01 + col)*ne02 + head; + + VKQ_parts += j_dst_unrolled * parallel_blocks*D; + VKQ_meta += j_dst_unrolled * parallel_blocks; + dst += j_dst_unrolled * D; + + const int tid = item_ct1.get_local_id(2); + __builtin_assume(tid < D); + + auto meta = (sycl::float2 *) dpct_local; + for (int i = tid; i < 2*parallel_blocks; i += D) { + ((float *) meta)[i] = ((const float *)VKQ_meta) [i]; + } + + item_ct1.barrier(sycl::access::fence_space::local_space); + + float kqmax = meta[0].x(); + for (int l = 1; l < parallel_blocks; ++l) { + kqmax = sycl::max(kqmax, meta[l].x()); + } + + float VKQ_numerator = 0.0f; + float VKQ_denominator = 0.0f; + for (int l = 0; l < parallel_blocks; ++l) { + const float KQ_max_scale = sycl::native::exp(meta[l].x() - kqmax); + + VKQ_numerator += KQ_max_scale * VKQ_parts[l*D + tid]; + VKQ_denominator += KQ_max_scale * meta[l].y(); + } + + dst[tid] = VKQ_numerator / VKQ_denominator; +} + +template +static void lauch_kernel( + dpct::dim3 group_range, + dpct::dim3 local_range, + queue_ptr q, + unsigned int local_mem_size, + const char* __restrict__ Q, + const char* __restrict__ K, + const char* __restrict__ V, + const char* __restrict__ mask, + const char* __restrict__ sinks, + const int* __restrict__ KV_max, + float* __restrict__ dst, + sycl::float2* __restrict__ dst_meta, + const float scale, + const float max_bias, + const float m0, + const float m1, + const uint32_t n_head_log2, + const float logit_softcap, + const int32_t ne00, + const sycl::uint3 ne01, + const int32_t ne02, + const int32_t ne03, + const int32_t nb01, + const int32_t nb02, + const int32_t nb03, + const int32_t ne10, + const int32_t ne11, + const int32_t ne12, + const int32_t ne13, + const int32_t nb11, + const int32_t nb12, + const int64_t nb13, + const int32_t nb21, + const int32_t nb22, + const int64_t nb23, + const int32_t ne31, + const int32_t ne32, + const int32_t ne33, + const int32_t nb31, + const int32_t nb32, + const int64_t nb33) { + GGML_UNUSED(local_mem_size); + q->submit([&](sycl::handler &cgh) { + cgh.parallel_for( + sycl::nd_range<3>( + static_cast>(group_range * local_range), + static_cast>(local_range)), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(warp_size)]] { + GGML_UNUSED(item_ct1); + fattn_kernel(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, + max_bias, m0, m1, n_head_log2, logit_softcap, ne00, + ne01, ne02, ne03, nb01, nb02, nb03, ne10, ne11, + ne12, ne13, nb11, nb12, nb13, nb21, nb22, nb23, + ne31, ne32, ne33, nb31, nb32, nb33); + }); + }); +} + +template +void launch_fattn( + ggml_backend_sycl_context & ctx, ggml_tensor * dst, const int nwarps, const size_t nbytes_shared, + const int nbatch_fa, const bool need_f16_K, const bool need_f16_V, const bool stream_k) { + + constexpr int ncols = ncols1 * ncols2; + + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; + + const bool V_is_K_view = V->view_src && (V->view_src == K || (V->view_src == K->view_src && V->view_offs == K->view_offs)); + + const ggml_tensor * mask = dst->src[3]; + const ggml_tensor * sinks = dst->src[4]; + + ggml_tensor * KQV = dst; + + GGML_ASSERT(Q->type == GGML_TYPE_F32); + GGML_ASSERT(KQV->type == GGML_TYPE_F32); + + GGML_ASSERT(Q->nb[0] == ggml_element_size(Q)); + GGML_ASSERT(K->nb[0] == ggml_element_size(K)); + GGML_ASSERT(V->nb[0] == ggml_element_size(V)); + + GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16); + + ggml_sycl_pool & pool = ctx.pool(); + dpct::queue_ptr main_stream = ctx.stream(); + const int id = ggml_sycl_get_device(); + const int nsm = ggml_sycl_info().devices[id].nsm; + + ggml_sycl_pool_alloc K_f16(pool); + ggml_sycl_pool_alloc V_f16(pool); + ggml_sycl_pool_alloc KV_max(pool); + ggml_sycl_pool_alloc dst_tmp(pool); + ggml_sycl_pool_alloc dst_tmp_meta(pool); + + const char * K_data = (const char *) K->data; + size_t nb11 = K->nb[1]; + size_t nb12 = K->nb[2]; + size_t nb13 = K->nb[3]; + + const char * V_data = (const char *) V->data; + size_t nb21 = V->nb[1]; + size_t nb22 = V->nb[2]; + size_t nb23 = V->nb[3]; + + if (need_f16_K && K->type != GGML_TYPE_F16) { + const size_t bs = ggml_blck_size(K->type); + const size_t ts = ggml_type_size(K->type); + + K_f16.alloc(ggml_nelements(K)); + if (ggml_is_contiguously_allocated(K)) { + to_fp16_sycl_t to_fp16 = ggml_get_to_fp16_sycl(K->type, dst); + to_fp16(K_data, K_f16.ptr, ggml_nelements(K), main_stream); + + nb11 = nb11 * bs * sizeof(sycl::half) / ts; + nb12 = nb12 * bs * sizeof(sycl::half) / ts; + nb13 = nb13 * bs * sizeof(sycl::half) / ts; + } else { + GGML_ASSERT(K->nb[0] == ts); + to_fp16_nc_sycl_t to_fp16 = ggml_get_to_fp16_nc_sycl(K->type); + const int64_t s01 = nb11 / ts; + const int64_t s02 = nb12 / ts; + const int64_t s03 = nb13 / ts; + to_fp16(K_data, K_f16.ptr, K->ne[0], K->ne[1], K->ne[2], K->ne[3], s01, s02, s03, main_stream); + + nb11 = K->ne[0] * sizeof(sycl::half); + nb12 = K->ne[1] * nb11; + nb13 = K->ne[2] * nb12; + } + K_data = (char *) K_f16.ptr; + } + + if (need_f16_V && V->type != GGML_TYPE_F16) { + if (V_is_K_view) { + V_data = K_data; + nb21 = nb11; + nb22 = nb12; + nb23 = nb13; + } else { + const size_t bs = ggml_blck_size(V->type); + const size_t ts = ggml_type_size(V->type); + + V_f16.alloc(ggml_nelements(V)); + if (ggml_is_contiguously_allocated(V)) { + to_fp16_sycl_t to_fp16 = ggml_get_to_fp16_sycl(V->type, dst); + to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream); + V_data = (char *) V_f16.ptr; + + nb21 = nb21 * bs * sizeof(sycl::half) / ts; + nb22 = nb22 * bs * sizeof(sycl::half) / ts; + nb23 = nb23 * bs * sizeof(sycl::half) / ts; + } else { + GGML_ASSERT(V->nb[0] == ts); + to_fp16_nc_sycl_t to_fp16 = ggml_get_to_fp16_nc_sycl(V->type); + const int64_t s01 = nb21 / ts; + const int64_t s02 = nb22 / ts; + const int64_t s03 = nb23 / ts; + to_fp16(V_data, V_f16.ptr, V->ne[0], V->ne[1], V->ne[2], V->ne[3], s01, s02, s03, main_stream); + + nb21 = V->ne[0] * sizeof(sycl::half); + nb22 = V->ne[1] * nb21; + nb23 = V->ne[2] * nb22; + } + V_data = (char *) V_f16.ptr; + } + } + + const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1); + const int gqa_ratio = Q->ne[2] / K->ne[2]; + const int ntiles_z_gqa = ((gqa_ratio + ncols2 - 1) / ncols2); + const int ntiles_total = ntiles_x * ntiles_z_gqa * K->ne[2] * Q->ne[3]; + + // Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped. + // Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or + // multiple sequences of possibly different lengths. + if (mask && K->ne[1] % FATTN_KQ_STRIDE == 0 && (Q->ne[1] >= 1024 || Q->ne[3] > 1)) { + const int s31 = mask->nb[1] / sizeof(sycl::half2); + const int s33 = mask->nb[3] / sizeof(sycl::half2); + + const dpct::dim3 blocks_num_KV_max(ntiles_x, Q->ne[3], 1); + const dpct::dim3 block_dim_KV_max(FATTN_KQ_STRIDE / 2, 1, 1); + + const int ne_KV_max = blocks_num_KV_max.x*blocks_num_KV_max.y; + const int iter_k = K->ne[1] / FATTN_KQ_STRIDE; + + KV_max.alloc(ne_KV_max); + { + dpct::has_capability_or_fail(main_stream->get_device(), { sycl::aspect::fp16 }); + + main_stream->submit([&](sycl::handler & cgh) { + sycl::local_accessor buf_iw_acc_ct1(sycl::range<1>(warp_size), cgh); + + auto mask_data_ct0 = (const sycl::half2 *) mask->data; + auto KV_max_ptr_ct1 = KV_max.ptr; + + cgh.parallel_for(sycl::nd_range<3>(blocks_num_KV_max * block_dim_KV_max, block_dim_KV_max), + [=](sycl::nd_item<3> item_ct1) { + GGML_UNUSED(item_ct1); + flash_attn_mask_to_KV_max( + mask_data_ct0, KV_max_ptr_ct1, iter_k, s31, s33, + buf_iw_acc_ct1.get_multi_ptr().get()); + }); + }); + } + SYCL_CHECK(0); + } + + const dpct::dim3 block_dim(warp_size, nwarps, 1); + + // Max. number of active blocks limited by occupancy. + int max_blocks_per_sm = ggml_sycl_info().devices[id].max_wg_per_cu; + int parallel_blocks = max_blocks_per_sm; + dpct::dim3 blocks_num; + if (stream_k) { + // For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup. + const int max_blocks = max_blocks_per_sm*nsm; + const int nblocks_stream_k = max_blocks; + const bool use_stream_k = true; + + blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_total; + blocks_num.y = 1; + blocks_num.z = 1; + + if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles. + dst_tmp_meta.alloc((size_t(blocks_num.x) * ncols * (2 + DV/2))); + } + } else { + const int ntiles_KQ = (K->ne[1] + nbatch_fa - 1) / nbatch_fa; // Max. number of parallel blocks limited by tensor size. + + // parallel_blocks must not be larger than what the tensor size allows: + parallel_blocks = std::min(parallel_blocks, ntiles_KQ); + // todo fix the hard code change + // parallel_blocks = ntiles_KQ; + + // If ntiles_total % blocks_per_wave != 0 then some efficiency is lost due to tail effects. + // Test whether parallel_blocks can be set to a higher value for better efficiency. + const int blocks_per_wave = nsm * max_blocks_per_sm; + int nwaves_best = 0; + int efficiency_percent_best = 0; + for (int parallel_blocks_test = parallel_blocks; parallel_blocks_test <= ntiles_KQ; ++parallel_blocks_test) { + const int nblocks_total = ntiles_total * parallel_blocks_test; + const int nwaves = (nblocks_total + blocks_per_wave - 1) / blocks_per_wave; + const int efficiency_percent = 100 * nblocks_total / (nwaves*blocks_per_wave); + + // Stop trying configurations with more waves if we already have good efficiency to avoid excessive overhead. + if (efficiency_percent_best >= 95 && nwaves > nwaves_best) { + break; + } + + if (efficiency_percent > efficiency_percent_best) { + nwaves_best = nwaves; + efficiency_percent_best = efficiency_percent; + parallel_blocks = parallel_blocks_test; + } + } + + blocks_num.x = ntiles_x; + blocks_num.y = parallel_blocks; + blocks_num.z = ntiles_z_gqa*K->ne[2]*Q->ne[3]; + + if (parallel_blocks > 1) { + dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); + dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV)); + } + } + + float scale = 1.0f; + float max_bias = 0.0f; + float logit_softcap = 0.0f; + + memcpy(&scale, (const float *) KQV->op_params + 0, sizeof(float)); + memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); + memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); + + if (logit_softcap != 0.0f) { + scale /= logit_softcap; + } + + const uint32_t n_head = Q->ne[2]; + const uint32_t n_head_log2 = 1u << uint32_t(floorf(log2f(float(n_head)))); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + // TODO other tensor dimensions after removal of WMMA kernel: + const sycl::uint3 ne01 = init_fastdiv_values(Q->ne[1]); + + GGML_ASSERT(block_dim.x % warp_size == 0); + + lauch_kernel( + blocks_num, block_dim, main_stream, (unsigned int) nbytes_shared, (const char *) Q->data, K_data, V_data, + mask ? ((const char *) mask->data) : nullptr, sinks ? ((const char *) sinks->data) : nullptr, KV_max.ptr, + !stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, (sycl::float2 *)dst_tmp_meta.ptr, scale, max_bias, m0, m1, + n_head_log2, logit_softcap, Q->ne[0], ne01, Q->ne[2], Q->ne[3], Q->nb[1], Q->nb[2], Q->nb[3], K->ne[0], + K->ne[1], K->ne[2], K->ne[3], nb11, nb12, nb13, nb21, nb22, nb23, mask ? mask->ne[1] : 0, + mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0, mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0, + mask ? mask->nb[3] : 0); + SYCL_CHECK(0); + + if (stream_k) { + if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles. + const dpct::dim3 block_dim_combine(DV, 1, 1); + const dpct::dim3 blocks_num_combine = { blocks_num.x, ncols1, ncols2 }; + + main_stream->submit([&](sycl::handler & cgh) { + auto KQV_data_ct0 = (float *) KQV->data; + auto dst_tmp_meta_ptr_ct1 = dst_tmp_meta.ptr; + auto Q_ne_ct2 = Q->ne[1]; + auto Q_ne_ct3 = Q->ne[2]; + auto Q_ne_ct4 = Q->ne[3]; + auto K_ne_ct5 = K->ne[1]; + auto K_ne_ct6 = K->ne[2]; + + cgh.parallel_for(sycl::nd_range<3>(blocks_num_combine * block_dim_combine, block_dim_combine), + [=](sycl::nd_item<3> item_ct1) { + GGML_UNUSED(item_ct1); + flash_attn_stream_k_fixup(KQV_data_ct0, dst_tmp_meta_ptr_ct1, + Q_ne_ct2, Q_ne_ct3, Q_ne_ct4, + K_ne_ct5, K_ne_ct6, nbatch_fa); + }); + }); + } + } else if (parallel_blocks > 1) { + const dpct::dim3 block_dim_combine(DV, 1, 1); + const dpct::dim3 blocks_num_combine(Q->ne[1], Q->ne[2], Q->ne[3]); + const size_t nbytes_shared_combine = parallel_blocks * sizeof(sycl::float2); + main_stream->submit([&](sycl::handler & cgh) { + sycl::local_accessor dpct_local_acc_ct1(sycl::range<1>(nbytes_shared_combine), cgh); + + auto dst_tmp_ptr_ct0 = dst_tmp.ptr; + auto dst_tmp_meta_ptr_ct1 = dst_tmp_meta.ptr; + auto KQV_data_ct2 = (float *) KQV->data; + + cgh.parallel_for(sycl::nd_range<3>(blocks_num_combine * block_dim_combine, block_dim_combine), + [=](sycl::nd_item<3> item_ct1) { + GGML_UNUSED(item_ct1); + flash_attn_combine_results( + dst_tmp_ptr_ct0, dst_tmp_meta_ptr_ct1, KQV_data_ct2, parallel_blocks, + dpct_local_acc_ct1.get_multi_ptr().get()); + }); + }); + } + SYCL_CHECK(0); +} diff --git a/ggml/src/ggml-sycl/fattn-tile.cpp b/ggml/src/ggml-sycl/fattn-tile.cpp new file mode 100644 index 00000000000..9d4f019cf51 --- /dev/null +++ b/ggml/src/ggml-sycl/fattn-tile.cpp @@ -0,0 +1,55 @@ +#include +#include +#include "dpct/helper.hpp" +#include "common.hpp" +#include "fattn-common.hpp" +#include "fattn-tile.hpp" +#include +#include +namespace syclex = sycl::ext::oneapi::experimental; + +void ggml_sycl_flash_attn_ext_tile(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; + switch (K->ne[0]) { + case 40: { + GGML_ASSERT(V->ne[0] == K->ne[0]); + ggml_sycl_flash_attn_ext_tile_case< 40, 40>(ctx, dst); + } break; + case 64: { + GGML_ASSERT(V->ne[0] == K->ne[0]); + ggml_sycl_flash_attn_ext_tile_case< 64, 64>(ctx, dst); + } break; + case 72: { + GGML_ASSERT(V->ne[0] == K->ne[0]); + ggml_sycl_flash_attn_ext_tile_case< 72, 72>(ctx, dst); + } break; + case 80: { + GGML_ASSERT(V->ne[0] == K->ne[0]); + ggml_sycl_flash_attn_ext_tile_case< 80, 80>(ctx, dst); + } break; + case 96: { + GGML_ASSERT(V->ne[0] == K->ne[0]); + ggml_sycl_flash_attn_ext_tile_case< 96, 96>(ctx, dst); + } break; + case 112: { + GGML_ASSERT(V->ne[0] == K->ne[0]); + ggml_sycl_flash_attn_ext_tile_case<112, 112>(ctx, dst); + } break; + case 128: { + GGML_ASSERT(V->ne[0] == K->ne[0]); + ggml_sycl_flash_attn_ext_tile_case<128, 128>(ctx, dst); + } break; + case 256: { + GGML_ASSERT(V->ne[0] == K->ne[0]); + ggml_sycl_flash_attn_ext_tile_case<256, 256>(ctx, dst); + } break; + case 576: { + GGML_ASSERT(V->ne[0] == 512); + ggml_sycl_flash_attn_ext_tile_case<576, 512>(ctx, dst); + } break; + default: { + GGML_ABORT("Unsupported head size"); + } break; + } +} diff --git a/ggml/src/ggml-sycl/fattn-tile.hpp b/ggml/src/ggml-sycl/fattn-tile.hpp new file mode 100644 index 00000000000..29fd0f8c9ec --- /dev/null +++ b/ggml/src/ggml-sycl/fattn-tile.hpp @@ -0,0 +1,1338 @@ +#include +#include +#include "dpct/helper.hpp" +#include "common.hpp" +#include "fattn-common.hpp" + +#include +#include + +namespace syclex = sycl::ext::oneapi::experimental; + +#define GGML_SYCL_FATTN_TILE_CONFIG_CASE(DKQ_, DV_, ncols_, nthreads, occupancy, nbatch_fa, nbatch_K) \ + if (DKQ == (DKQ_) && DV == (DV_) && ncols == (ncols_)) { \ + static_assert((nthreads) <= 512, "bad nthreads"); \ + static_assert((occupancy) <= 8, "bad occupancy"); \ + static_assert((nbatch_fa) <= 256, "bad nbatch_fa"); \ + static_assert((nbatch_K) <= 256, "bad nbatch_K"); \ + return ((nthreads) << 0) | ((occupancy) << 10) | ((nbatch_fa) << 14) | ((nbatch_K) << 23); \ + } \ + +static constexpr uint32_t ggml_sycl_fattn_tile_get_config_fp16(const int DKQ, const int DV, const int ncols) { + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 2, 64, 2, 64, 40) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 4, 128, 2, 64, 40) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 8, 256, 2, 64, 40) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 16, 256, 2, 64, 40) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 32, 256, 2, 64, 40) + + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 2, 64, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 4, 128, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 8, 256, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 16, 256, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 256, 2, 64, 64) + + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 2, 64, 2, 64, 72) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 4, 128, 2, 64, 72) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 8, 256, 2, 64, 72) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 16, 256, 2, 64, 72) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 32, 256, 2, 64, 72) + + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 64, 40) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 64, 40) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 64, 40) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 16, 256, 2, 64, 40) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 32, 256, 2, 64, 40) + + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 2, 64, 2, 64, 48) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 4, 128, 2, 64, 48) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 8, 256, 2, 64, 48) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 16, 256, 2, 64, 48) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 32, 256, 2, 64, 48) + + GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 2, 64, 2, 64, 56) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 4, 128, 2, 64, 56) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 8, 256, 2, 64, 56) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2, 64, 56) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2, 64, 56) + + GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 2, 64, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 4, 128, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 8, 256, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 16, 256, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2, 64, 64) + + GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 2, 64, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 4, 128, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 8, 256, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64) + + GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64) + + return 0; +} + +static constexpr uint32_t ggml_sycl_fattn_tile_get_config_fp32(const int DKQ, const int DV, const int ncols) { + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 2, 64, 2, 32, 40) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 4, 128, 2, 32, 40) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 8, 256, 2, 32, 40) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 16, 256, 2, 32, 40) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 32, 256, 2, 32, 40) + + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 2, 128, 3, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 4, 128, 3, 32, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 8, 128, 3, 32, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 16, 128, 3, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 256, 2, 64, 64) + + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 2, 64, 2, 32, 72) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 4, 128, 2, 32, 72) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 8, 256, 2, 32, 72) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 16, 256, 2, 32, 72) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 32, 256, 2, 32, 72) + + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 32, 40) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 32, 40) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 32, 40) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 16, 256, 2, 32, 40) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 32, 256, 2, 32, 40) + + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 2, 64, 2, 32, 48) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 4, 128, 2, 32, 48) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 8, 256, 2, 32, 48) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 16, 256, 2, 32, 48) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 32, 256, 2, 32, 48) + + GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 2, 64, 2, 32, 56) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 4, 128, 2, 32, 56) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 8, 256, 2, 32, 56) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2, 32, 56) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2, 32, 56) + + GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 2, 128, 3, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 4, 128, 3, 32, 128) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 8, 128, 3, 64, 128) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 16, 128, 3, 32, 128) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2, 64, 64) + + GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 2, 128, 3, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 4, 128, 3, 32, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 8, 256, 2, 32, 256) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64) + + GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 32, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 32, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 32, 64) + + return 0; +} + +static constexpr uint32_t ggml_sycl_fattn_tile_get_config_amd(const int DKQ, const int DV, const int ncols) { + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 2, 64, 2, 32, 40) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 4, 128, 2, 32, 40) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 8, 256, 2, 32, 40) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 16, 256, 2, 32, 40) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 32, 256, 2, 32, 40) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 64, 256, 2, 32, 40) + + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 2, 64, 3, 32, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 4, 128, 3, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 8, 128, 2, 32, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 16, 256, 2, 128, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 256, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 64, 256, 2, 64, 64) + + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 2, 64, 2, 32, 72) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 4, 128, 2, 32, 72) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 8, 256, 2, 32, 72) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 16, 256, 2, 32, 72) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 32, 256, 2, 32, 72) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 64, 256, 2, 32, 72) + + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 32, 40) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 32, 40) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 32, 40) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 16, 256, 2, 32, 40) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 32, 256, 2, 32, 40) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 64, 256, 2, 32, 40) + + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 2, 64, 2, 32, 48) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 4, 128, 2, 32, 48) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 8, 256, 2, 32, 48) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 16, 256, 2, 32, 48) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 32, 256, 2, 32, 48) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 64, 256, 2, 32, 48) + + GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 2, 64, 2, 32, 56) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 4, 128, 2, 32, 56) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 8, 256, 2, 32, 56) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2, 32, 56) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2, 32, 56) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 64, 256, 2, 32, 56) + + GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 2, 256, 2, 128, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 4, 128, 2, 64, 128) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 8, 256, 2, 64, 128) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 16, 256, 2, 64, 128) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 64, 256, 2, 64, 32) + + GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 2, 256, 2, 128, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 4, 256, 2, 64, 128) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 8, 256, 2, 64, 128) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 128) + + GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 32, 512, 1, 128, 64) + + return 0; +} + +static constexpr uint32_t ggml_sycl_fattn_tile_get_config_amd_rdna(const int DKQ, const int DV, const int ncols) { + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 2, 64, 2, 32, 40) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 4, 128, 2, 32, 40) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 8, 256, 2, 32, 40) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 16, 256, 2, 32, 40) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 32, 256, 2, 32, 40) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 64, 256, 2, 32, 40) + + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 2, 64, 8, 32, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 4, 64, 8, 32, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 8, 128, 5, 128, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 16, 128, 5, 128, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 128, 4, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 64, 128, 5, 64, 64) + + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 2, 64, 2, 32, 72) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 4, 128, 2, 32, 72) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 8, 256, 2, 32, 72) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 16, 256, 2, 32, 72) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 32, 256, 2, 32, 72) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 64, 256, 2, 32, 72) + + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 32, 40) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 32, 40) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 32, 40) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 16, 256, 2, 32, 40) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 32, 256, 2, 32, 40) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 64, 256, 2, 32, 40) + + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 2, 64, 2, 32, 48) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 4, 128, 2, 32, 48) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 8, 256, 2, 32, 48) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 16, 256, 2, 32, 48) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 32, 256, 2, 32, 48) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 64, 256, 2, 32, 48) + + GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 2, 64, 2, 32, 56) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 4, 128, 2, 32, 56) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 8, 256, 2, 32, 56) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2, 32, 56) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2, 32, 56) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 64, 256, 2, 32, 56) + + GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 2, 64, 8, 32, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 4, 128, 8, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 8, 128, 8, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 16, 256, 3, 128, 128) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 3, 128, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 64, 256, 3, 64, 64) + + GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 2, 64, 8, 32, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 4, 128, 6, 32, 256) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 8, 128, 6, 32, 256) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5, 32, 256) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3, 64, 128) + + GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 4, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 32, 256, 2, 128, 64) + + return 0; +} + +static constexpr uint32_t ggml_sycl_fattn_tile_get_config(const int DKQ, const int DV, const int ncols, const int cc) { + if(fast_fp16_available(cc)) + return ggml_sycl_fattn_tile_get_config_fp16(DKQ, DV, ncols); + else + return ggml_sycl_fattn_tile_get_config_fp32(DKQ, DV, ncols); +} + +static constexpr uint32_t ggml_sycl_fattn_tile_get_config(const int DKQ, const int DV, const int ncols) { +#ifdef SYCL_FAST_FP16 + return ggml_sycl_fattn_tile_get_config_fp16(DKQ, DV, ncols); +#else + return ggml_sycl_fattn_tile_get_config_fp32(DKQ, DV, ncols); +#endif // SYCL_FAST_FP16 +} + +static int ggml_sycl_fattn_tile_get_nthreads(const int DKQ, const int DV, const int ncols, const int cc) { + return (ggml_sycl_fattn_tile_get_config(DKQ, DV, ncols, cc) >> 0) & ((1 << 10) - 1); +} + +static constexpr int ggml_sycl_fattn_tile_get_nthreads(const int DKQ, const int DV, const int ncols) { + return (ggml_sycl_fattn_tile_get_config(DKQ, DV, ncols) >> 0) & ((1 << 10) - 1); +} + +static int ggml_sycl_fattn_tile_get_occupancy(const int DKQ, const int DV, const int ncols, const int cc) { + return (ggml_sycl_fattn_tile_get_config(DKQ, DV, ncols, cc) >> 10) & ((1 << 4) - 1); +} + +static constexpr int ggml_sycl_fattn_tile_get_occupancy(const int DKQ, const int DV, const int ncols) { + return (ggml_sycl_fattn_tile_get_config(DKQ, DV, ncols) >> 10) & ((1 << 4) - 1); +} + +static int ggml_sycl_fattn_tile_get_nbatch_fa(const int DKQ, const int DV, const int ncols, const int cc) { + return (ggml_sycl_fattn_tile_get_config(DKQ, DV, ncols, cc) >> 14) & ((1 << 9) - 1); +} + +static constexpr int ggml_sycl_fattn_tile_get_nbatch_fa(const int DKQ, const int DV, const int ncols) { + return (ggml_sycl_fattn_tile_get_config(DKQ, DV, ncols) >> 14) & ((1 << 9) - 1); +} + +static int ggml_sycl_fattn_tile_get_nbatch_K(const int DKQ, const int DV, const int ncols, const int cc) { + return (ggml_sycl_fattn_tile_get_config(DKQ, DV, ncols, cc) >> 23) & ((1 << 9) - 1); +} + +static constexpr int ggml_sycl_fattn_tile_get_nbatch_K(const int DKQ, const int DV, const int ncols) { + return (ggml_sycl_fattn_tile_get_config(DKQ, DV, ncols) >> 23) & ((1 << 9) - 1); +} + +template +static __dpct_inline__ void flash_attn_tile_load_tile(const sycl::half2 * const __restrict__ KV, + sycl::half2 * const __restrict__ tile_KV, + const int stride_KV, + const int i_sup) { + constexpr int cpy_nb = ggml_sycl_get_max_cpy_bytes(); + constexpr int cpy_ne = cpy_nb / 4; + + auto load = [&] (const int n) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + const int stride_j = warp_size >> n; + + if (stride_j == 0) { + return; + } + + const int j0_start = stride_j == warp_size ? 0 : ((J/2)/cpy_ne) - ((J/2)/cpy_ne) % (2*stride_j); + const int j0_stop = ((J/2)/cpy_ne) - ((J/2)/cpy_ne) % (1*stride_j); + const int stride_i = warp_size / stride_j; + + if (j0_start == j0_stop) { + return; + } + +#pragma unroll + for (int i0 = 0; i0 < I; i0 += nwarps*stride_i) { + const int i = i0 + item_ct1.get_local_id(1) * stride_i + + (stride_j == warp_size ? 0 : item_ct1.get_local_id(2) / stride_j); + + if (i0 + nwarps*stride_i <= I || i < I) { +#pragma unroll + for (int j0 = j0_start; j0 < j0_stop; j0 += stride_j) { + const int j = j0 * cpy_ne + (stride_j == warp_size ? item_ct1.get_local_id(2) : + item_ct1.get_local_id(2) % stride_j) * + cpy_ne; + + const __dpct_align__(16) sycl::half2 zero[cpy_ne] = { + { 0.0f, 0.0f } + }; + ggml_sycl_memcpy_1( + tile_KV + i*(J/2 + J_padding) + j, + !oob_check || i < i_sup ? KV + i*stride_KV + j : zero); + } + } + } + }; + // 1: max 64*16=512 bytes, 512 half + // 2: max 32*16=512 bytes, 256 half + // 3: max 16*16=256 bytes, 128 half + // 4: max 8*16=128 bytes, 64 half + // 5: max 4*16= 64 bytes, 32 half + // 6: max 2*16= 32 bytes, 16 half + // 7: max 1*16= 16 bytes, 8 half + static_assert(J % 8 == 0, "bad J"); + static_assert((J/2) % cpy_ne == 0, "bad J"); + ggml_sycl_unroll<7>{}(load); +} + +template +static __dpct_inline__ void flash_attn_tile_load_tile(const sycl::half2 * const __restrict__ KV, + float * const __restrict__ tile_KV, + const int stride_KV, + const int i_sup) { + constexpr int cpy_nb = ggml_sycl_get_max_cpy_bytes(); + constexpr int cpy_ne = cpy_nb / 4; + + auto load = [&] (const int n) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + const int stride_j = warp_size >> n; + + if (stride_j == 0) { + return; + } + + const int j0_start = stride_j == warp_size ? 0 : (J/cpy_ne) - (J/cpy_ne) % (2*stride_j); + const int j0_stop = (J/cpy_ne) - (J/cpy_ne) % (1*stride_j); + const int stride_i = warp_size / stride_j; + + if (j0_start == j0_stop) { + return; + } + +#pragma unroll + for (int i0 = 0; i0 < I; i0 += nwarps*stride_i) { + const int i = i0 + item_ct1.get_local_id(1) * stride_i + + (stride_j == warp_size ? 0 : item_ct1.get_local_id(2) / stride_j); + + if (i0 + nwarps*stride_i <= I || i < I) { +#pragma unroll + for (int j0 = j0_start; j0 < j0_stop; j0 += stride_j) { + const int j = j0 * (cpy_ne / 2) + (stride_j == warp_size ? item_ct1.get_local_id(2) : + item_ct1.get_local_id(2) % stride_j) * + (cpy_ne / 2); + + const sycl::half2 zero[cpy_ne / 2] = { + { 0.0f, 0.0f } + }; + __dpct_align__(16) sycl::half2 tmp_h2[cpy_ne / 2]; + ggml_sycl_memcpy_1( + tmp_h2, !oob_check || i < i_sup ? KV + i*stride_KV + j : zero); + + __dpct_align__(16) sycl::float2 tmp_f2[cpy_ne / 2]; +#pragma unroll + for (int l = 0; l < cpy_ne/2; ++l) { + tmp_f2[l] = tmp_h2[l].template convert(); + } + ggml_sycl_memcpy_1(tile_KV + i*(J + J_padding) + 2*j, tmp_f2); + } + } + } + }; + // 1: max 32*16=512 bytes, 128 float + // 2: max 16*16=256 bytes, 64 float + // 3: max 8*16=128 bytes, 32 float + // 4: max 4*16= 64 bytes, 16 float + // 5: max 2*16= 32 bytes, 8 float + static_assert(J % 8 == 0, "bad J"); + static_assert(J % cpy_ne == 0, "bad J"); + ggml_sycl_unroll<5>{}(load); +} + +// Function that performs a single iteration in for the KQ matrix multiplication: +template +static __dpct_inline__ void flash_attn_tile_iter_KQ(T_vec_dot * const Q_tmp, + const sycl::half2 * const __restrict__ K_h2, + T_vec_dot * const KV_tmp, + const int stride_K2, + const int k_VKQ_0, + const int k_VKQ_sup, + const int k_KQ_0, + float * KQ_acc) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + constexpr int cpy_nb = ggml_sycl_get_max_cpy_bytes(); + constexpr int cpy_ne = cpy_nb / 4; + + constexpr int ncols = ncols1*ncols2; + constexpr int cpw = ncols > nwarps ? ncols/nwarps : 1; // Q columns per warp + constexpr int np = nwarps > ncols ? nwarps/ncols : 1; // number of parallel warps per Q column + + flash_attn_tile_load_tile + (K_h2 + int64_t(k_VKQ_0)*stride_K2 + k_KQ_0/2, KV_tmp, stride_K2, k_VKQ_sup); + item_ct1.barrier(); + +#ifdef SYCL_FAST_FP16 + static_assert((nbatch_K/2) % cpy_ne == 0, "bad nbatch_K"); +#pragma unroll + for (int k_KQ_1 = 0; k_KQ_1 < nbatch_K/2; k_KQ_1 += cpy_ne) { + __dpct_align__(16) sycl::half2 K_k[nbatch_fa / (np * warp_size)][cpy_ne]; + __dpct_align__(16) sycl::half2 Q_k[cpw][cpy_ne]; +#else + static_assert(nbatch_K % cpy_ne == 0, "bad nbatch_K"); +#pragma unroll + for (int k_KQ_1 = 0; k_KQ_1 < nbatch_K; k_KQ_1 += cpy_ne) { + __dpct_align__(16) float K_k[nbatch_fa/(np*warp_size)][cpy_ne]; + __dpct_align__(16) float Q_k[cpw][cpy_ne]; +#endif // SYCL_FAST_FP16 + +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < nbatch_fa; i_KQ_0 += np*warp_size) { + const int i_KQ = i_KQ_0 + (item_ct1.get_local_id(1) % np) * warp_size + item_ct1.get_local_id(2); + +#ifdef SYCL_FAST_FP16 + ggml_sycl_memcpy_1(&K_k[i_KQ_0/(np*warp_size)], &KV_tmp[i_KQ*(nbatch_K/2 + cpy_ne) + k_KQ_1]); +#else + ggml_sycl_memcpy_1(&K_k[i_KQ_0/(np*warp_size)], &KV_tmp[i_KQ*(nbatch_K + cpy_ne) + k_KQ_1]); +#endif // SYCL_FAST_FP16 + } +#pragma unroll + for (int jc0 = 0; jc0 < cpw; ++jc0) { + const int jc = jc0 + (item_ct1.get_local_id(1) / np) * cpw; + +#ifdef SYCL_FAST_FP16 + ggml_sycl_memcpy_1(&Q_k[jc0], &Q_tmp[jc*(DKQ/2) + k_KQ_0/2 + k_KQ_1]); +#else + ggml_sycl_memcpy_1(&Q_k[jc0], &Q_tmp[jc* DKQ + k_KQ_0 + k_KQ_1]); +#endif // SYCL_FAST_FP16 + } + +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < nbatch_fa; i_KQ_0 += np*warp_size) { +#pragma unroll + for (int jc0 = 0; jc0 < cpw; ++jc0) { +#pragma unroll + for (int k = 0; k < cpy_ne; ++k) { + ggml_sycl_mad(KQ_acc[i_KQ_0/(np*warp_size)*cpw + jc0], K_k[i_KQ_0/(np*warp_size)][k], Q_k[jc0][k]); + } + } + } + } + + if (k_KQ_0 + nbatch_K < DKQ) { + item_ct1.barrier(); // Sync not needed on last iteration. + } +} + +// Function that performs a single iteration of the main loop over up to nbatch_fa tokens. +template +/* +The total declared local variable size in device function flash_attn_tile_iter exceeds 128 bytes and may cause high register pressure. Consult with your hardware vendor to find the total register size available and adjust the code, or use smaller sub-group size to avoid high register pressure. +*/ +static __dpct_inline__ void flash_attn_tile_iter(T_vec_dot * const Q_tmp, + const sycl::half2 * const __restrict__ K_h2, + const sycl::half2 * const __restrict__ V_h2, + const sycl::half * const __restrict__ mask, + const sycl::uint3 ne01, + const float logit_softcap, + const float slope, + T_KQ * const KQ, + T_vec_dot * const KV_tmp, + const int stride_K2, + const int stride_V2, + const int stride_mask, + float * const KQ_max, + float * const KQ_sum, + T_acc * const VKQ, + const int k_VKQ_0, + const int k_VKQ_max, + const int col_Q_0, + float * KQ_max_new_shared) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + constexpr int cpy_nb = ggml_sycl_get_max_cpy_bytes(); + constexpr int cpy_ne = cpy_nb / 4; + + constexpr int ncols = ncols1*ncols2; + constexpr int cpw = ncols > nwarps ? ncols/nwarps : 1; // Q columns per warp + constexpr int np = nwarps > ncols ? nwarps/ncols : 1; // number of parallel warps per Q column + + constexpr int DVp = (DV + 2*warp_size - 1) & ~(2*warp_size - 1); // DV padded to multiple of 2*warp_size. + +#ifdef SYCL_FAST_FP16 + constexpr int KQ_cs = cpw < 2*cpy_ne ? cpw : 2*cpy_ne; +#else + constexpr int KQ_cs = cpw < 1*cpy_ne ? cpw : 1*cpy_ne; +#endif // SYCL_FAST_FP16 + static_assert(cpw % KQ_cs == 0, "bad KQ_cs"); + const int k_VKQ_sup = k_VKQ_max - k_VKQ_0; // k supremum, only smaller k values have valid KV data + + float KQ_max_new[cpw]; +#pragma unroll + for (int jc0 = 0; jc0 < cpw; ++jc0) { + KQ_max_new[jc0] = KQ_max[jc0]; + } + + float KQ_acc[nbatch_fa/(np*warp_size) * cpw] = {0.0f}; // Accumulators for KQ matrix multiplication. + + // KQ = K @ Q matrix multiplication: + constexpr int nbatch_K_last = DKQ % nbatch_K; +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < DKQ - nbatch_K_last; k_KQ_0 += nbatch_K) { + flash_attn_tile_iter_KQ( + Q_tmp, K_h2, KV_tmp, stride_K2, k_VKQ_0, k_VKQ_sup, k_KQ_0, KQ_acc); + } + if (nbatch_K_last > 0) { + constexpr int k_KQ_0 = DKQ - nbatch_K_last; + flash_attn_tile_iter_KQ( + Q_tmp, K_h2, KV_tmp, stride_K2, k_VKQ_0, k_VKQ_sup, k_KQ_0, KQ_acc); + } + + // Apply logit softcap + mask, update KQ_max: +#pragma unroll + for (int jc0 = 0; jc0 < cpw; ++jc0) { + const int j = fastmodulo(col_Q_0 + (jc0 + (item_ct1.get_local_id(1) / np) * cpw) / ncols2, ne01); + +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < nbatch_fa; i_KQ_0 += np*warp_size) { + const int i_KQ = i_KQ_0 + (item_ct1.get_local_id(1) % np) * warp_size + item_ct1.get_local_id(2); + +#if defined(SYCL_FAST_FP16) && !defined(GGML_SYCL_F16) + // Without the v_dot2_f32_f16 instruction there is a higher risk of numerical overflow in the KQ calculation. + // Therefore, scale down Q values and apply the inverse scale the FP32 KQ values afterwards again. + KQ_acc[i_KQ_0/(np*warp_size)*cpw + jc0] *= 4.0f; +#endif // defined(SYCL_FAST_FP16) && !defined(GGML_SYCL_F16) + + if (use_logit_softcap) { + KQ_acc[(i_KQ_0 / (np * warp_size)) * cpw + jc0] = + logit_softcap * sycl::tanh((float) KQ_acc[(i_KQ_0 / (np * warp_size)) * cpw + jc0]); + } + + if (!oob_check || i_KQ < k_VKQ_sup) { + KQ_acc[(i_KQ_0 / (np * warp_size)) * cpw + jc0] += + (ncols2 > 1 || mask) ? slope * sycl::vec(mask[j * stride_mask + k_VKQ_0 + i_KQ]) + .convert()[0] : + 0.0f; + + KQ_max_new[jc0] = + sycl::fmax((float) KQ_max_new[jc0], + (float) (KQ_acc[(i_KQ_0 / (np * warp_size)) * cpw + jc0] + FATTN_KQ_MAX_OFFSET)); + } + } + + KQ_max_new[jc0] = warp_reduce_max(KQ_max_new[jc0]); + } + + if constexpr (np == 1) { + item_ct1.barrier(); + } else { + static_assert(cpw == 1, "bad cpw"); + + if (item_ct1.get_local_id(2) == 0) { + KQ_max_new_shared[item_ct1.get_local_id(1)] = KQ_max_new[0]; + } + item_ct1.barrier(); + KQ_max_new[0] = KQ_max_new_shared[(item_ct1.get_local_id(1) & ~(np - 1)) + item_ct1.get_local_id(2) % np]; + KQ_max_new[0] = warp_reduce_max(KQ_max_new[0]); + } + + // Calculate KQ softmax, write to shared KQ buffer, re-scale VKQ accumulators: +#pragma unroll + for (int jc0 = 0; jc0 < cpw; jc0 += KQ_cs) { +#ifdef SYCL_FAST_FP16 + __dpct_align__(16) sycl::half tmp[nbatch_fa / (np * warp_size)][KQ_cs]; +#else + __dpct_align__(16) float tmp[nbatch_fa/(np*warp_size)][KQ_cs]; +#endif // SYCL_FAST_FP16 + +#pragma unroll + for (int jc1 = 0; jc1 < KQ_cs; ++jc1) { + const int jc = jc0 + jc1; + + const float KQ_max_scale = sycl::native::exp((float) (KQ_max[jc] - KQ_max_new[jc])); + KQ_max[jc] = KQ_max_new[jc]; + + float KQ_sum_add = 0.0f; +#pragma unroll + for (int i0 = 0; i0 < nbatch_fa; i0 += np*warp_size) { + const float val = + !oob_check || i0 + (item_ct1.get_local_id(1) % np) * warp_size + item_ct1.get_local_id(2) < + static_cast(k_VKQ_sup) ? + sycl::native::exp((float) (KQ_acc[(i0 / (np * warp_size)) * cpw + jc] - KQ_max[jc])) : + 0.0f; + KQ_sum_add += val; + tmp[i0/(np*warp_size)][jc1] = val; + } + KQ_sum[jc] = KQ_sum[jc]*KQ_max_scale + KQ_sum_add; + +#ifdef SYCL_FAST_FP16 + const sycl::half2 KQ_max_scale_h2 = sycl::half2(KQ_max_scale, KQ_max_scale); +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size) { + VKQ[jc*((DVp/2)/warp_size) + i0/warp_size].x() *= KQ_max_scale_h2.x(); + VKQ[jc*((DVp/2)/warp_size) + i0/warp_size].y() *= KQ_max_scale_h2.y(); + } +#else +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size) { + VKQ[jc*((DVp/2)/warp_size) + i0/warp_size].x() *= KQ_max_scale; + VKQ[jc*((DVp/2)/warp_size) + i0/warp_size].y() *= KQ_max_scale; + } +#endif // SYCL_FAST_FP16 + } + +#pragma unroll + for (int i0 = 0; i0 < nbatch_fa; i0 += np*warp_size) { + const int i = i0 + (item_ct1.get_local_id(1) % np) * warp_size + item_ct1.get_local_id(2); + + ggml_sycl_memcpy_1( + KQ + (jc0 / KQ_cs + (item_ct1.get_local_id(1) / np) * (cpw / KQ_cs)) * (nbatch_fa * KQ_cs) + i * KQ_cs, + tmp[i0 / (np * warp_size)]); + } + } + + // VKQ = V @ KQ matrix multiplication: + static_assert(DV <= DKQ, "bad DV"); + static_assert(DV % nbatch_K == 0 || (nbatch_K % 3 == 0 && DV % (nbatch_K*2/3) == 0), "bad nbatch_K"); + constexpr int nbatch_V = (DV % nbatch_K == 0 ? nbatch_K : nbatch_K*2/3) * nbatch_fa / DV; // Number of V columns that fit in SRAM for K. + static_assert(nbatch_fa % nbatch_V == 0, "bad nbatch_V"); + static_assert(nbatch_V % np == 0, "bad nbatch_V"); +#pragma unroll + for (int k0 = 0; k0 < nbatch_fa; k0 += nbatch_V) { + flash_attn_tile_load_tile + (V_h2 + int64_t(k_VKQ_0 + k0)*stride_V2, KV_tmp, stride_V2, k_VKQ_sup - k0); + item_ct1.barrier(); + +#ifdef SYCL_FAST_FP16 +#pragma unroll + for (int k1 = 0; k1 < nbatch_V; k1 += np) { + __dpct_align__(16) sycl::half2 V_k[(DVp / 2) / warp_size]; + __dpct_align__(16) sycl::half2 KQ_k[cpw]; + + constexpr int cpy_ne_D = cpy_ne/2 < (DVp/2)/warp_size ? cpy_ne/2 : (DVp/2)/warp_size; +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) { + ggml_sycl_memcpy_1(&V_k[i0 / warp_size], + &KV_tmp[(k1 + item_ct1.get_local_id(1) % np) * (DV / 2) + i0 + + item_ct1.get_local_id(2) * cpy_ne_D]); + } +#pragma unroll + for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; jc_VKQ_0 += KQ_cs) { + const int jc_KQ = jc_VKQ_0 / KQ_cs + (item_ct1.get_local_id(1) / np) * (cpw / KQ_cs); + + __dpct_align__(16) sycl::half tmp[KQ_cs]; + ggml_sycl_memcpy_1( + &tmp, KQ + jc_KQ * (nbatch_fa * KQ_cs) + (k0 + k1 + item_ct1.get_local_id(1) % np) * KQ_cs); +#pragma unroll + for (int jc_VKQ_1 = 0; jc_VKQ_1 < KQ_cs; ++jc_VKQ_1) { + KQ_k[jc_VKQ_0 + jc_VKQ_1] = sycl::half2(tmp[jc_VKQ_1]); + } + } + +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size) { +#pragma unroll + for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; ++jc_VKQ_0) { + VKQ[jc_VKQ_0*((DVp/2)/warp_size) + i0/warp_size].x() += + V_k[i0/warp_size].x()*KQ_k[jc_VKQ_0].x(); + VKQ[jc_VKQ_0*((DVp/2)/warp_size) + i0/warp_size].y() += + V_k[i0/warp_size].y()*KQ_k[jc_VKQ_0].y(); + } + } + } +#else +#pragma unroll + for (int k1 = 0; k1 < nbatch_V; k1 += np) { + __dpct_align__(16) sycl::float2 V_k[(DVp/2)/warp_size]; + __dpct_align__(16) float KQ_k[cpw]; + + constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size; +#pragma unroll + for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) { + ggml_sycl_memcpy_1(&V_k[i0/(2*warp_size)], &KV_tmp[(k1 + item_ct1.get_local_id(1) % np)*DV + i0 + item_ct1.get_local_id(2)*cpy_ne_D]); + } +#pragma unroll + for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; jc_VKQ_0 += KQ_cs) { + const int jc_KQ = jc_VKQ_0/KQ_cs + (item_ct1.get_local_id(1) / np)*(cpw/KQ_cs); + + ggml_sycl_memcpy_1( + &KQ_k[jc_VKQ_0], KQ + jc_KQ*(nbatch_fa*KQ_cs) + (k0 + k1 + item_ct1.get_local_id(1) % np)*KQ_cs); + } + +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size) { +#pragma unroll + for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; ++jc_VKQ_0) { + VKQ[jc_VKQ_0*((DVp/2)/warp_size) + i0/warp_size].x() += V_k[i0/warp_size].x()*KQ_k[jc_VKQ_0]; + VKQ[jc_VKQ_0*((DVp/2)/warp_size) + i0/warp_size].y() += V_k[i0/warp_size].y()*KQ_k[jc_VKQ_0]; + } + } + } +#endif // SYCL_FAST_FP16 + item_ct1.barrier(); + } +} + +template // D == head size +/* +The total declared local variable size in device function flash_attn_tile exceeds 128 bytes and may cause high register pressure. Consult with your hardware vendor to find the total register size available and adjust the code, or use smaller sub-group size to avoid high register pressure. +*/ +static void flash_attn_tile(const char * Q, + const char * K, + const char * V, + const char * mask, + const char * sinks, + const int * KV_max, + float * dst, + sycl::float2 * dst_meta, + const float scale, + const float max_bias, + const float m0, + const float m1, + const uint32_t n_head_log2, + const float logit_softcap, + const int32_t ne00, + const sycl::uint3 ne01, + const int32_t ne02, + const int32_t ne03, + const int32_t nb01, + const int32_t nb02, + const int32_t nb03, + const int32_t ne10, + const int32_t ne11, + const int32_t ne12, + const int32_t ne13, + const int32_t nb11, + const int32_t nb12, + const int64_t nb13, + const int32_t nb21, + const int32_t nb22, + const int64_t nb23, + const int32_t ne31, + const int32_t ne32, + const int32_t ne33, + const int32_t nb31, + const int32_t nb32, + const int64_t nb33) { +#ifdef SYCL_FLASH_ATTN + // Skip unused kernel variants for faster compilation: + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + if ((use_logit_softcap && !(DV == 128 || DV == 256))) { + GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, + max_bias, m0, m1, n_head_log2, logit_softcap, + ne00, ne01, ne02, ne03, + nb01, nb02, nb03, + ne10, ne11, ne12, ne13, + nb11, nb12, nb13, + nb21, nb22, nb23, + ne31, ne32, ne33, + nb31, nb32, nb33); + return; + } + + static_assert(ggml_sycl_fattn_tile_get_config(DKQ, DV, ncols1*ncols2) != 0, "kernel config not defined"); + + constexpr int ncols = ncols1*ncols2; + + constexpr int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, ncols1*ncols2) / warp_size; + constexpr int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, ncols1*ncols2); + constexpr int nbatch_K = ggml_sycl_fattn_tile_get_nbatch_K (DKQ, DV, ncols1*ncols2); + + // In this kernel Q, K, V are matrices while i, j, k are matrix indices. + + const int col_Q_0 = item_ct1.get_group(2) * ncols1; // Index of the first Q column for this SYCL block to work on. + + const int sequence = item_ct1.get_group(0) / (ne02 / ncols2); + const int head0 = item_ct1.get_group(0) * ncols2 - sequence * ne02; // == item_ct1.get_group(0) % (ne02/ncols2) + const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. + const float * Q_f = (const float *) (Q + nb03*sequence + nb02* head0); + const sycl::half2 * K_h2 = (const sycl::half2 *) (K + nb13 * sequence + nb12 * (head0 / gqa_ratio)); + const sycl::half2 * V_h2 = + (const sycl::half2 *) (V + nb23 * sequence + nb22 * (head0 / gqa_ratio)); // K and V have same shape + + const sycl::half * maskh = mask ? (const sycl::half *) (mask + nb33 * (sequence % ne33)) : nullptr; + + const int stride_K2 = nb11 / sizeof(sycl::half2); + const int stride_V2 = nb21 / sizeof(sycl::half2); + const int stride_mask = nb31 / sizeof(sycl::half); + + const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f; + + constexpr int cpy_nb = ggml_sycl_get_max_cpy_bytes(); + constexpr int cpy_ne = cpy_nb / 4; + + constexpr int cpw = ncols > nwarps ? ncols/nwarps : 1; // Q columns per warp. + constexpr int np = nwarps > ncols ? nwarps/ncols : 1; // Number of parallel warps per Q column. + + static_assert(cpw == 1 || np == 1, "bad cpw / np"); + static_assert(nbatch_fa % (np*warp_size) == 0, "nbatch_fa % (np*warp_size) != 0"); + + constexpr int DKQp = (DKQ + 2*warp_size - 1) & ~(2*warp_size - 1); // DKQ padded to multiple of 2*warp_size. + constexpr int DVp = (DV + 2*warp_size - 1) & ~(2*warp_size - 1); // DV padded to multiple of 2*warp_size. + + // Q_tmp == SRAM buffer to hold Q data for the entire lifetime of the kernel. + // KV_tmp == SRAM buffer to hold fragments of K/V data while iterating over ne11. + // KV_tmp is padded to avoid memory conflicts for K (cpy_ne) and OOB accesses for V (DVp-DV). + // KQ == SRAM buffer to hold KQ fragments between KQ and VKQ matrix multiplications. + // VKQ == Accumulators in registers for the final VKQ result. + + +#ifdef SYCL_FAST_FP16 + constexpr size_t lsm_size1 = ncols * DKQ/2 ; + constexpr size_t lsm_size2 = nbatch_fa * (nbatch_K/2 + cpy_ne) + DVp-DV ; + constexpr size_t lsm_size3 = ncols * nbatch_fa; + constexpr size_t lsm_size4 = nwarps; + + constexpr size_t local_share_mem_size = lsm_size1 * sizeof(sycl::half2) + + lsm_size2 * sizeof(sycl::half2) + + lsm_size3 * sizeof(sycl::half) + + lsm_size4 * sizeof(float); + + syclex::work_group_static lsm; + + sycl::half2 *Q_tmp = (sycl::half2 *)&lsm; + sycl::half2 *KV_tmp = (sycl::half2*)(Q_tmp +lsm_size1); + sycl::half *KQ = (sycl::half *)(KV_tmp+lsm_size2); + float *KQ_max_new_shared = (float *)(KQ+lsm_size3); + + __dpct_align__(16) sycl::half2 VKQ[cpw * ((DVp / 2) / warp_size)] = { + { 0.0f, 0.0f } + }; +#else + constexpr size_t lsm_size1 = ncols * DKQ ; + constexpr size_t lsm_size2 = nbatch_fa * (nbatch_K + cpy_ne) + DVp-DV; + constexpr size_t lsm_size3 = ncols * nbatch_fa; + constexpr size_t lsm_size4 = nwarps; + + constexpr size_t local_share_mem_size = (lsm_size1 + lsm_size2 +lsm_size3 + lsm_size4) * sizeof(float); + + syclex::work_group_static lsm; + + float *Q_tmp = (float *)&lsm; + float *KV_tmp = Q_tmp +lsm_size1; + float *KQ = KV_tmp+lsm_size2; + float *KQ_max_new_shared = KQ+lsm_size3; + + __dpct_align__(16) sycl::float2 VKQ[cpw * ((DVp/2)/warp_size)] = {{0.0f, 0.0f}}; + + +#endif // SYCL_FAST_FP16 + + float KQ_max[cpw] = {}; + +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + KQ_max[j0/nwarps] = -FLT_MAX/2.0f; + } + float KQ_sum[cpw] = {0.0f}; + + // Load Q data, convert to FP16 if fast: +#pragma unroll + for (int jc0 = 0; jc0 < cpw; ++jc0) { + const int jc = jc0 + (item_ct1.get_local_id(1) / np) * cpw; + + const int j = jc / ncols2; + const int c = jc % ncols2; + + constexpr int cpy_ne_D = cpy_ne < DKQp/warp_size ? cpy_ne : DKQp/warp_size; + +#pragma unroll + for (int i0 = 0; i0 < DKQp; i0 += np*warp_size*cpy_ne_D) { + if (i0 + np * warp_size * cpy_ne_D <= DKQ || + i0 + (item_ct1.get_local_id(1) % np) * (warp_size * cpy_ne_D) + item_ct1.get_local_id(2) * cpy_ne_D < + DKQ) { + __dpct_align__(16) float tmp_f[cpy_ne_D] = { 0.0f }; + ggml_sycl_memcpy_1( + tmp_f, &Q_f[c * (nb02 / sizeof(float)) + fastmodulo(col_Q_0 + j, ne01) * (nb01 / sizeof(float)) + + i0 + (item_ct1.get_local_id(1) % np) * (warp_size * cpy_ne_D) + + item_ct1.get_local_id(2) * cpy_ne_D]); + +#pragma unroll + for (int i1 = 0; i1 < cpy_ne_D; ++i1) { + tmp_f[i1] *= scale; + } + +#ifdef SYCL_FAST_FP16 + __dpct_align__(16) sycl::half2 tmp_h2[cpy_ne_D / 2]; +#pragma unroll + for (int i1 = 0; i1 < cpy_ne_D; i1 += 2) { + tmp_h2[i1/2] = make_half2(tmp_f[i1 + 0], tmp_f[i1 + 1]); +#if defined(SYCL_FAST_FP16) && !defined(GGML_SYCL_F16) + // Without the v_dot2_f32_f16 instruction there is a higher risk of numerical overflow in the KQ calculation. + // Therefore, scale down Q values and apply the inverse scale the FP32 KQ values afterwards again. + tmp_h2[i1 / 2] *= sycl::half2(0.25f, 0.25f); +#endif // defined(SYCL_FAST_FP16) && !defined(GGML_SYCL_F16) + } + ggml_sycl_memcpy_1( + &Q_tmp[jc * (DKQ / 2) + i0 / 2 + (item_ct1.get_local_id(1) % np) * (warp_size * cpy_ne_D / 2) + + item_ct1.get_local_id(2) * (cpy_ne_D / 2)], + tmp_h2); +#else + ggml_sycl_memcpy_1( + &Q_tmp[jc* DKQ + i0 + (item_ct1.get_local_id(1) % np)*(warp_size*cpy_ne_D) + item_ct1.get_local_id(2)* cpy_ne_D], + tmp_f); +#endif // SYCL_FAST_FP16 + } + } + } + + item_ct1.barrier(); + + // Main loop over KV cache: + const int k_VKQ_max = KV_max ? KV_max[sequence * item_ct1.get_group_range(2) + item_ct1.get_group(2)] : ne11; + if (ncols2 == 1) { + // Branch with out-of-bounds checks. + int k_VKQ_0 = item_ct1.get_group(1) * nbatch_fa; + while (k_VKQ_0 < k_VKQ_max - nbatch_fa) { + constexpr bool oob_check = false; + flash_attn_tile_iter(Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp, stride_K2, + stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0, + KQ_max_new_shared); + k_VKQ_0 += item_ct1.get_group_range(1) * nbatch_fa; + } + if (k_VKQ_0 < k_VKQ_max) { + constexpr bool oob_check = true; + flash_attn_tile_iter(Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp, stride_K2, + stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0, + KQ_max_new_shared); + } + } else { + // Branch without out-of-bounds checks. + for (int k_VKQ_0 = item_ct1.get_group(1) * nbatch_fa; k_VKQ_0 < k_VKQ_max; + k_VKQ_0 += item_ct1.get_group_range(1) * nbatch_fa) { + + constexpr bool oob_check = false; + flash_attn_tile_iter(Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp, stride_K2, + stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0, + KQ_max_new_shared); + } + } + +#pragma unroll + for (int jc0 = 0; jc0 < cpw; ++jc0) { + KQ_sum[jc0] = warp_reduce_sum(KQ_sum[jc0]); + } + + if constexpr (np > 1) { + static_assert(cpw == 1, "bad cpw"); + static_assert(nbatch_fa*nbatch_K >= nwarps*DVp, "KV_tmp too small"); + +#ifdef SYCL_FAST_FP16 + sycl::half2 * VKQ_combine = (sycl::half2 *) KV_tmp; +#else + float * VKQ_combine = (float *) KV_tmp; +#endif // SYCL_FAST_FP16 + + float * KQ_sum_combine = (float *) Q_tmp; + + if (item_ct1.get_local_id(1) % np != 0) { + +#ifdef SYCL_FAST_FP16 + constexpr int cpy_ne_D = cpy_ne < (DVp/2)/warp_size ? cpy_ne : (DVp/2)/warp_size; +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) { + ggml_sycl_memcpy_1( + &VKQ_combine[item_ct1.get_local_id(1) * (DVp / 2) + i0 + item_ct1.get_local_id(2) * cpy_ne_D], + &VKQ[i0 / warp_size]); + } +#else + + constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size; + +#pragma unroll + for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) { + ggml_sycl_memcpy_1( + &VKQ_combine[item_ct1.get_local_id(1)*DVp + i0 + item_ct1.get_local_id(2)*cpy_ne_D], ((const float *) VKQ) + i0/warp_size); + } +#endif // SYCL_FAST_FP16 + + if (item_ct1.get_local_id(2) == 0) { + KQ_sum_combine[item_ct1.get_local_id(1)] = KQ_sum[0]; + } + return; + } + + item_ct1.barrier(); + +#pragma unroll + for (int ip = 1; ip < np; ++ip) { +#ifdef SYCL_FAST_FP16 + constexpr int cpy_ne_D = cpy_ne < (DVp/2)/warp_size ? cpy_ne : (DVp/2)/warp_size; +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) { + __dpct_align__(16) sycl::half2 tmp[cpy_ne_D]; + ggml_sycl_memcpy_1(tmp, &VKQ_combine[(item_ct1.get_local_id(1) + ip) * (DVp / 2) + i0 + + item_ct1.get_local_id(2) * cpy_ne_D]); +#pragma unroll + for (int i1 = 0; i1 < cpy_ne_D; ++i1) { + VKQ[i0/warp_size + i1] += tmp[i1]; + } + } +#else + constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size; +#pragma unroll + for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) { + __dpct_align__(16) float tmp[cpy_ne_D]; + ggml_sycl_memcpy_1(tmp, &VKQ_combine[(item_ct1.get_local_id(1) + ip)*DVp + i0 + item_ct1.get_local_id(2)*cpy_ne_D]); +#pragma unroll + for (int i1 = 0; i1 < cpy_ne_D; ++i1) { + ((float *)VKQ)[i0/warp_size + i1] += tmp[i1]; + } + } +#endif // SYCL_FAST_FP16 + + KQ_sum[0] += KQ_sum_combine[item_ct1.get_local_id(1) + ip]; + } + } + + // Attention sink: adjust KQ max and sum only for the first of all parallel blocks: + if (sinks && item_ct1.get_group(1) == 0) { +#pragma unroll + for (int jc0 = 0; jc0 < cpw; ++jc0) { + const int jc = jc0 + (item_ct1.get_local_id(1) / np) * cpw; + const float sink = ((const float *) sinks)[head0 + jc % ncols2]; + + float KQ_max_new_j = sycl::fmax((float) KQ_max[jc0], sink); + const float KQ_max_scale = sycl::native::exp((float) (KQ_max[jc0] - KQ_max_new_j)); + KQ_max[jc0] = KQ_max_new_j; + + const float val = sycl::native::exp((float) (sink - KQ_max[jc0])); + KQ_sum[jc0] = KQ_sum[jc0]*KQ_max_scale + val; + +#ifdef SYCL_FAST_FP16 + const sycl::half2 KQ_max_scale_h2 = sycl::half2(KQ_max_scale, KQ_max_scale); +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size) { + VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size] *= KQ_max_scale_h2; + } +#else +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size) { + VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size].x() *= KQ_max_scale; + VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size].y() *= KQ_max_scale; + } +#endif // SYCL_FAST_FP16 + } + } + + // Write back results: +#pragma unroll + for (int jc0 = 0; jc0 < cpw; ++jc0) { + const int jc = jc0 + (item_ct1.get_local_id(1) / np) * cpw; + + const int j = jc / ncols2; + const int c = jc % ncols2; + + if (ncols1 > 1 && col_Q_0 + j >= int(ne01.z())) { + return; + } + + const float scale = item_ct1.get_group_range(1) == 1 ? 1.0f / KQ_sum[jc0] : 1.0f; + + const int j_dst_unrolled = + ((sequence * int(ne01.z()) + col_Q_0 + j) * ne02 + head0 + c) * item_ct1.get_group_range(1) + + item_ct1.get_group(1); + +#ifdef SYCL_FAST_FP16 + constexpr int cpy_ne_D = cpy_ne/2 < (DVp/2)/warp_size ? cpy_ne/2 : (DVp/2)/warp_size; +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) { + __dpct_align__(16) sycl::float2 tmp[cpy_ne_D]; +#pragma unroll + for (int i1 = 0; i1 < cpy_ne_D; ++i1) { + tmp[i1] = VKQ[jc0 * ((DVp / 2) / warp_size) + i0 / warp_size + i1] + .template convert(); + tmp[i1].x() *= scale; + tmp[i1].y() *= scale; + } + if (i0 + warp_size * cpy_ne_D <= DV / 2 || i0 + item_ct1.get_local_id(2) * cpy_ne_D < DV / 2) { + ggml_sycl_memcpy_1( + &dst[j_dst_unrolled * DV + 2 * i0 + item_ct1.get_local_id(2) * (2 * cpy_ne_D)], tmp); + } + } +#else + constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size; +#pragma unroll + for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) { + if (i0 + warp_size*cpy_ne_D <= DV || i0 + item_ct1.get_local_id(2)*cpy_ne_D < DV) { +#pragma unroll + for (int i1 = 0; i1 < cpy_ne_D/2; ++i1) { + VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size) + i1].x() *= scale; + VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size) + i1].y() *= scale; + } + ggml_sycl_memcpy_1( + &dst[j_dst_unrolled*DV + i0 + item_ct1.get_local_id(2)*cpy_ne_D], + &VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size)]); + } + } +#endif // SYCL_FAST_FP16 + + if (item_ct1.get_group_range(1) != 1 && item_ct1.get_local_id(2) == 0) { + dst_meta[j_dst_unrolled] = make_float2(KQ_max[jc0], KQ_sum[jc0]); + } + } +#else + GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, + max_bias, m0, m1, n_head_log2, logit_softcap, + ne00, ne01, ne02, ne03, + nb01, nb02, nb03, + ne10, ne11, ne12, ne13, + nb11, nb12, nb13, + nb21, nb22, nb23, + ne31, ne32, ne33, + nb31, nb32, nb33); +#endif // SYCL_FLASH_ATTN +} + +template +static void launch_fattn_tile_switch_ncols1(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + const ggml_tensor * Q = dst->src[0]; + + const int id = ggml_sycl_get_device(); + const int cc = ggml_sycl_info().devices[id].cc; + const int warp_size = WARP_32_SIZE; //can't support WARP_16_SIZE + + constexpr size_t nbytes_shared = 0; + + if constexpr (DV <= 256) { + if (Q->ne[1] > 16/ncols2) { + constexpr int cols_per_block = 32; + const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; + const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); + launch_fattn, warp_size> + (ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false); + return; + } + } + + if (Q->ne[1] > 8/ncols2) { + constexpr int cols_per_block = 16; + const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; + const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); + launch_fattn, warp_size> + (ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false); + return; + } + + if constexpr (ncols2 <= 8) { + if (Q->ne[1] > 4/ncols2) { + constexpr int cols_per_block = 8; + const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; + const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); + launch_fattn, warp_size> + (ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false); + return; + } + } + + if constexpr (ncols2 <= 4) { + if (Q->ne[1] > 2/ncols2) { + constexpr int cols_per_block = 4; + const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; + const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); + launch_fattn, warp_size> + (ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false); + return; + } + } + + if constexpr (ncols2 <= 2) { + constexpr int cols_per_block = 2; + const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; + const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); + launch_fattn, warp_size> + (ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false); + return; + } + + GGML_ABORT("fatal error"); +} + +template +static void launch_fattn_tile_switch_ncols2(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + const ggml_tensor * KQV = dst; + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * mask = dst->src[3]; + + float max_bias = 0.0f; + memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); + + GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); + const int gqa_ratio = Q->ne[2] / K->ne[2]; + + // On NVIDIA (Pascal and older) the GQA optimizations seem to be detrimental in some cases. + // However, for DKQ == 576, DV == 512 only the kernel variant with GQA optimizations is implemented. + //const bool nvidia = GGML_SYCL_CC_IS_NVIDIA(ggml_sycl_info().devices[ggml_sycl_get_device()].cc); + const int gqa_limit = gqa_ratio <= 4 && DV <= 256 ? 16 : INT_MAX; + const bool use_gqa_opt = mask && max_bias == 0.0f && Q->ne[1] <= gqa_limit && K->ne[1] % FATTN_KQ_STRIDE == 0; + + if constexpr (DV == 512) { + if (use_gqa_opt && gqa_ratio % 16 == 0) { + launch_fattn_tile_switch_ncols1(ctx, dst); + return; + } + if (use_gqa_opt && gqa_ratio % 4 == 0) { + launch_fattn_tile_switch_ncols1(ctx, dst); + return; + } + } + + if constexpr (DV <= 256) { + if (use_gqa_opt && gqa_ratio % 8 == 0) { + launch_fattn_tile_switch_ncols1(ctx, dst); + return; + } + + if (use_gqa_opt && gqa_ratio % 4 == 0) { + launch_fattn_tile_switch_ncols1(ctx, dst); + return; + } + + if (use_gqa_opt && gqa_ratio % 2 == 0) { + launch_fattn_tile_switch_ncols1(ctx, dst); + return; + } + + launch_fattn_tile_switch_ncols1(ctx, dst); + return; + } + GGML_ABORT("fatal error"); +} + +template +void ggml_sycl_flash_attn_ext_tile_case(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + const ggml_tensor * KQV = dst; + + float logit_softcap; + memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); + + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + launch_fattn_tile_switch_ncols2(ctx, dst); + } else { + constexpr bool use_logit_softcap = true; + launch_fattn_tile_switch_ncols2(ctx, dst); + } +} + +void ggml_sycl_flash_attn_ext_tile(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +#define DECL_FATTN_TILE_CASE(DKQ, DV) \ + template void ggml_sycl_flash_attn_ext_tile_case \ + (ggml_backend_sycl_context & ctx, ggml_tensor * dst) \ + +extern DECL_FATTN_TILE_CASE( 40, 40); +extern DECL_FATTN_TILE_CASE( 64, 64); +extern DECL_FATTN_TILE_CASE( 72, 72); +extern DECL_FATTN_TILE_CASE( 80, 80); +extern DECL_FATTN_TILE_CASE( 96, 96); +extern DECL_FATTN_TILE_CASE(112, 112); +extern DECL_FATTN_TILE_CASE(128, 128); +extern DECL_FATTN_TILE_CASE(256, 256); +extern DECL_FATTN_TILE_CASE(576, 512); + diff --git a/ggml/src/ggml-sycl/fattn-vec.hpp b/ggml/src/ggml-sycl/fattn-vec.hpp new file mode 100644 index 00000000000..48c389052f4 --- /dev/null +++ b/ggml/src/ggml-sycl/fattn-vec.hpp @@ -0,0 +1,667 @@ +#ifndef GGML_SYCL_FATTN_VEC_HPP +#define GGML_SYCL_FATTN_VEC_HPP + +#include +#include +#include +#include + +#include "dpct/helper.hpp" +#include "common.hpp" +#include "ggml.h" +#include "fattn-common.hpp" +#include +#include + +namespace syclex = sycl::ext::oneapi::experimental; + +static int ggml_sycl_fattn_vec_get_nthreads_host(const int cc) { + return 128; + GGML_UNUSED(cc); +} + +static constexpr int ggml_sycl_fattn_vec_get_nthreads_device() { + return 128; +} + +// Currenlty llvm with the amdgcn target dose not support unrolling loops +// that contain a break that can not be resolved at compile time. +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wpass-failed" +#endif // __clang__ + +template // D == head size +static void flash_attn_ext_vec(const char* __restrict__ Q, + const char* __restrict__ K, + const char* __restrict__ V, + const char* __restrict__ mask, + const char* __restrict__ sinks, + const int* __restrict__ KV_max, + float* __restrict__ dst, + sycl::float2* __restrict__ dst_meta, + const float scale, + const float max_bias, + const float m0, + const float m1, + const uint32_t n_head_log2, + const float logit_softcap, + const int32_t ne00, + const sycl::uint3 ne01, + const int32_t ne02, + const int32_t ne03, + const int32_t nb01, + const int32_t nb02, + const int32_t nb03, + const int32_t ne10, + const int32_t ne11, + const int32_t ne12, + const int32_t ne13, + const int32_t nb11, + const int32_t nb12, + const int64_t nb13, + const int32_t nb21, + const int32_t nb22, + const int64_t nb23, + const int32_t ne31, + const int32_t ne32, + const int32_t ne33, + const int32_t nb31, + const int32_t nb32, + const int64_t nb33) { +#ifdef SYCL_FLASH_ATTN + // Skip unused kernel variants for faster compilation: + + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + if (use_logit_softcap && !(D == 128 || D == 256)) { + GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, + max_bias, m0, m1, n_head_log2, logit_softcap, + ne00, ne01, ne02, ne03, + nb01, nb02, nb03, + ne10, ne11, ne12, ne13, + nb11, nb12, nb13, + nb21, nb22, nb23, + ne31, ne32, ne33, + nb31, nb32, nb33); + return; + } + + //In this kernel Q, K, V are matrices while i, j, k are matrix indices. + + constexpr int cpy_nb = ggml_sycl_get_max_cpy_bytes(); + constexpr int cpy_ne = cpy_nb / 4; + + constexpr int nthreads_KQ_q = (D/4 < warp_size ? D/4 : warp_size); + constexpr int nthreads_V_q = (D/4 < warp_size ? D/4 : warp_size); + + constexpr int nthreads = ggml_sycl_fattn_vec_get_nthreads_device(); + constexpr int nthreads_KQ = type_K == GGML_TYPE_F16 ? 128 / cpy_nb : nthreads_KQ_q; + constexpr int nthreads_V = type_V == GGML_TYPE_F16 ? 128 / cpy_nb : nthreads_V_q; + + static_assert(warp_size % nthreads_KQ == 0, "bad nthreads_K"); + static_assert(warp_size % nthreads_V == 0, "bad nthreads_V"); + + constexpr int V_rows_per_thread = type_V == GGML_TYPE_F16 ? 2*cpy_ne : 4; + constexpr int V_cols_per_iter = warp_size / nthreads_V; + + constexpr vec_dot_KQ_t vec_dot_KQ = get_vec_dot_KQ(); + constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16; +#ifdef GGML_SYCL_F16 + constexpr dequantize_V_t dequantize_V = get_dequantize_V(); +#else + constexpr dequantize_V_t dequantize_V = get_dequantize_V(); +#endif // GGML_SYCL_F16 + + const int ic0 = item_ct1.get_group(2) * ncols; // Index of the Q/QKV column to work on. + + const int sequence = item_ct1.get_group(0) / ne02; + const int head = item_ct1.get_group(0) - sequence * ne02; + const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. + Q += nb03*sequence + nb02* head + nb01*ic0; + K += nb13*sequence + nb12*(head / gqa_ratio); + V += nb23*sequence + nb22*(head / gqa_ratio); + + const sycl::half * maskh = (const sycl::half *) (mask + nb33 * (sequence % ne33) + nb31 * ic0); + + const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1); + + static_assert(D % (2*warp_size) == 0, "D not divisible by 2*warp_size == 64."); + constexpr int nwarps = nthreads / warp_size; + const int tid = warp_size * item_ct1.get_local_id(1) + item_ct1.get_local_id(2); + __builtin_assume(tid < nthreads); + + constexpr int ne_KQ = ncols*D; + constexpr int ne_combine = nwarps*V_cols_per_iter*D; + + constexpr size_t lsm_size1 = ncols * warp_size; + constexpr size_t lsm_size2 = ncols * warp_size; +#ifdef GGML_SYCL_F16 + sycl::half2 VKQ[ncols][(D / 2) / nthreads_V] = { { { 0.0f, 0.0f } } }; + constexpr size_t lsm_size3 = (ne_KQ > ne_combine ? ne_KQ : ne_combine); + constexpr size_t local_share_mem_size = (lsm_size1 + lsm_size2)*sizeof(float) + lsm_size3*sizeof(sycl::half); + + syclex::work_group_static lsm; + + float *KQ_max_shared = (float *)&lsm; + float *KQ_sum_shared = KQ_max_shared+lsm_size1; + sycl::half* KQ = (sycl::half*)(KQ_sum_shared + lsm_size2); + + +#else + sycl::float2 VKQ[ncols][(D/2)/nthreads_V] = {{{0.0f, 0.0f}}}; + + constexpr size_t lsm_size3 = (ne_KQ > ne_combine ? ne_KQ : ne_combine); + constexpr size_t local_share_mem_size = (lsm_size1 + lsm_size2 + lsm_size3)*sizeof(float); + + + syclex::work_group_static lsm; + float *KQ_max_shared = (float *)&lsm; + float *KQ_sum_shared = KQ_max_shared+lsm_size1; + float* KQ = KQ_sum_shared + lsm_size2; + +#endif // GGML_SYCL_F16 + + float KQ_max[ncols]; + float KQ_sum[ncols]; +#pragma unroll + for (int j = 0; j < ncols; ++j) { + KQ_max[j] = -FLT_MAX/2.0f; + KQ_sum[j] = 0.0f; + } + + // Convert Q to float2 (f16 K) or q8_1 (quantized K) and store in registers: +#ifdef GGML_SYCL_F16 + sycl::half2 Q_reg[ncols][(D / 2) / nthreads_KQ] = {{{0.0f, 0.0f}}}; // Will be initialized completely. +#else + sycl::float2 Q_reg[ncols][(D/2)/nthreads_KQ] = {{{0.0f, 0.0f}}}; // May be only partially initialized. +#endif // GGML_SYCL_F16 + int Q_i32[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)]; + sycl::float2 Q_ds[ncols][1 > D / (sizeof(int) * nthreads_KQ) ? 1 : D / (sizeof(int) * nthreads_KQ)]; + if constexpr (Q_q8_1) { +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + item_ct1.get_local_id(1); + + if (j0 + nwarps > ncols && j >= ncols) { + break; + } + + // Reuse KQ as temporary storage for converting Q to q8_1: + int * tmp_q_i32 = (int *) &KQ[j*D]; + sycl::float2 * tmp_q_ds = (sycl::float2 *) (tmp_q_i32 + D / sizeof(int)); + + // Set memory to zero if out of bounds: + if (ncols > 1 && ic0 + j >= int(ne01.z())) { +#pragma unroll + for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += warp_size) { + const int i = i0 + item_ct1.get_local_id(2); + + if (i0 + warp_size <= int(D/sizeof(int)) || i < int(D/sizeof(int))) { + tmp_q_i32[i] = 0; + } + } + if (item_ct1.get_local_id(2) < D/QK8_1) { + tmp_q_ds[item_ct1.get_local_id(2)] = sycl::float2(0.0f, 0.0f); + } + } else { + const float * Q_f = (const float *) (Q + j*nb01); + constexpr int nthreads_quantize = D/sizeof(int) < warp_size ? D/sizeof(int) : warp_size; +#pragma unroll + for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += nthreads_quantize) { + quantize_q8_1_to_shared + (Q_f + i0*sizeof(int), scale, tmp_q_i32 + i0, tmp_q_ds + i0/QI8_1); + } + } + } + + + item_ct1.barrier(sycl::access::fence_space::local_space); + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + int * tmp_q_i32 = (int *) &KQ[j*D]; + sycl::float2 * tmp_q_ds = (sycl::float2 *) (tmp_q_i32 + D / sizeof(int)); + +#pragma unroll + for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += nthreads_KQ) { + const int i = + i0 + (nthreads_KQ == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads_KQ); + + Q_i32[j][i0/nthreads_KQ] = tmp_q_i32[i]; + Q_ds[j][i0/nthreads_KQ] = tmp_q_ds[i/QI8_1]; + } + } + + item_ct1.barrier(sycl::access::fence_space::local_space); + + } else { +#ifdef GGML_SYCL_F16 + const sycl::half2 scale_h2 = sycl::half2(scale, scale); +#pragma unroll + for (int j = 0; j < ncols; ++j) { + const sycl::float2 * Q_j = (const sycl::float2 *) (Q + j * nb01); +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += nthreads_KQ*cpy_ne) { + const int i = i0 + (nthreads_KQ == warp_size ? item_ct1.get_local_id(2) : + item_ct1.get_local_id(2) % nthreads_KQ) * + cpy_ne; + + sycl::float2 tmp[cpy_ne] = { + { 0.0f, 0.0f } + }; + if (ncols == 1 || ic0 + j < int(ne01.z())) { + ggml_sycl_memcpy_1(tmp, &Q_j[i]); + ggml_sycl_memcpy_1(tmp + cpy_ne/2, &Q_j[i + cpy_ne/2]); + } +#pragma unroll + for (int i1 = 0; i1 < cpy_ne; ++i1) { + Q_reg[j][i0 / nthreads_KQ + i1] = sycl::half2(tmp[i1].x(), tmp[i1].y()); + } + } +#pragma unroll + for (int k = 0; k < (D/2)/nthreads_KQ; ++k) { + Q_reg[j][k] *= scale_h2; + } + } +#else +#pragma unroll + for (int j = 0; j < ncols; ++j) { + const sycl::float2 * Q_j = (const sycl::float2 *) (Q + j*nb01); +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += nthreads_KQ*cpy_ne) { + const int i = i0 + (nthreads_KQ == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads_KQ)*cpy_ne; + if (ncols == 1 || ic0 + j < int(ne01.z())) { + ggml_sycl_memcpy_1(&Q_reg[j][i0/nthreads_KQ], &Q_j[i]); + ggml_sycl_memcpy_1(&Q_reg[j][i0/nthreads_KQ + cpy_ne/2], &Q_j[i + cpy_ne/2]); + } + } +#pragma unroll + for (int k = 0; k < (D/2)/nthreads_KQ; ++k) { + Q_reg[j][k].x() *= scale; + Q_reg[j][k].y() *= scale; + } + } +#endif // GGML_SYCL_F16 + } + + const int k_VKQ_max = KV_max ? KV_max[sequence * item_ct1.get_group_range(2) + item_ct1.get_group(2)] : ne11; + K += item_ct1.get_group(1) * nthreads * nb11; + V += item_ct1.get_group(1) * nthreads * nb21; + maskh += item_ct1.get_group(1) * nthreads; + for (int k_VKQ_0 = item_ct1.get_group(1) * nthreads; k_VKQ_0 < k_VKQ_max; + k_VKQ_0 += item_ct1.get_group_range(1) * nthreads, + // Increment pointers after each loop: + K += item_ct1.get_group_range(1) * nthreads * nb11, V += item_ct1.get_group_range(1) * nthreads * nb21, + maskh += item_ct1.get_group_range(1) * nthreads) { + // Calculate KQ tile and keep track of new maximum KQ values: + float KQ_reg[ncols]={}; // KQ in registers. + float KQ_max_new[ncols]={}; + + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + KQ_max_new[j] = KQ_max[j]; + } + +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < nthreads_KQ; ++i_KQ_0) { + const int i_KQ = item_ct1.get_local_id(1) * warp_size + + (nthreads_KQ == warp_size ? 0 : (item_ct1.get_local_id(2) & ~(nthreads_KQ - 1))) + i_KQ_0; + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + float sum = vec_dot_KQ(K + i_KQ*nb11, Q_reg[j], Q_i32[j], Q_ds[j]); + sum = warp_reduce_sum(sum); + + if (use_logit_softcap) { + sum = logit_softcap * sycl::tanh(sum); + } + if (mask) { + sum += slope * sycl::vec(maskh[j * ne11 + i_KQ]) + .convert()[0]; + } + + KQ_max_new[j] = sycl::fmax((float) KQ_max_new[j], sum); + + if (int(nthreads_KQ == warp_size ? item_ct1.get_local_id(2) + : item_ct1.get_local_id(2) % + nthreads_KQ) == i_KQ_0) { + KQ_reg[j] = sum; + } + } + } + +#pragma unroll + for (int j = 0; j < ncols; ++j) { +#pragma unroll + for (int offset = nthreads_KQ; offset < warp_size; offset <<= 1) { + KQ_max_new[j] = sycl::fmax( + (float)KQ_max_new[j], + (float)dpct::permute_sub_group_by_xor( + sycl::ext::oneapi::this_work_item::get_sub_group(), + KQ_max_new[j], + offset, + warp_size)); + } + const float KQ_max_scale = sycl::native::exp((float) (KQ_max[j] - KQ_max_new[j])); + KQ_max[j] = KQ_max_new[j]; + + KQ_reg[j] = sycl::native::exp((float) (KQ_reg[j] - KQ_max[j])); + KQ_sum[j] = KQ_sum[j]*KQ_max_scale + KQ_reg[j]; + KQ[j*nthreads + tid] = KQ_reg[j]; + +#ifdef GGML_SYCL_F16 + const sycl::half2 KQ_max_scale_h2 = sycl::half2(KQ_max_scale, KQ_max_scale); +#pragma unroll + for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) { + VKQ[j][i_VKQ_0/nthreads_V] *= KQ_max_scale_h2; + } +#else +#pragma unroll + for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) { + VKQ[j][i_VKQ_0/nthreads_V].x() *= KQ_max_scale; + VKQ[j][i_VKQ_0/nthreads_V].y() *= KQ_max_scale; + } +#endif // GGML_SYCL_F16 + } + + sycl::group_barrier(sycl::ext::oneapi::this_work_item::get_sub_group()); + +#pragma unroll + for (int k0 = 0; k0 < warp_size; k0 += V_cols_per_iter) { + const int k = item_ct1.get_local_id(1) * warp_size + k0 + + (nthreads_V == warp_size ? 0 : item_ct1.get_local_id(2) / nthreads_V); + +#ifdef GGML_SYCL_F16 + sycl::half2 KQ_k[ncols]; +#pragma unroll + for (int j = 0; j < ncols; ++j) { + KQ_k[j] = sycl::half2(KQ[j * nthreads + k]); + } +#pragma unroll + for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) { + sycl::half2 tmp[V_rows_per_thread / 2]; + dequantize_V(V + k * nb21, tmp, + 2 * i_VKQ_0 + (nthreads_V == warp_size ? item_ct1.get_local_id(2) : + item_ct1.get_local_id(2) % nthreads_V) * + V_rows_per_thread); +#pragma unroll + for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) { +#pragma unroll + for (int j = 0; j < ncols; ++j) { + VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1] += tmp[i_VKQ_1]*KQ_k[j]; + } + } + } +#else + float KQ_k[ncols]; +#pragma unroll + for (int j = 0; j < ncols; ++j) { + KQ_k[j] = KQ[j*nthreads + k]; + } +#pragma unroll + for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) { + sycl::float2 tmp[V_rows_per_thread/2]; + dequantize_V(V + k*nb21, tmp, + 2*i_VKQ_0 + (nthreads_V == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads_V)*V_rows_per_thread); +#pragma unroll + for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) { +#pragma unroll + for (int j = 0; j < ncols; ++j) { + VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1].x() += tmp[i_VKQ_1].x()*KQ_k[j]; + VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1].y() += tmp[i_VKQ_1].y()*KQ_k[j]; + } + } + } +#endif // GGML_SYCL_F16 + } + } + + if (sinks && item_ct1.get_group(1) == 0) { + const float sink = ((const float *) sinks)[head]; + +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + item_ct1.get_local_id(1); + + if (j0 + nwarps > ncols && j >= ncols) { + break; + } + const float kqmax_new_j = sycl::fmax(sink, (float) KQ_max[j]); + const float KQ_max_scale = sycl::native::exp((float) (KQ_max[j] - kqmax_new_j)); + KQ_max[j] = kqmax_new_j; + + KQ_sum[j] = KQ_sum[j] * KQ_max_scale + + (item_ct1.get_local_id(2) == 0 ? sycl::native::exp((float) (sink - KQ_max[j])) : 0.0f); +#ifdef GGML_SYCL_F16 + const sycl::half2 KQ_max_scale_h2 = sycl::half2(KQ_max_scale, KQ_max_scale); +#pragma unroll + for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) { + VKQ[j][i_VKQ_0/nthreads_V] *= KQ_max_scale_h2; + } +#else +#pragma unroll + for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) { + VKQ[j][i_VKQ_0/nthreads_V].x() *= KQ_max_scale; + VKQ[j][i_VKQ_0/nthreads_V].y() *= KQ_max_scale; + } +#endif // GGML_SYCL_F16 + } + } + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + if (item_ct1.get_local_id(1) == 0) { + KQ_max_shared[j*warp_size+item_ct1.get_local_id(2)] = -FLT_MAX / 2.0f; + KQ_sum_shared[j*warp_size+item_ct1.get_local_id(2)] = 0.0f; + } + } + + item_ct1.barrier(sycl::access::fence_space::local_space); + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + if (item_ct1.get_local_id(2) == 0) { + KQ_max_shared[j*warp_size+item_ct1.get_local_id(1)] = KQ_max[j]; + } + } + + + item_ct1.barrier(sycl::access::fence_space::local_space); + +#pragma unroll + for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) { + if (ncols > 1 && ic0 + j_VKQ >= int(ne01.z())) { + break; + } + + float kqmax_new = KQ_max_shared[j_VKQ*warp_size+item_ct1.get_local_id(2)]; + kqmax_new = warp_reduce_max(kqmax_new); + const float kqmax_scale = sycl::native::exp((float) (KQ_max[j_VKQ] - kqmax_new)); + KQ_max[j_VKQ] = kqmax_new; + +#ifdef GGML_SYCL_F16 + sycl::half2 * VKQ_tmp = (sycl::half2 *) KQ + item_ct1.get_local_id(1) * (V_cols_per_iter * D / 2) + + (nthreads_V == warp_size ? 0 : item_ct1.get_local_id(2) / nthreads_V) * (D / 2); + + const sycl::half2 kqmax_scale_h2 = sycl::half2(kqmax_scale, kqmax_scale); +#pragma unroll + for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) { + VKQ[j_VKQ][i_VKQ_0/nthreads_V] *= kqmax_scale_h2; + } +#pragma unroll + for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) { + const int i_VKQ = + i_VKQ_0 + (nthreads_V == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads_V) * + (V_rows_per_thread / 2); + + ggml_sycl_memcpy_1(VKQ_tmp + i_VKQ, + &VKQ[j_VKQ][i_VKQ_0 / nthreads_V]); + } +#else + sycl::float2 * VKQ_tmp = (sycl::float2 *) KQ + item_ct1.get_local_id(1)*(V_cols_per_iter*D/2) + + (nthreads_V == warp_size ? 0 : item_ct1.get_local_id(2) / nthreads_V)*(D/2); +#pragma unroll + for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) { + VKQ[j_VKQ][i_VKQ_0/nthreads_V].x() *= kqmax_scale; + VKQ[j_VKQ][i_VKQ_0/nthreads_V].y() *= kqmax_scale; + } +#pragma unroll + for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) { + const int i_VKQ = i_VKQ_0 + (nthreads_V == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads_V)*(V_rows_per_thread/2); + + ggml_sycl_memcpy_1(VKQ_tmp + i_VKQ, &VKQ[j_VKQ][i_VKQ_0/nthreads_V]); + ggml_sycl_memcpy_1(VKQ_tmp + i_VKQ + V_rows_per_thread/4, &VKQ[j_VKQ][i_VKQ_0/nthreads_V + V_rows_per_thread/4]); + } +#endif // GGML_SYCL_F16 + + KQ_sum[j_VKQ] *= kqmax_scale; + KQ_sum[j_VKQ] = warp_reduce_sum(KQ_sum[j_VKQ]); + if (item_ct1.get_local_id(2) == 0) { + KQ_sum_shared[j_VKQ*warp_size+item_ct1.get_local_id(1)] = KQ_sum[j_VKQ]; + } + + item_ct1.barrier(sycl::access::fence_space::local_space); + + + if (nthreads <= D || tid < D) { + KQ_sum[j_VKQ] = KQ_sum_shared[j_VKQ*warp_size+item_ct1.get_local_id(2)]; + KQ_sum[j_VKQ] = warp_reduce_sum(KQ_sum[j_VKQ]); + +#pragma unroll + for (int i0 = 0; i0 < D; i0 += nthreads) { + float dst_val = 0; +#pragma unroll + for (int w = 0; w < nwarps; ++w) { +#pragma unroll + for (int v = 0; v < V_cols_per_iter; ++v) { + dst_val += float(KQ[w*V_cols_per_iter*D + v*D + i0 + tid]); + } + } + if (item_ct1.get_group_range(1) == 1) { + dst_val /= KQ_sum[j_VKQ]; + } + dst[(((sequence * int(ne01.z()) + ic0 + j_VKQ) * ne02 + head) * item_ct1.get_group_range(1) + + item_ct1.get_group(1)) * + D + + i0 + tid] = dst_val; + } + } + + if (j_VKQ < ncols-1) { + item_ct1.barrier(sycl::access::fence_space::local_space); + } + + } + + if (item_ct1.get_group_range(1) != 1 && tid < ncols && (ncols == 1 || ic0 + tid < int(ne01.z()))) { + dst_meta[((sequence * int(ne01.z()) + ic0 + tid) * ne02 + head) * item_ct1.get_group_range(1) + + item_ct1.get_group(1)] = make_float2(KQ_max[tid], KQ_sum[tid]); + } +#else + GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, + max_bias, m0, m1, n_head_log2, logit_softcap, + ne00, ne01, ne02, ne03, + nb01, nb02, nb03, + ne10, ne11, ne12, ne13, + nb11, nb12, nb13, + nb21, nb22, nb23, + ne31, ne32, ne33, + nb31, nb32, nb33); + +#endif // SYCL_FLASH_ATTN +} +#ifdef __clang__ +#pragma clang diagnostic pop +#endif // __clang__ + + +template +void ggml_sycl_flash_attn_ext_vec_case_impl(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + + const int warp_size = WARP_16_SIZE; //better performance than WARP_32_SIZE + + const int cc = ggml_sycl_info().devices[ggml_sycl_get_device()].cc; + + const int nthreads = ggml_sycl_fattn_vec_get_nthreads_host(cc); + const int nwarps = nthreads / warp_size; + + const bool need_f16_K = type_K == GGML_TYPE_F16; + const bool need_f16_V = type_V == GGML_TYPE_F16; + constexpr size_t nbytes_shared = 0; + + launch_fattn, warp_size>( + ctx, dst, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false); +} + +template +void ggml_sycl_flash_attn_ext_vec_case(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + const ggml_tensor * KQV = dst; + const ggml_tensor * Q = dst->src[0]; + + float logit_softcap; + memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); + + if (Q->ne[1] == 1) { + constexpr int cols_per_block = 1; + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + ggml_sycl_flash_attn_ext_vec_case_impl(ctx, dst); + } else { + constexpr bool use_logit_softcap = true; + ggml_sycl_flash_attn_ext_vec_case_impl(ctx, dst); + } + return; + } + + constexpr int cols_per_block = 2; + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + ggml_sycl_flash_attn_ext_vec_case_impl(ctx, dst); + } else { + constexpr bool use_logit_softcap = true; + ggml_sycl_flash_attn_ext_vec_case_impl(ctx, dst); + } +} + +#define DECL_FATTN_VEC_CASE(D, type_K, type_V) \ + template void ggml_sycl_flash_attn_ext_vec_case \ + (ggml_backend_sycl_context & ctx, ggml_tensor * dst) \ + +#define EXTERN_DECL_FATTN_VEC_CASES(D, type_K) \ + extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_F16); \ + extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q4_0); \ + extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q4_1); \ + extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_0); \ + extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_1); \ + extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q8_0); \ + +EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_F16) +EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q4_0) +EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q4_1) +EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q5_0) +EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q5_1) +EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q8_0) + +EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_F16) +EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q4_0) +EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q4_1) +EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q5_0) +EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q5_1) +EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q8_0) + +EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_F16) +EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q4_0) +EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q4_1) +EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_0) +EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_1) +EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q8_0) + +#endif // GGML_SYCL_FATTN_VEC_HPP diff --git a/ggml/src/ggml-sycl/fattn.cpp b/ggml/src/ggml-sycl/fattn.cpp new file mode 100644 index 00000000000..c276ed89827 --- /dev/null +++ b/ggml/src/ggml-sycl/fattn.cpp @@ -0,0 +1,225 @@ +// +// MIT license +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: MIT +// + +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// + + +#include +#include "dpct/helper.hpp" +#include "common.hpp" +#include "fattn-common.hpp" +#include "fattn-tile.hpp" +#include "fattn-vec.hpp" +#include "fattn.hpp" + + +#define FATTN_VEC_CASE(D, type_K, type_V) \ + { \ + const bool type_K_okay = K->type == (type_K) || (K->type == GGML_TYPE_F32 && (type_K) == GGML_TYPE_F16); \ + const bool type_V_okay = V->type == (type_V) || (V->type == GGML_TYPE_F32 && (type_V) == GGML_TYPE_F16); \ + if (Q->ne[0] == (D) && type_K_okay && type_V_okay) { \ + ggml_sycl_flash_attn_ext_vec_case(ctx, dst); \ + return; \ + } \ + } \ + +#define FATTN_VEC_CASES_ALL_D(type_K, type_V) \ + FATTN_VEC_CASE( 64, type_K, type_V) \ + FATTN_VEC_CASE(128, type_K, type_V) \ + FATTN_VEC_CASE(256, type_K, type_V) \ + +static void ggml_sycl_flash_attn_ext_vec(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + ggml_tensor * Q = dst->src[0]; + ggml_tensor * K = dst->src[1]; + ggml_tensor * V = dst->src[2]; + +#ifdef GGML_SYCL_FA_ALL_QUANTS + FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_F16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_F16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_F16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_F16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_F16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_F16) + + FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q4_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q4_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q4_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q4_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q4_0) + + FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q4_1) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_1) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q4_1) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q4_1) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q4_1) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q4_1) + + FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q5_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q5_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q5_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q5_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q5_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q5_0) + + FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q5_1) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q5_1) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q5_1) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q5_1) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q5_1) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q5_1) + + FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q8_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q8_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q8_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q8_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q8_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) +#else + FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_F16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) +#endif // GGML_SYCL_FA_ALL_QUANTS + + GGML_ABORT("Not match KV type in vec"); +} + +// Best FlashAttention kernel for a specific GPU: +enum best_fattn_kernel { + BEST_FATTN_KERNEL_NONE = 0, + BEST_FATTN_KERNEL_VEC = 100, + BEST_FATTN_KERNEL_TILE = 200, +}; + +static best_fattn_kernel ggml_sycl_get_best_fattn_kernel(const int device, const ggml_tensor * dst) { + GGML_UNUSED(device); +#ifndef SYCL_FLASH_ATTN + GGML_UNUSED(dst); + return BEST_FATTN_KERNEL_NONE; +#endif// SYCL_FLASH_ATTN + + if(!g_ggml_sycl_enable_flash_attention) return BEST_FATTN_KERNEL_NONE; + + const ggml_tensor * KQV = dst; + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; + const ggml_tensor * mask = dst->src[3]; + + const int gqa_ratio = Q->ne[2] / K->ne[2]; + GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); + + float max_bias = 0.0f; + memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); + + bool gqa_opt_applies = gqa_ratio >= 2 && mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0; + for (const ggml_tensor * t : {Q, K, V, mask}) { + if (t == nullptr || ggml_is_quantized(t->type)) { + continue; + } + for (size_t i = 1; i < GGML_MAX_DIMS; ++i) { + if (t->nb[i] % 16 != 0) { + gqa_opt_applies = false; + break; + } + } + } + + switch (K->ne[0]) { + case 40: + case 64: + case 72: + case 80: + case 96: + case 128: + case 112: + case 256: + if (V->ne[0] != K->ne[0]) { + return BEST_FATTN_KERNEL_NONE; + } + break; + case 576: + if (V->ne[0] != 512) { + return BEST_FATTN_KERNEL_NONE; + } + if (!gqa_opt_applies) { + return BEST_FATTN_KERNEL_NONE; + } + break; + default: + return BEST_FATTN_KERNEL_NONE; + } + +#ifndef GGML_SYCL_FA_ALL_QUANTS + if (K->type != V->type) { + return BEST_FATTN_KERNEL_NONE; + } +#endif // GGML_SYCL_FA_ALL_QUANTS + + switch (K->type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + break; + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: +#ifndef GGML_SYCL_FA_ALL_QUANTS + return BEST_FATTN_KERNEL_NONE; +#endif // GGML_SYCL_FA_ALL_QUANTS + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q8_0: + break; + default: + return BEST_FATTN_KERNEL_NONE; + } + + if (mask && mask->ne[2] != 1) { + return BEST_FATTN_KERNEL_NONE; + } + + // For small batch sizes the vector kernel may be preferable over the kernels optimized for large batch sizes: + const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0 && K->ne[1] % FATTN_KQ_STRIDE == 0; + + // Todo: Use the XMX kernel if possible: + + // If there are no tensor cores available, use the generic tile kernel: + if (can_use_vector_kernel) { + if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) { + if (Q->ne[1] == 1) { + if (!gqa_opt_applies) { + return BEST_FATTN_KERNEL_VEC; + } + } + } else { + if (Q->ne[1] <= 2) { + return BEST_FATTN_KERNEL_VEC; + } + } + } + return BEST_FATTN_KERNEL_TILE; +} + +void ggml_sycl_flash_attn_ext(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + ggml_sycl_set_device(ctx.device); + switch (ggml_sycl_get_best_fattn_kernel(ggml_sycl_get_device(), dst)) { + case BEST_FATTN_KERNEL_NONE: + GGML_ABORT("Not support Flash-Attention"); + case BEST_FATTN_KERNEL_TILE: + ggml_sycl_flash_attn_ext_tile(ctx, dst); + break; + case BEST_FATTN_KERNEL_VEC: + ggml_sycl_flash_attn_ext_vec(ctx, dst); + break; + } +} + +bool ggml_sycl_flash_attn_ext_supported(int device, const ggml_tensor * dst) { + return ggml_sycl_get_best_fattn_kernel(device, dst) != BEST_FATTN_KERNEL_NONE; +} diff --git a/ggml/src/ggml-sycl/fattn.hpp b/ggml/src/ggml-sycl/fattn.hpp new file mode 100644 index 00000000000..f2a8ffc97de --- /dev/null +++ b/ggml/src/ggml-sycl/fattn.hpp @@ -0,0 +1,22 @@ +// +// MIT license +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: MIT +// + +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// + +#ifndef GGML_SYCL_FATTN_HPP +#define GGML_SYCL_FATTN_HPP + +#include "common.hpp" + +void ggml_sycl_flash_attn_ext(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +bool ggml_sycl_flash_attn_ext_supported(int device, const ggml_tensor * dst); + +#endif // GGML_SYCL_FATTN_HPP diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 0614d7e8f3a..dfacde0af33 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -62,6 +62,8 @@ int g_ggml_sycl_disable_graph = 0; int g_ggml_sycl_disable_dnn = 0; int g_ggml_sycl_prioritize_dmmv = 0; int g_ggml_sycl_use_async_mem_op = 0; +int g_ggml_sycl_enable_flash_attention = 1; + static ggml_sycl_device_info ggml_sycl_init() { ggml_sycl_device_info info = {}; @@ -94,11 +96,12 @@ static ggml_sycl_device_info ggml_sycl_init() { info.devices[i].cc = 100 * prop.get_major_version() + 10 * prop.get_minor_version(); - info.devices[i].nsm = prop.get_max_compute_units(); + info.devices[i].nsm = prop.get_max_compute_units() / 16; //16: Number of Xe Cores info.devices[i].opt_feature.reorder = device.ext_oneapi_architecture_is(syclex::arch_category::intel_gpu); info.devices[i].smpbo = prop.get_local_mem_size(); - info.max_work_group_sizes[i] = prop.get_max_work_group_size(); + info.devices[i].max_wg_per_cu = info.max_work_group_sizes[i] / prop.get_max_compute_units(); + } for (int id = 0; id < info.device_count; ++id) { @@ -211,7 +214,37 @@ static void ggml_check_sycl() try { g_ggml_sycl_disable_graph = get_sycl_env("GGML_SYCL_DISABLE_GRAPH", 1); g_ggml_sycl_disable_dnn = get_sycl_env("GGML_SYCL_DISABLE_DNN", 0); g_ggml_sycl_prioritize_dmmv = get_sycl_env("GGML_SYCL_PRIORITIZE_DMMV", 0); + +#ifdef SYCL_FLASH_ATTN + g_ggml_sycl_enable_flash_attention = get_sycl_env("GGML_SYCL_ENABLE_FLASH_ATTN", 1); +#else + g_ggml_sycl_enable_flash_attention = 0; +#endif + GGML_SYCL_DEBUG("[SYCL] call ggml_check_sycl\n"); + + GGML_LOG_INFO("Build with Macros:\n"); +#if defined(GGML_SYCL_FORCE_MMQ) + GGML_LOG_INFO(" GGML_SYCL_FORCE_MMQ: yes\n"); +#else + GGML_LOG_INFO(" GGML_SYCL_FORCE_MMQ: no\n"); +#endif +#if defined(GGML_SYCL_F16) + GGML_LOG_INFO(" GGML_SYCL_F16: yes\n"); +#else + GGML_LOG_INFO(" GGML_SYCL_F16: no\n"); +#endif +#if defined(GGML_SYCL_GRAPH) + GGML_LOG_INFO(" GGML_SYCL_GRAPH: yes\n"); +#else + GGML_LOG_INFO(" GGML_SYCL_GRAPH: no\n"); +#endif +#if defined(GGML_SYCL_DNNL) + GGML_LOG_INFO(" GGML_SYCL_DNNL: yes\n"); +#else + GGML_LOG_INFO(" GGML_SYCL_DNNL: no\n"); +#endif + GGML_LOG_INFO("Running with Environment Variables:\n"); GGML_LOG_INFO(" GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug); GGML_LOG_INFO(" GGML_SYCL_DISABLE_OPT: %d\n", g_ggml_sycl_disable_optimize); @@ -226,16 +259,12 @@ static void ggml_check_sycl() try { GGML_LOG_INFO(" GGML_SYCL_DISABLE_DNN: DNN disabled by compile flag\n"); #endif GGML_LOG_INFO(" GGML_SYCL_PRIORITIZE_DMMV: %d\n", g_ggml_sycl_prioritize_dmmv); - GGML_LOG_INFO("Build with Macros:\n"); -#if defined(GGML_SYCL_FORCE_MMQ) - GGML_LOG_INFO(" GGML_SYCL_FORCE_MMQ: yes\n"); -#else - GGML_LOG_INFO(" GGML_SYCL_FORCE_MMQ: no\n"); -#endif -#if defined(GGML_SYCL_F16) - GGML_LOG_INFO(" GGML_SYCL_F16: yes\n"); + +#ifdef SYCL_FLASH_ATTN + GGML_LOG_INFO(" GGML_SYCL_ENABLE_FLASH_ATTN: %d\n", g_ggml_sycl_enable_flash_attention); #else - GGML_LOG_INFO(" GGML_SYCL_F16: no\n"); + GGML_LOG_INFO(" GGML_SYCL_ENABLE_FLASH_ATTN: %d disabled by compile flag\n", + g_ggml_sycl_enable_flash_attention); #endif /* NOT REMOVE, keep it for next optimize for XMX. @@ -3012,7 +3041,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons } #if GGML_SYCL_DNNL - // oneDNN handles strided data and does not need overhead of get_to_fp16_nc_sycl + // oneDNN handles strided data and does not need overhead of ggml_get_to_fp16_nc_sycl const int64_t ne_src1 = src1->nb[last_str] * src1->ne[last_dim] / type_size_src1; src1_f16_alloc.alloc(ne_src1); const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type, dst); @@ -3021,7 +3050,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons # else const int64_t ne_src1 = ggml_nelements(src1); src1_f16_alloc.alloc(ne_src1); - const to_fp16_nc_sycl_t to_fp16_nc_sycl = get_to_fp16_nc_sycl(src1->type); + const to_fp16_nc_sycl_t to_fp16_nc_sycl = ggml_get_to_fp16_nc_sycl(src1->type); GGML_ASSERT(to_fp16_nc_sycl != nullptr); to_fp16_nc_sycl(src1_f16, src1_f16_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, queue); #endif @@ -4158,6 +4187,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg case GGML_OP_ARANGE: ggml_sycl_arange(ctx, dst); break; + case GGML_OP_FLASH_ATTN_EXT: + ggml_sycl_flash_attn_ext(ctx, dst); + break; default: return false; } @@ -4862,6 +4894,8 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g return op->type == GGML_TYPE_F32; case GGML_OP_ARANGE: return op->type == GGML_TYPE_F32; + case GGML_OP_FLASH_ATTN_EXT: + return ggml_sycl_flash_attn_ext_supported(device, op); default: return false; } diff --git a/ggml/src/ggml-sycl/presets.hpp b/ggml/src/ggml-sycl/presets.hpp index b6517374230..dc4dad1d37a 100644 --- a/ggml/src/ggml-sycl/presets.hpp +++ b/ggml/src/ggml-sycl/presets.hpp @@ -73,4 +73,7 @@ static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUA #define MUL_MAT_SRC1_COL_STRIDE 128 #define QK_WARP_SIZE 32 +#define WARP_32_SIZE 32 +#define WARP_16_SIZE 16 + #endif // GGML_SYCL_PRESETS_HPP diff --git a/ggml/src/ggml-sycl/softmax.cpp b/ggml/src/ggml-sycl/softmax.cpp index 15d92e5e04c..fdf9b843e01 100644 --- a/ggml/src/ggml-sycl/softmax.cpp +++ b/ggml/src/ggml-sycl/softmax.cpp @@ -102,7 +102,7 @@ static void soft_max_f32(const float * x, max_val = sycl::max(max_val, val); } // find the max value in the block - max_val = warp_reduce_max(max_val); + max_val = warp_reduce_max(max_val); if (block_size > WARP_SIZE) { if (warp_id == 0) { @@ -116,7 +116,7 @@ static void soft_max_f32(const float * x, item_ct1.barrier(); max_val = buf_iw[lane_id]; - max_val = warp_reduce_max(max_val); + max_val = warp_reduce_max(max_val); } float tmp = 0.0f; // partial sum @@ -133,7 +133,7 @@ static void soft_max_f32(const float * x, vals[col] = val; } // find the sum of exps in the block - tmp = warp_reduce_sum(tmp); + tmp = warp_reduce_sum(tmp); if (block_size > WARP_SIZE) { item_ct1.barrier(); if (warp_id == 0) { @@ -153,7 +153,7 @@ static void soft_max_f32(const float * x, for (size_t i = 1; i < nreduce; i += 1) { tmp += buf_iw[lane_id + i * WARP_SIZE]; } - tmp = warp_reduce_sum(tmp); + tmp = warp_reduce_sum(tmp); } if (sinks) { tmp += sycl::native::exp(sinks[i02] - max_val); @@ -191,7 +191,7 @@ static void soft_max_back_f32(const float *grad, const float *dstf, float *dst, dgf_dot += dstf[col]*grad[col]; } - dgf_dot = warp_reduce_sum(dgf_dot); + dgf_dot = warp_reduce_sum(dgf_dot); for (int col = tid; col < ncols; col += WARP_SIZE) { dst[col] = scale * (grad[col] - dgf_dot) * dstf[col]; diff --git a/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp b/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp new file mode 100644 index 00000000000..5c06d42fdbd --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.hpp" + +DECL_FATTN_TILE_CASE(112, 112); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp b/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp new file mode 100644 index 00000000000..f74e1202b83 --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.hpp" + +DECL_FATTN_TILE_CASE(128, 128); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp b/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp new file mode 100644 index 00000000000..b574fe9308d --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.hpp" + +DECL_FATTN_TILE_CASE(256, 256); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp b/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp new file mode 100644 index 00000000000..8c8fb692c43 --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.hpp" + +DECL_FATTN_TILE_CASE(40, 40); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp b/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp new file mode 100644 index 00000000000..f218552e85f --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.hpp" + +DECL_FATTN_TILE_CASE(576, 512); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp b/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp new file mode 100644 index 00000000000..99303a53a3c --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.hpp" + +DECL_FATTN_TILE_CASE(64, 64); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp b/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp new file mode 100644 index 00000000000..50592768afd --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.hpp" + +DECL_FATTN_TILE_CASE(72, 72); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp b/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp new file mode 100644 index 00000000000..74f1ea5e90c --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.hpp" + +DECL_FATTN_TILE_CASE(80, 80); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp b/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp new file mode 100644 index 00000000000..cefb46dddc7 --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.hpp" + +DECL_FATTN_TILE_CASE(96, 96); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp new file mode 100644 index 00000000000..32cf4f2859b --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.hpp" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp new file mode 100644 index 00000000000..a61a19021bb --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.hpp" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q4_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp new file mode 100644 index 00000000000..63b74fb347a --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.hpp" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q4_1); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp new file mode 100644 index 00000000000..46e2d9853c5 --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.hpp" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q5_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp new file mode 100644 index 00000000000..7aabb6ff6e4 --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.hpp" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q5_1); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp new file mode 100644 index 00000000000..148ea217f62 --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.hpp" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp new file mode 100644 index 00000000000..4b169dbcdbc --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.hpp" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_F16); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp new file mode 100644 index 00000000000..79f530b1815 --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.hpp" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp new file mode 100644 index 00000000000..2f7db51ce82 --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.hpp" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp new file mode 100644 index 00000000000..9e3bf0b14a1 --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.hpp" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp new file mode 100644 index 00000000000..18081879cec --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.hpp" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp new file mode 100644 index 00000000000..1c387b0d87c --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.hpp" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp new file mode 100644 index 00000000000..f005b3762cc --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.hpp" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_F16); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp new file mode 100644 index 00000000000..3553b1cdd16 --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.hpp" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp new file mode 100644 index 00000000000..687ec567115 --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.hpp" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp new file mode 100644 index 00000000000..2663bfe7466 --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.hpp" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp new file mode 100644 index 00000000000..641b7c7ae2a --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.hpp" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp new file mode 100644 index 00000000000..3d3181d4719 --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.hpp" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp new file mode 100644 index 00000000000..85d5026ad4f --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.hpp" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_F16); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp new file mode 100644 index 00000000000..1e81401a2c9 --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.hpp" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp new file mode 100644 index 00000000000..54251473f97 --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.hpp" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp new file mode 100644 index 00000000000..d418c1fb21e --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.hpp" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp new file mode 100644 index 00000000000..0f26cfabd09 --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.hpp" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp new file mode 100644 index 00000000000..4fb98723519 --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.hpp" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp new file mode 100644 index 00000000000..85b79cd1976 --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.hpp" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_F16); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp new file mode 100644 index 00000000000..7348323b28b --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.hpp" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp new file mode 100644 index 00000000000..f19af2aa0ba --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.hpp" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp new file mode 100644 index 00000000000..d7075bac600 --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.hpp" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp new file mode 100644 index 00000000000..627f9a57755 --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.hpp" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp new file mode 100644 index 00000000000..23304eecd35 --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.hpp" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp new file mode 100644 index 00000000000..95acb5d4fbf --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.hpp" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_F16); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp new file mode 100644 index 00000000000..5e88f4bab8a --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.hpp" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp new file mode 100644 index 00000000000..69f297feb0c --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.hpp" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp new file mode 100644 index 00000000000..455842a9421 --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.hpp" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp new file mode 100644 index 00000000000..f7ef7391571 --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.hpp" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp new file mode 100644 index 00000000000..1c633bdf2fa --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.hpp" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-sycl/vecdotq.hpp b/ggml/src/ggml-sycl/vecdotq.hpp index 43482b3672c..9a267d85a0c 100644 --- a/ggml/src/ggml-sycl/vecdotq.hpp +++ b/ggml/src/ggml-sycl/vecdotq.hpp @@ -650,6 +650,19 @@ static __dpct_inline__ float vec_dot_q8_0_q8_1_impl(const int *v, const int *u, return d8_0*d8_1 * sumi; } +template +static __dpct_inline__ T vec_dot_q8_0_q8_1_impl(const int * v, const int * u, const T & d8_0, const T & d8_1) { + int sumi = 0; + +#pragma unroll + for (int i = 0; i < vdr; ++i) { + // SIMD dot product of quantized values + sumi = ggml_sycl_dp4a(v[i], u[i], sumi); + } + + return d8_0*d8_1 * ((T) sumi); +} + template static __dpct_inline__ float vec_dot_q8_1_q8_1_impl(const int *v, const int *u, const sycl::half2 &dm8, From 4b0653a792a1bc5039a91bcf87910252b77fa0ba Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Sun, 8 Mar 2026 06:33:48 -0500 Subject: [PATCH 230/831] vulkan: Fix data races in coopmat1 mul_mat(_id) (llama/20084) * vulkan: Fix data races in coopmat1 mul_mat(_id) Add barriers between coopmat store and regular loads. We sort of got away with this because it was the same subgroup accessing the values, but it's still a race and may not work. * switch to subgroup control barriers --- ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp index 79344d33005..23f3bd8d6d0 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp @@ -377,6 +377,7 @@ void main() { [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); + barrier(); [[unroll]] for (uint col = 0; col < TN; col += storestride) { const uint row_i = dc + cm_col * TN + col + store_c; if (row_i >= _ne1) break; @@ -387,6 +388,7 @@ void main() { data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]); } } + barrier(); } } #else @@ -404,18 +406,22 @@ void main() { // Full coopMat is within bounds, but stride_d is not aligned coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); + controlBarrier(gl_ScopeSubgroup, gl_ScopeSubgroup, gl_StorageSemanticsShared, gl_SemanticsAcquireRelease); [[unroll]] for (uint col = 0; col < TN; col += storestride) { data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]); } + controlBarrier(gl_ScopeSubgroup, gl_ScopeSubgroup, gl_StorageSemanticsShared, gl_SemanticsAcquireRelease); } else if (dr + cm_row * TM < p.M && dc + cm_col * TN < p.N) { // Partial coopMat is within bounds coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); + controlBarrier(gl_ScopeSubgroup, gl_ScopeSubgroup, gl_StorageSemanticsShared, gl_SemanticsAcquireRelease); [[unroll]] for (uint col = 0; col < TN; col += storestride) { if (dr + cm_row * TM + store_r < p.M && dc + cm_col * TN + col + store_c < p.N) { data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]); } } + controlBarrier(gl_ScopeSubgroup, gl_ScopeSubgroup, gl_StorageSemanticsShared, gl_SemanticsAcquireRelease); } } } From 8d97f59639e75de2ea885bd27df51d1f48cc4dc1 Mon Sep 17 00:00:00 2001 From: GiantPrince <90118823+GiantPrince@users.noreply.github.com> Date: Sun, 8 Mar 2026 07:38:17 -0400 Subject: [PATCH 231/831] ggml-vulkan: Add ELU op support (llama/20183) * ggml-Vulkan: add ELU support * ggml-Vulkan: remove extra spaces and variables * ggml-Vulkan: fix format issue * ggml-Vulkan: fix format issue * fix whitespace issue * Update Vulkan.csv and ops.md --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 9 +++++++ ggml/src/ggml-vulkan/vulkan-shaders/elu.comp | 27 +++++++++++++++++++ .../vulkan-shaders/vulkan-shaders-gen.cpp | 2 ++ 3 files changed, 38 insertions(+) create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/elu.comp diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 23d6d39e0e8..0bf7d2e2473 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -744,6 +744,7 @@ struct vk_device_struct { // [src/dst 0=fp32,1=fp16] vk_pipeline pipeline_exp[2]; + vk_pipeline pipeline_elu[2]; vk_pipeline pipeline_gelu[2]; vk_pipeline pipeline_gelu_erf[2]; vk_pipeline pipeline_gelu_quick[2]; @@ -4373,6 +4374,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \ ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + CREATE_UNARY(elu) CREATE_UNARY(gelu) CREATE_UNARY(gelu_erf) CREATE_UNARY(gelu_quick) @@ -9241,6 +9243,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const switch (ggml_get_unary_op(dst)) { case GGML_UNARY_OP_EXP: return ctx->device->pipeline_exp[dst->type == GGML_TYPE_F16]; + case GGML_UNARY_OP_ELU: + return ctx->device->pipeline_elu[dst->type == GGML_TYPE_F16]; case GGML_UNARY_OP_SILU: return ctx->device->pipeline_silu[dst->type == GGML_TYPE_F16]; case GGML_UNARY_OP_GELU: @@ -12852,6 +12856,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr } switch (ggml_get_unary_op(node)) { + case GGML_UNARY_OP_ELU: case GGML_UNARY_OP_EXP: case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_GELU: @@ -14951,6 +14956,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_UNARY: switch (ggml_get_unary_op(op)) { case GGML_UNARY_OP_EXP: + case GGML_UNARY_OP_ELU: case GGML_UNARY_OP_GELU: case GGML_UNARY_OP_GELU_ERF: case GGML_UNARY_OP_GELU_QUICK: @@ -16074,6 +16080,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * case GGML_UNARY_OP_EXP: tensor_clone = ggml_exp(ggml_ctx, src_clone[0]); break; + case GGML_UNARY_OP_ELU: + tensor_clone = ggml_elu(ggml_ctx, src_clone[0]); + break; case GGML_UNARY_OP_SILU: tensor_clone = ggml_silu(ggml_ctx, src_clone[0]); break; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp b/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp new file mode 100644 index 00000000000..84dcbd8c88f --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp @@ -0,0 +1,27 @@ +#version 450 + +#include "generic_head.glsl" +#include "types.glsl" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + float x = float(data_a[i]); + + if (x < 0.0f) { + x = exp(x) - 1; + } + + data_d[i] = D_TYPE(x); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 85455988c57..ed077dfb6c1 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -867,6 +867,8 @@ void process_shaders() { string_to_spv("hardswish_f32", "hardswish.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("abs_f16", "abs.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); string_to_spv("abs_f32", "abs.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("elu_f16", "elu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("elu_f32", "elu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("xielu_f16", "xielu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); string_to_spv("xielu_f32", "xielu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); From f099ed27b8669c1ab1ac729788961acf25c53ae5 Mon Sep 17 00:00:00 2001 From: Michael Huang <15768500+tehsiuhuang@users.noreply.github.com> Date: Sun, 8 Mar 2026 21:45:43 -0700 Subject: [PATCH 232/831] cuda : display total and free VRAM capacity during device initialization (llama/20185) --- ggml/src/ggml-cuda/ggml-cuda.cu | 30 +++++++++++++++++++++++------- 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index a8007a06360..0fafaf00931 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -205,7 +205,14 @@ static ggml_cuda_device_info ggml_cuda_init() { GGML_ASSERT(info.device_count <= GGML_CUDA_MAX_DEVICES); int64_t total_vram = 0; - GGML_LOG_INFO("%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, info.device_count); + for (int id = 0; id < info.device_count; ++id) { + cudaDeviceProp prop; + CUDA_CHECK(cudaGetDeviceProperties(&prop, id)); + total_vram += prop.totalGlobalMem; + } + GGML_LOG_INFO("%s: found %d " GGML_CUDA_NAME " devices (Total VRAM: %zu MiB):\n", + __func__, info.device_count, (size_t)(total_vram / (1024 * 1024))); + total_vram = 0; std::vector> turing_devices_without_mma; for (int id = 0; id < info.device_count; ++id) { @@ -243,6 +250,12 @@ static ggml_cuda_device_info ggml_cuda_init() { #else info.devices[id].supports_cooperative_launch = false; #endif // !(GGML_USE_MUSA) + + // cudaMemGetInfo returns info for the current device + size_t free_mem; + CUDA_CHECK(cudaSetDevice(id)); + CUDA_CHECK(cudaMemGetInfo(&free_mem, NULL)); + #if defined(GGML_USE_HIP) info.devices[id].smpbo = prop.sharedMemPerBlock; @@ -257,22 +270,25 @@ static ggml_cuda_device_info ggml_cuda_init() { info.devices[id].cc += prop.minor * 0x10; } } - GGML_LOG_INFO(" Device %d: %s, %s (0x%x), VMM: %s, Wave Size: %d\n", + GGML_LOG_INFO(" Device %d: %s, %s (0x%x), VMM: %s, Wave Size: %d, VRAM: %zu MiB (%zu MiB free)\n", id, prop.name, prop.gcnArchName, info.devices[id].cc & 0xffff, - device_vmm ? "yes" : "no", prop.warpSize); + device_vmm ? "yes" : "no", prop.warpSize, + (size_t)(prop.totalGlobalMem / (1024 * 1024)), free_mem / (1024 * 1024)); #elif defined(GGML_USE_MUSA) // FIXME: Ensure compatibility with varying warp sizes across different MUSA archs. info.devices[id].warp_size = 32; info.devices[id].smpbo = prop.sharedMemPerBlockOptin; info.devices[id].cc = GGML_CUDA_CC_OFFSET_MTHREADS + prop.major * 0x100; info.devices[id].cc += prop.minor * 0x10; - GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s\n", - id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no"); + GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s, VRAM: %zu MiB (%zu MiB free)\n", + id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no", + (size_t)(prop.totalGlobalMem / (1024 * 1024)), free_mem / (1024 * 1024)); #else info.devices[id].smpbo = prop.sharedMemPerBlockOptin; info.devices[id].cc = 100*prop.major + 10*prop.minor; - GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s\n", - id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no"); + GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s, VRAM: %zu MiB (%zu MiB free)\n", + id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no", + (size_t)(prop.totalGlobalMem / (1024 * 1024)), free_mem / (1024 * 1024)); std::string device_name(prop.name); if (device_name == "NVIDIA GeForce MX450") { turing_devices_without_mma.push_back({ id, device_name }); From 890c047e306536aa906d76b01eb45623bd12ced4 Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Mon, 9 Mar 2026 07:23:45 +0100 Subject: [PATCH 233/831] vulkan: skip zero size tensors in backend copies (llama/20233) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 31 +++++++++++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 0bf7d2e2473..70e992ce233 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -13253,6 +13253,10 @@ static void ggml_backend_vk_buffer_memset_tensor(ggml_backend_buffer_t buffer, g ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context; vk_buffer buf = buf_ctx->dev_buffer; + if (size == 0) { + return; + } + uint32_t val32 = (uint32_t)value * 0x01010101; ggml_vk_buffer_memset(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, val32, size); } @@ -13262,6 +13266,10 @@ static void ggml_backend_vk_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context; vk_buffer buf = buf_ctx->dev_buffer; + if (size == 0) { + return; + } + ggml_vk_buffer_write(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size); } @@ -13269,12 +13277,20 @@ static void ggml_backend_vk_buffer_get_tensor(ggml_backend_buffer_t buffer, cons VK_LOG_DEBUG("ggml_backend_vk_buffer_get_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")"); ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context; + if (size == 0) { + return; + } + vk_buffer buf = buf_ctx->dev_buffer; ggml_vk_buffer_read(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size); } static bool ggml_backend_vk_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) { + if (ggml_nbytes(src) == 0) { + return true; + } + if (ggml_backend_buffer_is_vk(src->buffer)) { ggml_backend_vk_buffer_context * src_buf_ctx = (ggml_backend_vk_buffer_context *)src->buffer->context; ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; @@ -13464,6 +13480,10 @@ static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, ggml_tensor ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; GGML_ASSERT((tensor->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || tensor->buffer->buft == ggml_backend_vk_host_buffer_type()) && "unsupported buffer type"); + if (size == 0) { + return; + } + ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context; vk_context cpy_ctx; @@ -13507,6 +13527,10 @@ static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_ ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; GGML_ASSERT((tensor->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || tensor->buffer->buft == ggml_backend_vk_host_buffer_type()) && "unsupported buffer type"); + if (size == 0) { + return; + } + ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context; vk_context compute_ctx = ggml_vk_get_compute_ctx(ctx); @@ -13533,9 +13557,14 @@ static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_ } static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const ggml_tensor * src, ggml_tensor * dst) { - VK_LOG_DEBUG("ggml_backend_vk_cpy_tensor_async()"); + VK_LOG_DEBUG("ggml_backend_vk_cpy_tensor_async(" << src << " -> " << dst << ", size=" << ggml_nbytes(src) << ")"); ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend_dst->context; + // Skip zero-size tensors + if (ggml_nbytes(src) == 0) { + return true; + } + if (dst->buffer->buft != ggml_backend_vk_get_default_buffer_type(backend_dst)) { return false; } From 65dbf3c31a44501c9cea2995587d6090db4a7d5f Mon Sep 17 00:00:00 2001 From: Bertay Eren <39909689+bertaye@users.noreply.github.com> Date: Mon, 9 Mar 2026 09:24:16 +0300 Subject: [PATCH 234/831] ggml-vulkan: add SGN operator, auto-generate Vulkan.csv and ops.md (llama/20219) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 9 ++++++++ ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp | 21 +++++++++++++++++++ .../vulkan-shaders/vulkan-shaders-gen.cpp | 2 ++ 3 files changed, 32 insertions(+) create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 70e992ce233..61d112c50a7 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -763,6 +763,7 @@ struct vk_device_struct { vk_pipeline pipeline_ceil[2]; vk_pipeline pipeline_floor[2]; vk_pipeline pipeline_trunc[2]; + vk_pipeline pipeline_sgn[2]; vk_pipeline pipeline_add1_f16_f16; vk_pipeline pipeline_add1_f16_f32; @@ -4393,6 +4394,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_UNARY(ceil) CREATE_UNARY(floor) CREATE_UNARY(trunc) + CREATE_UNARY(sgn) #undef CREATE_UNARY #define CREATE_UNARY_RTE(name) \ @@ -9281,6 +9283,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_floor[dst->type == GGML_TYPE_F16]; case GGML_UNARY_OP_TRUNC: return ctx->device->pipeline_trunc[dst->type == GGML_TYPE_F16]; + case GGML_UNARY_OP_SGN: + return ctx->device->pipeline_sgn[dst->type == GGML_TYPE_F16]; default: break; } @@ -12875,6 +12879,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_UNARY_OP_CEIL: case GGML_UNARY_OP_FLOOR: case GGML_UNARY_OP_TRUNC: + case GGML_UNARY_OP_SGN: ggml_vk_unary(ctx, compute_ctx, src0, node); break; case GGML_UNARY_OP_XIELU: @@ -15004,6 +15009,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_UNARY_OP_CEIL: case GGML_UNARY_OP_FLOOR: case GGML_UNARY_OP_TRUNC: + case GGML_UNARY_OP_SGN: return ggml_is_contiguous(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && @@ -16170,6 +16176,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * case GGML_UNARY_OP_TRUNC: tensor_clone = ggml_trunc(ggml_ctx, src_clone[0]); break; + case GGML_UNARY_OP_SGN: + tensor_clone = ggml_sgn(ggml_ctx, src_clone[0]); + break; default: std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl; GGML_ABORT("fatal error"); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp new file mode 100644 index 00000000000..a9c147bf9ac --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp @@ -0,0 +1,21 @@ +#version 450 + +#include "generic_head.glsl" +#include "types.glsl" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + data_d[i] = D_TYPE(sign(float(data_a[i]))); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index ed077dfb6c1..fb8941232bc 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -871,6 +871,8 @@ void process_shaders() { string_to_spv("elu_f32", "elu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("xielu_f16", "xielu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); string_to_spv("xielu_f32", "xielu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("sgn_f16", "sgn.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("sgn_f32", "sgn.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("tri_f16", "tri.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); string_to_spv("tri_f32", "tri.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); From 3984ae384d7e0011f0a182693089c3cca67a4a6f Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Mon, 9 Mar 2026 16:15:36 +0800 Subject: [PATCH 235/831] ggml-cuda: disable gdn for musa (llama/20278) --- ggml/src/ggml-cuda/ggml-cuda.cu | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 0fafaf00931..cda275b8c58 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -4992,9 +4992,15 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_LEAKY_RELU: case GGML_OP_RWKV_WKV6: case GGML_OP_GATED_LINEAR_ATTN: - case GGML_OP_GATED_DELTA_NET: case GGML_OP_RWKV_WKV7: return true; + case GGML_OP_GATED_DELTA_NET: + //TODO: enable once MUSA compiler is solved https://github.com/ggml-org/llama.cpp/pull/19504#issuecomment-4018634327 +#ifdef GGML_USE_MUSA + return false; +#else + return true; +#endif // GGML_USE_MUSA case GGML_OP_FLASH_ATTN_EXT: return ggml_cuda_flash_attn_ext_supported(dev_ctx->device, op); case GGML_OP_CROSS_ENTROPY_LOSS: From d19c65e9daa7eb0ea383cf666aed9e3580c4a844 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 9 Mar 2026 16:45:11 +0200 Subject: [PATCH 236/831] metal : add upscale (llama/20284) --- ggml/src/ggml-metal/ggml-metal-device.cpp | 23 +++- ggml/src/ggml-metal/ggml-metal-device.m | 2 +- ggml/src/ggml-metal/ggml-metal-impl.h | 2 + ggml/src/ggml-metal/ggml-metal-ops.cpp | 59 +++++---- ggml/src/ggml-metal/ggml-metal.metal | 154 +++++++++++++++++++++- 5 files changed, 211 insertions(+), 29 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 06f3d804590..169c63dd7a4 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -1717,12 +1717,29 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_upscale(ggml_met char base[256]; char name[256]; - snprintf(base, 256, "kernel_upscale_%s", ggml_type_name(op->src[0]->type)); - snprintf(name, 256, "%s", base); + const int32_t mode_flags = ggml_get_op_params_i32(op, 0); + const ggml_scale_mode mode = (ggml_scale_mode) (mode_flags & 0xFF); + + const bool antialias = (mode_flags & GGML_SCALE_FLAG_ANTIALIAS); + + if (mode == GGML_SCALE_MODE_BILINEAR) { + snprintf(base, 256, "kernel_upscale_bilinear_%s", ggml_type_name(op->src[0]->type)); + } else if (mode == GGML_SCALE_MODE_BICUBIC) { + snprintf(base, 256, "kernel_upscale_bicubic_%s", ggml_type_name(op->src[0]->type)); + } else { + snprintf(base, 256, "kernel_upscale_nearest_%s", ggml_type_name(op->src[0]->type)); + } + snprintf(name, 256, "%s_aa=%d", base, antialias); ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); if (!res.pipeline) { - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_bool(cv, antialias, FC_UPSCALE + 0); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); } return res; diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 4cce414abfe..23bd2b2ab72 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -1108,7 +1108,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te op->type == GGML_TYPE_F32 && (op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32); case GGML_OP_UPSCALE: - return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST && !(op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS); + return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_POOL_1D: return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_POOL_2D: diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 383e0d6e93b..bf51055e367 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -83,6 +83,7 @@ #define FC_UNARY 1200 #define FC_BIN 1300 #define FC_SUM_ROWS 1400 +#define FC_UPSCALE 1500 // op-specific constants #define OP_FLASH_ATTN_EXT_NQPSG 8 @@ -890,6 +891,7 @@ typedef struct { float sf1; float sf2; float sf3; + float poffs; } ggml_metal_kargs_upscale; typedef struct { diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index b3390352ffc..524e1116629 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -3729,32 +3729,43 @@ int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne, op, ne); GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); - const float sf0 = (float)ne0/op->src[0]->ne[0]; - const float sf1 = (float)ne1/op->src[0]->ne[1]; - const float sf2 = (float)ne2/op->src[0]->ne[2]; - const float sf3 = (float)ne3/op->src[0]->ne[3]; + float sf0 = (float)ne0/op->src[0]->ne[0]; + float sf1 = (float)ne1/op->src[0]->ne[1]; + float sf2 = (float)ne2/op->src[0]->ne[2]; + float sf3 = (float)ne3/op->src[0]->ne[3]; + + const int32_t mode_flags = ggml_get_op_params_i32(op, 0); + + float poffs = 0.5f; + + if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) { + poffs = 0.0f; + sf0 = ne0 > 1 && ne00 > 1 ? (float)(ne0 - 1) / (ne00 - 1) : sf0; + sf1 = ne1 > 1 && ne01 > 1 ? (float)(ne1 - 1) / (ne01 - 1) : sf1; + } ggml_metal_kargs_upscale args = { - /*.ne00 =*/ ne00, - /*.ne01 =*/ ne01, - /*.ne02 =*/ ne02, - /*.ne03 =*/ ne03, - /*.nb00 =*/ nb00, - /*.nb01 =*/ nb01, - /*.nb02 =*/ nb02, - /*.nb03 =*/ nb03, - /*.ne0 =*/ ne0, - /*.ne1 =*/ ne1, - /*.ne2 =*/ ne2, - /*.ne3 =*/ ne3, - /*.nb0 =*/ nb0, - /*.nb1 =*/ nb1, - /*.nb2 =*/ nb2, - /*.nb3 =*/ nb3, - /*.sf0 =*/ sf0, - /*.sf1 =*/ sf1, - /*.sf2 =*/ sf2, - /*.sf3 =*/ sf3 + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.sf0 =*/ sf0, + /*.sf1 =*/ sf1, + /*.sf2 =*/ sf2, + /*.sf3 =*/ sf3, + /*.poffs =*/ poffs, }; auto pipeline = ggml_metal_library_get_pipeline_upscale(lib, op); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index a58e641ad86..5cfd69dd866 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -4530,7 +4530,9 @@ kernel void kernel_conv_transpose_2d( uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]); -kernel void kernel_upscale_f32( +constant bool FC_upscale_aa [[function_constant(FC_UPSCALE + 0)]]; + +kernel void kernel_upscale_nearest_f32( constant ggml_metal_kargs_upscale & args, device const char * src0, device char * dst, @@ -4556,6 +4558,156 @@ kernel void kernel_upscale_f32( } } +static inline float bilinear_tri(float x) { + return MAX(0.0f, 1.0f - fabs(x)); +} + +kernel void kernel_upscale_bilinear_f32( + constant ggml_metal_kargs_upscale & args, + device const char * src0, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + + const int64_t i3 = tgpig.z; + const int64_t i2 = tgpig.y; + const int64_t i1 = tgpig.x; + + const int64_t i03 = i3 / args.sf3; + const int64_t i02 = i2 / args.sf2; + + const float f01 = ((float)i1 + args.poffs) / args.sf1 - args.poffs; + const int64_t i01 = MAX(0, MIN(args.ne01 - 1, (int64_t)floor(f01))); + const int64_t i01p = MAX(0, MIN(args.ne01 - 1, i01 + 1)); + const float fd1 = MAX(0.0f, MIN(1.0f, f01 - (float)i01)); + + src0 += i03*args.nb03 + i02*args.nb02; + + device float * dst_ptr = (device float *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1); + + if (FC_upscale_aa) { + const float support0 = MAX(1.0f, 1.0f / args.sf0); + const float invscale0 = 1.0f / support0; + const float support1 = MAX(1.0f, 1.0f / args.sf1); + const float invscale1 = 1.0f / support1; + + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + const float f00 = ((float)i0 + args.poffs) / args.sf0 - args.poffs; + + int64_t x_min = MAX((int64_t)0, (int64_t)floor(f00 - support0 + args.poffs)); + int64_t x_max = MIN(args.ne00, (int64_t)ceil (f00 + support0 + args.poffs)); + + int64_t y_min = MAX((int64_t)0, (int64_t)floor(f01 - support1 + args.poffs)); + int64_t y_max = MIN(args.ne01, (int64_t)ceil (f01 + support1 + args.poffs)); + + float sum = 0.0f; + float wsum = 0.0f; + + for (int64_t sy = y_min; sy < y_max; ++sy) { + const float wy = MAX(0.0f, 1.0f - fabs((float)sy - f01) * invscale1); + for (int64_t sx = x_min; sx < x_max; ++sx) { + const float wx = MAX(0.0f, 1.0f - fabs((float)sx - f00) * invscale0); + const float w = wx * wy; + const device const float * src_ptr = (device const float *)(src0 + sy*args.nb01 + sx*args.nb00); + sum += (*src_ptr) * w; + wsum += w; + } + } + + const float v = (wsum > 0.0f) ? (sum / wsum) : 0.0f; + dst_ptr[i0] = v; + } + } else { + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + const float f00 = ((float)i0 + args.poffs) / args.sf0 - args.poffs; + const int64_t i00 = MAX(0, MIN(args.ne00 - 1, (int64_t)floor(f00))); + const int64_t i00p = MAX(0, MIN(args.ne00 - 1, i00 + 1)); + const float fd0 = MAX(0.0f, MIN(1.0f, f00 - (float)i00)); + + device const float * src00 = (device const float *)(src0 + i01*args.nb01 + i00*args.nb00); + device const float * src10 = (device const float *)(src0 + i01*args.nb01 + i00p*args.nb00); + device const float * src01 = (device const float *)(src0 + i01p*args.nb01 + i00*args.nb00); + device const float * src11 = (device const float *)(src0 + i01p*args.nb01 + i00p*args.nb00); + + const float v = + (*src00) * (1.0f - fd0) * (1.0f - fd1) + + (*src10) * fd0 * (1.0f - fd1) + + (*src01) * (1.0f - fd0) * fd1 + + (*src11) * fd0 * fd1; + + dst_ptr[i0] = v; + } + } +} + +static inline float bicubic_weight1(float x) { + const float a = -0.75f; + return ((a + 2) * x - (a + 3)) * x * x + 1; +} + +static inline float bicubic_weight2(float x) { + const float a = -0.75f; + return ((a * x - 5 * a) * x + 8 * a) * x - 4 * a; +} + +kernel void kernel_upscale_bicubic_f32( + constant ggml_metal_kargs_upscale & args, + device const char * src0, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + + const int64_t i3 = tgpig.z; + const int64_t i2 = tgpig.y; + const int64_t i1 = tgpig.x; + + const int64_t i03 = i3 / args.sf3; + const int64_t i02 = i2 / args.sf2; + + const float f01 = ((float)i1 + args.poffs) / args.sf1 - args.poffs; + const int64_t i01 = (int64_t)floor(f01); + const float fd1 = f01 - (float)i01; + + const float w_y0 = bicubic_weight2(fd1 + 1.0f); + const float w_y1 = bicubic_weight1(fd1); + const float w_y2 = bicubic_weight1(1.0f - fd1); + const float w_y3 = bicubic_weight2(2.0f - fd1); + + const device const char * src_slice = src0 + i03 * args.nb03 + i02 * args.nb02; + + device float * dst_ptr = (device float *)(dst + i3 * args.nb3 + i2 * args.nb2 + i1 * args.nb1); + + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + const float f00 = ((float)i0 + args.poffs) / args.sf0 - args.poffs; + const int64_t i00 = (int64_t)floor(f00); + const float fd0 = f00 - (float)i00; + + const float w_x0 = bicubic_weight2(fd0 + 1.0f); + const float w_x1 = bicubic_weight1(fd0); + const float w_x2 = bicubic_weight1(1.0f - fd0); + const float w_x3 = bicubic_weight2(2.0f - fd0); + + float sum = 0.0f; + + for (int dy = -1; dy <= 2; ++dy) { + const int64_t iy = MAX(0, MIN(args.ne01 - 1, i01 + dy)); + const float wy = (dy == -1) ? w_y0 : (dy == 0) ? w_y1 : (dy == 1) ? w_y2 : w_y3; + + for (int dx = -1; dx <= 2; ++dx) { + const int64_t ix = MAX(0, MIN(args.ne00 - 1, i00 + dx)); + const float wx = (dx == -1) ? w_x0 : (dx == 0) ? w_x1 : (dx == 1) ? w_x2 : w_x3; + + const device const float * src_ptr = (device const float *)(src_slice + iy * args.nb01 + ix * args.nb00); + sum += (*src_ptr) * wx * wy; + } + } + + dst_ptr[i0] = sum; + } +} + kernel void kernel_pad_f32( constant ggml_metal_kargs_pad & args, device const char * src0, From ae21974f4f3ded7c06f97804527009612a0be948 Mon Sep 17 00:00:00 2001 From: Paul Flynn Date: Mon, 9 Mar 2026 10:48:12 -0400 Subject: [PATCH 237/831] metal : extend mul_mv_ext to BF16, Q2_K, Q3_K (llama/20250) Enable mul_mv_ext small-batch kernels (BS 2-8) for BF16, Q2_K, and Q3_K quantization types. These types previously fell through to the slower single-row mul_mv path. BF16 uses the float4 dequantize path (like F16). Q2_K and Q3_K use the float4x4 K-quant path (like Q4_K/Q5_K/Q6_K). Co-authored-by: Claude Opus 4.6 --- ggml/src/ggml-metal/ggml-metal-ops.cpp | 3 +++ ggml/src/ggml-metal/ggml-metal.metal | 17 +++++++++++++++++ 2 files changed, 20 insertions(+) diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 524e1116629..267755d08cc 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -1963,6 +1963,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) { ( op->src[0]->type == GGML_TYPE_F32 || // TODO: helper function op->src[0]->type == GGML_TYPE_F16 || + op->src[0]->type == GGML_TYPE_BF16 || op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q4_1 || op->src[0]->type == GGML_TYPE_Q5_0 || @@ -1977,6 +1978,8 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) { op->src[0]->type == GGML_TYPE_Q4_K || op->src[0]->type == GGML_TYPE_Q5_K || op->src[0]->type == GGML_TYPE_Q6_K || + op->src[0]->type == GGML_TYPE_Q2_K || + op->src[0]->type == GGML_TYPE_Q3_K || false) && (ne11 >= 4 && ne11 <= 8) ) ) diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 5cfd69dd866..82ebbb4e409 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -3481,6 +3481,13 @@ template [[host_name("kernel_mul_mv_ext_f16_f32_r1_3")]] kernel mul_mv_ext_q4 template [[host_name("kernel_mul_mv_ext_f16_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, half4, 4, dequantize_f16_t4>; template [[host_name("kernel_mul_mv_ext_f16_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, half4, 4, dequantize_f16_t4>; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, bfloat4, 4, dequantize_bf16_t4>; +template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, bfloat4, 4, dequantize_bf16_t4>; +template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, bfloat4, 4, dequantize_bf16_t4>; +template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, bfloat4, 4, dequantize_bf16_t4>; +#endif + template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q4_0, 32, dequantize_q4_0_t4>; template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q4_0, 32, dequantize_q4_0_t4>; template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q4_0, 32, dequantize_q4_0_t4>; @@ -3531,6 +3538,16 @@ template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_3")]] kernel mul_mv_ext_q4x4 template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q6_K, 256, dequantize_q6_K>; template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q6_K, 256, dequantize_q6_K>; +template [[host_name("kernel_mul_mv_ext_q2_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q2_K, 256, dequantize_q2_K>; +template [[host_name("kernel_mul_mv_ext_q2_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q2_K, 256, dequantize_q2_K>; +template [[host_name("kernel_mul_mv_ext_q2_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q2_K, 256, dequantize_q2_K>; +template [[host_name("kernel_mul_mv_ext_q2_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q2_K, 256, dequantize_q2_K>; + +template [[host_name("kernel_mul_mv_ext_q3_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q3_K, 256, dequantize_q3_K>; +template [[host_name("kernel_mul_mv_ext_q3_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q3_K, 256, dequantize_q3_K>; +template [[host_name("kernel_mul_mv_ext_q3_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q3_K, 256, dequantize_q3_K>; +template [[host_name("kernel_mul_mv_ext_q3_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q3_K, 256, dequantize_q3_K>; + template void kernel_mul_mv_t_t_impl( args_t args, From cabe3d95f4403994cc8e6e6fbb58c12e10fb9b88 Mon Sep 17 00:00:00 2001 From: Julian Pscheid Date: Mon, 9 Mar 2026 23:32:24 -0700 Subject: [PATCH 238/831] metal: handle command buffer failures gracefully in synchronize (llama/20306) Replace GGML_ABORT("fatal error") in ggml_metal_synchronize() with error flag + return. This aligns synchronize error handling with graph_compute, which already returns GGML_STATUS_FAILED for the same condition. When a command buffer fails (e.g., iOS GPU access revocation during backgrounding, macOS eGPU disconnect, OOM), the backend enters an error state instead of killing the host process. Subsequent graph_compute calls return GGML_STATUS_FAILED immediately. Recovery requires recreating the backend. Failed extra command buffers are properly released on the error path to avoid Metal object leaks. --- ggml/src/ggml-metal/ggml-metal-context.m | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-context.m b/ggml/src/ggml-metal/ggml-metal-context.m index 5d3a8ce412a..1136ce99b09 100644 --- a/ggml/src/ggml-metal/ggml-metal-context.m +++ b/ggml/src/ggml-metal/ggml-metal-context.m @@ -75,6 +75,10 @@ // abort ggml_metal_graph_compute if callback returns true ggml_abort_callback abort_callback; void * abort_callback_data; + + // error state - set when a command buffer fails during synchronize + // once set, graph_compute will return GGML_STATUS_FAILED until the backend is recreated + bool has_error; }; ggml_metal_t ggml_metal_init(ggml_metal_device_t dev) { @@ -158,6 +162,8 @@ ggml_metal_t ggml_metal_init(ggml_metal_device_t dev) { res->capture_started = false; res->capture_scope = nil; + res->has_error = false; + res->gf = nil; res->encode_async = nil; for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) { @@ -246,7 +252,8 @@ void ggml_metal_synchronize(ggml_metal_t ctx) { if (status == MTLCommandBufferStatusError) { GGML_LOG_ERROR("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]); } - GGML_ABORT("fatal error"); + ctx->has_error = true; + return; } } } @@ -262,7 +269,15 @@ void ggml_metal_synchronize(ggml_metal_t ctx) { if (status == MTLCommandBufferStatusError) { GGML_LOG_ERROR("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]); } - GGML_ABORT("fatal error"); + + // release this and all remaining command buffers before returning + for (size_t j = i; j < ctx->cmd_bufs_ext.count; ++j) { + [ctx->cmd_bufs_ext[j] release]; + } + [ctx->cmd_bufs_ext removeAllObjects]; + + ctx->has_error = true; + return; } [cmd_buf release]; @@ -414,6 +429,11 @@ bool ggml_metal_cpy_tensor_async(ggml_metal_t ctx_src, ggml_metal_t ctx_dst, con } enum ggml_status ggml_metal_graph_compute(ggml_metal_t ctx, struct ggml_cgraph * gf) { + if (ctx->has_error) { + GGML_LOG_ERROR("%s: backend is in error state from a previous command buffer failure - recreate the backend to recover\n", __func__); + return GGML_STATUS_FAILED; + } + // number of nodes encoded by the main thread (empirically determined) const int n_main = MAX(64, 0.1*gf->n_nodes); From bd64b8af4ddc8fcd01eaa3728af124a4984a1639 Mon Sep 17 00:00:00 2001 From: Taimur Ahmad Date: Tue, 10 Mar 2026 11:49:52 +0500 Subject: [PATCH 239/831] ggml-cpu: add RVV repack GEMM and GEMV for quantization types (llama/19121) * ggml-cpu: add rvv ggml_quantize_mat_4x8 for q8_0 Co-authored-by: Rehan Qasim * ggml-cpu: add rvv repacking for iq4_nl * ggml-cpu: add generic impl for iq4_nl gemm/gemv * ggml-cpu: add rvv repacking for q8_0 * ggml-cpu: refactor; add rvv repacking for q4_0, q4_K * ggml-cpu: refactor; add rvv repacking for q2_K Co-authored-by: Rehan Qasim * ggml-cpu: refactor rvv repack --------- Co-authored-by: Rehan Qasim --- ggml/src/ggml-cpu/arch-fallback.h | 3 +- ggml/src/ggml-cpu/arch/riscv/repack.cpp | 1391 +++++++++++++++++++++++ ggml/src/ggml-cpu/repack.cpp | 1205 +++++++++++++++++++- ggml/src/ggml-cpu/repack.h | 61 +- 4 files changed, 2651 insertions(+), 9 deletions(-) diff --git a/ggml/src/ggml-cpu/arch-fallback.h b/ggml/src/ggml-cpu/arch-fallback.h index ebbd4b47e05..48315610f2f 100644 --- a/ggml/src/ggml-cpu/arch-fallback.h +++ b/ggml/src/ggml-cpu/arch-fallback.h @@ -202,8 +202,9 @@ #define ggml_vec_dot_iq4_xs_q8_K_generic ggml_vec_dot_iq4_xs_q8_K #define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0 // repack.cpp +#define ggml_quantize_mat_q8_0_4x1_generic ggml_quantize_mat_q8_0_4x1 #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 -#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8 +#define ggml_quantize_mat_q8_K_4x1_generic ggml_quantize_mat_q8_K_4x1 #define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 diff --git a/ggml/src/ggml-cpu/arch/riscv/repack.cpp b/ggml/src/ggml-cpu/arch/riscv/repack.cpp index 2a35ff9ad87..cd5807879ea 100644 --- a/ggml/src/ggml-cpu/arch/riscv/repack.cpp +++ b/ggml/src/ggml-cpu/arch/riscv/repack.cpp @@ -24,6 +24,94 @@ #define UNUSED GGML_UNUSED +void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { + assert(QK8_0 == 32); + assert(k % QK8_0 == 0); + const int nb = k / QK8_0; + +#if defined(__riscv_v_intrinsic) + block_q8_0x4 * GGML_RESTRICT y = (block_q8_0x4 *) vy; + const size_t vl_calc = __riscv_vsetvl_e32m8(QK8_0); + const size_t vl_save = __riscv_vsetvl_e64m2(4); + vfloat32m1_t v_scalar_zero = __riscv_vfmv_s_f_f32m1(0.0f, __riscv_vsetvl_e32m1(1)); + + for (int i = 0; i < nb; i++) { + const float *x_block_base = x + i * QK8_0; + vint8m2_t q_r0, q_r1, q_r2, q_r3; + { + vfloat32m8_t v_src = __riscv_vle32_v_f32m8(x_block_base + 0 * k, vl_calc); + vfloat32m8_t v_abs = __riscv_vfabs_v_f32m8(v_src, vl_calc); + vfloat32m1_t v_max = __riscv_vfredmax_vs_f32m8_f32m1(v_abs, v_scalar_zero, vl_calc); + float amax = __riscv_vfmv_f_s_f32m1_f32(v_max); + + float d = amax / 127.0f; + y[i].d[0] = GGML_CPU_FP32_TO_FP16(d); + + float id = d ? 1.0f / d : 0.0f; + vfloat32m8_t v_scaled = __riscv_vfmul_vf_f32m8(v_src, id, vl_calc); + vint16m4_t v_i16 = __riscv_vfncvt_x_f_w_i16m4_rm(v_scaled, 4, vl_calc); + q_r0 = __riscv_vncvt_x_x_w_i8m2(v_i16, vl_calc); + } + asm volatile ("" ::: "memory"); + + { + vfloat32m8_t v_src = __riscv_vle32_v_f32m8(x_block_base + 1 * k, vl_calc); + vfloat32m8_t v_abs = __riscv_vfabs_v_f32m8(v_src, vl_calc); + vfloat32m1_t v_max = __riscv_vfredmax_vs_f32m8_f32m1(v_abs, v_scalar_zero, vl_calc); + float amax = __riscv_vfmv_f_s_f32m1_f32(v_max); + + float d = amax / 127.0f; + y[i].d[1] = GGML_CPU_FP32_TO_FP16(d); + float id = d ? 1.0f / d : 0.0f; + + vfloat32m8_t v_scaled = __riscv_vfmul_vf_f32m8(v_src, id, vl_calc); + vint16m4_t v_i16 = __riscv_vfncvt_x_f_w_i16m4_rm(v_scaled, 4, vl_calc); + q_r1 = __riscv_vncvt_x_x_w_i8m2(v_i16, vl_calc); + } + asm volatile ("" ::: "memory"); + { + vfloat32m8_t v_src = __riscv_vle32_v_f32m8(x_block_base + 2 * k, vl_calc); + vfloat32m8_t v_abs = __riscv_vfabs_v_f32m8(v_src, vl_calc); + vfloat32m1_t v_max = __riscv_vfredmax_vs_f32m8_f32m1(v_abs, v_scalar_zero, vl_calc); + float amax = __riscv_vfmv_f_s_f32m1_f32(v_max); + + float d = amax / 127.0f; + y[i].d[2] = GGML_CPU_FP32_TO_FP16(d); + float id = d ? 1.0f / d : 0.0f; + + vfloat32m8_t v_scaled = __riscv_vfmul_vf_f32m8(v_src, id, vl_calc); + vint16m4_t v_i16 = __riscv_vfncvt_x_f_w_i16m4_rm(v_scaled, 4, vl_calc); + q_r2 = __riscv_vncvt_x_x_w_i8m2(v_i16, vl_calc); + } + asm volatile ("" ::: "memory"); + { + vfloat32m8_t v_src = __riscv_vle32_v_f32m8(x_block_base + 3 * k, vl_calc); + vfloat32m8_t v_abs = __riscv_vfabs_v_f32m8(v_src, vl_calc); + vfloat32m1_t v_max = __riscv_vfredmax_vs_f32m8_f32m1(v_abs, v_scalar_zero, vl_calc); + float amax = __riscv_vfmv_f_s_f32m1_f32(v_max); + + float d = amax / 127.0f; + y[i].d[3] = GGML_CPU_FP32_TO_FP16(d); + float id = d ? 1.0f / d : 0.0f; + + vfloat32m8_t v_scaled = __riscv_vfmul_vf_f32m8(v_src, id, vl_calc); + vint16m4_t v_i16 = __riscv_vfncvt_x_f_w_i16m4_rm(v_scaled, 4, vl_calc); + q_r3 = __riscv_vncvt_x_x_w_i8m2(v_i16, vl_calc); + } + vint64m2_t v_q64_r0 = __riscv_vreinterpret_v_i8m2_i64m2(q_r0); + vint64m2_t v_q64_r1 = __riscv_vreinterpret_v_i8m2_i64m2(q_r1); + vint64m2_t v_q64_r2 = __riscv_vreinterpret_v_i8m2_i64m2(q_r2); + vint64m2_t v_q64_r3 = __riscv_vreinterpret_v_i8m2_i64m2(q_r3); + vint64m2x4_t v_quant_tuple = __riscv_vcreate_v_i64m2x4(v_q64_r0, v_q64_r1, v_q64_r2, v_q64_r3); + __riscv_vsseg4e64_v_i64m2x4((int64_t*)y[i].qs, v_quant_tuple, vl_save); + } +#else + UNUSED(nb); + UNUSED(y); + ggml_quantize_mat_q8_0_4x4_generic(x, vy, k); +#endif +} + void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; @@ -115,6 +203,486 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo ggml_gemv_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc); } +void ggml_gemv_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 16; + const int blocklen = 1; + + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined __riscv_v_intrinsic + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x16 * b_ptr = (const block_q4_0x16 *) vx + (x * nb); + + // 1x16 Accumulator + vfloat32m2_t sumf = __riscv_vfmv_v_f_f32m2(0.0f, 16); + + for (int l = 0; l < nb; l++) { + // 1x16 Integer Accumulator + vint16m1_t sumi_0_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_0_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + + // Accumulation loop. + for (int i = 0; i < QK4_0 / 2; i++) { + // Load `b_ptr`. + const vint8mf2_t b_0_packed = __riscv_vle8_v_i8mf2((const int8_t *)&b_ptr[l].qs[i * 16], 16); + const vint8mf2_t b_0_lo = __riscv_vsra_vx_i8mf2(__riscv_vsll_vx_i8mf2(b_0_packed, 4, 16), 4, 16); + const vint8mf2_t b_0_hi = __riscv_vsra_vx_i8mf2(b_0_packed, 4, 16); + + sumi_0_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_0_lo_16, a_ptr[l].qs[i], b_0_lo, 16); + sumi_0_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_0_hi_16, a_ptr[l].qs[16 + i], b_0_hi, 16); + } + + const vint32m2_t sumi = __riscv_vwadd_vv_i32m2(sumi_0_lo_16, sumi_0_hi_16, 16); + + const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, 16); + const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d, 16); + + sumf = __riscv_vfmacc_vv_f32m2(sumf, __riscv_vfcvt_f_x_v_f32m2(sumi, 16), d_0, 16); + } + + __riscv_vse32_v_f32m2(s + x * 16, sumf, 16); + } + return; +#endif + ggml_gemv_q4_0_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); +} + +void ggml_gemv_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK_K; + const int nb = n / qk; + const int ncols_interleaved = 16; + const int blocklen = 1; + + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined __riscv_v_intrinsic + const block_q8_K * a_ptr = (const block_q8_K *) vy; + + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_Kx16 * b_ptr = (const block_q4_Kx16 *) vx + (x * nb); + + // 1x16 Accumulator + vfloat32m2_t sumf = __riscv_vfmv_v_f_f32m2(0.0f, 16); + + for (int l = 0; l < nb; l++) { + vint32m2_t sumi = __riscv_vmv_v_x_i32m2(0, 16); + + // Load `dmin`. + const vfloat32m2_t dmins_d = __riscv_vfmul_vf_f32m2( + __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, 16), 16), a_ptr[l].d, 16); + + // We process 4 sub-blocks at once. + for (int j = 0; j < QK_K / 128; j++) { + // Extract the scales and the mins. + // + // Low bits. + vuint8m2_t scales_mins_lo = __riscv_vle8_v_u8m2(&b_ptr[l].scales[j * 64], 64); + vuint8m2_t scales_lo = __riscv_vand_vx_u8m2(scales_mins_lo, 0x0F, 64); + vuint8m2_t mins_lo = __riscv_vsrl_vx_u8m2(scales_mins_lo, 4, 64); + + // High bits. + vuint8m2_t scales_mins_hi = __riscv_vle8_v_u8m2(&b_ptr[l].scales[128], 64); + vuint8m2_t scales_hi; + vuint8m2_t mins_hi; + if (!j) { + scales_hi = __riscv_vsll_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0x03, 64), 4, 64); + mins_hi = __riscv_vsll_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0x0C, 64), 2, 64); + } else { + scales_hi = __riscv_vand_vx_u8m2(scales_mins_hi, 0x30, 64); + mins_hi = __riscv_vsrl_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0xC0, 64), 2, 64); + } + vuint16m4_t scales = __riscv_vzext_vf2_u16m4(__riscv_vor_vv_u8m2(scales_hi, scales_lo, 64), 64); + vint16m4_t mins = __riscv_vreinterpret_v_u16m4_i16m4(__riscv_vzext_vf2_u16m4(__riscv_vor_vv_u8m2(mins_hi, mins_lo, 64), 64)); + + // Reduce the mins and multiply with `dmin`. + // + // Correct in `sumf`. + vint32m2_t bsums = __riscv_vmv_v_x_i32m2(0, 16); + bsums = __riscv_vwmacc_vx_i32m2(bsums, a_ptr[l].bsums[j * 8] + a_ptr[l].bsums[j * 8 + 1], __riscv_vget_v_i16m4_i16m1(mins, 0), 16); + bsums = __riscv_vwmacc_vx_i32m2(bsums, a_ptr[l].bsums[j * 8 + 2] + a_ptr[l].bsums[j * 8 + 3], __riscv_vget_v_i16m4_i16m1(mins, 1), 16); + bsums = __riscv_vwmacc_vx_i32m2(bsums, a_ptr[l].bsums[j * 8 + 4] + a_ptr[l].bsums[j * 8 + 5], __riscv_vget_v_i16m4_i16m1(mins, 2), 16); + bsums = __riscv_vwmacc_vx_i32m2(bsums, a_ptr[l].bsums[j * 8 + 6] + a_ptr[l].bsums[j * 8 + 7], __riscv_vget_v_i16m4_i16m1(mins, 3), 16); + + sumf = __riscv_vfsub_vv_f32m2(sumf, __riscv_vfmul_vv_f32m2(dmins_d, __riscv_vfcvt_f_x_v_f32m2(bsums, 16), 16), 16); + + // Accumulation for 2 sub-blocks. + // + // This might overflow, so we accumulate in two steps. + // + // Recheck. + for (int k = 0; k < 2; k++) { + vint16m1_t sumi_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + + for (int i = k * 16; i < k * 16 + QK4_0 / 2; i++) { + // Load `b_ptr`. + const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qs[j * 1024 + i * 16], 16); + const vint8mf2_t b_s_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(b_0_packed, 0xF, 16)); + const vint8mf2_t b_s_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(b_0_packed, 4, 16)); + + sumi_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_s_0_16, a_ptr[l].qs[j * 128 + i], b_s_0, 16); + sumi_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_s_1_16, a_ptr[l].qs[j * 128 + 32 + i], b_s_1, 16); + } + + sumi = __riscv_vwmacc_vv_i32m2(sumi, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 0)), + sumi_s_0_16, 16); + sumi = __riscv_vwmacc_vv_i32m2(sumi, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 1)), + sumi_s_1_16, 16); + } + // Accumulation for 2 sub-blocks. + // + // This might overflow, so we accumulate in two steps. + // + // Recheck. + for (int k = 0; k < 2; k++) { + vint16m1_t sumi_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + + for (int i = k * 16; i < k * 16 + QK4_0 / 2; i++) { + // Load `b_ptr`. + const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qs[j * 1024 + 512 + i * 16], 16); + const vint8mf2_t b_s_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(b_0_packed, 0xF, 16)); + const vint8mf2_t b_s_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(b_0_packed, 4, 16)); + + sumi_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_s_0_16, a_ptr[l].qs[j * 128 + 64 + i], b_s_0, 16); + sumi_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_s_1_16, a_ptr[l].qs[j * 128 + 96 + i], b_s_1, 16); + } + + sumi = __riscv_vwmacc_vv_i32m2(sumi, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 2)), + sumi_s_0_16, 16); + sumi = __riscv_vwmacc_vv_i32m2(sumi, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 3)), + sumi_s_1_16, 16); + } + } + + const vfloat32m2_t b_d = __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)&b_ptr[l].d[0], 16), 16); + const vfloat32m2_t d_0 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d, 16); + + sumf = __riscv_vfmacc_vv_f32m2(sumf, __riscv_vfcvt_f_x_v_f32m2(sumi, 16), d_0, 16); + } + + __riscv_vse32_v_f32m2(s + x * 16, sumf, 16); + } + return; +#endif + ggml_gemv_q4_K_16x1_q8_K_generic(n, s, bs, vx, vy, nr, nc); +} + +void ggml_gemv_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 16; + const int blocklen = 1; + + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined __riscv_v_intrinsic + const vint8mf2_t values = __riscv_vle8_v_i8mf2(kvalues_iq4nl, 16); + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_iq4_nlx16 * b_ptr = (const block_iq4_nlx16 *) vx + (x * nb); + + // 1x16 Accumulator1 + vfloat32m2_t sumf = __riscv_vfmv_v_f_f32m2(0.0f, 16); + + for (int l = 0; l < nb; l++) { + // 1x16 integer accumulator + vint32m2_t sumi = __riscv_vmv_v_x_i32m2(0.0f, 16); + + // Accumulation loop. + for (int i = 0; i < QK4_NL / 2; i++) { + // Load `b_ptr`. + const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2((const uint8_t *)&b_ptr[l].qs[i * 16], 16); + const vint8mf2_t b_0_lo = __riscv_vrgather_vv_i8mf2(values, __riscv_vand_vx_u8mf2(b_0_packed, 0xf, 16), 16); + const vint8mf2_t b_0_hi = __riscv_vrgather_vv_i8mf2(values, __riscv_vsrl_vx_u8mf2(b_0_packed, 4, 16), 16); + // const vint16m1_t b_0_lo_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_lo, 16); + // const vint16m1_t b_0_hi_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_hi, 16); + + const vint16m1_t sumi_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i], 16); + const vint16m1_t sumi_hi = __riscv_vwmul_vx_i16m1(b_0_hi, a_ptr[l].qs[16 + i], 16); + sumi = __riscv_vadd_vv_i32m2(sumi, __riscv_vwadd_vv_i32m2(sumi_lo, sumi_hi, 16), 16); + } + + const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, 16); + const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d, 16); + + sumf = __riscv_vfmacc_vv_f32m2(sumf, __riscv_vfcvt_f_x_v_f32m2(sumi, 16), d_0, 16); + } + + __riscv_vse32_v_f32m2(s + x * 16, sumf, 16); + } + return; +#endif + ggml_gemv_iq4_nl_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); +} + +void ggml_gemv_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 16; + const int blocklen = 1; + + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + UNUSED(bs); + +#if defined __riscv_v_intrinsic + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q8_0x16 * b_ptr = (const block_q8_0x16 *) vx + (x * nb); + + // 1x16 Accumulator + vfloat32m2_t sumf = __riscv_vfmv_v_f_f32m2(0.0f, 16); + + for (int l = 0; l < nb; l++) { + // 1x16 Integer Accumulator + vint32m2_t sumi = __riscv_vmv_v_x_i32m2(0.0f, 16); + + // Accumulation loop. + for (int i = 0; i < QK8_0; i++) { + // Load `b_ptr`. + const vint8mf2_t b_0 = __riscv_vle8_v_i8mf2((const int8_t *)&b_ptr[l].qs[i * 16], 16); + // const vint16m1_t b_0_16 = __riscv_vwcvt_x_x_v_i16m1(b_0, 16); + + sumi = __riscv_vwadd_wv_i32m2(sumi, __riscv_vwmul_vx_i16m1(b_0, a_ptr[l].qs[i], 16), 16); + } + + const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, 16); + const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d, 16); + + sumf = __riscv_vfmacc_vv_f32m2(sumf, __riscv_vfcvt_f_x_v_f32m2(sumi, 16), d_0, 16); + } + + __riscv_vse32_v_f32m2(s + x * 16, sumf, 16); + } + return; +#endif + ggml_gemv_q8_0_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); +} + +void ggml_gemv_q2_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + assert(n % QK_K == 0); + assert(nr == 1); + assert(nc % 16 == 0); + + UNUSED(bs); + + const int N_COLS_TILE = 16; + const int num_k_blocks = n / QK_K; + + const size_t vl = __riscv_vsetvl_e32m2(N_COLS_TILE); + for (int col_tile = 0; col_tile < nc; col_tile += N_COLS_TILE) { + + const block_q8_K* lhs_base_ptr = (const block_q8_K*)vy; + const block_q2_Kx16* rhs_base_ptr = (const block_q2_Kx16*)vx + (col_tile / N_COLS_TILE) * num_k_blocks; + + vfloat32m2_t v_sumf = __riscv_vfmv_v_f_f32m2(0.0f, vl); + + for (int k_block = 0; k_block < num_k_blocks; ++k_block) { + const block_q8_K* lhs_current = &lhs_base_ptr[k_block]; + const block_q2_Kx16* rhs_current = &rhs_base_ptr[k_block]; + + // 1. Prepare Global Min Scales + vfloat16m1_t v_g_min_f16 = __riscv_vle16_v_f16m1((const _Float16*)rhs_current->dmin, vl); + vfloat32m2_t v_g_min_base = __riscv_vfwcvt_f_f_v_f32m2(v_g_min_f16, vl); + + vfloat32m2_t v_g_min_final = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d, vl); + + vint32m2_t v_isum = __riscv_vmv_v_x_i32m2(0, vl); + + const uint8_t* rhs_qs_ptr = rhs_current->qs; + const uint8_t* rhs_sc_ptr = rhs_current->scales; + const int8_t* lhs_qs_ptr = lhs_current->qs; + + // --- Phase Loop (4 phases x 64 elements) --- + for (int phase = 0; phase < 4; ++phase) { + + // A. Load Scales/Mins + vuint16m1_t v_d_sb_0, v_d_sb_1, v_d_sb_2, v_d_sb_3; + vuint16m1_t v_m_sb_0, v_m_sb_1, v_m_sb_2, v_m_sb_3; + + { + vuint8mf2_t v_raw; + // Sub-block 0 + v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 0, vl); + v_d_sb_0 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl); + v_m_sb_0 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl); + + // Sub-block 1 + v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 16, vl); + v_d_sb_1 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl); + v_m_sb_1 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl); + + // Sub-block 2 + v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 32, vl); + v_d_sb_2 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl); + v_m_sb_2 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl); + + // Sub-block 3 + v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 48, vl); + v_d_sb_3 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl); + v_m_sb_3 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl); + + rhs_sc_ptr += 64; + } + + int base_k_phase = (phase < 2) ? (phase * 16) : (128 + (phase-2)*16); + int k_offsets[4] = {0, 32, 64, 96}; + + // B. Inner Dot Product Loop + for (int l = 0; l < 16; ++l) { + vuint8mf2_t v_rhs_data = __riscv_vle8_v_u8mf2(rhs_qs_ptr, vl); + rhs_qs_ptr += 16; + + // Sub-block 0 + { + vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(v_rhs_data, 3, vl); + vint16m1_t v_w = __riscv_vmul_vv_i16m1( + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, vl)), + __riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_0), vl); + + int8_t q8 = lhs_qs_ptr[base_k_phase + k_offsets[0] + l]; + v_isum = __riscv_vwmacc_vx_i32m2(v_isum, (int16_t)q8, v_w, vl); + } + // Sub-block 1 + { + vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_rhs_data, 2, vl), 3, vl); + vint16m1_t v_w = __riscv_vmul_vv_i16m1( + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, vl)), + __riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_1), vl); + + int8_t q8 = lhs_qs_ptr[base_k_phase + k_offsets[1] + l]; + v_isum = __riscv_vwmacc_vx_i32m2(v_isum, (int16_t)q8, v_w, vl); + } + // Sub-block 2 + { + vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_rhs_data, 4, vl), 3, vl); + vint16m1_t v_w = __riscv_vmul_vv_i16m1( + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, vl)), + __riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_2), vl); + + int8_t q8 = lhs_qs_ptr[base_k_phase + k_offsets[2] + l]; + v_isum = __riscv_vwmacc_vx_i32m2(v_isum, (int16_t)q8, v_w, vl); + } + // Sub-block 3 + { + vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_rhs_data, 6, vl), 3, vl); + vint16m1_t v_w = __riscv_vmul_vv_i16m1( + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, vl)), + __riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_3), vl); + + int8_t q8 = lhs_qs_ptr[base_k_phase + k_offsets[3] + l]; + v_isum = __riscv_vwmacc_vx_i32m2(v_isum, (int16_t)q8, v_w, vl); + } + } + + // correction + int sb_base_abs = base_k_phase / 16; + + // Sub-block 0 + { + int sb_idx = sb_base_abs + (k_offsets[0] / 16); + int16_t bsum = lhs_current->bsums[sb_idx]; + vint16m1_t v_min = __riscv_vreinterpret_v_u16m1_i16m1(v_m_sb_0); + vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, bsum, vl); + vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min_final, vl); + v_sumf = __riscv_vfsub_vv_f32m2(v_sumf, vf_c, vl); + } + // Sub-block 1 + { + int sb_idx = sb_base_abs + (k_offsets[1] / 16); + int16_t bsum = lhs_current->bsums[sb_idx]; + vint16m1_t v_min = __riscv_vreinterpret_v_u16m1_i16m1(v_m_sb_1); + vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, bsum, vl); + vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min_final, vl); + v_sumf = __riscv_vfsub_vv_f32m2(v_sumf, vf_c, vl); + } + // Sub-block 2 + { + int sb_idx = sb_base_abs + (k_offsets[2] / 16); + int16_t bsum = lhs_current->bsums[sb_idx]; + vint16m1_t v_min = __riscv_vreinterpret_v_u16m1_i16m1(v_m_sb_2); + vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, bsum, vl); + vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min_final, vl); + v_sumf = __riscv_vfsub_vv_f32m2(v_sumf, vf_c, vl); + } + // Sub-block 3 + { + int sb_idx = sb_base_abs + (k_offsets[3] / 16); + int16_t bsum = lhs_current->bsums[sb_idx]; + vint16m1_t v_min = __riscv_vreinterpret_v_u16m1_i16m1(v_m_sb_3); + vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, bsum, vl); + vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min_final, vl); + v_sumf = __riscv_vfsub_vv_f32m2(v_sumf, vf_c, vl); + } + + } // End Phase Loop + + // Apply global Scales + vfloat16m1_t v_g_all_f16 = __riscv_vle16_v_f16m1((const _Float16*)rhs_current->d, vl); + vfloat32m2_t v_g_all_base = __riscv_vfwcvt_f_f_v_f32m2(v_g_all_f16, vl); + + vfloat32m2_t v_g_all_final = __riscv_vfmul_vf_f32m2(v_g_all_base, lhs_current->d, vl); + vfloat32m2_t v_sum = __riscv_vfcvt_f_x_v_f32m2(v_isum, vl); + v_sum = __riscv_vfmul_vv_f32m2(v_sum, v_g_all_final, vl); + v_sumf = __riscv_vfadd_vv_f32m2(v_sumf, v_sum, vl); + + } // End K-Block + __riscv_vse32_v_f32m2(s + col_tile, v_sumf, vl); + + } +} + void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; @@ -340,3 +908,826 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo #endif ggml_gemm_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc); } + +void ggml_gemm_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 16; + const int blocklen = 1; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined __riscv_v_intrinsic + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x16 * b_ptr = (const block_q4_0x16 *) vx + (x * nb); + + // 4x16 Accumulators + vfloat32m2_t sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, 16); + vfloat32m2_t sumf_1 = __riscv_vfmv_v_f_f32m2(0.0f, 16); + vfloat32m2_t sumf_2 = __riscv_vfmv_v_f_f32m2(0.0f, 16); + vfloat32m2_t sumf_3 = __riscv_vfmv_v_f_f32m2(0.0f, 16); + + for (int l = 0; l < nb; l++) { + // 4x16 integer accumulators + vint16m1_t sumi_0_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_1_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_2_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_3_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_0_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_1_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_2_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_3_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + + // Accumulation loop. + for (int i = 0; i < QK4_0 / 2; i++) { + // Load `b_ptr`. + const vint8mf2_t b_0_packed = __riscv_vle8_v_i8mf2((const int8_t *)&b_ptr[l].qs[i * 16], 16); + const vint8mf2_t b_0_lo = __riscv_vsra_vx_i8mf2(__riscv_vsll_vx_i8mf2(b_0_packed, 4, 16), 4, 16); + const vint8mf2_t b_0_hi = __riscv_vsra_vx_i8mf2(b_0_packed, 4, 16); + + sumi_0_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_0_lo_16, a_ptr[l].qs[i * 4], b_0_lo, 16); + sumi_1_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_1_lo_16, a_ptr[l].qs[i * 4 + 1], b_0_lo, 16); + sumi_2_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_2_lo_16, a_ptr[l].qs[i * 4 + 2], b_0_lo, 16); + sumi_3_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_3_lo_16, a_ptr[l].qs[i * 4 + 3], b_0_lo, 16); + + sumi_0_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_0_hi_16, a_ptr[l].qs[64 + i * 4], b_0_hi, 16); + sumi_1_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_1_hi_16, a_ptr[l].qs[64 + i * 4 + 1], b_0_hi, 16); + sumi_2_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_2_hi_16, a_ptr[l].qs[64 + i * 4 + 2], b_0_hi, 16); + sumi_3_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_3_hi_16, a_ptr[l].qs[64 + i * 4 + 3], b_0_hi, 16); + } + + // Do the final accumulation in i32 to prevent overflow. + const vint32m2_t sumi_0 = __riscv_vwadd_vv_i32m2(sumi_0_lo_16, sumi_0_hi_16, 16); + const vint32m2_t sumi_1 = __riscv_vwadd_vv_i32m2(sumi_1_lo_16, sumi_1_hi_16, 16); + const vint32m2_t sumi_2 = __riscv_vwadd_vv_i32m2(sumi_2_lo_16, sumi_2_hi_16, 16); + const vint32m2_t sumi_3 = __riscv_vwadd_vv_i32m2(sumi_3_lo_16, sumi_3_hi_16, 16); + + const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, 16); + const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[0], 16); + const vfloat32m2_t d_1 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[1], 16); + const vfloat32m2_t d_2 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[2], 16); + const vfloat32m2_t d_3 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[3], 16); + + sumf_0 = __riscv_vfmacc_vv_f32m2(sumf_0, __riscv_vfcvt_f_x_v_f32m2(sumi_0, 16), d_0, 16); + sumf_1 = __riscv_vfmacc_vv_f32m2(sumf_1, __riscv_vfcvt_f_x_v_f32m2(sumi_1, 16), d_1, 16); + sumf_2 = __riscv_vfmacc_vv_f32m2(sumf_2, __riscv_vfcvt_f_x_v_f32m2(sumi_2, 16), d_2, 16); + sumf_3 = __riscv_vfmacc_vv_f32m2(sumf_3, __riscv_vfcvt_f_x_v_f32m2(sumi_3, 16), d_3, 16); + } + + __riscv_vse32_v_f32m2(s + (y * 4 + 0) * bs + x * 16, sumf_0, 16); + __riscv_vse32_v_f32m2(s + (y * 4 + 1) * bs + x * 16, sumf_1, 16); + __riscv_vse32_v_f32m2(s + (y * 4 + 2) * bs + x * 16, sumf_2, 16); + __riscv_vse32_v_f32m2(s + (y * 4 + 3) * bs + x * 16, sumf_3, 16); + } + } + return; +#endif + ggml_gemm_q4_0_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); +} + +void ggml_gemm_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK_K; + const int nb = n / qk; + const int ncols_interleaved = 16; + const int blocklen = 1; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined __riscv_v_intrinsic + for (int y = 0; y < nr / 4; y++) { + const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_Kx16 * b_ptr = (const block_q4_Kx16 *) vx + (x * nb); + + // 4x16 Accumulators + vfloat32m2_t sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, 16); + vfloat32m2_t sumf_1 = __riscv_vfmv_v_f_f32m2(0.0f, 16); + vfloat32m2_t sumf_2 = __riscv_vfmv_v_f_f32m2(0.0f, 16); + vfloat32m2_t sumf_3 = __riscv_vfmv_v_f_f32m2(0.0f, 16); + + for (int l = 0; l < nb; l++) { + vint32m2_t sumi_0 = __riscv_vmv_v_x_i32m2(0, 16); + vint32m2_t sumi_1 = __riscv_vmv_v_x_i32m2(0, 16); + vint32m2_t sumi_2 = __riscv_vmv_v_x_i32m2(0, 16); + vint32m2_t sumi_3 = __riscv_vmv_v_x_i32m2(0, 16); + + // Load `dmin`. + const vfloat32m2_t dmins = __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, 16), 16); + + // We process 4 sub-blocks at once. + for (int j = 0; j < QK_K / 128; j++) { + // Extract the scales and the mins. + // + // Low bits. + vuint8m2_t scales_mins_lo = __riscv_vle8_v_u8m2(&b_ptr[l].scales[j * 64], 64); + vuint8m2_t scales_lo = __riscv_vand_vx_u8m2(scales_mins_lo, 0x0F, 64); + vuint8m2_t mins_lo = __riscv_vsrl_vx_u8m2(scales_mins_lo, 4, 64); + + // High bits. + vuint8m2_t scales_mins_hi = __riscv_vle8_v_u8m2(&b_ptr[l].scales[128], 64); + vuint8m2_t scales_hi; + vuint8m2_t mins_hi; + if (!j) { + scales_hi = __riscv_vsll_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0x03, 64), 4, 64); + mins_hi = __riscv_vsll_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0x0C, 64), 2, 64); + } else { + scales_hi = __riscv_vand_vx_u8m2(scales_mins_hi, 0x30, 64); + mins_hi = __riscv_vsrl_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0xC0, 64), 2, 64); + } + vuint16m4_t scales = __riscv_vzext_vf2_u16m4(__riscv_vor_vv_u8m2(scales_hi, scales_lo, 64), 64); + vint16m4_t mins = __riscv_vreinterpret_v_u16m4_i16m4(__riscv_vzext_vf2_u16m4(__riscv_vor_vv_u8m2(mins_hi, mins_lo, 64), 64)); + + // Reduce the mins and multiply with `dmin`. + // + // Correct in `sumf`. + vint32m2_t bsums_0 = __riscv_vmv_v_x_i32m2(0, 16); + vint32m2_t bsums_1 = __riscv_vmv_v_x_i32m2(0, 16); + vint32m2_t bsums_2 = __riscv_vmv_v_x_i32m2(0, 16); + vint32m2_t bsums_3 = __riscv_vmv_v_x_i32m2(0, 16); + + bsums_0 = __riscv_vwmacc_vx_i32m2(bsums_0, + a_ptr[l].bsums[j * 32] + a_ptr[l].bsums[j * 32 + 4], + __riscv_vget_v_i16m4_i16m1(mins, 0), 16); + bsums_1 = __riscv_vwmacc_vx_i32m2(bsums_1, + a_ptr[l].bsums[j * 32 + 1] + a_ptr[l].bsums[j * 32 + 5], + __riscv_vget_v_i16m4_i16m1(mins, 0), 16); + bsums_2 = __riscv_vwmacc_vx_i32m2(bsums_2, + a_ptr[l].bsums[j * 32 + 2] + a_ptr[l].bsums[j * 32 + 6], + __riscv_vget_v_i16m4_i16m1(mins, 0), 16); + bsums_3 = __riscv_vwmacc_vx_i32m2(bsums_3, + a_ptr[l].bsums[j * 32 + 3] + a_ptr[l].bsums[j * 32 + 7], + __riscv_vget_v_i16m4_i16m1(mins, 0), 16); + bsums_0 = __riscv_vwmacc_vx_i32m2(bsums_0, + a_ptr[l].bsums[j * 32 + 8] + a_ptr[l].bsums[j * 32 + 8 + 4], + __riscv_vget_v_i16m4_i16m1(mins, 1), 16); + bsums_1 = __riscv_vwmacc_vx_i32m2(bsums_1, + a_ptr[l].bsums[j * 32 + 8 + 1] + a_ptr[l].bsums[j * 32 + 8 + 5], + __riscv_vget_v_i16m4_i16m1(mins, 1), 16); + bsums_2 = __riscv_vwmacc_vx_i32m2(bsums_2, + a_ptr[l].bsums[j * 32 + 8 + 2] + a_ptr[l].bsums[j * 32 + 8 + 6], + __riscv_vget_v_i16m4_i16m1(mins, 1), 16); + bsums_3 = __riscv_vwmacc_vx_i32m2(bsums_3, + a_ptr[l].bsums[j * 32 + 8 + 3] + a_ptr[l].bsums[j * 32 + 8 + 7], + __riscv_vget_v_i16m4_i16m1(mins, 1), 16); + bsums_0 = __riscv_vwmacc_vx_i32m2(bsums_0, + a_ptr[l].bsums[j * 32 + 16] + a_ptr[l].bsums[j * 32 + 16 + 4], + __riscv_vget_v_i16m4_i16m1(mins, 2), 16); + bsums_1 = __riscv_vwmacc_vx_i32m2(bsums_1, + a_ptr[l].bsums[j * 32 + 16 + 1] + a_ptr[l].bsums[j * 32 + 16 + 5], + __riscv_vget_v_i16m4_i16m1(mins, 2), 16); + bsums_2 = __riscv_vwmacc_vx_i32m2(bsums_2, + a_ptr[l].bsums[j * 32 + 16 + 2] + a_ptr[l].bsums[j * 32 + 16 + 6], + __riscv_vget_v_i16m4_i16m1(mins, 2), 16); + bsums_3 = __riscv_vwmacc_vx_i32m2(bsums_3, + a_ptr[l].bsums[j * 32 + 16 + 3] + a_ptr[l].bsums[j * 32 + 16 + 7], + __riscv_vget_v_i16m4_i16m1(mins, 2), 16); + bsums_0 = __riscv_vwmacc_vx_i32m2(bsums_0, + a_ptr[l].bsums[j * 32 + 24 + 0] + a_ptr[l].bsums[j * 32 + 24 + 4], + __riscv_vget_v_i16m4_i16m1(mins, 3), 16); + bsums_1 = __riscv_vwmacc_vx_i32m2(bsums_1, + a_ptr[l].bsums[j * 32 + 24 + 1] + a_ptr[l].bsums[j * 32 + 24 + 5], + __riscv_vget_v_i16m4_i16m1(mins, 3), 16); + bsums_2 = __riscv_vwmacc_vx_i32m2(bsums_2, + a_ptr[l].bsums[j * 32 + 24 + 2] + a_ptr[l].bsums[j * 32 + 24 + 6], + __riscv_vget_v_i16m4_i16m1(mins, 3), 16); + bsums_3 = __riscv_vwmacc_vx_i32m2(bsums_3, + a_ptr[l].bsums[j * 32 + 24 + 3] + a_ptr[l].bsums[j * 32 + 24 + 7], + __riscv_vget_v_i16m4_i16m1(mins, 3), 16); + + const vfloat32m2_t dmins_d_0 = __riscv_vfmul_vf_f32m2(dmins, a_ptr[l].d[0], 16); + const vfloat32m2_t dmins_d_1 = __riscv_vfmul_vf_f32m2(dmins, a_ptr[l].d[1], 16); + const vfloat32m2_t dmins_d_2 = __riscv_vfmul_vf_f32m2(dmins, a_ptr[l].d[2], 16); + const vfloat32m2_t dmins_d_3 = __riscv_vfmul_vf_f32m2(dmins, a_ptr[l].d[3], 16); + + sumf_0 = __riscv_vfsub_vv_f32m2(sumf_0, __riscv_vfmul_vv_f32m2(dmins_d_0, __riscv_vfcvt_f_x_v_f32m2(bsums_0, 16), 16), 16); + sumf_1 = __riscv_vfsub_vv_f32m2(sumf_1, __riscv_vfmul_vv_f32m2(dmins_d_1, __riscv_vfcvt_f_x_v_f32m2(bsums_1, 16), 16), 16); + sumf_2 = __riscv_vfsub_vv_f32m2(sumf_2, __riscv_vfmul_vv_f32m2(dmins_d_2, __riscv_vfcvt_f_x_v_f32m2(bsums_2, 16), 16), 16); + sumf_3 = __riscv_vfsub_vv_f32m2(sumf_3, __riscv_vfmul_vv_f32m2(dmins_d_3, __riscv_vfcvt_f_x_v_f32m2(bsums_3, 16), 16), 16); + + + // Accumulation for 2 sub-blocks. + // + // This might overflow, so we accumulate in two steps. + // + // Recheck. + for (int k = 0; k < 2; k++) { + // 4x16 integer accumulators + vint16m1_t sumi_0_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_1_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_2_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_3_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_0_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_1_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_2_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_3_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + + for (int i = k * 16; i < k * 16 + QK4_0 / 2; i++) { + // Load `b_ptr`. + const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qs[j * 1024 + i * 16], 16); + const vint8mf2_t b_s_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(b_0_packed, 0xF, 16)); + const vint8mf2_t b_s_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(b_0_packed, 4, 16)); + + sumi_0_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_0_s_0_16, a_ptr[l].qs[j * 512 + i * 4], b_s_0, 16); + sumi_1_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_1_s_0_16, a_ptr[l].qs[j * 512 + i * 4 + 1], b_s_0, 16); + sumi_2_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_2_s_0_16, a_ptr[l].qs[j * 512 + i * 4 + 2], b_s_0, 16); + sumi_3_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_3_s_0_16, a_ptr[l].qs[j * 512 + i * 4 + 3], b_s_0, 16); + + sumi_0_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_0_s_1_16, a_ptr[l].qs[j * 512 + 128 + i * 4], b_s_1, 16); + sumi_1_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_1_s_1_16, a_ptr[l].qs[j * 512 + 128 + i * 4 + 1], b_s_1, 16); + sumi_2_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_2_s_1_16, a_ptr[l].qs[j * 512 + 128 + i * 4 + 2], b_s_1, 16); + sumi_3_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_3_s_1_16, a_ptr[l].qs[j * 512 + 128 + i * 4 + 3], b_s_1, 16); + } + + sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 0)), + sumi_0_s_0_16, 16); + sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 1)), + sumi_0_s_1_16, 16); + sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 0)), + sumi_1_s_0_16, 16); + sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 1)), + sumi_1_s_1_16, 16); + sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 0)), + sumi_2_s_0_16, 16); + sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 1)), + sumi_2_s_1_16, 16); + sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 0)), + sumi_3_s_0_16, 16); + sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 1)), + sumi_3_s_1_16, 16); + } + // Accumulation for 2 sub-blocks. + // + // This might overflow, so we accumulate in two steps. + // + // Recheck. + for (int k = 0; k < 2; k++) { + // 4x16 integer accumulators + vint16m1_t sumi_0_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_1_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_2_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_3_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_0_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_1_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_2_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_3_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + + for (int i = k * 16; i < k * 16 + QK4_0 / 2; i++) { + // Load `b_ptr`. + const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qs[j * 1024 + 512 + i * 16], 16); + const vint8mf2_t b_s_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(b_0_packed, 0xF, 16)); + const vint8mf2_t b_s_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(b_0_packed, 4, 16)); + + sumi_0_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_0_s_0_16, a_ptr[l].qs[j * 512 + 256 + i * 4], b_s_0, 16); + sumi_1_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_1_s_0_16, a_ptr[l].qs[j * 512 + 256 + i * 4 + 1], b_s_0, 16); + sumi_2_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_2_s_0_16, a_ptr[l].qs[j * 512 + 256 + i * 4 + 2], b_s_0, 16); + sumi_3_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_3_s_0_16, a_ptr[l].qs[j * 512 + 256 + i * 4 + 3], b_s_0, 16); + + sumi_0_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_0_s_1_16, a_ptr[l].qs[j * 512 + 384 + i * 4], b_s_1, 16); + sumi_1_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_1_s_1_16, a_ptr[l].qs[j * 512 + 384 + i * 4 + 1], b_s_1, 16); + sumi_2_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_2_s_1_16, a_ptr[l].qs[j * 512 + 384 + i * 4 + 2], b_s_1, 16); + sumi_3_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_3_s_1_16, a_ptr[l].qs[j * 512 + 384 + i * 4 + 3], b_s_1, 16); + } + + sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 2)), + sumi_0_s_0_16, 16); + sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 3)), + sumi_0_s_1_16, 16); + sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 2)), + sumi_1_s_0_16, 16); + sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 3)), + sumi_1_s_1_16, 16); + sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 2)), + sumi_2_s_0_16, 16); + sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 3)), + sumi_2_s_1_16, 16); + sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 2)), + sumi_3_s_0_16, 16); + sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 3)), + sumi_3_s_1_16, 16); + } + } + + const vfloat32m2_t b_d = __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, 16), 16); + const vfloat32m2_t d_0 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d[0], 16); + const vfloat32m2_t d_1 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d[1], 16); + const vfloat32m2_t d_2 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d[2], 16); + const vfloat32m2_t d_3 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d[3], 16); + + sumf_0 = __riscv_vfmacc_vv_f32m2(sumf_0, __riscv_vfcvt_f_x_v_f32m2(sumi_0, 16), d_0, 16); + sumf_1 = __riscv_vfmacc_vv_f32m2(sumf_1, __riscv_vfcvt_f_x_v_f32m2(sumi_1, 16), d_1, 16); + sumf_2 = __riscv_vfmacc_vv_f32m2(sumf_2, __riscv_vfcvt_f_x_v_f32m2(sumi_2, 16), d_2, 16); + sumf_3 = __riscv_vfmacc_vv_f32m2(sumf_3, __riscv_vfcvt_f_x_v_f32m2(sumi_3, 16), d_3, 16); + } + + __riscv_vse32_v_f32m2(s + (y * 4 + 0) * bs + x * 16, sumf_0, 16); + __riscv_vse32_v_f32m2(s + (y * 4 + 1) * bs + x * 16, sumf_1, 16); + __riscv_vse32_v_f32m2(s + (y * 4 + 2) * bs + x * 16, sumf_2, 16); + __riscv_vse32_v_f32m2(s + (y * 4 + 3) * bs + x * 16, sumf_3, 16); + } + } + return; +#endif + ggml_gemm_q4_K_16x1_q8_K_generic(n, s, bs, vx, vy, nr, nc); +} + +void ggml_gemm_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 16; + const int blocklen = 1; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined __riscv_v_intrinsic + const vint8mf2_t values = __riscv_vle8_v_i8mf2(kvalues_iq4nl, 16); + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_iq4_nlx16 * b_ptr = (const block_iq4_nlx16 *) vx + (x * nb); + + // 4x16 Accumulators + vfloat32m2_t sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, 16); + vfloat32m2_t sumf_1 = __riscv_vfmv_v_f_f32m2(0.0f, 16); + vfloat32m2_t sumf_2 = __riscv_vfmv_v_f_f32m2(0.0f, 16); + vfloat32m2_t sumf_3 = __riscv_vfmv_v_f_f32m2(0.0f, 16); + + for (int l = 0; l < nb; l++) { + // 4x16 integer accumulators + vint32m2_t sumi_0 = __riscv_vmv_v_x_i32m2(0.0f, 16); + vint32m2_t sumi_1 = __riscv_vmv_v_x_i32m2(0.0f, 16); + vint32m2_t sumi_2 = __riscv_vmv_v_x_i32m2(0.0f, 16); + vint32m2_t sumi_3 = __riscv_vmv_v_x_i32m2(0.0f, 16); + + // Accumulation loop. + for (int i = 0; i < QK4_NL / 2; i++) { + // Load `b_ptr`. + const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2((const uint8_t *)&b_ptr[l].qs[i * 16], 16); + const vint8mf2_t b_0_lo = __riscv_vrgather_vv_i8mf2(values, __riscv_vand_vx_u8mf2(b_0_packed, 0xf, 16), 16); + const vint8mf2_t b_0_hi = __riscv_vrgather_vv_i8mf2(values, __riscv_vsrl_vx_u8mf2(b_0_packed, 4, 16), 16); + // const vint16m1_t b_0_lo_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_lo, 16); + // const vint16m1_t b_0_hi_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_hi, 16); + + const vint16m1_t sumi_0_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i * 4], 16); + const vint16m1_t sumi_1_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i * 4 + 1], 16); + const vint16m1_t sumi_2_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i * 4 + 2], 16); + const vint16m1_t sumi_3_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i * 4 + 3], 16); + + const vint16m1_t sumi_0_hi = __riscv_vwmul_vx_i16m1(b_0_hi, a_ptr[l].qs[64 + i * 4], 16); + const vint16m1_t sumi_1_hi = __riscv_vwmul_vx_i16m1(b_0_hi, a_ptr[l].qs[64 + i * 4 + 1], 16); + const vint16m1_t sumi_2_hi = __riscv_vwmul_vx_i16m1(b_0_hi, a_ptr[l].qs[64 + i * 4 + 2], 16); + const vint16m1_t sumi_3_hi = __riscv_vwmul_vx_i16m1(b_0_hi, a_ptr[l].qs[64 + i * 4 + 3], 16); + + sumi_0 = __riscv_vadd_vv_i32m2(sumi_0, __riscv_vwadd_vv_i32m2(sumi_0_lo, sumi_0_hi, 16), 16); + sumi_1 = __riscv_vadd_vv_i32m2(sumi_1, __riscv_vwadd_vv_i32m2(sumi_1_lo, sumi_1_hi, 16), 16); + sumi_2 = __riscv_vadd_vv_i32m2(sumi_2, __riscv_vwadd_vv_i32m2(sumi_2_lo, sumi_2_hi, 16), 16); + sumi_3 = __riscv_vadd_vv_i32m2(sumi_3, __riscv_vwadd_vv_i32m2(sumi_3_lo, sumi_3_hi, 16), 16); + } + + const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, 16); + const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[0], 16); + const vfloat32m2_t d_1 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[1], 16); + const vfloat32m2_t d_2 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[2], 16); + const vfloat32m2_t d_3 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[3], 16); + + sumf_0 = __riscv_vfmacc_vv_f32m2(sumf_0, __riscv_vfcvt_f_x_v_f32m2(sumi_0, 16), d_0, 16); + sumf_1 = __riscv_vfmacc_vv_f32m2(sumf_1, __riscv_vfcvt_f_x_v_f32m2(sumi_1, 16), d_1, 16); + sumf_2 = __riscv_vfmacc_vv_f32m2(sumf_2, __riscv_vfcvt_f_x_v_f32m2(sumi_2, 16), d_2, 16); + sumf_3 = __riscv_vfmacc_vv_f32m2(sumf_3, __riscv_vfcvt_f_x_v_f32m2(sumi_3, 16), d_3, 16); + } + + __riscv_vse32_v_f32m2(s + (y * 4 + 0) * bs + x * 16, sumf_0, 16); + __riscv_vse32_v_f32m2(s + (y * 4 + 1) * bs + x * 16, sumf_1, 16); + __riscv_vse32_v_f32m2(s + (y * 4 + 2) * bs + x * 16, sumf_2, 16); + __riscv_vse32_v_f32m2(s + (y * 4 + 3) * bs + x * 16, sumf_3, 16); + } + } + return; +#endif + ggml_gemm_iq4_nl_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); +} + +void ggml_gemm_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 16; + const int blocklen = 1; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined __riscv_v_intrinsic + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q8_0x16 * b_ptr = (const block_q8_0x16 *) vx + (x * nb); + + // 4x16 Accumulators + vfloat32m2_t sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, 16); + vfloat32m2_t sumf_1 = __riscv_vfmv_v_f_f32m2(0.0f, 16); + vfloat32m2_t sumf_2 = __riscv_vfmv_v_f_f32m2(0.0f, 16); + vfloat32m2_t sumf_3 = __riscv_vfmv_v_f_f32m2(0.0f, 16); + + for (int l = 0; l < nb; l++) { + // 4x16 Integer Accumulators + vint32m2_t sumi_0 = __riscv_vmv_v_x_i32m2(0.0f, 16); + vint32m2_t sumi_1 = __riscv_vmv_v_x_i32m2(0.0f, 16); + vint32m2_t sumi_2 = __riscv_vmv_v_x_i32m2(0.0f, 16); + vint32m2_t sumi_3 = __riscv_vmv_v_x_i32m2(0.0f, 16); + + // Accumulation loop. + for (int i = 0; i < QK8_0; i++) { + // Load `b_ptr`. + const vint8mf2_t b_0 = __riscv_vle8_v_i8mf2((const int8_t *)&b_ptr[l].qs[i * 16], 16); + // const vint16m1_t b_0_16 = __riscv_vwcvt_x_x_v_i16m1(b_0, 16); + + sumi_0 = __riscv_vwadd_wv_i32m2(sumi_0, __riscv_vwmul_vx_i16m1(b_0, a_ptr[l].qs[i * 4 + 0], 16), 16); + sumi_1 = __riscv_vwadd_wv_i32m2(sumi_1, __riscv_vwmul_vx_i16m1(b_0, a_ptr[l].qs[i * 4 + 1], 16), 16); + sumi_2 = __riscv_vwadd_wv_i32m2(sumi_2, __riscv_vwmul_vx_i16m1(b_0, a_ptr[l].qs[i * 4 + 2], 16), 16); + sumi_3 = __riscv_vwadd_wv_i32m2(sumi_3, __riscv_vwmul_vx_i16m1(b_0, a_ptr[l].qs[i * 4 + 3], 16), 16); + } + + const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, 16); + const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[0], 16); + const vfloat32m2_t d_1 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[1], 16); + const vfloat32m2_t d_2 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[2], 16); + const vfloat32m2_t d_3 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[3], 16); + + sumf_0 = __riscv_vfmacc_vv_f32m2(sumf_0, __riscv_vfcvt_f_x_v_f32m2(sumi_0, 16), d_0, 16); + sumf_1 = __riscv_vfmacc_vv_f32m2(sumf_1, __riscv_vfcvt_f_x_v_f32m2(sumi_1, 16), d_1, 16); + sumf_2 = __riscv_vfmacc_vv_f32m2(sumf_2, __riscv_vfcvt_f_x_v_f32m2(sumi_2, 16), d_2, 16); + sumf_3 = __riscv_vfmacc_vv_f32m2(sumf_3, __riscv_vfcvt_f_x_v_f32m2(sumi_3, 16), d_3, 16); + } + + __riscv_vse32_v_f32m2(s + (y * 4 + 0) * bs + x * 16, sumf_0, 16); + __riscv_vse32_v_f32m2(s + (y * 4 + 1) * bs + x * 16, sumf_1, 16); + __riscv_vse32_v_f32m2(s + (y * 4 + 2) * bs + x * 16, sumf_2, 16); + __riscv_vse32_v_f32m2(s + (y * 4 + 3) * bs + x * 16, sumf_3, 16); + } + } + return; +#endif + ggml_gemm_q8_0_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); +} + +void ggml_gemm_q2_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + assert(n % QK_K == 0); + const int num_k_blocks = n / QK_K; + const int N_ROWS_TILE = 4; + const int N_COLS_TILE = 16; + assert(nr % N_ROWS_TILE == 0); + assert(nc % N_COLS_TILE == 0); + + const size_t vl = __riscv_vsetvl_e32m2(N_COLS_TILE); + // --- Tiling Loops --- +#pragma GCC unroll 1 + for (int row_tile = 0; row_tile < nr; row_tile += N_ROWS_TILE) { +#pragma GCC unroll 1 + for (int col_tile = 0; col_tile < nc; col_tile += N_COLS_TILE) { + // Base Pointers + const block_q8_Kx4* lhs_base_ptr = (const block_q8_Kx4*)vy + (row_tile / N_ROWS_TILE) * num_k_blocks; + const block_q2_Kx16* rhs_base_ptr = (const block_q2_Kx16*)vx + (col_tile / N_COLS_TILE) * num_k_blocks; + + // Persistent Float Accumulators + vfloat32m2_t v_sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, vl); + vfloat32m2_t v_sumf_1 = __riscv_vfmv_v_f_f32m2(0.0f, vl); + vfloat32m2_t v_sumf_2 = __riscv_vfmv_v_f_f32m2(0.0f, vl); + vfloat32m2_t v_sumf_3 = __riscv_vfmv_v_f_f32m2(0.0f, vl); + + // --- Super-Block Loop (K=0..255) --- +#pragma GCC unroll 1 + for (int k_block = 0; k_block < num_k_blocks; ++k_block) { + const block_q8_Kx4* lhs_current = &lhs_base_ptr[k_block]; + const block_q2_Kx16* rhs_current = &rhs_base_ptr[k_block]; + + // 1. Load Global Min Scales (Keep as F16/LMUL=1 to save registers) + vfloat16m1_t v_g_min_f16 = __riscv_vle16_v_f16m1((const _Float16*)rhs_current->dmin, vl); + vfloat32m2_t v_g_min_base = __riscv_vfwcvt_f_f_v_f32m2(v_g_min_f16, vl); + + // 2. Initialize Integer Accumulators + vint32m2_t v_isum_0 = __riscv_vmv_v_x_i32m2(0, vl); + vint32m2_t v_isum_1 = __riscv_vmv_v_x_i32m2(0, vl); + vint32m2_t v_isum_2 = __riscv_vmv_v_x_i32m2(0, vl); + vint32m2_t v_isum_3 = __riscv_vmv_v_x_i32m2(0, vl); + + const uint8_t* rhs_qs_ptr = rhs_current->qs; + const uint8_t* rhs_sc_ptr = rhs_current->scales; + const int8_t* lhs_qs_ptr = lhs_current->qs; + + // --- Phase Loop (4 phases x 64 elements) --- +#pragma GCC unroll 1 + for (int phase = 0; phase < 4; ++phase) { + + // A. Load Scales/Mins for the 4 interleaved sub-blocks + vuint16m1_t v_d_sb_0, v_d_sb_1, v_d_sb_2, v_d_sb_3; + vuint16m1_t v_m_sb_0, v_m_sb_1, v_m_sb_2, v_m_sb_3; + + // Unrolled Load Logic + { + vuint8mf2_t v_raw; + // Sub-block 0 + v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 0, vl); + v_d_sb_0 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl); + v_m_sb_0 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl); + + // Sub-block 1 + v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 16, vl); + v_d_sb_1 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl); + v_m_sb_1 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl); + + // Sub-block 2 + v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 32, vl); + v_d_sb_2 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl); + v_m_sb_2 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl); + + // Sub-block 3 + v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 48, vl); + v_d_sb_3 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl); + v_m_sb_3 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl); + + rhs_sc_ptr += 64; + } + + int base_k_phase = (phase < 2) ? (phase * 16) : (128 + (phase-2)*16); + int k_offsets[4] = {0, 32, 64, 96}; + + // B. Inner Dot Product Loop +#pragma GCC unroll 1 + for (int l = 0; l < 16; ++l) { + vuint8mf2_t v_rhs_data = __riscv_vle8_v_u8mf2(rhs_qs_ptr, vl); + rhs_qs_ptr += 16; + + // Unroll over 4 sub-blocks (0, 1, 2, 3 relative to phase) + + // --- Sub-block 0 --- + { + vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(v_rhs_data, 3, vl); + vint16m1_t v_w = __riscv_vmul_vv_i16m1( + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, vl)), + __riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_0), vl); + + const int8_t* q8 = &lhs_qs_ptr[(base_k_phase + k_offsets[0] + l) * 4]; + v_isum_0 = __riscv_vwmacc_vx_i32m2(v_isum_0, (int16_t)q8[0], v_w, vl); + v_isum_1 = __riscv_vwmacc_vx_i32m2(v_isum_1, (int16_t)q8[1], v_w, vl); + v_isum_2 = __riscv_vwmacc_vx_i32m2(v_isum_2, (int16_t)q8[2], v_w, vl); + v_isum_3 = __riscv_vwmacc_vx_i32m2(v_isum_3, (int16_t)q8[3], v_w, vl); + } + // --- Sub-block 1 --- + { + vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_rhs_data, 2, vl), 3, vl); + vint16m1_t v_w = __riscv_vmul_vv_i16m1( + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, vl)), + __riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_1), vl); + + const int8_t* q8 = &lhs_qs_ptr[(base_k_phase + k_offsets[1] + l) * 4]; + v_isum_0 = __riscv_vwmacc_vx_i32m2(v_isum_0, (int16_t)q8[0], v_w, vl); + v_isum_1 = __riscv_vwmacc_vx_i32m2(v_isum_1, (int16_t)q8[1], v_w, vl); + v_isum_2 = __riscv_vwmacc_vx_i32m2(v_isum_2, (int16_t)q8[2], v_w, vl); + v_isum_3 = __riscv_vwmacc_vx_i32m2(v_isum_3, (int16_t)q8[3], v_w, vl); + } + // --- Sub-block 2 --- + { + vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_rhs_data, 4, vl), 3, vl); + vint16m1_t v_w = __riscv_vmul_vv_i16m1( + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, vl)), + __riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_2), vl); + + const int8_t* q8 = &lhs_qs_ptr[(base_k_phase + k_offsets[2] + l) * 4]; + v_isum_0 = __riscv_vwmacc_vx_i32m2(v_isum_0, (int16_t)q8[0], v_w, vl); + v_isum_1 = __riscv_vwmacc_vx_i32m2(v_isum_1, (int16_t)q8[1], v_w, vl); + v_isum_2 = __riscv_vwmacc_vx_i32m2(v_isum_2, (int16_t)q8[2], v_w, vl); + v_isum_3 = __riscv_vwmacc_vx_i32m2(v_isum_3, (int16_t)q8[3], v_w, vl); + } + // --- Sub-block 3 --- + { + vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_rhs_data, 6, vl), 3, vl); + vint16m1_t v_w = __riscv_vmul_vv_i16m1( + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, vl)), + __riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_3), vl); + + const int8_t* q8 = &lhs_qs_ptr[(base_k_phase + k_offsets[3] + l) * 4]; + v_isum_0 = __riscv_vwmacc_vx_i32m2(v_isum_0, (int16_t)q8[0], v_w, vl); + v_isum_1 = __riscv_vwmacc_vx_i32m2(v_isum_1, (int16_t)q8[1], v_w, vl); + v_isum_2 = __riscv_vwmacc_vx_i32m2(v_isum_2, (int16_t)q8[2], v_w, vl); + v_isum_3 = __riscv_vwmacc_vx_i32m2(v_isum_3, (int16_t)q8[3], v_w, vl); + } + } + + // C CORRECTION + int sb_base_abs = base_k_phase / 16; + + // --- Correction Sub-block 0 --- + { + int sb_abs = sb_base_abs + (k_offsets[0] / 16); + vint16m1_t v_min = __riscv_vreinterpret_v_u16m1_i16m1(v_m_sb_0); + + // Row 0 + vfloat32m2_t v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[0], vl); + vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 0], vl); + vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl); + v_sumf_0 = __riscv_vfsub_vv_f32m2(v_sumf_0, vf_c, vl); + + // Row 1 + v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[1], vl); + v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 1], vl); + vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl); + v_sumf_1 = __riscv_vfsub_vv_f32m2(v_sumf_1, vf_c, vl); + + // Row 2 + v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[2], vl); + v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 2], vl); + vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl); + v_sumf_2 = __riscv_vfsub_vv_f32m2(v_sumf_2, vf_c, vl); + + // Row 3 + v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[3], vl); + v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 3], vl); + vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl); + v_sumf_3 = __riscv_vfsub_vv_f32m2(v_sumf_3, vf_c, vl); + } + + // --- Correction Sub-block 1 --- + { + int sb_abs = sb_base_abs + (k_offsets[1] / 16); + vint16m1_t v_min = __riscv_vreinterpret_v_u16m1_i16m1(v_m_sb_1); + + vfloat32m2_t v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[0], vl); + vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 0], vl); + vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl); + v_sumf_0 = __riscv_vfsub_vv_f32m2(v_sumf_0, vf_c, vl); + + v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[1], vl); + v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 1], vl); + vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl); + v_sumf_1 = __riscv_vfsub_vv_f32m2(v_sumf_1, vf_c, vl); + + v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[2], vl); + v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 2], vl); + vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl); + v_sumf_2 = __riscv_vfsub_vv_f32m2(v_sumf_2, vf_c, vl); + + v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[3], vl); + v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 3], vl); + vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl); + v_sumf_3 = __riscv_vfsub_vv_f32m2(v_sumf_3, vf_c, vl); + } + + // --- Correction Sub-block 2 --- + { + int sb_abs = sb_base_abs + (k_offsets[2] / 16); + vint16m1_t v_min = __riscv_vreinterpret_v_u16m1_i16m1(v_m_sb_2); + + vfloat32m2_t v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[0], vl); + vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 0], vl); + vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl); + v_sumf_0 = __riscv_vfsub_vv_f32m2(v_sumf_0, vf_c, vl); + + v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[1], vl); + v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 1], vl); + vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl); + v_sumf_1 = __riscv_vfsub_vv_f32m2(v_sumf_1, vf_c, vl); + + v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[2], vl); + v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 2], vl); + vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl); + v_sumf_2 = __riscv_vfsub_vv_f32m2(v_sumf_2, vf_c, vl); + + v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[3], vl); + v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 3], vl); + vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl); + v_sumf_3 = __riscv_vfsub_vv_f32m2(v_sumf_3, vf_c, vl); + } + + // --- Correction Sub-block 3 --- + { + int sb_abs = sb_base_abs + (k_offsets[3] / 16); + vint16m1_t v_min = __riscv_vreinterpret_v_u16m1_i16m1(v_m_sb_3); + + vfloat32m2_t v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[0], vl); + vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 0], vl); + vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl); + v_sumf_0 = __riscv_vfsub_vv_f32m2(v_sumf_0, vf_c, vl); + + v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[1], vl); + v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 1], vl); + vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl); + v_sumf_1 = __riscv_vfsub_vv_f32m2(v_sumf_1, vf_c, vl); + + v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[2], vl); + v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 2], vl); + vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl); + v_sumf_2 = __riscv_vfsub_vv_f32m2(v_sumf_2, vf_c, vl); + + v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[3], vl); + v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 3], vl); + vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl); + v_sumf_3 = __riscv_vfsub_vv_f32m2(v_sumf_3, vf_c, vl); + } + + } // End Phase Loop + + // --- Apply Main Scales --- + vfloat16m1_t v_g_all_f16 = __riscv_vle16_v_f16m1((const _Float16*)rhs_current->d, vl); + vfloat32m2_t v_g_all_base = __riscv_vfwcvt_f_f_v_f32m2(v_g_all_f16, vl); + + { + vfloat32m2_t v_g_all = __riscv_vfmul_vf_f32m2(v_g_all_base, lhs_current->d[0], vl); + vfloat32m2_t v_sum = __riscv_vfcvt_f_x_v_f32m2(v_isum_0, vl); + v_sum = __riscv_vfmul_vv_f32m2(v_sum, v_g_all, vl); + v_sumf_0 = __riscv_vfadd_vv_f32m2(v_sumf_0, v_sum, vl); + } + // Row 1 + { + vfloat32m2_t v_g_all = __riscv_vfmul_vf_f32m2(v_g_all_base, lhs_current->d[1], vl); + vfloat32m2_t v_sum = __riscv_vfcvt_f_x_v_f32m2(v_isum_1, vl); + v_sum = __riscv_vfmul_vv_f32m2(v_sum, v_g_all, vl); + v_sumf_1 = __riscv_vfadd_vv_f32m2(v_sumf_1, v_sum, vl); + } + // Row 2 + { + vfloat32m2_t v_g_all = __riscv_vfmul_vf_f32m2(v_g_all_base, lhs_current->d[2], vl); + vfloat32m2_t v_sum = __riscv_vfcvt_f_x_v_f32m2(v_isum_2, vl); + v_sum = __riscv_vfmul_vv_f32m2(v_sum, v_g_all, vl); + v_sumf_2 = __riscv_vfadd_vv_f32m2(v_sumf_2, v_sum, vl); + } + // Row 3 + { + vfloat32m2_t v_g_all = __riscv_vfmul_vf_f32m2(v_g_all_base, lhs_current->d[3], vl); + vfloat32m2_t v_sum = __riscv_vfcvt_f_x_v_f32m2(v_isum_3, vl); + v_sum = __riscv_vfmul_vv_f32m2(v_sum, v_g_all, vl); + v_sumf_3 = __riscv_vfadd_vv_f32m2(v_sumf_3, v_sum, vl); + } + + } // End K-Block + + __riscv_vse32_v_f32m2(s + (row_tile + 0) * bs + col_tile, v_sumf_0, vl); + __riscv_vse32_v_f32m2(s + (row_tile + 1) * bs + col_tile, v_sumf_1, vl); + __riscv_vse32_v_f32m2(s + (row_tile + 2) * bs + col_tile, v_sumf_2, vl); + __riscv_vse32_v_f32m2(s + (row_tile + 3) * bs + col_tile, v_sumf_3, vl); + } + } +} diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp index 02c3cc3119b..6b76ab3bfb1 100644 --- a/ggml/src/ggml-cpu/repack.cpp +++ b/ggml/src/ggml-cpu/repack.cpp @@ -48,6 +48,90 @@ static inline int nearest_int(float fval) { extern "C" { +#if defined __riscv_zvfh +void ggml_quantize_mat_q8_0_4x1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { + assert(QK8_0 == 32); + assert(k % QK8_0 == 0); + const int nb = k / QK8_0; + + block_q8_0x4 * GGML_RESTRICT y = (block_q8_0x4 *) vy; + + // scalar + const int blck_size_interleave = 1; + float srcv[4][QK8_0]; + float id[4]; + + for (int i = 0; i < nb; i++) { + for (int row_iter = 0; row_iter < 4; row_iter++) { + float amax = 0.0f; // absolute max + + for (int j = 0; j < QK8_0; j++) { + srcv[row_iter][j] = x[row_iter * k + i * QK8_0 + j]; + amax = MAX(amax, fabsf(srcv[row_iter][j])); + } + + const float d = amax / ((1 << 7) - 1); + id[row_iter] = d ? 1.0f / d : 0.0f; + + y[i].d[row_iter] = GGML_CPU_FP32_TO_FP16(d); + } + + for (int j = 0; j < QK8_0 * 4; j++) { + int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave; + int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave; + src_offset += (j % blck_size_interleave); + + float x0 = srcv[src_id][src_offset] * id[src_id]; + y[i].qs[j] = roundf(x0); + } + } +} + +void ggml_quantize_mat_q8_K_4x1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { + assert(QK_K == 256); + assert(k % QK_K == 0); + const int nb = k / QK_K; + + block_q8_Kx4 * GGML_RESTRICT y = (block_q8_Kx4 *) vy; + + const int blck_size_interleave = 1; + float srcv[4][QK_K]; + float iscale[4]; + + for (int i = 0; i < nb; i++) { + for (int row_iter = 0; row_iter < 4; row_iter++) { + float amax = 0.0f; // absolute max + float max = 0; + + for (int j = 0; j < QK_K; j++) { + srcv[row_iter][j] = x[row_iter * k + i * QK_K + j]; + // Update the maximum value of the corresponding super block + if(amax < fabsf(srcv[row_iter][j])) { + amax = fabsf(srcv[row_iter][j]); + max = srcv[row_iter][j]; + } + } + + iscale[row_iter] = amax ? -127.f/max : 0; + y[i].d[row_iter] = amax ? 1/iscale[row_iter] : 0; + } + + for (int j = 0; j < QK_K / 4; j++) { + y[i].bsums[j] = 0; + } + for (int j = 0; j < QK_K * 4; j++) { + int src_id = j % 4; + int src_offset = j / 4; + int index = ((j >> 6) << 2) + (j & 3); + + float x0 = srcv[src_id][src_offset] * iscale[src_id]; + y[i].qs[j] = nearest_int(x0); + y[i].bsums[index] += y[i].qs[j]; + } + } +} +#endif + void ggml_quantize_mat_q8_0_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { assert(QK8_0 == 32); assert(k % QK8_0 == 0); @@ -124,7 +208,6 @@ void ggml_quantize_mat_q8_0_4x8_generic(const float * GGML_RESTRICT x, void * GG } } - void ggml_quantize_mat_q8_K_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { assert(QK_K == 256); assert(k % QK_K == 0); @@ -256,6 +339,20 @@ template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_K>(const float * GGML_RESTR ggml_quantize_mat_q8_K_4x8(x, vy, n_per_row); } +#if defined __riscv_zvfh +template <> void ggml_quantize_mat_t<1, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) { + assert(nrow == 4); + UNUSED(nrow); + ggml_quantize_mat_q8_0_4x1(x, vy, n_per_row); +} + +template <> void ggml_quantize_mat_t<1, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) { + assert(nrow == 4); + UNUSED(nrow); + ggml_quantize_mat_q8_K_4x1(x, vy, n_per_row); +} +#endif + template static void ggml_gemv_q6_K_NxM_q8_K_generic_impl(int n, float * GGML_RESTRICT s, @@ -1268,6 +1365,294 @@ void ggml_gemv_q8_0_4x8_q8_0_generic(int n, } } +#if defined __riscv_zvfh +void ggml_gemv_q4_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 16; + const int blocklen = 1; + + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + + float sumf[16]; + int sumi; + + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x16 * b_ptr = (const block_q4_0x16 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); + sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4; + } + sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d); + } + } + } + for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; + } +} + +void ggml_gemv_q4_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK_K; + const int nb = n / qk; + const int ncols_interleaved = 16; + const int blocklen = 1; + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + float sumf[16]; + float sum_minf[16]; + uint8_t scales[128]; + uint8_t mins[128]; + int sumi1; + int sumi2; + int sumi; + const block_q8_K * a_ptr = (const block_q8_K *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_Kx16 * b_ptr = (const block_q4_Kx16 *) vx + (x * nb); + for (int j = 0; j < ncols_interleaved; j++) { + sumf[j] = 0.0f; + sum_minf[j] = 0.0f; + } + for (int l = 0; l < nb; l++) { + for (int i = 0; i < 128; i++) { + scales[i] = b_ptr[l].scales[i] & 0x0F; + mins[i] = b_ptr[l].scales[i] >> 4; + } + for (int i = 0; i < 64; i++) { + scales[i] |= (b_ptr[l].scales[128 + i] & 0x03) << 4; + mins[i] |= (b_ptr[l].scales[128 + i] & 0x0C) << 2; + scales[i + 64] |= (b_ptr[l].scales[128 + i] & 0x30); + mins[i + 64] |= (b_ptr[l].scales[128 + i] & 0xC0) >> 2; + } + for (int sb = 0; sb < 8; sb++) { + uint8_t *min = &mins[sb * 16]; + for (int j = 0; j < ncols_interleaved; j++) { + sum_minf[j] += min[j] * (a_ptr[l].bsums[sb * 2] + a_ptr[l].bsums[sb * 2 + 1]) * GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d; + } + } + for (int sb = 0; sb < 8; sb += 2) { + uint8_t *scales_0 = &scales[sb * 16]; + uint8_t *scales_1 = &scales[(sb + 1) * 16]; + for (int i = 0; i < QK4_0; i++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi1 = 0; + sumi2 = 0; + sumi = 0; + const int v0 = (int8_t) (b_ptr[l].qs[sb * 256 + i * 16 + j] & 0xF); + const int v1 = (int8_t) (b_ptr[l].qs[sb * 256 + i * 16 + j] >> 4); + sumi1 = (v0 * a_ptr[l].qs[sb * 32 + i]); + sumi2 = (v1 * a_ptr[l].qs[sb * 32 + 32 + i]); + sumi1 = sumi1 * scales_0[j]; + sumi2 = sumi2 * scales_1[j]; + sumi += sumi1 + sumi2; + sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d; + } + } + } + } + for (int j = 0; j < ncols_interleaved; j++) { + s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j]; + } + } +} + +void ggml_gemv_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 16; + const int blocklen = 1; + + assert(nr == 1); + assert(n % qk == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(bs); + UNUSED(nr); + + float sumf[16]; + int sumi; + + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_iq4_nlx16 * b_ptr = (const block_iq4_nlx16 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F]; + const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4]; + sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])); + } + sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d); + } + } + } + for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; + } +} + +void ggml_gemv_q8_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 16; + const int blocklen = 1; + + assert(nr == 1); + assert(n % qk == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(bs); + UNUSED(nr); + + float sumf[16]; + int sumi; + + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q8_0x16 * b_ptr = (const block_q8_0x16 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) { + sumf[j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / blocklen); k++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i]; + sumi += v0 * a_ptr[l].qs[k * blocklen + i]; + } + sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d); + } + } + } + for (int j = 0; j < ncols_interleaved; j++) { + s[x * ncols_interleaved + j] = sumf[j]; + } + } +} + +void ggml_gemv_q2_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + assert(n % QK_K == 0); + assert(nr == 1); + assert(nc % 16 == 0); + + UNUSED(bs); + + const int nb = n / QK_K; + const block_q2_Kx16 * x = (const block_q2_Kx16 *)vx; + const block_q8_K * y = (const block_q8_K *)vy; + + // Layout: Even-Low(0,2,4,6), Odd-Low(1,3,5,7), Even-High(8...), Odd-High(9...) + const int sb_perm[16] = { + 0, 4, 1, 5, 2, 6, 3, 7, // 0-7 + 8, 12, 9, 13, 10, 14, 11, 15 // 8-15 + }; + + for (int col_tile = 0; col_tile < nc; col_tile += 16) { + const block_q2_Kx16 * x_ptr = x + (col_tile / 16) * nb; + const block_q8_K * y_ptr = y; + + float sumf[16] = {0}; + + // Loop over K-blocks + for (int k_block = 0; k_block < nb; ++k_block) { + int32_t isum[16] = {0}; + int32_t summs[16] = {0}; + + const uint8_t * qs_rhs = x_ptr[k_block].qs; + const uint8_t * sc_rhs = x_ptr[k_block].scales; + const int8_t * qs_lhs = y_ptr[k_block].qs; + const int16_t * bs_lhs = y_ptr[k_block].bsums; + + // Iterate over sub-blocks 0..15 + for (int sb = 0; sb < 16; ++sb) { + // Correction Term + int16_t bsum = bs_lhs[sb]; + int scale_offset = sb_perm[sb] * 16; + + for (int col = 0; col < 16; ++col) { + uint8_t sc_val = sc_rhs[scale_offset + col]; + summs[col] += bsum * (sc_val >> 4); // Min is high 4 bits + } + + // Main Dot Product + // Calculate base offsets for Q2 unpacking based on SB + int byte_base; + if (sb < 8) byte_base = (sb % 2 == 0) ? 0 : 16; + else byte_base = (sb % 2 == 0) ? 32 : 48; + + int shift = ((sb / 2) % 4) * 2; + + for (int col = 0; col < 16; ++col) { + uint8_t sc_val = sc_rhs[scale_offset + col]; + int32_t d_sb = sc_val & 0xF; // Scale is low 4 bits + + // Process 16 elements (l=0..15) + for (int l = 0; l < 16; ++l) { + // Q2: Interleaved by column. Byte `l` contains 4 k-values. + int qs_idx = (byte_base + l) * 16 + col; + uint8_t q2_val = (qs_rhs[qs_idx] >> shift) & 3; + + // Q8: Linear access + int k = sb * 16 + l; + int8_t q8_val = qs_lhs[k]; + + isum[col] += q8_val * q2_val * d_sb; + } + } + } + + // Finalize K-Block + for (int col = 0; col < 16; ++col) { + float d_lhs = y_ptr[k_block].d; + float d_rhs = GGML_FP16_TO_FP32(x_ptr[k_block].d[col]); + float dm_rhs = GGML_FP16_TO_FP32(x_ptr[k_block].dmin[col]); + + float d_all = d_lhs * d_rhs; + float d_min = d_lhs * dm_rhs; + + sumf[col] += (isum[col] * d_all) - (summs[col] * d_min); + } + } + + for (int col = 0; col < 16; ++col) { + s[col_tile + col] = sumf[col]; + } + } +} +#endif + void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; @@ -1942,6 +2327,8 @@ void ggml_gemm_q8_0_4x4_q8_0_generic(int n, } } + + void ggml_gemm_q8_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, @@ -1994,6 +2381,342 @@ void ggml_gemm_q8_0_4x8_q8_0_generic(int n, } } +#if defined __riscv_zvfh +void ggml_gemm_q4_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 16; + const int blocklen = 1; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + + float sumf[4][16]; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x16 * b_ptr = (const block_q4_0x16 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); + sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + + (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4; + } + sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]); + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + } + } + } +} + +void ggml_gemm_q4_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK_K; + const int nb = n / qk; + const int ncols_interleaved = 16; + const int blocklen = 1; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + + float sumf[4][16]; + float sum_minf[4][16]; + uint8_t scales[128]; + uint8_t mins[128]; + int sumi1; + int sumi2; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_Kx16 * b_ptr = (const block_q4_Kx16 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumf[m][j] = 0.0; + sum_minf[m][j] = 0.0; + } + } + for (int l = 0; l < nb; l++) { + for (int i = 0; i < 128; i++) { + scales[i] = b_ptr[l].scales[i] & 0x0F; + mins[i] = b_ptr[l].scales[i] >> 4; + } + for (int i = 0; i < 64; i++) { + scales[i] |= (b_ptr[l].scales[128 + i] & 0x03) << 4; + mins[i] |= (b_ptr[l].scales[128 + i] & 0x0C) << 2; + scales[i + 64] |= (b_ptr[l].scales[128 + i] & 0x30); + mins[i + 64] |= (b_ptr[l].scales[128 + i] & 0xC0) >> 2; + } + + for (int sb = 0; sb < 8; sb++) { + uint8_t *min = &mins[sb * 16]; + for(int m = 0; m < 4; m++) { + const int16_t bsums = a_ptr[l].bsums[sb * 8 + m] + a_ptr[l].bsums[sb * 8 + m + 4]; + for(int j = 0; j < ncols_interleaved; j++) { + sum_minf[m][j] += min[j] * bsums * GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m]; + } + } + } + + for (int sb = 0; sb < 8; sb += 2) { + uint8_t *scales_0 = &scales[sb * 16]; + uint8_t *scales_1 = &scales[(sb + 1) * 16]; + + for (int i = 0; i < QK4_0; i++) { + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi1 = 0; + sumi2 = 0; + sumi = 0; + + const int v0 = (int8_t) (b_ptr[l].qs[sb * 256 + i * 16 + j] & 0xF); + const int v1 = (int8_t) (b_ptr[l].qs[sb * 256 + i * 16 + j] >> 4); + sumi1 = (v0 * a_ptr[l].qs[sb * 4 * 32 + i * 4 + m]); + sumi2 = (v1 * a_ptr[l].qs[sb * 4 * 32 + 32 * 4 + i * 4 + m]); + sumi1 = sumi1 * scales_0[j]; + sumi2 = sumi2 * scales_1[j]; + sumi += sumi1 + sumi2; + + sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m]; + } + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j]; + } + } + } + } +} + +void ggml_gemm_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 16; + const int blocklen = 1; + + assert(n % qk == 0); + assert(nr % 4 == 0); + assert(nc % ncols_interleaved == 0); + + float sumf[4][16]; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_iq4_nlx16 * b_ptr = (const block_iq4_nlx16 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F]; + const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4]; + sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + + (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + (qk / 2) * 4])); + } + sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]); + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + } + } + } +} + +void ggml_gemm_q8_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 16; + const int blocklen = 1; + + assert(n % qk == 0); + assert(nr % 4 == 0); + assert(nc % ncols_interleaved == 0); + + float sumf[4][16]; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q8_0x16 * b_ptr = (const block_q8_0x16 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumf[m][j] = 0.0; + } + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / blocklen); k++) { + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i]; + sumi += v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]; + } + sumf[m][j] += + sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]); + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + } + } + } + } +} + + +void ggml_gemm_q2_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + assert(n % QK_K == 0); + assert(nr % 4 == 0); + assert(nc % 16 == 0); + const int nb = n / QK_K; + const block_q2_Kx16 * x = (const block_q2_Kx16 *)vx; + const block_q8_Kx4 * y = (const block_q8_Kx4 *)vy; + + const int sb_perm[16] = { + 0, 4, 1, 5, 2, 6, 3, 7, + 8, 12, 9, 13, 10, 14, 11, 15 + }; + + // Iterate Rows in tiles of 4 + for (int row_tile = 0; row_tile < nr; row_tile += 4) { + // Iterate Columns in tiles of 16 + for (int col_tile = 0; col_tile < nc; col_tile += 16) { + + const block_q2_Kx16 * x_ptr = x + (col_tile / 16) * nb; + const block_q8_Kx4 * y_ptr = y + (row_tile / 4) * nb; + + float sumf[4][16]; + memset(sumf, 0, sizeof(sumf)); + + for (int k_block = 0; k_block < nb; ++k_block) { + int32_t isum[4][16]; + int32_t summs[4][16]; + memset(isum, 0, sizeof(isum)); + memset(summs, 0, sizeof(summs)); + + const uint8_t * qs_rhs = x_ptr[k_block].qs; + const uint8_t * sc_rhs = x_ptr[k_block].scales; + const int8_t * qs_lhs = y_ptr[k_block].qs; + const int16_t * bs_lhs = y_ptr[k_block].bsums; + + for (int sb = 0; sb < 16; ++sb) { + int scale_offset = sb_perm[sb] * 16; + + int byte_base; + if (sb < 8) byte_base = (sb % 2 == 0) ? 0 : 16; + else byte_base = (sb % 2 == 0) ? 32 : 48; + int shift = ((sb / 2) % 4) * 2; + + for (int col = 0; col < 16; ++col) { + uint8_t sc_val = sc_rhs[scale_offset + col]; + int32_t d_sb = sc_val & 0xF; + int32_t m_sb = sc_val >> 4; + + // Correction Term + for (int r = 0; r < 4; ++r) { + int bsum_idx = (sb / 4) * 16 + r * 4 + (sb % 4); + summs[r][col] += bs_lhs[bsum_idx] * m_sb; + } + + // Main Dot Product + for (int l = 0; l < 16; ++l) { + int qs_idx = (byte_base + l) * 16 + col; + uint8_t q2_val = (qs_rhs[qs_idx] >> shift) & 3; + + // Calculate Q8 index for this specific k and row + int k = sb * 16 + l; + int q8_idx = (k / 4) * 16 + (k % 4); + + for (int r = 0; r < 4; ++r) { + // Add r*4 to jump to the correct row within the 4x4 chunk + int8_t q8_val = qs_lhs[q8_idx + r * 4]; + isum[r][col] += q8_val * q2_val * d_sb; + } + } + } + } + + // Finalize K-Block + for (int col = 0; col < 16; ++col) { + float d_rhs = GGML_FP16_TO_FP32(x_ptr[k_block].d[col]); + float dm_rhs = GGML_FP16_TO_FP32(x_ptr[k_block].dmin[col]); + + for (int r = 0; r < 4; ++r) { + float d_lhs = y_ptr[k_block].d[r]; + float d_all = d_lhs * d_rhs; + float d_min = d_lhs * dm_rhs; + sumf[r][col] += (isum[r][col] * d_all) - (summs[r][col] * d_min); + } + } + } + + for (int r = 0; r < 4; ++r) { + for (int col = 0; col < 16; ++col) { + s[(row_tile + r) * bs + (col_tile + col)] = sumf[r][col]; + } + } + } + } +} +#endif + } // extern "C" static block_q8_0x4 make_block_q8_0x4(block_q8_0 * in, unsigned int blck_size_interleave) { @@ -2082,6 +2805,31 @@ static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int blck_size_in return out; } +static block_q4_0x16 make_block_q4_0x16(block_q4_0 * in, unsigned int blck_size_interleave) { + block_q4_0x16 out; + + for (int i = 0; i < 16; i++) { + out.d[i] = in[i].d; + } + + const int end = QK4_0 * 8 / blck_size_interleave; + + if (blck_size_interleave == 1) { + const uint8_t xor_mask = 0x88; + for (int i = 0; i < end; ++i) { + int src_id = i % 16; + int src_offset = i / 16; + int dst_offset = i; + + out.qs[dst_offset] = in[src_id].qs[src_offset] ^ xor_mask; + } + } else { + GGML_ASSERT(false); + } + + return out; +} + static block_q4_Kx8 make_block_q4_Kx8(block_q4_K * in, unsigned int blck_size_interleave) { block_q4_Kx8 out; //Delta(scale) and dmin values of the eight Q4_K structures are copied onto the output interleaved structure @@ -2159,6 +2907,58 @@ static block_q4_Kx8 make_block_q4_Kx8(block_q4_K * in, unsigned int blck_size_in return out; } +static block_q4_Kx16 make_block_q4_Kx16(block_q4_K * in, unsigned int blck_size_interleave) { + block_q4_Kx16 out; + //Delta(scale) and dmin values of the 16 Q4_K structures are copied onto the output interleaved structure + for (int i = 0; i < 16; i++) { + out.d[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d; + } + + for (int i = 0; i < 16; i++) { + out.dmin[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin; + } + + const int end = QK_K * 8 / blck_size_interleave; + + if (blck_size_interleave == 1) { + for (int i = 0; i < end; ++i) { + int src_id = i % 16; + int src_offset = i / 16; + int dst_offset = i; + + out.qs[dst_offset] = in[src_id].qs[src_offset]; + } + + // RVV repacking. + // + // Extract sums and mins for all 8 sub-blocks for each block of Q4_K. + uint8_t s[128], m[128]; + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 16; j++) { + s[i * 16 + j] = in[j].scales[i] & 63; + m[i * 16 + j] = in[j].scales[i + 4] & 63; + } + } + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 16; j++) { + s[64 + i * 16 + j] = ((in[j].scales[i] & 192) >> 2) | (in[j].scales[i+8] & 15); + m[64 + i * 16 + j] = ((in[j].scales[i + 4] & 192) >> 2) | ((in[j].scales[i+8] & 240) >> 4); + } + } + + for (int i = 0; i < 128; i++) { + out.scales[i] = (s[i] & 15) | ((m[i] & 15) << 4); + } + for (int i = 0; i < 64; i++) { + out.scales[128 + i] = ((s[i] & 48) >> 4) | ((m[i] & 48) >> 2) | (s[64 + i] & 48) | ((m[64 + i] & 48) << 2); + } + } else { + GGML_ASSERT(false); + } + + return out; +} + static block_q2_Kx8 make_block_q2_Kx8(block_q2_K * in, unsigned int blck_size_interleave) { block_q2_Kx8 out; @@ -2332,6 +3132,68 @@ static block_q6_Kx8 make_block_q6_Kx8(block_q6_K * in, unsigned int blck_size_in return out; } +static block_q2_Kx16 make_block_q2_Kx16(const block_q2_K * in, unsigned int blck_size_interleave) { + block_q2_Kx16 out; + constexpr int N_COLS = 16; + + // 1. Copy Super-Scales (d) and Super-Mins (dmin) + for (int i = 0; i < N_COLS; i++) { + out.d[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d; + out.dmin[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin; + } + + // 2. Interleave Q2_K Data + const int bytes_per_col = 64; + const int total_bytes = N_COLS * bytes_per_col; + const int end = total_bytes / blck_size_interleave; + + for (int i = 0; i < end; ++i) { + int src_col_id = i % N_COLS; + int src_offset = (i / N_COLS) * blck_size_interleave; + int dst_offset = i * blck_size_interleave; + memcpy(&out.qs[dst_offset], &in[src_col_id].qs[src_offset], blck_size_interleave); + } + + // 3. Repack Scales into the Optimized "Sequential-Parallel" Layout + int out_idx = 0; + + // Arrays define the sub-block order for each group + const int even_low_sbs[] = {0, 2, 4, 6}; + const int odd_low_sbs[] = {1, 3, 5, 7}; + const int even_high_sbs[] = {8, 10, 12, 14}; + const int odd_high_sbs[] = {9, 11, 13, 15}; + + // Pack Group 1: Even-Low + for (int sb : even_low_sbs) { + for (int col = 0; col < N_COLS; col++) { + out.scales[out_idx++] = in[col].scales[sb]; + } + } + + // Pack Group 2: Odd-Low + for (int sb : odd_low_sbs) { + for (int col = 0; col < N_COLS; col++) { + out.scales[out_idx++] = in[col].scales[sb]; + } + } + + // Pack Group 3: Even-High + for (int sb : even_high_sbs) { + for (int col = 0; col < N_COLS; col++) { + out.scales[out_idx++] = in[col].scales[sb]; + } + } + + // Pack Group 4: Odd-High + for (int sb : odd_high_sbs) { + for (int col = 0; col < N_COLS; col++) { + out.scales[out_idx++] = in[col].scales[sb]; + } + } + + return out; +} + static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { GGML_ASSERT(t->type == GGML_TYPE_Q4_0); GGML_ASSERT(interleave_block == 4 || interleave_block == 8); @@ -2394,6 +3256,36 @@ static int repack_q4_K_to_q4_K_8_bl(struct ggml_tensor * t, int interleave_block GGML_UNUSED(data_size); } +static int repack_q4_K_to_q4_K_16_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_K); + constexpr int nrows_interleaved = 16; + + block_q4_Kx16 * dst = (block_q4_Kx16*)t->data; + const block_q4_K * src = (const block_q4_K*) data; + block_q4_K dst_tmp[16]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_K; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_K)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++ ) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_q4_Kx16(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + static int repack_q2_K_to_q2_K_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { GGML_ASSERT(t->type == GGML_TYPE_Q2_K); GGML_ASSERT(interleave_block == 8); @@ -2425,6 +3317,71 @@ static int repack_q2_K_to_q2_K_8_bl(struct ggml_tensor * t, int interleave_block GGML_UNUSED(data_size); } +static int repack_q2_K_to_q2_K_16_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q2_K); + constexpr int nrows_interleaved = 16; + + block_q2_Kx16 * dst = (block_q2_Kx16*)t->data; + const block_q2_K * src = (const block_q2_K*) data; + + block_q2_K dst_tmp[nrows_interleaved]; + + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_K; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q2_K)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + // This loop gathers 16 separate blocks (one from each column) + // that correspond to the same K-dimension chunk. + for (int i = 0; i < nrows_interleaved; i++ ) { + dst_tmp[i] = src[x + i * nblocks]; + } + + *dst++ = make_block_q2_Kx16(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +static int repack_q4_0_to_q4_0_16_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_0); + constexpr int nrows_interleaved = 16; + + block_q4_0x16 * dst = (block_q4_0x16*)t->data; + const block_q4_0 * src = (const block_q4_0*) data; + block_q4_0 dst_tmp[16]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK4_0; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++ ) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_q4_0x16(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + static int repack_q5_K_to_q5_K_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, @@ -2549,6 +3506,60 @@ static int repack_q8_0_to_q8_0_4_bl(struct ggml_tensor * t, return 0; } +static block_q8_0x16 make_block_q8_0x16(block_q8_0 * in, unsigned int blck_size_interleave) { + block_q8_0x16 out; + + for (int i = 0; i < 16; i++) { + out.d[i] = in[i].d; + } + + const int end = QK8_0 * 16 / blck_size_interleave; + + if (blck_size_interleave == 1) { + for (int i = 0; i < end; ++i) { + int src_id = i % 16; + int src_offset = i / 16; + int dst_offset = i; + out.qs[dst_offset] = in[src_id].qs[src_offset]; + } + } else { + GGML_ASSERT(false); + } + + return out; +} + +static int repack_q8_0_to_q8_0_16_bl(struct ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q8_0); + constexpr int nrows_interleaved = 16; + + block_q8_0x16 * dst = (block_q8_0x16 *) t->data; + const block_q8_0 * src = (const block_q8_0 *) data; + block_q8_0 dst_tmp[16]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK8_0; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q8_0)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_q8_0x16(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; +} + static block_iq4_nlx4 make_block_iq4_nlx4(block_iq4_nl * in, unsigned int blck_size_interleave) { block_iq4_nlx4 out; @@ -2674,6 +3685,62 @@ static int repack_iq4_nl_to_iq4_nl_8_bl(struct ggml_tensor * t, int interleave_b GGML_UNUSED(data_size); } +static block_iq4_nlx16 make_block_iq4_nlx16(block_iq4_nl * in, unsigned int blck_size_interleave) { + block_iq4_nlx16 out; + + for (int i = 0; i < 16; i++) { + out.d[i] = in[i].d; + } + + const int end = QK4_NL * 8 / blck_size_interleave; + + if (blck_size_interleave == 1) { + for (int i = 0; i < end; ++i) { + int src_id = i % 16; + int src_offset = i / 16; + int dst_offset = i; + + out.qs[dst_offset] = in[src_id].qs[src_offset]; + } + } else { + GGML_ASSERT(false); + } + + return out; +} + +static int repack_iq4_nl_to_iq4_nl_16_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_IQ4_NL); + GGML_ASSERT(interleave_block == 1); + + const block_iq4_nl * src = (const block_iq4_nl *)data; + block_iq4_nlx16 * dst = ( block_iq4_nlx16 *)t->data; + + block_iq4_nl dst_tmp[16]; + + int nrow = ggml_nrows(t); + int nrows_interleaved = 16; + int nblocks = t->ne[0] / QK4_NL; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_iq4_nl)); + + if (t->ne[1] % nrows_interleaved != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_iq4_nlx16(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} static block_mxfp4x4 make_block_mxfp4x4(block_mxfp4 * in, unsigned int blck_size_interleave) { block_mxfp4x4 out; @@ -2864,6 +3931,28 @@ template <> int repack(struct ggml_tensor * t, const void * da return repack_q8_0_to_q8_0_4_bl(t, 8, data, data_size); } +#if defined __riscv_zvfh +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q4_0_to_q4_0_16_bl(t, 1, data, data_size); +} + +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q4_K_to_q4_K_16_bl(t, 1, data, data_size); +} + +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_iq4_nl_to_iq4_nl_16_bl(t, 1, data, data_size); +} + +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q8_0_to_q8_0_16_bl(t, 1, data, data_size); +} + +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q2_K_to_q2_K_16_bl(t, 1, data, data_size); +} +#endif + // gemv template void gemv(int, float *, size_t, const void *, const void *, int, int); @@ -2939,6 +4028,28 @@ template <> void gemv(int n, float * s, size_t ggml_gemv_q8_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc); } +#if defined __riscv_zvfh +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q4_0_16x1_q8_0(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q4_K_16x1_q8_K(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_iq4_nl_16x1_q8_0(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q8_0_16x1_q8_0(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q2_K_16x1_q8_K(n, s, bs, vx, vy, nr, nc); +} +#endif + // gemm template void gemm(int, float *, size_t, const void *, const void *, int, int); @@ -3014,6 +4125,28 @@ template <> void gemm(int n, float * s, size_t ggml_gemm_q8_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc); } +#if defined __riscv_zvfh +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q4_0_16x1_q8_0(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q4_K_16x1_q8_K(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_iq4_nl_16x1_q8_0(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q8_0_16x1_q8_0(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q2_K_16x1_q8_K(n, s, bs, vx, vy, nr, nc); +} +#endif + class tensor_traits_base : public ggml::cpu::tensor_traits { public: virtual int repack(struct ggml_tensor * t, const void * data, size_t data_size) = 0; @@ -3422,9 +4555,20 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons static const ggml::cpu::repack::tensor_traits q8_0_4x4_q8_0; static const ggml::cpu::repack::tensor_traits q8_0_4x8_q8_0; + // instances for RISC-V + // + // These implement outer-product style matrix multiplication kernels with + // an interleave of 1. +#if defined __riscv_zvfh + static const ggml::cpu::repack::tensor_traits q4_0_16x1_q8_0; + static const ggml::cpu::repack::tensor_traits q4_K_16x1_q8_K; + static const ggml::cpu::repack::tensor_traits iq4_nl_16x1_q8_0; + static const ggml::cpu::repack::tensor_traits q8_0_16x1_q8_0; + static const ggml::cpu::repack::tensor_traits q2_K_16x1_q8_K; +#endif + if (cur->type == GGML_TYPE_Q4_0) { - if (ggml_cpu_has_avx2() || (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && ggml_cpu_get_sve_cnt() == QK8_0) - || (ggml_cpu_has_riscv_v() && (ggml_cpu_get_rvv_vlen() >= QK4_0))) { + if (ggml_cpu_has_avx2() || (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && ggml_cpu_get_sve_cnt() == QK8_0)) { if (cur->ne[1] % 8 == 0) { return &q4_0_8x8_q8_0; } @@ -3439,6 +4583,17 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons return &q4_0_4x4_q8_0; } } + if (ggml_cpu_has_riscv_v()) { + #if defined __riscv_zvfh + switch (__riscv_vlenb() * 8) { + case 128: { break; } // TODO + case 256: { if (cur->ne[1] % 16 == 0) { return &q4_0_16x1_q8_0; } break; } + case 512: { break; } // TODO + case 1024: { break; } // TODO + default: { return nullptr; } + } + #endif + } } else if (cur->type == GGML_TYPE_Q4_K) { if (ggml_cpu_has_avx2()) { if (cur->ne[1] % 8 == 0) { @@ -3455,12 +4610,34 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons return &q4_K_8x4_q8_K; } } + if (ggml_cpu_has_riscv_v()) { + #if defined __riscv_zvfh + switch (__riscv_vlenb() * 8) { + case 128: { break; } // TODO + case 256: { if (cur->ne[1] % 16 == 0) { return &q4_K_16x1_q8_K; } break; } + case 512: { break; } // TODO + case 1024: { break; } // TODO + default: { return nullptr; } + } + #endif + } } else if (cur->type == GGML_TYPE_Q2_K) { if (ggml_cpu_has_avx512()) { if (cur->ne[1] % 8 == 0) { return &q2_K_8x8_q8_K; } } + if (ggml_cpu_has_riscv_v()) { + #if defined __riscv_zvfh + switch (__riscv_vlenb() * 8) { + case 128: { break; } // TODO + case 256: { if (cur->ne[1] % 16 == 0) { return &q2_K_16x1_q8_K; } break; } + case 512: { break; } // TODO + case 1024: { break; } // TODO + default: { return nullptr; } + } + #endif + } } else if (cur->type == GGML_TYPE_Q5_K) { if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) { if (cur->ne[1] % 8 == 0) { @@ -3494,6 +4671,17 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons return &iq4_nl_4x4_q8_0; } } + if (ggml_cpu_has_riscv_v()) { + #if defined __riscv_zvfh + switch (__riscv_vlenb() * 8) { + case 128: { break; } // TODO + case 256: { if (cur->ne[1] % 16 == 0) { return &iq4_nl_16x1_q8_0; } break; } + case 512: { break; } // TODO + case 1024: { break; } // TODO + default: { return nullptr; } + } + #endif + } } else if (cur->type == GGML_TYPE_MXFP4) { if (ggml_cpu_has_avx2()) { if (cur->ne[1] % 8 == 0) { @@ -3516,6 +4704,17 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons return &q8_0_4x4_q8_0; } } + if (ggml_cpu_has_riscv_v()) { + #if defined __riscv_zvfh + switch (__riscv_vlenb() * 8) { + case 128: { break; } // TODO + case 256: { if (cur->ne[1] % 16 == 0) { return &q8_0_16x1_q8_0; } break; } + case 512: { break; } // TODO + case 1024: { break; } // TODO + default: { return nullptr; } + } + #endif + } } return nullptr; diff --git a/ggml/src/ggml-cpu/repack.h b/ggml/src/ggml-cpu/repack.h index b9f821630c4..cb21edf6239 100644 --- a/ggml/src/ggml-cpu/repack.h +++ b/ggml/src/ggml-cpu/repack.h @@ -28,13 +28,17 @@ template struct block { // control size static_assert(sizeof(block<4, 4>) == 4 * sizeof(ggml_half) + QK8_0 * 2, "wrong block<4,4> size/padding"); static_assert(sizeof(block<4, 8>) == 8 * sizeof(ggml_half) + QK8_0 * 4, "wrong block<4,8> size/padding"); +static_assert(sizeof(block<4, 16>) == 16 * sizeof(ggml_half) + QK8_0 * 8, "wrong block<4,16> size/padding"); static_assert(sizeof(block<8, 4>) == 4 * sizeof(ggml_half) + QK8_0 * 4, "wrong block<8,4> size/padding"); static_assert(sizeof(block<8, 8>) == 8 * sizeof(ggml_half) + QK8_0 * 8, "wrong block<8,8> size/padding"); +static_assert(sizeof(block<8, 16>) == 16 * sizeof(ggml_half) + QK8_0 * 16, "wrong block<8,16> size/padding"); using block_q4_0x4 = block<4, 4>; using block_q4_0x8 = block<4, 8>; +using block_q4_0x16 = block<4, 16>; using block_q8_0x4 = block<8, 4>; using block_q8_0x8 = block<8, 8>; +using block_q8_0x16 = block<8, 16>; struct block_q4_Kx8 { ggml_half d[8]; // super-block scale for quantized scales @@ -44,7 +48,14 @@ struct block_q4_Kx8 { }; static_assert(sizeof(block_q4_Kx8) == sizeof(ggml_half) * 16 + K_SCALE_SIZE * 8 + QK_K * 4, "wrong q4_K block size/padding"); +struct block_q4_Kx16 { + ggml_half d[16]; // super-block scale for quantized scales + ggml_half dmin[16]; // super-block scale for quantized mins + uint8_t scales[192]; // scales and mins, quantized with 6 bits + uint8_t qs[2048]; // 4--bit quants +}; +static_assert(sizeof(block_q4_Kx16) == sizeof(ggml_half) * 32 + K_SCALE_SIZE * 16 + QK_K * 8, "wrong q4_K block size/padding"); struct block_q2_Kx8 { ggml_half d[8]; // super-block scale for quantized scales ggml_half dmin[8]; // super-block scale for quantized mins @@ -53,6 +64,13 @@ struct block_q2_Kx8 { }; static_assert(sizeof(block_q2_Kx8) == sizeof(ggml_half) * 16 + QK_K/2 + QK_K * 2, "wrong q2_K block size/padding"); +struct block_q2_Kx16 { + ggml_half d[16]; // Super-block scale for quantized scales + ggml_half dmin[16]; // Super-block scale for quantized mins + uint8_t scales[256]; // Sub-block scales (16 cols * 16 sub-blocks) + uint8_t qs[1024]; // Data (16 cols * 64 bytes per block) +}; +static_assert(sizeof(block_q2_Kx16) == sizeof(ggml_half) * 32 + QK_K + QK_K * 4, "wrong q2_K block size/padding"); struct block_q5_Kx8 { ggml_half d[8]; // super-block scale for quantized scales @@ -97,6 +115,12 @@ struct block_iq4_nlx8 { static_assert(sizeof(block_iq4_nlx8) == 8 * sizeof(ggml_half) + QK4_NL * 4, "wrong iq4_nlx8 block size/padding"); +struct block_iq4_nlx16 { + ggml_half d[16]; // deltas for 16 iq4_nl blocks + uint8_t qs[QK4_NL * 8]; // nibbles / quants for 16 iq4_nl blocks +}; + +static_assert(sizeof(block_iq4_nlx16) == 16 * sizeof(ggml_half) + QK4_NL * 8, "wrong iq4_nlx16 block size/padding"); struct block_mxfp4x4 { uint8_t e[4]; uint8_t qs[QK_MXFP4 * 2]; @@ -109,7 +133,6 @@ struct block_mxfp4x8 { }; static_assert(sizeof(block_mxfp4x8) == 8 + QK_MXFP4 * 4, "wrong mxfp4x8 block size/padding"); - #if defined(__cplusplus) extern "C" { #endif @@ -132,6 +155,8 @@ void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void ggml_gemv_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_mxfp4_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_mxfp4_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q8_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -146,10 +171,22 @@ void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void ggml_gemm_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_mxfp4_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_mxfp4_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemv_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemv_q8_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q8_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +#if defined __riscv_zvfh +void ggml_quantize_mat_q8_0_4x1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); +void ggml_quantize_mat_q8_K_4x1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); +void ggml_gemv_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q2_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q2_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +#endif // Native implementations void ggml_quantize_mat_q8_0_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); @@ -170,6 +207,8 @@ void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_mxfp4_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_mxfp4_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q8_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -184,10 +223,22 @@ void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_mxfp4_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_mxfp4_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemv_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemv_q8_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q8_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +#if defined __riscv_zvfh +void ggml_quantize_mat_q8_0_4x1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); +void ggml_quantize_mat_q8_K_4x1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); +void ggml_gemv_q4_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q4_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q8_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q2_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q4_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q4_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q8_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q2_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +#endif #if defined(__cplusplus) } // extern "C" From dfa6858d0268741c7eea15724588435b53a14ed8 Mon Sep 17 00:00:00 2001 From: Charles Xu Date: Tue, 10 Mar 2026 08:25:25 +0100 Subject: [PATCH 240/831] kleidiai : support for concurrent sme and neon kernel execution (llama/20070) --- ggml/src/ggml-cpu/kleidiai/kernels.cpp | 6 +- ggml/src/ggml-cpu/kleidiai/kleidiai.cpp | 1217 ++++++++++++++++++----- 2 files changed, 968 insertions(+), 255 deletions(-) diff --git a/ggml/src/ggml-cpu/kleidiai/kernels.cpp b/ggml/src/ggml-cpu/kleidiai/kernels.cpp index 40f7c0df650..8c4d7bc925f 100644 --- a/ggml/src/ggml-cpu/kleidiai/kernels.cpp +++ b/ggml/src/ggml-cpu/kleidiai/kernels.cpp @@ -520,7 +520,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .packed_stride_ex = */ &rhs_stride_fn4, /* .pack_func_ex = */ &rhs_pack_fn12, }, - /* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM, + /* .required_cpu = */ CPU_FEATURE_I8MM, /* .lhs_type = */ GGML_TYPE_F32, /* .rhs_type = */ GGML_TYPE_Q4_0, /* .op_type = */ GGML_TYPE_F32, @@ -631,7 +631,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .packed_stride_ex = */ &rhs_stride_fn4, /* .pack_func_ex = */ &rhs_pack_fn12, }, - /* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM, + /* .required_cpu = */ CPU_FEATURE_I8MM, /* .lhs_type = */ GGML_TYPE_F32, /* .rhs_type = */ GGML_TYPE_Q4_0, /* .op_type = */ GGML_TYPE_F32, @@ -801,7 +801,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels_q8[] = { /* .packed_stride_ex = */ &rhs_stride_fn4, /* .pack_func_ex = */ &rhs_pack_scale_fn12, }, - /* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM, + /* .required_cpu = */ CPU_FEATURE_I8MM, /* .lhs_type = */ GGML_TYPE_F32, /* .rhs_type = */ GGML_TYPE_Q8_0, /* .op_type = */ GGML_TYPE_F32, diff --git a/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp index ad23e73184e..9bcc18d442c 100644 --- a/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +++ b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp @@ -1,20 +1,31 @@ -// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2025-2026 Arm Limited and/or its affiliates // SPDX-License-Identifier: MIT // #include #include +#include #include #include -#include #include +#include #include #include #include #include #include +#include +#include +#include +#include +#include +#include +#include #if defined(__linux__) #include #include +#include +#include +#include #elif defined(__APPLE__) #include #include @@ -39,11 +50,18 @@ #define GGML_COMMON_DECL_CPP #include "ggml-common.h" +static constexpr int GGML_KLEIDIAI_MAX_KERNEL_SLOTS = 2; +static constexpr uint32_t GGML_KLEIDIAI_PACK_MAGIC = 0x4b4c4149; // "KLAI" +static constexpr uint16_t GGML_KLEIDIAI_PACK_VERSION = 1; +static constexpr size_t GGML_KLEIDIAI_PACK_ALIGN = 64; + struct ggml_kleidiai_context { cpu_feature features; ggml_kleidiai_kernels * kernels_q4; ggml_kleidiai_kernels * kernels_q8; -} static ctx = { CPU_FEATURE_NONE, NULL, NULL }; + int sme_thread_cap; // <= 0 means “SME disabled/unknown”; + int thread_hint; // <= 0 means “no hint” +} static ctx = { CPU_FEATURE_NONE, nullptr, nullptr, 0, -1 }; static const char* cpu_feature_to_string(cpu_feature f) { if (f == CPU_FEATURE_NONE) { @@ -63,41 +81,335 @@ static const char* cpu_feature_to_string(cpu_feature f) { } } -static void init_kleidiai_context(void) { +static size_t detect_num_smcus() { + if (!ggml_cpu_has_sme()) { + return 0; + } + +#if defined(__linux__) && defined(__aarch64__) + // Linux/aarch64: Best-effort count of Streaming Mode Compute Units (SMCUs) via SMIDR_EL1 sysfs. + size_t num_private = 0; + std::set shared_ids; + + for (size_t cpu = 0;; ++cpu) { + const std::string path = + "/sys/devices/system/cpu/cpu" + std::to_string(cpu) + + "/regs/identification/smidr_el1"; + + std::ifstream file(path); + if (!file.is_open()) { + break; + } + + uint64_t smidr = 0; + if (!(file >> std::hex >> smidr)) { + continue; + } + + // Arm ARM: SMIDR_EL1 + const uint32_t sh = (uint32_t)((smidr >> 13) & 0x3); + // Build an "affinity-like" identifier for shared SMCUs. + // Keep the original packing logic, but isolate it here. + const uint32_t id = (uint32_t)((smidr & 0xFFFu) | ((smidr >> 20) & 0xFFFFF000u)); + + switch (sh) { + case 0b10: // private SMCU + ++num_private; + break; + case 0b11: // shared SMCU + shared_ids.emplace(id); + break; + case 0b00: + // Ambiguous / implementation-defined. Be conservative: + // treat id==0 as private, otherwise as shared. + if (id == 0) ++num_private; + else shared_ids.emplace(id); + break; + default: + break; + } + } + + return num_private + shared_ids.size(); + +#elif defined(__APPLE__) && defined(__aarch64__) + // table for known M4 variants. Users can override via GGML_KLEIDIAI_SME=. + char chip_name[256] = {}; + size_t size = sizeof(chip_name); + + if (sysctlbyname("machdep.cpu.brand_string", chip_name, &size, nullptr, 0) == 0) { + const std::string brand(chip_name); + + struct ModelSMCU { const char *match; size_t smcus; }; + static const ModelSMCU table[] = { + { "M4 Ultra", 2 }, + { "M4 Max", 2 }, + { "M4 Pro", 2 }, + { "M4", 1 }, + }; + for (const auto &e : table) { + if (brand.find(e.match) != std::string::npos) { + return e.smcus; + } + } + } + return 1; + +#else + return 1; +#endif +} + +static int parse_uint_env(const char *s, const char *name, bool *ok) { + if (!s) { *ok = false; return 0; } + char *end = nullptr; + long v = strtol(s, &end, 10); + if (end == s || *end != '\0') { + GGML_LOG_WARN("kleidiai: invalid %s='%s' (expected integer)\n", name, s); + *ok = false; + return 0; + } + if (v < 0 || v > INT_MAX) { + GGML_LOG_WARN("kleidiai: out-of-range %s='%s'\n", name, s); + *ok = false; + return 0; + } + *ok = true; + return (int)v; +} + +static void init_kleidiai_context(void) { ggml_critical_section_start(); static bool initialized = false; if (!initialized) { initialized = true; - const char *env_var = getenv("GGML_KLEIDIAI_SME"); - int sme_enabled = 0; + + const char *env_sme = getenv("GGML_KLEIDIAI_SME"); + const char *env_threads = getenv("GGML_TOTAL_THREADS"); + + const bool cpu_has_sme = ggml_cpu_has_sme(); + size_t detected_smcus = 0; ctx.features = (ggml_cpu_has_dotprod() ? CPU_FEATURE_DOTPROD : CPU_FEATURE_NONE) | (ggml_cpu_has_matmul_int8() ? CPU_FEATURE_I8MM : CPU_FEATURE_NONE) | ((ggml_cpu_has_sve() && ggml_cpu_get_sve_cnt() == QK8_0) ? CPU_FEATURE_SVE : CPU_FEATURE_NONE); - if (env_var) { - sme_enabled = atoi(env_var); + if (env_threads) { + bool ok = false; + int hint = parse_uint_env(env_threads, "GGML_TOTAL_THREADS", &ok); + if (ok && hint > 0) { + ctx.thread_hint = hint; + } } - if (sme_enabled != 0) { - ctx.features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE; + // SME policy: + // - If CPU doesn't support SME: SME always off. + // - Else: + // - env unset => auto-detect cores; enable if detected > 0. + // - env=0 => force off. + // - env>0 => force N cores (skip detection). + int sme_cores = 0; + bool sme_env_ok = false; + bool sme_env_set = (env_sme != nullptr); + + if (!cpu_has_sme) { + if (sme_env_set) { + bool ok = false; + int req = parse_uint_env(env_sme, "GGML_KLEIDIAI_SME", &ok); + if (ok && req > 0) { + GGML_LOG_WARN("kleidiai: GGML_KLEIDIAI_SME=%d but SME is not supported on this CPU; disabling SME\n", req); + } + } + sme_cores = 0; + } else { + if (sme_env_set) { + bool ok = false; + int v = parse_uint_env(env_sme, "GGML_KLEIDIAI_SME", &ok); + sme_env_ok = ok; + + if (!ok) { + GGML_LOG_WARN("kleidiai: GGML_KLEIDIAI_SME set but parsing failed; falling back to runtime SME-core detection\n"); + detected_smcus = detect_num_smcus(); + sme_cores = detected_smcus > 0 ? (int)detected_smcus : 0; + } else if (v == 0) { + sme_cores = 0; + } else { + sme_cores = v; + } + } else { + detected_smcus = detect_num_smcus(); + sme_cores = detected_smcus > 0 ? (int)detected_smcus : 0; + } + + if (!sme_env_set && sme_cores == 0) { + GGML_LOG_WARN("kleidiai: SME supported but runtime SME-core detection returned 0; falling back to NEON\n"); + } + + if (sme_cores > 0) { + ctx.features |= CPU_FEATURE_SME; + } } + + // Kernel selection ctx.kernels_q4 = ggml_kleidiai_select_kernels_q4_0(ctx.features); ctx.kernels_q8 = ggml_kleidiai_select_kernels_q8_0(ctx.features); -#ifndef NDEBUG - if (ctx.kernels_q4) { - GGML_LOG_DEBUG("kleidiai: using q4 kernel with CPU feature %s\n", cpu_feature_to_string(ctx.kernels_q4->required_cpu)); + + if (!ctx.kernels_q4) { + GGML_LOG_INFO("kleidiai: no compatible q4 kernels found for CPU features mask %d\n", (int)ctx.features); + } else { + GGML_LOG_INFO("kleidiai: primary q4 kernel feature %s\n", cpu_feature_to_string(ctx.kernels_q4->required_cpu)); + } + + if (!ctx.kernels_q8) { + GGML_LOG_INFO("kleidiai: no compatible q8 kernels found for CPU features mask %d\n", (int)ctx.features); + } else { + GGML_LOG_INFO("kleidiai: primary q8 kernel feature %s\n", cpu_feature_to_string(ctx.kernels_q8->required_cpu)); } - if (ctx.kernels_q8) { - GGML_LOG_DEBUG("kleidiai: using q8 kernel with CPU feature %s\n", cpu_feature_to_string(ctx.kernels_q8->required_cpu)); + + ctx.sme_thread_cap = (ctx.features & CPU_FEATURE_SME) ? sme_cores : 0; + + if (ctx.features & CPU_FEATURE_SME) { + if (sme_env_set && sme_env_ok && sme_cores > 0) { + GGML_LOG_INFO("kleidiai: SME enabled (GGML_KLEIDIAI_SME=%d override)\n", sme_cores); + } else { + GGML_LOG_INFO("kleidiai: SME enabled (runtime-detected SME cores=%d)\n", sme_cores); + } + } else { + GGML_LOG_INFO("kleidiai: SME disabled\n"); } -#endif } + ggml_critical_section_end(); } +static inline int kleidiai_sme_thread_cap() { + return ctx.sme_thread_cap; +} + +static inline size_t align_up(size_t value, size_t alignment) { + if (alignment == 0) { + return value; + } + const size_t remainder = value % alignment; + return remainder == 0 ? value : value + (alignment - remainder); +} + +static inline bool kleidiai_pack_fallback_allowed() { + if (ctx.sme_thread_cap <= 0) { + return false; + } + if (ctx.thread_hint <= 0) { + return true; + } + return ctx.thread_hint > ctx.sme_thread_cap; +} + +struct kleidiai_weight_header { + uint32_t magic; + uint16_t version; + uint16_t slot_count; + uint64_t offsets[GGML_KLEIDIAI_MAX_KERNEL_SLOTS]; + uint64_t sizes[GGML_KLEIDIAI_MAX_KERNEL_SLOTS]; +}; + +static inline kleidiai_weight_header * kleidiai_weight_header_from_ptr(void * data) { + return reinterpret_cast(data); +} + +static inline const kleidiai_weight_header * kleidiai_weight_header_from_ptr(const void * data) { + return reinterpret_cast(data); +} + +static inline bool kleidiai_is_weight_header_valid(const kleidiai_weight_header * header) { + if (!header) { + return false; + } + if (header->magic != GGML_KLEIDIAI_PACK_MAGIC || header->version != GGML_KLEIDIAI_PACK_VERSION) { + return false; + } + if (header->slot_count == 0 || header->slot_count > GGML_KLEIDIAI_MAX_KERNEL_SLOTS) { + return false; + } + return true; +} + +static inline uint8_t * kleidiai_weight_slot_ptr(kleidiai_weight_header * header, int slot) { + if (!kleidiai_is_weight_header_valid(header)) { + return nullptr; + } + if (slot < 0 || slot >= header->slot_count) { + return nullptr; + } + return reinterpret_cast(header) + header->offsets[slot]; +} + +static inline const uint8_t * kleidiai_weight_slot_ptr(const kleidiai_weight_header * header, int slot) { + if (!kleidiai_is_weight_header_valid(header)) { + return nullptr; + } + if (slot < 0 || slot >= header->slot_count) { + return nullptr; + } + return reinterpret_cast(header) + header->offsets[slot]; +} + +static inline ggml_kleidiai_kernels * kleidiai_primary_kernel_q4() { + return ctx.kernels_q4; +} + +static inline ggml_kleidiai_kernels * kleidiai_primary_kernel_q8() { + return ctx.kernels_q8; +} + +template +static int kleidiai_collect_kernel_chain_common( + ggml_kleidiai_kernels * primary, + cpu_feature features, + std::array & out, + SelectFallback select_fallback) { + int count = 0; + if (!primary) { + return 0; + } + out[count++] = primary; + + if ((primary->required_cpu & CPU_FEATURE_SME) == CPU_FEATURE_SME) { + const cpu_feature fallback_mask = static_cast(features & ~CPU_FEATURE_SME); + if (fallback_mask != CPU_FEATURE_NONE) { + ggml_kleidiai_kernels * fallback = select_fallback(fallback_mask); + if (fallback && fallback != primary && + fallback->lhs_type == primary->lhs_type && + fallback->rhs_type == primary->rhs_type && + fallback->op_type == primary->op_type) { + out[count++] = fallback; + } + } + } + + return count; +} + +static int kleidiai_collect_kernel_chain(const struct ggml_tensor * op, + std::array & out) { + ggml_kleidiai_kernels * primary = ggml_kleidiai_select_kernels(ctx.features, op); + return kleidiai_collect_kernel_chain_common(primary, ctx.features, out, + [&](cpu_feature mask) { return ggml_kleidiai_select_kernels(mask, op); }); +} + +static int kleidiai_collect_q4_chain(std::array & out) { + ggml_kleidiai_kernels * primary = kleidiai_primary_kernel_q4(); + return kleidiai_collect_kernel_chain_common(primary, ctx.features, out, + [&](cpu_feature mask) { return ggml_kleidiai_select_kernels_q4_0(mask); }); +} + +static int kleidiai_collect_q8_chain(std::array & out) { + ggml_kleidiai_kernels * primary = kleidiai_primary_kernel_q8(); + return kleidiai_collect_kernel_chain_common(primary, ctx.features, out, + [&](cpu_feature mask) { return ggml_kleidiai_select_kernels_q8_0(mask); }); +} + static inline int64_t ggml_ne(const ggml_tensor * tensor, int dim) { GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS); return tensor->ne[dim]; @@ -126,49 +438,108 @@ class tensor_traits : public ggml::cpu::tensor_traits { if (op->op != GGML_OP_MUL_MAT) { return false; } - ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, op); - if (!kernels) { + + std::array kernel_chain; + const int slot_count = kleidiai_collect_kernel_chain(op, kernel_chain); + if (slot_count == 0) { return false; } - bool is_gemv = op->src[1]->ne[1] == 1; - kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm; - lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info; - size_t k = op->src[0]->ne[0]; - size_t n = op->src[0]->ne[1]; - size_t m = op->src[1]->ne[1]; - - size_t mr = kernel->get_mr(); - size_t kr = kernel->get_kr(); - size_t sr = kernel->get_sr(); - - if (kernels->rhs_type == GGML_TYPE_Q4_0) { - if (!lhs_info->packed_size_ex) return false; - size = lhs_info->packed_size_ex(m, k, QK4_0, mr, kr, sr); - } else if (kernels->rhs_type == GGML_TYPE_Q8_0) { - if (!lhs_info->packed_size_ex) return false; - size = lhs_info->packed_size_ex(m, k, QK8_0, mr, kr, sr); - } else if (kernels->rhs_type == GGML_TYPE_F16) { - if (!lhs_info->packed_size_ex || !kernels->rhs_info.packed_size_ex) return false; + const bool is_gemv = op->src[1]->ne[1] == 1; + const size_t k = op->src[0]->ne[0]; + const size_t n = op->src[0]->ne[1]; + const size_t m = op->src[1]->ne[1]; + + if (op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q8_0) { + const size_t qk = (op->src[0]->type == GGML_TYPE_Q4_0) ? QK4_0 : QK8_0; + + size_t cursor = 0; + bool any_slot = false; + + for (int slot = 0; slot < slot_count; ++slot) { + ggml_kleidiai_kernels * kernels = kernel_chain[slot]; + lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info; + kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm; + + if (!lhs_info || !lhs_info->packed_size_ex || !kernel) { + return false; + } + + const size_t mr = kernel->get_mr(); + const size_t kr = kernel->get_kr(); + const size_t sr = kernel->get_sr(); + + const size_t packed = lhs_info->packed_size_ex(m, k, qk, mr, kr, sr); + + cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN); + cursor += packed; + any_slot = true; + } + + if (!any_slot) { + return false; + } + + size = cursor; + return true; + } + + if (op->src[0]->type == GGML_TYPE_F16) { const int64_t lhs_batch_size0 = op->src[1]->ne[2]; const int64_t rhs_batch_size0 = op->src[0]->ne[2]; + GGML_ASSERT(rhs_batch_size0 > 0); const int64_t r = lhs_batch_size0 / rhs_batch_size0; - size = lhs_info->packed_size_ex(m * r, k, 0, mr, kr, sr) + - kernels->rhs_info.packed_size_ex(n, k, kernel->get_nr(), kernel->get_kr(), 0) + - k * n * sizeof(float) + n * sizeof(float); - } else { - return false; + + size_t cursor = 0; + bool any_slot = false; + + for (int slot = 0; slot < slot_count; ++slot) { + ggml_kleidiai_kernels * kernels = kernel_chain[slot]; + lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info; + kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm; + if (!lhs_info || !lhs_info->packed_size_ex || !kernels->rhs_info.packed_size_ex || !kernel) { + return false; + } + + const size_t mr = kernel->get_mr(); + const size_t kr = kernel->get_kr(); + const size_t sr = kernel->get_sr(); + + cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN); + cursor += lhs_info->packed_size_ex(m * r, k, 0, mr, kr, sr); + any_slot = true; + } + + for (int slot = 0; slot < slot_count; ++slot) { + ggml_kleidiai_kernels * kernels = kernel_chain[slot]; + kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm; + if (!kernel || !kernels->rhs_info.packed_size_ex) { + return false; + } + cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN); + cursor += kernels->rhs_info.packed_size_ex(n, k, kernel->get_nr(), kernel->get_kr(), 0); + } + + cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN); + cursor += k * n * sizeof(float); + cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN); + cursor += n * sizeof(float); + + if (!any_slot) { + return false; + } + + size = cursor; + return true; } - return true; + return false; } bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * dst) override { if (dst->op == GGML_OP_MUL_MAT) { - if (dst->src[0]->type == GGML_TYPE_Q4_0) { - return compute_forward_q4_0(params, dst); - } else if (dst->src[0]->type == GGML_TYPE_Q8_0) { - return compute_forward_q8_0(params, dst); + if (dst->src[0]->type == GGML_TYPE_Q4_0 || dst->src[0]->type == GGML_TYPE_Q8_0) { + return compute_forward_qx(params, dst); } else if (dst->src[0]->type == GGML_TYPE_F16) { return compute_forward_fp16(params, dst); } @@ -331,204 +702,457 @@ class tensor_traits : public ggml::cpu::tensor_traits { return true; } - bool compute_forward_q4_0(struct ggml_compute_params * params, struct ggml_tensor * dst) { - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0); + bool compute_forward_qx(struct ggml_compute_params * params, struct ggml_tensor * dst) { + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0 || dst->src[0]->type == GGML_TYPE_Q8_0); const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; GGML_TENSOR_BINARY_OP_LOCALS - ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst); - if (!kernels) { - return false; - } - - bool is_gemv = src1->ne[1] == 1; - kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm; - lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info; - - GGML_ASSERT(kernel); - if (!lhs_info->get_packed_offset_ex || !lhs_info->pack_func_ex || - !kernel->get_rhs_packed_offset_ex || !kernel->run_kernel_ex || !kernel->get_dst_offset) { - return false; - } + const kleidiai_weight_header * header = kleidiai_weight_header_from_ptr(src0->data); + const bool has_header = kleidiai_is_weight_header_valid(header); + const bool is_gemv = src1->ne[1] == 1; + std::array kernel_chain; + const int slot_total = kleidiai_collect_kernel_chain(dst, kernel_chain); - const int ith = params->ith; - const int nth_raw = params->nth; - const int nth = nth_raw > 0 ? nth_raw : 1; + auto weight_for_slot = [&](int slot_index, size_t & size_out) -> const uint8_t * { + if (slot_index < 0 || slot_index >= slot_total) { + return nullptr; + } + if (has_header) { + if (slot_index < header->slot_count) { + size_out = static_cast(header->sizes[slot_index]); + return kleidiai_weight_slot_ptr(header, slot_index); + } + return nullptr; + } + if (slot_index == 0) { + size_out = ggml_nbytes(src0); + return static_cast(src0->data); + } + return nullptr; + }; + + struct runtime_slot { + int slot_index; + ggml_kleidiai_kernels * kernels; + kernel_info * kernel; + lhs_packing_info * lhs_info; + size_t mr; + size_t nr; + size_t kr; + size_t sr; + size_t n_step; + size_t lhs_packed_size; + size_t lhs_offset; + size_t n_offset; + size_t n_cols; + int assigned_threads; + int thread_begin; + int thread_end; + const uint8_t * rhs_base; + }; + + std::array runtime{}; + int runtime_count = 0; + + for (int slot = 0; slot < slot_total && runtime_count < GGML_KLEIDIAI_MAX_KERNEL_SLOTS; ++slot) { + ggml_kleidiai_kernels * kernels = kernel_chain[slot]; + kernel_info * kinfo = is_gemv ? &kernels->gemv : &kernels->gemm; + lhs_packing_info * linfo = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info; + if (!kinfo || !linfo || !linfo->packed_size_ex || !linfo->pack_func_ex || !linfo->get_offset || + !kinfo->get_rhs_packed_offset_ex || !kinfo->run_kernel_ex || !kinfo->get_dst_offset) { + continue; + } - const size_t k = ne00; - const size_t m = ne11; - const size_t n = ne01; + size_t rhs_size = 0; + const uint8_t * rhs_ptr = weight_for_slot(slot, rhs_size); + if (!rhs_ptr || rhs_size == 0) { + continue; + } - size_t mr = kernel->get_mr(); - size_t kr = kernel->get_kr(); - size_t sr = kernel->get_sr(); + runtime[runtime_count] = { + slot, + kernels, + kinfo, + linfo, + kinfo->get_mr(), + kinfo->get_nr(), + kinfo->get_kr(), + kinfo->get_sr(), + kinfo->get_n_step(), + 0, + 0, + 0, + 0, + 0, + 0, + 0, + rhs_ptr + }; + ++runtime_count; + } - const uint8_t * lhs = static_cast(src1->data); - uint8_t * lhs_packed = (uint8_t*)params->wdata; - const uint8_t * rhs_packed = static_cast(src0->data); + if (runtime_count == 0) { + ggml_kleidiai_kernels * fallback = ggml_kleidiai_select_kernels(ctx.features, dst); + if (!fallback) { + return false; + } + kernel_info * kinfo = is_gemv ? &fallback->gemv : &fallback->gemm; + lhs_packing_info * linfo = is_gemv ? &fallback->gemv_lhs_info : &fallback->gemm_lhs_info; + rhs_packing_info * rinfo = &fallback->rhs_info; + if (!kinfo || !linfo || !linfo->packed_size_ex || !linfo->pack_func_ex || + !kinfo->get_rhs_packed_offset_ex || !kinfo->run_kernel_ex || !kinfo->get_dst_offset || + !rinfo || !rinfo->pack_func_ex || !rinfo->packed_size_ex) { + return false; + } + kernel_chain[0] = fallback; + runtime[0] = { + 0, + fallback, + kinfo, + linfo, + kinfo->get_mr(), + kinfo->get_nr(), + kinfo->get_kr(), + kinfo->get_sr(), + kinfo->get_n_step(), + 0, + 0, + 0, + 0, + 0, + 0, + 0, + nullptr + }; + size_t rhs_size_fallback = 0; + const uint8_t * rhs_base = weight_for_slot(0, rhs_size_fallback); + if (!rhs_base) { + rhs_base = static_cast(src0->data); + } + runtime[0].rhs_base = rhs_base; + runtime_count = 1; + } - const size_t n_step = kernel->get_n_step(); - const size_t num_n_per_thread = kai_roundup(kai_roundup(n, nth) / nth, n_step); - const size_t n_start = ith * num_n_per_thread; + const int nth_total = params->nth > 0 ? params->nth : 1; + const int ith_total = params->ith; - size_t n_to_process = 0; - if (n_start < n) { - n_to_process = num_n_per_thread; - if ((n_start + n_to_process) > n) { - n_to_process = n - n_start; + int sme_slot = -1; + for (int i = 0; i < runtime_count; ++i) { + if ((runtime[i].kernels->required_cpu & CPU_FEATURE_SME) == CPU_FEATURE_SME) { + sme_slot = i; + break; } } - // Calculate number of columns to be processed per thread - const size_t num_m_per_thread = kai_roundup(m, mr * nth) / nth; - const size_t m_start = ith * num_m_per_thread; - size_t m_to_process = num_m_per_thread; - if ((m_start + m_to_process) > m) { - m_to_process = m - m_start; + const int sme_cap_limit = ctx.sme_thread_cap; + const bool use_hybrid = sme_cap_limit > 0 && + runtime_count > 1 && + nth_total > sme_cap_limit; + // Heuristic: disable hybrid for very small workloads where per-slot overhead dominates. + // If rows are small or average columns per thread are small, keep single-slot. + size_t min_cols_per_thread = 0; + if (runtime_count > 0 && nth_total > 0) { + min_cols_per_thread = (size_t) std::max(1, (int64_t)ne01 / (int64_t)nth_total); } + const bool too_small_for_hybrid = (min_cols_per_thread < 2) || (ne11 < 128); - if (m_start < m) { - // Transform LHS - const size_t src_stride = src1->nb[1]; - const float * src_ptr = reinterpret_cast(lhs + lhs_info->get_offset(m_start, dst->src[1]->nb[1])); - const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(m_start, k, QK4_0, mr, kr, sr); - void * lhs_packed_ptr = static_cast(lhs_packed + lhs_packed_offset); - - // Pack this thread's chunk with m_idx_start = 0 and per-thread output pointer - lhs_info->pack_func_ex(m_to_process, k, QK4_0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr); - } + const bool hybrid_enabled = use_hybrid && !too_small_for_hybrid; - ggml_barrier(params->threadpool); + if (!hybrid_enabled) { + int chosen_slot = 0; + if (too_small_for_hybrid && sme_slot != -1) { + chosen_slot = sme_slot; + } else if (runtime_count > 1 && ctx.sme_thread_cap > 0 && nth_total > ctx.sme_thread_cap) { + chosen_slot = 1; + } + if (chosen_slot != 0 && chosen_slot < runtime_count) { + runtime[0] = runtime[chosen_slot]; + } + runtime_count = runtime_count > 0 ? 1 : 0; - // Perform the operation - const size_t dst_stride = dst->nb[1]; - const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(0, k, QK4_0, mr, kr, sr); - const size_t rhs_packed_offset = kernel->get_rhs_packed_offset_ex(n_start, k, QK4_0); - const size_t dst_offset = kernel->get_dst_offset(0, n_start, dst_stride); - const void * rhs_ptr = static_cast(rhs_packed + rhs_packed_offset); - const void* lhs_ptr = (const void*)((const char *)lhs_packed + lhs_packed_offset); - float *dst_ptr = reinterpret_cast(static_cast(dst->data) + dst_offset); + // Recompute SME slot based on the collapsed runtime[0] + sme_slot = -1; + if (runtime_count > 0 && + (runtime[0].kernels->required_cpu & CPU_FEATURE_SME) == CPU_FEATURE_SME) { + sme_slot = 0; + } + } - if (n_to_process > 0) { - kernel->run_kernel_ex(m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride, - sizeof(float), -FLT_MAX, FLT_MAX); + int sme_cap = kleidiai_sme_thread_cap(); + if (sme_cap < 0) { + sme_cap = nth_total; } + sme_cap = std::min(sme_cap, nth_total); - return true; - } + int threads_remaining = nth_total; + if (sme_slot != -1) { + int sme_threads = std::min(std::max(sme_cap, 0), threads_remaining); + runtime[sme_slot].assigned_threads = sme_threads; + threads_remaining -= sme_threads; + } - bool compute_forward_q8_0(struct ggml_compute_params * params, struct ggml_tensor * dst) { - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q8_0); + int fallback_indices[GGML_KLEIDIAI_MAX_KERNEL_SLOTS]; + int fallback_count = 0; + for (int i = 0; i < runtime_count; ++i) { + if (i == sme_slot) { + continue; + } + fallback_indices[fallback_count++] = i; + } - const ggml_tensor * src0 = dst->src[0]; - const ggml_tensor * src1 = dst->src[1]; + for (int fi = 0; fi < fallback_count; ++fi) { + if (threads_remaining <= 0) { + break; + } + const int slot_index = fallback_indices[fi]; + const int slots_left = fallback_count - fi; + int share = (threads_remaining + slots_left - 1) / slots_left; + share = std::min(share, threads_remaining); + runtime[slot_index].assigned_threads = share; + threads_remaining -= share; + } - GGML_TENSOR_BINARY_OP_LOCALS + if (threads_remaining > 0) { + const int fallback_slot = (sme_slot != -1) ? sme_slot : 0; + runtime[fallback_slot].assigned_threads += threads_remaining; + threads_remaining = 0; + } - ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst); - if (!kernels) { - return false; + int thread_cursor = 0; + for (int i = 0; i < runtime_count; ++i) { + runtime[i].thread_begin = thread_cursor; + thread_cursor += runtime[i].assigned_threads; + runtime[i].thread_end = thread_cursor; } - bool is_gemv = src1->ne[1] == 1; - kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm; - lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info; + if (thread_cursor < nth_total && runtime_count > 0) { + runtime[runtime_count - 1].assigned_threads += nth_total - thread_cursor; + runtime[runtime_count - 1].thread_end = nth_total; + } - if (!kernel || !lhs_info->get_packed_offset_ex || !lhs_info->pack_func_ex || - !kernel->get_rhs_packed_offset_ex || !kernel->run_kernel_ex || !kernel->get_dst_offset) { + int local_slot = -1; + int local_ith = 0; + for (int i = 0; i < runtime_count; ++i) { + if (ith_total >= runtime[i].thread_begin && ith_total < runtime[i].thread_end) { + local_slot = i; + local_ith = ith_total - runtime[i].thread_begin; + break; + } + } + if (local_slot == -1) { return false; } - const int ith = params->ith; - const int nth_raw = params->nth; - const int nth = nth_raw > 0 ? nth_raw : 1; - const size_t k = ne00; const size_t m = ne11; const size_t n = ne01; - size_t mr = kernel->get_mr(); - size_t kr = kernel->get_kr(); - size_t sr = kernel->get_sr(); - - const uint8_t * lhs = static_cast(src1->data); - uint8_t * lhs_packed = static_cast(params->wdata); - const uint8_t * rhs_packed = static_cast(src0->data); + size_t cursor = 0; + for (int i = 0; i < runtime_count; ++i) { + const ggml_type slot_rhs_type = runtime[i].kernels->rhs_type; + const size_t slot_pack_size_arg = slot_rhs_type == GGML_TYPE_Q4_0 ? QK4_0 : + slot_rhs_type == GGML_TYPE_Q8_0 ? QK8_0 : 0; + runtime[i].lhs_packed_size = runtime[i].lhs_info->packed_size_ex(m, k, slot_pack_size_arg, runtime[i].mr, runtime[i].kr, runtime[i].sr); + cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN); + runtime[i].lhs_offset = cursor; + cursor += runtime[i].lhs_packed_size; + } - const size_t n_step = kernel->get_n_step(); - const size_t num_n_per_thread = kai_roundup(kai_roundup(n, nth) / nth, n_step); - const size_t n_start = ith * num_n_per_thread; + GGML_ASSERT(cursor <= params->wsize); + uint8_t * scratch = static_cast(params->wdata); - size_t n_to_process = 0; - if (n_start < n) { - n_to_process = num_n_per_thread; - if ((n_start + n_to_process) > n) { - n_to_process = n - n_start; + size_t assigned_cols = 0; + uint64_t weighted_total = 0; + if (runtime_count > 1 && sme_slot != -1) { + for (int i = 0; i < runtime_count; ++i) { + const uint64_t weight = (i == sme_slot) ? (sme_cap << 1) : 1; + weighted_total += (uint64_t)runtime[i].assigned_threads * weight; } } + for (int i = 0; i < runtime_count; ++i) { + runtime[i].n_offset = assigned_cols; + if (runtime[i].assigned_threads == 0) { + runtime[i].n_cols = 0; + continue; + } + const size_t remaining_cols = n - assigned_cols; + if (remaining_cols == 0) { + runtime[i].n_cols = 0; + continue; + } + const size_t step = runtime[i].n_step ? runtime[i].n_step : 1; + size_t target = 0; + if (weighted_total > 0) { + const uint64_t weight = (i == sme_slot) ? (sme_cap << 1) : 1; + target = (size_t)(((uint64_t)n * runtime[i].assigned_threads * weight) / weighted_total); + } else { + target = (size_t)(((uint64_t)n * runtime[i].assigned_threads) / nth_total); + } + target = std::min(target, remaining_cols); + size_t aligned = round_down(target, step); + if (aligned == 0 && remaining_cols >= step) { + aligned = step; + } + runtime[i].n_cols = aligned; + assigned_cols += aligned; + } - const size_t num_m_per_thread = kai_roundup(m, mr * nth) / nth; - const size_t m_start = ith * num_m_per_thread; - size_t m_to_process = num_m_per_thread; - if ((m_start + m_to_process) > m) { - m_to_process = m - m_start; + if (assigned_cols < n) { + for (int i = runtime_count - 1; i >= 0; --i) { + if (runtime[i].assigned_threads > 0) { + runtime[i].n_cols += n - assigned_cols; + break; + } + } } + const size_t dst_stride = dst->nb[1]; - if (m_start < m) { - const size_t src_stride = src1->nb[1]; - const float * src_ptr = reinterpret_cast(lhs + lhs_info->get_offset(m_start, dst->src[1]->nb[1])); - const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(m_start, k, 0, mr, kr, sr); - void * lhs_packed_ptr = static_cast(lhs_packed + lhs_packed_offset); + for (int64_t batch_idx = 0; batch_idx < ne12; ++batch_idx) { + const uint8_t * lhs_batch_base = static_cast(src1->data) + batch_idx * src1->nb[2]; + uint8_t * dst_batch_base = static_cast(dst->data) + batch_idx * dst->nb[2]; - lhs_info->pack_func_ex(m_to_process, k, 0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr); - } + if (runtime[local_slot].assigned_threads > 0) { + runtime_slot & slot = runtime[local_slot]; + const ggml_type slot_rhs_type = slot.kernels->rhs_type; + const size_t slot_lhs_exec_arg = slot_rhs_type == GGML_TYPE_Q4_0 ? QK4_0 : + slot_rhs_type == GGML_TYPE_Q8_0 ? 0 : 0; + const int64_t m_roundup_mr = kai_roundup((int64_t)m, (int64_t)slot.mr); + int64_t max_threads = slot.mr ? (m_roundup_mr / (int64_t)slot.mr) : slot.assigned_threads; + max_threads = std::max(1, max_threads); + const int64_t use_threads = std::min(slot.assigned_threads, max_threads); - ggml_barrier(params->threadpool); + if (local_ith < use_threads) { + const int64_t num_m_per_thread0 = round_down((size_t)(m_roundup_mr / use_threads), slot.mr); + const int64_t num_m_per_threadN_1 = (int64_t)m - (use_threads - 1) * num_m_per_thread0; - const size_t dst_stride = dst->nb[1]; - const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(0, k, 0, mr, kr, sr); - const size_t rhs_packed_offset = kernel->get_rhs_packed_offset_ex(n_start, k, 0); - const size_t dst_offset = kernel->get_dst_offset(0, n_start, dst_stride); - const void * rhs_ptr = static_cast(rhs_packed + rhs_packed_offset); - const void * lhs_ptr = static_cast(lhs_packed + lhs_packed_offset); - float * dst_ptr = reinterpret_cast(static_cast(dst->data) + dst_offset); + const int64_t m_start = (int64_t)local_ith * num_m_per_thread0; + const int64_t m_count = (local_ith == use_threads - 1) ? num_m_per_threadN_1 : num_m_per_thread0; + + const size_t base_packed_off = slot.lhs_info->get_packed_offset_ex(m_start, k, slot_lhs_exec_arg, slot.mr, slot.kr, slot.sr); + const size_t next_block_off = slot.lhs_info->get_packed_offset_ex(m_start + slot.mr, k, slot_lhs_exec_arg, slot.mr, slot.kr, slot.sr); + const size_t row_stride_bytes = slot.mr ? (next_block_off - base_packed_off) / slot.mr : 0; + + int64_t remaining = m_count; + int64_t cur = m_start; + + uint8_t * lhs_packed = scratch + slot.lhs_offset; + while (remaining > 0) { + const int64_t row_in_group = cur; + const int64_t avail = (int64_t)m - row_in_group; + const int64_t take = std::min(avail, remaining); + + const size_t src_off = slot.lhs_info->get_offset(row_in_group, src1->nb[1]); + const void * src_ptr = lhs_batch_base + src_off; + const size_t dst_off = base_packed_off + (size_t)(cur - m_start) * row_stride_bytes; + void * dst_ptr = lhs_packed + dst_off; + + slot.lhs_info->pack_func_ex(take, k, slot_lhs_exec_arg, slot.mr, slot.kr, slot.sr, 0, src_ptr, src1->nb[1], dst_ptr); + + cur += take; + remaining -= take; + } + } + } + + ggml_barrier(params->threadpool); - if (n_to_process > 0) { - kernel->run_kernel_ex(m, n_to_process, k, 0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride, - sizeof(float), -FLT_MAX, FLT_MAX); + runtime_slot & slot = runtime[local_slot]; + if (slot.n_cols > 0 && slot.assigned_threads > 0) { + int64_t active_threads = slot.assigned_threads; + const int64_t max_threads = slot.n_step ? (slot.n_cols / slot.n_step) : slot.assigned_threads; + if (max_threads > 0) { + active_threads = std::min(active_threads, std::max(1, max_threads)); + } + active_threads = std::max(1, active_threads); + + if (local_ith < active_threads) { + const size_t step = slot.n_step ? slot.n_step : 1; + const size_t chunk0 = round_down((size_t)(slot.n_cols / active_threads), step); + const size_t chunkN = slot.n_cols - (active_threads - 1) * chunk0; + const size_t local_start = (size_t)local_ith * chunk0; + const size_t cols = (local_ith == active_threads - 1) ? chunkN : chunk0; + + if (cols > 0) { + const ggml_type slot_rhs_type = slot.kernels->rhs_type; + const size_t slot_lhs_exec_arg = slot_rhs_type == GGML_TYPE_Q4_0 ? QK4_0 : + slot_rhs_type == GGML_TYPE_Q8_0 ? 0 : 0; + const size_t slot_rhs_block_arg = slot_rhs_type == GGML_TYPE_Q4_0 ? QK4_0 : + slot_rhs_type == GGML_TYPE_Q8_0 ? 0 : 0; + const size_t global_start = slot.n_offset + local_start; + const size_t lhs_packed_offset = slot.lhs_info->get_packed_offset_ex(0, k, slot_lhs_exec_arg, slot.mr, slot.kr, slot.sr); + const size_t rhs_packed_offset = slot.kernel->get_rhs_packed_offset_ex(global_start, k, slot_rhs_block_arg); + const size_t dst_offset = slot.kernel->get_dst_offset(0, global_start, dst_stride); + + const uint8_t * lhs_ptr = scratch + slot.lhs_offset + lhs_packed_offset; + const uint8_t * rhs_ptr = slot.rhs_base + rhs_packed_offset; + float * dst_ptr = reinterpret_cast(dst_batch_base + dst_offset); + + slot.kernel->run_kernel_ex(m, cols, k, slot_rhs_block_arg, + lhs_ptr, + rhs_ptr, + dst_ptr, + dst_stride, + sizeof(float), + -FLT_MAX, + FLT_MAX); + } + } + } + + if (batch_idx != ne12 - 1) { + ggml_barrier(params->threadpool); + } } return true; } bool compute_forward_get_rows(struct ggml_compute_params * params, struct ggml_tensor * dst) { + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0 || dst->src[0]->type == GGML_TYPE_Q8_0); const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; GGML_TENSOR_BINARY_OP_LOCALS + const kleidiai_weight_header * header = kleidiai_weight_header_from_ptr(src0->data); + const bool has_header = kleidiai_is_weight_header_valid(header); + + std::array kernel_chain; + const bool want_q8 = src0->type == GGML_TYPE_Q8_0; + const int chain_count = want_q8 ? kleidiai_collect_q8_chain(kernel_chain) + : kleidiai_collect_q4_chain(kernel_chain); + ggml_kleidiai_kernels * kernels = nullptr; - size_t block_len = 0; - size_t num_bytes_multiplier = 0; + const uint8_t * packed_base = static_cast(src0->data); - if (dst->src[0]->type == GGML_TYPE_Q4_0) { - if (!ctx.kernels_q4) { - return false; + if (has_header && chain_count > 0) { + int select_slot = 0; + if (select_slot >= header->slot_count) { + select_slot = header->slot_count - 1; } - kernels = ctx.kernels_q4; - block_len = QK4_0; - num_bytes_multiplier = sizeof(uint16_t); - } else if (dst->src[0]->type == GGML_TYPE_Q8_0) { - if (!ctx.kernels_q8) { - return false; + if (select_slot >= 0 && select_slot < chain_count) { + kernels = kernel_chain[select_slot]; + const uint8_t * slot_ptr = kleidiai_weight_slot_ptr(header, select_slot); + if (slot_ptr) { + packed_base = slot_ptr; + } } - kernels = ctx.kernels_q8; - block_len = QK8_0; - num_bytes_multiplier = sizeof(float); - } else { + } + + if (!kernels && chain_count > 0) { + kernels = kernel_chain[0]; + if (has_header) { + const uint8_t * slot_ptr = kleidiai_weight_slot_ptr(header, 0); + if (slot_ptr) { + packed_base = slot_ptr; + } + } + } + + if (!kernels) { return false; } @@ -541,6 +1165,19 @@ class tensor_traits : public ggml::cpu::tensor_traits { const int64_t nc = ne00; const int64_t nr = ggml_nelements(src1); + const ggml_type rhs_type = kernels->rhs_type; + size_t block_len = 0; + size_t num_bytes_multiplier = 0; + if (rhs_type == GGML_TYPE_Q4_0) { + block_len = QK4_0; + num_bytes_multiplier = sizeof(uint16_t); + } else if (rhs_type == GGML_TYPE_Q8_0) { + block_len = QK8_0; + num_bytes_multiplier = sizeof(float); + } else { + return false; + } + const size_t block_rows = kernel->get_nr(); const size_t kr = kernel->get_kr(); @@ -559,7 +1196,7 @@ class tensor_traits : public ggml::cpu::tensor_traits { GGML_ASSERT(row_idx >= 0 && row_idx < src0->ne[1]); float *out = (float *)((char *)dst->data + i * nb1); - rhs_info->to_float(src0->data, row_idx, nc, out, block_rows, packed_stride, kr, block_len, num_bytes_multiplier); + rhs_info->to_float(packed_base, row_idx, nc, out, block_rows, packed_stride, kr, block_len, num_bytes_multiplier); } return true; @@ -567,36 +1204,39 @@ class tensor_traits : public ggml::cpu::tensor_traits { public: int repack(struct ggml_tensor * tensor, const void * data, size_t data_size) { + GGML_ASSERT(tensor->type == GGML_TYPE_Q4_0 || tensor->type == GGML_TYPE_Q8_0); const size_t n = tensor->ne[1]; const size_t k = tensor->ne[0]; - if (tensor->type == GGML_TYPE_Q4_0) { - if (!ctx.kernels_q4) { - return -1; - } - size_t nr = ctx.kernels_q4->gemm.get_nr(); - size_t kr = ctx.kernels_q4->gemm.get_kr(); - size_t sr = ctx.kernels_q4->gemm.get_sr(); + kleidiai_weight_header * header = kleidiai_weight_header_from_ptr(tensor->data); + if (!header) { + return -1; + } - struct kai_rhs_pack_qs4cxs1s0_param params; - params.lhs_zero_point = 1; - params.rhs_zero_point = 8; - ctx.kernels_q4->rhs_info.pack_func_ex(1, n, k, nr, kr, sr, QK4_0, 0, - static_cast(data), - nullptr, nullptr, tensor->data, 0, ¶ms); - GGML_UNUSED(data_size); - return 0; - } else if (tensor->type == GGML_TYPE_Q8_0) { - if (!ctx.kernels_q8) { - return -1; - } + header->magic = GGML_KLEIDIAI_PACK_MAGIC; + header->version = GGML_KLEIDIAI_PACK_VERSION; + header->slot_count = 0; + + uint8_t * base_ptr = static_cast(tensor->data); + size_t cursor = sizeof(kleidiai_weight_header); + cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN); + + std::array kernel_chain; + const bool want_q8 = tensor->type == GGML_TYPE_Q8_0; + const int slot_total = want_q8 ? kleidiai_collect_q8_chain(kernel_chain) + : kleidiai_collect_q4_chain(kernel_chain); + const bool allow_fallback = kleidiai_pack_fallback_allowed(); + + std::vector qdata; + std::vector scales; + + if (want_q8 && slot_total > 0) { + qdata.resize(n * k, 0); + scales.resize(n, 0.0f); const size_t row_stride = tensor->nb[1]; const size_t k_blocks = (k + QK8_0 - 1) / QK8_0; - std::vector qdata(n * k, 0); - std::vector scales(n, 0.0f); - for (size_t row = 0; row < n; ++row) { const auto * row_blocks = reinterpret_cast( static_cast(data) + row * row_stride); @@ -610,7 +1250,7 @@ class tensor_traits : public ggml::cpu::tensor_traits { if (linear_idx >= k) { break; } - const float value = d * blk.qs[l]; + const float value = d * static_cast(blk.qs[l]); max_abs = std::max(max_abs, std::fabs(value)); } } @@ -627,31 +1267,73 @@ class tensor_traits : public ggml::cpu::tensor_traits { if (linear_idx >= k) { break; } - const float value = d * blk.qs[l]; + const float value = d * static_cast(blk.qs[l]); int32_t q = scale > 0.0f ? static_cast(std::lround(value * inv_scale)) : 0; q = std::clamp(q, -127, 127); qdata[row * k + linear_idx] = static_cast(q); } } } + } + + for (int slot = 0; slot < slot_total && slot < GGML_KLEIDIAI_MAX_KERNEL_SLOTS; ++slot) { + if (!allow_fallback && slot > 0) { + break; + } + ggml_kleidiai_kernels * kernels = kernel_chain[slot]; + kernel_info * kernel = &kernels->gemm; + rhs_packing_info * rhs_info = &kernels->rhs_info; + if (!rhs_info || !rhs_info->pack_func_ex || !rhs_info->packed_size_ex || !kernel) { + continue; + } + + const size_t nr = kernel->get_nr(); + const size_t kr = kernel->get_kr(); + const size_t sr = kernel->get_sr(); + const ggml_type rhs_type = kernels->rhs_type; + const size_t block_len = rhs_type == GGML_TYPE_Q8_0 ? QK8_0 : + rhs_type == GGML_TYPE_Q4_0 ? QK4_0 : 0; + if (block_len == 0) { + continue; + } - size_t nr = ctx.kernels_q8->gemm.get_nr(); - size_t kr = ctx.kernels_q8->gemm.get_kr(); - size_t sr = ctx.kernels_q8->gemm.get_sr(); + const size_t packed_size = rhs_info->packed_size_ex(n, k, nr, kr, block_len); + const size_t aligned_cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN); + + uint8_t * dst_ptr = base_ptr + aligned_cursor; + + if (rhs_type == GGML_TYPE_Q4_0) { + struct kai_rhs_pack_qs4cxs1s0_param params; + params.lhs_zero_point = 1; + params.rhs_zero_point = 8; + rhs_info->pack_func_ex(1, n, k, nr, kr, sr, QK4_0, 0, + static_cast(data), nullptr, nullptr, + dst_ptr, 0, ¶ms); + } else if (rhs_type == GGML_TYPE_Q8_0) { + struct kai_rhs_pack_qsi8cx_params params; + params.lhs_zero_point = 1; + params.scale_multiplier = 1.0f; + rhs_info->pack_func_ex(1, n, k, nr, kr, sr, 0, 0, + qdata.data(), nullptr, scales.data(), + dst_ptr, 0, ¶ms); + } else { + continue; + } + + header->offsets[header->slot_count] = aligned_cursor; + header->sizes[header->slot_count] = packed_size; + ++header->slot_count; - struct kai_rhs_pack_qsi8cx_params params; - params.lhs_zero_point = 1; - params.scale_multiplier = 1.0f; + cursor = aligned_cursor + packed_size; + } - ctx.kernels_q8->rhs_info.pack_func_ex(1, n, k, nr, kr, sr, 0, 0, - qdata.data(), nullptr, scales.data(), - tensor->data, 0, ¶ms); - GGML_UNUSED(data_size); - return 0; + if (header->slot_count == 0) { + header->magic = 0; + header->version = 0; + memcpy(tensor->data, data, data_size); } - GGML_UNUSED(data_size); - return -1; + return 0; } }; @@ -681,9 +1363,8 @@ static void ggml_backend_cpu_kleidiai_buffer_set_tensor(ggml_backend_buffer_t bu } static const char * ggml_backend_cpu_kleidiai_buffer_type_get_name(ggml_backend_buffer_type_t buft) { - return "CPU_KLEIDIAI"; - GGML_UNUSED(buft); + return "CPU_KLEIDIAI"; } static ggml_backend_buffer_t ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { @@ -702,49 +1383,78 @@ static ggml_backend_buffer_t ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer( } static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { - return TENSOR_ALIGNMENT; - GGML_UNUSED(buft); + return TENSOR_ALIGNMENT; } static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) { GGML_UNUSED(buft); + if (tensor->type != GGML_TYPE_Q4_0 && tensor->type != GGML_TYPE_Q8_0) { + return ggml_nbytes(tensor); + } + const size_t n = tensor->ne[1]; const size_t k = tensor->ne[0]; - ggml_kleidiai_kernels * kernels = nullptr; - size_t block_len = 0; - - if (tensor->type == GGML_TYPE_Q4_0) { - GGML_ASSERT(ctx.kernels_q4); - kernels = ctx.kernels_q4; - block_len = QK4_0; - } else if (tensor->type == GGML_TYPE_Q8_0) { - GGML_ASSERT(ctx.kernels_q8); - kernels = ctx.kernels_q8; - block_len = QK8_0; - } else { - return 0; + size_t cursor = sizeof(kleidiai_weight_header); + cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN); + + std::array kernel_chain; + const bool want_q8 = tensor->type == GGML_TYPE_Q8_0; + const int slot_total = want_q8 ? kleidiai_collect_q8_chain(kernel_chain) + : kleidiai_collect_q4_chain(kernel_chain); + const bool allow_fallback = kleidiai_pack_fallback_allowed(); + + size_t slot_count = 0; + for (int slot = 0; slot < slot_total; ++slot) { + if (!allow_fallback && slot > 0) { + break; + } + ggml_kleidiai_kernels * kernels = kernel_chain[slot]; + if (!kernels) { + continue; + } + kernel_info * kernel = &kernels->gemm; + rhs_packing_info * rhs_info = &kernels->rhs_info; + if (!kernel || !rhs_info || !rhs_info->packed_size_ex) { + continue; + } + + const ggml_type rhs_type = kernels->rhs_type; + const size_t block_len = rhs_type == GGML_TYPE_Q4_0 ? QK4_0 : + rhs_type == GGML_TYPE_Q8_0 ? QK8_0 : 0; + if (block_len == 0) { + continue; + } + + cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN); + cursor += rhs_info->packed_size_ex(n, k, kernel->get_nr(), kernel->get_kr(), block_len); + ++slot_count; } - const size_t nr = kernels->gemm.get_nr(); - const size_t kr = kernels->gemm.get_kr(); - const size_t packed = kernels->rhs_info.packed_size_ex(n, k, nr, kr, block_len); - const size_t raw = ggml_nbytes(tensor); + if (slot_count == 0) { + return ggml_nbytes(tensor); + } - return packed > raw ? packed : raw; + return std::max(cursor, ggml_nbytes(tensor)); } namespace ggml::cpu::kleidiai { class extra_buffer_type : ggml::cpu::extra_buffer_type { bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override { + std::array kernel_chain; + const int slot_total = kleidiai_collect_kernel_chain(op, kernel_chain); if ((op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_GET_ROWS) && (op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q8_0) && op->src[0]->buffer && (ggml_n_dims(op->src[0]) == 2) && - op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) { - if (((op->src[0]->type == GGML_TYPE_Q4_0) ? ctx.kernels_q4 : ctx.kernels_q8) == nullptr) { + op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() && + slot_total > 0) { + if (op->src[0]->type == GGML_TYPE_Q4_0 && ctx.kernels_q4 == nullptr) { + return false; + } + if (op->src[0]->type == GGML_TYPE_Q8_0 && ctx.kernels_q8 == nullptr) { return false; } if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) { @@ -762,14 +1472,17 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type { if (op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_GET_ROWS) { if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) { return (ggml::cpu::tensor_traits *) op->src[0]->extra; - } - else if (ggml_kleidiai_select_kernels(ctx.features, op) && op->src[1]->ne[1] > 1) { - if ((op->src[0]->nb[1] * op->src[0]->ne[1] != op->src[0]->nb[2]) || - (op->src[1]->nb[1] * op->src[1]->ne[1] != op->src[1]->nb[2])) { - return nullptr; + } else { + std::array kernel_chain; + const int slot_total = kleidiai_collect_kernel_chain(op, kernel_chain); + const bool has_kernel = slot_total > 0; + if (has_kernel && op->src[1]->ne[1] > 1) { + if ((op->src[0]->nb[1] * op->src[0]->ne[1] != op->src[0]->nb[2]) || + (op->src[1]->nb[1] * op->src[1]->ne[1] != op->src[1]->nb[2])) { + return nullptr; + } + return ggml::cpu::kleidiai::get_tensor_traits(NULL, NULL); } - - return ggml::cpu::kleidiai::get_tensor_traits(NULL, NULL); } } return nullptr; From fddedc5cbc09a7eccd3daf394fb13f9dec171413 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Tue, 10 Mar 2026 09:14:27 -0700 Subject: [PATCH 241/831] ggml webgpu: faster normal quant and some k-quant matrix operations, better shader parameter handling (llama/20173) * K quant speedup (llama/20) * Basic JIT compilation for mul_mat, get_rows, and scale (llama/17) * scale jit working * preliminary working jit for getrows and mulmat, needs refining * simplified mul_mat preprocessing switch statement * get_rows fixes, mul_mat refinement * formatted + last edits * removed some extraneous prints * fixed get_rows, fixed workgroup dispatch in mul_mat. no gibberish * small fix * some changes, working * get_rows and mul_mat jit fixed and working * Update formatting * formatting * Add header --------- Co-authored-by: Neha Abbas Co-authored-by: Reese Levine * Start work on all-encompassing shader library * refactor argmax, set_rows * Refactor all but flashattention, mat mul * no gibberish, all k quants added, merged * vec memory fix * q6_k matching metal on my machine, tests passing * Set tile size for q6_k separately * Separate out fast shaders --------- Co-authored-by: neha-ha <137219201+neha-ha@users.noreply.github.com> * Move towards writeBuffer for params * Move away from multiple buffers for set_rows errors, remove host buffer for parameter buffers, minor cleanups * Remove extra file * Formatting --------- Co-authored-by: neha-ha <137219201+neha-ha@users.noreply.github.com> --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 70 +- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 459 ++++++------ .../wgsl-shaders/mul_mat_decls.tmpl | 673 +++++++++++++++++- .../wgsl-shaders/mul_mat_reg_tile.wgsl | 1 + .../ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl | 290 +++++++- 5 files changed, 1237 insertions(+), 256 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 17c5e0fb51f..3c38b1a230f 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -42,11 +42,20 @@ #define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2 // Matrix-vector multiplication parameters -#define WEBGPU_MUL_MAT_VEC_WG_SIZE 256 +#define WEBGPU_MUL_MAT_VEC_WG_SIZE 256 + // Must be multiple of 4 to work with vectorized paths, and must divide // mul_mat_vec wg size -#define WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG 64 -#define WEBGPU_MUL_MAT_VEC_TILE_K 256 +#define WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG 64 +#define WEBGPU_MUL_MAT_VEC_FLOAT_TILE_K 256 + +#define WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG 64 +#define WEBGPU_MUL_MAT_VEC_LEGACY_Q_TILE_K 256 + +// Requires 32 threads per output (wg_size/outputs_per_wg == 32) +#define WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG 8 +// Requires at least two (and multiple of 2) k-quant blocks per tile +#define WEBGPU_MUL_MAT_VEC_K_Q_TILE_K 512 // default size for legacy matrix multiplication #define WEBGPU_MUL_MAT_WG_SIZE 256 @@ -199,7 +208,8 @@ struct ggml_webgpu_binary_pipeline_key { bool src_overlap; bool operator==(const ggml_webgpu_binary_pipeline_key & other) const { - return type == other.type && op == other.op && inplace == other.inplace && overlap == other.overlap && src_overlap == other.src_overlap; + return type == other.type && op == other.op && inplace == other.inplace && overlap == other.overlap && + src_overlap == other.src_overlap; } }; @@ -749,29 +759,17 @@ class ggml_webgpu_shader_lib { std::vector defines; std::string variant = "mul_mat_vec"; - // src1 type (vector) - switch (context.src1->type) { - case GGML_TYPE_F32: - defines.push_back("SRC1_INNER_TYPE=f32"); - variant += "_f32"; - break; - case GGML_TYPE_F16: - defines.push_back("SRC1_INNER_TYPE=f16"); - variant += "_f16"; - break; - default: - GGML_ABORT("Unsupported src1 type for mul_mat_vec shader"); - } - // src0 type (matrix row) switch (context.src0->type) { case GGML_TYPE_F32: defines.push_back("SRC0_INNER_TYPE=f32"); defines.push_back("MUL_ACC_FLOAT"); + variant += "_f32"; break; case GGML_TYPE_F16: defines.push_back("SRC0_INNER_TYPE=f16"); defines.push_back("MUL_ACC_FLOAT"); + variant += "_f16"; break; default: { @@ -779,6 +777,7 @@ class ggml_webgpu_shader_lib { const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type); std::string src0_name = src0_traits->type_name; std::string type_upper = src0_name; + variant += "_" + src0_name; std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper); defines.push_back("BYTE_HELPERS"); @@ -790,12 +789,35 @@ class ggml_webgpu_shader_lib { } } + // src1 type (vector) + switch (context.src1->type) { + case GGML_TYPE_F32: + defines.push_back("SRC1_INNER_TYPE=f32"); + variant += "_f32"; + break; + case GGML_TYPE_F16: + defines.push_back("SRC1_INNER_TYPE=f16"); + variant += "_f16"; + break; + default: + GGML_ABORT("Unsupported src1 type for mul_mat_vec shader"); + } + // VEC/SCALAR controls defines.push_back(key.vectorized ? "VEC" : "SCALAR"); uint32_t wg_size = WEBGPU_MUL_MAT_VEC_WG_SIZE; - uint32_t tile_k = WEBGPU_MUL_MAT_VEC_TILE_K; - uint32_t outputs_per_wg = WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG; + uint32_t tile_k = WEBGPU_MUL_MAT_VEC_FLOAT_TILE_K; + uint32_t outputs_per_wg = WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG; + + if (key.src0_type >= GGML_TYPE_Q2_K) { + tile_k = WEBGPU_MUL_MAT_VEC_K_Q_TILE_K; + outputs_per_wg = WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG; + } else if (key.src0_type >= GGML_TYPE_Q4_0) { + tile_k = WEBGPU_MUL_MAT_VEC_LEGACY_Q_TILE_K; + outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG; + } + defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); defines.push_back(std::string("TILE_K=") + std::to_string(tile_k)); defines.push_back(std::string("OUTPUTS_PER_WG=") + std::to_string(outputs_per_wg)); @@ -1061,10 +1083,10 @@ class ggml_webgpu_shader_lib { webgpu_pipeline get_binary_pipeline(const ggml_webgpu_shader_lib_context & context) { ggml_webgpu_binary_pipeline_key key = { - .type = context.dst->type, - .op = context.dst->op, - .inplace = context.inplace, - .overlap = context.overlap, + .type = context.dst->type, + .op = context.dst->op, + .inplace = context.inplace, + .overlap = context.overlap, .src_overlap = context.src_overlap, }; diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index b2ef2d59010..ccc34cb153f 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -8,7 +8,6 @@ #include "ggml-backend-impl.h" #include "ggml-impl.h" #include "ggml-webgpu-shader-lib.hpp" -#include "pre_wgsl.hpp" #ifdef __EMSCRIPTEN__ # include @@ -20,12 +19,18 @@ #include #include #include -#include +#ifdef GGML_WEBGPU_GPU_PROFILE +# include +#endif +#if defined(GGML_WEBGPU_DEBUG) || defined(GGML_WEBGPU_CPU_PROFILE) || defined(GGML_WEBGPU_GPU_PROFILE) +# include +#endif #include #include #include #include #include +#include #include #define ROUNDUP_POW2(x, pow2) (((x) + ((pow2) - 1)) & ~((pow2) - 1)) @@ -70,22 +75,21 @@ static inline void compute_2d_workgroups(uint32_t total_wg, uint32_t max_per_dim #endif // GGML_WEBGPU_CPU_PROFILE #ifdef GGML_WEBGPU_GPU_PROFILE -# define WEBGPU_NUM_TIMESTAMP_QUERY_BUFS 24 +# define WEBGPU_NUM_TIMESTAMP_QUERY_BUFS 32 # define WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES 16 // e.g. enough for two timestamps #endif /* Constants */ -#define WEBGPU_NUM_PARAM_BUFS 48u -#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 16u +#define WEBGPU_NUM_PARAM_BUFS 96u +#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 32u #define WEBGPU_WAIT_ANY_TIMEOUT_MS 0 // Maximum number of in-flight submissions per-thread, to avoid exhausting the // parameter buffer pool -#define WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE +#define WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD (WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE) #define WEBGPU_PARAMS_BUF_SIZE_BYTES 128 // enough for 32 parameters -#define WEBGPU_NUM_SET_ROWS_ERROR_BUFS 16 #define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4 -#define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4 +#define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4 // For operations which process a row in parallel, this seems like a reasonable // default @@ -118,14 +122,9 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device, wgpu::BufferUsage usage, const char * label); -struct webgpu_pool_bufs { - wgpu::Buffer host_buf; - wgpu::Buffer dev_buf; -}; - // Holds a pool of parameter buffers for WebGPU operations struct webgpu_buf_pool { - std::vector free; + std::vector free; // The pool must be synchronized because // 1. The memset pool is shared globally by every ggml buffer, @@ -138,7 +137,6 @@ struct webgpu_buf_pool { size_t cur_pool_size; size_t max_pool_size; wgpu::Device device; - wgpu::BufferUsage host_buf_usage; wgpu::BufferUsage dev_buf_usage; size_t buf_size; bool should_grow; @@ -147,53 +145,47 @@ struct webgpu_buf_pool { int num_bufs, size_t buf_size, wgpu::BufferUsage dev_buf_usage, - wgpu::BufferUsage host_buf_usage, bool should_grow = false, size_t max_pool_size = WEBGPU_NUM_PARAM_BUFS * 2) { - this->max_pool_size = max_pool_size; - this->cur_pool_size = num_bufs; - this->device = device; - this->host_buf_usage = host_buf_usage; - this->dev_buf_usage = dev_buf_usage; - this->buf_size = buf_size; - this->should_grow = should_grow; + this->max_pool_size = max_pool_size; + this->cur_pool_size = num_bufs; + this->device = device; + this->dev_buf_usage = dev_buf_usage; + this->buf_size = buf_size; + this->should_grow = should_grow; for (int i = 0; i < num_bufs; i++) { - wgpu::Buffer host_buf; wgpu::Buffer dev_buf; - ggml_webgpu_create_buffer(device, host_buf, buf_size, host_buf_usage, "ggml_webgpu_host_pool_buf"); ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_pool_buf"); - free.push_back({ host_buf, dev_buf }); + free.push_back(dev_buf); } } - webgpu_pool_bufs alloc_bufs() { + wgpu::Buffer alloc_bufs() { std::unique_lock lock(mutex); if (!free.empty()) { - webgpu_pool_bufs bufs = free.back(); + wgpu::Buffer buf = free.back(); free.pop_back(); - return bufs; + return buf; } // Try growing the pool if no free buffers if (free.empty() && cur_pool_size < max_pool_size && should_grow) { cur_pool_size++; - wgpu::Buffer host_buf; wgpu::Buffer dev_buf; - ggml_webgpu_create_buffer(device, host_buf, buf_size, host_buf_usage, "ggml_webgpu_host_pool_buf"); ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_pool_buf"); - if (!(host_buf && dev_buf)) { + if (!dev_buf) { GGML_ABORT("webgpu_buf_pool: failed to allocate buffers"); } - return webgpu_pool_bufs{ host_buf, dev_buf }; + return dev_buf; } cv.wait(lock, [this] { return !free.empty(); }); - webgpu_pool_bufs bufs = free.back(); + wgpu::Buffer buf = free.back(); free.pop_back(); - return bufs; + return buf; } - void free_bufs(std::vector bufs) { + void free_bufs(std::vector bufs) { std::lock_guard lock(mutex); free.insert(free.end(), bufs.begin(), bufs.end()); cv.notify_all(); @@ -201,12 +193,9 @@ struct webgpu_buf_pool { void cleanup() { std::lock_guard lock(mutex); - for (auto & bufs : free) { - if (bufs.host_buf) { - bufs.host_buf.Destroy(); - } - if (bufs.dev_buf) { - bufs.dev_buf.Destroy(); + for (auto & buf : free) { + if (buf) { + buf.Destroy(); } } free.clear(); @@ -280,10 +269,9 @@ struct webgpu_gpu_profile_buf_pool { #endif struct webgpu_command { - uint32_t num_kernels; - wgpu::CommandBuffer commands; - std::vector params_bufs; - std::optional set_rows_error_bufs; + uint32_t num_kernels; + wgpu::CommandBuffer commands; + std::vector params_bufs; #ifdef GGML_WEBGPU_GPU_PROFILE webgpu_gpu_profile_bufs timestamp_query_bufs; std::string pipeline_name; @@ -358,6 +346,13 @@ struct webgpu_global_context_struct { typedef std::shared_ptr webgpu_global_context; +struct webgpu_submission { + wgpu::FutureWaitInfo submit_done; +#ifdef GGML_WEBGPU_GPU_PROFILE + std::vector profile_futures; +#endif +}; + // All the base objects needed to run operations on a WebGPU device struct webgpu_context_struct { // Points to global instances owned by ggml_backend_webgpu_reg_context @@ -366,7 +361,8 @@ struct webgpu_context_struct { std::unique_ptr shader_lib; webgpu_buf_pool param_buf_pool; - webgpu_buf_pool set_rows_error_buf_pool; + wgpu::Buffer set_rows_dev_error_buf; + wgpu::Buffer set_rows_host_error_buf; std::map> cpy_pipelines; // src_type, dst_type @@ -458,67 +454,105 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device, /** End WebGPU object initializations */ /** WebGPU Actions */ -static void erase_completed(std::vector & futures) { + +static bool ggml_backend_webgpu_handle_wait_status(wgpu::WaitStatus status, bool allow_timeout = false) { + switch (status) { + case wgpu::WaitStatus::Success: + return true; + case wgpu::WaitStatus::TimedOut: + if (allow_timeout) { + return false; + } + GGML_LOG_ERROR("ggml_webgpu: WaitAny timed out unexpectedly\n"); + return false; + case wgpu::WaitStatus::Error: + GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n"); + return false; + default: + GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an unknown status\n"); + return false; + } +} + +#ifdef GGML_WEBGPU_GPU_PROFILE +static void ggml_backend_webgpu_erase_completed_futures(std::vector & futures) { futures.erase(std::remove_if(futures.begin(), futures.end(), [](const wgpu::FutureWaitInfo & info) { return info.completed; }), futures.end()); } -// Wait for the queue to finish processing all submitted work -static void ggml_backend_webgpu_wait(webgpu_global_context & ctx, - std::vector & futures, - bool block = true) { - // If we have too many in-flight submissions, wait on the oldest one first. +static void ggml_backend_webgpu_wait_profile_futures(webgpu_global_context & ctx, + std::vector & futures, + bool block) { if (futures.empty()) { return; } + uint64_t timeout_ms = block ? UINT64_MAX : 0; - while (futures.size() >= WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD) { - auto waitStatus = ctx->instance.WaitAny(1, &futures[0], UINT64_MAX); - if (waitStatus == wgpu::WaitStatus::Error) { - GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n"); + if (block) { + while (!futures.empty()) { + auto waitStatus = ctx->instance.WaitAny(futures.size(), futures.data(), timeout_ms); + if (ggml_backend_webgpu_handle_wait_status(waitStatus)) { + ggml_backend_webgpu_erase_completed_futures(futures); + } } - if (futures[0].completed) { - futures.erase(futures.begin()); + } else { + auto waitStatus = ctx->instance.WaitAny(futures.size(), futures.data(), timeout_ms); + if (ggml_backend_webgpu_handle_wait_status(waitStatus, true)) { + ggml_backend_webgpu_erase_completed_futures(futures); } } +} +#endif - if (futures.empty()) { +// Wait for the queue to finish processing all submitted work +static void ggml_backend_webgpu_wait(webgpu_global_context & ctx, + std::vector & subs, + bool block = true) { + // If we have too many in-flight submissions, wait on the oldest one first. + if (subs.empty()) { + return; + } + while (subs.size() >= WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD) { + auto waitStatus = ctx->instance.WaitAny(1, &subs[0].submit_done, UINT64_MAX); + if (ggml_backend_webgpu_handle_wait_status(waitStatus)) { +#ifdef GGML_WEBGPU_GPU_PROFILE + ggml_backend_webgpu_wait_profile_futures(ctx, subs[0].profile_futures, true); +#endif + subs.erase(subs.begin()); + } + } + + if (subs.empty()) { return; } if (block) { - while (!futures.empty()) { - auto waitStatus = ctx->instance.WaitAny(futures.size(), futures.data(), timeout_ms); - switch (waitStatus) { - case wgpu::WaitStatus::Success: - // WaitAny doesn't tell us which future completed, so we must check all futures to see which finished. - erase_completed(futures); - break; - case wgpu::WaitStatus::Error: - GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n"); - break; - default: - GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an unknown status\n"); - break; + for (auto & sub : subs) { + while (!sub.submit_done.completed) { + auto waitStatus = ctx->instance.WaitAny(1, &sub.submit_done, UINT64_MAX); + ggml_backend_webgpu_handle_wait_status(waitStatus); } +#ifdef GGML_WEBGPU_GPU_PROFILE + ggml_backend_webgpu_wait_profile_futures(ctx, sub.profile_futures, true); +#endif } + subs.clear(); } else { - // Poll once and return - auto waitStatus = ctx->instance.WaitAny(futures.size(), futures.data(), timeout_ms); - switch (waitStatus) { - case wgpu::WaitStatus::Success: - // WaitAny doesn't tell us which future completed, so we must check all futures to see which finished. - erase_completed(futures); - break; - case wgpu::WaitStatus::TimedOut: - break; - case wgpu::WaitStatus::Error: - GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n"); - break; - default: - GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an unknown status\n"); - break; + // Poll each submit future once and remove completed submissions. + for (auto sub = subs.begin(); sub != subs.end();) { + auto waitStatus = ctx->instance.WaitAny(1, &sub->submit_done, 0); + ggml_backend_webgpu_handle_wait_status(waitStatus, true); +#ifdef GGML_WEBGPU_GPU_PROFILE + ggml_backend_webgpu_wait_profile_futures(ctx, sub->profile_futures, false); + if (sub->submit_done.completed && sub->profile_futures.empty()) { +#else + if (sub->submit_done.completed) { +#endif + sub = subs.erase(sub); + } else { + ++sub; + } } } } @@ -554,14 +588,12 @@ static void ggml_backend_webgpu_debug(webgpu_global_context & ctx) { } #endif -static std::vector ggml_backend_webgpu_submit( - webgpu_global_context ctx, - std::vector commands, - webgpu_buf_pool & param_buf_pool, - webgpu_buf_pool * set_rows_error_buf_pool = nullptr) { +static webgpu_submission ggml_backend_webgpu_submit(webgpu_global_context & ctx, + std::vector & commands, + webgpu_buf_pool & param_buf_pool) { std::vector command_buffers; - std::vector params_bufs; - std::vector set_rows_error_bufs; + std::vector params_bufs; + webgpu_submission submission; #ifdef GGML_WEBGPU_GPU_PROFILE std::vector> pipeline_name_and_ts_bufs; #endif @@ -569,14 +601,9 @@ static std::vector ggml_backend_webgpu_submit( for (const auto & command : commands) { command_buffers.push_back(command.commands); params_bufs.insert(params_bufs.end(), command.params_bufs.begin(), command.params_bufs.end()); - if (command.set_rows_error_bufs) { - set_rows_error_bufs.push_back(command.set_rows_error_bufs.value()); - } } ctx->queue.Submit(command_buffers.size(), command_buffers.data()); - std::vector futures; - wgpu::Future p_f = ctx->queue.OnSubmittedWorkDone( wgpu::CallbackMode::AllowSpontaneous, [¶m_buf_pool, params_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) { @@ -586,27 +613,7 @@ static std::vector ggml_backend_webgpu_submit( // Free the staged buffers param_buf_pool.free_bufs(params_bufs); }); - futures.push_back({ p_f }); - - for (const auto & bufs : set_rows_error_bufs) { - wgpu::Future f = bufs.host_buf.MapAsync( - wgpu::MapMode::Read, 0, bufs.host_buf.GetSize(), wgpu::CallbackMode::AllowSpontaneous, - [set_rows_error_buf_pool, bufs](wgpu::MapAsyncStatus status, wgpu::StringView message) { - if (status != wgpu::MapAsyncStatus::Success) { - GGML_LOG_ERROR("ggml_webgpu: Failed to map error buffer: %s\n", std::string(message).c_str()); - } else { - const uint32_t * error_data = (const uint32_t *) bufs.host_buf.GetConstMappedRange(); - if (*error_data) { - GGML_ABORT("ggml_webgpu: SET_ROWS index > 2^32, unsupported."); - } - // We can't unmap in here due to WebGPU reentrancy limitations. - if (set_rows_error_buf_pool) { - set_rows_error_buf_pool->free_bufs({ bufs }); - } - } - }); - futures.push_back({ f }); - } + submission.submit_done = { p_f }; #ifdef GGML_WEBGPU_GPU_PROFILE for (const auto & command : commands) { @@ -623,14 +630,14 @@ static std::vector ggml_backend_webgpu_submit( // WebGPU timestamps are in ns; convert to ms double elapsed_ms = double(ts_data[1] - ts_data[0]) * 1e-6; ctx->shader_gpu_time_ms[label] += elapsed_ms; - // We can't unmap in here due to WebGPU reentrancy limitations. - ctx->timestamp_query_buf_pool.free_bufs({ ts_bufs }); } + // We can't unmap in here due to WebGPU reentrancy limitations. + ctx->timestamp_query_buf_pool.free_bufs({ ts_bufs }); }); - futures.push_back({ f }); + submission.profile_futures.push_back({ f }); } #endif - return futures; + return submission; } static webgpu_command ggml_backend_webgpu_build_multi( @@ -639,32 +646,21 @@ static webgpu_command ggml_backend_webgpu_build_multi( const std::vector & pipelines, const std::vector> & params_list, const std::vector> & bind_group_entries_list, - const std::vector> & workgroups_list, - const std::optional & set_rows_error_bufs = std::nullopt) { + const std::vector> & workgroups_list) { GGML_ASSERT(pipelines.size() == params_list.size()); GGML_ASSERT(pipelines.size() == bind_group_entries_list.size()); GGML_ASSERT(pipelines.size() == workgroups_list.size()); - std::vector params_bufs_list; - std::vector bind_groups; + std::vector params_bufs_list; + std::vector bind_groups; for (size_t i = 0; i < pipelines.size(); i++) { - webgpu_pool_bufs params_bufs = param_buf_pool.alloc_bufs(); - - ggml_backend_webgpu_map_buffer(ctx, params_bufs.host_buf, wgpu::MapMode::Write, 0, - params_bufs.host_buf.GetSize()); - uint32_t * _params = (uint32_t *) params_bufs.host_buf.GetMappedRange(); - for (size_t j = 0; j < params_list[i].size(); j++) { - _params[j] = params_list[i][j]; - } - params_bufs.host_buf.Unmap(); + wgpu::Buffer params_bufs = param_buf_pool.alloc_bufs(); std::vector entries = bind_group_entries_list[i]; uint32_t params_binding_num = entries.size(); - entries.push_back({ .binding = params_binding_num, - .buffer = params_bufs.dev_buf, - .offset = 0, - .size = params_bufs.dev_buf.GetSize() }); + entries.push_back( + { .binding = params_binding_num, .buffer = params_bufs, .offset = 0, .size = params_bufs.GetSize() }); wgpu::BindGroupDescriptor bind_group_desc; bind_group_desc.layout = pipelines[i].pipeline.GetBindGroupLayout(0); @@ -677,15 +673,8 @@ static webgpu_command ggml_backend_webgpu_build_multi( } wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder(); - for (const auto & params_bufs : params_bufs_list) { - encoder.CopyBufferToBuffer(params_bufs.host_buf, 0, params_bufs.dev_buf, 0, params_bufs.dev_buf.GetSize()); - } - - // If there are SET_ROWS operations in this submission, copy their error - // buffers to the host. - if (set_rows_error_bufs) { - encoder.CopyBufferToBuffer(set_rows_error_bufs->dev_buf, 0, set_rows_error_bufs->host_buf, 0, - set_rows_error_bufs->host_buf.GetSize()); + for (size_t i = 0; i < params_bufs_list.size(); i++) { + ctx->queue.WriteBuffer(params_bufs_list[i], 0, params_list[i].data(), params_list[i].size() * sizeof(uint32_t)); } #ifdef GGML_WEBGPU_GPU_PROFILE @@ -718,7 +707,6 @@ static webgpu_command ggml_backend_webgpu_build_multi( webgpu_command result = {}; result.commands = commands; result.params_bufs = params_bufs_list; - result.set_rows_error_bufs = set_rows_error_bufs; result.num_kernels = pipelines.size(); #ifdef GGML_WEBGPU_GPU_PROFILE result.timestamp_query_bufs = ts_bufs; @@ -734,13 +722,13 @@ static webgpu_command ggml_backend_webgpu_build(webgpu_global_context & std::vector params, std::vector bind_group_entries, uint32_t wg_x, - uint32_t wg_y = 1, - std::optional set_rows_error_bufs = std::nullopt) { + uint32_t wg_y = 1) { return ggml_backend_webgpu_build_multi(ctx, param_buf_pool, { pipeline }, - { params }, { bind_group_entries }, { { wg_x, wg_y } }, set_rows_error_bufs); + { std::move(params) }, { std::move(bind_group_entries) }, + { { wg_x, wg_y } }); } static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx, @@ -757,8 +745,9 @@ static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx, webgpu_command command = ggml_backend_webgpu_build(ctx, ctx->memset_buf_pool, ctx->memset_pipelines[0], params, entries, wg_x); - auto futures = ggml_backend_webgpu_submit(ctx, { command }, ctx->memset_buf_pool); - ggml_backend_webgpu_wait(ctx, futures); + std::vector commands = { command }; + std::vector sub = { ggml_backend_webgpu_submit(ctx, commands, ctx->memset_buf_pool) }; + ggml_backend_webgpu_wait(ctx, sub); } /** End WebGPU Actions */ @@ -805,7 +794,8 @@ static void ggml_backend_webgpu_free(ggml_backend_t backend) { std::cout << "\nggml_webgpu: gpu breakdown:\n"; for (const auto & kv : ctx->webgpu_ctx->global_ctx->shader_gpu_time_ms) { double pct = (total_gpu > 0.0) ? (kv.second / total_gpu * 100.0) : 0.0; - std::cout << "ggml_webgpu: " << kv.first << ": " << kv.second << " ms (" << pct << "%)\n"; + std::cout << "ggml_webgpu: " << kv.first << ": " << kv.second << " ms (" << std::fixed << std::setprecision(2) + << pct << "%)\n"; } #endif @@ -978,14 +968,6 @@ static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, auto * decisions = static_cast(pipeline.context.get()); - std::optional error_bufs = std::nullopt; - if (decisions->i64_idx) { - error_bufs = ctx->set_rows_error_buf_pool.alloc_bufs(); - if (error_bufs->host_buf.GetMapState() == wgpu::BufferMapState::Mapped) { - error_bufs->host_buf.Unmap(); - } - } - std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)), @@ -1018,8 +1000,10 @@ static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, }; if (decisions->i64_idx) { - entries.push_back( - { .binding = 3, .buffer = error_bufs->dev_buf, .offset = 0, .size = error_bufs->dev_buf.GetSize() }); + entries.push_back({ .binding = 3, + .buffer = ctx->set_rows_dev_error_buf, + .offset = 0, + .size = ctx->set_rows_dev_error_buf.GetSize() }); } uint32_t threads; @@ -1029,8 +1013,7 @@ static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, threads = src->ne[0] * src->ne[1] * src->ne[2] * src->ne[3]; } uint32_t wg_x = CEIL_DIV(threads, decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, 1, - error_bufs); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, 1); } // Workgroup size is a common constant @@ -1108,12 +1091,26 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, use_fast = (src0->type == GGML_TYPE_F16); break; case GGML_TYPE_F32: + // TODO: implement better mat-mat for k-quants, mat-vec for all k-quants except q6_K switch (src0->type) { case GGML_TYPE_F32: case GGML_TYPE_F16: case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q8_1: + case GGML_TYPE_Q6_K: use_fast = true; break; + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + // we don't have fast mat-vec for these types, but we do have (semi) fast mat-mat + use_fast = !is_vec; + break; default: break; } @@ -1187,17 +1184,18 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; if (use_fast && is_vec) { - auto decisions = static_cast(pipeline.context.get()); + auto * decisions = static_cast(pipeline.context.get()); uint32_t batches = dst->ne[2] * dst->ne[3]; uint32_t output_groups = CEIL_DIV(dst->ne[0], decisions->outputs_per_wg); uint32_t total_wg = output_groups * batches; compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y); } else if (use_fast) { - auto decisions = static_cast(pipeline.context.get()); + auto * decisions = static_cast(pipeline.context.get()); // Fast-path tiled/subgroup calculations - uint32_t wg_m, wg_n; + uint32_t wg_m; + uint32_t wg_n; if (decisions->use_subgroup_matrix) { uint32_t wg_m_sg_tile = decisions->subgroup_m * decisions->subgroup_matrix_m * ctx->global_ctx->capabilities.sg_mat_m; @@ -1215,7 +1213,7 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y); } else { // legacy - auto decisions = static_cast(pipeline.context.get()); + auto * decisions = static_cast(pipeline.context.get()); uint32_t wg_size = decisions->wg_size; uint32_t total_wg = CEIL_DIV(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3], wg_size); compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y); @@ -1514,10 +1512,10 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx, } static webgpu_command ggml_webgpu_concat(webgpu_context & ctx, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * dst) { - uint32_t ne = (uint32_t) ggml_nelements(dst); + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { + uint32_t ne = (uint32_t) ggml_nelements(dst); uint32_t dim = (uint32_t) dst->op_params[0]; std::vector params = { @@ -1538,28 +1536,22 @@ static webgpu_command ggml_webgpu_concat(webgpu_context & ctx, (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], dim, - (uint32_t)src0->ne[dim] + (uint32_t) src0->ne[dim] }; std::vector entries = { - { - .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = ggml_webgpu_tensor_align_offset(ctx, src0), - .size = ggml_webgpu_tensor_binding_size(ctx, src0) - }, - { - .binding = 1, - .buffer = ggml_webgpu_tensor_buf(src1), - .offset = ggml_webgpu_tensor_align_offset(ctx, src1), - .size = ggml_webgpu_tensor_binding_size(ctx, src1) - }, - { - .binding = 2, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) - } + { .binding = 0, + .buffer = ggml_webgpu_tensor_buf(src0), + .offset = ggml_webgpu_tensor_align_offset(ctx, src0), + .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, + { .binding = 1, + .buffer = ggml_webgpu_tensor_buf(src1), + .offset = ggml_webgpu_tensor_align_offset(ctx, src1), + .size = ggml_webgpu_tensor_binding_size(ctx, src1) }, + { .binding = 2, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = ggml_webgpu_tensor_align_offset(ctx, dst), + .size = ggml_webgpu_tensor_binding_size(ctx, dst) } }; ggml_webgpu_shader_lib_context shader_lib_ctx = { @@ -1569,9 +1561,9 @@ static webgpu_command ggml_webgpu_concat(webgpu_context & ctx, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, }; - webgpu_pipeline pipeline = ctx->shader_lib->get_concat_pipeline(shader_lib_ctx); - auto * decisions = static_cast(pipeline.context.get()); - uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); + webgpu_pipeline pipeline = ctx->shader_lib->get_concat_pipeline(shader_lib_ctx); + auto * decisions = static_cast(pipeline.context.get()); + uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } @@ -1623,7 +1615,12 @@ static webgpu_command ggml_webgpu_rope(webgpu_context & ctx, const int mode = ((int32_t *) dst->op_params)[2]; const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; - float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; + float freq_base; + float freq_scale; + float ext_factor; + float attn_factor; + float beta_fast; + float beta_slow; memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); @@ -2172,19 +2169,12 @@ static std::optional ggml_webgpu_encode_node(webgpu_context ctx, case GGML_OP_SOFT_MAX: return ggml_webgpu_soft_max(ctx, src0, src1, src2, node); case GGML_OP_UNARY: - return ggml_webgpu_unary_op(ctx, src0, node); case GGML_OP_CLAMP: - return ggml_webgpu_unary_op(ctx, src0, node); case GGML_OP_FILL: - return ggml_webgpu_unary_op(ctx, src0, node); case GGML_OP_LOG: - return ggml_webgpu_unary_op(ctx, src0, node); case GGML_OP_SQR: - return ggml_webgpu_unary_op(ctx, src0, node); case GGML_OP_SQRT: - return ggml_webgpu_unary_op(ctx, src0, node); case GGML_OP_SIN: - return ggml_webgpu_unary_op(ctx, src0, node); case GGML_OP_COS: return ggml_webgpu_unary_op(ctx, src0, node); case GGML_OP_PAD: @@ -2192,7 +2182,6 @@ static std::optional ggml_webgpu_encode_node(webgpu_context ctx, case GGML_OP_ARGMAX: return ggml_webgpu_argmax(ctx, src0, node); case GGML_OP_ARGSORT: - return ggml_webgpu_argsort(ctx, src0, node); case GGML_OP_TOP_K: // we reuse the same argsort implementation for top_k return ggml_webgpu_argsort(ctx, src0, node); @@ -2214,33 +2203,51 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str WEBGPU_CPU_PROFILE_TOTAL_START(graph_compute); - std::vector commands; - std::vector futures; - uint32_t num_batched_kernels = 0; + std::vector commands; + std::vector subs; + uint32_t num_batched_kernels = 0; + bool contains_set_rows = false; + for (int i = 0; i < cgraph->n_nodes; i++) { + if (cgraph->nodes[i]->op == GGML_OP_SET_ROWS) { + contains_set_rows = true; + } if (auto cmd = ggml_webgpu_encode_node(ctx, cgraph->nodes[i])) { commands.push_back(*cmd); num_batched_kernels += cmd.value().num_kernels; } if (num_batched_kernels >= WEBGPU_COMMAND_SUBMIT_BATCH_SIZE) { - num_batched_kernels = 0; - std::vector compute_futures = ggml_backend_webgpu_submit( - ctx->global_ctx, commands, ctx->param_buf_pool, &ctx->set_rows_error_buf_pool); - futures.insert(futures.end(), compute_futures.begin(), compute_futures.end()); + num_batched_kernels = 0; + subs.push_back(ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool)); // Process events and check for completed submissions ctx->global_ctx->instance.ProcessEvents(); - ggml_backend_webgpu_wait(ctx->global_ctx, futures, false); + ggml_backend_webgpu_wait(ctx->global_ctx, subs, false); commands.clear(); } } if (!commands.empty()) { - auto new_futures = - ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool, &ctx->set_rows_error_buf_pool); - futures.insert(futures.end(), new_futures.begin(), new_futures.end()); + subs.push_back(ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool)); + commands.clear(); + } + + // If there are SET_ROWS operations in this graph, copy the error buffers to the host for checking. + if (contains_set_rows) { + wgpu::CommandEncoder encoder = ctx->global_ctx->device.CreateCommandEncoder(); + encoder.CopyBufferToBuffer(ctx->set_rows_dev_error_buf, 0, ctx->set_rows_host_error_buf, 0, + ctx->set_rows_host_error_buf.GetSize()); + wgpu::CommandBuffer set_rows_commands = encoder.Finish(); + ctx->global_ctx->queue.Submit(1, &set_rows_commands); + ggml_backend_webgpu_map_buffer(ctx->global_ctx, ctx->set_rows_host_error_buf, wgpu::MapMode::Read, 0, + ctx->set_rows_host_error_buf.GetSize()); + const uint32_t * error_data = (const uint32_t *) ctx->set_rows_host_error_buf.GetConstMappedRange(); + if (*error_data) { + GGML_ABORT("ggml_webgpu: SET_ROWS index > 2^32, unsupported."); + } + ctx->set_rows_host_error_buf.Unmap(); } - ggml_backend_webgpu_wait(ctx->global_ctx, futures); + ggml_backend_webgpu_wait(ctx->global_ctx, subs); WEBGPU_CPU_PROFILE_TOTAL_END(graph_compute, ctx->global_ctx); return GGML_STATUS_SUCCESS; } @@ -2859,10 +2866,12 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) { webgpu_ctx->param_buf_pool.init(webgpu_ctx->global_ctx->device, WEBGPU_NUM_PARAM_BUFS, WEBGPU_PARAMS_BUF_SIZE_BYTES, wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform, wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite, true); - webgpu_ctx->set_rows_error_buf_pool.init(webgpu_ctx->global_ctx->device, WEBGPU_NUM_SET_ROWS_ERROR_BUFS, - WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES, - wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage, - wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead); + ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->set_rows_dev_error_buf, + WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES, + wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc, "set_rows_dev_error_buf"); + ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->set_rows_host_error_buf, + WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES, + wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "set_rows_host_error_buf"); ggml_webgpu_init_cpy_pipeline(webgpu_ctx); ggml_webgpu_init_rms_norm_pipeline(webgpu_ctx); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl index 5c1074ebc10..de60ebbcf2b 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl @@ -11,7 +11,7 @@ fn store_shmem(val: vec4, idx: u32) { shmem[idx + 2] = val.z; shmem[idx + 3] = val.w; } -#endif +#endif // VEC #ifdef SCALAR #define VEC_SIZE 1 @@ -23,7 +23,7 @@ fn store_shmem(val: vec4, idx: u32) { fn store_shmem(val: f16, idx: u32) { shmem[idx] = val; } -#endif +#endif // SCALAR #ifdef INIT_SRC0_SHMEM_FLOAT fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { @@ -40,7 +40,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 store_shmem(SHMEM_TYPE(src0_val), elem_idx); } } -#endif +#endif // INIT_SRC0_SHMEM_FLOAT #ifdef INIT_SRC1_SHMEM_FLOAT fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u32) { @@ -57,7 +57,7 @@ fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u3 store_shmem(SHMEM_TYPE(src1_val), TILE_SRC0_SHMEM + elem_idx); } } -#endif +#endif // INIT_SRC1_SHMEM_FLOAT #ifdef INIT_SRC0_SHMEM_Q4_0 const BLOCK_SIZE = 32u; @@ -100,4 +100,667 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 } } } -#endif +#endif // INIT_SRC0_SHMEM_Q4_0 + +#ifdef INIT_SRC0_SHMEM_Q4_1 +const BLOCK_SIZE = 32u; +// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. +override BLOCKS_K = TILE_K/BLOCK_SIZE; +const NQ = 16u; +const F16_PER_BLOCK = 10u; // 1 scale + 8 packed weights + 1 mean +const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 +const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) { + let blck_idx = i / BLOCK_SIZE; + let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; + let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; + + let tile_m = blck_idx / BLOCKS_K; + let global_m = offset_m + tile_m; + let block_k = blck_idx % BLOCKS_K; + let global_k = k_outer / BLOCK_SIZE + block_k; + + if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { + let src0_idx = batch_offset + global_m * params.stride_01 + global_k; + let scale_idx = src0_idx * F16_PER_BLOCK; + let d = src0[scale_idx]; + let m = src0[scale_idx + 1u]; + + for (var j = 0u; j < F16_PER_THREAD; j += 2) { + let q_0 = src0[scale_idx + 2u + block_offset + j]; + let q_1 = src0[scale_idx + 2u + block_offset + j + 1]; + + let q_packed = bitcast(vec2(q_0, q_1)); + for (var k = 0u; k < 4u; k++) { + let q_byte = get_byte(q_packed, k); + let q_lo = f16(q_byte & 0xF) * d + m; + let q_hi = f16((q_byte >> 4) & 0xF) * d + m; + shmem[shmem_idx + j * 2 + k] = q_lo; + shmem[shmem_idx + j * 2 + k + 16u] = q_hi; + } + } + } + } +} +#endif // INIT_SRC0_SHMEM_Q4_1 + +#ifdef INIT_SRC0_SHMEM_Q5_0 +// 32 weights per block, each at 4 bits each = 32 * 4 = 128 bits / 16 = 8 f16s per block +const BLOCK_SIZE = 32u; +// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. +// tile_k is defined as 32u, so blocks_k ends up being 1 always +override BLOCKS_K = TILE_K / BLOCK_SIZE; +const NQ = 16u; +const F16_PER_BLOCK = 11u; // 1 scale + 2 qh + 8 packed weights +const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 +const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 16 / 4 = 4 f16s per thread, each thread should handle 4 f16s * 4 weights per = 16 weights + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + + for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) { + let blck_idx = i / BLOCK_SIZE; + let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; + let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; + + let tile_m = blck_idx / BLOCKS_K; + let global_m = offset_m + tile_m; + let block_k = blck_idx % BLOCKS_K; + let global_k = k_outer / BLOCK_SIZE + block_k; + + if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { + let src0_idx = batch_offset + global_m * params.stride_01 + global_k; + let scale_idx = src0_idx * F16_PER_BLOCK; + + let d = src0[scale_idx]; + let qh0 = src0[scale_idx + 1u]; + let qh1 = src0[scale_idx + 2u]; + let qh_packed = bitcast(vec2(qh0, qh1)); + + for (var j = 0u; j < 2; j++) { + let q_0 = src0[scale_idx + 3u + block_offset + (j*2)]; + let q_1 = src0[scale_idx + 3u + block_offset + (j*2) + 1u]; + + let q_packed = bitcast(vec2(q_0, q_1)); + + let j_adjusted = j + (block_offset / 2u); + + + for (var k = 0u; k < 4u; k++) { + let q_byte = get_byte(q_packed, k); + + let qh_hi = (qh_packed >> (j_adjusted * 4 + k + 12)) & 0x10; + let q_hi = (f16(((q_byte >> 4) & 0xF) | qh_hi) - 16.0) * d; + let qh_lo = ((qh_packed >> (j_adjusted * 4 + k)) << 4) & 0x10; + let q_lo = (f16((q_byte & 0xF) | qh_lo) - 16.0) * d; + + shmem[shmem_idx + j * 4u + k] = q_lo; // store first weight + shmem[shmem_idx + j * 4u + k + 16u] = q_hi; // store second weight + } + } + } + } +} +#endif // INIT_SRC0_SHMEM_Q5_0 + +#ifdef INIT_SRC0_SHMEM_Q5_1 +// 32 weights per block, each at 4 bits each = 32 * 4 = 128 bits / 16 = 8 f16s per block +const BLOCK_SIZE = 32u; +// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. +// tile_k is defined as 32u, so blocks_k ends up being 1 always +override BLOCKS_K = TILE_K / BLOCK_SIZE; +const NQ = 16u; +const F16_PER_BLOCK = 12u; // 1 scale + 2 qh + 8 packed weights + 1 mean +const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 +const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 16 / 4 = 4 f16s per thread, each thread should handle 4 f16s * 4 weights per = 16 weights + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + + for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) { + let blck_idx = i / BLOCK_SIZE; + let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; + let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; + + let tile_m = blck_idx / BLOCKS_K; + let global_m = offset_m + tile_m; + let block_k = blck_idx % BLOCKS_K; + let global_k = k_outer / BLOCK_SIZE + block_k; + + if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { + let src0_idx = batch_offset + global_m * params.stride_01 + global_k; + let scale_idx = src0_idx * F16_PER_BLOCK; + + let d = src0[scale_idx]; + let m = src0[scale_idx + 1u]; + let qh0 = src0[scale_idx + 2u]; + let qh1 = src0[scale_idx + 3u]; + let qh_packed = bitcast(vec2(qh0, qh1)); + + for (var j = 0u; j < 2; j++) { + + let q_0 = src0[scale_idx + 4u + block_offset + (j*2)]; + let q_1 = src0[scale_idx + 4u + block_offset + (j*2) + 1u]; + + let q_packed = bitcast(vec2(q_0, q_1)); + + let j_adjusted = j + (block_offset / 2u); + + + for (var k = 0u; k < 4u; k++) { + let q_byte = get_byte(q_packed, k); + + let qh_hi = (qh_packed >> (j_adjusted * 4 + k + 12)) & 0x10; + let q_hi = (f16(((q_byte >> 4) & 0xF) | qh_hi)) * d + m; + let qh_lo = ((qh_packed >> (j_adjusted * 4 + k)) << 4) & 0x10; + let q_lo = (f16((q_byte & 0xF) | qh_lo)) * d + m; + + shmem[shmem_idx + j * 4u + k] = q_lo; // store first weight + shmem[shmem_idx + j * 4u + k + 16u] = q_hi; // store second weight + } + } + } + } +} +#endif // INIT_SRC0_SHMEM_Q5_1 + +#ifdef INIT_SRC0_SHMEM_Q8_0 +const BLOCK_SIZE = 32u; +// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. +override BLOCKS_K = TILE_K/BLOCK_SIZE; +const NQ = 16u; +const F16_PER_BLOCK = 17u; // 1 scale + 16 in array of weights +const WEIGHTS_PER_F16 = 2u; // 2 8-bit weights per f16 +const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 8 f16s per thread + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) { + let blck_idx = i / BLOCK_SIZE; + let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; + let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; + + let tile_m = blck_idx / BLOCKS_K; + let global_m = offset_m + tile_m; + let block_k = blck_idx % BLOCKS_K; + let global_k = k_outer / BLOCK_SIZE + block_k; + + if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { + let src0_idx = batch_offset + global_m * params.stride_01 + global_k; + let scale_idx = src0_idx * F16_PER_BLOCK; + let d = src0[scale_idx]; + + for (var j = 0u; j < F16_PER_THREAD; j+=2) { + let q_0 = src0[scale_idx + 1u + block_offset + j]; + let q_1 = src0[scale_idx + 1u + block_offset + j + 1]; + + let q_packed = bitcast(vec2(q_0, q_1)); + for (var k = 0u; k < 4u; k++) { + let q_byte = get_byte_i32(q_packed, k); + + let q_val = f16(q_byte) * d; + shmem[shmem_idx + j * 2 + k] = q_val; + } + } + } + } +} +#endif // INIT_SRC0_SHMEM_Q8_0 + +#ifdef INIT_SRC0_SHMEM_Q8_1 +const BLOCK_SIZE = 32u; +// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. +override BLOCKS_K = TILE_K/BLOCK_SIZE; +const NQ = 16u; +const F16_PER_BLOCK = 18u; // 1 scale + 1 mean + 8 32-bit values in array of weights +const WEIGHTS_PER_F16 = 2u; // 2 8-bit weights per f16 +const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 8 f16s per thread, 2 threads per block + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) { + let blck_idx = i / BLOCK_SIZE; + let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; + let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; + + let tile_m = blck_idx / BLOCKS_K; + let global_m = offset_m + tile_m; + let block_k = blck_idx % BLOCKS_K; + let global_k = k_outer / BLOCK_SIZE + block_k; + + if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { + let src0_idx = batch_offset + global_m * params.stride_01 + global_k; + let scale_idx = src0_idx * F16_PER_BLOCK; + let d = src0[scale_idx]; + let m = src0[scale_idx + 1u]; + + for (var j = 0u; j < F16_PER_THREAD; j+=2) { + let q_0 = src0[scale_idx + 2u + block_offset + j]; + let q_1 = src0[scale_idx + 2u + block_offset + j + 1]; + + let q_packed = bitcast(vec2(q_0, q_1)); + for (var k = 0u; k < 4u; k++) { + let q_byte = get_byte_i32(q_packed, k); + + let q_val = f16(q_byte) * d + m; + shmem[shmem_idx + j * 2 + k] = q_val; + } + } + } + } +} +#endif // INIT_SRC0_SHMEM_Q8_1 + +#ifdef INIT_SRC0_SHMEM_Q2_K +const BLOCK_SIZE = 256u; +const F16_PER_BLOCK = 42u; + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + // Use standard thread layout instead of lane/row_group + for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { + let tile_m = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + + let global_m = offset_m + tile_m; + let global_k = k_outer + tile_k; + + if (global_m >= params.m || global_k >= params.k) { + shmem[elem_idx] = f16(0.0); + continue; + } + + let block_k = global_k / BLOCK_SIZE; + let k_in_block = global_k % BLOCK_SIZE; + + let src0_idx = batch_offset + global_m * params.stride_01 + block_k; + let scale_idx = src0_idx * F16_PER_BLOCK; + + let d = src0[scale_idx + 40u]; + let dmin = src0[scale_idx + 41u]; + + // Decode the element at position k_in_block + let block_of_32 = k_in_block / 32u; + let pos_in_32 = k_in_block % 32u; + + let q_b_idx = (block_of_32 / 4u) * 32u; + let shift = (block_of_32 % 4u) * 2u; + let k = (pos_in_32 / 16u) * 16u; + let l = pos_in_32 % 16u; + + let is = k_in_block / 16u; + + let sc_0 = src0[scale_idx + 2u * (is / 4u)]; + let sc_1 = src0[scale_idx + 2u * (is / 4u) + 1u]; + let sc_packed = bitcast(vec2(sc_0, sc_1)); + let sc = get_byte(sc_packed, is % 4u); + + let dl = d * f16(sc & 0xFu); + let ml = dmin * f16(sc >> 4u); + + let q_idx = q_b_idx + k + l; + let q_0 = src0[scale_idx + 8u + 2u * (q_idx / 4u)]; + let q_1 = src0[scale_idx + 8u + 2u * (q_idx / 4u) + 1u]; + let q_packed = bitcast(vec2(q_0, q_1)); + let q_byte = get_byte(q_packed, q_idx % 4u); + let qs_val = (q_byte >> shift) & 3u; + + let q_val = f16(qs_val) * dl - ml; + shmem[elem_idx] = q_val; + } +} +#endif // INIT_SRC0_SHMEM_Q2_K + +#ifdef INIT_SRC0_SHMEM_Q3_K +const BLOCK_SIZE = 256u; +const F16_PER_BLOCK = 55u; + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { + let tile_m = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + + let global_m = offset_m + tile_m; + let global_k = k_outer + tile_k; + + if (global_m >= params.m || global_k >= params.k) { + shmem[elem_idx] = f16(0.0); + continue; + } + + let block_k = global_k / BLOCK_SIZE; + let k_in_block = global_k % BLOCK_SIZE; + + let src0_idx = batch_offset + global_m * params.stride_01 + block_k; + let scale_idx = src0_idx * F16_PER_BLOCK; + + let d = src0[scale_idx + 54u]; + + // Load and unpack scales + let kmask1: u32 = 0x03030303u; + let kmask2: u32 = 0x0f0f0f0fu; + + var scale_vals: array; + for (var i: u32 = 0u; i < 4u; i++) { + let scale_0 = src0[scale_idx + 48u + (2u*i)]; + let scale_1 = src0[scale_idx + 48u + (2u*i) + 1u]; + scale_vals[i] = bitcast(vec2(scale_0, scale_1)); + } + + var tmp: u32 = scale_vals[2]; + scale_vals[2] = ((scale_vals[0] >> 4u) & kmask2) | (((tmp >> 4u) & kmask1) << 4u); + scale_vals[3] = ((scale_vals[1] >> 4u) & kmask2) | (((tmp >> 6u) & kmask1) << 4u); + scale_vals[0] = (scale_vals[0] & kmask2) | ((tmp & kmask1) << 4u); + scale_vals[1] = (scale_vals[1] & kmask2) | (((tmp >> 2u) & kmask1) << 4u); + + // Load hmask and qs arrays + var hmask_vals: array; + for (var i: u32 = 0u; i < 8u; i++) { + let hmask_0 = src0[scale_idx + (2u*i)]; + let hmask_1 = src0[scale_idx + (2u*i) + 1u]; + hmask_vals[i] = bitcast(vec2(hmask_0, hmask_1)); + } + + var qs_vals: array; + for (var i: u32 = 0u; i < 16u; i++) { + let qs_0 = src0[scale_idx + 16u + (2u*i)]; + let qs_1 = src0[scale_idx + 16u + (2u*i) + 1u]; + qs_vals[i] = bitcast(vec2(qs_0, qs_1)); + } + + let half = k_in_block / 128u; // 0 or 1 + let pos_in_half = k_in_block % 128u; // 0-127 + let shift_group = pos_in_half / 32u; // 0-3 + let pos_in_32 = pos_in_half % 32u; // 0-31 + let k_group = pos_in_32 / 16u; // 0 or 1 + let l = pos_in_32 % 16u; // 0-15 + + let q_b_idx = half * 32u; // 0 or 32 + let shift = shift_group * 2u; // 0, 2, 4, 6 + let k = k_group * 16u; // 0 or 16 + let is = k_in_block / 16u; // 0-15 + + // m increments every 32 elements across entire 256 element block + let m_shift = k_in_block / 32u; // 0-7 + let m: u32 = 1u << m_shift; // 1,2,4,8,16,32,64,128 + + let sc = get_byte(scale_vals[is / 4u], is % 4u); + let dl = d * (f16(sc) - 32.0); + + let q_idx = q_b_idx + k + l; + let hm_idx = k + l; + + let q_byte = get_byte(qs_vals[q_idx / 4u], q_idx % 4u); + let hmask_byte = get_byte(hmask_vals[hm_idx / 4u], hm_idx % 4u); + + let hm = select(4.0, 0.0, (hmask_byte & m) != 0); + let qs_val = (q_byte >> shift) & 3u; + + let q_val = (f16(qs_val) - f16(hm)) * dl; + shmem[elem_idx] = q_val; + } +} + +#endif // INIT_SRC0_SHMEM_Q3_K + +#ifdef INIT_SRC0_SHMEM_Q4_K +const BLOCK_SIZE = 256u; +const F16_PER_BLOCK = 72u; + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { + let tile_m = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + + let global_m = offset_m + tile_m; + let global_k = k_outer + tile_k; + + if (global_m >= params.m || global_k >= params.k) { + shmem[elem_idx] = f16(0.0); + continue; + } + + let block_k = global_k / BLOCK_SIZE; + let k_in_block = global_k % BLOCK_SIZE; + + let src0_idx = batch_offset + global_m * params.stride_01 + block_k; + let scale_idx = src0_idx * F16_PER_BLOCK; + + let d = src0[scale_idx]; + let dmin = src0[scale_idx + 1u]; + + // Load packed scales + var scale_vals: array; + for (var i: u32 = 0u; i < 3u; i++) { + let scale_0 = src0[scale_idx + 2u + (2u*i)]; + let scale_1 = src0[scale_idx + 2u + (2u*i) + 1u]; + scale_vals[i] = bitcast(vec2(scale_0, scale_1)); + } + + // Map k_in_block to loop structure: + // Outer loop over 64-element groups (alternating q_b_idx) + // Inner loop over 2 shifts per group + let group_of_64 = k_in_block / 64u; // 0-3 (maps to q_b_idx) + let pos_in_64 = k_in_block % 64u; // 0-63 + let shift_group = pos_in_64 / 32u; // 0 or 1 + let l = pos_in_64 % 32u; // 0-31 + + let q_b_idx = group_of_64 * 32u; // 0, 32, 64, 96 + let shift = shift_group * 4u; // 0 or 4 + let is = k_in_block / 32u; // 0-7 + + var sc: u32; + var mn: u32; + + if (is < 4u) { + let sc_byte = get_byte(scale_vals[is / 4u], is % 4u); + let min_byte = get_byte(scale_vals[(is + 4u) / 4u], is % 4u); + sc = sc_byte & 63u; + mn = min_byte & 63u; + } else { + let sc_min_lo = get_byte(scale_vals[(is + 4u) / 4u], (is + 4u) % 4u); + let sc_hi = get_byte(scale_vals[(is - 4u) / 4u], (is - 4u) % 4u); + let min_hi = get_byte(scale_vals[is / 4u], is % 4u); + + sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u); + mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u); + } + + let dl = d * f16(sc); + let ml = dmin * f16(mn); + + let q_idx = q_b_idx + l; + let q_0 = src0[scale_idx + 8u + 2u * (q_idx / 4u)]; + let q_1 = src0[scale_idx + 8u + 2u * (q_idx / 4u) + 1u]; + let q_packed = bitcast(vec2(q_0, q_1)); + + let q_byte = get_byte(q_packed, q_idx % 4u); + let qs_val = (q_byte >> shift) & 0xFu; + + let q_val = f16(qs_val) * dl - ml; + shmem[elem_idx] = q_val; + } +} +#endif // INIT_SRC0_SHMEM_Q4_K + +#ifdef INIT_SRC0_SHMEM_Q5_K +const BLOCK_SIZE = 256u; +const F16_PER_BLOCK = 88u; + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { + let tile_m = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + + let global_m = offset_m + tile_m; + let global_k = k_outer + tile_k; + + if (global_m >= params.m || global_k >= params.k) { + shmem[elem_idx] = f16(0.0); + continue; + } + + let block_k = global_k / BLOCK_SIZE; + let k_in_block = global_k % BLOCK_SIZE; + + let src0_idx = batch_offset + global_m * params.stride_01 + block_k; + let scale_idx = src0_idx * F16_PER_BLOCK; + + let d = src0[scale_idx]; + let dmin = src0[scale_idx + 1u]; + + // Load packed scales + var scale_vals: array; + for (var i: u32 = 0u; i < 3u; i++) { + let scale_0 = src0[scale_idx + 2u + (2u*i)]; + let scale_1 = src0[scale_idx + 2u + (2u*i) + 1u]; + scale_vals[i] = bitcast(vec2(scale_0, scale_1)); + } + + // The original loop processes elements in groups of 64 + // Each group of 64: q_b_idx cycles through [0,32,64,96], shift cycles [0,4] + // But u increments EVERY 32 elements (after each l loop) + let group_of_64 = k_in_block / 64u; // 0-3 + let pos_in_64 = k_in_block % 64u; // 0-63 + let shift_group = pos_in_64 / 32u; // 0 or 1 + let l = pos_in_64 % 32u; // 0-31 + + let q_b_idx = group_of_64 * 32u; // 0, 32, 64, 96 + let shift = shift_group * 4u; // 0 or 4 + let is = k_in_block / 32u; // 0-7 + + // u increments every 32 elements (0->1, 1->2, 2->4, 3->8, 4->16, 5->32, 6->64, 7->128) + let u_shift = k_in_block / 32u; // 0-7 + let u: u32 = 1u << u_shift; + + var sc: u32; + var mn: u32; + + if (is < 4u) { + let sc_byte = get_byte(scale_vals[is / 4u], is % 4u); + let min_byte = get_byte(scale_vals[(is + 4u) / 4u], is % 4u); + sc = sc_byte & 63u; + mn = min_byte & 63u; + } else { + let sc_min_lo = get_byte(scale_vals[(is + 4u) / 4u], (is + 4u) % 4u); + let sc_hi = get_byte(scale_vals[(is - 4u) / 4u], (is - 4u) % 4u); + let min_hi = get_byte(scale_vals[is / 4u], is % 4u); + + sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u); + mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u); + } + + let dl = d * f16(sc); + let ml = dmin * f16(mn); + + let q_idx = q_b_idx + l; + let q_0 = src0[scale_idx + 24u + 2u * (q_idx / 4u)]; + let q_1 = src0[scale_idx + 24u + 2u * (q_idx / 4u) + 1u]; + let q_packed = bitcast(vec2(q_0, q_1)); + + let q_byte = get_byte(q_packed, q_idx % 4u); + + let qh_0 = src0[scale_idx + 8u + 2u * (l / 4u)]; + let qh_1 = src0[scale_idx + 8u + 2u * (l / 4u) + 1u]; + let qh_packed = bitcast(vec2(qh_0, qh_1)); + + let qh_byte = get_byte(qh_packed, l % 4u); + + let qs_val = (q_byte >> shift) & 0xFu; + let qh_val = select(0.0, 16.0, (qh_byte & u) != 0); + + let q_val = (f16(qs_val) + f16(qh_val)) * dl - ml; + shmem[elem_idx] = q_val; + } +} + +#endif // INIT_SRC0_SHMEM_Q5_K + +#ifdef INIT_SRC0_SHMEM_Q6_K +const BLOCK_SIZE = 256u; +const F16_PER_BLOCK = 105u; + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { + let tile_m = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + + let global_m = offset_m + tile_m; + let global_k = k_outer + tile_k; + + if (global_m >= params.m || global_k >= params.k) { + shmem[elem_idx] = f16(0.0); + continue; + } + + let block_k = global_k / BLOCK_SIZE; + let k_in_block = global_k % BLOCK_SIZE; + + let src0_idx = batch_offset + global_m * params.stride_01 + block_k; + let scale_idx = src0_idx * F16_PER_BLOCK; + + let half = k_in_block / 128u; + let pos_in_half = k_in_block % 128u; + let quarter = pos_in_half / 32u; + let l = pos_in_half % 32u; + + let ql_b_idx = half * 64u; + let qh_b_idx = half * 32u; + let sc_b_idx = half * 8u; + + // Load only ql13 word needed + let ql13_flat = ql_b_idx + l; + let ql13_word = ql13_flat / 4u; + let ql13 = bitcast(vec2( + src0[scale_idx + 2u * ql13_word], + src0[scale_idx + 2u * ql13_word + 1u] + )); + let ql13_b = get_byte(ql13, ql13_flat % 4u); + + // Load only ql24 word needed + let ql24_flat = ql_b_idx + l + 32u; + let ql24_word = ql24_flat / 4u; + let ql24 = bitcast(vec2( + src0[scale_idx + 2u * ql24_word], + src0[scale_idx + 2u * ql24_word + 1u] + )); + let ql24_b = get_byte(ql24, ql24_flat % 4u); + + // Load only qh word needed + let qh_flat = qh_b_idx + l; + let qh_word = qh_flat / 4u; + let qh = bitcast(vec2( + src0[scale_idx + 64u + 2u * qh_word], + src0[scale_idx + 64u + 2u * qh_word + 1u] + )); + let qh_b = get_byte(qh, qh_flat % 4u); + + let q1 = f16((ql13_b & 0xFu) | ((qh_b & 3u) << 4u)) - f16(32.0); + let q2 = f16((ql24_b & 0xFu) | (((qh_b >> 2u) & 3u) << 4u)) - f16(32.0); + let q3 = f16((ql13_b >> 4u) | (((qh_b >> 4u) & 3u) << 4u)) - f16(32.0); + let q4 = f16((ql24_b >> 4u) | (((qh_b >> 6u) & 3u) << 4u)) - f16(32.0); + + // Load only the scale word needed + let is = l / 16u; + let sc_idx = sc_b_idx + is + quarter * 2u; + let sc_word = sc_idx / 4u; + let sc = bitcast(vec2( + src0[scale_idx + 96u + 2u * sc_word], + src0[scale_idx + 96u + 2u * sc_word + 1u] + )); + let sc_val = get_byte_i32(sc, sc_idx % 4u); + + let d = src0[scale_idx + 104u]; + + var q_val: f16; + if (quarter == 0u) { + q_val = q1; + } else if (quarter == 1u) { + q_val = q2; + } else if (quarter == 2u) { + q_val = q3; + } else { + q_val = q4; + } + + shmem[elem_idx] = d * f16(sc_val) * q_val; + } +} +#endif // INIT_SRC0_SHMEM_Q6_K diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl index 761e3017c14..b1da421a691 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl @@ -50,6 +50,7 @@ fn get_local_m(thread_id: u32) -> u32 { const TOTAL_WORKGROUP_SIZE = WORKGROUP_SIZE_M * WORKGROUP_SIZE_N; const TILE_SRC0_SHMEM = TILE_K * WORKGROUP_SIZE_M * TILE_M; const TILE_SRC1_SHMEM = TILE_K * WORKGROUP_SIZE_N * TILE_N; + var shmem: array; @compute @workgroup_size(TOTAL_WORKGROUP_SIZE) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl index f9ea95e07b9..94f4bae11f4 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl @@ -1,4 +1,3 @@ - enable f16; #include "common_decls.tmpl" @@ -84,6 +83,294 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { } #endif +#ifdef MUL_ACC_Q4_1 + +const BLOCK_SIZE = 32; +const NQ = 16u; // number of weights per thread +const F16_PER_BLOCK = 10u; +const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 +const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; + +fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { + var local_sum = 0.0; + for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) { + let blck_idx = i / BLOCK_SIZE; + let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; + let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK; + // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] + let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; + let d = f32(src0[scale_idx]); + let m = f32(src0[scale_idx + 1u]); + for (var j = 0u; j < F16_PER_THREAD; j += 2) { + let q_0 = src0[scale_idx + 2u + block_offset + j]; + let q_1 = src0[scale_idx + 2u + block_offset + j + 1]; + let q_packed = bitcast(vec2(q_0, q_1)); + for (var k: u32 = 0; k < 4; k++) { + let q_byte = get_byte(q_packed, k); + let q_hi = f32((q_byte >> 4) & 0xF) * d + m; + let q_lo = f32(q_byte & 0xF) * d + m; + local_sum += q_lo * shared_vector[shmem_idx + j * 2 + k]; + local_sum += q_hi * shared_vector[shmem_idx + j * 2 + k + 16]; + } + } + } + return local_sum; +} +#endif + +#ifdef MUL_ACC_Q5_0 + +const BLOCK_SIZE = 32; +const NQ = 16u; // number of weights per thread +const F16_PER_BLOCK = 11u; +const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 +const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; + +fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { + var local_sum = 0.0; + for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) { + let blck_idx = i / BLOCK_SIZE; + let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; + let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK; + // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] + let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; + let d = f32(src0[scale_idx]); + let qh0 = src0[scale_idx + 1u]; + let qh1 = src0[scale_idx + 2u]; + let qh_packed = bitcast(vec2(qh0, qh1)); + + for (var j = 0u; j < 2; j++) { + let q_0 = src0[scale_idx + 3u + block_offset + (j*2)]; + let q_1 = src0[scale_idx + 3u + block_offset + (j*2) + 1u]; + let q_packed = bitcast(vec2(q_0, q_1)); + + let j_adjusted = j + (block_offset / 2u); + + for (var k: u32 = 0; k < 4; k++) { + let q_byte = get_byte(q_packed, k); + + let qh_hi = (qh_packed >> (j_adjusted * 4 + k + 12)) & 0x10; + let q_hi = (f32(((q_byte >> 4) & 0xF) | qh_hi) - 16.0) * d; + let qh_lo = ((qh_packed >> (j_adjusted * 4 + k)) << 4) & 0x10; + let q_lo = (f32((q_byte & 0xF) | qh_lo) - 16.0) * d; + + local_sum += q_lo * shared_vector[shmem_idx + j * 4 + k]; + local_sum += q_hi * shared_vector[shmem_idx + j * 4 + k + 16]; + } + + } + } + return local_sum; +} +#endif + + +#ifdef MUL_ACC_Q5_1 + +const BLOCK_SIZE = 32; +const NQ = 16u; // number of weights per thread +const F16_PER_BLOCK = 12u; +const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 +const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; + +fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { + var local_sum = 0.0; + for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) { + let blck_idx = i / BLOCK_SIZE; + let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; + let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK; + // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] + let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; + let d = f32(src0[scale_idx]); + let m = src0[scale_idx + 1u]; + let qh0 = src0[scale_idx + 2u]; + let qh1 = src0[scale_idx + 3u]; + let qh_packed = bitcast(vec2(qh0, qh1)); + + for (var j = 0u; j < 2; j++) { + let q_0 = src0[scale_idx + 4u + block_offset + (j*2)]; + let q_1 = src0[scale_idx + 4u + block_offset + (j*2) + 1u]; + let q_packed = bitcast(vec2(q_0, q_1)); + + let j_adjusted = j + (block_offset / 2u); + + for (var k: u32 = 0; k < 4; k++) { + let q_byte = get_byte(q_packed, k); + + let qh_hi = (qh_packed >> (j_adjusted * 4 + k + 12)) & 0x10; + let q_hi = f32(((q_byte >> 4) & 0xF) | qh_hi) * d + f32(m); + let qh_lo = ((qh_packed >> (j_adjusted * 4 + k)) << 4) & 0x10; + let q_lo = f32((q_byte & 0xF) | qh_lo) * d + f32(m); + + local_sum += q_lo * shared_vector[shmem_idx + j * 4 + k]; + local_sum += q_hi * shared_vector[shmem_idx + j * 4 + k + 16]; + } + + } + } + return local_sum; +} +#endif + + +#ifdef MUL_ACC_Q8_0 + +const BLOCK_SIZE = 32; +const NQ = 16u; // number of weights per thread +const F16_PER_BLOCK = 17u; +const WEIGHTS_PER_F16 = 2u; +const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; + +fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { + var local_sum = 0.0; + for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) { + let blck_idx = i / BLOCK_SIZE; + let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; + let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK; + // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] + let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; + let d = f32(src0[scale_idx]); + + for (var j = 0u; j < F16_PER_THREAD; j += 2) { + let q_0 = src0[scale_idx + 1 + block_offset + j]; + let q_1 = src0[scale_idx + 1 + block_offset + j + 1]; + let q_packed = bitcast(vec2(q_0, q_1)); + for (var k: u32 = 0; k < 4; k++) { + let q_byte = get_byte_i32(q_packed, k); + let q_val = f32(q_byte) * d; + local_sum += q_val * shared_vector[shmem_idx + j * 2 + k]; + } + } + } + return local_sum; +} +#endif + + +#ifdef MUL_ACC_Q8_1 + +const BLOCK_SIZE = 32; +const NQ = 16u; // number of weights per thread +const F16_PER_BLOCK = 18u; +const WEIGHTS_PER_F16 = 2u; +const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; + +fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { + var local_sum = 0.0; + for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) { + let blck_idx = i / BLOCK_SIZE; + let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; + let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK; + // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] + let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; + let d = f32(src0[scale_idx]); + let m = src0[scale_idx + 1u]; + + for (var j = 0u; j < F16_PER_THREAD; j += 2) { + let q_0 = src0[scale_idx + 2u + block_offset + j]; + let q_1 = src0[scale_idx + 2u + block_offset + j + 1]; + let q_packed = bitcast(vec2(q_0, q_1)); + for (var k: u32 = 0; k < 4; k++) { + let q_byte = get_byte_i32(q_packed, k); + let q_val = f32(q_byte) * d + f32(m); + local_sum += q_val * shared_vector[shmem_idx + j * 2 + k]; + } + } + } + return local_sum; +} +#endif + +#ifdef MUL_ACC_Q6_K + +const BLOCK_SIZE = 256u; +const F16_PER_BLOCK = 105u; + +fn load_u32_at(bbase: u32, byte_offset: u32) -> u32 { + let aligned = byte_offset & ~3u; + let idx = bbase + aligned / 2u; + return bitcast(vec2(src0[idx], src0[idx + 1u])); +} + +fn byte_of(v: u32, b: u32) -> u32 { + return (v >> (b * 8u)) & 0xFFu; +} + +fn sbyte_of(v: u32, b: u32) -> i32 { + let raw = i32((v >> (b * 8u)) & 0xFFu); + return select(raw, raw - 256, raw >= 128); +} + +fn mul_acc(tig: u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { + let tid = tig / 2u; + let ix = tig % 2u; + let ip = tid / 8u; + let il = tid % 8u; + let l0 = 4u * il; + let is = 8u * ip + l0 / 16u; + + let y_offset = 128u * ip + l0; + let q_offset_l = 64u * ip + l0; + let q_offset_h = 32u * ip + l0; + + let nb = tile_size / BLOCK_SIZE; + let k_block_start = k_outer / BLOCK_SIZE; + + // Aligned scale byte position (is can be odd) + let sc_base_byte = 192u + (is & ~3u); + let sc_byte_pos = is & 3u; + + var local_sum = 0.0; + + for (var i = ix; i < nb; i += 2u) { + let bbase = (idx_base + k_block_start + i) * F16_PER_BLOCK; + + let d_raw = load_u32_at(bbase, 208u); + let d = f32(bitcast>(d_raw)[0]); + + let ql1_u32 = load_u32_at(bbase, q_offset_l); + let ql2_u32 = load_u32_at(bbase, q_offset_l + 32u); + let qh_u32 = load_u32_at(bbase, 128u + q_offset_h); + let sc_u32_0 = load_u32_at(bbase, sc_base_byte); + let sc_u32_1 = load_u32_at(bbase, sc_base_byte + 4u); + + let sc0 = sbyte_of(sc_u32_0, sc_byte_pos); + let sc2 = sbyte_of(sc_u32_0, sc_byte_pos + 2u); + let sc4 = sbyte_of(sc_u32_1, sc_byte_pos); + let sc6 = sbyte_of(sc_u32_1, sc_byte_pos + 2u); + + var sums = vec4(0.0, 0.0, 0.0, 0.0); + + for (var l = 0u; l < 4u; l++) { + let y_base = i * BLOCK_SIZE + y_offset + l; + let yl0 = f32(shared_vector[y_base]); + let yl1 = f32(shared_vector[y_base + 32u]); + let yl2 = f32(shared_vector[y_base + 64u]); + let yl3 = f32(shared_vector[y_base + 96u]); + + let q1b = byte_of(ql1_u32, l); + let q2b = byte_of(ql2_u32, l); + let qhb = byte_of(qh_u32, l); + + let dq0 = f32(i32((q1b & 0x0Fu) | ((qhb & 0x03u) << 4u)) - 32); + let dq1 = f32(i32((q2b & 0x0Fu) | ((qhb & 0x0Cu) << 2u)) - 32); + let dq2 = f32(i32((q1b >> 4u) | ((qhb & 0x30u) )) - 32); + let dq3 = f32(i32((q2b >> 4u) | ((qhb & 0xC0u) >> 2u)) - 32); + + sums[0] += yl0 * dq0; + sums[1] += yl1 * dq1; + sums[2] += yl2 * dq2; + sums[3] += yl3 * dq3; + } + + local_sum += d * (sums[0] * f32(sc0) + sums[1] * f32(sc2) + + sums[2] * f32(sc4) + sums[3] * f32(sc6)); + } + + return local_sum; +} +#endif + struct MulMatParams { offset_src0: u32, offset_src1: u32, @@ -191,4 +478,3 @@ fn main( dst[dst_idx / VEC_SIZE] = store_val(group_base); } } - From 1e05b10d67f5833d7334e96ca7e527992555e71a Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 10 Mar 2026 21:36:57 +0200 Subject: [PATCH 242/831] ggml : bump RPC version (llama/20330) --- ggml/include/ggml-rpc.h | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/ggml/include/ggml-rpc.h b/ggml/include/ggml-rpc.h index df1ad2a5168..1c11495b66e 100644 --- a/ggml/include/ggml-rpc.h +++ b/ggml/include/ggml-rpc.h @@ -8,7 +8,12 @@ extern "C" { #define RPC_PROTO_MAJOR_VERSION 3 #define RPC_PROTO_MINOR_VERSION 6 -#define RPC_PROTO_PATCH_VERSION 0 +#define RPC_PROTO_PATCH_VERSION 1 + +#ifdef __cplusplus +static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT has changed - update RPC_PROTO_PATCH_VERSION"); +#endif + #define GGML_RPC_MAX_SERVERS 16 // backend API From 72c7a2532db3a5479f68747dc54c90a45a5ef828 Mon Sep 17 00:00:00 2001 From: Neo Zhang Date: Wed, 11 Mar 2026 09:53:05 +0800 Subject: [PATCH 243/831] fix for failed UT case: ACC, L2_NORM, UPSCALE, fused_glu, unary (llama/20283) --- ggml/src/ggml-sycl/common.hpp | 91 ++++++++++++++++++++ ggml/src/ggml-sycl/element_wise.cpp | 113 ++++++++++++------------ ggml/src/ggml-sycl/ggml-sycl.cpp | 3 +- ggml/src/ggml-sycl/norm.cpp | 128 ++++++++++++++-------------- 4 files changed, 213 insertions(+), 122 deletions(-) diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp index 298fddc1038..9f0efb65359 100644 --- a/ggml/src/ggml-sycl/common.hpp +++ b/ggml/src/ggml-sycl/common.hpp @@ -874,4 +874,95 @@ static bool fast_fp16_available(const int cc) { return true; //Intel GPUs always support FP16. } +enum class block_reduce_method { + MAX, + SUM, +}; + +template +struct block_reduce_policy; + +template +inline constexpr bool is_any = (std::is_same_v || ...); + +template +inline constexpr bool ggml_sycl_dependent_false_v = false; + +#define WARP_32_SIZE 32 + +template struct block_reduce_policy { + static T reduce(T val) { + if constexpr (is_any) { + return warp_reduce_sum(val); + } else { + static_assert(ggml_sycl_dependent_false_v, "Unsupported type for block reduce sum"); + } + } + + static T sentinel() { + if constexpr (std::is_same_v) { + return 0.0f; + } else if constexpr (std::is_same_v) { + return sycl::float2(0.0f, 0.0f); + } else if constexpr (std::is_same_v) { + return sycl::half2(0.0f, 0.0f); + } else if constexpr (std::is_same_v) { + return 0; + } else { + static_assert(ggml_sycl_dependent_false_v, "Unsupported type for block reduce sum"); + } + } +}; + +template struct block_reduce_policy { + static T reduce(T val) { + if constexpr (is_any) { + return warp_reduce_max(val); + } else { + static_assert(ggml_sycl_dependent_false_v, "Unsupported type for block reduce max"); + } + } + + static T sentinel() { + if constexpr (std::is_same_v) { + return -INFINITY; + } else if constexpr (std::is_same_v) { + return sycl::half2(-INFINITY, -INFINITY); + } else { + static_assert(ggml_sycl_dependent_false_v, "Unsupported type for block reduce max"); + } + } +}; + + +template +static T block_reduce(T val, T * shared_vals, int block_size_template) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + val = block_reduce_policy::reduce(val); + const int block_size = block_size_template == 0 ? item_ct1.get_local_range(2) : block_size_template; + const int nthreads = item_ct1.get_local_range(2); + const int nwarps = nthreads / WARP_SIZE; + + if (block_size > warp_size) { + assert((block_size <= 1024) && (block_size % warp_size) == 0); + const int warp_id = item_ct1.get_local_id(2) / warp_size; + const int lane_id = item_ct1.get_local_id(2) % warp_size; + if (lane_id == 0) { + shared_vals[warp_id] = val; + } + item_ct1.barrier(sycl::access::fence_space::local_space); + + size_t nreduce = nwarps / WARP_SIZE; + float tmp = 0.f; + if (lane_id < (static_cast(block_size) / warp_size)) { + for (size_t i = 0; i < nreduce; i += 1) + { + tmp += shared_vals[lane_id + i * WARP_SIZE]; + } + } + return block_reduce_policy::reduce(tmp); + } + return val; +} + #endif // GGML_SYCL_COMMON_HPP diff --git a/ggml/src/ggml-sycl/element_wise.cpp b/ggml/src/ggml-sycl/element_wise.cpp index 00d54b83f82..acd51bf45b2 100644 --- a/ggml/src/ggml-sycl/element_wise.cpp +++ b/ggml/src/ggml-sycl/element_wise.cpp @@ -9,23 +9,32 @@ #define SYCL_LOCAL_ID_CALC(ITEM, IDX) \ (ITEM.get_local_range(IDX) * ITEM.get_group(IDX) + ITEM.get_local_id(IDX)) +static void acc_f32(const float * x, const float * y, float * dst, const int64_t ne, + const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13, + const int64_t s11, const int64_t s12, const int64_t s13, const int64_t offset) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + const int64_t i = SYCL_LOCAL_ID_CALC(item_ct1, 2); -static void acc_f32(const float * x, const float * y, float * dst, const int ne, - const int ne10, const int ne11, const int ne12, - const int nb1, const int nb2, int offset, const sycl::nd_item<1> &item_ct1) { - const int i = SYCL_LOCAL_ID_CALC(item_ct1, 0); if (i >= ne) { return; } - int src1_idx = i - offset; - int oz = src1_idx / nb2; - int oy = (src1_idx - (oz * nb2)) / nb1; - int ox = src1_idx % nb1; - if (src1_idx >= 0 && ox < ne10 && oy < ne11 && oz < ne12) { - dst[i] = x[i] + y[ox + oy * ne10 + oz * ne10 * ne11]; - } else { - dst[i] = x[i]; + + int64_t src1_idx = i - offset; + + int64_t tmp = src1_idx; + const int64_t i13 = tmp / s13; + tmp -= i13 * s13; + const int64_t i12 = tmp / s12; + tmp -= i12 * s12; + const int64_t i11 = tmp / s11; + tmp -= i11 * s11; + const int64_t i10 = tmp; + + float val = x[i]; + if (src1_idx >= 0 && i10 < ne10 && i11 < ne11 && i12 < ne12 && i13 < ne13) { + val += y[((i13*ne12 + i12) * ne11 + i11) * ne10 + i10]; } + dst[i] = val; } /* Unary OP funcs */ @@ -364,18 +373,15 @@ static void gated_op_fused_geglu_quick(const T * x, const T * g, T * dst, const namespace ggml_sycl_detail { static void acc_f32_sycl(const float *x, const float *y, float *dst, - const int n_elements, const int ne10, const int ne11, - const int ne12, const int nb1, const int nb2, - const int offset, queue_ptr stream) { - int num_blocks = ceil_div(n_elements, SYCL_ACC_BLOCK_SIZE); - stream->parallel_for( - sycl::nd_range<1>(sycl::range<1>(num_blocks) * - sycl::range<1>(SYCL_ACC_BLOCK_SIZE), - sycl::range<1>(SYCL_ACC_BLOCK_SIZE)), - [=](sycl::nd_item<1> item_ct1) { - acc_f32(x, y, dst, n_elements, ne10, ne11, ne12, nb1, nb2, offset, - item_ct1); - }); + const int64_t n_elements, const int64_t ne10, const int64_t ne11, + const int64_t ne12, const int64_t ne13, const int64_t s1, const int64_t s2, const int64_t s3, + const int64_t offset, queue_ptr stream) { + const int num_blocks = (n_elements + SYCL_ACC_BLOCK_SIZE - 1) / SYCL_ACC_BLOCK_SIZE; + stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + acc_f32(x, y, dst, n_elements, ne10, ne11, ne12, ne13, s1, s2, s3, offset); + }); } template @@ -402,25 +408,19 @@ static void upscale_sycl(const T *x, T *dst, const int nb00, const int nb01, template static inline void dispatch_ggml_sycl_op_unary(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) { -#if defined (GGML_SYCL_F16) GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); -#else - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); -#endif GGML_ASSERT(dst->src[0]->type == dst->type); + dpct::queue_ptr main_stream = ctx.stream(); SYCL_CHECK(ggml_sycl_set_device(ctx.device)); switch (dst->type) { -#if defined (GGML_SYCL_F16) case GGML_TYPE_F16: { auto data_pts = cast_data(dst); kernel_invoker(data_pts.src, data_pts.dst, (int)ggml_nelements(dst->src[0]), main_stream, std::forward(args)...); break; } -#endif case GGML_TYPE_F32: { auto data_pts = cast_data(dst); @@ -434,14 +434,10 @@ static inline void dispatch_ggml_sycl_op_unary(ggml_backend_sycl_context & ctx, template static inline void dispatch_ggml_sycl_op_fused_glu(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) { -#if defined (GGML_SYCL_F16) GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); -#else - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); -#endif GGML_ASSERT(dst->src[0]->type == dst->type); + dpct::queue_ptr main_stream = ctx.stream(); SYCL_CHECK(ggml_sycl_set_device(ctx.device)); const ggml_tensor * src0 = dst->src[0]; @@ -463,7 +459,6 @@ static inline void dispatch_ggml_sycl_op_fused_glu(ggml_backend_sycl_context & c GGML_ASSERT(src0->type == src1->type); } switch (dst->type) { -#if defined (GGML_SYCL_F16) case GGML_TYPE_F16: { sycl::half * src0_p = (sycl::half *) src0_d; @@ -484,7 +479,6 @@ static inline void dispatch_ggml_sycl_op_fused_glu(ggml_backend_sycl_context & c std::forward(args)...); break; } -#endif case GGML_TYPE_F32: { float * src0_p = (float *) src0_d; @@ -513,13 +507,9 @@ static inline void dispatch_ggml_sycl_op_fused_glu(ggml_backend_sycl_context & c template static inline void dispatch_ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) { -#if defined (GGML_SYCL_F16) GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); -#else - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); -#endif + GGML_ASSERT(dst->src[0]->type == dst->type); dpct::queue_ptr main_stream = ctx.stream(); @@ -530,7 +520,6 @@ static inline void dispatch_ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx const float sf2 = (float) dst->ne[2] / dst->src[0]->ne[2]; const float sf3 = (float) dst->ne[3] / dst->src[0]->ne[3]; switch (dst->type) { -#if defined (GGML_SYCL_F16) case GGML_TYPE_F16: { auto data_pts = cast_data(dst); @@ -539,7 +528,6 @@ static inline void dispatch_ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx main_stream, std::forward(args)...); break; } -#endif case GGML_TYPE_F32: { auto data_pts = cast_data(dst); @@ -868,22 +856,31 @@ static inline void ggml_sycl_op_trunc(ggml_backend_sycl_context & ctx, ggml_tens } static inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->src[1]->type == GGML_TYPE_F32); + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + const float * src0_d = (const float *) src0->data; + const float * src1_d = (const float *) src1->data; + float * dst_d = (float *) dst->data; + + dpct::queue_ptr stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); - GGML_ASSERT(dst->ne[3] == 1); // just 3D tensors supported - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - const float * src0_dd = static_cast(dst->src[0]->data); - const float * src1_dd = static_cast(dst->src[1]->data); - float * dst_dd = static_cast(dst->data); - int nb1 = dst->op_params[0] / 4; // 4 bytes of float32 - int nb2 = dst->op_params[1] / 4; // 4 bytes of float32 - // int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused - int offset = dst->op_params[3] / 4; // offset in bytes + GGML_ASSERT(ggml_is_contiguous(src1)); + GGML_ASSERT(dst->nb[0] == ggml_element_size(dst)); + GGML_ASSERT(ggml_is_contiguously_allocated(dst)); + + const int64_t s1 = dst->op_params[0] / sizeof(float); + const int64_t s2 = dst->op_params[1] / sizeof(float); + const int64_t s3 = dst->op_params[2] / sizeof(float); + const int64_t offset = dst->op_params[3] / sizeof(float); - ggml_sycl_detail::acc_f32_sycl(src0_dd, src1_dd, dst_dd, (int)ggml_nelements(dst), (int)dst->src[1]->ne[0], (int)dst->src[1]->ne[1], (int)dst->src[1]->ne[2], nb1, nb2, offset, main_stream); + ggml_sycl_detail::acc_f32_sycl(src0_d, src1_d, dst_d, ggml_nelements(dst), + src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], + s1, s2, s3, offset, stream); } static inline void ggml_sycl_op_geglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index dfacde0af33..66dfc4532c0 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -4872,8 +4872,9 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g k > 0 && k <= 32; } case GGML_OP_POOL_2D: - case GGML_OP_ACC: return true; + case GGML_OP_ACC: + return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]); case GGML_OP_PAD: // TODO: add circular padding support for syscl, see https://github.com/ggml-org/llama.cpp/pull/16985 if (ggml_get_op_params_i32(op, 8) != 0) { diff --git a/ggml/src/ggml-sycl/norm.cpp b/ggml/src/ggml-sycl/norm.cpp index 00702b5d09c..09fce1280ad 100644 --- a/ggml/src/ggml-sycl/norm.cpp +++ b/ggml/src/ggml-sycl/norm.cpp @@ -202,47 +202,34 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const int6 } } -static void l2_norm_f32(const float* x, float* dst, const int ncols, const float eps, - const sycl::nd_item<3>& item_ct1, float* s_sum, int block_size) { - const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + - item_ct1.get_local_id(1); - const int tid = item_ct1.get_local_id(2); - const int nthreads = item_ct1.get_local_range(2); - const int nwarps = nthreads / WARP_SIZE; +template +static void l2_norm_f32(const float * x, float * dst, const int ncols, + const int64_t stride_row, const int64_t stride_channel, + const int64_t stride_sample, const float eps, + const sycl::nd_item<3>& item_ct1, float* s_sum, const int block_size) { + const int nrows = item_ct1.get_group_range(2); + const int nchannels = item_ct1.get_group_range(1); + + const int row = item_ct1.get_group(2); + const int channel = item_ct1.get_group(1); + const int sample = item_ct1.get_group(0); + const int tid = item_ct1.get_local_id(2); + + x += sample*stride_sample + channel*stride_channel + row*stride_row; + dst += ((sample*nchannels + channel)*nrows + row)*ncols; + float tmp = 0.0f; // partial sum for thread in warp for (int col = tid; col < ncols; col += block_size) { - const float xi = x[row * ncols + col]; + const float xi = x[col]; tmp += xi * xi; } - // sum up partial sums - tmp = warp_reduce_sum(tmp, item_ct1); - if (block_size > WARP_SIZE) { - - int warp_id = item_ct1.get_local_id(2) / WARP_SIZE; - int lane_id = item_ct1.get_local_id(2) % WARP_SIZE; - if (lane_id == 0) { - s_sum[warp_id] = tmp; - } - /* - DPCT1118:3: SYCL group functions and algorithms must be encountered in - converged control flow. You may need to adjust the code. - */ - item_ct1.barrier(sycl::access::fence_space::local_space); - size_t nreduce = nwarps / WARP_SIZE; - tmp = 0.f; - for (size_t i = 0; i < nreduce; i += 1) - { - tmp += s_sum[lane_id + i * WARP_SIZE]; - } - tmp = warp_reduce_sum(tmp, item_ct1); - } - - const float scale = sycl::rsqrt(sycl::max(tmp, eps * eps)); + tmp = block_reduce(tmp, s_sum, block_size); + const float scale = sycl::rsqrt(sycl::fmax(tmp, eps * eps)); for (int col = tid; col < ncols; col += block_size) { - dst[row * ncols + col] = scale * x[row * ncols + col]; + dst[col] = scale * x[col]; } } @@ -369,42 +356,50 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const } } -static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols, - const int nrows, const float eps, - queue_ptr stream, int device) { - // printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE); +template +static void l2_norm_f32_sycl(const float * x, + float * dst, + const int ncols, + const int nrows, + const int nchannels, + const int nsamples, + const int64_t stride_row, + const int64_t stride_channel, + const int64_t stride_sample, + const float eps, + queue_ptr stream, + int device) { + const dpct::dim3 blocks_num(nrows, nchannels, nsamples); + if (ncols < 1024) { - const sycl::range<3> block_dims(1, 1, WARP_SIZE); + const dpct::dim3 block_dims(warp_size, 1, 1); stream->submit([&](sycl::handler& cgh) { cgh.parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, + sycl::nd_range<3>(blocks_num * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) - [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - l2_norm_f32(x, dst, ncols, eps, item_ct1, - nullptr, WARP_SIZE); + [[sycl::reqd_sub_group_size(warp_size)]] { + l2_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, + nullptr, warp_size); }); }); } else { const int work_group_size = ggml_sycl_info().max_work_group_sizes[device]; - assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0); + assert(work_group_size % (warp_size * warp_size) == 0); const sycl::range<3> block_dims(1, 1, work_group_size); - /* - DPCT1049:19: The work-group size passed to the SYCL kernel may exceed - the limit. To get the device limit, query - info::device::max_work_group_size. Adjust the work-group size if needed. - */ + int lsm_size = block_dims[2] > warp_size ? work_group_size / warp_size * sizeof(float): 0; stream->submit([&](sycl::handler& cgh) { - sycl::local_accessor s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE), + sycl::local_accessor s_sum_acc_ct1(sycl::range<1>(lsm_size), cgh); + cgh.parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, + sycl::nd_range<3>(blocks_num * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) - [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - l2_norm_f32(x, dst, ncols, eps, item_ct1, - get_pointer(s_sum_acc_ct1), work_group_size); + [[sycl::reqd_sub_group_size(warp_size)]] { + l2_norm_f32(x, dst, ncols, stride_row, stride_channel, stride_sample, + eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size); }); }); } @@ -634,21 +629,28 @@ void ggml_sycl_op_rms_norm_back(ggml_backend_sycl_context & ctx, ggml_tensor * d } void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) { + const ggml_tensor * src0 = dst->src[0]; + const float * src0_d = (const float *) src0->data; + float * dst_d = (float *) dst->data; + dpct::queue_ptr stream = ctx.stream(); - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - - const int64_t ne00 = dst->src[0]->ne[0]; - const int64_t nrows = ggml_nrows(dst->src[0]); - const float * src0_dd = static_cast(dst->src[0]->data); - float * dst_dd = static_cast(dst->data); + GGML_TENSOR_UNARY_OP_LOCALS; float eps; memcpy(&eps, dst->op_params, sizeof(float)); + GGML_ASSERT(eps >= 0.0f); - l2_norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device); + const size_t ts0 = ggml_type_size(src0->type); + GGML_ASSERT(nb00 == ts0); + const int64_t s01 = nb01 / ts0; + const int64_t s02 = nb02 / ts0; + const int64_t s03 = nb03 / ts0; + /*support both WARP_SIZE or WARP_32_SIZE in code + choose by hardware for better performance + */ + l2_norm_f32_sycl(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream, ctx.device); } From 286387ef0a3d1e79d18032ec14efbd0ad1481180 Mon Sep 17 00:00:00 2001 From: Neo Zhang Date: Wed, 11 Mar 2026 09:53:34 +0800 Subject: [PATCH 244/831] fix op rope, add rope_back (llama/20293) --- ggml/src/ggml-sycl/convert.hpp | 6 + ggml/src/ggml-sycl/ggml-sycl.cpp | 4 + ggml/src/ggml-sycl/rope.cpp | 736 +++++++++++++++++++------------ ggml/src/ggml-sycl/rope.hpp | 6 + 4 files changed, 466 insertions(+), 286 deletions(-) diff --git a/ggml/src/ggml-sycl/convert.hpp b/ggml/src/ggml-sycl/convert.hpp index f93bd0df7d7..6e621f2154d 100644 --- a/ggml/src/ggml-sycl/convert.hpp +++ b/ggml/src/ggml-sycl/convert.hpp @@ -39,6 +39,11 @@ template return sycl::ext::oneapi::bfloat16(float(x)); } else if constexpr (std::is_same_v) { return static_cast(x); + } else if constexpr (std::is_same_v && std::is_same_v) { + return x.template convert(); + } else if constexpr (std::is_same_v && + std::is_same_v>) { + return {x.x, x.y}; } else if constexpr(std::is_same_v) { return int32_t(x); } else { @@ -46,4 +51,5 @@ template } } + #endif // GGML_SYCL_CONVERT_HPP diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 66dfc4532c0..f887061b279 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -4145,6 +4145,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg case GGML_OP_ROPE: ggml_sycl_rope(ctx, dst); break; + case GGML_OP_ROPE_BACK: + ggml_sycl_rope_back(ctx, dst); + break; case GGML_OP_IM2COL: ggml_sycl_im2col(ctx, dst); break; @@ -4851,6 +4854,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g return max_bias == 0.0f; } case GGML_OP_ROPE: + case GGML_OP_ROPE_BACK: case GGML_OP_IM2COL: return true; case GGML_OP_UPSCALE: diff --git a/ggml/src/ggml-sycl/rope.cpp b/ggml/src/ggml-sycl/rope.cpp index aeaa58b95b3..9d83a1e9fa0 100644 --- a/ggml/src/ggml-sycl/rope.cpp +++ b/ggml/src/ggml-sycl/rope.cpp @@ -1,4 +1,5 @@ #include "rope.hpp" +#include "convert.hpp" #include "ggml-sycl/common.hpp" #include "ggml.h" @@ -15,366 +16,489 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) { return 1.0f - sycl::min(1.0f, sycl::max(0.0f, y)); } -// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn -// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. -static void rope_yarn( - float theta_extrap, float freq_scale, rope_corr_dims corr_dims, int64_t i0, float ext_factor, float mscale, - float * cos_theta, float * sin_theta) { - // Get n-d rotational scaling corrected for extrapolation +template +static void rope_yarn(const float theta_extrap, const float freq_scale, + const rope_corr_dims corr_dims, const int64_t i0, + const float ext_factor, float mscale, float &cos_theta, + float &sin_theta) { float theta_interp = freq_scale * theta_extrap; float theta = theta_interp; if (ext_factor != 0.0f) { - float ramp_mix = rope_yarn_ramp(corr_dims.v[0], corr_dims.v[1], i0) * ext_factor; + float ramp_mix = + rope_yarn_ramp(corr_dims.v[0], corr_dims.v[1], i0) * ext_factor; theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; - // Get n-d magnitude scaling corrected for interpolation mscale *= 1.0f + 0.1f * sycl::log(1.0f / freq_scale); } - *cos_theta = sycl::cos(theta) * mscale; - *sin_theta = sycl::sin(theta) * mscale; + cos_theta = sycl::cos(theta) * mscale; + sin_theta = sycl::sin(theta) * mscale; + if (!forward) { + sin_theta *= -1.0f; + } } -template -static void rope_norm(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, - const int32_t * pos, float freq_scale, float ext_factor, float attn_factor, - const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors, - const sycl::nd_item<3> & item_ct1) { - const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) + item_ct1.get_local_id(1)); - - if (i0 >= ne0) { +template +static void rope_norm(const T *x, D *dst, const int ne00, const int ne01, + const int ne02, const int s01, const int s02, + const int s03, const int s1, const int s2, const int s3, + const int n_dims, const int32_t *pos, + const float freq_scale, const float ext_factor, + const float attn_factor, const rope_corr_dims corr_dims, + const float theta_scale, const float *freq_factors, + const int64_t *row_indices, const int set_rows_stride) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) + + item_ct1.get_local_id(1)); + + if (i0 >= ne00) { return; } - const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2); + const int row_dst = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); - const int row0 = row % ne1; - const int channel0 = row / ne1; + const uint32_t i3 = row_dst / (ne01 * ne02); + const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01; + const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01; - const int i = row * ne0 + i0; - const int i2 = channel0 * s2 + row0 * s1 + i0; + int idst = i0 + i1 * s1 + i2 * s2 + i3 * s3; + const int ix = i0 + i1 * s01 + i2 * s02 + i3 * s03; + + if (set_rows_stride != 0) { + idst = i1 * s1 + i0; + idst += row_indices[i2] * set_rows_stride; + } + const auto &store_coaelsced = [&](float x0, float x1) { + if constexpr (std::is_same_v) { + sycl::float2 v = sycl::float2(x0, x1); + ggml_sycl_memcpy_1<8>(dst + idst, &v); + } else if constexpr (std::is_same_v) { + sycl::half2 v = sycl::half2(x0, x1); + ggml_sycl_memcpy_1<4>(dst + idst, &v); + } + }; if (i0 >= n_dims) { - *reinterpret_cast *>(dst + i) = *reinterpret_cast *>(x + i2); + store_coaelsced(x[ix + 0], x[ix + 1]); return; } - const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f); + const float theta_base = pos[i2] * dpct::pow(theta_scale, i0 / 2.0f); const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f; float cos_theta; float sin_theta; - rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); + rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, + ext_factor, attn_factor, cos_theta, sin_theta); - const float x0 = x[i2 + 0]; - const float x1 = x[i2 + 1]; + const float x0 = x[ix + 0]; + const float x1 = x[ix + 1]; - dst[i + 0] = x0 * cos_theta - x1 * sin_theta; - dst[i + 1] = x0 * sin_theta + x1 * cos_theta; + store_coaelsced(x0 * cos_theta - x1 * sin_theta, + x0 * sin_theta + x1 * cos_theta); } -template -static void rope_neox(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, - const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, - const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors, - const sycl::nd_item<3> & item_ct1) { - const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) + item_ct1.get_local_id(1)); - - if (i0 >= ne0) { +template +static void rope_neox(const T *x, D *dst, const int ne00, const int ne01, + const int ne02, const int s01, const int s02, + const int s03, const int s1, const int s2, const int s3, + const int n_dims, const int32_t *pos, + const float freq_scale, const float ext_factor, + const float attn_factor, const rope_corr_dims corr_dims, + const float theta_scale, const float *freq_factors, + const int64_t *row_indices, const int set_rows_stride) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) + + item_ct1.get_local_id(1)); + + if (i0 >= ne00) { return; } - const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2); + const int row_dst = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); - const int row0 = row % ne1; - const int channel0 = row / ne1; + const uint32_t i3 = row_dst / (ne01 * ne02); + const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01; + const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01; - const int i = row * ne0 + i0 / 2; - const int i2 = channel0 * s2 + row0 * s1 + i0 / 2; + int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3; + const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03; + + if (set_rows_stride != 0) { + idst = i1 * s1 + i0 / 2; + idst += row_indices[i2] * set_rows_stride; + } if (i0 >= n_dims) { - *reinterpret_cast *>(dst + i + i0 / 2) = *reinterpret_cast *>(x + i2 + i0 / 2); + dst[idst + i0 / 2 + 0] = ggml_sycl_cast(x[ix + i0 / 2 + 0]); + dst[idst + i0 / 2 + 1] = ggml_sycl_cast(x[ix + i0 / 2 + 1]); + return; } - const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f); + const float theta_base = pos[i2] * dpct::pow(theta_scale, i0 / 2.0f); const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f; float cos_theta; float sin_theta; - rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); + rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, + ext_factor, attn_factor, cos_theta, sin_theta); - const float x0 = x[i2 + 0]; - const float x1 = x[i2 + n_dims / 2]; + const float x0 = x[ix + 0]; + const float x1 = x[ix + n_dims / 2]; - dst[i + 0] = x0 * cos_theta - x1 * sin_theta; - dst[i + n_dims / 2] = x0 * sin_theta + x1 * cos_theta; + dst[idst + 0] = ggml_sycl_cast(x0 * cos_theta - x1 * sin_theta); + dst[idst + n_dims / 2] = ggml_sycl_cast(x0 * sin_theta + x1 * cos_theta); } -template -static void rope_multi(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1, - const size_t s2, const int n_dims, const int32_t * pos, const float freq_scale, - const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims, - const float theta_scale, const float * freq_factors, const mrope_sections sections, - const bool is_imrope, const sycl::nd_item<3> & item_ct1) { - // get index pos - const int i0 = 2 * (item_ct1.get_group(1) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1)); - if (i0 >= ne0) { +template +static void rope_multi(const T *x, T *dst, const int ne00, const int ne01, + const int ne02, const int s01, const int s02, + const int s03, const int s1, const int s2, const int s3, + const int n_dims, const int32_t *pos, + const float freq_scale, const float ext_factor, + const float attn_factor, const rope_corr_dims corr_dims, + const float theta_scale, const float *freq_factors, + const mrope_sections sections, const bool is_imrope) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) + + item_ct1.get_local_id(1)); + + if (i0 >= ne00) { return; } - const int row_dst = (item_ct1.get_group(2) * item_ct1.get_local_range(2)) + item_ct1.get_local_id(2); - const int row_x = row_dst % ne1; - const int channel_x = row_dst / ne1; - const int idst = (row_dst * ne0) + (i0 / 2); - const size_t ix = ((size_t) channel_x * s2) + ((size_t) row_x * s1) + (i0 / 2); + const int row_dst = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); + + const uint32_t i3 = row_dst / (ne01 * ne02); + const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01; + const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01; + + int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3; + const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03; if (i0 >= n_dims) { - *reinterpret_cast *>(dst + idst + i0 / 2) = *reinterpret_cast *>(x + i0 / 2 + ix); + dst[idst + i0 / 2 + 0] = x[ix + i0 / 2 + 0]; + dst[idst + i0 / 2 + 1] = x[ix + i0 / 2 + 1]; + return; } - const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3]; + const int sect_dims = + sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3]; const int sec_w = sections.v[1] + sections.v[0]; const int sector = (i0 / 2) % sect_dims; - float theta_base = 0.0; if (is_imrope) { - if (sector % 3 == 1 && sector < 3 * sections.v[1]) { - theta_base = pos[channel_x + ne2 * 1]*sycl::pow(theta_scale, i0/2.0f); - } else if (sector % 3 == 2 && sector < 3 * sections.v[2]) { - theta_base = pos[channel_x + ne2 * 2]*sycl::pow(theta_scale, i0/2.0f); - } else if (sector % 3 == 0 && sector < 3 * sections.v[0]) { - theta_base = pos[channel_x]*sycl::pow(theta_scale, i0/2.0f); + if (sector % 3 == 1 && sector < 3 * sections.v[1]) { // h + theta_base = pos[i2 + ne02 * 1] * dpct::pow(theta_scale, i0 / 2.0f); + } else if (sector % 3 == 2 && sector < 3 * sections.v[2]) { // w + theta_base = pos[i2 + ne02 * 2] * dpct::pow(theta_scale, i0 / 2.0f); + } else if (sector % 3 == 0 && sector < 3 * sections.v[0]) { // t + theta_base = pos[i2] * dpct::pow(theta_scale, i0 / 2.0f); } else { - theta_base = pos[channel_x + ne2 * 3]*sycl::pow(theta_scale, i0/2.0f); + theta_base = pos[i2 + ne02 * 3] * dpct::pow(theta_scale, i0 / 2.0f); } } else { if (sector < sections.v[0]) { - theta_base = pos[channel_x]*sycl::pow(theta_scale, i0/2.0f); - } - else if (sector >= sections.v[0] && sector < sec_w) { - theta_base = pos[channel_x + ne2 * 1]*sycl::pow(theta_scale, i0/2.0f); - } - else if (sector >= sec_w && sector < sec_w + sections.v[2]) { - theta_base = pos[channel_x + ne2 * 2]*sycl::pow(theta_scale, i0/2.0f); - } - else if (sector >= sec_w + sections.v[2]) { - theta_base = pos[channel_x + ne2 * 3]*sycl::pow(theta_scale, i0/2.0f); + theta_base = pos[i2] * dpct::pow(theta_scale, i0 / 2.0f); + } else if (sector >= sections.v[0] && sector < sec_w) { + theta_base = pos[i2 + ne02 * 1] * dpct::pow(theta_scale, i0 / 2.0f); + } else if (sector >= sec_w && sector < sec_w + sections.v[2]) { + theta_base = pos[i2 + ne02 * 2] * dpct::pow(theta_scale, i0 / 2.0f); + } else if (sector >= sec_w + sections.v[2]) { + theta_base = pos[i2 + ne02 * 3] * dpct::pow(theta_scale, i0 / 2.0f); } } const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f; - float cos_theta; - float sin_theta; - rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); - const float x0 = x[ix + 0]; - const float x1 = x[ix + n_dims/2]; - // store results in dst - dst[idst + 0] = x0 * cos_theta - x1 * sin_theta; - dst[idst + n_dims/2] = x0 * sin_theta + x1 * cos_theta; -} + float cos_theta; + float sin_theta; + rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, + ext_factor, attn_factor, cos_theta, sin_theta); + const float x0 = x[ix + 0]; + const float x1 = x[ix + n_dims / 2]; + + dst[idst + 0] = x0 * cos_theta - x1 * sin_theta; + dst[idst + n_dims / 2] = x0 * sin_theta + x1 * cos_theta; +} -template -static void rope_vision(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1, - const size_t s2, const int n_dims, const int32_t * pos, const float freq_scale, - const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims, - const float theta_scale, const float * freq_factors, const mrope_sections sections, - const sycl::nd_item<3> & item_ct1) { - // get index pos - const int i0 = 2 * (item_ct1.get_group(1) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1)); - if (i0 >= ne0) { +template +static void rope_vision(const T *x, T *dst, const int ne00, const int ne01, + const int ne02, const int s01, const int s02, + const int s03, const int s1, const int s2, const int s3, + const int n_dims, const int32_t *pos, + const float freq_scale, const float ext_factor, + const float attn_factor, const rope_corr_dims corr_dims, + const float theta_scale, const float *freq_factors, + const mrope_sections sections) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) + + item_ct1.get_local_id(1)); + + if (i0 >= ne00) { return; } - const int row_dst = (item_ct1.get_group(2) * item_ct1.get_local_range(2)) + item_ct1.get_local_id(2); - const int row_x = row_dst % ne1; - const int channel_x = row_dst / ne1; - const int idst = (row_dst * ne0) + (i0 / 2); - const size_t ix = ((size_t) channel_x * s2) + ((size_t) row_x * s1) + (i0 / 2); + + const int row_dst = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); + + const uint32_t i3 = row_dst / (ne01 * ne02); + const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01; + const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01; + + int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3; + const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03; const int sect_dims = sections.v[0] + sections.v[1]; - const int sector = (i0 / 2) % sect_dims; + const int sec_w = sections.v[1] + sections.v[0]; + const int sector = (i0 / 2) % sect_dims; - float theta_base = 0.0f; + float theta_base = 0.0; if (sector < sections.v[0]) { const int p = sector; - theta_base = pos[channel_x] * sycl::pow(theta_scale, (float) p); - } else { + theta_base = pos[i2] * dpct::pow(theta_scale, p); + } else if (sector >= sections.v[0] && sector < sec_w) { const int p = sector - sections.v[0]; - theta_base = pos[channel_x + ne2] * sycl::pow(theta_scale, (float) p); + theta_base = pos[i2 + ne02] * dpct::pow(theta_scale, p); } const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f; - float cos_theta; - float sin_theta; - rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); + + float cos_theta; + float sin_theta; + + rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, + ext_factor, attn_factor, cos_theta, sin_theta); + const float x0 = x[ix + 0]; const float x1 = x[ix + n_dims]; - // store results in dst - dst[idst + 0] = x0 * cos_theta - x1 * sin_theta; + dst[idst + 0] = x0 * cos_theta - x1 * sin_theta; dst[idst + n_dims] = x0 * sin_theta + x1 * cos_theta; } -template -static void rope_norm_sycl(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, - const int n_dims, int nr, const int32_t * pos, const float freq_scale, const float freq_base, - const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims, - const float * freq_factors, queue_ptr stream) { - GGML_ASSERT(ne0 % 2 == 0); - const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1); - const int num_blocks_x = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE)); - const sycl::range<3> block_nums(1, num_blocks_x, nr); +template +static void +rope_norm_sycl(const T *x, D *dst, const int ne00, const int ne01, + const int ne02, const int s01, const int s02, const int s03, + const int s1, const int s2, const int s3, const int n_dims, + const int nr, const int32_t *pos, const float freq_scale, + const float freq_base, const float ext_factor, + const float attn_factor, const rope_corr_dims corr_dims, + const float *freq_factors, const int64_t *row_indices, + const int set_rows_stride, dpct::queue_ptr stream) { + GGML_ASSERT(ne00 % 2 == 0); + const dpct::dim3 block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1); + const int n_blocks_x = + (ne00 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE); + const dpct::dim3 block_nums(nr, n_blocks_x, 1); const float theta_scale = powf(freq_base, -2.0f / n_dims); - dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); - if (freq_factors == nullptr) { - /* - DPCT1049:40: The work-group size passed to the SYCL kernel may exceed - the limit. To get the device limit, query - info::device::max_work_group_size. Adjust the work-group size if needed. - */ - stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { - rope_norm(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, - theta_scale, freq_factors, item_ct1); - }); + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + GGML_UNUSED(item_ct1); + rope_norm( + x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, + pos, freq_scale, ext_factor, attn_factor, corr_dims, + theta_scale, freq_factors, row_indices, set_rows_stride); + }); } else { - /* - DPCT1049:41: The work-group size passed to the SYCL kernel may exceed - the limit. To get the device limit, query - info::device::max_work_group_size. Adjust the work-group size if needed. - */ - stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { - rope_norm(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, - theta_scale, freq_factors, item_ct1); - }); + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + GGML_UNUSED(item_ct1); + rope_norm( + x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, + pos, freq_scale, ext_factor, attn_factor, corr_dims, + theta_scale, freq_factors, row_indices, set_rows_stride); + }); } } -template -static void rope_neox_sycl(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, - const int n_dims, const int nr, const int32_t * pos, const float freq_scale, - const float freq_base, const float ext_factor, const float attn_factor, - const rope_corr_dims corr_dims, const float * freq_factors, queue_ptr stream) { - GGML_ASSERT(ne0 % 2 == 0); - const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1); - const int num_blocks_x = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE)); - const sycl::range<3> block_nums(1, num_blocks_x, nr); +template +static void +rope_neox_sycl(const T *x, D *dst, const int ne00, const int ne01, + const int ne02, const int s01, const int s02, const int s03, + const int s1, const int s2, const int s3, const int n_dims, + const int nr, const int32_t *pos, const float freq_scale, + const float freq_base, const float ext_factor, + const float attn_factor, const rope_corr_dims corr_dims, + const float *freq_factors, const int64_t *row_indices, + const int set_rows_stride, dpct::queue_ptr stream) { + GGML_ASSERT(ne00 % 2 == 0); + const dpct::dim3 block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1); + const int n_blocks_x = + (ne00 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE); + const dpct::dim3 block_nums(nr, n_blocks_x, 1); const float theta_scale = powf(freq_base, -2.0f / n_dims); - dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); - if (freq_factors == nullptr) { - stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { - rope_neox(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, - theta_scale, freq_factors, item_ct1); - }); + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + GGML_UNUSED(item_ct1); + rope_neox( + x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, + pos, freq_scale, ext_factor, attn_factor, corr_dims, + theta_scale, freq_factors, row_indices, set_rows_stride); + }); } else { - stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { - rope_neox(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, - theta_scale, freq_factors, item_ct1); - }); + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + GGML_UNUSED(item_ct1); + rope_neox( + x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, + pos, freq_scale, ext_factor, attn_factor, corr_dims, + theta_scale, freq_factors, row_indices, set_rows_stride); + }); } } -template -static void rope_multi_sycl(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1, - const size_t s2, const int n_dims, const int nr, const int32_t * pos, - const float freq_scale, const float freq_base, const float ext_factor, - const float attn_factor, const rope_corr_dims corr_dims, const float * freq_factors, - const mrope_sections sections, const bool is_imrope, queue_ptr stream) { - GGML_ASSERT(ne0 % 2 == 0); - const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1); - const int n_blocks_y = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE)); - const sycl::range<3> grid_dims(1, n_blocks_y, nr); - const sycl::nd_range<3> nd_range(grid_dims * block_dims, block_dims); - - const float theta_scale = std::pow(freq_base, -2.0f / n_dims); - // Add FP16 capability check if T could be sycl::half - if constexpr (std::is_same_v) { - dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); - } - // launch kernel +template +static void +rope_multi_sycl(const T *x, T *dst, const int ne00, const int ne01, + const int ne02, const int s01, const int s02, const int s03, + const int s1, const int s2, const int s3, const int n_dims, + const int nr, const int32_t *pos, const float freq_scale, + const float freq_base, const float ext_factor, + const float attn_factor, const rope_corr_dims corr_dims, + const float *freq_factors, const mrope_sections sections, + const bool is_imrope, dpct::queue_ptr stream) { + GGML_ASSERT(ne00 % 2 == 0); + const dpct::dim3 block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1); + const int n_blocks_x = + (ne00 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE); + const dpct::dim3 block_nums(nr, n_blocks_x, 1); + + const float theta_scale = powf(freq_base, -2.0f / n_dims); + if (freq_factors == nullptr) { - stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) { - rope_multi(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, - corr_dims, theta_scale, freq_factors, sections, is_imrope, item_ct1); - }); + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + GGML_UNUSED(item_ct1); + rope_multi( + x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, + pos, freq_scale, ext_factor, attn_factor, corr_dims, + theta_scale, freq_factors, sections, is_imrope); + }); } else { - stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) { - rope_multi(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, - corr_dims, theta_scale, freq_factors, sections, is_imrope, item_ct1); - }); + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + GGML_UNUSED(item_ct1); + rope_multi( + x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, + pos, freq_scale, ext_factor, attn_factor, corr_dims, + theta_scale, freq_factors, sections, is_imrope); + }); } } +template +static void +rope_vision_sycl(const T *x, T *dst, const int ne00, const int ne01, + const int ne02, const int s01, const int s02, const int s03, + const int s1, const int s2, const int s3, const int n_dims, + const int nr, const int32_t *pos, const float freq_scale, + const float freq_base, const float ext_factor, + const float attn_factor, const rope_corr_dims corr_dims, + const float *freq_factors, const mrope_sections sections, + dpct::queue_ptr stream) { + GGML_ASSERT(ne00 % 2 == 0); + const dpct::dim3 block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1); + const int n_blocks_x = + (ne00 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE); + const dpct::dim3 block_nums(nr, n_blocks_x, 1); + const float theta_scale = powf(freq_base, -2.0f / n_dims); - -// rope vision -template -static void rope_vision_sycl(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1, - const size_t s2, const int n_dims, const int nr, const int32_t * pos, - const float freq_scale, const float freq_base, const float ext_factor, - const float attn_factor, const rope_corr_dims corr_dims, const float * freq_factors, - const mrope_sections sections, queue_ptr stream) { - GGML_ASSERT(ne0 % 2 == 0); - const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1); - const int n_blocks_y = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE)); - const sycl::range<3> grid_dims(1, n_blocks_y, nr); - const sycl::nd_range<3> nd_range(grid_dims * block_dims, block_dims); - - const float theta_scale = std::pow(freq_base, -2.0f / n_dims); - // Add FP16 capability check if T could be sycl::half - if constexpr (std::is_same_v) { - dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); - } - // launch kernel if (freq_factors == nullptr) { - stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) { - rope_vision(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, - corr_dims, theta_scale, freq_factors, sections, item_ct1); - }); + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + GGML_UNUSED(item_ct1); + rope_vision( + x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, + pos, freq_scale, ext_factor, attn_factor, corr_dims, + theta_scale, freq_factors, sections); + }); } else { - stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) { - rope_vision(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, - corr_dims, theta_scale, freq_factors, sections, item_ct1); - }); + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + GGML_UNUSED(item_ct1); + rope_vision( + x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, + pos, freq_scale, ext_factor, attn_factor, corr_dims, + theta_scale, freq_factors, sections); + }); } } -inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { +template +void ggml_sycl_op_rope_impl(ggml_backend_sycl_context &ctx, ggml_tensor *dst, + const ggml_tensor *set_rows = nullptr) { + const ggml_tensor *src0 = dst->src[0]; + const ggml_tensor *src1 = dst->src[1]; + const ggml_tensor *src2 = dst->src[2]; + + const float *src0_d = (const float *)src0->data; + const float *src1_d = (const float *)src1->data; + + void *dst_d = dst->data; + const int64_t *row_indices = nullptr; + ggml_type dst_type = dst->type; + int set_rows_stride = 0; + + if (set_rows != nullptr) { + GGML_ASSERT(forward); + dst_d = set_rows->data; + row_indices = (const int64_t *)set_rows->src[1]->data; + dst_type = set_rows->type; + set_rows_stride = set_rows->nb[1] / ggml_type_size(set_rows->type); + } + dpct::queue_ptr stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); + GGML_ASSERT(src0->type == dst->type || + (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16)); - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); - GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); - GGML_ASSERT(dst->src[0]->type == dst->type); - const int64_t ne00 = dst->src[0]->ne[0]; // head dims - const int64_t ne01 = dst->src[0]->ne[1]; // num heads - const int64_t ne02 = dst->src[0]->ne[2]; // num heads - const int64_t nr = ggml_nrows(dst->src[0]); + const int64_t ne00 = src0->ne[0]; // head dims + const int64_t ne01 = src0->ne[1]; // num heads + const int64_t ne02 = src0->ne[2]; // num heads + const int64_t nr = ggml_nrows(src0); - const size_t s01 = dst->src[0]->nb[1] / ggml_type_size(dst->src[0]->type); - const size_t s02 = dst->src[0]->nb[2] / ggml_type_size(dst->src[0]->type); + const size_t s01 = src0->nb[1] / ggml_type_size(src0->type); + const size_t s02 = src0->nb[2] / ggml_type_size(src0->type); + const size_t s03 = src0->nb[3] / ggml_type_size(src0->type); + const size_t s1 = dst->nb[1] / ggml_type_size(dst->type); + const size_t s2 = dst->nb[2] / ggml_type_size(dst->type); + const size_t s3 = dst->nb[3] / ggml_type_size(dst->type); - //const int n_past = ((int32_t *) dst->op_params)[0]; - const int n_dims = ((int32_t *) dst->op_params)[1]; - const int mode = ((int32_t *) dst->op_params)[2]; - //const int n_ctx = ((int32_t *) dst->op_params)[3]; - const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; + const int n_dims = ((int32_t *)dst->op_params)[1]; + const int mode = ((int32_t *)dst->op_params)[2]; + const int n_ctx_orig = ((int32_t *)dst->op_params)[4]; mrope_sections sections; - // RoPE alteration for extended context float freq_base; float freq_scale; float ext_factor; @@ -382,13 +506,13 @@ inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst) float beta_fast; float beta_slow; - memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); - memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); - memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); - memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); - memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); - memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); - memcpy(§ions.v, (int32_t *) dst->op_params + 11, sizeof(int)*4); + memcpy(&freq_base, (int32_t *)dst->op_params + 5, sizeof(float)); + memcpy(&freq_scale, (int32_t *)dst->op_params + 6, sizeof(float)); + memcpy(&ext_factor, (int32_t *)dst->op_params + 7, sizeof(float)); + memcpy(&attn_factor, (int32_t *)dst->op_params + 8, sizeof(float)); + memcpy(&beta_fast, (int32_t *)dst->op_params + 9, sizeof(float)); + memcpy(&beta_slow, (int32_t *)dst->op_params + 10, sizeof(float)); + memcpy(§ions.v, (int32_t *)dst->op_params + 11, sizeof(int) * 4); const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; @@ -396,82 +520,122 @@ inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst) const bool is_vision = mode == GGML_ROPE_TYPE_VISION; if (is_mrope) { - GGML_ASSERT(sections.v[0] > 0 || sections.v[1] > 0 || sections.v[2] > 0); + GGML_ASSERT(sections.v[0] > 0 || sections.v[1] > 0 || + sections.v[2] > 0); } if (is_vision) { - GGML_ASSERT(n_dims == ne00/2); + GGML_ASSERT(n_dims == ne00 / 2); } - const int32_t * pos = (const int32_t *) dst->src[1]->data; + const int32_t *pos = (const int32_t *)src1_d; - const float * freq_factors = nullptr; - if (dst->src[2] != nullptr) { - freq_factors = (const float *) dst->src[2]->data; + const float *freq_factors = nullptr; + if (src2 != nullptr) { + freq_factors = (const float *)src2->data; } rope_corr_dims corr_dims; - ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v); - - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, + beta_slow, corr_dims.v); // compute if (is_neox) { GGML_SYCL_DEBUG("%s: neox path\n", __func__); - if (dst->src[0]->type == GGML_TYPE_F32) { - rope_neox_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, s01, s02, n_dims, nr, - pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, main_stream); - } else if (dst->src[0]->type == GGML_TYPE_F16) { - rope_neox_sycl((const sycl::half *) dst->src[0]->data, (sycl::half *) dst->data, ne00, ne01, s01, s02, - n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, - main_stream); + if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) { + rope_neox_sycl( + (const float *)src0_d, (float *)dst_d, ne00, ne01, ne02, s01, + s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base, + ext_factor, attn_factor, corr_dims, freq_factors, row_indices, + set_rows_stride, stream); + } else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) { + rope_neox_sycl( + (const float *)src0_d, (sycl::half *)dst_d, ne00, ne01, ne02, + s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale, + freq_base, ext_factor, attn_factor, corr_dims, freq_factors, + row_indices, set_rows_stride, stream); + } else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) { + rope_neox_sycl( + (const sycl::half *)src0_d, (sycl::half *)dst_d, ne00, ne01, + ne02, s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale, + freq_base, ext_factor, attn_factor, corr_dims, freq_factors, + row_indices, set_rows_stride, stream); } else { - GGML_ABORT("fatal error"); + GGML_ABORT("Fatal error: Tensor type unsupported!"); } } else if (is_mrope && !is_vision) { GGML_SYCL_DEBUG("%s: mrope path\n", __func__); - if (dst->src[0]->type == GGML_TYPE_F16) { - rope_multi_sycl((const sycl::half *)dst->src[0]->data, (sycl::half *)dst->data, ne00, ne01, ne02, s01, - s02, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, - freq_factors, sections, is_imrope, main_stream); - } else if (dst->src[0]->type == GGML_TYPE_F32) { - rope_multi_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, ne02, s01, s02, n_dims, - nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, - is_imrope, main_stream); + if (src0->type == GGML_TYPE_F32) { + rope_multi_sycl((const float *)src0_d, (float *)dst_d, + ne00, ne01, ne02, s01, s02, s03, s1, s2, + s3, n_dims, nr, pos, freq_scale, freq_base, + ext_factor, attn_factor, corr_dims, + freq_factors, sections, is_imrope, stream); + } else if (src0->type == GGML_TYPE_F16) { + rope_multi_sycl( + (const sycl::half *)src0_d, (sycl::half *)dst_d, ne00, ne01, + ne02, s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale, + freq_base, ext_factor, attn_factor, corr_dims, freq_factors, + sections, is_imrope, stream); } else { GGML_ABORT("Fatal error: Tensor type unsupported!"); } } else if (is_vision) { GGML_SYCL_DEBUG("%s: vision path\n", __func__); - if (dst->src[0]->type == GGML_TYPE_F16) { - rope_vision_sycl((const sycl::half *) dst->src[0]->data, (sycl::half *) dst->data, ne00, ne01, ne02, s01, - s02, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, - freq_factors, sections, main_stream); - } else if (dst->src[0]->type == GGML_TYPE_F32) { - rope_vision_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, ne02, s01, s02, n_dims, - nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, - main_stream); + if (src0->type == GGML_TYPE_F32) { + rope_vision_sycl( + (const float *)src0_d, (float *)dst_d, ne00, ne01, ne02, s01, + s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base, + ext_factor, attn_factor, corr_dims, freq_factors, sections, + stream); + } else if (src0->type == GGML_TYPE_F16) { + rope_vision_sycl( + (const sycl::half *)src0_d, (sycl::half *)dst_d, ne00, ne01, + ne02, s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale, + freq_base, ext_factor, attn_factor, corr_dims, freq_factors, + sections, stream); } else { GGML_ABORT("Fatal error: Tensor type unsupported!"); } } else { GGML_SYCL_DEBUG("%s: norm path\n", __func__); - if (dst->src[0]->type == GGML_TYPE_F32) { - rope_norm_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, s01, s02, n_dims, nr, - pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, main_stream); - } else if (dst->src[0]->type == GGML_TYPE_F16) { - rope_norm_sycl((const sycl::half *) dst->src[0]->data, (sycl::half *) dst->data, ne00, ne01, s01, s02, - n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, - main_stream); + if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) { + rope_norm_sycl( + (const float *)src0_d, (float *)dst_d, ne00, ne01, ne02, s01, + s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base, + ext_factor, attn_factor, corr_dims, freq_factors, row_indices, + set_rows_stride, stream); + } else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) { + rope_norm_sycl( + (const float *)src0_d, (sycl::half *)dst_d, ne00, ne01, ne02, + s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale, + freq_base, ext_factor, attn_factor, corr_dims, freq_factors, + row_indices, set_rows_stride, stream); + } else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) { + rope_norm_sycl( + (const sycl::half *)src0_d, (sycl::half *)dst_d, ne00, ne01, + ne02, s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale, + freq_base, ext_factor, attn_factor, corr_dims, freq_factors, + row_indices, set_rows_stride, stream); } else { - GGML_ABORT("fatal error"); + GGML_ABORT("Fatal error: Tensor type unsupported!"); } } } -void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +void ggml_sycl_rope(ggml_backend_sycl_context &ctx, ggml_tensor *dst) { scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/3); - ggml_sycl_op_rope(ctx, dst); + + ggml_sycl_op_rope_impl(ctx, dst); } +void ggml_sycl_rope_back(ggml_backend_sycl_context &ctx, ggml_tensor *dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/3); + ggml_sycl_op_rope_impl(ctx, dst); +} + +void ggml_sycl_rope_fused(ggml_backend_sycl_context &ctx, ggml_tensor *rope, + ggml_tensor *set_rows) { + scope_op_debug_print scope_dbg_print(__func__, rope, /*num_src=*/3); + ggml_sycl_op_rope_impl(ctx, rope, set_rows); +} diff --git a/ggml/src/ggml-sycl/rope.hpp b/ggml/src/ggml-sycl/rope.hpp index 8c7141aac5c..b95a585808b 100644 --- a/ggml/src/ggml-sycl/rope.hpp +++ b/ggml/src/ggml-sycl/rope.hpp @@ -15,6 +15,12 @@ #include "common.hpp" +#define SYCL_ROPE_BLOCK_SIZE 256 + void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst); +void ggml_sycl_rope_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +void ggml_sycl_rope_fused(ggml_backend_sycl_context & ctx, ggml_tensor * dst, ggml_tensor * set_rows); + #endif // GGML_SYCL_ROPE_HPP From 7c9a16c565797da6c071852204148e5711359ef7 Mon Sep 17 00:00:00 2001 From: uvos Date: Wed, 11 Mar 2026 06:04:32 +0100 Subject: [PATCH 245/831] cuda/hip: fix loop unrolling in ssm-conv (llama/20369) --- ggml/src/ggml-cuda/ssm-conv.cu | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/ssm-conv.cu b/ggml/src/ggml-cuda/ssm-conv.cu index 85e82b5a422..69985cd335c 100644 --- a/ggml/src/ggml-cuda/ssm-conv.cu +++ b/ggml/src/ggml-cuda/ssm-conv.cu @@ -76,7 +76,7 @@ static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0, int row = tid / load_cols; int col = tid % load_cols; #pragma unroll - for (int idx = tid; idx < total_elems; idx += split_d_inner) { + for (int idx = 0; idx < total_elems; idx += split_d_inner) { if (row < (int)split_d_inner) { smem[row * n_cols + col] = x_block[row * stride_x + col]; } @@ -84,6 +84,9 @@ static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0, col += split_d_inner; row += col / load_cols; col = col % load_cols; + if (idx >= total_elems - tid - split_d_inner) { + break; + } } __syncthreads(); From 8b335550cf5da785044594f6dff5182dc043d01d Mon Sep 17 00:00:00 2001 From: uvos Date: Wed, 11 Mar 2026 06:06:19 +0100 Subject: [PATCH 246/831] ggml-cuda: gdn use shared mem for HIP (llama/20366) Suggested-by: Aman Gupta --- ggml/src/ggml-cuda/gated_delta_net.cu | 85 ++++++++++++++++++--------- 1 file changed, 56 insertions(+), 29 deletions(-) diff --git a/ggml/src/ggml-cuda/gated_delta_net.cu b/ggml/src/ggml-cuda/gated_delta_net.cu index d8e81114559..c249bbc86d5 100644 --- a/ggml/src/ggml-cuda/gated_delta_net.cu +++ b/ggml/src/ggml-cuda/gated_delta_net.cu @@ -2,28 +2,29 @@ #include "ggml-cuda/common.cuh" template -__global__ void gated_delta_net_cuda(const float * q, - const float * k, - const float * v, - const float * g, - const float * beta, - const float * curr_state, - float * dst, - int64_t H, - int64_t n_tokens, - int64_t n_seqs, - int64_t sq1, - int64_t sq2, - int64_t sq3, - int64_t sv1, - int64_t sv2, - int64_t sv3, - int64_t sb1, - int64_t sb2, - int64_t sb3, - int64_t rq1, - int64_t rq3, - float scale) { +__global__ void __launch_bounds__(S_v, 1) +gated_delta_net_cuda(const float * q, + const float * k, + const float * v, + const float * g, + const float * beta, + const float * curr_state, + float * dst, + const int64_t H, + const int64_t n_tokens, + const int64_t n_seqs, + const int64_t sq1, + const int64_t sq2, + const int64_t sq3, + const int64_t sv1, + const int64_t sv2, + const int64_t sv3, + const int64_t sb1, + const int64_t sb2, + const int64_t sb3, + const int64_t rq1, + const int64_t rq3, + const float scale) { const int64_t h_idx = blockIdx.x; const int64_t sequence = blockIdx.y; const int col = threadIdx.x; // each thread owns one column @@ -40,8 +41,14 @@ __global__ void gated_delta_net_cuda(const float * q, curr_state += state_offset; attn_data += (sequence * n_tokens * H + h_idx) * S_v; - // Load state column into registers + // GCN and CDNA devices spill registers, we use shared mem for them. See https://github.com/ggml-org/llama.cpp/pull/20282#issuecomment-4025770229 + // TODO: check optimal path for RDNA1 and RDNA2 devices. +#if (defined(GGML_USE_HIP) && !defined(RDNA3) && !defined(RDNA4)) || defined(GGML_USE_MUSA) + extern __shared__ float s_shared[]; + float * s = s_shared + col * S_v; +#else float s[S_v]; +#endif #pragma unroll for (int i = 0; i < S_v; i++) { s[i] = curr_state[i * S_v + col]; @@ -114,6 +121,15 @@ __global__ void gated_delta_net_cuda(const float * q, } } +static size_t calculate_smem(const int sv, int cc) +{ + size_t smem = 0; + if ((GGML_CUDA_CC_IS_AMD(cc) && !GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_RDNA4(cc)) || GGML_CUDA_CC_IS_MTHREADS(cc)) { + smem = sv * sv * sizeof(float); + } + return smem; +} + template static void launch_gated_delta_net( const float * q_d, const float * k_d, const float * v_d, @@ -129,25 +145,36 @@ static void launch_gated_delta_net( dim3 grid_dims(H, n_seqs, 1); dim3 block_dims(S_v, 1, 1); + int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + switch (S_v) { - case 32: - gated_delta_net_cuda<32, KDA><<>>( + case 32: { + constexpr int sv = 32; + size_t smem = calculate_smem(sv, cc); + gated_delta_net_cuda<<>>( q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, rq1, rq3, scale); break; - case 64: - gated_delta_net_cuda<64, KDA><<>>( + } + case 64: { + constexpr int sv = 64; + size_t smem = calculate_smem(sv, cc); + gated_delta_net_cuda<<>>( q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, rq1, rq3, scale); break; - case 128: - gated_delta_net_cuda<128, KDA><<>>( + } + case 128: { + constexpr int sv = 128; + size_t smem = calculate_smem(sv, cc); + gated_delta_net_cuda<<>>( q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, rq1, rq3, scale); break; + } default: GGML_ABORT("fatal error"); break; From c2e384f21ec242657f19e1659dd0adadb05ba187 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 11 Mar 2026 16:25:10 +0200 Subject: [PATCH 247/831] metal : add env var to trigger graph capture (llama/20398) --- ggml/src/ggml-metal/ggml-metal-context.m | 29 ++++++++++++++++++------ 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-context.m b/ggml/src/ggml-metal/ggml-metal-context.m index 1136ce99b09..a345cb1f689 100644 --- a/ggml/src/ggml-metal/ggml-metal-context.m +++ b/ggml/src/ggml-metal/ggml-metal-context.m @@ -47,7 +47,7 @@ uint64_t fuse_cnt[GGML_OP_COUNT]; // capture state - bool capture_next_compute; + int capture_compute; bool capture_started; id capture_scope; @@ -158,10 +158,17 @@ ggml_metal_t ggml_metal_init(ggml_metal_device_t dev) { GGML_LOG_INFO("%s: use concurrency = %s\n", __func__, res->use_concurrency ? "true" : "false"); GGML_LOG_INFO("%s: use graph optimize = %s\n", __func__, res->use_graph_optimize ? "true" : "false"); - res->capture_next_compute = false; + res->capture_compute = 0; res->capture_started = false; res->capture_scope = nil; + { + const char * val = getenv("GGML_METAL_CAPTURE_COMPUTE"); + if (val) { + res->capture_compute = atoi(val); + } + } + res->has_error = false; res->gf = nil; @@ -458,9 +465,13 @@ enum ggml_status ggml_metal_graph_compute(ggml_metal_t ctx, struct ggml_cgraph * ctx->n_nodes_per_cb = (ctx->n_nodes_1 + ctx->n_cb - 1) / ctx->n_cb; - const bool use_capture = ctx->capture_next_compute; + if (ctx->capture_compute > 0) { + ctx->capture_compute--; + } + + const bool use_capture = ctx->capture_compute == 0; if (use_capture) { - ctx->capture_next_compute = false; + ctx->capture_compute = -1; // make sure all previous computations have finished before starting the capture if (ctx->cmd_buf_last) { @@ -469,6 +480,10 @@ enum ggml_status ggml_metal_graph_compute(ggml_metal_t ctx, struct ggml_cgraph * } if (!ctx->capture_started) { + NSString * path = [NSString stringWithFormat:@"/tmp/perf-metal-%d.gputrace", getpid()]; + + GGML_LOG_WARN("%s: capturing graph in %s\n", __func__, [path UTF8String]); + // create capture scope id device = ggml_metal_device_get_obj(ctx->dev); ctx->capture_scope = [[MTLCaptureManager sharedCaptureManager] newCaptureScopeWithDevice:device]; @@ -476,7 +491,7 @@ enum ggml_status ggml_metal_graph_compute(ggml_metal_t ctx, struct ggml_cgraph * MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new]; descriptor.captureObject = ctx->capture_scope; descriptor.destination = MTLCaptureDestinationGPUTraceDocument; - descriptor.outputURL = [NSURL fileURLWithPath:[NSString stringWithFormat:@"/tmp/perf-metal.gputrace"]]; + descriptor.outputURL = [NSURL fileURLWithPath:path]; NSError * error = nil; if (![[MTLCaptureManager sharedCaptureManager] startCaptureWithDescriptor:descriptor error:&error]) { @@ -683,7 +698,7 @@ void ggml_metal_set_n_cb(ggml_metal_t ctx, int n_cb) { idx_end, ctx->use_fusion, ctx->use_concurrency, - ctx->capture_next_compute, + ctx->capture_compute, ctx->debug_graph, ctx->debug_fusion); @@ -718,5 +733,5 @@ bool ggml_metal_supports_family(ggml_metal_t ctx, int family) { } void ggml_metal_capture_next_compute(ggml_metal_t ctx) { - ctx->capture_next_compute = true; + ctx->capture_compute = 1; } From 0e1e76f93bc72b35cd1659a67ab9bc52e4674d2f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 11 Mar 2026 16:25:27 +0200 Subject: [PATCH 248/831] metal : fix q5_k mul_mv register spill (llama/20399) --- ggml/src/ggml-metal/ggml-metal-impl.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index bf51055e367..99d64efc3b5 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -35,7 +35,7 @@ #define N_R0_Q4_K 2 #define N_SG_Q4_K 2 -#define N_R0_Q5_K 2 +#define N_R0_Q5_K 1 #define N_SG_Q5_K 2 #define N_R0_Q6_K 2 From e2aa5c73f369d0b483375786c44647a6d8e53d4b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 11 Mar 2026 18:38:22 +0200 Subject: [PATCH 249/831] metal : fix capture_compute counter logic (llama/20410) --- ggml/src/ggml-metal/ggml-metal-context.m | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-metal/ggml-metal-context.m b/ggml/src/ggml-metal/ggml-metal-context.m index a345cb1f689..855fd1adae8 100644 --- a/ggml/src/ggml-metal/ggml-metal-context.m +++ b/ggml/src/ggml-metal/ggml-metal-context.m @@ -465,7 +465,7 @@ enum ggml_status ggml_metal_graph_compute(ggml_metal_t ctx, struct ggml_cgraph * ctx->n_nodes_per_cb = (ctx->n_nodes_1 + ctx->n_cb - 1) / ctx->n_cb; - if (ctx->capture_compute > 0) { + if (ctx->capture_compute >= 0) { ctx->capture_compute--; } From 5d3a5447c8e50a59149a08392d487b8eea6e73c3 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Wed, 11 Mar 2026 19:27:53 +0100 Subject: [PATCH 250/831] llama : add support for Nemotron 3 Super (llama/20411) * llama : add support for Nemotron 3 Super This commit adds support for the Nemotron 3 Super model (120B.A12B) enabling this model to be converted to GGUF format and run in llama.cpp. Co-authored-by: Georgi Gerganov Co-authored-by: Matt Clayton <156335168+mattjcly@users.noreply.github.com> --- ggml/src/ggml-metal/ggml-metal.metal | 1 + 1 file changed, 1 insertion(+) diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 82ebbb4e409..29e4a245d5d 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -9081,6 +9081,7 @@ template [[host_name("kernel_mul_mm_id_map0_ne20_6" )]] kernel kernel_mul_mm_id_ template [[host_name("kernel_mul_mm_id_map0_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>; template [[host_name("kernel_mul_mm_id_map0_ne20_10")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<10>; template [[host_name("kernel_mul_mm_id_map0_ne20_16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<16>; +template [[host_name("kernel_mul_mm_id_map0_ne20_22")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<22>; template kernel void kernel_mul_mm_id( From e4021d4071087871d4b1d45c10b749e3c64c363e Mon Sep 17 00:00:00 2001 From: Richard Davison Date: Wed, 11 Mar 2026 21:02:54 +0100 Subject: [PATCH 251/831] ggml : add NVFP4 quantization type support (llama/19769) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * WIP: add NVFP4 quantization support * tests * improve NVFP4 dot product implementation performance and fix bad super call * typo * Use nvfp4 kvalues * vulkan : fix NVFP4 shader compilation by including kvalues_mxfp4 lookup table * vulcal and perf fixes * wip * Fix metal * fix vulcan * Rename threshold & fix wrong scale * Fix MOE * Shelf backend implementations (CUDA, Metal, Vulkan, arch-specific SIMD) Remove NVFP4 support from GPU backends and architecture-specific optimized dot products. These should be added in separate PRs so backend specialists can review them independently. Reverted files: - ggml-cuda: common.cuh, convert.cu, mmq.cu/cuh, mmvq.cu, vecdotq.cuh, quantize.cu/cuh, mma.cuh, ggml-cuda.cu, fattn-tile.cuh - ggml-metal: ggml-metal.metal, ggml-metal-device.cpp, ggml-metal-impl.h, ggml-metal-ops.cpp - ggml-vulkan: ggml-vulkan.cpp, all vulkan-shaders/* - ggml-cpu arch: arm/quants.c, x86/quants.c, powerpc/quants.c, s390/quants.c Core NVFP4 support (type definition, CPU fallback dot product, quantization, dequantization, conversion) is retained. * Fix arch-fallback.h: add NVFP4 generic fallback for all platforms After shelving backend-specific SIMD implementations, the generic CPU dot product needs to be aliased on ARM, x86, PowerPC, and s390 platforms that previously relied on arch-specific versions. * quantize: add NVFP4 as a quantization type option * Fix ggml_fp32_to_ue4m3: handle subnormal values Previously, values with ue4m3_exp <= 0 were clamped to 0, causing all small scales to underflow. This made NVFP4 quantization via llama-quantize produce garbage (PPL = 5.8M) since typical transformer weights have amax/6.0 in the range 0.001-0.01, which falls in the UE4M3 subnormal range. Now subnormals are properly encoded as man * 2^-9 (exp=0, man=1..7), matching the decode path in ggml_ue4m3_to_fp32. Result: NVFP4 requantization now produces PPL = 15.25 (vs F16 = 14.33), comparable to Q4_1 (PPL = 15.81) at slightly lower BPW (4.70 vs 5.15). * Restore ARM NEON NVFP4 dot product implementation Restores the optimized ggml_vec_dot_nvfp4_q8_0 for ARM NEON using vqtbl1q_s8 lookup and ggml_vdotq_s32 dot products. tg128 performance: 4.37 t/s (generic) -> 13.66 t/s (NEON) = 3.1x speedup * Optimize ARM NEON NVFP4 dot product: LUT + vpaddq + vfmaq - Add ue4m3_scale_lut[128] to ggml-common.h replacing branch-heavy ggml_ue4m3_to_fp32() in the hot loop - Use vpaddq_s32 for pairwise int32 reduction instead of vaddvq_s32 - Accumulate with vfmaq_f32 into float32x4_t vector accumulators tg128: 8.1 -> 31.0 t/s (3.8x speedup, 77% of Q4_1 speed) * ARM NEON NVFP4: rearrange q8 to match nibble layout Alternative approach: rearrange q8 data to match the NVFP4 lo/hi nibble layout instead of rearranging the looked-up NVFP4 values. Eliminates vcombine_s8(vget_low, vget_low) shuffles. Performance is equivalent (~18.5 t/s) - the bottleneck is the 2x block overhead from QK=16 vs QK=32, not the shuffle instructions. * CPU only backend 64 super-block layout * cleanup * Remove unused LUT * int * exclude NVFP4 from unsupported ops in metal build * remove quantization for now * store scales as native UE4M3, preserve original model bits when possible * Update convert_hf_to_gguf.py Co-authored-by: Sigbjørn Skjæret * correct comment * format * reduce duplication and cleanup * Address comments * move detection to prepare_tensors * Use math instead of const * Move * fix comment * Shelf quantize tests * Rebase and move check * cleanup * lint * Update gguf-py/gguf/scripts/gguf_convert_endian.py Co-authored-by: Sigbjørn Skjæret * Use fallback quant config * Simplify Co-authored-by: Sigbjørn Skjæret * organize * Refactor * Update convert_hf_to_gguf.py Co-authored-by: Sigbjørn Skjæret * Update convert_hf_to_gguf.py Co-authored-by: Sigbjørn Skjæret * Update convert_hf_to_gguf.py Co-authored-by: Sigbjørn Skjæret * add quantize_nvfp4 (required for test_quants.py) * add quantize_nvfp4 (required for test_quants.py) * add quantize_nvfp4 (required for test_quants.py) * fix return type --------- Co-authored-by: Sigbjørn Skjæret --- ggml/include/ggml.h | 4 +- ggml/src/ggml-common.h | 11 ++++ ggml/src/ggml-cpu/arch-fallback.h | 8 +++ ggml/src/ggml-cpu/arch/arm/quants.c | 84 +++++++++++++++++++++++++ ggml/src/ggml-cpu/ggml-cpu.c | 6 ++ ggml/src/ggml-cpu/ops.cpp | 7 +++ ggml/src/ggml-cpu/quants.c | 40 ++++++++++++ ggml/src/ggml-cpu/quants.h | 3 + ggml/src/ggml-impl.h | 55 ++++++++++++++++ ggml/src/ggml-metal/ggml-metal-device.m | 4 +- ggml/src/ggml-quants.c | 72 +++++++++++++++++++++ ggml/src/ggml-quants.h | 3 + ggml/src/ggml.c | 10 +++ 13 files changed, 304 insertions(+), 3 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 566e2714790..3323f8e6c3f 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -427,7 +427,8 @@ extern "C" { // GGML_TYPE_IQ4_NL_4_8 = 37, // GGML_TYPE_IQ4_NL_8_8 = 38, GGML_TYPE_MXFP4 = 39, // MXFP4 (1 block) - GGML_TYPE_COUNT = 40, + GGML_TYPE_NVFP4 = 40, // NVFP4 (4 blocks, E4M3 scale) + GGML_TYPE_COUNT = 41, }; // precision @@ -463,6 +464,7 @@ extern "C" { GGML_FTYPE_MOSTLY_IQ1_M = 23, // except 1d tensors GGML_FTYPE_MOSTLY_BF16 = 24, // except 1d tensors GGML_FTYPE_MOSTLY_MXFP4 = 25, // except 1d tensors + GGML_FTYPE_MOSTLY_NVFP4 = 26, // except 1d tensors }; // available tensor operations: diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index 93ab7ea446e..92cf739e7a7 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -102,6 +102,9 @@ typedef sycl::half2 ggml_half2; #define QI_MXFP4 (QK_MXFP4 / (4 * QR_MXFP4)) #define QR_MXFP4 2 +#define QI_NVFP4 (QK_NVFP4 / (4 * QR_NVFP4)) +#define QR_NVFP4 2 + #define QI5_0 (QK5_0 / (4 * QR5_0)) #define QR5_0 2 @@ -194,6 +197,14 @@ typedef struct { } block_mxfp4; static_assert(sizeof(block_mxfp4) == sizeof(uint8_t) + QK_MXFP4/2, "wrong mxfp4 block size/padding"); +#define QK_NVFP4 64 +#define QK_NVFP4_SUB 16 // sub-block size for per-group scales +typedef struct { + uint8_t d[QK_NVFP4/QK_NVFP4_SUB]; // UE4M3 scales (4 bytes, one per 16-element sub-block) + uint8_t qs[QK_NVFP4/2]; // packed 4-bit E2M1 values (32 bytes) +} block_nvfp4; +static_assert(sizeof(block_nvfp4) == sizeof(uint8_t)*(QK_NVFP4/QK_NVFP4_SUB) + QK_NVFP4/2, "wrong nvfp4 block size/padding"); + #define QK5_0 32 typedef struct { ggml_half d; // delta diff --git a/ggml/src/ggml-cpu/arch-fallback.h b/ggml/src/ggml-cpu/arch-fallback.h index 48315610f2f..175aa4a4bb9 100644 --- a/ggml/src/ggml-cpu/arch-fallback.h +++ b/ggml/src/ggml-cpu/arch-fallback.h @@ -15,6 +15,7 @@ #define ggml_vec_dot_q5_1_q8_1_generic ggml_vec_dot_q5_1_q8_1 #define ggml_vec_dot_q8_0_q8_0_generic ggml_vec_dot_q8_0_q8_0 #define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0 +#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 #define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K #define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K @@ -79,6 +80,8 @@ #define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0 #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #elif defined(__x86_64__) || defined(__i386__) || defined(_M_IX86) || defined(_M_X64) +// quants.c +#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 // repack.cpp #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 #define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 @@ -108,6 +111,7 @@ // ref: https://github.com/ggml-org/llama.cpp/pull/14146#issuecomment-2972561679 // quants.c #define quantize_row_q8_K_generic quantize_row_q8_K +#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 #define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K #define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K @@ -155,6 +159,7 @@ #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K #define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K #define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0 +#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 // repack.cpp #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8 @@ -201,6 +206,7 @@ #define ggml_vec_dot_iq4_nl_q8_0_generic ggml_vec_dot_iq4_nl_q8_0 #define ggml_vec_dot_iq4_xs_q8_K_generic ggml_vec_dot_iq4_xs_q8_K #define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0 +#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 // repack.cpp #define ggml_quantize_mat_q8_0_4x1_generic ggml_quantize_mat_q8_0_4x1 #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 @@ -240,6 +246,7 @@ #elif defined(__s390x__) // quants.c #define quantize_row_q8_K_generic quantize_row_q8_K +#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 #define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K #define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K @@ -302,6 +309,7 @@ #define ggml_vec_dot_iq4_nl_q8_0_generic ggml_vec_dot_iq4_nl_q8_0 #define ggml_vec_dot_iq4_xs_q8_K_generic ggml_vec_dot_iq4_xs_q8_K #define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0 +#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 // repack.cpp #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8 diff --git a/ggml/src/ggml-cpu/arch/arm/quants.c b/ggml/src/ggml-cpu/arch/arm/quants.c index a707d63985e..c1856201b31 100644 --- a/ggml/src/ggml-cpu/arch/arm/quants.c +++ b/ggml/src/ggml-cpu/arch/arm/quants.c @@ -650,6 +650,90 @@ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo *s = sumf; } +void ggml_vec_dot_nvfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + assert(n % QK_NVFP4 == 0); + + const block_nvfp4 * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + + // Each NVFP4 super-block (64 elements) spans 2 q8_0 blocks + const int nb = n / QK_NVFP4; + + float sumf = 0; + +#if defined __ARM_NEON + const int8x16_t values = vld1q_s8(kvalues_mxfp4); + const uint8x16_t m4b = vdupq_n_u8(0x0f); + float32x4_t acc = vdupq_n_f32(0.0f); + + for (int ib = 0; ib < nb; ++ib) { + const uint8x16_t q4bits_0 = vld1q_u8(x[ib].qs); + const uint8x16_t q4bits_1 = vld1q_u8(x[ib].qs + 16); + + const int8x16_t q4_lo_0 = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits_0, m4b)); + const int8x16_t q4_hi_0 = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits_0, 4)); + const int8x16_t q4_lo_1 = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits_1, m4b)); + const int8x16_t q4_hi_1 = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits_1, 4)); + + const int8x16_t q8_0a = vld1q_s8(y[2*ib].qs); + const int8x16_t q8_0b = vld1q_s8(y[2*ib].qs + 16); + const int8x16_t q8_lo_0 = vcombine_s8(vget_low_s8(q8_0a), vget_low_s8(q8_0b)); + const int8x16_t q8_hi_0 = vcombine_s8(vget_high_s8(q8_0a), vget_high_s8(q8_0b)); + + const int8x16_t q8_1a = vld1q_s8(y[2*ib+1].qs); + const int8x16_t q8_1b = vld1q_s8(y[2*ib+1].qs + 16); + const int8x16_t q8_lo_1 = vcombine_s8(vget_low_s8(q8_1a), vget_low_s8(q8_1b)); + const int8x16_t q8_hi_1 = vcombine_s8(vget_high_s8(q8_1a), vget_high_s8(q8_1b)); + + const int32x4_t p0 = vaddq_s32( + ggml_vdotq_s32(vdupq_n_s32(0), q4_lo_0, q8_lo_0), + ggml_vdotq_s32(vdupq_n_s32(0), q4_hi_0, q8_hi_0)); + const int32x4_t p1 = vaddq_s32( + ggml_vdotq_s32(vdupq_n_s32(0), q4_lo_1, q8_lo_1), + ggml_vdotq_s32(vdupq_n_s32(0), q4_hi_1, q8_hi_1)); + + const int32x4_t sums = vpaddq_s32(p0, p1); + + // Decode 4 UE4M3 scales to f32 and multiply with q8 scales + const float dy0 = GGML_CPU_FP16_TO_FP32(y[2*ib].d); + const float dy1 = GGML_CPU_FP16_TO_FP32(y[2*ib+1].d); + const float32x4_t nvsc = { + ggml_ue4m3_to_fp32(x[ib].d[0]), + ggml_ue4m3_to_fp32(x[ib].d[1]), + ggml_ue4m3_to_fp32(x[ib].d[2]), + ggml_ue4m3_to_fp32(x[ib].d[3]) + }; + const float32x4_t scales = vmulq_f32(nvsc, (float32x4_t){dy0, dy0, dy1, dy1}); + + acc = vfmaq_f32(acc, vcvtq_f32_s32(sums), scales); + } + sumf = vaddvq_f32(acc); +#else + for (int ib = 0; ib < nb; ++ib) { + for (int si = 0; si < 4; ++si) { + const float d = ggml_ue4m3_to_fp32(x[ib].d[si]); + const int q8b = si / 2; + const int q8o = (si % 2) * QK_NVFP4_SUB; + const float dy = GGML_CPU_FP16_TO_FP32(y[2*ib + q8b].d); + + int sumi_lo = 0, sumi_hi = 0; + for (int j = 0; j < QK_NVFP4_SUB/2; ++j) { + const uint8_t qv = x[ib].qs[si*(QK_NVFP4_SUB/2) + j]; + sumi_lo += y[2*ib + q8b].qs[q8o + j + 0] * kvalues_mxfp4[qv & 0xf]; + sumi_hi += y[2*ib + q8b].qs[q8o + j + QK_NVFP4_SUB/2] * kvalues_mxfp4[qv >> 4]; + } + sumf += dy * d * (sumi_lo + sumi_hi); + } + } +#endif + *s = sumf; +} + void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { const int qk = QK8_0; const int nb = n / qk; diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index dc2b5ffaa77..8b323bd9b06 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -270,6 +270,12 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = { .vec_dot_type = GGML_TYPE_Q8_0, .nrows = 1, }, + [GGML_TYPE_NVFP4] = { + .from_float = quantize_row_nvfp4, + .vec_dot = ggml_vec_dot_nvfp4_q8_0, + .vec_dot_type = GGML_TYPE_Q8_0, + .nrows = 1, + }, [GGML_TYPE_Q2_K] = { .from_float = quantize_row_q2_K, .vec_dot = ggml_vec_dot_q2_K_q8_K, diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 331e071a267..f9c4ec16e4b 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -670,6 +670,7 @@ void ggml_compute_forward_add( case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: case GGML_TYPE_MXFP4: + case GGML_TYPE_NVFP4: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: @@ -1119,6 +1120,7 @@ void ggml_compute_forward_add1( case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: case GGML_TYPE_MXFP4: + case GGML_TYPE_NVFP4: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: @@ -1247,6 +1249,7 @@ void ggml_compute_forward_acc( case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: case GGML_TYPE_MXFP4: + case GGML_TYPE_NVFP4: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: @@ -4334,6 +4337,7 @@ void ggml_compute_forward_out_prod( case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: case GGML_TYPE_MXFP4: + case GGML_TYPE_NVFP4: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: @@ -4609,6 +4613,7 @@ void ggml_compute_forward_set( case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: case GGML_TYPE_MXFP4: + case GGML_TYPE_NVFP4: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: @@ -4831,6 +4836,7 @@ void ggml_compute_forward_get_rows( case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: case GGML_TYPE_MXFP4: + case GGML_TYPE_NVFP4: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: @@ -5555,6 +5561,7 @@ void ggml_compute_forward_clamp( case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: case GGML_TYPE_MXFP4: + case GGML_TYPE_NVFP4: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: diff --git a/ggml/src/ggml-cpu/quants.c b/ggml/src/ggml-cpu/quants.c index 365cb36d2d7..7ebbb9c6f15 100644 --- a/ggml/src/ggml-cpu/quants.c +++ b/ggml/src/ggml-cpu/quants.c @@ -50,6 +50,10 @@ void quantize_row_mxfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, i quantize_row_mxfp4_ref(x, y, k); } +void quantize_row_nvfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) { + quantize_row_nvfp4_ref(x, y, k); +} + // // 2-6 bit quantization in super-blocks // @@ -216,6 +220,42 @@ void ggml_vec_dot_mxfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, *s = sumf; } +// NVFP4: super-block of 64 elements = 4 sub-blocks of 16 = 2 q8_0 blocks +void ggml_vec_dot_nvfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + assert(n % QK_NVFP4 == 0); + + const block_nvfp4 * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + + const int nb = n / QK_NVFP4; + + float sumf = 0; + + for (int ib = 0; ib < nb; ++ib) { + for (int s_idx = 0; s_idx < 4; ++s_idx) { + const float d = ggml_ue4m3_to_fp32(x[ib].d[s_idx]); + const int q8_block = s_idx / 2; + const int q8_off = (s_idx % 2) * QK_NVFP4_SUB; + const float dy = GGML_CPU_FP16_TO_FP32(y[2*ib + q8_block].d); + + int sumi_lo = 0, sumi_hi = 0; + for (int j = 0; j < QK_NVFP4_SUB/2; ++j) { + const uint8_t qv = x[ib].qs[s_idx*(QK_NVFP4_SUB/2) + j]; + sumi_lo += y[2*ib + q8_block].qs[q8_off + j + 0] * kvalues_mxfp4[qv & 0xf]; + sumi_hi += y[2*ib + q8_block].qs[q8_off + j + QK_NVFP4_SUB/2] * kvalues_mxfp4[qv >> 4]; + } + + sumf += dy * d * (sumi_lo + sumi_hi); + } + } + *s = sumf; +} + void ggml_vec_dot_q5_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { const int qk = QK8_0; const int nb = n / qk; diff --git a/ggml/src/ggml-cpu/quants.h b/ggml/src/ggml-cpu/quants.h index d83eb1b144d..3584aaa43e8 100644 --- a/ggml/src/ggml-cpu/quants.h +++ b/ggml/src/ggml-cpu/quants.h @@ -20,6 +20,7 @@ void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, in void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_mxfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_nvfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q2_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q3_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); @@ -42,6 +43,7 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_nvfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); @@ -73,6 +75,7 @@ void ggml_vec_dot_q5_1_q8_1_generic(int n, float * GGML_RESTRICT s, size_t bs, c void ggml_vec_dot_q8_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_mxfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_nvfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_tq1_0_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_tq2_0_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h index e3714b38a6a..92568655956 100644 --- a/ggml/src/ggml-impl.h +++ b/ggml/src/ggml-impl.h @@ -491,6 +491,61 @@ static inline float ggml_e8m0_to_fp32_half(uint8_t x) { #define GGML_E8M0_TO_FP32(x) ggml_e8m0_to_fp32(x) #define GGML_E8M0_TO_FP32_HALF(x) ggml_e8m0_to_fp32_half(x) +// UE4M3: unsigned, 4 exp bits (bias=7), 3 mantissa bits +// Returns value * 0.5 to match kvalues_mxfp4 convention (kvalues = 2 * E2M1_float) +static inline float ggml_ue4m3_to_fp32(uint8_t x) { + if (x == 0 || x == 0x7F) { + return 0.0f; + } + int exp = (x >> 3) & 0xF; + int man = x & 0x7; + float raw; + if (exp == 0) { + raw = ldexpf((float) man, -9); + } else { + raw = ldexpf(1.0f + (float) man / 8.0f, exp - 7); + } + return raw * 0.5f; +} + +static inline uint8_t ggml_fp32_to_ue4m3(float x) { + if (!(x > 0.0f)) { + return 0; + } + if (x > 448.0f) { + x = 448.0f; + } + uint32_t bits; + memcpy(&bits, &x, 4); + int fp32_exp = ((bits >> 23) & 0xFF) - 127; + int fp32_man = (bits >> 20) & 0x7; + int ue4m3_exp = fp32_exp + 7; + if (ue4m3_exp <= 0) { + // subnormal: value = man * 2^-9, man = round(x * 2^9) + int man = (int) (x * 512.0f + 0.5f); + if (man > 7) { + man = 7; + } + if (man < 1) { + return 0; + } + return (uint8_t) man; + } + if (ue4m3_exp >= 15) { + return 0x7E; + } + int round_bit = (bits >> 19) & 1; + int ue4m3_man = fp32_man + round_bit; + if (ue4m3_man > 7) { + ue4m3_man = 0; + ue4m3_exp++; + if (ue4m3_exp >= 15) { + return 0x7E; + } + } + return (uint8_t) ((ue4m3_exp << 3) | ue4m3_man); +} + /** * Converts brain16 to float32. * diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 23bd2b2ab72..d42b8ab1eb1 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -1158,7 +1158,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_SOLVE_TRI: case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: - return has_simdgroup_reduction; + return has_simdgroup_reduction && op->src[0]->type != GGML_TYPE_NVFP4; case GGML_OP_SET: case GGML_OP_CPY: case GGML_OP_DUP: @@ -1216,7 +1216,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te }; } case GGML_OP_GET_ROWS: - return true; + return op->src[0]->type != GGML_TYPE_NVFP4; case GGML_OP_SET_ROWS: { if (op->src[0]->type != GGML_TYPE_F32) { diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index e8e25633fb8..cdaded865b1 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -304,6 +304,41 @@ void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RE } } +void quantize_row_nvfp4_ref(const float * GGML_RESTRICT x, block_nvfp4 * GGML_RESTRICT y, int64_t k) { + static const int qk = QK_NVFP4; + static const int qk_sub = QK_NVFP4_SUB; + static const int n_sub = QK_NVFP4 / QK_NVFP4_SUB; + + assert(k % qk == 0); + + const int nb = k / qk; + + for (int i = 0; i < nb; i++) { + for (int s = 0; s < n_sub; s++) { + const float * xb = x + i*qk + s*qk_sub; + + float amax = 0.0f; + for (int j = 0; j < qk_sub; j++) { + if (amax < fabsf(xb[j])) { + amax = fabsf(xb[j]); + } + } + + // UE4M3 scale: amax / 6.0 maps the max E2M1 value (6.0) to amax + const uint8_t ue = ggml_fp32_to_ue4m3(amax / 6.0f); + y[i].d[s] = ue; + const float d = ggml_ue4m3_to_fp32(ue); + + for (int j = 0; j < qk_sub/2; ++j) { + const uint8_t x0 = best_index_mxfp4(xb[0 + j], d); + const uint8_t x1 = best_index_mxfp4(xb[qk_sub/2 + j], d); + + y[i].qs[s*(qk_sub/2) + j] = x0 | (x1 << 4); + } + } + } +} + void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { static const int qk = QK4_0; @@ -434,6 +469,31 @@ void dequantize_row_mxfp4(const block_mxfp4 * GGML_RESTRICT x, float * GGML_REST } } +void dequantize_row_nvfp4(const block_nvfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { + static const int qk = QK_NVFP4; + static const int qk_sub = QK_NVFP4_SUB; + static const int n_sub = QK_NVFP4 / QK_NVFP4_SUB; + + assert(k % qk == 0); + + const int nb = k / qk; + + for (int i = 0; i < nb; i++) { + for (int s = 0; s < n_sub; s++) { + const float d = ggml_ue4m3_to_fp32(x[i].d[s]); + float * yb = y + i*qk + s*qk_sub; + + for (int j = 0; j < qk_sub/2; ++j) { + const int8_t v0 = kvalues_mxfp4[x[i].qs[s*(qk_sub/2) + j] & 0x0F]; + const int8_t v1 = kvalues_mxfp4[x[i].qs[s*(qk_sub/2) + j] >> 4]; + + yb[j + 0 ] = v0*d; + yb[j + qk_sub/2] = v1*d; + } + } + } +} + // // 2-6 bit quantization in super-blocks // @@ -2098,6 +2158,12 @@ size_t quantize_mxfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, return nrow * ggml_row_size(GGML_TYPE_MXFP4, n_per_row); } +size_t quantize_nvfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + GGML_UNUSED(quant_weights); + quantize_row_nvfp4_ref(src, dst, (int64_t)nrow*n_per_row); + return nrow * ggml_row_size(GGML_TYPE_NVFP4, n_per_row); +} + // ====================== Ternary (de)-quantization (BitNet b1.58 and TriLMs) void quantize_row_tq1_0_ref(const float * GGML_RESTRICT x, block_tq1_0 * GGML_RESTRICT y, int64_t k) { @@ -5244,6 +5310,12 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte { VALIDATE_ROW_DATA_E_E8M0_IMPL(block_mxfp4, data, nb); } break; + case GGML_TYPE_NVFP4: + { + // UE4M3 scales are uint8_t — all byte values are valid + GGML_UNUSED(data); + GGML_UNUSED(nb); + } break; case GGML_TYPE_Q2_K: { VALIDATE_ROW_DATA_DM_F16_IMPL(block_q2_K, data, nb, d, dmin); diff --git a/ggml/src/ggml-quants.h b/ggml/src/ggml-quants.h index 3b688f31c21..00604f75c0e 100644 --- a/ggml/src/ggml-quants.h +++ b/ggml/src/ggml-quants.h @@ -22,6 +22,7 @@ GGML_API void quantize_row_q8_0_ref(const float * GGML_RESTRICT x, block_q8_0 * GGML_API void quantize_row_q8_1_ref(const float * GGML_RESTRICT x, block_q8_1 * GGML_RESTRICT y, int64_t k); GGML_API void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RESTRICT y, int64_t k); +GGML_API void quantize_row_nvfp4_ref(const float * GGML_RESTRICT x, block_nvfp4 * GGML_RESTRICT y, int64_t k); GGML_API void quantize_row_q2_K_ref(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int64_t k); GGML_API void quantize_row_q3_K_ref(const float * GGML_RESTRICT x, block_q3_K * GGML_RESTRICT y, int64_t k); @@ -48,6 +49,7 @@ GGML_API void dequantize_row_q8_0(const block_q8_0 * GGML_RESTRICT x, float * GG //GGML_API void dequantize_row_q8_1(const block_q8_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); GGML_API void dequantize_row_mxfp4(const block_mxfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +GGML_API void dequantize_row_nvfp4(const block_nvfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); GGML_API void dequantize_row_q2_K(const block_q2_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); GGML_API void dequantize_row_q3_K(const block_q3_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); @@ -95,6 +97,7 @@ GGML_API size_t quantize_q5_1(const float * GGML_RESTRICT src, void * GGML_RESTR GGML_API size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); GGML_API size_t quantize_mxfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +GGML_API size_t quantize_nvfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); GGML_API void iq2xs_init_impl(enum ggml_type type); GGML_API void iq2xs_free_impl(enum ggml_type type); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index aeafc395d71..e5b83e14479 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -718,6 +718,14 @@ static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = { .to_float = (ggml_to_float_t) dequantize_row_mxfp4, .from_float_ref = (ggml_from_float_t)quantize_row_mxfp4_ref, }, + [GGML_TYPE_NVFP4] = { + .type_name = "nvfp4", + .blck_size = QK_NVFP4, + .type_size = sizeof(block_nvfp4), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_nvfp4, + .from_float_ref = (ggml_from_float_t)quantize_row_nvfp4_ref, + }, [GGML_TYPE_Q2_K] = { .type_name = "q2_K", .blck_size = QK_K, @@ -1374,6 +1382,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) { case GGML_FTYPE_MOSTLY_Q5_1: wtype = GGML_TYPE_Q5_1; break; case GGML_FTYPE_MOSTLY_Q8_0: wtype = GGML_TYPE_Q8_0; break; case GGML_FTYPE_MOSTLY_MXFP4: wtype = GGML_TYPE_MXFP4; break; + case GGML_FTYPE_MOSTLY_NVFP4: wtype = GGML_TYPE_NVFP4; break; case GGML_FTYPE_MOSTLY_Q2_K: wtype = GGML_TYPE_Q2_K; break; case GGML_FTYPE_MOSTLY_Q3_K: wtype = GGML_TYPE_Q3_K; break; case GGML_FTYPE_MOSTLY_Q4_K: wtype = GGML_TYPE_Q4_K; break; @@ -7641,6 +7650,7 @@ size_t ggml_quantize_chunk( case GGML_TYPE_Q5_1: result = quantize_q5_1(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q8_0: result = quantize_q8_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_MXFP4: result = quantize_mxfp4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_NVFP4: result = quantize_nvfp4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q2_K: result = quantize_q2_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q3_K: result = quantize_q3_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q4_K: result = quantize_q4_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; From d73fe252677a9aa03c5f1c5d64340a57ca848a81 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 11 Mar 2026 22:46:40 +0200 Subject: [PATCH 252/831] llama : enable chunked fused GDN path (llama/20340) * llama : enable chunked fused GDN path * models : avoid Q and K repeats when using fused GDA * cont : fix comment Co-authored-by: Aman Gupta * cont : fix the fix Co-authored-by: Aman Gupta * cont : fix * metal : add GDN kernel (llama/20361) * metal : add Metal backend for GGML_OP_GATED_DELTA_NET Add a fused Metal kernel for the gated delta net recurrence op (#19504), enabling GPU-accelerated inference for DeltaNet-based models (Qwen3.5, etc.) on Apple Silicon. Supports both GDA (scalar gate) and KDA (per-row gate) modes with head_size 64 and 128. Unsupported configurations (head_size 32, non-contiguous tensors) gracefully fall back to CPU. Performance: Qwen3.5-0.8B Q4_K_M on M4 Max tg128: 170 -> 213 t/s (+25%) Co-Authored-By: Claude Opus 4.6 * metal : validate contiguity of all input tensors in supports_op Co-Authored-By: Claude Opus 4.6 * metal : add algorithm equivalence comment for GDA decay path Co-Authored-By: Claude Opus 4.6 * cont : unslop + optimize * cont : clean-up --------- Co-authored-by: Paul Flynn Co-authored-by: Claude Opus 4.6 * CUDA: AR gated delta net improvements (llama/20391) * Add FastDiv to gated_delta_net_cuda * Shard columns across warps This reduces register pressure (avoids spill for S_v = 128) and gives the warp-scheduler more CTAs to schedule (thus hiding data-access latencies). * Remove unneded include in gated_delta_net.cu * Improve comments * Apply code-formating * Make sharding HIP-compatible 1. Use ggml_cuda_get_physical_warp_size() to determine warp size flexibly 2. Add test with partial warp to test sum reduction on CUDA * Remove fastdiv_s64, as we can treat neqk1 and rq3 as uint32_t * Rename variables * Enable GDN also for prefill, move TODO for chunked_GDN * Actually remove the TODO from 206890897546bd16602c3b79394fd5ea09ef199f * Get warp size at runtime warp_size is not known at compile time in hip host code. * Don't expose ggml_cuda_get_physical_warp_size on host --------- Co-authored-by: uvos * llama : refactor llm_build_delta_net_base API --------- Co-authored-by: Aman Gupta Co-authored-by: Paul Flynn Co-authored-by: Claude Opus 4.6 Co-authored-by: Oliver Simons Co-authored-by: uvos --- ggml/include/ggml.h | 2 + ggml/src/ggml-cpu/ops.cpp | 11 +- ggml/src/ggml-cuda/gated_delta_net.cu | 183 ++++++++++-------- ggml/src/ggml-metal/ggml-metal-device.cpp | 35 ++++ ggml/src/ggml-metal/ggml-metal-device.h | 1 + ggml/src/ggml-metal/ggml-metal-device.m | 2 + ggml/src/ggml-metal/ggml-metal-impl.h | 39 ++++ ggml/src/ggml-metal/ggml-metal-ops.cpp | 79 ++++++++ ggml/src/ggml-metal/ggml-metal-ops.h | 1 + ggml/src/ggml-metal/ggml-metal.metal | 221 ++++++++++++++++++++++ 10 files changed, 489 insertions(+), 85 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 3323f8e6c3f..25f9601e9b5 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -2466,6 +2466,8 @@ extern "C" { bool lower, bool uni); + // TODO: add ggml_gated_delta_net_set_bcast() to be able to configure Q, K broadcast type: tiled vs interleaved [TAG_GGML_GDN_BCAST] + // ref: https://github.com/ggml-org/llama.cpp/pull/19468#discussion_r2786394306 GGML_API struct ggml_tensor * ggml_gated_delta_net( struct ggml_context * ctx, struct ggml_tensor * q, diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index f9c4ec16e4b..fa9d27046b5 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -10443,8 +10443,8 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( const float * state_in_base = (const float *)src_state->data; - const int64_t rq1 = nev1 / neq1; - const int64_t rk1 = nev1 / nek1; + //const int64_t rq1 = nev1 / neq1; + //const int64_t rk1 = nev1 / nek1; const int64_t rq3 = nev3 / neq3; const int64_t rk3 = nev3 / nek3; @@ -10454,8 +10454,8 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( const int64_t iv1 = ir % H; // head_index const int64_t iv3 = ir / H; // sequence - const int64_t iq1 = iv1 / rq1; - const int64_t ik1 = iv1 / rk1; + const int64_t iq1 = iv1 % neq1; + const int64_t ik1 = iv1 % nek1; const int64_t iq3 = iv3 / rq3; const int64_t ik3 = iv3 / rk3; @@ -10475,7 +10475,7 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( const float * v_d = (const float *)((const char *)src_v->data + iv3 * nbv3 + t * nbv2 + iv1 * nbv1); const float beta_val = *(const float *)((const char *)src_beta->data + iv3 * nbb3 + t * nbb2 + iv1 * nbb1); - const float * g_d = (const float *)((const char *)src_g->data + iv3 * nbg3 + t * nbg2 + iv1 * nbg1); + const float * g_d = (const float *)((const char *)src_g->data + iv3 * nbg3 + t * nbg2 + iv1 * nbg1); if (kda) { for (int64_t i = 0; i < S_v; ++i) { @@ -10508,7 +10508,6 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( attn_data += S_v * H; // advance to next token } - } } diff --git a/ggml/src/ggml-cuda/gated_delta_net.cu b/ggml/src/ggml-cuda/gated_delta_net.cu index c249bbc86d5..5f0fa8e58df 100644 --- a/ggml/src/ggml-cuda/gated_delta_net.cu +++ b/ggml/src/ggml-cuda/gated_delta_net.cu @@ -1,36 +1,36 @@ #include "gated_delta_net.cuh" -#include "ggml-cuda/common.cuh" template -__global__ void __launch_bounds__(S_v, 1) -gated_delta_net_cuda(const float * q, - const float * k, - const float * v, - const float * g, - const float * beta, - const float * curr_state, - float * dst, - const int64_t H, - const int64_t n_tokens, - const int64_t n_seqs, - const int64_t sq1, - const int64_t sq2, - const int64_t sq3, - const int64_t sv1, - const int64_t sv2, - const int64_t sv3, - const int64_t sb1, - const int64_t sb2, - const int64_t sb3, - const int64_t rq1, - const int64_t rq3, - const float scale) { - const int64_t h_idx = blockIdx.x; - const int64_t sequence = blockIdx.y; - const int col = threadIdx.x; // each thread owns one column - - const int64_t iq1 = h_idx / rq1; - const int64_t iq3 = sequence / rq3; +__global__ void gated_delta_net_cuda(const float * q, + const float * k, + const float * v, + const float * g, + const float * beta, + const float * curr_state, + float * dst, + int64_t H, + int64_t n_tokens, + int64_t n_seqs, + int64_t sq1, + int64_t sq2, + int64_t sq3, + int64_t sv1, + int64_t sv2, + int64_t sv3, + int64_t sb1, + int64_t sb2, + int64_t sb3, + const uint3 neqk1_magic, + const uint3 rq3_magic, + float scale) { + const uint32_t h_idx = blockIdx.x; + const uint32_t sequence = blockIdx.y; + // each warp owns one column, using warp-level primitives to reduce across rows + const int lane = threadIdx.x; + const int col = blockIdx.z * blockDim.y + threadIdx.y; + + const uint32_t iq1 = fastmodulo(h_idx, neqk1_magic); + const uint32_t iq3 = fastdiv(sequence, rq3_magic); const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs; float * attn_data = dst; @@ -41,17 +41,14 @@ gated_delta_net_cuda(const float * q, curr_state += state_offset; attn_data += (sequence * n_tokens * H + h_idx) * S_v; - // GCN and CDNA devices spill registers, we use shared mem for them. See https://github.com/ggml-org/llama.cpp/pull/20282#issuecomment-4025770229 - // TODO: check optimal path for RDNA1 and RDNA2 devices. -#if (defined(GGML_USE_HIP) && !defined(RDNA3) && !defined(RDNA4)) || defined(GGML_USE_MUSA) - extern __shared__ float s_shared[]; - float * s = s_shared + col * S_v; -#else - float s[S_v]; -#endif + constexpr int warp_size = ggml_cuda_get_physical_warp_size() < S_v ? ggml_cuda_get_physical_warp_size() : S_v; + static_assert(S_v % warp_size == 0, "S_v must be a multiple of warp_size"); + constexpr int rows_per_lane = (S_v + warp_size - 1) / warp_size; + float s_shard[rows_per_lane]; #pragma unroll - for (int i = 0; i < S_v; i++) { - s[i] = curr_state[i * S_v + col]; + for (int r = 0; r < rows_per_lane; r++) { + const int i = r * warp_size + lane; + s_shard[r] = curr_state[i * S_v + col]; } for (int t = 0; t < n_tokens; t++) { @@ -69,46 +66,61 @@ gated_delta_net_cuda(const float * q, const float g_val = expf(*g_t); // kv[col] = (S^T @ k)[col] = sum_i S[i][col] * k[i] - float kv_col = 0.0f; + float kv_shard = 0.0f; #pragma unroll - for (int i = 0; i < S_v; i++) { - kv_col += s[i] * k_t[i]; + for (int r = 0; r < rows_per_lane; r++) { + const int i = r * warp_size + lane; + kv_shard += s_shard[r] * k_t[i]; } + float kv_col = warp_reduce_sum(kv_shard); // delta[col] = (v[col] - g * kv[col]) * beta float delta_col = (v_t[col] - g_val * kv_col) * beta_val; // fused: S[i][col] = g * S[i][col] + k[i] * delta[col] // attn[col] = (S^T @ q)[col] = sum_i S[i][col] * q[i] - float attn_col = 0.0f; + float attn_partial = 0.0f; #pragma unroll - for (int i = 0; i < S_v; i++) { - s[i] = g_val * s[i] + k_t[i] * delta_col; - attn_col += s[i] * q_t[i]; + for (int r = 0; r < rows_per_lane; r++) { + const int i = r * warp_size + lane; + s_shard[r] = g_val * s_shard[r] + k_t[i] * delta_col; + attn_partial += s_shard[r] * q_t[i]; } - attn_data[col] = attn_col * scale; + float attn_col = warp_reduce_sum(attn_partial); + + if (lane == 0) { + attn_data[col] = attn_col * scale; + } } else { // kv[col] = sum_i g[i] * S[i][col] * k[i] - float kv_col = 0.0f; + float kv_shard = 0.0f; #pragma unroll - for (int i = 0; i < S_v; i++) { - kv_col += expf(g_t[i]) * s[i] * k_t[i]; + for (int r = 0; r < rows_per_lane; r++) { + const int i = r * warp_size + lane; + kv_shard += expf(g_t[i]) * s_shard[r] * k_t[i]; } + float kv_col = warp_reduce_sum(kv_shard); + // delta[col] = (v[col] - kv[col]) * beta float delta_col = (v_t[col] - kv_col) * beta_val; // fused: S[i][col] = g[i] * S[i][col] + k[i] * delta[col] // attn[col] = (S^T @ q)[col] = sum_i S[i][col] * q[i] - float attn_col = 0.0f; + float attn_partial = 0.0f; #pragma unroll - for (int i = 0; i < S_v; i++) { - s[i] = expf(g_t[i]) * s[i] + k_t[i] * delta_col; - attn_col += s[i] * q_t[i]; + for (int r = 0; r < rows_per_lane; r++) { + const int i = r * warp_size + lane; + s_shard[r] = expf(g_t[i]) * s_shard[r] + k_t[i] * delta_col; + attn_partial += s_shard[r] * q_t[i]; } - attn_data[col] = attn_col * scale; + float attn_col = warp_reduce_sum(attn_partial); + + if (lane == 0) { + attn_data[col] = attn_col * scale; + } } attn_data += S_v * H; @@ -116,8 +128,9 @@ gated_delta_net_cuda(const float * q, // Write state back to global memory #pragma unroll - for (int i = 0; i < S_v; i++) { - state[i * S_v + col] = s[i]; + for (int r = 0; r < rows_per_lane; r++) { + const int i = r * warp_size + lane; + state[i * S_v + col] = s_shard[r]; } } @@ -135,35 +148,43 @@ static void launch_gated_delta_net( const float * q_d, const float * k_d, const float * v_d, const float * g_d, const float * b_d, const float * s_d, float * dst_d, - int64_t S_v, int64_t H, int64_t n_tokens, int64_t n_seqs, - int64_t sq1, int64_t sq2, int64_t sq3, - int64_t sv1, int64_t sv2, int64_t sv3, - int64_t sb1, int64_t sb2, int64_t sb3, - int64_t rq1, int64_t rq3, + int64_t S_v, int64_t H, int64_t n_tokens, int64_t n_seqs, + int64_t sq1, int64_t sq2, int64_t sq3, + int64_t sv1, int64_t sv2, int64_t sv3, + int64_t sb1, int64_t sb2, int64_t sb3, + int64_t neqk1, int64_t rq3, float scale, cudaStream_t stream) { + //TODO: Add chunked kernel for even faster pre-fill + const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size; + const int num_warps = 4; + dim3 grid_dims(H, n_seqs, (S_v + num_warps - 1) / num_warps); + dim3 block_dims(warp_size <= S_v ? warp_size : S_v, num_warps, 1); - dim3 grid_dims(H, n_seqs, 1); - dim3 block_dims(S_v, 1, 1); + const uint3 neqk1_magic = init_fastdiv_values(neqk1); + const uint3 rq3_magic = init_fastdiv_values(rq3); int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; switch (S_v) { - case 32: { - constexpr int sv = 32; - size_t smem = calculate_smem(sv, cc); - gated_delta_net_cuda<<>>( + case 16: + gated_delta_net_cuda<16, KDA><<>>( q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, - sb1, sb2, sb3, rq1, rq3, scale); + sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); + break; + case 32: + gated_delta_net_cuda<32, KDA><<>>( + q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, + n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); break; - } case 64: { constexpr int sv = 64; size_t smem = calculate_smem(sv, cc); gated_delta_net_cuda<<>>( q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, - sb1, sb2, sb3, rq1, rq3, scale); + sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); break; } case 128: { @@ -172,7 +193,7 @@ static void launch_gated_delta_net( gated_delta_net_cuda<<>>( q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, - sb1, sb2, sb3, rq1, rq3, scale); + sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); break; } default: @@ -190,10 +211,12 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * ggml_tensor * src_state = dst->src[5]; GGML_TENSOR_LOCALS(int64_t, neq, src_q, ne); - GGML_TENSOR_LOCALS(size_t, nbq, src_q, nb); + GGML_TENSOR_LOCALS(size_t , nbq, src_q, nb); + GGML_TENSOR_LOCALS(int64_t, nek, src_k, ne); + GGML_TENSOR_LOCALS(size_t , nbk, src_k, nb); GGML_TENSOR_LOCALS(int64_t, nev, src_v, ne); - GGML_TENSOR_LOCALS(size_t, nbv, src_v, nb); - GGML_TENSOR_LOCALS(size_t, nbb, src_beta, nb); + GGML_TENSOR_LOCALS(size_t, nbv, src_v, nb); + GGML_TENSOR_LOCALS(size_t, nbb, src_beta, nb); const int64_t S_v = nev0; const int64_t H = nev1; @@ -202,7 +225,9 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * const bool kda = (src_g->ne[0] == S_v); - const int64_t rq1 = nev1 / neq1; + GGML_ASSERT(neq1 == nek1); + const int64_t neqk1 = neq1; + const int64_t rq3 = nev3 / neq3; const float * q_d = (const float *) src_q->data; @@ -241,10 +266,10 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * if (kda) { launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, - sb1, sb2, sb3, rq1, rq3, scale, stream); + sb1, sb2, sb3, neqk1, rq3, scale, stream); } else { launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, - sb1, sb2, sb3, rq1, rq3, scale, stream); + sb1, sb2, sb3, neqk1, rq3, scale, stream); } } diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 169c63dd7a4..15ae2e517df 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -577,6 +577,41 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv(ggml_metal_ return res; } +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net(ggml_metal_library_t lib, const ggml_tensor * op) { + char base[256]; + char name[256]; + + // v is src[2], dimensions: S_v = ne[0], H = ne[1] + const int ne20 = op->src[2]->ne[0]; // S_v + const int ne21 = op->src[2]->ne[1]; // H + const int ne30 = op->src[3]->ne[0]; // G + + const int nsg = op->src[2]->ne[0]/32; + + GGML_ASSERT(op->src[5]->type == GGML_TYPE_F32); + GGML_ASSERT(op->ne[0] == ne20 * ne21); + GGML_ASSERT(ne20 % 32 == 0); + + snprintf(base, 256, "kernel_gated_delta_net_%s_%d", ggml_type_name(op->src[0]->type), nsg); + snprintf(name, 256, "%s_ne20=%d_ne30=%d", base, ne20, ne30); + + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_int16(cv, ne20, FC_GATED_DELTA_NET + 0); + ggml_metal_cv_set_int16(cv, ne30, FC_GATED_DELTA_NET + 1); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); + } + + res.nsg = nsg; + + return res; +} + ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri(ggml_metal_library_t lib, const ggml_tensor * op) { char base[256]; char name[256]; diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index 93d7f6a216f..fd2b3ddeb55 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -125,6 +125,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv_batched (ggml_metal_library_t lib, const struct ggml_tensor * op, int ssm_conv_bs); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm (ggml_metal_library_t lib, const struct ggml_tensor * op); diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index d42b8ab1eb1..05b826a61b8 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -1155,6 +1155,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_RWKV_WKV6: case GGML_OP_RWKV_WKV7: return true; + case GGML_OP_GATED_DELTA_NET: + return op->src[2]->ne[0] % 32 == 0; case GGML_OP_SOLVE_TRI: case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 99d64efc3b5..53437b23cda 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -84,6 +84,7 @@ #define FC_BIN 1300 #define FC_SUM_ROWS 1400 #define FC_UPSCALE 1500 +#define FC_GATED_DELTA_NET 1600 // op-specific constants #define OP_FLASH_ATTN_EXT_NQPSG 8 @@ -793,6 +794,44 @@ typedef struct { uint64_t nb0; } ggml_metal_kargs_ssm_scan; +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne10; + int32_t ne11; + int32_t ne12; + int32_t ne13; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + int32_t ne20; + int32_t ne21; + int32_t ne22; + int32_t ne23; + uint64_t nb20; + uint64_t nb21; + uint64_t nb22; + uint64_t nb23; + int32_t ns02; + int32_t ns12; + int32_t ns22; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; +} ggml_metal_kargs_gated_delta_net; + typedef struct { int32_t ne00; int32_t ne01; diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 267755d08cc..306dbcf3660 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -333,6 +333,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { { n_fuse = ggml_metal_op_rwkv(ctx, idx); } break; + case GGML_OP_GATED_DELTA_NET: + { + n_fuse = ggml_metal_op_gated_delta_net(ctx, idx); + } break; case GGML_OP_SOLVE_TRI: { n_fuse = ggml_metal_op_solve_tri(ctx, idx); @@ -1562,6 +1566,81 @@ int ggml_metal_op_rwkv(ggml_metal_op_t ctx, int idx) { return 1; } +int ggml_metal_op_gated_delta_net(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne); + GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + auto pipeline = ggml_metal_library_get_pipeline_gated_delta_net(lib, op); + + int ida = 0; + + ggml_metal_kargs_gated_delta_net args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.ne13 =*/ ne13, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ne20 =*/ ne20, + /*.ne21 =*/ ne21, + /*.ne22 =*/ ne22, + /*.ne23 =*/ ne23, + /*.nb20 =*/ nb20, + /*.nb21 =*/ nb21, + /*.nb22 =*/ nb22, + /*.nb23 =*/ nb23, + /*.ns02 =*/ (int32_t) (nb02/sizeof(float)), + /*.ns12 =*/ (int32_t) (nb12/sizeof(float)), + /*.ns22 =*/ (int32_t) (nb22/sizeof(float)), + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + }; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++); // q + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++); // k + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++); // v + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), ida++); // gate + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), ida++); // beta + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[5]), ida++); // state + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), ida++); // dst + + const int nsg = pipeline.nsg; + + ggml_metal_encoder_dispatch_threadgroups(enc, op->src[2]->ne[0]/nsg, op->src[2]->ne[1], op->src[2]->ne[3], 32, nsg, 1); + + return 1; +} + int ggml_metal_op_solve_tri(ggml_metal_op_t ctx, int idx) { ggml_tensor * op = ctx->node(idx); diff --git a/ggml/src/ggml-metal/ggml-metal-ops.h b/ggml/src/ggml-metal/ggml-metal-ops.h index f3e38c7aa9d..019f2fec9ed 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.h +++ b/ggml/src/ggml-metal/ggml-metal-ops.h @@ -58,6 +58,7 @@ int ggml_metal_op_soft_max (ggml_metal_op_t ctx, int idx); int ggml_metal_op_ssm_conv (ggml_metal_op_t ctx, int idx); int ggml_metal_op_ssm_scan (ggml_metal_op_t ctx, int idx); int ggml_metal_op_rwkv (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_gated_delta_net (ggml_metal_op_t ctx, int idx); int ggml_metal_op_solve_tri (ggml_metal_op_t ctx, int idx); int ggml_metal_op_set (ggml_metal_op_t ctx, int idx); int ggml_metal_op_cpy (ggml_metal_op_t ctx, int idx); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 29e4a245d5d..0b77d5349b8 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -2434,6 +2434,227 @@ kernel void kernel_rwkv_wkv7_f32( } } +constant short FC_gated_delta_net_ne20 [[function_constant(FC_GATED_DELTA_NET + 0)]]; +constant short FC_gated_delta_net_ne30 [[function_constant(FC_GATED_DELTA_NET + 1)]]; + +#if 1 +template +kernel void kernel_gated_delta_net_impl( + constant ggml_metal_kargs_gated_delta_net & args, + device const char * q, + device const char * k, + device const char * v, + device const char * g, + device const char * b, + device const char * s, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { +#define S_v FC_gated_delta_net_ne20 +#define G FC_gated_delta_net_ne30 + + const uint tx = tpitg.x; + const uint ty = tpitg.y; + + const uint i23 = tgpig.z; // B + const uint i21 = tgpig.y; // H + const uint i20 = tgpig.x*NSG + ty; + + const uint i01 = i21 % args.ne01; + const uint i11 = i21 % args.ne11; + + const float scale = 1.0f / sqrt((float)S_v); + + device const float * s_ptr = (device const float *) (s) + (i23*args.ne21 + i21)*S_v*S_v + i20; + + float ls[NSG]; + + FOR_UNROLL (short j = 0; j < NSG; j++) { + const short is = tx*NSG + j; + ls[j] = s_ptr[is*S_v]; + } + + device float * dst_attn = (device float *) (dst) + (i23*args.ne22*args.ne21 + i21)*S_v + i20; + + device const float * q_ptr = (device const float *) (q + i23*args.nb03 + i01*args.nb01); + device const float * k_ptr = (device const float *) (k + i23*args.nb13 + i11*args.nb11); + device const float * v_ptr = (device const float *) (v + i23*args.nb23 + i21*args.nb21); + + device const float * b_ptr = (device const float *) (b) + (i23*args.ne22*args.ne21 + i21); + device const float * g_ptr = (device const float *) (g) + (i23*args.ne22*args.ne21 + i21)*G; + + for (short t = 0; t < args.ne22; t++) { + float s_k = 0.0f; + + if (G == 1) { + const float g_exp = exp(g_ptr[0]); + + FOR_UNROLL (short j = 0; j < NSG; j++) { + const short is = tx*NSG + j; + ls[j] *= g_exp; + + s_k += ls[j]*k_ptr[is]; + } + } else { + // KDA + FOR_UNROLL (short j = 0; j < NSG; j++) { + const short is = tx*NSG + j; + ls[j] *= exp(g_ptr[is]); + + s_k += ls[j]*k_ptr[is]; + } + } + + s_k = simd_sum(s_k); + + const float d = (v_ptr[i20] - s_k)*b_ptr[0]; + + float y = 0.0f; + + FOR_UNROLL (short j = 0; j < NSG; j++) { + const short is = tx*NSG + j; + ls[j] += k_ptr[is]*d; + + y += ls[j]*q_ptr[is]; + } + + y = simd_sum(y); + + if (tx == 0) { + dst_attn[t*args.ne21*S_v] = y*scale; + } + + q_ptr += args.ns02; + k_ptr += args.ns12; + v_ptr += args.ns22; + + b_ptr += args.ne21; + g_ptr += args.ne21*G; + } + + device float * dst_state = (device float *) (dst) + args.ne23*args.ne22*args.ne21*S_v + (i23*args.ne21 + i21)*S_v*S_v + i20; + + FOR_UNROLL (short j = 0; j < NSG; j++) { + const short is = tx*NSG + j; + dst_state[is*S_v] = ls[j]; + } + +#undef S_v +#undef G +} + +typedef decltype(kernel_gated_delta_net_impl<4>) kernel_gated_delta_net_t; + +template [[host_name("kernel_gated_delta_net_f32_1")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<1>; +template [[host_name("kernel_gated_delta_net_f32_2")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<2>; +template [[host_name("kernel_gated_delta_net_f32_4")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<4>; + +#else +// a simplified version of the above +// no performance improvement, so keep the above version for now + +template +kernel void kernel_gated_delta_net_impl( + constant ggml_metal_kargs_gated_delta_net & args, + device const char * q, + device const char * k, + device const char * v, + device const char * g, + device const char * b, + device const char * s, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { +#define S_v FC_gated_delta_net_ne20 +#define G FC_gated_delta_net_ne30 + + const uint tx = tpitg.x; + const uint ty = tpitg.y; + + const uint i23 = tgpig.z; // B + const uint i21 = tgpig.y; // H + const uint i20 = tgpig.x*NSG + ty; + + const uint i01 = i21 % args.ne01; + const uint i11 = i21 % args.ne11; + + const float scale = 1.0f / sqrt((float)S_v); + + device const float * s_ptr = (device const float *) (s) + (i23*args.ne21 + i21)*S_v*S_v + i20; + + float lsf[NSG]; + + FOR_UNROLL (short j = 0; j < NSG; j++) { + const short is = tx*NSG + j; + lsf[j] = s_ptr[is*S_v]; + } + + thread T * ls = (thread T *) (lsf); + + device float * dst_attn = (device float *) (dst) + (i23*args.ne22*args.ne21 + i21)*S_v + i20; + + device const float * q_ptr = (device const float *) (q + i23*args.nb03 + i01*args.nb01); + device const float * k_ptr = (device const float *) (k + i23*args.nb13 + i11*args.nb11); + device const float * v_ptr = (device const float *) (v + i23*args.nb23 + i21*args.nb21); + + device const float * b_ptr = (device const float *) (b) + (i23*args.ne22*args.ne21 + i21); + device const float * g_ptr = (device const float *) (g) + (i23*args.ne22*args.ne21 + i21)*G; + + for (short t = 0; t < args.ne22; t++) { + device const T * qt_ptr = (device const T *) (q_ptr); + device const T * kt_ptr = (device const T *) (k_ptr); + device const T * gt_ptr = (device const T *) (g_ptr); + + if (G == 1) { + *ls *= exp(g_ptr[0]); + } else { + // KDA + *ls *= exp(gt_ptr[tx]); + } + + const float s_k = simd_sum(dot(*ls, kt_ptr[tx])); + + const float d = (v_ptr[i20] - s_k)*b_ptr[0]; + + *ls += kt_ptr[tx]*d; + + const float y = simd_sum(dot(*ls, qt_ptr[tx])); + + if (tx == 0) { + *dst_attn = y*scale; + } + + q_ptr += args.ns02; + k_ptr += args.ns12; + v_ptr += args.ns22; + + b_ptr += args.ne21; + g_ptr += args.ne21*G; + + dst_attn += args.ne21*S_v; + } + + device float * dst_state = (device float *) (dst) + args.ne23*args.ne22*args.ne21*S_v + (i23*args.ne21 + i21)*S_v*S_v + i20; + device T * dstt_state = (device T *) (dst_state); + + FOR_UNROLL (short j = 0; j < NSG; j++) { + const short is = tx*NSG + j; + dst_state[is*S_v] = lsf[j]; + } + +#undef S_v +#undef G +} + +typedef decltype(kernel_gated_delta_net_impl) kernel_gated_delta_net_t; + +template [[host_name("kernel_gated_delta_net_f32_1")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl; +template [[host_name("kernel_gated_delta_net_f32_2")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl; +template [[host_name("kernel_gated_delta_net_f32_4")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl; +#endif + constant short FC_solve_tri_nsg [[function_constant(FC_SOLVE_TRI + 0)]]; constant short FC_solve_tri_n [[function_constant(FC_SOLVE_TRI + 1)]]; constant short FC_solve_tri_k [[function_constant(FC_SOLVE_TRI + 2)]]; From 5267523829c97912dbbcde80fe1664b6c8d45b56 Mon Sep 17 00:00:00 2001 From: Masashi Yoshimura Date: Thu, 12 Mar 2026 06:40:36 +0900 Subject: [PATCH 253/831] ggml-webgpu: Add supports for `GGML_OP_REPEAT` (llama/20230) * Add GGML_OP_REPEAT to webgpu backend. * Add i16 support for GGML_OP_REPEAT. --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 71 +++++++++++++++++-- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 55 ++++++++++++-- ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl | 67 +++++++++++++++++ 3 files changed, 183 insertions(+), 10 deletions(-) create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 3c38b1a230f..3d7e59fddf3 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -198,6 +198,22 @@ struct ggml_webgpu_concat_pipeline_key_hash { } }; +/** Repeat **/ + +struct ggml_webgpu_repeat_pipeline_key { + int type; + + bool operator==(const ggml_webgpu_repeat_pipeline_key & other) const { return type == other.type; } +}; + +struct ggml_webgpu_repeat_pipeline_key_hash { + size_t operator()(const ggml_webgpu_repeat_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.type); + return seed; + } +}; + /** Binary **/ struct ggml_webgpu_binary_pipeline_key { @@ -431,6 +447,8 @@ class ggml_webgpu_shader_lib { binary_pipelines; // type/op/inplace/overlap std::unordered_map concat_pipelines; // type + std::unordered_map + repeat_pipelines; // type std::unordered_map flash_attn_pipelines; std::unordered_map defines; - std::string variant = "concat"; + std::string variant = "concat"; switch (key.type) { case GGML_TYPE_F32: @@ -1164,15 +1182,56 @@ class ggml_webgpu_shader_lib { defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); - auto processed = preprocessor.preprocess(wgsl_concat, defines); - auto decisions = std::make_shared(); - decisions->wg_size = context.max_wg_size; + auto processed = preprocessor.preprocess(wgsl_concat, defines); + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); - pipeline.context = decisions; - concat_pipelines[key] = pipeline; + pipeline.context = decisions; + concat_pipelines[key] = pipeline; return concat_pipelines[key]; } + webgpu_pipeline get_repeat_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_repeat_pipeline_key key = { + .type = context.dst->type, + }; + + auto it = repeat_pipelines.find(key); + if (it != repeat_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "repeat"; + + switch (key.type) { + case GGML_TYPE_F32: + defines.push_back("TYPE_F32"); + variant += "_f32"; + break; + case GGML_TYPE_I32: + defines.push_back("TYPE_I32"); + variant += "_i32"; + break; + case GGML_TYPE_I16: + defines.push_back("TYPE_I16"); + variant += "_i16"; + break; + default: + GGML_ABORT("Unsupported type for repeat shader"); + } + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_repeat, defines); + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + repeat_pipelines[key] = pipeline; + return repeat_pipelines[key]; + } + webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_shader_lib_context & context) { const bool has_mask = context.src3 != nullptr; const bool has_sinks = context.src4 != nullptr; diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index ccc34cb153f..128b7dc3de8 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -1567,6 +1567,48 @@ static webgpu_command ggml_webgpu_concat(webgpu_context & ctx, return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } +static webgpu_command ggml_webgpu_repeat(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * dst) { + uint32_t ne = (uint32_t) ggml_nelements(dst); + + std::vector params = { ne, + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / + ggml_type_size(src0->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + (uint32_t) (src0->nb[0] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), + (uint32_t) (src0->ne[0]), + (uint32_t) (src0->ne[1]), + (uint32_t) (src0->ne[2]), + (uint32_t) (src0->ne[3]), + (uint32_t) (dst->ne[0]), + (uint32_t) (dst->ne[1]), + (uint32_t) (dst->ne[2]) }; + + std::vector entries = { + { .binding = 0, + .buffer = ggml_webgpu_tensor_buf(src0), + .offset = ggml_webgpu_tensor_align_offset(ctx, src0), + .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, + { .binding = 1, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = ggml_webgpu_tensor_align_offset(ctx, dst), + .size = ggml_webgpu_tensor_binding_size(ctx, dst) } + }; + + ggml_webgpu_shader_lib_context shader_lib_ctx = { + .src0 = src0, + .dst = dst, + .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, + }; + + webgpu_pipeline pipeline = ctx->shader_lib->get_repeat_pipeline(shader_lib_ctx); + auto * decisions = static_cast(pipeline.context.get()); + uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); +} + static webgpu_command ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { int inplace = ggml_webgpu_tensor_equal(src, dst); @@ -2158,6 +2200,8 @@ static std::optional ggml_webgpu_encode_node(webgpu_context ctx, return ggml_webgpu_binary_op(ctx, src0, src1, node); case GGML_OP_CONCAT: return ggml_webgpu_concat(ctx, src0, src1, node); + case GGML_OP_REPEAT: + return ggml_webgpu_repeat(ctx, src0, node); case GGML_OP_RMS_NORM: return ggml_webgpu_rms_norm(ctx, src0, node); case GGML_OP_ROPE: @@ -2919,10 +2963,10 @@ static ggml_backend_buffer_type_t ggml_backend_webgpu_device_get_buffer_type(ggm /* .iface = */ { /* .get_name = */ ggml_backend_webgpu_buffer_type_get_name, /* .alloc_buffer = */ - ggml_backend_webgpu_buffer_type_alloc_buffer, /* .get_alignment = */ - ggml_backend_webgpu_buffer_type_get_alignment, /* .get_max_size = */ - ggml_backend_webgpu_buffer_type_get_max_size, /* .get_alloc_size = */ - ggml_backend_webgpu_buffer_type_get_alloc_size, /* .is_host = */ NULL, // defaults to false + ggml_backend_webgpu_buffer_type_alloc_buffer, /* .get_alignment = */ + ggml_backend_webgpu_buffer_type_get_alignment, /* .get_max_size = */ + ggml_backend_webgpu_buffer_type_get_max_size, /* .get_alloc_size = */ + ggml_backend_webgpu_buffer_type_get_alloc_size, /* .is_host = */ NULL, // defaults to false }, /* .device = */ dev, @@ -3000,6 +3044,9 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const case GGML_OP_CONCAT: supports_op = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32); break; + case GGML_OP_REPEAT: + supports_op = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32 || src0->type == GGML_TYPE_I16); + break; case GGML_OP_CPY: case GGML_OP_CONT: supports_op = ((op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl new file mode 100644 index 00000000000..6e2a1a8b614 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl @@ -0,0 +1,67 @@ +enable f16; + +struct Params { + ne: u32, + + offset_src0: u32, + offset_dst: u32, + + stride_src0_0: u32, + stride_src0_1: u32, + stride_src0_2: u32, + stride_src0_3: u32, + + a_ne0: u32, + a_ne1: u32, + a_ne2: u32, + a_ne3: u32, + + ne0: u32, + ne1: u32, + ne2: u32, +}; + +#ifdef TYPE_F32 +#define DataType f32 +#endif +#ifdef TYPE_I32 +#define DataType i32 +#endif +#ifdef TYPE_I16 +// same size (16-bit) is sufficient for repeat +#define DataType f16 +#endif + +@group(0) @binding(0) +var src0: array; + +@group(0) @binding(1) +var dst: array; + +@group(0) @binding(2) +var params: Params; + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(global_invocation_id) gid: vec3) { + if (gid.x < params.ne) { + var i = gid.x; + let i3 = i / (params.ne2 * params.ne1 * params.ne0); + i = i % (params.ne2 * params.ne1 * params.ne0); + let i2 = i / (params.ne1 * params.ne0); + i = i % (params.ne1 * params.ne0); + let i1 = i / params.ne0; + let i0 = i % params.ne0; + + let a_i0 = i0 % params.a_ne0; + let a_i1 = i1 % params.a_ne1; + let a_i2 = i2 % params.a_ne2; + let a_i3 = i3 % params.a_ne3; + + let a_index = a_i0 * params.stride_src0_0 + + a_i1 * params.stride_src0_1 + + a_i2 * params.stride_src0_2 + + a_i3 * params.stride_src0_3; + + dst[params.offset_dst + gid.x] = src0[params.offset_src0 + a_index]; + } +} From f5ba86537886bbf9be621aeff1e43375f3de9db2 Mon Sep 17 00:00:00 2001 From: uvos Date: Thu, 12 Mar 2026 03:37:10 +0100 Subject: [PATCH 254/831] hip: compile debug builds with -O2 on hip to avoid a compiler bug (llama/20392) --- ggml/src/ggml-hip/CMakeLists.txt | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ggml/src/ggml-hip/CMakeLists.txt b/ggml/src/ggml-hip/CMakeLists.txt index 80037d24361..b44ed0f7215 100644 --- a/ggml/src/ggml-hip/CMakeLists.txt +++ b/ggml/src/ggml-hip/CMakeLists.txt @@ -11,6 +11,10 @@ endif() list(APPEND CMAKE_PREFIX_PATH ${ROCM_PATH}) list(APPEND CMAKE_PREFIX_PATH "${ROCM_PATH}/lib64/cmake") +if (NOT DEFINED CMAKE_HIP_FLAGS_DEBUG) + set(CMAKE_HIP_FLAGS_DEBUG "-g -O2") +endif() + # CMake on Windows doesn't support the HIP language yet if (WIN32) set(CXX_IS_HIPCC TRUE) From 193781cf0ebd2b299271106f22db91920e62caa4 Mon Sep 17 00:00:00 2001 From: shaofeiqi Date: Wed, 11 Mar 2026 22:03:07 -0700 Subject: [PATCH 255/831] opencl: add cumsum op (llama/18981) * OpenCL: add CUMSUM op support * remove unused argument * opencl: refactor cumsum * opencl: refactor * opencl: refactor tmp buffer * opencl: adjust max number of subgroups * opencl: fix whitespace * opencl: fix global size when cumsum the tmp buffer --------- Co-authored-by: Li He --- ggml/src/ggml-opencl/CMakeLists.txt | 1 + ggml/src/ggml-opencl/ggml-opencl.cpp | 139 +++++++++++++++++++++++++ ggml/src/ggml-opencl/kernels/cumsum.cl | 139 +++++++++++++++++++++++++ 3 files changed, 279 insertions(+) create mode 100644 ggml/src/ggml-opencl/kernels/cumsum.cl diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index 70802c9c001..1f8250934b0 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -132,6 +132,7 @@ set(GGML_OPENCL_KERNELS ssm_conv sub sum_rows + cumsum transpose concat tsembd diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 67e4b9277f5..46a95a19990 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -547,6 +547,7 @@ struct ggml_backend_opencl_context { cl_kernel kernel_im2col_f32, kernel_im2col_f16; cl_kernel kernel_argsort_f32_i32; cl_kernel kernel_sum_rows_f32, kernel_sum_rows_f32_4; + cl_kernel kernel_cumsum_blk, kernel_cumsum_add; cl_kernel kernel_repeat_f32; cl_kernel kernel_pad; cl_kernel kernel_tanh_f32, kernel_tanh_f32_4, kernel_tanh_f32_nc; @@ -1927,6 +1928,24 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // cumsum + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "cumsum.cl.h" + }; +#else + const std::string kernel_src = read_file("cumsum.cl"); +#endif + cl_program prog; + prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_cumsum_blk = clCreateKernel(prog, "kernel_cumsum_blk", &err), err)); + CL_CHECK((backend_ctx->kernel_cumsum_add = clCreateKernel(prog, "kernel_cumsum_add", &err), err)); + GGML_LOG_CONT("."); + CL_CHECK(clReleaseProgram(prog)); + } + // sigmoid { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -3803,6 +3822,8 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te return cols <= max_workgroup_size && op->src[0]->type == GGML_TYPE_F32; } case GGML_OP_SUM_ROWS: + case GGML_OP_CUMSUM: + return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]); case GGML_OP_MEAN: return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_FLASH_ATTN_EXT: @@ -11949,6 +11970,118 @@ static void ggml_cl_sum_rows(ggml_backend_t backend, const ggml_tensor * src0, c backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); } +static void ggml_cl_cumsum(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + GGML_UNUSED(src1); + + GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type)); + GGML_ASSERT(ggml_is_contiguous(src0)); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + GGML_TENSOR_LOCALS(int, ne0, src0, ne); + GGML_TENSOR_LOCALS(cl_ulong, nb0, src0, nb); + + cl_kernel kernel = backend_ctx->kernel_cumsum_blk; + + int max_workgroup_size = backend_ctx->get_kernel_workgroup_size(kernel); + int nth = 1; + while (nth < ne00 && 2*nth <= max_workgroup_size) { + nth *= 2; + } + + GGML_ASSERT(ne00 <= nth*nth); + + const int net0 = CEIL_DIV(ne00, nth); + const int net1 = ne01; + const int net2 = ne02; + const int net3 = ne03; + + const cl_ulong nbt0 = sizeof(float); + const cl_ulong nbt1 = net0*nbt0; + const cl_ulong nbt2 = net1*nbt1; + const cl_ulong nbt3 = net2*nbt2; + + static ggml_cl_buffer tmp_buffer; + tmp_buffer.allocate(backend_ctx->context, net0*ne01*ne02*ne03*sizeof(float)); + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &tmp_buffer.buffer)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne03)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb00)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &net0)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &net1)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &net2)); + + size_t global_work_size[] = { (size_t)(nth*net0*ne01), (size_t)ne02, (size_t)ne03}; + size_t local_work_size[] = { (size_t)nth, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + + if(ne00 > nth) { + // if a single workgroup cannot handle an entire row, each workgroup + // computes a partial sum and stores to dst, tmp_buffer contains the sum + // of the each workgroup; cumsum this buffer and add to the partial sums in dst + cl_ulong offsett = 0; + kernel = backend_ctx->kernel_cumsum_blk; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &tmp_buffer.buffer)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offsett)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &tmp_buffer.buffer)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &tmp_buffer.buffer)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &offsett)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &net0)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne03)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nbt0)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nbt1)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nbt2)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nbt3)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &net0)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &net1)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &net2)); + + size_t global_work_size_1[] = { (size_t)net1*nth, (size_t)net2, (size_t)net3}; + size_t local_work_size_1[] = { (size_t)nth, 1, 1}; + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size_1, local_work_size_1, dst); + + kernel = backend_ctx->kernel_cumsum_add; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &tmp_buffer.buffer)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne03)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &nbt0)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &nbt1)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &nbt2)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &nbt3)); + + size_t global_work_size_2[] = { (size_t)(nth*net0*ne01), (size_t)ne02, (size_t)ne03}; + size_t local_work_size_2[] = { (size_t)nth, 1, 1}; + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size_2, local_work_size_2, dst); + } +} + static void ggml_cl_glu(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0); GGML_ASSERT(src0->extra); @@ -12391,6 +12524,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor } func = ggml_cl_sum_rows; break; + case GGML_OP_CUMSUM: + if (!any_on_device) { + return false; + } + func = ggml_cl_cumsum; + break; case GGML_OP_FLASH_ATTN_EXT: if (!any_on_device) { return false; diff --git a/ggml/src/ggml-opencl/kernels/cumsum.cl b/ggml/src/ggml-opencl/kernels/cumsum.cl new file mode 100644 index 00000000000..edfb74b7058 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/cumsum.cl @@ -0,0 +1,139 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +// max workgroup size is usually 1024, this covers various subgroups sizes +#define MAX_SUBGROUPS 128 + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_32 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_cumsum_blk( + global char * src0, + ulong offset0, + global char * tmp, + global char * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + uint net0, + uint net1, + uint net2 +) { + src0 = src0 + offset0; + dst = dst + offsetd; + + const int i3 = get_group_id(2); + const int i2 = get_group_id(1); + const int i1 = get_group_id(0); + + const int nth = get_local_size(0); + const int tid = get_local_id(0); + + const uint sg_size = get_sub_group_size(); + const uint sg_id = get_sub_group_id(); + const uint sg_lid = get_sub_group_local_id(); + + const int ib = i1 / ne01; + const int i00 = ib * nth; + const int i01 = i1 % ne01; + const int i02 = i2; + const int i03 = i3; + + global const float * src0_row = (global const float *)(src0 + i03*nb03 + i02*nb02 + i01*nb01); + global float * tmp_row = (global float *)tmp + net0 * i01 + net0 * net1 * i02 + net0 * net1 * net2 * i03; + global float * dst_row = (global float *)dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + __local float partial[MAX_SUBGROUPS]; + + float v = 0.0f; + if (i00 + tid < ne00) { + v = src0_row[i00 + tid]; + } + + float s = sub_group_scan_inclusive_add(v); + if (sg_lid == sg_size - 1) { + partial[sg_id] = s; + } + barrier(CLK_LOCAL_MEM_FENCE); + + // NB: subgroup size should be larger than number of subgroups + // assuming max workgroup size of 1024, subgroup size should be >= 32 + if (sg_id == 0) { + float x = 0.0f; + if (sg_lid < get_num_sub_groups()) { + x = partial[sg_lid]; + } + float ex = sub_group_scan_exclusive_add(x); + if (sg_lid < get_num_sub_groups()) { + partial[sg_lid] = ex; + } + } + barrier(CLK_LOCAL_MEM_FENCE); + + s += partial[sg_id]; + + if (i00 + tid < ne00) { + dst_row[i00 + tid] = s; + } + if (ne00 > nth && tid == nth - 1) { + tmp_row[ib] = s; + } +} + +kernel void kernel_cumsum_add( + global char * tmp, + global char * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + uint nbt0, + uint nbt1, + uint nbt2, + uint nbt3 +) { + dst = dst + offsetd; + + const int i3 = get_group_id(2); + const int i2 = get_group_id(1); + const int i1 = get_group_id(0); + + const int nth = get_local_size(0); + const int tid = get_local_id(0); + + const int ib = i1 / ne01; + if (ib == 0) { + return; + } + const int i00 = ib * nth; + const int i01 = i1 % ne01; + const int i02 = i2; + const int i03 = i3; + + global float * tmp_row = (global float *)(tmp + nbt1 * i01 + nbt2 * i02 + nbt3 * i03); + global float * dst_row = (global float *)dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + if (i00 + tid < ne00) { + dst_row[i00 + tid] += tmp_row[ib - 1]; + } +} From d5772cf7b27ad19e48f0db2537fd07ac89ca468f Mon Sep 17 00:00:00 2001 From: lhez Date: Wed, 11 Mar 2026 22:03:27 -0700 Subject: [PATCH 256/831] opencl: use larger workgroup size for get_rows (llama/20316) --- ggml/src/ggml-opencl/ggml-opencl.cpp | 29 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 46a95a19990..e1dca6b4b4d 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -5796,19 +5796,12 @@ static void ggml_cl_get_rows(ggml_backend_t backend, const ggml_tensor * src0, c GGML_ASSERT(dst); GGML_ASSERT(dst->extra); - const int ne00 = src0->ne[0]; - const cl_ulong nb01 = src0->nb[1]; - const cl_ulong nb02 = src0->nb[2]; - const cl_ulong nb03 = src0->nb[3]; - const int ne10 = src1->ne[0]; - const cl_ulong nb10 = src1->nb[0]; - const int ne11 = src1->ne[1]; - const int ne12 = src1->ne[2]; - const cl_ulong nb11 = src1->nb[1]; - const cl_ulong nb12 = src1->nb[2]; - const cl_ulong nb1 = dst->nb[1]; - const cl_ulong nb2 = dst->nb[2]; - const cl_ulong nb3 = dst->nb[3]; + GGML_TENSOR_LOCALS(int, ne0, src0, ne); + GGML_TENSOR_LOCALS(cl_ulong, nb0, src0, nb); + GGML_TENSOR_LOCALS(int, ne1, src1, ne); + GGML_TENSOR_LOCALS(cl_ulong, nb1, src1, nb); + GGML_TENSOR_LOCALS(int, ne, dst, ne); + GGML_TENSOR_LOCALS(cl_ulong, nb, dst, nb); ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; @@ -5854,8 +5847,14 @@ static void ggml_cl_get_rows(ggml_backend_t backend, const ggml_tensor * src0, c CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb2)); CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb3)); - size_t global_work_size[] = {(size_t)ne10*64, (size_t)ne11, (size_t)ne12}; - size_t local_work_size[] = {64, 1, 1}; + int max_workgroup_size = backend_ctx->get_kernel_workgroup_size(kernel); + int nth = 1; + while (nth < ne00 && 2*nth <= max_workgroup_size) { + nth *= 2; + } + + size_t global_work_size[] = {(size_t)ne10*nth, (size_t)ne11, (size_t)ne12}; + size_t local_work_size[] = {(size_t)nth, 1, 1}; backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); } From 26ee4f73623afa6a5f214780843e3a58e417fcf4 Mon Sep 17 00:00:00 2001 From: Masato Nakasaka Date: Wed, 11 Mar 2026 22:30:16 -0700 Subject: [PATCH 257/831] vulkan: Fix ErrorOutOfHostMemory on Intel GPU when loading large models with --no-mmap (llama/20059) * Changed to reuse command buffers to fix crashing on Intel GPU * Removed unused parameter * Fixed compile error and minor mistake * Fix logging * Changing to use usage flag per command buffer * fixed style * added buffer reset * Removed cmd_buffer_idx for reuse consistency * Fixed style --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 129 +++++++++++++++++---------- 1 file changed, 81 insertions(+), 48 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 61d112c50a7..8807c3e2b6e 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -27,6 +27,7 @@ DispatchLoaderDynamic & ggml_vk_default_dispatcher(); #include #include #include +#include #include #include #include @@ -188,6 +189,11 @@ struct ggml_backend_vk_buffer_type_context { struct vk_queue; +struct vk_command_buffer { + vk::CommandBuffer buf; + bool in_use = false; +}; + // Stores command pool/buffers. There's an instance of this // for each (context,queue) pair and for each (device,queue) pair. struct vk_command_pool { @@ -195,10 +201,16 @@ struct vk_command_pool { void destroy(vk::Device& device); vk::CommandPool pool; - uint32_t cmd_buffer_idx; - std::vector cmd_buffers; + // Using deque so the pointers to command buffers + // remain valid even if we add more + std::deque cmd_buffers; vk_queue *q; + + size_t buffers_in_use() const { + return std::count_if(cmd_buffers.begin(), cmd_buffers.end(), + [](const auto& cb) { return cb.in_use; }); + } }; // Prevent simultaneous submissions to the same queue. @@ -878,10 +890,12 @@ struct vk_device_struct { }; void vk_command_pool::init(vk_device& device, vk_queue *q_) { - cmd_buffer_idx = 0; + cmd_buffers.clear(); q = q_; - vk::CommandPoolCreateInfo command_pool_create_info(vk::CommandPoolCreateFlags(VK_COMMAND_POOL_CREATE_TRANSIENT_BIT), q->queue_family_index); + vk::CommandPoolCreateInfo command_pool_create_info( + vk::CommandPoolCreateFlags(VK_COMMAND_POOL_CREATE_TRANSIENT_BIT | VK_COMMAND_POOL_CREATE_RESET_COMMAND_BUFFER_BIT), + q->queue_family_index); pool = device->device.createCommandPool(command_pool_create_info); } @@ -929,6 +943,7 @@ struct vk_subbuffer { struct vk_event { vk::Event event; vk::Fence fence; + vk_command_buffer* cmd_buffer = nullptr; }; struct vk_semaphore { @@ -937,7 +952,7 @@ struct vk_semaphore { }; struct vk_submission { - vk::CommandBuffer buffer; + vk_command_buffer* buffer = nullptr; std::vector wait_semaphores; std::vector signal_semaphores; }; @@ -2283,25 +2298,15 @@ static void ggml_pipeline_allocate_descriptor_sets(ggml_backend_vk_context * ctx } } -static vk::CommandBuffer ggml_vk_create_cmd_buffer(vk_device& device, vk_command_pool& p) { +static vk_command_buffer* ggml_vk_create_cmd_buffer(vk_device& device, vk_command_pool& p) { VK_LOG_DEBUG("ggml_vk_create_cmd_buffer()"); - - if (p.cmd_buffers.size() > p.cmd_buffer_idx) { - // Reuse command buffer - return p.cmd_buffers[p.cmd_buffer_idx++]; - } - vk::CommandBufferAllocateInfo command_buffer_alloc_info( p.pool, vk::CommandBufferLevel::ePrimary, 1); const std::vector cmd_buffers = device->device.allocateCommandBuffers(command_buffer_alloc_info); - auto buf = cmd_buffers.front(); - - p.cmd_buffers.push_back(buf); - p.cmd_buffer_idx++; - - return buf; + p.cmd_buffers.push_back({ cmd_buffers.front(), true }); + return &p.cmd_buffers[p.cmd_buffers.size()-1]; } static void ggml_vk_submit(vk_context& ctx, vk::Fence fence) { @@ -2368,7 +2373,7 @@ static void ggml_vk_submit(vk_context& ctx, vk::Fence fence) { tl_wait_semaphores[idx].data(), stage_flags[idx].data(), 1, - &submission.buffer, + &submission.buffer->buf, (uint32_t) submission.signal_semaphores.size(), tl_signal_semaphores[idx].data(), }; @@ -2492,7 +2497,11 @@ static void ggml_vk_command_pool_cleanup(vk_device& device, vk_command_pool& p) // Requires command buffers to be done device->device.resetCommandPool(p.pool); - p.cmd_buffer_idx = 0; + // Don't clear the command buffers and mark them as not in use. + // This allows us to reuse them + for (auto& cmd_buffer : p.cmd_buffers) { + cmd_buffer.in_use = false; + } } static void ggml_vk_queue_command_pools_cleanup(vk_device& device) { @@ -2501,10 +2510,10 @@ static void ggml_vk_queue_command_pools_cleanup(vk_device& device) { // Arbitrary frequency to cleanup/reuse command buffers static constexpr uint32_t cleanup_frequency = 10; - if (device->compute_queue.cmd_pool.cmd_buffer_idx >= cleanup_frequency) { + if (device->compute_queue.cmd_pool.buffers_in_use() >= cleanup_frequency) { ggml_vk_command_pool_cleanup(device, device->compute_queue.cmd_pool); } - if (device->transfer_queue.cmd_pool.cmd_buffer_idx >= cleanup_frequency) { + if (device->transfer_queue.cmd_pool.buffers_in_use() >= cleanup_frequency) { ggml_vk_command_pool_cleanup(device, device->transfer_queue.cmd_pool); } } @@ -2752,7 +2761,7 @@ static void ggml_vk_sync_buffers(ggml_backend_vk_context* ctx, vk_context& subct ctx->prealloc_x_need_sync = ctx->prealloc_y_need_sync = ctx->prealloc_split_k_need_sync = false; } - subctx->s->buffer.pipelineBarrier( + subctx->s->buffer->buf.pipelineBarrier( subctx->p->q->stage_flags, subctx->p->q->stage_flags, {}, @@ -2768,7 +2777,7 @@ static void ggml_vk_sync_buffers(ggml_backend_vk_context* ctx, vk_context& subct static void ggml_vk_set_event(vk_context& ctx, vk::Event& event) { VK_LOG_DEBUG("ggml_vk_set_event()"); - ctx->s->buffer.setEvent( + ctx->s->buffer->buf.setEvent( event, ctx->p->q->stage_flags ); @@ -2780,7 +2789,7 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector&& events return; } - ctx->s->buffer.waitEvents( + ctx->s->buffer->buf.waitEvents( events, ctx->p->q->stage_flags, ctx->p->q->stage_flags, @@ -6348,13 +6357,24 @@ static vk_subbuffer ggml_vk_tensor_subbuffer( return vk_subbuffer{buffer, offset, size}; } +// Get a command buffer from pool. Create a new one if no reusable buffer is available +static vk_command_buffer* ggml_vk_get_or_create_cmd_buffer(vk_device& device, vk_command_pool& pool) { + for (auto& cmd_buffer : pool.cmd_buffers) { + if (!cmd_buffer.in_use) { + cmd_buffer.in_use = true; + return &cmd_buffer; + } + } + return ggml_vk_create_cmd_buffer(device, pool); +} + static vk_submission ggml_vk_begin_submission(vk_device& device, vk_command_pool& p, bool one_time = true) { vk_submission s; - s.buffer = ggml_vk_create_cmd_buffer(device, p); + s.buffer = ggml_vk_get_or_create_cmd_buffer(device, p); if (one_time) { - s.buffer.begin({ vk::CommandBufferUsageFlagBits::eOneTimeSubmit }); + s.buffer->buf.begin({ vk::CommandBufferUsageFlagBits::eOneTimeSubmit }); } else { - s.buffer.begin({ vk::CommandBufferUsageFlags{} }); + s.buffer->buf.begin({ vk::CommandBufferUsageFlags{} }); } return s; @@ -6407,18 +6427,18 @@ static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context& vk::WriteDescriptorSet write_descriptor_set{ descriptor_set, 0, 0, pipeline->parameter_count, vk::DescriptorType::eStorageBuffer, nullptr, descriptor_buffer_infos.begin() }; ctx->device->device.updateDescriptorSets({ write_descriptor_set }, {}); - subctx->s->buffer.pushConstants(pipeline->layout, vk::ShaderStageFlagBits::eCompute, 0, push_constant_size(push_constants), push_constant_data(push_constants)); - subctx->s->buffer.bindPipeline(vk::PipelineBindPoint::eCompute, pipeline->pipeline); - subctx->s->buffer.bindDescriptorSets(vk::PipelineBindPoint::eCompute, + subctx->s->buffer->buf.pushConstants(pipeline->layout, vk::ShaderStageFlagBits::eCompute, 0, push_constant_size(push_constants), push_constant_data(push_constants)); + subctx->s->buffer->buf.bindPipeline(vk::PipelineBindPoint::eCompute, pipeline->pipeline); + subctx->s->buffer->buf.bindDescriptorSets(vk::PipelineBindPoint::eCompute, pipeline->layout, 0, { descriptor_set }, {}); - subctx->s->buffer.dispatch(wg0, wg1, wg2); + subctx->s->buffer->buf.dispatch(wg0, wg1, wg2); } static void ggml_vk_end_submission(vk_submission& s, std::vector wait_semaphores, std::vector signal_semaphores) { - s.buffer.end(); + s.buffer->buf.end(); s.wait_semaphores = std::move(wait_semaphores); s.signal_semaphores = std::move(signal_semaphores); @@ -6430,7 +6450,7 @@ static void ggml_vk_ctx_end(vk_context& ctx) { return; } - ctx->s->buffer.end(); + ctx->s->buffer->buf.end(); ctx->s = nullptr; } @@ -6584,7 +6604,7 @@ static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_cont } ggml_vk_sync_buffers(ctx, subctx); - subctx->s->buffer.copyBuffer(buf->buffer, dst->buffer, slices); + subctx->s->buffer->buf.copyBuffer(buf->buffer, dst->buffer, slices); return; } @@ -6599,7 +6619,7 @@ static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_cont VkBufferCopy buf_copy{ 0, offset, copy_size }; ggml_vk_sync_buffers(ctx, subctx); - vkCmdCopyBuffer(subctx->s->buffer, (VkBuffer)staging->buffer, (VkBuffer)dst->buffer, 1, &buf_copy); + vkCmdCopyBuffer(subctx->s->buffer->buf, (VkBuffer)staging->buffer, (VkBuffer)dst->buffer, 1, &buf_copy); for (uint64_t i3 = 0; i3 < ne3; i3++) { for (uint64_t i2 = 0; i2 < ne2; i2++) { @@ -6648,7 +6668,7 @@ static bool ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, siz } ggml_vk_sync_buffers(nullptr, subctx); - subctx->s->buffer.copyBuffer(buf->buffer, dst->buffer, slices); + subctx->s->buffer->buf.copyBuffer(buf->buffer, dst->buffer, slices); return true; } VK_LOG_DEBUG("STAGING"); @@ -6670,7 +6690,7 @@ static bool ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, siz copy_size}; ggml_vk_sync_buffers(nullptr, subctx); - vkCmdCopyBuffer(subctx->s->buffer, (VkBuffer)staging_buffer->buffer, (VkBuffer)dst->buffer, 1, &buf_copy); + vkCmdCopyBuffer(subctx->s->buffer->buf, (VkBuffer)staging_buffer->buffer, (VkBuffer)dst->buffer, 1, &buf_copy); if (width == spitch) { deferred_memcpy((uint8_t *)staging_buffer->ptr, src, width * height, &subctx->in_memcpys); @@ -6756,7 +6776,7 @@ static bool ggml_vk_buffer_read_2d_async(vk_context subctx, vk_buffer& src, size if (buf != nullptr) { // Memory is pinned, use as staging buffer ggml_vk_sync_buffers(nullptr, subctx); - subctx->s->buffer.copyBuffer(src->buffer, buf->buffer, slices); + subctx->s->buffer->buf.copyBuffer(src->buffer, buf->buffer, slices); return true; } @@ -6774,7 +6794,7 @@ static bool ggml_vk_buffer_read_2d_async(vk_context subctx, vk_buffer& src, size vk_buffer& staging_buffer = src->device->sync_staging; ggml_vk_sync_buffers(nullptr, subctx); - subctx->s->buffer.copyBuffer(src->buffer, staging_buffer->buffer, slices); + subctx->s->buffer->buf.copyBuffer(src->buffer, staging_buffer->buffer, slices); deferred_memcpy(dst, staging_buffer->ptr, copy_size, &subctx->out_memcpys); return true; @@ -6821,7 +6841,7 @@ static void ggml_vk_buffer_copy_async(vk_context& ctx, vk_buffer& dst, size_t ds VkBufferCopy bc{ src_offset, dst_offset, size }; - vkCmdCopyBuffer(ctx->s->buffer, (VkBuffer)src->buffer, (VkBuffer)dst->buffer, 1, &bc); + vkCmdCopyBuffer(ctx->s->buffer->buf, (VkBuffer)src->buffer, (VkBuffer)dst->buffer, 1, &bc); } static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& src, size_t src_offset, size_t size) { @@ -6859,7 +6879,7 @@ static void ggml_vk_buffer_memset_async(vk_context& ctx, vk_buffer& dst, size_t } // Fall back to GPU fillBuffer for non-UMA or non-host-visible buffers - ctx->s->buffer.fillBuffer(dst->buffer, offset, size, c); + ctx->s->buffer->buf.fillBuffer(dst->buffer, offset, size, c); } static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, size_t size) { @@ -6874,7 +6894,7 @@ static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, siz std::lock_guard guard(dst->device->mutex); vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue.cmd_pool); ggml_vk_ctx_begin(dst->device, subctx); - subctx->s->buffer.fillBuffer(dst->buffer, offset, size, c); + subctx->s->buffer->buf.fillBuffer(dst->buffer, offset, size, c); ggml_vk_ctx_end(subctx); ggml_vk_submit(subctx, dst->device->fence); @@ -12682,7 +12702,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr if (vk_perf_logger_enabled && vk_perf_logger_concurrent) { ctx->query_node_idx[ctx->query_idx] = node_idx; - compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++); + compute_ctx->s->buffer->buf.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++); } } // Add all fused nodes to the unsynchronized lists. @@ -13521,7 +13541,7 @@ static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, ggml_tensor buffer_cpy.dstOffset = dst_offset; buffer_cpy.size = size; - cpy_ctx->s->buffer.copyBuffer(ctx->sync_staging->buffer, buf->buffer, { buffer_cpy }); + cpy_ctx->s->buffer->buf.copyBuffer(ctx->sync_staging->buffer, buf->buffer, { buffer_cpy }); deferred_memcpy(ctx->sync_staging->ptr, data, size, &cpy_ctx->in_memcpys); ggml_vk_synchronize(ctx); } @@ -13555,7 +13575,7 @@ static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_ buffer_cpy.dstOffset = 0; buffer_cpy.size = size; - compute_ctx->s->buffer.copyBuffer(buf->buffer, ctx->sync_staging->buffer, { buffer_cpy }); + compute_ctx->s->buffer->buf.copyBuffer(buf->buffer, ctx->sync_staging->buffer, { buffer_cpy }); deferred_memcpy(data, ctx->sync_staging->ptr, size, &compute_ctx->out_memcpys); ggml_vk_synchronize(ctx); } @@ -13633,8 +13653,12 @@ static void ggml_vk_synchronize(ggml_backend_vk_context * ctx) { } vk_context compute_ctx; + vk_command_buffer* cmd_buf = nullptr; if (do_transfer) { compute_ctx = ctx->compute_ctx.lock(); + if (compute_ctx->s) { + cmd_buf = compute_ctx->s->buffer; + } ggml_vk_ctx_end(compute_ctx); @@ -13668,6 +13692,9 @@ static void ggml_vk_synchronize(ggml_backend_vk_context * ctx) { } ggml_vk_wait_for_fence(ctx); ctx->submit_pending = false; + if (cmd_buf) { + cmd_buf->in_use = false; + } } if (do_transfer) { @@ -14157,7 +14184,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg GGML_ASSERT(ctx->compute_ctx.expired()); compute_ctx = ggml_vk_get_compute_ctx(ctx); ctx->query_idx = 0; - compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++); + compute_ctx->s->buffer->buf.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++); } ctx->prealloc_y_last_pipeline_used = nullptr; @@ -14393,7 +14420,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg // track a single node/fusion for the current query ctx->query_nodes[ctx->query_idx] = cgraph->nodes[i]; ctx->query_fusion_names[ctx->query_idx] = fusion_string; - compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++); + compute_ctx->s->buffer->buf.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++); } else { // track a fusion string and number of fused ops for the current node_idx ctx->query_fusion_names[i] = fusion_string; @@ -14726,6 +14753,7 @@ static void ggml_backend_vk_event_record(ggml_backend_t backend, ggml_backend_ev ggml_vk_submit_transfer_ctx(ctx); vk_context compute_ctx = ggml_vk_get_compute_ctx(ctx); + auto* cmd_buf = compute_ctx->s->buffer; // retrieve pointer before it gets reset // the backend interface doesn't have an explicit reset, so reset it here // before we record the command to set it @@ -14738,6 +14766,7 @@ static void ggml_backend_vk_event_record(ggml_backend_t backend, ggml_backend_ev ggml_vk_submit(compute_ctx, {vkev->fence}); ctx->submit_pending = true; + vkev->cmd_buffer = cmd_buf; ctx->compute_ctx.reset(); } @@ -15557,6 +15586,10 @@ static void ggml_backend_vk_device_event_synchronize(ggml_backend_dev_t dev, ggm vk_event *vkev = (vk_event *)event->context; VK_CHECK(device->device.waitForFences({ vkev->fence }, true, UINT64_MAX), "event_synchronize"); + // Finished using current command buffer so we flag for reuse + if (vkev->cmd_buffer) { + vkev->cmd_buffer->in_use = false; + } } static vk_buffer ggml_vk_buffer_from_host_ptr(vk_device & device, void * ptr, size_t size) { From 6c5e3aac3eebec72887ff4ea10623d2c84430f52 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Thu, 12 Mar 2026 00:35:49 -0500 Subject: [PATCH 258/831] vulkan: fix OOB check in flash_attn_mask_opt (llama/20296) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 2 +- .../vulkan-shaders/flash_attn_mask_opt.comp | 98 +++++++++++-------- 2 files changed, 60 insertions(+), 40 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 8807c3e2b6e..6574955cf10 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -8840,7 +8840,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx } // Only use mask opt when the mask is fairly large. This hasn't been tuned extensively. - bool use_mask_opt = mask && nem1 >= 32 && nem0 * nem1 > 32768; + bool use_mask_opt = mask && nem1 >= 32 && nem0 * nem1 > 32768 && nem0 >= tuning_params.block_cols * 16; vk_fa_pipeline_state fa_pipeline_state = get_fa_pipeline_state(ctx->device, tuning_params, HSK, HSV, aligned, f32acc, mask != nullptr, use_mask_opt, logit_softcap != 0); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp index 8c92c1adcda..0e417708062 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp @@ -33,6 +33,61 @@ layout (push_constant) uniform parameter { shared float minsh[NUM_SUBGROUPS]; shared float maxsh[NUM_SUBGROUPS]; +float FLT_MAX_OVER_2 = uintBitsToFloat(0x7EFFFFFF); + +void loadvec4(inout uint result, const uint i0, const uint i1, const uint i2, const uint i3, const bool need_bounds_check) { + const uint tid = gl_LocalInvocationIndex; + + [[unroll]] for (uint block_x = 0; block_x < 16; ++block_x) { + float min_v = FLT_MAX_OVER_2; + float max_v = -FLT_MAX_OVER_2; + [[unroll]] for (uint i = 0; i < Br * Bc / 4; i += BLOCK_SIZE) { + uint j0 = (i + tid) % (Bc / 4); + uint j1 = (i + tid) / (Bc / 4); + + j0 *= 4; + j0 += (i0 * 16 + block_x) * Bc; + j1 += i1 * Br; + + if (!need_bounds_check || j0 + 3 < nem0) { + vec4 f = vec4(data_av4[(j0 + j1 * nbm1 + i2 * nbm2 + i3 * nbm3) / 4]); + [[unroll]] for (int c = 0; c < 4; ++c) { + min_v = min(min_v, f[c]); + max_v = max(max_v, f[c]); + } + } else { + [[unroll]] for (int c = 0; c < 4; ++c) { + if (j0 + c < nem0) { + float f = float(data_a[j0 + j1 * nbm1 + i2 * nbm2 + i3 * nbm3]); + min_v = min(min_v, f); + max_v = max(max_v, f); + } + } + } + } + min_v = subgroupMin(min_v); + max_v = subgroupMax(max_v); + if (gl_SubgroupInvocationID == 0) { + minsh[gl_SubgroupID] = min_v; + maxsh[gl_SubgroupID] = max_v; + } + barrier(); + if (tid == 0) { + [[unroll]] for (uint i = 0; i < NUM_SUBGROUPS; ++i) { + min_v = min(min_v, minsh[i]); + max_v = max(max_v, maxsh[i]); + } + if (max_v <= -FLT_MAX_OVER_2) { + result |= 1 << (2*block_x); + } + if (min_v == 0.0f && max_v == 0.0f) { + result |= 2 << (2*block_x); + } + } + barrier(); + } +} + // For each Br x Bc block of the mask (input) buffer, read all values and check // if it's all -inf or all zero. Write out a two-bit code indicating which it is // (or zero for neither). Each workgroup processes 16 tiles and writes out a @@ -48,50 +103,15 @@ void main() { const uint i2 = gl_WorkGroupID.z % nem2; const uint i3 = gl_WorkGroupID.z / nem2; - float FLT_MAX_OVER_2 = uintBitsToFloat(0x7EFFFFFF); - uint result = 0; // Fast path for fully in-bounds blocks where we can do f16vec4 loads if ((nem0 % Bc) == 0 && (nem1 % Br) == 0 && ((Br * Bc) % (BLOCK_SIZE * 4)) == 0) { - [[unroll]] for (uint block_x = 0; block_x < 16; ++block_x) { - float min_v = FLT_MAX_OVER_2; - float max_v = -FLT_MAX_OVER_2; - [[unroll]] for (uint i = 0; i < Br * Bc / 4; i += BLOCK_SIZE) { - uint j0 = (i + tid) % (Bc / 4); - uint j1 = (i + tid) / (Bc / 4); - - j0 *= 4; - j0 += (i0 * 16 + block_x) * Bc; - j1 += i1 * Br; - - vec4 f = vec4(data_av4[(j0 + j1 * nbm1 + i2 * nbm2 + i3 * nbm3) / 4]); - [[unroll]] for (int c = 0; c < 4; ++c) { - min_v = min(min_v, f[c]); - max_v = max(max_v, f[c]); - } - } - min_v = subgroupMin(min_v); - max_v = subgroupMax(max_v); - if (gl_SubgroupInvocationID == 0) { - minsh[gl_SubgroupID] = min_v; - maxsh[gl_SubgroupID] = max_v; - } - barrier(); - if (tid == 0) { - [[unroll]] for (uint i = 0; i < NUM_SUBGROUPS; ++i) { - min_v = min(min_v, minsh[i]); - max_v = max(max_v, maxsh[i]); - } - if (max_v <= -FLT_MAX_OVER_2) { - result |= 1 << (2*block_x); - } - if (min_v == 0.0f && max_v == 0.0f) { - result |= 2 << (2*block_x); - } - } - barrier(); + if ((i0 + 1) * 16 * Bc <= nem0) { + loadvec4(result, i0, i1, i2, i3, false); + } else { + loadvec4(result, i0, i1, i2, i3, true); } } else { [[unroll]] for (uint block_x = 0; block_x < 16; ++block_x) { From 86e312d61de2ae001fed8966ca9ccfd1d4f0ce70 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Thu, 12 Mar 2026 00:39:41 -0500 Subject: [PATCH 259/831] vulkan: fix l2_norm epsilon handling (llama/20350) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 2 +- ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 6574955cf10..ce3c85e7589 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -16061,7 +16061,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * tensor_clone = ggml_arange(ggml_ctx, start, stop, step); } else if (tensor->op == GGML_OP_FILL) { const float value = ggml_get_op_params_f32(tensor, 0); - tensor_clone = ggml_fill(ggml_ctx, tensor_clone, value); + tensor_clone = ggml_fill(ggml_ctx, src_clone[0], value); } else if (tensor->op == GGML_OP_SQR) { tensor_clone = ggml_sqr(ggml_ctx, src_clone[0]); } else if (tensor->op == GGML_OP_SQRT) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp index 7d0a1de0df9..f9af46744df 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp @@ -36,7 +36,7 @@ void main() { barrier(); } - const FLOAT_TYPE scale = inversesqrt(max(sum[0], FLOAT_TYPE(p.param1))); + const FLOAT_TYPE scale = 1.0f / max(sqrt(sum[0]), FLOAT_TYPE(p.param1)); [[unroll]] for (uint i0 = tid; i0 < p.ne00; i0 += BLOCK_SIZE) { data_d[i3*p.nb13 + i2*p.nb12 + i1*p.nb11 + i0] = D_TYPE(scale * FLOAT_TYPE(data_a[i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0])); From 7ccebd52640336b1ef011dd38d2ddf3a1245b6b7 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 16 Mar 2026 07:12:37 +0200 Subject: [PATCH 260/831] sync : ggml --- scripts/sync-ggml.last | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index 769d0fb6684..4be1f07dcf0 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -4773cde162a55f0d10a6a6d7c2ea4378e30e0b01 +4136ece06e4d5748212bad90e2916f71ce3aeffb From b48ffe28fce816fef851d2b4d7bb931569b258c1 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 16 Mar 2026 07:12:50 +0200 Subject: [PATCH 261/831] metal : avoid divisions in bin kernel (llama/20426) --- ggml/src/ggml-metal/ggml-metal-context.m | 4 +++- ggml/src/ggml-metal/ggml-metal-device.cpp | 4 +++- ggml/src/ggml-metal/ggml-metal-ops.cpp | 4 +--- ggml/src/ggml-metal/ggml-metal.metal | 11 +++++++---- 4 files changed, 14 insertions(+), 9 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-context.m b/ggml/src/ggml-metal/ggml-metal-context.m index 855fd1adae8..32d97cd5d0a 100644 --- a/ggml/src/ggml-metal/ggml-metal-context.m +++ b/ggml/src/ggml-metal/ggml-metal-context.m @@ -554,7 +554,7 @@ enum ggml_status ggml_metal_graph_compute(ggml_metal_t ctx, struct ggml_cgraph * // enter here only when capturing in order to wait for all computation to finish // otherwise, we leave the graph to compute asynchronously - if (!use_capture && ctx->capture_started) { + if (use_capture && ctx->capture_started) { // wait for completion and check status of each command buffer // needed to detect if the device ran out-of-memory for example (#1881) { @@ -606,6 +606,8 @@ enum ggml_status ggml_metal_graph_compute(ggml_metal_t ctx, struct ggml_cgraph * [ctx->capture_scope endScope]; [[MTLCaptureManager sharedCaptureManager] stopCapture]; + + ctx->capture_started = false; } } diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 15ae2e517df..72ad876d5e4 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -1470,10 +1470,11 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin(ggml_metal_l const bool is_c4 = (op->src[0]->ne[0] % 4 == 0) && (op->src[1]->ne[0] % 4 == 0); + const bool is_cb = op->src[0]->ne[0] != op->src[1]->ne[0]; const bool is_rb = ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]) && (ggml_nrows(op->src[1]) == 1) && ggml_nelements(op) < 65536; snprintf(base, 256, "kernel_bin_fuse_%s_%s_%s%s", t0_str, t1_str, t_str, is_c4 ? "_4" : ""); - snprintf(name, 256, "%s_op=%d_nf=%d_rb=%d", base, op_num, n_fuse, is_rb); + snprintf(name, 256, "%s_op=%d_nf=%d_rb=%d_cb=%d", base, op_num, n_fuse, is_rb, is_cb); ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); if (!res.pipeline) { @@ -1482,6 +1483,7 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin(ggml_metal_l ggml_metal_cv_set_int16(cv, op_num, FC_BIN + 0); ggml_metal_cv_set_int16(cv, n_fuse, FC_BIN + 1); ggml_metal_cv_set_bool (cv, is_rb, FC_BIN + 2); + ggml_metal_cv_set_bool (cv, is_cb, FC_BIN + 3); res = ggml_metal_library_compile_pipeline(lib, base, name, cv); diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 306dbcf3660..c0bcad392b9 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -3180,9 +3180,7 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_set_buffer (enc, bid_dst, 3); if (pipeline.cnt) { - const int n = pipeline.c4 ? ggml_nelements(op)/4 : ggml_nelements(op); - - ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1); + ggml_metal_encoder_dispatch_threadgroups(enc, args.ne0, ggml_nrows(op), 1, 1, 1, 1); } else { const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 0b77d5349b8..24a3092af22 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1111,6 +1111,7 @@ template [[host_name("kernel_unary_f16_f16_4")]] kernel kernel_unary_t kernel_un constant short FC_bin_op [[function_constant(FC_BIN + 0)]]; constant short FC_bin_f [[function_constant(FC_BIN + 1)]]; constant bool FC_bin_rb [[function_constant(FC_BIN + 2)]]; +constant bool FC_bin_cb [[function_constant(FC_BIN + 3)]]; template kernel void kernel_bin_fuse_impl( @@ -1124,11 +1125,12 @@ kernel void kernel_bin_fuse_impl( #define FC_OP FC_bin_op #define FC_F FC_bin_f #define FC_RB FC_bin_rb +#define FC_CB FC_bin_cb if (FC_RB) { // row broadcast - const uint i0 = tgpig.x; - const uint i1 = i0%args.ne10; + const uint i0 = tgpig.y*args.ne00 + tgpig.x; + const uint i1 = FC_CB ? tgpig.x%args.ne10 : tgpig.x; device const T0 * src0_row = (device const T0 *) (src0); device T * dst_row = (device T *) (dst); @@ -1200,7 +1202,7 @@ kernel void kernel_bin_fuse_impl( device const T1 * src1_ptr = (device const T1 *) (src1 + args.o1[0] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11); for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { - const int i10 = i0%args.ne10; + const int i10 = FC_CB ? i0%args.ne10 : i0; if (FC_OP == 0) { dst_ptr[i0] = src0_ptr[i0] + src1_ptr[i10]; @@ -1225,7 +1227,7 @@ kernel void kernel_bin_fuse_impl( } for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { - const int i10 = i0%args.ne10; + const int i10 = FC_CB ? i0%args.ne10 : i0; T res = src0_ptr[i0]; @@ -1261,6 +1263,7 @@ kernel void kernel_bin_fuse_impl( #undef FC_OP #undef FC_F #undef FC_RB +#undef FC_CB } typedef decltype(kernel_bin_fuse_impl) kernel_bin_fuse_t; From 7e816a99d21d0c9f0ce58a7004afe2d5c332f25a Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 16 Mar 2026 07:13:14 +0200 Subject: [PATCH 262/831] sync : ggml --- scripts/sync-ggml.last | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index 4be1f07dcf0..444f031fa2e 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -4136ece06e4d5748212bad90e2916f71ce3aeffb +75e4e5b841fc483127ebf14b8eed9ce589a52c5a From 44c12c642e298672e90a970f33e57360d0a2f0ab Mon Sep 17 00:00:00 2001 From: ProgenyAlpha Date: Thu, 12 Mar 2026 05:03:18 -0400 Subject: [PATCH 263/831] vulkan: fix SSM_CONV PP scaling with large ubatch sizes (llama/20379) * vulkan: optimize SSM_CONV workgroup dispatch for large ubatch Tile tokens into 2D workgroups (32x16) to reduce workgroup launch overhead at large ubatch sizes. Add vec4 fast path for nc=4 (common d_conv size). Fixes PP performance degradation with ubatch > 512. Ref: ggml-org/llama.cpp#18725 Co-Authored-By: Claude Opus 4.6 * vulkan: remove unused shared memory declaration in SSM_CONV Co-Authored-By: Claude Opus 4.6 --------- Co-authored-by: Progeny Alpha Co-authored-by: Claude Opus 4.6 --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 2 +- .../ggml-vulkan/vulkan-shaders/ssm_conv.comp | 26 ++++++++++++------- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index ce3c85e7589..2a2f7f4f11c 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -4576,7 +4576,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true); } - ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_f32, "ssm_conv_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 3, sizeof(vk_op_ssm_conv_push_constants), {32, 1, 1}, {32}, 1); + ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_f32, "ssm_conv_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 3, sizeof(vk_op_ssm_conv_push_constants), {32, 16, 1}, {32, 16}, 1); ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp b/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp index d62696bcfae..6802b1fc955 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp @@ -5,8 +5,9 @@ #include "types.glsl" layout(constant_id = 0) const uint BLOCK_SIZE = 32; +layout(constant_id = 1) const uint TOKENS_PER_WG = 16; -layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z = 1) in; layout(binding = 0) readonly buffer Src0 { float src0[]; }; layout(binding = 1) readonly buffer Src1 { float src1[]; }; @@ -20,25 +21,30 @@ layout(push_constant) uniform PushConstants { }; void main() { - const uint global_thread_id = gl_GlobalInvocationID.x; - const uint i2 = gl_WorkGroupID.y; + const uint i1 = gl_GlobalInvocationID.x; + const uint i2 = gl_WorkGroupID.y * TOKENS_PER_WG + gl_LocalInvocationID.y; const uint i3 = gl_WorkGroupID.z; - if (global_thread_id >= nr || i2 >= n_t || i3 >= n_s) { + if (i1 >= nr || i2 >= n_t || i3 >= n_s) { return; } - const uint i1 = global_thread_id; const uint src0_base = i3 * (nb02 / 4) + i2 + i1 * (nb01 / 4); const uint src1_base = i1 * (nb11 / 4); - const uint dst_idx = i3 * (dst_nb2 / 4) + i2 * (dst_nb1 / 4) + i1; float sum = 0.0; - [[unroll]] for (uint i0 = 0; i0 < nc; i0++) { - const uint src0_idx = src0_base + i0; - const uint src1_idx = src1_base + i0; - sum += src0[src0_idx] * src1[src1_idx]; + + if (nc == 4) { + sum = dot( + vec4(src0[src0_base], src0[src0_base + 1], src0[src0_base + 2], src0[src0_base + 3]), + vec4(src1[src1_base], src1[src1_base + 1], src1[src1_base + 2], src1[src1_base + 3]) + ); + } else { + [[unroll]] for (uint i0 = 0; i0 < nc; i0++) { + sum += src0[src0_base + i0] * src1[src1_base + i0]; + } } + const uint dst_idx = i3 * (dst_nb2 / 4) + i2 * (dst_nb1 / 4) + i1; dst[dst_idx] = sum; } From 245091966549dbe2b02f981be4b230dcc2df603a Mon Sep 17 00:00:00 2001 From: ProgenyAlpha Date: Thu, 12 Mar 2026 06:32:04 -0400 Subject: [PATCH 264/831] vulkan: add GATED_DELTA_NET op support (llama/20334) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * vulkan: add GATED_DELTA_NET op support Implements the fused gated delta net recurrence as a Vulkan compute shader with full support for scalar gate, KDA vector gate, GQA broadcast, multi-token sequences, and permuted (non-contiguous) q/k inputs. Specialization constants select head size (32/64/128) and KDA mode at pipeline creation time. Passes all 13 test-backend-ops cases on AMD Radeon 890M (RADV GFX1150). Co-Authored-By: Claude Opus 4.6 * vulkan: optimize GATED_DELTA_NET shader (Phase 1) - vec4 dot products on all inner loops (dp4 hardware intrinsic) - Cache exp(g) in shared memory for KDA path, eliminating ~32K redundant global reads and ~16K redundant exp() calls per token - vec4 fused decay + rank-1 update (3 vec4 ops vs 12 scalar ops) - Add perf benchmark cases for GATED_DELTA_NET to test-backend-ops KDA TG: +5.4% throughput. Non-KDA: no regressions. 13/13 test-backend-ops passing on AMD Radeon 890M (RADV GFX1150). Co-Authored-By: Claude Opus 4.6 * vulkan: address review feedback for GATED_DELTA_NET Pipeline array refactor [3][2], A_TYPE/D_TYPE/FLOAT_TYPE shader macros, scale in push constants, supports_op fix, dispatch restructuring. Co-Authored-By: Claude Opus 4.6 * vulkan: use FLOAT_TYPE for buffer/shared declarations, align formatting Co-Authored-By: Claude Opus 4.6 * vulkan: add explicit FLOAT_TYPE casts for buffer loads Wrap data_q, data_k, and data_g buffer reads with FLOAT_TYPE() casts to ensure correct behavior across all Vulkan configurations. Co-Authored-By: Claude Opus 4.6 * vulkan: fix Q/K broadcast for interleaved head layout Adapt to the interleaved broadcast convention from #20340: head_id / rq1 → head_id % neq1 Co-Authored-By: Claude Opus 4.6 --------- Co-authored-by: Progeny Alpha Co-authored-by: Claude Opus 4.6 --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 119 ++++++++++++++++ .../vulkan-shaders/gated_delta_net.comp | 128 ++++++++++++++++++ .../vulkan-shaders/vulkan-shaders-gen.cpp | 2 + 3 files changed, 249 insertions(+) create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 2a2f7f4f11c..3c81805b844 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -825,6 +825,8 @@ struct vk_device_struct { vk_pipeline pipeline_pool2d_f32; vk_pipeline pipeline_rwkv_wkv6_f32; vk_pipeline pipeline_rwkv_wkv7_f32; + // [size_idx][kda] where size_idx: 0=d32, 1=d64, 2=d128 + vk_pipeline pipeline_gated_delta_net[3][2]; vk_pipeline pipeline_ssm_scan_f32_d128; vk_pipeline pipeline_ssm_scan_f32_d256; vk_pipeline pipeline_ssm_conv_f32; @@ -1454,6 +1456,18 @@ struct vk_op_rwkv_wkv7_push_constants { uint32_t C; uint32_t H; }; +struct vk_op_gated_delta_net_push_constants { + uint32_t H; + uint32_t n_tokens; + uint32_t n_seqs; + uint32_t s_off; + uint32_t sq1, sq2, sq3; + uint32_t sv1, sv2, sv3; + uint32_t sb1, sb2, sb3; + uint32_t neq1, rq3; + float scale; +}; + struct vk_op_ssm_scan_push_constants { uint32_t nb02, nb03, nb12, nb13; uint32_t nb21, nb22, nb31; @@ -4568,6 +4582,23 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1); + { + const uint32_t gdn_sizes[] = {32, 64, 128}; + const char * gdn_names[][2] = { + {"gated_delta_net_f32_d32", "gated_delta_net_f32_d32_kda"}, + {"gated_delta_net_f32_d64", "gated_delta_net_f32_d64_kda"}, + {"gated_delta_net_f32_d128", "gated_delta_net_f32_d128_kda"}, + }; + for (uint32_t si = 0; si < 3; si++) { + for (uint32_t kda = 0; kda < 2; kda++) { + ggml_vk_create_pipeline(device, device->pipeline_gated_delta_net[si][kda], + gdn_names[si][kda], gated_delta_net_f32_len, gated_delta_net_f32_data, + "main", 7, sizeof(vk_op_gated_delta_net_push_constants), + {1, 1, 1}, {gdn_sizes[si], kda}, 1); + } + } + } + if (device->subgroup_arithmetic && device->subgroup_require_full_support) { ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size}, 1, true, true); ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size}, 1, true, true); @@ -9498,6 +9529,20 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_rwkv_wkv7_f32; } return nullptr; + case GGML_OP_GATED_DELTA_NET: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + const uint32_t S_v = dst->src[2]->ne[0]; + const uint32_t kda = (dst->src[3]->ne[0] == (int64_t)S_v) ? 1 : 0; + uint32_t si; + switch (S_v) { + case 32: si = 0; break; + case 64: si = 1; break; + case 128: si = 2; break; + default: return nullptr; + } + return ctx->device->pipeline_gated_delta_net[si][kda]; + } + return nullptr; case GGML_OP_SSM_SCAN: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { const uint32_t d_state = src0->ne[0]; @@ -10328,6 +10373,59 @@ static void ggml_vk_rwkv_wkv7(ggml_backend_vk_context * ctx, vk_context& subctx, ); } +static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) { + const ggml_tensor * src_q = dst->src[0]; + const ggml_tensor * src_v = dst->src[2]; + const ggml_tensor * src_beta = dst->src[4]; + + GGML_ASSERT(dst->buffer != nullptr); + + const uint32_t S_v = (uint32_t)src_v->ne[0]; + const uint32_t H = (uint32_t)src_v->ne[1]; + const uint32_t n_tokens = (uint32_t)src_v->ne[2]; + const uint32_t n_seqs = (uint32_t)src_v->ne[3]; + + const uint32_t s_off = S_v * H * n_tokens * n_seqs; + + vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, dst->src[0], dst->src[1], dst->src[2], dst, dst->op); + GGML_ASSERT(pipeline != nullptr); + + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + + vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst); + vk_subbuffer src_buf[6] = {}; + for (int i = 0; i < 6; i++) { + src_buf[i] = ggml_vk_tensor_subbuffer(ctx, dst->src[i]); + } + + const uint32_t sq1 = (uint32_t)(src_q->nb[1] / sizeof(float)); + const uint32_t sq2 = (uint32_t)(src_q->nb[2] / sizeof(float)); + const uint32_t sq3 = (uint32_t)(src_q->nb[3] / sizeof(float)); + const uint32_t sv1 = (uint32_t)(src_v->nb[1] / sizeof(float)); + const uint32_t sv2 = (uint32_t)(src_v->nb[2] / sizeof(float)); + const uint32_t sv3 = (uint32_t)(src_v->nb[3] / sizeof(float)); + const uint32_t sb1 = (uint32_t)(src_beta->nb[1] / sizeof(float)); + const uint32_t sb2 = (uint32_t)(src_beta->nb[2] / sizeof(float)); + const uint32_t sb3 = (uint32_t)(src_beta->nb[3] / sizeof(float)); + + const uint32_t neq1 = (uint32_t)src_q->ne[1]; + const uint32_t rq3 = (uint32_t)(src_v->ne[3] / src_q->ne[3]); + + const float scale = 1.0f / sqrtf((float)S_v); + const vk_op_gated_delta_net_push_constants pc = { + H, n_tokens, n_seqs, s_off, + sq1, sq2, sq3, + sv1, sv2, sv3, + sb1, sb2, sb3, + neq1, rq3, + scale + }; + + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, + {src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], dst_buf}, + pc, { H, n_seqs, 1u }); +} + static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; @@ -13044,6 +13142,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr break; + case GGML_OP_GATED_DELTA_NET: + ggml_vk_gated_delta_net(ctx, compute_ctx, node); + + break; + case GGML_OP_SSM_SCAN: ggml_vk_ssm_scan(ctx, compute_ctx, node); @@ -15455,6 +15558,19 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_RWKV_WKV6: case GGML_OP_RWKV_WKV7: return true; // all inputs are contiguous, see ggml.c + case GGML_OP_GATED_DELTA_NET: + { + const uint32_t S_v = op->src[2]->ne[0]; + if (S_v != 32 && S_v != 64 && S_v != 128) { + return false; + } + for (int i = 0; i < 6; i++) { + if (op->src[i] == nullptr || op->src[i]->type != GGML_TYPE_F32) { + return false; + } + } + return op->type == GGML_TYPE_F32; + } case GGML_OP_SSM_SCAN: { for (int i = 0; i < 6; i++) { @@ -16332,6 +16448,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * } else if (tensor->op == GGML_OP_RWKV_WKV7) { tensor_clone = ggml_rwkv_wkv7(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3], src_clone[4], src_clone[5], src_clone[6]); + } else if (tensor->op == GGML_OP_GATED_DELTA_NET) { + tensor_clone = ggml_gated_delta_net(ggml_ctx, src_clone[0], src_clone[1], + src_clone[2], src_clone[3], src_clone[4], src_clone[5]); } else if (tensor->op == GGML_OP_OPT_STEP_ADAMW) { src_clone[0]->flags = tensor->src[0]->flags; tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1], diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp new file mode 100644 index 00000000000..1fdf889e824 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp @@ -0,0 +1,128 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : require + +layout(constant_id = 0) const uint S_V = 128; +layout(constant_id = 1) const uint KDA = 0; + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout(push_constant) uniform Parameters { + uint H; + uint n_tokens; + uint n_seqs; + uint s_off; + uint sq1, sq2, sq3; + uint sv1, sv2, sv3; + uint sb1, sb2, sb3; + uint neq1, rq3; + float scale; +}; + +layout(binding = 0) readonly buffer QBuf { FLOAT_TYPE data_q[]; }; +layout(binding = 1) readonly buffer KBuf { FLOAT_TYPE data_k[]; }; +layout(binding = 2) readonly buffer VBuf { FLOAT_TYPE data_v[]; }; +layout(binding = 3) readonly buffer GBuf { FLOAT_TYPE data_g[]; }; +layout(binding = 4) readonly buffer BetaBuf { FLOAT_TYPE data_beta[]; }; +layout(binding = 5) readonly buffer StateBuf { FLOAT_TYPE data_state[]; }; +layout(binding = 6) buffer DstBuf { FLOAT_TYPE data_dst[]; }; + +shared FLOAT_TYPE s_k[S_V]; +shared FLOAT_TYPE s_q[S_V]; +shared FLOAT_TYPE s_g[S_V]; // KDA only: cached exp(g[i]) + +void main() { + const uint head_id = gl_WorkGroupID.x; + const uint seq_id = gl_WorkGroupID.y; + const uint col = gl_LocalInvocationID.x; + + const uint iq1 = head_id % neq1; + const uint iq3 = seq_id / rq3; + + const uint state_size = S_V * S_V; + const uint state_base = (seq_id * H + head_id) * state_size; + + FLOAT_TYPE state[S_V]; + [[unroll]] for (uint i = 0; i < S_V; i++) { + state[i] = FLOAT_TYPE(data_state[state_base + i * S_V + col]); + } + + uint attn_off = (seq_id * n_tokens * H + head_id) * S_V; + + for (uint t = 0; t < n_tokens; t++) { + const uint q_off = iq3 * sq3 + t * sq2 + iq1 * sq1; + const uint k_off = q_off; + const uint v_off = seq_id * sv3 + t * sv2 + head_id * sv1; + + s_q[col] = FLOAT_TYPE(data_q[q_off + col]); + s_k[col] = FLOAT_TYPE(data_k[k_off + col]); + + const uint gb_off = seq_id * sb3 + t * sb2 + head_id * sb1; + + if (KDA != 0) { + const uint g_base = gb_off * S_V; + s_g[col] = exp(FLOAT_TYPE(data_g[g_base + col])); + } + + barrier(); + + const FLOAT_TYPE v_val = FLOAT_TYPE(data_v[v_off + col]); + const FLOAT_TYPE beta_val = FLOAT_TYPE(data_beta[gb_off]); + + if (KDA == 0) { + const FLOAT_TYPE g_val = exp(FLOAT_TYPE(data_g[gb_off])); + + FLOAT_TYPE kv_col = 0.0; + [[unroll]] for (uint i = 0; i < S_V; i += 4) { + kv_col += dot( + vec4(state[i], state[i+1], state[i+2], state[i+3]), + vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3]) + ); + } + + FLOAT_TYPE delta_col = (v_val - g_val * kv_col) * beta_val; + + FLOAT_TYPE attn_col = 0.0; + [[unroll]] for (uint i = 0; i < S_V; i += 4) { + vec4 sv = vec4(state[i], state[i+1], state[i+2], state[i+3]); + vec4 kv = vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3]); + sv = g_val * sv + kv * delta_col; + state[i] = sv.x; state[i+1] = sv.y; state[i+2] = sv.z; state[i+3] = sv.w; + + attn_col += dot(sv, vec4(s_q[i], s_q[i+1], s_q[i+2], s_q[i+3])); + } + + data_dst[attn_off + col] = attn_col * scale; + } else { + FLOAT_TYPE kv_col = 0.0; + [[unroll]] for (uint i = 0; i < S_V; i += 4) { + vec4 gv = vec4(s_g[i], s_g[i+1], s_g[i+2], s_g[i+3]); + vec4 sv = vec4(state[i], state[i+1], state[i+2], state[i+3]); + vec4 kv = vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3]); + kv_col += dot(gv * sv, kv); + } + + FLOAT_TYPE delta_col = (v_val - kv_col) * beta_val; + + FLOAT_TYPE attn_col = 0.0; + [[unroll]] for (uint i = 0; i < S_V; i += 4) { + vec4 gv = vec4(s_g[i], s_g[i+1], s_g[i+2], s_g[i+3]); + vec4 sv = vec4(state[i], state[i+1], state[i+2], state[i+3]); + vec4 kv = vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3]); + sv = gv * sv + kv * delta_col; + state[i] = sv.x; state[i+1] = sv.y; state[i+2] = sv.z; state[i+3] = sv.w; + + attn_col += dot(sv, vec4(s_q[i], s_q[i+1], s_q[i+2], s_q[i+3])); + } + + data_dst[attn_off + col] = attn_col * scale; + } + + attn_off += S_V * H; + barrier(); + } + + [[unroll]] for (uint i = 0; i < S_V; i++) { + data_dst[s_off + state_base + i * S_V + col] = state[i]; + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index fb8941232bc..4b00ba3debb 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -987,6 +987,8 @@ void process_shaders() { string_to_spv("rwkv_wkv7_f32", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); + string_to_spv("gated_delta_net_f32", "gated_delta_net.comp", merge_maps(base_dict, {{"FLOAT_TYPE", "float"}})); + string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); string_to_spv("opt_step_sgd_f32", "opt_step_sgd.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); From 2ed6dc0222cd99ba33004ed400c18018c6f80397 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 12 Mar 2026 21:04:13 +0200 Subject: [PATCH 265/831] llama : disable graph reuse with pipeline parallelism (llama/20463) --- ggml/src/ggml-backend.cpp | 14 +++++--------- ggml/src/ggml-cuda/ggml-cuda.cu | 14 ++++---------- 2 files changed, 9 insertions(+), 19 deletions(-) diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index bc57df20ba2..22c656996cc 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -1455,10 +1455,6 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s int split_backend_id = split->backend_id; ggml_backend_t split_backend = sched->backends[split_backend_id]; - if (sched->events[split_backend_id][sched->cur_copy] == NULL) { - ggml_backend_synchronize(split_backend); - } - // copy the input tensors to the split backend for (int input_id = 0; input_id < split->n_inputs; input_id++) { ggml_backend_t input_backend = ggml_backend_sched_get_tensor_backend(sched, split->inputs[input_id]); @@ -1469,12 +1465,16 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s // inputs from the user must be copied immediately to prevent the user overwriting the data before the copy is done if (sched->events[split_backend_id][sched->cur_copy] != NULL) { ggml_backend_event_synchronize(sched->events[split_backend_id][sched->cur_copy]); + } else { + ggml_backend_synchronize(split_backend); } - ggml_backend_tensor_copy_async(input_backend, split_backend, input, input_cpy); + ggml_backend_tensor_copy(input, input_cpy); } else { // wait for the split backend to finish using the input before overwriting it if (sched->events[split_backend_id][sched->cur_copy] != NULL) { ggml_backend_event_wait(split_backend, sched->events[split_backend_id][sched->cur_copy]); + } else { + ggml_backend_synchronize(split_backend); } // when offloading MoE weights, we can reduce the amount of data copied by copying only the experts that are used @@ -1578,10 +1578,6 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s } } - if (sched->events[split_backend_id][sched->cur_copy] == NULL) { - ggml_backend_synchronize(split_backend); - } - if (!sched->callback_eval) { enum ggml_status ec = ggml_backend_graph_compute_async(split_backend, &split->graph); if (ec != GGML_STATUS_SUCCESS) { diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index cda275b8c58..9d2aacf4b2c 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2823,14 +2823,11 @@ static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_ ggml_backend_buffer_t buf_src = src->view_src ? src->view_src->buffer : src->buffer; ggml_backend_buffer_t buf_dst = dst->view_src ? dst->view_src->buffer : dst->buffer; - //enables async copies from CPU to CUDA, instead of only CUDA-to-CUDA - bool copy_from_host = ggml_backend_buffer_is_host(buf_src) && ggml_backend_dev_type(backend_src->device) == GGML_BACKEND_DEVICE_TYPE_CPU; - - if (!(copy_from_host || ggml_backend_is_cuda(backend_src)) || !ggml_backend_is_cuda(backend_dst)) { + if (!ggml_backend_is_cuda(backend_src) || !ggml_backend_is_cuda(backend_dst)) { return false; } - if (!(copy_from_host || ggml_backend_buffer_is_cuda(buf_src)) || !ggml_backend_buffer_is_cuda(dst->buffer)) { + if (!ggml_backend_buffer_is_cuda(src->buffer) || !ggml_backend_buffer_is_cuda(dst->buffer)) { return false; } @@ -2841,17 +2838,14 @@ static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_ ggml_backend_cuda_buffer_context * buf_ctx_src = (ggml_backend_cuda_buffer_context *)buf_src->context; ggml_backend_cuda_buffer_context * buf_ctx_dst = (ggml_backend_cuda_buffer_context *)buf_dst->context; - if ((copy_from_host && cuda_ctx_dst->device != buf_ctx_dst->device) || - !copy_from_host && (cuda_ctx_src->device != buf_ctx_src->device || cuda_ctx_dst->device != buf_ctx_dst->device)) { + if (cuda_ctx_src->device != buf_ctx_src->device || cuda_ctx_dst->device != buf_ctx_dst->device) { #ifndef NDEBUG GGML_LOG_DEBUG("%s: backend and buffer devices do not match\n", __func__); #endif return false; } - if (copy_from_host) { - CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(dst), cudaMemcpyHostToDevice, cuda_ctx_dst->stream())); - } else if (backend_src != backend_dst) { + if (backend_src != backend_dst) { // copy on src stream if (cuda_ctx_src->device == cuda_ctx_dst->device) { CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(dst), cudaMemcpyDeviceToDevice, cuda_ctx_src->stream())); From f1f5f43d6979036346ae64d1e8a518ff36480ace Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 13 Mar 2026 11:43:20 +0200 Subject: [PATCH 266/831] metal : fix l2 norm scale (llama/20493) --- ggml/src/ggml-metal/ggml-metal-device.m | 2 +- ggml/src/ggml-metal/ggml-metal.metal | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 05b826a61b8..b7d587f3bd9 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -1156,7 +1156,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_RWKV_WKV7: return true; case GGML_OP_GATED_DELTA_NET: - return op->src[2]->ne[0] % 32 == 0; + return has_simdgroup_reduction && op->src[2]->ne[0] % 32 == 0; case GGML_OP_SOLVE_TRI: case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 24a3092af22..107e7cf2ff3 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -3006,7 +3006,7 @@ kernel void kernel_l2_norm_impl( sumf = shmem_f32[tiisg]; sumf = simd_sum(sumf); - const float scale = 1.0f/sqrt(max(sumf, args.eps)); + const float scale = 1.0f/max(sqrt(sumf), args.eps); for (int i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) { y[i00] = x[i00] * scale; From 9bfa81d262b67a61b67224dbc578aaca06e01078 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrien=20Gallou=C3=ABt?= Date: Fri, 13 Mar 2026 14:36:13 +0100 Subject: [PATCH 267/831] ggml : fix typo gmml (llama/20512) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Adrien Gallouët --- ggml/CMakeLists.txt | 2 +- ggml/src/ggml-cpu/ops.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 4323afe57b5..8f679e2fd35 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -253,7 +253,7 @@ option(GGML_OPENCL_PROFILING "ggml: use OpenCL profiling (increas option(GGML_OPENCL_EMBED_KERNELS "ggml: embed kernels" ON) option(GGML_OPENCL_USE_ADRENO_KERNELS "ggml: use optimized kernels for Adreno" ON) set (GGML_OPENCL_TARGET_VERSION "300" CACHE STRING - "gmml: OpenCL API version to target") + "ggml: OpenCL API version to target") option(GGML_HEXAGON "ggml: enable Hexagon backend" OFF) set(GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE 128 CACHE STRING "ggml: quantize group size (32, 64, or 128)") diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index fa9d27046b5..85db02d92f1 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -9624,7 +9624,7 @@ void ggml_compute_forward_win_unpart( } } -//gmml_compute_forward_unary +//ggml_compute_forward_unary void ggml_compute_forward_unary( const ggml_compute_params * params, From 5905e8708f6ef43ceb8dc99197524c4deadeb5b0 Mon Sep 17 00:00:00 2001 From: rehan-10xengineer Date: Fri, 13 Mar 2026 20:36:04 +0500 Subject: [PATCH 268/831] ggml-cpu: add RVV vec dot kernels for quantization types (llama/18859) * ggml-cpu: add rvv quantize_row_q8_K kernel Co-authored-by: Rehan Qasim * ggml-cpu: add rvv vec_dot for iq4_nl, mxfp4, iq2_xxs Co-authored-by: Rehan Qasim * ggml-cpu: add rvv vec_dot for iq4_xs, refactor * ggml-cpu: remove ifunc for rvv vec dot * ggml-cpu: add vec_dot for iq2_xs, iq3_xxs Co-authored-by: Rehan Qasim * ggml-cpu: refactor quants.c --------- Co-authored-by: taimur-10x Co-authored-by: Rehan Qasim Co-authored-by: Rehan Qasim --- ggml/src/ggml-cpu/arch-fallback.h | 7 - ggml/src/ggml-cpu/arch/riscv/quants.c | 1555 +++++++++++++++++++------ 2 files changed, 1219 insertions(+), 343 deletions(-) diff --git a/ggml/src/ggml-cpu/arch-fallback.h b/ggml/src/ggml-cpu/arch-fallback.h index 175aa4a4bb9..41da829315b 100644 --- a/ggml/src/ggml-cpu/arch-fallback.h +++ b/ggml/src/ggml-cpu/arch-fallback.h @@ -199,13 +199,6 @@ #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 #elif defined(__riscv) // quants.c -#define quantize_row_q8_K_generic quantize_row_q8_K -#define ggml_vec_dot_iq2_xxs_q8_K_generic ggml_vec_dot_iq2_xxs_q8_K -#define ggml_vec_dot_iq2_xs_q8_K_generic ggml_vec_dot_iq2_xs_q8_K -#define ggml_vec_dot_iq3_xxs_q8_K_generic ggml_vec_dot_iq3_xxs_q8_K -#define ggml_vec_dot_iq4_nl_q8_0_generic ggml_vec_dot_iq4_nl_q8_0 -#define ggml_vec_dot_iq4_xs_q8_K_generic ggml_vec_dot_iq4_xs_q8_K -#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0 #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 // repack.cpp #define ggml_quantize_mat_q8_0_4x1_generic ggml_quantize_mat_q8_0_4x1 diff --git a/ggml/src/ggml-cpu/arch/riscv/quants.c b/ggml/src/ggml-cpu/arch/riscv/quants.c index bf9f4df1182..826055dd9a4 100644 --- a/ggml/src/ggml-cpu/arch/riscv/quants.c +++ b/ggml/src/ggml-cpu/arch/riscv/quants.c @@ -113,6 +113,104 @@ void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, i #endif } +void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) { + assert(k % QK_K == 0); + block_q8_K * y_blocks = (block_q8_K *)y; + size_t nb = k / QK_K; + +#if defined(__riscv_v_intrinsic) + const size_t vlmax_f32m8 = __riscv_vsetvlmax_e32m8(); + + for (size_t i = 0; i < nb; i++) { + const float* x_block = x + i * QK_K; + block_q8_K* y_block = &y_blocks[i]; + + // 1. Calculate Min/Max + vfloat32m8_t max_v = __riscv_vfmv_v_f_f32m8(-__builtin_inff(), vlmax_f32m8); + vfloat32m8_t min_v = __riscv_vfmv_v_f_f32m8(__builtin_inff(), vlmax_f32m8); + + size_t rem = QK_K; + size_t offset = 0; + while (rem > 0) { + size_t vl = __riscv_vsetvl_e32m8(rem); + vfloat32m8_t v_curr = __riscv_vle32_v_f32m8(x_block + offset, vl); + max_v = __riscv_vfmax_vv_f32m8(max_v, v_curr, vl); + min_v = __riscv_vfmin_vv_f32m8(min_v, v_curr, vl); + rem -= vl; + offset += vl; + } + + vfloat32m1_t v_init_max = __riscv_vfmv_s_f_f32m1(-__builtin_inff(), 1); + vfloat32m1_t v_init_min = __riscv_vfmv_s_f_f32m1(__builtin_inff(), 1); + + vfloat32m1_t v_scalar_max = __riscv_vfredmax_vs_f32m8_f32m1(max_v, v_init_max, vlmax_f32m8); + vfloat32m1_t v_scalar_min = __riscv_vfredmin_vs_f32m8_f32m1(min_v, v_init_min, vlmax_f32m8); + + float max_val = __riscv_vfmv_f_s_f32m1_f32(v_scalar_max); + float min_val = __riscv_vfmv_f_s_f32m1_f32(v_scalar_min); + + float amax = fabsf(max_val) > fabsf(min_val) ? fabsf(max_val) : fabsf(min_val); + + if (amax == 0.0f) { + y_block->d = 0.0f; + memset(y_block->qs, 0, QK_K); + memset(y_block->bsums, 0, sizeof(y_block->bsums)); + continue; + } + + const float iscale = -127.f / (fabsf(max_val) > fabsf(min_val) ? max_val : min_val); + y_block->d = 1.0f / iscale; + + // 2. Quantize and Calculate Sums + offset = 0; + rem = QK_K; + vint16m1_t v_zero_sum = __riscv_vmv_v_x_i16m1(0, 1); + + while (rem > 0) { + size_t vl = __riscv_vsetvl_e32m8(rem); + vfloat32m8_t v_f = __riscv_vle32_v_f32m8(x_block + offset, vl); + + v_f = __riscv_vfmul_vf_f32m8(v_f, iscale, vl); + + vint32m8_t v_i32 = __riscv_vfcvt_x_f_v_i32m8_rm(v_f, __RISCV_FRM_RNE, vl); + vint16m4_t v_i16 = __riscv_vnclip_wx_i16m4(v_i32, 0, __RISCV_VXRM_RNE, vl); + vint8m2_t v_q = __riscv_vnclip_wx_i8m2(v_i16, 0, __RISCV_VXRM_RNE, vl); + + __riscv_vse8_v_i8m2(y_block->qs + offset, v_q, vl); + + // first iteration clear + + int sum_idx; + vint8m1_t chunk_m1; + vint16m1_t v_sum; + sum_idx = offset / 16; + chunk_m1 = __riscv_vget_v_i8m2_i8m1(v_q, 0); + v_sum = __riscv_vwredsum_vs_i8m1_i16m1(chunk_m1, v_zero_sum, 16); + y_block->bsums[sum_idx] = (int16_t)__riscv_vmv_x_s_i16m1_i16(v_sum); + + // remaining iterations + vint8m2_t slid_q = v_q; + for (size_t k = 16; k < vl; k += 16) { + slid_q = __riscv_vslidedown_vx_i8m2(slid_q, 16, vl); + + sum_idx = (offset + k) / 16; + chunk_m1 = __riscv_vget_v_i8m2_i8m1(slid_q, 0); + + v_sum = __riscv_vwredsum_vs_i8m1_i16m1(chunk_m1, v_zero_sum, 16); + y_block->bsums[sum_idx] =(int16_t)__riscv_vmv_x_s_i16m1_i16(v_sum); + } + + rem -= vl; + offset += vl; + } + } +#else + GGML_UNUSED(nb); + // scalar + quantize_row_q8_K_ref(x, y, k); +#endif +} + //===================================== Dot products ================================= void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { @@ -1954,151 +2052,283 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi #endif } -static const uint8_t sign_gather_indices_arr[64] = { - 0,0,0,0,0,0,0,0, 1,1,1,1,1,1,1,1, 2,2,2,2,2,2,2,2, 3,3,3,3,3,3,3,3, - 4,4,4,4,4,4,4,4, 5,5,5,5,5,5,5,5, 6,6,6,6,6,6,6,6, 7,7,7,7,7,7,7,7 -}; - -static const uint8_t sign_bit_masks_arr[64] = { - 1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128, - 1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128 -}; - -static void ggml_vec_dot_iq2_s_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +static void ggml_vec_dot_iq1_s_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); - UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); - const block_iq2_s * GGML_RESTRICT x = vx; + const block_iq1_s * GGML_RESTRICT x = vx; const block_q8_K * GGML_RESTRICT y = vy; const int nb = n / QK_K; - const uint64_t * grid64 = (const uint64_t *)iq2s_grid; - - // --- Pre-load Constants --- - uint16_t gather_qh_arr[8] = {0, 0, 0, 0, 1, 1, 1, 1}; - vuint16mf2_t v_gather_qh = __riscv_vle16_v_u16mf2(gather_qh_arr, 8); - uint16_t shift_qh_arr[8] = {11, 9, 7, 5, 11, 9, 7, 5}; - vuint16mf2_t v_shift_qh = __riscv_vle16_v_u16mf2(shift_qh_arr, 8); - - // Constants for sign extraction - vuint8m2_t v_sign_gather_indices = __riscv_vle8_v_u8m2(sign_gather_indices_arr, 64); - vuint8m2_t v_sign_masks = __riscv_vle8_v_u8m2(sign_bit_masks_arr, 64); - - float sumf = 0.0f; + float sumf = 0; for (int i = 0; i < nb; ++i) { - const float combined_scale = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; - - const uint8_t * GGML_RESTRICT qs = x[i].qs; - const uint8_t * GGML_RESTRICT qh = x[i].qh; - const uint8_t * GGML_RESTRICT scales = x[i].scales; - const int8_t * GGML_RESTRICT q8 = y[i].qs; - - const uint8_t * signs_ptr = qs + 32; - - float sum_block = 0.0f; - - for (int ib = 0; ib < 4; ++ib) { - // Combine low + high bits - vuint8mf4_t v_qs_u8 = __riscv_vle8_v_u8mf4(qs, 8); - qs += 8; - uint16_t qh_val; - memcpy(&qh_val, qh, 2); - qh += 2; - vuint8mf8_t v_qh_raw = __riscv_vle8_v_u8mf8((const uint8_t*)&qh_val, 2); - vuint16mf4_t v_qh_u16 = __riscv_vwcvtu_x_x_v_u16mf4(v_qh_raw, 2); - vuint16mf2_t v_qh_u16_ext = __riscv_vlmul_ext_v_u16mf4_u16mf2(v_qh_u16); - vuint16mf2_t v_qh_expanded = __riscv_vrgather_vv_u16mf2(v_qh_u16_ext, v_gather_qh, 8); - v_qh_expanded = __riscv_vsll_vv_u16mf2(v_qh_expanded, v_shift_qh, 8); - - // Mask: We want bits 11-12. 0x1800 = 0001 1000 0000 0000 - v_qh_expanded = __riscv_vand_vx_u16mf2(v_qh_expanded, 0x1800, 8); - vuint16mf2_t v_qs_u16 = __riscv_vwcvtu_x_x_v_u16mf2(v_qs_u8, 8); - - // Multiply by 8 to get byte offset, instead of element offset - v_qs_u16 = __riscv_vsll_vx_u16mf2(v_qs_u16, 3, 8); - vuint16mf2_t v_grid_offsets = __riscv_vor_vv_u16mf2(v_qs_u16, v_qh_expanded, 8); + // Load qh once for the entire superblock. + vuint16mf2_t qh = __riscv_vle16_v_u16mf2(x[i].qh, 8); - // Lookup Grid using Byte Offsets - vuint64m2_t v_grid_vals = __riscv_vluxei16_v_u64m2(grid64, v_grid_offsets, 8); + // Calculate ls. + vuint16mf2_t temp = __riscv_vsrl_vx_u16mf2(qh, 12, 8); + temp = __riscv_vand_vx_u16mf2(temp, 7, 8); + vint32m1_t ls = __riscv_vreinterpret_v_u32m1_i32m1(__riscv_vwmulu_vx_u32m1(temp, 2, 8)); + ls = __riscv_vadd_vx_i32m1(ls, 1, 8); - vuint8m2_t v_grid_u8 = __riscv_vreinterpret_v_u64m2_u8m2(v_grid_vals); - vint8m2_t v_grid_i8 = __riscv_vreinterpret_v_u8m2_i8m2(v_grid_u8); + // Calculate delta. + vbool32_t mask = __riscv_vmseq_vx_u16mf2_b32(__riscv_vand_vx_u16mf2(qh, 0x8000, 8), 0, 8); + vint32m1_t delta_neg = __riscv_vmv_v_x_i32m1(-1, 8); + vint32m1_t delta_pos = __riscv_vmv_v_x_i32m1(1, 8); + vint32m1_t delta = __riscv_vmerge_vvm_i32m1(delta_neg, delta_pos, mask, 8); - // Load signs and generate sign mask - vuint8mf4_t v_signs_raw = __riscv_vle8_v_u8mf4(signs_ptr, 8); - signs_ptr += 8; + // Load qs. + vuint8m1_t qs = __riscv_vle8_v_u8m1(x[i].qs, 32); - vuint8m2_t v_signs_source = __riscv_vlmul_ext_v_u8mf4_u8m2(v_signs_raw); - vuint8m2_t v_signs_bcast = __riscv_vrgather_vv_u8m2(v_signs_source, v_sign_gather_indices, 64); + // Prepare the indices. + const uint64_t shift = 0x0009000600030000; + vuint16m2_t qh_shift = __riscv_vreinterpret_v_u64m2_u16m2(__riscv_vmv_v_x_u64m2(shift, 8)); + vuint16m2_t qh_gather_index = __riscv_vreinterpret_v_i16m2_u16m2( + __riscv_vdiv_vx_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vid_v_u16m2(32)), 4, 32)); + vuint16m2_t qh_ext = __riscv_vlmul_ext_v_u16m1_u16m2(__riscv_vlmul_ext_v_u16mf2_u16m1(qh)); + vuint16m2_t qh_index = __riscv_vrgather_vv_u16m2(qh_ext, qh_gather_index, 32); + qh_index = __riscv_vsrl_vv_u16m2(qh_index, qh_shift, 32); + qh_index = __riscv_vand_vx_u16m2(qh_index, 7, 32); + qh_index = __riscv_vsll_vx_u16m2(qh_index, 8, 32); + qh_index = __riscv_vor_vv_u16m2(qh_index, __riscv_vzext_vf2_u16m2(qs, 32), 32); + vuint16m2_t index = __riscv_vsll_vx_u16m2(qh_index, 3, 32); - vuint8m2_t v_sign_bits = __riscv_vand_vv_u8m2(v_signs_bcast, v_sign_masks, 64); - vbool4_t m_negative = __riscv_vmsne_vx_u8m2_b4(v_sign_bits, 0, 64); + // Final lsums. + int32_t lsums_s[8]; + vint32m1_t one_scalar = __riscv_vmv_v_x_i32m1(0, 1); - vint8m2_t v_q8 = __riscv_vle8_v_i8m2(q8, 64); - q8 += 64; + // Sub-blocks 1-4 + { + vuint16m1_t grid_index0 = __riscv_vget_v_u16m2_u16m1(index, 0); + vint8m4_t grid0 = __riscv_vreinterpret_v_i64m4_i8m4(__riscv_vluxei16_v_i64m4((const int64_t*)iq1s_grid, grid_index0, 16)); + vint8m4_t q80 = __riscv_vle8_v_i8m4(y[i].qs, 128); + vint16m8_t lsum0 = __riscv_vwmul_vv_i16m8(grid0, q80, 128); + lsums_s[0] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum0, 0), one_scalar, 32)); + lsums_s[1] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum0, 1), one_scalar, 32)); + lsums_s[2] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum0, 2), one_scalar, 32)); + lsums_s[3] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum0, 3), one_scalar, 32)); + } + __asm__ __volatile__("" ::: "memory"); + // Sub-blocks 5-8 + { + vuint16m1_t grid_index1 = __riscv_vget_v_u16m2_u16m1(index, 1); + vint8m4_t grid1 = __riscv_vreinterpret_v_i64m4_i8m4(__riscv_vluxei16_v_i64m4((const int64_t*)iq1s_grid, grid_index1, 16)); + vint8m4_t q81 = __riscv_vle8_v_i8m4(&y[i].qs[128], 128); + vint16m8_t lsum1 = __riscv_vwmul_vv_i16m8(grid1, q81, 128); + lsums_s[4] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum1, 0), one_scalar, 32)); + lsums_s[5] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum1, 1), one_scalar, 32)); + lsums_s[6] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum1, 2), one_scalar, 32)); + lsums_s[7] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum1, 3), one_scalar, 32)); + } + __asm__ __volatile__("" ::: "memory"); + vint32m1_t lsums = __riscv_vle32_v_i32m1(&lsums_s[0], 8); - vint8m2_t v_q8_signed = __riscv_vrsub_vx_i8m2_mu(m_negative, v_q8, v_q8, 0, 64); - vint16m4_t v_dot = __riscv_vwmul_vv_i16m4(v_grid_i8, v_q8_signed, 64); + // Calculate the bsums. + vint16m1_t bsums_0 = __riscv_vle16_v_i16m1(y[i].bsums, 16); + const vuint32m1_t bsums_i32 = __riscv_vreinterpret_v_u16m1_u32m1(__riscv_vreinterpret_v_i16m1_u16m1(bsums_0)); + const vint16mf2_t bsums_i32_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(bsums_i32, 0, 8)); + const vint16mf2_t bsums_i32_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(bsums_i32, 16, 8)); + const vint32m1_t bsums = __riscv_vwadd_vv_i32m1(bsums_i32_0, bsums_i32_1, 8); - vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, 1); + // Accumulation. + vint32m1_t sumi_v = __riscv_vmul_vv_i32m1(ls, lsums, 8); + vint32m1_t sumi1_v = __riscv_vmul_vv_i32m1(__riscv_vmul_vv_i32m1(ls, delta, 8), bsums, 8); - int32_t s0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( - __riscv_vget_v_i16m4_i16m1(v_dot, 0), v_zero, 16)); - int32_t s1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( - __riscv_vget_v_i16m4_i16m1(v_dot, 1), v_zero, 16)); - int32_t s2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( - __riscv_vget_v_i16m4_i16m1(v_dot, 2), v_zero, 16)); - int32_t s3 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( - __riscv_vget_v_i16m4_i16m1(v_dot, 3), v_zero, 16)); + // Update sumf. + int sumi = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m1_i32m1(sumi_v, __riscv_vmv_v_x_i32m1(0.0f, 1), 8)); + int sumi1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m1_i32m1(sumi1_v, __riscv_vmv_v_x_i32m1(0.0f, 1), 8)); + sumf += GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d * (sumi + IQ1S_DELTA * sumi1); + } - uint8_t sc0 = scales[0]; - uint8_t sc1 = scales[1]; - scales += 2; + *s = sumf; +} - sum_block += s0 * (2 * (sc0 & 0xF) + 1); - sum_block += s1 * (2 * (sc0 >> 4) + 1); - sum_block += s2 * (2 * (sc1 & 0xF) + 1); - sum_block += s3 * (2 * (sc1 >> 4) + 1); - } - sumf += sum_block * combined_scale; +void ggml_vec_dot_iq1_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_v_intrinsic + switch (__riscv_vlenb() * 8) { + case 256: + ggml_vec_dot_iq1_s_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); + break; + default: + ggml_vec_dot_iq1_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); + break; } - *s = 0.125f * sumf; +#else + ggml_vec_dot_iq1_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif } -static void ggml_vec_dot_iq2_s_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +static void ggml_vec_dot_iq1_m_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); - UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); - const block_iq2_s * GGML_RESTRICT x = vx; + const block_iq1_m * GGML_RESTRICT x = vx; const block_q8_K * GGML_RESTRICT y = vy; const int nb = n / QK_K; - const uint64_t * grid64 = (const uint64_t *)iq2s_grid; - - // Pre-load Constants - vuint8m2_t v_ids = __riscv_vid_v_u8m2(32); - vuint8m2_t v_sign_gather_indices = __riscv_vsrl_vx_u8m2(v_ids, 3, 32); - vuint8m2_t v_ones = __riscv_vmv_v_x_u8m2(1, 32); - vuint8m2_t v_shift_amts = __riscv_vand_vx_u8m2(v_ids, 7, 32); - vuint8m2_t v_sign_masks = __riscv_vsll_vv_u8m2(v_ones, v_shift_amts, 32); - uint16_t shift_qh_arr[4] = {11, 9, 7, 5}; - vuint16mf2_t v_shift_qh = __riscv_vle16_v_u16mf2(shift_qh_arr, 4); + iq1m_scale_t scale; float sumf = 0.0f; - for (int i = 0; i < nb; ++i) { - const float combined_scale = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint8_t * qh = x[i].qh; + const uint16_t * sc = (const uint16_t *)x[i].scales; - const uint8_t * GGML_RESTRICT qs = x[i].qs; - const uint8_t * GGML_RESTRICT qh = x[i].qh; - const uint8_t * GGML_RESTRICT scales = x[i].scales; - const int8_t * GGML_RESTRICT q8 = y[i].qs; + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); - const uint8_t * signs_ptr = qs + 32; - float sum_block = 0.0f; + // Accumulators. + vint32m2_t acc1 = __riscv_vmv_v_x_i32m2(0, 16); + vint32m2_t acc2 = __riscv_vmv_v_x_i32m2(0, 16); - for (int ib = 0; ib < 8; ++ib) { + // We process 4 sub-blocks together. + for (int ib = 0; ib < QK_K/128; ib++) { + // Load qh for 4 sub-blocks. + const vuint8mf4_t qh_8 = __riscv_vle8_v_u8mf4(qh, 8); + const vuint16mf2_t qh_16_lo = __riscv_vzext_vf2_u16mf2(qh_8, 8); + const vuint16mf2_t qh_16_hi = __riscv_vsll_vx_u16mf2(qh_16_lo, 8, 8); + const vuint16m1_t qhb = __riscv_vzext_vf2_u16m1( + __riscv_vreinterpret_v_u16mf2_u8mf2(__riscv_vor_vv_u16mf2(qh_16_lo, qh_16_hi, 8)), 16); + qh += 8; + + // Prepare grid indices. + const vuint16m1_t qsb = __riscv_vzext_vf2_u16m1(__riscv_vle8_v_u8mf2(&qs[0], 16), 16); + const vuint16m1_t shift = __riscv_vreinterpret_v_u32m1_u16m1(__riscv_vmv_v_x_u32m1(0x00040008, 8)); + vuint16m1_t index = __riscv_vor_vv_u16m1(qsb, __riscv_vand_vx_u16m1(__riscv_vsll_vv_u16m1(qhb, shift, 16), 0x700, 16), 16); + index = __riscv_vsll_vx_u16m1(index, 3, 16); + qs += 16; + + // Load the grid. + const vint8m4_t iq1b = __riscv_vreinterpret_v_i64m4_i8m4(__riscv_vreinterpret_v_u64m4_i64m4( + __riscv_vluxei16_v_u64m4(iq1s_grid, index, 16))); + + // Prepare the deltas. + const vbool16_t mask = __riscv_vmsgtu_vx_u16m1_b16( + __riscv_vand_vv_u16m1(qhb, __riscv_vreinterpret_v_u32m1_u16m1(__riscv_vmv_v_x_u32m1(0x00800008, 8)), 16), 0, 16); + const vint64m4_t delta_pos = __riscv_vmv_v_x_i64m4(0x0101010101010101, 16); + const vint64m4_t delta_neg = __riscv_vmv_v_x_i64m4(0xffffffffffffffff, 16); + const vint8m4_t delta = __riscv_vreinterpret_v_i64m4_i8m4( + __riscv_vmerge_vvm_i64m4(delta_pos, delta_neg, mask, 16)); + + // Load q8 for sub-blocks. + const vint8m4_t q8b = __riscv_vle8_v_i8m4(q8, 128); + q8 += 128; + + // Calculate the lsums. + const vint16m8_t lsum1 = __riscv_vwmul_vv_i16m8(iq1b, q8b, 128); + const vint16m8_t lsum2 = __riscv_vwmul_vv_i16m8(delta, q8b, 128); + + // Prepare the scales. + const int16_t ls_0_0 = 2*((sc[0] >> 0) & 0x7) + 1; + const int16_t ls_0_1 = 2*((sc[0] >> 3) & 0x7) + 1; + const int16_t ls_1_0 = 2*((sc[0] >> 6) & 0x7) + 1; + const int16_t ls_1_1 = 2*((sc[0] >> 9) & 0x7) + 1; + const int16_t ls_2_0 = 2*((sc[1] >> 0) & 0x7) + 1; + const int16_t ls_2_1 = 2*((sc[1] >> 3) & 0x7) + 1; + const int16_t ls_3_0 = 2*((sc[1] >> 6) & 0x7) + 1; + const int16_t ls_3_1 = 2*((sc[1] >> 9) & 0x7) + 1; + sc += 2; + + // Accumulate in acc0 and acc1 for each sub-block. + acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_0_0, __riscv_vget_v_i16m8_i16m1(lsum1, 0), 16); + acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_0_1, __riscv_vget_v_i16m8_i16m1(lsum1, 1), 16); + acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_0_0, __riscv_vget_v_i16m8_i16m1(lsum2, 0), 16); + acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_0_1, __riscv_vget_v_i16m8_i16m1(lsum2, 1), 16); + // + acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_1_0, __riscv_vget_v_i16m8_i16m1(lsum1, 2), 16); + acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_1_1, __riscv_vget_v_i16m8_i16m1(lsum1, 3), 16); + acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_1_0, __riscv_vget_v_i16m8_i16m1(lsum2, 2), 16); + acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_1_1, __riscv_vget_v_i16m8_i16m1(lsum2, 3), 16); + // + acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_2_0, __riscv_vget_v_i16m8_i16m1(lsum1, 4), 16); + acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_2_1, __riscv_vget_v_i16m8_i16m1(lsum1, 5), 16); + acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_2_0, __riscv_vget_v_i16m8_i16m1(lsum2, 4), 16); + acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_2_1, __riscv_vget_v_i16m8_i16m1(lsum2, 5), 16); + // + acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_3_0, __riscv_vget_v_i16m8_i16m1(lsum1, 6), 16); + acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_3_1, __riscv_vget_v_i16m8_i16m1(lsum1, 7), 16); + acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_3_0, __riscv_vget_v_i16m8_i16m1(lsum2, 6), 16); + acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_3_1, __riscv_vget_v_i16m8_i16m1(lsum2, 7), 16); + } + + // Reduce and accumulate in `sumf`. + vint32m1_t one = __riscv_vmv_v_x_i32m1(0, 1); + int sumi1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m2_i32m1(acc1, one, 16)); + int sumi2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m2_i32m1(acc2, one, 16)); + sumf += y[i].d * GGML_CPU_FP16_TO_FP32(scale.f16) * (sumi1 + IQ1M_DELTA * sumi2); + } + + *s = sumf; +} + +void ggml_vec_dot_iq1_m_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_v_intrinsic + switch (__riscv_vlenb() * 8) { + case 256: + ggml_vec_dot_iq1_m_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); + break; + default: + ggml_vec_dot_iq1_m_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); + break; + } +#else + ggml_vec_dot_iq1_m_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + +static const uint8_t sign_gather_indices_arr[64] = { + 0,0,0,0,0,0,0,0, 1,1,1,1,1,1,1,1, 2,2,2,2,2,2,2,2, 3,3,3,3,3,3,3,3, + 4,4,4,4,4,4,4,4, 5,5,5,5,5,5,5,5, 6,6,6,6,6,6,6,6, 7,7,7,7,7,7,7,7 +}; + +static const uint8_t sign_bit_masks_arr[64] = { + 1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128, + 1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128 +}; + + +static void ggml_vec_dot_iq2_s_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); + + const block_iq2_s * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + const uint64_t * grid64 = (const uint64_t *)iq2s_grid; + + // Pre-load Constants + vuint8m2_t v_ids = __riscv_vid_v_u8m2(32); + vuint8m2_t v_sign_gather_indices = __riscv_vsrl_vx_u8m2(v_ids, 3, 32); + vuint8m2_t v_ones = __riscv_vmv_v_x_u8m2(1, 32); + vuint8m2_t v_shift_amts = __riscv_vand_vx_u8m2(v_ids, 7, 32); + vuint8m2_t v_sign_masks = __riscv_vsll_vv_u8m2(v_ones, v_shift_amts, 32); + uint16_t shift_qh_arr[4] = {11, 9, 7, 5}; + vuint16mf2_t v_shift_qh = __riscv_vle16_v_u16mf2(shift_qh_arr, 4); + + float sumf = 0.0f; + + for (int i = 0; i < nb; ++i) { + const float combined_scale = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * GGML_RESTRICT qs = x[i].qs; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const uint8_t * GGML_RESTRICT scales = x[i].scales; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + const uint8_t * signs_ptr = qs + 32; + float sum_block = 0.0f; + + for (int ib = 0; ib < 8; ++ib) { // Load Low Bits [4 bytes] vuint8mf4_t v_qs_u8 = __riscv_vle8_v_u8mf4(qs, 4); @@ -2157,6 +2387,108 @@ static void ggml_vec_dot_iq2_s_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t *s = 0.125f * sumf; } +static void ggml_vec_dot_iq2_s_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); + + const block_iq2_s * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + const uint64_t * grid64 = (const uint64_t *)iq2s_grid; + + // --- Pre-load Constants --- + uint16_t gather_qh_arr[8] = {0, 0, 0, 0, 1, 1, 1, 1}; + vuint16mf2_t v_gather_qh = __riscv_vle16_v_u16mf2(gather_qh_arr, 8); + uint16_t shift_qh_arr[8] = {11, 9, 7, 5, 11, 9, 7, 5}; + vuint16mf2_t v_shift_qh = __riscv_vle16_v_u16mf2(shift_qh_arr, 8); + + // Constants for sign extraction + vuint8m2_t v_sign_gather_indices = __riscv_vle8_v_u8m2(sign_gather_indices_arr, 64); + vuint8m2_t v_sign_masks = __riscv_vle8_v_u8m2(sign_bit_masks_arr, 64); + + float sumf = 0.0f; + + for (int i = 0; i < nb; ++i) { + const float combined_scale = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * GGML_RESTRICT qs = x[i].qs; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const uint8_t * GGML_RESTRICT scales = x[i].scales; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + const uint8_t * signs_ptr = qs + 32; + + float sum_block = 0.0f; + + for (int ib = 0; ib < 4; ++ib) { + // Combine low + high bits + vuint8mf4_t v_qs_u8 = __riscv_vle8_v_u8mf4(qs, 8); + qs += 8; + uint16_t qh_val; + memcpy(&qh_val, qh, 2); + qh += 2; + vuint8mf8_t v_qh_raw = __riscv_vle8_v_u8mf8((const uint8_t*)&qh_val, 2); + vuint16mf4_t v_qh_u16 = __riscv_vwcvtu_x_x_v_u16mf4(v_qh_raw, 2); + vuint16mf2_t v_qh_u16_ext = __riscv_vlmul_ext_v_u16mf4_u16mf2(v_qh_u16); + vuint16mf2_t v_qh_expanded = __riscv_vrgather_vv_u16mf2(v_qh_u16_ext, v_gather_qh, 8); + v_qh_expanded = __riscv_vsll_vv_u16mf2(v_qh_expanded, v_shift_qh, 8); + + // Mask: We want bits 11-12. 0x1800 = 0001 1000 0000 0000 + v_qh_expanded = __riscv_vand_vx_u16mf2(v_qh_expanded, 0x1800, 8); + vuint16mf2_t v_qs_u16 = __riscv_vwcvtu_x_x_v_u16mf2(v_qs_u8, 8); + + // Multiply by 8 to get byte offset, instead of element offset + v_qs_u16 = __riscv_vsll_vx_u16mf2(v_qs_u16, 3, 8); + vuint16mf2_t v_grid_offsets = __riscv_vor_vv_u16mf2(v_qs_u16, v_qh_expanded, 8); + + // Lookup Grid using Byte Offsets + vuint64m2_t v_grid_vals = __riscv_vluxei16_v_u64m2(grid64, v_grid_offsets, 8); + + vuint8m2_t v_grid_u8 = __riscv_vreinterpret_v_u64m2_u8m2(v_grid_vals); + vint8m2_t v_grid_i8 = __riscv_vreinterpret_v_u8m2_i8m2(v_grid_u8); + + // Load signs and generate sign mask + vuint8mf4_t v_signs_raw = __riscv_vle8_v_u8mf4(signs_ptr, 8); + signs_ptr += 8; + + vuint8m2_t v_signs_source = __riscv_vlmul_ext_v_u8mf4_u8m2(v_signs_raw); + vuint8m2_t v_signs_bcast = __riscv_vrgather_vv_u8m2(v_signs_source, v_sign_gather_indices, 64); + + vuint8m2_t v_sign_bits = __riscv_vand_vv_u8m2(v_signs_bcast, v_sign_masks, 64); + vbool4_t m_negative = __riscv_vmsne_vx_u8m2_b4(v_sign_bits, 0, 64); + + vint8m2_t v_q8 = __riscv_vle8_v_i8m2(q8, 64); + q8 += 64; + + vint8m2_t v_q8_signed = __riscv_vrsub_vx_i8m2_mu(m_negative, v_q8, v_q8, 0, 64); + vint16m4_t v_dot = __riscv_vwmul_vv_i16m4(v_grid_i8, v_q8_signed, 64); + + vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, 1); + + int32_t s0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( + __riscv_vget_v_i16m4_i16m1(v_dot, 0), v_zero, 16)); + int32_t s1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( + __riscv_vget_v_i16m4_i16m1(v_dot, 1), v_zero, 16)); + int32_t s2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( + __riscv_vget_v_i16m4_i16m1(v_dot, 2), v_zero, 16)); + int32_t s3 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( + __riscv_vget_v_i16m4_i16m1(v_dot, 3), v_zero, 16)); + + uint8_t sc0 = scales[0]; + uint8_t sc1 = scales[1]; + scales += 2; + + sum_block += s0 * (2 * (sc0 & 0xF) + 1); + sum_block += s1 * (2 * (sc0 >> 4) + 1); + sum_block += s2 * (2 * (sc1 & 0xF) + 1); + sum_block += s3 * (2 * (sc1 >> 4) + 1); + } + sumf += sum_block * combined_scale; + } + *s = 0.125f * sumf; +} + void ggml_vec_dot_iq2_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { #if defined __riscv_v_intrinsic switch (__riscv_vlenb() * 8) { @@ -2175,6 +2507,333 @@ void ggml_vec_dot_iq2_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo #endif } +#if defined(__riscv_v) +static const int8_t keven_signs_q2xs[1024] = { + 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1, + 1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, -1, + 1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, -1, + 1, 1, -1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, 1, + 1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, -1, + 1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, 1, + 1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, 1, + 1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, -1, + 1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, -1, + 1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, 1, + 1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, 1, + 1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, -1, + 1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, 1, + 1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, -1, + 1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, -1, + 1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, 1, + 1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, -1, + 1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, 1, + 1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, 1, + 1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, -1, + 1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, 1, + 1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, -1, + 1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, -1, + 1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, 1, + 1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, 1, + 1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, -1, + 1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, -1, + 1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, 1, + 1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, -1, + 1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, 1, + 1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, 1, + 1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1, +}; +#endif + +static void ggml_vec_dot_iq2_xs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq2_xs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + const uint64_t * grid64 = (const uint64_t *)iq2xs_grid; + + float sumf = 0.0f; + + for (int i = 0; i < nb; ++i) { + const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + const uint16_t * GGML_RESTRICT qs = x[i].qs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + const uint8_t * GGML_RESTRICT scales = x[i].scales; + + int32_t sum_int = 0; + + // Loop over 4 subblocks of 64 elements (QK_K = 256) + for (int ib64 = 0; ib64 < QK_K / 64; ++ib64) { + // Load 8 uint16 indices (controls 64 values) + vuint16mf2_t v_qs = __riscv_vle16_v_u16mf2(qs, 8); + qs += 8; + + // Extract indices for grid (low 9 bits) and signs (high 7 bits) + // Multiply by 8 (<< 3) for byte offsets into the uint64 tables + vuint16mf2_t vidx_grid = __riscv_vsll_vx_u16mf2(__riscv_vand_vx_u16mf2(v_qs, 511, 8), 3, 8); + vuint16mf2_t vidx_sign = __riscv_vsll_vx_u16mf2(__riscv_vsrl_vx_u16mf2(v_qs, 9, 8), 3, 8); + + vuint64m2_t vq2_64 = __riscv_vluxei16_v_u64m2(grid64, vidx_grid, 8); + vuint64m2_t vs2_64 = __riscv_vluxei16_v_u64m2(signs64, vidx_sign, 8); + + vint8m2_t q2u = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vreinterpret_v_u64m2_u8m2(vq2_64)); + vint8m2_t q2s = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vreinterpret_v_u64m2_u8m2(vs2_64)); + + vint8m2_t q2_final = __riscv_vmul_vv_i8m2(q2u, q2s, 64); + + vint8m2_t q8v = __riscv_vle8_v_i8m2(q8, 64); + q8 += 64; + + vint16m4_t prod = __riscv_vwmul_vv_i16m4(q2_final, q8v, 64); + + vint32m1_t zero_vec = __riscv_vmv_v_x_i32m1(0, 1); + + int32_t sum0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( + __riscv_vget_v_i16m4_i16m1(prod, 0), zero_vec, 16)); + int32_t sum1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( + __riscv_vget_v_i16m4_i16m1(prod, 1), zero_vec, 16)); + int32_t sum2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( + __riscv_vget_v_i16m4_i16m1(prod, 2), zero_vec, 16)); + int32_t sum3 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( + __riscv_vget_v_i16m4_i16m1(prod, 3), zero_vec, 16)); + + const uint8_t scale_byte_1 = scales[0]; + const uint8_t scale_byte_2 = scales[1]; + scales += 2; + + sum_int += sum0 * ((scale_byte_1 & 0x0F) * 2 + 1); + sum_int += sum1 * ((scale_byte_1 >> 4) * 2 + 1); + sum_int += sum2 * ((scale_byte_2 & 0x0F) * 2 + 1); + sum_int += sum3 * ((scale_byte_2 >> 4) * 2 + 1); + } + + sumf += d * sum_int; + } + *s = 0.125f * sumf; +} + +void ggml_vec_dot_iq2_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_v_intrinsic + switch (__riscv_vlenb() * 8) { + case 256: + ggml_vec_dot_iq2_xs_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); + break; + default: + ggml_vec_dot_iq2_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); + break; + } +#else + ggml_vec_dot_iq2_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + +static void ggml_vec_dot_iq2_xxs_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq2_xxs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + const uint64_t * grid64 = (const uint64_t *)iq2xxs_grid; + + uint32_t shift_constants[4] = {0, 7, 14, 21}; + vuint32m1_t v_shifts = __riscv_vle32_v_u32m1(shift_constants, 4); + + float sumf = 0.0f; + for (int i = 0; i < nb; ++i) { + const float combined_scale = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * GGML_RESTRICT q2_ptr = (const uint8_t *) x[i].qs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + float sum = 0.0f; + + #pragma GCC unroll 1 + for (int ib32 = 0; ib32 < QK_K / 32; ib32 += 2) { + vint8m2_t q8_1 = __riscv_vle8_v_i8m2(q8, 32); q8 += 32; + vint8m2_t q8_2 = __riscv_vle8_v_i8m2(q8, 32); q8 += 32; + + vuint8mf4_t v_raw_q2_1 = __riscv_vle8_v_u8mf4(q2_ptr, 4); + vuint8mf4_t v_raw_q2_2 = __riscv_vle8_v_u8mf4(q2_ptr + 8, 4); + + vuint16mf2_t vidx_q2_1 = __riscv_vwcvtu_x_x_v_u16mf2(v_raw_q2_1, 4); + vuint16mf2_t vidx_q2_2 = __riscv_vwcvtu_x_x_v_u16mf2(v_raw_q2_2, 4); + + vidx_q2_1 = __riscv_vsll_vx_u16mf2(vidx_q2_1, 3, 4); + vidx_q2_2 = __riscv_vsll_vx_u16mf2(vidx_q2_2, 3, 4); + + uint32_t s_packed_1, s_packed_2; + memcpy(&s_packed_1, q2_ptr + 4, 4); + memcpy(&s_packed_2, q2_ptr + 12, 4); + + vuint32m1_t v_s_1 = __riscv_vmv_v_x_u32m1(s_packed_1, 4); + vuint32m1_t v_s_2 = __riscv_vmv_v_x_u32m1(s_packed_2, 4); + v_s_1 = __riscv_vsrl_vv_u32m1(v_s_1, v_shifts, 4); + v_s_2 = __riscv_vsrl_vv_u32m1(v_s_2, v_shifts, 4); + + v_s_1 = __riscv_vand_vx_u32m1(v_s_1, 127, 4); + v_s_2 = __riscv_vand_vx_u32m1(v_s_2, 127, 4); + + vuint16mf2_t vidx_s2_1 = __riscv_vsll_vx_u16mf2(__riscv_vncvt_x_x_w_u16mf2(v_s_1, 4), 3, 4); + vuint16mf2_t vidx_s2_2 = __riscv_vsll_vx_u16mf2(__riscv_vncvt_x_x_w_u16mf2(v_s_2, 4), 3, 4); + + vuint64m2_t vq2_64_1 = __riscv_vluxei16_v_u64m2(grid64, vidx_q2_1, 4); + vuint64m2_t vq2_64_2 = __riscv_vluxei16_v_u64m2(grid64, vidx_q2_2, 4); + + vint8m2_t q2_1 = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vreinterpret_v_u64m2_u8m2(vq2_64_1)); + vint8m2_t q2_2 = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vreinterpret_v_u64m2_u8m2(vq2_64_2)); + + vuint64m2_t vs2_64_1 = __riscv_vluxei16_v_u64m2(signs64, vidx_s2_1, 4); + vuint64m2_t vs2_64_2 = __riscv_vluxei16_v_u64m2(signs64, vidx_s2_2, 4); + vint8m2_t s2_1 = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vreinterpret_v_u64m2_u8m2(vs2_64_1)); + vint8m2_t s2_2 = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vreinterpret_v_u64m2_u8m2(vs2_64_2)); + + vint8m2_t q8s_1 = __riscv_vmul_vv_i8m2(q8_1, s2_1, 32); + vint8m2_t q8s_2 = __riscv_vmul_vv_i8m2(q8_2, s2_2, 32); + + vint16m4_t dot1 = __riscv_vwmul_vv_i16m4(q8s_1, q2_1, 32); + vint16m4_t dot2 = __riscv_vwmul_vv_i16m4(q8s_2, q2_2, 32); + + vint32m1_t zero_vec = __riscv_vmv_v_x_i32m1(0, 1); + vint32m1_t sumv1 = __riscv_vwredsum_vs_i16m4_i32m1(dot1, zero_vec, 32); + vint32m1_t sumv2 = __riscv_vwredsum_vs_i16m4_i32m1(dot2, zero_vec, 32); + + int32_t scalar_sum1 = __riscv_vmv_x_s_i32m1_i32(sumv1); + int32_t scalar_sum2 = __riscv_vmv_x_s_i32m1_i32(sumv2); + + int16_t scale1 = 2 * ((s_packed_1 >> 28) & 0xF) + 1; + int16_t scale2 = 2 * ((s_packed_2 >> 28) & 0xF) + 1; + + sum += scalar_sum1 * scale1 + scalar_sum2 * scale2; + q2_ptr += 16; + } + sumf += sum * combined_scale; + } + *s = 0.125f * sumf; +} + +static void ggml_vec_dot_iq2_xxs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq2_xxs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + const uint64_t * grid64 = (const uint64_t *)iq2xxs_grid; + + uint32_t shift_constants[4] = {0, 7, 14, 21}; + vuint32mf2_t v_shifts = __riscv_vle32_v_u32mf2(shift_constants, 4); + + float sumf = 0.0f; + + for (int i = 0; i < nb; ++i) { + const float combined_scale = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * GGML_RESTRICT q2_ptr = (const uint8_t *) x[i].qs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + float sum = 0.0f; + + for (int ib32 = 0; ib32 < QK_K / 32; ib32 += 2) { + vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8, 32); q8 += 32; + vint8m1_t q8_2 = __riscv_vle8_v_i8m1(q8, 32); q8 += 32; + + vuint8mf8_t v_raw_q2_1 = __riscv_vle8_v_u8mf8(q2_ptr, 4); + vuint8mf8_t v_raw_q2_2 = __riscv_vle8_v_u8mf8(q2_ptr + 8, 4); + + vuint16mf4_t vidx_q2_1 = __riscv_vwcvtu_x_x_v_u16mf4(v_raw_q2_1, 4); + vuint16mf4_t vidx_q2_2 = __riscv_vwcvtu_x_x_v_u16mf4(v_raw_q2_2, 4); + + vidx_q2_1 = __riscv_vsll_vx_u16mf4(vidx_q2_1, 3, 4); + vidx_q2_2 = __riscv_vsll_vx_u16mf4(vidx_q2_2, 3, 4); + + uint32_t s_packed_1, s_packed_2; + memcpy(&s_packed_1, q2_ptr + 4, 4); + memcpy(&s_packed_2, q2_ptr + 12, 4); + + vuint32mf2_t v_s_1 = __riscv_vmv_v_x_u32mf2(s_packed_1, 4); + vuint32mf2_t v_s_2 = __riscv_vmv_v_x_u32mf2(s_packed_2, 4); + + v_s_1 = __riscv_vsrl_vv_u32mf2(v_s_1, v_shifts, 4); + v_s_2 = __riscv_vsrl_vv_u32mf2(v_s_2, v_shifts, 4); + + v_s_1 = __riscv_vand_vx_u32mf2(v_s_1, 127, 4); + v_s_2 = __riscv_vand_vx_u32mf2(v_s_2, 127, 4); + + // Narrow u32 -> u16 (vncvt) and Scale by 8 to get byte offsets + vuint16mf4_t vidx_s2_1 = __riscv_vsll_vx_u16mf4(__riscv_vncvt_x_x_w_u16mf4(v_s_1, 4), 3, 4); + vuint16mf4_t vidx_s2_2 = __riscv_vsll_vx_u16mf4(__riscv_vncvt_x_x_w_u16mf4(v_s_2, 4), 3, 4); + + // Load q2 values from lookup grid + vuint64m1_t vq2_64_1 = __riscv_vluxei16_v_u64m1(grid64, vidx_q2_1, 4); + vuint64m1_t vq2_64_2 = __riscv_vluxei16_v_u64m1(grid64, vidx_q2_2, 4); + vint8m1_t q2_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vreinterpret_v_u64m1_u8m1(vq2_64_1)); + vint8m1_t q2_2 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vreinterpret_v_u64m1_u8m1(vq2_64_2)); + + // Load sign values + vuint64m1_t vs2_64_1 = __riscv_vluxei16_v_u64m1(signs64, vidx_s2_1, 4); + vuint64m1_t vs2_64_2 = __riscv_vluxei16_v_u64m1(signs64, vidx_s2_2, 4); + vint8m1_t s2_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vreinterpret_v_u64m1_u8m1(vs2_64_1)); + vint8m1_t s2_2 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vreinterpret_v_u64m1_u8m1(vs2_64_2)); + + // Apply signs to q8 + vint8m1_t q8s_1 = __riscv_vmul_vv_i8m1(q8_1, s2_1, 32); + vint8m1_t q8s_2 = __riscv_vmul_vv_i8m1(q8_2, s2_2, 32); + + // multiplying q2 with q8 + vint16m2_t dot1 = __riscv_vwmul_vv_i16m2(q8s_1, q2_1, 32); + vint16m2_t dot2 = __riscv_vwmul_vv_i16m2(q8s_2, q2_2, 32); + + vint32m1_t zero_vec = __riscv_vmv_v_x_i32m1(0, 1); + vint32m1_t sumv1 = __riscv_vwredsum_vs_i16m2_i32m1(dot1, zero_vec, 32); + vint32m1_t sumv2 = __riscv_vwredsum_vs_i16m2_i32m1(dot2, zero_vec, 32); + int32_t scalar_sum1 = __riscv_vmv_x_s_i32m1_i32(sumv1); + int32_t scalar_sum2 = __riscv_vmv_x_s_i32m1_i32(sumv2); + int16_t scale1 = 2 * ((s_packed_1 >> 28) & 0xF) + 1; + int16_t scale2 = 2 * ((s_packed_2 >> 28) & 0xF) + 1; + + sum += scalar_sum1 * scale1 + scalar_sum2 * scale2; + q2_ptr += 16; + } + sumf += sum * combined_scale; + } + *s = 0.125f * sumf; +} + +void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_v_intrinsic + switch (__riscv_vlenb() * 8) { + case 128: + ggml_vec_dot_iq2_xxs_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); + break; + default: + ggml_vec_dot_iq2_xxs_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); + break; + } +#else + ggml_vec_dot_iq2_xxs_q8_K(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + static void ggml_vec_dot_iq3_s_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); UNUSED(nrc); @@ -2231,57 +2890,389 @@ static void ggml_vec_dot_iq3_s_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t v_qh_val = __riscv_vsll_vx_u16m1(v_qh_val, 10, 16); vuint16m1_t v_grid_offsets = __riscv_vor_vv_u16m1(v_qs_u16, v_qh_val, 16); - // Grid value is 4xuint8 - vuint32m2_t v_grid_packed = __riscv_vluxei16_v_u32m2((const uint32_t *)grid64, v_grid_offsets, 16); - vuint8m2_t v_grid_u8 = __riscv_vreinterpret_v_u32m2_u8m2(v_grid_packed); - vuint8mf4_t v_signs_raw = __riscv_vle8_v_u8mf4(signs, 8); - signs += 8; + // Grid value is 4xuint8 + vuint32m2_t v_grid_packed = __riscv_vluxei16_v_u32m2((const uint32_t *)grid64, v_grid_offsets, 16); + vuint8m2_t v_grid_u8 = __riscv_vreinterpret_v_u32m2_u8m2(v_grid_packed); + vuint8mf4_t v_signs_raw = __riscv_vle8_v_u8mf4(signs, 8); + signs += 8; + + // Generate sign mask + vuint8m2_t v_signs_source = __riscv_vlmul_ext_v_u8mf4_u8m2(v_signs_raw); + vuint8m2_t v_signs_bcast = __riscv_vrgather_vv_u8m2(v_signs_source, v_sign_gather_indices, 64); + vuint8m2_t v_sign_bits = __riscv_vand_vv_u8m2(v_signs_bcast, v_sign_masks, 64); + vbool4_t m_negative = __riscv_vmsne_vx_u8m2_b4(v_sign_bits, 0, 64); + + vint8m2_t v_q8 = __riscv_vle8_v_i8m2(q8, 64); + q8 += 64; + + // Apply Signs + vint8m2_t v_q8_signed = __riscv_vrsub_vx_i8m2_mu(m_negative, v_q8, v_q8, 0, 64); + vint16m4_t v_dot = __riscv_vwmulsu_vv_i16m4(v_q8_signed, v_grid_u8, 64); + + // Reduction + vint16m2_t v_dot_lo = __riscv_vget_v_i16m4_i16m2(v_dot, 0); + vint16m2_t v_dot_hi = __riscv_vget_v_i16m4_i16m2(v_dot, 1); + vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, 1); + + int32_t s_lo = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(v_dot_lo, v_zero, 32)); + int32_t s_hi = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(v_dot_hi, v_zero, 32)); + + // Apply sub-scales + uint8_t sc_byte = *scales++; + int sc_lo = (sc_byte & 0xF) * 2 + 1; + int sc_hi = (sc_byte >> 4) * 2 + 1; + + sum_block += s_lo * sc_lo + s_hi * sc_hi; + } + sumf += sum_block * combined_scale; + } + *s = sumf; +} + +void ggml_vec_dot_iq3_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_v_intrinsic + switch (__riscv_vlenb() * 8) { + case 256: + ggml_vec_dot_iq3_s_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); + break; + default: + ggml_vec_dot_iq3_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); + break; + } +#else + ggml_vec_dot_iq3_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + +static void ggml_vec_dot_iq3_xxs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq3_xxs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + const int nb = n / QK_K; + + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + const uint32_t * grid32 = (const uint32_t *)iq3xxs_grid; + + // constants for unpacking logic + const uint32_t shifts_val[8] = {0, 7, 14, 21, 0, 7, 14, 21}; + vuint32m1_t v_shifts = __riscv_vle32_v_u32m1(shifts_val, 8); + + const uint32_t gather_idx_val[8] = {0, 0, 0, 0, 1, 1, 1, 1}; + vuint32m1_t v_gather_idx = __riscv_vle32_v_u32m1(gather_idx_val, 8); + + uint32_t aux32[2]; + float sumf = 0.0f; + + for (int i = 0; i < nb; ++i) { + const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * GGML_RESTRICT q3_indices = x[i].qs; + const uint8_t * GGML_RESTRICT metadata = x[i].qs + QK_K/4; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + float block_sum = 0.0f; + + for (int ib = 0; ib < QK_K / 64; ++ib) { + // Load q8 (64 bytes) + vint8m2_t v_q8 = __riscv_vle8_v_i8m2(q8, 64); + q8 += 64; + + // load of metadata via memcpy + memcpy(aux32, metadata, 2 * sizeof(uint32_t)); + metadata += 2 * sizeof(uint32_t); + + // Load q3 indices and gather magnitudes + vuint8mf2_t v_q3_idx_u8 = __riscv_vle8_v_u8mf2(q3_indices, 16); + q3_indices += 16; + + vuint16m1_t v_q3_idx_u16 = __riscv_vwmulu_vx_u16m1(v_q3_idx_u8, 4, 16); + vuint32m2_t v_q3_magnitudes_u32 = __riscv_vluxei16_v_u32m2(grid32, v_q3_idx_u16, 16); + vint8m2_t v_q3_magnitudes = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vreinterpret_v_u32m2_u8m2(v_q3_magnitudes_u32)); + + // --- Unpacking of Sign Indices --- + + // 1. Load the 2 auxiliary 32-bit integers into a vector + vuint32m1_t v_aux = __riscv_vle32_v_u32m1(aux32, 2); + + // 2. Broadcast/Gather: replicate aux[0] to first 4 lanes, aux[1] to next 4 lanes + vuint32m1_t v_aux_expanded = __riscv_vrgather_vv_u32m1(v_aux, v_gather_idx, 8); + + // 3. Apply Shifts and Mask: ((val >> shift) & 127) + vuint32m1_t v_s_vals_raw = __riscv_vand_vx_u32m1(__riscv_vsrl_vv_u32m1(v_aux_expanded, v_shifts, 8), 127, 8); + + // 4. Narrow to u16 (required for vluxei index) and multiply by 8 (byte offset for u64 table) + vuint16mf2_t sign_indices_byte_offset = __riscv_vsll_vx_u16mf2(__riscv_vncvt_x_x_w_u16mf2(v_s_vals_raw, 8), 3, 8); + + // 5. Gather Signs + vuint64m2_t v_s_vals_u64 = __riscv_vluxei16_v_u64m2(signs64, sign_indices_byte_offset, 8); + vint8m2_t v_s_vals = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vreinterpret_v_u64m2_u8m2(v_s_vals_u64)); + + vint8m2_t v_q3_signed = __riscv_vmul_vv_i8m2(v_q3_magnitudes, v_s_vals, 64); + vint16m4_t v_dot = __riscv_vwmul_vv_i16m4(v_q8, v_q3_signed, 64); + + vint16m2_t v_dot_1 = __riscv_vget_v_i16m4_i16m2(v_dot, 0); + vint16m2_t v_dot_2 = __riscv_vget_v_i16m4_i16m2(v_dot, 1); + + vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, 1); + vint32m1_t v_sum_1 = __riscv_vwredsum_vs_i16m2_i32m1(v_dot_1, v_zero, 32); + vint32m1_t v_sum_2 = __riscv_vwredsum_vs_i16m2_i32m1(v_dot_2, v_zero, 32); + + int32_t sum1_i = __riscv_vmv_x_s_i32m1_i32(v_sum_1); + int32_t sum2_i = __riscv_vmv_x_s_i32m1_i32(v_sum_2); + + const float scale1_f = (float)(2 * (aux32[0] >> 28) + 1); + const float scale2_f = (float)(2 * (aux32[1] >> 28) + 1); + + block_sum += sum1_i * scale1_f + sum2_i * scale2_f; + } + + sumf += d * block_sum; + } + *s = 0.25f * sumf; +} + +void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_v_intrinsic + switch (__riscv_vlenb() * 8) { + case 256: + ggml_vec_dot_iq3_xxs_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); + break; + default: + ggml_vec_dot_iq3_xxs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); + break; + } +#else + ggml_vec_dot_iq3_xxs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + +static void ggml_vec_dot_iq4_nl_q8_0_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + assert(n % QK4_NL == 0); + static_assert(QK4_NL == QK8_0, "QK4_NL and QK8_0 must be the same"); + + const block_iq4_nl * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + + const int nb = n / QK4_NL; + + int ib = 0; + float sumf = 0; + + // Load the lookup table once. + const vint8m2_t values = __riscv_vle8_v_i8m2(kvalues_iq4nl, 16); + int acc1, acc2; + + // We process 2 blocks at once. + for (; ib + 1 < nb; ib += 2) { + // Weights and activations. + vuint8m1_t iq4_packed1 = __riscv_vle8_v_u8m1(x[ib + 0].qs, 16); + vint8m2_t q8b1 = __riscv_vle8_v_i8m2(y[ib + 0].qs, 32); + vuint8m1_t iq4_packed2 = __riscv_vle8_v_u8m1(x[ib + 1].qs, 16); + vint8m2_t q8b2 = __riscv_vle8_v_i8m2(y[ib + 1].qs, 32); + + // Unpack the weight blocks. + vuint8m2_t iq4bits1; + iq4bits1 = __riscv_vset_v_u8m1_u8m2(iq4bits1, 0, __riscv_vand_vx_u8m1(iq4_packed1, 0xf, 16)); + iq4bits1 = __riscv_vset_v_u8m1_u8m2(iq4bits1, 1, __riscv_vsrl_vx_u8m1(iq4_packed1, 4, 16)); + vuint8m2_t iq4bits2; + iq4bits2 = __riscv_vset_v_u8m1_u8m2(iq4bits2, 0, __riscv_vand_vx_u8m1(iq4_packed2, 0xf, 16)); + iq4bits2 = __riscv_vset_v_u8m1_u8m2(iq4bits2, 1, __riscv_vsrl_vx_u8m1(iq4_packed2, 4, 16)); + + // Gather values from the lookup table. + vint8m2_t iq4b1 = __riscv_vrgather_vv_i8m2(values, iq4bits1, 32); + vint8m2_t iq4b2 = __riscv_vrgather_vv_i8m2(values, iq4bits2, 32); + + // Accumulation. + vint16m4_t sum1 = __riscv_vwmul_vv_i16m4(q8b1, iq4b1, 32); + vint16m4_t sum2 = __riscv_vwmul_vv_i16m4(q8b2, iq4b2, 32); + __riscv_vse32_v_i32m1(&acc1,__riscv_vwredsum_vs_i16m4_i32m1(sum1, __riscv_vmv_v_x_i32m1(0, 1), 32), 1); + __riscv_vse32_v_i32m1(&acc2,__riscv_vwredsum_vs_i16m4_i32m1(sum2, __riscv_vmv_v_x_i32m1(0, 1), 32), 1); + sumf += ((GGML_CPU_FP16_TO_FP32(x[ib + 0].d) * GGML_CPU_FP16_TO_FP32(y[ib + 0].d) * acc1)); + sumf += ((GGML_CPU_FP16_TO_FP32(x[ib + 1].d) * GGML_CPU_FP16_TO_FP32(y[ib + 1].d) * acc2)); + } + + *s = sumf; +} + +static void ggml_vec_dot_iq4_nl_q8_0_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + assert(n % QK4_NL == 0); + static_assert(QK4_NL == QK8_0, "QK4_NL and QK8_0 must be the same"); + + const block_iq4_nl * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + + const int nb = n / QK4_NL; + + int ib = 0; + float sumf = 0; + + // Load the lookup table once. + const vint8mf2_t values = __riscv_vle8_v_i8mf2(kvalues_iq4nl, 16); + int acc1, acc2; + + // We process 2 blocks at once. + for (; ib + 1 < nb; ib += 2) { + // Weights and activations. + vuint8mf2_t iq4_packed1 = __riscv_vle8_v_u8mf2(x[ib + 0].qs, 16); + vint8mf2_t q8b_lo1 = __riscv_vle8_v_i8mf2(y[ib + 0].qs, 16); + vint8mf2_t q8b_hi1 = __riscv_vle8_v_i8mf2(y[ib + 0].qs + 16, 16); + vuint8mf2_t iq4_packed2 = __riscv_vle8_v_u8mf2(x[ib + 1].qs, 16); + vint8mf2_t q8b_lo2 = __riscv_vle8_v_i8mf2(y[ib + 1].qs, 16); + vint8mf2_t q8b_hi2 = __riscv_vle8_v_i8mf2(y[ib + 1].qs + 16, 16); + + // Unpack the weight blocks. + vuint8mf2_t iq4bits_lo1 = __riscv_vand_vx_u8mf2(iq4_packed1, 0xf, 16); + vuint8mf2_t iq4bits_hi1 = __riscv_vsrl_vx_u8mf2(iq4_packed1, 4, 16); + vuint8mf2_t iq4bits_lo2 = __riscv_vand_vx_u8mf2(iq4_packed2, 0xf, 16); + vuint8mf2_t iq4bits_hi2 = __riscv_vsrl_vx_u8mf2(iq4_packed2, 4, 16); + + // Gather values from the lookup table. + vint8mf2_t iq4b_lo1 = __riscv_vrgather_vv_i8mf2(values, iq4bits_lo1, 16); + vint8mf2_t iq4b_hi1 = __riscv_vrgather_vv_i8mf2(values, iq4bits_hi1, 16); + vint8mf2_t iq4b_lo2 = __riscv_vrgather_vv_i8mf2(values, iq4bits_lo2, 16); + vint8mf2_t iq4b_hi2 = __riscv_vrgather_vv_i8mf2(values, iq4bits_hi2, 16); + + // Accumulation. + vint16m1_t sum1 = __riscv_vwmul_vv_i16m1(q8b_lo1, iq4b_lo1, 16); + sum1 = __riscv_vwmacc_vv_i16m1(sum1, q8b_hi1, iq4b_hi1, 16); + vint16m1_t sum2 = __riscv_vwmul_vv_i16m1(q8b_lo2, iq4b_lo2, 16); + sum2 = __riscv_vwmacc_vv_i16m1(sum2, q8b_hi2, iq4b_hi2, 16); + __riscv_vse32_v_i32m1(&acc1,__riscv_vwredsum_vs_i16m1_i32m1(sum1, __riscv_vmv_v_x_i32m1(0, 1), 16), 1); + __riscv_vse32_v_i32m1(&acc2,__riscv_vwredsum_vs_i16m1_i32m1(sum2, __riscv_vmv_v_x_i32m1(0, 1), 16), 1); + sumf += ((GGML_CPU_FP16_TO_FP32(x[ib + 0].d) * GGML_CPU_FP16_TO_FP32(y[ib + 0].d) * acc1)); + sumf += ((GGML_CPU_FP16_TO_FP32(x[ib + 1].d) * GGML_CPU_FP16_TO_FP32(y[ib + 1].d) * acc2)); + } + + *s = sumf; +} + +void ggml_vec_dot_iq4_nl_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_v_intrinsic + switch (__riscv_vlenb() * 8) { + case 128: + ggml_vec_dot_iq4_nl_q8_0_vl128(n, s, bs, vx, bx, vy, by, nrc); + break; + default: + ggml_vec_dot_iq4_nl_q8_0_vl256(n, s, bs, vx, bx, vy, by, nrc); + break; + } +#else + ggml_vec_dot_iq4_nl_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + +static void ggml_vec_dot_iq4_xs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + assert(n % QK_K == 0); - // Generate sign mask - vuint8m2_t v_signs_source = __riscv_vlmul_ext_v_u8mf4_u8m2(v_signs_raw); - vuint8m2_t v_signs_bcast = __riscv_vrgather_vv_u8m2(v_signs_source, v_sign_gather_indices, 64); - vuint8m2_t v_sign_bits = __riscv_vand_vv_u8m2(v_signs_bcast, v_sign_masks, 64); - vbool4_t m_negative = __riscv_vmsne_vx_u8m2_b4(v_sign_bits, 0, 64); + const block_iq4_xs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; - vint8m2_t v_q8 = __riscv_vle8_v_i8m2(q8, 64); - q8 += 64; + const int nb = n / QK_K; - // Apply Signs - vint8m2_t v_q8_signed = __riscv_vrsub_vx_i8m2_mu(m_negative, v_q8, v_q8, 0, 64); - vint16m4_t v_dot = __riscv_vwmulsu_vv_i16m4(v_q8_signed, v_grid_u8, 64); +#if defined __riscv_v_intrinsic + const vint8m4_t values = __riscv_vle8_v_i8m4(kvalues_iq4nl, 16); + float sumf = 0; + int acc[4]; + + // Indices for re-ordering IQ4 data. + uint64_t index[16] = { + 0, 1, 8, 9, + 2, 3, 10, 11, + 4, 5,12, 13, + 6, 7, 14, 15, + }; + vuint64m4_t i_vec = __riscv_vle64_v_u64m4(index, 16); - // Reduction - vint16m2_t v_dot_lo = __riscv_vget_v_i16m4_i16m2(v_dot, 0); - vint16m2_t v_dot_hi = __riscv_vget_v_i16m4_i16m2(v_dot, 1); - vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, 1); + for (int ibl = 0; ibl < nb; ++ibl) { + const int8_t * q8 = y[ibl].qs; + const uint8_t * iq4 = x[ibl].qs; + uint16_t h = x[ibl].scales_h; - int32_t s_lo = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(v_dot_lo, v_zero, 32)); - int32_t s_hi = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(v_dot_hi, v_zero, 32)); + int sumi1 = 0, sumi2 = 0, sumi3 = 0, sumi4 = 0; - // Apply sub-scales - uint8_t sc_byte = *scales++; - int sc_lo = (sc_byte & 0xF) * 2 + 1; - int sc_hi = (sc_byte >> 4) * 2 + 1; + for (int ib = 0; ib < QK_K / 128; ++ib) { + // Weights and activations. + vuint8m2_t iq4_packed = __riscv_vle8_v_u8m2(iq4, 64); + vint8m4_t q8b = __riscv_vle8_v_i8m4(q8, 128); + iq4 += 64; + q8 += 128; - sum_block += s_lo * sc_lo + s_hi * sc_hi; + // Unpack the weight blocks. + vuint8m2_t iq4bits_lo = __riscv_vand_vx_u8m2(iq4_packed, 0xf, 64); + vuint8m2_t iq4bits_hi = __riscv_vsrl_vx_u8m2(iq4_packed, 4, 64); + vuint8m4_t iq4bits; + iq4bits = __riscv_vset_v_u8m2_u8m4(iq4bits, 0, iq4bits_lo); + iq4bits = __riscv_vset_v_u8m2_u8m4(iq4bits, 1, iq4bits_hi); + vuint8m4_t iq4bits_reorder = __riscv_vreinterpret_v_u64m4_u8m4(__riscv_vrgather_vv_u64m4(__riscv_vreinterpret_v_u8m4_u64m4(iq4bits), i_vec, 16)); + vint8m4_t iq4b = __riscv_vrgather_vv_i8m4(values, iq4bits_reorder, 128); + + // Multiply with activations. + vint16m8_t prod = __riscv_vwmul_vv_i16m8(iq4b, q8b, 128); + + // Reduce separately. + __riscv_vse32_v_i32m1(&acc[0],__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(prod, 0), __riscv_vmv_v_x_i32m1(0, 1), 32), 1); + __riscv_vse32_v_i32m1(&acc[1],__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(prod, 1), __riscv_vmv_v_x_i32m1(0, 1), 32), 1); + __riscv_vse32_v_i32m1(&acc[2],__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(prod, 2), __riscv_vmv_v_x_i32m1(0, 1), 32), 1); + __riscv_vse32_v_i32m1(&acc[3],__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(prod, 3), __riscv_vmv_v_x_i32m1(0, 1), 32), 1); + + int ls1 = ((x[ibl].scales_l[ib * 2 + 0] & 0xf) | ((h << 4) & 0x30)) - 32; + int ls2 = ((x[ibl].scales_l[ib * 2 + 0] >> 4) | ((h << 2) & 0x30)) - 32; + int ls3 = ((x[ibl].scales_l[ib * 2 + 1] & 0xf) | ((h << 0) & 0x30)) - 32; + int ls4 = ((x[ibl].scales_l[ib * 2 + 1] >> 4) | ((h >> 2) & 0x30)) - 32; + h >>= 8; + + sumi1 += acc[0] * ls1; + sumi2 += acc[1] * ls2; + sumi3 += acc[2] * ls3; + sumi4 += acc[3] * ls4; } - sumf += sum_block * combined_scale; + + sumf += GGML_CPU_FP16_TO_FP32(x[ibl].d) * y[ibl].d * (sumi1 + sumi2 + sumi3 + sumi4); } - *s = 0.125f * sumf; + + *s = sumf; + +#else + UNUSED(x); + UNUSED(y); + UNUSED(nb); + ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif } -void ggml_vec_dot_iq3_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { #if defined __riscv_v_intrinsic switch (__riscv_vlenb() * 8) { case 256: - ggml_vec_dot_iq3_s_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); + ggml_vec_dot_iq4_xs_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); break; default: - ggml_vec_dot_iq3_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); + ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); break; } #else - ggml_vec_dot_iq3_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); + ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); #endif } @@ -2492,235 +3483,127 @@ void ggml_vec_dot_tq2_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo #endif } -static void ggml_vec_dot_iq1_s_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - assert(n % QK_K == 0); +static void ggml_vec_dot_mxfp4_q8_0_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(nrc == 1); UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); + assert(n % QK_MXFP4 == 0); + static_assert(QK_MXFP4 == QK8_0, "QK_MXFP4 and QK8_0 must be the same"); - const block_iq1_s * GGML_RESTRICT x = vx; - const block_q8_K * GGML_RESTRICT y = vy; + const block_mxfp4 * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; - const int nb = n / QK_K; + const int nb = n / QK_MXFP4; + int ib = 0; float sumf = 0; - for (int i = 0; i < nb; ++i) { - // Load qh once for the entire superblock. - vuint16mf2_t qh = __riscv_vle16_v_u16mf2(x[i].qh, 8); - - // Calculate ls. - vuint16mf2_t temp = __riscv_vsrl_vx_u16mf2(qh, 12, 8); - temp = __riscv_vand_vx_u16mf2(temp, 7, 8); - vint32m1_t ls = __riscv_vreinterpret_v_u32m1_i32m1(__riscv_vwmulu_vx_u32m1(temp, 2, 8)); - ls = __riscv_vadd_vx_i32m1(ls, 1, 8); - - // Calculate delta. - vbool32_t mask = __riscv_vmseq_vx_u16mf2_b32(__riscv_vand_vx_u16mf2(qh, 0x8000, 8), 0, 8); - vint32m1_t delta_neg = __riscv_vmv_v_x_i32m1(-1, 8); - vint32m1_t delta_pos = __riscv_vmv_v_x_i32m1(1, 8); - vint32m1_t delta = __riscv_vmerge_vvm_i32m1(delta_neg, delta_pos, mask, 8); - - // Load qs. - vuint8m1_t qs = __riscv_vle8_v_u8m1(x[i].qs, 32); - - // Prepare the indices. - const uint64_t shift = 0x0009000600030000; - vuint16m2_t qh_shift = __riscv_vreinterpret_v_u64m2_u16m2(__riscv_vmv_v_x_u64m2(shift, 8)); - vuint16m2_t qh_gather_index = __riscv_vreinterpret_v_i16m2_u16m2( - __riscv_vdiv_vx_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vid_v_u16m2(32)), 4, 32)); - vuint16m2_t qh_ext = __riscv_vlmul_ext_v_u16m1_u16m2(__riscv_vlmul_ext_v_u16mf2_u16m1(qh)); - vuint16m2_t qh_index = __riscv_vrgather_vv_u16m2(qh_ext, qh_gather_index, 32); - qh_index = __riscv_vsrl_vv_u16m2(qh_index, qh_shift, 32); - qh_index = __riscv_vand_vx_u16m2(qh_index, 7, 32); - qh_index = __riscv_vsll_vx_u16m2(qh_index, 8, 32); - qh_index = __riscv_vor_vv_u16m2(qh_index, __riscv_vzext_vf2_u16m2(qs, 32), 32); - vuint16m2_t index = __riscv_vsll_vx_u16m2(qh_index, 3, 32); - - // Final lsums. - int32_t lsums_s[8]; - vint32m1_t one_scalar = __riscv_vmv_v_x_i32m1(0, 1); - - // Sub-blocks 1-4 - { - vuint16m1_t grid_index0 = __riscv_vget_v_u16m2_u16m1(index, 0); - vint8m4_t grid0 = __riscv_vreinterpret_v_i64m4_i8m4(__riscv_vluxei16_v_i64m4((const int64_t*)iq1s_grid, grid_index0, 16)); - vint8m4_t q80 = __riscv_vle8_v_i8m4(y[i].qs, 128); - vint16m8_t lsum0 = __riscv_vwmul_vv_i16m8(grid0, q80, 128); - lsums_s[0] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum0, 0), one_scalar, 32)); - lsums_s[1] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum0, 1), one_scalar, 32)); - lsums_s[2] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum0, 2), one_scalar, 32)); - lsums_s[3] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum0, 3), one_scalar, 32)); - } - __asm__ __volatile__("" ::: "memory"); - // Sub-blocks 5-8 - { - vuint16m1_t grid_index1 = __riscv_vget_v_u16m2_u16m1(index, 1); - vint8m4_t grid1 = __riscv_vreinterpret_v_i64m4_i8m4(__riscv_vluxei16_v_i64m4((const int64_t*)iq1s_grid, grid_index1, 16)); - vint8m4_t q81 = __riscv_vle8_v_i8m4(&y[i].qs[128], 128); - vint16m8_t lsum1 = __riscv_vwmul_vv_i16m8(grid1, q81, 128); - lsums_s[4] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum1, 0), one_scalar, 32)); - lsums_s[5] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum1, 1), one_scalar, 32)); - lsums_s[6] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum1, 2), one_scalar, 32)); - lsums_s[7] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum1, 3), one_scalar, 32)); - } - __asm__ __volatile__("" ::: "memory"); - vint32m1_t lsums = __riscv_vle32_v_i32m1(&lsums_s[0], 8); - // Calculate the bsums. - vint16m1_t bsums_0 = __riscv_vle16_v_i16m1(y[i].bsums, 16); - const vuint32m1_t bsums_i32 = __riscv_vreinterpret_v_u16m1_u32m1(__riscv_vreinterpret_v_i16m1_u16m1(bsums_0)); - const vint16mf2_t bsums_i32_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(bsums_i32, 0, 8)); - const vint16mf2_t bsums_i32_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(bsums_i32, 16, 8)); - const vint32m1_t bsums = __riscv_vwadd_vv_i32m1(bsums_i32_0, bsums_i32_1, 8); + // Load the lookup table once. + const vint8m2_t values = __riscv_vle8_v_i8m2(kvalues_mxfp4, 16); + int acc1, acc2; + + // We process 2 blocks at once. + for (; ib + 1 < nb; ib += 2) { + // Weights and activations. + vuint8m1_t mx_packed1 = __riscv_vle8_v_u8m1(x[ib + 0].qs, 16); + vint8m2_t q8b1 = __riscv_vle8_v_i8m2(y[ib + 0].qs, 32); + vuint8m1_t mx_packed2 = __riscv_vle8_v_u8m1(x[ib + 1].qs, 16); + vint8m2_t q8b2 = __riscv_vle8_v_i8m2(y[ib + 1].qs, 32); + + // Unpack the weight blocks. + vuint8m2_t mxbits1; + mxbits1 = __riscv_vset_v_u8m1_u8m2(mxbits1, 0, __riscv_vand_vx_u8m1(mx_packed1, 0xf, 16)); + mxbits1 = __riscv_vset_v_u8m1_u8m2(mxbits1, 1, __riscv_vsrl_vx_u8m1(mx_packed1, 4, 16)); + vuint8m2_t mxbits2; + mxbits2 = __riscv_vset_v_u8m1_u8m2(mxbits2, 0, __riscv_vand_vx_u8m1(mx_packed2, 0xf, 16)); + mxbits2 = __riscv_vset_v_u8m1_u8m2(mxbits2, 1, __riscv_vsrl_vx_u8m1(mx_packed2, 4, 16)); + + // Gather values from the lookup table. + vint8m2_t mxb1 = __riscv_vrgather_vv_i8m2(values, mxbits1, 32); + vint8m2_t mxb2 = __riscv_vrgather_vv_i8m2(values, mxbits2, 32); // Accumulation. - vint32m1_t sumi_v = __riscv_vmul_vv_i32m1(ls, lsums, 8); - vint32m1_t sumi1_v = __riscv_vmul_vv_i32m1(__riscv_vmul_vv_i32m1(ls, delta, 8), bsums, 8); - - // Update sumf. - int sumi = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m1_i32m1(sumi_v, __riscv_vmv_v_x_i32m1(0.0f, 1), 8)); - int sumi1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m1_i32m1(sumi1_v, __riscv_vmv_v_x_i32m1(0.0f, 1), 8)); - sumf += GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d * (sumi + IQ1S_DELTA * sumi1); + vint16m4_t sum1 = __riscv_vwmul_vv_i16m4(q8b1, mxb1, 32); + vint16m4_t sum2 = __riscv_vwmul_vv_i16m4(q8b2, mxb2, 32); + __riscv_vse32_v_i32m1(&acc1,__riscv_vwredsum_vs_i16m4_i32m1(sum1, __riscv_vmv_v_x_i32m1(0, 1), 32), 1); + __riscv_vse32_v_i32m1(&acc2,__riscv_vwredsum_vs_i16m4_i32m1(sum2, __riscv_vmv_v_x_i32m1(0, 1), 32), 1); + sumf += ((GGML_E8M0_TO_FP32_HALF(x[ib + 0].e) * GGML_CPU_FP16_TO_FP32(y[ib + 0].d) * acc1)); + sumf += ((GGML_E8M0_TO_FP32_HALF(x[ib + 1].e) * GGML_CPU_FP16_TO_FP32(y[ib + 1].d) * acc2)); } *s = sumf; } -void ggml_vec_dot_iq1_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { -#if defined __riscv_v_intrinsic - switch (__riscv_vlenb() * 8) { - case 256: - ggml_vec_dot_iq1_s_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); - break; - default: - ggml_vec_dot_iq1_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); - break; - } -#else - ggml_vec_dot_iq1_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); -#endif -} - -static void ggml_vec_dot_iq1_m_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { - assert(n % QK_K == 0); +static void ggml_vec_dot_mxfp4_q8_0_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(nrc == 1); UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); + assert(n % QK_MXFP4 == 0); + static_assert(QK_MXFP4 == QK8_0, "QK_MXFP4 and QK8_0 must be the same"); - const block_iq1_m * GGML_RESTRICT x = vx; - const block_q8_K * GGML_RESTRICT y = vy; - - const int nb = n / QK_K; - - iq1m_scale_t scale; - float sumf = 0.0f; - for (int i = 0; i < nb; ++i) { - const int8_t * q8 = y[i].qs; - const uint8_t * qs = x[i].qs; - const uint8_t * qh = x[i].qh; - const uint16_t * sc = (const uint16_t *)x[i].scales; - - scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); - - // Accumulators. - vint32m2_t acc1 = __riscv_vmv_v_x_i32m2(0, 16); - vint32m2_t acc2 = __riscv_vmv_v_x_i32m2(0, 16); - - // We process 4 sub-blocks together. - for (int ib = 0; ib < QK_K/128; ib++) { - // Load qh for 4 sub-blocks. - const vuint8mf4_t qh_8 = __riscv_vle8_v_u8mf4(qh, 8); - const vuint16mf2_t qh_16_lo = __riscv_vzext_vf2_u16mf2(qh_8, 8); - const vuint16mf2_t qh_16_hi = __riscv_vsll_vx_u16mf2(qh_16_lo, 8, 8); - const vuint16m1_t qhb = __riscv_vzext_vf2_u16m1( - __riscv_vreinterpret_v_u16mf2_u8mf2(__riscv_vor_vv_u16mf2(qh_16_lo, qh_16_hi, 8)), 16); - qh += 8; - - // Prepare grid indices. - const vuint16m1_t qsb = __riscv_vzext_vf2_u16m1(__riscv_vle8_v_u8mf2(&qs[0], 16), 16); - const vuint16m1_t shift = __riscv_vreinterpret_v_u32m1_u16m1(__riscv_vmv_v_x_u32m1(0x00040008, 8)); - vuint16m1_t index = __riscv_vor_vv_u16m1(qsb, __riscv_vand_vx_u16m1(__riscv_vsll_vv_u16m1(qhb, shift, 16), 0x700, 16), 16); - index = __riscv_vsll_vx_u16m1(index, 3, 16); - qs += 16; - - // Load the grid. - const vint8m4_t iq1b = __riscv_vreinterpret_v_i64m4_i8m4(__riscv_vreinterpret_v_u64m4_i64m4( - __riscv_vluxei16_v_u64m4(iq1s_grid, index, 16))); - - // Prepare the deltas. - const vbool16_t mask = __riscv_vmsgtu_vx_u16m1_b16( - __riscv_vand_vv_u16m1(qhb, __riscv_vreinterpret_v_u32m1_u16m1(__riscv_vmv_v_x_u32m1(0x00800008, 8)), 16), 0, 16); - const vint64m4_t delta_pos = __riscv_vmv_v_x_i64m4(0x0101010101010101, 16); - const vint64m4_t delta_neg = __riscv_vmv_v_x_i64m4(0xffffffffffffffff, 16); - const vint8m4_t delta = __riscv_vreinterpret_v_i64m4_i8m4( - __riscv_vmerge_vvm_i64m4(delta_pos, delta_neg, mask, 16)); - - // Load q8 for sub-blocks. - const vint8m4_t q8b = __riscv_vle8_v_i8m4(q8, 128); - q8 += 128; + const block_mxfp4 * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; - // Calculate the lsums. - const vint16m8_t lsum1 = __riscv_vwmul_vv_i16m8(iq1b, q8b, 128); - const vint16m8_t lsum2 = __riscv_vwmul_vv_i16m8(delta, q8b, 128); + const int nb = n / QK_MXFP4; - // Prepare the scales. - const int16_t ls_0_0 = 2*((sc[0] >> 0) & 0x7) + 1; - const int16_t ls_0_1 = 2*((sc[0] >> 3) & 0x7) + 1; - const int16_t ls_1_0 = 2*((sc[0] >> 6) & 0x7) + 1; - const int16_t ls_1_1 = 2*((sc[0] >> 9) & 0x7) + 1; - const int16_t ls_2_0 = 2*((sc[1] >> 0) & 0x7) + 1; - const int16_t ls_2_1 = 2*((sc[1] >> 3) & 0x7) + 1; - const int16_t ls_3_0 = 2*((sc[1] >> 6) & 0x7) + 1; - const int16_t ls_3_1 = 2*((sc[1] >> 9) & 0x7) + 1; - sc += 2; + int ib = 0; + float sumf = 0; - // Accumulate in acc0 and acc1 for each sub-block. - acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_0_0, __riscv_vget_v_i16m8_i16m1(lsum1, 0), 16); - acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_0_1, __riscv_vget_v_i16m8_i16m1(lsum1, 1), 16); - acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_0_0, __riscv_vget_v_i16m8_i16m1(lsum2, 0), 16); - acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_0_1, __riscv_vget_v_i16m8_i16m1(lsum2, 1), 16); - // - acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_1_0, __riscv_vget_v_i16m8_i16m1(lsum1, 2), 16); - acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_1_1, __riscv_vget_v_i16m8_i16m1(lsum1, 3), 16); - acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_1_0, __riscv_vget_v_i16m8_i16m1(lsum2, 2), 16); - acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_1_1, __riscv_vget_v_i16m8_i16m1(lsum2, 3), 16); - // - acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_2_0, __riscv_vget_v_i16m8_i16m1(lsum1, 4), 16); - acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_2_1, __riscv_vget_v_i16m8_i16m1(lsum1, 5), 16); - acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_2_0, __riscv_vget_v_i16m8_i16m1(lsum2, 4), 16); - acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_2_1, __riscv_vget_v_i16m8_i16m1(lsum2, 5), 16); - // - acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_3_0, __riscv_vget_v_i16m8_i16m1(lsum1, 6), 16); - acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_3_1, __riscv_vget_v_i16m8_i16m1(lsum1, 7), 16); - acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_3_0, __riscv_vget_v_i16m8_i16m1(lsum2, 6), 16); - acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_3_1, __riscv_vget_v_i16m8_i16m1(lsum2, 7), 16); - } + // Load the lookup table once. + const vint8mf2_t values = __riscv_vle8_v_i8mf2(kvalues_mxfp4, 16); + int acc1, acc2; + + // We process 2 blocks at once. + for (; ib + 1 < nb; ib+=2) { + // Weights and activations. + vuint8mf2_t mx_packed1 = __riscv_vle8_v_u8mf2(x[ib + 0].qs, 16); + vint8mf2_t q8b_lo1 = __riscv_vle8_v_i8mf2(y[ib + 0].qs, 16); + vint8mf2_t q8b_hi1 = __riscv_vle8_v_i8mf2(y[ib + 0].qs + 16, 16); + vuint8mf2_t mx_packed2 = __riscv_vle8_v_u8mf2(x[ib + 1].qs, 16); + vint8mf2_t q8b_lo2 = __riscv_vle8_v_i8mf2(y[ib + 1].qs, 16); + vint8mf2_t q8b_hi2 = __riscv_vle8_v_i8mf2(y[ib + 1].qs + 16, 16); + + // Unpack the weight blocks. + vuint8mf2_t mxbits_lo1 = __riscv_vand_vx_u8mf2(mx_packed1, 0xf, 16); + vuint8mf2_t mxbits_hi1 = __riscv_vsrl_vx_u8mf2(mx_packed1, 4, 16); + vuint8mf2_t mxbits_lo2 = __riscv_vand_vx_u8mf2(mx_packed2, 0xf, 16); + vuint8mf2_t mxbits_hi2 = __riscv_vsrl_vx_u8mf2(mx_packed2, 4, 16); + + // Gather values from the lookup table. + vint8mf2_t mxb_lo1 = __riscv_vrgather_vv_i8mf2(values, mxbits_lo1, 16); + vint8mf2_t mxb_hi1 = __riscv_vrgather_vv_i8mf2(values, mxbits_hi1, 16); + vint8mf2_t mxb_lo2 = __riscv_vrgather_vv_i8mf2(values, mxbits_lo2, 16); + vint8mf2_t mxb_hi2 = __riscv_vrgather_vv_i8mf2(values, mxbits_hi2, 16); - // Reduce and accumulate in `sumf`. - vint32m1_t one = __riscv_vmv_v_x_i32m1(0, 1); - int sumi1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m2_i32m1(acc1, one, 16)); - int sumi2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m2_i32m1(acc2, one, 16)); - sumf += y[i].d * GGML_CPU_FP16_TO_FP32(scale.f16) * (sumi1 + IQ1M_DELTA * sumi2); + // Accumulation. + vint16m1_t sum1 = __riscv_vwmul_vv_i16m1(q8b_lo1, mxb_lo1, 16); + sum1 = __riscv_vwmacc_vv_i16m1(sum1, q8b_hi1, mxb_hi1, 16); + vint16m1_t sum2 = __riscv_vwmul_vv_i16m1(q8b_lo2, mxb_lo2, 16); + sum2 = __riscv_vwmacc_vv_i16m1(sum2, q8b_hi2, mxb_hi2, 16); + __riscv_vse32_v_i32m1(&acc1,__riscv_vwredsum_vs_i16m1_i32m1(sum1, __riscv_vmv_v_x_i32m1(0, 1), 16), 1); + __riscv_vse32_v_i32m1(&acc2,__riscv_vwredsum_vs_i16m1_i32m1(sum2, __riscv_vmv_v_x_i32m1(0, 1), 16), 1); + sumf += ((GGML_E8M0_TO_FP32_HALF(x[ib + 0].e) * GGML_CPU_FP16_TO_FP32(y[ib + 0].d) * acc1)); + sumf += ((GGML_E8M0_TO_FP32_HALF(x[ib + 1].e) * GGML_CPU_FP16_TO_FP32(y[ib + 1].d) * acc2)); } *s = sumf; } -void ggml_vec_dot_iq1_m_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { #if defined __riscv_v_intrinsic switch (__riscv_vlenb() * 8) { - case 256: - ggml_vec_dot_iq1_m_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); + case 128: + ggml_vec_dot_mxfp4_q8_0_vl128(n, s, bs, vx, bx, vy, by, nrc); break; default: - ggml_vec_dot_iq1_m_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); + ggml_vec_dot_mxfp4_q8_0_vl256(n, s, bs, vx, bx, vy, by, nrc); break; } #else - ggml_vec_dot_iq1_m_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); + return ggml_vec_dot_mxfp4_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); #endif } From c7abcd577bdb0d351aed145455df7c8a173735a5 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 13 Mar 2026 22:12:54 +0200 Subject: [PATCH 269/831] graph : remove redundant GDN state transposes (llama/20443) * ggml : transpose fused GDN state access for coalesced memory reads (llama/20436) The fused Gated Delta Net kernel accessed the [S_v, S_v] state matrix column-wise on row-major storage, causing strided reads (stride S_v = 128 floats = 512 bytes) that waste GPU cache bandwidth. This produced a 39% regression on Qwen3.5-9B (Metal, M4 Max) compared to the unfused path. Transpose the state indexing so threads read contiguously: - Metal: s_ptr[is*S_v] -> s_ptr[is] (stride 1 vs S_v) - CUDA: curr_state[i*S_v+col] -> curr_state[col*S_v+i] (coalesced) - CPU: restructured loops for row-wise transposed access Also add --fused-gdn [on|off|auto] CLI flag (mirrors --flash-attn) so users can control fused GDN independently of auto-detection. All GATED_DELTA_NET backend-ops tests pass. Co-Authored-By: Claude Opus 4.6 * ggml : use SIMD dot products in CPU GDN kernel, couple AR/chunked fused flags - Replace scalar inner loops with ggml_vec_dot_f32 for SIMD-optimized dot products in the CPU fused GDN kernel (delta and attention output) - Couple fused_gdn_ar and fused_gdn_ch flags in auto-detection: if one path lacks device support, disable both to prevent state layout mismatch between transposed (fused) and non-transposed (unfused) formats Co-Authored-By: Claude Opus 4.6 * llama : rever fgdn argument changes * graph : remove GDN state transposes * vulkan : adapt * cuda : remove obsolete smem code --------- Co-authored-by: Paul Flynn Co-authored-by: Claude Opus 4.6 Co-authored-by: Oliver Simons --- ggml/src/ggml-cpu/ops.cpp | 36 +++++++++++-------- ggml/src/ggml-cuda/gated_delta_net.cu | 24 ++++--------- ggml/src/ggml-metal/ggml-metal.metal | 9 ++--- .../vulkan-shaders/gated_delta_net.comp | 4 +-- 4 files changed, 34 insertions(+), 39 deletions(-) diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 85db02d92f1..314cc1088a0 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -10477,34 +10477,40 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( const float beta_val = *(const float *)((const char *)src_beta->data + iv3 * nbb3 + t * nbb2 + iv1 * nbb1); const float * g_d = (const float *)((const char *)src_g->data + iv3 * nbg3 + t * nbg2 + iv1 * nbg1); + // state is stored transposed: s_out[j*S_v + i] = S[i][j] + // so row j of s_out = column j of S (contiguous access) + if (kda) { + // precompute exp(g) into delta scratch (reused below) for (int64_t i = 0; i < S_v; ++i) { - ggml_vec_scale_f32(S_v, &s_out[i * S_v], expf(g_d[i])); + delta[i] = expf(g_d[i]); + } + // S[i][:] *= exp(g[i]) => for each row j of M: M[j][i] *= exp(g[i]) + for (int64_t j = 0; j < S_v; ++j) { + ggml_vec_mul_f32(S_v, &s_out[j * S_v], &s_out[j * S_v], delta); } } else { ggml_vec_scale_f32(S_v * S_v, s_out, expf(g_d[0])); } - // delta[j] = sum_i S[j][i] * k[i] - memset(delta, 0, S_v * sizeof(float)); - for (int64_t i = 0; i < S_v; ++i) { - ggml_vec_mad_f32(S_v, delta, &s_out[i * S_v], k_d[i]); - } + // delta[j] = sum_i S[i][j] * k[i] = dot(row j of M, k) for (int64_t j = 0; j < S_v; ++j) { - delta[j] = (v_d[j] - delta[j]) * beta_val; + float sum = 0.0f; + ggml_vec_dot_f32(S_v, &sum, 0, &s_out[j * S_v], 0, k_d, 0, 1); + delta[j] = (v_d[j] - sum) * beta_val; } - // outer product: S[j][i] += k[i] * delta[j] - for (int64_t i = 0; i < S_v; ++i) { - ggml_vec_mad_f32(S_v, &s_out[i * S_v], delta, k_d[i]); + // outer product: S[i][j] += k[i] * delta[j] => M[j][i] += delta[j] * k[i] + for (int64_t j = 0; j < S_v; ++j) { + ggml_vec_mad_f32(S_v, &s_out[j * S_v], k_d, delta[j]); } - // attn_out[j] = sum_i S[j][i] * q[i] - memset(attn_data, 0, S_v * sizeof(float)); - for (int64_t i = 0; i < S_v; ++i) { - ggml_vec_mad_f32(S_v, attn_data, &s_out[i * S_v], q_d[i]); + // attn_out[j] = sum_i S[i][j] * q[i] = dot(row j of M, q) + for (int64_t j = 0; j < S_v; ++j) { + float sum = 0.0f; + ggml_vec_dot_f32(S_v, &sum, 0, &s_out[j * S_v], 0, q_d, 0, 1); + attn_data[j] = sum * scale; } - ggml_vec_scale_f32(S_v, attn_data, scale); attn_data += S_v * H; // advance to next token } diff --git a/ggml/src/ggml-cuda/gated_delta_net.cu b/ggml/src/ggml-cuda/gated_delta_net.cu index 5f0fa8e58df..1ce6d5f31b5 100644 --- a/ggml/src/ggml-cuda/gated_delta_net.cu +++ b/ggml/src/ggml-cuda/gated_delta_net.cu @@ -45,10 +45,11 @@ __global__ void gated_delta_net_cuda(const float * q, static_assert(S_v % warp_size == 0, "S_v must be a multiple of warp_size"); constexpr int rows_per_lane = (S_v + warp_size - 1) / warp_size; float s_shard[rows_per_lane]; + // state is stored transposed: M[col][i] = S[i][col], row col is contiguous #pragma unroll for (int r = 0; r < rows_per_lane; r++) { const int i = r * warp_size + lane; - s_shard[r] = curr_state[i * S_v + col]; + s_shard[r] = curr_state[col * S_v + i]; } for (int t = 0; t < n_tokens; t++) { @@ -126,23 +127,14 @@ __global__ void gated_delta_net_cuda(const float * q, attn_data += S_v * H; } - // Write state back to global memory + // Write state back to global memory (transposed layout) #pragma unroll for (int r = 0; r < rows_per_lane; r++) { const int i = r * warp_size + lane; - state[i * S_v + col] = s_shard[r]; + state[col * S_v + i] = s_shard[r]; } } -static size_t calculate_smem(const int sv, int cc) -{ - size_t smem = 0; - if ((GGML_CUDA_CC_IS_AMD(cc) && !GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_RDNA4(cc)) || GGML_CUDA_CC_IS_MTHREADS(cc)) { - smem = sv * sv * sizeof(float); - } - return smem; -} - template static void launch_gated_delta_net( const float * q_d, const float * k_d, const float * v_d, @@ -179,18 +171,14 @@ static void launch_gated_delta_net( sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); break; case 64: { - constexpr int sv = 64; - size_t smem = calculate_smem(sv, cc); - gated_delta_net_cuda<<>>( + gated_delta_net_cuda<64, KDA><<>>( q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); break; } case 128: { - constexpr int sv = 128; - size_t smem = calculate_smem(sv, cc); - gated_delta_net_cuda<<>>( + gated_delta_net_cuda<128, KDA><<>>( q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 107e7cf2ff3..d4b129ed756 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -2469,13 +2469,14 @@ kernel void kernel_gated_delta_net_impl( const float scale = 1.0f / sqrt((float)S_v); - device const float * s_ptr = (device const float *) (s) + (i23*args.ne21 + i21)*S_v*S_v + i20; + // state is stored transposed: M[i20][is] = S[is][i20], so row i20 is contiguous + device const float * s_ptr = (device const float *) (s) + (i23*args.ne21 + i21)*S_v*S_v + i20*S_v; float ls[NSG]; FOR_UNROLL (short j = 0; j < NSG; j++) { const short is = tx*NSG + j; - ls[j] = s_ptr[is*S_v]; + ls[j] = s_ptr[is]; } device float * dst_attn = (device float *) (dst) + (i23*args.ne22*args.ne21 + i21)*S_v + i20; @@ -2536,11 +2537,11 @@ kernel void kernel_gated_delta_net_impl( g_ptr += args.ne21*G; } - device float * dst_state = (device float *) (dst) + args.ne23*args.ne22*args.ne21*S_v + (i23*args.ne21 + i21)*S_v*S_v + i20; + device float * dst_state = (device float *) (dst) + args.ne23*args.ne22*args.ne21*S_v + (i23*args.ne21 + i21)*S_v*S_v + i20*S_v; FOR_UNROLL (short j = 0; j < NSG; j++) { const short is = tx*NSG + j; - dst_state[is*S_v] = ls[j]; + dst_state[is] = ls[j]; } #undef S_v diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp index 1fdf889e824..f008859b99d 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp @@ -44,7 +44,7 @@ void main() { FLOAT_TYPE state[S_V]; [[unroll]] for (uint i = 0; i < S_V; i++) { - state[i] = FLOAT_TYPE(data_state[state_base + i * S_V + col]); + state[i] = FLOAT_TYPE(data_state[state_base + col * S_V + i]); } uint attn_off = (seq_id * n_tokens * H + head_id) * S_V; @@ -123,6 +123,6 @@ void main() { } [[unroll]] for (uint i = 0; i < S_V; i++) { - data_dst[s_off + state_base + i * S_V + col] = state[i]; + data_dst[s_off + state_base + col * S_V + i] = state[i]; } } From a31600d8e34bc76185e4e5b2eddf68895131e223 Mon Sep 17 00:00:00 2001 From: lhez Date: Fri, 13 Mar 2026 22:18:52 -0700 Subject: [PATCH 270/831] opencl: fix l2_norm (llama/20480) --- ggml/src/ggml-opencl/kernels/l2_norm.cl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-opencl/kernels/l2_norm.cl b/ggml/src/ggml-opencl/kernels/l2_norm.cl index 39f400199fa..fb95355a679 100644 --- a/ggml/src/ggml-opencl/kernels/l2_norm.cl +++ b/ggml/src/ggml-opencl/kernels/l2_norm.cl @@ -63,7 +63,7 @@ kernel void kernel_l2_norm_f32( barrier(CLK_LOCAL_MEM_FENCE); - const float scale = 1.0f/sqrt(max(sum[0], eps)); + const float scale = 1.0f/max(sqrt(sum[0]), eps); for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { y[i00] = x[i00] * scale; From 46aad766f5e5e683b190799fcdd855852046cd1a Mon Sep 17 00:00:00 2001 From: Rail Chabdarov Date: Sat, 14 Mar 2026 06:19:44 +0100 Subject: [PATCH 271/831] Fix data race in CUDA's "cpy" kernel (influences GGML's DUP, CONT operations). (llama/20507) * Fix datarace in CUDA's "cpy" kernel. * Remove extra barrier by using more of shared memory. --- ggml/src/ggml-cuda/cpy.cu | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index ee84303ef0e..d208acf2d5f 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -56,7 +56,8 @@ static __global__ void cpy_scalar_transpose(const char * cx, char * cdst, const const int tx = blockIdx.y * CUDA_CPY_TILE_DIM_2D + threadIdx.x; // transpose block offset const int ty = blockIdx.x * CUDA_CPY_TILE_DIM_2D + threadIdx.y; - __shared__ float tile[CUDA_CPY_TILE_DIM_2D][CUDA_CPY_TILE_DIM_2D+1]; + __shared__ float tile[2][CUDA_CPY_TILE_DIM_2D][CUDA_CPY_TILE_DIM_2D+1]; + int cur_tile_buf = 0; #pragma unroll for (int i = 0; i < CUDA_CPY_BLOCK_NM; ++i) { @@ -70,7 +71,7 @@ static __global__ void cpy_scalar_transpose(const char * cx, char * cdst, const if(x < ne01 && y + j < ne00){ const int row = threadIdx.y+j; const int col = threadIdx.x * sizeof(float)/sizeof(T); - T *tile2 = reinterpret_cast(tile[row]); + T *tile2 = reinterpret_cast(tile[cur_tile_buf][row]); tile2[col] = src[imat*n + (y+j)*ne01 + x]; } } @@ -81,10 +82,12 @@ static __global__ void cpy_scalar_transpose(const char * cx, char * cdst, const for (int j = 0; j < CUDA_CPY_TILE_DIM_2D; j += CUDA_CPY_BLOCK_ROWS) { if (ty + j < ne01 && tx < ne00) { const int col = (threadIdx.y+j)*sizeof(float)/sizeof(T); - const T *tile2 = reinterpret_cast(tile[threadIdx.x]); + const T *tile2 = reinterpret_cast(tile[cur_tile_buf][threadIdx.x]); dst[imat*n + (ty+j)*ne00 + tx] = tile2[col]; } } + + cur_tile_buf = (cur_tile_buf + 1) % 2; } GGML_UNUSED_VARS(ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, From 96b163e874b30cfb74032864a91aed5a45169e84 Mon Sep 17 00:00:00 2001 From: Zijun Yu Date: Sat, 14 Mar 2026 13:56:55 +0800 Subject: [PATCH 272/831] ggml : add OpenVINO backend (llama/15307) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Update build doc * Add cgraph tensor output name to OV op name * Update openvino build instructions * Add initial NPU support * draft NPU support version 2: prefill + kvcache * NPU support version 2: prefill + kvcache * Change due to ggml cgraph changes, not correct yet * Change due to ggml cgraph changes, llama-3.2 CPU work * Add AMD64 to CMakeLists * Change due to ggml cgraph changes, all device work * Refactor: clean, fix warning * Update clang-format * Statful transformation for CPU GPU * Add SwiGLU * Fuse to SDPA * Replace Concat with Broadcast in MulMat for GQA * Pull out indices creation for kv cache update * Refactor: remove past_token_len from extra_inputs * Fix Phi3 SwiGLU and SoftMax * Pull out sin cos from rope * Reduce memory: free ov weights node after graph conversion * Fix CPY due to cgraph change * Added OpenVINO CI/CD. Updated docs * Fix llama-cli * Fix Phi3 ROPE; Add test-backend-ops * Fix NPU * Fix llama-bench; Clang-format * Fix llama-perplexity * temp. changes for mark decomp * matmul in fp32 * mulmat input conversion fix * mulmat type conversion update * add mark decomp pass * Revert changes in fuse_to_sdpa * Update build.md * Fix test-backend-ops * Skip test-thread-safety; Run ctest only in ci/run.sh * Use CiD for NPU * Optimize tensor conversion, improve TTFT * Support op SET_ROWS * Fix NPU * Remove CPY * Fix test-backend-ops * Minor updates for raising PR * Perf: RMS fused to OV internal RMS op * Fix after rebasing - Layout of cache k and cache v are unified: [seq, n_head, head_size] - Add CPY and FLASH_ATTN_EXT, flash attn is not used yet - Skip test-backend-ops due to flash attn test crash - Add mutex around graph conversion to avoid test-thread-safety fali in the future - Update NPU config - Update GPU config to disable SDPA opt to make phi-3 run * Change openvino device_type to GPU; Enable flash_attn * Update supports_buft and supports_op for quantized models * Add quant weight conversion functions from genai gguf reader * Quant models run with accuracy issue * Fix accuracy: disable cpu_repack * Fix CI; Disable test-backend-ops * Fix Q4_1 * Fix test-backend-ops: Treat quantized tensors as weights * Add NPU Q4_0 support * NPU perf: eliminate zp * Dequantize q4_1 q4_k q6_k for NPU * Add custom quant type: q8_1_c, q4_0_128 * Set m_is_static=false as default in decoder * Simpilfy translation of get_rows * Fix after rebasing * Improve debug util; Eliminate nop ReshapeReshape * STYLE: make get_types_to_requant a function * Support BF16 model * Fix NPU compile * WA for npu 1st token acc issue * Apply EliminateZP only for npu * Add GeGLU * Fix Hunyuan * Support iSWA * Fix NPU accuracy * Fix ROPE accuracy when freq_scale != 1 * Minor: not add attention_size_swa for non-swa model * Minor refactor * Add Q5_K to support phi-3-q4_k_m * Requantize Q6_K (gs16) to gs32 on GPU * Fix after rebasing * Always apply Eliminate_ZP to fix GPU compile issue on some platforms * kvcachefusion support * env variable GGML_OPENVINO_DISABLE_SDPA_OPTIMIZATION added * Fix for Phi3 * Fix llama-cli (need to run with --no-warmup) * Fix add_sliced_mask; Revert mulmat, softmax; Remove input attention_size, iSWA model not working * fix after rebasing * Fix llama-3-8b and phi3-mini q4_0 NPU * Update to OV-2025.3 and CMakeLists.txt * Add OV CI cache * Apply CISC review and update CI to OV2025.3 * Update CI to run OV dep install before build * Update OV dockerfile to use OV2025.3 and update build docs * Style: use switch in supports_ops * Style: middle ptr and ref align, omit optional struct keyword * NPU Unify PD (llama/14) * Stateless. Fix llama-cli llama-server * Simplify broadcast op in attention * Replace get_output_tensor+memcpy with set_output_tensor * NPU unify PD. Unify dynamic and static dims * Clean placeholders in ggml-openvino.cpp * NPU unify PD (handled internally) * change graph to 4d, support multi sequences * Fix llama-bench * Fix NPU * Update ggml-decoder.cpp Hitting error while compiling on windows: error C3861: 'unsetenv': identifier not found Reason: unsetenv() is a POSIX function; it doesn’t exist on Windows. Visual Studio (MSVC) won’t recognize it. Proposed fix: Use _putenv_s() (Windows equivalent) This is supported by MSVC and achieves the same effect: it removes the environment variable from the process environment. This keeps cross-platform compatibility. * Update ggml-decoder.cpp * Update ggml-decoder.cpp * Update ggml-decoder.cpp * Update ggml-decoder.cpp * Update ggml-decoder.cpp * Remove the second decoder for node. Moving the function into the model decoder * Fix error for naive * NPU prefill chunking * NPU fix llama-bench * fallback naive run with accuracy issue * NPU support llma-perplexity -b 512 --no-warmup * Refactor: split ov_graph_compute for dynamic and static * remove unused API GgmlOvDecoder::get_output_stride(const std::string & name) * minor update due to ov 2025.4 * remove unused API GgmlOvDecoder::get_output_names() * remove unused API get_output_shape(const std::string & name) * Modified API GgmlOvDecoder::get_output_type(const std::string & name) * Removed API GgmlOvDecoder::get_output_op_params(const std::string & name) * Removed API get_output_ggml_tensor(const std::string & name) * Removed API m_outputs * Removed m_output_names * Removed API GgmlOvDecoder::get_input_names() * Removed API GgmlOvDecoder::get_input_stride(const std::string& name) * Removed API get_input_type * Removed API get_input_type * Removed API GgmlOvDecoder::get_input_shape(const std::string & name) * Removed API GgmlOvDecoder::get_input_op_params(const std::string & name) * Fix error for decoder cache * Reuse cached decoder * GPU remove Q6_K requantization * NPU fix wrong model output shape * NPU fix q4 perf regression * Remove unused variable nodes * Fix decoder can_reuse for llama-bench * Update build.md for Windows * backend buffer: allocate on host * Use shared_buffer for GPU NPU; Refactor * Add ov_backend_host_buffer; Use cached remote context * Put kvcache on GPU * Use ggml_aligned_malloc * only use remote tensor for kvcache * only use remote tensor for kvcache for GPU * FIX: use remote tensor from singleton * Update build.md to include OpenCL * NPU always requant to q4_0_128 * Optimize symmetric quant weight extraction: use single zp * Use Q8_0_C in token embd, lm_head, and for 5 and 6 bits quant * Update build.md * Support -ctk f32 * Initial stateful graph support * Update ggml/src/ggml-openvino/ggml-decoder.cpp Co-authored-by: Yamini Nimmagadda * code cleanup * npu perf fix * requant to f16 for Q6 embed on NPU * Update ggml/src/ggml-openvino/ggml-decoder.cpp * Update ggml/src/ggml-openvino/ggml-openvino-extra.cpp * Create OPENVINO.md in llama.cpp backend docs * Update OPENVINO.md * Update OPENVINO.md * Update OPENVINO.md * Update build.md * Update OPENVINO.md * Update OPENVINO.md * Update OPENVINO.md * kq_mask naming fix * Syntax correction for workflows build file * Change ov backend buffer is_host to false * Fix llama-bench -p -n where p<=256 * Fix --direct-io 0 * Don't put kvcache on GPU in stateful mode * Remove hardcode names * Fix stateful shapes * Simplification for stateful and update output shape processing * Remove hardcode names * Avoid re-compilation in llama-bench * Extract zp directly instead of bias * Refactor weight tensor processing * create_weight_node accept non-ov backend buffer * remove changes in llama-graph.cpp * stateful masking fix (llama/38) Fix for stateful accuracy issues and cl_out_of_resources error in stateful GPU with larger context sizes. * Fix test-backend-ops crash glu, get_rows, scale, rms_norm, add * hardcoded name handling for rope_freqs.weight * Suppress logging and add error handling to allow test-backend-ops to complete * Fix MUL_MAT with broadcast; Add unsupported MUL_MAT FLASH_ATTN cases * Use bias instead of zp in test-backend-ops * Update OV in CI, Add OV CI Tests in GH Actions * Temp fix for multithreading bug * Update OV CI, fix review suggestions. * fix editorconfig-checker, update docs * Fix tabs to spaces for editorconfig-checker * fix editorconfig-checker * Update docs * updated model link to be GGUF model links * Remove GGML_CPU_REPACK=OFF * Skip permuted ADD and MUL * Removed static variables from utils.cpp * Removed initializing non-existing variable * Remove unused structs * Fix test-backend-ops for OV GPU * unify api calling * Update utils.cpp * When the dim is dynamic, throw an error, need to is stastic forst * Add interface compute_model_outputs(), which get the model output through computing the node use count & status in the cgraph to avoid the flag using * No need to return * Fix test-backend-ops for OV GPU LNL * Fix test-thread-safety * use the shape from infer request of output tensor create to avoid issue * fix dynamic output shape issue * fix issue for the unused node in tests * Remove unused lock * Add comment * Update openvino docs * update to OV release version 2026.0 * add ci ov-gpu self hosted runner * fix editorconfig * Fix perplexity * Rewrite the model inputs finding mechanism (llama/54) * Rewrite the model inputs finding logistic * Put stateful shape handle in get input shape * Put the iteration logistic in func * Added ggml-ci-intel-openvino-gpu and doc update * .hpp files converted to .h * fix ggml-ci-x64-intel-openvino-gpu * Fix for stateful execution bug in llama-bench * Minor updates after stateful llama-bench fix * Update ggml/src/ggml-openvino/utils.cpp Co-authored-by: Yamini Nimmagadda * Remove multiple get_shape calls * Bring back mutex into compute * Fix VIEW op, which slice the input node * Added token_len_per_seq existence check before slicing masks and moved node retrieval inside guarded block to prevent missing-key access * Temp. fix for test requant errors * Update to OV ggml-ci to low-perf * ci : temporary disable "test-llama-archs" * ci : cache v4 -> v5, checkout v4 -> v6, fix runner tag * docs : update url * Fix OV link in docker and Update docs --------- Co-authored-by: Ravi Panchumarthy Co-authored-by: Cavus Mustafa Co-authored-by: Arshath Co-authored-by: XuejunZhai Co-authored-by: Yamini Nimmagadda Co-authored-by: Xuejun Zhai Co-authored-by: Georgi Gerganov --- ggml/CMakeLists.txt | 3 + ggml/include/ggml-openvino.h | 37 + ggml/src/CMakeLists.txt | 1 + ggml/src/ggml-backend-reg.cpp | 8 + ggml/src/ggml-openvino/.clang-format | 154 +++ ggml/src/ggml-openvino/CMakeLists.txt | 22 + ggml/src/ggml-openvino/ggml-decoder.cpp | 975 +++++++++++++++ ggml/src/ggml-openvino/ggml-decoder.h | 294 +++++ .../src/ggml-openvino/ggml-openvino-extra.cpp | 373 ++++++ ggml/src/ggml-openvino/ggml-openvino-extra.h | 182 +++ ggml/src/ggml-openvino/ggml-openvino.cpp | 1110 +++++++++++++++++ ggml/src/ggml-openvino/ggml-quants.cpp | 884 +++++++++++++ ggml/src/ggml-openvino/ggml-quants.h | 153 +++ ggml/src/ggml-openvino/openvino/decoder.h | 74 ++ ggml/src/ggml-openvino/openvino/frontend.cpp | 27 + ggml/src/ggml-openvino/openvino/frontend.h | 23 + .../ggml-openvino/openvino/input_model.cpp | 17 + ggml/src/ggml-openvino/openvino/input_model.h | 29 + .../src/ggml-openvino/openvino/node_context.h | 112 ++ ggml/src/ggml-openvino/openvino/op/cont.cpp | 48 + ggml/src/ggml-openvino/openvino/op/cpy.cpp | 21 + .../openvino/op/flash_attn_ext.cpp | 90 ++ .../ggml-openvino/openvino/op/get_rows.cpp | 69 + .../ggml-openvino/openvino/op/glu_geglu.cpp | 61 + .../ggml-openvino/openvino/op/glu_swiglu.cpp | 62 + ggml/src/ggml-openvino/openvino/op/mulmat.cpp | 90 ++ .../src/ggml-openvino/openvino/op/permute.cpp | 102 ++ .../src/ggml-openvino/openvino/op/reshape.cpp | 83 ++ .../ggml-openvino/openvino/op/rms_norm.cpp | 46 + ggml/src/ggml-openvino/openvino/op/rope.cpp | 123 ++ ggml/src/ggml-openvino/openvino/op/scale.cpp | 41 + .../ggml-openvino/openvino/op/set_rows.cpp | 76 ++ .../src/ggml-openvino/openvino/op/softmax.cpp | 89 ++ .../ggml-openvino/openvino/op/transpose.cpp | 23 + .../ggml-openvino/openvino/op/unary_silu.cpp | 27 + ggml/src/ggml-openvino/openvino/op/view.cpp | 53 + ggml/src/ggml-openvino/openvino/op_table.cpp | 46 + ggml/src/ggml-openvino/openvino/op_table.h | 39 + .../openvino/pass/eliminate_zp.cpp | 123 ++ .../openvino/pass/eliminate_zp.h | 17 + .../openvino/pass/fuse_to_sdpa.cpp | 60 + .../openvino/pass/fuse_to_sdpa.h | 17 + ...k_decompression_convert_constant_folding.h | 29 + .../openvino/pass/squeeze_matmul.cpp | 58 + .../openvino/pass/squeeze_matmul.h | 17 + .../openvino/translate_session.cpp | 293 +++++ .../openvino/translate_session.h | 28 + ggml/src/ggml-openvino/openvino/utils.cpp | 226 ++++ ggml/src/ggml-openvino/openvino/utils.h | 85 ++ ggml/src/ggml-openvino/utils.cpp | 823 ++++++++++++ ggml/src/ggml-openvino/utils.h | 123 ++ 51 files changed, 7566 insertions(+) create mode 100644 ggml/include/ggml-openvino.h create mode 100644 ggml/src/ggml-openvino/.clang-format create mode 100644 ggml/src/ggml-openvino/CMakeLists.txt create mode 100644 ggml/src/ggml-openvino/ggml-decoder.cpp create mode 100644 ggml/src/ggml-openvino/ggml-decoder.h create mode 100644 ggml/src/ggml-openvino/ggml-openvino-extra.cpp create mode 100644 ggml/src/ggml-openvino/ggml-openvino-extra.h create mode 100644 ggml/src/ggml-openvino/ggml-openvino.cpp create mode 100644 ggml/src/ggml-openvino/ggml-quants.cpp create mode 100644 ggml/src/ggml-openvino/ggml-quants.h create mode 100644 ggml/src/ggml-openvino/openvino/decoder.h create mode 100644 ggml/src/ggml-openvino/openvino/frontend.cpp create mode 100644 ggml/src/ggml-openvino/openvino/frontend.h create mode 100644 ggml/src/ggml-openvino/openvino/input_model.cpp create mode 100644 ggml/src/ggml-openvino/openvino/input_model.h create mode 100644 ggml/src/ggml-openvino/openvino/node_context.h create mode 100644 ggml/src/ggml-openvino/openvino/op/cont.cpp create mode 100644 ggml/src/ggml-openvino/openvino/op/cpy.cpp create mode 100644 ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp create mode 100644 ggml/src/ggml-openvino/openvino/op/get_rows.cpp create mode 100644 ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp create mode 100644 ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp create mode 100644 ggml/src/ggml-openvino/openvino/op/mulmat.cpp create mode 100644 ggml/src/ggml-openvino/openvino/op/permute.cpp create mode 100644 ggml/src/ggml-openvino/openvino/op/reshape.cpp create mode 100644 ggml/src/ggml-openvino/openvino/op/rms_norm.cpp create mode 100644 ggml/src/ggml-openvino/openvino/op/rope.cpp create mode 100644 ggml/src/ggml-openvino/openvino/op/scale.cpp create mode 100644 ggml/src/ggml-openvino/openvino/op/set_rows.cpp create mode 100644 ggml/src/ggml-openvino/openvino/op/softmax.cpp create mode 100644 ggml/src/ggml-openvino/openvino/op/transpose.cpp create mode 100644 ggml/src/ggml-openvino/openvino/op/unary_silu.cpp create mode 100644 ggml/src/ggml-openvino/openvino/op/view.cpp create mode 100644 ggml/src/ggml-openvino/openvino/op_table.cpp create mode 100644 ggml/src/ggml-openvino/openvino/op_table.h create mode 100644 ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp create mode 100644 ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h create mode 100644 ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp create mode 100644 ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h create mode 100644 ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h create mode 100644 ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp create mode 100644 ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h create mode 100644 ggml/src/ggml-openvino/openvino/translate_session.cpp create mode 100644 ggml/src/ggml-openvino/openvino/translate_session.h create mode 100644 ggml/src/ggml-openvino/openvino/utils.cpp create mode 100644 ggml/src/ggml-openvino/openvino/utils.h create mode 100644 ggml/src/ggml-openvino/utils.cpp create mode 100644 ggml/src/ggml-openvino/utils.h diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 8f679e2fd35..44e58a52761 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -248,6 +248,8 @@ set (GGML_SYCL_TARGET "INTEL" CACHE STRING set (GGML_SYCL_DEVICE_ARCH "" CACHE STRING "ggml: sycl device architecture") +option(GGML_OPENVINO "ggml: use OPENVINO" OFF) + option(GGML_OPENCL "ggml: use OpenCL" OFF) option(GGML_OPENCL_PROFILING "ggml: use OpenCL profiling (increases overhead)" OFF) option(GGML_OPENCL_EMBED_KERNELS "ggml: embed kernels" ON) @@ -327,6 +329,7 @@ set(GGML_PUBLIC_HEADERS include/ggml-vulkan.h include/ggml-webgpu.h include/ggml-zendnn.h + include/ggml-openvino.h include/gguf.h) set_target_properties(ggml PROPERTIES PUBLIC_HEADER "${GGML_PUBLIC_HEADERS}") diff --git a/ggml/include/ggml-openvino.h b/ggml/include/ggml-openvino.h new file mode 100644 index 00000000000..c43beb07b6a --- /dev/null +++ b/ggml/include/ggml-openvino.h @@ -0,0 +1,37 @@ +#pragma once + +#include "ggml-backend.h" + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +#define GGML_OPENVINO_NAME "OPENVINO" + +// backend API +GGML_BACKEND_API ggml_backend_t ggml_backend_openvino_init(int device); + +GGML_BACKEND_API bool ggml_backend_is_openvino(ggml_backend_t backend); + +GGML_BACKEND_API bool ggml_backend_buffer_is_openvino(ggml_backend_buffer_t buffer); + +GGML_BACKEND_API bool ggml_backend_buft_is_openvino(ggml_backend_buffer_type_t buft); + +GGML_BACKEND_API bool ggml_backend_buft_is_openvino_host(ggml_backend_buffer_type_t buft); + +GGML_BACKEND_API size_t ggml_backend_openvino_buffer_get_ctx_id(ggml_backend_buffer_t buffer); + +// device buffer +GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_openvino_buffer_type(int device); + +GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_openvino_host_buffer_type(int device); + +GGML_BACKEND_API int ggml_backend_openvino_get_device_count(void); + +GGML_BACKEND_API ggml_backend_reg_t ggml_backend_openvino_reg(void); + +#ifdef __cplusplus +} +#endif diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index 265023733e7..78853304d9f 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -460,6 +460,7 @@ ggml_add_backend(zDNN) ggml_add_backend(OpenCL) ggml_add_backend(Hexagon) ggml_add_backend(ZenDNN) +ggml_add_backend(OPENVINO) foreach (target ggml-base ggml) target_include_directories(${target} PUBLIC $ $) diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp index 311fa5fe368..0587109212e 100644 --- a/ggml/src/ggml-backend-reg.cpp +++ b/ggml/src/ggml-backend-reg.cpp @@ -82,6 +82,10 @@ #include "ggml-zendnn.h" #endif +#ifdef GGML_USE_OPENVINO +#include "ggml-openvino.h" +#endif + namespace fs = std::filesystem; static std::string path_str(const fs::path & path) { @@ -154,6 +158,9 @@ struct ggml_backend_registry { #ifdef GGML_USE_RPC register_backend(ggml_backend_rpc_reg()); #endif +#ifdef GGML_USE_OPENVINO + register_backend(ggml_backend_openvino_reg()); +#endif #ifdef GGML_USE_CPU register_backend(ggml_backend_cpu_reg()); #endif @@ -557,6 +564,7 @@ void ggml_backend_load_all_from_path(const char * dir_path) { ggml_backend_load_best("opencl", silent, dir_path); ggml_backend_load_best("hexagon", silent, dir_path); ggml_backend_load_best("musa", silent, dir_path); + ggml_backend_load_best("openvino", silent, dir_path); ggml_backend_load_best("cpu", silent, dir_path); // check the environment variable GGML_BACKEND_PATH to load an out-of-tree backend const char * backend_path = std::getenv("GGML_BACKEND_PATH"); diff --git a/ggml/src/ggml-openvino/.clang-format b/ggml/src/ggml-openvino/.clang-format new file mode 100644 index 00000000000..a2a24d7d33a --- /dev/null +++ b/ggml/src/ggml-openvino/.clang-format @@ -0,0 +1,154 @@ +--- +# Override root .clang-format +AlignConsecutiveAssignments: false +AlignConsecutiveDeclarations: false +Cpp11BracedListStyle: true +SpacesInContainerLiterals: false +BreakBeforeBraces: Attach +AccessModifierOffset: -4 +IndentCaseBlocks: false +IndentCaseLabels: false + +Language: Cpp +AlignAfterOpenBracket: Align +AlignArrayOfStructures: Left +AlignConsecutiveBitFields: AcrossComments +AlignConsecutiveMacros: AcrossComments +# AlignConsecutiveShortCaseStatements: AcrossComments +AlignEscapedNewlines: Left # LeftWithLastLine +AlignOperands: Align +AlignTrailingComments: + Kind: Always + OverEmptyLines: 1 +AllowAllArgumentsOnNextLine: true +AllowAllParametersOfDeclarationOnNextLine: false +# AllowBreakBeforeNoexceptSpecifier: OnlyWithParen +AllowShortBlocksOnASingleLine: Never +AllowShortCaseLabelsOnASingleLine: false +AllowShortFunctionsOnASingleLine: Inline +AllowShortIfStatementsOnASingleLine: Never +AllowShortLambdasOnASingleLine: Inline +AllowShortLoopsOnASingleLine: false +AlwaysBreakBeforeMultilineStrings: true +# Treat CUDA keywords/attributes as "attribute macros" and avoid breaking lines inside them +AttributeMacros: + - __host__ + - __device__ + - __global__ + - __forceinline__ + - __launch_bounds__ +BinPackArguments: true +BinPackParameters: false # OnePerLine +BitFieldColonSpacing: Both +# BreakAdjacentStringLiterals: true +BreakAfterAttributes: Never +BreakBeforeBinaryOperators: None +BreakBeforeInlineASMColon: OnlyMultiline +BreakBeforeTernaryOperators: false +# BreakBinaryOperations: Never +BreakConstructorInitializers: AfterColon +# BreakFunctionDefinitionParameters: false +BreakInheritanceList: AfterComma +BreakStringLiterals: true +# BreakTemplateDeclarations: Yes +ColumnLimit: 120 +CommentPragmas: '^ IWYU pragma:' +CompactNamespaces: false +ConstructorInitializerIndentWidth: 4 +ContinuationIndentWidth: 4 +DerivePointerAlignment: false +DisableFormat: false +EmptyLineBeforeAccessModifier: Leave +EmptyLineAfterAccessModifier: Never +ExperimentalAutoDetectBinPacking: false +FixNamespaceComments: true +IncludeBlocks: Regroup +IncludeCategories: + - Regex: '".*"' + Priority: 1 + SortPriority: 0 + - Regex: '^<.*\.h>' + Priority: 2 + SortPriority: 0 + - Regex: '^<.*' + Priority: 3 + SortPriority: 0 + - Regex: '.*' + Priority: 4 + SortPriority: 0 +IncludeIsMainRegex: '([-_](test|unittest))?$' +IncludeIsMainSourceRegex: '' +IndentAccessModifiers: false +IndentExternBlock: NoIndent +IndentGotoLabels: false +IndentPPDirectives: AfterHash +IndentWidth: 4 +IndentWrappedFunctionNames: false +InsertBraces: true # NOTE: may lead to incorrect formatting +InsertNewlineAtEOF: true +JavaScriptQuotes: Leave +JavaScriptWrapImports: true +KeepEmptyLinesAtTheStartOfBlocks: false +LambdaBodyIndentation: Signature +LineEnding: LF +MacroBlockBegin: '' +MacroBlockEnd: '' +MaxEmptyLinesToKeep: 1 +NamespaceIndentation: None +ObjCBinPackProtocolList: Auto +ObjCBlockIndentWidth: 4 +ObjCSpaceAfterProperty: true +ObjCSpaceBeforeProtocolList: true +PPIndentWidth: -1 +PackConstructorInitializers: CurrentLine +PenaltyBreakAssignment: 2 +PenaltyBreakBeforeFirstCallParameter: 1 +PenaltyBreakComment: 300 +PenaltyBreakFirstLessLess: 120 +PenaltyBreakString: 1000 +PenaltyBreakTemplateDeclaration: 10 +PenaltyExcessCharacter: 1000000 +PenaltyReturnTypeOnItsOwnLine: 200 +PointerAlignment: Middle +QualifierAlignment: Left +#QualifierOrder: ['static', 'inline', 'friend', 'constexpr', 'const', 'volatile', 'type', 'restrict'] +RawStringFormats: + - Language: Cpp + Delimiters: + - cc + - CC + - cpp + - Cpp + - CPP + - 'c++' + - 'C++' + CanonicalDelimiter: '' +ReferenceAlignment: Middle +ReflowComments: false # IndentOnly +SeparateDefinitionBlocks: Always +SortIncludes: CaseInsensitive +SortUsingDeclarations: LexicographicNumeric +SpaceAfterCStyleCast: true +SpaceAfterLogicalNot: false +SpaceAfterTemplateKeyword: true +SpaceBeforeAssignmentOperators: true +SpaceBeforeCpp11BracedList: false +SpaceBeforeCtorInitializerColon: true +SpaceBeforeInheritanceColon: true +SpaceBeforeParens: ControlStatements +SpaceBeforeRangeBasedForLoopColon: true +SpaceInEmptyBlock: false +SpaceInEmptyParentheses: false +SpacesBeforeTrailingComments: 2 +SpacesInAngles: Never +SpacesInLineCommentPrefix: + Minimum: 1 + Maximum: -1 +SpacesInParentheses: false +SpacesInSquareBrackets: false +SpaceBeforeSquareBrackets: false +Standard: c++17 +TabWidth: 4 +UseTab: Never +WhitespaceSensitiveMacros: ['STRINGIZE'] +... diff --git a/ggml/src/ggml-openvino/CMakeLists.txt b/ggml/src/ggml-openvino/CMakeLists.txt new file mode 100644 index 00000000000..175b585661d --- /dev/null +++ b/ggml/src/ggml-openvino/CMakeLists.txt @@ -0,0 +1,22 @@ +find_package(OpenVINO REQUIRED) +find_package(OpenCL REQUIRED) + +include("${OpenVINO_DIR}/../3rdparty/tbb/lib/cmake/TBB/TBBConfig.cmake") + +file(GLOB_RECURSE GGML_HEADERS_OPENVINO "*.h" "*.hpp") +file(GLOB_RECURSE GGML_SOURCES_OPENVINO "*.cpp") + +ggml_add_backend_library(ggml-openvino + ${GGML_SOURCES_OPENVINO} + ${GGML_HEADERS_OPENVINO} +) + +target_link_libraries(ggml-openvino PRIVATE openvino::runtime TBB::tbb OpenCL::OpenCL) + +if (GGML_OPENVINO) + if (CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64") + elseif (CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64" OR CMAKE_SYSTEM_PROCESSOR STREQUAL "amd64" OR CMAKE_SYSTEM_PROCESSOR STREQUAL "AMD64") + else() + message(FATAL_ERROR "OpenVINO: OpenVINO toolkit supports x86-64 and arm64 but not ${CMAKE_SYSTEM_PROCESSOR}") + endif() +endif() diff --git a/ggml/src/ggml-openvino/ggml-decoder.cpp b/ggml/src/ggml-openvino/ggml-decoder.cpp new file mode 100644 index 00000000000..0938d2273e9 --- /dev/null +++ b/ggml/src/ggml-openvino/ggml-decoder.cpp @@ -0,0 +1,975 @@ +#include "ggml-decoder.h" + +#include "ggml-backend-impl.h" +#include "ggml-backend.h" +#include "ggml-openvino-extra.h" +#include "ggml-openvino.h" +#include "ggml-quants.h" + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +GgmlOvDecoder::GgmlOvDecoder(ggml_cgraph * cgraph, + ModelParams & model_params, + ComputeParams & compute_params, + std::map> & model_weights, + bool is_static, + bool is_stateful, + bool is_prefill, + int prefill_chunk_size) : + m_is_static(is_static), + m_is_stateful(is_stateful), + m_is_prefill(is_prefill), + m_naive(false), + m_prefill_chunk_size(prefill_chunk_size), + m_cgraph(cgraph), + m_model_weights(model_weights), + m_model_params(model_params), + m_compute_params(compute_params) { + if (auto * env = getenv("GGML_OPENVINO_PRINT_CGRAPH_TENSOR_ADDRESS"); env && std::string(env) != "0") { +#ifdef _WIN32 + _putenv_s("GGML_OPENVINO_PRINT_CGRAPH_TENSOR_ADDRESS", ""); +#else + unsetenv("GGML_OPENVINO_PRINT_CGRAPH_TENSOR_ADDRESS"); +#endif + print_tensor_address_map(cgraph); + } + + validate_cgraph(); + + set_input_output(); + compute_model_inputs(); + compute_model_outputs(); + + for (int node_n = 0; node_n < cgraph->n_nodes; node_n++) { + m_node_info_list[node_n].node_op_case = compute_op_case(m_node_info_list[node_n].node); + m_node_info_list[node_n].node_op_type = compute_op_type(m_node_info_list[node_n].node); + } + + add_extra_inputs(); +} + +void GgmlOvDecoder::update_io(ggml_cgraph * cgraph) { + m_cgraph = cgraph; + m_model_inputs.clear(); + m_model_outputs.clear(); + m_node_info_list.clear(); + set_input_output(); + compute_model_inputs(); + compute_model_outputs(); +} + +GgmlOvDecoder::GgmlOvDecoder(ggml_cgraph * cgraph, std::map> & model_weights) { + m_cgraph = cgraph; + m_model_weights = model_weights; + m_naive = true; + set_input_output(); + compute_model_inputs(); + compute_model_outputs(); + for (int node_n = 0; node_n < cgraph->n_nodes; node_n++) { + m_node_info_list[node_n].node_op_case = compute_op_case(m_node_info_list[node_n].node); + m_node_info_list[node_n].node_op_type = compute_op_type(m_node_info_list[node_n].node); + } +} + +void GgmlOvDecoder::set_input_output() { + for (int node_n = 0; node_n < m_cgraph->n_nodes; node_n++) { + auto node = m_cgraph->nodes[node_n]; + + NodeInfo current_node_info; + auto node_name = std::string(node->name); + auto node_output_name = node_name; + auto * node_output = node; + if (node->op == GGML_OP_SET_ROWS) { + // SET_ROWS updates the tensor in place. For later ov op that uses the + // the view_src of SET_ROWS, we need to make sure they get the updated tensor + // by putting the view_src name in the tensor_map in + // /src/frontends/ggml/src/translate_session.cpp + node_output_name = std::string(node->view_src->name); + node_output = node->view_src; + } + + current_node_info.node = node; + current_node_info.node_name = node_name; + current_node_info.node_output = node_output; + current_node_info.node_output_name = node_output_name; + current_node_info.node_op_case = 0; + current_node_info.data_addr = node->data; + + for (int i = 0; i < GGML_MAX_SRC; i++) { + auto * src = node->src[i]; + if (src == nullptr) { + continue; + } + auto src_name = std::string(src->name); + if (src->flags & GGML_TENSOR_FLAG_INPUT) { + src_name = get_graph_input_ov_name(src, node); + } + current_node_info.node_inputs[src_name] = src; + current_node_info.node_inputs_names.push_back(src_name); + } + + m_node_info_list.push_back(current_node_info); + } +} + +int GgmlOvDecoder::compute_op_case(const ggml_tensor * node) const { + int op_case = 0; + switch (node->op) { + case GGML_OP_RESHAPE: { + auto * src = node->src[0]; + if (src->op == GGML_OP_RESHAPE && src->src[0]->ne[0] == node->ne[0] && src->src[0]->ne[1] == node->ne[1]) { + op_case = 4; + } else if (node->ne[0] * node->ne[1] == src->ne[0]) { + op_case = 1; + } else if (src->ne[0] * src->ne[1] == node->ne[0]) { + op_case = 2; + if (src->ne[2] * src->ne[3] == node->ne[1]) { + op_case = 5; + } + } else if (src->ne[0] * src->ne[1] == node->ne[1]) { + op_case = 3; + } else if (src->ne[1] * src->ne[2] == node->ne[1]) { + op_case = 6; + } + break; + } + case GGML_OP_CONT: { + if (node->src[0]->op == GGML_OP_PERMUTE) { + op_case = 1; + } else if (node->src[0]->op == GGML_OP_TRANSPOSE) { + op_case = 2; + } else if (node->src[0]->op == GGML_OP_VIEW) { + op_case = 3; + } + break; + } + case GGML_OP_PERMUTE: { + if (node->src[0]->op != GGML_OP_VIEW) { + op_case = 1; + } else if (node->src[0]->src[0]->op == GGML_OP_NONE) { + // kv cache tensor + std::string src_name(node->view_src->name); + int layer = extract_layer_from_name(src_name); + if (!is_swa_layer(layer)) { + op_case = 2; + } else { + op_case = 3; + } + } else { + // rope'ed query tensor + op_case = 4; + } + break; + } + case GGML_OP_MUL_MAT: { + if (node->src[0]->op == GGML_OP_CONT && node->src[0]->src[0]->op == GGML_OP_TRANSPOSE) { + op_case = 2; + } else if (node->src[0]->op == GGML_OP_VIEW && node->src[1]->op == GGML_OP_VIEW) { + op_case = 3; + } + break; + } + case GGML_OP_GET_ROWS: { + if (node->src[1]->op == GGML_OP_VIEW) { + op_case = 2; + } + break; + } + case GGML_OP_ROPE: { + if (node->src[0]->op == GGML_OP_VIEW) { + op_case = 2; + } + break; + } + case GGML_OP_VIEW: { + if (node->src[0]->op == GGML_OP_VIEW) { + auto * src = node->src[0]; + if (ggml_nelements(node) != ggml_nelements(src)) { + throw std::runtime_error("Unsupported VIEW case"); + } + op_case = 2; + } + { + auto * src = node->src[0]; + if ((ggml_nelements(node) != ggml_nelements(src)) && m_naive) { + // Compare each dimension of node and src, if only one dimension differs then op_case=3 + int diff_count = 0; + for (int i = 0; i < GGML_MAX_DIMS; i++) { + if (node->ne[i] != src->ne[i]) { + diff_count++; + } + } + if (diff_count == 1) { + op_case = 3; + } + } + } + break; + } + default: + break; + } + return op_case; +} + +int extract_layer_from_name(const std::string & name) { + size_t pos1 = name.find("_l"); + assert(pos1 != std::string::npos); + pos1 += 2; + size_t pos2 = name.find(' ', pos1); + if (pos2 == std::string::npos) { + pos2 = name.length(); + } + std::string layer_str = name.substr(pos1, pos2 - pos1); + int layer = std::stoi(layer_str); + return layer; +} + +std::pair GgmlOvDecoder::compute_llm_params(ggml_cgraph * cgraph, bool is_static) { + ModelParams model_params; + ComputeParams compute_params; + for (int i = 0; i < cgraph->n_nodes; i++) { + auto * node = cgraph->nodes[i]; + std::string name = std::string(node->name); + if (node->op == GGML_OP_FLASH_ATTN_EXT) { + model_params.n_heads = node->src[0]->ne[2]; + model_params.n_heads_kv = node->src[1]->ne[2]; + model_params.head_size = node->src[0]->ne[0]; + compute_params.input_len = node->src[0]->ne[1]; + + auto * cache_k_perm = node->src[1]; + if (cache_k_perm->op == GGML_OP_CPY) { + cache_k_perm = cache_k_perm->src[0]; + } + assert(cache_k_perm->op == GGML_OP_PERMUTE); + auto * cache_k_view = cache_k_perm->src[0]; + assert(cache_k_view->op == GGML_OP_VIEW); + + auto * cache_k = cache_k_view->src[0]; + int layer = extract_layer_from_name(cache_k->name); + auto * mask = node->src[3]; + std::string mask_name(mask->name); + + model_params.kv_buffer_ctx_id = ggml_backend_openvino_buffer_get_ctx_id(cache_k->buffer); + if (mask_name.find("swa") != std::string::npos) { + model_params.swa_layers.push_back(layer); + model_params.ctx_per_seq_swa = cache_k->ne[1]; + } else { + model_params.ctx_per_seq = cache_k->ne[1]; + model_params.n_seq = cache_k->ne[2]; + } + + compute_params.n_seq_active = mask->ne[3]; + auto seq_size = cache_k->ne[0] * cache_k->ne[1] * ggml_type_size(cache_k->type); + size_t offset; + memcpy(&offset, cache_k_view->op_params, sizeof(size_t)); + compute_params.seq_active_start = offset / seq_size; + compute_params.token_len_per_seq = node->ne[2]; + + if (mask_name.find("swa") != std::string::npos) { + compute_params.attention_size_swa = mask->ne[0]; + } else { + compute_params.attention_size = mask->ne[0]; + } + if (is_static) { + compute_params.attention_size = model_params.ctx_per_seq; + compute_params.attention_size_swa = model_params.ctx_per_seq_swa; + compute_params.token_len_per_seq = 1; + } + break; + } + if (node->op == GGML_OP_ROPE) { + memcpy(model_params.rope_params, node->op_params, sizeof(int32_t) * 15); + } + } + auto * output_tensor = cgraph->nodes[cgraph->n_nodes - 1]; + compute_params.output_len = output_tensor->ne[1]; + // for NPU, output_len is always 1 except for llama-perplexity + if (is_static && compute_params.output_len == 0) { + compute_params.output_len = 1; + } + model_params.ctx = model_params.ctx_per_seq * model_params.n_seq; + model_params.ctx_swa = model_params.ctx_per_seq_swa * model_params.n_seq; + return {model_params, compute_params}; +} + +void GgmlOvDecoder::validate_cgraph() const { + if (m_model_params.n_seq > 1 && m_is_static == true) { + throw std::runtime_error("n_seq > 1 is not supported on NPU. Try setting -np 1."); + } +} + +ov::PartialShape GgmlOvDecoder::get_graph_input_shape(const ggml_tensor * op, const ggml_tensor * input) const { + if (m_naive) { + return input!= nullptr ? ov::PartialShape{get_shape(input)} : ov::PartialShape{get_shape(op)}; + } + auto name = std::string(input->name); + ov::PartialShape input_shape; + + if (is_inp_tok(input, op) || is_inp_pos(input, op)) { + // tokens or positions + int len = m_is_static ? (m_is_prefill ? m_prefill_chunk_size : 1) : -1; + input_shape = ov::PartialShape{1, 1, 1, len}; + + } else if (is_output_idx(input, op)) { + // output index + input_shape = ov::PartialShape{1, 1, 1, m_is_static ? m_compute_params.output_len : -1}; + + } else if (is_inp_mask(input, op)) { + // mask + if (m_is_static) { + input_shape = ov::PartialShape{1, 1, m_is_prefill ? m_prefill_chunk_size : 1, m_model_params.ctx}; + } else if (m_is_stateful) { + input_shape = ov::PartialShape{1, 1, -1, -1}; + } else { + input_shape = ov::PartialShape{-1, 1, -1, -1}; + } + + } else if (is_kvcache(input, op)) { + // kvcache + input_shape = ov::PartialShape{get_shape(input)}; + if (!m_is_static) { + // do not fix ctx size to make llama-bench work across test params + input_shape[2] = -1; + } + if (is_stateful()) { + // Convert stateless KV cache layout [1, 1, seq, n_heads_kv * head_size] + // to stateful layout [1, seq, n_heads_kv, head_size]. + assert(input_shape.size() == 4 && input_shape[0] == 1 && input_shape[1] == 1 && + input_shape[2].is_dynamic() && + input_shape[3] == (m_model_params.n_heads_kv * m_model_params.head_size)); + input_shape = {input_shape[0], ov::Dimension::dynamic(), m_model_params.n_heads_kv, + m_model_params.head_size}; + } + + } else if (is_kv_idx(input, op)) { + // kv update index + int len = m_is_static ? (m_is_prefill ? m_prefill_chunk_size : 1) : -1; + input_shape = ov::PartialShape{1, 1, 1, len}; + + } else { + input_shape = ov::PartialShape{get_shape(input)}; + } + return input_shape; +} + +void GgmlOvDecoder::add_extra_inputs() { + // Extra inputs: + // 1. `attention_size`, used in FLASH_ATTN where the shape of the matmul's are 256 aligned, + // see llama_kv_cache_unified::get_n_kv and llama_kv_cache_unified::get_padding. + // 2. `n_seq_active` and `seq_active_start`, used in FLASH_ATTN_EXT to indicate the active sequences in the batch + + auto create_1d_input = [this](const std::string & name, int64_t value) { + if (m_is_static) { + auto constant = + std::make_shared(ov::element::i64, ov::Shape{1}, std::vector{value}); + constant->set_friendly_name(name); + m_model_extra_inputs[name] = constant; + } else { + auto param_node = std::make_shared(ov::element::i64, ov::Shape{1}); + param_node->set_friendly_name(name); + param_node->output(0).get_tensor().set_names({name}); + m_model_extra_inputs[name] = param_node; + + auto tensor = std::make_shared(ov::element::i64, ov::Shape{1}); + *tensor->data() = value; + m_model_extra_input_values[name] = tensor; + } + }; + + create_1d_input("attention_size", m_compute_params.attention_size); + if (m_compute_params.attention_size_swa != -1) { + create_1d_input("attention_size_swa", m_compute_params.attention_size_swa); + } + create_1d_input("n_seq_active", m_compute_params.n_seq_active); + create_1d_input("seq_active_start", m_compute_params.seq_active_start); + create_1d_input("seq_active_end", m_compute_params.seq_active_start + m_compute_params.n_seq_active); + create_1d_input("token_len_per_seq", m_compute_params.token_len_per_seq); + // create_1d_input("token_len", m_token_len_per_seq * m_n_seq_active); +} + +bool GgmlOvDecoder::node_is_used_as_src(const int node_idx) { + ggml_tensor * node = m_cgraph->nodes[node_idx]; + for (int i = node_idx; i < m_cgraph->n_nodes; i++) { + ggml_tensor * other_node = m_cgraph->nodes[i]; + for (int j = 0; j < GGML_MAX_SRC; j++) { + if (other_node->src[j] == node) { + return true; + } + } + } + return false; +} + +void GgmlOvDecoder::compute_model_inputs() { + m_model_inputs.clear(); + m_inputs.clear(); + for (int i = 0; i < m_cgraph->n_nodes; i++) { + ggml_tensor * node = m_cgraph->nodes[i]; + // the node op is NONE means this node maybe as input of later nodes, we should add it to model inputs for this node. + if (node->op == GGML_OP_NONE && node_is_used_as_src(i)) { + std::string node_name(node->name); + if (m_model_weights.find(node_name) == m_model_weights.end()) { + m_inputs[node_name] = node; + auto param_node = + std::make_shared(get_ov_type(node), get_graph_input_shape(node, nullptr)); + param_node->set_friendly_name(node_name); + param_node->output(0).get_tensor().set_names({node_name}); + m_model_inputs[node_name] = param_node; + } + continue; + } + for (int i = 0; i < GGML_MAX_SRC; i++) { + auto * src = node->src[i]; + if (src == nullptr) { + continue; + } + std::string src_name = std::string(src->name); + if (src->flags & GGML_TENSOR_FLAG_INPUT) { + src_name = get_graph_input_ov_name(src, node); + } + if (m_model_weights.find(src_name) != m_model_weights.end()) { + continue; + } + + bool is_intermediate_node = false; + for (const auto & node_info : m_node_info_list) { + if (node_info.node == src) { + is_intermediate_node = true; + break; + } + } + if (is_intermediate_node) { + continue; + } + if (m_model_inputs.find(src_name) != m_model_inputs.end()) { + continue; + } + + m_inputs[src_name] = src; + + ggml_backend_buffer * buffer = src->buffer; + // GGML_BACKEND_BUFFER_USAGE_ANY are kv caches + if (buffer->usage == GGML_BACKEND_BUFFER_USAGE_ANY) { + if (auto it = std::find(m_model_params.kv_names.begin(), m_model_params.kv_names.end(), src_name); + it == m_model_params.kv_names.end()) { + m_model_params.kv_names.push_back(src_name); + } + } + ov::PartialShape param_shape = get_graph_input_shape(node, src); + auto param_node = std::make_shared(get_ov_type(src), param_shape); + param_node->set_friendly_name(src_name); + param_node->output(0).get_tensor().set_names({src_name}); + m_model_inputs[src_name] = param_node; + } + } +} + +void GgmlOvDecoder::compute_model_outputs() { + m_model_outputs.clear(); + m_model_output_names.clear(); + for (int node_n = 0; node_n < m_cgraph->n_nodes; node_n++) { + auto * cur_node = m_cgraph->nodes[node_n]; + // if the node op is NONE means this node is not used at all, we can skip it directly without adding to model outputs. + if (cur_node->op == GGML_OP_NONE) { + continue; + } + auto cur_node_use_count = m_cgraph->use_counts[ggml_hash_find(&m_cgraph->visited_hash_set, cur_node)]; + if (cur_node_use_count == 0) { + // The output of SET_ROWS is the view_src tensor, which is updated in place. We should use the view_src name as the output name to make sure it can be correctly matched with the later ops that use the view_src. + if (cur_node != nullptr && cur_node->op == GGML_OP_SET_ROWS) { + cur_node = cur_node->view_src; + } + } else { + int input_use_count = 0; + for (int i = 0; i < m_cgraph->n_nodes; i++) { + ggml_tensor * node = m_cgraph->nodes[i]; + for (int j = 0; j < GGML_MAX_SRC; j++) { + if (node->src[j] != NULL && node->src[j] == cur_node) { + input_use_count++; + } + } + } + if (input_use_count == cur_node_use_count) { + cur_node = nullptr; + } + } + if (cur_node != nullptr) { + std::string node_output_name(cur_node->name); + m_model_outputs[node_output_name] = cur_node; + m_model_output_names.push_back(node_output_name); + } + } +} + +const ggml_tensor * GgmlOvDecoder::get_tensor_used_op(const ggml_tensor * tensor) const { + if (tensor == nullptr) { + return nullptr; + } + for (int i = 0; i < m_cgraph->n_nodes; i++) { + const auto * node = m_cgraph->nodes[i]; + for (int j = 0; j < GGML_MAX_SRC; j++) { + if (node->src[j] == tensor) { + return node; + } + } + } + return nullptr; +} + +const ggml_tensor * GgmlOvDecoder::get_tensor_from_name(const std::string & name) const { + for (int i = 0; i < m_cgraph->n_nodes; i++) { + const auto * node = m_cgraph->nodes[i]; + for (int j = 0; j < GGML_MAX_SRC; j++) { + const auto * src = node->src[j]; + if (src == nullptr) { + break; + } + if (std::string(src->name) == name) { + return src; + } + } + } + return nullptr; +} + +std::map GgmlOvDecoder::get_kv_param_res_names() const { + std::map kv_param_res_names; + for (const auto & name : m_model_params.kv_names) { + kv_param_res_names[name] = name; + } + return kv_param_res_names; +} + +std::map> GgmlOvDecoder::create_weight_nodes(ggml_cgraph * cgraph, bool naive) { + static std::mutex weights_mutex; + std::lock_guard lock(weights_mutex); + + std::map> model_weights; + auto * nodes = cgraph->nodes; + auto n_nodes = cgraph->n_nodes; + for (int node_i = 0; node_i < n_nodes; node_i++) { + auto * node = nodes[node_i]; + for (int i = 0; i < GGML_MAX_SRC; i++) { + auto * src = node->src[i]; + if (src == nullptr) { + continue; + } + + std::string src_name(src->name); + if (is_rope_freqs_weight(src, node)) { + src_name = "rope_freqs.weight"; + } + if (!src->view_src) { + ggml_backend_buffer * buffer = src->buffer; + if (buffer->usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS || ggml_is_quantized(src->type)) { + if (model_weights.find(src_name) == model_weights.end()) { + auto weight_node = create_weight_node(src, naive); + weight_node->set_friendly_name(src_name); + model_weights[src_name] = weight_node; + } + } + } + } + } + return model_weights; +} + +std::shared_ptr GgmlOvDecoder::create_weight_node(ggml_tensor * tensor, bool naive) { + const bool is_ov_buffer = ggml_backend_buffer_is_openvino(tensor->buffer); + + // Check if we have a pre-built constant from the OpenVINO backend buffer + // This is set during ggml_backend_openvino_buffer_set_tensor + if (tensor->extra) { + OPENVINO_ASSERT(is_ov_buffer, "Unsupported weight tensor: " + std::string(tensor->name) + + " Possibly this is a cpu backend repacked quantized weights"); + // Cast to our extra base type and check the type + auto * extra_base = static_cast(tensor->extra); + + if (extra_base->type == ggml_openvino_extra_base::Type::WEIGHT) { + // F16/F32/BF16 weight with shared-memory constant + auto * weight_extra = static_cast(tensor->extra); + if (weight_extra->weight_node) { + // GGML_LOG_DEBUG("%s: using pre-built weight node for %s\n", __func__, tensor->name); + return weight_extra->weight_node; + } + } else if (extra_base->type == ggml_openvino_extra_base::Type::QUANTIZED_WEIGHT) { + // Quantized weight with pre-extracted data + auto * quant_extra = static_cast(tensor->extra); + if (quant_extra->weight_node) { + // GGML_LOG_DEBUG("%s: using pre-extracted quantized weight node for %s\n", __func__, tensor->name); + return quant_extra->weight_node; + } + } + } + + // There are three cases where we need to create a new weight node: + // 1. weights are in openvino_host_buffer. Weight loading to host buffer will not trigger backend_buffer_set_tensor + // 2. weights are in cpu/cpu_mapped buffer. On token_embd.weight goes to case 1 or 2, depending on whether mmap or direct_io is used + // 3. test-backend-ops. buffers in test-backend-ops does not set USAGE_WEIGHT so backend_buffer_set_tensor will not create weight node + + // GGML_LOG_DEBUG("%s: creating new weight node for %s\n", __func__, tensor->name); + static const std::set weight_types = {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16, + GGML_TYPE_Q8_0, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1, + GGML_TYPE_Q4_K, GGML_TYPE_Q5_K, GGML_TYPE_Q6_K}; + if (weight_types.find(tensor->type) == weight_types.end()) { + throw std::runtime_error("Unexpected weight tensor type: " + std::string(tensor->name) + " with type " + + ggml_type_name(tensor->type)); + } + + OvWeight ov_weight; + if (ggml_is_quantized(tensor->type)) { + auto use_bias = naive; + if (is_ov_buffer) { + // For quantized weights, copy raw data to a temp buffer first because + // process_weight_tensor reads from data and writes extracted results + // (weights/scales/zp) to output_base_ptr — they would overlap if both + // point to tensor->data. + size_t raw_size = ggml_nbytes(tensor); + std::vector tmp(raw_size); + memcpy(tmp.data(), tensor->data, raw_size); + ov_weight = process_weight_tensor(tensor, tmp.data(), tensor->data, use_bias); + } else { + ov_weight = process_weight_tensor(tensor, tensor->data, nullptr, use_bias); + } + } else { + // For non-quantized weights (F16/F32/BF16), data is already in tensor->data. + // process_weight_tensor will create an ov::Tensor wrapping tensor->data directly. + ov_weight = process_weight_tensor(tensor, tensor->data, tensor->data); + } + + ov_weight.weight_node->set_friendly_name(tensor->name); + if (!is_ov_buffer) { + return ov_weight.weight_node; + } + + ggml_openvino_extra_base * extra; + if (ov_weight.is_quantized()) { + extra = new ggml_openvino_quantized_weight_extra(std::move(ov_weight.weights), std::move(ov_weight.scales), + std::move(ov_weight.zp), ov_weight.weight_node); + } else { + extra = new ggml_openvino_weight_extra(std::move(ov_weight.weights), ov_weight.weight_node); + } + ggml_openvino_buffer_register_extra(tensor, extra); + + return ov_weight.weight_node; +} + +void GgmlOvDecoder::dump_cgraph(const ggml_cgraph * cgraph, std::string & filename) { + std::ofstream file(filename); + if (!file.is_open()) { + std::cerr << "Failed to open file" << std::endl; + return; + } + + file << "=== GRAPH ===\n"; + + // clang-format off + file << "n_nodes = " << cgraph->n_nodes << "\n"; + file << " " << std::setw(3) << "nodes" + << std::setw(15) << "shape" + << std::setw(20) << "op" + << std::setw(20) << "name" + << std::setw(3) << " " + << std::setw(62) << "stride" + << std::setw(20) << "buffer_type" + << "\n"; + for (int i = 0; i < cgraph->n_nodes; i++) { + ggml_tensor * node = cgraph->nodes[i]; + + // Get buffer type name + const char * buf_name = "none"; + ggml_backend_buffer_t buf = node->view_src ? node->view_src->buffer : node->buffer; + if (buf) { + buf_name = ggml_backend_buffer_name(buf); + } + + file << " - " << std::setw(3) << i << ": [ " + << std::setw(5) << node->ne[0] << ", " + << std::setw(5) << node->ne[1] << ", " + << std::setw(5) << node->ne[2] << ", " + << std::setw(5) << node->ne[3] << "] " + << std::left << std::setw(20) << ggml_op_name(node->op) << std::right << " " + << std::left << std::setw(45) << node->name << std::right + << std::setw(2) << "[ " + << std::setw(0) << node->nb[0] << ", " + << std::setw(5) << node->nb[1] << ", " + << std::setw(5) << node->nb[2] << ", " + << std::setw(5) << node->nb[3] << "] " + << std::right << std::setw(15) << buf_name << std::right + << "\n"; + + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (auto* src = node->src[i]) { + // Get buffer type name for source + const char * src_buf_name = "none"; + ggml_backend_buffer_t src_buf = src->view_src ? src->view_src->buffer : src->buffer; + if (src_buf) { + src_buf_name = ggml_backend_buffer_name(src_buf); + } + + file << std::setw(10) << " [ " + << std::setw(5) << src->ne[0] << ", " + << std::setw(5) << src->ne[1] << ", " + << std::setw(5) << src->ne[2] << ", " + << std::setw(5) << src->ne[3] << "] " + << std::setw(12) + << i << ": " << std::left << std::setw(12) << ggml_op_name(src->op) << std::right; + file << std::left << std::setw(30) << src->name << std::right + << std::setw(16) << "[ " + << std::setw(0) << src->nb[0] << ", " + << std::setw(5) << src->nb[1] << ", " + << std::setw(5) << src->nb[2] << ", " + << std::setw(5) << src->nb[3] << "] " + << std::right << std::setw(15) << src_buf_name << std::right + << "\n"; + } + } + } + + file << "n_leafs = " << cgraph->n_leafs << "\n"; + for (int i = 0; i < cgraph->n_leafs; i++) { + ggml_tensor * node = cgraph->leafs[i]; + + // Get buffer type name for leaf + const char * leaf_buf_name = "none"; + ggml_backend_buffer_t leaf_buf = node->view_src ? node->view_src->buffer : node->buffer; + if (leaf_buf) { + leaf_buf_name = ggml_backend_buffer_name(leaf_buf); + } + + file << " - " << std::setw(3) << i << ": [ " + << std::setw(5) << node->ne[0] << ", " + << std::setw(5) << node->ne[1] << "] " + << std::setw(8) << ggml_op_name(node->op) << " " + << std::setw(16) << ggml_get_name(node) + << std::setw(20) << leaf_buf_name << "\n"; + } + // clang-format on + file << "========================================\n"; + + file.close(); +} + +void print_tensor_address_map(const ggml_cgraph * cgraph) { + std::map> address_map; + for (int node_n = 0; node_n < cgraph->n_nodes; node_n++) { + auto * node = cgraph->nodes[node_n]; + if (node->data) { + auto it = address_map.find(node->data); + if (it == address_map.end()) { + address_map[node->data] = std::vector(); + } + address_map[node->data].push_back(node->name); + } + } + for (const auto & pair : address_map) { + std::cout << "Address: " << pair.first << std::endl; + for (const auto & name : pair.second) { + std::cout << name << " ; "; + } + std::cout << std::endl << std::endl; + } +} + +ov::Shape GgmlOvDecoder::get_shape(const ggml_tensor * tensor) { + std::vector shape; + for (int i = GGML_MAX_DIMS - 1; i >= 0; --i) { + shape.push_back(static_cast(tensor->ne[i])); + } + return shape; +} + +std::vector GgmlOvDecoder::get_stride(const ggml_tensor * tensor) { + std::vector stride; + for (int i = GGML_MAX_DIMS - 1; i >= 0; --i) { + stride.push_back(static_cast(tensor->nb[i])); + } + return stride; +} + +ov::element::Type GgmlOvDecoder::get_ov_type(const ggml_tensor * tensor) { + switch (tensor->type) { + case GGML_TYPE_F64: + return ov::element::f64; + case GGML_TYPE_F32: + return ov::element::f32; + case GGML_TYPE_F16: + return ov::element::f16; + case GGML_TYPE_BF16: + return ov::element::bf16; + case GGML_TYPE_I8: + return ov::element::i8; + case GGML_TYPE_I16: + return ov::element::i16; + case GGML_TYPE_I32: + return ov::element::i32; + case GGML_TYPE_I64: + return ov::element::i64; + default: + return ov::element::dynamic; + } +} + +ov::PartialShape GgmlOvDecoder::get_input_shape(int node_idx, const std::string & name) const { + return ov::PartialShape(get_shape(m_node_info_list[node_idx].node_inputs.at(name))); +} + +std::vector GgmlOvDecoder::get_input_stride(int node_idx, const std::string & name) const { + return get_stride(m_node_info_list[node_idx].node_inputs.at(name)); +} + +ov::element::Type GgmlOvDecoder::get_input_type(int node_idx, const std::string & name) const { + return get_ov_type(m_node_info_list[node_idx].node_inputs.at(name)); +} + +size_t GgmlOvDecoder::get_input_size() const { + return m_model_inputs.size(); +} + +size_t GgmlOvDecoder::get_input_size(int node_idx) const { + return m_node_info_list[node_idx].node_inputs_names.size(); +} + +std::vector GgmlOvDecoder::get_input_names(int node_idx) const { + return m_node_info_list[node_idx].node_inputs_names; +} + +ov::PartialShape GgmlOvDecoder::get_output_shape(int node_idx) const { + auto * ggml_tensor = m_node_info_list[node_idx].node_output; + return ov::PartialShape(get_shape(ggml_tensor)); +} + +ov::element::Type GgmlOvDecoder::get_output_type(const int node_idx) const { + return get_ov_type(m_node_info_list[node_idx].node); +} + +std::vector GgmlOvDecoder::get_output_names(int node_idx) const { + return {m_node_info_list[node_idx].node_output_name}; +} + +const std::string & GgmlOvDecoder::get_op_name() const { + static const std::string unknown_name = "UNKNOWN_OP_NAME"; + return unknown_name; +} + +const std::string & GgmlOvDecoder::get_op_name(int node_idx) const { + return m_node_info_list[node_idx].node_name; +} + +int32_t * GgmlOvDecoder::get_input_op_params(int node_idx, const std::string & name) const { + return m_node_info_list[node_idx].node_inputs.at(name)->op_params; +} + +int32_t * GgmlOvDecoder::get_output_op_params(int node_idx) const { + return m_node_info_list[node_idx].node->op_params; +} + +void GgmlOvDecoder::visit_subgraph(std::function, int node_idx)> node_visitor) const { + for (int node_idx = 0; node_idx < m_cgraph->n_nodes; node_idx++) { + if (m_cgraph->nodes[node_idx]->op == GGML_OP_NONE) { + continue; + } + node_visitor(std::make_shared(*this), node_idx); + } +} + +std::string GgmlOvDecoder::compute_op_type(const ggml_tensor * node) { + static const std::map ops = { + {GGML_OP_NONE, "GGML_OP_NONE" }, + {GGML_OP_ACC, "GGML_OP_ACC" }, + {GGML_OP_ADD, "GGML_OP_ADD" }, + {GGML_OP_ADD1, "GGML_OP_ADD1" }, + {GGML_OP_CONT, "GGML_OP_CONT" }, + {GGML_OP_DIV, "GGML_OP_DIV" }, + {GGML_OP_DUP, "GGML_OP_DUP" }, + {GGML_OP_GET_ROWS, "GGML_OP_GET_ROWS" }, + {GGML_OP_MUL, "GGML_OP_MUL" }, + {GGML_OP_MUL_MAT, "GGML_OP_MUL_MAT" }, + {GGML_OP_PERMUTE, "GGML_OP_PERMUTE" }, + {GGML_OP_RESHAPE, "GGML_OP_RESHAPE" }, + {GGML_OP_RMS_NORM, "GGML_OP_RMS_NORM" }, + {GGML_OP_ROPE, "GGML_OP_ROPE" }, + {GGML_OP_SCALE, "GGML_OP_SCALE" }, + {GGML_OP_SOFT_MAX, "GGML_OP_SOFT_MAX" }, + {GGML_OP_SUB, "GGML_OP_SUB" }, + {GGML_OP_TRANSPOSE, "GGML_OP_TRANSPOSE" }, + {GGML_OP_VIEW, "GGML_OP_VIEW" }, + {GGML_OP_SET_ROWS, "GGML_OP_SET_ROWS" }, + {GGML_OP_CPY, "GGML_OP_CPY" }, + {GGML_OP_FLASH_ATTN_EXT, "GGML_OP_FLASH_ATTN_EXT"}, + }; + static const std::map unary_ops = { + {GGML_UNARY_OP_ABS, "GGML_UNARY_OP_ABS" }, + {GGML_UNARY_OP_SGN, "GGML_UNARY_OP_SGN" }, + {GGML_UNARY_OP_NEG, "GGML_UNARY_OP_NEG" }, + {GGML_UNARY_OP_STEP, "GGML_UNARY_OP_STEP" }, + {GGML_UNARY_OP_TANH, "GGML_UNARY_OP_TANH" }, + {GGML_UNARY_OP_ELU, "GGML_UNARY_OP_ELU" }, + {GGML_UNARY_OP_RELU, "GGML_UNARY_OP_RELU" }, + {GGML_UNARY_OP_SIGMOID, "GGML_UNARY_OP_SIGMOID" }, + {GGML_UNARY_OP_GELU, "GGML_UNARY_OP_GELU" }, + {GGML_UNARY_OP_GELU_QUICK, "GGML_UNARY_OP_GELU_QUICK" }, + {GGML_UNARY_OP_SILU, "GGML_UNARY_OP_SILU" }, + {GGML_UNARY_OP_HARDSWISH, "GGML_UNARY_OP_HARDSWISH" }, + {GGML_UNARY_OP_HARDSIGMOID, "GGML_UNARY_OP_HARDSIGMOID"}, + {GGML_UNARY_OP_EXP, "GGML_UNARY_OP_EXP" }, + {GGML_UNARY_OP_COUNT, "GGML_UNARY_OP_COUNT" } + }; + static const std::map glu_ops = { + {GGML_GLU_OP_SWIGLU, "GGML_GLU_OP_SWIGLU"}, + {GGML_GLU_OP_GEGLU, "GGML_GLU_OP_GEGLU" }, + {GGML_GLU_OP_REGLU, "GGML_GLU_OP_REGLU" } + }; + + switch (node->op) { + case GGML_OP_UNARY: + return unary_ops.at(ggml_get_unary_op(node)); + case GGML_OP_GLU: + return glu_ops.at(ggml_get_glu_op(node)); + default: + return ops.at(node->op); + } + static const std::string unknown_op = "UNKNOWN_GGML_OP"; + return unknown_op; +} + +const std::string & GgmlOvDecoder::get_op_type(int node_idx) const { + return m_node_info_list[node_idx].node_op_type; +} + +const std::string & GgmlOvDecoder::get_op_type() const { + static const std::string unknown_op = "UNKNOWN_GGML_OP"; + return unknown_op; +} diff --git a/ggml/src/ggml-openvino/ggml-decoder.h b/ggml/src/ggml-openvino/ggml-decoder.h new file mode 100644 index 00000000000..3ae25ddda32 --- /dev/null +++ b/ggml/src/ggml-openvino/ggml-decoder.h @@ -0,0 +1,294 @@ +#pragma once + +#include "ggml-quants.h" +#include "ggml.h" +#include "openvino/decoder.h" + +#include +#include +#include +#include +#include +#include +#include + +struct ModelParams { + int ctx = -1; + int ctx_swa = -1; + int ctx_per_seq = -1; + int ctx_per_seq_swa = -1; + int n_seq = 1; + int n_heads = -1; + int n_heads_kv = -1; + int head_size = -1; + int32_t rope_params[15]; + std::vector swa_layers; + + std::vector kv_names; + size_t kv_buffer_ctx_id = 0; + + bool same_rope_params(const ModelParams & other) const { + return memcmp(rope_params, other.rope_params, sizeof(int32_t) * 15) == 0; + } + + bool can_reuse_dynamically(const ModelParams & other) const { return same_rope_params(other); } + + bool can_reuse_statically(const ModelParams & other) const { return same_rope_params(other) && ctx == other.ctx; } + + bool kv_buffer_changed(const ModelParams & other) const { return kv_buffer_ctx_id != other.kv_buffer_ctx_id; } +}; + +struct ComputeParams { + int n_seq_active = 1; + int seq_active_start = 0; + int attention_size = -1; + int attention_size_swa = -1; + int input_len = -1; + int token_len_per_seq = -1; + int past_kv_len = -1; + int output_len = 1; +}; + +class GgmlOvDecoder : public ov::frontend::ggml::GgmlDecoder { +public: + struct NodeInfo { + ggml_tensor * node; + std::string node_name; + std::string node_op_type; + std::map node_inputs; + std::vector node_inputs_names; + ggml_tensor * node_output; + std::string node_output_name; + int node_op_case = 0; + void * data_addr; + }; + // Graph decoder + GgmlOvDecoder(ggml_cgraph * cgraph, + ModelParams & model_params, + ComputeParams & compute_params, + std::map> & model_weights, + bool is_static, + bool is_stateful = false, + bool is_prefill = false, + int prefill_chunk_size = 256); + + // Naive graph decoder + GgmlOvDecoder(ggml_cgraph * cgraph, std::map> & model_weights); + + virtual ov::Any get_attribute(const std::string & name) const override { + return nullptr; + GGML_UNUSED(name); + } + + virtual ov::PartialShape get_input_shape(int node_idx, const std::string & name) const override; + + virtual std::vector get_input_stride(int node_idx, const std::string & name) const override; + + virtual ov::element::Type get_input_type(int node_idx, const std::string & name) const override; + + virtual size_t get_input_size() const override; + + virtual size_t get_input_size(int node_idx) const override; + + virtual void get_input_node(size_t input_port_idx, + std::string & producer_name, + std::string & producer_output_port_name, + size_t & producer_output_port_index) const override { + GGML_UNUSED(input_port_idx); + GGML_UNUSED(producer_name); + GGML_UNUSED(producer_output_port_name); + GGML_UNUSED(producer_output_port_index); + } + + virtual std::vector get_input_names(int node_idx) const override; + + virtual ov::PartialShape get_output_shape(int node_idx) const override; + + virtual ov::element::Type get_output_type(int node_idx) const override; + + virtual int32_t * get_input_op_params(int node_idx, const std::string & name) const override; + + virtual int32_t * get_output_op_params(int node_idx) const override; + + virtual std::vector get_output_names(int node_idx) const override; + + virtual const std::string & get_op_type() const override; + + virtual const std::string & get_op_type(int node_idx) const override; + + virtual const std::string & get_op_name() const override; + + virtual const std::string & get_op_name(int node_idx) const override; + + virtual void visit_subgraph(std::function, int node_idx)> node_visitor) const override; + + ggml_tensor * get_input_ggml_tensor(const std::string & name) const { return m_inputs.at(name); } + + virtual int get_op_case(int node_idx) const override { return m_node_info_list[node_idx].node_op_case; } + + virtual const std::map> & get_model_inputs() const override { + return m_model_inputs; + } + + virtual const std::map> & get_model_extra_inputs() const override { + return m_model_extra_inputs; + } + + virtual const std::map> & get_model_extra_input_values() const { + return m_model_extra_input_values; + } + + virtual const std::map> & get_model_weights() const override { + return m_model_weights; + } + + virtual std::vector get_model_output_names() const override { + return m_model_output_names; + } + + const std::map & get_model_outputs() const { return m_model_outputs; } + + virtual int get_ctx_size() const { return m_model_params.ctx; } + + virtual int get_ctx_swa_size() const { return m_model_params.ctx_swa; } + + virtual int get_ctx_per_seq() const { return m_model_params.ctx_per_seq; } + + virtual int get_ctx_per_seq_swa() const { return m_model_params.ctx_per_seq_swa; } + + virtual int get_n_seq() const { return m_model_params.n_seq; } + + virtual int is_swa_layer(int layer) const override { + return std::find(m_model_params.swa_layers.begin(), m_model_params.swa_layers.end(), layer) != + m_model_params.swa_layers.end(); + } + + int get_past_kv_len() const { return m_compute_params.past_kv_len; } + + int get_input_len() const { return m_compute_params.input_len; } + + virtual int32_t * get_rope_params() const override { return const_cast(m_model_params.rope_params); } + + virtual std::map get_kv_param_res_names() const override; + + virtual bool is_static() const override { return m_is_static; } + + virtual bool is_stateful() const override { return m_is_stateful; } + + ov::PartialShape get_graph_input_shape(const ggml_tensor * op, const ggml_tensor * input) const; + + static void dump_cgraph(const ggml_cgraph * cgraph, std::string & filename); + + static std::shared_ptr create_weight_node(ggml_tensor * tensor, bool naive = false); + + static std::map> create_weight_nodes(ggml_cgraph * cgraph, + bool naive = false); + + const ggml_tensor * get_tensor_used_op(const ggml_tensor * tensor) const; + + const ggml_tensor * get_tensor_from_name(const std::string & name) const; + + void clear_model_weights() { m_model_weights.clear(); } + + static std::pair compute_llm_params(ggml_cgraph * cgraph, bool is_static); + + ModelParams get_model_params() const { return m_model_params; } + + ComputeParams get_compute_params() const { return m_compute_params; } + + void set_model_params(const ModelParams & model_params) { m_model_params = model_params; } + + void set_compute_params(const ComputeParams & compute_params) { m_compute_params = compute_params; } + + bool m_is_static = false; + bool m_is_stateful = false; + bool m_is_prefill = false; + bool m_naive = false; + int m_prefill_chunk_size = 0; + + static ov::Shape get_shape(const ggml_tensor * tensor); + static std::vector get_stride(const ggml_tensor * tensor); + static ov::element::Type get_ov_type(const ggml_tensor * tensor); + static std::string compute_op_type(const ggml_tensor * node); + void add_extra_inputs(); + + void update_io(ggml_cgraph * cgraph); + + inline static bool is_inp_tok(const ggml_tensor * tensor, const ggml_tensor * op) { + return op->op == GGML_OP_GET_ROWS && tensor == op->src[1] && op->src[0]->op == GGML_OP_NONE; + } + + inline static bool is_inp_pos(const ggml_tensor * tensor, const ggml_tensor * op) { + return op->op == GGML_OP_ROPE && tensor == op->src[1]; + } + + inline static bool is_inp_emb(const ggml_tensor * tensor, const ggml_tensor * op) { + return tensor->op == GGML_OP_GET_ROWS && op->op == GGML_OP_RMS_NORM; + } + + inline static bool is_inp_mask(const ggml_tensor * tensor, const ggml_tensor * op) { + return op->op == GGML_OP_CPY || (op->op == GGML_OP_FLASH_ATTN_EXT && tensor == op->src[3]); + } + + inline static bool is_rope_freqs_weight(const ggml_tensor * tensor, const ggml_tensor * op) { + return op->op == GGML_OP_ROPE && tensor == op->src[2]; + } + + inline static bool is_kvcache(const ggml_tensor * tensor, const ggml_tensor * op) { + return op->op == GGML_OP_SET_ROWS && op->src[2] == tensor; + } + + inline static bool is_kv_idx(const ggml_tensor * tensor, const ggml_tensor * op) { + return op->op == GGML_OP_SET_ROWS && op->src[1] == tensor; + } + + inline static bool is_output_idx(const ggml_tensor * tensor, const ggml_tensor * op) { + return op->op == GGML_OP_GET_ROWS && tensor == op->src[1] && op->src[0]->op != GGML_OP_NONE; + } + + static std::string get_graph_input_ov_name(const ggml_tensor * tensor, const ggml_tensor * op) { + if (is_inp_tok(tensor, op)) { + return "inp_tokens"; + } + if (is_inp_pos(tensor, op)) { + return "inp_pos"; + } + if (is_inp_emb(tensor, op)) { + return "embd"; + } + if (is_output_idx(tensor, op)) { + return "inp_out_ids"; + } + if (is_inp_mask(tensor, op)) { + return std::string(tensor->name).find("swa") == std::string::npos ? "self_kq_mask" : "self_kq_mask_swa"; + } + return tensor->name; + } + +private: + void set_input_output(); + int compute_op_case(const ggml_tensor * node) const; + bool node_is_used_as_src(const int node_idx); + void compute_model_inputs(); + void compute_model_outputs(); + + void validate_cgraph() const; + + ggml_cgraph * m_cgraph = nullptr; + std::map m_inputs; + + std::map> m_model_inputs; + std::map> m_model_extra_inputs; + std::map> m_model_extra_input_values; + std::map> m_model_weights; + std::map m_model_outputs; + std::vector m_model_output_names; + std::vector m_node_info_list; + + ModelParams m_model_params; + ComputeParams m_compute_params; +}; + +void print_tensor_address_map(const ggml_cgraph * cgraph); + +int extract_layer_from_name(const std::string & name); diff --git a/ggml/src/ggml-openvino/ggml-openvino-extra.cpp b/ggml/src/ggml-openvino/ggml-openvino-extra.cpp new file mode 100644 index 00000000000..cc3cb4583cd --- /dev/null +++ b/ggml/src/ggml-openvino/ggml-openvino-extra.cpp @@ -0,0 +1,373 @@ +#include "ggml-openvino-extra.h" + +#include "ggml-impl.h" +#include "ggml.h" + +#include +#include +#include +#include + +ov::Core & ov_singleton_core() { + static ov::Core core; + return core; +} + +// ===================================================== +// Device Configuration Implementations +// ===================================================== + +void ggml_openvino_device_config::init() { + if (initialized) { + return; + } + device_name = getenv("GGML_OPENVINO_DEVICE") ? getenv("GGML_OPENVINO_DEVICE") : "CPU"; + auto available_devices = ov_singleton_core().get_available_devices(); + if (std::find(available_devices.begin(), available_devices.end(), device_name) == available_devices.end()) { + GGML_LOG_WARN("GGML OpenVINO Backend: device %s is not available, fallback to CPU\n", device_name.c_str()); + device_name = "CPU"; + } + is_npu = (device_name == "NPU"); + + auto * cache_dir = getenv("GGML_OPENVINO_CACHE_DIR"); + if (device_name == "NPU") { + compile_config = { + {"NPU_COMPILER_DYNAMIC_QUANTIZATION", "YES" }, + {"NPU_USE_NPUW", "YES" }, + {"NPUW_DEVICES", "NPU" }, + {"NPUW_FOLD", "YES" }, + {"NPUW_WEIGHTS_BANK", "shared"}, + {"NPUW_FUNCALL_FOR_ALL", "YES" }, + {"NPUW_FUNCALL_ASYNC", "YES" }, + {"NPUW_DQ", "YES" }, + {"NPUW_DQ_FULL", "NO" }, + }; + if (cache_dir) { + compile_config["NPUW_CACHE_DIR"] = cache_dir; + } + } else if (cache_dir) { + ov_singleton_core().set_property(ov::cache_dir(cache_dir)); + } + + // Initialize remote context with queue sharing for GPU + if (device_name == "GPU") { + // Create OpenCL context and queue + cl_int err; + cl_platform_id platform; + err = clGetPlatformIDs(1, &platform, nullptr); + if (err != CL_SUCCESS) { + GGML_LOG_ERROR("Failed to get OpenCL platform: %d\n", err); + return; + } + + cl_device_id cl_device; + err = clGetDeviceIDs(platform, CL_DEVICE_TYPE_GPU, 1, &cl_device, nullptr); + if (err != CL_SUCCESS) { + GGML_LOG_ERROR("Failed to get OpenCL device: %d\n", err); + return; + } + + cl_context cl_ctx = clCreateContext(nullptr, 1, &cl_device, nullptr, nullptr, &err); + if (err != CL_SUCCESS) { + GGML_LOG_ERROR("Failed to create OpenCL context: %d\n", err); + return; + } + + cl_queue = clCreateCommandQueueWithProperties(cl_ctx, cl_device, nullptr, &err); + if (err != CL_SUCCESS) { + GGML_LOG_ERROR("Failed to create OpenCL command queue: %d\n", err); + clReleaseContext(cl_ctx); + return; + } + + // Create OpenVINO remote context with queue sharing + remote_context = ov::intel_gpu::ocl::ClContext(ov_singleton_core(), cl_queue); + + // Release the context (queue keeps a reference) + clReleaseContext(cl_ctx); + } else if (device_name == "NPU") { + // remote tensor is not used for NPU yet + // remote_context = ov_singleton_core().get_default_context(device_name); + } + + initialized = true; +} + +ggml_openvino_device_config::~ggml_openvino_device_config() { + if (cl_queue != nullptr) { + clReleaseCommandQueue(cl_queue); + cl_queue = nullptr; + } +} + +// Get the global device config singleton +ggml_openvino_device_config & ggml_openvino_get_device_config() { + static ggml_openvino_device_config config; + return config; +} + +// Initialize device config (call during backend init) +void ggml_openvino_init_device_config() { + ggml_openvino_get_device_config().init(); +} + +// Get the device name +const std::string & ggml_openvino_get_device_name() { + return ggml_openvino_get_device_config().device_name; +} + +// Check if running on NPU +bool ggml_openvino_is_npu() { + return ggml_openvino_get_device_config().is_npu; +} + +// Get the remote context for the current device (returns empty optional for CPU) +std::optional ggml_openvino_get_remote_context() { + return ggml_openvino_get_device_config().remote_context; +} + +// Get the compile config for the current device +const ov::AnyMap & ggml_openvino_get_compile_config() { + return ggml_openvino_get_device_config().compile_config; +} + +// Get the OpenCL command queue for GPU operations +cl_command_queue ggml_openvino_get_cl_queue() { + return ggml_openvino_get_device_config().cl_queue; +} + +// Get the clEnqueueMemFillINTEL function pointer (lazy load) +clEnqueueMemFillINTEL_fn ggml_openvino_get_clEnqueueMemFillINTEL() { + static clEnqueueMemFillINTEL_fn fn = nullptr; + static bool loaded = false; + if (!loaded) { + loaded = true; + cl_platform_id platform; + if (clGetPlatformIDs(1, &platform, nullptr) == CL_SUCCESS) { + fn = (clEnqueueMemFillINTEL_fn) clGetExtensionFunctionAddressForPlatform(platform, "clEnqueueMemFillINTEL"); + } + } + return fn; +} + +// Get the clEnqueueMemcpyINTEL function pointer (lazy load) +clEnqueueMemcpyINTEL_fn ggml_openvino_get_clEnqueueMemcpyINTEL() { + static clEnqueueMemcpyINTEL_fn fn = nullptr; + static bool loaded = false; + if (!loaded) { + loaded = true; + cl_platform_id platform; + if (clGetPlatformIDs(1, &platform, nullptr) == CL_SUCCESS) { + fn = (clEnqueueMemcpyINTEL_fn) clGetExtensionFunctionAddressForPlatform(platform, "clEnqueueMemcpyINTEL"); + } + } + return fn; +} + +// Get requantization type for a tensor type (returns nullopt if no requant needed) +std::optional ggml_openvino_get_requant_type(const ggml_tensor * tensor, bool no_requant) { + if (no_requant) { + return std::nullopt; + } + if (strncmp(tensor->name, "token_embd.weight", 17) == 0) { + return ((ggml_openvino_is_npu() && tensor->type == GGML_TYPE_Q6_K) ? ExtraQuantType::F16 : ExtraQuantType::Q8_0_C); + } + if (strncmp(tensor->name, "output.weight", 13) == 0) { + return ExtraQuantType::Q8_0_C; + } + if (ggml_openvino_is_npu()) { + return ExtraQuantType::Q4_0_128; + } + switch (tensor->type) { + case GGML_TYPE_Q6_K: + case GGML_TYPE_Q5_K: + return ExtraQuantType::Q8_0_C; + default: + return std::nullopt; + } +} + +// ===================================================== +// Extracted Layout Calculation +// ===================================================== + +ggml_openvino_extracted_layout ggml_openvino_get_extracted_layout(const ggml_tensor * tensor, bool use_bias) { + ggml_openvino_extracted_layout layout = {}; + layout.is_symmetric = false; + + if (!ggml_is_quantized(tensor->type)) { + return layout; + } + + // Only handle 2D weight tensors + if (tensor->ne[2] != 1 || tensor->ne[3] != 1) { + return layout; + } + + int64_t n_elements = ggml_nelements(tensor); + const size_t alignment = 64; // Good for SIMD + + // Check if requantization is needed (NPU-specific) + auto requant_type = ggml_openvino_get_requant_type(tensor, use_bias); + if (requant_type.has_value()) { + layout.is_requant = true; + layout.requant_type = requant_type; + + // Special case: requant to F16 - just store F16 weights, no scales/zp + if (requant_type.value() == ExtraQuantType::F16) { + layout.weights_size = n_elements * sizeof(uint16_t); // F16 = 2 bytes + layout.total_size = layout.weights_size; + layout.weights_offset = 0; + // No scales/zp for F16 + return layout; + } + + // Requant to different quantized format (e.g., Q4_0_128) + switch (requant_type.value()) { + case ExtraQuantType::Q4_0_128: + layout.is_u4 = true; + layout.weights_per_block = 128; + layout.is_symmetric = true; + break; + case ExtraQuantType::Q4_0_C: + layout.is_u4 = true; + layout.weights_per_block = tensor->ne[0]; + layout.is_symmetric = true; + break; + case ExtraQuantType::Q8_0_32: + layout.is_u4 = false; + layout.weights_per_block = 32; + layout.is_symmetric = true; + break; + case ExtraQuantType::Q8_0_C: + layout.is_u4 = false; + layout.weights_per_block = tensor->ne[0]; + layout.is_symmetric = true; + break; + case ExtraQuantType::Q8_1_C: + layout.is_u4 = false; + layout.weights_per_block = tensor->ne[0]; + break; + default: + layout.weights_per_block = -1; + GGML_ABORT("Code of re-quantizing to channel-wise is not updated"); + break; + } + + if (layout.is_requant) { + // Calculate sizes for requantized format + layout.weights_size = layout.is_u4 ? (n_elements / 2) : n_elements; + int64_t n_blocks = n_elements / layout.weights_per_block; + layout.scales_size = n_blocks * sizeof(uint16_t); + // For symmetric quantization, we only need one zp value (not one per block) + // Zero points are stored in U4 or U8 format matching the weight type + size_t n_zp_elements = layout.is_symmetric ? 1 : n_blocks; + layout.zp_size = layout.is_u4 ? ((n_zp_elements + 1) / 2) : n_zp_elements; + + layout.weights_offset = 0; + layout.scales_offset = ((layout.weights_size + alignment - 1) / alignment) * alignment; + layout.zp_offset = layout.scales_offset + ((layout.scales_size + alignment - 1) / alignment) * alignment; + layout.total_size = layout.zp_offset + layout.zp_size; + layout.total_size = std::max(layout.total_size, ggml_nbytes(tensor)); + return layout; + } + } + + // Normal extraction (no requant) - determine format based on tensor type + layout.is_u4 = false; + layout.weights_per_block = 32; + layout.is_symmetric = false; + + switch (tensor->type) { + case GGML_TYPE_Q4_0: + layout.is_u4 = true; + layout.is_symmetric = true; + break; + + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_K: + layout.is_u4 = true; + break; + + case GGML_TYPE_Q8_0: + layout.is_symmetric = true; + break; + + case GGML_TYPE_Q6_K: + layout.weights_per_block = 16; + layout.is_symmetric = true; + break; + + case GGML_TYPE_Q5_K: + break; + + default: + // Unsupported quantization type + return layout; + } + + // Calculate sizes + // Weights: U4 = n_elements/2 bytes, U8 = n_elements bytes + layout.weights_size = layout.is_u4 ? (n_elements / 2) : n_elements; + + // Scales: F16 per block + int64_t n_blocks = n_elements / layout.weights_per_block; + layout.scales_size = n_blocks * sizeof(uint16_t); // F16 = 2 bytes + // Zero points: U4 or U8 matching weight type + // For symmetric quantization, we only need one zp value (not one per block) + size_t n_zp_elements = layout.is_symmetric ? 1 : n_blocks; + layout.zp_size = layout.is_u4 ? ((n_zp_elements + 1) / 2) : n_zp_elements; + + // Layout in buffer: [weights | scales | zp] with alignment + layout.weights_offset = 0; + layout.scales_offset = ((layout.weights_size + alignment - 1) / alignment) * alignment; + layout.zp_offset = layout.scales_offset + ((layout.scales_size + alignment - 1) / alignment) * alignment; + layout.total_size = layout.zp_offset + layout.zp_size; + layout.total_size = std::max(layout.total_size, ggml_nbytes(tensor)); + + return layout; +} + +ggml_openvino_tensor_extra * ggml_openvino_create_tensor_extra(const ggml_tensor * tensor, bool is_remote) { + ov::Shape shape; + for (int i = GGML_MAX_DIMS - 1; i >= 0; --i) { + shape.push_back(static_cast(tensor->ne[i])); + } + + ov::element::Type element_type; + switch (tensor->type) { + case GGML_TYPE_F32: + element_type = ov::element::f32; + break; + case GGML_TYPE_F16: + element_type = ov::element::f16; + break; + case GGML_TYPE_BF16: + element_type = ov::element::bf16; + break; + case GGML_TYPE_I32: + element_type = ov::element::i32; + break; + case GGML_TYPE_I64: + element_type = ov::element::i64; + break; + default: + // GGML_LOG_WARN("%s: unsupported tensor type for ov::Tensor: %s\n", __func__, ggml_type_name(tensor->type)); + return nullptr; + } + + const auto & device_name = ggml_openvino_get_device_name(); + auto remote_context = ggml_openvino_get_remote_context(); + + std::shared_ptr ov_tensor; + if (is_remote) { + GGML_ASSERT(device_name == "GPU"); + auto gpu_context = remote_context->as(); + auto usm_tensor = gpu_context.create_tensor(element_type, shape, tensor->data); + ov_tensor = std::make_shared(std::move(usm_tensor)); + } else { + ov_tensor = std::make_shared(element_type, shape, tensor->data); + } + + return new ggml_openvino_tensor_extra(ov_tensor); +} diff --git a/ggml/src/ggml-openvino/ggml-openvino-extra.h b/ggml/src/ggml-openvino/ggml-openvino-extra.h new file mode 100644 index 00000000000..cd0baf4a681 --- /dev/null +++ b/ggml/src/ggml-openvino/ggml-openvino-extra.h @@ -0,0 +1,182 @@ +#pragma once + +#include "ggml.h" +#include "openvino/runtime/core.hpp" + +#define CL_TARGET_OPENCL_VERSION 300 +#include + +#include +#include +#include +#include +#include +#include +#include + +// ExtraQuantType enum - defines requantization target formats +enum class ExtraQuantType { F16, Q4_0_C, Q8_1_C, Q4_0_128, Q8_0_C, Q8_0_32 }; + +ov::Core & ov_singleton_core(); + +// Get the remote context for the current device (returns empty optional for CPU) +std::optional ggml_openvino_get_remote_context(); + +// Get the compile config for the current device +const ov::AnyMap & ggml_openvino_get_compile_config(); + +// Get the OpenCL command queue for GPU operations (returns nullptr for CPU/NPU) +cl_command_queue ggml_openvino_get_cl_queue(); + +// Intel USM extension function type +typedef cl_int(CL_API_CALL * clEnqueueMemFillINTEL_fn)(cl_command_queue queue, + void * dst_ptr, + const void * pattern, + size_t pattern_size, + size_t size, + cl_uint num_events_in_wait_list, + const cl_event * event_wait_list, + cl_event * event); + +typedef cl_int(CL_API_CALL * clEnqueueMemcpyINTEL_fn)(cl_command_queue queue, + cl_bool blocking, + void * dst_ptr, + const void * src_ptr, + size_t size, + cl_uint num_events_in_wait_list, + const cl_event * event_wait_list, + cl_event * event); + +// Get the clEnqueueMemFillINTEL function pointer (returns nullptr if not available) +clEnqueueMemFillINTEL_fn ggml_openvino_get_clEnqueueMemFillINTEL(); + +// Get the clEnqueueMemcpyINTEL function pointer (returns nullptr if not available) +clEnqueueMemcpyINTEL_fn ggml_openvino_get_clEnqueueMemcpyINTEL(); + +// ===================================================== +// Global Device Configuration (singleton) +// ===================================================== +// Initialized once during backend init from GGML_OPENVINO_DEVICE env var + +struct ggml_openvino_device_config { + std::string device_name = "CPU"; + bool is_npu = false; + bool initialized = false; + std::optional remote_context; + ov::AnyMap compile_config; + cl_command_queue cl_queue = nullptr; + + void init(); + ~ggml_openvino_device_config(); +}; + +// Get the global device config singleton +ggml_openvino_device_config & ggml_openvino_get_device_config(); + +// Initialize device config (call during backend init) +void ggml_openvino_init_device_config(); + +// Get the device name +const std::string & ggml_openvino_get_device_name(); + +// Check if running on NPU +bool ggml_openvino_is_npu(); + +// Get requantization type for a tensor type (returns nullopt if no requant needed) +std::optional ggml_openvino_get_requant_type(const ggml_tensor * tensor, bool no_requant = false); + +// ===================================================== +// OpenVINO Tensor Extra Types +// ===================================================== +// These types are stored in tensor->extra by the OpenVINO backend buffer. +// They allow: +// 1. Pre-built ov::Constant nodes for weights (avoiding memcpy during graph construction) +// 2. ov::Tensor wrappers for KV cache / compute tensors (for direct use with infer_request) + +// Base class for OpenVINO tensor extra data +struct ggml_openvino_extra_base { + enum class Type { WEIGHT, QUANTIZED_WEIGHT, TENSOR }; + Type type; + virtual ~ggml_openvino_extra_base() = default; +protected: + explicit ggml_openvino_extra_base(Type t) : type(t) {} +}; + +// Extra data for F16/F32/BF16 weight tensors - stores the pre-built weight node +struct ggml_openvino_weight_extra : public ggml_openvino_extra_base { + ov::Tensor weights; // The underlying weight data tensor + std::shared_ptr weight_node; // Pre-built OpenVINO weight node + + ggml_openvino_weight_extra(ov::Tensor w, std::shared_ptr n) : + ggml_openvino_extra_base(Type::WEIGHT), + weights(std::move(w)), + weight_node(std::move(n)) {} +}; + +// Extra data for quantized weight tensors - stores extracted weights/scales/zp and weight node +struct ggml_openvino_quantized_weight_extra : public ggml_openvino_extra_base { + ov::Tensor weights; // U4 or U8 extracted weights + ov::Tensor scales; // F16 scales + ov::Tensor zp; // U4 or U8 zero points (same type as weights) + std::shared_ptr weight_node; // Pre-built OpenVINO weight subgraph + + ggml_openvino_quantized_weight_extra(ov::Tensor w, ov::Tensor s, ov::Tensor z, std::shared_ptr n) : + ggml_openvino_extra_base(Type::QUANTIZED_WEIGHT), + weights(std::move(w)), + scales(std::move(s)), + zp(std::move(z)), + weight_node(std::move(n)) {} +}; + +// Extra data for KV cache / compute tensors - stores ov::Tensor for infer_request +struct ggml_openvino_tensor_extra : public ggml_openvino_extra_base { + std::shared_ptr tensor; // For direct use with infer_request + + explicit ggml_openvino_tensor_extra(std::shared_ptr t) + : ggml_openvino_extra_base(Type::TENSOR), tensor(std::move(t)) {} +}; + +// ===================================================== +// Extracted Size Calculation for Quantized Tensors +// ===================================================== +// For quantized tensors, we need extra space to store extracted weights, scales, and zero points. +// Returns the total size needed in the buffer for extracted data. + +struct ggml_openvino_extracted_layout { + size_t total_size = 0; // Total bytes needed + size_t weights_offset = 0; // Offset to weights in buffer + size_t weights_size = 0; // Size of weights in bytes + size_t scales_offset = 0; // Offset to scales in buffer + size_t scales_size = 0; // Size of scales in bytes + size_t zp_offset = 0; // Offset to zero points in buffer + size_t zp_size = 0; // Size of zero points in bytes (U4 or U8) + bool is_u4; // true for U4 weights, false for U8 + int64_t weights_per_block; // weights per scale/zp block + bool is_symmetric; // true for symmetric quantization + + // Requantization info + bool is_requant = false; // true if this tensor needs requantization + std::optional requant_type; // target requant type if is_requant +}; + +// Calculate the buffer layout for extracted quantized data +ggml_openvino_extracted_layout ggml_openvino_get_extracted_layout(const ggml_tensor * tensor, bool use_bias = false); + +ggml_openvino_tensor_extra * ggml_openvino_create_tensor_extra(const ggml_tensor * tensor, bool is_remote); + +// Register an extra with the tensor's OpenVINO buffer context for proper lifetime management. +// This sets tensor->extra and tracks the extra in the buffer context for cleanup. +void ggml_openvino_buffer_register_extra(ggml_tensor * tensor, ggml_openvino_extra_base * extra); + +// ===================================================== +// OpenVINO Backend Context and Interface +// ===================================================== +struct ggml_backend_openvino_context { + int device = 0; + std::string name = "OpenVINO"; + std::string description = "OpenVINO Backend Context"; + + std::shared_ptr runtime_context = nullptr; + + ggml_backend_openvino_context() = default; +}; diff --git a/ggml/src/ggml-openvino/ggml-openvino.cpp b/ggml/src/ggml-openvino/ggml-openvino.cpp new file mode 100644 index 00000000000..0031cb7369f --- /dev/null +++ b/ggml/src/ggml-openvino/ggml-openvino.cpp @@ -0,0 +1,1110 @@ +#include "ggml-openvino.h" + +#include "ggml-backend-impl.h" +#include "ggml-backend.h" +#include "ggml-impl.h" +#include "ggml-openvino-extra.h" +#include "ggml-openvino/utils.h" +#include "ggml-quants.h" +#include "ggml.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if defined(_WIN32) +# define WIN32_LEAN_AND_MEAN +# ifndef NOMINMAX +# define NOMINMAX +# endif +# include +#else +# include +#endif + +// ===================================================== +// OpenVINO Buffer Implementation using ov::Tensor +// ===================================================== +// +// Design: This implementation uses a hybrid approach: +// 1. For weight tensors: Store a pre-built ov::op::v0::Constant in tensor->extra +// - This avoids the memcpy during graph construction +// - For quantized weights, the constant is already converted to OpenVINO format +// 2. For KV cache / compute tensors: Store an ov::Tensor in tensor->extra +// - This can be directly passed to infer_request +// - Future: can be changed to ov::RemoteTensor for GPU/NPU +// +// This design is similar to: +// - CUDA split buffer: tensor->extra stores device pointers +// - CPU repack buffer: tensor->extra stores tensor_traits with repacked data +// ===================================================== + +// Buffer context that manages per-tensor allocations (no contiguous buffer for weights) +struct ggml_backend_openvino_buffer_context { + int device; + std::string name; + size_t id; + + // For non-weight buffers (KV cache, compute), we still use contiguous allocation + void * data; + size_t size; + bool is_remote; + + // Wrapping of the buffer + std::shared_ptr ov_buffer; + + // Track all extras for cleanup + std::map tensor_extras; + + // Used for re-allocation on device for kvcache + void * data_prev; + + ggml_backend_openvino_buffer_context(int device, size_t size, bool is_remote = false) : + device(device), + name(std::string(GGML_OPENVINO_NAME) + std::to_string(device)), + id([]() { + static std::atomic next_id{1}; + return next_id.fetch_add(1); + }()), + data(nullptr), + size(size), + is_remote(is_remote) { + if (size == 0) { + return; + } + + const auto & device_name = ggml_openvino_get_device_name(); + + if (is_remote) { + GGML_ASSERT(device_name == "GPU"); + auto remote_context = ggml_openvino_get_remote_context(); + auto gpu_context = remote_context->as(); + ov::intel_gpu::ocl::USMTensor usm_tensor = + gpu_context.create_usm_device_tensor(ov::element::u8, ov::Shape{size}); + data = usm_tensor.get(); + ov_buffer = std::make_shared(std::move(usm_tensor)); + } else { + data = ggml_aligned_malloc(size); + ov_buffer = std::make_shared(ov::element::u8, ov::Shape{size}, data); + } + + if (data == nullptr) { + GGML_LOG_ERROR("%s: failed to allocate %zu bytes\n", __func__, size); + return; + } + + if (reinterpret_cast(data) % TENSOR_ALIGNMENT != 0) { + GGML_LOG_ERROR("%s: %s buffer is not aligned to %d bytes\n", __func__, device_name.c_str(), + TENSOR_ALIGNMENT); + GGML_ABORT("fatal error"); + } + } + + ~ggml_backend_openvino_buffer_context() { + // Clean up all tensor extras + // GGML_LOG_DEBUG("Deleting OpenVINO buffer context #%zu for device %d, size %zu MB\n", id, device, + // size / 1024 / 1024); + for (auto & pair : tensor_extras) { + delete pair.second; + } + tensor_extras.clear(); + if (!is_remote && data != nullptr) { + ggml_aligned_free(data, size); + } + } +}; + +// Buffer type context (per-device) +struct ggml_backend_openvino_buffer_type_context { + int device; + std::string name; +}; + +// Buffer interface functions +static void ggml_backend_openvino_buffer_free_buffer(ggml_backend_buffer_t buffer) { + ggml_backend_openvino_buffer_context * ctx = (ggml_backend_openvino_buffer_context *) buffer->context; + delete ctx; +} + +static void * ggml_backend_openvino_buffer_get_base(ggml_backend_buffer_t buffer) { + ggml_backend_openvino_buffer_context * ctx = (ggml_backend_openvino_buffer_context *) buffer->context; + return ctx->data; +} + +static enum ggml_status ggml_backend_openvino_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { + // GGML_LOG_DEBUG("%s: buffer usage=%d, tensor name=%s\n", __func__, buffer->usage, tensor->name); + ggml_backend_openvino_buffer_context * ctx = (ggml_backend_openvino_buffer_context *) buffer->context; + + // Put kvcache on device memory for GPU (NPU memory is too small even for kvcache) + if (strncmp(tensor->name, "cache_", 6) == 0 && !ctx->is_remote && ggml_openvino_get_device_name() == "GPU" && + !getenv("GGML_OPENVINO_STATEFUL_EXECUTION")) { + GGML_ASSERT(ctx->tensor_extras.empty()); + auto device = ctx->device; + auto size = ctx->size; + auto * data_prev = ctx->data; + delete ctx; + ctx = new ggml_backend_openvino_buffer_context(device, size, true); + buffer->context = ctx; + tensor->data = (char *) ctx->data + ((char *) tensor->data - (char *) data_prev); + } + + // Views share the extra from view_src + if (tensor->view_src != nullptr) { + GGML_ASSERT(tensor->view_src->buffer->buft == buffer->buft); + if (tensor->view_src->extra != nullptr) { + tensor->extra = tensor->view_src->extra; + } + return GGML_STATUS_SUCCESS; + } + + ctx = (ggml_backend_openvino_buffer_context *) buffer->context; + + if (tensor->data != nullptr && !ggml_is_quantized(tensor->type)) { + ggml_openvino_tensor_extra * extra = ggml_openvino_create_tensor_extra(tensor, ctx->is_remote); + if (extra != nullptr) { + auto it = ctx->tensor_extras.find(tensor); + if (it != ctx->tensor_extras.end()) { + delete it->second; + } + ctx->tensor_extras[tensor] = extra; + tensor->extra = extra; + } + } + + return GGML_STATUS_SUCCESS; +} + +static void ggml_backend_openvino_buffer_memset_tensor(ggml_backend_buffer_t buffer, + ggml_tensor * tensor, + uint8_t value, + size_t offset, + size_t size) { + // GGML_LOG_DEBUG("%s: buffer usage=%d, tensor name=%s\n", __func__, buffer->usage, tensor->name); + GGML_ASSERT(tensor != nullptr && tensor->data != nullptr); + ggml_backend_openvino_buffer_context * ctx = (ggml_backend_openvino_buffer_context *) buffer->context; + + if (ctx->is_remote) { + // For remote (device) buffers, use OpenCL USM memfill + cl_command_queue queue = ggml_openvino_get_cl_queue(); + auto mem_fill_fn = ggml_openvino_get_clEnqueueMemFillINTEL(); + if (queue != nullptr && mem_fill_fn != nullptr) { + uint8_t pattern = value; + cl_int err = mem_fill_fn(queue, (char *) tensor->data + offset, &pattern, sizeof(pattern), size, 0, nullptr, + nullptr); + if (err != CL_SUCCESS) { + GGML_LOG_ERROR("%s: clEnqueueMemFillINTEL failed with error %d\n", __func__, err); + } + clFinish(queue); + } else { + GGML_LOG_ERROR("%s: no OpenCL queue or clEnqueueMemFillINTEL not available for GPU buffer\n", __func__); + } + } else { + memset((char *) tensor->data + offset, value, size); + } +} + +static void ggml_backend_openvino_buffer_set_tensor(ggml_backend_buffer_t buffer, + ggml_tensor * tensor, + const void * data, + size_t offset, + size_t size) { + // GGML_LOG_DEBUG("%s: buffer usage=%d, tensor name=%s\n", __func__, buffer->usage, tensor->name); + GGML_ASSERT(tensor != nullptr && tensor->data != nullptr); + ggml_backend_openvino_buffer_context * ctx = (ggml_backend_openvino_buffer_context *) buffer->context; + + // Check if this is a weight buffer (usage is set BEFORE set_tensor is called, except in test-backend-ops) + bool is_weight_buffer = (buffer->usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS); + // Full tensor set: offset=0, full size, not a view + bool is_full_tensor_set = (offset == 0 && size == ggml_nbytes(tensor) && tensor->view_src == nullptr); + // 2D tensor (typical weight shape) + bool is_2d = (tensor->ne[2] == 1 && tensor->ne[3] == 1); + + if (is_weight_buffer && is_full_tensor_set && is_2d) { + try { + auto result = process_weight_tensor(tensor, data, tensor->data); + result.weight_node->set_friendly_name(tensor->name); + + // const auto & layout = result.layout; + ggml_openvino_extra_base * extra; + + // Quantized path with extracted weight/scale/zp tensors + if (result.is_quantized()) { + extra = new ggml_openvino_quantized_weight_extra(std::move(result.weights), std::move(result.scales), + std::move(result.zp), result.weight_node); + + // if (layout.is_requant) { + // GGML_LOG_DEBUG("%s: requantized %s to %s (u%d, block_size=%ld)\n", __func__, tensor->name, + // extra_quant_type_name(layout.requant_type.value()), layout.is_u4 ? 4 : 8, + // layout.weights_per_block); + // } else { + // int64_t n_blocks = ggml_nelements(tensor) / layout.weights_per_block; + // GGML_LOG_DEBUG("%s: extracted quantized weight node for %s (u%d, %zu weights, %ld blocks)\n", + // __func__, tensor->name, layout.is_u4 ? 4 : 8, layout.weights_size, n_blocks); + // } + } else { + // F16/F32/BF16 weight or F16-requant + extra = new ggml_openvino_weight_extra(std::move(result.weights), result.weight_node); + + // if (layout.total_size > 0) { + // GGML_LOG_DEBUG("%s: requantized %s to F16\n", __func__, tensor->name); + // } else { + // GGML_LOG_DEBUG("%s: created shared-memory weight node for %s\n", __func__, tensor->name); + // } + } + + ctx->tensor_extras[tensor] = extra; + tensor->extra = extra; + + } catch (const std::exception & e) { + GGML_LOG_ERROR("%s: failed to process weight tensor for %s: %s\n", __func__, tensor->name, e.what()); + memcpy((char *) tensor->data + offset, data, size); + } + } else { + // Non-weight tensor (KV cache, activations, etc.) - copy data. test-backend-ops also goes here + if (ctx->is_remote) { + cl_command_queue queue = ggml_openvino_get_cl_queue(); + auto mem_cpy_fn = ggml_openvino_get_clEnqueueMemcpyINTEL(); + if (queue != nullptr && mem_cpy_fn != nullptr) { + cl_int err = + mem_cpy_fn(queue, CL_TRUE, (char *) tensor->data + offset, data, size, 0, nullptr, nullptr); + if (err != CL_SUCCESS) { + GGML_LOG_ERROR("%s: clEnqueueMemcpyINTEL failed with error %d\n", __func__, err); + } + } else { + GGML_LOG_ERROR("%s: no OpenCL queue or clEnqueueMemcpyINTEL not available for GPU buffer\n", __func__); + } + } else { + memcpy((char *) tensor->data + offset, data, size); + } + + ggml_openvino_tensor_extra * extra = ggml_openvino_create_tensor_extra(tensor, ctx->is_remote); + if (extra == nullptr) { + // GGML_LOG_ERROR("%s: failed to create tensor extra for %s\n", __func__, tensor->name); + return; + } + + auto it = ctx->tensor_extras.find(tensor); + if (it != ctx->tensor_extras.end()) { + delete it->second; + } + ctx->tensor_extras[tensor] = extra; + tensor->extra = extra; + } +} + +static void ggml_backend_openvino_buffer_get_tensor(ggml_backend_buffer_t buffer, + const ggml_tensor * tensor, + void * data, + size_t offset, + size_t size) { + // GGML_LOG_DEBUG("%s: buffer usage=%d, tensor name=%s\n", __func__, buffer->usage, tensor->name); + GGML_ASSERT(tensor != nullptr && tensor->data != nullptr); + ggml_backend_openvino_buffer_context * ctx = (ggml_backend_openvino_buffer_context *) buffer->context; + + if (ctx->is_remote) { + // For remote (device) buffers, use OpenCL USM memcpy (device-to-host) + cl_command_queue queue = ggml_openvino_get_cl_queue(); + auto mem_cpy_fn = ggml_openvino_get_clEnqueueMemcpyINTEL(); + if (queue != nullptr && mem_cpy_fn != nullptr) { + cl_int err = + mem_cpy_fn(queue, CL_TRUE, data, (const char *) tensor->data + offset, size, 0, nullptr, nullptr); + if (err != CL_SUCCESS) { + GGML_LOG_ERROR("%s: clEnqueueMemcpyINTEL failed with error %d\n", __func__, err); + } + } else { + GGML_LOG_ERROR("%s: no OpenCL queue or clEnqueueMemcpyINTEL not available for GPU buffer\n", __func__); + } + } else { + memcpy(data, (const char *) tensor->data + offset, size); + } +} + +static bool ggml_backend_openvino_buffer_cpy_tensor(ggml_backend_buffer_t buffer, + const ggml_tensor * src, + ggml_tensor * dst) { + // GGML_LOG_DEBUG("%s: src tensor name=%s, dst tensor name=%s\n", __func__, src->name, dst->name); + GGML_ASSERT(src != nullptr && dst != nullptr); + ggml_backend_openvino_buffer_context * ctx = (ggml_backend_openvino_buffer_context *) buffer->context; + + if (ctx->is_remote) { + // For remote (device) buffers, use OpenCL USM memcpy + cl_command_queue queue = ggml_openvino_get_cl_queue(); + auto mem_cpy_fn = ggml_openvino_get_clEnqueueMemcpyINTEL(); + if (queue == nullptr || mem_cpy_fn == nullptr) { + GGML_LOG_ERROR("%s: no OpenCL queue or clEnqueueMemcpyINTEL not available for GPU buffer\n", __func__); + return false; + } + // Can copy from host to device + if (ggml_backend_buffer_is_host(src->buffer)) { + cl_int err = mem_cpy_fn(queue, CL_TRUE, dst->data, src->data, ggml_nbytes(src), 0, nullptr, nullptr); + if (err != CL_SUCCESS) { + GGML_LOG_ERROR("%s: clEnqueueMemcpyINTEL (host-to-device) failed with error %d\n", __func__, err); + return false; + } + return true; + } + // Can also copy from device to device if both are OpenVINO remote buffers + if (ggml_backend_buffer_is_openvino(src->buffer)) { + ggml_backend_openvino_buffer_context * src_ctx = + (ggml_backend_openvino_buffer_context *) src->buffer->context; + if (src_ctx->is_remote) { + cl_int err = + mem_cpy_fn(queue, CL_TRUE, dst->data, src->data, ggml_nbytes(src), 0, nullptr, nullptr); + if (err != CL_SUCCESS) { + GGML_LOG_ERROR("%s: clEnqueueMemcpyINTEL (device-to-device) failed with error %d\n", __func__, + err); + return false; + } + return true; + } + } + return false; + } + + // Host buffer - can copy from any host buffer + if (ggml_backend_buffer_is_host(src->buffer)) { + memcpy(dst->data, src->data, ggml_nbytes(src)); + return true; + } + return false; +} + +static void ggml_backend_openvino_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { + ggml_backend_openvino_buffer_context * ctx = (ggml_backend_openvino_buffer_context *) buffer->context; + GGML_ASSERT(ctx->data != nullptr); + if (ctx->is_remote) { + cl_command_queue queue = ggml_openvino_get_cl_queue(); + auto mem_fill_fn = ggml_openvino_get_clEnqueueMemFillINTEL(); + if (queue != nullptr && mem_fill_fn != nullptr) { + uint8_t pattern = value; + cl_int err = mem_fill_fn(queue, ctx->data, &pattern, sizeof(pattern), ctx->size, 0, nullptr, nullptr); + if (err != CL_SUCCESS) { + GGML_LOG_WARN("%s: clEnqueueMemFillINTEL failed with error %d\n", __func__, err); + } + clFinish(queue); + } else { + GGML_LOG_WARN("%s: no OpenCL queue or clEnqueueMemFillINTEL not available for GPU buffer clear\n", + __func__); + } + } else { + memset(ctx->data, value, ctx->size); + } +} + +static const ggml_backend_buffer_i ggml_backend_openvino_buffer_interface = { + /* .free_buffer = */ ggml_backend_openvino_buffer_free_buffer, + /* .get_base = */ ggml_backend_openvino_buffer_get_base, + /* .init_tensor = */ ggml_backend_openvino_buffer_init_tensor, + /* .memset_tensor = */ ggml_backend_openvino_buffer_memset_tensor, + /* .set_tensor = */ ggml_backend_openvino_buffer_set_tensor, + /* .get_tensor = */ ggml_backend_openvino_buffer_get_tensor, + /* .cpy_tensor = */ ggml_backend_openvino_buffer_cpy_tensor, + /* .clear = */ ggml_backend_openvino_buffer_clear, + /* .reset = */ NULL, +}; + +// Buffer type interface functions +static const char * ggml_backend_openvino_buffer_type_get_name(ggml_backend_buffer_type_t buft) { + ggml_backend_openvino_buffer_type_context * ctx = (ggml_backend_openvino_buffer_type_context *) buft->context; + return ctx->name.c_str(); +} + +static ggml_backend_buffer_t ggml_backend_openvino_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, + size_t size) { + ggml_backend_openvino_buffer_type_context * buft_ctx = (ggml_backend_openvino_buffer_type_context *) buft->context; + + // Create buffer context with contiguous memory allocation + ggml_backend_openvino_buffer_context * ctx = new ggml_backend_openvino_buffer_context(buft_ctx->device, size); + + if (ctx->data == nullptr && size > 0) { + GGML_LOG_ERROR("%s: failed to allocate buffer of size %zu\n", __func__, size); + delete ctx; + return nullptr; + } + + return ggml_backend_buffer_init(buft, ggml_backend_openvino_buffer_interface, ctx, size); +} + +static size_t ggml_backend_openvino_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { + GGML_UNUSED(buft); + return TENSOR_ALIGNMENT; +} + +static size_t ggml_backend_openvino_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) { + GGML_UNUSED(buft); + return SIZE_MAX; +} + +static size_t ggml_backend_openvino_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, + const ggml_tensor * tensor) { + GGML_UNUSED(buft); + + // For quantized 2D tensors (weights), we need extra space for extracted data + if (ggml_is_quantized(tensor->type) && tensor->ne[2] == 1 && tensor->ne[3] == 1) { + ggml_openvino_extracted_layout layout = ggml_openvino_get_extracted_layout(tensor); + if (layout.total_size > 0) { + // GGML_LOG_DEBUG("%s: tensor %s needs %zu bytes (original %zu, extracted: weights=%zu scales=%zu zp=%zu)\n", + // __func__, tensor->name, layout.total_size, ggml_nbytes(tensor), layout.weights_size, + // layout.scales_size, layout.zp_size); + return layout.total_size; + } + } + + return ggml_nbytes(tensor); +} + +static const ggml_backend_buffer_type_i ggml_backend_openvino_buffer_type_interface = { + /* .get_name = */ ggml_backend_openvino_buffer_type_get_name, + /* .alloc_buffer = */ ggml_backend_openvino_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_openvino_buffer_type_get_alignment, + /* .get_max_size = */ ggml_backend_openvino_buffer_type_get_max_size, + /* .get_alloc_size = */ ggml_backend_openvino_buffer_type_get_alloc_size, + /* .is_host = */ nullptr, +}; + +// Get buffer type for a specific device +GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_openvino_buffer_type(int device) { + GGML_ASSERT(device >= 0 && device < ggml_backend_openvino_get_device_count()); + + static std::mutex mutex; + std::lock_guard lock(mutex); + + static std::vector buffer_types; + static std::vector buffer_type_contexts; + + if (buffer_types.empty()) { + int device_count = ggml_backend_openvino_get_device_count(); + buffer_types.resize(device_count); + buffer_type_contexts.resize(device_count); + + for (int i = 0; i < device_count; i++) { + buffer_type_contexts[i].device = i; + buffer_type_contexts[i].name = std::string(GGML_OPENVINO_NAME) + std::to_string(i); + + buffer_types[i] = ggml_backend_buffer_type{ + /* .iface = */ ggml_backend_openvino_buffer_type_interface, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_openvino_reg(), i), + /* .context = */ &buffer_type_contexts[i], + }; + } + } + + return &buffer_types[device]; +} + +// ===================================================== +// OpenVINO Host Buffer Implementation +// ===================================================== + +static const char * ggml_backend_openvino_host_buffer_type_get_name(ggml_backend_buffer_type_t buft) { + ggml_backend_openvino_buffer_type_context * ctx = (ggml_backend_openvino_buffer_type_context *) buft->context; + static std::string name; + name = ctx->name + "_HOST"; + return name.c_str(); +} + +static bool ggml_backend_openvino_host_buffer_type_is_host(ggml_backend_buffer_type_t buft) { + GGML_UNUSED(buft); + return true; +} + +static const ggml_backend_buffer_type_i ggml_backend_openvino_host_buffer_type_interface = { + /* .get_name = */ ggml_backend_openvino_host_buffer_type_get_name, + /* .alloc_buffer = */ ggml_backend_openvino_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_openvino_buffer_type_get_alignment, + /* .get_max_size = */ ggml_backend_openvino_buffer_type_get_max_size, + /* .get_alloc_size = */ ggml_backend_openvino_buffer_type_get_alloc_size, + /* .is_host = */ ggml_backend_openvino_host_buffer_type_is_host, +}; + +GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_openvino_host_buffer_type(int device) { + GGML_ASSERT(device >= 0 && device < ggml_backend_openvino_get_device_count()); + + static std::mutex mutex; + std::lock_guard lock(mutex); + + static std::vector buffer_types; + static std::vector buffer_type_contexts; + + if (buffer_types.empty()) { + int device_count = ggml_backend_openvino_get_device_count(); + buffer_types.resize(device_count); + buffer_type_contexts.resize(device_count); + + for (int i = 0; i < device_count; i++) { + buffer_type_contexts[i].device = i; + buffer_type_contexts[i].name = std::string(GGML_OPENVINO_NAME) + std::to_string(i); + + buffer_types[i] = ggml_backend_buffer_type{ + /* .iface = */ ggml_backend_openvino_host_buffer_type_interface, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_openvino_reg(), i), + /* .context = */ &buffer_type_contexts[i], + }; + } + } + + return &buffer_types[device]; +} + +bool ggml_backend_buffer_is_openvino(ggml_backend_buffer_t buffer) { + return buffer->iface.free_buffer == ggml_backend_openvino_buffer_free_buffer; +} + +size_t ggml_backend_openvino_buffer_get_ctx_id(ggml_backend_buffer_t buffer) { + if (!ggml_backend_buffer_is_openvino(buffer)) { + return 0; + } + ggml_backend_openvino_buffer_context * ctx = (ggml_backend_openvino_buffer_context *) buffer->context; + return ctx->id; +} + +void ggml_openvino_buffer_register_extra(ggml_tensor * tensor, ggml_openvino_extra_base * extra) { + GGML_ASSERT(tensor != nullptr); + GGML_ASSERT(tensor->buffer != nullptr); + GGML_ASSERT(ggml_backend_buffer_is_openvino(tensor->buffer)); + + auto * ctx = static_cast(tensor->buffer->context); + + auto it = ctx->tensor_extras.find(tensor); + if (it != ctx->tensor_extras.end()) { + delete it->second; + } + + ctx->tensor_extras[tensor] = extra; + tensor->extra = extra; +} + +bool ggml_backend_buft_is_openvino(ggml_backend_buffer_type_t buft) { + return buft->iface.get_name == ggml_backend_openvino_buffer_type_get_name; +} + +bool ggml_backend_buft_is_openvino_host(ggml_backend_buffer_type_t buft) { + return buft->iface.get_name == ggml_backend_openvino_host_buffer_type_get_name; +} + +static void ggml_backend_openvino_free(ggml_backend_t backend) { + ggml_backend_openvino_context * ctx = (ggml_backend_openvino_context *) backend->context; + delete ctx; + delete backend; +} + +static const char * ggml_backend_openvino_get_name(ggml_backend_t backend) { + return GGML_OPENVINO_NAME; + GGML_UNUSED(backend); +} + +static enum ggml_status ggml_backend_openvino_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { + return ov_graph_compute(cgraph, backend); + GGML_UNUSED(backend); +} + +static const ggml_backend_i ggml_backend_openvino_interface = { + /* .get_name = */ ggml_backend_openvino_get_name, + /* .free = */ ggml_backend_openvino_free, + /* .set_tensor_async = */ NULL, + /* .get_tensor_async = */ NULL, + /* .cpy_tensor_async = */ NULL, + /* .synchronize = */ NULL, + /* .graph_plan_create = */ NULL, + /* .graph_plan_free = */ NULL, + /* .graph_plan_update = */ NULL, + /* .graph_plan_compute = */ NULL, + /* .graph_compute = */ ggml_backend_openvino_graph_compute, + /* .event_record = */ NULL, + /* .event_wait = */ NULL, + /* .graph_optimize = */ NULL, +}; + +int ggml_backend_openvino_get_device_count() { + return 1; +} + +static ggml_guid_t ggml_backend_openvino_guid(void) { + static ggml_guid guid = {0x12, 0xa8, 0xae, 0xf4, 0xc0, 0x1e, 0x61, 0x97, + 0x8f, 0xeb, 0x33, 0x04, 0xa1, 0x33, 0x51, 0x2d}; + return &guid; +} + +static std::shared_ptr get_ov_runtime_context_ptr() { + static std::shared_ptr r_ctx = std::make_shared(); + return r_ctx; +} + +// backend API +GGML_BACKEND_API ggml_backend_t ggml_backend_openvino_init(int device) { + if (device < 0 || device >= ggml_backend_openvino_get_device_count()) { + GGML_LOG_ERROR("%s: invalid device %d\n", __func__, device); + return nullptr; + } + + ggml_backend_openvino_context * ctx = new ggml_backend_openvino_context; + if (ctx == nullptr) { + GGML_LOG_ERROR("%s: failed to allocate context\n", __func__); + return nullptr; + } + + ctx->runtime_context = get_ov_runtime_context_ptr(); + if (ctx->runtime_context == nullptr) { + GGML_LOG_ERROR("%s: failed to allocate runtime context\n", __func__); + delete ctx; + return nullptr; + } + + std::shared_ptr r_ctx = std::static_pointer_cast(ctx->runtime_context); + r_ctx->device = ggml_openvino_get_device_name(); + r_ctx->stateful = getenv("GGML_OPENVINO_STATEFUL_EXECUTION") && !ggml_openvino_is_npu(); + + ggml_backend_t openvino_backend = new ggml_backend{ + /* .guid = */ ggml_backend_openvino_guid(), + /* .interface = */ ggml_backend_openvino_interface, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_openvino_reg(), device), + /* .context = */ ctx, + }; + + return openvino_backend; +} + +GGML_BACKEND_API bool ggml_backend_is_openvino(ggml_backend_t backend) { + return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_openvino_guid()); +} + +struct ggml_backend_openvino_device_context { + int device; + std::string name; + std::string description; +}; + +static const char * ggml_backend_openvino_device_get_name(ggml_backend_dev_t dev) { + ggml_backend_openvino_device_context * ctx = (ggml_backend_openvino_device_context *) dev->context; + return ctx->name.c_str(); +} + +static const char * ggml_backend_openvino_device_get_description(ggml_backend_dev_t dev) { + ggml_backend_openvino_device_context * ctx = (ggml_backend_openvino_device_context *) dev->context; + return ctx->description.c_str(); +} + +static void ggml_backend_openvino_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { +#ifdef _WIN32 + MEMORYSTATUSEX status; + status.dwLength = sizeof(status); + GlobalMemoryStatusEx(&status); + *total = status.ullTotalPhys; + *free = status.ullAvailPhys; +#else + long pages = sysconf(_SC_PHYS_PAGES); + long page_size = sysconf(_SC_PAGE_SIZE); + *total = pages * page_size; + + // "free" system memory is ill-defined, for practical purposes assume that all of it is free: + *free = *total; +#endif // _WIN32 + + GGML_UNUSED(dev); +} + +static enum ggml_backend_dev_type ggml_backend_openvino_device_get_type(ggml_backend_dev_t dev) { + GGML_UNUSED(dev); + return GGML_BACKEND_DEVICE_TYPE_GPU; +} + +static void ggml_backend_openvino_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) { + props->name = ggml_backend_openvino_device_get_name(dev); + props->description = ggml_backend_openvino_device_get_description(dev); + props->type = ggml_backend_openvino_device_get_type(dev); + ggml_backend_openvino_device_get_memory(dev, &props->memory_free, &props->memory_total); + + props->caps = { + /* .async = */ false, + /* .host_buffer = */ false, + /* .buffer_from_host_ptr = */ false, + /* .events = */ false, + }; +} + +static ggml_backend_t ggml_backend_openvino_device_init(ggml_backend_dev_t dev, const char * params) { + GGML_UNUSED(params); + ggml_backend_openvino_device_context * ctx = (ggml_backend_openvino_device_context *) dev->context; + return ggml_backend_openvino_init(ctx->device); +} + +static ggml_backend_buffer_type_t ggml_backend_openvino_device_get_buffer_type(ggml_backend_dev_t dev) { + ggml_backend_openvino_device_context * ctx = (ggml_backend_openvino_device_context *) dev->context; + return ggml_backend_openvino_buffer_type(ctx->device); +} + +static ggml_backend_buffer_type_t ggml_backend_openvino_device_get_host_buffer_type(ggml_backend_dev_t dev) { + ggml_backend_openvino_device_context * ctx = (ggml_backend_openvino_device_context *) dev->context; + return ggml_backend_openvino_host_buffer_type(ctx->device); +} + +static bool has_view_op_input(const ggml_tensor * op) { + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (op->src[i] == nullptr) { + break; + } + if (op->src[i]->op == GGML_OP_VIEW) { + return true; + } + } + return false; +} + +static bool is_supported_flash_attn_pattern(const ggml_tensor * op) { + // pattern of q,k,v should be q->op==PERMUTE, q->src[0]->op==VIEW, q->src[0]->src[0]->view_src==nullptr + for (int i = 0; i < 3; i++) { + const ggml_tensor * src = op->src[i]; + if (src->op != GGML_OP_PERMUTE || src->src[0] == nullptr || src->src[0]->op != GGML_OP_VIEW || + src->src[0]->src[0] == nullptr || src->src[0]->src[0]->view_src != nullptr) { + return false; + } + } + return true; +} + +static bool is_op_unsupported_case(const ggml_tensor * op) { + switch (op->op) { + case GGML_OP_GET_ROWS: + case GGML_OP_SET_ROWS: { + if (op->ne[3] != 1) { + return true; + } + break; + } + case GGML_OP_ADD: + case GGML_OP_MUL: { + if (op->src[1]->op == GGML_OP_PERMUTE) { + return true; + } + for (int i = 0; i < 4; i++) { + if (op->src[0]->ne[i] != op->src[1]->ne[i] && (op->src[0]->ne[i] != 1 && op->src[1]->ne[i] != 1)) { + return true; + } + } + break; + } + case GGML_OP_SOFT_MAX: { + if (op->src[2] != nullptr) { + // GGML_LOG_WARN("OpenVINO backend does not support SOFT_MAX with sinks\n"); + return true; + } + float scale = 1.0f; + float max_bias = 0.0f; + const auto * op_params = op->op_params; + memcpy(&scale, (const float *) op_params + 0, sizeof(float)); + memcpy(&max_bias, (const float *) op_params + 1, sizeof(float)); + if (max_bias > 0) { + // GGML_LOG_WARN("OpenVINO backend does not support SOFT_MAX with max_bias > 0\n"); + return true; + } + break; + } + case GGML_OP_FLASH_ATTN_EXT: { + if (op->src[4] != nullptr) { + // GGML_LOG_WARN("OpenVINO backend does not support FLASH_ATTN_EXT with sinks\n"); + return true; + } + if (!is_supported_flash_attn_pattern(op)) { + return true; + } + float scale = 1.0f; + float max_bias = 0.0f; + float logit_softcap = 0.0f; + const auto * op_params = op->op_params; + memcpy(&scale, (const float *) op_params + 0, sizeof(float)); + memcpy(&max_bias, (const float *) op_params + 1, sizeof(float)); + memcpy(&logit_softcap, (const float *) op_params + 2, sizeof(float)); + if (max_bias > 0) { + // GGML_LOG_WARN("OpenVINO backend does not support FLASH_ATTN_EXT with max_bias > 0\n"); + return true; + } + if (logit_softcap != 0) { + // GGML_LOG_WARN("OpenVINO backend does not support FLASH_ATTN_EXT with logit_softcap != 0\n"); + return true; + } + break; + } + case GGML_OP_PERMUTE: { + if (op->type == GGML_TYPE_BF16) { + // err msg: [GPU] Could not find a suitable kernel for transpose + // GGML_LOG_WARN("OpenVINO backend does not support PERMUTE with BF16 type\n"); + return true; + } + break; + } + case GGML_OP_CPY: { + if (op->src[1] != op) { + // GGML_LOG_WARN("OpenVINO backend only supports CPY that is a cast\n"); + return true; + } + break; + } + case GGML_OP_MUL_MAT: { + if (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F16) { + // Has accuracy issue, try enabling this and see `test-backend-ops -o "MUL_MAT"` + // GGML_LOG_WARN("OpenVINO backend does not support MUL_MAT with two F16 tensors\n"); + return true; + } + if (op->src[0]->ne[3] != op->src[1]->ne[3] && op->src[0]->ne[3] != 1 && op->src[1]->ne[3] != 1) { + return true; + } + if (op->src[0]->op == GGML_OP_PERMUTE || op->src[1]->op == GGML_OP_PERMUTE) { + return true; + } + if (ggml_is_quantized(op->src[0]->type) && op->src[0]->ne[1] == 1) { + // MUL_MAT(type_a=q4_0,type_b=f32,m=1,n=2048,k=8192,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1) + // triggers a bug in ov matmul_shape_inference.hpp + return true; + } + if (op->src[0]->op == GGML_OP_VIEW && op->src[1]->op == GGML_OP_VIEW) { + return true; + } + break; + } + case GGML_OP_ROPE: { + const int32_t * op_params = op->op_params; + const int n_dims = op_params[1]; + const int mode = op_params[2]; + if (mode != GGML_ROPE_TYPE_NORMAL && mode != GGML_ROPE_TYPE_NEOX) { + // GGML_LOG_WARN("OpenVINO backend does not support ROPE with mode %d\n", mode); + return true; + } + if (n_dims != 0.0f && n_dims != op->src[0]->ne[0]) { + // GGML_LOG_WARN("OpenVINO backend does not support ROPE with n_dims %d != src[0]->ne[0] %ld\n", n_dims, + // op->src[0]->ne[0]); + return true; + } + if (op->type != GGML_TYPE_F32) { + // GGML_LOG_WARN("OpenVINO backend does not support ROPE with type %s\n", ggml_type_name(op->type)); + return true; + } + float freq_scale; + float ext_factor; + memcpy(&freq_scale, op_params + 6, sizeof(float)); + memcpy(&ext_factor, op_params + 7, sizeof(float)); + if (ext_factor != 0.0f) { + // GGML_LOG_WARN("OpenVINO backend does not support ROPE with ext_factor %f != 0.0f\n", ext_factor); + return true; + } + if (op->src[0]->op == GGML_OP_VIEW) { + if (op->src[0]->view_src->ne[1] != op->src[0]->ne[2]) { + // GGML_LOG_WARN( + // "OpenVINO backend does not support ROPE with src[0]->view_src->ne[1] %ld != src[0]->ne[2] " + // "%ld\n", + // op->src[0]->view_src->ne[1], op->src[0]->ne[2]); + return true; + } + } + break; + } + default: + break; + } + if (op->op == GGML_OP_GET_ROWS) { + if (op->ne[0] == 256 && (op->src[0]->type == GGML_TYPE_Q4_K || op->src[0]->type == GGML_TYPE_Q5_K)) { + // ERR = 0.000000306 > 0.000000100 GET_ROWS(type=q4_K,n=256,m=5,r=4,be1=1,be2=1,v=0) + // ERR = 0.000000197 > 0.000000100 GET_ROWS(type=q5_K,n=256,m=5,r=4,be1=1,be2=1,v=0) + return true; + } + } + return false; +} + +static bool ggml_backend_openvino_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) { + GGML_ASSERT(dev->reg != nullptr); + + static std::set supported_types{GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_I64, + GGML_TYPE_I32, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1, GGML_TYPE_Q4_K, + GGML_TYPE_Q5_K, GGML_TYPE_Q8_0, GGML_TYPE_Q6_K}; + + static const std::set supported_ops{GGML_OP_NONE, GGML_OP_ADD, GGML_OP_MUL, GGML_OP_MUL_MAT, GGML_OP_VIEW, + /*GGML_OP_CONT,*/ GGML_OP_RESHAPE, GGML_OP_PERMUTE, GGML_OP_TRANSPOSE, + GGML_OP_GET_ROWS, GGML_OP_ROPE, GGML_OP_RMS_NORM, GGML_OP_SCALE, + // softmax is not updated due to replaced by flash_attn_ext + // GGML_OP_SOFT_MAX, + GGML_OP_SET_ROWS, GGML_OP_FLASH_ATTN_EXT, GGML_OP_CPY}; + static const std::set supported_unary_ops{ + GGML_UNARY_OP_SILU, + }; + static const std::set supported_glu_ops{ + GGML_GLU_OP_SWIGLU, + GGML_GLU_OP_GEGLU, + }; + + switch (op->op) { + case GGML_OP_UNARY: { + auto supported = supported_unary_ops.find(ggml_get_unary_op(op)) != supported_unary_ops.end(); + if (!supported) { + // GGML_LOG_WARN("OpenVINO backend does not support unary op %s\n", ggml_unary_op_name(ggml_get_unary_op(op))); + return false; + } + if (has_view_op_input(op)) { + // GGML_LOG_WARN("OpenVINO backend does not support unary op %s with view input\n", + // ggml_unary_op_name(ggml_get_unary_op(op))); + return false; + } + break; + } + case GGML_OP_GLU: { + auto supported = supported_glu_ops.find(ggml_get_glu_op(op)) != supported_glu_ops.end(); + if (!supported) { + // GGML_LOG_WARN("OpenVINO backend does not support GLU op %s\n", ggml_glu_op_name(ggml_get_glu_op(op))); + return false; + } + if (has_view_op_input(op)) { + // GGML_LOG_WARN("OpenVINO backend does not support unary op %s with view input\n", + // ggml_glu_op_name(ggml_get_glu_op(op))); + return false; + } + if (op->src[1] == nullptr && op->src[0]->ne[0] % 2 != 0) { + // triggers bug in ov gpu + return false; + } + break; + } + default: { + auto supported = supported_ops.find(op->op) != supported_ops.end(); + if (!supported) { + // GGML_LOG_WARN("OpenVINO backend does not support op %s\n", ggml_op_name(op->op)); + return false; + } + static std::set ops_not_support_view_input{ + GGML_OP_GET_ROWS, + GGML_OP_RMS_NORM, + }; + if (ops_not_support_view_input.find(op->op) != ops_not_support_view_input.end() && has_view_op_input(op)) { + // GGML_LOG_WARN("OpenVINO backend does not support op %s with view input\n", ggml_op_name(op->op)); + return false; + } + } + } + + if (supported_types.find(op->type) == supported_types.end()) { + // GGML_LOG_WARN("OpenVINO backend does not support tensor type %s\n", ggml_type_name(op->type)); + return false; + } + for (int i = 0; i < GGML_MAX_SRC; i++) { + auto * src = op->src[i]; + if (src == nullptr) { + break; + } + if (supported_types.find(src->type) == supported_types.end()) { + // GGML_LOG_WARN("OpenVINO backend does not support tensor type %s\n", ggml_type_name(src->type)); + return false; + } + if (ggml_is_quantized(src->type) && src->ne[2] != 1) { + // GGML_LOG_WARN("OpenVINO backend does not support 3D quantized tensors\n"); + return false; + } + } + + if (is_op_unsupported_case(op)) { + return false; + } + return true; +} + +static bool ggml_backend_openvino_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { + return ggml_backend_buft_is_openvino(buft) || ggml_backend_buft_is_host(buft); + GGML_UNUSED(dev); +} + +static const struct ggml_backend_device_i ggml_backend_openvino_device_interface = { + /* .get_name = */ ggml_backend_openvino_device_get_name, + /* .get_description = */ ggml_backend_openvino_device_get_description, + /* .get_memory = */ ggml_backend_openvino_device_get_memory, + /* .get_type = */ ggml_backend_openvino_device_get_type, + /* .get_props = */ ggml_backend_openvino_device_get_props, + /* .init_backend = */ ggml_backend_openvino_device_init, + /* .get_buffer_type = */ ggml_backend_openvino_device_get_buffer_type, + /* .get_host_buffer_type = */ ggml_backend_openvino_device_get_host_buffer_type, + /* .buffer_from_host_ptr = */ NULL, + /* .supports_op = */ ggml_backend_openvino_device_supports_op, + /* .supports_buft = */ ggml_backend_openvino_device_supports_buft, + /* .offload_op = */ NULL, + /* .event_new = */ NULL, + /* .event_free = */ NULL, + /* .event_synchronize = */ NULL, +}; + +struct ggml_backend_openvino_reg_context { + std::vector devices; +}; + +static const char * ggml_backend_openvino_reg_get_name(ggml_backend_reg_t reg) { + return GGML_OPENVINO_NAME; + GGML_UNUSED(reg); +} + +static size_t ggml_backend_openvino_reg_get_device_count(ggml_backend_reg_t reg) { + GGML_UNUSED(reg); + return (size_t) ggml_backend_openvino_get_device_count(); +} + +static ggml_backend_dev_t ggml_backend_openvino_reg_get_device(ggml_backend_reg_t reg, size_t index) { + ggml_backend_openvino_reg_context * ctx = (ggml_backend_openvino_reg_context *) reg->context; + GGML_ASSERT(index < ctx->devices.size()); + return ctx->devices[index]; +} + +static const struct ggml_backend_reg_i ggml_backend_openvino_reg_interface = { + /* .get_name = */ ggml_backend_openvino_reg_get_name, + /* .get_device_count = */ ggml_backend_openvino_reg_get_device_count, + /* .get_device = */ ggml_backend_openvino_reg_get_device, + /* .get_proc_address = */ NULL, +}; + +static void ggml_openvino_init() { + // Initialize device config singleton from env var + ggml_openvino_init_device_config(); + GGML_LOG_INFO("OpenVINO: using device %s\n", ggml_openvino_get_device_name().c_str()); +} + +GGML_BACKEND_API ggml_backend_reg_t ggml_backend_openvino_reg(void) { + static ggml_backend_reg reg; + + static bool initialized = false; + { + static std::mutex mutex; + std::lock_guard lock(mutex); + if (!initialized) { + ggml_openvino_init(); + + ggml_backend_openvino_reg_context * ctx = new ggml_backend_openvino_reg_context; + + for (int i = 0; i < ggml_backend_openvino_get_device_count(); i++) { + ggml_backend_openvino_device_context * dev_ctx = new ggml_backend_openvino_device_context; + dev_ctx->device = i; + dev_ctx->name = GGML_OPENVINO_NAME + std::to_string(i); + + dev_ctx->description = ov::get_openvino_version().description; + + ggml_backend_dev_t dev = + new ggml_backend_device{/* .interface = */ ggml_backend_openvino_device_interface, + /* .reg = */ ®, + /* .context = */ dev_ctx}; + ctx->devices.push_back(dev); + } + + reg = ggml_backend_reg{/* .api_version = */ GGML_BACKEND_API_VERSION, + /* .iface = */ ggml_backend_openvino_reg_interface, + /* .context = */ ctx}; + } + + initialized = true; + } + + return ® +} diff --git a/ggml/src/ggml-openvino/ggml-quants.cpp b/ggml/src/ggml-openvino/ggml-quants.cpp new file mode 100644 index 00000000000..dbf38646ddd --- /dev/null +++ b/ggml/src/ggml-openvino/ggml-quants.cpp @@ -0,0 +1,884 @@ +#include "ggml-quants.h" + +#include "ggml-common.h" +#include "ggml-impl.h" +#include "ggml.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +void unpack_32_4(const uint8_t * data, uint8_t * dst) { + std::fill_n(dst, 16, 0); + for (int j = 0; j < 16; ++j) { + uint8_t x = (data[j] & 0x0F); + uint8_t y = (data[j] >> 4); + if (j % 2 != 0) { + x <<= 4; + y <<= 4; + } + dst[j / 2] |= x; + dst[8 + j / 2] |= y; // Last 16 weights are in the higher bits + } +} + +// Extracts (weight, scales, zp) from Q4_0 tensors. +// Data layout is: |16 bit scale|32 x 4bit weights|. +void extract_q4_0_data(const ggml_tensor * tensor, + ov::Tensor & weights_arr, + ov::Tensor & scales_arr, + ov::Tensor & zp_arr) { + const uint64_t bytes_per_block = 18; // 2 bytes scale, 32x0.5 byte weights + + auto * data = static_cast(tensor->data); + auto * weights = static_cast(weights_arr.data()); + auto * scales = scales_arr.data::value_type>(); + auto * zp = static_cast(zp_arr.data()); + + bool is_scalar_zp = (zp_arr.get_size() == 1); // Symmetric quantization + + // For Q4_0, zero point is always 8 + if (is_scalar_zp) { + zp[0] = 8 | (8 << 4); // Pack two 4-bit values + } + + ov::parallel_for(scales_arr.get_size(), [&](size_t i) { + scales[i] = ov::float16::from_bits(*((uint16_t *) (data + i * bytes_per_block))); + // For asymmetric quantization, compute per-block zero points + if (!is_scalar_zp) { + // Pack two 4-bit zero points per byte + if (i % 2 == 0) { + zp[i / 2] = 8; // Lower nibble + } else { + zp[i / 2] |= (8 << 4); // Upper nibble + } + } + unpack_32_4(data + i * bytes_per_block + 2, weights + i * 16); + }); +} + +// Extracts (weight, scales, zp) from Q4_1 tensors. +// Data layout is: |16 bit scale|16 bit min|32 x 4bit weights|. +void extract_q4_1_data(const ggml_tensor * tensor, + ov::Tensor & weights_arr, + ov::Tensor & scales_arr, + ov::Tensor & zp_arr, + bool use_bias) { + const uint64_t bytes_per_block = 20; // 2 bytes scale, 2 bytes min, 32x0.5 byte weights + + auto * data = static_cast(tensor->data); + auto * weights = static_cast(weights_arr.data()); + auto * scales = scales_arr.data::value_type>(); + + if (use_bias) { + // Store bias (min) directly as f16 instead of computing u4 zero points + auto * bias = zp_arr.data::value_type>(); + ov::parallel_for(scales_arr.get_size(), [&](size_t i) { + float scale = static_cast(ov::float16::from_bits(*((uint16_t *) (data + i * bytes_per_block)))); + float min = static_cast(ov::float16::from_bits(*((uint16_t *) (data + i * bytes_per_block + 2)))); + scales[i] = ov::float16(scale); + bias[i] = ov::float16(min); // bias = min, dequant: w*s + bias + unpack_32_4(data + i * bytes_per_block + 4, weights + i * 16); + }); + } else { + auto * zp = static_cast(zp_arr.data()); + ov::parallel_for(scales_arr.get_size(), [&](size_t i) { + float scale = static_cast(ov::float16::from_bits(*((uint16_t *) (data + i * bytes_per_block)))); + float min = static_cast(ov::float16::from_bits(*((uint16_t *) (data + i * bytes_per_block + 2)))); + scales[i] = ov::float16(scale); + // zp = -min / scale (bias = min, so zp = -bias/scale) + uint8_t zp_val = (scale != 0.0f) ? (uint8_t) std::round(-min / scale) : 0; + // Pack two 4-bit zero points per byte + if (i % 2 == 0) { + zp[i / 2] = zp_val & 0x0F; // Lower nibble + } else { + zp[i / 2] |= (zp_val << 4); // Upper nibble + } + unpack_32_4(data + i * bytes_per_block + 4, weights + i * 16); + }); + } +} + +// Extracts (weight, scales, zp) from Q8_0 tensors. +// Data layout is: |16 bit scale|32 x 8bit weights|. +void extract_q8_0_data(const ggml_tensor * tensor, + ov::Tensor & weights_arr, + ov::Tensor & scales_arr, + ov::Tensor & zp_arr) { + const uint64_t weights_per_block = 32; + const uint64_t bytes_per_block = 34; // 2 bytes scale, 32x1 byte weights + + auto * data = static_cast(tensor->data); + auto * weights = static_cast(weights_arr.data()); + auto * scales = scales_arr.data::value_type>(); + auto * zp = static_cast(zp_arr.data()); + + bool is_scalar_zp = (zp_arr.get_size() == 1); // Symmetric quantization + + // For Q8_0, zero point is always 128 + if (is_scalar_zp) { + zp[0] = 128; + } + + ov::parallel_for(scales_arr.get_size(), [&](size_t i) { + uint8_t * block_data = data + i * bytes_per_block; + scales[i] = ov::float16::from_bits(*(uint16_t *) block_data); + // For asymmetric quantization, store per-block zero points + if (!is_scalar_zp) { + zp[i] = 128; + } + for (size_t j = 0; j < weights_per_block; ++j) { + uint8_t x = block_data[j + 2]; // j+2 to skip the scale bytes. + // Original data is in int8_t, so we add a bias of -128 and invert the first bit. + x ^= 1 << 7; + weights[i * weights_per_block + j] = x; + } + }); +} + +void unpack_256_4(const uint8_t * data, uint8_t * dst) { + // Initialize the output array with zeros + std::fill_n(dst, 128, 0); + + for (size_t i = 0; i < 4; ++i) { + for (int j = 0; j < 32; ++j) { + uint8_t x = (data[i * 32 + j] & 0x0F); + uint8_t y = (data[i * 32 + j] >> 4); + if (j % 2 != 0) { + x <<= 4; + y <<= 4; + } + dst[i * 32 + j / 2] |= x; + dst[i * 32 + 16 + j / 2] |= y; // Last 16 weights are in the higher bits + } + } +} + +void extract_q4_k_data(const ggml_tensor * tensor, + ov::Tensor & weights_arr, + ov::Tensor & scales_arr, + ov::Tensor & zp_arr, + bool use_bias) { + const uint64_t bytes_per_block = 2 + 2 + 12 + 128; + const uint64_t n_super_block = tensor->nb[3] / bytes_per_block; + + auto * data = static_cast(tensor->data); + auto * weights = static_cast(weights_arr.data()); + auto * scales = scales_arr.data::value_type>(); + + // For bias path, zp_arr holds f16 bias values; for zp path, it holds packed u4 zero points + auto * zp_u4 = use_bias ? nullptr : static_cast(zp_arr.data()); + auto * bias_f16 = use_bias ? zp_arr.data::value_type>() : nullptr; + + ov::parallel_for(n_super_block, [&](size_t i) { + uint8_t * block_data = data + i * bytes_per_block; + + // Extract scale factors and offsets + float scale_scales = static_cast(ov::float16::from_bits(*((uint16_t *) block_data))); + float scale_mins = static_cast(ov::float16::from_bits(*((uint16_t *) block_data + 1))); + + // Extract qs1 and qs2 + uint8_t * qs1 = block_data + 4; + + // Calculate scales + float scale_vals[8]; + scale_vals[0] = scale_scales * static_cast((*(qs1) & 0b111111)); + scale_vals[1] = scale_scales * static_cast((*(qs1 + 1) & 0b111111)); + scale_vals[2] = scale_scales * static_cast((*(qs1 + 2) & 0b111111)); + scale_vals[3] = scale_scales * static_cast((*(qs1 + 3) & 0b111111)); + scale_vals[4] = scale_scales * static_cast((*(qs1 + 8) & 0b00001111) | ((*(qs1) >> 6) << 4)); + scale_vals[5] = scale_scales * static_cast((*(qs1 + 9) & 0b00001111) | ((*(qs1 + 1) >> 6) << 4)); + scale_vals[6] = scale_scales * static_cast((*(qs1 + 10) & 0b00001111) | ((*(qs1 + 2) >> 6) << 4)); + scale_vals[7] = scale_scales * static_cast((*(qs1 + 11) & 0b00001111) | ((*(qs1 + 3) >> 6) << 4)); + + // Calculate min values (bias = -min) + float min_vals[8]; + min_vals[0] = scale_mins * static_cast((*(qs1 + 4) & 0b111111)); + min_vals[1] = scale_mins * static_cast((*(qs1 + 5) & 0b111111)); + min_vals[2] = scale_mins * static_cast((*(qs1 + 6) & 0b111111)); + min_vals[3] = scale_mins * static_cast((*(qs1 + 7) & 0b111111)); + min_vals[4] = scale_mins * static_cast((*(qs1 + 8) >> 4) | ((*(qs1 + 4) >> 6) << 4)); + min_vals[5] = scale_mins * static_cast((*(qs1 + 9) >> 4) | ((*(qs1 + 5) >> 6) << 4)); + min_vals[6] = scale_mins * static_cast((*(qs1 + 10) >> 4) | ((*(qs1 + 6) >> 6) << 4)); + min_vals[7] = scale_mins * static_cast((*(qs1 + 11) >> 4) | ((*(qs1 + 7) >> 6) << 4)); + + // Store scales and compute zero points or bias + for (int j = 0; j < 8; j++) { + scales[i * 8 + j] = ov::float16(scale_vals[j]); + if (use_bias) { + // Store bias = -min directly as f16, dequant: w*s + bias + bias_f16[i * 8 + j] = ov::float16(-min_vals[j]); + } else { + // zp = min / scale (since bias = -min and zp = -bias/scale) + uint8_t zp_val = (scale_vals[j] != 0.0f) ? (uint8_t) std::round(min_vals[j] / scale_vals[j]) : 0; + // Pack two 4-bit zero points per byte + size_t idx = i * 8 + j; + if (idx % 2 == 0) { + zp_u4[idx / 2] = zp_val & 0x0F; + } else { + zp_u4[idx / 2] |= (zp_val << 4); + } + } + } + unpack_256_4(block_data + 16, weights + i * 128); + }); +} + +void extract_q6_k_data(const ggml_tensor * tensor, + ov::Tensor & weights_arr, + ov::Tensor & scales_arr, + ov::Tensor & zp_arr) { + const uint64_t bytes_per_block = 128 + 64 + 16 + 2; + const uint64_t n_super_block = tensor->nb[3] / bytes_per_block; + + auto * data = static_cast(tensor->data); + auto * weights = static_cast(weights_arr.data()); + auto * scales = scales_arr.data::value_type>(); + auto * zp = static_cast(zp_arr.data()); + + bool is_scalar_zp = (zp_arr.get_size() == 1); // Symmetric quantization + + // For Q6_K, zero point is always 32 + if (is_scalar_zp) { + zp[0] = 32; + } + + ov::parallel_for(n_super_block, [&](size_t i) { + uint8_t * block_data = data + i * bytes_per_block; + + float scale_factor = + static_cast(ov::float16::from_bits(*((uint16_t *) block_data + 104))); // (128+64+16)/2 + + for (size_t j = 0; j < 16; j++) { + scales[j + i * 16] = + ov::float16(scale_factor * static_cast(*((int8_t *) (block_data + 128 + 64 + j)))); + // For asymmetric quantization, store per-block zero points + if (!is_scalar_zp) { + zp[j + i * 16] = 32; + } + } + + uint8_t * ql = block_data; + uint8_t * qh = block_data + 128; + + for (int64_t j = 0; j < 32; ++j) { + weights[i * 256 + j] = (ql[j] & 0xF) | (((qh[j] >> 0) & 3) << 4); + weights[i * 256 + j + 32] = (ql[32 + j] & 0xF) | (((qh[j] >> 2) & 3) << 4); + weights[i * 256 + j + 64] = (ql[j] >> 4) | (((qh[j] >> 4) & 3) << 4); + weights[i * 256 + j + 96] = (ql[32 + j] >> 4) | (((qh[j] >> 6) & 3) << 4); + weights[i * 256 + j + 128] = (ql[64 + j] & 0xF) | (((qh[32 + j] >> 0) & 3) << 4); + weights[i * 256 + j + 160] = (ql[96 + j] & 0xF) | (((qh[32 + j] >> 2) & 3) << 4); + weights[i * 256 + j + 192] = (ql[64 + j] >> 4) | (((qh[32 + j] >> 4) & 3) << 4); + weights[i * 256 + j + 224] = (ql[96 + j] >> 4) | (((qh[32 + j] >> 6) & 3) << 4); + } + }); +} + +static inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t * d, uint8_t * m) { + if (j < 4) { + *d = q[j] & 63; + *m = q[j + 4] & 63; + } else { + *d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4); + *m = (q[j + 4] >> 4) | ((q[j - 0] >> 6) << 4); + } +} + +void extract_q5_k_data(const ggml_tensor * tensor, + ov::Tensor & weights_arr, + ov::Tensor & scales_arr, + ov::Tensor & zp_arr, + bool use_bias) { + const uint64_t bytes_per_block = 4 + 12 + 32 + 128; + const uint64_t n_super_block = tensor->nb[3] / bytes_per_block; + + auto * data = static_cast(tensor->data); + auto * weights = static_cast(weights_arr.data()); + auto * scales = scales_arr.data::value_type>(); + + // For bias path, zp_arr holds f16 bias values; for zp path, it holds u8 zero points + auto * zp_u8 = use_bias ? nullptr : static_cast(zp_arr.data()); + auto * bias_f16 = use_bias ? zp_arr.data::value_type>() : nullptr; + + ov::parallel_for(n_super_block, [&](size_t i) { + uint8_t * block_data = data + i * bytes_per_block; + + const float d = static_cast(ov::float16::from_bits(*((uint16_t *) block_data))); + const float min_factor = static_cast(ov::float16::from_bits(*((uint16_t *) block_data + 1))); + + const uint8_t * scales_data = block_data + 4; // 12 bytes of scales + const uint8_t * qh = block_data + 4 + 12; // 32 bytes of high bits + const uint8_t * ql = block_data + 4 + 12 + 32; // 128 bytes of low bits + + int is = 0; + uint8_t u1 = 1; + uint8_t u2 = 2; + + // Process 2 blocks in one iteration + for (int j = 0; j < 256; j += 64) { // 256 = QK_K, so 4 iterations of 64 + uint8_t sc; + uint8_t m; + + // Get scale and min for first 32 elements + get_scale_min_k4(is + 0, scales_data, &sc, &m); + const float d1 = d * sc; + const float m1 = min_factor * m; + + // Get scale and min for second 32 elements + get_scale_min_k4(is + 1, scales_data, &sc, &m); + const float d2 = d * sc; + const float m2 = min_factor * m; + + scales[i * 8 + is] = ov::float16(d1); + scales[i * 8 + is + 1] = ov::float16(d2); + if (use_bias) { + // Store bias = -min directly as f16, dequant: w*s + bias + bias_f16[i * 8 + is] = ov::float16(-m1); + bias_f16[i * 8 + is + 1] = ov::float16(-m2); + } else { + // zp = min / scale (since bias = -min and zp = -bias/scale) + zp_u8[i * 8 + is] = (d1 != 0.0f) ? (uint8_t) std::round(m1 / d1) : 0; + zp_u8[i * 8 + is + 1] = (d2 != 0.0f) ? (uint8_t) std::round(m2 / d2) : 0; + } + + // Extract weights for first 32 elements (matching deq formula exactly) + for (int l = 0; l < 32; ++l) { + weights[i * 256 + j + l] = (ql[l] & 0xF) + ((qh[l] & u1) ? 16 : 0); + } + + // Extract weights for second 32 elements + for (int l = 0; l < 32; ++l) { + weights[i * 256 + j + l + 32] = (ql[l] >> 4) + ((qh[l] & u2) ? 16 : 0); + } + + ql += 32; + is += 2; + u1 <<= 2; + u2 <<= 2; + } + }); +} + +// TODO Reorder for make_intX_weights + +ov::Output make_int8_weights(ov::Tensor & weight, + ov::Tensor & scales, + ov::Tensor & zp, + size_t group_size, + bool use_bias) { + ov::Shape orig_shape = weight.get_shape(); + + // Expand dimensions for scales and zp/bias + auto scale_shape = scales.get_shape(); + auto zp_shape = zp.get_shape(); + bool is_scalar_zp = zp_shape.empty(); // Symmetric quantization + + ov::Shape packed_shape = {orig_shape[0], orig_shape[1] / group_size, group_size}; + + if (packed_shape[1] == 1) { + // Requantized channel-wise case + packed_shape.erase(packed_shape.begin() + 1); + } else { + scale_shape.push_back(1); + scales.set_shape(scale_shape); + // For symmetric quantization, zp remains scalar (don't resize) + if (!is_scalar_zp) { + zp_shape.push_back(1); + zp.set_shape(zp_shape); + } + } + + // Create graph nodes + auto weights_node = std::make_shared(ov::element::u8, packed_shape, + static_cast(weight.data()), nullptr); + weights_node->get_rt_info()["__gguf_tensor_holder"] = weight; + auto scales_f16 = std::make_shared(scales); + auto weights_f16 = std::make_shared(weights_node, ov::element::f16); + + ov::Output result; + if (use_bias && !is_scalar_zp) { + // Bias path: w * s + b (zp tensor holds f16 bias values) + auto bias_f16 = std::make_shared(zp); + auto w_s = std::make_shared(weights_f16, scales_f16, ov::op::AutoBroadcastType::NUMPY); + result = std::make_shared(w_s, bias_f16, ov::op::AutoBroadcastType::NUMPY); + } else { + // Zero point path: (w - zp) * s + auto zero_point = std::make_shared(zp); + float zp_value; + if (ov::op::util::get_single_value(zero_point, zp_value)) { + zero_point = ov::op::v0::Constant::create(zero_point->get_element_type(), {}, {zp_value}); + } + auto zero_point_f16 = std::make_shared(zero_point, ov::element::f16); + auto w_zp = + std::make_shared(weights_f16, zero_point_f16, ov::op::AutoBroadcastType::NUMPY); + result = std::make_shared(w_zp, scales_f16, ov::op::AutoBroadcastType::NUMPY); + } + + if (packed_shape.size() != 2) { + // If not requantized channel-wise case, reshape back to original shape + auto final_shape = + std::make_shared(ov::element::i64, ov::Shape{orig_shape.size()}, orig_shape); + result = std::make_shared(result, final_shape, false); + } + + return std::make_shared(result, ov::element::f32); +} + +ov::Output make_int4_weights(ov::Tensor & weight, + ov::Tensor & scales, + ov::Tensor & zp, + size_t group_size, + bool use_bias) { + ov::Shape orig_weight_shape = weight.get_shape(); + + // Expand dimensions for scales and zp/bias + ov::Shape scale_shape = scales.get_shape(); + auto zp_shape = zp.get_shape(); + bool is_scalar_zp = zp_shape.empty(); // Symmetric quantization + + // Create INT4 weight tensor + ov::Shape packed_shape = {orig_weight_shape[0], orig_weight_shape[1] / group_size, group_size}; + + if (packed_shape[1] == 1) { + // Requantized channel-wise case + packed_shape.erase(packed_shape.begin() + 1); + } else { + scale_shape.push_back(1); + scales.set_shape(scale_shape); + // For symmetric quantization, zp remains scalar (don't resize) + if (!is_scalar_zp) { + zp_shape.push_back(1); + zp.set_shape(zp_shape); + } + } + + auto weights_node = std::make_shared(ov::element::u4, packed_shape, + static_cast(weight.data()), nullptr); + weights_node->get_rt_info()["__gguf_tensor_holder"] = weight; + auto weights_f16 = std::make_shared(weights_node, ov::element::f16); + auto scales_f16 = std::make_shared(scales); + + ov::Output result; + if (use_bias && !is_scalar_zp) { + // Bias path: w * s + b (zp tensor holds f16 bias values) + auto bias_f16 = std::make_shared(zp); + auto w_s = std::make_shared(weights_f16, scales_f16, ov::op::AutoBroadcastType::NUMPY); + result = std::make_shared(w_s, bias_f16, ov::op::AutoBroadcastType::NUMPY); + } else { + // Zero point path: (w - zp) * s + auto zero_points_node = std::make_shared(zp); + float zp_value; + if (ov::op::util::get_single_value(zero_points_node, zp_value)) { + zero_points_node = ov::op::v0::Constant::create(zero_points_node->get_element_type(), {}, {zp_value}); + } + auto zero_points_f16 = std::make_shared(zero_points_node, ov::element::f16); + auto w_zp = + std::make_shared(weights_f16, zero_points_f16, ov::op::AutoBroadcastType::NUMPY); + result = std::make_shared(w_zp, scales_f16, ov::op::AutoBroadcastType::NUMPY); + } + + if (packed_shape.size() != 2) { + // If not requantized channel-wise case, reshape back to original shape + auto final_shape = std::make_shared(ov::element::i64, ov::Shape{orig_weight_shape.size()}, + orig_weight_shape); + result = std::make_shared(result, final_shape, false); + } + + return std::make_shared(result, ov::element::f32); +} + +// Extract quantized weights from tensor and create weight subgraph +std::shared_ptr extract_quantized_weights(const ggml_tensor * tensor, + const void * data, + ov::Tensor & weights, + ov::Tensor & scales, + ov::Tensor & zp, + bool use_bias) { + // Create a temporary tensor for extraction functions that read from tensor->data + ggml_tensor temp_tensor = *tensor; + temp_tensor.data = const_cast(data); + + // Determine block size based on tensor type + int64_t weights_per_block; + bool is_u4; + switch (tensor->type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_K: + is_u4 = true; + weights_per_block = 32; + break; + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q5_K: + is_u4 = false; + weights_per_block = 32; + break; + case GGML_TYPE_Q6_K: + is_u4 = false; + weights_per_block = 16; + break; + default: + throw std::runtime_error("Unsupported quantized type for extraction: " + + std::string(ggml_type_name(tensor->type))); + } + + // Extract quantized data + switch (tensor->type) { + case GGML_TYPE_Q4_0: + extract_q4_0_data(&temp_tensor, weights, scales, zp); + break; + case GGML_TYPE_Q4_1: + extract_q4_1_data(&temp_tensor, weights, scales, zp, use_bias); + break; + case GGML_TYPE_Q4_K: + extract_q4_k_data(&temp_tensor, weights, scales, zp, use_bias); + break; + case GGML_TYPE_Q8_0: + extract_q8_0_data(&temp_tensor, weights, scales, zp); + break; + case GGML_TYPE_Q6_K: + extract_q6_k_data(&temp_tensor, weights, scales, zp); + break; + case GGML_TYPE_Q5_K: + extract_q5_k_data(&temp_tensor, weights, scales, zp, use_bias); + break; + default: + throw std::runtime_error("Unsupported quantized type: " + std::string(ggml_type_name(tensor->type))); + } + + // Create the OpenVINO weight subgraph + ov::Output weight_node; + if (is_u4) { + weight_node = make_int4_weights(weights, scales, zp, weights_per_block, use_bias); + } else { + weight_node = make_int8_weights(weights, scales, zp, weights_per_block, use_bias); + } + + auto result = weight_node.get_node_shared_ptr(); + result->set_friendly_name(tensor->name); + return result; +} + +// Requantize weights to target format, writing to provided buffers +std::shared_ptr requantize_to_buffers(const ggml_tensor * tensor, + const void * data, + ExtraQuantType requant_type, + int64_t block_size, + ov::Tensor & weights, + ov::Tensor & scales, + ov::Tensor & zp) { + int64_t n_elements = ggml_nelements(tensor); + + // First dequantize to F32 + std::vector weights_f32(n_elements); + ggml_get_type_traits(tensor->type)->to_float(data, weights_f32.data(), n_elements); + + // Handle F16 case - just convert and create constant + if (requant_type == ExtraQuantType::F16) { + ggml_get_type_traits(GGML_TYPE_F16)->from_float_ref(weights_f32.data(), weights.data(), n_elements); + auto result = std::make_shared(weights); + result->set_friendly_name(tensor->name); + return result; + } + + // Requantize to target quantized format + bool is_u4 = (requant_type == ExtraQuantType::Q4_0_C || requant_type == ExtraQuantType::Q4_0_128); + + if (is_u4) { + quantize_q4_0(weights_f32.data(), weights, scales, zp, n_elements, block_size); + } else if (requant_type == ExtraQuantType::Q8_1_C) { + quantize_q8_1(weights_f32.data(), weights, scales, zp, n_elements, block_size); + } else { + quantize_q8_0(weights_f32.data(), weights, scales, zp, n_elements, block_size); + } + + // Create the OpenVINO weight subgraph + ov::Output weight_node; + if (is_u4) { + weight_node = make_int4_weights(weights, scales, zp, block_size); + } else { + weight_node = make_int8_weights(weights, scales, zp, block_size); + } + + auto result = weight_node.get_node_shared_ptr(); + result->set_friendly_name(tensor->name); + return result; +} + +OvWeight process_weight_tensor(const ggml_tensor * tensor, const void * data, void * output_base_ptr, bool use_bias) { + GGML_ASSERT(tensor != nullptr); + GGML_ASSERT(data != nullptr); + + OvWeight result; + + // Get 2D shape for weights [rows, cols] + ov::Shape node_shape = {static_cast(tensor->ne[1]), static_cast(tensor->ne[0])}; + + // Handle F16/F32/BF16 weights + if (tensor->type == GGML_TYPE_F32 || tensor->type == GGML_TYPE_F16 || tensor->type == GGML_TYPE_BF16) { + ov::element::Type element_type; + switch (tensor->type) { + case GGML_TYPE_F32: + element_type = ov::element::f32; + break; + case GGML_TYPE_F16: + element_type = ov::element::f16; + break; + case GGML_TYPE_BF16: + element_type = ov::element::bf16; + break; + default: + OPENVINO_THROW("Unexpected tensor type in F16/F32/BF16 path"); + } + + if (output_base_ptr && output_base_ptr != data) { + // Using external buffer - copy data and create shared-memory constant + size_t tensor_bytes = ggml_nbytes(tensor); + memcpy(output_base_ptr, data, tensor_bytes); + result.weights = ov::Tensor(element_type, node_shape, output_base_ptr); + } else { + result.weights = ov::Tensor(element_type, node_shape, data); + } + result.weight_node = std::make_shared(result.weights); + return result; + } + + // Handle quantized weights + if (!ggml_is_quantized(tensor->type)) { + OPENVINO_THROW("Unsupported weight tensor type: ", ggml_type_name(tensor->type)); + } + + result.layout = ggml_openvino_get_extracted_layout(tensor, use_bias); + const auto & layout = result.layout; + if (layout.total_size == 0) { + OPENVINO_THROW("Unsupported quantized type: ", ggml_type_name(tensor->type)); + } + + if (use_bias) { + OPENVINO_ASSERT(!layout.is_requant, + "use_bias is only used for test-backend-ops, which should not have requantization"); + // bias node will be created on the fly and not use backend buffer + output_base_ptr = nullptr; + } + + // F16 requant path - no separate scales/zp needed in result + if (layout.is_requant && layout.requant_type.has_value() && layout.requant_type.value() == ExtraQuantType::F16) { + if (output_base_ptr) { + result.weights = ov::Tensor(ov::element::f16, node_shape, + static_cast(output_base_ptr) + layout.weights_offset); + } else { + result.weights = ov::Tensor(ov::element::f16, node_shape); + } + ov::Tensor dummy_scales, dummy_zp; // Not used for F16 + result.weight_node = + requantize_to_buffers(tensor, data, ExtraQuantType::F16, 0, result.weights, dummy_scales, dummy_zp); + return result; + } + + // Quantized path (normal extraction or quantized requant) + // Create weight/scale/zp tensors - shared between both paths + ov::element::Type weight_type = layout.is_u4 ? ov::element::u4 : ov::element::u8; + ov::Shape scale_shape = {node_shape[0], node_shape[1] / layout.weights_per_block}; + ov::Shape zp_shape = layout.is_symmetric ? ov::Shape{} : scale_shape; + + if (output_base_ptr) { + uint8_t * buf_base = static_cast(output_base_ptr); + result.weights = ov::Tensor(weight_type, node_shape, buf_base + layout.weights_offset); + result.scales = ov::Tensor(ov::element::f16, scale_shape, buf_base + layout.scales_offset); + result.zp = ov::Tensor(weight_type, zp_shape, buf_base + layout.zp_offset); + } else { + result.weights = ov::Tensor(weight_type, node_shape); + result.scales = ov::Tensor(ov::element::f16, scale_shape); + if (use_bias && !layout.is_symmetric) { + // bias only has effect for asymmetric quant + result.zp = ov::Tensor(ov::element::f16, zp_shape); + } else { + result.zp = ov::Tensor(weight_type, zp_shape); + } + } + + if (layout.is_requant && layout.requant_type.has_value()) { + result.weight_node = requantize_to_buffers(tensor, data, layout.requant_type.value(), layout.weights_per_block, + result.weights, result.scales, result.zp); + } else { + result.weight_node = + extract_quantized_weights(tensor, data, result.weights, result.scales, result.zp, use_bias); + } + + return result; +} + +void quantize_q4_0(const float * x, + ov::Tensor & weights_arr, + ov::Tensor & scales_arr, + ov::Tensor & zp_arr, + int64_t k, + int64_t qk) { + assert(k % qk == 0); + const int nb = k / qk; + + auto * weights = static_cast(weights_arr.data()); + auto * scales = scales_arr.data::value_type>(); + auto * zp = static_cast(zp_arr.data()); + bool is_scalar_zp = (zp_arr.get_size() == 1); // Symmetric quantization + + // For Q4_0, zero point is always 8 + if (is_scalar_zp) { + zp[0] = 8 | (8 << 4); // Pack two 4-bit values + } + + for (int i = 0; i < nb; i++) { + float amax = 0.0f; // absolute max + float max = 0.0f; + + for (int j = 0; j < qk; j++) { + const float v = x[i * qk + j]; + if (amax < fabsf(v)) { + amax = fabsf(v); + max = v; + } + } + + const float d = max / -8; + + if (d == 0) { + scales[i] = ov::float16(1.0f); + // zp is already set to 8 for symmetric, or set per-block for asymmetric + if (!is_scalar_zp) { + if (i % 2 == 0) { + zp[i / 2] = 8; + } else { + zp[i / 2] |= (8 << 4); + } + } + memset(weights + i * qk / 2, 8 | (8 << 4), qk / 2); + continue; + } + + const float id = 1.0f / d; + scales[i] = ov::float16(d); + // For asymmetric quantization, store per-block zero points + if (!is_scalar_zp) { + if (i % 2 == 0) { + zp[i / 2] = 8; + } else { + zp[i / 2] |= (8 << 4); + } + } + + for (int j = 0; j < qk / 2; ++j) { + const float x0 = x[i * qk + 2 * j] * id; + const float x1 = x[i * qk + 2 * j + 1] * id; + const uint8_t xi0 = MIN(15, (int8_t) (x0 + 8.5f)); + const uint8_t xi1 = MIN(15, (int8_t) (x1 + 8.5f)); + weights[i * qk / 2 + j] = xi0 | (xi1 << 4); + } + } +} + +void quantize_q8_0(const float * x, + ov::Tensor & weights_arr, + ov::Tensor & scales_arr, + ov::Tensor & zp_arr, + int64_t k, + int64_t qk) { + assert(k % qk == 0); + const int nb = k / qk; + + auto * weights = static_cast(weights_arr.data()); + auto * scales = scales_arr.data::value_type>(); + auto * zp = static_cast(zp_arr.data()); + bool is_scalar_zp = (zp_arr.get_size() == 1); // Symmetric quantization + + // For Q8_0, zero point is always 128 + if (is_scalar_zp) { + zp[0] = 128; + } + + for (int i = 0; i < nb; i++) { + float amax = 0.0f; // absolute max + + for (int j = 0; j < qk; j++) { + const float v = x[i * qk + j]; + if (amax < fabsf(v)) { + amax = fabsf(v); + } + } + + const float d = amax / 127.0f; + const float id = d ? 1.0f / d : 0.0f; + scales[i] = ov::float16(d); + // For asymmetric quantization, store per-block zero points + if (!is_scalar_zp) { + zp[i] = 128; + } + + for (int j = 0; j < qk; ++j) { + const float x0 = x[i * qk + j] * id; + const int8_t xi0 = roundf(x0); + weights[i * qk + j] = (uint8_t) (xi0 + 128); + } + } +} + +void quantize_q8_1(const float * x, + ov::Tensor & weights_arr, + ov::Tensor & scales_arr, + ov::Tensor & zp_arr, + int64_t k, + int64_t qk) { + assert(k % qk == 0); + const int nb = k / qk; + + auto * weights = static_cast(weights_arr.data()); + auto * scales = scales_arr.data::value_type>(); + auto * zp = static_cast(zp_arr.data()); + for (int i = 0; i < nb; i++) { + float min = std::numeric_limits::max(); + float max = std::numeric_limits::lowest(); + + for (int j = 0; j < qk; j++) { + const float v = x[i * qk + j]; + if (v < min) { + min = v; + } + if (v > max) { + max = v; + } + } + + const float d = (max - min) / ((1 << 8) - 1); + const float id = d ? 1.0f / d : 0.0f; + scales[i] = ov::float16(d); + // zp = -min / scale (Q8_1 is asymmetric) + zp[i] = (d != 0.0f) ? (uint8_t) std::round(-min / d) : 0; + + for (int j = 0; j < qk; ++j) { + const float x0 = (x[i * qk + j] - min) * id; + const uint8_t xi0 = roundf(x0); + weights[i * qk + j] = xi0; + } + } +} diff --git a/ggml/src/ggml-openvino/ggml-quants.h b/ggml/src/ggml-openvino/ggml-quants.h new file mode 100644 index 00000000000..e4a02297cae --- /dev/null +++ b/ggml/src/ggml-openvino/ggml-quants.h @@ -0,0 +1,153 @@ +#pragma once +#include "ggml-openvino-extra.h" // For ExtraQuantType +#include "ggml.h" + +#include +#include +#include + +void unpack_32_4(const uint8_t* data, uint8_t* dst); + +void extract_q4_0_data(const ggml_tensor * tensor, + ov::Tensor & weights_arr, + ov::Tensor & scales_arr, + ov::Tensor & zp_arr); + +void extract_q4_1_data(const ggml_tensor * tensor, + ov::Tensor & weights_arr, + ov::Tensor & scales_arr, + ov::Tensor & zp_arr, + bool use_bias = false); + +void extract_q8_0_data(const ggml_tensor * tensor, + ov::Tensor & weights_arr, + ov::Tensor & scales_arr, + ov::Tensor & zp_arr); + +void unpack_256_4(const uint8_t* data, uint8_t* dst); + +void extract_q4_k_data(const ggml_tensor * tensor, + ov::Tensor & weights_arr, + ov::Tensor & scales_arr, + ov::Tensor & zp_arr, + bool use_bias = false); + +void extract_q5_k_data(const ggml_tensor * tensor, + ov::Tensor & weights_arr, + ov::Tensor & scales_arr, + ov::Tensor & zp_arr, + bool use_bias = false); + +void extract_q6_k_data(const ggml_tensor * tensor, + ov::Tensor & weights_arr, + ov::Tensor & scales_arr, + ov::Tensor & zp_arr); + +static constexpr size_t GGML_QUANTIZATION_GROUP_SIZE = 32; + +ov::Output make_int8_weights(ov::Tensor & weight, + ov::Tensor & scales, + ov::Tensor & zp, + size_t group_size = GGML_QUANTIZATION_GROUP_SIZE, + bool use_bias = false); + +ov::Output make_int4_weights(ov::Tensor & weight, + ov::Tensor & scales, + ov::Tensor & zp, + size_t group_size = GGML_QUANTIZATION_GROUP_SIZE, + bool use_bias = false); + +// Extract quantized weights from tensor and create weight subgraph +// If weights/scales/zp are provided (non-empty), uses them as output buffers +// Otherwise allocates new ov::Tensors internally +// Returns the weight node (make_int4_weights or make_int8_weights result) +std::shared_ptr extract_quantized_weights( + const ggml_tensor * tensor, + const void * data, // Source data pointer (may differ from tensor->data) + ov::Tensor & weights, + ov::Tensor & scales, + ov::Tensor & zp, + bool use_bias = false); // Use fp bias instead of quantized zero_point (for test-backend-ops) + +// Requantize weights from tensor to target format, writing to provided buffers +// For F16 target, only weights buffer is used (scales/zp ignored) +// Returns the weight node +std::shared_ptr requantize_to_buffers(const ggml_tensor * tensor, + const void * data, // Source data pointer + ExtraQuantType requant_type, + int64_t block_size, + ov::Tensor & weights, + ov::Tensor & scales, + ov::Tensor & zp); + +inline const char * extra_quant_type_name(ExtraQuantType t) { + switch (t) { + case ExtraQuantType::F16: + return "F16"; + case ExtraQuantType::Q4_0_C: + return "Q4_0_C"; + case ExtraQuantType::Q4_0_128: + return "Q4_0_128"; + case ExtraQuantType::Q8_0_C: + return "Q8_0_C"; + case ExtraQuantType::Q8_0_32: + return "Q8_0_32"; + case ExtraQuantType::Q8_1_C: + return "Q8_1_C"; + default: + return "unknown"; + } +} + +// Result from process_weight_tensor containing the weight node and tensors. +// For quantized weights, also contains the extracted layout and scale/zp tensors. +struct OvWeight { + std::shared_ptr weight_node; + ggml_openvino_extracted_layout layout; // Only meaningful for quantized (layout.total_size > 0) + ov::Tensor weights; + ov::Tensor scales; + ov::Tensor zp; + + bool is_quantized() const { return layout.scales_size > 0; } +}; + +// Process weight tensor and create an OpenVINO weight node +// Handles F16/F32/BF16 and quantized weights, with optional requantization +// If output_base_ptr is nullptr, allocates internal buffers (for decoder use) +// If output_base_ptr is provided, uses pre-allocated buffers at specified offsets (for backend buffer use) +// Returns OvWeight with the weight node and optional quantized tensors +OvWeight process_weight_tensor( + const ggml_tensor * tensor, + const void * data, // Source data pointer (may differ from tensor->data) + void * output_base_ptr = nullptr, // Base pointer for output buffers (or nullptr for internal allocation) + bool use_bias = false); // Use fp bias instead of quantized zero_point, only used in test-backend-ops + +void quantize_q4_0(const float * x, + ov::Tensor & weights_arr, + ov::Tensor & scales_arr, + ov::Tensor & zp_arr, + int64_t k, + int64_t qk); +void quantize_q8_1(const float * x, + ov::Tensor & weights_arr, + ov::Tensor & scales_arr, + ov::Tensor & zp_arr, + int64_t k, + int64_t qk); +void quantize_q8_0(const float * x, + ov::Tensor & weights_arr, + ov::Tensor & scales_arr, + ov::Tensor & zp_arr, + int64_t k, + int64_t qk); + +namespace ov { +namespace op { +namespace util { +// From /src/common/transformations/include/transformations/utils/utils.hpp +bool get_single_value(const std::shared_ptr& const_node, + float& value, + bool check_value_range = true); +} // namespace util +} // namespace op +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/decoder.h b/ggml/src/ggml-openvino/openvino/decoder.h new file mode 100644 index 00000000000..3b8da2be5d2 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/decoder.h @@ -0,0 +1,74 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace ov { +namespace frontend { +namespace ggml { + +class GgmlDecoder : public DecoderBase { +public: + virtual ov::Any get_attribute(const std::string& name) const = 0; + + virtual PartialShape get_input_shape(int node_idx, const std::string& name) const = 0; + + virtual std::vector get_input_stride(int node_idx, const std::string& name) const = 0; + + virtual element::Type get_input_type(int node_idx, const std::string& name) const = 0; + + virtual size_t get_input_size() const = 0; + + virtual size_t get_input_size(int node_idx) const = 0; + + virtual void get_input_node(size_t input_port_idx, + std::string& producer_name, + std::string& producer_output_port_name, + size_t& producer_output_port_index) const = 0; + + virtual std::vector get_input_names(int node_idx) const = 0; + + virtual PartialShape get_output_shape(int node_idx) const = 0; + + virtual element::Type get_output_type(const int node_idx) const = 0; + + virtual int32_t* get_input_op_params(int node_idx, const std::string& name) const = 0; + + virtual int32_t * get_output_op_params(int node_idx) const = 0; + + virtual std::vector get_output_names(int node_idx) const = 0; + + virtual const std::string& get_op_type() const = 0; + + virtual const std::string& get_op_type(int node_idx) const = 0; + + virtual const std::string& get_op_name() const = 0; + + virtual const std::string& get_op_name(int node_idx) const = 0; + + virtual void visit_subgraph(std::function, int node_idx)> node_visitor) const = 0; + + virtual int get_op_case(int node_idx) const = 0; + + virtual const std::map>& get_model_inputs() const = 0; + virtual const std::map>& get_model_extra_inputs() const = 0; + virtual const std::map>& get_model_weights() const = 0; + virtual std::vector get_model_output_names() const = 0; + + virtual int32_t* get_rope_params() const = 0; + + virtual std::map get_kv_param_res_names() const = 0; + + virtual bool is_static() const = 0; + + virtual bool is_stateful() const = 0; + + virtual int is_swa_layer(int layer) const = 0; +}; + +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/frontend.cpp b/ggml/src/ggml-openvino/openvino/frontend.cpp new file mode 100644 index 00000000000..c2ba14e66e6 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/frontend.cpp @@ -0,0 +1,27 @@ +#include "frontend.h" + +#include "input_model.h" +#include "op_table.h" +#include "translate_session.h" + +namespace ov { +namespace frontend { +namespace ggml { + +FrontEnd::FrontEnd() {} + +std::shared_ptr FrontEnd::convert(const InputModel::Ptr & model, bool naive) { + auto ggml_model = std::dynamic_pointer_cast(model); + FRONT_END_GENERAL_CHECK(ggml_model, "Invalid input model"); + std::shared_ptr converted_model; + const auto & supported_ops = get_supported_ops(); + { + TranslateSession translate_session(model, supported_ops, naive); + converted_model = translate_session.get_converted_model(); + } + return converted_model; +} + +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/frontend.h b/ggml/src/ggml-openvino/openvino/frontend.h new file mode 100644 index 00000000000..f1c6f0c3e3c --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/frontend.h @@ -0,0 +1,23 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +namespace ov { +namespace frontend { +namespace ggml { + +class FrontEnd { +public: + using Ptr = std::shared_ptr; + FrontEnd(); + + static std::shared_ptr convert(const InputModel::Ptr& model, bool naive = false); +}; + +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/input_model.cpp b/ggml/src/ggml-openvino/openvino/input_model.cpp new file mode 100644 index 00000000000..39b004c9317 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/input_model.cpp @@ -0,0 +1,17 @@ +#include "input_model.h" + +#include "decoder.h" + +namespace ov { +namespace frontend { +namespace ggml { + +InputModel::InputModel(const std::shared_ptr & gdecoder) : m_decoder(gdecoder) {} + +const std::shared_ptr & InputModel::get_model_decoder() const { + return m_decoder; +} + +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/input_model.h b/ggml/src/ggml-openvino/openvino/input_model.h new file mode 100644 index 00000000000..ce8434426c9 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/input_model.h @@ -0,0 +1,29 @@ +#pragma once + +#include + +#include "decoder.h" + +namespace ov { +namespace frontend { +namespace ggml { + +class FrontEnd; +class GgmlDecoder; +using ov::frontend::ggml::GgmlDecoder; + +class InputModel : public ov::frontend::InputModel { + friend class ::ov::frontend::ggml::FrontEnd; + +public: + explicit InputModel(const std::shared_ptr& gdecoder); + + const std::shared_ptr& get_model_decoder() const; + +private: + std::shared_ptr m_decoder; +}; + +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/node_context.h b/ggml/src/ggml-openvino/openvino/node_context.h new file mode 100644 index 00000000000..aa484128a95 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/node_context.h @@ -0,0 +1,112 @@ +#pragma once + +#include +#include +#include + +#include "decoder.h" + +namespace ov { +namespace frontend { +namespace ggml { + +class TranslateSession; + +typedef std::map> TensorMap; + +class NodeContext : public frontend::NodeContext { +public: + NodeContext(const std::shared_ptr& decoder, + std::shared_ptr& tensor_map, + int node_idx, + TranslateSession* translate_session = nullptr) + : ov::frontend::NodeContext(decoder->get_op_type(node_idx)), + m_decoder(decoder), + m_tensor_map(tensor_map), + m_node_idx(node_idx), + m_translate_session(translate_session) { + m_input_names = decoder->get_input_names(m_node_idx); + m_output_names = decoder->get_output_names(m_node_idx); + } + + TranslateSession* get_translate_session() const { + return m_translate_session; + } + + const std::vector& get_input_names() const { return m_input_names; } + + size_t get_input_size() const override { + return m_decoder->get_input_size(m_node_idx); + } + + ov::element::Type get_input_type(size_t index) const { + return m_decoder->get_input_type(m_node_idx, m_input_names[index]); + } + + PartialShape get_input_shape(size_t input_index) const { + return m_decoder->get_input_shape(m_node_idx, m_input_names[input_index]); + } + + std::vector get_input_stride(size_t index) const { + return m_decoder->get_input_stride(m_node_idx, m_input_names[index]); + } + + std::string get_output_name() const { return m_output_names[0]; } + + PartialShape get_output_shape() const { return m_decoder->get_output_shape(m_node_idx); } + + int32_t* get_input_op_params(size_t index) const { + return m_decoder->get_input_op_params(m_node_idx, m_input_names[index]); + } + + int32_t * get_output_op_params() const { return m_decoder->get_output_op_params(m_node_idx); } + + ov::element::Type get_output_type() const { + return m_decoder->get_output_type(m_node_idx); + } + + Output get_input(int idx) const override { + return m_tensor_map->at(m_input_names[idx]); + } + + Output get_input(const std::string& name) const override { + if (m_tensor_map->find(name) == m_tensor_map->end()) { + throw std::runtime_error("'" + name + "' not found in tensor map."); + } + return m_tensor_map->at(name); + } + + bool has_input(const std::string& name) const { + return m_tensor_map->find(name) != m_tensor_map->end(); + } + + const std::string& get_name() const override { + return m_decoder->get_op_name(m_node_idx); + } + + ov::Any get_attribute_as_any(const std::string& name) const override { + return m_decoder->get_attribute(name); + } + + int get_op_case() const { + return m_decoder->get_op_case(m_node_idx); + } + + bool is_static() const { return m_decoder->is_static(); } + + bool is_stateful() const { return m_decoder->is_stateful(); } + +private: + std::shared_ptr m_decoder; + std::shared_ptr& m_tensor_map; + int m_node_idx; + TranslateSession* m_translate_session; + std::vector m_input_names; + std::vector m_output_names; +}; + +using CreatorFunction = std::function; + +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op/cont.cpp b/ggml/src/ggml-openvino/openvino/op/cont.cpp new file mode 100644 index 00000000000..6160dd74444 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op/cont.cpp @@ -0,0 +1,48 @@ + +#include "../node_context.h" +#include "../op_table.h" +#include "../utils.h" + +#include +#include +#include +#include +#include +#include + +namespace ov { +namespace frontend { +namespace ggml { +namespace op { + +OutputVector translate_cont(const NodeContext & context) { + num_inputs_check(context, 1, 1); + + int op_case = context.get_op_case(); + FRONT_END_CHECK_IMPLEMENTED(op_case == 1 || op_case == 2 || op_case == 3, "Unsupported CONT case"); + + auto src_shape = context.get_input_shape(0).to_shape(); + auto dst_shape = context.get_output_shape().to_shape(); + ov::Output res; + + if (op_case == 1) { + // The input comes from a PERMUTE + throw std::runtime_error("Code of this case might be outdated"); + dst_shape[1] = -1; + res = std::make_shared( + context.get_input(0), ov::op::v0::Constant::create(ov::element::i64, {dst_shape.size()}, dst_shape), false); + } else if (op_case == 2) { + // The input comes from a TRANSPOSE + return {context.get_input(0)}; + } else { + // The input comes from a VIEW + res = process_view_input(context, 0); + } + + return rename_outputs_with_suffix({res}, context.get_name()); +} + +} // namespace op +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op/cpy.cpp b/ggml/src/ggml-openvino/openvino/op/cpy.cpp new file mode 100644 index 00000000000..831117208be --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op/cpy.cpp @@ -0,0 +1,21 @@ +#include "../node_context.h" +#include "../op_table.h" +#include "../utils.h" + +#include +#include + +namespace ov { +namespace frontend { +namespace ggml { +namespace op { + +OutputVector translate_cpy(const NodeContext & context) { + auto res = std::make_shared(context.get_input(0), context.get_output_type()); + return rename_outputs_with_suffix({res}, context.get_name()); +} + +} // namespace op +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp b/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp new file mode 100644 index 00000000000..42602a730a4 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp @@ -0,0 +1,90 @@ +#include "../node_context.h" +#include "../op_table.h" +#include "../utils.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace ov { +namespace frontend { +namespace ggml { +namespace op { + +OutputVector translate_flash_attn_ext(const NodeContext & context) { + num_inputs_check(context, 4, 4); + auto q_f32 = context.get_input(0); + auto k = context.get_input(1); + auto v = context.get_input(2); + auto mask = context.get_input(3); + + float * params = reinterpret_cast(context.get_output_op_params()); + float scale = params[0]; + // float max_bias = params[1]; + // float logit_softcap = params[2]; + + auto q = std::make_shared(q_f32, ov::element::f16); + auto scale_node = std::make_shared(ov::element::f16, ov::Shape{}, std::vector{scale}); + + ov::Output mask_sliced, res; + std::string mask_name = "KQ_mask_sliced"; + if (context.get_input_names()[3].find("swa") != std::string::npos) { + mask_name = "KQ_mask_swa_sliced"; + } + if (context.has_input(mask_name)) { + mask_sliced = context.get_input(mask_name); + } else { + auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0}); + auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1}); + auto two = ov::op::v0::Constant::create(ov::element::i64, {1}, {2}); + auto token_len = get_dimensions(q, {2}); + mask_sliced = std::make_shared(mask, zero, token_len, one, two); + } + + if (mask_sliced.get_element_type() != ov::element::f16) { + mask_sliced = std::make_shared(mask_sliced, ov::element::f16); + } + + auto tile_kv = [&](int64_t num_heads, int64_t num_heads_kv, int64_t head_size, ov::Output kv) { + int64_t factor = num_heads / num_heads_kv; + if (factor > 1 && num_heads_kv > 1) { + ov::Output kv_broadcast_shape, kv_unsqueezed, new_kv_shape; + auto unsqueeze_axes = ov::op::v0::Constant::create(ov::element::i64, Shape{}, {2}); + kv_unsqueezed = std::make_shared(kv, unsqueeze_axes); + + kv_broadcast_shape = ov::op::v0::Constant::create( + ov::element::i64, {5}, {(int64_t) 1, (int64_t) 1, factor, (int64_t) 1, (int64_t) 1}); + new_kv_shape = + ov::op::v0::Constant::create(ov::element::i64, {4}, {(int64_t) 0, num_heads, (int64_t) -1, head_size}); + + kv = std::make_shared(kv_unsqueezed, kv_broadcast_shape, + ov::op::BroadcastType::BIDIRECTIONAL); + kv = std::make_shared(kv, new_kv_shape, true); + } + return kv; + }; + + auto q_shape = context.get_input_shape(0).to_shape(); + auto k_shape = context.get_input_shape(1).to_shape(); + k = tile_kv(q_shape[1], k_shape[1], q_shape[3], k); + v = tile_kv(q_shape[1], k_shape[1], q_shape[3], v); + + auto sdpa = std::make_shared(q, k, v, mask_sliced, scale_node, false); + res = std::make_shared(sdpa, + ov::op::v0::Constant::create(ov::element::i64, {4}, {0, 2, 1, 3})); + res = std::make_shared(res, ov::element::f32); + return rename_outputs_with_suffix({res}, context.get_name()); +} + +} // namespace op +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op/get_rows.cpp b/ggml/src/ggml-openvino/openvino/op/get_rows.cpp new file mode 100644 index 00000000000..49f51b7ca3f --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op/get_rows.cpp @@ -0,0 +1,69 @@ +#include "../node_context.h" +#include "../op_table.h" +#include "../utils.h" + +#include +#include +#include +#include +#include +#include +#include + +namespace ov { +namespace frontend { +namespace ggml { +namespace op { + +OutputVector translate_get_rows(const NodeContext & context) { + num_inputs_check(context, 2, 2); + + int op_case = context.get_op_case(); + + Output res; + auto data = context.get_input(0); + auto indices = context.get_input(1); + + if (op_case == 2) { + // The input comes from a VIEW + indices = process_view_input(context, 1); + } + + // data[1,b,x,y] ind[1,1,b,x'] test-backend-ops case + // data[x,y] ind[1,1,1,x'] normal case + indices = + std::make_shared(indices, ov::op::v0::Constant::create(ov::element::i64, {2}, {0, 1})); + if (data.get_partial_shape().rank() == 4) { + if (!(data.get_partial_shape()[1].is_dynamic()) && data.get_partial_shape()[1].get_length() == 1) { + // Work-around for a bug in ov cpu plugin for test-backend-ops + data = std::make_shared(data, + ov::op::v0::Constant::create(ov::element::i64, {2}, {0, 1})); + auto axis = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{}, {0}); + res = std::make_shared(data, indices, axis); + } else { + auto axis = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{}, {1}); + data = + std::make_shared(data, ov::op::v0::Constant::create(ov::element::i64, {1}, {0})); + res = std::make_shared(data, indices, axis, 1); + } + } else if (context.is_stateful() && data.get_partial_shape().rank() == 3) { + auto axis = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{}, {1}); + res = std::make_shared(data, indices, axis, 1); + } else { + auto axis = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{}, {0}); + res = std::make_shared(data, indices, axis); + } + + if (res.get_element_type() != context.get_output_type()) { + res = std::make_shared(res, context.get_output_type()); + } + if (!(context.is_stateful())) { + res = std::make_shared(res, ov::op::v0::Constant::create(ov::element::i64, {1}, {0})); + } + return rename_outputs_with_suffix({res}, context.get_name()); +} + +} // namespace op +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp b/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp new file mode 100644 index 00000000000..d9fa4c24367 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp @@ -0,0 +1,61 @@ +#include "../node_context.h" +#include "../op_table.h" +#include "../utils.h" + +#include +#include +#include +#include +#include +#include +#include + +namespace ov { +namespace frontend { +namespace ggml { +namespace op { + +OutputVector translate_glu_geglu(const NodeContext & context) { + num_inputs_check(context, 1, 2); + + ov::Output src0; + ov::Output src1; + if (context.get_input_size() == 2) { + src0 = context.get_input(0); + src1 = context.get_input(1); + } else { + // GGML splits along ne[0] (OV last axis) using floor division: nc = ne[0] / 2. + // Both halves are nc elements; if the dimension is odd, the last element is dropped. + // Use Slice instead of Split to handle odd dimensions correctly. + auto combined = context.get_input(0); + auto combined_shape = combined.get_partial_shape(); + int64_t last_dim_val = combined_shape[combined_shape.rank().get_length() - 1].get_length(); + int64_t nc = last_dim_val / 2; + + auto axis = ov::op::v0::Constant::create(ov::element::i64, {1}, {-1}); + auto step = ov::op::v0::Constant::create(ov::element::i64, {1}, {1}); + auto start0 = ov::op::v0::Constant::create(ov::element::i64, {1}, {0}); + auto stop0 = ov::op::v0::Constant::create(ov::element::i64, {1}, {nc}); + auto start1 = ov::op::v0::Constant::create(ov::element::i64, {1}, {nc}); + auto stop1 = ov::op::v0::Constant::create(ov::element::i64, {1}, {2 * nc}); + + src0 = std::make_shared(combined, start0, stop0, step, axis); + src1 = std::make_shared(combined, start1, stop1, step, axis); + } + + int32_t * params = context.get_output_op_params(); + const int32_t swapped = params[1]; + if (swapped) { + std::swap(src0, src1); + } + + auto gelu = std::make_shared(src0); + auto res = std::make_shared(gelu, src1); + + return rename_outputs_with_suffix({res}, context.get_name()); +} + +} // namespace op +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp b/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp new file mode 100644 index 00000000000..00ed7951a03 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp @@ -0,0 +1,62 @@ +#include "../node_context.h" +#include "../op_table.h" +#include "../utils.h" + +#include +#include +#include +#include +#include +#include +#include + +namespace ov { +namespace frontend { +namespace ggml { +namespace op { + +OutputVector translate_glu_swiglu(const NodeContext & context) { + num_inputs_check(context, 1, 2); + + ov::Output src0; + ov::Output src1; + if (context.get_input_size() == 2) { + src0 = context.get_input(0); + src1 = context.get_input(1); + } else { + // GGML splits along ne[0] (OV last axis) using floor division: nc = ne[0] / 2. + // Both halves are nc elements; if the dimension is odd, the last element is dropped. + // Use Slice instead of Split to handle odd dimensions correctly. + auto combined = context.get_input(0); + auto combined_shape = combined.get_partial_shape(); + int64_t last_dim_val = combined_shape[combined_shape.rank().get_length() - 1].get_length(); + int64_t nc = last_dim_val / 2; + + auto axis = ov::op::v0::Constant::create(ov::element::i64, {1}, {-1}); + auto step = ov::op::v0::Constant::create(ov::element::i64, {1}, {1}); + auto start0 = ov::op::v0::Constant::create(ov::element::i64, {1}, {0}); + auto stop0 = ov::op::v0::Constant::create(ov::element::i64, {1}, {nc}); + auto start1 = ov::op::v0::Constant::create(ov::element::i64, {1}, {nc}); + auto stop1 = ov::op::v0::Constant::create(ov::element::i64, {1}, {2 * nc}); + + src0 = std::make_shared(combined, start0, stop0, step, axis); + src1 = std::make_shared(combined, start1, stop1, step, axis); + } + + int32_t * params = context.get_output_op_params(); + const int32_t swapped = params[1]; + if (swapped) { + std::swap(src0, src1); + } + + auto sigmoid = std::make_shared(src0); + auto silu = std::make_shared(src0, sigmoid); + auto res = std::make_shared(silu, src1); + + return rename_outputs_with_suffix({res}, context.get_name()); +} + +} // namespace op +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op/mulmat.cpp b/ggml/src/ggml-openvino/openvino/op/mulmat.cpp new file mode 100644 index 00000000000..38edec85ddf --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op/mulmat.cpp @@ -0,0 +1,90 @@ +#include "../node_context.h" +#include "../op_table.h" +#include "../utils.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace ov { +namespace frontend { +namespace ggml { +namespace op { + +OutputVector translate_mulmat(const NodeContext & context) { + num_inputs_check(context, 2, 2); + + int op_case = context.get_op_case(); + + ov::Output res; + ov::Output B = context.get_input(0); + ov::Output A = context.get_input(1); + + bool transpose_b = true; + if (op_case == 2) { + B = B.get_node_shared_ptr()->input_value(0); + transpose_b = false; + } else if (op_case == 3) { + B = process_view_input(context, 0); + A = process_view_input(context, 1); + } + if (A.get_element_type() != B.get_element_type()) { + B = std::make_shared(context.get_input(0), context.get_input_type(1)); + } + + auto B_shape = context.get_input_shape(0).to_shape(); + auto A_shape = context.get_input_shape(1).to_shape(); + int64_t A_batch = A_shape[1]; + int64_t B_batch = B_shape[1]; + + auto A_batch_larger = A_batch > B_batch; + auto batch_large = A_batch_larger ? A_batch : B_batch; + auto batch_small = A_batch_larger ? B_batch : A_batch; + + Output Z = A_batch_larger ? B : A; + int64_t factor = batch_large / batch_small; + if (factor > 1 && batch_small > 1) { + auto batch_large_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector{batch_large}); + auto batch_small_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector{batch_small}); + auto factor_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector{factor}); + + auto unsqueeze_axes = ov::op::v0::Constant::create(ov::element::i64, Shape{}, {2}); + auto Z_unsqueezed = std::make_shared(Z, unsqueeze_axes); + + auto broadcast_shape = ov::op::v0::Constant::create( + ov::element::i64, {5}, {(int64_t) 1, (int64_t) 1, factor, (int64_t) 1, (int64_t) 1}); + auto new_Z_shape = ov::op::v0::Constant::create(ov::element::i64, {4}, + {(int64_t) 0, batch_large, (int64_t) -1, (int64_t) A_shape[3]}); + + auto Z_broadcasted = std::make_shared(Z_unsqueezed, broadcast_shape, + ov::op::BroadcastType::BIDIRECTIONAL); + Z = std::make_shared(Z_broadcasted, new_Z_shape, true); + } + if (A_batch_larger) { + B = Z; + } else { + A = Z; + } + + res = std::make_shared(A, B, false, transpose_b); + + return rename_outputs_with_suffix({res}, context.get_name()); +} + +} // namespace op +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op/permute.cpp b/ggml/src/ggml-openvino/openvino/op/permute.cpp new file mode 100644 index 00000000000..4c800f9ee4f --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op/permute.cpp @@ -0,0 +1,102 @@ +#include "../node_context.h" +#include "../op_table.h" +#include "../utils.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace ov { +namespace frontend { +namespace ggml { +namespace op { + +OutputVector translate_permute(const NodeContext & context) { + num_inputs_check(context, 1, 1); + + int op_case = context.get_op_case(); + FRONT_END_CHECK_IMPLEMENTED(op_case == 1 || op_case == 2 || op_case == 3 || op_case == 4, + "Unsupported PERMUTE case"); + + ov::Output res; + auto src = context.get_input(0); + auto perm = ov::op::v0::Constant::create(ov::element::i64, {4}, {0, 2, 1, 3}); + + if (op_case == 1 || context.is_stateful()) { + res = std::make_shared(src, perm); + } else if (op_case == 4) { + auto output_shape = context.get_output_shape().to_shape(); + auto n_heads = ov::op::v0::Constant::create(ov::element::i64, {1}, {output_shape[1]}); + auto head_size = ov::op::v0::Constant::create(ov::element::i64, {1}, {output_shape[3]}); + auto n_seq_active = context.has_input("n_seq_active") ? + context.get_input("n_seq_active") : + ov::op::v0::Constant::create(ov::element::i64, {1}, {output_shape[0]}); + auto neg_one = ov::op::v0::Constant::create(ov::element::i64, {1}, {-1}); + + auto new_shape = + std::make_shared(ov::OutputVector{n_seq_active, neg_one, n_heads, head_size}, 0); + + // // Alternative + // auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0}); + // auto new_shape = std::make_shared(ov::OutputVector{n_seq_active, neg_one, zero, zero}, 0); + + auto reshaped = std::make_shared(src, new_shape, true); + res = std::make_shared(reshaped, perm); + } else { + auto cache_shape = src.get_partial_shape(); + auto output_shape = context.get_output_shape().to_shape(); + int64_t head_size = output_shape[3]; + int64_t n_heads = output_shape[1]; + int64_t ctx_per_seq = cache_shape[2].is_static() ? cache_shape[2].get_length() : -1; + int64_t n_seq = cache_shape[1].get_length(); + + Output attention_size; + if (!context.has_input("attention_size")) { + attention_size = ov::op::v0::Constant::create(ov::element::i64, {1}, {output_shape[2]}); + } else if (op_case == 2) { + attention_size = context.get_input("attention_size"); + } else { + attention_size = context.get_input("attention_size_swa"); + } + + Output seq_active_start; + Output seq_active_end; + if (context.has_input("seq_active_start")) { + seq_active_start = context.get_input("seq_active_start"); + seq_active_end = context.get_input("seq_active_end"); + } else { + int64_t n_seq_active = output_shape[0]; + size_t offset = *((size_t *) context.get_input_op_params(0)); + int64_t seq_active_start_val = offset / context.get_input_stride(0)[0]; + int64_t seq_active_end_val = seq_active_start_val + n_seq_active; + seq_active_start = ov::op::v0::Constant::create(ov::element::i64, {1}, {seq_active_start_val}); + seq_active_end = ov::op::v0::Constant::create(ov::element::i64, {1}, {seq_active_end_val}); + } + + // 1. reshape to [n_seq, ctx_per_seq, n_heads, head_size] + // 2. slice out the active sequences + // 3. slice out the attention part in each sequence + // 4. permute + auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0}); + auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1}); + + auto src_reshaped = std::make_shared( + src, ov::op::v0::Constant::create(ov::element::i64, {4}, {n_seq, ctx_per_seq, n_heads, head_size}), false); + auto slice1 = std::make_shared(src_reshaped, seq_active_start, seq_active_end, one, zero); + auto slice2 = std::make_shared(slice1, zero, attention_size, one, one); + res = std::make_shared(slice2, perm); + } + return rename_outputs_with_suffix({res}, context.get_name()); +} + +} // namespace op +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op/reshape.cpp b/ggml/src/ggml-openvino/openvino/op/reshape.cpp new file mode 100644 index 00000000000..efd9a5a860a --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op/reshape.cpp @@ -0,0 +1,83 @@ +#include "../node_context.h" +#include "../op_table.h" +#include "../utils.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace ov { +namespace frontend { +namespace ggml { +namespace op { + +OutputVector translate_reshape(const NodeContext & context) { + num_inputs_check(context, 1, 1); + if (context.get_input_shape(0) == context.get_output_shape()) { + return {context.get_input(0)}; + } + + int op_case = context.get_op_case(); + FRONT_END_CHECK_IMPLEMENTED( + op_case == 1 || op_case == 2 || op_case == 3 || op_case == 4 || op_case == 5 || op_case == 6, + "Unsupported RESHAPE case"); + + auto output_shape = context.get_output_shape().to_shape(); + std::shared_ptr new_shape_node; + if (op_case == 1) { + if (context.is_stateful()) { + new_shape_node = ov::op::v0::Constant::create( + ov::element::i64, {3}, + std::vector{-1, (int64_t) output_shape[2], (int64_t) output_shape[3]}); + } else { + new_shape_node = ov::op::v0::Constant::create( + ov::element::i64, {4}, + std::vector{(int64_t) output_shape[0], -1, (int64_t) output_shape[2], (int64_t) output_shape[3]}); + } + } else if (op_case == 2) { + new_shape_node = ov::op::v0::Constant::create( + ov::element::i64, {4}, + std::vector{(int64_t) output_shape[0], (int64_t) output_shape[1], -1, (int64_t) output_shape[3]}); + + } else if (op_case == 3) { + throw std::runtime_error("might be outdated RESHAPE case"); + new_shape_node = ov::op::v0::Constant::create( + ov::element::i64, {4}, std::vector{(int64_t) output_shape[0], (int64_t) output_shape[1], -1, 1}); + + } else if (op_case == 4) { + return {context.get_input(0).get_node_shared_ptr()->input_value(0)}; + + } else if (op_case == 5) { + if (context.is_stateful()) { + std::vector shape_vec = {1, -1, (int64_t) context.get_output_shape().to_shape()[3]}; + new_shape_node = ov::op::v0::Constant::create(ov::element::i64, {3}, shape_vec); + } else { + std::vector shape_vec = {1, 1, -1, (int64_t) context.get_output_shape().to_shape()[3]}; + new_shape_node = ov::op::v0::Constant::create(ov::element::i64, {4}, shape_vec); + } + + // // Alternative + // auto token_len = context.get_input("token_len"); + // auto emb_size = + // ov::op::v0::Constant::create(ov::element::i64, {1}, {(int64_t) context.get_output_shape().to_shape()[3]}); + // auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1}); + // new_shape_node = std::make_shared(ov::OutputVector{one, one, token_len, emb_size}, 0); + + } else if (op_case == 6) { + new_shape_node = ov::op::v0::Constant::create(ov::element::i64, {4}, context.get_output_shape().to_shape()); + } + auto res = std::make_shared(context.get_input(0), new_shape_node, false); + return rename_outputs_with_suffix({res}, context.get_name()); +} + +} // namespace op +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp b/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp new file mode 100644 index 00000000000..72cf92283e9 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp @@ -0,0 +1,46 @@ +#include "../node_context.h" +#include "../op_table.h" +#include "../utils.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace ov { +namespace frontend { +namespace ggml { +namespace op { + +OutputVector translate_rms_norm(const NodeContext & context) { + num_inputs_check(context, 1, 1); + + auto input_node = context.get_input(0); + auto square = std::make_shared( + input_node, ov::op::v0::Constant::create(ov::element::f32, ov::Shape{1}, {2.0f})); + + auto mean = std::make_shared( + square, ov::op::v0::Constant::create(ov::element::i64, ov::Shape{1}, {-1}), true); + + float eps; + memcpy(&eps, context.get_output_op_params(), sizeof(float)); + + auto rms = std::make_shared( + std::make_shared(mean, ov::op::v0::Constant::create(ov::element::f32, ov::Shape{1}, {eps}))); + + auto reciprocal = + std::make_shared(ov::op::v0::Constant::create(ov::element::f32, ov::Shape{1}, {1.0f}), rms); + + auto res = std::make_shared(input_node, reciprocal); + + return rename_outputs_with_suffix({res}, context.get_name()); +} + +} // namespace op +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op/rope.cpp b/ggml/src/ggml-openvino/openvino/op/rope.cpp new file mode 100644 index 00000000000..26dc2d24f82 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op/rope.cpp @@ -0,0 +1,123 @@ +#include "../node_context.h" +#include "../op_table.h" +#include "../utils.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace ov { +namespace frontend { +namespace ggml { +namespace op { + +OutputVector translate_rope(const NodeContext & context) { + num_inputs_check(context, 2, 3); + + int op_case = context.get_op_case(); + + ov::Output res; + + auto data_node = context.get_input(0).get_node_shared_ptr(); + auto output_shape = context.get_output_shape().to_shape(); + int32_t * op_params = context.get_output_op_params(); + + Output cos_theta_node; + Output sin_theta_node; + if (context.has_input("rope_cos")) { + cos_theta_node = context.get_input("rope_cos"); + sin_theta_node = context.get_input("rope_sin"); + } else { + auto inp_pos = context.get_input(1).get_node_shared_ptr(); + std::shared_ptr rope_freqs_weight; + if (context.get_input_size() == 3) { + rope_freqs_weight = context.get_input(2).get_node_shared_ptr(); + } + auto sin_cos = make_sin_cos(op_params, inp_pos, rope_freqs_weight); + sin_theta_node = sin_cos.first; + cos_theta_node = sin_cos.second; + } + + if (op_case == 2) { + // The input comes from a VIEW + int slice_len = output_shape[2] * output_shape[3]; + data_node = process_view_input(context, 0, slice_len).get_node_shared_ptr(); + if (context.is_stateful()) { + auto data_shape = ov::op::v0::Constant::create( + ov::element::i64, {3}, std::vector{-1, (int64_t) output_shape[2], (int64_t) output_shape[3]}); + data_node = std::make_shared(data_node, data_shape, false); + } else { + auto data_shape = ov::op::v0::Constant::create( + ov::element::i64, {4}, std::vector{1, -1, (int64_t) output_shape[2], (int64_t) output_shape[3]}); + data_node = std::make_shared(data_node, data_shape, false); + } + } + + const int mode = op_params[2]; + constexpr int ROPE_TYPE_NORMAL = 0; + constexpr int ROPE_TYPE_NEOX = 2; + + if (mode == ROPE_TYPE_NORMAL) { + auto neg_one = ov::op::v0::Constant::create(ov::element::i64, {1}, {-1}); + auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0}); + auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1}); + auto two = ov::op::v0::Constant::create(ov::element::i64, {1}, {2}); + auto end = ov::op::v0::Constant::create(ov::element::i64, {1}, {output_shape[3]}); + Output even_slice; + Output odd_slice; + int32_t unsqueeze_dim = context.is_stateful() ? 3 : 4; + even_slice = std::make_shared(data_node, zero, end, two, neg_one); + odd_slice = std::make_shared(data_node, one, end, two, neg_one); + + Output first_half = + std::make_shared(std::make_shared(even_slice, cos_theta_node), + std::make_shared(odd_slice, sin_theta_node)); + Output second_half = + std::make_shared(std::make_shared(even_slice, sin_theta_node), + std::make_shared(odd_slice, cos_theta_node)); + + first_half = std::make_shared(first_half, + ov::op::v0::Constant::create(ov::element::i64, {1}, {unsqueeze_dim})); + second_half = std::make_shared(second_half, + ov::op::v0::Constant::create(ov::element::i64, {1}, {unsqueeze_dim})); + auto stack = std::make_shared(OutputVector{first_half, second_half}, unsqueeze_dim); + + auto data_shape = ov::op::v0::Constant::create( + ov::element::i64, {4}, std::vector{1, -1, (int64_t) output_shape[2], (int64_t) output_shape[3]}); + res = std::make_shared(stack, data_shape, false); + } else if (mode == ROPE_TYPE_NEOX) { + auto data_split = std::make_shared( + data_node, ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {-1}), 2); + Output slice_data_node_0 = data_split->outputs()[0]; + Output slice_data_node_1 = data_split->outputs()[1]; + + auto first_half_node = std::make_shared( + std::make_shared(slice_data_node_0, cos_theta_node), + std::make_shared(slice_data_node_1, sin_theta_node)); + + auto second_half_node = std::make_shared( + std::make_shared(slice_data_node_0, sin_theta_node), + std::make_shared(slice_data_node_1, cos_theta_node)); + + res = std::make_shared(ov::OutputVector{first_half_node, second_half_node}, -1); + } + + return rename_outputs_with_suffix({res}, context.get_name()); +} + +} // namespace op +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op/scale.cpp b/ggml/src/ggml-openvino/openvino/op/scale.cpp new file mode 100644 index 00000000000..0f3d800c199 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op/scale.cpp @@ -0,0 +1,41 @@ +#include "../node_context.h" +#include "../op_table.h" +#include "../utils.h" + +#include +#include +#include +#include + +namespace ov { +namespace frontend { +namespace ggml { +namespace op { + +OutputVector translate_scale(const NodeContext & context) { + num_inputs_check(context, 1, 1); + + float scale; + float bias; + memcpy(&scale, (float *) context.get_output_op_params() + 0, sizeof(float)); + memcpy(&bias, (float *) context.get_output_op_params() + 1, sizeof(float)); + + auto scale_node = std::make_shared(ov::element::f32, ov::Shape{}, std::vector{scale}); + auto scaled = std::make_shared(context.get_input(0), scale_node); + + std::shared_ptr res; + if (bias != 0.0f) { + auto bias_node = + std::make_shared(ov::element::f32, ov::Shape{}, std::vector{bias}); + res = std::make_shared(scaled, bias_node); + } else { + res = scaled; + } + + return rename_outputs_with_suffix({res}, context.get_name()); +} + +} // namespace op +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op/set_rows.cpp b/ggml/src/ggml-openvino/openvino/op/set_rows.cpp new file mode 100644 index 00000000000..136e4265b42 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op/set_rows.cpp @@ -0,0 +1,76 @@ +#include "../node_context.h" +#include "../op_table.h" +#include "../utils.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace ov { +namespace frontend { +namespace ggml { +namespace op { + +OutputVector translate_set_rows(const NodeContext & context) { + num_inputs_check(context, 3, 3); + + auto data = context.get_input(0); + auto indices = context.get_input(1); + auto dst = context.get_input(2); + + data = std::make_shared(data, context.get_output_type()); + + auto dst_shape = context.get_output_shape().to_shape(); + + auto ind_squeezed = + std::make_shared(indices, ov::op::v0::Constant::create(ov::element::i64, {3}, {0, 1, 2})); + auto data_reshaped = std::make_shared( + data, + ov::op::v0::Constant::create(ov::element::i64, {4}, + {(int64_t) 1, (int64_t) 1, (int64_t) -1, (int64_t) dst_shape[3]}), + false); + auto axes = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {2}); + + Output res; + if (context.is_stateful()) { + int concat_axis = 1; + int64_t dim2 = dst.get_partial_shape()[2].get_length(); + int64_t dim3 = dst.get_partial_shape()[3].get_length(); + data = std::make_shared( + data, ov::op::v0::Constant::create(ov::element::i64, {4}, {(int64_t) 1, (int64_t) -1, dim2, dim3}), false); + res = std::make_shared(OutputVector{dst, data}, concat_axis); + } else { + res = std::make_shared(dst, ind_squeezed, data_reshaped, axes); + } + + if (auto dst_reshape = std::dynamic_pointer_cast(dst.get_node_shared_ptr())) { + // Fix the case of multiple sequences, reshape back to original shape [1, n_seq, ctx_per_seq, emb] + // ctx_per_seq is not fixed due to llama-bench compatibility + auto dst_shape_partial = dst_reshape->get_input_partial_shape(0); + std::vector dst_shape = {dst_shape_partial[0].get_length(), dst_shape_partial[1].get_length(), + dst_shape_partial[2].is_static() ? dst_shape_partial[2].get_length() : -1, + dst_shape_partial[3].get_length()}; + res = std::make_shared(res, ov::op::v0::Constant::create(ov::element::i64, {4}, dst_shape), + false); + } + return rename_outputs_with_suffix({res}, context.get_name()); +} + +} // namespace op +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op/softmax.cpp b/ggml/src/ggml-openvino/openvino/op/softmax.cpp new file mode 100644 index 00000000000..9f6330862be --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op/softmax.cpp @@ -0,0 +1,89 @@ +#include "../node_context.h" +#include "../op_table.h" +#include "../utils.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace ov { +namespace frontend { +namespace ggml { +namespace op { + +OutputVector translate_soft_max(const NodeContext & context) { + // TODO code is outdated + num_inputs_check(context, 1, 2); + + auto input_node = context.get_input(0).get_node_shared_ptr(); + ov::Output res; + + float scale = 1.0f; + float max_bias = 0.0f; + auto * op_params = context.get_output_op_params(); + memcpy(&scale, (float *) op_params + 0, sizeof(float)); + memcpy(&max_bias, (float *) op_params + 1, sizeof(float)); + auto src0_shape = context.get_input_shape(0).get_shape(); + const uint32_t h = src0_shape[2]; + const uint32_t n_head = src0_shape[0]; + const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head)); + + const float m0 = powf(2.0f, -(max_bias) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + const float slope = + (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2 * (h - n_head_log2) + 1) : 1.0f; + + auto scale_node = std::make_shared(ov::element::f32, ov::Shape{}, std::vector{scale}); + auto scaled_input = std::make_shared(input_node, scale_node); + + if (context.get_input_size() < 2) { + res = std::make_shared(scaled_input, 2); + return rename_outputs_with_suffix({res}, context.get_name()); + } + + ov::Output mask_node_sliced; + if (context.has_input("KQ_mask_sliced")) { + mask_node_sliced = context.get_input("KQ_mask_sliced"); + } else { + auto token_len = get_dimensions(input_node, {1}); + auto mask_node = context.get_input(1); + auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0}); + auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1}); + mask_node_sliced = std::make_shared(mask_node, zero, token_len, one, one); + } + + if (mask_node_sliced.get_element_type() != context.get_output_type()) { + mask_node_sliced = std::make_shared(mask_node_sliced, context.get_output_type()); + } + + Output slope_mask; + if (slope != 1.0f) { + auto slope_node = + std::make_shared(ov::element::f32, ov::Shape{}, std::vector{slope}); + slope_mask = std::make_shared(mask_node_sliced, slope_node); + throw std::runtime_error("Slope != 1.0f in softmax has not been tested, verify it before use."); + } + slope_mask = mask_node_sliced; + + auto input_slope_mask_node = std::make_shared(scaled_input, slope_mask); + + res = std::make_shared(input_slope_mask_node, 2); + + return rename_outputs_with_suffix({res}, context.get_name()); +} + +} // namespace op +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op/transpose.cpp b/ggml/src/ggml-openvino/openvino/op/transpose.cpp new file mode 100644 index 00000000000..8e62e83c0d7 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op/transpose.cpp @@ -0,0 +1,23 @@ +#include "../node_context.h" +#include "../op_table.h" +#include "../utils.h" + +#include + +namespace ov { +namespace frontend { +namespace ggml { +namespace op { + +OutputVector translate_transpose(const NodeContext & context) { + num_inputs_check(context, 1, 1); + + auto res = std::make_shared( + context.get_input(0), ov::op::v0::Constant::create(ov::element::i64, {4}, {0, 1, 3, 2})); + return rename_outputs_with_suffix({res}, context.get_name()); +} + +} // namespace op +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp b/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp new file mode 100644 index 00000000000..037e0b94df1 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp @@ -0,0 +1,27 @@ +#include "../node_context.h" +#include "../op_table.h" +#include "../utils.h" + +#include +#include +#include + +namespace ov { +namespace frontend { +namespace ggml { +namespace op { + +OutputVector translate_unary_silu(const NodeContext & context) { + num_inputs_check(context, 1, 1); + + auto input = context.get_input(0); + auto sigmoid = std::make_shared(input); + auto res = std::make_shared(input, sigmoid); + + return rename_outputs_with_suffix({res}, context.get_name()); +} + +} // namespace op +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op/view.cpp b/ggml/src/ggml-openvino/openvino/op/view.cpp new file mode 100644 index 00000000000..8528d252336 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op/view.cpp @@ -0,0 +1,53 @@ +#include "../op_table.h" +#include "../utils.h" +#include +namespace ov { +namespace frontend { +namespace ggml { +namespace op { + +OutputVector translate_view(const NodeContext & context) { + num_inputs_check(context, 1, 1); + + if (context.get_op_case() == 2) { + auto dst_shape = context.get_output_shape().to_shape(); + return rename_outputs_with_suffix({process_view_input(context, 0, dst_shape[2] * dst_shape[3])}, + context.get_name()); + } + // op_case 3 + if (context.get_op_case() == 3) { + auto input = context.get_input(0); + auto input_ov_shape = input.get_partial_shape(); + + auto input_llama_shape = context.get_input_shape(0).to_shape(); + + // if the input ov shape size is different from the input llama shape size, it means the input is already reshaped and we need to reshape it back to the original shape before slicing + if (input_ov_shape.size() != input_llama_shape.size()) { + input = std::make_shared(input, ov::op::v0::Constant::create(ov::element::i64, {input_llama_shape.size()}, input_llama_shape), false); + } + + auto dst_shape = context.get_output_shape().to_shape(); + + // find the index of dst_shape that is different from input shape, and use that index to slice the input + int slice_dim = -1; + for (size_t i = 0; i < dst_shape.size(); ++i) { + if (dst_shape[i] != input_llama_shape[i]) { + slice_dim = i; + break; + } + } + + auto begin = ov::op::v0::Constant::create(ov::element::i64, {1}, {0}); + auto end = ov::op::v0::Constant::create(ov::element::i64, {1}, {dst_shape[slice_dim]}); + auto stride = ov::op::v0::Constant::create(ov::element::i64, {1}, {1}); + auto axes = ov::op::v0::Constant::create(ov::element::i64, {1}, {slice_dim}); + auto sliced = std::make_shared(input, begin, end, stride, axes); + return {sliced}; + } + return {context.get_input(0)}; +} + +} // namespace op +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op_table.cpp b/ggml/src/ggml-openvino/openvino/op_table.cpp new file mode 100644 index 00000000000..beadafe8103 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op_table.cpp @@ -0,0 +1,46 @@ +#include "op_table.h" + +#include "utils.h" + +#include +#include +#include +#include +#include +#include + +namespace ov { +namespace frontend { +namespace ggml { + +std::unordered_map get_supported_ops() { + using namespace ov::op; + return { + {"GGML_OP_ADD", op::translate_1to1_match_2_inputs }, + {"GGML_OP_ADD1", op::translate_1to1_match_2_inputs }, + {"GGML_OP_CONT", op::translate_cont }, + {"GGML_OP_DIV", op::translate_1to1_match_2_inputs }, + {"GGML_OP_GET_ROWS", op::translate_get_rows }, + {"GGML_OP_MUL", op::translate_1to1_match_2_inputs}, + {"GGML_OP_MUL_MAT", op::translate_mulmat }, + {"GGML_OP_PERMUTE", op::translate_permute }, + {"GGML_OP_RESHAPE", op::translate_reshape }, + {"GGML_OP_RMS_NORM", op::translate_rms_norm }, + {"GGML_OP_ROPE", op::translate_rope }, + {"GGML_OP_SCALE", op::translate_scale }, + {"GGML_OP_SOFT_MAX", op::translate_soft_max }, + {"GGML_OP_SUB", op::translate_1to1_match_2_inputs}, + {"GGML_OP_TRANSPOSE", op::translate_transpose }, + {"GGML_UNARY_OP_SILU", op::translate_unary_silu }, + {"GGML_OP_VIEW", op::translate_view }, + {"GGML_GLU_OP_SWIGLU", op::translate_glu_swiglu }, + {"GGML_GLU_OP_GEGLU", op::translate_glu_geglu }, + {"GGML_OP_SET_ROWS", op::translate_set_rows }, + {"GGML_OP_CPY", op::translate_cpy }, + {"GGML_OP_FLASH_ATTN_EXT", op::translate_flash_attn_ext }, + }; +} + +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op_table.h b/ggml/src/ggml-openvino/openvino/op_table.h new file mode 100644 index 00000000000..37f763117aa --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op_table.h @@ -0,0 +1,39 @@ +#pragma once + +#include "node_context.h" + +namespace ov { +namespace frontend { +namespace ggml { + +namespace op { + +#define GGML_OP_CONVERTER(op) OutputVector op(const NodeContext& context) + +GGML_OP_CONVERTER(translate_add); +GGML_OP_CONVERTER(translate_cont); +GGML_OP_CONVERTER(translate_get_rows); +GGML_OP_CONVERTER(translate_mul); +GGML_OP_CONVERTER(translate_mulmat); +GGML_OP_CONVERTER(translate_permute); +GGML_OP_CONVERTER(translate_reshape); +GGML_OP_CONVERTER(translate_rms_norm); +GGML_OP_CONVERTER(translate_rope); +GGML_OP_CONVERTER(translate_scale); +GGML_OP_CONVERTER(translate_unary_silu); +GGML_OP_CONVERTER(translate_soft_max); +GGML_OP_CONVERTER(translate_transpose); +GGML_OP_CONVERTER(translate_view); +GGML_OP_CONVERTER(translate_glu_swiglu); +GGML_OP_CONVERTER(translate_glu_geglu); +GGML_OP_CONVERTER(translate_set_rows); +GGML_OP_CONVERTER(translate_cpy); +GGML_OP_CONVERTER(translate_flash_attn_ext); + +} // namespace op + +std::unordered_map get_supported_ops(); + +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp b/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp new file mode 100644 index 00000000000..ed2a3ab6d1b --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp @@ -0,0 +1,123 @@ +#include "eliminate_zp.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace ov { +namespace frontend { +namespace ggml { +namespace pass { + +EliminateZeroPoints::EliminateZeroPoints() { + // Find pattern: + // (Multiply Any(scale) + // (Subtract (Convert Constant(data))) + // (Convert Constant(zero_point))) + // where zero_point is a scalar + // If data is u4 and zp value is 8 (q4_0), Replace the Subtract with an i4 Constant whose value is data - zp_val + // If data is u8 and zp value is 128 (q8_0) or 32 (q6_k), Replace the Subtract with an i8 Constant + + auto m_data_constant = ov::pass::pattern::wrap_type(); + auto m_data_convert = ov::pass::pattern::wrap_type({m_data_constant}); + + auto m_zp_constant = ov::pass::pattern::wrap_type(); + auto m_zp_convert = ov::pass::pattern::wrap_type({m_zp_constant}); + + auto m_subtract = ov::pass::pattern::wrap_type({m_data_convert, m_zp_convert}); + auto m_scale = ov::pass::pattern::any_input(); + auto m_multiply = ov::pass::pattern::wrap_type({m_scale, m_subtract}); + + const auto callback = [=](ov::pass::pattern::Matcher & m) { + const auto & pattern_map = m.get_pattern_value_map(); + + auto multiply_node = + std::dynamic_pointer_cast(pattern_map.at(m_multiply).get_node_shared_ptr()); + auto subtract_node = + std::dynamic_pointer_cast(pattern_map.at(m_subtract).get_node_shared_ptr()); + auto data_constant = + std::dynamic_pointer_cast(pattern_map.at(m_data_constant).get_node_shared_ptr()); + auto zp_constant = + std::dynamic_pointer_cast(pattern_map.at(m_zp_constant).get_node_shared_ptr()); + + if (!multiply_node || !subtract_node || !data_constant || !zp_constant) { + return false; + } + + if (ov::shape_size(zp_constant->get_shape()) != 1) { + return false; + } + + auto data_type = data_constant->get_element_type(); + auto zp_data = zp_constant->cast_vector(); + + if (zp_data.empty()) { + return false; + } + + int zp_value = zp_data[0]; + + bool should_eliminate = false; + ov::element::Type target_type; + + if (data_type == ov::element::u4 && zp_value == 8) { + should_eliminate = true; + target_type = ov::element::i4; + } else if (data_type == ov::element::u8 && (zp_value == 128 || zp_value == 32)) { + should_eliminate = true; + target_type = ov::element::i8; + } + + if (!should_eliminate) { + return false; + } + + auto data_shape = data_constant->get_shape(); + size_t total_elements = ov::shape_size(data_shape); + + std::shared_ptr new_constant; + + // TODO improve performance + if (data_type == ov::element::u4) { + auto data_values = data_constant->cast_vector(); + std::vector adjusted_values(total_elements); + + ov::parallel_for(total_elements, [&](size_t i) { + adjusted_values[i] = static_cast(static_cast(data_values[i]) - 8); + }); + + new_constant = std::make_shared(target_type, data_shape, adjusted_values); + } else if (data_type == ov::element::u8) { + auto data_values = data_constant->cast_vector(); + std::vector adjusted_values(total_elements); + + ov::parallel_for(total_elements, [&, zp_value](size_t i) { + adjusted_values[i] = static_cast(static_cast(data_values[i]) - zp_value); + }); + + new_constant = std::make_shared(target_type, data_shape, adjusted_values); + } + + auto new_convert = + std::make_shared(new_constant, subtract_node->get_output_element_type(0)); + ov::replace_node(subtract_node, new_convert); + + return true; + }; + + register_matcher( + std::make_shared(m_multiply, "ov::frontend::ggml::pass::EliminateZeroPoints"), + callback); +} + +} // namespace pass +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h b/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h new file mode 100644 index 00000000000..edd3cd718d9 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h @@ -0,0 +1,17 @@ +#include "openvino/pass/matcher_pass.hpp" + +namespace ov { +namespace frontend { +namespace ggml { +namespace pass { + +class EliminateZeroPoints : public ov::pass::MatcherPass { +public: + OPENVINO_MATCHER_PASS_RTTI("ov::frontend::ggml::pass::EliminateZeroPoints") + EliminateZeroPoints(); +}; + +} // namespace pass +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp b/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp new file mode 100644 index 00000000000..0671542ee38 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp @@ -0,0 +1,60 @@ +#include "fuse_to_sdpa.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace ov { +namespace frontend { +namespace ggml { +namespace pass { + +FuseToSDPA::FuseToSDPA() { + // Not maintained since FLASH_ATTN_EXT has replaced this pattern + const auto m_k = ov::pass::pattern::any_input(); + const auto m_q = ov::pass::pattern::any_input(); + const auto m_qk = ov::pass::pattern::wrap_type({m_q, m_k}); + const auto m_qk_f32 = ov::pass::pattern::wrap_type({m_qk}); + const auto m_scale = ov::pass::pattern::any_input(); + const auto m_scaled_qk = ov::pass::pattern::wrap_type({m_qk_f32, m_scale}); + const auto m_mask = ov::pass::pattern::any_input(); + const auto m_masked_qk = ov::pass::pattern::wrap_type({m_scaled_qk, m_mask}); + const auto m_softmax_qk = ov::pass::pattern::wrap_type({m_masked_qk}); + const auto m_softmax_qk_f16 = ov::pass::pattern::wrap_type({m_softmax_qk}); + const auto m_v = ov::pass::pattern::any_input(); + const auto m_qkv = ov::pass::pattern::wrap_type({m_softmax_qk_f16, m_v}); + + const auto callback = [=](ov::pass::pattern::Matcher & m) { + auto & pattern_to_output = m.get_pattern_value_map(); + auto k = pattern_to_output[m_k]; + auto q = pattern_to_output[m_q]; + auto v = pattern_to_output[m_v]; + auto mask = pattern_to_output[m_mask]; + auto scale = pattern_to_output[m_scale]; + + auto mask_f16 = register_new_node(mask, ov::element::f16); + auto scale_f16 = register_new_node(scale, ov::element::f16); + auto sdpa = std::make_shared(q, k, v, mask_f16, scale_f16, false); + + ov::replace_node(m.get_match_root(), sdpa); + ov::copy_runtime_info(m.get_matched_nodes(), sdpa); + + return true; + }; + register_matcher(std::make_shared(m_qkv, "ov::frontend::ggml::pass::FuseToSDPA"), + callback); +} + +} // namespace pass +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h b/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h new file mode 100644 index 00000000000..8b5164d2329 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h @@ -0,0 +1,17 @@ +#include "openvino/pass/matcher_pass.hpp" + +namespace ov { +namespace frontend { +namespace ggml { +namespace pass { + +class FuseToSDPA : public ov::pass::MatcherPass { +public: + OPENVINO_MATCHER_PASS_RTTI("ov::frontend::ggml::pass::FuseToSDPA") + FuseToSDPA(); +}; + +} // namespace pass +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h b/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h new file mode 100644 index 00000000000..b95385611e8 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h @@ -0,0 +1,29 @@ +#pragma once + +#include "mark_decompression_convert_constant_folding.h" +#include "openvino/pass/matcher_pass.hpp" +#include "openvino/core/visibility.hpp" + +#ifdef OPENVINO_STATIC_LIBRARY +# define TRANSFORMATIONS_API +#else +# ifdef IMPLEMENT_OPENVINO_API +# define TRANSFORMATIONS_API OPENVINO_CORE_EXPORTS +# else +# define TRANSFORMATIONS_API OPENVINO_CORE_IMPORTS +# endif // IMPLEMENT_OPENVINO_API +#endif // OPENVINO_STATIC_LIBRARY + +namespace ov { +namespace pass { + +class TRANSFORMATIONS_API MarkCompressedFloatConstants; + +} // namespace pass +} // namespace ov + +class ov::pass::MarkCompressedFloatConstants : public MatcherPass { +public: + OPENVINO_MATCHER_PASS_RTTI("MarkCompressedFloatConstants") + MarkCompressedFloatConstants(); +}; diff --git a/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp b/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp new file mode 100644 index 00000000000..20a3a374934 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp @@ -0,0 +1,58 @@ +#include "squeeze_matmul.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace opp = ov::pass::pattern; + +namespace ov { +namespace frontend { +namespace ggml { +namespace pass { + +// For quantized models, NPUW expects the activation to be 3d in DQ(DynamicQuantization) opt, e.g. DQMatMulGQ2i +SqueezeMatmul::SqueezeMatmul() { + auto m_act = opp::any_input(); + auto m_wei = opp::any_input(); + auto m_matmul = opp::wrap_type({m_act, m_wei}); + + const auto callback = [=](ov::pass::pattern::Matcher & m) { + const auto & pattern_map = m.get_pattern_value_map(); + auto matmul_node = + std::dynamic_pointer_cast(pattern_map.at(m_matmul).get_node_shared_ptr()); + auto act = pattern_map.at(m_act); + auto wei = pattern_map.at(m_wei); + auto act_shape = act.get_partial_shape(); + auto wei_shape = wei.get_partial_shape(); + if (act_shape.rank().is_dynamic() || wei_shape.rank().is_dynamic()) { + return false; + } + if (act_shape.rank().get_length() == 4 && wei_shape.rank().get_length() == 2) { + auto axis = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{1}, {0}); + auto squeezed_act = std::make_shared(act, axis); + auto new_matmul = std::make_shared(squeezed_act, wei, matmul_node->get_transpose_a(), + matmul_node->get_transpose_b()); + auto unsqueezed_output = std::make_shared(new_matmul, axis); + unsqueezed_output->set_friendly_name(matmul_node->get_friendly_name()); + ov::copy_runtime_info(matmul_node, {squeezed_act, new_matmul, unsqueezed_output}); + ov::replace_node(matmul_node, unsqueezed_output); + return true; + } + return false; + }; + + register_matcher(std::make_shared(m_matmul, "ov::frontend::ggml::pass::SqueezeMatmul"), + callback); +} + +} // namespace pass +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h b/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h new file mode 100644 index 00000000000..f8fbc69d546 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h @@ -0,0 +1,17 @@ +#include "openvino/pass/matcher_pass.hpp" + +namespace ov { +namespace frontend { +namespace ggml { +namespace pass { + +class SqueezeMatmul : public ov::pass::MatcherPass { +public: + OPENVINO_MATCHER_PASS_RTTI("ov::frontend::ggml::pass::SqueezeMatmul") + SqueezeMatmul(); +}; + +} // namespace pass +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/translate_session.cpp b/ggml/src/ggml-openvino/openvino/translate_session.cpp new file mode 100644 index 00000000000..23a1dea2496 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/translate_session.cpp @@ -0,0 +1,293 @@ +#include "translate_session.h" + +#include "ggml-openvino/openvino/node_context.h" +#include "ggml-openvino/openvino/utils.h" +#include "input_model.h" +#include "pass/eliminate_zp.h" +#include "pass/mark_decompression_convert_constant_folding.h" +#include "pass/squeeze_matmul.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace ov { +namespace frontend { +namespace ggml { + +using namespace ov::op; + +namespace { + +ov::pass::MakeStateful::ParamResPairs get_kv_param_res_pairs( + const std::shared_ptr & model, + const std::map & kv_param_res_names) { + ov::pass::MakeStateful::ParamResPairs pairs; + const auto & params = model->get_parameters(); + const auto & results = model->get_results(); + + for (const auto & param_res : kv_param_res_names) { + const auto & param_name = param_res.first; + const auto & res_name = param_res.second; + + auto param_it = std::find_if(params.begin(), params.end(), [&](const std::shared_ptr & node) { + return node->get_friendly_name() == param_name; + }); + + OPENVINO_ASSERT(param_it != params.end(), "The tensor name ", param_name, + " is not associated with any of " + "Parameters in the network."); + + auto res_it = std::find_if(results.begin(), results.end(), [&](const std::shared_ptr & node) { + return node->get_friendly_name() == res_name; + }); + + OPENVINO_ASSERT(res_it != results.end(), "The tensor name ", res_name, + " is not associated with any of " + "Results in the network."); + + std::shared_ptr param = *param_it; + std::shared_ptr res = *res_it; + pairs.emplace_back(param, res); + } + return pairs; +} + +void add_sliced_mask(TensorMap & tensor_map, GgmlDecoder & ggml_model_decoder) { + + auto create_sliced_mask = [&](const std::string & mask_name, const std::string & sliced_name, bool is_static) { + if ((tensor_map.find(mask_name) != tensor_map.end()) && + (tensor_map.find("token_len_per_seq") != tensor_map.end())) { + auto token_len_per_seq = tensor_map.at("token_len_per_seq").get_node_shared_ptr(); + auto mask = tensor_map.at(mask_name).get_node_shared_ptr(); + std::shared_ptr mask_sliced; + if (is_static) { + mask_sliced = mask; + } else if (ggml_model_decoder.is_stateful()) { + auto zero_2d = ov::op::v0::Constant::create(ov::element::i64, {2}, {0,0}); + auto one_2d = ov::op::v0::Constant::create(ov::element::i64, {2}, {1,1}); + auto zero_1d = ov::op::v0::Constant::create(ov::element::i64, {1}, {0}); + auto three_1d = ov::op::v0::Constant::create(ov::element::i64, {1}, {3}); + auto neg_one_1d = ov::op::v0::Constant::create(ov::element::i64, {1}, {-1}); + auto axes = ov::op::v0::Constant::create(ov::element::i64, {2}, {-2,-1}); + auto inp_pos = tensor_map.at("inp_pos").get_node_shared_ptr(); + auto gather_inp_pos = std::make_shared(inp_pos, neg_one_1d, three_1d); + auto reshaped_inp_pos = std::make_shared(gather_inp_pos, ov::op::v0::Constant::create(ov::element::i64, {1}, {1}), false); + auto inp_pos_incremented = std::make_shared(reshaped_inp_pos, ov::op::v0::Constant::create(ov::element::i32, ov::Shape{1}, {1})); + auto stop = std::make_shared(ov::OutputVector{token_len_per_seq, std::make_shared(inp_pos_incremented, token_len_per_seq)}, 0); + mask_sliced = + std::make_shared(mask, zero_2d, stop, one_2d, axes); + mask_sliced = std::make_shared(mask_sliced, ov::element::f16); + mask_sliced->set_friendly_name(sliced_name); + } else { + auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0}); + auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1}); + auto two = ov::op::v0::Constant::create(ov::element::i64, {1}, {2}); + mask_sliced = std::make_shared(mask, zero, token_len_per_seq, one, two); + mask_sliced = std::make_shared(mask_sliced, ov::element::f16); + mask_sliced->set_friendly_name(sliced_name); + } + tensor_map.insert({sliced_name, mask_sliced->output(0)}); + } + }; + + create_sliced_mask("self_kq_mask", "KQ_mask_sliced", ggml_model_decoder.is_static()); + create_sliced_mask("self_kq_mask_swa", "KQ_mask_swa_sliced", ggml_model_decoder.is_static()); +} + +void add_rope_sin_cos(TensorMap & tensor_map, GgmlDecoder & ggml_model_decoder) { + int32_t * rope_params = ggml_model_decoder.get_rope_params(); + if (tensor_map.find("inp_pos") == tensor_map.end() || rope_params == nullptr) { + return; + } + auto inp_pos = tensor_map.at("inp_pos").get_node_shared_ptr(); + std::shared_ptr rope_freqs_weight; + if (tensor_map.find("rope_freqs.weight") != tensor_map.end()) { + rope_freqs_weight = tensor_map.at("rope_freqs.weight").get_node_shared_ptr(); + } + + auto sin_cos = make_sin_cos(rope_params, inp_pos, rope_freqs_weight); + auto sin_theta = sin_cos.first; + auto cos_theta = sin_cos.second; + + cos_theta.get_node_shared_ptr()->set_friendly_name("rope_cos"); + sin_theta.get_node_shared_ptr()->set_friendly_name("rope_sin"); + tensor_map.insert({"rope_cos", cos_theta}); + tensor_map.insert({"rope_sin", sin_theta}); +} + +// Create common patterns +void preprocess(TensorMap & tensor_map, GgmlDecoder & ggml_model_decoder) { + add_sliced_mask(tensor_map, ggml_model_decoder); + add_rope_sin_cos(tensor_map, ggml_model_decoder); +} + +} // namespace + +TranslateSession::TranslateSession(const frontend::InputModel::Ptr & input_model, + const std::unordered_map & translator_map, + bool naive) : + m_input_model(input_model), + m_translator_map(translator_map), + m_ov_model(nullptr), + m_naive(naive) {} + +std::shared_ptr TranslateSession::get_converted_model() { + if (m_ov_model) { + return m_ov_model; + } + m_ov_model = translate_graph(m_input_model); + return m_ov_model; +} + +std::shared_ptr TranslateSession::translate_graph(const frontend::InputModel::Ptr & input_model) { + ov::ParameterVector params; + ov::ResultVector results; + auto tensor_map = std::make_shared(); + std::shared_ptr resulting_model; + + const auto & ggml_model = std::dynamic_pointer_cast(input_model); + std::shared_ptr ggml_model_decoder = ggml_model->get_model_decoder(); + + for (const auto & it : ggml_model_decoder->get_model_inputs()) { + params.push_back(std::dynamic_pointer_cast(it.second)); + (*tensor_map)[it.first] = it.second; + } + + for (const auto & it : ggml_model_decoder->get_model_extra_inputs()) { + if (std::dynamic_pointer_cast(it.second)) { + params.push_back(std::dynamic_pointer_cast(it.second)); + } + (*tensor_map)[it.first] = it.second; + } + + for (const auto & it : ggml_model_decoder->get_model_weights()) { + (*tensor_map)[it.first] = it.second; + } + + auto node_visitor = [&](std::shared_ptr decoder, int node_idx) { + auto operation_type = decoder->get_op_type(node_idx); + if (operation_type == "GGML_OP_NONE") { + return; + } + + ov::OutputVector converted_outputs; + auto it = m_translator_map.find(operation_type); + FRONT_END_OP_CONVERSION_CHECK(it != m_translator_map.end(), "Translation for operation type ", operation_type, + " is not implemented."); + NodeContext node_context(decoder, tensor_map, node_idx, this); + converted_outputs = it->second(node_context); + + const auto & node_output_names = decoder->get_output_names(node_idx); + FRONT_END_OP_CONVERSION_CHECK(node_output_names.size() == converted_outputs.size(), "Number of ", + operation_type, " outputs greater than number of converted outputs, which are ", + node_output_names.size(), " and ", converted_outputs.size(), " respectively."); + + for (size_t i = 0; i < node_output_names.size(); ++i) { + auto output_name = node_output_names[i]; + if (i < converted_outputs.size() && converted_outputs[i].get_node_shared_ptr() != nullptr) { + (*tensor_map)[output_name] = converted_outputs[i]; + } + } + }; + + if (!m_naive) { + preprocess(*tensor_map, *ggml_model_decoder); + } + ggml_model_decoder->visit_subgraph(node_visitor); + + for (const auto & name : ggml_model_decoder->get_model_output_names()) { + FRONT_END_GENERAL_CHECK(tensor_map->find(name) != tensor_map->end(), + "Output name not found in tensor map: ", name); + auto result = std::make_shared(tensor_map->at(name)); + result->set_friendly_name(name); + results.push_back(result); + } + + ov::ParameterVector used_params; + for (const auto & param : params) { + if (!param->output(0).get_target_inputs().empty()) { + used_params.push_back(param); + } + } + // if (auto diff = params.size() - used_params.size()) { + // GGML_LOG_INFO("%zu parameters are not used in the model.", diff); + // } + resulting_model = std::make_shared(results, used_params); + + apply_transformations(resulting_model); + return resulting_model; +} + +std::shared_ptr TranslateSession::apply_transformations(std::shared_ptr model) { + auto ggml_model_decoder = std::dynamic_pointer_cast(m_input_model)->get_model_decoder(); + { + ov::pass::Manager manager; + manager.set_per_pass_validation(true); + manager.register_pass(); + + if (ggml_model_decoder->is_stateful()) { + const auto kv_param_res_names = ggml_model_decoder->get_kv_param_res_names(); + const auto kv_param_res_pairs = get_kv_param_res_pairs(model, kv_param_res_names); + manager.register_pass(kv_param_res_pairs); + } + + if (ggml_model_decoder->is_static()) { + manager.register_pass(); + manager.register_pass(); + } + manager.run_passes(model); + if (ggml_model_decoder->is_stateful()) { + auto output_names = ggml_model_decoder->get_model_output_names(); + std::map model_output_indexes; + for (size_t i=0; iget_output_size(); i++) { + auto output_friendly_name = model->output(i).get_node_shared_ptr()->get_friendly_name(); + auto output_id = model_output_indexes[output_friendly_name]; + auto model_output_shape = model->output(i).get_partial_shape(); + auto decoder_output_shape = ggml_model_decoder->get_output_shape(output_id); + if (model_output_shape.rank().is_static() && decoder_output_shape.rank().is_static() + && model_output_shape.rank().get_length() + 1 == decoder_output_shape.rank().get_length() + && decoder_output_shape[0].is_static() && decoder_output_shape[0].get_length() == 1) { + ppp.output(i).postprocess().custom([](const ov::Output& node) { + auto axes = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{1}, {0}); + return std::make_shared(node, axes); + }); + } + } + model = ppp.build(); + } + } + return model; +} + +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/translate_session.h b/ggml/src/ggml-openvino/openvino/translate_session.h new file mode 100644 index 00000000000..56a14ae7c07 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/translate_session.h @@ -0,0 +1,28 @@ +#pragma once + +#include "input_model.h" +#include "node_context.h" + +namespace ov { +namespace frontend { +namespace ggml { + +class TranslateSession { +public: + TranslateSession(const frontend::InputModel::Ptr& input_model, + const std::unordered_map& translator_map, bool naive = false); + + std::shared_ptr get_converted_model(); + std::shared_ptr translate_graph(const frontend::InputModel::Ptr& input_model); + +private: + std::shared_ptr apply_transformations(std::shared_ptr model); + const frontend::InputModel::Ptr m_input_model; + const std::unordered_map& m_translator_map; + std::shared_ptr m_ov_model; + bool m_naive; +}; + +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/utils.cpp b/ggml/src/ggml-openvino/openvino/utils.cpp new file mode 100644 index 00000000000..65356a51b51 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/utils.cpp @@ -0,0 +1,226 @@ +#include "utils.h" + +#include "ggml-impl.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace ov { +namespace frontend { +namespace ggml { + +std::string getCurrentTime() { + std::time_t now = std::time(nullptr); + char buf[100]; + std::strftime(buf, sizeof(buf), "%Y-%m-%d %H:%M:%S", std::localtime(&now)); + return buf; +} + +void num_inputs_check(const NodeContext & context, size_t min_inputs, size_t max_inputs) { + auto input_size = context.get_input_size(); + FRONT_END_OP_CONVERSION_CHECK(input_size >= min_inputs, "Got less inputs than expected"); + FRONT_END_OP_CONVERSION_CHECK(input_size <= max_inputs, "Got more inputs than expected"); +} + +int non_cont_dim(std::vector ne, std::vector nb) { + int dim = nb.size() - 1; + size_t bytes = nb[dim]; + for (int i = dim; i > 0; i--) { + bytes *= ne[i]; + if (bytes != nb[i - 1]) { + return i; + } + } + return 0; +} + +std::shared_ptr get_dimensions(const std::shared_ptr & shape, + const std::vector & dims) { + using namespace ov::op; + const auto zero = v0::Constant::create(ov::element::i32, ov::Shape{}, {0}); + const auto dims_const = v0::Constant::create(ov::element::i32, ov::Shape{dims.size()}, dims); + return std::make_shared(shape, dims_const, zero); +} + +std::shared_ptr get_dimensions(const std::shared_ptr & node, const std::vector & dims) { + return get_dimensions(std::make_shared(node), dims); +} + +OutputVector rename_outputs_with_suffix(const OutputVector & outputs, const std::string & suffix) { + for (const auto & output : outputs) { + auto node = output.get_node_shared_ptr(); + std::string name = node->get_friendly_name(); + name += "_"; + name += suffix; + node->set_friendly_name(name); + // std::cout << name << " " << output.get_partial_shape() << std::endl; + } + return outputs; +} + +namespace { +ov::Output rope_yarn_ramp_mix(int n_dims, const float corr_dims[2], float ext_factor) { + int half_n_dims = n_dims / 2; + std::vector dim_ids_vec(half_n_dims); + std::iota(dim_ids_vec.begin(), dim_ids_vec.end(), 0); + auto dim_ids = ov::op::v0::Constant::create(ov::element::f32, Shape{1, 1, 1, (size_t) half_n_dims}, dim_ids_vec); + auto corr_low = ov::op::v0::Constant::create(ov::element::f32, Shape{1, 1, 1, 1}, {corr_dims[0]}); + auto corr_high = ov::op::v0::Constant::create(ov::element::f32, Shape{1, 1, 1, 1}, {corr_dims[1]}); + auto denom = std::make_shared( + std::make_shared(corr_high, corr_low), + ov::op::v0::Constant::create(ov::element::f32, Shape{1, 1, 1, 1}, {0.001f})); + auto ramp_y = + std::make_shared(std::make_shared(dim_ids, corr_low), denom); + auto ramp_clamped = std::make_shared(ramp_y, 0.0f, 1.0f); + auto ext_factor_node = ov::op::v0::Constant::create(ov::element::f32, Shape{}, {ext_factor}); + auto ramp_mix = std::make_shared(ramp_clamped, ext_factor_node); + return ramp_mix; +} + +float ggml_rope_yarn_corr_dim(int n_dims, int n_ctx_orig, float n_rot, float base) { +#ifndef M_PI +# define M_PI 3.14159265358979323846 +#endif + return n_dims * logf(n_ctx_orig / (n_rot * 2 * (float) M_PI)) / (2 * logf(base)); +} + +void ggml_rope_yarn_corr_dims(int n_dims, + int n_ctx_orig, + float freq_base, + float beta_fast, + float beta_slow, + float dims[2]) { + float start = floorf(ggml_rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_fast, freq_base)); + float end = ceilf(ggml_rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_slow, freq_base)); + dims[0] = std::max(0.0f, start); + dims[1] = std::min(static_cast(n_dims - 1), end); +} +} // namespace + +std::pair, ov::Output> make_sin_cos(int32_t * rope_params, + std::shared_ptr inp_pos, + std::shared_ptr rope_freqs_weight, + bool stateful) { + if (stateful) { + inp_pos = std::make_shared(inp_pos, ov::op::v0::Constant::create(ov::element::i64, {1}, {0})); + inp_pos = std::make_shared(inp_pos, ov::element::f32); + auto pos_perm = + std::make_shared(ov::element::i64, ov::Shape{3}, std::vector{2, 1, 0}); + inp_pos = std::make_shared(inp_pos, pos_perm); + } else { + inp_pos = std::make_shared(inp_pos, ov::element::f32); + auto pos_perm = + std::make_shared(ov::element::i64, ov::Shape{4}, std::vector{0, 3, 1, 2}); + inp_pos = std::make_shared(inp_pos, pos_perm); + } + + float freq_base; + float freq_scale; + float ext_factor; + float attn_factor; + float beta_fast; + float beta_slow; + const int n_dims = rope_params[1]; + const int n_ctx_orig = rope_params[4]; + memcpy(&freq_base, rope_params + 5, sizeof(float)); + memcpy(&freq_scale, rope_params + 6, sizeof(float)); + memcpy(&ext_factor, rope_params + 7, sizeof(float)); + memcpy(&attn_factor, rope_params + 8, sizeof(float)); + memcpy(&beta_fast, rope_params + 9, sizeof(float)); + memcpy(&beta_slow, rope_params + 10, sizeof(float)); + + const float theta_scale = powf(freq_base, -2.0f / n_dims); + + float corr_dims[2]; + ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); + + std::vector factor(n_dims / 2); + factor[0] = 1.0f; + for (size_t i = 1; i < factor.size(); i++) { + factor[i] = theta_scale * factor[i - 1]; + } + + Output freq_factors; + if (stateful) { + freq_factors = + std::make_shared(ov::element::f32, ov::Shape{1, 1, factor.size()}, factor); + } else { + freq_factors = + std::make_shared(ov::element::f32, ov::Shape{1, 1, 1, factor.size()}, factor); + } + if (rope_freqs_weight) { + freq_factors = std::make_shared(freq_factors, rope_freqs_weight); + } + + auto theta_extrap = std::make_shared(freq_factors, inp_pos); + auto theta_interp = std::make_shared( + theta_extrap, ov::op::v0::Constant::create(ov::element::f32, {1}, {freq_scale})); + + Output theta; + float mscale = attn_factor; + if (ext_factor == 0.0f) { + theta = theta_interp; + } else { + auto ramp_mix = rope_yarn_ramp_mix(n_dims, corr_dims, ext_factor); + Output one; + if (stateful) { + one = ov::op::v0::Constant::create(ov::element::f32, Shape{1, 1, 1}, {1.0f}); + } else { + one = ov::op::v0::Constant::create(ov::element::f32, Shape{1, 1, 1, 1}, {1.0f}); + } + auto one_minus_ramp = std::make_shared(one, ramp_mix); + + theta = std::make_shared(std::make_shared(theta_interp, one_minus_ramp), + std::make_shared(theta_extrap, ramp_mix)); + mscale *= (1.0f + 0.1f * std::log(1.0f / freq_scale)); + } + + Output cos_theta = std::make_shared(theta); + Output sin_theta = std::make_shared(theta); + + auto mscale_node = ov::op::v0::Constant::create(ov::element::f32, Shape{}, {mscale}); + + cos_theta = std::make_shared(cos_theta, mscale_node); + sin_theta = std::make_shared(sin_theta, mscale_node); + return std::make_pair(sin_theta, cos_theta); +} + +ov::Output process_view_input(const NodeContext & context, int input_index, int slice_len) { + // Only works for VIEW operations that slice at the lowest dimension + // If the VIEW also reshape the result, `slice_len` should be provided + auto input = context.get_input(input_index); + auto * op_params = (size_t *) context.get_input_op_params(input_index); + auto src1_stride = context.get_input_stride(input_index); + + int64_t split_addr = op_params[0] / src1_stride[3]; + if (slice_len == 0) { + slice_len = context.get_input_shape(input_index)[3].get_length(); + } + int64_t slice_end = split_addr + slice_len; + + auto begin = ov::op::v0::Constant::create(ov::element::i64, {1}, {split_addr}); + auto end = ov::op::v0::Constant::create(ov::element::i64, {1}, {slice_end}); + auto stride = ov::op::v0::Constant::create(ov::element::i64, {1}, {1}); + auto axes = ov::op::v0::Constant::create(ov::element::i64, {1}, {context.is_stateful() ? 2 : 3}); + auto sliced = std::make_shared(input, begin, end, stride, axes); + return sliced; +} + +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/utils.h b/ggml/src/ggml-openvino/openvino/utils.h new file mode 100644 index 00000000000..88dcad4c906 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/utils.h @@ -0,0 +1,85 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include "node_context.h" + +namespace ov { +namespace frontend { +namespace ggml { + +std::string getCurrentTime(); + +void dump_ov_model(std::shared_ptr model); + +void num_inputs_check(const NodeContext& context, size_t min_inputs, size_t max_inputs); + +int non_cont_dim(std::vector ne, std::vector nb); + +template +std::vector argsort_descend(const std::vector& v) { + std::vector idx(v.size()); + std::iota(idx.begin(), idx.end(), 0); + std::sort(idx.begin(), idx.end(), [&v](int i1, int i2) { + return v[i1] > v[i2]; + }); + return idx; +} + +template +std::vector sorted_descend(std::vector v) { + std::sort(v.begin(), v.end(), [](T a, T b) { + return a > b; + }); + return v; +} + +template +bool is_permuted(const std::vector& strides) { + for (size_t i = 0; i < strides.size() - 1; ++i) { + if (strides[i] < strides[i + 1]) { + return true; + } + } + return false; +} + +template +std::vector permute(const std::vector& x, const std::vector& perm) { + std::vector result; + result.reserve(perm.size()); + for (int i : perm) { + result.push_back(x[i]); + } + return result; +} + +std::shared_ptr get_dimensions(const std::shared_ptr& shape, + const std::vector& dims); +std::shared_ptr get_dimensions(const std::shared_ptr& node, const std::vector& dims); + +OutputVector rename_outputs_with_suffix(const OutputVector& outputs, const std::string& suffix); + +std::pair, ov::Output> make_sin_cos(int32_t* rope_params, + std::shared_ptr inp_pos, + std::shared_ptr rope_freqs_weight = nullptr, + bool stateful = false); + +ov::Output process_view_input(const NodeContext& context, int input_index, int slice_len = 0); + +namespace op { +template +OutputVector translate_1to1_match_2_inputs(const NodeContext& context) { + num_inputs_check(context, 2, 2); + auto res = std::make_shared(context.get_input(0), context.get_input(1)); + return rename_outputs_with_suffix({res}, context.get_name()); +} +} // namespace op + +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/utils.cpp b/ggml/src/ggml-openvino/utils.cpp new file mode 100644 index 00000000000..1b553a0de00 --- /dev/null +++ b/ggml/src/ggml-openvino/utils.cpp @@ -0,0 +1,823 @@ +#include "utils.h" + +#include "ggml-impl.h" +#include "ggml-openvino-extra.h" +#include "ggml-openvino/ggml-decoder.h" +#include "ggml.h" +#include "openvino/frontend.h" +#include "openvino/input_model.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// Suppress deprecation warning for ov::Tensor::data() +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wdeprecated-declarations" + +enum ggml_status ov_graph_compute(ggml_cgraph * cgraph, ggml_backend_t backend) { + ggml_backend_openvino_context * ctx = (ggml_backend_openvino_context *) backend->context; + try { + if (getenv("GGML_OPENVINO_DUMP_CGRAPH")) { + std::string filename = "cgraph_ov.txt"; + GgmlOvDecoder::dump_cgraph(cgraph, filename); + } + + const auto is_static = ggml_openvino_is_npu(); + + GGML_ASSERT(ctx->runtime_context != nullptr); + std::shared_ptr r_ctx = std::static_pointer_cast(ctx->runtime_context); + + return is_static ? ov_graph_compute_static(cgraph, r_ctx) : ov_graph_compute_dynamic(cgraph, r_ctx); + } catch (const ov::Exception & e) { + GGML_LOG_ERROR("GGML OpenVINO backend ov::Exception: %s\n", e.what()); + return GGML_STATUS_FAILED; + } catch (const std::exception & e) { + GGML_LOG_ERROR("GGML OpenVINO backend std::exception: %s\n", e.what()); + return GGML_STATUS_FAILED; + } catch (...) { + GGML_LOG_ERROR("GGML OpenVINO backend unknown exception\n"); + return GGML_STATUS_FAILED; + } +} + +ov::Tensor create_ov_output_tensor(std::shared_ptr ggml_decoder, + std::shared_ptr infer_request, + int output_index, + const ggml_tensor * ggml_tensor) { + auto output_type = ggml_decoder->get_ov_type(ggml_tensor); + ov::Shape output_shape; + if (ggml_decoder->is_static()) { + output_shape = infer_request->get_output_tensor(output_index).get_shape(); + } else { + output_shape = ggml_decoder->get_shape(ggml_tensor); + } + + ov::Tensor output_tensor(output_type, output_shape, ggml_tensor->data); + return output_tensor; +} + +enum ggml_status ov_graph_compute_dynamic(ggml_cgraph * cgraph, std::shared_ptr r_ctx) { + auto & core = ov_singleton_core(); + const auto & config = ggml_openvino_get_compile_config(); + auto device = r_ctx->device; + bool stateful = r_ctx->stateful; + static auto is_static = false; + + if (is_naive(cgraph)) { + return naive_compute(cgraph, core, device, config); + } + + auto start_time = ggml_time_us(); + + std::shared_ptr ggml_decoder; + std::shared_ptr infer_request; + ModelParams m_params; + ComputeParams c_params; + std::tie(m_params, c_params) = GgmlOvDecoder::compute_llm_params(cgraph, is_static); + + graph_key key(cgraph); + bool cache_hit; + + int64_t decoder_end_time; + int64_t conversion_end_time; + int64_t compile_end_time; + int64_t infer_end_time; + + { + std::lock_guard lock(r_ctx->ov_compute_mutex); + + auto it = r_ctx->decoder_cache.find(key); + + cache_hit = it != r_ctx->decoder_cache.end(); + ModelParams old_m_params; + if (cache_hit) { + ggml_decoder = it->second; + old_m_params = ggml_decoder->get_model_params(); + cache_hit = old_m_params.can_reuse_dynamically(m_params); + } + + if (cache_hit) { + std::map> model_weights; + ggml_decoder->set_compute_params(c_params); + ggml_decoder->set_model_params(m_params); + if (old_m_params.kv_buffer_changed(m_params)) { + ggml_decoder->update_io(cgraph); + } + ggml_decoder->add_extra_inputs(); + infer_request = r_ctx->infer_request_cache.at(key); + + if (stateful) { + const auto * inp_pos = get_inp_pos_tensor(cgraph); + int32_t * pos_data = (int32_t *) inp_pos->data; + auto pos_shape = ggml_decoder->get_shape(inp_pos); + if (pos_data[0] == 0) { + infer_request->reset_state(); + r_ctx->stateful_kv_size = pos_shape[3]; + } else if (r_ctx->stateful_kv_size == static_cast(pos_data[0])) { + r_ctx->stateful_kv_size += pos_shape[3]; + } else { + auto states = infer_request->query_state(); + for (auto state : states) { + auto state_tensor = state.get_state(); + auto state_tensor_shape = state_tensor.get_shape(); + if (static_cast(pos_data[0]) > r_ctx->stateful_kv_size) { + std::string state_name; + try { + state_name = r_ctx->kv_state_input_name_map.at(state.get_name()); + } catch (...) { + GGML_LOG_ERROR("GGML OpenVINO backend stateful inference failed: no input found for the state\n"); + return GGML_STATUS_FAILED; + } + auto kv_tensor = get_ov_input_tensor(ggml_decoder, state_name); + kv_tensor.set_shape({state_tensor_shape[0], kv_tensor.get_shape()[2], + state_tensor_shape[2], state_tensor_shape[3]}); + state_tensor = kv_tensor; + state_tensor_shape = state_tensor.get_shape(); + } + ov::Coordinate begin = {0, 0, 0, 0}; + ov::Coordinate end = {state_tensor_shape[0], static_cast(pos_data[0]), + state_tensor_shape[2], state_tensor_shape[3]}; + ov::Tensor new_state_tensor(state_tensor, begin, end); + state.set_state(new_state_tensor); + } + r_ctx->stateful_kv_size = pos_data[0] + 1; + } + } + + decoder_end_time = ggml_time_us(); + conversion_end_time = decoder_end_time; + compile_end_time = decoder_end_time; + } else { + r_ctx->infer_request_cache.erase(key); + + std::shared_ptr model; + auto model_weights = GgmlOvDecoder::create_weight_nodes(cgraph); + + ggml_decoder = std::make_shared(cgraph, m_params, c_params, model_weights, is_static, stateful); + decoder_end_time = ggml_time_us(); + + auto input_model = std::make_shared(ggml_decoder); + model = ov::frontend::ggml::FrontEnd::convert(input_model); + ggml_decoder->clear_model_weights(); + conversion_end_time = ggml_time_us(); + + if (getenv("GGML_OPENVINO_DUMP_IR")) { + char timestamped_filename[64]; + auto timestamp = (long long) ggml_time_us(); + snprintf(timestamped_filename, sizeof(timestamped_filename), "model_%lld.xml", timestamp); + ov::serialize(model, timestamped_filename); + } + + ov::CompiledModel compiled_model; + auto remote_context = ggml_openvino_get_remote_context(); + if (remote_context.has_value()) { + compiled_model = core.compile_model(model, remote_context.value(), config); + } else { + compiled_model = core.compile_model(model, device, config); + } + compile_end_time = ggml_time_us(); + infer_request = std::make_shared(compiled_model.create_infer_request()); + r_ctx->infer_request_cache[key] = infer_request; + r_ctx->decoder_cache[key] = ggml_decoder; + + std::vector ov_input_names; + std::vector ov_output_names; + for (const auto & ov_param : model->get_parameters()) { + ov_input_names.push_back(ov_param->get_friendly_name()); + } + for (const auto & ov_output : model->get_results()) { + ov_output_names.push_back(ov_output->get_friendly_name()); + } + r_ctx->ov_input_names_cache[key] = std::move(ov_input_names); + r_ctx->ov_output_names_cache[key] = std::move(ov_output_names); + + if (stateful) { + const auto * inp_pos = get_inp_pos_tensor(cgraph); + auto pos_shape = ggml_decoder->get_shape(inp_pos); + r_ctx->stateful_kv_size = pos_shape[3]; + const auto kv_param_res_names = ggml_decoder->get_kv_param_res_names(); + for (const auto& pair : kv_param_res_names) { + r_ctx->kv_state_input_name_map[pair.first+pair.second] = pair.first; + } + } + } + + auto ov_input_names = r_ctx->ov_input_names_cache[key]; + auto ov_output_names = r_ctx->ov_output_names_cache[key]; + + for (size_t i = 0; i < ov_input_names.size(); i++) { + auto param_name = ov_input_names[i]; + auto input_tensor = get_ov_input_tensor(ggml_decoder, param_name); + infer_request->set_input_tensor(i, input_tensor); + + if (getenv("GGML_OPENVINO_DEBUG_INPUT")) { + print_input_tensor_info(param_name, input_tensor); + } + } + + for (size_t i = 0; i < ov_output_names.size(); i++) { + auto * ggml_tensor = ggml_decoder->get_model_outputs().at(ov_output_names[i]); + auto output_tensor = create_ov_output_tensor(ggml_decoder, infer_request, i, ggml_tensor); + infer_request->set_output_tensor(i, output_tensor); + } + + infer_request->infer(); + infer_end_time = ggml_time_us(); + + if (getenv("GGML_OPENVINO_DEBUG_OUTPUT")) { + for (size_t i = 0; i < ov_output_names.size(); i++) { + const auto output_tensor = infer_request->get_output_tensor(i); + print_output_tensor_info(ov_output_names[i], output_tensor, output_tensor.data()); + } + } + + if (getenv("GGML_OPENVINO_PROFILING")) { + GGML_LOG_INFO("\nGGML OpenVINO Backend: \n"); + GGML_LOG_INFO(" - Graph decoder time: %ld ms \n", (decoder_end_time - start_time) / 1000); + if (!cache_hit) { + GGML_LOG_INFO(" - Graph conversion time: %ld ms \n", (conversion_end_time - decoder_end_time) / 1000); + GGML_LOG_INFO(" - Graph compile time: %ld ms \n", (compile_end_time - conversion_end_time) / 1000); + } + GGML_LOG_INFO(" - Graph inference time: %ld ms \n", (infer_end_time - compile_end_time) / 1000); + } + } + + return GGML_STATUS_SUCCESS; +} + +enum ggml_status ov_graph_compute_static(ggml_cgraph * cgraph, std::shared_ptr r_ctx) { + auto & core = ov_singleton_core(); + + auto get_prefill_chunk_size = [] { + const char * chunk_size_str = getenv("GGML_OPENVINO_PREFILL_CHUNK_SIZE"); + if (chunk_size_str && atoi(chunk_size_str) > 0) { + return atoi(chunk_size_str); + } + return 256; + }; + + static std::string device = "NPU"; + static auto is_static = true; + static auto stateful = false; + static auto prefill_chunk_size = get_prefill_chunk_size(); + const auto & config = ggml_openvino_get_compile_config(); + + if (is_naive(cgraph)) { + return naive_compute(cgraph, core, device, config); + } + + auto start_time = ggml_time_us(); + + std::shared_ptr ggml_decoder; + std::shared_ptr infer_request; + ModelParams m_params; + ComputeParams c_params; + std::tie(m_params, c_params) = GgmlOvDecoder::compute_llm_params(cgraph, is_static); + + const auto * inp_pos = get_inp_pos_tensor(cgraph); + const auto is_prefill = get_is_prefill(inp_pos); + graph_key key(cgraph); + bool cache_hit; + + int64_t decoder_end_time; + int64_t conversion_end_time; + int64_t compile_end_time; + int64_t infer_end_time; + + auto it = r_ctx->decoder_cache.find(key); + + cache_hit = it != r_ctx->decoder_cache.end(); + ModelParams old_m_params; + if (cache_hit) { + ggml_decoder = it->second; + old_m_params = ggml_decoder->get_model_params(); + cache_hit = old_m_params.can_reuse_statically(m_params); + } + + if (cache_hit) { + std::map> model_weights; + ggml_decoder->m_is_prefill = is_prefill; + ggml_decoder->set_model_params(m_params); + ggml_decoder->set_compute_params(c_params); + if (old_m_params.kv_buffer_changed(m_params)) { + ggml_decoder->update_io(cgraph); + } + ggml_decoder->add_extra_inputs(); + infer_request = is_prefill ? r_ctx->infer_request_cache_prefill.at(key) : r_ctx->infer_request_cache.at(key); + + decoder_end_time = ggml_time_us(); + conversion_end_time = decoder_end_time; + compile_end_time = decoder_end_time; + } else { + r_ctx->infer_request_cache.erase(key); + r_ctx->infer_request_cache_prefill.erase(key); + + std::shared_ptr model; + auto model_weights = GgmlOvDecoder::create_weight_nodes(cgraph); + + auto ggml_decoder_prefill = std::make_shared(cgraph, m_params, c_params, model_weights, + is_static, stateful, true, prefill_chunk_size); + auto ggml_decoder_decode = std::make_shared(cgraph, m_params, c_params, model_weights, is_static, + stateful, false, prefill_chunk_size); + decoder_end_time = ggml_time_us(); + + auto input_model_prefill = std::make_shared(ggml_decoder_prefill); + auto input_model_decode = std::make_shared(ggml_decoder_decode); + + auto model_prefill = ov::frontend::ggml::FrontEnd::convert(input_model_prefill); + ggml_decoder_prefill->clear_model_weights(); + auto model_decode = ov::frontend::ggml::FrontEnd::convert(input_model_decode); + ggml_decoder_decode->clear_model_weights(); + conversion_end_time = ggml_time_us(); + + if (getenv("GGML_OPENVINO_DUMP_IR")) { + char timestamped_filename[64]; + auto timestamp = (long long) ggml_time_us(); + snprintf(timestamped_filename, sizeof(timestamped_filename), "model_prefill_%lld.xml", timestamp); + ov::serialize(model_prefill, timestamped_filename); + snprintf(timestamped_filename, sizeof(timestamped_filename), "model_decode_%lld.xml", timestamp); + ov::serialize(model_decode, timestamped_filename); + } + + ov::CompiledModel compiled_model_prefill; + ov::CompiledModel compiled_model_decode; + auto remote_context = ggml_openvino_get_remote_context(); + if (remote_context.has_value()) { + compiled_model_prefill = core.compile_model(model_prefill, remote_context.value(), config); + compiled_model_decode = core.compile_model(model_decode, remote_context.value(), config); + } else { + compiled_model_prefill = core.compile_model(model_prefill, device, config); + compiled_model_decode = core.compile_model(model_decode, device, config); + } + + r_ctx->infer_request_cache_prefill[key] = + std::make_shared(compiled_model_prefill.create_infer_request()); + r_ctx->infer_request_cache[key] = + std::make_shared(compiled_model_decode.create_infer_request()); + compile_end_time = ggml_time_us(); + + model = is_prefill ? model_prefill : model_decode; + ggml_decoder = is_prefill ? ggml_decoder_prefill : ggml_decoder_decode; + infer_request = is_prefill ? r_ctx->infer_request_cache_prefill[key] : r_ctx->infer_request_cache[key]; + r_ctx->decoder_cache[key] = ggml_decoder; + + std::vector ov_input_names; + std::vector ov_output_names; + for (const auto & ov_param : model->get_parameters()) { + ov_input_names.push_back(ov_param->get_friendly_name()); + } + for (const auto & ov_output : model->get_results()) { + ov_output_names.push_back(ov_output->get_friendly_name()); + } + r_ctx->ov_input_names_cache[key] = std::move(ov_input_names); + r_ctx->ov_output_names_cache[key] = std::move(ov_output_names); + } + + auto ov_input_names = r_ctx->ov_input_names_cache[key]; + auto ov_output_names = r_ctx->ov_output_names_cache[key]; + + if (is_prefill) { + auto inp_len = inp_pos->ne[0]; + for (int chunk_index = 0; chunk_index * prefill_chunk_size < inp_len; chunk_index++) { + for (size_t i = 0; i < ov_input_names.size(); i++) { + auto param_name = ov_input_names[i]; + auto input_tensor = get_ov_input_tensor_static_prefill(ggml_decoder, param_name, chunk_index); + infer_request->set_input_tensor(i, input_tensor); + + if (getenv("GGML_OPENVINO_DEBUG_INPUT")) { + const auto input_tensor = infer_request->get_input_tensor(i); + print_input_tensor_info(param_name, input_tensor); + } + } + + for (size_t i = 0; i < ov_output_names.size(); i++) { + auto * ggml_tensor = ggml_decoder->get_model_outputs().at(ov_output_names[i]); + auto output_tensor = create_ov_output_tensor(ggml_decoder, infer_request, i, ggml_tensor); + infer_request->set_output_tensor(i, output_tensor); + } + + infer_request->infer(); + + if (getenv("GGML_OPENVINO_DEBUG_OUTPUT")) { + for (size_t i = 0; i < ov_output_names.size(); i++) { + const auto output_tensor = infer_request->get_output_tensor(i); + print_output_tensor_info(ov_output_names[i], output_tensor, output_tensor.data()); + } + } + } + infer_end_time = ggml_time_us(); + } else { + for (size_t i = 0; i < ov_input_names.size(); i++) { + auto param_name = ov_input_names[i]; + auto input_tensor = get_ov_input_tensor_static_decode(ggml_decoder, param_name); + infer_request->set_input_tensor(i, input_tensor); + + if (getenv("GGML_OPENVINO_DEBUG_INPUT")) { + const auto input_tensor = infer_request->get_input_tensor(i); + print_input_tensor_info(param_name, input_tensor); + } + } + + for (size_t i = 0; i < ov_output_names.size(); i++) { + auto * ggml_tensor = ggml_decoder->get_model_outputs().at(ov_output_names[i]); + auto output_tensor = create_ov_output_tensor(ggml_decoder, infer_request, i, ggml_tensor); + infer_request->set_output_tensor(i, output_tensor); + } + + infer_request->infer(); + infer_end_time = ggml_time_us(); + + if (getenv("GGML_OPENVINO_DEBUG_OUTPUT")) { + for (size_t i = 0; i < ov_output_names.size(); i++) { + const auto output_tensor = infer_request->get_output_tensor(i); + print_output_tensor_info(ov_output_names[i], output_tensor, output_tensor.data()); + } + } + } + + if (getenv("GGML_OPENVINO_PROFILING")) { + GGML_LOG_INFO("\nGGML OpenVINO Backend: \n"); + GGML_LOG_INFO(" - Graph decoder time: %ld ms \n", (decoder_end_time - start_time) / 1000); + if (!cache_hit) { + GGML_LOG_INFO(" - Graph conversion time: %ld ms \n", (conversion_end_time - decoder_end_time) / 1000); + GGML_LOG_INFO(" - Graph compile time: %ld ms \n", (compile_end_time - conversion_end_time) / 1000); + } + GGML_LOG_INFO(" - Graph inference time: %ld ms \n", (infer_end_time - compile_end_time) / 1000); + } + + return GGML_STATUS_SUCCESS; +} + +bool is_naive(ggml_cgraph * cgraph) { + constexpr int naive_graph_size_threshold = 20; + int count = 0; + for (int i = 0; i < cgraph->n_nodes; i++) { + if (cgraph->nodes[i]->op != GGML_OP_NONE) { + count++; + } + } + return count < naive_graph_size_threshold; +} + +enum ggml_status naive_compute(ggml_cgraph * cgraph, + ov::Core & core, + const std::string & device, + const ov::AnyMap & config) { + if (cgraph->n_nodes == 1 && (cgraph->nodes[0]->op == GGML_OP_NONE || cgraph->nodes[0]->op == GGML_OP_VIEW)) { + return GGML_STATUS_SUCCESS; + } + + bool naive = true; + auto model_weights = GgmlOvDecoder::create_weight_nodes(cgraph, naive); + auto decoder = std::make_shared(cgraph, model_weights); + auto input_model = std::make_shared(decoder); + auto model = ov::frontend::ggml::FrontEnd::convert(input_model, naive); + if (getenv("GGML_OPENVINO_DUMP_IR")) { + ov::serialize(model, "IR_naive.xml"); + } + + std::shared_ptr infer_request; + auto remote_context = ggml_openvino_get_remote_context(); + if (cgraph->nodes[0]->op == GGML_OP_MUL_MAT) { + // TODO ACCURACY hint triggers a bug in GPU plugin/driver on Lunar Lake. Remove once CVS-182166 is resolved + core.set_property(device, ov::hint::execution_mode(ov::hint::ExecutionMode::PERFORMANCE)); + } else { + core.set_property(device, ov::hint::execution_mode(ov::hint::ExecutionMode::ACCURACY)); + } + if (remote_context.has_value()) { + infer_request = std::make_shared( + core.compile_model(model, remote_context.value(), config).create_infer_request()); + } else { + infer_request = + std::make_shared(core.compile_model(model, device, config).create_infer_request()); + } + + auto ov_params = model->get_parameters(); + for (size_t i = 0; i < ov_params.size(); i++) { + auto param_name = ov_params[i]->get_friendly_name(); + auto input_tensor = get_ov_input_tensor(decoder, param_name); + infer_request->set_input_tensor(i, input_tensor); + } + + auto ov_results = model->get_results(); + for (size_t i = 0; i < ov_results.size(); i++) { + auto * ggml_tensor = decoder->get_model_outputs().at(ov_results[i]->get_friendly_name()); + auto output_tensor = create_ov_output_tensor(decoder, infer_request, i, ggml_tensor); + infer_request->set_output_tensor(i, output_tensor); + } + + infer_request->infer(); + return GGML_STATUS_SUCCESS; +} + +namespace { +ov::Tensor convert_ggml_input_to_ov(std::shared_ptr ggml_decoder, const std::string & name) { + const auto * ggml_tensor = ggml_decoder->get_input_ggml_tensor(name); + + if (ggml_tensor->extra != nullptr) { + // GGML_LOG_DEBUG("Using ggml_tensor->extra as ov::Tensor for input: %s\n", name.c_str()); + auto * extra_base = static_cast(ggml_tensor->extra); + if (extra_base->type != ggml_openvino_extra_base::Type::TENSOR) { + throw std::runtime_error("ggml tensor extra is not of type TENSOR for input: " + name); + } + auto * tensor_extra = static_cast(extra_base); + return *tensor_extra->tensor; + } + + // GGML_LOG_DEBUG("Converting ggml tensor to ov::Tensor for input: %s\n", name.c_str()); + auto * input_data = ggml_tensor->data; + ov::Shape input_shape; + if (ggml_tensor->op == GGML_OP_VIEW) { + // This case is added to make test-backend-ops work + input_shape = ggml_decoder->get_shape(ggml_tensor->view_src); + } else { + input_shape = ggml_decoder->get_shape(ggml_tensor); + } + auto input_tensor = ov::Tensor(ggml_decoder->get_ov_type(ggml_tensor), input_shape, input_data); + return input_tensor; +} +} // namespace + +ov::Tensor get_ov_input_tensor(std::shared_ptr ggml_decoder, const std::string & param_name) { + ov::Tensor input_tensor; + if (ggml_decoder->get_model_extra_inputs().find(param_name) != ggml_decoder->get_model_extra_inputs().end()) { + input_tensor = *ggml_decoder->get_model_extra_input_values().at(param_name); + } else { + input_tensor = convert_ggml_input_to_ov(ggml_decoder, param_name); + } + return input_tensor; +} + +ov::Tensor get_ov_input_tensor_static_decode(std::shared_ptr ggml_decoder, + const std::string & param_name) { + // NPU decoding stage + const auto * ggml_tensor = ggml_decoder->get_input_ggml_tensor(param_name); + const auto * op = ggml_decoder->get_tensor_used_op(ggml_tensor); + + if (GgmlOvDecoder::is_inp_tok(ggml_tensor, op) || GgmlOvDecoder::is_inp_pos(ggml_tensor, op) || + GgmlOvDecoder::is_kv_idx(ggml_tensor, op)) { + assert(ggml_tensor->ne[0] == 1); + ov::Shape input_shape = {1, 1, 1, 1}; + ov::Tensor input_tensor(ggml_decoder->get_ov_type(ggml_tensor), input_shape); + if (ggml_tensor->type == GGML_TYPE_I32) { + *input_tensor.data() = *((int32_t *) ggml_tensor->data); + } else if (ggml_tensor->type == GGML_TYPE_I64) { + *input_tensor.data() = *((int64_t *) ggml_tensor->data); + } else { + throw std::runtime_error("Unexpected tensor type for " + param_name); + } + return input_tensor; + } + + if (GgmlOvDecoder::is_output_idx(ggml_tensor, op)) { + ov::Shape input_shape = {1, 1, 1, 1}; + ov::Tensor input_tensor(ggml_decoder->get_ov_type(ggml_tensor), input_shape); + int32_t inp_out_id = *((int32_t *) ggml_tensor->data); + assert(ggml_tensor->ne[0] == 1); + assert(inp_out_id == 0); + *input_tensor.data() = inp_out_id; + return input_tensor; + } + + if (GgmlOvDecoder::is_inp_mask(ggml_tensor, op)) { + size_t context_size = ggml_decoder->get_ctx_size(); + std::vector padded_data = pad_input(ggml_tensor, 1, context_size, -INFINITY); + ov::Tensor input_tensor(ov::element::f32, ov::Shape{1, 1, 1, context_size}); + auto * data_ptr = input_tensor.data(); + std::copy(padded_data.begin(), padded_data.begin() + context_size, data_ptr); + return input_tensor; + } + + return get_ov_input_tensor(ggml_decoder, param_name); +} + +ov::Tensor get_ov_input_tensor_static_prefill(std::shared_ptr ggml_decoder, + const std::string & param_name, + int chunk_index) { + // NPU prompt processing stage + const auto * ggml_tensor = ggml_decoder->get_input_ggml_tensor(param_name); + const auto * op = ggml_decoder->get_tensor_used_op(ggml_tensor); + + const size_t input_len = ggml_decoder->get_input_len(); + const size_t chunk_size = ggml_decoder->m_prefill_chunk_size; + const size_t chunk_valid_size = std::min(chunk_size, input_len - chunk_index * chunk_size); + const size_t chunk_pad_size = chunk_size - chunk_valid_size; + + if (GgmlOvDecoder::is_inp_tok(ggml_tensor, op) || GgmlOvDecoder::is_inp_pos(ggml_tensor, op) || + GgmlOvDecoder::is_kv_idx(ggml_tensor, op)) { + ov::Shape input_shape = {1, 1, 1, chunk_size}; + ov::Tensor input_tensor(ggml_decoder->get_ov_type(ggml_tensor), input_shape); + // copy the chunk_index-th chunk from ggml_tensor + size_t element_size = ggml_type_size(ggml_tensor->type); + void * input_data = (char *) ggml_tensor->data + chunk_index * chunk_size * element_size; + std::memcpy(input_tensor.data(), input_data, chunk_valid_size * element_size); + // pad the rest with last_value + 1, so that kv's of padded positions are inserted + // to the next row after the valids row in the kvcache + if (chunk_pad_size > 0) { + if (ggml_tensor->type == GGML_TYPE_I32) { + int32_t last_value = + *((int32_t *) ggml_tensor->data + (chunk_index * chunk_size + chunk_valid_size - 1)); + int32_t * output_data = input_tensor.data(); + std::fill(output_data + chunk_valid_size, output_data + chunk_size, last_value + 1); + } else if (ggml_tensor->type == GGML_TYPE_I64) { + int64_t last_value = + *((int64_t *) ggml_tensor->data + (chunk_index * chunk_size + chunk_valid_size - 1)); + int64_t * output_data = input_tensor.data(); + std::fill(output_data + chunk_valid_size, output_data + chunk_size, last_value + 1); + } else { + throw std::runtime_error("Unexpected tensor type for " + param_name); + } + } + return input_tensor; + } + + if (GgmlOvDecoder::is_output_idx(ggml_tensor, op)) { + size_t output_len = ggml_decoder->get_compute_params().output_len; + ov::Shape input_shape = {1, 1, 1, output_len}; + ov::Tensor input_tensor(ggml_decoder->get_ov_type(ggml_tensor), input_shape); + if (ggml_tensor->ne[0] == 0) { + *input_tensor.data() = 0; + } else { + auto * data_addr = input_tensor.data(); + for (size_t i = 0; i < output_len; i++) { + data_addr[i] = ((int32_t *) ggml_tensor->data)[i] % chunk_size; + } + } + return input_tensor; + } + + if (GgmlOvDecoder::is_inp_mask(ggml_tensor, op)) { + size_t cols = ggml_tensor->ne[0]; + size_t rows = ggml_tensor->ne[1]; + float * ggml_data = (float *) ggml_tensor->data + chunk_index * chunk_size * cols; + size_t chunk_valid_rows = std::min(chunk_size, rows - chunk_index * chunk_size); + size_t context_size = ggml_decoder->get_ctx_size(); + std::vector padded_data = + pad_input(ggml_data, chunk_valid_rows, cols, chunk_size, context_size, -INFINITY); + set_zero_diagonal(padded_data, chunk_size, context_size); + ov::Tensor input_tensor(ov::element::f32, ov::Shape{1, 1, chunk_size, context_size}); + auto * data_ptr = input_tensor.data(); + std::copy(padded_data.begin(), padded_data.begin() + chunk_size * context_size, data_ptr); + return input_tensor; + } + + return get_ov_input_tensor(ggml_decoder, param_name); +} + +size_t checksum(const void * data, size_t size) { + const uint8_t * bytes = static_cast(data); + size_t sum = 0; + for (size_t i = 0; i < size; ++i) { + sum += (uint8_t) i; + sum += bytes[i]; + } + return sum; +} + +void print_input_tensor_info(const std::string & name, const ov::Tensor & tensor) { + std::cout << "Input name: " << name << ", Input shape: " << tensor.get_shape() << ", Address: " << tensor.data() + << std::endl; + switch (tensor.get_element_type()) { + case ov::element::f32: { + if (name.find("self_kq_mask") == std::string::npos) { + std::cout << *(tensor.data()) << std::endl; + } else { + size_t rows = tensor.get_shape()[2]; + size_t cols = tensor.get_shape()[3]; + auto * data = tensor.data(); + for (size_t i = 0; i < rows; ++i) { + for (size_t j = 0; j < cols; ++j) { + float val = data[i * cols + j]; + if (std::isinf(val) && val < 0) { + std::cout << std::setw(5) << "-inf"; + } else { + std::cout << std::setw(5) << val; + } + } + std::cout << std::endl; + } + } + + break; + } + case ov::element::f16: + std::cout << *(tensor.data()) << std::endl; + break; + case ov::element::i32: + for (size_t i = 0; i < tensor.get_size(); ++i) { + std::cout << tensor.data()[i] << " "; + } + std::cout << std::endl; + break; + case ov::element::i64: + for (size_t i = 0; i < tensor.get_size(); ++i) { + std::cout << tensor.data()[i] << " "; + } + std::cout << std::endl; + break; + default: + break; + } +} + +void print_output_tensor_info(const std::string & name, const ov::Tensor & tensor, const void * output_dst) { + std::cout << "Output name: " << name << ", Output shape: " << tensor.get_shape() << ", Address: " << output_dst + << std::endl; + + auto print_float_stats = [](const std::string & type_name, size_t size, auto get_value) { + if (size == 0) { + return; + } + + float first = get_value(0); + float min = first; + float max = first; + double sum = first; + + for (size_t i = 1; i < size; ++i) { + float v = get_value(i); + if (v < min) { + min = v; + } + if (v > max) { + max = v; + } + sum += v; + } + double mean = sum / size; + + std::cout << std::right << std::setw(6) << type_name << std::right << std::setw(12) << "First" << std::setw(12) + << "Min" << std::setw(12) << "Max" << std::setw(12) << "Mean" << std::endl; + std::cout << std::right << std::setw(6) << "" << std::right << std::setw(12) << first << std::setw(12) << min + << std::setw(12) << max << std::setw(12) << mean << std::endl; + }; + + switch (tensor.get_element_type()) { + case ov::element::f32: { + const float * data = tensor.data(); + size_t size = tensor.get_size(); + print_float_stats("[f32]", size, [data](size_t i) { return data[i]; }); + break; + } + case ov::element::f16: { + const ov::float16 * data = tensor.data(); + size_t size = tensor.get_size(); + print_float_stats("[f16]", size, [data](size_t i) { return static_cast(data[i]); }); + break; + } + default: + break; + } +} + +void set_zero_diagonal(std::vector & matrix, size_t rows, size_t cols) { + for (size_t i = 0; i < rows; ++i) { + size_t diag_col = std::min(i, cols - 1); + matrix[i * cols + diag_col] = 0.0f; + } +} + +const ggml_tensor * get_inp_pos_tensor(ggml_cgraph * cgraph) { + for (int i = 0; i < cgraph->n_nodes; ++i) { + auto * op = cgraph->nodes[i]; + for (int j = 0; j < GGML_MAX_SRC; ++j) { + auto * src = op->src[j]; + if (src == nullptr) { + break; + } + if (GgmlOvDecoder::is_inp_pos(src, op)) { + return src; + } + } + } + GGML_LOG_ERROR("get_inp_pos_tensor: inp_pos not found in cgraph"); + throw std::runtime_error("get_inp_pos_tensor: inp_pos not found in cgraph"); +} + +bool get_is_prefill(const ggml_tensor * inp_pos) { + return inp_pos->ne[0] > 1; +} + +#pragma GCC diagnostic pop diff --git a/ggml/src/ggml-openvino/utils.h b/ggml/src/ggml-openvino/utils.h new file mode 100644 index 00000000000..656573d1389 --- /dev/null +++ b/ggml/src/ggml-openvino/utils.h @@ -0,0 +1,123 @@ +#include "ggml-backend-impl.h" +#include "ggml-decoder.h" +#include "ggml-impl.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +struct graph_key { + int n_nodes; + std::string first_node_name; + std::string last_node_name; + + graph_key(const ggml_cgraph * cgraph) : n_nodes(cgraph->n_nodes) { + if (n_nodes > 0) { + first_node_name = cgraph->nodes[0]->name; + last_node_name = cgraph->nodes[n_nodes - 1]->name; + } + } + + bool operator==(const graph_key & other) const { + return n_nodes == other.n_nodes && first_node_name == other.first_node_name && + last_node_name == other.last_node_name; + } +}; + +struct graph_key_hash { + size_t operator()(const graph_key & key) const { + size_t h = std::hash{}(key.n_nodes); + if (key.n_nodes > 0) { + h ^= std::hash{}(key.first_node_name) + 0x9e3779b9 + (h << 6) + (h >> 2); + h ^= std::hash{}(key.last_node_name) + 0x9e3779b9 + (h << 6) + (h >> 2); + } + return h; + } +}; + +struct ov_runtime_context { + std::mutex ov_compute_mutex; + std::string device; + bool stateful; + std::unordered_map, graph_key_hash> decoder_cache; + std::unordered_map, graph_key_hash> infer_request_cache; + std::unordered_map, graph_key_hash> infer_request_cache_prefill; + std::unordered_map, graph_key_hash> ov_input_names_cache; + std::unordered_map, graph_key_hash> ov_output_names_cache; + //TODO: Stateful is only supported for single request at a time. + // Simultanous stateful inference request support to be added. + size_t stateful_kv_size; + std::map kv_state_input_name_map; + + ov_runtime_context() : + device("CPU"), + stateful(false), + stateful_kv_size(0) {} +}; + +enum ggml_status ov_graph_compute(struct ggml_cgraph * cgraph, ggml_backend_t backend); + +enum ggml_status ov_graph_compute_dynamic(struct ggml_cgraph * cgraph, std::shared_ptr r_ctx); +enum ggml_status ov_graph_compute_static(struct ggml_cgraph * cgraph, std::shared_ptr r_ctx); + +size_t checksum(const void * data, size_t size); + +void print_input_tensor_info(const std::string & name, const ov::Tensor & tensor); + +void print_output_tensor_info(const std::string & name, const ov::Tensor & tensor, const void * output_dst); + +template +std::vector pad_input(const T * data, + size_t rows, + size_t cols, + size_t padded_rows, + size_t padded_cols, + T pad_value) { + std::vector padded(padded_rows * padded_cols, pad_value); + + for (size_t i = 0; i < std::min(rows, padded_rows); ++i) { + for (size_t j = 0; j < std::min(cols, padded_cols); ++j) { + padded[i * padded_cols + j] = data[i * cols + j]; + } + } + + return padded; +} + +template +std::vector pad_input(const ggml_tensor * tensor, size_t padded_rows, size_t padded_cols, T pad_value) { + return pad_input(reinterpret_cast(tensor->data), + static_cast(tensor->ne[1]), // rows + static_cast(tensor->ne[0]), // cols + padded_rows, padded_cols, pad_value); +} + +void set_zero_diagonal(std::vector & matrix, size_t rows, size_t cols); + +const ggml_tensor * get_inp_pos_tensor(struct ggml_cgraph * cgraph); + +bool get_is_prefill(const ggml_tensor * inp_pos); + +ov::Tensor get_ov_input_tensor(std::shared_ptr ggml_decoder, const std::string & param_name); +ov::Tensor get_ov_input_tensor_static_decode(std::shared_ptr ggml_decoder, + const std::string & param_name); +ov::Tensor get_ov_input_tensor_static_prefill(std::shared_ptr ggml_decoder, + const std::string & param_name, + int chunk_index); + +ov::Tensor create_ov_output_tensor(std::shared_ptr ggml_decoder, + std::shared_ptr infer_request, + int output_index, + const ggml_tensor * ggml_tensor); + +bool is_naive(struct ggml_cgraph * cgraph); + +enum ggml_status naive_compute(struct ggml_cgraph * cgraph, + ov::Core & core, + const std::string & device, + const ov::AnyMap & config); From 8ad5cb1e9d9c5d81c2d9ed64993625509b16a89f Mon Sep 17 00:00:00 2001 From: Wallentri Date: Sat, 14 Mar 2026 10:43:13 +0300 Subject: [PATCH 273/831] Use fp32 in cuBLAS V100 to avoid overflows, env variables to override cuBLAS compute type (llama/19959) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Update ggml-cuda.cu * Update ggml-cuda.cu * Update build.md * Update build.md * Update ggml/src/ggml-cuda/ggml-cuda.cu Co-authored-by: Johannes Gäßler * Update ggml-cuda.cu * Update build.md * Update ggml/src/ggml-cuda/ggml-cuda.cu Co-authored-by: Johannes Gäßler * Update build.md * Update ggml-cuda.cu * Update ggml-cuda.cu --------- Co-authored-by: Johannes Gäßler --- ggml/src/ggml-cuda/ggml-cuda.cu | 71 +++++++++++++++++++++++++-------- 1 file changed, 55 insertions(+), 16 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 9d2aacf4b2c..ce7a80acde8 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -1242,6 +1242,34 @@ static cudaError_t ggml_cuda_cpy_tensor_2d( } } +struct cublas_force_compute_type { + bool fp32 = false; + bool fp16 = false; +}; + +static const cublas_force_compute_type & ggml_cuda_cublas_get_force_compute_type() { + static const cublas_force_compute_type compute_type = [] { + cublas_force_compute_type result; + + const bool ggml_cuda_force_cublas_compute_32f_env = getenv("GGML_CUDA_FORCE_CUBLAS_COMPUTE_32F") != nullptr; + const bool ggml_cuda_force_cublas_compute_16f_env = getenv("GGML_CUDA_FORCE_CUBLAS_COMPUTE_16F") != nullptr; + + GGML_ASSERT(ggml_cuda_force_cublas_compute_16f_env == false || ggml_cuda_force_cublas_compute_32f_env == false); + + if (ggml_cuda_force_cublas_compute_32f_env) { + GGML_LOG_INFO("Detected GGML_CUDA_FORCE_CUBLAS_COMPUTE_32F\n"); + result.fp32 = true; + } else if (ggml_cuda_force_cublas_compute_16f_env) { + GGML_LOG_INFO("Detected GGML_CUDA_FORCE_CUBLAS_COMPUTE_16F\n"); + result.fp16 = true; + } + + return result; + }(); + + return compute_type; +} + static void ggml_cuda_op_mul_mat_cublas( ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, @@ -1324,7 +1352,13 @@ static void ggml_cuda_op_mul_mat_cublas( CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream)); - if (GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) { + const auto & force_compute_type = ggml_cuda_cublas_get_force_compute_type(); + + if (!force_compute_type.fp16 && (GGML_CUDA_CC_IS_CDNA(cc) + || GGML_CUDA_CC_IS_RDNA4(cc) + || cc == GGML_CUDA_CC_VOLTA + || force_compute_type.fp32)) + { const float alpha = 1.0f; const float beta = 0.0f; CUBLAS_CHECK( @@ -1923,10 +1957,23 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct cudaDataType_t cu_data_type_b = traits::data_type; const void * alpha = traits::get_alpha(); const void * beta = traits::get_beta(); - const float alpha_f32 = 1.0f; - const float beta_f32 = 0.0f; - if (dst->op_params[0] == GGML_PREC_DEFAULT) { + const auto & force_compute_type = ggml_cuda_cublas_get_force_compute_type(); + + int id = ggml_cuda_get_device(); + const int cc = ggml_cuda_info().devices[id].cc; + static constexpr bool is_src0_type_f16 = src0_type == GGML_TYPE_F16; + + // bf16 and fp32 are already being computed in fp32 (ensure it using static_assert), + // so checking necessity of forced fp32 only for fp16 src0_type + static_assert(is_src0_type_f16 || traits::compute_type == CUBLAS_COMPUTE_32F); + + const bool need_compute_32f = is_src0_type_f16 && !force_compute_type.fp16 && (GGML_CUDA_CC_IS_CDNA(cc) + || GGML_CUDA_CC_IS_RDNA4(cc) + || cc == GGML_CUDA_CC_VOLTA + || force_compute_type.fp32); + + if (dst->op_params[0] == GGML_PREC_DEFAULT && !need_compute_32f) { if constexpr (src0_type == GGML_TYPE_F32) { dst_t = (char *) dst_ddf; // Direct F32 output } else { @@ -1936,18 +1983,10 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct } } else { dst_t = (char *) dst_ddf; - cu_compute_type = CUBLAS_COMPUTE_32F; - cu_data_type = CUDA_R_32F; - alpha = &alpha_f32; - beta = &beta_f32; - } - - int id = ggml_cuda_get_device(); - const int cc = ggml_cuda_info().devices[id].cc; - if (GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) { - cu_compute_type = CUBLAS_COMPUTE_32F; - alpha = &alpha_f32; - beta = &beta_f32; + cu_compute_type = batched_mul_mat_traits::compute_type; + cu_data_type = batched_mul_mat_traits::data_type; + alpha = batched_mul_mat_traits::get_alpha(); + beta = batched_mul_mat_traits::get_beta(); } GGML_ASSERT(ne12 % ne02 == 0); From 93d09fdb2376943ba5740729219ed984adaf31da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrien=20Gallou=C3=ABt?= Date: Sat, 14 Mar 2026 10:06:14 +0100 Subject: [PATCH 274/831] ggml : add native AVX512-FP16 support for F16 operations (llama/20529) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The overall benchmark speed remains almost the same because the CPU is now calculating faster than the RAM can deliver the data. (See perf stat results below showing 2.7 billion fewer instructions). Also note that this path will be only enabled for native build or with custom flags. now: ``` Performance counter stats for 'build/bin/llama-bench -m Qwen3-0.6B-f16.gguf -p 512 -n 128': 189,073.52 msec task-clock # 14.658 CPUs utilized 404 context-switches # 2.137 /sec 19 cpu-migrations # 0.100 /sec 372,390 page-faults # 1.970 K/sec 310,877,195,595 instructions # 0.54 insn per cycle 581,071,530,602 cycles # 3.073 GHz 19,352,107,994 branches # 102.352 M/sec 48,304,438 branch-misses # 0.25% of all branches 84,998,431,152 L1-dcache-loads # 449.552 M/sec 12,186,410,279 L1-dcache-load-misses # 14.34% of all L1-dcache accesses 12.899358742 seconds time elapsed 187.823044000 seconds user 1.253416000 seconds sys ``` before: ``` Performance counter stats for 'build/bin/llama-bench -m Qwen3-0.6B-f16.gguf -p 512 -n 128': 190,594.56 msec task-clock # 14.652 CPUs utilized 436 context-switches # 2.288 /sec 22 cpu-migrations # 0.115 /sec 372,782 page-faults # 1.956 K/sec 313,574,921,966 instructions # 0.54 insn per cycle 586,064,970,425 cycles # 3.075 GHz 19,585,778,563 branches # 102.761 M/sec 48,437,488 branch-misses # 0.25% of all branches 86,219,336,628 L1-dcache-loads # 452.370 M/sec 12,232,085,771 L1-dcache-load-misses # 14.19% of all L1-dcache accesses 13.007923164 seconds time elapsed 189.395316000 seconds user 1.202612000 seconds sys ``` Signed-off-by: Adrien Gallouët --- ggml/src/ggml-cpu/simd-mappings.h | 46 +++++++++++++++++++++++++++++-- 1 file changed, 43 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-cpu/simd-mappings.h b/ggml/src/ggml-cpu/simd-mappings.h index 22de55700d4..0deda930985 100644 --- a/ggml/src/ggml-cpu/simd-mappings.h +++ b/ggml/src/ggml-cpu/simd-mappings.h @@ -479,13 +479,51 @@ do { \ // F16 AVX512 -// F16 AVX +#if defined(__AVX512FP16__) + +#define GGML_F16_STEP 128 +#define GGML_F16_EPR 32 + +#define GGML_F16x32 __m512h +#define GGML_F16x32_ZERO _mm512_setzero_ph() +#define GGML_F16x32_SET1(x) _mm512_set1_ph(__extension__(_Float16)(x)) +#define GGML_F16x32_LOAD(x) _mm512_loadu_ph(x) +#define GGML_F16x32_STORE(x, y) _mm512_storeu_ph(x, y) +#define GGML_F16x32_FMA(a, b, c) _mm512_fmadd_ph(b, c, a) +#define GGML_F16x32_ADD _mm512_add_ph +#define GGML_F16x32_MUL _mm512_mul_ph +#define GGML_F16x32_REDUCE(res, x) \ +do { \ + int offset = GGML_F16_ARR >> 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = _mm512_add_ph(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = _mm512_add_ph(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = _mm512_add_ph(x[i], x[offset+i]); \ + } \ + res = (ggml_float) _mm512_reduce_add_ph(x[0]); \ +} while (0) + +#define GGML_F16_VEC GGML_F16x32 +#define GGML_F16_VEC_ZERO GGML_F16x32_ZERO +#define GGML_F16_VEC_SET1 GGML_F16x32_SET1 +#define GGML_F16_VEC_LOAD(p, i) GGML_F16x32_LOAD(p) +#define GGML_F16_VEC_STORE(p, r, i) GGML_F16x32_STORE(p, r[i]) +#define GGML_F16_VEC_FMA GGML_F16x32_FMA +#define GGML_F16_VEC_ADD GGML_F16x32_ADD +#define GGML_F16_VEC_MUL GGML_F16x32_MUL +#define GGML_F16_VEC_REDUCE GGML_F16x32_REDUCE + +#else // Fallback FP16 <-> FP32 #define GGML_F16_STEP 64 #define GGML_F16_EPR 16 -// AVX512 has FP16 extension (AVX512_FP16) but I don't have it on my machine so I use FP32 instead - #define GGML_F32Cx16 __m512 #define GGML_F32Cx16_ZERO _mm512_setzero_ps() #define GGML_F32Cx16_SET1(x) _mm512_set1_ps(x) @@ -525,6 +563,8 @@ do { \ #define GGML_F16_VEC_MUL GGML_F32Cx16_MUL #define GGML_F16_VEC_REDUCE GGML_F32Cx16_REDUCE + +#endif // __AVX512FP16__ #elif defined(__AVX__) #define GGML_SIMD From c5f9a49b51ac277b6bc0bd12cd5e3c3be64299ee Mon Sep 17 00:00:00 2001 From: Neo Zhang Date: Sat, 14 Mar 2026 22:01:57 +0800 Subject: [PATCH 275/831] add op gated_delta_net (llama/20455) --- ggml/src/ggml-sycl/common.hpp | 2 +- ggml/src/ggml-sycl/gated_delta_net.cpp | 309 +++++++++++++++++++++++++ ggml/src/ggml-sycl/gated_delta_net.hpp | 8 + ggml/src/ggml-sycl/ggml-sycl.cpp | 20 +- 4 files changed, 332 insertions(+), 7 deletions(-) create mode 100644 ggml/src/ggml-sycl/gated_delta_net.cpp create mode 100644 ggml/src/ggml-sycl/gated_delta_net.hpp diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp index 9f0efb65359..fcb0db99c6b 100644 --- a/ggml/src/ggml-sycl/common.hpp +++ b/ggml/src/ggml-sycl/common.hpp @@ -211,7 +211,7 @@ struct sycl_device_info { // number of compute units on a SYCL device. // size_t smpb; // max. shared memory per block size_t smpbo; // max. shared memory per block (with opt-in) - int warp_size; // max sub_group_size of SYCL + int warp_size; // WARP_SIZE(16)|WARP_32_SIZE(32)|WARP_16_SIZE(16). For Intel GPU, 16 is better in most cases. Some OP support 32 only. int max_wg_per_cu; // max work groups per compute unit - refer to // cudaOccupancyMaxActiveBlocksPerMultiprocessor bool vmm; // virtual memory support diff --git a/ggml/src/ggml-sycl/gated_delta_net.cpp b/ggml/src/ggml-sycl/gated_delta_net.cpp new file mode 100644 index 00000000000..8c76afbd571 --- /dev/null +++ b/ggml/src/ggml-sycl/gated_delta_net.cpp @@ -0,0 +1,309 @@ +#include +#include "dpct/helper.hpp" +#include "common.hpp" +#include "ggml.h" +#include "gated_delta_net.hpp" +#include + + +template +void gated_delta_net_sycl(const float * q, + const float * k, + const float * v, + const float * g, + const float * beta, + const float * curr_state, + float * dst, + int64_t H, + int64_t n_tokens, + int64_t n_seqs, + int64_t sq1, + int64_t sq2, + int64_t sq3, + int64_t sv1, + int64_t sv2, + int64_t sv3, + int64_t sb1, + int64_t sb2, + int64_t sb3, + const sycl::uint3 neqk1_magic, + const sycl::uint3 rq3_magic, + float scale) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + const uint32_t h_idx = item_ct1.get_group(2); + const uint32_t sequence = item_ct1.get_group(1); + // each warp owns one column, using warp-level primitives to reduce across rows + const int lane = item_ct1.get_local_id(2); + const int col = item_ct1.get_group(0) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1); + + const uint32_t iq1 = fastmodulo(h_idx, neqk1_magic); + const uint32_t iq3 = fastdiv(sequence, rq3_magic); + + const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs; + float * attn_data = dst; + float * state = dst + attn_score_elems; + + const int64_t state_offset = (sequence * H + h_idx) * S_v * S_v; + state += state_offset; + curr_state += state_offset; + attn_data += (sequence * n_tokens * H + h_idx) * S_v; + + constexpr int warp_size = ggml_sycl_get_physical_warp_size() < S_v ? ggml_sycl_get_physical_warp_size() : S_v; + static_assert(S_v % warp_size == 0, "S_v must be a multiple of warp_size"); + constexpr int rows_per_lane = (S_v + warp_size - 1) / warp_size; + float s_shard[rows_per_lane]; +#pragma unroll + for (int r = 0; r < rows_per_lane; r++) { + const int i = r * warp_size + lane; + s_shard[r] = curr_state[i * S_v + col]; + } + + for (int t = 0; t < n_tokens; t++) { + const float * q_t = q + iq3 * sq3 + t * sq2 + iq1 * sq1; + const float * k_t = k + iq3 * sq3 + t * sq2 + iq1 * sq1; + const float * v_t = v + sequence * sv3 + t * sv2 + h_idx * sv1; + + const int64_t gb_offset = sequence * sb3 + t * sb2 + h_idx * sb1; + const float * beta_t = beta + gb_offset; + const float * g_t = g + gb_offset * (KDA ? S_v : 1); + + const float beta_val = *beta_t; + + if constexpr (!KDA) { + const float g_val = sycl::native::exp(*g_t); + + // kv[col] = (S^T @ k)[col] = sum_i S[i][col] * k[i] + float kv_shard = 0.0f; +#pragma unroll + for (int r = 0; r < rows_per_lane; r++) { + const int i = r * warp_size + lane; + kv_shard += s_shard[r] * k_t[i]; + } + float kv_col = warp_reduce_sum(kv_shard); + + // delta[col] = (v[col] - g * kv[col]) * beta + float delta_col = (v_t[col] - g_val * kv_col) * beta_val; + + // fused: S[i][col] = g * S[i][col] + k[i] * delta[col] + // attn[col] = (S^T @ q)[col] = sum_i S[i][col] * q[i] + float attn_partial = 0.0f; +#pragma unroll + for (int r = 0; r < rows_per_lane; r++) { + const int i = r * warp_size + lane; + s_shard[r] = g_val * s_shard[r] + k_t[i] * delta_col; + attn_partial += s_shard[r] * q_t[i]; + } + + float attn_col = warp_reduce_sum(attn_partial); + + if (lane == 0) { + attn_data[col] = attn_col * scale; + } + } else { + // kv[col] = sum_i g[i] * S[i][col] * k[i] + float kv_shard = 0.0f; +#pragma unroll + for (int r = 0; r < rows_per_lane; r++) { + const int i = r * warp_size + lane; + kv_shard += sycl::native::exp(g_t[i]) * s_shard[r] * k_t[i]; + } + + float kv_col = warp_reduce_sum(kv_shard); + + // delta[col] = (v[col] - kv[col]) * beta + float delta_col = (v_t[col] - kv_col) * beta_val; + + // fused: S[i][col] = g[i] * S[i][col] + k[i] * delta[col] + // attn[col] = (S^T @ q)[col] = sum_i S[i][col] * q[i] + float attn_partial = 0.0f; +#pragma unroll + for (int r = 0; r < rows_per_lane; r++) { + const int i = r * warp_size + lane; + s_shard[r] = sycl::native::exp(g_t[i]) * s_shard[r] + k_t[i] * delta_col; + attn_partial += s_shard[r] * q_t[i]; + } + + float attn_col = warp_reduce_sum(attn_partial); + + if (lane == 0) { + attn_data[col] = attn_col * scale; + } + } + + attn_data += S_v * H; + } + + // Write state back to global memory +#pragma unroll + for (int r = 0; r < rows_per_lane; r++) { + const int i = r * warp_size + lane; + state[i * S_v + col] = s_shard[r]; + } +} + +template +static void launch_gated_delta_net(const float * q_d, + const float * k_d, + const float * v_d, + const float * g_d, + const float * b_d, + const float * s_d, + float * dst_d, + int64_t S_v, + int64_t H, + int64_t n_tokens, + int64_t n_seqs, + int64_t sq1, + int64_t sq2, + int64_t sq3, + int64_t sv1, + int64_t sv2, + int64_t sv3, + int64_t sb1, + int64_t sb2, + int64_t sb3, + int64_t neqk1, + int64_t rq3, + float scale, + dpct::queue_ptr stream) { + //TODO: Add chunked kernel for even faster pre-fill + const int warp_size = ggml_sycl_info().devices[ggml_sycl_get_device()].warp_size; + + const int num_warps = 4; + dpct::dim3 grid_dims(H, n_seqs, (S_v + num_warps - 1) / num_warps); + dpct::dim3 block_dims(warp_size <= S_v ? warp_size : S_v, num_warps, 1); + + const sycl::uint3 neqk1_magic = init_fastdiv_values(neqk1); + const sycl::uint3 rq3_magic = init_fastdiv_values(rq3); + + int cc = ggml_sycl_info().devices[ggml_sycl_get_device()].cc; + + switch (S_v) { + case 16: + { + constexpr int sv = 16; + stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + gated_delta_net_sycl(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, + n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, + sb3, neqk1_magic, rq3_magic, scale); + }); + } + break; + case 32: + { + constexpr int sv = 32; + stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + gated_delta_net_sycl(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, + n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, + sb3, neqk1_magic, rq3_magic, scale); + }); + } + break; + case 64: { + { + constexpr int sv = 64; + stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + gated_delta_net_sycl( + q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, + sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); + }); + } + break; + } + case 128: { + { + constexpr int sv = 128; + stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + gated_delta_net_sycl( + q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, + sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); + }); + } + break; + } + default: + GGML_ABORT("fatal error"); + break; + } +} + +void ggml_sycl_op_gated_delta_net(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + ggml_tensor * src_q = dst->src[0]; + ggml_tensor * src_k = dst->src[1]; + ggml_tensor * src_v = dst->src[2]; + ggml_tensor * src_g = dst->src[3]; + ggml_tensor * src_beta = dst->src[4]; + ggml_tensor * src_state = dst->src[5]; + + GGML_TENSOR_LOCALS(int64_t, neq, src_q, ne); + GGML_TENSOR_LOCALS(size_t , nbq, src_q, nb); + GGML_TENSOR_LOCALS(int64_t, nek, src_k, ne); + GGML_TENSOR_LOCALS(size_t , nbk, src_k, nb); + GGML_TENSOR_LOCALS(int64_t, nev, src_v, ne); + GGML_TENSOR_LOCALS(size_t, nbv, src_v, nb); + GGML_TENSOR_LOCALS(size_t, nbb, src_beta, nb); + + const int64_t S_v = nev0; + const int64_t H = nev1; + const int64_t n_tokens = nev2; + const int64_t n_seqs = nev3; + + const bool kda = (src_g->ne[0] == S_v); + + GGML_ASSERT(neq1 == nek1); + const int64_t neqk1 = neq1; + + const int64_t rq3 = nev3 / neq3; + + const float * q_d = (const float *) src_q->data; + const float * k_d = (const float *) src_k->data; + const float * v_d = (const float *) src_v->data; + const float * g_d = (const float *) src_g->data; + const float * b_d = (const float *) src_beta->data; + + const float * s_d = (const float *) src_state->data; + float * dst_d = (float *) dst->data; + + GGML_ASSERT(ggml_is_contiguous_rows(src_q)); + GGML_ASSERT(ggml_is_contiguous_rows(src_k)); + GGML_ASSERT(ggml_is_contiguous_rows(src_v)); + GGML_ASSERT(ggml_are_same_stride(src_q, src_k)); + GGML_ASSERT(src_g->ne[0] == 1 || kda); + GGML_ASSERT(ggml_is_contiguous(src_g)); + GGML_ASSERT(ggml_is_contiguous(src_beta)); + GGML_ASSERT(ggml_is_contiguous(src_state)); + + // strides in floats (beta strides used for both g and beta offset computation) + const int64_t sq1 = nbq1 / sizeof(float); + const int64_t sq2 = nbq2 / sizeof(float); + const int64_t sq3 = nbq3 / sizeof(float); + const int64_t sv1 = nbv1 / sizeof(float); + const int64_t sv2 = nbv2 / sizeof(float); + const int64_t sv3 = nbv3 / sizeof(float); + const int64_t sb1 = nbb1 / sizeof(float); + const int64_t sb2 = nbb2 / sizeof(float); + const int64_t sb3 = nbb3 / sizeof(float); + + const float scale = 1.0f / sqrtf((float) S_v); + + dpct::queue_ptr stream = ctx.stream(); + + if (kda) { + launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, + S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, neqk1, rq3, scale, stream); + } else { + launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, + S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, neqk1, rq3, scale, stream); + } +} + +void ggml_sycl_gated_delta_net(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/6); + ggml_sycl_op_gated_delta_net(ctx, dst); +} diff --git a/ggml/src/ggml-sycl/gated_delta_net.hpp b/ggml/src/ggml-sycl/gated_delta_net.hpp new file mode 100644 index 00000000000..a3308ee8763 --- /dev/null +++ b/ggml/src/ggml-sycl/gated_delta_net.hpp @@ -0,0 +1,8 @@ +#pragma once + +#include +#include "dpct/helper.hpp" +#include "common.hpp" +#include "ggml.h" + +void ggml_sycl_gated_delta_net(ggml_backend_sycl_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index f887061b279..12819705849 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -35,6 +35,7 @@ #endif #include +#include "ggml.h" #include "ggml-sycl.h" #include "ggml-impl.h" #include "ggml-backend-impl.h" @@ -43,17 +44,18 @@ #include "ggml-sycl/backend.hpp" #include "ggml-sycl/common.hpp" #include "ggml-sycl/element_wise.hpp" +#include "ggml-sycl/gated_delta_net.hpp" +#include "ggml-sycl/gemm.hpp" +#include "ggml-sycl/getrows.hpp" #include "ggml-sycl/norm.hpp" #include "ggml-sycl/presets.hpp" -#include "ggml-sycl/gemm.hpp" +#include "ggml-sycl/quantize.hpp" +#include "ggml-sycl/repeat_back.hpp" #include "ggml-sycl/set_rows.hpp" #include "ggml-sycl/set.hpp" -#include "ggml-sycl/sycl_hw.hpp" -#include "ggml-sycl/getrows.hpp" -#include "ggml-sycl/repeat_back.hpp" -#include "ggml-sycl/quantize.hpp" #include "ggml-sycl/ssm_conv.hpp" -#include "ggml.h" +#include "ggml-sycl/sycl_hw.hpp" + static bool g_sycl_loaded = false; int g_ggml_sycl_debug = 0; @@ -99,6 +101,8 @@ static ggml_sycl_device_info ggml_sycl_init() { info.devices[i].nsm = prop.get_max_compute_units() / 16; //16: Number of Xe Cores info.devices[i].opt_feature.reorder = device.ext_oneapi_architecture_is(syclex::arch_category::intel_gpu); info.devices[i].smpbo = prop.get_local_mem_size(); + info.devices[i].warp_size = WARP_SIZE; + info.max_work_group_sizes[i] = prop.get_max_work_group_size(); info.devices[i].max_wg_per_cu = info.max_work_group_sizes[i] / prop.get_max_compute_units(); @@ -4181,6 +4185,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg case GGML_OP_GATED_LINEAR_ATTN: ggml_sycl_op_gated_linear_attn(ctx, dst); break; + case GGML_OP_GATED_DELTA_NET: + ggml_sycl_gated_delta_net(ctx, dst); + break; case GGML_OP_SSM_CONV: ggml_sycl_ssm_conv(ctx, dst); break; @@ -4890,6 +4897,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_RWKV_WKV6: case GGML_OP_RWKV_WKV7: case GGML_OP_GATED_LINEAR_ATTN: + case GGML_OP_GATED_DELTA_NET: return true; case GGML_OP_SSM_CONV: return op->type == GGML_TYPE_F32 && From 55f8cfdaed7f44b1009293e4c032f075a973ded5 Mon Sep 17 00:00:00 2001 From: Max Krasnyansky Date: Sat, 14 Mar 2026 11:09:08 -0700 Subject: [PATCH 276/831] hexagon: Q4_0 and MXFP4 repack fixes (llama/20527) * hexagon: fix tail corruption with rows sizes not multiple of 256 * hexagon: use different stride for repacking partial blocks * hex-mm: update repack and kernels to avoid shuffles for full 256-element blocks Previous commit changed the repacking to use even:odd (0:1,2:3,..) packing instead of the original (0:128,1:129,...) packing in order to fix tail corruption. Since the mm kernels already deal with partial tails we can use even:odd packing only for the last block. This avoid performance penalty of having to shuffle to zip the elements in the common case. * hex-mm: update rmpy x8 for better optimizations * hex-mm: tighten supported MUL_MAT checks to avoid spurios failures * hex-mm: use vzero to init accumulators * hex-mm: properly call partial rmpy_x8 --- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 43 ++- ggml/src/ggml-hexagon/htp/matmul-ops.c | 415 +++++++++++++++---------- 2 files changed, 291 insertions(+), 167 deletions(-) diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index d6e9776b878..19917cb1140 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -402,6 +402,7 @@ static void pack_q4_0_quants(block_q4_0 * x, const uint8_t * qs, unsigned int bi static void repack_row_q4x4x2(uint8_t * y, const block_q4_0 * x, int64_t k) { static const int qk = QK_Q4_0x4x2; const int nb = (k + qk - 1) / qk; // number of blocks (padded) + const int nloe = k % qk; // leftovers const int dblk_size = 8 * 2; // 8x __fp16 const int qblk_size = qk / 2; // int4 @@ -435,9 +436,11 @@ static void repack_row_q4x4x2(uint8_t * y, const block_q4_0 * x, int64_t k) { unpack_q4_0_quants(qs, &x[i * 8 + 6], 6); unpack_q4_0_quants(qs, &x[i * 8 + 7], 7); + bool partial = (nloe && i == nb-1); + uint8_t * q = y_q + (i * qblk_size); for (int j = 0; j < qk / 2; j++) { - q[j] = (qs[j + 128] << 4) | qs[j]; + q[j] = partial ? (qs[j*2+1] << 4) | qs[j*2+0] : (qs[j+128] << 4) | qs[j+000]; } } @@ -467,6 +470,7 @@ static void repack_row_q4x4x2(uint8_t * y, const block_q4_0 * x, int64_t k) { static void unpack_row_q4x4x2(block_q4_0 * x, const uint8_t * y, int64_t k) { static const int qk = QK_Q4_0x4x2; const int nb = (k + qk - 1) / qk; // number of blocks (padded) + const int nloe = k % qk; // leftovers const int dblk_size = 8 * 2; // 8x __fp16 const int qblk_size = qk / 2; // int4 @@ -485,10 +489,17 @@ static void unpack_row_q4x4x2(block_q4_0 * x, const uint8_t * y, int64_t k) { for (int i = 0; i < nb; i++) { uint8_t qs[QK_Q4_0x4x2]; // unpacked quants + bool partial = (nloe && i == nb-1); + const uint8_t * q = y_q + (i * qblk_size); for (int j = 0; j < qk / 2; j++) { - qs[j] = q[j] & 0xf; - qs[j + 128] = q[j] >> 4; + if (partial) { + qs[j*2+0] = q[j] & 0xf; + qs[j*2+1] = q[j] >> 4; + } else { + qs[j+000] = q[j] & 0xf; + qs[j+128] = q[j] >> 4; + } } pack_q4_0_quants(&x[i * 8 + 0], qs, 0); @@ -1078,6 +1089,7 @@ static void pack_mxfp4_quants(block_mxfp4 * x, const uint8_t * qs, unsigned int static void repack_row_mxfp4x4x2(uint8_t * y, const block_mxfp4 * x, int64_t k) { static const int qk = QK_MXFP4x4x2; const int nb = (k + qk - 1) / qk; // number of blocks (padded) + const int nloe = k % qk; // leftovers const int eblk_size = 8 * 1; // 8x E8M0 const int qblk_size = qk / 2; // int4 @@ -1112,9 +1124,11 @@ static void repack_row_mxfp4x4x2(uint8_t * y, const block_mxfp4 * x, int64_t k) unpack_mxfp4_quants(qs, &x[i * 8 + 6], 6); unpack_mxfp4_quants(qs, &x[i * 8 + 7], 7); + bool partial = (nloe && i == nb-1); + uint8_t * q = y_q + (i * qblk_size); for (int j = 0; j < qk / 2; j++) { - q[j] = (qs[j + 128] << 4) | qs[j]; + q[j] = partial ? (qs[j*2+1] << 4) | qs[j*2+0] : (qs[j+128] << 4) | qs[j+000]; } } @@ -1144,6 +1158,7 @@ static void repack_row_mxfp4x4x2(uint8_t * y, const block_mxfp4 * x, int64_t k) static void unpack_row_mxfp4x4x2(block_mxfp4 * x, const uint8_t * y, int64_t k) { static const int qk = QK_MXFP4x4x2; const int nb = (k + qk - 1) / qk; // number of blocks (padded) + const int nloe = k % qk; // leftovers const int eblk_size = 8 * 1; // 8x E8M0 const int qblk_size = qk / 2; // int4 @@ -1162,10 +1177,17 @@ static void unpack_row_mxfp4x4x2(block_mxfp4 * x, const uint8_t * y, int64_t k) for (int i = 0; i < nb; i++) { uint8_t qs[QK_MXFP4x4x2]; // unpacked quants + bool partial = (nloe && i == nb-1); + const uint8_t * q = y_q + (i * qblk_size); for (int j = 0; j < qk / 2; j++) { - qs[j] = q[j] & 0xf; - qs[j + 128] = q[j] >> 4; + if (partial) { + qs[j*2+0] = q[j] & 0xf; + qs[j*2+1] = q[j] >> 4; + } else { + qs[j+000] = q[j] & 0xf; + qs[j+128] = q[j] >> 4; + } } pack_mxfp4_quants(&x[i * 8 + 0], qs, 0); @@ -1801,12 +1823,12 @@ static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * s return false; } - if (src0->ne[1] > 16 * 1024) { + if (ggml_nrows(src0) > 16 * 1024) { return false; // typically the lm-head which would be too large for VTCM } - if ((src1->ne[2] != 1 || src1->ne[3] != 1)) { - return false; + if (ggml_nrows(src1) > 1024 || src1->ne[2] != 1 || src1->ne[3] != 1) { + return false; // no huge batches or broadcasting (for now) } // src0 (weights) must be repacked @@ -1820,6 +1842,9 @@ static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * s GGML_LOG_DEBUG("ggml_hexagon_supported_mul_mat: permuted F16 src0 not supported\n"); return false; } + if (ggml_nrows(src1) > 1024) { + return false; // no huge batches (for now) + } break; default: diff --git a/ggml/src/ggml-hexagon/htp/matmul-ops.c b/ggml/src/ggml-hexagon/htp/matmul-ops.c index 9ca74aedfef..73aaba79ebf 100644 --- a/ggml/src/ggml-hexagon/htp/matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/matmul-ops.c @@ -77,7 +77,7 @@ static inline size_t q8x4x2_row_size(uint32_t ne) { return hex_round_up(ne + nb * 8 * sizeof(__fp16), 128); } -static inline HVX_Vector_x8 hvx_vec_load_q4x4x8(const uint8_t * restrict ptr) { +static inline HVX_Vector_x8 hvx_vec_load_q4x4x8_full(const uint8_t * restrict ptr) { const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; HVX_Vector v0_1 = vptr[0]; // first 256 elements (128 bytes) @@ -88,9 +88,9 @@ static inline HVX_Vector_x8 hvx_vec_load_q4x4x8(const uint8_t * restrict ptr) { const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); const HVX_Vector i8 = Q6_Vb_vsplat_R(8); - HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4); // & 0x0F - HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4); // >> 4 - HVX_Vector v2 = Q6_V_vand_VV(v2_3, mask_h4); // & 0x0F + HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4); // & 0x0F : first 128 elements + HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4); // >> 4 : second 128 elements + HVX_Vector v2 = Q6_V_vand_VV(v2_3, mask_h4); // & 0x0F ... HVX_Vector v3 = Q6_Vub_vlsr_VubR(v2_3, 4); // >> 4 HVX_Vector v4 = Q6_V_vand_VV(v4_5, mask_h4); // & 0x0F HVX_Vector v5 = Q6_Vub_vlsr_VubR(v4_5, 4); // >> 4 @@ -111,7 +111,41 @@ static inline HVX_Vector_x8 hvx_vec_load_q4x4x8(const uint8_t * restrict ptr) { return r; } -static inline HVX_Vector_x8 hvx_vec_load_mxfp4x4x8(const uint8_t * restrict ptr) { +static HVX_Vector_x8 hvx_vec_load_q4x4x8_partial(const uint8_t * restrict ptr, uint32_t n) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; + + const uint32_t qk = QK_Q4_0x4x2; // 256 + const uint32_t nb = n / qk; + const uint32_t nloe = n % qk; + + const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + const HVX_Vector i8 = Q6_Vb_vsplat_R(8); + + HVX_Vector_x8 r; + uint32_t i = 0; + + #pragma unroll(2) + for (i=0; i < nb; i++) { + HVX_Vector v = vptr[i]; // 256 elements (128 bytes) + HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : first 128 elements + HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : second 128 elements + r.v[i*2+0] = Q6_Vb_vsub_VbVb(v0, i8); + r.v[i*2+1] = Q6_Vb_vsub_VbVb(v1, i8); + } + + if (nloe) { + HVX_Vector v = vptr[i]; // 256 elements (128 bytes) + HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : even 128 elements + HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : odd 128 elements + HVX_VectorPair v0_1_p = Q6_W_vshuff_VVR(v1, v0, -1); // zip even:odd:... + r.v[i*2+0] = Q6_Vb_vsub_VbVb(Q6_V_lo_W(v0_1_p), i8); + r.v[i*2+1] = Q6_Vb_vsub_VbVb(Q6_V_hi_W(v0_1_p), i8); + } + + return r; +} + +static inline HVX_Vector_x8 hvx_vec_load_mxfp4x4x8_full(const uint8_t * restrict ptr) { const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; HVX_Vector v0_1 = vptr[0]; // first 256 elements (128 bytes) @@ -144,7 +178,41 @@ static inline HVX_Vector_x8 hvx_vec_load_mxfp4x4x8(const uint8_t * restrict ptr) return r; } -static inline HVX_Vector_x8 hvx_vec_load_q8x4x8(const uint8_t * restrict ptr) { +static inline HVX_Vector_x8 hvx_vec_load_mxfp4x4x8_partial(const uint8_t * restrict ptr, uint32_t n) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; + + const uint32_t qk = QK_Q4_0x4x2; // 256 + const uint32_t nb = n / qk; + const uint32_t nloe = n % qk; + + const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + const HVX_Vector lut = *(const HVX_Vector *) kvalues_mxfp4_lut; + + HVX_Vector_x8 r; + uint32_t i = 0; + + #pragma unroll(2) + for (i=0; i < nb; i++) { + HVX_Vector v = vptr[i]; // 256 elements (128 bytes) + HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : first 128 elements + HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : second 128 elements + r.v[i*2+0] = Q6_Vb_vlut32_VbVbI(v0, lut, 0); + r.v[i*2+1] = Q6_Vb_vlut32_VbVbI(v1, lut, 0); + } + + if (nloe) { + HVX_Vector v = vptr[i]; // 256 elements (128 bytes) + HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : even 128 elements + HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : odd 128 elements + HVX_VectorPair v0_1_p = Q6_W_vshuff_VVR(v1, v0, -1); // zip even:odd:... + r.v[i*2+0] = Q6_Vb_vlut32_VbVbI(Q6_V_lo_W(v0_1_p), lut, 0); + r.v[i*2+1] = Q6_Vb_vlut32_VbVbI(Q6_V_hi_W(v0_1_p), lut, 0); + } + + return r; +} + +static inline HVX_Vector_x8 hvx_vec_load_q8x4x8_full(const uint8_t * restrict ptr) { const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; HVX_Vector v0 = vptr[0]; // first 128 vals @@ -160,6 +228,10 @@ static inline HVX_Vector_x8 hvx_vec_load_q8x4x8(const uint8_t * restrict ptr) { return r; } +static inline HVX_Vector_x8 hvx_vec_load_q8x4x8_partial(const uint8_t * restrict ptr, uint32_t nloe) { + return hvx_vec_load_q8x4x8_full(ptr); +} + // Reduce multiply 1024 x 1024 int8 elements (32x q4/8 blocks in 8x HVX vectors). // Accumulate each block into a single int32 value. // Return a single HVX vector with 32x int32 accumulators. @@ -167,14 +239,14 @@ static inline HVX_Vector_x8 hvx_vec_load_q8x4x8(const uint8_t * restrict ptr) { // if() checks are optimized out at compile time -- make sure to pass N as a constexpr. static inline HVX_Vector hvx_vec_rmpy_x8_n(HVX_Vector_x8 x, HVX_Vector_x8 y, unsigned int n) { - HVX_Vector r0 = Q6_V_vsplat_R(0); - HVX_Vector r1 = Q6_V_vsplat_R(0); - HVX_Vector r2 = Q6_V_vsplat_R(0); - HVX_Vector r3 = Q6_V_vsplat_R(0); - HVX_Vector r4 = Q6_V_vsplat_R(0); - HVX_Vector r5 = Q6_V_vsplat_R(0); - HVX_Vector r6 = Q6_V_vsplat_R(0); - HVX_Vector r7 = Q6_V_vsplat_R(0); + HVX_Vector r0 = Q6_V_vzero(); + HVX_Vector r1 = Q6_V_vzero(); + HVX_Vector r2 = Q6_V_vzero(); + HVX_Vector r3 = Q6_V_vzero(); + HVX_Vector r4 = Q6_V_vzero(); + HVX_Vector r5 = Q6_V_vzero(); + HVX_Vector r6 = Q6_V_vzero(); + HVX_Vector r7 = Q6_V_vzero(); HVX_VectorPair p3; HVX_VectorPair p2; @@ -213,15 +285,42 @@ static inline HVX_Vector hvx_vec_rmpy_x8_n(HVX_Vector_x8 x, HVX_Vector_x8 y, uns } static inline HVX_Vector hvx_vec_rmpy_x8_full(HVX_Vector_x8 x, HVX_Vector_x8 y) { - return hvx_vec_rmpy_x8_n(x, y, 1024); + HVX_Vector r0 = Q6_Vw_vrmpy_VbVb(x.v[0], y.v[0]); + HVX_Vector r1 = Q6_Vw_vrmpy_VbVb(x.v[1], y.v[1]); + HVX_Vector r2 = Q6_Vw_vrmpy_VbVb(x.v[2], y.v[2]); + HVX_Vector r3 = Q6_Vw_vrmpy_VbVb(x.v[3], y.v[3]); + HVX_Vector r4 = Q6_Vw_vrmpy_VbVb(x.v[4], y.v[4]); + HVX_Vector r5 = Q6_Vw_vrmpy_VbVb(x.v[5], y.v[5]); + HVX_Vector r6 = Q6_Vw_vrmpy_VbVb(x.v[6], y.v[6]); + HVX_Vector r7 = Q6_Vw_vrmpy_VbVb(x.v[7], y.v[7]); + + HVX_VectorPair p0 = Q6_W_vdeal_VVR(r1, r0, -4); + HVX_VectorPair p1 = Q6_W_vdeal_VVR(r3, r2, -4); + HVX_VectorPair p2 = Q6_W_vdeal_VVR(r5, r4, -4); + HVX_VectorPair p3 = Q6_W_vdeal_VVR(r7, r6, -4); + + r0 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p0), Q6_V_hi_W(p0)); + r1 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p1), Q6_V_hi_W(p1)); + r2 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p2), Q6_V_hi_W(p2)); + r3 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p3), Q6_V_hi_W(p3)); + + p0 = Q6_W_vdeal_VVR(r1, r0, -4); + p1 = Q6_W_vdeal_VVR(r3, r2, -4); + + r0 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p0), Q6_V_hi_W(p0)); + r1 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p1), Q6_V_hi_W(p1)); + + p0 = Q6_W_vdeal_VVR(r1, r0, -4); + r0 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p0), Q6_V_hi_W(p0)); + + return r0; } -// Handle most common cases of tensors not multiple of 1024. -static inline HVX_Vector hvx_vec_rmpy_x8_nloe(HVX_Vector_x8 x, HVX_Vector_x8 y, unsigned int n) { - if (n <= 256) { return hvx_vec_rmpy_x8_n(x, y, 256); }; - if (n <= 512) { return hvx_vec_rmpy_x8_n(x, y, 512); }; - if (n <= 768) { return hvx_vec_rmpy_x8_n(x, y, 768); }; - return hvx_vec_rmpy_x8_n(x, y, 1024); +static inline HVX_Vector hvx_vec_rmpy_x8_partial(HVX_Vector_x8 x, HVX_Vector_x8 y, unsigned int n) { + if (n >= 512) + return hvx_vec_rmpy_x8_full(x, y); + + return hvx_vec_rmpy_x8_partial(x, y, 512); } static void vec_dot_q4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) { @@ -246,7 +345,7 @@ static void vec_dot_q4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const vo const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales // Row sum (sf) - HVX_Vector r0_sum = Q6_V_vsplat_R(0); + HVX_Vector r0_sum = Q6_V_vzero(); // Multiply and accumulate into int32. // Compute combined scale (fp32). @@ -257,12 +356,12 @@ static void vec_dot_q4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const vo uint32_t i = 0; for (; i < nb; i++) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_full(r0_x_q + i * x_qblk_size); HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); @@ -272,19 +371,19 @@ static void vec_dot_q4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const vo r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); } - // Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks + // Process leftovers if (nloe) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy_q, nloe)); + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); - // Zero out unused scales + // Zero out unused elements HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); r0_dd = Q6_V_vand_QV(bmask, r0_dd); r0_ia = Q6_V_vand_QV(bmask, r0_ia); @@ -326,8 +425,8 @@ static void vec_dot_q4x4x2_q8x4x2_2x1(const int n, float * restrict s0, const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales // Row sum (sf) - HVX_Vector r0_sum = Q6_V_vsplat_R(0); - HVX_Vector r1_sum = Q6_V_vsplat_R(0); + HVX_Vector r0_sum = Q6_V_vzero(); + HVX_Vector r1_sum = Q6_V_vzero(); // Multiply and accumulate into int32. // Compute combined scale (fp32). @@ -338,14 +437,14 @@ static void vec_dot_q4x4x2_q8x4x2_2x1(const int n, float * restrict s0, uint32_t i = 0; for (; i < nb; i++) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8(r1_x_q + i * x_qblk_size); + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_full(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_full(r1_x_q + i * x_qblk_size); HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); @@ -359,23 +458,23 @@ static void vec_dot_q4x4x2_q8x4x2_2x1(const int n, float * restrict s0, r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); } - // Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks + // Process leftovers if (nloe) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8(r1_x_q + i * x_qblk_size); + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_partial(r1_x_q + i * x_qblk_size, nloe); - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy_q, nloe)); - HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy_q, nloe)); + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe)); - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); - // Zero out unused scales + // Zero out unused elements HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); r0_dd = Q6_V_vand_QV(bmask, r0_dd); r1_dd = Q6_V_vand_QV(bmask, r1_dd); @@ -423,10 +522,10 @@ static void vec_dot_q4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; // then scales // Row sums (sf) - 4 accumulators for 2×2 tile - HVX_Vector r0_c0_sum = Q6_V_vsplat_R(0); - HVX_Vector r0_c1_sum = Q6_V_vsplat_R(0); - HVX_Vector r1_c0_sum = Q6_V_vsplat_R(0); - HVX_Vector r1_c1_sum = Q6_V_vsplat_R(0); + HVX_Vector r0_c0_sum = Q6_V_vzero(); + HVX_Vector r0_c1_sum = Q6_V_vzero(); + HVX_Vector r1_c0_sum = Q6_V_vzero(); + HVX_Vector r1_c1_sum = Q6_V_vzero(); const uint32_t nb = n / qk; // num full blocks const uint32_t nloe = n % qk; // num leftover elements @@ -434,12 +533,12 @@ static void vec_dot_q4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * uint32_t i = 0; for (; i < nb; i++) { // Load src1 columns (reused across both src0 rows) - HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size); - HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size); + HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_full(y0_q + i * y_qblk_size); + HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_full(y1_q + i * y_qblk_size); // Load src0 rows (reused across both src1 columns) - HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8(r1_x_q + i * x_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_full(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_full(r1_x_q + i * x_qblk_size); // Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1 HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q)); @@ -448,8 +547,8 @@ static void vec_dot_q4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q)); // Load scales - HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); - HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); + HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); + HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); @@ -473,18 +572,18 @@ static void vec_dot_q4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * // Process leftovers if (nloe) { - HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size); - HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8(r1_x_q + i * x_qblk_size); - - HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy0_q, nloe)); - HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy1_q, nloe)); - HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy0_q, nloe)); - HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy1_q, nloe)); - - HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); - HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); + HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_partial(y0_q + i * y_qblk_size, nloe); + HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_partial(y1_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_partial(r1_x_q + i * x_qblk_size, nloe); + + HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy0_q, nloe)); + HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy1_q, nloe)); + HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy0_q, nloe)); + HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy1_q, nloe)); + + HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); + HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); @@ -545,7 +644,7 @@ static void vec_dot_q8x4x2_q8x4x2_1x1(const int n, float * restrict s0, const vo const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales // Row sum (sf) - HVX_Vector r0_sum = Q6_V_vsplat_R(0); + HVX_Vector r0_sum = Q6_V_vzero(); // Multiply and accumulate into int32. // Compute combined scale (fp32). @@ -556,12 +655,12 @@ static void vec_dot_q8x4x2_q8x4x2_1x1(const int n, float * restrict s0, const vo uint32_t i = 0; for (; i < nb; i++) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_full(r0_x_q + i * x_qblk_size); HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); @@ -571,19 +670,19 @@ static void vec_dot_q8x4x2_q8x4x2_1x1(const int n, float * restrict s0, const vo r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); } - // Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks + // Process leftovers if (nloe) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_partial(r0_x_q + i * x_qblk_size, nloe); - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy_q, nloe)); + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); - // Zero out unused scales + // Zero out unused elements HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); r0_dd = Q6_V_vand_QV(bmask, r0_dd); r0_ia = Q6_V_vand_QV(bmask, r0_ia); @@ -625,8 +724,8 @@ static void vec_dot_q8x4x2_q8x4x2_2x1(const int n, float * restrict s0, const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales // Row sum (qf32) - HVX_Vector r0_sum = Q6_V_vsplat_R(0); - HVX_Vector r1_sum = Q6_V_vsplat_R(0); + HVX_Vector r0_sum = Q6_V_vzero(); + HVX_Vector r1_sum = Q6_V_vzero(); // Multiply and accumulate into int32. // Compute combined scale (fp32). @@ -637,14 +736,14 @@ static void vec_dot_q8x4x2_q8x4x2_2x1(const int n, float * restrict s0, uint32_t i = 0; for (; i < nb; i++) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8(r1_x_q + i * x_qblk_size); + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_full(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_full(r1_x_q + i * x_qblk_size); HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); @@ -658,14 +757,14 @@ static void vec_dot_q8x4x2_q8x4x2_2x1(const int n, float * restrict s0, r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); } - // Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks + // Process leftovers if (nloe) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8(r1_x_q + i * x_qblk_size); + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_partial(r0_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_partial(r1_x_q + i * x_qblk_size, nloe); - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy_q, nloe)); - HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy_q, nloe)); + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe)); HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); @@ -674,7 +773,7 @@ static void vec_dot_q8x4x2_q8x4x2_2x1(const int n, float * restrict s0, HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); - // Zero out unused scales + // Zero out unused elements HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); r0_dd = Q6_V_vand_QV(bmask, r0_dd); r1_dd = Q6_V_vand_QV(bmask, r1_dd); @@ -722,10 +821,10 @@ static void vec_dot_q8x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; // then scales // Row sums (sf) - 4 accumulators for 2×2 tile - HVX_Vector r0_c0_sum = Q6_V_vsplat_R(0); - HVX_Vector r0_c1_sum = Q6_V_vsplat_R(0); - HVX_Vector r1_c0_sum = Q6_V_vsplat_R(0); - HVX_Vector r1_c1_sum = Q6_V_vsplat_R(0); + HVX_Vector r0_c0_sum = Q6_V_vzero(); + HVX_Vector r0_c1_sum = Q6_V_vzero(); + HVX_Vector r1_c0_sum = Q6_V_vzero(); + HVX_Vector r1_c1_sum = Q6_V_vzero(); const uint32_t nb = n / qk; // num full blocks const uint32_t nloe = n % qk; // num leftover elements @@ -733,12 +832,12 @@ static void vec_dot_q8x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * uint32_t i = 0; for (; i < nb; i++) { // Load src1 columns (reused across both src0 rows) - HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size); - HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size); + HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_full(y0_q + i * y_qblk_size); + HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_full(y1_q + i * y_qblk_size); // Load src0 rows (reused across both src1 columns) - HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8(r1_x_q + i * x_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_full(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_full(r1_x_q + i * x_qblk_size); // Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1 HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q)); @@ -747,8 +846,8 @@ static void vec_dot_q8x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q)); // Load scales - HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); - HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); + HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); + HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); @@ -772,18 +871,18 @@ static void vec_dot_q8x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * // Process leftovers if (nloe) { - HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size); - HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8(r1_x_q + i * x_qblk_size); - - HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy0_q, nloe)); - HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy1_q, nloe)); - HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy0_q, nloe)); - HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy1_q, nloe)); - - HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); - HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); + HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_partial(y0_q + i * y_qblk_size, nloe); + HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_partial(y1_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_partial(r0_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_partial(r1_x_q + i * x_qblk_size, nloe); + + HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy0_q, nloe)); + HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy1_q, nloe)); + HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy0_q, nloe)); + HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy1_q, nloe)); + + HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); + HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); @@ -792,7 +891,7 @@ static void vec_dot_q8x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d))); HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d))); - // Zero out unused scales + // Zero out unused elements HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd); r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd); @@ -844,7 +943,7 @@ static void vec_dot_mxfp4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales // Row sum (sf) - HVX_Vector r0_sum = Q6_V_vsplat_R(0); + HVX_Vector r0_sum = Q6_V_vzero(); // Multiply and accumulate into int32. // Compute combined scale (fp32). @@ -855,8 +954,8 @@ static void vec_dot_mxfp4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const uint32_t i = 0; for (; i < nb; i++) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full( y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_full(r0_x_q + i * x_qblk_size); HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); @@ -887,12 +986,12 @@ static void vec_dot_mxfp4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const // Process leftovers if (nloe) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial( y_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); - HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size); + HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size); HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving @@ -954,8 +1053,8 @@ static void vec_dot_mxfp4x4x2_q8x4x2_2x1(const int n, float * restrict s0, const uint8_t * restrict y_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales // Row sum (sf) - HVX_Vector r0_sum = Q6_V_vsplat_R(0); - HVX_Vector r1_sum = Q6_V_vsplat_R(0); + HVX_Vector r0_sum = Q6_V_vzero(); + HVX_Vector r1_sum = Q6_V_vzero(); // Multiply and accumulate into int32. // Compute combined scale (fp32). @@ -966,9 +1065,9 @@ static void vec_dot_mxfp4x4x2_q8x4x2_2x1(const int n, float * restrict s0, uint32_t i = 0; for (; i < nb; i++) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8(r1_x_q + i * x_qblk_size); + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full( y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_full(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8_full(r1_x_q + i * x_qblk_size); HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); @@ -1007,14 +1106,14 @@ static void vec_dot_mxfp4x4x2_q8x4x2_2x1(const int n, float * restrict s0, // Process leftovers if (nloe) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8(r1_x_q + i * x_qblk_size); + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial( y_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8_partial(r1_x_q + i * x_qblk_size, nloe); HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); - HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size); + HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size); HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); @@ -1087,10 +1186,10 @@ static void vec_dot_mxfp4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; // then scales // Row sums (sf) - 4 accumulators for 2×2 tile - HVX_Vector r0_c0_sum = Q6_V_vsplat_R(0); - HVX_Vector r0_c1_sum = Q6_V_vsplat_R(0); - HVX_Vector r1_c0_sum = Q6_V_vsplat_R(0); - HVX_Vector r1_c1_sum = Q6_V_vsplat_R(0); + HVX_Vector r0_c0_sum = Q6_V_vzero(); + HVX_Vector r0_c1_sum = Q6_V_vzero(); + HVX_Vector r1_c0_sum = Q6_V_vzero(); + HVX_Vector r1_c1_sum = Q6_V_vzero(); const uint32_t nb = n / qk; // num full blocks const uint32_t nloe = n % qk; // num leftover elements @@ -1098,12 +1197,12 @@ static void vec_dot_mxfp4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float uint32_t i = 0; for (; i < nb; i++) { // Load src1 columns (reused across both src0 rows) - HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size); - HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size); + HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_full(y0_q + i * y_qblk_size); + HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_full(y1_q + i * y_qblk_size); // Load src0 rows (reused across both src1 columns) - HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8(r1_x_q + i * x_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_full(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8_full(r1_x_q + i * x_qblk_size); // Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1 HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q)); @@ -1157,15 +1256,15 @@ static void vec_dot_mxfp4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float // Process leftovers if (nloe) { - HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size); - HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8(r1_x_q + i * x_qblk_size); + HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_partial( y0_q + i * y_qblk_size, nloe); + HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_partial( y1_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8_partial(r1_x_q + i * x_qblk_size, nloe); - HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy0_q, nloe)); - HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy1_q, nloe)); - HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy0_q, nloe)); - HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy1_q, nloe)); + HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy0_q, nloe)); + HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy1_q, nloe)); + HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy0_q, nloe)); + HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy1_q, nloe)); HVX_Vector vy0_d = *(const HVX_UVector *) (y0_d + i * y_dblk_size); HVX_Vector vy1_d = *(const HVX_UVector *) (y1_d + i * y_dblk_size); @@ -1234,7 +1333,7 @@ static void vec_dot_f16_f16_aa_1x1(const int n, float * restrict s, const void * uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors uint32_t nloe = n % VLEN_FP16; // leftover elements - HVX_VectorPair rsum_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0)); + HVX_VectorPair rsum_p = Q6_W_vzero(); uint32_t i = 0; @@ -1264,8 +1363,8 @@ static void vec_dot_f16_f16_aa_2x1(const int n, float * restrict s0, uint32_t nvec = n / VLEN_FP16; uint32_t nloe = n % VLEN_FP16; - HVX_VectorPair rsum0_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0)); - HVX_VectorPair rsum1_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0)); + HVX_VectorPair rsum0_p = Q6_W_vzero(); + HVX_VectorPair rsum1_p = Q6_W_vzero(); uint32_t i = 0; @@ -1303,10 +1402,10 @@ static void vec_dot_f16_f16_aa_2x2(const int n, float * restrict s0, float * res uint32_t nloe = n % VLEN_FP16; // Row sums (sf) - 4 accumulators for 2×2 tile - HVX_VectorPair r0_c0_sum_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0)); - HVX_VectorPair r0_c1_sum_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0)); - HVX_VectorPair r1_c0_sum_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0)); - HVX_VectorPair r1_c1_sum_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0)); + HVX_VectorPair r0_c0_sum_p = Q6_W_vzero(); + HVX_VectorPair r0_c1_sum_p = Q6_W_vzero(); + HVX_VectorPair r1_c0_sum_p = Q6_W_vzero(); + HVX_VectorPair r1_c1_sum_p = Q6_W_vzero(); uint32_t i = 0; @@ -1358,7 +1457,7 @@ static void vec_dot_f16_f16_uu_1x1(const int n, float * restrict s, const void * uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors uint32_t nloe = n % VLEN_FP16; // leftover elements - HVX_Vector rsum = Q6_V_vsplat_R(0); + HVX_Vector rsum = Q6_V_vzero(); uint32_t i = 0; @@ -1388,9 +1487,9 @@ static void vec_dot_f16_f32_uu_1x1(const int n, float * restrict s, const void * uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors uint32_t nloe = n % VLEN_FP16; // leftover elements - const HVX_Vector zero = Q6_V_vsplat_R(0); + const HVX_Vector zero = Q6_V_vzero(); - HVX_Vector rsum = Q6_V_vsplat_R(0); + HVX_Vector rsum = Q6_V_vzero(); uint32_t i = 0; @@ -1973,7 +2072,7 @@ static inline void quantize_block_f32_q8x1(float * restrict x, uint8_t * restric assert((unsigned long) y_q % 128 == 0); HVX_Vector * vx = (HVX_Vector *) x; - HVX_Vector zero = Q6_V_vsplat_R(0); + HVX_Vector zero = Q6_V_vzero(); // Use reduce max fp32 to find max(abs(e)) first HVX_Vector vmax0_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[0])); @@ -2034,7 +2133,7 @@ static inline void quantize_block_f32_q8x2(float * restrict x, uint8_t * restric HVX_Vector * vx = (HVX_Vector *) x; // Load and convert into QF32 - HVX_Vector zero = Q6_V_vsplat_R(0); + HVX_Vector zero = Q6_V_vzero(); HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero); // 32 elements HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero); // 32 elements HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero); // 32 elements @@ -2077,7 +2176,7 @@ static inline void quantize_block_f32_q8x4(float * restrict x, uint8_t * restric HVX_Vector * vx = (HVX_Vector *) x; // Load and convert into QF32 - HVX_Vector zero = Q6_V_vsplat_R(0); + HVX_Vector zero = Q6_V_vzero(); HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero); // 32 elements HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero); // 32 elements HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero); // 32 elements From b312018435f8576dcf32c3f7ca5acd6ce985b278 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 14 Mar 2026 23:15:47 +0200 Subject: [PATCH 277/831] metal : add FA specialization for HSK = 320, HSV = 256 (llama/20549) --- ggml/src/ggml-metal/ggml-metal-device.m | 1 + ggml/src/ggml-metal/ggml-metal.metal | 19 +++++++++++++++++++ 2 files changed, 20 insertions(+) diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index b7d587f3bd9..82101f4714e 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -1142,6 +1142,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te op->src[0]->ne[0] != 128 && op->src[0]->ne[0] != 192 && op->src[0]->ne[0] != 256 && + op->src[0]->ne[0] != 320 && op->src[0]->ne[0] != 576) { return false; } diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index d4b129ed756..b2328605dd9 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -6176,6 +6176,7 @@ template [[host_name("kernel_flash_attn_ext_f32_dk128_dv128")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_f32_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f32_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f32_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f32_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f32_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -6190,6 +6191,7 @@ template [[host_name("kernel_flash_attn_ext_f16_dk128_dv128")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_f16_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f16_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f16_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; #if defined(GGML_METAL_HAS_BF16) @@ -6205,6 +6207,7 @@ template [[host_name("kernel_flash_attn_ext_bf16_dk128_dv128")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_bf16_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_bf16_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_bf16_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_bf16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; #endif @@ -6220,6 +6223,7 @@ template [[host_name("kernel_flash_attn_ext_q4_0_dk128_dv128")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -6234,6 +6238,7 @@ template [[host_name("kernel_flash_attn_ext_q4_1_dk128_dv128")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -6248,6 +6253,7 @@ template [[host_name("kernel_flash_attn_ext_q5_0_dk128_dv128")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -6262,6 +6268,7 @@ template [[host_name("kernel_flash_attn_ext_q5_1_dk128_dv128")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q8_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -6276,6 +6283,7 @@ template [[host_name("kernel_flash_attn_ext_q8_0_dk128_dv128")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q8_0_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q8_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q8_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q8_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; #undef FA_TYPES @@ -6846,6 +6854,17 @@ template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk256_dv256")]] kernel flas template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f32_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f16_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_flash_attn_ext_vec_bf16_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#endif +template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; + template [[host_name("kernel_flash_attn_ext_vec_f32_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; template [[host_name("kernel_flash_attn_ext_vec_f16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #if defined(GGML_METAL_HAS_BF16) From cd02195b8fbfe1d0ac505ed43daeb548a912b279 Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Sun, 15 Mar 2026 08:18:54 +0100 Subject: [PATCH 278/831] vulkan: use graphics queue on AMD (llama/20551) * vulkan: use graphics queue on AMD for slightly better performance * disable async transfer queue on AMD --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 3c81805b844..7092361d2ea 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -4981,8 +4981,10 @@ static vk_device ggml_vk_get_device(size_t idx) { std::vector queue_family_props = device->physical_device.getQueueFamilyProperties(); // Try to find a non-graphics compute queue and transfer-focused queues - const uint32_t compute_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eCompute, vk::QueueFlagBits::eGraphics, -1, 1); - const uint32_t transfer_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eTransfer, vk::QueueFlagBits::eCompute | vk::QueueFlagBits::eGraphics, compute_queue_family_index, 1); + // On AMD, the graphics queue seems to be faster, so don't avoid it + const vk::QueueFlagBits graphics_flag = device->vendor_id == VK_VENDOR_ID_AMD ? (vk::QueueFlagBits)0 : vk::QueueFlagBits::eGraphics; + const uint32_t compute_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eCompute, graphics_flag, -1, 1); + const uint32_t transfer_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eTransfer, vk::QueueFlagBits::eCompute | graphics_flag, compute_queue_family_index, 1); const float priorities[] = { 1.0f, 1.0f }; device->single_queue = compute_queue_family_index == transfer_queue_family_index && queue_family_props[compute_queue_family_index].queueCount == 1; @@ -5441,13 +5443,11 @@ static vk_device ggml_vk_get_device(size_t idx) { ggml_vk_load_shaders(device); - const bool prefers_transfer_queue = device->vendor_id == VK_VENDOR_ID_AMD && device->architecture != AMD_GCN; - if (!device->single_queue) { const uint32_t transfer_queue_index = compute_queue_family_index == transfer_queue_family_index ? 1 : 0; ggml_vk_create_queue(device, device->transfer_queue, transfer_queue_family_index, transfer_queue_index, { vk::PipelineStageFlagBits::eTransfer }, true); - device->async_use_transfer_queue = prefers_transfer_queue || (getenv("GGML_VK_ASYNC_USE_TRANSFER_QUEUE") != nullptr); + device->async_use_transfer_queue = (getenv("GGML_VK_ASYNC_USE_TRANSFER_QUEUE") != nullptr); } else { // TODO: Use pointer or reference to avoid copy device->transfer_queue.copyFrom(device->compute_queue); From 55c66106afa1a7703af6af654ebc99fb4264251d Mon Sep 17 00:00:00 2001 From: PikaPikachu Date: Sun, 15 Mar 2026 15:33:39 +0800 Subject: [PATCH 279/831] cuda : add RDNA4-specific MMVQ parameter table for bs=1 decode (llama/19478) * mmvq: add RDNA3/RDNA4-specific parameter table (nwarps=8, rows=1) * mmvq: add dedicated RDNA3 parameter table * mmvq: exclude RDNA3.5 (gfx1150/1151) from RDNA3 table --- ggml/src/ggml-cuda/mmvq.cu | 89 ++++++++++++++++++++++++++------ ggml/src/ggml-cuda/vendors/hip.h | 8 +++ 2 files changed, 81 insertions(+), 16 deletions(-) diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index ce25ccf427c..632246e43fd 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -60,11 +60,17 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) { enum mmvq_parameter_table_id { MMVQ_PARAMETERS_GENERIC = 0, MMVQ_PARAMETERS_GCN, - MMVQ_PARAMETERS_RDNA2 + MMVQ_PARAMETERS_RDNA2, + MMVQ_PARAMETERS_RDNA3_0, + MMVQ_PARAMETERS_RDNA4 }; static constexpr __device__ mmvq_parameter_table_id get_device_table_id() { -#if defined(RDNA2) || defined(RDNA3) || defined(RDNA4) +#if defined(RDNA4) + return MMVQ_PARAMETERS_RDNA4; +#elif defined(RDNA3_0) + return MMVQ_PARAMETERS_RDNA3_0; +#elif defined(RDNA2) || defined(RDNA3_5) return MMVQ_PARAMETERS_RDNA2; #elif defined(GCN) || defined(CDNA) return MMVQ_PARAMETERS_GCN; @@ -74,7 +80,13 @@ static constexpr __device__ mmvq_parameter_table_id get_device_table_id() { } static __host__ mmvq_parameter_table_id get_device_table_id(int cc) { - if (GGML_CUDA_CC_IS_RDNA2(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) { + if (GGML_CUDA_CC_IS_RDNA4(cc)) { + return MMVQ_PARAMETERS_RDNA4; + } + if (GGML_CUDA_CC_IS_RDNA3_0(cc)) { + return MMVQ_PARAMETERS_RDNA3_0; + } + if (GGML_CUDA_CC_IS_RDNA2(cc) || GGML_CUDA_CC_IS_RDNA3_5(cc)) { return MMVQ_PARAMETERS_RDNA2; } if (GGML_CUDA_CC_IS_GCN(cc) || GGML_CUDA_CC_IS_CDNA(cc)) { @@ -83,7 +95,7 @@ static __host__ mmvq_parameter_table_id get_device_table_id(int cc) { return MMVQ_PARAMETERS_GENERIC; } -static constexpr __host__ __device__ int calc_nwarps(int ncols_dst, mmvq_parameter_table_id table_id) { +static constexpr __host__ __device__ int calc_nwarps(ggml_type type, int ncols_dst, mmvq_parameter_table_id table_id) { if (table_id == MMVQ_PARAMETERS_GENERIC) { switch (ncols_dst) { case 1: @@ -114,6 +126,50 @@ static constexpr __host__ __device__ int calc_nwarps(int ncols_dst, mmvq_paramet return 1; } } + if (table_id == MMVQ_PARAMETERS_RDNA4) { + // nwarps=8 benefits types with simple vec_dot on RDNA4 (ncols_dst=1). + // Types with complex vec_dot (Q3_K, IQ2_*, IQ3_*) regress due to register + // pressure and lookup table contention at higher thread counts. + if (ncols_dst == 1) { + switch (type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ4_XS: + return 8; + default: + return 1; + } + } + return 1; + } + if (table_id == MMVQ_PARAMETERS_RDNA3_0) { + // RDNA3 (W7900): stricter whitelist than RDNA4. + // Q2_K / Q5_K / IQ4_XS regress in full quant sweeps. + if (ncols_dst == 1) { + switch (type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ4_NL: + return 8; + default: + return 1; + } + } + return 1; + } return 1; } @@ -138,7 +194,7 @@ static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int } template -__launch_bounds__(calc_nwarps(ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1) +__launch_bounds__(calc_nwarps(type, ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1) static __global__ void mul_mat_vec_q( const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst, const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y, @@ -151,7 +207,7 @@ static __global__ void mul_mat_vec_q( constexpr int qi = ggml_cuda_type_traits::qi; constexpr int vdr = get_vdr_mmvq(type); constexpr mmvq_parameter_table_id table_id = get_device_table_id(); - constexpr int nwarps = calc_nwarps(ncols_dst, table_id); + constexpr int nwarps = calc_nwarps(type, ncols_dst, table_id); constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_dst, table_id); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); @@ -355,12 +411,13 @@ static __global__ void mul_mat_vec_q( } } +template static std::pair calc_launch_params( const int ncols_dst, const int nrows_x, const int nchannels_dst, const int nsamples_or_ntokens, const int warp_size, const mmvq_parameter_table_id table_id) { const int64_t nblocks = (nrows_x + calc_rows_per_block(ncols_dst, table_id) - 1) / calc_rows_per_block(ncols_dst, table_id); const dim3 block_nums(nblocks, nchannels_dst, nsamples_or_ntokens); - const dim3 block_dims(warp_size, calc_nwarps(ncols_dst, table_id), 1); + const dim3 block_dims(warp_size, calc_nwarps(type, ncols_dst, table_id), 1); return {block_nums, block_dims}; } @@ -420,7 +477,7 @@ static void mul_mat_vec_q_switch_ncols_dst( if (has_ids && ncols_dst > 1) { // Multi-token MUL_MAT_ID path only - single-token goes through regular path below constexpr int c_ncols_dst = 1; - std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, ncols_dst, warp_size, table_id); + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, ncols_dst, warp_size, table_id); mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, @@ -431,7 +488,7 @@ static void mul_mat_vec_q_switch_ncols_dst( switch (ncols_dst) { case 1: { constexpr int c_ncols_dst = 1; - std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, @@ -439,7 +496,7 @@ static void mul_mat_vec_q_switch_ncols_dst( } break; case 2: { constexpr int c_ncols_dst = 2; - std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, @@ -447,7 +504,7 @@ static void mul_mat_vec_q_switch_ncols_dst( } break; case 3: { constexpr int c_ncols_dst = 3; - std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, @@ -455,7 +512,7 @@ static void mul_mat_vec_q_switch_ncols_dst( } break; case 4: { constexpr int c_ncols_dst = 4; - std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, @@ -463,7 +520,7 @@ static void mul_mat_vec_q_switch_ncols_dst( } break; case 5: { constexpr int c_ncols_dst = 5; - std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, @@ -471,7 +528,7 @@ static void mul_mat_vec_q_switch_ncols_dst( } break; case 6: { constexpr int c_ncols_dst = 6; - std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, @@ -479,7 +536,7 @@ static void mul_mat_vec_q_switch_ncols_dst( } break; case 7: { constexpr int c_ncols_dst = 7; - std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, @@ -487,7 +544,7 @@ static void mul_mat_vec_q_switch_ncols_dst( } break; case 8: { constexpr int c_ncols_dst = 8; - std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h index 5cc1b54319c..35d1e1a0639 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h @@ -207,6 +207,14 @@ #define RDNA3 #endif // defined(__GFX11__) +#if defined(__gfx1150__) || defined(__gfx1151__) +#define RDNA3_5 +#endif // defined(__gfx1150__) || defined(__gfx1151__) + +#if defined(RDNA3) && !defined(RDNA3_5) +#define RDNA3_0 +#endif // defined(RDNA3) && !defined(RDNA3_5) + #if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) || \ defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__) #define RDNA2 From 6770239830f6e99c31b808cf7acacb52edc96faf Mon Sep 17 00:00:00 2001 From: Bartowski <3266127+bartowski1182@users.noreply.github.com> Date: Sun, 15 Mar 2026 04:47:28 -0400 Subject: [PATCH 280/831] ggml : guard against sumq2 being 0 in IQ4_NL (llama/20460) --- ggml/src/ggml-quants.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index cdaded865b1..48695a61ea3 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -4767,7 +4767,7 @@ static void quantize_row_iq4_nl_impl(const int super_block_size, const int block sumqx += w*q*xb[j]; sumq2 += w*q*q; } - d = sumqx/sumq2; + d = sumq2 > 0 ? sumqx/sumq2 : 0.f; float best = d*sumqx; for (int itry = -ntry; itry <= ntry; ++itry) { id = (itry + values[0])/max; From b327a321a21899e17c8923671500241ed9f93670 Mon Sep 17 00:00:00 2001 From: MoonShadow Date: Mon, 16 Mar 2026 00:23:58 +0800 Subject: [PATCH 281/831] ggml/hip: fix APU compatibility - soft error handling for hipMemAdviseSetCoarseGrain (llama/20536) * ggml/hip: fix APU compatibility - soft error handling for hipMemAdviseSetCoarseGrain On AMD APU/iGPU devices (unified memory architecture), hipMemAdviseSetCoarseGrain returns hipErrorInvalidValue because the hint is not applicable to UMA systems. The previous CUDA_CHECK() call treated this as a fatal error, causing crashes on APU systems such as AMD Strix Halo (gfx1151). Fix: treat hipMemAdviseSetCoarseGrain as an optional performance hint - call it without error checking and clear any resulting error with hipGetLastError(). Also add pre-allocation debug logging (GGML_LOG_DEBUG) to help diagnose memory issues on APU systems, and store totalGlobalMem in device info. Context: AMD APUs on Windows are affected by a ROCm runtime bug that limits hipMallocManaged to ~64GB regardless of available system RAM. A fix has been submitted upstream: https://github.com/ROCm/rocm-systems/pull/4077 Co-Authored-By: Claude Sonnet 4.6 * ggml/hip: remove unrelated changes, keep only hipMemAdviseSetCoarseGrain fix --------- Co-authored-by: moonshadow-25 Co-authored-by: Claude Sonnet 4.6 --- ggml/src/ggml-cuda/ggml-cuda.cu | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index ce7a80acde8..3886290c5ff 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -124,7 +124,10 @@ static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) err = cudaMallocManaged(ptr, size); #if defined(GGML_USE_HIP) if (err == hipSuccess) { - CUDA_CHECK(cudaMemAdvise(*ptr, size, hipMemAdviseSetCoarseGrain, device)); + // hipMemAdviseSetCoarseGrain is an optional performance hint; + // ignore errors (e.g. hipErrorInvalidValue on some APU/iGPU configs). + cudaMemAdvise(*ptr, size, hipMemAdviseSetCoarseGrain, device); + (void)hipGetLastError(); // clear any error } // fall back to cudaMalloc if not supported (e.g. on Windows) From 2fb6aea8ad4453dba649937e914b28d194c55c53 Mon Sep 17 00:00:00 2001 From: Pascal Date: Sun, 15 Mar 2026 17:42:56 +0100 Subject: [PATCH 282/831] ggml: avoid creating CUDA context during device init (llama/20595) --- ggml/src/ggml-cuda/ggml-cuda.cu | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 3886290c5ff..5a0be4a472a 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -254,11 +254,6 @@ static ggml_cuda_device_info ggml_cuda_init() { info.devices[id].supports_cooperative_launch = false; #endif // !(GGML_USE_MUSA) - // cudaMemGetInfo returns info for the current device - size_t free_mem; - CUDA_CHECK(cudaSetDevice(id)); - CUDA_CHECK(cudaMemGetInfo(&free_mem, NULL)); - #if defined(GGML_USE_HIP) info.devices[id].smpbo = prop.sharedMemPerBlock; @@ -273,25 +268,25 @@ static ggml_cuda_device_info ggml_cuda_init() { info.devices[id].cc += prop.minor * 0x10; } } - GGML_LOG_INFO(" Device %d: %s, %s (0x%x), VMM: %s, Wave Size: %d, VRAM: %zu MiB (%zu MiB free)\n", + GGML_LOG_INFO(" Device %d: %s, %s (0x%x), VMM: %s, Wave Size: %d, VRAM: %zu MiB\n", id, prop.name, prop.gcnArchName, info.devices[id].cc & 0xffff, device_vmm ? "yes" : "no", prop.warpSize, - (size_t)(prop.totalGlobalMem / (1024 * 1024)), free_mem / (1024 * 1024)); + (size_t)(prop.totalGlobalMem / (1024 * 1024))); #elif defined(GGML_USE_MUSA) // FIXME: Ensure compatibility with varying warp sizes across different MUSA archs. info.devices[id].warp_size = 32; info.devices[id].smpbo = prop.sharedMemPerBlockOptin; info.devices[id].cc = GGML_CUDA_CC_OFFSET_MTHREADS + prop.major * 0x100; info.devices[id].cc += prop.minor * 0x10; - GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s, VRAM: %zu MiB (%zu MiB free)\n", + GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s, VRAM: %zu MiB\n", id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no", - (size_t)(prop.totalGlobalMem / (1024 * 1024)), free_mem / (1024 * 1024)); + (size_t)(prop.totalGlobalMem / (1024 * 1024))); #else info.devices[id].smpbo = prop.sharedMemPerBlockOptin; info.devices[id].cc = 100*prop.major + 10*prop.minor; - GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s, VRAM: %zu MiB (%zu MiB free)\n", + GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s, VRAM: %zu MiB\n", id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no", - (size_t)(prop.totalGlobalMem / (1024 * 1024)), free_mem / (1024 * 1024)); + (size_t)(prop.totalGlobalMem / (1024 * 1024))); std::string device_name(prop.name); if (device_name == "NVIDIA GeForce MX450") { turing_devices_without_mma.push_back({ id, device_name }); @@ -306,6 +301,7 @@ static ggml_cuda_device_info ggml_cuda_init() { // TODO: Check for future drivers the default scheduling strategy and // remove this call again when cudaDeviceScheduleSpin is default. if (prop.major == 12 && prop.minor == 1) { + CUDA_CHECK(cudaSetDevice(id)); CUDA_CHECK(cudaSetDeviceFlags(cudaDeviceScheduleSpin)); } From d7926e62d40e9ac9e5e9610421564f2506f10d1a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sun, 15 Mar 2026 18:30:47 +0100 Subject: [PATCH 283/831] CUDA: limit number of FA stream-k CUDA blocks (llama/20586) --- ggml/src/ggml-cuda/fattn-common.cuh | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index b6a7460da83..e9abdf288c4 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -892,7 +892,7 @@ void launch_fattn( const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1); const int gqa_ratio = Q->ne[2] / K->ne[2]; const int ntiles_z_gqa = ((gqa_ratio + ncols2 - 1) / ncols2); - const int ntiles_total = ntiles_x * ntiles_z_gqa * K->ne[2] * Q->ne[3]; + const int ntiles_dst = ntiles_x * ntiles_z_gqa * K->ne[2] * Q->ne[3]; // Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped. // Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or @@ -919,37 +919,37 @@ void launch_fattn( GGML_ASSERT(max_blocks_per_sm > 0); int parallel_blocks = max_blocks_per_sm; + const int ntiles_KV = (K->ne[1] + nbatch_fa - 1) / nbatch_fa; // Max. number of parallel blocks limited by KV cache length. + dim3 blocks_num; if (stream_k) { // For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup. const int max_blocks = max_blocks_per_sm*nsm; - const int tiles_nwaves = (ntiles_total + max_blocks - 1) / max_blocks; - const int tiles_efficiency_percent = 100 * ntiles_total / (max_blocks*tiles_nwaves); + const int tiles_nwaves = (ntiles_dst + max_blocks - 1) / max_blocks; + const int tiles_efficiency_percent = 100 * ntiles_dst / (max_blocks*tiles_nwaves); - const int nblocks_stream_k = max_blocks; + const int nblocks_stream_k = std::min(max_blocks, ntiles_KV*ntiles_dst); const bool use_stream_k = cc >= GGML_CUDA_CC_ADA_LOVELACE || amd_wmma_available(cc) || tiles_efficiency_percent < 75; - blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_total; + blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_dst; blocks_num.y = 1; blocks_num.z = 1; - if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles. + if (ntiles_dst % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles. dst_tmp_meta.alloc((size_t(blocks_num.x) * ncols * (2 + DV/2))); } } else { - const int ntiles_KQ = (K->ne[1] + nbatch_fa - 1) / nbatch_fa; // Max. number of parallel blocks limited by tensor size. - // parallel_blocks must not be larger than what the tensor size allows: - parallel_blocks = std::min(parallel_blocks, ntiles_KQ); + parallel_blocks = std::min(parallel_blocks, ntiles_KV); // If ntiles_total % blocks_per_wave != 0 then some efficiency is lost due to tail effects. // Test whether parallel_blocks can be set to a higher value for better efficiency. const int blocks_per_wave = nsm * max_blocks_per_sm; int nwaves_best = 0; int efficiency_percent_best = 0; - for (int parallel_blocks_test = parallel_blocks; parallel_blocks_test <= ntiles_KQ; ++parallel_blocks_test) { - const int nblocks_total = ntiles_total * parallel_blocks_test; + for (int parallel_blocks_test = parallel_blocks; parallel_blocks_test <= ntiles_KV; ++parallel_blocks_test) { + const int nblocks_total = ntiles_dst * parallel_blocks_test; const int nwaves = (nblocks_total + blocks_per_wave - 1) / blocks_per_wave; const int efficiency_percent = 100 * nblocks_total / (nwaves*blocks_per_wave); @@ -1015,7 +1015,7 @@ void launch_fattn( CUDA_CHECK(cudaGetLastError()); if (stream_k) { - if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles. + if (ntiles_dst % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles. const dim3 block_dim_combine(DV, 1, 1); const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2}; From 81ea958719e3dc38808664585ecc26ae1370b0b5 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 15 Mar 2026 19:56:19 +0200 Subject: [PATCH 284/831] common : add nvfp4 (ggml/0) --- examples/common-ggml.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/common-ggml.cpp b/examples/common-ggml.cpp index c42b644fedd..6f02a2504c5 100644 --- a/examples/common-ggml.cpp +++ b/examples/common-ggml.cpp @@ -73,6 +73,7 @@ bool ggml_common_quantize_0( case GGML_FTYPE_MOSTLY_IQ1_M: case GGML_FTYPE_MOSTLY_BF16: case GGML_FTYPE_MOSTLY_MXFP4: + case GGML_FTYPE_MOSTLY_NVFP4: { fprintf(stderr, "%s: invalid model type %d\n", __func__, ftype); return false; @@ -213,6 +214,7 @@ bool ggml_common_quantize_0( case GGML_TYPE_TQ1_0: case GGML_TYPE_TQ2_0: case GGML_TYPE_MXFP4: + case GGML_TYPE_NVFP4: case GGML_TYPE_COUNT: { fprintf(stderr, "%s: unsupported quantization type %d (%s)\n", __func__, ttype, ggml_type_name((ggml_type) ttype)); From d4bc312169376690a5615c1817b3e01c491f8c1c Mon Sep 17 00:00:00 2001 From: David366AI <86212041+David366AI@users.noreply.github.com> Date: Sun, 15 Mar 2026 15:50:56 -0400 Subject: [PATCH 285/831] ggml : extend im2col f16 (ggml/1434) * examples/yolo: fix load_model memory leak * fix/issue-1433 ggml_compute_forward_im2col_f16 assert error * fix/issue-1433 --- ggml/src/ggml-cpu/ops.cpp | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 314cc1088a0..3f85e531daa 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -6205,7 +6205,7 @@ static void ggml_compute_forward_im2col_f16( const ggml_tensor * src1 = dst->src[1]; GGML_ASSERT(src0->type == GGML_TYPE_F16); - GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F16); GGML_TENSOR_BINARY_OP_LOCALS; @@ -6236,7 +6236,7 @@ static void ggml_compute_forward_im2col_f16( int ofs1 = is_2D ? nb12 : nb11; GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); - GGML_ASSERT(nb10 == sizeof(float)); + GGML_ASSERT(nb10 == ggml_type_size(src1->type)); // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW] { @@ -6249,7 +6249,12 @@ static void ggml_compute_forward_im2col_f16( // micro kernel ggml_fp16_t * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW] - const float * const src_data = (float *)((char *) src1->data + in*ofs0 + iic*ofs1); // [IH, IW] + const float * const src_data_f32 = src1->type == GGML_TYPE_F32 + ? (const float *)((const char *) src1->data + in*ofs0 + iic*ofs1) + : nullptr; // [IH, IW] + const ggml_fp16_t * const src_data_f16 = src1->type == GGML_TYPE_F16 + ? (const ggml_fp16_t *)((const char *) src1->data + in*ofs0 + iic*ofs1) + : nullptr; // [IH, IW] for (int64_t ikh = 0; ikh < KH; ikh++) { // 1 for (int64_t ikw = 0; ikw < KW; ikw++) { @@ -6259,7 +6264,11 @@ static void ggml_compute_forward_im2col_f16( if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0; } else { - dst_data[iic*(KH*KW) + ikh*KW + ikw] = GGML_CPU_FP32_TO_FP16(src_data[iih*IW + iiw]); + if (src_data_f32 != nullptr) { + dst_data[iic*(KH*KW) + ikh*KW + ikw] = GGML_CPU_FP32_TO_FP16(src_data_f32[iih*IW + iiw]); + } else { + dst_data[iic*(KH*KW) + ikh*KW + ikw] = src_data_f16[iih*IW + iiw]; + } } } } From ab1252c19e74edba35b092552987d0e3d4c73fbe Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 16 Mar 2026 07:13:51 +0200 Subject: [PATCH 286/831] sync : ggml --- scripts/sync-ggml.last | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index 444f031fa2e..709d00d40b3 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -75e4e5b841fc483127ebf14b8eed9ce589a52c5a +9d0addf420778b42c257cd3837fbd38ca4599f3b From 2bc630f197d9b97f3502fc1fa38c7f0f37783237 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 16 Mar 2026 07:16:46 +0200 Subject: [PATCH 287/831] talk-llama : sync llama.cpp --- examples/talk-llama/llama-arch.cpp | 22 + examples/talk-llama/llama-arch.h | 9 + examples/talk-llama/llama-batch.cpp | 4 +- examples/talk-llama/llama-context.cpp | 153 +++- examples/talk-llama/llama-cparams.h | 3 + examples/talk-llama/llama-ext.h | 12 + examples/talk-llama/llama-grammar.cpp | 14 +- examples/talk-llama/llama-graph.cpp | 80 +- examples/talk-llama/llama-graph.h | 17 +- examples/talk-llama/llama-hparams.cpp | 32 +- examples/talk-llama/llama-hparams.h | 20 +- examples/talk-llama/llama-impl.cpp | 4 +- examples/talk-llama/llama-impl.h | 4 +- examples/talk-llama/llama-kv-cache.cpp | 54 +- examples/talk-llama/llama-kv-cache.h | 3 +- examples/talk-llama/llama-model-loader.cpp | 642 ++++++++++++--- examples/talk-llama/llama-model-loader.h | 40 +- examples/talk-llama/llama-model-saver.cpp | 103 +-- examples/talk-llama/llama-model-saver.h | 7 +- examples/talk-llama/llama-model.cpp | 753 +++++++----------- examples/talk-llama/llama-model.h | 35 +- examples/talk-llama/llama-quant.cpp | 747 ++++++++++------- examples/talk-llama/llama-vocab.cpp | 4 +- examples/talk-llama/llama.cpp | 32 +- examples/talk-llama/llama.h | 27 +- examples/talk-llama/models/afmoe.cpp | 5 +- examples/talk-llama/models/apertus.cpp | 6 +- examples/talk-llama/models/arcee.cpp | 6 +- examples/talk-llama/models/arctic.cpp | 9 +- examples/talk-llama/models/baichuan.cpp | 7 +- examples/talk-llama/models/bailingmoe.cpp | 3 +- examples/talk-llama/models/bailingmoe2.cpp | 8 +- examples/talk-llama/models/bert.cpp | 20 +- examples/talk-llama/models/bitnet.cpp | 33 +- examples/talk-llama/models/bloom.cpp | 4 +- examples/talk-llama/models/chameleon.cpp | 6 +- examples/talk-llama/models/chatglm.cpp | 4 +- examples/talk-llama/models/codeshell.cpp | 6 +- examples/talk-llama/models/cogvlm.cpp | 6 +- examples/talk-llama/models/cohere2-iswa.cpp | 4 +- examples/talk-llama/models/command-r.cpp | 4 +- examples/talk-llama/models/dbrx.cpp | 9 +- examples/talk-llama/models/deci.cpp | 6 +- examples/talk-llama/models/deepseek.cpp | 10 +- examples/talk-llama/models/deepseek2.cpp | 6 +- examples/talk-llama/models/delta-net-base.cpp | 109 ++- examples/talk-llama/models/dots1.cpp | 10 +- examples/talk-llama/models/dream.cpp | 6 +- examples/talk-llama/models/ernie4-5-moe.cpp | 10 +- examples/talk-llama/models/ernie4-5.cpp | 6 +- examples/talk-llama/models/eurobert.cpp | 4 +- examples/talk-llama/models/exaone-moe.cpp | 9 +- examples/talk-llama/models/exaone.cpp | 6 +- examples/talk-llama/models/exaone4.cpp | 6 +- examples/talk-llama/models/falcon-h1.cpp | 2 +- examples/talk-llama/models/falcon.cpp | 6 +- .../talk-llama/models/gemma-embedding.cpp | 2 +- examples/talk-llama/models/gemma.cpp | 2 +- examples/talk-llama/models/gemma2-iswa.cpp | 2 +- examples/talk-llama/models/gemma3.cpp | 2 +- examples/talk-llama/models/gemma3n-iswa.cpp | 2 +- examples/talk-llama/models/glm4-moe.cpp | 6 +- examples/talk-llama/models/glm4.cpp | 4 +- examples/talk-llama/models/gpt2.cpp | 4 +- examples/talk-llama/models/gptneox.cpp | 4 +- examples/talk-llama/models/granite-hybrid.cpp | 7 +- examples/talk-llama/models/granite.cpp | 9 +- examples/talk-llama/models/grok.cpp | 8 +- examples/talk-llama/models/grovemoe.cpp | 12 +- examples/talk-llama/models/hunyuan-dense.cpp | 6 +- examples/talk-llama/models/hunyuan-moe.cpp | 9 +- examples/talk-llama/models/internlm2.cpp | 6 +- examples/talk-llama/models/jais.cpp | 4 +- examples/talk-llama/models/jais2.cpp | 6 +- examples/talk-llama/models/jamba.cpp | 4 +- examples/talk-llama/models/kimi-linear.cpp | 21 +- examples/talk-llama/models/lfm2.cpp | 16 +- examples/talk-llama/models/llada-moe.cpp | 8 +- examples/talk-llama/models/llada.cpp | 6 +- examples/talk-llama/models/llama-iswa.cpp | 8 +- examples/talk-llama/models/llama.cpp | 29 +- examples/talk-llama/models/maincoder.cpp | 6 +- examples/talk-llama/models/mamba-base.cpp | 4 + examples/talk-llama/models/mimo2-iswa.cpp | 16 +- examples/talk-llama/models/minicpm3.cpp | 26 +- examples/talk-llama/models/minimax-m2.cpp | 9 +- examples/talk-llama/models/mistral3.cpp | 8 +- examples/talk-llama/models/models.h | 22 +- examples/talk-llama/models/modern-bert.cpp | 4 +- examples/talk-llama/models/mpt.cpp | 4 +- examples/talk-llama/models/nemotron-h.cpp | 28 +- examples/talk-llama/models/nemotron.cpp | 6 +- examples/talk-llama/models/neo-bert.cpp | 4 +- examples/talk-llama/models/olmo.cpp | 6 +- examples/talk-llama/models/olmo2.cpp | 6 +- examples/talk-llama/models/olmoe.cpp | 8 +- .../talk-llama/models/openai-moe-iswa.cpp | 2 +- examples/talk-llama/models/openelm.cpp | 4 +- examples/talk-llama/models/orion.cpp | 6 +- examples/talk-llama/models/paddleocr.cpp | 6 +- examples/talk-llama/models/pangu-embedded.cpp | 6 +- examples/talk-llama/models/phi2.cpp | 4 +- examples/talk-llama/models/phi3.cpp | 6 +- examples/talk-llama/models/plamo.cpp | 6 +- examples/talk-llama/models/plamo2.cpp | 10 +- examples/talk-llama/models/plamo3.cpp | 4 +- examples/talk-llama/models/plm.cpp | 28 +- examples/talk-llama/models/qwen.cpp | 4 +- examples/talk-llama/models/qwen2.cpp | 6 +- examples/talk-llama/models/qwen2moe.cpp | 8 +- examples/talk-llama/models/qwen2vl.cpp | 6 +- examples/talk-llama/models/qwen3.cpp | 21 +- examples/talk-llama/models/qwen35.cpp | 45 +- examples/talk-llama/models/qwen35moe.cpp | 62 +- examples/talk-llama/models/qwen3moe.cpp | 23 +- examples/talk-llama/models/qwen3next.cpp | 26 +- examples/talk-llama/models/qwen3vl-moe.cpp | 8 +- examples/talk-llama/models/qwen3vl.cpp | 6 +- examples/talk-llama/models/refact.cpp | 4 +- examples/talk-llama/models/rnd1.cpp | 8 +- examples/talk-llama/models/seed-oss.cpp | 6 +- examples/talk-llama/models/smallthinker.cpp | 8 +- examples/talk-llama/models/smollm3.cpp | 6 +- examples/talk-llama/models/stablelm.cpp | 4 +- examples/talk-llama/models/starcoder.cpp | 4 +- examples/talk-llama/models/starcoder2.cpp | 6 +- examples/talk-llama/models/step35-iswa.cpp | 9 +- examples/talk-llama/models/t5-dec.cpp | 4 +- examples/talk-llama/models/t5-enc.cpp | 4 +- examples/talk-llama/models/xverse.cpp | 6 +- examples/talk-llama/unicode.cpp | 2 +- 131 files changed, 2362 insertions(+), 1507 deletions(-) create mode 100644 examples/talk-llama/llama-ext.h diff --git a/examples/talk-llama/llama-arch.cpp b/examples/talk-llama/llama-arch.cpp index 47e8d5278ac..799d16167ba 100644 --- a/examples/talk-llama/llama-arch.cpp +++ b/examples/talk-llama/llama-arch.cpp @@ -4,6 +4,7 @@ #include #include +#include static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_CLIP, "clip" }, // dummy, only used by llama-quantize @@ -184,6 +185,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_EXPERT_GROUP_SCALE, "%s.expert_group_scale" }, { LLM_KV_EXPERTS_PER_GROUP, "%s.experts_per_group" }, { LLM_KV_MOE_EVERY_N_LAYERS, "%s.moe_every_n_layers" }, + { LLM_KV_MOE_LATENT_SIZE, "%s.moe_latent_size" }, { LLM_KV_NEXTN_PREDICT_LAYERS, "%s.nextn_predict_layers" }, { LLM_KV_NUM_DEEPSTACK_LAYERS, "%s.n_deepstack_layers" }, { LLM_KV_POOLING_TYPE, "%s.pooling_type" }, @@ -229,11 +231,14 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ATTENTION_TEMPERATURE_SCALE, "%s.attention.temperature_scale" }, { LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" }, { LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" }, + { LLM_KV_ATTENTION_KEY_LENGTH_SWA, "%s.attention.key_length_swa" }, + { LLM_KV_ATTENTION_VALUE_LENGTH_SWA, "%s.attention.value_length_swa" }, { LLM_KV_ATTENTION_INDEXER_HEAD_COUNT, "%s.attention.indexer.head_count" }, { LLM_KV_ATTENTION_INDEXER_KEY_LENGTH, "%s.attention.indexer.key_length" }, { LLM_KV_ATTENTION_INDEXER_TOP_K, "%s.attention.indexer.top_k" }, { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, + { LLM_KV_ROPE_DIMENSION_COUNT_SWA, "%s.rope.dimension_count_swa" }, { LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" }, { LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" }, { LLM_KV_ROPE_FREQ_BASE_SWA, "%s.rope.freq_base_swa" }, @@ -361,6 +366,8 @@ static const std::map LLM_TENSOR_NAMES = { { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" }, + { LLM_TENSOR_FFN_LATENT_DOWN, "blk.%d.ffn_latent_down" }, + { LLM_TENSOR_FFN_LATENT_UP, "blk.%d.ffn_latent_up" }, { LLM_TENSOR_ATTN_NORM_2, "blk.%d.attn_norm_2" }, { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, { LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" }, @@ -1083,6 +1090,7 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_TOKEN_EMBD, LLM_TENSOR_OUTPUT_NORM, LLM_TENSOR_OUTPUT, + LLM_TENSOR_CLS_OUT, LLM_TENSOR_ATTN_NORM, LLM_TENSOR_ATTN_Q, LLM_TENSOR_ATTN_Q_NORM, @@ -1874,6 +1882,8 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_FFN_UP_EXPS, LLM_TENSOR_FFN_DOWN_EXPS, LLM_TENSOR_FFN_EXP_PROBS_B, + LLM_TENSOR_FFN_LATENT_DOWN, + LLM_TENSOR_FFN_LATENT_UP, // MoE shared expert layer LLM_TENSOR_FFN_DOWN_SHEXP, LLM_TENSOR_FFN_UP_SHEXP, @@ -2749,6 +2759,9 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, {LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, {LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, + // Nemotron 3 Super + {LLM_TENSOR_FFN_LATENT_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_FFN_LATENT_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, }; LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {} @@ -2786,6 +2799,15 @@ std::string LLM_TN_IMPL::str() const { return name; } +std::vector llm_arch_all() { + std::vector ret; + ret.reserve(LLM_ARCH_NAMES.size()); + for (const auto & [arch, _] : LLM_ARCH_NAMES) { + ret.push_back(arch); + } + return ret; +} + const char * llm_arch_name(llm_arch arch) { auto it = LLM_ARCH_NAMES.find(arch); if (it == LLM_ARCH_NAMES.end()) { diff --git a/examples/talk-llama/llama-arch.h b/examples/talk-llama/llama-arch.h index 6d1b1df31c0..b1b1dcf1883 100644 --- a/examples/talk-llama/llama-arch.h +++ b/examples/talk-llama/llama-arch.h @@ -4,6 +4,7 @@ #include #include +#include // // gguf constants (sync with gguf.py) @@ -188,6 +189,7 @@ enum llm_kv { LLM_KV_EXPERT_GROUP_SCALE, LLM_KV_EXPERTS_PER_GROUP, LLM_KV_MOE_EVERY_N_LAYERS, + LLM_KV_MOE_LATENT_SIZE, LLM_KV_NEXTN_PREDICT_LAYERS, LLM_KV_NUM_DEEPSTACK_LAYERS, LLM_KV_POOLING_TYPE, @@ -233,11 +235,14 @@ enum llm_kv { LLM_KV_ATTENTION_TEMPERATURE_SCALE, LLM_KV_ATTENTION_KEY_LENGTH_MLA, LLM_KV_ATTENTION_VALUE_LENGTH_MLA, + LLM_KV_ATTENTION_KEY_LENGTH_SWA, + LLM_KV_ATTENTION_VALUE_LENGTH_SWA, LLM_KV_ATTENTION_INDEXER_HEAD_COUNT, LLM_KV_ATTENTION_INDEXER_KEY_LENGTH, LLM_KV_ATTENTION_INDEXER_TOP_K, LLM_KV_ROPE_DIMENSION_COUNT, + LLM_KV_ROPE_DIMENSION_COUNT_SWA, LLM_KV_ROPE_DIMENSION_SECTIONS, LLM_KV_ROPE_FREQ_BASE, LLM_KV_ROPE_FREQ_BASE_SWA, @@ -381,6 +386,8 @@ enum llm_tensor { LLM_TENSOR_FFN_GATE_CHEXPS, LLM_TENSOR_FFN_UP_CHEXPS, LLM_TENSOR_FFN_EXP_PROBS_B, + LLM_TENSOR_FFN_LATENT_DOWN, + LLM_TENSOR_FFN_LATENT_UP, LLM_TENSOR_ATTN_Q_NORM, LLM_TENSOR_ATTN_K_NORM, LLM_TENSOR_LAYER_OUT_NORM, @@ -608,6 +615,8 @@ struct llm_tensor_info { ggml_op op; }; +std::vector llm_arch_all(); + const char * llm_arch_name(llm_arch arch); llm_arch llm_arch_from_string(const std::string & name); diff --git a/examples/talk-llama/llama-batch.cpp b/examples/talk-llama/llama-batch.cpp index 386fab04ac9..6bf76939cdd 100644 --- a/examples/talk-llama/llama-batch.cpp +++ b/examples/talk-llama/llama-batch.cpp @@ -394,11 +394,13 @@ llama_ubatch llama_batch_allocr::ubatch_reserve(uint32_t n_seq_tokens, uint32_t clear(); split_reset(); + const int64_t n_pos_all = (int64_t) n_tokens*n_pos_per_embd; + auto udata = std::make_shared(); udata->token .resize(n_tokens); udata->embd .clear(); - udata->pos .resize(n_tokens); + udata->pos .resize(n_pos_all); udata->n_seq_id .resize(n_tokens); udata->seq_id .resize(n_tokens); udata->seq_id_unq.resize(0); diff --git a/examples/talk-llama/llama-context.cpp b/examples/talk-llama/llama-context.cpp index 98d055d34ef..1f7a52d7895 100644 --- a/examples/talk-llama/llama-context.cpp +++ b/examples/talk-llama/llama-context.cpp @@ -7,6 +7,7 @@ #include "llama-memory.h" #include "llama-mmap.h" #include "llama-model.h" +#include "llama-ext.h" #include #include @@ -150,6 +151,10 @@ llama_context::llama_context( cparams.flash_attn = params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED; cparams.auto_fa = params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO; + cparams.fused_gdn_ar = true; + cparams.fused_gdn_ch = true; + cparams.auto_fgdn = true; + // with causal attention, the batch size is limited by the context size cparams.n_batch = cparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch; @@ -158,7 +163,7 @@ llama_context::llama_context( cparams.op_offload = params.op_offload; cparams.kv_unified = params.kv_unified; - // intialized later + // initialized later cparams.pipeline_parallel = false; { @@ -337,6 +342,14 @@ llama_context::llama_context( if (cparams.pipeline_parallel) { LLAMA_LOG_INFO("%s: pipeline parallelism enabled\n", __func__); + + if (!graph_reuse_disable) { + // TODO: figure out a way to make graph reuse work with pipeline parallelism + // ref: https://github.com/ggml-org/llama.cpp/pull/20463 + LLAMA_LOG_WARN("%s: graph reuse is currently not compatible with pipeline parallelism - disabling\n", __func__); + + graph_reuse_disable = true; + } } sched_reserve(); @@ -422,7 +435,7 @@ void llama_context::sched_reserve() { if (cparams.auto_fa) { auto * gf = graph_reserve(1, n_seqs, n_outputs, mctx.get(), true); if (!gf) { - throw std::runtime_error("failed to split graph for Flash Attention check"); + throw std::runtime_error("failed to reserve graph for Flash Attention check"); } const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FATTN) + 1; @@ -432,8 +445,7 @@ void llama_context::sched_reserve() { if (n->op != GGML_OP_FLASH_ATTN_EXT) { continue; } - ggml_backend_dev_t device_fa = ggml_backend_get_device( - ggml_backend_sched_get_tensor_backend(sched.get(), n)); + ggml_backend_dev_t device_fa = ggml_backend_get_device(ggml_backend_sched_get_tensor_backend(sched.get(), n)); // TODO: instead of the tensor names, use a map to keep track of which (FA) tensors belong to which layer GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FATTN "-", prefix_len) == 0); @@ -448,6 +460,7 @@ void llama_context::sched_reserve() { break; } } + if (fa_device_mismatch) { cparams.flash_attn = false; LLAMA_LOG_WARN("%s: Flash Attention was auto, set to disabled\n", __func__); @@ -459,6 +472,88 @@ void llama_context::sched_reserve() { cparams.auto_fa = false; } + if (cparams.auto_fgdn) { + LLAMA_LOG_INFO("%s: resolving fused Gated Delta Net support:\n", __func__); + + if (cparams.fused_gdn_ar) { + auto * gf = graph_reserve(1, n_seqs, n_outputs, mctx.get(), true); + if (!gf) { + throw std::runtime_error("failed to reserve graph for fused Gated Delta Net check (autoregressive)"); + } + + const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FGDN_AR) + 1; + bool gdn_device_mismatch = false; + for (int i = 0; i < ggml_graph_n_nodes(gf); i++) { + ggml_tensor * n = ggml_graph_node(gf, i); + if (n->op != GGML_OP_GATED_DELTA_NET) { + continue; + } + ggml_backend_dev_t device_gdn = ggml_backend_get_device(ggml_backend_sched_get_tensor_backend(sched.get(), n)); + + GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FGDN_AR "-", prefix_len) == 0); + const int il = std::stoi(n->name + prefix_len); + ggml_backend_dev_t device_kv = model.dev_layer(il); + if (device_gdn != device_kv) { + LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the fused Gated Delta Net tensor " + "is assigned to device %s (usually due to missing support)\n", + __func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_gdn)); + gdn_device_mismatch = true; + break; + } + } + + if (gdn_device_mismatch) { + cparams.fused_gdn_ar = false; + LLAMA_LOG_WARN("%s: fused Gated Delta Net (autoregressive) not supported, set to disabled\n", __func__); + } else { + LLAMA_LOG_INFO("%s: fused Gated Delta Net (autoregressive) enabled\n", __func__); + } + } + + if (cparams.fused_gdn_ch) { + // more than one token in the batch per sequence in order to take the chunked path + // note: n_outputs must match n_tokens for embedding models with mean/rank pooling, + // because build_pooling creates inp_mean with shape [n_tokens, n_seqs] and multiplies + // it with t_embd which is reduced to [n_outputs, ...] via out_ids. if n_outputs != n_tokens, + // the ggml_mul_mat assertion fails. this matches the pp reservation below (line ~553). + const uint32_t n_tokens_ch = 16*n_seqs; + auto * gf = graph_reserve(n_tokens_ch, n_seqs, n_tokens_ch, mctx.get(), true); + if (!gf) { + throw std::runtime_error("failed to reserve graph for fused Gated Delta Net check (chunked)"); + } + + const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FGDN_CH) + 1; + bool gdn_device_mismatch = false; + for (int i = 0; i < ggml_graph_n_nodes(gf); i++) { + ggml_tensor * n = ggml_graph_node(gf, i); + if (n->op != GGML_OP_GATED_DELTA_NET) { + continue; + } + ggml_backend_dev_t device_gdn = ggml_backend_get_device(ggml_backend_sched_get_tensor_backend(sched.get(), n)); + + GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FGDN_CH "-", prefix_len) == 0); + const int il = std::stoi(n->name + prefix_len); + ggml_backend_dev_t device_kv = model.dev_layer(il); + if (device_gdn != device_kv) { + LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the fused Gated Delta Net tensor " + "is assigned to device %s (usually due to missing support)\n", + __func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_gdn)); + gdn_device_mismatch = true; + break; + } + } + + if (gdn_device_mismatch) { + cparams.fused_gdn_ch = false; + LLAMA_LOG_WARN("%s: fused Gated Delta Net (chunked) not supported, set to disabled\n", __func__); + } else { + LLAMA_LOG_INFO("%s: fused Gated Delta Net (chunked) enabled\n", __func__); + } + } + + cparams.auto_fgdn = false; + } + // reserve worst-case graph int n_splits_pp = -1; int n_nodes_pp = -1; @@ -1039,11 +1134,15 @@ void llama_context::set_adapters_lora(llama_adapter_lora ** adapters, size_t n_a bool llama_context::adapters_lora_are_same(llama_adapter_lora ** adapters, size_t n_adapters, float * scales) { LLAMA_LOG_DEBUG("%s: adapters = %p\n", __func__, (void *) adapters); - if (n_adapters != loras->size()) { - return false; - } + // Adapters with a zero scale are never added to `loras`, so also ignore them for the comparison. + size_t n_non_zero = 0; for (size_t i = 0; i < n_adapters; i ++) { + if (scales[i] == 0.0f) { + continue; + } + n_non_zero++; + auto it = loras->find(adapters[i]); if (it == loras->end() || it->second != scales[i]) { @@ -1051,6 +1150,10 @@ bool llama_context::adapters_lora_are_same(llama_adapter_lora ** adapters, size_ } } + if (n_non_zero != loras->size()) { + return false; + } + return true; } @@ -1114,6 +1217,7 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll { //const auto t_start_us = ggml_time_us(); + // FIXME this call causes a crash if any model inputs were not used in the graph and were therefore not allocated res->set_inputs(&ubatch); //LLAMA_LOG_INFO("graph set inputs time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0); @@ -1981,7 +2085,7 @@ ggml_cgraph * llama_context::graph_reserve( ggml_backend_sched_reset(sched.get()); - // when the scheduler is reset, we cannnot reuse the old graph, so we reset the previous graph result to prevent that + // when the scheduler is reset, we cannot reuse the old graph, so we reset the previous graph result to prevent that gf_res_prev->reset(); // store the n_outputs as it is, and restore it afterwards @@ -2831,19 +2935,23 @@ llama_context * llama_init_from_model( if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_k)) { const uint32_t blck_size = ggml_blck_size(params.type_k); - if (model->hparams.n_embd_head_k % blck_size != 0) { - LLAMA_LOG_ERROR("%s: K cache type %s with block size %u does not divide n_embd_head_k=%u\n", - __func__, ggml_type_name(params.type_k), blck_size, model->hparams.n_embd_head_k); - return nullptr; + for (uint32_t il = 0; il < model->hparams.n_layer; ++il) { + if (model->hparams.n_embd_head_k(il) % blck_size != 0) { + LLAMA_LOG_ERROR("%s: K cache type %s with block size %u does not divide n_embd_head_k=%u\n", + __func__, ggml_type_name(params.type_k), blck_size, model->hparams.n_embd_head_k(il)); + return nullptr; + } } } if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_v)) { const uint32_t blck_size = ggml_blck_size(params.type_v); - if (model->hparams.n_embd_head_v % blck_size != 0) { - LLAMA_LOG_ERROR("%s: V cache type %s with block size %u does not divide n_embd_head_k=%u\n", - __func__, ggml_type_name(params.type_v), blck_size, model->hparams.n_embd_head_v); - return nullptr; + for (uint32_t il = 0; il < model->hparams.n_layer; ++il) { + if (model->hparams.n_embd_head_v(il) % blck_size != 0) { + LLAMA_LOG_ERROR("%s: V cache type %s with block size %u does not divide n_embd_head_v=%u\n", + __func__, ggml_type_name(params.type_v), blck_size, model->hparams.n_embd_head_v(il)); + return nullptr; + } } } @@ -3035,6 +3143,19 @@ uint32_t llama_get_sampled_probs_count_ith(llama_context * ctx, int32_t i) { return static_cast(ctx->get_sampled_probs_count(i)); } +struct ggml_cgraph * llama_graph_reserve( + struct llama_context * ctx, + uint32_t n_tokens, + uint32_t n_seqs, + uint32_t n_outputs) { + auto * memory = ctx->get_memory(); + llama_memory_context_ptr mctx; + if (memory) { + mctx = memory->init_full(); + } + return ctx->graph_reserve(n_tokens, n_seqs, n_outputs, mctx.get()); +} + // llama adapter API int32_t llama_set_adapters_lora( diff --git a/examples/talk-llama/llama-cparams.h b/examples/talk-llama/llama-cparams.h index 2da3bbd6f94..9d359474132 100644 --- a/examples/talk-llama/llama-cparams.h +++ b/examples/talk-llama/llama-cparams.h @@ -31,6 +31,9 @@ struct llama_cparams { bool offload_kqv; bool flash_attn; bool auto_fa; + bool fused_gdn_ar; // use fused gated delta net (autoregressive) + bool fused_gdn_ch; // use fused gated delta net (chunked) + bool auto_fgdn; bool no_perf; bool warmup; bool op_offload; diff --git a/examples/talk-llama/llama-ext.h b/examples/talk-llama/llama-ext.h new file mode 100644 index 00000000000..13ced783b42 --- /dev/null +++ b/examples/talk-llama/llama-ext.h @@ -0,0 +1,12 @@ +#pragma once + +#include "llama-context.h" +#include "ggml.h" +#include "stdint.h" + +// Reserve a new compute graph. It is valid until the next call to llama_graph_reserve. +LLAMA_API struct ggml_cgraph * llama_graph_reserve( + struct llama_context * ctx, + uint32_t n_tokens, + uint32_t n_seqs, + uint32_t n_outputs); diff --git a/examples/talk-llama/llama-grammar.cpp b/examples/talk-llama/llama-grammar.cpp index 2d55070cecc..aac0d41f2b4 100644 --- a/examples/talk-llama/llama-grammar.cpp +++ b/examples/talk-llama/llama-grammar.cpp @@ -601,7 +601,7 @@ const char * llama_grammar_parser::parse_sequence( throw std::runtime_error(std::string("expecting an int at ") + pos); } const char * int_end = parse_int(pos); - uint64_t min_times = std::stoul(std::string(pos, int_end - pos)); + uint64_t min_times = std::stoull(std::string(pos, int_end - pos)); pos = parse_space(int_end, is_nested); uint64_t max_times = UINT64_MAX; // default: no max limit @@ -614,7 +614,7 @@ const char * llama_grammar_parser::parse_sequence( if (is_digit_char(*pos)) { const char * int_end = parse_int(pos); - max_times = std::stoul(std::string(pos, int_end - pos)); + max_times = std::stoull(std::string(pos, int_end - pos)); pos = parse_space(int_end, is_nested); } @@ -1160,13 +1160,13 @@ struct llama_grammar * llama_grammar_init_impl( // if there is a grammar, parse it // rules will be empty (default) if there are parse errors if (!parser.parse(grammar_str) || parser.rules.empty()) { - fprintf(stderr, "%s: failed to parse grammar\n", __func__); + LLAMA_LOG_ERROR("failed to parse grammar\n"); return nullptr; } - // Ensure that there is a "root" node. - if (parser.symbol_ids.find("root") == parser.symbol_ids.end()) { - fprintf(stderr, "%s: grammar does not contain a 'root' symbol\n", __func__); + // Ensure that the grammar contains the start symbol + if (parser.symbol_ids.find(grammar_root) == parser.symbol_ids.end()) { + LLAMA_LOG_ERROR("grammar does not contain a '%s' symbol\n", grammar_root); return nullptr; } @@ -1195,7 +1195,7 @@ struct llama_grammar * llama_grammar_init_impl( continue; } if (llama_grammar_detect_left_recursion(vec_rules, i, &rules_visited, &rules_in_progress, &rules_may_be_empty)) { - LLAMA_LOG_ERROR("unsupported grammar, left recursion detected for nonterminal at index %zu", i); + LLAMA_LOG_ERROR("unsupported grammar, left recursion detected for nonterminal at index %zu\n", i); return nullptr; } } diff --git a/examples/talk-llama/llama-graph.cpp b/examples/talk-llama/llama-graph.cpp index 23a86ea2905..9a215bb77a0 100644 --- a/examples/talk-llama/llama-graph.cpp +++ b/examples/talk-llama/llama-graph.cpp @@ -250,7 +250,7 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) { const bool last = ( cparams.pooling_type == LLAMA_POOLING_TYPE_LAST || - (cparams.pooling_type == LLAMA_POOLING_TYPE_RANK && arch == LLM_ARCH_QWEN3) // qwen3 reranking & embedding models use last token + (cparams.pooling_type == LLAMA_POOLING_TYPE_RANK && (arch == LLM_ARCH_QWEN3 || arch == LLM_ARCH_QWEN3VL)) // qwen3 reranking & embedding models use last token ); for (int i = 0; i < n_tokens; ++i) { @@ -509,6 +509,7 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) { float * data = (float *) cross_kq_mask->data; for (int i = 0; i < n_tokens; ++i) { + GGML_ASSERT(!cross->seq_ids_enc.empty() && "llama_encode must be called first"); for (int j = 0; j < n_enc; ++j) { float f = -INFINITY; @@ -848,13 +849,13 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) : ubatch (params.ubatch), n_embd (hparams.n_embd), n_layer (hparams.n_layer), - n_rot (hparams.n_rot), + n_rot (hparams.n_rot()), n_ctx (cparams.n_ctx), n_head (hparams.n_head()), n_head_kv (hparams.n_head_kv()), - n_embd_head_k (hparams.n_embd_head_k), + n_embd_head_k (hparams.n_embd_head_k()), n_embd_k_gqa (hparams.n_embd_k_gqa()), - n_embd_head_v (hparams.n_embd_head_v), + n_embd_head_v (hparams.n_embd_head_v()), n_embd_v_gqa (hparams.n_embd_v_gqa()), n_expert (hparams.n_expert), n_expert_used (cparams.warmup ? hparams.n_expert : hparams.n_expert_used), @@ -899,7 +900,8 @@ ggml_tensor * llm_graph_context::build_cvec( ggml_tensor * llm_graph_context::build_lora_mm( ggml_tensor * w, - ggml_tensor * cur) const { + ggml_tensor * cur, + ggml_tensor * w_s) const { ggml_tensor * res = ggml_mul_mat(ctx0, w, cur); for (const auto & lora : *loras) { @@ -920,6 +922,10 @@ ggml_tensor * llm_graph_context::build_lora_mm( res = ggml_add(ctx0, res, ab_cur); } + if (w_s) { + res = ggml_mul(ctx0, res, w_s); + } + return res; } @@ -1161,12 +1167,14 @@ ggml_tensor * llm_graph_context::build_moe_ffn( int64_t n_expert_used, llm_ffn_op_type type_op, bool norm_w, - bool scale_w, float w_scale, llama_expert_gating_func_type gating_op, int il, ggml_tensor * probs_in, - ggml_tensor * gate_up_exps) const { + ggml_tensor * gate_up_exps, + ggml_tensor * up_exps_s, + ggml_tensor * gate_exps_s, + ggml_tensor * down_exps_s) const { return build_moe_ffn( cur, gate_inp, /* gate_inp_b */ nullptr, @@ -1178,12 +1186,15 @@ ggml_tensor * llm_graph_context::build_moe_ffn( n_expert_used, type_op, norm_w, - scale_w, w_scale, gating_op, il, probs_in, - gate_up_exps + gate_up_exps, + /* gate_up_exps_b */ nullptr, + up_exps_s, + gate_exps_s, + down_exps_s ); } @@ -1202,13 +1213,15 @@ ggml_tensor * llm_graph_context::build_moe_ffn( int64_t n_expert_used, llm_ffn_op_type type_op, bool norm_w, - bool scale_w, float w_scale, llama_expert_gating_func_type gating_op, int il, ggml_tensor * probs_in, ggml_tensor * gate_up_exps, - ggml_tensor * gate_up_exps_b) const { + ggml_tensor * gate_up_exps_b, + ggml_tensor * up_exps_s, + ggml_tensor * gate_exps_s, + ggml_tensor * down_exps_s) const { const int64_t n_embd = cur->ne[0]; const int64_t n_tokens = cur->ne[1]; const bool weight_before_ffn = arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN @@ -1330,7 +1343,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn( weights = ggml_reshape_3d(ctx0, weights, 1, n_expert_used, n_tokens); } - if (scale_w) { + if (w_scale != 0.0f && w_scale != 1.0f) { weights = ggml_scale(ctx0, weights, w_scale); cb(weights, "ffn_moe_weights_scaled", il); } @@ -1360,6 +1373,15 @@ ggml_tensor * llm_graph_context::build_moe_ffn( cb(gate_up, "ffn_moe_gate_up_biased", il); } + // apply per-expert scale2 to merged gate_up (use up_exps_s since gate and up are fused) + if (up_exps_s) { + ggml_tensor * s = ggml_reshape_3d(ctx0, up_exps_s, 1, n_expert, 1); + s = ggml_repeat_4d(ctx0, s, 1, n_expert, n_tokens, 1); + s = ggml_get_rows(ctx0, s, selected_experts); // [1, n_expert_used, n_tokens] + gate_up = ggml_mul(ctx0, gate_up, s); + cb(gate_up, "ffn_moe_gate_up_scaled", il); + } + const int64_t n_ff = gate_up->ne[0] / 2; cur = ggml_view_3d(ctx0, gate_up, n_ff, gate_up->ne[1], gate_up->ne[2], gate_up->nb[1], gate_up->nb[2], 0); cb(cur, "ffn_moe_gate", il); @@ -1375,6 +1397,15 @@ ggml_tensor * llm_graph_context::build_moe_ffn( cb(up, "ffn_moe_up_biased", il); } + // apply per-expert scale2 to up + if (up_exps_s) { + ggml_tensor * s = ggml_reshape_3d(ctx0, up_exps_s, 1, n_expert, 1); + s = ggml_repeat_4d(ctx0, s, 1, n_expert, n_tokens, 1); + s = ggml_get_rows(ctx0, s, selected_experts); // [1, n_expert_used, n_tokens] + up = ggml_mul(ctx0, up, s); + cb(up, "ffn_moe_up_scaled", il); + } + if (gate_exps) { cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] cb(cur, "ffn_moe_gate", il); @@ -1386,6 +1417,15 @@ ggml_tensor * llm_graph_context::build_moe_ffn( cur = ggml_add_id(ctx0, cur, gate_exps_b, selected_experts); cb(cur, "ffn_moe_gate_biased", il); } + + // apply per-expert scale2 to gate + if (gate_exps_s) { + ggml_tensor * s = ggml_reshape_3d(ctx0, gate_exps_s, 1, n_expert, 1); + s = ggml_repeat_4d(ctx0, s, 1, n_expert, n_tokens, 1); + s = ggml_get_rows(ctx0, s, selected_experts); // [1, n_expert_used, n_tokens] + cur = ggml_mul(ctx0, cur, s); + cb(cur, "ffn_moe_gate_scaled", il); + } } const bool has_gate = gate_exps || gate_up_exps; @@ -1465,6 +1505,15 @@ ggml_tensor * llm_graph_context::build_moe_ffn( cb(experts, "ffn_moe_down_biased", il); } + // apply per-expert scale2 to down + if (down_exps_s) { + ggml_tensor * s = ggml_reshape_3d(ctx0, down_exps_s, 1, n_expert, 1); + s = ggml_repeat_4d(ctx0, s, 1, n_expert, n_tokens, 1); + s = ggml_get_rows(ctx0, s, selected_experts); // [1, n_expert_used, n_tokens] + experts = ggml_mul(ctx0, experts, s); + cb(experts, "ffn_moe_down_scaled", il); + } + if (!weight_before_ffn) { experts = ggml_mul(ctx0, experts, weights); cb(cur, "ffn_moe_weighted", il); @@ -1607,6 +1656,7 @@ ggml_tensor * llm_graph_context::build_inp_attn_scale() const { // this need to be 1x1xN for broadcasting cur = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 1, 1, n_tokens); ggml_set_input(cur); + ggml_set_name(cur, "attn_scale"); res->add_input(std::move(inp)); @@ -1616,7 +1666,7 @@ ggml_tensor * llm_graph_context::build_inp_attn_scale() const { ggml_tensor * llm_graph_context::build_inp_out_ids() const { // note: when all tokens are output, we could skip this optimization to spare the ggml_get_rows() calls, // but this would make the graph topology depend on the number of output tokens, which can interere with - // features that require constant topology such as pipline parallelism + // features that require constant topology such as pipeline parallelism // ref: https://github.com/ggml-org/llama.cpp/pull/14275#issuecomment-2987424471 //if (n_outputs < n_tokens) { // return nullptr; @@ -1779,7 +1829,7 @@ ggml_tensor * llm_graph_context::build_attn_mha( if (v_mla) { #if 0 // v_mla can be applied as a matrix-vector multiplication with broadcasting across dimension 3 == n_tokens. - // However, the code is optimized for dimensions 0 and 1 being large, so this is ineffient. + // However, the code is optimized for dimensions 0 and 1 being large, so this is inefficient. cur = ggml_reshape_4d(ctx0, cur, v_mla->ne[0], 1, n_head, n_tokens); cur = ggml_mul_mat(ctx0, v_mla, cur); #else @@ -2553,7 +2603,7 @@ void llm_graph_context::build_pooling( } // softmax for qwen3 reranker - if (arch == LLM_ARCH_QWEN3) { + if (arch == LLM_ARCH_QWEN3 || arch == LLM_ARCH_QWEN3VL) { cur = ggml_soft_max(ctx0, cur); } } break; diff --git a/examples/talk-llama/llama-graph.h b/examples/talk-llama/llama-graph.h index e8f006977d2..4855685ef71 100644 --- a/examples/talk-llama/llama-graph.h +++ b/examples/talk-llama/llama-graph.h @@ -764,10 +764,11 @@ struct llm_graph_context { ggml_tensor * cur, int il) const; - // do mat_mul, while optionally apply lora + // do mat_mul, while optionally apply lora and per-tensor scale ggml_tensor * build_lora_mm( ggml_tensor * w, - ggml_tensor * cur) const; + ggml_tensor * cur, + ggml_tensor * w_s = nullptr) const; // do mat_mul_id, while optionally apply lora ggml_tensor * build_lora_mm_id( @@ -810,12 +811,14 @@ struct llm_graph_context { int64_t n_expert_used, llm_ffn_op_type type_op, bool norm_w, - bool scale_w, float w_scale, llama_expert_gating_func_type gating_op, int il, ggml_tensor * probs_in = nullptr, - ggml_tensor * gate_up_exps = nullptr) const; + ggml_tensor * gate_up_exps = nullptr, + ggml_tensor * up_exps_s = nullptr, + ggml_tensor * gate_exps_s = nullptr, + ggml_tensor * down_exps_s = nullptr) const; ggml_tensor * build_moe_ffn( ggml_tensor * cur, @@ -832,13 +835,15 @@ struct llm_graph_context { int64_t n_expert_used, llm_ffn_op_type type_op, bool norm_w, - bool scale_w, float w_scale, llama_expert_gating_func_type gating_op, int il, ggml_tensor * probs_in = nullptr, ggml_tensor * gate_up_exps = nullptr, - ggml_tensor * gate_up_exps_b = nullptr) const; + ggml_tensor * gate_up_exps_b = nullptr, + ggml_tensor * up_exps_s = nullptr, + ggml_tensor * gate_exps_s = nullptr, + ggml_tensor * down_exps_s = nullptr) const; // // inputs diff --git a/examples/talk-llama/llama-hparams.cpp b/examples/talk-llama/llama-hparams.cpp index 756dda1a7ab..002d15d415f 100644 --- a/examples/talk-llama/llama-hparams.cpp +++ b/examples/talk-llama/llama-hparams.cpp @@ -62,6 +62,14 @@ uint32_t llama_hparams::n_gqa(uint32_t il) const { return n_head/n_head_kv; } +uint32_t llama_hparams::n_rot(uint32_t il) const { + if (il < n_layer) { + return is_swa(il) ? n_rot_swa : n_rot_full; + } + + GGML_ABORT("fatal error"); +} + uint32_t llama_hparams::n_embd_inp() const { uint32_t n_embd_inp = n_embd; @@ -76,16 +84,32 @@ uint32_t llama_hparams::n_embd_out() const { return n_embd_out_impl > 0 ? n_embd_out_impl : n_embd; } +uint32_t llama_hparams::n_embd_head_k(uint32_t il) const { + if (il < n_layer) { + return is_swa(il) ? n_embd_head_k_swa : n_embd_head_k_full; + } + + GGML_ABORT("fatal error"); +} + +uint32_t llama_hparams::n_embd_head_v(uint32_t il) const { + if (il < n_layer) { + return is_swa(il) ? n_embd_head_v_swa : n_embd_head_v_full; + } + + GGML_ABORT("fatal error"); +} + uint32_t llama_hparams::n_embd_k_gqa(uint32_t il) const { const uint32_t n_head_kv = this->n_head_kv(il); - return n_embd_head_k * n_head_kv; + return n_embd_head_k(il) * n_head_kv; } uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const { const uint32_t n_head_kv = this->n_head_kv(il); - return n_embd_head_v * n_head_kv; + return n_embd_head_v(il) * n_head_kv; } bool llama_hparams::is_n_embd_k_gqa_variable() const { @@ -197,11 +221,11 @@ bool llama_hparams::is_mla() const { } uint32_t llama_hparams::n_embd_head_k_mla() const { - return is_mla() ? n_embd_head_k_mla_impl : n_embd_head_k; + return is_mla() ? n_embd_head_k_mla_impl : n_embd_head_k(); } uint32_t llama_hparams::n_embd_head_v_mla() const { - return is_mla() ? n_embd_head_v_mla_impl : n_embd_head_v; + return is_mla() ? n_embd_head_v_mla_impl : n_embd_head_v(); } bool llama_hparams::has_kv(uint32_t il) const { diff --git a/examples/talk-llama/llama-hparams.h b/examples/talk-llama/llama-hparams.h index c4b2a99da5a..78c0bc27d4d 100644 --- a/examples/talk-llama/llama-hparams.h +++ b/examples/talk-llama/llama-hparams.h @@ -44,13 +44,20 @@ struct llama_hparams { uint32_t n_embd; uint32_t n_layer; int32_t n_layer_kv_from_start = -1; // if non-negative, the first n_layer_kv_from_start layers have KV cache - uint32_t n_rot; - uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads - uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head uint32_t n_expert = 0; uint32_t n_expert_used = 0; uint32_t n_rel_attn_bkts = 0; + // different head size for full_attention and SWA layers + uint32_t n_embd_head_k_full; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads + uint32_t n_embd_head_v_full; // dimension of values (d_v) aka n_embd_head + uint32_t n_embd_head_k_swa; + uint32_t n_embd_head_v_swa; + + // different RoPE dimensions for full_attention and SWA layers + uint32_t n_rot_full; + uint32_t n_rot_swa; + // note: deepseek2 using MLA converts into MQA with larger heads, then decompresses to MHA uint32_t n_embd_head_k_mla_impl = 0; uint32_t n_embd_head_v_mla_impl = 0; @@ -82,6 +89,7 @@ struct llama_hparams { bool expert_weights_norm = false; uint32_t expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_NONE; uint32_t moe_every_n_layers = 0; + uint32_t moe_latent_size = 0; uint32_t nextn_predict_layers = 0; float f_norm_eps; @@ -247,12 +255,18 @@ struct llama_hparams { uint32_t n_gqa(uint32_t il = 0) const; + uint32_t n_rot(uint32_t il = 0) const; + // dimension of main + auxiliary input embeddings uint32_t n_embd_inp() const; // dimension of output embeddings uint32_t n_embd_out() const; + // dimension of key/value embeddings for each head (per layer) + uint32_t n_embd_head_k(uint32_t il = 0) const; + uint32_t n_embd_head_v(uint32_t il = 0) const; + // dimension of key embeddings across all k-v heads uint32_t n_embd_k_gqa(uint32_t il = 0) const; diff --git a/examples/talk-llama/llama-impl.cpp b/examples/talk-llama/llama-impl.cpp index 710a5a1e08d..4c0188ee722 100644 --- a/examples/talk-llama/llama-impl.cpp +++ b/examples/talk-llama/llama-impl.cpp @@ -100,9 +100,9 @@ std::string format(const char * fmt, ...) { std::string llama_format_tensor_shape(const std::vector & ne) { char buf[256]; - snprintf(buf, sizeof(buf), "%5" PRId64, ne.at(0)); + snprintf(buf, sizeof(buf), "%6" PRId64, ne.at(0)); for (size_t i = 1; i < ne.size(); i++) { - snprintf(buf + strlen(buf), sizeof(buf) - strlen(buf), ", %5" PRId64, ne.at(i)); + snprintf(buf + strlen(buf), sizeof(buf) - strlen(buf), ", %6" PRId64, ne.at(i)); } return buf; } diff --git a/examples/talk-llama/llama-impl.h b/examples/talk-llama/llama-impl.h index dfd9fee9f44..e4f35c8e53d 100644 --- a/examples/talk-llama/llama-impl.h +++ b/examples/talk-llama/llama-impl.h @@ -70,4 +70,6 @@ std::string llama_format_tensor_shape(const struct ggml_tensor * t); std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i); -#define LLAMA_TENSOR_NAME_FATTN "__fattn__" +#define LLAMA_TENSOR_NAME_FATTN "__fattn__" +#define LLAMA_TENSOR_NAME_FGDN_AR "__fgdn_ar__" +#define LLAMA_TENSOR_NAME_FGDN_CH "__fgdn_ch__" diff --git a/examples/talk-llama/llama-kv-cache.cpp b/examples/talk-llama/llama-kv-cache.cpp index 6b668ee9abd..01166fac9ce 100644 --- a/examples/talk-llama/llama-kv-cache.cpp +++ b/examples/talk-llama/llama-kv-cache.cpp @@ -583,7 +583,7 @@ llama_kv_cache::slot_info_vec_t llama_kv_cache::prepare(const std::vectortype, hparams.n_embd_head_k), + hparams.n_embd_head_k(il), hparams.n_head_kv(il), n_kv, ns, + ggml_row_size(k->type, hparams.n_embd_head_k(il)), ggml_row_size(k->type, n_embd_k_gqa), ggml_row_size(k->type, n_embd_k_gqa*kv_size), ggml_row_size(k->type, n_embd_k_gqa*kv_size)*sinfo.s0); @@ -1056,8 +1056,8 @@ ggml_tensor * llama_kv_cache::get_v(ggml_context * ctx, int32_t il, uint32_t n_k if (!v_trans) { // note: v->nb[1] <= v->nb[2] return ggml_view_4d(ctx, v, - hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv, ns, - ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1] + hparams.n_embd_head_v(il), hparams.n_head_kv(il), n_kv, ns, + ggml_row_size(v->type, hparams.n_embd_head_v(il)), // v->nb[1] ggml_row_size(v->type, n_embd_v_gqa), // v->nb[2] ggml_row_size(v->type, n_embd_v_gqa*kv_size), // v->nb[3] ggml_row_size(v->type, n_embd_v_gqa*kv_size)*sinfo.s0); @@ -1065,8 +1065,8 @@ ggml_tensor * llama_kv_cache::get_v(ggml_context * ctx, int32_t il, uint32_t n_k // note: v->nb[1] > v->nb[2] return ggml_view_4d(ctx, v, - n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v, ns, - ggml_row_size(v->type, kv_size*hparams.n_embd_head_v), // v->nb[1] + n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v(il), ns, + ggml_row_size(v->type, kv_size*hparams.n_embd_head_v(il)), // v->nb[1] ggml_row_size(v->type, kv_size), // v->nb[2] ggml_row_size(v->type, kv_size*n_embd_v_gqa), // v->nb[3] ggml_row_size(v->type, kv_size*n_embd_v_gqa)*sinfo.s0); @@ -1293,7 +1293,7 @@ static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * } for (uint32_t s = 0; s < n_stream; ++s) { - // bookeeping of the KQ mask cells that could change for other tokens of the same sequence + // bookkeeping of the KQ mask cells that could change for other tokens of the same sequence std::unordered_map seq_srct; std::unordered_map> seq_idxs; @@ -1544,7 +1544,8 @@ ggml_tensor * llama_kv_cache::build_rope_shift( ggml_tensor * shift, ggml_tensor * factors, float freq_base, - float freq_scale) const { + float freq_scale, + uint32_t il) const { const auto & n_ctx_orig = cparams.n_ctx_orig_yarn; const auto & yarn_ext_factor = cparams.yarn_ext_factor; @@ -1552,7 +1553,7 @@ ggml_tensor * llama_kv_cache::build_rope_shift( const auto & yarn_beta_slow = cparams.yarn_beta_slow; const auto & yarn_attn_factor = cparams.yarn_attn_factor; - const auto & n_rot = hparams.n_rot; + const auto & n_rot = hparams.n_rot(il); const auto & rope_type = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE || hparams.rope_type == LLAMA_ROPE_TYPE_IMROPE // @ngxson : this is a workaround // for M-RoPE, we want to rotate the whole vector when doing KV shift @@ -1606,13 +1607,6 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co auto * ctx = res->get_ctx(); auto * gf = res->get_gf(); - const auto & n_embd_head_k = hparams.n_embd_head_k; - //const auto & n_embd_head_v = hparams.n_embd_head_v; - - const auto & n_rot = hparams.n_rot; - - const auto n_embd_nope = hparams.n_lora_kv > 0 ? n_embd_head_k - n_rot : 0; - auto inp = std::make_unique(this); inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, (int64_t) get_size()*n_stream); @@ -1626,6 +1620,10 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co const int64_t n_head_kv = hparams.n_head_kv(il); const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); + const auto n_rot = hparams.n_rot(il); + const auto n_embd_head_k = hparams.n_embd_head_k(il); + const auto n_embd_nope = hparams.n_lora_kv > 0 ? n_embd_head_k - n_rot : 0; + const float freq_base_l = model.get_rope_freq_base (cparams, il); const float freq_scale_l = model.get_rope_freq_scale(cparams, il); @@ -1638,7 +1636,7 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co ggml_row_size(layer.k->type, n_embd_k_gqa), ggml_row_size(layer.k->type, n_embd_nope)); - ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l); + ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l, il); ggml_build_forward_expand(gf, cur); } @@ -1760,8 +1758,10 @@ void llama_kv_cache::state_write_meta(llama_io_write_i & io, const cell_ranges_t io.write(&pos, sizeof(pos)); io.write(&n_seq_id, sizeof(n_seq_id)); - // TODO: we also need to save llama_kv_cell_ext when apply_ubatch() support loading it - // see: https://github.com/ggml-org/llama.cpp/pull/16825#issuecomment-3460868350 + if (hparams.n_pos_per_embd() > 1) { + const llama_kv_cell_ext ext = cells.ext_get(i); + io.write(&ext, sizeof(ext)); + } for (const auto & seq_id : seq_ids) { io.write(&seq_id, sizeof(seq_id)); @@ -1895,6 +1895,14 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32 return false; } + if (hparams.n_pos_per_embd() > 1) { + llama_kv_cell_ext ext; + io.read_to(&ext, sizeof(ext)); + + ubatch.pos[i + ubatch.n_tokens] = ext.y; + ubatch.pos[i + ubatch.n_tokens*2] = ext.x; + } + // read the sequence id, but directly discard it - we will use dest_seq_id instead { llama_seq_id seq_id; @@ -1945,6 +1953,12 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32 cells.pos_set(i, pos); + if (hparams.n_pos_per_embd() > 1) { + llama_kv_cell_ext ext; + io.read_to(&ext, sizeof(ext)); + cells.ext_set(i, ext); + } + for (uint32_t j = 0; j < n_seq_id; ++j) { llama_seq_id seq_id; io.read_to(&seq_id, sizeof(seq_id)); diff --git a/examples/talk-llama/llama-kv-cache.h b/examples/talk-llama/llama-kv-cache.h index e194bf3e26f..33c78c5f210 100644 --- a/examples/talk-llama/llama-kv-cache.h +++ b/examples/talk-llama/llama-kv-cache.h @@ -264,7 +264,8 @@ class llama_kv_cache : public llama_memory_i { ggml_tensor * shift, ggml_tensor * factors, float freq_base, - float freq_scale) const; + float freq_scale, + uint32_t il) const; ggml_cgraph * build_graph_shift( llm_graph_result * res, diff --git a/examples/talk-llama/llama-model-loader.cpp b/examples/talk-llama/llama-model-loader.cpp index 1501e392ca8..413f34c2268 100644 --- a/examples/talk-llama/llama-model-loader.cpp +++ b/examples/talk-llama/llama-model-loader.cpp @@ -1,12 +1,17 @@ #include "llama-model-loader.h" +#include "ggml-alloc.h" #include "ggml.h" +#include "gguf.h" +#include "llama-hparams.h" #include #include #include +#include #include #include +#include static const size_t kiB = 1024; static const size_t MiB = 1024*kiB; @@ -37,6 +42,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) { case LLAMA_FTYPE_MOSTLY_Q5_1: return "Q5_1"; case LLAMA_FTYPE_MOSTLY_Q8_0: return "Q8_0"; case LLAMA_FTYPE_MOSTLY_MXFP4_MOE: return "MXFP4 MoE"; + case LLAMA_FTYPE_MOSTLY_NVFP4: return "NVFP4"; case LLAMA_FTYPE_MOSTLY_Q2_K: return "Q2_K - Medium"; case LLAMA_FTYPE_MOSTLY_Q2_K_S: return "Q2_K - Small"; case LLAMA_FTYPE_MOSTLY_Q3_K_S: return "Q3_K - Small"; @@ -263,7 +269,7 @@ namespace GGUFMeta { template typename std::enable_if::value, bool>::type llama_model_loader::get_arr_n(const std::string & key, T & result, bool required) { - const int kid = gguf_find_key(meta.get(), key.c_str()); + const int kid = gguf_find_key(metadata, key.c_str()); if (kid < 0) { if (required) { @@ -273,7 +279,7 @@ namespace GGUFMeta { } struct GGUFMeta::ArrayInfo arr_info = - GGUFMeta::GKV::get_kv(meta.get(), kid); + GGUFMeta::GKV::get_kv(metadata, kid); result = arr_info.length; @@ -290,7 +296,7 @@ namespace GGUFMeta { template bool llama_model_loader::get_arr(const std::string & key, std::vector & result, bool required) { - const gguf_context * ctx = meta.get(); + const gguf_context * ctx = metadata; const int kid = gguf_find_key(ctx, key.c_str()); if (kid < 0 || gguf_get_kv_type(ctx, kid) != GGUF_TYPE_ARRAY) { @@ -331,7 +337,7 @@ namespace GGUFMeta { template bool llama_model_loader::get_arr(const std::string & key, std::array & result, bool required) { - const gguf_context * ctx = meta.get(); + const gguf_context * ctx = metadata; const int kid = gguf_find_key(ctx, key.c_str()); if (kid < 0 || gguf_get_kv_type(ctx, kid) != GGUF_TYPE_ARRAY) { @@ -393,7 +399,7 @@ namespace GGUFMeta { const struct llama_model_kv_override * override = it != kv_overrides.end() ? &it->second : nullptr; - const bool found = GGUFMeta::GKV::set(meta.get(), key, result, override); + const bool found = GGUFMeta::GKV::set(metadata, key, result, override); if (required && !found) { throw std::runtime_error(format("key not found in model: %s", key.c_str())); @@ -427,7 +433,7 @@ namespace GGUFMeta { // get array of n <= N_MAX elements, or a single element repeated n times template bool llama_model_loader::get_key_or_arr(const std::string & key, std::array & result, uint32_t n, bool required) { - const int kid = gguf_find_key(meta.get(), key.c_str()); + const int kid = gguf_find_key(metadata, key.c_str()); if (kid < 0) { if (required) { @@ -440,9 +446,9 @@ namespace GGUFMeta { throw std::runtime_error(format("n > N_MAX: %u > %u for key %s", (uint32_t) n, (uint32_t) N_MAX, key.c_str())); } - if (gguf_get_kv_type(meta.get(), kid) == GGUF_TYPE_ARRAY) { + if (gguf_get_kv_type(metadata, kid) == GGUF_TYPE_ARRAY) { struct GGUFMeta::ArrayInfo arr_info = - GGUFMeta::GKV::get_kv(meta.get(), kid); + GGUFMeta::GKV::get_kv(metadata, kid); if (n != arr_info.length) { throw std::runtime_error(format("key %s has wrong array length; expected %u, got %u", key.c_str(), n, (uint32_t) arr_info.length)); @@ -473,7 +479,7 @@ namespace GGUFMeta { bool llama_model_loader::get_key_or_arr(enum llm_kv kid, uint32_t & result, bool required) { const std::string key = llm_kv(kid); - const int id = gguf_find_key(meta.get(), key.c_str()); + const int id = gguf_find_key(metadata, key.c_str()); if (id < 0) { if (required) { @@ -483,7 +489,7 @@ namespace GGUFMeta { } // throw and error if type is an array - if (gguf_get_kv_type(meta.get(), id) == GGUF_TYPE_ARRAY) { + if (gguf_get_kv_type(metadata, id) == GGUF_TYPE_ARRAY) { if (required) { throw std::runtime_error(format("expected scalar, found array for key: %s", key.c_str())); } @@ -500,6 +506,9 @@ namespace GGUFMeta { llama_model_loader::llama_model_loader( + struct gguf_context * meta, + llama_model_set_tensor_data_t set_tensor_data, + void * set_tensor_data_ud, const std::string & fname, std::vector & splits, bool use_mmap, @@ -507,7 +516,8 @@ llama_model_loader::llama_model_loader( bool check_tensors, bool no_alloc, const llama_model_kv_override * param_overrides_p, - const llama_model_tensor_buft_override * param_tensor_buft_overrides_p) { + const llama_model_tensor_buft_override * param_tensor_buft_overrides_p) + : metadata(meta), set_tensor_data(set_tensor_data), set_tensor_data_ud(set_tensor_data_ud) { int trace = 0; if (getenv("LLAMA_TRACE")) { trace = atoi(getenv("LLAMA_TRACE")); @@ -521,136 +531,142 @@ llama_model_loader::llama_model_loader( tensor_buft_overrides = param_tensor_buft_overrides_p; - // Load the main GGUF - struct ggml_context * ctx = NULL; - struct gguf_init_params params = { - /*.no_alloc = */ true, - /*.ctx = */ &ctx, - }; - - meta.reset(gguf_init_from_file(fname.c_str(), params)); - if (!meta) { - throw std::runtime_error(format("%s: failed to load model from %s", __func__, fname.c_str())); - } - - get_key(llm_kv(LLM_KV_GENERAL_ARCHITECTURE), arch_name, false); - llm_kv = LLM_KV(llm_arch_from_string(arch_name)); - - files.emplace_back(new llama_file(fname.c_str(), "rb", use_direct_io)); - contexts.emplace_back(ctx); + if (!fname.empty()) { + // Load the main GGUF + struct ggml_context * ctx = NULL; + struct gguf_init_params params = { + /*.no_alloc = */ true, + /*.ctx = */ &ctx, + }; - if (use_mmap && use_direct_io) { - if (files.back()->has_direct_io()) { - LLAMA_LOG_WARN("%s: direct I/O is enabled, disabling mmap\n", __func__); - use_mmap = false; - } else { - LLAMA_LOG_WARN("%s: direct I/O is not available, using mmap\n", __func__); - use_direct_io = false; - - // reopen file using std::fopen for mmap - files.pop_back(); - files.emplace_back(new llama_file(fname.c_str(), "rb", false)); + metadata_ptr.reset(gguf_init_from_file(fname.c_str(), params)); + metadata = metadata_ptr.get(); + if (metadata == nullptr) { + throw std::runtime_error(format("%s: failed to load model from %s", __func__, fname.c_str())); } - } - // Save tensors data offset of the main file. - // For subsidiary files, `meta` tensor data offset must not be used, - // so we build a unified tensors index for weights. - for (ggml_tensor * cur = ggml_get_first_tensor(ctx); cur; cur = ggml_get_next_tensor(ctx, cur)) { - std::string tensor_name = std::string(cur->name); - // make sure there is no duplicated tensor names - if (weights_map.find(tensor_name) != weights_map.end()) { - throw std::runtime_error(format("invalid model: tensor '%s' is duplicated", ggml_get_name(cur))); - } - n_elements += ggml_nelements(cur); - n_bytes += ggml_nbytes(cur); - weights_map.emplace(tensor_name, llama_tensor_weight(files.back().get(), 0, meta.get(), cur)); - } - uint16_t n_split = 0; - get_key(llm_kv(LLM_KV_SPLIT_COUNT), n_split, false); + get_key(llm_kv(LLM_KV_GENERAL_ARCHITECTURE), arch_name, false); + llm_kv = LLM_KV(llm_arch_from_string(arch_name)); - // Load additional GGML contexts - if (n_split > 1) { - // make sure the main file is loaded first - uint16_t idx = 0; - const std::string kv_split_no = llm_kv(LLM_KV_SPLIT_NO); - get_key(kv_split_no, idx); - if (idx != 0) { - throw std::runtime_error(format("illegal split file idx: %d (file: %s), model must be loaded with the first split", idx, fname.c_str())); - } + files.emplace_back(new llama_file(fname.c_str(), "rb", use_direct_io)); + contexts.emplace_back(ctx); - // generate list of splits if needed - if (splits.empty()) { - splits = llama_get_list_splits(fname, idx, n_split); - } + if (use_mmap && use_direct_io) { + if (files.back()->has_direct_io()) { + LLAMA_LOG_WARN("%s: direct I/O is enabled, disabling mmap\n", __func__); + use_mmap = false; + } else { + LLAMA_LOG_WARN("%s: direct I/O is not available, using mmap\n", __func__); + use_direct_io = false; - // in case user give a custom list of splits, check if it matches the expected number - if (n_split != (uint16_t)splits.size()) { - throw std::runtime_error(format("invalid split count, given: %zu splits, but expected %d", splits.size(), n_split)); + // reopen file using std::fopen for mmap + files.pop_back(); + files.emplace_back(new llama_file(fname.c_str(), "rb", false)); + } } - if (trace > 0) { - LLAMA_LOG_INFO("%s: loading additional %d GGUFs\n", __func__, n_split); - } + // Save tensors data offset of the main file. + // For subsidiary files, `meta` tensor data offset must not be used, + // so we build a unified tensors index for weights. + for (ggml_tensor * cur = ggml_get_first_tensor(ctx); cur; cur = ggml_get_next_tensor(ctx, cur)) { + std::string tensor_name = std::string(cur->name); + // make sure there is no duplicated tensor names + if (weights_map.find(tensor_name) != weights_map.end()) { + throw std::runtime_error(format("invalid model: tensor '%s' is duplicated", ggml_get_name(cur))); + } + n_elements += ggml_nelements(cur); + n_bytes += ggml_nbytes(cur); + weights_map.emplace(tensor_name, llama_tensor_weight(files.back().get(), 0, metadata, cur)); + } + uint16_t n_split = 0; + get_key(llm_kv(LLM_KV_SPLIT_COUNT), n_split, false); + + // Load additional GGML contexts + if (n_split > 1) { + // make sure the main file is loaded first + uint16_t idx = 0; + const std::string kv_split_no = llm_kv(LLM_KV_SPLIT_NO); + get_key(kv_split_no, idx); + if (idx != 0) { + throw std::runtime_error(format("illegal split file idx: %d (file: %s), model must be loaded with the first split", idx, fname.c_str())); + } - // load other splits - for (idx = 1; idx < n_split; idx++) { - const char * fname_split = splits[idx].c_str(); + // generate list of splits if needed + if (splits.empty()) { + splits = llama_get_list_splits(fname, idx, n_split); + } - struct gguf_init_params split_params = { - /*.no_alloc = */ true, - /*.ctx = */ &ctx, - }; - gguf_context_ptr ctx_gguf { gguf_init_from_file(fname_split, split_params) }; - if (!ctx_gguf) { - throw std::runtime_error(format("%s: failed to load GGUF split from %s", __func__, fname_split)); + // in case user give a custom list of splits, check if it matches the expected number + if (n_split != (uint16_t)splits.size()) { + throw std::runtime_error(format("invalid split count, given: %zu splits, but expected %d", splits.size(), n_split)); } - // check idx - { - const int kid = gguf_find_key(ctx_gguf.get(), kv_split_no.c_str()); - if (kid < 0) { - throw std::runtime_error(format("missing key %s in GGUF split %s", kv_split_no.c_str(), fname_split)); + if (trace > 0) { + LLAMA_LOG_INFO("%s: loading additional %d GGUFs\n", __func__, n_split); + } + + // load other splits + for (idx = 1; idx < n_split; idx++) { + const char * fname_split = splits[idx].c_str(); + + struct gguf_init_params split_params = { + /*.no_alloc = */ true, + /*.ctx = */ &ctx, + }; + gguf_context_ptr ctx_gguf { gguf_init_from_file(fname_split, split_params) }; + if (!ctx_gguf) { + throw std::runtime_error(format("%s: failed to load GGUF split from %s", __func__, fname_split)); } - int idx_gguf = gguf_get_val_u16(ctx_gguf.get(), kid); - if (idx_gguf != idx) { - throw std::runtime_error(format("invalid split file idx: %d (file: %s), expected %d", idx_gguf, fname_split, idx)); + + // check idx + { + const int kid = gguf_find_key(ctx_gguf.get(), kv_split_no.c_str()); + if (kid < 0) { + throw std::runtime_error(format("missing key %s in GGUF split %s", kv_split_no.c_str(), fname_split)); + } + int idx_gguf = gguf_get_val_u16(ctx_gguf.get(), kid); + if (idx_gguf != idx) { + throw std::runtime_error(format("invalid split file idx: %d (file: %s), expected %d", idx_gguf, fname_split, idx)); + } } - } - files.emplace_back(new llama_file(fname_split, "rb", use_direct_io)); - contexts.emplace_back(ctx); + files.emplace_back(new llama_file(fname_split, "rb", use_direct_io)); + contexts.emplace_back(ctx); - // Save tensors data offset info of the shard. - for (ggml_tensor * cur = ggml_get_first_tensor(ctx); cur; cur = ggml_get_next_tensor(ctx, cur)) { - std::string tensor_name = std::string(cur->name); - // make sure there is no duplicated tensor names - if (weights_map.find(tensor_name) != weights_map.end()) { - throw std::runtime_error(format("invalid model: tensor '%s' is duplicated", ggml_get_name(cur))); + // Save tensors data offset info of the shard. + for (ggml_tensor * cur = ggml_get_first_tensor(ctx); cur; cur = ggml_get_next_tensor(ctx, cur)) { + std::string tensor_name = std::string(cur->name); + // make sure there is no duplicated tensor names + if (weights_map.find(tensor_name) != weights_map.end()) { + throw std::runtime_error(format("invalid model: tensor '%s' is duplicated", ggml_get_name(cur))); + } + n_elements += ggml_nelements(cur); + n_bytes += ggml_nbytes(cur); + weights_map.emplace(tensor_name, llama_tensor_weight(files.back().get(), idx, ctx_gguf.get(), cur)); } - n_elements += ggml_nelements(cur); - n_bytes += ggml_nbytes(cur); - weights_map.emplace(tensor_name, llama_tensor_weight(files.back().get(), idx, ctx_gguf.get(), cur)); } - } - get_key(llm_kv(LLM_KV_SPLIT_TENSORS_COUNT), n_tensors); + get_key(llm_kv(LLM_KV_SPLIT_TENSORS_COUNT), n_tensors); - // sanity check - { - const int n_tensors_loaded = (int) weights_map.size(); - if (n_tensors != n_tensors_loaded) { - throw std::runtime_error(format("corrupted model: %d tensors expected but %d found", n_tensors, n_tensors_loaded)); + // sanity check + { + const int n_tensors_loaded = (int) weights_map.size(); + if (n_tensors != n_tensors_loaded) { + throw std::runtime_error(format("corrupted model: %d tensors expected but %d found", n_tensors, n_tensors_loaded)); + } } - } - LLAMA_LOG_INFO("%s: additional %d GGUFs metadata loaded.\n", __func__, n_split - 1); + LLAMA_LOG_INFO("%s: additional %d GGUFs metadata loaded.\n", __func__, n_split - 1); + } + } else { + get_key(llm_kv(LLM_KV_GENERAL_ARCHITECTURE), arch_name, false); + llm_kv = LLM_KV(llm_arch_from_string(arch_name)); } - n_kv = gguf_get_n_kv(meta.get()); + n_kv = gguf_get_n_kv(metadata); n_tensors = weights_map.size(); - fver = (enum llama_fver) gguf_get_version(meta.get()); + fver = (enum llama_fver) gguf_get_version(metadata); LLAMA_LOG_INFO("%s: loaded meta data with %d key-value pairs and %d tensors from %s (version %s)\n", __func__, n_kv, n_tensors, fname.c_str(), llama_file_version_name(fver)); @@ -709,6 +725,7 @@ llama_model_loader::llama_model_loader( case GGML_TYPE_IQ4_NL: ftype = LLAMA_FTYPE_MOSTLY_IQ4_NL; break; case GGML_TYPE_IQ4_XS: ftype = LLAMA_FTYPE_MOSTLY_IQ4_XS; break; case GGML_TYPE_IQ3_S: ftype = LLAMA_FTYPE_MOSTLY_IQ3_S; break; + case GGML_TYPE_NVFP4: ftype = LLAMA_FTYPE_MOSTLY_NVFP4; break; default: { LLAMA_LOG_WARN("%s: unknown type %s\n", __func__, ggml_type_name(type_max)); @@ -729,14 +746,14 @@ llama_model_loader::llama_model_loader( LLAMA_LOG_INFO("%s: Dumping metadata keys/values. Note: KV overrides do not apply in this output.\n", __func__); for (int i = 0; i < n_kv; i++) { - const char * name = gguf_get_key(meta.get(), i); - const enum gguf_type type = gguf_get_kv_type(meta.get(), i); + const char * name = gguf_get_key(metadata, i); + const enum gguf_type type = gguf_get_kv_type(metadata, i); const std::string type_name = type == GGUF_TYPE_ARRAY - ? format("%s[%s,%zu]", gguf_type_name(type), gguf_type_name(gguf_get_arr_type(meta.get(), i)), gguf_get_arr_n(meta.get(), i)) + ? format("%s[%s,%zu]", gguf_type_name(type), gguf_type_name(gguf_get_arr_type(metadata, i)), gguf_get_arr_n(metadata, i)) : gguf_type_name(type); - std::string value = gguf_kv_to_str(meta.get(), i); + std::string value = gguf_kv_to_str(metadata, i); const size_t MAX_VALUE_LEN = 40; if (value.size() > MAX_VALUE_LEN) { value = format("%s...", value.substr(0, MAX_VALUE_LEN - 3).c_str()); @@ -838,15 +855,382 @@ const struct ggml_tensor * llama_model_loader::check_tensor_dims(const std::stri return cur; } -struct ggml_tensor * llama_model_loader::create_tensor(struct ggml_context * ctx, const std::string & name, const std::initializer_list & ne, int flags) { - LLAMA_LOG_DEBUG("%s: loading tensor %s\n", __func__, name.c_str()); - const struct ggml_tensor * cur = check_tensor_dims(name, ne, !(flags & TENSOR_NOT_REQUIRED)); +// checks if the weight tensor can be used with the specified buffer type and device +static bool weight_buft_supported(const llama_hparams & hparams, ggml_tensor * w, ggml_op op, ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev) { + GGML_ASSERT(w != nullptr); + + if (op == GGML_OP_NONE) { + return true; + } + + ggml_init_params params = { + /*.mem_size =*/ ggml_tensor_overhead()*8, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + ggml_context_ptr ctx_ptr { ggml_init(params) }; + if (!ctx_ptr) { + throw std::runtime_error(format("failed to create ggml context")); + } + ggml_context * ctx = ctx_ptr.get(); + + ggml_tensor * op_tensor = nullptr; + + switch (op) { + case GGML_OP_GET_ROWS: + { + ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 512); + op_tensor = ggml_get_rows(ctx, w, b); + } break; + case GGML_OP_MUL_MAT: + { + ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], 512, w->ne[2], w->ne[3]); + op_tensor = ggml_mul_mat(ctx, w, b); + } break; + case GGML_OP_MUL_MAT_ID: + { + const int n_expert_used = hparams.n_expert_used; + GGML_ASSERT(n_expert_used > 0); + ggml_tensor * b = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, w->ne[0], n_expert_used, 512); + ggml_tensor * ids = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, n_expert_used, 512); + op_tensor = ggml_mul_mat_id(ctx, w, b, ids); + } break; + case GGML_OP_ADD: + { + ggml_tensor * a = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], w->ne[1], w->ne[2], w->ne[3]); + op_tensor = ggml_add(ctx, a, w); + } break; + case GGML_OP_ADD_ID: + { + const int n_expert_used = hparams.n_expert_used; + GGML_ASSERT(n_expert_used > 0); + ggml_tensor * a = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, w->ne[0], n_expert_used, 512); + ggml_tensor * c = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, n_expert_used, 512); + op_tensor = ggml_add_id(ctx, a, w, c); + } break; + case GGML_OP_MUL: + { + ggml_tensor * a = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], w->ne[1], w->ne[2], w->ne[3]); + op_tensor = ggml_mul(ctx, a, w); + } break; + case GGML_OP_DIV: + { + ggml_tensor * a = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, w->ne[0]); + op_tensor = ggml_div(ctx, a, w); + } break; + case GGML_OP_ROPE: + { + const int n_embd_head = hparams.n_embd_head_v(); + const int n_head = hparams.n_head(); + ggml_tensor * a = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd_head, n_head, 512); + ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 512); + op_tensor = ggml_rope_ext( + ctx, a, b, w, + 0, 0, 0, 0, 0, + 0, 0, 0, 0 + ); + + } break; + case GGML_OP_SSM_CONV: + { + const int64_t n_seq_tokens = 512; + const int64_t n_seqs = 3; + ggml_tensor * conv_x = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, w->ne[0] - 1 + n_seq_tokens, w->ne[1], n_seqs); + op_tensor = ggml_ssm_conv(ctx, conv_x, w); + } break; + case GGML_OP_SSM_SCAN: + { + // w is ssm_a, which is used to distinguish Mamba-1 and Mamba-2 + const int64_t d_state = w->ne[0] == 1 ? hparams.ssm_d_state : w->ne[0]; + const int64_t n_head = w->ne[1]; + const int64_t head_dim = hparams.ssm_d_inner / n_head; + const int64_t n_group = hparams.ssm_n_group ? hparams.ssm_n_group : 1; + const int64_t n_seq_tokens = 512; + const int64_t n_seqs = 3; + ggml_tensor * s = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, d_state, head_dim, n_head, n_seqs); + ggml_tensor * x = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, head_dim, n_head, n_seq_tokens, n_seqs); + ggml_tensor * dt = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_head, n_seq_tokens, n_seqs); + ggml_tensor * B = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, d_state, n_group, n_seq_tokens, n_seqs); + ggml_tensor * C = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, d_state, n_group, n_seq_tokens, n_seqs); + ggml_tensor * ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_seqs); + op_tensor = ggml_ssm_scan(ctx, s, x, dt, w, B, C, ids); + } break; + case GGML_OP_RWKV_WKV6: + { + // FIXME + const int64_t S = 123; + const int64_t H = 123; + const int64_t n_tokens = 123; + const int64_t n_seqs = 123; + ggml_tensor * k = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, S, H, n_tokens); + ggml_tensor * v = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, S, H, n_tokens); + ggml_tensor * r = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, S, H, n_tokens); + ggml_tensor * tf = w; + ggml_tensor * td = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, S, H, n_tokens); + ggml_tensor * state = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, S, n_seqs, S, H); + op_tensor = ggml_rwkv_wkv6(ctx, k, v, r, tf, td, state); + } break; + case GGML_OP_IM2COL: + { + const int n_embd_inp = hparams.n_embd_inp(); + ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_embd_inp, w->ne[1], 1, 1); + op_tensor = ggml_im2col(ctx, w, b, 1, 0, 0, 0, 1, 0, false, GGML_TYPE_F16); + } break; + case GGML_OP_SCALE: + { + op_tensor = ggml_scale(ctx, w, 1.0f); + } break; + default: + GGML_ABORT("%s: missing test for op %s for tensor %s", __func__, ggml_op_name(op), w->name); + } + + // create a temporary dummy buffer for the weight so that supports_op can check the buffer type + GGML_ASSERT(w->buffer == nullptr); + w->buffer = ggml_backend_buft_alloc_buffer(buft, 0); + bool op_supported = ggml_backend_dev_supports_op(dev, op_tensor); + ggml_backend_buffer_free(w->buffer); + w->buffer = nullptr; + + return op_supported; +} + +// find the first buffer type in the list that can use the tensor +static ggml_backend_buffer_type_t select_weight_buft(const llama_hparams & hparams, ggml_tensor * tensor, ggml_op op, const buft_list_t * buft_list) { + GGML_ASSERT(!buft_list->empty()); + for (const auto & cur : *buft_list) { + ggml_backend_dev_t cur_dev = cur.first; + ggml_backend_buffer_type_t cur_buft = cur.second; + if (weight_buft_supported(hparams, tensor, op, cur_buft, cur_dev)) { + return cur_buft; + } + } + + return nullptr; +} + +struct ggml_tensor * llama_model_loader::create_tensor( + const llama_hparams & hparams, const buft_list_t * buft_list_cpu, const buft_list_t * buft_list_input, const buft_list_t * buft_list_output, + const buft_list_t * buft_list_layer, const LLM_TN_IMPL & tn, const std::initializer_list & ne, int flags) { + auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * { + auto it = ctx_map.find(buft); + if (it == ctx_map.end()) { + // one ggml context per buffer type + int max_n_tensors = n_tensors; + max_n_tensors += 1; // duplicated output tensor + max_n_tensors += hparams.n_layer*2; // duplicated rope freq tensors + if (files.empty()) { + max_n_tensors += hparams.n_layer*256; // this should be well above what any model actually uses + } + const size_t ctx_size = ggml_tensor_overhead()*max_n_tensors; + + ggml_init_params params = { + /*.mem_size =*/ ctx_size, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + ggml_context * ctx = ggml_init(params); + if (!ctx) { + throw std::runtime_error(format("failed to create ggml context")); + } + + ctx_map.emplace(buft, ctx); + + return ctx; + } + return it->second.get(); + }; + + auto buft_for_tensor = [&](ggml_tensor * t_meta) -> ggml_backend_buffer_type_t { + if (!t_meta) { + if (flags & TENSOR_NOT_REQUIRED) { + return nullptr; + } + throw std::runtime_error(format("missing tensor '%s'", tn.str().c_str())); + } + + // some models use the token embedding tensor as the output, but since these are used in different layers and with different ops + // the tensor is duplicated + // to handle this, we check if the tensor is duplicated, and if so, we assume that it is being loaded as the output tensor + llm_tensor tn_tensor = tn.tensor; + if (tn.tensor == LLM_TENSOR_TOKEN_EMBD && (flags & TENSOR_DUPLICATED)) { + tn_tensor = LLM_TENSOR_OUTPUT; + } + + llm_tensor_info info; + try { + info = llm_tensor_info_for(tn_tensor); + } catch (const std::out_of_range & e) { + throw std::runtime_error(format("missing tensor info mapping for %s", tn.str().c_str())); + } + + // skip unused tensors + if (info.op == GGML_OP_NONE || (flags & TENSOR_SKIP)) { + const size_t nbytes = ggml_nbytes(t_meta); + LLAMA_LOG_WARN("model has unused tensor %s (size = %zu bytes) -- ignoring\n", tn.str().c_str(), nbytes); + + size_data -= nbytes; + n_created++; + + return nullptr; + } + + // tensors with "bias" suffix are always used with GGML_OP_ADD or GGML_OP_ADD_ID + ggml_op op; + bool bias = tn.suffix != nullptr && strcmp(tn.suffix, "bias") == 0; + if (bias) { + if (info.op == GGML_OP_MUL_MAT_ID) { + op = GGML_OP_ADD_ID; + } else { + op = GGML_OP_ADD; + } + } else { + op = info.op; + } + + // sanity checks + if (info.layer == LLM_TENSOR_LAYER_INPUT || info.layer == LLM_TENSOR_LAYER_OUTPUT) { + if (tn.bid != -1) { + GGML_ABORT("input/output layer tensor %s used with a layer number", tn.str().c_str()); + } + } else { + if (tn.bid == -1) { + GGML_ABORT("repeating layer tensor %s used without a layer number", tn.str().c_str()); + } + } + + // select the buffer type for this tensor + const buft_list_t * buft_list; + switch (info.layer) { + case LLM_TENSOR_LAYER_INPUT: + buft_list = buft_list_input; + break; + case LLM_TENSOR_LAYER_OUTPUT: + buft_list = buft_list_output; + break; + case LLM_TENSOR_LAYER_REPEATING: + GGML_ASSERT(buft_list_layer != nullptr); + buft_list = buft_list_layer; + break; + default: + GGML_ABORT("invalid layer %d for tensor %s", info.layer, tn.str().c_str()); + } + + ggml_backend_buffer_type_t buft = nullptr; + + // check overrides + if (tensor_buft_overrides) { + std::string tensor_name = tn.str(); + for (const auto * overrides = tensor_buft_overrides; overrides->pattern != nullptr; ++overrides) { + std::regex pattern(overrides->pattern); + if (std::regex_search(tensor_name, pattern)) { + if (overrides->buft == ggml_backend_cpu_buffer_type()) { + // when overriding to a CPU buffer, consider the extra buffer types + buft = select_weight_buft(hparams, t_meta, op, buft_list_cpu); + } else { + buft = overrides->buft; + } + + LLAMA_LOG_DEBUG("tensor %s (%zu MiB %s) buffer type overridden to %s\n", + tensor_name.c_str(), + ggml_nbytes(t_meta) / 1024 / 1024, ggml_type_name(t_meta->type), + ggml_backend_buft_name(buft)); + break; + } + } + } + + if (!buft) { + buft = select_weight_buft(hparams, t_meta, op, buft_list); + if (!buft) { + throw std::runtime_error(format("failed to find a compatible buffer type for tensor %s", tn.str().c_str())); + } + } + + // avoid using a host buffer when using mmap + auto * buft_dev = ggml_backend_buft_get_device(buft); + if (use_mmap && buft_dev && buft == ggml_backend_dev_host_buffer_type(buft_dev)) { + auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + if (!cpu_dev) { + throw std::runtime_error("no CPU backend found"); + } + buft = ggml_backend_dev_buffer_type(cpu_dev); + } + + if (buft != buft_list->front().second) { + if (n_tensors_moved == 0) { + first_tensor_moved_name = t_meta->name; + first_tensor_moved_type_name = ggml_type_name(t_meta->type); + first_moved_from_buft = buft_list->front().second; + first_moved_to_buft = buft; + } + n_tensors_moved++; + } + + return buft; + }; + + if (files.empty()) { + if (flags & TENSOR_SKIP_IF_VIRTUAL) { + return nullptr; + } + ggml_type type = GGML_TYPE_F32; + const int64_t tid = gguf_find_tensor(metadata, tn.str().c_str()); + if (tid != -1) { + type = gguf_get_tensor_type(metadata, tid); + } + + // for tensors that are not required some of the dimensions can be invalid: + if (flags & TENSOR_NOT_REQUIRED) { + for (size_t dim = 0; dim < ne.size(); dim++) { + if (ne.begin()[dim] <= 0) { + return nullptr; + } + } + } + + ggml_tensor t_meta; + memset(&t_meta, 0, sizeof(ggml_tensor)); + t_meta.type = type; + for (size_t dim = 0; dim < GGML_MAX_DIMS; dim++) { + t_meta.ne[dim] = dim < ne.size() ? ne.begin()[dim] : 1; + GGML_ASSERT(t_meta.ne[dim] >= 1); + t_meta.nb[dim] = dim == 0 ? ggml_type_size(type) : t_meta.ne[dim-1]*t_meta.nb[dim-1]; + GGML_ASSERT(t_meta.nb[dim] >= 1); + } + ggml_set_name(&t_meta, tn.str().c_str()); + + ggml_backend_buffer_type_t buft = buft_for_tensor(&t_meta); + GGML_ASSERT(buft != nullptr); + ggml_context * ctx = ctx_for_buft(buft); + ggml_tensor * ret = ggml_dup_tensor(ctx, &t_meta); + ggml_set_name(ret, tn.str().c_str()); + return ret; + } + + ggml_tensor * t_meta = get_tensor_meta(tn.str().c_str()); + ggml_backend_buffer_type_t buft = buft_for_tensor(t_meta); + if (buft == nullptr) { + return nullptr; // return type is ggml_tensor * + } + ggml_context * ctx = ctx_for_buft(buft); + + // if duplicated, check if the original tensor was allocated in the same buffer type context and avoid creating a new one + if (flags & TENSOR_DUPLICATED) { + ggml_tensor * t = ggml_get_tensor(ctx, tn.str().c_str()); + if (t) { + return t; + } + } + + LLAMA_LOG_DEBUG("%s: loading tensor %s\n", __func__, tn.str().c_str()); + const struct ggml_tensor * cur = check_tensor_dims(tn.str(), ne, !(flags & TENSOR_NOT_REQUIRED)); if (cur == NULL) { return NULL; } - bool duplicated = flags & TENSOR_DUPLICATED; + const bool duplicated = flags & TENSOR_DUPLICATED; struct ggml_tensor * tensor = ggml_dup_tensor(ctx, cur); ggml_set_name(tensor, ggml_get_name(cur)); @@ -858,7 +1242,6 @@ struct ggml_tensor * llama_model_loader::create_tensor(struct ggml_context * ctx } return tensor; - } struct ggml_tensor * llama_model_loader::create_tensor_as_view(struct ggml_context * ctx, struct ggml_tensor * base, const std::string & name, const std::initializer_list & ne, size_t offset, bool required) { @@ -893,6 +1276,11 @@ void llama_model_loader::done_getting_tensors() const { if (n_created != n_tensors) { throw std::runtime_error(format("%s: wrong number of tensors; expected %d, got %d", __func__, n_tensors, n_created)); } + if (n_tensors_moved > 0) { + LLAMA_LOG_DEBUG("%s: tensor '%s' (%s) (and %zu others) cannot be used with preferred buffer type %s, using %s instead\n", + __func__, first_tensor_moved_name.c_str(), first_tensor_moved_type_name.c_str(), n_tensors_moved - 1, + ggml_backend_buft_name(first_moved_from_buft), ggml_backend_buft_name(first_moved_to_buft)); + } } void llama_model_loader::init_mappings(bool prefetch, llama_mlocks * mlock_mmaps) { @@ -974,6 +1362,12 @@ bool llama_model_loader::load_all_data( llama_mlocks * lmlocks, llama_progress_callback progress_callback, void * progress_callback_user_data) { + if (files.empty()) { + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) { + set_tensor_data(t, set_tensor_data_ud); + } + return true; + } GGML_ASSERT(size_data != 0 && "call init_mappings() first"); std::vector> read_buf; diff --git a/examples/talk-llama/llama-model-loader.h b/examples/talk-llama/llama-model-loader.h index 65953dd3d5a..ed5de729caf 100644 --- a/examples/talk-llama/llama-model-loader.h +++ b/examples/talk-llama/llama-model-loader.h @@ -4,17 +4,22 @@ #include "llama-impl.h" #include "llama-arch.h" +#include "llama-hparams.h" #include "llama-mmap.h" #include "ggml-cpp.h" #include +#include #include #include #include using llama_buf_map = std::unordered_map; +// lists of buffer types used for each layer +using buft_list_t = std::vector>; + enum llama_fver { GGUF_FILE_VERSION_V1 = 1, GGUF_FILE_VERSION_V2 = 2, @@ -58,9 +63,10 @@ struct llama_model_loader { } }; - static const int TENSOR_NOT_REQUIRED = 1 << 0; - static const int TENSOR_DUPLICATED = 1 << 1; - static const int TENSOR_SKIP = 1 << 2; + static const int TENSOR_NOT_REQUIRED = 1 << 0; + static const int TENSOR_DUPLICATED = 1 << 1; + static const int TENSOR_SKIP = 1 << 2; + static const int TENSOR_SKIP_IF_VIRTUAL = 1 << 3; int n_kv = 0; int n_tensors = 0; @@ -84,7 +90,10 @@ struct llama_model_loader { std::unordered_map kv_overrides; const llama_model_tensor_buft_override * tensor_buft_overrides; - gguf_context_ptr meta; + gguf_context_ptr metadata_ptr; + struct gguf_context * metadata; // either metadata_ptr.get() or externally set + llama_model_set_tensor_data_t set_tensor_data; + void * set_tensor_data_ud; std::vector contexts; std::string arch_name; @@ -94,7 +103,26 @@ struct llama_model_loader { size_t size_data = 0; std::vector> mmaps_used; + // define a comparator for the buft -> ctx map to ensure that the order is well-defined: + struct ggml_backend_buft_comparator { + bool operator()(const ggml_backend_buffer_type_t & lhs, const ggml_backend_buffer_type_t & rhs) const { + return strcmp(ggml_backend_buft_name(lhs), ggml_backend_buft_name(rhs)) < 0; + } + }; + + std::map ctx_map; + + // track tensors that had to be moved for debugging: + size_t n_tensors_moved = 0; + std::string first_tensor_moved_name; + std::string first_tensor_moved_type_name; + ggml_backend_buffer_type_t first_moved_from_buft = nullptr; + ggml_backend_buffer_type_t first_moved_to_buft = nullptr; + llama_model_loader( + struct gguf_context * metadata, + llama_model_set_tensor_data_t set_tensor_data, + void * set_tensor_data_ud, const std::string & fname, std::vector & splits, // optional, only need if the split does not follow naming scheme bool use_mmap, @@ -149,7 +177,9 @@ struct llama_model_loader { const struct ggml_tensor * check_tensor_dims(const std::string & name, const std::vector & ne, bool required) const; - struct ggml_tensor * create_tensor(struct ggml_context * ctx, const std::string & name, const std::initializer_list & ne, int flags = 0); + struct ggml_tensor * create_tensor( + const llama_hparams & hparams, const buft_list_t * buft_list_cpu, const buft_list_t * buft_list_input, const buft_list_t * buft_list_output, + const buft_list_t * buft_list_layer, const LLM_TN_IMPL & tn, const std::initializer_list & ne, int flags); struct ggml_tensor * create_tensor_as_view(struct ggml_context * ctx, struct ggml_tensor * base, const std::string & name, const std::initializer_list & ne, size_t offset, bool required = true); diff --git a/examples/talk-llama/llama-model-saver.cpp b/examples/talk-llama/llama-model-saver.cpp index 676efeda709..6f6538aeccd 100644 --- a/examples/talk-llama/llama-model-saver.cpp +++ b/examples/talk-llama/llama-model-saver.cpp @@ -7,14 +7,19 @@ #include "llama-model.h" #include "llama-vocab.h" +#include #include -llama_model_saver::llama_model_saver(const struct llama_model & model) : model(model), llm_kv(model.arch) { - gguf_ctx = gguf_init_empty(); -} +llama_model_saver::llama_model_saver(const struct llama_model * model) : + gguf_ctx(gguf_init_empty()), gguf_ctx_owned(true), model(model), llm_kv(model->arch) {} + +llama_model_saver::llama_model_saver(enum llm_arch arch, struct gguf_context * gguf_ctx) : + gguf_ctx(gguf_ctx == nullptr ? gguf_init_empty() : gguf_ctx), gguf_ctx_owned(gguf_ctx == nullptr), model(nullptr), llm_kv(arch) {} llama_model_saver::~llama_model_saver() { - gguf_free(gguf_ctx); + if (gguf_ctx_owned) { + gguf_free(gguf_ctx); + } } void llama_model_saver::add_kv(const enum llm_kv key, const uint32_t value) { @@ -46,7 +51,8 @@ void llama_model_saver::add_kv(const enum llm_kv key, const char value) { template void llama_model_saver::add_kv(const enum llm_kv key, const Container & value, const bool per_layer) { - const size_t n_values = per_layer ? size_t(model.hparams.n_layer) : value.size(); + GGML_ASSERT(model != nullptr || !per_layer); + const size_t n_values = per_layer ? size_t(model->hparams.n_layer) : value.size(); GGML_ASSERT(n_values <= value.size()); if (n_values == 0) { @@ -83,6 +89,8 @@ void llama_model_saver::add_kv(const enum llm_kv key, const Container & value, c GGML_ABORT("fatal error"); } } +// instantiate for external usage: +template void llama_model_saver::add_kv>(const enum llm_kv, const std::vector &, const bool); void llama_model_saver::add_kv(const enum llm_kv key, const std::vector & value) { std::vector tmp(value.size()); @@ -104,37 +112,39 @@ void llama_model_saver::add_tensor(const struct ggml_tensor * tensor) { } void llama_model_saver::add_kv_from_model() { - const llama_hparams & hparams = model.hparams; - const llama_vocab & vocab = model.vocab; + const llama_hparams & hparams = model->hparams; + const llama_vocab & vocab = model->vocab; const int32_t n_vocab = vocab.n_tokens(); std::vector tokens(n_vocab); std::vector scores(n_vocab); std::vector token_types(n_vocab); - for (int32_t id = 0; id < n_vocab; ++id) { - const llama_vocab::token_data & token_data = vocab.get_token_data(id); - - tokens[id] = token_data.text; - scores[id] = token_data.score; - - switch(token_data.attr) { - case LLAMA_TOKEN_ATTR_UNKNOWN: token_types[id] = LLAMA_TOKEN_TYPE_UNKNOWN; break; - case LLAMA_TOKEN_ATTR_UNUSED: token_types[id] = LLAMA_TOKEN_TYPE_UNUSED; break; - case LLAMA_TOKEN_ATTR_NORMAL: token_types[id] = LLAMA_TOKEN_TYPE_NORMAL; break; - case LLAMA_TOKEN_ATTR_CONTROL: token_types[id] = LLAMA_TOKEN_TYPE_CONTROL; break; - case LLAMA_TOKEN_ATTR_USER_DEFINED: token_types[id] = LLAMA_TOKEN_TYPE_USER_DEFINED; break; - case LLAMA_TOKEN_ATTR_BYTE: token_types[id] = LLAMA_TOKEN_TYPE_BYTE; break; - case LLAMA_TOKEN_ATTR_UNDEFINED: - default: token_types[id] = LLAMA_TOKEN_TYPE_UNDEFINED; break; + if (vocab.get_type() != LLAMA_VOCAB_TYPE_NONE) { + for (int32_t id = 0; id < n_vocab; ++id) { + const llama_vocab::token_data & token_data = vocab.get_token_data(id); + + tokens[id] = token_data.text; + scores[id] = token_data.score; + + switch(token_data.attr) { + case LLAMA_TOKEN_ATTR_UNKNOWN: token_types[id] = LLAMA_TOKEN_TYPE_UNKNOWN; break; + case LLAMA_TOKEN_ATTR_UNUSED: token_types[id] = LLAMA_TOKEN_TYPE_UNUSED; break; + case LLAMA_TOKEN_ATTR_NORMAL: token_types[id] = LLAMA_TOKEN_TYPE_NORMAL; break; + case LLAMA_TOKEN_ATTR_CONTROL: token_types[id] = LLAMA_TOKEN_TYPE_CONTROL; break; + case LLAMA_TOKEN_ATTR_USER_DEFINED: token_types[id] = LLAMA_TOKEN_TYPE_USER_DEFINED; break; + case LLAMA_TOKEN_ATTR_BYTE: token_types[id] = LLAMA_TOKEN_TYPE_BYTE; break; + case LLAMA_TOKEN_ATTR_UNDEFINED: + default: token_types[id] = LLAMA_TOKEN_TYPE_UNDEFINED; break; + } } } // add_kv(LLM_KV_GENERAL_TYPE, ???); - add_kv(LLM_KV_GENERAL_ARCHITECTURE, model.arch_name()); + add_kv(LLM_KV_GENERAL_ARCHITECTURE, model->arch_name()); // add_kv(LLM_KV_GENERAL_QUANTIZATION_VERSION, ???); // add_kv(LLM_KV_GENERAL_ALIGNMENT, ???); - add_kv(LLM_KV_GENERAL_NAME, model.name); + add_kv(LLM_KV_GENERAL_NAME, model->name); // add_kv(LLM_KV_GENERAL_AUTHOR, ???); // add_kv(LLM_KV_GENERAL_VERSION, ???); // add_kv(LLM_KV_GENERAL_URL, ???); @@ -176,8 +186,10 @@ void llama_model_saver::add_kv_from_model() { add_kv(LLM_KV_ATTENTION_HEAD_COUNT_KV, hparams.n_head_kv_arr, true); add_kv(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, hparams.f_max_alibi_bias); add_kv(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv); - add_kv(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k); - add_kv(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v); + add_kv(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k_full); + add_kv(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v_full); + add_kv(LLM_KV_ATTENTION_KEY_LENGTH_SWA, hparams.n_embd_head_k_swa); + add_kv(LLM_KV_ATTENTION_VALUE_LENGTH_SWA, hparams.n_embd_head_v_swa); add_kv(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); add_kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); add_kv(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); @@ -189,7 +201,8 @@ void llama_model_saver::add_kv_from_model() { const float rope_scaling_factor = hparams.rope_freq_scale_train == 1.0f ? 0.0f : 1.0f/hparams.rope_freq_scale_train; - add_kv(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot); + add_kv(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot_full); + add_kv(LLM_KV_ROPE_DIMENSION_COUNT_SWA, hparams.n_rot_swa); add_kv(LLM_KV_ROPE_FREQ_BASE, hparams.rope_freq_base_train); // add_kv(LLM_KV_ROPE_SCALE_LINEAR, rope_scaling_factor); // old name add_kv(LLM_KV_ROPE_SCALING_TYPE, llama_rope_scaling_type_name(hparams.rope_scaling_type_train)); @@ -255,25 +268,25 @@ void llama_model_saver::add_kv_from_model() { } void llama_model_saver::add_tensors_from_model() { - if (std::string(model.output->name) != std::string(model.tok_embd->name)) { - add_tensor(model.tok_embd); // some models use the same tensor for tok_embd and output + if (std::string(model->output->name) != std::string(model->tok_embd->name)) { + add_tensor(model->tok_embd); // some models use the same tensor for tok_embd and output } - add_tensor(model.type_embd); - add_tensor(model.pos_embd); - add_tensor(model.tok_norm); - add_tensor(model.tok_norm_b); - add_tensor(model.output_norm); - add_tensor(model.output_norm_b); - add_tensor(model.output); - add_tensor(model.output_b); - add_tensor(model.output_norm_enc); - add_tensor(model.cls); - add_tensor(model.cls_b); - add_tensor(model.cls_out); - add_tensor(model.cls_out_b); - add_tensor(model.cls_norm); - - for (const struct llama_layer & layer : model.layers) { + add_tensor(model->type_embd); + add_tensor(model->pos_embd); + add_tensor(model->tok_norm); + add_tensor(model->tok_norm_b); + add_tensor(model->output_norm); + add_tensor(model->output_norm_b); + add_tensor(model->output); + add_tensor(model->output_b); + add_tensor(model->output_norm_enc); + add_tensor(model->cls); + add_tensor(model->cls_b); + add_tensor(model->cls_out); + add_tensor(model->cls_out_b); + add_tensor(model->cls_norm); + + for (const struct llama_layer & layer : model->layers) { for (size_t i = 0; i < sizeof(layer)/sizeof(struct ggml_tensor *); ++i) { add_tensor(reinterpret_cast(&layer)[i]); } diff --git a/examples/talk-llama/llama-model-saver.h b/examples/talk-llama/llama-model-saver.h index a5a434c3069..2b3541ce6c5 100644 --- a/examples/talk-llama/llama-model-saver.h +++ b/examples/talk-llama/llama-model-saver.h @@ -1,5 +1,6 @@ #pragma once +#include "gguf.h" #include "llama.h" #include "llama-arch.h" @@ -7,10 +8,12 @@ struct llama_model_saver { struct gguf_context * gguf_ctx = nullptr; - const struct llama_model & model; + const bool gguf_ctx_owned; + const struct llama_model * model; const struct LLM_KV llm_kv; - llama_model_saver(const struct llama_model & model); + llama_model_saver(const struct llama_model * model); + llama_model_saver(enum llm_arch arch, struct gguf_context * gguf_ctx); ~llama_model_saver(); void add_kv(enum llm_kv key, uint32_t value); diff --git a/examples/talk-llama/llama-model.cpp b/examples/talk-llama/llama-model.cpp index dabf3b3086e..e8e1bbf1cd1 100644 --- a/examples/talk-llama/llama-model.cpp +++ b/examples/talk-llama/llama-model.cpp @@ -1,5 +1,6 @@ #include "llama-model.h" +#include "ggml.h" #include "llama-impl.h" #include "llama-mmap.h" #include "llama-cparams.h" @@ -18,6 +19,7 @@ #include #include #include +#include #include #include #include @@ -61,6 +63,7 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_0_3B: return "0.3B"; case LLM_TYPE_0_5B: return "0.5B"; case LLM_TYPE_0_6B: return "0.6B"; + case LLM_TYPE_0_8B: return "0.8B"; case LLM_TYPE_1B: return "1B"; case LLM_TYPE_1_2B: return "1.2B"; case LLM_TYPE_1_3B: return "1.3B"; @@ -132,12 +135,15 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_100B_A6B: return "100B.A6B"; case LLM_TYPE_102B_A12B: return "102B.A12B"; case LLM_TYPE_106B_A12B: return "106B.A12B"; + case LLM_TYPE_120B_A12B: return "120B.A12B"; + case LLM_TYPE_122B_A10B: return "122B.A10B"; case LLM_TYPE_196B_A11B: return "196B.A11B"; case LLM_TYPE_230B_A10B: return "230B.A10B"; case LLM_TYPE_235B_A22B: return "235B.A22B"; case LLM_TYPE_300B_A47B: return "300B.A47B"; case LLM_TYPE_310B_A15B: return "310B.A15B"; case LLM_TYPE_355B_A32B: return "355B.A32B"; + case LLM_TYPE_397B_A17B: return "397B.A17B"; case LLM_TYPE_744B_A40B: return "744B.A40B"; case LLM_TYPE_E2B: return "E2B"; case LLM_TYPE_E4B: return "E4B"; @@ -174,160 +180,6 @@ static llama_rope_scaling_type llama_rope_scaling_type_from_string(const std::st return LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED; } -// checks if the weight tensor can be used with the specified buffer type and device -static bool weight_buft_supported(const llama_hparams & hparams, ggml_tensor * w, ggml_op op, ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev) { - GGML_ASSERT(w != nullptr); - - if (op == GGML_OP_NONE) { - return true; - } - - ggml_init_params params = { - /*.mem_size =*/ ggml_tensor_overhead()*8, - /*.mem_buffer =*/ NULL, - /*.no_alloc =*/ true, - }; - ggml_context_ptr ctx_ptr { ggml_init(params) }; - if (!ctx_ptr) { - throw std::runtime_error(format("failed to create ggml context")); - } - ggml_context * ctx = ctx_ptr.get(); - - ggml_tensor * op_tensor = nullptr; - - switch (op) { - case GGML_OP_GET_ROWS: - { - ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 512); - op_tensor = ggml_get_rows(ctx, w, b); - } break; - case GGML_OP_MUL_MAT: - { - ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], 512, w->ne[2], w->ne[3]); - op_tensor = ggml_mul_mat(ctx, w, b); - } break; - case GGML_OP_MUL_MAT_ID: - { - int n_expert_used = hparams.n_expert_used; - ggml_tensor * b = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, w->ne[0], n_expert_used, 512); - ggml_tensor * ids = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, n_expert_used, 512); - op_tensor = ggml_mul_mat_id(ctx, w, b, ids); - } break; - case GGML_OP_ADD: - { - ggml_tensor * a = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], w->ne[1], w->ne[2], w->ne[3]); - op_tensor = ggml_add(ctx, a, w); - } break; - case GGML_OP_ADD_ID: - { - int n_expert_used = hparams.n_expert_used; - ggml_tensor * a = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, w->ne[0], n_expert_used, 512); - ggml_tensor * c = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, n_expert_used, 512); - op_tensor = ggml_add_id(ctx, a, w, c); - } break; - case GGML_OP_MUL: - { - ggml_tensor * a = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], w->ne[1], w->ne[2], w->ne[3]); - op_tensor = ggml_mul(ctx, a, w); - } break; - case GGML_OP_DIV: - { - ggml_tensor * a = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, w->ne[0]); - op_tensor = ggml_div(ctx, a, w); - } break; - case GGML_OP_ROPE: - { - int n_embd_head = hparams.n_embd_head_v; - int n_head = hparams.n_head(); - ggml_tensor * a = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd_head, n_head, 512); - ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 512); - op_tensor = ggml_rope_ext( - ctx, a, b, w, - 0, 0, 0, 0, 0, - 0, 0, 0, 0 - ); - - } break; - case GGML_OP_SSM_CONV: - { - const int64_t n_seq_tokens = 512; - const int64_t n_seqs = 3; - ggml_tensor * conv_x = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, w->ne[0] - 1 + n_seq_tokens, w->ne[1], n_seqs); - op_tensor = ggml_ssm_conv(ctx, conv_x, w); - } break; - case GGML_OP_SSM_SCAN: - { - // w is ssm_a, which is used to distinguish Mamba-1 and Mamba-2 - const int64_t d_state = w->ne[0] == 1 ? hparams.ssm_d_state : w->ne[0]; - const int64_t n_head = w->ne[1]; - const int64_t head_dim = hparams.ssm_d_inner / n_head; - const int64_t n_group = hparams.ssm_n_group ? hparams.ssm_n_group : 1; - const int64_t n_seq_tokens = 512; - const int64_t n_seqs = 3; - ggml_tensor * s = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, d_state, head_dim, n_head, n_seqs); - ggml_tensor * x = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, head_dim, n_head, n_seq_tokens, n_seqs); - ggml_tensor * dt = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_head, n_seq_tokens, n_seqs); - ggml_tensor * B = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, d_state, n_group, n_seq_tokens, n_seqs); - ggml_tensor * C = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, d_state, n_group, n_seq_tokens, n_seqs); - ggml_tensor * ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_seqs); - op_tensor = ggml_ssm_scan(ctx, s, x, dt, w, B, C, ids); - } break; - case GGML_OP_RWKV_WKV6: - { - // FIXME - const int64_t S = 123; - const int64_t H = 123; - const int64_t n_tokens = 123; - const int64_t n_seqs = 123; - ggml_tensor * k = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, S, H, n_tokens); - ggml_tensor * v = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, S, H, n_tokens); - ggml_tensor * r = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, S, H, n_tokens); - ggml_tensor * tf = w; - ggml_tensor * td = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, S, H, n_tokens); - ggml_tensor * state = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, S, n_seqs, S, H); - op_tensor = ggml_rwkv_wkv6(ctx, k, v, r, tf, td, state); - } break; - case GGML_OP_IM2COL: - { - const int n_embd_inp = hparams.n_embd_inp(); - ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_embd_inp, w->ne[1], 1, 1); - op_tensor = ggml_im2col(ctx, w, b, 1, 0, 0, 0, 1, 0, false, GGML_TYPE_F16); - } break; - case GGML_OP_SCALE: - { - op_tensor = ggml_scale(ctx, w, 1.0f); - } break; - default: - GGML_ABORT("%s: missing test for op %s for tensor %s", __func__, ggml_op_name(op), w->name); - } - - // create a temporary dummy buffer for the weight so that supports_op can check the buffer type - GGML_ASSERT(w->buffer == nullptr); - w->buffer = ggml_backend_buft_alloc_buffer(buft, 0); - bool op_supported = ggml_backend_dev_supports_op(dev, op_tensor); - ggml_backend_buffer_free(w->buffer); - w->buffer = nullptr; - - return op_supported; -} - -// lists of buffer types used for each layer -using buft_list_t = std::vector>; - -// find the first buffer type in the list that can use the tensor -static ggml_backend_buffer_type_t select_weight_buft(const llama_hparams & hparams, ggml_tensor * tensor, ggml_op op, const buft_list_t & buft_list) { - GGML_ASSERT(!buft_list.empty()); - for (const auto & cur : buft_list) { - ggml_backend_dev_t cur_dev = cur.first; - ggml_backend_buffer_type_t cur_buft = cur.second; - if (weight_buft_supported(hparams, tensor, op, cur_buft, cur_dev)) { - return cur_buft; - } - } - - return nullptr; -} - // CPU: ACCEL -> GPU host -> CPU extra -> CPU static buft_list_t make_cpu_buft_list(const std::vector & devices, bool use_extra_bufts, bool no_host) { buft_list_t buft_list; @@ -493,7 +345,7 @@ void llama_model::load_arch(llama_model_loader & ml) { } void llama_model::load_hparams(llama_model_loader & ml) { - const gguf_context * ctx = ml.meta.get(); + const gguf_context * ctx = ml.metadata; // get metadata as string for (int i = 0; i < gguf_get_n_kv(ctx); i++) { @@ -608,26 +460,37 @@ void llama_model::load_hparams(llama_model_loader & ml) { // gpt-neox n_rot = rotary_pct * (n_embd / n_head) // gpt-j n_rot = rotary_dim - hparams.n_embd_head_k = hparams.n_embd / hparams.n_head(); - ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k, false); + hparams.n_embd_head_k_full = hparams.n_embd / hparams.n_head(); + ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k_full, false); - hparams.n_embd_head_v = hparams.n_embd / hparams.n_head(); - ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false); + hparams.n_embd_head_v_full = hparams.n_embd / hparams.n_head(); + ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v_full, false); // sanity check for n_rot (optional) - hparams.n_rot = hparams.n_embd_head_k; + hparams.n_rot_full = hparams.n_embd_head_k_full; - ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false); + ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot_full, false); if (arch == LLM_ARCH_LLAMA || arch == LLM_ARCH_DECI || arch == LLM_ARCH_FALCON || arch == LLM_ARCH_LLAMA_EMBED) { - if (hparams.n_rot != hparams.n_embd_head_k) { - throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd_head_k)); + if (hparams.n_rot_full != hparams.n_embd_head_k_full) { + throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot_full, hparams.n_embd_head_k_full)); } } } else { - hparams.n_rot = 0; - hparams.n_embd_head_k = 0; - hparams.n_embd_head_v = 0; + hparams.n_rot_full = 0; + hparams.n_embd_head_k_full = 0; + hparams.n_embd_head_v_full = 0; + } + + // head size and n_rot for SWA layers + { + hparams.n_embd_head_k_swa = hparams.n_embd_head_k_full; + hparams.n_embd_head_v_swa = hparams.n_embd_head_v_full; + ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_SWA, hparams.n_embd_head_k_swa, false); + ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_SWA, hparams.n_embd_head_v_swa, false); + + hparams.n_rot_swa = hparams.n_rot_full; + ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT_SWA, hparams.n_rot_swa, false); } // for differentiating model types @@ -687,7 +550,9 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.n_attn_temp_floor_scale = 8192; hparams.f_attn_temp_scale = 0.1f; hparams.f_attn_temp_offset = 1.0f; - hparams.set_swa_pattern(4); // pattern: 3 chunked - 1 full + uint32_t swa_period = 4; // pattern: 3 chunked - 1 full + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); + hparams.set_swa_pattern(swa_period); hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; @@ -724,7 +589,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { case LLM_ARCH_AFMOE: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); @@ -736,7 +601,9 @@ void llama_model::load_hparams(llama_model_loader & ml) { // Pattern: 3 sliding - 1 full (global_attn_every_n_layers = 4) if (hparams.n_swa > 0) { hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; - hparams.set_swa_pattern(4); + uint32_t swa_period = 4; + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); + hparams.set_swa_pattern(swa_period); hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; @@ -881,7 +748,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { case LLM_ARCH_BERT: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); + ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn, false); ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); switch (hparams.n_layer) { @@ -904,10 +771,9 @@ void llama_model::load_hparams(llama_model_loader & ml) { { const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); if (found_swa && hparams.n_swa > 0) { - uint32_t swa_period = 3; hparams.swa_type = LLAMA_SWA_TYPE_SYMMETRIC; - - ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa); + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + uint32_t swa_period = 3; ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); hparams.set_swa_pattern(swa_period, true); } else { @@ -915,7 +781,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { } ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); + ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn, false); ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); switch (hparams.n_layer) { @@ -931,7 +797,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { case LLM_ARCH_JINA_BERT_V2: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); + ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn, false); ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); hparams.f_max_alibi_bias = 8.0f; @@ -944,7 +810,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { case LLM_ARCH_JINA_BERT_V3: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); + ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn, false); ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); switch (hparams.n_layer) { @@ -957,8 +823,8 @@ void llama_model::load_hparams(llama_model_loader & ml) { case LLM_ARCH_NOMIC_BERT_MOE: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); - ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type); + ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn, false); + ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); ml.get_key(LLM_KV_MOE_EVERY_N_LAYERS, hparams.moe_every_n_layers, 0); if (hparams.n_layer == 12 && hparams.n_embd == 768) { @@ -972,8 +838,8 @@ void llama_model::load_hparams(llama_model_loader & ml) { case LLM_ARCH_NEO_BERT: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); - ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type); + ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn, false); + ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); if (hparams.n_layer == 28) { type = LLM_TYPE_250M; @@ -982,8 +848,8 @@ void llama_model::load_hparams(llama_model_loader & ml) { case LLM_ARCH_EUROBERT: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); - ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type); + ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn, false); + ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); if (hparams.n_layer == 12) { type = LLM_TYPE_SMALL; // 0.2B @@ -1011,7 +877,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); ml.get_key(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv, false); - ml.get_key(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, hparams.f_max_alibi_bias); + ml.get_key(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, hparams.f_max_alibi_bias, false); switch (hparams.n_layer) { case 32: type = LLM_TYPE_7B; break; @@ -1260,19 +1126,15 @@ void llama_model::load_hparams(llama_model_loader & ml) { break; default: type = LLM_TYPE_UNKNOWN; } - - // Load attention parameters - ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k, false); - ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false); } break; case LLM_ARCH_PLAMO3: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); if (found_swa && hparams.n_swa > 0) { - uint32_t swa_period = 8; hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; - ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa); + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + uint32_t swa_period = 8; ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); hparams.set_swa_pattern(swa_period); } else { @@ -1335,7 +1197,9 @@ void llama_model::load_hparams(llama_model_loader & ml) { { hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; hparams.n_swa = 4096; // default value of gemma 2 - hparams.set_swa_pattern(2); + uint32_t swa_period = 2; + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); + hparams.set_swa_pattern(swa_period); hparams.attn_soft_cap = true; hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; @@ -1356,14 +1220,16 @@ void llama_model::load_hparams(llama_model_loader & ml) { // ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/config.py#L173 hparams.f_attention_scale = type == LLM_TYPE_27B ? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0))) - : 1.0f / std::sqrt(float(hparams.n_embd_head_k)); + : 1.0f / std::sqrt(float(hparams.n_embd_head_k())); } break; case LLM_ARCH_GEMMA3: { const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); if (found_swa && hparams.n_swa > 0) { hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; - hparams.set_swa_pattern(6); + uint32_t swa_period = 6; + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); + hparams.set_swa_pattern(swa_period); ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); } else { @@ -1387,12 +1253,14 @@ void llama_model::load_hparams(llama_model_loader & ml) { // ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/config.py#L289 hparams.f_attention_scale = type == LLM_TYPE_27B ? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0))) - : 1.0f / std::sqrt(float(hparams.n_embd_head_k)); + : 1.0f / std::sqrt(float(hparams.n_embd_head_k())); } break; case LLM_ARCH_GEMMA3N: { + uint32_t swa_period = 5; + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; - hparams.set_swa_pattern(5); + hparams.set_swa_pattern(swa_period); hparams.n_layer_kv_from_start = 20; hparams.f_attention_scale = 1.0f; @@ -1410,14 +1278,16 @@ void llama_model::load_hparams(llama_model_loader & ml) { case LLM_ARCH_GEMMA_EMBEDDING: { hparams.swa_type = LLAMA_SWA_TYPE_SYMMETRIC; - hparams.set_swa_pattern(6); + uint32_t swa_period = 6; + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); + hparams.set_swa_pattern(swa_period); hparams.causal_attn = false; // embeddings do not use causal attention ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type); + ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); //applied only if model converted with --sentence-transformers-dense-modules ml.get_key(LLM_KV_DENSE_2_FEAT_IN, hparams.dense_2_feat_in, false); @@ -1432,7 +1302,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { case 24: type = LLM_TYPE_0_3B; break; default: type = LLM_TYPE_UNKNOWN; } - hparams.f_attention_scale = 1.0f / std::sqrt(float(hparams.n_embd_head_k)); + hparams.f_attention_scale = 1.0f / std::sqrt(float(hparams.n_embd_head_k())); } break; case LLM_ARCH_STARCODER2: @@ -1524,7 +1394,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { } switch (hparams.n_layer) { - // TODO: Jamba layers are a bit heterogenous, so naming this is hard. + // TODO: Jamba layers are a bit heterogeneous, so naming this is hard. case 12: // 900M 8x???M case 32: // 51B 16x?B default: type = LLM_TYPE_UNKNOWN; @@ -1542,7 +1412,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { } break; case LLM_ARCH_COMMAND_R: { - ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); + ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale, false); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); switch (hparams.n_layer) { case 40: type = LLM_TYPE_35B; break; @@ -1552,7 +1422,9 @@ void llama_model::load_hparams(llama_model_loader & ml) { case LLM_ARCH_COHERE2: { hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; - hparams.set_swa_pattern(4); + uint32_t swa_period = 4; + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); + hparams.set_swa_pattern(swa_period); hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; @@ -1594,7 +1466,9 @@ void llama_model::load_hparams(llama_model_loader & ml) { const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); if (found_swa && hparams.n_swa > 0) { hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; - hparams.set_swa_pattern(4); + uint32_t swa_period = 4; + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); + hparams.set_swa_pattern(swa_period); hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; hparams.rope_freq_scale_train_swa = 1.0; // See olmo2.cpp @@ -1701,10 +1575,10 @@ void llama_model::load_hparams(llama_model_loader & ml) { case LLM_ARCH_DEEPSEEK: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); switch (hparams.n_ff_exp) { case 1408: type = LLM_TYPE_16B; break; @@ -1718,7 +1592,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { const bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26 || (hparams.n_layer == 48 && n_vocab == 128256)); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); if (!is_lite) { ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); } @@ -1820,7 +1694,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used); ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); // Expert gating function (GLM-4.5 uses sigmoid) @@ -1853,7 +1727,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used); ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); // deepseek MLA parameters @@ -1939,7 +1813,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { case LLM_ARCH_JAIS: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - ml.get_key(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, hparams.f_max_alibi_bias); + ml.get_key(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, hparams.f_max_alibi_bias, false); switch (hparams.n_layer) { case 24: type = LLM_TYPE_1_3B; break; @@ -1988,10 +1862,12 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared, false); ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); + ml.get_key(LLM_KV_MOE_LATENT_SIZE, hparams.moe_latent_size, false); switch (hparams.n_layer) { case 52: type = LLM_TYPE_31B_A3_5B; break; // Nemotron-H_MOE 31B case 56: type = LLM_TYPE_9B; break; + case 88: type = LLM_TYPE_120B_A12B; break; default: type = LLM_TYPE_UNKNOWN; } } break; @@ -2009,7 +1885,9 @@ void llama_model::load_hparams(llama_model_loader & ml) { if (hparams.n_layer == 64) { // 32B hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; hparams.n_swa = 4096; - hparams.set_swa_pattern(4); + uint32_t swa_period = 4; + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); + hparams.set_swa_pattern(swa_period); hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; @@ -2029,7 +1907,9 @@ void llama_model::load_hparams(llama_model_loader & ml) { { hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; hparams.n_swa = 128; - hparams.set_swa_pattern(4); + uint32_t swa_period = 4; + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); + hparams.set_swa_pattern(swa_period); hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; @@ -2042,7 +1922,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func); ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); - ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); @@ -2126,9 +2006,9 @@ void llama_model::load_hparams(llama_model_loader & ml) { { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); - ml.get_key(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale); - ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale); - ml.get_key(LLM_KV_ATTENTION_SCALE, hparams.f_attention_scale); + ml.get_key(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale, false); + ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale, false); + ml.get_key(LLM_KV_ATTENTION_SCALE, hparams.f_attention_scale, false); // Granite uses rope_finetuned as a switch for rope, so default to true bool rope_finetuned = true; @@ -2186,7 +2066,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); hparams.f_norm_eps = 1e-5; // eps for qk-norm, torch default - ml.get_key(LLM_KV_SWIN_NORM, hparams.swin_norm); + ml.get_key(LLM_KV_SWIN_NORM, hparams.swin_norm, false); switch (hparams.n_layer) { case 32: type = LLM_TYPE_7B; break; @@ -2199,15 +2079,15 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); ml.get_key(LLM_KV_ATTENTION_GROUPNORM_EPS, hparams.f_norm_group_eps); ml.get_key(LLM_KV_ATTENTION_GROUPNORM_GROUPS, hparams.n_norm_groups); - ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); + ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn, false); } break; case LLM_ARCH_BAILINGMOE: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); switch (hparams.n_layer) { @@ -2219,11 +2099,11 @@ void llama_model::load_hparams(llama_model_loader & ml) { case LLM_ARCH_BAILINGMOE2: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); - ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp); + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func); ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); @@ -2242,10 +2122,10 @@ void llama_model::load_hparams(llama_model_loader & ml) { case LLM_ARCH_DOTS1: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); switch (hparams.n_layer) { @@ -2265,7 +2145,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); ml.get_key(LLM_KV_INTERLEAVE_MOE_LAYER_STEP, hparams.n_moe_layer_step); - ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); } switch (hparams.n_layer) { @@ -2310,7 +2190,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); - ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp); + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); switch (hparams.n_layer) { case 32: type = LLM_TYPE_A13B; break; @@ -2346,7 +2226,9 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; - hparams.set_swa_pattern(2); + uint32_t swa_period = 2; + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); + hparams.set_swa_pattern(swa_period); hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; @@ -2384,7 +2266,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { { ml.get_key(LLM_KV_SHORTCONV_L_CACHE, hparams.n_shortconv_l_cache); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func); @@ -2403,9 +2285,11 @@ void llama_model::load_hparams(llama_model_loader & ml) { const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); if (found_swa && hparams.n_swa > 0) { - hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; - hparams.n_swa = 4096; - hparams.set_swa_pattern(4, true); + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + hparams.n_swa = 4096; + uint32_t swa_period = 4; + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); + hparams.set_swa_pattern(swa_period, true); hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; @@ -2428,7 +2312,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { case LLM_ARCH_GROVEMOE: { ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); - ml.get_key(LLM_KV_EXPERT_CHUNK_FEED_FORWARD_LENGTH, hparams.n_ff_chexp); + ml.get_key(LLM_KV_EXPERT_CHUNK_FEED_FORWARD_LENGTH, hparams.n_ff_chexp, false); ml.get_key(LLM_KV_EXPERT_GROUP_SCALE, hparams.expert_group_scale); ml.get_key(LLM_KV_EXPERTS_PER_GROUP, hparams.n_group_experts); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -2528,7 +2412,9 @@ void llama_model::load_hparams(llama_model_loader & ml) { } switch (hparams.n_layer) { - case 24: type = LLM_TYPE_2B; break; + case 24: type = hparams.n_embd == 1024 ? LLM_TYPE_0_8B : LLM_TYPE_2B; break; + case 32: type = hparams.n_embd == 2560 ? LLM_TYPE_4B : LLM_TYPE_9B; break; + case 64: type = LLM_TYPE_27B; break; default: type = LLM_TYPE_UNKNOWN; } } break; @@ -2557,8 +2443,9 @@ void llama_model::load_hparams(llama_model_loader & ml) { } switch (hparams.n_layer) { - case 28: type = LLM_TYPE_35B_A3B; break; - case 48: type = LLM_TYPE_80B_A3B; break; + case 40: type = LLM_TYPE_35B_A3B; break; + case 48: type = LLM_TYPE_122B_A10B; break; + case 60: type = LLM_TYPE_397B_A17B; break; default: type = LLM_TYPE_UNKNOWN; } } break; @@ -2596,7 +2483,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); - ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa); + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.swa_layers, hparams.n_layer); switch (hparams.n_layer) { @@ -2610,7 +2497,6 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla_impl); ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla_impl); ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); - ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot); ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); ml.get_key(LLM_KV_KDA_HEAD_DIM, hparams.n_embd_head_kda); @@ -2626,8 +2512,8 @@ void llama_model::load_hparams(llama_model_loader & ml) { // MoE parameters - Kimi uses moe_intermediate_size = 1024 ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); - ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func); switch (hparams.n_layer) { @@ -2641,6 +2527,9 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + // full_attention layer only use half of the RoPE dimensions + hparams.n_rot_full = hparams.n_rot_full / 2; + // MoE + SWA parameters ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); @@ -2654,7 +2543,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { } ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); - ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa); + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.swa_layers, hparams.n_layer); ml.get_key_or_arr(LLM_KV_SWIGLU_CLAMP_EXP, hparams.swiglu_clamp_exp, hparams.n_layer, false); ml.get_key_or_arr(LLM_KV_SWIGLU_CLAMP_SHEXP, hparams.swiglu_clamp_shexp, hparams.n_layer, false); @@ -2664,7 +2553,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; - default: throw std::runtime_error("unsupported model architecture"); + default: throw std::runtime_error("unsupported model architecture: " + arch_name()); } pimpl->n_bytes = ml.n_bytes; @@ -2771,44 +2660,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // assign the output layer pimpl->dev_output = get_layer_buft_list(n_layer); - // one ggml context per buffer type - int max_n_tensors = ml.n_tensors; - max_n_tensors += 1; // duplicated output tensor - max_n_tensors += n_layer*2; // duplicated rope freq tensors - const size_t ctx_size = ggml_tensor_overhead()*max_n_tensors; - - // define a comparator for the buft -> ctx map to ensure that the order is well-defined: - struct ggml_backend_buft_comparator { - bool operator()(const ggml_backend_buffer_type_t & lhs, const ggml_backend_buffer_type_t & rhs) const { - return strcmp(ggml_backend_buft_name(lhs), ggml_backend_buft_name(rhs)) < 0; - } - }; - std::map ctx_map; - - auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * { - auto it = ctx_map.find(buft); - if (it == ctx_map.end()) { - ggml_init_params params = { - /*.mem_size =*/ ctx_size, - /*.mem_buffer =*/ NULL, - /*.no_alloc =*/ true, - }; - - ggml_context * ctx = ggml_init(params); - if (!ctx) { - throw std::runtime_error(format("failed to create ggml context")); - } - - ctx_map.emplace(buft, ctx); - - return ctx; - } - return it->second.get(); - }; - - const auto TENSOR_DUPLICATED = llama_model_loader::TENSOR_DUPLICATED; - const auto TENSOR_NOT_REQUIRED = llama_model_loader::TENSOR_NOT_REQUIRED; - const auto TENSOR_SKIP = llama_model_loader::TENSOR_SKIP; + const auto TENSOR_DUPLICATED = llama_model_loader::TENSOR_DUPLICATED; + const auto TENSOR_NOT_REQUIRED = llama_model_loader::TENSOR_NOT_REQUIRED; + const auto TENSOR_SKIP = llama_model_loader::TENSOR_SKIP; + const auto TENSOR_SKIP_IF_VIRTUAL = llama_model_loader::TENSOR_SKIP_IF_VIRTUAL; // create tensors for the weights { @@ -2818,13 +2673,13 @@ bool llama_model::load_tensors(llama_model_loader & ml) { const int64_t n_embd = hparams.n_embd; const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(); const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(); - const int64_t n_embd_head_k = hparams.n_embd_head_k; - const int64_t n_embd_head_v = hparams.n_embd_head_v; + const int64_t n_embd_head_k = hparams.n_embd_head_k(); + const int64_t n_embd_head_v = hparams.n_embd_head_v(); const int64_t n_ff = hparams.n_ff(); const int64_t n_embd_gqa = n_embd_v_gqa; const int64_t n_vocab = vocab.n_tokens(); const int64_t n_token_types = vocab.n_token_types(); - const int64_t n_rot = hparams.n_rot; + const int64_t n_rot = hparams.n_rot(); const int64_t n_expert = hparams.n_expert; const int64_t n_expert_used = hparams.n_expert_used; const int64_t n_ctx_train = hparams.n_ctx_train; @@ -2833,147 +2688,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) { throw std::runtime_error("model has expert layers but no expert layers are used"); } - int n_moved_tensors = 0; - ggml_tensor * first_moved_tensor = nullptr; - ggml_backend_buffer_type_t first_moved_from_buft = nullptr; - ggml_backend_buffer_type_t first_moved_to_buft = nullptr; - auto create_tensor = [&](const LLM_TN_IMPL & tn, const std::initializer_list & ne, int flags) -> ggml_tensor * { - ggml_tensor * t_meta = ml.get_tensor_meta(tn.str().c_str()); - - if (!t_meta) { - if (flags & TENSOR_NOT_REQUIRED) { - return nullptr; - } - throw std::runtime_error(format("missing tensor '%s'", tn.str().c_str())); - } - - // some models use the token embedding tensor as the output, but since these are used in different layers and with different ops - // the tensor is duplicated - // to handle this, we check if the tensor is duplicated, and if so, we assume that it is being loaded as the output tensor - llm_tensor tn_tensor = tn.tensor; - if (tn.tensor == LLM_TENSOR_TOKEN_EMBD && flags & TENSOR_DUPLICATED) { - tn_tensor = LLM_TENSOR_OUTPUT; - } - - llm_tensor_info info; - try { - info = llm_tensor_info_for(tn_tensor); - } catch (const std::out_of_range & e) { - throw std::runtime_error(format("missing tensor info mapping for %s", tn.str().c_str())); - } - - // skip unused tensors - if (info.op == GGML_OP_NONE || flags & TENSOR_SKIP) { - const size_t nbytes = ggml_nbytes(t_meta); - LLAMA_LOG_WARN("model has unused tensor %s (size = %zu bytes) -- ignoring\n", tn.str().c_str(), nbytes); - - ml.size_data -= nbytes; - ml.n_created++; - - return nullptr; - } - - // tensors with "bias" suffix are always used with GGML_OP_ADD or GGML_OP_ADD_ID - ggml_op op; - bool bias = tn.suffix != nullptr && strcmp(tn.suffix, "bias") == 0; - if (bias) { - if (info.op == GGML_OP_MUL_MAT_ID) { - op = GGML_OP_ADD_ID; - } else { - op = GGML_OP_ADD; - } - } else { - op = info.op; - } - - // sanity checks - if (info.layer == LLM_TENSOR_LAYER_INPUT || info.layer == LLM_TENSOR_LAYER_OUTPUT) { - if (tn.bid != -1) { - GGML_ABORT("input/output layer tensor %s used with a layer number", tn.str().c_str()); - } - } else { - if (tn.bid == -1) { - GGML_ABORT("repeating layer tensor %s used without a layer number", tn.str().c_str()); - } - } - - // select the buffer type for this tensor - buft_list_t * buft_list; - switch (info.layer) { - case LLM_TENSOR_LAYER_INPUT: - buft_list = pimpl->dev_input.buft_list; - break; - case LLM_TENSOR_LAYER_OUTPUT: - buft_list = pimpl->dev_output.buft_list; - break; - case LLM_TENSOR_LAYER_REPEATING: - buft_list = pimpl->dev_layer.at(tn.bid).buft_list; - break; - default: - GGML_ABORT("invalid layer %d for tensor %s", info.layer, tn.str().c_str()); - } - - ggml_backend_buffer_type_t buft = nullptr; - - // check overrides - if (ml.tensor_buft_overrides) { - std::string tensor_name = tn.str(); - for (const auto * overrides = ml.tensor_buft_overrides; overrides->pattern != nullptr; ++overrides) { - std::regex pattern(overrides->pattern); - if (std::regex_search(tensor_name, pattern)) { - if (overrides->buft == ggml_backend_cpu_buffer_type()) { - // when overriding to a CPU buffer, consider the extra buffer types - buft = select_weight_buft(hparams, t_meta, op, pimpl->cpu_buft_list); - } else { - buft = overrides->buft; - } - - LLAMA_LOG_DEBUG("tensor %s (%zu MiB %s) buffer type overridden to %s\n", - tensor_name.c_str(), - ggml_nbytes(t_meta) / 1024 / 1024, ggml_type_name(t_meta->type), - ggml_backend_buft_name(buft)); - break; - } - } - } - - if (!buft) { - buft = select_weight_buft(hparams, t_meta, op, *buft_list); - if (!buft) { - throw std::runtime_error(format("failed to find a compatible buffer type for tensor %s", tn.str().c_str())); - } - } - - // avoid using a host buffer when using mmap - auto * buft_dev = ggml_backend_buft_get_device(buft); - if (ml.use_mmap && buft_dev && buft == ggml_backend_dev_host_buffer_type(buft_dev)) { - auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); - if (!cpu_dev) { - throw std::runtime_error("no CPU backend found"); - } - buft = ggml_backend_dev_buffer_type(cpu_dev); - } - - if (buft != buft_list->front().second) { - n_moved_tensors++; - if (!first_moved_tensor) { - first_moved_tensor = t_meta; - first_moved_from_buft = buft_list->front().second; - first_moved_to_buft = buft; - } - } - - ggml_context * ctx = ctx_for_buft(buft); - - // if duplicated, check if the original tensor was allocated in the same buffer type context and avoid creating a new one - if (flags & TENSOR_DUPLICATED) { - ggml_tensor * t = ggml_get_tensor(ctx, tn.str().c_str()); - if (t) { - return t; - } - } - return ml.create_tensor(ctx, tn, ne, flags); + const buft_list_t * buft_list_layer = tn.bid == -1 ? nullptr : pimpl->dev_layer.at(tn.bid).buft_list; + return ml.create_tensor( + hparams, &pimpl->cpu_buft_list, pimpl->dev_input.buft_list, pimpl->dev_output.buft_list, buft_list_layer, + tn, ne, flags); }; layers.resize(n_layer); @@ -3142,6 +2861,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } break; case LLM_ARCH_LLAMA4: { + if (n_expert == 0) { + throw std::runtime_error(arch_name() + " model cannot have zero experts"); + } tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); // output @@ -3154,7 +2876,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } for (int i = 0; i < n_layer; ++i) { - bool is_moe_layer = hparams.n_moe_layer_step > 0 && (i + 1) % hparams.n_moe_layer_step == 0; + const bool is_moe_layer = hparams.n_moe_layer_step > 0 && (i + 1) % hparams.n_moe_layer_step == 0; auto & layer = layers[i]; @@ -3170,7 +2892,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); if (is_moe_layer) { - int n_ff_exp = hparams.n_ff_exp; + const int64_t n_ff_exp = hparams.n_ff_exp; layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); @@ -3257,8 +2979,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } break; case LLM_ARCH_MINICPM3: { - const int64_t n_embd_head_qk_rope = hparams.n_rot; - const int64_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot; + const int64_t n_embd_head_qk_rope = hparams.n_rot(); + const int64_t n_embd_head_qk_nope = hparams.n_embd_head_k() - hparams.n_rot(); const int64_t q_lora_rank = hparams.n_lora_q; const int64_t kv_lora_rank = hparams.n_lora_kv; @@ -3301,7 +3023,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { case LLM_ARCH_GROK: { if (n_expert == 0) { - throw std::runtime_error("Grok model cannot have zero experts"); + throw std::runtime_error(arch_name() + " model cannot have zero experts"); } tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -3473,6 +3195,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { case LLM_ARCH_NOMIC_BERT_MOE: case LLM_ARCH_JINA_BERT_V3: { + if (n_token_types == 0) { + throw std::runtime_error(arch_name() + " model needs to define token type count"); + } tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, TENSOR_NOT_REQUIRED); @@ -3739,8 +3464,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.attn_q_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + // FIXME test-llama-archs crashes if q_norm is created + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED | TENSOR_SKIP_IF_VIRTUAL); + layer.attn_q_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED | TENSOR_SKIP_IF_VIRTUAL); layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); @@ -4126,8 +3852,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { const int64_t dt_dim = std::max(64, int(hparams.n_embd / 16)); // attention parameters - const uint32_t qk_dim = hparams.n_embd_head_k; - const uint32_t v_dim = hparams.n_embd_head_v; + const uint32_t qk_dim = hparams.n_embd_head_k(); + const uint32_t v_dim = hparams.n_embd_head_v(); tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -4187,8 +3913,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } break; case LLM_ARCH_PLAMO3: { - const int64_t head_dim_q = hparams.n_embd_head_k; - const int64_t head_dim_v = hparams.n_embd_head_v; + const int64_t head_dim_q = hparams.n_embd_head_k(); + const int64_t head_dim_v = hparams.n_embd_head_v(); tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -4935,7 +4661,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } break; case LLM_ARCH_SEED_OSS: { - const uint32_t head_dim = hparams.n_embd_head_k; + const uint32_t head_dim = hparams.n_embd_head_k(); const int64_t n_qo_dim = n_head * head_dim; const int64_t n_kv_dim = n_head_kv * head_dim; @@ -5164,8 +4890,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { const int64_t n_embd_head_k_mla = hparams.n_embd_head_k_mla(); const int64_t n_embd_head_v_mla = hparams.n_embd_head_v_mla(); - const int64_t n_embd_head_qk_rope = hparams.n_rot; + const int64_t n_embd_head_qk_rope = hparams.n_rot(); const int64_t n_embd_head_qk_nope = n_embd_head_k_mla - n_embd_head_qk_rope; + GGML_ASSERT(n_embd_head_qk_nope >= 1); const int64_t q_lora_rank = hparams.n_lora_q; const int64_t kv_lora_rank = hparams.n_lora_kv; @@ -5242,8 +4969,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } break; case LLM_ARCH_PLM: { - const int64_t n_embd_head_qk_rope = hparams.n_rot; - const int64_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot; + const int64_t n_embd_head_qk_rope = hparams.n_rot(); + const int64_t n_embd_head_qk_nope = hparams.n_embd_head_k() - hparams.n_rot(); const int64_t kv_lora_rank = hparams.n_lora_kv; tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -5283,23 +5010,23 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_sub_norm = create_tensor(tn(LLM_TENSOR_ATTN_SUB_NORM, "weight", i), {n_embd}, 0); layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wq_scale = create_tensor(tn(LLM_TENSOR_ATTN_Q, "scale", i), {1}, TENSOR_NOT_REQUIRED); + layer.wq_s = create_tensor(tn(LLM_TENSOR_ATTN_Q, "scale", i), {1}, TENSOR_NOT_REQUIRED); layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wk_scale = create_tensor(tn(LLM_TENSOR_ATTN_K, "scale", i), {1}, TENSOR_NOT_REQUIRED); + layer.wk_s = create_tensor(tn(LLM_TENSOR_ATTN_K, "scale", i), {1}, TENSOR_NOT_REQUIRED); layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv_scale = create_tensor(tn(LLM_TENSOR_ATTN_V, "scale", i), {1}, TENSOR_NOT_REQUIRED); + layer.wv_s = create_tensor(tn(LLM_TENSOR_ATTN_V, "scale", i), {1}, TENSOR_NOT_REQUIRED); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.wo_scale = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "scale", i), {1}, TENSOR_NOT_REQUIRED); + layer.wo_s = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "scale", i), {1}, TENSOR_NOT_REQUIRED); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); layer.ffn_sub_norm = create_tensor(tn(LLM_TENSOR_FFN_SUB_NORM, "weight", i), {n_ff}, 0); layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_gate_scale = create_tensor(tn(LLM_TENSOR_FFN_GATE, "scale", i), {1}, TENSOR_NOT_REQUIRED); + layer.ffn_gate_s = create_tensor(tn(LLM_TENSOR_FFN_GATE, "scale", i), {1}, TENSOR_NOT_REQUIRED); layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); - layer.ffn_down_scale = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "scale", i), {1}, TENSOR_NOT_REQUIRED); + layer.ffn_down_s = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "scale", i), {1}, TENSOR_NOT_REQUIRED); layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_up_scale = create_tensor(tn(LLM_TENSOR_FFN_UP, "scale", i), {1}, TENSOR_NOT_REQUIRED); + layer.ffn_up_s = create_tensor(tn(LLM_TENSOR_FFN_UP, "scale", i), {1}, TENSOR_NOT_REQUIRED); } } break; case LLM_ARCH_T5: @@ -5357,7 +5084,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_NORM, "weight", i), {n_embd}, 0); // this tensor seems to be unused in HF transformers implementation - layer.attn_rel_b_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, TENSOR_NOT_REQUIRED); + layer.attn_rel_b_cross = create_tensor( + tn(LLM_TENSOR_DEC_CROSS_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, TENSOR_NOT_REQUIRED | TENSOR_SKIP_IF_VIRTUAL); layer.wq_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}, 0); layer.wk_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); @@ -5680,7 +5408,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { const int64_t n_embd_head_k_mla = hparams.n_embd_head_k_mla(); const int64_t n_embd_head_v_mla = hparams.n_embd_head_v_mla(); - const int64_t n_embd_head_qk_rope = hparams.n_rot; + const int64_t n_embd_head_qk_rope = hparams.n_rot(); const int64_t n_embd_head_qk_nope = n_embd_head_k_mla - n_embd_head_qk_rope; const int64_t q_lora_rank = hparams.n_lora_q; @@ -5819,6 +5547,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { const int64_t n_ssm_head = hparams.ssm_dt_rank; const int64_t n_group = hparams.ssm_n_group; const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + n_ssm_head; + const int64_t moe_n_embd = hparams.moe_latent_size > 0 ? hparams.moe_latent_size : n_embd; // embeddings tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -5878,8 +5607,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert }, 0); // MoE branch - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_latent_down = create_tensor(tn(LLM_TENSOR_FFN_LATENT_DOWN, "weight", i), {n_embd, moe_n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_latent_up = create_tensor(tn(LLM_TENSOR_FFN_LATENT_UP, "weight", i), {moe_n_embd, n_embd}, TENSOR_NOT_REQUIRED); + + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, moe_n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {moe_n_embd, n_ff_exp, n_expert}, 0); // Shared expert branch layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, 0); @@ -5963,8 +5695,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { const int64_t n_ff_exp = hparams.n_ff_exp; const int64_t n_expert = hparams.n_expert; const int64_t n_expert_used = hparams.n_expert_used; - const int64_t n_ff_shexp = hparams.n_ff_shexp; - const int64_t head_dim = hparams.n_embd_head_k; + const int64_t n_ff_shexp = hparams.n_ff_shexp > 0 ? hparams.n_ff_shexp : n_ff_exp; + const int64_t head_dim = hparams.n_embd_head_k(); const int64_t n_qo_dim = n_head * head_dim; const int64_t n_kv_dim = n_head_kv * head_dim; @@ -6824,6 +6556,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; + const uint32_t n_ff_shexp = hparams.n_ff_shexp > 0 ? hparams.n_ff_shexp : hparams.n_ff(i); layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); @@ -6842,9 +6575,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); - layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); - layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); - layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, 0); + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, 0); } } break; case LLM_ARCH_HUNYUAN_DENSE: @@ -7180,15 +6913,14 @@ bool llama_model::load_tensors(llama_model_loader & ml) { const int64_t n_embd_head_v_kda = hparams.n_embd_head_kda; const int64_t ssm_d_conv = hparams.ssm_d_conv; - // Try loading KDA specific tensors (using SSM_ prefix) - // Conv1d weights: try 4D first, then 3D (quantization may remove trailing 1) - // 4D: [d_conv, 1, d_inner, 1], 3D: [d_conv, 1, d_inner] - layer.ssm_q_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_Q, "weight", i), {ssm_d_conv, 1, n_embd_head_k_kda * n_head, 1}, TENSOR_NOT_REQUIRED); - if (!layer.ssm_q_conv) { - layer.ssm_q_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_Q, "weight", i), {ssm_d_conv, 1, n_embd_head_k_kda * n_head}, TENSOR_NOT_REQUIRED); - } + if (hparams.is_recurrent(i)) { + // Conv1d weights: try 4D first, then 3D (quantization may remove trailing 1) + // 4D: [d_conv, 1, d_inner, 1], 3D: [d_conv, 1, d_inner] + layer.ssm_q_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_Q, "weight", i), {ssm_d_conv, 1, n_embd_head_k_kda * n_head, 1}, TENSOR_NOT_REQUIRED); + if (!layer.ssm_q_conv) { + layer.ssm_q_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_Q, "weight", i), {ssm_d_conv, 1, n_embd_head_k_kda * n_head}, 0); + } - if (layer.ssm_q_conv) { // KDA Layer - Conv1d weights may be 3D or 4D layer.ssm_k_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_K, "weight", i), {ssm_d_conv, 1, n_embd_head_k_kda * n_head, 1}, TENSOR_NOT_REQUIRED); if (!layer.ssm_k_conv) { @@ -7252,10 +6984,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // Kimi: qk_rope_head_dim = 64 (actual RoPE dimension for MLA) // Note: hparams.n_rot may be 72 (from conversion) but actual is 64 - const int64_t qk_rope_head_dim = hparams.n_rot; // From config: qk_rope_head_dim + const int64_t qk_rope_head_dim = hparams.n_rot(); // From config: qk_rope_head_dim layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + qk_rope_head_dim}, 0); // Support Legacy GGUFs that don't split wkv_b (MLA KV cache disabled) - layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_k_mla - qk_rope_head_dim + n_embd_head_v_mla)}, TENSOR_NOT_REQUIRED); + layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), + {kv_lora_rank, n_head * (n_embd_head_k_mla - qk_rope_head_dim + n_embd_head_v_mla)}, TENSOR_NOT_REQUIRED | TENSOR_SKIP_IF_VIRTUAL); if (!layer.wkv_b) { // MLA KV cache enabled layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K_B, "weight", i), {n_embd_head_k_mla - qk_rope_head_dim, kv_lora_rank, n_head}, 0); layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_embd_head_v_mla, n_head}, 0); @@ -7375,6 +7108,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } break; case LLM_ARCH_QWEN3NEXT: { + if (n_expert == 0) { + throw std::runtime_error(arch_name() + " model cannot have zero experts"); + } + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); // output @@ -7403,6 +7140,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; + const uint32_t n_ff_shexp = hparams.n_ff_shexp > 0 ? hparams.n_ff_shexp : hparams.n_ff(i); layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0); @@ -7438,9 +7176,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // Shared experts layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), { n_embd }, 0); - layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, hparams.n_ff_shexp }, 0); - layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, hparams.n_ff_shexp }, 0); - layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { hparams.n_ff_shexp, n_embd }, 0); + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_shexp, n_embd }, 0); } } break; case LLM_ARCH_QWEN35MOE: @@ -7617,7 +7355,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // ("rope_freqs.weight") and ggml uses only the first (n_rot_l/2) entries per layer. uint32_t n_rot_max = 0; for (int i = 0; i < n_layer; ++i) { - n_rot_max = std::max(n_rot_max, hparams.n_rot); + n_rot_max = std::max(n_rot_max, hparams.n_rot(i)); } if (n_rot_max == 0) { n_rot_max = n_rot; @@ -7706,10 +7444,72 @@ bool llama_model::load_tensors(llama_model_loader & ml) { throw std::runtime_error("unknown architecture"); } - if (n_moved_tensors > 0) { - LLAMA_LOG_DEBUG("%s: tensor '%s' (%s) (and %d others) cannot be used with preferred buffer type %s, using %s instead\n", - __func__, first_moved_tensor->name, ggml_type_name(first_moved_tensor->type), n_moved_tensors - 1, - ggml_backend_buft_name(first_moved_from_buft), ggml_backend_buft_name(first_moved_to_buft)); + // generic pass: load optional per-tensor/per-expert ".scale" tensors (e.g. NVFP4 scale2) + // this avoids having to add scale loading to every architecture + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + // attention weight scales (per-tensor, shape {1}) + if (!layer.wq_s && layer.wq) { + layer.wq_s = create_tensor(tn(LLM_TENSOR_ATTN_Q, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.wk_s && layer.wk) { + layer.wk_s = create_tensor(tn(LLM_TENSOR_ATTN_K, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.wv_s && layer.wv) { + layer.wv_s = create_tensor(tn(LLM_TENSOR_ATTN_V, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.wo_s && layer.wo) { + layer.wo_s = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.wqkv_s && layer.wqkv) { + layer.wqkv_s = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.wqkv_gate_s && layer.wqkv_gate) { + layer.wqkv_gate_s = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + + // dense FFN weight scales (per-tensor, shape {1}) + if (!layer.ffn_gate_s && layer.ffn_gate) { + layer.ffn_gate_s = create_tensor(tn(LLM_TENSOR_FFN_GATE, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_down_s && layer.ffn_down) { + layer.ffn_down_s = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_up_s && layer.ffn_up) { + layer.ffn_up_s = create_tensor(tn(LLM_TENSOR_FFN_UP, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_gate_shexp_s && layer.ffn_gate_shexp) { + layer.ffn_gate_shexp_s = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_down_shexp_s && layer.ffn_down_shexp) { + layer.ffn_down_shexp_s = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_up_shexp_s && layer.ffn_up_shexp) { + layer.ffn_up_shexp_s = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + + // MoE expert weight scales (per-expert, shape {n_expert}) + if (!layer.ffn_gate_exps_s && layer.ffn_gate_exps) { + layer.ffn_gate_exps_s = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "scale", i), {n_expert}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_down_exps_s && layer.ffn_down_exps) { + layer.ffn_down_exps_s = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "scale", i), {n_expert}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_up_exps_s && layer.ffn_up_exps) { + layer.ffn_up_exps_s = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "scale", i), {n_expert}, TENSOR_NOT_REQUIRED); + } + + // recurrent / linear-attention weight scales (per-tensor, shape {1}) + if (!layer.ssm_out_s && layer.ssm_out) { + layer.ssm_out_s = create_tensor(tn(LLM_TENSOR_SSM_OUT, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ssm_alpha_s && layer.ssm_alpha) { + layer.ssm_alpha_s = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ssm_beta_s && layer.ssm_beta) { + layer.ssm_beta_s = create_tensor(tn(LLM_TENSOR_SSM_BETA, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } } } @@ -7720,13 +7520,13 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // create the backend buffers std::vector> ctx_buf_maps; - ctx_buf_maps.reserve(ctx_map.size()); + ctx_buf_maps.reserve(ml.ctx_map.size()); // Ensure we have enough capacity for the maximum backend buffer we will potentially create - const size_t n_max_backend_buffer = ctx_map.size() * ml.files.size(); + const size_t n_max_backend_buffer = ml.ctx_map.size() * ml.files.size(); pimpl->ctxs_bufs.reserve(n_max_backend_buffer); - for (auto & [buft, ctx_ptr] : ctx_map) { + for (auto & [buft, ctx_ptr] : ml.ctx_map) { ggml_context * ctx = ctx_ptr.get(); // skip contexts without tensors @@ -7958,11 +7758,11 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: n_layer = %u\n", __func__, hparams.n_layer); LLAMA_LOG_INFO("%s: n_head = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head(il); }, hparams.n_layer).c_str()); LLAMA_LOG_INFO("%s: n_head_kv = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head_kv(il); }, hparams.n_layer).c_str()); - LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot); + LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot_full); LLAMA_LOG_INFO("%s: n_swa = %u\n", __func__, hparams.n_swa); LLAMA_LOG_INFO("%s: is_swa_any = %u\n", __func__, hparams.is_swa_any()); - LLAMA_LOG_INFO("%s: n_embd_head_k = %u\n", __func__, hparams.n_embd_head_k); - LLAMA_LOG_INFO("%s: n_embd_head_v = %u\n", __func__, hparams.n_embd_head_v); + LLAMA_LOG_INFO("%s: n_embd_head_k = %u\n", __func__, hparams.n_embd_head_k_full); + LLAMA_LOG_INFO("%s: n_embd_head_v = %u\n", __func__, hparams.n_embd_head_v_full); LLAMA_LOG_INFO("%s: n_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_gqa(il); }, hparams.n_layer).c_str()); LLAMA_LOG_INFO("%s: n_embd_k_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_k_gqa(il); }, hparams.n_layer).c_str()); LLAMA_LOG_INFO("%s: n_embd_v_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_v_gqa(il); }, hparams.n_layer).c_str()); @@ -7986,6 +7786,9 @@ void llama_model::print_info() const { if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { LLAMA_LOG_INFO("%s: freq_base_swa = %.1f\n", __func__, hparams.rope_freq_base_train_swa); LLAMA_LOG_INFO("%s: freq_scale_swa = %g\n", __func__, hparams.rope_freq_scale_train_swa); + LLAMA_LOG_INFO("%s: n_embd_head_k_swa = %u\n", __func__, hparams.n_embd_head_k_swa); + LLAMA_LOG_INFO("%s: n_embd_head_v_swa = %u\n", __func__, hparams.n_embd_head_v_swa); + LLAMA_LOG_INFO("%s: n_rot_swa = %u\n", __func__, hparams.n_rot_swa); } LLAMA_LOG_INFO("%s: n_ctx_orig_yarn = %u\n", __func__, hparams.n_ctx_orig_yarn); LLAMA_LOG_INFO("%s: rope_yarn_log_mul = %.4f\n", __func__, hparams.rope_yarn_log_mul); diff --git a/examples/talk-llama/llama-model.h b/examples/talk-llama/llama-model.h index d7c3e7d1c1a..25bf892e7e2 100644 --- a/examples/talk-llama/llama-model.h +++ b/examples/talk-llama/llama-model.h @@ -54,6 +54,7 @@ enum llm_type { LLM_TYPE_0_3B, LLM_TYPE_0_5B, LLM_TYPE_0_6B, + LLM_TYPE_0_8B, LLM_TYPE_1B, LLM_TYPE_1_2B, LLM_TYPE_1_3B, @@ -125,12 +126,15 @@ enum llm_type { LLM_TYPE_100B_A6B, LLM_TYPE_102B_A12B, // Solar-Open LLM_TYPE_106B_A12B, // GLM-4.5-Air + LLM_TYPE_120B_A12B, // Nemotron 3 Super + LLM_TYPE_122B_A10B, // Qwen3.5 LLM_TYPE_196B_A11B, // Step3.5-Flash LLM_TYPE_230B_A10B, // Minimax M2 LLM_TYPE_235B_A22B, LLM_TYPE_300B_A47B, // Ernie MoE big LLM_TYPE_310B_A15B, // /MiMo-V2-Flash LLM_TYPE_355B_A32B, // GLM-4.5 + LLM_TYPE_397B_A17B, // Qwen3.5 LLM_TYPE_744B_A40B, // GLM-5 LLM_TYPE_E2B, LLM_TYPE_E4B, @@ -291,6 +295,15 @@ struct llama_layer { struct ggml_tensor * ffn_up_exps_b = nullptr; struct ggml_tensor * ffn_gate_up_exps_b = nullptr; + // ff MoE per-expert scales (NVFP4 per-tensor scale2) + struct ggml_tensor * ffn_gate_exps_s = nullptr; + struct ggml_tensor * ffn_down_exps_s = nullptr; + struct ggml_tensor * ffn_up_exps_s = nullptr; + + // ff MoE latent proj + struct ggml_tensor * ffn_latent_down = nullptr; + struct ggml_tensor * ffn_latent_up = nullptr; + // ff shared expert (shexp) struct ggml_tensor * ffn_gate_inp_shexp = nullptr; struct ggml_tensor * ffn_gate_shexp = nullptr; @@ -384,13 +397,21 @@ struct llama_layer { struct ggml_tensor * rope_freqs = nullptr; // bitnet scale - struct ggml_tensor * wq_scale = nullptr; - struct ggml_tensor * wk_scale = nullptr; - struct ggml_tensor * wv_scale = nullptr; - struct ggml_tensor * wo_scale = nullptr; - struct ggml_tensor * ffn_gate_scale = nullptr; - struct ggml_tensor * ffn_up_scale = nullptr; - struct ggml_tensor * ffn_down_scale = nullptr; + struct ggml_tensor * wq_s = nullptr; + struct ggml_tensor * wk_s = nullptr; + struct ggml_tensor * wv_s = nullptr; + struct ggml_tensor * wo_s = nullptr; + struct ggml_tensor * wqkv_s = nullptr; + struct ggml_tensor * wqkv_gate_s = nullptr; + struct ggml_tensor * ffn_gate_s = nullptr; + struct ggml_tensor * ffn_up_s = nullptr; + struct ggml_tensor * ffn_down_s = nullptr; + struct ggml_tensor * ffn_gate_shexp_s = nullptr; + struct ggml_tensor * ffn_up_shexp_s = nullptr; + struct ggml_tensor * ffn_down_shexp_s = nullptr; + struct ggml_tensor * ssm_out_s = nullptr; + struct ggml_tensor * ssm_alpha_s = nullptr; + struct ggml_tensor * ssm_beta_s = nullptr; // altup & laurel struct ggml_tensor * per_layer_inp_gate = nullptr; diff --git a/examples/talk-llama/llama-quant.cpp b/examples/talk-llama/llama-quant.cpp index 24770430e1c..8e8ce231249 100644 --- a/examples/talk-llama/llama-quant.cpp +++ b/examples/talk-llama/llama-quant.cpp @@ -1,11 +1,11 @@ -#include "llama-quant.h" +#include "llama.h" #include "llama-impl.h" #include "llama-model.h" #include "llama-model-loader.h" -#include #include #include +#include #include #include #include @@ -13,10 +13,28 @@ #include #include -// Quantization types. Changes to this struct must be replicated in quantize.cpp -struct tensor_quantization { +// result of parsing --tensor-type option +// (changes to this struct must be reflected in tools/quantize/quantize.cpp) +struct tensor_type_option { std::string name; - ggml_type quant = GGML_TYPE_COUNT; + ggml_type type = GGML_TYPE_COUNT; +}; + +// tensor categorization - used to avoid repeated string matching in quantization logic. +// this is different from LLM_TN - we want broad categories, not specific tensor names per arch. +enum class tensor_category { + TOKEN_EMBD, + ATTENTION_Q, + ATTENTION_V, + ATTENTION_K, + ATTENTION_QKV, + ATTENTION_KV_B, + ATTENTION_OUTPUT, + FFN_UP, + FFN_GATE, + FFN_DOWN, + OUTPUT, + OTHER }; static void zeros(std::ofstream & file, size_t n) { @@ -54,7 +72,7 @@ static std::string remap_layer(const std::string & orig_name, const std::vector< return orig_name; } -static std::string remap_imatrix (const std::string & orig_name, const std::map & mapped) { +static std::string remap_imatrix(const std::string & orig_name, const std::map & mapped) { if (mapped.empty()) { return orig_name; } @@ -76,6 +94,73 @@ static std::string remap_imatrix (const std::string & orig_name, const std::map< return orig_name; } +// +// helper functions for tensor name matching +// + +static bool tensor_name_match_token_embd(const char * tensor_name) { + return std::strcmp(tensor_name, "token_embd.weight") == 0 || + std::strcmp(tensor_name, "per_layer_token_embd.weight") == 0; +} + +static bool tensor_name_match_output_weight(const char * tensor_name) { + return std::strcmp(tensor_name, "output.weight") == 0; +} + +// +// tensor categorization for quantization +// +// (this is different from LLM_TN - we want broad categories, not specific tensor names per arch) +// + +static tensor_category tensor_get_category(const std::string & tensor_name) { + if (tensor_name_match_output_weight(tensor_name.c_str())) { + return tensor_category::OUTPUT; + } + if (tensor_name_match_token_embd(tensor_name.c_str())) { + return tensor_category::TOKEN_EMBD; + } + if (tensor_name.find("attn_qkv.weight") != std::string::npos) { + return tensor_category::ATTENTION_QKV; + } + if (tensor_name.find("attn_kv_b.weight") != std::string::npos) { + return tensor_category::ATTENTION_KV_B; + } + if (tensor_name.find("attn_v.weight") != std::string::npos) { + return tensor_category::ATTENTION_V; + } + if (tensor_name.find("attn_k.weight") != std::string::npos) { + return tensor_category::ATTENTION_K; + } + if (tensor_name.find("attn_q.weight") != std::string::npos) { + return tensor_category::ATTENTION_Q; + } + if (tensor_name.find("attn_output.weight") != std::string::npos) { + return tensor_category::ATTENTION_OUTPUT; + } + if (tensor_name.find("ffn_up") != std::string::npos) { + return tensor_category::FFN_UP; + } + if (tensor_name.find("ffn_gate") != std::string::npos) { + return tensor_category::FFN_GATE; + } + if (tensor_name.find("ffn_down") != std::string::npos) { + return tensor_category::FFN_DOWN; + } + return tensor_category::OTHER; +} + +// check if category is for attention-v-like tensors (more sensitive to quantization) +static bool category_is_attn_v(tensor_category cat) { + return cat == tensor_category::ATTENTION_V || + cat == tensor_category::ATTENTION_QKV || + cat == tensor_category::ATTENTION_KV_B; +} + +// +// quantization state +// + struct quantize_state_impl { const llama_model & model; const llama_model_quantize_params * params; @@ -89,20 +174,42 @@ struct quantize_state_impl { int i_ffn_gate = 0; int i_ffn_up = 0; - int n_k_quantized = 0; int n_fallback = 0; bool has_imatrix = false; - // used to figure out if a model shares tok_embd with the output weight - bool has_output = false; + // used to figure out if a model has tied embeddings (tok_embd shares weights with output) + bool has_tied_embeddings = true; // assume tied until we see output.weight + + // tensor type override patterns (compiled once, used twice) + std::vector> tensor_type_patterns; + + quantize_state_impl(const llama_model & model, const llama_model_quantize_params * params): + model(model), params(params) + { + // compile regex patterns once - they are expensive + if (params->tensor_types) { + const auto & tensor_types = *static_cast *>(params->tensor_types); + for (const auto & [tname, qtype] : tensor_types) { + tensor_type_patterns.emplace_back(std::regex(tname), qtype); + } + } + } +}; - quantize_state_impl(const llama_model & model, const llama_model_quantize_params * params) - : model(model) - , params(params) - {} +// per-tensor metadata, computed in the preliminary loop and used in the main loop +struct tensor_metadata { + ggml_type target_type; + tensor_category category; + std::string remapped_imatrix_name; + bool allows_quantization; + bool requires_imatrix; }; +// +// dequantization +// + static void llama_tensor_dequantize_impl( ggml_tensor * tensor, std::vector> & output, std::vector & workers, const size_t nelements, const int nthread @@ -175,12 +282,132 @@ static void llama_tensor_dequantize_impl( workers.clear(); } -static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_type, const ggml_tensor * tensor, llama_ftype ftype) { +// +// do we allow this tensor to be quantized? +// + +static bool tensor_allows_quantization(const llama_model_quantize_params * params, llm_arch arch, const ggml_tensor * tensor) { + // trivial checks first -- no string ops needed + if (params->only_copy) return false; + + // quantize only 2D and 3D tensors (experts) + if (ggml_n_dims(tensor) < 2) return false; + + const std::string name = ggml_get_name(tensor); + + // This used to be a regex, but has an extreme cost to compile times. + bool quantize = name.rfind("weight") == name.size() - 6; // ends with 'weight'? + + // do not quantize norm tensors + quantize &= name.find("_norm.weight") == std::string::npos; + + quantize &= params->quantize_output_tensor || name != "output.weight"; + + // do not quantize expert gating tensors + // NOTE: can't use LLM_TN here because the layer number is not known + quantize &= name.find("ffn_gate_inp.weight") == std::string::npos; + + // these are very small (e.g. 4x4) + quantize &= name.find("altup") == std::string::npos; + quantize &= name.find("laurel") == std::string::npos; + + // these are not too big so keep them as it is + quantize &= name.find("per_layer_model_proj") == std::string::npos; + + // do not quantize positional embeddings and token types (BERT) + quantize &= name != LLM_TN(arch)(LLM_TENSOR_POS_EMBD, "weight"); + quantize &= name != LLM_TN(arch)(LLM_TENSOR_TOKEN_TYPES, "weight"); + + // do not quantize Mamba/Kimi's small conv1d weights + // NOTE: can't use LLM_TN here because the layer number is not known + quantize &= name.find("ssm_conv1d") == std::string::npos; + quantize &= name.find("shortconv.conv.weight") == std::string::npos; + + // do not quantize RWKV's small yet 2D weights + quantize &= name.find("time_mix_first.weight") == std::string::npos; + quantize &= name.find("time_mix_w0.weight") == std::string::npos; + quantize &= name.find("time_mix_w1.weight") == std::string::npos; + quantize &= name.find("time_mix_w2.weight") == std::string::npos; + quantize &= name.find("time_mix_v0.weight") == std::string::npos; + quantize &= name.find("time_mix_v1.weight") == std::string::npos; + quantize &= name.find("time_mix_v2.weight") == std::string::npos; + quantize &= name.find("time_mix_a0.weight") == std::string::npos; + quantize &= name.find("time_mix_a1.weight") == std::string::npos; + quantize &= name.find("time_mix_a2.weight") == std::string::npos; + quantize &= name.find("time_mix_g1.weight") == std::string::npos; + quantize &= name.find("time_mix_g2.weight") == std::string::npos; + quantize &= name.find("time_mix_decay_w1.weight") == std::string::npos; + quantize &= name.find("time_mix_decay_w2.weight") == std::string::npos; + quantize &= name.find("time_mix_lerp_fused.weight") == std::string::npos; + + // do not quantize relative position bias (T5) + quantize &= name.find("attn_rel_b.weight") == std::string::npos; + + // do not quantize specific multimodal tensors + quantize &= name.find(".position_embd.") == std::string::npos; + + return quantize; +} + +// +// tensor type selection +// + +// incompatible tensor shapes are handled here - fallback to a compatible type +static ggml_type tensor_type_fallback(quantize_state_impl & qs, const ggml_tensor * t, const ggml_type target_type) { + ggml_type return_type = target_type; + + const int64_t ncols = t->ne[0]; + const int64_t qk_k = ggml_blck_size(target_type); + + if (ncols % qk_k != 0) { // this tensor's shape is incompatible with this quant + LLAMA_LOG_WARN("warning: %-36s - ncols %6" PRId64 " not divisible by %3" PRId64 " (required for type %7s) ", + t->name, ncols, qk_k, ggml_type_name(target_type)); + ++qs.n_fallback; + + switch (target_type) { + // types on the left: block size 256 + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: // types on the right: block size 32 + case GGML_TYPE_IQ4_XS: return_type = GGML_TYPE_IQ4_NL; break; + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_TQ1_0: + case GGML_TYPE_TQ2_0: return_type = GGML_TYPE_Q4_0; break; + case GGML_TYPE_Q4_K: return_type = GGML_TYPE_Q5_0; break; + case GGML_TYPE_Q5_K: return_type = GGML_TYPE_Q5_1; break; + case GGML_TYPE_Q6_K: return_type = GGML_TYPE_Q8_0; break; + default: + throw std::runtime_error(format("no tensor type fallback is defined for type %s", + ggml_type_name(target_type))); + } + if (ncols % ggml_blck_size(return_type) != 0) { + // + // the fallback return type is still not compatible for this tensor! + // + // most likely, this tensor's first dimension is not divisible by 32. + // this is very rare. we can either abort the quantization, or + // fallback to F16 / F32. + // + LLAMA_LOG_WARN("(WARNING: must use F16 due to unusual shape) "); + return_type = GGML_TYPE_F16; + } + LLAMA_LOG_WARN("-> falling back to %7s\n", ggml_type_name(return_type)); + } + return return_type; +} + +// internal standard logic for selecting the target tensor type based on tensor category, ftype, and model arch +static ggml_type llama_tensor_get_type_impl(quantize_state_impl & qs, ggml_type new_type, const ggml_tensor * tensor, llama_ftype ftype, tensor_category category) { const std::string name = ggml_get_name(tensor); // TODO: avoid hardcoded tensor names - use the TN_* constants const llm_arch arch = qs.model.arch; - const auto tn = LLM_TN(arch); auto use_more_bits = [](int i_layer, int n_layers) -> bool { return i_layer < n_layers/8 || i_layer >= 7*n_layers/8 || (i_layer - n_layers/8)%3 == 2; @@ -204,7 +431,7 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t // for arches that share the same tensor between the token embeddings and the output, we quantize the token embeddings // with the quantization of the output tensor - if (name == tn(LLM_TENSOR_OUTPUT, "weight") || (!qs.has_output && name == tn(LLM_TENSOR_TOKEN_EMBD, "weight"))) { + if (category == tensor_category::OUTPUT || (qs.has_tied_embeddings && category == tensor_category::TOKEN_EMBD)) { if (qs.params->output_tensor_type < GGML_TYPE_COUNT) { new_type = qs.params->output_tensor_type; } else { @@ -234,7 +461,7 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t } else { new_type = GGML_TYPE_Q8_0; } - } else if (name == "token_embd.weight" || name == "per_layer_token_embd.weight") { + } else if (category == tensor_category::TOKEN_EMBD) { if (qs.params->token_embedding_type < GGML_TYPE_COUNT) { new_type = qs.params->token_embedding_type; } else { @@ -254,21 +481,21 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t } } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) { - if (name.find("attn_v.weight") != std::string::npos) { + if (category_is_attn_v(category)) { if (qs.model.hparams.n_gqa() >= 4 || qs.model.hparams.n_expert >= 4) new_type = GGML_TYPE_Q4_K; else new_type = ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K; ++qs.i_attention_wv; } - else if (qs.model.hparams.n_expert == 8 && name.find("attn_k.weight") != std::string::npos) { + else if (qs.model.hparams.n_expert == 8 && category == tensor_category::ATTENTION_K) { new_type = GGML_TYPE_Q4_K; } - else if (name.find("ffn_down") != std::string::npos) { + else if (category == tensor_category::FFN_DOWN) { if (qs.i_ffn_down < qs.n_ffn_down/8) { new_type = ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K; } ++qs.i_ffn_down; } - else if (name.find("attn_output.weight") != std::string::npos) { + else if (category == tensor_category::ATTENTION_OUTPUT) { if (qs.model.hparams.n_expert == 8) { new_type = GGML_TYPE_Q5_K; } else { @@ -276,7 +503,7 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M) new_type = GGML_TYPE_IQ3_S; } } - } else if (name.find("attn_v.weight") != std::string::npos) { + } else if (category_is_attn_v(category)) { if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) { new_type = qs.model.hparams.n_gqa() >= 4 ? GGML_TYPE_Q4_K : GGML_TYPE_Q3_K; } @@ -314,7 +541,7 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t new_type = GGML_TYPE_Q8_0; } ++qs.i_attention_wv; - } else if (name.find("attn_k.weight") != std::string::npos) { + } else if (category == tensor_category::ATTENTION_K) { if (qs.model.hparams.n_expert == 8) { // for the 8-expert model, bumping this to Q8_0 trades just ~128MB // TODO: explore better strategies @@ -326,14 +553,14 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) { new_type = GGML_TYPE_IQ2_S; } - } else if (name.find("attn_q.weight") != std::string::npos) { + } else if (category == tensor_category::ATTENTION_Q) { if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS) { new_type = GGML_TYPE_IQ3_XXS; } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) { new_type = GGML_TYPE_IQ2_S; } - } else if (name.find("ffn_down") != std::string::npos) { + } else if (category == tensor_category::FFN_DOWN) { auto info = layer_info(qs.i_ffn_down, qs.n_ffn_down, name.c_str()); int i_layer = info.first, n_layer = info.second; if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K; @@ -378,7 +605,7 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t new_type = ftype == LLAMA_FTYPE_MOSTLY_Q4_0 ? GGML_TYPE_Q4_1 : GGML_TYPE_Q5_1; } ++qs.i_ffn_down; - } else if (name.find("attn_output.weight") != std::string::npos) { + } else if (category == tensor_category::ATTENTION_OUTPUT) { if (arch != LLM_ARCH_FALCON) { if (qs.model.hparams.n_expert == 8) { if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS || @@ -398,14 +625,14 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q4_K; } } - else if (name.find("attn_qkv.weight") != std::string::npos) { + else if (category == tensor_category::ATTENTION_QKV) { if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L || ftype == LLAMA_FTYPE_MOSTLY_IQ3_M) { new_type = GGML_TYPE_Q4_K; } else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M) new_type = GGML_TYPE_Q5_K; else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) new_type = GGML_TYPE_Q6_K; } - else if (name.find("ffn_gate") != std::string::npos) { + else if (category == tensor_category::FFN_GATE) { auto info = layer_info(qs.i_ffn_gate, qs.n_ffn_gate, name.c_str()); int i_layer = info.first, n_layer = info.second; if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS && (i_layer >= n_layer/8 && i_layer < 7*n_layer/8)) { @@ -413,7 +640,7 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t } ++qs.i_ffn_gate; } - else if (name.find("ffn_up") != std::string::npos) { + else if (category == tensor_category::FFN_UP) { auto info = layer_info(qs.i_ffn_up, qs.n_ffn_up, name.c_str()); int i_layer = info.first, n_layer = info.second; if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS && (i_layer >= n_layer/8 && i_layer < 7*n_layer/8)) { @@ -425,6 +652,55 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t return new_type; } +// outer wrapper: determine the ggml_type that this tensor should be quantized to +static ggml_type llama_tensor_get_type(quantize_state_impl & qs, const llama_model_quantize_params * params, const ggml_tensor * tensor, ggml_type default_type, const tensor_metadata & tm) { + if (!tensor_allows_quantization(params, qs.model.arch, tensor)) { + return tensor->type; + } + if (params->token_embedding_type < GGML_TYPE_COUNT && tm.category == tensor_category::TOKEN_EMBD) { + return params->token_embedding_type; + } + if (params->output_tensor_type < GGML_TYPE_COUNT && tm.category == tensor_category::OUTPUT) { + return params->output_tensor_type; + } + + ggml_type new_type = default_type; + + // get more optimal quantization type based on the tensor shape, layer, etc. + if (!params->pure && ggml_is_quantized(default_type)) { + // if the user provided tensor types - use those + bool manual = false; + if (!qs.tensor_type_patterns.empty()) { + const std::string tensor_name(tensor->name); + for (const auto & [pattern, qtype] : qs.tensor_type_patterns) { + if (std::regex_search(tensor_name, pattern)) { + if (qtype != new_type) { + LLAMA_LOG_WARN("%s: %-36s - applying manual override: %s -> %s\n", + __func__, tensor_name.c_str(), ggml_type_name(new_type), ggml_type_name(qtype)); + new_type = qtype; + manual = true; + break; + } + } + } + } + + // if not manual - use the standard logic for choosing the quantization type based on the selected mixture + if (!manual) { + new_type = llama_tensor_get_type_impl(qs, new_type, tensor, params->ftype, tm.category); + } + + // incompatible tensor shapes are handled here - fallback to a compatible type + new_type = tensor_type_fallback(qs, tensor, new_type); + } + + return new_type; +} + +// +// quantization implementation +// + static size_t llama_tensor_quantize_impl(enum ggml_type new_type, const float * f32_data, void * new_data, const int64_t chunk_size, int64_t nrows, int64_t n_per_row, const float * imatrix, std::vector & workers, const int nthread) { if (nthread < 2) { // single-thread @@ -479,61 +755,85 @@ static size_t llama_tensor_quantize_impl(enum ggml_type new_type, const float * return new_size; } -static bool tensor_type_requires_imatrix(const ggml_tensor * t, const ggml_type dst_type, const llama_ftype ftype) { - return ( - dst_type == GGML_TYPE_IQ2_XXS || dst_type == GGML_TYPE_IQ2_XS || - dst_type == GGML_TYPE_IQ3_XXS || dst_type == GGML_TYPE_IQ1_S || - dst_type == GGML_TYPE_IQ2_S || dst_type == GGML_TYPE_IQ1_M || - ( // Q2_K_S is the worst k-quant type - only allow it without imatrix for token embeddings - dst_type == GGML_TYPE_Q2_K && ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S && strcmp(t->name, "token_embd.weight") != 0 - ) - ); +// +// imatrix requirement check +// + +static bool tensor_requires_imatrix(const char * tensor_name, const ggml_type dst_type, const llama_ftype ftype) { + if (tensor_name_match_token_embd(tensor_name) || tensor_name_match_output_weight(tensor_name)) { + return false; + } + switch (dst_type) { + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ1_S: + return true; + case GGML_TYPE_Q2_K: + // as a general rule, the k-type quantizations don't require imatrix data. + // the only exception is Q2_K tensors that are part of a Q2_K_S file. + return ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S; + default: + return false; + } } -static void llama_model_quantize_impl(const std::string & fname_inp, const std::string & fname_out, const llama_model_quantize_params * params) { - ggml_type default_type; - llama_ftype ftype = params->ftype; +// +// given a file type, get the default tensor type +// - switch (params->ftype) { - case LLAMA_FTYPE_MOSTLY_Q4_0: default_type = GGML_TYPE_Q4_0; break; - case LLAMA_FTYPE_MOSTLY_Q4_1: default_type = GGML_TYPE_Q4_1; break; - case LLAMA_FTYPE_MOSTLY_Q5_0: default_type = GGML_TYPE_Q5_0; break; - case LLAMA_FTYPE_MOSTLY_Q5_1: default_type = GGML_TYPE_Q5_1; break; - case LLAMA_FTYPE_MOSTLY_Q8_0: default_type = GGML_TYPE_Q8_0; break; - case LLAMA_FTYPE_MOSTLY_F16: default_type = GGML_TYPE_F16; break; - case LLAMA_FTYPE_MOSTLY_BF16: default_type = GGML_TYPE_BF16; break; - case LLAMA_FTYPE_ALL_F32: default_type = GGML_TYPE_F32; break; +static ggml_type llama_ftype_get_default_type(llama_ftype ftype) { + switch (ftype) { + case LLAMA_FTYPE_MOSTLY_Q4_0: return GGML_TYPE_Q4_0; + case LLAMA_FTYPE_MOSTLY_Q4_1: return GGML_TYPE_Q4_1; + case LLAMA_FTYPE_MOSTLY_Q5_0: return GGML_TYPE_Q5_0; + case LLAMA_FTYPE_MOSTLY_Q5_1: return GGML_TYPE_Q5_1; + case LLAMA_FTYPE_MOSTLY_Q8_0: return GGML_TYPE_Q8_0; + case LLAMA_FTYPE_MOSTLY_F16: return GGML_TYPE_F16; + case LLAMA_FTYPE_MOSTLY_BF16: return GGML_TYPE_BF16; + case LLAMA_FTYPE_ALL_F32: return GGML_TYPE_F32; - case LLAMA_FTYPE_MOSTLY_MXFP4_MOE: default_type = GGML_TYPE_MXFP4; break; + case LLAMA_FTYPE_MOSTLY_MXFP4_MOE: return GGML_TYPE_MXFP4; // K-quants case LLAMA_FTYPE_MOSTLY_Q2_K_S: - case LLAMA_FTYPE_MOSTLY_Q2_K: default_type = GGML_TYPE_Q2_K; break; - case LLAMA_FTYPE_MOSTLY_IQ3_XS: default_type = GGML_TYPE_IQ3_S; break; + case LLAMA_FTYPE_MOSTLY_Q2_K: return GGML_TYPE_Q2_K; + case LLAMA_FTYPE_MOSTLY_IQ3_XS: return GGML_TYPE_IQ3_S; case LLAMA_FTYPE_MOSTLY_Q3_K_S: case LLAMA_FTYPE_MOSTLY_Q3_K_M: - case LLAMA_FTYPE_MOSTLY_Q3_K_L: default_type = GGML_TYPE_Q3_K; break; + case LLAMA_FTYPE_MOSTLY_Q3_K_L: return GGML_TYPE_Q3_K; case LLAMA_FTYPE_MOSTLY_Q4_K_S: - case LLAMA_FTYPE_MOSTLY_Q4_K_M: default_type = GGML_TYPE_Q4_K; break; + case LLAMA_FTYPE_MOSTLY_Q4_K_M: return GGML_TYPE_Q4_K; case LLAMA_FTYPE_MOSTLY_Q5_K_S: - case LLAMA_FTYPE_MOSTLY_Q5_K_M: default_type = GGML_TYPE_Q5_K; break; - case LLAMA_FTYPE_MOSTLY_Q6_K: default_type = GGML_TYPE_Q6_K; break; - case LLAMA_FTYPE_MOSTLY_TQ1_0: default_type = GGML_TYPE_TQ1_0; break; - case LLAMA_FTYPE_MOSTLY_TQ2_0: default_type = GGML_TYPE_TQ2_0; break; - case LLAMA_FTYPE_MOSTLY_IQ2_XXS: default_type = GGML_TYPE_IQ2_XXS; break; - case LLAMA_FTYPE_MOSTLY_IQ2_XS: default_type = GGML_TYPE_IQ2_XS; break; - case LLAMA_FTYPE_MOSTLY_IQ2_S: default_type = GGML_TYPE_IQ2_XS; break; - case LLAMA_FTYPE_MOSTLY_IQ2_M: default_type = GGML_TYPE_IQ2_S; break; - case LLAMA_FTYPE_MOSTLY_IQ3_XXS: default_type = GGML_TYPE_IQ3_XXS; break; - case LLAMA_FTYPE_MOSTLY_IQ1_S: default_type = GGML_TYPE_IQ1_S; break; - case LLAMA_FTYPE_MOSTLY_IQ1_M: default_type = GGML_TYPE_IQ1_M; break; - case LLAMA_FTYPE_MOSTLY_IQ4_NL: default_type = GGML_TYPE_IQ4_NL; break; - case LLAMA_FTYPE_MOSTLY_IQ4_XS: default_type = GGML_TYPE_IQ4_XS; break; - case LLAMA_FTYPE_MOSTLY_IQ3_S: default_type = GGML_TYPE_IQ3_S; break; - case LLAMA_FTYPE_MOSTLY_IQ3_M: default_type = GGML_TYPE_IQ3_S; break; + case LLAMA_FTYPE_MOSTLY_Q5_K_M: return GGML_TYPE_Q5_K; + case LLAMA_FTYPE_MOSTLY_Q6_K: return GGML_TYPE_Q6_K; + case LLAMA_FTYPE_MOSTLY_TQ1_0: return GGML_TYPE_TQ1_0; + case LLAMA_FTYPE_MOSTLY_TQ2_0: return GGML_TYPE_TQ2_0; + case LLAMA_FTYPE_MOSTLY_IQ2_XXS: return GGML_TYPE_IQ2_XXS; + case LLAMA_FTYPE_MOSTLY_IQ2_XS: return GGML_TYPE_IQ2_XS; + case LLAMA_FTYPE_MOSTLY_IQ2_S: return GGML_TYPE_IQ2_XS; + case LLAMA_FTYPE_MOSTLY_IQ2_M: return GGML_TYPE_IQ2_S; + case LLAMA_FTYPE_MOSTLY_IQ3_XXS: return GGML_TYPE_IQ3_XXS; + case LLAMA_FTYPE_MOSTLY_IQ1_S: return GGML_TYPE_IQ1_S; + case LLAMA_FTYPE_MOSTLY_IQ1_M: return GGML_TYPE_IQ1_M; + case LLAMA_FTYPE_MOSTLY_IQ4_NL: return GGML_TYPE_IQ4_NL; + case LLAMA_FTYPE_MOSTLY_IQ4_XS: return GGML_TYPE_IQ4_XS; + case LLAMA_FTYPE_MOSTLY_IQ3_S: + case LLAMA_FTYPE_MOSTLY_IQ3_M: return GGML_TYPE_IQ3_S; default: throw std::runtime_error(format("invalid output file type %d\n", ftype)); } +} + +// +// main quantization driver +// + +static void llama_model_quantize_impl(const std::string & fname_inp, const std::string & fname_out, const llama_model_quantize_params * params) { + ggml_type default_type; + llama_ftype ftype = params->ftype; int nthread = params->nthread; @@ -541,6 +841,8 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: nthread = std::thread::hardware_concurrency(); } + default_type = llama_ftype_get_default_type(ftype); + // mmap consistently increases speed on Linux, and also increases speed on Windows with // hot cache. It may cause a slowdown on macOS, possibly related to free memory. #if defined(__linux__) || defined(_WIN32) @@ -556,7 +858,8 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: } std::vector splits = {}; - llama_model_loader ml(fname_inp, splits, use_mmap, /*use_direct_io*/ false, /*check_tensors*/ true, /*no_alloc*/ false, kv_overrides, nullptr); + llama_model_loader ml(/*metadata*/ nullptr, /*set_tensor_data*/ nullptr, /*set_tensor_data_ud*/ nullptr, + fname_inp, splits, use_mmap, /*use_direct_io*/ false, /*check_tensors*/ true, /*no_alloc*/ false, kv_overrides, nullptr); ml.init_mappings(false); // no prefetching llama_model model(llama_model_default_params()); @@ -574,7 +877,8 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: if (params->imatrix) { imatrix_data = static_cast>*>(params->imatrix); if (imatrix_data) { - LLAMA_LOG_INFO("================================ Have weights data with %d entries\n",int(imatrix_data->size())); + LLAMA_LOG_INFO("\n%s: have importance matrix data with %d entries\n", + __func__, (int)imatrix_data->size()); qs.has_imatrix = true; // check imatrix for nans or infs for (const auto & kv : *imatrix_data) { @@ -596,7 +900,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: } // copy the KV pairs from the input file - gguf_set_kv (ctx_out.get(), ml.meta.get()); + gguf_set_kv (ctx_out.get(), ml.metadata); gguf_set_val_u32(ctx_out.get(), "general.quantization_version", GGML_QNT_VERSION); // TODO: use LLM_KV gguf_set_val_u32(ctx_out.get(), "general.file_type", ftype); // TODO: use LLM_KV @@ -657,35 +961,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: }); } - for (const auto * it : tensors) { - const struct ggml_tensor * tensor = it->tensor; - - const std::string name = ggml_get_name(tensor); - - // TODO: avoid hardcoded tensor names - use the TN_* constants - if (name.find("attn_v.weight") != std::string::npos || - name.find("attn_qkv.weight") != std::string::npos || - name.find("attn_kv_b.weight")!= std::string::npos) { - ++qs.n_attention_wv; - } else if (name == LLM_TN(model.arch)(LLM_TENSOR_OUTPUT, "weight")) { - qs.has_output = true; - } - } - - qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)model.hparams.n_layer; - - size_t total_size_org = 0; - size_t total_size_new = 0; - - std::vector workers; - workers.reserve(nthread); - int idx = 0; - - std::vector> read_data; - std::vector> work; - std::vector> f32_conv_buf; - uint16_t n_split = 1; // Assume split index is continuous @@ -697,14 +973,68 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: std::vector ctx_outs(n_split); ctx_outs[0] = std::move(ctx_out); - // populate the original tensors so we get an initial meta data - for (const auto * it : tensors) { + // compute tensor metadata once and cache it + std::vector metadata(tensors.size()); + + // initialize quantization state before preliminary loop (counters for use_more_bits) + { + for (size_t i = 0; i < tensors.size(); ++i) { + const auto cat = tensor_get_category(tensors[i]->tensor->name); + if (category_is_attn_v(cat)) { + ++qs.n_attention_wv; + } + if (cat == tensor_category::OUTPUT) { + qs.has_tied_embeddings = false; + } + metadata[i].category = cat; // save and re-use the category while we're at it + } + // these also need to be set to n_layer by default + qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)qs.model.hparams.n_layer; + } + + // flag for --dry-run + bool will_require_imatrix = false; + + // + // preliminary iteration over all weights + // + + for (size_t i = 0; i < tensors.size(); ++i) { + const auto * it = tensors[i]; + const struct ggml_tensor * tensor = it->tensor; + const std::string name = ggml_get_name(tensor); + uint16_t i_split = params->keep_split ? it->idx : 0; - ggml_tensor * tensor = it->tensor; if (!ctx_outs[i_split]) { ctx_outs[i_split].reset(gguf_init_empty()); } gguf_add_tensor(ctx_outs[i_split].get(), tensor); + + metadata[i].allows_quantization = tensor_allows_quantization(params, model.arch, tensor); + + if (metadata[i].allows_quantization) { + metadata[i].target_type = llama_tensor_get_type(qs, params, tensor, default_type, metadata[i]); + } else { + metadata[i].target_type = tensor->type; + } + + metadata[i].requires_imatrix = tensor_requires_imatrix(tensor->name, metadata[i].target_type, ftype); + + if (params->imatrix) { + metadata[i].remapped_imatrix_name = remap_imatrix(tensor->name, mapped); + } else if (metadata[i].allows_quantization && metadata[i].requires_imatrix) { + if (params->dry_run) { + will_require_imatrix = true; + } else { + LLAMA_LOG_ERROR("\n============================================================================\n" + " ERROR: this quantization requires an importance matrix!\n" + " - offending tensor: %s\n" + " - target type: %s\n" + "============================================================================\n\n", + name.c_str(), ggml_type_name(metadata[i].target_type)); + throw std::runtime_error("this quantization requires an imatrix!"); + } + } } // Set split info if needed @@ -716,6 +1046,16 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: } } + size_t total_size_org = 0; + size_t total_size_new = 0; + + std::vector workers; + workers.reserve(nthread); + + std::vector> read_data; + std::vector> work; + std::vector> f32_conv_buf; + int cur_split = -1; std::ofstream fout; auto close_ofstream = [&]() { @@ -745,20 +1085,20 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: ::zeros(fout, meta_size); }; - const auto tn = LLM_TN(model.arch); - // no output file for --dry-run if (!params->dry_run) { new_ofstream(0); } - // flag for `--dry-run`, to let the user know if imatrix will be required for a real - // quantization, as a courtesy - bool will_require_imatrix = false; + // + // main loop: iterate over all weights + // - for (const auto * it : tensors) { - const auto & weight = *it; + for (size_t i = 0; i < tensors.size(); ++i) { + const auto & weight = *tensors[i]; + const auto & tm = metadata[i]; ggml_tensor * tensor = weight.tensor; + if (!params->dry_run && (weight.idx != cur_split && params->keep_split)) { close_ofstream(); new_ofstream(weight.idx); @@ -777,162 +1117,31 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: ml.load_data_for(tensor); } - LLAMA_LOG_INFO("[%4d/%4d] %36s - [%s], type = %6s, ", + LLAMA_LOG_INFO("[%4d/%4d] %-36s - [%s], type = %6s, ", ++idx, ml.n_tensors, ggml_get_name(tensor), llama_format_tensor_shape(tensor).c_str(), ggml_type_name(tensor->type)); - // This used to be a regex, but has an extreme cost to compile times. - bool quantize = name.rfind("weight") == name.size() - 6; // ends with 'weight'? - - // quantize only 2D and 3D tensors (experts) - quantize &= (ggml_n_dims(tensor) >= 2); - - // do not quantize norm tensors - quantize &= name.find("_norm.weight") == std::string::npos; - - quantize &= params->quantize_output_tensor || name != "output.weight"; - quantize &= !params->only_copy; - - // do not quantize expert gating tensors - // NOTE: can't use LLM_TN here because the layer number is not known - quantize &= name.find("ffn_gate_inp.weight") == std::string::npos; - - // these are very small (e.g. 4x4) - quantize &= name.find("altup") == std::string::npos; - quantize &= name.find("laurel") == std::string::npos; - - // these are not too big so keep them as it is - quantize &= name.find("per_layer_model_proj") == std::string::npos; - - // do not quantize positional embeddings and token types (BERT) - quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_POS_EMBD, "weight"); - quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_TOKEN_TYPES, "weight"); - - // do not quantize Mamba /Kimi's small conv1d weights - // NOTE: can't use LLM_TN here because the layer number is not known - quantize &= name.find("ssm_conv1d") == std::string::npos; - quantize &= name.find("shortconv.conv.weight") == std::string::npos; - - // do not quantize RWKV's small yet 2D weights - quantize &= name.find("time_mix_first.weight") == std::string::npos; - quantize &= name.find("time_mix_w0.weight") == std::string::npos; - quantize &= name.find("time_mix_w1.weight") == std::string::npos; - quantize &= name.find("time_mix_w2.weight") == std::string::npos; - quantize &= name.find("time_mix_v0.weight") == std::string::npos; - quantize &= name.find("time_mix_v1.weight") == std::string::npos; - quantize &= name.find("time_mix_v2.weight") == std::string::npos; - quantize &= name.find("time_mix_a0.weight") == std::string::npos; - quantize &= name.find("time_mix_a1.weight") == std::string::npos; - quantize &= name.find("time_mix_a2.weight") == std::string::npos; - quantize &= name.find("time_mix_g1.weight") == std::string::npos; - quantize &= name.find("time_mix_g2.weight") == std::string::npos; - quantize &= name.find("time_mix_decay_w1.weight") == std::string::npos; - quantize &= name.find("time_mix_decay_w2.weight") == std::string::npos; - quantize &= name.find("time_mix_lerp_fused.weight") == std::string::npos; - - // do not quantize relative position bias (T5) - quantize &= name.find("attn_rel_b.weight") == std::string::npos; - - // do not quantize specific multimodal tensors - quantize &= name.find(".position_embd.") == std::string::npos; - - ggml_type new_type; - void * new_data; - size_t new_size; - - if (quantize) { - new_type = default_type; - - // get more optimal quantization type based on the tensor shape, layer, etc. - if (!params->pure && ggml_is_quantized(default_type)) { - // if the user provided tensor types - use those - bool manual = false; - if (params->tensor_types) { - const std::vector & tensor_types = *static_cast *>(params->tensor_types); - const std::string tensor_name(tensor->name); - for (const auto & [tname, qtype] : tensor_types) { - if (std::regex pattern(tname); std::regex_search(tensor_name, pattern)) { - if (qtype != new_type) { - LLAMA_LOG_WARN("(manual override: %s -> %s) ", ggml_type_name(new_type), ggml_type_name(qtype)); - new_type = qtype; // if two or more types are specified for the same tensor, the last match wins - manual = true; - break; - } - } - } - } - - // if not manual - use the standard logic for choosing the quantization type based on the selected mixture - if (!manual) { - new_type = llama_tensor_get_type(qs, new_type, tensor, ftype); - } - - // incompatible tensor shapes are handled here - fallback to a compatible type - { - bool convert_incompatible_tensor = false; + const ggml_type cur_type = tensor->type; + const ggml_type new_type = tm.target_type; - const int64_t nx = tensor->ne[0]; - const int64_t ny = tensor->ne[1]; - const int64_t qk_k = ggml_blck_size(new_type); - - if (nx % qk_k != 0) { - LLAMA_LOG_WARN("\n\n%s : tensor cols %" PRId64 " x %" PRId64 " are not divisible by %" PRId64 ", required for %s", __func__, nx, ny, qk_k, ggml_type_name(new_type)); - convert_incompatible_tensor = true; - } else { - ++qs.n_k_quantized; - } - - if (convert_incompatible_tensor) { - switch (new_type) { - case GGML_TYPE_TQ1_0: - case GGML_TYPE_TQ2_0: new_type = GGML_TYPE_Q4_0; break; // TODO: use a symmetric type instead - case GGML_TYPE_IQ2_XXS: - case GGML_TYPE_IQ2_XS: - case GGML_TYPE_IQ2_S: - case GGML_TYPE_IQ3_XXS: - case GGML_TYPE_IQ3_S: - case GGML_TYPE_IQ1_S: - case GGML_TYPE_IQ1_M: - case GGML_TYPE_Q2_K: - case GGML_TYPE_Q3_K: - case GGML_TYPE_IQ4_XS: new_type = GGML_TYPE_IQ4_NL; break; - case GGML_TYPE_Q4_K: new_type = GGML_TYPE_Q5_0; break; - case GGML_TYPE_Q5_K: new_type = GGML_TYPE_Q5_1; break; - case GGML_TYPE_Q6_K: new_type = GGML_TYPE_Q8_0; break; - default: throw std::runtime_error("\nUnsupported tensor size encountered\n"); - } - if (tensor->ne[0] % ggml_blck_size(new_type) != 0) { - new_type = GGML_TYPE_F16; - } - LLAMA_LOG_WARN(" - using fallback quantization %s\n", ggml_type_name(new_type)); - ++qs.n_fallback; - } - } - } - if (params->token_embedding_type < GGML_TYPE_COUNT && strcmp(tensor->name, "token_embd.weight") == 0) { - new_type = params->token_embedding_type; - } - if (params->output_tensor_type < GGML_TYPE_COUNT && strcmp(tensor->name, "output.weight") == 0) { - new_type = params->output_tensor_type; - } + // If we've decided to quantize to the same type the tensor is already + // in then there's nothing to do. + bool quantize = cur_type != new_type; - // If we've decided to quantize to the same type the tensor is already - // in then there's nothing to do. - quantize = tensor->type != new_type; - } + void * new_data; + size_t new_size; - // we have now decided on the target type for this tensor if (params->dry_run) { - // the --dry-run option calculates the final quantization size without quantizting + // the --dry-run option calculates the final quantization size without quantizing if (quantize) { new_size = ggml_nrows(tensor) * ggml_row_size(new_type, tensor->ne[0]); LLAMA_LOG_INFO("size = %8.2f MiB -> %8.2f MiB (%s)\n", tensor_size/1024.0/1024.0, new_size/1024.0/1024.0, ggml_type_name(new_type)); - if (!will_require_imatrix && tensor_type_requires_imatrix(tensor, new_type, params->ftype)) { + if (!will_require_imatrix && tm.requires_imatrix) { will_require_imatrix = true; } } else { @@ -945,7 +1154,6 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: } else { // no --dry-run, perform quantization if (!quantize) { - new_type = tensor->type; new_data = tensor->data; new_size = tensor_size; LLAMA_LOG_INFO("size = %8.3f MiB\n", tensor_size/1024.0/1024.0); @@ -954,7 +1162,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: const float * imatrix = nullptr; if (imatrix_data) { - auto it = imatrix_data->find(remap_imatrix(tensor->name, mapped)); + auto it = imatrix_data->find(tm.remapped_imatrix_name); if (it == imatrix_data->end()) { LLAMA_LOG_INFO("\n====== %s: did not find weights for %s\n", __func__, tensor->name); } else { @@ -968,14 +1176,14 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: // this is a significant error and it may be good idea to abort the process if this happens, // since many people will miss the error and not realize that most of the model is being quantized without an imatrix // tok_embd should be ignored in this case, since it always causes this warning - if (name != tn(LLM_TENSOR_TOKEN_EMBD, "weight")) { + if (!tensor_name_match_token_embd(tensor->name)) { throw std::runtime_error(format("imatrix size %d is different from tensor size %d for %s", int(it->second.size()), int(tensor->ne[0]*tensor->ne[2]), tensor->name)); } } } } - if (!imatrix && tensor_type_requires_imatrix(tensor, new_type, params->ftype)) { + if (!imatrix && tm.requires_imatrix) { LLAMA_LOG_ERROR("\n\n============================================================\n"); LLAMA_LOG_ERROR("Missing importance matrix for tensor %s in a very low-bit quantization\n", tensor->name); LLAMA_LOG_ERROR("The result will be garbage, so bailing out\n"); @@ -1020,29 +1228,6 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: const float * imatrix_03 = imatrix ? imatrix + i03 * n_per_row : nullptr; new_size += llama_tensor_quantize_impl(new_type, f32_data_03, new_data_03, chunk_size, nrows, n_per_row, imatrix_03, workers, nthread_use); - - // TODO: temporary sanity check that the F16 -> MXFP4 is lossless -#if 0 - if (new_type == GGML_TYPE_MXFP4) { - auto * x = f32_data_03; - - //LLAMA_LOG_INFO("nrows = %d, n_per_row = %d\n", nrows, n_per_row); - std::vector deq(nrows*n_per_row); - const ggml_type_traits * qtype = ggml_get_type_traits(new_type); - qtype->to_float(new_data_03, deq.data(), deq.size()); - - double err = 0.0f; - for (int i = 0; i < (int) deq.size(); ++i) { - err += fabsf(deq[i] - x[i]); - //if (fabsf(deq[i] - x[i]) > 0.00001 && i < 256) { - if (deq[i] != x[i]) { - LLAMA_LOG_INFO("deq[%d] = %f, x[%d] = %f\n", i, deq[i], i, x[i]); - } - } - //LLAMA_LOG_INFO("err = %f\n", err); - GGML_ASSERT(err == 0.00000); - } -#endif } LLAMA_LOG_INFO("size = %8.2f MiB -> %8.2f MiB\n", tensor_size/1024.0/1024.0, new_size/1024.0/1024.0); } @@ -1058,7 +1243,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: fout.write((const char *) new_data, new_size); zeros(fout, GGML_PAD(new_size, align) - new_size); } // no --dry-run - } // iterate over tensors + } // main loop if (!params->dry_run) { close_ofstream(); @@ -1075,7 +1260,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: if (qs.n_fallback > 0) { LLAMA_LOG_WARN("%s: WARNING: %d of %d tensor(s) required fallback quantization\n", - __func__, qs.n_fallback, qs.n_k_quantized + qs.n_fallback); + __func__, qs.n_fallback, ml.n_tensors); } } diff --git a/examples/talk-llama/llama-vocab.cpp b/examples/talk-llama/llama-vocab.cpp index 194eed238ec..68ba292d426 100644 --- a/examples/talk-llama/llama-vocab.cpp +++ b/examples/talk-llama/llama-vocab.cpp @@ -1719,7 +1719,7 @@ struct llama_vocab::impl { }; void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { - struct gguf_context * ctx = ml.meta.get(); + struct gguf_context * ctx = ml.metadata; // determine vocab type { @@ -1833,7 +1833,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { const char * pc = (const char *) gguf_get_arr_data(ctx, precompiled_charsmap_keyidx); precompiled_charsmap.assign(pc, pc + n_precompiled_charsmap); #if defined(__BYTE_ORDER__) && defined(__ORDER_BIG_ENDIAN__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ - // correct endiannes of data in precompiled_charsmap binary blob + // correct endianness of data in precompiled_charsmap binary blob uint32_t * xcda_blob_size = (uint32_t *) &precompiled_charsmap[0]; *xcda_blob_size = __builtin_bswap32(*xcda_blob_size); assert(*xcda_blob_size + sizeof(uint32_t) < n_precompiled_charsmap); diff --git a/examples/talk-llama/llama.cpp b/examples/talk-llama/llama.cpp index 6da90d6f1f8..872e659edca 100644 --- a/examples/talk-llama/llama.cpp +++ b/examples/talk-llama/llama.cpp @@ -1,5 +1,6 @@ #include "llama.h" +#include "ggml-cpp.h" #include "llama-impl.h" #include "llama-chat.h" @@ -12,6 +13,7 @@ #include "ggml.h" #include "ggml-backend.h" +#include "gguf.h" #include #include @@ -825,7 +827,8 @@ int64_t llama_time_us(void) { } // Returns 0 on success, -1 on error, and -2 on cancellation via llama_progress_callback -static int llama_model_load(const std::string & fname, std::vector & splits, llama_model & model, llama_model_params & params) { +static int llama_model_load(struct gguf_context * metadata, llama_model_set_tensor_data_t set_tensor_data, void * set_tensor_data_ud, + const std::string & fname, std::vector & splits, llama_model & model, llama_model_params & params) { // loading time will be recalculated after the first eval, so // we take page faults deferred by mmap() into consideration model.t_load_us = 0; @@ -834,7 +837,8 @@ static int llama_model_load(const std::string & fname, std::vector model.t_start_us = tm.t_start_us; try { - llama_model_loader ml(fname, splits, params.use_mmap, params.use_direct_io, params.check_tensors, params.no_alloc, params.kv_overrides, params.tensor_buft_overrides); + llama_model_loader ml(metadata, set_tensor_data, set_tensor_data_ud, fname, splits, params.use_mmap, params.use_direct_io, + params.check_tensors, params.no_alloc, params.kv_overrides, params.tensor_buft_overrides); ml.print_info(); @@ -880,9 +884,13 @@ static int llama_model_load(const std::string & fname, std::vector } static struct llama_model * llama_model_load_from_file_impl( + struct gguf_context * metadata, + llama_model_set_tensor_data_t set_tensor_data, + void * set_tensor_data_ud, const std::string & path_model, std::vector & splits, struct llama_model_params params) { + GGML_ASSERT((metadata == nullptr) != path_model.empty() && "exactly one out of metadata and path_model needs to be defined"); ggml_time_init(); if (!params.vocab_only && ggml_backend_reg_count() == 0) { @@ -1003,7 +1011,7 @@ static struct llama_model * llama_model_load_from_file_impl( props.memory_free/1024/1024); } - const int status = llama_model_load(path_model, splits, *model, params); + const int status = llama_model_load(metadata, set_tensor_data, set_tensor_data_ud, path_model, splits, *model, params); GGML_ASSERT(status <= 0); if (status < 0) { if (status == -1) { @@ -1019,6 +1027,18 @@ static struct llama_model * llama_model_load_from_file_impl( return model; } +struct llama_model * llama_model_init_from_user( + struct gguf_context * metadata, + llama_model_set_tensor_data_t set_tensor_data, + void * set_tensor_data_ud, + struct llama_model_params params) { + GGML_ASSERT(metadata != nullptr); + std::string path_model; + std::vector splits = {}; + params.use_mmap = false; + params.use_extra_bufts = false; + return llama_model_load_from_file_impl(metadata, set_tensor_data, set_tensor_data_ud, path_model, splits, params); +} // deprecated struct llama_model * llama_load_model_from_file( const char * path_model, @@ -1030,7 +1050,7 @@ struct llama_model * llama_model_load_from_file( const char * path_model, struct llama_model_params params) { std::vector splits = {}; - return llama_model_load_from_file_impl(path_model, splits, params); + return llama_model_load_from_file_impl(nullptr, nullptr, nullptr, path_model, splits, params); } struct llama_model * llama_model_load_from_splits( @@ -1046,11 +1066,11 @@ struct llama_model * llama_model_load_from_splits( for (size_t i = 0; i < n_paths; ++i) { splits.push_back(paths[i]); } - return llama_model_load_from_file_impl(splits.front(), splits, params); + return llama_model_load_from_file_impl(nullptr, nullptr, nullptr, splits.front(), splits, params); } void llama_model_save_to_file(const struct llama_model * model, const char * path_model) { - llama_model_saver ms(*model); + llama_model_saver ms(model); ms.add_kv_from_model(); ms.add_tensors_from_model(); ms.save(path_model); diff --git a/examples/talk-llama/llama.h b/examples/talk-llama/llama.h index 077f66dc651..c6e102abe51 100644 --- a/examples/talk-llama/llama.h +++ b/examples/talk-llama/llama.h @@ -5,6 +5,7 @@ #include "ggml-cpu.h" #include "ggml-backend.h" #include "ggml-opt.h" +#include "gguf.h" #include #include @@ -152,6 +153,7 @@ extern "C" { LLAMA_FTYPE_MOSTLY_TQ1_0 = 36, // except 1d tensors LLAMA_FTYPE_MOSTLY_TQ2_0 = 37, // except 1d tensors LLAMA_FTYPE_MOSTLY_MXFP4_MOE = 38, // except 1d tensors + LLAMA_FTYPE_MOSTLY_NVFP4 = 39, // except 1d tensors LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file }; @@ -440,19 +442,30 @@ extern "C" { LLAMA_API void llama_detach_threadpool(struct llama_context * ctx); + typedef void (*llama_model_set_tensor_data_t)(struct ggml_tensor * tensor, void * userdata); + + // Create a new model from GGUF metadata as well as a function to set the tensor data + // - tensors are created as GGML_TYPE_F32 by default, + // override by adding a tensor with the same name but a different name to the context + LLAMA_API struct llama_model * llama_model_init_from_user( + struct gguf_context * metadata, + llama_model_set_tensor_data_t set_tensor_data, // function to initialize tensor data with + void * set_tensor_data_ud, // userdata for function + struct llama_model_params params); + DEPRECATED(LLAMA_API struct llama_model * llama_load_model_from_file( const char * path_model, struct llama_model_params params), "use llama_model_load_from_file instead"); - // Load the model from a file + // Load a model from a file // If the file is split into multiple parts, the file name must follow this pattern: -%05d-of-%05d.gguf // If the split file name does not follow this pattern, use llama_model_load_from_splits LLAMA_API struct llama_model * llama_model_load_from_file( const char * path_model, struct llama_model_params params); - // Load the model from multiple splits (support custom naming scheme) + // Load a model from multiple splits (support custom naming scheme) // The paths must be in the correct order LLAMA_API struct llama_model * llama_model_load_from_splits( const char ** paths, @@ -973,7 +986,7 @@ extern "C" { // Logits for the ith token. For positive indices, Equivalent to: // llama_get_logits(ctx) + ctx->output_ids[i]*n_vocab - // Negative indicies can be used to access logits in reverse order, -1 is the last logit. + // Negative indices can be used to access logits in reverse order, -1 is the last logit. // returns NULL for invalid ids. LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i); @@ -988,7 +1001,7 @@ extern "C" { // Get the embeddings for the ith token. For positive indices, Equivalent to: // llama_get_embeddings(ctx) + ctx->output_ids[i]*n_embd - // Negative indicies can be used to access embeddings in reverse order, -1 is the last embedding. + // Negative indices can be used to access embeddings in reverse order, -1 is the last embedding. // shape: [n_embd] (1-dimensional) // returns NULL for invalid ids. LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i); @@ -1008,9 +1021,9 @@ extern "C" { // Returns LLAMA_TOKEN_NULL if no token was sampled. LLAMA_API llama_token llama_get_sampled_token_ith(struct llama_context * ctx, int32_t i); - // Get the backend sampled probabilites for the ith token + // Get the backend sampled probabilities for the ith token // The index matches llama_get_sampled_token_ith(). - // Returns NULL if no probabilites were generated. + // Returns NULL if no probabilities were generated. LLAMA_API float * llama_get_sampled_probs_ith (struct llama_context * ctx, int32_t i); LLAMA_API uint32_t llama_get_sampled_probs_count_ith(struct llama_context * ctx, int32_t i); @@ -1337,7 +1350,7 @@ extern "C" { float tau, float eta); - /// @details Intializes a GBNF grammar, see grammars/README.md for details. + /// @details Initializes a GBNF grammar, see grammars/README.md for details. /// @param vocab The vocabulary that this grammar will be used with. /// @param grammar_str The production rules for the grammar, encoded as a string. Returns an empty grammar if empty. Returns NULL if parsing of grammar_str fails. /// @param grammar_root The name of the start symbol for the grammar. diff --git a/examples/talk-llama/models/afmoe.cpp b/examples/talk-llama/models/afmoe.cpp index 6a752a403f6..9aabe25c965 100644 --- a/examples/talk-llama/models/afmoe.cpp +++ b/examples/talk-llama/models/afmoe.cpp @@ -1,8 +1,8 @@ #include "models.h" llm_build_afmoe::llm_build_afmoe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + const int64_t n_embd_head = hparams.n_embd_head_v(); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); ggml_tensor * cur; ggml_tensor * inpL; @@ -127,7 +127,6 @@ llm_build_afmoe::llm_build_afmoe(const llama_model & model, const llm_graph_para n_expert, n_expert_used, LLM_FFN_SILU, hparams.expert_weights_norm, // norm_w (route_norm=True) - hparams.expert_weights_scale, // scale_w hparams.expert_weights_scale, // w_scale (route_scale=2.826) (llama_expert_gating_func_type) hparams.expert_gating_func, il); diff --git a/examples/talk-llama/models/apertus.cpp b/examples/talk-llama/models/apertus.cpp index 9af19c1bfe8..4d65614e466 100644 --- a/examples/talk-llama/models/apertus.cpp +++ b/examples/talk-llama/models/apertus.cpp @@ -3,10 +3,10 @@ llm_build_apertus::llm_build_apertus(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; diff --git a/examples/talk-llama/models/arcee.cpp b/examples/talk-llama/models/arcee.cpp index aa6167dba1e..20b9ffd49eb 100644 --- a/examples/talk-llama/models/arcee.cpp +++ b/examples/talk-llama/models/arcee.cpp @@ -2,10 +2,10 @@ llm_build_arcee::llm_build_arcee(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; diff --git a/examples/talk-llama/models/arctic.cpp b/examples/talk-llama/models/arctic.cpp index e8f028a723e..b712e08cbd3 100644 --- a/examples/talk-llama/models/arctic.cpp +++ b/examples/talk-llama/models/arctic.cpp @@ -1,11 +1,10 @@ #include "models.h" - llm_build_arctic::llm_build_arctic(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; @@ -104,7 +103,7 @@ llm_build_arctic::llm_build_arctic(const llama_model & model, const llm_graph_pa nullptr, n_expert, n_expert_used, LLM_FFN_SILU, true, - false, 0.0, + hparams.expert_weights_scale, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il); cb(cur, "ffn_moe_out", il); diff --git a/examples/talk-llama/models/baichuan.cpp b/examples/talk-llama/models/baichuan.cpp index c04b0c98b0b..abd03cd0b97 100644 --- a/examples/talk-llama/models/baichuan.cpp +++ b/examples/talk-llama/models/baichuan.cpp @@ -2,10 +2,10 @@ llm_build_baichuan::llm_build_baichuan(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; @@ -56,6 +56,7 @@ llm_build_baichuan::llm_build_baichuan(const llama_model & model, const llm_grap ); break; case LLM_TYPE_13B: + case LLM_TYPE_UNKNOWN: break; default: GGML_ABORT("fatal error"); diff --git a/examples/talk-llama/models/bailingmoe.cpp b/examples/talk-llama/models/bailingmoe.cpp index ed56b9c4713..25e3369c313 100644 --- a/examples/talk-llama/models/bailingmoe.cpp +++ b/examples/talk-llama/models/bailingmoe.cpp @@ -1,6 +1,5 @@ #include "models.h" - llm_build_bailingmoe::llm_build_bailingmoe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { ggml_tensor * cur; ggml_tensor * inpL; @@ -97,7 +96,7 @@ llm_build_bailingmoe::llm_build_bailingmoe(const llama_model & model, const llm_ nullptr, n_expert, n_expert_used, LLM_FFN_SILU, hparams.expert_weights_norm, - false, hparams.expert_weights_scale, + hparams.expert_weights_scale, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il); cb(moe_out, "ffn_moe_out", il); diff --git a/examples/talk-llama/models/bailingmoe2.cpp b/examples/talk-llama/models/bailingmoe2.cpp index fbf7b210c42..42098624663 100644 --- a/examples/talk-llama/models/bailingmoe2.cpp +++ b/examples/talk-llama/models/bailingmoe2.cpp @@ -1,13 +1,11 @@ #include "models.h" - - llm_build_bailingmoe2::llm_build_bailingmoe2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); ggml_tensor * cur; ggml_tensor * inpL; @@ -90,7 +88,7 @@ llm_build_bailingmoe2::llm_build_bailingmoe2(const llama_model & model, const ll model.layers[il].ffn_exp_probs_b, n_expert, n_expert_used, LLM_FFN_SILU, hparams.expert_weights_norm, - true, hparams.expert_weights_scale, + hparams.expert_weights_scale, (llama_expert_gating_func_type) hparams.expert_gating_func, il); cb(moe_out, "ffn_moe_out", il); diff --git a/examples/talk-llama/models/bert.cpp b/examples/talk-llama/models/bert.cpp index bca0e254fc5..87331791418 100644 --- a/examples/talk-llama/models/bert.cpp +++ b/examples/talk-llama/models/bert.cpp @@ -1,12 +1,10 @@ #include "models.h" - - llm_build_bert::llm_build_bert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); ggml_tensor * cur; ggml_tensor * inpL; @@ -129,9 +127,17 @@ llm_build_bert::llm_build_bert(const llama_model & model, const llm_graph_params // feed-forward network if (hparams.moe_every_n_layers > 0 && il % hparams.moe_every_n_layers == 1) { // MoE branch - cur = build_moe_ffn(cur, model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps, nullptr, - model.layers[il].ffn_down_exps, nullptr, hparams.n_expert, hparams.n_expert_used, - LLM_FFN_GELU, false, false, 0.0f, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il); + cur = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + nullptr, + model.layers[il].ffn_down_exps, + nullptr, + hparams.n_expert, hparams.n_expert_used, + LLM_FFN_GELU, false, + hparams.expert_weights_scale, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + il); cb(cur, "ffn_moe_out", il); } else if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_NOMIC_BERT_MOE || model.arch == LLM_ARCH_JINA_BERT_V3) { diff --git a/examples/talk-llama/models/bitnet.cpp b/examples/talk-llama/models/bitnet.cpp index 331a3f11197..ccf5bc8e82b 100644 --- a/examples/talk-llama/models/bitnet.cpp +++ b/examples/talk-llama/models/bitnet.cpp @@ -2,9 +2,9 @@ llm_build_bitnet::llm_build_bitnet(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); ggml_tensor * cur; ggml_tensor * inpL; @@ -29,10 +29,7 @@ llm_build_bitnet::llm_build_bitnet(const llama_model & model, const llm_graph_pa // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - if (model.layers[il].wq_scale) { - Qcur = ggml_mul(ctx0, Qcur, model.layers[il].wq_scale); - } + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur, model.layers[il].wq_s); cb(Qcur, "Qcur", il); if (model.layers[il].bq) { Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); @@ -40,10 +37,7 @@ llm_build_bitnet::llm_build_bitnet(const llama_model & model, const llm_graph_pa } // B1.K - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - if (model.layers[il].wk_scale) { - Kcur = ggml_mul(ctx0, Kcur, model.layers[il].wk_scale); - } + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur, model.layers[il].wk_s); cb(Kcur, "Kcur", il); if (model.layers[il].bk) { Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); @@ -51,10 +45,7 @@ llm_build_bitnet::llm_build_bitnet(const llama_model & model, const llm_graph_pa } // B1.V - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - if (model.layers[il].wv_scale) { - Vcur = ggml_mul(ctx0, Vcur, model.layers[il].wv_scale); - } + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur, model.layers[il].wv_s); cb(Vcur, "Vcur", il); if (model.layers[il].bv) { Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); @@ -90,10 +81,7 @@ llm_build_bitnet::llm_build_bitnet(const llama_model & model, const llm_graph_pa LLM_NORM_RMS, il); cb(cur, "attn_sub_norm", il); - cur = build_lora_mm(model.layers[il].wo, cur); - if (model.layers[il].wo_scale) { - cur = ggml_mul(ctx0, cur, model.layers[il].wo_scale); - } + cur = build_lora_mm(model.layers[il].wo, cur, model.layers[il].wo_s); if (model.layers[il].bo) { cur = ggml_add(ctx0, cur, model.layers[il].bo); } @@ -115,8 +103,8 @@ llm_build_bitnet::llm_build_bitnet(const llama_model & model, const llm_graph_pa cb(cur, "ffn_norm", il); cur = build_ffn(cur, - model.layers[il].ffn_up, NULL, model.layers[il].ffn_up_scale, - model.layers[il].ffn_gate, NULL, model.layers[il].ffn_gate_scale, + model.layers[il].ffn_up, NULL, model.layers[il].ffn_up_s, + model.layers[il].ffn_gate, NULL, model.layers[il].ffn_gate_s, NULL, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); @@ -127,10 +115,7 @@ llm_build_bitnet::llm_build_bitnet(const llama_model & model, const llm_graph_pa LLM_NORM_RMS, il); cb(cur, "ffn_sub_norm", il); - cur = build_lora_mm(model.layers[il].ffn_down, cur); - if (model.layers[il].ffn_down_scale) { - cur = ggml_mul(ctx0, cur, model.layers[il].ffn_down_scale); - } + cur = build_lora_mm(model.layers[il].ffn_down, cur, model.layers[il].ffn_down_s); cb(cur, "ffn_down", il); cur = ggml_add(ctx0, cur, ffn_inp); diff --git a/examples/talk-llama/models/bloom.cpp b/examples/talk-llama/models/bloom.cpp index 2c552d1d15e..b1c19bb58a2 100644 --- a/examples/talk-llama/models/bloom.cpp +++ b/examples/talk-llama/models/bloom.cpp @@ -1,10 +1,10 @@ #include "models.h" llm_build_bloom::llm_build_bloom(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); ggml_tensor * cur; ggml_tensor * inpL; diff --git a/examples/talk-llama/models/chameleon.cpp b/examples/talk-llama/models/chameleon.cpp index 184511aed4c..2f24105fa14 100644 --- a/examples/talk-llama/models/chameleon.cpp +++ b/examples/talk-llama/models/chameleon.cpp @@ -3,10 +3,10 @@ #include llm_build_chameleon::llm_build_chameleon(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; diff --git a/examples/talk-llama/models/chatglm.cpp b/examples/talk-llama/models/chatglm.cpp index 2685d4fbcbe..5887ed22e7e 100644 --- a/examples/talk-llama/models/chatglm.cpp +++ b/examples/talk-llama/models/chatglm.cpp @@ -2,10 +2,10 @@ llm_build_chatglm::llm_build_chatglm(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); ggml_tensor * cur; ggml_tensor * inpL; diff --git a/examples/talk-llama/models/codeshell.cpp b/examples/talk-llama/models/codeshell.cpp index 0b3bdbff529..e8e13e143f2 100644 --- a/examples/talk-llama/models/codeshell.cpp +++ b/examples/talk-llama/models/codeshell.cpp @@ -1,11 +1,11 @@ #include "models.h" llm_build_codeshell::llm_build_codeshell(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; diff --git a/examples/talk-llama/models/cogvlm.cpp b/examples/talk-llama/models/cogvlm.cpp index 0ceae3aaeb5..2ef2b6e389b 100644 --- a/examples/talk-llama/models/cogvlm.cpp +++ b/examples/talk-llama/models/cogvlm.cpp @@ -2,11 +2,11 @@ llm_build_cogvlm::llm_build_cogvlm(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); const float kq_scale = 1.0f / sqrtf(float(n_embd_head)); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * inpL; ggml_tensor * cur; diff --git a/examples/talk-llama/models/cohere2-iswa.cpp b/examples/talk-llama/models/cohere2-iswa.cpp index 9334b5e4263..7c71a59ae7f 100644 --- a/examples/talk-llama/models/cohere2-iswa.cpp +++ b/examples/talk-llama/models/cohere2-iswa.cpp @@ -1,9 +1,9 @@ #include "models.h" llm_build_cohere2_iswa::llm_build_cohere2_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); const float f_logit_scale = hparams.f_logit_scale; diff --git a/examples/talk-llama/models/command-r.cpp b/examples/talk-llama/models/command-r.cpp index 4d3b643b444..ba1230f0419 100644 --- a/examples/talk-llama/models/command-r.cpp +++ b/examples/talk-llama/models/command-r.cpp @@ -4,9 +4,9 @@ llm_build_command_r::llm_build_command_r(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); const float f_logit_scale = hparams.f_logit_scale; diff --git a/examples/talk-llama/models/dbrx.cpp b/examples/talk-llama/models/dbrx.cpp index 6d2a0ebf1b7..73eb5cd24e7 100644 --- a/examples/talk-llama/models/dbrx.cpp +++ b/examples/talk-llama/models/dbrx.cpp @@ -1,12 +1,11 @@ #include "models.h" - llm_build_dbrx::llm_build_dbrx(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; @@ -89,7 +88,7 @@ llm_build_dbrx::llm_build_dbrx(const llama_model & model, const llm_graph_params nullptr, n_expert, n_expert_used, LLM_FFN_SILU, true, - false, 0.0, + hparams.expert_weights_scale, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il); cb(cur, "ffn_moe_out", il); diff --git a/examples/talk-llama/models/deci.cpp b/examples/talk-llama/models/deci.cpp index 7410a3a46d9..ac448bfcaa8 100644 --- a/examples/talk-llama/models/deci.cpp +++ b/examples/talk-llama/models/deci.cpp @@ -3,10 +3,10 @@ llm_build_deci::llm_build_deci(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; diff --git a/examples/talk-llama/models/deepseek.cpp b/examples/talk-llama/models/deepseek.cpp index 17866c0d88e..3432359e03a 100644 --- a/examples/talk-llama/models/deepseek.cpp +++ b/examples/talk-llama/models/deepseek.cpp @@ -1,13 +1,11 @@ #include "models.h" - - llm_build_deepseek::llm_build_deepseek(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; @@ -100,7 +98,7 @@ llm_build_deepseek::llm_build_deepseek(const llama_model & model, const llm_grap nullptr, n_expert, n_expert_used, LLM_FFN_SILU, false, - false, hparams.expert_weights_scale, + hparams.expert_weights_scale, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il); cb(moe_out, "ffn_moe_out", il); diff --git a/examples/talk-llama/models/deepseek2.cpp b/examples/talk-llama/models/deepseek2.cpp index b608396e50e..d437fe29e71 100644 --- a/examples/talk-llama/models/deepseek2.cpp +++ b/examples/talk-llama/models/deepseek2.cpp @@ -8,7 +8,7 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr const int64_t n_embd_head_k = hparams.n_embd_head_k_mla(); const int64_t n_embd_head_v = hparams.n_embd_head_v_mla(); - const int64_t n_embd_head_qk_rope = hparams.n_rot; + const int64_t n_embd_head_qk_rope = hparams.n_rot(); const int64_t n_embd_head_qk_nope = n_embd_head_k - n_embd_head_qk_rope; const uint32_t kv_lora_rank = hparams.n_lora_kv; @@ -146,7 +146,7 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr cb(Qcur, "Qcur_attn_temp_scaled", il); } - // note: MLA with the absorption optimzation converts into MQA (ie: GQA with 1 group) + // note: MLA with the absorption optimization converts into MQA (ie: GQA with 1 group) cur = build_attn(inp_attn_k, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, model.layers[il].wv_b, kq_scale, il); @@ -216,7 +216,7 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr model.layers[il].ffn_exp_probs_b, n_expert, n_expert_used, LLM_FFN_SILU, hparams.expert_weights_norm, - hparams.expert_weights_scale, hparams.expert_weights_scale, + hparams.expert_weights_scale, (llama_expert_gating_func_type) hparams.expert_gating_func, il, nullptr, diff --git a/examples/talk-llama/models/delta-net-base.cpp b/examples/talk-llama/models/delta-net-base.cpp index 99f1fdd9538..6bc989c9509 100644 --- a/examples/talk-llama/models/delta-net-base.cpp +++ b/examples/talk-llama/models/delta-net-base.cpp @@ -1,6 +1,6 @@ #include "models.h" -#define CHUNK_SIZE 64 +#include "llama-impl.h" // utility to get one slice from the third dimension // input dim: [x, y, c, b] @@ -57,7 +57,7 @@ std::pair llm_build_delta_net_base::build_delta_ne g = ggml_permute(ctx0, g, 0, 2, 1, 3); // [g_0, n_tokens, H_v, n_seqs] b = ggml_permute(ctx0, b, 0, 2, 1, 3); // [ 1, n_tokens, H_v, n_seqs] - const int CS = CHUNK_SIZE; + const int CS = kda ? 16 : 64; // chunk size const int pad = (CS - n_tokens % CS) % CS; const int n_chunks = (n_tokens + pad) / CS; @@ -225,9 +225,8 @@ std::pair llm_build_delta_net_base::build_delta_ne ggml_tensor * kg_t = ggml_cont(ctx0, ggml_transpose(ctx0, kg)); cb(kg_t, "key_gdiff_t", il); - ggml_tensor * s_t = ggml_transpose(ctx0, s); - s_t = ggml_cont_4d(ctx0, s_t, S_v, S_v, 1, H_v * n_seqs); - cb(s_t, "dnet_add_ch_state", il); + s = ggml_reshape_4d(ctx0, s, S_v, S_v, 1, H_v * n_seqs); + cb(s, "dnet_add_ch_state", il); // [CS, S_v, n_chunks, H_v * n_seqs] ggml_tensor * v_t = ggml_cont(ctx0, ggml_transpose(ctx0, v)); @@ -240,7 +239,7 @@ std::pair llm_build_delta_net_base::build_delta_ne ggml_tensor * ch_kg_t = get_slice_2d(ctx0, kg_t, chunk); // [ CS, S_k, 1, H_v * n_seqs] // [CS, S_v, 1, H_v * n_seqs] - ggml_tensor * v_t_p = ggml_mul_mat(ctx0, ch_k_cd, s_t); + ggml_tensor * v_t_p = ggml_mul_mat(ctx0, ch_k_cd, s); cb(v_t_p, "v_prime", il); // [CS, S_v, 1, H_v * n_seqs] @@ -252,7 +251,7 @@ std::pair llm_build_delta_net_base::build_delta_ne cb(v_attn, "v_attn", il); // [S_v, CS, 1, H_v * n_seqs] - ggml_tensor * attn_inter = ggml_mul_mat(ctx0, s_t, ch_q_g_exp); + ggml_tensor * attn_inter = ggml_mul_mat(ctx0, s, ch_q_g_exp); cb(attn_inter, "attn_inter", il); // [S_v, CS, 1, H_v * n_seqs] @@ -268,13 +267,11 @@ std::pair llm_build_delta_net_base::build_delta_ne // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew ggml_tensor * ch_g_last_exp_t = get_slice_2d(ctx0, g_last_exp_t, chunk); - s_t = ggml_mul(ctx0, s_t, ch_g_last_exp_t); - s_t = ggml_add(ctx0, s_t, kgv); - cb(s_t, "dnet_add_ch_state", il); + s = ggml_mul(ctx0, s, ch_g_last_exp_t); + s = ggml_add(ctx0, s, kgv); + cb(s, "dnet_add_ch_state", il); } - s_t = ggml_reshape_4d(ctx0, s_t, S_v, S_v, H_v, n_seqs); - // truncate padded tokens ggml_tensor * o = ggml_view_4d(ctx0, v, S_v, n_tokens, H_v, n_seqs, @@ -282,7 +279,7 @@ std::pair llm_build_delta_net_base::build_delta_ne ggml_row_size(v->type, S_v * CS * n_chunks), ggml_row_size(v->type, S_v * CS * n_chunks * H_v), 0); o = ggml_permute (ctx0, o, 0, 2, 1, 3); // [S_v, H_v, n_tokens, n_seqs] - s = ggml_transpose(ctx0, s_t); + s = ggml_reshape_4d(ctx0, s, S_v, S_v, H_v, n_seqs); cb(s, "output_state", il); return {o, s}; @@ -341,11 +338,9 @@ std::pair llm_build_delta_net_base::build_delta_ne g = ggml_exp(ctx0, g); s = ggml_mul(ctx0, s, g); - ggml_tensor * s_t = ggml_cont(ctx0, ggml_transpose(ctx0, s)); - // [1, S_v, H_v, n_seqs] ggml_tensor * sk; - sk = ggml_mul (ctx0, s_t, k); + sk = ggml_mul (ctx0, s, k); sk = ggml_sum_rows(ctx0, sk); // [S_v, 1, H_v, n_seqs] @@ -362,15 +357,89 @@ std::pair llm_build_delta_net_base::build_delta_ne k = ggml_repeat(ctx0, k, s); kd = ggml_mul (ctx0, k, d_t); - s_t = ggml_add(ctx0, s_t, kd); + s = ggml_add(ctx0, s, kd); - cb(s_t, "dnet_add_ar_state", il); + cb(s, "dnet_add_ar_state", il); - ggml_tensor * s_q = ggml_mul (ctx0, s_t, q); + ggml_tensor * s_q = ggml_mul (ctx0, s, q); ggml_tensor * o = ggml_sum_rows(ctx0, s_q); o = ggml_permute (ctx0, o, 2, 0, 1, 3); // [S_v, H_v, n_tokens, n_seqs] - s = ggml_transpose(ctx0, s_t); // [S_v, S_v, H_v, n_seqs] return {o, s}; } + +std::pair llm_build_delta_net_base::build_delta_net_fused( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * b, + ggml_tensor * s, + int il) { + const int64_t S_k = q->ne[0]; + const int64_t H_k = q->ne[1]; + const int64_t n_tokens = q->ne[2]; + const int64_t n_seqs = q->ne[3]; + + const int64_t S_v = v->ne[0]; + const int64_t H_v = v->ne[1]; + + GGML_ASSERT(S_k == S_v); + GGML_ASSERT(H_v % H_k == 0); + + GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs); + GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs); + GGML_ASSERT(v->ne[0] == S_v && v->ne[1] == H_v && v->ne[2] == n_tokens && v->ne[3] == n_seqs); + + GGML_ASSERT(g->ne[0] == 1 || g->ne[0] == S_v); + GGML_ASSERT( g->ne[1] == H_v && g->ne[2] == n_tokens && g->ne[3] == n_seqs); + GGML_ASSERT(b->ne[0] == 1 && b->ne[1] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs); + GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs); + + ggml_tensor * result = ggml_gated_delta_net(ctx0, q, k, v, g, b, s); + if (n_tokens == 1) { + cb(result, LLAMA_TENSOR_NAME_FGDN_AR, il); + } else { + cb(result, LLAMA_TENSOR_NAME_FGDN_CH, il); + } + + ggml_tensor * output = ggml_view_4d(ctx0, result, + S_v, H_v, n_tokens, n_seqs, + ggml_row_size(result->type, S_v), + ggml_row_size(result->type, S_v * H_v), + ggml_row_size(result->type, S_v * H_v * n_tokens), 0); + + ggml_tensor * new_state = ggml_view_4d(ctx0, result, + S_v, S_v, H_v, n_seqs, + ggml_row_size(result->type, S_v), + ggml_row_size(result->type, S_v * S_v), + ggml_row_size(result->type, S_v * S_v * H_v), + ggml_row_size(result->type, S_v * H_v * n_tokens * n_seqs)); + + return {output, new_state}; +} + +std::pair llm_build_delta_net_base::build_delta_net( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * b, + ggml_tensor * s, + int il) { + const int64_t n_seq_tokens = q->ne[2]; + + if (n_seq_tokens == 1) { + if (cparams.fused_gdn_ar) { + return build_delta_net_fused(q, k, v, g, b, s, il); + } + return build_delta_net_autoregressive(q, k, v, g, b, s, il); + } + + if (cparams.fused_gdn_ch) { + return build_delta_net_fused(q, k, v, g, b, s, il); + } + + return build_delta_net_chunking(q, k, v, g, b, s, il); +} diff --git a/examples/talk-llama/models/dots1.cpp b/examples/talk-llama/models/dots1.cpp index 09c36f82fe2..07236dd27c9 100644 --- a/examples/talk-llama/models/dots1.cpp +++ b/examples/talk-llama/models/dots1.cpp @@ -1,13 +1,11 @@ #include "models.h" - - llm_build_dots1::llm_build_dots1(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; @@ -91,7 +89,7 @@ llm_build_dots1::llm_build_dots1(const llama_model & model, const llm_graph_para model.layers[il].ffn_exp_probs_b, n_expert, n_expert_used, LLM_FFN_SILU, hparams.expert_weights_norm, - true, hparams.expert_weights_scale, + hparams.expert_weights_scale, (llama_expert_gating_func_type) hparams.expert_gating_func, il); cb(moe_out, "ffn_moe_out", il); diff --git a/examples/talk-llama/models/dream.cpp b/examples/talk-llama/models/dream.cpp index 2aafbae1397..4edc8530cb3 100644 --- a/examples/talk-llama/models/dream.cpp +++ b/examples/talk-llama/models/dream.cpp @@ -5,10 +5,10 @@ llm_build_dream::llm_build_dream(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { //copied from qwen2 - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; diff --git a/examples/talk-llama/models/ernie4-5-moe.cpp b/examples/talk-llama/models/ernie4-5-moe.cpp index 0d96d14e6fd..63baf152c40 100644 --- a/examples/talk-llama/models/ernie4-5-moe.cpp +++ b/examples/talk-llama/models/ernie4-5-moe.cpp @@ -1,13 +1,11 @@ #include "models.h" - - llm_build_ernie4_5_moe::llm_build_ernie4_5_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; @@ -103,7 +101,7 @@ llm_build_ernie4_5_moe::llm_build_ernie4_5_moe(const llama_model & model, const model.layers[il].ffn_exp_probs_b, n_expert, n_expert_used, LLM_FFN_SILU, true, - false, 0.0, + hparams.expert_weights_scale, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il); cb(moe_out, "ffn_moe_out", il); diff --git a/examples/talk-llama/models/ernie4-5.cpp b/examples/talk-llama/models/ernie4-5.cpp index 99aead53283..d548de0547b 100644 --- a/examples/talk-llama/models/ernie4-5.cpp +++ b/examples/talk-llama/models/ernie4-5.cpp @@ -2,10 +2,10 @@ llm_build_ernie4_5::llm_build_ernie4_5(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; diff --git a/examples/talk-llama/models/eurobert.cpp b/examples/talk-llama/models/eurobert.cpp index 86e3176edc0..e8628d165d0 100644 --- a/examples/talk-llama/models/eurobert.cpp +++ b/examples/talk-llama/models/eurobert.cpp @@ -1,9 +1,9 @@ #include "models.h" llm_build_eurobert::llm_build_eurobert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); ggml_tensor * cur; ggml_tensor * inpL; diff --git a/examples/talk-llama/models/exaone-moe.cpp b/examples/talk-llama/models/exaone-moe.cpp index bef5b2ad351..ea75701c528 100644 --- a/examples/talk-llama/models/exaone-moe.cpp +++ b/examples/talk-llama/models/exaone-moe.cpp @@ -1,12 +1,11 @@ #include "models.h" - llm_build_exaone_moe::llm_build_exaone_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_k; + const int64_t n_embd_head = hparams.n_embd_head_k(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_v); - GGML_ASSERT(n_embd_head == hparams.n_rot); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_v()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; @@ -100,7 +99,7 @@ llm_build_exaone_moe::llm_build_exaone_moe(const llama_model & model, const llm_ model.layers[il].ffn_exp_probs_b, n_expert, n_expert_used, LLM_FFN_SILU, hparams.expert_weights_norm, - true, hparams.expert_weights_scale, + hparams.expert_weights_scale, (llama_expert_gating_func_type) hparams.expert_gating_func, il); cb(moe_out, "ffn_moe_out", il); diff --git a/examples/talk-llama/models/exaone.cpp b/examples/talk-llama/models/exaone.cpp index 62602b284de..d4eea58e2f1 100644 --- a/examples/talk-llama/models/exaone.cpp +++ b/examples/talk-llama/models/exaone.cpp @@ -4,10 +4,10 @@ llm_build_exaone::llm_build_exaone(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; diff --git a/examples/talk-llama/models/exaone4.cpp b/examples/talk-llama/models/exaone4.cpp index 8b7e3dc06e5..755af3b747b 100644 --- a/examples/talk-llama/models/exaone4.cpp +++ b/examples/talk-llama/models/exaone4.cpp @@ -4,10 +4,10 @@ template llm_build_exaone4::llm_build_exaone4(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_k; + const int64_t n_embd_head = hparams.n_embd_head_k(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_v); - GGML_ASSERT(n_embd_head == hparams.n_rot); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_v()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; diff --git a/examples/talk-llama/models/falcon-h1.cpp b/examples/talk-llama/models/falcon-h1.cpp index 785a7e5e662..ff842d93a41 100644 --- a/examples/talk-llama/models/falcon-h1.cpp +++ b/examples/talk-llama/models/falcon-h1.cpp @@ -2,7 +2,7 @@ llm_build_falcon_h1::llm_build_falcon_h1(const llama_model & model, const llm_graph_params & params) : llm_build_mamba_base(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); ggml_tensor * cur; ggml_tensor * inpL; diff --git a/examples/talk-llama/models/falcon.cpp b/examples/talk-llama/models/falcon.cpp index db1ccdb5008..9fcba508878 100644 --- a/examples/talk-llama/models/falcon.cpp +++ b/examples/talk-llama/models/falcon.cpp @@ -2,11 +2,11 @@ llm_build_falcon::llm_build_falcon(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; diff --git a/examples/talk-llama/models/gemma-embedding.cpp b/examples/talk-llama/models/gemma-embedding.cpp index 944c198bf95..98110d45e3b 100644 --- a/examples/talk-llama/models/gemma-embedding.cpp +++ b/examples/talk-llama/models/gemma-embedding.cpp @@ -2,7 +2,7 @@ llm_build_gemma_embedding::llm_build_gemma_embedding(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_k; + const int64_t n_embd_head = hparams.n_embd_head_k(); ggml_tensor * cur; ggml_tensor * inpL; diff --git a/examples/talk-llama/models/gemma.cpp b/examples/talk-llama/models/gemma.cpp index 4893d9af4b8..1869efd389a 100644 --- a/examples/talk-llama/models/gemma.cpp +++ b/examples/talk-llama/models/gemma.cpp @@ -2,7 +2,7 @@ llm_build_gemma::llm_build_gemma(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); ggml_tensor * cur; ggml_tensor * inpL; diff --git a/examples/talk-llama/models/gemma2-iswa.cpp b/examples/talk-llama/models/gemma2-iswa.cpp index 7a9198193ac..3927ddd297b 100644 --- a/examples/talk-llama/models/gemma2-iswa.cpp +++ b/examples/talk-llama/models/gemma2-iswa.cpp @@ -1,7 +1,7 @@ #include "models.h" llm_build_gemma2_iswa::llm_build_gemma2_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_k; + const int64_t n_embd_head = hparams.n_embd_head_k(); ggml_tensor * cur; ggml_tensor * inpL; diff --git a/examples/talk-llama/models/gemma3.cpp b/examples/talk-llama/models/gemma3.cpp index dec3fc4b8bc..bbb4d9a81e8 100644 --- a/examples/talk-llama/models/gemma3.cpp +++ b/examples/talk-llama/models/gemma3.cpp @@ -2,7 +2,7 @@ template llm_build_gemma3::llm_build_gemma3(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_k; + const int64_t n_embd_head = hparams.n_embd_head_k(); ggml_tensor * cur; ggml_tensor * inpL; diff --git a/examples/talk-llama/models/gemma3n-iswa.cpp b/examples/talk-llama/models/gemma3n-iswa.cpp index 7db6d3bf4ec..8ce2ae39c2f 100644 --- a/examples/talk-llama/models/gemma3n-iswa.cpp +++ b/examples/talk-llama/models/gemma3n-iswa.cpp @@ -3,7 +3,7 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params), model(model), - n_embd_head(model.hparams.n_embd_head_k), + n_embd_head(model.hparams.n_embd_head_k()), n_embd_altup(model.hparams.n_embd_altup), n_altup(model.hparams.n_altup), i_altup_act(model.hparams.i_altup_act) { diff --git a/examples/talk-llama/models/glm4-moe.cpp b/examples/talk-llama/models/glm4-moe.cpp index 003f70f7396..7938545ed8a 100644 --- a/examples/talk-llama/models/glm4-moe.cpp +++ b/examples/talk-llama/models/glm4-moe.cpp @@ -1,9 +1,9 @@ #include "models.h" llm_build_glm4_moe::llm_build_glm4_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); int sections[4]; std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); @@ -128,7 +128,7 @@ llm_build_glm4_moe::llm_build_glm4_moe(const llama_model & model, const llm_grap model.layers[il].ffn_exp_probs_b, n_expert, n_expert_used, LLM_FFN_SILU, hparams.expert_weights_norm, - true, hparams.expert_weights_scale, + hparams.expert_weights_scale, (llama_expert_gating_func_type) hparams.expert_gating_func, il); cb(routed_out, "ffn_moe_out", il); diff --git a/examples/talk-llama/models/glm4.cpp b/examples/talk-llama/models/glm4.cpp index bcd837b30d6..b6ad8febed3 100644 --- a/examples/talk-llama/models/glm4.cpp +++ b/examples/talk-llama/models/glm4.cpp @@ -3,10 +3,10 @@ llm_build_glm4::llm_build_glm4(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); int sections[4]; std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); diff --git a/examples/talk-llama/models/gpt2.cpp b/examples/talk-llama/models/gpt2.cpp index 60761c8e765..cb1238f2d34 100644 --- a/examples/talk-llama/models/gpt2.cpp +++ b/examples/talk-llama/models/gpt2.cpp @@ -1,10 +1,10 @@ #include "models.h" llm_build_gpt2::llm_build_gpt2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); ggml_tensor * cur; ggml_tensor * pos; diff --git a/examples/talk-llama/models/gptneox.cpp b/examples/talk-llama/models/gptneox.cpp index 2151b14e939..1c8fe6c836d 100644 --- a/examples/talk-llama/models/gptneox.cpp +++ b/examples/talk-llama/models/gptneox.cpp @@ -2,10 +2,10 @@ llm_build_gptneox::llm_build_gptneox(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); ggml_tensor * cur; ggml_tensor * inpL; diff --git a/examples/talk-llama/models/granite-hybrid.cpp b/examples/talk-llama/models/granite-hybrid.cpp index 726ecdcca77..9b54a38c386 100644 --- a/examples/talk-llama/models/granite-hybrid.cpp +++ b/examples/talk-llama/models/granite-hybrid.cpp @@ -1,10 +1,9 @@ #include "models.h" - llm_build_granite_hybrid::llm_build_granite_hybrid(const llama_model & model, const llm_graph_params & params) : llm_build_mamba_base(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + const int64_t n_embd_head = hparams.n_embd_head_v(); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); ggml_tensor * cur; ggml_tensor * inpL; @@ -160,7 +159,7 @@ ggml_tensor * llm_build_granite_hybrid::build_layer_ffn(ggml_tensor * cur, nullptr, n_expert, n_expert_used, LLM_FFN_SILU, true, - false, 0.0, + hparams.expert_weights_scale, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il); cb(moe_out, "ffn_moe_out", il); diff --git a/examples/talk-llama/models/granite.cpp b/examples/talk-llama/models/granite.cpp index 18748e9c26c..7a7e1664c29 100644 --- a/examples/talk-llama/models/granite.cpp +++ b/examples/talk-llama/models/granite.cpp @@ -1,15 +1,14 @@ #include "models.h" - llm_build_granite::llm_build_granite( const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; @@ -175,7 +174,7 @@ ggml_tensor * llm_build_granite::build_layer_ffn( nullptr, n_expert, n_expert_used, LLM_FFN_SILU, true, - false, 0.0, + hparams.expert_weights_scale, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il); cb(moe_out, "ffn_moe_out", il); diff --git a/examples/talk-llama/models/grok.cpp b/examples/talk-llama/models/grok.cpp index 3c54dfee636..580d63e36ae 100644 --- a/examples/talk-llama/models/grok.cpp +++ b/examples/talk-llama/models/grok.cpp @@ -1,10 +1,10 @@ #include "models.h" llm_build_grok::llm_build_grok(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; @@ -99,7 +99,7 @@ llm_build_grok::llm_build_grok(const llama_model & model, const llm_graph_params nullptr, n_expert, n_expert_used, LLM_FFN_GELU, true, - false, 0.0, + hparams.expert_weights_scale, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il); cb(moe_out, "ffn_moe_out", il); diff --git a/examples/talk-llama/models/grovemoe.cpp b/examples/talk-llama/models/grovemoe.cpp index 56b6db9a3d0..aa60d3e9388 100644 --- a/examples/talk-llama/models/grovemoe.cpp +++ b/examples/talk-llama/models/grovemoe.cpp @@ -1,14 +1,12 @@ #include "models.h" - - llm_build_grovemoe::llm_build_grovemoe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); const int64_t n_chunk_expert = n_expert / hparams.n_group_experts; - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; @@ -90,7 +88,7 @@ llm_build_grovemoe::llm_build_grovemoe(const llama_model & model, const llm_grap nullptr, n_expert, n_expert_used, LLM_FFN_SILU, true, - false, 0.0, + hparams.expert_weights_scale, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il, probs); @@ -106,7 +104,7 @@ llm_build_grovemoe::llm_build_grovemoe(const llama_model & model, const llm_grap nullptr, n_chunk_expert, n_expert_used > n_chunk_expert ? n_chunk_expert : n_expert_used, LLM_FFN_SILU, true, - false, 0.0, + hparams.expert_weights_scale, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il, probs); diff --git a/examples/talk-llama/models/hunyuan-dense.cpp b/examples/talk-llama/models/hunyuan-dense.cpp index 7d5dcc7828b..6a51707c85b 100644 --- a/examples/talk-llama/models/hunyuan-dense.cpp +++ b/examples/talk-llama/models/hunyuan-dense.cpp @@ -1,10 +1,10 @@ #include "models.h" llm_build_hunyuan_dense::llm_build_hunyuan_dense(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; diff --git a/examples/talk-llama/models/hunyuan-moe.cpp b/examples/talk-llama/models/hunyuan-moe.cpp index 77e39de5b8b..806c30b3667 100644 --- a/examples/talk-llama/models/hunyuan-moe.cpp +++ b/examples/talk-llama/models/hunyuan-moe.cpp @@ -1,10 +1,10 @@ #include "models.h" llm_build_hunyuan_moe::llm_build_hunyuan_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; @@ -119,8 +119,7 @@ llm_build_hunyuan_moe::llm_build_hunyuan_moe(const llama_model & model, const ll n_expert, n_expert_used, LLM_FFN_SILU, true, // norm_topk_prob - false, - 0.0, + hparams.expert_weights_scale, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il); cb(cur_moe, "ffn_moe_out", il); diff --git a/examples/talk-llama/models/internlm2.cpp b/examples/talk-llama/models/internlm2.cpp index 387e8211270..441d250268e 100644 --- a/examples/talk-llama/models/internlm2.cpp +++ b/examples/talk-llama/models/internlm2.cpp @@ -1,10 +1,10 @@ #include "models.h" llm_build_internlm2::llm_build_internlm2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; diff --git a/examples/talk-llama/models/jais.cpp b/examples/talk-llama/models/jais.cpp index 3e3376e6a62..135bf288ba1 100644 --- a/examples/talk-llama/models/jais.cpp +++ b/examples/talk-llama/models/jais.cpp @@ -1,10 +1,10 @@ #include "models.h" llm_build_jais::llm_build_jais(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); ggml_tensor * cur; ggml_tensor * inpL; diff --git a/examples/talk-llama/models/jais2.cpp b/examples/talk-llama/models/jais2.cpp index a69fcaa3bb3..2cfe484eb52 100644 --- a/examples/talk-llama/models/jais2.cpp +++ b/examples/talk-llama/models/jais2.cpp @@ -3,10 +3,10 @@ // JAIS-2 model graph builder // Uses: LayerNorm (not RMSNorm), relu2 activation, separate Q/K/V, RoPE embeddings llm_build_jais2::llm_build_jais2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; diff --git a/examples/talk-llama/models/jamba.cpp b/examples/talk-llama/models/jamba.cpp index ceab5817407..c0c89de187a 100644 --- a/examples/talk-llama/models/jamba.cpp +++ b/examples/talk-llama/models/jamba.cpp @@ -1,7 +1,7 @@ #include "models.h" llm_build_jamba::llm_build_jamba(const llama_model & model, const llm_graph_params & params) : llm_build_mamba_base(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); ggml_tensor * cur; ggml_tensor * inpL; @@ -76,7 +76,7 @@ llm_build_jamba::llm_build_jamba(const llama_model & model, const llm_graph_para nullptr, n_expert, n_expert_used, LLM_FFN_SILU, false, - false, 0.0, + hparams.expert_weights_scale, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il); cb(cur, "ffn_moe_out", il); diff --git a/examples/talk-llama/models/kimi-linear.cpp b/examples/talk-llama/models/kimi-linear.cpp index 83d11241f8d..4d62f4e7159 100644 --- a/examples/talk-llama/models/kimi-linear.cpp +++ b/examples/talk-llama/models/kimi-linear.cpp @@ -1,5 +1,4 @@ #include "models.h" -#include "ggml.h" #include "llama-memory-recurrent.h" @@ -103,7 +102,7 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll const int64_t kv_lora_rank = hparams.n_lora_kv; // qk_rope_head_dim = 64 (from Kimi config) which is hparams.n_rot // Confirmed from tensor shape: wkv_a_mqa [2304, 576] = [n_embd, kv_lora_rank + qk_rope_head_dim] - const int64_t n_embd_head_qk_rope = hparams.n_rot; // config.qk_rope_head_dim + const int64_t n_embd_head_qk_rope = hparams.n_rot(); // config.qk_rope_head_dim const int64_t n_embd_head_qk_nope = n_embd_head_k_mla - n_embd_head_qk_rope; // 192 - 64 = 128 // Attention scale for MLA const float kq_scale_mla = 1.0f / sqrtf((float)n_embd_head_k_mla); @@ -118,12 +117,7 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll ggml_build_forward_expand(gf, cur); - // Check layer type by checking which tensors exist - // KDA layers have ssm_a_log tensor, MLA layers have wkv_a_mqa tensor - bool is_kda = (layer.ssm_a != nullptr); - bool is_mla = (layer.wkv_a_mqa != nullptr); - - if (is_kda) { + if (hparams.is_recurrent(il)) { // === KDA Layer (Kimi Delta Attention) with Recurrent State === // Reference: vLLM kda.py const auto * mctx_cur = inp_rs->mctx; @@ -175,9 +169,7 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll Kcur = ggml_l2_norm(ctx0, Kcur, eps_norm); // Choose between build_delta_net_chunking and build_delta_net_recurrent based on n_tokens - std::pair attn_out = n_seq_tokens == 1 ? - build_delta_net_autoregressive(Qcur, Kcur, Vcur, g1, beta, state, il) : - build_delta_net_chunking(Qcur, Kcur, Vcur, g1, beta, state, il); + auto attn_out = build_delta_net(Qcur, Kcur, Vcur, g1, beta, state, il); ggml_tensor * output = ggml_cont(ctx0, attn_out.first); ggml_tensor * new_state = attn_out.second; @@ -211,7 +203,7 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll cur = ggml_mul_mat(ctx0, layer.wo, gated); cb(cur, "kda_out", il); - } else if (is_mla) { + } else { // === MLA Layer (Multi-head Latent Attention) without KV Cache === // Reference: vLLM mla.py // Step 1: Q projection and reshape @@ -310,9 +302,6 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll cur = build_attn(inp_attn_kv, layer.wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale_mla, il); cb(cur, "mla_out", il); } - } else { - // Unknown layer type - this should not happen - GGML_ABORT("Kimi layer is neither KDA nor MLA - missing required tensors"); } // On last layer, select only the output tokens @@ -349,7 +338,7 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll hparams.n_expert, hparams.n_expert_used, LLM_FFN_SILU, true, - true, hparams.expert_weights_scale, + hparams.expert_weights_scale, (llama_expert_gating_func_type) hparams.expert_gating_func, il); cb(moe_out, "ffn_moe_out", il); diff --git a/examples/talk-llama/models/lfm2.cpp b/examples/talk-llama/models/lfm2.cpp index cf01ad62557..dfa322166b1 100644 --- a/examples/talk-llama/models/lfm2.cpp +++ b/examples/talk-llama/models/lfm2.cpp @@ -23,17 +23,23 @@ llm_build_lfm2::llm_build_lfm2(const llama_model & model, const llm_graph_ }; auto build_moe_feed_forward = [&model, this](ggml_tensor * cur, int il) -> ggml_tensor * { return build_moe_ffn(cur, - model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps, - model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps, - model.layers[il].ffn_exp_probs_b, n_expert, n_expert_used, LLM_FFN_SILU, true, false, 0.0, - static_cast(hparams.expert_gating_func), il); + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + model.layers[il].ffn_exp_probs_b, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + hparams.expert_weights_scale, + static_cast(hparams.expert_gating_func), + il); }; auto build_attn_block = [&model, this](ggml_tensor * cur, ggml_tensor * inp_pos, inp_attn_type * inp_attn, int il) -> ggml_tensor * { GGML_ASSERT(hparams.n_embd_v_gqa(il) == hparams.n_embd_k_gqa(il)); - const auto n_embd_head = hparams.n_embd_head_v; + const auto n_embd_head = hparams.n_embd_head_v(); const auto n_head_kv = hparams.n_head_kv(il); auto * q = build_lora_mm(model.layers[il].wq, cur); diff --git a/examples/talk-llama/models/llada-moe.cpp b/examples/talk-llama/models/llada-moe.cpp index 5f64686f5fb..18de88fde1f 100644 --- a/examples/talk-llama/models/llada-moe.cpp +++ b/examples/talk-llama/models/llada-moe.cpp @@ -1,10 +1,10 @@ #include "models.h" llm_build_llada_moe::llm_build_llada_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; @@ -90,7 +90,7 @@ llm_build_llada_moe::llm_build_llada_moe(const llama_model & model, const llm_gr nullptr, n_expert, n_expert_used, LLM_FFN_SILU, false, - false, 0.0, + hparams.expert_weights_scale, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il); cb(cur, "ffn_moe_out", il); diff --git a/examples/talk-llama/models/llada.cpp b/examples/talk-llama/models/llada.cpp index 857033660a0..0dac9d616ae 100644 --- a/examples/talk-llama/models/llada.cpp +++ b/examples/talk-llama/models/llada.cpp @@ -2,10 +2,10 @@ llm_build_llada::llm_build_llada(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { // LLaDA is similar to LLaMA but uses non-causal attention for diffusion - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; diff --git a/examples/talk-llama/models/llama-iswa.cpp b/examples/talk-llama/models/llama-iswa.cpp index 61dd2c179f1..67cb9a10ec5 100644 --- a/examples/talk-llama/models/llama-iswa.cpp +++ b/examples/talk-llama/models/llama-iswa.cpp @@ -1,10 +1,10 @@ #include "models.h" llm_build_llama_iswa::llm_build_llama_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; @@ -134,7 +134,7 @@ llm_build_llama_iswa::llm_build_llama_iswa(const llama_model & model, const llm_ nullptr, n_expert, n_expert_used, LLM_FFN_SILU, false, - false, 0.0, + hparams.expert_weights_scale, LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID, il); diff --git a/examples/talk-llama/models/llama.cpp b/examples/talk-llama/models/llama.cpp index 42b5fcdf42e..e08ae0c0b0e 100644 --- a/examples/talk-llama/models/llama.cpp +++ b/examples/talk-llama/models/llama.cpp @@ -2,10 +2,10 @@ template llm_build_llama::llm_build_llama(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; @@ -43,19 +43,19 @@ llm_build_llama::llm_build_llama(const llama_model & model, const llm_gra ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur, model.layers[il].wq_s); cb(Qcur, "Qcur", il); if (model.layers[il].bq) { Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); cb(Qcur, "Qcur", il); } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur, model.layers[il].wk_s); cb(Kcur, "Kcur", il); if (model.layers[il].bk) { Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); cb(Kcur, "Kcur", il); } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur, model.layers[il].wv_s); cb(Vcur, "Vcur", il); if (model.layers[il].bv) { Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); @@ -91,6 +91,9 @@ llm_build_llama::llm_build_llama(const llama_model & model, const llm_gra cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + if (model.layers[il].wo_s) { + cur = ggml_mul(ctx0, cur, model.layers[il].wo_s); + } cb(cur, "attn_out", il); } if (il == n_layer - 1 && inp_out_ids) { @@ -109,9 +112,9 @@ llm_build_llama::llm_build_llama(const llama_model & model, const llm_gra cb(cur, "ffn_norm", il); cur = build_ffn(cur, - model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, - model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, - model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, model.layers[il].ffn_up_s, + model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, model.layers[il].ffn_gate_s, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, model.layers[il].ffn_down_s, NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); cb(cur, "ffn_out", il); @@ -130,9 +133,13 @@ llm_build_llama::llm_build_llama(const llama_model & model, const llm_gra nullptr, n_expert, n_expert_used, LLM_FFN_SILU, true, - false, 0.0, + hparams.expert_weights_scale, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, - il); + il, + nullptr, nullptr, + model.layers[il].ffn_up_exps_s, + model.layers[il].ffn_gate_exps_s, + model.layers[il].ffn_down_exps_s); cb(cur, "ffn_moe_out", il); } cur = ggml_add(ctx0, cur, ffn_inp); diff --git a/examples/talk-llama/models/maincoder.cpp b/examples/talk-llama/models/maincoder.cpp index da57308167e..a72b7790a1f 100644 --- a/examples/talk-llama/models/maincoder.cpp +++ b/examples/talk-llama/models/maincoder.cpp @@ -1,10 +1,10 @@ #include "models.h" llm_build_maincoder::llm_build_maincoder(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; diff --git a/examples/talk-llama/models/mamba-base.cpp b/examples/talk-llama/models/mamba-base.cpp index aaac9487dfa..9de587db55f 100644 --- a/examples/talk-llama/models/mamba-base.cpp +++ b/examples/talk-llama/models/mamba-base.cpp @@ -30,6 +30,7 @@ ggml_tensor * llm_build_mamba_base::build_mamba_layer(llm_graph_input_rs * inp, GGML_ASSERT(n_seqs != 0); GGML_ASSERT(ubatch.equal_seqs()); GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); + GGML_ASSERT(d_inner % n_head == 0); ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); @@ -167,6 +168,9 @@ ggml_tensor * llm_build_mamba_base::build_mamba2_layer(llm_graph_input_rs * inp, GGML_ASSERT(n_seqs != 0); GGML_ASSERT(ubatch.equal_seqs()); GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); + GGML_ASSERT(d_inner % n_head == 0); + GGML_ASSERT(d_inner % d_state == 0); + GGML_ASSERT(d_inner % n_group == 0); ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); diff --git a/examples/talk-llama/models/mimo2-iswa.cpp b/examples/talk-llama/models/mimo2-iswa.cpp index edc87cc9f0d..06956915ea0 100644 --- a/examples/talk-llama/models/mimo2-iswa.cpp +++ b/examples/talk-llama/models/mimo2-iswa.cpp @@ -1,4 +1,3 @@ - #include "models.h" llm_build_mimo2_iswa::llm_build_mimo2_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { @@ -88,10 +87,17 @@ llm_build_mimo2_iswa::llm_build_mimo2_iswa(const llama_model & model, const llm_ cb(cur, "ffn_out", il); } else { // MoE branch - cur = build_moe_ffn(cur, model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps, - model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps, - model.layers[il].ffn_exp_probs_b, n_expert, n_expert_used, LLM_FFN_SILU, true, false, - 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID, il); + cur = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + model.layers[il].ffn_exp_probs_b, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + hparams.expert_weights_scale, + LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID, + il); cb(cur, "ffn_moe_out", il); } diff --git a/examples/talk-llama/models/minicpm3.cpp b/examples/talk-llama/models/minicpm3.cpp index 297cc34ba58..89dd7105157 100644 --- a/examples/talk-llama/models/minicpm3.cpp +++ b/examples/talk-llama/models/minicpm3.cpp @@ -5,10 +5,10 @@ llm_build_minicpm3::llm_build_minicpm3(const llama_model & model, const llm_grap const int64_t n_embd_base = 256; const float scale_embd = 12.0f; const float scale_depth = 1.4f; - const float kq_scale = 1.0f / sqrtf(float(hparams.n_embd_head_k)); + const float kq_scale = 1.0f / sqrtf(float(hparams.n_embd_head_k())); - const uint32_t n_embd_head_qk_rope = hparams.n_rot; - const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot; + const uint32_t n_embd_head_qk_rope = hparams.n_rot(); + const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k() - hparams.n_rot(); const uint32_t kv_lora_rank = hparams.n_lora_kv; @@ -51,21 +51,21 @@ llm_build_minicpm3::llm_build_minicpm3(const llama_model & model, const llm_grap LLM_NORM_RMS, il); cb(q, "q", il); - // {q_lora_rank, n_head * hparams.n_embd_head_k} * {q_lora_rank, n_tokens} -> {n_head * hparams.n_embd_head_k, n_tokens} + // {q_lora_rank, n_head * hparams.n_embd_head_k()} * {q_lora_rank, n_tokens} -> {n_head * hparams.n_embd_head_k(), n_tokens} q = ggml_mul_mat(ctx0, model.layers[il].wq_b, q); cb(q, "q", il); // split into {n_head * n_embd_head_qk_nope, n_tokens} ggml_tensor * q_nope = ggml_view_3d(ctx0, q, n_embd_head_qk_nope, n_head, n_tokens, - ggml_row_size(q->type, hparams.n_embd_head_k), - ggml_row_size(q->type, hparams.n_embd_head_k * n_head), + ggml_row_size(q->type, hparams.n_embd_head_k()), + ggml_row_size(q->type, hparams.n_embd_head_k() * n_head), 0); cb(q_nope, "q_nope", il); // and {n_head * n_embd_head_qk_rope, n_tokens} ggml_tensor * q_pe = ggml_view_3d(ctx0, q, n_embd_head_qk_rope, n_head, n_tokens, - ggml_row_size(q->type, hparams.n_embd_head_k), - ggml_row_size(q->type, hparams.n_embd_head_k * n_head), + ggml_row_size(q->type, hparams.n_embd_head_k()), + ggml_row_size(q->type, hparams.n_embd_head_k() * n_head), ggml_row_size(q->type, n_embd_head_qk_nope)); cb(q_pe, "q_pe", il); @@ -97,15 +97,15 @@ llm_build_minicpm3::llm_build_minicpm3(const llama_model & model, const llm_grap // split into {n_head * n_embd_head_qk_nope, n_tokens} ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens, - ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v), - ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), + ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v()), + ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v())), 0); cb(k_nope, "k_nope", il); // and {n_head * n_embd_head_v, n_tokens} - ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens, - ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)), - ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head), + ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v(), n_head, n_tokens, + ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v())), + ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v())*n_head), ggml_row_size(kv->type, (n_embd_head_qk_nope))); cb(v_states, "v_states", il); diff --git a/examples/talk-llama/models/minimax-m2.cpp b/examples/talk-llama/models/minimax-m2.cpp index f7001badf75..83d0916c08c 100644 --- a/examples/talk-llama/models/minimax-m2.cpp +++ b/examples/talk-llama/models/minimax-m2.cpp @@ -1,11 +1,10 @@ - #include "models.h" llm_build_minimax_m2::llm_build_minimax_m2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - // GGML_ASSERT(n_embd_head == hparams.n_rot); this is wrong in case of minimax, head_dim = 128, n_rot = 64 + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + // GGML_ASSERT(n_embd_head == n_rot); this is wrong in case of minimax, head_dim = 128, n_rot = 64 ggml_tensor * cur; ggml_tensor * inpL; @@ -91,7 +90,7 @@ llm_build_minimax_m2::llm_build_minimax_m2(const llama_model & model, const llm_ model.layers[il].ffn_exp_probs_b, n_expert, n_expert_used, LLM_FFN_SILU, true, - false, 0.0, + hparams.expert_weights_scale, (llama_expert_gating_func_type) hparams.expert_gating_func, il); cb(cur, "ffn_moe_out", il); diff --git a/examples/talk-llama/models/mistral3.cpp b/examples/talk-llama/models/mistral3.cpp index 0b672235911..42a5117ff02 100644 --- a/examples/talk-llama/models/mistral3.cpp +++ b/examples/talk-llama/models/mistral3.cpp @@ -1,10 +1,10 @@ #include "models.h" llm_build_mistral3::llm_build_mistral3(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; @@ -127,7 +127,7 @@ llm_build_mistral3::llm_build_mistral3(const llama_model & model, const llm_grap nullptr, n_expert, n_expert_used, LLM_FFN_SILU, true, - false, 0.0, + hparams.expert_weights_scale, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il); cb(cur, "ffn_moe_out", il); diff --git a/examples/talk-llama/models/models.h b/examples/talk-llama/models/models.h index 0712d03d8d9..a86b2b1ebd7 100644 --- a/examples/talk-llama/models/models.h +++ b/examples/talk-llama/models/models.h @@ -3,7 +3,7 @@ #include "llama-model.h" #include "llama-graph.h" -// note: almost all graphs require atleast sqrtf, so include cmath globally +// note: almost all graphs require at least sqrtf, so include cmath globally #include // @@ -44,6 +44,26 @@ struct llm_build_delta_net_base : public llm_graph_context { ggml_tensor * b, ggml_tensor * s, int il); + + // use the ggml_gated_delta_net fused operator + std::pair build_delta_net_fused( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * b, + ggml_tensor * s, + int il); + + // choose one of two implementations above based on the number of tokens + std::pair build_delta_net( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * b, + ggml_tensor * s, + int il); }; struct llm_build_rwkv6_base : public llm_graph_context { diff --git a/examples/talk-llama/models/modern-bert.cpp b/examples/talk-llama/models/modern-bert.cpp index 32066c712b4..26020584c6d 100644 --- a/examples/talk-llama/models/modern-bert.cpp +++ b/examples/talk-llama/models/modern-bert.cpp @@ -1,10 +1,10 @@ #include "models.h" llm_build_modern_bert::llm_build_modern_bert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); ggml_tensor * cur; ggml_tensor * inpL; diff --git a/examples/talk-llama/models/mpt.cpp b/examples/talk-llama/models/mpt.cpp index 2328e027a74..ce44a805f5c 100644 --- a/examples/talk-llama/models/mpt.cpp +++ b/examples/talk-llama/models/mpt.cpp @@ -3,10 +3,10 @@ llm_build_mpt::llm_build_mpt(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); ggml_tensor * cur; ggml_tensor * pos; diff --git a/examples/talk-llama/models/nemotron-h.cpp b/examples/talk-llama/models/nemotron-h.cpp index d61d62a8c96..7af99174d16 100644 --- a/examples/talk-llama/models/nemotron-h.cpp +++ b/examples/talk-llama/models/nemotron-h.cpp @@ -2,8 +2,8 @@ llm_build_nemotron_h::llm_build_nemotron_h(const llama_model & model, const llm_graph_params & params) : llm_build_mamba_base(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + const int64_t n_embd_head = hparams.n_embd_head_v(); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); ggml_tensor * cur; ggml_tensor * inpL; @@ -114,9 +114,18 @@ ggml_tensor * llm_build_nemotron_h::build_ffn_layer(ggml_tensor * cur, const lla LLM_FFN_RELU_SQR, LLM_FFN_PAR, il); cb(cur, "ffn_out", il); } else { - ggml_tensor * ffn_inp = cur; + ggml_tensor * inp_emb = cur; + ggml_tensor * inp_latent = cur; + + if (model.layers[il].ffn_latent_down) { + inp_latent = ggml_mul_mat(ctx0, model.layers[il].ffn_latent_down, cur); + } + + ggml_tensor * router_logits = build_lora_mm(model.layers[il].ffn_gate_inp, cur); + cb(router_logits, "ffn_moe_logits", il); + ggml_tensor * moe_out = - build_moe_ffn(ffn_inp, + build_moe_ffn(inp_latent, model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps, nullptr, // no gate @@ -124,12 +133,17 @@ ggml_tensor * llm_build_nemotron_h::build_ffn_layer(ggml_tensor * cur, const lla model.layers[il].ffn_exp_probs_b, n_expert, n_expert_used, LLM_FFN_RELU_SQR, hparams.expert_weights_norm, - true, hparams.expert_weights_scale, + hparams.expert_weights_scale, LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID, - il); + il, + router_logits); cb(moe_out, "ffn_moe_out", il); - ggml_tensor * ffn_shexp = build_ffn(ffn_inp, + if (model.layers[il].ffn_latent_up) { + moe_out = ggml_mul_mat(ctx0, model.layers[il].ffn_latent_up, moe_out); + } + + ggml_tensor * ffn_shexp = build_ffn(inp_emb, model.layers[il].ffn_up_shexp, NULL, NULL, NULL /* no gate */ , NULL, NULL, model.layers[il].ffn_down_shexp, NULL, NULL, diff --git a/examples/talk-llama/models/nemotron.cpp b/examples/talk-llama/models/nemotron.cpp index fcead041f0a..34aa6fa5ec4 100644 --- a/examples/talk-llama/models/nemotron.cpp +++ b/examples/talk-llama/models/nemotron.cpp @@ -1,10 +1,10 @@ #include "models.h" llm_build_nemotron::llm_build_nemotron(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - //GGML_ASSERT(n_embd_head == hparams.n_rot); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + //GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; diff --git a/examples/talk-llama/models/neo-bert.cpp b/examples/talk-llama/models/neo-bert.cpp index 7c32bfca5f5..2fdf4a3692f 100644 --- a/examples/talk-llama/models/neo-bert.cpp +++ b/examples/talk-llama/models/neo-bert.cpp @@ -1,10 +1,10 @@ #include "models.h" llm_build_neo_bert::llm_build_neo_bert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); ggml_tensor * cur; ggml_tensor * inpL; diff --git a/examples/talk-llama/models/olmo.cpp b/examples/talk-llama/models/olmo.cpp index bbd623f1112..26f4b6ee628 100644 --- a/examples/talk-llama/models/olmo.cpp +++ b/examples/talk-llama/models/olmo.cpp @@ -1,10 +1,10 @@ #include "models.h" llm_build_olmo::llm_build_olmo(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; diff --git a/examples/talk-llama/models/olmo2.cpp b/examples/talk-llama/models/olmo2.cpp index 713552dab89..5076359e3f9 100644 --- a/examples/talk-llama/models/olmo2.cpp +++ b/examples/talk-llama/models/olmo2.cpp @@ -2,10 +2,10 @@ template llm_build_olmo2::llm_build_olmo2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; diff --git a/examples/talk-llama/models/olmoe.cpp b/examples/talk-llama/models/olmoe.cpp index b8b6988f897..83a56a0b3b6 100644 --- a/examples/talk-llama/models/olmoe.cpp +++ b/examples/talk-llama/models/olmoe.cpp @@ -1,10 +1,10 @@ #include "models.h" llm_build_olmoe::llm_build_olmoe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; @@ -92,7 +92,7 @@ llm_build_olmoe::llm_build_olmoe(const llama_model & model, const llm_graph_para nullptr, n_expert, n_expert_used, LLM_FFN_SILU, false, - false, 0.0, + hparams.expert_weights_scale, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il); cb(cur, "ffn_moe_out", il); diff --git a/examples/talk-llama/models/openai-moe-iswa.cpp b/examples/talk-llama/models/openai-moe-iswa.cpp index dbe3ca1851f..403f130bc41 100644 --- a/examples/talk-llama/models/openai-moe-iswa.cpp +++ b/examples/talk-llama/models/openai-moe-iswa.cpp @@ -95,7 +95,7 @@ llm_build_openai_moe_iswa::llm_build_openai_moe_iswa(const llama_model & model, nullptr, n_expert, n_expert_used, LLM_FFN_SWIGLU_OAI_MOE, false, - false, 0.0, + hparams.expert_weights_scale, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT, il); cb(cur, "ffn_moe_out", il); diff --git a/examples/talk-llama/models/openelm.cpp b/examples/talk-llama/models/openelm.cpp index fbf682ec835..5df6fe3e3ce 100644 --- a/examples/talk-llama/models/openelm.cpp +++ b/examples/talk-llama/models/openelm.cpp @@ -1,9 +1,9 @@ #include "models.h" llm_build_openelm::llm_build_openelm(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); ggml_tensor * cur; ggml_tensor * inpL; diff --git a/examples/talk-llama/models/orion.cpp b/examples/talk-llama/models/orion.cpp index bb02273bfe7..48c01efe368 100644 --- a/examples/talk-llama/models/orion.cpp +++ b/examples/talk-llama/models/orion.cpp @@ -1,10 +1,10 @@ #include "models.h" llm_build_orion::llm_build_orion(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; diff --git a/examples/talk-llama/models/paddleocr.cpp b/examples/talk-llama/models/paddleocr.cpp index 39a368df53b..340455c2d5f 100644 --- a/examples/talk-llama/models/paddleocr.cpp +++ b/examples/talk-llama/models/paddleocr.cpp @@ -5,10 +5,10 @@ llm_build_paddleocr::llm_build_paddleocr(const llama_model & model, const llm_gr // NOTE: same with qwen2vl.cpp, but bias tensors are optional - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; diff --git a/examples/talk-llama/models/pangu-embedded.cpp b/examples/talk-llama/models/pangu-embedded.cpp index 664572a5001..1cf0938e68f 100644 --- a/examples/talk-llama/models/pangu-embedded.cpp +++ b/examples/talk-llama/models/pangu-embedded.cpp @@ -2,10 +2,10 @@ llm_build_pangu_embedded::llm_build_pangu_embedded(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; diff --git a/examples/talk-llama/models/phi2.cpp b/examples/talk-llama/models/phi2.cpp index 22dbf610767..32d40d71fb7 100644 --- a/examples/talk-llama/models/phi2.cpp +++ b/examples/talk-llama/models/phi2.cpp @@ -2,10 +2,10 @@ llm_build_phi2::llm_build_phi2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); ggml_tensor * cur; ggml_tensor * attn_norm_output; diff --git a/examples/talk-llama/models/phi3.cpp b/examples/talk-llama/models/phi3.cpp index c8e5da33db7..3d11a9459c4 100644 --- a/examples/talk-llama/models/phi3.cpp +++ b/examples/talk-llama/models/phi3.cpp @@ -2,10 +2,10 @@ template llm_build_phi3::llm_build_phi3(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); ggml_tensor * cur; ggml_tensor * inpL; @@ -114,7 +114,7 @@ llm_build_phi3::llm_build_phi3(const llama_model & model, const llm_graph_ nullptr, n_expert, n_expert_used, LLM_FFN_SILU, true, - false, 0.0, + hparams.expert_weights_scale, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il); cb(cur, "ffn_moe_out", il); diff --git a/examples/talk-llama/models/plamo.cpp b/examples/talk-llama/models/plamo.cpp index 04ff709f9c6..b7a71211042 100644 --- a/examples/talk-llama/models/plamo.cpp +++ b/examples/talk-llama/models/plamo.cpp @@ -1,10 +1,10 @@ #include "models.h" llm_build_plamo::llm_build_plamo(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; diff --git a/examples/talk-llama/models/plamo2.cpp b/examples/talk-llama/models/plamo2.cpp index 3af236843bb..f02acbc1869 100644 --- a/examples/talk-llama/models/plamo2.cpp +++ b/examples/talk-llama/models/plamo2.cpp @@ -27,7 +27,7 @@ llm_build_plamo2::llm_build_plamo2(const llama_model & model, const llm_graph_pa cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); // check if this layer is Mamba or Attention - bool is_mamba_layer = hparams.is_recurrent(il); + const bool is_mamba_layer = hparams.is_recurrent(il); if (is_mamba_layer) { // PLaMo-2 Mamba layer @@ -106,9 +106,9 @@ ggml_tensor * llm_build_plamo2::build_plamo2_attn_layer(llm_graph_input_attn_kv cb(qkv, "wqkv", il); // split QKV tensor into Q, K, V - const int64_t n_embd_head_q = hparams.n_embd_head_k; - const int64_t n_embd_head_k = hparams.n_embd_head_k; - const int64_t n_embd_head_v = hparams.n_embd_head_v; + const int64_t n_embd_head_q = hparams.n_embd_head_k(); + const int64_t n_embd_head_k = hparams.n_embd_head_k(); + const int64_t n_embd_head_v = hparams.n_embd_head_v(); int32_t n_head = hparams.n_head(il); int32_t n_head_kv = hparams.n_head_kv(il); @@ -171,6 +171,8 @@ ggml_tensor * llm_build_plamo2::build_plamo2_mamba_layer(llm_graph_input_rs * in GGML_ASSERT(n_seqs != 0); GGML_ASSERT(ubatch.equal_seqs()); GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); + GGML_ASSERT(d_inner % n_head == 0); + GGML_ASSERT(n_group == 0); ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); diff --git a/examples/talk-llama/models/plamo3.cpp b/examples/talk-llama/models/plamo3.cpp index 55c8064679e..32af6e04663 100644 --- a/examples/talk-llama/models/plamo3.cpp +++ b/examples/talk-llama/models/plamo3.cpp @@ -3,8 +3,8 @@ template llm_build_plamo3::llm_build_plamo3(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t head_dim_q = hparams.n_embd_head_k; - const int64_t head_dim_v = hparams.n_embd_head_v; + const int64_t head_dim_q = hparams.n_embd_head_k(); + const int64_t head_dim_v = hparams.n_embd_head_v(); ggml_tensor * cur; ggml_tensor * inpL = build_inp_embd(model.tok_embd); diff --git a/examples/talk-llama/models/plm.cpp b/examples/talk-llama/models/plm.cpp index 612a487c564..bcb651ce543 100644 --- a/examples/talk-llama/models/plm.cpp +++ b/examples/talk-llama/models/plm.cpp @@ -1,10 +1,10 @@ #include "models.h" llm_build_plm::llm_build_plm(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const float kq_scale = 1.0f/sqrtf(float(hparams.n_embd_head_k)); + const float kq_scale = 1.0f/sqrtf(float(hparams.n_embd_head_k())); - const uint32_t n_embd_head_qk_rope = hparams.n_rot; - const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot; + const uint32_t n_embd_head_qk_rope = hparams.n_rot(); + const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k() - hparams.n_rot(); const uint32_t kv_lora_rank = hparams.n_lora_kv; @@ -38,15 +38,15 @@ llm_build_plm::llm_build_plm(const llama_model & model, const llm_graph_params & // split into {n_head * n_embd_head_qk_nope, n_tokens} ggml_tensor * q_nope = ggml_view_3d(ctx0, q, n_embd_head_qk_nope, n_head, n_tokens, - ggml_row_size(q->type, hparams.n_embd_head_k), - ggml_row_size(q->type, hparams.n_embd_head_k * n_head), + ggml_row_size(q->type, hparams.n_embd_head_k()), + ggml_row_size(q->type, hparams.n_embd_head_k() * n_head), 0); cb(q_nope, "q_nope", il); // and {n_head * n_embd_head_qk_rope, n_tokens} ggml_tensor * q_pe = ggml_view_3d(ctx0, q, n_embd_head_qk_rope, n_head, n_tokens, - ggml_row_size(q->type, hparams.n_embd_head_k), - ggml_row_size(q->type, hparams.n_embd_head_k * n_head), + ggml_row_size(q->type, hparams.n_embd_head_k()), + ggml_row_size(q->type, hparams.n_embd_head_k() * n_head), ggml_row_size(q->type, n_embd_head_qk_nope)); cb(q_pe, "q_pe", il); @@ -78,23 +78,23 @@ llm_build_plm::llm_build_plm(const llama_model & model, const llm_graph_params & // split into {n_head * n_embd_head_qk_nope, n_tokens} ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens, - ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v), - ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), + ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v()), + ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v())), 0); cb(k_nope, "k_nope", il); // and {n_head * n_embd_head_v, n_tokens} - ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens, - ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)), - ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head), + ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v(), n_head, n_tokens, + ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v())), + ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v())*n_head), ggml_row_size(kv->type, (n_embd_head_qk_nope))); cb(v_states, "v_states", il); v_states = ggml_cont(ctx0, v_states); cb(v_states, "v_states", il); - v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens, - ggml_row_size(kv->type, hparams.n_embd_head_v * n_head), + v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v() * n_head, n_tokens, + ggml_row_size(kv->type, hparams.n_embd_head_v() * n_head), 0); cb(v_states, "v_states", il); diff --git a/examples/talk-llama/models/qwen.cpp b/examples/talk-llama/models/qwen.cpp index 31fd9b73763..7390f1320bf 100644 --- a/examples/talk-llama/models/qwen.cpp +++ b/examples/talk-llama/models/qwen.cpp @@ -2,9 +2,9 @@ llm_build_qwen::llm_build_qwen(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); ggml_tensor * cur; ggml_tensor * inpL; diff --git a/examples/talk-llama/models/qwen2.cpp b/examples/talk-llama/models/qwen2.cpp index 3da4dea3c16..58c10622508 100644 --- a/examples/talk-llama/models/qwen2.cpp +++ b/examples/talk-llama/models/qwen2.cpp @@ -1,10 +1,10 @@ #include "models.h" llm_build_qwen2::llm_build_qwen2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; diff --git a/examples/talk-llama/models/qwen2moe.cpp b/examples/talk-llama/models/qwen2moe.cpp index 49142b71236..60761789dc9 100644 --- a/examples/talk-llama/models/qwen2moe.cpp +++ b/examples/talk-llama/models/qwen2moe.cpp @@ -1,10 +1,10 @@ #include "models.h" llm_build_qwen2moe::llm_build_qwen2moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; @@ -94,7 +94,7 @@ llm_build_qwen2moe::llm_build_qwen2moe(const llama_model & model, const llm_grap nullptr, n_expert, n_expert_used, LLM_FFN_SILU, false, - false, 0.0, + hparams.expert_weights_scale, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il); cb(moe_out, "ffn_moe_out", il); diff --git a/examples/talk-llama/models/qwen2vl.cpp b/examples/talk-llama/models/qwen2vl.cpp index 9be38675cf7..9004bab9db1 100644 --- a/examples/talk-llama/models/qwen2vl.cpp +++ b/examples/talk-llama/models/qwen2vl.cpp @@ -1,10 +1,10 @@ #include "models.h" llm_build_qwen2vl::llm_build_qwen2vl(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; diff --git a/examples/talk-llama/models/qwen3.cpp b/examples/talk-llama/models/qwen3.cpp index a5cfffa5314..52081668477 100644 --- a/examples/talk-llama/models/qwen3.cpp +++ b/examples/talk-llama/models/qwen3.cpp @@ -1,10 +1,10 @@ #include "models.h" llm_build_qwen3::llm_build_qwen3(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; @@ -30,13 +30,13 @@ llm_build_qwen3::llm_build_qwen3(const llama_model & model, const llm_graph_para // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur, model.layers[il].wq_s); cb(Qcur, "Qcur", il); - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur, model.layers[il].wk_s); cb(Kcur, "Kcur", il); - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur, model.layers[il].wv_s); cb(Vcur, "Vcur", il); Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); @@ -68,6 +68,9 @@ llm_build_qwen3::llm_build_qwen3(const llama_model & model, const llm_graph_para cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + if (model.layers[il].wo_s) { + cur = ggml_mul(ctx0, cur, model.layers[il].wo_s); + } } if (il == n_layer - 1 && inp_out_ids) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); @@ -83,9 +86,9 @@ llm_build_qwen3::llm_build_qwen3(const llama_model & model, const llm_graph_para cb(cur, "ffn_norm", il); cur = build_ffn(cur, - model.layers[il].ffn_up, NULL, NULL, - model.layers[il].ffn_gate, NULL, NULL, - model.layers[il].ffn_down, NULL, NULL, + model.layers[il].ffn_up, NULL, model.layers[il].ffn_up_s, + model.layers[il].ffn_gate, NULL, model.layers[il].ffn_gate_s, + model.layers[il].ffn_down, NULL, model.layers[il].ffn_down_s, NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); cb(cur, "ffn_out", il); diff --git a/examples/talk-llama/models/qwen35.cpp b/examples/talk-llama/models/qwen35.cpp index bacf7a4c2ee..3108bf331ac 100644 --- a/examples/talk-llama/models/qwen35.cpp +++ b/examples/talk-llama/models/qwen35.cpp @@ -4,9 +4,9 @@ llm_build_qwen35::llm_build_qwen35(const llama_model & model, const llm_graph_params & params) : llm_build_delta_net_base(params), model(model) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); int sections[4]; std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); @@ -90,11 +90,11 @@ std::pair llm_build_qwen35::build_qkvz( const int64_t n_seqs = ubatch.n_seqs; const int64_t n_seq_tokens = ubatch.n_seq_tokens; - ggml_tensor * qkv_mixed = build_lora_mm(model.layers[il].wqkv, input); + ggml_tensor * qkv_mixed = build_lora_mm(model.layers[il].wqkv, input, model.layers[il].wqkv_s); qkv_mixed = ggml_reshape_3d(ctx0, qkv_mixed, qkv_mixed->ne[0], n_seq_tokens, n_seqs); cb(qkv_mixed, "linear_attn_qkv_mixed", il); - ggml_tensor * z = build_lora_mm(model.layers[il].wqkv_gate, input); + ggml_tensor * z = build_lora_mm(model.layers[il].wqkv_gate, input, model.layers[il].wqkv_gate_s); cb(z, "z", il); return { qkv_mixed, z }; @@ -117,13 +117,13 @@ ggml_tensor * llm_build_qwen35::build_layer_attn( ggml_tensor * inp_pos, int * sections, int il) { - const int64_t n_embd_head = hparams.n_embd_head_v; - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + const int64_t n_embd_head = hparams.n_embd_head_v(); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); // Order: joint QG projection, QG split, Q norm, KV projection, K norm, RoPE, attention // Qwen3Next uses a single Q projection that outputs query + gate - ggml_tensor * Qcur_full = build_lora_mm(model.layers[il].wq, cur); // [ (n_embd_head * 2) * n_head, n_tokens ] + ggml_tensor * Qcur_full = build_lora_mm(model.layers[il].wq, cur, model.layers[il].wq_s); // [ (n_embd_head * 2) * n_head, n_tokens ] cb(Qcur_full, "Qcur_full", il); ggml_tensor * Qcur = ggml_view_3d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, @@ -135,10 +135,10 @@ ggml_tensor * llm_build_qwen35::build_layer_attn( Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il); cb(Qcur, "Qcur_normed", il); - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur, model.layers[il].wk_s); cb(Kcur, "Kcur", il); - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur, model.layers[il].wv_s); cb(Vcur, "Vcur", il); // Apply K normalization @@ -186,7 +186,7 @@ ggml_tensor * llm_build_qwen35::build_layer_attn( cur = ggml_mul(ctx0, cur, gate_sigmoid); cb(cur, "attn_gated", il); - cur = build_lora_mm(model.layers[il].wo, cur); + cur = build_lora_mm(model.layers[il].wo, cur, model.layers[il].wo_s); cb(cur, "attn_output", il); return cur; @@ -217,13 +217,13 @@ ggml_tensor * llm_build_qwen35::build_layer_attn_linear( ggml_tensor * qkv_mixed = qkvz.first; ggml_tensor * z = qkvz.second; - ggml_tensor * beta = build_lora_mm(model.layers[il].ssm_beta, cur); + ggml_tensor * beta = build_lora_mm(model.layers[il].ssm_beta, cur, model.layers[il].ssm_beta_s); beta = ggml_reshape_4d(ctx0, beta, 1, num_v_heads, n_seq_tokens, n_seqs); cb(beta, "beta", il); beta = ggml_sigmoid(ctx0, beta); - ggml_tensor * alpha = build_lora_mm(model.layers[il].ssm_alpha, cur); + ggml_tensor * alpha = build_lora_mm(model.layers[il].ssm_alpha, cur, model.layers[il].ssm_alpha_s); alpha = ggml_cont_3d(ctx0, alpha, num_v_heads, n_seq_tokens, n_seqs); cb(alpha, "alpha", il); @@ -321,9 +321,9 @@ ggml_tensor * llm_build_qwen35::build_layer_attn_linear( //v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); // if head keys and value keys are different, repeat to force tensors into matching shapes - if (num_k_heads != num_v_heads) { + // note: need explicit repeat only if we are not using the fused GDN + if (num_k_heads != num_v_heads && (!cparams.fused_gdn_ar || !cparams.fused_gdn_ch)) { GGML_ASSERT(num_v_heads % num_k_heads == 0); - // TODO: try to avoid these explicit repeats by utilizing op broadcast q_conv = ggml_repeat_4d(ctx0, q_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs); k_conv = ggml_repeat_4d(ctx0, k_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs); } @@ -332,13 +332,8 @@ ggml_tensor * llm_build_qwen35::build_layer_attn_linear( cb(k_conv, "k_conv_predelta", il); cb(v_conv, "v_conv_predelta", il); - // Choose between build_delta_net_chunking, build_delta_net_recurrent, and build_delta_net_autoregressive based on n_tokens - std::pair attn_out; // pair of (output, new_state) - if (n_seq_tokens == 1) { - attn_out = build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il); - } else { - attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, il); - } + auto attn_out = build_delta_net(q_conv, k_conv, v_conv, gate, beta, state, il); + ggml_tensor * output = attn_out.first; ggml_tensor * new_state = attn_out.second; cb(output, "attn_output", il); @@ -361,7 +356,7 @@ ggml_tensor * llm_build_qwen35::build_layer_attn_linear( cb(final_output, "final_output", il); // Output projection - cur = build_lora_mm(model.layers[il].ssm_out, final_output); + cur = build_lora_mm(model.layers[il].ssm_out, final_output, model.layers[il].ssm_out_s); cb(cur, "linear_attn_out", il); // Reshape back to original dimensions @@ -375,9 +370,9 @@ ggml_tensor * llm_build_qwen35::build_layer_ffn(ggml_tensor * cur, const int il) GGML_ASSERT(model.layers[il].ffn_gate_inp == nullptr); cur = build_ffn(cur, - model.layers[il].ffn_up, NULL, NULL, - model.layers[il].ffn_gate, NULL, NULL, - model.layers[il].ffn_down, NULL, NULL, + model.layers[il].ffn_up, NULL, model.layers[il].ffn_up_s, + model.layers[il].ffn_gate, NULL, model.layers[il].ffn_gate_s, + model.layers[il].ffn_down, NULL, model.layers[il].ffn_down_s, NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); cb(cur, "ffn_out", il); diff --git a/examples/talk-llama/models/qwen35moe.cpp b/examples/talk-llama/models/qwen35moe.cpp index 22d708f2062..165e2412e56 100644 --- a/examples/talk-llama/models/qwen35moe.cpp +++ b/examples/talk-llama/models/qwen35moe.cpp @@ -4,9 +4,9 @@ llm_build_qwen35moe::llm_build_qwen35moe(const llama_model & model, const llm_graph_params & params) : llm_build_delta_net_base(params), model(model) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); int sections[4]; std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); @@ -90,11 +90,11 @@ std::pair llm_build_qwen35moe::build_qkvz( const int64_t n_seqs = ubatch.n_seqs; const int64_t n_seq_tokens = ubatch.n_seq_tokens; - ggml_tensor * qkv_mixed = build_lora_mm(model.layers[il].wqkv, input); + ggml_tensor * qkv_mixed = build_lora_mm(model.layers[il].wqkv, input, model.layers[il].wqkv_s); qkv_mixed = ggml_reshape_3d(ctx0, qkv_mixed, qkv_mixed->ne[0], n_seq_tokens, n_seqs); cb(qkv_mixed, "linear_attn_qkv_mixed", il); - ggml_tensor * z = build_lora_mm(model.layers[il].wqkv_gate, input); + ggml_tensor * z = build_lora_mm(model.layers[il].wqkv_gate, input, model.layers[il].wqkv_gate_s); cb(z, "z", il); return { qkv_mixed, z }; @@ -117,13 +117,13 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn( ggml_tensor * inp_pos, int * sections, int il) { - const int64_t n_embd_head = hparams.n_embd_head_v; - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + const int64_t n_embd_head = hparams.n_embd_head_v(); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); // Order: joint QG projection, QG split, Q norm, KV projection, K norm, RoPE, attention // Qwen3Next uses a single Q projection that outputs query + gate - ggml_tensor * Qcur_full = build_lora_mm(model.layers[il].wq, cur); // [ (n_embd_head * 2) * n_head, n_tokens ] + ggml_tensor * Qcur_full = build_lora_mm(model.layers[il].wq, cur, model.layers[il].wq_s); // [ (n_embd_head * 2) * n_head, n_tokens ] cb(Qcur_full, "Qcur_full", il); ggml_tensor * Qcur = ggml_view_3d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, @@ -135,10 +135,10 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn( Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il); cb(Qcur, "Qcur_normed", il); - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur, model.layers[il].wk_s); cb(Kcur, "Kcur", il); - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur, model.layers[il].wv_s); cb(Vcur, "Vcur", il); // Apply K normalization @@ -186,7 +186,7 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn( cur = ggml_mul(ctx0, cur, gate_sigmoid); cb(cur, "attn_gated", il); - cur = build_lora_mm(model.layers[il].wo, cur); + cur = build_lora_mm(model.layers[il].wo, cur, model.layers[il].wo_s); cb(cur, "attn_output", il); return cur; @@ -217,13 +217,13 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear( ggml_tensor * qkv_mixed = qkvz.first; ggml_tensor * z = qkvz.second; - ggml_tensor * beta = build_lora_mm(model.layers[il].ssm_beta, cur); + ggml_tensor * beta = build_lora_mm(model.layers[il].ssm_beta, cur, model.layers[il].ssm_beta_s); beta = ggml_reshape_4d(ctx0, beta, 1, num_v_heads, n_seq_tokens, n_seqs); cb(beta, "beta", il); beta = ggml_sigmoid(ctx0, beta); - ggml_tensor * alpha = build_lora_mm(model.layers[il].ssm_alpha, cur); + ggml_tensor * alpha = build_lora_mm(model.layers[il].ssm_alpha, cur, model.layers[il].ssm_alpha_s); alpha = ggml_cont_3d(ctx0, alpha, num_v_heads, n_seq_tokens, n_seqs); cb(alpha, "alpha", il); @@ -321,9 +321,9 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear( //v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); // if head keys and value keys are different, repeat to force tensors into matching shapes - if (num_k_heads != num_v_heads) { + // note: need explicit repeat only if we are not using the fused GDN + if (num_k_heads != num_v_heads && (!cparams.fused_gdn_ar || !cparams.fused_gdn_ch)) { GGML_ASSERT(num_v_heads % num_k_heads == 0); - // TODO: try to avoid these explicit repeats by utilizing op broadcast q_conv = ggml_repeat_4d(ctx0, q_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs); k_conv = ggml_repeat_4d(ctx0, k_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs); } @@ -332,13 +332,8 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear( cb(k_conv, "k_conv_predelta", il); cb(v_conv, "v_conv_predelta", il); - // Choose between build_delta_net_chunking, build_delta_net_recurrent, and build_delta_net_autoregressive based on n_tokens - std::pair attn_out; // pair of (output, new_state) - if (n_seq_tokens == 1) { - attn_out = build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il); - } else { - attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, il); - } + auto attn_out = build_delta_net(q_conv, k_conv, v_conv, gate, beta, state, il); + ggml_tensor * output = attn_out.first; ggml_tensor * new_state = attn_out.second; cb(output, "attn_output", il); @@ -361,7 +356,7 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear( cb(final_output, "final_output", il); // Output projection - cur = build_lora_mm(model.layers[il].ssm_out, final_output); + cur = build_lora_mm(model.layers[il].ssm_out, final_output, model.layers[il].ssm_out_s); cb(cur, "linear_attn_out", il); // Reshape back to original dimensions @@ -376,21 +371,28 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_ffn(ggml_tensor * cur, const int ggml_tensor * moe_out = build_moe_ffn(cur, - model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps, - model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, nullptr, - n_expert, n_expert_used, LLM_FFN_SILU, - true, false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il, - nullptr, model.layers[il].ffn_gate_up_exps); + n_expert, n_expert_used, + LLM_FFN_SILU, true, + hparams.expert_weights_scale, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il, + nullptr, model.layers[il].ffn_gate_up_exps, + model.layers[il].ffn_up_exps_s, + model.layers[il].ffn_gate_exps_s, + model.layers[il].ffn_down_exps_s); cb(moe_out, "ffn_moe_out", il); // Add shared experts if present - following Qwen3Next reference implementation if (model.layers[il].ffn_up_shexp != nullptr) { ggml_tensor * ffn_shexp = build_ffn(cur, - model.layers[il].ffn_up_shexp, NULL, NULL, - model.layers[il].ffn_gate_shexp, NULL, NULL, - model.layers[il].ffn_down_shexp, NULL, NULL, + model.layers[il].ffn_up_shexp, NULL, model.layers[il].ffn_up_shexp_s, + model.layers[il].ffn_gate_shexp, NULL, model.layers[il].ffn_gate_shexp_s, + model.layers[il].ffn_down_shexp, NULL, model.layers[il].ffn_down_shexp_s, NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); cb(ffn_shexp, "ffn_shexp", il); diff --git a/examples/talk-llama/models/qwen3moe.cpp b/examples/talk-llama/models/qwen3moe.cpp index 888534fb347..dba46618ff2 100644 --- a/examples/talk-llama/models/qwen3moe.cpp +++ b/examples/talk-llama/models/qwen3moe.cpp @@ -1,10 +1,10 @@ #include "models.h" llm_build_qwen3moe::llm_build_qwen3moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; @@ -30,13 +30,13 @@ llm_build_qwen3moe::llm_build_qwen3moe(const llama_model & model, const llm_grap // self_attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur, model.layers[il].wq_s); cb(Qcur, "Qcur", il); - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur, model.layers[il].wk_s); cb(Kcur, "Kcur", il); - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur, model.layers[il].wv_s); cb(Vcur, "Vcur", il); Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); @@ -68,6 +68,9 @@ llm_build_qwen3moe::llm_build_qwen3moe(const llama_model & model, const llm_grap cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + if (model.layers[il].wo_s) { + cur = ggml_mul(ctx0, cur, model.layers[il].wo_s); + } } if (il == n_layer - 1 && inp_out_ids) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); @@ -91,9 +94,13 @@ llm_build_qwen3moe::llm_build_qwen3moe(const llama_model & model, const llm_grap nullptr, n_expert, n_expert_used, LLM_FFN_SILU, true, - false, 0.0, + hparams.expert_weights_scale, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, - il); + il, + nullptr, nullptr, + model.layers[il].ffn_up_exps_s, + model.layers[il].ffn_gate_exps_s, + model.layers[il].ffn_down_exps_s); cb(moe_out, "ffn_moe_out", il); cur = moe_out; diff --git a/examples/talk-llama/models/qwen3next.cpp b/examples/talk-llama/models/qwen3next.cpp index f2621200f23..cc479dd075c 100644 --- a/examples/talk-llama/models/qwen3next.cpp +++ b/examples/talk-llama/models/qwen3next.cpp @@ -100,8 +100,8 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn( ggml_tensor * cur, ggml_tensor * inp_pos, int il) { - const int64_t n_embd_head = hparams.n_embd_head_v; - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + const int64_t n_embd_head = hparams.n_embd_head_v(); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); // Order: joint QG projection, QG split, Q norm, KV projection, K norm, RoPE, attention @@ -406,6 +406,7 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( //v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); // if head keys and value keys are different, repeat to force tensors into matching shapes + // TODO: avoid repeats for fused GDN, needs broadcast configuration for GDN op [TAG_GGML_GDN_BCAST] if (num_k_heads != num_v_heads) { GGML_ASSERT(num_v_heads % num_k_heads == 0); int64_t repeat_factor = num_v_heads / num_k_heads; @@ -431,13 +432,8 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( cb(k_conv, "k_conv_predelta", il); cb(v_conv, "v_conv_predelta", il); - // Choose between build_delta_net_chunking, build_delta_net_recurrent, and build_delta_net_autoregressive based on n_tokens - std::pair attn_out; // pair of (output, new_state) - if (n_seq_tokens == 1) { - attn_out = build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il); - } else { - attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, il); - } + auto attn_out = build_delta_net(q_conv, k_conv, v_conv, gate, beta, state, il); + ggml_tensor * output = attn_out.first; ggml_tensor * new_state = attn_out.second; cb(output, "attn_output", il); @@ -475,11 +471,15 @@ ggml_tensor * llm_build_qwen3next::build_layer_ffn(ggml_tensor * cur, const int // MoE branch ggml_tensor * moe_out = build_moe_ffn(cur, - model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps, - model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, nullptr, - n_expert, n_expert_used, LLM_FFN_SILU, - true, false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + hparams.expert_weights_scale, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il, nullptr, model.layers[il].ffn_gate_up_exps); cb(moe_out, "ffn_moe_out", il); diff --git a/examples/talk-llama/models/qwen3vl-moe.cpp b/examples/talk-llama/models/qwen3vl-moe.cpp index e5e1a2150c8..195daea66c9 100644 --- a/examples/talk-llama/models/qwen3vl-moe.cpp +++ b/examples/talk-llama/models/qwen3vl-moe.cpp @@ -4,10 +4,10 @@ llm_build_qwen3vlmoe::llm_build_qwen3vlmoe(const llama_model & model, const llm_ const size_t n_deepstack_layers = hparams.n_deepstack_layers; const int64_t n_embd = hparams.n_embd; - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; @@ -99,7 +99,7 @@ llm_build_qwen3vlmoe::llm_build_qwen3vlmoe(const llama_model & model, const llm_ nullptr, n_expert, n_expert_used, LLM_FFN_SILU, true, - false, 0.0, + hparams.expert_weights_scale, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il); cb(moe_out, "ffn_moe_out", il); diff --git a/examples/talk-llama/models/qwen3vl.cpp b/examples/talk-llama/models/qwen3vl.cpp index 0f8315b3240..bbd5f42ba5b 100644 --- a/examples/talk-llama/models/qwen3vl.cpp +++ b/examples/talk-llama/models/qwen3vl.cpp @@ -4,10 +4,10 @@ llm_build_qwen3vl::llm_build_qwen3vl(const llama_model & model, const llm_graph_ const size_t n_deepstack_layers = hparams.n_deepstack_layers; const int64_t n_embd = hparams.n_embd; - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; diff --git a/examples/talk-llama/models/refact.cpp b/examples/talk-llama/models/refact.cpp index ff5eb2841db..140700d9e2d 100644 --- a/examples/talk-llama/models/refact.cpp +++ b/examples/talk-llama/models/refact.cpp @@ -1,9 +1,9 @@ #include "models.h" llm_build_refact::llm_build_refact(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); ggml_tensor * cur; ggml_tensor * inpL; diff --git a/examples/talk-llama/models/rnd1.cpp b/examples/talk-llama/models/rnd1.cpp index 46b3dc3efca..c8e1f43400f 100644 --- a/examples/talk-llama/models/rnd1.cpp +++ b/examples/talk-llama/models/rnd1.cpp @@ -2,10 +2,10 @@ // RND1 is a Qwen3Moe AR model converted to diffusion model. llm_build_rnd1::llm_build_rnd1(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; @@ -93,7 +93,7 @@ llm_build_rnd1::llm_build_rnd1(const llama_model & model, const llm_graph_params nullptr, n_expert, n_expert_used, LLM_FFN_SILU, true, - false, 0.0, + hparams.expert_weights_scale, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il); cb(moe_out, "ffn_moe_out", il); diff --git a/examples/talk-llama/models/seed-oss.cpp b/examples/talk-llama/models/seed-oss.cpp index 0dc33c50ba3..a4d0b75d846 100644 --- a/examples/talk-llama/models/seed-oss.cpp +++ b/examples/talk-llama/models/seed-oss.cpp @@ -1,10 +1,10 @@ #include "models.h" llm_build_seed_oss::llm_build_seed_oss(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; diff --git a/examples/talk-llama/models/smallthinker.cpp b/examples/talk-llama/models/smallthinker.cpp index 4c497ca76f4..e2155aacef4 100644 --- a/examples/talk-llama/models/smallthinker.cpp +++ b/examples/talk-llama/models/smallthinker.cpp @@ -2,10 +2,10 @@ template llm_build_smallthinker::llm_build_smallthinker(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params){ - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; @@ -93,7 +93,7 @@ llm_build_smallthinker::llm_build_smallthinker(const llama_model & model, nullptr, n_expert, n_expert_used, LLM_FFN_RELU, true, - false, 0.0, + hparams.expert_weights_scale, static_cast(hparams.expert_gating_func), il, probs); diff --git a/examples/talk-llama/models/smollm3.cpp b/examples/talk-llama/models/smollm3.cpp index 97c30deed54..e267fd8f32f 100644 --- a/examples/talk-llama/models/smollm3.cpp +++ b/examples/talk-llama/models/smollm3.cpp @@ -1,10 +1,10 @@ #include "models.h" llm_build_smollm3::llm_build_smollm3(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; diff --git a/examples/talk-llama/models/stablelm.cpp b/examples/talk-llama/models/stablelm.cpp index bed1915c006..ff5aced93b3 100644 --- a/examples/talk-llama/models/stablelm.cpp +++ b/examples/talk-llama/models/stablelm.cpp @@ -1,9 +1,9 @@ #include "models.h" llm_build_stablelm::llm_build_stablelm(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); ggml_tensor * cur; ggml_tensor * inpL; diff --git a/examples/talk-llama/models/starcoder.cpp b/examples/talk-llama/models/starcoder.cpp index e197af4a8c6..941cee98219 100644 --- a/examples/talk-llama/models/starcoder.cpp +++ b/examples/talk-llama/models/starcoder.cpp @@ -1,10 +1,10 @@ #include "models.h" llm_build_starcoder::llm_build_starcoder(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); ggml_tensor * cur; ggml_tensor * inpL; diff --git a/examples/talk-llama/models/starcoder2.cpp b/examples/talk-llama/models/starcoder2.cpp index e40ef2cb749..a5965aceb3b 100644 --- a/examples/talk-llama/models/starcoder2.cpp +++ b/examples/talk-llama/models/starcoder2.cpp @@ -1,10 +1,10 @@ #include "models.h" llm_build_starcoder2::llm_build_starcoder2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; diff --git a/examples/talk-llama/models/step35-iswa.cpp b/examples/talk-llama/models/step35-iswa.cpp index f8737815a67..176209cd93e 100644 --- a/examples/talk-llama/models/step35-iswa.cpp +++ b/examples/talk-llama/models/step35-iswa.cpp @@ -52,7 +52,7 @@ llm_build_step35_iswa::llm_build_step35_iswa(const llama_model & model, const ll // RoPE (partial rotary factors per layer) const bool is_swa = hparams.is_swa(il); ggml_tensor * rope_factors = is_swa ? nullptr : model.get_rope_factors(cparams, il); - const int64_t n_rot_l = is_swa ? hparams.n_rot : (hparams.n_rot / 2); + const int64_t n_rot_l = hparams.n_rot(il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, rope_factors, n_rot_l, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, @@ -119,9 +119,6 @@ llm_build_step35_iswa::llm_build_step35_iswa(const llama_model & model, const ll cb(cur, "ffn_out", il); } else { // MoE routed experts - const bool norm_w = hparams.expert_weights_norm; - const float w_scale = hparams.expert_weights_scale; - const bool scale_w = w_scale != 0.0f; ggml_tensor * moe_out = build_moe_ffn(cur, model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps, @@ -129,8 +126,8 @@ llm_build_step35_iswa::llm_build_step35_iswa(const llama_model & model, const ll model.layers[il].ffn_down_exps, model.layers[il].ffn_exp_probs_b, n_expert, n_expert_used, - LLM_FFN_SILU, - norm_w, scale_w, w_scale, + LLM_FFN_SILU, hparams.expert_weights_norm, + hparams.expert_weights_scale, (llama_expert_gating_func_type) hparams.expert_gating_func, il); cb(moe_out, "ffn_moe_out", il); diff --git a/examples/talk-llama/models/t5-dec.cpp b/examples/talk-llama/models/t5-dec.cpp index 297e450de76..8ca8372bd4c 100644 --- a/examples/talk-llama/models/t5-dec.cpp +++ b/examples/talk-llama/models/t5-dec.cpp @@ -1,10 +1,10 @@ #include "models.h" llm_build_t5_dec::llm_build_t5_dec(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); //const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); ggml_tensor * cur; ggml_tensor * inpL; diff --git a/examples/talk-llama/models/t5-enc.cpp b/examples/talk-llama/models/t5-enc.cpp index 70e1d80dcdd..395dfb51042 100644 --- a/examples/talk-llama/models/t5-enc.cpp +++ b/examples/talk-llama/models/t5-enc.cpp @@ -1,9 +1,9 @@ #include "models.h" llm_build_t5_enc::llm_build_t5_enc(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); ggml_tensor * cur; ggml_tensor * inpL; diff --git a/examples/talk-llama/models/xverse.cpp b/examples/talk-llama/models/xverse.cpp index 364797dd31b..3a8dfafcceb 100644 --- a/examples/talk-llama/models/xverse.cpp +++ b/examples/talk-llama/models/xverse.cpp @@ -1,10 +1,10 @@ #include "models.h" llm_build_xverse::llm_build_xverse(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; diff --git a/examples/talk-llama/unicode.cpp b/examples/talk-llama/unicode.cpp index 1475b53b659..122c8ca04a5 100644 --- a/examples/talk-llama/unicode.cpp +++ b/examples/talk-llama/unicode.cpp @@ -773,7 +773,7 @@ static std::vector unicode_regex_split_custom(const std::string & text, // tiny_aya digit grouping pattern from tokenizer.json: // {"type": "Split", "pattern": {"Regex": "\\d{1,3}(?=(?:\\d{3})*\\b)"}, "behavior": "Isolated"} // Splits digits into groups of 3 from the right (e.g., 1234567 -> 1, 234, 567) - // TODO: Revisit this regex, incase there are any subtle tokenization differences with the original regex. + // TODO: Revisit this regex, in case there are any subtle tokenization differences with the original regex. bpe_offsets = unicode_regex_split_custom_afmoe(text, offsets); } From 27fa20774a34b36bfe18edd7572ed42ef601bc3b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 16 Mar 2026 09:11:13 +0200 Subject: [PATCH 288/831] ggml : try fix arm build (#0) --- ggml/src/ggml-cpu/arch/arm/quants.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-cpu/arch/arm/quants.c b/ggml/src/ggml-cpu/arch/arm/quants.c index c1856201b31..82b048bb3ae 100644 --- a/ggml/src/ggml-cpu/arch/arm/quants.c +++ b/ggml/src/ggml-cpu/arch/arm/quants.c @@ -666,7 +666,7 @@ void ggml_vec_dot_nvfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo float sumf = 0; -#if defined __ARM_NEON +#if defined(__ARM_NEON) && defined(__ARM_FEATURE_FMA) const int8x16_t values = vld1q_s8(kvalues_mxfp4); const uint8x16_t m4b = vdupq_n_u8(0x0f); float32x4_t acc = vdupq_n_f32(0.0f); From 136dc2eb1254e3f944caa1317e128283d97aa3a5 Mon Sep 17 00:00:00 2001 From: Igor Loskutov Date: Mon, 16 Mar 2026 07:33:06 -0400 Subject: [PATCH 289/831] server: return proper HTTP status codes for error responses (#3707) Several error paths in the /inference and /load endpoints returned HTTP 200 with a JSON error body, making it impossible for clients to distinguish errors from successful responses by status code. Set 400 for client errors (missing file field, unreadable audio, missing/invalid model) and 500 for server errors (ffmpeg conversion failure). The two existing status-code sites (499 for client disconnect, 500 for processing failure) are unchanged. --- examples/server/server.cpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index c5354efc314..8ace43bf80e 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -811,6 +811,7 @@ int main(int argc, char ** argv) { { fprintf(stderr, "error: no 'file' field in the request\n"); const std::string error_resp = "{\"error\":\"no 'file' field in the request\"}"; + res.status = 400; res.set_content(error_resp, "application/json"); return; } @@ -837,6 +838,7 @@ int main(int argc, char ** argv) { std::string error_resp = "{\"error\":\"Failed to execute ffmpeg command.\"}"; const bool is_converted = convert_to_wav(temp_filename, error_resp); if (!is_converted) { + res.status = 500; res.set_content(error_resp, "application/json"); return; } @@ -846,6 +848,7 @@ int main(int argc, char ** argv) { { fprintf(stderr, "error: failed to read WAV file '%s'\n", temp_filename.c_str()); const std::string error_resp = "{\"error\":\"failed to read WAV file\"}"; + res.status = 400; res.set_content(error_resp, "application/json"); std::remove(temp_filename.c_str()); return; @@ -857,6 +860,7 @@ int main(int argc, char ** argv) { { fprintf(stderr, "error: failed to read audio data\n"); const std::string error_resp = "{\"error\":\"failed to read audio data\"}"; + res.status = 400; res.set_content(error_resp, "application/json"); return; } @@ -1127,6 +1131,7 @@ int main(int argc, char ** argv) { { fprintf(stderr, "error: no 'model' field in the request\n"); const std::string error_resp = "{\"error\":\"no 'model' field in the request\"}"; + res.status = 400; res.set_content(error_resp, "application/json"); return; } @@ -1135,6 +1140,7 @@ int main(int argc, char ** argv) { { fprintf(stderr, "error: 'model': %s not found!\n", model.c_str()); const std::string error_resp = "{\"error\":\"model not found!\"}"; + res.status = 400; res.set_content(error_resp, "application/json"); return; } From 21665eab4c6a5a4c27459a205eaf643144831fe1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ga=C3=ABl=20James?= Date: Mon, 16 Mar 2026 12:33:56 +0100 Subject: [PATCH 290/831] examples : Allow max_len to be used for any output format (#3679) --- examples/server/server.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 8ace43bf80e..f6a7a83181a 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -939,7 +939,7 @@ int main(int argc, char ** argv) { wparams.logprob_thold = params.logprob_thold; wparams.no_timestamps = params.no_timestamps; - wparams.token_timestamps = !params.no_timestamps && params.response_format == vjson_format; + wparams.token_timestamps = !params.no_timestamps; wparams.no_context = params.no_context; wparams.suppress_nst = params.suppress_nst; From 975b979834f884a8d91dfdb0bf9e576dac34cc5b Mon Sep 17 00:00:00 2001 From: Aiudadadadf Date: Mon, 16 Mar 2026 12:41:54 +0100 Subject: [PATCH 291/831] py : replace deprecated openvino-dev with openvino>=2023.3.0 (#3678) * models: replace deprecated openvino-dev with openvino>=2023.3.0 for Python 3.12+ compat * models: remove unused openvino.tools.mo import from convert-whisper-to-openvino.py --- models/convert-whisper-to-openvino.py | 1 - models/requirements-openvino.txt | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/models/convert-whisper-to-openvino.py b/models/convert-whisper-to-openvino.py index 3124dd3d7cf..a17e535550d 100644 --- a/models/convert-whisper-to-openvino.py +++ b/models/convert-whisper-to-openvino.py @@ -2,7 +2,6 @@ import torch from whisper import load_model import os -from openvino.tools import mo from openvino.frontend import FrontEndManager from openvino.runtime import serialize import shutil diff --git a/models/requirements-openvino.txt b/models/requirements-openvino.txt index 5bfd95db88e..707fa58ab30 100644 --- a/models/requirements-openvino.txt +++ b/models/requirements-openvino.txt @@ -1,2 +1,2 @@ -openvino-dev[pytorch,onnx] -openai-whisper \ No newline at end of file +openvino>=2023.3.0 +openai-whisper From 79218f51d02ffe70575ef7fba3496dfc7adda027 Mon Sep 17 00:00:00 2001 From: Alan <103587817+Lumberj3ck@users.noreply.github.com> Date: Mon, 16 Mar 2026 12:44:18 +0100 Subject: [PATCH 292/831] go : handle EOF correctly in model download (#3671) --- bindings/go/examples/go-model-download/main.go | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/bindings/go/examples/go-model-download/main.go b/bindings/go/examples/go-model-download/main.go index 728c6df53d4..e72262eb7cb 100644 --- a/bindings/go/examples/go-model-download/main.go +++ b/bindings/go/examples/go-model-download/main.go @@ -282,13 +282,20 @@ func Download(ctx context.Context, p io.Writer, model, out string) (string, erro default: // Read body n, err := resp.Body.Read(data) + if n > 0 { + if m, err := w.Write(data[:n]); err != nil { + return path, err + } else { + count += int64(m) + } + } + if err != nil { - DownloadReport(p, pct, count, resp.ContentLength) - return path, err - } else if m, err := w.Write(data[:n]); err != nil { + if err == io.EOF { + DownloadReport(p, pct, count, resp.ContentLength) + return path, nil + } return path, err - } else { - count += int64(m) } } } From dc9611662265870df22a7230b7586176a99c1955 Mon Sep 17 00:00:00 2001 From: lohopupa <87423657+lohopupa@users.noreply.github.com> Date: Tue, 17 Mar 2026 12:19:08 +0600 Subject: [PATCH 293/831] fix: VAD time mapping timestamp drift caused by overlap samples (#3711) * whisper : fix VAD segment overlap boundary handling - Use original segment length (pre-overlap) for vad_end in the time mapping table, so segment boundaries are preserved accurately Claude Sonnet 4.6 (Low) * whisper : remove intermediate VAD time mapping points Now that segment boundaries are mapped accurately, the intermediate point interpolation is no longer necessary. --------- Co-authored-by: Lohopupa --- src/whisper.cpp | 34 ++++++---------------------------- 1 file changed, 6 insertions(+), 28 deletions(-) diff --git a/src/whisper.cpp b/src/whisper.cpp index 796bccfb45d..86bfafeaad8 100644 --- a/src/whisper.cpp +++ b/src/whisper.cpp @@ -6701,12 +6701,13 @@ static bool whisper_vad( int segment_start_samples = cs_to_samples(vad_segments->data[i].start); int segment_end_samples = cs_to_samples(vad_segments->data[i].end); - if (i < (int)vad_segments->data.size() - 1) { - segment_end_samples += overlap_samples; - } - segment_start_samples = std::min(segment_start_samples, n_samples - 1); segment_end_samples = std::min(segment_end_samples, n_samples - 1); + int original_segment_length = segment_end_samples - segment_start_samples; + + if (i < (int)vad_segments->data.size() - 1) { + segment_end_samples = std::min(segment_end_samples + overlap_samples, n_samples - 1); + } int segment_length = segment_end_samples - segment_start_samples; if (segment_length > 0) { whisper_state::vad_segment_info segment; @@ -6715,7 +6716,7 @@ static bool whisper_vad( segment.orig_end = vad_segments->data[i].end; segment.vad_start = samples_to_cs(offset); - segment.vad_end = samples_to_cs(offset + segment_length); + segment.vad_end = samples_to_cs(offset + original_segment_length); // Add segment boundaries to mapping table vad_time_mapping start_mapping = {segment.vad_start, segment.orig_start}; @@ -6724,29 +6725,6 @@ static bool whisper_vad( state->vad_mapping_table.push_back(start_mapping); state->vad_mapping_table.push_back(end_mapping); - // Add intermediate points for longer segments to improve interpolation accuracy - const int64_t min_segment_length = 100; // 1 second - const int64_t point_interval = 20; // Add a point every 200ms - - if (segment.vad_end - segment.vad_start > min_segment_length) { - int64_t segment_duration = segment.vad_end - segment.vad_start; - int num_points = (int)(segment_duration / point_interval) - 1; - - for (int j = 1; j <= num_points; j++) { - int64_t vad_time = segment.vad_start + j * point_interval; - - if (vad_time >= segment.vad_end) continue; - - int64_t vad_elapsed = vad_time - segment.vad_start; - int64_t vad_total = segment.vad_end - segment.vad_start; - int64_t orig_total = segment.orig_end - segment.orig_start; - int64_t orig_time = segment.orig_start + (vad_elapsed * orig_total) / vad_total; - - vad_time_mapping intermediate_mapping = {vad_time, orig_time}; - state->vad_mapping_table.push_back(intermediate_mapping); - } - } - WHISPER_LOG_INFO("%s: vad_segment_info: orig_start: %.2f, orig_end: %.2f, vad_start: %.2f, vad_end: %.2f\n", __func__, segment.orig_start/100.0, segment.orig_end/100.0, segment.vad_start/100.0, segment.vad_end/100.0); ctx->state->vad_segments.push_back(segment); From 945d3151d9a77a2d2bd59f4ab817712fb7e1fdf9 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 16 Mar 2026 20:09:25 +0200 Subject: [PATCH 294/831] ggml : restore ggml_type_sizef() to aboid major version bump (ggml/1441) --- ggml/include/ggml.h | 4 ++++ ggml/src/ggml.c | 6 ++++++ 2 files changed, 10 insertions(+) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 25f9601e9b5..669f66b650f 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -733,6 +733,10 @@ extern "C" { GGML_API size_t ggml_type_size(enum ggml_type type); // size in bytes for all elements in a block GGML_API size_t ggml_row_size (enum ggml_type type, int64_t ne); // size in bytes for all elements in a row + GGML_DEPRECATED( + GGML_API double ggml_type_sizef(enum ggml_type type), // ggml_type_size()/ggml_blck_size() as float + "use ggml_row_size() instead"); + GGML_API const char * ggml_type_name(enum ggml_type type); GGML_API const char * ggml_op_name (enum ggml_op op); GGML_API const char * ggml_op_symbol(enum ggml_op op); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index e5b83e14479..4c0764a0ac5 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1294,6 +1294,12 @@ size_t ggml_row_size(enum ggml_type type, int64_t ne) { return ggml_type_size(type)*ne/ggml_blck_size(type); } +double ggml_type_sizef(enum ggml_type type) { + assert(type >= 0); + assert(type < GGML_TYPE_COUNT); + return ((double)(type_traits[type].type_size))/type_traits[type].blck_size; +} + const char * ggml_type_name(enum ggml_type type) { assert(type >= 0); assert(type < GGML_TYPE_COUNT); From b2be16208dfc6d09b4383bcb1047d0379d8a7f2a Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 16 Mar 2026 20:15:14 +0200 Subject: [PATCH 295/831] ggml : bump version to 0.9.8 (ggml/1442) --- ggml/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 44e58a52761..c780077acaa 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -4,7 +4,7 @@ project("ggml" C CXX ASM) ### GGML Version set(GGML_VERSION_MAJOR 0) set(GGML_VERSION_MINOR 9) -set(GGML_VERSION_PATCH 7) +set(GGML_VERSION_PATCH 8) set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}") find_program(GIT_EXE NAMES git git.exe NO_CMAKE_FIND_ROOT_PATH) From f5b477ab09041467e3c630c2f3424a92969ad9b3 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 18 Mar 2026 14:45:25 +0200 Subject: [PATCH 296/831] sync : ggml --- scripts/sync-ggml.last | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index 709d00d40b3..6557fb46cbe 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -9d0addf420778b42c257cd3837fbd38ca4599f3b +c044a8eeae2591faa0950c8b5e514cbc4bbfc4ca From 4bbce1e5b230300b375d44b9da40a76e1e4130ee Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 18 Mar 2026 22:34:51 +0200 Subject: [PATCH 297/831] benches : update --- examples/bench/bench.cpp | 41 ++++---- scripts/bench-all-gg.txt | 196 +++++++++++++++++++-------------------- scripts/bench-all.sh | 10 +- 3 files changed, 127 insertions(+), 120 deletions(-) diff --git a/examples/bench/bench.cpp b/examples/bench/bench.cpp index 2d967f2caf4..049473d4f32 100644 --- a/examples/bench/bench.cpp +++ b/examples/bench/bench.cpp @@ -85,33 +85,38 @@ static int whisper_bench_full(const whisper_params & params) { fprintf(stderr, "error: failed to set mel: %d\n", ret); return 3; } - // heat encoder - if (int ret = whisper_encode(ctx, 0, params.n_threads) != 0) { - fprintf(stderr, "error: failed to encode: %d\n", ret); - return 4; - } whisper_token tokens[512]; memset(tokens, 0, sizeof(tokens)); - // prompt heat - if (int ret = whisper_decode(ctx, tokens, 256, 0, params.n_threads) != 0) { - fprintf(stderr, "error: failed to decode: %d\n", ret); - return 4; - } + // TODO: need 2 loops because of the current graph capture logic in the CUDA backend + // https://github.com/ggml-org/llama.cpp/pull/19754 + for (int h = 0; h < 2; ++h) { + // heat encoder + if (int ret = whisper_encode(ctx, 0, params.n_threads) != 0) { + fprintf(stderr, "error: failed to encode: %d\n", ret); + return 4; + } - // text-generation heat - for (int i = 0; i < 256; i++) { - if (int ret = whisper_decode(ctx, tokens, 1, i, params.n_threads) != 0) { + // prompt heat + if (int ret = whisper_decode(ctx, tokens, 256, 0, params.n_threads) != 0) { fprintf(stderr, "error: failed to decode: %d\n", ret); return 4; } - } - // batched heat - if (int ret = whisper_decode(ctx, tokens, 5, 0, params.n_threads) != 0) { - fprintf(stderr, "error: failed to decode: %d\n", ret); - return 4; + // text-generation heat + for (int i = 0; i < 256; i++) { + if (int ret = whisper_decode(ctx, tokens, 1, i, params.n_threads) != 0) { + fprintf(stderr, "error: failed to decode: %d\n", ret); + return 4; + } + } + + // batched heat + if (int ret = whisper_decode(ctx, tokens, 5, 0, params.n_threads) != 0) { + fprintf(stderr, "error: failed to decode: %d\n", ret); + return 4; + } } whisper_reset_timings(ctx); diff --git a/scripts/bench-all-gg.txt b/scripts/bench-all-gg.txt index 32a0908306c..220bd4c98b8 100644 --- a/scripts/bench-all-gg.txt +++ b/scripts/bench-all-gg.txt @@ -111,61 +111,61 @@ make -j && ./scripts/bench-all.sh 1 1 0 | CPU | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit | | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | -| M2 ULTRA | METAL | tiny | 1 | 0 | 8.80 | 1.13 | 0.28 | 0.01 | 47af2fb7 | -| M2 ULTRA | METAL | tiny-q5_0 | 1 | 0 | 9.34 | 1.09 | 0.28 | 0.01 | 47af2fb7 | -| M2 ULTRA | METAL | tiny-q5_1 | 1 | 0 | 9.29 | 1.09 | 0.29 | 0.01 | 47af2fb7 | -| M2 ULTRA | METAL | tiny-q8_0 | 1 | 0 | 9.00 | 1.12 | 0.28 | 0.01 | 47af2fb7 | -| M2 ULTRA | METAL | base | 1 | 0 | 15.92 | 1.60 | 0.43 | 0.02 | 47af2fb7 | -| M2 ULTRA | METAL | base-q5_0 | 1 | 0 | 17.01 | 1.53 | 0.43 | 0.02 | 47af2fb7 | -| M2 ULTRA | METAL | base-q5_1 | 1 | 0 | 17.02 | 1.53 | 0.44 | 0.02 | 47af2fb7 | -| M2 ULTRA | METAL | base-q8_0 | 1 | 0 | 16.25 | 1.55 | 0.43 | 0.02 | 47af2fb7 | -| M2 ULTRA | METAL | small | 1 | 0 | 47.83 | 3.09 | 0.91 | 0.05 | 47af2fb7 | -| M2 ULTRA | METAL | small-q5_0 | 1 | 0 | 52.85 | 2.98 | 0.94 | 0.06 | 47af2fb7 | -| M2 ULTRA | METAL | small-q5_1 | 1 | 0 | 52.92 | 2.97 | 0.94 | 0.06 | 47af2fb7 | -| M2 ULTRA | METAL | small-q8_0 | 1 | 0 | 49.05 | 2.89 | 0.90 | 0.06 | 47af2fb7 | -| M2 ULTRA | METAL | medium | 1 | 0 | 127.98 | 6.62 | 2.05 | 0.12 | 47af2fb7 | -| M2 ULTRA | METAL | medium-q5_0 | 1 | 0 | 145.42 | 6.09 | 2.12 | 0.14 | 47af2fb7 | -| M2 ULTRA | METAL | medium-q5_1 | 1 | 0 | 145.16 | 6.08 | 2.14 | 0.14 | 47af2fb7 | -| M2 ULTRA | METAL | medium-q8_0 | 1 | 0 | 132.72 | 6.10 | 2.07 | 0.13 | 47af2fb7 | -| M2 ULTRA | METAL | medium-dis | 1 | 0 | 115.09 | 0.91 | 0.25 | 0.02 | 47af2fb7 | -| M2 ULTRA | METAL | large-v2 | 1 | 0 | 243.69 | 9.68 | 3.14 | 0.22 | 47af2fb7 | -| M2 ULTRA | METAL | large-v2-q5_0 | 1 | 0 | 280.38 | 8.95 | 3.18 | 0.25 | 47af2fb7 | -| M2 ULTRA | METAL | large-v2-q5_1 | 1 | 0 | 279.76 | 8.92 | 3.18 | 0.25 | 47af2fb7 | -| M2 ULTRA | METAL | large-v2-q8_0 | 1 | 0 | 254.55 | 9.35 | 3.04 | 0.23 | 47af2fb7 | -| M2 ULTRA | METAL | large-v2-dis | 1 | 0 | 219.23 | 1.01 | 0.28 | 0.02 | 47af2fb7 | -| M2 ULTRA | METAL | large-v3-turbo | 1 | 0 | 220.57 | 1.55 | 0.46 | 0.03 | 47af2fb7 | -| M2 ULTRA | METAL | large-v3-turbo-q5_0 | 1 | 0 | 253.03 | 1.40 | 0.47 | 0.04 | 47af2fb7 | -| M2 ULTRA | METAL | large-v3-turbo-q8_0 | 1 | 0 | 229.82 | 1.43 | 0.45 | 0.04 | 47af2fb7 | +| M2 ULTRA | METAL | tiny | 1 | 0 | 8.57 | 1.12 | 0.27 | 0.01 | f5b477ab | +| M2 ULTRA | METAL | tiny-q5_0 | 1 | 0 | 9.17 | 1.10 | 0.28 | 0.01 | f5b477ab | +| M2 ULTRA | METAL | tiny-q5_1 | 1 | 0 | 9.16 | 1.09 | 0.28 | 0.01 | f5b477ab | +| M2 ULTRA | METAL | tiny-q8_0 | 1 | 0 | 8.81 | 1.12 | 0.27 | 0.01 | f5b477ab | +| M2 ULTRA | METAL | base | 1 | 0 | 15.60 | 1.61 | 0.41 | 0.02 | f5b477ab | +| M2 ULTRA | METAL | base-q5_0 | 1 | 0 | 16.75 | 1.54 | 0.42 | 0.02 | f5b477ab | +| M2 ULTRA | METAL | base-q5_1 | 1 | 0 | 16.64 | 1.54 | 0.43 | 0.02 | f5b477ab | +| M2 ULTRA | METAL | base-q8_0 | 1 | 0 | 16.09 | 1.55 | 0.41 | 0.02 | f5b477ab | +| M2 ULTRA | METAL | small | 1 | 0 | 46.74 | 3.13 | 0.89 | 0.05 | f5b477ab | +| M2 ULTRA | METAL | small-q5_0 | 1 | 0 | 51.57 | 3.03 | 0.91 | 0.06 | f5b477ab | +| M2 ULTRA | METAL | small-q5_1 | 1 | 0 | 51.85 | 3.03 | 0.92 | 0.06 | f5b477ab | +| M2 ULTRA | METAL | small-q8_0 | 1 | 0 | 48.34 | 3.01 | 0.89 | 0.06 | f5b477ab | +| M2 ULTRA | METAL | medium | 1 | 0 | 125.82 | 6.46 | 2.01 | 0.12 | f5b477ab | +| M2 ULTRA | METAL | medium-q5_0 | 1 | 0 | 143.44 | 5.97 | 2.07 | 0.14 | f5b477ab | +| M2 ULTRA | METAL | medium-q5_1 | 1 | 0 | 143.41 | 5.97 | 2.09 | 0.14 | f5b477ab | +| M2 ULTRA | METAL | medium-q8_0 | 1 | 0 | 131.23 | 6.30 | 2.01 | 0.13 | f5b477ab | +| M2 ULTRA | METAL | medium-dis | 1 | 0 | 114.07 | 0.90 | 0.25 | 0.02 | f5b477ab | +| M2 ULTRA | METAL | large-v2 | 1 | 0 | 240.73 | 9.46 | 3.21 | 0.21 | f5b477ab | +| M2 ULTRA | METAL | large-v2-q5_0 | 1 | 0 | 276.56 | 8.62 | 3.16 | 0.25 | f5b477ab | +| M2 ULTRA | METAL | large-v2-q5_1 | 1 | 0 | 275.90 | 8.98 | 3.16 | 0.25 | f5b477ab | +| M2 ULTRA | METAL | large-v2-q8_0 | 1 | 0 | 251.00 | 9.10 | 3.02 | 0.22 | f5b477ab | +| M2 ULTRA | METAL | large-v2-dis | 1 | 0 | 217.43 | 1.01 | 0.28 | 0.02 | f5b477ab | +| M2 ULTRA | METAL | large-v3-turbo | 1 | 0 | 218.39 | 1.55 | 0.47 | 0.03 | f5b477ab | +| M2 ULTRA | METAL | large-v3-turbo-q5_0 | 1 | 0 | 249.41 | 1.39 | 0.47 | 0.04 | f5b477ab | +| M2 ULTRA | METAL | large-v3-turbo-q8_0 | 1 | 0 | 227.54 | 1.43 | 0.45 | 0.03 | f5b477ab | make -j && ./scripts/bench-all.sh 1 1 1 | CPU | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit | | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | -| M2 ULTRA | METAL | tiny | 1 | 1 | 6.19 | 0.93 | 0.21 | 0.01 | 47af2fb7 | -| M2 ULTRA | METAL | tiny-q5_0 | 1 | 1 | 6.64 | 0.89 | 0.22 | 0.01 | 47af2fb7 | -| M2 ULTRA | METAL | tiny-q5_1 | 1 | 1 | 6.65 | 0.91 | 0.23 | 0.01 | 47af2fb7 | -| M2 ULTRA | METAL | tiny-q8_0 | 1 | 1 | 6.26 | 0.93 | 0.22 | 0.01 | 47af2fb7 | -| M2 ULTRA | METAL | base | 1 | 1 | 10.89 | 1.31 | 0.32 | 0.02 | 47af2fb7 | -| M2 ULTRA | METAL | base-q5_0 | 1 | 1 | 12.10 | 1.22 | 0.33 | 0.02 | 47af2fb7 | -| M2 ULTRA | METAL | base-q5_1 | 1 | 1 | 12.05 | 1.22 | 0.33 | 0.02 | 47af2fb7 | -| M2 ULTRA | METAL | base-q8_0 | 1 | 1 | 11.24 | 1.24 | 0.32 | 0.02 | 47af2fb7 | -| M2 ULTRA | METAL | small | 1 | 1 | 32.06 | 2.41 | 0.64 | 0.04 | 47af2fb7 | -| M2 ULTRA | METAL | small-q5_0 | 1 | 1 | 37.20 | 2.32 | 0.67 | 0.04 | 47af2fb7 | -| M2 ULTRA | METAL | small-q5_1 | 1 | 1 | 37.13 | 2.30 | 0.67 | 0.04 | 47af2fb7 | -| M2 ULTRA | METAL | small-q8_0 | 1 | 1 | 33.63 | 2.28 | 0.64 | 0.04 | 47af2fb7 | -| M2 ULTRA | METAL | medium | 1 | 1 | 89.22 | 5.14 | 1.46 | 0.09 | 47af2fb7 | -| M2 ULTRA | METAL | medium-q5_0 | 1 | 1 | 106.82 | 4.83 | 1.49 | 0.11 | 47af2fb7 | -| M2 ULTRA | METAL | medium-q5_1 | 1 | 1 | 106.60 | 4.88 | 1.50 | 0.11 | 47af2fb7 | -| M2 ULTRA | METAL | medium-q8_0 | 1 | 1 | 94.48 | 4.93 | 1.43 | 0.09 | 47af2fb7 | -| M2 ULTRA | METAL | medium-dis | 1 | 1 | 77.85 | 0.80 | 0.20 | 0.01 | 47af2fb7 | -| M2 ULTRA | METAL | large-v2 | 1 | 1 | 170.73 | 7.50 | 2.12 | 0.16 | 47af2fb7 | -| M2 ULTRA | METAL | large-v2-q5_0 | 1 | 1 | 206.46 | 7.05 | 2.17 | 0.20 | 47af2fb7 | -| M2 ULTRA | METAL | large-v2-q5_1 | 1 | 1 | 206.15 | 7.10 | 2.19 | 0.20 | 47af2fb7 | -| M2 ULTRA | METAL | large-v2-q8_0 | 1 | 1 | 180.31 | 6.90 | 2.10 | 0.17 | 47af2fb7 | -| M2 ULTRA | METAL | large-v2-dis | 1 | 1 | 147.44 | 0.90 | 0.22 | 0.02 | 47af2fb7 | -| M2 ULTRA | METAL | large-v3-turbo | 1 | 1 | 148.79 | 1.30 | 0.34 | 0.03 | 47af2fb7 | -| M2 ULTRA | METAL | large-v3-turbo-q5_0 | 1 | 1 | 180.34 | 1.14 | 0.35 | 0.03 | 47af2fb7 | -| M2 ULTRA | METAL | large-v3-turbo-q8_0 | 1 | 1 | 158.04 | 1.18 | 0.33 | 0.03 | 47af2fb7 | +| M2 ULTRA | METAL | tiny | 1 | 1 | 6.06 | 0.96 | 0.22 | 0.01 | f5b477ab | +| M2 ULTRA | METAL | tiny-q5_0 | 1 | 1 | 6.51 | 0.93 | 0.22 | 0.01 | f5b477ab | +| M2 ULTRA | METAL | tiny-q5_1 | 1 | 1 | 6.47 | 0.93 | 0.23 | 0.01 | f5b477ab | +| M2 ULTRA | METAL | tiny-q8_0 | 1 | 1 | 6.16 | 0.94 | 0.21 | 0.01 | f5b477ab | +| M2 ULTRA | METAL | base | 1 | 1 | 10.63 | 1.37 | 0.32 | 0.01 | f5b477ab | +| M2 ULTRA | METAL | base-q5_0 | 1 | 1 | 11.75 | 1.27 | 0.33 | 0.02 | f5b477ab | +| M2 ULTRA | METAL | base-q5_1 | 1 | 1 | 11.73 | 1.25 | 0.33 | 0.02 | f5b477ab | +| M2 ULTRA | METAL | base-q8_0 | 1 | 1 | 11.17 | 1.28 | 0.32 | 0.02 | f5b477ab | +| M2 ULTRA | METAL | small | 1 | 1 | 31.74 | 2.55 | 0.67 | 0.04 | f5b477ab | +| M2 ULTRA | METAL | small-q5_0 | 1 | 1 | 36.21 | 2.47 | 0.69 | 0.04 | f5b477ab | +| M2 ULTRA | METAL | small-q5_1 | 1 | 1 | 36.22 | 2.47 | 0.70 | 0.04 | f5b477ab | +| M2 ULTRA | METAL | small-q8_0 | 1 | 1 | 32.73 | 2.45 | 0.66 | 0.04 | f5b477ab | +| M2 ULTRA | METAL | medium | 1 | 1 | 86.94 | 5.21 | 1.49 | 0.09 | f5b477ab | +| M2 ULTRA | METAL | medium-q5_0 | 1 | 1 | 104.31 | 4.93 | 1.51 | 0.10 | f5b477ab | +| M2 ULTRA | METAL | medium-q5_1 | 1 | 1 | 104.09 | 4.98 | 1.51 | 0.10 | f5b477ab | +| M2 ULTRA | METAL | medium-q8_0 | 1 | 1 | 92.13 | 5.06 | 1.45 | 0.09 | f5b477ab | +| M2 ULTRA | METAL | medium-dis | 1 | 1 | 76.67 | 0.81 | 0.20 | 0.01 | f5b477ab | +| M2 ULTRA | METAL | large-v2 | 1 | 1 | 167.66 | 7.56 | 2.25 | 0.16 | f5b477ab | +| M2 ULTRA | METAL | large-v2-q5_0 | 1 | 1 | 203.09 | 7.13 | 2.29 | 0.20 | f5b477ab | +| M2 ULTRA | METAL | large-v2-q5_1 | 1 | 1 | 202.53 | 7.12 | 2.29 | 0.20 | f5b477ab | +| M2 ULTRA | METAL | large-v2-q8_0 | 1 | 1 | 177.48 | 6.94 | 2.18 | 0.17 | f5b477ab | +| M2 ULTRA | METAL | large-v2-dis | 1 | 1 | 145.61 | 0.91 | 0.23 | 0.02 | f5b477ab | +| M2 ULTRA | METAL | large-v3-turbo | 1 | 1 | 146.95 | 1.33 | 0.36 | 0.03 | f5b477ab | +| M2 ULTRA | METAL | large-v3-turbo-q5_0 | 1 | 1 | 178.57 | 1.17 | 0.36 | 0.03 | f5b477ab | +| M2 ULTRA | METAL | large-v3-turbo-q8_0 | 1 | 1 | 156.19 | 1.21 | 0.34 | 0.03 | f5b477ab | ## M4 Max @@ -268,35 +268,35 @@ make -j && ./scripts/bench-all.sh 1 1 0 | GPU | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit | | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | -| RTX 5090 | CUDA | tiny | 1 | 0 | 2.12 | 0.51 | 0.13 | 0.00 | 47af2fb7 | -| RTX 5090 | CUDA | tiny-q8_0 | 1 | 0 | 2.50 | 0.52 | 0.14 | 0.01 | 47af2fb7 | -| RTX 5090 | CUDA | base | 1 | 0 | 3.74 | 0.76 | 0.19 | 0.01 | 47af2fb7 | -| RTX 5090 | CUDA | base-q8_0 | 1 | 0 | 4.38 | 0.74 | 0.20 | 0.01 | 47af2fb7 | -| RTX 5090 | CUDA | small | 1 | 0 | 11.25 | 1.46 | 0.39 | 0.02 | 47af2fb7 | -| RTX 5090 | CUDA | small-q8_0 | 1 | 0 | 12.70 | 1.58 | 0.41 | 0.02 | 47af2fb7 | -| RTX 5090 | CUDA | medium | 1 | 0 | 31.16 | 3.07 | 0.80 | 0.04 | 47af2fb7 | -| RTX 5090 | CUDA | medium-q8_0 | 1 | 0 | 32.50 | 3.23 | 0.83 | 0.05 | 47af2fb7 | -| RTX 5090 | CUDA | large-v2 | 1 | 0 | 50.04 | 4.59 | 1.15 | 0.05 | 47af2fb7 | -| RTX 5090 | CUDA | large-v2-q8_0 | 1 | 0 | 52.17 | 4.38 | 1.14 | 0.07 | 47af2fb7 | -| RTX 5090 | CUDA | large-v3-turbo | 1 | 0 | 46.88 | 0.70 | 0.17 | 0.01 | 47af2fb7 | -| RTX 5090 | CUDA | large-v3-turbo-q8_0 | 1 | 0 | 48.49 | 0.64 | 0.16 | 0.01 | 47af2fb7 | +| RTX 5090 | CUDA | tiny | 1 | 0 | 2.20 | 0.51 | 0.13 | 0.01 | f5b477ab | +| RTX 5090 | CUDA | tiny-q8_0 | 1 | 0 | 2.35 | 0.52 | 0.14 | 0.01 | f5b477ab | +| RTX 5090 | CUDA | base | 1 | 0 | 3.97 | 0.77 | 0.20 | 0.01 | f5b477ab | +| RTX 5090 | CUDA | base-q8_0 | 1 | 0 | 4.20 | 0.73 | 0.20 | 0.01 | f5b477ab | +| RTX 5090 | CUDA | small | 1 | 0 | 11.87 | 1.48 | 0.40 | 0.02 | f5b477ab | +| RTX 5090 | CUDA | small-q8_0 | 1 | 0 | 12.40 | 1.59 | 0.42 | 0.02 | f5b477ab | +| RTX 5090 | CUDA | medium | 1 | 0 | 32.63 | 3.11 | 0.82 | 0.04 | f5b477ab | +| RTX 5090 | CUDA | medium-q8_0 | 1 | 0 | 31.80 | 3.23 | 0.84 | 0.05 | f5b477ab | +| RTX 5090 | CUDA | large-v2 | 1 | 0 | 52.22 | 4.66 | 1.18 | 0.06 | f5b477ab | +| RTX 5090 | CUDA | large-v2-q8_0 | 1 | 0 | 51.11 | 4.37 | 1.15 | 0.07 | f5b477ab | +| RTX 5090 | CUDA | large-v3-turbo | 1 | 0 | 48.72 | 0.70 | 0.18 | 0.01 | f5b477ab | +| RTX 5090 | CUDA | large-v3-turbo-q8_0 | 1 | 0 | 47.81 | 0.64 | 0.16 | 0.01 | f5b477ab | make -j && ./scripts/bench-all.sh 1 1 1 | GPU | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit | | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | -| RTX 5090 | CUDA | tiny | 1 | 1 | 1.42 | 0.44 | 0.11 | 0.00 | 47af2fb7 | -| RTX 5090 | CUDA | tiny-q8_0 | 1 | 1 | 1.83 | 0.45 | 0.12 | 0.01 | 47af2fb7 | -| RTX 5090 | CUDA | base | 1 | 1 | 2.21 | 0.65 | 0.16 | 0.01 | 47af2fb7 | -| RTX 5090 | CUDA | base-q8_0 | 1 | 1 | 2.85 | 0.62 | 0.17 | 0.01 | 47af2fb7 | -| RTX 5090 | CUDA | small | 1 | 1 | 5.11 | 1.23 | 0.32 | 0.01 | 47af2fb7 | -| RTX 5090 | CUDA | small-q8_0 | 1 | 1 | 6.50 | 1.35 | 0.34 | 0.02 | 47af2fb7 | -| RTX 5090 | CUDA | medium | 1 | 1 | 14.01 | 2.57 | 0.64 | 0.03 | 47af2fb7 | -| RTX 5090 | CUDA | medium-q8_0 | 1 | 1 | 15.34 | 2.72 | 0.67 | 0.04 | 47af2fb7 | -| RTX 5090 | CUDA | large-v2 | 1 | 1 | 21.70 | 3.96 | 0.97 | 0.04 | 47af2fb7 | -| RTX 5090 | CUDA | large-v2-q8_0 | 1 | 1 | 23.57 | 3.70 | 0.94 | 0.05 | 47af2fb7 | -| RTX 5090 | CUDA | large-v3-turbo | 1 | 1 | 18.61 | 0.62 | 0.15 | 0.01 | 47af2fb7 | -| RTX 5090 | CUDA | large-v3-turbo-q8_0 | 1 | 1 | 20.10 | 0.56 | 0.14 | 0.01 | 47af2fb7 | +| RTX 5090 | CUDA | tiny | 1 | 1 | 1.37 | 0.44 | 0.11 | 0.00 | f5b477ab | +| RTX 5090 | CUDA | tiny-q8_0 | 1 | 1 | 1.48 | 0.44 | 0.12 | 0.01 | f5b477ab | +| RTX 5090 | CUDA | base | 1 | 1 | 2.34 | 0.66 | 0.16 | 0.01 | f5b477ab | +| RTX 5090 | CUDA | base-q8_0 | 1 | 1 | 2.51 | 0.62 | 0.17 | 0.01 | f5b477ab | +| RTX 5090 | CUDA | small | 1 | 1 | 5.53 | 1.23 | 0.32 | 0.01 | f5b477ab | +| RTX 5090 | CUDA | small-q8_0 | 1 | 1 | 5.88 | 1.35 | 0.33 | 0.02 | f5b477ab | +| RTX 5090 | CUDA | medium | 1 | 1 | 15.09 | 2.55 | 0.65 | 0.03 | f5b477ab | +| RTX 5090 | CUDA | medium-q8_0 | 1 | 1 | 14.06 | 2.72 | 0.67 | 0.03 | f5b477ab | +| RTX 5090 | CUDA | large-v2 | 1 | 1 | 23.24 | 3.94 | 0.97 | 0.04 | f5b477ab | +| RTX 5090 | CUDA | large-v2-q8_0 | 1 | 1 | 22.00 | 3.68 | 0.93 | 0.05 | f5b477ab | +| RTX 5090 | CUDA | large-v3-turbo | 1 | 1 | 19.81 | 0.62 | 0.15 | 0.01 | f5b477ab | +| RTX 5090 | CUDA | large-v3-turbo-q8_0 | 1 | 1 | 18.62 | 0.56 | 0.14 | 0.01 | f5b477ab | # DGX Spark @@ -305,35 +305,35 @@ make -j && ./scripts/bench-all.sh 1 1 0 | GPU | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit | | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | -| DGX Spk. | CUDA | tiny | 1 | 0 | 9.42 | 0.85 | 0.22 | 0.01 | 47af2fb7 | -| DGX Spk. | CUDA | tiny-q8_0 | 1 | 0 | 9.69 | 0.81 | 0.20 | 0.01 | 47af2fb7 | -| DGX Spk. | CUDA | base | 1 | 0 | 18.81 | 1.36 | 0.33 | 0.02 | 47af2fb7 | -| DGX Spk. | CUDA | base-q8_0 | 1 | 0 | 18.11 | 1.20 | 0.30 | 0.02 | 47af2fb7 | -| DGX Spk. | CUDA | small | 1 | 0 | 59.83 | 3.01 | 0.74 | 0.04 | 47af2fb7 | -| DGX Spk. | CUDA | small-q8_0 | 1 | 0 | 59.12 | 2.66 | 0.67 | 0.05 | 47af2fb7 | -| DGX Spk. | CUDA | medium | 1 | 0 | 163.73 | 7.53 | 1.70 | 0.12 | 47af2fb7 | -| DGX Spk. | CUDA | medium-q8_0 | 1 | 0 | 157.54 | 5.98 | 1.48 | 0.13 | 47af2fb7 | -| DGX Spk. | CUDA | large-v2 | 1 | 0 | 279.83 | 12.26 | 2.77 | 0.21 | 47af2fb7 | -| DGX Spk. | CUDA | large-v2-q8_0 | 1 | 0 | 273.05 | 9.31 | 2.33 | 0.22 | 47af2fb7 | -| DGX Spk. | CUDA | large-v3-turbo | 1 | 0 | 271.11 | 2.06 | 0.47 | 0.03 | 47af2fb7 | -| DGX Spk. | CUDA | large-v3-turbo-q8_0 | 1 | 0 | 262.69 | 1.49 | 0.36 | 0.03 | 47af2fb7 | +| DGX Spk. | CUDA | tiny | 1 | 0 | 9.00 | 0.85 | 0.14 | 0.01 | f5b477ab | +| DGX Spk. | CUDA | tiny-q8_0 | 1 | 0 | 8.86 | 0.83 | 0.12 | 0.01 | f5b477ab | +| DGX Spk. | CUDA | base | 1 | 0 | 18.48 | 1.38 | 0.22 | 0.02 | f5b477ab | +| DGX Spk. | CUDA | base-q8_0 | 1 | 0 | 17.28 | 1.22 | 0.19 | 0.02 | f5b477ab | +| DGX Spk. | CUDA | small | 1 | 0 | 56.43 | 3.01 | 0.51 | 0.04 | f5b477ab | +| DGX Spk. | CUDA | small-q8_0 | 1 | 0 | 55.70 | 2.68 | 0.44 | 0.04 | f5b477ab | +| DGX Spk. | CUDA | medium | 1 | 0 | 160.20 | 7.52 | 1.25 | 0.11 | f5b477ab | +| DGX Spk. | CUDA | medium-q8_0 | 1 | 0 | 150.84 | 6.01 | 1.01 | 0.12 | f5b477ab | +| DGX Spk. | CUDA | large-v2 | 1 | 0 | 276.42 | 12.29 | 2.16 | 0.20 | f5b477ab | +| DGX Spk. | CUDA | large-v2-q8_0 | 1 | 0 | 264.92 | 9.32 | 1.67 | 0.20 | f5b477ab | +| DGX Spk. | CUDA | large-v3-turbo | 1 | 0 | 264.90 | 2.03 | 0.37 | 0.03 | f5b477ab | +| DGX Spk. | CUDA | large-v3-turbo-q8_0 | 1 | 0 | 253.56 | 1.48 | 0.27 | 0.03 | f5b477ab | make -j && ./scripts/bench-all.sh 1 1 1 | GPU | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit | | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | -| DGX Spk. | CUDA | tiny | 1 | 1 | 2.89 | 0.76 | 0.19 | 0.01 | 47af2fb7 | -| DGX Spk. | CUDA | tiny-q8_0 | 1 | 1 | 3.06 | 0.72 | 0.17 | 0.01 | 47af2fb7 | -| DGX Spk. | CUDA | base | 1 | 1 | 5.37 | 1.23 | 0.29 | 0.01 | 47af2fb7 | -| DGX Spk. | CUDA | base-q8_0 | 1 | 1 | 4.70 | 1.07 | 0.26 | 0.01 | 47af2fb7 | -| DGX Spk. | CUDA | small | 1 | 1 | 17.70 | 2.73 | 0.66 | 0.02 | 47af2fb7 | -| DGX Spk. | CUDA | small-q8_0 | 1 | 1 | 16.77 | 2.38 | 0.58 | 0.03 | 47af2fb7 | -| DGX Spk. | CUDA | medium | 1 | 1 | 56.22 | 6.98 | 1.53 | 0.06 | 47af2fb7 | -| DGX Spk. | CUDA | medium-q8_0 | 1 | 1 | 46.39 | 5.46 | 1.28 | 0.07 | 47af2fb7 | -| DGX Spk. | CUDA | large-v2 | 1 | 1 | 100.33 | 11.59 | 2.53 | 0.09 | 47af2fb7 | -| DGX Spk. | CUDA | large-v2-q8_0 | 1 | 1 | 97.28 | 8.60 | 2.10 | 0.10 | 47af2fb7 | -| DGX Spk. | CUDA | large-v3-turbo | 1 | 1 | 92.59 | 2.00 | 0.44 | 0.02 | 47af2fb7 | -| DGX Spk. | CUDA | large-v3-turbo-q8_0 | 1 | 1 | 85.96 | 1.40 | 0.33 | 0.02 | 47af2fb7 | +| DGX Spk. | CUDA | tiny | 1 | 1 | 2.63 | 0.76 | 0.13 | 0.01 | f5b477ab | +| DGX Spk. | CUDA | tiny-q8_0 | 1 | 1 | 2.46 | 0.73 | 0.11 | 0.01 | f5b477ab | +| DGX Spk. | CUDA | base | 1 | 1 | 4.96 | 1.24 | 0.20 | 0.01 | f5b477ab | +| DGX Spk. | CUDA | base-q8_0 | 1 | 1 | 4.23 | 1.08 | 0.17 | 0.01 | f5b477ab | +| DGX Spk. | CUDA | small | 1 | 1 | 16.26 | 2.73 | 0.47 | 0.02 | f5b477ab | +| DGX Spk. | CUDA | small-q8_0 | 1 | 1 | 14.94 | 2.38 | 0.39 | 0.02 | f5b477ab | +| DGX Spk. | CUDA | medium | 1 | 1 | 51.81 | 6.94 | 1.22 | 0.05 | f5b477ab | +| DGX Spk. | CUDA | medium-q8_0 | 1 | 1 | 41.51 | 5.44 | 0.93 | 0.05 | f5b477ab | +| DGX Spk. | CUDA | large-v2 | 1 | 1 | 98.54 | 11.53 | 2.05 | 0.08 | f5b477ab | +| DGX Spk. | CUDA | large-v2-q8_0 | 1 | 1 | 91.61 | 8.49 | 1.55 | 0.08 | f5b477ab | +| DGX Spk. | CUDA | large-v3-turbo | 1 | 1 | 87.20 | 1.94 | 0.36 | 0.02 | f5b477ab | +| DGX Spk. | CUDA | large-v3-turbo-q8_0 | 1 | 1 | 80.28 | 1.38 | 0.26 | 0.01 | f5b477ab | # V100 diff --git a/scripts/bench-all.sh b/scripts/bench-all.sh index a15a361c708..7a0d0c8764b 100755 --- a/scripts/bench-all.sh +++ b/scripts/bench-all.sh @@ -100,12 +100,14 @@ for model in "${models[@]}"; do if [[ $system_info == *"CUDA = 1"* ]]; then config="$config CUDA" + elif [[ $system_info == *"CUDA : ARCHS"* ]]; then + config="$config CUDA" fi - if [[ $system_info == *"METAL = 1"* ]]; then - config="$config METAL" - elif [[ $system_info == *"Metal : EMBED_LIBRARY = 1"* ]]; then - config="$config METAL" + if [[ $system_info == *"MTL = 1"* ]]; then + config="$config MTL" + elif [[ $system_info == *"MTL : EMBED_LIBRARY = 1"* ]]; then + config="$config MTL" fi commit=$(git rev-parse --short HEAD) From ef3463bb29ef90d25dfabfd1e75993111c52412d Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 18 Mar 2026 22:43:38 +0200 Subject: [PATCH 298/831] ci : update workflows --- .github/workflows/build.yml | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 8ce887fd111..fb115b22abb 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -860,7 +860,7 @@ jobs: echo "$CUDA_TOOLKIT_DIR\libnvvp" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append echo "CUDA_PATH=$CUDA_TOOLKIT_DIR" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8 echo "CUDA_PATH_V11_8=$CUDA_TOOLKIT_DIR" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8 - + - name: Install Cuda Toolkit 12.4.0 if: ${{ matrix.cuda-toolkit == '12.4.0' }} run: | @@ -1478,7 +1478,7 @@ jobs: LLAMA_ARG_THREADS=$(nproc) GG_BUILD_NO_BF16=1 GG_BUILD_EXTRA_TESTS_0=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt ggml-ci-x64-nvidia-cuda: - runs-on: [self-hosted, Linux, X64, NVIDIA] + runs-on: [self-hosted, Linux, mnt-root, NVIDIA] steps: - name: Clone @@ -1492,7 +1492,7 @@ jobs: GG_BUILD_CUDA=1 bash ./ci/run.sh ~/results/whisper.cpp /mnt/whisper.cpp ggml-ci-x64-nvidia-vulkan-cm: - runs-on: [self-hosted, Linux, X64, NVIDIA] + runs-on: [self-hosted, Linux, mnt-root, NVIDIA] steps: - name: Clone @@ -1506,7 +1506,7 @@ jobs: GG_BUILD_VULKAN=1 GGML_VK_DISABLE_COOPMAT2=1 bash ./ci/run.sh ~/results/whisper.cpp /mnt/whisper.cpp ggml-ci-x64-nvidia-vulkan-cm2: - runs-on: [self-hosted, Linux, X64, NVIDIA, COOPMAT2] + runs-on: [self-hosted, Linux, mnt-root, NVIDIA, COOPMAT2] steps: - name: Clone @@ -1519,18 +1519,18 @@ jobs: vulkaninfo --summary GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/whisper.cpp /mnt/whisper.cpp - ggml-ci-x64-cpu-amx: - runs-on: [self-hosted, Linux, X64, CPU, AMX] + #ggml-ci-x64-cpu-amx: + # runs-on: [self-hosted, Linux, X64, CPU, AMX] - steps: - - name: Clone - id: checkout - uses: actions/checkout@v6 + # steps: + # - name: Clone + # id: checkout + # uses: actions/checkout@v6 - - name: Test - id: ggml-ci - run: | - bash ./ci/run.sh ~/results/whisper.cpp /mnt/whisper.cpp + # - name: Test + # id: ggml-ci + # run: | + # bash ./ci/run.sh ~/results/whisper.cpp /mnt/whisper.cpp ggml-ci-mac-metal: runs-on: [self-hosted, macOS, ARM64] From 9386f239401074690479731c1e41683fbbeac557 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 19 Mar 2026 10:40:13 +0200 Subject: [PATCH 299/831] release : v1.8.4 --- CMakeLists.txt | 2 +- bindings/javascript/package.json | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 06577bf1181..a0f74041321 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,6 +1,6 @@ cmake_minimum_required(VERSION 3.5) # for add_link_options and implicit target directories. project("whisper.cpp" C CXX) -project("whisper.cpp" VERSION 1.8.3) +project("whisper.cpp" VERSION 1.8.4) include(CheckIncludeFileCXX) set(SOVERSION 1) diff --git a/bindings/javascript/package.json b/bindings/javascript/package.json index 84139804314..074dfdda307 100644 --- a/bindings/javascript/package.json +++ b/bindings/javascript/package.json @@ -1,6 +1,6 @@ { "name": "whisper.cpp", - "version": "1.8.3", + "version": "1.8.4", "description": "Whisper speech recognition", "main": "whisper.js", "scripts": { From 76684141a5d059be71cbe23dc2f0ed552213ba2d Mon Sep 17 00:00:00 2001 From: KITAITI Makoto Date: Sun, 22 Mar 2026 02:03:00 +0900 Subject: [PATCH 300/831] ruby : fix dangling pointers, memory leak, and SEGV on parallel transcription (#3715) * Prevent dangling pointers * Use proper free function * Free callback containers * Set default log callback when nil is passed to log_set * Raise error if callbacks set when parallel transcription * Bump version to 1.3.7 * Make tests follow spec change * Add note on parallel transcription and callbacks * Update signature of Whisper.log_set [skip ci] --- bindings/ruby/README.md | 2 + bindings/ruby/ext/ruby_whisper.c | 16 +++- bindings/ruby/ext/ruby_whisper.h | 1 + bindings/ruby/ext/ruby_whisper_context.c | 6 +- bindings/ruby/ext/ruby_whisper_params.c | 79 +++++++++++++++++-- bindings/ruby/ext/ruby_whisper_transcribe.cpp | 4 +- bindings/ruby/sig/whisper.rbs | 8 +- bindings/ruby/test/test_params.rb | 2 + bindings/ruby/test/test_whisper.rb | 29 +++++-- bindings/ruby/whispercpp.gemspec | 2 +- 10 files changed, 127 insertions(+), 22 deletions(-) diff --git a/bindings/ruby/README.md b/bindings/ruby/README.md index c6280a6926a..41e7b330d58 100644 --- a/bindings/ruby/README.md +++ b/bindings/ruby/README.md @@ -202,6 +202,8 @@ whisper.transcribe("path/to/audio.wav", params, n_processors: Etc.nprocessors) Note that transcription occasionally might be low accuracy when it works in parallel. +If n_processors is greater than 1, you cannot set any callbacks including new_segment_callback, progress_callback, encoder_begin_callback, abort_callback, and log_callback set by Whisper.log_set. + ### Segments ### Once `Whisper::Context#transcribe` called, you can retrieve segments by `#each_segment`: diff --git a/bindings/ruby/ext/ruby_whisper.c b/bindings/ruby/ext/ruby_whisper.c index ba71d4ba594..5f1917ee805 100644 --- a/bindings/ruby/ext/ruby_whisper.c +++ b/bindings/ruby/ext/ruby_whisper.c @@ -112,6 +112,10 @@ ruby_whisper_log_callback(enum ggml_log_level level, const char * buffer, void * return; } VALUE log_callback = rb_iv_get(mWhisper, "log_callback"); + if (NIL_P(log_callback)) { + return; + } + VALUE udata = rb_iv_get(mWhisper, "user_data"); rb_funcall(log_callback, id_call, 3, INT2NUM(level), rb_str_new2(buffer), udata); } @@ -129,10 +133,16 @@ static VALUE ruby_whisper_s_log_set(VALUE self, VALUE log_callback, VALUE user_d rb_iv_set(self, "log_callback", log_callback); rb_iv_set(self, "user_data", user_data); - VALUE finalize_log_callback = rb_funcall(mWhisper, rb_intern("method"), 1, rb_str_new2("finalize_log_callback")); - rb_define_finalizer(log_callback, finalize_log_callback); + if (!NIL_P(log_callback)) { + VALUE finalize_log_callback = rb_funcall(mWhisper, rb_intern("method"), 1, rb_str_new2("finalize_log_callback")); + rb_define_finalizer(log_callback, finalize_log_callback); + } - whisper_log_set(ruby_whisper_log_callback, NULL); + if (NIL_P(log_callback)) { + whisper_log_set(NULL, NULL); + } else { + whisper_log_set(ruby_whisper_log_callback, NULL); + } return Qnil; } diff --git a/bindings/ruby/ext/ruby_whisper.h b/bindings/ruby/ext/ruby_whisper.h index 8dfd103c17a..6b0b4df7214 100644 --- a/bindings/ruby/ext/ruby_whisper.h +++ b/bindings/ruby/ext/ruby_whisper.h @@ -2,6 +2,7 @@ #define RUBY_WHISPER_H #include +#include #include #include "whisper.h" diff --git a/bindings/ruby/ext/ruby_whisper_context.c b/bindings/ruby/ext/ruby_whisper_context.c index c39d43bd76c..6e38ead6321 100644 --- a/bindings/ruby/ext/ruby_whisper_context.c +++ b/bindings/ruby/ext/ruby_whisper_context.c @@ -22,7 +22,7 @@ extern const rb_data_type_t ruby_whisper_context_params_type; extern VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self); extern VALUE rb_whisper_model_s_new(VALUE context); extern VALUE rb_whisper_segment_s_new(VALUE context, int index); -extern void prepare_transcription(ruby_whisper_params *rwp, VALUE *context); +extern void prepare_transcription(ruby_whisper_params *rwp, VALUE *context, int n_processors); ID transcribe_option_names[1]; @@ -436,7 +436,7 @@ full_body(VALUE rb_args) GetContext(*args->context, rw); TypedData_Get_Struct(*args->params, ruby_whisper_params, &ruby_whisper_params_type, rwp); - prepare_transcription(rwp, args->context); + prepare_transcription(rwp, args->context, 1); int result = whisper_full(rw->context, rwp->params, args->samples, args->n_samples); return INT2NUM(result); @@ -487,7 +487,7 @@ full_parallel_body(VALUE rb_args) GetContext(*args->context, rw); TypedData_Get_Struct(*args->params, ruby_whisper_params, &ruby_whisper_params_type, rwp); - prepare_transcription(rwp, args->context); + prepare_transcription(rwp, args->context, args->n_processors); int result = whisper_full_parallel(rw->context, rwp->params, args->samples, args->n_samples, args->n_processors); return INT2NUM(result); diff --git a/bindings/ruby/ext/ruby_whisper_params.c b/bindings/ruby/ext/ruby_whisper_params.c index 61eb1733676..3e5dca9c1e1 100644 --- a/bindings/ruby/ext/ruby_whisper_params.c +++ b/bindings/ruby/ext/ruby_whisper_params.c @@ -29,6 +29,7 @@ extern VALUE cParams; extern VALUE cVADParams; +extern VALUE mWhisper; extern ID id_call; @@ -186,6 +187,35 @@ static bool abort_callback(void * user_data) { return false; } +static void +check_thread_safety(ruby_whisper_params *rwp, VALUE *context, int n_processors) +{ + if (n_processors == 1) { + return; + } + + if (!NIL_P(rwp->new_segment_callback_container->callback) || 0 != RARRAY_LEN(rwp->new_segment_callback_container->callbacks)) { + rb_raise(rb_eRuntimeError, "new segment callback not supported on parallel transcription"); + } + + if (!NIL_P(rwp->progress_callback_container->callback) || 0 != RARRAY_LEN(rwp->progress_callback_container->callbacks)) { + rb_raise(rb_eRuntimeError, "progress callback not supported on parallel transcription"); + } + + if (!NIL_P(rwp->encoder_begin_callback_container->callback) || 0 != RARRAY_LEN(rwp->encoder_begin_callback_container->callbacks)) { + rb_raise(rb_eRuntimeError, "encoder begin callback not supported on parallel transcription"); + } + + if (!NIL_P(rwp->abort_callback_container->callback) || 0 != RARRAY_LEN(rwp->abort_callback_container->callbacks)) { + rb_raise(rb_eRuntimeError, "abort callback not supported on parallel transcription"); + } + + VALUE log_callback = rb_iv_get(mWhisper, "log_callback"); + if (!NIL_P(log_callback)) { + rb_raise(rb_eRuntimeError, "log callback not supported for parallel transcription"); + } +} + static void register_callbacks(ruby_whisper_params * rwp, VALUE * context) { if (!NIL_P(rwp->new_segment_callback_container->callback) || 0 != RARRAY_LEN(rwp->new_segment_callback_container->callbacks)) { rwp->new_segment_callback_container->context = context; @@ -219,9 +249,13 @@ static void set_vad_params(ruby_whisper_params *rwp) rwp->params.vad_params = rwvp->params; } +/* + TODO: Set abort callback to trap SIGINT and SIGTERM +*/ void -prepare_transcription(ruby_whisper_params *rwp, VALUE *context) +prepare_transcription(ruby_whisper_params *rwp, VALUE *context, int n_processors) { + check_thread_safety(rwp, context, n_processors); register_callbacks(rwp, context); set_vad_params(rwp); } @@ -240,6 +274,20 @@ rb_whisper_params_mark(void *p) void ruby_whisper_params_free(ruby_whisper_params *rwp) { + if (rwp->params.language) { + ruby_xfree((void *)rwp->params.language); + } + if (rwp->params.initial_prompt) { + ruby_xfree((void *)rwp->params.initial_prompt); + } + if (rwp->params.vad_model_path) { + ruby_xfree((void *)rwp->params.vad_model_path); + } + + xfree(rwp->new_segment_callback_container); + xfree(rwp->progress_callback_container); + xfree(rwp->encoder_begin_callback_container); + xfree(rwp->abort_callback_container); } void @@ -248,7 +296,7 @@ rb_whisper_params_free(void *p) ruby_whisper_params *rwp = (ruby_whisper_params *)p; // How to free user_data and callback only when not referred to by others? ruby_whisper_params_free(rwp); - free(rwp); + xfree(rwp); } static size_t @@ -276,6 +324,15 @@ ruby_whisper_params_allocate(VALUE klass) ruby_whisper_params *rwp; VALUE obj = TypedData_Make_Struct(klass, ruby_whisper_params, &ruby_whisper_params_type, rwp); rwp->params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); + if (rwp->params.language != NULL) { + rwp->params.language = ruby_strdup(rwp->params.language); + } + if (rwp->params.initial_prompt != NULL) { + rwp->params.initial_prompt = ruby_strdup(rwp->params.initial_prompt); + } + if (rwp->params.vad_model_path != NULL) { + rwp->params.vad_model_path = ruby_strdup(rwp->params.vad_model_path); + } rwp->diarize = false; rwp->vad_params = TypedData_Wrap_Struct(cVADParams, &ruby_whisper_vad_params_type, (void *)&rwp->params.vad_params); rwp->new_segment_callback_container = rb_whisper_callback_container_allocate(); @@ -296,10 +353,12 @@ ruby_whisper_params_set_language(VALUE self, VALUE value) { ruby_whisper_params *rwp; TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); + ruby_xfree((void *)rwp->params.language); + rwp->params.language = NULL; if (value == Qfalse || value == Qnil) { - rwp->params.language = "auto"; + rwp->params.language = ruby_strdup("auto"); } else { - rwp->params.language = StringValueCStr(value); + rwp->params.language = ruby_strdup(StringValueCStr(value)); } return value; } @@ -608,7 +667,13 @@ ruby_whisper_params_set_initial_prompt(VALUE self, VALUE value) { ruby_whisper_params *rwp; TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); - rwp->params.initial_prompt = StringValueCStr(value); + ruby_xfree((void *)rwp->params.initial_prompt); + rwp->params.initial_prompt = NULL; + if (NIL_P(value)) { + rwp->params.initial_prompt = NULL; + } else { + rwp->params.initial_prompt = ruby_strdup(StringValueCStr(value)); + } return value; } /* @@ -1103,12 +1168,14 @@ ruby_whisper_params_set_vad_model_path(VALUE self, VALUE value) { ruby_whisper_params *rwp; TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); + ruby_xfree((void *)rwp->params.vad_model_path); + rwp->params.vad_model_path = NULL; if (NIL_P(value)) { rwp->params.vad_model_path = NULL; return value; } VALUE path = ruby_whisper_normalize_model_path(value); - rwp->params.vad_model_path = StringValueCStr(path); + rwp->params.vad_model_path = ruby_strdup(StringValueCStr(path)); return value; } diff --git a/bindings/ruby/ext/ruby_whisper_transcribe.cpp b/bindings/ruby/ext/ruby_whisper_transcribe.cpp index c00fbcd1def..3d00566009a 100644 --- a/bindings/ruby/ext/ruby_whisper_transcribe.cpp +++ b/bindings/ruby/ext/ruby_whisper_transcribe.cpp @@ -16,7 +16,7 @@ extern ID id_to_path; extern ID transcribe_option_names[1]; extern void -prepare_transcription(ruby_whisper_params * rwp, VALUE * self); +prepare_transcription(ruby_whisper_params * rwp, VALUE * self, int n_processors); /* * transcribe a single file @@ -73,7 +73,7 @@ ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) { // rwp->params.encoder_begin_callback_user_data = &is_aborted; // } - prepare_transcription(rwp, &self); + prepare_transcription(rwp, &self, n_processors); if (whisper_full_parallel(rw->context, rwp->params, pcmf32.data(), pcmf32.size(), n_processors) != 0) { fprintf(stderr, "failed to process audio\n"); diff --git a/bindings/ruby/sig/whisper.rbs b/bindings/ruby/sig/whisper.rbs index 9ade451c6b2..3c59661975b 100644 --- a/bindings/ruby/sig/whisper.rbs +++ b/bindings/ruby/sig/whisper.rbs @@ -37,7 +37,7 @@ module Whisper def self.lang_id: (string name) -> Integer def self.lang_str: (Integer id) -> String def self.lang_str_full: (Integer id) -> String - def self.log_set: (log_callback, Object? user_data) -> log_callback + def self.log_set: (log_callback?, Object? user_data) -> log_callback def self.system_info_str: () -> String class Context @@ -52,6 +52,9 @@ module Whisper # puts text # end # + # If n_processors is greater than 1, you cannot set any callbacks including + # new_segment_callback, progress_callback, encoder_begin_callback, abort_callback, + # and log_callback set by Whisper.log_set def transcribe: (path, Params, ?n_processors: Integer) -> self | (path, Params, ?n_processors: Integer) { (String) -> void } -> self @@ -129,6 +132,9 @@ module Whisper # It seems this approach can offer some speedup in some cases. # However, the transcription accuracy can be worse at the beginning and end of each chunk. # + # If n_processors is greater than 1, you cannot set any callbacks including + # new_segment_callback, progress_callback, encoder_begin_callback, abort_callback, + # and log_callback set by Whisper.log_set def full_parallel: (Params, Array[Float], ?Integer n_samples) -> self | (Params, _Samples, ?Integer n_samples) -> self | (Params, _Samples, ?Integer? n_samples, Integer n_processors) -> self diff --git a/bindings/ruby/test/test_params.rb b/bindings/ruby/test/test_params.rb index 094dba6f48e..ff5c28e9043 100644 --- a/bindings/ruby/test/test_params.rb +++ b/bindings/ruby/test/test_params.rb @@ -46,6 +46,8 @@ def setup def test_language @params.language = "en" assert_equal @params.language, "en" + GC.compact + assert_equal @params.language, "en" @params.language = "auto" assert_equal @params.language, "auto" end diff --git a/bindings/ruby/test/test_whisper.rb b/bindings/ruby/test/test_whisper.rb index 29071210072..f7e25239d5d 100644 --- a/bindings/ruby/test/test_whisper.rb +++ b/bindings/ruby/test/test_whisper.rb @@ -43,9 +43,20 @@ def test_transcribe_n_processors @whisper = Whisper::Context.new("base.en") params = Whisper::Params.new - @whisper.transcribe(AUDIO, params, n_processors: 4) {|text| - assert_match(/what you can do for your country/i, text) - } + without_log_callback do + @whisper.transcribe(AUDIO, params, n_processors: 4) {|text| + assert_match(/what you can do for your country/i, text) + } + end + end + + private + + def without_log_callback + Whisper.log_set nil, nil + yield + ensure + Whisper.log_set ->(level, buffer, user_data) {}, nil end sub_test_case "After transcription" do @@ -229,7 +240,9 @@ def test_full_with_memroy_view_gc def test_full_parallel nprocessors = 2 - @whisper.full_parallel(@params, @samples, @samples.length, nprocessors) + without_log_callback do + @whisper.full_parallel(@params, @samples, @samples.length, nprocessors) + end assert_equal nprocessors, @whisper.full_n_segments text = @whisper.each_segment.collect(&:text).join @@ -240,7 +253,9 @@ def test_full_parallel def test_full_parallel_with_memory_view nprocessors = 2 samples = JFKReader.new(AUDIO) - @whisper.full_parallel(@params, samples, nil, nprocessors) + without_log_callback do + @whisper.full_parallel(@params, samples, nil, nprocessors) + end assert_equal nprocessors, @whisper.full_n_segments text = @whisper.each_segment.collect(&:text).join @@ -259,7 +274,9 @@ def test_full_parallel_without_length_and_n_processors def test_full_parallel_without_length nprocessors = 2 - @whisper.full_parallel(@params, @samples, nil, nprocessors) + without_log_callback do + @whisper.full_parallel(@params, @samples, nil, nprocessors) + end assert_equal nprocessors, @whisper.full_n_segments text = @whisper.each_segment.collect(&:text).join diff --git a/bindings/ruby/whispercpp.gemspec b/bindings/ruby/whispercpp.gemspec index 88b94e7eb8a..2d952222f29 100644 --- a/bindings/ruby/whispercpp.gemspec +++ b/bindings/ruby/whispercpp.gemspec @@ -3,7 +3,7 @@ require_relative "extsources" Gem::Specification.new do |s| s.name = "whispercpp" s.authors = ["Georgi Gerganov", "Todd A. Fisher"] - s.version = '1.3.6' + s.version = '1.3.7' s.description = %q{High-performance inference of OpenAI's Whisper automatic speech recognition (ASR) model via Ruby} s.email = 'todd.fisher@gmail.com' s.extra_rdoc_files = ['LICENSE', 'README.md'] From 1335dfa785af56c55bf510275f86f795db2fe474 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Sun, 15 Mar 2026 19:10:15 +0100 Subject: [PATCH 301/831] sycl : fix for untransposed GDA recurrent state (llama/20583) --- ggml/src/ggml-sycl/gated_delta_net.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-sycl/gated_delta_net.cpp b/ggml/src/ggml-sycl/gated_delta_net.cpp index 8c76afbd571..648455c134b 100644 --- a/ggml/src/ggml-sycl/gated_delta_net.cpp +++ b/ggml/src/ggml-sycl/gated_delta_net.cpp @@ -55,7 +55,7 @@ void gated_delta_net_sycl(const float * q, #pragma unroll for (int r = 0; r < rows_per_lane; r++) { const int i = r * warp_size + lane; - s_shard[r] = curr_state[i * S_v + col]; + s_shard[r] = curr_state[col * S_v + i]; } for (int t = 0; t < n_tokens; t++) { @@ -137,7 +137,7 @@ void gated_delta_net_sycl(const float * q, #pragma unroll for (int r = 0; r < rows_per_lane; r++) { const int i = r * warp_size + lane; - state[i * S_v + col] = s_shard[r]; + state[col * S_v + i] = s_shard[r]; } } From dae7781052d858a38d5d57eb2b252364d6e2c6d0 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Mon, 16 Mar 2026 11:41:45 +0800 Subject: [PATCH 302/831] CUDA: GDN hide memory latency (llama/20537) --- ggml/src/ggml-cuda/gated_delta_net.cu | 32 ++++++++++++++++++--------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/ggml/src/ggml-cuda/gated_delta_net.cu b/ggml/src/ggml-cuda/gated_delta_net.cu index 1ce6d5f31b5..6b44bec7317 100644 --- a/ggml/src/ggml-cuda/gated_delta_net.cu +++ b/ggml/src/ggml-cuda/gated_delta_net.cu @@ -1,7 +1,8 @@ #include "gated_delta_net.cuh" template -__global__ void gated_delta_net_cuda(const float * q, +__global__ void __launch_bounds__((ggml_cuda_get_physical_warp_size() < S_v ? ggml_cuda_get_physical_warp_size() : S_v) * 4, 2) +gated_delta_net_cuda(const float * q, const float * k, const float * v, const float * g, @@ -38,7 +39,7 @@ __global__ void gated_delta_net_cuda(const float * q, const int64_t state_offset = (sequence * H + h_idx) * S_v * S_v; state += state_offset; - curr_state += state_offset; + curr_state += state_offset + col * S_v; attn_data += (sequence * n_tokens * H + h_idx) * S_v; constexpr int warp_size = ggml_cuda_get_physical_warp_size() < S_v ? ggml_cuda_get_physical_warp_size() : S_v; @@ -46,10 +47,11 @@ __global__ void gated_delta_net_cuda(const float * q, constexpr int rows_per_lane = (S_v + warp_size - 1) / warp_size; float s_shard[rows_per_lane]; // state is stored transposed: M[col][i] = S[i][col], row col is contiguous + #pragma unroll for (int r = 0; r < rows_per_lane; r++) { const int i = r * warp_size + lane; - s_shard[r] = curr_state[col * S_v + i]; + s_shard[r] = curr_state[i]; } for (int t = 0; t < n_tokens; t++) { @@ -63,6 +65,16 @@ __global__ void gated_delta_net_cuda(const float * q, const float beta_val = *beta_t; + // Cache k and q in registers + float k_reg[rows_per_lane]; + float q_reg[rows_per_lane]; +#pragma unroll + for (int r = 0; r < rows_per_lane; r++) { + const int i = r * warp_size + lane; + k_reg[r] = k_t[i]; + q_reg[r] = q_t[i]; + } + if constexpr (!KDA) { const float g_val = expf(*g_t); @@ -70,8 +82,7 @@ __global__ void gated_delta_net_cuda(const float * q, float kv_shard = 0.0f; #pragma unroll for (int r = 0; r < rows_per_lane; r++) { - const int i = r * warp_size + lane; - kv_shard += s_shard[r] * k_t[i]; + kv_shard += s_shard[r] * k_reg[r]; } float kv_col = warp_reduce_sum(kv_shard); @@ -83,9 +94,8 @@ __global__ void gated_delta_net_cuda(const float * q, float attn_partial = 0.0f; #pragma unroll for (int r = 0; r < rows_per_lane; r++) { - const int i = r * warp_size + lane; - s_shard[r] = g_val * s_shard[r] + k_t[i] * delta_col; - attn_partial += s_shard[r] * q_t[i]; + s_shard[r] = g_val * s_shard[r] + k_reg[r] * delta_col; + attn_partial += s_shard[r] * q_reg[r]; } float attn_col = warp_reduce_sum(attn_partial); @@ -99,7 +109,7 @@ __global__ void gated_delta_net_cuda(const float * q, #pragma unroll for (int r = 0; r < rows_per_lane; r++) { const int i = r * warp_size + lane; - kv_shard += expf(g_t[i]) * s_shard[r] * k_t[i]; + kv_shard += expf(g_t[i]) * s_shard[r] * k_reg[r]; } float kv_col = warp_reduce_sum(kv_shard); @@ -113,8 +123,8 @@ __global__ void gated_delta_net_cuda(const float * q, #pragma unroll for (int r = 0; r < rows_per_lane; r++) { const int i = r * warp_size + lane; - s_shard[r] = expf(g_t[i]) * s_shard[r] + k_t[i] * delta_col; - attn_partial += s_shard[r] * q_t[i]; + s_shard[r] = expf(g_t[i]) * s_shard[r] + k_reg[r] * delta_col; + attn_partial += s_shard[r] * q_reg[r]; } float attn_col = warp_reduce_sum(attn_partial); From 724ea71cf97e4887809093c7fc75b7bfe34506f4 Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Mon, 16 Mar 2026 10:45:49 +0100 Subject: [PATCH 303/831] vulkan: fix flash attention dot product precision (llama/20589) --- ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index ec48f5b1152..11b7dce8578 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -245,7 +245,7 @@ void main() { #endif } [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - Sf[r][c] += ACC_TYPE(dot(Q_cache[r], K_Tf)); + Sf[r][c] += dot(ACC_TYPEV4(Q_cache[r]), ACC_TYPEV4(K_Tf)); } } } @@ -270,7 +270,7 @@ void main() { #endif } [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - Sf[r][c] += ACC_TYPE(dot(Qf[tile_row(r) * qf_stride + d * D_split + d_tid], K_Tf)); + Sf[r][c] += dot(ACC_TYPEV4(Qf[tile_row(r) * qf_stride + d * D_split + d_tid]), ACC_TYPEV4(K_Tf)); } } } From 9232af59ba3cf8ed9050d4a7229fd8a30e9eeeb8 Mon Sep 17 00:00:00 2001 From: Martin Klacer Date: Mon, 16 Mar 2026 19:25:54 +0000 Subject: [PATCH 304/831] kleidiai: add data type check to get_tensor_traits (llama/20639) * kleidiai: add data type check to get_tensor_traits * Added check for F16 data type into get_tensor_traits path with input data not in ggml_backend_cpu_kleidiai_buffer_type format (unsupported for Q4/8) Signed-off-by: Martin Klacer Change-Id: I9aca4b9b8d669d35db6f1dbcc4e080b1919b1de7 * updated ggml/src/ggml-cpu/kleidiai/kleidiai.cpp updated kleidiai.cpp file as per suggestion Co-authored-by: Georgi Gerganov --------- Signed-off-by: Martin Klacer Co-authored-by: Georgi Gerganov --- ggml/src/ggml-cpu/kleidiai/kleidiai.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp index 9bcc18d442c..7a5924944a8 100644 --- a/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +++ b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp @@ -1473,10 +1473,12 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type { if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) { return (ggml::cpu::tensor_traits *) op->src[0]->extra; } else { + if (op->src[0]->type != GGML_TYPE_F16) { + return nullptr; + } std::array kernel_chain; const int slot_total = kleidiai_collect_kernel_chain(op, kernel_chain); - const bool has_kernel = slot_total > 0; - if (has_kernel && op->src[1]->ne[1] > 1) { + if (slot_total > 0 && op->src[1]->ne[1] > 1) { if ((op->src[0]->nb[1] * op->src[0]->ne[1] != op->src[0]->nb[2]) || (op->src[1]->nb[1] * op->src[1]->ne[1] != op->src[1]->nb[2])) { return nullptr; From 64942511978ccb23975c17b44513fa7fe858b8a2 Mon Sep 17 00:00:00 2001 From: Neo Zhang Date: Tue, 17 Mar 2026 10:01:52 +0800 Subject: [PATCH 305/831] ehance UPSCALE to support all UT cases (llama/20637) * [SYCL] ehance UPSCALE to support more cases * rm test case result of SYCL1 --- ggml/src/ggml-sycl/backend.hpp | 4 +- ggml/src/ggml-sycl/element_wise.cpp | 89 ------ ggml/src/ggml-sycl/element_wise.hpp | 2 - ggml/src/ggml-sycl/ggml-sycl.cpp | 4 +- ggml/src/ggml-sycl/upscale.cpp | 410 ++++++++++++++++++++++++++++ ggml/src/ggml-sycl/upscale.hpp | 9 + 6 files changed, 423 insertions(+), 95 deletions(-) create mode 100644 ggml/src/ggml-sycl/upscale.cpp create mode 100644 ggml/src/ggml-sycl/upscale.hpp diff --git a/ggml/src/ggml-sycl/backend.hpp b/ggml/src/ggml-sycl/backend.hpp index b30b7f2beb7..a526d8e58bc 100644 --- a/ggml/src/ggml-sycl/backend.hpp +++ b/ggml/src/ggml-sycl/backend.hpp @@ -24,6 +24,7 @@ #include "dmmv.hpp" #include "element_wise.hpp" #include "fattn.hpp" +#include "gated_delta_net.hpp" #include "gla.hpp" #include "im2col.hpp" #include "mmq.hpp" @@ -31,6 +32,7 @@ #include "norm.hpp" #include "outprod.hpp" #include "pad.hpp" +#include "pad_reflect_1d.hpp" #include "quantize.hpp" #include "quants.hpp" #include "roll.hpp" @@ -39,8 +41,8 @@ #include "ssm_conv.hpp" #include "softmax.hpp" #include "tsembd.hpp" +#include "upscale.hpp" #include "wkv.hpp" -#include "pad_reflect_1d.hpp" #endif // GGML_SYCL_BACKEND_HPP diff --git a/ggml/src/ggml-sycl/element_wise.cpp b/ggml/src/ggml-sycl/element_wise.cpp index acd51bf45b2..ec0247528c4 100644 --- a/ggml/src/ggml-sycl/element_wise.cpp +++ b/ggml/src/ggml-sycl/element_wise.cpp @@ -294,30 +294,6 @@ static void unary_op_trunc_kernel(const T * x, T * dst, const int k, const sycl: } } -template -static void upscale(const T *x, T *dst, const int nb00, const int nb01, - const int nb02, const int nb03, const int ne10, const int ne11, - const int ne12, const int ne13, const float sf0, const float sf1, - const float sf2, const float sf3, const sycl::nd_item<1> &item_ct1) { - int index = item_ct1.get_local_id(0) + - item_ct1.get_group(0) * item_ct1.get_local_range(0); - if (index >= ne10 * ne11 * ne12 * ne13) { - return; - } - // operation - int i10 = index % ne10; - int i11 = (index / ne10) % ne11; - int i12 = (index / (ne10 * ne11)) % ne12; - int i13 = (index / (ne10 * ne11 * ne12)) % ne13; - - int i00 = static_cast(i10 / sf0); - int i01 = static_cast(i11 / sf1); - int i02 = static_cast(i12 / sf2); - int i03 = static_cast(i13 / sf3); - - dst[index] = *(const T *)((const char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00); -} - template static void clamp(const T * x, T * dst, const float min, const float max, const int k, const sycl::nd_item<1> &item_ct1) { @@ -392,20 +368,6 @@ static void arange_kernel(T * dst, const int k, T start, T step, } } -template -static void upscale_sycl(const T *x, T *dst, const int nb00, const int nb01, - const int nb02, const int nb03, const int ne10, const int ne11, - const int ne12, const int ne13, const float sf0, const float sf1, - const float sf2, const float sf3, queue_ptr stream) { - int dst_size = ne10 * ne11 * ne12 * ne13; - int num_blocks = ceil_div(dst_size, SYCL_UPSCALE_BLOCK_SIZE); - sycl::range<1> gridDim(num_blocks * SYCL_UPSCALE_BLOCK_SIZE); - stream->parallel_for( - sycl::nd_range<1>(gridDim, sycl::range<1>(SYCL_UPSCALE_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { - upscale(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3, item_ct1); - }); -} - template static inline void dispatch_ggml_sycl_op_unary(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) { GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); @@ -505,42 +467,6 @@ static inline void dispatch_ggml_sycl_op_fused_glu(ggml_backend_sycl_context & c } } -template -static inline void dispatch_ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) { - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); - - GGML_ASSERT(dst->src[0]->type == dst->type); - - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - - const float sf0 = (float) dst->ne[0] / dst->src[0]->ne[0]; - const float sf1 = (float) dst->ne[1] / dst->src[0]->ne[1]; - const float sf2 = (float) dst->ne[2] / dst->src[0]->ne[2]; - const float sf3 = (float) dst->ne[3] / dst->src[0]->ne[3]; - switch (dst->type) { - case GGML_TYPE_F16: - { - auto data_pts = cast_data(dst); - kernel_invoker(data_pts.src, data_pts.dst, (int)dst->src[0]->nb[0], (int)dst->src[0]->nb[1], (int)dst->src[0]->nb[2], - (int)dst->src[0]->nb[3], (int)dst->ne[0], (int)dst->ne[1], (int)dst->ne[2], (int)dst->ne[3], sf0, sf1, sf2, sf3, - main_stream, std::forward(args)...); - break; - } - case GGML_TYPE_F32: - { - auto data_pts = cast_data(dst); - kernel_invoker(data_pts.src, data_pts.dst, (int)dst->src[0]->nb[0], (int)dst->src[0]->nb[1], (int)dst->src[0]->nb[2], - (int)dst->src[0]->nb[3], (int)dst->ne[0], (int)dst->ne[1], (int)dst->ne[2], (int)dst->ne[3], sf0, sf1, sf2, sf3, - main_stream, std::forward(args)...); - break; - } - default: - GGML_ABORT("GGML tensor type not supported!\n"); - } -} - template static inline void ggml_sycl_op_unary( ggml_backend_sycl_context & ctx, ggml_tensor * dst, F func) { @@ -784,15 +710,6 @@ static inline void ggml_sycl_op_sqr(ggml_backend_sycl_context & ctx, ggml_tensor }); } -static inline void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - ggml_sycl_detail::dispatch_ggml_sycl_op_upscale(ctx, dst, - [](const auto* src, auto* dst_ptr, int nb00, int nb01, int nb02, int nb03, - int ne10, int ne11, int ne12, int ne13, float sf0, float sf1, float sf2, float sf3, - queue_ptr stream) { - ggml_sycl_detail::upscale_sycl(src, dst_ptr, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3, stream); - }); -} - static inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { float min_val; float max_val; @@ -1131,12 +1048,6 @@ void ggml_sycl_sqr(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { ggml_sycl_op_sqr(ctx, dst); } -void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); - ggml_sycl_op_upscale(ctx, dst); -} - - void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); ggml_sycl_op_clamp(ctx, dst); diff --git a/ggml/src/ggml-sycl/element_wise.hpp b/ggml/src/ggml-sycl/element_wise.hpp index 7c71974687a..997132166ab 100644 --- a/ggml/src/ggml-sycl/element_wise.hpp +++ b/ggml/src/ggml-sycl/element_wise.hpp @@ -71,8 +71,6 @@ void ggml_sycl_leaky_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst); void ggml_sycl_sqr(ggml_backend_sycl_context & ctx, ggml_tensor * dst); -void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst); - void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst); void ggml_sycl_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 12819705849..2ec1421841b 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -44,7 +44,6 @@ #include "ggml-sycl/backend.hpp" #include "ggml-sycl/common.hpp" #include "ggml-sycl/element_wise.hpp" -#include "ggml-sycl/gated_delta_net.hpp" #include "ggml-sycl/gemm.hpp" #include "ggml-sycl/getrows.hpp" #include "ggml-sycl/norm.hpp" @@ -4863,9 +4862,8 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_ROPE: case GGML_OP_ROPE_BACK: case GGML_OP_IM2COL: - return true; case GGML_OP_UPSCALE: - return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST && !(op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS); + return true; case GGML_OP_SUM: case GGML_OP_SUM_ROWS: case GGML_OP_MEAN: diff --git a/ggml/src/ggml-sycl/upscale.cpp b/ggml/src/ggml-sycl/upscale.cpp new file mode 100644 index 00000000000..18c743de447 --- /dev/null +++ b/ggml/src/ggml-sycl/upscale.cpp @@ -0,0 +1,410 @@ +#include "upscale.hpp" + +static void upscale_f32(const float * x, float * dst, + const int nb00, const int nb01, const int nb02, const int nb03, + const int ne10, const int ne11, const int ne12, const int ne13, + const float sf0, const float sf1, const float sf2, const float sf3) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + int index = item_ct1.get_local_id(2) + item_ct1.get_group(2) * item_ct1.get_local_range(2); + if (index >= ne10 * ne11 * ne12 * ne13) { + return; + } + + int i10 = index % ne10; + int i11 = (index / ne10) % ne11; + int i12 = (index / (ne10 * ne11)) % ne12; + int i13 = (index / (ne10 * ne11 * ne12)) % ne13; + + int i00 = i10 / sf0; + int i01 = i11 / sf1; + int i02 = i12 / sf2; + int i03 = i13 / sf3; + + dst[index] = *((const float*)((const char*)x + i03 * nb03 + i02 * nb02 + + i01 * nb01 + i00 * nb00)); +} + +static void upscale_f32_bilinear(const float * x, float * dst, + const int nb00, const int nb01, const int nb02, const int nb03, + const int ne00_src, const int ne01_src, + const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst, + const float sf0, const float sf1, const float sf2, const float sf3, + const float pixel_offset) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + const int64_t index = item_ct1.get_local_id(2) + + item_ct1.get_group(2) * item_ct1.get_local_range(2); + const int64_t dst_total_elements = ne10_dst * ne11_dst * ne12_dst * ne13_dst; + + if (index >= dst_total_elements) { + return; + } + + const int i10_dst = index % ne10_dst; + const int i11_dst = (index / ne10_dst) % ne11_dst; + const int i12_dst = (index / (ne10_dst * ne11_dst)) % ne12_dst; + const int i13_dst = index / (ne10_dst * ne11_dst * ne12_dst); + + const int i02_src = (int)(i12_dst / sf2); + const int i03_src = (int)(i13_dst / sf3); + + const float y_src_f = ((float)i11_dst + pixel_offset) / sf1 - pixel_offset; + int y0_src = (int) sycl::floor((float) y_src_f); + int y1_src = y0_src + 1; + + y0_src = sycl::max(0, sycl::min(y0_src, ne01_src - 1)); + y1_src = sycl::max(0, sycl::min(y1_src, ne01_src - 1)); + + float dy = y_src_f - (float)y0_src; + dy = sycl::max(0.0f, sycl::min(dy, 1.0f)); + + float x_src_f = ((float)i10_dst + pixel_offset) / sf0 - pixel_offset; + int x0_src = (int) sycl::floor(x_src_f); + int x1_src = x0_src + 1; + + x0_src = sycl::max(0, sycl::min(x0_src, ne00_src - 1)); + x1_src = sycl::max(0, sycl::min(x1_src, ne00_src - 1)); + + float dx = x_src_f - (float)x0_src; + dx = sycl::max(0.0f, sycl::min(dx, 1.0f)); + + const float* p_a = + (const float*)((const char*)x + (int64_t)x0_src * nb00 + + (int64_t)y0_src * nb01 + (int64_t)i02_src * nb02 + + (int64_t)i03_src * nb03); + const float* p_b = + (const float*)((const char*)x + (int64_t)x1_src * nb00 + + (int64_t)y0_src * nb01 + (int64_t)i02_src * nb02 + + (int64_t)i03_src * nb03); + const float* p_c = + (const float*)((const char*)x + (int64_t)x0_src * nb00 + + (int64_t)y1_src * nb01 + (int64_t)i02_src * nb02 + + (int64_t)i03_src * nb03); + const float* p_d = + (const float*)((const char*)x + (int64_t)x1_src * nb00 + + (int64_t)y1_src * nb01 + (int64_t)i02_src * nb02 + + (int64_t)i03_src * nb03); + + const float val_a = *p_a; + const float val_b = *p_b; + const float val_c = *p_c; + const float val_d = *p_d; + + float result = val_a * (1.0f - dx) * (1.0f - dy) + + val_b * dx * (1.0f - dy) + + val_c * (1.0f - dx) * dy + + val_d * dx * dy; + + dst[index] = result; +} + +// Similar to F.interpolate(..., mode="bilinear", align_corners=False, antialias=True) +// https://github.com/pytorch/pytorch/blob/8871ff29b743948d1225389d5b7068f37b22750b/aten/src/ATen/native/cpu/UpSampleKernel.cpp +static void upscale_f32_bilinear_antialias(const float * src0, + float * dst, + const int nb00, + const int nb01, + const int nb02, + const int nb03, + const int ne00_src, + const int ne01_src, + const int ne10_dst, + const int ne11_dst, + const int ne12_dst, + const int ne13_dst, + const float sf0, + const float sf1, + const float sf2, + const float sf3, + const float pixel_offset) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + const int64_t index = item_ct1.get_local_id(2) + + item_ct1.get_group(2) * item_ct1.get_local_range(2); + const int64_t dst_total_elements = ne10_dst * ne11_dst * ne12_dst * ne13_dst; + + if (index >= dst_total_elements) { + return; + } + + const int i10_dst = index % ne10_dst; + const int i11_dst = (index / ne10_dst) % ne11_dst; + const int i12_dst = (index / (ne10_dst * ne11_dst)) % ne12_dst; + const int i13_dst = index / (ne10_dst * ne11_dst * ne12_dst); + + const int i02_src = (int)(i12_dst / sf2); + const int i03_src = (int)(i13_dst / sf3); + + const float y = ((float)i11_dst + pixel_offset) / sf1; + const float x = ((float)i10_dst + pixel_offset) / sf0; + + // support and invscale, minimum 1 pixel for bilinear + const float support1 = sycl::max(1.0f / sf1, 1.0f); + const float invscale1 = 1.0f / support1; + const float support0 = sycl::max(1.0f / sf0, 1.0f); + const float invscale0 = 1.0f / support0; + + // the range of source pixels that contribute + const int64_t x_min = sycl::max(int64_t(0), int64_t(x - support0 + pixel_offset)); + const int64_t x_max = sycl::min(int64_t(ne00_src), int64_t(x + support0 + pixel_offset)); + const int64_t y_min = sycl::max(int64_t(0), int64_t(y - support1 + pixel_offset)); + const int64_t y_max = sycl::min(int64_t(ne01_src), int64_t(y + support1 + pixel_offset)); + + // bilinear filter with antialiasing + float val = 0.0f; + float total_weight = 0.0f; + + auto triangle_filter = [](float x) -> float { + return sycl::max(1.0f - sycl::fabs(x), 0.0f); + }; + + for (int64_t sy = y_min; sy < y_max; sy++) { + const float weight_y = triangle_filter((sy - y + pixel_offset) * invscale1); + + for (int64_t sx = x_min; sx < x_max; sx++) { + const float weight_x = triangle_filter((sx - x + pixel_offset) * invscale0); + const float weight = weight_x * weight_y; + + if (weight <= 0.0f) { + continue; + } + + const float pixel = + *(const float*)((const char*)src0 + sx * nb00 + sy * nb01 + + i02_src * nb02 + i03_src * nb03); + val += pixel * weight; + total_weight += weight; + } + } + + if (total_weight > 0.0f) { + val /= total_weight; + } + + dst[index] = val; +} + +namespace bicubic_interpolation { +static float weight1(float x, const float &a) { return ((a + 2) * x - (a + 3)) * x * x + 1; }; +static float weight2(float x, const float &a) { return ((a * x - 5 * a) * x + 8 * a) * x - 4 * a; }; + +static float bicubic(float p0, float p1, float p2, float p3, float x, float a) { + const float w0 = weight2(x + 1, a); + const float w1 = weight1(x + 0, a); + const float w2 = weight1(1 - x, a); + const float w3 = weight2(2 - x, a); + return p0 * w0 + p1 * w1 + p2 * w2 + p3 * w3; +}; + +} + +static void upscale_f32_bicubic(const float * x, float * dst, + const int nb00, const int nb01, const int nb02, const int nb03, + const int ne00_src, const int ne01_src, + const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst, + const float sf0, const float sf1, const float sf2, const float sf3, + const float pixel_offset) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + const float a = -0.75f; + using bicubic_interpolation::bicubic; + + const int64_t index = item_ct1.get_local_id(2) + + item_ct1.get_group(2) * item_ct1.get_local_range(2); + const int64_t dst_total_elements = + ne10_dst * ne11_dst * ne12_dst * ne13_dst; + + if (index >= dst_total_elements) { + return; + } + + const int i10_dst = index % ne10_dst; + const int i11_dst = (index / ne10_dst) % ne11_dst; + const int i12_dst = (index / (ne10_dst * ne11_dst)) % ne12_dst; + const int i13_dst = index / (ne10_dst * ne11_dst * ne12_dst); + + const int i02_src = (int)(i12_dst / sf2); + const int i03_src = (int)(i13_dst / sf3); + + const float y_src_f = ((float)i11_dst + pixel_offset) / sf1 - pixel_offset; + const int y0_src = (int) sycl::floor((float) y_src_f); + const float dy = y_src_f - (float)y0_src; + + const float x_src_f = ((float)i10_dst + pixel_offset) / sf0 - pixel_offset; + const int x0_src = (int) sycl::floor((float) x_src_f); + const float dx = x_src_f - (float)x0_src; + + const char * x_base = (const char *)x + (int64_t)i02_src * nb02 + (int64_t)i03_src * nb03; + + auto load = [=](int x_off, int y_off) -> float { + int i00_src = sycl::max(0, sycl::min(x0_src + x_off, ne00_src - 1)); + int i01_src = sycl::max(0, sycl::min(y0_src + y_off, ne01_src - 1)); + return *(const float *)(x_base + (int64_t)i00_src * nb00 + (int64_t)i01_src * nb01); + }; + + const float result = bicubic( + bicubic(load(-1, -1), load(0, -1), load(1, -1), load(2, -1), dx, a), + bicubic(load(-1, 0), load(0, 0), load(1, 0), load(2, 0), dx, a), + bicubic(load(-1, 1), load(0, 1), load(1, 1), load(2, 1), dx, a), + bicubic(load(-1, 2), load(0, 2), load(1, 2), load(2, 2), dx, a), + dy, + a); + + dst[index] = result; +} + +static void upscale_f32_sycl(const float * x, + float * dst, + const int nb00, + const int nb01, + const int nb02, + const int nb03, + const int ne10, + const int ne11, + const int ne12, + const int ne13, + const float sf0, + const float sf1, + const float sf2, + const float sf3, + dpct::queue_ptr stream) { + const int64_t dst_size = ne10 * ne11 * ne12 * ne13; + const int64_t num_blocks = (dst_size + SYCL_UPSCALE_BLOCK_SIZE - 1) / SYCL_UPSCALE_BLOCK_SIZE; + + stream->parallel_for( + sycl::nd_range<3>( + sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + upscale_f32(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3); + }); +} + +static void upscale_f32_bilinear_sycl(const float * x, + float * dst, + const int nb00, + const int nb01, + const int nb02, + const int nb03, + const int ne00_src, + const int ne01_src, + const int ne10_dst, + const int ne11_dst, + const int ne12_dst, + const int ne13_dst, + const float sf0, + const float sf1, + const float sf2, + const float sf3, + const float pixel_offset, + bool antialias, + dpct::queue_ptr stream) { + const int64_t dst_size = ne10_dst * ne11_dst * ne12_dst * ne13_dst; + const int64_t num_blocks = (dst_size + SYCL_UPSCALE_BLOCK_SIZE - 1) / SYCL_UPSCALE_BLOCK_SIZE; + + if (antialias) { + stream->parallel_for( + sycl::nd_range<3>( + sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + upscale_f32_bilinear_antialias( + x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, + ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset); + }); + } else { + stream->parallel_for( + sycl::nd_range<3>( + sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + upscale_f32_bilinear( + x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, + ne13_dst, sf0, sf1, sf2, sf3, pixel_offset); + }); + } +} + +static void upscale_f32_bicubic_sycl(const float * x, + float * dst, + const int nb00, + const int nb01, + const int nb02, + const int nb03, + const int ne00_src, + const int ne01_src, + const int ne10_dst, + const int ne11_dst, + const int ne12_dst, + const int ne13_dst, + const float sf0, + const float sf1, + const float sf2, + const float sf3, + const float pixel_offset, + dpct::queue_ptr stream) { + const int64_t dst_size = ne10_dst * ne11_dst * ne12_dst * ne13_dst; + const int64_t num_blocks = (dst_size + SYCL_UPSCALE_BLOCK_SIZE - 1) / SYCL_UPSCALE_BLOCK_SIZE; + + { + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for( + sycl::nd_range<3>( + sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + upscale_f32_bicubic( + x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, + ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset); + }); + }); + } +} + +void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const float * src0_d = (const float *)src0->data; + float * dst_d = (float *)dst->data; + dpct::queue_ptr stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + const int mode_flags = dst->op_params[0]; + const ggml_scale_mode mode = (ggml_scale_mode)(mode_flags & 0xFF); + + float sf0 = (float)dst->ne[0]/src0->ne[0]; + float sf1 = (float)dst->ne[1]/src0->ne[1]; + float sf2 = (float)dst->ne[2]/src0->ne[2]; + const float sf3 = (float)dst->ne[3]/src0->ne[3]; + + float pixel_offset = 0.5f; + if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) { + sf0 = dst->ne[0] > 1 && src0->ne[0] > 1 + ? (float)(dst->ne[0] - 1) / (src0->ne[0] - 1) + : sf0; + sf1 = dst->ne[1] > 1 && src0->ne[1] > 1 + ? (float)(dst->ne[1] - 1) / (src0->ne[1] - 1) + : sf1; + pixel_offset = 0.0f; + } + + if (mode == GGML_SCALE_MODE_NEAREST) { + upscale_f32_sycl( + src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3, stream); + } else if (mode == GGML_SCALE_MODE_BILINEAR) { + const bool antialias = (mode_flags & GGML_SCALE_FLAG_ANTIALIAS); + upscale_f32_bilinear_sycl( + src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + src0->ne[0], src0->ne[1], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + sf0, sf1, sf2, sf3, pixel_offset, antialias, stream); + } else if (mode == GGML_SCALE_MODE_BICUBIC) { + upscale_f32_bicubic_sycl( + src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + src0->ne[0], src0->ne[1], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + sf0, sf1, sf2, sf3, pixel_offset, stream); + } +} + +void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + ggml_sycl_op_upscale(ctx, dst); +} diff --git a/ggml/src/ggml-sycl/upscale.hpp b/ggml/src/ggml-sycl/upscale.hpp new file mode 100644 index 00000000000..c36c1bdc970 --- /dev/null +++ b/ggml/src/ggml-sycl/upscale.hpp @@ -0,0 +1,9 @@ +#pragma once + +#include +#include "dpct/helper.hpp" +#include "common.hpp" + +#define SYCL_UPSCALE_BLOCK_SIZE 256 + +void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst); From 49adc8b470cb1c09d01d313b6fc1859d43658158 Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Tue, 17 Mar 2026 10:09:59 +0100 Subject: [PATCH 306/831] vulkan: allow graphics queue only through env var (llama/20599) * vulkan: avoid graphics queue on non-RADV AMD drivers * avoid graphics queues on small GPUs * change to only use graphics queue if overridden with env var GGML_VK_ALLOW_GRAPHICS_QUEUE * reenable transfer queue if graphics queue is not used --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 7092361d2ea..e9b6778d628 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -4981,8 +4981,9 @@ static vk_device ggml_vk_get_device(size_t idx) { std::vector queue_family_props = device->physical_device.getQueueFamilyProperties(); // Try to find a non-graphics compute queue and transfer-focused queues - // On AMD, the graphics queue seems to be faster, so don't avoid it - const vk::QueueFlagBits graphics_flag = device->vendor_id == VK_VENDOR_ID_AMD ? (vk::QueueFlagBits)0 : vk::QueueFlagBits::eGraphics; + // Allow overriding avoiding the graphics queue because it can increase performance on RADV + const bool allow_graphics_queue = (getenv("GGML_VK_ALLOW_GRAPHICS_QUEUE") != nullptr); + const vk::QueueFlagBits graphics_flag = allow_graphics_queue ? (vk::QueueFlagBits)0 : vk::QueueFlagBits::eGraphics; const uint32_t compute_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eCompute, graphics_flag, -1, 1); const uint32_t transfer_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eTransfer, vk::QueueFlagBits::eCompute | graphics_flag, compute_queue_family_index, 1); @@ -5443,11 +5444,14 @@ static vk_device ggml_vk_get_device(size_t idx) { ggml_vk_load_shaders(device); + // Only use transfer queue on AMD non-GCN, when the graphics queue is not enabled + const bool prefers_transfer_queue = device->vendor_id == VK_VENDOR_ID_AMD && device->architecture != AMD_GCN && !allow_graphics_queue; + if (!device->single_queue) { const uint32_t transfer_queue_index = compute_queue_family_index == transfer_queue_family_index ? 1 : 0; ggml_vk_create_queue(device, device->transfer_queue, transfer_queue_family_index, transfer_queue_index, { vk::PipelineStageFlagBits::eTransfer }, true); - device->async_use_transfer_queue = (getenv("GGML_VK_ASYNC_USE_TRANSFER_QUEUE") != nullptr); + device->async_use_transfer_queue = prefers_transfer_queue || (getenv("GGML_VK_ASYNC_USE_TRANSFER_QUEUE") != nullptr); } else { // TODO: Use pointer or reference to avoid copy device->transfer_queue.copyFrom(device->compute_queue); From ab7d305b751ddd8c50f3beeddb95eaf02d19741a Mon Sep 17 00:00:00 2001 From: Justin Bradford Date: Tue, 17 Mar 2026 05:03:54 -0700 Subject: [PATCH 307/831] kleidiai : fix MUL_MAT support for batched (3D) inputs (llama/20620) * kleidiai : fix MUL_MAT support for batched (3D) inputs The supports_op() check incorrectly rejected MUL_MAT operations with 3D inputs (ne[2] > 1), but the actual compute_forward_qx() implementation handles batched inputs correctly via a loop over ne12. This caused models with Q4_0/Q8_0 weights to crash during graph scheduling when n_seq_max > 1, because weights were placed in KLEIDIAI buffers during loading (tested with 2D inputs) but the runtime used 3D inputs. Also relax the buffer check to allow supports_op() to be called during weight loading when src[0]->buffer is NULL. Fixes #20608 * Kleidiai support_ops should only return true for 3D inputs, not also 4D --- ggml/src/ggml-cpu/kleidiai/kleidiai.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp index 7a5924944a8..0ecf7ae02ac 100644 --- a/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +++ b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp @@ -1461,7 +1461,7 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type { return false; } if ((op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_I32) && - ggml_ne(op->src[1], 2) == 1 && ggml_ne(op->src[1], 3) == 1) { + ggml_ne(op->src[1], 3) == 1) { return true; } } From 0ad6ceef59777414829eb8167e4e12022c03be21 Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Tue, 17 Mar 2026 14:27:23 +0100 Subject: [PATCH 308/831] vulkan: async and event fixes (llama/20518) * vulkan: fix event wait submission, event command buffer reset * fix event command buffer reset validation error * also reset command buffers before reuse * use timeline semaphores instead of fences for event_synchronize * don't use initializer list for semaphore wait info * use multiple events to avoid reset issues * fix event reuse issue with multiple vectors * add semaphore wait condition also if compute_ctx already exists * remove event pending stage --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 128 ++++++++++++++++++++------- 1 file changed, 95 insertions(+), 33 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index e9b6778d628..3d8ce10676e 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -191,6 +191,7 @@ struct vk_queue; struct vk_command_buffer { vk::CommandBuffer buf; + uint64_t use_counter = 0; bool in_use = false; }; @@ -938,19 +939,24 @@ struct vk_subbuffer { } }; -// vk_event is used for the event-related backend interfaces. It uses 'event' for -// event_wait and 'fence' for event_synchronize. Polling on an event for +struct vk_semaphore { + vk::Semaphore s; + uint64_t value; +}; + +// vk_event is used for the event-related backend interfaces. It uses vk::Events for +// event_wait and a timeline semaphore for event_synchronize. Polling on an event for // event_synchronize wouldn't be sufficient to wait for command buffers to complete, // and would lead to validation errors. struct vk_event { + std::vector events_free; // Events available for reuse + std::vector events_submitted; // Events that are fully submitted and can be reused on next synchronize vk::Event event; - vk::Fence fence; - vk_command_buffer* cmd_buffer = nullptr; -}; + bool has_event; -struct vk_semaphore { - vk::Semaphore s; - uint64_t value; + vk_semaphore tl_semaphore; + vk_command_buffer* cmd_buffer = nullptr; + uint64_t cmd_buffer_use_counter = 0; }; struct vk_submission { @@ -2319,7 +2325,7 @@ static vk_command_buffer* ggml_vk_create_cmd_buffer(vk_device& device, vk_comman vk::CommandBufferLevel::ePrimary, 1); const std::vector cmd_buffers = device->device.allocateCommandBuffers(command_buffer_alloc_info); - p.cmd_buffers.push_back({ cmd_buffers.front(), true }); + p.cmd_buffers.push_back({ cmd_buffers.front(), 0, true }); return &p.cmd_buffers[p.cmd_buffers.size()-1]; } @@ -2788,6 +2794,15 @@ static void ggml_vk_sync_buffers(ggml_backend_vk_context* ctx, vk_context& subct ); } +static void ggml_vk_reset_event(vk_context& ctx, vk::Event& event) { + VK_LOG_DEBUG("ggml_vk_set_event()"); + + ctx->s->buffer->buf.resetEvent( + event, + ctx->p->q->stage_flags + ); +} + static void ggml_vk_set_event(vk_context& ctx, vk::Event& event) { VK_LOG_DEBUG("ggml_vk_set_event()"); @@ -6396,6 +6411,7 @@ static vk_subbuffer ggml_vk_tensor_subbuffer( static vk_command_buffer* ggml_vk_get_or_create_cmd_buffer(vk_device& device, vk_command_pool& pool) { for (auto& cmd_buffer : pool.cmd_buffers) { if (!cmd_buffer.in_use) { + cmd_buffer.use_counter++; cmd_buffer.in_use = true; return &cmd_buffer; } @@ -6500,14 +6516,15 @@ static void ggml_vk_ctx_begin(vk_device& device, vk_context& subctx) { } static vk_context ggml_vk_get_compute_ctx(ggml_backend_vk_context * ctx) { + vk_context result; if (!ctx->compute_ctx.expired()) { - return ctx->compute_ctx.lock(); - } - - vk_context result = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); + result = ctx->compute_ctx.lock(); + } else { + result = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); - ctx->compute_ctx = result; - ggml_vk_ctx_begin(ctx->device, result); + ctx->compute_ctx = result; + ggml_vk_ctx_begin(ctx->device, result); + } if (ctx->device->async_use_transfer_queue && ctx->transfer_semaphore_last_submitted < ctx->transfer_semaphore.value) { result->s->wait_semaphores.push_back(ctx->transfer_semaphore); @@ -13801,6 +13818,7 @@ static void ggml_vk_synchronize(ggml_backend_vk_context * ctx) { ctx->submit_pending = false; if (cmd_buf) { cmd_buf->in_use = false; + cmd_buf->buf.reset(); } } @@ -14862,18 +14880,31 @@ static void ggml_backend_vk_event_record(ggml_backend_t backend, ggml_backend_ev vk_context compute_ctx = ggml_vk_get_compute_ctx(ctx); auto* cmd_buf = compute_ctx->s->buffer; // retrieve pointer before it gets reset - // the backend interface doesn't have an explicit reset, so reset it here - // before we record the command to set it - ctx->device->device.resetEvent(vkev->event); - ctx->device->device.resetFences({ vkev->fence }); + if (vkev->has_event) { + // Move existing event into submitted + vkev->events_submitted.push_back(vkev->event); + } + + // Grab the next event and record it, create one if necessary + if (vkev->events_free.empty()) { + vkev->event = ctx->device->device.createEvent({}); + } else { + vkev->event = vkev->events_free.back(); + vkev->events_free.pop_back(); + } + + vkev->has_event = true; ggml_vk_set_event(compute_ctx, vkev->event); + vkev->tl_semaphore.value++; + compute_ctx->s->signal_semaphores.push_back(vkev->tl_semaphore); ggml_vk_ctx_end(compute_ctx); - ggml_vk_submit(compute_ctx, {vkev->fence}); + ggml_vk_submit(compute_ctx, {}); ctx->submit_pending = true; vkev->cmd_buffer = cmd_buf; + vkev->cmd_buffer_use_counter = cmd_buf->use_counter; ctx->compute_ctx.reset(); } @@ -14884,9 +14915,10 @@ static void ggml_backend_vk_event_wait(ggml_backend_t backend, ggml_backend_even vk_context compute_ctx = ggml_vk_get_compute_ctx(ctx); - ggml_vk_wait_events(compute_ctx, {vkev->event}); - ggml_vk_ctx_end(compute_ctx); - ctx->compute_ctx.reset(); + if (vkev->has_event) { + // Wait for latest event + ggml_vk_wait_events(compute_ctx, { vkev->event }); + } } // TODO: enable async and synchronize @@ -15676,10 +15708,13 @@ static ggml_backend_event_t ggml_backend_vk_device_event_new(ggml_backend_dev_t return nullptr; } - // The event/fence is expected to initially be in the signaled state. - vkev->event = device->device.createEvent({}); - vkev->fence = device->device.createFence({vk::FenceCreateFlagBits::eSignaled}); - device->device.setEvent(vkev->event); + // No events initially, they get created on demand + vkev->has_event = false; + + vk::SemaphoreTypeCreateInfo tci{ vk::SemaphoreType::eTimeline, 0 }; + vk::SemaphoreCreateInfo ci{}; + ci.setPNext(&tci); + vkev->tl_semaphore = { device->device.createSemaphore(ci), 0 }; return new ggml_backend_event { /* .device = */ dev, @@ -15693,8 +15728,16 @@ static void ggml_backend_vk_device_event_free(ggml_backend_dev_t dev, ggml_backe vk_event *vkev = (vk_event *)event->context; - device->device.destroyFence(vkev->fence); - device->device.destroyEvent(vkev->event); + device->device.destroySemaphore(vkev->tl_semaphore.s); + for (auto& event : vkev->events_free) { + device->device.destroyEvent(event); + } + for (auto& event : vkev->events_submitted) { + device->device.destroyEvent(event); + } + if (vkev->has_event) { + device->device.destroyEvent(vkev->event); + } delete vkev; delete event; } @@ -15705,10 +15748,29 @@ static void ggml_backend_vk_device_event_synchronize(ggml_backend_dev_t dev, ggm auto device = ggml_vk_get_device(ctx->device); vk_event *vkev = (vk_event *)event->context; - VK_CHECK(device->device.waitForFences({ vkev->fence }, true, UINT64_MAX), "event_synchronize"); - // Finished using current command buffer so we flag for reuse - if (vkev->cmd_buffer) { - vkev->cmd_buffer->in_use = false; + // Only do something if the event has actually been used + if (vkev->has_event) { + vk::Semaphore sem = vkev->tl_semaphore.s; + uint64_t val = vkev->tl_semaphore.value; + vk::SemaphoreWaitInfo swi{vk::SemaphoreWaitFlags{}, sem, val}; + VK_CHECK(device->device.waitSemaphores(swi, UINT64_MAX), "event_synchronize"); + + // Reset and move submitted events + for (auto& event : vkev->events_submitted) { + device->device.resetEvent(event); + } + vkev->events_free.insert(vkev->events_free.end(), vkev->events_submitted.begin(), vkev->events_submitted.end()); + vkev->events_submitted.clear(); + + // Finished using current command buffer so we flag for reuse + if (vkev->cmd_buffer) { + // Only flag for reuse if it hasn't been reused already + if (vkev->cmd_buffer_use_counter == vkev->cmd_buffer->use_counter) { + vkev->cmd_buffer->in_use = false; + vkev->cmd_buffer->buf.reset(); + } + vkev->cmd_buffer = nullptr; + } } } From c890a9d9b4f6ad9c9a75387a0b0d3c973ad7f4ca Mon Sep 17 00:00:00 2001 From: Taimur Ahmad Date: Tue, 17 Mar 2026 19:03:40 +0500 Subject: [PATCH 309/831] ggml-cpu: fix RVV checks in quants and repacking (llama/20682) * ggml-cpu: refactor quants.c; add rvv check * ggml-cpu: refactor; disable generic fallback --- ggml/src/ggml-cpu/arch/riscv/quants.c | 40 +++++++++++++++++-------- ggml/src/ggml-cpu/arch/riscv/repack.cpp | 40 ++++--------------------- ggml/src/ggml-cpu/repack.cpp | 3 ++ 3 files changed, 35 insertions(+), 48 deletions(-) diff --git a/ggml/src/ggml-cpu/arch/riscv/quants.c b/ggml/src/ggml-cpu/arch/riscv/quants.c index 826055dd9a4..d7e9ba46348 100644 --- a/ggml/src/ggml-cpu/arch/riscv/quants.c +++ b/ggml/src/ggml-cpu/arch/riscv/quants.c @@ -115,10 +115,10 @@ void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, i void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) { assert(k % QK_K == 0); - block_q8_K * y_blocks = (block_q8_K *)y; size_t nb = k / QK_K; #if defined(__riscv_v_intrinsic) + block_q8_K * y_blocks = (block_q8_K *)y; const size_t vlmax_f32m8 = __riscv_vsetvlmax_e32m8(); for (size_t i = 0; i < nb; i++) { @@ -2052,6 +2052,7 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi #endif } +#if defined __riscv_v_intrinsic static void ggml_vec_dot_iq1_s_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); @@ -2147,6 +2148,7 @@ static void ggml_vec_dot_iq1_s_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t *s = sumf; } +#endif void ggml_vec_dot_iq1_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { #if defined __riscv_v_intrinsic @@ -2163,6 +2165,7 @@ void ggml_vec_dot_iq1_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo #endif } +#if defined __riscv_v_intrinsic static void ggml_vec_dot_iq1_m_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); @@ -2269,6 +2272,7 @@ static void ggml_vec_dot_iq1_m_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t *s = sumf; } +#endif void ggml_vec_dot_iq1_m_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { #if defined __riscv_v_intrinsic @@ -2285,6 +2289,7 @@ void ggml_vec_dot_iq1_m_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo #endif } +#if defined __riscv_v_intrinsic static const uint8_t sign_gather_indices_arr[64] = { 0,0,0,0,0,0,0,0, 1,1,1,1,1,1,1,1, 2,2,2,2,2,2,2,2, 3,3,3,3,3,3,3,3, 4,4,4,4,4,4,4,4, 5,5,5,5,5,5,5,5, 6,6,6,6,6,6,6,6, 7,7,7,7,7,7,7,7 @@ -2488,6 +2493,7 @@ static void ggml_vec_dot_iq2_s_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t } *s = 0.125f * sumf; } +#endif void ggml_vec_dot_iq2_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { #if defined __riscv_v_intrinsic @@ -2507,7 +2513,7 @@ void ggml_vec_dot_iq2_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo #endif } -#if defined(__riscv_v) +#if defined(__riscv_v_intrinsic) static const int8_t keven_signs_q2xs[1024] = { 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, -1, @@ -2542,7 +2548,6 @@ static const int8_t keven_signs_q2xs[1024] = { 1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1, }; -#endif static void ggml_vec_dot_iq2_xs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); @@ -2618,6 +2623,7 @@ static void ggml_vec_dot_iq2_xs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_ } *s = 0.125f * sumf; } +#endif void ggml_vec_dot_iq2_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { #if defined __riscv_v_intrinsic @@ -2634,6 +2640,7 @@ void ggml_vec_dot_iq2_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v #endif } +#if defined __riscv_v_intrinsic static void ggml_vec_dot_iq2_xxs_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); @@ -2818,6 +2825,7 @@ static void ggml_vec_dot_iq2_xxs_q8_K_vl256(int n, float * GGML_RESTRICT s, size } *s = 0.125f * sumf; } +#endif void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { #if defined __riscv_v_intrinsic @@ -2830,10 +2838,11 @@ void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const break; } #else - ggml_vec_dot_iq2_xxs_q8_K(n, s, bs, vx, bx, vy, by, nrc); + ggml_vec_dot_iq2_xxs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); #endif } +#if defined __riscv_v_intrinsic static void ggml_vec_dot_iq3_s_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); UNUSED(nrc); @@ -2928,6 +2937,7 @@ static void ggml_vec_dot_iq3_s_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t } *s = sumf; } +#endif void ggml_vec_dot_iq3_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { #if defined __riscv_v_intrinsic @@ -2944,6 +2954,7 @@ void ggml_vec_dot_iq3_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo #endif } +#if defined __riscv_v_intrinsic static void ggml_vec_dot_iq3_xxs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); @@ -3036,6 +3047,7 @@ static void ggml_vec_dot_iq3_xxs_q8_K_vl256(int n, float * GGML_RESTRICT s, size } *s = 0.25f * sumf; } +#endif void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { #if defined __riscv_v_intrinsic @@ -3052,6 +3064,7 @@ void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const #endif } +#if defined __riscv_v_intrinsic static void ggml_vec_dot_iq4_nl_q8_0_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(nrc == 1); UNUSED(nrc); @@ -3161,6 +3174,7 @@ static void ggml_vec_dot_iq4_nl_q8_0_vl256(int n, float * GGML_RESTRICT s, size_ *s = sumf; } +#endif void ggml_vec_dot_iq4_nl_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { #if defined __riscv_v_intrinsic @@ -3177,6 +3191,7 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v #endif } +#if defined __riscv_v_intrinsic static void ggml_vec_dot_iq4_xs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(nrc == 1); UNUSED(nrc); @@ -3190,7 +3205,6 @@ static void ggml_vec_dot_iq4_xs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_ const int nb = n / QK_K; -#if defined __riscv_v_intrinsic const vint8m4_t values = __riscv_vle8_v_i8m4(kvalues_iq4nl, 16); float sumf = 0; int acc[4]; @@ -3252,14 +3266,8 @@ static void ggml_vec_dot_iq4_xs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_ } *s = sumf; - -#else - UNUSED(x); - UNUSED(y); - UNUSED(nb); - ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); -#endif } +#endif void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { #if defined __riscv_v_intrinsic @@ -3276,6 +3284,7 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v #endif } +#if defined __riscv_v_intrinsic static void ggml_vec_dot_tq1_0_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(nrc == 1); UNUSED(nrc); @@ -3381,6 +3390,7 @@ static void ggml_vec_dot_tq1_0_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t *s = sumf; } +#endif void ggml_vec_dot_tq1_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { #if defined __riscv_v_intrinsic @@ -3397,6 +3407,7 @@ void ggml_vec_dot_tq1_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo #endif } +#if defined __riscv_v_intrinsic static void ggml_vec_dot_tq2_0_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); @@ -3467,6 +3478,7 @@ static void ggml_vec_dot_tq2_0_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t *s = sumf; } +#endif void ggml_vec_dot_tq2_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { #if defined __riscv_v_intrinsic @@ -3483,6 +3495,7 @@ void ggml_vec_dot_tq2_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo #endif } +#if defined __riscv_v_intrinsic static void ggml_vec_dot_mxfp4_q8_0_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(nrc == 1); UNUSED(nrc); @@ -3592,6 +3605,7 @@ static void ggml_vec_dot_mxfp4_q8_0_vl256(int n, float * GGML_RESTRICT s, size_t *s = sumf; } +#endif void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { #if defined __riscv_v_intrinsic @@ -3604,6 +3618,6 @@ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo break; } #else - return ggml_vec_dot_mxfp4_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); + ggml_vec_dot_mxfp4_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); #endif } diff --git a/ggml/src/ggml-cpu/arch/riscv/repack.cpp b/ggml/src/ggml-cpu/arch/riscv/repack.cpp index cd5807879ea..c37488cae54 100644 --- a/ggml/src/ggml-cpu/arch/riscv/repack.cpp +++ b/ggml/src/ggml-cpu/arch/riscv/repack.cpp @@ -107,8 +107,7 @@ void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTR } #else UNUSED(nb); - UNUSED(y); - ggml_quantize_mat_q8_0_4x4_generic(x, vy, k); + ggml_quantize_mat_q8_0_4x8_generic(x, vy, k); #endif } @@ -203,6 +202,7 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo ggml_gemv_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc); } +#if defined __riscv_zvfh void ggml_gemv_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; @@ -222,7 +222,6 @@ void ggml_gemv_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v UNUSED(ncols_interleaved); UNUSED(blocklen); -#if defined __riscv_v_intrinsic const block_q8_0 * a_ptr = (const block_q8_0 *) vy; for (int x = 0; x < nc / ncols_interleaved; x++) { const block_q4_0x16 * b_ptr = (const block_q4_0x16 *) vx + (x * nb); @@ -256,9 +255,6 @@ void ggml_gemv_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v __riscv_vse32_v_f32m2(s + x * 16, sumf, 16); } - return; -#endif - ggml_gemv_q4_0_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); } void ggml_gemv_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { @@ -280,7 +276,6 @@ void ggml_gemv_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v UNUSED(ncols_interleaved); UNUSED(blocklen); -#if defined __riscv_v_intrinsic const block_q8_K * a_ptr = (const block_q8_K *) vy; for (int x = 0; x < nc / ncols_interleaved; x++) { @@ -392,9 +387,6 @@ void ggml_gemv_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v __riscv_vse32_v_f32m2(s + x * 16, sumf, 16); } - return; -#endif - ggml_gemv_q4_K_16x1_q8_K_generic(n, s, bs, vx, vy, nr, nc); } void ggml_gemv_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { @@ -416,7 +408,6 @@ void ggml_gemv_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const UNUSED(ncols_interleaved); UNUSED(blocklen); -#if defined __riscv_v_intrinsic const vint8mf2_t values = __riscv_vle8_v_i8mf2(kvalues_iq4nl, 16); const block_q8_0 * a_ptr = (const block_q8_0 *) vy; for (int x = 0; x < nc / ncols_interleaved; x++) { @@ -451,9 +442,6 @@ void ggml_gemv_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const __riscv_vse32_v_f32m2(s + x * 16, sumf, 16); } - return; -#endif - ggml_gemv_iq4_nl_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); } void ggml_gemv_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { @@ -476,7 +464,6 @@ void ggml_gemv_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v UNUSED(blocklen); UNUSED(bs); -#if defined __riscv_v_intrinsic const block_q8_0 * a_ptr = (const block_q8_0 *) vy; for (int x = 0; x < nc / ncols_interleaved; x++) { const block_q8_0x16 * b_ptr = (const block_q8_0x16 *) vx + (x * nb); @@ -505,9 +492,6 @@ void ggml_gemv_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v __riscv_vse32_v_f32m2(s + x * 16, sumf, 16); } - return; -#endif - ggml_gemv_q8_0_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); } void ggml_gemv_q2_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { @@ -679,9 +663,9 @@ void ggml_gemv_q2_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v } // End K-Block __riscv_vse32_v_f32m2(s + col_tile, v_sumf, vl); - } } +#endif void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; @@ -909,6 +893,7 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo ggml_gemm_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc); } +#if defined __riscv_zvfh void ggml_gemm_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; @@ -929,7 +914,6 @@ void ggml_gemm_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v UNUSED(ncols_interleaved); UNUSED(blocklen); -#if defined __riscv_v_intrinsic for (int y = 0; y < nr / 4; y++) { const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); for (int x = 0; x < nc / ncols_interleaved; x++) { @@ -994,9 +978,6 @@ void ggml_gemm_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v __riscv_vse32_v_f32m2(s + (y * 4 + 3) * bs + x * 16, sumf_3, 16); } } - return; -#endif - ggml_gemm_q4_0_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); } void ggml_gemm_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { @@ -1019,7 +1000,6 @@ void ggml_gemm_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v UNUSED(ncols_interleaved); UNUSED(blocklen); -#if defined __riscv_v_intrinsic for (int y = 0; y < nr / 4; y++) { const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb); for (int x = 0; x < nc / ncols_interleaved; x++) { @@ -1267,9 +1247,6 @@ void ggml_gemm_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v __riscv_vse32_v_f32m2(s + (y * 4 + 3) * bs + x * 16, sumf_3, 16); } } - return; -#endif - ggml_gemm_q4_K_16x1_q8_K_generic(n, s, bs, vx, vy, nr, nc); } void ggml_gemm_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { @@ -1292,7 +1269,6 @@ void ggml_gemm_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const UNUSED(ncols_interleaved); UNUSED(blocklen); -#if defined __riscv_v_intrinsic const vint8mf2_t values = __riscv_vle8_v_i8mf2(kvalues_iq4nl, 16); for (int y = 0; y < nr / 4; y++) { const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); @@ -1355,9 +1331,6 @@ void ggml_gemm_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const __riscv_vse32_v_f32m2(s + (y * 4 + 3) * bs + x * 16, sumf_3, 16); } } - return; -#endif - ggml_gemm_iq4_nl_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); } void ggml_gemm_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { @@ -1380,7 +1353,6 @@ void ggml_gemm_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v UNUSED(ncols_interleaved); UNUSED(blocklen); -#if defined __riscv_v_intrinsic for (int y = 0; y < nr / 4; y++) { const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); for (int x = 0; x < nc / ncols_interleaved; x++) { @@ -1429,9 +1401,6 @@ void ggml_gemm_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v __riscv_vse32_v_f32m2(s + (y * 4 + 3) * bs + x * 16, sumf_3, 16); } } - return; -#endif - ggml_gemm_q8_0_16x1_q8_0_generic(n, s, bs, vx, vy, nr, nc); } void ggml_gemm_q2_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { @@ -1731,3 +1700,4 @@ void ggml_gemm_q2_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v } } } +#endif diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp index 6b76ab3bfb1..f18758f16bb 100644 --- a/ggml/src/ggml-cpu/repack.cpp +++ b/ggml/src/ggml-cpu/repack.cpp @@ -1365,6 +1365,7 @@ void ggml_gemv_q8_0_4x8_q8_0_generic(int n, } } +// Only enable these for RISC-V. #if defined __riscv_zvfh void ggml_gemv_q4_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; @@ -1568,6 +1569,7 @@ void ggml_gemv_q2_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, assert(nc % 16 == 0); UNUSED(bs); + UNUSED(nr); const int nb = n / QK_K; const block_q2_Kx16 * x = (const block_q2_Kx16 *)vx; @@ -2381,6 +2383,7 @@ void ggml_gemm_q8_0_4x8_q8_0_generic(int n, } } +// Only enable these for RISC-V. #if defined __riscv_zvfh void ggml_gemm_q4_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; From 906aef3da84894d4b59e4f94d5fe69bc6fd0f01b Mon Sep 17 00:00:00 2001 From: Kevin Hannon Date: Tue, 17 Mar 2026 13:16:49 -0400 Subject: [PATCH 310/831] ggml-blas: set mkl threads from thread context (llama/20602) * ggml blas: set mkl threads from thread context * add code to run blas locally --- ggml/src/ggml-blas/ggml-blas.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ggml/src/ggml-blas/ggml-blas.cpp b/ggml/src/ggml-blas/ggml-blas.cpp index 5de64b816fc..e7a1763b54d 100644 --- a/ggml/src/ggml-blas/ggml-blas.cpp +++ b/ggml/src/ggml-blas/ggml-blas.cpp @@ -121,6 +121,8 @@ static void ggml_backend_blas_mul_mat(ggml_backend_blas_context * ctx, struct gg bli_thread_set_num_threads(ctx->n_threads); #elif defined(GGML_BLAS_USE_NVPL) nvpl_blas_set_num_threads(ctx->n_threads); +#elif defined(GGML_BLAS_USE_MKL) + mkl_set_num_threads(ctx->n_threads); #endif for (int64_t i13 = 0; i13 < ne13; i13++) { From 16ca5e6fb130cd68d1d499db8b59361e3aba0db6 Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Tue, 17 Mar 2026 21:51:43 +0100 Subject: [PATCH 311/831] vulkan: disable mmvq on Intel Windows driver (llama/20672) * vulkan: disable mmvq on Intel Windows driver * improve comment --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 3d8ce10676e..3e36435d166 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -7646,20 +7646,14 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_ return true; } case VK_VENDOR_ID_INTEL: - if (k < 2048) { + if (device->driver_id == vk::DriverId::eIntelProprietaryWindows) { + // Intel Windows proprietary driver MMVQ performance is worse than fp16, see + // https://github.com/ggml-org/llama.cpp/issues/17628 return false; } - if (device->driver_id == vk::DriverId::eIntelProprietaryWindows) { - // Intel Windows proprietary driver tuning - switch (src0_type) { - case GGML_TYPE_MXFP4: - case GGML_TYPE_Q4_K: - case GGML_TYPE_Q5_K: - return false; - default: - return true; - } + if (k < 2048) { + return false; } switch (src0_type) { From e222814fc4bef846054b071b48bb54e89fcc00c5 Mon Sep 17 00:00:00 2001 From: Krishna Sridhar <99914379+srikris-sridhar@users.noreply.github.com> Date: Tue, 17 Mar 2026 15:34:36 -0700 Subject: [PATCH 312/831] hexagon: add neg, exp, sigmoid, softplus ops, cont, repeat ops (llama/20701) Add element-wise unary ops needed by Qwen 3.5's DeltaNet linear attention layers. These ops follow the existing unary-ops pattern with VTCM DMA double-buffering. - neg: negate via scale by -1.0 - exp: uses existing hvx_exp_f32 HVX intrinsics - sigmoid: uses existing hvx_sigmoid_f32_aa HVX intrinsics - softplus: log(1 + exp(x)) scalar fallback - CONT reuses the existing CPY infrastructure since making a tensor contiguous is equivalent to a same-type copy. - REPEAT implements tiled memory copy with multi-threaded execution via the worker pool, supporting f32 and f16 types. The kernel parallelizes across output rows and uses memcpy for each tile. Co-authored-by: Max Krasnyansky --- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 152 ++++++++++++++++++++--- ggml/src/ggml-hexagon/htp/CMakeLists.txt | 1 + ggml/src/ggml-hexagon/htp/htp-msg.h | 5 + ggml/src/ggml-hexagon/htp/htp-ops.h | 1 + ggml/src/ggml-hexagon/htp/hvx-base.h | 2 + ggml/src/ggml-hexagon/htp/hvx-exp.h | 17 +-- ggml/src/ggml-hexagon/htp/hvx-sigmoid.h | 1 + ggml/src/ggml-hexagon/htp/main.c | 45 +++++++ ggml/src/ggml-hexagon/htp/repeat-ops.c | 148 ++++++++++++++++++++++ ggml/src/ggml-hexagon/htp/softmax-ops.c | 2 +- ggml/src/ggml-hexagon/htp/unary-ops.c | 95 ++++++++++++++ 11 files changed, 441 insertions(+), 28 deletions(-) create mode 100644 ggml/src/ggml-hexagon/htp/repeat-ops.c diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 19917cb1140..4b8a16c3635 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -2362,6 +2362,27 @@ static inline size_t init_cpy_req(htp_general_req * req, dspqueue_buffer * bufs, return n_bufs; } +static inline size_t init_cont_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { + // CONT is just a contiguous copy — reuse CPY op + req->op = HTP_OP_CPY; + + size_t n_bufs = 0; + n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); + n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); + + return n_bufs; +} + +static inline size_t init_repeat_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { + req->op = HTP_OP_REPEAT; + + size_t n_bufs = 0; + n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); + n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); + + return n_bufs; +} + static inline size_t init_get_rows_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { req->op = HTP_OP_GET_ROWS; @@ -2449,12 +2470,33 @@ static inline size_t init_unary_req(htp_general_req * req, dspqueue_buffer * buf break; case GGML_OP_UNARY: - if (ggml_get_unary_op(t) == GGML_UNARY_OP_SILU) { + switch (ggml_get_unary_op(t)) { + case GGML_UNARY_OP_SILU: req->op = HTP_OP_UNARY_SILU; supported = true; - } else if (ggml_get_unary_op(t) == GGML_UNARY_OP_GELU) { + break; + case GGML_UNARY_OP_GELU: req->op = HTP_OP_UNARY_GELU; supported = true; + break; + case GGML_UNARY_OP_SIGMOID: + req->op = HTP_OP_UNARY_SIGMOID; + supported = true; + break; + case GGML_UNARY_OP_NEG: + req->op = HTP_OP_UNARY_NEG; + supported = true; + break; + case GGML_UNARY_OP_EXP: + req->op = HTP_OP_UNARY_EXP; + supported = true; + break; + case GGML_UNARY_OP_SOFTPLUS: + req->op = HTP_OP_UNARY_SOFTPLUS; + supported = true; + break; + default: + break; } break; @@ -2640,16 +2682,28 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg ggml_hexagon_dispatch_op(sess, node, flags); break; case GGML_OP_UNARY: - if ((ggml_get_unary_op(node) == GGML_UNARY_OP_SILU) || - (ggml_get_unary_op(node) == GGML_UNARY_OP_GELU)) { - ggml_hexagon_dispatch_op(sess, node, flags); + switch (ggml_get_unary_op(node)) { + case GGML_UNARY_OP_NEG: + case GGML_UNARY_OP_EXP: + case GGML_UNARY_OP_SIGMOID: + case GGML_UNARY_OP_SOFTPLUS: + case GGML_UNARY_OP_SILU: + case GGML_UNARY_OP_GELU: + ggml_hexagon_dispatch_op(sess, node, flags); + break; + default: + break; } break; case GGML_OP_GLU: - if ((ggml_get_glu_op(node) == GGML_GLU_OP_SWIGLU) || - (ggml_get_glu_op(node) == GGML_GLU_OP_SWIGLU_OAI) || - (ggml_get_glu_op(node) == GGML_GLU_OP_GEGLU)) { - ggml_hexagon_dispatch_op(sess, node, flags); + switch (ggml_get_glu_op(node)) { + case GGML_GLU_OP_SWIGLU: + case GGML_GLU_OP_SWIGLU_OAI: + case GGML_GLU_OP_GEGLU: + ggml_hexagon_dispatch_op(sess, node, flags); + break; + default: + break; } break; case GGML_OP_SOFT_MAX: @@ -2676,6 +2730,14 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg ggml_hexagon_dispatch_op(sess, node, flags); break; + case GGML_OP_CONT: + ggml_hexagon_dispatch_op(sess, node, flags); + break; + + case GGML_OP_REPEAT: + ggml_hexagon_dispatch_op(sess, node, flags); + break; + case GGML_OP_ARGSORT: ggml_hexagon_dispatch_op(sess, node, flags); break; @@ -3006,6 +3068,39 @@ static bool ggml_hexagon_supported_cpy(const struct ggml_hexagon_session * sess, return true; } +static bool ggml_hexagon_supported_cont(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + GGML_UNUSED(sess); + const struct ggml_tensor * src0 = op->src[0]; + + // CONT is same-type only, supports f32 and f16 + if (src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) return false; + + return true; +} + +static bool ggml_hexagon_supported_repeat(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + GGML_UNUSED(sess); + const struct ggml_tensor * src0 = op->src[0]; + const struct ggml_tensor * dst = op; + + // Support f32 and f16 + if (src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) return false; + + // src and dst must be the same type + if (src0->type != dst->type) return false; + + // dst dims must be multiples of src dims + if (dst->ne[0] % src0->ne[0] != 0) return false; + if (dst->ne[1] % src0->ne[1] != 0) return false; + if (dst->ne[2] % src0->ne[2] != 0) return false; + if (dst->ne[3] % src0->ne[3] != 0) return false; + + // require contiguous tensors (no transposition) + if (ggml_is_transposed(src0) || ggml_is_transposed(dst)) return false; + + return true; +} + static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) { auto sess = static_cast(dev->context); @@ -3063,21 +3158,32 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons break; case GGML_OP_UNARY: - { - const auto unary_op = ggml_get_unary_op(op); - if (unary_op == GGML_UNARY_OP_SILU || unary_op == GGML_UNARY_OP_GELU) { + switch (ggml_get_unary_op(op)) { + case GGML_UNARY_OP_NEG: + case GGML_UNARY_OP_EXP: + case GGML_UNARY_OP_SIGMOID: + case GGML_UNARY_OP_SOFTPLUS: + supp = ggml_hexagon_supported_unary(sess, op); + break; + case GGML_UNARY_OP_SILU: + case GGML_UNARY_OP_GELU: supp = ggml_hexagon_supported_activations(sess, op); - } - break; + break; + default: + break; } + break; case GGML_OP_GLU: - { - const auto glu_op = ggml_get_glu_op(op); - if ((glu_op == GGML_GLU_OP_SWIGLU) || (glu_op == GGML_GLU_OP_SWIGLU_OAI) || (glu_op == GGML_GLU_OP_GEGLU)) { + switch (ggml_get_glu_op(op)) { + case GGML_GLU_OP_SWIGLU: + case GGML_GLU_OP_SWIGLU_OAI: + case GGML_GLU_OP_GEGLU: supp = ggml_hexagon_supported_activations(sess, op); - } - break; + break; + default: + break; } + break; case GGML_OP_ROPE: supp = ggml_hexagon_supported_rope(sess, op); break; @@ -3098,6 +3204,14 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons supp = ggml_hexagon_supported_cpy(sess, op); break; + case GGML_OP_CONT: + supp = ggml_hexagon_supported_cont(sess, op); + break; + + case GGML_OP_REPEAT: + supp = ggml_hexagon_supported_repeat(sess, op); + break; + case GGML_OP_ARGSORT: supp = ggml_hexagon_supported_argsort(sess, op); break; diff --git a/ggml/src/ggml-hexagon/htp/CMakeLists.txt b/ggml/src/ggml-hexagon/htp/CMakeLists.txt index 02d07a503d5..a490a2ce9a1 100644 --- a/ggml/src/ggml-hexagon/htp/CMakeLists.txt +++ b/ggml/src/ggml-hexagon/htp/CMakeLists.txt @@ -30,6 +30,7 @@ add_library(${HTP_LIB} SHARED set-rows-ops.c get-rows-ops.c cpy-ops.c + repeat-ops.c argsort-ops.c ssm-conv.c ) diff --git a/ggml/src/ggml-hexagon/htp/htp-msg.h b/ggml/src/ggml-hexagon/htp/htp-msg.h index 52dcc36d8f7..56bc5b622c5 100644 --- a/ggml/src/ggml-hexagon/htp/htp-msg.h +++ b/ggml/src/ggml-hexagon/htp/htp-msg.h @@ -53,6 +53,10 @@ enum htp_op { HTP_OP_RMS_NORM, HTP_OP_UNARY_SILU, HTP_OP_UNARY_GELU, + HTP_OP_UNARY_SIGMOID, + HTP_OP_UNARY_EXP, + HTP_OP_UNARY_NEG, + HTP_OP_UNARY_SOFTPLUS, HTP_OP_GLU_SWIGLU, HTP_OP_GLU_SWIGLU_OAI, HTP_OP_GLU_GEGLU, @@ -69,6 +73,7 @@ enum htp_op { HTP_OP_SQRT, HTP_OP_SUM_ROWS, HTP_OP_SSM_CONV, + HTP_OP_REPEAT, INVALID }; diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index 2ef20936f1b..f643fdc340d 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -57,6 +57,7 @@ int op_flash_attn_ext(struct htp_ops_context * octx); int op_set_rows(struct htp_ops_context * octx); int op_get_rows(struct htp_ops_context * octx); int op_cpy(struct htp_ops_context * octx); +int op_repeat(struct htp_ops_context * octx); int op_argsort(struct htp_ops_context * octx); int op_ssm_conv(struct htp_ops_context * octx); diff --git a/ggml/src/ggml-hexagon/htp/hvx-base.h b/ggml/src/ggml-hexagon/htp/hvx-base.h index 578ca288fb6..3e6a8579b1f 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-base.h +++ b/ggml/src/ggml-hexagon/htp/hvx-base.h @@ -3,6 +3,8 @@ #include #include +#include +#include #include "hex-utils.h" #include "hvx-types.h" diff --git a/ggml/src/ggml-hexagon/htp/hvx-exp.h b/ggml/src/ggml-hexagon/htp/hvx-exp.h index 44dfe232a3d..84e4836dc92 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-exp.h +++ b/ggml/src/ggml-hexagon/htp/hvx-exp.h @@ -3,6 +3,7 @@ #include #include +#include #include "hvx-base.h" #include "hvx-floor.h" @@ -16,8 +17,8 @@ #define EXP_LOGN2 (0x3F317218) // ln(2) = 0.6931471805 #define EXP_LOG2E (0x3FB8AA3B) // log2(e) = 1/ln(2) = 1.4426950408 #define EXP_ONE (0x3f800000) // 1.0 -#define EXP_RANGE_R (0x41a00000) // 20.0 -#define EXP_RANGE_L (0xc1a00000) // -20.0 +#define EXP_RANGE_R (0x42B16666) // 88.7 +#define EXP_RANGE_L (0xC2B00000) // -88.0 (approx log(FLT_MIN)) static inline HVX_Vector hvx_vec_exp_f32(HVX_Vector in_vec) { HVX_Vector z_qf32_v; @@ -47,12 +48,12 @@ static inline HVX_Vector hvx_vec_exp_f32(HVX_Vector in_vec) { HVX_Vector temp_v = in_vec; - // Clamp inputs to (-20.0, 20.0) + // Clamp inputs to (-88.0, 88.0) to avoid overflow/underflow HVX_VectorPred pred_cap_right = Q6_Q_vcmp_gt_VsfVsf(in_vec, Q6_V_vsplat_R(EXP_RANGE_R)); HVX_VectorPred pred_cap_left = Q6_Q_vcmp_gt_VsfVsf(Q6_V_vsplat_R(EXP_RANGE_L), in_vec); in_vec = Q6_V_vmux_QVV(pred_cap_right, Q6_V_vsplat_R(EXP_RANGE_R), temp_v); - in_vec = Q6_V_vmux_QVV(pred_cap_left, Q6_V_vsplat_R(EXP_RANGE_L), temp_v); + in_vec = Q6_V_vmux_QVV(pred_cap_left, Q6_V_vsplat_R(EXP_RANGE_L), in_vec); epsilon_v = Q6_Vqf32_vmpy_VsfVsf(log2e, in_vec); epsilon_v = Q6_Vsf_equals_Vqf32(epsilon_v); @@ -69,12 +70,12 @@ static inline HVX_Vector hvx_vec_exp_f32(HVX_Vector in_vec) { // normalize before every QFloat's vmpy x_qf32_v = Q6_Vqf32_vadd_Vqf32Vsf(x_qf32_v, zero_v); + x_v = Q6_Vsf_equals_Vqf32(x_qf32_v); + // z = x * x; z_qf32_v = Q6_Vqf32_vmpy_Vqf32Vqf32(x_qf32_v, x_qf32_v); z_qf32_v = Q6_Vqf32_vadd_Vqf32Vsf(z_qf32_v, zero_v); - x_v = Q6_Vsf_equals_Vqf32(x_qf32_v); - // y = E4 + E5 * x; E_const = Q6_V_vsplat_R(EXP_COEFF_5); y_v = Q6_Vqf32_vmpy_VsfVsf(E_const, x_v); @@ -145,7 +146,7 @@ static inline HVX_Vector hvx_vec_exp_f32_guard(HVX_Vector in_vec, HVX_Vector max return Q6_V_vmux_QVV(pred0, inf, out); } -static inline void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, bool negate) { +static inline void hvx_exp_f32(uint8_t * restrict dst, const uint8_t * restrict src, const int num_elems, bool negate) { int left_over = num_elems & (VLEN_FP32 - 1); int num_elems_whole = num_elems - left_over; @@ -162,7 +163,7 @@ static inline void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict HVX_Vector vec_out = Q6_V_vzero(); static const float kInf = INFINITY; - static const float kMaxExp = 88.02f; // log(INF) + static const float kMaxExp = 88.7f; const HVX_Vector max_exp = hvx_vec_splat_f32(kMaxExp); const HVX_Vector inf = hvx_vec_splat_f32(kInf); diff --git a/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h b/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h index 095193277ea..37f3e7b6fae 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +++ b/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h @@ -2,6 +2,7 @@ #define HVX_SIGMOID_H #include "hvx-base.h" +#include "hvx-inverse.h" #define FAST_SIGMOID_LOG2F (0x3fb8aa3b) // 1.442695022 #define FAST_SIGMOID_C1 (0x3d009076) // 0.03138777 diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index 3f99dbb32c4..2a3f9e562b7 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -516,6 +516,39 @@ static void proc_cpy_req(struct htp_context * ctx, struct htp_general_req * req, send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); } +static void proc_repeat_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { + struct dspqueue_buffer rsp_bufs[1]; + + // We had written to the output buffer, we'd also need to flush it + rsp_bufs[0].fd = bufs[1].fd; + rsp_bufs[0].ptr = bufs[1].ptr; + rsp_bufs[0].offset = bufs[1].offset; + rsp_bufs[0].size = bufs[1].size; + rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP + DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU + + // Setup Op context + struct htp_ops_context octx = { 0 }; + octx.ctx = ctx; + octx.src0 = req->src0; + octx.dst = req->dst; + octx.flags = req->flags; + octx.op = req->op; + + // Update data pointers + octx.src0.data = (uint32_t) bufs[0].ptr; + octx.dst.data = (uint32_t) bufs[1].ptr; + octx.n_threads = ctx->n_threads; + + struct profile_data prof; + profile_start(&prof); + + uint32_t rsp_status = op_repeat(&octx); + + profile_stop(&prof); + send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); +} + static void proc_get_rows_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { struct dspqueue_buffer rsp_bufs[1]; @@ -1090,6 +1123,10 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) { case HTP_OP_SQR: case HTP_OP_SQRT: + case HTP_OP_UNARY_NEG: + case HTP_OP_UNARY_EXP: + case HTP_OP_UNARY_SIGMOID: + case HTP_OP_UNARY_SOFTPLUS: if (n_bufs != 2) { FARF(ERROR, "Bad unary-req buffer list"); continue; @@ -1175,6 +1212,14 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) { proc_cpy_req(ctx, &req, bufs); break; + case HTP_OP_REPEAT: + if (n_bufs != 2) { + FARF(ERROR, "Bad repeat-req buffer list"); + continue; + } + proc_repeat_req(ctx, &req, bufs); + break; + case HTP_OP_ARGSORT: if (n_bufs != 2) { FARF(ERROR, "Bad argsort-req buffer list"); diff --git a/ggml/src/ggml-hexagon/htp/repeat-ops.c b/ggml/src/ggml-hexagon/htp/repeat-ops.c new file mode 100644 index 00000000000..5db06c920e2 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/repeat-ops.c @@ -0,0 +1,148 @@ +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#include +#include + +#include + +#include "hvx-utils.h" + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" +#include "htp-msg.h" +#include "htp-ops.h" + +struct htp_repeat_context { + struct htp_ops_context * octx; + + uint32_t nr0; + uint32_t nr1; + uint32_t nr2; + uint32_t nr3; + + uint32_t nrows_per_thread; + uint32_t total_dst_rows; // ne1 * ne2 * ne3 + + size_t type_size; +}; + +static void repeat_job_per_thread(unsigned int nth, unsigned int ith, void * data) { + const struct htp_repeat_context * rctx = (const struct htp_repeat_context *) data; + struct htp_ops_context * octx = rctx->octx; + const struct htp_tensor * src = &octx->src0; + const struct htp_tensor * dst = &octx->dst; + + const uint32_t ne00 = src->ne[0]; + const uint32_t ne01 = src->ne[1]; + const uint32_t ne02 = src->ne[2]; + const uint32_t ne03 = src->ne[3]; + + const uint32_t nb00 = src->nb[0]; + const uint32_t nb01 = src->nb[1]; + const uint32_t nb02 = src->nb[2]; + const uint32_t nb03 = src->nb[3]; + + const uint32_t ne0 = dst->ne[0]; + const uint32_t ne1 = dst->ne[1]; + const uint32_t ne2 = dst->ne[2]; + const uint32_t ne3 = dst->ne[3]; + + const uint32_t nb0 = dst->nb[0]; + const uint32_t nb1 = dst->nb[1]; + const uint32_t nb2 = dst->nb[2]; + const uint32_t nb3 = dst->nb[3]; + + const uint32_t nr0 = rctx->nr0; + const uint32_t nr1 = rctx->nr1; + const uint32_t nr2 = rctx->nr2; + const uint32_t nr3 = rctx->nr3; + + const size_t row_bytes = ne00 * rctx->type_size; + + const uint32_t row_start = rctx->nrows_per_thread * ith; + const uint32_t row_end = MIN(row_start + rctx->nrows_per_thread, rctx->total_dst_rows); + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + for (uint32_t dst_row = row_start; dst_row < row_end; dst_row++) { + // Decompose flat dst row index into (i1, i2, i3) + const uint32_t i1 = dst_row % ne1; + const uint32_t i2 = (dst_row / ne1) % ne2; + const uint32_t i3 = dst_row / (ne1 * ne2); + + // Map to source indices (tiling) + const uint32_t k1 = i1 % ne01; + const uint32_t k2 = i2 % ne02; + const uint32_t k3 = i3 % ne03; + + const uint8_t * src_row = (const uint8_t *) src->data + k1 * nb01 + k2 * nb02 + k3 * nb03; + uint8_t * dst_base = (uint8_t *) dst->data + i1 * nb1 + i2 * nb2 + i3 * nb3; + + // Tile along dimension 0 + for (uint32_t i0 = 0; i0 < nr0; i0++) { + uint8_t * dst_ptr = dst_base + i0 * ne00 * nb0; + memcpy(dst_ptr, src_row, row_bytes); + } + } + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "repeat %d/%d: (%ux%ux%ux%u) -> (%ux%ux%ux%u) rows %u:%u usec %u\n", + ith, nth, src->ne[0], src->ne[1], src->ne[2], src->ne[3], + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + row_start, row_end, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +int op_repeat(struct htp_ops_context * octx) { + const struct htp_tensor * src0 = &octx->src0; + struct htp_tensor * dst = &octx->dst; + + // Validate that dst dims are multiples of src dims + if (dst->ne[0] % src0->ne[0] != 0 || + dst->ne[1] % src0->ne[1] != 0 || + dst->ne[2] % src0->ne[2] != 0 || + dst->ne[3] % src0->ne[3] != 0) { + FARF(ERROR, "repeat: dst dims must be multiples of src dims\n"); + return HTP_STATUS_INVAL_PARAMS; + } + + size_t type_size; + switch (src0->type) { + case HTP_TYPE_F32: type_size = 4; break; + case HTP_TYPE_F16: type_size = 2; break; + default: + FARF(ERROR, "repeat: unsupported type %u\n", src0->type); + return HTP_STATUS_NO_SUPPORT; + } + + const uint32_t total_dst_rows = dst->ne[1] * dst->ne[2] * dst->ne[3]; + const uint32_t n_threads = MIN(octx->n_threads, total_dst_rows); + + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { + return HTP_STATUS_OK; + } + + struct htp_repeat_context rctx = { + .octx = octx, + .nr0 = dst->ne[0] / src0->ne[0], + .nr1 = dst->ne[1] / src0->ne[1], + .nr2 = dst->ne[2] / src0->ne[2], + .nr3 = dst->ne[3] / src0->ne[3], + .nrows_per_thread = (total_dst_rows + n_threads - 1) / n_threads, + .total_dst_rows = total_dst_rows, + .type_size = type_size, + }; + + FARF(HIGH, "repeat: (%ux%ux%ux%u) -> (%ux%ux%ux%u) nr=(%u,%u,%u,%u)\n", + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + rctx.nr0, rctx.nr1, rctx.nr2, rctx.nr3); + + worker_pool_run_func(octx->ctx->worker_pool, repeat_job_per_thread, &rctx, n_threads); + + return HTP_STATUS_OK; +} diff --git a/ggml/src/ggml-hexagon/htp/softmax-ops.c b/ggml/src/ggml-hexagon/htp/softmax-ops.c index 8dae7f1ed55..d6356b9506f 100644 --- a/ggml/src/ggml-hexagon/htp/softmax-ops.c +++ b/ggml/src/ggml-hexagon/htp/softmax-ops.c @@ -195,7 +195,7 @@ static float hvx_softmax_f32(const uint8_t * restrict src, const float max) { hvx_sub_scalar_f32(spad, src, max, num_elems); - hvx_exp_f32(spad, dst, num_elems, false); + hvx_exp_f32(dst, spad, num_elems, false); float sum = hvx_reduce_sum_f32(dst, num_elems); diff --git a/ggml/src/ggml-hexagon/htp/unary-ops.c b/ggml/src/ggml-hexagon/htp/unary-ops.c index 5bbd5040d3d..3d0928d4dce 100644 --- a/ggml/src/ggml-hexagon/htp/unary-ops.c +++ b/ggml/src/ggml-hexagon/htp/unary-ops.c @@ -9,6 +9,8 @@ #include #include "hex-dma.h" +#include "hvx-exp.h" +#include "hvx-sigmoid.h" #include "hvx-utils.h" #define GGML_COMMON_DECL_C @@ -166,6 +168,75 @@ static void sqrt_f32(const float * restrict src, } } +static void neg_f32(const float * restrict src, + float * restrict dst, + uint8_t * restrict spad, + const uint32_t num_rows, + const uint32_t row_elems, + const size_t row_size, + int32_t * op_params) { + + for (uint32_t ir = 0; ir < num_rows; ir++) { + const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size); + uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size); + + hvx_scale_f32_aa(dst_local, src_local, row_elems, -1.0f); + } +} + +static void exp_f32(const float * restrict src, + float * restrict dst, + uint8_t * restrict spad, + const uint32_t num_rows, + const uint32_t row_elems, + const size_t row_size, + int32_t * op_params) { + + for (uint32_t ir = 0; ir < num_rows; ir++) { + const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size); + uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size); + + hvx_exp_f32(dst_local, src_local, row_elems, false); + } +} + +static void sigmoid_f32(const float * restrict src, + float * restrict dst, + uint8_t * restrict spad, + const uint32_t num_rows, + const uint32_t row_elems, + const size_t row_size, + int32_t * op_params) { + + for (uint32_t ir = 0; ir < num_rows; ir++) { + const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size); + uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size); + + hvx_sigmoid_f32_aa(dst_local, src_local, row_elems); + } +} + +static void softplus_f32(const float * restrict src, + float * restrict dst, + uint8_t * restrict spad, + const uint32_t num_rows, + const uint32_t row_elems, + const size_t row_size, + int32_t * op_params) { + // softplus(x) = log(1 + exp(x)) + // Match CPU reference: ggml_compute_softplus_f32() in ggml-impl.h + for (uint32_t ir = 0; ir < num_rows; ir++) { + const float * restrict src_f = (const float *)((const uint8_t *)src + (ir * row_size)); + float * restrict dst_f = (float *)((uint8_t *)dst + (ir * row_size)); + + for (uint32_t i = 0; i < row_elems; i++) { + float x = src_f[i]; + // For x > 20: softplus(x) ≈ x (avoids exp overflow) + dst_f[i] = (x > 20.0f) ? x : logf(1.0f + expf(x)); + } + } +} + static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * data) { const struct htp_unary_context * uctx = (const struct htp_unary_context *) data; struct htp_ops_context * octx = uctx->octx; @@ -247,6 +318,18 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * case HTP_OP_SQRT: sqrt_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params); break; + case HTP_OP_UNARY_NEG: + neg_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params); + break; + case HTP_OP_UNARY_EXP: + exp_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params); + break; + case HTP_OP_UNARY_SIGMOID: + sigmoid_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params); + break; + case HTP_OP_UNARY_SOFTPLUS: + softplus_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params); + break; default: break; } @@ -295,6 +378,18 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) { case HTP_OP_SQRT: op_type = "sqrt-f32"; break; + case HTP_OP_UNARY_NEG: + op_type = "neg-f32"; + break; + case HTP_OP_UNARY_EXP: + op_type = "exp-f32"; + break; + case HTP_OP_UNARY_SIGMOID: + op_type = "sigmoid-f32"; + break; + case HTP_OP_UNARY_SOFTPLUS: + op_type = "softplus-f32"; + break; default: FARF(ERROR, "Unsupported unary Op %u\n", octx->op); From 61c7cd024dd371952f3dae27266eaf7bf82f2f04 Mon Sep 17 00:00:00 2001 From: uvos Date: Wed, 18 Mar 2026 09:53:13 +0100 Subject: [PATCH 313/831] HIP : ignore return of hipMemAdvise [no ci] (llama/20696) --- ggml/src/ggml-cuda/ggml-cuda.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 5a0be4a472a..a31e843e153 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -126,7 +126,7 @@ static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) if (err == hipSuccess) { // hipMemAdviseSetCoarseGrain is an optional performance hint; // ignore errors (e.g. hipErrorInvalidValue on some APU/iGPU configs). - cudaMemAdvise(*ptr, size, hipMemAdviseSetCoarseGrain, device); + (void)cudaMemAdvise(*ptr, size, hipMemAdviseSetCoarseGrain, device); (void)hipGetLastError(); // clear any error } From 14caedfa18bfcf75888661221117db591897b40b Mon Sep 17 00:00:00 2001 From: Shaw Nguyen <49144872+mrshaw01@users.noreply.github.com> Date: Wed, 18 Mar 2026 23:45:06 +0700 Subject: [PATCH 314/831] ggml-cpu/x86: fix unused changemask warning in repack (llama/20692) --- ggml/src/ggml-cpu/arch/x86/repack.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-cpu/arch/x86/repack.cpp b/ggml/src/ggml-cpu/arch/x86/repack.cpp index 33c6cb65098..af1cebad131 100644 --- a/ggml/src/ggml-cpu/arch/x86/repack.cpp +++ b/ggml/src/ggml-cpu/arch/x86/repack.cpp @@ -531,7 +531,6 @@ static void gemv_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t UNUSED(bs); - __m128i changemask = _mm_set_epi8(15, 14, 7, 6, 13, 12, 5, 4, 11, 10, 3, 2, 9, 8, 1, 0); __m256i finalpermutemask = _mm256_set_epi32(7, 5, 3, 1, 6, 4, 2, 0); // Permute mask used for easier vector processing at later stages @@ -580,6 +579,7 @@ static void gemv_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t if constexpr ( std::is_same_v || std::is_same_v) { + const __m128i changemask = _mm_set_epi8(15, 14, 7, 6, 13, 12, 5, 4, 11, 10, 3, 2, 9, 8, 1, 0); col_scale_f32 = GGML_F32Cx8_REARRANGE_LOAD(b_ptr[b].d, changemask); } else if constexpr (std::is_same_v) { // Load 8 E8M0 exponents and convert to float via LUT From d6a0f0d075a2732e30031408e843fbbb712a860f Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Wed, 18 Mar 2026 10:23:47 -0700 Subject: [PATCH 315/831] Move to no timeout for WaitAny in graph submission to avoid deadlocks in some cases on llvm-pipe backends (llama/20618) --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 60 +++++++++++----------------- 1 file changed, 24 insertions(+), 36 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 128b7dc3de8..3976a171d16 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -509,50 +509,39 @@ static void ggml_backend_webgpu_wait_profile_futures(webgpu_global_context & static void ggml_backend_webgpu_wait(webgpu_global_context & ctx, std::vector & subs, bool block = true) { - // If we have too many in-flight submissions, wait on the oldest one first. if (subs.empty()) { return; } - while (subs.size() >= WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD) { - auto waitStatus = ctx->instance.WaitAny(1, &subs[0].submit_done, UINT64_MAX); - if (ggml_backend_webgpu_handle_wait_status(waitStatus)) { + + bool blocking_wait = block || subs.size() >= WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD; + while (blocking_wait) { + auto waitStatus = ctx->instance.WaitAny(1, &subs[0].submit_done, 0); + if (ggml_backend_webgpu_handle_wait_status(waitStatus, true)) { #ifdef GGML_WEBGPU_GPU_PROFILE ggml_backend_webgpu_wait_profile_futures(ctx, subs[0].profile_futures, true); #endif subs.erase(subs.begin()); } + blocking_wait = (block && !subs.empty()) || subs.size() >= WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD; } if (subs.empty()) { return; } - if (block) { - for (auto & sub : subs) { - while (!sub.submit_done.completed) { - auto waitStatus = ctx->instance.WaitAny(1, &sub.submit_done, UINT64_MAX); - ggml_backend_webgpu_handle_wait_status(waitStatus); - } -#ifdef GGML_WEBGPU_GPU_PROFILE - ggml_backend_webgpu_wait_profile_futures(ctx, sub.profile_futures, true); -#endif - } - subs.clear(); - } else { - // Poll each submit future once and remove completed submissions. - for (auto sub = subs.begin(); sub != subs.end();) { - auto waitStatus = ctx->instance.WaitAny(1, &sub->submit_done, 0); - ggml_backend_webgpu_handle_wait_status(waitStatus, true); + // Poll each submit future once and remove completed submissions. + for (auto sub = subs.begin(); sub != subs.end();) { + auto waitStatus = ctx->instance.WaitAny(1, &sub->submit_done, 0); + bool success = ggml_backend_webgpu_handle_wait_status(waitStatus, true); #ifdef GGML_WEBGPU_GPU_PROFILE - ggml_backend_webgpu_wait_profile_futures(ctx, sub->profile_futures, false); - if (sub->submit_done.completed && sub->profile_futures.empty()) { + ggml_backend_webgpu_wait_profile_futures(ctx, sub->profile_futures, false); + if (success && sub->profile_futures.empty()) { #else - if (sub->submit_done.completed) { + if (success) { #endif - sub = subs.erase(sub); - } else { - ++sub; - } + sub = subs.erase(sub); + } else { + ++sub; } } } @@ -2961,17 +2950,16 @@ static ggml_backend_buffer_type_t ggml_backend_webgpu_device_get_buffer_type(ggm static struct ggml_backend_buffer_type ggml_backend_webgpu_buffer_type = { /* .iface = */ { - /* .get_name = */ ggml_backend_webgpu_buffer_type_get_name, - /* .alloc_buffer = */ - ggml_backend_webgpu_buffer_type_alloc_buffer, /* .get_alignment = */ - ggml_backend_webgpu_buffer_type_get_alignment, /* .get_max_size = */ - ggml_backend_webgpu_buffer_type_get_max_size, /* .get_alloc_size = */ - ggml_backend_webgpu_buffer_type_get_alloc_size, /* .is_host = */ NULL, // defaults to false + /* .get_name = */ ggml_backend_webgpu_buffer_type_get_name, + /* .alloc_buffer = */ ggml_backend_webgpu_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_webgpu_buffer_type_get_alignment, + /* .get_max_size = */ ggml_backend_webgpu_buffer_type_get_max_size, + /* .get_alloc_size = */ ggml_backend_webgpu_buffer_type_get_alloc_size, + /* .is_host = */ NULL, // defaults to false }, /* .device = */ - dev, - /* .context = */ - NULL + dev, + /* .context = */ NULL }; return &ggml_backend_webgpu_buffer_type; From dfba84cb470ec2c4d750936b048460648aea7db6 Mon Sep 17 00:00:00 2001 From: Chenguang Li <757486878@qq.com> Date: Thu, 19 Mar 2026 11:02:42 +0800 Subject: [PATCH 316/831] CANN: support flash attention for head dim not multiple of 16, fix ALiBi slope offset (llama/20031) - Allow FLASH_ATTN_EXT when head dimension D is not a multiple of 16 by padding Q/K/V to D_padded = GGML_PAD(D, 16), running FusedInferAttentionScoreV2, then slicing the output back to D (ggml-cann.cpp + aclnn_ops.cpp). - Fix aclnn_get_slope second-part offset: use ggml_type_size(dtype) instead of sizeof(float) so ALiBi slopes are correct when dtype is F16 (e.g. GQA with 48 heads); fixes buffer overflow and large numerical errors in those cases. --- ggml/src/ggml-cann/aclnn_ops.cpp | 78 ++++++++++++++++++++++++++++---- ggml/src/ggml-cann/ggml-cann.cpp | 4 -- 2 files changed, 70 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index fc7c3e3b724..4b7aab1e72d 100644 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -1544,8 +1544,8 @@ static void aclnn_get_slope(ggml_backend_cann_context & ctx, end = 2 * ((n_head - 1) - n_head_log2) + 1; step = 2; count = n_head - n_head_log2; - aclnn_get_slope_inner(ctx, (char *) slope_buffer + n_head_log2 * sizeof(float), m1, count, start, end + 1, step, - dtype); + aclnn_get_slope_inner(ctx, (char *) slope_buffer + n_head_log2 * ggml_type_size(dtype), m1, count, start, end + 1, + step, dtype); } } @@ -3599,6 +3599,44 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context & ctx, ggml_tensor * dst acl_k_tensor = ggml_cann_create_tensor(src1, src1_bsnd_ne, src1_bsnd_nb, GGML_MAX_DIMS); acl_v_tensor = ggml_cann_create_tensor(src2, src2_bsnd_ne, src2_bsnd_nb, GGML_MAX_DIMS); + // Step 2.5: Pad Q, K, V along head dimension if D is not a multiple of 16 + // (required by FusedInferAttentionScoreV2) + const int64_t D = src0->ne[0]; + const int64_t D_padded = GGML_PAD(D, 16); + const bool needs_padding = (D != D_padded); + + ggml_cann_pool_alloc q_pad_allocator(ctx.pool()); + ggml_cann_pool_alloc k_pad_allocator(ctx.pool()); + ggml_cann_pool_alloc v_pad_allocator(ctx.pool()); + + if (needs_padding) { + int64_t paddings[] = { 0, D_padded - D, 0, 0, 0, 0, 0, 0 }; + + auto pad_fa_tensor = [&](acl_tensor_ptr & tensor, const int64_t * bsnd_ne, + ggml_cann_pool_alloc & allocator) { + int64_t pad_ne[GGML_MAX_DIMS] = { D_padded, bsnd_ne[1], bsnd_ne[2], bsnd_ne[3] }; + size_t pad_nb[GGML_MAX_DIMS]; + pad_nb[0] = faElemSize; + for (int i = 1; i < GGML_MAX_DIMS; ++i) { + pad_nb[i] = pad_nb[i - 1] * pad_ne[i - 1]; + } + int64_t nelements = pad_ne[0] * pad_ne[1] * pad_ne[2] * pad_ne[3]; + void * buffer = allocator.alloc(nelements * faElemSize); + acl_tensor_ptr padded = + ggml_cann_create_tensor(buffer, faDataType, faElemSize, pad_ne, pad_nb, GGML_MAX_DIMS); + aclnn_pad(ctx, tensor.get(), padded.get(), paddings); + tensor = std::move(padded); + }; + + pad_fa_tensor(acl_q_tensor, src0_bsnd_ne, q_pad_allocator); + pad_fa_tensor(acl_k_tensor, src1_bsnd_ne, k_pad_allocator); + pad_fa_tensor(acl_v_tensor, src2_bsnd_ne, v_pad_allocator); + + src0_bsnd_ne[0] = D_padded; + src1_bsnd_ne[0] = D_padded; + src2_bsnd_ne[0] = D_padded; + } + // Step 3: create the PSEShift tensor if needed // this tensor is considered as mask (f16) in the llama.cpp acl_tensor_ptr bcast_pse_tensor; @@ -3688,17 +3726,16 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context & ctx, ggml_tensor * dst GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); acl_tensor_ptr fa_dst_tensor; - acl_tensor_ptr acl_dst_tensor; ggml_cann_pool_alloc out_f16_allocator(ctx.pool()); - if (dst->type == GGML_TYPE_F32) { - void * out_f16_buffer = out_f16_allocator.alloc(ggml_nelements(dst) * faElemSize); - + if (dst->type == GGML_TYPE_F32 || needs_padding) { int64_t * out_f16_ne = src0_bsnd_ne; size_t out_f16_nb[GGML_MAX_DIMS]; out_f16_nb[0] = faElemSize; for (int i = 1; i < GGML_MAX_DIMS; ++i) { out_f16_nb[i] = out_f16_nb[i - 1] * out_f16_ne[i - 1]; } + int64_t out_nelements = out_f16_ne[0] * out_f16_ne[1] * out_f16_ne[2] * out_f16_ne[3]; + void * out_f16_buffer = out_f16_allocator.alloc(out_nelements * faElemSize); fa_dst_tensor = ggml_cann_create_tensor(out_f16_buffer, faDataType, faElemSize, out_f16_ne, out_f16_nb, GGML_MAX_DIMS); @@ -3730,8 +3767,33 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context & ctx, ggml_tensor * dst nullptr // softmaxLse ); - if (dst->type == GGML_TYPE_F32) { - // Step 6: post-processing, permute and cast to f32 + // Step 6: post-processing — slice padded output and/or cast to f32 + if (needs_padding) { + ggml_cann_pool_alloc sliced_f16_allocator(ctx.pool()); + + if (dst->type == GGML_TYPE_F32) { + int64_t sliced_ne[GGML_MAX_DIMS] = { D, src0_bsnd_ne[1], src0_bsnd_ne[2], src0_bsnd_ne[3] }; + size_t sliced_nb[GGML_MAX_DIMS]; + sliced_nb[0] = faElemSize; + for (int i = 1; i < GGML_MAX_DIMS; ++i) { + sliced_nb[i] = sliced_nb[i - 1] * sliced_ne[i - 1]; + } + int64_t sliced_nelements = sliced_ne[0] * sliced_ne[1] * sliced_ne[2] * sliced_ne[3]; + void * sliced_buffer = sliced_f16_allocator.alloc(sliced_nelements * faElemSize); + acl_tensor_ptr sliced_f16_tensor = ggml_cann_create_tensor(sliced_buffer, faDataType, faElemSize, + sliced_ne, sliced_nb, GGML_MAX_DIMS); + + GGML_CANN_CALL_ACLNN_OP(ctx, Slice, fa_dst_tensor.get(), + (int64_t) -1, (int64_t) 0, D, (int64_t) 1, sliced_f16_tensor.get()); + + acl_tensor_ptr acl_dst_tensor = ggml_cann_create_tensor(dst); + aclnn_cast(ctx, sliced_f16_tensor.get(), acl_dst_tensor.get(), ggml_cann_type_mapping(dst->type)); + } else { + acl_tensor_ptr acl_dst_tensor = ggml_cann_create_tensor(dst); + GGML_CANN_CALL_ACLNN_OP(ctx, Slice, fa_dst_tensor.get(), + (int64_t) -1, (int64_t) 0, D, (int64_t) 1, acl_dst_tensor.get()); + } + } else if (dst->type == GGML_TYPE_F32) { acl_tensor_ptr acl_dst_tensor = ggml_cann_create_tensor(dst); aclnn_cast(ctx, fa_dst_tensor.get(), acl_dst_tensor.get(), ggml_cann_type_mapping(dst->type)); } diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index 3f3de9f0bcb..a682746bb42 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -2503,10 +2503,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten // different head sizes of K and V are not supported yet return false; } - if (op->src[0]->ne[0] % 16 != 0) { - // TODO: padding to support - return false; - } float logitSoftcap = 0.0f; memcpy(&logitSoftcap, (const float *) (op->op_params) + 2, sizeof(float)); if (logitSoftcap != 0.0f) { From 12015a2174ad014cffdafddb0158875c3de8aed5 Mon Sep 17 00:00:00 2001 From: Masashi Yoshimura Date: Thu, 19 Mar 2026 13:08:35 +0900 Subject: [PATCH 317/831] ggml-webgpu: Add supports for `DIAG` and `TRI` (llama/20664) * Add supports for DIAG and TRI. * Remove extra ttype and add a comment for TRI op. --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 37 ++++++++++++++++--- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 8 ++++ ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl | 35 ++++++++++++++---- 3 files changed, 68 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 3d7e59fddf3..ad665e4de93 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -244,13 +244,15 @@ struct ggml_webgpu_binary_pipeline_key_hash { /** Unary **/ struct ggml_webgpu_unary_pipeline_key { - int type; - int op; - bool is_unary; // many unary operators fall under the GGML_OP_UNARY umbrella - bool inplace; + int type; + int op; + bool is_unary; // many unary operators fall under the GGML_OP_UNARY umbrella + bool inplace; + ggml_tri_type ttype; // only used for GGML_OP_TRI bool operator==(const ggml_webgpu_unary_pipeline_key & other) const { - return type == other.type && op == other.op && is_unary == other.is_unary && inplace == other.inplace; + return type == other.type && op == other.op && is_unary == other.is_unary && inplace == other.inplace && + ttype == other.ttype; } }; @@ -261,6 +263,7 @@ struct ggml_webgpu_unary_pipeline_key_hash { ggml_webgpu_hash_combine(seed, key.op); ggml_webgpu_hash_combine(seed, key.is_unary); ggml_webgpu_hash_combine(seed, key.inplace); + ggml_webgpu_hash_combine(seed, key.ttype); return seed; } }; @@ -1058,6 +1061,7 @@ class ggml_webgpu_shader_lib { .op = op, .is_unary = is_unary, .inplace = context.inplace, + .ttype = (ggml_tri_type) ggml_get_op_params_i32(context.dst, 0), }; auto it = unary_pipelines.find(key); @@ -1088,6 +1092,29 @@ class ggml_webgpu_shader_lib { variant += "_inplace"; } + if (op == GGML_OP_TRI) { + switch (key.ttype) { + case GGML_TRI_TYPE_LOWER: + defines.push_back("TRI_TYPE_LOWER"); + variant += "_tri_type_lower"; + break; + case GGML_TRI_TYPE_LOWER_DIAG: + defines.push_back("TRI_TYPE_LOWER_DIAG"); + variant += "_tri_type_lower_diag"; + break; + case GGML_TRI_TYPE_UPPER: + defines.push_back("TRI_TYPE_UPPER"); + variant += "_tri_type_upper"; + break; + case GGML_TRI_TYPE_UPPER_DIAG: + defines.push_back("TRI_TYPE_UPPER_DIAG"); + variant += "_tri_upper_diag"; + break; + default: + GGML_ABORT("Unsupported ggml_tri_type for unary shader"); + } + } + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); auto processed = preprocessor.preprocess(wgsl_unary, defines); diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 3976a171d16..4b0eeac0f42 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -2209,6 +2209,8 @@ static std::optional ggml_webgpu_encode_node(webgpu_context ctx, case GGML_OP_SQRT: case GGML_OP_SIN: case GGML_OP_COS: + case GGML_OP_DIAG: + case GGML_OP_TRI: return ggml_webgpu_unary_op(ctx, src0, node); case GGML_OP_PAD: return ggml_webgpu_pad(ctx, src0, node); @@ -3201,6 +3203,12 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const case GGML_OP_COS: supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type); break; + case GGML_OP_DIAG: + supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type); + break; + case GGML_OP_TRI: + supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type); + break; case GGML_OP_PAD: supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32; break; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl index feaf6d0ac29..21beb9bb94d 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl @@ -5,7 +5,6 @@ enable f16; #define TYPE f32 #endif - @group(0) @binding(0) var src: array; @@ -57,12 +56,20 @@ fn main(@builtin(global_invocation_id) gid: vec3) { return; } var i = gid.x; - let i3 = i / (params.ne2 * params.ne1 * params.ne0); - i = i % (params.ne2 * params.ne1 * params.ne0); - let i2 = i / (params.ne1 * params.ne0); - i = i % (params.ne1 * params.ne0); - let i1 = i / params.ne0; - let i0 = i % params.ne0; + let ne2 = params.ne2; +#ifdef DIAG + let ne1 = params.ne0; +#else + let ne1 = params.ne1; +#endif + let ne0 = params.ne0; + + let i3 = i / (ne2 * ne1 * ne0); + i = i % (ne2 * ne1 * ne0); + let i2 = i / (ne1 * ne0); + i = i % (ne1 * ne0); + let i1 = i / ne0; + let i0 = i % ne0; let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 + i2 * params.stride_src2 + i3 * params.stride_src3; @@ -184,6 +191,20 @@ fn main(@builtin(global_invocation_id) gid: vec3) { let res_f32 = cos(f32(src[params.offset_src + src_idx])); let res = TYPE(res_f32); #endif +#ifdef DIAG + let res = select(0.0, src[params.offset_src + i0 + i2 * params.stride_src2 + i3 * params.stride_src3], i0 == i1); +#endif +#ifdef TRI +#ifdef TRI_TYPE_LOWER + let res = select(0.0, src[params.offset_src + src_idx], i0 < i1); +#elif TRI_TYPE_LOWER_DIAG + let res = select(0.0, src[params.offset_src + src_idx], i0 <= i1); +#elif TRI_TYPE_UPPER + let res = select(0.0, src[params.offset_src + src_idx], i0 > i1); +#elif TRI_TYPE_UPPER_DIAG + let res = select(0.0, src[params.offset_src + src_idx], i0 >= i1); +#endif +#endif #ifdef INPLACE src[params.offset_src + src_idx] = res; From 3d004fbf0af918d2bcca0b43b63371dc9666446a Mon Sep 17 00:00:00 2001 From: Masashi Yoshimura Date: Sat, 28 Mar 2026 11:47:59 +0200 Subject: [PATCH 318/831] ggml-webgpu: Update the `RMS_NORM` preprocessor and add `L2_NORM` (llama/20665) * Update the preprocessor of RMS_NORM and add L2_NORM. * Fix the name of rms_norm to row_norm. --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 60 ++++++++++++ ggml/src/ggml-webgpu/ggml-webgpu.cpp | 30 +++--- .../ggml-webgpu/wgsl-shaders/row_norm.wgsl | 97 +++++++++++++++++++ 3 files changed, 171 insertions(+), 16 deletions(-) create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index ad665e4de93..9d16abf20d7 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -151,6 +151,26 @@ struct ggml_webgpu_get_rows_pipeline_key_hash { } }; +/** Row Norm **/ + +struct ggml_webgpu_row_norm_pipeline_key { + ggml_op op; + bool inplace; + + bool operator==(const ggml_webgpu_row_norm_pipeline_key & other) const { + return op == other.op && inplace == other.inplace; + } +}; + +struct ggml_webgpu_row_norm_pipeline_key_hash { + size_t operator()(const ggml_webgpu_row_norm_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.op); + ggml_webgpu_hash_combine(seed, key.inplace); + return seed; + } +}; + /** Pad **/ struct ggml_webgpu_pad_pipeline_key { bool circular; @@ -438,6 +458,8 @@ class ggml_webgpu_shader_lib { std::unordered_map argsort_pipelines; // key is order std::unordered_map argsort_merge_pipelines; // key is order std::unordered_map cumsum_pipelines; // key is fixed, no variants yet + std::unordered_map + row_norm_pipelines; // op/inplace std::unordered_map get_rows_pipelines; // src_type, vectorized std::unordered_map @@ -482,6 +504,44 @@ class ggml_webgpu_shader_lib { return sum_rows_pipelines[1]; } + webgpu_pipeline get_row_norm_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_row_norm_pipeline_key key = { + .op = context.dst->op, + .inplace = context.inplace, + }; + + auto it = row_norm_pipelines.find(key); + if (it != row_norm_pipelines.end()) { + return it->second; + } + std::vector defines; + std::string variant; + + switch (key.op) { + case GGML_OP_RMS_NORM: + defines.push_back("OP_RMS_NORM"); + variant = "rms_norm"; + break; + case GGML_OP_L2_NORM: + defines.push_back("OP_L2_NORM"); + variant = "l2_norm"; + break; + default: + GGML_ABORT("Unsupported op for row_norm shader"); + } + + if (key.inplace) { + defines.push_back("INPLACE"); + variant += "_inplace"; + } + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_row_norm, defines); + row_norm_pipelines[key] = ggml_webgpu_create_pipeline(device, processed, variant); + return row_norm_pipelines[key]; + } + webgpu_pipeline get_argmax_pipeline(const ggml_webgpu_shader_lib_context & context) { bool vec4 = context.src0->ne[0] % 4 == 0; diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 4b0eeac0f42..f7973df682a 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -366,7 +366,6 @@ struct webgpu_context_struct { std::map> cpy_pipelines; // src_type, dst_type - std::map rms_norm_pipelines; // inplace std::map>> rope_pipelines; // type, ff, inplace std::map>> glu_pipelines; // glu_op, type, split @@ -1598,8 +1597,8 @@ static webgpu_command ggml_webgpu_repeat(webgpu_context & ctx, ggml_tensor * src return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } -static webgpu_command ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { - int inplace = ggml_webgpu_tensor_equal(src, dst); +static webgpu_command ggml_webgpu_row_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { + bool inplace = ggml_webgpu_tensor_equal(src, dst); std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), @@ -1630,8 +1629,15 @@ static webgpu_command ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * s .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); } - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, ctx->rms_norm_pipelines[inplace], params, - entries, ggml_nrows(src)); + ggml_webgpu_shader_lib_context shader_lib_ctx = { + .src0 = src, + .dst = dst, + .max_wg_size = WEBGPU_ROW_SPLIT_WG_SIZE, + .inplace = inplace, + }; + + webgpu_pipeline pipeline = ctx->shader_lib->get_row_norm_pipeline(shader_lib_ctx); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, ggml_nrows(src)); } static webgpu_command ggml_webgpu_rope(webgpu_context & ctx, @@ -2192,7 +2198,8 @@ static std::optional ggml_webgpu_encode_node(webgpu_context ctx, case GGML_OP_REPEAT: return ggml_webgpu_repeat(ctx, src0, node); case GGML_OP_RMS_NORM: - return ggml_webgpu_rms_norm(ctx, src0, node); + case GGML_OP_L2_NORM: + return ggml_webgpu_row_norm(ctx, src0, node); case GGML_OP_ROPE: return ggml_webgpu_rope(ctx, src0, src1, src2, node); case GGML_OP_GLU: @@ -2616,15 +2623,6 @@ static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) { ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f16_f16, "cpy_f16_f16", constants); } -static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE); - - webgpu_ctx->rms_norm_pipelines[0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rms_norm, "rms_norm", constants); - webgpu_ctx->rms_norm_pipelines[1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_rms_norm_inplace, "rms_norm_inplace", constants); -} - static void ggml_webgpu_init_rope_pipeline(webgpu_context & webgpu_ctx) { std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); @@ -2909,7 +2907,6 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) { wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "set_rows_host_error_buf"); ggml_webgpu_init_cpy_pipeline(webgpu_ctx); - ggml_webgpu_init_rms_norm_pipeline(webgpu_ctx); ggml_webgpu_init_rope_pipeline(webgpu_ctx); ggml_webgpu_init_glu_pipeline(webgpu_ctx); ggml_webgpu_init_soft_max_pipeline(webgpu_ctx); @@ -3120,6 +3117,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const break; } case GGML_OP_RMS_NORM: + case GGML_OP_L2_NORM: supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32; break; case GGML_OP_ROPE: diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl new file mode 100644 index 00000000000..7777944941c --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl @@ -0,0 +1,97 @@ +#ifdef INPLACE +fn update(src_offset: u32, dst_offset: u32, scale: f32) { + src[dst_offset] = scale * src[src_offset]; +} + +@group(0) @binding(1) +var params: Params; +#else +fn update(src_offset: u32, dst_offset: u32, scale: f32) { + dst[dst_offset] = scale * src[src_offset]; +} + +@group(0) @binding(1) +var dst: array; + +@group(0) @binding(2) +var params: Params; +#endif + +struct Params { + offset_src: u32, // in elements + offset_dst: u32, // in elements + + // Strides (in elements) + stride_src1: u32, + stride_src2: u32, + stride_src3: u32, + + stride_dst1: u32, + stride_dst2: u32, + stride_dst3: u32, + + // Shape of src/dst + ne0: u32, + ne1: u32, + ne2: u32, + ne3: u32, + + eps: f32 +}; + +@group(0) @binding(0) +var src: array; + +var scratch: array; + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(workgroup_id) wid: vec3, + @builtin(local_invocation_id) lid: vec3) { + + // one thread per row + var i = wid.x; + let i3 = i / (params.ne2 * params.ne1); + i = i % (params.ne2 * params.ne1); + let i2 = i / params.ne1; + let i1 = i % params.ne1; + let i_src_row = params.offset_src + i3 * params.stride_src3 + i2 * params.stride_src2 + i1 * params.stride_src1; + let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1; + + let elems = (params.ne0 + WG_SIZE - 1) / WG_SIZE; + + var sum = 0.0f; + var col = lid.x; + for (var j: u32 = 0; j < elems; j++) { + if (col >= params.ne0) { + break; + } + sum += pow(src[i_src_row + col], 2.0); + col += WG_SIZE; + } + + scratch[lid.x] = sum; + workgroupBarrier(); + var offset: u32 = WG_SIZE / 2; + while (offset > 0) { + if (lid.x < offset) { + scratch[lid.x] += scratch[lid.x + offset]; + } + offset = offset / 2; + workgroupBarrier(); + } + sum = scratch[0]; + +#ifdef OP_RMS_NORM + let scale = 1.0/sqrt(sum/f32(params.ne0) + params.eps); +#elif OP_L2_NORM + let scale = 1.0/max(sqrt(sum), params.eps); +#endif + col = lid.x; + for (var j: u32 = 0; j < elems; j++) { + if (col >= params.ne0) { + break; + } + update(i_src_row + col, i_dst_row + col, scale); + col += WG_SIZE; + } +} From 2a6de29364870e524efc88fd0a470139ddf28332 Mon Sep 17 00:00:00 2001 From: Chenguang Li <757486878@qq.com> Date: Thu, 19 Mar 2026 14:05:01 +0800 Subject: [PATCH 319/831] CANN: handle in-place ROPE on non-contiguous f32 tensors (llama/20274) RotaryPositionEmbedding on CANN fails when src and dst share the same non-contiguous buffer (inplace + view), because the operator overwrites source data before it is fully read. Add a branch that detects this case and uses contiguous temporary buffers: copy src to temp, run ROPE into another temp, then copy back to the non-contiguous dst. Fixes 20 failing ROPE tests (f32, v=1, inplace=1). Signed-off-by: noemotiovon <757486878@qq.com> --- ggml/src/ggml-cann/aclnn_ops.cpp | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index 4b7aab1e72d..9b736636def 100644 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -2943,6 +2943,27 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) { // Rotate full tensor (no tail), using trans tensors GGML_CANN_CALL_ACLNN_OP(ctx, RotaryPositionEmbedding, acl_src_trans_tensor.get(), acl_cos_reshape_tensor.get(), acl_sin_reshape_tensor.get(), acl_mode, acl_dst_trans_tensor.get()); + } else if (src0->data == dst->data && !ggml_is_contiguous(src0)) { + // In-place on non-contiguous tensor: RotaryPositionEmbedding cannot safely + // read and write the same non-contiguous buffer. Use contiguous temporaries. + size_t contiguous_nb[GGML_MAX_DIMS]; + contiguous_nb[0] = sizeof(float); + for (int i = 1; i < GGML_MAX_DIMS; i++) { + contiguous_nb[i] = contiguous_nb[i - 1] * src0->ne[i - 1]; + } + int64_t total_elements = ggml_nelements(src0); + ggml_cann_pool_alloc inplace_src_alloc(ctx.pool(), total_elements * sizeof(float)); + ggml_cann_pool_alloc inplace_dst_alloc(ctx.pool(), total_elements * sizeof(float)); + + acl_tensor_ptr acl_src_contig = ggml_cann_create_tensor(inplace_src_alloc.get(), ACL_FLOAT, sizeof(float), + src0->ne, contiguous_nb, GGML_MAX_DIMS); + acl_tensor_ptr acl_dst_contig = ggml_cann_create_tensor(inplace_dst_alloc.get(), ACL_FLOAT, sizeof(float), + dst->ne, contiguous_nb, GGML_MAX_DIMS); + + cann_copy(ctx, acl_src.get(), acl_src_contig.get()); + GGML_CANN_CALL_ACLNN_OP(ctx, RotaryPositionEmbedding, acl_src_contig.get(), acl_cos_reshape_tensor.get(), + acl_sin_reshape_tensor.get(), acl_mode, acl_dst_contig.get()); + cann_copy(ctx, acl_dst_contig.get(), acl_dst.get()); } else { // Rotate full tensor (no tail), using original tensors GGML_CANN_CALL_ACLNN_OP(ctx, RotaryPositionEmbedding, acl_src.get(), acl_cos_reshape_tensor.get(), From fea629d00f46863f76a34a5e5f37c98cf7043524 Mon Sep 17 00:00:00 2001 From: Charles Xu Date: Thu, 19 Mar 2026 09:14:48 +0100 Subject: [PATCH 320/831] cmake : fix build warning when kleidiai is enabled (llama/20457) * cmake : fix build warning when kleidiai is enabled * remove LLAMA_ARG_THREADS from KleidiAI backend --- ggml/src/ggml-cpu/CMakeLists.txt | 36 ++++++++++++++++++++------------ 1 file changed, 23 insertions(+), 13 deletions(-) diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index 6ca3176a2f2..7c062a62995 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -570,24 +570,34 @@ function(ggml_add_cpu_backend_variant_impl tag_name) set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz") set(KLEIDIAI_ARCHIVE_MD5 "54049037570ab0ee0a0d126b2ba5ece1") - if (POLICY CMP0135) - cmake_policy(SET CMP0135 NEW) - endif() - - # TODO: Use FetchContent_MakeAvailable with EXCLUDE_FROM_ALL after bumping minimum CMake version to 3.28+ - # Using FetchContent_Populate instead to avoid EXCLUDE_FROM_ALL which requires CMake 3.28 - FetchContent_Declare(KleidiAI_Download + set(KLEIDIAI_FETCH_ARGS URL ${KLEIDIAI_DOWNLOAD_URL} DOWNLOAD_EXTRACT_TIMESTAMP NEW - URL_HASH MD5=${KLEIDIAI_ARCHIVE_MD5}) + URL_HASH MD5=${KLEIDIAI_ARCHIVE_MD5} + ) - FetchContent_GetProperties(KleidiAI_Download - SOURCE_DIR KLEIDIAI_SRC - POPULATED KLEIDIAI_POPULATED) + if (CMAKE_VERSION VERSION_GREATER_EQUAL "3.28") + FetchContent_Declare(KleidiAI_Download + ${KLEIDIAI_FETCH_ARGS} + EXCLUDE_FROM_ALL + ) - if (NOT KLEIDIAI_POPULATED) - FetchContent_Populate(KleidiAI_Download) + FetchContent_MakeAvailable(KleidiAI_Download) FetchContent_GetProperties(KleidiAI_Download SOURCE_DIR KLEIDIAI_SRC) + else() + FetchContent_Declare(KleidiAI_Download + ${KLEIDIAI_FETCH_ARGS} + ) + + FetchContent_GetProperties(KleidiAI_Download + SOURCE_DIR KLEIDIAI_SRC + POPULATED KLEIDIAI_POPULATED + ) + + if (NOT KLEIDIAI_POPULATED) + FetchContent_Populate(KleidiAI_Download) + FetchContent_GetProperties(KleidiAI_Download SOURCE_DIR KLEIDIAI_SRC) + endif() endif() add_compile_definitions(GGML_USE_CPU_KLEIDIAI) From 43c7c0f86c09455bf2acdb7384bcaf351d35c564 Mon Sep 17 00:00:00 2001 From: Eve <139727413+netrunnereve@users.noreply.github.com> Date: Thu, 19 Mar 2026 10:32:04 +0000 Subject: [PATCH 321/831] vulkan: dequantize iq4_xs 4 at a time (llama/20657) --- .../ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl | 13 +++++++------ .../vulkan-shaders/vulkan-shaders-gen.cpp | 2 +- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl index ce7f2d699a2..3f494eb4d5a 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl @@ -444,19 +444,20 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; - const uint ib = idx / 128; // 2 values per idx - const uint ib32 = (idx % 128) / 16; // 0..7 - const uint iq = 16 * ib32 + 2 * (idx % 8); + const uint ib = idx / 64; // 4 values per idx + const uint ib32 = (idx % 64) / 8; // 0..7 + const uint iq = 4 * ib32 + (idx % 4); const uint sl = (data_a[ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF; const uint sh = ((data_a[ib].scales_h) >> (2 * ib32)) & 3; - const uint qshift = (idx & 8) >> 1; - u8vec2 qs = unpack8((uint(data_a_packed16[ib].qs[iq/2]) >> qshift) & 0x0F0F).xy; + const uint qshift = idx & 4; + u8vec4 qs = unpack8((uint(data_a_packed32[ib].qs[iq]) >> qshift) & 0x0F0F0F0F); const float d = float(data_a[ib].d); - const vec2 v = d * float(int(sl | (sh << 4)) - 32) * vec2(kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y]); + const vec4 v = d * float(int(sl | (sh << 4)) - 32) * vec4(kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y], kvalues_iq4nl[qs.z], kvalues_iq4nl[qs.w]); buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy); + buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v.zw); #elif defined(DATA_A_IQ4_NL) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 4b00ba3debb..abd2a9c36fa 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -554,7 +554,7 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c std::string load_vec_quant = "2"; if ((tname == "q4_0") || (tname == "q4_1") || (tname == "q5_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s")) load_vec_quant = "8"; - else if ((tname == "q5_0") || (tname == "q8_0") || (tname == "q2_k") || (tname == "q4_k") || (tname == "q5_k") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_nl") || (tname == "mxfp4")) + else if ((tname == "q5_0") || (tname == "q8_0") || (tname == "q2_k") || (tname == "q4_k") || (tname == "q5_k") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_xs") || (tname == "iq4_nl") || (tname == "mxfp4")) load_vec_quant = "4"; if (tname == "bf16") { From 551bb8296008094dbab71b8450f6434e6045ba35 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Thu, 19 Mar 2026 08:45:28 -0700 Subject: [PATCH 322/831] ggml webgpu: ops support for qwen3.5 (SET, TRI_SOLVE, SSM_CONV, GATED_DELTA_NET) + GET_ROWS optimization (llama/20687) * Implement l2_norm, set, tri * Add DIAG/SOLVE_TRI * Add SSM_CONV * Better get_rows and gated_delta_net to support qwen3.5 * Clean up, update ops.md * Fix binding_index type for wasm * Fix read write annotations * cleanups --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 277 ++++++++++++++- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 314 +++++++++++++++++- .../wgsl-shaders/gated_delta_net.wgsl | 132 ++++++++ .../ggml-webgpu/wgsl-shaders/get_rows.wgsl | 31 +- .../ggml-webgpu/wgsl-shaders/row_norm.wgsl | 5 +- ggml/src/ggml-webgpu/wgsl-shaders/set.wgsl | 109 ++++++ .../ggml-webgpu/wgsl-shaders/solve_tri.wgsl | 121 +++++++ .../ggml-webgpu/wgsl-shaders/ssm_conv.wgsl | 65 ++++ 8 files changed, 1034 insertions(+), 20 deletions(-) create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/set.wgsl create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/solve_tri.wgsl create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/ssm_conv.wgsl diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 9d16abf20d7..59861ac16cc 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -95,6 +95,11 @@ struct ggml_webgpu_generic_shader_decisions { uint32_t wg_size = 0; }; +struct ggml_webgpu_ssm_conv_shader_decisions { + uint32_t block_size; + uint32_t tokens_per_wg; +}; + /** Argsort **/ struct ggml_webgpu_argsort_shader_lib_context { @@ -131,6 +136,26 @@ struct ggml_webgpu_set_rows_shader_decisions { uint32_t wg_size; }; +/** Set **/ + +struct ggml_webgpu_set_pipeline_key { + ggml_type type; + bool inplace; + + bool operator==(const ggml_webgpu_set_pipeline_key & other) const { + return type == other.type && inplace == other.inplace; + } +}; + +struct ggml_webgpu_set_pipeline_key_hash { + size_t operator()(const ggml_webgpu_set_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.type); + ggml_webgpu_hash_combine(seed, key.inplace); + return seed; + } +}; + /** Get Rows **/ struct ggml_webgpu_get_rows_pipeline_key { @@ -186,6 +211,67 @@ struct ggml_webgpu_pad_pipeline_key_hash { } }; +/** Solve Tri **/ +struct ggml_webgpu_solve_tri_pipeline_key { + int type; + int n; + int k; + + bool operator==(const ggml_webgpu_solve_tri_pipeline_key & other) const { + return type == other.type && n == other.n && k == other.k; + } +}; + +struct ggml_webgpu_solve_tri_pipeline_key_hash { + size_t operator()(const ggml_webgpu_solve_tri_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.type); + ggml_webgpu_hash_combine(seed, key.n); + ggml_webgpu_hash_combine(seed, key.k); + return seed; + } +}; + +/** SSM Conv **/ +struct ggml_webgpu_ssm_conv_pipeline_key { + int type; + int vectorized; + + bool operator==(const ggml_webgpu_ssm_conv_pipeline_key & other) const { + return type == other.type && vectorized == other.vectorized; + } +}; + +/** Gated Delta Net **/ +struct ggml_webgpu_gated_delta_net_pipeline_key { + int type; + int s_v; + int kda; + + bool operator==(const ggml_webgpu_gated_delta_net_pipeline_key & other) const { + return type == other.type && s_v == other.s_v && kda == other.kda; + } +}; + +struct ggml_webgpu_gated_delta_net_pipeline_key_hash { + size_t operator()(const ggml_webgpu_gated_delta_net_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.type); + ggml_webgpu_hash_combine(seed, key.s_v); + ggml_webgpu_hash_combine(seed, key.kda); + return seed; + } +}; + +struct ggml_webgpu_ssm_conv_pipeline_key_hash { + size_t operator()(const ggml_webgpu_ssm_conv_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.type); + ggml_webgpu_hash_combine(seed, key.vectorized); + return seed; + } +}; + /** Scale **/ struct ggml_webgpu_scale_pipeline_key { @@ -466,14 +552,22 @@ class ggml_webgpu_shader_lib { unary_pipelines; // type/op/inplace std::unordered_map scale_pipelines; // inplace + std::unordered_map + solve_tri_pipelines; // type + std::unordered_map + ssm_conv_pipelines; // type/vectorized + std::unordered_map + gated_delta_net_pipelines; // type/S_v/kda std::unordered_map - pad_pipelines; // circular/non-circular + pad_pipelines; // circular/non-circular std::unordered_map - binary_pipelines; // type/op/inplace/overlap + binary_pipelines; // type/op/inplace/overlap std::unordered_map - concat_pipelines; // type + concat_pipelines; // type std::unordered_map - repeat_pipelines; // type + repeat_pipelines; // type std::unordered_map flash_attn_pipelines; std::unordered_map set_rows_pipelines; + std::unordered_map set_pipelines; public: ggml_webgpu_shader_lib(wgpu::Device device) { this->device = device; } @@ -519,11 +614,11 @@ class ggml_webgpu_shader_lib { switch (key.op) { case GGML_OP_RMS_NORM: - defines.push_back("OP_RMS_NORM"); + defines.push_back("RMS_NORM"); variant = "rms_norm"; break; case GGML_OP_L2_NORM: - defines.push_back("OP_L2_NORM"); + defines.push_back("L2_NORM"); variant = "l2_norm"; break; default: @@ -535,8 +630,9 @@ class ggml_webgpu_shader_lib { variant += "_inplace"; } - defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); - + const uint32_t row_norm_wg_size = 128u; + uint32_t wg_size = std::min(context.max_wg_size, row_norm_wg_size); + defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); auto processed = preprocessor.preprocess(wgsl_row_norm, defines); row_norm_pipelines[key] = ggml_webgpu_create_pipeline(device, processed, variant); return row_norm_pipelines[key]; @@ -609,6 +705,46 @@ class ggml_webgpu_shader_lib { return set_rows_pipelines[key]; } + webgpu_pipeline get_set_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_set_pipeline_key key = { .type = context.dst->type, .inplace = context.inplace }; + + auto it = set_pipelines.find(key); + if (it != set_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "set"; + + switch (key.type) { + case GGML_TYPE_F32: + defines.push_back("TYPE_F32"); + variant += "_f32"; + break; + case GGML_TYPE_I32: + defines.push_back("TYPE_I32"); + variant += "_i32"; + break; + default: + GGML_ABORT("Unsupported type for set shader"); + } + + if (key.inplace) { + defines.push_back("INPLACE"); + variant += "_inplace"; + } + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_set, defines); + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + set_pipelines[key] = pipeline; + return set_pipelines[key]; + } + webgpu_pipeline get_cumsum_pipeline(const ggml_webgpu_shader_lib_context & context) { auto it = cumsum_pipelines.find(1); if (it != cumsum_pipelines.end()) { @@ -695,6 +831,7 @@ class ggml_webgpu_shader_lib { switch (key.src_type) { case GGML_TYPE_F32: + defines.push_back("FLOAT_PARALLEL"); if (key.vectorized) { defines.push_back("F32_VEC"); defines.push_back("SRC_TYPE=vec4"); @@ -709,6 +846,7 @@ class ggml_webgpu_shader_lib { variant += "_f32"; break; case GGML_TYPE_F16: + defines.push_back("FLOAT_PARALLEL"); defines.push_back("F16"); defines.push_back("SRC_TYPE=f16"); defines.push_back("DST_TYPE=f32"); @@ -716,6 +854,7 @@ class ggml_webgpu_shader_lib { variant += "_f16"; break; case GGML_TYPE_I32: + defines.push_back("FLOAT_PARALLEL"); defines.push_back("I32"); defines.push_back("SRC_TYPE=i32"); defines.push_back("DST_TYPE=i32"); @@ -794,6 +933,128 @@ class ggml_webgpu_shader_lib { return scale_pipelines[key]; } + webgpu_pipeline get_solve_tri_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_solve_tri_pipeline_key key = { + .type = context.dst->type, + .n = (int) context.src0->ne[0], + .k = (int) context.src1->ne[0], + }; + + auto it = solve_tri_pipelines.find(key); + if (it != solve_tri_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "solve_tri"; + + switch (key.type) { + case GGML_TYPE_F32: + variant += "_f32"; + break; + default: + GGML_ABORT("Unsupported type for solve_tri shader"); + } + + const uint32_t wg_size = std::min((uint32_t) key.n, context.max_wg_size); + const uint32_t k_tile = wg_size; + const uint32_t bytes_per_row = ((uint32_t) key.n + wg_size) * GGML_WEBGPU_F32_SIZE_BYTES; + const uint32_t batch_n = (uint32_t) (context.wg_mem_limit_bytes / bytes_per_row); + + defines.push_back(std::string("N=") + std::to_string(key.n)); + defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); + defines.push_back(std::string("K_TILE=") + std::to_string(k_tile)); + defines.push_back(std::string("BATCH_N=") + std::to_string(batch_n)); + + auto processed = preprocessor.preprocess(wgsl_solve_tri, defines); + auto decisions = std::make_shared(); + decisions->wg_size = wg_size; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + solve_tri_pipelines[key] = pipeline; + return solve_tri_pipelines[key]; + } + + webgpu_pipeline get_ssm_conv_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_ssm_conv_pipeline_key key = { + .type = context.dst->type, + .vectorized = context.src1->ne[0] == 4, + }; + + auto it = ssm_conv_pipelines.find(key); + if (it != ssm_conv_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "ssm_conv"; + + switch (key.type) { + case GGML_TYPE_F32: + variant += "_f32"; + break; + default: + GGML_ABORT("Unsupported type for ssm_conv shader"); + } + + if (key.vectorized) { + defines.push_back("VECTORIZED"); + variant += "_vec4"; + } + + constexpr uint32_t block_size = 32u; + constexpr uint32_t tokens_per_wg = 8u; + + defines.push_back("BLOCK_SIZE=" + std::to_string(block_size) + "u"); + defines.push_back("TOKENS_PER_WG=" + std::to_string(tokens_per_wg) + "u"); + + auto processed = preprocessor.preprocess(wgsl_ssm_conv, defines); + auto decisions = std::make_shared(); + decisions->block_size = block_size; + decisions->tokens_per_wg = tokens_per_wg; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + ssm_conv_pipelines[key] = pipeline; + return ssm_conv_pipelines[key]; + } + + webgpu_pipeline get_gated_delta_net_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_gated_delta_net_pipeline_key key = { + .type = context.dst->type, + .s_v = (int) context.src2->ne[0], + .kda = context.src3->ne[0] == context.src2->ne[0], + }; + + auto it = gated_delta_net_pipelines.find(key); + if (it != gated_delta_net_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "gated_delta_net"; + + switch (key.type) { + case GGML_TYPE_F32: + variant += "_f32"; + break; + default: + GGML_ABORT("Unsupported type for gated_delta_net shader"); + } + + if (key.kda) { + defines.push_back("KDA"); + variant += "_kda"; + } + + defines.push_back("S_V=" + std::to_string(key.s_v) + "u"); + defines.push_back("WG_SIZE=" + std::to_string(key.s_v) + "u"); + + auto processed = preprocessor.preprocess(wgsl_gated_delta_net, defines); + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + gated_delta_net_pipelines[key] = pipeline; + return gated_delta_net_pipelines[key]; + } + webgpu_pipeline get_pad_pipeline(const ggml_webgpu_shader_lib_context & context) { ggml_webgpu_pad_pipeline_key key = { .circular = ggml_get_op_params_i32(context.dst, 8) != 0 }; diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index f7973df682a..5e16f84ddd2 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -880,6 +880,68 @@ static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, g params, entries, wg_x); } +static webgpu_command ggml_webgpu_set(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) { + const bool inplace = ggml_webgpu_tensor_equal(src0, dst); + + ggml_webgpu_shader_lib_context shader_lib_ctx = { + .src0 = src0, + .src1 = src1, + .dst = dst, + .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, + .inplace = inplace, + }; + + webgpu_pipeline pipeline = ctx->shader_lib->get_set_pipeline(shader_lib_ctx); + + auto * decisions = static_cast(pipeline.context.get()); + + const uint32_t ne = inplace ? (uint32_t) ggml_nelements(src1) : (uint32_t) ggml_nelements(dst); + const uint32_t dst_type_size = (uint32_t) ggml_type_size(dst->type); + + std::vector params = { + ne, + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), + (uint32_t) (((const int32_t *) dst->op_params)[3] / dst_type_size), + + (uint32_t) (src1->nb[0] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)), + + 1u, + (uint32_t) (((const int32_t *) dst->op_params)[0] / dst_type_size), + (uint32_t) (((const int32_t *) dst->op_params)[1] / dst_type_size), + (uint32_t) (((const int32_t *) dst->op_params)[2] / dst_type_size), + + (uint32_t) src1->ne[0], + (uint32_t) src1->ne[1], + (uint32_t) src1->ne[2], + (uint32_t) src1->ne[3], + }; + + std::vector entries; + uint32_t binding_index = 0; + if (!inplace) { + entries.push_back({ .binding = 0, + .buffer = ggml_webgpu_tensor_buf(src0), + .offset = ggml_webgpu_tensor_align_offset(ctx, src0), + .size = ggml_webgpu_tensor_binding_size(ctx, src0) }); + binding_index++; + } + entries.push_back({ .binding = binding_index, + .buffer = ggml_webgpu_tensor_buf(src1), + .offset = ggml_webgpu_tensor_align_offset(ctx, src1), + .size = ggml_webgpu_tensor_binding_size(ctx, src1) }); + entries.push_back({ .binding = binding_index + 1, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = ggml_webgpu_tensor_align_offset(ctx, dst), + .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + + uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); +} + static webgpu_command ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { ggml_webgpu_shader_lib_context shader_lib_ctx = { .src0 = src, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup @@ -935,6 +997,208 @@ static webgpu_command ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, g return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } +static webgpu_command ggml_webgpu_solve_tri(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { + ggml_webgpu_shader_lib_context shader_lib_ctx = { + .src0 = src0, + .src1 = src1, + .dst = dst, + .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, + .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize, + }; + + webgpu_pipeline pipeline = ctx->shader_lib->get_solve_tri_pipeline(shader_lib_ctx); + + auto * decisions = static_cast(pipeline.context.get()); + + std::vector params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + + (uint32_t) (src0->nb[0] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), + + (uint32_t) (src1->nb[0] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)), + + (uint32_t) (dst->nb[0] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), + + (uint32_t) src1->ne[0], + (uint32_t) dst->ne[2], + (uint32_t) dst->ne[3], + }; + + std::vector entries = { + { .binding = 0, + .buffer = ggml_webgpu_tensor_buf(src0), + .offset = ggml_webgpu_tensor_align_offset(ctx, src0), + .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, + { .binding = 1, + .buffer = ggml_webgpu_tensor_buf(src1), + .offset = ggml_webgpu_tensor_align_offset(ctx, src1), + .size = ggml_webgpu_tensor_binding_size(ctx, src1) }, + { .binding = 2, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = ggml_webgpu_tensor_align_offset(ctx, dst), + .size = ggml_webgpu_tensor_binding_size(ctx, dst) } + }; + + const uint32_t wg_x = CEIL_DIV((uint32_t) src1->ne[0], decisions->wg_size); + const uint32_t wg_y = (uint32_t) (dst->ne[2] * dst->ne[3]); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, wg_y); +} + +static webgpu_command ggml_webgpu_ssm_conv(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { + ggml_webgpu_shader_lib_context shader_lib_ctx = { + .src0 = src0, + .src1 = src1, + .dst = dst, + .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, + }; + + webgpu_pipeline pipeline = ctx->shader_lib->get_ssm_conv_pipeline(shader_lib_ctx); + auto * decisions = static_cast(pipeline.context.get()); + + const uint32_t token_tiles = CEIL_DIV((uint32_t) dst->ne[1], decisions->tokens_per_wg); + + std::vector params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + + (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), + (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), + + (uint32_t) (dst->nb[0] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), + + (uint32_t) src1->ne[0], + (uint32_t) src0->ne[1], + (uint32_t) dst->ne[1], + (uint32_t) dst->ne[2], + token_tiles, + }; + + std::vector entries = { + { .binding = 0, + .buffer = ggml_webgpu_tensor_buf(src0), + .offset = ggml_webgpu_tensor_align_offset(ctx, src0), + .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, + { .binding = 1, + .buffer = ggml_webgpu_tensor_buf(src1), + .offset = ggml_webgpu_tensor_align_offset(ctx, src1), + .size = ggml_webgpu_tensor_binding_size(ctx, src1) }, + { .binding = 2, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = ggml_webgpu_tensor_align_offset(ctx, dst), + .size = ggml_webgpu_tensor_binding_size(ctx, dst) } + }; + + const uint32_t wg_x = CEIL_DIV((uint32_t) src0->ne[1], decisions->block_size); + const uint32_t wg_y = token_tiles * (uint32_t) dst->ne[2]; + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, wg_y); +} + +static webgpu_command ggml_webgpu_gated_delta_net(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * src2, + ggml_tensor * src3, + ggml_tensor * src4, + ggml_tensor * src5, + ggml_tensor * dst) { + ggml_webgpu_shader_lib_context shader_lib_ctx = { + .src0 = src0, + .src1 = src1, + .src2 = src2, + .src3 = src3, + .src4 = src4, + .dst = dst, + .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, + }; + + webgpu_pipeline pipeline = ctx->shader_lib->get_gated_delta_net_pipeline(shader_lib_ctx); + + const uint32_t s_v = (uint32_t) src2->ne[0]; + const uint32_t h = (uint32_t) src2->ne[1]; + const uint32_t n_tokens = (uint32_t) src2->ne[2]; + const uint32_t n_seqs = (uint32_t) src2->ne[3]; + const float scale = 1.0f / sqrtf((float) s_v); + uint32_t scale_u32; + memcpy(&scale_u32, &scale, sizeof(scale_u32)); + + std::vector params = { + h, + n_tokens, + n_seqs, + s_v * h * n_tokens * n_seqs, + + (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), + + (uint32_t) (src2->nb[1] / ggml_type_size(src2->type)), + (uint32_t) (src2->nb[2] / ggml_type_size(src2->type)), + (uint32_t) (src2->nb[3] / ggml_type_size(src2->type)), + + (uint32_t) (src4->nb[1] / ggml_type_size(src4->type)), + (uint32_t) (src4->nb[2] / ggml_type_size(src4->type)), + (uint32_t) (src4->nb[3] / ggml_type_size(src4->type)), + + (uint32_t) src0->ne[1], + (uint32_t) (src2->ne[3] / src0->ne[3]), + scale_u32, + }; + + std::vector entries = { + { .binding = 0, + .buffer = ggml_webgpu_tensor_buf(src0), + .offset = ggml_webgpu_tensor_align_offset(ctx, src0), + .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, + { .binding = 1, + .buffer = ggml_webgpu_tensor_buf(src1), + .offset = ggml_webgpu_tensor_align_offset(ctx, src1), + .size = ggml_webgpu_tensor_binding_size(ctx, src1) }, + { .binding = 2, + .buffer = ggml_webgpu_tensor_buf(src2), + .offset = ggml_webgpu_tensor_align_offset(ctx, src2), + .size = ggml_webgpu_tensor_binding_size(ctx, src2) }, + { .binding = 3, + .buffer = ggml_webgpu_tensor_buf(src3), + .offset = ggml_webgpu_tensor_align_offset(ctx, src3), + .size = ggml_webgpu_tensor_binding_size(ctx, src3) }, + { .binding = 4, + .buffer = ggml_webgpu_tensor_buf(src4), + .offset = ggml_webgpu_tensor_align_offset(ctx, src4), + .size = ggml_webgpu_tensor_binding_size(ctx, src4) }, + { .binding = 5, + .buffer = ggml_webgpu_tensor_buf(src5), + .offset = ggml_webgpu_tensor_align_offset(ctx, src5), + .size = ggml_webgpu_tensor_binding_size(ctx, src5) }, + { .binding = 6, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = ggml_webgpu_tensor_align_offset(ctx, dst), + .size = ggml_webgpu_tensor_binding_size(ctx, dst) } + }; + + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, h, n_seqs); +} + static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * idx, @@ -1016,6 +1280,8 @@ static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * idx, ggml_tensor * dst) { + const bool float_parallel = src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16 || src->type == GGML_TYPE_I32; + ggml_webgpu_shader_lib_context shader_lib_ctx = { .src0 = src, .src1 = nullptr, @@ -1060,7 +1326,10 @@ static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx, .size = ggml_webgpu_tensor_binding_size(ctx, dst) } }; - uint32_t wg_x = CEIL_DIV(dst->ne[1] * dst->ne[2] * dst->ne[3], decisions->wg_size); + uint32_t blocks_per_row = (uint32_t) (dst->ne[0] / (src->type == GGML_TYPE_F32 && dst->ne[0] % 4 == 0 ? 4 : 1)); + uint32_t total_rows = (uint32_t) (dst->ne[1] * dst->ne[2] * dst->ne[3]); + uint32_t total_threads = float_parallel ? blocks_per_row * total_rows : total_rows; + uint32_t wg_x = CEIL_DIV(total_threads, decisions->wg_size); return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } @@ -1632,7 +1901,7 @@ static webgpu_command ggml_webgpu_row_norm(webgpu_context & ctx, ggml_tensor * s ggml_webgpu_shader_lib_context shader_lib_ctx = { .src0 = src, .dst = dst, - .max_wg_size = WEBGPU_ROW_SPLIT_WG_SIZE, + .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, .inplace = inplace, }; @@ -2176,6 +2445,8 @@ static std::optional ggml_webgpu_encode_node(webgpu_context ctx, case GGML_OP_CPY: case GGML_OP_CONT: return ggml_webgpu_cpy(ctx, src0, node); + case GGML_OP_SET: + return ggml_webgpu_set(ctx, src0, src1, node); case GGML_OP_SET_ROWS: return ggml_webgpu_set_rows(ctx, src0, src1, node); case GGML_OP_GET_ROWS: @@ -2219,6 +2490,12 @@ static std::optional ggml_webgpu_encode_node(webgpu_context ctx, case GGML_OP_DIAG: case GGML_OP_TRI: return ggml_webgpu_unary_op(ctx, src0, node); + case GGML_OP_SOLVE_TRI: + return ggml_webgpu_solve_tri(ctx, src0, src1, node); + case GGML_OP_SSM_CONV: + return ggml_webgpu_ssm_conv(ctx, src0, src1, node); + case GGML_OP_GATED_DELTA_NET: + return ggml_webgpu_gated_delta_net(ctx, src0, src1, src2, node->src[3], node->src[4], node->src[5], node); case GGML_OP_PAD: return ggml_webgpu_pad(ctx, src0, node); case GGML_OP_ARGMAX: @@ -2957,7 +3234,7 @@ static ggml_backend_buffer_type_t ggml_backend_webgpu_device_get_buffer_type(ggm /* .is_host = */ NULL, // defaults to false }, /* .device = */ - dev, + dev, /* .context = */ NULL }; @@ -3040,6 +3317,10 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) || (op->type == GGML_TYPE_I32 && src0->type == GGML_TYPE_F32); break; + case GGML_OP_SET: + supports_op = src0->type == src1->type && src0->type == op->type && + (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_I32); + break; case GGML_OP_SET_ROWS: supports_op = ((op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32) && src0->type == GGML_TYPE_F32 && (src1->type == GGML_TYPE_I64 || src1->type == GGML_TYPE_I32)); @@ -3180,6 +3461,27 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const } } break; + case GGML_OP_TRI: + supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32; + break; + case GGML_OP_DIAG: + supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32; + break; + case GGML_OP_SOLVE_TRI: + supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32; + break; + case GGML_OP_SSM_CONV: + supports_op = op->type == GGML_TYPE_F32; + break; + case GGML_OP_GATED_DELTA_NET: + { + const uint32_t s_v = (uint32_t) src2->ne[0]; + supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && + src2->type == GGML_TYPE_F32 && op->src[3]->type == GGML_TYPE_F32 && + op->src[4]->type == GGML_TYPE_F32 && op->src[5]->type == GGML_TYPE_F32 && + s_v <= ctx->webgpu_global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + } + break; case GGML_OP_CLAMP: supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type); break; @@ -3201,12 +3503,6 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const case GGML_OP_COS: supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type); break; - case GGML_OP_DIAG: - supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type); - break; - case GGML_OP_TRI: - supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type); - break; case GGML_OP_PAD: supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32; break; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl new file mode 100644 index 00000000000..f9d98fda40b --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl @@ -0,0 +1,132 @@ +@group(0) @binding(0) +var src_q: array; + +@group(0) @binding(1) +var src_k: array; + +@group(0) @binding(2) +var src_v: array; + +@group(0) @binding(3) +var src_g: array; + +@group(0) @binding(4) +var src_beta: array; + +@group(0) @binding(5) +var src_state: array; + +@group(0) @binding(6) +var dst: array; + +struct Params { + h: u32, + n_tokens: u32, + n_seqs: u32, + s_off: u32, + + sq1: u32, + sq2: u32, + sq3: u32, + + sv1: u32, + sv2: u32, + sv3: u32, + + sb1: u32, + sb2: u32, + sb3: u32, + + neq1: u32, + rq3: u32, + scale: f32, +}; + +@group(0) @binding(7) +var params: Params; + +var sh_k: array; +var sh_q: array; +var sh_g: array; + +@compute @workgroup_size(WG_SIZE) +fn main( + @builtin(workgroup_id) workgroup_id: vec3, + @builtin(local_invocation_id) local_id: vec3 +) { + let head_id = workgroup_id.x; + let seq_id = workgroup_id.y; + let col = local_id.x; + + let iq1 = head_id % params.neq1; + let iq3 = seq_id / params.rq3; + + let state_size = S_V * S_V; + let state_base = (seq_id * params.h + head_id) * state_size; + + var state: array; + for (var i = 0u; i < S_V; i++) { + state[i] = src_state[state_base + col * S_V + i]; + } + + var attn_off = (seq_id * params.n_tokens * params.h + head_id) * S_V; + + for (var t = 0u; t < params.n_tokens; t++) { + let q_off = iq3 * params.sq3 + t * params.sq2 + iq1 * params.sq1; + let k_off = q_off; + let v_off = seq_id * params.sv3 + t * params.sv2 + head_id * params.sv1; + let gb_off = seq_id * params.sb3 + t * params.sb2 + head_id * params.sb1; + + sh_q[col] = src_q[q_off + col]; + sh_k[col] = src_k[k_off + col]; + +#ifdef KDA + let g_base = gb_off * S_V; + sh_g[col] = exp(src_g[g_base + col]); +#endif + + workgroupBarrier(); + + let v_val = src_v[v_off + col]; + let beta_val = src_beta[gb_off]; + + var kv_col = 0.0; + var delta_col = 0.0; + var attn_col = 0.0; + +#ifdef KDA + for (var i = 0u; i < S_V; i++) { + kv_col += (sh_g[i] * state[i]) * sh_k[i]; + } + + delta_col = (v_val - kv_col) * beta_val; + + for (var i = 0u; i < S_V; i++) { + state[i] = sh_g[i] * state[i] + sh_k[i] * delta_col; + attn_col += state[i] * sh_q[i]; + } +#else + let g_val = exp(src_g[gb_off]); + + for (var i = 0u; i < S_V; i++) { + kv_col += state[i] * sh_k[i]; + } + + delta_col = (v_val - g_val * kv_col) * beta_val; + + for (var i = 0u; i < S_V; i++) { + state[i] = g_val * state[i] + sh_k[i] * delta_col; + attn_col += state[i] * sh_q[i]; + } +#endif + + dst[attn_off + col] = attn_col * params.scale; + attn_off += S_V * params.h; + + workgroupBarrier(); + } + + for (var i = 0u; i < S_V; i++) { + dst[params.s_off + state_base + col * S_V + i] = state[i]; + } +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl index b10800e36d2..d9eb6a3567e 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl @@ -640,6 +640,35 @@ var params: Params; @compute @workgroup_size(WG_SIZE) fn main(@builtin(global_invocation_id) gid: vec3) { +#ifdef FLOAT_PARALLEL + let blocks_per_row = params.ne0 / BLOCK_SIZE; + let row_count = params.n_rows * params.ne2 * params.ne3; + + if (gid.x >= blocks_per_row * row_count) { + return; + } + + let block_idx = gid.x % blocks_per_row; + var row_idx = gid.x / blocks_per_row; + let i_dst3 = row_idx / (params.ne2 * params.n_rows); + + row_idx = row_idx % (params.ne2 * params.n_rows); + let i_dst2 = row_idx / params.n_rows; + let i_dst1 = row_idx % params.n_rows; + + let i_idx2 = i_dst3 % params.idx2; + let i_idx1 = i_dst2 % params.idx1; + let i_idx0 = i_dst1; + + let i_idx = params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2; + + let idx_val = u32(idx[i_idx]); + + let i_src_row = params.offset_src + idx_val * params.stride_src1 + i_dst2 * params.stride_src2 + i_dst3 * params.stride_src3; + let i_dst_row = params.offset_dst + i_dst1 * params.stride_dst1 + i_dst2 * params.stride_dst2 + i_dst3 * params.stride_dst3; + + copy_elements(i_src_row, i_dst_row, block_idx); +#else if (gid.x >= params.n_rows * params.ne2 * params.ne3) { return; } @@ -664,5 +693,5 @@ fn main(@builtin(global_invocation_id) gid: vec3) { for (var i: u32 = 0; i < params.ne0/BLOCK_SIZE; i++) { copy_elements(i_src_row, i_dst_row, i); } +#endif } - diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl index 7777944941c..bd8d32bded7 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl @@ -81,11 +81,12 @@ fn main(@builtin(workgroup_id) wid: vec3, } sum = scratch[0]; -#ifdef OP_RMS_NORM +#ifdef RMS_NORM let scale = 1.0/sqrt(sum/f32(params.ne0) + params.eps); -#elif OP_L2_NORM +#elif defined(L2_NORM) let scale = 1.0/max(sqrt(sum), params.eps); #endif + col = lid.x; for (var j: u32 = 0; j < elems; j++) { if (col >= params.ne0) { diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/set.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/set.wgsl new file mode 100644 index 00000000000..0a7ae9bdb2c --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/set.wgsl @@ -0,0 +1,109 @@ +#ifdef TYPE_I32 +#define TYPE i32 +#else +#define TYPE f32 +#endif + +#ifndef INPLACE +@group(0) @binding(0) +var src0: array; +#define SRC1_BINDING 1 +#else +#define SRC1_BINDING 0 +#endif + +#define DST_BINDING SRC1_BINDING + 1 +#define PARAMS_BINDING SRC1_BINDING + 2 + +@group(0) @binding(SRC1_BINDING) +var src1: array; + +@group(0) @binding(DST_BINDING) +var dst: array; + +struct Params { + ne: u32, + offset_src0: u32, + offset_src1: u32, + offset_view: u32, + + stride_src10: u32, + stride_src11: u32, + stride_src12: u32, + stride_src13: u32, + + stride_dst10: u32, + stride_dst11: u32, + stride_dst12: u32, + stride_dst13: u32, + + src1_ne0: u32, + src1_ne1: u32, + src1_ne2: u32, + src1_ne3: u32, +}; + +@group(0) @binding(PARAMS_BINDING) +var params: Params; + +fn decode_src1_coords(idx: u32) -> vec4 { + var i = idx; + let plane = params.src1_ne2 * params.src1_ne1 * params.src1_ne0; + let i3 = i / plane; + i = i % plane; + let row = params.src1_ne1 * params.src1_ne0; + let i2 = i / row; + i = i % row; + let i1 = i / params.src1_ne0; + let i0 = i % params.src1_ne0; + return vec4(i0, i1, i2, i3); +} + +fn decode_view_coords(rel: u32) -> vec4 { + let i3 = rel / params.stride_dst13; + let rem3 = rel % params.stride_dst13; + let i2 = rem3 / params.stride_dst12; + let rem2 = rem3 % params.stride_dst12; + let i1 = rem2 / params.stride_dst11; + let i0 = rem2 % params.stride_dst11; + return vec4(i0, i1, i2, i3); +} + +fn view_rel_from_coords(coords: vec4) -> u32 { + return coords.x * params.stride_dst10 + coords.y * params.stride_dst11 + + coords.z * params.stride_dst12 + coords.w * params.stride_dst13; +} + +fn src1_idx_from_coords(coords: vec4) -> u32 { + return coords.x * params.stride_src10 + coords.y * params.stride_src11 + + coords.z * params.stride_src12 + coords.w * params.stride_src13; +} + +fn in_set_view(rel: u32, coords: vec4) -> bool { + return view_rel_from_coords(coords) == rel; +} + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(global_invocation_id) gid: vec3) { + if (gid.x >= params.ne) { + return; + } + +#ifdef INPLACE + let coords = decode_src1_coords(gid.x); + + let src1_idx = params.offset_src1 + src1_idx_from_coords(coords); + let dst_idx = params.offset_view + view_rel_from_coords(coords); + + dst[dst_idx] = src1[src1_idx]; +#else + let rel = select(params.ne, gid.x - params.offset_view, gid.x >= params.offset_view); + let coords = decode_view_coords(rel); + + if (rel < params.stride_dst13 * params.src1_ne3 && in_set_view(rel, coords)) { + dst[gid.x] = src1[params.offset_src1 + src1_idx_from_coords(coords)]; + } else { + dst[gid.x] = src0[params.offset_src0 + gid.x]; + } +#endif +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/solve_tri.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/solve_tri.wgsl new file mode 100644 index 00000000000..9d5d902cb1e --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/solve_tri.wgsl @@ -0,0 +1,121 @@ +@group(0) @binding(0) +var src0: array; + +@group(0) @binding(1) +var src1: array; + +@group(0) @binding(2) +var dst: array; + +struct Params { + offset_src0: u32, + offset_src1: u32, + offset_dst: u32, + + stride_src00: u32, + stride_src01: u32, + stride_src02: u32, + stride_src03: u32, + + stride_src10: u32, + stride_src11: u32, + stride_src12: u32, + stride_src13: u32, + + stride_dst0: u32, + stride_dst1: u32, + stride_dst2: u32, + stride_dst3: u32, + + k: u32, + ne2: u32, + ne3: u32, +}; + +@group(0) @binding(3) +var params: Params; + +var shA: array; +var shB: array; + +fn src0_idx(row: u32, col: u32, i2: u32, i3: u32) -> u32 { + return params.offset_src0 + + col * params.stride_src00 + + row * params.stride_src01 + + i2 * params.stride_src02 + + i3 * params.stride_src03; +} + +fn src1_idx(row: u32, col: u32, i2: u32, i3: u32) -> u32 { + return params.offset_src1 + + col * params.stride_src10 + + row * params.stride_src11 + + i2 * params.stride_src12 + + i3 * params.stride_src13; +} + +fn dst_idx(row: u32, col: u32, i2: u32, i3: u32) -> u32 { + return params.offset_dst + + col * params.stride_dst0 + + row * params.stride_dst1 + + i2 * params.stride_dst2 + + i3 * params.stride_dst3; +} + +@compute @workgroup_size(WG_SIZE) +fn main( + @builtin(workgroup_id) workgroup_id: vec3, + @builtin(local_invocation_id) local_id: vec3 +) { + let batch = workgroup_id.y; + let col = workgroup_id.x * WG_SIZE + local_id.x; + let i3 = batch / params.ne2; + let i2 = batch % params.ne2; + let active_lane = local_id.x < K_TILE; + let active_col = active_lane && col < params.k; + + var X: array; + + for (var row_base = 0u; row_base < N; row_base += BATCH_N) { + let cur_n = min(BATCH_N, N - row_base); + + for (var i = local_id.x; i < cur_n * N; i += WG_SIZE) { + let tile_row = i / N; + let tile_col = i % N; + shA[i] = src0[src0_idx(row_base + tile_row, tile_col, i2, i3)]; + } + + for (var i = local_id.x; i < cur_n * K_TILE; i += WG_SIZE) { + let tile_row = i / K_TILE; + let tile_col = i % K_TILE; + let global_col = workgroup_id.x * WG_SIZE + tile_col; + let sh_idx = tile_row * K_TILE + tile_col; + + if (global_col < params.k) { + shB[sh_idx] = src1[src1_idx(row_base + tile_row, global_col, i2, i3)]; + } else { + shB[sh_idx] = 0.0; + } + } + + workgroupBarrier(); + + if (active_col) { + for (var row_offset = 0u; row_offset < cur_n; row_offset++) { + let r = row_base + row_offset; + var b = shB[row_offset * K_TILE + local_id.x]; + let a_row = row_offset * N; + + for (var t = 0u; t < r; t++) { + b -= shA[a_row + t] * X[t]; + } + + let x = b / shA[a_row + r]; + X[r] = x; + dst[dst_idx(r, col, i2, i3)] = x; + } + } + + workgroupBarrier(); + } +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/ssm_conv.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/ssm_conv.wgsl new file mode 100644 index 00000000000..11511305ed8 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/ssm_conv.wgsl @@ -0,0 +1,65 @@ +@group(0) @binding(0) +var src0: array; + +@group(0) @binding(1) +var src1: array; + +@group(0) @binding(2) +var dst: array; + +struct Params { + offset_src0: u32, + offset_src1: u32, + offset_dst: u32, + + stride_src01: u32, + stride_src02: u32, + stride_src11: u32, + + stride_dst0: u32, + stride_dst1: u32, + stride_dst2: u32, + + nc: u32, + nr: u32, + n_t: u32, + n_s: u32, + token_tiles: u32, +}; + +@group(0) @binding(3) +var params: Params; + +@compute @workgroup_size(BLOCK_SIZE, TOKENS_PER_WG) +fn main(@builtin(global_invocation_id) gid: vec3) { + let i1 = gid.x; + let tile_y = gid.y / TOKENS_PER_WG; + let local_token = gid.y % TOKENS_PER_WG; + let i3 = tile_y / params.token_tiles; + let token_tile = tile_y % params.token_tiles; + let i2 = token_tile * TOKENS_PER_WG + local_token; + + if (i1 >= params.nr || i2 >= params.n_t || i3 >= params.n_s) { + return; + } + + let src0_base = params.offset_src0 + i3 * params.stride_src02 + i2 + i1 * params.stride_src01; + let src1_base = params.offset_src1 + i1 * params.stride_src11; + + var sum = 0.0; + +#ifdef VECTORIZED + sum = + src0[src0_base + 0u] * src1[src1_base + 0u] + + src0[src0_base + 1u] * src1[src1_base + 1u] + + src0[src0_base + 2u] * src1[src1_base + 2u] + + src0[src0_base + 3u] * src1[src1_base + 3u]; +#else + for (var i0 = 0u; i0 < params.nc; i0++) { + sum += src0[src0_base + i0] * src1[src1_base + i0]; + } +#endif + + let dst_idx = params.offset_dst + i3 * params.stride_dst2 + i2 * params.stride_dst1 + i1 * params.stride_dst0; + dst[dst_idx] = sum; +} From 081dc773a5bbc9c119ccb7ec94a5fca332ccd0d5 Mon Sep 17 00:00:00 2001 From: uvos Date: Thu, 19 Mar 2026 17:05:44 +0100 Subject: [PATCH 323/831] ci : add hip quality check (llama/20430) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * CI: add hip quality check * Update scripts/hip/gcn-cdna-vgpr-check.py Co-authored-by: Sigbjørn Skjæret * Update .github/workflows/hip-quality-check.yml Co-authored-by: Sigbjørn Skjæret * Update .github/workflows/hip-quality-check.yml Co-authored-by: Sigbjørn Skjæret * Update .github/workflows/hip-quality-check.yml Co-authored-by: Sigbjørn Skjæret * Update scripts/hip/gcn-cdna-vgpr-check.py Co-authored-by: Sigbjørn Skjæret * Update scripts/hip/gcn-cdna-vgpr-check.py Co-authored-by: Sigbjørn Skjæret * Update scripts/hip/gcn-cdna-vgpr-check.py Co-authored-by: Sigbjørn Skjæret * Update scripts/hip/gcn-cdna-vgpr-check.py Co-authored-by: Sigbjørn Skjæret * Revert "Update .github/workflows/hip-quality-check.yml" This reverts commit efa0bfcdb01dfac0feee674987a0482d50f46145. * scripts: gcn-cdna-vgpr-check.py: enforce int type for total_vgprs * scripts: gcn-cdna-vgpr-check.py: add flash attention instances to ignore list * Bump ccache version * Add mssing seperators to list --------- Co-authored-by: Sigbjørn Skjæret --- ggml/src/ggml-hip/CMakeLists.txt | 3 --- 1 file changed, 3 deletions(-) diff --git a/ggml/src/ggml-hip/CMakeLists.txt b/ggml/src/ggml-hip/CMakeLists.txt index b44ed0f7215..c2357722629 100644 --- a/ggml/src/ggml-hip/CMakeLists.txt +++ b/ggml/src/ggml-hip/CMakeLists.txt @@ -53,9 +53,6 @@ endif() message(STATUS "HIP and hipBLAS found") -# Workaround old compilers -set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} --gpu-max-threads-per-block=1024") - file(GLOB GGML_HEADERS_ROCM "../ggml-cuda/*.cuh") list(APPEND GGML_HEADERS_ROCM "../../include/ggml-cuda.h") From 15f6b6ad76ef6aed7814d1212bfc17a8e05e0937 Mon Sep 17 00:00:00 2001 From: Yiwei Shao <44545837+njsyw1997@users.noreply.github.com> Date: Thu, 19 Mar 2026 09:11:06 -0700 Subject: [PATCH 324/831] hexagon: add Matrix Extensions (HMX) for Hexagon NPU backend (llama/20693) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * migrate(vtcm): unify VTCM management for HMX merge - Add HMX fields to htp_context (#ifdef HTP_HAS_HMX): hmx_enabled, hmx_dma, vtcm_scratch_size, exp2_table - Add HTP_VTCM_SESSION_HOLD CMake option (default ON): hold VTCM for entire session instead of per-op acquire/release - Add vtcm_op_acquire/vtcm_op_release inline wrappers: no-op in session-hold mode, delegate in per-op mode - Add VTCM tail reservation for precompute tables (256KB, 64KB aligned) in htp_iface_start under HTP_HAS_HMX - Add HMX init/cleanup hooks in htp_iface_start/stop - Add precompute table recovery in vtcm_acquire after VTCM preemption - Do NOT migrate vtcm_mgr from htp-ops-lib (replaced by tail reservation) * migrate(repack): replace x4x2 with HMX tile-permuted super-block format - Add hmx_block_q4_0/q8_0 struct definitions (scales-first + sequential quants) - Implement forward repack: repack_q4_0_to_hmx_superblock, repack_q8_0_to_hmx_superblock, repack_f16_to_tile_permuted - Implement inverse repack for get_tensor debug verification - Route set_tensor/get_tensor via opt_arch >= 73 to HMX path, else existing HVX x4x2 - MXFP4 on v73+ falls back to HVX x4x2 repack (not memcpy) - Extend supports_op: add IQ4_NL for v73+, F16 tile alignment checks - Tail blocks (K not multiple of 256): repack to x4x2 via pad-repack-truncate - Add CMake GGML_HEXAGON_HMX_TAIL_HVX option (default ON); OFF rejects non-256-aligned K in supports_op * migrate(dma): add dma_queue_push_1d() convenience wrapper for HMX ops Add 1D linear DMA transfer helper to hex-dma.h for upcoming HMX op migration. Reuses existing dma_queue_flush() for sync points instead of adding redundant dma_queue_drain(). * migrate(hmx): reorganize HMX files into htp/hmx/ and simplify HMX locking Move all 14 HMX-related files from htp/ to htp/hmx/ subdirectory for cleaner separation between HVX and HMX code. Simplify HMX hardware locking by replacing the two-level lock design (SHARED HAP lock + custom asm spin-lock) with direct HAP_compute_res_hmx_lock/unlock on the existing vtcm_rctx, which already has HMX capability. Key changes: - Create htp/hmx/ subdirectory with all HMX infrastructure and ops - Replace hmx_mgr_ctx_id + spin-lock with HAP_compute_res_hmx_lock(vtcm_rctx) - Remove hmx_manager_enable/disable_execution() (SHARED lock no longer needed) - Add hmx_set_vtcm_state() call in main.c (was missing, caused null globals) - Update main.c includes to use hmx/ prefix - Clean up duplicate declarations from hmx-worker-pool.h * migrate(hmx-infra): consolidate HMX infrastructure into htp_context - Remove hmx-mgr.c/h: eliminate global HMX state singleton, thread htp_context through all HMX ops - Remove hmx-worker-pool.c/h: replace separate HMX worker pool with main worker_pool API (worker_pool_run_func) - Replace hmx_unit_acquire/release with direct HAP_compute_res_hmx_lock/unlock on ctx->vtcm_rctx - Remove HTP_VTCM_SESSION_HOLD compile option: always use per-op vtcm_acquire/release - Remove hmx_dma from htp_context: HMX ops use ctx->dma[0] instead of separate DMA queue - Simplify main.c init/cleanup: remove hmx_manager_setup/reset and vtcm_op_acquire/release wrappers - Delete upstream llama.cpp AGENTS.md (not applicable to fork) * migrate(flash-attn): remove HTP_EXP2_TABLE_COPIES, use single exp2 table - Remove HTP_EXP2_TABLE_COPIES compile definition and CMake cache variable - Remove table duplication loop in precompute-table.c - Remove worker_index % N sub-table indexing in hmx-flash-attn-ops.c - Fix table_size to 65536 (single 64 KB copy) in main.c The exp2 lookup table is read-only; concurrent VTCM reads do not cause bank conflicts, so duplicating the table wastes 192 KB of VTCM for no benefit. * migrate(dsp-main): add HMX priority dispatch in packet_callback - Add proc_hmx_matmul_req() wrapper for HMX mat_mul (F16 and quantized types) - Add proc_hmx_flash_attn_req() wrapper for HMX simple_flash_attn (FP16 only, falls back to HVX for non-FP16) - Add proc_hmx_rms_norm_req() wrapper using hvx_rms_norm_f32 - Route MUL_MAT, FLASH_ATTN_EXT, RMS_NORM through HMX path when ctx->hmx_enabled - Split RMS_NORM and SCALE into separate case blocks for independent dispatch - All HMX wrappers guarded by #ifdef HTP_HAS_HMX * migrate(cmake-dsp): add HMX source files and -mhmx for v73+ skels Add HTP_VTCM_SESSION_HOLD option (default ON) and v73+ HMX build integration: compile hmx-matmul-ops, hmx-flash-attn-ops, hmx-rms-norm-ops and precompute-table into v73/v75/v79/v81 skels with -mhmx flag and HTP_HAS_HMX=1 definition. v68/v69 skels remain unchanged. * migrate(hmx-ops): fix compile errors in HMX ops for ggml struct compatibility - hmx-matmul-ops.c: include ggml-common.h for block_q4_0/block_q8_0 definitions - hmx-matmul-ops.c: rename quants->qs, scale->d to match upstream ggml field names - hmx-flash-attn-ops.c: suppress -Wunused-function/-Wunused-variable warnings - hmx-flash-attn-ops.c: inline ctx->n_threads, remove unused n_workers variable * hmx: set Q/O element type to fp16 for flash attention The llama.cpp integration passes fp16 Q/O tensors, so qo_fp32_element should be false to match the actual data layout. * hexagon: unify HMX weight format to x4x2, add IQ4_NL and DSP-side fallback Remove the v73+ HMX-specific super-block/tile-permuted weight format and unify all architectures on the HVX x4x2 packed format. The DSP now decides at runtime whether to use the HMX or HVX matmul path based on dimension constraints (M%32, N%32, K%256 alignment), rather than the host rejecting ops in supports_op. This simplifies the host repack logic, eliminates ~400 lines of HMX super-block code, and adds IQ4_NL quantization support across host and DSP. Key changes: - Remove hmx_block_q4_0/q8_0 types, repack functions, and F16 tile permutation (ggml-hexagon.cpp, hmx-quants.h) - Simplify set_tensor/get_tensor to always use x4x2 repack, add IQ4_NL - Force is_host=false so tensor copies go through format conversion - Add HTP_TYPE_IQ4_NL to DSP message protocol (htp-msg.h) - Rewrite DSP dequantizers to work directly on x4x2 layout (hmx-matmul-ops.c) - Fix mxclracc.hf placement: clear per output tile, not once globally - Move HMX eligibility checks to DSP proc_hmx_matmul_req (main.c) - Remove dma_queue_push_1d wrapper, use 2D DMA for weight sub-blocks - Add VTCM allocation overflow asserts - Remove GGML_HEXAGON_HMX_TAIL_HVX build option (CMakeLists.txt) * Enhance HMX debugging capabilities with new tile dumping functions - Introduced hmx_dump_tile_mem and hmx_dump_fp32_tile_region for improved memory layout visualization of tile data. - Updated hmx_dump_tile_rows to provide raw memory output for debugging. - Added debug logging for activation and weight tile pairs during processing to facilitate troubleshooting. - Refined existing macros for dumping HVX vector values to streamline debugging output. These changes aim to enhance the debugging experience for HMX matmul operations, ensuring better visibility into data handling and transformations. * OK for small mat mul * hexagon: fix UDMA roiwidth 16-bit overflow in HMX matmul DMA transfers The UDMA descriptor roiwidth field is 16-bit (max 65535), but large matrix DMA transfers (e.g. 32×2304 = 73728 bytes) exceeded this limit, causing truncated transfers and NaN results. Fix by using 2D DMA (per-row stride × n_rows) instead of 1D (total_size × 1) for all 4 DMA push calls in both x4x2 and fp16 weight paths. Also includes: - Use standard vlut16 instead of _nomatch variant for dequantization - Add per-tile vscatter drain barrier for correctness - Add compile-time HMX_DEBUG_TRACE_VALUES instrumentation (disabled by default) * hexagon: remove HMX RMS norm fallback and re-enable matmul pipeline Remove hmx-rms-norm-ops.c as the HVX RMS norm offers no benefit over the generic unary path. Re-enable DMA pipeline mode for QK matmul. * hexagon: guard all HMX matmul DMA transfers against UDMA 16-bit field overflow All UDMA type1 descriptor fields (roiwidth, roiheight, srcstride, dststride) are 16-bit (max 65535). Commit 40d2a9cc fixed roiwidth overflow in the non-pipeline path by switching from 1D to 2D DMA, but the pipeline path (3 call sites) was left unchanged and still used 1D DMA with chunk_size = n_cols * row_stride as roiwidth, which overflows for any practical matrix size when the pipeline is active. Add a local hmx_dma_push_safe() helper that transparently handles overflow: - Fast path (zero overhead): all params fit in 16 bits -> direct call. - Contiguous block: reshapes into a single 2D descriptor with sub_width that fits in 16 bits, preserving async DMA behavior. - Stride overflow: row-by-row fallback for future large-k models where per-row stride itself exceeds 65535. Convert all 8 external dma_queue_push calls in hmx-matmul-ops.c to use the safe helper, including the 3 pipeline sites (1D -> 2D fix), the FP16 and x4x2 weight paths, qweight_fetch sub-block DMA, and the output-stationary activation fetch. * hexagon: multithread activation/output transfer and add HMX matmul fallback - Replace single-threaded transfer_activation_chunk_fp32_to_fp16 with transfer_activation_chunk_multithread across all HMX matmul paths - Add multi-threaded transfer_output_chunk_multithread for FP16-to-FP32 output store, following the same worker pool pattern - Rename transfer_activation_chunk_no_prefetch back to transfer_activation_chunk_fp32_to_fp16 and clean up stale comments - Add HVX fallback in proc_hmx_matmul_req when HMX matmul returns error * [todo]: dynamic alloc vtcm, cause prefill regression. * hexagon: constrain HMX mxmem tile load region to avoid VTCM bank boundary faults Set activation/weight mxmem Rt to 2047 for single-tile loads and document the 4MB VTCM bank boundary constraint, preventing precise bus errors when dynamic VTCM allocation places tiles near bank edges. * hexagon: split unaligned-M HMX matmul into HMX+HVX phases - keep HMX for the 32-aligned head rows and process tail rows with HVX - force re-quantization for HVX tail after HMX phase to avoid stale VTCM state - preserve fallback behavior when N is unaligned or no aligned M rows exist * hexagon: batch-4 Q4_0 dequantize fast path and remove debug traces Add dequantize_x4x2_q4_0_x4groups_hvx() that processes 4 contiguous K-tiles with a single vmemu + vlut16 per row, reducing per-tile overhead. The dequantize loop now takes the batch-4 path when 4 aligned K-tiles are available within the same column tile, falling back to the original single-tile path otherwise. Also removes HMX_DEBUG_TRACE_VALUES instrumentation blocks that are no longer needed. * hexagon: abort on DSP error and fix HMX-to-HVX fallback quantize flag Promote DSP response error from log to GGML_ABORT for fail-fast behavior. Clear SKIP_QUANTIZE flag when falling back from HMX to HVX matmul so the HVX path correctly re-quantizes activations. * hexagon: support batch matmul. This fix perplexity issue The problem comes from Grouped-Query Attention(GQA). Strides between batches are not well respected TODO: optimize batch matmul to reuse weights between batches. * hexagon: reuse weights in fp16 batch matmul * hexagon: remove unused HMX flash attention operations and precomputation table, remove the log system for test * hexagon: remove unused HVX math helpers, debug infrastructure, and stale build options * hexagon: fix HMX not enabled due to missing force_hvx parameter in IDL * hexagon: remove the unnecessary changes not related to HMX * hexagon: bypass HMX by default * hexagon: add upstream repo link to htp-ops-lib ported file headers * hexagon: restore host buffer support * hexagon: add HMX=1 option for the adb scripts * hex-hmx: improve DMA pipelining * hex-hmx: further improvements to dma pipelining * hex-hmx: minor cleanup * hex-hmx: move hmx lock out of inner loops/calls * hex-hmx: remove unnecessary state and wrappers * hex-hmx: remove hmx dir and unify f32 to f16 conversions * hex-hmx: further unify hvx conversions * hex-hmx: revert f16 converter to the original for now * hex-hmx: minor cleanup for f16 to f32 converter * hex-mm: replace incorrect fp16-to-fp32 hmx converter and reformated related code * hex-dma: move chanied dma push into hex-dma.h header and update hmx-mm * hex-mm: use hex_is_aligned instead of a duplicated hmx_is_aligned * hex-mm: use hvx_vec_splat_f16 in the hmx code * hex-mm: use VLEN and HTP types in hmx-code * hex-mm: remove duplicate QK and defs * hexagon: pre-shuffle quants before vlut16 * hexagon: enable HMX by default * hex-mm: code indent fixes for hmx-matmul * hexagon: update hex-utils to include align/smin/etc helpers and use that in hmx mm * hex-mm: more formatting fixes * hex-mm: minor naming updates in hmx code * hex-mm: remove leftover from rebase conflict * Fix the incorrect indents --------- Co-authored-by: Max Krasnyansky --- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 7 +- ggml/src/ggml-hexagon/htp/CMakeLists.txt | 18 + ggml/src/ggml-hexagon/htp/hex-dma.h | 80 + ggml/src/ggml-hexagon/htp/hex-utils.h | 22 +- ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c | 1528 ++++++++++++++++++++ ggml/src/ggml-hexagon/htp/hmx-ops.h | 72 + ggml/src/ggml-hexagon/htp/hmx-profile.h | 34 + ggml/src/ggml-hexagon/htp/hmx-utils.h | 88 ++ ggml/src/ggml-hexagon/htp/htp-ctx.h | 6 + ggml/src/ggml-hexagon/htp/htp-msg.h | 19 +- ggml/src/ggml-hexagon/htp/htp_iface.idl | 2 +- ggml/src/ggml-hexagon/htp/hvx-base.h | 37 +- ggml/src/ggml-hexagon/htp/main.c | 246 +++- 13 files changed, 2142 insertions(+), 17 deletions(-) create mode 100644 ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c create mode 100644 ggml/src/ggml-hexagon/htp/hmx-ops.h create mode 100644 ggml/src/ggml-hexagon/htp/hmx-profile.h create mode 100644 ggml/src/ggml-hexagon/htp/hmx-utils.h diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 4b8a16c3635..8bcf5291c11 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -45,6 +45,7 @@ static int opt_verbose = 0; static int opt_profile = 0; static int opt_hostbuf = 1; // hostbuf ON by default static int opt_experimental = 0; +static int opt_use_hmx = 1; // when set, enable HMX; when 0, use HVX only // Enable all stages by default static int opt_opmask = HTP_OPMASK_QUEUE | HTP_OPMASK_QUANTIZE | HTP_OPMASK_COMPUTE; @@ -1693,7 +1694,7 @@ void ggml_hexagon_session::allocate(int dev_id) noexcept(false) { // Start the DSP-side service. We need to pass the queue ID to the // DSP in a FastRPC call; the DSP side will import the queue and start // listening for packets in a callback. - err = htp_iface_start(this->handle, dev_id, this->queue_id, opt_nhvx); + err = htp_iface_start(this->handle, dev_id, this->queue_id, opt_nhvx, opt_use_hmx); if (err != 0) { GGML_LOG_ERROR("ggml-hex: failed to start session: 0x%08x\n", (unsigned) err); throw std::runtime_error("ggml-hex: iface start failed (see log for details)"); @@ -3372,6 +3373,7 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) { const char * str_profile = getenv("GGML_HEXAGON_PROFILE"); const char * str_etm = getenv("GGML_HEXAGON_ETM"); const char * str_nhvx = getenv("GGML_HEXAGON_NHVX"); + const char * str_use_hmx = getenv("GGML_HEXAGON_USE_HMX"); const char * str_ndev = getenv("GGML_HEXAGON_NDEV"); const char * str_arch = getenv("GGML_HEXAGON_ARCH"); @@ -3381,8 +3383,9 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) { opt_opmask = str_opmask ? strtoul(str_opmask, NULL, 0) : opt_opmask; opt_opsync = str_opsync ? atoi(str_opsync) : 0; opt_profile = str_profile ? atoi(str_profile) : 0; - opt_etm = str_etm ? atoi(str_etm) : 0; + opt_etm = str_etm ? atoi(str_etm) : 0; opt_nhvx = str_nhvx ? strtoul(str_nhvx, NULL, 0) : opt_nhvx; + opt_use_hmx = str_use_hmx ? atoi(str_use_hmx) : opt_use_hmx; opt_ndev = str_ndev ? strtoul(str_ndev, NULL, 0) : opt_ndev; if (opt_ndev > GGML_HEXAGON_MAX_SESSIONS) { diff --git a/ggml/src/ggml-hexagon/htp/CMakeLists.txt b/ggml/src/ggml-hexagon/htp/CMakeLists.txt index a490a2ce9a1..6ddfe4252f5 100644 --- a/ggml/src/ggml-hexagon/htp/CMakeLists.txt +++ b/ggml/src/ggml-hexagon/htp/CMakeLists.txt @@ -40,6 +40,24 @@ target_compile_definitions(${HTP_LIB} PRIVATE $,FARF_HIGH=1,> FP32_QUANTIZE_GROUP_SIZE=${GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE}) +# HMX acceleration: available on v73+ architectures +set(HTP_HMX_VERSIONS v73 v75 v79 v81) +list(FIND HTP_HMX_VERSIONS ${DSP_VERSION} _hmx_idx) + +if (_hmx_idx GREATER_EQUAL 0) + target_sources(${HTP_LIB} PRIVATE + hmx-matmul-ops.c + ) + + # -mhmx enables HMX instruction set (needed by files that include hmx-utils.h) + set_source_files_properties( + hmx-matmul-ops.c + PROPERTIES COMPILE_OPTIONS "-mhmx" + ) + + target_compile_definitions(${HTP_LIB} PRIVATE HTP_HAS_HMX=1) +endif() + build_idl(htp_iface.idl ${HTP_LIB}) set_target_properties(${HTP_LIB} PROPERTIES EXPORT_COMPILE_COMMANDS ON) diff --git a/ggml/src/ggml-hexagon/htp/hex-dma.h b/ggml/src/ggml-hexagon/htp/hex-dma.h index 350ab9d966f..9811a07599f 100644 --- a/ggml/src/ggml-hexagon/htp/hex-dma.h +++ b/ggml/src/ggml-hexagon/htp/hex-dma.h @@ -175,6 +175,86 @@ static inline uint32_t dma_queue_capacity(dma_queue * q) { return q->capacity; } +// --------------------------------------------------------------------------- +// Overflow-safe DMA push: all UDMA type1 descriptor fields (roiwidth, +// roiheight, srcstride, dststride) are 16-bit, max 65535. This helper +// transparently handles values that exceed the 16-bit limit and submits +// chained DMA transtions. +// +// Case 1 (fast path): all params fit in 16 bits -> direct dma_queue_push. +// Case 2 (contiguous block): width == srcstride == dststride. Reshape the +// flat transfer into a 2D descriptor with sub_width <= 65535. Produces a +// single descriptor, preserving async DMA behavior. +// Case 3 (stride overflow): srcstride or dststride > 65535. Issue rows +// one at a time. The first N-1 rows are pushed+popped synchronously; +// the last row is left async so the caller can pop it. +// --------------------------------------------------------------------------- +#define UDMA_MAX_FIELD_VAL 65535u + +static inline bool dma_queue_push_chained(dma_queue *q, dma_ptr dptr, size_t dst_stride, size_t src_stride, size_t width, size_t nrows) { + // Fast path: everything fits in 16 bits. + if (__builtin_expect( + width <= UDMA_MAX_FIELD_VAL && + nrows <= UDMA_MAX_FIELD_VAL && + src_stride <= UDMA_MAX_FIELD_VAL && + dst_stride <= UDMA_MAX_FIELD_VAL, 1)) { + return dma_queue_push(q, dptr, dst_stride, src_stride, width, nrows); + } + + // Case 2: contiguous block (width == src_stride == dst_stride). + // Reshape total bytes into sub_width * sub_nrows where sub_width <= 65535. + if (width == src_stride && width == dst_stride) { + size_t total = width * nrows; + + // Pick the largest 128-byte-aligned sub_width that divides total evenly. + size_t sub_width = UDMA_MAX_FIELD_VAL & ~(size_t)127; // 65408 + while (sub_width > 0 && total % sub_width != 0) { + sub_width -= 128; + } + if (sub_width == 0) { + // Fallback: use original width (must fit) with adjusted nrows. + // This shouldn't happen for 128-aligned DMA sizes. + sub_width = width; + } + size_t sub_nrows = total / sub_width; + + // Handle sub_nrows > 65535 by issuing chunked descriptors. + const uint8_t *src = (const uint8_t *)dptr.src; + uint8_t *dst = (uint8_t *)dptr.dst; + size_t rows_done = 0; + while (rows_done < sub_nrows) { + size_t chunk = sub_nrows - rows_done; + if (chunk > UDMA_MAX_FIELD_VAL) chunk = UDMA_MAX_FIELD_VAL; + + dma_ptr p = dma_make_ptr(dst + rows_done * sub_width, src + rows_done * sub_width); + if (!dma_queue_push(q, p, sub_width, sub_width, sub_width, chunk)) + return false; + + rows_done += chunk; + // Complete all chunks without waiting except the last one, so the + // caller's single dma_queue_pop drains the final descriptor. + if (rows_done < sub_nrows) + dma_queue_pop_nowait(q); + } + return true; + } + + // Case 3: stride overflow — fall back to row-by-row. + { + const uint8_t *src = (const uint8_t *)dptr.src; + uint8_t *dst = (uint8_t *)dptr.dst; + for (size_t r = 0; r < nrows; ++r) { + dma_ptr p = dma_make_ptr(dst + r * dst_stride, + src + r * src_stride); + if (!dma_queue_push(q, p, 0, 0, width, 1)) + return false; + if (r + 1 < nrows) + dma_queue_pop_nowait(q); + } + return true; + } +} + #ifdef __cplusplus } // extern "C" #endif diff --git a/ggml/src/ggml-hexagon/htp/hex-utils.h b/ggml/src/ggml-hexagon/htp/hex-utils.h index fb8a25a3f20..8ed1456bc54 100644 --- a/ggml/src/ggml-hexagon/htp/hex-utils.h +++ b/ggml/src/ggml-hexagon/htp/hex-utils.h @@ -29,10 +29,22 @@ static inline uint64_t hex_get_pktcnt() { return pktcnt; } -static inline int32_t hex_is_aligned(void * addr, uint32_t align) { +static inline size_t hmx_ceil_div(size_t num, size_t den) { + return (num + den - 1) / den; +} + +static inline int32_t hex_is_aligned(const void * addr, uint32_t align) { return ((size_t) addr & (align - 1)) == 0; } +static inline size_t hex_align_up(size_t v, size_t align) { + return hmx_ceil_div(v, align) * align; +} + +static inline size_t hex_align_down(size_t v, size_t align) { + return (v / align) * align; +} + static inline int32_t hex_is_one_chunk(void * addr, uint32_t n, uint32_t chunk_size) { uint32_t left_off = (size_t) addr & (chunk_size - 1); uint32_t right_off = left_off + n; @@ -43,6 +55,14 @@ static inline uint32_t hex_round_up(uint32_t n, uint32_t m) { return m * ((n + m - 1) / m); } +static inline size_t hex_smin(size_t a, size_t b) { + return a < b ? a : b; +} + +static inline size_t hex_smax(size_t a, size_t b) { + return a > b ? a : b; +} + static inline void hex_l2fetch(const void * p, uint32_t width, uint32_t stride, uint32_t height) { const uint64_t control = Q6_P_combine_RR(stride, Q6_R_combine_RlRl(width, height)); Q6_l2fetch_AP((void *) p, control); diff --git a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c new file mode 100644 index 00000000000..c703a049426 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c @@ -0,0 +1,1528 @@ +#pragma clang diagnostic ignored "-Wgnu-zero-variadic-macro-arguments" +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#include +#include +#include +#include +#include + +#include +#include + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" + +#include "hex-dma.h" +#include "hvx-utils.h" +#include "hvx-dump.h" +#include "worker-pool.h" +#include "htp-ctx.h" +#include "htp-msg.h" + +#include "hmx-utils.h" +#include "hmx-ops.h" +#include "hmx-profile.h" + +static const __fp16 q4_0_to_fp16_lut[64] __attribute__((aligned(VLEN))) = { + -8, 0, -7, 0, -6, 0, -5, 0, -4, 0, -3, 0, -2, 0, -1, 0, 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7, 0, +}; + +static const __fp16 iq4_nl_to_fp16_lut[64] __attribute__((aligned(VLEN))) = { + -127, 0, -104, 0, -83, 0, -65, 0, -49, 0, -35, 0, -22, 0, -10, 0, + 1, 0, 13, 0, 25, 0, 38, 0, 53, 0, 69, 0, 89, 0, 113, 0, +}; + +// vscatter offsets for fused dequant+transpose: write K-values directly to [K][N] tile. +// word[i] = i*128 maps K-row-pair i to byte offset i*128 in the tile. +// Column offset (n*4) is added at runtime. Only entries 0..15 are used (masked by predicate). +static const int32_t weight_transpose_scatter_offsets[32] __attribute__((aligned(VLEN))) = { + 0*128, 1*128, 2*128, 3*128, 4*128, 5*128, 6*128, 7*128, + 8*128, 9*128, 10*128, 11*128, 12*128, 13*128, 14*128, 15*128, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, +}; + +// Scales per x4x2 logical block: 8 × sizeof(__fp16) = 16 bytes +#define HMX_X4X2_SCALES_PER_BLK 8 +#define HMX_X4X2_DBLK_SIZE 16 // 8 * 2 bytes + +static inline void swap_ptr(void **p1, void **p2) { + void *t = *p1; + *p1 = *p2; + *p2 = t; +} + +typedef struct { + uint8_t *dst; + const uint8_t *src; + dma_queue *dma; + size_t n_rows; + size_t src_stride; // DDR row stride (full row_stride) + size_t dst_stride; // VTCM sub-block row stride + size_t quant_off; // quant byte offset in each DDR row + size_t quant_width; // quant bytes to copy per row + size_t scale_off; // scale byte offset in each DDR row + size_t scale_width; // scale bytes to copy per row +} qweight_fetch_task_state_t; + +// Compute the byte stride of one row in x4x2 format. +// Numerically equals ggml_row_size(type, k) when k is 256-aligned, because +// x4x2 packing has the same density as block_q4_0 / block_q8_0. +// Layout per row: [quants: nb*128 (Q4) or nb*256 (Q8)][scales: nb*16 bytes] +// Total per row = nb * (128+16) = 144*nb (Q4) or nb * (256+16) = 272*nb (Q8). +// Callers must ensure k is a multiple of 256 (enforced by proc_hmx_matmul_req). +static inline size_t get_x4x2_row_stride(int weight_type, int k) { + int nb = (k + QK_Q4_0x4x2 - 1) / QK_Q4_0x4x2; + switch (weight_type) { + case HTP_TYPE_Q4_0: + case HTP_TYPE_IQ4_NL: + return (size_t)nb * (QK_Q4_0x4x2 / 2 + HMX_X4X2_DBLK_SIZE); // 144 * nb + case HTP_TYPE_Q8_0: + return (size_t)nb * (QK_Q8_0x4x2 + HMX_X4X2_DBLK_SIZE); // 272 * nb + default: + return 0; + } +} + +// --- Overflow-safe arithmetic for VTCM budget calculation --- + +static inline bool hmx_mul_overflow(size_t a, size_t b, size_t *out) { + if (a != 0 && b > SIZE_MAX / a) return true; + *out = a * b; + return false; +} + +static inline bool hmx_add_overflow(size_t a, size_t b, size_t *out) { + if (a > SIZE_MAX - b) return true; + *out = a + b; + return false; +} + +// Search for optimal (mc, nc) chunk sizes that maximize mc * nc within VTCM budget. +// +// Cost model: total = nc * per_n_cost + mc * per_m_cost + mc * nc * per_mn_cost + overhead +// per_n_cost: bytes per nc column (weight + scratch buffers) +// per_m_cost: bytes per mc row (activation) +// per_mn_cost: bytes per mc*nc element (output) +// overhead: fixed bytes (scales 256B, eye_tile 2048B, etc.) +// +// Algorithm: nc sweeps from n_max down by 32, analytically solving for mc_max. +// Returns 0 on success, -1 if VTCM is insufficient. +static int hmx_compute_chunks( + size_t vtcm_total, size_t overhead, + size_t per_n_cost, size_t per_m_cost, size_t per_mn_cost, + int m, int n, + size_t *m_chunk_out, size_t *n_chunk_out, + size_t *total_out) +{ + if (m <= 0 || n <= 0) return -1; + if (vtcm_total <= overhead) return -1; + if (per_n_cost == 0 || per_m_cost == 0 || per_mn_cost == 0) return -1; + + const size_t usable = vtcm_total - overhead; + size_t best_mn = 0, best_m = 0, best_n = 0; + + const size_t n_max = hex_align_down((size_t)n, HMX_FP16_TILE_N_COLS); + for (size_t nc = n_max; nc >= HMX_FP16_TILE_N_COLS; nc -= HMX_FP16_TILE_N_COLS) { + // Early exit: if nc * m_max cannot beat best, smaller nc won't either + if (nc * hex_align_down((size_t)m, HMX_FP16_TILE_N_ROWS) <= best_mn) + break; + + size_t n_fixed = 0, ncmn = 0, mc_denom = 0; + if (hmx_mul_overflow(nc, per_n_cost, &n_fixed)) continue; + if (n_fixed >= usable) goto next_nc; + + if (hmx_mul_overflow(nc, per_mn_cost, &ncmn)) goto next_nc; + if (hmx_add_overflow(per_m_cost, ncmn, &mc_denom) || mc_denom == 0) goto next_nc; + + { + size_t remain = usable - n_fixed; + size_t mc = remain / mc_denom; + mc = hex_align_down(mc, HMX_FP16_TILE_N_ROWS); + mc = hex_smin(mc, (size_t)m); + + if (mc > 0 && mc * nc > best_mn) { + best_mn = mc * nc; + best_m = mc; + best_n = nc; + } + } + +next_nc: + if (nc == HMX_FP16_TILE_N_COLS) break; // avoid size_t underflow + } + + if (best_m == 0 || best_n == 0) return -1; + + // Compute exact total (with overflow checks) + size_t t0 = 0, t1 = 0, t2 = 0, mn = 0, total = 0; + if (hmx_mul_overflow(best_n, per_n_cost, &t0)) return -1; + if (hmx_mul_overflow(best_m, per_m_cost, &t1)) return -1; + if (hmx_mul_overflow(best_m, best_n, &mn)) return -1; + if (hmx_mul_overflow(mn, per_mn_cost, &t2)) return -1; + if (hmx_add_overflow(t0, t1, &total)) return -1; + if (hmx_add_overflow(total, t2, &total)) return -1; + if (hmx_add_overflow(total, overhead, &total)) return -1; + + *m_chunk_out = best_m; + *n_chunk_out = best_n; + *total_out = total; + return 0; +} + +// forward declaration – defined after transfer_activation_chunk_fp32_to_fp16 +void transfer_activation_chunk_threaded(struct htp_context *ctx, __fp16 *dst, const float *src, int n_rows, int k_block, int k_stride); + +// Scatter row-major FP16 weight (already in VTCM scratch) directly into transposed [K][N] tiles. +// vtcm_src: [n_cols][k] row-major fp16 in VTCM scratch buffer +// vtcm_dst: [n_col_tiles][n_k_tiles][HMX_FP16_TILE_N_ELMS] tile-major interleaved fp16 +static void interleave_fp16_weight_chunk_to_tiles(__fp16 *restrict vtcm_dst, + const __fp16 *restrict vtcm_src, + int n_cols, int k) { + assert(n_cols % HMX_FP16_TILE_N_COLS == 0); + assert(k % HMX_FP16_TILE_N_COLS == 0); + + const int n_k_tiles = k / HMX_FP16_TILE_N_COLS; + const HVX_Vector v_scat_base = hvx_vmem(weight_transpose_scatter_offsets); + const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); + const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64); + + for (int r = 0; r < n_cols; r += 2) { + int ct = r / HMX_FP16_TILE_N_ROWS; // N-dimension tile index + int local_r = r % HMX_FP16_TILE_N_ROWS; // intra-tile row index + const bool next_row_valid = (r + 1) < n_cols; + + // Offset vectors for N-columns local_r and local_r+1, reused across K-tiles. + HVX_Vector v_off0 = Q6_Vw_vadd_VwVw(v_scat_base, Q6_V_vsplat_R(local_r * 4)); + HVX_Vector v_off1 = Q6_Vw_vadd_VwVw(v_off0, v_scat_step); + + for (int c = 0; c < k; c += HMX_FP16_TILE_N_COLS) { + int kt = c / HMX_FP16_TILE_N_COLS; + int tile_idx = ct * n_k_tiles + kt; + __fp16 *tile_base = vtcm_dst + tile_idx * HMX_FP16_TILE_N_ELMS; + + HVX_Vector v0 = hvx_vmemu(vtcm_src + r * k + c); + HVX_Vector v1 = next_row_valid ? hvx_vmemu(vtcm_src + (r + 1) * k + c) : Q6_V_vzero(); + + Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off0, v0); + Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off1, v1); + } + } +} + +// --- x4x2 format dequantizers --- + +// Dequantize one x4x2 Q4_0 group (32 elements from 32 packed bytes) -> 32 FP16 in first 64 bytes. +// In x4x2, sub-blocks 0..3 use lower nibbles, sub-blocks 4..7 use upper nibbles +// of the same 32 packed bytes. +static inline HVX_Vector dequantize_x4x2_q4_0_group_hvx( + const uint8_t *packed_32, bool upper_nibbles, + const __fp16 *scale, const HVX_Vector vlut_cvt) { + HVX_Vector vq = hvx_vmemu(packed_32); + const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + HVX_Vector v_scales = hvx_vec_splat_f16(*scale); + // q4x4x2 stores two int4 values per byte. Keep only the selected nibble. + HVX_Vector v_quants = upper_nibbles ? Q6_Vub_vlsr_VubR(vq, 4) : vq; + v_quants = Q6_V_vand_VV(v_quants, mask_h4); + // Shuffle before LUT + v_quants = Q6_Vb_vshuff_Vb(v_quants); + // Use standard vlut16 (not _nomatch) to avoid stale-register NaN. + // _nomatch retains the previous destination-register value for colliding + // indices, but the C intrinsic doesn't model the implicit read so the + // compiler may allocate a register containing garbage/NaN. + HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0); + HVX_Vector v_hf = Q6_V_lo_W(vp); + + return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hf, v_scales)); +} + +// Batch-dequantize 4 contiguous x4x2 Q4_0 groups (4x32 = 128 packed bytes) using +// full HVX vector width. One vmemu + one vlut16 replaces 4 separate calls. +// Output: out[0..3] each hold 32 FP16 values in the first 64 bytes. +static inline void dequantize_x4x2_q4_0_x4groups_hvx( + const uint8_t *packed_128, bool upper_nibbles, + const __fp16 *scales_4, const HVX_Vector vlut_cvt, + HVX_Vector out[4]) { + // Load all 128 packed bytes (4 contiguous 32-byte groups) + HVX_Vector vq = hvx_vmemu(packed_128); + const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + HVX_Vector v_quants = upper_nibbles ? Q6_Vub_vlsr_VubR(vq, 4) : vq; + v_quants = Q6_V_vand_VV(v_quants, mask_h4); + + // Shuffle before LUT + v_quants = Q6_Vb_vshuff_Vb(v_quants); + + // Full-width vlut16: 128 byte lookups -> 128 fp16 results in a VectorPair + HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0); + HVX_Vector v_lo = Q6_V_lo_W(vp); // [group0: 32 fp16 | group1: 32 fp16] + HVX_Vector v_hi = Q6_V_hi_W(vp); // [group2: 32 fp16 | group3: 32 fp16] + + // Build per-group scale vectors: first 64 bytes use scale_a, last 64 use scale_b + HVX_VectorPred q64 = Q6_Q_vsetq_R(64); + HVX_Vector v_sc01 = Q6_V_vmux_QVV(q64, hvx_vec_splat_f16(scales_4[0]), hvx_vec_splat_f16(scales_4[1])); + HVX_Vector v_sc23 = Q6_V_vmux_QVV(q64, hvx_vec_splat_f16(scales_4[2]), hvx_vec_splat_f16(scales_4[3])); + + v_lo = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_lo, v_sc01)); + v_hi = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hi, v_sc23)); + + // Extract individual groups: scatter uses q_mask64 so only first 64 bytes matter + out[0] = v_lo; // group0 already in [0:63] + out[1] = Q6_V_vror_VR(v_lo, 64); // group1 rotated to [0:63] + out[2] = v_hi; // group2 already in [0:63] + out[3] = Q6_V_vror_VR(v_hi, 64); // group3 rotated to [0:63] +} + +// Dequantize one x4x2 Q8_0 group (32 int8 quants) -> 32 FP16 in first 64 bytes. +static inline HVX_Vector dequantize_x4x2_q8_0_group_hvx( + const int8_t *quants_32, const __fp16 *scale) { + HVX_Vector vq = hvx_vmemu(quants_32); + HVX_Vector v_scales = hvx_vec_splat_f16(*scale); + HVX_Vector v0 = Q6_V_lo_W(Q6_Wh_vunpack_Vb(vq)); + HVX_Vector v_hf = Q6_Vhf_equals_Vh(v0); + return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hf, v_scales)); +} + +// Dequantize a tile range from x4x2 weight data (already in VTCM) to tile-major FP16. +// Input: vtcm_src has n_cols rows of x4x2 data, each row_stride bytes. +// Output: vtcm_dst in tile-major FP16 layout. +static void dequantize_x4x2_weight_to_fp16_tiles_task( + __fp16 *restrict vtcm_dst, + const uint8_t *restrict vtcm_src, + int n_cols, int k_block, + size_t row_stride, int weight_type, + int start_tile, int end_tile) { + + const int n_k_tiles = k_block / HMX_FP16_TILE_N_COLS; + const bool is_q4 = (weight_type == HTP_TYPE_Q4_0 || weight_type == HTP_TYPE_IQ4_NL); + const int qrow_size = is_q4 ? (k_block / 2) : k_block; + + const HVX_Vector vlut_cvt = (weight_type == HTP_TYPE_IQ4_NL) + ? hvx_vmem(iq4_nl_to_fp16_lut) : hvx_vmem(q4_0_to_fp16_lut); + + // vscatter setup: write dequantized K-values directly to transposed [K][N] tile positions. + // Each int32 element holds a K-row-pair (2 adjacent fp16 values). word[i] at offset i*128 + // maps to K-rows 2i and 2i+1. Column offset (n*4) added per row. + const HVX_Vector v_scat_base = hvx_vmem(weight_transpose_scatter_offsets); + const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); // 4 bytes = 1 column step + const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64); // first 16 words (64 bytes) + + for (int t = start_tile; t < end_tile; ) { + int ct = t / n_k_tiles; // column tile index + int kt = t % n_k_tiles; // K tile index + + // --- Batch-4 fast path for Q4: process 4 contiguous K-tiles with one vlut16 per row --- + if (is_q4 && (kt % 4 == 0) && (t + 4 <= end_tile) && ((t + 3) / n_k_tiles == ct)) { + int blk_idx = (kt * 32) / QK_Q4_0x4x2; + int sub_blk_base = ((kt * 32) % QK_Q4_0x4x2) / 32; // 0 or 4 + bool upper = (sub_blk_base >= 4); + int packed_off = blk_idx * (QK_Q4_0x4x2 / 2); // 128 contiguous packed bytes + int scale_off = qrow_size + blk_idx * HMX_X4X2_DBLK_SIZE + + sub_blk_base * (int)sizeof(__fp16); // 4 consecutive scales + + __fp16 *tile_bases[4]; + for (int g = 0; g < 4; g++) { tile_bases[g] = vtcm_dst + (t + g) * HMX_FP16_TILE_N_ELMS; } + + HVX_Vector v_off = v_scat_base; + for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) { + int row0 = ct * HMX_FP16_TILE_N_COLS + r; + int row1 = row0 + 1; + const uint8_t *r0 = vtcm_src + row0 * row_stride; + const uint8_t *r1 = vtcm_src + row1 * row_stride; + + HVX_Vector v0[4], v1[4]; + dequantize_x4x2_q4_0_x4groups_hvx(r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt, v0); + if (row1 < n_cols) { + dequantize_x4x2_q4_0_x4groups_hvx(r1 + packed_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt, v1); + } else { + v1[0] = v1[1] = v1[2] = v1[3] = Q6_V_vzero(); + } + + for (int g = 0; g < 4; g++) { Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_bases[g], HMX_FP16_TILE_SIZE - 1, v_off, v0[g]); } + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + for (int g = 0; g < 4; g++) { Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_bases[g], HMX_FP16_TILE_SIZE - 1, v_off, v1[g]); } + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + } + + for (int g = 0; g < 4; g++) { (void) *(volatile HVX_Vector *)(tile_bases[g]); } + + t += 4; + continue; + } + + // --- Single-tile fallback --- + __fp16 *tile_base = vtcm_dst + t * HMX_FP16_TILE_N_ELMS; + + if (is_q4) { + int blk_idx = (kt * 32) / QK_Q4_0x4x2; + int sub_blk = ((kt * 32) % QK_Q4_0x4x2) / 32; + bool upper = (sub_blk >= 4); + int byte_off = blk_idx * (QK_Q4_0x4x2 / 2) + (upper ? (sub_blk - 4) : sub_blk) * 32; + int scale_off = qrow_size + blk_idx * HMX_X4X2_DBLK_SIZE + sub_blk * (int)sizeof(__fp16); + + HVX_Vector v_off = v_scat_base; // reset to column 0 + for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) { + int row0 = ct * HMX_FP16_TILE_N_COLS + r; + int row1 = row0 + 1; + + const uint8_t *r0 = vtcm_src + row0 * row_stride; + const uint8_t *r1 = vtcm_src + row1 * row_stride; + + HVX_Vector v0 = dequantize_x4x2_q4_0_group_hvx( + r0 + byte_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt); + HVX_Vector v1 = (row1 < n_cols) + ? dequantize_x4x2_q4_0_group_hvx( + r1 + byte_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt) + : Q6_V_vzero(); + + Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v1); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + } + (void) *(volatile HVX_Vector *)(tile_base); + } else { + // Q8_0 + int blk_idx = (kt * 32) / QK_Q8_0x4x2; + int sub_blk = ((kt * 32) % QK_Q8_0x4x2) / 32; + int byte_off = blk_idx * QK_Q8_0x4x2 + sub_blk * 32; + int scale_off = qrow_size + blk_idx * HMX_X4X2_DBLK_SIZE + sub_blk * (int)sizeof(__fp16); + + HVX_Vector v_off = v_scat_base; // reset to column 0 + for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) { + int row0 = ct * HMX_FP16_TILE_N_COLS + r; + int row1 = row0 + 1; + + const uint8_t *r0 = vtcm_src + row0 * row_stride; + const uint8_t *r1 = vtcm_src + row1 * row_stride; + + HVX_Vector v0 = dequantize_x4x2_q8_0_group_hvx( + (const int8_t *)(r0 + byte_off), (const __fp16 *)(r0 + scale_off)); + HVX_Vector v1 = (row1 < n_cols) + ? dequantize_x4x2_q8_0_group_hvx( + (const int8_t *)(r1 + byte_off), (const __fp16 *)(r1 + scale_off)) + : Q6_V_vzero(); + + Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v1); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + } + (void) *(volatile HVX_Vector *)(tile_base); + } + ++t; + } + + // Drain HVX scatter write buffer: a vmem load on the same HW thread retires + // all pending scatter entries to VTCM. Without this, the main thread's HMX + // reads may see stale data because atomic_fetch_sub (release) only orders + // regular stores, not the HVX scatter buffer. + if (start_tile < end_tile) { + (void) *(volatile HVX_Vector *)(vtcm_dst + (end_tile - 1) * HMX_FP16_TILE_N_ELMS); + } +} + +typedef struct { + __fp16 *dst; + const uint8_t *src; + int n_cols; + int k_block; + size_t row_stride; + int weight_type; + int n_tot_tiles; + int n_tiles_per_task; + int n_tasks; +} x4x2_dequantize_state_t; + +static void dequantize_x4x2_worker_loop(unsigned int n, unsigned int i, void *data) { + x4x2_dequantize_state_t *state = (x4x2_dequantize_state_t *)data; + + for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) { + int start = task_id * state->n_tiles_per_task; + int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles); + + dequantize_x4x2_weight_to_fp16_tiles_task( + state->dst, state->src, state->n_cols, state->k_block, + state->row_stride, state->weight_type, start, end); + } +} + +static void dequantize_x4x2_weight_chunk_to_fp16_tiles( + struct htp_context *ctx, __fp16 *vtcm_dst, + const void *vtcm_src, int n_cols, int k_block, + size_t row_stride, int weight_type) { + + assert(n_cols % HMX_FP16_TILE_N_COLS == 0); + assert(k_block % HMX_FP16_TILE_N_COLS == 0); + + int n_col_tiles = n_cols / HMX_FP16_TILE_N_COLS; + int n_k_tiles = k_block / HMX_FP16_TILE_N_COLS; + int n_tot_tiles = n_col_tiles * n_k_tiles; + + size_t n_tiles_per_task = hmx_ceil_div(n_tot_tiles, ctx->n_threads); + + x4x2_dequantize_state_t state; + state.n_tasks = (n_tot_tiles + n_tiles_per_task - 1) / n_tiles_per_task; + state.n_tot_tiles = n_tot_tiles; + state.n_tiles_per_task = n_tiles_per_task; + state.dst = vtcm_dst; + state.src = (const uint8_t *)vtcm_src; + state.n_cols = n_cols; + state.k_block = k_block; + state.row_stride = row_stride; + state.weight_type = weight_type; + + worker_pool_run_func(ctx->worker_pool, dequantize_x4x2_worker_loop, &state, ctx->n_threads); +} + +// --- End x4x2 dequantizers --- + +// requires external HMX lock +static void core_dot_chunk_fp16(__fp16 *output, const __fp16 *activation, const __fp16 *weight, const __fp16 *scales, + int n_row_tiles, int n_col_tiles, int n_dot_tiles) { + hmx_set_output_scales(scales); + + for (int r = 0; r < n_row_tiles; ++r) { + for (int c = 0; c < n_col_tiles; ++c) { + Q6_mxclracc_hf(); + + const __fp16 *row_tiles = activation + r * n_dot_tiles * HMX_FP16_TILE_N_ELMS; + const __fp16 *col_tiles = weight + c * n_dot_tiles * HMX_FP16_TILE_N_ELMS; + + for (int k = 0; k < n_dot_tiles; ++k) { + int offset = k * HMX_FP16_TILE_N_ELMS; + hmx_load_tile_pair_fp16(row_tiles + offset, col_tiles + offset); + } + + __fp16 *out_tile = output + (r * n_col_tiles + c) * HMX_FP16_TILE_N_ELMS; + hmx_consume_accumulator_fp16(out_tile); + } + } +} + +static void transfer_output_chunk_fp16_to_fp32(float *restrict dst, const __fp16 *restrict vtcm_src, int n_rows, int n_cols, int n) { + assert(n_cols % HMX_FP16_TILE_N_COLS == 0); + const int n_col_tiles = n_cols / HMX_FP16_TILE_N_COLS; + + const HVX_Vector one = hvx_vec_splat_f16(1.0); + + for (int r = 0; r < n_rows; r += 2) { + int r0 = r / HMX_FP16_TILE_N_ROWS; + int r1 = r % HMX_FP16_TILE_N_ROWS; + + #pragma unroll(4) + for (int c = 0; c < n_cols; c += HMX_FP16_TILE_N_COLS) { + int c0 = c / HMX_FP16_TILE_N_COLS; + + const __fp16 *tile = vtcm_src + (r0 * n_col_tiles + c0) * HMX_FP16_TILE_N_ELMS; + + HVX_Vector v = ((const HVX_Vector *) tile)[r1 / 2]; + HVX_VectorPair vp = Q6_Wqf32_vmpy_VhfVhf(v, one); + + volatile HVX_Vector *pv_out0 = (volatile HVX_Vector *) (dst + (r * n + c + 0)); + volatile HVX_Vector *pv_out1 = (volatile HVX_Vector *) (dst + (r * n + c + n)); // next row in global memory + + *pv_out0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(vp)); + if (r + 1 < n_rows) { + *pv_out1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(vp)); + } + } + } +} + +typedef struct { + const __fp16 *vtcm_src; + float *dst; + int n_tasks; + int n_tot_chunks; + int n_chunks_per_task; + int n_cols; + int n; // DDR row stride (total output columns) +} output_transfer_task_state_t; + +static void transfer_output_chunk_worker_fn(unsigned int n, unsigned int i, void *data) { + output_transfer_task_state_t *st = (output_transfer_task_state_t *) data; + + for (unsigned int task_id = i; task_id < (unsigned int)st->n_tasks; task_id += n) { + int chunk_idx = task_id * st->n_chunks_per_task; + size_t chunk_size = hex_smin(st->n_tot_chunks - chunk_idx, st->n_chunks_per_task); + + float *dst = st->dst + chunk_idx * st->n; + const __fp16 *vtcm_src = st->vtcm_src + chunk_idx * st->n_cols; + transfer_output_chunk_fp16_to_fp32(dst, vtcm_src, chunk_size, st->n_cols, st->n); + } +} + +static void transfer_output_chunk_threaded(struct htp_context *ctx, float *dst, const __fp16 *vtcm_src, + int n_rows, int n_cols, int n) { + assert(n_cols % HMX_FP16_TILE_N_COLS == 0); + + size_t n_tot_chunks = n_rows; + size_t n_chunks_per_task = 32; // must be multiple of HMX_FP16_TILE_N_ROWS (32) + + output_transfer_task_state_t state; + state.n_tasks = (n_tot_chunks + n_chunks_per_task - 1) / n_chunks_per_task; + state.n_tot_chunks = n_tot_chunks; + state.n_chunks_per_task = n_chunks_per_task; + state.dst = dst; + state.vtcm_src = vtcm_src; + state.n_cols = n_cols; + state.n = n; + + worker_pool_run_func(ctx->worker_pool, transfer_output_chunk_worker_fn, &state, ctx->n_threads); +} + +static inline int hmx_matmul_batch_r2(const hmx_matmul_w16a32_batched_params_t *params) { + return params->ne02 > 0 ? params->ne12 / params->ne02 : 1; +} + +static inline int hmx_matmul_batch_r3(const hmx_matmul_w16a32_batched_params_t *params) { + return params->ne03 > 0 ? params->ne13 / params->ne03 : 1; +} + +static inline const __fp16 *hmx_matmul_weight_batch_ptr(const hmx_matmul_w16a32_batched_params_t *params, + int dst_b2, int dst_b3) { + const int r2 = hmx_matmul_batch_r2(params); + const int r3 = hmx_matmul_batch_r3(params); + return (const __fp16 *) ((const uint8_t *) params->permuted_weight + + (size_t) (dst_b2 / r2) * params->src0_nb2 + + (size_t) (dst_b3 / r3) * params->src0_nb3); +} + +static inline const float *hmx_matmul_activation_batch_ptr(const hmx_matmul_w16a32_batched_params_t *params, + int dst_b2, int dst_b3) { + return (const float *) ((const uint8_t *) params->activation + + (size_t) dst_b2 * params->src1_nb2 + + (size_t) dst_b3 * params->src1_nb3); +} + +static inline float *hmx_matmul_dst_batch_ptr(const hmx_matmul_w16a32_batched_params_t *params, + int dst_b2, int dst_b3) { + return (float *) ((uint8_t *) params->dst + + (size_t) dst_b2 * params->dst_nb2 + + (size_t) dst_b3 * params->dst_nb3); +} + +static int hmx_mat_mul_permuted_w16a32_batched_legacy(struct htp_context *ctx, + const hmx_matmul_w16a32_batched_params_t *params) { + int ret = 0; + for (int b3 = 0; b3 < params->ne13 && ret == 0; ++b3) { + for (int b2 = 0; b2 < params->ne12 && ret == 0; ++b2) { + ret = hmx_mat_mul_permuted_w16a32(ctx, + hmx_matmul_dst_batch_ptr(params, b2, b3), + hmx_matmul_activation_batch_ptr(params, b2, b3), + hmx_matmul_weight_batch_ptr(params, b2, b3), + params->m, params->k, params->n, + params->act_stride, params->weight_stride); + } + } + return ret; +} + +int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmul_w16a32_batched_params_t *params) { + if (!ctx || !params || !params->dst || !params->activation || !params->permuted_weight) { return -1; } + if (!params->m || !params->k || !params->n) { return -1; } + if (params->act_stride < params->k || params->weight_stride < params->k || params->dst_stride < params->n) { return -1; } + if (params->ne02 <= 0 || params->ne03 <= 0 || params->ne12 <= 0 || params->ne13 <= 0) { return -1; } + if (params->ne12 % params->ne02 != 0 || params->ne13 % params->ne03 != 0) { return -1; } + if (params->k % 32 != 0 || params->n % 32 != 0) { return -1; } + + if (!hex_is_aligned(params->dst, VLEN) || + !hex_is_aligned(params->activation, VLEN) || + !hex_is_aligned(params->permuted_weight, VLEN)) { + return -1; + } + + const int group_size = hmx_matmul_batch_r2(params); + + if (group_size <= 1) { + FARF(MEDIUM, "%s: no dim2 GQA reuse (group=%d), using legacy batched loop", __func__, group_size); + return hmx_mat_mul_permuted_w16a32_batched_legacy(ctx, params); + } + + // Grouped path: reuse interleaved weight across all q_heads sharing a + // kv_head. Each q_head gets its own activation buffer in VTCM (so + // activation is loaded once per m_chunk and reused across all n_chunks), + // and each q_head is computed individually to avoid tile-major packing + // issues. m_chunk_n_rows is always a multiple of 32 (from + // hmx_compute_chunks), so per-head tile arrays don't overlap. + const size_t vtcm_budget = ctx->vtcm_scratch_size; + const size_t vec_dot_size = params->k * sizeof(__fp16); + + // When the activation has a large stride (e.g. permuted Q tensor with + // act_stride >> k), HVX vector loads from strided DDR thrash L2 cache. + // Allocate an F32 scratch buffer in VTCM and use 2D DMA to gather + // strided rows into a contiguous block before the F32->F16 conversion. + const bool use_dma_activation = (params->act_stride > params->k); + const size_t f32_scratch_per_m = use_dma_activation ? (size_t) params->k * sizeof(float) : 0; + + size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0, vtcm_used = 0; + if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, + /*per_n=*/3 * vec_dot_size, + /*per_m=*/group_size * vec_dot_size + f32_scratch_per_m, + /*per_mn=*/sizeof(__fp16), + params->m, params->n, + &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) { + FARF(HIGH, "%s: grouped path does not fit VTCM, falling back to legacy batched loop", __func__); + return hmx_mat_mul_permuted_w16a32_batched_legacy(ctx, params); + } + + const size_t act_head_stride = m_chunk_n_rows * (size_t) params->k; // fp16 elements between heads + const size_t weight_area_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); + const size_t activation_area_size = hex_align_up(group_size * m_chunk_n_rows * vec_dot_size, HMX_FP16_TILE_SIZE); + const size_t output_area_size = hex_align_up(m_chunk_n_rows * n_chunk_n_cols * sizeof(__fp16), HMX_FP16_TILE_SIZE); + const size_t scratch_area_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); + const size_t f32_scratch_size = use_dma_activation + ? hex_align_up(m_chunk_n_rows * (size_t) params->k * sizeof(float), HMX_FP16_TILE_SIZE) : 0; + + uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base; + __fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_area_size); + __fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, activation_area_size); + __fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, output_area_size); + void *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch_area_size); + void *vtcm_scratch1 = vtcm_seq_alloc(&vtcm_ptr, scratch_area_size); + __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); + float *vtcm_f32_act = use_dma_activation ? (float *) vtcm_seq_alloc(&vtcm_ptr, f32_scratch_size) : NULL; + + if ((size_t) (vtcm_ptr - (uint8_t *) ctx->vtcm_base) > vtcm_budget) { + FARF(HIGH, "%s: grouped layout overflowed VTCM, falling back to legacy batched loop", __func__); + return hmx_mat_mul_permuted_w16a32_batched_legacy(ctx, params); + } + + hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // fp16: 1.0 + + FARF(MEDIUM, "%s: grouped path m=%d k=%d n=%d group=%d streams=%d mc=%zu nc=%zu vtcm=%zu/%zu", + __func__, params->m, params->k, params->n, group_size, params->ne13, + m_chunk_n_rows, n_chunk_n_cols, + (size_t) (vtcm_ptr - (uint8_t *) ctx->vtcm_base), vtcm_budget); + + TIMER_DEFINE(activation_load); + TIMER_DEFINE(weight_load); + TIMER_DEFINE(hmx_core); + TIMER_DEFINE(output_store); + TIMER_DEFINE(total); + + TIMER_START(total); + + const size_t fp16_row_bytes = (size_t) params->k * sizeof(__fp16); + const size_t weight_row_bytes = (size_t) params->weight_stride * sizeof(__fp16); + + for (int b3 = 0; b3 < params->ne13; ++b3) { + for (int b2_base = 0; b2_base < params->ne12; b2_base += group_size) { + const __fp16 *weight_group = hmx_matmul_weight_batch_ptr(params, b2_base, b3); + + for (size_t mr = 0; mr < (size_t) params->m; mr += m_chunk_n_rows) { + const size_t n_rows = hex_smin((size_t) params->m - mr, m_chunk_n_rows); + + // Pre-load activations for all heads in the group (once per m_chunk). + // When the source is strided (permuted Q), use 2D DMA to gather + // contiguous rows into a VTCM scratch buffer first, then HVX + // converts from the contiguous VTCM buffer. This avoids L2 cache + // thrashing from HVX loads at large strides. + TIMER_START(activation_load); + for (int g = 0; g < group_size; ++g) { + const float *activation_chunk = hmx_matmul_activation_batch_ptr(params, b2_base + g, b3) + mr * params->act_stride; + __fp16 *vtcm_act_g = vtcm_activation + (size_t) g * act_head_stride; + if (use_dma_activation) { + const size_t row_bytes = (size_t) params->k * sizeof(float); + const size_t stride_bytes = (size_t) params->act_stride * sizeof(float); + dma_queue_push_chained(ctx->dma[0], + dma_make_ptr(vtcm_f32_act, activation_chunk), + row_bytes, stride_bytes, row_bytes, n_rows); + dma_queue_pop(ctx->dma[0]); + transfer_activation_chunk_threaded(ctx, vtcm_act_g, + vtcm_f32_act, (int) n_rows, + params->k, params->k); + } else { + transfer_activation_chunk_threaded(ctx, vtcm_act_g, + activation_chunk, (int) n_rows, + params->k, params->act_stride); + } + } + TIMER_STOP(activation_load); + + void *buf_curr = vtcm_scratch0; + void *buf_next = vtcm_scratch1; + + { + const size_t n_cols_first = hex_smin((size_t) params->n, n_chunk_n_cols); + dma_queue_push_chained(ctx->dma[0], dma_make_ptr(buf_curr, weight_group), + fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_first); + } + + HAP_compute_res_hmx_lock(ctx->vtcm_rctx); + + for (size_t nc = 0; nc < (size_t) params->n; nc += n_chunk_n_cols) { + const size_t n_cols = hex_smin((size_t) params->n - nc, n_chunk_n_cols); + + TIMER_START(weight_load); + { + dma_queue_pop(ctx->dma[0]); + + const size_t nc_next = nc + n_chunk_n_cols; + if (nc_next < (size_t) params->n) { + const size_t n_cols_next = hex_smin((size_t) params->n - nc_next, n_chunk_n_cols); + const __fp16 *next_weight_chunk = weight_group + nc_next * params->weight_stride; + + dma_queue_push_chained(ctx->dma[0], dma_make_ptr(buf_next, next_weight_chunk), + fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_next); + } + + interleave_fp16_weight_chunk_to_tiles(vtcm_weight, (const __fp16 *) buf_curr, n_cols, params->k); + swap_ptr(&buf_curr, &buf_next); + } + TIMER_STOP(weight_load); + + // Reuse the interleaved weight for every q_head in this GQA group + for (int g = 0; g < group_size; ++g) { + TIMER_START(hmx_core); + { + const __fp16 *vtcm_act_g = vtcm_activation + (size_t) g * act_head_stride; + const int n_row_tiles = hmx_ceil_div((int) n_rows, HMX_FP16_TILE_N_ROWS); + const int n_col_tiles = hmx_ceil_div((int) n_cols, HMX_FP16_TILE_N_COLS); + core_dot_chunk_fp16(vtcm_output, vtcm_act_g, vtcm_weight, vtcm_scales, + n_row_tiles, n_col_tiles, params->k / 32); + } + TIMER_STOP(hmx_core); + + TIMER_START(output_store); + { + float *output = hmx_matmul_dst_batch_ptr(params, b2_base + g, b3) + mr * params->dst_stride + nc; + transfer_output_chunk_threaded(ctx, output, vtcm_output, (int) n_rows, (int) n_cols, params->dst_stride); + } + TIMER_STOP(output_store); + } + } + + HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); + } + } + } + + TIMER_STOP(total); + +#if defined(ENABLE_PROFILE_TIMERS) + FARF(HIGH, "%s: %lld us, m=%d k=%d n=%d group=%d", __func__, TIMER_US(total), + params->m, params->k, params->n, group_size); + FARF(HIGH, " activation_load: %lld us, weight_load: %lld us, hmx_core: %lld us, output_store: %lld us", + TIMER_US(activation_load), TIMER_US(weight_load), TIMER_US(hmx_core), TIMER_US(output_store)); +#endif + + return 0; +} + +int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, float *restrict dst, const float *restrict activation, + const __fp16 *restrict permuted_weight, int m, int k, int n, + int act_stride, int weight_stride) { + if (!dst || !activation || !permuted_weight || !m || !n || !k) { return -1; } + if (act_stride < k || weight_stride < k) { return -1; } + if (k % 32 != 0 || n % 32 != 0) { return -1; } + + if (!hex_is_aligned(dst, VLEN) || !hex_is_aligned(activation, VLEN) || !hex_is_aligned(permuted_weight, VLEN)) { + return -1; + } + + // --- Dynamic VTCM layout --- + const size_t vtcm_budget = ctx->vtcm_scratch_size; + const size_t vec_dot_size = k * sizeof(__fp16); + + // DMA-based activation gather for strided tensors (see batched path comment). + const bool use_dma_activation = (act_stride > k); + const size_t f32_scratch_per_m = use_dma_activation ? (size_t) k * sizeof(float) : 0; + + size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0, vtcm_used = 0; + if (hmx_compute_chunks(vtcm_budget, + /*overhead=*/ 256, + /*per_n=*/ 3 * vec_dot_size, // W + S0 + S1 + /*per_m=*/ vec_dot_size + f32_scratch_per_m, // A + optional F32 scratch + /*per_mn=*/ sizeof(__fp16), // O + m, n, + &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) { + FARF(HIGH, "%s: VTCM too small (m=%d k=%d n=%d budget=%zu)", __func__, m, k, n, vtcm_budget); + return -1; + } + + const size_t weight_area_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); + const size_t activation_area_size = hex_align_up(m_chunk_n_rows * vec_dot_size, HMX_FP16_TILE_SIZE); + const size_t output_area_size = hex_align_up(m_chunk_n_rows * n_chunk_n_cols * sizeof(__fp16), HMX_FP16_TILE_SIZE); + const size_t scratch_area_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); + const size_t f32_scratch_size = use_dma_activation + ? hex_align_up(m_chunk_n_rows * (size_t) k * sizeof(float), HMX_FP16_TILE_SIZE) : 0; + + // VTCM layout: weight | activation | output | scratch0 | scratch1 | scales | [f32_scratch] + uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base; + __fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_area_size); + __fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, activation_area_size); + __fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, output_area_size); + void *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch_area_size); + void *vtcm_scratch1 = vtcm_seq_alloc(&vtcm_ptr, scratch_area_size); + __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); + float *vtcm_f32_act = use_dma_activation ? (float *) vtcm_seq_alloc(&vtcm_ptr, f32_scratch_size) : NULL; + if ((size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base) > vtcm_budget) { + FARF(ERROR, "%s: vtcm overflow: used=%zu limit=%zu", __func__, + (size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget); + return -1; + } + + hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // fp16: 1.0 + + FARF(MEDIUM, "%s: m=%d k=%d n=%d mc=%zu nc=%zu vtcm=%zu/%zu", + __func__, m, k, n, m_chunk_n_rows, n_chunk_n_cols, + (size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget); + + TIMER_DEFINE(activation_load); + TIMER_DEFINE(weight_load); + TIMER_DEFINE(hmx_core); + TIMER_DEFINE(output_store); + + TIMER_DEFINE(total); + TIMER_START(total); + + HAP_compute_res_hmx_lock(ctx->vtcm_rctx); + + for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) { + // transfer activation matrix chunk into VTCM + size_t n_rows = hex_smin(m - mr, m_chunk_n_rows); + + TIMER_START(activation_load); + { + const float *activation_chunk = activation + mr * act_stride; + if (use_dma_activation) { + const size_t row_bytes = (size_t) k * sizeof(float); + const size_t stride_bytes = (size_t) act_stride * sizeof(float); + dma_queue_push_chained(ctx->dma[0], + dma_make_ptr(vtcm_f32_act, activation_chunk), + row_bytes, stride_bytes, row_bytes, n_rows); + dma_queue_pop(ctx->dma[0]); + transfer_activation_chunk_threaded(ctx, vtcm_activation, + vtcm_f32_act, n_rows, k, k); + } else { + transfer_activation_chunk_threaded(ctx, vtcm_activation, + activation_chunk, n_rows, k, act_stride); + } + } + TIMER_STOP(activation_load); + + const size_t fp16_row_bytes = (size_t) k * sizeof(__fp16); + const size_t weight_row_bytes = (size_t) weight_stride * sizeof(__fp16); + + void *buf_curr = vtcm_scratch0; + void *buf_next = vtcm_scratch1; + + // issue async DMA for the first weight chunk + // NOTE: use 2D DMA (n_cols rows x fp16_row_bytes) to avoid 16-bit roiwidth overflow. + // The source rows can be strided (e.g. KV-cache K after ggml_permute). + { + const size_t n_cols_first = hex_smin(n, n_chunk_n_cols); + + dma_queue_push_chained(ctx->dma[0], dma_make_ptr(buf_curr, permuted_weight), + fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_first); + } + + for (size_t nc = 0; nc < n; nc += n_chunk_n_cols) { + size_t n_cols = hex_smin(n - nc, n_chunk_n_cols); + + TIMER_START(weight_load); + { + dma_queue_pop(ctx->dma[0]); // wait until current weight chunk is ready + + // issue async DMA for the next weight chunk (double buffering) + const size_t nc_next = nc + n_chunk_n_cols; + if (nc_next < n) { + const size_t n_cols_next = hex_smin(n - nc_next, n_chunk_n_cols); + const __fp16 *next_weight_chunk = permuted_weight + nc_next * weight_stride; + + dma_queue_push_chained(ctx->dma[0], dma_make_ptr(buf_next, next_weight_chunk), + fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_next); + } + + // interleave row-major fp16 from scratch into tile-major in vtcm_weight + interleave_fp16_weight_chunk_to_tiles(vtcm_weight, (const __fp16 *)buf_curr, n_cols, k); + + swap_ptr(&buf_curr, &buf_next); + } + TIMER_STOP(weight_load); + + TIMER_START(hmx_core); + { + const int n_row_tiles = hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS); + const int n_col_tiles = hmx_ceil_div(n_cols, HMX_FP16_TILE_N_COLS); + core_dot_chunk_fp16(vtcm_output, vtcm_activation, vtcm_weight, vtcm_scales, n_row_tiles, n_col_tiles, k / 32); + } + TIMER_STOP(hmx_core); + + TIMER_START(output_store); + { + float *output = dst + (mr * n + nc); + transfer_output_chunk_threaded(ctx, output, vtcm_output, n_rows, n_cols, n); + } + TIMER_STOP(output_store); + } + + } + + HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); + + TIMER_STOP(total); + +#if defined(ENABLE_PROFILE_TIMERS) + FARF(HIGH, "%s: %lld us, m=%d k=%d n=%d", __func__, TIMER_US(total), m, k, n); + FARF(HIGH, " activation_load: %lld us, weight_load: %lld us, hmx_core: %lld us, output_store: %lld us", + TIMER_US(activation_load), TIMER_US(weight_load), TIMER_US(hmx_core), TIMER_US(output_store)); + { + size_t weight_size = (size_t)k * n * sizeof(__fp16); + float bandwidth = 1e-3f * weight_size / (float)TIMER_US(weight_load); + FARF(HIGH, " weight load bandwidth: %.2f GB/s", bandwidth); + } +#endif + + return 0; +} + +int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict out, const float *restrict x, const uint8_t *restrict w, int m, + int k, int n, int w_type); + +int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict dst, const float *restrict activation, + const uint8_t *restrict permuted_weight, int m, int k, int n, + int weight_type) { + if (!dst || !activation || !permuted_weight || !m || !n || !k) { return -1; } + if (k % 32 != 0 || n % 32 != 0) { return -1; } + + if (!hex_is_aligned(dst, VLEN) || !hex_is_aligned(activation, VLEN) || !hex_is_aligned(permuted_weight, VLEN)) { + return -1; + } + + // for large m, k (e.g. prefill FFN Down), use out-stationary version + if (m >= 128 && k > n && n > 1024) { + FARF(MEDIUM, "hmx_matmul_qk: OUT-STATIONARY path m=%d k=%d n=%d type=%d (K_BLOCK=512, %d K-iters with fp16 intermediate)", + m, k, n, weight_type, (k + 511) / 512); + return mat_mul_qk_0_d16a32_out_stationary(ctx, dst, activation, permuted_weight, m, k, n, weight_type); + } + + size_t row_stride = get_x4x2_row_stride(weight_type, k); + if (row_stride == 0) { + return -1; + } + + FARF(MEDIUM, "hmx_matmul_qk: STANDARD path m=%d k=%d n=%d type=%d", m, k, n, weight_type); + + // --- Dynamic VTCM layout --- + const size_t vtcm_budget = ctx->vtcm_scratch_size; + const size_t vec_dot_size = k * sizeof(__fp16); + const bool use_pipeline = (m >= 128) && (k <= n); + + // Select cost parameters based on execution path + size_t per_n_cost, per_mn_cost; + if (use_pipeline) { + per_n_cost = row_stride + 2 * vec_dot_size; // Q + S0 + S1 (dequant bufs) + per_mn_cost = 2 * sizeof(__fp16); // O x 2 (output double buffer) + } else { + per_n_cost = vec_dot_size + 2 * row_stride; // W + S0 + S1 (x4x2 DMA bufs) + per_mn_cost = sizeof(__fp16); // O x 1 + } + + size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0, vtcm_used = 0; + if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, + per_n_cost, /*per_m=*/vec_dot_size, per_mn_cost, + m, n, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) { + FARF(HIGH, "%s: VTCM too small (m=%d k=%d n=%d pipe=%d budget=%zu)", + __func__, m, k, n, use_pipeline, vtcm_budget); + return -1; + } + + // Compute precise buffer sizes per execution path + const size_t weight_area_size = hex_align_up( + n_chunk_n_cols * (use_pipeline ? row_stride : vec_dot_size), HMX_FP16_TILE_SIZE); + const size_t activation_area_size = hex_align_up(m_chunk_n_rows * vec_dot_size, HMX_FP16_TILE_SIZE); + const size_t output_area_size = hex_align_up( + m_chunk_n_rows * n_chunk_n_cols * sizeof(__fp16), HMX_FP16_TILE_SIZE); + + size_t scratch0_size, scratch1_size, scratch2_size; + if (use_pipeline) { + scratch0_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); // dequant buf 0 + scratch1_size = scratch0_size; // dequant buf 1 + scratch2_size = output_area_size; // output buf 1 + } else { + scratch0_size = hex_align_up(n_chunk_n_cols * row_stride, HMX_FP16_TILE_SIZE); // x4x2 DMA buf 0 + scratch1_size = scratch0_size; // x4x2 DMA buf 1 + scratch2_size = 0; // unused + } + + uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base; + __fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_area_size); + __fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, activation_area_size); + __fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, output_area_size); + void *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch0_size); + void *vtcm_scratch1 = vtcm_seq_alloc(&vtcm_ptr, scratch1_size); + void *vtcm_scratch2 = scratch2_size ? vtcm_seq_alloc(&vtcm_ptr, scratch2_size) : NULL; + __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); + if ((size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base) > vtcm_budget) { + FARF(ERROR, "%s: vtcm overflow: used=%zu limit=%zu", __func__, + (size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget); + return -1; + } + + hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // fp16: 1.0 + + FARF(MEDIUM, "%s: m=%d k=%d n=%d wtype=%d pipe=%d mc=%zu nc=%zu vtcm=%zu/%zu", + __func__, m, k, n, weight_type, use_pipeline, + m_chunk_n_rows, n_chunk_n_cols, + (size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget); + + TIMER_DEFINE(activation_load); + TIMER_DEFINE(weight_load); + TIMER_DEFINE(hmx_core); + TIMER_DEFINE(output_store); + + TIMER_DEFINE(total); + TIMER_START(total); + + FARF(MEDIUM, "hmx_matmul_qk: %s mc=%zu nc=%zu vtcm=%zu/%zu", + use_pipeline ? "PIPELINE" : "SEQUENTIAL", m_chunk_n_rows, n_chunk_n_cols, + (size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget); + + HAP_compute_res_hmx_lock(ctx->vtcm_rctx); + + if (!use_pipeline) { + for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) { + // transfer activation matrix chunk into VTCM + size_t n_rows = hex_smin(m - mr, m_chunk_n_rows); + + TIMER_START(activation_load); + { + const float *activation_chunk = activation + mr * k; + transfer_activation_chunk_threaded(ctx, vtcm_activation, activation_chunk, n_rows, k, k); + } + TIMER_STOP(activation_load); + + void *buf_curr = vtcm_scratch0; + void *buf_next = vtcm_scratch1; + + // issue async DDR data transfer for the first weight chunk + // NOTE: use 2D DMA (n_cols rows x row_stride bytes) instead of 1D + // because UDMA roiwidth is 16-bit and total size can exceed 65535. + { + const size_t n_cols_first = hex_smin(n, n_chunk_n_cols); + dma_queue_push_chained(ctx->dma[0], dma_make_ptr(buf_curr, permuted_weight), row_stride, row_stride, row_stride, n_cols_first); + } + + for (size_t nc = 0; nc < n; nc += n_chunk_n_cols) { + size_t n_cols = hex_smin(n - nc, n_chunk_n_cols); + + TIMER_START(weight_load); + { + dma_queue_pop(ctx->dma[0]); // wait until current weight chunk become ready + + const size_t nc_next = nc + n_chunk_n_cols; + if (nc_next < n) { + const size_t n_cols_next = hex_smin(n - nc_next, n_chunk_n_cols); + + const uint8_t *next_weight_chunk = permuted_weight + nc_next * row_stride; + + dma_queue_push_chained(ctx->dma[0], dma_make_ptr(buf_next, next_weight_chunk), row_stride, row_stride, row_stride, n_cols_next); + } + + // Dequant + vscatter writes directly to [K, N] transposed tiles. + // HMX computes C = A x B, where A=[M,K] activation, B=[K,N] weight. + dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight, buf_curr, n_cols, k, row_stride, weight_type); + + swap_ptr(&buf_curr, &buf_next); + } + TIMER_STOP(weight_load); + + TIMER_START(hmx_core); + { + const int n_row_tiles = hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS); + const int n_col_tiles = hmx_ceil_div(n_cols, HMX_FP16_TILE_N_COLS); + core_dot_chunk_fp16(vtcm_output, vtcm_activation, vtcm_weight, vtcm_scales, n_row_tiles, n_col_tiles, k / 32); + } + TIMER_STOP(hmx_core); + + TIMER_START(output_store); + { + float *output = dst + (mr * n + nc); + transfer_output_chunk_threaded(ctx, output, vtcm_output, n_rows, n_cols, n); + } + TIMER_STOP(output_store); + } + } + } else { + // 4-stage pipeline: DMA load (A), dequantize (B), HMX matmul (C), store (D) + // stage B and D (dequantize and store) are expected to be on the critical path + + // A --> B: vtcm_qweight, 1 buffer + // B --> C: vtcm_weight0/vtcm_weight1, 2 buffers + // C --> D: vtcm_output0/vtcm_output1, 2 buffers + + // + // LD ||A3| | B3 || + // MM || C2 || + // ST || D1 | || + + int n_chunk_cnt = hmx_ceil_div(n, n_chunk_n_cols); + for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) { + const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows); + + void *vtcm_qweight = vtcm_weight; + void *vtcm_weight_bufs[2] = { vtcm_scratch0, vtcm_scratch1 }; + void *vtcm_output_bufs[2] = { vtcm_output, vtcm_scratch2 }; + + // prologue: A0 + const size_t n_cols_A0 = hex_smin(n - 0 * n_chunk_n_cols, n_chunk_n_cols); + { + // Use 2D DMA (n_cols rows x row_stride) to avoid 16-bit roiwidth overflow. + const uint8_t *qweight_chunk_A0 = permuted_weight; + dma_queue_push_chained(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A0), row_stride, row_stride, row_stride, n_cols_A0); + } + + { + const float *activation_chunk = activation + mr * k; + transfer_activation_chunk_threaded(ctx, vtcm_activation, activation_chunk, n_rows, k, k); + } + + // prologue: B0, A1, C0, B1 + { + // B0 + dma_queue_pop(ctx->dma[0]); + dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[0], vtcm_qweight, n_cols_A0, k, row_stride, weight_type); + + // A1 + const size_t n_cols_A1 = hex_smin(n - 1 * n_chunk_n_cols, n_chunk_n_cols); + if (1 < n_chunk_cnt) { + const uint8_t *qweight_chunk_A1 = permuted_weight + n_chunk_n_cols * row_stride; + dma_queue_push_chained(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A1), row_stride, row_stride, row_stride, n_cols_A1); + } + + // C0 + core_dot_chunk_fp16((__fp16 *) vtcm_output_bufs[0], (__fp16 *) vtcm_activation, (__fp16 *) vtcm_weight_bufs[0], vtcm_scales, + hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS), hmx_ceil_div(n_cols_A0, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS); + + // B1 + if (1 < n_chunk_cnt) { + dma_queue_pop(ctx->dma[0]); + dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[1], vtcm_qweight, n_cols_A1, k, row_stride, weight_type); + } + } + + // main loop + for (int i = 0; i < n_chunk_cnt; ++i) { + const size_t nc = i * n_chunk_n_cols; + const size_t nc_p1 = nc + 1 * n_chunk_n_cols; + const size_t nc_p2 = nc + 2 * n_chunk_n_cols; + + const size_t n_cols = hex_smin(n - nc, n_chunk_n_cols); + const size_t n_cols_p1 = hex_smin(n - nc_p1, n_chunk_n_cols); + const size_t n_cols_p2 = hex_smin(n - nc_p2, n_chunk_n_cols); + + // issue A_{i+2} + if (i + 2 < n_chunk_cnt) { + const uint8_t *qweight_chunk_p2 = permuted_weight + nc_p2 * row_stride; + dma_queue_push_chained(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_p2), row_stride, row_stride, row_stride, n_cols_p2); + } + + // wait for HMX (C_{i}) -- C_{i} is done + + // result of B_{i+1} (input of C_{i+1}) should be ready now + + // issue C_{i+1} + if (i + 1 < n_chunk_cnt) { + core_dot_chunk_fp16((__fp16 *) vtcm_output_bufs[(i + 1) % 2], (__fp16 *) vtcm_activation, (__fp16 *) vtcm_weight_bufs[(i + 1) % 2], vtcm_scales, + hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS), hmx_ceil_div(n_cols_p1, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS); + } + + // compute D_{i} + float *output_chunk = dst + (mr * n + nc); + transfer_output_chunk_threaded(ctx, output_chunk, vtcm_output_bufs[i % 2], n_rows, n_cols, n); + + // wait for DMA (A_{i+2}), compute B_{i+2} + if (i + 2 < n_chunk_cnt) { + dma_queue_pop(ctx->dma[0]); + dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[(i + 2) % 2], vtcm_qweight, n_cols_p2, k, row_stride, weight_type); + } + } + } + } + + HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); + + TIMER_STOP(total); + +#if defined(ENABLE_PROFILE_TIMERS) + FARF(HIGH, "%s: %lld us, m=%d k=%d n=%d pipeline=%d", __func__, TIMER_US(total), m, k, n, use_pipeline); + if (!use_pipeline) { + FARF(HIGH, " activation_load: %lld us, weight_load: %lld us, hmx_core: %lld us, output_store: %lld us", + TIMER_US(activation_load), TIMER_US(weight_load), TIMER_US(hmx_core), TIMER_US(output_store)); + size_t weight_size = (size_t)n * row_stride; + float bandwidth = 1e-3f * weight_size / (float)TIMER_US(weight_load); + FARF(HIGH, " weight load bandwidth: %.2f GB/s", bandwidth); + } +#endif + + return 0; +} + +// C += AB +void core_mma_chunk_fp16(__fp16 *c, const __fp16 *a, const __fp16 *b, const __fp16 *col_scales, const __fp16 *eye_tile, + int n_row_tiles, int n_col_tiles, int n_dot_tiles, bool zero_init) { + + hmx_set_output_scales(col_scales); + + for (int i = 0; i < n_row_tiles; ++i) { + for (int j = 0; j < n_col_tiles; ++j) { + Q6_mxclracc_hf(); + + const __fp16 *row_tiles = a + i * n_dot_tiles * HMX_FP16_TILE_N_ELMS; + const __fp16 *col_tiles = b + j * n_dot_tiles * HMX_FP16_TILE_N_ELMS; + + __fp16 *accum_tile = c + (i * n_col_tiles + j) * HMX_FP16_TILE_N_ELMS; + if (!zero_init) { + hmx_load_tile_pair_fp16(accum_tile, eye_tile); + } + + for (int k = 0; k < n_dot_tiles; ++k) { + int offset = k * HMX_FP16_TILE_N_ELMS; + hmx_load_tile_pair_fp16(row_tiles + offset, col_tiles + offset); + } + + hmx_consume_accumulator_fp16(accum_tile); + } + } +} + +static void transfer_activation_chunk_fp32_to_fp16(__fp16 *restrict vtcm_dst, const float *restrict src, int n_rows, + int k_block, int k_stride) { + for (int r = 0; r < n_rows; r += 2) { + int r0 = r / HMX_FP16_TILE_N_ROWS; // tile row index + int r1 = r % HMX_FP16_TILE_N_ROWS; // intra-tile row idx + + const bool next_row_valid = (r + 1) < n_rows; + + const HVX_Vector *pv_in0 = (const HVX_Vector *) (src + (r + 0) * k_stride); + const HVX_Vector *pv_in1 = (const HVX_Vector *) (src + (r + 1) * k_stride); + for (int c = 0; c < k_block; c += 32) { + HVX_Vector v0 = *pv_in0++; + HVX_Vector v1 = next_row_valid ? *pv_in1++ : Q6_V_vzero(); + + HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); + + // compute output position + int c0 = c / HMX_FP16_TILE_N_COLS; // tile column index + int tile_idx = r0 * (k_block / HMX_FP16_TILE_N_COLS) + c0; + + HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HMX_FP16_TILE_N_ELMS); + tile[r1 / 2] = v_out; + } + } +} + +typedef struct { + __fp16 *dst; + const float *src; + int n_tasks; + int n_tot_chunks; + int n_chunks_per_task; + int k_block; + int k_stride; +} activation_transfer_task_state_t; + +static void transfer_activation_chunk_worker_fn(unsigned int n, unsigned int i, void *data) { + activation_transfer_task_state_t *st = (activation_transfer_task_state_t *) data; + + for (unsigned int task_id = i; task_id < (unsigned int)st->n_tasks; task_id += n) { + // one chunk: one row + int chunk_idx = task_id * st->n_chunks_per_task; + size_t chunk_size = hex_smin(st->n_tot_chunks - chunk_idx, st->n_chunks_per_task); + + __fp16 *dst = st->dst + chunk_idx * st->k_block; + const float *src = st->src + chunk_idx * st->k_stride; + transfer_activation_chunk_fp32_to_fp16(dst, src, chunk_size, st->k_block, st->k_stride); + } +} + +void transfer_activation_chunk_threaded(struct htp_context *ctx, __fp16 *dst, const float *src, int n_rows, int k_block, int k_stride) { + assert(k_block % HMX_FP16_TILE_N_COLS == 0 && k_stride % HMX_FP16_TILE_N_COLS == 0); + assert(VLEN == 32 * sizeof(float)); + + size_t n_tot_chunks = n_rows; + size_t n_chunks_per_task = 32; // must be multiple of 32 to ensure correct destination address + + activation_transfer_task_state_t state; + state.n_tasks = (n_tot_chunks + n_chunks_per_task - 1) / n_chunks_per_task; + state.n_tot_chunks = n_tot_chunks; + state.n_chunks_per_task = n_chunks_per_task; + state.dst = dst; + state.src = src; + state.k_block = k_block; + state.k_stride = k_stride; + + worker_pool_run_func(ctx->worker_pool, transfer_activation_chunk_worker_fn, &state, ctx->n_threads); +} + +int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict out, const float *restrict x, const uint8_t *restrict w, int m, + int k, int n, int weight_type) { + // Runtime check -- k >= 16384 exceeds 2D DMA limit + if (k >= 16384) { + FARF(HIGH, "%s: k=%d exceeds 2D DMA limit", __func__, k); + return -1; + } + // assume k % 32 == 0 && n % 32 == 0 + const size_t row_stride = get_x4x2_row_stride(weight_type, k); + if (row_stride == 0) { + return -1; + } + + const size_t vtcm_budget = ctx->vtcm_scratch_size; + + const size_t M_BLOCK_SIZE = 512; + const size_t N_BLOCK_SIZE = 512; + const size_t K_BLOCK_SIZE = 512; + + // Compute precise buffer sizes + const size_t sub_row_stride_alloc = get_x4x2_row_stride(weight_type, K_BLOCK_SIZE); + const size_t weight_size = hex_align_up(N_BLOCK_SIZE * K_BLOCK_SIZE * sizeof(__fp16), HMX_FP16_TILE_SIZE); + const size_t act_size = hex_align_up(M_BLOCK_SIZE * K_BLOCK_SIZE * sizeof(__fp16), HMX_FP16_TILE_SIZE); + const size_t out_size = hex_align_up(M_BLOCK_SIZE * N_BLOCK_SIZE * sizeof(__fp16), HMX_FP16_TILE_SIZE); + const size_t scratch0_sz = hex_align_up(N_BLOCK_SIZE * sub_row_stride_alloc, HMX_FP16_TILE_SIZE); + const size_t scratch1_sz = hex_align_up(M_BLOCK_SIZE * K_BLOCK_SIZE * sizeof(float), HMX_FP16_TILE_SIZE); + + const size_t total_vtcm = weight_size + act_size + out_size + scratch0_sz + scratch1_sz + HMX_FP16_TILE_SIZE + 256; + if (total_vtcm > vtcm_budget) { + FARF(HIGH, "%s: VTCM too small: need %zu have %zu (m=%d k=%d n=%d)", __func__, total_vtcm, vtcm_budget, m, k, n); + return -1; + } + + uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base; + __fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_size); + __fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, act_size); + __fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, out_size); + uint8_t *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch0_sz); + uint8_t *vtcm_scratch1 = vtcm_seq_alloc(&vtcm_ptr, scratch1_sz); + __fp16 *vtcm_eye_tile = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, HMX_FP16_TILE_SIZE); + __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); + assert((size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base) <= vtcm_budget); + + FARF(MEDIUM, "%s: m=%d k=%d n=%d wtype=%d vtcm=%zu/%zu", + __func__, m, k, n, weight_type, + (size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget); + + // initialize eye tile (32x32 identity matrix) + { + HVX_Vector v; + v = Q6_V_vzero(); + v = Q6_Vw_vinsert_VwR(v, 0x3c000000); + v = Q6_V_vror_VR(v, VLEN - 4); + v = Q6_Vw_vinsert_VwR(v, 0x00003c00); + for (int i = 0; i < 16; ++i) { + ((HVX_Vector *) vtcm_eye_tile)[i] = v; + v = Q6_V_vror_VR(v, VLEN - 8); + } + } + hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // fp16: 1.0 + + TIMER_DEFINE(fetch); + TIMER_DEFINE(act_load); + TIMER_DEFINE(wt_dequant); + TIMER_DEFINE(core); + + HAP_compute_res_hmx_lock(ctx->vtcm_rctx); + + for (size_t mr = 0; mr < m; mr += M_BLOCK_SIZE) { + size_t m_blk_sz = hex_smin(m - mr, M_BLOCK_SIZE); + for (size_t nc = 0; nc < n; nc += N_BLOCK_SIZE) { + size_t n_blk_sz = hex_smin(n - nc, N_BLOCK_SIZE); + + const int n_row_tiles = hmx_ceil_div(m_blk_sz, HMX_FP16_TILE_N_ROWS); + const int n_col_tiles = hmx_ceil_div(n_blk_sz, HMX_FP16_TILE_N_COLS); + + for (size_t kk = 0; kk < k; kk += K_BLOCK_SIZE) { + size_t k_blk_sz = hex_smin(k - kk, K_BLOCK_SIZE); + + TIMER_START(fetch); + // fetch activation block into VTCM + { + const float *activation_block = x + mr * k + kk; + + dma_queue_push_chained(ctx->dma[0], + dma_make_ptr(vtcm_scratch1, activation_block), + k_blk_sz * sizeof(float), + k * sizeof(float), + k_blk_sz * sizeof(float), + m_blk_sz); + } + + // fetch weight block into VTCM (x4x2 sub-block: quants + scales) + { + qweight_fetch_task_state_t s; + + const bool is_q4 = (weight_type == HTP_TYPE_Q4_0 || weight_type == HTP_TYPE_IQ4_NL); + const int blk_start = kk / QK_Q4_0x4x2; + const int nb_sub = (k_blk_sz + QK_Q4_0x4x2 - 1) / QK_Q4_0x4x2; + const int full_qrow = is_q4 ? (k / 2) : k; + const size_t sub_row_stride = get_x4x2_row_stride(weight_type, k_blk_sz); + + s.dst = vtcm_scratch0; + s.src = w + nc * row_stride; + s.n_rows = n_blk_sz; + s.src_stride = row_stride; + s.dst_stride = sub_row_stride; + s.quant_off = is_q4 ? (blk_start * (QK_Q4_0x4x2 / 2)) : (blk_start * QK_Q8_0x4x2); + s.quant_width = is_q4 ? (nb_sub * (QK_Q4_0x4x2 / 2)) : (nb_sub * QK_Q8_0x4x2); + s.scale_off = full_qrow + blk_start * HMX_X4X2_DBLK_SIZE; + s.scale_width = nb_sub * HMX_X4X2_DBLK_SIZE; + + // 2D DMA: quants sub-range + dma_queue_push_chained(ctx->dma[0], dma_make_ptr(s.dst, s.src + s.quant_off), + s.dst_stride, s.src_stride, s.quant_width, s.n_rows); + // 2D DMA: scales sub-range + dma_queue_push_chained(ctx->dma[0], dma_make_ptr(s.dst + s.quant_width, s.src + s.scale_off), + s.dst_stride, s.src_stride, s.scale_width, s.n_rows); + } + TIMER_STOP(fetch); + + TIMER_START(act_load); + // load activation block + { + dma_queue_pop(ctx->dma[0]); // wait for act DNA + transfer_activation_chunk_threaded(ctx, vtcm_activation, (float *) vtcm_scratch1, m_blk_sz, k_blk_sz, k_blk_sz); + } + TIMER_STOP(act_load); + + TIMER_START(wt_dequant); + // dequantize weight block + { + dma_queue_pop(ctx->dma[0]); + dma_queue_pop(ctx->dma[0]); + // vtcm_scratch0 is used to store the qweight chunk + // worker_pool_run_func already returned, so fetch is done + const size_t sub_row_stride = get_x4x2_row_stride(weight_type, k_blk_sz); + dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight, vtcm_scratch0, + n_blk_sz, k_blk_sz, sub_row_stride, weight_type); + } + TIMER_STOP(wt_dequant); + + // core mma + TIMER_START(core); + { + core_mma_chunk_fp16(vtcm_output, vtcm_activation, vtcm_weight, vtcm_scales, vtcm_eye_tile, n_row_tiles, + n_col_tiles, k_blk_sz / HMX_FP16_TILE_N_COLS, kk == 0); + } + TIMER_STOP(core); + } + + // store output block + { + float *output_block = out + (mr * n + nc); + transfer_output_chunk_threaded(ctx, output_block, vtcm_output, m_blk_sz, n_blk_sz, n); + } + } + } + + HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); + +#if defined(ENABLE_PROFILE_TIMERS) + FARF(HIGH, "fetch: %lld us, act_load: %lld us, wt_dequant: %lld us, core: %lld us", + TIMER_US(fetch), TIMER_US(act_load), TIMER_US(wt_dequant), TIMER_US(core)); +#endif + return 0; +} diff --git a/ggml/src/ggml-hexagon/htp/hmx-ops.h b/ggml/src/ggml-hexagon/htp/hmx-ops.h new file mode 100644 index 00000000000..b36c8d129ba --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hmx-ops.h @@ -0,0 +1,72 @@ +// HMX operation entry-point declarations. +// Ported from htp-ops-lib/include/dsp/ops.h (renamed, benchmark kernels removed). (https://github.com/haozixu/htp-ops-lib) + +#ifndef HMX_OPS_H +#define HMX_OPS_H + +#include +#include + +#ifndef restrict +# define restrict __restrict +#endif + +#ifdef __cplusplus +extern "C" { +#endif + +struct htp_context; // forward declaration + +typedef struct { + float *dst; + const float *activation; + const __fp16 *permuted_weight; + int m; + int k; + int n; + int act_stride; + int weight_stride; + int dst_stride; + int ne02; + int ne03; + int ne12; + int ne13; + size_t src0_nb2; + size_t src0_nb3; + size_t src1_nb2; + size_t src1_nb3; + size_t dst_nb2; + size_t dst_nb3; +} hmx_matmul_w16a32_batched_params_t; + +// HMX matrix multiplication — tile-permuted FP16 weights, FP32 activation/output +// act_stride: activation row stride in elements (= k for contiguous, or +// nb[1]/sizeof(float) for permuted tensors like attention Q). +// weight_stride: weight row stride in elements (= k for compact weights, or +// nb[1]/sizeof(__fp16) for permuted KV-cache views used by QK). +int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, + float *restrict dst, + const float *activation, + const __fp16 *permuted_weight, + int m, int k, int n, + int act_stride, + int weight_stride); + +// Batched F16 wrapper over hmx_mat_mul_permuted_w16a32. +// Batch semantics match ggml_mul_mat(): src0 broadcasts to src1 in dims 2/3. +int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, + const hmx_matmul_w16a32_batched_params_t *params); + +// HMX matrix multiplication — tile-permuted quantised weights (Q4_0/Q8_0/IQ4_NL) +int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, + float *restrict dst, + const float *activation, + const uint8_t *permuted_weight, + int m, int k, int n, + int weight_type); + +#ifdef __cplusplus +} +#endif + +#endif // HMX_OPS_H diff --git a/ggml/src/ggml-hexagon/htp/hmx-profile.h b/ggml/src/ggml-hexagon/htp/hmx-profile.h new file mode 100644 index 00000000000..01eece720c5 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hmx-profile.h @@ -0,0 +1,34 @@ +// Conditional fine-grained profiling macros for HMX operations. +// +// Define ENABLE_PROFILE_TIMERS (via compiler flag or before including this +// header) to instrument sub-operation latencies with HAP qtimer. When the +// macro is not defined the TIMER_* helpers expand to nothing so there is zero +// overhead. +// +// Usage: +// TIMER_DEFINE(my_phase); // declare accumulator variable +// TIMER_START(my_phase); // snapshot start time +// ... work ... +// TIMER_STOP(my_phase); // accumulate elapsed ticks +// FARF(ALWAYS, "my_phase: %lld us", TIMER_US(my_phase)); + +#ifndef HMX_PROFILE_H +#define HMX_PROFILE_H + +#include + +// #define ENABLE_PROFILE_TIMERS + +#if defined(ENABLE_PROFILE_TIMERS) +# define TIMER_DEFINE(name) int64_t name##_ticks = 0 +# define TIMER_START(name) int64_t name##_t0 = HAP_perf_get_qtimer_count() +# define TIMER_STOP(name) name##_ticks += HAP_perf_get_qtimer_count() - name##_t0 +# define TIMER_US(name) HAP_perf_qtimer_count_to_us(name##_ticks) +#else +# define TIMER_DEFINE(name) +# define TIMER_START(name) +# define TIMER_STOP(name) +# define TIMER_US(name) 0LL +#endif + +#endif // HMX_PROFILE_H diff --git a/ggml/src/ggml-hexagon/htp/hmx-utils.h b/ggml/src/ggml-hexagon/htp/hmx-utils.h new file mode 100644 index 00000000000..aacfbcda287 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hmx-utils.h @@ -0,0 +1,88 @@ +// HMX tile-level inline helpers (FP16 32x32 tile operations). +// Ported from htp-ops-lib/include/dsp/hmx_utils.h. (https://github.com/haozixu/htp-ops-lib) + +#ifndef HMX_UTILS_H +#define HMX_UTILS_H + +#include +#include + +#define HMX_FP16_TILE_N_ROWS 32 +#define HMX_FP16_TILE_N_COLS 32 +#define HMX_FP16_TILE_N_ELMS 1024 +#define HMX_FP16_TILE_SIZE 2048 + +#define HMX_INLINE_ALWAYS inline __attribute__((unused, always_inline)) + +static HMX_INLINE_ALWAYS void hmx_set_output_scales(const void *scales) { + asm volatile("bias = mxmem2(%0)" :: "r"(scales)); +} + +// Initialise aligned 256-byte area with scale vector + zero padding. +static HMX_INLINE_ALWAYS void hmx_init_column_scales(void *out_scales, HVX_Vector v_scale) { + HVX_Vector *pv = (HVX_Vector *)out_scales; + *pv++ = v_scale; + *pv = Q6_V_vzero(); +} + +// Load multiple contiguous tiles with :deep streaming. +// Rt = total region size - 1; the hardware streams through [Rs, Rs + Rt]. +// IMPORTANT: the tile region [Rs, Rs + Rt] must NOT cross a VTCM 4 MB bank +// boundary, otherwise the mxmem instruction will raise a precise bus error. +// Callers must ensure their VTCM layout satisfies this constraint. +static HMX_INLINE_ALWAYS void hmx_load_tiles_fp16(const __fp16 *row_tiles, + const __fp16 *col_tiles, + size_t n_tiles) { + size_t limit = n_tiles * HMX_FP16_TILE_SIZE - 1; + asm volatile( + "{ activation.hf = mxmem(%0, %1):deep\n" + "weight.hf = mxmem(%2, %3) }\n" + :: "r"(row_tiles), "r"(limit), "r"(col_tiles), "r"(limit) + : "memory"); +} + +// Load a single activation+weight tile pair (no :deep streaming). +// Rt defines the accessible region [Rs, Rs+Rt]. Following the reference formula +// (limit = n_tiles * HMX_FP16_TILE_SIZE - 1), for a single tile Rt = 2047. +// The original code used Rt=0x7FFF (32 KB region); when dynamic VTCM allocation +// places a tile near a 4 MB bank boundary, the oversized region crosses it and +// triggers a precise bus error (0x2601). Rt=2047 confines accesses to exactly +// one 2048-byte tile while covering all 16 HVX vectors (offsets 0..2047). +static HMX_INLINE_ALWAYS void hmx_load_tile_pair_fp16(const __fp16 *act_tile, + const __fp16 *wt_tile) { + asm volatile( + "{ activation.hf = mxmem(%0, %1)\n" + "weight.hf = mxmem(%2, %3) }\n" + :: "r"(act_tile), "r"(2047), + "r"(wt_tile), "r"(2047) + : "memory"); +} + +static HMX_INLINE_ALWAYS void hmx_consume_accumulator_fp16(__fp16 *out) { + // Use the combined convert-and-store instruction (matches the reference + // Q6_mxmem_AR_after_hf intrinsic). The previous two-instruction sequence + // "cvt.hf = acc(2); mxmem = cvt" used an undocumented Rs=2 parameter. + asm volatile( + "mxmem(%0, %1):after.hf = acc\n" + :: "r"(out), "r"(0) + : "memory"); +} + +// Compute inner product of two vectors of tiles and store result. +static HMX_INLINE_ALWAYS void hmx_dot_fp16(__fp16 *out, + const __fp16 *row_tiles, + const __fp16 *col_tiles, + size_t n_tiles) { + hmx_load_tiles_fp16(row_tiles, col_tiles, n_tiles); + hmx_consume_accumulator_fp16(out); +} + +// --- VTCM sequential allocator (from htp-ops-lib/include/dsp/vtcm_mgr.h) --- + +static inline uint8_t *vtcm_seq_alloc(uint8_t **vtcm_ptr, size_t size) { + uint8_t *p = *vtcm_ptr; + *vtcm_ptr += size; + return p; +} + +#endif // HMX_UTILS_H diff --git a/ggml/src/ggml-hexagon/htp/htp-ctx.h b/ggml/src/ggml-hexagon/htp/htp-ctx.h index a707d98239c..a92acfa0a85 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ctx.h +++ b/ggml/src/ggml-hexagon/htp/htp-ctx.h @@ -30,6 +30,12 @@ struct htp_context { atomic_bool vtcm_needs_release; uint32_t opmask; + + // HMX acceleration fields (v73+, enabled by compile-time HTP_HAS_HMX) +#ifdef HTP_HAS_HMX + int hmx_enabled; // Runtime flag: HMX initialisation succeeded + size_t vtcm_scratch_size; // Usable dynamic scratch (vtcm_size minus tail reservation) +#endif }; #endif /* HTP_CTX_H */ diff --git a/ggml/src/ggml-hexagon/htp/htp-msg.h b/ggml/src/ggml-hexagon/htp/htp-msg.h index 56bc5b622c5..391148be0e9 100644 --- a/ggml/src/ggml-hexagon/htp/htp-msg.h +++ b/ggml/src/ggml-hexagon/htp/htp-msg.h @@ -32,13 +32,14 @@ enum htp_status { // Duplicated here because we can't include full ggml.h in the htp build. // We have some static_asserts in the cpp code to ensure things are in sync. enum htp_data_type { - HTP_TYPE_F32 = 0, - HTP_TYPE_F16 = 1, - HTP_TYPE_Q4_0 = 2, - HTP_TYPE_Q8_0 = 8, - HTP_TYPE_I32 = 26, - HTP_TYPE_I64 = 27, - HTP_TYPE_MXFP4 = 39, + HTP_TYPE_F32 = 0, + HTP_TYPE_F16 = 1, + HTP_TYPE_Q4_0 = 2, + HTP_TYPE_Q8_0 = 8, + HTP_TYPE_IQ4_NL = 20, + HTP_TYPE_I32 = 26, + HTP_TYPE_I64 = 27, + HTP_TYPE_MXFP4 = 39, HTP_TYPE_COUNT }; @@ -87,6 +88,8 @@ static inline size_t htp_t_block_size(uint32_t t) { return QK4_0; case HTP_TYPE_Q8_0: return QK8_0; + case HTP_TYPE_IQ4_NL: + return QK4_NL; case HTP_TYPE_MXFP4: return QK_MXFP4; default: @@ -105,6 +108,8 @@ static inline size_t htp_type_nbytes(uint32_t t) { return sizeof(block_q4_0); case HTP_TYPE_Q8_0: return sizeof(block_q8_0); + case HTP_TYPE_IQ4_NL: + return sizeof(block_iq4_nl); case HTP_TYPE_MXFP4: return sizeof(block_mxfp4); default: diff --git a/ggml/src/ggml-hexagon/htp/htp_iface.idl b/ggml/src/ggml-hexagon/htp/htp_iface.idl index 9ebd937e46d..2dc716cb441 100644 --- a/ggml/src/ggml-hexagon/htp/htp_iface.idl +++ b/ggml/src/ggml-hexagon/htp/htp_iface.idl @@ -7,7 +7,7 @@ #include "remote.idl" interface htp_iface : remote_handle64 { - AEEResult start(in uint32 sess_id, in uint64 dsp_queue_id, in uint32 n_hvx); + AEEResult start(in uint32 sess_id, in uint64 dsp_queue_id, in uint32 n_hvx, in uint32 use_hmx); AEEResult stop(); AEEResult enable_etm(); AEEResult disable_etm(); diff --git a/ggml/src/ggml-hexagon/htp/hvx-base.h b/ggml/src/ggml-hexagon/htp/hvx-base.h index 3e6a8579b1f..db05ab40d28 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-base.h +++ b/ggml/src/ggml-hexagon/htp/hvx-base.h @@ -9,6 +9,9 @@ #include "hex-utils.h" #include "hvx-types.h" +#define hvx_vmem(A) *((HVX_Vector *)(A)) +#define hvx_vmemu(A) *((HVX_UVector *)(A)) + static inline void hvx_vec_store_u(void * restrict dst, uint32_t n, HVX_Vector v) { // Rotate as needed. v = Q6_V_vlalign_VVR(v, v, (size_t) dst); @@ -112,11 +115,15 @@ static inline HVX_VectorPred hvx_vec_is_nan_f16(HVX_Vector v) { return Q6_Q_and_QQ(p_exp, p_frac); } -static inline HVX_Vector hvx_vec_f32_to_f16(HVX_Vector v0, HVX_Vector v1) { - const HVX_Vector zero = Q6_V_vsplat_R(0); +static inline HVX_Vector hvx_vec_f32_to_f16_shuff(HVX_Vector v0, HVX_Vector v1) { + const HVX_Vector zero = Q6_V_vzero(); HVX_Vector q0 = Q6_Vqf32_vadd_VsfVsf(v0, zero); HVX_Vector q1 = Q6_Vqf32_vadd_VsfVsf(v1, zero); - HVX_Vector v = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(q1, q0))); + return Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(q1, q0)); +} + +static inline HVX_Vector hvx_vec_f32_to_f16(HVX_Vector v0, HVX_Vector v1) { + HVX_Vector v = Q6_Vh_vdeal_Vh(hvx_vec_f32_to_f16_shuff(v0, v1)); #if __HVX_ARCH__ < 79 // replace NaNs with -INF, older arches produce NaNs for (-INF + 0.0) @@ -128,6 +135,30 @@ static inline HVX_Vector hvx_vec_f32_to_f16(HVX_Vector v0, HVX_Vector v1) { return v; } +#if __HVX_ARCH__ >= 79 +static inline HVX_VectorPair hvx_vec_f16_to_f32_shuff(HVX_Vector v) { + const HVX_Vector one = hvx_vec_splat_f16(1.0); + HVX_VectorPair p = Q6_Wsf_vmpy_VhfVhf(v, one); + return Q6_W_vcombine_VV(Q6_V_hi_W(p), Q6_V_lo_W(p)); +} +static inline HVX_VectorPair hvx_vec_f16_to_f32(HVX_Vector v) { + const HVX_Vector one = hvx_vec_splat_f16(1.0); + HVX_VectorPair p = Q6_Wsf_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(v), one); + return Q6_W_vcombine_VV(Q6_V_hi_W(p), Q6_V_lo_W(p)); +} +#else +static inline HVX_VectorPair hvx_vec_f16_to_f32_shuff(HVX_Vector v) { + const HVX_Vector one = hvx_vec_splat_f16(1.0); + HVX_VectorPair p = Q6_Wqf32_vmpy_VhfVhf(v, one); + return Q6_W_vcombine_VV(Q6_Vsf_equals_Vqf32(Q6_V_hi_W(p)), Q6_Vsf_equals_Vqf32(Q6_V_lo_W(p))); +} +static inline HVX_VectorPair hvx_vec_f16_to_f32(HVX_Vector v) { + const HVX_Vector one = hvx_vec_splat_f16(1.0); + HVX_VectorPair p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(v), one); + return Q6_W_vcombine_VV(Q6_Vsf_equals_Vqf32(Q6_V_hi_W(p)), Q6_Vsf_equals_Vqf32(Q6_V_lo_W(p))); +} +#endif + /* Q6_Vsf_equals_Vw is only available on v73+.*/ #if __HVX_ARCH__ < 73 static inline HVX_Vector hvx_vec_i32_to_qf32(HVX_Vector const in) diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index 2a3f9e562b7..ef9cba8ecc1 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -25,6 +25,10 @@ #include "htp-ops.h" #include "worker-pool.h" +#ifdef HTP_HAS_HMX +#include "hmx-ops.h" +#endif // HTP_HAS_HMX + AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) { struct htp_context * ctx; int err = 0; @@ -163,6 +167,9 @@ static int vtcm_acquire(struct htp_context * ctx) { } ctx->vtcm_inuse = true; + + + return 0; } @@ -246,7 +253,7 @@ static void vtcm_free(struct htp_context * ctx) { static void htp_packet_callback(dspqueue_t queue, int error, void * context); static void htp_error_callback(dspqueue_t queue, int error, void * context); -AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_queue_id, uint32 n_hvx) { +AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_queue_id, uint32 n_hvx, uint32 use_hmx) { struct htp_context * ctx = (struct htp_context *) handle; if (!ctx) { @@ -280,6 +287,21 @@ AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_que return AEE_ENOMEMORY; } +#ifdef HTP_HAS_HMX + if (use_hmx) { + ctx->vtcm_scratch_size = ctx->vtcm_size; + ctx->hmx_enabled = 1; + + FARF(HIGH, "HMX enabled: vtcm-scratch %zu", ctx->vtcm_scratch_size); + } else { + // HMX disabled: skip HMX initialisation so the + // dispatch loop falls through to the HVX compute paths. + ctx->hmx_enabled = 0; + ctx->vtcm_scratch_size = ctx->vtcm_size; + FARF(HIGH, "HMX disabled (use_hmx=0): vtcm-scratch %zu", ctx->vtcm_scratch_size); + } +#endif + qurt_sysenv_max_hthreads_t hw_threads; qurt_sysenv_get_max_hw_threads(&hw_threads); uint32_t hw_nhvx = (qurt_hvx_get_units() >> 8) & 0xFF; @@ -340,6 +362,12 @@ AEEResult htp_iface_stop(remote_handle64 handle) { for (int i = 0; i < ctx->n_threads; i++) { dma_queue_delete(ctx->dma[i]); } +#ifdef HTP_HAS_HMX + if (ctx->hmx_enabled) { + ctx->hmx_enabled = 0; + } +#endif + vtcm_free(ctx); @@ -375,8 +403,9 @@ static int send_htp_rsp(struct htp_context * c, struct dspqueue_buffer * bufs, size_t n_bufs, struct profile_data * prof) { - // Prep response struct + // Prep response struct (zero-init to clear cmp/unused union) struct htp_general_rsp rsp; + memset(&rsp, 0, sizeof(rsp)); rsp.op = op; rsp.status = status; rsp.prof_usecs = prof->usecs; @@ -1037,6 +1066,210 @@ static void proc_flash_attn_ext_req(struct htp_context * ctx, send_htp_rsp(ctx, req->op, rsp_status, &bufs[last_buf], 1, &prof); } +#ifdef HTP_HAS_HMX +// --------------------------------------------------------------------------- +// HMX operation wrappers — self-contained, bypass htp_ops_context / htp_spad. +// VTCM, DMA and thread dispatch are managed inside the HMX kernels. +// --------------------------------------------------------------------------- + +static void proc_hmx_matmul_req(struct htp_context * ctx, + struct htp_general_req * req, + struct dspqueue_buffer * bufs, + size_t n_bufs) { + // HMX weight tile requires N to be 32-aligned. + if (req->src0.ne[1] % 32 != 0) { + proc_matmul_req(ctx, req, bufs, n_bufs); + return; + } + + const bool is_batched = (req->src0.ne[2] * req->src0.ne[3] > 1 || + req->src1.ne[2] * req->src1.ne[3] > 1); + + // Quantised HMX kernels only handle flat 2D matmul (host already rejects + // batched quantised, but guard here too). F16 batched matmul is handled + // by the dedicated wrapper in hmx-matmul-ops.c. + if (is_batched && + req->src0.type != HTP_TYPE_F16) { + proc_matmul_req(ctx, req, bufs, n_bufs); + return; + } + + // HMX assumes contiguous row-major layout. Fall back for permuted + // tensors where strides are non-monotonic (e.g. transposed KV cache). + if (req->src0.nb[0] > req->src0.nb[1] || + req->src1.nb[0] > req->src1.nb[1]) { + proc_matmul_req(ctx, req, bufs, n_bufs); + return; + } + + // M alignment: when M > 32 but not 32-aligned, we split into + // HMX (first m_hmx = M & ~31 rows) + HVX (remaining m_tail rows). + // When M <= 32 and not 32-aligned, fall back entirely to HVX. + const int m_total = (int) req->src1.ne[1]; + const int m_tail = m_total % 32; + const int m_hmx = m_total - m_tail; + + if (m_hmx == 0) { + proc_matmul_req(ctx, req, bufs, n_bufs); + return; + } + + // HMX only supports F16, Q4_0, Q8_0, IQ4_NL weights. + // Other types (e.g. MXFP4) fall back to HVX. + { + uint32_t wtype = req->src0.type; + if (wtype != HTP_TYPE_F16 && + wtype != HTP_TYPE_Q4_0 && + wtype != HTP_TYPE_Q8_0 && + wtype != HTP_TYPE_IQ4_NL) { + proc_matmul_req(ctx, req, bufs, n_bufs); + return; + } + // Quantised HMX path requires K aligned to 256 (x4x2 super-block). + // F16 HMX path requires K aligned to 32 (tile width). + if (wtype != HTP_TYPE_F16 && req->src0.ne[0] % 256 != 0) { + proc_matmul_req(ctx, req, bufs, n_bufs); + return; + } + if (wtype == HTP_TYPE_F16 && req->src0.ne[0] % 32 != 0) { + proc_matmul_req(ctx, req, bufs, n_bufs); + return; + } + } + + (void) n_bufs; + + struct dspqueue_buffer rsp_bufs[1]; + rsp_bufs[0].fd = bufs[2].fd; + rsp_bufs[0].ptr = bufs[2].ptr; + rsp_bufs[0].size = bufs[2].size; + rsp_bufs[0].offset = bufs[2].offset; + rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | + DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); + + // src0 = weights, src1 = activation, dst = output + void * wgt = (void *) bufs[0].ptr; + float * act = (float *) bufs[1].ptr; + float * dst = (float *) bufs[2].ptr; + + int k = (int) req->src0.ne[0]; // inner dimension + int n = (int) req->src0.ne[1]; // weight columns + + + struct profile_data prof; + profile_start(&prof); + + uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; + + // --- Phase 1: HMX on the first m_hmx (32-aligned) rows --- + if (vtcm_acquire(ctx) == AEE_SUCCESS) { + int ret = -1; + + const int ne02 = (int) req->src0.ne[2]; + const int ne03 = (int) req->src0.ne[3]; + const int ne12 = (int) req->src1.ne[2]; + const int ne13 = (int) req->src1.ne[3]; + // Row strides in elements. For compact tensors these equal k; for + // permuted attention views they can be larger, so pass the real stride. + const int act_stride = (int)(req->src1.nb[1] / sizeof(float)); + const int weight_stride = (int)(req->src0.nb[1] / sizeof(__fp16)); + + switch (req->src0.type) { + case HTP_TYPE_F16: + if (is_batched) { + hmx_matmul_w16a32_batched_params_t batch_params = { + .dst = dst, + .activation = act, + .permuted_weight = (const __fp16 *) wgt, + .m = m_hmx, + .k = k, + .n = n, + .act_stride = act_stride, + .weight_stride = weight_stride, + .dst_stride = (int)(req->dst.nb[1] / sizeof(float)), + .ne02 = ne02, + .ne03 = ne03, + .ne12 = ne12, + .ne13 = ne13, + .src0_nb2 = req->src0.nb[2], + .src0_nb3 = req->src0.nb[3], + .src1_nb2 = req->src1.nb[2], + .src1_nb3 = req->src1.nb[3], + .dst_nb2 = req->dst.nb[2], + .dst_nb3 = req->dst.nb[3], + }; + ret = hmx_mat_mul_permuted_w16a32_batched(ctx, &batch_params); + } else { + ret = hmx_mat_mul_permuted_w16a32(ctx, dst, act, + (const __fp16 *) wgt, + m_hmx, k, n, + act_stride, + weight_stride); + } + break; + default: + ret = hmx_mat_mul_permuted_qk_0_d16a32(ctx, dst, act, + (const uint8_t *) wgt, + m_hmx, k, n, (int) req->src0.type); + break; + } + + if (ret == 0) { + rsp_status = HTP_STATUS_OK; + } else { + FARF(HIGH, "HMX matmul failed (ret=%d), falling back to HVX", ret); + vtcm_release(ctx); + req->flags &= ~HTP_OPFLAGS_SKIP_QUANTIZE; + proc_matmul_req(ctx, req, bufs, n_bufs); + return; + } + vtcm_release(ctx); + } + + // --- Phase 2: HVX on the remaining m_tail rows --- + if (m_tail > 0 && rsp_status == HTP_STATUS_OK) { + struct htp_ops_context octx = { 0 }; + octx.ctx = ctx; + octx.src0 = req->src0; // weights: unchanged + octx.src1 = req->src1; + octx.src1.ne[1] = m_tail; // only tail rows + octx.dst = req->dst; + octx.dst.ne[1] = m_tail; // only tail rows + // Always re-quantize tail src1: HMX Phase 1 overwrites VTCM, + // so any previously cached quantized data (SKIP_QUANTIZE pipeline) + // is invalid. + octx.flags = req->flags & ~HTP_OPFLAGS_SKIP_QUANTIZE; + octx.op = req->op; + octx.n_threads = ctx->n_threads; + + // Offset activation and dst pointers past the HMX-processed rows. + // Use nb[1] (row stride in bytes) to compute the byte offset. + octx.src0.data = (uint32_t) bufs[0].ptr; + octx.src1.data = (uint32_t)((uint8_t *) bufs[1].ptr + (size_t) m_hmx * req->src1.nb[1]); + octx.dst.data = (uint32_t)((uint8_t *) bufs[2].ptr + (size_t) m_hmx * req->dst.nb[1]); + + FARF(HIGH, "proc_hmx_matmul: HVX tail m_tail=%d act=%p dst=%p", + m_tail, (void *)(uintptr_t) octx.src1.data, (void *)(uintptr_t) octx.dst.data); + + if (vtcm_acquire(ctx) == AEE_SUCCESS) { + uint32_t hvx_ret = op_matmul(&octx); + vtcm_release(ctx); + if (hvx_ret != HTP_STATUS_OK) { + FARF(ERROR, "HVX tail matmul failed (ret=%u)", hvx_ret); + rsp_status = HTP_STATUS_INTERNAL_ERR; + } + } else { + rsp_status = HTP_STATUS_INTERNAL_ERR; + } + } + + profile_stop(&prof); + + send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); +} + +#endif // HTP_HAS_HMX + static void htp_packet_callback(dspqueue_t queue, int error, void * context) { struct htp_context * ctx = (struct htp_context *) context; @@ -1089,7 +1322,14 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) { FARF(ERROR, "Bad matmul-req buffer list"); continue; } - proc_matmul_req(ctx, &req, bufs, n_bufs); +#ifdef HTP_HAS_HMX + if (ctx->hmx_enabled) { + proc_hmx_matmul_req(ctx, &req, bufs, n_bufs); + } else +#endif + { + proc_matmul_req(ctx, &req, bufs, n_bufs); + } break; case HTP_OP_MUL_MAT_ID: From e1cdce46c5e795932ad9bc1470c38a31cf1bd05c Mon Sep 17 00:00:00 2001 From: Rail Chabdarov Date: Thu, 19 Mar 2026 19:14:08 +0100 Subject: [PATCH 325/831] hip: Avoid compiler bug in RDNA code generation during debug builds on Windows (llama/20655) --- ggml/src/ggml-hip/CMakeLists.txt | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/ggml/src/ggml-hip/CMakeLists.txt b/ggml/src/ggml-hip/CMakeLists.txt index c2357722629..f96c6e09a9b 100644 --- a/ggml/src/ggml-hip/CMakeLists.txt +++ b/ggml/src/ggml-hip/CMakeLists.txt @@ -129,6 +129,11 @@ endif() if (CXX_IS_HIPCC) set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES LANGUAGE CXX) + if (WIN32 AND CMAKE_BUILD_TYPE STREQUAL "Debug") + # CMake on Windows doesn't support the HIP language yet. + # Therefore we workaround debug build's failure on HIP backend this way. + set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES COMPILE_FLAGS "-O2 -g") + endif() target_link_libraries(ggml-hip PRIVATE hip::device) else() set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES LANGUAGE HIP) From 65d820a44a6c95b88ae121918202fea9b4ba0d10 Mon Sep 17 00:00:00 2001 From: Sundaram krishnan <104441812+sundaram123krishnan@users.noreply.github.com> Date: Fri, 20 Mar 2026 01:06:23 +0530 Subject: [PATCH 326/831] ggml: guard KleidiAI DOWNLOAD_EXTRACT_TIMESTAMP for cmake < 3.24 (llama/20767) --- ggml/src/ggml-cpu/CMakeLists.txt | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index 7c062a62995..1a1bbc9f2be 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -572,9 +572,11 @@ function(ggml_add_cpu_backend_variant_impl tag_name) set(KLEIDIAI_FETCH_ARGS URL ${KLEIDIAI_DOWNLOAD_URL} - DOWNLOAD_EXTRACT_TIMESTAMP NEW URL_HASH MD5=${KLEIDIAI_ARCHIVE_MD5} ) + if (CMAKE_VERSION VERSION_GREATER_EQUAL "3.24") + list(APPEND KLEIDIAI_FETCH_ARGS DOWNLOAD_EXTRACT_TIMESTAMP NEW) + endif() if (CMAKE_VERSION VERSION_GREATER_EQUAL "3.28") FetchContent_Declare(KleidiAI_Download From 46dcb35aa38f10eb5e1eb6f7c2de071e928a60bc Mon Sep 17 00:00:00 2001 From: hipudding Date: Fri, 20 Mar 2026 17:08:39 +0800 Subject: [PATCH 327/831] CANN: add BF16 support for core operators (llama/20152) * CANN: add BF16 support for core operators Add BF16 (bfloat16) type support to the CANN backend for the following operators: MUL_MAT, MUL_MAT_ID, GET_ROWS, SET_ROWS, CPY, CONT, and OUT_PROD. This enables BF16 models to run on Ascend NPUs. * CANN: skip NZ weight format for BF16 and add 310P compile guards NZ weight format conversion does not support BF16 tensors, skip it in set_tensor, get_alloc_size and mul_mat. Remove BF16 from MUL_MAT_ID and OUT_PROD as there are no BF16 use cases. Add #ifndef ASCEND_310P guards for all BF16 operator support since 310P does not support BF16. --- ggml/src/ggml-cann/aclnn_ops.cpp | 12 +++++++++--- ggml/src/ggml-cann/ggml-cann.cpp | 29 +++++++++++++++++++++++++---- 2 files changed, 34 insertions(+), 7 deletions(-) diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index 9b736636def..b45774dde34 100644 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -1788,9 +1788,11 @@ void ggml_cann_get_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst) { ggml_tensor * src0 = dst->src[0]; // src ggml_tensor * src1 = dst->src[1]; // index - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16 + || dst->type == GGML_TYPE_BF16); switch (src0->type) { + case GGML_TYPE_BF16: case GGML_TYPE_F16: case GGML_TYPE_F32: if (src0->type == dst->type) { @@ -1881,6 +1883,7 @@ void ggml_cann_set_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst) { break; } case GGML_TYPE_F16: + case GGML_TYPE_BF16: { acl_tensor_ptr acl_src0 = ggml_cann_create_tensor(src0); ggml_cann_pool_alloc src_buffer_allocator(ctx.pool(), ggml_nelements(src0) * sizeof(uint16_t)); @@ -1891,7 +1894,7 @@ void ggml_cann_set_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst) { src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1]; } acl_tensor_ptr src_trans_tensor = ggml_cann_create_tensor( - src_trans_buffer, ACL_FLOAT16, ggml_type_size(dst->type), src0->ne, src_trans_nb, GGML_MAX_DIMS); + src_trans_buffer, ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type), src0->ne, src_trans_nb, GGML_MAX_DIMS); aclnn_cast(ctx, acl_src0.get(), src_trans_tensor.get(), ggml_cann_type_mapping(dst->type)); aclnn_index_copy_4d(ctx, src_trans_buffer, src0->ne, src_trans_nb, dst->data, dst->ne, dst->nb, src1, dst->type); @@ -1965,7 +1968,7 @@ static void ggml_cann_mat_mul_fp(ggml_backend_cann_context & ctx, ggml_tensor * // Only check env once. static bool weight_to_nz = parse_bool(get_env_as_lowercase("GGML_CANN_WEIGHT_NZ").value_or("on")); - if (weight_to_nz && is_matmul_weight(weight)) { + if (weight_to_nz && weight->type != GGML_TYPE_BF16 && is_matmul_weight(weight)) { acl_weight_tensor = ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, n_dims, ACL_FORMAT_FRACTAL_NZ); } else { acl_weight_tensor = ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, n_dims, ACL_FORMAT_ND); @@ -2146,6 +2149,9 @@ void ggml_cann_mul_mat(ggml_backend_cann_context & ctx, ggml_tensor * dst) { switch (type) { case GGML_TYPE_F32: case GGML_TYPE_F16: +#ifndef ASCEND_310P + case GGML_TYPE_BF16: +#endif ggml_cann_mat_mul_fp(ctx, dst); break; case GGML_TYPE_Q4_0: diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index a682746bb42..2f9c350789c 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -1234,7 +1234,8 @@ static void ggml_backend_cann_buffer_set_tensor(ggml_backend_buffer_t buffer, static bool weight_to_nz = parse_bool(get_env_as_lowercase("GGML_CANN_WEIGHT_NZ").value_or("on")); if (!need_transform(tensor->type)) { ACL_CHECK(aclrtMemcpy((char *) tensor->data + offset, size, data, size, ACL_MEMCPY_HOST_TO_DEVICE)); - if (weight_to_nz && is_matmul_weight((const ggml_tensor *) tensor)) { + if (weight_to_nz && tensor->type != GGML_TYPE_BF16 + && is_matmul_weight((const ggml_tensor *) tensor)) { GGML_ASSERT(tensor->ne[2] == 1); GGML_ASSERT(tensor->ne[3] == 1); weight_format_to_nz(tensor, offset, ctx->device); @@ -1443,7 +1444,8 @@ static size_t ggml_backend_cann_buffer_type_get_alloc_size(ggml_backend_buffer_t if (ne0 % MATRIX_ROW_PADDING != 0) { size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING); } - } else if (weight_to_nz && is_matmul_weight((const ggml_tensor *) tensor)) { + } else if (weight_to_nz && tensor->type != GGML_TYPE_BF16 + && is_matmul_weight((const ggml_tensor *) tensor)) { // NZ format weight are not support quantized yet. // If ND tensor transform to NZ, size may changed. int64_t shape[] = { tensor->ne[1], tensor->ne[0] }; @@ -2283,6 +2285,9 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten case GGML_OP_MUL_MAT: { switch (op->src[0]->type) { +#ifndef ASCEND_310P + case GGML_TYPE_BF16: +#endif case GGML_TYPE_F16: case GGML_TYPE_F32: return true; @@ -2320,6 +2325,9 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten switch (op->src[0]->type) { case GGML_TYPE_F32: case GGML_TYPE_F16: +#ifndef ASCEND_310P + case GGML_TYPE_BF16: +#endif case GGML_TYPE_Q8_0: return true; default: @@ -2332,6 +2340,9 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten switch (op->type) { case GGML_TYPE_F32: case GGML_TYPE_F16: +#ifndef ASCEND_310P + case GGML_TYPE_BF16: +#endif return true; default: return false; @@ -2341,20 +2352,30 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten case GGML_OP_CPY: { ggml_tensor * src = op->src[0]; +#ifdef ASCEND_310P if ((op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_F16) || (src->type != GGML_TYPE_F32 && src->type != GGML_TYPE_F16)) { - // only support F32 and F16. + // only support F32 and F16 on 310P. return false; } +#else + if ((op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_F16 && op->type != GGML_TYPE_BF16) || + (src->type != GGML_TYPE_F32 && src->type != GGML_TYPE_F16 && src->type != GGML_TYPE_BF16)) { + // only support F32, F16 and BF16. + return false; + } +#endif return true; } break; case GGML_OP_CONT: { - // TODO: support GGML_TYPE_BF16 switch (op->src[0]->type) { case GGML_TYPE_F32: case GGML_TYPE_F16: +#ifndef ASCEND_310P + case GGML_TYPE_BF16: +#endif return true; default: return false; From 49b505bcc5c76b30584a26fcc2d1d6751bcc986c Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Fri, 20 Mar 2026 06:17:15 -0500 Subject: [PATCH 328/831] vulkan: change gated_delta_net to shard a column across a subgroup (llama/20662) * vulkan: change gated_delta_net to shard a column across a subgroup This is based on https://github.com/ggml-org/llama.cpp/pull/20391, I used an LLM to port the CUDA code to Vulkan, and guided to it to make various fixes to work with Vulkan (e.g. handling different subgroup sizes, unknown mapping of subgroup to invocation id, using subgroupAdd optionally, etc.). This fixes a perf regression from the transposing of the values in memory (!20443). * vulkan: Spread columns across fewer lanes to reduce the number of workgroups --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 38 +++- .../vulkan-shaders/gated_delta_net.comp | 165 +++++++++++------- .../vulkan-shaders/vulkan-shaders-gen.cpp | 4 +- 3 files changed, 140 insertions(+), 67 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 3e36435d166..566958b3a9d 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -4604,12 +4604,42 @@ static void ggml_vk_load_shaders(vk_device& device) { {"gated_delta_net_f32_d64", "gated_delta_net_f32_d64_kda"}, {"gated_delta_net_f32_d128", "gated_delta_net_f32_d128_kda"}, }; + const bool use_subgroup_reduce = device->subgroup_arithmetic; for (uint32_t si = 0; si < 3; si++) { + const uint32_t S_V = gdn_sizes[si]; + GGML_ASSERT(is_pow2(S_V)); + + uint32_t lanes_per_column; + if (S_V >= 128u && device->subgroup_clustered) { + lanes_per_column = 8u; + } else { + // Use largest power-of-two that divides both S_V and subgroup_size so that + // (1) S_V % lanes_per_column == 0 and (2) S_V % (subgroup_size / lanes_per_column) == 0. + // This means we don't need extra bounds checking logic in the shader. + lanes_per_column = std::min(S_V, device->subgroup_size); + } + + const bool need_clustered_shader = lanes_per_column != 1 && (lanes_per_column < device->subgroup_size); + size_t gdn_len; + const void * gdn_data; + if (use_subgroup_reduce && need_clustered_shader) { + gdn_len = gated_delta_net_f32_len; + gdn_data = (const void *)gated_delta_net_f32_data; + } else if (use_subgroup_reduce) { + gdn_len = gated_delta_net_f32_nocluster_len; + gdn_data = (const void *)gated_delta_net_f32_nocluster_data; + } else { + gdn_len = gated_delta_net_f32_shmem_len; + gdn_data = (const void *)gated_delta_net_f32_shmem_data; + } + + const uint32_t cols_per_wg = device->subgroup_size / lanes_per_column; + const std::array wg_denoms = {1u, 1u, cols_per_wg}; + for (uint32_t kda = 0; kda < 2; kda++) { ggml_vk_create_pipeline(device, device->pipeline_gated_delta_net[si][kda], - gdn_names[si][kda], gated_delta_net_f32_len, gated_delta_net_f32_data, - "main", 7, sizeof(vk_op_gated_delta_net_push_constants), - {1, 1, 1}, {gdn_sizes[si], kda}, 1); + gdn_names[si][kda], gdn_len, gdn_data, "main", 7, sizeof(vk_op_gated_delta_net_push_constants), + wg_denoms, {S_V, kda, device->subgroup_size, lanes_per_column}, 1, true, use_subgroup_reduce, device->subgroup_size); } } } @@ -10438,7 +10468,7 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], dst_buf}, - pc, { H, n_seqs, 1u }); + pc, { H, n_seqs, S_v }); } static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp index f008859b99d..5e9f8308c1d 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp @@ -1,11 +1,25 @@ #version 450 #extension GL_EXT_control_flow_attributes : require - +#extension GL_KHR_shader_subgroup_basic : enable +#if USE_SUBGROUP_CLUSTERED +#extension GL_KHR_shader_subgroup_clustered : enable +#endif +#if USE_SUBGROUP_ADD +#extension GL_KHR_shader_subgroup_arithmetic : enable +#endif + +// Caller guarantees valid spec constants: S_V % COLS_PER_WG == 0 and S_V % LANES_PER_COLUMN == 0, +// so no bounds checking is needed. layout(constant_id = 0) const uint S_V = 128; layout(constant_id = 1) const uint KDA = 0; +layout(constant_id = 2) const uint SUBGROUP_SIZE = 32; +layout(constant_id = 3) const uint LANES_PER_COLUMN = 32; + +const uint COLS_PER_WG = SUBGROUP_SIZE / LANES_PER_COLUMN; +const uint ROWS_PER_LANE = S_V / LANES_PER_COLUMN; -layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; +layout(local_size_x_id = 2, local_size_y = 1, local_size_z = 1) in; layout(push_constant) uniform Parameters { uint H; @@ -27,14 +41,61 @@ layout(binding = 4) readonly buffer BetaBuf { FLOAT_TYPE data_beta[]; }; layout(binding = 5) readonly buffer StateBuf { FLOAT_TYPE data_state[]; }; layout(binding = 6) buffer DstBuf { FLOAT_TYPE data_dst[]; }; -shared FLOAT_TYPE s_k[S_V]; -shared FLOAT_TYPE s_q[S_V]; -shared FLOAT_TYPE s_g[S_V]; // KDA only: cached exp(g[i]) +#if !USE_SUBGROUP_ADD && !USE_SUBGROUP_CLUSTERED +shared FLOAT_TYPE temp[SUBGROUP_SIZE]; + +// This does a reduction across groups of LANES_PER_COLUMN +FLOAT_TYPE reduce_add_shmem(FLOAT_TYPE partial) { + const uint lane = gl_SubgroupInvocationID; + temp[lane] = partial; + barrier(); + [[unroll]] for (uint s = LANES_PER_COLUMN / 2u; s > 0; s >>= 1u) { + FLOAT_TYPE other = temp[lane ^ s]; + barrier(); + temp[lane] += other; + barrier(); + } + const FLOAT_TYPE result = temp[lane]; + barrier(); + return result; +} +#endif + +// clusterSize for subgroupClusteredAdd must be a compile-time constant; branch on spec constant +FLOAT_TYPE reduce_partial(FLOAT_TYPE partial) { + switch (LANES_PER_COLUMN) { + case 1u: + return partial; +#if USE_SUBGROUP_CLUSTERED + // Workaround for GLSL requiring a literal constant for the cluster size. + // The branches should all fold away. + case 2u: + return subgroupClusteredAdd(partial, 2u); + case 4u: + return subgroupClusteredAdd(partial, 4u); + case 8u: + return subgroupClusteredAdd(partial, 8u); + case 16u: + return subgroupClusteredAdd(partial, 16u); + case 32u: + return subgroupClusteredAdd(partial, 32u); + case 64u: + return subgroupClusteredAdd(partial, 64u); +#endif + default: +#if USE_SUBGROUP_ADD + return subgroupAdd(partial); +#else + return reduce_add_shmem(partial); +#endif + } +} void main() { const uint head_id = gl_WorkGroupID.x; - const uint seq_id = gl_WorkGroupID.y; - const uint col = gl_LocalInvocationID.x; + const uint seq_id = gl_WorkGroupID.y; + const uint lane = gl_SubgroupInvocationID % LANES_PER_COLUMN; + const uint col = gl_WorkGroupID.z * COLS_PER_WG + (gl_SubgroupInvocationID / LANES_PER_COLUMN); const uint iq1 = head_id % neq1; const uint iq3 = seq_id / rq3; @@ -42,9 +103,9 @@ void main() { const uint state_size = S_V * S_V; const uint state_base = (seq_id * H + head_id) * state_size; - FLOAT_TYPE state[S_V]; - [[unroll]] for (uint i = 0; i < S_V; i++) { - state[i] = FLOAT_TYPE(data_state[state_base + col * S_V + i]); + FLOAT_TYPE s_shard[ROWS_PER_LANE]; + [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) { + s_shard[r] = FLOAT_TYPE(data_state[state_base + col * S_V + r * LANES_PER_COLUMN + lane]); } uint attn_off = (seq_id * n_tokens * H + head_id) * S_V; @@ -53,76 +114,56 @@ void main() { const uint q_off = iq3 * sq3 + t * sq2 + iq1 * sq1; const uint k_off = q_off; const uint v_off = seq_id * sv3 + t * sv2 + head_id * sv1; - - s_q[col] = FLOAT_TYPE(data_q[q_off + col]); - s_k[col] = FLOAT_TYPE(data_k[k_off + col]); - const uint gb_off = seq_id * sb3 + t * sb2 + head_id * sb1; + const FLOAT_TYPE beta_val = FLOAT_TYPE(data_beta[gb_off]); - if (KDA != 0) { - const uint g_base = gb_off * S_V; - s_g[col] = exp(FLOAT_TYPE(data_g[g_base + col])); + FLOAT_TYPE k_reg[ROWS_PER_LANE]; + FLOAT_TYPE q_reg[ROWS_PER_LANE]; + [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) { + const uint i = r * LANES_PER_COLUMN + lane; + k_reg[r] = FLOAT_TYPE(data_k[k_off + i]); + q_reg[r] = FLOAT_TYPE(data_q[q_off + i]); } - barrier(); - - const FLOAT_TYPE v_val = FLOAT_TYPE(data_v[v_off + col]); - const FLOAT_TYPE beta_val = FLOAT_TYPE(data_beta[gb_off]); - + FLOAT_TYPE g_exp[ROWS_PER_LANE]; if (KDA == 0) { const FLOAT_TYPE g_val = exp(FLOAT_TYPE(data_g[gb_off])); - - FLOAT_TYPE kv_col = 0.0; - [[unroll]] for (uint i = 0; i < S_V; i += 4) { - kv_col += dot( - vec4(state[i], state[i+1], state[i+2], state[i+3]), - vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3]) - ); - } - - FLOAT_TYPE delta_col = (v_val - g_val * kv_col) * beta_val; - - FLOAT_TYPE attn_col = 0.0; - [[unroll]] for (uint i = 0; i < S_V; i += 4) { - vec4 sv = vec4(state[i], state[i+1], state[i+2], state[i+3]); - vec4 kv = vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3]); - sv = g_val * sv + kv * delta_col; - state[i] = sv.x; state[i+1] = sv.y; state[i+2] = sv.z; state[i+3] = sv.w; - - attn_col += dot(sv, vec4(s_q[i], s_q[i+1], s_q[i+2], s_q[i+3])); + [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) { + g_exp[r] = g_val; } - - data_dst[attn_off + col] = attn_col * scale; } else { - FLOAT_TYPE kv_col = 0.0; - [[unroll]] for (uint i = 0; i < S_V; i += 4) { - vec4 gv = vec4(s_g[i], s_g[i+1], s_g[i+2], s_g[i+3]); - vec4 sv = vec4(state[i], state[i+1], state[i+2], state[i+3]); - vec4 kv = vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3]); - kv_col += dot(gv * sv, kv); + const uint g_base = gb_off * S_V; + [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) { + const uint i = r * LANES_PER_COLUMN + lane; + g_exp[r] = exp(FLOAT_TYPE(data_g[g_base + i])); } + } + + const FLOAT_TYPE v_val = FLOAT_TYPE(data_v[v_off + col]); - FLOAT_TYPE delta_col = (v_val - kv_col) * beta_val; + FLOAT_TYPE kv_shard = 0.0; + [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) { + kv_shard += g_exp[r] * s_shard[r] * k_reg[r]; + } + FLOAT_TYPE kv_col = reduce_partial(kv_shard); - FLOAT_TYPE attn_col = 0.0; - [[unroll]] for (uint i = 0; i < S_V; i += 4) { - vec4 gv = vec4(s_g[i], s_g[i+1], s_g[i+2], s_g[i+3]); - vec4 sv = vec4(state[i], state[i+1], state[i+2], state[i+3]); - vec4 kv = vec4(s_k[i], s_k[i+1], s_k[i+2], s_k[i+3]); - sv = gv * sv + kv * delta_col; - state[i] = sv.x; state[i+1] = sv.y; state[i+2] = sv.z; state[i+3] = sv.w; + FLOAT_TYPE delta_col = (v_val - kv_col) * beta_val; - attn_col += dot(sv, vec4(s_q[i], s_q[i+1], s_q[i+2], s_q[i+3])); - } + FLOAT_TYPE attn_partial = 0.0; + [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) { + s_shard[r] = g_exp[r] * s_shard[r] + k_reg[r] * delta_col; + attn_partial += s_shard[r] * q_reg[r]; + } + FLOAT_TYPE attn_col = reduce_partial(attn_partial); + if (lane == 0) { data_dst[attn_off + col] = attn_col * scale; } attn_off += S_V * H; - barrier(); } - [[unroll]] for (uint i = 0; i < S_V; i++) { - data_dst[s_off + state_base + col * S_V + i] = state[i]; + [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) { + data_dst[s_off + state_base + col * S_V + r * LANES_PER_COLUMN + lane] = s_shard[r]; } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index abd2a9c36fa..8186dba36f6 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -987,7 +987,9 @@ void process_shaders() { string_to_spv("rwkv_wkv7_f32", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); - string_to_spv("gated_delta_net_f32", "gated_delta_net.comp", merge_maps(base_dict, {{"FLOAT_TYPE", "float"}})); + string_to_spv("gated_delta_net_f32", "gated_delta_net.comp", merge_maps(base_dict, {{"FLOAT_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}, {"USE_SUBGROUP_CLUSTERED", "1"}})); + string_to_spv("gated_delta_net_f32_nocluster", "gated_delta_net.comp", merge_maps(base_dict, {{"FLOAT_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}, {"USE_SUBGROUP_CLUSTERED", "0"}})); + string_to_spv("gated_delta_net_f32_shmem", "gated_delta_net.comp", merge_maps(base_dict, {{"FLOAT_TYPE", "float"}, {"USE_SUBGROUP_ADD", "0"}, {"USE_SUBGROUP_CLUSTERED", "0"}})); string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); string_to_spv("opt_step_sgd_f32", "opt_step_sgd.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); From ca5d565dcdb677bad03f719b7c85bf5939f18e39 Mon Sep 17 00:00:00 2001 From: shalinib-ibm Date: Sat, 21 Mar 2026 04:41:45 +0530 Subject: [PATCH 329/831] ggml-cpu: add always_inline to tinyBLAS_PPC accumulator saves (llama/20791) Explicitly mark save_acc and add_save_Acc with always_inline in tinyBLAS_PPC. This ensures the compiler keeps MMA accumulator disassembly within kernel's register context, preventing un-necessary stask spills. Signed-off-by: Shalini Salomi Bodapati --- ggml/src/ggml-cpu/llamafile/sgemm.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ggml/src/ggml-cpu/llamafile/sgemm.cpp b/ggml/src/ggml-cpu/llamafile/sgemm.cpp index c89e5076f26..63ceb635dea 100644 --- a/ggml/src/ggml-cpu/llamafile/sgemm.cpp +++ b/ggml/src/ggml-cpu/llamafile/sgemm.cpp @@ -3194,6 +3194,7 @@ class tinyBLAS_PPC { private: + __attribute__((always_inline)) inline void save_acc(acc_t * ACC, int64_t ii, int64_t jj) { vec_t vec_C[4]; __builtin_mma_disassemble_acc(vec_C, ACC); @@ -3204,6 +3205,7 @@ class tinyBLAS_PPC { } } + __attribute__((always_inline)) inline void add_save_acc(acc_t * ACC, int64_t ii, int64_t jj) { vec_t vec_C[4]; __builtin_mma_disassemble_acc(vec_C, ACC); From 22710fdb82e744a521b69214187d9889e244b404 Mon Sep 17 00:00:00 2001 From: Matt Corallo <649246+TheBlueMatt@users.noreply.github.com> Date: Sat, 21 Mar 2026 04:22:51 +0000 Subject: [PATCH 330/831] Add shader count for Intel Arc Pro B60 (llama/20818) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 566958b3a9d..221e6fa04e9 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -16048,6 +16048,7 @@ static uint32_t ggml_vk_intel_shader_core_count(const vk::PhysicalDevice& vkdev) case 0xE20C: // B570 return 18; case 0xE20B: // B580 + case 0xE211: // Pro B60 return 20; default: return 0; From 5f3428219a79f5c24604d9d34a3a4a0cbbc1e212 Mon Sep 17 00:00:00 2001 From: y198 <90976397+y198nt@users.noreply.github.com> Date: Sat, 21 Mar 2026 20:59:43 +0700 Subject: [PATCH 331/831] fix(rpc): prevent division by zero in deserialize_tensor (llama/20712) rpc : prevent division by zero in deserialize_tensor When receiving an RPC message with a deprecated tensor type (e.g., type 4 or 5 where `blck_size == 0`), `ggml_row_size()` will trigger a division by zero (SIGFPE) and crash the rpc-server. This patch adds a simple validation check in `deserialize_tensor` to return `nullptr` if the requested tensor type has a block size of 0. (Note: This was originally reported via Security Advisory and maintainer suggested dropping a patch here). * style: remove trailing whitespace --- ggml/src/ggml-rpc/ggml-rpc.cpp | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index d7c8ad8c168..5d8defad209 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -1162,12 +1162,18 @@ ggml_tensor * rpc_server::deserialize_tensor(struct ggml_context * ctx, const rp return nullptr; } + // Fix: Prevent division by zero if blck_size is 0 (e.g., deprecated types) + if (ggml_blck_size((enum ggml_type)tensor->type) == 0) { + GGML_LOG_ERROR("[%s] invalid tensor type received (blck_size is 0): %u\n", __func__, tensor->type); + return nullptr; + } + ggml_tensor * result = ggml_new_tensor_4d(ctx, (ggml_type) tensor->type, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); // ggml_new_tensor_4d might fail if dimensions are invalid, although less likely to crash than invalid type if (result == nullptr) { - GGML_LOG_ERROR("[%s] ggml_new_tensor_4d failed for type %u\\n", __func__, tensor->type); + GGML_LOG_ERROR("[%s] ggml_new_tensor_4d failed for type %u\n", __func__, tensor->type); return nullptr; } From 77b635e9c4f0e3a8fe0252f2197f61b37a62a22c Mon Sep 17 00:00:00 2001 From: Gaurav Garg Date: Sun, 22 Mar 2026 14:19:35 +0530 Subject: [PATCH 332/831] Increase number of output elements per-thread block if the K-dimension is small (llama/20635) * Increase per-thread work if the K-dimension is small With tensor parallelism, the K-dimension of the FFN-down matrices is split, which makes it quite small, especially for MOEs. For example, Qwen3-30b-A3B has a K-dimension of 768, and Qwen3235B-A22B has k-dimension of 1536. The current heuristic uses a group of 4 warps irrespective of K-dimension size, resulting in some of the threads being idle. This results in poor performance for these matrices. This change increases the number of output elements per block for such cases. * Limit this change to ncols_dst = 1 * tab to space --- ggml/src/ggml-cuda/mmvq.cu | 56 +++++++++++++++++++++++++++----------- 1 file changed, 40 insertions(+), 16 deletions(-) diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 632246e43fd..024b3d8cf22 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -33,7 +33,7 @@ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) } } -static constexpr __device__ int get_vdr_mmvq(ggml_type type) { +static constexpr __host__ __device__ int get_vdr_mmvq(ggml_type type) { switch (type) { case GGML_TYPE_Q4_0: return VDR_Q4_0_Q8_1_MMVQ; case GGML_TYPE_Q4_1: return VDR_Q4_1_Q8_1_MMVQ; @@ -173,11 +173,11 @@ static constexpr __host__ __device__ int calc_nwarps(ggml_type type, int ncols_d return 1; } -static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int table_id) { +static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int table_id, bool small_k = false, int nwarps = 1) { if (table_id == MMVQ_PARAMETERS_GENERIC || table_id == MMVQ_PARAMETERS_GCN) { switch (ncols_dst) { case 1: - return 1; + return small_k ? nwarps : 1; case 2: case 3: case 4: @@ -193,7 +193,7 @@ static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int return 1; } -template +template __launch_bounds__(calc_nwarps(type, ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1) static __global__ void mul_mat_vec_q( const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst, @@ -208,7 +208,7 @@ static __global__ void mul_mat_vec_q( constexpr int vdr = get_vdr_mmvq(type); constexpr mmvq_parameter_table_id table_id = get_device_table_id(); constexpr int nwarps = calc_nwarps(type, ncols_dst, table_id); - constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_dst, table_id); + constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_dst, table_id, small_k, nwarps); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type); @@ -414,14 +414,16 @@ static __global__ void mul_mat_vec_q( template static std::pair calc_launch_params( const int ncols_dst, const int nrows_x, const int nchannels_dst, const int nsamples_or_ntokens, - const int warp_size, const mmvq_parameter_table_id table_id) { - const int64_t nblocks = (nrows_x + calc_rows_per_block(ncols_dst, table_id) - 1) / calc_rows_per_block(ncols_dst, table_id); + const int warp_size, const mmvq_parameter_table_id table_id, const bool small_k = false) { + const int nwarps = calc_nwarps(type, ncols_dst, table_id); + const int rpb = calc_rows_per_block(ncols_dst, table_id, small_k, nwarps); + const int64_t nblocks = (nrows_x + rpb - 1) / rpb; const dim3 block_nums(nblocks, nchannels_dst, nsamples_or_ntokens); - const dim3 block_dims(warp_size, calc_nwarps(type, ncols_dst, table_id), 1); + const dim3 block_dims(warp_size, nwarps, 1); return {block_nums, block_dims}; } -template +template static void mul_mat_vec_q_switch_fusion( const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst, const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y, @@ -434,7 +436,7 @@ static void mul_mat_vec_q_switch_fusion( const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr; if constexpr (c_ncols_dst == 1) { if (has_fusion) { - mul_mat_vec_q<<>> + mul_mat_vec_q<<>> (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride); @@ -444,7 +446,7 @@ static void mul_mat_vec_q_switch_fusion( GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1"); - mul_mat_vec_q<<>> + mul_mat_vec_q<<>> (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride); @@ -488,11 +490,33 @@ static void mul_mat_vec_q_switch_ncols_dst( switch (ncols_dst) { case 1: { constexpr int c_ncols_dst = 1; - std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); - mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, - channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, - dims.first, dims.second, 0, ids_stride, stream); + + // When K is small, increase rows_per_block to match nwarps so each warp has more work to do + // Trigger when the full thread block covers all K blocks in a single loop iteration and few threads remain idle. + constexpr int qk = ggml_cuda_type_traits::qk; + constexpr int qi = ggml_cuda_type_traits::qi; + constexpr int vdr = get_vdr_mmvq(type); + const int blocks_per_row_x = ncols_x / qk; + const int blocks_per_iter_1warp = vdr * warp_size / qi; + const int nwarps = calc_nwarps(type, c_ncols_dst, table_id); + const bool use_small_k = nwarps > 1 && blocks_per_row_x < nwarps * blocks_per_iter_1warp; + if (use_small_k) { + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, + warp_size, table_id, true); + mul_mat_vec_q_switch_fusion( + vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, + dims.first, dims.second, 0, ids_stride, stream); + } else { + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, + warp_size, table_id); + mul_mat_vec_q_switch_fusion( + vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, + dims.first, dims.second, 0, ids_stride, stream); + } } break; case 2: { constexpr int c_ncols_dst = 2; From 69f0d907ee609091eaa3a552ccfd63c9390e55d6 Mon Sep 17 00:00:00 2001 From: Patrick Buckley Date: Sun, 22 Mar 2026 03:05:51 -0700 Subject: [PATCH 333/831] ggml-cuda: native bf16 flash attention for vec kernel (llama/20525) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * ggml-cuda: native bf16 flash attention for vec and tile kernels mma kernel still converts bf16 to fp16 before launch, native mma bf16 todo * ggml-cuda: address code owner review feedback reverted tile kernel changes to avoid larger refactor * fix ci failures on turing and hip * fix bf16 vec kernel compile on hip v_dot2 platforms * add comments --------- Co-authored-by: Johannes Gäßler --- ggml/src/ggml-cuda/CMakeLists.txt | 11 ++--- ggml/src/ggml-cuda/convert.cuh | 6 +++ ggml/src/ggml-cuda/fattn-common.cuh | 48 +++++++++++++++++++ ggml/src/ggml-cuda/fattn-vec.cuh | 26 +++++++--- ggml/src/ggml-cuda/fattn.cu | 16 +++++++ .../fattn-vec-instance-bf16-bf16.cu | 7 +++ .../fattn-vec-instance-bf16-f16.cu | 7 +++ .../fattn-vec-instance-bf16-q4_0.cu | 7 +++ .../fattn-vec-instance-bf16-q4_1.cu | 7 +++ .../fattn-vec-instance-bf16-q5_0.cu | 7 +++ .../fattn-vec-instance-bf16-q5_1.cu | 7 +++ .../fattn-vec-instance-bf16-q8_0.cu | 7 +++ .../fattn-vec-instance-f16-bf16.cu | 7 +++ .../fattn-vec-instance-q4_0-bf16.cu | 7 +++ .../fattn-vec-instance-q4_1-bf16.cu | 7 +++ .../fattn-vec-instance-q5_0-bf16.cu | 7 +++ .../fattn-vec-instance-q5_1-bf16.cu | 7 +++ .../fattn-vec-instance-q8_0-bf16.cu | 7 +++ .../template-instances/generate_cu_files.py | 2 +- ggml/src/ggml-hip/CMakeLists.txt | 11 ++--- ggml/src/ggml-musa/CMakeLists.txt | 11 ++--- 21 files changed, 197 insertions(+), 25 deletions(-) create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu diff --git a/ggml/src/ggml-cuda/CMakeLists.txt b/ggml/src/ggml-cuda/CMakeLists.txt index 262f88204e0..419862101d1 100644 --- a/ggml/src/ggml-cuda/CMakeLists.txt +++ b/ggml/src/ggml-cuda/CMakeLists.txt @@ -116,12 +116,11 @@ if (CUDAToolkit_FOUND) list(APPEND GGML_SOURCES_CUDA ${SRCS}) add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS) else() - file(GLOB SRCS "template-instances/fattn-vec*q4_0-q4_0.cu") - list(APPEND GGML_SOURCES_CUDA ${SRCS}) - file(GLOB SRCS "template-instances/fattn-vec*q8_0-q8_0.cu") - list(APPEND GGML_SOURCES_CUDA ${SRCS}) - file(GLOB SRCS "template-instances/fattn-vec*f16-f16.cu") - list(APPEND GGML_SOURCES_CUDA ${SRCS}) + list(APPEND GGML_SOURCES_CUDA + template-instances/fattn-vec-instance-f16-f16.cu + template-instances/fattn-vec-instance-q4_0-q4_0.cu + template-instances/fattn-vec-instance-q8_0-q8_0.cu + template-instances/fattn-vec-instance-bf16-bf16.cu) endif() ggml_add_backend_library(ggml-cuda diff --git a/ggml/src/ggml-cuda/convert.cuh b/ggml/src/ggml-cuda/convert.cuh index 09f9a33f909..b8caeacf094 100644 --- a/ggml/src/ggml-cuda/convert.cuh +++ b/ggml/src/ggml-cuda/convert.cuh @@ -41,6 +41,12 @@ template return __bfloat162float(x); } else if constexpr(std::is_same_v && std::is_same_v) { return __float22half2_rn(x); + } else if constexpr(std::is_same_v && std::is_same_v) { +#if !defined(GGML_USE_HIP) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + return __bfloat1622float2(x); +#else + return make_float2(__bfloat162float(__low2bfloat16(x)), __bfloat162float(__high2bfloat16(x))); +#endif // !defined(GGML_USE_HIP) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 } else if constexpr(std::is_same_v && std::is_same_v) { // bypass compile error on cuda 12.0.1 #ifdef GGML_USE_HIP diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index e9abdf288c4..c59a4db3999 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -74,6 +74,37 @@ static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_f16( return sum; } +template +static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_bf16( + const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) { + + const nv_bfloat162 * K_bf16 = (const nv_bfloat162 *) K_c; + GGML_UNUSED(Q_q8); + GGML_UNUSED(Q_ds_v); + + constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes(); + constexpr int cpy_ne = cpy_nb / 4; + + float sum = 0.0f; + +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += nthreads*cpy_ne) { + __align__(16) nv_bfloat162 tmp[cpy_ne]; + ggml_cuda_memcpy_1(tmp, K_bf16 + k_KQ_0 + (threadIdx.x % nthreads)*cpy_ne); +#pragma unroll + for (int k_KQ_1 = 0; k_KQ_1 < cpy_ne; ++k_KQ_1) { +#ifdef V_DOT2_F32_F16_AVAILABLE + // FIXME replace macros in vector FA kernel with templating and use FP32 for BF16 + ggml_cuda_mad(sum, ggml_cuda_cast(tmp[k_KQ_1]), __half22float2(((const half2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1])); +#else + ggml_cuda_mad(sum, ggml_cuda_cast(tmp[k_KQ_1]), ((const float2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]); +#endif // V_DOT2_F32_F16_AVAILABLE + } + } + + return sum; +} + template static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q4_0( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { @@ -321,6 +352,19 @@ static __device__ __forceinline__ void dequantize_V_f16(const void * __restrict_ } } +template +static __device__ __forceinline__ void dequantize_V_bf16(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) { + static_assert(std::is_same_v, "BF16 V dequantization only supports float output"); + static_assert(ne % 2 == 0, "bad ne"); + __align__(16) nv_bfloat162 tmp[ne/2]; + ggml_cuda_memcpy_1(tmp, (const nv_bfloat16 *) vx + i0); + float2 * dst_f2 = (float2 *) dst; +#pragma unroll + for (int l = 0; l < ne/2; ++l) { + dst_f2[l] = ggml_cuda_cast(tmp[l]); + } +} + template static __device__ __forceinline__ void dequantize_V_q4_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) { const block_q4_0 * x = (const block_q4_0 *) vx; @@ -547,6 +591,8 @@ constexpr __device__ vec_dot_KQ_t get_vec_dot_KQ() { return vec_dot_fattn_vec_KQ_q5_1; } else if constexpr (type_K == GGML_TYPE_Q8_0) { return vec_dot_fattn_vec_KQ_q8_0; + } else if constexpr (type_K == GGML_TYPE_BF16) { + return vec_dot_fattn_vec_KQ_bf16; } else { static_assert(type_K == -1, "bad type"); return nullptr; @@ -567,6 +613,8 @@ constexpr __device__ dequantize_V_t get_dequantize_V() { return dequantize_V_q5_1; } else if constexpr (type_V == GGML_TYPE_Q8_0) { return dequantize_V_q8_0; + } else if constexpr (type_V == GGML_TYPE_BF16) { + return dequantize_V_bf16; } else { static_assert(type_V == -1, "bad type"); return nullptr; diff --git a/ggml/src/ggml-cuda/fattn-vec.cuh b/ggml/src/ggml-cuda/fattn-vec.cuh index 7cbe32633e5..f0bd42a5761 100644 --- a/ggml/src/ggml-cuda/fattn-vec.cuh +++ b/ggml/src/ggml-cuda/fattn-vec.cuh @@ -75,17 +75,17 @@ static __global__ void flash_attn_ext_vec( #endif // GGML_USE_HIP constexpr int nthreads = ggml_cuda_fattn_vec_get_nthreads_device(); - constexpr int nthreads_KQ = type_K == GGML_TYPE_F16 ? 128 / cpy_nb : nthreads_KQ_q; - constexpr int nthreads_V = type_V == GGML_TYPE_F16 ? 128 / cpy_nb : nthreads_V_q; + constexpr int nthreads_KQ = (type_K == GGML_TYPE_F16 || type_K == GGML_TYPE_BF16) ? 128 / cpy_nb : nthreads_KQ_q; + constexpr int nthreads_V = (type_V == GGML_TYPE_F16 || type_V == GGML_TYPE_BF16) ? 128 / cpy_nb : nthreads_V_q; static_assert(WARP_SIZE % nthreads_KQ == 0, "bad nthreads_K"); static_assert(WARP_SIZE % nthreads_V == 0, "bad nthreads_V"); - constexpr int V_rows_per_thread = type_V == GGML_TYPE_F16 ? 2*cpy_ne : 4; + constexpr int V_rows_per_thread = (type_V == GGML_TYPE_F16 || type_V == GGML_TYPE_BF16) ? 2*cpy_ne : 4; constexpr int V_cols_per_iter = WARP_SIZE / nthreads_V; constexpr vec_dot_KQ_t vec_dot_KQ = get_vec_dot_KQ(); - constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16; + constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16 && type_K != GGML_TYPE_BF16; #ifdef V_DOT2_F32_F16_AVAILABLE constexpr dequantize_V_t dequantize_V = get_dequantize_V(); #else @@ -323,8 +323,18 @@ static __global__ void flash_attn_ext_vec( #pragma unroll for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) { half2 tmp[V_rows_per_thread/2]; - dequantize_V(V + k*nb21, tmp, - 2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread); + if constexpr (type_V == GGML_TYPE_BF16) { + float2 tmp_f[V_rows_per_thread/2]; + dequantize_V(V + k*nb21, tmp_f, + 2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread); +#pragma unroll + for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) { + tmp[i_VKQ_1] = __float22half2_rn(tmp_f[i_VKQ_1]); + } + } else { + dequantize_V(V + k*nb21, tmp, + 2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread); + } #pragma unroll for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) { #pragma unroll @@ -563,6 +573,7 @@ void ggml_cuda_flash_attn_ext_vec_case(ggml_backend_cuda_context & ctx, ggml_ten extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_0); \ extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_1); \ extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q8_0); \ + extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_BF16); \ EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_F16) EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q4_0) @@ -570,6 +581,7 @@ EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q4_1) EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q5_0) EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q5_1) EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q8_0) +EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_BF16) EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_F16) EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q4_0) @@ -577,6 +589,7 @@ EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q4_1) EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q5_0) EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q5_1) EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q8_0) +EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_BF16) EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_F16) EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q4_0) @@ -584,3 +597,4 @@ EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q4_1) EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_0) EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_1) EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q8_0) +EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_BF16) diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 85c177f496f..a25a890db6d 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -224,6 +224,7 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_F16) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_F16) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_F16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_F16) FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q4_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0) @@ -231,6 +232,7 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q4_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q4_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q4_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_Q4_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q4_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_1) @@ -238,6 +240,7 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q4_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q4_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q4_1) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_Q4_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q5_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q5_0) @@ -245,6 +248,7 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q5_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q5_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q5_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_Q5_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q5_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q5_1) @@ -252,6 +256,7 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q5_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q5_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q5_1) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_Q5_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q8_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q8_0) @@ -259,10 +264,20 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q8_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q8_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_Q8_0) + + FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_BF16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_BF16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_BF16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_BF16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_BF16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_BF16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_BF16) #else FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_F16) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_BF16) #endif // GGML_CUDA_FA_ALL_QUANTS GGML_ABORT("fatal error"); @@ -355,6 +370,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const #endif // GGML_CUDA_FA_ALL_QUANTS case GGML_TYPE_Q4_0: case GGML_TYPE_Q8_0: + case GGML_TYPE_BF16: break; default: return BEST_FATTN_KERNEL_NONE; diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu new file mode 100644 index 00000000000..3a2fa99b05b --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu new file mode 100644 index 00000000000..60f0f6f7952 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_F16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu new file mode 100644 index 00000000000..489e05f08c3 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_Q4_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu new file mode 100644 index 00000000000..6fa3c26d309 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_Q4_1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu new file mode 100644 index 00000000000..421027fb29d --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_Q5_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu new file mode 100644 index 00000000000..abbc9434802 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_Q5_1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu new file mode 100644 index 00000000000..d641f859d81 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu new file mode 100644 index 00000000000..d1071dc2438 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu new file mode 100644 index 00000000000..8afda314238 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu new file mode 100644 index 00000000000..506864ac18d --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu new file mode 100644 index 00000000000..0bbda8371e6 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu new file mode 100644 index 00000000000..79be24daf9e --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu new file mode 100644 index 00000000000..45636e5e70c --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py index e382df1ae20..3b5ab12fc40 100755 --- a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +++ b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py @@ -5,7 +5,7 @@ HEAD_SIZES_KQ = [40, 64, 72, 80, 96, 112, 128, 256, 576] -TYPES_KV = ["GGML_TYPE_F16", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0"] +TYPES_KV = ["GGML_TYPE_F16", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", "GGML_TYPE_BF16"] SOURCE_FATTN_TILE = """// This file has been autogenerated by generate_cu_files.py, do not edit manually. diff --git a/ggml/src/ggml-hip/CMakeLists.txt b/ggml/src/ggml-hip/CMakeLists.txt index f96c6e09a9b..291b4837455 100644 --- a/ggml/src/ggml-hip/CMakeLists.txt +++ b/ggml/src/ggml-hip/CMakeLists.txt @@ -71,12 +71,11 @@ if (GGML_CUDA_FA_ALL_QUANTS) list(APPEND GGML_SOURCES_ROCM ${SRCS}) add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS) else() - file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*q4_0-q4_0.cu") - list(APPEND GGML_SOURCES_ROCM ${SRCS}) - file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*q8_0-q8_0.cu") - list(APPEND GGML_SOURCES_ROCM ${SRCS}) - file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*f16-f16.cu") - list(APPEND GGML_SOURCES_ROCM ${SRCS}) + list(APPEND GGML_SOURCES_ROCM + ../ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu + ../ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu + ../ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu + ../ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu) endif() ggml_add_backend_library(ggml-hip diff --git a/ggml/src/ggml-musa/CMakeLists.txt b/ggml/src/ggml-musa/CMakeLists.txt index d76cb51977f..cc53c812ce5 100644 --- a/ggml/src/ggml-musa/CMakeLists.txt +++ b/ggml/src/ggml-musa/CMakeLists.txt @@ -48,12 +48,11 @@ if (MUSAToolkit_FOUND) list(APPEND GGML_SOURCES_MUSA ${SRCS}) add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS) else() - file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*q4_0-q4_0.cu") - list(APPEND GGML_SOURCES_MUSA ${SRCS}) - file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*q8_0-q8_0.cu") - list(APPEND GGML_SOURCES_MUSA ${SRCS}) - file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*f16-f16.cu") - list(APPEND GGML_SOURCES_MUSA ${SRCS}) + list(APPEND GGML_SOURCES_MUSA + ../ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu + ../ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu + ../ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu + ../ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu) endif() set_source_files_properties(${GGML_SOURCES_MUSA} PROPERTIES LANGUAGE CXX) From 1d0f0285de5575194a9c42450a1c5293cf433b51 Mon Sep 17 00:00:00 2001 From: Neo Zhang Date: Sun, 22 Mar 2026 22:06:27 +0800 Subject: [PATCH 334/831] support bf16 and quantized type (llama/20803) --- ggml/src/ggml-sycl/ggml-sycl.cpp | 16 +--------------- 1 file changed, 1 insertion(+), 15 deletions(-) diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 2ec1421841b..456b1699fa3 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -4667,22 +4667,8 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g if (a->ne[3] != b->ne[3]) { return false; } - ggml_type a_type = a->type; - if (a_type == GGML_TYPE_IQ4_NL || a_type == GGML_TYPE_IQ4_XS || - a_type == GGML_TYPE_IQ3_XXS || a_type == GGML_TYPE_IQ3_S || - a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ2_S || - a_type == GGML_TYPE_IQ1_S || a_type == GGML_TYPE_IQ1_M - ) { - if (b->ne[1] == 1 && ggml_nrows(b) > 1) { - return false; - } - } + ggml_type src0_type = op->src[0]->type; - if (src0_type == GGML_TYPE_BF16 ) { - // TODO: support GGML_TYPE_BF16 - // FIXME: keep a list of supported types to avoid breaking the backend when a new type is added - return false; - } // TODO: The configuration below needs more work to be supported with oneDNN if (ggml_is_permuted(a) && !ggml_is_contiguous(a) && From 607c92430f5ba00db92f18e1d4097de2212b9d6e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sun, 22 Mar 2026 17:53:33 +0100 Subject: [PATCH 335/831] CUDA: fix BF16 FA compilation (llama/20865) --- ggml/src/ggml-cuda/convert.cuh | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-cuda/convert.cuh b/ggml/src/ggml-cuda/convert.cuh index b8caeacf094..f5d37c7b998 100644 --- a/ggml/src/ggml-cuda/convert.cuh +++ b/ggml/src/ggml-cuda/convert.cuh @@ -42,11 +42,15 @@ template } else if constexpr(std::is_same_v && std::is_same_v) { return __float22half2_rn(x); } else if constexpr(std::is_same_v && std::is_same_v) { -#if !defined(GGML_USE_HIP) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#ifdef GGML_USE_HIP + return make_float2(__bfloat162float(__low2bfloat16(x)), __bfloat162float(__high2bfloat16(x))); +#else +#if __CUDA_ARCH__ >= 800 return __bfloat1622float2(x); #else - return make_float2(__bfloat162float(__low2bfloat16(x)), __bfloat162float(__high2bfloat16(x))); -#endif // !defined(GGML_USE_HIP) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + return make_float2(__bfloat162float(x.x), __bfloat162float(x.y)); +#endif // __CUDA_ARCH__ >= 800 +#endif // GGML_USE_HIP } else if constexpr(std::is_same_v && std::is_same_v) { // bypass compile error on cuda 12.0.1 #ifdef GGML_USE_HIP From c976b22d7bf197ab8a727a99c42d58472ba144b0 Mon Sep 17 00:00:00 2001 From: shaofeiqi Date: Sun, 22 Mar 2026 22:45:11 -0700 Subject: [PATCH 336/831] opencl: add flattened Q4_K mv and general Q4_K mm (llama/20773) --- ggml/src/ggml-opencl/CMakeLists.txt | 2 + ggml/src/ggml-opencl/ggml-opencl.cpp | 289 ++++++++++++++++++ ggml/src/ggml-opencl/kernels/cvt.cl | 67 ++++ .../kernels/mul_mm_q4_k_f32_l4_lm.cl | 179 +++++++++++ .../kernels/mul_mv_q4_k_f32_flat.cl | 196 ++++++++++++ 5 files changed, 733 insertions(+) create mode 100644 ggml/src/ggml-opencl/kernels/mul_mm_q4_k_f32_l4_lm.cl create mode 100644 ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32_flat.cl diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index 1f8250934b0..ae667b12d17 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -89,6 +89,7 @@ set(GGML_OPENCL_KERNELS mul_mv_q4_1_f32 mul_mv_q4_1_f32_flat mul_mv_q4_k_f32 + mul_mv_q4_k_f32_flat mul_mv_q6_k_f32 mul_mv_q6_k_f32_flat mul_mv_q8_0_f32 @@ -107,6 +108,7 @@ set(GGML_OPENCL_KERNELS mul_mm_q4_0_f32_l4_lm mul_mm_q4_1_f32_l4_lm mul_mm_q8_0_f32_l4_lm + mul_mm_q4_k_f32_l4_lm mul_mm_q6_k_f32_l4_lm mul_mm_q8_0_f32_8x4 gemv_noshuffle_q4_1_f32 diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index e1dca6b4b4d..c984e59b6b4 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -534,11 +534,13 @@ struct ggml_backend_opencl_context { cl_kernel kernel_restore_block_q4_0_noshuffle; cl_kernel kernel_convert_block_q4_1_noshuffle; cl_kernel kernel_restore_block_q4_1_noshuffle; + cl_kernel kernel_convert_block_q4_K, kernel_restore_block_q4_K; cl_kernel kernel_convert_block_q6_K, kernel_restore_block_q6_K; cl_kernel kernel_mul_mat_q4_0_f32_1d_8x_flat, kernel_mul_mat_q4_0_f32_1d_16x_flat; cl_kernel kernel_mul_mv_q4_1_f32; cl_kernel kernel_mul_mv_q4_1_f32_flat; cl_kernel kernel_mul_mv_q4_K_f32; + cl_kernel kernel_mul_mv_q4_K_f32_flat; cl_kernel kernel_mul_mv_q6_K_f32; cl_kernel kernel_mul_mv_q6_K_f32_flat; cl_kernel kernel_mul_mv_mxfp4_f32, kernel_mul_mv_mxfp4_f32_flat; @@ -578,6 +580,7 @@ struct ggml_backend_opencl_context { cl_kernel kernel_mul_mm_q4_0_f32_l4_lm; cl_kernel kernel_mul_mm_q4_1_f32_l4_lm; cl_kernel kernel_mul_mm_q8_0_f32_l4_lm; + cl_kernel kernel_mul_mm_q4_k_f32_l4_lm; cl_kernel kernel_mul_mm_q6_k_f32_l4_lm; std::vector profiling_info; @@ -917,6 +920,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve CL_CHECK((backend_ctx->kernel_convert_block_q8_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q8_0", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_q8_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q8_0", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_q8_0_trans = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q8_0_trans", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q4_K = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_K", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q4_K = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_K", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q6_K = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q6_K", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_q6_K = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q6_K", &err), err)); GGML_LOG_CONT("."); @@ -1209,6 +1214,23 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // mul_mv_q4_k_f32_flat + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_q4_k_f32_flat.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_q4_k_f32_flat.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mv_q4_K_f32_flat = clCreateKernel(prog, "kernel_mul_mv_q4_K_f32_flat", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + // mul_mv_q6_k_f32 { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -1482,6 +1504,23 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // mul_mm_q4_k_f32_l4_lm + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mm_q4_k_f32_l4_lm.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mm_q4_k_f32_l4_lm.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mm_q4_k_f32_l4_lm = clCreateKernel(prog, "kernel_mul_mm_q4_k_f32_l4_lm", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + // mul_mm_q6_k_f32_l4_lm { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -3347,6 +3386,40 @@ struct ggml_tensor_extra_cl_q8_0 { } }; +struct ggml_tensor_extra_cl_q4_K { + // Quantized values + cl_mem q = nullptr; + // Scales for each super block. + cl_mem s = nullptr; + // Scales + cl_mem d = nullptr; + // Min + cl_mem dm = nullptr; + + ~ggml_tensor_extra_cl_q4_K() { + reset(); + } + + void reset() { + if (q != nullptr) { + CL_CHECK(clReleaseMemObject(q)); + q = nullptr; + } + if (s != nullptr) { + CL_CHECK(clReleaseMemObject(s)); + s = nullptr; + } + if (d != nullptr) { + CL_CHECK(clReleaseMemObject(d)); + d = nullptr; + } + if (dm != nullptr) { + CL_CHECK(clReleaseMemObject(dm)); + dm = nullptr; + } + } +}; + struct ggml_tensor_extra_cl_q6_K { // Lower 4 bits of quantized weights. cl_mem ql = nullptr; @@ -3956,6 +4029,12 @@ struct ggml_backend_opencl_buffer_context { for (ggml_tensor_extra_cl_q8_0 * e : temp_tensor_extras_q8_0_in_use) { delete e; } + for (ggml_tensor_extra_cl_q4_K * e : temp_tensor_extras_q4_K) { + delete e; + } + for (ggml_tensor_extra_cl_q4_K * e : temp_tensor_extras_q4_K_in_use) { + delete e; + } for (ggml_tensor_extra_cl_q6_K * e : temp_tensor_extras_q6_K) { delete e; } @@ -4039,6 +4118,21 @@ struct ggml_backend_opencl_buffer_context { return extra; } + ggml_tensor_extra_cl_q4_K * ggml_opencl_alloc_temp_tensor_extra_q4_K() { + ggml_tensor_extra_cl_q4_K * extra; + if (temp_tensor_extras_q4_K.empty()) { + extra = new ggml_tensor_extra_cl_q4_K(); + } else { + extra = temp_tensor_extras_q4_K.back(); + temp_tensor_extras_q4_K.pop_back(); + } + + temp_tensor_extras_q4_K_in_use.push_back(extra); + + extra->reset(); + return extra; + } + ggml_tensor_extra_cl_q6_K * ggml_opencl_alloc_temp_tensor_extra_q6_K() { ggml_tensor_extra_cl_q6_K * extra; if (temp_tensor_extras_q6_K.empty()) { @@ -4080,6 +4174,11 @@ struct ggml_backend_opencl_buffer_context { } temp_tensor_extras_q8_0_in_use.clear(); + for (ggml_tensor_extra_cl_q4_K * e : temp_tensor_extras_q4_K_in_use) { + temp_tensor_extras_q4_K.push_back(e); + } + temp_tensor_extras_q4_K_in_use.clear(); + for (ggml_tensor_extra_cl_q6_K * e : temp_tensor_extras_q6_K_in_use) { temp_tensor_extras_q6_K.push_back(e); } @@ -4101,6 +4200,8 @@ struct ggml_backend_opencl_buffer_context { std::vector temp_tensor_extras_mxfp4_in_use; std::vector temp_tensor_extras_q8_0; std::vector temp_tensor_extras_q8_0_in_use; + std::vector temp_tensor_extras_q4_K; + std::vector temp_tensor_extras_q4_K_in_use; std::vector temp_tensor_extras_q6_K; std::vector temp_tensor_extras_q6_K_in_use; @@ -4835,6 +4936,83 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, return; } + if (tensor->type == GGML_TYPE_Q4_K) { + ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra; + GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized"); + + // Allocate the new extra and create aliases from the original. + ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; + ggml_tensor_extra_cl_q4_K * extra = ctx->ggml_opencl_alloc_temp_tensor_extra_q4_K(); + + size_t size_d = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t); + size_t size_dm = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t); + size_t size_s = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*(3 * ggml_blck_size(tensor->type) / 64); + size_t size_q = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/2; + GGML_ASSERT(size_d + size_dm + size_s + size_q == ggml_nbytes(tensor) && "Incorrect tensor size"); + + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + CL_CHECK(clEnqueueWriteBuffer( + queue, data_device, CL_TRUE, 0, + ggml_nbytes(tensor), data, 0, NULL, NULL)); + + cl_buffer_region region; + + // Create subbuffer for d. + region.origin = align_to(extra_orig->offset + tensor->view_offs + offset, backend_ctx->alignment); + region.size = size_d; + extra->d = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + auto previous_origin = region.origin; + + // Create subbuffer for mins. + region.origin = align_to(previous_origin + size_d, backend_ctx->alignment); + region.size = size_dm; + extra->dm = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + previous_origin = region.origin; + + // Create subbuffer for s. + region.origin = align_to(previous_origin + size_dm, backend_ctx->alignment); + region.size = size_s; + extra->s = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + previous_origin = region.origin; + + // Create subbuffer for quants. + region.origin = align_to(previous_origin + size_s, backend_ctx->alignment); + region.size = size_q; + extra->q = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + + cl_kernel kernel = backend_ctx->kernel_convert_block_q4_K; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->s)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra->dm)); + + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + + tensor->extra = extra; + return; + } if (tensor->type == GGML_TYPE_Q6_K) { ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra; GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized"); @@ -5245,6 +5423,34 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, CL_CHECK(clReleaseMemObject(data_device)); return; } + if (tensor->type == GGML_TYPE_Q4_K) { + ggml_tensor_extra_cl_q4_K * extra = (ggml_tensor_extra_cl_q4_K *)tensor->extra; + + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + + cl_kernel kernel = backend_ctx->kernel_restore_block_q4_K; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->s)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->dm)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &data_device)); + + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {1, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clEnqueueReadBuffer( + queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); + return; + } if (tensor->type == GGML_TYPE_Q6_K) { ggml_tensor_extra_cl_q6_K * extra = (ggml_tensor_extra_cl_q6_K *)tensor->extra; @@ -9357,6 +9563,7 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co ggml_tensor_extra_cl_q4_1 * extra0_q4_1 = (ggml_tensor_extra_cl_q4_1 *)src0->extra; ggml_tensor_extra_cl_mxfp4 * extra0_mxfp4 = (ggml_tensor_extra_cl_mxfp4 *)src0->extra; ggml_tensor_extra_cl_q8_0 * extra0_q8_0 = (ggml_tensor_extra_cl_q8_0 *)src0->extra; + ggml_tensor_extra_cl_q4_K * extra0_q4_K = (ggml_tensor_extra_cl_q4_K *)src0->extra; ggml_tensor_extra_cl_q6_K * extra0_q6_K = (ggml_tensor_extra_cl_q6_K *)src0->extra; #endif @@ -10005,6 +10212,50 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); return; } + case GGML_TYPE_Q4_K: { + if (ne11 < 32) { + break; + } + if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) { + break; + } + + kernel = backend_ctx->kernel_mul_mm_q4_k_f32_l4_lm; + nth0 = 128; // calculated as (BM*BN)/(TM*TN) + + int batch_stride_a = ne00*ne01; + int batch_stride_b = ne10*ne11; + int batch_stride_d = ne0*ne1; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q4_K->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_K->s)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q4_K->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q4_K->dm)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne10)); // stride_a + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne10)); // stride_b + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne01)); // stride_d + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &batch_stride_a)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &batch_stride_b)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &batch_stride_d)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int), &r3)); + + // 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed. + size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13}; + size_t local_work_size[] = {(size_t)nth0, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + return; + } case GGML_TYPE_Q6_K: { if (ne11 < 32) { break; @@ -10449,6 +10700,43 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: { +#ifdef GGML_OPENCL_SOA_Q + kernel = backend_ctx->kernel_mul_mv_q4_K_f32_flat; + + if (backend_ctx->gpu_family == INTEL) { + nth0 = 16; + nth1 = 1; + ndst = 4; + } else if (backend_ctx->gpu_family == ADRENO) { + nth0 = 64; + nth1 = 2; + ndst = 16; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q4_K->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_K->s)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q4_K->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q4_K->dm)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb13)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int), &r3)); +#else kernel = backend_ctx->kernel_mul_mv_q4_K_f32; if (backend_ctx->gpu_family == INTEL) { @@ -10482,6 +10770,7 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne1)); CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &r2)); CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &r3)); +#endif // GGML_OPENCL_SOA_Q break; } case GGML_TYPE_Q5_K: diff --git a/ggml/src/ggml-opencl/kernels/cvt.cl b/ggml/src/ggml-opencl/kernels/cvt.cl index 78ef9c177f6..272d0ea23f0 100644 --- a/ggml/src/ggml-opencl/kernels/cvt.cl +++ b/ggml/src/ggml-opencl/kernels/cvt.cl @@ -28,6 +28,7 @@ #define QK8_0 32 #define QR8_0 1 #define QK_K 256 +#define K_SCALE_SIZE (3 * QK_K / 64) #define K_QUANTS_PER_ITERATION 2 typedef char int8_t; @@ -55,6 +56,16 @@ struct block_q4_1 { uchar qs[QK4_1 / 2]; // nibbles / quants }; +//------------------------------------------------------------------------------ +// block_q4_k +//------------------------------------------------------------------------------ +struct block_q4_K { + half d; // delta + half dm; // min + uchar s[K_SCALE_SIZE]; + uchar q[QK_K / 2]; // nibbles / quants +}; + //------------------------------------------------------------------------------ // block_q6_K //------------------------------------------------------------------------------ @@ -408,6 +419,62 @@ kernel void kernel_restore_block_q8_0_trans( } } +//------------------------------------------------------------------------------ +// kernel_convert_block_q4_K +// Convert the block_q4_K format to 4 separate arrays (AOS -> SOA). +// This kernel does not deshuffle the bits. +// Each thread processes a super block. +//------------------------------------------------------------------------------ +kernel void kernel_convert_block_q4_K( + global struct block_q4_K * src0, + global uchar * dst_q, + global uchar * dst_s, + global half * dst_d, + global half * dst_dm +) { + global struct block_q4_K * b = (global struct block_q4_K *) src0 + get_global_id(0); + global uchar * q = (global uchar *) dst_q + QK_K/2*get_global_id(0); + global uchar * s = (global uchar *) dst_s + K_SCALE_SIZE*get_global_id(0); + global half * d = (global half *) dst_d + get_global_id(0); + global half * dm = (global half *) dst_dm + get_global_id(0); + + *d = b->d; + *dm = b->dm; + + for (int i = 0; i < QK_K/2; ++i) { + q[i] = b->q[i]; + } + for (int i = 0; i < K_SCALE_SIZE; ++i) { + s[i] = b->s[i]; + } +} + +// Restore block_q4_K from flattened arrays. +// Each thread processes a super block. +kernel void kernel_restore_block_q4_K( + global uchar * src_q, + global uchar * src_s, + global half * src_d, + global half * src_dm, + global struct block_q4_K * dst +) { + global struct block_q4_K * b = (global struct block_q4_K *) dst + get_global_id(0); + global uchar * q = (global uchar *) src_q + QK_K/2*get_global_id(0); + global uchar * s = (global uchar *) src_s + K_SCALE_SIZE*get_global_id(0); + global half * d = (global half *) src_d + get_global_id(0); + global half * dm = (global half *) src_dm + get_global_id(0); + + b->d = *d; + b->dm = *dm; + + for (int i = 0; i < QK_K/2; ++i) { + b->q[i] = q[i]; + } + for (int i = 0; i < K_SCALE_SIZE; ++i) { + b->s[i] = s[i]; + } +} + //------------------------------------------------------------------------------ // kernel_convert_block_q6_K // Convert the block_q6_K format to 3 separate arrays (AOS -> SOA). diff --git a/ggml/src/ggml-opencl/kernels/mul_mm_q4_k_f32_l4_lm.cl b/ggml/src/ggml-opencl/kernels/mul_mm_q4_k_f32_l4_lm.cl new file mode 100644 index 00000000000..2235b1ae838 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mm_q4_k_f32_l4_lm.cl @@ -0,0 +1,179 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#define LOAD_VEC_A 4 +#define LOAD_VEC_B 4 + +#define BM 64 +#define BN 64 +#define BK 32 +#define TM 4 +#define TN 8 + +kernel void kernel_mul_mm_q4_k_f32_l4_lm( + global uchar4 * src0_q, + global uchar * src0_s, + global half * src0_d, + global half * src0_dm, + global float4 * src1, + ulong offset1, + global float * dst, + ulong offsetd, + + int ne00, + int ne01, + int ne02, + int ne11, + int ne12, + + int stride_a, + int stride_b, + int stride_d, + + int batch_stride_a, + int batch_stride_b, + int batch_stride_d, + + int r2, + int r3 +) { + src1 = (global float4*)((global char*)src1 + offset1); + dst = (global float *)((global char*)dst + offsetd); + + local float buf_a[BM * BK]; + local float buf_b[BN * BK]; + + const int batch_idx = get_global_id(2); + + const int i13 = batch_idx / ne12; + const int i12 = batch_idx % ne12; + + const int i03 = i13 / r3; + const int i02 = i12 / r2; + + const int batch_idx_a = i03 * ne02 + i02; + + const int ir = get_group_id(0); + const int ic = get_group_id(1); + + const int tid = get_local_id(0); + const int th_r = tid % (BM / TM); + const int th_c = tid / (BM / TM); + + const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A); + const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A); + const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B); + const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B); + + const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK; + const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK; + + int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A; + int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B; + + float sums[TM * TN]; + float cache_a[TM]; + float cache_b[TN]; + + for (int i = 0; i < TM * TN; i++) { + sums[i] = 0.0f; + } + + for (int block = 0; block < ne00; block += BK) { + for (int l = 0; l < BM; l += loadstride_a) { + if (ir*BM + loadc_a + l < ne01) { + int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a; + int ib = idx / 64; + int iqs = (idx % 64) * 2; + + int n = iqs / 32; + int b = (iqs % 32) / 16; + int is = 2 * n + b; + int qsi = n * 32 + (iqs % 16) * 2; + + char * scales = src0_s + ib * 12; + + int scidx0 = (is < 4) ? is : (is + 4); + int scidx1 = (is < 4) ? is : (is - 4); + int scidxmask1 = (is < 4) ? 0x30 : 0xC0; + int scidxshift1 = (is < 4) ? 0 : 2; + int mbidx0 = is + 4; + int mbidx1 = (is < 4) ? is + 4 : is; + int mbidxmask0 = (is < 4) ? 0xF : 0xF0; + int mbidxshift0 = (is < 4) ? 0 : 4; + int mbidxmask1 = (is < 4) ? 0x30 : 0xC0; + int mbidxshift1 = (is < 4) ? 0 : 2; + + uchar sc = (scales[scidx0] & 0xF) | ((scales[scidx1] & scidxmask1) >> scidxshift1); + uchar mbyte = ((scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((scales[mbidx1] & mbidxmask1) >> mbidxshift1); + + float d = (float)src0_d[ib] * (float)sc; + float m = -(float)src0_dm[ib] * (float)mbyte; + + global uchar4 * qs = src0_q + ib*32 + (qsi >> 2); + uchar4 q = *qs; + float4 v1 = (convert_float4((uchar4)((q.s0 >> (b * 4))&0x0F, (q.s1 >> (b * 4))&0x0F, (q.s2 >> (b * 4))&0x0F, (q.s3 >> (b * 4))&0x0F)))*d + m; + + buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = v1.s0; + buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = v1.s1; + buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = v1.s2; + buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = v1.s3; + } else { + buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = 0.0f; + } + } + + for (int l = 0; l < BN; l += loadstride_b) { + if (ic*BN + loadc_b + l < ne11) { + int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b; + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3; + } else { + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f; + } + } + + barrier(CLK_LOCAL_MEM_FENCE); + + pos_a += BK / LOAD_VEC_A; + pos_b += BK / LOAD_VEC_B; + + for (int i = 0; i < BK; i++) { + for (int j = 0; j < TM; j++) { + cache_a[j] = buf_a[(i) * BM + th_r * TM + j]; + } + + for (int j = 0; j < TN; j++) { + cache_b[j] = buf_b[(i) * BN + th_c * TN + j]; + } + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + const int sums_idx = cc*TM + cr; + sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]); + } + } + } + barrier(CLK_LOCAL_MEM_FENCE); + } + + const int dr = ir * BM + th_r * TM; + const int dc = ic * BN + th_c * TN; + + const int offsets = batch_idx * batch_stride_d; + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + if (dr + cr < ne01 && dc + cc < ne11) { + dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr]; + } + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32_flat.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32_flat.cl new file mode 100644 index 00000000000..d92fb968904 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32_flat.cl @@ -0,0 +1,196 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +//------------------------------------------------------------------------------ +// block_q4_K +//------------------------------------------------------------------------------ +#define QK_K 256 +#define BLOCK_Q4K_SIZE 144 +#define K_SCALE_SIZE 12 + +// 8 blocks of 32 elements each +// weight is represented as x = a * q + b +typedef struct { + half d; // super-block scale for quantized scales + half dmin; // super-block scale for quantized mins + + uchar scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits + uchar qs[QK_K/2]; // 4-bit quants +} block_q4_K; + +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 4 // number of rows each SIMD group works on +#define N_SIMDGROUP 1 // number of SIMD groups in a thread group +#define N_SIMDWIDTH 16 // SIMD group size +#elif defined (ADRENO_GPU) +#define N_DST 16 +#define N_SIMDGROUP 2 +#define N_SIMDWIDTH 64 +#endif + +#undef BLOCK_STRIDE +// number of (super) blocks each subgroup processes +// each thread in a subgroup processes a block (32 weights) +#define BLOCK_STRIDE (N_SIMDWIDTH/8) + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mv_q4_K_f32_flat( + global uchar * src0_q, + global uchar * src0_s, + global half * src0_d, + global half * src0_dm, + global char * src1, + int offset1, + global char * dst, + int offsetd, + int ne00, + int ne01, + ulong nb01, + ulong nb02, + ulong nb03, + int ne12, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + int ne1, + int r2, + int r3 +) { + src1 = src1 + offset1; + dst = dst + offsetd; + + ushort kmask1 = 0x3f3f; + ushort kmask2 = 0x0f0f; + ushort kmask3 = 0xc0c0; + + int ix = get_sub_group_local_id()/8; + int it = get_sub_group_local_id()%8; + int iq = it/4; + int ir = it%4; + + int nb = ne00/QK_K; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + int offset_src0 = (first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03)/BLOCK_Q4K_SIZE; + uint blk = nb01 / BLOCK_Q4K_SIZE; + global uchar * blk_q = (global uchar *)src0_q + offset_src0*(QK_K/2); + global uchar * blk_s = (global uchar *)src0_s + offset_src0*K_SCALE_SIZE; + global half * blk_d = (global half *)src0_d + offset_src0; + global half * blk_dm = (global half *)src0_dm + offset_src0; + + int offset_src1 = r1*nb11 + (i12)*nb12 + (i13)*nb13; + global float * y = (global float *)(src1 + offset_src1); + + float yl[16]; + float yh[16]; + float sumf[N_DST] = {0.f}; + float all_sum; + + global float * y4 = y + ix * QK_K + 64 * iq + 8 * ir; + + ushort sc16[4]; + uchar * sc8 = (uchar *)sc16; + + for (int ib = ix; ib < nb; ib += BLOCK_STRIDE) { + float4 sumy = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; ++i) { + yl[i+0] = y4[i+0]; + sumy.s0 += yl[i+0]; + + yl[i+8] = y4[i+32]; + sumy.s1 += yl[i+8]; + + yh[i+0] = y4[i+128]; + sumy.s2 += yh[i+0]; + + yh[i+8] = y4[i+160]; + sumy.s3 += yh[i+8]; + } + + global ushort * q1 = (global ushort *)(blk_q + ib * (QK_K/2)) + (16 * iq + 4 * ir); + global ushort * sc = (global ushort *)(blk_s + ib * K_SCALE_SIZE) + iq; + global half * d = blk_d + ib; + global half * dm = blk_dm + ib; + + for (int row = 0; row < N_DST; row++) { + sc16[0] = sc[0] & kmask1; + sc16[1] = sc[2] & kmask1; + sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2); + sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2); + + global ushort * q2 = q1 + 32; + + float4 acc1 = {0.f, 0.f, 0.f, 0.f}; + float4 acc2 = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; i += 2) { + acc1.s0 += yl[i+0] * (q1[i/2] & 0x000F); + acc1.s1 += yl[i+1] * (q1[i/2] & 0x0F00); + acc1.s2 += yl[i+8] * (q1[i/2] & 0x00F0); + acc1.s3 += yl[i+9] * (q1[i/2] & 0xF000); + acc2.s0 += yh[i+0] * (q2[i/2] & 0x000F); + acc2.s1 += yh[i+1] * (q2[i/2] & 0x0F00); + acc2.s2 += yh[i+8] * (q2[i/2] & 0x00F0); + acc2.s3 += yh[i+9] * (q2[i/2] & 0xF000); + } + + float dall = *d; + float dmin = *dm; + sumf[row] += dall * ((acc1.s0 + 1.f/256.f * acc1.s1) * sc8[0] + + (acc1.s2 + 1.f/256.f * acc1.s3) * sc8[1] * 1.f/16.f + + (acc2.s0 + 1.f/256.f * acc2.s1) * sc8[4] + + (acc2.s2 + 1.f/256.f * acc2.s3) * sc8[5] * 1.f/16.f) - + dmin * (sumy.s0 * sc8[2] + sumy.s1 * sc8[3] + sumy.s2 * sc8[6] + sumy.s3 * sc8[7]); + + q1 += blk*64; + sc += blk*6; + d += blk; + dm += blk; + } + + y4 += BLOCK_STRIDE * QK_K; + } + + global float * dst_f32 = (global float *) dst + im*ne0*ne1 + r1*ne0; + + for (int row = 0; row < N_DST; ++row) { + all_sum = sub_group_reduce_add(sumf[row]); + if (first_row + row < ne01) { + if (get_sub_group_local_id() == 0) { + dst_f32[first_row + row] = all_sum; + } + } + } +} From a0e41ec26111e76755d2dd8e912e0c88f4582135 Mon Sep 17 00:00:00 2001 From: Dan Hoffman <43101339+thedanhoffman@users.noreply.github.com> Date: Sun, 22 Mar 2026 23:05:37 -0700 Subject: [PATCH 337/831] fix(openvino): explicit memset in buffer_context allocation (llama/20857) * fix(openvino): explicit memset in buffer_context allocation * minor --------- Co-authored-by: Dan Hoffman Co-authored-by: Georgi Gerganov --- ggml/src/ggml-openvino/ggml-openvino.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ggml/src/ggml-openvino/ggml-openvino.cpp b/ggml/src/ggml-openvino/ggml-openvino.cpp index 0031cb7369f..b3058b4af73 100644 --- a/ggml/src/ggml-openvino/ggml-openvino.cpp +++ b/ggml/src/ggml-openvino/ggml-openvino.cpp @@ -97,6 +97,8 @@ struct ggml_backend_openvino_buffer_context { ov_buffer = std::make_shared(std::move(usm_tensor)); } else { data = ggml_aligned_malloc(size); + GGML_ASSERT(data); + memset(data, 0, size); ov_buffer = std::make_shared(ov::element::u8, ov::Shape{size}, data); } From 54f5c02f29a6d5f20f94d4707dded5fcc0b2fdb0 Mon Sep 17 00:00:00 2001 From: Chenguang Li <757486878@qq.com> Date: Mon, 23 Mar 2026 15:24:06 +0800 Subject: [PATCH 338/831] CANN: add RoPE cache preload before ACL graph capture (llama/20747) ACL graph capture disallows host-to-device memcpy and device memory malloc/free on the captured stream. Pre-load the RoPE cache before capture so that: - Host-to-device copies and allocations run on the non-captured stream - Cache metadata is populated and memory pool is warmed up - During capture, only on-device computations are recorded; host-side and allocation branches are skipped --- ggml/src/ggml-cann/aclnn_ops.cpp | 52 ++++++++++++++++++++++++++++++++ ggml/src/ggml-cann/aclnn_ops.h | 15 +++++++++ ggml/src/ggml-cann/common.h | 2 +- ggml/src/ggml-cann/ggml-cann.cpp | 13 ++++++++ 4 files changed, 81 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index b45774dde34..adb4d68e868 100644 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -3011,6 +3011,58 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) { } } +void ggml_cann_rope_cache_preload(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + ggml_tensor * src0 = dst->src[0]; + + float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; + int sections[4]; + const int n_dims = ((int32_t *) dst->op_params)[1]; + const int mode = ((int32_t *) dst->op_params)[2]; + const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; + + GGML_TENSOR_UNARY_OP_LOCALS + + memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); + memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); + memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); + memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); + memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); + memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); + memcpy(§ions, (int32_t *) dst->op_params + 11, sizeof(int) * 4); + + const float theta_scale = powf(freq_base, -2.0f / n_dims); + + float corr_dims[2]; + ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); + + bool is_neox = mode & GGML_ROPE_TYPE_NEOX; + const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; + const bool mrope_used = mode & GGML_ROPE_TYPE_MROPE; + const bool is_vision = mode == GGML_ROPE_TYPE_VISION; + + if (is_imrope || mrope_used) { + is_neox = true; + } + + int64_t rope_dims = n_dims; + if (is_vision) { + rope_dims = src0->ne[0]; + } + + // Run the full cache init on the non-captured stream. This performs all + // host-to-device memcpy, aclrtMalloc/Free, and on-device computations + // so that the memory pool is warmed up and cache metadata is populated. + aclnn_rope_cache_init(ctx, dst, corr_dims, ext_factor, theta_scale, freq_scale, attn_factor, is_neox, sections, + mrope_used, is_imrope, is_vision, rope_dims); + + // Reset `cached` so that during graph capture the on-device computations + // (sin/cos, position multiply, repeat, etc.) still execute and get recorded + // into the captured graph. The cache metadata (theta_scale_length, + // theta_scale, sections, position_length, etc.) remains set, which causes + // all host-to-device copy and malloc/free branches to be skipped. + ctx.rope_cache.cached = false; +} + void ggml_cann_argmax(ggml_backend_cann_context & ctx, ggml_tensor * dst) { ggml_tensor * src0 = dst->src[0]; diff --git a/ggml/src/ggml-cann/aclnn_ops.h b/ggml/src/ggml-cann/aclnn_ops.h index 3effa1c289c..7f5ba4d3302 100644 --- a/ggml/src/ggml-cann/aclnn_ops.h +++ b/ggml/src/ggml-cann/aclnn_ops.h @@ -543,6 +543,21 @@ void ggml_cann_mul_mat(ggml_backend_cann_context & ctx, ggml_tensor * dst); */ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst); +/** + * @brief Pre-load the RoPE cache before ACL graph capture. + * + * This function must be called outside of graph capture to perform + * host-to-device memory copies and device memory allocations that are + * not allowed on a captured stream. After pre-loading, the rope cache + * metadata is updated so that the subsequent call to + * aclnn_rope_cache_init (inside graph capture) skips these operations + * and only records the on-device computations into the captured graph. + * + * @param ctx CANN backend context. + * @param dst A ROPE destination tensor from the computation graph. + */ +void ggml_cann_rope_cache_preload(ggml_backend_cann_context & ctx, ggml_tensor * dst); + /** * @brief Computes the index of the maximum value along the specified dimension * of a ggml tensor using the CANN backend. diff --git a/ggml/src/ggml-cann/common.h b/ggml/src/ggml-cann/common.h index 0120f0dfd1e..5f960548cd2 100644 --- a/ggml/src/ggml-cann/common.h +++ b/ggml/src/ggml-cann/common.h @@ -277,7 +277,7 @@ struct ggml_graph_node_properties { } } - if (node->op == GGML_OP_SCALE || node->op == GGML_OP_UNARY || node->op == GGML_OP_GLU) { + if (node->op == GGML_OP_SCALE || node->op == GGML_OP_UNARY || node->op == GGML_OP_GLU || node->op == GGML_OP_ROPE){ return memcmp(this->op_params, node->op_params, GGML_MAX_OP_PARAMS) == 0; } return true; diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index 2f9c350789c..6f26e91e046 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -2225,6 +2225,19 @@ static enum ggml_status ggml_backend_cann_graph_compute(ggml_backend_t backend, // If no matching graph is found, add a new ACL graph. ggml_cann_graph * new_graph = ggml_cann_graph::create_from_cgraph(cgraph); cann_ctx->graph_lru_cache.push(new_graph); + + // Pre-load rope cache before graph capture. During capture the + // stream cannot perform host-to-device memcpy or device memory + // malloc/free. Running the full cache init now populates the + // cache metadata so these branches are skipped during capture, + // while also warming up the memory pool. + for (int i = 0; i < cgraph->n_nodes; i++) { + ggml_tensor * node = cgraph->nodes[i]; + if (node->op == GGML_OP_ROPE) { + ggml_cann_rope_cache_preload(*cann_ctx, node); + break; + } + } } } #else From c589dd77d4fe9df5ea6f1072d7913e870ed4da10 Mon Sep 17 00:00:00 2001 From: Rashid Ul Islam <33536561+Ra5hidIslam@users.noreply.github.com> Date: Mon, 23 Mar 2026 13:15:34 +0530 Subject: [PATCH 339/831] metal: add CONV_3D (llama/19927) * Apply suggestions from code review Co-authored-by: Georgi Gerganov * metal:add conv_3d backend Rebased with master and resolved conflicts. * Resolved issues related to changes in variable names * kernel void kernel_upscale_bilinear_f32 was missing in my branch, added back, should pass all tests now --------- Co-authored-by: Georgi Gerganov --- ggml/src/ggml-metal/ggml-metal-device.cpp | 22 ++++++ ggml/src/ggml-metal/ggml-metal-device.h | 1 + ggml/src/ggml-metal/ggml-metal-device.m | 5 ++ ggml/src/ggml-metal/ggml-metal-impl.h | 36 +++++++++ ggml/src/ggml-metal/ggml-metal-ops.cpp | 75 ++++++++++++++++++ ggml/src/ggml-metal/ggml-metal-ops.h | 1 + ggml/src/ggml-metal/ggml-metal.metal | 92 +++++++++++++++++++++++ 7 files changed, 232 insertions(+) diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 72ad876d5e4..9162342ee98 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -1748,6 +1748,28 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_2d(ggml_met return res; } +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_3d(ggml_metal_library_t lib, const ggml_tensor * op) { + assert(op->op == GGML_OP_CONV_3D); + + GGML_ASSERT(ggml_is_contiguous(op->src[0])); + GGML_ASSERT(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32); + GGML_ASSERT(op->type == GGML_TYPE_F32); + + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_conv_3d_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type)); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + } + + return res; +} + ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_upscale(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_UPSCALE); diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index fd2b3ddeb55..de43f819312 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -148,6 +148,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_im2col struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_1d (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_2d (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_2d (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_3d (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_upscale (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_metal_library_t lib, const struct ggml_tensor * op); diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 82101f4714e..14144aab087 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -1077,6 +1077,11 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te (op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32) && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; + case GGML_OP_CONV_3D: + return ggml_is_contiguous(op->src[0]) && + ggml_is_contiguous(op->src[1]) && + (op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32) && + op->src[1]->type == GGML_TYPE_F32; case GGML_OP_SUM: return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]); case GGML_OP_TRI: diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 53437b23cda..ea471090cd8 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -643,6 +643,42 @@ typedef struct { int32_t KHW; // KH * KW, pre-computed on CPU to save GPU resources } ggml_metal_kargs_im2col; +typedef struct { + int32_t IW; + int32_t IH; + int32_t ID; + int32_t OW; + int32_t OH; + int32_t OD; + int32_t KW; + int32_t KH; + int32_t KD; + int32_t s0; + int32_t s1; + int32_t s2; + int32_t p0; + int32_t p1; + int32_t p2; + int32_t d0; + int32_t d1; + int32_t d2; + int32_t IC; + int32_t N; + int32_t OC; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; +} ggml_metal_kargs_conv_3d; + typedef struct{ int32_t ne00; uint64_t nb01; diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index c0bcad392b9..3cda21be43e 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -394,6 +394,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { { n_fuse = ggml_metal_op_conv_transpose_2d(ctx, idx); } break; + case GGML_OP_CONV_3D: + { + n_fuse = ggml_metal_op_conv_3d(ctx, idx); + } break; case GGML_OP_UPSCALE: { n_fuse = ggml_metal_op_upscale(ctx, idx); @@ -3697,6 +3701,77 @@ int ggml_metal_op_conv_2d(ggml_metal_op_t ctx, int idx) { return 1; } +int ggml_metal_op_conv_3d(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + // 1. Extract standard dimensions and byte strides + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + // 2. Extract hyperparams from op_params + const int32_t s0 = ((const int32_t *)(op->op_params))[0]; + const int32_t s1 = ((const int32_t *)(op->op_params))[1]; + const int32_t s2 = ((const int32_t *)(op->op_params))[2]; + const int32_t p0 = ((const int32_t *)(op->op_params))[3]; + const int32_t p1 = ((const int32_t *)(op->op_params))[4]; + const int32_t p2 = ((const int32_t *)(op->op_params))[5]; + const int32_t d0 = ((const int32_t *)(op->op_params))[6]; + const int32_t d1 = ((const int32_t *)(op->op_params))[7]; + const int32_t d2 = ((const int32_t *)(op->op_params))[8]; + const int32_t IC = ((const int32_t *)(op->op_params))[9]; + const int32_t N = ((const int32_t *)(op->op_params))[10]; + const int32_t OC = ((const int32_t *)(op->op_params))[11]; + + // 3. Build the parameter struct using the macro-generated variables + ggml_metal_kargs_conv_3d args = { + /*.IW =*/ (int32_t)op->src[1]->ne[0], + /*.IH =*/ (int32_t)op->src[1]->ne[1], + /*.ID =*/ (int32_t)op->src[1]->ne[2], + /*.OW =*/ (int32_t)op->ne[0], + /*.OH =*/ (int32_t)op->ne[1], + /*.OD =*/ (int32_t)op->ne[2], + /*.KW =*/ (int32_t)op->src[0]->ne[0], + /*.KH =*/ (int32_t)op->src[0]->ne[1], + /*.KD =*/ (int32_t)op->src[0]->ne[2], + s0, s1, s2, + p0, p1, p2, + d0, d1, d2, + IC, N, OC, + nb00, nb01, nb02, nb03, // Weight strides + nb10, nb11, nb12, nb13, // Input strides + nb0, nb1, nb2, nb3 // Output strides + }; + + // 4. Fetch the JIT pipeline + auto pipeline = ggml_metal_library_get_pipeline_conv_3d(lib, op); + + // 5. Grid mapping + int nth0 = 32; // Standard SIMD width for Apple Silicon + int nth1 = 1; + int nth2 = 1; + + int64_t spatial_volume = args.OW * args.OH * args.OD; + + int ntg0 = (spatial_volume + nth0 - 1) / nth0; + int ntg1 = args.OC; + int ntg2 = args.N; + + // 6. Bind and Dispatch via the ggml C wrapper + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); + + ggml_metal_encoder_dispatch_threadgroups(enc, ntg0, ntg1, ntg2, nth0, nth1, nth2); + + return 1; +} + int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) { ggml_tensor * op = ctx->node(idx); diff --git a/ggml/src/ggml-metal/ggml-metal-ops.h b/ggml/src/ggml-metal/ggml-metal-ops.h index 019f2fec9ed..50e3c5c77a1 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.h +++ b/ggml/src/ggml-metal/ggml-metal-ops.h @@ -75,6 +75,7 @@ int ggml_metal_op_norm (ggml_metal_op_t ctx, int idx); int ggml_metal_op_rope (ggml_metal_op_t ctx, int idx); int ggml_metal_op_im2col (ggml_metal_op_t ctx, int idx); int ggml_metal_op_conv_2d (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_conv_3d (ggml_metal_op_t ctx, int idx); int ggml_metal_op_conv_transpose_1d (ggml_metal_op_t ctx, int idx); int ggml_metal_op_conv_transpose_2d (ggml_metal_op_t ctx, int idx); int ggml_metal_op_upscale (ggml_metal_op_t ctx, int idx); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index b2328605dd9..9c6b1c4f62b 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -4883,6 +4883,98 @@ kernel void kernel_upscale_bilinear_f32( } } +template +kernel void kernel_conv_3d( + constant ggml_metal_kargs_conv_3d & args, + device const char * src0, // Weights [IC * OC, KD, KH, KW] + device const char * src1, // Inputs [IC * N, ID, IH, IW] + device char * dst, // Outputs [OC * N, OD, OH, OW] + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]]) { + + // 1. Un-flatten the spatial dimension from Grid X + int64_t spatial_idx = tgpig.x * 32 + tpitg.x; + + if (spatial_idx >= args.OW * args.OH * args.OD) { + return; // Thread falls outside the spatial volume + } + + int64_t od = spatial_idx / (args.OW * args.OH); + int64_t oh = (spatial_idx / args.OW) % args.OH; + int64_t ow = spatial_idx % args.OW; + + // 2. Map Y to Channels, Z to Batch + int64_t oc = tgpig.y; + int64_t batch_idx = tgpig.z; + + // 3. Calculate anchor coordinates in the Input volume + int64_t i_w_base = ow * args.s0 - args.p0; + int64_t i_h_base = oh * args.s1 - args.p1; + int64_t i_d_base = od * args.s2 - args.p2; + + float sum = 0.0f; + + // 4. Gather Loop (Iterate over Input Channels -> Depth -> Height -> Width) + for (int64_t ic = 0; ic < args.IC; ++ic) { + + // ggml packs batch and channel together in the 4th dimension + int64_t src_cn_idx = batch_idx * args.IC + ic; + int64_t w_cn_idx = oc * args.IC + ic; + + for (int64_t kz = 0; kz < args.KD; ++kz) { + int64_t id = i_d_base + kz * args.d2; + if (id < 0 || id >= args.ID) continue; // Boundary check (Padding) + + for (int64_t ky = 0; ky < args.KH; ++ky) { + int64_t ih = i_h_base + ky * args.d1; + if (ih < 0 || ih >= args.IH) continue; + + for (int64_t kx = 0; kx < args.KW; ++kx) { + int64_t iw = i_w_base + kx * args.d0; + if (iw < 0 || iw >= args.IW) continue; + + // Convert multi-dimensional coordinates to flat byte offsets + int64_t w_idx = kx*args.nb00 + ky*args.nb01 + kz*args.nb02 + w_cn_idx*args.nb03; + int64_t i_idx = iw*args.nb10 + ih*args.nb11 + id*args.nb12 + src_cn_idx*args.nb13; + + // Dereference memory and cast weights to f32 if they were f16 + float w_val = (float)*(device const T*)((device const char*)src0 + w_idx); + float i_val = *(device const float*)((device const char*)src1 + i_idx); + + sum += w_val * i_val; + } + } + } + } + + // 5. Write the accumulated value out to RAM + int64_t dst_cn_idx = batch_idx * args.OC + oc; + int64_t d_idx = ow*args.nb0 + oh*args.nb1 + od*args.nb2 + dst_cn_idx*args.nb3; + + *(device float*)(dst + d_idx) = sum; +} + +// Explicit instantiations so the JIT compiler can find them by name +template [[host_name("kernel_conv_3d_f32_f32")]] +kernel void kernel_conv_3d( + constant ggml_metal_kargs_conv_3d & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]]); + +// Explicit instantiation for f16 weights +template [[host_name("kernel_conv_3d_f16_f32")]] +kernel void kernel_conv_3d( + constant ggml_metal_kargs_conv_3d & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]]); + + static inline float bicubic_weight1(float x) { const float a = -0.75f; return ((a + 2) * x - (a + 3)) * x * x + 1; From 37c0a52c1bf4cd2bd32eca4cfce26fa667ae5736 Mon Sep 17 00:00:00 2001 From: las7 <98077186+las7@users.noreply.github.com> Date: Mon, 23 Mar 2026 10:54:57 -0700 Subject: [PATCH 340/831] rpc : RCE patch (llama/20908) --- ggml/src/ggml-rpc/ggml-rpc.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index 5d8defad209..0ed2c0dce60 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -1443,7 +1443,9 @@ ggml_tensor * rpc_server::create_node(uint64_t id, const rpc_tensor * tensor = it_ptr->second; struct ggml_tensor * result = deserialize_tensor(ctx, tensor); - if (result == nullptr) { + if (result == nullptr || result->buffer == nullptr) { + GGML_LOG_ERROR("[%s] invalid tensor: null %s (id=%" PRIu64 ")\n", + __func__, result == nullptr ? "tensor" : "buffer", id); return nullptr; } tensor_map[id] = result; From 624be93425126001fdcaf830be9e0b719705c4b9 Mon Sep 17 00:00:00 2001 From: lhez Date: Mon, 23 Mar 2026 12:44:18 -0700 Subject: [PATCH 341/831] opencl: add q6_K gemm and gemv kernels for Adreno (llama/20089) * opencl: add q6_K noshuffle kernels, initial q6_K gemv, some host code * opencl: add q6_K transpose * opencl: fix cvt kernel name * opencl: add call to q6_K gemv * opencl: fix q6_K scale transpose * opencl: fix loading for gemv q6_K, refactor * opencl: fix transpose_8_buf kernel assignment, refactor * opencl: refactor q6_K transpose * opencl: add gemm_noshuffle_q6_k_f32 * opencl: fix qh loading * opencl: refactor q6_K gemv host side, release bufs and imgs * opencl: refactor * opencl: fix q6_K dequant and scale selection * opencl: workaround compiler bug, fix dump_tensor * opencl: refactor q6_K convert kernels * opencl: unpack transformed q6_K in get_tensor * opencl: refactor, handle non-uniform workgroups * opencl: support non-vector subgroup bcast --- ggml/src/ggml-opencl/CMakeLists.txt | 2 + ggml/src/ggml-opencl/ggml-opencl.cpp | 397 ++++++++++++++++-- ggml/src/ggml-opencl/kernels/cvt.cl | 128 +++++- .../kernels/gemm_noshuffle_q6_k_f32.cl | 140 ++++++ .../kernels/gemv_noshuffle_q6_k_f32.cl | 293 +++++++++++++ 5 files changed, 920 insertions(+), 40 deletions(-) create mode 100644 ggml/src/ggml-opencl/kernels/gemm_noshuffle_q6_k_f32.cl create mode 100644 ggml/src/ggml-opencl/kernels/gemv_noshuffle_q6_k_f32.cl diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index ae667b12d17..af29f3b8f4c 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -114,6 +114,8 @@ set(GGML_OPENCL_KERNELS gemv_noshuffle_q4_1_f32 gemm_noshuffle_q4_1_f32 gemv_noshuffle_general_q8_0_f32 + gemv_noshuffle_q6_k_f32 + gemm_noshuffle_q6_k_f32 mul neg norm diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index c984e59b6b4..4dddcd82cfa 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -529,6 +529,7 @@ struct ggml_backend_opencl_context { cl_kernel kernel_convert_block_q4_1, kernel_restore_block_q4_1; cl_kernel kernel_convert_block_mxfp4, kernel_convert_block_mxfp4_trans, kernel_restore_block_mxfp4, kernel_restore_block_mxfp4_trans; cl_kernel kernel_convert_block_q8_0, kernel_restore_block_q8_0, kernel_restore_block_q8_0_trans; + cl_kernel kernel_convert_block_q6_K_noshuffle, kernel_restore_block_q6_K_noshuffle; cl_kernel kernel_mul_mat_q4_0_f32_8x_flat; cl_kernel kernel_convert_block_q4_0_noshuffle; cl_kernel kernel_restore_block_q4_0_noshuffle; @@ -716,6 +717,8 @@ struct ggml_backend_opencl_context { cl_kernel kernel_gemm_noshuffle_q4_1_f32; cl_kernel kernel_mul_mm_q8_0_f32_8x4; cl_kernel CL_mul_mat_vec_q8_0_f32; + cl_kernel kernel_gemv_noshuffle_q6_K_f32; + cl_kernel kernel_gemm_noshuffle_q6_K_f32; #endif // GGML_OPENCL_USE_ADRENO_KERNELS void free() { @@ -924,6 +927,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve CL_CHECK((backend_ctx->kernel_restore_block_q4_K = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_K", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q6_K = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q6_K", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_q6_K = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q6_K", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q6_K_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q6_K_noshuffle", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q6_K_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q6_K_noshuffle", &err), err)); GGML_LOG_CONT("."); } @@ -2642,6 +2647,45 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve CL_CHECK((backend_ctx->kernel_gemm_moe_mxfp4_f32 = clCreateKernel(backend_ctx->program_gemm_moe_mxfp4_f32, "kernel_gemm_moe_mxfp4_f32", &err), err)); GGML_LOG_CONT("."); } + + // gemv_noshuffle_q6_k_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemv_noshuffle_q6_k_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("gemv_noshuffle_q6_k_f32.cl"); +#endif + + std::string CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable "; + if (backend_ctx->has_vector_subgroup_broadcast) { + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; + } + + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_gemv_compile_opts); + + CL_CHECK((backend_ctx->kernel_gemv_noshuffle_q6_K_f32 = clCreateKernel(prog, "kernel_gemv_noshuffle_q6_K_f32", &err), err)); + GGML_LOG_CONT("."); + } + + // gemm_noshuffle_q6_k_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemm_noshuffle_q6_k_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("gemm_noshuffle_q6_k_f32.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); + + CL_CHECK((backend_ctx->kernel_gemm_noshuffle_q6_K_f32 = clCreateKernel(prog, "kernel_gemm_noshuffle_q6_K_f32", &err), err)); + GGML_LOG_CONT("."); + } #endif // GGML_OPENCL_USE_ADRENO_KERNELS GGML_LOG_CONT("\n"); } @@ -5029,61 +5073,58 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, "Incorrect tensor size"); cl_int err; - cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, - ggml_nbytes(tensor), NULL, &err); - CL_CHECK(err); - CL_CHECK(clEnqueueWriteBuffer( - queue, data_device, CL_TRUE, 0, - ggml_nbytes(tensor), data, 0, NULL, NULL)); + cl_mem data_device; + CL_CHECK((data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, ggml_nbytes(tensor), NULL, &err), err)); + CL_CHECK(clEnqueueWriteBuffer(queue, data_device, CL_TRUE, 0, ggml_nbytes(tensor), data, 0, NULL, NULL)); cl_buffer_region region; // Subbuffer for ql region.origin = align_to(extra_orig->offset + tensor->view_offs + offset, backend_ctx->alignment); region.size = size_ql; - extra->ql = clCreateSubBuffer( - extra_orig->data_device, CL_MEM_READ_WRITE, - CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); - CL_CHECK(err); + CL_CHECK((extra->ql = clCreateSubBuffer(extra_orig->data_device, CL_MEM_READ_WRITE, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); auto previous_origin = region.origin; // Subbuffer for qh region.origin = align_to(previous_origin + size_ql, backend_ctx->alignment); region.size = size_qh; - extra->qh = clCreateSubBuffer( - extra_orig->data_device, CL_MEM_READ_WRITE, - CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); - CL_CHECK(err); + CL_CHECK((extra->qh = clCreateSubBuffer(extra_orig->data_device, CL_MEM_READ_WRITE, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); previous_origin = region.origin; // Subbuffer for scales region.origin = align_to(previous_origin + size_qh, backend_ctx->alignment); region.size = size_s; - extra->s = clCreateSubBuffer( - extra_orig->data_device, CL_MEM_READ_WRITE, - CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); - CL_CHECK(err); + CL_CHECK((extra->s = clCreateSubBuffer(extra_orig->data_device, CL_MEM_READ_WRITE, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); previous_origin = region.origin; // Create subbuffer for d. region.origin = align_to(previous_origin + size_s, backend_ctx->alignment); region.size = size_d; - extra->d = clCreateSubBuffer( - extra_orig->data_device, CL_MEM_READ_WRITE, - CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); - CL_CHECK(err); + CL_CHECK((extra->d = clCreateSubBuffer(extra_orig->data_device, CL_MEM_READ_WRITE, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); previous_origin = region.origin; // Flatten the weights - cl_kernel kernel = backend_ctx->kernel_convert_block_q6_K; - - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->ql)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->qh)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->s)); - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra->d)); + cl_kernel kernel; +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + kernel = backend_ctx->kernel_convert_block_q6_K; + if (use_adreno_kernels(backend_ctx, tensor)) { + kernel = backend_ctx->kernel_convert_block_q6_K_noshuffle; + } +#else + kernel = backend_ctx->kernel_convert_block_q6_K; +#endif // GGML_OPENCL_USE_ADRENO_KERNELS - size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + cl_uchar mask = 0xff; + cl_ulong n_blk = ggml_nelements(tensor)/ggml_blck_size(tensor->type); + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->ql)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->qh)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->s)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_uchar), &mask)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &n_blk)); + + size_t global_work_size[] = {(size_t)CEIL_DIV(n_blk, 64)*64, 1, 1}; size_t local_work_size[] = {64, 1, 1}; cl_event evt; @@ -5097,6 +5138,29 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, extra->size_d = size_d; tensor->extra = extra; + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_kernels(backend_ctx, tensor)) { + cl_int M = tensor->ne[1]; // ne01 + cl_int K = tensor->ne[0]; // ne00 + + // Transpose ql as ushort + transpose_2d_as_16b(backend_ctx, + extra->ql, extra->ql, size_ql, K/4, M); + + // Transpose qh as uchar + transpose_2d_as_8b(backend_ctx, + extra->qh, extra->qh, size_qh, K/4, M); + + // Transpose s as ushort + transpose_2d_as_16b(backend_ctx, + extra->s, extra->s, size_s, K/16/2, M); + + // Transpose d as ushort + transpose_2d_as_16b(backend_ctx, + extra->d, extra->d, size_d, K/256, M); + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS return; } #endif // GGML_OPENCL_SOA_Q @@ -5454,19 +5518,78 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, if (tensor->type == GGML_TYPE_Q6_K) { ggml_tensor_extra_cl_q6_K * extra = (ggml_tensor_extra_cl_q6_K *)tensor->extra; +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_kernels(backend_ctx, tensor)) { + static ggml_cl_buffer buf_trans_ql; + static ggml_cl_buffer buf_trans_qh; + static ggml_cl_buffer buf_trans_s; + static ggml_cl_buffer buf_trans_d; + static ggml_cl_buffer buf_unpacked; + + cl_int M = tensor->ne[1]; // ne01 + cl_int K = tensor->ne[0]; // ne00 + + GGML_ASSERT(K % ggml_blck_size(tensor->type) == 0); + + size_t size_ql = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/2; + size_t size_qh = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/4; + size_t size_s = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/16; + size_t size_d = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t); + GGML_ASSERT(size_ql + size_qh + size_s + size_d == ggml_nbytes(tensor) && "Incorrect tensor size"); + + buf_trans_ql.allocate(backend_ctx->context, size_ql); + buf_trans_qh.allocate(backend_ctx->context, size_qh); + buf_trans_s.allocate(backend_ctx->context, size_s); + buf_trans_d.allocate(backend_ctx->context, size_d); + buf_unpacked.allocate(backend_ctx->context, ggml_nbytes(tensor)); + + // transpose ql, qh, s and d back + transpose_2d_as_16b(backend_ctx, extra->ql, buf_trans_ql.buffer, size_ql, M, K/4); + transpose_2d_as_8b(backend_ctx, extra->qh, buf_trans_qh.buffer, size_qh, M, K/4); + transpose_2d_as_16b(backend_ctx, extra->s, buf_trans_s.buffer, size_s, M, K/16/2); + transpose_2d_as_16b(backend_ctx, extra->d, buf_trans_d.buffer, size_d, M, K/256); + + // unpack + cl_uchar mask = 0xFF; + cl_ulong n_blk = ggml_nelements(tensor)/ggml_blck_size(tensor->type); + cl_kernel kernel = backend_ctx->kernel_restore_block_q6_K_noshuffle; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &buf_trans_ql.buffer)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &buf_trans_qh.buffer)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &buf_trans_s.buffer)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &buf_trans_d.buffer)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &buf_unpacked.buffer)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_uchar), &mask)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &n_blk)); + + size_t global_work_size[] = {(size_t)n_blk, 1, 1}; + size_t local_work_size[] = {1, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clEnqueueReadBuffer(queue, buf_unpacked.buffer, CL_TRUE, offset, size, data, 0, NULL, NULL)); + + return; + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + cl_int err; cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, ggml_nbytes(tensor), NULL, &err); CL_CHECK(err); + cl_uchar mask = 0xFF; + cl_ulong n_blk = ggml_nelements(tensor)/ggml_blck_size(tensor->type); cl_kernel kernel = backend_ctx->kernel_restore_block_q6_K; - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->ql)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->qh)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->s)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->d)); - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &data_device)); - - size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->ql)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->s)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_uchar), &mask)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &n_blk)); + + size_t global_work_size[] = {(size_t)n_blk, 1, 1}; size_t local_work_size[] = {1, 1, 1}; cl_event evt; @@ -5759,6 +5882,8 @@ typedef struct { static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, "wrong q4_0 block size/padding"); +#define QK_MXFP4 32 + #include #ifdef __cplusplus #include "half.hpp" @@ -5802,7 +5927,7 @@ static void dump_tensor(ggml_backend_t backend, const struct ggml_tensor * tenso buf_d = malloc(size_e); CL_CHECK(clEnqueueReadBuffer(queue, extra->q, CL_TRUE, 0, size_q, buf_q, 0, NULL, NULL)); - CL_CHECK(clEnqueueReadBuffer(queue, extra->d, CL_TRUE, 0, size_e, buf_d, 0, NULL, NULL)); + CL_CHECK(clEnqueueReadBuffer(queue, extra->e, CL_TRUE, 0, size_e, buf_d, 0, NULL, NULL)); CL_CHECK(clFinish(queue)); } else { // Read out the tensor from GPU memory. @@ -9537,6 +9662,196 @@ static void ggml_cl_mul_mat_q8_0_f32_adreno(ggml_backend_t backend, const ggml_t #endif } +static void ggml_cl_mul_mat_q6_K_f32_adreno(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + ggml_tensor_extra_cl_q6_K * extra0_q6_K = (ggml_tensor_extra_cl_q6_K *)src0->extra; + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + + const int ne1 = dst->ne[1]; + + GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0); + + cl_context context = backend_ctx->context; + cl_kernel kernel; + + cl_int err; + cl_buffer_region region; + cl_image_format img_fmt; + cl_image_desc img_desc; + + // subbuffer and image for activation + if (ne1 == 1) { + cl_mem ql_img = nullptr; + cl_mem qh_img = nullptr; + cl_mem b_sub_buffer = nullptr; + cl_mem b_img = nullptr; + + // image for ql + img_fmt.image_channel_order = CL_R; + img_fmt.image_channel_data_type = CL_FLOAT; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = ne01 * ne00 / 8; + img_desc.buffer = extra0_q6_K->ql; + CL_CHECK((ql_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + // image for qh + img_fmt.image_channel_order = CL_R; + img_fmt.image_channel_data_type = CL_HALF_FLOAT; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = ne01 * ne00 / 8; + img_desc.buffer = extra0_q6_K->qh; + CL_CHECK((qh_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + region.origin = offset1; + region.size = ne00 * ne1 * sizeof(float); + CL_CHECK((b_sub_buffer = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + img_fmt.image_channel_order = CL_RGBA; + img_fmt.image_channel_data_type = CL_FLOAT; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = ne00 * ne1 / 4; + img_desc.buffer = b_sub_buffer; + CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + kernel = backend_ctx->kernel_gemv_noshuffle_q6_K_f32; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &ql_img)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &qh_img)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q6_K->s)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q6_K->d)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &b_img)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_int), &ne01)); + + size_t local_work_size[3] = {64, 4, 1}; + size_t global_work_size[3] = {(size_t)CEIL_DIV(ne01/2, 64)*64, 4, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + + CL_CHECK(clReleaseMemObject(ql_img)); + CL_CHECK(clReleaseMemObject(qh_img)); + CL_CHECK(clReleaseMemObject(b_sub_buffer)); + CL_CHECK(clReleaseMemObject(b_img)); + } else { + cl_mem b_sub_buf; + cl_mem b_buf_trans; + cl_mem b_img; + cl_mem b_img_trans; + + // subbuffer for activation + region.origin = offset1; + region.size = ne00 * ne1 * sizeof(float); + CL_CHECK((b_sub_buf = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for activation + img_fmt.image_channel_order = CL_RGBA; + img_fmt.image_channel_data_type = CL_FLOAT; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = ne00 * ne1 / 4; + img_desc.buffer = b_sub_buf; + CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + // pad N to multiple of 8 + int extra_elements = ne1 % 8; + int padding = 0; + if (extra_elements > 0){ + padding = 8 - extra_elements; + } + + // subbuffer for transposed activation + region.origin = 0; + region.size = ne00 * (ne1 + padding) * sizeof(float)/2; + backend_ctx->prealloc_act_trans.allocate(context, region.size); + CL_CHECK((b_buf_trans = clCreateSubBuffer(backend_ctx->prealloc_act_trans.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for transposed activation + img_fmt.image_channel_order = CL_RGBA; + img_fmt.image_channel_data_type = CL_HALF_FLOAT; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = ne00 * (ne1 + padding) / 4; + img_desc.buffer = b_buf_trans; + CL_CHECK((b_img_trans = clCreateImage(context, 0, &img_fmt, &img_desc, NULL, &err), err)); + + // transpose activation + int height_B = ne1/4; + if (height_B == 0) { + height_B = 1; + } + int width_B = ne00/4; + int padded_height_B = (ne1 + padding) / 4; + + kernel = backend_ctx->kernel_transpose_32_16; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &b_img)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &b_img_trans)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_B)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_B)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &padded_height_B)); + + size_t local_size_t[2] = { 1, 16 }; + size_t global_size_t[2] = { (size_t)width_B, (size_t)padded_height_B }; + backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_size_t, local_size_t, dst); + + // gemm + kernel = backend_ctx->kernel_gemm_noshuffle_q6_K_f32; + int padded_N = ne1 + padding; + + cl_ushort mask_f000 = 0xF000; + cl_uchar mask_c0 = 0xC0; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q6_K->ql)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q6_K->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q6_K->s)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q6_K->d)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &b_img_trans)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &padded_N)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ushort),&mask_f000)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_uchar), &mask_c0)); + + size_t global_work_size[3] = {(size_t)CEIL_DIV(ne1, 8), (size_t)CEIL_DIV(ne01, 4), 1}; + size_t local_work_size[3] = {2, 128, 1}; + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + + CL_CHECK(clReleaseMemObject(b_sub_buf)); + CL_CHECK(clReleaseMemObject(b_img)); + CL_CHECK(clReleaseMemObject(b_buf_trans)); + CL_CHECK(clReleaseMemObject(b_img_trans)); + } +#else + GGML_UNUSED(backend); + GGML_UNUSED(src0); + GGML_UNUSED(src1); + GGML_UNUSED(dst); +#endif +} + static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0); GGML_ASSERT(src0->extra); @@ -9673,6 +9988,12 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co return; } + // q6_K x fp32 + if (src0t == GGML_TYPE_Q6_K && src1t == GGML_TYPE_F32) { + ggml_cl_mul_mat_q6_K_f32_adreno(backend, src0, src1, dst); + return; + } + // q4_0 x fp32 if(src0t == GGML_TYPE_Q4_0 && src1t == GGML_TYPE_F32) { // TODO: remove duplicate definitions of image description + format -- move to top diff --git a/ggml/src/ggml-opencl/kernels/cvt.cl b/ggml/src/ggml-opencl/kernels/cvt.cl index 272d0ea23f0..34930dfbe6a 100644 --- a/ggml/src/ggml-opencl/kernels/cvt.cl +++ b/ggml/src/ggml-opencl/kernels/cvt.cl @@ -486,8 +486,13 @@ kernel void kernel_convert_block_q6_K( global uchar * dst_ql, global uchar * dst_qh, global char * dst_s, - global half * dst_d + global half * dst_d, + uchar mask_lsb_8, + ulong n_blk ) { + if (get_global_id(0) >= n_blk) { + return; + } global struct block_q6_K * b = (global struct block_q6_K *) src0 + get_global_id(0); global uchar * ql = (global uchar *) dst_ql + QK_K/2*get_global_id(0); global uchar * qh = (global uchar *) dst_qh + QK_K/4*get_global_id(0); @@ -514,8 +519,13 @@ kernel void kernel_restore_block_q6_K( global uchar * dst_qh, global char * dst_s, global half * dst_d, - global struct block_q6_K * dst + global struct block_q6_K * dst, + uchar mask_lsb_8, + ulong n_blk ) { + if (get_global_id(0) >= n_blk) { + return; + } global struct block_q6_K * b = (global struct block_q6_K *) dst + get_global_id(0); global uchar * ql = (global uchar *) dst_ql + QK_K/2*get_global_id(0); global uchar * qh = (global uchar *) dst_qh + QK_K/4*get_global_id(0); @@ -534,3 +544,117 @@ kernel void kernel_restore_block_q6_K( b->scales[i] = s[i]; } } + +kernel void kernel_convert_block_q6_K_noshuffle( + global struct block_q6_K * src0, + global uchar * dst_ql, + global uchar * dst_qh, + global char * dst_s, + global half * dst_d, + uchar mask_lsb_8, + ulong n_blk +) { + if (get_global_id(0) >= n_blk) { + return; + } + global struct block_q6_K * b = (global struct block_q6_K *) src0 + get_global_id(0); + global uchar * ql = (global uchar *) dst_ql + QK_K/2*get_global_id(0); + global uchar * qh = (global uchar *) dst_qh + QK_K/4*get_global_id(0); + global char * s = (global char *) dst_s + QK_K/16*get_global_id(0); + global half * d = (global half *) dst_d + get_global_id(0); + + *d = b->d; + + for (int i = 0; i < QK_K/2/4; ++i) { + uchar x0 = b->ql[i*2 + 0] & mask_lsb_8; + uchar x1 = b->ql[i*2 + 1] & mask_lsb_8; + ql[i + 0] = (x0 & 0x0F) | ((x1 & 0x0F) << 4); + ql[i + 32] = ((x0 & 0xF0) >> 4) | (x1 & 0xF0); + + uchar x2 = b->ql[i*2 + 0 + 64] & mask_lsb_8; + uchar x3 = b->ql[i*2 + 1 + 64] & mask_lsb_8; + ql[i + 64] = (x2 & 0x0F) | ((x3 & 0x0F) << 4); + ql[i + 96] = ((x2 & 0xF0) >> 4) | (x3 & 0xF0); + } + + for (int i = 0; i < QK_K/4/8; ++i) { + uchar x0 = b->qh[i*4 + 0] & mask_lsb_8; + uchar x1 = b->qh[i*4 + 1] & mask_lsb_8; + uchar x2 = b->qh[i*4 + 2] & mask_lsb_8; + uchar x3 = b->qh[i*4 + 3] & mask_lsb_8; + qh[i + 0] = (x0 & 0x03) | ((x1 & 0x03) << 2) | ((x2 & 0x03) << 4) | ((x3 & 0x03) << 6); + qh[i + 8] = ((x0 & 0x0C) >> 2) | (x1 & 0x0C) | ((x2 & 0x0C) << 2) | ((x3 & 0x0C) << 4); + qh[i + 16] = ((x0 & 0x30) >> 4) | ((x1 & 0x30) >> 2) | (x2 & 0x30) | ((x3 & 0x30) << 2); + qh[i + 24] = ((x0 & 0xC0) >> 6) | ((x1 & 0xC0) >> 4) | ((x2 & 0xC0) >> 2) | (x3 & 0xC0); + + uchar x4 = b->qh[i*4 + 0 + 32] & mask_lsb_8; + uchar x5 = b->qh[i*4 + 1 + 32] & mask_lsb_8; + uchar x6 = b->qh[i*4 + 2 + 32] & mask_lsb_8; + uchar x7 = b->qh[i*4 + 3 + 32] & mask_lsb_8; + qh[i + 32] = (x4 & 0x03) | ((x5 & 0x03) << 2) | ((x6 & 0x03) << 4) | ((x7 & 0x03) << 6); + qh[i + 40] = ((x4 & 0x0C) >> 2) | (x5 & 0x0C) | ((x6 & 0x0C) << 2) | ((x7 & 0x0C) << 4); + qh[i + 48] = ((x4 & 0x30) >> 4) | ((x5 & 0x30) >> 2) | (x6 & 0x30) | ((x7 & 0x30) << 2); + qh[i + 56] = ((x4 & 0xC0) >> 6) | ((x5 & 0xC0) >> 4) | ((x6 & 0xC0) >> 2) | (x7 & 0xC0); + } + + for (int i = 0; i < QK_K/16; ++i) { + s[i] = b->scales[i]; + } +} + +kernel void kernel_restore_block_q6_K_noshuffle( + global uchar * src_ql, + global uchar * src_qh, + global char * src_s, + global half * src_d, + global struct block_q6_K * dst, + uchar mask_lsb_8, + ulong n_blk +) { + if (get_global_id(0) >= n_blk) { + return; + } + global struct block_q6_K * b = (global struct block_q6_K *) dst + get_global_id(0); + global uchar * ql = (global uchar *) src_ql + QK_K/2*get_global_id(0); + global uchar * qh = (global uchar *) src_qh + QK_K/4*get_global_id(0); + global char * s = (global char *) src_s + QK_K/16*get_global_id(0); + global half * d = (global half *) src_d + get_global_id(0); + + b->d = *d; + + for (int i = 0; i < QK_K/2/4; ++i) { + uchar x0 = ql[i + 0] & mask_lsb_8; + uchar x1 = ql[i + 32] & mask_lsb_8; + b->ql[i*2 + 0] = (x0 & 0x0F) | ((x1 & 0x0F) << 4); + b->ql[i*2 + 1] = ((x0 & 0xF0) >> 4) | (x1 & 0xF0); + + uchar x2 = ql[i + 64] & mask_lsb_8; + uchar x3 = ql[i + 96] & mask_lsb_8; + b->ql[i*2 + 0 + 64] = (x2 & 0x0F) | ((x3 & 0x0F) << 4); + b->ql[i*2 + 1 + 64] = ((x2 & 0xF0) >> 4) | (x3 & 0xF0); + } + + for (int i = 0; i < QK_K/4/8; ++i) { + uchar x0 = qh[i + 0] & mask_lsb_8; + uchar x1 = qh[i + 8] & mask_lsb_8; + uchar x2 = qh[i + 16] & mask_lsb_8; + uchar x3 = qh[i + 24] & mask_lsb_8; + b->qh[i*4 + 0] = (x0 & 0x03) | ((x1 & 0x03) << 2) | ((x2 & 0x03) << 4) | ((x3 & 0x03) << 6); + b->qh[i*4 + 1] = ((x0 & 0x0C) >> 2) | (x1 & 0x0C) | ((x2 & 0x0C) << 2) | ((x3 & 0x0C) << 4); + b->qh[i*4 + 2] = ((x0 & 0x30) >> 4) | ((x1 & 0x30) >> 2) | (x2 & 0x30) | ((x3 & 0x30) << 2); + b->qh[i*4 + 3] = ((x0 & 0xC0) >> 6) | ((x1 & 0xC0) >> 4) | ((x2 & 0xC0) >> 2) | (x3 & 0xC0); + + uchar x4 = qh[i + 0 + 32] & mask_lsb_8; + uchar x5 = qh[i + 8 + 32] & mask_lsb_8; + uchar x6 = qh[i + 16 + 32] & mask_lsb_8; + uchar x7 = qh[i + 24 + 32] & mask_lsb_8; + b->qh[i*4 + 0 + 32] = (x4 & 0x03) | ((x5 & 0x03) << 2) | ((x6 & 0x03) << 4) | ((x7 & 0x03) << 6); + b->qh[i*4 + 1 + 32] = ((x4 & 0x0C) >> 2) | (x5 & 0x0C) | ((x6 & 0x0C) << 2) | ((x7 & 0x0C) << 4); + b->qh[i*4 + 2 + 32] = ((x4 & 0x30) >> 4) | ((x5 & 0x30) >> 2) | (x6 & 0x30) | ((x7 & 0x30) << 2); + b->qh[i*4 + 3 + 32] = ((x4 & 0xC0) >> 6) | ((x5 & 0xC0) >> 4) | ((x6 & 0xC0) >> 2) | (x7 & 0xC0); + } + + for (int i = 0; i < QK_K/16; ++i) { + b->scales[i] = s[i]; + } +} diff --git a/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q6_k_f32.cl b/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q6_k_f32.cl new file mode 100644 index 00000000000..3a9c624508a --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q6_k_f32.cl @@ -0,0 +1,140 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable + +#ifdef cl_qcom_reqd_sub_group_size +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_128 +#endif +kernel void kernel_gemm_noshuffle_q6_K_f32( + global const ushort * src0_ql, + global const uchar * src0_qh, + global const ushort * src0_s, + global const half * src0_d, + read_only image1d_buffer_t src1, + global float * dst, + ulong offsetd, + int m, + int n, + int k, + int n_no_padding, + ushort mask_f000, + uchar mask_c0 +) { + dst = (global float *)( (global char *)dst + offsetd ); + + int m_4 = m >> 2; + int n_4 = n >> 2; + + int gy = get_global_id(0); // n + int gx = get_global_id(1); // m + int gx_2 = gx << 2; + + half8 c0 = 0, c1 = 0, c2 = 0, c3 = 0; + half8 B; + half4 dequantized_weights; + + global const ushort * ptr_ql = src0_ql + gx_2; + global const uchar * ptr_qh = src0_qh + gx_2; + global const ushort * ptr_s = src0_s + gx_2; + global const half * ptr_d = src0_d + gx_2; + + for (int i = 0; i < k; i += 4) { + // load 4x elements (ushort) of ql on M, each ushort contains 4 weights + // 4x ushort correspons to 4 rows on M + ushort4 bits4 = vload4(0, ptr_ql + (i/4)*m); // ql packed in 4s in ushort + uchar4 bits2 = vload4(0, ptr_qh + (i/4)*m); // qh packed in 4s in uchar + + // load 4 consecutive scales + char8 scale_s_8 = as_char8(vload4(0, ptr_s + (i/16/2)*m)); // 1 char scale every 16 elements, packed in 2s + char4 scale_s = ((i/16) % 2) == 0 ? scale_s_8.s0246 : scale_s_8.s1357; // transposed as ushort, 2 blocks + half4 scale_d = vload4(0, ptr_d + (i/256)*m); // 1 half scale every 256 elements + + // j=0 + // load 2x 4 elements of activations on N, corresponding to 8 rows on N + B.s0123 = read_imageh(src1, gy*2 + (i + 0)*n_4 + 0); + B.s4567 = read_imageh(src1, gy*2 + (i + 0)*n_4 + 1); + dequantized_weights.s0 = (convert_half((bits4.s0 & 0x000F) | ((bits2.s0 & 0x03) << 4)) - 32.f) * scale_s.s0 * scale_d.s0; + dequantized_weights.s1 = (convert_half((bits4.s1 & 0x000F) | ((bits2.s1 & 0x03) << 4)) - 32.f) * scale_s.s1 * scale_d.s1; + dequantized_weights.s2 = (convert_half((bits4.s2 & 0x000F) | ((bits2.s2 & 0x03) << 4)) - 32.f) * scale_s.s2 * scale_d.s2; + dequantized_weights.s3 = (convert_half((bits4.s3 & 0x000F) | ((bits2.s3 & 0x03) << 4)) - 32.f) * scale_s.s3 * scale_d.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=1 + B.s0123 = read_imageh(src1, gy*2 + (i + 1)*n_4 + 0); + B.s4567 = read_imageh(src1, gy*2 + (i + 1)*n_4 + 1); + dequantized_weights.s0 = (convert_half((((bits4.s0 & 0x00F0) >> 4) | ((bits2.s0 & 0x0C) << 2))) - 32.f) * scale_s.s0 * scale_d.s0; + dequantized_weights.s1 = (convert_half((((bits4.s1 & 0x00F0) >> 4) | ((bits2.s1 & 0x0C) << 2))) - 32.f) * scale_s.s1 * scale_d.s1; + dequantized_weights.s2 = (convert_half((((bits4.s2 & 0x00F0) >> 4) | ((bits2.s2 & 0x0C) << 2))) - 32.f) * scale_s.s2 * scale_d.s2; + dequantized_weights.s3 = (convert_half((((bits4.s3 & 0x00F0) >> 4) | ((bits2.s3 & 0x0C) << 2))) - 32.f) * scale_s.s3 * scale_d.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=2 + B.s0123 = read_imageh(src1, gy*2 + (i + 2)*n_4 + 0); + B.s4567 = read_imageh(src1, gy*2 + (i + 2)*n_4 + 1); + dequantized_weights.s0 = (convert_half((((bits4.s0 & 0x0F00) >> 8) | (bits2.s0 & 0x30))) - 32.f) * scale_s.s0 * scale_d.s0; + dequantized_weights.s1 = (convert_half((((bits4.s1 & 0x0F00) >> 8) | (bits2.s1 & 0x30))) - 32.f) * scale_s.s1 * scale_d.s1; + dequantized_weights.s2 = (convert_half((((bits4.s2 & 0x0F00) >> 8) | (bits2.s2 & 0x30))) - 32.f) * scale_s.s2 * scale_d.s2; + dequantized_weights.s3 = (convert_half((((bits4.s3 & 0x0F00) >> 8) | (bits2.s3 & 0x30))) - 32.f) * scale_s.s3 * scale_d.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=3 + B.s0123 = read_imageh(src1, gy*2 + (i + 3)*n_4 + 0); + B.s4567 = read_imageh(src1, gy*2 + (i + 3)*n_4 + 1); + dequantized_weights.s0 = (convert_half((((bits4.s0 & mask_f000) >> 12) | ((bits2.s0 & mask_c0) >> 2))) - 32.f) * scale_s.s0 * scale_d.s0; + dequantized_weights.s1 = (convert_half((((bits4.s1 & mask_f000) >> 12) | ((bits2.s1 & mask_c0) >> 2))) - 32.f) * scale_s.s1 * scale_d.s1; + dequantized_weights.s2 = (convert_half((((bits4.s2 & mask_f000) >> 12) | ((bits2.s2 & mask_c0) >> 2))) - 32.f) * scale_s.s2 * scale_d.s2; + dequantized_weights.s3 = (convert_half((((bits4.s3 & mask_f000) >> 12) | ((bits2.s3 & mask_c0) >> 2))) - 32.f) * scale_s.s3 * scale_d.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + } + + int idx = (gy<<3)*m + (gx<<2); + + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s0, c1.s0, c2.s0, c3.s0), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s1, c1.s1, c2.s1, c3.s1), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s2, c1.s2, c2.s2, c3.s2), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s3, c1.s3, c2.s3, c3.s3), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s4, c1.s4, c2.s4, c3.s4), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s5, c1.s5, c2.s5, c3.s5), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s6, c1.s6, c2.s6, c3.s6), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s7, c1.s7, c2.s7, c3.s7), 0, dst + idx); + } +} diff --git a/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q6_k_f32.cl b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q6_k_f32.cl new file mode 100644 index 00000000000..6f89cf968b9 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q6_k_f32.cl @@ -0,0 +1,293 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define NSUBGROUPS 4 +#define SUBGROUP_SIZE 64 + +#define dequantize_block_acc_bcast_8_hi(total_sum, bits4, bits2, scale_d, scale_s, y) \ + float8 shared_y; \ + shared_y = sub_group_broadcast(y, 0); \ + total_sum.s0 += ((float)(((bits4.s0 & 0x000F) ) | ((bits2.s0 & 0x03) << 4)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s0; \ + total_sum.s0 += ((float)(((bits4.s0 & 0x00F0) >> 4) | ((bits2.s0 & 0x0C) << 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s1; \ + total_sum.s0 += ((float)(((bits4.s0 & 0x0F00) >> 8) | ((bits2.s0 & 0x30) )) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s2; \ + total_sum.s0 += ((float)(((bits4.s0 & 0xF000) >> 12) | ((bits2.s0 & 0xC0) >> 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s3; \ + total_sum.s0 += ((float)(((bits4.s2 & 0x000F) ) | ((bits2.s2 & 0x03) << 4)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s4; \ + total_sum.s0 += ((float)(((bits4.s2 & 0x00F0) >> 4) | ((bits2.s2 & 0x0C) << 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s5; \ + total_sum.s0 += ((float)(((bits4.s2 & 0x0F00) >> 8) | ((bits2.s2 & 0x30) )) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s6; \ + total_sum.s0 += ((float)(((bits4.s2 & 0xF000) >> 12) | ((bits2.s2 & 0xC0) >> 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s7; \ + total_sum.s1 += ((float)(((bits4.s1 & 0x000F) ) | ((bits2.s1 & 0x03) << 4)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s0; \ + total_sum.s1 += ((float)(((bits4.s1 & 0x00F0) >> 4) | ((bits2.s1 & 0x0C) << 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s1; \ + total_sum.s1 += ((float)(((bits4.s1 & 0x0F00) >> 8) | ((bits2.s1 & 0x30) )) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s2; \ + total_sum.s1 += ((float)(((bits4.s1 & 0xF000) >> 12) | ((bits2.s1 & 0xC0) >> 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s3; \ + total_sum.s1 += ((float)(((bits4.s3 & 0x000F) ) | ((bits2.s3 & 0x03) << 4)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s4; \ + total_sum.s1 += ((float)(((bits4.s3 & 0x00F0) >> 4) | ((bits2.s3 & 0x0C) << 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s5; \ + total_sum.s1 += ((float)(((bits4.s3 & 0x0F00) >> 8) | ((bits2.s3 & 0x30) )) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s6; \ + total_sum.s1 += ((float)(((bits4.s3 & 0xF000) >> 12) | ((bits2.s3 & 0xC0) >> 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s7; \ + shared_y = sub_group_broadcast(y, 1); \ + total_sum.s0 += ((float)(((bits4.s4 & 0x000F) ) | ((bits2.s4 & 0x03) << 4)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s0; \ + total_sum.s0 += ((float)(((bits4.s4 & 0x00F0) >> 4) | ((bits2.s4 & 0x0C) << 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s1; \ + total_sum.s0 += ((float)(((bits4.s4 & 0x0F00) >> 8) | ((bits2.s4 & 0x30) )) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s2; \ + total_sum.s0 += ((float)(((bits4.s4 & 0xF000) >> 12) | ((bits2.s4 & 0xC0) >> 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s3; \ + total_sum.s0 += ((float)(((bits4.s6 & 0x000F) ) | ((bits2.s6 & 0x03) << 4)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s4; \ + total_sum.s0 += ((float)(((bits4.s6 & 0x00F0) >> 4) | ((bits2.s6 & 0x0C) << 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s5; \ + total_sum.s0 += ((float)(((bits4.s6 & 0x0F00) >> 8) | ((bits2.s6 & 0x30) )) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s6; \ + total_sum.s0 += ((float)(((bits4.s6 & 0xF000) >> 12) | ((bits2.s6 & 0xC0) >> 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s7; \ + total_sum.s1 += ((float)(((bits4.s5 & 0x000F) ) | ((bits2.s5 & 0x03) << 4)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s0; \ + total_sum.s1 += ((float)(((bits4.s5 & 0x00F0) >> 4) | ((bits2.s5 & 0x0C) << 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s1; \ + total_sum.s1 += ((float)(((bits4.s5 & 0x0F00) >> 8) | ((bits2.s5 & 0x30) )) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s2; \ + total_sum.s1 += ((float)(((bits4.s5 & 0xF000) >> 12) | ((bits2.s5 & 0xC0) >> 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s3; \ + total_sum.s1 += ((float)(((bits4.s7 & 0x000F) ) | ((bits2.s7 & 0x03) << 4)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s4; \ + total_sum.s1 += ((float)(((bits4.s7 & 0x00F0) >> 4) | ((bits2.s7 & 0x0C) << 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s5; \ + total_sum.s1 += ((float)(((bits4.s7 & 0x0F00) >> 8) | ((bits2.s7 & 0x30) )) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s6; \ + total_sum.s1 += ((float)(((bits4.s7 & 0xF000) >> 12) | ((bits2.s7 & 0xC0) >> 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s7; \ + +#define dequantize_block_acc_bcast_8_lo(total_sum, bits4, bits2, scale_d, scale_s, y) \ + shared_y = sub_group_broadcast(y, 2); \ + total_sum.s0 += ((float)(((bits4.s0 & 0x000F) ) | ((bits2.s0 & 0x03) << 4)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s0; \ + total_sum.s0 += ((float)(((bits4.s0 & 0x00F0) >> 4) | ((bits2.s0 & 0x0C) << 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s1; \ + total_sum.s0 += ((float)(((bits4.s0 & 0x0F00) >> 8) | ((bits2.s0 & 0x30) )) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s2; \ + total_sum.s0 += ((float)(((bits4.s0 & 0xF000) >> 12) | ((bits2.s0 & 0xC0) >> 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s3; \ + total_sum.s0 += ((float)(((bits4.s2 & 0x000F) ) | ((bits2.s2 & 0x03) << 4)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s4; \ + total_sum.s0 += ((float)(((bits4.s2 & 0x00F0) >> 4) | ((bits2.s2 & 0x0C) << 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s5; \ + total_sum.s0 += ((float)(((bits4.s2 & 0x0F00) >> 8) | ((bits2.s2 & 0x30) )) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s6; \ + total_sum.s0 += ((float)(((bits4.s2 & 0xF000) >> 12) | ((bits2.s2 & 0xC0) >> 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s7; \ + total_sum.s1 += ((float)(((bits4.s1 & 0x000F) ) | ((bits2.s1 & 0x03) << 4)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s0; \ + total_sum.s1 += ((float)(((bits4.s1 & 0x00F0) >> 4) | ((bits2.s1 & 0x0C) << 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s1; \ + total_sum.s1 += ((float)(((bits4.s1 & 0x0F00) >> 8) | ((bits2.s1 & 0x30) )) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s2; \ + total_sum.s1 += ((float)(((bits4.s1 & 0xF000) >> 12) | ((bits2.s1 & 0xC0) >> 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s3; \ + total_sum.s1 += ((float)(((bits4.s3 & 0x000F) ) | ((bits2.s3 & 0x03) << 4)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s4; \ + total_sum.s1 += ((float)(((bits4.s3 & 0x00F0) >> 4) | ((bits2.s3 & 0x0C) << 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s5; \ + total_sum.s1 += ((float)(((bits4.s3 & 0x0F00) >> 8) | ((bits2.s3 & 0x30) )) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s6; \ + total_sum.s1 += ((float)(((bits4.s3 & 0xF000) >> 12) | ((bits2.s3 & 0xC0) >> 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s7; \ + shared_y = sub_group_broadcast(y, 3); \ + total_sum.s0 += ((float)(((bits4.s4 & 0x000F) ) | ((bits2.s4 & 0x03) << 4)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s0; \ + total_sum.s0 += ((float)(((bits4.s4 & 0x00F0) >> 4) | ((bits2.s4 & 0x0C) << 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s1; \ + total_sum.s0 += ((float)(((bits4.s4 & 0x0F00) >> 8) | ((bits2.s4 & 0x30) )) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s2; \ + total_sum.s0 += ((float)(((bits4.s4 & 0xF000) >> 12) | ((bits2.s4 & 0xC0) >> 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s3; \ + total_sum.s0 += ((float)(((bits4.s6 & 0x000F) ) | ((bits2.s6 & 0x03) << 4)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s4; \ + total_sum.s0 += ((float)(((bits4.s6 & 0x00F0) >> 4) | ((bits2.s6 & 0x0C) << 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s5; \ + total_sum.s0 += ((float)(((bits4.s6 & 0x0F00) >> 8) | ((bits2.s6 & 0x30) )) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s6; \ + total_sum.s0 += ((float)(((bits4.s6 & 0xF000) >> 12) | ((bits2.s6 & 0xC0) >> 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s7; \ + total_sum.s1 += ((float)(((bits4.s5 & 0x000F) ) | ((bits2.s5 & 0x03) << 4)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s0; \ + total_sum.s1 += ((float)(((bits4.s5 & 0x00F0) >> 4) | ((bits2.s5 & 0x0C) << 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s1; \ + total_sum.s1 += ((float)(((bits4.s5 & 0x0F00) >> 8) | ((bits2.s5 & 0x30) )) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s2; \ + total_sum.s1 += ((float)(((bits4.s5 & 0xF000) >> 12) | ((bits2.s5 & 0xC0) >> 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s3; \ + total_sum.s1 += ((float)(((bits4.s7 & 0x000F) ) | ((bits2.s7 & 0x03) << 4)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s4; \ + total_sum.s1 += ((float)(((bits4.s7 & 0x00F0) >> 4) | ((bits2.s7 & 0x0C) << 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s5; \ + total_sum.s1 += ((float)(((bits4.s7 & 0x0F00) >> 8) | ((bits2.s7 & 0x30) )) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s6; \ + total_sum.s1 += ((float)(((bits4.s7 & 0xF000) >> 12) | ((bits2.s7 & 0xC0) >> 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s7; \ + +#define dequantize_block_acc_bcast_1_hi(total_sum, bits4, bits2, scale_d, scale_s, y) \ + float shared_y; \ + shared_y = sub_group_broadcast(y.s0, 0); \ + total_sum.s0 += ((float)(((bits4.s0 & 0x000F) ) | ((bits2.s0 & 0x03) << 4)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s1 & 0x000F) ) | ((bits2.s1 & 0x03) << 4)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 0); \ + total_sum.s0 += ((float)(((bits4.s0 & 0x00F0) >> 4) | ((bits2.s0 & 0x0C) << 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s1 & 0x00F0) >> 4) | ((bits2.s1 & 0x0C) << 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 0); \ + total_sum.s0 += ((float)(((bits4.s0 & 0x0F00) >> 8) | ((bits2.s0 & 0x30) )) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s1 & 0x0F00) >> 8) | ((bits2.s1 & 0x30) )) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 0); \ + total_sum.s0 += ((float)(((bits4.s0 & 0xF000) >> 12) | ((bits2.s0 & 0xC0) >> 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s1 & 0xF000) >> 12) | ((bits2.s1 & 0xC0) >> 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 0); \ + total_sum.s0 += ((float)(((bits4.s2 & 0x000F) ) | ((bits2.s2 & 0x03) << 4)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s3 & 0x000F) ) | ((bits2.s3 & 0x03) << 4)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 0); \ + total_sum.s0 += ((float)(((bits4.s2 & 0x00F0) >> 4) | ((bits2.s2 & 0x0C) << 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s3 & 0x00F0) >> 4) | ((bits2.s3 & 0x0C) << 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 0); \ + total_sum.s0 += ((float)(((bits4.s2 & 0x0F00) >> 8) | ((bits2.s2 & 0x30) )) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s3 & 0x0F00) >> 8) | ((bits2.s3 & 0x30) )) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 0); \ + total_sum.s0 += ((float)(((bits4.s2 & 0xF000) >> 12) | ((bits2.s2 & 0xC0) >> 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s3 & 0xF000) >> 12) | ((bits2.s3 & 0xC0) >> 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s0, 1); \ + total_sum.s0 += ((float)(((bits4.s4 & 0x000F) ) | ((bits2.s4 & 0x03) << 4)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s5 & 0x000F) ) | ((bits2.s5 & 0x03) << 4)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 1); \ + total_sum.s0 += ((float)(((bits4.s4 & 0x00F0) >> 4) | ((bits2.s4 & 0x0C) << 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s5 & 0x00F0) >> 4) | ((bits2.s5 & 0x0C) << 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 1); \ + total_sum.s0 += ((float)(((bits4.s4 & 0x0F00) >> 8) | ((bits2.s4 & 0x30) )) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s5 & 0x0F00) >> 8) | ((bits2.s5 & 0x30) )) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 1); \ + total_sum.s0 += ((float)(((bits4.s4 & 0xF000) >> 12) | ((bits2.s4 & 0xC0) >> 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s5 & 0xF000) >> 12) | ((bits2.s5 & 0xC0) >> 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 1); \ + total_sum.s0 += ((float)(((bits4.s6 & 0x000F) ) | ((bits2.s6 & 0x03) << 4)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s7 & 0x000F) ) | ((bits2.s7 & 0x03) << 4)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 1); \ + total_sum.s0 += ((float)(((bits4.s6 & 0x00F0) >> 4) | ((bits2.s6 & 0x0C) << 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s7 & 0x00F0) >> 4) | ((bits2.s7 & 0x0C) << 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 1); \ + total_sum.s0 += ((float)(((bits4.s6 & 0x0F00) >> 8) | ((bits2.s6 & 0x30) )) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s7 & 0x0F00) >> 8) | ((bits2.s7 & 0x30) )) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 1); \ + total_sum.s0 += ((float)(((bits4.s6 & 0xF000) >> 12) | ((bits2.s6 & 0xC0) >> 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s7 & 0xF000) >> 12) | ((bits2.s7 & 0xC0) >> 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \ + +#define dequantize_block_acc_bcast_1_lo(total_sum, bits4, bits2, scale_d, scale_s, y) \ + shared_y = sub_group_broadcast(y.s0, 2); \ + total_sum.s0 += ((float)(((bits4.s0 & 0x000F) ) | ((bits2.s0 & 0x03) << 4)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s1 & 0x000F) ) | ((bits2.s1 & 0x03) << 4)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 2); \ + total_sum.s0 += ((float)(((bits4.s0 & 0x00F0) >> 4) | ((bits2.s0 & 0x0C) << 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s1 & 0x00F0) >> 4) | ((bits2.s1 & 0x0C) << 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 2); \ + total_sum.s0 += ((float)(((bits4.s0 & 0x0F00) >> 8) | ((bits2.s0 & 0x30) )) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s1 & 0x0F00) >> 8) | ((bits2.s1 & 0x30) )) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 2); \ + total_sum.s0 += ((float)(((bits4.s0 & 0xF000) >> 12) | ((bits2.s0 & 0xC0) >> 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s1 & 0xF000) >> 12) | ((bits2.s1 & 0xC0) >> 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 2); \ + total_sum.s0 += ((float)(((bits4.s2 & 0x000F) ) | ((bits2.s2 & 0x03) << 4)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s3 & 0x000F) ) | ((bits2.s3 & 0x03) << 4)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 2); \ + total_sum.s0 += ((float)(((bits4.s2 & 0x00F0) >> 4) | ((bits2.s2 & 0x0C) << 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s3 & 0x00F0) >> 4) | ((bits2.s3 & 0x0C) << 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 2); \ + total_sum.s0 += ((float)(((bits4.s2 & 0x0F00) >> 8) | ((bits2.s2 & 0x30) )) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s3 & 0x0F00) >> 8) | ((bits2.s3 & 0x30) )) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 2); \ + total_sum.s0 += ((float)(((bits4.s2 & 0xF000) >> 12) | ((bits2.s2 & 0xC0) >> 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s3 & 0xF000) >> 12) | ((bits2.s3 & 0xC0) >> 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s0, 3); \ + total_sum.s0 += ((float)(((bits4.s4 & 0x000F) ) | ((bits2.s4 & 0x03) << 4)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s5 & 0x000F) ) | ((bits2.s5 & 0x03) << 4)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 3); \ + total_sum.s0 += ((float)(((bits4.s4 & 0x00F0) >> 4) | ((bits2.s4 & 0x0C) << 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s5 & 0x00F0) >> 4) | ((bits2.s5 & 0x0C) << 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 3); \ + total_sum.s0 += ((float)(((bits4.s4 & 0x0F00) >> 8) | ((bits2.s4 & 0x30) )) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s5 & 0x0F00) >> 8) | ((bits2.s5 & 0x30) )) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 3); \ + total_sum.s0 += ((float)(((bits4.s4 & 0xF000) >> 12) | ((bits2.s4 & 0xC0) >> 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s5 & 0xF000) >> 12) | ((bits2.s5 & 0xC0) >> 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 3); \ + total_sum.s0 += ((float)(((bits4.s6 & 0x000F) ) | ((bits2.s6 & 0x03) << 4)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s7 & 0x000F) ) | ((bits2.s7 & 0x03) << 4)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 3); \ + total_sum.s0 += ((float)(((bits4.s6 & 0x00F0) >> 4) | ((bits2.s6 & 0x0C) << 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s7 & 0x00F0) >> 4) | ((bits2.s7 & 0x0C) << 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 3); \ + total_sum.s0 += ((float)(((bits4.s6 & 0x0F00) >> 8) | ((bits2.s6 & 0x30) )) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s7 & 0x0F00) >> 8) | ((bits2.s7 & 0x30) )) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 3); \ + total_sum.s0 += ((float)(((bits4.s6 & 0xF000) >> 12) | ((bits2.s6 & 0xC0) >> 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s7 & 0xF000) >> 12) | ((bits2.s7 & 0xC0) >> 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \ + +#if defined(ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_gemv_noshuffle_q6_K_f32( + read_only image1d_buffer_t src0_ql, + read_only image1d_buffer_t src0_qh, + global half2 * src0_s, + global half2 * src0_d, + read_only image1d_buffer_t src1, + global float * dst, + ulong offsetd, + int ne00, + int ne01 +) { + int grp = get_local_id(1); + int gid = get_global_id(0); + ushort slid = get_sub_group_local_id(); + + int nb = ne00 / 32; + + uint4 reg_a_l; + ushort4 reg_a_h; + half2 reg_d; + char4 reg_s; + float8 reg_b; + + float2 total_sum = 0.0f; + + int line_stride_a = ne01 / 2; + int block_stride_a = NSUBGROUPS * ne01; + + for (int k = grp; k < nb; k += NSUBGROUPS) { + reg_d = src0_d[gid + k/8 * line_stride_a]; + reg_s = as_char4(src0_s[gid + k * line_stride_a]); + + if (slid < 4) { + reg_b.s0123 = read_imagef(src1, 0 + slid*2 + k*8); + reg_b.s4567 = read_imagef(src1, 1 + slid*2 + k*8); + } + + reg_a_l.s0 = read_imageui(src0_ql, gid + k*block_stride_a + line_stride_a*0).x; + reg_a_l.s1 = read_imageui(src0_ql, gid + k*block_stride_a + line_stride_a*1).x; + reg_a_l.s2 = read_imageui(src0_ql, gid + k*block_stride_a + line_stride_a*2).x; + reg_a_l.s3 = read_imageui(src0_ql, gid + k*block_stride_a + line_stride_a*3).x; + + reg_a_h.s0 = as_ushort(read_imageh(src0_qh, gid + k*block_stride_a + line_stride_a*0).x); + reg_a_h.s1 = as_ushort(read_imageh(src0_qh, gid + k*block_stride_a + line_stride_a*1).x); + reg_a_h.s2 = as_ushort(read_imageh(src0_qh, gid + k*block_stride_a + line_stride_a*2).x); + reg_a_h.s3 = as_ushort(read_imageh(src0_qh, gid + k*block_stride_a + line_stride_a*3).x); + +#ifdef VECTOR_SUB_GROUP_BROADCAT + dequantize_block_acc_bcast_8_hi(total_sum, as_ushort8(reg_a_l), as_uchar8(reg_a_h), reg_d, reg_s, reg_b); +#else + dequantize_block_acc_bcast_1_hi(total_sum, as_ushort8(reg_a_l), as_uchar8(reg_a_h), reg_d, reg_s, reg_b); +#endif // VECTOR_SUB_GROUP_BROADCAT + + reg_a_l.s0 = read_imageui(src0_ql, gid + k*block_stride_a + line_stride_a*4).x; + reg_a_l.s1 = read_imageui(src0_ql, gid + k*block_stride_a + line_stride_a*5).x; + reg_a_l.s2 = read_imageui(src0_ql, gid + k*block_stride_a + line_stride_a*6).x; + reg_a_l.s3 = read_imageui(src0_ql, gid + k*block_stride_a + line_stride_a*7).x; + + reg_a_h.s0 = as_ushort(read_imageh(src0_qh, gid + k*block_stride_a + line_stride_a*4).x); + reg_a_h.s1 = as_ushort(read_imageh(src0_qh, gid + k*block_stride_a + line_stride_a*5).x); + reg_a_h.s2 = as_ushort(read_imageh(src0_qh, gid + k*block_stride_a + line_stride_a*6).x); + reg_a_h.s3 = as_ushort(read_imageh(src0_qh, gid + k*block_stride_a + line_stride_a*7).x); + +#ifdef VECTOR_SUB_GROUP_BROADCAT + dequantize_block_acc_bcast_8_lo(total_sum, as_ushort8(reg_a_l), as_uchar8(reg_a_h), reg_d, reg_s, reg_b); +#else + dequantize_block_acc_bcast_1_lo(total_sum, as_ushort8(reg_a_l), as_uchar8(reg_a_h), reg_d, reg_s, reg_b); +#endif // VECTOR_SUB_GROUP_BROADCAT + } + + local float2 reduce_lm[SUBGROUP_SIZE * 3]; + if (grp == 1) { + reduce_lm[SUBGROUP_SIZE*0 + slid] = total_sum; + } + if (grp == 2) { + reduce_lm[SUBGROUP_SIZE*1 + slid] = total_sum; + } + if (grp == 3) { + reduce_lm[SUBGROUP_SIZE*2 + slid] = total_sum; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + if (grp == 0) { + total_sum += reduce_lm[SUBGROUP_SIZE*0 + slid]; + } + if (grp == 0) { + total_sum += reduce_lm[SUBGROUP_SIZE*1 + slid]; + } + if (grp == 0) { + total_sum += reduce_lm[SUBGROUP_SIZE*2 + slid]; + } + + if (grp == 0) { + dst = (global float*)((global char*)dst + offsetd); + vstore2(total_sum, 0, &(dst[gid * 2])); + } +} From 116a9f6ab79b518babc4c036d1bb56ac60a3da58 Mon Sep 17 00:00:00 2001 From: Max Krasnyansky Date: Mon, 23 Mar 2026 15:33:49 -0700 Subject: [PATCH 342/831] hexagon: general DMA and Binary Op fixes for large strides (llama/20918) * hex-dma: make chained dma the default to handle newer models This also includes some new instrumentation that we can remove later. * hexagon: add uint32 dump helper * hexagon: use single-page VTCM allocation to avoid issues with large gather ops in ssm-conv ssm-conv uses HVX gather instruction and that instruction cannot handle cases where the base+offset spans page boundaries. * hexagon: update ssm-conv to make base-addr compute a bit easier to read * hex-dma: use 1d mode for reshaping, it supports sizes up to 24-bits (>16MB) * hex-bin: fix incorrect stride logic * hexagon: make sure repack buffs are dumped for verbose > 2 * hex-bin: consistently use dma_queue_push even for dummy dst transactions * hex-dma: start using 2d-wide mode on v75 and up The removes the need to deal with the 16-bit limitaion for the strides. * hex-bin: cleanup kernel selection logic * hex-bin: cleanup binary op core and fix transposed tensor handling * snapdragon: update run-bench to use larger ubatch and fa-on --- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 12 +- ggml/src/ggml-hexagon/htp/binary-ops.c | 307 ++++++++++----------- ggml/src/ggml-hexagon/htp/hex-dma.c | 4 +- ggml/src/ggml-hexagon/htp/hex-dma.h | 307 ++++++++++++--------- ggml/src/ggml-hexagon/htp/hex-dump.h | 9 + ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c | 28 +- ggml/src/ggml-hexagon/htp/hvx-utils.h | 8 - ggml/src/ggml-hexagon/htp/main.c | 4 +- ggml/src/ggml-hexagon/htp/ssm-conv.c | 18 +- 9 files changed, 368 insertions(+), 329 deletions(-) diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 8bcf5291c11..9c1ce93cc69 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -461,7 +461,7 @@ static void repack_row_q4x4x2(uint8_t * y, const block_q4_0 * x, int64_t k) { d[7] = x[i * 8 + 7].d; } - if (opt_verbose > 1) { + if (opt_verbose > 2) { for (int i = 0; i < nb; i++) { dump_packed_block_q4x4x2(y, i, k); } @@ -480,7 +480,7 @@ static void unpack_row_q4x4x2(block_q4_0 * x, const uint8_t * y, int64_t k) { const uint8_t * y_q = y + 0; // quants first const uint8_t * y_d = y + qrow_size; // then scales - if (opt_verbose > 1) { + if (opt_verbose > 2) { for (int i = 0; i < nb; i++) { dump_packed_block_q4x4x2(y, i, k); } @@ -796,7 +796,7 @@ static void repack_row_q8x4x2(uint8_t * y, const block_q8_0 * x, int64_t k) { d[7] = x[i * 8 + 7].d; } - if (opt_verbose > 1) { + if (opt_verbose > 2) { for (int i = 0; i < nb; i++) { dump_packed_block_q8x4x2(y, i, k); } @@ -814,7 +814,7 @@ static void unpack_row_q8x4x2(block_q8_0 * x, const uint8_t * y, int64_t k) { const uint8_t * y_q = y + 0; // quants first const uint8_t * y_d = y + qrow_size; // then scales - if (opt_verbose > 1) { + if (opt_verbose > 2) { for (int i = 0; i < nb; i++) { dump_packed_block_q8x4x2(y, i, k); } @@ -1149,7 +1149,7 @@ static void repack_row_mxfp4x4x2(uint8_t * y, const block_mxfp4 * x, int64_t k) e[7] = x[i * 8 + 7].e; } - if (opt_verbose > 1) { + if (opt_verbose > 2) { for (int i = 0; i < nb; i++) { dump_packed_block_mxfp4x4x2(y, i, k); } @@ -1168,7 +1168,7 @@ static void unpack_row_mxfp4x4x2(block_mxfp4 * x, const uint8_t * y, int64_t k) const uint8_t * y_q = y + 0; // quants first const uint8_t * y_e = y + qrow_size; // then scales - if (opt_verbose > 1) { + if (opt_verbose > 2) { for (int i = 0; i < nb; i++) { dump_packed_block_mxfp4x4x2(y, i, k); } diff --git a/ggml/src/ggml-hexagon/htp/binary-ops.c b/ggml/src/ggml-hexagon/htp/binary-ops.c index ec90f22de52..1b0f97493bc 100644 --- a/ggml/src/ggml-hexagon/htp/binary-ops.c +++ b/ggml/src/ggml-hexagon/htp/binary-ops.c @@ -24,28 +24,26 @@ // Context for binary operations struct htp_binary_context { struct htp_ops_context * octx; - struct fastdiv_values dim1_div; - struct fastdiv_values dim2_div; - struct fastdiv_values dim12_div; + + struct fastdiv_values src0_dim1_div; // ne01 + struct fastdiv_values src0_dim2_div; // ne02 + struct fastdiv_values src0_dim12_div;// ne03 struct fastdiv_values src1_dim1_div; // ne11 struct fastdiv_values src1_dim2_div; // ne12 struct fastdiv_values src1_dim3_div; // ne13 - uint32_t nrows_per_thread; - bool split_at_ne01; - bool split_at_ne02; - - // Precomputed values uint32_t block_max; + uint32_t nrows_per_thread; size_t src0_row_size_aligned; size_t src1_row_size_aligned; size_t dst_row_size_aligned; - uint32_t src1_fetch_rows; // 1 or block_max - uint32_t src1_dma_stride; // 0 or stride + + bool split_at_ne01; + bool split_at_ne02; }; -#define htp_binary_preamble \ +#define htp_binary_preamble \ const struct htp_tensor * src0 = &octx->src0; \ const struct htp_tensor * src1 = &octx->src1; \ struct htp_tensor * dst = &octx->dst; \ @@ -72,12 +70,11 @@ struct htp_binary_context { const uint32_t nb2 = dst->nb[2]; \ const uint32_t nb3 = dst->nb[3]; -static inline uint32_t calc_block_size(struct htp_binary_context * bctx, uint32_t ir, uint32_t end_row, - uint32_t ne01, uint32_t ne02) { +static inline uint32_t calc_block_size(struct htp_binary_context * bctx, uint32_t ir, uint32_t end_row, uint32_t ne01, uint32_t ne02) { uint32_t i03, i02, i01, rem; - i03 = fastdiv(ir, &bctx->dim12_div); + i03 = fastdiv(ir, &bctx->src0_dim12_div); rem = ir - i03 * (ne02 * ne01); - i02 = fastdiv(rem, &bctx->dim1_div); + i02 = fastdiv(rem, &bctx->src0_dim1_div); i01 = rem - i02 * ne01; uint32_t rows_left = end_row - ir; @@ -191,6 +188,8 @@ static void binary_job_scalar(unsigned int nth, unsigned int ith, void * data) { const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows); if (start_row >= end_row) return; + FARF(HIGH, "binary-scalar: %d/%d (%u:%u) row-size %u (%u)", ith, nth, start_row, end_row, nb01, bctx->dst_row_size_aligned); + uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread); uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread); size_t src0_spad_half = octx->src0_spad.size_per_thread / 2; @@ -204,9 +203,9 @@ static void binary_job_scalar(unsigned int nth, unsigned int ith, void * data) { for (int k = 0; k < 2 && ir_prefetch < end_row; k++) { uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); uint32_t i03, i02, i01, rem; - i03 = fastdiv(ir_prefetch, &bctx->dim12_div); + i03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div); rem = ir_prefetch - i03 * (ne02 * ne01); - i02 = fastdiv(rem, &bctx->dim1_div); + i02 = fastdiv(rem, &bctx->src0_dim1_div); i01 = rem - i02 * ne01; uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01; @@ -215,7 +214,7 @@ static void binary_job_scalar(unsigned int nth, unsigned int ith, void * data) { uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half; uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half; - dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0); + dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, 0); dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size); ir_prefetch += current_block_size; spad_idx ^= 1; @@ -229,9 +228,9 @@ static void binary_job_scalar(unsigned int nth, unsigned int ith, void * data) { uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst; uint32_t i03, i02, i01, rem; - i03 = fastdiv(ir, &bctx->dim12_div); + i03 = fastdiv(ir, &bctx->src0_dim12_div); rem = ir - i03 * (ne02 * ne01); - i02 = fastdiv(rem, &bctx->dim1_div); + i02 = fastdiv(rem, &bctx->src0_dim1_div); i01 = rem - i02 * ne01; // src1 indices (broadcast/repeat) @@ -255,9 +254,9 @@ static void binary_job_scalar(unsigned int nth, unsigned int ith, void * data) { if (ir_prefetch < end_row) { uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); uint32_t p03, p02, p01, prem; - p03 = fastdiv(ir_prefetch, &bctx->dim12_div); + p03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div); prem = ir_prefetch - p03 * (ne02 * ne01); - p02 = fastdiv(prem, &bctx->dim1_div); + p02 = fastdiv(prem, &bctx->src0_dim1_div); p01 = prem - p02 * ne01; uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01; @@ -282,6 +281,8 @@ static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, voi const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows); if (start_row >= end_row) return; + FARF(HIGH, "binary-same-shape: %d/%d (%u:%u) row-size %u (%u)", ith, nth, start_row, end_row, nb01, bctx->dst_row_size_aligned); + uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread); uint8_t * src1_spad_base = octx->src1_spad.data + (ith * octx->src1_spad.size_per_thread); uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread); @@ -297,9 +298,9 @@ static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, voi for (int k = 0; k < 2 && ir_prefetch < end_row; k++) { uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); uint32_t i03, i02, i01, rem; - i03 = fastdiv(ir_prefetch, &bctx->dim12_div); + i03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div); rem = ir_prefetch - i03 * (ne02 * ne01); - i02 = fastdiv(rem, &bctx->dim1_div); + i02 = fastdiv(rem, &bctx->src0_dim1_div); i01 = rem - i02 * ne01; uint32_t i13 = (ne13 == 1) ? 0 : i03; @@ -307,23 +308,23 @@ static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, voi uint32_t i11 = (ne11 == 1) ? 0 : i01; uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01; - uint8_t * src1_base = (uint8_t *)src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11; + uint8_t * src1_curr = (uint8_t *)src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11; uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half; uint8_t * s1_spad = src1_spad_base + spad_idx * src1_spad_half; uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half; - dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0); + dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, 0); dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size); - dma_queue_push(q, dma_make_ptr(s1_spad, src1_base), bctx->src1_row_size_aligned, bctx->src1_dma_stride, row_size_bytes, current_block_size); + dma_queue_push(q, dma_make_ptr(s1_spad, src1_curr), bctx->src1_row_size_aligned, nb11, row_size_bytes, current_block_size); ir_prefetch += current_block_size; spad_idx ^= 1; } for (uint32_t ir = start_row; ir < end_row; ) { uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02); - uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src; + uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src; uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst; uint8_t * s1_spad = (uint8_t *) dma_queue_pop(q).dst; @@ -335,9 +336,9 @@ static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, voi } uint32_t i03, i02, i01, rem; - i03 = fastdiv(ir, &bctx->dim12_div); + i03 = fastdiv(ir, &bctx->src0_dim12_div); rem = ir - i03 * (ne02 * ne01); - i02 = fastdiv(rem, &bctx->dim1_div); + i02 = fastdiv(rem, &bctx->src0_dim1_div); i01 = rem - i02 * ne01; uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, current_block_size); @@ -345,9 +346,9 @@ static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, voi if (ir_prefetch < end_row) { uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); uint32_t p03, p02, p01, prem; - p03 = fastdiv(ir_prefetch, &bctx->dim12_div); + p03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div); prem = ir_prefetch - p03 * (ne02 * ne01); - p02 = fastdiv(prem, &bctx->dim1_div); + p02 = fastdiv(prem, &bctx->src0_dim1_div); p01 = prem - p02 * ne01; uint32_t p13 = (ne13 == 1) ? 0 : p03; @@ -358,7 +359,7 @@ static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, voi uint8_t * s1_next = (uint8_t *)src1->data + p13 * nb13 + p12 * nb12 + p11 * nb11; dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size); - dma_queue_push(q, dma_make_ptr(s1_spad, s1_next), bctx->src1_row_size_aligned, bctx->src1_dma_stride, row_size_bytes, next_block_size); + dma_queue_push(q, dma_make_ptr(s1_spad, s1_next), bctx->src1_row_size_aligned, nb11, row_size_bytes, next_block_size); ir_prefetch += next_block_size; } @@ -373,15 +374,17 @@ static void binary_job_vector_row_broadcast(unsigned int nth, unsigned int ith, struct htp_ops_context * octx = bctx->octx; htp_binary_preamble; - const uint32_t src0_type = octx->src0.type; + const uint32_t src0_type = octx->src0.type; const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16); const uint32_t total_rows = ne01 * ne02 * ne03; - const uint32_t start_row = bctx->nrows_per_thread * ith; - const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows); + const uint32_t start_row = bctx->nrows_per_thread * ith; + const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows); if (start_row >= end_row) return; + FARF(HIGH, "binary-row-bcast: %d/%d (%u:%u) row-size %u (%u)", ith, nth, start_row, end_row, nb01, bctx->dst_row_size_aligned); + uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread); - uint8_t * src1_spad = octx->src1_spad.data + (ith * octx->src1_spad.size_per_thread); + uint8_t * src1_spad_base = octx->src1_spad.data + (ith * octx->src1_spad.size_per_thread); uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread); size_t src0_spad_half = octx->src0_spad.size_per_thread / 2; @@ -391,15 +394,14 @@ static void binary_job_vector_row_broadcast(unsigned int nth, unsigned int ith, uint32_t ir_prefetch = start_row; int spad_idx = 0; - void * s1_ptr = (void *) src1_spad; + void * s1_ptr = (void *) src1_spad_base; for (int k = 0; k < 2 && ir_prefetch < end_row; k++) { uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); - uint32_t i03, i02, i01, rem; - i03 = fastdiv(ir_prefetch, &bctx->dim12_div); - rem = ir_prefetch - i03 * (ne02 * ne01); - i02 = fastdiv(rem, &bctx->dim1_div); - i01 = rem - i02 * ne01; + uint32_t i03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div); + uint32_t rem = ir_prefetch - i03 * (ne02 * ne01); + uint32_t i02 = fastdiv(rem, &bctx->src0_dim1_div); + uint32_t i01 = rem - i02 * ne01; uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01; uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; @@ -407,7 +409,7 @@ static void binary_job_vector_row_broadcast(unsigned int nth, unsigned int ith, uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half; uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half; - dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0); + dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, 0); dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size); ir_prefetch += current_block_size; spad_idx ^= 1; @@ -415,7 +417,7 @@ static void binary_job_vector_row_broadcast(unsigned int nth, unsigned int ith, for (uint32_t ir = start_row; ir < end_row; ) { uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02); - uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src; + uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src; uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst; for (uint32_t r = 0; r < current_block_size; r++) { @@ -425,21 +427,19 @@ static void binary_job_vector_row_broadcast(unsigned int nth, unsigned int ith, COMPUTE_VECTOR_OP_AAA(r_dst, r_src0, r_src1, src0_type, ne00); } - uint32_t i03, i02, i01, rem; - i03 = fastdiv(ir, &bctx->dim12_div); - rem = ir - i03 * (ne02 * ne01); - i02 = fastdiv(rem, &bctx->dim1_div); - i01 = rem - i02 * ne01; + uint32_t i03 = fastdiv(ir, &bctx->src0_dim12_div); + uint32_t rem = ir - i03 * (ne02 * ne01); + uint32_t i02 = fastdiv(rem, &bctx->src0_dim1_div); + uint32_t i01 = rem - i02 * ne01; uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, current_block_size); if (ir_prefetch < end_row) { uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); - uint32_t p03, p02, p01, prem; - p03 = fastdiv(ir_prefetch, &bctx->dim12_div); - prem = ir_prefetch - p03 * (ne02 * ne01); - p02 = fastdiv(prem, &bctx->dim1_div); - p01 = prem - p02 * ne01; + uint32_t p03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div); + uint32_t prem = ir_prefetch - p03 * (ne02 * ne01); + uint32_t p02 = fastdiv(prem, &bctx->src0_dim1_div); + uint32_t p01 = prem - p02 * ne01; uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01; dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size); ir_prefetch += next_block_size; @@ -458,14 +458,16 @@ static void binary_job_vector_complex(unsigned int nth, unsigned int ith, void * const uint32_t src0_type = octx->src0.type; const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16); const uint32_t total_rows = ne01 * ne02 * ne03; - const uint32_t start_row = bctx->nrows_per_thread * ith; - const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows); + const uint32_t start_row = bctx->nrows_per_thread * ith; + const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows); if (start_row >= end_row) return; + FARF(HIGH, "binary-complex: %d/%d (%u:%u) row-size %u (%u)", ith, nth, start_row, end_row, nb01, bctx->dst_row_size_aligned); + uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread); uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread); - size_t src0_spad_half = octx->src0_spad.size_per_thread / 2; - size_t dst_spad_half = octx->dst_spad.size_per_thread / 2; + size_t src0_spad_half = octx->src0_spad.size_per_thread / 2; + size_t dst_spad_half = octx->dst_spad.size_per_thread / 2; dma_queue * q = octx->ctx->dma[ith]; uint32_t ir_prefetch = start_row; @@ -473,11 +475,10 @@ static void binary_job_vector_complex(unsigned int nth, unsigned int ith, void * for (int k = 0; k < 2 && ir_prefetch < end_row; k++) { uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); - uint32_t i03, i02, i01, rem; - i03 = fastdiv(ir_prefetch, &bctx->dim12_div); - rem = ir_prefetch - i03 * (ne02 * ne01); - i02 = fastdiv(rem, &bctx->dim1_div); - i01 = rem - i02 * ne01; + uint32_t i03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div); + uint32_t rem = ir_prefetch - i03 * (ne02 * ne01); + uint32_t i02 = fastdiv(rem, &bctx->src0_dim1_div); + uint32_t i01 = rem - i02 * ne01; uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01; uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; @@ -485,7 +486,7 @@ static void binary_job_vector_complex(unsigned int nth, unsigned int ith, void * uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half; uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half; - dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0); + dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, 0); dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size); ir_prefetch += current_block_size; spad_idx ^= 1; @@ -496,11 +497,10 @@ static void binary_job_vector_complex(unsigned int nth, unsigned int ith, void * uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src; uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst; - uint32_t i03, i02, i01, rem; - i03 = fastdiv(ir, &bctx->dim12_div); - rem = ir - i03 * (ne02 * ne01); - i02 = fastdiv(rem, &bctx->dim1_div); - i01 = rem - i02 * ne01; + uint32_t i03 = fastdiv(ir, &bctx->src0_dim12_div); + uint32_t rem = ir - i03 * (ne02 * ne01); + uint32_t i02 = fastdiv(rem, &bctx->src0_dim1_div); + uint32_t i01 = rem - i02 * ne01; for (uint32_t r = 0; r < current_block_size; r++) { uint32_t r_i01 = i01 + r; @@ -521,11 +521,10 @@ static void binary_job_vector_complex(unsigned int nth, unsigned int ith, void * if (ir_prefetch < end_row) { uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); - uint32_t p03, p02, p01, prem; - p03 = fastdiv(ir_prefetch, &bctx->dim12_div); - prem = ir_prefetch - p03 * (ne02 * ne01); - p02 = fastdiv(prem, &bctx->dim1_div); - p01 = prem - p02 * ne01; + uint32_t p03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div); + uint32_t prem = ir_prefetch - p03 * (ne02 * ne01); + uint32_t p02 = fastdiv(prem, &bctx->src0_dim1_div); + uint32_t p01 = prem - p02 * ne01; uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01; dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size); ir_prefetch += next_block_size; @@ -545,14 +544,16 @@ static void binary_job_element_repeat(unsigned int nth, unsigned int ith, void * const uint32_t elem_size_bytes = (src0_type == HTP_TYPE_F32) ? sizeof(float) : sizeof(_Float16); const uint32_t row_size_bytes = ne00 * elem_size_bytes;; const uint32_t total_rows = ne01 * ne02 * ne03; - const uint32_t start_row = bctx->nrows_per_thread * ith; - const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows); + const uint32_t start_row = bctx->nrows_per_thread * ith; + const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows); if (start_row >= end_row) return; uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread); uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread); - size_t src0_spad_half = octx->src0_spad.size_per_thread / 2; - size_t dst_spad_half = octx->dst_spad.size_per_thread / 2; + size_t src0_spad_half = octx->src0_spad.size_per_thread / 2; + size_t dst_spad_half = octx->dst_spad.size_per_thread / 2; + + FARF(HIGH, "binary-repeat: %d/%d (%u:%u) row-size %u (%u)", ith, nth, start_row, end_row, nb01, bctx->dst_row_size_aligned); dma_queue * q = octx->ctx->dma[ith]; uint32_t ir_prefetch = start_row; @@ -560,11 +561,10 @@ static void binary_job_element_repeat(unsigned int nth, unsigned int ith, void * for (int k = 0; k < 2 && ir_prefetch < end_row; k++) { uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); - uint32_t i03, i02, i01, rem; - i03 = fastdiv(ir_prefetch, &bctx->dim12_div); - rem = ir_prefetch - i03 * (ne02 * ne01); - i02 = fastdiv(rem, &bctx->dim1_div); - i01 = rem - i02 * ne01; + uint32_t i03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div); + uint32_t rem = ir_prefetch - i03 * (ne02 * ne01); + uint32_t i02 = fastdiv(rem, &bctx->src0_dim1_div); + uint32_t i01 = rem - i02 * ne01; uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01; uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; @@ -572,7 +572,7 @@ static void binary_job_element_repeat(unsigned int nth, unsigned int ith, void * uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half; uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half; - dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0); + dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, 0); dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size); ir_prefetch += current_block_size; spad_idx ^= 1; @@ -583,11 +583,10 @@ static void binary_job_element_repeat(unsigned int nth, unsigned int ith, void * uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src; uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst; - uint32_t i03, i02, i01, rem; - i03 = fastdiv(ir, &bctx->dim12_div); - rem = ir - i03 * (ne02 * ne01); - i02 = fastdiv(rem, &bctx->dim1_div); - i01 = rem - i02 * ne01; + uint32_t i03 = fastdiv(ir, &bctx->src0_dim12_div); + uint32_t rem = ir - i03 * (ne02 * ne01); + uint32_t i02 = fastdiv(rem, &bctx->src0_dim1_div); + uint32_t i01 = rem - i02 * ne01; for (uint32_t r = 0; r < current_block_size; r++) { uint32_t r_i01 = i01 + r; @@ -612,11 +611,10 @@ static void binary_job_element_repeat(unsigned int nth, unsigned int ith, void * if (ir_prefetch < end_row) { uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); - uint32_t p03, p02, p01, prem; - p03 = fastdiv(ir_prefetch, &bctx->dim12_div); - prem = ir_prefetch - p03 * (ne02 * ne01); - p02 = fastdiv(prem, &bctx->dim1_div); - p01 = prem - p02 * ne01; + uint32_t p03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div); + uint32_t prem = ir_prefetch - p03 * (ne02 * ne01); + uint32_t p02 = fastdiv(prem, &bctx->src0_dim1_div); + uint32_t p01 = prem - p02 * ne01; uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01; dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size); ir_prefetch += next_block_size; @@ -646,6 +644,7 @@ static void binary_job_add_id(unsigned int nth, unsigned int ith, void * data) { const uint32_t nb02 = src0->nb[2]; const uint32_t nb03 = src0->nb[3]; const uint32_t nb11 = src1->nb[1]; // src1 row stride + const uint32_t nb1 = dst->nb[1]; const uint32_t nb2 = dst->nb[2]; const uint32_t nb3 = dst->nb[3]; @@ -657,8 +656,8 @@ static void binary_job_add_id(unsigned int nth, unsigned int ith, void * data) { uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread); uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread); - size_t src0_spad_half = octx->src0_spad.size_per_thread / 2; - size_t dst_spad_half = octx->dst_spad.size_per_thread / 2; + size_t src0_spad_half = octx->src0_spad.size_per_thread / 2; + size_t dst_spad_half = octx->dst_spad.size_per_thread / 2; dma_queue * q = octx->ctx->dma[ith]; uint32_t ir_prefetch = start_row; @@ -666,11 +665,10 @@ static void binary_job_add_id(unsigned int nth, unsigned int ith, void * data) { for (int k = 0; k < 2 && ir_prefetch < end_row; k++) { uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); - uint32_t i03, i02, i01, rem; - i03 = fastdiv(ir_prefetch, &bctx->dim12_div); - rem = ir_prefetch - i03 * (ne02 * ne01); - i02 = fastdiv(rem, &bctx->dim1_div); - i01 = rem - i02 * ne01; + uint32_t i03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div); + uint32_t rem = ir_prefetch - i03 * (ne02 * ne01); + uint32_t i02 = fastdiv(rem, &bctx->src0_dim1_div); + uint32_t i01 = rem - i02 * ne01; uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01; uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; @@ -678,7 +676,7 @@ static void binary_job_add_id(unsigned int nth, unsigned int ith, void * data) { uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half; uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half; - dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0); + dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), 0); dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size); ir_prefetch += current_block_size; spad_idx ^= 1; @@ -689,11 +687,10 @@ static void binary_job_add_id(unsigned int nth, unsigned int ith, void * data) { uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src; uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst; - uint32_t i03, i02, i01, rem; - i03 = fastdiv(ir, &bctx->dim12_div); - rem = ir - i03 * (ne02 * ne01); - i02 = fastdiv(rem, &bctx->dim1_div); - i01 = rem - i02 * ne01; + uint32_t i03 = fastdiv(ir, &bctx->src0_dim12_div); + uint32_t rem = ir - i03 * (ne02 * ne01); + uint32_t i02 = fastdiv(rem, &bctx->src0_dim1_div); + uint32_t i01 = rem - i02 * ne01; for (uint32_t r = 0; r < current_block_size; r++) { uint32_t r_i01 = i01 + r; // linear within block since we split at ne01 @@ -712,11 +709,10 @@ static void binary_job_add_id(unsigned int nth, unsigned int ith, void * data) { if (ir_prefetch < end_row) { uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); - uint32_t p03, p02, p01, prem; - p03 = fastdiv(ir_prefetch, &bctx->dim12_div); - prem = ir_prefetch - p03 * (ne02 * ne01); - p02 = fastdiv(prem, &bctx->dim1_div); - p01 = prem - p02 * ne01; + uint32_t p03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div); + uint32_t prem = ir_prefetch - p03 * (ne02 * ne01); + uint32_t p02 = fastdiv(prem, &bctx->src0_dim1_div); + uint32_t p01 = prem - p02 * ne01; uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01; dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size); ir_prefetch += next_block_size; @@ -739,40 +735,36 @@ static int execute_op_binary(struct htp_ops_context * octx) { const size_t elem_size = (src0_type == HTP_TYPE_F32) ? sizeof(float) : sizeof(_Float16); const size_t src0_row_size = src0->ne[0] * elem_size; const size_t src1_row_size = src1->ne[0] * elem_size; - const size_t dst_row_size = dst->ne[0] * elem_size; + const size_t dst_row_size = dst->ne[0] * elem_size; - // Align to VLEN - const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN); - const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN); + size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN); size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN); + size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN); bool is_add_id = (octx->op == HTP_OP_ADD_ID); bool is_scalar = !is_add_id && (src1->ne[0] == 1); - // Determine which kernel we will use to alloc memory and dispatch - bool use_vector_same = !is_add_id && !is_scalar && ((src0->nb[1] % VLEN) == 0) && (src1->ne[0] == src0->ne[0]) && + bool is_transposed = (src0->nb[1] < src0_row_size || src1->nb[1] < src1_row_size || dst->nb[1] < dst_row_size); + + bool is_same_shape = !is_add_id && !is_scalar && !is_transposed && + (src1->ne[0] == src0->ne[0] && src0->ne[0] % VLEN == 0) && (src1->ne[1] == src0->ne[1] || src1->ne[1] == 1) && (src1->ne[2] == src0->ne[2] || src1->ne[2] == 1) && (src1->ne[3] == src0->ne[3] || src1->ne[3] == 1); - bool is_row_bcast = use_vector_same && (src1->ne[1] == 1 && src1->ne[2] == 1 && src1->ne[3] == 1); - bool use_complex = !is_add_id && !is_scalar && !use_vector_same && (src1->ne[0] == src0->ne[0]); - bool use_repeat = !is_add_id && !is_scalar && !use_vector_same && (src1->ne[0] != src0->ne[0]); + bool is_row_bcast = is_same_shape && (src1->ne[1] == 1 && src1->ne[2] == 1 && src1->ne[3] == 1); + bool is_complex = !is_add_id && !is_scalar && !is_same_shape && (src1->ne[0] == src0->ne[0]); + bool is_repeat = !is_add_id && !is_scalar && !is_same_shape && (src1->ne[0] != src0->ne[0]); size_t spad_row_total; - if (is_scalar) { - spad_row_total = 2 * (src0_row_size_aligned + dst_row_size_aligned); - } else if (is_row_bcast) { - spad_row_total = 2 * (src0_row_size_aligned + dst_row_size_aligned); - } else if (use_vector_same) { + if (is_same_shape) { spad_row_total = 2 * (src0_row_size_aligned + src1_row_size_aligned + dst_row_size_aligned); - } else if (is_add_id) { - spad_row_total = 2 * (src0_row_size_aligned + dst_row_size_aligned); // src1 read directly } else { spad_row_total = 2 * (src0_row_size_aligned + dst_row_size_aligned); } size_t rows_per_buffer = octx->ctx->vtcm_size / (n_threads * spad_row_total); + // Adjust for static src1 in row_bcast case if (is_row_bcast) { size_t needed_static = src1_row_size_aligned; @@ -782,28 +774,26 @@ static int execute_op_binary(struct htp_ops_context * octx) { } if (rows_per_buffer < 1) { - FARF(ERROR, "binary: VTCM too small\n"); - return HTP_STATUS_VTCM_TOO_SMALL; + FARF(ERROR, "binary: VTCM too small\n"); + return HTP_STATUS_VTCM_TOO_SMALL; } octx->src0_spad.size_per_thread = rows_per_buffer * 2 * src0_row_size_aligned; octx->dst_spad.size_per_thread = rows_per_buffer * 2 * dst_row_size_aligned; - if (is_scalar || use_complex || use_repeat || is_add_id) { - octx->src1_spad.size_per_thread = 0; - } else if (is_row_bcast) { + if (is_add_id || is_scalar || is_complex || is_repeat || is_row_bcast) { octx->src1_spad.size_per_thread = 0; } else { octx->src1_spad.size_per_thread = rows_per_buffer * 2 * src1_row_size_aligned; } + octx->dst_spad.size = n_threads * octx->dst_spad.size_per_thread; octx->src0_spad.size = n_threads * octx->src0_spad.size_per_thread; if (is_row_bcast) { octx->src1_spad.size = src1_row_size_aligned; } else { octx->src1_spad.size = n_threads * octx->src1_spad.size_per_thread; } - octx->dst_spad.size = n_threads * octx->dst_spad.size_per_thread; if (octx->ctx->vtcm_size < (octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size)) { return HTP_STATUS_VTCM_TOO_SMALL; @@ -823,46 +813,37 @@ static int execute_op_binary(struct htp_ops_context * octx) { } struct htp_binary_context bctx; - bctx.octx = octx; - bctx.nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads; - bctx.block_max = rows_per_buffer; + bctx.octx = octx; + bctx.nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads; + bctx.block_max = rows_per_buffer; bctx.src0_row_size_aligned = src0_row_size_aligned; bctx.src1_row_size_aligned = src1_row_size_aligned; bctx.dst_row_size_aligned = dst_row_size_aligned; - bctx.dim1_div = init_fastdiv_values(src0->ne[1]); - bctx.dim2_div = init_fastdiv_values(src0->ne[2]); - bctx.dim12_div = init_fastdiv_values(src0->ne[1] * src0->ne[2]); + bctx.src0_dim1_div = init_fastdiv_values(src0->ne[1]); + bctx.src0_dim2_div = init_fastdiv_values(src0->ne[2]); + bctx.src0_dim12_div = init_fastdiv_values(src0->ne[1] * src0->ne[2]); - bctx.src1_dim1_div = init_fastdiv_values(src1->ne[1]); - bctx.src1_dim2_div = init_fastdiv_values(src1->ne[2]); - bctx.src1_dim3_div = init_fastdiv_values(src1->ne[3]); + bctx.src1_dim1_div = init_fastdiv_values(src1->ne[1]); + bctx.src1_dim2_div = init_fastdiv_values(src1->ne[2]); + bctx.src1_dim3_div = init_fastdiv_values(src1->ne[3]); bool src0_contig_dim1 = (src0->nb[2] == src0->ne[1] * src0->nb[1]); - bool dst_contig_dim1 = (dst->nb[2] == src0->ne[1] * dst->nb[1]); + bool dst_contig_dim1 = (dst->nb[2] == src0->ne[1] * dst->nb[1]); bool src0_contig_dim2 = (src0->nb[3] == src0->ne[2] * src0->nb[2]); - bool dst_contig_dim2 = (dst->nb[3] == src0->ne[2] * dst->nb[2]); - - bctx.split_at_ne01 = (src0->ne[2] > 1) && - ((src1->ne[1] > 1) || (src1->ne[2] > 1) || !src0_contig_dim1 || !dst_contig_dim1); + bool dst_contig_dim2 = (dst->nb[3] == src0->ne[2] * dst->nb[2]); - bctx.split_at_ne02 = (src0->ne[3] > 1) && - ((src1->ne[2] > 1) || (src1->ne[3] > 1) || !src0_contig_dim2 || !dst_contig_dim2); - - // Precompute specific kernel parameters - if (use_vector_same) { - bctx.src1_dma_stride = (src1->ne[1] == 1) ? 0 : src1->nb[1]; - bctx.src1_fetch_rows = (src1->ne[1] == 1) ? 1 : rows_per_buffer; - } + bctx.split_at_ne01 = (src0->ne[2] > 1) && ((src1->ne[1] > 1) || (src1->ne[2] > 1) || !src0_contig_dim1 || !dst_contig_dim1); + bctx.split_at_ne02 = (src0->ne[3] > 1) && ((src1->ne[2] > 1) || (src1->ne[3] > 1) || !src0_contig_dim2 || !dst_contig_dim2); worker_callback_t worker_func; - if (is_add_id) worker_func = binary_job_add_id; - else if (is_scalar) worker_func = binary_job_scalar; - else if (is_row_bcast) worker_func = binary_job_vector_row_broadcast; - else if (use_vector_same) worker_func = binary_job_vector_same_shape; - else if (use_complex) worker_func = binary_job_vector_complex; - else worker_func = binary_job_element_repeat; + if (is_add_id) worker_func = binary_job_add_id; + else if (is_scalar) worker_func = binary_job_scalar; + else if (is_row_bcast) worker_func = binary_job_vector_row_broadcast; + else if (is_same_shape) worker_func = binary_job_vector_same_shape; + else if (is_complex) worker_func = binary_job_vector_complex; + else worker_func = binary_job_element_repeat; if (is_row_bcast) { dma_queue_pop(q); diff --git a/ggml/src/ggml-hexagon/htp/hex-dma.c b/ggml/src/ggml-hexagon/htp/hex-dma.c index 44e1be40c5d..b66e2d2603c 100644 --- a/ggml/src/ggml-hexagon/htp/hex-dma.c +++ b/ggml/src/ggml-hexagon/htp/hex-dma.c @@ -31,8 +31,8 @@ dma_queue * dma_queue_create(size_t capacity) { q->capacity = capacity; q->idx_mask = capacity - 1; - q->desc = (hexagon_udma_descriptor_type1_t *) memalign(64, capacity * sizeof(hexagon_udma_descriptor_type1_t)); - memset(q->desc, 0, capacity * sizeof(hexagon_udma_descriptor_type1_t)); + q->desc = (dma_descriptor_2d *) memalign(64, capacity * sizeof(dma_descriptor_2d)); + memset(q->desc, 0, capacity * sizeof(dma_descriptor_2d)); q->dptr = (dma_ptr *) memalign(4, capacity * sizeof(dma_ptr)); memset(q->dptr, 0, capacity * sizeof(dma_ptr)); diff --git a/ggml/src/ggml-hexagon/htp/hex-dma.h b/ggml/src/ggml-hexagon/htp/hex-dma.h index 9811a07599f..ff166cbcc7a 100644 --- a/ggml/src/ggml-hexagon/htp/hex-dma.h +++ b/ggml/src/ggml-hexagon/htp/hex-dma.h @@ -10,19 +10,84 @@ extern "C" { #endif +// Define the HW descriptor structs here since the ones in HexSDK are a bit out of date +typedef struct dma_descriptor_1d_s { + void * next; + uint32_t size:24; + uint32_t desc_size:2; + uint32_t dst_comp:1; + uint32_t src_comp:1; + uint32_t dst_bypass:1; + uint32_t src_bypass:1; + uint32_t order:1; + uint32_t done:1; + void * src; + void * dst; +} dma_descriptor_1d; + +#if __HVX_ARCH__ < 75 + +typedef struct dma_descriptor_2d_s { + void * next; + uint32_t reserved0:24; + uint32_t desc_size:2; + uint32_t dst_comp:1; + uint32_t src_comp:1; + uint32_t dst_bypass:1; + uint32_t src_bypass:1; + uint32_t order:1; + uint32_t done:1; + void * src; + void * dst; + uint32_t desc_type:8; + uint32_t reserved1:24; + uint32_t row_size:16; + uint32_t nrows:16; + uint32_t src_stride:16; + uint32_t dst_stride:16; + uint32_t src_offset:16; + uint32_t dst_offset:16; +} dma_descriptor_2d; + +#else + +typedef struct dma_descriptor_2d_s { + void * next; + uint32_t dst_stride:24; + uint32_t desc_size:2; + uint32_t dst_comp:1; + uint32_t src_comp:1; + uint32_t dst_bypass:1; + uint32_t src_bypass:1; + uint32_t order:1; + uint32_t done:1; + void * src; + void * dst; + uint32_t desc_type:8; + uint32_t reserved0:24; + uint32_t row_size:24; + uint32_t nrows_lo:8; + uint32_t nrows_hi:8; + uint32_t src_stride:24; + uint32_t offset:24; + uint32_t reserved1:8; +} dma_descriptor_2d; + +#endif + typedef struct { - void *dst; + void *dst; const void *src; } dma_ptr; typedef struct { - hexagon_udma_descriptor_type1_t * desc; // descriptor pointers - hexagon_udma_descriptor_type1_t * tail; // tail pointer - dma_ptr * dptr; // dst/src pointers - uint32_t push_idx; - uint32_t pop_idx; - uint32_t capacity; - uint32_t idx_mask; + dma_descriptor_2d * desc; // descriptor pointers + dma_descriptor_2d * tail; // tail pointer + dma_ptr * dptr; // dst/src pointers + uint32_t push_idx; + uint32_t pop_idx; + uint32_t capacity; + uint32_t idx_mask; } dma_queue; dma_queue * dma_queue_create(size_t capacity); @@ -59,71 +124,87 @@ static inline dma_ptr dma_make_ptr(void *dst, const void *src) return p; } -static inline bool dma_queue_push(dma_queue * q, - dma_ptr dptr, - size_t dst_row_size, - size_t src_row_size, - size_t width, // width in bytes. number of bytes to transfer per row - size_t nrows) { +#if __HVX_ARCH__ < 73 +static const uint32_t dma_src_l2_bypass_on = 1; +static const uint32_t dma_dst_l2_bypass_on = 0; +#else +static const uint32_t dma_src_l2_bypass_on = 1; +static const uint32_t dma_dst_l2_bypass_on = 1; +#endif + +static inline bool dma_queue_push_single_1d(dma_queue * q, dma_ptr dptr, size_t size) { if (((q->push_idx + 1) & q->idx_mask) == q->pop_idx) { - FARF(ERROR, "dma-push: queue full\n"); + FARF(HIGH, "dma-push: queue full\n"); return false; } - hexagon_udma_descriptor_type1_t * desc = &q->desc[q->push_idx]; + dma_descriptor_1d * desc = (dma_descriptor_1d *) &q->desc[q->push_idx]; + desc->next = NULL; + desc->desc_size = 0; // 1D mode + desc->src_bypass = dma_src_l2_bypass_on; + desc->dst_bypass = dma_dst_l2_bypass_on; + desc->order = 1; + desc->done = 0; + desc->src = (void *) dptr.src; + desc->dst = (void *) dptr.dst; + desc->size = size; + + q->dptr[q->push_idx] = dptr; + + dmlink(q->tail, desc); + q->tail = (dma_descriptor_2d *) desc; + + // FARF(ERROR, "dma-push: i %u row-size %u nrows %d dst %p src %p\n", q->push_idx, row_size, nrows, dptr.dst, dptr.src); + q->push_idx = (q->push_idx + 1) & q->idx_mask; + return true; +} + +static inline bool dma_queue_push_single_2d(dma_queue * q, dma_ptr dptr, size_t dst_stride, size_t src_stride, size_t row_size, size_t nrows) { + if (((q->push_idx + 1) & q->idx_mask) == q->pop_idx) { + FARF(HIGH, "dma-push: queue full\n"); + return false; + } + + dma_descriptor_2d * desc = &q->desc[q->push_idx]; desc->next = NULL; - desc->length = 0; - desc->desctype = HEXAGON_UDMA_DESC_DESCTYPE_TYPE1; - desc->dstbypass = 1; - desc->srcbypass = 1; -#if __HVX_ARCH__ >= 73 - desc->dstbypass = 1; - desc->srcbypass = 1; -#else - desc->dstbypass = 0; - desc->srcbypass = 1; -#endif - desc->order = 0; - desc->dstate = HEXAGON_UDMA_DESC_DSTATE_INCOMPLETE; + desc->reserved0 = 0; + desc->reserved1 = 0; + desc->desc_size = 1; // 2d mode + desc->src_bypass = dma_src_l2_bypass_on; + desc->dst_bypass = dma_dst_l2_bypass_on; + desc->src_comp = 0; + desc->dst_comp = 0; + desc->order = 1; + desc->done = 0; + desc->src_stride = src_stride; + desc->dst_stride = dst_stride; desc->src = (void *) dptr.src; desc->dst = (void *) dptr.dst; - desc->allocation = 0; - desc->padding = 0; - desc->roiwidth = width; - desc->roiheight = nrows; - desc->srcstride = src_row_size; - desc->dststride = dst_row_size; - desc->srcwidthoffset = 0; - desc->dstwidthoffset = 0; + desc->row_size = row_size; + +#if __HVX_ARCH__ < 75 + desc->desc_type = 0; // 2d (16-bit) mode + desc->nrows = nrows; + desc->src_offset = 0; + desc->dst_offset = 0; +#else + desc->desc_type = 9; // 2d (24-bit) mode + desc->nrows_lo = (nrows & 0xff); + desc->nrows_hi = (nrows >> 8); + desc->offset = 0; +#endif q->dptr[q->push_idx] = dptr; dmlink(q->tail, desc); q->tail = desc; - // FARF(ERROR, "dma-push: i %u width %u nrows %d dst %p src %p\n", q->push_idx, width, nrows, dptr.dst, dptr.src); + // FARF(ERROR, "dma-push: i %u row-size %u nrows %d dst %p src %p\n", q->push_idx, row_size, nrows, dptr.dst, dptr.src); q->push_idx = (q->push_idx + 1) & q->idx_mask; return true; } -static inline bool dma_queue_push_ddr_to_vtcm(dma_queue * q, - dma_ptr dptr, - size_t dst_row_size, - size_t src_row_size, - size_t nrows) { - return dma_queue_push(q, dptr, dst_row_size, src_row_size, src_row_size, nrows); -} - - -static inline bool dma_queue_push_vtcm_to_ddr(dma_queue * q, - dma_ptr dptr, - size_t dst_row_size, - size_t src_row_size, - size_t nrows) { - return dma_queue_push(q, dptr, dst_row_size, src_row_size, dst_row_size, nrows); -} - static inline dma_ptr dma_queue_pop(dma_queue * q) { dma_ptr dptr = { NULL }; @@ -131,12 +212,12 @@ static inline dma_ptr dma_queue_pop(dma_queue * q) { return dptr; } - hexagon_udma_descriptor_type1_t * desc = &q->desc[q->pop_idx]; + dma_descriptor_2d * desc = &q->desc[q->pop_idx]; // Wait for desc to complete while (1) { dmpoll(); - if (desc->dstate == HEXAGON_UDMA_DESC_DSTATE_COMPLETE) { + if (desc->done) { break; } // FARF(ERROR, "dma-pop: waiting for DMA : %u\n", q->pop_idx); @@ -175,86 +256,62 @@ static inline uint32_t dma_queue_capacity(dma_queue * q) { return q->capacity; } -// --------------------------------------------------------------------------- -// Overflow-safe DMA push: all UDMA type1 descriptor fields (roiwidth, -// roiheight, srcstride, dststride) are 16-bit, max 65535. This helper -// transparently handles values that exceed the 16-bit limit and submits -// chained DMA transtions. -// -// Case 1 (fast path): all params fit in 16 bits -> direct dma_queue_push. -// Case 2 (contiguous block): width == srcstride == dststride. Reshape the -// flat transfer into a 2D descriptor with sub_width <= 65535. Produces a -// single descriptor, preserving async DMA behavior. -// Case 3 (stride overflow): srcstride or dststride > 65535. Issue rows -// one at a time. The first N-1 rows are pushed+popped synchronously; -// the last row is left async so the caller can pop it. -// --------------------------------------------------------------------------- -#define UDMA_MAX_FIELD_VAL 65535u - -static inline bool dma_queue_push_chained(dma_queue *q, dma_ptr dptr, size_t dst_stride, size_t src_stride, size_t width, size_t nrows) { - // Fast path: everything fits in 16 bits. - if (__builtin_expect( - width <= UDMA_MAX_FIELD_VAL && - nrows <= UDMA_MAX_FIELD_VAL && - src_stride <= UDMA_MAX_FIELD_VAL && - dst_stride <= UDMA_MAX_FIELD_VAL, 1)) { - return dma_queue_push(q, dptr, dst_stride, src_stride, width, nrows); - } +#if __HVX_ARCH__ < 75 - // Case 2: contiguous block (width == src_stride == dst_stride). - // Reshape total bytes into sub_width * sub_nrows where sub_width <= 65535. - if (width == src_stride && width == dst_stride) { - size_t total = width * nrows; +// Overflow-safe DMA push: all 2d descriptor fields (row_size, nrows, src_stride, dst_stride) are 16-bit, max 65535. +// This version transparently handles values that exceed the 16-bit limit and submits chained DMA transtions. - // Pick the largest 128-byte-aligned sub_width that divides total evenly. - size_t sub_width = UDMA_MAX_FIELD_VAL & ~(size_t)127; // 65408 - while (sub_width > 0 && total % sub_width != 0) { - sub_width -= 128; - } - if (sub_width == 0) { - // Fallback: use original width (must fit) with adjusted nrows. - // This shouldn't happen for 128-aligned DMA sizes. - sub_width = width; - } - size_t sub_nrows = total / sub_width; - - // Handle sub_nrows > 65535 by issuing chunked descriptors. - const uint8_t *src = (const uint8_t *)dptr.src; - uint8_t *dst = (uint8_t *)dptr.dst; - size_t rows_done = 0; - while (rows_done < sub_nrows) { - size_t chunk = sub_nrows - rows_done; - if (chunk > UDMA_MAX_FIELD_VAL) chunk = UDMA_MAX_FIELD_VAL; - - dma_ptr p = dma_make_ptr(dst + rows_done * sub_width, src + rows_done * sub_width); - if (!dma_queue_push(q, p, sub_width, sub_width, sub_width, chunk)) - return false; +#define DMA_MAX_FIELD_VAL 65535u - rows_done += chunk; - // Complete all chunks without waiting except the last one, so the - // caller's single dma_queue_pop drains the final descriptor. - if (rows_done < sub_nrows) - dma_queue_pop_nowait(q); - } - return true; +static inline bool dma_queue_push(dma_queue *q, dma_ptr dptr, size_t dst_stride, size_t src_stride, size_t row_size, size_t nrows) { + // Fast path: everything fits in 16 bits + if (nrows == 0 || __builtin_expect( + row_size <= DMA_MAX_FIELD_VAL && + nrows <= DMA_MAX_FIELD_VAL && + src_stride <= DMA_MAX_FIELD_VAL && + dst_stride <= DMA_MAX_FIELD_VAL, 1)) { + return dma_queue_push_single_2d(q, dptr, dst_stride, src_stride, row_size, nrows); } - // Case 3: stride overflow — fall back to row-by-row. + // Contiguous block + // Use 1d DMA mode which supports sizes up to 24-bits (16MB) + if (nrows == 1 || (row_size == src_stride && row_size == dst_stride)) { + size_t total = row_size * nrows; + return dma_queue_push_single_1d(q, dptr, total); + } + + // Stride overflow — fall back to row-by-row. { - const uint8_t *src = (const uint8_t *)dptr.src; - uint8_t *dst = (uint8_t *)dptr.dst; + const uint8_t *src = (const uint8_t *) dptr.src; + uint8_t *dst = (uint8_t *) dptr.dst; for (size_t r = 0; r < nrows; ++r) { - dma_ptr p = dma_make_ptr(dst + r * dst_stride, - src + r * src_stride); - if (!dma_queue_push(q, p, 0, 0, width, 1)) - return false; - if (r + 1 < nrows) - dma_queue_pop_nowait(q); + dma_ptr p = dma_make_ptr(dst + r * dst_stride, src + r * src_stride); + if (!dma_queue_push_single_1d(q, p, row_size)) + return false; + if (r + 1 < nrows) + dma_queue_pop(q); } return true; } } +#else // HVX_ARCH >= 75 + +static inline bool dma_queue_push(dma_queue *q, dma_ptr dptr, size_t dst_stride, size_t src_stride, size_t row_size, size_t nrows) { + // On v75 and up we always use 2d 24-bit mode + return dma_queue_push_single_2d(q, dptr, dst_stride, src_stride, row_size, nrows); +} + +#endif + +static inline bool dma_queue_push_ddr_to_vtcm(dma_queue * q, dma_ptr dptr, size_t dst_row_size, size_t src_row_size, size_t nrows) { + return dma_queue_push(q, dptr, dst_row_size, src_row_size, src_row_size, nrows); +} + +static inline bool dma_queue_push_vtcm_to_ddr(dma_queue * q, dma_ptr dptr, size_t dst_row_size, size_t src_row_size, size_t nrows) { + return dma_queue_push(q, dptr, dst_row_size, src_row_size, dst_row_size, nrows); +} + #ifdef __cplusplus } // extern "C" #endif diff --git a/ggml/src/ggml-hexagon/htp/hex-dump.h b/ggml/src/ggml-hexagon/htp/hex-dump.h index e3badb57f92..19d173c2232 100644 --- a/ggml/src/ggml-hexagon/htp/hex-dump.h +++ b/ggml/src/ggml-hexagon/htp/hex-dump.h @@ -21,6 +21,15 @@ static inline void hex_dump_uint8_line(char * pref, const uint8_t * x, uint32_t FARF(HIGH, "%s\n", str); } +static inline void hex_dump_uint32_line(char * pref, const uint32_t * x, uint32_t n) { + char str[1024], *p = str, *p_end = str + sizeof(str); + p += snprintf(p, p_end - p, "%s: ", pref); + for (int i = 0; i < n; i++) { + p += snprintf(p, p_end - p, "%u, ", (unsigned int) x[i]); + } + FARF(HIGH, "%s\n", str); +} + static inline void hex_dump_int32_line(char * pref, const int32_t * x, uint32_t n) { char str[1024], *p = str, *p_end = str + sizeof(str); p += snprintf(p, p_end - p, "%s: ", pref); diff --git a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c index c703a049426..a56356bee9f 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c @@ -727,7 +727,7 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu if (use_dma_activation) { const size_t row_bytes = (size_t) params->k * sizeof(float); const size_t stride_bytes = (size_t) params->act_stride * sizeof(float); - dma_queue_push_chained(ctx->dma[0], + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_f32_act, activation_chunk), row_bytes, stride_bytes, row_bytes, n_rows); dma_queue_pop(ctx->dma[0]); @@ -747,7 +747,7 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu { const size_t n_cols_first = hex_smin((size_t) params->n, n_chunk_n_cols); - dma_queue_push_chained(ctx->dma[0], dma_make_ptr(buf_curr, weight_group), + dma_queue_push(ctx->dma[0], dma_make_ptr(buf_curr, weight_group), fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_first); } @@ -765,7 +765,7 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu const size_t n_cols_next = hex_smin((size_t) params->n - nc_next, n_chunk_n_cols); const __fp16 *next_weight_chunk = weight_group + nc_next * params->weight_stride; - dma_queue_push_chained(ctx->dma[0], dma_make_ptr(buf_next, next_weight_chunk), + dma_queue_push(ctx->dma[0], dma_make_ptr(buf_next, next_weight_chunk), fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_next); } @@ -891,7 +891,7 @@ int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, float *restrict dst, co if (use_dma_activation) { const size_t row_bytes = (size_t) k * sizeof(float); const size_t stride_bytes = (size_t) act_stride * sizeof(float); - dma_queue_push_chained(ctx->dma[0], + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_f32_act, activation_chunk), row_bytes, stride_bytes, row_bytes, n_rows); dma_queue_pop(ctx->dma[0]); @@ -916,7 +916,7 @@ int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, float *restrict dst, co { const size_t n_cols_first = hex_smin(n, n_chunk_n_cols); - dma_queue_push_chained(ctx->dma[0], dma_make_ptr(buf_curr, permuted_weight), + dma_queue_push(ctx->dma[0], dma_make_ptr(buf_curr, permuted_weight), fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_first); } @@ -933,7 +933,7 @@ int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, float *restrict dst, co const size_t n_cols_next = hex_smin(n - nc_next, n_chunk_n_cols); const __fp16 *next_weight_chunk = permuted_weight + nc_next * weight_stride; - dma_queue_push_chained(ctx->dma[0], dma_make_ptr(buf_next, next_weight_chunk), + dma_queue_push(ctx->dma[0], dma_make_ptr(buf_next, next_weight_chunk), fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_next); } @@ -1104,7 +1104,7 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds // because UDMA roiwidth is 16-bit and total size can exceed 65535. { const size_t n_cols_first = hex_smin(n, n_chunk_n_cols); - dma_queue_push_chained(ctx->dma[0], dma_make_ptr(buf_curr, permuted_weight), row_stride, row_stride, row_stride, n_cols_first); + dma_queue_push(ctx->dma[0], dma_make_ptr(buf_curr, permuted_weight), row_stride, row_stride, row_stride, n_cols_first); } for (size_t nc = 0; nc < n; nc += n_chunk_n_cols) { @@ -1120,7 +1120,7 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds const uint8_t *next_weight_chunk = permuted_weight + nc_next * row_stride; - dma_queue_push_chained(ctx->dma[0], dma_make_ptr(buf_next, next_weight_chunk), row_stride, row_stride, row_stride, n_cols_next); + dma_queue_push(ctx->dma[0], dma_make_ptr(buf_next, next_weight_chunk), row_stride, row_stride, row_stride, n_cols_next); } // Dequant + vscatter writes directly to [K, N] transposed tiles. @@ -1173,7 +1173,7 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds { // Use 2D DMA (n_cols rows x row_stride) to avoid 16-bit roiwidth overflow. const uint8_t *qweight_chunk_A0 = permuted_weight; - dma_queue_push_chained(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A0), row_stride, row_stride, row_stride, n_cols_A0); + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A0), row_stride, row_stride, row_stride, n_cols_A0); } { @@ -1191,7 +1191,7 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds const size_t n_cols_A1 = hex_smin(n - 1 * n_chunk_n_cols, n_chunk_n_cols); if (1 < n_chunk_cnt) { const uint8_t *qweight_chunk_A1 = permuted_weight + n_chunk_n_cols * row_stride; - dma_queue_push_chained(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A1), row_stride, row_stride, row_stride, n_cols_A1); + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A1), row_stride, row_stride, row_stride, n_cols_A1); } // C0 @@ -1218,7 +1218,7 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds // issue A_{i+2} if (i + 2 < n_chunk_cnt) { const uint8_t *qweight_chunk_p2 = permuted_weight + nc_p2 * row_stride; - dma_queue_push_chained(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_p2), row_stride, row_stride, row_stride, n_cols_p2); + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_p2), row_stride, row_stride, row_stride, n_cols_p2); } // wait for HMX (C_{i}) -- C_{i} is done @@ -1443,7 +1443,7 @@ int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict { const float *activation_block = x + mr * k + kk; - dma_queue_push_chained(ctx->dma[0], + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_scratch1, activation_block), k_blk_sz * sizeof(float), k * sizeof(float), @@ -1472,10 +1472,10 @@ int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict s.scale_width = nb_sub * HMX_X4X2_DBLK_SIZE; // 2D DMA: quants sub-range - dma_queue_push_chained(ctx->dma[0], dma_make_ptr(s.dst, s.src + s.quant_off), + dma_queue_push(ctx->dma[0], dma_make_ptr(s.dst, s.src + s.quant_off), s.dst_stride, s.src_stride, s.quant_width, s.n_rows); // 2D DMA: scales sub-range - dma_queue_push_chained(ctx->dma[0], dma_make_ptr(s.dst + s.quant_width, s.src + s.scale_off), + dma_queue_push(ctx->dma[0], dma_make_ptr(s.dst + s.quant_width, s.src + s.scale_off), s.dst_stride, s.src_stride, s.scale_width, s.n_rows); } TIMER_STOP(fetch); diff --git a/ggml/src/ggml-hexagon/htp/hvx-utils.h b/ggml/src/ggml-hexagon/htp/hvx-utils.h index 08343798794..a518ad37331 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-utils.h +++ b/ggml/src/ggml-hexagon/htp/hvx-utils.h @@ -15,12 +15,4 @@ #include "hvx-div.h" #include "hvx-base.h" -#ifndef GATHER_TYPE -# if defined(__hexagon__) -# define GATHER_TYPE(_a) (intptr_t) _a -# else -# define GATHER_TYPE(_a) (HVX_Vector *) _a -# endif -#endif - #endif /* HVX_UTILS_H */ diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index ef9cba8ecc1..70ba9f9f4fe 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -214,7 +214,7 @@ static int vtcm_alloc(struct htp_context * ctx) { HAP_compute_res_attr_init(&attr); HAP_compute_res_attr_set_serialize(&attr, 0); HAP_compute_res_attr_set_cache_mode(&attr, 1); - HAP_compute_res_attr_set_vtcm_param_v2(&attr, vtcm_size, 0, vtcm_size); + HAP_compute_res_attr_set_vtcm_param_v2(&attr, vtcm_size, vtcm_size, vtcm_size); // single page HAP_compute_res_attr_set_release_callback(&attr, vtcm_release_callback, (void *) ctx); HAP_compute_res_attr_set_hmx_param(&attr, 1); @@ -319,7 +319,7 @@ AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_que ctx->n_threads = n_hvx; for (int i = 0; i < ctx->n_threads; i++) { // see discussion https://github.com/ggml-org/llama.cpp/pull/18151#discussion_r2632388541 - ctx->dma[i] = dma_queue_create(64); + ctx->dma[i] = dma_queue_create(128); } // init worker pool diff --git a/ggml/src/ggml-hexagon/htp/ssm-conv.c b/ggml/src/ggml-hexagon/htp/ssm-conv.c index b3c1ef9572e..6b035810d57 100644 --- a/ggml/src/ggml-hexagon/htp/ssm-conv.c +++ b/ggml/src/ggml-hexagon/htp/ssm-conv.c @@ -151,7 +151,7 @@ static void ssm_conv_thread_f32_f32_hvx(unsigned int nth, unsigned int ith, void const int dr = scctx->nrows_per_thread; const uint32_t ir0 = dr * ith; const uint32_t ir1 = MIN(ir0 + dr, d_inner); - const int ir = ir1 - ir0; + const uint32_t ir = ir1 - ir0; if (ir0 >= ir1) { return; // No work for this thread @@ -205,10 +205,10 @@ static void ssm_conv_thread_f32_f32_hvx(unsigned int nth, unsigned int ith, void HVX_Vector acc_vec = Q6_V_vsplat_R(0); for (uint32_t i0 = 0; i0 < d_conv; ++i0) { - Q6_vgather_ARMVw(src0_vec, GATHER_TYPE(spad_src0 + (i0 + i1 * ncs) * sizeof(float) + i2 * (src0->nb[0])), - src0_gather_len, (*(const HVX_Vector *) src0_offsets)); - Q6_vgather_ARMVw(src1_vec, GATHER_TYPE(spad_src1 + (i0 + i1 * nc) * sizeof(float)), - src1_gather_len, (*(const HVX_Vector *) src1_offsets)); + uint32_t src0_base = (uint32_t) spad_src0 + (i0 + i1 * ncs) * sizeof(float) + i2 * (src0->nb[0]); + uint32_t src1_base = (uint32_t) spad_src1 + (i0 + i1 * nc) * sizeof(float); + Q6_vgather_ARMVw(src0_vec, src0_base, src0_gather_len, (*(const HVX_Vector *) src0_offsets)); + Q6_vgather_ARMVw(src1_vec, src1_base, src1_gather_len, (*(const HVX_Vector *) src1_offsets)); HVX_Vector prod = Q6_Vqf32_vmpy_VsfVsf(*(const HVX_Vector *) src0_vec, *(const HVX_Vector *) src1_vec); acc_vec = Q6_Vqf32_vadd_Vqf32Vqf32(acc_vec, prod); @@ -222,10 +222,10 @@ static void ssm_conv_thread_f32_f32_hvx(unsigned int nth, unsigned int ith, void HVX_Vector acc_vec = Q6_V_vsplat_R(0); for (uint32_t i0 = 0; i0 < d_conv; ++i0) { - Q6_vgather_ARMVw(src0_vec, GATHER_TYPE(spad_src0 + (i0 + i1 * ncs) * sizeof(float) + i2 * (src0->nb[0])), - src0_gather_len, (*(const HVX_Vector *) src0_offsets)); - Q6_vgather_ARMVw(src1_vec, GATHER_TYPE(spad_src1 + (i0 + i1 * nc) * sizeof(float)), - src1_gather_len, (*(const HVX_Vector *) src1_offsets)); + uint32_t src0_base = (uint32_t) spad_src0 + (i0 + i1 * ncs) * sizeof(float) + i2 * (src0->nb[0]); + uint32_t src1_base = (uint32_t) spad_src1 + (i0 + i1 * nc) * sizeof(float); + Q6_vgather_ARMVw(src0_vec, src0_base, src0_gather_len, (*(const HVX_Vector *) src0_offsets)); + Q6_vgather_ARMVw(src1_vec, src1_base, src1_gather_len, (*(const HVX_Vector *) src1_offsets)); HVX_Vector prod = Q6_Vqf32_vmpy_VsfVsf(*(const HVX_Vector *) src0_vec, *(const HVX_Vector *) src1_vec); acc_vec = Q6_Vqf32_vadd_Vqf32Vqf32(acc_vec, prod); From eef7422d4d6bd336a9343b0a04b20f94ad9c80a2 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 24 Mar 2026 10:03:09 +0200 Subject: [PATCH 343/831] metal : add FA instantiations for HSK=512, HSV=512 (llama/20902) --- ggml/src/ggml-metal/ggml-metal-device.m | 1 + ggml/src/ggml-metal/ggml-metal.metal | 19 +++++++++++++++++++ 2 files changed, 20 insertions(+) diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 14144aab087..2fbb274c5f9 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -1148,6 +1148,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te op->src[0]->ne[0] != 192 && op->src[0]->ne[0] != 256 && op->src[0]->ne[0] != 320 && + op->src[0]->ne[0] != 512 && op->src[0]->ne[0] != 576) { return false; } diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 9c6b1c4f62b..9286675189d 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -6269,6 +6269,7 @@ template [[host_name("kernel_flash_attn_ext_f32_dk192_dv192")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_f32_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f32_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f32_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f32_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f32_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -6284,6 +6285,7 @@ template [[host_name("kernel_flash_attn_ext_f16_dk192_dv192")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_f16_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f16_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f16_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_f16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; #if defined(GGML_METAL_HAS_BF16) @@ -6300,6 +6302,7 @@ template [[host_name("kernel_flash_attn_ext_bf16_dk192_dv192")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_bf16_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_bf16_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_bf16_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_bf16_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_bf16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; #endif @@ -6316,6 +6319,7 @@ template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv192")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_0_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -6331,6 +6335,7 @@ template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv192")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_1_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q4_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -6346,6 +6351,7 @@ template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv192")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_0_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -6361,6 +6367,7 @@ template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv192")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_1_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q5_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q8_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; @@ -6376,6 +6383,7 @@ template [[host_name("kernel_flash_attn_ext_q8_0_dk192_dv192")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q8_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q8_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q8_0_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; template [[host_name("kernel_flash_attn_ext_q8_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext; #undef FA_TYPES @@ -6957,6 +6965,17 @@ template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk320_dv256")]] kernel flas template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f32_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f16_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_flash_attn_ext_vec_bf16_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +#endif +template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; + template [[host_name("kernel_flash_attn_ext_vec_f32_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; template [[host_name("kernel_flash_attn_ext_vec_f16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #if defined(GGML_METAL_HAS_BF16) From 9e4e4c2401e6fa73bbaedab1e86512c34aee052c Mon Sep 17 00:00:00 2001 From: nuri Date: Tue, 24 Mar 2026 17:13:07 +0900 Subject: [PATCH 344/831] metal : add FLOOR, CEIL, ROUND, TRUNC unary ops (llama/20930) Co-authored-by: nryoo --- ggml/src/ggml-metal/ggml-metal-device.cpp | 4 ++++ ggml/src/ggml-metal/ggml-metal-device.m | 4 ++++ ggml/src/ggml-metal/ggml-metal-impl.h | 4 ++++ ggml/src/ggml-metal/ggml-metal.metal | 16 ++++++++++++++++ 4 files changed, 28 insertions(+) diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 9162342ee98..89539bd7615 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -246,6 +246,10 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary(ggml_metal case GGML_UNARY_OP_EXP: op_num = OP_UNARY_NUM_EXP; break; case GGML_UNARY_OP_SOFTPLUS: op_num = OP_UNARY_NUM_SOFTPLUS; break; case GGML_UNARY_OP_EXPM1: op_num = OP_UNARY_NUM_EXPM1; break; + case GGML_UNARY_OP_FLOOR: op_num = OP_UNARY_NUM_FLOOR; break; + case GGML_UNARY_OP_CEIL: op_num = OP_UNARY_NUM_CEIL; break; + case GGML_UNARY_OP_ROUND: op_num = OP_UNARY_NUM_ROUND; break; + case GGML_UNARY_OP_TRUNC: op_num = OP_UNARY_NUM_TRUNC; break; default: GGML_ABORT("fatal error"); } break; default: GGML_ABORT("fatal error"); diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 2fbb274c5f9..cbef2fb4879 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -1039,6 +1039,10 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_UNARY_OP_EXP: case GGML_UNARY_OP_SOFTPLUS: case GGML_UNARY_OP_EXPM1: + case GGML_UNARY_OP_FLOOR: + case GGML_UNARY_OP_CEIL: + case GGML_UNARY_OP_ROUND: + case GGML_UNARY_OP_TRUNC: return ggml_is_contiguous_rows(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16); default: return false; diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index ea471090cd8..eb2253e029a 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -120,6 +120,10 @@ #define OP_UNARY_NUM_EXP 114 #define OP_UNARY_NUM_SOFTPLUS 115 #define OP_UNARY_NUM_EXPM1 116 +#define OP_UNARY_NUM_FLOOR 117 +#define OP_UNARY_NUM_CEIL 118 +#define OP_UNARY_NUM_ROUND 119 +#define OP_UNARY_NUM_TRUNC 120 #define OP_SUM_ROWS_NUM_SUM_ROWS 10 #define OP_SUM_ROWS_NUM_MEAN 11 diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 9286675189d..2074211594c 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1094,6 +1094,22 @@ kernel void kernel_unary_impl( // TODO: precise implementation dst_ptr[i0] = (T) (exp(x) - 1); } + + if (FC_OP == OP_UNARY_NUM_FLOOR) { + dst_ptr[i0] = (T) floor(x); + } + + if (FC_OP == OP_UNARY_NUM_CEIL) { + dst_ptr[i0] = (T) ceil(x); + } + + if (FC_OP == OP_UNARY_NUM_ROUND) { + dst_ptr[i0] = (T) round(x); + } + + if (FC_OP == OP_UNARY_NUM_TRUNC) { + dst_ptr[i0] = (T) trunc(x); + } } #undef FC_OP From f2a8e65ea7e4b58cf862a832c2c1fabd2e6ff63f Mon Sep 17 00:00:00 2001 From: Neo Zhang Date: Wed, 25 Mar 2026 17:48:37 +0800 Subject: [PATCH 345/831] sycl : fix wrong variable check by assert (llama/20903) * fix wrong variable check by assert * use GGML api --- ggml/src/ggml-sycl/add-id.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-sycl/add-id.cpp b/ggml/src/ggml-sycl/add-id.cpp index 8929017a999..e0adc4fe423 100644 --- a/ggml/src/ggml-sycl/add-id.cpp +++ b/ggml/src/ggml-sycl/add-id.cpp @@ -56,7 +56,7 @@ void ggml_sycl_add_id(ggml_backend_sycl_context& ctx, ggml_tensor* dst) { float* dst_d = (float*)dst->data; const unsigned int max_work_group_size = ggml_sycl_info().max_work_group_sizes[ctx.device]; - assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0); + GGML_ASSERT(max_work_group_size % (WARP_SIZE * WARP_SIZE) == 0); int threads = std::min((unsigned int)ne00, max_work_group_size); // cols From 3987857d2db803eddfc82d61d199913f2013dfab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Wed, 25 Mar 2026 11:53:16 +0100 Subject: [PATCH 346/831] llama: fix llama-model-saver (llama/20503) * llama : add fd-based model loading via llama_model_load_from_fd * llama : address review feedback for fd-based model loading * llama : use FILE pointer instead of fd in public API * llama : use FILE pointer consistently, address review feedback * fixup * fix tensor names * fix llama-model-saver * roundtrip tests * fixup * refactor tests * fix prints * fix model saving * fix CI, disable Chameleon * print seed --------- Co-authored-by: Siddhesh2377 --- ggml/include/gguf.h | 2 ++ ggml/src/ggml-impl.h | 1 - ggml/src/gguf.cpp | 33 +++++++++++++++++++++++---------- 3 files changed, 25 insertions(+), 11 deletions(-) diff --git a/ggml/include/gguf.h b/ggml/include/gguf.h index 79ee202062b..02d5f221c03 100644 --- a/ggml/include/gguf.h +++ b/ggml/include/gguf.h @@ -77,6 +77,7 @@ extern "C" { }; GGML_API struct gguf_context * gguf_init_empty(void); + GGML_API struct gguf_context * gguf_init_from_file_ptr(FILE * file, struct gguf_init_params params); GGML_API struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params); //GGML_API struct gguf_context * gguf_init_from_buffer(..); @@ -189,6 +190,7 @@ extern "C" { // // write the entire context to a binary file + GGML_API bool gguf_write_to_file_ptr(const struct gguf_context * ctx, FILE * file, bool only_meta); GGML_API bool gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta); // get the size in bytes of the meta data (header, kv pairs, tensor info) including padding diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h index 92568655956..0639db362e7 100644 --- a/ggml/src/ggml-impl.h +++ b/ggml/src/ggml-impl.h @@ -773,6 +773,5 @@ inline bool ggml_check_edges(const struct ggml_cgraph * cgraph, // expose GGUF internals for test code GGML_API size_t gguf_type_size(enum gguf_type type); -GGML_API struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params); GGML_API void gguf_write_to_buf(const struct gguf_context * ctx, std::vector & buf, bool only_meta); #endif // __cplusplus diff --git a/ggml/src/gguf.cpp b/ggml/src/gguf.cpp index cbeedf6c4b6..ab3cc974867 100644 --- a/ggml/src/gguf.cpp +++ b/ggml/src/gguf.cpp @@ -394,7 +394,11 @@ bool gguf_read_emplace_helper(const struct gguf_reader & gr, std::vector & bu gguf_write_out(ctx, gw, only_meta); } +bool gguf_write_to_file_ptr(const struct gguf_context * ctx, FILE * file, bool only_meta) { + GGML_ASSERT(file); + + try { + gguf_writer_file gw(file); + gguf_write_out(ctx, gw, only_meta); + } catch (const std::runtime_error& ex) { + GGML_LOG_ERROR("%s: failed to write GGUF data: %s\n", __func__, ex.what()); + return false; + } + return true; +} + bool gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta) { FILE * file = ggml_fopen(fname, "wb"); @@ -1516,17 +1533,13 @@ bool gguf_write_to_file(const struct gguf_context * ctx, const char * fname, boo return false; } - try { - gguf_writer_file gw(file); - gguf_write_out(ctx, gw, only_meta); - } catch (const std::runtime_error& ex) { - GGML_LOG_ERROR("%s: failed to write GGUF data into '%s': %s\n", __func__, fname, ex.what()); - fclose(file); - return false; + const bool success = gguf_write_to_file_ptr(ctx, file, only_meta); + if (!success) { + GGML_LOG_ERROR("%s: failed to write GGUF data into '%s'\n", __func__, fname); } fclose(file); - return true; + return success; } size_t gguf_get_meta_size(const struct gguf_context * ctx) { From 495b77aec29017b13a2dfe5d29b35eb677056d08 Mon Sep 17 00:00:00 2001 From: Saba Fallah <10401143+sfallah@users.noreply.github.com> Date: Wed, 25 Mar 2026 19:57:40 +0100 Subject: [PATCH 347/831] mtmd: Add DeepSeekOCR Support (llama/17400) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * mtmd: llama.cpp DeepSeekOCR support init commit * loading sam tensors * mtmd: fix vision model processing * deepseek-ocr clip-vit model impl * mtmd: add DeepSeek-OCR LM support with standard attention * mtmd: successfully runs DeepSeek-OCR LM in llama-cli * mtmd: Fix RoPE type for DeepSeek-OCR LM. * loading LM testing Vision model loading * sam warmup working * sam erroneous return corrected * clip-vit: corrected cls_embd concat * clip-vit: model convert qkv_proj split * corrected combining of image encoders' results * fix: update callback for ffn_moe_weighted and add callback for attn_out in deepseek2 model * concat image_newline and image_seperator tokens * visual_model warmup (technically) works * window partitioning using standard ggml ops * sam implementation without using CPU only ops * clip: fixed warnings * Merge branch 'sf/deepseek-ocr' of github.com:sfallah/llama.cpp into sf/deepseek-ocr * mtmd: fix get_rel_pos * mtmd: fixed the wrong scaler for get_rel_pos * image encoding technically works but the output can't be checked singe image decoding fails * mtmd: minor changed * mtmd: add native resolution support * - image encoding debugged - issues fixed mainly related wrong config like n_patches etc. - configs need to be corrected in the converter * mtmd: correct token order * - dynamic resizing - changes are concerning PR https://github.com/sfallah/llama.cpp/pull/4 * mtmd: quick fix token order * mtmd: fix danling pointer * mtmd: SAM numerically works * mtmd: debug CLIP-L (vit_pre_ln) * mtmd: debug CLIP-L & first working DeepSeek-OCR model * mtmd : add --dsocr-mode CLI argument for DeepSeek-OCR resolution control & all native resolution modes work * mtmd: simplify SAM patch embedding * mtmd: adapt Pillow image resizing function * mtmd: simplify DeepSeek-OCR dynamic resolution preprocessing * mtmd: remove --dsocr-mode argument * mtmd: refactor code & remove unused helper functions * mtmd: fix tensor names for image newlines and view separator * clean up * reverting automatically removed spaces * reverting automatically removed spaces * mtmd: fixed bad ocr check in Deepseek2 (LM) * mtmd: support combined QKV projection in buid_vit * using common build_attn in sam * corrected code-branch when flash-attn disabled enabling usage of --flash-attn option * mtmd: minor fix * minor formatting and style * fixed flake8 lint issues * minor editorconfig-check fixes * minor editorconfig-check fixes * mtmd: simplify get_rel_pos * mtmd: make sam hparams configurable * mtmd: add detailed comments for resize_bicubic_pillow * mtmd: fixed wrong input setting * mtmd: convert model in FP16 * mtmd: minor fix * mtmd: remove tweak to llama-mtmd-cli & deepseek-ocr template * fix: test-1.jpg ORC issue with small (640) resolution setting min-resolution base (1024) max large (1280) for dynamic-resolution * minor: editconfig-check fix * merge with changes from https://github.com/ggml-org/llama.cpp/pull/17909 added new opt to tests.sh to disable flash-attn * minor: editconfig-check fix * testing deepseek-ocr quick and dirty test script comparing results of Qwen2.5-VL vs DeepSeek-OCR * quick and (potential) dirty merge with https://github.com/ggml-org/llama.cpp/pull/17909 * refactoring, one single builder function and static helpers * added deepseek-ocr test to tests.sh * minor formatting fixes * check with fixed expected resutls * minor formatting * editorconfig-check fix * merge with changes from https://github.com/ggml-org/llama.cpp/pull/18042 * minor - added GLM-4.6V to big tests - added missing deps for python test * convert: minor fix * mtmd: format code * convert: quick fix * convert: quick fix * minor python formatting * fixed merge build issue * merge resolved - fixed issues in convert - tested several deepseek models * minor fix * minor * Update convert_hf_to_gguf.py Co-authored-by: Sigbjørn Skjæret * - removed clip_is_deepseekocr - removed redundant RESIZE_ALGO_BICUBIC_PILLOW resize-algo - simplified image-preprocessing - removed/simplified debug functions * - cleaning commented out code * fixing instabilities issues reintroducing resize_bicubic_pillow * - use f16 model for deepseek-ocr test - ignore llama-arch test for deepseek-ocr * rename fc_w --> mm_fc_w * add links to OCR discussion * cleaner loading code * add missing .weight to some tensors * add default jinja template (to be used by server) * move test model to ggml-org * rolling back upscale change * Update convert_hf_to_gguf.py Co-authored-by: Sigbjørn Skjæret --------- Co-authored-by: bluebread Co-authored-by: Sigbjørn Skjæret Co-authored-by: Xuan Son Nguyen Co-authored-by: Xuan-Son Nguyen --- ggml/src/ggml.c | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 4c0764a0ac5..e9b6720c0af 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -4962,6 +4962,7 @@ static struct ggml_tensor * ggml_interpolate_impl( GGML_ASSERT((mode & 0xFF) < GGML_SCALE_MODE_COUNT); // TODO: implement antialias for modes other than bilinear GGML_ASSERT(!(mode & GGML_SCALE_FLAG_ANTIALIAS) || (mode & 0xFF) == GGML_SCALE_MODE_BILINEAR); + GGML_ASSERT(a->type == GGML_TYPE_F32); struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3); @@ -5307,6 +5308,7 @@ struct ggml_tensor * ggml_flash_attn_ext( GGML_ASSERT(q->ne[3] == v->ne[3]); if (mask) { + GGML_ASSERT(mask->type == GGML_TYPE_F16); GGML_ASSERT(ggml_is_contiguous(mask)); //GGML_ASSERT(ggml_can_repeat_rows(mask, qk)); From a050c7d1bf2aae985cba6896cdbb6644383f20bf Mon Sep 17 00:00:00 2001 From: Yihao Wang <42559837+AgainstEntropy@users.noreply.github.com> Date: Wed, 25 Mar 2026 19:19:14 -0700 Subject: [PATCH 348/831] CUDA & CPU: support F32 kernel type for `CONV_TRANSPOSE_2D` (llama/17094) * Refactor CUDA 2D transpose implementation to support multiple kernel types and improve parameter handling - Introduced a `conv2d_transpose_params` struct for better parameter management. - Updated `conv2d_transpose_kernel` to be templated for different kernel types (float and half). - Modified `ggml_cuda_conv_2d_transpose_p0` to handle both F16 and F32 kernel types. - Enhanced test cases to validate functionality for both kernel types. * Refactor test cases for 2D convolution transpose to support dynamic kernel types - Updated `test_conv_transpose_2d` structure to improve parameter handling by reordering constructor arguments. - Enhanced test case generation to iterate over kernel types, allowing for flexible testing of different configurations. - Removed hardcoded kernel type instances in favor of a loop for better maintainability and scalability. * Refactor ggml_compute_forward_conv_transpose_2d to support both F16 and F32 tensor types. * Refactor conv2d transpose kernel to use a template for kernel type, enhancing flexibility for different data types. Update test cases to include both F16 and F32 tensor types for comprehensive coverage. * Update ggml/src/ggml-cuda/conv2d-transpose.cu Co-authored-by: Aman Gupta * Update ggml/src/ggml-cpu/ggml-cpu.c Co-authored-by: Aman Gupta * Refactor conv2d transpose implementation by removing the conv2d_transpose_params struct and dispatching with direct kernel launch. * Enhance cpu conv2d transpose implementation by introducing a templated kernel type for improved flexibility with F16 and F32 data types. --------- Co-authored-by: Aman Gupta --- ggml/src/ggml-cpu/ggml-cpu.c | 8 ++- ggml/src/ggml-cpu/ops.cpp | 69 ++++++++++++++++++------- ggml/src/ggml-cuda/conv2d-transpose.cu | 66 +++++++++++++++-------- ggml/src/ggml-cuda/conv2d-transpose.cuh | 1 + 4 files changed, 102 insertions(+), 42 deletions(-) diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 8b323bd9b06..df17cc55300 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -2871,8 +2871,12 @@ struct ggml_cplan ggml_graph_plan( const int64_t ne11 = node->src[1]->ne[1]; // H const int64_t ne12 = node->src[1]->ne[2]; // Channels In - cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02*ne03; - cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12; + GGML_ASSERT(node->src[0]->type == GGML_TYPE_F16 || node->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(node->src[1]->type == GGML_TYPE_F32); + + cur += ggml_type_size(node->src[0]->type) * ne00 * ne01 * ne02 * ne03; + cur += ggml_type_size(node->src[0]->type) * ne10 * ne11 * ne12; + } break; case GGML_OP_TOP_K: { diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 3f85e531daa..d950972c83e 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -6923,16 +6923,15 @@ void ggml_compute_forward_conv_3d( ggml_compute_forward_conv_3d_impl(params, src0, src1, dst, src0->type); } -// ggml_compute_forward_conv_transpose_2d - -void ggml_compute_forward_conv_transpose_2d( - const ggml_compute_params * params, - ggml_tensor * dst) { +template +static void ggml_compute_forward_conv_transpose_2d_impl( + const ggml_compute_params * params, + ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; - GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_F32); GGML_ASSERT(src1->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); @@ -6943,7 +6942,7 @@ void ggml_compute_forward_conv_transpose_2d( const int nk = ne00*ne01*ne02*ne03; - GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nb00 == ggml_type_size(src0->type)); GGML_ASSERT(nb10 == sizeof(float)); if (ith == 0) { @@ -6951,12 +6950,12 @@ void ggml_compute_forward_conv_transpose_2d( // permute kernel data (src0) from (Kw x Kh x Cout x Cin) to (Cin x Kw x Kh x Cout) { - ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0; + kernel_t * const wdata = (kernel_t *) params->wdata + 0; for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { - const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i03*nb03 + i02*nb02); - ggml_fp16_t * dst_data = wdata + i02*ne01*ne00*ne03; + const kernel_t * const src = (kernel_t *)((char *) src0->data + i03*nb03 + i02*nb02); + kernel_t * dst_data = wdata + i02*ne01*ne00*ne03; for (int64_t i01 = 0; i01 < ne01; i01++) { for (int64_t i00 = 0; i00 < ne00; i00++) { dst_data[i01*ne00*ne03 + i00*ne03 + i03] = src[i01 * ne00 + i00]; @@ -6968,13 +6967,17 @@ void ggml_compute_forward_conv_transpose_2d( // permute source data (src1) from (Sw x Sh x Cin) to (Cin x Sw x Sh) { - ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + nk; + kernel_t * const wdata = (kernel_t *) params->wdata + nk; for (int i12 = 0; i12 < ne12; i12++) { for (int i11 = 0; i11 < ne11; i11++) { const float * const src = (float *)((char *) src1->data + i12*nb12 + i11*nb11); - ggml_fp16_t * dst_data = wdata + i11*ne10*ne12; + kernel_t * dst_data = wdata + i11*ne10*ne12; for (int i10 = 0; i10 < ne10; i10++) { - dst_data[i10*ne12 + i12] = GGML_CPU_FP32_TO_FP16(src[i10]); + if constexpr (std::is_same_v) { + dst_data[i10*ne12 + i12] = GGML_CPU_FP32_TO_FP16(src[i10]); + } else { + dst_data[i10*ne12 + i12] = src[i10]; + } } } } @@ -6996,21 +6999,27 @@ void ggml_compute_forward_conv_transpose_2d( const int ip0 = dp*ith; const int ip1 = MIN(ip0 + dp, np); - ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0; - ggml_fp16_t * const wdata_src = wdata + nk; + kernel_t * const wdata = (kernel_t *) params->wdata + 0; + kernel_t * const wdata_src = wdata + nk; for (int i2 = ip0; i2 < ip1; i2++) { // Cout float * dst_data = (float *)((char *) dst->data + i2*nb2); - ggml_fp16_t * wdata_kernel = wdata + i2*ne01*ne00*ne03; + kernel_t * wdata_kernel = wdata + i2*ne01*ne00*ne03; for (int i11 = 0; i11 < ne11; i11++) { for (int i10 = 0; i10 < ne10; i10++) { const int i1n = i11*ne10*ne12 + i10*ne12; for (int i01 = 0; i01 < ne01; i01++) { for (int i00 = 0; i00 < ne00; i00++) { float v = 0; - ggml_vec_dot_f16(ne03, &v, 0, - wdata_src + i1n, 0, - wdata_kernel + i01*ne00*ne03 + i00*ne03, 0, 1); + if constexpr (std::is_same_v) { + ggml_vec_dot_f16(ne03, &v, 0, + wdata_src + i1n, 0, + wdata_kernel + i01*ne00*ne03 + i00*ne03, 0, 1); + } else { + ggml_vec_dot_f32(ne03, &v, 0, + wdata_src + i1n, 0, + wdata_kernel + i01*ne00*ne03 + i00*ne03, 0, 1); + } dst_data[(i11*stride + i01)*ne0 + i10*stride + i00] += v; } } @@ -7019,6 +7028,28 @@ void ggml_compute_forward_conv_transpose_2d( } } +void ggml_compute_forward_conv_transpose_2d( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_conv_transpose_2d_impl(params, dst); + } break; + case GGML_TYPE_F32: + { + ggml_compute_forward_conv_transpose_2d_impl(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + // ggml_compute_forward_conv_2d_dw struct ggml_conv_2d_dw_params { diff --git a/ggml/src/ggml-cuda/conv2d-transpose.cu b/ggml/src/ggml-cuda/conv2d-transpose.cu index 03224e404d3..6cbd6f879e6 100644 --- a/ggml/src/ggml-cuda/conv2d-transpose.cu +++ b/ggml/src/ggml-cuda/conv2d-transpose.cu @@ -1,12 +1,20 @@ -#include - #include "conv2d-transpose.cuh" -#include "ggml.h" - -__global__ void conv2d_transpose_kernel(const float * __restrict__ input, const half * __restrict__ kernel, - float * __restrict__ output, const int in_w, const int in_h, const int out_w, - const int out_h, const int kernel_w, const int kernel_h, const int stride, - const int c_in, const int c_out, const int batches) { +#include "convert.cuh" + +template +static __global__ void conv2d_transpose_kernel(const float * __restrict__ input, + const kernel_t * __restrict__ kernel, + float * __restrict__ output, + const int in_w, + const int in_h, + const int out_w, + const int out_h, + const int kernel_w, + const int kernel_h, + const int stride, + const int c_in, + const int c_out, + const int batches) { const int global_idx = blockIdx.x * blockDim.x + threadIdx.x; const int total_elements = out_w * out_h * c_out * batches; @@ -26,24 +34,32 @@ __global__ void conv2d_transpose_kernel(const float * __restrict__ input, const for (int c_in_idx = 0; c_in_idx < c_in; c_in_idx++) { for (int kh = 0; kh < kernel_h; ++kh) { int in_y = out_y_idx - kh; - if (in_y < 0 || in_y % stride) continue; + if (in_y < 0 || in_y % stride) { + continue; + } in_y /= stride; - if (in_y >= in_h) continue; + if (in_y >= in_h) { + continue; + } for (int kw = 0; kw < kernel_w; ++kw) { int in_x = out_x_idx - kw; - if (in_x < 0 || in_x % stride) continue; + if (in_x < 0 || in_x % stride) { + continue; + } in_x /= stride; - if (in_x >= in_w) continue; + if (in_x >= in_w) { + continue; + } const int input_idx = (in_w * in_h * c_in) * n_idx + (in_w * in_h) * c_in_idx + (in_w) *in_y + in_x; const int kernel_idx = (kernel_h * kernel_w * c_out) * c_in_idx + (kernel_h * kernel_w) * c_idx + (kernel_w) *kh + kw; - float input_val = input[input_idx]; - half kern_val = kernel[kernel_idx]; + float input_val = input[input_idx]; + kernel_t kern_val = kernel[kernel_idx]; - accumulator += input_val * (float) kern_val; + accumulator += input_val * ggml_cuda_cast(kern_val); } } } @@ -56,11 +72,12 @@ void ggml_cuda_conv_2d_transpose_p0(ggml_backend_cuda_context & ctx, ggml_tensor const ggml_tensor * kernel = dst->src[0]; const ggml_tensor * input = dst->src[1]; - GGML_ASSERT(kernel->type == GGML_TYPE_F16 && input->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32); + GGML_ASSERT(kernel->type == GGML_TYPE_F16 || kernel->type == GGML_TYPE_F32); + GGML_ASSERT(input->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32); const float * input_data = (const float *) input->data; float * output_data = (float *) dst->data; - const half * kernel_data = (const half *) kernel->data; + const void * kernel_data = kernel->data; const int input_w = input->ne[0]; const int input_h = input->ne[1]; @@ -82,10 +99,17 @@ void ggml_cuda_conv_2d_transpose_p0(ggml_backend_cuda_context & ctx, ggml_tensor GGML_ASSERT(ggml_is_contiguous(kernel)); GGML_ASSERT(ggml_is_contiguous(dst)); - const int total = (output_w * output_h * channels_out * batches); + const int total = output_w * output_h * channels_out * batches; const int blocks = (total + CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE - 1) / CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE; - conv2d_transpose_kernel<<>>( - input_data, kernel_data, output_data, input_w, input_h, output_w, output_h, kernel_w, kernel_h, stride, - channels_in, channels_out, batches); + if (kernel->type == GGML_TYPE_F16) { + conv2d_transpose_kernel<<>>( + input_data, (const half *) kernel_data, output_data, input_w, input_h, output_w, output_h, kernel_w, + kernel_h, stride, channels_in, channels_out, batches); + + } else { + conv2d_transpose_kernel<<>>( + input_data, (const float *) kernel_data, output_data, input_w, input_h, output_w, output_h, kernel_w, + kernel_h, stride, channels_in, channels_out, batches); + } } diff --git a/ggml/src/ggml-cuda/conv2d-transpose.cuh b/ggml/src/ggml-cuda/conv2d-transpose.cuh index c9430b24850..72889c5f0fa 100644 --- a/ggml/src/ggml-cuda/conv2d-transpose.cuh +++ b/ggml/src/ggml-cuda/conv2d-transpose.cuh @@ -1,4 +1,5 @@ #include "common.cuh" #define CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE 256 + void ggml_cuda_conv_2d_transpose_p0(ggml_backend_cuda_context & ctx, ggml_tensor * dst); From eb747f3def7907b11b661a8af1450dd508fd6c9d Mon Sep 17 00:00:00 2001 From: Michael Wand Date: Thu, 26 Mar 2026 01:54:03 -0700 Subject: [PATCH 349/831] ggml-cuda: Add NVFP4 dp4a kernel (llama/20644) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Added check for dst_t to cuda_cast template for float Restored ggml_cuda_ue4m3_to_fp32, changed vecdot ints to int32ts Added CUDART/HIP Check and HIP/fp8 include Added NVFP4 to Test-backend-ops Added hip_fp8_e4m3 to __nv_fp8_e4m3 typedef --------- Co-authored-by: Johannes Gäßler --- ggml/src/ggml-cuda/common.cuh | 17 ++++++++++++ ggml/src/ggml-cuda/convert.cu | 43 +++++++++++++++++++++++++++++++ ggml/src/ggml-cuda/ggml-cuda.cu | 10 ++++++- ggml/src/ggml-cuda/mmvq.cu | 8 ++++++ ggml/src/ggml-cuda/vecdotq.cuh | 32 +++++++++++++++++++++++ ggml/src/ggml-cuda/vendors/cuda.h | 5 ++-- ggml/src/ggml-cuda/vendors/hip.h | 6 +++++ 7 files changed, 118 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 36d8a3aaab2..9f93c70d21d 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -799,6 +799,16 @@ static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) { #endif // CUDART_VERSION >= 12050 } +static __device__ __forceinline__ float ggml_cuda_ue4m3_to_fp32(uint8_t x) { +#ifdef FP8_AVAILABLE + const uint32_t bits = x * (x != 0x7F && x != 0xFF); // Convert NaN to 0.0f to match CPU implementation. + const __nv_fp8_e4m3 xf = *reinterpret_cast(&bits); + return static_cast(xf) / 2; +#else + NO_DEVICE_CODE; +#endif // FP8_AVAILABLE +} + __device__ __forceinline__ uint8_t ggml_cuda_float_to_fp4_e2m1(float x, float e) { const uint8_t sign_bit = (x < 0.0f) << 3; float ax = fabsf(x) * e; @@ -931,6 +941,13 @@ struct ggml_cuda_type_traits { static constexpr int qi = QI_MXFP4; }; +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK_NVFP4; + static constexpr int qr = QR_NVFP4; + static constexpr int qi = QI_NVFP4; +}; + template<> struct ggml_cuda_type_traits { static constexpr int qk = QK_K; diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index b70492c7d6c..79ccfe568a2 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -617,6 +617,45 @@ static void dequantize_row_mxfp4_cuda(const void * vx, dst_t * y, const int64_t dequantize_block_mxfp4<<>>(vx, y); } +template +static __global__ void dequantize_block_nvfp4( + const void * __restrict__ vx, + dst_t * __restrict__ yy, + const int64_t ne) { + const int64_t i = blockIdx.x; + const int tid = threadIdx.x; + + const int64_t base = i * QK_NVFP4; + if (base >= ne) { + return; + } + + const block_nvfp4 * x = (const block_nvfp4 *) vx; + const block_nvfp4 & xb = x[i]; + + const int sub = tid / (QK_NVFP4_SUB / 2); + const int j = tid % (QK_NVFP4_SUB / 2); + + const float d = ggml_cuda_ue4m3_to_fp32(xb.d[sub]); + const uint8_t q = xb.qs[sub * (QK_NVFP4_SUB / 2) + j]; + + const int64_t y0 = base + sub * QK_NVFP4_SUB + j; + const int64_t y1 = y0 + QK_NVFP4_SUB / 2; + + yy[y0] = ggml_cuda_cast(d * kvalues_mxfp4[q & 0x0F]); + yy[y1] = ggml_cuda_cast(d * kvalues_mxfp4[q >> 4]); +} + +template +static void dequantize_row_nvfp4_cuda( + const void * vx, + dst_t * y, + const int64_t k, + cudaStream_t stream) { + GGML_ASSERT(k % QK_NVFP4 == 0); + const int nb = k / QK_NVFP4; + dequantize_block_nvfp4<<>>(vx, y, k); +} template static __global__ void convert_unary( const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01, @@ -715,6 +754,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { return dequantize_row_iq3_s_cuda; case GGML_TYPE_MXFP4: return dequantize_row_mxfp4_cuda; + case GGML_TYPE_NVFP4: + return dequantize_row_nvfp4_cuda; case GGML_TYPE_F32: return convert_unary_cont_cuda; case GGML_TYPE_BF16: @@ -766,6 +807,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { return dequantize_row_iq3_s_cuda; case GGML_TYPE_MXFP4: return dequantize_row_mxfp4_cuda; + case GGML_TYPE_NVFP4: + return dequantize_row_nvfp4_cuda; case GGML_TYPE_F16: return convert_unary_cont_cuda; case GGML_TYPE_BF16: diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index a31e843e153..cc80eb3ffc2 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -1297,7 +1297,12 @@ static void ggml_cuda_op_mul_mat_cublas( const bool supports_bf16 = GGML_CUDA_CC_IS_NVIDIA(cc) || GGML_CUDA_CC_IS_AMD(cc) || (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2); - const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT; + const bool use_fp16 = + src0->type != GGML_TYPE_NVFP4 && + (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && + ggml_is_contiguous(src0) && + row_diff == src0->ne[1] && + dst->op_params[0] == GGML_PREC_DEFAULT; if (supports_bf16 && src0->type == GGML_TYPE_BF16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) { ggml_cuda_pool_alloc src1_as_bf16(ctx.pool(id)); @@ -4781,6 +4786,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: case GGML_TYPE_MXFP4: +#ifdef FP8_AVAILABLE + case GGML_TYPE_NVFP4: +#endif // FP8_AVAILABLE case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 024b3d8cf22..66bd8beeae7 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -15,6 +15,7 @@ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) case GGML_TYPE_Q5_1: return vec_dot_q5_1_q8_1; case GGML_TYPE_Q8_0: return vec_dot_q8_0_q8_1; case GGML_TYPE_MXFP4: return vec_dot_mxfp4_q8_1; + case GGML_TYPE_NVFP4: return vec_dot_nvfp4_q8_1; case GGML_TYPE_Q2_K: return vec_dot_q2_K_q8_1; case GGML_TYPE_Q3_K: return vec_dot_q3_K_q8_1; case GGML_TYPE_Q4_K: return vec_dot_q4_K_q8_1; @@ -41,6 +42,7 @@ static constexpr __host__ __device__ int get_vdr_mmvq(ggml_type type) { case GGML_TYPE_Q5_1: return VDR_Q5_1_Q8_1_MMVQ; case GGML_TYPE_Q8_0: return VDR_Q8_0_Q8_1_MMVQ; case GGML_TYPE_MXFP4: return VDR_MXFP4_Q8_1_MMVQ; + case GGML_TYPE_NVFP4: return VDR_NVFP4_Q8_1_MMVQ; case GGML_TYPE_Q2_K: return VDR_Q2_K_Q8_1_MMVQ; case GGML_TYPE_Q3_K: return VDR_Q3_K_Q8_1_MMVQ; case GGML_TYPE_Q4_K: return VDR_Q4_K_Q8_1_MMVQ; @@ -626,6 +628,12 @@ static void mul_mat_vec_q_switch_type( nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; + case GGML_TYPE_NVFP4: + mul_mat_vec_q_switch_ncols_dst + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); + break; case GGML_TYPE_Q2_K: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, diff --git a/ggml/src/ggml-cuda/vecdotq.cuh b/ggml/src/ggml-cuda/vecdotq.cuh index ab803aca21b..40b2b41e7e8 100644 --- a/ggml/src/ggml-cuda/vecdotq.cuh +++ b/ggml/src/ggml-cuda/vecdotq.cuh @@ -322,6 +322,38 @@ static __device__ __forceinline__ float vec_dot_mxfp4_q8_1( return d * sumi; } +#define VDR_NVFP4_Q8_1_MMVQ 4 +#define VDR_NVFP4_Q8_1_MMQ 8 + +static __device__ __forceinline__ float vec_dot_nvfp4_q8_1( + const void * __restrict__ vbq, + const block_q8_1 * __restrict__ bq8_1, + const int32_t & kbx, + const int32_t & iqs) { + + const block_nvfp4 * bq4 = (const block_nvfp4 *) vbq + kbx; + float sum = 0.0f; +#pragma unroll + for (int i = 0; i < VDR_NVFP4_Q8_1_MMVQ/2; i++) { + const int32_t iqs0 = iqs + 2*i; + const int32_t iqs1 = iqs0 + 1; + const int32_t is = iqs0 >> 1; + const int2 v0 = get_int_from_table_16(get_int_b4(bq4->qs, iqs0), kvalues_mxfp4); + const int2 v1 = get_int_from_table_16(get_int_b4(bq4->qs, iqs1), kvalues_mxfp4); + const block_q8_1 * bq8 = bq8_1 + (is >> 1); + const int32_t i8 = ((is & 1) << 2); + + int sumi = ggml_cuda_dp4a(v0.x, get_int_b4(bq8->qs, i8 + 0), 0); + sumi = ggml_cuda_dp4a(v0.y, get_int_b4(bq8->qs, i8 + 2), sumi); + sumi = ggml_cuda_dp4a(v1.x, get_int_b4(bq8->qs, i8 + 1), sumi); + sumi = ggml_cuda_dp4a(v1.y, get_int_b4(bq8->qs, i8 + 3), sumi); + + const float d = ggml_cuda_ue4m3_to_fp32(bq4->d[is]) * __low2float(bq8->ds); + sum += d * float(sumi); + } + + return sum; +} #define VDR_Q2_K_Q8_1_MMVQ 1 #define VDR_Q2_K_Q8_1_MMQ 4 diff --git a/ggml/src/ggml-cuda/vendors/cuda.h b/ggml/src/ggml-cuda/vendors/cuda.h index ba032cfab4b..07bc47df3b8 100644 --- a/ggml/src/ggml-cuda/vendors/cuda.h +++ b/ggml/src/ggml-cuda/vendors/cuda.h @@ -6,9 +6,10 @@ #include #include -#if CUDART_VERSION >= 12050 +#if CUDART_VERSION >= 11080 #include -#endif // CUDART_VERSION >= 12050 +#define FP8_AVAILABLE +#endif // CUDART_VERSION >= 11080 #if CUDART_VERSION >= 12080 #include diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h index 35d1e1a0639..9d9ba1ee219 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h @@ -235,6 +235,12 @@ typedef __hip_bfloat16 nv_bfloat16; typedef __hip_bfloat162 nv_bfloat162; +#if HIP_VERSION >= 60200000 +#include +typedef __hip_fp8_e4m3 __nv_fp8_e4m3; +#define FP8_AVAILABLE +#endif // HIP_VERSION >= 60200000 + typedef int8_t int8x4_t __attribute__((ext_vector_type(4))); typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4))); static __device__ __forceinline__ int __vsubss4(const int a, const int b) { From 07237ff99e479c364d778c1f4d1cad8729305dca Mon Sep 17 00:00:00 2001 From: ihb2032 <40718643+ihb2032@users.noreply.github.com> Date: Thu, 26 Mar 2026 19:08:41 +0800 Subject: [PATCH 350/831] fix(ggml): correct RISC-V ISA string canonical ordering for RVV in CMake (llama/20888) Signed-off-by: ihb2032 --- ggml/src/ggml-cpu/CMakeLists.txt | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index 1a1bbc9f2be..beebc4760d2 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -460,6 +460,10 @@ function(ggml_add_cpu_backend_variant_impl tag_name) endif() if(NOT GGML_CPU_ALL_VARIANTS) set(MARCH_STR "rv64gc") + if (GGML_RVV) + string(APPEND MARCH_STR "v") + endif() + if (GGML_RV_ZFH) string(APPEND MARCH_STR "_zfh") endif() @@ -467,7 +471,6 @@ function(ggml_add_cpu_backend_variant_impl tag_name) if (GGML_XTHEADVECTOR) string(APPEND MARCH_STR "_xtheadvector") elseif (GGML_RVV) - string(APPEND MARCH_STR "_v") if (GGML_RV_ZVFH) string(APPEND MARCH_STR "_zvfh") endif() @@ -475,12 +478,14 @@ function(ggml_add_cpu_backend_variant_impl tag_name) string(APPEND MARCH_STR "_zvfbfwma") endif() endif() + if (GGML_RV_ZICBOP) string(APPEND MARCH_STR "_zicbop") endif() if (GGML_RV_ZIHINTPAUSE) string(APPEND MARCH_STR "_zihintpause") endif() + list(APPEND ARCH_FLAGS "-march=${MARCH_STR}" -mabi=lp64d) else() # Begin with the lowest baseline From 1848f994e324840cd9e1b67f9b2685868546debe Mon Sep 17 00:00:00 2001 From: lhez Date: Thu, 26 Mar 2026 08:52:21 -0700 Subject: [PATCH 351/831] opencl: allow large buffer for adreno (llama/20997) --- ggml/src/ggml-opencl/ggml-opencl.cpp | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 4dddcd82cfa..c40e1f2d391 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -394,6 +394,9 @@ struct ggml_backend_opencl_context { bool fp16_support; bool has_vector_subgroup_broadcast; bool disable_fusion; + + bool adreno_has_large_buffer; + bool adreno_use_large_buffer; ggml_cl_compiler_version adreno_cl_compiler_version; int adreno_wave_size; @@ -787,6 +790,10 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve " -cl-mad-enable -cl-unsafe-math-optimizations" " -cl-finite-math-only -cl-fast-relaxed-math"; + if (backend_ctx->adreno_use_large_buffer) { + compile_opts += " -qcom-enable-large-buffer "; + } + GGML_LOG_INFO("ggml_opencl: loading OpenCL kernels"); // add @@ -3020,6 +3027,8 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { // Check if ext_buffer contains cl_khr_fp16 backend_ctx->fp16_support = strstr(ext_buffer, "cl_khr_fp16") != NULL; GGML_LOG_INFO("ggml_opencl: device FP16 support: %s\n", backend_ctx->fp16_support ? "true" : "false"); + // check Adreno large buffer support + backend_ctx->adreno_has_large_buffer = strstr(ext_buffer, "cl_qcom_large_buffer") != NULL; // fp16 is required if (!backend_ctx->fp16_support) { @@ -3086,6 +3095,18 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { GGML_LOG_INFO("ggml_opencl: using kernels optimized for Adreno (GGML_OPENCL_USE_ADRENO_KERNELS)\n"); #endif // GGML_OPENCL_USE_ADRENO_KERNELS + // determine whether to use large buffer for Adreno + backend_ctx->adreno_use_large_buffer = getenv("GGML_OPENCL_ADRENO_USE_LARGE_BUFFER") != nullptr && + backend_ctx->gpu_family == GPU_FAMILY::ADRENO; + if (backend_ctx->adreno_use_large_buffer) { + if (!backend_ctx->adreno_has_large_buffer) { + GGML_LOG_INFO("ggml_opencl: Adreno large buffer requested but not supported by driver, will use regular buffer\n"); + backend_ctx->adreno_use_large_buffer = false; + } else { + GGML_LOG_INFO("ggml_opencl: Adreno large buffer enabled\n"); + } + } + cl_int err; // A local ref of cl_context for convenience @@ -5660,6 +5681,11 @@ static ggml_backend_buffer_t ggml_backend_opencl_buffer_type_alloc_buffer(ggml_b cl_int err; cl_mem mem = clCreateBuffer(backend_ctx->context, CL_MEM_READ_WRITE, size, NULL, &err); + if (err != CL_SUCCESS && backend_ctx->adreno_use_large_buffer) { + cl_mem_properties props[] = { 0x41A6 /* CL_LARGE_BUFFER_QCOM */, 1, 0 }; + mem = clCreateBufferWithProperties(backend_ctx->context, props, CL_MEM_READ_WRITE, size, NULL, &err); + } + if (err != CL_SUCCESS) { GGML_LOG_INFO("%s: failed to allocate %.2f MiB\n", __func__, size / 1024.0 / 1024.0); return nullptr; From 45a708343104837e1bc74a983b94d1101fc6e13a Mon Sep 17 00:00:00 2001 From: uvos Date: Thu, 26 Mar 2026 23:06:33 +0100 Subject: [PATCH 352/831] hip: use fnuz fp8 for conversion on CDNA3 (llama/21040) --- ggml/src/ggml-cuda/common.cuh | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 9f93c70d21d..7d7f20af3a0 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -802,7 +802,13 @@ static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) { static __device__ __forceinline__ float ggml_cuda_ue4m3_to_fp32(uint8_t x) { #ifdef FP8_AVAILABLE const uint32_t bits = x * (x != 0x7F && x != 0xFF); // Convert NaN to 0.0f to match CPU implementation. +#if defined(GGML_USE_HIP) && defined(CDNA3) + // ROCm dose not support fp8 in software on devices with fp8 hardware, + // but CDNA3 supports only e4m3_fnuz (no inf). + const __hip_fp8_e4m3_fnuz xf = *reinterpret_cast(&bits); +#else const __nv_fp8_e4m3 xf = *reinterpret_cast(&bits); +#endif // defined(GGML_USE_HIP) && defined(GGML_USE_HIP) return static_cast(xf) / 2; #else NO_DEVICE_CODE; From b564a99ed63abdf646a30f00b54bd5557238e900 Mon Sep 17 00:00:00 2001 From: ren <189031187+lathrys-at@users.noreply.github.com> Date: Fri, 27 Mar 2026 00:05:21 -0700 Subject: [PATCH 353/831] metal : Fix dimension constraint violation in matmul2d descriptor (llama/21048) Updates Metal tensor API test probe to fix the dimension constraint violation in the matmul2d descriptor (at least one value must be a multiple of 16). --- ggml/src/ggml-metal/ggml-metal-device.m | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index cbef2fb4879..17d51b11b6e 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -690,7 +690,7 @@ ggml_metal_device_t ggml_metal_device_init(int device) { " auto tB = B.slice((int)tgid.x, 0); \n" " \n" " matmul2d< \n" - " matmul2d_descriptor(8, 8, dynamic_extent), \n" + " matmul2d_descriptor(16, 16, dynamic_extent), \n" " execution_simdgroups<4>> mm; \n" " \n" " auto cT = mm.get_destination_cooperative_tensor(); \n" @@ -740,7 +740,7 @@ ggml_metal_device_t ggml_metal_device_init(int device) { " auto tB = B.slice((int)tgid.x, 0); \n" " \n" " matmul2d< \n" - " matmul2d_descriptor(8, 8, dynamic_extent), \n" + " matmul2d_descriptor(16, 16, dynamic_extent), \n" " execution_simdgroups<4>> mm; \n" " \n" " auto cT = mm.get_destination_cooperative_tensor(); \n" From 7f466e237b02974aed2c1c49b13b3c847c1fa55b Mon Sep 17 00:00:00 2001 From: Radoslav Gerganov Date: Fri, 27 Mar 2026 10:59:35 +0200 Subject: [PATCH 354/831] rpc : proper handling of data pointers to CPU buffers (llama/21030) The compute graph may contain tensors pointing to CPU buffers. In these cases the buffer address is serialized as 0 and sent over the wire. However, the data pointer is serialized as-is and this prevents proper validation on the server side. This patches fixes this by serializing the data pointer as 0 for non-RPC buffers and doing proper validation on the server side. closes: #21006 --- ggml/src/ggml-rpc/ggml-rpc.cpp | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index 0ed2c0dce60..16f6abdffd6 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -589,8 +589,10 @@ static rpc_tensor serialize_tensor(const ggml_tensor * tensor) { ggml_backend_buffer_t buffer = tensor->buffer; ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; result.buffer = ctx != nullptr ? ctx->remote_ptr : 0; + result.data = reinterpret_cast(tensor->data); } else { result.buffer = 0; + result.data = 0; } for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) { result.ne[i] = tensor->ne[i]; @@ -606,7 +608,6 @@ static rpc_tensor serialize_tensor(const ggml_tensor * tensor) { } result.view_src = reinterpret_cast(tensor->view_src); result.view_offs = tensor->view_offs; - result.data = reinterpret_cast(tensor->data); // Avoid sending uninitialized data over the wire memset(result.name, 0, sizeof(result.name)); @@ -1443,9 +1444,11 @@ ggml_tensor * rpc_server::create_node(uint64_t id, const rpc_tensor * tensor = it_ptr->second; struct ggml_tensor * result = deserialize_tensor(ctx, tensor); - if (result == nullptr || result->buffer == nullptr) { - GGML_LOG_ERROR("[%s] invalid tensor: null %s (id=%" PRIu64 ")\n", - __func__, result == nullptr ? "tensor" : "buffer", id); + if (result == nullptr) { + return nullptr; + } + if (result->buffer == nullptr && result->data != nullptr) { + GGML_LOG_ERROR("[%s] invalid data ptr", __func__); return nullptr; } tensor_map[id] = result; From 52699f6d193058353e0832caaf8c055c705e72a6 Mon Sep 17 00:00:00 2001 From: Yiwei Shao <44545837+njsyw1997@users.noreply.github.com> Date: Fri, 27 Mar 2026 09:22:41 -0700 Subject: [PATCH 355/831] hexagon: support for IQ4_NL and MXFP4 (llama/21018) * ggml-hexagon: add IQ4_NL and MXFP4 HMX matmul support - Add IQ4_NL quantization type support to Hexagon backend (buffer set/get tensor repack, mul_mat, mul_mat_id dispatch) - Implement HVX IQ4_NL vec_dot kernels (1x1, 2x1, 2x2) with LUT-based 4-bit index to int8 kvalue dequantization - Add MXFP4 HMX dequantization path with E8M0 scale conversion, including batch-4 fast path and single-tile fallback - Unify quantized row size / scale offset logic to handle Q4_0, Q8_0, IQ4_NL, and MXFP4 in the DMA fetch path * ggml-hexagon: fix SKIP_QUANTIZE src1 address mismatch in mixed-quant models * Fix the pragma indent --- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 37 +- ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c | 209 +++++++++++- ggml/src/ggml-hexagon/htp/htp-ctx.h | 6 + ggml/src/ggml-hexagon/htp/main.c | 10 +- ggml/src/ggml-hexagon/htp/matmul-ops.c | 380 +++++++++++++++++++++ 5 files changed, 619 insertions(+), 23 deletions(-) diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 9c1ce93cc69..dd604db4333 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -1406,6 +1406,13 @@ static void ggml_backend_hexagon_buffer_set_tensor(ggml_backend_buffer_t buffer, repack_q8_0_q8x4x2(tensor, data, size); break; + case GGML_TYPE_IQ4_NL: + GGML_ASSERT(offset == 0); + GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); + // IQ4_NL has identical block layout to Q4_0 (ggml_half d + uint8_t qs[16]) + repack_q4_0_q4x4x2(tensor, data, size); + break; + case GGML_TYPE_MXFP4: GGML_ASSERT(offset == 0); GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); @@ -1442,6 +1449,12 @@ static void ggml_backend_hexagon_buffer_get_tensor(ggml_backend_buffer_t buffer, repack_q8x4x2_q8_0(data, tensor, size); break; + case GGML_TYPE_IQ4_NL: + GGML_ASSERT(offset == 0); + GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); + repack_q4x4x2_q4_0(data, tensor, size); + break; + case GGML_TYPE_MXFP4: GGML_ASSERT(offset == 0); GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); @@ -1819,6 +1832,7 @@ static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * s switch (src0->type) { case GGML_TYPE_Q4_0: case GGML_TYPE_Q8_0: + case GGML_TYPE_IQ4_NL: case GGML_TYPE_MXFP4: if (src0->ne[0] % 32) { return false; @@ -1868,6 +1882,7 @@ static bool ggml_hexagon_supported_mul_mat_id(const struct ggml_hexagon_session switch (src0->type) { case GGML_TYPE_Q4_0: case GGML_TYPE_Q8_0: + case GGML_TYPE_IQ4_NL: case GGML_TYPE_MXFP4: if ((src0->ne[0] % 32)) { return false; @@ -2596,8 +2611,26 @@ static void ggml_backend_hexagon_free(ggml_backend_t backend) { delete backend; } +// Map weight type to its activation quantization family. +// Types in the same family produce identical Q8 formats in VTCM and can +// safely share quantized activation data via SKIP_QUANTIZE. +// When adding a new quantized type, assign it the correct family here. +static inline int act_quant_family(enum ggml_type wtype) { + switch (wtype) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q8_0: + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_MXFP4: + return 1; // Q8x4x2 + default: + return 0; // unknown / not quantized + } +} + static inline bool op_reuse_src1(const ggml_tensor * op1, const ggml_tensor * op0) { - return (op0 && op0->src[1] == op1->src[1] && ggml_is_quantized(op0->src[0]->type)); + return (op0 && op0->src[1] == op1->src[1] && + act_quant_family(op0->src[0]->type) == act_quant_family(op1->src[0]->type) && + act_quant_family(op0->src[0]->type) != 0); } static inline bool is_compute_op(ggml_tensor *node) @@ -3364,6 +3397,8 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) { "please update hexagon_type to match ggml_type"); static_assert((unsigned int) HTP_TYPE_MXFP4 == (unsigned int) GGML_TYPE_MXFP4, "please update hexagon_type to match ggml_type"); + static_assert((unsigned int) HTP_TYPE_IQ4_NL == (unsigned int) GGML_TYPE_IQ4_NL, + "please update hexagon_type to match ggml_type"); const char * str_experimental = getenv("GGML_HEXAGON_EXPERIMENTAL"); const char * str_verbose = getenv("GGML_HEXAGON_VERBOSE"); diff --git a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c index a56356bee9f..4ff2b36de96 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c @@ -30,6 +30,12 @@ static const __fp16 q4_0_to_fp16_lut[64] __attribute__((aligned(VLEN))) = { -8, 0, -7, 0, -6, 0, -5, 0, -4, 0, -3, 0, -2, 0, -1, 0, 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7, 0, }; +// MXFP4 dequantization LUT: maps 4-bit index to fp16 mantissa value +// kvalues: 0, 0.5, 1, 1.5, 2, 3, 4, 6, 0, -0.5, -1, -1.5, -2, -3, -4, -6 +static const __fp16 mxfp4_to_fp16_lut[64] __attribute__((aligned(VLEN))) = { + 0, 0, 0.5, 0, 1, 0, 1.5, 0, 2, 0, 3, 0, 4, 0, 6, 0, 0, 0, -0.5, 0, -1, 0, -1.5, 0, -2, 0, -3, 0, -4, 0, -6, 0, +}; + static const __fp16 iq4_nl_to_fp16_lut[64] __attribute__((aligned(VLEN))) = { -127, 0, -104, 0, -83, 0, -65, 0, -49, 0, -35, 0, -22, 0, -10, 0, 1, 0, 13, 0, 25, 0, 38, 0, 53, 0, 69, 0, 89, 0, 113, 0, @@ -46,7 +52,8 @@ static const int32_t weight_transpose_scatter_offsets[32] __attribute__((aligned // Scales per x4x2 logical block: 8 × sizeof(__fp16) = 16 bytes #define HMX_X4X2_SCALES_PER_BLK 8 -#define HMX_X4X2_DBLK_SIZE 16 // 8 * 2 bytes +#define HMX_X4X2_DBLK_SIZE 16 // 8 * 2 bytes (fp16 scales for Q4_0/Q8_0/IQ4_NL) +#define HMX_X4X2_MXFP4_EBLK_SIZE 8 // 8 * 1 byte (E8M0 scales for MXFP4) static inline void swap_ptr(void **p1, void **p2) { void *t = *p1; @@ -78,9 +85,11 @@ static inline size_t get_x4x2_row_stride(int weight_type, int k) { switch (weight_type) { case HTP_TYPE_Q4_0: case HTP_TYPE_IQ4_NL: - return (size_t)nb * (QK_Q4_0x4x2 / 2 + HMX_X4X2_DBLK_SIZE); // 144 * nb + return (size_t) nb * (QK_Q4_0x4x2 / 2 + HMX_X4X2_DBLK_SIZE); // 144 * nb case HTP_TYPE_Q8_0: - return (size_t)nb * (QK_Q8_0x4x2 + HMX_X4X2_DBLK_SIZE); // 272 * nb + return (size_t) nb * (QK_Q8_0x4x2 + HMX_X4X2_DBLK_SIZE); // 272 * nb + case HTP_TYPE_MXFP4: + return (size_t) nb * (QK_MXFP4x4x2 / 2 + HMX_X4X2_MXFP4_EBLK_SIZE); // 136 * nb default: return 0; } @@ -284,6 +293,87 @@ static inline HVX_Vector dequantize_x4x2_q8_0_group_hvx( return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hf, v_scales)); } +// --- MXFP4 E8M0 scale conversion and dequantization --- +// +// HVX batch-convert 8 E8M0 bytes (one x4x2 block's scales) to __fp16[8] on stack. +// Scalar loads from the stack array execute on the scalar pipeline, in parallel +// with HVX vlut16/vmpy/vscatter — freeing HVX slots in the hot loop. +// Arithmetic: fp16_bits = clamp(e - 112, 0, 30) << 10 +// e=0..112 -> 0 (underflow), e=113..142 -> valid fp16, e>=143 -> clamped to 2^15. + +typedef struct { + __fp16 v[8] __attribute__((aligned(16))); +} mxfp4_scales_t; + +static inline mxfp4_scales_t mxfp4_convert_scales(const uint8_t * e8m0_8) { + mxfp4_scales_t s; + HVX_Vector v = hvx_vmemu(e8m0_8); + HVX_Vector vh = Q6_V_lo_W(Q6_Wuh_vunpack_Vub(v)); + vh = Q6_Vh_vsub_VhVh(vh, Q6_Vh_vsplat_R(112)); + vh = Q6_Vh_vmax_VhVh(vh, Q6_V_vzero()); + vh = Q6_Vh_vmin_VhVh(vh, Q6_Vh_vsplat_R(30)); + vh = Q6_Vh_vasl_VhR(vh, 10); + hvx_vec_store_u(s.v, 16, vh); + return s; +} + +static inline HVX_Vector mxfp4_extract_splat(mxfp4_scales_t scales, int idx) { + return hvx_vec_splat_f16(scales.v[idx]); +} + +// Dequantize one x4x2 MXFP4 group (32 elements from 32 packed bytes) -> 32 FP16. +static inline HVX_Vector dequantize_x4x2_mxfp4_group_hvx(const uint8_t * packed_32, + bool upper_nibbles, + int sub_blk, + const HVX_Vector vlut_cvt, + mxfp4_scales_t scales) { + HVX_Vector vq = hvx_vmemu(packed_32); + const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + HVX_Vector v_quants = upper_nibbles ? Q6_Vub_vlsr_VubR(vq, 4) : vq; + v_quants = Q6_V_vand_VV(v_quants, mask_h4); + + HVX_Vector v_sc = mxfp4_extract_splat(scales, sub_blk); + + v_quants = Q6_Vb_vshuff_Vb(v_quants); + HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0); + HVX_Vector v_hf = Q6_V_lo_W(vp); + + return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hf, v_sc)); +} + +// Batch-dequantize 4 contiguous x4x2 MXFP4 groups (4x32 = 128 packed bytes). +static inline void dequantize_x4x2_mxfp4_x4groups_hvx(const uint8_t * packed_128, + bool upper_nibbles, + int sub_blk_base, + const HVX_Vector vlut_cvt, + mxfp4_scales_t scales, + HVX_Vector out[4]) { + HVX_Vector vq = hvx_vmemu(packed_128); + const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + HVX_Vector v_quants = upper_nibbles ? Q6_Vub_vlsr_VubR(vq, 4) : vq; + v_quants = Q6_V_vand_VV(v_quants, mask_h4); + + v_quants = Q6_Vb_vshuff_Vb(v_quants); + + HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0); + HVX_Vector v_lo = Q6_V_lo_W(vp); + HVX_Vector v_hi = Q6_V_hi_W(vp); + + HVX_VectorPred q64 = Q6_Q_vsetq_R(64); + HVX_Vector v_sc01 = Q6_V_vmux_QVV(q64, mxfp4_extract_splat(scales, sub_blk_base + 0), + mxfp4_extract_splat(scales, sub_blk_base + 1)); + HVX_Vector v_sc23 = Q6_V_vmux_QVV(q64, mxfp4_extract_splat(scales, sub_blk_base + 2), + mxfp4_extract_splat(scales, sub_blk_base + 3)); + + v_lo = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_lo, v_sc01)); + v_hi = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hi, v_sc23)); + + out[0] = v_lo; + out[1] = Q6_V_vror_VR(v_lo, 64); + out[2] = v_hi; + out[3] = Q6_V_vror_VR(v_hi, 64); +} + // Dequantize a tile range from x4x2 weight data (already in VTCM) to tile-major FP16. // Input: vtcm_src has n_cols rows of x4x2 data, each row_stride bytes. // Output: vtcm_dst in tile-major FP16 layout. @@ -295,11 +385,11 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task( int start_tile, int end_tile) { const int n_k_tiles = k_block / HMX_FP16_TILE_N_COLS; - const bool is_q4 = (weight_type == HTP_TYPE_Q4_0 || weight_type == HTP_TYPE_IQ4_NL); - const int qrow_size = is_q4 ? (k_block / 2) : k_block; + const int qrow_size = (weight_type == HTP_TYPE_Q8_0) ? k_block : (k_block / 2); - const HVX_Vector vlut_cvt = (weight_type == HTP_TYPE_IQ4_NL) - ? hvx_vmem(iq4_nl_to_fp16_lut) : hvx_vmem(q4_0_to_fp16_lut); + const HVX_Vector vlut_cvt = (weight_type == HTP_TYPE_IQ4_NL) ? hvx_vmem(iq4_nl_to_fp16_lut) : + (weight_type == HTP_TYPE_MXFP4) ? hvx_vmem(mxfp4_to_fp16_lut) : + hvx_vmem(q4_0_to_fp16_lut); // vscatter setup: write dequantized K-values directly to transposed [K][N] tile positions. // Each int32 element holds a K-row-pair (2 adjacent fp16 values). word[i] at offset i*128 @@ -312,8 +402,9 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task( int ct = t / n_k_tiles; // column tile index int kt = t % n_k_tiles; // K tile index - // --- Batch-4 fast path for Q4: process 4 contiguous K-tiles with one vlut16 per row --- - if (is_q4 && (kt % 4 == 0) && (t + 4 <= end_tile) && ((t + 3) / n_k_tiles == ct)) { + // --- Batch-4 fast path for Q4_0/IQ4_NL: process 4 contiguous K-tiles with one vlut16 per row --- + if ((weight_type == HTP_TYPE_Q4_0 || weight_type == HTP_TYPE_IQ4_NL) && (kt % 4 == 0) && (t + 4 <= end_tile) && + ((t + 3) / n_k_tiles == ct)) { int blk_idx = (kt * 32) / QK_Q4_0x4x2; int sub_blk_base = ((kt * 32) % QK_Q4_0x4x2) / 32; // 0 or 4 bool upper = (sub_blk_base >= 4); @@ -351,10 +442,60 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task( continue; } + // --- Batch-4 fast path for MXFP4: same nibble layout but E8M0 scales --- + if (weight_type == HTP_TYPE_MXFP4 && (kt % 4 == 0) && (t + 4 <= end_tile) && ((t + 3) / n_k_tiles == ct)) { + int blk_idx = (kt * 32) / QK_MXFP4x4x2; + int sub_blk_base = ((kt * 32) % QK_MXFP4x4x2) / 32; // 0 or 4 + bool upper = (sub_blk_base >= 4); + int packed_off = blk_idx * (QK_MXFP4x4x2 / 2); // 128 contiguous packed bytes + int e8m0_blk_off = qrow_size + blk_idx * HMX_X4X2_MXFP4_EBLK_SIZE; // all 8 E8M0 scales + + __fp16 * tile_bases[4]; + for (int g = 0; g < 4; g++) { + tile_bases[g] = vtcm_dst + (t + g) * HMX_FP16_TILE_N_ELMS; + } + + HVX_Vector v_off = v_scat_base; + for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) { + int row0 = ct * HMX_FP16_TILE_N_COLS + r; + int row1 = row0 + 1; + const uint8_t * r0 = vtcm_src + row0 * row_stride; + const uint8_t * r1 = vtcm_src + row1 * row_stride; + + // Batch-convert all 8 E8M0 scales once per row (stays in HVX register) + mxfp4_scales_t r0_e8 = mxfp4_convert_scales(r0 + e8m0_blk_off); + + HVX_Vector v0[4], v1[4]; + dequantize_x4x2_mxfp4_x4groups_hvx(r0 + packed_off, upper, sub_blk_base, vlut_cvt, r0_e8, v0); + if (row1 < n_cols) { + mxfp4_scales_t r1_e8 = mxfp4_convert_scales(r1 + e8m0_blk_off); + dequantize_x4x2_mxfp4_x4groups_hvx(r1 + packed_off, upper, sub_blk_base, vlut_cvt, r1_e8, v1); + } else { + v1[0] = v1[1] = v1[2] = v1[3] = Q6_V_vzero(); + } + + for (int g = 0; g < 4; g++) { + Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_bases[g], HMX_FP16_TILE_SIZE - 1, v_off, v0[g]); + } + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + for (int g = 0; g < 4; g++) { + Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_bases[g], HMX_FP16_TILE_SIZE - 1, v_off, v1[g]); + } + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + } + + for (int g = 0; g < 4; g++) { + (void) *(volatile HVX_Vector *) (tile_bases[g]); + } + + t += 4; + continue; + } + // --- Single-tile fallback --- __fp16 *tile_base = vtcm_dst + t * HMX_FP16_TILE_N_ELMS; - if (is_q4) { + if (weight_type == HTP_TYPE_Q4_0 || weight_type == HTP_TYPE_IQ4_NL) { int blk_idx = (kt * 32) / QK_Q4_0x4x2; int sub_blk = ((kt * 32) % QK_Q4_0x4x2) / 32; bool upper = (sub_blk >= 4); @@ -382,6 +523,39 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task( v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); } (void) *(volatile HVX_Vector *)(tile_base); + } else if (weight_type == HTP_TYPE_MXFP4) { + int blk_idx = (kt * 32) / QK_MXFP4x4x2; + int sub_blk = ((kt * 32) % QK_MXFP4x4x2) / 32; + bool upper = (sub_blk >= 4); + int byte_off = blk_idx * (QK_MXFP4x4x2 / 2) + (upper ? (sub_blk - 4) : sub_blk) * 32; + int e8m0_blk_off = qrow_size + blk_idx * HMX_X4X2_MXFP4_EBLK_SIZE; + + HVX_Vector v_off = v_scat_base; + for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) { + int row0 = ct * HMX_FP16_TILE_N_COLS + r; + int row1 = row0 + 1; + + const uint8_t * r0 = vtcm_src + row0 * row_stride; + const uint8_t * r1 = vtcm_src + row1 * row_stride; + + // Batch-convert all 8 E8M0 scales once per row (stays in HVX register) + mxfp4_scales_t r0_e8 = mxfp4_convert_scales(r0 + e8m0_blk_off); + + HVX_Vector v0 = dequantize_x4x2_mxfp4_group_hvx(r0 + byte_off, upper, sub_blk, vlut_cvt, r0_e8); + HVX_Vector v1; + if (row1 < n_cols) { + mxfp4_scales_t r1_e8 = mxfp4_convert_scales(r1 + e8m0_blk_off); + v1 = dequantize_x4x2_mxfp4_group_hvx(r1 + byte_off, upper, sub_blk, vlut_cvt, r1_e8); + } else { + v1 = Q6_V_vzero(); + } + + Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v1); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + } + (void) *(volatile HVX_Vector *) (tile_base); } else { // Q8_0 int blk_idx = (kt * 32) / QK_Q8_0x4x2; @@ -1455,21 +1629,24 @@ int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict { qweight_fetch_task_state_t s; - const bool is_q4 = (weight_type == HTP_TYPE_Q4_0 || weight_type == HTP_TYPE_IQ4_NL); const int blk_start = kk / QK_Q4_0x4x2; const int nb_sub = (k_blk_sz + QK_Q4_0x4x2 - 1) / QK_Q4_0x4x2; - const int full_qrow = is_q4 ? (k / 2) : k; + const int full_qrow = (weight_type == HTP_TYPE_Q8_0) ? k : (k / 2); const size_t sub_row_stride = get_x4x2_row_stride(weight_type, k_blk_sz); + const int scale_blk_size = + (weight_type == HTP_TYPE_MXFP4) ? HMX_X4X2_MXFP4_EBLK_SIZE : HMX_X4X2_DBLK_SIZE; s.dst = vtcm_scratch0; s.src = w + nc * row_stride; s.n_rows = n_blk_sz; s.src_stride = row_stride; s.dst_stride = sub_row_stride; - s.quant_off = is_q4 ? (blk_start * (QK_Q4_0x4x2 / 2)) : (blk_start * QK_Q8_0x4x2); - s.quant_width = is_q4 ? (nb_sub * (QK_Q4_0x4x2 / 2)) : (nb_sub * QK_Q8_0x4x2); - s.scale_off = full_qrow + blk_start * HMX_X4X2_DBLK_SIZE; - s.scale_width = nb_sub * HMX_X4X2_DBLK_SIZE; + s.quant_off = + (weight_type == HTP_TYPE_Q8_0) ? (blk_start * QK_Q8_0x4x2) : (blk_start * (QK_Q4_0x4x2 / 2)); + s.quant_width = + (weight_type == HTP_TYPE_Q8_0) ? (nb_sub * QK_Q8_0x4x2) : (nb_sub * (QK_Q4_0x4x2 / 2)); + s.scale_off = full_qrow + blk_start * scale_blk_size; + s.scale_width = nb_sub * scale_blk_size; // 2D DMA: quants sub-range dma_queue_push(ctx->dma[0], dma_make_ptr(s.dst, s.src + s.quant_off), diff --git a/ggml/src/ggml-hexagon/htp/htp-ctx.h b/ggml/src/ggml-hexagon/htp/htp-ctx.h index a92acfa0a85..6f1917fa2cb 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ctx.h +++ b/ggml/src/ggml-hexagon/htp/htp-ctx.h @@ -31,6 +31,12 @@ struct htp_context { uint32_t opmask; + // Cached src1 spad position from the last quantize pass. + // When SKIP_QUANTIZE is set the Q8 activation data is already in VTCM + // at this address; the matmul must read from here instead of recomputing + // the offset (which depends on the current op's src0 size). + uint8_t * prev_src1_spad; + // HMX acceleration fields (v73+, enabled by compile-time HTP_HAS_HMX) #ifdef HTP_HAS_HMX int hmx_enabled; // Runtime flag: HMX initialisation succeeded diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index 70ba9f9f4fe..49f34b5f7d1 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -1114,14 +1114,12 @@ static void proc_hmx_matmul_req(struct htp_context * ctx, return; } - // HMX only supports F16, Q4_0, Q8_0, IQ4_NL weights. - // Other types (e.g. MXFP4) fall back to HVX. + // HMX supports F16, Q4_0, Q8_0, IQ4_NL, MXFP4 weights. + // Other types fall back to HVX. { uint32_t wtype = req->src0.type; - if (wtype != HTP_TYPE_F16 && - wtype != HTP_TYPE_Q4_0 && - wtype != HTP_TYPE_Q8_0 && - wtype != HTP_TYPE_IQ4_NL) { + if (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_Q4_0 && wtype != HTP_TYPE_Q8_0 && wtype != HTP_TYPE_IQ4_NL && + wtype != HTP_TYPE_MXFP4) { proc_matmul_req(ctx, req, bufs, n_bufs); return; } diff --git a/ggml/src/ggml-hexagon/htp/matmul-ops.c b/ggml/src/ggml-hexagon/htp/matmul-ops.c index 73aaba79ebf..24b7bad6876 100644 --- a/ggml/src/ggml-hexagon/htp/matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/matmul-ops.c @@ -60,6 +60,16 @@ static const uint8_t __attribute__((aligned(128))) expand_x32_e8m0[128] = { 0x00, 0x00, 0x09, 0x08, 0x00, 0x00, 0x22, 0x20, 0x24, 0x20, 0x21, 0x22, 0x20, 0x20, }; +// IQ4_NL dequantization LUT: maps 4-bit index (0-15) to int8 kvalue +// kvalues: -127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113 +static const uint8_t __attribute__((aligned(VLEN))) kvalues_iq4nl_lut[] = { + 0x81, 0, 0x98, 0, 0xAD, 0, 0xBF, 0, 0xCF, 0, 0xDD, 0, 0xEA, 0, 0xF6, 0, 0x01, 0, 0x0D, 0, 0x19, 0, 0x26, 0, + 0x35, 0, 0x45, 0, 0x59, 0, 0x71, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, +}; + static const uint8_t __attribute__((aligned(VLEN))) kvalues_mxfp4_lut[] = { 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 6, 0, 8, 0, 12, 0, 0, 0, 0xff, 0, 0xfe, 0, 0xfd, 0, 0xfc, 0, 0xfa, 0, 0xf8, 0, 0xf4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, @@ -68,6 +78,73 @@ static const uint8_t __attribute__((aligned(VLEN))) kvalues_mxfp4_lut[] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, }; +static inline HVX_Vector_x8 hvx_vec_load_iq4nlx4x8_full(const uint8_t * restrict ptr) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; + + HVX_Vector v0_1 = vptr[0]; // first 256 elements (128 bytes) + HVX_Vector v2_3 = vptr[1]; // ... + HVX_Vector v4_5 = vptr[2]; // ... + HVX_Vector v6_7 = vptr[3]; // ... + + const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + const HVX_Vector lut = *(const HVX_Vector *) kvalues_iq4nl_lut; + + HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4); // & 0x0F + HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4); // >> 4 + HVX_Vector v2 = Q6_V_vand_VV(v2_3, mask_h4); // & 0x0F + HVX_Vector v3 = Q6_Vub_vlsr_VubR(v2_3, 4); // >> 4 + HVX_Vector v4 = Q6_V_vand_VV(v4_5, mask_h4); // & 0x0F + HVX_Vector v5 = Q6_Vub_vlsr_VubR(v4_5, 4); // >> 4 + HVX_Vector v6 = Q6_V_vand_VV(v6_7, mask_h4); // & 0x0F + HVX_Vector v7 = Q6_Vub_vlsr_VubR(v6_7, 4); // >> 4 + + v0 = Q6_Vb_vlut32_VbVbI(v0, lut, 0); + v1 = Q6_Vb_vlut32_VbVbI(v1, lut, 0); + v2 = Q6_Vb_vlut32_VbVbI(v2, lut, 0); + v3 = Q6_Vb_vlut32_VbVbI(v3, lut, 0); + v4 = Q6_Vb_vlut32_VbVbI(v4, lut, 0); + v5 = Q6_Vb_vlut32_VbVbI(v5, lut, 0); + v6 = Q6_Vb_vlut32_VbVbI(v6, lut, 0); + v7 = Q6_Vb_vlut32_VbVbI(v7, lut, 0); + + HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 }; + return r; +} + +static inline HVX_Vector_x8 hvx_vec_load_iq4nlx4x8_partial(const uint8_t * restrict ptr, uint32_t n) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; + + const uint32_t qk = QK_Q4_0x4x2; // 256 + const uint32_t nb = n / qk; + const uint32_t nloe = n % qk; + + const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + const HVX_Vector lut = *(const HVX_Vector *) kvalues_iq4nl_lut; + + HVX_Vector_x8 r; + uint32_t i = 0; + + #pragma unroll(2) + for (i = 0; i < nb; i++) { + HVX_Vector v = vptr[i]; // 256 elements (128 bytes) + HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : first 128 elements + HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : second 128 elements + r.v[i * 2 + 0] = Q6_Vb_vlut32_VbVbI(v0, lut, 0); + r.v[i * 2 + 1] = Q6_Vb_vlut32_VbVbI(v1, lut, 0); + } + + if (nloe) { + HVX_Vector v = vptr[i]; // 256 elements (128 bytes) + HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : even 128 elements + HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : odd 128 elements + HVX_VectorPair v0_1_p = Q6_W_vshuff_VVR(v1, v0, -1); // zip even:odd:... + r.v[i * 2 + 0] = Q6_Vb_vlut32_VbVbI(Q6_V_lo_W(v0_1_p), lut, 0); + r.v[i * 2 + 1] = Q6_Vb_vlut32_VbVbI(Q6_V_hi_W(v0_1_p), lut, 0); + } + + return r; +} + // q4x4x2 and q8x4x2 are the flat q4/8_0 formats where all quants are stored first followed by all scales static inline size_t q8x4x2_row_size(uint32_t ne) { @@ -921,6 +998,293 @@ static void vec_dot_q8x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); // row0,col1 row1,col1 } +// ======== IQ4_NL x Q8_0 vec_dot kernels ======== +// Same structure as Q4_0 vec_dot but uses IQ4_NL LUT-based load (4-bit index -> int8 kvalue). +// Scale format is identical to Q4_0 (fp16 scales). + +static void vec_dot_iq4nlx4x2_q8x4x2_1x1(const int n, + float * restrict s0, + const void * restrict vx0, + const void * restrict vy0) { + assert(n % 32 == 0); + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + + const uint32_t qk = QK_Q4_0x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_qblk_size = qk / 2; // int4 + const uint32_t x_qrow_size = n / 2; // int4 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales + + const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales + + HVX_Vector r0_sum = Q6_V_vzero(); + + const uint32_t nb = n / qk; + const uint32_t nloe = n % qk; + + uint32_t i = 0; + for (; i < nb; i++) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_full(r0_x_q + i * x_qblk_size); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + } + + if (nloe) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_partial(r0_x_q + i * x_qblk_size, nloe); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_dd = Q6_V_vand_QV(bmask, r0_dd); + r0_ia = Q6_V_vand_QV(bmask, r0_ia); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + } + + r0_sum = hvx_vec_reduce_sum_f32(r0_sum); + + hvx_vec_store_u(s0, 4, r0_sum); +} + +static void vec_dot_iq4nlx4x2_q8x4x2_2x1(const int n, + float * restrict s0, + const void * restrict vx0, + const void * restrict vx1, + const void * restrict vy0) { + assert(n % 32 == 0); + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + + const uint32_t qk = QK_Q4_0x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_qblk_size = qk / 2; // int4 + const uint32_t x_qrow_size = n / 2; // int4 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales + + const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales + + HVX_Vector r0_sum = Q6_V_vzero(); + HVX_Vector r1_sum = Q6_V_vzero(); + + const uint32_t nb = n / qk; + const uint32_t nloe = n % qk; + + uint32_t i = 0; + for (; i < nb; i++) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_full(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_iq4nlx4x8_full(r1_x_q + i * x_qblk_size); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + } + + if (nloe) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_partial(r0_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r1_q = hvx_vec_load_iq4nlx4x8_partial(r1_x_q + i * x_qblk_size, nloe); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_dd = Q6_V_vand_QV(bmask, r0_dd); + r1_dd = Q6_V_vand_QV(bmask, r1_dd); + r0_ia = Q6_V_vand_QV(bmask, r0_ia); + r1_ia = Q6_V_vand_QV(bmask, r1_ia); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + } + + HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum); + hvx_vec_store_u(s0, 8, rsum); +} + +static void vec_dot_iq4nlx4x2_q8x4x2_2x2(const int n, + float * restrict s0, + float * restrict s1, + const void * restrict vx0, + const void * restrict vx1, + const void * restrict vy0, + const void * restrict vy1) { + assert(n % 32 == 0); + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + assert((unsigned long) vy1 % 128 == 0); + + const uint32_t qk = QK_Q4_0x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_qblk_size = qk / 2; // int4 + const uint32_t x_qrow_size = n / 2; // int4 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; + + const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0; + const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size; + const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0; + const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; + + HVX_Vector r0_c0_sum = Q6_V_vzero(); + HVX_Vector r0_c1_sum = Q6_V_vzero(); + HVX_Vector r1_c0_sum = Q6_V_vzero(); + HVX_Vector r1_c1_sum = Q6_V_vzero(); + + const uint32_t nb = n / qk; + const uint32_t nloe = n % qk; + + uint32_t i = 0; + for (; i < nb; i++) { + HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_full(y0_q + i * y_qblk_size); + HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_full(y1_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_full(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_iq4nlx4x8_full(r1_x_q + i * x_qblk_size); + + HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q)); + HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q)); + HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q)); + HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q)); + + HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); + HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + + HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d))); + HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d))); + HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d))); + HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d))); + + HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); + HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); + HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); + HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); + + r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); + r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); + r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); + r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); + } + + if (nloe) { + HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_partial(y0_q + i * y_qblk_size, nloe); + HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_partial(y1_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_partial(r0_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r1_q = hvx_vec_load_iq4nlx4x8_partial(r1_x_q + i * x_qblk_size, nloe); + + HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy0_q, nloe)); + HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy1_q, nloe)); + HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy0_q, nloe)); + HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy1_q, nloe)); + + HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); + HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + + HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d))); + HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d))); + HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d))); + HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d))); + + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd); + r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd); + r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd); + r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd); + r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia); + r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia); + r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia); + r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia); + + HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); + HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); + HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); + HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); + + r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); + r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); + r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); + r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); + } + + HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum); + HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum); + + hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum); + hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); +} + static void vec_dot_mxfp4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) { assert(n % 32 == 0); // min sub-block size assert((unsigned long) vx0 % 128 == 0); @@ -2393,6 +2757,12 @@ static int htp_mminit_vec_dot(struct htp_matmul_context * mmctx, enum htp_data_t mmctx->vec_dot_2x1 = vec_dot_q8x4x2_q8x4x2_2x1; mmctx->vec_dot_2x2 = vec_dot_q8x4x2_q8x4x2_2x2; return 0; + case HTP_TYPE_IQ4_NL: + mmctx->type = "iq4nlx4x2-f32"; + mmctx->vec_dot_1x1 = vec_dot_iq4nlx4x2_q8x4x2_1x1; + mmctx->vec_dot_2x1 = vec_dot_iq4nlx4x2_q8x4x2_2x1; + mmctx->vec_dot_2x2 = vec_dot_iq4nlx4x2_q8x4x2_2x2; + return 0; case HTP_TYPE_MXFP4: mmctx->type = "mxfp4x4x2-f32"; mmctx->vec_dot_1x1 = vec_dot_mxfp4x4x2_q8x4x2_1x1; @@ -2556,6 +2926,13 @@ int op_matmul(struct htp_ops_context * octx) { const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads); mmctx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs; worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, mmctx, n_quant_jobs); + // Cache where src1 was written so subsequent SKIP_QUANTIZE ops can find it + octx->ctx->prev_src1_spad = octx->src1_spad.data; + } else { + // SKIP_QUANTIZE: Q8 data lives at the address written by the previous + // quantize pass. The current op may have a different src0 size (e.g. + // IQ4_NL vs MXFP4), so src1_spad.data computed above could be wrong. + octx->src1_spad.data = octx->ctx->prev_src1_spad; } if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { @@ -2659,6 +3036,9 @@ int op_matmul_id(struct htp_ops_context * octx) { const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads); mmctx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs; worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, mmctx, n_quant_jobs); + octx->ctx->prev_src1_spad = octx->src1_spad.data; + } else { + octx->src1_spad.data = octx->ctx->prev_src1_spad; } if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { From 759f0084b4172f891412fbfd22a1b20a4f25f2c1 Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Sat, 28 Mar 2026 08:44:56 +0100 Subject: [PATCH 356/831] vulkan: add noncontiguous GLU support (llama/21081) * vulkan: add noncontiguous GLU support * fix compile issue --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 39 +++++++++++++------ .../ggml-vulkan/vulkan-shaders/glu_head.glsl | 10 +++++ .../ggml-vulkan/vulkan-shaders/glu_main.glsl | 22 ++++++++--- 3 files changed, 53 insertions(+), 18 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 221e6fa04e9..15ed5b2a79d 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -1112,6 +1112,16 @@ struct vk_op_glu_push_constants { uint32_t mode; // 0: default, 1: swapped, 2: split float alpha; // for swiglu_oai float limit; + uint32_t nb01; + uint32_t nb02; + uint32_t nb03; + uint32_t ne01; + uint32_t ne02; + uint32_t nb11; + uint32_t nb12; + uint32_t nb13; + uint32_t ne11; + uint32_t ne12; }; struct vk_op_unary_push_constants { @@ -5044,7 +5054,7 @@ static vk_device ggml_vk_get_device(size_t idx) { } else { device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 1, priorities}); } - vk::DeviceCreateInfo device_create_info; + vk::DeviceCreateInfo device_create_info{}; std::vector device_extensions; vk::PhysicalDeviceFeatures device_features = device->physical_device.getFeatures(); @@ -5413,12 +5423,10 @@ static vk_device ggml_vk_get_device(size_t idx) { #endif device->name = GGML_VK_NAME + std::to_string(idx); - device_create_info = { - vk::DeviceCreateFlags(), - device_queue_create_infos, - {}, - device_extensions - }; + device_create_info + .setFlags(vk::DeviceCreateFlags()) + .setQueueCreateInfos(device_queue_create_infos) + .setPEnabledExtensionNames(device_extensions); device_create_info.setPNext(&device_features2); device->device = device->physical_device.createDevice(device_create_info); @@ -11048,8 +11056,6 @@ static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const const float alpha = op_params_f[2]; const float limit = op_params_f[3]; - GGML_ASSERT(ggml_is_contiguous(src0)); - if (!split) { GGML_ASSERT(src0->ne[0] / 2 == dst->ne[0]); } else { @@ -11067,7 +11073,17 @@ static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const (uint32_t)dst->ne[0], mode, alpha, - limit + limit, + (uint32_t)(src0->nb[1] / src0->nb[0]), + (uint32_t)(src0->nb[2] / src0->nb[0]), + (uint32_t)(src0->nb[3] / src0->nb[0]), + (uint32_t)src0->ne[1], + (uint32_t)src0->ne[2], + (uint32_t)(dst->nb[1] / dst->nb[0]), + (uint32_t)(dst->nb[2] / dst->nb[0]), + (uint32_t)(dst->nb[3] / dst->nb[0]), + (uint32_t)dst->ne[1], + (uint32_t)dst->ne[2] }); } @@ -15217,8 +15233,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_GLU_OP_SWIGLU_OAI: case GGML_GLU_OP_GEGLU_ERF: case GGML_GLU_OP_GEGLU_QUICK: - return ggml_is_contiguous(op->src[0]) && - (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && + return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (op->src[0]->type == op->type); default: diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl index 2168989340b..95298922d83 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl @@ -16,4 +16,14 @@ layout (push_constant) uniform parameter uint mode; float alpha; float limit; + uint nb01; + uint nb02; + uint nb03; + uint ne01; + uint ne02; + uint nb11; + uint nb12; + uint nb13; + uint ne11; + uint ne12; } p; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl index 85cf65a9eca..359461306a5 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl @@ -8,22 +8,32 @@ void main() { const uint row = i / p.ne20; const uint col = i - row * p.ne20; + const uint i3 = row / (p.ne01 * p.ne02); + const uint i2 = (row % (p.ne01 * p.ne02)) / p.ne01; + const uint i1 = row % p.ne01; + const uint src_idx = i3 * p.nb03 + i2 * p.nb02 + i1 * p.nb01 + col; + + const uint dst_i3 = row / (p.ne11 * p.ne12); + const uint dst_i2 = (row % (p.ne11 * p.ne12)) / p.ne11; + const uint dst_i1 = row % p.ne11; + const uint dst_idx = dst_i3 * p.nb13 + dst_i2 * p.nb12 + dst_i1 * p.nb11 + col; + if (p.mode == 0) { // Default const uint offset = p.ne00 / 2; - const uint idx = row * p.ne00 + col; + const uint idx = src_idx; - data_d[row * offset + col] = D_TYPE(op(float(data_a[idx]), float(data_a[idx + offset]))); + data_d[dst_idx] = D_TYPE(op(float(data_a[idx]), float(data_a[idx + offset]))); } else if (p.mode == 1) { // Swapped const uint offset = p.ne00 / 2; - const uint idx = row * p.ne00 + col; + const uint idx = src_idx; - data_d[row * offset + col] = D_TYPE(op(float(data_a[idx + offset]), float(data_a[idx]))); + data_d[dst_idx] = D_TYPE(op(float(data_a[idx + offset]), float(data_a[idx]))); } else { // Split - const uint idx = row * p.ne00 + col; + const uint idx = src_idx; - data_d[idx] = D_TYPE(op(float(data_a[idx]), float(data_b[idx]))); + data_d[dst_idx] = D_TYPE(op(float(data_a[idx]), float(data_b[idx]))); } } From 95ea8f9bfb03a15db08a8989966fd1ae3361e20d Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 29 Mar 2026 13:23:24 +0300 Subject: [PATCH 357/831] sync : ggml --- scripts/sync-ggml.last | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index 6557fb46cbe..58863dc6bbb 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -c044a8eeae2591faa0950c8b5e514cbc4bbfc4ca +404fcb9d7c96989569e68c9e7881ee3465a05c50 From 166c20b473d5f4d04052e699f992f625ea2a2fdd Mon Sep 17 00:00:00 2001 From: Daniel Worthington-Bodart Date: Fri, 17 Apr 2026 12:36:27 +0100 Subject: [PATCH 358/831] whisper : add stateless VAD detect + explicit state reset for streaming (#3677) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit whisper_vad_detect_speech resets LSTM state on every call, which is correct for batch processing but prevents temporal continuity when calling per-chunk in a streaming loop. Add whisper_vad_detect_speech_no_reset (skips buffer clear) and whisper_vad_reset_state (explicit clear between utterances). Existing whisper_vad_detect_speech is now a thin wrapper — zero behavior change for current callers. Co-authored-by: Claude Opus 4.6 (1M context) --- include/whisper.h | 10 ++++++++++ src/whisper.cpp | 17 +++++++++++++---- 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/include/whisper.h b/include/whisper.h index f4cc6bf7abd..b5dcdb2917a 100644 --- a/include/whisper.h +++ b/include/whisper.h @@ -695,6 +695,16 @@ extern "C" { const float * samples, int n_samples); + // Like whisper_vad_detect_speech, but does not reset LSTM state. + // Use for streaming: call whisper_vad_reset_state() between utterances. + WHISPER_API bool whisper_vad_detect_speech_no_reset( + struct whisper_vad_context * vctx, + const float * samples, + int n_samples); + + // Reset LSTM hidden/cell states to zero. + WHISPER_API void whisper_vad_reset_state(struct whisper_vad_context * vctx); + WHISPER_API int whisper_vad_n_probs(struct whisper_vad_context * vctx); WHISPER_API float * whisper_vad_probs (struct whisper_vad_context * vctx); diff --git a/src/whisper.cpp b/src/whisper.cpp index 86bfafeaad8..2f356da0f06 100644 --- a/src/whisper.cpp +++ b/src/whisper.cpp @@ -5083,7 +5083,11 @@ struct whisper_vad_context * whisper_vad_init_with_params( return vctx; } -bool whisper_vad_detect_speech( +void whisper_vad_reset_state(whisper_vad_context * vctx) { + ggml_backend_buffer_clear(vctx->buffer, 0); +} + +bool whisper_vad_detect_speech_no_reset( struct whisper_vad_context * vctx, const float * samples, int n_samples) { @@ -5095,9 +5099,6 @@ bool whisper_vad_detect_speech( WHISPER_LOG_INFO("%s: detecting speech in %d samples\n", __func__, n_samples); WHISPER_LOG_INFO("%s: n_chunks: %d\n", __func__, n_chunks); - // Reset LSTM hidden/cell states - ggml_backend_buffer_clear(vctx->buffer, 0); - vctx->probs.resize(n_chunks); WHISPER_LOG_INFO("%s: props size: %u\n", __func__, n_chunks); @@ -5165,6 +5166,14 @@ bool whisper_vad_detect_speech( return true; } +bool whisper_vad_detect_speech( + struct whisper_vad_context * vctx, + const float * samples, + int n_samples) { + whisper_vad_reset_state(vctx); + return whisper_vad_detect_speech_no_reset(vctx, samples, n_samples); +} + int whisper_vad_segments_n_segments(struct whisper_vad_segments * segments) { return segments->data.size(); } From fc674574ca27cac59a15e5b22a09b9d9ad62aafe Mon Sep 17 00:00:00 2001 From: jinweihan Date: Sun, 19 Apr 2026 22:12:57 -0700 Subject: [PATCH 359/831] bench : sync submit-results URL to ggml-org (#3769) The project moved from ggerganov/ to ggml-org/ and the README already references the new URL in both places it mentions issue #89 (README.md and examples/bench/README.md). Syncing the two remaining hardcoded URLs in examples/bench/bench.cpp and examples/bench.wasm/emscripten.cpp. The old URL still redirects, so this is cosmetic. --- examples/bench.wasm/emscripten.cpp | 2 +- examples/bench/bench.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/bench.wasm/emscripten.cpp b/examples/bench.wasm/emscripten.cpp index 083397db057..7e9f277f66e 100644 --- a/examples/bench.wasm/emscripten.cpp +++ b/examples/bench.wasm/emscripten.cpp @@ -45,7 +45,7 @@ void bench_main(size_t index) { fprintf(stderr, "\n"); fprintf(stderr, "If you wish, you can submit these results here:\n"); fprintf(stderr, "\n"); - fprintf(stderr, " https://github.com/ggerganov/whisper.cpp/issues/89\n"); + fprintf(stderr, " https://github.com/ggml-org/whisper.cpp/issues/89\n"); fprintf(stderr, "\n"); fprintf(stderr, "Please include the following information:\n"); fprintf(stderr, "\n"); diff --git a/examples/bench/bench.cpp b/examples/bench/bench.cpp index 049473d4f32..84915c56a8a 100644 --- a/examples/bench/bench.cpp +++ b/examples/bench/bench.cpp @@ -157,7 +157,7 @@ static int whisper_bench_full(const whisper_params & params) { fprintf(stderr, "\n"); fprintf(stderr, "If you wish, you can submit these results here:\n"); fprintf(stderr, "\n"); - fprintf(stderr, " https://github.com/ggerganov/whisper.cpp/issues/89\n"); + fprintf(stderr, " https://github.com/ggml-org/whisper.cpp/issues/89\n"); fprintf(stderr, "\n"); fprintf(stderr, "Please include the following information:\n"); fprintf(stderr, "\n"); From 763a4540521ae191c68e79397506b01e3d9c9d78 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 30 Mar 2026 18:34:29 +0300 Subject: [PATCH 360/831] ggml : bump version to 0.9.9 (ggml/1449) --- ggml/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index c780077acaa..a739cca4218 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -4,7 +4,7 @@ project("ggml" C CXX ASM) ### GGML Version set(GGML_VERSION_MAJOR 0) set(GGML_VERSION_MINOR 9) -set(GGML_VERSION_PATCH 8) +set(GGML_VERSION_PATCH 9) set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}") find_program(GIT_EXE NAMES git git.exe NO_CMAKE_FIND_ROOT_PATH) From 9e96d390f7dc63544ebbdafe36902879c217104a Mon Sep 17 00:00:00 2001 From: Max Krasnyansky Date: Sun, 29 Mar 2026 06:40:13 -0700 Subject: [PATCH 361/831] hexagon: dma optimizations (mostly fixing regressions) (llama/21137) * hex-fa: add simple dma cache for Mask I noticed that we were refetch the mask rows over and over. This simple cache avoids that. * hex-dma: unset in-order desc bit which caused signficant perf regression We don't rely on true in order processing of the DMA descriptors anywhere. Turns out this mode caused significant regression of around 3-4 TPS during token gen. * hex-rope: update comment to clarify that we don't need in-order DMA completions --- ggml/src/ggml-hexagon/htp/flash-attn-ops.c | 12 ++-- ggml/src/ggml-hexagon/htp/hex-dma.h | 75 ++++++++++++++++++---- ggml/src/ggml-hexagon/htp/rope-ops.c | 4 +- 3 files changed, 74 insertions(+), 17 deletions(-) diff --git a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c index 6dc978dd68a..0c9bc785620 100644 --- a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +++ b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c @@ -346,6 +346,9 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void * const HVX_Vector logit_cap = hvx_vec_splat_f32(factx->logit_softcap); + dma_cache m_cache; + dma_cache_init(&m_cache, spad_m, factx->size_m_block, DMA_CACHE_MAX_SIZE); + for (uint32_t ir = ir0; ir < ir1; ++ir) { const uint32_t iq3 = fastdiv(ir, &factx->src0_div21); const uint32_t iq2 = fastdiv(ir - iq3*neq2*neq1, &factx->src0_div1); @@ -389,9 +392,8 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void * // Mask if (mask) { const uint8_t * m_src = (const uint8_t *) (mp_base + ic_start); - uint8_t * m_dst = spad_m + (ib % 2) * factx->size_m_block; // Mask is 1D contiguous for this row - dma_queue_push(dma, dma_make_ptr(m_dst, m_src), current_block_size * 2, current_block_size * 2, current_block_size * 2, 1); + dma_cache_push(dma, &m_cache, m_src, current_block_size * 2, current_block_size * 2, current_block_size * 2, 1); } // FARF(HIGH, "fa %u: prefetch KVM: ir %u ib %u iq1 %u iq2 %u iq3 %u : size_k_row %u size_v_row %u bs %u: usec %u", @@ -554,7 +556,7 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void * // Mask if (mask) { const uint8_t * m_src = (const uint8_t *) (mp_base + next_ic_start); - dma_queue_push(dma, dma_make_ptr(m_base, m_src), next_block_size * 2, next_block_size * 2, next_block_size * 2, 1); + dma_cache_push(dma, &m_cache, m_src, next_block_size * 2, next_block_size * 2, next_block_size * 2, 1); } // FARF(HIGH, "fa %u: prefetch KVM: ir %u ib %u : iq1 %u iq2 %u iq3 %u : size_k_row %u size_v_row %u bs %u: usec %u", @@ -684,7 +686,7 @@ int op_flash_attn_ext(struct htp_ops_context * octx) { octx->src0_spad.size_per_thread = size_q_block * 1; octx->src1_spad.size_per_thread = factx.size_k_block * 2; octx->src2_spad.size_per_thread = factx.size_v_block * 2; - octx->src3_spad.size_per_thread = mask ? factx.size_m_block * 2 : 0; + octx->src3_spad.size_per_thread = mask ? factx.size_m_block * DMA_CACHE_MAX_SIZE : 0; octx->dst_spad.size_per_thread = size_vkq_acc; octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; @@ -705,6 +707,8 @@ int op_flash_attn_ext(struct htp_ops_context * octx) { octx->src3_spad.data = octx->src2_spad.data + octx->src2_spad.size; octx->dst_spad.data = octx->src3_spad.data + octx->src3_spad.size; + // FARF(ERROR, "fa: qrows-per-thread %u", factx.qrows_per_thread); + if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { worker_pool_run_func(octx->ctx->worker_pool, flash_attn_ext_f16_thread, &factx, octx->n_threads); } diff --git a/ggml/src/ggml-hexagon/htp/hex-dma.h b/ggml/src/ggml-hexagon/htp/hex-dma.h index ff166cbcc7a..7685473f463 100644 --- a/ggml/src/ggml-hexagon/htp/hex-dma.h +++ b/ggml/src/ggml-hexagon/htp/hex-dma.h @@ -143,7 +143,7 @@ static inline bool dma_queue_push_single_1d(dma_queue * q, dma_ptr dptr, size_t desc->desc_size = 0; // 1D mode desc->src_bypass = dma_src_l2_bypass_on; desc->dst_bypass = dma_dst_l2_bypass_on; - desc->order = 1; + desc->order = 0; desc->done = 0; desc->src = (void *) dptr.src; desc->dst = (void *) dptr.dst; @@ -151,8 +151,12 @@ static inline bool dma_queue_push_single_1d(dma_queue * q, dma_ptr dptr, size_t q->dptr[q->push_idx] = dptr; - dmlink(q->tail, desc); - q->tail = (dma_descriptor_2d *) desc; + if (size) { + dmlink(q->tail, desc); + q->tail = (dma_descriptor_2d *) desc; + } else { + desc->done = 1; + } // FARF(ERROR, "dma-push: i %u row-size %u nrows %d dst %p src %p\n", q->push_idx, row_size, nrows, dptr.dst, dptr.src); q->push_idx = (q->push_idx + 1) & q->idx_mask; @@ -175,7 +179,7 @@ static inline bool dma_queue_push_single_2d(dma_queue * q, dma_ptr dptr, size_t desc->dst_bypass = dma_dst_l2_bypass_on; desc->src_comp = 0; desc->dst_comp = 0; - desc->order = 1; + desc->order = 0; desc->done = 0; desc->src_stride = src_stride; desc->dst_stride = dst_stride; @@ -197,8 +201,12 @@ static inline bool dma_queue_push_single_2d(dma_queue * q, dma_ptr dptr, size_t q->dptr[q->push_idx] = dptr; - dmlink(q->tail, desc); - q->tail = desc; + if (nrows) { + dmlink(q->tail, desc); + q->tail = desc; + } else { + desc->done = 1; + } // FARF(ERROR, "dma-push: i %u row-size %u nrows %d dst %p src %p\n", q->push_idx, row_size, nrows, dptr.dst, dptr.src); q->push_idx = (q->push_idx + 1) & q->idx_mask; @@ -215,12 +223,9 @@ static inline dma_ptr dma_queue_pop(dma_queue * q) { dma_descriptor_2d * desc = &q->desc[q->pop_idx]; // Wait for desc to complete - while (1) { - dmpoll(); - if (desc->done) { - break; - } + while (!desc->done) { // FARF(ERROR, "dma-pop: waiting for DMA : %u\n", q->pop_idx); + dmpoll(); } dptr = q->dptr[q->pop_idx]; @@ -312,6 +317,54 @@ static inline bool dma_queue_push_vtcm_to_ddr(dma_queue * q, dma_ptr dptr, size_ return dma_queue_push(q, dptr, dst_row_size, src_row_size, dst_row_size, nrows); } +#define DMA_CACHE_MAX_SIZE 64U + +typedef struct { + uint8_t *base; + uint32_t line_size; + uint32_t capacity; + uint32_t src[DMA_CACHE_MAX_SIZE]; + uint16_t age[DMA_CACHE_MAX_SIZE]; +} dma_cache; + +static inline void dma_cache_init(dma_cache *c, uint8_t *base, uint32_t line_size, uint32_t capacity) +{ + c->capacity = (capacity > DMA_CACHE_MAX_SIZE) ? DMA_CACHE_MAX_SIZE : capacity; + c->base = base; + c->line_size = line_size; + + for (unsigned i=0; i < c->capacity; i++) { + c->src[i] = 0; + c->age[i] = 0; + } +} + +static inline bool dma_cache_push(dma_queue *q, dma_cache *c, const uint8_t * src, uint32_t dst_stride, uint32_t src_stride, uint32_t row_size, uint32_t nrows) +{ + uint32_t o_idx = 0; + uint16_t o_age = 0; + uint8_t * dst = 0; + + for (unsigned i=0; i < c->capacity; i++) { + if (c->src[i] == (uint32_t) src) { + c->age[i] = 0; + dst = c->base + (i * c->line_size); nrows = 0; // dummy dma + // FARF(ERROR, "dma-cache: found %p", src); + } else { + c->age[i]++; + if (c->age[i] > o_age) { o_age = c->age[i]; o_idx = i; } + } + } + if (!dst) { + // FARF(ERROR, "dma-cache: replacing #%u : age %u %p -> %p", o_idx, c->age[o_idx], (void *) c->src[o_idx], src); + c->age[o_idx] = 0; + c->src[o_idx] = (uint32_t) src; + dst = c->base + o_idx * c->line_size; // normal nrows dma + } + + return dma_queue_push(q, dma_make_ptr(dst, src), dst_stride, src_stride, row_size, nrows); +} + #ifdef __cplusplus } // extern "C" #endif diff --git a/ggml/src/ggml-hexagon/htp/rope-ops.c b/ggml/src/ggml-hexagon/htp/rope-ops.c index be9469538f6..ecedadb0fea 100644 --- a/ggml/src/ggml-hexagon/htp/rope-ops.c +++ b/ggml/src/ggml-hexagon/htp/rope-ops.c @@ -333,8 +333,8 @@ static void rope_job_f32(unsigned int nth, unsigned int ith, void * data) { // (unsigned) HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - rctx->t_start)); } - // Skip DMA transactions from prev block (if any) - // No need to wait for these since the DMA is setup for in-order processing + // Skip output DMA transactions from prev block (if any) + // No need to wait for those here since we're explicitly waiting for the latest prefecthes below. for (uint32_t d=0; d < dma_depth; d++) { dma_queue_pop_nowait(dma_queue); } // Compute loop From 6b67c918797be49d4c9e67eda05efd490f8e123d Mon Sep 17 00:00:00 2001 From: Gaurav Garg Date: Sun, 29 Mar 2026 22:05:18 +0530 Subject: [PATCH 362/831] Optimize MOE GEMV kernel for BS > 1. (llama/20905) * Optimize MOE GEMV kernel for BS > 1. The previous MOE kernel for BS > 1 had too many thread blocks (nrows_x, nchannels_dst, ncols_dst), with very little work per block. block of (32, 4) was doing inner dot product for a single row. New mul_mat_vec_q_moe kernel is dedicated for MoE multi-token kernel with grid (ceil(nrows_x/rpb), nchannels_dst), block (warp_size, ncols_dst). Each warp handles two rows independently with warp-level reduction only (no shared memory sync). This change doesn't increase any compilation time as a single template instance is needed per type. This also simplifies the original GEMV kernel and gets rid of `is_multi_token_id` specialization. * Remove em-dashes * Cherry-pick changes from @am17an PR https://github.com/ggml-org/llama.cpp/pull/20885 to enable small_k optimization only for cases where it benefits Increase max batch size for MMVQ kernels for MUL_MAT_ID to 8 * Make the max batch size for MOE GEMV kernel configurable based on GPU arch and datatype --------- Co-authored-by: Aman Gupta --- ggml/src/ggml-cuda/ggml-cuda.cu | 19 +- ggml/src/ggml-cuda/mmvq.cu | 393 +++++++++++++++++++++++++++----- ggml/src/ggml-cuda/mmvq.cuh | 5 +- 3 files changed, 358 insertions(+), 59 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index cc80eb3ffc2..d1239b1c5f7 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2343,7 +2343,8 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * static_assert(MMVQ_MAX_BATCH_SIZE == MMVF_MAX_BATCH_SIZE); if (ne2 <= MMVQ_MAX_BATCH_SIZE) { if (ggml_is_quantized(src0->type)) { - if (ne2 <= MMVQ_MMID_MAX_BATCH_SIZE) { + const int mmvq_mmid_max = get_mmvq_mmid_max_batch(src0->type, cc); + if (ne2 <= mmvq_mmid_max) { ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst); return; } @@ -2946,14 +2947,18 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) { } // [TAG_MUL_MAT_ID_CUDA_GRAPHS] - if (node->op == GGML_OP_MUL_MAT_ID && (!ggml_is_quantized(node->src[0]->type) || node->ne[2] > MMVQ_MMID_MAX_BATCH_SIZE)) { - // under these conditions, the mul_mat_id operation will need to synchronize the stream, so we cannot use CUDA graphs - // TODO: figure out a way to enable for larger batch sizes, without hurting performance - // ref: https://github.com/ggml-org/llama.cpp/pull/18958 - use_cuda_graph = false; + if (node->op == GGML_OP_MUL_MAT_ID) { + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + const int mmvq_mmid_max = get_mmvq_mmid_max_batch(node->src[0]->type, cc); + if (!ggml_is_quantized(node->src[0]->type) || node->ne[2] > mmvq_mmid_max) { + // under these conditions, the mul_mat_id operation will need to synchronize the stream, so we cannot use CUDA graphs + // TODO: figure out a way to enable for larger batch sizes, without hurting performance + // ref: https://github.com/ggml-org/llama.cpp/pull/18958 + use_cuda_graph = false; #ifndef NDEBUG - GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported node type\n", __func__); + GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported node type\n", __func__); #endif + } } if (!use_cuda_graph) { diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 66bd8beeae7..8d80d1dd9a7 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -97,6 +97,194 @@ static __host__ mmvq_parameter_table_id get_device_table_id(int cc) { return MMVQ_PARAMETERS_GENERIC; } +// Per-architecture maximum batch size for which MMVQ should be used for MUL_MAT_ID. +// Returns a value <= MMVQ_MAX_BATCH_SIZE. Default is MMVQ_MAX_BATCH_SIZE. +// Check https://github.com/ggml-org/llama.cpp/pull/20905#issuecomment-4145835627 for details + +static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_pascal_older(ggml_type type) { + switch (type) { + case GGML_TYPE_IQ1_S: return 6; + case GGML_TYPE_IQ1_M: return 6; + case GGML_TYPE_IQ2_S: return 4; + case GGML_TYPE_IQ2_XS: return 5; + case GGML_TYPE_IQ2_XXS: return 5; + case GGML_TYPE_IQ3_S: return 4; + case GGML_TYPE_IQ3_XXS: return 4; + case GGML_TYPE_IQ4_NL: return 6; + case GGML_TYPE_IQ4_XS: return 5; + case GGML_TYPE_MXFP4: return 4; + case GGML_TYPE_Q2_K: return 4; + case GGML_TYPE_Q3_K: return 4; + case GGML_TYPE_Q4_0: return 6; + case GGML_TYPE_Q4_1: return 6; + case GGML_TYPE_Q4_K: return 5; + case GGML_TYPE_Q5_0: return 6; + case GGML_TYPE_Q5_1: return 6; + case GGML_TYPE_Q5_K: return 5; + case GGML_TYPE_Q6_K: return 4; + case GGML_TYPE_Q8_0: return 4; + default: return MMVQ_MAX_BATCH_SIZE; + } +} + +static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_turing_plus(ggml_type type) { + switch (type) { + case GGML_TYPE_IQ2_S: return 7; + case GGML_TYPE_IQ3_S: return 6; + case GGML_TYPE_IQ3_XXS: return 7; + case GGML_TYPE_MXFP4: return 7; + case GGML_TYPE_Q2_K: return 7; + case GGML_TYPE_Q3_K: return 5; + default: return MMVQ_MAX_BATCH_SIZE; + } +} + +static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_gcn(ggml_type type) { + switch (type) { + case GGML_TYPE_IQ1_S: return 5; + case GGML_TYPE_IQ1_M: return 5; + case GGML_TYPE_IQ2_S: return 4; + case GGML_TYPE_IQ2_XS: return 4; + case GGML_TYPE_IQ2_XXS: return 4; + case GGML_TYPE_IQ3_S: return 4; + case GGML_TYPE_IQ3_XXS: return 4; + case GGML_TYPE_IQ4_NL: return 6; + case GGML_TYPE_IQ4_XS: return 4; + case GGML_TYPE_Q2_K: return 4; + case GGML_TYPE_Q3_K: return 4; + case GGML_TYPE_Q4_0: return 5; + case GGML_TYPE_Q4_1: return 5; + case GGML_TYPE_Q4_K: return 4; + case GGML_TYPE_Q5_K: return 4; + case GGML_TYPE_Q6_K: return 4; + case GGML_TYPE_Q8_0: return 4; + default: return MMVQ_MAX_BATCH_SIZE; + } +} + +static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_cdna(ggml_type type) { + switch (type) { + case GGML_TYPE_IQ2_S: return 5; + case GGML_TYPE_IQ2_XS: return 5; + case GGML_TYPE_IQ2_XXS: return 5; + case GGML_TYPE_IQ3_S: return 4; + case GGML_TYPE_IQ3_XXS: return 5; + default: return MMVQ_MAX_BATCH_SIZE; + } +} + +static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_rdna1_rdna2(ggml_type type) { + switch (type) { + case GGML_TYPE_IQ2_S: return 4; + case GGML_TYPE_IQ2_XS: return 4; + case GGML_TYPE_IQ2_XXS: return 4; + case GGML_TYPE_IQ3_S: return 4; + case GGML_TYPE_IQ3_XXS: return 4; + case GGML_TYPE_Q2_K: return 7; + case GGML_TYPE_Q3_K: return 4; + case GGML_TYPE_Q4_K: return 5; + case GGML_TYPE_Q5_K: return 6; + case GGML_TYPE_Q6_K: return 5; + default: return MMVQ_MAX_BATCH_SIZE; + } +} + +static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_rdna3(ggml_type type) { + switch (type) { + case GGML_TYPE_IQ1_S: return 6; + case GGML_TYPE_IQ1_M: return 6; + case GGML_TYPE_IQ2_S: return 4; + case GGML_TYPE_IQ2_XS: return 4; + case GGML_TYPE_IQ2_XXS: return 4; + case GGML_TYPE_IQ3_S: return 4; + case GGML_TYPE_IQ3_XXS: return 4; + case GGML_TYPE_IQ4_NL: return 6; + case GGML_TYPE_IQ4_XS: return 6; + case GGML_TYPE_Q4_K: return 4; + case GGML_TYPE_Q5_K: return 4; + case GGML_TYPE_Q6_K: return 4; + default: return MMVQ_MAX_BATCH_SIZE; + } +} + +static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_rdna4(ggml_type type) { + switch (type) { + case GGML_TYPE_IQ1_S: return 7; + case GGML_TYPE_IQ1_M: return 7; + case GGML_TYPE_IQ2_S: return 4; + case GGML_TYPE_IQ2_XS: return 4; + case GGML_TYPE_IQ2_XXS: return 4; + case GGML_TYPE_IQ3_S: return 4; + case GGML_TYPE_IQ3_XXS: return 4; + case GGML_TYPE_IQ4_NL: return 7; + case GGML_TYPE_IQ4_XS: return 5; + case GGML_TYPE_MXFP4: return 5; + case GGML_TYPE_Q3_K: return 4; + case GGML_TYPE_Q4_0: return 7; + case GGML_TYPE_Q4_1: return 7; + case GGML_TYPE_Q4_K: return 4; + case GGML_TYPE_Q5_0: return 7; + case GGML_TYPE_Q5_1: return 7; + case GGML_TYPE_Q5_K: return 5; + case GGML_TYPE_Q6_K: return 5; + case GGML_TYPE_Q8_0: return 7; + default: return MMVQ_MAX_BATCH_SIZE; + } +} + +// Host function: returns the max batch size for the current arch+type at runtime. +int get_mmvq_mmid_max_batch(ggml_type type, int cc) { + // NVIDIA: Volta, Ada Lovelace, and Blackwell always use MMVQ for MUL_MAT_ID. + if (cc == GGML_CUDA_CC_VOLTA || cc >= GGML_CUDA_CC_ADA_LOVELACE) { + return MMVQ_MAX_BATCH_SIZE; + } + if (cc >= GGML_CUDA_CC_TURING) { + return get_mmvq_mmid_max_batch_turing_plus(type); + } + if (GGML_CUDA_CC_IS_NVIDIA(cc)) { + return get_mmvq_mmid_max_batch_pascal_older(type); + } + // AMD + if (GGML_CUDA_CC_IS_RDNA4(cc)) { + return get_mmvq_mmid_max_batch_rdna4(type); + } + if (GGML_CUDA_CC_IS_RDNA3(cc)) { + return get_mmvq_mmid_max_batch_rdna3(type); + } + if (GGML_CUDA_CC_IS_RDNA1(cc) || GGML_CUDA_CC_IS_RDNA2(cc)) { + return get_mmvq_mmid_max_batch_rdna1_rdna2(type); + } + if (GGML_CUDA_CC_IS_CDNA(cc)) { + return get_mmvq_mmid_max_batch_cdna(type); + } + if (GGML_CUDA_CC_IS_GCN(cc)) { + return get_mmvq_mmid_max_batch_gcn(type); + } + return MMVQ_MAX_BATCH_SIZE; +} + +// Device constexpr: returns the max batch size for the current arch+type at compile time. +template +static constexpr __device__ int get_mmvq_mmid_max_batch_for_device() { +#if defined(RDNA4) + return get_mmvq_mmid_max_batch_rdna4(type); +#elif defined(RDNA3) + return get_mmvq_mmid_max_batch_rdna3(type); +#elif defined(RDNA2) || defined(RDNA1) + return get_mmvq_mmid_max_batch_rdna1_rdna2(type); +#elif defined(CDNA) + return get_mmvq_mmid_max_batch_cdna(type); +#elif defined(GCN) + return get_mmvq_mmid_max_batch_gcn(type); +#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || __CUDA_ARCH__ >= GGML_CUDA_CC_ADA_LOVELACE) + return MMVQ_MAX_BATCH_SIZE; +#elif defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING + return get_mmvq_mmid_max_batch_turing_plus(type); +#else + return get_mmvq_mmid_max_batch_pascal_older(type); +#endif +} + static constexpr __host__ __device__ int calc_nwarps(ggml_type type, int ncols_dst, mmvq_parameter_table_id table_id) { if (table_id == MMVQ_PARAMETERS_GENERIC) { switch (ncols_dst) { @@ -195,7 +383,7 @@ static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int return 1; } -template +template __launch_bounds__(calc_nwarps(type, ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1) static __global__ void mul_mat_vec_q( const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst, @@ -222,22 +410,13 @@ static __global__ void mul_mat_vec_q( const uint32_t channel_dst = blockIdx.y; - uint32_t token_idx = 0; uint32_t channel_x; uint32_t channel_y; uint32_t sample_dst; - if constexpr (is_multi_token_id) { - // Multi-token MUL_MAT_ID path, adding these in the normal path causes a perf regression for n_tokens=1 case - token_idx = blockIdx.z; - channel_x = ids[channel_dst + token_idx * ids_stride]; - channel_y = fastmodulo(channel_dst, nchannels_y); - sample_dst = 0; - } else { - channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : fastdiv(channel_dst, channel_ratio); - channel_y = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst; - sample_dst = blockIdx.z; - } + channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : fastdiv(channel_dst, channel_ratio); + channel_y = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst; + sample_dst = blockIdx.z; const uint32_t sample_x = fastdiv(sample_dst, sample_ratio); const uint32_t sample_y = sample_dst; @@ -294,9 +473,6 @@ static __global__ void mul_mat_vec_q( float tmp_gate[ncols_dst][rows_per_cuda_block] = {{0.0f}}; const block_q8_1 * y = ((const block_q8_1 *) vy) + sample_y*stride_sample_y + channel_y*stride_channel_y; - if constexpr (is_multi_token_id) { - y += token_idx*stride_col_y; - } const int kbx_offset = sample_x*stride_sample_x + channel_x*stride_channel_x + row0*stride_row_x; for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) { @@ -350,10 +526,6 @@ static __global__ void mul_mat_vec_q( dst += sample_dst*stride_sample_dst + channel_dst*stride_channel_dst + row0; - if constexpr (is_multi_token_id) { - dst += token_idx*stride_col_dst; - } - // sum up partial sums and write back result #pragma unroll for (int j = 0; j < ncols_dst; ++j) { @@ -413,6 +585,69 @@ static __global__ void mul_mat_vec_q( } } +// Dedicated MoE multi-token kernel. +// Grid: (ceil(nrows_x / c_rows_per_block), nchannels_dst) +// Block: (warp_size, ncols_dst) - each warp handles one token independently. +// No shared memory reduction needed since each warp works alone. +template +__launch_bounds__(get_mmvq_mmid_max_batch_for_device()*ggml_cuda_get_physical_warp_size(), 1) +static __global__ void mul_mat_vec_q_moe( + const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, + float * __restrict__ dst, + const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t nrows_x, + const uint32_t stride_row_x, const uint32_t stride_col_y, const uint32_t stride_col_dst, + const uint32_t stride_channel_x, const uint32_t stride_channel_y, const uint32_t stride_channel_dst, + const uint32_t ncols_dst, const uint32_t ids_stride) { + + constexpr int qk = ggml_cuda_type_traits::qk; + constexpr int qi = ggml_cuda_type_traits::qi; + constexpr int vdr = get_vdr_mmvq(type); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + + constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type); + + const uint32_t token_idx = threadIdx.y; + const int row0 = c_rows_per_block*blockIdx.x; + const int blocks_per_row_x = ncols_x / qk; + constexpr int blocks_per_iter = vdr * warp_size / qi; + + const uint32_t channel_dst = blockIdx.y; + + if (token_idx >= ncols_dst) { + return; + } + + const uint32_t channel_x = ids[channel_dst + token_idx * ids_stride]; + const uint32_t channel_y = fastmodulo(channel_dst, nchannels_y); + + const block_q8_1 * y = ((const block_q8_1 *) vy) + channel_y*stride_channel_y + token_idx*stride_col_y; + const int kbx_offset = channel_x*stride_channel_x + row0*stride_row_x; + + // partial sum for each thread + float tmp[c_rows_per_block] = {0.0f}; + + for (int kbx = threadIdx.x / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) { + const int kby = kbx * (qk/QK8_1); + const int kqs = vdr * (threadIdx.x % (qi/vdr)); + +#pragma unroll + for (int i = 0; i < c_rows_per_block; ++i) { + tmp[i] += vec_dot_q_cuda(vx, &y[kby], kbx_offset + i*stride_row_x + kbx, kqs); + } + } + + // Warp-level reduction only - no shared memory needed +#pragma unroll + for (int i = 0; i < c_rows_per_block; ++i) { + tmp[i] = warp_reduce_sum(tmp[i]); + } + + // Write results + if (threadIdx.x < c_rows_per_block && (c_rows_per_block == 1 || uint32_t(row0 + threadIdx.x) < nrows_x)) { + dst[channel_dst*stride_channel_dst + token_idx*stride_col_dst + row0 + threadIdx.x] = tmp[threadIdx.x]; + } +} + template static std::pair calc_launch_params( const int ncols_dst, const int nrows_x, const int nchannels_dst, const int nsamples_or_ntokens, @@ -425,7 +660,7 @@ static std::pair calc_launch_params( return {block_nums, block_dims}; } -template +template static void mul_mat_vec_q_switch_fusion( const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst, const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y, @@ -438,7 +673,7 @@ static void mul_mat_vec_q_switch_fusion( const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr; if constexpr (c_ncols_dst == 1) { if (has_fusion) { - mul_mat_vec_q<<>> + mul_mat_vec_q<<>> (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride); @@ -448,12 +683,33 @@ static void mul_mat_vec_q_switch_fusion( GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1"); - mul_mat_vec_q<<>> + mul_mat_vec_q<<>> (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride); } +template +static void mul_mat_vec_q_moe_launch( + const void * vx, const void * vy, const int32_t * ids, float * dst, + const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t nrows_x, + const uint32_t stride_row_x, const uint32_t stride_col_y, const uint32_t stride_col_dst, + const uint32_t stride_channel_x, const uint32_t stride_channel_y, const uint32_t stride_channel_dst, + const uint32_t ncols_dst, const uint32_t ids_stride, + const int warp_size, const int nchannels_dst, cudaStream_t stream) { + + constexpr int rows_per_block = 2; // 2 gives best perf based on tuning + const int64_t nblocks_rows = (nrows_x + rows_per_block - 1) / rows_per_block; + const dim3 block_nums(nblocks_rows, nchannels_dst); + const dim3 block_dims(warp_size, ncols_dst); + + mul_mat_vec_q_moe<<>>( + vx, vy, ids, dst, ncols_x, nchannels_y, nrows_x, + stride_row_x, stride_col_y, stride_col_dst, + stride_channel_x, stride_channel_y, stride_channel_dst, + ncols_dst, ids_stride); +} + template static void mul_mat_vec_q_switch_ncols_dst( const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst, @@ -472,20 +728,62 @@ static void mul_mat_vec_q_switch_ncols_dst( const uint3 sample_ratio_fd = init_fastdiv_values(nsamples_dst / nsamples_x); const int device = ggml_cuda_get_device(); + const int cc = ggml_cuda_info().devices[device].cc; const int warp_size = ggml_cuda_info().devices[device].warp_size; - const mmvq_parameter_table_id table_id = get_device_table_id(ggml_cuda_info().devices[device].cc); + const mmvq_parameter_table_id table_id = get_device_table_id(cc); const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr; const bool has_ids = ids != nullptr; + const auto should_use_small_k = [&](int c_ncols_dst) { + // When K is small, increase rows_per_block to match nwarps so each warp has more work to do + // Trigger when the full thread block covers all K blocks in a single loop iteration and few threads remain idle. + constexpr int qk = ggml_cuda_type_traits::qk; + constexpr int qi = ggml_cuda_type_traits::qi; + constexpr int vdr = get_vdr_mmvq(type); + const int blocks_per_row_x = ncols_x / qk; + const int blocks_per_iter_1warp = vdr * warp_size / qi; + const int nwarps = calc_nwarps(type, c_ncols_dst, table_id); + bool use = nwarps > 1 && blocks_per_row_x < nwarps * blocks_per_iter_1warp; + + constexpr std::array iq_slow_turing = { + GGML_TYPE_IQ3_XXS, + GGML_TYPE_IQ3_S, + }; + constexpr std::array iq_slow_other = { + GGML_TYPE_IQ1_S, GGML_TYPE_IQ1_M, GGML_TYPE_IQ2_XXS, GGML_TYPE_IQ2_XS, + GGML_TYPE_IQ2_S, GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS, + }; + constexpr std::array slow_pascal = { + GGML_TYPE_IQ3_S, + GGML_TYPE_Q2_K, + GGML_TYPE_Q3_K, + }; + + const bool is_nvidia_turing_plus = GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_TURING; + const bool is_nvidia_pascal_older = GGML_CUDA_CC_IS_NVIDIA(cc) && cc < GGML_CUDA_CC_VOLTA; + + if (is_nvidia_turing_plus) { + if (ncols_dst == 1 && + std::find(iq_slow_turing.begin(), iq_slow_turing.end(), type) != iq_slow_turing.end()) { + use = false; + } + } else if ((ncols_dst == 1 && std::find(iq_slow_other.begin(), iq_slow_other.end(), type) != iq_slow_other.end()) || + (is_nvidia_pascal_older && std::find(slow_pascal.begin(), slow_pascal.end(), type) != slow_pascal.end()) || + GGML_CUDA_CC_IS_RDNA(cc)) { + use = false; + } + + return use; + }; + if (has_ids && ncols_dst > 1) { - // Multi-token MUL_MAT_ID path only - single-token goes through regular path below - constexpr int c_ncols_dst = 1; - std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, ncols_dst, warp_size, table_id); - mul_mat_vec_q_switch_fusion(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, - channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, - dims.first, dims.second, 0, ids_stride, stream); + // Multi-token MUL_MAT_ID path - dedicated MoE kernel + mul_mat_vec_q_moe_launch( + vx, vy, ids, dst, ncols_x, nchannels_y_fd, nrows_x, + stride_row_x, stride_col_y, stride_col_dst, + stride_channel_x, stride_channel_y, stride_channel_dst, + ncols_dst, ids_stride, warp_size, nchannels_dst, stream); return; } @@ -493,31 +791,24 @@ static void mul_mat_vec_q_switch_ncols_dst( case 1: { constexpr int c_ncols_dst = 1; - // When K is small, increase rows_per_block to match nwarps so each warp has more work to do - // Trigger when the full thread block covers all K blocks in a single loop iteration and few threads remain idle. - constexpr int qk = ggml_cuda_type_traits::qk; - constexpr int qi = ggml_cuda_type_traits::qi; - constexpr int vdr = get_vdr_mmvq(type); - const int blocks_per_row_x = ncols_x / qk; - const int blocks_per_iter_1warp = vdr * warp_size / qi; - const int nwarps = calc_nwarps(type, c_ncols_dst, table_id); - const bool use_small_k = nwarps > 1 && blocks_per_row_x < nwarps * blocks_per_iter_1warp; + bool use_small_k = should_use_small_k(c_ncols_dst); + if (use_small_k) { - std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, - warp_size, table_id, true); - mul_mat_vec_q_switch_fusion( + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, + nsamples_dst, warp_size, table_id, true); + mul_mat_vec_q_switch_fusion( vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, - channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, - dims.first, dims.second, 0, ids_stride, stream); + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd, + stride_sample_x, stride_sample_y, stride_sample_dst, dims.first, dims.second, 0, ids_stride, + stream); } else { - std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, - warp_size, table_id); + std::pair dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, + nsamples_dst, warp_size, table_id); mul_mat_vec_q_switch_fusion( vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, - channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, - dims.first, dims.second, 0, ids_stride, stream); + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd, + stride_sample_x, stride_sample_y, stride_sample_dst, dims.first, dims.second, 0, ids_stride, + stream); } } break; case 2: { diff --git a/ggml/src/ggml-cuda/mmvq.cuh b/ggml/src/ggml-cuda/mmvq.cuh index 8a154631f69..6bf0a8e8677 100644 --- a/ggml/src/ggml-cuda/mmvq.cuh +++ b/ggml/src/ggml-cuda/mmvq.cuh @@ -1,7 +1,10 @@ #include "common.cuh" #define MMVQ_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVQ kernels. -#define MMVQ_MMID_MAX_BATCH_SIZE 4 // Max. batch size for which to use MMVQ kernels for MUL_MAT_ID + +// Returns the maximum batch size for which MMVQ should be used for MUL_MAT_ID, +// based on the quantization type and GPU architecture (compute capability). +int get_mmvq_mmid_max_batch(ggml_type type, int cc); void ggml_cuda_mul_mat_vec_q(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, const ggml_cuda_mm_fusion_args_host * fusion = nullptr); From 40ddc5a5b911851f4867207c80d8db2eb238388f Mon Sep 17 00:00:00 2001 From: Radoslav Gerganov Date: Mon, 30 Mar 2026 17:05:11 +0300 Subject: [PATCH 363/831] rpc : fix misleading error log (llama/21184) When RPC is running with a remote backend which doesn't have init_tensor function (like CPU and Metal), the server log gets full with error messages saying that init_tensor is being called with null buffer which is incorrect. This patch fixes this. --- ggml/src/ggml-rpc/ggml-rpc.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index 16f6abdffd6..1378ba9f5bf 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -1340,7 +1340,9 @@ bool rpc_server::init_tensor(const rpc_msg_init_tensor_req & request) { if (buffer && buffer->iface.init_tensor) { buffer->iface.init_tensor(buffer, tensor); } else { - GGML_LOG_ERROR("Null buffer for tensor passed to init_tensor function\n"); + if (!buffer) { + GGML_LOG_ERROR("Tensor with null buffer passed to init_tensor function\n"); + } } if (tensor->extra != nullptr) { From 75b9543856158561584afe59772713ded1e82e95 Mon Sep 17 00:00:00 2001 From: Oliver Simons Date: Mon, 30 Mar 2026 16:20:00 +0200 Subject: [PATCH 364/831] CUDA : Fix CUB's argsort when nrows % block_size == 0 CCCL < 3.1 (llama/21181) * CUDA: Fix CUB's argsort when nrows % block_size == 0 CCCL < 3.1 We wrongly calculated offset_grid as `ceildiv(nrows, block_size)`, while it must be `ceildiv(nrows + 1, block_size)`. As a consequence, we had uninitialized values in `offset_iterator[nrows]` for the case when `nrows % block_size == 0`. Fixes #21162 * Reduce nrows in test case to 256, don't need 768 --- ggml/src/ggml-cuda/argsort.cu | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/argsort.cu b/ggml/src/ggml-cuda/argsort.cu index 4896669c32a..38fdf3678c1 100644 --- a/ggml/src/ggml-cuda/argsort.cu +++ b/ggml/src/ggml-cuda/argsort.cu @@ -47,9 +47,11 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool, #ifdef STRIDED_ITERATOR_AVAILABLE auto offset_iterator = cuda::make_strided_iterator(cuda::make_counting_iterator(0), ncols); #else - ggml_cuda_pool_alloc offsets_alloc(pool, nrows + 1); + // offset_iterator needs to populate nrows + 1 elements, so we also have to ceildiv nrows + 1 by block_size + const int nrows_offset = nrows + 1; + ggml_cuda_pool_alloc offsets_alloc(pool, nrows_offset); int * offset_iterator = offsets_alloc.get(); - const dim3 offset_grid((nrows + block_size - 1) / block_size); + const dim3 offset_grid((nrows_offset + block_size - 1) / block_size); init_offsets<<>>(offset_iterator, ncols, nrows); #endif CUDA_CHECK(cudaMemcpyAsync(temp_keys, x, ncols * nrows * sizeof(float), cudaMemcpyDeviceToDevice, stream)); From 6ac5a50005e7080d1c1b293c8e753ea135a9f325 Mon Sep 17 00:00:00 2001 From: shaofeiqi Date: Mon, 30 Mar 2026 12:19:16 -0700 Subject: [PATCH 365/831] opencl: add q4_K gemm and gemv kernels for Adreno (llama/20919) * opencl: add q4_K gemm and gemv kernels for Adreno * opencl: fix whitespace * opencl: add workarounds for compiler bugs on older devices * opencl: handle fp16 denorm on X Elite * opencl: fix kernel build error * opencl: fix whitespace * opencl: make q4_K cvt kernels signature consistent --------- Co-authored-by: Li He --- ggml/src/ggml-opencl/CMakeLists.txt | 2 + ggml/src/ggml-opencl/ggml-opencl.cpp | 312 +++++++++++++++++ ggml/src/ggml-opencl/kernels/cvt.cl | 75 ++++- .../kernels/gemm_noshuffle_q4_k_f32.cl | 172 ++++++++++ .../kernels/gemv_noshuffle_q4_k_f32.cl | 318 ++++++++++++++++++ 5 files changed, 877 insertions(+), 2 deletions(-) create mode 100644 ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_k_f32.cl create mode 100644 ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_k_f32.cl diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index af29f3b8f4c..540942b195d 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -114,6 +114,8 @@ set(GGML_OPENCL_KERNELS gemv_noshuffle_q4_1_f32 gemm_noshuffle_q4_1_f32 gemv_noshuffle_general_q8_0_f32 + gemv_noshuffle_q4_k_f32 + gemm_noshuffle_q4_k_f32 gemv_noshuffle_q6_k_f32 gemm_noshuffle_q6_k_f32 mul diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index c40e1f2d391..0f6628c377d 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -538,6 +538,8 @@ struct ggml_backend_opencl_context { cl_kernel kernel_restore_block_q4_0_noshuffle; cl_kernel kernel_convert_block_q4_1_noshuffle; cl_kernel kernel_restore_block_q4_1_noshuffle; + cl_kernel kernel_convert_block_q4_K_noshuffle; + cl_kernel kernel_restore_block_q4_K_noshuffle; cl_kernel kernel_convert_block_q4_K, kernel_restore_block_q4_K; cl_kernel kernel_convert_block_q6_K, kernel_restore_block_q6_K; cl_kernel kernel_mul_mat_q4_0_f32_1d_8x_flat, kernel_mul_mat_q4_0_f32_1d_16x_flat; @@ -720,6 +722,8 @@ struct ggml_backend_opencl_context { cl_kernel kernel_gemm_noshuffle_q4_1_f32; cl_kernel kernel_mul_mm_q8_0_f32_8x4; cl_kernel CL_mul_mat_vec_q8_0_f32; + cl_kernel kernel_gemv_noshuffle_q4_k_f32; + cl_kernel kernel_gemm_noshuffle_q4_k_f32; cl_kernel kernel_gemv_noshuffle_q6_K_f32; cl_kernel kernel_gemm_noshuffle_q6_K_f32; #endif // GGML_OPENCL_USE_ADRENO_KERNELS @@ -932,6 +936,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve CL_CHECK((backend_ctx->kernel_restore_block_q8_0_trans = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q8_0_trans", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q4_K = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_K", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_q4_K = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_K", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q4_K_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_K_noshuffle", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q4_K_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_K_noshuffle", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q6_K = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q6_K", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_q6_K = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q6_K", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q6_K_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q6_K_noshuffle", &err), err)); @@ -2619,6 +2625,45 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // gemm_noshuffle_q4_k_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemm_noshuffle_q4_k_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("gemm_noshuffle_q4_k_f32.cl"); +#endif + cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_gemm_noshuffle_q4_k_f32 = clCreateKernel(prog, "kernel_gemm_noshuffle_q4_k_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // gemv_noshuffle_q4_k_f32 + { + std::string CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable "; + if (backend_ctx->has_vector_subgroup_broadcast) { + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAST "; + } + +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemv_noshuffle_q4_k_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("gemv_noshuffle_q4_k_f32.cl"); +#endif + + cl_program prog = build_program_from_source( + backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_gemv_compile_opts); + + CL_CHECK((backend_ctx->kernel_gemv_noshuffle_q4_k_f32 = clCreateKernel(prog, "kernel_gemv_noshuffle_q4_k_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + std::string CL_moe_compile_opts = std::string("-cl-std=") + opencl_c_std + " -cl-mad-enable " " -cl-fast-relaxed-math"; @@ -5060,12 +5105,25 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); CL_CHECK(err); + #ifdef GGML_OPENCL_USE_ADRENO_KERNELS cl_kernel kernel = backend_ctx->kernel_convert_block_q4_K; + if (use_adreno_kernels(backend_ctx, tensor)) { + kernel = backend_ctx->kernel_convert_block_q4_K_noshuffle; + } + #else + cl_kernel kernel = backend_ctx->kernel_convert_block_q4_K; + #endif + + cl_uchar mask_0F = 0x0F; + cl_uchar mask_F0 = 0xF0; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q)); CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->s)); CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->d)); CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra->dm)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_uchar), &mask_0F)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_uchar), &mask_F0)); size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; size_t local_work_size[] = {64, 1, 1}; @@ -5076,6 +5134,20 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, CL_CHECK(clReleaseMemObject(data_device)); tensor->extra = extra; +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_kernels(backend_ctx, tensor)) { + + int M = tensor->ne[1]; + int K = tensor->ne[0]; + + GGML_ASSERT(K % 32 == 0); + + // Transpose q, d, dm as ushort + transpose_2d_as_16b(backend_ctx, extra->q, extra->q, size_q, K/4, M); + transpose_2d_as_16b(backend_ctx, extra->d, extra->d, size_d, K/256, M); + transpose_2d_as_16b(backend_ctx, extra->dm, extra->dm, size_dm, K/256, M); + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS return; } if (tensor->type == GGML_TYPE_Q6_K) { @@ -5516,12 +5588,60 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, ggml_nbytes(tensor), NULL, &err); CL_CHECK(err); + cl_uchar mask_0F = 0x0F; + cl_uchar mask_F0 = 0xF0; + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_kernels(backend_ctx, tensor)) { + int M = tensor->ne[1]; + int K = tensor->ne[0]; + + size_t size_q = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/2; + size_t size_d = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t); + size_t size_dm = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t); + + static ggml_cl_buffer buf_trans_q; + static ggml_cl_buffer buf_trans_d; + static ggml_cl_buffer buf_trans_dm; + + buf_trans_q.allocate(backend_ctx->context, size_q); + buf_trans_d.allocate(backend_ctx->context, size_d); + buf_trans_dm.allocate(backend_ctx->context, size_dm); + + // Transpose q, d, dm back + transpose_2d_as_16b(backend_ctx, extra->q, buf_trans_q.buffer, size_q, M, K/4); + transpose_2d_as_16b(backend_ctx, extra->d, buf_trans_d.buffer, size_d, M, K/256); + transpose_2d_as_16b(backend_ctx, extra->dm, buf_trans_dm.buffer, size_dm, M, K/256); + + cl_kernel kernel = backend_ctx->kernel_restore_block_q4_K_noshuffle; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &buf_trans_q.buffer)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->s)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &buf_trans_d.buffer)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &buf_trans_dm.buffer)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_uchar), &mask_0F)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_uchar), &mask_F0)); + + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {1, 1, 1}; + + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, NULL)); + CL_CHECK(clEnqueueReadBuffer(queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); + return; + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + cl_kernel kernel = backend_ctx->kernel_restore_block_q4_K; CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q)); CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->s)); CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d)); CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->dm)); CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_uchar), &mask_0F)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_uchar), &mask_F0)); size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; size_t local_work_size[] = {1, 1, 1}; @@ -9688,6 +9808,192 @@ static void ggml_cl_mul_mat_q8_0_f32_adreno(ggml_backend_t backend, const ggml_t #endif } +static void ggml_cl_mul_mat_q4_k_f32_adreno(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + ggml_tensor_extra_cl_q4_K * extra0_q4_k = (ggml_tensor_extra_cl_q4_K *)src0->extra; + + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + + const int ne1 = dst->ne[1]; + + GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0); + + cl_context context = backend_ctx->context; + cl_kernel kernel; + + cl_int err; + cl_image_format img_fmt; + cl_image_desc img_desc; + cl_buffer_region region; + + int M = ne01; + int N = ne1; + int K = ne00; + + cl_uchar mask_d6 = 0x3F; + cl_uchar mask_d4 = 0x0F; + cl_uchar mask_hi2 = 0xC0; + + if (ne1 == 1) { + cl_mem q_img = nullptr; + cl_mem b_sub_buf = nullptr; + cl_mem b_img = nullptr; + + // image for q + img_fmt = { CL_R, CL_UNSIGNED_INT32}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = M * K / 2 / 4; + img_desc.buffer = extra0_q4_k->q; + CL_CHECK((q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + // subbuffer for activations + region.origin = offset1; + region.size = K * N * sizeof(float); + CL_CHECK((b_sub_buf = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for activations + img_fmt = {CL_RGBA, CL_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * N / 4; + img_desc.buffer = b_sub_buf; + CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + kernel = backend_ctx->kernel_gemv_noshuffle_q4_k_f32; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &q_img)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_k->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q4_k->dm)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q4_k->s)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &b_img)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_uchar), &mask_d6)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_uchar), &mask_d4)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_uchar), &mask_hi2)); + + size_t local_work_size[3] = {64, 4, 1}; + size_t global_work_size[3] = {(size_t)CEIL_DIV(ne01/2, 64)*64, 4, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + + CL_CHECK(clReleaseMemObject(q_img)); + CL_CHECK(clReleaseMemObject(b_sub_buf)); + CL_CHECK(clReleaseMemObject(b_img)); + } else { + + cl_mem b_sub_buf = nullptr; + cl_mem b_sub_buf_trans = nullptr; + cl_mem b_img = nullptr; + cl_mem b_img_trans = nullptr; + + // subbuffer for activations + region.origin = offset1; + region.size = K * N * sizeof(float); + CL_CHECK((b_sub_buf = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for activations + img_fmt = {CL_RGBA, CL_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * N / 4; + img_desc.buffer = b_sub_buf; + CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + // pad N to multiple of 8 + int extra_elements = N % 8; + int padding = 0; + if (extra_elements > 0){ + padding = 8 - extra_elements; + } + + // subbuffer for transposed activations + region.origin = 0; + region.size = K * (N + padding) * sizeof(float)/2; + backend_ctx->prealloc_act_trans.allocate(context, region.size); + CL_CHECK((b_sub_buf_trans = clCreateSubBuffer(backend_ctx->prealloc_act_trans.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for transposed activations + img_fmt = {CL_RGBA, CL_HALF_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * (N + padding) / 4; + img_desc.buffer = b_sub_buf_trans; + CL_CHECK((b_img_trans = clCreateImage(context, 0, &img_fmt, &img_desc, NULL, &err), err)); + + // transpose activations + int height_B = N/4; + if (height_B == 0) { + height_B = 1; + } + int width_B = K/4; + int padded_height_B = (N + padding)/4; + + kernel = backend_ctx->kernel_transpose_32_16; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &b_img)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &b_img_trans)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_B)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_B)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &padded_height_B)); + + size_t local_work_size_t[2] = { 1, 16 }; + size_t global_work_size_t[2] = { (size_t)width_B, (size_t)padded_height_B }; + backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size_t, local_work_size_t, dst); + + // gemm + kernel = backend_ctx->kernel_gemm_noshuffle_q4_k_f32; + int padded_N = N + padding; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q4_k->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_k->s)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q4_k->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q4_k->dm)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &b_img_trans)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_int), &padded_N)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_uchar), &mask_d6)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_uchar), &mask_d4)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_uchar), &mask_hi2)); + + size_t global_work_size[3] = {(size_t)CEIL_DIV(ne1, 8), (size_t)CEIL_DIV(ne01, 4), 1}; + size_t local_work_size[3] = {1, 128, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + CL_CHECK(clReleaseMemObject(b_sub_buf)); + CL_CHECK(clReleaseMemObject(b_sub_buf_trans)); + CL_CHECK(clReleaseMemObject(b_img)); + CL_CHECK(clReleaseMemObject(b_img_trans)); + } +#else + GGML_UNUSED(backend); + GGML_UNUSED(src0); + GGML_UNUSED(src1); + GGML_UNUSED(dst); +#endif +} + static void ggml_cl_mul_mat_q6_K_f32_adreno(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { #ifdef GGML_OPENCL_USE_ADRENO_KERNELS GGML_ASSERT(src0); @@ -10014,6 +10320,12 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co return; } + // q4_k x fp32 + if (src0t == GGML_TYPE_Q4_K && src1t == GGML_TYPE_F32) { + ggml_cl_mul_mat_q4_k_f32_adreno(backend, src0, src1, dst); + return; + } + // q6_K x fp32 if (src0t == GGML_TYPE_Q6_K && src1t == GGML_TYPE_F32) { ggml_cl_mul_mat_q6_K_f32_adreno(backend, src0, src1, dst); diff --git a/ggml/src/ggml-opencl/kernels/cvt.cl b/ggml/src/ggml-opencl/kernels/cvt.cl index 34930dfbe6a..81fe17fa10f 100644 --- a/ggml/src/ggml-opencl/kernels/cvt.cl +++ b/ggml/src/ggml-opencl/kernels/cvt.cl @@ -424,13 +424,17 @@ kernel void kernel_restore_block_q8_0_trans( // Convert the block_q4_K format to 4 separate arrays (AOS -> SOA). // This kernel does not deshuffle the bits. // Each thread processes a super block. +// Mask args are just to keep the signature consistent with the no-shuffle +// version and they are not used in this kernel. //------------------------------------------------------------------------------ kernel void kernel_convert_block_q4_K( global struct block_q4_K * src0, global uchar * dst_q, global uchar * dst_s, global half * dst_d, - global half * dst_dm + global half * dst_dm, + uchar mask_0F, + uchar mask_F0 ) { global struct block_q4_K * b = (global struct block_q4_K *) src0 + get_global_id(0); global uchar * q = (global uchar *) dst_q + QK_K/2*get_global_id(0); @@ -451,12 +455,15 @@ kernel void kernel_convert_block_q4_K( // Restore block_q4_K from flattened arrays. // Each thread processes a super block. +// Mask args are just to keep the signature consistent with the no-shuffle ones. kernel void kernel_restore_block_q4_K( global uchar * src_q, global uchar * src_s, global half * src_d, global half * src_dm, - global struct block_q4_K * dst + global struct block_q4_K * dst, + uchar mask_0F, + uchar mask_F0 ) { global struct block_q4_K * b = (global struct block_q4_K *) dst + get_global_id(0); global uchar * q = (global uchar *) src_q + QK_K/2*get_global_id(0); @@ -475,6 +482,70 @@ kernel void kernel_restore_block_q4_K( } } +kernel void kernel_convert_block_q4_K_noshuffle( + global struct block_q4_K * src0, + global uchar * dst_q, + global uchar * dst_s, + global half * dst_d, + global half * dst_dm, + uchar mask_0F, + uchar mask_F0 +) { + global struct block_q4_K * b = (global struct block_q4_K *) src0 + get_global_id(0); + global uchar * q = (global uchar *) dst_q + QK_K/2 * get_global_id(0); + global uchar * s = (global uchar *) dst_s + K_SCALE_SIZE * get_global_id(0); + global half * d = (global half *) dst_d + get_global_id(0); + global half * dm = (global half *) dst_dm + get_global_id(0); + + *d = b->d; + *dm = b->dm; + + for (int i = 0; i < QK_K / 64; ++i) { + for (int j = 0; j < 16; ++j) { + uchar x0 = b->q[i*32 + 2*j]; + uchar x1 = b->q[i*32 + 2*j + 1]; + q[i*32 + j] = convert_uchar(x0 & mask_0F) | convert_uchar((x1 & mask_0F) << 4); + q[i*32 + j + 16] = convert_uchar((x0 & mask_F0) >> 4) | convert_uchar(x1 & mask_F0); + } + } + + for (int i = 0; i < K_SCALE_SIZE; ++i) { + s[i] = b->s[i]; + } +} + +kernel void kernel_restore_block_q4_K_noshuffle( + global uchar * src_q, + global uchar * src_s, + global half * src_d, + global half * src_dm, + global struct block_q4_K * dst, + uchar mask_0F, + uchar mask_F0 +) { + global struct block_q4_K * b = (global struct block_q4_K *) dst + get_global_id(0); + global uchar * q = (global uchar *) src_q + QK_K/2 * get_global_id(0); + global uchar * s = (global uchar *) src_s + K_SCALE_SIZE * get_global_id(0); + global half * d = (global half *) src_d + get_global_id(0); + global half * dm = (global half *) src_dm + get_global_id(0); + + b->d = *d; + b->dm = *dm; + + for (int i = 0; i < QK_K / 64; ++i) { + for (int j = 0; j < 16; ++j) { + uchar lo = q[i*32 + j]; + uchar hi = q[i*32 + j + 16]; + b->q[i*32 + 2*j] = convert_uchar((lo & mask_0F) | ((hi & mask_0F) << 4)); + b->q[i*32 + 2*j + 1] = convert_uchar(((lo & mask_F0) >> 4) | (hi & mask_F0)); + } + } + + for (int i = 0; i < K_SCALE_SIZE; ++i) { + b->s[i] = s[i]; + } +} + //------------------------------------------------------------------------------ // kernel_convert_block_q6_K // Convert the block_q6_K format to 3 separate arrays (AOS -> SOA). diff --git a/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_k_f32.cl b/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_k_f32.cl new file mode 100644 index 00000000000..99fd1fd7bf1 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_k_f32.cl @@ -0,0 +1,172 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_qcom_reqd_sub_group_size +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif +#define QK_K 256 +#define K_SCALE_SIZE 12 + +inline void get_scale_min_k4( + int j, + global const uchar * q, + uchar * d, + uchar * m, + uchar mask_d6, + uchar mask_d4, + uchar mask_hi2 +) { + if (j < 4) { + *d = q[j] & mask_d6; + *m = q[j+4] & mask_d6; + } else { + *d = (q[j+4] & mask_d4) | ((q[j-4] & mask_hi2) >> 2); + *m = ((q[j+4] >> 4) & mask_d4) | ((q[j] & mask_hi2) >> 2); + } +} + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_128 +#endif +kernel void kernel_gemm_noshuffle_q4_k_f32( + global const ushort * src0_q, + global const uchar * src0_s, + global const half * src0_d, + global const half * src0_dm, + read_only image1d_buffer_t src1, + global float * dst, + ulong offsetd, + int m, + int n, + int k, + int n_no_padding, + uchar mask_d6, + uchar mask_d4, + uchar mask_hi2 +) { + dst = (global float *)((global char *)dst + offsetd); + int n_4 = n >> 2; + int gy = get_global_id(0); + int gx = get_global_id(1); + int gx_2 = gx << 2; + + half8 c0 = 0, c1 = 0, c2 = 0, c3 = 0; + half8 B; + half4 dequantized_weights; + + int num_blocks_K = k / QK_K; + + global const ushort * weight_ptr = src0_q + gx_2; + global const half * d_ptr = src0_d + gx_2; + global const half * dm_ptr = src0_dm + gx_2; + + for (int i = 0; i < k; i += 32) { + int sb_idx = i / QK_K; + int sub_idx = (i / 32) % 8; + + half4 d = vload4(0, d_ptr + sb_idx * m); + half4 dm = vload4(0, dm_ptr + sb_idx * m); + + global const uchar * sc0 = src0_s + (gx_2+0) * num_blocks_K * K_SCALE_SIZE + sb_idx * K_SCALE_SIZE; + global const uchar * sc1 = src0_s + (gx_2+1) * num_blocks_K * K_SCALE_SIZE + sb_idx * K_SCALE_SIZE; + global const uchar * sc2 = src0_s + (gx_2+2) * num_blocks_K * K_SCALE_SIZE + sb_idx * K_SCALE_SIZE; + global const uchar * sc3 = src0_s + (gx_2+3) * num_blocks_K * K_SCALE_SIZE + sb_idx * K_SCALE_SIZE; + + uchar sv0, mn0, sv1, mn1, sv2, mn2, sv3, mn3; + get_scale_min_k4(sub_idx, sc0, &sv0, &mn0, mask_d6, mask_d4, mask_hi2); + get_scale_min_k4(sub_idx, sc1, &sv1, &mn1, mask_d6, mask_d4, mask_hi2); + get_scale_min_k4(sub_idx, sc2, &sv2, &mn2, mask_d6, mask_d4, mask_hi2); + get_scale_min_k4(sub_idx, sc3, &sv3, &mn3, mask_d6, mask_d4, mask_hi2); + + half4 scale = convert_half4(convert_float4(d) * convert_float4((uchar4)(sv0, sv1, sv2, sv3))); + half4 mval = convert_half4(convert_float4(dm) * convert_float4((uchar4)(mn0, mn1, mn2, mn3))); + + for (int l = 0; l < 32; l += 4) { + int ki = i + l; + ushort4 bits4 = vload4(0, weight_ptr + (ki/4) * m); + + // j=0 + B.s0123 = read_imageh(src1, gy*2 + (ki+0) * n_4); + B.s4567 = read_imageh(src1, gy*2+1 + (ki+0) * n_4); + dequantized_weights.s0 = (bits4.s0 & 0x000F) * scale.s0 - mval.s0; + dequantized_weights.s1 = (bits4.s1 & 0x000F) * scale.s1 - mval.s1; + dequantized_weights.s2 = (bits4.s2 & 0x000F) * scale.s2 - mval.s2; + dequantized_weights.s3 = (bits4.s3 & 0x000F) * scale.s3 - mval.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=1 + B.s0123 = read_imageh(src1, gy*2 + (ki+1) * n_4); + B.s4567 = read_imageh(src1, gy*2+1 + (ki+1) * n_4); + dequantized_weights.s0 = ((bits4.s0 & 0x00F0) >> 4) * scale.s0 - mval.s0; + dequantized_weights.s1 = ((bits4.s1 & 0x00F0) >> 4) * scale.s1 - mval.s1; + dequantized_weights.s2 = ((bits4.s2 & 0x00F0) >> 4) * scale.s2 - mval.s2; + dequantized_weights.s3 = ((bits4.s3 & 0x00F0) >> 4) * scale.s3 - mval.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=2 + B.s0123 = read_imageh(src1, gy*2 + (ki+2) * n_4); + B.s4567 = read_imageh(src1, gy*2+1 + (ki+2) * n_4); + dequantized_weights.s0 = ((bits4.s0 & 0x0F00) >> 8) * scale.s0 - mval.s0; + dequantized_weights.s1 = ((bits4.s1 & 0x0F00) >> 8) * scale.s1 - mval.s1; + dequantized_weights.s2 = ((bits4.s2 & 0x0F00) >> 8) * scale.s2 - mval.s2; + dequantized_weights.s3 = ((bits4.s3 & 0x0F00) >> 8) * scale.s3 - mval.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=3 + B.s0123 = read_imageh(src1, gy*2 + (ki+3) * n_4); + B.s4567 = read_imageh(src1, gy*2+1 + (ki+3) * n_4); + dequantized_weights.s0 = ((bits4.s0 & 0xF000) >> 12) * scale.s0 - mval.s0; + dequantized_weights.s1 = ((bits4.s1 & 0xF000) >> 12) * scale.s1 - mval.s1; + dequantized_weights.s2 = ((bits4.s2 & 0xF000) >> 12) * scale.s2 - mval.s2; + dequantized_weights.s3 = ((bits4.s3 & 0xF000) >> 12) * scale.s3 - mval.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + } + } + + int idx = (gy<<3)*m + (gx<<2); + + if (idx+3 < m*n_no_padding) { + vstore4((float4)(c0.s0, c1.s0, c2.s0, c3.s0), 0, dst + idx); + idx += m; + } + if (idx+3 < m*n_no_padding) { + vstore4((float4)(c0.s1, c1.s1, c2.s1, c3.s1), 0, dst + idx); + idx += m; + } + if (idx+3 < m*n_no_padding) { + vstore4((float4)(c0.s2, c1.s2, c2.s2, c3.s2), 0, dst + idx); + idx += m; + } + if (idx+3 < m*n_no_padding) { + vstore4((float4)(c0.s3, c1.s3, c2.s3, c3.s3), 0, dst + idx); + idx += m; + } + if (idx+3 < m*n_no_padding) { + vstore4((float4)(c0.s4, c1.s4, c2.s4, c3.s4), 0, dst + idx); + idx += m; + } + if (idx+3 < m*n_no_padding) { + vstore4((float4)(c0.s5, c1.s5, c2.s5, c3.s5), 0, dst + idx); + idx += m; + } + if (idx+3 < m*n_no_padding) { + vstore4((float4)(c0.s6, c1.s6, c2.s6, c3.s6), 0, dst + idx); + idx += m; + } + if (idx+3 < m*n_no_padding) { + vstore4((float4)(c0.s7, c1.s7, c2.s7, c3.s7), 0, dst + idx); + } +} diff --git a/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_k_f32.cl b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_k_f32.cl new file mode 100644 index 00000000000..dd1e2b55c0b --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_k_f32.cl @@ -0,0 +1,318 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable + +#ifdef cl_qcom_reqd_sub_group_size +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#endif + +#define QK_K 256 +#define NSUBGROUPS 4 +#define SUBGROUP_SIZE 64 + +inline void get_scale_min_k4( + int j, + global const uchar * q, + uchar * d, + uchar * m, + uchar mask_d6, + uchar mask_d4, + uchar mask_hi2 +) { + if (j < 4) { + *d = q[j] & mask_d6; + *m = q[j+4] & mask_d6; + } else { + *d = (q[j+4] & mask_d4) | ((q[j-4] & mask_hi2) >> 2); + *m = ((q[j+4] >> 4) & mask_d4) | ((q[j] & mask_hi2) >> 2); + } +} + +#define dequantizeBlockAccum_ns_sgbroadcast_1_hi(total_sums, bits4, scale, minv, y) \ + float shared_y; \ + shared_y = sub_group_broadcast(y.s0, 0); \ + total_sums.s0 += ((bits4.s0 & 0x000F) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((bits4.s1 & 0x000F) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 0); \ + total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 0); \ + total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 0); \ + total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 0); \ + total_sums.s0 += ((bits4.s2 & 0x000F) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((bits4.s3 & 0x000F) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 0); \ + total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 0); \ + total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 0); \ + total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s0, 1); \ + total_sums.s0 += ((bits4.s4 & 0x000F) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((bits4.s5 & 0x000F) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 1); \ + total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 1); \ + total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 1); \ + total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 1); \ + total_sums.s0 += ((bits4.s6 & 0x000F) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((bits4.s7 & 0x000F) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 1); \ + total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 1); \ + total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 1); \ + total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y; \ + + +#define dequantizeBlockAccum_ns_sgbroadcast_1_lo(total_sums, bits4, scale, minv, y) \ + shared_y = sub_group_broadcast(y.s0, 2); \ + total_sums.s0 += ((bits4.s0 & 0x000F) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((bits4.s1 & 0x000F) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 2); \ + total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 2); \ + total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 2); \ + total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 2); \ + total_sums.s0 += ((bits4.s2 & 0x000F) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((bits4.s3 & 0x000F) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 2); \ + total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 2); \ + total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 2); \ + total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s0, 3); \ + total_sums.s0 += ((bits4.s4 & 0x000F) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((bits4.s5 & 0x000F) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 3); \ + total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 3); \ + total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 3); \ + total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 3); \ + total_sums.s0 += ((bits4.s6 & 0x000F) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((bits4.s7 & 0x000F) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 3); \ + total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 3); \ + total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 3); \ + total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y; \ + + +#define dequantizeBlockAccum_ns_sgbroadcast_8_hi(total_sums, bits4, scale, minv, y) \ + float8 shared_y; \ + shared_y = sub_group_broadcast(y, 0); \ + total_sums.s0 += ((bits4.s0 & 0x000F) * scale.s0 - minv.s0) * shared_y.s0; \ + total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y.s1; \ + total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y.s2; \ + total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y.s3; \ + total_sums.s0 += ((bits4.s2 & 0x000F) * scale.s0 - minv.s0) * shared_y.s4; \ + total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y.s5; \ + total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y.s6; \ + total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y.s7; \ + total_sums.s1 += ((bits4.s1 & 0x000F) * scale.s1 - minv.s1) * shared_y.s0; \ + total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y.s1; \ + total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y.s2; \ + total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y.s3; \ + total_sums.s1 += ((bits4.s3 & 0x000F) * scale.s1 - minv.s1) * shared_y.s4; \ + total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y.s5; \ + total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y.s6; \ + total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y.s7; \ + shared_y = sub_group_broadcast(y, 1); \ + total_sums.s0 += ((bits4.s4 & 0x000F) * scale.s0 - minv.s0) * shared_y.s0; \ + total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y.s1; \ + total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y.s2; \ + total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y.s3; \ + total_sums.s0 += ((bits4.s6 & 0x000F) * scale.s0 - minv.s0) * shared_y.s4; \ + total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y.s5; \ + total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y.s6; \ + total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y.s7; \ + total_sums.s1 += ((bits4.s5 & 0x000F) * scale.s1 - minv.s1) * shared_y.s0; \ + total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y.s1; \ + total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y.s2; \ + total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y.s3; \ + total_sums.s1 += ((bits4.s7 & 0x000F) * scale.s1 - minv.s1) * shared_y.s4; \ + total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y.s5; \ + total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y.s6; \ + total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y.s7; \ + + +#define dequantizeBlockAccum_ns_sgbroadcast_8_lo(total_sums, bits4, scale, minv, y) \ + shared_y = sub_group_broadcast(y, 2); \ + total_sums.s0 += ((bits4.s0 & 0x000F) * scale.s0 - minv.s0) * shared_y.s0; \ + total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y.s1; \ + total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y.s2; \ + total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y.s3; \ + total_sums.s0 += ((bits4.s2 & 0x000F) * scale.s0 - minv.s0) * shared_y.s4; \ + total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y.s5; \ + total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y.s6; \ + total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y.s7; \ + total_sums.s1 += ((bits4.s1 & 0x000F) * scale.s1 - minv.s1) * shared_y.s0; \ + total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y.s1; \ + total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y.s2; \ + total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y.s3; \ + total_sums.s1 += ((bits4.s3 & 0x000F) * scale.s1 - minv.s1) * shared_y.s4; \ + total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y.s5; \ + total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y.s6; \ + total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y.s7; \ + shared_y = sub_group_broadcast(y, 3); \ + total_sums.s0 += ((bits4.s4 & 0x000F) * scale.s0 - minv.s0) * shared_y.s0; \ + total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y.s1; \ + total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y.s2; \ + total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y.s3; \ + total_sums.s0 += ((bits4.s6 & 0x000F) * scale.s0 - minv.s0) * shared_y.s4; \ + total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y.s5; \ + total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y.s6; \ + total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y.s7; \ + total_sums.s1 += ((bits4.s5 & 0x000F) * scale.s1 - minv.s1) * shared_y.s0; \ + total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y.s1; \ + total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y.s2; \ + total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y.s3; \ + total_sums.s1 += ((bits4.s7 & 0x000F) * scale.s1 - minv.s1) * shared_y.s4; \ + total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y.s5; \ + total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y.s6; \ + total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y.s7; \ + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_gemv_noshuffle_q4_k_f32( + read_only image1d_buffer_t src0_q, + global half2 * src0_d, + global half2 * src0_m, + global uchar * src0_s, + read_only image1d_buffer_t src1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + uchar mask_d6, + uchar mask_d4, + uchar mask_hi2) +{ + uint groupId = get_local_id(1); + uint gid = get_global_id(0); + ushort slid = get_sub_group_local_id(); + + uint K = ne00; + uint M = ne01; + + uint LINE_STRIDE_A = M / 2; + uint BLOCK_STRIDE_A = NSUBGROUPS * M; + uint scales_per_row = (K / QK_K) * 12; + + private uint4 regA; + private half2 regS; + private half2 regM; + private float8 regB; + + private float2 totalSum = (float2)(0.0f); + + for (uint k = groupId; k < (K / 32); k += NSUBGROUPS) { + uint sb = k / 8; + uint j = k % 8; + + half2 d = src0_d[gid + sb * LINE_STRIDE_A]; + half2 dm = src0_m[gid + sb * LINE_STRIDE_A]; + + global const uchar * sc0 = src0_s + 2 * gid * scales_per_row + sb * 12; + global const uchar * sc1 = src0_s + (2 * gid + 1) * scales_per_row + sb * 12; + + uchar sv0, mn0, sv1, mn1; + get_scale_min_k4(j, sc0, &sv0, &mn0, mask_d6, mask_d4, mask_hi2); + get_scale_min_k4(j, sc1, &sv1, &mn1, mask_d6, mask_d4, mask_hi2); + + regS = convert_half2(convert_float2(d) * convert_float2((uchar2)(sv0, sv1))); + regM = convert_half2(convert_float2(dm) * convert_float2((uchar2)(mn0, mn1))); + + if (slid < 4) { + regB.s0123 = read_imagef(src1, (slid * 2 + k * 8)); + regB.s4567 = read_imagef(src1, (1 + slid * 2 + k * 8)); + } + + // load half weights for two blocks in consecutive rows + regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 0)).x; + regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 1)).x; + regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 2)).x; + regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 3)).x; +#ifdef VECTOR_SUB_GROUP_BROADCAST + dequantizeBlockAccum_ns_sgbroadcast_8_hi(totalSum, as_ushort8(regA), regS, regM, regB); +#else + dequantizeBlockAccum_ns_sgbroadcast_1_hi(totalSum, as_ushort8(regA), regS, regM, regB); +#endif // VECTOR_SUB_GROUP_BROADCAST + + regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 4)).x; + regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 5)).x; + regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x; + regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x; +#ifdef VECTOR_SUB_GROUP_BROADCAST + dequantizeBlockAccum_ns_sgbroadcast_8_lo(totalSum, as_ushort8(regA), regS, regM, regB); +#else + dequantizeBlockAccum_ns_sgbroadcast_1_lo(totalSum, as_ushort8(regA), regS, regM, regB); +#endif // VECTOR_SUB_GROUP_BROADCAST + } + + // reduction in local memory, assumes #wave=4 + local float2 reduceLM[SUBGROUP_SIZE * 3]; + if (groupId == 1) { + reduceLM[SUBGROUP_SIZE * 0 + slid] = totalSum; + } + if (groupId == 2) { + reduceLM[SUBGROUP_SIZE * 1 + slid] = totalSum; + } + if (groupId == 3) { + reduceLM[SUBGROUP_SIZE * 2 + slid] = totalSum; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + if (groupId == 0) { + totalSum += reduceLM[SUBGROUP_SIZE * 0 + slid]; + } + if (groupId == 0) { + totalSum += reduceLM[SUBGROUP_SIZE * 1 + slid]; + } + if (groupId == 0) { + totalSum += reduceLM[SUBGROUP_SIZE * 2 + slid]; + } + + // 2 outputs per fiber in wave 0 + if (groupId == 0) { + dst = (global float*)((global char*)dst + offsetd); + vstore2(totalSum, 0, &(dst[gid * 2])); + } + +} From 952c66237de555d87a1ae3f39948fe6a1b6cdfb5 Mon Sep 17 00:00:00 2001 From: Neo Zhang Date: Tue, 31 Mar 2026 18:31:50 +0800 Subject: [PATCH 366/831] sycl : enhance fattn perf (llama/21185) --- ggml/src/ggml-sycl/fattn-tile.hpp | 83 ++++++++++++++++--------------- 1 file changed, 43 insertions(+), 40 deletions(-) diff --git a/ggml/src/ggml-sycl/fattn-tile.hpp b/ggml/src/ggml-sycl/fattn-tile.hpp index 29fd0f8c9ec..c4d24613a55 100644 --- a/ggml/src/ggml-sycl/fattn-tile.hpp +++ b/ggml/src/ggml-sycl/fattn-tile.hpp @@ -70,6 +70,7 @@ static constexpr uint32_t ggml_sycl_fattn_tile_get_config_fp16(const int DKQ, co GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64) GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64) GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 32, 256, 2, 64, 64) return 0; } @@ -310,11 +311,11 @@ static __dpct_inline__ void flash_attn_tile_load_tile(const sycl::half2 * const sycl::half2 * const __restrict__ tile_KV, const int stride_KV, const int i_sup) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); constexpr int cpy_nb = ggml_sycl_get_max_cpy_bytes(); constexpr int cpy_ne = cpy_nb / 4; auto load = [&] (const int n) { - auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); const int stride_j = warp_size >> n; if (stride_j == 0) { @@ -455,7 +456,7 @@ static __dpct_inline__ void flash_attn_tile_iter_KQ(T_vec_dot * const Q_tmp, flash_attn_tile_load_tile (K_h2 + int64_t(k_VKQ_0)*stride_K2 + k_KQ_0/2, KV_tmp, stride_K2, k_VKQ_sup); - item_ct1.barrier(); + item_ct1.barrier(sycl::access::fence_space::local_space); #ifdef SYCL_FAST_FP16 static_assert((nbatch_K/2) % cpy_ne == 0, "bad nbatch_K"); @@ -505,7 +506,7 @@ static __dpct_inline__ void flash_attn_tile_iter_KQ(T_vec_dot * const Q_tmp, } if (k_KQ_0 + nbatch_K < DKQ) { - item_ct1.barrier(); // Sync not needed on last iteration. + item_ct1.barrier(sycl::access::fence_space::local_space); // Sync not needed on last iteration. } } @@ -545,7 +546,7 @@ static __dpct_inline__ void flash_attn_tile_iter(T_vec_dot * const Q_tmp, const int k_VKQ_max, const int col_Q_0, float * KQ_max_new_shared) { - auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); constexpr int cpy_nb = ggml_sycl_get_max_cpy_bytes(); constexpr int cpy_ne = cpy_nb / 4; @@ -620,14 +621,14 @@ static __dpct_inline__ void flash_attn_tile_iter(T_vec_dot * const Q_tmp, } if constexpr (np == 1) { - item_ct1.barrier(); + item_ct1.barrier(sycl::access::fence_space::local_space); } else { static_assert(cpw == 1, "bad cpw"); if (item_ct1.get_local_id(2) == 0) { KQ_max_new_shared[item_ct1.get_local_id(1)] = KQ_max_new[0]; } - item_ct1.barrier(); + item_ct1.barrier(sycl::access::fence_space::local_space); KQ_max_new[0] = KQ_max_new_shared[(item_ct1.get_local_id(1) & ~(np - 1)) + item_ct1.get_local_id(2) % np]; KQ_max_new[0] = warp_reduce_max(KQ_max_new[0]); } @@ -697,7 +698,7 @@ static __dpct_inline__ void flash_attn_tile_iter(T_vec_dot * const Q_tmp, for (int k0 = 0; k0 < nbatch_fa; k0 += nbatch_V) { flash_attn_tile_load_tile (V_h2 + int64_t(k_VKQ_0 + k0)*stride_V2, KV_tmp, stride_V2, k_VKQ_sup - k0); - item_ct1.barrier(); + item_ct1.barrier(sycl::access::fence_space::local_space); #ifdef SYCL_FAST_FP16 #pragma unroll @@ -765,7 +766,7 @@ static __dpct_inline__ void flash_attn_tile_iter(T_vec_dot * const Q_tmp, } } #endif // SYCL_FAST_FP16 - item_ct1.barrier(); + item_ct1.barrier(sycl::access::fence_space::local_space); } } @@ -972,7 +973,7 @@ static void flash_attn_tile(const char * Q, } } - item_ct1.barrier(); + item_ct1.barrier(sycl::access::fence_space::local_space); // Main loop over KV cache: const int k_VKQ_max = KV_max ? KV_max[sequence * item_ct1.get_group_range(2) + item_ct1.get_group(2)] : ne11; @@ -1051,7 +1052,7 @@ static void flash_attn_tile(const char * Q, return; } - item_ct1.barrier(); + item_ct1.barrier(sycl::access::fence_space::local_space); #pragma unroll for (int ip = 1; ip < np; ++ip) { @@ -1193,37 +1194,39 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_sycl_context & ctx, ggm constexpr size_t nbytes_shared = 0; - if constexpr (DV <= 256) { - if (Q->ne[1] > 16/ncols2) { - constexpr int cols_per_block = 32; - const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; - const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); - launch_fattn, warp_size> - (ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false); - return; + if (DV < 512 && Q->ne[1] < 32) { + if constexpr (ncols2 <= 32) { + if (Q->ne[1] > 16/ncols2) { + constexpr int cols_per_block = 32; + const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; + const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); + launch_fattn, warp_size> + (ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false); + return; + } } - } - - if (Q->ne[1] > 8/ncols2) { - constexpr int cols_per_block = 16; - const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; - const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); - launch_fattn, warp_size> - (ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false); - return; - } - - if constexpr (ncols2 <= 8) { - if (Q->ne[1] > 4/ncols2) { - constexpr int cols_per_block = 8; - const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; - const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); - launch_fattn, warp_size> - (ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false); - return; + if constexpr (ncols2 <= 16) { + if (Q->ne[1] > 8/ncols2) { + constexpr int cols_per_block = 16; + const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; + const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); + launch_fattn, warp_size> + (ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false); + return; + } + } + if constexpr (ncols2 <= 8) { + if (Q->ne[1] > 4/ncols2) { + constexpr int cols_per_block = 8; + const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; + const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); + launch_fattn, warp_size> + (ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false); + return; + } } } From 5ffe58838dbb34db41f1f1c1db6c739c1b7fe83b Mon Sep 17 00:00:00 2001 From: hipudding Date: Tue, 31 Mar 2026 22:00:51 +0800 Subject: [PATCH 367/831] CANN: fix multi-thread set_tensor race conditions (llama/20151) * CANN: fix multi-thread set_tensor race conditions When ollama calls ggml_backend_tensor_set from multiple threads (each writing a different chunk of the same tensor), the CANN backend had three concurrency issues: 1. Quantized tensors (Q4_0/Q8_0) require a full-tensor format transform before uploading to device. Per-chunk transforms produced corrupt data. 2. ND-to-NZ weight conversion requires complete tensor data on device. Per-chunk conversion operated on incomplete data. 3. The global g_nz_workspaces array had unprotected concurrent access. Fix by introducing a TensorSetTracker that accumulates write progress per tensor. For quantized tensors, raw data is staged in a host buffer and the transform + upload is deferred until all chunks arrive. For NZ weights, chunks are uploaded directly but conversion is deferred. The tracker and its staging buffer are released immediately after post-processing completes. Add per-device mutex to g_nz_workspaces to prevent data races. * CANN: fix L2_NORM ignoring eps parameter The L2_NORM implementation was not using the eps parameter from op_params, causing incorrect results when eps is large (e.g. 10.0). The CPU reference computes scale = 1/fmaxf(norm, eps), so add a Clamp step to clamp the norm to at least eps before dividing. * ggml/cann: compare op_params for POOL_2D in ACL graph cache matching When ACL graph mode is enabled, the graph LRU cache checks whether a cached graph matches the current computation graph. Previously, GGML_OP_POOL_2D was not included in the op_params comparison, so two POOL_2D nodes with different pooling parameters (kernel size, stride, padding) but identical tensor shapes and addresses could incorrectly reuse a cached graph, leading to wrong results or aclnn errors. Add GGML_OP_POOL_2D to the list of ops that require op_params matching in ggml_graph_node_properties::has_matching_properties(). * cann: fix ACL graph cache matching by adding tensor type and unconditional op_params comparison The ACL graph LRU cache was incorrectly reusing cached graphs for operations with different tensor types or op_params, causing test failures for CPY (f16 vs bf16), POOL_2D, L2_NORM, NORM_MUL_ADD, RMS_NORM_MUL_ADD, and ADD_RMS_NORM. Changes: - Add node_type and src_type[] fields to ggml_graph_node_properties so the cache can distinguish tensors with different types but identical ne/nb (e.g. f16 and bf16 both have 2-byte elements) - Compare op_params unconditionally for all ops instead of only for SCALE/UNARY/GLU/ROPE/POOL_2D --- ggml/src/ggml-cann/aclnn_ops.cpp | 10 +++ ggml/src/ggml-cann/common.h | 30 ++++--- ggml/src/ggml-cann/ggml-cann.cpp | 129 +++++++++++++++++++++++++++---- 3 files changed, 145 insertions(+), 24 deletions(-) diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index adb4d68e868..a950475fc3b 100644 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -434,6 +434,9 @@ void ggml_cann_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst) { void ggml_cann_l2_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst) { ggml_tensor * src = dst->src[0]; + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + acl_tensor_ptr acl_src = ggml_cann_create_tensor(src); acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst); @@ -456,6 +459,13 @@ void ggml_cann_l2_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst) { float p_value = 2.0f; acl_scalar_ptr p_scalar = ggml_cann_create_scalar(&p_value, aclDataType::ACL_FLOAT); GGML_CANN_CALL_ACLNN_OP(ctx, Norm, acl_src.get(), p_scalar.get(), dims_array.get(), true, acl_div.get()); + + // Clamp norm to at least eps: scale = 1/fmaxf(norm, eps) + acl_scalar_ptr acl_min = ggml_cann_create_scalar(&eps, aclDataType::ACL_FLOAT); + float flt_max = FLT_MAX; + acl_scalar_ptr acl_max = ggml_cann_create_scalar(&flt_max, aclDataType::ACL_FLOAT); + GGML_CANN_CALL_ACLNN_OP(ctx, Clamp, acl_div.get(), acl_min.get(), acl_max.get(), acl_div.get()); + GGML_CANN_CALL_ACLNN_OP(ctx, Div, acl_src.get(), acl_div.get(), acl_dst.get()); } diff --git a/ggml/src/ggml-cann/common.h b/ggml/src/ggml-cann/common.h index 5f960548cd2..1c6e685c38c 100644 --- a/ggml/src/ggml-cann/common.h +++ b/ggml/src/ggml-cann/common.h @@ -216,14 +216,16 @@ struct ggml_cann_pool_alloc { #ifdef USE_ACL_GRAPH struct ggml_graph_node_properties { // dst tensor - void * node_address; - int64_t ne[GGML_MAX_DIMS]; - size_t nb[GGML_MAX_DIMS]; + void * node_address; + ggml_type node_type; + int64_t ne[GGML_MAX_DIMS]; + size_t nb[GGML_MAX_DIMS]; // src tensor - void * src_address[GGML_MAX_SRC]; - int64_t src_ne[GGML_MAX_SRC][GGML_MAX_DIMS]; - size_t src_nb[GGML_MAX_SRC][GGML_MAX_DIMS]; + void * src_address[GGML_MAX_SRC]; + ggml_type src_type[GGML_MAX_SRC]; + int64_t src_ne[GGML_MAX_SRC][GGML_MAX_DIMS]; + size_t src_nb[GGML_MAX_SRC][GGML_MAX_DIMS]; // op ggml_op node_op; @@ -247,6 +249,10 @@ struct ggml_graph_node_properties { return false; } + if (node->type != this->node_type) { + return false; + } + for (int i = 0; i < GGML_MAX_DIMS; i++) { if (node->ne[i] != this->ne[i]) { return false; @@ -262,6 +268,10 @@ struct ggml_graph_node_properties { return false; } + if (node->src[i]->type != this->src_type[i]) { + return false; + } + for (int d = 0; d < GGML_MAX_DIMS; d++) { if (node->src[i]->ne[d] != this->src_ne[i][d]) { return false; @@ -277,10 +287,7 @@ struct ggml_graph_node_properties { } } - if (node->op == GGML_OP_SCALE || node->op == GGML_OP_UNARY || node->op == GGML_OP_GLU || node->op == GGML_OP_ROPE){ - return memcmp(this->op_params, node->op_params, GGML_MAX_OP_PARAMS) == 0; - } - return true; + return memcmp(this->op_params, node->op_params, GGML_MAX_OP_PARAMS) == 0; } }; @@ -322,6 +329,7 @@ struct ggml_cann_graph { prop.node_address = node->data; prop.node_op = node->op; + prop.node_type = node->type; std::copy_n(node->ne, GGML_MAX_DIMS, prop.ne); std::copy_n(node->nb, GGML_MAX_DIMS, prop.nb); @@ -329,10 +337,12 @@ struct ggml_cann_graph { for (int src = 0; src < GGML_MAX_SRC; ++src) { if (node->src[src]) { prop.src_address[src] = node->src[src]->data; + prop.src_type[src] = node->src[src]->type; std::copy_n(node->src[src]->ne, GGML_MAX_DIMS, prop.src_ne[src]); std::copy_n(node->src[src]->nb, GGML_MAX_DIMS, prop.src_nb[src]); } else { prop.src_address[src] = nullptr; + prop.src_type[src] = GGML_TYPE_COUNT; std::fill_n(prop.src_ne[src], GGML_MAX_DIMS, 0); std::fill_n(prop.src_nb[src], GGML_MAX_DIMS, 0); } diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index 6f26e91e046..40fe3d82ecc 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -36,10 +36,13 @@ #include #include #include +#include #include #include #include +#include #include +#include #define GGML_COMMON_DECL_C @@ -770,6 +773,21 @@ std::unique_ptr ggml_backend_cann_context::new_pool_for_device(i } // cann buffer + +/** + * @brief Tracks multi-threaded write progress for a single tensor. + * + * When multiple threads call set_tensor on different chunks of the same tensor, + * this tracker accumulates progress and defers post-processing (quantized format + * transform or ND-to-NZ conversion) until all data has been written. + */ +struct TensorSetTracker { + std::mutex mtx; ///< Protects concurrent access to this tracker + size_t bytes_written = 0; ///< Accumulated bytes written so far + size_t total_bytes = 0; ///< Target size (full tensor) + std::vector host_buffer; ///< Host staging buffer for quantized tensors +}; + /** * @brief Context for managing a CANN buffer associated with a specific device. * @@ -780,6 +798,9 @@ struct ggml_backend_cann_buffer_context { int32_t device; ///< The device ID associated with this buffer context. void * dev_ptr = nullptr; ///< Pointer to the device memory allocated for the buffer. + std::mutex tracker_mutex; ///< Protects the trackers map + std::unordered_map> trackers; + /** * @brief Constructor to initialize the CANN buffer context. * @@ -792,6 +813,31 @@ struct ggml_backend_cann_buffer_context { * @brief Destructor to free the device memory allocated for the buffer. */ ~ggml_backend_cann_buffer_context() { ACL_CHECK(aclrtFree(dev_ptr)); } + + /** + * @brief Get or create a tracker for the given tensor. + */ + TensorSetTracker * get_or_create_tracker(ggml_tensor * tensor) { + std::lock_guard lock(tracker_mutex); + auto key = tensor->data; + auto it = trackers.find(key); + if (it == trackers.end()) { + auto tracker = std::make_unique(); + tracker->total_bytes = ggml_nbytes(tensor); + auto * ptr = tracker.get(); + trackers[key] = std::move(tracker); + return ptr; + } + return it->second.get(); + } + + /** + * @brief Remove the tracker for the given tensor. + */ + void remove_tracker(ggml_tensor * tensor) { + std::lock_guard lock(tracker_mutex); + trackers.erase(tensor->data); + } }; // cann buffer type @@ -1124,6 +1170,7 @@ static enum ggml_status ggml_backend_cann_buffer_init_tensor(ggml_backend_buffer * designed to be used with a global array, one per device. */ struct ggml_cann_nz_workspace { + std::mutex mtx; // Protects ptr/allocated from concurrent access void * ptr; // Pointer to allocated device buffer size_t allocated; // Size of currently allocated buffer in bytes @@ -1190,13 +1237,15 @@ static ggml_cann_nz_workspace g_nz_workspaces[GGML_CANN_MAX_DEVICES]; * @note The workspace buffer used in this function is managed globally and reused * across calls. This reduces overhead from repeated memory allocation and deallocation. */ -static void weight_format_to_nz(ggml_tensor * tensor, size_t offset, int device) { - acl_tensor_ptr weightTransposed = ggml_cann_create_tensor(tensor, tensor->ne, tensor->nb, 2, ACL_FORMAT_ND, offset); +static void weight_format_to_nz(ggml_tensor * tensor, int device) { + acl_tensor_ptr weightTransposed = ggml_cann_create_tensor(tensor, tensor->ne, tensor->nb, 2, ACL_FORMAT_ND, 0); uint64_t workspaceSize = 0; aclOpExecutor * executor; // TransMatmulWeight ACL_CHECK(aclnnTransMatmulWeightGetWorkspaceSize(weightTransposed.get(), &workspaceSize, &executor)); + + std::lock_guard lock(g_nz_workspaces[device].mtx); // Avoid frequent malloc/free of the workspace. g_nz_workspaces[device].realloc(workspaceSize); @@ -1210,7 +1259,13 @@ static void weight_format_to_nz(ggml_tensor * tensor, size_t offset, int device) * @brief Set tensor data in a CANN buffer. * * This function sets tensor data in a CANN buffer, handling transformations - * if needed based on the tensor's type. + * if needed based on the tensor's type. It supports multi-threaded calls + * where different threads write different chunks of the same tensor. + * + * For quantized tensors (Q4_0/Q8_0), data is staged in a host buffer and + * the format transform is deferred until all chunks are written. + * For NZ weight tensors, chunks are uploaded directly but the ND-to-NZ + * conversion is deferred until all chunks are written. * * @param buffer The CANN buffer where the tensor data will be set. * @param tensor Pointer to the tensor whose data will be set. @@ -1226,26 +1281,72 @@ static void ggml_backend_cann_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_backend_cann_buffer_context * ctx = (ggml_backend_cann_buffer_context *) buffer->context; ggml_cann_set_device(ctx->device); - // TODO: refer to cann(#6017), it use thread's default stream. - // For acl, synchronous functions use this default stream. - // Why aclrtSynchronizeDevice? // Only check env once. static bool weight_to_nz = parse_bool(get_env_as_lowercase("GGML_CANN_WEIGHT_NZ").value_or("on")); - if (!need_transform(tensor->type)) { + + bool is_quantized = need_transform(tensor->type); + bool is_nz = !is_quantized && tensor->type != GGML_TYPE_BF16 && weight_to_nz && + is_matmul_weight((const ggml_tensor *) tensor); + + // Plain tensor (not quantized, not NZ): direct copy, no tracking needed + if (!is_quantized && !is_nz) { ACL_CHECK(aclrtMemcpy((char *) tensor->data + offset, size, data, size, ACL_MEMCPY_HOST_TO_DEVICE)); - if (weight_to_nz && tensor->type != GGML_TYPE_BF16 - && is_matmul_weight((const ggml_tensor *) tensor)) { + return; + } + + // Single-shot write (full tensor at once): handle directly without tracking overhead + if (offset == 0 && size == ggml_nbytes(tensor)) { + if (is_quantized) { + void * transform_buffer = malloc(size); + ggml_backend_cann_transform(tensor, data, transform_buffer); + ACL_CHECK(aclrtMemcpy(tensor->data, size, transform_buffer, size, ACL_MEMCPY_HOST_TO_DEVICE)); + free(transform_buffer); + } else { + // NZ weight GGML_ASSERT(tensor->ne[2] == 1); GGML_ASSERT(tensor->ne[3] == 1); - weight_format_to_nz(tensor, offset, ctx->device); + ACL_CHECK(aclrtMemcpy(tensor->data, size, data, size, ACL_MEMCPY_HOST_TO_DEVICE)); + weight_format_to_nz(tensor, ctx->device); } + return; + } + + // Chunked write: use tracker to accumulate progress and defer transform/conversion + TensorSetTracker * tracker = ctx->get_or_create_tracker(tensor); + std::unique_lock lock(tracker->mtx); + + if (is_quantized) { + // Stage data in host buffer; transform requires full tensor data + if (tracker->host_buffer.empty()) { + tracker->host_buffer.resize(tracker->total_bytes); + } + memcpy(tracker->host_buffer.data() + offset, data, size); } else { - void * transform_buffer = malloc(size); - ggml_backend_cann_transform(tensor, data, transform_buffer); + // NZ weight: upload chunk to device immediately, defer conversion + ACL_CHECK(aclrtMemcpy((char *) tensor->data + offset, size, data, size, ACL_MEMCPY_HOST_TO_DEVICE)); + } - ACL_CHECK(aclrtMemcpy((char *) tensor->data + offset, size, transform_buffer, size, ACL_MEMCPY_HOST_TO_DEVICE)); - free(transform_buffer); + tracker->bytes_written += size; + + // All chunks received: perform deferred transform/conversion + if (tracker->bytes_written >= tracker->total_bytes) { + if (is_quantized) { + void * transform_buffer = malloc(tracker->total_bytes); + ggml_backend_cann_transform(tensor, tracker->host_buffer.data(), transform_buffer); + ACL_CHECK(aclrtMemcpy(tensor->data, tracker->total_bytes, transform_buffer, tracker->total_bytes, ACL_MEMCPY_HOST_TO_DEVICE)); + free(transform_buffer); + } + + if (is_nz) { + GGML_ASSERT(tensor->ne[2] == 1); + GGML_ASSERT(tensor->ne[3] == 1); + weight_format_to_nz(tensor, ctx->device); + } + + // Unlock before removing tracker, as remove_tracker destroys the mutex + lock.unlock(); + ctx->remove_tracker(tensor); } } From 21b9dd6789eac3db4e152aca87c727874e2f0cf1 Mon Sep 17 00:00:00 2001 From: Abhijit Ramesh Date: Wed, 1 Apr 2026 12:58:53 +0300 Subject: [PATCH 368/831] ggml-webgpu: port all AOT operators to JIT (llama/20728) * port cpy pipeline to shader lib with JIT compilation * port glu pipeline to shader lib with JIT compilation * port rope pipeline to shader lib with JIT compilation * port soft_max pipeline to shader lib with JIT compilation * removed unused functions from embed_wgsl.py which were used for old AOT template expansion --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 325 ++++++++++++++++++ ggml/src/ggml-webgpu/ggml-webgpu.cpp | 224 ++++-------- ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl | 81 +++++ .../ggml-webgpu/wgsl-shaders/embed_wgsl.py | 107 +----- ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl | 155 +++++++++ ggml/src/ggml-webgpu/wgsl-shaders/rope.wgsl | 224 ++++++++++++ .../ggml-webgpu/wgsl-shaders/soft_max.wgsl | 245 +++++++++++++ 7 files changed, 1097 insertions(+), 264 deletions(-) create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/rope.wgsl create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/soft_max.wgsl diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 59861ac16cc..97863f40412 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -535,6 +535,95 @@ struct ggml_webgpu_mul_mat_shader_decisions { uint32_t mul_mat_wg_size; }; +/** Cpy **/ + +struct ggml_webgpu_cpy_pipeline_key { + ggml_type src_type; + ggml_type dst_type; + + bool operator==(const ggml_webgpu_cpy_pipeline_key & other) const { + return src_type == other.src_type && dst_type == other.dst_type; + } +}; + +struct ggml_webgpu_cpy_pipeline_key_hash { + size_t operator()(const ggml_webgpu_cpy_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.src_type); + ggml_webgpu_hash_combine(seed, key.dst_type); + return seed; + } +}; + +/** Glu **/ + +struct ggml_webgpu_glu_pipeline_key { + ggml_glu_op glu_op; + ggml_type type; + bool split; + + bool operator==(const ggml_webgpu_glu_pipeline_key & other) const { + return glu_op == other.glu_op && type == other.type && split == other.split; + } +}; + +struct ggml_webgpu_glu_pipeline_key_hash { + size_t operator()(const ggml_webgpu_glu_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.glu_op); + ggml_webgpu_hash_combine(seed, key.type); + ggml_webgpu_hash_combine(seed, key.split); + return seed; + } +}; + +/** Rope **/ + +struct ggml_webgpu_rope_pipeline_key { + ggml_type type; + bool inplace; + bool has_ff; + + bool operator==(const ggml_webgpu_rope_pipeline_key & other) const { + return type == other.type && inplace == other.inplace && has_ff == other.has_ff; + } +}; + +struct ggml_webgpu_rope_pipeline_key_hash { + size_t operator()(const ggml_webgpu_rope_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.type); + ggml_webgpu_hash_combine(seed, key.inplace); + ggml_webgpu_hash_combine(seed, key.has_ff); + return seed; + } +}; + +/** SoftMax **/ + +struct ggml_webgpu_soft_max_pipeline_key { + ggml_type mask_type; + bool has_mask; + bool has_sink; + bool inplace; + + bool operator==(const ggml_webgpu_soft_max_pipeline_key & other) const { + return mask_type == other.mask_type && has_mask == other.has_mask && has_sink == other.has_sink && + inplace == other.inplace; + } +}; + +struct ggml_webgpu_soft_max_pipeline_key_hash { + size_t operator()(const ggml_webgpu_soft_max_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.mask_type); + ggml_webgpu_hash_combine(seed, key.has_mask); + ggml_webgpu_hash_combine(seed, key.has_sink); + ggml_webgpu_hash_combine(seed, key.inplace); + return seed; + } +}; + class ggml_webgpu_shader_lib { wgpu::Device device; pre_wgsl::Preprocessor preprocessor; @@ -582,6 +671,12 @@ class ggml_webgpu_shader_lib { std::unordered_map set_rows_pipelines; std::unordered_map set_pipelines; + std::unordered_map cpy_pipelines; + std::unordered_map glu_pipelines; + std::unordered_map + rope_pipelines; + std::unordered_map + soft_max_pipelines; public: ggml_webgpu_shader_lib(wgpu::Device device) { this->device = device; } @@ -1679,6 +1774,236 @@ class ggml_webgpu_shader_lib { return flash_attn_pipelines[key]; } + webgpu_pipeline get_cpy_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_cpy_pipeline_key key = { + .src_type = context.src0->type, + .dst_type = context.dst->type, + }; + + auto it = cpy_pipelines.find(key); + if (it != cpy_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "cpy"; + + switch (key.src_type) { + case GGML_TYPE_F32: + defines.push_back("SRC_F32"); + variant += "_f32"; + break; + case GGML_TYPE_F16: + defines.push_back("SRC_F16"); + variant += "_f16"; + break; + default: + GGML_ABORT("Unsupported src type for cpy shader"); + } + + switch (key.dst_type) { + case GGML_TYPE_F32: + defines.push_back("DST_F32"); + variant += "_f32"; + break; + case GGML_TYPE_F16: + defines.push_back("DST_F16"); + variant += "_f16"; + break; + case GGML_TYPE_I32: + defines.push_back("DST_I32"); + variant += "_i32"; + break; + default: + GGML_ABORT("Unsupported dst type for cpy shader"); + } + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_cpy, defines); + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + cpy_pipelines[key] = pipeline; + return cpy_pipelines[key]; + } + + webgpu_pipeline get_glu_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_glu_pipeline_key key = { + .glu_op = ggml_get_glu_op(context.dst), + .type = context.dst->type, + .split = (context.src1 != nullptr), + }; + + auto it = glu_pipelines.find(key); + if (it != glu_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "glu"; + + switch (key.glu_op) { + case GGML_GLU_OP_REGLU: + defines.push_back("OP_REGLU"); + variant += "_reglu"; + break; + case GGML_GLU_OP_GEGLU: + defines.push_back("OP_GEGLU"); + variant += "_geglu"; + break; + case GGML_GLU_OP_SWIGLU: + defines.push_back("OP_SWIGLU"); + variant += "_swiglu"; + break; + case GGML_GLU_OP_SWIGLU_OAI: + defines.push_back("OP_SWIGLU_OAI"); + variant += "_swiglu_oai"; + break; + case GGML_GLU_OP_GEGLU_ERF: + defines.push_back("OP_GEGLU_ERF"); + variant += "_geglu_erf"; + break; + case GGML_GLU_OP_GEGLU_QUICK: + defines.push_back("OP_GEGLU_QUICK"); + variant += "_geglu_quick"; + break; + default: + GGML_ABORT("Unsupported GLU op"); + } + switch (key.type) { + case GGML_TYPE_F32: + defines.push_back("TYPE_F32"); + variant += "_f32"; + break; + case GGML_TYPE_F16: + defines.push_back("TYPE_F16"); + variant += "_f16"; + break; + default: + GGML_ABORT("Unsupported type for GLU shader"); + } + + if (key.split) { + variant += "_split"; + } else { + defines.push_back("NO_SPLIT"); + } + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_glu, defines); + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + glu_pipelines[key] = pipeline; + return glu_pipelines[key]; + } + + webgpu_pipeline get_rope_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_rope_pipeline_key key = { + .type = context.dst->type, + .inplace = context.inplace, + .has_ff = (context.src2 != nullptr), + }; + + auto it = rope_pipelines.find(key); + if (it != rope_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "rope"; + + switch (key.type) { + case GGML_TYPE_F32: + defines.push_back("TYPE_F32"); + variant += "_f32"; + break; + case GGML_TYPE_F16: + defines.push_back("TYPE_F16"); + variant += "_f16"; + break; + default: + GGML_ABORT("Unsupported type for ROPE shader"); + } + + if (key.inplace) { + defines.push_back("INPLACE"); + variant += "_inplace"; + } + + if (key.has_ff) { + defines.push_back("FF_FUNC"); + variant += "_ff"; + } + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_rope, defines); + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + rope_pipelines[key] = pipeline; + return rope_pipelines[key]; + } + + webgpu_pipeline get_soft_max_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_soft_max_pipeline_key key = { + .mask_type = context.src1 ? context.src1->type : GGML_TYPE_F32, + .has_mask = (context.src1 != nullptr), + .has_sink = (context.src2 != nullptr), + .inplace = context.inplace, + }; + + auto it = soft_max_pipelines.find(key); + if (it != soft_max_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "soft_max"; + + if (key.has_mask) { + defines.push_back("HAS_MASK"); + switch (key.mask_type) { + case GGML_TYPE_F32: + defines.push_back("MASK_F32"); + variant += "_mask_f32"; + break; + case GGML_TYPE_F16: + defines.push_back("MASK_F16"); + variant += "_mask_f16"; + break; + default: + GGML_ABORT("Unsupported type for SOFT_MAX shader"); + } + } + + if (key.has_sink) { + defines.push_back("HAS_SINK"); + variant += "_sink"; + } + + if (key.inplace) { + defines.push_back("INPLACE"); + variant += "_inplace"; + } + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_soft_max, defines); + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + soft_max_pipelines[key] = pipeline; + return soft_max_pipelines[key]; + } + private: static webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & device, std::string shader_code, diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 5e16f84ddd2..fa3c492a7a5 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -364,13 +364,6 @@ struct webgpu_context_struct { wgpu::Buffer set_rows_dev_error_buf; wgpu::Buffer set_rows_host_error_buf; - std::map> cpy_pipelines; // src_type, dst_type - - std::map>> rope_pipelines; // type, ff, inplace - std::map>> glu_pipelines; // glu_op, type, split - - std::map>> soft_max_pipelines; // mask_type, has_sink, inplace - size_t memset_bytes_per_thread; }; @@ -849,6 +842,16 @@ static binary_overlap_flags ggml_webgpu_detect_binary_overlap(ggml_tensor * src0 } static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { + ggml_webgpu_shader_lib_context shader_lib_ctx = { + .src0 = src, + .dst = dst, + .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, + }; + + webgpu_pipeline pipeline = ctx->shader_lib->get_cpy_pipeline(shader_lib_ctx); + + auto * decisions = static_cast(pipeline.context.get()); + uint32_t ne = (uint32_t) ggml_nelements(dst); std::vector params = { @@ -875,9 +878,8 @@ static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, g .size = ggml_webgpu_tensor_binding_size(ctx, dst) } }; - uint32_t wg_x = CEIL_DIV(ne, WEBGPU_MAX_WG_SIZE); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, ctx->cpy_pipelines[src->type][dst->type], - params, entries, wg_x); + uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } static webgpu_command ggml_webgpu_set(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) { @@ -1914,6 +1916,19 @@ static webgpu_command ggml_webgpu_rope(webgpu_context & ctx, ggml_tensor * src1, ggml_tensor * src2, ggml_tensor * dst) { + ggml_webgpu_shader_lib_context shader_lib_ctx = { + .src0 = src0, + .src1 = src1, + .src2 = src2, + .dst = dst, + .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, + .inplace = ggml_webgpu_tensor_equal(src0, dst), + }; + + webgpu_pipeline pipeline = ctx->shader_lib->get_rope_pipeline(shader_lib_ctx); + + auto * decisions = static_cast(pipeline.context.get()); + const int inplace = ggml_webgpu_tensor_equal(src0, dst); const int has_freq_factor = (src2 != nullptr); @@ -1996,12 +2011,22 @@ static webgpu_command ggml_webgpu_rope(webgpu_context & ctx, .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); } - webgpu_pipeline pipeline = ctx->rope_pipelines[dst->type][has_freq_factor][inplace]; - uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE); + uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size); return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } static webgpu_command ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) { + ggml_webgpu_shader_lib_context shader_lib_ctx = { + .src0 = src0, + .src1 = src1, + .dst = dst, + .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, + }; + + webgpu_pipeline pipeline = ctx->shader_lib->get_glu_pipeline(shader_lib_ctx); + + auto * decisions = static_cast(pipeline.context.get()); + const int split = (src1 != nullptr); std::vector params = { @@ -2048,8 +2073,7 @@ static webgpu_command ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0, .offset = ggml_webgpu_tensor_align_offset(ctx, dst), .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); - webgpu_pipeline pipeline = ctx->glu_pipelines[ggml_get_glu_op(dst)][dst->type][split]; - uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE); + uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size); return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } @@ -2109,9 +2133,20 @@ static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx, ggml_tensor * src1, ggml_tensor * src2, ggml_tensor * dst) { - const int inplace = ggml_webgpu_tensor_equal(src0, dst); - const int mask_type = (src1 != nullptr) ? src1->type : 2; // use 2 for no mask here - const int has_sink = (src2 != nullptr); + ggml_webgpu_shader_lib_context shader_lib_ctx = { + .src0 = src0, + .src1 = src1, + .src2 = src2, + .dst = dst, + .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, + .inplace = ggml_webgpu_tensor_equal(src0, dst), + }; + + webgpu_pipeline pipeline = ctx->shader_lib->get_soft_max_pipeline(shader_lib_ctx); + + const int inplace = ggml_webgpu_tensor_equal(src0, dst); + const int has_mask = (src1 != nullptr); + const int has_sink = (src2 != nullptr); float max_bias; memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); float n_head_log2 = float(1u << (uint32_t) floor(log2(src0->ne[2]))); @@ -2120,15 +2155,15 @@ static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx, std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), - mask_type < 2 ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)) : 0, + has_mask ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)) : 0, has_sink ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)) : 0, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), - mask_type < 2 ? (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)) : 0, - mask_type < 2 ? (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)) : 0, - mask_type < 2 ? (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)) : 0, + has_mask ? (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)) : 0, + has_mask ? (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)) : 0, + has_mask ? (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)) : 0, (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), @@ -2136,8 +2171,8 @@ static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx, (uint32_t) src0->ne[0], (uint32_t) src0->ne[1], (uint32_t) src0->ne[2], - mask_type < 2 ? (uint32_t) src1->ne[2] : 0, - mask_type < 2 ? (uint32_t) src1->ne[3] : 0, + has_mask ? (uint32_t) src1->ne[2] : 0, + has_mask ? (uint32_t) src1->ne[3] : 0, *(uint32_t *) dst->op_params, // scale *(uint32_t *) &max_bias, *(uint32_t *) &n_head_log2, @@ -2152,7 +2187,7 @@ static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx, .size = ggml_webgpu_tensor_binding_size(ctx, src0) } }; uint32_t binding_num = 1; - if (mask_type < 2) { + if (has_mask) { entries.push_back({ .binding = binding_num, .buffer = ggml_webgpu_tensor_buf(src1), .offset = ggml_webgpu_tensor_align_offset(ctx, src1), @@ -2173,9 +2208,7 @@ static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx, .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); } - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, - ctx->soft_max_pipelines[mask_type][has_sink][inplace], params, entries, - ggml_nrows(dst)); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, ggml_nrows(dst)); } static webgpu_command ggml_webgpu_argmax(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { @@ -2885,139 +2918,6 @@ static void ggml_webgpu_init_memset_pipeline(webgpu_global_context & ctx) { ctx->memset_pipelines[0] = ggml_webgpu_create_pipeline(ctx->device, wgsl_memset, "memset", constants); } -static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); - - webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_F32] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f32_f32, "cpy_f32_f32", constants); - webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_I32] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f32_i32, "cpy_f32_i32", constants); - webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_F16] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f32_f16, "cpy_f32_f16", constants); - webgpu_ctx->cpy_pipelines[GGML_TYPE_F16][GGML_TYPE_F32] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f16_f32, "cpy_f16_f32", constants); - webgpu_ctx->cpy_pipelines[GGML_TYPE_F16][GGML_TYPE_F16] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f16_f16, "cpy_f16_f16", constants); -} - -static void ggml_webgpu_init_rope_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); - - webgpu_ctx->rope_pipelines[GGML_TYPE_F32][0][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f32, "rope_f32", constants); - webgpu_ctx->rope_pipelines[GGML_TYPE_F32][0][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_rope_f32_inplace, "rope_f32_inplace", constants); - webgpu_ctx->rope_pipelines[GGML_TYPE_F32][1][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f32_ff, "rope_f32_ff", constants); - webgpu_ctx->rope_pipelines[GGML_TYPE_F32][1][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_rope_f32_ff_inplace, "rope_f32_ff_inplace", constants); - - webgpu_ctx->rope_pipelines[GGML_TYPE_F16][0][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f16, "rope_f16", constants); - webgpu_ctx->rope_pipelines[GGML_TYPE_F16][0][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_rope_f16_inplace, "rope_f16_inplace", constants); - webgpu_ctx->rope_pipelines[GGML_TYPE_F16][1][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_rope_f16_ff, "rope_f16_ff", constants); - webgpu_ctx->rope_pipelines[GGML_TYPE_F16][1][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_rope_f16_ff_inplace, "rope_f16_ff_inplace", constants); -} - -static void ggml_webgpu_init_glu_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); - - // REGLU - webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f32, "reglu_f32", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f16, "reglu_f16", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f32_split, "reglu_f32_split", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_reglu_f16_split, "reglu_f16_split", constants); - - // GEGLU - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f32, "geglu_f32", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f16, "geglu_f16", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f32_split, "geglu_f32_split", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_f16_split, "geglu_f16_split", constants); - - // SWIGLU - webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_swiglu_f32, "swiglu_f32", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_swiglu_f16, "swiglu_f16", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_swiglu_f32_split, "swiglu_f32_split", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_swiglu_f16_split, "swiglu_f16_split", constants); - - // SWIGLU_OAI - webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_swiglu_oai_f32, "swiglu_oai_f32", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_swiglu_oai_f32_split, "swiglu_oai_f32_split", constants); - - // GEGLU_ERF - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f32, "geglu_erf_f32", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f16, "geglu_erf_f16", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f32_split, "geglu_erf_f32_split", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_geglu_erf_f16_split, "geglu_erf_f16_split", constants); - - // GEGLU_QUICK - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f32, "geglu_quick_f32", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f16, "geglu_quick_f16", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f32_split, "geglu_quick_f32_split", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f16_split, "geglu_quick_f16_split", constants); -} - -static void ggml_webgpu_init_soft_max_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE); - - // f32 (no mask) - webgpu_ctx->soft_max_pipelines[2][0][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_soft_max_f32, "soft_max_f32", constants); - webgpu_ctx->soft_max_pipelines[2][0][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_inplace, "soft_max_f32_inplace", constants); - webgpu_ctx->soft_max_pipelines[2][1][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_sink, "soft_max_f32_sink", constants); - webgpu_ctx->soft_max_pipelines[2][1][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_sink_inplace, "soft_max_f32_sink_inplace", constants); - - // f32 mask (mask_type = 0) - webgpu_ctx->soft_max_pipelines[0][0][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32, "soft_max_f32_mask_f32", constants); - webgpu_ctx->soft_max_pipelines[0][0][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32_inplace, "soft_max_f32_mask_f32_inplace", constants); - webgpu_ctx->soft_max_pipelines[0][1][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32_sink, "soft_max_f32_mask_f32_sink", constants); - webgpu_ctx->soft_max_pipelines[0][1][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f32_sink_inplace, - "soft_max_f32_mask_f32_sink_inplace", constants); - - // f16 mask (mask_type = 1) - webgpu_ctx->soft_max_pipelines[1][0][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16, "soft_max_f32_mask_f16", constants); - webgpu_ctx->soft_max_pipelines[1][0][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16_inplace, "soft_max_f32_mask_f16_inplace", constants); - webgpu_ctx->soft_max_pipelines[1][1][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16_sink, "soft_max_f32_mask_f16_sink", constants); - webgpu_ctx->soft_max_pipelines[1][1][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_soft_max_f32_mask_f16_sink_inplace, - "soft_max_f32_mask_f16_sink_inplace", constants); -} - static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { wgpu::RequestAdapterOptions options = {}; @@ -3183,10 +3083,6 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) { WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES, wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "set_rows_host_error_buf"); - ggml_webgpu_init_cpy_pipeline(webgpu_ctx); - ggml_webgpu_init_rope_pipeline(webgpu_ctx); - ggml_webgpu_init_glu_pipeline(webgpu_ctx); - ggml_webgpu_init_soft_max_pipeline(webgpu_ctx); #ifdef GGML_WEBGPU_DEBUG // Initialize debug buffers ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->global_ctx->debug_host_buf, diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl new file mode 100644 index 00000000000..fa3bdf4e393 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl @@ -0,0 +1,81 @@ +enable f16; + +#ifdef SRC_F32 +#define SRC_TYPE f32 +#elif defined(SRC_F16) +#define SRC_TYPE f16 +#endif + +#ifdef DST_F32 +#define DST_TYPE f32 +#elif defined(DST_F16) +#define DST_TYPE f16 +#elif defined(DST_I32) +#define DST_TYPE i32 +#endif + +@group(0) @binding(0) +var src: array; + +@group(0) @binding(1) +var dst: array; + +struct Params{ + ne: u32, + offset_src: u32, + offset_dst: u32, + + stride_src0: u32, + stride_src1: u32, + stride_src2: u32, + stride_src3: u32, + + + stride_dst0: u32, + stride_dst1: u32, + stride_dst2: u32, + stride_dst3: u32, + + src_ne0: u32, + src_ne1: u32, + src_ne2: u32, + + dst_ne0: u32, + dst_ne1: u32, + dst_ne2: u32 +}; + +@group(0) @binding(2) +var params: Params; + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(global_invocation_id) gid: vec3) { + if (gid.x >= params.ne) { + return; + } + + var i = gid.x; + let i3 = i / (params.src_ne2 * params.src_ne1 * params.src_ne0); + i = i % (params.src_ne2 * params.src_ne1 * params.src_ne0); + let i2 = i / (params.src_ne1 * params.src_ne0); + i = i % (params.src_ne1 * params.src_ne0); + let i1 = i / params.src_ne0; + let i0 = i % params.src_ne0; + + var j = gid.x; + let j3 = j / (params.dst_ne2 * params.dst_ne1 * params.dst_ne0); + j = j % (params.dst_ne2 * params.dst_ne1 * params.dst_ne0); + let j2 = j / (params.dst_ne1 * params.dst_ne0); + j = j % (params.dst_ne1 * params.dst_ne0); + let j1 = j / params.dst_ne0; + let j0 = j % params.dst_ne0; + + let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 + + i2 * params.stride_src2 + i3 * params.stride_src3; + + let dst_idx = j0 * params.stride_dst0 + j1 * params.stride_dst1 + + j2 * params.stride_dst2 + j3 * params.stride_dst3; + + dst[params.offset_dst + dst_idx] = DST_TYPE((src[params.offset_src + src_idx])); +} + diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py b/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py index 8b5cfe715e7..79a3a9597ab 100755 --- a/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +++ b/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py @@ -1,41 +1,8 @@ import os import re -import ast import argparse -def extract_block(text, name): - pattern = rf'#define\({name}\)\s*(.*?)#end\({name}\)' - match = re.search(pattern, text, re.DOTALL) - if not match: - raise ValueError(f"Missing block: {name}") - return match.group(1).strip() - - -def parse_decls(decls_text): - decls = {} - for name, code in re.findall(r'#decl\((.*?)\)\s*(.*?)#enddecl\(\1\)', decls_text, re.DOTALL): - decls[name.strip()] = code.strip() - return decls - - -def replace_repl_placeholders(variant, template_map): - for repl, code in variant["REPLS"].items(): - for key, val in template_map.items(): - # Match "key" and avoid matching subsequences using by using \b - code = re.sub(rf'\b{re.escape(str(key))}\b', str(val), code) - variant["REPLS"][repl] = code - return variant - - -def replace_placeholders(shader_text, replacements): - for key, val in replacements.items(): - # Match {{KEY}} literally, where KEY is escaped - pattern = r'{{\s*' + re.escape(key) + r'\s*}}' - shader_text = re.sub(pattern, str(val), shader_text) - return shader_text - - def expand_includes(shader, input_dir): """ Replace #include "file" lines in the text with the contents of that file. @@ -98,84 +65,24 @@ def write_shader(shader_name, shader_code, output_dir, outfile, input_dir): outfile.write(f'const char* wgsl_{shader_name} = wgsl_{shader_name}_str().c_str();\n\n') -def generate_variants(fname, input_dir, output_dir, outfile): - shader_path = os.path.join(input_dir, fname) - shader_base_name = fname.split(".")[0] - - with open(shader_path, "r", encoding="utf-8") as f: - text = f.read() - - try: - variants = ast.literal_eval(extract_block(text, "VARIANTS")) - except ValueError: - write_shader(shader_base_name, text, output_dir, outfile, input_dir) - else: - try: - decls_map = parse_decls(extract_block(text, "DECLS")) - except ValueError: - decls_map = {} - try: - templates_map = ast.literal_eval(extract_block(text, "REPL_TEMPLATES")) - except ValueError: - templates_map = {} - - for fname in sorted(os.listdir(input_dir)): - if fname.endswith(".tmpl"): - tmpl_path = os.path.join(input_dir, fname) - with open(tmpl_path, "r", encoding="utf-8") as f_tmpl: - decls = f_tmpl.read() - decls_map.update(parse_decls(decls)) - - shader_template = extract_block(text, "SHADER") - for variant in variants: - if "DECLS" in variant: - decls = variant["DECLS"] - else: - decls = [] - decls_code = "" - for key in decls: - if key not in decls_map: - raise ValueError(f"DECLS key '{key}' not found.") - decls_code += decls_map[key] + "\n\n" - final_shader = re.sub(r'\bDECLS\b', decls_code, shader_template) - if "REPLS" in variant: - variant = replace_repl_placeholders(variant, templates_map) - final_shader = replace_placeholders(final_shader, variant["REPLS"]) - # second run to expand placeholders in repl_template - final_shader = replace_placeholders(final_shader, variant["REPLS"]) - final_shader = expand_includes(final_shader, input_dir) - - if "SHADER_NAME" in variant: - output_name = variant["SHADER_NAME"] - elif "SHADER_SUFFIX" in variant: - output_name = f"{shader_base_name}_" + variant["SHADER_SUFFIX"] - elif "REPLS" in variant and "SRC0_TYPE" in variant["REPLS"] and "SRC1_TYPE" in variant["REPLS"]: - output_name = f"{shader_base_name}_" + "_".join([variant["REPLS"]["SRC0_TYPE"], variant["REPLS"]["SRC1_TYPE"]]) - elif "REPLS" in variant and "SRC_TYPE" in variant["REPLS"] and "DST_TYPE" in variant["REPLS"]: - output_name = f"{shader_base_name}_" + "_".join([variant["REPLS"]["SRC_TYPE"], variant["REPLS"]["DST_TYPE"]]) - elif "REPLS" in variant and "TYPE" in variant["REPLS"]: - output_name = f"{shader_base_name}_" + variant["REPLS"]["TYPE"] - else: - output_name = shader_base_name - write_shader(output_name, final_shader, output_dir, outfile, input_dir) - - def main(): parser = argparse.ArgumentParser() parser.add_argument("--input_dir", required=True) parser.add_argument("--output_file", required=True) - parser.add_argument("--output_dir") args = parser.parse_args() - if args.output_dir: - os.makedirs(args.output_dir, exist_ok=True) - with open(args.output_file, "w", encoding="utf-8") as out: out.write("// Auto-generated shader embedding\n") out.write("#include \n\n") for fname in sorted(os.listdir(args.input_dir)): if fname.endswith(".wgsl"): - generate_variants(fname, args.input_dir, args.output_dir, out) + shader_path = os.path.join(args.input_dir, fname) + shader_name = fname.replace(".wgsl", "") + + with open(shader_path, "r", encoding="utf-8") as f: + shader_code = f.read() + + write_shader(shader_name, shader_code, None, out, args.input_dir) if __name__ == "__main__": diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl new file mode 100644 index 00000000000..e6d7608cec5 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl @@ -0,0 +1,155 @@ +enable f16; + +#ifdef TYPE_F32 +#define DataType f32 +#endif +#ifdef TYPE_F16 +#define DataType f16 +#endif + +#ifdef OP_REGLU +fn op(a: DataType, b: DataType) -> DataType { + return max(a, 0) * b; +} +#endif + +#ifdef OP_GEGLU +const SQRT_2_OVER_PI: DataType = 0.79788456080286535587989211986876; +const GELU_COEF_A: DataType = 0.044715; + +fn op(a: DataType, b: DataType) -> DataType { + let val = SQRT_2_OVER_PI * a * (1.0 + GELU_COEF_A * a * a); + return 0.5 * a * (2.0 - 2.0/ (exp(2* val) + 1)) * b; +} +#endif + +#ifdef OP_SWIGLU +fn op(a: DataType, b: DataType) -> DataType { + return a / (1.0 + exp(-a)) * b; +} +#endif +#ifdef OP_SWIGLU_OAI +fn op(a: f32, b: f32) -> f32 { + let xi = min(a, params.limit); + let gi = max(min(b, params.limit), -params.limit); + var out_glu = xi / (1.0 + exp(-xi * params.alpha)); + out_glu = out_glu * (1.0 + gi); + return out_glu; +} +#endif +#ifdef OP_GEGLU_ERF +const p_erf: DataType = 0.3275911; +const a1_erf: DataType = 0.254829592; +const a2_erf: DataType = -0.284496736; +const a3_erf: DataType = 1.421413741; +const a4_erf: DataType = -1.453152027; +const a5_erf: DataType = 1.061405429; +const SQRT_2_INV: DataType = 0.7071067811865476; + +fn op(a: DataType, b: DataType) -> DataType { + let a_div_sqr2 = a * SQRT_2_INV; + let sign_x = sign(a_div_sqr2); + let x = abs(a_div_sqr2); + let t = 1.0 / (1.0 + p_erf * x); + let y = 1.0 - (((((a5_erf * t + a4_erf) * t + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x)); + let erf_approx = sign_x * y; + return 0.5 * a * (1.0 + erf_approx) * b; +} +#endif +#ifdef OP_GEGLU_QUICK +const GELU_QUICK_COEF: DataType = -1.702; + +fn op(a: DataType, b: DataType) -> DataType { + return a * (1.0 / (1.0 + exp(GELU_QUICK_COEF * a))) * b; +} +#endif + +struct Params { + offset_src0: u32, + offset_src1: u32, + offset_dst: u32, + + // Strides (in elements) + stride_src01: u32, + stride_src02: u32, + stride_src03: u32, + + stride_src11: u32, + stride_src12: u32, + stride_src13: u32, + + stride_dst1: u32, + stride_dst2: u32, + stride_dst3: u32, + + // shape of dst + ne: u32, + ne0: u32, + ne1: u32, + ne2: u32, + + swapped: u32, + alpha: f32, + limit: f32, +} + +@group(0) @binding(0) +var src0: array; + +#ifdef NO_SPLIT +@group(0) @binding(1) +var dst: array; + +@group(0) @binding(2) +var params: Params; + +fn a_value(base: u32) -> DataType { + let offset: u32 = select(0, params.ne0, params.swapped != 0); + return src0[base + offset]; +} + +fn b_value(base: u32) -> DataType { + let offset: u32 = select(params.ne0, 0, params.swapped != 0); + return src0[base + offset]; +} + +#else +@group(0) @binding(1) +var src1: array; + +@group(0) @binding(2) +var dst: array; + +@group(0) @binding(3) +var params: Params; + +fn a_value(base: u32) -> DataType { + return src0[base]; +} + +fn b_value(base: u32) -> DataType { + return src1[base]; +} + +#endif + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(global_invocation_id) gid: vec3) { + if (gid.x >= params.ne) { + return; + } + + var i = gid.x; + let i3 = i / (params.ne2 * params.ne1 * params.ne0); + i = i % (params.ne2 * params.ne1 * params.ne0); + let i2 = i / (params.ne1 * params.ne0); + i = i % (params.ne1 * params.ne0); + let i1 = i / params.ne0; + let i0 = i % params.ne0; + + let i_a = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01 + i0; + let i_b = params.offset_src1 + i3 * params.stride_src13 + i2 * params.stride_src12 + i1 * params.stride_src11 + i0; + let i_dst = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1 + i0; + + dst[i_dst] = op(a_value(i_a), b_value(i_b)); +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/rope.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/rope.wgsl new file mode 100644 index 00000000000..1c874e14240 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/rope.wgsl @@ -0,0 +1,224 @@ +enable f16; + +#ifdef TYPE_F32 +#define DataType f32 +#endif +#ifdef TYPE_F16 +#define DataType f16 +#endif + +struct Params { + offset_src0: u32, + offset_src1: u32, + offset_src2: u32, + offset_dst: u32, + + // Strides (in elements) + stride_src01: u32, + stride_src02: u32, + stride_src03: u32, + + stride_dst1: u32, + stride_dst2: u32, + stride_dst3: u32, + + n_threads: u32, + ne0: u32, + ne1: u32, + ne2: u32, + + n_dims: u32, + mode: u32, + theta_scale: f32, + attn_factor: f32, + freq_scale: f32, + ext_factor: f32, + corr_dim0: f32, + corr_dim1: f32, + sections0: u32, + sections1: u32, + sections2: u32, + sections3: u32 +}; + +@group(0) @binding(0) +var src0: array; +@group(0) @binding(1) +var src1: array; + +#ifdef INPLACE + +#ifdef FF_FUNC + +@group(0) @binding(2) +var src2: array; + +@group(0) @binding(3) +var params: Params; + +#else + +@group(0) @binding(2) +var params: Params; + +#endif + +#else + +#ifdef FF_FUNC +@group(0) @binding(2) +var src2: array; + +@group(0) @binding(3) +var dst: array; + +@group(0) @binding(4) +var params: Params; + +#else +@group(0) @binding(2) +var dst: array; + +@group(0) @binding(3) +var params: Params; +#endif +#endif + +#ifdef FF_FUNC +fn freq_factor(i: u32) -> f32 { + return src2[params.offset_src2 + i/2]; +} + +#else +fn freq_factor(i: u32) -> f32 { + return 1.0f; +} +#endif +#ifdef INPLACE +fn rotate(i_dst0: u32, i_dst1: u32, out0: f32, out1: f32) { + src0[i_dst0] = DataType(out0); + src0[i_dst1] = DataType(out1); +} +#else +fn rotate(i_dst0: u32, i_dst1: u32, out0: f32, out1: f32) { + dst[i_dst0] = DataType(out0); + dst[i_dst1] = DataType(out1); +} +#endif + +fn rope_yarn_ramp(low: f32, high: f32, i: u32) -> f32 { + let y = (f32(i / 2) - low) / max(0.001f, high - low); + return 1.0f - min(1.0f, max(0.0f, y)); +} + +// returns vector of (cos_theta, sin_theta) +// TODO: check performance of instantiating once on the CPU and passed as buffer, since it's repeated per-row +fn rope_yarn(theta_extrap: f32, i: u32) -> vec2 { + var mscale = params.attn_factor; + var theta = params.freq_scale * theta_extrap; + if (params.ext_factor != 0.0f) { + let ramp_mix = rope_yarn_ramp(params.corr_dim0, params.corr_dim1, i) * params.ext_factor; + theta = theta * (1 - ramp_mix) + theta_extrap * ramp_mix; + mscale *= 1.0f + 0.1f * log(1.0f / params.freq_scale); + } + return vec2(cos(theta) * mscale, sin(theta) * mscale); +} + +fn pair_base(i0: u32, div_2: bool) -> u32 { + if (div_2) { + return i0 / 2; + } else { + return i0; + } +} + +fn pair_offset(is_neox: bool, is_mrope: bool, is_vision: bool) -> u32 { + if (is_vision) { + return params.n_dims; + } else if (is_neox || is_mrope) { + return params.n_dims / 2; + } else { + return 1; + } +} + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(global_invocation_id) gid: vec3) { + // two elements per n_threads + if (gid.x >= params.n_threads) { + return; + } + + let is_neox = bool(params.mode & 2); + let is_mrope = bool(params.mode & 8); + let is_imrope = params.mode == 40; + let is_vision = params.mode == 24; + + var i = gid.x * 2; // start index for this thread + let i3 = i / (params.ne2 * params.ne1 * params.ne0); + i = i % (params.ne2 * params.ne1 * params.ne0); + let i2 = i / (params.ne1 * params.ne0); + i = i % (params.ne1 * params.ne0); + let i1 = i / params.ne0; + let i0 = i % params.ne0; + + let i_src_row = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01; + let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1; + + if (i0 >= params.n_dims && !is_vision) { + let i_src = i_src_row + i0; + let i_dst = i_dst_row + i0; + rotate(i_dst, i_dst + 1, f32(src0[i_src]), f32(src0[i_src + 1])); + return; + } + + var theta_base_mult: u32 = 0; + var theta_scale_pwr: u32 = i0 / 2; + if (is_mrope) { + let sect_dims = params.sections0 + params.sections1 + params.sections2 + params.sections3; + let sec_w = params.sections1 + params.sections0; + let sec_e = params.sections2 + sec_w; + let sector = (i0 / 2) % sect_dims; + if (is_imrope) { + if (sector % 3 == 1 && sector < 3 * params.sections1) { + theta_base_mult = 1; + } else if (sector % 3 == 2 && sector < 3 * params.sections2) { + theta_base_mult = 2; + } else if (sector % 3 == 0 && sector < 3 * params.sections0) { + theta_base_mult = 0; + } else { + theta_base_mult = 3; + } + } else { + if (sector >= params.sections0 && sector < sec_w) { + theta_base_mult = 1; + if (is_vision) { + theta_scale_pwr = sector - params.sections0; + } + } else if (sector >= sec_w && sector < sec_e) { + theta_base_mult = 2; + if (is_vision) { + theta_scale_pwr = sector - sec_w; + } + } else if (sector >= sec_e) { + if (is_vision) { + theta_scale_pwr = sector - sec_e; + theta_scale_pwr = (i0 / 2) % sec_e; + } + theta_base_mult = 3; + } else if (is_vision) { + theta_scale_pwr = sector; + } + } + } + let theta_base = f32(src1[params.offset_src1 + i2 + params.ne2 * theta_base_mult]) * pow(params.theta_scale, f32(theta_scale_pwr)); + let thetas = rope_yarn(theta_base/freq_factor(i0), i0); + + let i_src = i_src_row + pair_base(i0, is_neox || is_mrope || is_vision); + let i_dst = i_dst_row + pair_base(i0, is_neox || is_mrope || is_vision); + + let x0 = f32(src0[i_src]); + let x1 = f32(src0[i_src + pair_offset(is_neox, is_mrope, is_vision)]); + rotate(i_dst, i_dst + pair_offset(is_neox, is_mrope, is_vision), x0 * thetas.x - x1 * thetas.y, x0 * thetas.y + x1 * thetas.x); + +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.wgsl new file mode 100644 index 00000000000..10edf136048 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.wgsl @@ -0,0 +1,245 @@ +enable f16; + +#ifdef MASK_F32 +#define MaskType f32 +#endif +#ifdef MASK_F16 +#define MaskType f16 +#endif + +struct Params { + offset_src0: u32, + offset_src1: u32, + offset_sinks: u32, + offset_dst: u32, + + // Strides (in elements) + stride_src01: u32, + stride_src02: u32, + stride_src03: u32, + + stride_src11: u32, + stride_src12: u32, + stride_src13: u32, + + stride_dst1: u32, + stride_dst2: u32, + stride_dst3: u32, + + // shape of src0/dst + ne: u32, + ne0: u32, + ne1: u32, + ne2: u32, + + // shape of src1 + ne12: u32, + ne13: u32, + + scale: f32, + max_bias: f32, + n_head_log2: f32, + m0: f32, + m1: f32, +}; + +@group(0) @binding(0) +var src: array; + +#ifdef HAS_MASK +#ifdef HAS_SINK +@group(0) @binding(1) +var mask: array; +@group(0) @binding(2) +var sinks: array; + +#ifdef INPLACE +@group(0) @binding(3) +var params: Params; + +#else +@group(0) @binding(3) +var dst: array; +@group(0) @binding(4) +var params: Params; +#endif + +#else +@group(0) @binding(1) +var mask: array; + +#ifdef INPLACE +@group(0) @binding(2) +var params: Params; + +#else +@group(0) @binding(2) +var dst: array; +@group(0) @binding(3) +var params: Params; +#endif +#endif + +#else +#ifdef HAS_SINK +@group(0) @binding(1) +var sinks: array; + +#ifdef INPLACE +@group(0) @binding(2) +var params: Params; + +#else +@group(0) @binding(2) +var dst: array; +@group(0) @binding(3) +var params: Params; +#endif + +#else +#ifdef INPLACE +@group(0) @binding(1) +var params: Params; +#else +@group(0) @binding(1) +var dst: array; +@group(0) @binding(2) +var params: Params; +#endif +#endif +#endif + +#ifdef INPLACE +fn inter_value(i: u32) -> f32 { + return src[i]; +} +fn update(i: u32, val: f32) { + src[i] = val; +} + +#else +fn inter_value(i: u32) -> f32 { + return dst[i]; +} +fn update(i: u32, val: f32) { + dst[i] = val; +} +#endif + +#ifdef HAS_MASK +fn mask_val(i: u32) -> f32 { + return f32(mask[i]); +} + +#else +fn mask_val(i: u32) -> f32 { + return 0.0; +} +#endif + +#ifdef HAS_SINK +fn lower_max_bound(i2: u32) -> f32 { + return sinks[params.offset_sinks + i2]; +} +fn add_sinks(val: f32, i2: u32, max_val: f32) -> f32 { + return val + exp(sinks[params.offset_sinks + i2] - max_val); +} +#else +fn lower_max_bound(i2: u32) -> f32 { + return -1e30; +} +fn add_sinks(val: f32, i2: u32, max_val: f32) -> f32 { + return val; +} +#endif + +const CACHE_SIZE: u32 = 16; +var scratch: array; + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(workgroup_id) wid: vec3, + @builtin(local_invocation_id) lid: vec3) { + + var i = wid.x; + let i3 = i / (params.ne2 * params.ne1); + i = i % (params.ne2 * params.ne1); + let i2 = i / params.ne1; + let i1 = i % params.ne1; + let i_src0_row = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01; + let i_src1_row = params.offset_src1 + (i3 % params.ne13) * params.stride_src13 + (i2 % params.ne12) * params.stride_src12 + i1 * params.stride_src11; + let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1; + let elems = (params.ne0 + WG_SIZE - 1) / WG_SIZE; + + let head = f32(i2); + let slope = select(1, select(pow(params.m1, 2 * (head - params.n_head_log2) + 1), pow(params.m0, head + 1), head < params.n_head_log2), params.max_bias > 0); + + var cache: array; + + var max_val = lower_max_bound(i2); + var col = lid.x; + for (var j: u32 = 0; j < elems; j++) { + if (col >= params.ne0) { + break; + } + let val = src[i_src0_row + col] * params.scale + slope * mask_val(i_src1_row + col); + max_val = max(max_val, val); + if (col < CACHE_SIZE) { + cache[col] = val; + } + col += WG_SIZE; + } + + scratch[lid.x] = max_val; + workgroupBarrier(); + var offset: u32 = WG_SIZE / 2; + while (offset > 0) { + if (lid.x < offset) { + scratch[lid.x] = max(scratch[lid.x], scratch[lid.x + offset]); + } + offset = offset / 2; + workgroupBarrier(); + } + let row_max = scratch[0]; + workgroupBarrier(); + + var sum = 0.0f; + col = lid.x; + for (var j: u32 = 0; j < elems; j++) { + if (col >= params.ne0) { + break; + } + let val = select(src[i_src0_row + col] * params.scale + slope * mask_val(i_src1_row + col), + cache[col], col < CACHE_SIZE); + let ex = exp(val - row_max); + sum += ex; + if (col < CACHE_SIZE) { + cache[col] = ex; + } else { + update(i_dst_row + col, ex); + } + col += WG_SIZE; + } + + scratch[lid.x] = sum; + workgroupBarrier(); + offset = WG_SIZE / 2; + while (offset > 0) { + if (lid.x < offset) { + scratch[lid.x] += scratch[lid.x + offset]; + } + offset = offset / 2; + workgroupBarrier(); + } + let row_sum = add_sinks(scratch[0], i2, row_max); + + let sum_recip = 1.0 / row_sum; + col = lid.x; + for (var j: u32 = 0; j < elems; j++) { + if (col >= params.ne0) { + break; + } + update(i_dst_row + col, select(inter_value(i_dst_row + col), cache[col], col < CACHE_SIZE) * sum_recip); + col += WG_SIZE; + } +} + From 78f54d15d80aded8a603e12fb066539acdb32f49 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Tue, 31 Mar 2026 22:38:24 -0700 Subject: [PATCH 369/831] ggml webgpu: quantized buffers to u32 + wider browser/device support (llama/21046) * Work towards removing bitcast * Move rest of existing types over * Add timeout back to wait and remove synchronous set_tensor/memset_tensor * move to unpackf16 for wider compatibility * cleanup * Remove deadlock condition in free_bufs --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 10 +- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 17 +- .../wgsl-shaders/common_decls.tmpl | 24 +++ .../ggml-webgpu/wgsl-shaders/flash_attn.wgsl | 81 ++++++-- .../wgsl-shaders/mul_mat_decls.tmpl | 196 +++++++----------- .../ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl | 103 ++++----- 6 files changed, 207 insertions(+), 224 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 97863f40412..a194ce84e25 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -1219,9 +1219,8 @@ class ggml_webgpu_shader_lib { defines.push_back("BYTE_HELPERS"); defines.push_back("MUL_ACC_" + type_upper); - - // For fast path we always dequantize from f16 inside the shader - defines.push_back("SRC0_INNER_TYPE=f16"); + defines.push_back("U32_DEQUANT_HELPERS"); + defines.push_back("SRC0_INNER_TYPE=u32"); break; } } @@ -1334,9 +1333,8 @@ class ggml_webgpu_shader_lib { defines.push_back("MUL_ACC_" + type_upper); defines.push_back("INIT_SRC0_SHMEM_" + type_upper); defines.push_back("INIT_SRC1_SHMEM_FLOAT"); - - // Use f16 inside the shader for quantized types - defines.push_back("SRC0_INNER_TYPE=f16"); + defines.push_back("U32_DEQUANT_HELPERS"); + defines.push_back("SRC0_INNER_TYPE=u32"); variant += std::string("_") + src0_name; break; diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index fa3c492a7a5..1aa15b0507c 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -83,7 +83,7 @@ static inline void compute_2d_workgroups(uint32_t total_wg, uint32_t max_per_dim #define WEBGPU_NUM_PARAM_BUFS 96u #define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 32u -#define WEBGPU_WAIT_ANY_TIMEOUT_MS 0 +#define WEBGPU_WAIT_ANY_TIMEOUT_MS 100 // Maximum number of in-flight submissions per-thread, to avoid exhausting the // parameter buffer pool #define WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD (WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE) @@ -171,6 +171,7 @@ struct webgpu_buf_pool { // Try growing the pool if no free buffers if (free.empty() && cur_pool_size < max_pool_size && should_grow) { cur_pool_size++; + lock.unlock(); // avoid deadlock between this lock and Dawn's internal locks when buffers are freed in callbacks wgpu::Buffer dev_buf; ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_pool_buf"); @@ -507,7 +508,7 @@ static void ggml_backend_webgpu_wait(webgpu_global_context & ctx, bool blocking_wait = block || subs.size() >= WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD; while (blocking_wait) { - auto waitStatus = ctx->instance.WaitAny(1, &subs[0].submit_done, 0); + auto waitStatus = ctx->instance.WaitAny(1, &subs[0].submit_done, WEBGPU_WAIT_ANY_TIMEOUT_MS * 1e6); if (ggml_backend_webgpu_handle_wait_status(waitStatus, true)) { #ifdef GGML_WEBGPU_GPU_PROFILE ggml_backend_webgpu_wait_profile_futures(ctx, subs[0].profile_futures, true); @@ -728,7 +729,6 @@ static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx, ggml_backend_webgpu_build(ctx, ctx->memset_buf_pool, ctx->memset_pipelines[0], params, entries, wg_x); std::vector commands = { command }; std::vector sub = { ggml_backend_webgpu_submit(ctx, commands, ctx->memset_buf_pool) }; - ggml_backend_webgpu_wait(ctx, sub); } /** End WebGPU Actions */ @@ -2694,17 +2694,6 @@ static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer, // memset the remaining bytes ggml_backend_webgpu_buffer_memset(buf_ctx->global_ctx, buf_ctx->buffer, val32, total_offset + (size - remaining_size), remaining_size); - } else { - // wait for WriteBuffer to complete - buf_ctx->global_ctx->instance.WaitAny(buf_ctx->global_ctx->queue.OnSubmittedWorkDone( - wgpu::CallbackMode::AllowSpontaneous, - [](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) { - if (status != wgpu::QueueWorkDoneStatus::Success) { - GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", - std::string(message).c_str()); - } - }), - UINT64_MAX); } WEBGPU_CPU_PROFILE_TOTAL_END(set_tensor, buf_ctx->global_ctx); } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl index 9a5b18ebc07..feb0bca3f84 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl @@ -8,6 +8,30 @@ fn get_byte_i32(value: u32, index: u32) -> i32 { } #endif +#ifdef U32_DEQUANT_HELPERS +fn load_src0_u16_at(byte_offset: u32) -> u32 { + let word = src0[byte_offset / 4u]; + let shift = (byte_offset & 2u) * 8u; + return (word >> shift) & 0xFFFFu; +} + +fn load_src0_u32_at(byte_offset: u32) -> u32 { + let word_idx = byte_offset / 4u; + let shift = (byte_offset & 3u) * 8u; + let lo = src0[word_idx]; + if (shift == 0u) { + return lo; + } + let hi = src0[word_idx + 1u]; + return (lo >> shift) | (hi << (32u - shift)); +} + +fn load_src0_f16_at(byte_offset: u32) -> f16 { + let packed = unpack2x16float(load_src0_u16_at(byte_offset)); + return f16(packed[0]); +} +#endif + #ifdef Q4_0_T struct q4_0 { d: f16, diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl index b6822161464..8b76cecba91 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl @@ -6,6 +6,8 @@ enable chromium_experimental_subgroup_matrix; #ifdef KV_F32 #define KV_TYPE f32 +#elif defined(KV_Q4_0) || defined(KV_Q8_0) +#define KV_TYPE u32 #else #define KV_TYPE f16 #endif @@ -37,11 +39,13 @@ enable chromium_experimental_subgroup_matrix; #define NQ 16 // Q4_0 has 32 elements, 1 f16 for scale, 8 f16 for 4-bit weights #define F16_PER_BLOCK 9 +#define BLOCK_SIZE_BYTES 18u #define WEIGHTS_PER_F16 4 #elif defined(KV_Q8_0) #define NQ 8 // Q8_0 has 32 elements, 1 f16 for scale, 16 f16 for 8-bit weights #define F16_PER_BLOCK 17 +#define BLOCK_SIZE_BYTES 34u #define WEIGHTS_PER_F16 2 #endif #define F16_PER_THREAD (NQ / WEIGHTS_PER_F16) @@ -55,6 +59,47 @@ fn get_byte_i32(value: u32, index: u32) -> i32 { return bitcast(((value >> (index * 8)) & 0xFF) << 24) >> 24; } +#if defined(KV_Q4_0) || defined(KV_Q8_0) +fn load_k_u16_at(byte_offset: u32) -> u32 { + let word = K[byte_offset / 4u]; + let shift = (byte_offset & 2u) * 8u; + return (word >> shift) & 0xFFFFu; +} + +fn load_k_u32_at(byte_offset: u32) -> u32 { + let word_idx = byte_offset / 4u; + let shift = (byte_offset & 3u) * 8u; + let lo = K[word_idx]; + if (shift == 0u) { + return lo; + } + let hi = K[word_idx + 1u]; + return (lo >> shift) | (hi << (32u - shift)); +} + +fn load_v_u16_at(byte_offset: u32) -> u32 { + let word = V[byte_offset / 4u]; + let shift = (byte_offset & 2u) * 8u; + return (word >> shift) & 0xFFFFu; +} + +fn load_v_u32_at(byte_offset: u32) -> u32 { + let word_idx = byte_offset / 4u; + let shift = (byte_offset & 3u) * 8u; + let lo = V[word_idx]; + if (shift == 0u) { + return lo; + } + let hi = V[word_idx + 1u]; + return (lo >> shift) | (hi << (32u - shift)); +} + +fn f16_from_u16(bits: u32) -> f16 { + let packed = unpack2x16float(bits); + return f16(packed[0]); +} +#endif + struct Params { offset_q: u32, offset_k: u32, @@ -254,12 +299,11 @@ fn main(@builtin(workgroup_id) wg_id: vec3, if (global_k_row < params.seq_len_kv) { let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k; - let base_idx = global_block_idx * F16_PER_BLOCK; - let d = K[base_idx]; // scale + let block_byte_base = global_block_idx * BLOCK_SIZE_BYTES; + let d = f16_from_u16(load_k_u16_at(block_byte_base)); for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_0 = K[base_idx + 1u + block_offset + j]; - let q_1 = K[base_idx + 1u + block_offset + j + 1]; - let q_packed = bitcast(vec2(q_0, q_1)); + let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); + let q_packed = load_k_u32_at(q_byte_offset); for (var k = 0u; k < 4u; k++) { let q_byte = get_byte(q_packed, k); let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d; @@ -282,12 +326,11 @@ fn main(@builtin(workgroup_id) wg_id: vec3, if (global_k_row < params.seq_len_kv) { let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k; - let base_idx = global_block_idx * F16_PER_BLOCK; - let d = K[base_idx]; // scale + let block_byte_base = global_block_idx * BLOCK_SIZE_BYTES; + let d = f16_from_u16(load_k_u16_at(block_byte_base)); for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_0 = K[base_idx + 1u + block_offset + j]; - let q_1 = K[base_idx + 1u + block_offset + j + 1]; - let q_packed = bitcast(vec2(q_0, q_1)); + let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); + let q_packed = load_k_u32_at(q_byte_offset); for (var k = 0u; k < 4u; k++) { let q_byte = get_byte_i32(q_packed, k); let q_val = f16(q_byte) * d; @@ -459,12 +502,11 @@ fn main(@builtin(workgroup_id) wg_id: vec3, if (global_v_row < params.seq_len_kv) { let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k; - let base_idx = global_block_idx * F16_PER_BLOCK; - let d = V[base_idx]; // scale + let block_byte_base = global_block_idx * BLOCK_SIZE_BYTES; + let d = f16_from_u16(load_v_u16_at(block_byte_base)); for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_0 = V[base_idx + 1u + block_offset + j]; - let q_1 = V[base_idx + 1u + block_offset + j + 1]; - let q_packed = bitcast(vec2(q_0, q_1)); + let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); + let q_packed = load_v_u32_at(q_byte_offset); for (var k = 0u; k < 4u; k++) { let q_byte = get_byte(q_packed, k); let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d; @@ -487,12 +529,11 @@ fn main(@builtin(workgroup_id) wg_id: vec3, if (global_v_row < params.seq_len_kv) { let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k; - let base_idx = global_block_idx * F16_PER_BLOCK; - let d = V[base_idx]; // scale + let block_byte_base = global_block_idx * BLOCK_SIZE_BYTES; + let d = f16_from_u16(load_v_u16_at(block_byte_base)); for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_0 = V[base_idx + 1u + block_offset + j]; - let q_1 = V[base_idx + 1u + block_offset + j + 1]; - let q_packed = bitcast(vec2(q_0, q_1)); + let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); + let q_packed = load_v_u32_at(q_byte_offset); for (var k = 0u; k < 4u; k++) { let q_byte = get_byte_i32(q_packed, k); let q_val = f16(q_byte) * d; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl index de60ebbcf2b..eb228537bad 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl @@ -61,10 +61,10 @@ fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u3 #ifdef INIT_SRC0_SHMEM_Q4_0 const BLOCK_SIZE = 32u; +const BLOCK_SIZE_BYTES = 18u; // the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. override BLOCKS_K = TILE_K/BLOCK_SIZE; const NQ = 16u; -const F16_PER_BLOCK = 9u; // 1 scale + 8x4 packed weights const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; @@ -81,14 +81,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { let src0_idx = batch_offset + global_m * params.stride_01 + global_k; - let scale_idx = src0_idx * F16_PER_BLOCK; - let d = src0[scale_idx]; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + let d = load_src0_f16_at(block_byte_base); for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_0 = src0[scale_idx + 1u + block_offset + j]; - let q_1 = src0[scale_idx + 1u + block_offset + j + 1]; - - let q_packed = bitcast(vec2(q_0, q_1)); + let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); + let q_packed = load_src0_u32_at(q_byte_offset); for (var k = 0u; k < 4u; k++) { let q_byte = get_byte(q_packed, k); let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d; @@ -104,10 +102,10 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 #ifdef INIT_SRC0_SHMEM_Q4_1 const BLOCK_SIZE = 32u; +const BLOCK_SIZE_BYTES = 20u; // the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. override BLOCKS_K = TILE_K/BLOCK_SIZE; const NQ = 16u; -const F16_PER_BLOCK = 10u; // 1 scale + 8 packed weights + 1 mean const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; @@ -124,15 +122,13 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { let src0_idx = batch_offset + global_m * params.stride_01 + global_k; - let scale_idx = src0_idx * F16_PER_BLOCK; - let d = src0[scale_idx]; - let m = src0[scale_idx + 1u]; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + let d = load_src0_f16_at(block_byte_base); + let m = load_src0_f16_at(block_byte_base + 2u); for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_0 = src0[scale_idx + 2u + block_offset + j]; - let q_1 = src0[scale_idx + 2u + block_offset + j + 1]; - - let q_packed = bitcast(vec2(q_0, q_1)); + let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j); + let q_packed = load_src0_u32_at(q_byte_offset); for (var k = 0u; k < 4u; k++) { let q_byte = get_byte(q_packed, k); let q_lo = f16(q_byte & 0xF) * d + m; @@ -149,11 +145,11 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 #ifdef INIT_SRC0_SHMEM_Q5_0 // 32 weights per block, each at 4 bits each = 32 * 4 = 128 bits / 16 = 8 f16s per block const BLOCK_SIZE = 32u; +const BLOCK_SIZE_BYTES = 22u; // the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. // tile_k is defined as 32u, so blocks_k ends up being 1 always override BLOCKS_K = TILE_K / BLOCK_SIZE; const NQ = 16u; -const F16_PER_BLOCK = 11u; // 1 scale + 2 qh + 8 packed weights const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 16 / 4 = 4 f16s per thread, each thread should handle 4 f16s * 4 weights per = 16 weights @@ -171,18 +167,14 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { let src0_idx = batch_offset + global_m * params.stride_01 + global_k; - let scale_idx = src0_idx * F16_PER_BLOCK; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = src0[scale_idx]; - let qh0 = src0[scale_idx + 1u]; - let qh1 = src0[scale_idx + 2u]; - let qh_packed = bitcast(vec2(qh0, qh1)); + let d = load_src0_f16_at(block_byte_base); + let qh_packed = load_src0_u32_at(block_byte_base + 2u); for (var j = 0u; j < 2; j++) { - let q_0 = src0[scale_idx + 3u + block_offset + (j*2)]; - let q_1 = src0[scale_idx + 3u + block_offset + (j*2) + 1u]; - - let q_packed = bitcast(vec2(q_0, q_1)); + let q_byte_offset = block_byte_base + 6u + 2u * (block_offset + j * 2u); + let q_packed = load_src0_u32_at(q_byte_offset); let j_adjusted = j + (block_offset / 2u); @@ -207,11 +199,11 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 #ifdef INIT_SRC0_SHMEM_Q5_1 // 32 weights per block, each at 4 bits each = 32 * 4 = 128 bits / 16 = 8 f16s per block const BLOCK_SIZE = 32u; +const BLOCK_SIZE_BYTES = 24u; // the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. // tile_k is defined as 32u, so blocks_k ends up being 1 always override BLOCKS_K = TILE_K / BLOCK_SIZE; const NQ = 16u; -const F16_PER_BLOCK = 12u; // 1 scale + 2 qh + 8 packed weights + 1 mean const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 16 / 4 = 4 f16s per thread, each thread should handle 4 f16s * 4 weights per = 16 weights @@ -229,20 +221,16 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { let src0_idx = batch_offset + global_m * params.stride_01 + global_k; - let scale_idx = src0_idx * F16_PER_BLOCK; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = src0[scale_idx]; - let m = src0[scale_idx + 1u]; - let qh0 = src0[scale_idx + 2u]; - let qh1 = src0[scale_idx + 3u]; - let qh_packed = bitcast(vec2(qh0, qh1)); + let d = load_src0_f16_at(block_byte_base); + let m = load_src0_f16_at(block_byte_base + 2u); + let qh_packed = load_src0_u32_at(block_byte_base + 4u); for (var j = 0u; j < 2; j++) { - let q_0 = src0[scale_idx + 4u + block_offset + (j*2)]; - let q_1 = src0[scale_idx + 4u + block_offset + (j*2) + 1u]; - - let q_packed = bitcast(vec2(q_0, q_1)); + let q_byte_offset = block_byte_base + 8u + 2u * (block_offset + j * 2u); + let q_packed = load_src0_u32_at(q_byte_offset); let j_adjusted = j + (block_offset / 2u); @@ -266,10 +254,10 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 #ifdef INIT_SRC0_SHMEM_Q8_0 const BLOCK_SIZE = 32u; +const BLOCK_SIZE_BYTES = 34u; // the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. override BLOCKS_K = TILE_K/BLOCK_SIZE; const NQ = 16u; -const F16_PER_BLOCK = 17u; // 1 scale + 16 in array of weights const WEIGHTS_PER_F16 = 2u; // 2 8-bit weights per f16 const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 8 f16s per thread @@ -286,14 +274,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { let src0_idx = batch_offset + global_m * params.stride_01 + global_k; - let scale_idx = src0_idx * F16_PER_BLOCK; - let d = src0[scale_idx]; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + let d = load_src0_f16_at(block_byte_base); for (var j = 0u; j < F16_PER_THREAD; j+=2) { - let q_0 = src0[scale_idx + 1u + block_offset + j]; - let q_1 = src0[scale_idx + 1u + block_offset + j + 1]; - - let q_packed = bitcast(vec2(q_0, q_1)); + let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); + let q_packed = load_src0_u32_at(q_byte_offset); for (var k = 0u; k < 4u; k++) { let q_byte = get_byte_i32(q_packed, k); @@ -308,10 +294,10 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 #ifdef INIT_SRC0_SHMEM_Q8_1 const BLOCK_SIZE = 32u; +const BLOCK_SIZE_BYTES = 36u; // the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. override BLOCKS_K = TILE_K/BLOCK_SIZE; const NQ = 16u; -const F16_PER_BLOCK = 18u; // 1 scale + 1 mean + 8 32-bit values in array of weights const WEIGHTS_PER_F16 = 2u; // 2 8-bit weights per f16 const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 8 f16s per thread, 2 threads per block @@ -328,15 +314,13 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { let src0_idx = batch_offset + global_m * params.stride_01 + global_k; - let scale_idx = src0_idx * F16_PER_BLOCK; - let d = src0[scale_idx]; - let m = src0[scale_idx + 1u]; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + let d = load_src0_f16_at(block_byte_base); + let m = load_src0_f16_at(block_byte_base + 2u); for (var j = 0u; j < F16_PER_THREAD; j+=2) { - let q_0 = src0[scale_idx + 2u + block_offset + j]; - let q_1 = src0[scale_idx + 2u + block_offset + j + 1]; - - let q_packed = bitcast(vec2(q_0, q_1)); + let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j); + let q_packed = load_src0_u32_at(q_byte_offset); for (var k = 0u; k < 4u; k++) { let q_byte = get_byte_i32(q_packed, k); @@ -351,7 +335,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 #ifdef INIT_SRC0_SHMEM_Q2_K const BLOCK_SIZE = 256u; -const F16_PER_BLOCK = 42u; +const BLOCK_SIZE_BYTES = 84u; fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { // Use standard thread layout instead of lane/row_group @@ -371,10 +355,10 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let k_in_block = global_k % BLOCK_SIZE; let src0_idx = batch_offset + global_m * params.stride_01 + block_k; - let scale_idx = src0_idx * F16_PER_BLOCK; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = src0[scale_idx + 40u]; - let dmin = src0[scale_idx + 41u]; + let d = load_src0_f16_at(block_byte_base + 80u); + let dmin = load_src0_f16_at(block_byte_base + 82u); // Decode the element at position k_in_block let block_of_32 = k_in_block / 32u; @@ -387,18 +371,14 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let is = k_in_block / 16u; - let sc_0 = src0[scale_idx + 2u * (is / 4u)]; - let sc_1 = src0[scale_idx + 2u * (is / 4u) + 1u]; - let sc_packed = bitcast(vec2(sc_0, sc_1)); + let sc_packed = load_src0_u32_at(block_byte_base + 4u * (is / 4u)); let sc = get_byte(sc_packed, is % 4u); let dl = d * f16(sc & 0xFu); let ml = dmin * f16(sc >> 4u); let q_idx = q_b_idx + k + l; - let q_0 = src0[scale_idx + 8u + 2u * (q_idx / 4u)]; - let q_1 = src0[scale_idx + 8u + 2u * (q_idx / 4u) + 1u]; - let q_packed = bitcast(vec2(q_0, q_1)); + let q_packed = load_src0_u32_at(block_byte_base + 16u + 4u * (q_idx / 4u)); let q_byte = get_byte(q_packed, q_idx % 4u); let qs_val = (q_byte >> shift) & 3u; @@ -410,7 +390,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 #ifdef INIT_SRC0_SHMEM_Q3_K const BLOCK_SIZE = 256u; -const F16_PER_BLOCK = 55u; +const BLOCK_SIZE_BYTES = 110u; fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { @@ -429,9 +409,9 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let k_in_block = global_k % BLOCK_SIZE; let src0_idx = batch_offset + global_m * params.stride_01 + block_k; - let scale_idx = src0_idx * F16_PER_BLOCK; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = src0[scale_idx + 54u]; + let d = load_src0_f16_at(block_byte_base + 108u); // Load and unpack scales let kmask1: u32 = 0x03030303u; @@ -439,9 +419,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 var scale_vals: array; for (var i: u32 = 0u; i < 4u; i++) { - let scale_0 = src0[scale_idx + 48u + (2u*i)]; - let scale_1 = src0[scale_idx + 48u + (2u*i) + 1u]; - scale_vals[i] = bitcast(vec2(scale_0, scale_1)); + scale_vals[i] = load_src0_u32_at(block_byte_base + 96u + 4u * i); } var tmp: u32 = scale_vals[2]; @@ -453,16 +431,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 // Load hmask and qs arrays var hmask_vals: array; for (var i: u32 = 0u; i < 8u; i++) { - let hmask_0 = src0[scale_idx + (2u*i)]; - let hmask_1 = src0[scale_idx + (2u*i) + 1u]; - hmask_vals[i] = bitcast(vec2(hmask_0, hmask_1)); + hmask_vals[i] = load_src0_u32_at(block_byte_base + 4u * i); } var qs_vals: array; for (var i: u32 = 0u; i < 16u; i++) { - let qs_0 = src0[scale_idx + 16u + (2u*i)]; - let qs_1 = src0[scale_idx + 16u + (2u*i) + 1u]; - qs_vals[i] = bitcast(vec2(qs_0, qs_1)); + qs_vals[i] = load_src0_u32_at(block_byte_base + 32u + 4u * i); } let half = k_in_block / 128u; // 0 or 1 @@ -502,7 +476,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 #ifdef INIT_SRC0_SHMEM_Q4_K const BLOCK_SIZE = 256u; -const F16_PER_BLOCK = 72u; +const BLOCK_SIZE_BYTES = 144u; fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { @@ -521,17 +495,15 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let k_in_block = global_k % BLOCK_SIZE; let src0_idx = batch_offset + global_m * params.stride_01 + block_k; - let scale_idx = src0_idx * F16_PER_BLOCK; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = src0[scale_idx]; - let dmin = src0[scale_idx + 1u]; + let d = load_src0_f16_at(block_byte_base); + let dmin = load_src0_f16_at(block_byte_base + 2u); // Load packed scales var scale_vals: array; for (var i: u32 = 0u; i < 3u; i++) { - let scale_0 = src0[scale_idx + 2u + (2u*i)]; - let scale_1 = src0[scale_idx + 2u + (2u*i) + 1u]; - scale_vals[i] = bitcast(vec2(scale_0, scale_1)); + scale_vals[i] = load_src0_u32_at(block_byte_base + 4u + 4u * i); } // Map k_in_block to loop structure: @@ -567,9 +539,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let ml = dmin * f16(mn); let q_idx = q_b_idx + l; - let q_0 = src0[scale_idx + 8u + 2u * (q_idx / 4u)]; - let q_1 = src0[scale_idx + 8u + 2u * (q_idx / 4u) + 1u]; - let q_packed = bitcast(vec2(q_0, q_1)); + let q_packed = load_src0_u32_at(block_byte_base + 16u + 4u * (q_idx / 4u)); let q_byte = get_byte(q_packed, q_idx % 4u); let qs_val = (q_byte >> shift) & 0xFu; @@ -582,7 +552,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 #ifdef INIT_SRC0_SHMEM_Q5_K const BLOCK_SIZE = 256u; -const F16_PER_BLOCK = 88u; +const BLOCK_SIZE_BYTES = 176u; fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { @@ -601,17 +571,15 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let k_in_block = global_k % BLOCK_SIZE; let src0_idx = batch_offset + global_m * params.stride_01 + block_k; - let scale_idx = src0_idx * F16_PER_BLOCK; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = src0[scale_idx]; - let dmin = src0[scale_idx + 1u]; + let d = load_src0_f16_at(block_byte_base); + let dmin = load_src0_f16_at(block_byte_base + 2u); // Load packed scales var scale_vals: array; for (var i: u32 = 0u; i < 3u; i++) { - let scale_0 = src0[scale_idx + 2u + (2u*i)]; - let scale_1 = src0[scale_idx + 2u + (2u*i) + 1u]; - scale_vals[i] = bitcast(vec2(scale_0, scale_1)); + scale_vals[i] = load_src0_u32_at(block_byte_base + 4u + 4u * i); } // The original loop processes elements in groups of 64 @@ -651,15 +619,11 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let ml = dmin * f16(mn); let q_idx = q_b_idx + l; - let q_0 = src0[scale_idx + 24u + 2u * (q_idx / 4u)]; - let q_1 = src0[scale_idx + 24u + 2u * (q_idx / 4u) + 1u]; - let q_packed = bitcast(vec2(q_0, q_1)); + let q_packed = load_src0_u32_at(block_byte_base + 48u + 4u * (q_idx / 4u)); let q_byte = get_byte(q_packed, q_idx % 4u); - let qh_0 = src0[scale_idx + 8u + 2u * (l / 4u)]; - let qh_1 = src0[scale_idx + 8u + 2u * (l / 4u) + 1u]; - let qh_packed = bitcast(vec2(qh_0, qh_1)); + let qh_packed = load_src0_u32_at(block_byte_base + 16u + 4u * (l / 4u)); let qh_byte = get_byte(qh_packed, l % 4u); @@ -675,7 +639,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 #ifdef INIT_SRC0_SHMEM_Q6_K const BLOCK_SIZE = 256u; -const F16_PER_BLOCK = 105u; +const BLOCK_SIZE_BYTES = 210u; fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { @@ -694,7 +658,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let k_in_block = global_k % BLOCK_SIZE; let src0_idx = batch_offset + global_m * params.stride_01 + block_k; - let scale_idx = src0_idx * F16_PER_BLOCK; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; let half = k_in_block / 128u; let pos_in_half = k_in_block % 128u; @@ -707,30 +671,18 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 // Load only ql13 word needed let ql13_flat = ql_b_idx + l; - let ql13_word = ql13_flat / 4u; - let ql13 = bitcast(vec2( - src0[scale_idx + 2u * ql13_word], - src0[scale_idx + 2u * ql13_word + 1u] - )); - let ql13_b = get_byte(ql13, ql13_flat % 4u); + let ql13 = load_src0_u32_at(block_byte_base + ql13_flat); + let ql13_b = get_byte(ql13, 0u); // Load only ql24 word needed let ql24_flat = ql_b_idx + l + 32u; - let ql24_word = ql24_flat / 4u; - let ql24 = bitcast(vec2( - src0[scale_idx + 2u * ql24_word], - src0[scale_idx + 2u * ql24_word + 1u] - )); - let ql24_b = get_byte(ql24, ql24_flat % 4u); + let ql24 = load_src0_u32_at(block_byte_base + ql24_flat); + let ql24_b = get_byte(ql24, 0u); // Load only qh word needed let qh_flat = qh_b_idx + l; - let qh_word = qh_flat / 4u; - let qh = bitcast(vec2( - src0[scale_idx + 64u + 2u * qh_word], - src0[scale_idx + 64u + 2u * qh_word + 1u] - )); - let qh_b = get_byte(qh, qh_flat % 4u); + let qh = load_src0_u32_at(block_byte_base + 128u + qh_flat); + let qh_b = get_byte(qh, 0u); let q1 = f16((ql13_b & 0xFu) | ((qh_b & 3u) << 4u)) - f16(32.0); let q2 = f16((ql24_b & 0xFu) | (((qh_b >> 2u) & 3u) << 4u)) - f16(32.0); @@ -740,14 +692,10 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 // Load only the scale word needed let is = l / 16u; let sc_idx = sc_b_idx + is + quarter * 2u; - let sc_word = sc_idx / 4u; - let sc = bitcast(vec2( - src0[scale_idx + 96u + 2u * sc_word], - src0[scale_idx + 96u + 2u * sc_word + 1u] - )); - let sc_val = get_byte_i32(sc, sc_idx % 4u); - - let d = src0[scale_idx + 104u]; + let sc = load_src0_u32_at(block_byte_base + 192u + sc_idx); + let sc_val = get_byte_i32(sc, 0u); + + let d = load_src0_f16_at(block_byte_base + 208u); var q_val: f16; if (quarter == 0u) { diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl index 94f4bae11f4..6525f23bdfc 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl @@ -52,8 +52,8 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { #ifdef MUL_ACC_Q4_0 const BLOCK_SIZE = 32; +const BLOCK_SIZE_BYTES = 18u; const NQ = 16u; // number of weights per thread -const F16_PER_BLOCK = 9u; // 1 scale + 8x4 packed weights const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; @@ -62,14 +62,13 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) { let blck_idx = i / BLOCK_SIZE; let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; - let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK; + let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(src0[scale_idx]); + let d = f32(load_src0_f16_at(block_byte_base)); for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_0 = src0[scale_idx + 1 + block_offset + j]; - let q_1 = src0[scale_idx + 1 + block_offset + j + 1]; - let q_packed = bitcast(vec2(q_0, q_1)); + let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); + let q_packed = load_src0_u32_at(q_byte_offset); for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte(q_packed, k); let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * d; @@ -86,8 +85,8 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { #ifdef MUL_ACC_Q4_1 const BLOCK_SIZE = 32; +const BLOCK_SIZE_BYTES = 20u; const NQ = 16u; // number of weights per thread -const F16_PER_BLOCK = 10u; const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; @@ -96,15 +95,14 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) { let blck_idx = i / BLOCK_SIZE; let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; - let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK; + let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(src0[scale_idx]); - let m = f32(src0[scale_idx + 1u]); + let d = f32(load_src0_f16_at(block_byte_base)); + let m = f32(load_src0_f16_at(block_byte_base + 2u)); for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_0 = src0[scale_idx + 2u + block_offset + j]; - let q_1 = src0[scale_idx + 2u + block_offset + j + 1]; - let q_packed = bitcast(vec2(q_0, q_1)); + let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j); + let q_packed = load_src0_u32_at(q_byte_offset); for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte(q_packed, k); let q_hi = f32((q_byte >> 4) & 0xF) * d + m; @@ -121,8 +119,8 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { #ifdef MUL_ACC_Q5_0 const BLOCK_SIZE = 32; +const BLOCK_SIZE_BYTES = 22u; const NQ = 16u; // number of weights per thread -const F16_PER_BLOCK = 11u; const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; @@ -131,18 +129,15 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) { let blck_idx = i / BLOCK_SIZE; let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; - let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK; + let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(src0[scale_idx]); - let qh0 = src0[scale_idx + 1u]; - let qh1 = src0[scale_idx + 2u]; - let qh_packed = bitcast(vec2(qh0, qh1)); + let d = f32(load_src0_f16_at(block_byte_base)); + let qh_packed = load_src0_u32_at(block_byte_base + 2u); for (var j = 0u; j < 2; j++) { - let q_0 = src0[scale_idx + 3u + block_offset + (j*2)]; - let q_1 = src0[scale_idx + 3u + block_offset + (j*2) + 1u]; - let q_packed = bitcast(vec2(q_0, q_1)); + let q_byte_offset = block_byte_base + 6u + 2u * (block_offset + j * 2u); + let q_packed = load_src0_u32_at(q_byte_offset); let j_adjusted = j + (block_offset / 2u); @@ -168,8 +163,8 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { #ifdef MUL_ACC_Q5_1 const BLOCK_SIZE = 32; +const BLOCK_SIZE_BYTES = 24u; const NQ = 16u; // number of weights per thread -const F16_PER_BLOCK = 12u; const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; @@ -178,19 +173,16 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) { let blck_idx = i / BLOCK_SIZE; let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; - let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK; + let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(src0[scale_idx]); - let m = src0[scale_idx + 1u]; - let qh0 = src0[scale_idx + 2u]; - let qh1 = src0[scale_idx + 3u]; - let qh_packed = bitcast(vec2(qh0, qh1)); + let d = f32(load_src0_f16_at(block_byte_base)); + let m = load_src0_f16_at(block_byte_base + 2u); + let qh_packed = load_src0_u32_at(block_byte_base + 4u); for (var j = 0u; j < 2; j++) { - let q_0 = src0[scale_idx + 4u + block_offset + (j*2)]; - let q_1 = src0[scale_idx + 4u + block_offset + (j*2) + 1u]; - let q_packed = bitcast(vec2(q_0, q_1)); + let q_byte_offset = block_byte_base + 8u + 2u * (block_offset + j * 2u); + let q_packed = load_src0_u32_at(q_byte_offset); let j_adjusted = j + (block_offset / 2u); @@ -216,8 +208,8 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { #ifdef MUL_ACC_Q8_0 const BLOCK_SIZE = 32; +const BLOCK_SIZE_BYTES = 34u; const NQ = 16u; // number of weights per thread -const F16_PER_BLOCK = 17u; const WEIGHTS_PER_F16 = 2u; const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; @@ -226,15 +218,14 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) { let blck_idx = i / BLOCK_SIZE; let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; - let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK; + let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(src0[scale_idx]); + let d = f32(load_src0_f16_at(block_byte_base)); for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_0 = src0[scale_idx + 1 + block_offset + j]; - let q_1 = src0[scale_idx + 1 + block_offset + j + 1]; - let q_packed = bitcast(vec2(q_0, q_1)); + let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); + let q_packed = load_src0_u32_at(q_byte_offset); for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte_i32(q_packed, k); let q_val = f32(q_byte) * d; @@ -250,8 +241,8 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { #ifdef MUL_ACC_Q8_1 const BLOCK_SIZE = 32; +const BLOCK_SIZE_BYTES = 36u; const NQ = 16u; // number of weights per thread -const F16_PER_BLOCK = 18u; const WEIGHTS_PER_F16 = 2u; const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; @@ -260,16 +251,15 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) { let blck_idx = i / BLOCK_SIZE; let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; - let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK; + let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(src0[scale_idx]); - let m = src0[scale_idx + 1u]; + let d = f32(load_src0_f16_at(block_byte_base)); + let m = load_src0_f16_at(block_byte_base + 2u); for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_0 = src0[scale_idx + 2u + block_offset + j]; - let q_1 = src0[scale_idx + 2u + block_offset + j + 1]; - let q_packed = bitcast(vec2(q_0, q_1)); + let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j); + let q_packed = load_src0_u32_at(q_byte_offset); for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte_i32(q_packed, k); let q_val = f32(q_byte) * d + f32(m); @@ -284,13 +274,7 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { #ifdef MUL_ACC_Q6_K const BLOCK_SIZE = 256u; -const F16_PER_BLOCK = 105u; - -fn load_u32_at(bbase: u32, byte_offset: u32) -> u32 { - let aligned = byte_offset & ~3u; - let idx = bbase + aligned / 2u; - return bitcast(vec2(src0[idx], src0[idx + 1u])); -} +const BLOCK_SIZE_BYTES = 210u; fn byte_of(v: u32, b: u32) -> u32 { return (v >> (b * 8u)) & 0xFFu; @@ -323,16 +307,15 @@ fn mul_acc(tig: u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { var local_sum = 0.0; for (var i = ix; i < nb; i += 2u) { - let bbase = (idx_base + k_block_start + i) * F16_PER_BLOCK; + let bbase = (idx_base + k_block_start + i) * BLOCK_SIZE_BYTES; - let d_raw = load_u32_at(bbase, 208u); - let d = f32(bitcast>(d_raw)[0]); + let d = f32(load_src0_f16_at(bbase + 208u)); - let ql1_u32 = load_u32_at(bbase, q_offset_l); - let ql2_u32 = load_u32_at(bbase, q_offset_l + 32u); - let qh_u32 = load_u32_at(bbase, 128u + q_offset_h); - let sc_u32_0 = load_u32_at(bbase, sc_base_byte); - let sc_u32_1 = load_u32_at(bbase, sc_base_byte + 4u); + let ql1_u32 = load_src0_u32_at(bbase + q_offset_l); + let ql2_u32 = load_src0_u32_at(bbase + q_offset_l + 32u); + let qh_u32 = load_src0_u32_at(bbase + 128u + q_offset_h); + let sc_u32_0 = load_src0_u32_at(bbase + sc_base_byte); + let sc_u32_1 = load_src0_u32_at(bbase + sc_base_byte + 4u); let sc0 = sbyte_of(sc_u32_0, sc_byte_pos); let sc2 = sbyte_of(sc_u32_0, sc_byte_pos + 2u); From 933bd1f79c925f9f1d563854dd1fb4e40c288568 Mon Sep 17 00:00:00 2001 From: Anav Prasad Date: Wed, 1 Apr 2026 07:07:24 +0000 Subject: [PATCH 370/831] CUDA: Add Flash Attention Support for Head Dimension 512 (llama/20998) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * flash attention support for head dimension 512 added * FA D=512 - match 576 configs, limit ncols2, revert vec cap * fix HIP tile kernel build for D=512 * fix HIP tile kernel occupancy for D=512 on AMD * Apply suggestions from code review Co-authored-by: Johannes Gäßler * fix tile FA compilation --------- Co-authored-by: Johannes Gäßler --- ggml/src/ggml-cuda/fattn-mma-f16.cuh | 30 ++++++++++++++- ggml/src/ggml-cuda/fattn-tile.cu | 4 ++ ggml/src/ggml-cuda/fattn-tile.cuh | 37 +++++++++++++++---- ggml/src/ggml-cuda/fattn.cu | 11 ++++-- ...attn-mma-f16-instance-ncols1_1-ncols2_8.cu | 1 + ...ttn-mma-f16-instance-ncols1_16-ncols2_4.cu | 1 + ...attn-mma-f16-instance-ncols1_2-ncols2_4.cu | 1 + ...attn-mma-f16-instance-ncols1_2-ncols2_8.cu | 1 + ...attn-mma-f16-instance-ncols1_4-ncols2_4.cu | 1 + ...attn-mma-f16-instance-ncols1_4-ncols2_8.cu | 1 + ...attn-mma-f16-instance-ncols1_8-ncols2_4.cu | 1 + ...attn-mma-f16-instance-ncols1_8-ncols2_8.cu | 1 + .../fattn-tile-instance-dkq512-dv512.cu | 5 +++ .../template-instances/generate_cu_files.py | 4 +- 14 files changed, 86 insertions(+), 13 deletions(-) create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index fff70c8eb89..b613ae61fb8 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -66,6 +66,11 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 2, true); GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 32, 128, 128, 128, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 8, 64, 4, 32, 256, 256, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 256, 256, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 256, 1, 32, 128, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 128, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 128, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false); @@ -80,6 +85,11 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true); GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 8, 64, 4, 32, 96, 64, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 96, 64, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 256, 1, 32, 128, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 96, 64, 128, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 96, 64, 128, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false); @@ -89,6 +99,11 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co } static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_volta(const int DKQ, const int DV, const int ncols) { + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 8, 64, 4, 32, 256, 256, 64, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 256, 256, 64, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 64, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 256, 1, 32, 128, 128, 64, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 64, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 64, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 64, 1, false); @@ -103,6 +118,10 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true); GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 128, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 256, 1, 32, 128, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 96, 64, 128, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false); @@ -1552,7 +1571,7 @@ static __global__ void flash_attn_ext_f16( #if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE)) // Skip unused kernel variants for faster compilation: - if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) { + if (use_logit_softcap && !(DKQ == 128 || DKQ == 256 || DKQ == 512)) { NO_DEVICE_CODE; return; } @@ -1815,6 +1834,15 @@ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 112, 64) DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128, 64) DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 64) +extern DECL_FATTN_MMA_F16_CASE(512, 512, 2, 4); +extern DECL_FATTN_MMA_F16_CASE(512, 512, 4, 4); +extern DECL_FATTN_MMA_F16_CASE(512, 512, 8, 4); +extern DECL_FATTN_MMA_F16_CASE(512, 512, 16, 4); +extern DECL_FATTN_MMA_F16_CASE(512, 512, 1, 8); +extern DECL_FATTN_MMA_F16_CASE(512, 512, 2, 8); +extern DECL_FATTN_MMA_F16_CASE(512, 512, 4, 8); +extern DECL_FATTN_MMA_F16_CASE(512, 512, 8, 8); + // The number of viable configurations for Deepseek is very limited: extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16); extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16); diff --git a/ggml/src/ggml-cuda/fattn-tile.cu b/ggml/src/ggml-cuda/fattn-tile.cu index 3fcb09b7a2b..25b16e83cac 100644 --- a/ggml/src/ggml-cuda/fattn-tile.cu +++ b/ggml/src/ggml-cuda/fattn-tile.cu @@ -38,6 +38,10 @@ void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor GGML_ASSERT(V->ne[0] == K->ne[0]); ggml_cuda_flash_attn_ext_tile_case<256, 256>(ctx, dst); } break; + case 512: { + GGML_ASSERT(V->ne[0] == K->ne[0]); + ggml_cuda_flash_attn_ext_tile_case<512, 512>(ctx, dst); + } break; case 576: { GGML_ASSERT(V->ne[0] == 512); ggml_cuda_flash_attn_ext_tile_case<576, 512>(ctx, dst); diff --git a/ggml/src/ggml-cuda/fattn-tile.cuh b/ggml/src/ggml-cuda/fattn-tile.cuh index f3fa80ab23d..26721cc4c7d 100644 --- a/ggml/src/ggml-cuda/fattn-tile.cuh +++ b/ggml/src/ggml-cuda/fattn-tile.cuh @@ -68,6 +68,10 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64) @@ -124,6 +128,10 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 32, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 32, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 32, 64) @@ -187,6 +195,11 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 32, 512, 1, 128, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64) @@ -251,6 +264,11 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5, 32, 256) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3, 64, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 4, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 32, 256, 2, 128, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 4, 64, 64) @@ -767,7 +785,7 @@ static __global__ void flash_attn_tile( #ifdef GGML_USE_WMMA_FATTN (ncols2 != 1 && DV != 40 && DV != 72 && DV != 512) || #endif // GGML_USE_WMMA_FATTN - (use_logit_softcap && !(DV == 128 || DV == 256)) + (use_logit_softcap && !(DV == 128 || DV == 256 || DV == 512)) ) { GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, max_bias, m0, m1, n_head_log2, logit_softcap, @@ -1192,7 +1210,7 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm const int gqa_limit = nvidia && gqa_ratio <= 4 && DV <= 256 ? 16 : INT_MAX; const bool use_gqa_opt = mask && max_bias == 0.0f && Q->ne[1] <= gqa_limit && K->ne[1] % FATTN_KQ_STRIDE == 0; - if constexpr (DV == 512) { + if constexpr (DKQ == 576) { if (use_gqa_opt && gqa_ratio % 16 == 0) { launch_fattn_tile_switch_ncols1(ctx, dst); return; @@ -1203,7 +1221,7 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm } } - if constexpr (DV <= 256) { + if constexpr (DKQ <= 512) { if (use_gqa_opt && gqa_ratio % 8 == 0) { launch_fattn_tile_switch_ncols1(ctx, dst); return; @@ -1214,13 +1232,15 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm return; } - if (use_gqa_opt && gqa_ratio % 2 == 0) { - launch_fattn_tile_switch_ncols1(ctx, dst); + if constexpr (DV <= 256) { + if (use_gqa_opt && gqa_ratio % 2 == 0) { + launch_fattn_tile_switch_ncols1(ctx, dst); + return; + } + + launch_fattn_tile_switch_ncols1(ctx, dst); return; } - - launch_fattn_tile_switch_ncols1(ctx, dst); - return; } GGML_ABORT("fatal error"); } @@ -1255,4 +1275,5 @@ extern DECL_FATTN_TILE_CASE( 96, 96); extern DECL_FATTN_TILE_CASE(112, 112); extern DECL_FATTN_TILE_CASE(128, 128); extern DECL_FATTN_TILE_CASE(256, 256); +extern DECL_FATTN_TILE_CASE(512, 512); extern DECL_FATTN_TILE_CASE(576, 512); diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index a25a890db6d..a21c5361048 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -135,6 +135,10 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg GGML_ASSERT(V->ne[0] == 256); ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<256, 256>(ctx, dst); break; + case 512: + GGML_ASSERT(V->ne[0] == 512); + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<512, 512>(ctx, dst); + break; case 576: { // For Deepseek, go straight to the ncols1 switch to avoid compiling unnecessary kernels. GGML_ASSERT(V->ne[0] == 512); @@ -336,7 +340,8 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const case 128: case 112: case 256: - if (V->ne[0] != K->ne[0]) { + case 512: + if (!gqa_opt_applies) { return BEST_FATTN_KERNEL_NONE; } break; @@ -424,7 +429,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const } // Use the WMMA kernel if possible: - if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 576) { + if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 512 && Q->ne[0] != 576) { if (can_use_vector_kernel && Q->ne[1] <= 2) { return BEST_FATTN_KERNEL_VEC; } @@ -457,7 +462,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const } // Use MFMA flash attention for CDNA (MI100+): - if (amd_mfma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 256 && Q->ne[0] != 576) { + if (amd_mfma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 256 && Q->ne[0] != 512 && Q->ne[0] != 576) { const int64_t eff_nq = Q->ne[1] * (gqa_opt_applies ? gqa_ratio : 1); // MMA vs tile crossover benchmarked on MI300X @ d32768: // hsk=64 (gqa=4): MMA wins at eff >= 128 (+11%) diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu index dc16829021f..22d383173f3 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu @@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 1, 8); DECL_FATTN_MMA_F16_CASE(112, 112, 1, 8); DECL_FATTN_MMA_F16_CASE(128, 128, 1, 8); DECL_FATTN_MMA_F16_CASE(256, 256, 1, 8); +DECL_FATTN_MMA_F16_CASE(512, 512, 1, 8); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu index 517993cb068..d2415bfa957 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu @@ -8,4 +8,5 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 16, 4); DECL_FATTN_MMA_F16_CASE(112, 112, 16, 4); DECL_FATTN_MMA_F16_CASE(128, 128, 16, 4); DECL_FATTN_MMA_F16_CASE(256, 256, 16, 4); +DECL_FATTN_MMA_F16_CASE(512, 512, 16, 4); DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu index 97b19c67ade..8eec1d74e29 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu @@ -8,4 +8,5 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 2, 4); DECL_FATTN_MMA_F16_CASE(112, 112, 2, 4); DECL_FATTN_MMA_F16_CASE(128, 128, 2, 4); DECL_FATTN_MMA_F16_CASE(256, 256, 2, 4); +DECL_FATTN_MMA_F16_CASE(512, 512, 2, 4); DECL_FATTN_MMA_F16_CASE(576, 512, 2, 4); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu index 163b1d939e4..84b674cd05a 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu @@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 2, 8); DECL_FATTN_MMA_F16_CASE(112, 112, 2, 8); DECL_FATTN_MMA_F16_CASE(128, 128, 2, 8); DECL_FATTN_MMA_F16_CASE(256, 256, 2, 8); +DECL_FATTN_MMA_F16_CASE(512, 512, 2, 8); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu index 989626dfa5e..3475dfea08a 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu @@ -8,4 +8,5 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 4, 4); DECL_FATTN_MMA_F16_CASE(112, 112, 4, 4); DECL_FATTN_MMA_F16_CASE(128, 128, 4, 4); DECL_FATTN_MMA_F16_CASE(256, 256, 4, 4); +DECL_FATTN_MMA_F16_CASE(512, 512, 4, 4); DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu index bad296b4141..5906398db91 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu @@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 4, 8); DECL_FATTN_MMA_F16_CASE(112, 112, 4, 8); DECL_FATTN_MMA_F16_CASE(128, 128, 4, 8); DECL_FATTN_MMA_F16_CASE(256, 256, 4, 8); +DECL_FATTN_MMA_F16_CASE(512, 512, 4, 8); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu index 173de7aac7d..684cd25ce0d 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu @@ -8,4 +8,5 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 8, 4); DECL_FATTN_MMA_F16_CASE(112, 112, 8, 4); DECL_FATTN_MMA_F16_CASE(128, 128, 8, 4); DECL_FATTN_MMA_F16_CASE(256, 256, 8, 4); +DECL_FATTN_MMA_F16_CASE(512, 512, 8, 4); DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu index 680a13ca6de..4bc60d62f91 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu @@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 8, 8); DECL_FATTN_MMA_F16_CASE(112, 112, 8, 8); DECL_FATTN_MMA_F16_CASE(128, 128, 8, 8); DECL_FATTN_MMA_F16_CASE(256, 256, 8, 8); +DECL_FATTN_MMA_F16_CASE(512, 512, 8, 8); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu new file mode 100644 index 00000000000..7c61d8d2ecd --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.cuh" + +DECL_FATTN_TILE_CASE(512, 512); diff --git a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py index 3b5ab12fc40..b7b5832293e 100755 --- a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +++ b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py @@ -3,7 +3,7 @@ from glob import glob import os -HEAD_SIZES_KQ = [40, 64, 72, 80, 96, 112, 128, 256, 576] +HEAD_SIZES_KQ = [40, 64, 72, 80, 96, 112, 128, 256, 512, 576] TYPES_KV = ["GGML_TYPE_F16", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", "GGML_TYPE_BF16"] @@ -83,6 +83,8 @@ def get_short_name(long_quant_name): continue if head_size_kq == 72: continue + if head_size_kq == 512 and ncols2 not in (4, 8): + continue if head_size_kq != 576 and ncols2 in (16, 32): continue if head_size_kq == 576 and ncols2 not in (4, 16, 32): From 1b95f84550d32e59e9bbef4eaab0e0ce9240bf90 Mon Sep 17 00:00:00 2001 From: Taimur Ahmad Date: Wed, 1 Apr 2026 13:10:03 +0500 Subject: [PATCH 371/831] ggml-cpu: fix fallback for RVV kernels without zvfh (llama/21157) * ggml-cpu: refactor sgemm; fix rvv checks * ggml-cpu: refactor rvv kernels; set zvfbfwma default to off --- ggml/CMakeLists.txt | 19 +- ggml/src/ggml-cpu/llamafile/sgemm.cpp | 147 +++++++------ ggml/src/ggml-cpu/vec.h | 292 +++++++++++++------------- 3 files changed, 239 insertions(+), 219 deletions(-) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index a739cca4218..ab558438e95 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -166,15 +166,16 @@ if (NOT MSVC) option(GGML_AMX_INT8 "ggml: enable AMX-INT8" OFF) option(GGML_AMX_BF16 "ggml: enable AMX-BF16" OFF) endif() -option(GGML_LASX "ggml: enable lasx" ON) -option(GGML_LSX "ggml: enable lsx" ON) -option(GGML_RVV "ggml: enable rvv" ON) -option(GGML_RV_ZFH "ggml: enable riscv zfh" ON) -option(GGML_RV_ZVFH "ggml: enable riscv zvfh" ON) -option(GGML_RV_ZICBOP "ggml: enable riscv zicbop" ON) -option(GGML_RV_ZIHINTPAUSE "ggml: enable riscv zihintpause " ON) -option(GGML_XTHEADVECTOR "ggml: enable xtheadvector" OFF) -option(GGML_VXE "ggml: enable vxe" ${GGML_NATIVE}) +option(GGML_LASX "ggml: enable lasx" ON) +option(GGML_LSX "ggml: enable lsx" ON) +option(GGML_RVV "ggml: enable rvv" ON) +option(GGML_RV_ZFH "ggml: enable riscv zfh" ON) +option(GGML_RV_ZVFH "ggml: enable riscv zvfh" ON) +option(GGML_RV_ZICBOP "ggml: enable riscv zicbop" ON) +option(GGML_RV_ZIHINTPAUSE "ggml: enable riscv zihintpause" ON) +option(GGML_RV_ZVFBFWMA "ggml: enable riscv zvfbfwma" OFF) +option(GGML_XTHEADVECTOR "ggml: enable xtheadvector" OFF) +option(GGML_VXE "ggml: enable vxe" ${GGML_NATIVE}) option(GGML_CPU_ALL_VARIANTS "ggml: build all variants of the CPU backend (requires GGML_BACKEND_DL)" OFF) set(GGML_CPU_ARM_ARCH "" CACHE STRING "ggml: CPU architecture for ARM") diff --git a/ggml/src/ggml-cpu/llamafile/sgemm.cpp b/ggml/src/ggml-cpu/llamafile/sgemm.cpp index 63ceb635dea..34e320e2f50 100644 --- a/ggml/src/ggml-cpu/llamafile/sgemm.cpp +++ b/ggml/src/ggml-cpu/llamafile/sgemm.cpp @@ -180,44 +180,49 @@ inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) { } #endif +#if defined(__riscv_v_intrinsic) +template <> inline vfloat32m1_t madd(vfloat32m1_t a, vfloat32m1_t b, vfloat32m1_t c) { + return __riscv_vfmacc_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1()); +} +template <> inline vfloat32m2_t madd(vfloat32m2_t a, vfloat32m2_t b, vfloat32m2_t c) { + return __riscv_vfmacc_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2()); +} +template <> inline vfloat32m4_t madd(vfloat32m4_t a, vfloat32m4_t b, vfloat32m4_t c) { + return __riscv_vfmacc_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4()); +} +template <> inline vfloat32m8_t madd(vfloat32m8_t a, vfloat32m8_t b, vfloat32m8_t c) { + return __riscv_vfmacc_vv_f32m8(c, a, b, __riscv_vsetvlmax_e32m8()); +} +#endif + #if defined(__riscv_zvfh) -template <> -inline vfloat32m1_t madd(vfloat16mf2_t a, vfloat16mf2_t b, vfloat32m1_t c) { +template <> inline vfloat32m1_t madd(vfloat16mf2_t a, vfloat16mf2_t b, vfloat32m1_t c) { return __riscv_vfwmacc_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1()); } -inline vfloat32m2_t madd(vfloat16m1_t a, vfloat16m1_t b, vfloat32m2_t c) { +template <> inline vfloat32m2_t madd(vfloat16m1_t a, vfloat16m1_t b, vfloat32m2_t c) { return __riscv_vfwmacc_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2()); } -inline vfloat32m4_t madd(vfloat16m2_t a, vfloat16m2_t b, vfloat32m4_t c) { +template <> inline vfloat32m4_t madd(vfloat16m2_t a, vfloat16m2_t b, vfloat32m4_t c) { return __riscv_vfwmacc_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4()); } -inline vfloat32m8_t madd(vfloat16m4_t a, vfloat16m4_t b, vfloat32m8_t c) { +template <> inline vfloat32m8_t madd(vfloat16m4_t a, vfloat16m4_t b, vfloat32m8_t c) { return __riscv_vfwmacc_vv_f32m8(c, a, b, __riscv_vsetvlmax_e32m8()); } -inline vfloat32m1_t madd(vfloat32m1_t a, vfloat32m1_t b, vfloat32m1_t c) { - return __riscv_vfmacc_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1()); -} -inline vfloat32m2_t madd(vfloat32m2_t a, vfloat32m2_t b, vfloat32m2_t c) { - return __riscv_vfmacc_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2()); -} -inline vfloat32m4_t madd(vfloat32m4_t a, vfloat32m4_t b, vfloat32m4_t c) { - return __riscv_vfmacc_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4()); -} -inline vfloat32m8_t madd(vfloat32m8_t a, vfloat32m8_t b, vfloat32m8_t c) { - return __riscv_vfmacc_vv_f32m8(c, a, b, __riscv_vsetvlmax_e32m8()); -} #endif #if defined(__riscv_zvfbfwma) -inline vfloat32m1_t madd(vbfloat16mf2_t a, vbfloat16mf2_t b, vfloat32m1_t c) { +template <> inline vfloat32m1_t madd(vbfloat16mf2_t a, vbfloat16mf2_t b, vfloat32m1_t c) { return __riscv_vfwmaccbf16_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1()); } -inline vfloat32m2_t madd(vbfloat16m1_t a, vbfloat16m1_t b, vfloat32m2_t c) { +template <> inline vfloat32m2_t madd(vbfloat16m1_t a, vbfloat16m1_t b, vfloat32m2_t c) { return __riscv_vfwmaccbf16_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2()); } -inline vfloat32m4_t madd(vbfloat16m2_t a, vbfloat16m2_t b, vfloat32m4_t c) { +template <> inline vfloat32m4_t madd(vbfloat16m2_t a, vbfloat16m2_t b, vfloat32m4_t c) { return __riscv_vfwmaccbf16_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4()); } +template <> inline vfloat32m8_t madd(vbfloat16m4_t a, vbfloat16m4_t b, vfloat32m8_t c) { + return __riscv_vfwmaccbf16_vv_f32m8(c, a, b, __riscv_vsetvlmax_e32m8()); +} #endif //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -272,7 +277,7 @@ inline float hsum(__m512 x) { } #endif // __AVX512F__ -#if defined(__riscv_zvfh) +#if defined(__riscv_v_intrinsic) inline float hsum(vfloat32m1_t x) { return __riscv_vfmv_f_s_f32m1_f32( __riscv_vfredusum_vs_f32m1_f32m1(x, __riscv_vfmv_v_f_f32m1(0, 1), __riscv_vsetvlmax_e32m1())); @@ -379,6 +384,21 @@ template <> inline __m256bh load(const float *p) { } #endif +#if defined(__riscv_v_intrinsic) +template <> inline vfloat32m1_t load(const float *p) { + return __riscv_vle32_v_f32m1(p, __riscv_vsetvlmax_e32m1()); +} +template <> inline vfloat32m2_t load(const float *p) { + return __riscv_vle32_v_f32m2(p, __riscv_vsetvlmax_e32m2()); +} +template <> inline vfloat32m4_t load(const float *p) { + return __riscv_vle32_v_f32m4(p, __riscv_vsetvlmax_e32m4()); +} +template <> inline vfloat32m8_t load(const float *p) { + return __riscv_vle32_v_f32m8(p, __riscv_vsetvlmax_e32m8()); +} +#endif + #if defined(__riscv_zvfh) template <> inline vfloat16mf2_t load(const ggml_fp16_t *p) { return __riscv_vle16_v_f16mf2(reinterpret_cast(p), __riscv_vsetvlmax_e16mf2()); @@ -392,18 +412,6 @@ template <> inline vfloat16m2_t load(const ggml_fp16_t *p) { template <> inline vfloat16m4_t load(const ggml_fp16_t *p) { return __riscv_vle16_v_f16m4(reinterpret_cast(p), __riscv_vsetvlmax_e16m4()); } -template <> inline vfloat32m1_t load(const float *p) { - return __riscv_vle32_v_f32m1(p, __riscv_vsetvlmax_e32m1()); -} -template <> inline vfloat32m2_t load(const float *p) { - return __riscv_vle32_v_f32m2(p, __riscv_vsetvlmax_e32m2()); -} -template <> inline vfloat32m4_t load(const float *p) { - return __riscv_vle32_v_f32m4(p, __riscv_vsetvlmax_e32m4()); -} -template <> inline vfloat32m8_t load(const float *p) { - return __riscv_vle32_v_f32m8(p, __riscv_vsetvlmax_e32m8()); -} #endif #if defined(__riscv_zvfbfwma) @@ -416,23 +424,14 @@ template <> inline vbfloat16m1_t load(const ggml_bf16_t *p) { template <> inline vbfloat16m2_t load(const ggml_bf16_t *p) { return __riscv_vle16_v_bf16m2(reinterpret_cast(p), __riscv_vsetvlmax_e16m2()); } +template <> inline vbfloat16m4_t load(const ggml_bf16_t *p) { + return __riscv_vle16_v_bf16m4(reinterpret_cast(p), __riscv_vsetvlmax_e16m4()); +} #endif -#if defined(__riscv_zvfh) +#if defined(__riscv_v_intrinsic) template T set_zero(); -template <> inline vfloat16mf2_t set_zero() { - return __riscv_vfmv_v_f_f16mf2(0, __riscv_vsetvlmax_e16mf2()); -} -template <> inline vfloat16m1_t set_zero() { - return __riscv_vfmv_v_f_f16m1(0, __riscv_vsetvlmax_e16m1()); -} -template <> inline vfloat16m2_t set_zero() { - return __riscv_vfmv_v_f_f16m2(0, __riscv_vsetvlmax_e16m2()); -} -template <> inline vfloat16m4_t set_zero() { - return __riscv_vfmv_v_f_f16m4(0, __riscv_vsetvlmax_e16m4()); -} template <> inline vfloat32m1_t set_zero() { return __riscv_vfmv_v_f_f32m1(0.0f, __riscv_vsetvlmax_e32m1()); } @@ -449,14 +448,22 @@ template <> inline vfloat32m8_t set_zero() { #if defined(__riscv_v_intrinsic) template size_t vlmax() { - if constexpr (std::is_same_v) { return __riscv_vsetvlmax_e16mf2(); } - else if constexpr (std::is_same_v) { return __riscv_vsetvlmax_e16m1(); } - else if constexpr (std::is_same_v) { return __riscv_vsetvlmax_e16m2(); } - else if constexpr (std::is_same_v) { return __riscv_vsetvlmax_e16m4(); } - else if constexpr (std::is_same_v) { return __riscv_vsetvlmax_e32m1(); } + if constexpr (std::is_same_v) { return __riscv_vsetvlmax_e32m1(); } else if constexpr (std::is_same_v) { return __riscv_vsetvlmax_e32m2(); } else if constexpr (std::is_same_v) { return __riscv_vsetvlmax_e32m4(); } else if constexpr (std::is_same_v) { return __riscv_vsetvlmax_e32m8(); } + #if defined (__riscv_zvfh) + else if constexpr (std::is_same_v) { return __riscv_vsetvlmax_e16mf2(); } + else if constexpr (std::is_same_v) { return __riscv_vsetvlmax_e16m1(); } + else if constexpr (std::is_same_v) { return __riscv_vsetvlmax_e16m2(); } + else if constexpr (std::is_same_v) { return __riscv_vsetvlmax_e16m4(); } + #endif + #if defined (__riscv_zvfbfwma) + else if constexpr (std::is_same_v) { return __riscv_vsetvlmax_e16mf2(); } + else if constexpr (std::is_same_v) { return __riscv_vsetvlmax_e16m1(); } + else if constexpr (std::is_same_v) { return __riscv_vsetvlmax_e16m2(); } + else if constexpr (std::is_same_v) { return __riscv_vsetvlmax_e16m4(); } + #endif return 0; } #endif @@ -3740,7 +3747,7 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64 params->ith, params->nth}; tb.matmul(m, n); return true; -#elif defined(__riscv_zvfh) +#elif defined(__riscv_v_intrinsic) #if LMUL == 1 tinyBLAS_RVV tb{ params, k, (const float *)A, lda, @@ -3804,23 +3811,25 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64 return true; } #elif defined(__riscv_zvfbfwma) - #if LMUL == 1 - tinyBLAS_RVV tb{ params, - k, (const ggml_bf16_t *)A, lda, - (const ggml_bf16_t *)B, ldb, - (float *)C, ldc}; - #elif LMUL == 2 - tinyBLAS_RVV tb{ params, - k, (const ggml_bf16_t *)A, lda, - (const ggml_bf16_t *)B, ldb, - (float *)C, ldc}; - #else // LMUL = 4 - tinyBLAS_RVV tb{ params, - k, (const ggml_bf16_t *)A, lda, - (const ggml_bf16_t *)B, ldb, - (float *)C, ldc}; - #endif - return tb.matmul(m, n); + if (Btype == GGML_TYPE_BF16) { + #if LMUL == 1 + tinyBLAS_RVV tb{ params, + k, (const ggml_bf16_t *)A, lda, + (const ggml_bf16_t *)B, ldb, + (float *)C, ldc}; + #elif LMUL == 2 + tinyBLAS_RVV tb{ params, + k, (const ggml_bf16_t *)A, lda, + (const ggml_bf16_t *)B, ldb, + (float *)C, ldc}; + #else // LMUL = 4 + tinyBLAS_RVV tb{ params, + k, (const ggml_bf16_t *)A, lda, + (const ggml_bf16_t *)B, ldb, + (float *)C, ldc}; + #endif + return tb.matmul(m, n); + } #endif return false; } diff --git a/ggml/src/ggml-cpu/vec.h b/ggml/src/ggml-cpu/vec.h index 3198b33b509..a0375a28de0 100644 --- a/ggml/src/ggml-cpu/vec.h +++ b/ggml/src/ggml-cpu/vec.h @@ -126,7 +126,7 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG const int ggml_f16_epr = sve_register_length / 16; // running when 16 const int ggml_f16_step = 8 * ggml_f16_epr; // choose 8 SVE registers - const int np = (n & ~(ggml_f16_step - 1)); + int np = (n & ~(ggml_f16_step - 1)); svfloat16_t sum_00 = svdup_n_f16(0.0f); svfloat16_t sum_01 = svdup_n_f16(0.0f); @@ -224,71 +224,75 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG } GGML_F16x_VEC_REDUCE(sumf[0], sum_00, sum_01, sum_02, sum_03); GGML_F16x_VEC_REDUCE(sumf[1], sum_10, sum_11, sum_12, sum_13); + np = n; + #elif defined(__riscv_v_intrinsic) + #if defined(__riscv_zvfh) + size_t vl = __riscv_vsetvlmax_e32m4(); + + // initialize accumulators to all zeroes + vfloat32m4_t vsum0_0 = __riscv_vfmv_v_f_f32m4(0.0f, vl); + vfloat32m4_t vsum0_1 = __riscv_vfmv_v_f_f32m4(0.0f, vl); + vfloat32m4_t vsum1_0 = __riscv_vfmv_v_f_f32m4(0.0f, vl); + vfloat32m4_t vsum1_1 = __riscv_vfmv_v_f_f32m4(0.0f, vl); + + // calculate step size + const size_t epr = __riscv_vsetvlmax_e16m2(); + const size_t step = epr * 2; + int np = (n & ~(step - 1)); + + // unroll by 2 along the row dimension + for (int i = 0; i < np; i += step) { + vfloat16m2_t ay0 = __riscv_vle16_v_f16m2((const _Float16 *)(y + i), epr); + vfloat16m2_t ax0_0 = __riscv_vle16_v_f16m2((const _Float16 *)(x[0] + i), epr); + vfloat16m2_t ax1_0 = __riscv_vle16_v_f16m2((const _Float16 *)(x[1] + i), epr); + vsum0_0 = __riscv_vfwmacc_vv_f32m4(vsum0_0, ax0_0, ay0, epr); + vsum1_0 = __riscv_vfwmacc_vv_f32m4(vsum1_0, ax1_0, ay0, epr); + + vfloat16m2_t ay1 = __riscv_vle16_v_f16m2((const _Float16 *)(y + i + epr), epr); + vfloat16m2_t ax0_1 = __riscv_vle16_v_f16m2((const _Float16 *)(x[0] + i + epr), epr); + vfloat16m2_t ax1_1 = __riscv_vle16_v_f16m2((const _Float16 *)(x[1] + i + epr), epr); + vsum0_1 = __riscv_vfwmacc_vv_f32m4(vsum0_1, ax0_1, ay1, epr); + vsum1_1 = __riscv_vfwmacc_vv_f32m4(vsum1_1, ax1_1, ay1, epr); + } - #elif defined(__riscv_v_intrinsic) && defined(__riscv_zvfh) - size_t vl = __riscv_vsetvlmax_e32m4(); - - // initialize accumulators to all zeroes - vfloat32m4_t vsum0_0 = __riscv_vfmv_v_f_f32m4(0.0f, vl); - vfloat32m4_t vsum0_1 = __riscv_vfmv_v_f_f32m4(0.0f, vl); - vfloat32m4_t vsum1_0 = __riscv_vfmv_v_f_f32m4(0.0f, vl); - vfloat32m4_t vsum1_1 = __riscv_vfmv_v_f_f32m4(0.0f, vl); - - // calculate step size - const size_t epr = __riscv_vsetvlmax_e16m2(); - const size_t step = epr * 2; - const int np = (n & ~(step - 1)); - - // unroll by 2 along the row dimension - for (int i = 0; i < np; i += step) { - vfloat16m2_t ay0 = __riscv_vle16_v_f16m2((const _Float16 *)(y + i), epr); - vfloat16m2_t ax0_0 = __riscv_vle16_v_f16m2((const _Float16 *)(x[0] + i), epr); - vfloat16m2_t ax1_0 = __riscv_vle16_v_f16m2((const _Float16 *)(x[1] + i), epr); - vsum0_0 = __riscv_vfwmacc_vv_f32m4(vsum0_0, ax0_0, ay0, epr); - vsum1_0 = __riscv_vfwmacc_vv_f32m4(vsum1_0, ax1_0, ay0, epr); - - vfloat16m2_t ay1 = __riscv_vle16_v_f16m2((const _Float16 *)(y + i + epr), epr); - vfloat16m2_t ax0_1 = __riscv_vle16_v_f16m2((const _Float16 *)(x[0] + i + epr), epr); - vfloat16m2_t ax1_1 = __riscv_vle16_v_f16m2((const _Float16 *)(x[1] + i + epr), epr); - vsum0_1 = __riscv_vfwmacc_vv_f32m4(vsum0_1, ax0_1, ay1, epr); - vsum1_1 = __riscv_vfwmacc_vv_f32m4(vsum1_1, ax1_1, ay1, epr); - } - - vfloat32m4_t vsum0 = __riscv_vfadd_vv_f32m4(vsum0_0, vsum0_1, vl); - vfloat32m4_t vsum1 = __riscv_vfadd_vv_f32m4(vsum1_0, vsum1_1, vl); - - // leftovers - for (int i = np; i < n; i += vl) { - vl = __riscv_vsetvl_e16m2(n - i); - vfloat16m2_t ay = __riscv_vle16_v_f16m2((const _Float16 *)(y + i), vl); - vfloat16m2_t ax0 = __riscv_vle16_v_f16m2((const _Float16 *)(x[0] + i), vl); - vfloat16m2_t ax1 = __riscv_vle16_v_f16m2((const _Float16 *)(x[1] + i), vl); + vfloat32m4_t vsum0 = __riscv_vfadd_vv_f32m4(vsum0_0, vsum0_1, vl); + vfloat32m4_t vsum1 = __riscv_vfadd_vv_f32m4(vsum1_0, vsum1_1, vl); - vsum0 = __riscv_vfwmacc_vv_f32m4(vsum0, ax0, ay, vl); - vsum1 = __riscv_vfwmacc_vv_f32m4(vsum1, ax1, ay, vl); - } + // leftovers + for (int i = np; i < n; i += vl) { + vl = __riscv_vsetvl_e16m2(n - i); + vfloat16m2_t ay = __riscv_vle16_v_f16m2((const _Float16 *)(y + i), vl); + vfloat16m2_t ax0 = __riscv_vle16_v_f16m2((const _Float16 *)(x[0] + i), vl); + vfloat16m2_t ax1 = __riscv_vle16_v_f16m2((const _Float16 *)(x[1] + i), vl); - // reduce - vl = __riscv_vsetvlmax_e32m2(); - vfloat32m2_t acc0_0 = __riscv_vfadd_vv_f32m2(__riscv_vget_v_f32m4_f32m2(vsum0, 0), - __riscv_vget_v_f32m4_f32m2(vsum0, 1), vl); - vl = __riscv_vsetvlmax_e32m1(); - vfloat32m1_t acc0_1 = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m2_f32m1(acc0_0, 0), - __riscv_vget_v_f32m2_f32m1(acc0_0, 1), vl); - vfloat32m1_t redsum0 = __riscv_vfredusum_vs_f32m1_f32m1( - acc0_1, __riscv_vfmv_v_f_f32m1(0.0f, 1), vl); - - vl = __riscv_vsetvlmax_e32m2(); - vfloat32m2_t acc1_0 = __riscv_vfadd_vv_f32m2(__riscv_vget_v_f32m4_f32m2(vsum1, 0), - __riscv_vget_v_f32m4_f32m2(vsum1, 1), vl); - vl = __riscv_vsetvlmax_e32m1(); - vfloat32m1_t acc1_1 = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m2_f32m1(acc1_0, 0), - __riscv_vget_v_f32m2_f32m1(acc1_0, 1), vl); - vfloat32m1_t redsum1 = __riscv_vfredusum_vs_f32m1_f32m1( - acc1_1, __riscv_vfmv_v_f_f32m1(0.0f, 1), vl); - sumf[0] = __riscv_vfmv_f_s_f32m1_f32(redsum0); - sumf[1] = __riscv_vfmv_f_s_f32m1_f32(redsum1); + vsum0 = __riscv_vfwmacc_vv_f32m4(vsum0, ax0, ay, vl); + vsum1 = __riscv_vfwmacc_vv_f32m4(vsum1, ax1, ay, vl); + } + // reduce + vl = __riscv_vsetvlmax_e32m2(); + vfloat32m2_t acc0_0 = __riscv_vfadd_vv_f32m2(__riscv_vget_v_f32m4_f32m2(vsum0, 0), + __riscv_vget_v_f32m4_f32m2(vsum0, 1), vl); + vl = __riscv_vsetvlmax_e32m1(); + vfloat32m1_t acc0_1 = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m2_f32m1(acc0_0, 0), + __riscv_vget_v_f32m2_f32m1(acc0_0, 1), vl); + vfloat32m1_t redsum0 = __riscv_vfredusum_vs_f32m1_f32m1( + acc0_1, __riscv_vfmv_v_f_f32m1(0.0f, 1), vl); + + vl = __riscv_vsetvlmax_e32m2(); + vfloat32m2_t acc1_0 = __riscv_vfadd_vv_f32m2(__riscv_vget_v_f32m4_f32m2(vsum1, 0), + __riscv_vget_v_f32m4_f32m2(vsum1, 1), vl); + vl = __riscv_vsetvlmax_e32m1(); + vfloat32m1_t acc1_1 = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m2_f32m1(acc1_0, 0), + __riscv_vget_v_f32m2_f32m1(acc1_0, 1), vl); + vfloat32m1_t redsum1 = __riscv_vfredusum_vs_f32m1_f32m1( + acc1_1, __riscv_vfmv_v_f_f32m1(0.0f, 1), vl); + sumf[0] = __riscv_vfmv_f_s_f32m1_f32(redsum0); + sumf[1] = __riscv_vfmv_f_s_f32m1_f32(redsum1); + np = n; + #else + const int np = 0; + #endif #else const int np = (n & ~(GGML_F16_STEP - 1)); @@ -313,21 +317,17 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) { GGML_F16_VEC_REDUCE(sumf[k], sum[k]); } - - // leftovers - for (int i = np; i < n; ++i) { - for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) { - sumf[j] += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[j][i])*GGML_CPU_FP16_TO_FP32(y[i])); - } - } #endif #else - for (int i = 0; i < n; ++i) { + // scalar path + const int np = 0; +#endif + // scalar and leftovers + for (int i = np; i < n; ++i) { for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) { sumf[j] += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[j][i])*GGML_CPU_FP16_TO_FP32(y[i])); } } -#endif for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) { s[i] = (float)sumf[i]; @@ -532,40 +532,45 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * GGML_RESTRICT y, svst1_f16(pg, (__fp16 *)(y + np2), hy); } np = n; -#elif defined(__riscv_zvfh) // implies __riscv_v_intrinsic - const ggml_fp16_t s = GGML_CPU_FP32_TO_FP16(v); - const _Float16 scale = *(const _Float16*)(&s); - - // calculate step size - const int epr = __riscv_vsetvlmax_e16m4(); - const int step = epr * 2; - int np = (n & ~(step - 1)); - - // unroll by 2 - for (int i = 0; i < np; i += step) { - vfloat16m4_t ax0 = __riscv_vle16_v_f16m4((const _Float16*)x + i, epr); - vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, epr); - ay0 = __riscv_vfmacc_vf_f16m4(ay0, scale, ax0, epr); - __riscv_vse16_v_f16m4((_Float16*)y + i, ay0, epr); - __asm__ __volatile__ ("" ::: "memory"); - - vfloat16m4_t ax1 = __riscv_vle16_v_f16m4((const _Float16*)x + i + epr, epr); - vfloat16m4_t ay1 = __riscv_vle16_v_f16m4((const _Float16*)y + i + epr, epr); - ay1 = __riscv_vfmacc_vf_f16m4(ay1, scale, ax1, epr); - __riscv_vse16_v_f16m4((_Float16*)y + i + epr, ay1, epr); - __asm__ __volatile__ ("" ::: "memory"); - } +#elif defined(__riscv_v_intrinsic) // implies __riscv_v_intrinsic + #if defined (__riscv_zvfh) + const ggml_fp16_t s = GGML_CPU_FP32_TO_FP16(v); + const _Float16 scale = *(const _Float16*)(&s); - // leftovers - int vl; - for (int i = np; i < n; i += vl) { - vl = __riscv_vsetvl_e16m4(n - i); - vfloat16m4_t ax0 = __riscv_vle16_v_f16m4((const _Float16*)x + i, vl); - vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, vl); - ay0 = __riscv_vfmacc_vf_f16m4(ay0, scale, ax0, vl); - __riscv_vse16_v_f16m4((_Float16*)y + i, ay0, vl); - } - np = n; + // calculate step size + const int epr = __riscv_vsetvlmax_e16m4(); + const int step = epr * 2; + int np = (n & ~(step - 1)); + + // unroll by 2 + for (int i = 0; i < np; i += step) { + vfloat16m4_t ax0 = __riscv_vle16_v_f16m4((const _Float16*)x + i, epr); + vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, epr); + ay0 = __riscv_vfmacc_vf_f16m4(ay0, scale, ax0, epr); + __riscv_vse16_v_f16m4((_Float16*)y + i, ay0, epr); + __asm__ __volatile__ ("" ::: "memory"); + + vfloat16m4_t ax1 = __riscv_vle16_v_f16m4((const _Float16*)x + i + epr, epr); + vfloat16m4_t ay1 = __riscv_vle16_v_f16m4((const _Float16*)y + i + epr, epr); + ay1 = __riscv_vfmacc_vf_f16m4(ay1, scale, ax1, epr); + __riscv_vse16_v_f16m4((_Float16*)y + i + epr, ay1, epr); + __asm__ __volatile__ ("" ::: "memory"); + } + + // leftovers + int vl; + for (int i = np; i < n; i += vl) { + vl = __riscv_vsetvl_e16m4(n - i); + vfloat16m4_t ax0 = __riscv_vle16_v_f16m4((const _Float16*)x + i, vl); + vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, vl); + ay0 = __riscv_vfmacc_vf_f16m4(ay0, scale, ax0, vl); + __riscv_vse16_v_f16m4((_Float16*)y + i, ay0, vl); + } + np = n; + #else + // fall to scalar path + const int np = 0; + #endif #elif defined(GGML_SIMD) const int np = (n & ~(GGML_F16_STEP - 1)); @@ -584,10 +589,11 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * GGML_RESTRICT y, } } #else + // scalar path const int np = 0; #endif - // leftovers + // scalar and leftovers for (int i = np; i < n; ++i) { y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v); } @@ -785,7 +791,7 @@ inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float const int ggml_f16_step = 2 * ggml_f16_epr; GGML_F16x_VEC vx = GGML_F16x_VEC_SET1(v); - const int np = (n & ~(ggml_f16_step - 1)); + int np = (n & ~(ggml_f16_step - 1)); svfloat16_t ay1, ay2; for (int i = 0; i < np; i += ggml_f16_step) { @@ -805,36 +811,43 @@ inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float svfloat16_t out = svmul_f16_m(pg, hy, vx); svst1_f16(pg, (__fp16 *)(y + np), out); } -#elif defined(__riscv_v_intrinsic) && defined(__riscv_zvfh) - const ggml_fp16_t s = GGML_CPU_FP32_TO_FP16(v); - const _Float16 scale = *(const _Float16*)(&s); - - // calculate step size - const int epr = __riscv_vsetvlmax_e16m4(); - const int step = epr * 2; - const int np = (n & ~(step - 1)); + np = n; +#elif defined(__riscv_v_intrinsic) + #if defined(__riscv_zvfh) + const ggml_fp16_t s = GGML_CPU_FP32_TO_FP16(v); + const _Float16 scale = *(const _Float16*)(&s); - // unroll by 2 - for (int i = 0; i < np; i += step) { - vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, epr); - ay0 = __riscv_vfmul_vf_f16m4(ay0, scale, epr); - __riscv_vse16_v_f16m4((_Float16*)y + i, ay0, epr); - __asm__ __volatile__ ("" ::: "memory"); + // calculate step size + const int epr = __riscv_vsetvlmax_e16m4(); + const int step = epr * 2; + int np = (n & ~(step - 1)); - vfloat16m4_t ay1 = __riscv_vle16_v_f16m4((const _Float16*)y + i + epr, epr); - ay1 = __riscv_vfmul_vf_f16m4(ay1, scale, epr); - __riscv_vse16_v_f16m4((_Float16*)y + i + epr, ay1, epr); - __asm__ __volatile__ ("" ::: "memory"); - } + // unroll by 2 + for (int i = 0; i < np; i += step) { + vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, epr); + ay0 = __riscv_vfmul_vf_f16m4(ay0, scale, epr); + __riscv_vse16_v_f16m4((_Float16*)y + i, ay0, epr); + __asm__ __volatile__ ("" ::: "memory"); + + vfloat16m4_t ay1 = __riscv_vle16_v_f16m4((const _Float16*)y + i + epr, epr); + ay1 = __riscv_vfmul_vf_f16m4(ay1, scale, epr); + __riscv_vse16_v_f16m4((_Float16*)y + i + epr, ay1, epr); + __asm__ __volatile__ ("" ::: "memory"); + } - // leftovers - int vl; - for (int i = np; i < n; i += vl) { - vl = __riscv_vsetvl_e16m4(n - i); - vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, vl); - ay0 = __riscv_vfmul_vf_f16m4(ay0, scale, vl); - __riscv_vse16_v_f16m4((_Float16*)y + i, ay0, vl); - } + // leftovers + int vl; + for (int i = np; i < n; i += vl) { + vl = __riscv_vsetvl_e16m4(n - i); + vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, vl); + ay0 = __riscv_vfmul_vf_f16m4(ay0, scale, vl); + __riscv_vse16_v_f16m4((_Float16*)y + i, ay0, vl); + } + np = n; + #else + // fall to scalar path + const int np = 0; + #endif #elif defined(GGML_SIMD) const int np = (n & ~(GGML_F16_STEP - 1)); @@ -850,17 +863,14 @@ inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j); } } - - // leftovers - for (int i = np; i < n; ++i) { - y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v); - } #else - // scalar - for (int i = 0; i < n; ++i) { + // scalar path + const int np = 0; +#endif + // scalar and leftovers + for (int i = np; i < n; ++i) { y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v); } -#endif } inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, 0, x, 0, x, 0, 1); *s = sqrtf(*s); } From 5c5b88eb779cbd37a32c209c2034e6b56b55c4fe Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 1 Apr 2026 11:10:25 +0300 Subject: [PATCH 372/831] ggml : fix RWKV ops thread assignment (llama/21226) --- ggml/src/ggml-cpu/ggml-cpu.c | 6 +++++- ggml/src/ggml-cpu/ops.cpp | 30 +++++++++--------------------- 2 files changed, 14 insertions(+), 22 deletions(-) diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index df17cc55300..7486acc2b5d 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -2350,11 +2350,15 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_FLASH_ATTN_BACK: case GGML_OP_SSM_CONV: case GGML_OP_SSM_SCAN: + { + n_tasks = n_threads; + } break; case GGML_OP_RWKV_WKV6: case GGML_OP_GATED_LINEAR_ATTN: case GGML_OP_RWKV_WKV7: { - n_tasks = n_threads; + const int64_t n_heads = node->src[1]->ne[1]; + n_tasks = MIN(n_threads, n_heads); } break; case GGML_OP_WIN_PART: case GGML_OP_WIN_UNPART: diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index d950972c83e..765ce07f06c 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -9953,13 +9953,9 @@ static void ggml_compute_forward_rwkv_wkv6_f32( const int ith = params->ith; const int nth = params->nth; - if (ith >= HEADS) { - return; - } - - const int h_start = (HEADS * ith) / nth; - const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ? - (HEADS * (ith + 1)) / nth : HEADS; + const int h_start = (HEADS * (ith )) / nth; + const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ? + (HEADS * (ith + 1)) / nth : HEADS; float * k = (float *) dst->src[0]->data; float * v = (float *) dst->src[1]->data; @@ -10170,13 +10166,9 @@ static void ggml_compute_forward_gla_f32( const int ith = params->ith; const int nth = params->nth; - if (ith >= HEADS) { - return; - } - - const int h_start = (HEADS * ith) / nth; - const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ? - (HEADS * (ith + 1)) / nth : HEADS; + const int h_start = (HEADS * (ith )) / nth; + const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ? + (HEADS * (ith + 1)) / nth : HEADS; float * k = (float *) dst->src[0]->data; float * v = (float *) dst->src[1]->data; @@ -10633,13 +10625,9 @@ static void ggml_compute_forward_rwkv_wkv7_f32( const int ith = params->ith; const int nth = params->nth; - if (ith >= HEADS) { - return; - } - - const int h_start = (HEADS * ith) / nth; - const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ? - (HEADS * (ith + 1)) / nth : HEADS; + const int h_start = (HEADS * (ith )) / nth; + const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ? + (HEADS * (ith + 1)) / nth : HEADS; float * r = (float *) dst->src[0]->data; float * w = (float *) dst->src[1]->data; From 1971a362dc008312762ed208cd0296bc23717901 Mon Sep 17 00:00:00 2001 From: uvos Date: Wed, 1 Apr 2026 10:21:20 +0200 Subject: [PATCH 373/831] CUDA/HIP: Fix kernel slection for mmvq mmid kernel to align host selection with device launch bounds (llama/21238) The conditions cc == GGML_CUDA_CC_VOLTA || cc >= GGML_CUDA_CC_ADA_LOVELACE and cc >= GGML_CUDA_CC_TURING match all non-nvidia devices. This causes us to attempt to launch the kernel for batch sizes with larger configurations than our launch bounds on HIP devices. This pr fixes the conditionals in get_mmvq_mmid_max_batch. Fixes #21191 --- ggml/src/ggml-cuda/mmvq.cu | 43 ++++++++++++++++++++------------------ 1 file changed, 23 insertions(+), 20 deletions(-) diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 8d80d1dd9a7..07b10167bc4 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -235,30 +235,33 @@ static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_rdna4(ggml_type // Host function: returns the max batch size for the current arch+type at runtime. int get_mmvq_mmid_max_batch(ggml_type type, int cc) { // NVIDIA: Volta, Ada Lovelace, and Blackwell always use MMVQ for MUL_MAT_ID. - if (cc == GGML_CUDA_CC_VOLTA || cc >= GGML_CUDA_CC_ADA_LOVELACE) { - return MMVQ_MAX_BATCH_SIZE; - } - if (cc >= GGML_CUDA_CC_TURING) { - return get_mmvq_mmid_max_batch_turing_plus(type); - } if (GGML_CUDA_CC_IS_NVIDIA(cc)) { + if (cc == GGML_CUDA_CC_VOLTA || cc >= GGML_CUDA_CC_ADA_LOVELACE) { + return MMVQ_MAX_BATCH_SIZE; + } + if (cc >= GGML_CUDA_CC_TURING) { + return get_mmvq_mmid_max_batch_turing_plus(type); + } return get_mmvq_mmid_max_batch_pascal_older(type); } + // AMD - if (GGML_CUDA_CC_IS_RDNA4(cc)) { - return get_mmvq_mmid_max_batch_rdna4(type); - } - if (GGML_CUDA_CC_IS_RDNA3(cc)) { - return get_mmvq_mmid_max_batch_rdna3(type); - } - if (GGML_CUDA_CC_IS_RDNA1(cc) || GGML_CUDA_CC_IS_RDNA2(cc)) { - return get_mmvq_mmid_max_batch_rdna1_rdna2(type); - } - if (GGML_CUDA_CC_IS_CDNA(cc)) { - return get_mmvq_mmid_max_batch_cdna(type); - } - if (GGML_CUDA_CC_IS_GCN(cc)) { - return get_mmvq_mmid_max_batch_gcn(type); + if (GGML_CUDA_CC_IS_AMD(cc)) { + if (GGML_CUDA_CC_IS_RDNA4(cc)) { + return get_mmvq_mmid_max_batch_rdna4(type); + } + if (GGML_CUDA_CC_IS_RDNA3(cc)) { + return get_mmvq_mmid_max_batch_rdna3(type); + } + if (GGML_CUDA_CC_IS_RDNA1(cc) || GGML_CUDA_CC_IS_RDNA2(cc)) { + return get_mmvq_mmid_max_batch_rdna1_rdna2(type); + } + if (GGML_CUDA_CC_IS_CDNA(cc)) { + return get_mmvq_mmid_max_batch_cdna(type); + } + if (GGML_CUDA_CC_IS_GCN(cc)) { + return get_mmvq_mmid_max_batch_gcn(type); + } } return MMVQ_MAX_BATCH_SIZE; } From ace95aac6b32f6e0e57a45789d3ec82c8c89e9ac Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 1 Apr 2026 16:01:45 +0300 Subject: [PATCH 374/831] ggml : bump version to 0.9.10 (ggml/1454) --- ggml/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index ab558438e95..2ffc3b391fe 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -4,7 +4,7 @@ project("ggml" C CXX ASM) ### GGML Version set(GGML_VERSION_MAJOR 0) set(GGML_VERSION_MINOR 9) -set(GGML_VERSION_PATCH 9) +set(GGML_VERSION_PATCH 10) set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}") find_program(GIT_EXE NAMES git git.exe NO_CMAKE_FIND_ROOT_PATH) From 981195be5aa6ec359b2e536b4194c0f7a7a3ee20 Mon Sep 17 00:00:00 2001 From: Michael Wand Date: Wed, 1 Apr 2026 03:04:58 -0700 Subject: [PATCH 375/831] ggml-cuda: Add generic NVFP4 MMQ kernel (llama/21074) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Introduced NVFP4 generic MMQ kernel * Added extra FP8 guard, hope to solve ci HIP failure * Rename tiles and use HIP_FP8_AVAILABLE * Removed remaning FP8 straggler and added const int * Const * Removed DECL_MMQ_CASE artifact * Removed newline * Removed space after else * Changed HIP FP8 NVFP4 conversion gate * Added new line to bottom of mmq.cu 270 * Removed extra spaces * Removed single space in front of else on line 814 * Added NVFP4 to generate cu script so HIP can see it, further tightened logic * Include generated mmq-instance-nvfp4.cu * Added NVFP4 mmq to HIP Check ignore list * Update ggml/src/ggml-cuda/mmq.cuh Changed to Q3_K tile to read MMQ_MMA_TILE_X_K_NVFP4 Co-authored-by: Johannes Gäßler * Update ggml/src/ggml-cuda/mmq.cuh Changed to Q3_K tile to read MMQ_MMA_TILE_X_K_NVFP4 in tile assert Co-authored-by: Johannes Gäßler * Update ggml/src/ggml-cuda/mmq.cuh Added function name ending for end if Co-authored-by: Johannes Gäßler * Added function names to closing endif Co-authored-by: Johannes Gäßler --------- Co-authored-by: Johannes Gäßler --- ggml/src/ggml-cuda/common.cuh | 27 ++++-- ggml/src/ggml-cuda/ggml-cuda.cu | 2 - ggml/src/ggml-cuda/mmq.cu | 5 +- ggml/src/ggml-cuda/mmq.cuh | 89 +++++++++++++++++-- .../template-instances/generate_cu_files.py | 2 +- .../template-instances/mmq-instance-nvfp4.cu | 5 ++ 6 files changed, 112 insertions(+), 18 deletions(-) create mode 100644 ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 7d7f20af3a0..9affe023403 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -800,19 +800,32 @@ static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) { } static __device__ __forceinline__ float ggml_cuda_ue4m3_to_fp32(uint8_t x) { -#ifdef FP8_AVAILABLE - const uint32_t bits = x * (x != 0x7F && x != 0xFF); // Convert NaN to 0.0f to match CPU implementation. -#if defined(GGML_USE_HIP) && defined(CDNA3) - // ROCm dose not support fp8 in software on devices with fp8 hardware, +#if defined(GGML_USE_HIP) && defined(CDNA3) && defined(FP8_AVAILABLE) && HIP_VERSION >= 60200000 + // ROCm does not support fp8 in software on devices with fp8 hardware, // but CDNA3 supports only e4m3_fnuz (no inf). + const uint32_t bits = x * (x != 0x7F && x != 0xFF); // Convert NaN to 0.0f to match CPU implementation. const __hip_fp8_e4m3_fnuz xf = *reinterpret_cast(&bits); + return static_cast(xf) / 2; #else +#if defined(FP8_AVAILABLE) && !defined(GGML_USE_HIP) + const uint32_t bits = x * (x != 0x7F && x != 0xFF); // Convert NaN to 0.0f to match CPU implementation. const __nv_fp8_e4m3 xf = *reinterpret_cast(&bits); -#endif // defined(GGML_USE_HIP) && defined(GGML_USE_HIP) return static_cast(xf) / 2; #else - NO_DEVICE_CODE; -#endif // FP8_AVAILABLE + if (x == 0 || (x == 0x7F && x != 0xFF)) { // Convert NaN to 0.0f + return 0.0f; + } + const int exp = (x >> 3) & 0xF; + const int man = x & 0x7; + float raw; + if (exp == 0) { + raw = ldexpf((float) man, -9); + } else { + raw = ldexpf(1.0f + (float) man / 8.0f, exp - 7); + } + return static_cast(raw / 2); +#endif // defined(FP8_AVAILABLE) && !defined(GGML_USE_HIP) +#endif // defined(GGML_USE_HIP) && defined(CDNA3) && defined(FP8_AVAILABLE) && HIP_VERSION >= 60200000 } __device__ __forceinline__ uint8_t ggml_cuda_float_to_fp4_e2m1(float x, float e) { diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index d1239b1c5f7..75b62129ade 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -4791,9 +4791,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: case GGML_TYPE_MXFP4: -#ifdef FP8_AVAILABLE case GGML_TYPE_NVFP4: -#endif // FP8_AVAILABLE case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index 9a69f41d159..27b4145ac9a 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -23,6 +23,9 @@ static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, con case GGML_TYPE_MXFP4: mul_mat_q_case(ctx, args, stream); break; + case GGML_TYPE_NVFP4: + mul_mat_q_case(ctx, args, stream); + break; case GGML_TYPE_Q2_K: mul_mat_q_case(ctx, args, stream); break; @@ -273,6 +276,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: case GGML_TYPE_MXFP4: + case GGML_TYPE_NVFP4: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: @@ -362,5 +366,4 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t } return (!GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; - } diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 255e59f6fc6..51e8dad4ce7 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -68,6 +68,8 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) { return MMQ_Q8_1_DS_LAYOUT_D4; case GGML_TYPE_MXFP4: return MMQ_Q8_1_DS_LAYOUT_D4; + case GGML_TYPE_NVFP4: + return MMQ_Q8_1_DS_LAYOUT_D4; case GGML_TYPE_Q2_K: return MMQ_Q8_1_DS_LAYOUT_D2S6; case GGML_TYPE_Q3_K: @@ -189,6 +191,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml case GGML_TYPE_Q5_1: return MMQ_DP4A_TXS_Q8_1; case GGML_TYPE_Q8_0: return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_MXFP4: return MMQ_DP4A_TXS_Q8_1; + case GGML_TYPE_NVFP4: return MMQ_DP4A_TXS_Q8_0_16; case GGML_TYPE_Q2_K: return MMQ_DP4A_TXS_Q2_K; case GGML_TYPE_Q3_K: return MMQ_DP4A_TXS_Q3_K; case GGML_TYPE_Q4_K: return MMQ_DP4A_TXS_Q4_K; @@ -206,12 +209,13 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml } } -#define MMQ_MMA_TILE_X_K_Q8_0 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4) -#define MMQ_MMA_TILE_X_K_FP4 (2*MMQ_TILE_NE_K + 8 + 4) -#define MMQ_MMA_TILE_X_K_Q8_1 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4) -#define MMQ_MMA_TILE_X_K_Q2_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K + 4) -#define MMQ_MMA_TILE_X_K_Q3_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4) -#define MMQ_MMA_TILE_X_K_Q6_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI6_K + MMQ_TILE_NE_K/8 + 7) +#define MMQ_MMA_TILE_X_K_Q8_0 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4) +#define MMQ_MMA_TILE_X_K_FP4 (2*MMQ_TILE_NE_K + 8 + 4) // MXFP4 +#define MMQ_MMA_TILE_X_K_NVFP4 (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4) // NVFP4 +#define MMQ_MMA_TILE_X_K_Q8_1 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4) +#define MMQ_MMA_TILE_X_K_Q2_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K + 4) +#define MMQ_MMA_TILE_X_K_Q3_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4) +#define MMQ_MMA_TILE_X_K_Q6_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI6_K + MMQ_TILE_NE_K/8 + 7) static_assert(MMQ_MMA_TILE_X_K_Q8_0 % 8 == 4, "Wrong padding."); static_assert(MMQ_MMA_TILE_X_K_Q8_1 % 8 == 4, "Wrong padding."); @@ -220,6 +224,8 @@ static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding."); static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding."); static_assert(MMQ_MMA_TILE_X_K_FP4 % 8 == 4, "Wrong padding."); static_assert(MMQ_MMA_TILE_X_K_FP4 == MMQ_MMA_TILE_X_K_Q8_1, "Wrong tile size for MXFP4"); +static_assert(MMQ_MMA_TILE_X_K_NVFP4 % 8 == 4, "Wrong padding."); + static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { switch (type) { @@ -230,6 +236,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { case GGML_TYPE_Q8_0: return MMQ_MMA_TILE_X_K_Q8_0; // tile sizes are the same for Q8_1 and FP4 for blackwell case GGML_TYPE_MXFP4: return MMQ_MMA_TILE_X_K_Q8_1; + case GGML_TYPE_NVFP4: return MMQ_MMA_TILE_X_K_NVFP4; case GGML_TYPE_Q2_K: return MMQ_MMA_TILE_X_K_Q2_K; case GGML_TYPE_Q3_K: return MMQ_MMA_TILE_X_K_Q3_K; case GGML_TYPE_Q4_K: return MMQ_MMA_TILE_X_K_Q8_1; @@ -826,6 +833,65 @@ static __device__ __forceinline__ void load_tiles_mxfp4_fp4(const char * __restr } } + +template +static __device__ __forceinline__ void load_tiles_nvfp4(const char * __restrict__ x, + int * __restrict__ x_tile, + const int kb0, + const int i_max, + const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_NVFP4, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + + constexpr int threads_per_row = MMQ_ITER_K / QK_NVFP4; + constexpr int rows_per_warp = warp_size / threads_per_row; + const int kbx = threadIdx.x % threads_per_row; + const int row_in_warp = threadIdx.x / threads_per_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += rows_per_warp * nwarps) { + int i = i0 + threadIdx.y * rows_per_warp + row_in_warp; + + if constexpr (need_check) { + i = min(i, i_max); + } + + const block_nvfp4 * bxi = (const block_nvfp4 *) x + kb0 + i * stride + kbx; + const uint32_t * __restrict__ src_qs = reinterpret_cast(bxi->qs); + const int kqs = 16 * kbx; + const int ksc = 4 * kbx; + +#pragma unroll + for (int sub = 0; sub < QK_NVFP4 / QK_NVFP4_SUB; ++sub) { + const int2 q0 = get_int_from_table_16(src_qs[2 * sub + 0], kvalues_mxfp4); + const int2 q1 = get_int_from_table_16(src_qs[2 * sub + 1], kvalues_mxfp4); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + x_qs[i * MMQ_MMA_TILE_X_K_NVFP4 + kqs + 4 * sub + 0] = q0.x; + x_qs[i * MMQ_MMA_TILE_X_K_NVFP4 + kqs + 4 * sub + 1] = q1.x; + x_qs[i * MMQ_MMA_TILE_X_K_NVFP4 + kqs + 4 * sub + 2] = q0.y; + x_qs[i * MMQ_MMA_TILE_X_K_NVFP4 + kqs + 4 * sub + 3] = q1.y; + x_df[i * MMQ_MMA_TILE_X_K_NVFP4 + ksc + sub] = ggml_cuda_ue4m3_to_fp32(bxi->d[sub]); +#else + x_qs[i * (2 * MMQ_TILE_NE_K + 1) + kqs + 4 * sub + 0] = q0.x; + x_qs[i * (2 * MMQ_TILE_NE_K + 1) + kqs + 4 * sub + 1] = q1.x; + x_qs[i * (2 * MMQ_TILE_NE_K + 1) + kqs + 4 * sub + 2] = q0.y; + x_qs[i * (2 * MMQ_TILE_NE_K + 1) + kqs + 4 * sub + 3] = q1.y; + x_df[i * (2 * MMQ_TILE_NE_K * 2 / QI_NVFP4) + i / (QK_NVFP4_SUB / QI_NVFP4) + ksc + sub] = ggml_cuda_ue4m3_to_fp32(bxi->d[sub]); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + } + } +} + template static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { @@ -1229,7 +1295,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( #endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } -// Used for Q3_K, IQ2_S, and IQ2_XS +// Used for NVFP4, Q3_K, IQ2_S, and IQ2_XS template static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { @@ -3261,6 +3327,14 @@ struct mmq_type_traits { static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; }; +template +struct mmq_type_traits { + static constexpr int vdr = VDR_NVFP4_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_nvfp4; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a; +}; + template struct mmq_type_traits { static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ; @@ -4069,6 +4143,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_Q5_0); extern DECL_MMQ_CASE(GGML_TYPE_Q5_1); extern DECL_MMQ_CASE(GGML_TYPE_Q8_0); extern DECL_MMQ_CASE(GGML_TYPE_MXFP4); +extern DECL_MMQ_CASE(GGML_TYPE_NVFP4); extern DECL_MMQ_CASE(GGML_TYPE_Q2_K); extern DECL_MMQ_CASE(GGML_TYPE_Q3_K); extern DECL_MMQ_CASE(GGML_TYPE_Q4_K); diff --git a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py index b7b5832293e..40d51f93fa4 100755 --- a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +++ b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py @@ -35,7 +35,7 @@ "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", "GGML_TYPE_Q2_K", "GGML_TYPE_Q3_K", "GGML_TYPE_Q4_K", "GGML_TYPE_Q5_K", "GGML_TYPE_Q6_K", "GGML_TYPE_IQ2_XXS", "GGML_TYPE_IQ2_XS", "GGML_TYPE_IQ2_S", "GGML_TYPE_IQ3_XXS", "GGML_TYPE_IQ3_S", - "GGML_TYPE_IQ1_S", "GGML_TYPE_IQ4_NL", "GGML_TYPE_IQ4_XS", "GGML_TYPE_MXFP4" + "GGML_TYPE_IQ1_S", "GGML_TYPE_IQ4_NL", "GGML_TYPE_IQ4_XS", "GGML_TYPE_MXFP4", "GGML_TYPE_NVFP4" ] SOURCE_MMQ = """// This file has been autogenerated by generate_cu_files.py, do not edit manually. diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu new file mode 100644 index 00000000000..2cb140d35a3 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE(GGML_TYPE_NVFP4); From fab70d287e977d607247218e7e6e85b7f093adf3 Mon Sep 17 00:00:00 2001 From: Neo Zhang Date: Wed, 1 Apr 2026 18:54:15 +0800 Subject: [PATCH 376/831] sycl : support nvfp4 type in mul_mat (llama/21227) --- ggml/src/ggml-sycl/common.hpp | 7 ++ ggml/src/ggml-sycl/convert.cpp | 18 +++++ ggml/src/ggml-sycl/dequantize.hpp | 32 +++++++++ ggml/src/ggml-sycl/mmvq.cpp | 22 +++++- ggml/src/ggml-sycl/type.hpp | 112 ++++++++++++++++++++++++++++++ ggml/src/ggml-sycl/vecdotq.hpp | 42 +++++++++++ 6 files changed, 232 insertions(+), 1 deletion(-) create mode 100644 ggml/src/ggml-sycl/type.hpp diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp index fcb0db99c6b..fd84c917853 100644 --- a/ggml/src/ggml-sycl/common.hpp +++ b/ggml/src/ggml-sycl/common.hpp @@ -23,6 +23,7 @@ #include "ggml-impl.h" #include "ggml-sycl.h" #include "presets.hpp" +#include "type.hpp" #include "sycl_hw.hpp" namespace syclexp = sycl::ext::oneapi::experimental; @@ -965,4 +966,10 @@ static T block_reduce(T val, T * shared_vals, int block_size_template) { return val; } +static __dpct_inline__ float ggml_sycl_ue4m3_to_fp32(uint8_t x) { + const uint32_t bits = x * (x != 0x7F && x != 0xFF); + const __nv_fp8_e4m3 xf = *reinterpret_cast(&bits); + return static_cast(xf) / 2; +} + #endif // GGML_SYCL_COMMON_HPP diff --git a/ggml/src/ggml-sycl/convert.cpp b/ggml/src/ggml-sycl/convert.cpp index d17aca2cac4..d7f60cbc9ea 100644 --- a/ggml/src/ggml-sycl/convert.cpp +++ b/ggml/src/ggml-sycl/convert.cpp @@ -482,6 +482,18 @@ static void dequantize_row_mxfp4_sycl(const void * vx, dst_t * y, const int64_t }); } +template +static void dequantize_row_nvfp4_sycl(const void * vx, dst_t * y, const int64_t k, dpct::queue_ptr stream) { + GGML_ASSERT(k % QK_NVFP4 == 0); + const int nb = k / QK_NVFP4; + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)), + [=](sycl::nd_item<3> item_ct1) { + dequantize_block_nvfp4(vx, y, k); + }); +} + + template static void dequantize_block_nc(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01, const int64_t ne02, @@ -641,6 +653,8 @@ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) { return dequantize_row_iq4_nl_sycl; case GGML_TYPE_MXFP4: return dequantize_row_mxfp4_sycl; + case GGML_TYPE_NVFP4: + return dequantize_row_nvfp4_sycl; case GGML_TYPE_F32: return convert_unary_sycl; #ifdef GGML_SYCL_HAS_BF16 @@ -648,6 +662,7 @@ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) { return convert_unary_sycl; #endif default: + GGML_ABORT("fatal error: unsupport data type=%s\n", ggml_type_name(type)); return nullptr; } } @@ -708,6 +723,8 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) { return dequantize_row_iq4_nl_sycl; case GGML_TYPE_MXFP4: return dequantize_row_mxfp4_sycl; + case GGML_TYPE_NVFP4: + return dequantize_row_nvfp4_sycl; case GGML_TYPE_F16: return convert_unary_sycl; #ifdef GGML_SYCL_HAS_BF16 @@ -715,6 +732,7 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) { return convert_unary_sycl; #endif default: + GGML_ABORT("fatal error: unsupport data type=%s\n", ggml_type_name(type)); return nullptr; } } diff --git a/ggml/src/ggml-sycl/dequantize.hpp b/ggml/src/ggml-sycl/dequantize.hpp index da2a605daa8..3272724f41b 100644 --- a/ggml/src/ggml-sycl/dequantize.hpp +++ b/ggml/src/ggml-sycl/dequantize.hpp @@ -838,4 +838,36 @@ static void dequantize_block_mxfp4(const void * __restrict__ vx, dst_t * __restr } } + +template +static void dequantize_block_nvfp4( + const void * __restrict__ vx, + dst_t * __restrict__ yy, + const int64_t ne) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + const int64_t i = item_ct1.get_group(2); + const int tid = item_ct1.get_local_id(2); + + const int64_t base = i * QK_NVFP4; + if (base >= ne) { + return; + } + + const block_nvfp4 * x = (const block_nvfp4 *) vx; + const block_nvfp4 & xb = x[i]; + + const int sub = tid / (QK_NVFP4_SUB / 2); + const int j = tid % (QK_NVFP4_SUB / 2); + + const float d = ggml_sycl_ue4m3_to_fp32(xb.d[sub]); + const uint8_t q = xb.qs[sub * (QK_NVFP4_SUB / 2) + j]; + + const int64_t y0 = base + sub * QK_NVFP4_SUB + j; + const int64_t y1 = y0 + QK_NVFP4_SUB / 2; + + yy[y0] = ggml_sycl_cast(d * kvalues_mxfp4[q & 0x0F]); + yy[y1] = ggml_sycl_cast(d * kvalues_mxfp4[q >> 4]); +} + + #endif // GGML_SYCL_DEQUANTIZE_HPP diff --git a/ggml/src/ggml-sycl/mmvq.cpp b/ggml/src/ggml-sycl/mmvq.cpp index 316aa0d0fb5..5abc50fabfe 100644 --- a/ggml/src/ggml-sycl/mmvq.cpp +++ b/ggml/src/ggml-sycl/mmvq.cpp @@ -613,6 +613,23 @@ static void mul_mat_vec_mxfp4_q8_1_sycl(const void * vx, const void * vy, float } } +static void mul_mat_vec_nvfp4_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_NVFP4 == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + + { + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q( + vx, vy, dst, ncols, nrows, item_ct1); + }); + }); + } +} static void mul_mat_vec_q5_0_q8_1_sycl(const void *vx, const void *vy, float *dst, const int ncols, @@ -1145,8 +1162,11 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens case GGML_TYPE_MXFP4: mul_mat_vec_mxfp4_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); break; + case GGML_TYPE_NVFP4: + mul_mat_vec_nvfp4_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + break; default: - GGML_ABORT("fatal error"); + GGML_ABORT("fatal error: unsupport data type=%s\n", ggml_type_name(src0->type)); } } GGML_UNUSED(src1); diff --git a/ggml/src/ggml-sycl/type.hpp b/ggml/src/ggml-sycl/type.hpp new file mode 100644 index 00000000000..d7ff89d7d42 --- /dev/null +++ b/ggml/src/ggml-sycl/type.hpp @@ -0,0 +1,112 @@ +#pragma once + +#include +#include +#include + +inline uint8_t float_to_e4m3(float f) +{ + if (sycl::isnan(f)) { + return 0x7F; // Canonical NaN (positive) + } + + uint32_t bits = sycl::bit_cast(f); + uint32_t sign = (bits >> 31) & 0x1u; + uint32_t exp = (bits >> 23) & 0xFFu; + uint32_t mant = bits & 0x7FFFFFu; + + // Zero + if (exp == 0 && mant == 0) { + return static_cast(sign << 7); + } + + // Extract biased exponent and mantissa for FP8 + int e = static_cast(exp) - 127; // true exponent (IEEE bias 127) + uint32_t m = mant; + + // Handle very large values → NaN (NVIDIA behavior for E4M3) + if (e > 7) { // max exponent for E4M3 is 7 (biased 14) + return static_cast((sign << 7) | 0x7F); + } + + // Handle subnormals and normal numbers + if (e < -6) { // smallest normal exponent is -6 + // Subnormal in FP8: shift mantissa right + int shift = -6 - e; + m = (m | 0x800000u) >> (shift + 1); // +1 because we lose the implicit 1 position + if (shift > 23) m = 0; + } else { + // Normal number: adjust exponent bias from 127 to 7 + int new_exp = e + 7; + m = (m >> 20) & 0x7u; // take top 3 mantissa bits (after implicit 1) + m |= (static_cast(new_exp) << 3); + } + + // Round-to-nearest-even (simple guard + round bit) + // For better accuracy you can add sticky bit, but this is sufficient for most use cases + uint32_t round_bit = (mant >> 19) & 0x1u; // bit after the 3 mantissa bits + if (round_bit) { + m += 1; + // Carry into exponent if mantissa overflows + if ((m & 0x8u) != 0) { + m = (m & 0x7u) | ((m & 0x38u) << 1); // simple carry handling + // If exponent overflows after carry → NaN + if ((m >> 3) > 14) { + return static_cast((sign << 7) | 0x7F); + } + } + } + + uint8_t result = static_cast((sign << 7) | (m & 0x7F)); + return result; +} + +inline float e4m3_to_float(uint8_t x) +{ + if (x == 0) return 0.0f; + + uint8_t sign = (x >> 7) & 0x1u; + uint8_t exp = (x >> 3) & 0xFu; + uint8_t mant = x & 0x7u; + + // NaN (NVIDIA uses 0x7F / 0xFF as NaN) + if (exp == 0xF && mant != 0) { + return std::numeric_limits::quiet_NaN(); + } + if (exp == 0xF) { // 0x7F or 0xFF treated as NaN + return std::numeric_limits::quiet_NaN(); + } + + float val; + + if (exp == 0) { + // Subnormal + val = mant * (1.0f / 8.0f) * sycl::pow(2.0f, -6.0f); + } else { + // Normal: implicit leading 1 + bias 7 + val = (1.0f + mant / 8.0f) * sycl::pow(2.0f, static_cast(exp) - 7.0f); + } + + return sign ? -val : val; +} + +// The actual type definition +struct __nv_fp8_e4m3 { + uint8_t raw; + + __nv_fp8_e4m3() = default; + + explicit __nv_fp8_e4m3(float f) : raw(float_to_e4m3(f)) {} + explicit __nv_fp8_e4m3(sycl::half h) : raw(float_to_e4m3(static_cast(h))) {} + + operator float() const { return e4m3_to_float(raw); } + operator sycl::half() const { return static_cast(static_cast(*this)); } + + // Allow direct access for vector loads/stores + operator uint8_t&() { return raw; } + operator uint8_t() const { return raw; } +}; + +using __nv_fp8x2_e4m3 = sycl::vec<__nv_fp8_e4m3, 2>; +using __nv_fp8x4_e4m3 = sycl::vec<__nv_fp8_e4m3, 4>; + diff --git a/ggml/src/ggml-sycl/vecdotq.hpp b/ggml/src/ggml-sycl/vecdotq.hpp index 9a267d85a0c..eab9850aed7 100644 --- a/ggml/src/ggml-sycl/vecdotq.hpp +++ b/ggml/src/ggml-sycl/vecdotq.hpp @@ -15,6 +15,7 @@ #include "dpct/helper.hpp" #include "ggml.h" +#include "type.hpp" #include "quants.hpp" typedef float (*vec_dot_q_sycl_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, @@ -31,6 +32,18 @@ static __dpct_inline__ int get_int_b1(const void * x, const int & i32) { return x32; } +static __dpct_inline__ int get_int_b2(const void * x, const int & i32) { + const uint16_t * x16 = (const uint16_t *) x; // assume at least 2 byte alignment + + int x32 = x16[2*i32 + 0] << 0; + x32 |= x16[2*i32 + 1] << 16; + + return x32; +} + +static __dpct_inline__ int get_int_b4(const void * x, const int & i32) { + return ((const int *) x)[i32]; // assume at least 4 byte alignment +} static __dpct_inline__ int get_int_from_int8(const int8_t* x8, const int& i32) { const uint16_t* x16 = @@ -755,6 +768,35 @@ static __dpct_inline__ float vec_dot_mxfp4_q8_1(const void * __restrict__ vbq, return d * sumi; } +#define VDR_NVFP4_Q8_1_MMVQ 4 +#define VDR_NVFP4_Q8_1_MMQ 8 + +static __dpct_inline__ float vec_dot_nvfp4_q8_1(const void * __restrict__ vbq, + const block_q8_1 * __restrict__ bq8_1, + const int32_t & iqs) { + const block_nvfp4 * bq4 = (const block_nvfp4 *) vbq; + float sum = 0.0f; +#pragma unroll + for (int i = 0; i < VDR_NVFP4_Q8_1_MMVQ/2; i++) { + const int32_t iqs0 = iqs + 2*i; + const int32_t iqs1 = iqs0 + 1; + const int32_t is = iqs0 >> 1; + const sycl::int2 v0 = get_int_from_table_16(get_int_b4(bq4->qs, iqs0), kvalues_mxfp4); + const sycl::int2 v1 = get_int_from_table_16(get_int_b4(bq4->qs, iqs1), kvalues_mxfp4); + const block_q8_1 * bq8 = bq8_1 + (is >> 1); + const int32_t i8 = ((is & 1) << 2); + + int sumi = ggml_sycl_dp4a(v0.x(), get_int_b4(bq8->qs, i8 + 0), 0); + sumi = ggml_sycl_dp4a(v0.y(), get_int_b4(bq8->qs, i8 + 2), sumi); + sumi = ggml_sycl_dp4a(v1.x(), get_int_b4(bq8->qs, i8 + 1), sumi); + sumi = ggml_sycl_dp4a(v1.y(), get_int_b4(bq8->qs, i8 + 3), sumi); + + const float d = ggml_sycl_ue4m3_to_fp32(bq4->d[is]) * (bq8->ds)[0]; + sum += d * float(sumi); + } + + return sum; +} static __dpct_inline__ float vec_dot_q5_0_q8_1(const void *__restrict__ vbq, From 9a40dd9365ac55c16a27200e8db3873dbb4c7cbd Mon Sep 17 00:00:00 2001 From: Aparna M P Date: Wed, 1 Apr 2026 21:13:08 +0530 Subject: [PATCH 377/831] hexagon: improve RMS_NORM and DIV accuracy (llama/21251) * hexagon-rms_norm: fix RMS_NORM for non-aligned tensor sizes Co-authored-by: Krishna Sridhar * hexagon-div: perform DIV in fp16 domain for lower dsp archs --------- Co-authored-by: Krishna Sridhar --- ggml/src/ggml-hexagon/htp/hvx-div.h | 86 ++++++++++++++++++++------- ggml/src/ggml-hexagon/htp/unary-ops.c | 41 ++++++++++--- 2 files changed, 97 insertions(+), 30 deletions(-) diff --git a/ggml/src/ggml-hexagon/htp/hvx-div.h b/ggml/src/ggml-hexagon/htp/hvx-div.h index 05cefea039f..53ee304e749 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-div.h +++ b/ggml/src/ggml-hexagon/htp/hvx-div.h @@ -16,8 +16,10 @@ #if __HVX_ARCH__ < 79 #define HVX_OP_MUL_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b)) +#define HVX_OP_MUL_F16(a, b) Q6_Vhf_equals_Wqf32(Q6_Wqf32_vmpy_VhfVhf(a, b)) #else #define HVX_OP_MUL_F32(a, b) Q6_Vsf_vmpy_VsfVsf(a, b) +#define HVX_OP_MUL_F16(a, b) Q6_Vhf_vmpy_VhfVhf(a, b) #endif // Compute div by scaler in f32. Requires first by expanding fp32 to fp16 and converting the result back to fp32. @@ -43,46 +45,67 @@ static inline HVX_Vector hvx_div_mul_f16_const_using_f32(HVX_Vector vec1_hf, HVX return res; } -#define hvx_div_scaler_f16_loop_body(dst_type, src_type, vec_store) \ - do { \ - dst_type * restrict vdst = (dst_type *) dst; \ - src_type * restrict vsrc = (src_type *) src; \ - HVX_Vector hf_one = Q6_Vh_vsplat_R(0x3C00); \ - \ - const uint32_t nvec = n / VLEN_FP16; \ - const uint32_t nloe = n % VLEN_FP16; \ - \ - uint32_t i = 0; \ - \ - _Pragma("unroll(4)") \ - for (; i < nvec; i++) { \ - HVX_Vector res = hvx_div_mul_f16_const_using_f32(vsrc[i], val_vec_f32, hf_one); \ - vdst[i] = res; \ - } \ - if (nloe) { \ - HVX_Vector res = hvx_div_mul_f16_const_using_f32(vsrc[i], val_vec_f32, hf_one); \ - vec_store((void *) &vdst[i], nloe * SIZEOF_FP16, res); \ - } \ +// Variant for =v79 +static inline HVX_Vector hvx_vec_hybrid_div_f16(HVX_Vector vec1, HVX_Vector vec2, HVX_Vector f32_nan_inf_mask, HVX_Vector f16_nan_inf_mask, HVX_Vector vec_hf_one_1_0) { +#if __HVX_ARCH__ < 79 + // For older architectures, use f16 reciprocal to avoid NaN/-inf issues + HVX_Vector vec2_inv = hvx_vec_inverse_f16_guard(vec2, f16_nan_inf_mask); + return HVX_OP_MUL_F16(vec1, vec2_inv); +#else + return hvx_vec_div_f16_using_f32(vec1, vec2, f32_nan_inf_mask, vec_hf_one_1_0); +#endif +} + #define hvx_div_f16_loop_body(dst_type, src0_type, src1_type, vec_store) \ do { \ dst_type * restrict vdst = (dst_type *) dst; \ src0_type * restrict vsrc0 = (src0_type *) src0; \ src1_type * restrict vsrc1 = (src1_type *) src1; \ \ - const HVX_Vector nan_inf_mask = Q6_V_vsplat_R(0x7f800000); \ + const HVX_Vector f32_nan_inf_mask = Q6_V_vsplat_R(0x7f800000); \ + const HVX_Vector f16_nan_inf_mask = Q6_Vh_vsplat_R(0x7c00); \ const HVX_Vector hf_one = Q6_Vh_vsplat_R(0x3C00); \ \ const uint32_t nvec = n / VLEN_FP16; \ @@ -144,11 +179,15 @@ static inline HVX_Vector hvx_vec_div_f16_using_f32(HVX_Vector vec1, HVX_Vector v \ _Pragma("unroll(4)") \ for (; i < nvec; i++) { \ - HVX_Vector res = hvx_vec_div_f16_using_f32(vsrc0[i], vsrc1[i], nan_inf_mask, hf_one); \ + HVX_Vector res = hvx_vec_hybrid_div_f16(vsrc0[i], vsrc1[i], \ + f32_nan_inf_mask, f16_nan_inf_mask, \ + hf_one); \ vdst[i] = res; \ } \ if (nloe) { \ - HVX_Vector res = hvx_vec_div_f16_using_f32(vsrc0[i], vsrc1[i], nan_inf_mask, hf_one); \ + HVX_Vector res = hvx_vec_hybrid_div_f16(vsrc0[i], vsrc1[i], \ + f32_nan_inf_mask, f16_nan_inf_mask, \ + hf_one); \ vec_store((void *) &vdst[i], nloe * SIZEOF_FP16, res); \ } \ } while(0) @@ -247,5 +286,6 @@ HVX_DIV_DISPATCHER(hvx_div_f32) HVX_DIV_DISPATCHER(hvx_div_f16) #undef HVX_OP_MUL_F32 +#undef HVX_OP_MUL_F16 #endif // HVX_DIV_H diff --git a/ggml/src/ggml-hexagon/htp/unary-ops.c b/ggml/src/ggml-hexagon/htp/unary-ops.c index 3d0928d4dce..13d28317d5c 100644 --- a/ggml/src/ggml-hexagon/htp/unary-ops.c +++ b/ggml/src/ggml-hexagon/htp/unary-ops.c @@ -67,34 +67,61 @@ static void hvx_fast_rms_norm_f32(const uint8_t * restrict src, uint8_t * restrict pad, const int num_elems, float epsilon) { + (void)pad; + const HVX_Vector * restrict v_src = (HVX_Vector *) src; HVX_Vector * restrict v_dst = (HVX_Vector *) dst; - HVX_Vector sum_v = Q6_V_vsplat_R(0x00000000); + const int nvec = num_elems / VLEN_FP32; // number of full vectors + const int nloe = num_elems % VLEN_FP32; // leftover elements + + // Compute sum of squares for full vectors + HVX_Vector sum_v = Q6_V_vsplat_R(0x00000000); HVX_Vector epsilon_v = hvx_vec_splat_f32(epsilon); - int step_of_1 = num_elems >> 5; #pragma unroll(4) - for (int i = 0; i < step_of_1; i++) { + for (int i = 0; i < nvec; i++) { HVX_Vector v1 = v_src[i]; HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, v1); - sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2); + sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2); + } + + // Handle tail elements using vectorized ops with masking + if (nloe > 0) { + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4); + HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]); + HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, v1); + sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2); } - sum_v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_v)); // replicated over all lanes + // Reduce HVX sum + sum_v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_v)); HVX_Vector t_v = hvx_vec_splat_f32((float) num_elems); HVX_Vector denom_v = hvx_vec_inverse_f32(t_v); HVX_Vector mean_v = Q6_Vqf32_vmpy_VsfVsf(sum_v, denom_v); HVX_Vector mean_epsilon_v = Q6_Vqf32_vadd_Vqf32Vsf(mean_v, epsilon_v); + // Scale full vectors HVX_Vector scale_v = hvx_vec_rsqrt_f32(Q6_Vsf_equals_Vqf32(mean_epsilon_v)); #pragma unroll(4) - for (int i = 0; i < step_of_1; i++) { + for (int i = 0; i < nvec; i++) { HVX_Vector v1 = v_src[i]; HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, scale_v); - v_dst[i] = Q6_Vsf_equals_Vqf32(v2); + v_dst[i] = Q6_Vsf_equals_Vqf32(v2); + } + + // Handle tail elements using vectorized ops with masking + if (nloe > 0) { + + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4); + HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]); + HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, scale_v); + HVX_Vector result = Q6_Vsf_equals_Vqf32(v2); + + // Store with masking to avoid overwriting memory beyond the tensor + hvx_vec_store_a(&v_dst[nvec], nloe * 4, result); } } From 82bb26fba1b4de5180009ae5a2a20537efba8ee7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Wed, 1 Apr 2026 21:28:19 +0200 Subject: [PATCH 378/831] CUDA: fix FA kernel selection logic (llama/21271) --- ggml/src/ggml-cuda/fattn.cu | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index a21c5361048..addf93205ef 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -340,7 +340,14 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const case 128: case 112: case 256: + if (V->ne[0] != K->ne[0]) { + return BEST_FATTN_KERNEL_NONE; + } + break; case 512: + if (V->ne[0] != K->ne[0]) { + return BEST_FATTN_KERNEL_NONE; + } if (!gqa_opt_applies) { return BEST_FATTN_KERNEL_NONE; } From 08108512c7c3ae2610d1e5f36c80cd7d3a753987 Mon Sep 17 00:00:00 2001 From: lhez Date: Wed, 1 Apr 2026 12:54:58 -0700 Subject: [PATCH 379/831] opencl: fix leak in Adreno q8_0 path (llama/21212) --- ggml/src/ggml-opencl/ggml-opencl.cpp | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 0f6628c377d..6f3fc5886d8 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -9612,6 +9612,9 @@ static void ggml_cl_mul_mat_q8_0_f32_adreno(ggml_backend_t backend, const ggml_t cl_mem B_image1d; cl_mem B_sub_buffer; cl_mem S_image1d; + // for B transpose + cl_mem B_image1d_trans = nullptr; + cl_mem B_d = nullptr; cl_mem D_image1d; cl_mem D_sub_buffer; @@ -9703,9 +9706,6 @@ static void ggml_cl_mul_mat_q8_0_f32_adreno(ggml_backend_t backend, const ggml_t global_work_size[2] = 1; } else { cl_ulong offsetd = extrad->offset + dst->view_offs; - cl_mem B_image1d_trans = nullptr; - // for B transpose - cl_mem B_d = nullptr; int padding; //how many extra elements beyond multiple of 8 @@ -9800,6 +9800,12 @@ static void ggml_cl_mul_mat_q8_0_f32_adreno(ggml_backend_t backend, const ggml_t CL_CHECK(clReleaseMemObject(S_image1d)); CL_CHECK(clReleaseMemObject(D_sub_buffer)); CL_CHECK(clReleaseMemObject(D_image1d)); + if (B_image1d_trans) { + CL_CHECK(clReleaseMemObject(B_image1d_trans)); + } + if (B_d) { + CL_CHECK(clReleaseMemObject(B_d)); + } #else GGML_UNUSED(backend); GGML_UNUSED(src0); From 444662bc8307fc7a5d49acde48ae32e3c51b280b Mon Sep 17 00:00:00 2001 From: Todor Boinovski Date: Wed, 1 Apr 2026 17:44:02 -0700 Subject: [PATCH 380/831] hexagon : add cumsum op support (llama/21246) * hexagon : add cumsum op support * hexagon: enable dma for cumsum op * Fix line-ending --------- Co-authored-by: Max Krasnyansky --- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 34 +++ ggml/src/ggml-hexagon/htp/CMakeLists.txt | 1 + ggml/src/ggml-hexagon/htp/cumsum-ops.c | 267 +++++++++++++++++++++++ ggml/src/ggml-hexagon/htp/htp-msg.h | 1 + ggml/src/ggml-hexagon/htp/htp-ops.h | 1 + ggml/src/ggml-hexagon/htp/main.c | 43 ++++ 6 files changed, 347 insertions(+) create mode 100644 ggml/src/ggml-hexagon/htp/cumsum-ops.c diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index dd604db4333..f91bc46552e 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -2231,6 +2231,22 @@ static bool ggml_hexagon_supported_ssm_conv(const struct ggml_hexagon_session * return true; } +static bool ggml_hexagon_supported_cumsum(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + const struct ggml_tensor * src0 = op->src[0]; + const struct ggml_tensor * dst = op; + + if (src0->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) { + return false; + } + + if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(dst)) { + return false; + } + + GGML_UNUSED(sess); + return true; +} + enum dspqbuf_type { DSPQBUF_TYPE_DSP_WRITE_CPU_READ = 0, DSPQBUF_TYPE_CPU_WRITE_DSP_READ, @@ -2399,6 +2415,16 @@ static inline size_t init_repeat_req(htp_general_req * req, dspqueue_buffer * bu return n_bufs; } +static inline size_t init_cumsum_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { + req->op = HTP_OP_CUMSUM; + + size_t n_bufs = 0; + n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); + n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); + + return n_bufs; +} + static inline size_t init_get_rows_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { req->op = HTP_OP_GET_ROWS; @@ -2780,6 +2806,10 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg ggml_hexagon_dispatch_op(sess, node, flags); break; + case GGML_OP_CUMSUM: + ggml_hexagon_dispatch_op(sess, node, flags); + break; + default: GGML_ABORT("\nggml-hex: graph-compute %s is not supported\n", ggml_op_desc(node)); } @@ -3254,6 +3284,10 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons supp = ggml_hexagon_supported_ssm_conv(sess, op); break; + case GGML_OP_CUMSUM: + supp = ggml_hexagon_supported_cumsum(sess, op); + break; + default: break; } diff --git a/ggml/src/ggml-hexagon/htp/CMakeLists.txt b/ggml/src/ggml-hexagon/htp/CMakeLists.txt index 6ddfe4252f5..2b60f427ada 100644 --- a/ggml/src/ggml-hexagon/htp/CMakeLists.txt +++ b/ggml/src/ggml-hexagon/htp/CMakeLists.txt @@ -33,6 +33,7 @@ add_library(${HTP_LIB} SHARED repeat-ops.c argsort-ops.c ssm-conv.c + cumsum-ops.c ) target_compile_definitions(${HTP_LIB} PRIVATE diff --git a/ggml/src/ggml-hexagon/htp/cumsum-ops.c b/ggml/src/ggml-hexagon/htp/cumsum-ops.c new file mode 100644 index 00000000000..ce51555a7fd --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/cumsum-ops.c @@ -0,0 +1,267 @@ +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#include +#include + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" +#include "htp-ops.h" +#include "hvx-types.h" +#include "hvx-utils.h" +#include "hex-dma.h" + +#define htp_cumsum_tensors_preamble \ + struct htp_tensor * restrict src0 = &octx->src0; \ + struct htp_tensor * restrict dst = &octx->dst; \ + \ + const uint32_t ne00 = src0->ne[0]; \ + const uint32_t ne01 = src0->ne[1]; \ + const uint32_t ne02 = src0->ne[2]; \ + const uint32_t ne03 = src0->ne[3]; \ + \ + const uint32_t ne0 = dst->ne[0]; \ + const uint32_t ne1 = dst->ne[1]; \ + const uint32_t ne2 = dst->ne[2]; \ + const uint32_t ne3 = dst->ne[3]; \ + \ + const uint32_t nb00 = src0->nb[0]; \ + const uint32_t nb01 = src0->nb[1]; \ + const uint32_t nb02 = src0->nb[2]; \ + const uint32_t nb03 = src0->nb[3]; \ + \ + const uint32_t nb0 = dst->nb[0]; \ + const uint32_t nb1 = dst->nb[1]; \ + const uint32_t nb2 = dst->nb[2]; \ + const uint32_t nb3 = dst->nb[3]; + +struct htp_cumsum_context { + struct htp_ops_context * octx; + size_t src_row_size; + size_t dst_row_size; + size_t src_row_size_aligned; + size_t dst_row_size_aligned; + uint32_t rows_per_thread; + uint32_t total_rows; +}; + +#define htp_cumsum_preamble \ + struct htp_cumsum_context * cctx = (struct htp_cumsum_context *) data; \ + struct htp_ops_context * octx = cctx->octx; \ + htp_cumsum_tensors_preamble; \ + dma_queue * dma_queue = octx->ctx->dma[ith]; + +// --------------------------------------------------------------------------- +// HVX prefix scan helpers +// --------------------------------------------------------------------------- + +#if __HVX_ARCH__ > 75 +static inline HVX_Vector hvx_cumsum_vadd(HVX_Vector a, HVX_Vector b) { + return Q6_Vsf_vadd_VsfVsf(a, b); +} +#else +static inline HVX_Vector hvx_cumsum_vadd(HVX_Vector a, HVX_Vector b) { + return Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(a, b)); +} +#endif // __HVX_ARCH__ > 75 + +static inline HVX_Vector hvx_prefix_scan_f32(HVX_Vector v, HVX_Vector carry_in) { + const HVX_Vector zero = Q6_V_vsplat_R(0); + + v = hvx_cumsum_vadd(v, Q6_V_vlalign_VVR(v, zero, 4)); + v = hvx_cumsum_vadd(v, Q6_V_vlalign_VVR(v, zero, 8)); + v = hvx_cumsum_vadd(v, Q6_V_vlalign_VVR(v, zero, 16)); + v = hvx_cumsum_vadd(v, Q6_V_vlalign_VVR(v, zero, 32)); + v = hvx_cumsum_vadd(v, Q6_V_vlalign_VVR(v, zero, 64)); + v = hvx_cumsum_vadd(v, carry_in); + + return v; +} + +static inline HVX_Vector hvx_splat_last_f32(HVX_Vector v) { + return hvx_vec_repl4(Q6_V_vror_VR(v, 124)); +} + +static inline void hvx_cumsum_row_f32(const float * restrict src, float * restrict dst, uint32_t n) { + const uint32_t nvec = n / VLEN_FP32; + const uint32_t nloe = n % VLEN_FP32; + + HVX_Vector carry = Q6_V_vsplat_R(0); + + for (uint32_t i = 0; i < nvec; i++) { + HVX_Vector v = *((const HVX_UVector *) (src + i * VLEN_FP32)); + v = hvx_prefix_scan_f32(v, carry); + hvx_vec_store_u(dst + i * VLEN_FP32, VLEN, v); + carry = hvx_splat_last_f32(v); + } + + if (nloe) { + float acc = hvx_vec_get_f32(carry); + const float * src_tail = src + nvec * VLEN_FP32; + float * dst_tail = dst + nvec * VLEN_FP32; + for (uint32_t i = 0; i < nloe; i++) { + acc += src_tail[i]; + dst_tail[i] = acc; + } + } +} + +// --------------------------------------------------------------------------- +// Per thread worker: Double-buffered DMA +// --------------------------------------------------------------------------- + +static void cumsum_thread_f32_dma(unsigned int nth, unsigned int ith, void * data) { + htp_cumsum_preamble; + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + const uint32_t ir0 = cctx->rows_per_thread * ith; + const uint32_t ir1 = MIN(ir0 + cctx->rows_per_thread, cctx->total_rows); + + if (ir0 >= ir1) { + return; + } + + const size_t src_row_size = cctx->src_row_size; + const size_t dst_row_size = cctx->dst_row_size; + const size_t src_row_size_aligned = cctx->src_row_size_aligned; + const size_t dst_row_size_aligned = cctx->dst_row_size_aligned; + + const uint8_t * src_data = (const uint8_t *) src0->data; + uint8_t * dst_data = (uint8_t *) dst->data; + + uint8_t * src_spad = octx->src0_spad.data + (ith * src_row_size_aligned * 2); + uint8_t * dst_spad = octx->dst_spad.data + (ith * dst_row_size_aligned * 2); + + for (uint32_t ir = ir0, spad_idx = 0; ir < ir1 && spad_idx < 2; ir++, spad_idx++) { + // Dummy dst writeback to establish queue ordering + dma_queue_push_vtcm_to_ddr(dma_queue, + dma_make_ptr(dst_data, dst_spad + (spad_idx * dst_row_size_aligned)), + dst_row_size, dst_row_size_aligned, 0); + + dma_queue_push_ddr_to_vtcm(dma_queue, + dma_make_ptr(src_spad + (spad_idx * src_row_size_aligned), + src_data + (ir * src_row_size)), + src_row_size_aligned, src_row_size, 1); + } + + for (uint32_t ir = ir0; ir < ir1; ir++) { + float * dst_spad_row = (float *) dma_queue_pop(dma_queue).src; + float * src_spad_row = (float *) dma_queue_pop(dma_queue).dst; + + hvx_cumsum_row_f32(src_spad_row, dst_spad_row, ne00); + + dma_queue_push_vtcm_to_ddr(dma_queue, + dma_make_ptr(dst_data + (ir * dst_row_size), (uint8_t *) dst_spad_row), + dst_row_size, dst_row_size_aligned, 1); + + const uint32_t next_row = ir + 2; + if (next_row < ir1) { + dma_queue_push_ddr_to_vtcm(dma_queue, + dma_make_ptr((uint8_t *) src_spad_row, src_data + (next_row * src_row_size)), + src_row_size_aligned, src_row_size, 1); + } + } + + dma_queue_flush(dma_queue); + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "cumsum-f32-dma %d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", + ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ir0, ir1, + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +// --------------------------------------------------------------------------- +// Per thread worker: Direct HVX (no DMA) +// --------------------------------------------------------------------------- + +static void cumsum_thread_f32(unsigned int nth, unsigned int ith, void * data) { + htp_cumsum_preamble; + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + const uint8_t * src_data = (const uint8_t *) src0->data; + uint8_t * dst_data = (uint8_t *) dst->data; + + const uint32_t ir0 = cctx->rows_per_thread * ith; + const uint32_t ir1 = MIN(ir0 + cctx->rows_per_thread, cctx->total_rows); + + for (uint32_t ir = ir0; ir < ir1; ir++) { + const float * restrict src_row = (const float *) (src_data + ir * cctx->src_row_size); + float * restrict dst_row = (float *) (dst_data + ir * cctx->dst_row_size); + hvx_cumsum_row_f32(src_row, dst_row, ne00); + } + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "cumsum-f32 %d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", + ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ir0, ir1, + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +int op_cumsum_f32(struct htp_ops_context * octx) { + const struct htp_tensor * src0 = &octx->src0; + const struct htp_tensor * dst = &octx->dst; + + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { + return HTP_STATUS_OK; + } + + const uint32_t total_rows = src0->ne[1] * src0->ne[2] * src0->ne[3]; + const uint32_t n_threads = MIN(octx->n_threads, total_rows); + + const size_t src_row_size = src0->nb[1]; + const size_t dst_row_size = dst->nb[1]; + const size_t src_row_size_aligned = hex_round_up(src_row_size, VLEN); + const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN); + + // 2 ping-pong buffers per thread for src and dst + const size_t spad_per_thread = 2 * (src_row_size_aligned + dst_row_size_aligned); + + octx->src0_spad.size_per_thread = src_row_size_aligned * 2; + octx->dst_spad.size_per_thread = dst_row_size_aligned * 2; + octx->src0_spad.size = n_threads * octx->src0_spad.size_per_thread; + octx->dst_spad.size = n_threads * octx->dst_spad.size_per_thread; + octx->src0_spad.data = octx->ctx->vtcm_base; + octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size; + + struct htp_cumsum_context cctx = { + .octx = octx, + .src_row_size = src_row_size, + .dst_row_size = dst_row_size, + .src_row_size_aligned = src_row_size_aligned, + .dst_row_size_aligned = dst_row_size_aligned, + .rows_per_thread = (total_rows + n_threads - 1) / n_threads, + .total_rows = total_rows, + }; + + if (octx->ctx->vtcm_size < spad_per_thread * n_threads) { + worker_pool_run_func(octx->ctx->worker_pool, cumsum_thread_f32, &cctx, n_threads); + } else { + worker_pool_run_func(octx->ctx->worker_pool, cumsum_thread_f32_dma, &cctx, n_threads); + } + + return HTP_STATUS_OK; +} + +int op_cumsum(struct htp_ops_context * octx) { + int err = HTP_STATUS_OK; + struct htp_tensor * dst = &octx->dst; + + switch (dst->type) { + case HTP_TYPE_F32: + err = op_cumsum_f32(octx); + break; + default: + err = HTP_STATUS_NO_SUPPORT; + break; + } + + return err; +} diff --git a/ggml/src/ggml-hexagon/htp/htp-msg.h b/ggml/src/ggml-hexagon/htp/htp-msg.h index 391148be0e9..df0ea7ccbd6 100644 --- a/ggml/src/ggml-hexagon/htp/htp-msg.h +++ b/ggml/src/ggml-hexagon/htp/htp-msg.h @@ -75,6 +75,7 @@ enum htp_op { HTP_OP_SUM_ROWS, HTP_OP_SSM_CONV, HTP_OP_REPEAT, + HTP_OP_CUMSUM, INVALID }; diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index f643fdc340d..d35decaac20 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -60,5 +60,6 @@ int op_cpy(struct htp_ops_context * octx); int op_repeat(struct htp_ops_context * octx); int op_argsort(struct htp_ops_context * octx); int op_ssm_conv(struct htp_ops_context * octx); +int op_cumsum(struct htp_ops_context * octx); #endif /* HTP_OPS_H */ diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index 49f34b5f7d1..6f37bf9d4b8 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -860,6 +860,41 @@ static void proc_ssm_conv_req(struct htp_context * ctx, struct htp_general_req * send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); } +static void proc_cumsum_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { + struct dspqueue_buffer rsp_bufs[1]; + + // We've written to the output buffer, we'd also need to flush it + rsp_bufs[0].fd = bufs[1].fd; + rsp_bufs[0].ptr = bufs[1].ptr; + rsp_bufs[0].offset = bufs[1].offset; + rsp_bufs[0].size = bufs[1].size; + rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP + DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU + + // Setup Op context + struct htp_ops_context octx = { 0 }; + octx.ctx = ctx; + octx.src0 = req->src0; + octx.dst = req->dst; + octx.flags = req->flags; + octx.op = req->op; + octx.src0.data = (uint32_t) bufs[0].ptr; + octx.dst.data = (uint32_t) bufs[1].ptr; + octx.n_threads = ctx->n_threads; + + struct profile_data prof; + profile_start(&prof); + + uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; + if (vtcm_acquire(ctx) == AEE_SUCCESS) { + rsp_status = op_cumsum(&octx); + vtcm_release(ctx); + } + + profile_stop(&prof); + send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); +} + static void proc_activations_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs, @@ -1474,6 +1509,14 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) { proc_ssm_conv_req(ctx, &req, bufs); break; + case HTP_OP_CUMSUM: + if (n_bufs != 2) { + FARF(ERROR, "Bad cumsum-req buffer list"); + continue; + } + proc_cumsum_req(ctx, &req, bufs); + break; + default: FARF(ERROR, "Unknown Op %u", req.op); break; From 514eabc1e5c67a32d2cc5990bf729af0f9802be1 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 2 Apr 2026 10:37:26 +0300 Subject: [PATCH 381/831] ggml : bump version to 0.9.11 (ggml/1456) --- ggml/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 2ffc3b391fe..5834e544b48 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -4,7 +4,7 @@ project("ggml" C CXX ASM) ### GGML Version set(GGML_VERSION_MAJOR 0) set(GGML_VERSION_MINOR 9) -set(GGML_VERSION_PATCH 10) +set(GGML_VERSION_PATCH 11) set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}") find_program(GIT_EXE NAMES git git.exe NO_CMAKE_FIND_ROOT_PATH) From 7f6c0ac20f09ed85a3b00c4bb0665a2a091ed770 Mon Sep 17 00:00:00 2001 From: Neo Zhang Date: Thu, 2 Apr 2026 15:08:32 +0800 Subject: [PATCH 382/831] sycl : fix llama_kv_cache hang when kv_cache is huge: 5GB (llama/21283) --- ggml/src/ggml-sycl/ggml-sycl.cpp | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 456b1699fa3..28be4939784 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -569,9 +569,15 @@ static void ggml_backend_sycl_buffer_clear(ggml_backend_buffer_t buffer, SYCL_CHECK( CHECK_TRY_ERROR(dpct::get_current_device().queues_wait_and_throw())); - SYCL_CHECK(CHECK_TRY_ERROR((*stream) - .memset(ctx->dev_ptr, value, buffer->size) - .wait())); + constexpr size_t MAX_CHUNK = 2ULL << 30; // 2 GiB + for (size_t off = 0; off < buffer->size; off += MAX_CHUNK) { + size_t chunk = std::min(buffer->size - off, MAX_CHUNK); + SYCL_CHECK(CHECK_TRY_ERROR( + (*stream) + .memset(static_cast(ctx->dev_ptr) + off, value, chunk) + .wait() + )); + } } catch (sycl::exception const &exc) { std::cerr << exc.what() << "Exception caught at file:" << __FILE__ From c5a5e6528ec6002cd1d84f7a11c42255f4550044 Mon Sep 17 00:00:00 2001 From: Zheyuan Chen Date: Thu, 2 Apr 2026 10:40:42 -0700 Subject: [PATCH 383/831] ggml-webgpu: add vectorized flash attention (llama/20709) * naive vectorized version * add vectorized flash attention * update vec version * remove unused path and shader * remove unused helper functions * add comments * remove pad path * ggml-webgpu: fix flash-attn vec nwg=1 path and tighten vec specialization * change back to vec4 * enable multi split * enable vec path when: - Q->ne[1] < 20 - Q->ne[0] % 32 == 0 - V->ne[0] % 4 == 0 - K->type == f16 * update flast_attn_vec_split.wgsl to reduce redundant workgroup barrier usage and use select * enable vec path for q4 and q8 * flash-attn vec nwg=1 fast path (skip tmp/reduce staging) * use packed f16 K loads in flash-attn vec split * use packed f16 K loads in flash-attn vec split on host side * tune flash-attn vec f16 VEC_NE by head dim * cleanup * cleanup * keep host side clean * cleanup host side * change back to original host wait/submit behavior * formatting * reverted param-buffer pool r ecfactor * add helper functions * ggml-webgpu: move flash-attn vec pipeline caching back into shader lib * ggml-webgpu: remove duplicate functions * ggml-webgpu: reserve flash-attn vec scratch in dst buffer allocation * ggml-webgpu: revert unrelated change * ggml-webgpu: revert deleted comment * disable uniformity check * remove unnecessary change * Update ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl * Update ggml/src/ggml-webgpu/ggml-webgpu.cpp --------- Co-authored-by: Reese Levine --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 230 +++++- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 323 +++++++- .../wgsl-shaders/flash_attn_vec_blk.wgsl | 105 +++ .../wgsl-shaders/flash_attn_vec_reduce.wgsl | 78 ++ .../wgsl-shaders/flash_attn_vec_split.wgsl | 729 ++++++++++++++++++ 5 files changed, 1412 insertions(+), 53 deletions(-) create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index a194ce84e25..1c56c689312 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -95,6 +95,12 @@ struct ggml_webgpu_generic_shader_decisions { uint32_t wg_size = 0; }; +struct ggml_webgpu_processed_shader { + std::string wgsl; + std::string variant; + std::shared_ptr decisions; +}; + struct ggml_webgpu_ssm_conv_shader_decisions { uint32_t block_size; uint32_t tokens_per_wg; @@ -384,11 +390,12 @@ struct ggml_webgpu_flash_attn_pipeline_key { bool has_mask; bool has_sinks; bool uses_logit_softcap; + bool use_vec; bool operator==(const ggml_webgpu_flash_attn_pipeline_key & other) const { return kv_type == other.kv_type && head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v && kv_direct == other.kv_direct && has_mask == other.has_mask && has_sinks == other.has_sinks && - uses_logit_softcap == other.uses_logit_softcap; + uses_logit_softcap == other.uses_logit_softcap && use_vec == other.use_vec; } }; @@ -402,6 +409,7 @@ struct ggml_webgpu_flash_attn_pipeline_key_hash { ggml_webgpu_hash_combine(seed, key.has_mask); ggml_webgpu_hash_combine(seed, key.has_sinks); ggml_webgpu_hash_combine(seed, key.uses_logit_softcap); + ggml_webgpu_hash_combine(seed, key.use_vec); return seed; } }; @@ -421,6 +429,115 @@ struct ggml_webgpu_flash_attn_shader_decisions { uint32_t wg_size = 0; }; +inline uint32_t ggml_webgpu_flash_attn_pick_vec_ne(const ggml_webgpu_flash_attn_pipeline_key & key) { + // Keep conservative defaults unless this is the f16 vec-split shape family. + if (key.kv_type != GGML_TYPE_F16 || key.head_dim_qk != key.head_dim_v) { + return 1u; + } + + // Head-dim specializations used by the tuned vec f16 path. + switch (key.head_dim_qk) { + case 64: return 2u; + case 96: return 4u; + case 128: return 1u; + case 192: return 2u; + case 576: return 2u; + default: return 1u; + } +} + +struct ggml_webgpu_flash_attn_vec_reduce_pipeline_key { + uint32_t head_dim_v; + uint32_t wg_size; +}; + +struct ggml_webgpu_flash_attn_vec_reduce_pipeline_key_hash { + size_t operator()(const ggml_webgpu_flash_attn_vec_reduce_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.head_dim_v); + ggml_webgpu_hash_combine(seed, key.wg_size); + return seed; + } +}; + +inline bool operator==(const ggml_webgpu_flash_attn_vec_reduce_pipeline_key & lhs, + const ggml_webgpu_flash_attn_vec_reduce_pipeline_key & rhs) { + return lhs.head_dim_v == rhs.head_dim_v && lhs.wg_size == rhs.wg_size; +} + +struct ggml_webgpu_flash_attn_vec_reduce_shader_lib_context { + ggml_webgpu_flash_attn_vec_reduce_pipeline_key key; + uint32_t max_wg_size; +}; + +inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_vec_reduce_shader( + pre_wgsl::Preprocessor & preprocessor, + const char * shader_src, + const ggml_webgpu_flash_attn_vec_reduce_shader_lib_context & context) { + std::vector defines; + std::string variant = "flash_attn_vec_reduce"; + + defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(context.key.head_dim_v)); + variant += std::string("_hsv") + std::to_string(context.key.head_dim_v); + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + variant += std::string("_wg") + std::to_string(context.max_wg_size); + + ggml_webgpu_processed_shader result; + result.wgsl = preprocessor.preprocess(shader_src, defines); + result.variant = variant; + return result; +} + +struct ggml_webgpu_flash_attn_blk_pipeline_key { + uint32_t q_tile; + uint32_t kv_tile; + + bool operator==(const ggml_webgpu_flash_attn_blk_pipeline_key & other) const { + return q_tile == other.q_tile && kv_tile == other.kv_tile; + } +}; + +struct ggml_webgpu_flash_attn_blk_pipeline_key_hash { + size_t operator()(const ggml_webgpu_flash_attn_blk_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.q_tile); + ggml_webgpu_hash_combine(seed, key.kv_tile); + return seed; + } +}; + +struct ggml_webgpu_flash_attn_blk_shader_lib_context { + ggml_webgpu_flash_attn_blk_pipeline_key key; + uint32_t max_wg_size; +}; + +inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_blk_shader( + pre_wgsl::Preprocessor & preprocessor, + const char * shader_src, + const ggml_webgpu_flash_attn_blk_shader_lib_context & context) { + std::vector defines; + std::string variant = "flash_attn_vec_blk"; + + defines.push_back(std::string("Q_TILE=") + std::to_string(context.key.q_tile)); + variant += std::string("_qt") + std::to_string(context.key.q_tile); + + defines.push_back(std::string("KV_TILE=") + std::to_string(context.key.kv_tile)); + variant += std::string("_kvt") + std::to_string(context.key.kv_tile); + + uint32_t wg_size = 1; + while ((wg_size << 1) <= context.max_wg_size) { + wg_size <<= 1; + } + defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); + variant += std::string("_wg") + std::to_string(wg_size); + + ggml_webgpu_processed_shader result; + result.wgsl = preprocessor.preprocess(shader_src, defines); + result.variant = variant; + return result; +} + // This is exposed because it's necessary in supports_op inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile, uint32_t kv_tile, @@ -659,6 +776,14 @@ class ggml_webgpu_shader_lib { repeat_pipelines; // type std::unordered_map flash_attn_pipelines; + std::unordered_map + flash_attn_vec_reduce_pipelines; + std::unordered_map + flash_attn_blk_pipelines; std::unordered_map @@ -1673,24 +1798,8 @@ class ggml_webgpu_shader_lib { return repeat_pipelines[key]; } - webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_shader_lib_context & context) { - const bool has_mask = context.src3 != nullptr; - const bool has_sinks = context.src4 != nullptr; - - bool kv_direct = (context.src1->type == GGML_TYPE_F16) && (context.src0->ne[0] % context.sg_mat_k == 0) && - (context.src1->ne[1] % context.sg_mat_n == 0); - - ggml_webgpu_flash_attn_pipeline_key key = { - .kv_type = context.src1->type, - .head_dim_qk = (uint32_t) context.src0->ne[0], - .head_dim_v = (uint32_t) context.src2->ne[0], - .kv_direct = kv_direct, - .has_mask = has_mask, - .has_sinks = has_sinks, - .uses_logit_softcap = (*(float *) &context.dst->op_params[2]) != 0.0f, - }; - - auto it = flash_attn_pipelines.find(key); + webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_flash_attn_shader_lib_context & context) { + auto it = flash_attn_pipelines.find(context.key); if (it != flash_attn_pipelines.end()) { return it->second; } @@ -1698,7 +1807,7 @@ class ggml_webgpu_shader_lib { std::vector defines; std::string variant = "flash_attn"; - switch (key.kv_type) { + switch (context.key.kv_type) { case GGML_TYPE_F32: defines.push_back("KV_F32"); break; @@ -1714,41 +1823,52 @@ class ggml_webgpu_shader_lib { default: GGML_ABORT("Unsupported KV type for flash attention shader"); } - variant += std::string("_") + ggml_type_name(key.kv_type); + variant += std::string("_") + ggml_type_name(context.key.kv_type); - if (key.has_mask) { + if (context.key.has_mask) { defines.push_back("MASK"); variant += "_mask"; } - if (key.has_sinks) { + if (context.key.has_sinks) { defines.push_back("SINKS"); variant += "_sinks"; } - if (key.uses_logit_softcap) { + if (context.key.uses_logit_softcap) { defines.push_back("LOGIT_SOFTCAP"); variant += "_lgsc"; } - if (key.kv_direct) { + if (context.key.kv_direct) { defines.push_back("KV_DIRECT"); variant += "_kvdirect"; } + if (context.key.has_mask && context.key.use_vec) { + defines.push_back("BLK"); + variant += "_blk"; + } - defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(key.head_dim_qk)); - variant += std::string("_hsqk") + std::to_string(key.head_dim_qk); + defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(context.key.head_dim_qk)); + variant += std::string("_hsqk") + std::to_string(context.key.head_dim_qk); - defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v)); - variant += std::string("_hsv") + std::to_string(key.head_dim_v); + defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(context.key.head_dim_v)); + variant += std::string("_hsv") + std::to_string(context.key.head_dim_v); defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m)); defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n)); defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k)); - uint32_t q_tile = context.sg_mat_m; + uint32_t q_tile = context.sg_mat_m; uint32_t kv_tile = - std::min(ggml_webgpu_flash_attn_max_kv_tile({ key, context.sg_mat_m, context.sg_mat_n, context.sg_mat_k, - context.wg_mem_limit_bytes, context.max_subgroup_size }), + std::min(ggml_webgpu_flash_attn_max_kv_tile(context), context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES); - if (key.kv_direct) { + if (context.key.use_vec) { + q_tile = 1; + kv_tile = std::max(context.sg_mat_n, std::min(32u, ggml_webgpu_flash_attn_max_kv_tile(context))); + kv_tile = (kv_tile / context.sg_mat_n) * context.sg_mat_n; + const uint32_t vec_ne = ggml_webgpu_flash_attn_pick_vec_ne(context.key); + defines.push_back(std::string("VEC_NE=") + std::to_string(vec_ne) + "u"); + } + if (context.key.kv_direct) { + GGML_ASSERT(kv_tile <= GGML_WEBGPU_KV_SEQ_PAD); while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) { kv_tile -= context.sg_mat_n; } @@ -1757,19 +1877,51 @@ class ggml_webgpu_shader_lib { defines.push_back(std::string("Q_TILE=") + std::to_string(q_tile)); defines.push_back(std::string("KV_TILE=") + std::to_string(kv_tile)); - uint32_t wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE); + uint32_t wg_size = 0; + if (context.key.use_vec) { + wg_size = std::max(1u, std::min(32u, context.max_subgroup_size)); + } else { + wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE); + } defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); - auto processed = preprocessor.preprocess(wgsl_flash_attn, defines); + const char * shader_src = context.key.use_vec ? wgsl_flash_attn_vec_split : wgsl_flash_attn; + webgpu_pipeline pipeline = + ggml_webgpu_create_pipeline(device, preprocessor.preprocess(shader_src, defines), variant); auto decisions = std::make_shared(); decisions->q_tile = q_tile; decisions->kv_tile = kv_tile; decisions->wg_size = wg_size; + pipeline.context = decisions; + flash_attn_pipelines[context.key] = pipeline; + return flash_attn_pipelines[context.key]; + } + + webgpu_pipeline get_flash_attn_blk_pipeline(const ggml_webgpu_flash_attn_blk_shader_lib_context & context) { + auto it = flash_attn_blk_pipelines.find(context.key); + if (it != flash_attn_blk_pipelines.end()) { + return it->second; + } + + ggml_webgpu_processed_shader processed = + ggml_webgpu_preprocess_flash_attn_blk_shader(preprocessor, wgsl_flash_attn_vec_blk, context); + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed.wgsl, processed.variant); + flash_attn_blk_pipelines[context.key] = pipeline; + return flash_attn_blk_pipelines[context.key]; + } + + webgpu_pipeline get_flash_attn_vec_reduce_pipeline( + const ggml_webgpu_flash_attn_vec_reduce_shader_lib_context & context) { + auto it = flash_attn_vec_reduce_pipelines.find(context.key); + if (it != flash_attn_vec_reduce_pipelines.end()) { + return it->second; + } - webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); - pipeline.context = decisions; - flash_attn_pipelines[key] = pipeline; - return flash_attn_pipelines[key]; + ggml_webgpu_processed_shader processed = + ggml_webgpu_preprocess_flash_attn_vec_reduce_shader(preprocessor, wgsl_flash_attn_vec_reduce, context); + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed.wgsl, processed.variant); + flash_attn_vec_reduce_pipelines[context.key] = pipeline; + return flash_attn_vec_reduce_pipelines[context.key]; } webgpu_pipeline get_cpy_pipeline(const ggml_webgpu_shader_lib_context & context) { diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 1aa15b0507c..e53281bfbbd 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -658,7 +658,6 @@ static webgpu_command ggml_backend_webgpu_build_multi( for (size_t i = 0; i < params_bufs_list.size(); i++) { ctx->queue.WriteBuffer(params_bufs_list[i], 0, params_list[i].data(), params_list[i].size() * sizeof(uint32_t)); } - #ifdef GGML_WEBGPU_GPU_PROFILE webgpu_gpu_profile_bufs ts_bufs = ctx->timestamp_query_buf_pool.alloc_bufs(); if (ts_bufs.host_buf.GetMapState() == wgpu::BufferMapState::Mapped) { @@ -1481,7 +1480,6 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, wg_y); } -#ifndef __EMSCRIPTEN__ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, ggml_tensor * Q, ggml_tensor * K, @@ -1565,30 +1563,248 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, .offset = ggml_webgpu_tensor_align_offset(ctx, dst), .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = Q, - .src1 = K, - .src2 = V, - .src3 = mask, - .src4 = sinks, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize, + const uint32_t k_offset_elems = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type)); + const uint32_t v_offset_elems = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type)); + const bool f16_vec4_aligned = (k_offset_elems % 4u == 0u) && (v_offset_elems % 4u == 0u); + + const bool kv_direct = (K->type == GGML_TYPE_F16) && f16_vec4_aligned && + (Q->ne[0] % ctx->global_ctx->capabilities.sg_mat_k == 0) && + (K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0); + + const bool kv_vec_type_supported = + K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0; + const bool use_vec = (Q->ne[1] < 20) && (Q->ne[0] % 32 == 0) && (V->ne[0] % 4 == 0) && kv_vec_type_supported && + (K->type != GGML_TYPE_F16 || f16_vec4_aligned) && (V->type == K->type); + const uint32_t vec_nwg_cap = + std::max(1u, std::min(32u, ctx->global_ctx->capabilities.max_subgroup_size)); + const bool use_blk = use_vec && has_mask; + + ggml_webgpu_flash_attn_pipeline_key key = { + .kv_type = K->type, + .head_dim_qk = (uint32_t) Q->ne[0], + .head_dim_v = (uint32_t) V->ne[0], + .kv_direct = kv_direct, + .has_mask = static_cast(has_mask), + .has_sinks = static_cast(has_sinks), + .uses_logit_softcap = logit_softcap != 0.0f, + .use_vec = use_vec, + }; + + ggml_webgpu_flash_attn_shader_lib_context shader_lib_ctx = { + .key = key, .sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m, .sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n, .sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k, + .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize, .max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size, }; - webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_pipeline(shader_lib_ctx); auto * decisions = static_cast(pipeline.context.get()); uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions->q_tile); uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches + + wgpu::Buffer blk_buf = {}; + uint64_t blk_size_bytes = 0; + uint32_t blk_nblk0 = 0; + uint32_t blk_nblk1 = 0; + uint32_t blk_batch_count = 0; + + if (use_vec) { + uint32_t nwg = 1u; + const uint64_t kv_span = (uint64_t) std::max(1u, decisions->kv_tile); + while ((2u * nwg * kv_span) < (uint64_t) K->ne[1] && nwg < vec_nwg_cap) { + nwg <<= 1; + } + nwg = std::min(nwg, vec_nwg_cap); + GGML_ASSERT(nwg <= ctx->global_ctx->capabilities.max_subgroup_size); + const uint64_t nrows = (uint64_t) Q->ne[1] * Q->ne[2] * Q->ne[3]; + const bool use_vec_reduce = nwg > 1u; + GGML_ASSERT(nrows <= UINT32_MAX); + + uint64_t tmp_stats_base = 0; + uint64_t tmp_size_bytes = 0; + wgpu::Buffer tmp_buf = {}; + uint64_t tmp_bind_offset = 0; + uint64_t tmp_bind_size = 0; + const size_t align_bytes = ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment; + const size_t dst_offset = ggml_webgpu_tensor_offset(dst); + size_t scratch_offset = ROUNDUP_POW2(dst_offset + ggml_nbytes(dst), align_bytes); + + if (use_vec_reduce) { + const uint64_t tmp_data_elems = nrows * (uint64_t) V->ne[0] * nwg; + const uint64_t tmp_stats_elems = nrows * 2u * nwg; + tmp_stats_base = tmp_data_elems; + tmp_size_bytes = + ROUNDUP_POW2((tmp_data_elems + tmp_stats_elems) * sizeof(float), WEBGPU_STORAGE_BUF_BINDING_MULT); + GGML_ASSERT(tmp_stats_base <= UINT32_MAX); + tmp_buf = ggml_webgpu_tensor_buf(dst); + tmp_bind_offset = scratch_offset; + tmp_bind_size = tmp_size_bytes; + scratch_offset = ROUNDUP_POW2(scratch_offset + tmp_size_bytes, align_bytes); + } else { + // nwg==1 writes final dst directly in vec-split; keep tmp binding valid without extra allocation. + tmp_buf = ggml_webgpu_tensor_buf(dst); + tmp_bind_offset = ggml_webgpu_tensor_align_offset(ctx, dst); + tmp_bind_size = ggml_webgpu_tensor_binding_size(ctx, dst); + } + + webgpu_pipeline blk_pipeline; + std::vector blk_params; + std::vector blk_entries; + if (use_blk) { + GGML_ASSERT(has_mask); + + blk_nblk0 = CEIL_DIV((uint32_t) K->ne[1], decisions->kv_tile); + blk_nblk1 = CEIL_DIV((uint32_t) Q->ne[1], decisions->q_tile); + blk_buf = ggml_webgpu_tensor_buf(dst); + const uint32_t stride_mask3 = (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)); + blk_batch_count = stride_mask3 > 0 ? (uint32_t) Q->ne[3] : 1u; + const uint64_t blk_elems = (uint64_t) blk_nblk0 * blk_nblk1 * blk_batch_count; + blk_size_bytes = ROUNDUP_POW2(blk_elems * sizeof(uint32_t), WEBGPU_STORAGE_BUF_BINDING_MULT); + ggml_webgpu_flash_attn_blk_shader_lib_context blk_shader_ctx = { + .key = + { + .q_tile = decisions->q_tile, + .kv_tile = decisions->kv_tile, + }, + .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, + }; + blk_pipeline = ctx->shader_lib->get_flash_attn_blk_pipeline(blk_shader_ctx); + + blk_params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)), // offset_mask + (uint32_t) Q->ne[1], // seq_len_q + (uint32_t) K->ne[1], // seq_len_kv + stride_mask3, // stride_mask3 + blk_nblk0, // nblk0 + blk_nblk1, // nblk1 + }; + blk_entries = { + { .binding = 0, + .buffer = ggml_webgpu_tensor_buf(mask), + .offset = ggml_webgpu_tensor_align_offset(ctx, mask), + .size = ggml_webgpu_tensor_binding_size(ctx, mask) }, + { .binding = 1, .buffer = blk_buf, .offset = scratch_offset, .size = blk_size_bytes }, + }; + scratch_offset = ROUNDUP_POW2(scratch_offset + blk_size_bytes, align_bytes); + } + + std::vector split_params = params; + if (use_blk) { + split_params.push_back(0u); // blk_base + split_params.push_back(blk_nblk0); // blk_nblk0 + split_params.push_back(blk_nblk1); // blk_nblk1 + } + split_params.push_back(0u); // tmp_data_base + split_params.push_back((uint32_t) tmp_stats_base); // tmp_stats_base + split_params.push_back(nwg); // nwg + + std::vector split_entries = { + { .binding = 0, + .buffer = ggml_webgpu_tensor_buf(Q), + .offset = ggml_webgpu_tensor_align_offset(ctx, Q), + .size = ggml_webgpu_tensor_binding_size(ctx, Q) }, + { .binding = 1, + .buffer = ggml_webgpu_tensor_buf(K), + .offset = ggml_webgpu_tensor_align_offset(ctx, K), + .size = ggml_webgpu_tensor_binding_size(ctx, K) }, + { .binding = 2, + .buffer = ggml_webgpu_tensor_buf(V), + .offset = ggml_webgpu_tensor_align_offset(ctx, V), + .size = ggml_webgpu_tensor_binding_size(ctx, V) }, + }; + uint32_t split_binding_index = 3; + if (has_mask) { + split_entries.push_back({ .binding = split_binding_index++, + .buffer = ggml_webgpu_tensor_buf(mask), + .offset = ggml_webgpu_tensor_align_offset(ctx, mask), + .size = ggml_webgpu_tensor_binding_size(ctx, mask) }); + } + if (has_sinks) { + split_entries.push_back({ .binding = split_binding_index++, + .buffer = ggml_webgpu_tensor_buf(sinks), + .offset = ggml_webgpu_tensor_align_offset(ctx, sinks), + .size = ggml_webgpu_tensor_binding_size(ctx, sinks) }); + } + if (use_blk) { + split_entries.push_back( + { .binding = split_binding_index++, .buffer = blk_buf, .offset = blk_entries[1].offset, .size = blk_size_bytes }); + } + split_entries.push_back( + { .binding = split_binding_index++, .buffer = tmp_buf, .offset = tmp_bind_offset, .size = tmp_bind_size }); + split_entries.push_back({ .binding = split_binding_index++, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = ggml_webgpu_tensor_align_offset(ctx, dst), + .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + + webgpu_pipeline reduce_pipeline; + std::vector reduce_params; + std::vector reduce_entries; + if (use_vec_reduce) { + const uint32_t reduce_wg_size = std::max( + 32u, + std::min(nwg * 32u, ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup)); + ggml_webgpu_flash_attn_vec_reduce_shader_lib_context reduce_shader_ctx = { + .key = + { + .head_dim_v = (uint32_t) V->ne[0], + .wg_size = reduce_wg_size, + }, + .max_wg_size = reduce_wg_size, + }; + reduce_pipeline = ctx->shader_lib->get_flash_attn_vec_reduce_pipeline(reduce_shader_ctx); + + reduce_params = { + (uint32_t) nrows, // nrows + (uint32_t) Q->ne[1], // seq_len_q + (uint32_t) Q->ne[2], // n_heads + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), // offset_dst + nwg, // nwg + 0u, // tmp_data_base + (uint32_t) tmp_stats_base, // tmp_stats_base + }; + + reduce_entries = { + { .binding = 0, .buffer = tmp_buf, .offset = tmp_bind_offset, .size = tmp_size_bytes }, + { .binding = 1, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = ggml_webgpu_tensor_align_offset(ctx, dst), + .size = ggml_webgpu_tensor_binding_size(ctx, dst) }, + }; + } + + const uint64_t split_wg_total = (uint64_t) wg_x * nwg; + GGML_ASSERT(split_wg_total <= UINT32_MAX); + std::vector pipelines; + std::vector> params_list; + std::vector> entries_list; + std::vector> workgroups_list; + + if (use_blk) { + pipelines.push_back(blk_pipeline); + params_list.push_back(std::move(blk_params)); + entries_list.push_back(std::move(blk_entries)); + workgroups_list.push_back({ blk_nblk0, blk_nblk1 * blk_batch_count }); + } + pipelines.push_back(pipeline); + params_list.push_back(std::move(split_params)); + entries_list.push_back(std::move(split_entries)); + workgroups_list.push_back({ (uint32_t) split_wg_total, 1u }); + if (use_vec_reduce) { + pipelines.push_back(reduce_pipeline); + params_list.push_back(std::move(reduce_params)); + entries_list.push_back(std::move(reduce_entries)); + workgroups_list.push_back({ (uint32_t) nrows, 1u }); + } + + return ggml_backend_webgpu_build_multi(ctx->global_ctx, ctx->param_buf_pool, pipelines, params_list, + entries_list, workgroups_list); + } + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } -#endif static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { bool is_unary = dst->op == GGML_OP_UNARY; @@ -2559,7 +2775,6 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str std::vector subs; uint32_t num_batched_kernels = 0; bool contains_set_rows = false; - for (int i = 0; i < cgraph->n_nodes; i++) { if (cgraph->nodes[i]->op == GGML_OP_SET_ROWS) { contains_set_rows = true; @@ -2834,6 +3049,86 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer } } break; + case GGML_OP_FLASH_ATTN_EXT: + { + const ggml_tensor * Q = tensor->src[0]; + const ggml_tensor * K = tensor->src[1]; + const ggml_tensor * V = tensor->src[2]; + const ggml_tensor * mask = tensor->src[3]; + const ggml_tensor * sinks = tensor->src[4]; + if (Q && K && V) { + GGML_UNUSED(sinks); + const bool kv_direct = (K->type == GGML_TYPE_F16) && + (Q->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k == 0) && + (K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0); + const bool kv_vec_type_supported = + K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0; + const bool use_vec = + (Q->ne[1] < 20) && (Q->ne[0] % 32 == 0) && (V->ne[0] % 4 == 0) && kv_vec_type_supported && + (V->type == K->type); + if (use_vec) { + const uint32_t sg_mat_m = ctx->webgpu_global_ctx->capabilities.sg_mat_m; + const uint32_t sg_mat_n = ctx->webgpu_global_ctx->capabilities.sg_mat_n; + const size_t limit_bytes = + ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; + const size_t q_tile = sg_mat_m; + const size_t base_q_bytes = + (Q->ne[0] + V->ne[0]) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES + + 2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES; + size_t bytes_per_kv = 0; + if (!kv_direct) { + bytes_per_kv += std::max(Q->ne[0], V->ne[0]); + } + if (mask != nullptr) { + bytes_per_kv += q_tile; + } + bytes_per_kv += q_tile; + bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES; + uint32_t kv_tile = + ((limit_bytes - base_q_bytes) / bytes_per_kv / sg_mat_n) * sg_mat_n; + kv_tile = std::max(sg_mat_n, std::min(32u, kv_tile)); + kv_tile = (kv_tile / sg_mat_n) * sg_mat_n; + if (kv_direct) { + GGML_ASSERT(kv_tile <= GGML_WEBGPU_KV_SEQ_PAD); + while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) { + kv_tile -= sg_mat_n; + } + } + + const uint32_t vec_nwg_cap = std::max( + 1u, std::min(32u, ctx->webgpu_global_ctx->capabilities.max_subgroup_size)); + uint32_t nwg = 1u; + const uint64_t kv_span = (uint64_t) std::max(1u, kv_tile); + while ((2u * nwg * kv_span) < (uint64_t) K->ne[1] && nwg < vec_nwg_cap) { + nwg <<= 1; + } + nwg = std::min(nwg, vec_nwg_cap); + + const size_t align = ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment; + const uint64_t nrows = (uint64_t) Q->ne[1] * Q->ne[2] * Q->ne[3]; + if (nwg > 1u) { + const uint64_t tmp_data_elems = nrows * (uint64_t) V->ne[0] * nwg; + const uint64_t tmp_stats_elems = nrows * 2u * nwg; + const size_t tmp_size_bytes = ROUNDUP_POW2( + (tmp_data_elems + tmp_stats_elems) * sizeof(float), WEBGPU_STORAGE_BUF_BINDING_MULT); + res += tmp_size_bytes + align; + } + if (mask != nullptr) { + const uint32_t blk_nblk0 = CEIL_DIV((uint32_t) K->ne[1], kv_tile); + const uint32_t blk_nblk1 = CEIL_DIV((uint32_t) Q->ne[1], 1u); + const uint32_t stride_mask3 = + (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)); + const uint32_t blk_batch_count = stride_mask3 > 0 ? (uint32_t) Q->ne[3] : 1u; + const uint64_t blk_elems = (uint64_t) blk_nblk0 * blk_nblk1 * blk_batch_count; + const size_t blk_size_bytes = + ROUNDUP_POW2(blk_elems * sizeof(uint32_t), WEBGPU_STORAGE_BUF_BINDING_MULT); + res += blk_size_bytes + align; + } + res = ROUNDUP_POW2(res, WEBGPU_STORAGE_BUF_BINDING_MULT); + } + } + } + break; default: break; } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl new file mode 100644 index 00000000000..82d072be73a --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl @@ -0,0 +1,105 @@ +diagnostic(off, subgroup_uniformity); +enable f16; + +#define Q_TILE 1 +#define KV_TILE 32 +#define WG_SIZE 32 + +struct Params { + offset_mask: u32, + seq_len_q: u32, + seq_len_kv: u32, + stride_mask3: u32, + // Number of KV blocks and Q blocks per batch. + // nblk0 = ceil(seq_len_kv / KV_TILE), nblk1 = ceil(seq_len_q / Q_TILE). + nblk0: u32, + nblk1: u32, +}; + +@group(0) @binding(0) var mask: array; +@group(0) @binding(1) var blk: array; +@group(0) @binding(2) var params: Params; + +const MASK_MIN: f32 = -65504.0; +const MASK_MAX: f32 = 65504.0; +var wg_min: array; +var wg_max: array; +var wg_any: array; + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(workgroup_id) wg_id: vec3, + @builtin(local_invocation_id) local_id: vec3) { + // Dispatch mapping: + // - x indexes KV blocks + // - y flattens (batch_idx, q_blk) as y = batch_idx * nblk1 + q_blk + let kv_blk = wg_id.x; + let y = wg_id.y; + let q_blk = y % params.nblk1; + let batch_idx = y / params.nblk1; + if (kv_blk >= params.nblk0) { + return; + } + + let q_start = q_blk * Q_TILE; + let k_start = kv_blk * KV_TILE; + + let mask_batch = select(0u, batch_idx, params.stride_mask3 > 0u); + let mask_batch_base = params.offset_mask + mask_batch * params.stride_mask3; + + // We keep min/max to classify: + // - fully masked (max <= MASK_MIN) + // - all-zero mask (min == 0 && max == 0) + // - mixed/general mask + var local_min = MASK_MAX; + var local_max = -MASK_MAX; + var local_any = 0u; + + for (var q_rel = 0u; q_rel < Q_TILE; q_rel += 1u) { + let q_row = q_start + q_rel; + if (q_row >= params.seq_len_q) { + continue; + } + let row_base = mask_batch_base + q_row * params.seq_len_kv; + for (var k_rel = local_id.x; k_rel < KV_TILE; k_rel += WG_SIZE) { + let k_col = k_start + k_rel; + if (k_col >= params.seq_len_kv) { + continue; + } + let mv = f32(mask[row_base + k_col]); + local_min = min(local_min, mv); + local_max = max(local_max, mv); + local_any = 1u; + } + } + + wg_min[local_id.x] = local_min; + wg_max[local_id.x] = local_max; + wg_any[local_id.x] = local_any; + workgroupBarrier(); + + // Thread 0 writes one state per block. + if (local_id.x == 0u) { + var mmin = wg_min[0]; + var mmax = wg_max[0]; + var many = wg_any[0]; + for (var i = 1u; i < WG_SIZE; i += 1u) { + mmin = min(mmin, wg_min[i]); + mmax = max(mmax, wg_max[i]); + many = max(many, wg_any[i]); + } + + var state = 0u; + if (many != 0u) { + if (mmax <= MASK_MIN) { + state = 0u; + } else if (mmin == 0.0 && mmax == 0.0) { + state = 2u; + } else { + state = 1u; + } + } + + let blk_idx = (batch_idx * params.nblk1 + q_blk) * params.nblk0 + kv_blk; + blk[blk_idx] = state; + } +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl new file mode 100644 index 00000000000..9a0de82a56a --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl @@ -0,0 +1,78 @@ +diagnostic(off, subgroup_uniformity); +enable f16; +enable subgroups; + +// Default values +#define HEAD_DIM_V 64 +#define WG_SIZE 128 + +struct Params { + nrows: u32, + seq_len_q: u32, + n_heads: u32, + offset_dst: u32, + nwg: u32, + tmp_data_base: u32, + tmp_stats_base: u32, +}; + +@group(0) @binding(0) var tmp: array; +@group(0) @binding(1) var dst: array>; +@group(0) @binding(2) var params: Params; + +const FLOAT_MIN: f32 = -1.0e9; + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(workgroup_id) wg_id: vec3, + @builtin(subgroup_id) subgroup_id: u32, + @builtin(num_subgroups) num_subgroups: u32, + @builtin(subgroup_size) subgroup_size: u32, + @builtin(subgroup_invocation_id) sg_inv_id: u32) { + let rid = wg_id.x; + if (rid >= params.nrows) { + return; + } + + let rows_per_batch = params.n_heads * params.seq_len_q; + let batch_idx = rid / rows_per_batch; + let rem = rid % rows_per_batch; + let head_idx = rem / params.seq_len_q; + let q_row = rem % params.seq_len_q; + + let dst2_stride = HEAD_DIM_V * params.n_heads; + let dst3_stride = dst2_stride * params.seq_len_q; + let row_base = params.offset_dst + batch_idx * dst3_stride + q_row * dst2_stride + head_idx * HEAD_DIM_V; + + let thread = sg_inv_id; + if (params.nwg > subgroup_size) { + return; + } + + let stats_base = params.tmp_stats_base + rid * (2u * params.nwg); + let active_thread = thread < params.nwg; + let si = select(0.0, tmp[stats_base + 2u * thread + 0u], active_thread); + let mi = select(FLOAT_MIN, tmp[stats_base + 2u * thread + 1u], active_thread); + let m = subgroupMax(mi); + let ms = select(0.0, exp(mi - m), active_thread); + let s = subgroupAdd(si * ms); + let inv_s = select(0.0, 1.0 / s, s != 0.0); + + let row_tmp_base = params.tmp_data_base + rid * (HEAD_DIM_V * params.nwg); + for (var elem_base = subgroup_id * 4u; elem_base < HEAD_DIM_V; elem_base += num_subgroups * 4u) { + var weighted = vec4(0.0, 0.0, 0.0, 0.0); + if (active_thread) { + let src = row_tmp_base + thread * HEAD_DIM_V + elem_base; + weighted = vec4(tmp[src + 0u], tmp[src + 1u], tmp[src + 2u], tmp[src + 3u]) * ms; + } + + let sum_x = subgroupAdd(weighted.x); + let sum_y = subgroupAdd(weighted.y); + let sum_z = subgroupAdd(weighted.z); + let sum_w = subgroupAdd(weighted.w); + + if (thread == 0u) { + let dst_vec_index = (row_base + elem_base) >> 2u; + dst[dst_vec_index] = vec4(sum_x, sum_y, sum_z, sum_w) * inv_s; + } + } +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl new file mode 100644 index 00000000000..a52575871ae --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl @@ -0,0 +1,729 @@ +diagnostic(off, chromium.subgroup_matrix_uniformity); +diagnostic(off, subgroup_uniformity); +enable f16; +enable subgroups; +enable chromium_experimental_subgroup_matrix; + +#ifdef KV_F32 +#define KV_TYPE f32 +#else +#define KV_TYPE f16 +#endif + +#define HEAD_DIM_QK 64 +#define HEAD_DIM_V 64 + + +#define SG_MAT_M 8 +#define SG_MAT_N 8 +#define SG_MAT_K 8 + +#define Q_TILE SG_MAT_M +#define KV_TILE 16 +#define WG_SIZE 64 +#ifndef VEC_NE +#define VEC_NE 4u +#endif + +#define KV_BLOCKS (KV_TILE / SG_MAT_N) + +#define BLOCK_SIZE 32 +#define BLOCKS_K ((HEAD_DIM_QK + BLOCK_SIZE - 1) / BLOCK_SIZE) +#define BLOCKS_V ((HEAD_DIM_V + BLOCK_SIZE - 1) / BLOCK_SIZE) +#if defined(KV_Q4_0) +#define NQ 16 +#define F16_PER_BLOCK 9 +#define WEIGHTS_PER_F16 4 +#elif defined(KV_Q8_0) +#define NQ 8 +#define F16_PER_BLOCK 17 +#define WEIGHTS_PER_F16 2 +#endif +#define F16_PER_THREAD (NQ / WEIGHTS_PER_F16) + +fn get_byte(value: u32, index: u32) -> u32 { + return (value >> (index * 8)) & 0xFF; +} + +fn get_byte_i32(value: u32, index: u32) -> i32 { + return bitcast(((value >> (index * 8)) & 0xFF) << 24) >> 24; +} + +struct Params { + offset_q: u32, + offset_k: u32, + offset_v: u32, + offset_mask: u32, + offset_sinks: u32, + offset_dst: u32, + + // shapes of Q/K/V + n_heads: u32, + seq_len_q: u32, + seq_len_kv: u32, + + // strides (in elements) + stride_q1: u32, + stride_q2: u32, + stride_q3: u32, + stride_k1: u32, + stride_k2: u32, + stride_k3: u32, + stride_v1: u32, + stride_v2: u32, + stride_v3: u32, + stride_mask3: u32, + + // repeat factors for K/V, e.g., MHA vs. MQA vs. GQA + q_per_kv: u32, + + // softmax params + scale: f32, + max_bias: f32, + logit_softcap: f32, + n_head_log2: f32, + m0: f32, + m1: f32, + +#ifdef BLK + blk_base: u32, + blk_nblk0: u32, + blk_nblk1: u32, +#endif + + tmp_data_base: u32, + tmp_stats_base: u32, + nwg: u32, +}; + +@group(0) @binding(0) var Q: array; +#if defined(KV_Q4_0) || defined(KV_Q8_0) +@group(0) @binding(1) var K: array; +#else +@group(0) @binding(1) var K: array>; +#endif +#if defined(KV_Q4_0) || defined(KV_Q8_0) +@group(0) @binding(2) var V: array; +#else +@group(0) @binding(2) var V: array>; +#endif +#if defined(MASK) && defined(SINKS) +@group(0) @binding(3) var mask: array; +@group(0) @binding(4) var sinks: array; +#ifdef BLK +#define BLK_BINDING 5 +#define TMP_BINDING 6 +#define DST_BINDING 7 +#define PARAMS_BINDING 8 +#else +#define TMP_BINDING 5 +#define DST_BINDING 6 +#define PARAMS_BINDING 7 +#endif +#elif defined(MASK) +@group(0) @binding(3) var mask: array; +#ifdef BLK +#define BLK_BINDING 4 +#define TMP_BINDING 5 +#define DST_BINDING 6 +#define PARAMS_BINDING 7 +#else +#define TMP_BINDING 4 +#define DST_BINDING 5 +#define PARAMS_BINDING 6 +#endif +#elif defined(SINKS) +@group(0) @binding(3) var sinks: array; +#define TMP_BINDING 4 +#define DST_BINDING 5 +#define PARAMS_BINDING 6 +#else +#define TMP_BINDING 3 +#define DST_BINDING 4 +#define PARAMS_BINDING 5 +#endif + +#ifdef BLK +@group(0) @binding(BLK_BINDING) var blk: array; +#endif +@group(0) @binding(TMP_BINDING) var tmp: array; +@group(0) @binding(DST_BINDING) var dst: array>; +@group(0) @binding(PARAMS_BINDING) var params: Params; + +// Just a very small float value. +const FLOAT_MIN: f32 = -1.0e9; + +var q_shmem: array; + +#ifndef KV_DIRECT +const kv_shmem_size = KV_TILE * max(HEAD_DIM_QK, HEAD_DIM_V); +// we can reuse the same shmem for K and V since we only need one at a time +var kv_shmem: array; +#endif + +var o_shmem: array; + +#ifdef MASK +// storage for mask values +var mask_shmem: array; +#endif + +// note that we reuse the same storage for both since we only need one at a time +var inter_shmem: array; + +// Storage for row max and exp sum during online softmax +var row_max_shmem: array; +var exp_sum_shmem: array; +var blk_state_wg: u32; + +fn calc_softmax_term(kv_idx: u32, q_tile_row: u32, slope: f32, has_bias: bool, apply_mask: bool) -> f32 { + var v = select(FLOAT_MIN, + f32(inter_shmem[kv_idx + q_tile_row * KV_TILE]) * params.scale, + kv_idx < KV_TILE); +#ifdef LOGIT_SOFTCAP + v = params.logit_softcap * tanh(v); +#endif +#ifdef MASK + if (apply_mask) { + var mask_val = select(0.0,f32(mask_shmem[q_tile_row * KV_TILE + kv_idx]), kv_idx < KV_TILE); + v += select(mask_val, slope * mask_val, has_bias); + } +#endif + return v; +} + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(workgroup_id) wg_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(subgroup_id) subgroup_id: u32, + @builtin(subgroup_size) subgroup_size: u32, + @builtin(num_subgroups) num_subgroups: u32, + @builtin(subgroup_invocation_id) sg_inv_id: u32) { + + // initialize row max for online softmax + for (var i = local_id.x; i < Q_TILE; i += WG_SIZE) { + row_max_shmem[i] = FLOAT_MIN; + exp_sum_shmem[i] = 0.0; + } + + for (var i = local_id.x; i < Q_TILE * HEAD_DIM_V; i += WG_SIZE) { + o_shmem[i] = 0.0; + } + + // workgroups per head/batch + let wg_per_head = (params.seq_len_q + Q_TILE - 1u) / Q_TILE; + let wg_per_batch = wg_per_head * params.n_heads; + + let dst2_stride = HEAD_DIM_V * params.n_heads; + let dst3_stride = dst2_stride * params.seq_len_q; + + let iwg = wg_id.x % params.nwg; + let base_wg_id = wg_id.x / params.nwg; + + // batch index + let batch_idx = base_wg_id / wg_per_batch; + let q_batch_offset = params.offset_q + batch_idx * params.stride_q3; + let k_batch_offset = params.offset_k + batch_idx * params.stride_k3; + let v_batch_offset = params.offset_v + batch_idx * params.stride_v3; + let wg_in_batch = base_wg_id % wg_per_batch; + + // head index + let head_idx = wg_in_batch / wg_per_head; + let q_head_offset = q_batch_offset + head_idx * params.stride_q2; + let k_head_idx = head_idx / params.q_per_kv; + let v_head_idx = k_head_idx; + let k_head_offset = k_batch_offset + k_head_idx * params.stride_k2; + let v_head_offset = v_batch_offset + v_head_idx * params.stride_v2; + + // starting Q row for this workgroup + let wg_in_head = wg_in_batch % wg_per_head; + let q_row_start = wg_in_head * Q_TILE; + +#ifdef MASK + // mask offset + let mask_global_offset = params.offset_mask + batch_idx * params.stride_mask3 + q_row_start * params.seq_len_kv; +#endif + + let head = f32(head_idx); + let has_bias = params.max_bias > 0.0; + let slope = select(1.0, select(pow(params.m1, 2.0 * (head - params.n_head_log2) + 1.0), pow(params.m0, head + 1.0), head < params.n_head_log2), has_bias); + + // load q tile into shared memory + for (var elem_idx = local_id.x; elem_idx < Q_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE) { + let q_row = elem_idx / HEAD_DIM_QK; + let q_col = elem_idx % HEAD_DIM_QK; + let head_q_row = q_row_start + q_row; + let global_q_row_offset = q_head_offset + head_q_row * params.stride_q1; + q_shmem[elem_idx] = f16(select( + 0.0, + Q[global_q_row_offset + q_col], + head_q_row < params.seq_len_q && q_col < HEAD_DIM_QK)); + } + + for (var kv_tile = iwg * KV_TILE; kv_tile < params.seq_len_kv; kv_tile += KV_TILE * params.nwg) { +#ifdef BLK + let q_blk = q_row_start / Q_TILE; + let kv_blk = kv_tile / KV_TILE; + let blk_batch = select(0u, batch_idx, params.stride_mask3 > 0u); + let blk_idx = params.blk_base + (blk_batch * params.blk_nblk1 + q_blk) * params.blk_nblk0 + kv_blk; + let blk_state_local = blk[blk_idx]; +#else + let blk_state_local = 1u; +#endif + if (local_id.x == 0u) { + blk_state_wg = blk_state_local; + } + workgroupBarrier(); + let blk_state = blk_state_wg; + let skip_tile = blk_state == 0u; + for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) { + inter_shmem[elem_idx] = f16(0.0); + } + + // load k tile into shared memory +#if defined(KV_Q4_0) + for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) { + let blck_idx = elem_idx / BLOCK_SIZE; + let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; + let k_row = blck_idx / BLOCKS_K; + let global_k_row = kv_tile + k_row; + let block_k = blck_idx % BLOCKS_K; + let row_offset = k_row * HEAD_DIM_QK; + + if (global_k_row < params.seq_len_kv) { + let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k; + let base_idx = global_block_idx * F16_PER_BLOCK; + let d = K[base_idx]; + for (var j = 0u; j < F16_PER_THREAD; j += 2) { + let q_0 = K[base_idx + 1u + block_offset + j]; + let q_1 = K[base_idx + 1u + block_offset + j + 1]; + let q_packed = bitcast(vec2(q_0, q_1)); + for (var k = 0u; k < 4u; k++) { + let q_byte = get_byte(q_packed, k); + let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d; + let q_lo = (f16(q_byte & 0xF) - 8.0) * d; + let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; + kv_shmem[row_offset + idx] = q_lo; + kv_shmem[row_offset + idx + 16u] = q_hi; + } + } + } + } +#elif defined(KV_Q8_0) + for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) { + let blck_idx = elem_idx / BLOCK_SIZE; + let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; + let k_row = blck_idx / BLOCKS_K; + let global_k_row = kv_tile + k_row; + let block_k = blck_idx % BLOCKS_K; + let row_offset = k_row * HEAD_DIM_QK; + + if (global_k_row < params.seq_len_kv) { + let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k; + let base_idx = global_block_idx * F16_PER_BLOCK; + let d = K[base_idx]; + for (var j = 0u; j < F16_PER_THREAD; j += 2) { + let q_0 = K[base_idx + 1u + block_offset + j]; + let q_1 = K[base_idx + 1u + block_offset + j + 1]; + let q_packed = bitcast(vec2(q_0, q_1)); + for (var k = 0u; k < 4u; k++) { + let q_byte = get_byte_i32(q_packed, k); + let q_val = f16(q_byte) * d; + let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; + kv_shmem[row_offset + idx] = q_val; + } + } + } + } +#elif defined(KV_DIRECT) + // Direct global loads for KV +#else + for (var elem_idx = local_id.x * 4u; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * 4u) { + let k_row = elem_idx / HEAD_DIM_QK; + let k_col = elem_idx % HEAD_DIM_QK; + let global_k_row = kv_tile + k_row; + let global_k_row_offset = k_head_offset + global_k_row * params.stride_k1; + let in_bounds = global_k_row < params.seq_len_kv && (k_col + 3u) < HEAD_DIM_QK; + let vec_idx = (global_k_row_offset + k_col) >> 2u; + let k4 = select(vec4(0.0), K[vec_idx], in_bounds); + kv_shmem[elem_idx + 0u] = f16(k4.x); + kv_shmem[elem_idx + 1u] = f16(k4.y); + kv_shmem[elem_idx + 2u] = f16(k4.z); + kv_shmem[elem_idx + 3u] = f16(k4.w); + } +#endif + + workgroupBarrier(); + + // accumulate q block * k block into registers across the entire KV tile + if (!skip_tile) { + let num_of_threads = subgroup_size / VEC_NE; + let tx = sg_inv_id % num_of_threads; + let ty = sg_inv_id / num_of_threads; + for (var q_tile_row = subgroup_id; q_tile_row < Q_TILE; q_tile_row += num_subgroups) { + let global_q_row = q_row_start + q_tile_row; + if (global_q_row >= params.seq_len_q) { + continue; + } + let local_q_row_offset = q_tile_row * HEAD_DIM_QK; + + for (var kv_base : u32 = 0u; kv_base < KV_TILE; kv_base += VEC_NE) { + let kv_idx = kv_base + ty; + var partial_sum: f32 = 0.0; + let kv_valid = kv_idx < KV_TILE && (kv_tile + kv_idx) < params.seq_len_kv; + if (kv_valid) { + for (var i = tx; i < (HEAD_DIM_QK / 4u); i += num_of_threads) { + let q_off = local_q_row_offset + i * 4u; + + let qv = vec4( + f32(q_shmem[q_off + 0u]), + f32(q_shmem[q_off + 1u]), + f32(q_shmem[q_off + 2u]), + f32(q_shmem[q_off + 3u])); +#ifdef KV_DIRECT + let idx = k_head_offset + (kv_tile + kv_idx) * params.stride_k1 + (i * 4u); + let kv = vec4(K[idx >> 2u]); +#else + let idx = kv_idx * HEAD_DIM_QK + (i * 4u); + let kv = vec4( + f32(kv_shmem[idx + 0u]), + f32(kv_shmem[idx + 1u]), + f32(kv_shmem[idx + 2u]), + f32(kv_shmem[idx + 3u])); +#endif + partial_sum += dot(qv, kv); + } + } + var sum = partial_sum; + // Reduce over tx threads (NL) for this ty stripe. + var tx_delta = num_of_threads >> 1u; + loop { + if (tx_delta == 0u) { + break; + } + let sh = subgroupShuffleDown(sum, tx_delta); + if (tx < tx_delta) { + sum += sh; + } + tx_delta >>= 1u; + } + + let sum_bcast = subgroupShuffle(sum, num_of_threads * ty); + if (tx == 0u && kv_valid) { + let dst_idx = q_tile_row * KV_TILE + kv_idx; + inter_shmem[dst_idx] = f16(sum_bcast); + } + } + } + } + + +#ifdef MASK + let apply_mask = !skip_tile && (blk_state != 2u); + if (apply_mask) { + // load mask tile into shared memory for this KV block + for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) { + let mask_row = elem_idx / KV_TILE; + let mask_col = elem_idx % KV_TILE; + let global_q_row = q_row_start + mask_row; + let global_k_col = kv_tile + mask_col; + let mask_in_bounds = global_q_row < params.seq_len_q && global_k_col < params.seq_len_kv; + let mask_idx = mask_global_offset + mask_row * params.seq_len_kv + global_k_col; + mask_shmem[elem_idx] = select(0.0, mask[mask_idx], mask_in_bounds); + } + } +#else + let apply_mask = false; +#endif + + workgroupBarrier(); + + // online softmax + if (!skip_tile) { + for (var q_tile_row = subgroup_id; q_tile_row < Q_TILE; q_tile_row += num_subgroups) { + let global_q_row = q_row_start + q_tile_row; + if (global_q_row >= params.seq_len_q) { + break; + } + + var prev_max = row_max_shmem[q_tile_row]; + var final_max = prev_max; + // pass 1: compute final max across the full KV tile in chunks + for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) { + let kv_idx = kv_offset + sg_inv_id; + let kv_valid = kv_tile + kv_idx < params.seq_len_kv && kv_idx < KV_TILE; + let softmax_term = select(FLOAT_MIN, + calc_softmax_term(kv_idx, q_tile_row, slope, has_bias, apply_mask), + kv_valid); + final_max = subgroupMax(max(final_max, softmax_term)); + } + + var total_exp_term: f32 = 0.0; + // pass 2: compute exp sum and write P using final_max + for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) { + let kv_idx = kv_offset + sg_inv_id; + let softmax_term = calc_softmax_term(kv_idx, q_tile_row, slope, has_bias, apply_mask); + let cur_p = select(0.0, + exp(softmax_term - final_max), + kv_tile + kv_idx < params.seq_len_kv && kv_idx < KV_TILE); + total_exp_term += subgroupAdd(cur_p); + if (kv_idx < KV_TILE) { + inter_shmem[kv_idx + q_tile_row * KV_TILE] = f16(cur_p); + } + } + + let cur_exp = exp(prev_max - final_max); + + if (sg_inv_id == 0) { + row_max_shmem[q_tile_row] = final_max; + exp_sum_shmem[q_tile_row] = exp_sum_shmem[q_tile_row] * cur_exp + total_exp_term; + } + + for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) { + let idx = q_tile_row * HEAD_DIM_V + elem_idx; + o_shmem[idx] = f16(f32(o_shmem[idx]) * cur_exp); + } + } + } + + // load v tile into shared memory +#if defined(KV_Q4_0) + for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) { + let blck_idx = elem_idx / BLOCK_SIZE; + let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; + let v_row = blck_idx / BLOCKS_V; + let global_v_row = kv_tile + v_row; + let block_k = blck_idx % BLOCKS_V; + let row_offset = v_row * HEAD_DIM_V; + + if (global_v_row < params.seq_len_kv) { + let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k; + let base_idx = global_block_idx * F16_PER_BLOCK; + let d = V[base_idx]; + for (var j = 0u; j < F16_PER_THREAD; j += 2) { + let q_0 = V[base_idx + 1u + block_offset + j]; + let q_1 = V[base_idx + 1u + block_offset + j + 1]; + let q_packed = bitcast(vec2(q_0, q_1)); + for (var k = 0u; k < 4u; k++) { + let q_byte = get_byte(q_packed, k); + let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d; + let q_lo = (f16(q_byte & 0xF) - 8.0) * d; + let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; + kv_shmem[row_offset + idx] = q_lo; + kv_shmem[row_offset + idx + 16u] = q_hi; + } + } + } + } +#elif defined(KV_Q8_0) + for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) { + let blck_idx = elem_idx / BLOCK_SIZE; + let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; + let v_row = blck_idx / BLOCKS_V; + let global_v_row = kv_tile + v_row; + let block_k = blck_idx % BLOCKS_V; + let row_offset = v_row * HEAD_DIM_V; + + if (global_v_row < params.seq_len_kv) { + let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k; + let base_idx = global_block_idx * F16_PER_BLOCK; + let d = V[base_idx]; + for (var j = 0u; j < F16_PER_THREAD; j += 2) { + let q_0 = V[base_idx + 1u + block_offset + j]; + let q_1 = V[base_idx + 1u + block_offset + j + 1]; + let q_packed = bitcast(vec2(q_0, q_1)); + for (var k = 0u; k < 4u; k++) { + let q_byte = get_byte_i32(q_packed, k); + let q_val = f16(q_byte) * d; + let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; + kv_shmem[row_offset + idx] = q_val; + } + } + } + } +#elif defined(KV_DIRECT) + // Direct global loads for KV +#else + for (var elem_idx = local_id.x * 4u; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * 4u) { + let v_row = elem_idx / HEAD_DIM_V; + let v_col = elem_idx % HEAD_DIM_V; + let global_v_row = kv_tile + v_row; + let global_v_row_offset = v_head_offset + global_v_row * params.stride_v1; + let in_bounds = global_v_row < params.seq_len_kv && (v_col + 3u) < HEAD_DIM_V; + let vec_idx = (global_v_row_offset + v_col) >> 2u; + let v4 = select(vec4(0.0), V[vec_idx], in_bounds); + kv_shmem[elem_idx + 0u] = f16(v4.x); + kv_shmem[elem_idx + 1u] = f16(v4.y); + kv_shmem[elem_idx + 2u] = f16(v4.z); + kv_shmem[elem_idx + 3u] = f16(v4.w); + } +#endif + + workgroupBarrier(); + + if (!skip_tile) { + // we have P (Q_TILE x KV_TILE) in inter_shmem and V (KV_TILE x head_dim_v) in kv_shmem + // we want to compute O += P * V across the full KV tile + let ne_threads : u32 = VEC_NE; + let nl_threads = max(1u, subgroup_size / ne_threads); + let tx_pv = sg_inv_id % nl_threads; + let ty_pv = sg_inv_id / nl_threads; + for (var q_tile_row = subgroup_id; + q_tile_row < Q_TILE; + q_tile_row += num_subgroups) { + for (var vec_col = tx_pv; vec_col < (HEAD_DIM_V / 4u); vec_col += nl_threads) { + var lo = vec4(0.0, 0.0, 0.0, 0.0); + for (var cc = 0u; cc < KV_TILE / ne_threads; cc += 1u) { + let kv_idx = cc * ne_threads + ty_pv; + let v_row = kv_tile + kv_idx; + if (v_row >= params.seq_len_kv) { + continue; + } + + let p = f32(inter_shmem[kv_idx + q_tile_row * KV_TILE]); +#ifdef KV_DIRECT + let v_idx = v_head_offset + v_row * params.stride_v1 + vec_col * 4u; + let v4 = vec4(V[v_idx >> 2u]); +#else + let v_idx = kv_idx * HEAD_DIM_V + vec_col * 4u; + let v4 = vec4( + f32(kv_shmem[v_idx + 0u]), + f32(kv_shmem[v_idx + 1u]), + f32(kv_shmem[v_idx + 2u]), + f32(kv_shmem[v_idx + 3u])); +#endif + lo += p * v4; + } + + var lo_x = lo.x; + var lo_y = lo.y; + var lo_z = lo.z; + var lo_w = lo.w; + // Reduce over ty threads (NE) for this tx thread. + var ty_delta = ne_threads >> 1u; + loop { + if (ty_delta == 0u) { + break; + } + let thread_delta = ty_delta * nl_threads; + let shx = subgroupShuffleDown(lo_x, thread_delta); + let shy = subgroupShuffleDown(lo_y, thread_delta); + let shz = subgroupShuffleDown(lo_z, thread_delta); + let shw = subgroupShuffleDown(lo_w, thread_delta); + if (ty_pv < ty_delta) { + lo_x += shx; + lo_y += shy; + lo_z += shz; + lo_w += shw; + } + ty_delta >>= 1u; + } + + if (ty_pv == 0u) { + let elem_base = vec_col * 4u; + let o_base_idx = q_tile_row * HEAD_DIM_V + elem_base; + o_shmem[o_base_idx + 0u] = f16(f32(o_shmem[o_base_idx + 0u]) + lo_x); + o_shmem[o_base_idx + 1u] = f16(f32(o_shmem[o_base_idx + 1u]) + lo_y); + o_shmem[o_base_idx + 2u] = f16(f32(o_shmem[o_base_idx + 2u]) + lo_z); + o_shmem[o_base_idx + 3u] = f16(f32(o_shmem[o_base_idx + 3u]) + lo_w); + } + } + } + } + + workgroupBarrier(); + } + + +#ifdef SINKS + // Sinks are global terms and must be applied exactly once across split workgroups. + if (iwg == 0u) { + for (var q_tile_row = subgroup_id; + q_tile_row < Q_TILE; + q_tile_row += num_subgroups) { + let global_q_row = q_row_start + q_tile_row; + if (global_q_row >= params.seq_len_q) { + break; + } + + var prev_max = row_max_shmem[q_tile_row]; + + // for non-sink threads, exp(FLOAT_MIN) effectively zeroes out their contribution to the sum + let sink_val = select(FLOAT_MIN, sinks[params.offset_sinks + head_idx], sg_inv_id == 0); + let new_max = subgroupMax(max(prev_max, sink_val)); + let max_exp = exp(prev_max - new_max); + let sink_exp = exp(sink_val - new_max); + + let sink_exp_sum = subgroupAdd(sink_exp); + + if (sg_inv_id == 0) { + row_max_shmem[q_tile_row] = new_max; + exp_sum_shmem[q_tile_row] = exp_sum_shmem[q_tile_row] * max_exp + sink_exp_sum; + } + + for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) { + let idx = q_tile_row * HEAD_DIM_V + elem_idx; + o_shmem[idx] = f16(f32(o_shmem[idx]) * max_exp); + } + } + workgroupBarrier(); + } +#endif + let rows_per_batch = params.n_heads * params.seq_len_q; + for (var q_tile_row = subgroup_id; + q_tile_row < Q_TILE; + q_tile_row += num_subgroups) { + + let global_q_row = q_row_start + q_tile_row; + if (global_q_row >= params.seq_len_q) { break; } + + if (params.nwg == 1u) { + let exp_sum = exp_sum_shmem[q_tile_row]; + let scale = select(0.0, 1.0 / exp_sum, exp_sum != 0.0); + let row_base: u32 = + params.offset_dst + batch_idx * dst3_stride + global_q_row * dst2_stride + head_idx * HEAD_DIM_V; + + for (var elem_base = sg_inv_id * 4u; elem_base < HEAD_DIM_V; elem_base += subgroup_size * 4u) { + let i0 = q_tile_row * HEAD_DIM_V + (elem_base + 0u); + let i1 = q_tile_row * HEAD_DIM_V + (elem_base + 1u); + let i2 = q_tile_row * HEAD_DIM_V + (elem_base + 2u); + let i3 = q_tile_row * HEAD_DIM_V + (elem_base + 3u); + + let v = vec4( + f32(o_shmem[i0]) * scale, + f32(o_shmem[i1]) * scale, + f32(o_shmem[i2]) * scale, + f32(o_shmem[i3]) * scale + ); + + let dst_vec_index: u32 = (row_base + elem_base) >> 2u; + dst[dst_vec_index] = v; + } + } else { + let rid = batch_idx * rows_per_batch + head_idx * params.seq_len_q + global_q_row; + let tmp_row_data_base = params.tmp_data_base + rid * (HEAD_DIM_V * params.nwg) + iwg * HEAD_DIM_V; + let tmp_row_stats_base = params.tmp_stats_base + rid * (2u * params.nwg) + 2u * iwg; + + for (var elem_base = sg_inv_id * 4u; + elem_base < HEAD_DIM_V; + elem_base += subgroup_size * 4u) { + + let i0 = q_tile_row * HEAD_DIM_V + (elem_base + 0u); + let i1 = q_tile_row * HEAD_DIM_V + (elem_base + 1u); + let i2 = q_tile_row * HEAD_DIM_V + (elem_base + 2u); + let i3 = q_tile_row * HEAD_DIM_V + (elem_base + 3u); + + let tbase = tmp_row_data_base + elem_base; + tmp[tbase + 0u] = f32(o_shmem[i0]); + tmp[tbase + 1u] = f32(o_shmem[i1]); + tmp[tbase + 2u] = f32(o_shmem[i2]); + tmp[tbase + 3u] = f32(o_shmem[i3]); + } + + if (sg_inv_id == 0u) { + tmp[tmp_row_stats_base + 0u] = exp_sum_shmem[q_tile_row]; + tmp[tmp_row_stats_base + 1u] = row_max_shmem[q_tile_row]; + } + } + } +} From 321f62823902b890bc1eb5594f937e853c6afc3b Mon Sep 17 00:00:00 2001 From: Radoslav Gerganov Date: Fri, 3 Apr 2026 10:28:09 +0300 Subject: [PATCH 384/831] rpc : reuse compute graph buffers (llama/21299) Reuse the buffer for the ggml context which is used for creating the compute graph on the server side. This partially addresses a memory leak created by the CUDA backend due to using buffer addresses as cache keys. ref: #21265 ref: #20315 --- ggml/src/ggml-rpc/ggml-rpc.cpp | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index 1378ba9f5bf..4e2f1ab0f23 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -1009,8 +1009,8 @@ class rpc_server { bool get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response); struct stored_graph { - ggml_context_ptr ctx_ptr; - ggml_cgraph * graph; + std::vector buffer; + ggml_cgraph * graph; }; private: @@ -1518,10 +1518,12 @@ bool rpc_server::graph_compute(const std::vector & input) { LOG_DBG("[%s] device: %u, n_nodes: %u, n_tensors: %u\n", __func__, device, n_nodes, n_tensors); size_t buf_size = ggml_tensor_overhead()*(n_nodes + n_tensors) + ggml_graph_overhead_custom(n_nodes, false); - + if (stored_graphs[device].buffer.size() < buf_size) { + stored_graphs[device].buffer.resize(buf_size); + } struct ggml_init_params params = { /*.mem_size =*/ buf_size, - /*.mem_buffer =*/ NULL, + /*.mem_buffer =*/ stored_graphs[device].buffer.data(), /*.no_alloc =*/ true, }; ggml_context_ptr ctx_ptr { ggml_init(params) }; @@ -1551,7 +1553,6 @@ bool rpc_server::graph_compute(const std::vector & input) { } ggml_status status = ggml_backend_graph_compute(backends[device], graph); GGML_ASSERT(status == GGML_STATUS_SUCCESS && "Unsuccessful graph computations are not supported with RPC"); - stored_graphs[device].ctx_ptr.swap(ctx_ptr); stored_graphs[device].graph = graph; return true; } From 3f5117610b9053b1a7a7f9db66181645063ce4cf Mon Sep 17 00:00:00 2001 From: Vishal Singh Date: Fri, 3 Apr 2026 14:49:08 +0530 Subject: [PATCH 385/831] ggml-zendnn : add MUL_MAT_ID op support for MoE models (llama/21315) * ggml-zendnn : add MUL_MAT_ID op support for MoE models - Add MUL_MAT_ID op acceleration for Mixture-of-Experts models - MUL_MAT_ID op fallback to CPU backend if total experts > 32 - Point ZenDNN lib to latest bits ZenDNN-2026-WW13 * ggml-zendnn : add braces to sgemm failure condition for consistency Co-authored-by: Aaron Teo --------- Co-authored-by: Aaron Teo --- ggml/src/ggml-zendnn/CMakeLists.txt | 2 +- ggml/src/ggml-zendnn/ggml-zendnn.cpp | 179 +++++++++++++++++++++++++++ 2 files changed, 180 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-zendnn/CMakeLists.txt b/ggml/src/ggml-zendnn/CMakeLists.txt index 9bdb4e836d3..4f321a25257 100644 --- a/ggml/src/ggml-zendnn/CMakeLists.txt +++ b/ggml/src/ggml-zendnn/CMakeLists.txt @@ -28,7 +28,7 @@ if (NOT ZENDNN_ROOT OR ZENDNN_ROOT STREQUAL "" OR ZENDNN_ROOT STREQUAL "OFF") ExternalProject_Add( zendnn GIT_REPOSITORY https://github.com/amd/ZenDNN.git - GIT_TAG a18adf8c605fb5f5e52cefd7eda08a7b18febbaf # ZenDNN-2026-WW08 + GIT_TAG f79f7321a1add65ced6397a6bfab7edba6e3e14e # ZenDNN-2026-WW13 PREFIX ${ZENDNN_PREFIX} SOURCE_DIR ${ZENDNN_SOURCE_DIR} BINARY_DIR ${ZENDNN_BUILD_DIR} diff --git a/ggml/src/ggml-zendnn/ggml-zendnn.cpp b/ggml/src/ggml-zendnn/ggml-zendnn.cpp index c8760304008..377303720c7 100644 --- a/ggml/src/ggml-zendnn/ggml-zendnn.cpp +++ b/ggml/src/ggml-zendnn/ggml-zendnn.cpp @@ -190,6 +190,170 @@ static void ggml_zendnn_compute_forward_mul_mat( } } +struct mmid_row_mapping { + int32_t i1; + int32_t i2; +}; + +static void ggml_zendnn_compute_forward_mul_mat_id( + ggml_backend_zendnn_context * ctx, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; // expert weights + const ggml_tensor * src1 = dst->src[1]; // inputs + const ggml_tensor * ids = dst->src[2]; // expert ids + + GGML_TENSOR_BINARY_OP_LOCALS + + // exit for no tokens to process + if (ne2 == 0 || ne11 == 0) { + return; + } + + ggml_type const vec_dot_type = src0->type; + ggml_from_float_t const from_float = ggml_get_type_traits(vec_dot_type)->from_float_ref; + + // we don't support permuted src0 or src1 + GGML_ASSERT(nb00 == ggml_type_size(src0->type)); + GGML_ASSERT(nb10 == ggml_type_size(src1->type)); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + GGML_ASSERT(ne03 == 1); + GGML_ASSERT(ne13 == 1); + GGML_ASSERT(ne3 == 1); + + // row groups + const int n_ids = ids->ne[0]; // n_expert_used + const int n_as = ne02; // n_experts + + std::vector matrix_row_counts(n_as, 0); + std::vector> matrix_rows(n_as); + + int64_t max_rows = 0; + // group rows by expert (preprocessing step) + for (int64_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) { + for (int id = 0; id < n_ids; ++id) { + const int32_t i02 = *(const int32_t *)((const char *)ids->data + iid1*ids->nb[1] + id*ids->nb[0]); + + GGML_ASSERT(i02 >= 0 && i02 < n_as); + + matrix_rows[i02].push_back({id, iid1}); + matrix_row_counts[i02]++; + if (matrix_row_counts[i02] > max_rows) { + max_rows = matrix_row_counts[i02]; + } + } + } + + if (max_rows == 0) { + return; // no rows to process + } + + const size_t row_size = ggml_row_size(vec_dot_type, ne10); + + // size for converting src1 rows to vec_dot_type if needed + const size_t nbw1 = row_size; + const size_t nbw2 = nbw1 * ne11; + const size_t nbw3 = nbw2 * ne12; + const size_t src1_conv_size = (src1->type != vec_dot_type) ? ne13 * nbw3 : 0; + + // size for MoE gather/scatter buffers + const size_t wdata_cur_size = max_rows * row_size; + const size_t dst_cur_size = max_rows * ggml_row_size(dst->type, ne01); + + // allocate single buffer for all needs + const size_t total_size = src1_conv_size + wdata_cur_size + dst_cur_size; + if (ctx->work_size < total_size) { + ctx->work_data.reset(new char[total_size]); + ctx->work_size = total_size; + } + + // partition the buffer + char * work_data = ctx->work_data.get(); + char * wdata_cur = work_data + src1_conv_size; + char * dst_cur = wdata_cur + wdata_cur_size; + + if (src1->type != vec_dot_type) { + GGML_ASSERT(src1->type == GGML_TYPE_F32); + + #pragma omp parallel for collapse(3) num_threads(ctx->n_threads) schedule(static) + for (int64_t i13 = 0; i13 < ne13; ++i13) { + for (int64_t i12 = 0; i12 < ne12; ++i12) { + for (int64_t i11 = 0; i11 < ne11; ++i11) { + const float * src1_f32 = (float *)((char *)src1->data + i11*nb11 + i12*nb12 + i13*nb13); + void * src1_conv = (char *)work_data + i11*nbw1 + i12*nbw2 + i13*nbw3; + from_float(src1_f32, src1_conv, ne10); + } + } + } + } + + const void * wdata = src1->type == vec_dot_type ? src1->data : work_data; + + // process each expert with gather -> gemm -> scatter pattern + for (int64_t cur_a = 0; cur_a < n_as; ++cur_a) { + const int64_t cne1 = matrix_row_counts[cur_a]; + + if (cne1 == 0) { + continue; + } + + const char * src0_cur = (const char *) src0->data + cur_a*nb02; + + // gather input rows for this expert + #pragma omp parallel for num_threads(ctx->n_threads) schedule(static) + for (int64_t ir1 = 0; ir1 < cne1; ++ir1) { + const mmid_row_mapping & row_mapping = matrix_rows[cur_a][ir1]; + const int64_t id = row_mapping.i1; + const int64_t i11 = id % ne11; + const int64_t i12 = row_mapping.i2; + + std::memcpy( + wdata_cur + ir1 * row_size, + (const char *) wdata + (i11 + i12*ne11) * row_size, + row_size + ); + } + + // batched gemm for all tokens in this expert + if (!ggml_zendnn_sgemm(ctx, + ne01, // m + cne1, // n + ne10, // k + src0_cur, + ne00, // lda + wdata_cur, + ne10, // ldb + dst_cur, + ne01, // ldc + src0->type, + vec_dot_type, + dst->type)) { + GGML_ABORT("%s: ZenDNN sgemm failed\n", __func__); + } + + // scatter output rows to destination + #pragma omp parallel for num_threads(ctx->n_threads) schedule(static) + for (int64_t ir1 = 0; ir1 < cne1; ++ir1) { + const mmid_row_mapping & row_mapping = matrix_rows[cur_a][ir1]; + const int64_t id = row_mapping.i1; + const int64_t i1 = id; + const int64_t i2 = row_mapping.i2; + + std::memcpy( + (char *) dst->data + i1*nb1 + i2*nb2, + dst_cur + ir1 * ggml_row_size(dst->type, ne01), + ggml_row_size(dst->type, ne01) + ); + } + } +} + // backend interface static const char * ggml_backend_zendnn_get_name(ggml_backend_t backend) { @@ -218,6 +382,9 @@ static ggml_status ggml_backend_zendnn_graph_compute(ggml_backend_t backend, ggm case GGML_OP_MUL_MAT: ggml_zendnn_compute_forward_mul_mat(ctx, node); break; + case GGML_OP_MUL_MAT_ID: + ggml_zendnn_compute_forward_mul_mat_id(ctx, node); + break; case GGML_OP_NONE: case GGML_OP_RESHAPE: case GGML_OP_VIEW: @@ -361,6 +528,7 @@ static bool ggml_backend_zendnn_device_supports_op(ggml_backend_dev_t dev, const return true; case GGML_OP_MUL_MAT: + case GGML_OP_MUL_MAT_ID: { const ggml_tensor * weights = op->src[0]; const ggml_tensor * inputs = op->src[1]; @@ -374,6 +542,17 @@ static bool ggml_backend_zendnn_device_supports_op(ggml_backend_dev_t dev, const ne0 < min_batch || ne1 < min_batch || ne10 < min_batch) { return false; } + // MUL_MAT_ID performs best with a moderate number of experts due to its + // gather + batched matmul + scatter approach. Future versions will leverage + // ZenDNN's grouped_gemm for better scalability with larger expert counts: + // https://github.com/amd/ZenDNN/blob/main/docs/operator/lowoha_group_gemm_operator.md + if (op->op == GGML_OP_MUL_MAT_ID) { + const int64_t n_experts = weights->ne[2]; + const int64_t max_experts = 32; + if (n_experts > max_experts) { + return false; + } + } switch (weights->type) { case GGML_TYPE_F32: case GGML_TYPE_BF16: From d6cfdc669cad5faa0171011b57df3ed7c1ed4911 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Fri, 3 Apr 2026 11:40:14 -0700 Subject: [PATCH 386/831] ggml-webgpu: move from parameter buffer pool to single buffer with offsets (llama/21278) * Work towards removing bitcast * Move rest of existing types over * Add timeout back to wait and remove synchronous set_tensor/memset_tensor * move to unpackf16 for wider compatibility * cleanup * Remove deadlock condition in free_bufs * Start work on removing parameter buffer pools * Simplify and optimize further * simplify profile futures * Fix stride * Try using a single command buffer per batch * formatting --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 43 +- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 758 ++++++++---------- 2 files changed, 379 insertions(+), 422 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 1c56c689312..669d2cd53a8 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -437,12 +437,18 @@ inline uint32_t ggml_webgpu_flash_attn_pick_vec_ne(const ggml_webgpu_flash_attn_ // Head-dim specializations used by the tuned vec f16 path. switch (key.head_dim_qk) { - case 64: return 2u; - case 96: return 4u; - case 128: return 1u; - case 192: return 2u; - case 576: return 2u; - default: return 1u; + case 64: + return 2u; + case 96: + return 4u; + case 128: + return 1u; + case 192: + return 2u; + case 576: + return 2u; + default: + return 1u; } } @@ -513,9 +519,9 @@ struct ggml_webgpu_flash_attn_blk_shader_lib_context { }; inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_blk_shader( - pre_wgsl::Preprocessor & preprocessor, - const char * shader_src, - const ggml_webgpu_flash_attn_blk_shader_lib_context & context) { + pre_wgsl::Preprocessor & preprocessor, + const char * shader_src, + const ggml_webgpu_flash_attn_blk_shader_lib_context & context) { std::vector defines; std::string variant = "flash_attn_vec_blk"; @@ -1857,9 +1863,8 @@ class ggml_webgpu_shader_lib { defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k)); uint32_t q_tile = context.sg_mat_m; - uint32_t kv_tile = - std::min(ggml_webgpu_flash_attn_max_kv_tile(context), - context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES); + uint32_t kv_tile = std::min(ggml_webgpu_flash_attn_max_kv_tile(context), + context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES); if (context.key.use_vec) { q_tile = 1; kv_tile = std::max(context.sg_mat_n, std::min(32u, ggml_webgpu_flash_attn_max_kv_tile(context))); @@ -1885,14 +1890,14 @@ class ggml_webgpu_shader_lib { } defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); - const char * shader_src = context.key.use_vec ? wgsl_flash_attn_vec_split : wgsl_flash_attn; + const char * shader_src = context.key.use_vec ? wgsl_flash_attn_vec_split : wgsl_flash_attn; webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, preprocessor.preprocess(shader_src, defines), variant); - auto decisions = std::make_shared(); - decisions->q_tile = q_tile; - decisions->kv_tile = kv_tile; - decisions->wg_size = wg_size; - pipeline.context = decisions; + auto decisions = std::make_shared(); + decisions->q_tile = q_tile; + decisions->kv_tile = kv_tile; + decisions->wg_size = wg_size; + pipeline.context = decisions; flash_attn_pipelines[context.key] = pipeline; return flash_attn_pipelines[context.key]; } @@ -1905,7 +1910,7 @@ class ggml_webgpu_shader_lib { ggml_webgpu_processed_shader processed = ggml_webgpu_preprocess_flash_attn_blk_shader(preprocessor, wgsl_flash_attn_vec_blk, context); - webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed.wgsl, processed.variant); + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed.wgsl, processed.variant); flash_attn_blk_pipelines[context.key] = pipeline; return flash_attn_blk_pipelines[context.key]; } diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index e53281bfbbd..5c567dc0df0 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -81,12 +81,10 @@ static inline void compute_2d_workgroups(uint32_t total_wg, uint32_t max_per_dim /* Constants */ -#define WEBGPU_NUM_PARAM_BUFS 96u -#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 32u +#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 32u +#define WEBGPU_NUM_PARAM_SLOTS \ + (WEBGPU_COMMAND_SUBMIT_BATCH_SIZE + 10) // a few extra for safety, since some operations may need multiple slots #define WEBGPU_WAIT_ANY_TIMEOUT_MS 100 -// Maximum number of in-flight submissions per-thread, to avoid exhausting the -// parameter buffer pool -#define WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD (WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE) #define WEBGPU_PARAMS_BUF_SIZE_BYTES 128 // enough for 32 parameters #define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4 #define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4 @@ -122,87 +120,45 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device, wgpu::BufferUsage usage, const char * label); -// Holds a pool of parameter buffers for WebGPU operations -struct webgpu_buf_pool { - std::vector free; - - // The pool must be synchronized because - // 1. The memset pool is shared globally by every ggml buffer, - // since allocating a pool per ggml buffer would consume too much memory. - // 2. For the per-thread buffer pools in webgpu_context, - // buffers are allocated and freed in Dawn callbacks, - // which can run on a different thread than the calling thread. - std::mutex mutex; - std::condition_variable cv; - size_t cur_pool_size; - size_t max_pool_size; - wgpu::Device device; - wgpu::BufferUsage dev_buf_usage; - size_t buf_size; - bool should_grow; - - void init(wgpu::Device device, - int num_bufs, - size_t buf_size, - wgpu::BufferUsage dev_buf_usage, - bool should_grow = false, - size_t max_pool_size = WEBGPU_NUM_PARAM_BUFS * 2) { - this->max_pool_size = max_pool_size; - this->cur_pool_size = num_bufs; - this->device = device; - this->dev_buf_usage = dev_buf_usage; - this->buf_size = buf_size; - this->should_grow = should_grow; - for (int i = 0; i < num_bufs; i++) { - wgpu::Buffer dev_buf; - ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_pool_buf"); - free.push_back(dev_buf); +// Slot-based parameter arena for compute graph encoding. Each encoded kernel +// gets a unique uniform-buffer slice within the current batch, and the slot +// cursor is reset immediately after that batch is submitted. +struct webgpu_param_arena { + wgpu::Buffer buffer; + size_t slot_stride = 0; + size_t slot_size = 0; + uint32_t slot_count = 0; + uint32_t next_slot = 0; + + void init(wgpu::Device device, size_t slot_size, uint32_t slot_count, size_t alignment) { + this->slot_stride = ROUNDUP_POW2(slot_size, alignment); + this->slot_size = slot_size; + this->slot_count = slot_count; + this->next_slot = 0; + + ggml_webgpu_create_buffer(device, buffer, this->slot_stride * slot_count, + wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform, "ggml_webgpu_param_arena"); + } + + size_t alloc_slot(size_t size) { + GGML_ASSERT(size <= slot_size); + if (next_slot >= slot_count) { + GGML_ABORT("ggml_webgpu: parameter arena exhausted while encoding a batch"); } - } - wgpu::Buffer alloc_bufs() { - std::unique_lock lock(mutex); - if (!free.empty()) { - wgpu::Buffer buf = free.back(); - free.pop_back(); - return buf; - } - - // Try growing the pool if no free buffers - if (free.empty() && cur_pool_size < max_pool_size && should_grow) { - cur_pool_size++; - lock.unlock(); // avoid deadlock between this lock and Dawn's internal locks when buffers are freed in callbacks - wgpu::Buffer dev_buf; - ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_pool_buf"); - - if (!dev_buf) { - GGML_ABORT("webgpu_buf_pool: failed to allocate buffers"); - } - return dev_buf; - } - cv.wait(lock, [this] { return !free.empty(); }); - wgpu::Buffer buf = free.back(); - free.pop_back(); - return buf; + return slot_stride * next_slot++; } - void free_bufs(std::vector bufs) { - std::lock_guard lock(mutex); - free.insert(free.end(), bufs.begin(), bufs.end()); - cv.notify_all(); - } + void reset() { next_slot = 0; } void cleanup() { - std::lock_guard lock(mutex); - for (auto & buf : free) { - if (buf) { - buf.Destroy(); - } + if (buffer) { + buffer.Destroy(); + buffer = nullptr; } - free.clear(); } - ~webgpu_buf_pool() { this->cleanup(); } + ~webgpu_param_arena() { this->cleanup(); } }; #ifdef GGML_WEBGPU_GPU_PROFILE @@ -269,10 +225,8 @@ struct webgpu_gpu_profile_buf_pool { }; #endif -struct webgpu_command { - uint32_t num_kernels; - wgpu::CommandBuffer commands; - std::vector params_bufs; +struct webgpu_encoded_op { + uint32_t num_kernels = 0; #ifdef GGML_WEBGPU_GPU_PROFILE webgpu_gpu_profile_bufs timestamp_query_bufs; std::string pipeline_name; @@ -305,8 +259,8 @@ struct webgpu_global_context_struct { // Global mutex for pipeline and staging buffer, will be refactored to exclude pipeline caches. std::recursive_mutex mutex; - webgpu_buf_pool memset_buf_pool; - std::map memset_pipelines; // variant or type index + wgpu::Buffer memset_params_buf; + webgpu_pipeline memset_pipeline; #ifdef GGML_WEBGPU_CPU_PROFILE // Profiling: labeled CPU time in ms (total) @@ -332,6 +286,10 @@ struct webgpu_global_context_struct { this->get_tensor_staging_buf.Destroy(); this->get_tensor_staging_buf = nullptr; } + if (this->memset_params_buf) { + this->memset_params_buf.Destroy(); + this->memset_params_buf = nullptr; + } #ifdef GGML_WEBGPU_DEBUG if (this->debug_host_buf) { this->debug_host_buf.Destroy(); @@ -347,13 +305,6 @@ struct webgpu_global_context_struct { typedef std::shared_ptr webgpu_global_context; -struct webgpu_submission { - wgpu::FutureWaitInfo submit_done; -#ifdef GGML_WEBGPU_GPU_PROFILE - std::vector profile_futures; -#endif -}; - // All the base objects needed to run operations on a WebGPU device struct webgpu_context_struct { // Points to global instances owned by ggml_backend_webgpu_reg_context @@ -361,9 +312,9 @@ struct webgpu_context_struct { std::unique_ptr shader_lib; - webgpu_buf_pool param_buf_pool; - wgpu::Buffer set_rows_dev_error_buf; - wgpu::Buffer set_rows_host_error_buf; + webgpu_param_arena param_arena; + wgpu::Buffer set_rows_dev_error_buf; + wgpu::Buffer set_rows_host_error_buf; size_t memset_bytes_per_thread; }; @@ -448,95 +399,34 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device, /** WebGPU Actions */ -static bool ggml_backend_webgpu_handle_wait_status(wgpu::WaitStatus status, bool allow_timeout = false) { - switch (status) { - case wgpu::WaitStatus::Success: - return true; - case wgpu::WaitStatus::TimedOut: - if (allow_timeout) { - return false; - } - GGML_LOG_ERROR("ggml_webgpu: WaitAny timed out unexpectedly\n"); - return false; - case wgpu::WaitStatus::Error: - GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n"); - return false; - default: - GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an unknown status\n"); - return false; - } -} - #ifdef GGML_WEBGPU_GPU_PROFILE -static void ggml_backend_webgpu_erase_completed_futures(std::vector & futures) { - futures.erase(std::remove_if(futures.begin(), futures.end(), - [](const wgpu::FutureWaitInfo & info) { return info.completed; }), - futures.end()); -} - static void ggml_backend_webgpu_wait_profile_futures(webgpu_global_context & ctx, - std::vector & futures, - bool block) { + std::vector & futures) { if (futures.empty()) { return; } - uint64_t timeout_ms = block ? UINT64_MAX : 0; - if (block) { - while (!futures.empty()) { - auto waitStatus = ctx->instance.WaitAny(futures.size(), futures.data(), timeout_ms); - if (ggml_backend_webgpu_handle_wait_status(waitStatus)) { - ggml_backend_webgpu_erase_completed_futures(futures); - } - } - } else { - auto waitStatus = ctx->instance.WaitAny(futures.size(), futures.data(), timeout_ms); - if (ggml_backend_webgpu_handle_wait_status(waitStatus, true)) { - ggml_backend_webgpu_erase_completed_futures(futures); - } + constexpr size_t max_futures_per_wait = 64; + + while (!futures.empty()) { + ctx->instance.WaitAny(std::min(max_futures_per_wait, futures.size()), futures.data(), UINT64_MAX); + futures.erase(std::remove_if(futures.begin(), futures.end(), + [](const wgpu::FutureWaitInfo & info) { return info.completed; }), + futures.end()); } } #endif -// Wait for the queue to finish processing all submitted work -static void ggml_backend_webgpu_wait(webgpu_global_context & ctx, - std::vector & subs, - bool block = true) { - if (subs.empty()) { - return; - } - - bool blocking_wait = block || subs.size() >= WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD; - while (blocking_wait) { - auto waitStatus = ctx->instance.WaitAny(1, &subs[0].submit_done, WEBGPU_WAIT_ANY_TIMEOUT_MS * 1e6); - if (ggml_backend_webgpu_handle_wait_status(waitStatus, true)) { -#ifdef GGML_WEBGPU_GPU_PROFILE - ggml_backend_webgpu_wait_profile_futures(ctx, subs[0].profile_futures, true); -#endif - subs.erase(subs.begin()); - } - blocking_wait = (block && !subs.empty()) || subs.size() >= WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD; - } - - if (subs.empty()) { - return; - } - - // Poll each submit future once and remove completed submissions. - for (auto sub = subs.begin(); sub != subs.end();) { - auto waitStatus = ctx->instance.WaitAny(1, &sub->submit_done, 0); - bool success = ggml_backend_webgpu_handle_wait_status(waitStatus, true); -#ifdef GGML_WEBGPU_GPU_PROFILE - ggml_backend_webgpu_wait_profile_futures(ctx, sub->profile_futures, false); - if (success && sub->profile_futures.empty()) { -#else - if (success) { -#endif - sub = subs.erase(sub); - } else { - ++sub; - } - } +static void ggml_backend_webgpu_wait_queue(webgpu_global_context & ctx) { + ctx->instance.WaitAny( + ctx->queue.OnSubmittedWorkDone(wgpu::CallbackMode::AllowSpontaneous, + [](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) { + if (status != wgpu::QueueWorkDoneStatus::Success) { + GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", + std::string(message).c_str()); + } + }), + UINT64_MAX); } static void ggml_backend_webgpu_map_buffer(webgpu_global_context & ctx, @@ -570,34 +460,10 @@ static void ggml_backend_webgpu_debug(webgpu_global_context & ctx) { } #endif -static webgpu_submission ggml_backend_webgpu_submit(webgpu_global_context & ctx, - std::vector & commands, - webgpu_buf_pool & param_buf_pool) { - std::vector command_buffers; - std::vector params_bufs; - webgpu_submission submission; -#ifdef GGML_WEBGPU_GPU_PROFILE - std::vector> pipeline_name_and_ts_bufs; -#endif - - for (const auto & command : commands) { - command_buffers.push_back(command.commands); - params_bufs.insert(params_bufs.end(), command.params_bufs.begin(), command.params_bufs.end()); - } - ctx->queue.Submit(command_buffers.size(), command_buffers.data()); - - wgpu::Future p_f = ctx->queue.OnSubmittedWorkDone( - wgpu::CallbackMode::AllowSpontaneous, - [¶m_buf_pool, params_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) { - if (status != wgpu::QueueWorkDoneStatus::Success) { - GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", std::string(message).c_str()); - } - // Free the staged buffers - param_buf_pool.free_bufs(params_bufs); - }); - submission.submit_done = { p_f }; - #ifdef GGML_WEBGPU_GPU_PROFILE +static void ggml_backend_webgpu_collect_profile_futures(webgpu_global_context & ctx, + const std::vector & commands, + std::vector & futures) { for (const auto & command : commands) { auto label = command.pipeline_name; auto ts_bufs = command.timestamp_query_bufs; @@ -616,15 +482,15 @@ static webgpu_submission ggml_backend_webgpu_submit(webgpu_global_context & // We can't unmap in here due to WebGPU reentrancy limitations. ctx->timestamp_query_buf_pool.free_bufs({ ts_bufs }); }); - submission.profile_futures.push_back({ f }); + futures.push_back({ f }); } -#endif - return submission; } +#endif -static webgpu_command ggml_backend_webgpu_build_multi( +static webgpu_encoded_op ggml_backend_webgpu_build_multi( webgpu_global_context & ctx, - webgpu_buf_pool & param_buf_pool, + webgpu_param_arena & param_arena, + wgpu::CommandEncoder & encoder, const std::vector & pipelines, const std::vector> & params_list, const std::vector> & bind_group_entries_list, @@ -633,16 +499,21 @@ static webgpu_command ggml_backend_webgpu_build_multi( GGML_ASSERT(pipelines.size() == bind_group_entries_list.size()); GGML_ASSERT(pipelines.size() == workgroups_list.size()); - std::vector params_bufs_list; + webgpu_encoded_op result = {}; std::vector bind_groups; + std::vector param_offsets; + result.num_kernels = pipelines.size(); for (size_t i = 0; i < pipelines.size(); i++) { - wgpu::Buffer params_bufs = param_buf_pool.alloc_bufs(); + const size_t param_size = params_list[i].size() * sizeof(uint32_t); + const size_t param_offset = param_arena.alloc_slot(param_size); std::vector entries = bind_group_entries_list[i]; uint32_t params_binding_num = entries.size(); - entries.push_back( - { .binding = params_binding_num, .buffer = params_bufs, .offset = 0, .size = params_bufs.GetSize() }); + entries.push_back({ .binding = params_binding_num, + .buffer = param_arena.buffer, + .offset = param_offset, + .size = param_arena.slot_size }); wgpu::BindGroupDescriptor bind_group_desc; bind_group_desc.layout = pipelines[i].pipeline.GetBindGroupLayout(0); @@ -650,13 +521,12 @@ static webgpu_command ggml_backend_webgpu_build_multi( bind_group_desc.entries = entries.data(); bind_group_desc.label = pipelines[i].name.c_str(); bind_groups.push_back(ctx->device.CreateBindGroup(&bind_group_desc)); - - params_bufs_list.push_back(params_bufs); + param_offsets.push_back(param_offset); } - wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder(); - for (size_t i = 0; i < params_bufs_list.size(); i++) { - ctx->queue.WriteBuffer(params_bufs_list[i], 0, params_list[i].data(), params_list[i].size() * sizeof(uint32_t)); + for (size_t i = 0; i < param_offsets.size(); i++) { + ctx->queue.WriteBuffer(param_arena.buffer, param_offsets[i], params_list[i].data(), + params_list[i].size() * sizeof(uint32_t)); } #ifdef GGML_WEBGPU_GPU_PROFILE webgpu_gpu_profile_bufs ts_bufs = ctx->timestamp_query_buf_pool.alloc_bufs(); @@ -682,29 +552,21 @@ static webgpu_command ggml_backend_webgpu_build_multi( #ifdef GGML_WEBGPU_GPU_PROFILE encoder.ResolveQuerySet(ts_bufs.query_set, 0, 2, ts_bufs.dev_buf, 0); encoder.CopyBufferToBuffer(ts_bufs.dev_buf, 0, ts_bufs.host_buf, 0, ts_bufs.host_buf.GetSize()); -#endif - - wgpu::CommandBuffer commands = encoder.Finish(); - webgpu_command result = {}; - result.commands = commands; - result.params_bufs = params_bufs_list; - result.num_kernels = pipelines.size(); -#ifdef GGML_WEBGPU_GPU_PROFILE result.timestamp_query_bufs = ts_bufs; - // TODO: handle multiple pipeline names result.pipeline_name = pipelines.front().name; #endif return result; } -static webgpu_command ggml_backend_webgpu_build(webgpu_global_context & ctx, - webgpu_buf_pool & param_buf_pool, - webgpu_pipeline & pipeline, - std::vector params, - std::vector bind_group_entries, - uint32_t wg_x, - uint32_t wg_y = 1) { - return ggml_backend_webgpu_build_multi(ctx, param_buf_pool, +static webgpu_encoded_op ggml_backend_webgpu_build(webgpu_global_context & ctx, + webgpu_param_arena & param_arena, + wgpu::CommandEncoder & encoder, + webgpu_pipeline & pipeline, + std::vector params, + std::vector bind_group_entries, + uint32_t wg_x, + uint32_t wg_y = 1) { + return ggml_backend_webgpu_build_multi(ctx, param_arena, encoder, { pipeline }, @@ -724,10 +586,28 @@ static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx, size_t bytes_per_wg = WEBGPU_MAX_WG_SIZE * ctx->capabilities.memset_bytes_per_thread; uint32_t wg_x = CEIL_DIV(size + 3, bytes_per_wg); - webgpu_command command = - ggml_backend_webgpu_build(ctx, ctx->memset_buf_pool, ctx->memset_pipelines[0], params, entries, wg_x); - std::vector commands = { command }; - std::vector sub = { ggml_backend_webgpu_submit(ctx, commands, ctx->memset_buf_pool) }; + ctx->queue.WriteBuffer(ctx->memset_params_buf, 0, params.data(), params.size() * sizeof(uint32_t)); + + entries.push_back( + { .binding = 1, .buffer = ctx->memset_params_buf, .offset = 0, .size = WEBGPU_PARAMS_BUF_SIZE_BYTES }); + + wgpu::BindGroupDescriptor bind_group_desc; + bind_group_desc.layout = ctx->memset_pipeline.pipeline.GetBindGroupLayout(0); + bind_group_desc.entryCount = entries.size(); + bind_group_desc.entries = entries.data(); + bind_group_desc.label = ctx->memset_pipeline.name.c_str(); + wgpu::BindGroup bind_group = ctx->device.CreateBindGroup(&bind_group_desc); + + wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder(); + wgpu::ComputePassEncoder pass = encoder.BeginComputePass(); + pass.SetPipeline(ctx->memset_pipeline.pipeline); + pass.SetBindGroup(0, bind_group); + pass.DispatchWorkgroups(wg_x, 1, 1); + pass.End(); + + wgpu::CommandBuffer command = encoder.Finish(); + std::vector commands = { command }; + ctx->queue.Submit(commands.size(), commands.data()); } /** End WebGPU Actions */ @@ -840,7 +720,10 @@ static binary_overlap_flags ggml_webgpu_detect_binary_overlap(ggml_tensor * src0 return flags; } -static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_cpy(webgpu_context & ctx, + wgpu::CommandEncoder & encoder, + ggml_tensor * src, + ggml_tensor * dst) { ggml_webgpu_shader_lib_context shader_lib_ctx = { .src0 = src, .dst = dst, @@ -878,10 +761,14 @@ static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, g }; uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); } -static webgpu_command ggml_webgpu_set(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_set(webgpu_context & ctx, + wgpu::CommandEncoder & encoder, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { const bool inplace = ggml_webgpu_tensor_equal(src0, dst); ggml_webgpu_shader_lib_context shader_lib_ctx = { @@ -940,10 +827,13 @@ static webgpu_command ggml_webgpu_set(webgpu_context & ctx, ggml_tensor * src0, .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); } -static webgpu_command ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_pad(webgpu_context & ctx, + wgpu::CommandEncoder & encoder, + ggml_tensor * src, + ggml_tensor * dst) { ggml_webgpu_shader_lib_context shader_lib_ctx = { .src0 = src, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup }; @@ -995,13 +885,14 @@ static webgpu_command ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, g }; uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); } -static webgpu_command ggml_webgpu_solve_tri(webgpu_context & ctx, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_solve_tri(webgpu_context & ctx, + wgpu::CommandEncoder & encoder, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { ggml_webgpu_shader_lib_context shader_lib_ctx = { .src0 = src0, .src1 = src1, @@ -1056,13 +947,14 @@ static webgpu_command ggml_webgpu_solve_tri(webgpu_context & ctx, const uint32_t wg_x = CEIL_DIV((uint32_t) src1->ne[0], decisions->wg_size); const uint32_t wg_y = (uint32_t) (dst->ne[2] * dst->ne[3]); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, wg_y); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x, wg_y); } -static webgpu_command ggml_webgpu_ssm_conv(webgpu_context & ctx, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_ssm_conv(webgpu_context & ctx, + wgpu::CommandEncoder & encoder, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { ggml_webgpu_shader_lib_context shader_lib_ctx = { .src0 = src0, .src1 = src1, @@ -1112,17 +1004,18 @@ static webgpu_command ggml_webgpu_ssm_conv(webgpu_context & ctx, const uint32_t wg_x = CEIL_DIV((uint32_t) src0->ne[1], decisions->block_size); const uint32_t wg_y = token_tiles * (uint32_t) dst->ne[2]; - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, wg_y); -} - -static webgpu_command ggml_webgpu_gated_delta_net(webgpu_context & ctx, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * src2, - ggml_tensor * src3, - ggml_tensor * src4, - ggml_tensor * src5, - ggml_tensor * dst) { + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x, wg_y); +} + +static webgpu_encoded_op ggml_webgpu_gated_delta_net(webgpu_context & ctx, + wgpu::CommandEncoder & encoder, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * src2, + ggml_tensor * src3, + ggml_tensor * src4, + ggml_tensor * src5, + ggml_tensor * dst) { ggml_webgpu_shader_lib_context shader_lib_ctx = { .src0 = src0, .src1 = src1, @@ -1197,13 +1090,14 @@ static webgpu_command ggml_webgpu_gated_delta_net(webgpu_context & ctx, .size = ggml_webgpu_tensor_binding_size(ctx, dst) } }; - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, h, n_seqs); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, h, n_seqs); } -static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, - ggml_tensor * src, - ggml_tensor * idx, - ggml_tensor * dst) { +static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, + wgpu::CommandEncoder & encoder, + ggml_tensor * src, + ggml_tensor * idx, + ggml_tensor * dst) { // For set rows specifically, we need to check if src and idx are empty // tensors. if (ggml_is_empty(src) || ggml_is_empty(idx)) { @@ -1266,7 +1160,7 @@ static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, threads = src->ne[0] * src->ne[1] * src->ne[2] * src->ne[3]; } uint32_t wg_x = CEIL_DIV(threads, decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, 1); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x, 1); } // Workgroup size is a common constant @@ -1277,10 +1171,11 @@ static std::vector ggml_webgpu_wg_size_entry(uint32_t wg_si return constants; } -static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx, - ggml_tensor * src, - ggml_tensor * idx, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_get_rows(webgpu_context & ctx, + wgpu::CommandEncoder & encoder, + ggml_tensor * src, + ggml_tensor * idx, + ggml_tensor * dst) { const bool float_parallel = src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16 || src->type == GGML_TYPE_I32; ggml_webgpu_shader_lib_context shader_lib_ctx = { @@ -1332,13 +1227,14 @@ static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx, uint32_t total_threads = float_parallel ? blocks_per_row * total_rows : total_rows; uint32_t wg_x = CEIL_DIV(total_threads, decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); } -static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, + wgpu::CommandEncoder & encoder, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { // Determine if this is a mat-vec operation bool is_vec = (dst->ne[1] == 1); @@ -1477,16 +1373,18 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y); } - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, wg_y); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x, wg_y); } -static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, - ggml_tensor * Q, - ggml_tensor * K, - ggml_tensor * V, - ggml_tensor * mask, - ggml_tensor * sinks, - ggml_tensor * dst) { +#ifndef __EMSCRIPTEN__ +static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, + wgpu::CommandEncoder & encoder, + ggml_tensor * Q, + ggml_tensor * K, + ggml_tensor * V, + ggml_tensor * mask, + ggml_tensor * sinks, + ggml_tensor * dst) { float scale = *(float *) dst->op_params; float max_bias; memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); @@ -1575,9 +1473,8 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0; const bool use_vec = (Q->ne[1] < 20) && (Q->ne[0] % 32 == 0) && (V->ne[0] % 4 == 0) && kv_vec_type_supported && (K->type != GGML_TYPE_F16 || f16_vec4_aligned) && (V->type == K->type); - const uint32_t vec_nwg_cap = - std::max(1u, std::min(32u, ctx->global_ctx->capabilities.max_subgroup_size)); - const bool use_blk = use_vec && has_mask; + const uint32_t vec_nwg_cap = std::max(1u, std::min(32u, ctx->global_ctx->capabilities.max_subgroup_size)); + const bool use_blk = use_vec && has_mask; ggml_webgpu_flash_attn_pipeline_key key = { .kv_type = K->type, @@ -1656,9 +1553,9 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, if (use_blk) { GGML_ASSERT(has_mask); - blk_nblk0 = CEIL_DIV((uint32_t) K->ne[1], decisions->kv_tile); - blk_nblk1 = CEIL_DIV((uint32_t) Q->ne[1], decisions->q_tile); - blk_buf = ggml_webgpu_tensor_buf(dst); + blk_nblk0 = CEIL_DIV((uint32_t) K->ne[1], decisions->kv_tile); + blk_nblk1 = CEIL_DIV((uint32_t) Q->ne[1], decisions->q_tile); + blk_buf = ggml_webgpu_tensor_buf(dst); const uint32_t stride_mask3 = (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)); blk_batch_count = stride_mask3 > 0 ? (uint32_t) Q->ne[3] : 1u; const uint64_t blk_elems = (uint64_t) blk_nblk0 * blk_nblk1 * blk_batch_count; @@ -1729,8 +1626,10 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, .size = ggml_webgpu_tensor_binding_size(ctx, sinks) }); } if (use_blk) { - split_entries.push_back( - { .binding = split_binding_index++, .buffer = blk_buf, .offset = blk_entries[1].offset, .size = blk_size_bytes }); + split_entries.push_back({ .binding = split_binding_index++, + .buffer = blk_buf, + .offset = blk_entries[1].offset, + .size = blk_size_bytes }); } split_entries.push_back( { .binding = split_binding_index++, .buffer = tmp_buf, .offset = tmp_bind_offset, .size = tmp_bind_size }); @@ -1799,14 +1698,18 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, workgroups_list.push_back({ (uint32_t) nrows, 1u }); } - return ggml_backend_webgpu_build_multi(ctx->global_ctx, ctx->param_buf_pool, pipelines, params_list, + return ggml_backend_webgpu_build_multi(ctx->global_ctx, ctx->param_arena, encoder, pipelines, params_list, entries_list, workgroups_list); } - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); } +#endif // __EMSCRIPTEN__ -static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_unary_op(webgpu_context & ctx, + wgpu::CommandEncoder & encoder, + ggml_tensor * src, + ggml_tensor * dst) { bool is_unary = dst->op == GGML_OP_UNARY; bool inplace = ggml_webgpu_tensor_equal(src, dst) || (dst->op == GGML_OP_FILL); @@ -1881,13 +1784,14 @@ static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * s } uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); } -static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_binary_op(webgpu_context & ctx, + wgpu::CommandEncoder & encoder, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { binary_overlap_flags flags = ggml_webgpu_detect_binary_overlap(src0, src1, dst); ggml_webgpu_shader_lib_context shader_lib_ctx = { @@ -1983,13 +1887,14 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx, } uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); } -static webgpu_command ggml_webgpu_concat(webgpu_context & ctx, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_concat(webgpu_context & ctx, + wgpu::CommandEncoder & encoder, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { uint32_t ne = (uint32_t) ggml_nelements(dst); uint32_t dim = (uint32_t) dst->op_params[0]; @@ -2039,10 +1944,13 @@ static webgpu_command ggml_webgpu_concat(webgpu_context & ctx, webgpu_pipeline pipeline = ctx->shader_lib->get_concat_pipeline(shader_lib_ctx); auto * decisions = static_cast(pipeline.context.get()); uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); } -static webgpu_command ggml_webgpu_repeat(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_repeat(webgpu_context & ctx, + wgpu::CommandEncoder & encoder, + ggml_tensor * src0, + ggml_tensor * dst) { uint32_t ne = (uint32_t) ggml_nelements(dst); std::vector params = { ne, @@ -2081,10 +1989,13 @@ static webgpu_command ggml_webgpu_repeat(webgpu_context & ctx, ggml_tensor * src webgpu_pipeline pipeline = ctx->shader_lib->get_repeat_pipeline(shader_lib_ctx); auto * decisions = static_cast(pipeline.context.get()); uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); } -static webgpu_command ggml_webgpu_row_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_row_norm(webgpu_context & ctx, + wgpu::CommandEncoder & encoder, + ggml_tensor * src, + ggml_tensor * dst) { bool inplace = ggml_webgpu_tensor_equal(src, dst); std::vector params = { @@ -2124,14 +2035,16 @@ static webgpu_command ggml_webgpu_row_norm(webgpu_context & ctx, ggml_tensor * s }; webgpu_pipeline pipeline = ctx->shader_lib->get_row_norm_pipeline(shader_lib_ctx); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, ggml_nrows(src)); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, + ggml_nrows(src)); } -static webgpu_command ggml_webgpu_rope(webgpu_context & ctx, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * src2, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_rope(webgpu_context & ctx, + wgpu::CommandEncoder & encoder, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * src2, + ggml_tensor * dst) { ggml_webgpu_shader_lib_context shader_lib_ctx = { .src0 = src0, .src1 = src1, @@ -2228,10 +2141,14 @@ static webgpu_command ggml_webgpu_rope(webgpu_context & ctx, } uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); } -static webgpu_command ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_glu(webgpu_context & ctx, + wgpu::CommandEncoder & encoder, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { ggml_webgpu_shader_lib_context shader_lib_ctx = { .src0 = src0, .src1 = src1, @@ -2290,10 +2207,13 @@ static webgpu_command ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0, .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); } -static webgpu_command ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_scale(webgpu_context & ctx, + wgpu::CommandEncoder & encoder, + ggml_tensor * src, + ggml_tensor * dst) { bool inplace = ggml_webgpu_tensor_equal(src, dst); ggml_webgpu_shader_lib_context shader_lib_ctx = { @@ -2341,14 +2261,15 @@ static webgpu_command ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, } uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); } -static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * src2, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_soft_max(webgpu_context & ctx, + wgpu::CommandEncoder & encoder, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * src2, + ggml_tensor * dst) { ggml_webgpu_shader_lib_context shader_lib_ctx = { .src0 = src0, .src1 = src1, @@ -2424,10 +2345,14 @@ static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx, .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); } - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, ggml_nrows(dst)); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, + ggml_nrows(dst)); } -static webgpu_command ggml_webgpu_argmax(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_argmax(webgpu_context & ctx, + wgpu::CommandEncoder & encoder, + ggml_tensor * src, + ggml_tensor * dst) { std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), (uint32_t) src->ne[0] }; @@ -2449,10 +2374,13 @@ static webgpu_command ggml_webgpu_argmax(webgpu_context & ctx, ggml_tensor * src webgpu_pipeline pipeline = ctx->shader_lib->get_argmax_pipeline(shader_lib_ctx); uint32_t wg_x = ggml_nelements(dst); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); } -static webgpu_command ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_argsort(webgpu_context & ctx, + wgpu::CommandEncoder & encoder, + ggml_tensor * src, + ggml_tensor * dst) { bool is_top_k = dst->op == GGML_OP_TOP_K; ggml_webgpu_shader_lib_context shader_lib_ctx = { @@ -2543,7 +2471,7 @@ static webgpu_command ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * sr workgroups_list.push_back({ wg_x_init, wg_y_init }); if (merge_passes == 0) { - return ggml_backend_webgpu_build_multi(ctx->global_ctx, ctx->param_buf_pool, pipelines, params_list, + return ggml_backend_webgpu_build_multi(ctx->global_ctx, ctx->param_arena, encoder, pipelines, params_list, entries_list, workgroups_list); } @@ -2605,11 +2533,14 @@ static webgpu_command ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * sr in_is_tmp = !in_is_tmp; } - return ggml_backend_webgpu_build_multi(ctx->global_ctx, ctx->param_buf_pool, pipelines, params_list, entries_list, - workgroups_list); + return ggml_backend_webgpu_build_multi(ctx->global_ctx, ctx->param_arena, encoder, pipelines, params_list, + entries_list, workgroups_list); } -static webgpu_command ggml_webgpu_cumsum(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_cumsum(webgpu_context & ctx, + wgpu::CommandEncoder & encoder, + ggml_tensor * src, + ggml_tensor * dst) { std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), (uint32_t) src->ne[0] }; @@ -2634,10 +2565,13 @@ static webgpu_command ggml_webgpu_cumsum(webgpu_context & ctx, ggml_tensor * src webgpu_pipeline pipeline = ctx->shader_lib->get_cumsum_pipeline(shader_lib_ctx); uint32_t wg_x = ggml_nrows(dst); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); } -static webgpu_command ggml_webgpu_sum_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_sum_rows(webgpu_context & ctx, + wgpu::CommandEncoder & encoder, + ggml_tensor * src, + ggml_tensor * dst) { bool total_sum = dst->op == GGML_OP_SUM; std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), @@ -2666,11 +2600,13 @@ static webgpu_command ggml_webgpu_sum_rows(webgpu_context & ctx, ggml_tensor * s webgpu_pipeline pipeline = ctx->shader_lib->get_sum_rows_pipeline(shader_lib_ctx); uint32_t wg_x = total_sum ? 1 : ggml_nrows(dst); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); } // Returns the encoded command, or std::nullopt if the operation is a no-op -static std::optional ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) { +static std::optional ggml_webgpu_encode_node(webgpu_context ctx, + wgpu::CommandEncoder & encoder, + ggml_tensor * node) { if (ggml_is_empty(node)) { return std::nullopt; } @@ -2693,18 +2629,18 @@ static std::optional ggml_webgpu_encode_node(webgpu_context ctx, return std::nullopt; case GGML_OP_CPY: case GGML_OP_CONT: - return ggml_webgpu_cpy(ctx, src0, node); + return ggml_webgpu_cpy(ctx, encoder, src0, node); case GGML_OP_SET: - return ggml_webgpu_set(ctx, src0, src1, node); + return ggml_webgpu_set(ctx, encoder, src0, src1, node); case GGML_OP_SET_ROWS: - return ggml_webgpu_set_rows(ctx, src0, src1, node); + return ggml_webgpu_set_rows(ctx, encoder, src0, src1, node); case GGML_OP_GET_ROWS: - return ggml_webgpu_get_rows(ctx, src0, src1, node); + return ggml_webgpu_get_rows(ctx, encoder, src0, src1, node); case GGML_OP_MUL_MAT: - return ggml_webgpu_mul_mat(ctx, src0, src1, node); + return ggml_webgpu_mul_mat(ctx, encoder, src0, src1, node); case GGML_OP_FLASH_ATTN_EXT: #ifndef __EMSCRIPTEN__ - return ggml_webgpu_flash_attn(ctx, src0, src1, src2, node->src[3], node->src[4], node); + return ggml_webgpu_flash_attn(ctx, encoder, src0, src1, src2, node->src[3], node->src[4], node); #else return std::nullopt; #endif @@ -2712,22 +2648,22 @@ static std::optional ggml_webgpu_encode_node(webgpu_context ctx, case GGML_OP_SUB: case GGML_OP_MUL: case GGML_OP_DIV: - return ggml_webgpu_binary_op(ctx, src0, src1, node); + return ggml_webgpu_binary_op(ctx, encoder, src0, src1, node); case GGML_OP_CONCAT: - return ggml_webgpu_concat(ctx, src0, src1, node); + return ggml_webgpu_concat(ctx, encoder, src0, src1, node); case GGML_OP_REPEAT: - return ggml_webgpu_repeat(ctx, src0, node); + return ggml_webgpu_repeat(ctx, encoder, src0, node); case GGML_OP_RMS_NORM: case GGML_OP_L2_NORM: - return ggml_webgpu_row_norm(ctx, src0, node); + return ggml_webgpu_row_norm(ctx, encoder, src0, node); case GGML_OP_ROPE: - return ggml_webgpu_rope(ctx, src0, src1, src2, node); + return ggml_webgpu_rope(ctx, encoder, src0, src1, src2, node); case GGML_OP_GLU: - return ggml_webgpu_glu(ctx, src0, src1, node); + return ggml_webgpu_glu(ctx, encoder, src0, src1, node); case GGML_OP_SCALE: - return ggml_webgpu_scale(ctx, src0, node); + return ggml_webgpu_scale(ctx, encoder, src0, node); case GGML_OP_SOFT_MAX: - return ggml_webgpu_soft_max(ctx, src0, src1, src2, node); + return ggml_webgpu_soft_max(ctx, encoder, src0, src1, src2, node); case GGML_OP_UNARY: case GGML_OP_CLAMP: case GGML_OP_FILL: @@ -2738,26 +2674,27 @@ static std::optional ggml_webgpu_encode_node(webgpu_context ctx, case GGML_OP_COS: case GGML_OP_DIAG: case GGML_OP_TRI: - return ggml_webgpu_unary_op(ctx, src0, node); + return ggml_webgpu_unary_op(ctx, encoder, src0, node); case GGML_OP_SOLVE_TRI: - return ggml_webgpu_solve_tri(ctx, src0, src1, node); + return ggml_webgpu_solve_tri(ctx, encoder, src0, src1, node); case GGML_OP_SSM_CONV: - return ggml_webgpu_ssm_conv(ctx, src0, src1, node); + return ggml_webgpu_ssm_conv(ctx, encoder, src0, src1, node); case GGML_OP_GATED_DELTA_NET: - return ggml_webgpu_gated_delta_net(ctx, src0, src1, src2, node->src[3], node->src[4], node->src[5], node); + return ggml_webgpu_gated_delta_net(ctx, encoder, src0, src1, src2, node->src[3], node->src[4], node->src[5], + node); case GGML_OP_PAD: - return ggml_webgpu_pad(ctx, src0, node); + return ggml_webgpu_pad(ctx, encoder, src0, node); case GGML_OP_ARGMAX: - return ggml_webgpu_argmax(ctx, src0, node); + return ggml_webgpu_argmax(ctx, encoder, src0, node); case GGML_OP_ARGSORT: case GGML_OP_TOP_K: // we reuse the same argsort implementation for top_k - return ggml_webgpu_argsort(ctx, src0, node); + return ggml_webgpu_argsort(ctx, encoder, src0, node); case GGML_OP_CUMSUM: - return ggml_webgpu_cumsum(ctx, src0, node); + return ggml_webgpu_cumsum(ctx, encoder, src0, node); case GGML_OP_SUM: case GGML_OP_SUM_ROWS: - return ggml_webgpu_sum_rows(ctx, src0, node); + return ggml_webgpu_sum_rows(ctx, encoder, src0, node); default: return std::nullopt; } @@ -2771,30 +2708,42 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str WEBGPU_CPU_PROFILE_TOTAL_START(graph_compute); - std::vector commands; - std::vector subs; - uint32_t num_batched_kernels = 0; - bool contains_set_rows = false; + std::vector commands; +#ifdef GGML_WEBGPU_GPU_PROFILE + std::vector profile_futures; +#endif + uint32_t num_batched_kernels = 0; + bool contains_set_rows = false; + wgpu::CommandEncoder batch_encoder = ctx->global_ctx->device.CreateCommandEncoder(); + for (int i = 0; i < cgraph->n_nodes; i++) { if (cgraph->nodes[i]->op == GGML_OP_SET_ROWS) { contains_set_rows = true; } - if (auto cmd = ggml_webgpu_encode_node(ctx, cgraph->nodes[i])) { + if (auto cmd = ggml_webgpu_encode_node(ctx, batch_encoder, cgraph->nodes[i])) { commands.push_back(*cmd); num_batched_kernels += cmd.value().num_kernels; } if (num_batched_kernels >= WEBGPU_COMMAND_SUBMIT_BATCH_SIZE) { - num_batched_kernels = 0; - subs.push_back(ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool)); - // Process events and check for completed submissions - ctx->global_ctx->instance.ProcessEvents(); - ggml_backend_webgpu_wait(ctx->global_ctx, subs, false); + num_batched_kernels = 0; + wgpu::CommandBuffer batch_commands = batch_encoder.Finish(); + ctx->global_ctx->queue.Submit(1, &batch_commands); +#ifdef GGML_WEBGPU_GPU_PROFILE + ggml_backend_webgpu_collect_profile_futures(ctx->global_ctx, commands, profile_futures); +#endif + ctx->param_arena.reset(); commands.clear(); + batch_encoder = ctx->global_ctx->device.CreateCommandEncoder(); } } if (!commands.empty()) { - subs.push_back(ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool)); + wgpu::CommandBuffer batch_commands = batch_encoder.Finish(); + ctx->global_ctx->queue.Submit(1, &batch_commands); +#ifdef GGML_WEBGPU_GPU_PROFILE + ggml_backend_webgpu_collect_profile_futures(ctx->global_ctx, commands, profile_futures); +#endif + ctx->param_arena.reset(); commands.clear(); } @@ -2805,6 +2754,11 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str ctx->set_rows_host_error_buf.GetSize()); wgpu::CommandBuffer set_rows_commands = encoder.Finish(); ctx->global_ctx->queue.Submit(1, &set_rows_commands); + } + + ggml_backend_webgpu_wait_queue(ctx->global_ctx); + + if (contains_set_rows) { ggml_backend_webgpu_map_buffer(ctx->global_ctx, ctx->set_rows_host_error_buf, wgpu::MapMode::Read, 0, ctx->set_rows_host_error_buf.GetSize()); const uint32_t * error_data = (const uint32_t *) ctx->set_rows_host_error_buf.GetConstMappedRange(); @@ -2814,7 +2768,9 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str ctx->set_rows_host_error_buf.Unmap(); } - ggml_backend_webgpu_wait(ctx->global_ctx, subs); +#ifdef GGML_WEBGPU_GPU_PROFILE + ggml_backend_webgpu_wait_profile_futures(ctx->global_ctx, profile_futures); +#endif WEBGPU_CPU_PROFILE_TOTAL_END(graph_compute, ctx->global_ctx); return GGML_STATUS_SUCCESS; } @@ -3063,18 +3019,16 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer (K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0); const bool kv_vec_type_supported = K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0; - const bool use_vec = - (Q->ne[1] < 20) && (Q->ne[0] % 32 == 0) && (V->ne[0] % 4 == 0) && kv_vec_type_supported && - (V->type == K->type); + const bool use_vec = (Q->ne[1] < 20) && (Q->ne[0] % 32 == 0) && (V->ne[0] % 4 == 0) && + kv_vec_type_supported && (V->type == K->type); if (use_vec) { const uint32_t sg_mat_m = ctx->webgpu_global_ctx->capabilities.sg_mat_m; const uint32_t sg_mat_n = ctx->webgpu_global_ctx->capabilities.sg_mat_n; const size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; - const size_t q_tile = sg_mat_m; - const size_t base_q_bytes = - (Q->ne[0] + V->ne[0]) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES + - 2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES; + const size_t q_tile = sg_mat_m; + const size_t base_q_bytes = (Q->ne[0] + V->ne[0]) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES + + 2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES; size_t bytes_per_kv = 0; if (!kv_direct) { bytes_per_kv += std::max(Q->ne[0], V->ne[0]); @@ -3084,10 +3038,9 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer } bytes_per_kv += q_tile; bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES; - uint32_t kv_tile = - ((limit_bytes - base_q_bytes) / bytes_per_kv / sg_mat_n) * sg_mat_n; - kv_tile = std::max(sg_mat_n, std::min(32u, kv_tile)); - kv_tile = (kv_tile / sg_mat_n) * sg_mat_n; + uint32_t kv_tile = ((limit_bytes - base_q_bytes) / bytes_per_kv / sg_mat_n) * sg_mat_n; + kv_tile = std::max(sg_mat_n, std::min(32u, kv_tile)); + kv_tile = (kv_tile / sg_mat_n) * sg_mat_n; if (kv_direct) { GGML_ASSERT(kv_tile <= GGML_WEBGPU_KV_SEQ_PAD); while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) { @@ -3097,30 +3050,30 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer const uint32_t vec_nwg_cap = std::max( 1u, std::min(32u, ctx->webgpu_global_ctx->capabilities.max_subgroup_size)); - uint32_t nwg = 1u; - const uint64_t kv_span = (uint64_t) std::max(1u, kv_tile); + uint32_t nwg = 1u; + const uint64_t kv_span = (uint64_t) std::max(1u, kv_tile); while ((2u * nwg * kv_span) < (uint64_t) K->ne[1] && nwg < vec_nwg_cap) { nwg <<= 1; } nwg = std::min(nwg, vec_nwg_cap); - const size_t align = ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment; + const size_t align = + ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment; const uint64_t nrows = (uint64_t) Q->ne[1] * Q->ne[2] * Q->ne[3]; if (nwg > 1u) { const uint64_t tmp_data_elems = nrows * (uint64_t) V->ne[0] * nwg; const uint64_t tmp_stats_elems = nrows * 2u * nwg; - const size_t tmp_size_bytes = ROUNDUP_POW2( + const size_t tmp_size_bytes = ROUNDUP_POW2( (tmp_data_elems + tmp_stats_elems) * sizeof(float), WEBGPU_STORAGE_BUF_BINDING_MULT); res += tmp_size_bytes + align; } if (mask != nullptr) { - const uint32_t blk_nblk0 = CEIL_DIV((uint32_t) K->ne[1], kv_tile); - const uint32_t blk_nblk1 = CEIL_DIV((uint32_t) Q->ne[1], 1u); - const uint32_t stride_mask3 = - (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)); + const uint32_t blk_nblk0 = CEIL_DIV((uint32_t) K->ne[1], kv_tile); + const uint32_t blk_nblk1 = CEIL_DIV((uint32_t) Q->ne[1], 1u); + const uint32_t stride_mask3 = (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)); const uint32_t blk_batch_count = stride_mask3 > 0 ? (uint32_t) Q->ne[3] : 1u; - const uint64_t blk_elems = (uint64_t) blk_nblk0 * blk_nblk1 * blk_batch_count; - const size_t blk_size_bytes = + const uint64_t blk_elems = (uint64_t) blk_nblk0 * blk_nblk1 * blk_batch_count; + const size_t blk_size_bytes = ROUNDUP_POW2(blk_elems * sizeof(uint32_t), WEBGPU_STORAGE_BUF_BINDING_MULT); res += blk_size_bytes + align; } @@ -3195,11 +3148,11 @@ static void ggml_webgpu_init_memset_pipeline(webgpu_global_context & ctx) { ctx->capabilities.memset_bytes_per_thread = CEIL_DIV(ctx->capabilities.limits.maxStorageBufferBindingSize, max_threads); std::vector constants(2); - constants[0].key = "wg_size"; - constants[0].value = WEBGPU_MAX_WG_SIZE; - constants[1].key = "bytes_per_thread"; - constants[1].value = ctx->capabilities.memset_bytes_per_thread; - ctx->memset_pipelines[0] = ggml_webgpu_create_pipeline(ctx->device, wgsl_memset, "memset", constants); + constants[0].key = "wg_size"; + constants[0].value = WEBGPU_MAX_WG_SIZE; + constants[1].key = "bytes_per_thread"; + constants[1].value = ctx->capabilities.memset_bytes_per_thread; + ctx->memset_pipeline = ggml_webgpu_create_pipeline(ctx->device, wgsl_memset, "memset", constants); } static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { @@ -3331,9 +3284,9 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { GGML_ASSERT(ctx->webgpu_global_ctx->device != nullptr); ggml_webgpu_init_memset_pipeline(ctx->webgpu_global_ctx); - ctx->webgpu_global_ctx->memset_buf_pool.init(ctx->webgpu_global_ctx->device, 1, WEBGPU_PARAMS_BUF_SIZE_BYTES, - wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform, - wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite); + ggml_webgpu_create_buffer(ctx->webgpu_global_ctx->device, ctx->webgpu_global_ctx->memset_params_buf, + WEBGPU_PARAMS_BUF_SIZE_BYTES, wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform, + "memset_params_buf"); ctx->webgpu_global_ctx->queue = ctx->webgpu_global_ctx->device.GetQueue(); #ifdef GGML_WEBGPU_GPU_PROFILE @@ -3357,9 +3310,8 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) { webgpu_context webgpu_ctx = std::make_shared(); webgpu_ctx->global_ctx = dev_ctx->webgpu_global_ctx; webgpu_ctx->shader_lib = std::make_unique(dev_ctx->webgpu_global_ctx->device); - webgpu_ctx->param_buf_pool.init(webgpu_ctx->global_ctx->device, WEBGPU_NUM_PARAM_BUFS, WEBGPU_PARAMS_BUF_SIZE_BYTES, - wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform, - wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite, true); + webgpu_ctx->param_arena.init(webgpu_ctx->global_ctx->device, WEBGPU_PARAMS_BUF_SIZE_BYTES, WEBGPU_NUM_PARAM_SLOTS, + webgpu_ctx->global_ctx->capabilities.limits.minUniformBufferOffsetAlignment); ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->set_rows_dev_error_buf, WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES, wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc, "set_rows_dev_error_buf"); From c03104553133724ed8d8593cb87a6659f1399e68 Mon Sep 17 00:00:00 2001 From: Yarden Tal Date: Mon, 6 Apr 2026 04:30:25 +0300 Subject: [PATCH 387/831] hexagon: slight optimization for argosrt output init (llama/21463) --- ggml/src/ggml-hexagon/htp/argsort-ops.c | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-hexagon/htp/argsort-ops.c b/ggml/src/ggml-hexagon/htp/argsort-ops.c index 170220e8f80..3ec26a4c1ac 100644 --- a/ggml/src/ggml-hexagon/htp/argsort-ops.c +++ b/ggml/src/ggml-hexagon/htp/argsort-ops.c @@ -164,6 +164,12 @@ static void quicksort_values_indices_desc(float * values, int32_t * indices, int if (i < right) quicksort_values_indices_desc(values, indices, i, right); } +// LUT for ramp initialization of argsort output (first 32 members) +int32_t argosrt_ramp_lut[32] __attribute__((aligned(VLEN))) = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31 +}; + static void htp_argsort_f32(unsigned int n, unsigned int i, void * data) { struct htp_argsort_context * actx = (struct htp_argsort_context *)data; struct htp_ops_context * octx = actx->octx; @@ -205,8 +211,12 @@ static void htp_argsort_f32(unsigned int n, unsigned int i, void * data) { // Padded to 128 bytes. size_t values_size = hex_round_up(ne00 * sizeof(float), 128); + size_t num_vec_ind_values = hmx_ceil_div(ne00, VLEN/(sizeof(int32_t))); float * values_buf = (float *) spad; int32_t * indices_buf = (int32_t *) (spad + values_size); + HVX_Vector * indices_buf_vec = (HVX_Vector *) (spad + values_size); + const HVX_Vector ind_init_vec = *(HVX_Vector *)argosrt_ramp_lut; + const HVX_Vector ind_diff_vec = Q6_V_vsplat_R(32); for (uint32_t r = start_row; r < end_row; r++) { uint32_t src_offset = r * nb01; @@ -218,9 +228,11 @@ static void htp_argsort_f32(unsigned int n, unsigned int i, void * data) { hex_l2fetch(src_ptr, ne00 * sizeof(float), ne00 * sizeof(float), 1); hvx_copy_f32_au((uint8_t*)values_buf, src_ptr, ne00); - // Initialize indices - for (uint32_t j = 0; j < ne00; j++) { - indices_buf[j] = j; + // Initialize indices - Start with values 0..31, add 32 for additional vec iterations + HVX_Vector curr_ind_vec = ind_init_vec; + for (uint32_t j_vec = 0; j_vec < num_vec_ind_values; j_vec++) { + indices_buf_vec[j_vec] = curr_ind_vec; + curr_ind_vec = Q6_Vw_vadd_VwVw(curr_ind_vec, ind_diff_vec); } // Sort values and mirror swaps to indices From 42e4a28865c6909d8a5b6390a68740404005aa3f Mon Sep 17 00:00:00 2001 From: Neo Zhang Date: Mon, 6 Apr 2026 18:28:00 +0800 Subject: [PATCH 388/831] sycl : handle other FA case (llama/21377) --- ggml/src/ggml-sycl/fattn-tile.hpp | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/ggml/src/ggml-sycl/fattn-tile.hpp b/ggml/src/ggml-sycl/fattn-tile.hpp index c4d24613a55..b4d4e0ae90e 100644 --- a/ggml/src/ggml-sycl/fattn-tile.hpp +++ b/ggml/src/ggml-sycl/fattn-tile.hpp @@ -1252,6 +1252,16 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_sycl_context & ctx, ggm return; } + { + constexpr int cols_per_block = ncols2*2; + const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; + const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); + launch_fattn, warp_size> + (ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false); + return; + } + GGML_ABORT("fatal error"); } From 7b19b94c5dc822a21bdb2e574ece4a7b2316c436 Mon Sep 17 00:00:00 2001 From: Gaurav Garg Date: Tue, 7 Apr 2026 00:04:29 +0530 Subject: [PATCH 389/831] Write an optimized flash_attn_stream_k_fixup kernel (llama/21159) * Write an optimized flash_attn_stream_k_fixup kernel Write a specialized and more optimized kernel for cases where nblocks_stream_k is multiple of ntiles_dst. Make nblocks_stream_k to multiple of ntiles_dst if nblocks_stream_k > 2 * ntiles_dst * Use the new kernel only for nblocks_stream_k_raw > 4 * ntiles_dst to make sure we have enough concurrency on GPUs * Address review comments * Address review comments * Revert variable names to original --- ggml/src/ggml-cuda/fattn-common.cuh | 178 ++++++++++++++++++++++++---- 1 file changed, 153 insertions(+), 25 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index c59a4db3999..beeb5238946 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -676,9 +676,96 @@ static __global__ void flash_attn_mask_to_KV_max( template // D == head size __launch_bounds__(D, 1) -static __global__ void flash_attn_stream_k_fixup( - float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, - const int ne11, const int ne12, const int nbatch_fa) { +static __global__ void flash_attn_stream_k_fixup_uniform( + float * __restrict__ dst, + const float2 * __restrict__ dst_fixup, + const int ne01, const int ne02, + const int ne12, const int nblocks_stream_k, + const int gqa_ratio, + const int blocks_per_tile, + const uint3 fd_iter_j_z_ne12, + const uint3 fd_iter_j_z, + const uint3 fd_iter_j) { + constexpr int ncols = ncols1*ncols2; + + const int tile_idx = blockIdx.x; // One block per output tile. + const int j = blockIdx.y; + const int c = blockIdx.z; + const int jc = j*ncols2 + c; + const int tid = threadIdx.x; + + // nblocks_stream_k is a multiple of ntiles_dst (== gridDim.x), so each tile gets the same number of blocks. + const int b_first = tile_idx * blocks_per_tile; + const int b_last = b_first + blocks_per_tile - 1; + + const float * dst_fixup_data = ((const float *) dst_fixup) + nblocks_stream_k*(2*2*ncols); + + // z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index + const uint2 dm0 = fast_div_modulo(tile_idx, fd_iter_j_z_ne12); + const uint2 dm1 = fast_div_modulo(dm0.y, fd_iter_j_z); + const uint2 dm2 = fast_div_modulo(dm1.y, fd_iter_j); + + const int sequence = dm0.x; + const int z_KV = dm1.x; + const int zt_gqa = dm2.x; + const int jt = dm2.y; + + const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index. + + if (jt*ncols1 + j >= ne01 || zt_gqa*ncols2 + c >= gqa_ratio) { + return; + } + + dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + zt_Q*D + (j*ne02 + c)*D + tid; + + // Load the partial result that needs a fixup + float dst_val = *dst; + float max_val; + float rowsum; + { + const float2 tmp = dst_fixup[b_last*ncols + jc]; + max_val = tmp.x; + rowsum = tmp.y; + } + + // Combine with all previous blocks in this tile. + for (int bidx = b_last - 1; bidx >= b_first; --bidx) { + const float dst_add = dst_fixup_data[bidx*ncols*D + jc*D + tid]; + + const float2 tmp = dst_fixup[(nblocks_stream_k + bidx)*ncols + jc]; + + const float max_val_new = fmaxf(max_val, tmp.x); + + const float diff_val = max_val - max_val_new; + const float diff_add = tmp.x - max_val_new; + + const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_val) : 0.0f; + const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_add) : 0.0f; + + dst_val = scale_val*dst_val + scale_add*dst_add; + rowsum = scale_val*rowsum + scale_add*tmp.y; + + max_val = max_val_new; + } + + // Write back final result: + *dst = dst_val / rowsum; +} + +// General fixup kernel for the case where the number of blocks per tile is not uniform across tiles +// (blocks_num.x not a multiple of ntiles_dst) +template // D == head size +__launch_bounds__(D, 1) +static __global__ void flash_attn_stream_k_fixup_general( + float * __restrict__ dst, + const float2 * __restrict__ dst_fixup, + const int ne01, const int ne02, + const int gqa_ratio, + const int total_work, + const uint3 fd_iter_k_j_z_ne12, + const uint3 fd_iter_k_j_z, + const uint3 fd_iter_k_j, + const uint3 fd_iter_k) { constexpr int ncols = ncols1*ncols2; const int bidx0 = blockIdx.x; @@ -689,27 +776,26 @@ static __global__ void flash_attn_stream_k_fixup( const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols); - const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - - const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa; - const int iter_j = (ne01 + (ncols1 - 1)) / ncols1; - const int iter_z_gqa = (gqa_ratio + (ncols2 - 1)) / ncols2; - - const int kbc0 = int64_t(bidx0 + 0)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x; - const int kbc0_stop = int64_t(bidx0 + 1)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x; + const int kbc0 = int64_t(bidx0 + 0)*total_work / gridDim.x; + const int kbc0_stop = int64_t(bidx0 + 1)*total_work / gridDim.x; const bool did_not_have_any_data = kbc0 == kbc0_stop; - const bool wrote_beginning_of_tile = kbc0 % iter_k == 0; - const bool did_not_write_last = kbc0/iter_k == kbc0_stop/iter_k && kbc0_stop % iter_k != 0; + const bool wrote_beginning_of_tile = fastmodulo(kbc0, fd_iter_k) == 0; + const bool did_not_write_last = fastdiv(kbc0, fd_iter_k) == fastdiv(kbc0_stop, fd_iter_k) && fastmodulo(kbc0_stop, fd_iter_k) != 0; if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) { return; } // z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index - const int sequence = kbc0 /(iter_k*iter_j*iter_z_gqa*ne12); - const int z_KV = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa); - const int zt_gqa = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j); - const int jt = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k; + const uint2 dm0 = fast_div_modulo(kbc0, fd_iter_k_j_z_ne12); + const uint2 dm1 = fast_div_modulo(dm0.y, fd_iter_k_j_z); + const uint2 dm2 = fast_div_modulo(dm1.y, fd_iter_k_j); + const uint2 dm3 = fast_div_modulo(dm2.y, fd_iter_k); + + const int sequence = dm0.x; + const int z_KV = dm1.x; + const int zt_gqa = dm2.x; + const int jt = dm3.x; const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index. @@ -733,10 +819,11 @@ static __global__ void flash_attn_stream_k_fixup( // Iterate over previous blocks and compute the combined results. // All CUDA blocks that get here must have a previous block that needs a fixup. + const int tile_kbc0 = fastdiv(kbc0, fd_iter_k); int bidx = bidx0 - 1; int kbc_stop = kbc0; while(true) { - const int kbc = int64_t(bidx)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x; + const int kbc = int64_t(bidx)*total_work / gridDim.x; if (kbc == kbc_stop) { // Did not have any data. bidx--; kbc_stop = kbc; @@ -762,7 +849,7 @@ static __global__ void flash_attn_stream_k_fixup( max_val = max_val_new; // If this block started in a previous tile we are done and don't need to combine additional partial results. - if (kbc % iter_k == 0 || kbc/iter_k < kbc0/iter_k) { + if (fastmodulo(kbc, fd_iter_k) == 0 || fastdiv(kbc, fd_iter_k) < tile_kbc0) { break; } bidx--; @@ -976,14 +1063,28 @@ void launch_fattn( const int tiles_nwaves = (ntiles_dst + max_blocks - 1) / max_blocks; const int tiles_efficiency_percent = 100 * ntiles_dst / (max_blocks*tiles_nwaves); - const int nblocks_stream_k = std::min(max_blocks, ntiles_KV*ntiles_dst); - const bool use_stream_k = cc >= GGML_CUDA_CC_ADA_LOVELACE || amd_wmma_available(cc) || tiles_efficiency_percent < 75; - blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_dst; + blocks_num.x = ntiles_dst; blocks_num.y = 1; blocks_num.z = 1; + if(use_stream_k) { + const int nblocks_stream_k_raw = std::min(max_blocks, ntiles_KV*ntiles_dst); + // Round down to a multiple of ntiles_dst so that each output tile gets the same number of blocks (avoids fixup). + // Only do this if the occupancy loss from rounding is acceptable. + const int nblocks_stream_k_rounded = (nblocks_stream_k_raw / ntiles_dst) * ntiles_dst; + const int max_efficiency_loss_percent = 5; + const int efficiency_loss_percent = nblocks_stream_k_rounded > 0 + ? 100 * (nblocks_stream_k_raw - nblocks_stream_k_rounded) / nblocks_stream_k_raw + : 100; + const int nblocks_stream_k = efficiency_loss_percent <= max_efficiency_loss_percent + ? nblocks_stream_k_rounded + : nblocks_stream_k_raw; + + blocks_num.x = nblocks_stream_k; + } + if (ntiles_dst % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles. dst_tmp_meta.alloc((size_t(blocks_num.x) * ncols * (2 + DV/2))); } @@ -1063,13 +1164,40 @@ void launch_fattn( CUDA_CHECK(cudaGetLastError()); if (stream_k) { - if (ntiles_dst % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles. + if ((int)blocks_num.x % ntiles_dst == 0 && (int)blocks_num.x > ntiles_dst) { + // Optimized fixup: nblocks_stream_k is a multiple of ntiles_dst, launch one block per tile. + const int nblocks_sk = (int)blocks_num.x; + const int bpt = nblocks_sk / ntiles_dst; + + const uint3 fd0 = init_fastdiv_values(ntiles_x * ntiles_z_gqa * K->ne[2]); + const uint3 fd1 = init_fastdiv_values(ntiles_x * ntiles_z_gqa); + const uint3 fd2 = init_fastdiv_values(ntiles_x); + + const dim3 block_dim_combine(DV, 1, 1); + const dim3 blocks_num_combine = {(unsigned)ntiles_dst, ncols1, ncols2}; + + flash_attn_stream_k_fixup_uniform + <<>> + ((float *) KQV->data, dst_tmp_meta.ptr, + Q->ne[1], Q->ne[2], K->ne[2], nblocks_sk, + gqa_ratio, bpt, fd0, fd1, fd2); + } else if (ntiles_dst % blocks_num.x != 0) { + // General fixup for the cases where nblocks_stream_k < ntiles_dst. + const int total_work = ntiles_KV * ntiles_dst; + + const uint3 fd_k_j_z_ne12 = init_fastdiv_values(ntiles_KV * ntiles_x * ntiles_z_gqa * K->ne[2]); + const uint3 fd_k_j_z = init_fastdiv_values(ntiles_KV * ntiles_x * ntiles_z_gqa); + const uint3 fd_k_j = init_fastdiv_values(ntiles_KV * ntiles_x); + const uint3 fd_k = init_fastdiv_values(ntiles_KV); + const dim3 block_dim_combine(DV, 1, 1); const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2}; - flash_attn_stream_k_fixup + flash_attn_stream_k_fixup_general <<>> - ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1], K->ne[2], nbatch_fa); + ((float *) KQV->data, dst_tmp_meta.ptr, + Q->ne[1], Q->ne[2], gqa_ratio, total_work, + fd_k_j_z_ne12, fd_k_j_z, fd_k_j, fd_k); } } else if (parallel_blocks > 1) { const dim3 block_dim_combine(DV, 1, 1); From 0c2fbd4703a7a64a71dc07e60a17d89dac81d57b Mon Sep 17 00:00:00 2001 From: Pasha Khosravi Date: Mon, 6 Apr 2026 11:55:21 -0700 Subject: [PATCH 390/831] ggml: add Q1_0 1-bit quantization support (CPU) (llama/21273) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * ggml: add Q1_0 and Q1_0_g128 1-bit quantization support (CPU) * add generic fallback for x86 * remove Q1_0 (group size 32) * rename Q1_0_g128 => Q1_0 * fix Q1_0 LlamaFileType Enum * Fix trailing spaces; add generic fallback for othre backends * Apply suggestions from code review Co-authored-by: Sigbjørn Skjæret * fix /r/n spacing + arch-fallback --------- Co-authored-by: Sigbjørn Skjæret --- ggml/include/ggml.h | 4 +- ggml/src/ggml-common.h | 11 +++ ggml/src/ggml-cpu/arch-fallback.h | 7 ++ ggml/src/ggml-cpu/arch/arm/quants.c | 103 ++++++++++++++++++++++ ggml/src/ggml-cpu/arch/loongarch/quants.c | 1 - ggml/src/ggml-cpu/arch/powerpc/quants.c | 1 - ggml/src/ggml-cpu/arch/s390/quants.c | 1 - ggml/src/ggml-cpu/arch/wasm/quants.c | 1 - ggml/src/ggml-cpu/ggml-cpu.c | 6 ++ ggml/src/ggml-cpu/ops.cpp | 2 + ggml/src/ggml-cpu/quants.c | 49 ++++++++++ ggml/src/ggml-cpu/quants.h | 3 + ggml/src/ggml-quants.c | 75 ++++++++++++++++ ggml/src/ggml-quants.h | 3 + ggml/src/ggml.c | 10 +++ 15 files changed, 272 insertions(+), 5 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 669f66b650f..3bb2faa2c66 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -428,7 +428,8 @@ extern "C" { // GGML_TYPE_IQ4_NL_8_8 = 38, GGML_TYPE_MXFP4 = 39, // MXFP4 (1 block) GGML_TYPE_NVFP4 = 40, // NVFP4 (4 blocks, E4M3 scale) - GGML_TYPE_COUNT = 41, + GGML_TYPE_Q1_0 = 41, + GGML_TYPE_COUNT = 42, }; // precision @@ -465,6 +466,7 @@ extern "C" { GGML_FTYPE_MOSTLY_BF16 = 24, // except 1d tensors GGML_FTYPE_MOSTLY_MXFP4 = 25, // except 1d tensors GGML_FTYPE_MOSTLY_NVFP4 = 26, // except 1d tensors + GGML_FTYPE_MOSTLY_Q1_0 = 27, // except 1d tensors }; // available tensor operations: diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index 92cf739e7a7..f05683b44cd 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -93,6 +93,10 @@ typedef sycl::half2 ggml_half2; // QR = QK / number of values before dequantization // QI = number of 32 bit integers before dequantization +#define QI1_0 (QK1_0 / 32) +#define QR1_0 1 + + #define QI4_0 (QK4_0 / (4 * QR4_0)) #define QR4_0 2 @@ -170,6 +174,13 @@ typedef sycl::half2 ggml_half2; #define GGML_EXTENSION __extension__ #endif // _MSC_VER +#define QK1_0 128 +typedef struct { + ggml_half d; // delta + uint8_t qs[QK1_0 / 8]; // bits / quants +} block_q1_0; +static_assert(sizeof(block_q1_0) == sizeof(ggml_half) + QK1_0 / 8, "wrong q1_0 block size/padding"); + #define QK4_0 32 typedef struct { ggml_half d; // delta diff --git a/ggml/src/ggml-cpu/arch-fallback.h b/ggml/src/ggml-cpu/arch-fallback.h index 41da829315b..c589a213e9d 100644 --- a/ggml/src/ggml-cpu/arch-fallback.h +++ b/ggml/src/ggml-cpu/arch-fallback.h @@ -16,6 +16,7 @@ #define ggml_vec_dot_q8_0_q8_0_generic ggml_vec_dot_q8_0_q8_0 #define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0 #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 +#define ggml_vec_dot_q1_0_q8_0_generic ggml_vec_dot_q1_0_q8_0 #define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K #define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K @@ -82,6 +83,7 @@ #elif defined(__x86_64__) || defined(__i386__) || defined(_M_IX86) || defined(_M_X64) // quants.c #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 +#define ggml_vec_dot_q1_0_q8_0_generic ggml_vec_dot_q1_0_q8_0 // repack.cpp #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 #define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 @@ -112,6 +114,7 @@ // quants.c #define quantize_row_q8_K_generic quantize_row_q8_K #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 +#define ggml_vec_dot_q1_0_q8_0_generic ggml_vec_dot_q1_0_q8_0 #define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K #define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K @@ -160,6 +163,7 @@ #define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K #define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0 #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 +#define ggml_vec_dot_q1_0_q8_0_generic ggml_vec_dot_q1_0_q8_0 // repack.cpp #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8 @@ -200,6 +204,7 @@ #elif defined(__riscv) // quants.c #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 +#define ggml_vec_dot_q1_0_q8_0_generic ggml_vec_dot_q1_0_q8_0 // repack.cpp #define ggml_quantize_mat_q8_0_4x1_generic ggml_quantize_mat_q8_0_4x1 #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 @@ -240,6 +245,7 @@ // quants.c #define quantize_row_q8_K_generic quantize_row_q8_K #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 +#define ggml_vec_dot_q1_0_q8_0_generic ggml_vec_dot_q1_0_q8_0 #define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K #define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K @@ -303,6 +309,7 @@ #define ggml_vec_dot_iq4_xs_q8_K_generic ggml_vec_dot_iq4_xs_q8_K #define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0 #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 +#define ggml_vec_dot_q1_0_q8_0_generic ggml_vec_dot_q1_0_q8_0 // repack.cpp #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8 diff --git a/ggml/src/ggml-cpu/arch/arm/quants.c b/ggml/src/ggml-cpu/arch/arm/quants.c index 82b048bb3ae..e09db59cf22 100644 --- a/ggml/src/ggml-cpu/arch/arm/quants.c +++ b/ggml/src/ggml-cpu/arch/arm/quants.c @@ -137,6 +137,109 @@ void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, in //===================================== Dot products ================================= +void ggml_vec_dot_q1_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + const int qk = QK1_0; // 128 + const int nb = n / qk; + + assert(n % qk == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q1_0 * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + + float sumf = 0.0f; + +#if defined(__ARM_NEON) + float32x4_t sumv = vdupq_n_f32(0.0f); + + for (int i = 0; i < nb; i++) { + const float d0 = GGML_CPU_FP16_TO_FP32(x[i].d); + + // Process 4 Q8_0 blocks (each has 32 elements) + for (int k = 0; k < 4; k++) { + const block_q8_0 * GGML_RESTRICT yb = &y[i * 4 + k]; + const float d1 = GGML_CPU_FP16_TO_FP32(yb->d); + + // Get the 4 bytes of bits for this Q8_0 block (32 bits = 4 bytes) + // Bits are at offset k*4 bytes in x[i].qs + const uint8_t * bits = &x[i].qs[k * 4]; + + // Load 32 int8 values from y + const int8x16_t y0 = vld1q_s8(yb->qs); + const int8x16_t y1 = vld1q_s8(yb->qs + 16); + + // Byte 0-1: bits for y0[0..15] + const uint64_t expand0 = table_b2b_0[bits[0]]; + const uint64_t expand1 = table_b2b_0[bits[1]]; + // Byte 2-3: bits for y1[0..15] + const uint64_t expand2 = table_b2b_0[bits[2]]; + const uint64_t expand3 = table_b2b_0[bits[3]]; + + // Build the sign vectors by reinterpreting the table values + uint8x8_t e0 = vcreate_u8(expand0); + uint8x8_t e1 = vcreate_u8(expand1); + uint8x8_t e2 = vcreate_u8(expand2); + uint8x8_t e3 = vcreate_u8(expand3); + + // Shift right by 4 to get 0 or 1 + int8x8_t s0 = vreinterpret_s8_u8(vshr_n_u8(e0, 4)); + int8x8_t s1 = vreinterpret_s8_u8(vshr_n_u8(e1, 4)); + int8x8_t s2 = vreinterpret_s8_u8(vshr_n_u8(e2, 4)); + int8x8_t s3 = vreinterpret_s8_u8(vshr_n_u8(e3, 4)); + + // Convert 0/1 to -1/+1: sign = 2*val - 1 + int8x8_t one = vdup_n_s8(1); + s0 = vsub_s8(vadd_s8(s0, s0), one); // 2*s0 - 1 + s1 = vsub_s8(vadd_s8(s1, s1), one); + s2 = vsub_s8(vadd_s8(s2, s2), one); + s3 = vsub_s8(vadd_s8(s3, s3), one); + + // Combine into 16-element vectors + int8x16_t signs0 = vcombine_s8(s0, s1); + int8x16_t signs1 = vcombine_s8(s2, s3); + + // Multiply signs with y values and accumulate + // dot(signs, y) where signs are +1/-1 + int32x4_t p0 = ggml_vdotq_s32(vdupq_n_s32(0), signs0, y0); + int32x4_t p1 = ggml_vdotq_s32(p0, signs1, y1); + + // Scale by d1 and accumulate + sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(p1), d0 * d1); + } + } + + sumf = vaddvq_f32(sumv); +#else + // Scalar fallback + for (int i = 0; i < nb; i++) { + const float d0 = GGML_FP16_TO_FP32(x[i].d); + + // Process 4 Q8_0 blocks + for (int k = 0; k < 4; k++) { + const float d1 = GGML_FP16_TO_FP32(y[i*4 + k].d); + + int sumi = 0; + for (int j = 0; j < QK8_0; j++) { + const int bit_index = k * QK8_0 + j; + const int byte_index = bit_index / 8; + const int bit_offset = bit_index % 8; + + const int xi = ((x[i].qs[byte_index] >> bit_offset) & 1) ? 1 : -1; + sumi += xi * y[i*4 + k].qs[j]; + } + sumf += d0 * d1 * sumi; + } + } +#endif + + *s = sumf; +} + + void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { const int qk = QK8_0; const int nb = n / qk; diff --git a/ggml/src/ggml-cpu/arch/loongarch/quants.c b/ggml/src/ggml-cpu/arch/loongarch/quants.c index f531e916b9e..74e0c086c6d 100644 --- a/ggml/src/ggml-cpu/arch/loongarch/quants.c +++ b/ggml/src/ggml-cpu/arch/loongarch/quants.c @@ -2156,4 +2156,3 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); #endif } - diff --git a/ggml/src/ggml-cpu/arch/powerpc/quants.c b/ggml/src/ggml-cpu/arch/powerpc/quants.c index d3dfd049eaf..644c380c738 100644 --- a/ggml/src/ggml-cpu/arch/powerpc/quants.c +++ b/ggml/src/ggml-cpu/arch/powerpc/quants.c @@ -2302,4 +2302,3 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); #endif } - diff --git a/ggml/src/ggml-cpu/arch/s390/quants.c b/ggml/src/ggml-cpu/arch/s390/quants.c index 34184ed8510..500857579a7 100644 --- a/ggml/src/ggml-cpu/arch/s390/quants.c +++ b/ggml/src/ggml-cpu/arch/s390/quants.c @@ -1463,4 +1463,3 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); #endif } - diff --git a/ggml/src/ggml-cpu/arch/wasm/quants.c b/ggml/src/ggml-cpu/arch/wasm/quants.c index 74a359e6d12..648c6fcaba7 100644 --- a/ggml/src/ggml-cpu/arch/wasm/quants.c +++ b/ggml/src/ggml-cpu/arch/wasm/quants.c @@ -1218,4 +1218,3 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi ggml_vec_dot_q6_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); #endif } - diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 7486acc2b5d..2b3eb5b5ce6 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -217,6 +217,12 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = { .vec_dot_type = GGML_TYPE_F16, .nrows = 1, }, + [GGML_TYPE_Q1_0] = { + .from_float = quantize_row_q1_0, + .vec_dot = ggml_vec_dot_q1_0_q8_0, + .vec_dot_type = GGML_TYPE_Q8_0, + .nrows = 1, + }, [GGML_TYPE_Q4_0] = { .from_float = quantize_row_q4_0, .vec_dot = ggml_vec_dot_q4_0_q8_0, diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 765ce07f06c..0b5d6c6df88 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -4829,6 +4829,7 @@ void ggml_compute_forward_get_rows( const ggml_tensor * src0 = dst->src[0]; switch (src0->type) { + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -5554,6 +5555,7 @@ void ggml_compute_forward_clamp( ggml_compute_forward_clamp_f16(params, dst); } break; case GGML_TYPE_BF16: + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: diff --git a/ggml/src/ggml-cpu/quants.c b/ggml/src/ggml-cpu/quants.c index 7ebbb9c6f15..f66127c2290 100644 --- a/ggml/src/ggml-cpu/quants.c +++ b/ggml/src/ggml-cpu/quants.c @@ -22,6 +22,10 @@ #define UNUSED GGML_UNUSED +void quantize_row_q1_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) { + quantize_row_q1_0_ref(x, y, k); +} + void quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) { quantize_row_q4_0_ref(x, y, k); } @@ -116,6 +120,51 @@ void quantize_row_q8_K_generic(const float * GGML_RESTRICT x, void * GGML_RESTRI //===================================== Dot products ================================= +void ggml_vec_dot_q1_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + const int qk = QK1_0; + const int nb = n / qk; + + assert(n % qk == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q1_0 * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + + float sumf = 0.0; + + for (int i = 0; i < nb; i++) { + const float d0 = GGML_FP16_TO_FP32(x[i].d); + + float sumi = 0.0f; + + for (int k = 0; k < 4; k++) { + const float d1 = GGML_FP16_TO_FP32(y[i*4 + k].d); + + int sumi_block = 0; + + for (int j = 0; j < QK8_0; j++) { + const int bit_index = k * QK8_0 + j; + const int byte_index = bit_index / 8; + const int bit_offset = bit_index % 8; + + const int xi = ((x[i].qs[byte_index] >> bit_offset) & 1) ? 1 : -1; + sumi_block += xi * y[i*4 + k].qs[j]; + } + + sumi += d1 * sumi_block; + } + + sumf += d0 * sumi; + } + + *s = sumf; +} + + void ggml_vec_dot_q4_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { const int qk = QK8_0; const int nb = n / qk; diff --git a/ggml/src/ggml-cpu/quants.h b/ggml/src/ggml-cpu/quants.h index 3584aaa43e8..d4bc87a1c05 100644 --- a/ggml/src/ggml-cpu/quants.h +++ b/ggml/src/ggml-cpu/quants.h @@ -12,6 +12,7 @@ extern "C" { #endif // Quantization +void quantize_row_q1_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q4_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q5_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); @@ -36,6 +37,7 @@ void quantize_row_iq4_nl (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, void quantize_row_iq4_xs (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); // Dot product +void ggml_vec_dot_q1_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); @@ -68,6 +70,7 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void quantize_row_q8_0_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void quantize_row_q8_1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void quantize_row_q8_K_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void ggml_vec_dot_q1_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q4_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q4_1_q8_1_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q5_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index 48695a61ea3..15443aa554a 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -32,6 +32,41 @@ static inline int best_index_int8(int n, const int8_t * val, float x) { return x - val[mu-1] < val[mu] - x ? mu-1 : mu; } +// reference implementation for deterministic creation of model files +void quantize_row_q1_0_ref(const float * GGML_RESTRICT x, block_q1_0 * GGML_RESTRICT y, int64_t k) { + static const int qk = QK1_0; + + assert(k % qk == 0); + + const int nb = k / qk; + + for (int i = 0; i < nb; i++) { + float sum_abs = 0.0f; + for (int j = 0; j < qk; j++) { + sum_abs += fabsf(x[i*qk + j]); + } + const float d = sum_abs / qk; + + y[i].d = GGML_FP32_TO_FP16(d); + + // Clear all bits first + for (int j = 0; j < qk / 8; ++j) { + y[i].qs[j] = 0; + } + + // Just store sign of each weight directly (no normalization) + for (int j = 0; j < qk; ++j) { + const int bit_index = j; + const int byte_index = bit_index / 8; + const int bit_offset = bit_index % 8; + + if (x[i*qk + j] >= 0.0f) { + y[i].qs[byte_index] |= (1 << bit_offset); + } + } + } +} + // reference implementation for deterministic creation of model files void quantize_row_q4_0_ref(const float * GGML_RESTRICT x, block_q4_0 * GGML_RESTRICT y, int64_t k) { static const int qk = QK4_0; @@ -339,6 +374,26 @@ void quantize_row_nvfp4_ref(const float * GGML_RESTRICT x, block_nvfp4 * GGML_RE } } +void dequantize_row_q1_0(const block_q1_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { + static const int qk = QK1_0; + + assert(k % qk == 0); + + const int nb = k / qk; + + for (int i = 0; i < nb; i++) { + const float d = GGML_FP16_TO_FP32(x[i].d); + const float neg_d = -d; + + for (int j = 0; j < qk; ++j) { + const int byte_index = j / 8; + const int bit_offset = j % 8; + const uint8_t bit = (x[i].qs[byte_index] >> bit_offset) & 1; + y[i*qk + j] = bit ? d : neg_d; + } + } +} + void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { static const int qk = QK4_0; @@ -1978,6 +2033,22 @@ static void quantize_row_q4_0_impl(const float * GGML_RESTRICT x, block_q4_0 * G } } +size_t quantize_q1_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + if (!quant_weights) { + quantize_row_q1_0_ref(src, dst, (int64_t)nrow*n_per_row); + return nrow * ggml_row_size(GGML_TYPE_Q1_0, n_per_row); + } + size_t row_size = ggml_row_size(GGML_TYPE_Q1_0, n_per_row); + char * qrow = (char *)dst; + for (int64_t row = 0; row < nrow; ++row) { + quantize_row_q1_0_ref(src, (block_q1_0*)qrow, n_per_row); + src += n_per_row; + qrow += row_size; + } + return nrow * row_size; +} + + size_t quantize_q4_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { if (!quant_weights) { quantize_row_q4_0_ref(src, dst, (int64_t)nrow*n_per_row); @@ -5286,6 +5357,10 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte } } } break; + case GGML_TYPE_Q1_0: + { + VALIDATE_ROW_DATA_D_F16_IMPL(block_q1_0, data, nb); + } break; case GGML_TYPE_Q4_0: { VALIDATE_ROW_DATA_D_F16_IMPL(block_q4_0, data, nb); diff --git a/ggml/src/ggml-quants.h b/ggml/src/ggml-quants.h index 00604f75c0e..d56c86da890 100644 --- a/ggml/src/ggml-quants.h +++ b/ggml/src/ggml-quants.h @@ -14,6 +14,7 @@ extern "C" { // NOTE: these functions are defined as GGML_API because they used by the CPU backend // Quantization +GGML_API void quantize_row_q1_0_ref(const float * GGML_RESTRICT x, block_q1_0 * GGML_RESTRICT y, int64_t k); GGML_API void quantize_row_q4_0_ref(const float * GGML_RESTRICT x, block_q4_0 * GGML_RESTRICT y, int64_t k); GGML_API void quantize_row_q4_1_ref(const float * GGML_RESTRICT x, block_q4_1 * GGML_RESTRICT y, int64_t k); GGML_API void quantize_row_q5_0_ref(const float * GGML_RESTRICT x, block_q5_0 * GGML_RESTRICT y, int64_t k); @@ -41,6 +42,7 @@ GGML_API void quantize_row_iq3_s_ref (const float * GGML_RESTRICT x, block_iq3_ GGML_API void quantize_row_iq2_s_ref (const float * GGML_RESTRICT x, block_iq2_s * GGML_RESTRICT y, int64_t k); // Dequantization +GGML_API void dequantize_row_q1_0(const block_q1_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); GGML_API void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); GGML_API void dequantize_row_q4_1(const block_q4_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); GGML_API void dequantize_row_q5_0(const block_q5_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); @@ -90,6 +92,7 @@ GGML_API size_t quantize_q3_K(const float * GGML_RESTRICT src, void * GGML_RESTR GGML_API size_t quantize_q4_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); GGML_API size_t quantize_q5_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); GGML_API size_t quantize_q6_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +GGML_API size_t quantize_q1_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); GGML_API size_t quantize_q4_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); GGML_API size_t quantize_q4_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); GGML_API size_t quantize_q5_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index e9b6720c0af..0142498d967 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -651,6 +651,14 @@ static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = { .to_float = (ggml_to_float_t) ggml_fp16_to_fp32_row, .from_float_ref = (ggml_from_float_t) ggml_fp32_to_fp16_row, }, + [GGML_TYPE_Q1_0] = { + .type_name = "q1_0", + .blck_size = QK1_0, + .type_size = sizeof(block_q1_0), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_q1_0, + .from_float_ref = (ggml_from_float_t) quantize_row_q1_0_ref, + }, [GGML_TYPE_Q4_0] = { .type_name = "q4_0", .blck_size = QK4_0, @@ -1384,6 +1392,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) { case GGML_FTYPE_MOSTLY_BF16: wtype = GGML_TYPE_BF16; break; case GGML_FTYPE_MOSTLY_Q4_0: wtype = GGML_TYPE_Q4_0; break; case GGML_FTYPE_MOSTLY_Q4_1: wtype = GGML_TYPE_Q4_1; break; + case GGML_FTYPE_MOSTLY_Q1_0: wtype = GGML_TYPE_Q1_0; break; case GGML_FTYPE_MOSTLY_Q5_0: wtype = GGML_TYPE_Q5_0; break; case GGML_FTYPE_MOSTLY_Q5_1: wtype = GGML_TYPE_Q5_1; break; case GGML_FTYPE_MOSTLY_Q8_0: wtype = GGML_TYPE_Q8_0; break; @@ -7652,6 +7661,7 @@ size_t ggml_quantize_chunk( size_t result = 0; switch (type) { + case GGML_TYPE_Q1_0: result = quantize_q1_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q4_0: result = quantize_q4_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q4_1: result = quantize_q4_1(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q5_0: result = quantize_q5_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; From 9cbc4b3acb70f1eabd916b7deacf0ba511185ee8 Mon Sep 17 00:00:00 2001 From: Masashi Yoshimura Date: Tue, 7 Apr 2026 05:08:46 +0900 Subject: [PATCH 391/831] ggml-webgpu: Add the support of `MUL_MAT_ID` (llama/21147) * Add mul_mat_id support to WebGPU * Apply suggestion from @reeselevine --------- Co-authored-by: Reese Levine --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 134 +++++++++++- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 202 ++++++++++++++++++ .../wgsl-shaders/mul_mat_decls.tmpl | 2 + .../ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl | 193 +++++++++++++++++ .../wgsl-shaders/mul_mat_id_gather.wgsl | 55 +++++ 5 files changed, 585 insertions(+), 1 deletion(-) create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 669d2cd53a8..c10157766d9 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -658,6 +658,26 @@ struct ggml_webgpu_mul_mat_shader_decisions { uint32_t mul_mat_wg_size; }; +/** MUL_MAT_ID **/ + +struct ggml_webgpu_mul_mat_id_pipeline_key { + ggml_type src0_type; + ggml_type src1_type; + + bool operator==(const ggml_webgpu_mul_mat_id_pipeline_key & other) const { + return src0_type == other.src0_type && src1_type == other.src1_type; + } +}; + +struct ggml_webgpu_mul_mat_id_pipeline_key_hash { + size_t operator()(const ggml_webgpu_mul_mat_id_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.src0_type); + ggml_webgpu_hash_combine(seed, key.src1_type); + return seed; + } +}; + /** Cpy **/ struct ggml_webgpu_cpy_pipeline_key { @@ -797,7 +817,10 @@ class ggml_webgpu_shader_lib { std::unordered_map mul_mat_vec_pipelines; // fast mat-vec (n==1) std::unordered_map - mul_mat_fast_pipelines; // fast mat-mat (reg-tile or subgroup) + mul_mat_fast_pipelines; // fast mat-mat (reg-tile or subgroup) + std::unordered_map mul_mat_id_gather_pipelines; // key is fixed + std::unordered_map + mul_mat_id_pipelines; // src0_type/src1_type std::unordered_map set_rows_pipelines; @@ -1598,6 +1621,115 @@ class ggml_webgpu_shader_lib { return mul_mat_legacy_pipelines[key]; } + webgpu_pipeline get_mul_mat_id_gather_pipeline(const ggml_webgpu_shader_lib_context & context) { + auto it = mul_mat_id_gather_pipelines.find(1); + if (it != mul_mat_id_gather_pipelines.end()) { + return it->second; + } + std::vector defines; + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_mul_mat_id_gather, defines); + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; + + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, "mul_mat_id_gather"); + pipeline.context = decisions; + mul_mat_id_gather_pipelines[1] = pipeline; + return pipeline; + } + + webgpu_pipeline get_mul_mat_id_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_mul_mat_id_pipeline_key key = { + .src0_type = context.src0->type, + .src1_type = context.src1->type, + }; + + auto it = mul_mat_id_pipelines.find(key); + if (it != mul_mat_id_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "mul_mat_id"; + defines.push_back("MUL_MAT_ID"); + + // src1 type + switch (context.src1->type) { + case GGML_TYPE_F32: + defines.push_back("SRC1_INNER_TYPE=f32"); + break; + case GGML_TYPE_F16: + defines.push_back("SRC1_INNER_TYPE=f16"); + break; + default: + GGML_ABORT("Unsupported src1 type for mul_mat fast shader"); + } + + // src0 type + const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type); + const char * src0_name = src0_traits->type_name; + + switch (context.src0->type) { + case GGML_TYPE_F32: + defines.push_back("SRC0_INNER_TYPE=f32"); + defines.push_back("FLOAT"); + defines.push_back("INIT_SRC0_SHMEM_FLOAT"); + defines.push_back("INIT_SRC1_SHMEM_FLOAT"); + variant += "_f32"; + break; + case GGML_TYPE_F16: + defines.push_back("SRC0_INNER_TYPE=f16"); + defines.push_back("FLOAT"); + defines.push_back("INIT_SRC0_SHMEM_FLOAT"); + defines.push_back("INIT_SRC1_SHMEM_FLOAT"); + variant += "_f16"; + break; + default: + { + std::string type_upper = src0_name; + std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper); + + defines.push_back("BYTE_HELPERS"); + defines.push_back("INIT_SRC0_SHMEM_" + type_upper); + defines.push_back("INIT_SRC1_SHMEM_FLOAT"); + defines.push_back("U32_DEQUANT_HELPERS"); + defines.push_back("SRC0_INNER_TYPE=u32"); + + variant += std::string("_") + src0_name; + break; + } + } + + defines.push_back("SCALAR"); + + // Tiles + defines.push_back("TILE_M=" + std::to_string(WEBGPU_MUL_MAT_TILE_M) + "u"); + defines.push_back("TILE_N=" + std::to_string(WEBGPU_MUL_MAT_TILE_N) + "u"); + defines.push_back("TILE_K=" + std::to_string(WEBGPU_MUL_MAT_TILE_K) + "u"); + + defines.push_back("WORKGROUP_SIZE_M=" + std::to_string(WEBGPU_MUL_MAT_WG_SIZE_M) + "u"); + defines.push_back("WORKGROUP_SIZE_N=" + std::to_string(WEBGPU_MUL_MAT_WG_SIZE_N) + "u"); + + // variant suffix for src1 type + variant += std::string("_") + (context.src1->type == GGML_TYPE_F32 ? "f32" : "f16"); + + auto processed = preprocessor.preprocess(wgsl_mul_mat_id, defines); + + auto decisions = std::make_shared(); + decisions->tile_k = WEBGPU_MUL_MAT_TILE_K; + decisions->tile_m = WEBGPU_MUL_MAT_TILE_M; + decisions->tile_n = WEBGPU_MUL_MAT_TILE_N; + decisions->wg_size_m = WEBGPU_MUL_MAT_WG_SIZE_M; + decisions->wg_size_n = WEBGPU_MUL_MAT_WG_SIZE_N; + decisions->wg_size = WEBGPU_MUL_MAT_WG_SIZE_M * WEBGPU_MUL_MAT_WG_SIZE_N; + + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + mul_mat_id_pipelines[key] = pipeline; + return mul_mat_id_pipelines[key]; + } + webgpu_pipeline get_unary_pipeline(const ggml_webgpu_shader_lib_context & context) { const bool is_unary = context.dst->op == GGML_OP_UNARY; const int op = is_unary ? (int) ggml_get_unary_op(context.dst) : context.dst->op; diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 5c567dc0df0..5b118393640 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -1376,6 +1376,163 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x, wg_y); } +static webgpu_encoded_op ggml_webgpu_mul_mat_id(webgpu_context & ctx, + wgpu::CommandEncoder & encoder, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * src2, + ggml_tensor * dst) { + ggml_webgpu_shader_lib_context shader_lib_ctx = { + .src0 = src0, + .src1 = src1, + .src2 = src2, + .dst = dst, + .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, + }; + + // Get or create pipeline + webgpu_pipeline gather_pipeline, main_pipeline; + + std::vector pipelines; + std::vector> params_list; + std::vector> entries_list; + std::vector> workgroups_list; + + gather_pipeline = ctx->shader_lib->get_mul_mat_id_gather_pipeline(shader_lib_ctx); + main_pipeline = ctx->shader_lib->get_mul_mat_id_pipeline(shader_lib_ctx); + + const uint32_t param_n_expert = (uint32_t) src0->ne[2]; + const uint32_t param_n_expert_used = (uint32_t) dst->ne[1]; + const uint32_t param_n_tokens = (uint32_t) dst->ne[2]; + + // params for mul_mat_id_gather.wgsl + std::vector gather_params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)), + param_n_expert, + param_n_expert_used, + param_n_tokens, + (uint32_t) (src2->nb[1] / ggml_type_size(src2->type)), + }; + + const size_t dst_offset = ggml_webgpu_tensor_offset(dst); + const size_t gathered_buf_nbytes = src0->ne[2] * src1->ne[2] * sizeof(uint32_t); + + const size_t gathered_expert_used_align_offset = ROUNDUP_POW2( + dst_offset + ggml_nbytes(dst), ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment); + const size_t gathered_tokens_align_offset = + ROUNDUP_POW2(gathered_expert_used_align_offset + gathered_buf_nbytes, + ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment); + const size_t gathered_count_ids_align_offset = + ROUNDUP_POW2(gathered_tokens_align_offset + gathered_buf_nbytes, + ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment); + + const size_t gathered_binding_size = ROUNDUP_POW2(gathered_buf_nbytes, WEBGPU_STORAGE_BUF_BINDING_MULT); + const size_t gathered_count_ids_binding_size = + ROUNDUP_POW2(src0->ne[2] * sizeof(uint32_t), WEBGPU_STORAGE_BUF_BINDING_MULT); + + // bind group entries for mul_mat_id_gather.wgsl + std::vector gather_entries = { + { .binding = 0, + .buffer = ggml_webgpu_tensor_buf(src2), + .offset = ggml_webgpu_tensor_align_offset(ctx, src2), + .size = ggml_webgpu_tensor_binding_size(ctx, src2) }, + { .binding = 1, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = gathered_expert_used_align_offset, + .size = gathered_binding_size }, + { .binding = 2, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = gathered_tokens_align_offset, + .size = gathered_binding_size }, + { .binding = 3, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = gathered_count_ids_align_offset, + .size = gathered_count_ids_binding_size }, + }; + + const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; + + const uint32_t gather_total_wg = param_n_expert; + const uint32_t gather_wg_x = std::min(gather_total_wg, max_wg_per_dim); + const uint32_t gather_wg_y = CEIL_DIV(gather_total_wg, gather_wg_x); + + pipelines.push_back(gather_pipeline); + params_list.push_back(std::move(gather_params)); + entries_list.push_back(std::move(gather_entries)); + workgroups_list.push_back({ gather_wg_x, gather_wg_y }); + + // params for mul_mat_id.wgsl + std::vector main_params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + (uint32_t) src0->ne[0], + (uint32_t) src0->ne[1], + param_n_expert, + param_n_expert_used, + param_n_tokens, + (uint32_t) src1->ne[1], + (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), + (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), + (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), + (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), + }; + + // bind group entries for mul_mat_id.wgsl + std::vector main_entries = { + { .binding = 0, + .buffer = ggml_webgpu_tensor_buf(src0), + .offset = ggml_webgpu_tensor_align_offset(ctx, src0), + .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, + { .binding = 1, + .buffer = ggml_webgpu_tensor_buf(src1), + .offset = ggml_webgpu_tensor_align_offset(ctx, src1), + .size = ggml_webgpu_tensor_binding_size(ctx, src1) }, + { .binding = 2, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = ggml_webgpu_tensor_align_offset(ctx, dst), + .size = ggml_webgpu_tensor_binding_size(ctx, dst) }, + { .binding = 3, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = gathered_expert_used_align_offset, + .size = gathered_binding_size }, + { .binding = 4, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = gathered_tokens_align_offset, + .size = gathered_binding_size }, + { .binding = 5, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = gathered_count_ids_align_offset, + .size = gathered_count_ids_binding_size }, + }; + + // Calculate workgroup dimensions + uint32_t wg_x = 1; + uint32_t wg_y = 1; + + auto * main_decisions = static_cast(main_pipeline.context.get()); + + uint32_t wg_m; + + uint32_t tile_m_s = main_decisions->tile_m * main_decisions->wg_size_m; + uint32_t tile_n_s = main_decisions->tile_n * main_decisions->wg_size_n; + wg_m = CEIL_DIV(dst->ne[0], tile_m_s); + uint32_t total_gathered = dst->ne[1] * dst->ne[2]; + uint32_t max_active_experts = std::min((uint32_t) src0->ne[2], total_gathered); + uint32_t max_wg_n = CEIL_DIV(total_gathered, tile_n_s) + max_active_experts; + uint32_t total_wg = wg_m * max_wg_n; + + compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y); + + pipelines.push_back(main_pipeline); + params_list.push_back(std::move(main_params)); + entries_list.push_back(std::move(main_entries)); + workgroups_list.push_back({ wg_x, wg_y }); + + return ggml_backend_webgpu_build_multi(ctx->global_ctx, ctx->param_arena, encoder, pipelines, params_list, + entries_list, workgroups_list); +} + #ifndef __EMSCRIPTEN__ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, wgpu::CommandEncoder & encoder, @@ -2638,6 +2795,8 @@ static std::optional ggml_webgpu_encode_node(webgpu_context return ggml_webgpu_get_rows(ctx, encoder, src0, src1, node); case GGML_OP_MUL_MAT: return ggml_webgpu_mul_mat(ctx, encoder, src0, src1, node); + case GGML_OP_MUL_MAT_ID: + return ggml_webgpu_mul_mat_id(ctx, encoder, src0, src1, src2, node); case GGML_OP_FLASH_ATTN_EXT: #ifndef __EMSCRIPTEN__ return ggml_webgpu_flash_attn(ctx, encoder, src0, src1, src2, node->src[3], node->src[4], node); @@ -3082,6 +3241,20 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer } } break; + case GGML_OP_MUL_MAT_ID: + { + const ggml_tensor * src0 = tensor->src[0]; + const ggml_tensor * src1 = tensor->src[1]; + if (src0 && src1) { + const size_t gathered_size = sizeof(uint32_t) * tensor->src[0]->ne[2] * tensor->src[1]->ne[2]; + const size_t gathered_count_ids_size = sizeof(uint32_t) * tensor->src[0]->ne[2]; + res = ROUNDUP_POW2( + res + gathered_size * 2 + gathered_count_ids_size + + ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment * 3, + WEBGPU_STORAGE_BUF_BINDING_MULT); + } + } + break; default: break; } @@ -3503,6 +3676,35 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const } break; } + case GGML_OP_MUL_MAT_ID: + switch (src1->type) { + case GGML_TYPE_F16: + supports_op |= (src0->type == GGML_TYPE_F16); + break; + case GGML_TYPE_F32: + switch (src0->type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + supports_op = true; + break; + default: + break; + } + break; + default: + break; + } + break; case GGML_OP_FLASH_ATTN_EXT: { #ifndef __EMSCRIPTEN__ diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl index eb228537bad..ea91c13468f 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl @@ -42,6 +42,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 } #endif // INIT_SRC0_SHMEM_FLOAT +#ifndef MUL_MAT_ID #ifdef INIT_SRC1_SHMEM_FLOAT fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u32) { for (var elem_idx = thread_id * VEC_SIZE; elem_idx < TILE_SRC1_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * VEC_SIZE) { @@ -58,6 +59,7 @@ fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u3 } } #endif // INIT_SRC1_SHMEM_FLOAT +#endif #ifdef INIT_SRC0_SHMEM_Q4_0 const BLOCK_SIZE = 32u; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl new file mode 100644 index 00000000000..5f763a6400a --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl @@ -0,0 +1,193 @@ +enable f16; + +#include "common_decls.tmpl" +#include "mul_mat_decls.tmpl" + +#ifdef VEC +fn store_val(acc: array, TILE_N>, tn: u32, tm: u32) -> vec4 { + return vec4(f32(acc[tn][tm]), f32(acc[tn][tm + 1]), f32(acc[tn][tm + 2]), f32(acc[tn][tm + 3])); +} +#endif + +#ifdef SCALAR +fn store_val(acc: array, TILE_N>, tn: u32, tm: u32) -> f32 { + return f32(acc[tn][tm]); +} +#endif + +struct MulMatIdParams { + offset_src0: u32, + offset_src1: u32, + offset_dst: u32, + + k: u32, + m: u32, + n_expert: u32, + n_expert_used: u32, + n_tokens: u32, + b_ne1: u32, + + stride_01: u32, + stride_11: u32, + stride_02: u32, + stride_12: u32, +}; + +@group(0) @binding(0) var src0: array; // [cols, rows, n_expert] +@group(0) @binding(1) var src1: array; // [cols, b_ne1, n_tokens] +@group(0) @binding(2) var dst: array; // [rows, n_expert_used, n_tokens] +@group(0) @binding(3) var global_gathered_expert_used: array; // [n_expert][n_tokens] +@group(0) @binding(4) var global_gathered_tokens: array; // [n_expert][n_tokens] +@group(0) @binding(5) var gathered_count_ids: array; // [n_expert] + +@group(0) @binding(6) var params: MulMatIdParams; + +fn get_local_n(thread_id: u32) -> u32 { + return thread_id / WORKGROUP_SIZE_M; +} +fn get_local_m(thread_id: u32) -> u32 { + return thread_id % WORKGROUP_SIZE_M; +} + +const TOTAL_WORKGROUP_SIZE = WORKGROUP_SIZE_M * WORKGROUP_SIZE_N; +const TILE_SRC0_SHMEM = TILE_K * WORKGROUP_SIZE_M * TILE_M; +const TILE_SRC1_SHMEM = TILE_K * WORKGROUP_SIZE_N * TILE_N; + +var shmem: array; +var gathered_expert_used: array; +var gathered_tokens: array; + +#ifdef INIT_SRC1_SHMEM_FLOAT +fn init_shmem_id_src1(thread_id: u32, offset_src1: u32, rest_token_n: u32, k_outer: u32) { + for (var elem_idx = thread_id * VEC_SIZE; elem_idx < TILE_SRC1_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * VEC_SIZE) { + let tile_n = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + if (tile_n < rest_token_n) { + let global_src10 = k_outer + tile_k; + let expert_used_idx = gathered_expert_used[tile_n] % params.b_ne1; + let token_idx = gathered_tokens[tile_n]; + let src1_idx = offset_src1 + token_idx * params.stride_12 + expert_used_idx * params.stride_11 + global_src10; + let src1_val = select( + SRC1_TYPE(0.0), + src1[src1_idx/VEC_SIZE], + global_src10 < params.k); + store_shmem(SHMEM_TYPE(src1_val), TILE_SRC0_SHMEM + elem_idx); + } else { + store_shmem(SHMEM_TYPE(0.0), TILE_SRC0_SHMEM + elem_idx); + } + } +} +#endif // INIT_SRC1_SHMEM_FLOAT + +@compute @workgroup_size(TOTAL_WORKGROUP_SIZE) +fn main(@builtin(workgroup_id) wg_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(num_workgroups) num_wg: vec3) { + + let thread_id = local_id.x; + let local_m = get_local_m(thread_id); + let local_n = get_local_n(thread_id); + + var expert_idx:u32 = 0xFFFFFFFFu; + var wg_in_batch:u32 = 0; + var wg_sum:u32 = 0; + let wg_m_count = (params.m + WORKGROUP_SIZE_M * TILE_M - 1u) / (WORKGROUP_SIZE_M * TILE_M); + let wg_linear = wg_id.y * num_wg.x + wg_id.x; + + for (var i = 0u;i < params.n_expert;i += 1) { + let wg_n_count = (gathered_count_ids[i] + WORKGROUP_SIZE_N * TILE_N - 1u) / (WORKGROUP_SIZE_N * TILE_N); + let wg_per_matrix = wg_m_count * wg_n_count; + if (wg_sum <= wg_linear && wg_linear < wg_sum + wg_per_matrix) { + expert_idx = i; + wg_in_batch = wg_linear - wg_sum; + break; + } + wg_sum += wg_per_matrix; + } + + let is_valid = expert_idx != 0xFFFFFFFFu; + + var wg_m: u32 = 0; + var wg_n: u32 = 0; + var offset_wg_m: u32 = 0; + var offset_wg_n: u32 = 0; + var rest_token_n: u32 = 0; + var src0_batch_offset: u32 = 0; + + wg_m = wg_in_batch % wg_m_count; + wg_n = wg_in_batch / wg_m_count; + + offset_wg_m = wg_m * WORKGROUP_SIZE_M * TILE_M; + offset_wg_n = wg_n * WORKGROUP_SIZE_N * TILE_N; + + if (is_valid) { + rest_token_n = gathered_count_ids[expert_idx] - offset_wg_n; + let global_gathered_base = expert_idx * params.n_tokens + offset_wg_n; + for (var i = thread_id; i < TILE_N * WORKGROUP_SIZE_N && offset_wg_n + i < gathered_count_ids[expert_idx]; i += TOTAL_WORKGROUP_SIZE) { + gathered_expert_used[i] = global_gathered_expert_used[global_gathered_base + i]; + gathered_tokens[i] = global_gathered_tokens[global_gathered_base + i]; + } + src0_batch_offset = params.offset_src0 + expert_idx * params.stride_02; + } + + workgroupBarrier(); + + let output_row_base = offset_wg_m + local_m * TILE_M; + let output_col_base = offset_wg_n + local_n * TILE_N; + + let dst2_stride = params.m * params.n_expert_used; + let dst1_stride = params.m; + + var acc: array, TILE_N>; + + for (var k_outer = 0u; k_outer < params.k; k_outer += TILE_K) { + + if (is_valid) { + init_shmem_src0(thread_id, src0_batch_offset, offset_wg_m, k_outer); + init_shmem_id_src1(thread_id, params.offset_src1, rest_token_n, k_outer); + } + + workgroupBarrier(); + + if (is_valid) { + let k_end = min(TILE_K, params.k - k_outer); + + for (var k_inner = 0u; k_inner < k_end; k_inner++) { + var src0_tile: array; + for (var tm = 0u; tm < TILE_M; tm++) { + let src0_m = local_m * TILE_M + tm; + let src0_idx = k_inner + src0_m * TILE_K; + src0_tile[tm] = shmem[src0_idx]; + } + for (var tn = 0u; tn < TILE_N; tn++) { + let src1_n = local_n * TILE_N + tn; + let src1_idx = src1_n * TILE_K + k_inner; + let src1_val = shmem[TILE_SRC0_SHMEM + src1_idx]; + for (var tm = 0u; tm < TILE_M; tm++) { + acc[tn][tm] += src0_tile[tm] * src1_val; + } + } + } + } + + workgroupBarrier(); + } + + if (is_valid) { + for (var tn = 0u; tn < TILE_N; tn++) { + let n_idx = output_col_base + tn; + if (n_idx < gathered_count_ids[expert_idx]) { + let dst1_idx = gathered_expert_used[n_idx - offset_wg_n]; + let dst2_idx = gathered_tokens[n_idx - offset_wg_n]; + let dst12_offset = params.offset_dst + dst2_idx * dst2_stride + dst1_idx * dst1_stride; + for (var tm = 0u; tm < TILE_M; tm += VEC_SIZE) { + let global_row = output_row_base + tm; + if (global_row < params.m) { + let dst_idx = dst12_offset + global_row; + dst[dst_idx/VEC_SIZE] = store_val(acc, tn, tm); + } + } + } + } + } +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl new file mode 100644 index 00000000000..d79d5f3f282 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl @@ -0,0 +1,55 @@ +enable f16; + +struct MulMatIdGatherParams { + offset_ids: u32, + + n_expert: u32, + n_expert_used: u32, + n_tokens: u32, + + stride_ids_1: u32, +}; + +@group(0) @binding(0) var ids: array; // [n_expert_used, n_tokens] +@group(0) @binding(1) var global_gathered_expert_used: array; // [n_expert][n_tokens] +@group(0) @binding(2) var global_gathered_tokens: array; // [n_expert][n_tokens] +@group(0) @binding(3) var gathered_count_ids: array; // [n_expert] + +@group(0) @binding(4) var params: MulMatIdGatherParams; + +var count:atomic; + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(workgroup_id) wg_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(num_workgroups) num_wg: vec3) { + + let thread_id = local_id.x; + let own_expert = wg_id.y * num_wg.x + wg_id.x; // the expert assigned to this workgroup + + if (own_expert < params.n_expert) { + if (thread_id == 0u) { + atomicStore(&count, 0); + } + + workgroupBarrier(); + + for (var i = thread_id;i < params.n_expert_used * params.n_tokens;i += WG_SIZE) { + let row = i / params.n_expert_used; + let col = i % params.n_expert_used; + let expert = u32(ids[params.offset_ids + row * params.stride_ids_1 + col]); + if (own_expert == expert) { + let pos = atomicAdd(&count, 1u); + let gathered_id = own_expert * params.n_tokens + pos; + global_gathered_expert_used[gathered_id] = col; + global_gathered_tokens[gathered_id] = row; + } + } + + workgroupBarrier(); + + if (thread_id == 0u) { + gathered_count_ids[own_expert] = atomicLoad(&count); + } + } +} From 1ebf3cafa03bf94ae71795f2ceb4a3b2effc7cea Mon Sep 17 00:00:00 2001 From: PMZFX Date: Tue, 7 Apr 2026 04:12:49 -0400 Subject: [PATCH 392/831] Add Q8_0 reorder optimization (~3x tg speedup on Intel Arc) (llama/21527) Extend the existing reorder optimization to Q8_0. The reorder separates scale factors from weight data for coalesced memory access -- was implemented for Q4_0/Q4_K/Q6_K but Q8_0 was missing. On Arc Pro B70 (Xe2), Q8_0 tg goes from 4.88 to 15.24 t/s (3.1x) on Qwen3.5-27B. BW utilization: 21% -> 66%. The key fix beyond the kernels: Q8_0 was missing from the type check in ggml_backend_sycl_buffer_init_tensor() that allocates the extra struct carrying the reorder flag -- so the optimization was silently skipped. AI (Claude) was used to assist with root cause investigation and writing the kernel code. All code was human-reviewed and tested on real hardware. Fixes: #21517 --- ggml/src/ggml-sycl/dequantize.hpp | 16 +++++ ggml/src/ggml-sycl/dmmv.cpp | 104 +++++++++++++++++++++++++++++- ggml/src/ggml-sycl/ggml-sycl.cpp | 42 +++++++++++- ggml/src/ggml-sycl/mmvq.cpp | 27 +++++++- ggml/src/ggml-sycl/quants.hpp | 21 ++++++ ggml/src/ggml-sycl/vecdotq.hpp | 40 ++++++++++++ 6 files changed, 247 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-sycl/dequantize.hpp b/ggml/src/ggml-sycl/dequantize.hpp index 3272724f41b..f992db33b2d 100644 --- a/ggml/src/ggml-sycl/dequantize.hpp +++ b/ggml/src/ggml-sycl/dequantize.hpp @@ -143,6 +143,22 @@ static __dpct_inline__ void dequantize_q5_1(const void *vx, const int64_t ib, #endif // GGML_SYCL_F16 } +static __dpct_inline__ void dequantize_q8_0_reorder(const void *d_ptr, const int64_t ib, const void *qs, + const int iqs, dfloat2 &v) { + const dfloat d = (const dfloat)*((const sycl::half*)d_ptr + ib); + + v.x() = ((const int8_t *)qs)[iqs + 0]; + v.y() = ((const int8_t *)qs)[iqs + 1]; + +#ifdef GGML_SYCL_F16 + v.s0() *= d; + v.s1() *= d; +#else + v.x() *= d; + v.y() *= d; +#endif // GGML_SYCL_F16 +} + static __dpct_inline__ void dequantize_q8_0(const void *vx, const int64_t ib, const int iqs, dfloat2 &v) { const block_q8_0 * x = (const block_q8_0 *) vx; diff --git a/ggml/src/ggml-sycl/dmmv.cpp b/ggml/src/ggml-sycl/dmmv.cpp index 4f2760110c2..1c8b6f3771f 100644 --- a/ggml/src/ggml-sycl/dmmv.cpp +++ b/ggml/src/ggml-sycl/dmmv.cpp @@ -972,6 +972,103 @@ static void dequantize_mul_mat_vec_q5_1_sycl(const void *vx, const dfloat *y, } } +static void dequantize_mul_mat_vec_q8_0_sycl_reorder(const void *vx, const dfloat *y, + float *dst, const int ncols, + const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + { + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + // Q8_0 reorder layout: [all qs (ncols*nrows bytes)][all d values] + // Cannot reuse dequantize_mul_mat_vec_reorder template because it has + // Q4_0-specific constants hardcoded (d_ptr offset and qs stride). + const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + + item_ct1.get_local_id(1); + if (row >= nrows) return; + + const int tid = item_ct1.get_local_id(2); + const int iter_stride = 8*2*GGML_SYCL_DMMV_X; + const int vals_per_iter = iter_stride / WARP_SIZE; + const int ncols_left = ncols % (QK8_0*WARP_SIZE); + const int ncols_align = ncols - ncols_left; + +#ifdef GGML_SYCL_F16 + sycl::half2 tmp = {0.0f, 0.0f}; +#else + float tmp = 0.0f; +#endif + const char *d_ptr = (const char*)vx + ncols*nrows; // d after all qs + + int i = 0; + for (i = 0; i < ncols_align; i += iter_stride) { + const int col = i + vals_per_iter*tid; + const int ib = (row*ncols + col)/QK8_0; + const int iqs = col % QK8_0; + +#pragma unroll + for (int j = 0; j < vals_per_iter; j += 2) { + dfloat2 v; + dequantize_q8_0_reorder((const void *)d_ptr, ib, (const void *)vx, + ib * QK8_0 + iqs + j, v); + +#ifdef GGML_SYCL_F16 + dfloat2 t1{y[col + j + 0], y[col + j + 1]}; + tmp += v * t1; +#else + tmp += v.x() * y[col + j + 0]; + tmp += v.y() * y[col + j + 1]; +#endif + } + } + + // handle remaining columns + for (; i < ncols; i += iter_stride) { + if (tid >= ncols_left/QK8_0) continue; + const int col = i + vals_per_iter*tid; + const int ib = (row*ncols + col)/QK8_0; + const int iqs = col % QK8_0; + +#pragma unroll + for (int j = 0; j < vals_per_iter; j += 2) { + dfloat2 v; + dequantize_q8_0_reorder((const void *)d_ptr, ib, (const void *)vx, + ib * QK8_0 + iqs + j, v); + +#ifdef GGML_SYCL_F16 + dfloat2 t1{y[col + j + 0], y[col + j + 1]}; + tmp += v * t1; +#else + tmp += v.x() * y[col + j + 0]; + tmp += v.y() * y[col + j + 1]; +#endif + } + } + + // reduce + const int mask_start = ncols > GGML_SYCL_DMMV_X ? WARP_SIZE >> 1 : WARP_SIZE >> 2; + for (int mask = mask_start; mask > 0; mask >>= 1) { + tmp += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); + } + + if (tid == 0) { +#ifdef GGML_SYCL_F16 + dst[row] = tmp.x() + tmp.y(); +#else + dst[row] = tmp; +#endif + } + }); + } +} + static void dequantize_mul_mat_vec_q8_0_sycl(const void *vx, const dfloat *y, float *dst, const int ncols, const int nrows, @@ -1122,7 +1219,12 @@ void ggml_sycl_op_dequantize_mul_mat_vec( dequantize_mul_mat_vec_q5_1_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); break; case GGML_TYPE_Q8_0: - dequantize_mul_mat_vec_q8_0_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); + if ((ggml_tensor_extra_gpu *) dst->src[0]->extra && + ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { + dequantize_mul_mat_vec_q8_0_sycl_reorder(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); + } else { + dequantize_mul_mat_vec_q8_0_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); + } break; case GGML_TYPE_Q2_K: dequantize_mul_mat_vec_q2_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 28be4939784..e80ead9aea4 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -411,7 +411,7 @@ ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer, assert(tensor->view_src->buffer->buft == buffer->buft); return GGML_STATUS_SUCCESS; } - if ((tensor->type == GGML_TYPE_Q4_0 || tensor->type == GGML_TYPE_Q4_K || tensor->type == GGML_TYPE_Q6_K) && + if ((tensor->type == GGML_TYPE_Q4_0 || tensor->type == GGML_TYPE_Q8_0 || tensor->type == GGML_TYPE_Q4_K || tensor->type == GGML_TYPE_Q6_K) && !g_ggml_sycl_disable_optimize) { ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{}; tensor->extra = extra; @@ -3254,6 +3254,7 @@ inline bool ggml_sycl_supports_mmq(enum ggml_type type) { inline bool ggml_sycl_supports_reorder_mul_mat_sycl(enum ggml_type type) { switch (type) { case GGML_TYPE_Q4_0: + case GGML_TYPE_Q8_0: return true; case GGML_TYPE_Q4_K: case GGML_TYPE_Q6_K: @@ -3266,6 +3267,7 @@ inline bool ggml_sycl_supports_reorder_mul_mat_sycl(enum ggml_type type) { inline bool ggml_sycl_supports_reorder_dmmv(enum ggml_type type) { switch (type) { case GGML_TYPE_Q4_0: + case GGML_TYPE_Q8_0: return true; default: return false; @@ -3275,6 +3277,7 @@ inline bool ggml_sycl_supports_reorder_dmmv(enum ggml_type type) { inline bool ggml_sycl_supports_reorder_mmvq(enum ggml_type type) { switch (type) { case GGML_TYPE_Q4_0: + case GGML_TYPE_Q8_0: case GGML_TYPE_Q4_K: case GGML_TYPE_Q6_K: return true; @@ -3364,6 +3367,40 @@ static void reorder_qw_q4_0(uint8_t * data_device, const int ncols, const int nr sycl_ext_free(stream, tmp_buf); } +static void reorder_qw_q8_0(uint8_t * data_device, const int ncols, const int nrows, size_t size, size_t offset, + dpct::queue_ptr stream) { + uint8_t * tmp_buf = static_cast(sycl_ext_malloc_device(stream, size)); + + sycl::event copy_event; + SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size))); + if (!g_ggml_sycl_use_async_mem_op) { + copy_event.wait(); + } + + GGML_ASSERT((size % sizeof(block_q8_0) == 0)); + GGML_ASSERT((offset % sizeof(block_q8_0) == 0)); + int offset_blks = offset / sizeof(block_q8_0); + auto qs_ptr = data_device + offset_blks * QK8_0; + auto d_ptr = (sycl::half*)(qs_ptr + ncols * nrows) + offset_blks; + + auto reorder_event = stream->parallel_for( + size / sizeof(block_q8_0), + [=](auto i) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + const block_q8_0* x = (const block_q8_0*)tmp_buf; + const int ib = i; + + for (int j = 0; j < QK8_0; j++) + { + *((int8_t*)qs_ptr + ib * QK8_0 + j) = x[ib].qs[j]; + } + *(d_ptr + ib) = x[ib].d; + }); + if (!g_ggml_sycl_use_async_mem_op) { + reorder_event.wait_and_throw(); + } + sycl_ext_free(stream, tmp_buf); +} + static void reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) { GGML_ASSERT(size % sizeof(block_q4_K) == 0); GGML_ASSERT(offset % sizeof(block_q4_K) == 0); @@ -3460,6 +3497,9 @@ static void reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) { case GGML_TYPE_Q4_0: reorder_qw_q4_0(data_device, ncols, nrows, size, 0, stream); break; + case GGML_TYPE_Q8_0: + reorder_qw_q8_0(data_device, ncols, nrows, size, 0, stream); + break; case GGML_TYPE_Q4_K: reorder_qw_q4_k(data_device, size, 0, stream); break; diff --git a/ggml/src/ggml-sycl/mmvq.cpp b/ggml/src/ggml-sycl/mmvq.cpp index 5abc50fabfe..af22b98dddb 100644 --- a/ggml/src/ggml-sycl/mmvq.cpp +++ b/ggml/src/ggml-sycl/mmvq.cpp @@ -679,6 +679,25 @@ static void mul_mat_vec_q5_1_q8_1_sycl(const void *vx, const void *vy, } } +static void reorder_mul_mat_vec_q8_0_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, + const int nrows, dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK8_0 == 0); + const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y); + constexpr size_t num_subgroups = 16; + GGML_ASSERT(block_num_y % num_subgroups == 0); + + const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, (block_num_y * WARP_SIZE)); + const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE); + + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size), + [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_reorder>(vx, vy, dst, ncols, nrows, + nd_item); + }); + }); +} + static void mul_mat_vec_q8_0_q8_1_sycl(const void *vx, const void *vy, float *dst, const int ncols, const int nrows, @@ -1101,7 +1120,13 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens mul_mat_vec_q5_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); break; case GGML_TYPE_Q8_0: - mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + if ((ggml_tensor_extra_gpu *) dst->src[0]->extra && + ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { + GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q8_0_q8_1_sycl\n"); + reorder_mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } else { + mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } break; case GGML_TYPE_Q2_K: mul_mat_vec_q2_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); diff --git a/ggml/src/ggml-sycl/quants.hpp b/ggml/src/ggml-sycl/quants.hpp index 14490fea5be..1f5b62740a8 100644 --- a/ggml/src/ggml-sycl/quants.hpp +++ b/ggml/src/ggml-sycl/quants.hpp @@ -105,6 +105,27 @@ template <> struct block_q_t { static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; } }; +template <> struct block_q_t { + struct traits { + static constexpr uint32_t qk = QK8_0; // 32 + static constexpr uint32_t qi = QI8_0; // 8 + static constexpr uint32_t qr = QR8_0; // 1 + static constexpr uint32_t vdr_mmvq = 4; + }; + + // Q8_0 reorder layout: [qs0|qs1|...|qsN][d0|d1|...|dN] + // Each block has 32 int8 weights (32 bytes) followed by all scales + static constexpr std::pair get_block_offset(const int block_index, const int /* nblocks */) { + return { block_index * QK8_0, 0 }; + } + + static constexpr std::pair get_d_offset(int nrows, int ncols, const int block_index) { + return { (ncols * nrows) + block_index * sizeof(ggml_half), 0 }; + } + + static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; } // 1 +}; + } // namespace ggml_sycl_reordered #endif // GGML_SYCL_QUANTS_HPP diff --git a/ggml/src/ggml-sycl/vecdotq.hpp b/ggml/src/ggml-sycl/vecdotq.hpp index eab9850aed7..9253168e5ea 100644 --- a/ggml/src/ggml-sycl/vecdotq.hpp +++ b/ggml/src/ggml-sycl/vecdotq.hpp @@ -351,6 +351,46 @@ template <> struct reorder_vec_dot_q_sycl { }; }; +template <> struct reorder_vec_dot_q_sycl { + static constexpr ggml_type gtype = GGML_TYPE_Q8_0; + + using q8_0_block = ggml_sycl_reordered::block_q_t; + using q8_0_traits = typename q8_0_block::traits; + + __dpct_inline__ float vec_dot_q8_0_q8_1_impl(const int * v, const int * u, const float & d8_0, const sycl::half2 & ds8) { + int sumi = 0; + +#pragma unroll + for (size_t i = 0; i < q8_0_traits::vdr_mmvq; ++i) { + // Q8_0 values are signed int8, no nibble extraction needed + // Direct dp4a: each int packs 4 int8 values + sumi = dpct::dp4a(v[i], u[i], sumi); + } + + const sycl::float2 ds8f = ds8.convert(); + + // Q8_0 has no bias term (values are signed), so just scale + return d8_0 * sumi * ds8f.x(); + } + + __dpct_inline__ float operator()(const void * __restrict__ vbq, const std::pair ibx_offset, + const std::pair d_offset, const int8_t * q8_1_quant_ptr, + const sycl::half2 * q8_1_ds, const int & iqs) { + const int8_t * bq8_0 = static_cast(vbq) + ibx_offset.first; + const ggml_half d = *(reinterpret_cast(static_cast(vbq) + d_offset.first)); + int v[q8_0_traits::vdr_mmvq]; + int u[q8_0_traits::vdr_mmvq]; + +#pragma unroll + for (size_t i = 0; i < q8_0_traits::vdr_mmvq; ++i) { + v[i] = get_int_from_int8(bq8_0, iqs + i); + u[i] = get_int_from_int8_aligned(q8_1_quant_ptr, iqs + i); + } + + return vec_dot_q8_0_q8_1_impl(v, u, d, *q8_1_ds); + }; +}; + static inline float vec_dot_q4_K_q8_1_common(const int * __restrict__ q4, const uint16_t * __restrict__ scales, const ggml_half2 & dm, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { From a1f76fb4cfd05ed08c96d2f569379551f6e6989f Mon Sep 17 00:00:00 2001 From: Antoine Viallon Date: Tue, 7 Apr 2026 12:18:55 +0200 Subject: [PATCH 393/831] ggml-cuda : fix CDNA2 compute capability constant for gfx90a (MI210) (llama/21519) GGML_CUDA_CC_CDNA2 was set to 0x910 Fix by setting the constant to 0x90a to match the actual gfx90a ISA. --- ggml/src/ggml-cuda/common.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 9affe023403..1c9233b4fc1 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -65,7 +65,7 @@ #define GGML_CUDA_CC_VEGA (GGML_CUDA_CC_OFFSET_AMD + 0x900) // Vega56/64, minimum for fp16 dual issue #define GGML_CUDA_CC_VEGA20 (GGML_CUDA_CC_OFFSET_AMD + 0x906) // MI50/Radeon VII, minimum for dp4a #define GGML_CUDA_CC_CDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x908) // MI100, minimum for MFMA, acc registers -#define GGML_CUDA_CC_CDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x910) // MI210, minimum acc register renameing +#define GGML_CUDA_CC_CDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x90a) // MI210 (gfx90a), minimum acc register renaming #define GGML_CUDA_CC_CDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x942) // MI300 // RDNA removes MFMA, dp4a, xnack, acc registers, wave size is 32 From 18c98ffaf7355917935915cfedd95414fccdc1a2 Mon Sep 17 00:00:00 2001 From: mkoker <132301062+mkoker@users.noreply.github.com> Date: Tue, 7 Apr 2026 07:41:29 -0400 Subject: [PATCH 394/831] vulkan: add FA dequant for q4_1, q5_0, q5_1, iq4_nl (llama/21029) Add dequantize4() implementations for Q4_1, Q5_0, Q5_1, and IQ4_NL in the flash attention base shader. Register them in the shader generator, pipeline creation, and enable in the scalar/coopmat1 FA support check. --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 24 +++-- .../vulkan-shaders/flash_attn_base.glsl | 102 +++++++++++++++++- .../vulkan-shaders/vulkan-shaders-gen.cpp | 4 +- 3 files changed, 118 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 15ed5b2a79d..19e7fbdaae7 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -3447,11 +3447,19 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, ) CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, ) CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, ) + CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, ) + CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, ) + CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, ) + CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, ) } else { CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, _fp32) CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, _fp32) CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32) CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32) + CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _fp32) + CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _fp32) + CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _fp32) + CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _fp32) } #if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) if (device->coopmat1_fa_support) { @@ -3459,6 +3467,10 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT1, _cm1) CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT1, _cm1) CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT1, _cm1) + CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_COOPMAT1, _cm1) + CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_COOPMAT1, _cm1) + CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_COOPMAT1, _cm1) + CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_COOPMAT1, _cm1) } #endif #if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) @@ -15331,11 +15343,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_TYPE_F32: case GGML_TYPE_Q4_0: case GGML_TYPE_Q8_0: - // supported in scalar and coopmat2 paths - break; case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: + case GGML_TYPE_IQ4_NL: + // supported in scalar and coopmat2 paths + break; // K dequants currently disabled because D dimension is rounded up to 256 and runs inefficiently //case GGML_TYPE_Q2_K: //case GGML_TYPE_Q3_K: @@ -15350,12 +15363,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm //case GGML_TYPE_IQ3_XXS: //case GGML_TYPE_IQ3_S: //case GGML_TYPE_IQ4_XS: - case GGML_TYPE_IQ4_NL: - // currently supported only in coopmat2 path - if (!coopmat2) { - return false; - } - break; + default: return false; } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl index 172d38f034e..b30dee86871 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl @@ -110,7 +110,11 @@ FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { #if defined(DATA_A_Q4_0) #define BLOCK_BYTE_SIZE 18 +#elif defined(DATA_A_Q4_1) +#define BLOCK_BYTE_SIZE 20 +#endif +#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1) FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { if (binding_idx == BINDING_IDX_K) { uint vui_lo = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); @@ -119,7 +123,12 @@ FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { vui_lo >>= shift; vui_hi >>= shift; - return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * (FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - FLOAT_TYPE(8.0f)); + FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF); +#ifdef DATA_A_Q4_1 + return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * nibbles + FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].m); +#else + return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * (nibbles - FLOAT_TYPE(8.0f)); +#endif } else { uint vui_lo = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); uint vui_hi = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); @@ -127,11 +136,100 @@ FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { vui_lo >>= shift; vui_hi >>= shift; - return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * (FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - FLOAT_TYPE(8.0f)); + FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF); +#ifdef DATA_A_Q4_1 + return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * nibbles + FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].m); +#else + return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * (nibbles - FLOAT_TYPE(8.0f)); +#endif } } #endif +#if defined(DATA_A_Q5_0) +#define BLOCK_BYTE_SIZE 22 +#elif defined(DATA_A_Q5_1) +#define BLOCK_BYTE_SIZE 24 +#endif + +#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1) +FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { + if (binding_idx == BINDING_IDX_K) { + uint vui_lo = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); + uint vui_hi = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); + uint shift = (iqs & 0x10) >> 2; + vui_lo >>= shift; + vui_hi >>= shift; + +#ifdef DATA_A_Q5_1 + uint qh = k_packed.k_data_packed16[a_offset + ib].qh; +#else + uint qh = uint(k_packed.k_data_packed16[a_offset + ib].qh[0]) | (uint(k_packed.k_data_packed16[a_offset + ib].qh[1]) << 16); +#endif + FLOAT_TYPEV4 hb = FLOAT_TYPEV4((qh >> iqs) & 1, (qh >> (iqs + 1)) & 1, (qh >> (iqs + 2)) & 1, (qh >> (iqs + 3)) & 1) * FLOAT_TYPE(16.0f); + + FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF); +#ifdef DATA_A_Q5_1 + return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * (nibbles + hb) + FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].m); +#else + return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * (nibbles + hb - FLOAT_TYPE(16.0f)); +#endif + } else { + uint vui_lo = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); + uint vui_hi = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); + uint shift = (iqs & 0x10) >> 2; + vui_lo >>= shift; + vui_hi >>= shift; + +#ifdef DATA_A_Q5_1 + uint qh = v_packed.v_data_packed16[a_offset + ib].qh; +#else + uint qh = uint(v_packed.v_data_packed16[a_offset + ib].qh[0]) | (uint(v_packed.v_data_packed16[a_offset + ib].qh[1]) << 16); +#endif + FLOAT_TYPEV4 hb = FLOAT_TYPEV4((qh >> iqs) & 1, (qh >> (iqs + 1)) & 1, (qh >> (iqs + 2)) & 1, (qh >> (iqs + 3)) & 1) * FLOAT_TYPE(16.0f); + + FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF); +#ifdef DATA_A_Q5_1 + return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * (nibbles + hb) + FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].m); +#else + return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * (nibbles + hb - FLOAT_TYPE(16.0f)); +#endif + } +} +#endif + + +#if defined(DATA_A_IQ4_NL) +#define BLOCK_BYTE_SIZE 18 + +FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { + if (binding_idx == BINDING_IDX_K) { + uint vui_lo = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); + uint vui_hi = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); + uint shift = (iqs & 0x10) >> 2; + vui_lo >>= shift; + vui_hi >>= shift; + + return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * FLOAT_TYPEV4( + kvalues_iq4nl[vui_lo & 0xF], + kvalues_iq4nl[(vui_lo >> 8) & 0xF], + kvalues_iq4nl[vui_hi & 0xF], + kvalues_iq4nl[(vui_hi >> 8) & 0xF]); + } else { + uint vui_lo = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); + uint vui_hi = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); + uint shift = (iqs & 0x10) >> 2; + vui_lo >>= shift; + vui_hi >>= shift; + + return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * FLOAT_TYPEV4( + kvalues_iq4nl[vui_lo & 0xF], + kvalues_iq4nl[(vui_lo >> 8) & 0xF], + kvalues_iq4nl[vui_hi & 0xF], + kvalues_iq4nl[(vui_hi >> 8) & 0xF]); + } +} +#endif #if defined(DATA_A_Q8_0) #define BLOCK_BYTE_SIZE 34 FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 8186dba36f6..bf04f4822eb 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -655,7 +655,7 @@ void process_shaders() { if (tname == "f16") { string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp", merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"COOPMAT", "1"}}), fp16, true, false, f16acc); - } else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") { + } else if (tname == "q4_0" || tname == "q4_1" || tname == "q5_0" || tname == "q5_1" || tname == "iq4_nl" || tname == "q8_0" || tname == "f32") { std::string data_a_key = "DATA_A_" + to_uppercase(tname); string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp", merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), fp16, true, false, f16acc); @@ -666,7 +666,7 @@ void process_shaders() { if (tname == "f16") { string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), fp16, false, false, f16acc); - } else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") { + } else if (tname == "q4_0" || tname == "q4_1" || tname == "q5_0" || tname == "q5_1" || tname == "iq4_nl" || tname == "q8_0" || tname == "f32") { std::string data_a_key = "DATA_A_" + to_uppercase(tname); string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), fp16, false, false, f16acc); From 78b4fd85e13faf0dd25c3e94b3aecac3b4def041 Mon Sep 17 00:00:00 2001 From: Tom Overlund Date: Tue, 7 Apr 2026 07:54:55 -0400 Subject: [PATCH 395/831] ggml: Vulkan build, Linux -- output error string for errno on fork failure (#20868) (llama/20904) --- ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index bf04f4822eb..7afdcef7d22 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -137,6 +137,7 @@ void execute_command(std::vector& command, std::string& stdout_str, pid_t pid = fork(); if (pid < 0) { + std::cerr << strerror(errno) << "\n"; throw std::runtime_error("Failed to fork process"); } From f1d2b83db08b3ae11b72e0044015d1b196bbbdce Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 7 Apr 2026 15:28:27 +0300 Subject: [PATCH 396/831] ggml : deprecate GGML_OP_ADD1 (llama/21363) * ggml : deprecate GGML_OP_ADD1 * cont : remove tests * cont : re-enable vulkan check --- ggml/include/ggml.h | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 3bb2faa2c66..11d3e8a8167 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -902,15 +902,17 @@ extern "C" { struct ggml_tensor * b, struct ggml_tensor * ids); - GGML_API struct ggml_tensor * ggml_add1( + GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_add1( struct ggml_context * ctx, struct ggml_tensor * a, - struct ggml_tensor * b); + struct ggml_tensor * b), + "use ggml_add instead"); - GGML_API struct ggml_tensor * ggml_add1_inplace( + GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_add1_inplace( struct ggml_context * ctx, struct ggml_tensor * a, - struct ggml_tensor * b); + struct ggml_tensor * b), + "use ggml_add_inplace instead"); // dst = a // view(dst, nb1, nb2, nb3, offset) += b From 5ef7aafa0678070ce3cb428162c0c51fafc54d51 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Wed, 8 Apr 2026 00:57:04 +0800 Subject: [PATCH 397/831] CUDA: check for buffer overlap before fusing (llama/21566) * CUDA: check for buffer overlap before fusing * use ggml_cuda_check_fusion_memory_ranges --- ggml/src/ggml-cuda/ggml-cuda.cu | 138 ++++++++++++++++---------------- 1 file changed, 71 insertions(+), 67 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 75b62129ade..25b904b7dc2 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3308,6 +3308,71 @@ static bool ggml_cuda_topk_moe_fusion(const struct ggml_cgraph * cgraph, int nod return true; } +// returns whether the write (out) nodes overwrite the read nodes in operation +static bool ggml_cuda_check_fusion_memory_ranges(const ggml_cgraph * cgraph, + const int node_idx, + const int node_count, + const int * out_nodes, + const int out_count, + const bool is_topk_moe = false) { + auto nodes_overlap = [&](const ggml_tensor * a, const ggml_tensor * b) { + const int64_t a_start = (int64_t) a->data; + const int64_t a_end = a_start + ggml_backend_buft_get_alloc_size(a->buffer->buft, a); + + const int64_t b_start = (int64_t) b->data; + const int64_t b_end = b_start + ggml_backend_buft_get_alloc_size(b->buffer->buft, b); + + if ((b_start <= a_start && a_start < b_end) || (a_start <= b_start && b_start < a_end)) { + return true; + } + + return false; + }; + + bool is_ok = true; + // exception for topk-moe, as each row is read entirely before writing + if (ggml_nrows(cgraph->nodes[node_idx]) == 1 && is_topk_moe) { + return true; + } + + for (int i = 0; i < out_count; ++i) { + const ggml_tensor * dst = cgraph->nodes[out_nodes[i]]; + + for (int j = node_idx; j < node_idx + node_count; ++j) { + // Loop over all srcs of all nodes in the fusion. If the src overlaps + // the destination and the src is not an intermediate node that's being + // elided, then disable fusion. + + for (int src_idx = 0; src_idx < GGML_MAX_SRC; ++src_idx) { + const ggml_tensor * src = cgraph->nodes[j]->src[src_idx]; + + if (!src || src->op == GGML_OP_NONE) { + continue; + } + + if (nodes_overlap(dst, src)) { + bool found = false; + + for (int k = node_idx; k < j; ++k) { + if (cgraph->nodes[k] == src) { + found = true; + break; + } + } + + if (!found) { + is_ok = false; + break; + } + } + } + } + } + + return is_ok; +} + + static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list ops, @@ -3337,7 +3402,8 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, const ggml_tensor * glu = cgraph->nodes[node_idx + 4]; if (ggml_cuda_should_fuse_mul_mat(ffn_up, ffn_gate, glu, ffn_up_bias, ffn_gate_bias)) { - return true; + int out_nodes[] = { node_idx + 4 }; + return ggml_cuda_check_fusion_memory_ranges(cgraph, node_idx, (int)ops.size(), out_nodes, 1); } } @@ -3348,7 +3414,8 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, const ggml_tensor * glu = cgraph->nodes[node_idx + 2]; if (ggml_cuda_should_fuse_mul_mat(ffn_up, ffn_gate, glu)) { - return true; + int out_nodes[] = { node_idx + 2 }; + return ggml_cuda_check_fusion_memory_ranges(cgraph, node_idx, (int)ops.size(), out_nodes, 1); } } @@ -3474,69 +3541,6 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, return false; } -// returns whether the write (out) nodes overwrite the read nodes in operation -static bool ggml_cuda_check_fusion_memory_ranges(ggml_cgraph * cgraph, - int node_idx, - int node_count, - int * out_nodes, - int out_count) { - auto nodes_overlap = [&](const ggml_tensor * a, const ggml_tensor * b) { - const int64_t a_start = (int64_t) a->data; - const int64_t a_end = a_start + ggml_nbytes(a); - - const int64_t b_start = (int64_t) b->data; - const int64_t b_end = b_start + ggml_nbytes(b); - - if ((b_start <= a_start && a_start < b_end) || (a_start <= b_start && b_start < a_end)) { - return true; - } - - return false; - }; - - bool is_ok = true; - // for nrows=1, all fusion operations correctly read the src before writing dst or do it elementwise, so we should be ok - if (ggml_nrows(cgraph->nodes[node_idx]) == 1) { - return true; - } - - for (int i = 0; i < out_count; ++i) { - const ggml_tensor * dst = cgraph->nodes[out_nodes[i]]; - - for (int j = node_idx; j < node_idx + node_count; ++j) { - // Loop over all srcs of all nodes in the fusion. If the src overlaps - // the destination and the src is not an intermediate node that's being - // elided, then disable fusion. - - for (int src_idx = 0; src_idx < GGML_MAX_SRC; ++src_idx) { - const ggml_tensor * src = cgraph->nodes[j]->src[src_idx]; - - if (!src || src->op == GGML_OP_NONE) { - continue; - } - - if (nodes_overlap(dst, src)) { - bool found = false; - - for (int k = node_idx; k < j; ++k) { - if (cgraph->nodes[k] == src) { - found = true; - break; - } - } - - if (!found) { - is_ok = false; - break; - } - } - } - } - } - - return is_ok; -} - static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, const bool use_cuda_graph, const bool cuda_graph_update_required, const void * graph_key) { bool graph_evaluated_or_captured = false; @@ -3734,7 +3738,7 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) && ggml_cuda_should_use_topk_moe(node, logits, weights, ids) && - ggml_cuda_check_fusion_memory_ranges(cgraph, i, ops.size(), out_nodes, 2)) { + ggml_cuda_check_fusion_memory_ranges(cgraph, i, ops.size(), out_nodes, 2, /*is_topk_moe=*/ true)) { ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args); i += ops.size() - 1; continue; @@ -3750,7 +3754,7 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud int out_nodes[2] = { i + 1, i + 5 }; if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) && ggml_cuda_should_use_topk_moe(softmax, logits, weights, ids) && - ggml_cuda_check_fusion_memory_ranges(cgraph, i, ops.size(), out_nodes, 2)) { + ggml_cuda_check_fusion_memory_ranges(cgraph, i, ops.size(), out_nodes, 2, /*is_topk_moe=*/ true)) { ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args); i += ops.size() - 1; continue; From d1456437e1867fa957eb298648c68e48261ee476 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Tue, 7 Apr 2026 10:30:01 -0700 Subject: [PATCH 398/831] ggml-webgpu: parameterize submission size and add iOS specific limits (llama/21533) * Work towards removing bitcast * Move rest of existing types over * Add timeout back to wait and remove synchronous set_tensor/memset_tensor * move to unpackf16 for wider compatibility * cleanup * Remove deadlock condition in free_bufs * Start work on removing parameter buffer pools * Simplify and optimize further * simplify profile futures * Fix stride * Try using a single command buffer per batch * formatting * Add parameters for different browsers in-flight submissions * Update handling of batch size too * Throttle ios as much as possible * Increase timeout for llvm-pipe testing --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 148 ++++++++++++++++++++------- 1 file changed, 113 insertions(+), 35 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 5b118393640..3d038924b78 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -16,7 +16,6 @@ #include #include -#include #include #include #ifdef GGML_WEBGPU_GPU_PROFILE @@ -25,7 +24,6 @@ #if defined(GGML_WEBGPU_DEBUG) || defined(GGML_WEBGPU_CPU_PROFILE) || defined(GGML_WEBGPU_GPU_PROFILE) # include #endif -#include #include #include #include @@ -81,13 +79,13 @@ static inline void compute_2d_workgroups(uint32_t total_wg, uint32_t max_per_dim /* Constants */ -#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 32u -#define WEBGPU_NUM_PARAM_SLOTS \ - (WEBGPU_COMMAND_SUBMIT_BATCH_SIZE + 10) // a few extra for safety, since some operations may need multiple slots -#define WEBGPU_WAIT_ANY_TIMEOUT_MS 100 -#define WEBGPU_PARAMS_BUF_SIZE_BYTES 128 // enough for 32 parameters -#define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4 -#define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4 +#define WEBGPU_DEFAULT_COMMAND_SUBMIT_BATCH_SIZE 32u +#define WEBGPU_NUM_PARAM_SLOT_SAFETY_MARGIN 10u +#define WEBGPU_RUNTIME_WAIT_TIMEOUT_MS 30000u +#define WEBGPU_RUNTIME_WAIT_TIMEOUT_NS (WEBGPU_RUNTIME_WAIT_TIMEOUT_MS * 1e6) +#define WEBGPU_PARAMS_BUF_SIZE_BYTES 128 // enough for 32 parameters +#define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4 +#define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4 // For operations which process a row in parallel, this seems like a reasonable // default @@ -252,6 +250,8 @@ struct webgpu_global_context_struct { wgpu::Adapter adapter; wgpu::Device device; wgpu::Queue queue; + uint32_t command_submit_batch_size = WEBGPU_DEFAULT_COMMAND_SUBMIT_BATCH_SIZE; + uint32_t max_inflight_batches = UINT32_MAX; webgpu_capabilities capabilities; // Shared buffer to move data from device to host @@ -417,16 +417,72 @@ static void ggml_backend_webgpu_wait_profile_futures(webgpu_global_context & } #endif +template +static void ggml_backend_webgpu_check_wait_status(wgpu::WaitStatus wait_status, + T callback_status, + T success_status, + const char * wait_name, + const char * failure_name, + const char * callback_message) { + if (wait_status == wgpu::WaitStatus::TimedOut) { + GGML_ABORT("ggml_webgpu: %s timed out after %u ms\n", wait_name, WEBGPU_RUNTIME_WAIT_TIMEOUT_MS); + } + if (wait_status == wgpu::WaitStatus::Error) { + GGML_ABORT("ggml_webgpu: %s failed\n", wait_name); + } + if (callback_status != success_status) { + GGML_ABORT("ggml_webgpu: %s failed with status %d: %s\n", failure_name, static_cast(callback_status), + callback_message); + } +} + +#ifdef __EMSCRIPTEN__ +// iOS browsers seem to have very strict limits on the number of in-flight GPU commands, so we need to throttle to avoid failures. +EM_JS(int, ggml_webgpu_is_ios_browser, (), { + const ua = navigator.userAgent; + return (ua.includes('iPhone') || ua.includes('iPad')) ? 1 : 0; +}); +#endif + +static uint32_t ggml_backend_webgpu_get_max_inflight_batches(const wgpu::AdapterInfo & info) { +#ifdef __EMSCRIPTEN__ + if (ggml_webgpu_is_ios_browser()) { + return 1; + } +#else + GGML_UNUSED(info); +#endif + + return UINT32_MAX; +} + +static uint32_t ggml_backend_webgpu_get_command_submit_batch_size(const wgpu::AdapterInfo & info) { +#ifdef __EMSCRIPTEN__ + if (ggml_webgpu_is_ios_browser()) { + return 16; + } +#else + GGML_UNUSED(info); +#endif + + return WEBGPU_DEFAULT_COMMAND_SUBMIT_BATCH_SIZE; +} + static void ggml_backend_webgpu_wait_queue(webgpu_global_context & ctx) { - ctx->instance.WaitAny( - ctx->queue.OnSubmittedWorkDone(wgpu::CallbackMode::AllowSpontaneous, - [](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) { - if (status != wgpu::QueueWorkDoneStatus::Success) { - GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", - std::string(message).c_str()); - } - }), - UINT64_MAX); + wgpu::QueueWorkDoneStatus callback_status = wgpu::QueueWorkDoneStatus::Error; + std::string callback_message; + + const wgpu::WaitStatus wait_status = ctx->instance.WaitAny( + ctx->queue.OnSubmittedWorkDone( + wgpu::CallbackMode::AllowSpontaneous, + [&callback_status, &callback_message](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) { + callback_status = status; + callback_message = std::string(message); + }), + WEBGPU_RUNTIME_WAIT_TIMEOUT_NS); + + ggml_backend_webgpu_check_wait_status(wait_status, callback_status, wgpu::QueueWorkDoneStatus::Success, + "Queue wait", "Queue work", callback_message.c_str()); } static void ggml_backend_webgpu_map_buffer(webgpu_global_context & ctx, @@ -434,14 +490,31 @@ static void ggml_backend_webgpu_map_buffer(webgpu_global_context & ctx, wgpu::MapMode mode, size_t offset, size_t size) { - ctx->instance.WaitAny(buffer.MapAsync(mode, offset, size, wgpu::CallbackMode::AllowSpontaneous, - [](wgpu::MapAsyncStatus status, wgpu::StringView message) { - if (status != wgpu::MapAsyncStatus::Success) { - GGML_LOG_ERROR("ggml_webgpu: Failed to map buffer: %s\n", - message.data); - } - }), - UINT64_MAX); + wgpu::MapAsyncStatus callback_status = wgpu::MapAsyncStatus::Error; + std::string callback_message; + + const wgpu::WaitStatus wait_status = ctx->instance.WaitAny( + buffer.MapAsync(mode, offset, size, wgpu::CallbackMode::AllowSpontaneous, + [&callback_status, &callback_message](wgpu::MapAsyncStatus status, wgpu::StringView message) { + callback_status = status; + callback_message = std::string(message); + }), + WEBGPU_RUNTIME_WAIT_TIMEOUT_NS); + + ggml_backend_webgpu_check_wait_status(wait_status, callback_status, wgpu::MapAsyncStatus::Success, + "Buffer map wait", "Buffer map", callback_message.c_str()); +} + +static void ggml_backend_webgpu_submit_commands(webgpu_context & ctx, + const wgpu::CommandBuffer commands, + uint32_t & num_inflight_batches) { + if (num_inflight_batches >= ctx->global_ctx->max_inflight_batches) { + ggml_backend_webgpu_wait_queue(ctx->global_ctx); + num_inflight_batches = 0; + } + + ctx->global_ctx->queue.Submit(1, &commands); + num_inflight_batches++; } #ifdef GGML_WEBGPU_DEBUG @@ -2871,9 +2944,10 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str #ifdef GGML_WEBGPU_GPU_PROFILE std::vector profile_futures; #endif - uint32_t num_batched_kernels = 0; - bool contains_set_rows = false; - wgpu::CommandEncoder batch_encoder = ctx->global_ctx->device.CreateCommandEncoder(); + uint32_t num_batched_kernels = 0; + uint32_t num_inflight_batches = 0; + bool contains_set_rows = false; + wgpu::CommandEncoder batch_encoder = ctx->global_ctx->device.CreateCommandEncoder(); for (int i = 0; i < cgraph->n_nodes; i++) { if (cgraph->nodes[i]->op == GGML_OP_SET_ROWS) { @@ -2884,10 +2958,10 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str num_batched_kernels += cmd.value().num_kernels; } - if (num_batched_kernels >= WEBGPU_COMMAND_SUBMIT_BATCH_SIZE) { + if (num_batched_kernels >= ctx->global_ctx->command_submit_batch_size) { num_batched_kernels = 0; wgpu::CommandBuffer batch_commands = batch_encoder.Finish(); - ctx->global_ctx->queue.Submit(1, &batch_commands); + ggml_backend_webgpu_submit_commands(ctx, batch_commands, num_inflight_batches); #ifdef GGML_WEBGPU_GPU_PROFILE ggml_backend_webgpu_collect_profile_futures(ctx->global_ctx, commands, profile_futures); #endif @@ -2898,7 +2972,7 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str } if (!commands.empty()) { wgpu::CommandBuffer batch_commands = batch_encoder.Finish(); - ctx->global_ctx->queue.Submit(1, &batch_commands); + ggml_backend_webgpu_submit_commands(ctx, batch_commands, num_inflight_batches); #ifdef GGML_WEBGPU_GPU_PROFILE ggml_backend_webgpu_collect_profile_futures(ctx->global_ctx, commands, profile_futures); #endif @@ -2912,7 +2986,7 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str encoder.CopyBufferToBuffer(ctx->set_rows_dev_error_buf, 0, ctx->set_rows_host_error_buf, 0, ctx->set_rows_host_error_buf.GetSize()); wgpu::CommandBuffer set_rows_commands = encoder.Finish(); - ctx->global_ctx->queue.Submit(1, &set_rows_commands); + ggml_backend_webgpu_submit_commands(ctx, set_rows_commands, num_inflight_batches); } ggml_backend_webgpu_wait_queue(ctx->global_ctx); @@ -3363,6 +3437,8 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { } #endif ctx->webgpu_global_ctx->adapter.GetInfo(&info); + ctx->webgpu_global_ctx->command_submit_batch_size = ggml_backend_webgpu_get_command_submit_batch_size(info); + ctx->webgpu_global_ctx->max_inflight_batches = ggml_backend_webgpu_get_max_inflight_batches(info); wgpu::SupportedFeatures features; ctx->webgpu_global_ctx->adapter.GetFeatures(&features); // we require f16 support @@ -3483,8 +3559,10 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) { webgpu_context webgpu_ctx = std::make_shared(); webgpu_ctx->global_ctx = dev_ctx->webgpu_global_ctx; webgpu_ctx->shader_lib = std::make_unique(dev_ctx->webgpu_global_ctx->device); - webgpu_ctx->param_arena.init(webgpu_ctx->global_ctx->device, WEBGPU_PARAMS_BUF_SIZE_BYTES, WEBGPU_NUM_PARAM_SLOTS, - webgpu_ctx->global_ctx->capabilities.limits.minUniformBufferOffsetAlignment); + webgpu_ctx->param_arena.init( + webgpu_ctx->global_ctx->device, WEBGPU_PARAMS_BUF_SIZE_BYTES, + webgpu_ctx->global_ctx->command_submit_batch_size + WEBGPU_NUM_PARAM_SLOT_SAFETY_MARGIN, + webgpu_ctx->global_ctx->capabilities.limits.minUniformBufferOffsetAlignment); ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->set_rows_dev_error_buf, WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES, wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc, "set_rows_dev_error_buf"); From d91d1e8e6c362ac503cc659d5be8cafd7c35ab86 Mon Sep 17 00:00:00 2001 From: iacopPBK Date: Tue, 7 Apr 2026 21:47:42 +0200 Subject: [PATCH 399/831] ggml-cuda: ds_read_b128 for q4_0 and q4_1 mmq kernels (llama/21168) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * ds_read_b128 for q4_0 and q4_1 mmq kernels Current for loop generates ds_read_b32 instructions with hip compiler, the new solution generates ds_read_b128 instructions for the same operation, saving some LDS bandwidth. Tested on MI50 and RX6800XT, its faster on both. * Vectorized lds load update: used ggml_cuda_get_max_cpy_bytes and ggml_cuda_memcpy_1 functions for generic implementation * Explicit for loop in mmq, renamed vec into tmp * Fixed max_cpy usage in the loading loop * Fixed typo in q4_1 kernel * Update ggml/src/ggml-cuda/mmq.cuh Co-authored-by: Johannes Gäßler * Update ggml/src/ggml-cuda/mmq.cuh Co-authored-by: Johannes Gäßler * Update ggml/src/ggml-cuda/mmq.cuh Co-authored-by: Johannes Gäßler * Renoved trailing white line 500 * Update mmq.cuh removed other whitelines * Remove trailing whitespaces --------- Co-authored-by: iacopPBK Co-authored-by: Johannes Gäßler Co-authored-by: iacopPBK --- ggml/src/ggml-cuda/mmq.cuh | 37 +++++++++++++++++++++++++++---------- 1 file changed, 27 insertions(+), 10 deletions(-) diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 51e8dad4ce7..489d3616bb4 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -386,17 +386,25 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a( #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { const int i = i0 + threadIdx.x; - const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2); int u[2*VDR_Q4_0_Q8_1_MMQ]; -#pragma unroll - for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) { - u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + kyqs + l]; - u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_0)]; + constexpr int max_cpy = ggml_cuda_get_max_cpy_bytes(); + constexpr int mcpy_int = max_cpy / sizeof(int); + static_assert(VDR_Q4_0_Q8_1_MMQ == 4, "bad VDR_Q4_0_Q8_1_MMQ"); + + int tmp0[4], tmp1[4]; + + #pragma unroll + for (int l0 = 0; l0 < 4 / mcpy_int; ++l0) { + ggml_cuda_memcpy_1(tmp0 + l0 * mcpy_int, &y_qs[j*MMQ_TILE_Y_K + kyqs + l0 * mcpy_int] ); + ggml_cuda_memcpy_1(tmp1 + l0 * mcpy_int, &y_qs[j*MMQ_TILE_Y_K + kyqs + QI4_0 + l0 * mcpy_int]); } + u[0]=tmp0[0]; u[2]=tmp0[1]; u[4]=tmp0[2]; u[6]=tmp0[3]; + u[1]=tmp1[0]; u[3]=tmp1[1]; u[5]=tmp1[2]; u[7]=tmp1[3]; + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_0_q8_1_impl (&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/QR4_0], u, x_df[i*(MMQ_TILE_NE_K/QI4_0) + i/QI4_0 + k0/(QR4_0*QI4_0)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); @@ -489,17 +497,25 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a( #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { const int i = i0 + threadIdx.x; - const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2); int u[2*VDR_Q4_1_Q8_1_MMQ]; -#pragma unroll - for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) { - u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + kyqs + l]; - u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_1)]; + constexpr int max_cpy = ggml_cuda_get_max_cpy_bytes(); + constexpr int mcpy_int = max_cpy / sizeof(int); + static_assert(VDR_Q4_0_Q8_1_MMQ == 4, "bad VDR_Q4_0_Q8_1_MMQ"); + + int tmp0[4], tmp1[4]; + + #pragma unroll + for (int l0 = 0; l0 < 4 / mcpy_int; ++l0) { + ggml_cuda_memcpy_1(tmp0 + l0 * mcpy_int, &y_qs[j*MMQ_TILE_Y_K + kyqs + l0 * mcpy_int] ); + ggml_cuda_memcpy_1(tmp1 + l0 * mcpy_int, &y_qs[j*MMQ_TILE_Y_K + kyqs + QI4_1 + l0 * mcpy_int]); } + u[0]=tmp0[0]; u[2]=tmp0[1]; u[4]=tmp0[2]; u[6]=tmp0[3]; + u[1]=tmp1[0]; u[3]=tmp1[1]; u[5]=tmp1[2]; u[7]=tmp1[3]; + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_1_q8_1_impl (&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/QR4_1], u, x_dm[i*(MMQ_TILE_NE_K/QI4_1) + i/QI4_1 + k0/(QR4_1*QI4_1)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); @@ -4170,3 +4186,4 @@ void ggml_cuda_op_mul_mat_q( const int64_t src1_padded_row_size, cudaStream_t stream); bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t n_experts); + From fa2eaa433bb9aba9bee5fefc0341a9a0b9d6091b Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Wed, 8 Apr 2026 09:05:51 +0800 Subject: [PATCH 400/831] CUDA: make cuda graphs props check faster (llama/21472) * CUDA: compute fast hash instead of expensive props check * use seen node * use memcp --- ggml/src/ggml-cuda/common.cuh | 21 +----- ggml/src/ggml-cuda/ggml-cuda.cu | 113 ++------------------------------ 2 files changed, 6 insertions(+), 128 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 1c9233b4fc1..a2960e5ae3c 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -1157,19 +1157,6 @@ struct ggml_tensor_extra_gpu { #define USE_CUDA_GRAPH #endif -struct ggml_cuda_graph_node_properties { - void * node_data; - ggml_op node_op; - enum ggml_type node_type; - int32_t flags; - int64_t ne[GGML_MAX_DIMS]; - size_t nb[GGML_MAX_DIMS]; - void * src_data[GGML_MAX_SRC]; - int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)]; -}; - -static_assert(std::is_trivial::value, "ggml_cuda_graph_node_properties must be trivial"); - struct ggml_cuda_graph { #ifdef USE_CUDA_GRAPH ~ggml_cuda_graph() { @@ -1186,13 +1173,7 @@ struct ggml_cuda_graph { std::vector nodes; bool disable_due_to_gpu_arch = false; bool warmup_complete = false; - std::vector props; - - // these are extra tensors (inputs) that participate in the ggml graph but are not nodes - // they properties also have to match in order to be able to safely reuse a CUDA graph - // ref: https://github.com/ggml-org/llama.cpp/pull/18583 - // ref: https://github.com/ggml-org/llama.cpp/pull/19165 - std::vector extra; + std::vector nodes_copy; bool is_enabled() const { static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 25b904b7dc2..b21196bb4f3 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -82,7 +82,6 @@ #include #include #include -#include static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size"); @@ -2969,74 +2968,6 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) { return use_cuda_graph; } -static void ggml_cuda_graph_node_set_properties(ggml_cuda_graph_node_properties * props, ggml_tensor * node) { - memset(props, 0, sizeof(ggml_cuda_graph_node_properties)); - props->node_data = node->data; - props->node_op = node->op; - props->node_type = node->type; - props->flags = node->flags; - for (int i = 0; i < GGML_MAX_DIMS; i++) { - props->ne[i] = node->ne[i]; - props->nb[i] = node->nb[i]; - } - for (int i = 0; i < GGML_MAX_SRC; i++) { - if (!node->src[i]) { - continue; - } - - props->src_data[i] = node->src[i]->data; - } - memcpy(props->op_params, node->op_params, GGML_MAX_OP_PARAMS); -} - -static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_graph_node_properties * props) { - if (node->data != props->node_data && node->op != GGML_OP_VIEW) { - return false; - } - - if (node->op != props->node_op) { - return false; - } - - if (node->type != props->node_type) { - return false; - } - - for (int i = 0; i < GGML_MAX_DIMS; i++) { - if (node->ne[i] != props->ne[i]) { - return false; - } - if (node->nb[i] != props->nb[i]) { - return false; - } - } - - if (node->op != GGML_OP_VIEW) { - for (int i = 0; i < GGML_MAX_SRC; i++) { - if (!node->src[i]) { - if (props->src_data[i] != nullptr) { - return false; - } - continue; - } - - if (node->src[i]->data != props->src_data[i]) { - return false; - } - } - } - - if (memcmp(props->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) { - return false; - } - - if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) != (props->flags & GGML_TENSOR_FLAG_COMPUTE)) { - return false; - } - - return true; -} - static const void * ggml_cuda_graph_get_key(ggml_cgraph * cgraph) { return cgraph->nodes[0]; } @@ -3048,52 +2979,18 @@ static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key); // Check if the graph size has changed - if (graph->props.size() != (size_t)cgraph->n_nodes) { + if ((int)graph->nodes_copy.size() != cgraph->n_nodes) { res = true; - graph->props.resize(cgraph->n_nodes); + graph->nodes_copy.resize(cgraph->n_nodes); } - // Loop over nodes in GGML graph to determine if CUDA graph update is required - // and store properties to allow this comparison for the next token - std::unordered_set seen_node; - std::vector srcs_extra; for (int i = 0; i < cgraph->n_nodes; i++) { - bool props_match = true; - - seen_node.insert(cgraph->nodes[i]); - if (!res) { - props_match = ggml_cuda_graph_node_properties_match(cgraph->nodes[i], &graph->props[i]); - } - if (!props_match) { - res = true; - } - ggml_cuda_graph_node_set_properties(&graph->props[i], cgraph->nodes[i]); - - for (int src_idx = 0; src_idx < GGML_MAX_SRC; ++src_idx) { - ggml_tensor * src = cgraph->nodes[i]->src[src_idx]; - if (src && seen_node.find(src) == seen_node.end()) { - srcs_extra.push_back(src); + if (memcmp(&graph->nodes_copy[i], cgraph->nodes[i], sizeof(ggml_tensor)) != 0) { + res = true; } } - } - - if (graph->extra.size() != (size_t) srcs_extra.size()) { - res = true; - graph->extra.resize(srcs_extra.size()); - } - - for (size_t i = 0; i < srcs_extra.size(); ++i) { - bool props_match = true; - - if (!res) { - props_match = ggml_cuda_graph_node_properties_match(srcs_extra[i], &graph->extra[i]); - } - - if (!props_match) { - res = true; - } - ggml_cuda_graph_node_set_properties(&graph->extra[i], srcs_extra[i]); + memcpy(&graph->nodes_copy[i], cgraph->nodes[i], sizeof(ggml_tensor)); } return res; From 15deafa31ecdd0a44a9a41494f9f7785068fe822 Mon Sep 17 00:00:00 2001 From: Pasha Khosravi Date: Wed, 8 Apr 2026 06:07:47 -0700 Subject: [PATCH 401/831] metal: Q1_0 backend (llama/21528) * initial Q1_0 Metal backend * tuning q1_0 metal kernels * add Q1_0 to test-backend-ops * add Q1_0<->F32 copy test * Apply suggestions from code review Co-authored-by: Georgi Gerganov --------- Co-authored-by: Georgi Gerganov --- ggml/src/ggml-metal/ggml-metal-device.cpp | 10 ++ ggml/src/ggml-metal/ggml-metal-device.m | 2 + ggml/src/ggml-metal/ggml-metal-impl.h | 3 + ggml/src/ggml-metal/ggml-metal-ops.cpp | 1 + ggml/src/ggml-metal/ggml-metal.metal | 187 ++++++++++++++++++++++ 5 files changed, 203 insertions(+) diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 89539bd7615..e8548b053e8 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -736,6 +736,11 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv(ggml_meta suffix = ne00 % 4 == 0 ? "_4" : ""; } } break; + case GGML_TYPE_Q1_0: + { + nsg = N_SG_Q1_0; + nr0 = N_R0_Q1_0; + } break; case GGML_TYPE_Q4_0: { nsg = N_SG_Q4_0; @@ -948,6 +953,11 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_id(ggml_m smem = 32*sizeof(float)*nr0; suffix = ne00 % 4 == 0 ? "_4" : ""; } break; + case GGML_TYPE_Q1_0: + { + nsg = N_SG_Q1_0; + nr0 = N_R0_Q1_0; + } break; case GGML_TYPE_Q4_0: { nsg = N_SG_Q4_0; diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 17d51b11b6e..40cacb46520 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -1184,6 +1184,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_TYPE_F16: case GGML_TYPE_BF16: case GGML_TYPE_Q8_0: + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -1210,6 +1211,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te default: return false; } + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index eb2253e029a..62b028f4a4a 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -8,6 +8,9 @@ // // TODO: for optimal performance, become function of the device and work size +#define N_R0_Q1_0 8 +#define N_SG_Q1_0 2 + #define N_R0_Q4_0 4 #define N_SG_Q4_0 2 diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 3cda21be43e..846225d9077 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -2047,6 +2047,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) { op->src[0]->type == GGML_TYPE_F32 || // TODO: helper function op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_BF16 || + op->src[0]->type == GGML_TYPE_Q1_0 || op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q4_1 || op->src[0]->type == GGML_TYPE_Q5_0 || diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 2074211594c..f28bfa0b95b 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -118,6 +118,56 @@ void dequantize_bf16_t4(device const bfloat4 * src, short il, thread type4 & reg } #endif +template +void dequantize_q1_0(device const block_q1_0 * xb, short il, thread type4x4 & reg) { + device const uint8_t * qs = xb->qs; + const float d = xb->d; + const float neg_d = -d; + + const int byte_offset = il * 2; // il*16 bits = il*2 bytes + const uint8_t b0 = qs[byte_offset]; + const uint8_t b1 = qs[byte_offset + 1]; + + float4x4 reg_f; + + reg_f[0][0] = select(neg_d, d, bool(b0 & 0x01)); + reg_f[0][1] = select(neg_d, d, bool(b0 & 0x02)); + reg_f[0][2] = select(neg_d, d, bool(b0 & 0x04)); + reg_f[0][3] = select(neg_d, d, bool(b0 & 0x08)); + reg_f[1][0] = select(neg_d, d, bool(b0 & 0x10)); + reg_f[1][1] = select(neg_d, d, bool(b0 & 0x20)); + reg_f[1][2] = select(neg_d, d, bool(b0 & 0x40)); + reg_f[1][3] = select(neg_d, d, bool(b0 & 0x80)); + + reg_f[2][0] = select(neg_d, d, bool(b1 & 0x01)); + reg_f[2][1] = select(neg_d, d, bool(b1 & 0x02)); + reg_f[2][2] = select(neg_d, d, bool(b1 & 0x04)); + reg_f[2][3] = select(neg_d, d, bool(b1 & 0x08)); + reg_f[3][0] = select(neg_d, d, bool(b1 & 0x10)); + reg_f[3][1] = select(neg_d, d, bool(b1 & 0x20)); + reg_f[3][2] = select(neg_d, d, bool(b1 & 0x40)); + reg_f[3][3] = select(neg_d, d, bool(b1 & 0x80)); + + reg = (type4x4) reg_f; +} + +template +void dequantize_q1_0_t4(device const block_q1_0 * xb, short il, thread type4 & reg) { + const float d = xb->d; + const float neg_d = -d; + const int base = il * 4; + const uint8_t byte = xb->qs[base / 8]; + const int s = base % 8; + + float4 reg_f; + reg_f[0] = select(neg_d, d, bool((byte >> (s )) & 1)); + reg_f[1] = select(neg_d, d, bool((byte >> (s + 1)) & 1)); + reg_f[2] = select(neg_d, d, bool((byte >> (s + 2)) & 1)); + reg_f[3] = select(neg_d, d, bool((byte >> (s + 3)) & 1)); + + reg = (type4) reg_f; +} + template void dequantize_q4_0(device const block_q4_0 * xb, short il, thread type4x4 & reg) { device const uint16_t * qs = ((device const uint16_t *)xb + 1); @@ -152,6 +202,23 @@ void dequantize_q4_0_t4(device const block_q4_0 * xb, short il, thread type4 & r } } +void quantize_q1_0(device const float * src, device block_q1_0 & dst) { + float sum_abs = 0.0f; + for (int j = 0; j < QK1_0; j++) { + sum_abs += fabs(src[j]); + } + dst.d = sum_abs / QK1_0; + + for (int j = 0; j < QK1_0 / 8; j++) { + dst.qs[j] = 0; + } + for (int j = 0; j < QK1_0; j++) { + if (src[j] >= 0.0f) { + dst.qs[j / 8] |= (1 << (j % 8)); + } + } +} + void quantize_q4_0(device const float * src, device block_q4_0 & dst) { #pragma METAL fp math_mode(safe) float amax = 0.0f; // absolute max @@ -3116,6 +3183,35 @@ kernel void kernel_group_norm_f32( } } +// Q1_0 dot product: dot = d * (2 * Σ(yl[i] where bit=1) - sumy) +inline float block_q_n_dot_y(device const block_q1_0 * qb_curr, float sumy, thread float * yl, int il) { + device const uint8_t * qs = qb_curr->qs + il / 8; + const uint8_t b0 = qs[0]; + const uint8_t b1 = qs[1]; + + float acc = 0.0f; + + acc += select(0.0f, yl[ 0], bool(b0 & 0x01)); + acc += select(0.0f, yl[ 1], bool(b0 & 0x02)); + acc += select(0.0f, yl[ 2], bool(b0 & 0x04)); + acc += select(0.0f, yl[ 3], bool(b0 & 0x08)); + acc += select(0.0f, yl[ 4], bool(b0 & 0x10)); + acc += select(0.0f, yl[ 5], bool(b0 & 0x20)); + acc += select(0.0f, yl[ 6], bool(b0 & 0x40)); + acc += select(0.0f, yl[ 7], bool(b0 & 0x80)); + + acc += select(0.0f, yl[ 8], bool(b1 & 0x01)); + acc += select(0.0f, yl[ 9], bool(b1 & 0x02)); + acc += select(0.0f, yl[10], bool(b1 & 0x04)); + acc += select(0.0f, yl[11], bool(b1 & 0x08)); + acc += select(0.0f, yl[12], bool(b1 & 0x10)); + acc += select(0.0f, yl[13], bool(b1 & 0x20)); + acc += select(0.0f, yl[14], bool(b1 & 0x40)); + acc += select(0.0f, yl[15], bool(b1 & 0x80)); + + return qb_curr->d * (2.0f * acc - sumy); +} + // function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i]) // il indicates where the q4 quants begin (0 or QK4_0/4) // we assume that the yl's have been multiplied with the appropriate scale factor @@ -3337,6 +3433,85 @@ void mul_vec_q_n_f32_impl( } } +template +void kernel_mul_mv_q1_0_f32_impl( + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + const short NSG = FC_mul_mv_nsg; + + const int nb = args.ne00/QK1_0; + + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * NSG + sgitg) * nr0; + + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; + + const uint64_t offset1 = r1*args.nb11 + (i12)*args.nb12 + (i13)*args.nb13; + + device const float * y = (device const float *) (src1 + offset1); + + device const block_q1_0 * ax[nr0]; + for (int row = 0; row < nr0; ++row) { + const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + ax[row] = (device const block_q1_0 *) ((device char *) src0 + offset0); + } + + float yl[16]; + float sumf[nr0] = {0.f}; + + const short ix = (tiisg/8); + const short il = (tiisg%8)*16; + + device const float * yb = y + ix*QK1_0 + il; + + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/8) { + float sumy = 0.f; + + FOR_UNROLL (short i = 0; i < 16; i++) { + yl[i] = yb[i]; + sumy += yb[i]; + } + + FOR_UNROLL (short row = 0; row < nr0; row++) { + sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy, yl, il); + } + + yb += QK1_0 * (N_SIMDWIDTH/8); + } + + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + + for (int row = 0; row < nr0; ++row) { + const float tot = simd_sum(sumf[row]); + + if (tiisg == 0 && first_row + row < args.ne01) { + dst_f32[first_row + row] = tot; + } + } +} + +[[host_name("kernel_mul_mv_q1_0_f32")]] +kernel void kernel_mul_mv_q1_0_f32( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + kernel_mul_mv_q1_0_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); +} + kernel void kernel_mul_mv_q4_0_f32( constant ggml_metal_kargs_mul_mv & args, device const char * src0, @@ -3729,6 +3904,11 @@ template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_4")]] kernel mul_mv_ext_q4 template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, bfloat4, 4, dequantize_bf16_t4>; #endif +template [[host_name("kernel_mul_mv_ext_q1_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q1_0, 128, dequantize_q1_0_t4>; +template [[host_name("kernel_mul_mv_ext_q1_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q1_0, 128, dequantize_q1_0_t4>; +template [[host_name("kernel_mul_mv_ext_q1_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q1_0, 128, dequantize_q1_0_t4>; +template [[host_name("kernel_mul_mv_ext_q1_0_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q1_0, 128, dequantize_q1_0_t4>; + template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q4_0, 32, dequantize_q4_0_t4>; template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q4_0, 32, dequantize_q4_0_t4>; template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q4_0, 32, dequantize_q4_0_t4>; @@ -7133,6 +7313,7 @@ kernel void kernel_cpy_f32_q( typedef decltype(kernel_cpy_f32_q) cpy_f_q_t; template [[host_name("kernel_cpy_f32_q8_0")]] kernel cpy_f_q_t kernel_cpy_f32_q; +template [[host_name("kernel_cpy_f32_q1_0")]] kernel cpy_f_q_t kernel_cpy_f32_q; template [[host_name("kernel_cpy_f32_q4_0")]] kernel cpy_f_q_t kernel_cpy_f32_q; template [[host_name("kernel_cpy_f32_q4_1")]] kernel cpy_f_q_t kernel_cpy_f32_q; template [[host_name("kernel_cpy_f32_q5_0")]] kernel cpy_f_q_t kernel_cpy_f32_q; @@ -7173,12 +7354,14 @@ kernel void kernel_cpy_q_f32( typedef decltype(kernel_cpy_q_f32) cpy_q_f_t; +template [[host_name("kernel_cpy_q1_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32; template [[host_name("kernel_cpy_q4_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32; template [[host_name("kernel_cpy_q4_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32; template [[host_name("kernel_cpy_q5_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32; template [[host_name("kernel_cpy_q5_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32; template [[host_name("kernel_cpy_q8_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32; +template [[host_name("kernel_cpy_q1_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32; template [[host_name("kernel_cpy_q4_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32; template [[host_name("kernel_cpy_q4_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32; template [[host_name("kernel_cpy_q5_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32; @@ -9776,6 +9959,7 @@ template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_ro typedef decltype(kernel_get_rows_q) get_rows_q_t; +template [[host_name("kernel_get_rows_q1_0")]] kernel get_rows_q_t kernel_get_rows_q; template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_q_t kernel_get_rows_q; template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_q_t kernel_get_rows_q; template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_q_t kernel_get_rows_q; @@ -9838,6 +10022,7 @@ template [[host_name("kernel_mul_mm_f16_f32")]] kernel mul_mm_t kernel_mul_m #if defined(GGML_METAL_HAS_BF16) template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mul_mm_t kernel_mul_mm; #endif +template [[host_name("kernel_mul_mm_q1_0_f32")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mul_mm_t kernel_mul_mm; @@ -9861,6 +10046,7 @@ template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_m template [[host_name("kernel_mul_mm_f32_f16")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_f16_f16")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q1_0_f16")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q4_0_f16")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q4_1_f16")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q5_0_f16")]] kernel mul_mm_t kernel_mul_mm; @@ -10070,6 +10256,7 @@ template [[host_name("kernel_mul_mv_id_bf16_f32_4")]] kernel kernel_mul_mv_id_4 template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_q1_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; From e70c0d43f4a8dba5dc0ba7faa387f88d7b41a74c Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Wed, 8 Apr 2026 06:08:29 -0700 Subject: [PATCH 402/831] webgpu : Query for adapter support when registering WebGPU backend (llama/21579) --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 39 +++++++++++++++++++--------- 1 file changed, 27 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 3d038924b78..b8df0f4dd05 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -4033,8 +4033,14 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() { WEBGPU_LOG_DEBUG("ggml_backend_webgpu_reg()"); static ggml_backend_webgpu_reg_context ctx; + static ggml_backend_reg reg = { + /* .api_version = */ GGML_BACKEND_API_VERSION, + /* .iface = */ ggml_backend_webgpu_reg_i, + /* .context = */ &ctx, + }; + ctx.name = GGML_WEBGPU_NAME; - ctx.device_count = 1; + ctx.device_count = 0; wgpu::InstanceDescriptor instance_descriptor{}; std::vector instance_features = { wgpu::InstanceFeatureName::TimedWaitAny }; @@ -4053,19 +4059,28 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() { ctx.webgpu_global_ctx = webgpu_global_context(new webgpu_global_context_struct()); ctx.webgpu_global_ctx->instance = std::move(inst); -#ifdef __EMSCRIPTEN__ - if (ctx.webgpu_global_ctx->instance == nullptr) { - GGML_LOG_ERROR("ggml_webgpu: Failed to create WebGPU instance. Make sure either -sASYNCIFY or -sJSPI is set\n"); - return nullptr; + wgpu::Adapter adapter; + if (ctx.webgpu_global_ctx->instance != nullptr) { + wgpu::RequestAdapterOptions options = {}; + + // probe for adapter support + ctx.webgpu_global_ctx->instance.WaitAny( + ctx.webgpu_global_ctx->instance.RequestAdapter( + &options, wgpu::CallbackMode::AllowSpontaneous, + [&adapter](wgpu::RequestAdapterStatus status, wgpu::Adapter _adapter, const char * message) { + if (status != wgpu::RequestAdapterStatus::Success) { + GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message); + return; + } + adapter = std::move(_adapter); + }), + UINT64_MAX); + } + + if (adapter != nullptr) { + ctx.device_count = 1; } -#endif - GGML_ASSERT(ctx.webgpu_global_ctx->instance != nullptr); - static ggml_backend_reg reg = { - /* .api_version = */ GGML_BACKEND_API_VERSION, - /* .iface = */ ggml_backend_webgpu_reg_i, - /* .context = */ &ctx, - }; return ® } From 16dd1716204773a2f99a29ecf46748fa29d5f2b9 Mon Sep 17 00:00:00 2001 From: RealOrko <45273739+RealOrko@users.noreply.github.com> Date: Wed, 8 Apr 2026 16:40:15 +0100 Subject: [PATCH 403/831] fix: free ctx_copy in ggml_opt_free to plug per-training-session leak (llama/21592) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: free ctx_copy in ggml_opt_free to plug per-training-session leak ggml_opt_alloc populates opt_ctx->ctx_copy via a free+init pair every time the allocated graph shape changes. The last ctx_copy from the final ggml_opt_alloc call survives until ggml_opt_free is invoked, but ggml_opt_free was only freeing ctx_static and ctx_cpu, never ctx_copy. Each opt_ctx lifetime therefore leaks the final per-batch context — ~900 KB for a typical GNN training session in sindarin-pkg-tensor, surfaced via AddressSanitizer. ctx_copy is nullptr-initialized and ggml_free() handles NULL safely, so the new release is guard-free. * Update ggml/src/ggml-opt.cpp Co-authored-by: Johannes Gäßler --------- Co-authored-by: realorko Co-authored-by: Johannes Gäßler --- ggml/src/ggml-opt.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/ggml/src/ggml-opt.cpp b/ggml/src/ggml-opt.cpp index e078ad14a39..53903defa8f 100644 --- a/ggml/src/ggml-opt.cpp +++ b/ggml/src/ggml-opt.cpp @@ -589,6 +589,7 @@ void ggml_opt_free(ggml_opt_context_t opt_ctx) { ggml_backend_buffer_free(opt_ctx->buf_cpu); ggml_free(opt_ctx->ctx_static); ggml_free(opt_ctx->ctx_cpu); + ggml_free(opt_ctx->ctx_copy); delete opt_ctx; } From 2c7472939fd6d29bc14a8feabdace940d86aecd0 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Thu, 9 Apr 2026 01:01:56 +0800 Subject: [PATCH 404/831] CUDA: also store `node->src->data` ptrs for equality check (llama/21635) * CUDA: also store node->src->data ptrs for equality check * address review comments --- ggml/src/ggml-cuda/common.cuh | 6 +++++- ggml/src/ggml-cuda/ggml-cuda.cu | 21 ++++++++++++++------- 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index a2960e5ae3c..65d7a6e22ae 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -1173,7 +1173,11 @@ struct ggml_cuda_graph { std::vector nodes; bool disable_due_to_gpu_arch = false; bool warmup_complete = false; - std::vector nodes_copy; + struct node_properties { + ggml_tensor node; + void * node_src_data_ptrs[GGML_MAX_SRC]; + }; + std::vector node_props; bool is_enabled() const { static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index b21196bb4f3..648124c0d31 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2979,18 +2979,25 @@ static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key); // Check if the graph size has changed - if ((int)graph->nodes_copy.size() != cgraph->n_nodes) { + if ((int)graph->node_props.size() != cgraph->n_nodes) { res = true; - graph->nodes_copy.resize(cgraph->n_nodes); + graph->node_props.resize(cgraph->n_nodes); } for (int i = 0; i < cgraph->n_nodes; i++) { - if (!res) { - if (memcmp(&graph->nodes_copy[i], cgraph->nodes[i], sizeof(ggml_tensor)) != 0) { - res = true; - } + ggml_cuda_graph::node_properties prop = {}; + memcpy(&prop.node, cgraph->nodes[i], sizeof(ggml_tensor)); + + // if the backend scheduler is making copies of CPU tensors, the src pointers can be the same but with different data, see: + // https://github.com/ggml-org/llama.cpp/pull/21472#discussion_r3052235188 + for (int j = 0; j < GGML_MAX_SRC; ++j) { + prop.node_src_data_ptrs[j] = cgraph->nodes[i]->src[j] ? cgraph->nodes[i]->src[j]->data : nullptr; + } + + if (!res && memcmp(&graph->node_props[i], &prop, sizeof(prop)) != 0) { + res = true; } - memcpy(&graph->nodes_copy[i], cgraph->nodes[i], sizeof(ggml_tensor)); + graph->node_props[i] = prop; } return res; From 1d555510dedb4656ad7dedb2279201dccd9f5858 Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Thu, 9 Apr 2026 07:31:51 +0200 Subject: [PATCH 405/831] vulkan: unify type macros to use Vx instead of _VECx (llama/21605) --- .../vulkan-shaders/mul_mat_vec_iface.glsl | 12 +- .../vulkan-shaders/mul_mat_vec_q2_k.comp | 2 +- .../vulkan-shaders/mul_mat_vec_q4_k.comp | 2 +- .../vulkan-shaders/mul_mat_vec_q5_k.comp | 2 +- .../vulkan-shaders/mul_mat_vecq_funcs.glsl | 12 +- .../ggml-vulkan/vulkan-shaders/mul_mm.comp | 16 +- .../vulkan-shaders/mul_mm_funcs.glsl | 192 +++++++++--------- .../vulkan-shaders/mul_mmq_funcs.glsl | 16 +- .../vulkan-shaders/mul_mmq_shmem_types.glsl | 16 +- .../vulkan-shaders/vulkan-shaders-gen.cpp | 64 +++--- 10 files changed, 167 insertions(+), 167 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl index 337dbd796ad..e8d053cdd43 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl @@ -6,8 +6,8 @@ #define MAT_VEC_FUSION_FLAGS_SCALE1 0x8 layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; -#if defined(A_TYPE_VEC4) -layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];}; +#if defined(A_TYPEV4) +layout (binding = 0) readonly buffer AV4 {A_TYPEV4 data_a_v4[];}; #endif #if defined(A_TYPE_PACKED16) layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];}; @@ -17,11 +17,11 @@ layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32 #endif layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; -#ifdef B_TYPE_VEC2 -layout (binding = 1) readonly buffer BV2 {B_TYPE_VEC2 data_b_v2[];}; +#ifdef B_TYPEV2 +layout (binding = 1) readonly buffer BV2 {B_TYPEV2 data_b_v2[];}; #endif -#ifdef B_TYPE_VEC4 -layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];}; +#ifdef B_TYPEV4 +layout (binding = 1) readonly buffer BV4 {B_TYPEV4 data_b_v4[];}; #endif layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp index 619de054cb8..975cec8013f 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp @@ -41,7 +41,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303)); const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303)); - const FLOAT_TYPE_VEC2 dm = vec2(data_a[ib0 + i].dm); + const FLOAT_TYPEV2 dm = vec2(data_a[ib0 + i].dm); [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { vec2 b0 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp index 6af5a81587d..93fbacc6282 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp @@ -14,7 +14,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im, [[unroll]] for (uint n = 0; n < num_rows; ++n) { const uint ib0 = a_offset + (first_row+n)*num_blocks_per_row; - const FLOAT_TYPE_VEC2 dm = FLOAT_TYPE_VEC2(data_a[ib0 + i].dm); + const FLOAT_TYPEV2 dm = FLOAT_TYPEV2(data_a[ib0 + i].dm); const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ]; const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2]; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp index 3695b47b98d..54d7e1bcdca 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp @@ -14,7 +14,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im, [[unroll]] for (uint n = 0; n < num_rows; ++n) { const uint ib0 = a_offset + (first_row+n)*num_blocks_per_row; - const FLOAT_TYPE_VEC2 dm = FLOAT_TYPE_VEC2(data_a[ib0 + i].dm); + const FLOAT_TYPEV2 dm = FLOAT_TYPEV2(data_a[ib0 + i].dm); const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ]; const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2]; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl index 6ddbed309d7..e99108dc50c 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl @@ -11,8 +11,8 @@ FLOAT_TYPE get_dm(uint ib) { #endif #if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1) -FLOAT_TYPE_VEC2 get_dm(uint ib) { - return FLOAT_TYPE_VEC2(data_a_packed32[ib].dm); +FLOAT_TYPEV2 get_dm(uint ib) { + return FLOAT_TYPEV2(data_a_packed32[ib].dm); } #endif @@ -23,9 +23,9 @@ FLOAT_TYPE get_dm(uint ib) { #endif #if defined(DATA_A_Q2_K) -FLOAT_TYPE_VEC2 get_dm(uint ib) { +FLOAT_TYPEV2 get_dm(uint ib) { const uint ib_k = ib / 8; - return FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm); + return FLOAT_TYPEV2(data_a_packed32[ib_k].dm); } #endif @@ -304,7 +304,7 @@ vec2 get_dm_scale(uint ib, uint iqs) { (data_a[ib_k].scales[is+4] >> 4) | ((data_a[ib_k].scales[is ] & 0xC0) >> 2)); } - return FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm) * FLOAT_TYPE_VEC2(scale_dm); + return FLOAT_TYPEV2(data_a_packed32[ib_k].dm) * FLOAT_TYPEV2(scale_dm); } FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) { @@ -422,7 +422,7 @@ vec2 get_dm(uint ib, uint iqs) { const float dl = d * float(2 * bitfieldExtract(qh, 12, 3) + 1); // the -1 cancels out the bias in iq1s_grid_gpu - return FLOAT_TYPE_VEC2(dl, dl * (delta - 1)); + return FLOAT_TYPEV2(dl, dl * (delta - 1)); } FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp index 23f3bd8d6d0..89346e48e06 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp @@ -125,8 +125,8 @@ layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working wit #define SHMEM_STRIDE (BK / 2 + 1) #endif -shared FLOAT_TYPE_VEC2 buf_a[BM * SHMEM_STRIDE]; -shared FLOAT_TYPE_VEC2 buf_b[BN * SHMEM_STRIDE]; +shared FLOAT_TYPEV2 buf_a[BM * SHMEM_STRIDE]; +shared FLOAT_TYPEV2 buf_b[BN * SHMEM_STRIDE]; #define NUM_WARPS (BLOCK_SIZE / WARP) @@ -258,17 +258,17 @@ void main() { sums[i] = coopmat(0.0f); } #else - ACC_TYPE_VEC2 sums[WMITER * TM * WNITER * TN/2]; + ACC_TYPEV2 sums[WMITER * TM * WNITER * TN/2]; #if defined(DATA_A_F32) || defined(DATA_A_F16) - FLOAT_TYPE_VEC4 cache_a[WMITER * TM]; - FLOAT_TYPE_VEC4 cache_b; + FLOAT_TYPEV4 cache_a[WMITER * TM]; + FLOAT_TYPEV4 cache_b; #else - FLOAT_TYPE_VEC2 cache_a[WMITER * TM]; - FLOAT_TYPE_VEC2 cache_b; + FLOAT_TYPEV2 cache_a[WMITER * TM]; + FLOAT_TYPEV2 cache_b; #endif [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN/2; i++) { - sums[i] = ACC_TYPE_VEC2(0.0f, 0.0f); + sums[i] = ACC_TYPEV2(0.0f, 0.0f); } #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl index 3f494eb4d5a..9b769bbc887 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl @@ -3,7 +3,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin #if LOAD_VEC_A == 8 const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; - FLOAT_TYPE_VEC8 aa = FLOAT_TYPE_VEC8(data_a[idx]); + FLOAT_TYPEV8 aa = FLOAT_TYPEV8(data_a[idx]); buf_a[buf_idx ] = aa[0].xy; buf_a[buf_idx + 1] = aa[0].zw; buf_a[buf_idx + 2] = aa[1].xy; @@ -11,38 +11,38 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin #elif LOAD_VEC_A == 4 const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; - FLOAT_TYPE_VEC4 aa = FLOAT_TYPE_VEC4(data_a[idx]); + FLOAT_TYPEV4 aa = FLOAT_TYPEV4(data_a[idx]); buf_a[buf_idx ] = aa.xy; buf_a[buf_idx + 1] = aa.zw; #else // LOAD_VEC_BATCH_A == 2 const uint idx = pos_a + col * p.stride_a + row * 2; const uint buf_idx = col * SHMEM_STRIDE + row; if (idx_m < p.M && block + row * 2 + 1 < end_k) { - buf_a[buf_idx] = FLOAT_TYPE_VEC2(data_a[idx], - data_a[idx + 1]); + buf_a[buf_idx] = FLOAT_TYPEV2(data_a[idx], + data_a[idx + 1]); } else if (idx_m < p.M && block + row * 2 < end_k) { - buf_a[buf_idx] = FLOAT_TYPE_VEC2(data_a[idx], 0.0f); + buf_a[buf_idx] = FLOAT_TYPEV2(data_a[idx], 0.0f); } else { - buf_a[buf_idx] = FLOAT_TYPE_VEC2(0.0f); + buf_a[buf_idx] = FLOAT_TYPEV2(0.0f); } #endif #elif defined(DATA_A_BF16) #if LOAD_VEC_A == 4 const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; - FLOAT_TYPE_VEC4 aa = FLOAT_TYPE_VEC4(TO_FLOAT_TYPE(data_a[idx])); + FLOAT_TYPEV4 aa = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_a[idx])); buf_a[buf_idx ] = aa.xy; buf_a[buf_idx + 1] = aa.zw; #else // LOAD_VEC_BATCH_A == 2 const uint idx = pos_a + col * p.stride_a + row * 2; const uint buf_idx = col * SHMEM_STRIDE + row; if (idx_m < p.M && block + row * 2 + 1 < end_k) { - buf_a[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_a[idx]), - TO_FLOAT_TYPE(data_a[idx + 1])); + buf_a[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_a[idx]), + TO_FLOAT_TYPE(data_a[idx + 1])); } else if (idx_m < p.M && block + row * 2 < end_k) { - buf_a[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_a[idx]), 0.0f); + buf_a[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_a[idx]), 0.0f); } else { - buf_a[buf_idx] = FLOAT_TYPE_VEC2(0.0f); + buf_a[buf_idx] = FLOAT_TYPEV2(0.0f); } #endif #elif defined(DATA_A_Q4_0) @@ -57,10 +57,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const vec4 v0 = (vec4(unpack8(vui & 0x0F0F0F0F)) - 8.0f) * d; const vec4 v1 = (vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) - 8.0f) * d; - buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v0.xy); - buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v0.zw); - buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v1.xy); - buf_a[buf_idx + 9] = FLOAT_TYPE_VEC2(v1.zw); + buf_a[buf_idx ] = FLOAT_TYPEV2(v0.xy); + buf_a[buf_idx + 1] = FLOAT_TYPEV2(v0.zw); + buf_a[buf_idx + 8] = FLOAT_TYPEV2(v1.xy); + buf_a[buf_idx + 9] = FLOAT_TYPEV2(v1.zw); #elif defined(DATA_A_Q4_1) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4; @@ -73,10 +73,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const vec4 v0 = vec4(unpack8(vui & 0x0F0F0F0F)) * dm.x + dm.y; const vec4 v1 = vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) * dm.x + dm.y; - buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v0.xy); - buf_a[buf_idx + 1 ] = FLOAT_TYPE_VEC2(v0.zw); - buf_a[buf_idx + 8 ] = FLOAT_TYPE_VEC2(v1.xy); - buf_a[buf_idx + 9 ] = FLOAT_TYPE_VEC2(v1.zw); + buf_a[buf_idx ] = FLOAT_TYPEV2(v0.xy); + buf_a[buf_idx + 1 ] = FLOAT_TYPEV2(v0.zw); + buf_a[buf_idx + 8 ] = FLOAT_TYPEV2(v1.xy); + buf_a[buf_idx + 9 ] = FLOAT_TYPEV2(v1.zw); #elif defined(DATA_A_Q5_0) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4; @@ -92,8 +92,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const uint vui = uint(data_a_packed16[ib].qs[iqs]); const vec4 v = (vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) - 16.0f) * d; - buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xz); - buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v.yw); + buf_a[buf_idx ] = FLOAT_TYPEV2(v.xz); + buf_a[buf_idx + 8] = FLOAT_TYPEV2(v.yw); #elif defined(DATA_A_Q5_1) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4; @@ -112,10 +112,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const vec4 v0 = vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, ((vui >> 12) & 0xF) | qh1.y) * dm.x + dm.y; const vec4 v1 = vec4(((vui >> 16) & 0xF) | qh2.x, ((vui >> 20) & 0xF) | qh2.y, ((vui >> 24) & 0xF) | qh3.x, ((vui >> 28) & 0xF) | qh3.y) * dm.x + dm.y; - buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v0.xz); - buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v1.xz); - buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v0.yw); - buf_a[buf_idx + 9] = FLOAT_TYPE_VEC2(v1.yw); + buf_a[buf_idx ] = FLOAT_TYPEV2(v0.xz); + buf_a[buf_idx + 1] = FLOAT_TYPEV2(v1.xz); + buf_a[buf_idx + 8] = FLOAT_TYPEV2(v0.yw); + buf_a[buf_idx + 9] = FLOAT_TYPEV2(v1.yw); #elif defined(DATA_A_Q8_0) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; @@ -128,8 +128,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const i8vec2 v1 = unpack8(int32_t(data_a_packed16[ib].qs[2*iqs + 1])).xy; const vec4 v = vec4(v0.x, v0.y, v1.x, v1.y) * d; - buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy); - buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v.zw); + buf_a[buf_idx ] = FLOAT_TYPEV2(v.xy); + buf_a[buf_idx + 1] = FLOAT_TYPEV2(v.zw); #elif defined(DATA_A_Q2_K) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; @@ -147,8 +147,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const vec4 v = dm.x * float(scales & 0xF) * qs - dm.y * float(scales >> 4); - buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy); - buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v.zw); + buf_a[buf_idx ] = FLOAT_TYPEV2(v.xy); + buf_a[buf_idx + 1] = FLOAT_TYPEV2(v.zw); #elif defined(DATA_A_Q3_K) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; @@ -171,8 +171,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const vec2 qs = vec2(unpack8((uint(data_a_packed16[ib].qs[qsi / 2]) >> qsshift) & 0x0303).xy); const vec2 hm = vec2(unpack8(((uint(data_a_packed16[ib].hmask[hmi / 2]) >> (4 * n + halfsplit)) & 0x0101 ^ 0x0101) << 2).xy); - buf_a[buf_idx] = FLOAT_TYPE_VEC2(dl * (qs.x - hm.x), - dl * (qs.y - hm.y)); + buf_a[buf_idx] = FLOAT_TYPEV2(dl * (qs.x - hm.x), + dl * (qs.y - hm.y)); #elif defined(DATA_A_Q4_K) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; @@ -206,8 +206,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const vec4 q = vec4(unpack8((data_a_packed32[ib].qs[qsi / 4] >> (b * 4)) & 0x0F0F0F0F)); - buf_a[buf_idx ] = FLOAT_TYPE_VEC2(fma(d, q.x, m), fma(d, q.y, m)); - buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(fma(d, q.z, m), fma(d, q.w, m)); + buf_a[buf_idx ] = FLOAT_TYPEV2(fma(d, q.x, m), fma(d, q.y, m)); + buf_a[buf_idx + 1] = FLOAT_TYPEV2(fma(d, q.z, m), fma(d, q.w, m)); #elif defined(DATA_A_Q5_K) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; @@ -244,8 +244,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const uint qh = ((data_a_packed32[ib].qh[qhi / 4] >> (iqs / 16)) & 0x01010101) << 4; const vec4 q = vec4(unpack8(qs | qh)); - buf_a[buf_idx ] = FLOAT_TYPE_VEC2(fma(d, q.x, m), fma(d, q.y, m)); - buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(fma(d, q.z, m), fma(d, q.w, m)); + buf_a[buf_idx ] = FLOAT_TYPEV2(fma(d, q.x, m), fma(d, q.y, m)); + buf_a[buf_idx + 1] = FLOAT_TYPEV2(fma(d, q.z, m), fma(d, q.w, m)); #elif defined(DATA_A_Q6_K) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; @@ -267,7 +267,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const uint qh = (uint(data_a_packed16[ib].qh[qhi]) >> qhshift) & 0x0303; const vec2 q = (vec2(unpack8(ql | (qh << 4)).xy) - 32) * dscale; - buf_a[buf_idx] = FLOAT_TYPE_VEC2(q.x, q.y); + buf_a[buf_idx] = FLOAT_TYPEV2(q.x, q.y); #elif defined(DATA_A_IQ1_S) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; @@ -284,8 +284,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const int16_t grid = int16_t(iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]); [[unroll]] for (int k = 0; k < 4; ++k) { - buf_a[buf_idx + k] = FLOAT_TYPE_VEC2(dl * (bitfieldExtract(grid, 4 * k , 2) + delta), - dl * (bitfieldExtract(grid, 4 * k + 2, 2) + delta)); + buf_a[buf_idx + k] = FLOAT_TYPEV2(dl * (bitfieldExtract(grid, 4 * k , 2) + delta), + dl * (bitfieldExtract(grid, 4 * k + 2, 2) + delta)); } #elif defined(DATA_A_IQ1_M) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; @@ -306,8 +306,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]); [[unroll]] for (int k = 0; k < 4; ++k) { - buf_a[buf_idx + k] = FLOAT_TYPE_VEC2(dl * (bitfieldExtract(grid, 4 * k , 2) + delta), - dl * (bitfieldExtract(grid, 4 * k + 2, 2) + delta)); + buf_a[buf_idx + k] = FLOAT_TYPEV2(dl * (bitfieldExtract(grid, 4 * k , 2) + delta), + dl * (bitfieldExtract(grid, 4 * k + 2, 2) + delta)); } #elif defined(DATA_A_IQ2_XXS) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; @@ -332,14 +332,14 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const vec4 grid0 = vec4(unpack8(grid.x)); const vec4 grid1 = vec4(unpack8(grid.y)); - buf_a[buf_idx ] = db * FLOAT_TYPE_VEC2((sign & 1) != 0 ? -grid0.x : grid0.x, - (sign & 2) != 0 ? -grid0.y : grid0.y); - buf_a[buf_idx + 1] = db * FLOAT_TYPE_VEC2((sign & 4) != 0 ? -grid0.z : grid0.z, - (sign & 8) != 0 ? -grid0.w : grid0.w); - buf_a[buf_idx + 2] = db * FLOAT_TYPE_VEC2((sign & 16) != 0 ? -grid1.x : grid1.x, - (sign & 32) != 0 ? -grid1.y : grid1.y); - buf_a[buf_idx + 3] = db * FLOAT_TYPE_VEC2((sign & 64) != 0 ? -grid1.z : grid1.z, - (sign & 128) != 0 ? -grid1.w : grid1.w); + buf_a[buf_idx ] = db * FLOAT_TYPEV2((sign & 1) != 0 ? -grid0.x : grid0.x, + (sign & 2) != 0 ? -grid0.y : grid0.y); + buf_a[buf_idx + 1] = db * FLOAT_TYPEV2((sign & 4) != 0 ? -grid0.z : grid0.z, + (sign & 8) != 0 ? -grid0.w : grid0.w); + buf_a[buf_idx + 2] = db * FLOAT_TYPEV2((sign & 16) != 0 ? -grid1.x : grid1.x, + (sign & 32) != 0 ? -grid1.y : grid1.y); + buf_a[buf_idx + 3] = db * FLOAT_TYPEV2((sign & 64) != 0 ? -grid1.z : grid1.z, + (sign & 128) != 0 ? -grid1.w : grid1.w); #elif defined(DATA_A_IQ2_XS) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; @@ -358,14 +358,14 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const vec4 grid0 = vec4(unpack8(grid.x)); const vec4 grid1 = vec4(unpack8(grid.y)); - buf_a[buf_idx ] = db * FLOAT_TYPE_VEC2((sign & 1) != 0 ? -grid0.x : grid0.x, - (sign & 2) != 0 ? -grid0.y : grid0.y); - buf_a[buf_idx + 1] = db * FLOAT_TYPE_VEC2((sign & 4) != 0 ? -grid0.z : grid0.z, - (sign & 8) != 0 ? -grid0.w : grid0.w); - buf_a[buf_idx + 2] = db * FLOAT_TYPE_VEC2((sign & 16) != 0 ? -grid1.x : grid1.x, - (sign & 32) != 0 ? -grid1.y : grid1.y); - buf_a[buf_idx + 3] = db * FLOAT_TYPE_VEC2((sign & 64) != 0 ? -grid1.z : grid1.z, - (sign & 128) != 0 ? -grid1.w : grid1.w); + buf_a[buf_idx ] = db * FLOAT_TYPEV2((sign & 1) != 0 ? -grid0.x : grid0.x, + (sign & 2) != 0 ? -grid0.y : grid0.y); + buf_a[buf_idx + 1] = db * FLOAT_TYPEV2((sign & 4) != 0 ? -grid0.z : grid0.z, + (sign & 8) != 0 ? -grid0.w : grid0.w); + buf_a[buf_idx + 2] = db * FLOAT_TYPEV2((sign & 16) != 0 ? -grid1.x : grid1.x, + (sign & 32) != 0 ? -grid1.y : grid1.y); + buf_a[buf_idx + 3] = db * FLOAT_TYPEV2((sign & 64) != 0 ? -grid1.z : grid1.z, + (sign & 128) != 0 ? -grid1.w : grid1.w); #elif defined(DATA_A_IQ2_S) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; @@ -386,14 +386,14 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const vec4 grid0 = vec4(unpack8(grid.x)); const vec4 grid1 = vec4(unpack8(grid.y)); - buf_a[buf_idx ] = db * FLOAT_TYPE_VEC2((sign & 1) != 0 ? -grid0.x : grid0.x, - (sign & 2) != 0 ? -grid0.y : grid0.y); - buf_a[buf_idx + 1] = db * FLOAT_TYPE_VEC2((sign & 4) != 0 ? -grid0.z : grid0.z, - (sign & 8) != 0 ? -grid0.w : grid0.w); - buf_a[buf_idx + 2] = db * FLOAT_TYPE_VEC2((sign & 16) != 0 ? -grid1.x : grid1.x, - (sign & 32) != 0 ? -grid1.y : grid1.y); - buf_a[buf_idx + 3] = db * FLOAT_TYPE_VEC2((sign & 64) != 0 ? -grid1.z : grid1.z, - (sign & 128) != 0 ? -grid1.w : grid1.w); + buf_a[buf_idx ] = db * FLOAT_TYPEV2((sign & 1) != 0 ? -grid0.x : grid0.x, + (sign & 2) != 0 ? -grid0.y : grid0.y); + buf_a[buf_idx + 1] = db * FLOAT_TYPEV2((sign & 4) != 0 ? -grid0.z : grid0.z, + (sign & 8) != 0 ? -grid0.w : grid0.w); + buf_a[buf_idx + 2] = db * FLOAT_TYPEV2((sign & 16) != 0 ? -grid1.x : grid1.x, + (sign & 32) != 0 ? -grid1.y : grid1.y); + buf_a[buf_idx + 3] = db * FLOAT_TYPEV2((sign & 64) != 0 ? -grid1.z : grid1.z, + (sign & 128) != 0 ? -grid1.w : grid1.w); #elif defined(DATA_A_IQ3_XXS) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; @@ -414,10 +414,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const uint grid = iq3xxs_grid[qs]; const vec4 v = db * vec4(unpack8(grid)); - buf_a[buf_idx ] = FLOAT_TYPE_VEC2((sign & 1) != 0 ? -v.x : v.x, - (sign & 2) != 0 ? -v.y : v.y); - buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2((sign & 4) != 0 ? -v.z : v.z, - (sign & 8) != 0 ? -v.w : v.w); + buf_a[buf_idx ] = FLOAT_TYPEV2((sign & 1) != 0 ? -v.x : v.x, + (sign & 2) != 0 ? -v.y : v.y); + buf_a[buf_idx + 1] = FLOAT_TYPEV2((sign & 4) != 0 ? -v.z : v.z, + (sign & 8) != 0 ? -v.w : v.w); #elif defined(DATA_A_IQ3_S) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; @@ -436,10 +436,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)]; const vec4 v = db * vec4(unpack8(grid)); - buf_a[buf_idx ] = FLOAT_TYPE_VEC2((sign & 1) != 0 ? -v.x : v.x, - (sign & 2) != 0 ? -v.y : v.y); - buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2((sign & 4) != 0 ? -v.z : v.z, - (sign & 8) != 0 ? -v.w : v.w); + buf_a[buf_idx ] = FLOAT_TYPEV2((sign & 1) != 0 ? -v.x : v.x, + (sign & 2) != 0 ? -v.y : v.y); + buf_a[buf_idx + 1] = FLOAT_TYPEV2((sign & 4) != 0 ? -v.z : v.z, + (sign & 8) != 0 ? -v.w : v.w); #elif defined(DATA_A_IQ4_XS) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; @@ -456,8 +456,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const float d = float(data_a[ib].d); const vec4 v = d * float(int(sl | (sh << 4)) - 32) * vec4(kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y], kvalues_iq4nl[qs.z], kvalues_iq4nl[qs.w]); - buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy); - buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v.zw); + buf_a[buf_idx ] = FLOAT_TYPEV2(v.xy); + buf_a[buf_idx + 1] = FLOAT_TYPEV2(v.zw); #elif defined(DATA_A_IQ4_NL) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4; @@ -468,10 +468,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const FLOAT_TYPE d = FLOAT_TYPE(data_a_packed16[ib].d); const uint vui = uint(data_a_packed16[ib].qs[iqs]); - buf_a[buf_idx ] = d * FLOAT_TYPE_VEC2(kvalues_iq4nl[vui & 0xF], - kvalues_iq4nl[bitfieldExtract(vui, 8, 4)]); - buf_a[buf_idx + 8] = d * FLOAT_TYPE_VEC2(kvalues_iq4nl[bitfieldExtract(vui, 4, 4)], - kvalues_iq4nl[vui >> 12]); + buf_a[buf_idx ] = d * FLOAT_TYPEV2(kvalues_iq4nl[vui & 0xF], + kvalues_iq4nl[bitfieldExtract(vui, 8, 4)]); + buf_a[buf_idx + 8] = d * FLOAT_TYPEV2(kvalues_iq4nl[bitfieldExtract(vui, 4, 4)], + kvalues_iq4nl[vui >> 12]); #elif defined(DATA_A_MXFP4) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4; @@ -483,10 +483,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const uint vui = uint(data_a[ib].qs[iqs]); const uint vui2 = uint(data_a[ib].qs[iqs+1]); - buf_a[buf_idx ] = FLOAT_TYPE_VEC2(kvalues_mxfp4[vui & 0xF] * d, - kvalues_mxfp4[vui2 & 0xF] * d); - buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(kvalues_mxfp4[vui >> 4] * d, - kvalues_mxfp4[vui2 >> 4] * d); + buf_a[buf_idx ] = FLOAT_TYPEV2(kvalues_mxfp4[vui & 0xF] * d, + kvalues_mxfp4[vui2 & 0xF] * d); + buf_a[buf_idx + 8] = FLOAT_TYPEV2(kvalues_mxfp4[vui >> 4] * d, + kvalues_mxfp4[vui2 >> 4] * d); #endif } @@ -496,7 +496,7 @@ void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uin // Not supported for b_type bf16 because bf16mat2x4 does not exist const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2; - FLOAT_TYPE_VEC8 bb = FLOAT_TYPE_VEC8(data_b[idx]); + FLOAT_TYPEV8 bb = FLOAT_TYPEV8(data_b[idx]); buf_b[buf_idx + 0] = bb[0].xy; buf_b[buf_idx + 1] = bb[0].zw; buf_b[buf_idx + 2] = bb[1].xy; @@ -505,9 +505,9 @@ void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uin const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2; #if defined(DATA_B_BF16) - FLOAT_TYPE_VEC4 bb = FLOAT_TYPE_VEC4(TO_FLOAT_TYPE(data_b[idx])); + FLOAT_TYPEV4 bb = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_b[idx])); #else - FLOAT_TYPE_VEC4 bb = FLOAT_TYPE_VEC4(data_b[idx]); + FLOAT_TYPEV4 bb = FLOAT_TYPEV4(data_b[idx]); #endif buf_b[buf_idx + 0] = bb.xy; buf_b[buf_idx + 1] = bb.zw; @@ -515,12 +515,12 @@ void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uin const uint idx = pos_b + col * p.stride_b + row * 2; const uint buf_idx = col * SHMEM_STRIDE + row; if (idx_n < p.N && block + row * 2 + 1 < end_k) { - buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]), - TO_FLOAT_TYPE(data_b[idx + 1])); + buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b[idx]), + TO_FLOAT_TYPE(data_b[idx + 1])); } else if (idx_n < p.N && block + row * 2 < end_k) { - buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]), 0.0f); + buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b[idx]), 0.0f); } else { - buf_b[buf_idx] = FLOAT_TYPE_VEC2(0.0f); + buf_b[buf_idx] = FLOAT_TYPEV2(0.0f); } #endif } @@ -531,7 +531,7 @@ void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uin const u16vec2 row_idx = row_ids[col]; const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2; - FLOAT_TYPE_VEC8 bb = FLOAT_TYPE_VEC8(data_b[idx]); + FLOAT_TYPEV8 bb = FLOAT_TYPEV8(data_b[idx]); buf_b[buf_idx + 0] = bb[0].xy; buf_b[buf_idx + 1] = bb[0].zw; buf_b[buf_idx + 2] = bb[1].xy; @@ -541,9 +541,9 @@ void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uin const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2; #if defined(DATA_B_BF16) - FLOAT_TYPE_VEC4 bb = FLOAT_TYPE_VEC4(TO_FLOAT_TYPE(data_b[idx])); + FLOAT_TYPEV4 bb = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_b[idx])); #else - FLOAT_TYPE_VEC4 bb = FLOAT_TYPE_VEC4(data_b[idx]); + FLOAT_TYPEV4 bb = FLOAT_TYPEV4(data_b[idx]); #endif buf_b[buf_idx + 0] = bb.xy; buf_b[buf_idx + 1] = bb.zw; @@ -553,14 +553,14 @@ void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uin if (row_i < _ne1 && block + row * 2 + 1 < end_k) { const u16vec2 row_idx = row_ids[col]; const uint idx = pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2; - buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]), - TO_FLOAT_TYPE(data_b[idx + 1])); + buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b[idx]), + TO_FLOAT_TYPE(data_b[idx + 1])); } else if (row_i < _ne1 && block + row * 2 < end_k) { const u16vec2 row_idx = row_ids[col]; const uint idx = pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2; - buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]), 0.0f); + buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b[idx]), 0.0f); } else { - buf_b[buf_idx] = FLOAT_TYPE_VEC2(0.0f); + buf_b[buf_idx] = FLOAT_TYPEV2(0.0f); } #endif } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl index 9c297d1c60d..59931b04b94 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl @@ -21,7 +21,7 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { buf_a[buf_ib].qs[iqs] = data_a_packed32[ib].qs[iqs]; if (iqs == 0) { - buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib].dm); + buf_a[buf_ib].dm = FLOAT_TYPEV2(data_a_packed32[ib].dm); } #endif } @@ -72,7 +72,7 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { buf_a[buf_ib].qs[iqs] = data_a_packed32[ib].qs[iqs]; if (iqs == 0) { - buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib].dm); + buf_a[buf_ib].dm = FLOAT_TYPEV2(data_a_packed32[ib].dm); buf_a[buf_ib].qh = data_a_packed32[ib].qh; } #endif @@ -203,7 +203,7 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { buf_a[buf_ib].qs[iqs] = vals0 | (vals1 << 2) | (vals2 << 4) | (vals3 << 6); if (iqs == 0) { - buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm); + buf_a[buf_ib].dm = FLOAT_TYPEV2(data_a_packed32[ib_k].dm); buf_a[buf_ib].scales = unpack8(uint32_t(data_a_packed16[ib_k].scales[iqs_k / 8])).xy; // vec4 used due to #12147 } } @@ -264,7 +264,7 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { const i8vec2 scales = i8vec2(unpack8(uint32_t(((data_a_packed16[ib_k].scales[(is % 8 ) / 2] >> (4 * (is / 8))) & 0x0F0F) | (((data_a_packed16[ib_k].scales[(8 + (is % 4)) / 2] >> (2 * (is / 4))) & 0x0303) << 4))).xy); // vec4 used due to #12147 - buf_a[buf_ib].d_scales = FLOAT_TYPE_VEC2(float(data_a_packed16[ib_k].d) * vec2(scales - 32)); + buf_a[buf_ib].d_scales = FLOAT_TYPEV2(float(data_a_packed16[ib_k].d) * vec2(scales - 32)); } } @@ -334,7 +334,7 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { (data_a[ib_k].scales[is+4] >> 4) | ((data_a[ib_k].scales[is ] & 0xC0) >> 2)); } - buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(vec2(data_a_packed32[ib_k].dm) * vec2(scale_dm)); + buf_a[buf_ib].dm = FLOAT_TYPEV2(vec2(data_a_packed32[ib_k].dm) * vec2(scale_dm)); } } @@ -385,7 +385,7 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { const uint is = iqs_k / 4; const i8vec2 scales = unpack8(int32_t(data_a_packed16[ib_k].scales[is / 2])).xy; - buf_a[buf_ib].d_scales = FLOAT_TYPE_VEC2(float(data_a_packed16[ib_k].d) * vec2(scales)); + buf_a[buf_ib].d_scales = FLOAT_TYPEV2(float(data_a_packed16[ib_k].d) * vec2(scales)); } } @@ -426,7 +426,7 @@ void block_b_to_shmem(const uint buf_ib, const uint ib, const uint iqs, const bo const uint ib_inner = ib % 4; if (iqs == 0) { - buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]); + buf_b[buf_ib].ds = FLOAT_TYPEV2(data_b[ib_outer].ds[ib_inner]); } const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs]; @@ -436,7 +436,7 @@ void block_b_to_shmem(const uint buf_ib, const uint ib, const uint iqs, const bo buf_b[buf_ib].qs[iqs * 4 + 3] = values.w; } else { if (iqs == 0) { - buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(0.0f); + buf_b[buf_ib].ds = FLOAT_TYPEV2(0.0f); } buf_b[buf_ib].qs[iqs * 4 ] = 0; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl index 1c0f5306f38..c700f6e3f25 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl @@ -8,7 +8,7 @@ struct block_a_cache { #define QUANT_R_MMQ 2 struct block_a_cache { uint32_t qs[16/4]; - FLOAT_TYPE_VEC2 dm; + FLOAT_TYPEV2 dm; }; #elif defined(DATA_A_Q5_0) #define QUANT_R_MMQ 2 @@ -22,7 +22,7 @@ struct block_a_cache { struct block_a_cache { uint32_t qs[16/4]; uint32_t qh; - FLOAT_TYPE_VEC2 dm; + FLOAT_TYPEV2 dm; }; #elif defined(DATA_A_Q8_0) #define QUANT_R_MMQ 1 @@ -43,36 +43,36 @@ struct block_a_cache { struct block_a_cache { uint32_t qs[2]; u8vec2 scales; - FLOAT_TYPE_VEC2 dm; + FLOAT_TYPEV2 dm; }; #elif defined(DATA_A_Q3_K) #define QUANT_R_MMQ 2 struct block_a_cache { uint32_t qs[4]; - FLOAT_TYPE_VEC2 d_scales; + FLOAT_TYPEV2 d_scales; }; #elif defined(DATA_A_Q4_K) #define QUANT_R_MMQ 2 struct block_a_cache { uint32_t qs[4]; - FLOAT_TYPE_VEC2 dm; + FLOAT_TYPEV2 dm; }; #elif defined(DATA_A_Q5_K) #define QUANT_R_MMQ 1 struct block_a_cache { int32_t qs[8]; - FLOAT_TYPE_VEC2 dm; + FLOAT_TYPEV2 dm; }; #elif defined(DATA_A_Q6_K) #define QUANT_R_MMQ 1 struct block_a_cache { int32_t qs[8]; - FLOAT_TYPE_VEC2 d_scales; + FLOAT_TYPEV2 d_scales; }; #endif struct block_b_cache { int32_t qs[8]; - FLOAT_TYPE_VEC2 ds; + FLOAT_TYPEV2 ds; }; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 7afdcef7d22..11385f93378 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -446,8 +446,8 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c base_dict["FLOAT16"] = "1"; } - base_dict["ACC_TYPE" ] = f16acc ? "float16_t" : "float"; - base_dict["ACC_TYPE_VEC2"] = f16acc ? "f16vec2" : "vec2"; + base_dict["ACC_TYPE" ] = f16acc ? "float16_t" : "float"; + base_dict["ACC_TYPEV2"] = f16acc ? "f16vec2" : "vec2"; if (f16acc) { base_dict["ACC_TYPE_MAX"] = "float16_t(65504.0)"; } @@ -514,10 +514,10 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c }; const std::map float_type_dict_f16 = { - {"FLOAT_TYPE", FLOAT_TYPE(1, "f16")}, - {"FLOAT_TYPE_VEC2", FLOAT_TYPE(2, "f16")}, - {"FLOAT_TYPE_VEC4", FLOAT_TYPE(4, "f16")}, - {"FLOAT_TYPE_VEC8", FLOAT_TYPE(8, "f16")}, + {"FLOAT_TYPE", FLOAT_TYPE(1, "f16")}, + {"FLOAT_TYPEV2", FLOAT_TYPE(2, "f16")}, + {"FLOAT_TYPEV4", FLOAT_TYPE(4, "f16")}, + {"FLOAT_TYPEV8", FLOAT_TYPE(8, "f16")}, }; // Shaders with f16 B_TYPE @@ -536,9 +536,9 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c std::string to_float_type = (coopmat || coopmat2) ? "uintBitsToBFloat16EXT" : "bf16_to_fp32"; const std::map float_type_dict_bf16 = { - {"FLOAT_TYPE", FLOAT_TYPE(1, "bf16")}, - {"FLOAT_TYPE_VEC2", FLOAT_TYPE(2, "bf16")}, - {"FLOAT_TYPE_VEC4", FLOAT_TYPE(4, "bf16")}, + {"FLOAT_TYPE", FLOAT_TYPE(1, "bf16")}, + {"FLOAT_TYPEV2", FLOAT_TYPE(2, "bf16")}, + {"FLOAT_TYPEV4", FLOAT_TYPE(4, "bf16")}, }; // If bfloat16 is not supported, then only compile the scalar (promote to fp32) shader @@ -569,10 +569,10 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? load_vec : load_vec_quant; const std::map float_type_dict = { - {"FLOAT_TYPE", FLOAT_TYPE(1, tname)}, - {"FLOAT_TYPE_VEC2", FLOAT_TYPE(2, tname)}, - {"FLOAT_TYPE_VEC4", FLOAT_TYPE(4, tname)}, - {"FLOAT_TYPE_VEC8", FLOAT_TYPE(8, tname)}, + {"FLOAT_TYPE", FLOAT_TYPE(1, tname)}, + {"FLOAT_TYPEV2", FLOAT_TYPE(2, tname)}, + {"FLOAT_TYPEV4", FLOAT_TYPE(4, tname)}, + {"FLOAT_TYPEV8", FLOAT_TYPE(8, tname)}, }; // don't generate f32 variants for coopmat2 @@ -676,36 +676,36 @@ void process_shaders() { } } - std::map base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}}; + std::map base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}}; for (const auto& tname : type_names) { // mul mat vec std::string data_a_key = "DATA_A_" + to_uppercase(tname); std::string shader = (string_ends_with(tname, "_k") || string_starts_with(tname, "iq1_") || string_starts_with(tname, "iq2_") || string_starts_with(tname, "iq3_")) ? "mul_mat_vec_" + tname + ".comp" : "mul_mat_vec.comp"; - string_to_spv("mul_mat_vec_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}})); - string_to_spv("mul_mat_vec_" + tname + "_f16_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}})); + string_to_spv("mul_mat_vec_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPEV2", "vec2"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}})); + string_to_spv("mul_mat_vec_" + tname + "_f16_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV2", "f16vec2"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}})); - string_to_spv("mul_mat_vec_" + tname + "_f32_f32_subgroup", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); - string_to_spv("mul_mat_vec_" + tname + "_f16_f32_subgroup", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); + string_to_spv("mul_mat_vec_" + tname + "_f32_f32_subgroup", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPEV2", "vec2"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); + string_to_spv("mul_mat_vec_" + tname + "_f16_f32_subgroup", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV2", "f16vec2"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); - string_to_spv("mul_mat_vec_" + tname + "_f32_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); - string_to_spv("mul_mat_vec_" + tname + "_f16_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); + string_to_spv("mul_mat_vec_" + tname + "_f32_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPEV2", "vec2"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); + string_to_spv("mul_mat_vec_" + tname + "_f16_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV2", "f16vec2"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); - string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}})); - string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32_subgroup", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); - string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); + string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPEV2", "vec2"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}})); + string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32_subgroup", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPEV2", "vec2"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); + string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPEV2", "vec2"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); // mul mat vec with integer dot product #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) if (is_legacy_quant(tname) || tname == "mxfp4" || is_k_quant(tname) || tname == "iq1_s" || tname == "iq1_m") { - string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}})); - string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); - string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); + string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}, {"ACC_TYPE", "float"}})); + string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); + string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); - string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}})); - string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); - string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); + string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}, {"ACC_TYPE", "float"}})); + string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); + string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); } #endif @@ -726,9 +726,9 @@ void process_shaders() { string_to_spv("get_rows_i32", "get_rows.comp", {{"TEMP_TYPE", "uint"}, {"A_TYPE", "uint"}, {"B_TYPE", "int"}, {"D_TYPE", "uint"}}); - string_to_spv("mul_mat_vec_p021_f16_f32_subgroup_add", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}); - string_to_spv("mul_mat_vec_p021_f16_f32", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}); - string_to_spv("mul_mat_vec_nc_f16_f32", "mul_mat_vec_nc.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}); + string_to_spv("mul_mat_vec_p021_f16_f32_subgroup_add", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPEV4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}); + string_to_spv("mul_mat_vec_p021_f16_f32", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPEV4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}}); + string_to_spv("mul_mat_vec_nc_f16_f32", "mul_mat_vec_nc.comp", {{"A_TYPE", "float16_t"}, {"A_TYPEV4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}}); // Norms string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); From 4598eb080b513752b90663fe37b9c92ecbf9e5b1 Mon Sep 17 00:00:00 2001 From: Akarshan Biswas Date: Thu, 9 Apr 2026 12:06:48 +0530 Subject: [PATCH 406/831] sycl : add flash-attn support for head size 512 (llama/21654) * sycl : add flash-attn support for head size 512 This patch extends the SYCL Flash Attention implementation to support head sizes (DKQ/DV) of 512. Changes: - Added DKQ/DV 512 cases to both tile and vector Flash Attention kernels. - Updated kernel selection logic to allow vector kernels for head sizes up to 512 (previously 256). - Removed unused/redundant AMD and RDNA-specific configuration functions in `fattn-tile.hpp`. - Refactored `ggml_backend_sycl_buffer_init_tensor` to use a switch statement for clearer tensor extra buffer initialization. - Added necessary template instances for the new 512 head size across various quantization types. * remove defunct mxfp4 reorder from setting buffer type --- ggml/src/ggml-sycl/fattn-tile.cpp | 4 + ggml/src/ggml-sycl/fattn-tile.hpp | 151 +++--------------- ggml/src/ggml-sycl/fattn-vec.hpp | 7 + ggml/src/ggml-sycl/fattn.cpp | 4 +- ggml/src/ggml-sycl/ggml-sycl.cpp | 21 ++- .../fattn-tile-instance-dkq512-dv512.cpp | 6 + .../fattn-vec-instance-f16-f16.cpp | 1 + .../fattn-vec-instance-f16-q4_0.cpp | 1 + .../fattn-vec-instance-f16-q4_1.cpp | 1 + .../fattn-vec-instance-f16-q5_0.cpp | 1 + .../fattn-vec-instance-f16-q5_1.cpp | 1 + .../fattn-vec-instance-f16-q8_0.cpp | 1 + .../fattn-vec-instance-q4_0-f16.cpp | 1 + .../fattn-vec-instance-q4_0-q4_0.cpp | 1 + .../fattn-vec-instance-q4_0-q4_1.cpp | 1 + .../fattn-vec-instance-q4_0-q5_0.cpp | 1 + .../fattn-vec-instance-q4_0-q5_1.cpp | 1 + .../fattn-vec-instance-q4_0-q8_0.cpp | 1 + .../fattn-vec-instance-q4_1-f16.cpp | 1 + .../fattn-vec-instance-q4_1-q4_0.cpp | 1 + .../fattn-vec-instance-q4_1-q4_1.cpp | 1 + .../fattn-vec-instance-q4_1-q5_0.cpp | 1 + .../fattn-vec-instance-q4_1-q5_1.cpp | 1 + .../fattn-vec-instance-q4_1-q8_0.cpp | 1 + .../fattn-vec-instance-q5_0-f16.cpp | 1 + .../fattn-vec-instance-q5_0-q4_0.cpp | 1 + .../fattn-vec-instance-q5_0-q4_1.cpp | 1 + .../fattn-vec-instance-q5_0-q5_0.cpp | 1 + .../fattn-vec-instance-q5_0-q5_1.cpp | 1 + .../fattn-vec-instance-q5_0-q8_0.cpp | 1 + .../fattn-vec-instance-q5_1-f16.cpp | 1 + .../fattn-vec-instance-q5_1-q4_0.cpp | 1 + .../fattn-vec-instance-q5_1-q4_1.cpp | 1 + .../fattn-vec-instance-q5_1-q5_0.cpp | 1 + .../fattn-vec-instance-q5_1-q5_1.cpp | 1 + .../fattn-vec-instance-q5_1-q8_0.cpp | 1 + .../fattn-vec-instance-q8_0-f16.cpp | 1 + .../fattn-vec-instance-q8_0-q4_0.cpp | 1 + .../fattn-vec-instance-q8_0-q4_1.cpp | 1 + .../fattn-vec-instance-q8_0-q5_0.cpp | 1 + .../fattn-vec-instance-q8_0-q5_1.cpp | 1 + .../fattn-vec-instance-q8_0-q8_0.cpp | 1 + 42 files changed, 95 insertions(+), 134 deletions(-) create mode 100644 ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq512-dv512.cpp diff --git a/ggml/src/ggml-sycl/fattn-tile.cpp b/ggml/src/ggml-sycl/fattn-tile.cpp index 9d4f019cf51..9449d75784d 100644 --- a/ggml/src/ggml-sycl/fattn-tile.cpp +++ b/ggml/src/ggml-sycl/fattn-tile.cpp @@ -44,6 +44,10 @@ void ggml_sycl_flash_attn_ext_tile(ggml_backend_sycl_context & ctx, ggml_tensor GGML_ASSERT(V->ne[0] == K->ne[0]); ggml_sycl_flash_attn_ext_tile_case<256, 256>(ctx, dst); } break; + case 512: { + GGML_ASSERT(V->ne[0] == K->ne[0]); + ggml_sycl_flash_attn_ext_tile_case<512, 512>(ctx, dst); + } break; case 576: { GGML_ASSERT(V->ne[0] == 512); ggml_sycl_flash_attn_ext_tile_case<576, 512>(ctx, dst); diff --git a/ggml/src/ggml-sycl/fattn-tile.hpp b/ggml/src/ggml-sycl/fattn-tile.hpp index b4d4e0ae90e..9ba5296968d 100644 --- a/ggml/src/ggml-sycl/fattn-tile.hpp +++ b/ggml/src/ggml-sycl/fattn-tile.hpp @@ -67,6 +67,12 @@ static constexpr uint32_t ggml_sycl_fattn_tile_get_config_fp16(const int DKQ, co GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64) GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 2, 64, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 32, 256, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64) GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64) GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64) @@ -124,6 +130,12 @@ static constexpr uint32_t ggml_sycl_fattn_tile_get_config_fp32(const int DKQ, co GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128) GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 2, 128, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 32, 256, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 32, 64) GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 32, 64) GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 32, 64) @@ -131,134 +143,6 @@ static constexpr uint32_t ggml_sycl_fattn_tile_get_config_fp32(const int DKQ, co return 0; } -static constexpr uint32_t ggml_sycl_fattn_tile_get_config_amd(const int DKQ, const int DV, const int ncols) { - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 2, 64, 2, 32, 40) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 4, 128, 2, 32, 40) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 8, 256, 2, 32, 40) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 16, 256, 2, 32, 40) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 32, 256, 2, 32, 40) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 64, 256, 2, 32, 40) - - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 2, 64, 3, 32, 64) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 4, 128, 3, 64, 64) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 8, 128, 2, 32, 64) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 16, 256, 2, 128, 64) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 256, 2, 64, 64) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 64, 256, 2, 64, 64) - - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 2, 64, 2, 32, 72) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 4, 128, 2, 32, 72) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 8, 256, 2, 32, 72) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 16, 256, 2, 32, 72) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 32, 256, 2, 32, 72) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 64, 256, 2, 32, 72) - - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 32, 40) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 32, 40) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 32, 40) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 16, 256, 2, 32, 40) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 32, 256, 2, 32, 40) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 64, 256, 2, 32, 40) - - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 2, 64, 2, 32, 48) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 4, 128, 2, 32, 48) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 8, 256, 2, 32, 48) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 16, 256, 2, 32, 48) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 32, 256, 2, 32, 48) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 64, 256, 2, 32, 48) - - GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 2, 64, 2, 32, 56) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 4, 128, 2, 32, 56) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 8, 256, 2, 32, 56) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2, 32, 56) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2, 32, 56) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 64, 256, 2, 32, 56) - - GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 2, 256, 2, 128, 64) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 4, 128, 2, 64, 128) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 8, 256, 2, 64, 128) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 16, 256, 2, 64, 128) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2, 64, 64) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 64, 256, 2, 64, 32) - - GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 2, 256, 2, 128, 64) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 4, 256, 2, 64, 128) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 8, 256, 2, 64, 128) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 128) - - GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 32, 512, 1, 128, 64) - - return 0; -} - -static constexpr uint32_t ggml_sycl_fattn_tile_get_config_amd_rdna(const int DKQ, const int DV, const int ncols) { - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 2, 64, 2, 32, 40) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 4, 128, 2, 32, 40) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 8, 256, 2, 32, 40) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 16, 256, 2, 32, 40) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 32, 256, 2, 32, 40) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 64, 256, 2, 32, 40) - - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 2, 64, 8, 32, 64) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 4, 64, 8, 32, 64) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 8, 128, 5, 128, 64) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 16, 128, 5, 128, 64) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 128, 4, 64, 64) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 64, 128, 5, 64, 64) - - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 2, 64, 2, 32, 72) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 4, 128, 2, 32, 72) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 8, 256, 2, 32, 72) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 16, 256, 2, 32, 72) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 32, 256, 2, 32, 72) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 64, 256, 2, 32, 72) - - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 32, 40) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 32, 40) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 32, 40) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 16, 256, 2, 32, 40) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 32, 256, 2, 32, 40) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 64, 256, 2, 32, 40) - - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 2, 64, 2, 32, 48) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 4, 128, 2, 32, 48) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 8, 256, 2, 32, 48) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 16, 256, 2, 32, 48) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 32, 256, 2, 32, 48) - GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 64, 256, 2, 32, 48) - - GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 2, 64, 2, 32, 56) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 4, 128, 2, 32, 56) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 8, 256, 2, 32, 56) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2, 32, 56) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2, 32, 56) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 64, 256, 2, 32, 56) - - GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 2, 64, 8, 32, 64) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 4, 128, 8, 64, 64) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 8, 128, 8, 64, 64) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 16, 256, 3, 128, 128) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 3, 128, 64) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 64, 256, 3, 64, 64) - - GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 2, 64, 8, 32, 64) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 4, 128, 6, 32, 256) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 8, 128, 6, 32, 256) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5, 32, 256) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3, 64, 128) - - GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 4, 64, 64) - GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 32, 256, 2, 128, 64) - - return 0; -} - static constexpr uint32_t ggml_sycl_fattn_tile_get_config(const int DKQ, const int DV, const int ncols, const int cc) { if(fast_fp16_available(cc)) return ggml_sycl_fattn_tile_get_config_fp16(DKQ, DV, ncols); @@ -1293,6 +1177,16 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_sycl_context & ctx, ggm launch_fattn_tile_switch_ncols1(ctx, dst); return; } + // ncols2=2 and ncols2=1 fallbacks only for cases where ncols=2 config exists (DKQ == DV). + // For DKQ == 576, DV == 512 only GQA-optimized variants are implemented. + if constexpr (DKQ == DV) { + if (use_gqa_opt && gqa_ratio % 2 == 0) { + launch_fattn_tile_switch_ncols1(ctx, dst); + return; + } + launch_fattn_tile_switch_ncols1(ctx, dst); + return; + } } if constexpr (DV <= 256) { @@ -1347,5 +1241,6 @@ extern DECL_FATTN_TILE_CASE( 96, 96); extern DECL_FATTN_TILE_CASE(112, 112); extern DECL_FATTN_TILE_CASE(128, 128); extern DECL_FATTN_TILE_CASE(256, 256); +extern DECL_FATTN_TILE_CASE(512, 512); extern DECL_FATTN_TILE_CASE(576, 512); diff --git a/ggml/src/ggml-sycl/fattn-vec.hpp b/ggml/src/ggml-sycl/fattn-vec.hpp index 48c389052f4..8031acfdff8 100644 --- a/ggml/src/ggml-sycl/fattn-vec.hpp +++ b/ggml/src/ggml-sycl/fattn-vec.hpp @@ -664,4 +664,11 @@ EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_0) EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_1) EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q8_0) +EXTERN_DECL_FATTN_VEC_CASES(512, GGML_TYPE_F16) +EXTERN_DECL_FATTN_VEC_CASES(512, GGML_TYPE_Q4_0) +EXTERN_DECL_FATTN_VEC_CASES(512, GGML_TYPE_Q4_1) +EXTERN_DECL_FATTN_VEC_CASES(512, GGML_TYPE_Q5_0) +EXTERN_DECL_FATTN_VEC_CASES(512, GGML_TYPE_Q5_1) +EXTERN_DECL_FATTN_VEC_CASES(512, GGML_TYPE_Q8_0) + #endif // GGML_SYCL_FATTN_VEC_HPP diff --git a/ggml/src/ggml-sycl/fattn.cpp b/ggml/src/ggml-sycl/fattn.cpp index c276ed89827..7c6e6112fdc 100644 --- a/ggml/src/ggml-sycl/fattn.cpp +++ b/ggml/src/ggml-sycl/fattn.cpp @@ -34,6 +34,7 @@ FATTN_VEC_CASE( 64, type_K, type_V) \ FATTN_VEC_CASE(128, type_K, type_V) \ FATTN_VEC_CASE(256, type_K, type_V) \ + FATTN_VEC_CASE(512, type_K, type_V) \ static void ggml_sycl_flash_attn_ext_vec(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { ggml_tensor * Q = dst->src[0]; @@ -141,6 +142,7 @@ static best_fattn_kernel ggml_sycl_get_best_fattn_kernel(const int device, const case 128: case 112: case 256: + case 512: if (V->ne[0] != K->ne[0]) { return BEST_FATTN_KERNEL_NONE; } @@ -185,7 +187,7 @@ static best_fattn_kernel ggml_sycl_get_best_fattn_kernel(const int device, const } // For small batch sizes the vector kernel may be preferable over the kernels optimized for large batch sizes: - const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0 && K->ne[1] % FATTN_KQ_STRIDE == 0; + const bool can_use_vector_kernel = Q->ne[0] <= 512 && Q->ne[0] % 64 == 0 && K->ne[1] % FATTN_KQ_STRIDE == 0; // Todo: Use the XMX kernel if possible: diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index e80ead9aea4..7f9b2df524e 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -411,11 +411,22 @@ ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer, assert(tensor->view_src->buffer->buft == buffer->buft); return GGML_STATUS_SUCCESS; } - if ((tensor->type == GGML_TYPE_Q4_0 || tensor->type == GGML_TYPE_Q8_0 || tensor->type == GGML_TYPE_Q4_K || tensor->type == GGML_TYPE_Q6_K) && - !g_ggml_sycl_disable_optimize) { - ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{}; - tensor->extra = extra; - ctx->tensor_extras.push_back(extra); //used to release it when destroy ctx. + + if (!g_ggml_sycl_disable_optimize) { + // set reorder extra buffer based on supported type + switch (tensor->type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q6_K:{ + ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{}; + tensor->extra = extra; + ctx->tensor_extras.push_back(extra); + break; + } + default: + break; + } } if (ggml_is_quantized(tensor->type)) { diff --git a/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq512-dv512.cpp b/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq512-dv512.cpp new file mode 100644 index 00000000000..9a6a1877566 --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq512-dv512.cpp @@ -0,0 +1,6 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.hpp" + +DECL_FATTN_TILE_CASE(512, 512); + diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp index 32cf4f2859b..43ef94c118c 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16); DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16); DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_F16, GGML_TYPE_F16); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp index a61a19021bb..9404061d456 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0); DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0); DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_F16, GGML_TYPE_Q4_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp index 63b74fb347a..a8bb9f52d0c 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1); DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1); DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_F16, GGML_TYPE_Q4_1); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp index 46e2d9853c5..7d61f6ab0af 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0); DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0); DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_F16, GGML_TYPE_Q5_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp index 7aabb6ff6e4..753bae09f83 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1); DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1); DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_F16, GGML_TYPE_Q5_1); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp index 148ea217f62..546a93b2570 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0); DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0); DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_F16, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp index 4b169dbcdbc..53c8c2f2654 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_F16); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_0, GGML_TYPE_F16); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp index 79f530b1815..5b409c55f21 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp index 2f7db51ce82..8c4ef588d63 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp index 9e3bf0b14a1..83f0a07552e 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp index 18081879cec..9df9b03bba4 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp index 1c387b0d87c..6980c2a65bb 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp index f005b3762cc..bd61bc1dc2b 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_F16); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_1, GGML_TYPE_F16); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp index 3553b1cdd16..492e229a58e 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp index 687ec567115..30f88a2ebd5 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp index 2663bfe7466..db76663604e 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp index 641b7c7ae2a..1dbcc8a85a8 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp index 3d3181d4719..d30996a6259 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp index 85d5026ad4f..bc0f635d922 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_F16); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_0, GGML_TYPE_F16); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp index 1e81401a2c9..9e0378107cb 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp index 54251473f97..a8535ac9156 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp index d418c1fb21e..43d4fae9a61 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp index 0f26cfabd09..23335a41640 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp index 4fb98723519..52550a33757 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp index 85b79cd1976..4651f14c050 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_F16); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_1, GGML_TYPE_F16); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp index 7348323b28b..2310fd8792c 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp index f19af2aa0ba..d2494048bc1 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp index d7075bac600..be3a1fe97f5 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp index 627f9a57755..be0a89409ca 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp index 23304eecd35..6781efcb0d2 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp index 95acb5d4fbf..43a70ae3543 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_F16); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q8_0, GGML_TYPE_F16); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp index 5e88f4bab8a..fa7eb8163ca 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp index 69f297feb0c..79d9cfbee96 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp index 455842a9421..86befd5d327 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp index f7ef7391571..c2f619b0b16 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp index 1c633bdf2fa..7cf31f8b8a1 100644 --- a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp @@ -5,3 +5,4 @@ DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0); DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0); DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0); From f0ee409f7b3c1ae1e0b3c1139ef8e7e02b0bb3b3 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 9 Apr 2026 10:54:00 +0300 Subject: [PATCH 407/831] metal : add missing mm-id specializations for q1_0 (llama/21662) --- ggml/src/ggml-metal/ggml-metal.metal | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index f28bfa0b95b..f67c5cd8a1d 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -10079,6 +10079,7 @@ template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mul_mm_id kernel_m #if defined(GGML_METAL_HAS_BF16) template [[host_name("kernel_mul_mm_id_bf16_f32")]] kernel mul_mm_id kernel_mul_mm_id; #endif +template [[host_name("kernel_mul_mm_id_q1_0_f32")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mul_mm_id kernel_mul_mm_id; @@ -10102,6 +10103,7 @@ template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mul_mm_id kernel_m template [[host_name("kernel_mul_mm_id_f32_f16")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_f16_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q1_0_f16")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q4_0_f16")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q4_1_f16")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_q5_0_f16")]] kernel mul_mm_id kernel_mul_mm_id; From c4c6e143a7731627be5f6d72c4738ac3ca066bd6 Mon Sep 17 00:00:00 2001 From: fairydreaming <166155368+fairydreaming@users.noreply.github.com> Date: Thu, 9 Apr 2026 15:17:11 +0200 Subject: [PATCH 408/831] ggml : check return value of CUB calls used in argsort and top-k (they all return cudaError_t) (llama/21676) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Stanisław Szymczyk --- ggml/src/ggml-cuda/argsort.cu | 32 ++++++++++++++++---------------- ggml/src/ggml-cuda/top-k.cu | 8 ++++---- 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/ggml/src/ggml-cuda/argsort.cu b/ggml/src/ggml-cuda/argsort.cu index 38fdf3678c1..ed4e5de70f5 100644 --- a/ggml/src/ggml-cuda/argsort.cu +++ b/ggml/src/ggml-cuda/argsort.cu @@ -60,24 +60,24 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool, if (order == GGML_SORT_ORDER_ASC) { if (nrows == 1) { - DeviceRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) + CUDA_CHECK(DeviceRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) temp_indices, dst, // values (indices) - ncols, 0, sizeof(float) * 8, stream); + ncols, 0, sizeof(float) * 8, stream)); } else { - DeviceSegmentedSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) + CUDA_CHECK(DeviceSegmentedSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) temp_indices, dst, // values (indices) ncols * nrows, nrows, // num items, num segments - offset_iterator, offset_iterator + 1, stream); + offset_iterator, offset_iterator + 1, stream)); } } else { if (nrows == 1) { - DeviceRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) + CUDA_CHECK(DeviceRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) temp_indices, dst, // values (indices) - ncols, 0, sizeof(float) * 8, stream); + ncols, 0, sizeof(float) * 8, stream)); } else { - DeviceSegmentedSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices, + CUDA_CHECK(DeviceSegmentedSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst, ncols * nrows, nrows, offset_iterator, offset_iterator + 1, - stream); + stream)); } } @@ -86,22 +86,22 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool, if (order == GGML_SORT_ORDER_ASC) { if (nrows == 1) { - DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) + CUDA_CHECK(DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) temp_indices, dst, // values (indices) - ncols, 0, sizeof(float) * 8, stream); + ncols, 0, sizeof(float) * 8, stream)); } else { - DeviceSegmentedSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst, - ncols * nrows, nrows, offset_iterator, offset_iterator + 1, stream); + CUDA_CHECK(DeviceSegmentedSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst, + ncols * nrows, nrows, offset_iterator, offset_iterator + 1, stream)); } } else { if (nrows == 1) { - DeviceRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) + CUDA_CHECK(DeviceRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) temp_indices, dst, // values (indices) - ncols, 0, sizeof(float) * 8, stream); + ncols, 0, sizeof(float) * 8, stream)); } else { - DeviceSegmentedSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, + CUDA_CHECK(DeviceSegmentedSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst, ncols * nrows, nrows, offset_iterator, - offset_iterator + 1, stream); + offset_iterator + 1, stream)); } } } diff --git a/ggml/src/ggml-cuda/top-k.cu b/ggml/src/ggml-cuda/top-k.cu index 785a18389f2..59ce36fb1c9 100644 --- a/ggml/src/ggml-cuda/top-k.cu +++ b/ggml/src/ggml-cuda/top-k.cu @@ -25,14 +25,14 @@ static void top_k_cub(ggml_cuda_pool & pool, auto indexes_in = cuda::make_counting_iterator(0); size_t temp_storage_bytes = 0; - DeviceTopK::MaxPairs(nullptr, temp_storage_bytes, src, cuda::discard_iterator(), indexes_in, dst, ncols, k, - env); + CUDA_CHECK(DeviceTopK::MaxPairs(nullptr, temp_storage_bytes, src, cuda::discard_iterator(), indexes_in, dst, ncols, k, + env)); ggml_cuda_pool_alloc temp_storage_alloc(pool, temp_storage_bytes); void * d_temp_storage = temp_storage_alloc.get(); - DeviceTopK::MaxPairs(d_temp_storage, temp_storage_bytes, src, cuda::discard_iterator(), indexes_in, dst, - ncols, k, env); + CUDA_CHECK(DeviceTopK::MaxPairs(d_temp_storage, temp_storage_bytes, src, cuda::discard_iterator(), indexes_in, dst, + ncols, k, env)); } #elif defined(GGML_CUDA_USE_CUB) // CUB_TOP_K_AVAILABLE From bb895c843d249ee4a15dcfa19caf2d78ad5e2aa0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Thu, 9 Apr 2026 16:42:19 +0200 Subject: [PATCH 409/831] ggml: backend-agnostic tensor parallelism (experimental) (llama/19378) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * ggml: backend-agnostic tensor parallelism * support for GPT-OSS, Qwen 3 MoE * partial Vulkan fix * add support for 4/8 GPUs * unconditional peer access * re-use buffers + ggml contexts * fix output pattern * NCCL support * GGML: HIP: add RCCL support * Remove shfl and AllReduce from backend interface * move allocation workaround out of ggml-alloc.c * 2d tensor set/get support * Fix the seg fault without NCCL * Apply suggestion from JohannesGaessler * support for tensor dims % n_devs != 0 * fix view_offs scaling * arbitrary num. of GPUs/tensor split * fix compilation * better granularity estimate * Support device-specific host buffer types if all underlying backends expose the same type. This allows using pinned memory instead of pageable memory for CUDA. Fix compilation errors. * partial Qwen 3 Next support * Fix qwen3 30b (llama/8) * Fix crash with Qwen-30B-A3B Q4_0 Qwen-30B-A3B Q4_0 has an intermediate dimension of 768. Using a granularity of 256 forces an uneven split between GPUs, which is not supported by the current implementation. * Decide block size based on tensor quantization type * Fix crashes due to KV cache serialization (llama/9) KV cache serialization requires non-zero offsets on the tensor. Add support in the meta backend to set/get a tensor with a non-zero offset. * metal : fix build (llama/7) * static memory allocations, fix usage count * fix tensor granularity * more even memory distribution * use BF16 for allreduce * rebase fixup * better error message for unsupported architectures * Fix device mismatch during scatter of allReduce. (llama/11) There is a mismatch between the dst buffer device and the backend device, causing the use of sync copies * Enable the previous allreduce implementation. It is better in both perf and stability (llama/12) * delay AllReduce for Moe for less I/O * build : clean-up compile warnings * backend : move most of the meta backend API to ggml-backend-impl.h * cont : hide unused public API in the implementation * llama : use llama_device + remove ggml_backend_dev_is_meta() * ggml-backend : remove unused alloc include * minor : remove regex include * ggml : introduce ggml-ext.h for staging new APIs * rebase fixup * fix tests * llama : more robust logic for determining Meta devices (llama/16) * llama : more robust logic for determining Meta devices * cont : fix devs size check Co-authored-by: Johannes Gäßler * cont : fix log type Co-authored-by: Johannes Gäßler --------- Co-authored-by: Johannes Gäßler * disable roundtrip for meta backend * fix arch selection * Qwen 3.5 support * fix Gemma 4 MoE * fix OpenVino, SYCL * fix test-llama-archs for CPU-only builds * Fix Qwen 3.5 MoE * disable meta backend tests for WebGPU * tests : filter CPU-based devices from the Meta backend tests (llama/17) * meta : formatting, naming, indentation (llama/18) * formatting : llama-model.cpp * formatting : ggml-ext.h * formatting : ggml-backend-meta.cpp * meta : add TODO * add documentation * better error messages * fix GPT-OSS --------- Co-authored-by: Carl Philipp Klemm Co-authored-by: Gaurav Garg Co-authored-by: Georgi Gerganov --- ggml/CMakeLists.txt | 4 + ggml/include/ggml-backend.h | 26 +- ggml/include/ggml-cuda.h | 3 + ggml/src/CMakeLists.txt | 1 + ggml/src/ggml-alloc.c | 3 + ggml/src/ggml-backend-impl.h | 24 +- ggml/src/ggml-backend-meta.cpp | 1923 +++++++++++++++++ ggml/src/ggml-backend.cpp | 110 +- ggml/src/ggml-blas/ggml-blas.cpp | 2 + ggml/src/ggml-cann/ggml-cann.cpp | 4 + ggml/src/ggml-cpu/amx/amx.cpp | 2 + ggml/src/ggml-cpu/ggml-cpu.cpp | 2 + ggml/src/ggml-cuda/CMakeLists.txt | 10 + ggml/src/ggml-cuda/common.cuh | 8 + ggml/src/ggml-cuda/ggml-cuda.cu | 245 ++- ggml/src/ggml-cuda/vendors/cuda.h | 4 + ggml/src/ggml-cuda/vendors/hip.h | 6 + ggml/src/ggml-hexagon/ggml-hexagon.cpp | 4 + ggml/src/ggml-hip/CMakeLists.txt | 12 + ggml/src/ggml-metal/ggml-metal.cpp | 24 +- ggml/src/ggml-opencl/ggml-opencl.cpp | 4 + ggml/src/ggml-openvino/ggml-openvino.cpp | 4 + ggml/src/ggml-rpc/ggml-rpc.cpp | 4 + ggml/src/ggml-sycl/ggml-sycl.cpp | 6 + ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp | 4 + ggml/src/ggml-virtgpu/ggml-backend.cpp | 2 + ggml/src/ggml-vulkan/ggml-vulkan.cpp | 4 + ggml/src/ggml-webgpu/ggml-webgpu.cpp | 4 + ggml/src/ggml-zdnn/ggml-zdnn.cpp | 32 +- ggml/src/ggml-zendnn/ggml-zendnn.cpp | 2 + 30 files changed, 2362 insertions(+), 121 deletions(-) create mode 100644 ggml/src/ggml-backend-meta.cpp diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 5834e544b48..6bf15723b3c 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -7,6 +7,8 @@ set(GGML_VERSION_MINOR 9) set(GGML_VERSION_PATCH 11) set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}") +list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/") + find_program(GIT_EXE NAMES git git.exe NO_CMAKE_FIND_ROOT_PATH) if(GIT_EXE) # Get current git commit hash @@ -204,12 +206,14 @@ option(GGML_CUDA_NO_VMM "ggml: do not try to use CUDA VMM" option(GGML_CUDA_FA "ggml: compile ggml FlashAttention CUDA kernels" ON) option(GGML_CUDA_FA_ALL_QUANTS "ggml: compile all quants for FlashAttention" OFF) option(GGML_CUDA_GRAPHS "ggml: use CUDA graphs (llama.cpp only)" ${GGML_CUDA_GRAPHS_DEFAULT}) +option(GGML_CUDA_NCCL "ggml: use NVIDIA Collective Comm. Library" ON) set (GGML_CUDA_COMPRESSION_MODE "size" CACHE STRING "ggml: cuda link binary compression mode; requires cuda 12.8+") set_property(CACHE GGML_CUDA_COMPRESSION_MODE PROPERTY STRINGS "none;speed;balance;size") option(GGML_HIP "ggml: use HIP" OFF) option(GGML_HIP_GRAPHS "ggml: use HIP graph, experimental, slow" OFF) +option(GGML_HIP_RCCL "ggml: use ROCm Collective Comm. Library" OFF) option(GGML_HIP_NO_VMM "ggml: do not try to use HIP VMM" ON) option(GGML_HIP_ROCWMMA_FATTN "ggml: enable rocWMMA for FlashAttention" OFF) option(GGML_HIP_MMQ_MFMA "ggml: enable MFMA MMA for CDNA in MMQ" ON) diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h index 9fd3f7f32a0..3c06aeaffb1 100644 --- a/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h @@ -68,7 +68,7 @@ extern "C" { GGML_API void ggml_backend_buffer_reset (ggml_backend_buffer_t buffer); // tensor copy between different backends - GGML_API void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst); + GGML_API void ggml_backend_tensor_copy(const struct ggml_tensor * src, struct ggml_tensor * dst); // // Backend (stream) @@ -83,13 +83,17 @@ extern "C" { GGML_API size_t ggml_backend_get_alignment(ggml_backend_t backend); GGML_API size_t ggml_backend_get_max_size(ggml_backend_t backend); - GGML_API void ggml_backend_tensor_set_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); - GGML_API void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); + GGML_API void ggml_backend_tensor_set_async (ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); + GGML_API void ggml_backend_tensor_get_async (ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); + GGML_API void ggml_backend_tensor_set_2d_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data); + GGML_API void ggml_backend_tensor_get_2d_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data); // "offset" refers to the offset in tensor->data for setting/getting data - GGML_API void ggml_backend_tensor_set( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); - GGML_API void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); - GGML_API void ggml_backend_tensor_memset( struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size); + GGML_API void ggml_backend_tensor_set ( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); + GGML_API void ggml_backend_tensor_get (const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); + GGML_API void ggml_backend_tensor_set_2d( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data); + GGML_API void ggml_backend_tensor_get_2d(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data); + GGML_API void ggml_backend_tensor_memset( struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size); GGML_API void ggml_backend_synchronize(ggml_backend_t backend); @@ -109,7 +113,7 @@ extern "C" { // the copy is performed after all the currently queued operations in backend_src // backend_dst will wait for the copy to complete before performing other operations // automatic fallback to sync copy if async is not supported - GGML_API void ggml_backend_tensor_copy_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, struct ggml_tensor * src, struct ggml_tensor * dst); + GGML_API void ggml_backend_tensor_copy_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const struct ggml_tensor * src, struct ggml_tensor * dst); GGML_API ggml_backend_dev_t ggml_backend_get_device(ggml_backend_t backend); @@ -135,7 +139,9 @@ extern "C" { // integrated GPU device using host memory GGML_BACKEND_DEVICE_TYPE_IGPU, // accelerator devices intended to be used together with the CPU backend (e.g. BLAS or AMX) - GGML_BACKEND_DEVICE_TYPE_ACCEL + GGML_BACKEND_DEVICE_TYPE_ACCEL, + // "meta" device wrapping multiple other devices for tensor parallelism + GGML_BACKEND_DEVICE_TYPE_META, }; // functionality supported by the device @@ -196,7 +202,9 @@ extern "C" { // Common functions that may be obtained using ggml_backend_reg_get_proc_address - // Split buffer type for tensor parallelism + // AllReduce operation for tensor parallelism (meta backend) + typedef bool (*ggml_backend_allreduce_tensor_t)(ggml_backend_t * backends, struct ggml_tensor ** tensors, size_t n_backends); + // Split buffer type for tensor parallelism (old) typedef ggml_backend_buffer_type_t (*ggml_backend_split_buffer_type_t)(int main_device, const float * tensor_split); // Set the number of threads for the backend typedef void (*ggml_backend_set_n_threads_t)(ggml_backend_t backend, int n_threads); diff --git a/ggml/include/ggml-cuda.h b/ggml/include/ggml-cuda.h index 22ad2c00963..5436c7ef579 100644 --- a/ggml/include/ggml-cuda.h +++ b/ggml/include/ggml-cuda.h @@ -27,6 +27,9 @@ GGML_BACKEND_API bool ggml_backend_is_cuda(ggml_backend_t backend); // device buffer GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device); +// conduct allreduce operation between devices +GGML_BACKEND_API bool ggml_backend_cuda_allreduce_tensor(ggml_backend_t * backends, struct ggml_tensor ** tensors, size_t n_backends); + // split tensor buffer that splits matrices by rows across multiple devices GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cuda_split_buffer_type(int main_device, const float * tensor_split); diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index 78853304d9f..48fbe208d90 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -200,6 +200,7 @@ add_library(ggml-base ggml.cpp ggml-alloc.c ggml-backend.cpp + ggml-backend-meta.cpp ggml-opt.cpp ggml-threading.cpp ggml-threading.h diff --git a/ggml/src/ggml-alloc.c b/ggml/src/ggml-alloc.c index 7f414b2311c..e9b70398ffc 100644 --- a/ggml/src/ggml-alloc.c +++ b/ggml/src/ggml-alloc.c @@ -1236,6 +1236,9 @@ size_t ggml_backend_alloc_ctx_tensors_from_buft_size(struct ggml_context * ctx, ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft) { size_t nbytes_total = 0; + if (ggml_backend_buft_is_meta(buft)) { + return ggml_backend_meta_alloc_ctx_tensors_from_buft(ctx, buft); + } return ggml_backend_alloc_ctx_tensors_from_buft_impl(ctx, buft, &nbytes_total, /*no_alloc =*/ false); } diff --git a/ggml/src/ggml-backend-impl.h b/ggml/src/ggml-backend-impl.h index 59190b7c465..9c56ec30c5f 100644 --- a/ggml/src/ggml-backend-impl.h +++ b/ggml/src/ggml-backend-impl.h @@ -49,6 +49,10 @@ extern "C" { void (*memset_tensor)(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size); void (*set_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); void (*get_tensor) (ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); + // (optional) 2d data copies + void (*set_tensor_2d)(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data); + void (*get_tensor_2d)(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data); + // (optional) tensor copy: dst is in the buffer, src may be in any buffer, including buffers from a different backend (return false if not supported) bool (*cpy_tensor) (ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst); // clear the entire buffer @@ -80,6 +84,20 @@ extern "C" { GGML_API bool ggml_backend_buffer_is_multi_buffer(ggml_backend_buffer_t buffer); GGML_API void ggml_backend_multi_buffer_set_usage(ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage); + // + // Backend (meta) + // + + GGML_API bool ggml_backend_is_meta (ggml_backend_t backend); + GGML_API bool ggml_backend_buffer_is_meta(ggml_backend_buffer_t buf); + GGML_API bool ggml_backend_buft_is_meta (ggml_backend_buffer_type_t buft); + + GGML_API size_t ggml_backend_meta_n_backends (ggml_backend_t meta_backend); + GGML_API ggml_backend_t ggml_backend_meta_simple_backend(ggml_backend_t meta_backend, size_t index); + + // temporary workaround to statically allocate tensors from a context in a deduplicated way: + GGML_API struct ggml_backend_buffer * ggml_backend_meta_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft); + // // Backend (stream) // @@ -90,8 +108,10 @@ extern "C" { void (*free)(ggml_backend_t backend); // (optional) asynchronous tensor data access - void (*set_tensor_async)(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); - void (*get_tensor_async)(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); + void (*set_tensor_async) (ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); + void (*get_tensor_async) (ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); + void (*set_tensor_2d_async)(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data); + void (*get_tensor_2d_async)(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data); bool (*cpy_tensor_async)(ggml_backend_t backend_src, ggml_backend_t backend_dst, const struct ggml_tensor * src, struct ggml_tensor * dst); // (optional) complete all pending operations (required if the backend supports async operations) diff --git a/ggml/src/ggml-backend-meta.cpp b/ggml/src/ggml-backend-meta.cpp new file mode 100644 index 00000000000..a2ab8872c4a --- /dev/null +++ b/ggml/src/ggml-backend-meta.cpp @@ -0,0 +1,1923 @@ +#include "ggml.h" +#include "ggml-impl.h" +#include "ggml-backend.h" +#include "ggml-backend-impl.h" +#include "ggml-alloc.h" +#include "ggml-cpp.h" + +// TODO: tmp +#include "ggml-ext.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +struct ggml_backend_meta_device; +struct ggml_backend_meta_buffer_type; +struct ggml_backend_meta_buffer; +struct ggml_backend_meta; + +const char * ggml_backend_meta_split_axis_name(enum ggml_backend_meta_split_axis split_axis) { + switch (split_axis) { + case GGML_BACKEND_SPLIT_AXIS_0: + return "0"; + case GGML_BACKEND_SPLIT_AXIS_1: + return "1"; + case GGML_BACKEND_SPLIT_AXIS_2: + return "2"; + case GGML_BACKEND_SPLIT_AXIS_3: + return "3"; + case GGML_BACKEND_SPLIT_AXIS_MIRRORED: + return "MIRRORED"; + case GGML_BACKEND_SPLIT_AXIS_PARTIAL: + return "PARTIAL"; + case GGML_BACKEND_SPLIT_AXIS_NONE: + return "NONE"; + case GGML_BACKEND_SPLIT_AXIS_UNKNOWN: + return "UNKNOWN"; + default: + GGML_ABORT("fatal error"); + } +} + +// +// meta backend device +// + +struct ggml_backend_meta_device_context { + std::vector simple_devs; + ggml_backend_meta_get_split_state_t get_split_state; + void * get_split_state_ud; + + std::string name; + std::string description; + + ggml_backend_meta_device_context( + std::vector simple_devs, ggml_backend_meta_get_split_state_t get_split_state, void * get_split_state_ud) : + simple_devs(std::move(simple_devs)), get_split_state(get_split_state), get_split_state_ud(get_split_state_ud) { + name = std::string("Meta("); + description = std::string("Meta("); + for (size_t i = 0; i < simple_devs.size(); i++) { + if (i > 0) { + name += ","; + description += ","; + } + name += ggml_backend_dev_name (simple_devs[i]); + description += ggml_backend_dev_description(simple_devs[i]); + } + name += ")"; + description += ")"; + } + + bool operator<(const ggml_backend_meta_device_context & other) const { + return std::tie(simple_devs, get_split_state, get_split_state_ud) + < std::tie(other.simple_devs, other.get_split_state, other.get_split_state_ud); + } +}; + +static bool ggml_backend_dev_is_meta(ggml_backend_dev_t dev); + +static const char * ggml_backend_meta_device_get_name(ggml_backend_dev_t dev) { + GGML_ASSERT(ggml_backend_dev_is_meta(dev)); + const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context; + return meta_dev_ctx->name.c_str(); +} + +static const char * ggml_backend_meta_device_get_description(ggml_backend_dev_t dev) { + GGML_ASSERT(ggml_backend_dev_is_meta(dev)); + const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context; + return meta_dev_ctx->description.c_str(); +} + +static void ggml_backend_meta_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { + GGML_ASSERT(ggml_backend_dev_is_meta(dev)); + const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context; + *free = 0; + *total = 0; + for (ggml_backend_dev_t dev : meta_dev_ctx->simple_devs) { + size_t tmp_free, tmp_total; + ggml_backend_dev_memory(dev, &tmp_free, &tmp_total); + *free += tmp_free; + *total += tmp_total; + } +} + +static enum ggml_backend_dev_type ggml_backend_meta_device_get_type(ggml_backend_dev_t dev) { + return GGML_BACKEND_DEVICE_TYPE_META; + + GGML_UNUSED(dev); +} + +static void ggml_backend_meta_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) { + GGML_ASSERT(ggml_backend_dev_is_meta(dev)); + const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context; + + // TODO replace placeholders + props->name = ggml_backend_meta_device_get_name(dev); + props->description = ggml_backend_meta_device_get_description(dev); + props->type = ggml_backend_meta_device_get_type(dev); + props->device_id = 0; + + ggml_backend_meta_device_get_memory(dev, &props->memory_free, &props->memory_total); + + props->caps = { + /* .async = */ true, + /* .host_buffer = */ false, // Not implemented. + /* .buffer_from_host_ptr = */ false, // Not implemented. + /* .events = */ false, // Not implemented. + }; + for (ggml_backend_dev_t simple_dev : meta_dev_ctx->simple_devs) { + ggml_backend_dev_props tmp_props; + ggml_backend_dev_get_props(simple_dev, &tmp_props); + props->caps.async = props->caps.async && tmp_props.caps.async; + props->caps.host_buffer = props->caps.host_buffer && tmp_props.caps.host_buffer; + props->caps.buffer_from_host_ptr = props->caps.buffer_from_host_ptr && tmp_props.caps.buffer_from_host_ptr; + props->caps.events = props->caps.events && tmp_props.caps.events; + } +} + +static ggml_backend_t ggml_backend_meta_device_init_backend(ggml_backend_dev_t dev, const char * params); + +static ggml_backend_buffer_type_t ggml_backend_meta_device_get_buffer_type(ggml_backend_dev_t dev); + +static ggml_backend_buffer_type_t ggml_backend_meta_device_get_host_buffer_type(ggml_backend_dev_t dev); + +static bool ggml_backend_meta_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) { + GGML_ASSERT(ggml_backend_dev_is_meta(dev)); + const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context; + return std::all_of(meta_dev_ctx->simple_devs.begin(), meta_dev_ctx->simple_devs.end(), + [op](ggml_backend_dev_t simple_dev) { return ggml_backend_dev_supports_op(simple_dev, op); }); +} + +static bool ggml_backend_meta_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { + GGML_ASSERT(ggml_backend_dev_is_meta(dev)); + ggml_backend_dev_t dev_buft = ggml_backend_buft_get_device(buft); + if (!ggml_backend_dev_is_meta(dev_buft)) { + return false; + } + const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context; + const ggml_backend_meta_device_context * meta_buft_dev_ctx = (const ggml_backend_meta_device_context *) dev_buft->context; + if (meta_dev_ctx->simple_devs.size() != meta_buft_dev_ctx->simple_devs.size()) { + return false; + } + for (size_t i = 0; i < meta_dev_ctx->simple_devs.size(); i++) { + if (meta_dev_ctx->simple_devs[i] != meta_buft_dev_ctx->simple_devs[i]) { + return false; + } + } + return true; +} + +static const ggml_backend_device_i ggml_backend_meta_device_iface = { + /* .get_name = */ ggml_backend_meta_device_get_name, + /* .get_description = */ ggml_backend_meta_device_get_description, + /* .get_memory = */ ggml_backend_meta_device_get_memory, + /* .get_type = */ ggml_backend_meta_device_get_type, + /* .get_props = */ ggml_backend_meta_device_get_props, + /* .init_backend = */ ggml_backend_meta_device_init_backend, + /* .get_buffer_type = */ ggml_backend_meta_device_get_buffer_type, + /* .get_host_buffer_type = */ ggml_backend_meta_device_get_host_buffer_type, + /* .buffer_from_host_ptr = */ nullptr, + /* .supports_op = */ ggml_backend_meta_device_supports_op, + /* .supports_buft = */ ggml_backend_meta_device_supports_buft, + /* .offload_op = */ nullptr, + /* .event_new = */ nullptr, + /* .event_free = */ nullptr, + /* .event_synchronize = */ nullptr, +}; + +static bool ggml_backend_dev_is_meta(ggml_backend_dev_t dev) { + return dev != nullptr && dev->iface.get_name == ggml_backend_meta_device_iface.get_name; +} + +static size_t ggml_backend_meta_dev_n_devs(ggml_backend_dev_t meta_dev) { + GGML_ASSERT(ggml_backend_dev_is_meta(meta_dev)); + const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) meta_dev->context; + return meta_dev_ctx->simple_devs.size(); +} + +static ggml_backend_dev_t ggml_backend_meta_dev_simple_dev(ggml_backend_dev_t meta_dev, size_t index) { + GGML_ASSERT(ggml_backend_dev_is_meta(meta_dev)); + const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) meta_dev->context; + GGML_ASSERT(index < meta_dev_ctx->simple_devs.size()); + return meta_dev_ctx->simple_devs[index]; +} + +ggml_backend_dev_t ggml_backend_meta_device( + ggml_backend_dev_t * devs, size_t n_devs, ggml_backend_meta_get_split_state_t get_split_state, void * get_split_state_ud) { + GGML_ASSERT(n_devs <= GGML_BACKEND_META_MAX_DEVICES); + // TODO: this is not thread-safe - needs to be fixed + static std::vector> ctxs; + static std::map meta_devs; + + std::vector simple_devs; + simple_devs.reserve(n_devs); + for (size_t i = 0; i < n_devs; i++) { + simple_devs.push_back(devs[i]); + } + ggml_backend_meta_device_context ctx(simple_devs, get_split_state, get_split_state_ud); + + { + auto it = meta_devs.find(ctx); + if (it != meta_devs.end()) { + return &it->second; + } + } + ctxs.push_back(std::make_unique(ctx)); + + struct ggml_backend_device meta_dev = { + /*iface =*/ ggml_backend_meta_device_iface, + /*reg =*/ nullptr, + /*ctx =*/ ctxs.back().get(), + }; + + auto result = meta_devs.emplace(*ctxs.back(), meta_dev); + return &result.first->second; +} + +// +// meta backend buffer type +// + +struct ggml_backend_meta_buffer_type_context { + std::vector simple_bufts; + + std::string name; + + ggml_backend_meta_buffer_type_context(std::vector simple_bufts) : simple_bufts(std::move(simple_bufts)) { + name = "Meta("; + for (size_t i = 0; i < simple_bufts.size(); i++) { + if (i > 0) { + name += ","; + } + name += ggml_backend_buft_name(simple_bufts[i]); + } + name += ")"; + } + + bool operator<(const ggml_backend_meta_buffer_type_context & other) const { + return simple_bufts < other.simple_bufts; + } +}; + +static size_t ggml_backend_meta_buft_n_bufts(ggml_backend_buffer_type_t meta_buft) { + GGML_ASSERT(ggml_backend_buft_is_meta(meta_buft)); + const ggml_backend_meta_buffer_type_context * meta_buft_ctx = (const ggml_backend_meta_buffer_type_context *) meta_buft->context; + return meta_buft_ctx->simple_bufts.size(); +} + +static const char * ggml_backend_meta_buffer_type_get_name(ggml_backend_buffer_type_t buft) { + GGML_ASSERT(ggml_backend_buft_is_meta(buft)); + const ggml_backend_meta_buffer_type_context * meta_buft_ctx = (const ggml_backend_meta_buffer_type_context *) buft->context; + return meta_buft_ctx->name.c_str(); +} + +static ggml_backend_buffer_type_t ggml_backend_meta_buft_simple_buft(ggml_backend_buffer_type_t meta_buft, size_t index) { + GGML_ASSERT(ggml_backend_buft_is_meta(meta_buft)); + const ggml_backend_meta_buffer_type_context * meta_buft_ctx = (const ggml_backend_meta_buffer_type_context *) meta_buft->context; + GGML_ASSERT(index < meta_buft_ctx->simple_bufts.size()); + return meta_buft_ctx->simple_bufts[index]; +} + +static ggml_backend_buffer_t ggml_backend_meta_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size); + +static size_t ggml_backend_meta_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { + const size_t n_simple_bufts = ggml_backend_meta_buft_n_bufts(buft); + size_t max_alignment = 1; + for (size_t i = 0; i < n_simple_bufts; i++) { + const size_t alignment = ggml_backend_buft_get_alignment(ggml_backend_meta_buft_simple_buft(buft, i)); + max_alignment = std::max(max_alignment, alignment); + GGML_ASSERT(max_alignment % alignment == 0); + } + return max_alignment; +} + +static size_t ggml_backend_meta_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) { + const size_t n_simple_bufts = ggml_backend_meta_buft_n_bufts(buft); + size_t max_size = SIZE_MAX; + for (size_t i = 0; i < n_simple_bufts; i++) { + max_size = std::min(max_size, ggml_backend_buft_get_max_size(ggml_backend_meta_buft_simple_buft(buft, i))); + } + return max_size; +} + +static size_t ggml_backend_meta_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { + const size_t n_simple_bufts = ggml_backend_meta_buft_n_bufts(buft); + size_t max_alloc_size = 0; + for (size_t i = 0; i < n_simple_bufts; i++) { + const size_t alloc_size = ggml_backend_buft_get_alloc_size(ggml_backend_meta_buft_simple_buft(buft, i), tensor); + max_alloc_size = std::max(max_alloc_size, alloc_size); + } + return max_alloc_size; +} + +static bool ggml_backend_meta_buffer_type_is_host(ggml_backend_buffer_type_t buft) { + const size_t n_simple_bufts = ggml_backend_meta_buft_n_bufts(buft); + for (size_t i = 0; i < n_simple_bufts; i++) { + if (!ggml_backend_buft_is_host(ggml_backend_meta_buft_simple_buft(buft, i))) { + return false; + } + } + return true; +} + +static const struct ggml_backend_buffer_type_i ggml_backend_meta_buffer_type_iface = { + /* .get_name = */ ggml_backend_meta_buffer_type_get_name, + /* .alloc_buffer = */ ggml_backend_meta_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_meta_buffer_type_get_alignment, + /* .get_max_size = */ ggml_backend_meta_buffer_type_get_max_size, + /* .get_alloc_size = */ ggml_backend_meta_buffer_type_get_alloc_size, + /* .is_host = */ ggml_backend_meta_buffer_type_is_host, +}; + +bool ggml_backend_buft_is_meta(ggml_backend_buffer_type_t buft) { + return buft != nullptr && buft->iface.get_name == ggml_backend_meta_buffer_type_iface.get_name; +} + +static ggml_backend_buffer_type_t ggml_backend_meta_device_get_buffer_type(ggml_backend_dev_t dev) { + static std::map meta_bufts; + GGML_ASSERT(ggml_backend_dev_is_meta(dev)); + { + auto it = meta_bufts.find(dev); + if (it != meta_bufts.end()) { + return &it->second; + } + } + + const size_t n_devs = ggml_backend_meta_dev_n_devs(dev); + std::vector simple_bufts; + simple_bufts.reserve(n_devs); + for (size_t i = 0; i < n_devs; i++) { + simple_bufts.push_back(ggml_backend_dev_buffer_type(ggml_backend_meta_dev_simple_dev(dev, i))); + } + ggml_backend_meta_buffer_type_context * buft_ctx = new ggml_backend_meta_buffer_type_context(simple_bufts); + + struct ggml_backend_buffer_type meta_buft = { + /*iface =*/ ggml_backend_meta_buffer_type_iface, + /*device =*/ dev, + /*ctx =*/ buft_ctx, + }; + auto result = meta_bufts.emplace(dev, meta_buft); + return &result.first->second; +} + +static ggml_backend_buffer_type_t ggml_backend_meta_device_get_host_buffer_type(ggml_backend_dev_t dev) { + GGML_ASSERT(ggml_backend_dev_is_meta(dev)); + const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context; + + ggml_backend_buffer_type_t host_buft = nullptr; + for (ggml_backend_dev_t simple_dev : meta_dev_ctx->simple_devs) { + ggml_backend_buffer_type_t simple_host_buft = ggml_backend_dev_host_buffer_type(simple_dev); + if (simple_host_buft == nullptr) { + return nullptr; + } + if (host_buft == nullptr) { + host_buft = simple_host_buft; + } else if (host_buft != simple_host_buft) { + // if different simple devices have different host buffer types, + // we cannot provide a single host buffer type for the meta device + return nullptr; + } + } + return host_buft; +} + +// +// meta backend buffer +// + +struct ggml_backend_meta_buffer_context { + static constexpr size_t nbtc = GGML_TENSOR_SIZE - sizeof(ggml_tensor::padding); + + std::map, std::pair> split_state_cache; + std::map< const ggml_tensor *, std::vector> simple_tensors; + + struct buffer_config { + ggml_context * ctx; + ggml_backend_buffer_t buf; + + buffer_config(ggml_context * ctx, ggml_backend_buffer_t buf) : ctx(ctx), buf(buf) {} + }; + std::vector buf_configs; + + int debug; + + ggml_backend_meta_buffer_context() { + const char * GGML_META_DEBUG = getenv("GGML_META_DEBUG"); + debug = GGML_META_DEBUG ? atoi(GGML_META_DEBUG) : 0; + } +}; + +static void ggml_backend_meta_buffer_free_buffer(ggml_backend_buffer_t buffer) { + GGML_ASSERT(ggml_backend_buffer_is_meta(buffer)); + ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) buffer->context; + for (auto & [ctx, buf] : buf_ctx->buf_configs) { + ggml_backend_buffer_free(buf); + ggml_free(ctx); + } + delete buf_ctx; +} + +static size_t ggml_backend_meta_buffer_n_bufs(ggml_backend_buffer_t meta_buf) { + GGML_ASSERT(ggml_backend_buffer_is_meta(meta_buf)); + ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) meta_buf->context; + return buf_ctx->buf_configs.size(); +} + +static ggml_backend_buffer_t ggml_backend_meta_buffer_simple_buffer(ggml_backend_buffer_t meta_buf, size_t index) { + GGML_ASSERT(ggml_backend_buffer_is_meta(meta_buf)); + ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) meta_buf->context; + GGML_ASSERT(index < buf_ctx->buf_configs.size()); + return buf_ctx->buf_configs[index].buf; +} + +static struct ggml_tensor * ggml_backend_meta_buffer_simple_tensor(const struct ggml_tensor * tensor, size_t index) { + GGML_ASSERT(ggml_backend_buffer_is_meta(tensor->buffer)); + ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) tensor->buffer->context; + GGML_ASSERT(index < buf_ctx->buf_configs.size()); + + auto it = buf_ctx->simple_tensors.find(tensor); + if (it == buf_ctx->simple_tensors.end()) { + return nullptr; + } + return it->second[index]; +} + +static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state(const struct ggml_tensor * tensor, bool assume_sync) { + const size_t n_bufs = ggml_backend_meta_buffer_n_bufs(tensor->buffer); + ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) tensor->buffer->context; + + auto split_states_equal = [&](const ggml_backend_meta_split_state & a, const ggml_backend_meta_split_state & b) -> bool { + if (a.axis != b.axis) { + return false; + } + for (size_t j = 0; j < n_bufs; j++) { + int64_t sum_a = 0; + for (size_t s = 0; s < a.n_segments; s++) { + sum_a += a.ne[s*n_bufs + j]; + } + int64_t sum_b = 0; + for (size_t s = 0; s < b.n_segments; s++) { + sum_b += b.ne[s*n_bufs + j]; + } + if (sum_a != sum_b) { + return false; + } + } + return true; + }; + + auto handle_generic = [&](const std::vector & src_ss, bool scalar_only) -> ggml_backend_meta_split_state { + ggml_backend_meta_split_state ret = {GGML_BACKEND_SPLIT_AXIS_NONE, {0}, 1}; + for (size_t i = 0; i < GGML_MAX_SRC; i++) { + if (tensor->src[i] == nullptr || tensor->src[i] == tensor) { + continue; + } + if (ret.axis == GGML_BACKEND_SPLIT_AXIS_NONE) { + ret = src_ss[i]; + } else if (!split_states_equal(src_ss[i], ret)) { + ret = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; + break; + } + } + if (ret.axis == GGML_BACKEND_SPLIT_AXIS_NONE) { + ret = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; + } + if (scalar_only && ret.axis >= 0 && ret.axis < GGML_MAX_DIMS) { + ret = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; + } + GGML_ASSERT(ret.axis != GGML_BACKEND_SPLIT_AXIS_UNKNOWN); + return ret; + }; + + // Some ops process data on a per-row bases: + auto handle_per_row = [&](const std::vector & src_ss) -> ggml_backend_meta_split_state { + GGML_ASSERT(src_ss[0].axis != GGML_BACKEND_SPLIT_AXIS_0); + return src_ss[0]; + }; + + // Some ops broadcast the src1 data across src0: + auto handle_bin_bcast = [&](const std::vector & src_ss) -> ggml_backend_meta_split_state { + if (src_ss[0].axis >= 0 && src_ss[0].axis < GGML_MAX_DIMS && + tensor->src[1]->ne[src_ss[0].axis] == 1 && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { + return src_ss[0]; + } + if (src_ss[2].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && (src_ss[0].axis == src_ss[1].axis || + (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && (src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL)))) { + return src_ss[0]; // GGML_OP_ADD_ID + } + GGML_ASSERT(tensor->src[2] == nullptr || src_ss[2].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED); + return handle_generic(src_ss, /*scalar_only =*/ false); + }; + + auto handle_concat = [&](const std::vector & src_ss) -> ggml_backend_meta_split_state { + const ggml_backend_meta_split_axis concat_axis = ggml_backend_meta_split_axis(ggml_get_op_params_i32(tensor, 0)); + if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[1].axis >= 0 && src_ss[1].axis < GGML_MAX_DIMS) { + GGML_ASSERT(concat_axis != src_ss[1].axis); + return src_ss[1]; + } + if (src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[0].axis >= 0 && src_ss[0].axis < GGML_MAX_DIMS) { + GGML_ASSERT(concat_axis != src_ss[0].axis); + return src_ss[0]; + } + if (src_ss[0].axis == src_ss[1].axis && src_ss[0].axis != concat_axis) { + return src_ss[0]; + } + return handle_generic(src_ss, /*scalar_only =*/ true); + }; + + auto handle_mul_mat = [&](const std::vector & src_ss) -> ggml_backend_meta_split_state { + if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { + return {GGML_BACKEND_SPLIT_AXIS_MIRRORED, {0}, 1}; + } + if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_1 && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { + ggml_backend_meta_split_state ret = src_ss[0]; + ret.axis = GGML_BACKEND_SPLIT_AXIS_0; + ret.n_segments = 1; + return ret; + } + if (src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_1 && src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { + ggml_backend_meta_split_state ret = src_ss[1]; + ret.n_segments = 1; + return ret; + } + if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_0 && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_0) { + GGML_ASSERT(split_states_equal(src_ss[0], src_ss[1])); + return {assume_sync ? GGML_BACKEND_SPLIT_AXIS_MIRRORED : GGML_BACKEND_SPLIT_AXIS_PARTIAL, {0}, 1}; + } + GGML_ABORT("fatal error"); + //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; + }; + + auto handle_cpy = [&](const std::vector & src_ss) -> ggml_backend_meta_split_state { + if (src_ss[0].axis >= 0 && src_ss[0].axis < GGML_MAX_DIMS) { + int64_t ne_split_src = tensor->src[0]->ne[0]; + for (int dim = 1; dim <= src_ss[0].axis; dim++) { + ne_split_src *= tensor->src[0]->ne[dim]; + } + int64_t ne_split_dst = 1; + for (int dim = 0; dim < GGML_MAX_DIMS; dim++) { + ne_split_dst *= tensor->ne[dim]; + if (ne_split_dst == ne_split_src) { + return {ggml_backend_meta_split_axis(dim), {0}, 1}; + } + } + } + return handle_generic(src_ss, /*scalar_only =*/ false); + }; + + auto handle_reshape = [&](const std::vector & src_ss) -> ggml_backend_meta_split_state { + switch (src_ss[0].axis) { + case GGML_BACKEND_SPLIT_AXIS_0: + case GGML_BACKEND_SPLIT_AXIS_1: + case GGML_BACKEND_SPLIT_AXIS_2: + case GGML_BACKEND_SPLIT_AXIS_3: { + GGML_ASSERT(!ggml_is_permuted(tensor) && !ggml_is_permuted(tensor->src[0])); + if (src_ss[0].axis == ggml_n_dims(tensor->src[0]) - 1) { + return {ggml_backend_meta_split_axis(ggml_n_dims(tensor) - 1), {0}, 1}; + } + std::vector base_ne_in; + base_ne_in.reserve(GGML_MAX_DIMS - src_ss[0].axis); + { + base_ne_in.push_back(1); + int dim = 0; + for (; dim <= src_ss[0].axis; dim++) { + base_ne_in[0] *= tensor->src[0]->ne[dim]; + } + for (; dim <= GGML_MAX_DIMS; dim++) { + base_ne_in.push_back(base_ne_in.back() * tensor->src[0]->ne[dim]); + } + } + int64_t base_ne_out = 1; + for (int dim = 0; dim < GGML_MAX_DIMS; dim++) { + const int64_t base_ne_out_next = base_ne_out *= tensor->ne[dim]; + for (const int64_t & bni : base_ne_in) { + if (bni == base_ne_out_next) { + return {ggml_backend_meta_split_axis(dim), {0}, 1}; + } + } + if (base_ne_out_next > base_ne_in[0]) { + GGML_ASSERT(dim + 1 < GGML_MAX_DIMS); + return {ggml_backend_meta_split_axis(dim + 1), {0}, 1}; + } + base_ne_out = base_ne_out_next; + } + GGML_ABORT("shape mismatch for %s", ggml_op_name(tensor->op)); + } + case GGML_BACKEND_SPLIT_AXIS_MIRRORED: + case GGML_BACKEND_SPLIT_AXIS_PARTIAL: { + return src_ss[0]; + } + default: { + GGML_ABORT("fatal error"); + //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; + } + } + }; + + auto handle_view = [&](const std::vector & src_ss) -> ggml_backend_meta_split_state { + if (ggml_is_contiguous(tensor) && ggml_is_contiguous(tensor->src[0])) { + return handle_reshape(src_ss); + } + const int axis = src_ss[0].axis; + { + bool all_strides_the_same = true; + for (int dim = 0; dim < GGML_MAX_DIMS; dim++) { + if (tensor->ne[dim] == 1 && tensor->src[0]->ne[dim] == 1) { + continue; + } + if (tensor->nb[dim] != tensor->src[0]->nb[dim]) { + all_strides_the_same = false; + break; + } + } + if (all_strides_the_same) { + return src_ss[0]; + } + } + if (!ggml_is_permuted(tensor) && !ggml_is_permuted(tensor->src[0]) && axis >= 0 && axis < GGML_MAX_DIMS-1) { + for (int dim = 0; dim < GGML_MAX_DIMS-1; dim++) { + if (tensor->nb[dim+1] == tensor->src[0]->nb[axis+1]) { + return {ggml_backend_meta_split_axis(dim), {0}, 1}; + } + } + GGML_ABORT("fatal error"); + } + if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED || src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL) { + return src_ss[0]; + } + GGML_ABORT("view of permuted tensor not implemented"); + //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; + }; + + auto handle_permute = [&](const std::vector & src_ss) -> ggml_backend_meta_split_state { + switch (src_ss[0].axis) { + case GGML_BACKEND_SPLIT_AXIS_0: + case GGML_BACKEND_SPLIT_AXIS_1: + case GGML_BACKEND_SPLIT_AXIS_2: + case GGML_BACKEND_SPLIT_AXIS_3: { + return {ggml_backend_meta_split_axis(tensor->op_params[src_ss[0].axis]), {0}, 1}; + } + case GGML_BACKEND_SPLIT_AXIS_MIRRORED: + case GGML_BACKEND_SPLIT_AXIS_PARTIAL: { + return src_ss[0]; + } + default: { + GGML_ABORT("fatal error"); + //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; + } + } + }; + + auto handle_transpose = [&](const std::vector & src_ss) -> ggml_backend_meta_split_state { + switch (src_ss[0].axis) { + case GGML_BACKEND_SPLIT_AXIS_0: + case GGML_BACKEND_SPLIT_AXIS_1: { + return {ggml_backend_meta_split_axis(int(src_ss[0].axis) ^ 1), {0}, 1}; + } + case GGML_BACKEND_SPLIT_AXIS_2: + case GGML_BACKEND_SPLIT_AXIS_3: + case GGML_BACKEND_SPLIT_AXIS_MIRRORED: + case GGML_BACKEND_SPLIT_AXIS_PARTIAL: { + return src_ss[0]; + } + default: { + GGML_ABORT("fatal error"); + //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; + } + } + }; + + auto handle_get_rows = [&](const std::vector & src_ss) -> ggml_backend_meta_split_state { + if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_0 && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { + return src_ss[0]; + } + return handle_generic(src_ss, /*scalar_only =*/ true); + }; + + auto handle_set_rows = [&](const std::vector & src_ss) -> ggml_backend_meta_split_state { + GGML_ASSERT(src_ss[0].axis != GGML_BACKEND_SPLIT_AXIS_1); + GGML_ASSERT(src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED); + GGML_ASSERT(split_states_equal(src_ss[0], src_ss[2])); + return src_ss[0]; + }; + + auto handle_rope = [&](const std::vector & src_ss) -> ggml_backend_meta_split_state { + GGML_ASSERT(src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED); + return src_ss[0]; + }; + + auto handle_pad = [&](const std::vector & src_ss) -> ggml_backend_meta_split_state { + if (src_ss[0].axis >= 0 && src_ss[0].axis < GGML_MAX_DIMS) { + GGML_ASSERT(tensor->op_params[2*src_ss[0].axis + 0] == 0); + GGML_ASSERT(tensor->op_params[2*src_ss[0].axis + 1] == 0); + } + return src_ss[0]; + }; + + auto handle_flash_attn_ext = [&](const std::vector & src_ss) -> ggml_backend_meta_split_state { + GGML_ASSERT( src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_2); + GGML_ASSERT( src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_2); + GGML_ASSERT( src_ss[2].axis == GGML_BACKEND_SPLIT_AXIS_2); + GGML_ASSERT(tensor->src[4] == nullptr || src_ss[3].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED); + GGML_ASSERT(tensor->src[4] == nullptr || src_ss[4].axis == GGML_BACKEND_SPLIT_AXIS_0); + return {GGML_BACKEND_SPLIT_AXIS_1, {0}, 1}; + }; + + auto handle_ssm_conv = [&](const std::vector & src_ss) -> ggml_backend_meta_split_state { + if (src_ss[0].axis == src_ss[1].axis) { + if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_0) { + return {GGML_BACKEND_SPLIT_AXIS_1, {0}, 1}; + } + if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_1) { + return {GGML_BACKEND_SPLIT_AXIS_0, {0}, 1}; + } + } + return handle_generic(src_ss, /*scalar_only =*/ false); + }; + + auto handle_gated_delta_net = [&](const std::vector & src_ss) -> ggml_backend_meta_split_state { + if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && + src_ss[2].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[3].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && + src_ss[4].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { + return src_ss[0]; + } + GGML_ASSERT(src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_1); + GGML_ASSERT(src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_1); + GGML_ASSERT(src_ss[2].axis == GGML_BACKEND_SPLIT_AXIS_1); + GGML_ASSERT(src_ss[3].axis == GGML_BACKEND_SPLIT_AXIS_1); + GGML_ASSERT(src_ss[4].axis == GGML_BACKEND_SPLIT_AXIS_1); + GGML_ASSERT(src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_2); + return {GGML_BACKEND_SPLIT_AXIS_0, {0}, 1}; + }; + + auto calculate_split_state = [&]() -> ggml_backend_meta_split_state { + if (ggml_nelements(tensor) == 0) { + return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; + } + if (ggml_backend_buffer_get_usage(tensor->buffer) != GGML_BACKEND_BUFFER_USAGE_COMPUTE && tensor->view_src == nullptr) { + ggml_backend_dev_t dev = ggml_backend_buft_get_device(ggml_backend_buffer_get_type(tensor->buffer)); + const ggml_backend_meta_device_context * dev_ctx = (const ggml_backend_meta_device_context *) dev->context; + ggml_backend_meta_split_state ret = dev_ctx->get_split_state(tensor, dev_ctx->get_split_state_ud); + if (ret.axis >= 0 && ret.axis <= GGML_MAX_DIMS) { + const int64_t granularity = ret.axis == GGML_BACKEND_SPLIT_AXIS_0 ? ggml_blck_size(tensor->type) : 1; + int64_t ne_sum = 0; + for (size_t sj = 0; sj < ret.n_segments*n_bufs; sj++) { + GGML_ASSERT(ret.ne[sj] % granularity == 0); + ne_sum += ret.ne[sj]; + } + GGML_ASSERT(ne_sum == tensor->ne[ret.axis]); + } + return ret; + } + + std::vector src_ss(GGML_MAX_SRC, {GGML_BACKEND_SPLIT_AXIS_NONE, {0}, 1}); + for (size_t i = 0; i < GGML_MAX_SRC; i++) { + if (tensor->src[i] == nullptr || tensor->src[i] == tensor) { + src_ss[i] = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; + continue; + } + src_ss[i] = ggml_backend_meta_get_split_state(tensor->src[i], /*assume_sync =*/ true); + GGML_ASSERT(src_ss[i].axis != GGML_BACKEND_SPLIT_AXIS_UNKNOWN); + } + + ggml_backend_meta_split_state split_state; + switch (tensor->op) { + case GGML_OP_NONE: { + split_state = {GGML_BACKEND_SPLIT_AXIS_MIRRORED, {0}, 1}; + } break; + case GGML_OP_DUP: { + split_state = handle_generic(src_ss, /*scalar_only =*/ true); + } break; + case GGML_OP_ADD: + case GGML_OP_ADD_ID: { + split_state = handle_bin_bcast(src_ss); + } break; + case GGML_OP_ADD1: + case GGML_OP_ACC: { + split_state = handle_generic(src_ss, /*scalar_only =*/ true); + } break; + case GGML_OP_SUB: + case GGML_OP_MUL: + case GGML_OP_DIV: { + split_state = handle_bin_bcast(src_ss); + } break; + case GGML_OP_SQR: + case GGML_OP_SQRT: + case GGML_OP_LOG: + case GGML_OP_SIN: + case GGML_OP_COS: { + split_state = handle_generic(src_ss, /*scalar_only =*/ false); + } break; + case GGML_OP_SUM: { + split_state = handle_generic(src_ss, /*scalar_only =*/ true); + } break; + case GGML_OP_SUM_ROWS: + case GGML_OP_CUMSUM: + case GGML_OP_MEAN: + case GGML_OP_ARGMAX: + case GGML_OP_COUNT_EQUAL: { + split_state = handle_per_row(src_ss); + } break; + case GGML_OP_REPEAT: + case GGML_OP_REPEAT_BACK: { + split_state = handle_generic(src_ss, /*scalar_only =*/ false); + } break; + case GGML_OP_CONCAT: { + split_state = handle_concat(src_ss); + } break; + case GGML_OP_SILU_BACK: { + split_state = handle_generic(src_ss, /*scalar_only =*/ false); + } break; + case GGML_OP_NORM: + case GGML_OP_RMS_NORM: + case GGML_OP_RMS_NORM_BACK: + case GGML_OP_GROUP_NORM: + case GGML_OP_L2_NORM: { + split_state = handle_per_row(src_ss); + } break; + case GGML_OP_MUL_MAT: + case GGML_OP_MUL_MAT_ID: { + split_state = handle_mul_mat(src_ss); + } break; + case GGML_OP_OUT_PROD: { + split_state = handle_generic(src_ss, /*scalar_only =*/ true); + } break; + case GGML_OP_SCALE: { + split_state = handle_generic(src_ss, /*scalar_only =*/ false); + } break; + case GGML_OP_SET: { + split_state = handle_generic(src_ss, /*scalar_only =*/ true); + } break; + case GGML_OP_CPY: { + split_state = handle_cpy(src_ss); + } break; + case GGML_OP_CONT: + case GGML_OP_RESHAPE: { + split_state = handle_reshape(src_ss); + } break; + case GGML_OP_VIEW: { + split_state = handle_view(src_ss); + } break; + case GGML_OP_PERMUTE: { + split_state = handle_permute(src_ss); + } break; + case GGML_OP_TRANSPOSE: { + split_state = handle_transpose(src_ss); + } break; + case GGML_OP_GET_ROWS: { + split_state = handle_get_rows(src_ss); + } break; + case GGML_OP_GET_ROWS_BACK: { + split_state = handle_generic(src_ss, /*scalar_only =*/ true); + } break; + case GGML_OP_SET_ROWS: { + split_state = handle_set_rows(src_ss); + } break; + case GGML_OP_DIAG: + case GGML_OP_DIAG_MASK_INF: + case GGML_OP_DIAG_MASK_ZERO: { + split_state = handle_generic(src_ss, /*scalar_only =*/ true); + } break; + case GGML_OP_SOFT_MAX: + case GGML_OP_SOFT_MAX_BACK: { + split_state = handle_generic(src_ss, /*scalar_only =*/ false); + } break; + case GGML_OP_ROPE: { + split_state = handle_rope(src_ss); + } break; + case GGML_OP_ROPE_BACK: { + split_state = handle_generic(src_ss, /*scalar_only =*/ true); + } break; + case GGML_OP_CLAMP: { + split_state = handle_generic(src_ss, /*scalar_only =*/ false); + } break; + case GGML_OP_CONV_TRANSPOSE_1D: + case GGML_OP_IM2COL: + case GGML_OP_IM2COL_BACK: + case GGML_OP_IM2COL_3D: + case GGML_OP_CONV_2D: + case GGML_OP_CONV_3D: + case GGML_OP_CONV_2D_DW: + case GGML_OP_CONV_TRANSPOSE_2D: + case GGML_OP_POOL_1D: + case GGML_OP_POOL_2D: + case GGML_OP_POOL_2D_BACK: + case GGML_OP_UPSCALE: { + split_state = handle_generic(src_ss, /*scalar_only =*/ true); + } break; + case GGML_OP_PAD: { + split_state = handle_pad(src_ss); + } break; + case GGML_OP_PAD_REFLECT_1D: + case GGML_OP_ROLL: + case GGML_OP_ARANGE: + case GGML_OP_TIMESTEP_EMBEDDING: { + split_state = handle_generic(src_ss, /*scalar_only =*/ true); + } break; + case GGML_OP_ARGSORT: + case GGML_OP_TOP_K: { + split_state = handle_per_row(src_ss); + } break; + case GGML_OP_LEAKY_RELU: { + split_state = handle_generic(src_ss, /*scalar_only =*/ false); + } break; + case GGML_OP_TRI: { + split_state = handle_generic(src_ss, /*scalar_only =*/ true); + } break; + case GGML_OP_FILL: { + split_state = handle_generic(src_ss, /*scalar_only =*/ false); + } break; + case GGML_OP_FLASH_ATTN_EXT: { + split_state = handle_flash_attn_ext(src_ss); + } break; + case GGML_OP_FLASH_ATTN_BACK: { + split_state = handle_generic(src_ss, /*scalar_only =*/ true); + } break; + case GGML_OP_SSM_CONV: { + split_state = handle_ssm_conv(src_ss); + } break; + case GGML_OP_SSM_SCAN: + case GGML_OP_WIN_PART: + case GGML_OP_WIN_UNPART: + case GGML_OP_GET_REL_POS: + case GGML_OP_ADD_REL_POS: + case GGML_OP_RWKV_WKV6: + case GGML_OP_GATED_LINEAR_ATTN: + case GGML_OP_RWKV_WKV7: + case GGML_OP_SOLVE_TRI: { + split_state = handle_generic(src_ss, /*scalar_only =*/ true); + } break; + case GGML_OP_GATED_DELTA_NET: { + split_state = handle_gated_delta_net(src_ss); + } break; + case GGML_OP_UNARY: { + split_state = handle_generic(src_ss, /*scalar_only =*/ false); + } break; + case GGML_OP_MAP_CUSTOM1: + case GGML_OP_MAP_CUSTOM2: + case GGML_OP_MAP_CUSTOM3: + case GGML_OP_CUSTOM: { + split_state = handle_generic(src_ss, /*scalar_only =*/ true); + } break; + case GGML_OP_CROSS_ENTROPY_LOSS: + case GGML_OP_CROSS_ENTROPY_LOSS_BACK: { + split_state = handle_per_row(src_ss); + } break; + case GGML_OP_OPT_STEP_ADAMW: + case GGML_OP_OPT_STEP_SGD: + case GGML_OP_GLU: { + split_state = handle_generic(src_ss, /*scalar_only =*/ false); + } break; + default: { + GGML_ABORT("ggml op not implemented: %s", ggml_op_name(tensor->op)); + split_state = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; + } break; + } + if (split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS) { + bool first_src_split_by_axis = true; + const size_t n_bufs = ggml_backend_meta_buffer_n_bufs(tensor->buffer); + + for (size_t i = 0; i < GGML_MAX_SRC; i++) { + if (tensor->src[i] == nullptr || src_ss[i].axis < 0 || src_ss[i].axis >= GGML_MAX_DIMS) { + continue; + } + if (first_src_split_by_axis) { + for (size_t j = 0; j < n_bufs; j++) { + // Take over ratio from src: + for (size_t s = 0; s < src_ss[i].n_segments; s++) { + split_state.ne[s*n_bufs + j] = 0; + } + for (size_t s = 0; s < src_ss[i].n_segments; s++) { + split_state.ne[j] += src_ss[i].ne[s*n_bufs + j]; + } + split_state.ne[j] *= tensor->ne[split_state.axis]; + if (split_state.ne[j] != 0 || tensor->src[i]->ne[src_ss[i].axis] != 0) { + GGML_ASSERT(split_state.ne[j] % tensor->src[i]->ne[src_ss[i].axis] == 0); + split_state.ne[j] /= tensor->src[i]->ne[src_ss[i].axis]; + } + } + } else { + for (size_t j = 0; j < n_bufs; j++) { + int64_t sum = 0; + for (size_t s = 0; s < src_ss[i].n_segments; s++) { + sum += src_ss[i].ne[s*n_bufs + j]; + } + // Assert that ratio is consistent: + GGML_ASSERT(split_state.ne[j] * tensor->src[i]->ne[src_ss[i].axis] + == sum * tensor->ne[split_state.axis]); + } + } + first_src_split_by_axis = false; + } + GGML_ASSERT(!first_src_split_by_axis); + } + return split_state; + }; + + const std::pair key = std::make_pair(tensor, assume_sync); + auto it = buf_ctx->split_state_cache.find(key); + if (it != buf_ctx->split_state_cache.end() && memcmp(it->second.second, (const char *) tensor, sizeof(it->second.second)) != 0) { + buf_ctx->split_state_cache.clear(); + it = buf_ctx->split_state_cache.end(); + } + + if (it == buf_ctx->split_state_cache.end()) { + buf_ctx->split_state_cache[key].first = calculate_split_state(); + memcpy(buf_ctx->split_state_cache[key].second, tensor, sizeof(buf_ctx->split_state_cache[key].second)); + if (buf_ctx->debug > 0) { + std::string srcs_info; + for (size_t i = 0; i < GGML_MAX_SRC; i++) { + if (tensor->src[i] == nullptr) { + continue; + } + if (!srcs_info.empty()) { + srcs_info += ", "; + } + const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor->src[0], true); + const char * axis_name = ggml_backend_meta_split_axis_name(split_state.axis); + std::string ne_info; + for (size_t j = 0; j < n_bufs; j++) { + if (!ne_info.empty()) { + ne_info += ", "; + } + ne_info += std::to_string(split_state.ne[j]); + } + srcs_info += std::string(tensor->src[i]->name) + "[" + ggml_op_name(tensor->src[i]->op) + ", " + axis_name + ", {" + ne_info + "}]"; + } + std::string ne_info; + for (size_t j = 0; j < n_bufs; j++) { + if (!ne_info.empty()) { + ne_info += ", "; + } + ne_info += std::to_string(buf_ctx->split_state_cache[key].first.ne[j]); + } + GGML_LOG_DEBUG("SPLIT_STATE: {%s} -> %s[%s, %s, {%s}]\n", srcs_info.c_str(), tensor->name, ggml_op_name(tensor->op), + ggml_backend_meta_split_axis_name(buf_ctx->split_state_cache[key].first.axis), ne_info.c_str()); + } + } + + ggml_backend_meta_split_state ret = buf_ctx->split_state_cache[key].first; + GGML_ASSERT(ret.axis != GGML_BACKEND_SPLIT_AXIS_NONE); +#ifndef NDEBUG + if (ret.axis >= 0 && ret.axis < GGML_MAX_DIMS) { + int64_t ne_ret = 0; + for (size_t sj = 0; sj < ret.n_segments*n_bufs; sj++) { + ne_ret += ret.ne[sj]; + } + assert(ne_ret == tensor->ne[int(ret.axis)]); + } +#endif // NDEBUG + return ret; +} + +static void * ggml_backend_meta_buffer_get_base(ggml_backend_buffer_t buffer) { + GGML_UNUSED(buffer); + return (void *) 0x1000000000000000; // FIXME +} + +static enum ggml_status ggml_backend_meta_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { + GGML_ASSERT(ggml_backend_buffer_is_meta(buffer)); + ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) buffer->context; + const size_t n_simple_bufs = ggml_backend_meta_buffer_n_bufs(buffer); + + const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ true); + GGML_ASSERT(ggml_nelements(tensor) == 0 || split_state.axis != GGML_BACKEND_SPLIT_AXIS_UNKNOWN); + GGML_ASSERT(split_state.n_segments <= 16); + + int split_dim = split_state.axis; + int64_t ne[GGML_MAX_DIMS]; + size_t nb[GGML_MAX_DIMS]; + for (size_t k = 0; k < GGML_MAX_DIMS; k++) { + ne[k] = tensor->ne[k]; + nb[k] = tensor->nb[k]; + } + + std::vector simple_tensors; + simple_tensors.reserve(n_simple_bufs); + for (size_t j = 0; j < n_simple_bufs; j++) { + ggml_context * simple_ctx = buf_ctx->buf_configs[j].ctx; + ggml_backend_buffer_t simple_buf = buf_ctx->buf_configs[j].buf; + + if (split_dim >= 0 && split_dim < GGML_MAX_DIMS) { + // TODO: the following assert fails for llama-parallel even though the results are correct: + // GGML_ASSERT(ggml_is_contiguously_allocated(tensor)); + ne[split_dim] = 0; + for (size_t s = 0; s < split_state.n_segments; s++) { + ne[split_dim] += split_state.ne[s*n_simple_bufs + j]; + } + for (int i = 0; i < GGML_MAX_DIMS; i++) { + if (tensor->nb[i] > tensor->nb[split_dim]) { + nb[i] = tensor->nb[i] * ne[split_dim]/tensor->ne[split_dim]; + } + } + } + + ggml_tensor * t_ij = ggml_new_tensor(simple_ctx, tensor->type, GGML_MAX_DIMS, ne); + t_ij->op = tensor->op; + for (int i = 0; i < GGML_MAX_DIMS; i++) { + t_ij->nb[i] = nb[i]; + } + t_ij->flags = tensor->flags; + memcpy(t_ij->op_params, tensor->op_params, sizeof(tensor->op_params)); + ggml_set_name(t_ij, tensor->name); + t_ij->buffer = simple_buf; + t_ij->view_src = tensor->view_src; + t_ij->view_offs = tensor->view_offs; + if (t_ij->view_src != nullptr && ggml_backend_buffer_is_meta(t_ij->view_src->buffer)) { + t_ij->view_src = ggml_backend_meta_buffer_simple_tensor(tensor->view_src, j); + if (t_ij->view_offs > 0 && split_dim >= 0 && split_dim < GGML_MAX_DIMS) { + GGML_ASSERT(ne[split_dim] != 0 && tensor->ne[split_dim] != 0); + const int split_dim_view_src = ggml_backend_meta_get_split_state(tensor->view_src, /*assume_sync =*/ true).axis; + GGML_ASSERT(split_dim_view_src >= 0 && split_dim_view_src < GGML_MAX_DIMS); + + // The offset can be internal to the data split, in those cases the view offset should not be scaled. + // If however, the offset is larger than the data split then it needs to be scaled proportionally. + bool split_internal_offset = t_ij->view_offs <= tensor->view_src->nb[split_dim_view_src]; + for (int i = 0; i < GGML_MAX_DIMS; i++) { + const size_t dim_size = tensor->ne[i] * tensor->nb[i]; + if (tensor->view_offs <= dim_size && dim_size < tensor->nb[split_dim]) { + split_internal_offset = true; + break; + } + } + if (!split_internal_offset) { + t_ij->view_offs = t_ij->view_offs * ne[split_dim]/tensor->ne[split_dim]; + } + } + } + if (t_ij->view_src != nullptr) { + t_ij->data = (char *) t_ij->view_src->data + t_ij->view_offs; + } else if (simple_buf != nullptr) { + t_ij->data = (char *) ggml_backend_buffer_get_base(simple_buf) + + size_t(tensor->data) - size_t(ggml_backend_buffer_get_base(buffer)); + } + t_ij->extra = tensor->extra; + for (int i = 0; i < GGML_MAX_SRC; i++) { + t_ij->src[i] = tensor->src[i]; + if (tensor->src[i] == tensor) { + t_ij->src[i] = t_ij; + } else if (t_ij->src[i] != nullptr && ggml_backend_buffer_is_meta(t_ij->src[i]->buffer)) { + t_ij->src[i] = ggml_backend_meta_buffer_simple_tensor(tensor->src[i], j); + } + } + + simple_tensors.push_back(t_ij); + } + buf_ctx->simple_tensors[tensor] = simple_tensors; + + return GGML_STATUS_SUCCESS; +} + +static void ggml_backend_meta_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + const size_t n_bufs = ggml_backend_meta_buffer_n_bufs(buffer); + GGML_ASSERT(ggml_is_contiguous(tensor)); + + const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ false); + + if (split_state.n_segments != 1) { + GGML_ASSERT(split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS); + GGML_ASSERT(offset == 0); + GGML_ASSERT(size == ggml_nbytes(tensor)); + GGML_ASSERT(tensor->ne[3] == 1); + size_t offset_data = 0; + std::vector simple_offsets(n_bufs, 0); + if (split_state.axis == GGML_BACKEND_SPLIT_AXIS_0) { + GGML_ASSERT(tensor->ne[2] == 1); + const int64_t blck_size = ggml_blck_size(tensor->type); + for (size_t s = 0; s < split_state.n_segments; s++) { + for (size_t j = 0; j < n_bufs; j++) { + ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); + GGML_ASSERT(split_state.ne[s*n_bufs + j] % blck_size == 0); + const size_t nbytes = split_state.ne[s*n_bufs + j]/blck_size * tensor->nb[0]; + ggml_backend_tensor_set_2d(simple_tensor, (const char *) data + offset_data, simple_offsets[j], nbytes, + tensor->ne[1], simple_tensor->nb[1], tensor->nb[1]); + offset_data += nbytes; + simple_offsets[j] += nbytes; + } + } + GGML_ASSERT(offset_data*tensor->ne[1] == size); + return; + } + GGML_ASSERT(split_state.axis == GGML_BACKEND_SPLIT_AXIS_1); + for (size_t s = 0; s < split_state.n_segments; s++) { + for (size_t j = 0; j < n_bufs; j++) { + ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); + const size_t nbytes = split_state.ne[s*n_bufs + j] * tensor->nb[1]; + ggml_backend_tensor_set_2d(simple_tensor, (const char *) data + offset_data, simple_offsets[j], nbytes, + tensor->ne[2], simple_tensor->nb[2], tensor->nb[2]); + offset_data += nbytes; + simple_offsets[j] += nbytes; + } + } + GGML_ASSERT(offset_data*tensor->ne[2] == size); + return; + } + + switch (split_state.axis) { + case GGML_BACKEND_SPLIT_AXIS_0: + case GGML_BACKEND_SPLIT_AXIS_1: + case GGML_BACKEND_SPLIT_AXIS_2: { + // Exploit that tensors are contiguous to splice it with simple tensors as "chunks". + const size_t chunk_size_full = tensor->nb[split_state.axis + 1]; + GGML_ASSERT(offset % chunk_size_full == 0); + GGML_ASSERT(size % chunk_size_full == 0); + const int64_t i_start = offset /chunk_size_full; + const int64_t i_stop = (offset + size)/chunk_size_full; + size_t offset_j = 0; + for (size_t j = 0; j < n_bufs; j++) { + ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); + const size_t chunk_size_j = simple_tensor->nb[split_state.axis + 1]; + const size_t simple_offset = i_start * chunk_size_j; + ggml_backend_tensor_set_2d(simple_tensor, (const char *) data + offset_j, simple_offset, chunk_size_j, i_stop - i_start, chunk_size_j, chunk_size_full); + offset_j += chunk_size_j; + } + GGML_ASSERT(offset_j == chunk_size_full); + } break; + case GGML_BACKEND_SPLIT_AXIS_MIRRORED: { + for (size_t j = 0; j < n_bufs; j++) { + ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); + ggml_backend_tensor_set(simple_tensor, data, offset, size); + } + } break; + case GGML_BACKEND_SPLIT_AXIS_PARTIAL: { + GGML_ASSERT(tensor->type == GGML_TYPE_F32); + const int64_t ne = ggml_nelements(tensor); + std::vector tmp; + tmp.reserve(ne); + for (int64_t i = 0; i < ne; i++) { + tmp.push_back(((const float *) data)[i] / n_bufs); + } + for (size_t j = 0; j < n_bufs; j++) { + ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); + ggml_backend_tensor_set(simple_tensor, tmp.data(), offset, size); + } + } break; + default: { + GGML_ABORT("fatal error"); + } + } +} + +static void ggml_backend_meta_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { + const size_t n_bufs = ggml_backend_meta_buffer_n_bufs(buffer); + GGML_ASSERT(ggml_is_contiguous(tensor)); + + const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ false); + GGML_ASSERT(split_state.n_segments == 1); + + switch (split_state.axis) { + case GGML_BACKEND_SPLIT_AXIS_0: + case GGML_BACKEND_SPLIT_AXIS_1: + case GGML_BACKEND_SPLIT_AXIS_2: { + // Exploit that tensors are contiguous to splice it with simple tensors as "chunks". + const size_t chunk_size_full = tensor->nb[split_state.axis + 1]; + GGML_ASSERT(offset % chunk_size_full == 0); + GGML_ASSERT(size % chunk_size_full == 0); + const int64_t i_start = offset /chunk_size_full; + const int64_t i_stop = (offset + size)/chunk_size_full; + size_t offset_j = 0; + for (size_t j = 0; j < n_bufs; j++){ + const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); + const size_t chunk_size_j = simple_tensor->nb[split_state.axis + 1]; + const size_t simple_offset = i_start * chunk_size_j; + ggml_backend_tensor_get_2d(simple_tensor, (char *) data + offset_j, simple_offset, chunk_size_j, i_stop - i_start, chunk_size_j, chunk_size_full); + offset_j += chunk_size_j; + } + GGML_ASSERT(offset_j == chunk_size_full); + } break; + case GGML_BACKEND_SPLIT_AXIS_MIRRORED: { + // TODO other simple backend may be better + const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, 0); + ggml_backend_tensor_get(simple_tensor, data, offset, size); + } break; + default: { + GGML_ABORT("fatal error"); + } + } +} + +static void ggml_backend_meta_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { + const size_t n_buffers = ggml_backend_meta_buffer_n_bufs(buffer); + for (size_t i = 0; i < n_buffers; i++) { + ggml_backend_buffer_clear(ggml_backend_meta_buffer_simple_buffer(buffer, i), value); + } +} + +static void ggml_backend_meta_buffer_reset(ggml_backend_buffer_t buffer) { + const size_t n_buffers = ggml_backend_meta_buffer_n_bufs(buffer); + for (size_t i = 0; i < n_buffers; i++) { + ggml_backend_buffer_reset(ggml_backend_meta_buffer_simple_buffer(buffer, i)); + } +} + +static const ggml_backend_buffer_i ggml_backend_meta_buffer_iface = { + /* .free_buffer = */ ggml_backend_meta_buffer_free_buffer, + /* .get_base = */ ggml_backend_meta_buffer_get_base, + /* .init_tensor = */ ggml_backend_meta_buffer_init_tensor, + /* .memset_tensor = */ nullptr, // TODO implement + /* .set_tensor = */ ggml_backend_meta_buffer_set_tensor, + /* .get_tensor = */ ggml_backend_meta_buffer_get_tensor, + /* .set_tensor_2d = */ nullptr, + /* .get_tensor_2d = */ nullptr, + /* .cpy_tensor = */ nullptr, + /* .clear = */ ggml_backend_meta_buffer_clear, + /* .reset = */ ggml_backend_meta_buffer_reset, +}; + +bool ggml_backend_buffer_is_meta(ggml_backend_buffer_t buf) { + return buf != nullptr && buf->iface.free_buffer == ggml_backend_meta_buffer_iface.free_buffer; +} + +static ggml_backend_buffer_t ggml_backend_meta_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { + const size_t n_simple_bufts = ggml_backend_meta_buft_n_bufts(buft); + + ggml_init_params params = { + /*.mem_size =*/ 1024*1024*1024, // FIXME + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + + ggml_backend_meta_buffer_context * buf_ctx = new ggml_backend_meta_buffer_context(); + size_t max_size = 0; + buf_ctx->buf_configs.reserve(n_simple_bufts); + for (size_t i = 0; i < n_simple_bufts; i++) { + ggml_backend_buffer_t simple_buf = ggml_backend_buft_alloc_buffer(ggml_backend_meta_buft_simple_buft(buft, i), size); + max_size = std::max(max_size, ggml_backend_buffer_get_size(simple_buf)); + buf_ctx->buf_configs.emplace_back(ggml_init(params), simple_buf); + } + + return ggml_backend_buffer_init(buft, ggml_backend_meta_buffer_iface, buf_ctx, max_size); +} + +struct ggml_backend_buffer * ggml_backend_meta_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft) { + const size_t n_simple_bufts = ggml_backend_meta_buft_n_bufts(buft); + + ggml_init_params params = { + /*.mem_size =*/ 1024*1024*1024, // FIXME + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + + ggml_backend_meta_buffer_context * meta_buf_ctx = new ggml_backend_meta_buffer_context(); + meta_buf_ctx->buf_configs.reserve(n_simple_bufts); + for (size_t i = 0; i < n_simple_bufts; i++) { + meta_buf_ctx->buf_configs.emplace_back(ggml_init(params), nullptr); + } + + ggml_backend_buffer_t meta_buf = ggml_backend_buffer_init(buft, ggml_backend_meta_buffer_iface, meta_buf_ctx, 0); + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) { + t->buffer = meta_buf; + ggml_backend_meta_buffer_init_tensor(meta_buf, t); + t->data = (void *) 0x2000000000000000; // FIXME + } + for (size_t i = 0; i < n_simple_bufts; i++) { + meta_buf_ctx->buf_configs[i].buf = ggml_backend_alloc_ctx_tensors_from_buft( + meta_buf_ctx->buf_configs[i].ctx, ggml_backend_meta_buft_simple_buft(buft, i)); + meta_buf->size = std::max(meta_buf->size, ggml_backend_buffer_get_size(meta_buf_ctx->buf_configs[i].buf)); + } + return meta_buf; +} + +// +// meta backend +// + +static ggml_guid_t ggml_backend_meta_guid() { + static ggml_guid guid = {0xf1, 0x0e, 0x34, 0xcf, 0x9c, 0x6f, 0x43, 0xcb, 0x96, 0x92, 0xbe, 0x8e, 0xbb, 0x71, 0x3f, 0xda}; + return &guid; +} + +struct ggml_backend_meta_context { + struct cgraph_config { + ggml_cgraph * cgraph_main = nullptr; + int offset = 0; // Node offset vs. original graph + + std::vector cgraphs_aux; + }; + struct backend_config { + ggml_backend_t backend; + + std::vector cgraphs; + std::vector nodes; + ggml_backend_buffer_ptr buf; + + backend_config(ggml_backend_t backend) : backend(backend) {} + }; + std::string name; + std::vector backend_configs; + ggml_context_ptr ctx; + std::vector cgraphs_aux; + std::vector nodes_aux; + int max_nnodes = 0; + size_t max_tmp_size = 0; + size_t max_subgraphs = 0; + + ggml_backend_meta_context(ggml_backend_dev_t meta_dev, const char * params) { + const size_t n_devs = ggml_backend_meta_dev_n_devs(meta_dev); + name = "Meta("; + backend_configs.reserve(n_devs); + for (size_t i = 0; i < n_devs; i++) { + ggml_backend_dev_t simple_dev = ggml_backend_meta_dev_simple_dev(meta_dev, i); + if (i > 0) { + name += ","; + } + name += ggml_backend_dev_name(simple_dev); + backend_configs.emplace_back(ggml_backend_dev_init(simple_dev, params)); + } + name += ")"; + } + + ~ggml_backend_meta_context() { + for (auto & bc : backend_configs) { + ggml_backend_free(bc.backend); + } + } + + size_t n_reduce_steps() const { + return std::ceil(std::log2(backend_configs.size())); + } +}; + +static const char * ggml_backend_meta_get_name(ggml_backend_t backend) { + GGML_ASSERT(ggml_backend_is_meta(backend)); + const ggml_backend_meta_context * backend_ctx = (const ggml_backend_meta_context *) backend->context; + return backend_ctx->name.c_str(); +} + +static void ggml_backend_meta_free(ggml_backend_t backend) { + GGML_ASSERT(ggml_backend_is_meta(backend)); + ggml_backend_meta_context * backend_ctx = (ggml_backend_meta_context *) backend->context; + delete backend_ctx; + delete backend; +} + +static void ggml_backend_meta_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + const size_t n_backends = ggml_backend_meta_n_backends(backend); + GGML_ASSERT(offset == 0); + GGML_ASSERT(ggml_is_contiguous(tensor)); + + const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ false); + GGML_ASSERT(split_state.n_segments == 1); + + switch (split_state.axis) { + case GGML_BACKEND_SPLIT_AXIS_0: + case GGML_BACKEND_SPLIT_AXIS_1: + case GGML_BACKEND_SPLIT_AXIS_2: { + // Exploit that tensors are contiguous to splice it with simple tensors as "chunks". + const size_t chunk_size_full = tensor->nb[split_state.axis + 1]; + GGML_ASSERT(offset % chunk_size_full == 0); + GGML_ASSERT(size % chunk_size_full == 0); + const int64_t i_start = offset /chunk_size_full; + const int64_t i_stop = (offset + size)/chunk_size_full; + size_t offset_j = 0; + for (size_t j = 0; j < n_backends; j++){ + ggml_backend_t simple_backend = ggml_backend_meta_simple_backend(backend, j); + ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); + const size_t chunk_size_j = simple_tensor->nb[split_state.axis + 1]; + ggml_backend_tensor_set_2d_async(simple_backend, simple_tensor, (const char *) data + offset_j, offset, chunk_size_j, + i_stop - i_start, chunk_size_j, chunk_size_full); + offset_j += chunk_size_j; + } + GGML_ASSERT(offset_j == chunk_size_full); + } break; + case GGML_BACKEND_SPLIT_AXIS_MIRRORED: { + for (size_t j = 0; j < n_backends; j++) { + ggml_backend_tensor_set_async( + ggml_backend_meta_simple_backend(backend, j), ggml_backend_meta_buffer_simple_tensor(tensor, j), data, offset, size); + } + } break; + default: { + GGML_ABORT("fatal error"); + } + } +} + +static void ggml_backend_meta_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { + const size_t n_backends = ggml_backend_meta_n_backends(backend); + GGML_ASSERT(offset == 0); + GGML_ASSERT(ggml_is_contiguous(tensor)); + + const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ false); + GGML_ASSERT(split_state.n_segments == 1); + + switch (split_state.axis) { + case GGML_BACKEND_SPLIT_AXIS_0: + case GGML_BACKEND_SPLIT_AXIS_1: + case GGML_BACKEND_SPLIT_AXIS_2: { + // Exploit that tensors are contiguous to splice it with simple tensors as "chunks". + const size_t chunk_size_full = tensor->nb[split_state.axis + 1]; + GGML_ASSERT(offset % chunk_size_full == 0); + GGML_ASSERT(size % chunk_size_full == 0); + const int64_t i_start = offset /chunk_size_full; + const int64_t i_stop = (offset + size)/chunk_size_full; + size_t offset_j = 0; + for (size_t j = 0; j < n_backends; j++){ + ggml_backend_t simple_backend = ggml_backend_meta_simple_backend(backend, j); + const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); + const size_t chunk_size_j = simple_tensor->nb[split_state.axis + 1]; + ggml_backend_tensor_get_2d_async(simple_backend, simple_tensor, (char *) data + offset_j, offset, chunk_size_j, + i_stop - i_start, chunk_size_j, chunk_size_full); + offset_j += chunk_size_j; + } + GGML_ASSERT(offset_j == chunk_size_full); + } break; + case GGML_BACKEND_SPLIT_AXIS_MIRRORED: { + // TODO other simple backend may be better + ggml_backend_t simple_backend = ggml_backend_meta_simple_backend(backend, 0); + const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, 0); + ggml_backend_tensor_get_async(simple_backend, simple_tensor, data, offset, size); + } break; + default: { + GGML_ABORT("fatal error"); + } + } +} + +static void ggml_backend_meta_synchronize(ggml_backend_t backend) { + const size_t n_backends = ggml_backend_meta_n_backends(backend); + for (size_t i = 0; i < n_backends; i++) { + ggml_backend_synchronize(ggml_backend_meta_simple_backend(backend, i)); + } +} + +static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) { + GGML_ASSERT(cgraph->grads == nullptr); + const size_t n_backends = ggml_backend_meta_n_backends(backend); + ggml_backend_meta_context * backend_ctx = (ggml_backend_meta_context *) backend->context; + + bool max_nnodes_raised = false; + if (cgraph->n_nodes > backend_ctx->max_nnodes) { + for (size_t j = 0; j < n_backends; j++) { + auto & bcj = backend_ctx->backend_configs[j]; + bcj.nodes.resize(cgraph->n_nodes); + bcj.cgraphs.resize(cgraph->n_nodes); + } + backend_ctx->max_nnodes = cgraph->n_nodes; + max_nnodes_raised = true; + } + for (size_t j = 0; j < n_backends; j++) { + auto & bcj = backend_ctx->backend_configs[j]; + + for (int i = 0; i < cgraph->n_nodes; i++) { + ggml_tensor * node = cgraph->nodes[i]; + if (node->view_src != nullptr && node->view_src->op == GGML_OP_NONE && ggml_backend_buffer_is_host(node->view_src->buffer)) { + // FIXME s_copy_main is on the CPU and its view seems to be incorrectly added to the graph nodes. + // For regular usage this doesn't matter since it's a noop but trying to call ggml_backend_meta_buffer_simple_tensor results in a crash. + bcj.nodes[i] = node; + continue; + } + bcj.nodes[i] = ggml_backend_meta_buffer_simple_tensor(node, j); + GGML_ASSERT(bcj.nodes[i]); + } + } + + size_t n_subgraphs = 0; + size_t max_tmp_size = 0; + { + // For MoE models it may make sense to delay the AllReduce in order to reduce I/O: + auto get_i_delayed = [&](const int i) -> int { + int id = i; // i_delayed + int idr = i; // i_delayed return, last safe return value + + ggml_tensor * node = cgraph->nodes[id]; + int32_t n_used = ggml_node_get_use_count(cgraph, id); + if (id + 1 >= cgraph->n_nodes) { + return idr; + } + { + ggml_tensor * next = cgraph->nodes[id+1]; + if (next->op == GGML_OP_ADD_ID && next->src[0] == node && + ggml_backend_meta_get_split_state(next->src[1], false).axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL && + ggml_backend_meta_get_split_state(next->src[2], false).axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { + node = next; + id++; + idr = id; + n_used = ggml_node_get_use_count(cgraph, id); + } + } + if (id + 1 >= cgraph->n_nodes) { + return idr; + } + { + ggml_tensor * next = cgraph->nodes[id+1]; + if (next->op == GGML_OP_MUL && next->src[0] == node && + ggml_backend_meta_get_split_state(next->src[1], false).axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { + node = next; + id++; + idr = id; + n_used = ggml_node_get_use_count(cgraph, id); + } + } + + if (n_used != node->ne[1] || id + 2*n_used-1 >= cgraph->n_nodes) { + return idr; + } + for (int32_t k = 0; k < n_used; k++) { + ggml_tensor * next = cgraph->nodes[id+1]; + if (next->op != GGML_OP_VIEW || next->view_src != node || next->view_offs != k*node->nb[1] || + next->ne[0] != node->ne[0] || next->ne[1] != node->ne[2] || next->nb[1] != node->nb[2] || + ggml_node_get_use_count(cgraph, id+1) != 1) { + return idr; + } + id++; + } + { + ggml_tensor * next = cgraph->nodes[id+1]; + if (next->op != GGML_OP_ADD || next->src[0] != cgraph->nodes[id - (n_used-1)] || + next->src[1] != cgraph->nodes[id - (n_used-2)] || ggml_node_get_use_count(cgraph, id+1) != 1) { + return idr; + } + id++; + } + for (int32_t k = 0; k < n_used - 2; k++) { + ggml_tensor * next = cgraph->nodes[id+1]; + if (next->op != GGML_OP_ADD || next->src[0] != cgraph->nodes[id] || + next->src[1] != cgraph->nodes[id - (n_used-2)] || ggml_node_get_use_count(cgraph, id+1) != 1) { + return idr; + } + id++; + } + idr = id; + return idr; + }; + + int i_start = 0; + for (int i = 0; i < cgraph->n_nodes; i++) { + ggml_tensor * node = cgraph->nodes[i]; + if (node->view_src != nullptr && node->view_src->op == GGML_OP_NONE && ggml_backend_buffer_is_host(node->view_src->buffer)) { + continue; + } + const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(node, /*assume_sync =*/ false); + if (split_state.axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL) { + max_tmp_size = std::max(max_tmp_size, ggml_nbytes(node)); + } + const bool new_subgraph = i + 1 == cgraph->n_nodes || split_state.axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL; + if (!new_subgraph) { + continue; + } + + i = get_i_delayed(i); + + for (size_t j = 0; j < n_backends; j++) { + auto & bcj = backend_ctx->backend_configs[j]; + bcj.cgraphs[n_subgraphs].offset = i_start; + } + n_subgraphs++; + i_start = i + 1; + } + GGML_ASSERT(i_start == cgraph->n_nodes); + } + + if (max_tmp_size > backend_ctx->max_tmp_size) { + for (size_t j = 0; j < n_backends; j++) { + auto & bcj = backend_ctx->backend_configs[j]; + bcj.buf.reset(ggml_backend_alloc_buffer(bcj.backend, max_tmp_size)); + } + backend_ctx->max_tmp_size = max_tmp_size; + } + + + if (max_nnodes_raised || n_subgraphs > backend_ctx->max_subgraphs) { + backend_ctx->max_subgraphs = std::max(backend_ctx->max_subgraphs, n_subgraphs); + const size_t n_reduce_steps = backend_ctx->n_reduce_steps(); + const size_t n_nodes_per_device = 2 * n_reduce_steps; // tmp + ADD per step + const size_t n_cgraphs_per_device = n_reduce_steps; // 1 ADD graph per step + const size_t mem_per_device_graphs_main = backend_ctx->max_subgraphs*ggml_graph_overhead_custom(backend_ctx->max_nnodes, cgraph->grads); + const size_t mem_per_device_graphs_aux = n_cgraphs_per_device*backend_ctx->max_subgraphs*ggml_graph_overhead_custom(1, cgraph->grads); + const size_t mem_per_device_nodes_aux = n_nodes_per_device*backend_ctx->max_subgraphs*ggml_tensor_overhead(); + ggml_init_params params = { + /*.mem_size =*/ n_backends * (mem_per_device_graphs_main + mem_per_device_graphs_aux + mem_per_device_nodes_aux), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + backend_ctx->ctx.reset(ggml_init(params)); + for (size_t j = 0; j < n_backends; j++) { + auto & bcj = backend_ctx->backend_configs[j]; + for (size_t i = 0; i < n_subgraphs; i++) { + bcj.cgraphs[i].cgraph_main = ggml_new_graph_custom(backend_ctx->ctx.get(), cgraph->n_nodes, /*grads =*/ false); + } + } + backend_ctx->cgraphs_aux.resize(n_backends*n_cgraphs_per_device*backend_ctx->max_subgraphs); + for (size_t k = 0; k < backend_ctx->cgraphs_aux.size(); k++) { + backend_ctx->cgraphs_aux[k] = ggml_new_graph_custom(backend_ctx->ctx.get(), 1, cgraph->grads); + } + backend_ctx->nodes_aux.resize(n_backends*n_nodes_per_device*backend_ctx->max_subgraphs); + for (size_t k = 0; k < backend_ctx->nodes_aux.size(); k++) { + backend_ctx->nodes_aux[k] = ggml_new_tensor_1d(backend_ctx->ctx.get(), GGML_TYPE_F32, 1); + } + } + + for (size_t j = 0; j < n_backends; j++) { + auto & bcj = backend_ctx->backend_configs[j]; + for (size_t i_graph = 0; i_graph < n_subgraphs; i_graph++) { + ggml_cgraph * cgraph_ij = bcj.cgraphs[i_graph].cgraph_main; + const size_t i_node_start = bcj.cgraphs[i_graph].offset; + const size_t i_node_stop = i_graph + 1 < n_subgraphs ? bcj.cgraphs[i_graph + 1].offset : cgraph->n_nodes; + cgraph_ij->n_nodes = i_node_stop - i_node_start; + ggml_hash_set_reset(&cgraph_ij->visited_hash_set); + for (size_t i_node = i_node_start; i_node < i_node_stop; i_node++) { + ggml_tensor * node_ij = bcj.nodes[i_node]; + cgraph_ij->nodes[i_node - i_node_start] = node_ij; + const size_t hash_pos_orig = ggml_hash_find(&cgraph->visited_hash_set, cgraph->nodes[i_node]); + const size_t hash_pos_ij = ggml_hash_insert(&cgraph_ij->visited_hash_set, node_ij); + cgraph_ij->use_counts[hash_pos_ij] = cgraph->use_counts[hash_pos_orig]; + } + } + } + + size_t iga = 0; // i graph aux + size_t ina = 0; // i node aux + + // FIXME usage_counts + auto get_cgraph_aux = [&]() -> ggml_cgraph * { + ggml_cgraph * ret = backend_ctx->cgraphs_aux[iga++]; + return ret; + }; + auto get_node_aux = [&](ggml_tensor * t) -> ggml_tensor * { + ggml_tensor * ret = backend_ctx->nodes_aux[ina++]; + memset(ret, 0, sizeof(ggml_tensor)); + ret->op = GGML_OP_NONE; + ret->type = t->type; + for (size_t k = 0; k < GGML_MAX_DIMS; k++) { + ret->ne[k] = t->ne[k]; + ret->nb[k] = t->nb[k]; + } + return ret; + }; + + // Preferentially use backend-specific allreduce_tensor_async (e.g. NCCL for CUDA), use a generic fallback if unavailable: + auto allreduce_fallback = [&](size_t i) -> ggml_status { + std::vector step_cgraphs(n_backends, nullptr); + + for (size_t offset_j = 1; offset_j < n_backends; offset_j *= 2) { + std::fill(step_cgraphs.begin(), step_cgraphs.end(), nullptr); + + for (size_t j = 0; j < n_backends; j++) { + const size_t j_other = j ^ offset_j; + if (j_other > j) { + continue; + } + + auto & bcj1 = backend_ctx->backend_configs[j]; + auto & bcj2 = backend_ctx->backend_configs[j_other]; + + ggml_tensor * node1 = bcj1.cgraphs[i].cgraph_main->nodes[bcj1.cgraphs[i].cgraph_main->n_nodes - 1]; + ggml_tensor * node2 = bcj2.cgraphs[i].cgraph_main->nodes[bcj2.cgraphs[i].cgraph_main->n_nodes - 1]; + GGML_ASSERT(ggml_is_contiguous(node1)); + GGML_ASSERT(ggml_is_contiguous(node2)); + + // Tmp tensors to receive P2P copies + ggml_tensor * node_tmp_1 = get_node_aux(node1); + node_tmp_1->buffer = bcj1.buf.get(); + node_tmp_1->data = ggml_backend_buffer_get_base(bcj1.buf.get()); + + ggml_tensor * node_tmp_2 = get_node_aux(node2); + node_tmp_2->buffer = bcj2.buf.get(); + node_tmp_2->data = ggml_backend_buffer_get_base(bcj2.buf.get()); + + // 2 P2P copies: exchange full buffers + ggml_backend_tensor_copy_async(bcj1.backend, bcj2.backend, node1, node_tmp_2); + ggml_backend_tensor_copy_async(bcj2.backend, bcj1.backend, node2, node_tmp_1); + + // Local ADD: node1 += tmp1 (in-place via view) + ggml_tensor * node_red_1 = get_node_aux(node1); + node_red_1->view_src = node1->view_src == nullptr ? node1 : node1->view_src; + node_red_1->view_offs = node1->view_offs; + node_red_1->op = GGML_OP_ADD; + node_red_1->src[0] = node1; + node_red_1->src[1] = node_tmp_1; + node_red_1->flags |= GGML_TENSOR_FLAG_COMPUTE; + ggml_backend_view_init(node_red_1); + + // Local ADD: node2 += tmp2 (in-place via view) + ggml_tensor * node_red_2 = get_node_aux(node2); + node_red_2->view_src = node2->view_src == nullptr ? node2 : node2->view_src; + node_red_2->view_offs = node2->view_offs; + node_red_2->op = GGML_OP_ADD; + node_red_2->src[0] = node2; + node_red_2->src[1] = node_tmp_2; + node_red_2->flags |= GGML_TENSOR_FLAG_COMPUTE; + ggml_backend_view_init(node_red_2); + + // Build 1-node cgraphs for the ADD ops + ggml_cgraph * cgraph_aux_1 = get_cgraph_aux(); + cgraph_aux_1->nodes[0] = node_red_1; + cgraph_aux_1->n_nodes = 1; + step_cgraphs[j] = cgraph_aux_1; + + ggml_cgraph * cgraph_aux_2 = get_cgraph_aux(); + cgraph_aux_2->nodes[0] = node_red_2; + cgraph_aux_2->n_nodes = 1; + step_cgraphs[j_other] = cgraph_aux_2; + } + + // Execute local ADDs for this step + for (size_t j = 0; j < n_backends; j++) { + if (step_cgraphs[j] == nullptr) { + continue; + } + auto & bcj = backend_ctx->backend_configs[j]; + const ggml_status status = ggml_backend_graph_compute_async(bcj.backend, step_cgraphs[j]); + if (status != GGML_STATUS_SUCCESS) { + return status; + } + } + } + return GGML_STATUS_SUCCESS; + }; + + + for (size_t i = 0; i < n_subgraphs; i++) { + for (size_t j = 0; j < n_backends; j++) { + auto & bcj = backend_ctx->backend_configs[j]; + const ggml_status status = ggml_backend_graph_compute_async(bcj.backend, bcj.cgraphs[i].cgraph_main); + if (status != GGML_STATUS_SUCCESS) { + return status; + } + } + + if (n_backends > 1 && i < n_subgraphs - 1) { + bool backend_allreduce_success = false; + ggml_backend_allreduce_tensor_t allreduce_tensor = (ggml_backend_allreduce_tensor_t) ggml_backend_reg_get_proc_address( + ggml_backend_dev_backend_reg(ggml_backend_get_device(backend_ctx->backend_configs[0].backend)), "ggml_backend_allreduce_tensor"); + if (allreduce_tensor) { + std::vector backends; + backends.reserve(n_backends); + std::vector nodes; + nodes.reserve(n_backends); + for (size_t j = 0; j < n_backends; j++) { + auto & bcj = backend_ctx->backend_configs[j]; + backends.push_back(bcj.backend); + ggml_cgraph * cgraph_ij = bcj.cgraphs[i].cgraph_main; + nodes.push_back(cgraph_ij->nodes[cgraph_ij->n_nodes-1]); + } + backend_allreduce_success = allreduce_tensor(backends.data(), nodes.data(), n_backends); + } + + if (!backend_allreduce_success) { + const ggml_status status = allreduce_fallback(i); + if (status != GGML_STATUS_SUCCESS) { + return status; + } + } + } + } + return GGML_STATUS_SUCCESS; +} + +static const ggml_backend_i ggml_backend_meta_i = { + /* .get_name = */ ggml_backend_meta_get_name, + /* .free = */ ggml_backend_meta_free, + /* .set_tensor_async = */ ggml_backend_meta_set_tensor_async, + /* .get_tensor_async = */ ggml_backend_meta_get_tensor_async, + /* .get_tensor_2d_async = */ nullptr, + /* .set_tensor_2d_async = */ nullptr, + /* .cpy_tensor_async = */ nullptr, + /* .synchronize = */ ggml_backend_meta_synchronize, + /* .graph_plan_create = */ nullptr, + /* .graph_plan_free = */ nullptr, + /* .graph_plan_update = */ nullptr, + /* .graph_plan_compute = */ nullptr, + /* .graph_compute = */ ggml_backend_meta_graph_compute, + /* .event_record = */ nullptr, + /* .event_wait = */ nullptr, + /* .graph_optimize = */ nullptr, +}; + +bool ggml_backend_is_meta(ggml_backend_t backend) { + return backend != nullptr && backend->iface.get_name == ggml_backend_meta_i.get_name; +} + +static ggml_backend_t ggml_backend_meta_device_init_backend(ggml_backend_dev_t dev, const char * params) { + ggml_backend_meta_context * backend_ctx = new ggml_backend_meta_context(dev, params); + + ggml_backend_t backend = new struct ggml_backend; + backend->guid = ggml_backend_meta_guid(); + backend->iface = ggml_backend_meta_i; + backend->device = dev; + backend->context = backend_ctx; + return backend; +} + +size_t ggml_backend_meta_n_backends(ggml_backend_t meta_backend) { + GGML_ASSERT(ggml_backend_is_meta(meta_backend)); + const ggml_backend_meta_context * backend_ctx = (const ggml_backend_meta_context *) meta_backend->context; + return backend_ctx->backend_configs.size(); +} + +ggml_backend_t ggml_backend_meta_simple_backend(ggml_backend_t meta_backend, size_t index) { + GGML_ASSERT(ggml_backend_is_meta(meta_backend)); + const ggml_backend_meta_context * backend_ctx = (const ggml_backend_meta_context *) meta_backend->context; + return backend_ctx->backend_configs[index].backend; +} + diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index 22c656996cc..1a555bf2a4d 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -123,7 +123,7 @@ size_t ggml_backend_buffer_get_size(ggml_backend_buffer_t buffer) { void * ggml_backend_buffer_get_base(ggml_backend_buffer_t buffer) { GGML_ASSERT(buffer); // get_base is optional if the buffer is zero-sized - if (buffer->size == 0) { + if (!ggml_backend_buffer_is_meta(buffer) && buffer->size == 0) { return NULL; } @@ -279,15 +279,57 @@ void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_ten } } +void ggml_backend_tensor_set_2d_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size, + size_t n_copies, size_t stride_tensor, size_t stride_data) { + GGML_ASSERT(backend); + GGML_ASSERT(tensor); + GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); + + if (n_copies <= 1 || backend->iface.set_tensor_2d_async == NULL) { + for (size_t i = 0; i < n_copies; i++) { + ggml_backend_tensor_set_async(backend, tensor, (const char *) data + i*stride_data, offset + i*stride_tensor, size); + } + return; + } + if (size == 0) { + return; + } + + GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); + GGML_ASSERT(offset + (n_copies-1)*stride_tensor + size <= ggml_nbytes(tensor) && "tensor write out of bounds"); + backend->iface.set_tensor_2d_async(backend, tensor, data, offset, size, n_copies, stride_tensor, stride_data); +} + +void ggml_backend_tensor_get_2d_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size, + size_t n_copies, size_t stride_tensor, size_t stride_data) { + GGML_ASSERT(backend); + GGML_ASSERT(tensor); + GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); + + if (n_copies <= 1 || backend->iface.set_tensor_2d_async == NULL) { + for (size_t i = 0; i < n_copies; i++) { + ggml_backend_tensor_get_async(backend, tensor, (char *) data + i*stride_data, offset + i*stride_tensor, size); + } + return; + } + if (size == 0) { + return; + } + + GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); + GGML_ASSERT(offset + (n_copies-1)*stride_tensor + size <= ggml_nbytes(tensor) && "tensor write out of bounds"); + backend->iface.get_tensor_2d_async(backend, tensor, data, offset, size, n_copies, stride_tensor, stride_data); +} + void ggml_backend_tensor_set(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { GGML_ASSERT(tensor); ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; + GGML_ASSERT(buf != NULL && "tensor buffer not set"); if (size == 0) { return; } - GGML_ASSERT(buf != NULL && "tensor buffer not set"); GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds"); @@ -297,18 +339,62 @@ void ggml_backend_tensor_set(struct ggml_tensor * tensor, const void * data, siz void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { GGML_ASSERT(tensor); ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; + GGML_ASSERT(buf != NULL && "tensor buffer not set"); if (size == 0) { return; } - GGML_ASSERT(buf != NULL && "tensor buffer not set"); GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds"); buf->iface.get_tensor(buf, tensor, data, offset, size); } +void ggml_backend_tensor_set_2d(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size, + size_t n_copies, size_t stride_tensor, size_t stride_data) { + GGML_ASSERT(tensor); + ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; + GGML_ASSERT(buf != NULL && "tensor buffer not set"); + + if (n_copies <= 1 || buf->iface.set_tensor_2d == NULL) { + for (size_t i = 0; i < n_copies; i++) { + ggml_backend_tensor_set(tensor, (const char *) data + i*stride_data, offset + i*stride_tensor, size); + } + return; + } + if (size == 0) { + return; + } + + GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); + GGML_ASSERT(offset + (n_copies-1)*stride_tensor + size <= ggml_nbytes(tensor) && "tensor write out of bounds"); + + buf->iface.set_tensor_2d(buf, tensor, data, offset, size, n_copies, stride_tensor, stride_data); +} + +void ggml_backend_tensor_get_2d(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size, + size_t n_copies, size_t stride_tensor, size_t stride_data) { + GGML_ASSERT(tensor); + ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; + GGML_ASSERT(buf != NULL && "tensor buffer not set"); + + if (n_copies <= 1 || buf->iface.set_tensor_2d == NULL) { + for (size_t i = 0; i < n_copies; i++) { + ggml_backend_tensor_get(tensor, (char *) data + i*stride_data, offset + i*stride_tensor, size); + } + return; + } + if (size == 0) { + return; + } + + GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); + GGML_ASSERT(offset + (n_copies-1)*stride_tensor + size <= ggml_nbytes(tensor) && "tensor read out of bounds"); + + buf->iface.get_tensor_2d(buf, tensor, data, offset, size, n_copies, stride_tensor, stride_data); +} + void ggml_backend_tensor_memset(struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { GGML_ASSERT(tensor); ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; @@ -388,7 +474,7 @@ ggml_backend_dev_t ggml_backend_get_device(ggml_backend_t backend) { // backend copy -void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst) { +void ggml_backend_tensor_copy(const struct ggml_tensor * src, struct ggml_tensor * dst) { GGML_ASSERT(ggml_are_same_layout(src, dst) && "cannot copy tensors with different layouts"); if (src == dst) { @@ -402,7 +488,7 @@ void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst } else if (!ggml_backend_buffer_copy_tensor(src, dst)) { #ifndef NDEBUG GGML_LOG_DEBUG("%s: warning: slow copy from %s to %s\n", __func__, ggml_backend_buffer_name(src->buffer), ggml_backend_buffer_name(dst->buffer)); -#endif +#endif // NDEBUG size_t nbytes = ggml_nbytes(src); void * data = malloc(nbytes); ggml_backend_tensor_get(src, data, 0, nbytes); @@ -411,7 +497,7 @@ void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst } } -void ggml_backend_tensor_copy_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, struct ggml_tensor * src, struct ggml_tensor * dst) { +void ggml_backend_tensor_copy_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const struct ggml_tensor * src, struct ggml_tensor * dst) { GGML_ASSERT(ggml_are_same_layout(src, dst) && "cannot copy tensors with different layouts"); if (src == dst) { @@ -500,6 +586,7 @@ enum ggml_backend_dev_type ggml_backend_dev_type(ggml_backend_dev_t device) { } void ggml_backend_dev_get_props(ggml_backend_dev_t device, struct ggml_backend_dev_props * props) { + GGML_ASSERT(device); memset(props, 0, sizeof(*props)); device->iface.get_props(device, props); } @@ -610,6 +697,8 @@ static const struct ggml_backend_buffer_i ggml_backend_multi_buffer_i = { /* .memset_tensor = */ NULL, /* .set_tensor = */ NULL, /* .get_tensor = */ NULL, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, /* .cpy_tensor = */ NULL, /* .clear = */ ggml_backend_multi_buffer_clear, /* .reset = */ NULL, @@ -1899,8 +1988,9 @@ enum ggml_status ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct GGML_ASSERT(tensor->data == NULL); GGML_ASSERT(tensor->view_src == NULL); GGML_ASSERT(addr >= ggml_backend_buffer_get_base(buffer)); - GGML_ASSERT((char *)addr + ggml_backend_buffer_get_alloc_size(buffer, tensor) <= - (char *)ggml_backend_buffer_get_base(buffer) + ggml_backend_buffer_get_size(buffer)); + GGML_ASSERT(ggml_backend_buffer_is_meta(buffer) || + (char *) addr + ggml_backend_buffer_get_alloc_size(buffer, tensor) <= + (char *) ggml_backend_buffer_get_base(buffer) + ggml_backend_buffer_get_size(buffer)); tensor->buffer = buffer; tensor->data = addr; @@ -2174,6 +2264,8 @@ static const struct ggml_backend_buffer_i ggml_backend_cpu_buffer_i = { /* .memset_tensor = */ ggml_backend_cpu_buffer_memset_tensor, /* .set_tensor = */ ggml_backend_cpu_buffer_set_tensor, /* .get_tensor = */ ggml_backend_cpu_buffer_get_tensor, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, /* .cpy_tensor = */ ggml_backend_cpu_buffer_cpy_tensor, /* .clear = */ ggml_backend_cpu_buffer_clear, /* .reset = */ NULL, @@ -2186,6 +2278,8 @@ static const struct ggml_backend_buffer_i ggml_backend_cpu_buffer_from_ptr_i = { /* .memset_tensor = */ ggml_backend_cpu_buffer_memset_tensor, /* .set_tensor = */ ggml_backend_cpu_buffer_set_tensor, /* .get_tensor = */ ggml_backend_cpu_buffer_get_tensor, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, /* .cpy_tensor = */ ggml_backend_cpu_buffer_cpy_tensor, /* .clear = */ ggml_backend_cpu_buffer_clear, /* .reset = */ NULL, diff --git a/ggml/src/ggml-blas/ggml-blas.cpp b/ggml/src/ggml-blas/ggml-blas.cpp index e7a1763b54d..05245b69807 100644 --- a/ggml/src/ggml-blas/ggml-blas.cpp +++ b/ggml/src/ggml-blas/ggml-blas.cpp @@ -262,6 +262,8 @@ static struct ggml_backend_i blas_backend_i = { /* .get_name = */ ggml_backend_blas_get_name, /* .free = */ ggml_backend_blas_free, /* .set_tensor_async = */ NULL, + /* .get_tensor_2d_async = */ NULL, + /* .set_tensor_2d_async = */ NULL, /* .get_tensor_async = */ NULL, /* .cpy_tensor_async = */ NULL, /* .synchronize = */ NULL, diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index 40fe3d82ecc..5fc484b342b 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -1457,6 +1457,8 @@ static const ggml_backend_buffer_i ggml_backend_cann_buffer_interface = { /* .memset_tensor = */ NULL, /* .set_tensor = */ ggml_backend_cann_buffer_set_tensor, /* .get_tensor = */ ggml_backend_cann_buffer_get_tensor, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, /* .cpy_tensor = */ ggml_backend_cann_buffer_cpy_tensor, /* .clear = */ ggml_backend_cann_buffer_clear, /* .reset = */ NULL, @@ -2698,6 +2700,8 @@ static const ggml_backend_i ggml_backend_cann_interface = { /* .free = */ ggml_backend_cann_free, /* .set_tensor_async = */ ggml_backend_cann_set_tensor_async, /* .get_tensor_async = */ ggml_backend_cann_get_tensor_async, + /* .get_tensor_2d_async = */ NULL, + /* .set_tensor_2d_async = */ NULL, /* .cpy_tensor_async = */ ggml_backend_cann_cpy_tensor_async, /* .synchronize = */ ggml_backend_cann_synchronize, /* .graph_plan_create = */ NULL, diff --git a/ggml/src/ggml-cpu/amx/amx.cpp b/ggml/src/ggml-cpu/amx/amx.cpp index 9baf3e025e6..1118f7169c9 100644 --- a/ggml/src/ggml-cpu/amx/amx.cpp +++ b/ggml/src/ggml-cpu/amx/amx.cpp @@ -111,6 +111,8 @@ static ggml_backend_buffer_i ggml_backend_amx_buffer_interface = { /* .memset_tensor = */ ggml_backend_amx_buffer_memset_tensor, /* .set_tensor = */ ggml_backend_amx_buffer_set_tensor, /* .get_tensor = */ nullptr, + /* .set_tensor_2d = */ nullptr, + /* .get_tensor_2d = */ nullptr, /* .cpy_tensor = */ nullptr, /* .clear = */ ggml_backend_amx_buffer_clear, /* .reset = */ nullptr, diff --git a/ggml/src/ggml-cpu/ggml-cpu.cpp b/ggml/src/ggml-cpu/ggml-cpu.cpp index ddf1737a317..49f840be207 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.cpp +++ b/ggml/src/ggml-cpu/ggml-cpu.cpp @@ -195,6 +195,8 @@ static const struct ggml_backend_i ggml_backend_cpu_i = { /* .free = */ ggml_backend_cpu_free, /* .set_tensor_async = */ NULL, /* .get_tensor_async = */ NULL, + /* .get_tensor_2d_async = */ NULL, + /* .set_tensor_2d_async = */ NULL, /* .cpy_tensor_async = */ NULL, /* .synchronize = */ NULL, /* .graph_plan_create = */ ggml_backend_cpu_graph_plan_create, diff --git a/ggml/src/ggml-cuda/CMakeLists.txt b/ggml/src/ggml-cuda/CMakeLists.txt index 419862101d1..b54d4a6b107 100644 --- a/ggml/src/ggml-cuda/CMakeLists.txt +++ b/ggml/src/ggml-cuda/CMakeLists.txt @@ -181,6 +181,16 @@ if (CUDAToolkit_FOUND) target_link_libraries(ggml-cuda PRIVATE CUDA::cuda_driver) endif() + if (GGML_CUDA_NCCL) + find_package(NCCL) + if (NCCL_FOUND) + add_compile_definitions(GGML_USE_NCCL) + target_link_libraries(ggml-cuda PRIVATE NCCL::NCCL) + else() + message(STATUS "Warning: NCCL not found, performance for multiple CUDA GPUs will be suboptimal") + endif() + endif() + set(CUDA_CXX_FLAGS "") set(CUDA_FLAGS -use_fast_math -extended-lambda) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 65d7a6e22ae..64b91811c39 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -186,6 +186,10 @@ void ggml_cuda_error(const char * stmt, const char * func, const char * file, in #define CUBLAS_CHECK(err) CUDA_CHECK_GEN(err, CUBLAS_STATUS_SUCCESS, cublas_get_error_str) +#ifdef GGML_USE_NCCL +#define NCCL_CHECK(err) CUDA_CHECK_GEN(err, ncclSuccess, ncclGetErrorString) +#endif // GGML_USE_NCCL + #if !defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM) static const char * cu_get_error_str(CUresult err) { const char * err_str; @@ -1086,6 +1090,10 @@ struct ggml_cuda_device_info { cuda_device_info devices[GGML_CUDA_MAX_DEVICES] = {}; std::array default_tensor_split = {}; + +#ifdef GGML_USE_NCCL + ncclComm_t comms[GGML_CUDA_MAX_DEVICES]; +#endif // GGML_USE_NCCL }; const ggml_cuda_device_info & ggml_cuda_info(); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 648124c0d31..841af0726b6 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -324,6 +324,28 @@ static ggml_cuda_device_info ggml_cuda_init() { // configure logging to stdout // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, nullptr)); + for (int id = 0; id < info.device_count; ++id) { + ggml_cuda_set_device(id); + for (int id_other = 0; id_other < info.device_count; ++id_other) { + if (id == id_other) { + continue; + } + int can_access_peer; + CUDA_CHECK(cudaDeviceCanAccessPeer(&can_access_peer, id, id_other)); + if (can_access_peer) { + CUDA_CHECK(cudaDeviceEnablePeerAccess(id_other, 0)); + } + } + } + +#ifdef GGML_USE_NCCL + int dev_ids[GGML_CUDA_MAX_DEVICES]; + for (int id = 0; id < info.device_count; ++id) { + dev_ids[id] = id; + } + NCCL_CHECK(ncclCommInitAll(info.comms, info.device_count, dev_ids)); +#endif // GGML_USE_NCCL + return info; } @@ -632,26 +654,46 @@ static enum ggml_status ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer } static void ggml_backend_cuda_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { - ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context; + ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *) buffer->context; ggml_cuda_set_device(ctx->device); - CUDA_CHECK(cudaMemsetAsync((char *)tensor->data + offset, value, size, cudaStreamPerThread)); + CUDA_CHECK(cudaMemsetAsync((char *) tensor->data + offset, value, size, cudaStreamPerThread)); CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread)); } static void ggml_backend_cuda_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { - ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context; + ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *) buffer->context; ggml_cuda_set_device(ctx->device); - CUDA_CHECK(cudaMemcpyAsync((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice, cudaStreamPerThread)); + CUDA_CHECK(cudaMemcpyAsync((char *) tensor->data + offset, data, size, cudaMemcpyHostToDevice, cudaStreamPerThread)); CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread)); } static void ggml_backend_cuda_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { + ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *) buffer->context; + + ggml_cuda_set_device(ctx->device); + CUDA_CHECK(cudaMemcpyAsync(data, (const char *) tensor->data + offset, size, cudaMemcpyDeviceToHost, cudaStreamPerThread)); + CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread)); +} + +static void ggml_backend_cuda_buffer_set_tensor_2d(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, + size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data) { + ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *) buffer->context; + + ggml_cuda_set_device(ctx->device); + CUDA_CHECK(cudaMemcpy2DAsync( + (char *) tensor->data + offset, stride_tensor, data, stride_data, size, n_copies, cudaMemcpyHostToDevice, cudaStreamPerThread)); + CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread)); +} + +static void ggml_backend_cuda_buffer_get_tensor_2d(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, + size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data) { ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context; ggml_cuda_set_device(ctx->device); - CUDA_CHECK(cudaMemcpyAsync(data, (const char *)tensor->data + offset, size, cudaMemcpyDeviceToHost, cudaStreamPerThread)); + CUDA_CHECK(cudaMemcpy2DAsync( + data, stride_data, (const char *) tensor->data + offset, stride_tensor, size, n_copies, cudaMemcpyDeviceToHost, cudaStreamPerThread)); CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread)); } @@ -691,6 +733,8 @@ static const ggml_backend_buffer_i ggml_backend_cuda_buffer_interface = { /* .memset_tensor = */ ggml_backend_cuda_buffer_memset_tensor, /* .set_tensor = */ ggml_backend_cuda_buffer_set_tensor, /* .get_tensor = */ ggml_backend_cuda_buffer_get_tensor, + /* .set_tensor_2d = */ ggml_backend_cuda_buffer_set_tensor_2d, + /* .get_tensor_2d = */ ggml_backend_cuda_buffer_get_tensor_2d, /* .cpy_tensor = */ ggml_backend_cuda_buffer_cpy_tensor, /* .clear = */ ggml_backend_cuda_buffer_clear, /* .reset = */ NULL, @@ -1003,6 +1047,8 @@ static const ggml_backend_buffer_i ggml_backend_cuda_split_buffer_interface = { /* .memset_tensor = */ NULL, /* .set_tensor = */ ggml_backend_cuda_split_buffer_set_tensor, /* .get_tensor = */ ggml_backend_cuda_split_buffer_get_tensor, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, /* .cpy_tensor = */ NULL, /* .clear = */ ggml_backend_cuda_split_buffer_clear, /* .reset = */ NULL, @@ -1079,6 +1125,83 @@ static const ggml_backend_buffer_type_i ggml_backend_cuda_split_buffer_type_inte /* .is_host = */ ggml_backend_cuda_split_buffer_type_is_host, }; +bool ggml_backend_cuda_allreduce_tensor(ggml_backend_t * backends, struct ggml_tensor ** tensors, size_t n_backends) { +#ifdef GGML_USE_NCCL + const int64_t ne = ggml_nelements(tensors[0]); + // FIXME the input of llm_graph_context::build_in_out_ids can produce a tensor with 0 elements if n_outputs == 0 + // This then causes a crash in this function + if (ne == 0) { + return true; + } + for (size_t i = 0; i < n_backends; ++i) { + GGML_ASSERT(tensors[i] != nullptr); + GGML_ASSERT(ggml_nelements(tensors[i]) == ne); + GGML_ASSERT(ggml_is_contiguously_allocated(tensors[i])); + } + + const ggml_cuda_device_info info = ggml_cuda_info(); + + // For small tensors, simply reduce them as FP32. + // The following heuristic for how "small" a tensor should be is based on RTX 4090s connected via 16x PCIe 4.0. + if ((n_backends <= 2 && ne < 32768) || (n_backends == 3 && ne < 131072) || (n_backends >= 4 && ne < 262144)) { + NCCL_CHECK(ncclGroupStart()); + for (size_t i = 0; i < n_backends; ++i) { + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backends[i]->context; + NCCL_CHECK(ncclAllReduce(tensors[i]->data, tensors[i]->data, ne, ncclFloat, ncclSum, info.comms[cuda_ctx->device], cuda_ctx->stream())); + } + NCCL_CHECK(ncclGroupEnd()); + + return true; + } + + // For large tensors it's faster to compress them to BF16 for the reduction: + to_bf16_cuda_t to_bf16 = ggml_get_to_bf16_cuda(GGML_TYPE_F32); + to_fp32_cuda_t to_fp32 = ggml_get_to_fp32_cuda(GGML_TYPE_BF16); + + ggml_cuda_pool_alloc tmp[GGML_CUDA_MAX_DEVICES]; + for (size_t i = 0; i < n_backends; ++i) { + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backends[i]->context; + tmp[i].pool = &cuda_ctx->pool(); + tmp[i].alloc(ne); + + ggml_cuda_set_device(i); + to_bf16(tensors[i]->data, tmp[i].get(), ne, cuda_ctx->stream()); + CUDA_CHECK(cudaGetLastError()); + } + + NCCL_CHECK(ncclGroupStart()); + for (size_t i = 0; i < n_backends; ++i) { + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backends[i]->context; + NCCL_CHECK(ncclAllReduce(tmp[i].get(), tmp[i].get(), ne, ncclBfloat16, ncclSum, info.comms[cuda_ctx->device], cuda_ctx->stream())); + } + NCCL_CHECK(ncclGroupEnd()); + + for (size_t i = 0; i < n_backends; ++i) { + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backends[i]->context; + + ggml_cuda_set_device(i); + to_fp32(tmp[i].get(), (float *) tensors[i]->data, ne, cuda_ctx->stream()); + CUDA_CHECK(cudaGetLastError()); + } + + return true; +#else + // If NCCL is installed it is used by default for optimal performance. + // However, NVIDIA does not distribute NCCL with CUDA so users may be unwittingly missing this package. + // RCCL is disabled by default, users are explicitly opting in. + // Therefore print no warning for RCCL. +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + static bool warning_printed = false; + if (!warning_printed) { + GGML_LOG_WARN("%s: NVIDIA Collective Communications Library (NCCL) is unavailable, multi GPU performance will be suboptimal\n", __func__); + warning_printed = true; + } +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + GGML_UNUSED_VARS(backends, tensors, n_backends); + return false; +#endif // GGML_USE_NCCL +} + ggml_backend_buffer_type_t ggml_backend_cuda_split_buffer_type(int main_device, const float * tensor_split) { static std::mutex mutex; std::lock_guard lock(mutex); @@ -1425,64 +1548,6 @@ static void ggml_cuda_op_mul_mat_cublas( GGML_UNUSED_VARS(dst, src1_ddq_i, src1_padded_row_size); } -static void ggml_cuda_set_peer_access(const int n_tokens, int main_device) { - static bool peer_access_enabled = false; - - const bool enable_peer_access = n_tokens <= GGML_CUDA_PEER_MAX_BATCH_SIZE; - - if (peer_access_enabled == enable_peer_access) { - return; - } - -#ifdef NDEBUG - for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) { - ggml_cuda_set_device(id); - CUDA_CHECK(cudaDeviceSynchronize()); - } - - for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) { - ggml_cuda_set_device(id); - - for (int id_other = 0; id_other < ggml_backend_cuda_get_device_count(); ++id_other) { - if (id == id_other) { - continue; - } - if (id != main_device && id_other != main_device) { - continue; - } - - int can_access_peer; - CUDA_CHECK(cudaDeviceCanAccessPeer(&can_access_peer, id, id_other)); - if (can_access_peer) { - if (enable_peer_access) { - cudaError_t err = cudaDeviceEnablePeerAccess(id_other, 0); - if (err != cudaErrorPeerAccessAlreadyEnabled) { - CUDA_CHECK(err); - } else { - // reset the error - (void)cudaGetLastError(); - } - } else { - cudaError_t err = cudaDeviceDisablePeerAccess(id_other); - if (err != cudaErrorPeerAccessNotEnabled) { - CUDA_CHECK(err); - } else { - // reset the error - (void)cudaGetLastError(); - } - } - } - } - } - - ggml_cuda_set_device(main_device); -#endif // NDEBUG - - peer_access_enabled = enable_peer_access; - - GGML_UNUSED(main_device); -} - static cudaError_t ggml_cuda_Memcpy2DPeerAsync( void * dst, int dstDevice, size_t dpitch, void * src, int srcDevice, size_t spitch, size_t width, size_t height, cudaStream_t stream) { @@ -2483,11 +2548,6 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * } static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst) { - // why is this here instead of mul_mat? - if (dst->src[0] != nullptr && ggml_backend_buft_is_cuda_split(dst->src[0]->buffer->buft)) { - ggml_cuda_set_peer_access(dst->src[1]->ne[1], ctx.device); - } - switch (dst->op) { case GGML_OP_ARGMAX: ggml_cuda_argmax(ctx, dst); @@ -2845,21 +2905,43 @@ static void ggml_backend_cuda_free(ggml_backend_t backend) { } static void ggml_backend_cuda_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { - ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context; ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; GGML_ASSERT(buf->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type"); - CUDA_CHECK(cudaMemcpyAsync((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice, cuda_ctx->stream())); + CUDA_CHECK(cudaMemcpyAsync((char *) tensor->data + offset, data, size, cudaMemcpyHostToDevice, cuda_ctx->stream())); } static void ggml_backend_cuda_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { - ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context; ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; GGML_ASSERT(buf->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type"); - CUDA_CHECK(cudaMemcpyAsync(data, (const char *)tensor->data + offset, size, cudaMemcpyDeviceToHost, cuda_ctx->stream())); + CUDA_CHECK(cudaMemcpyAsync(data, (const char *) tensor->data + offset, size, cudaMemcpyDeviceToHost, cuda_ctx->stream())); +} + +static void ggml_backend_cuda_set_tensor_2d_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, + size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data) { + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context; + ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; + + GGML_ASSERT(buf->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type"); + + CUDA_CHECK(cudaMemcpy2DAsync( + (char *) tensor->data + offset, stride_tensor, data, stride_data, size, n_copies, cudaMemcpyHostToDevice, cuda_ctx->stream())); +} + +static void ggml_backend_cuda_get_tensor_2d_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, + size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data) { + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context; + ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; + + GGML_ASSERT(buf->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type"); + + CUDA_CHECK(cudaMemcpy2DAsync( + data, stride_data, (const char *) tensor->data + offset, stride_tensor, size, n_copies, cudaMemcpyDeviceToHost, cuda_ctx->stream())); } static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const ggml_tensor * src, ggml_tensor * dst) { @@ -2870,21 +2952,21 @@ static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_ return false; } - if (!ggml_backend_buffer_is_cuda(src->buffer) || !ggml_backend_buffer_is_cuda(dst->buffer)) { + if (!ggml_backend_buffer_is_cuda(buf_src) || !ggml_backend_buffer_is_cuda(buf_dst)) { return false; } // device -> device copy - ggml_backend_cuda_context * cuda_ctx_src = (ggml_backend_cuda_context *)backend_src->context; - ggml_backend_cuda_context * cuda_ctx_dst = (ggml_backend_cuda_context *)backend_dst->context; + ggml_backend_cuda_context * cuda_ctx_src = (ggml_backend_cuda_context *) backend_src->context; + ggml_backend_cuda_context * cuda_ctx_dst = (ggml_backend_cuda_context *) backend_dst->context; - ggml_backend_cuda_buffer_context * buf_ctx_src = (ggml_backend_cuda_buffer_context *)buf_src->context; - ggml_backend_cuda_buffer_context * buf_ctx_dst = (ggml_backend_cuda_buffer_context *)buf_dst->context; + ggml_backend_cuda_buffer_context * buf_ctx_src = (ggml_backend_cuda_buffer_context *) buf_src->context; + ggml_backend_cuda_buffer_context * buf_ctx_dst = (ggml_backend_cuda_buffer_context *) buf_dst->context; if (cuda_ctx_src->device != buf_ctx_src->device || cuda_ctx_dst->device != buf_ctx_dst->device) { #ifndef NDEBUG GGML_LOG_DEBUG("%s: backend and buffer devices do not match\n", __func__); -#endif +#endif // NDEBUG return false; } @@ -2897,7 +2979,7 @@ static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_ return false; #else CUDA_CHECK(cudaMemcpyPeerAsync(dst->data, cuda_ctx_dst->device, src->data, cuda_ctx_src->device, ggml_nbytes(dst), cuda_ctx_src->stream())); -#endif +#endif // GGML_CUDA_NO_PEER_COPY } // record event on src stream after the copy @@ -4343,6 +4425,8 @@ static const ggml_backend_i ggml_backend_cuda_interface = { /* .free = */ ggml_backend_cuda_free, /* .set_tensor_async = */ ggml_backend_cuda_set_tensor_async, /* .get_tensor_async = */ ggml_backend_cuda_get_tensor_async, + /* .get_tensor_2d_async = */ ggml_backend_cuda_set_tensor_2d_async, + /* .set_tensor_2d_async = */ ggml_backend_cuda_get_tensor_2d_async, /* .cpy_tensor_async = */ ggml_backend_cuda_cpy_tensor_async, /* .synchronize = */ ggml_backend_cuda_synchronize, /* .graph_plan_create = */ NULL, @@ -5130,6 +5214,9 @@ static ggml_backend_feature * ggml_backend_cuda_get_features(ggml_backend_reg_t static void * ggml_backend_cuda_reg_get_proc_address(ggml_backend_reg_t reg, const char * name) { GGML_UNUSED(reg); + if (strcmp(name, "ggml_backend_allreduce_tensor") == 0) { + return (void *)ggml_backend_cuda_allreduce_tensor; + } if (strcmp(name, "ggml_backend_split_buffer_type") == 0) { return (void *)ggml_backend_cuda_split_buffer_type; } diff --git a/ggml/src/ggml-cuda/vendors/cuda.h b/ggml/src/ggml-cuda/vendors/cuda.h index 07bc47df3b8..323c9801934 100644 --- a/ggml/src/ggml-cuda/vendors/cuda.h +++ b/ggml/src/ggml-cuda/vendors/cuda.h @@ -6,6 +6,10 @@ #include #include +#ifdef GGML_USE_NCCL +#include +#endif // GGML_USE_NCCL + #if CUDART_VERSION >= 11080 #include #define FP8_AVAILABLE diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h index 9d9ba1ee219..d146e018d94 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h @@ -10,6 +10,11 @@ #include #endif // defined(GGML_HIP_ROCWMMA_FATTN) +#ifdef GGML_USE_NCCL +#include +#endif // GGML_USE_NCCL + + #define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT #define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT #define CUBLAS_OP_N HIPBLAS_OP_N @@ -28,6 +33,7 @@ #define CU_MEM_LOCATION_TYPE_DEVICE hipMemLocationTypeDevice #define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite #define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }} +#define NCCL_CHECK(fn) {ncclResult_t err = fn; if(err != ncclSuccess) { GGML_ABORT("RCCL Failure RCCL returned: %i\n", err); }} #define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width) #define __shfl_up_sync(mask, var, laneMask, width) __shfl_up(var, laneMask, width) #define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width) diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index f91bc46552e..ac5baa2acaf 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -1491,6 +1491,8 @@ static ggml_backend_buffer_i ggml_backend_hexagon_buffer_interface = { /* .memset_tensor = */ NULL, /* .set_tensor = */ ggml_backend_hexagon_buffer_set_tensor, /* .get_tensor = */ ggml_backend_hexagon_buffer_get_tensor, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, /* .cpy_tensor = */ ggml_backend_hexagon_buffer_cpy_tensor, /* .clear = */ ggml_backend_hexagon_buffer_clear, /* .reset = */ NULL, @@ -3002,6 +3004,8 @@ static struct ggml_backend_i hexagon_backend_i = { /* .free = */ ggml_backend_hexagon_free, /* .set_tensor_async = */ NULL, /* .get_tensor_async = */ NULL, + /* .get_tensor_2d_async = */ NULL, + /* .set_tensor_2d_async = */ NULL, /* .cpy_tensor_async = */ NULL, /* .synchronize = */ ggml_backend_hexagon_synchronize, /* .graph_plan_create = */ NULL, diff --git a/ggml/src/ggml-hip/CMakeLists.txt b/ggml/src/ggml-hip/CMakeLists.txt index 291b4837455..a7d4e0ea2b5 100644 --- a/ggml/src/ggml-hip/CMakeLists.txt +++ b/ggml/src/ggml-hip/CMakeLists.txt @@ -47,6 +47,10 @@ find_package(hip REQUIRED) find_package(hipblas REQUIRED) find_package(rocblas REQUIRED) +if (GGML_HIP_RCCL) + find_package(rccl REQUIRED) +endif() + if (${hip_VERSION} VERSION_LESS 6.1) message(FATAL_ERROR "At least ROCM/HIP V6.1 is required") endif() @@ -118,6 +122,10 @@ if (NOT GGML_HIP_MMQ_MFMA) add_compile_definitions(GGML_HIP_NO_MMQ_MFMA) endif() +if (GGML_HIP_RCCL) + add_compile_definitions(GGML_USE_NCCL) # RCCL has the same interface as NCCL. +endif() + if (GGML_HIP_EXPORT_METRICS) set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} -Rpass-analysis=kernel-resource-usage --save-temps") endif() @@ -142,4 +150,8 @@ if (GGML_STATIC) message(FATAL_ERROR "Static linking not supported for HIP/ROCm") endif() +if (GGML_HIP_RCCL) + target_link_libraries(ggml-hip PRIVATE ggml-base roc::rccl) +endif() + target_link_libraries(ggml-hip PRIVATE ggml-base hip::host roc::rocblas roc::hipblas) diff --git a/ggml/src/ggml-metal/ggml-metal.cpp b/ggml/src/ggml-metal/ggml-metal.cpp index 9382ce53b36..4dbf8e6fea9 100644 --- a/ggml/src/ggml-metal/ggml-metal.cpp +++ b/ggml/src/ggml-metal/ggml-metal.cpp @@ -90,6 +90,8 @@ static ggml_backend_buffer_i ggml_backend_metal_buffer_shared_i = { /* .memset_tensor = */ ggml_backend_metal_buffer_shared_memset_tensor, /* .set_tensor = */ ggml_backend_metal_buffer_shared_set_tensor, /* .get_tensor = */ ggml_backend_metal_buffer_shared_get_tensor, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, /* .cpy_tensor = */ ggml_backend_metal_buffer_shared_cpy_tensor, /* .clear = */ ggml_backend_metal_buffer_shared_clear, /* .reset = */ NULL, @@ -158,15 +160,17 @@ static void ggml_backend_metal_buffer_private_clear(ggml_backend_buffer_t buffer } static ggml_backend_buffer_i ggml_backend_metal_buffer_private_i = { - /* .free_buffer = */ ggml_backend_metal_buffer_private_free_buffer, - /* .get_base = */ ggml_backend_metal_buffer_private_get_base, - /* .init_tensor = */ NULL, - /* .memset_tensor = */ ggml_backend_metal_buffer_private_memset_tensor, - /* .set_tensor = */ ggml_backend_metal_buffer_private_set_tensor, - /* .get_tensor = */ ggml_backend_metal_buffer_private_get_tensor, - /* .cpy_tensor = */ ggml_backend_metal_buffer_private_cpy_tensor, - /* .clear = */ ggml_backend_metal_buffer_private_clear, - /* .reset = */ NULL, + /* .free_buffer = */ ggml_backend_metal_buffer_private_free_buffer, + /* .get_base = */ ggml_backend_metal_buffer_private_get_base, + /* .init_tensor = */ NULL, + /* .memset_tensor = */ ggml_backend_metal_buffer_private_memset_tensor, + /* .set_tensor = */ ggml_backend_metal_buffer_private_set_tensor, + /* .get_tensor = */ ggml_backend_metal_buffer_private_get_tensor, + /* .get_tensor_2d_async = */ NULL, + /* .set_tensor_2d_async = */ NULL, + /* .cpy_tensor = */ ggml_backend_metal_buffer_private_cpy_tensor, + /* .clear = */ ggml_backend_metal_buffer_private_clear, + /* .reset = */ NULL, }; static bool ggml_backend_buffer_is_metal(ggml_backend_buffer_t buffer) { @@ -563,6 +567,8 @@ static ggml_backend_i ggml_backend_metal_i = { /* .free = */ ggml_backend_metal_free, /* .set_tensor_async = */ ggml_backend_metal_set_tensor_async, /* .get_tensor_async = */ ggml_backend_metal_get_tensor_async, + /* .get_tensor_2d_async = */ NULL, + /* .set_tensor_2d_async = */ NULL, /* .cpy_tensor_async = */ ggml_backend_metal_cpy_tensor_async, // only needed for multi-GPU setups /* .synchronize = */ ggml_backend_metal_synchronize, /* .graph_plan_create = */ NULL, diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 6f3fc5886d8..f1a28a7f4cd 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -4063,6 +4063,8 @@ static ggml_backend_i ggml_backend_opencl_i = { /* .set_tensor_async = */ NULL, /* ggml_backend_opencl_set_tensor_async */ /* .get_tensor_async = */ NULL, /* ggml_backend_opencl_get_tensor_async */ /* .cpy_tensor_async = */ NULL, /* ggml_backend_opencl_cpy_tensor_async */ + /* .get_tensor_2d_async = */ NULL, + /* .set_tensor_2d_async = */ NULL, /* .synchronize = */ ggml_backend_opencl_synchronize, /* .graph_plan_create = */ NULL, /* .graph_plan_free = */ NULL, @@ -5778,6 +5780,8 @@ static ggml_backend_buffer_i ggml_backend_opencl_buffer_interface = { /* .memset_tensor = */ NULL, /* .set_tensor = */ ggml_backend_opencl_buffer_set_tensor, /* .get_tensor = */ ggml_backend_opencl_buffer_get_tensor, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, /* .cpy_tensor = */ NULL, /* .clear = */ ggml_backend_opencl_buffer_clear, /* .reset = */ ggml_backend_opencl_buffer_reset, diff --git a/ggml/src/ggml-openvino/ggml-openvino.cpp b/ggml/src/ggml-openvino/ggml-openvino.cpp index b3058b4af73..0c8d3508e87 100644 --- a/ggml/src/ggml-openvino/ggml-openvino.cpp +++ b/ggml/src/ggml-openvino/ggml-openvino.cpp @@ -412,6 +412,8 @@ static const ggml_backend_buffer_i ggml_backend_openvino_buffer_interface = { /* .memset_tensor = */ ggml_backend_openvino_buffer_memset_tensor, /* .set_tensor = */ ggml_backend_openvino_buffer_set_tensor, /* .get_tensor = */ ggml_backend_openvino_buffer_get_tensor, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, /* .cpy_tensor = */ ggml_backend_openvino_buffer_cpy_tensor, /* .clear = */ ggml_backend_openvino_buffer_clear, /* .reset = */ NULL, @@ -617,6 +619,8 @@ static const ggml_backend_i ggml_backend_openvino_interface = { /* .free = */ ggml_backend_openvino_free, /* .set_tensor_async = */ NULL, /* .get_tensor_async = */ NULL, + /* .set_tensor_2d_async = */ NULL, + /* .get_tensor_2d_async = */ NULL, /* .cpy_tensor_async = */ NULL, /* .synchronize = */ NULL, /* .graph_plan_create = */ NULL, diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index 4e2f1ab0f23..61bfcc5a675 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -706,6 +706,8 @@ static ggml_backend_buffer_i ggml_backend_rpc_buffer_interface = { /* .memset_tensor = */ NULL, /* .set_tensor = */ ggml_backend_rpc_buffer_set_tensor, /* .get_tensor = */ ggml_backend_rpc_buffer_get_tensor, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, /* .cpy_tensor = */ ggml_backend_rpc_buffer_cpy_tensor, /* .clear = */ ggml_backend_rpc_buffer_clear, /* .reset = */ NULL, @@ -894,6 +896,8 @@ static ggml_backend_i ggml_backend_rpc_interface = { /* .set_tensor_async = */ NULL, /* .get_tensor_async = */ NULL, /* .cpy_tensor_async = */ NULL, + /* .get_tensor_2d_async = */ NULL, + /* .set_tensor_2d_async = */ NULL, /* .synchronize = */ ggml_backend_rpc_synchronize, /* .graph_plan_create = */ NULL, /* .graph_plan_free = */ NULL, diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 7f9b2df524e..989c91a6abb 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -638,6 +638,8 @@ static const ggml_backend_buffer_i ggml_backend_sycl_buffer_interface = { /* .memset_tensor = */ ggml_backend_sycl_buffer_memset_tensor, /* .set_tensor = */ ggml_backend_sycl_buffer_set_tensor, /* .get_tensor = */ ggml_backend_sycl_buffer_get_tensor, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, /* .cpy_tensor = */ ggml_backend_sycl_buffer_cpy_tensor, /* .clear = */ ggml_backend_sycl_buffer_clear, /* .reset = */ ggml_backend_sycl_buffer_reset, @@ -1084,6 +1086,8 @@ static struct ggml_backend_buffer_i ggml_backend_sycl_split_buffer_interface = { /* .memset_tensor = */ NULL, /* .set_tensor = */ ggml_backend_sycl_split_buffer_set_tensor, /* .get_tensor = */ ggml_backend_sycl_split_buffer_get_tensor, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, /* .cpy_tensor = */ NULL, /* .clear = */ ggml_backend_sycl_split_buffer_clear, /* .reset = */ NULL, @@ -4553,6 +4557,8 @@ static ggml_backend_i ggml_backend_sycl_interface = { /* .free = */ ggml_backend_sycl_free, /* .set_tensor_async = */ ggml_backend_sycl_set_tensor_async, /* .get_tensor_async = */ ggml_backend_sycl_get_tensor_async, + /* .get_tensor_2d_async = */ NULL, + /* .set_tensor_2d_async = */ NULL, /* .cpy_tensor_async = */ NULL, // ggml_backend_sycl_cpy_tensor_async, // // TODO: update for the new // interface diff --git a/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp b/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp index 6b95362dd80..b6c561cd61e 100644 --- a/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +++ b/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp @@ -101,6 +101,8 @@ const ggml_backend_buffer_i ggml_backend_remoting_buffer_interface = { /* .memset_tensor = */ NULL, /* .set_tensor = */ ggml_backend_remoting_buffer_set_tensor, /* .get_tensor = */ ggml_backend_remoting_buffer_get_tensor, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, /* .cpy_tensor = */ ggml_backend_remoting_buffer_cpy_tensor, /* .clear = */ ggml_backend_remoting_buffer_clear, /* .reset = */ NULL, @@ -113,6 +115,8 @@ const ggml_backend_buffer_i ggml_backend_remoting_buffer_from_ptr_interface = { /* .memset_tensor = */ NULL, /* .set_tensor = */ ggml_backend_remoting_buffer_set_tensor_from_ptr, /* .get_tensor = */ ggml_backend_remoting_buffer_get_tensor_from_ptr, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, /* .cpy_tensor = */ ggml_backend_remoting_buffer_cpy_tensor, /* .clear = */ ggml_backend_remoting_buffer_clear, /* .reset = */ NULL, diff --git a/ggml/src/ggml-virtgpu/ggml-backend.cpp b/ggml/src/ggml-virtgpu/ggml-backend.cpp index a63ee2b9d2f..2b978556228 100644 --- a/ggml/src/ggml-virtgpu/ggml-backend.cpp +++ b/ggml/src/ggml-virtgpu/ggml-backend.cpp @@ -34,6 +34,8 @@ static ggml_backend_i ggml_backend_remoting_interface = { /* .free = */ ggml_backend_remoting_free, /* .set_tensor_async = */ NULL, // ggml_backend_remoting_set_tensor_async, /* .get_tensor_async = */ NULL, // ggml_backend_remoting_get_tensor_async, + /* .get_tensor_2d_async = */ NULL, + /* .set_tensor_2d_async = */ NULL, /* .cpy_tensor_async = */ NULL, // ggml_backend_remoting_cpy_tensor_async, /* .synchronize = */ NULL, // ggml_backend_remoting_synchronize, /* .graph_plan_create = */ NULL, diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 19e7fbdaae7..20a4d30d5eb 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -13521,6 +13521,8 @@ static ggml_backend_buffer_i ggml_backend_vk_buffer_interface = { /* .memset_tensor = */ ggml_backend_vk_buffer_memset_tensor, /* .set_tensor = */ ggml_backend_vk_buffer_set_tensor, /* .get_tensor = */ ggml_backend_vk_buffer_get_tensor, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, /* .cpy_tensor = */ ggml_backend_vk_buffer_cpy_tensor, /* .clear = */ ggml_backend_vk_buffer_clear, /* .reset = */ NULL, @@ -14979,6 +14981,8 @@ static ggml_backend_i ggml_backend_vk_interface = { /* .free = */ ggml_backend_vk_free, /* .set_tensor_async = */ ggml_backend_vk_set_tensor_async, /* .get_tensor_async = */ ggml_backend_vk_get_tensor_async, + /* .get_tensor_2d_async = */ NULL, + /* .set_tensor_2d_async = */ NULL, /* .cpy_tensor_async = */ ggml_backend_vk_cpy_tensor_async, /* .synchronize = */ ggml_backend_vk_synchronize, /* .graph_plan_create = */ NULL, diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index b8df0f4dd05..edfc6579171 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -3013,6 +3013,8 @@ static ggml_backend_i ggml_backend_webgpu_i = { /* .free = */ ggml_backend_webgpu_free, /* .set_tensor_async = */ NULL, /* .get_tensor_async = */ NULL, + /* .get_tensor_2d_async = */ NULL, + /* .set_tensor_2d_async = */ NULL, /* .cpy_tensor_async = */ NULL, /* .synchronize = */ NULL, /* .graph_plan_create = */ NULL, @@ -3170,6 +3172,8 @@ static ggml_backend_buffer_i ggml_backend_webgpu_buffer_interface = { /* .memset_tensor = */ ggml_backend_webgpu_buffer_memset_tensor, /* .set_tensor = */ ggml_backend_webgpu_buffer_set_tensor, /* .get_tensor = */ ggml_backend_webgpu_buffer_get_tensor, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, /* .cpy_tensor = */ NULL, // TODO: optional, implement this /* .clear = */ ggml_backend_webgpu_buffer_clear, /* .reset = */ NULL, // TODO: optional, think it coordinates with diff --git a/ggml/src/ggml-zdnn/ggml-zdnn.cpp b/ggml/src/ggml-zdnn/ggml-zdnn.cpp index 9b6938abf7e..e6b6fc24fd7 100644 --- a/ggml/src/ggml-zdnn/ggml-zdnn.cpp +++ b/ggml/src/ggml-zdnn/ggml-zdnn.cpp @@ -313,6 +313,8 @@ static ggml_backend_buffer_i ggml_backend_zdnn_buffer_i = { /* .memset_tensor = */ ggml_backend_zdnn_buffer_memset_tensor, /* .set_tensor = */ ggml_backend_zdnn_buffer_set_tensor, /* .get_tensor = */ ggml_backend_zdnn_buffer_get_tensor, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, /* .cpy_tensor = */ NULL, /* .clear = */ ggml_backend_zdnn_buffer_clear, /* .reset = */ NULL, @@ -417,20 +419,22 @@ static enum ggml_status ggml_backend_zdnn_graph_compute(ggml_backend_t backend, } static ggml_backend_i ggml_backend_zdnn_i = { - /* .get_name = */ ggml_backend_zdnn_name, - /* .free = */ ggml_backend_zdnn_free, - /* .set_tensor_async = */ NULL, - /* .get_tensor_async = */ NULL, - /* .cpy_tensor_async = */ NULL, - /* .synchronize = */ NULL, - /* .graph_plan_create = */ NULL, - /* .graph_plan_free = */ NULL, - /* .graph_plan_update = */ NULL, - /* .graph_plan_compute = */ NULL, - /* .graph_compute = */ ggml_backend_zdnn_graph_compute, - /* .event_record = */ NULL, - /* .event_wait = */ NULL, - /* .graph_optimize = */ NULL, + /* .get_name = */ ggml_backend_zdnn_name, + /* .free = */ ggml_backend_zdnn_free, + /* .set_tensor_async = */ NULL, + /* .get_tensor_async = */ NULL, + /* .get_tensor_2d_async = */ NULL, + /* .set_tensor_2d_async = */ NULL, + /* .cpy_tensor_async = */ NULL, + /* .synchronize = */ NULL, + /* .graph_plan_create = */ NULL, + /* .graph_plan_free = */ NULL, + /* .graph_plan_update = */ NULL, + /* .graph_plan_compute = */ NULL, + /* .graph_compute = */ ggml_backend_zdnn_graph_compute, + /* .event_record = */ NULL, + /* .event_wait = */ NULL, + /* .graph_optimize = */ NULL, }; static ggml_guid_t ggml_backend_zdnn_guid(void) { diff --git a/ggml/src/ggml-zendnn/ggml-zendnn.cpp b/ggml/src/ggml-zendnn/ggml-zendnn.cpp index 377303720c7..fc1df4dbef4 100644 --- a/ggml/src/ggml-zendnn/ggml-zendnn.cpp +++ b/ggml/src/ggml-zendnn/ggml-zendnn.cpp @@ -407,6 +407,8 @@ static struct ggml_backend_i ggml_backend_zendnn_i = { /* .free = */ ggml_backend_zendnn_free, /* .set_tensor_async = */ NULL, /* .get_tensor_async = */ NULL, + /* .get_tensor_2d_async = */ NULL, + /* .set_tensor_2d_async = */ NULL, /* .cpy_tensor_async = */ NULL, /* .synchronize = */ NULL, /* .graph_plan_create = */ NULL, From c77a33df06f64eda3cff5dd54a99e7b3fdbb152c Mon Sep 17 00:00:00 2001 From: andyluo7 <43718156+andyluo7@users.noreply.github.com> Date: Thu, 9 Apr 2026 22:13:32 +0300 Subject: [PATCH 410/831] HIP: add CDNA4 (gfx950) architecture support for MI350X/MI355X (llama/21570) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add AMD Instinct MI350X/MI355X (gfx950, CDNA4) support: - vendors/hip.h: Add CDNA4 preprocessor define for __gfx950__ - common.cuh: Add GGML_CUDA_CC_CDNA4 and GGML_CUDA_CC_IS_CDNA4 macros - mma.cuh: Route CDNA4 to compatible MFMA instructions: * f32 matmul: mfma_f32_16x16x4f32 (xf32 variant unavailable on gfx950) * bf16 matmul: mfma_f32_16x16x16bf16_1k (same as CDNA3) * int8 matmul: mfma_i32_16x16x32_i8/32x32x16 (same as CDNA3) - mmq.cuh: Include CDNA4 in stream-k kernel dispatch CDNA4 is largely compatible with CDNA3 except: - No xf32 MFMA (mfma_f32_16x16x8_xf32) — routes to f32 path - Different FP8 format (e4m3fn vs e4m3_fnuz) — not changed here Tested on AMD Instinct MI355X (gfx950), ROCm 7.0.1: - Build: compiles cleanly with -DAMDGPU_TARGETS=gfx950 - llama-bench (Qwen2.5-1.5B Q4_K_M, single GPU): * f16+FA: 40,013 tok/s prefill, 254 tok/s decode * q8_0+FA: functional - Flash attention: works correctly - MMQ: works correctly with stream-k dispatch Co-authored-by: Andy Luo --- ggml/src/ggml-cuda/common.cuh | 4 +++- ggml/src/ggml-cuda/mma.cuh | 17 +++++++++-------- ggml/src/ggml-cuda/mmq.cuh | 2 +- ggml/src/ggml-cuda/vendors/hip.h | 8 ++++++-- 4 files changed, 19 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 64b91811c39..56a67f1edc8 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -67,6 +67,7 @@ #define GGML_CUDA_CC_CDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x908) // MI100, minimum for MFMA, acc registers #define GGML_CUDA_CC_CDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x90a) // MI210 (gfx90a), minimum acc register renaming #define GGML_CUDA_CC_CDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x942) // MI300 +#define GGML_CUDA_CC_CDNA4 (GGML_CUDA_CC_OFFSET_AMD + 0x950) // MI350X/MI355X // RDNA removes MFMA, dp4a, xnack, acc registers, wave size is 32 #define GGML_CUDA_CC_RDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x1010) // RX 5000 @@ -87,7 +88,8 @@ #define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA1 && cc < GGML_CUDA_CC_RDNA1) #define GGML_CUDA_CC_IS_CDNA1(cc) (cc >= GGML_CUDA_CC_CDNA1 && cc < GGML_CUDA_CC_CDNA2) #define GGML_CUDA_CC_IS_CDNA2(cc) (cc >= GGML_CUDA_CC_CDNA2 && cc < GGML_CUDA_CC_CDNA3) -#define GGML_CUDA_CC_IS_CDNA3(cc) (cc >= GGML_CUDA_CC_CDNA3 && cc < GGML_CUDA_CC_RDNA1) +#define GGML_CUDA_CC_IS_CDNA3(cc) (cc >= GGML_CUDA_CC_CDNA3 && cc < GGML_CUDA_CC_CDNA4) +#define GGML_CUDA_CC_IS_CDNA4(cc) (cc >= GGML_CUDA_CC_CDNA4 && cc < GGML_CUDA_CC_RDNA1) // Moore Threads #define MUSART_HMASK 40300 // MUSA rc4.3, min. ver. for half2 -> uint mask comparisons diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index 5d1dadd3e4f..c91dd2d9ad6 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -1025,7 +1025,8 @@ namespace ggml_cuda_mma { const floatx2_t& a_frag = reinterpret_cast(A.x[0]); const floatx2_t& b_frag = reinterpret_cast(B.x[0]); acc_frag = __builtin_amdgcn_mfma_f32_16x16x8_xf32(a_frag, b_frag, acc_frag, 0, 0, 0); -#elif defined(CDNA2) || defined(CDNA1) +#elif defined(CDNA4) || defined(CDNA2) || defined(CDNA1) + // CDNA4 (gfx950) does not support xf32 MFMA, use f32 path like CDNA2/CDNA1 #pragma unroll for (int i = 0; i < 2; ++i) { acc_frag = __builtin_amdgcn_mfma_f32_16x16x4f32(A.x[i], B.x[i], acc_frag, 0, 0, 0); @@ -1187,7 +1188,7 @@ namespace ggml_cuda_mma { #elif defined(AMD_MFMA_AVAILABLE) using floatx4_t = __attribute__((ext_vector_type(4))) float; floatx4_t& acc_frag = reinterpret_cast(D.x[0]); -#if defined(CDNA3) || defined(CDNA2) +#if defined(CDNA4) || defined(CDNA3) || defined(CDNA2) using bf16x4_t = __attribute__((ext_vector_type(4))) __bf16; const bf16x4_t& a_frag = reinterpret_cast(A.x[0]); const bf16x4_t& b_frag = reinterpret_cast(B.x[0]); @@ -1216,12 +1217,12 @@ namespace ggml_cuda_mma { #if defined(AMD_MFMA_AVAILABLE) using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int; int32x4_t * acc = (int32x4_t *) D.x; -#if defined(CDNA3) +#if defined(CDNA4) || defined(CDNA3) acc[0] = __builtin_amdgcn_mfma_i32_16x16x32_i8(((int64_t *) A.x)[0], ((int64_t *) B.x)[0], acc[0], 0, 0, 0); -#elif defined(CDNA2) || defined(CDNA) +#elif defined(CDNA2) || defined(CDNA1) acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[0], B.x[0], acc[0], @@ -1230,7 +1231,7 @@ namespace ggml_cuda_mma { B.x[1], acc[0], 0, 0, 0); -#endif // defined(CDNA3) +#endif // defined(CDNA4) || defined(CDNA3) #elif defined(AMD_WMMA_AVAILABLE) @@ -1295,12 +1296,12 @@ namespace ggml_cuda_mma { #if defined(AMD_MFMA_AVAILABLE) using int32x16_t = __attribute__((__vector_size__(16 * sizeof(int)))) int; int32x16_t * acc = (int32x16_t *) D.x; -#if defined(CDNA3) +#if defined(CDNA4) || defined(CDNA3) acc[0] = __builtin_amdgcn_mfma_i32_32x32x16_i8(((int64_t *) A.x)[0], ((int64_t *) B.x)[0], acc[0], 0, 0, 0); -#elif defined(CDNA2) || defined(CDNA) +#elif defined(CDNA2) || defined(CDNA1) acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[0], B.x[0], acc[0], @@ -1309,7 +1310,7 @@ namespace ggml_cuda_mma { B.x[1], acc[0], 0, 0, 0); -#endif // defined(CDNA3) +#endif // defined(CDNA4) || defined(CDNA3) #else GGML_UNUSED_VARS(D, A, B); diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 489d3616bb4..18911141472 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -3645,7 +3645,7 @@ static __global__ void mul_mat_q( tile_x_max_i, tile_y_max_j, 0, ncols_x/qk); return; } -#endif // (defined(GGML_USE_HIP) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA +#endif // (defined(GGML_USE_HIP) && !defined(CDNA4) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA constexpr int ITER_K = get_iter_k(type); diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h index d146e018d94..898fec31e36 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h @@ -189,6 +189,10 @@ #define GCN #endif // defined(GCN5) || defined(GCN4) +#if defined(__gfx950__) +#define CDNA4 +#endif // defined(__gfx950__) + #if defined(__gfx942__) #define CDNA3 #endif // defined(__gfx942__) @@ -201,9 +205,9 @@ #define CDNA1 #endif // defined(__gfx908__) -#if defined(CDNA3) || defined(CDNA2) || defined(CDNA1) +#if defined(CDNA4) || defined(CDNA3) || defined(CDNA2) || defined(CDNA1) #define CDNA // For the entire family -#endif // defined(CDNA3) || defined(CDNA2) || defined(CDNA1) +#endif // defined(CDNA4) || defined(CDNA3) || defined(CDNA2) || defined(CDNA1) #if defined(__GFX12__) #define RDNA4 From 28347201fcd8771fdec88fbcad39eff597ee7866 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Fri, 10 Apr 2026 10:24:09 +0800 Subject: [PATCH 411/831] CUDA: fuse muls (llama/21665) --- ggml/src/ggml-cuda/binbcast.cu | 30 ++++++++++++++++++++++++++++++ ggml/src/ggml-cuda/binbcast.cuh | 1 + ggml/src/ggml-cuda/ggml-cuda.cu | 18 +++++++++++------- 3 files changed, 42 insertions(+), 7 deletions(-) diff --git a/ggml/src/ggml-cuda/binbcast.cu b/ggml/src/ggml-cuda/binbcast.cu index 7339fe0c070..adb4d5f0cb9 100644 --- a/ggml/src/ggml-cuda/binbcast.cu +++ b/ggml/src/ggml-cuda/binbcast.cu @@ -472,6 +472,36 @@ void ggml_cuda_op_fused_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst, } } +void ggml_cuda_op_fused_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst, int n_fuse) { + GGML_ASSERT(2 <= n_fuse && n_fuse <= 8); + + switch (n_fuse) { + case 2: + ggml_cuda_op_fused_binbcast_impl(ctx, dst); + break; + case 3: + ggml_cuda_op_fused_binbcast_impl(ctx, dst); + break; + case 4: + ggml_cuda_op_fused_binbcast_impl(ctx, dst); + break; + case 5: + ggml_cuda_op_fused_binbcast_impl(ctx, dst); + break; + case 6: + ggml_cuda_op_fused_binbcast_impl(ctx, dst); + break; + case 7: + ggml_cuda_op_fused_binbcast_impl(ctx, dst); + break; + case 8: + ggml_cuda_op_fused_binbcast_impl(ctx, dst); + break; + default: + GGML_ASSERT(false && "Unsupported n_fuse value"); + } +} + void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; diff --git a/ggml/src/ggml-cuda/binbcast.cuh b/ggml/src/ggml-cuda/binbcast.cuh index 62bc950111b..12624785b44 100644 --- a/ggml/src/ggml-cuda/binbcast.cuh +++ b/ggml/src/ggml-cuda/binbcast.cuh @@ -9,3 +9,4 @@ void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_fused_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst, int n_fuse); +void ggml_cuda_op_fused_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst, int n_fuse); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 841af0726b6..8613d20b9f9 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3758,10 +3758,10 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud continue; } - if (node->op == GGML_OP_ADD) { + if (node->op == GGML_OP_ADD || node->op == GGML_OP_MUL) { int n_fuse = 0; ggml_op ops[8]; - std::fill(ops, ops + 8, GGML_OP_ADD); + std::fill(ops, ops + 8, node->op); for (; n_fuse <= 6; ++n_fuse){ if (!ggml_can_fuse(cgraph, i + n_fuse, ops + n_fuse, 2)) { @@ -3778,13 +3778,17 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud n_fuse++; if (n_fuse > 1) { - ggml_tensor fused_add_node; - memcpy(&fused_add_node, node, sizeof(ggml_tensor)); + ggml_tensor fused_node; + memcpy(&fused_node, node, sizeof(ggml_tensor)); for (int j = 0; j < n_fuse - 1; ++j) { - fused_add_node.src[j + 2] = cgraph->nodes[i + j + 1]->src[1]; + fused_node.src[j + 2] = cgraph->nodes[i + j + 1]->src[1]; + } + fused_node.data = cgraph->nodes[i + n_fuse - 1]->data; + if (node->op == GGML_OP_ADD) { + ggml_cuda_op_fused_add(*cuda_ctx, &fused_node, n_fuse); + } else { + ggml_cuda_op_fused_mul(*cuda_ctx, &fused_node, n_fuse); } - fused_add_node.data = cgraph->nodes[i + n_fuse - 1]->data; - ggml_cuda_op_fused_add(*cuda_ctx, &fused_add_node, n_fuse); i += n_fuse - 1; continue; From 458ad1d93ec9c5c08752cc409cebf09a06ddd8ea Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Fri, 10 Apr 2026 01:35:27 -0500 Subject: [PATCH 412/831] vulkan: Support Q1_0 (llama/21539) * vulkan: Support Q1_0 * use get_dm --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 33 +++++++++++++++++++ .../vulkan-shaders/copy_to_quant.comp | 25 ++++++++++++++ .../vulkan-shaders/dequant_funcs.glsl | 24 ++++++++++++++ .../vulkan-shaders/dequant_funcs_cm2.glsl | 16 ++++++++- .../vulkan-shaders/dequant_q1_0.comp | 29 ++++++++++++++++ .../vulkan-shaders/mul_mm_funcs.glsl | 14 ++++++++ .../src/ggml-vulkan/vulkan-shaders/types.glsl | 16 +++++++++ .../vulkan-shaders/vulkan-shaders-gen.cpp | 7 ++-- 8 files changed, 160 insertions(+), 4 deletions(-) create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/dequant_q1_0.comp diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 20a4d30d5eb..977aff62d81 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -3512,6 +3512,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3) } #endif + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q1_0], matmul_q1_0_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_0], matmul_q4_0_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_1], matmul_q4_1_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_0], matmul_q5_0_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) @@ -3541,6 +3542,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 5) } #endif + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q1_0], matmul_id_subgroup_q1_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5) CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5) CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5) CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5) @@ -3602,6 +3604,7 @@ static void ggml_vk_load_shaders(vk_device& device) { #endif if (device->coopmat_acc_f16_support) { + CREATE_MM2(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q1_0], matmul_q1_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0], matmul_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1], matmul_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0], matmul_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); @@ -3624,6 +3627,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4], matmul_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); } else { + CREATE_MM(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q1_0].f32acc, matmul_q1_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); @@ -3658,6 +3662,7 @@ static void ggml_vk_load_shaders(vk_device& device) { } #endif + CREATE_MM2(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q1_0], matmul_id_subgroup_q1_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); @@ -3721,6 +3726,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q1_0], matmul_q1_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0], matmul_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1], matmul_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0], matmul_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); @@ -3767,6 +3773,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_subgroup_f16_f32, wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16); CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16); + CREATE_MM2(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q1_0], matmul_id_subgroup_q1_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); @@ -3811,6 +3818,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM2(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q1_0], matmul_id_q1_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); @@ -3884,6 +3892,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q1_0].f32acc, matmul_q1_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); @@ -3928,6 +3937,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_subgroup_f16_f32, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16); CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16); + CREATE_MM(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q1_0].f32acc, matmul_id_subgroup_q1_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_subgroup_q4_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_subgroup_q4_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_subgroup_q5_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); @@ -3954,6 +3964,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q1_0].f32acc, matmul_id_q1_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); @@ -4051,6 +4062,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f32_f32", arr_dmmv_f32_f32_f32_len[reduc], arr_dmmv_f32_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {wg_size_subgroup, 1, i+1}, 1, false, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f32_f32", arr_dmmv_f16_f32_f32_len[reduc], arr_dmmv_f16_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f32_f32", arr_dmmv_bf16_f32_f32_len[reduc], arr_dmmv_bf16_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q1_0][i], "mul_mat_vec_q1_0_f32_f32", arr_dmmv_q1_0_f32_f32_len[reduc], arr_dmmv_q1_0_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f32_f32", arr_dmmv_q4_0_f32_f32_len[reduc], arr_dmmv_q4_0_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f32_f32", arr_dmmv_q4_1_f32_f32_len[reduc], arr_dmmv_q4_1_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f32_f32", arr_dmmv_q5_0_f32_f32_len[reduc], arr_dmmv_q5_0_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); @@ -4075,6 +4087,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32", arr_dmmv_f32_f16_f32_len[reduc], arr_dmmv_f32_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {wg_size_subgroup, 1, i+1}, 1, false, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32", arr_dmmv_f16_f16_f32_len[reduc], arr_dmmv_f16_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f16_f32", arr_dmmv_bf16_f16_f32_len[reduc], arr_dmmv_bf16_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q1_0][i], "mul_mat_vec_q1_0_f16_f32", arr_dmmv_q1_0_f16_f32_len[reduc], arr_dmmv_q1_0_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f16_f32", arr_dmmv_q4_0_f16_f32_len[reduc], arr_dmmv_q4_0_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f16_f32", arr_dmmv_q4_1_f16_f32_len[reduc], arr_dmmv_q4_1_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f16_f32", arr_dmmv_q5_0_f16_f32_len[reduc], arr_dmmv_q5_0_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); @@ -4125,6 +4138,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", arr_dmmv_id_f32_f32_f32_len[reduc], arr_dmmv_id_f32_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {wg_size_subgroup, 1}, 1, false, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", arr_dmmv_id_f16_f32_f32_len[reduc], arr_dmmv_id_f16_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {wg_size_subgroup, 2}, 1, false, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_BF16], "mul_mat_vec_id_bf16_f32", arr_dmmv_id_bf16_f32_f32_len[reduc], arr_dmmv_id_bf16_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {wg_size_subgroup, 2}, 1, false, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q1_0], "mul_mat_vec_id_q1_0_f32", arr_dmmv_id_q1_0_f32_f32_len[reduc], arr_dmmv_id_q1_0_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", arr_dmmv_id_q4_0_f32_f32_len[reduc], arr_dmmv_id_q4_0_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", arr_dmmv_id_q4_1_f32_f32_len[reduc], arr_dmmv_id_q4_1_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", arr_dmmv_id_q5_0_f32_f32_len[reduc], arr_dmmv_id_q5_0_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size); @@ -4179,6 +4193,7 @@ static void ggml_vk_load_shaders(vk_device& device) { // dequant shaders ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q1_0], "dequant_q1_0", dequant_q1_0_len, dequant_q1_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 8, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_0], "dequant_q4_0", dequant_q4_0_len, dequant_q4_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_1], "dequant_q4_1", dequant_q4_1_len, dequant_q4_1_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_0], "dequant_q5_0", dequant_q5_0_len, dequant_q5_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); @@ -4204,6 +4219,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F32 ], "get_rows_f32", get_rows_f32_len, get_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F16 ], "get_rows_f16", get_rows_f16_len, get_rows_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_BF16], "get_rows_bf16", get_rows_bf16_len, get_rows_bf16_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q1_0], "get_rows_q1_0", get_rows_q1_0_len, get_rows_q1_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_0], "get_rows_q4_0", get_rows_q4_0_len, get_rows_q4_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_1], "get_rows_q4_1", get_rows_q4_1_len, get_rows_q4_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_0], "get_rows_q5_0", get_rows_q5_0_len, get_rows_q5_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); @@ -4229,6 +4245,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32", get_rows_f32_f32_len, get_rows_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F16 ], "get_rows_f16_f32", get_rows_f16_f32_len, get_rows_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_BF16], "get_rows_bf16_f32", get_rows_bf16_f32_len, get_rows_bf16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q1_0], "get_rows_q1_0_f32", get_rows_q1_0_f32_len, get_rows_q1_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_0], "get_rows_q4_0_f32", get_rows_q4_0_f32_len, get_rows_q4_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_1], "get_rows_q4_1_f32", get_rows_q4_1_f32_len, get_rows_q4_1_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_0], "get_rows_q5_0_f32", get_rows_q5_0_f32_len, get_rows_q5_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); @@ -4310,6 +4327,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_cpy_transpose_16, "cpy_transpose_16", cpy_transpose_16_len, cpy_transpose_16_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1); if (device->float_controls_rte_fp16) { + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q1_0], "cpy_f32_q1_0", cpy_f32_q1_0_rte_len, cpy_f32_q1_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_rte_len, cpy_f32_q4_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_rte_len, cpy_f32_q5_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); @@ -4317,6 +4335,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_rte_len, cpy_f32_q8_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_rte_len, cpy_f32_iq4_nl_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); } else { + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q1_0], "cpy_f32_q1_0", cpy_f32_q1_0_len, cpy_f32_q1_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); @@ -4329,6 +4348,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_F32], "set_rows_f32" #itype, set_rows_f32 ## itype ## rte ## _len, set_rows_f32 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_F16], "set_rows_f16" #itype, set_rows_f16 ## itype ## rte ## _len, set_rows_f16 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_BF16], "set_rows_bf16" #itype, set_rows_bf16 ## itype ## rte ## _len, set_rows_bf16 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q1_0], "set_rows_q1_0" #itype, set_rows_q1_0 ## itype ## rte ## _len, set_rows_q1_0 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q4_0], "set_rows_q4_0" #itype, set_rows_q4_0 ## itype ## rte ## _len, set_rows_q4_0 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q4_1], "set_rows_q4_1" #itype, set_rows_q4_1 ## itype ## rte ## _len, set_rows_q4_1 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q5_0], "set_rows_q5_0" #itype, set_rows_q5_0 ## itype ## rte ## _len, set_rows_q5_0 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ @@ -4346,6 +4366,7 @@ static void ggml_vk_load_shaders(vk_device& device) { #undef SET_ROWS + ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q1_0], "cpy_q1_0_f32", cpy_q1_0_f32_len, cpy_q1_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q1_0), 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_0], "cpy_q4_0_f32", cpy_q4_0_f32_len, cpy_q4_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_1], "cpy_q4_1_f32", cpy_q4_1_f32_len, cpy_q4_1_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q5_0], "cpy_q5_0_f32", cpy_q5_0_f32_len, cpy_q5_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_0), 1, 1}, {}, 1); @@ -6022,6 +6043,7 @@ static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type VK_LOG_DEBUG("ggml_vk_get_to_fp16()"); switch (type) { case GGML_TYPE_F32: + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -6093,6 +6115,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte } switch (src0_type) { + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -6158,6 +6181,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * case GGML_TYPE_F32: case GGML_TYPE_F16: case GGML_TYPE_BF16: + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -6248,6 +6272,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co GGML_ASSERT(src1_type == GGML_TYPE_F32 || (ctx->device->coopmat2 && src1_type == GGML_TYPE_F16)); switch (src0_type) { + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -6316,6 +6341,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context case GGML_TYPE_F32: case GGML_TYPE_F16: case GGML_TYPE_BF16: + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -7263,6 +7289,7 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const } if (src->type == GGML_TYPE_F32) { switch (to) { + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -7277,6 +7304,7 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const if (to == GGML_TYPE_F32) { switch (src->type) { + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -15269,6 +15297,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_TYPE_F32: case GGML_TYPE_F16: case GGML_TYPE_BF16: + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -15383,6 +15412,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_TYPE_F32: case GGML_TYPE_F16: case GGML_TYPE_BF16: + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -15415,6 +15445,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_TYPE_F32: case GGML_TYPE_F16: case GGML_TYPE_BF16: + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -15438,6 +15469,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_TYPE_F32: case GGML_TYPE_F16: case GGML_TYPE_BF16: + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -15452,6 +15484,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm if (src1_type == GGML_TYPE_F32) { switch (src0_type) { case GGML_TYPE_F16: + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp b/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp index b8c40eec102..4ffa45485c9 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp @@ -184,6 +184,31 @@ void quantize(uint dst_idx, uint src_idx) } #endif +#if defined(DATA_A_Q1_0) +void quantize(uint dst_idx, uint src_idx) +{ + float sum_abs = 0.0; + + [[unroll]] for (int j = 0; j < QUANT_K_Q1_0; j++) { + sum_abs += abs(data_s[src_idx + j]); + } + + const float d = sum_abs / QUANT_K_Q1_0; + + data_q[dst_idx].d = float16_t(d); + + [[unroll]] for (int j = 0; j < QUANT_K_Q1_0 / 8; ++j) { + data_q[dst_idx].qs[j] = uint8_t(0); + } + + [[unroll]] for (int j = 0; j < QUANT_K_Q1_0; ++j) { + if (data_s[src_idx + j] >= 0.0) { + data_q[dst_idx].qs[j / 8] |= uint8_t(1 << (j % 8)); + } + } +} +#endif + #if defined(DATA_A_IQ4_NL) uint best_index(float x) { if (x <= kvalues_iq4nl[0]) return 0; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl index 7865a6bda79..ede1275cfc2 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl @@ -87,6 +87,23 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) { } #endif +#if defined(DATA_A_Q1_0) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint bits = uint(data_a[a_offset + ib].qs[iqs / 8u]) >> (iqs % 8u); + return vec2( + (bits & 1u) != 0u ? 1.0f : -1.0f, + (bits & 2u) != 0u ? 1.0f : -1.0f); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint bits = uint(data_a[a_offset + ib].qs[iqs / 8u]) >> (iqs % 8u); + return vec4( + (bits & 1u) != 0u ? 1.0f : -1.0f, + (bits & 2u) != 0u ? 1.0f : -1.0f, + (bits & 4u) != 0u ? 1.0f : -1.0f, + (bits & 8u) != 0u ? 1.0f : -1.0f); +} +#endif + #if defined(DATA_A_IQ1_S) vec2 dequantize(uint ib, uint iqs, uint a_offset) { const uint ib32 = iqs / 32; @@ -454,6 +471,13 @@ vec2 get_dm(uint ib, uint a_offset) { } #endif +#if defined(DATA_A_Q1_0) +vec2 get_dm(uint ib, uint a_offset) { + const float d = float(data_a[a_offset + ib].d); + return vec2(d, 0); +} +#endif + #if defined(DATA_A_MXFP4) vec2 get_dm(uint ib, uint a_offset) { return vec2(e8m0_to_fp32(data_a[a_offset + ib].e), 0); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl index 8ac6482dc94..03035f28120 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl @@ -13,6 +13,18 @@ float16_t dequantFuncF32(const in decodeBufF32 bl, const in uint blockCoords[2], return vf16[idx]; } +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ1_0 { + block_q1_0 block; +}; + +float16_t dequantFuncQ1_0(const in decodeBufQ1_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + const uint bit = (uint(bl.block.qs[(idx & 0x78) >> 3]) >> (idx & 0x7)) & 1u; + return bit != 0u ? d : -d; +} + layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ4_0 { block_q4_0_packed16 block; }; @@ -685,7 +697,9 @@ float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords } #endif -#if defined(DATA_A_Q4_0) +#if defined(DATA_A_Q1_0) +#define dequantFuncA dequantFuncQ1_0 +#elif defined(DATA_A_Q4_0) #define dequantFuncA dequantFuncQ4_0 #elif defined(DATA_A_Q4_1) #define dequantFuncA dequantFuncQ4_1 diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q1_0.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q1_0.comp new file mode 100644 index 00000000000..ca0bdbc63e0 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q1_0.comp @@ -0,0 +1,29 @@ +#version 450 + +#include "dequant_head.glsl" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_q1_0 data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64; + + const uint tid = gl_LocalInvocationID.x % 64; + const uint il = tid / 4; + const uint ir = tid % 4; + const uint ib = 4*i + ir; + if (ib >= p.nel / 128) { + return; + } + + const uint b_idx = 512*i + 128*ir + 8*il; + + const float d = float(data_a[ib].d); + const uint bits = uint(data_a[ib].qs[il]); + + [[unroll]] for (uint l = 0; l < 8; ++l) { + data_b[b_idx + l] = D_TYPE((bits & (1u << l)) != 0u ? d : -d); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl index 9b769bbc887..219bd608035 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl @@ -130,6 +130,20 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin buf_a[buf_idx ] = FLOAT_TYPEV2(v.xy); buf_a[buf_idx + 1] = FLOAT_TYPEV2(v.zw); +#elif defined(DATA_A_Q1_0) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + + const uint ib = idx / 16; + const uint iqs = idx & 0xfu; + + const float d = float(data_a[ib].d); + const uint bits = uint(data_a[ib].qs[iqs]); + + buf_a[buf_idx ] = FLOAT_TYPEV2((bits & 0x01u) != 0u ? d : -d, (bits & 0x02u) != 0u ? d : -d); + buf_a[buf_idx + 1] = FLOAT_TYPEV2((bits & 0x04u) != 0u ? d : -d, (bits & 0x08u) != 0u ? d : -d); + buf_a[buf_idx + 2] = FLOAT_TYPEV2((bits & 0x10u) != 0u ? d : -d, (bits & 0x20u) != 0u ? d : -d); + buf_a[buf_idx + 3] = FLOAT_TYPEV2((bits & 0x40u) != 0u ? d : -d, (bits & 0x80u) != 0u ? d : -d); #elif defined(DATA_A_Q2_K) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl index bdb2c09259b..4239070af5e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl @@ -188,6 +188,22 @@ struct block_q8_0_packed16 #define DATA_A_QUANT_LEGACY #endif +#define QUANT_K_Q1_0 128 +#define QUANT_R_Q1_0 1 + +struct block_q1_0 +{ + float16_t d; + uint8_t qs[QUANT_K_Q1_0 / 8]; +}; + +#if defined(DATA_A_Q1_0) +#define QUANT_K QUANT_K_Q1_0 +#define QUANT_R QUANT_R_Q1_0 +#define QUANT_AUXF 1 +#define A_TYPE block_q1_0 +#endif + #define QUANT_K_Q8_1 32 #define QUANT_R_Q8_1 1 diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 11385f93378..77a55ea812b 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -45,6 +45,7 @@ std::string target_cpp = ""; const std::vector type_names = { "f32", "f16", + "q1_0", "q4_0", "q4_1", "q5_0", @@ -553,7 +554,7 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c for (const auto& tname : type_names) { std::string load_vec_quant = "2"; - if ((tname == "q4_0") || (tname == "q4_1") || (tname == "q5_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s")) + if ((tname == "q1_0") || (tname == "q4_0") || (tname == "q4_1") || (tname == "q5_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s")) load_vec_quant = "8"; else if ((tname == "q5_0") || (tname == "q8_0") || (tname == "q2_k") || (tname == "q4_k") || (tname == "q5_k") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_xs") || (tname == "iq4_nl") || (tname == "mxfp4")) load_vec_quant = "4"; @@ -758,13 +759,13 @@ void process_shaders() { string_to_spv("cpy_transpose_16", "copy_transpose.comp", {{"A_TYPE", "uint16_t"}, {"D_TYPE", "uint16_t"}}); string_to_spv("cpy_transpose_32", "copy_transpose.comp", {{"A_TYPE", "uint"}, {"D_TYPE", "uint"}}); - for (std::string t : {"q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) { + for (std::string t : {"q1_0", "q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) { string_to_spv("cpy_f32_" + t, "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); string_to_spv("cpy_f32_" + t + "_rte", "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}}); string_to_spv("cpy_" + t + "_f32", "copy_from_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); } - for (std::string t : {"f32", "f16", "bf16", "q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) { + for (std::string t : {"f32", "f16", "bf16", "q1_0", "q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) { string_to_spv("set_rows_" + t + "_i32", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uint"}, {"B_SIZE", "32"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); string_to_spv("set_rows_" + t + "_i32_rte", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uint"}, {"B_SIZE", "32"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}}); string_to_spv("set_rows_" + t + "_i64", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uvec2"}, {"B_SIZE", "64"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); From 3fc738a8c2c798deef3371c4a5da95aaa251379c Mon Sep 17 00:00:00 2001 From: Chen Yuan Date: Fri, 10 Apr 2026 13:52:01 -0400 Subject: [PATCH 413/831] ggml-webgpu: address quantization precision and backend lifecycle managment (llama/21521) * ggml(webgpu): fix the busy-polls in Emscripten in the waitAny after #20618, and remove the busy webgpu log * Merge with upstream * Fix GET_ROWS packed integer NaN when using f16 as memory buffer in shader quants * Update Unary wgsl EXP and EXPM1 for f16 stability * Fix GET_ROWS IQ4_XS strcut for NaN f16 canonicalization * Fix numerical percision for unary sqrt when working with f16 * Fix NaN canonicalization for packed integers using f16 * Update err threshold for binary div ops when using f16 * backend: Keep one Dawn/WebGPU instance alive for the lifetime of the static backend * clean: uncomment existing code logs * clean: clean the unncessary debug info * Refactor and generalize dequant helpers * Remove deprecated quant structs * Refactor shader defines to reduce repetition * Remove error override for F16 type * fix: fix the accidential removal of the proper initialization of ctx * clean: clean legacy and format code * fix: did not modify tests ops --------- Co-authored-by: Jeremy J. Hartmann --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 55 ++++- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 37 +++- .../wgsl-shaders/common_decls.tmpl | 139 +++---------- .../ggml-webgpu/wgsl-shaders/get_rows.wgsl | 189 +++++++++++------- .../src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl | 161 ++++++++------- .../wgsl-shaders/mul_mat_decls.tmpl | 78 ++++---- .../ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl | 46 ++--- ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl | 8 +- 8 files changed, 383 insertions(+), 330 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index c10157766d9..3de6258c74d 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -1115,6 +1115,32 @@ class ggml_webgpu_shader_lib { std::string type_upper = type_str; std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper); + switch (key.src_type) + { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ4_NL: + { + // Quantized types using u32 buffers for portability. + defines.push_back("SRC_TYPE=u32"); + defines.push_back("U32_DEQUANT_HELPERS"); + break; + } + default: + { + defines.push_back(std::string("SRC_TYPE=") + type_str); + } + } + defines.push_back("BYTE_HELPERS"); defines.push_back(type_upper + "_T"); defines.push_back(type_upper); @@ -1125,7 +1151,6 @@ class ggml_webgpu_shader_lib { variant += "_"; variant += type_str; - defines.push_back(std::string("SRC_TYPE=") + type_str); defines.push_back("DST_TYPE=f32"); if ((key.src_type >= GGML_TYPE_Q4_0 && key.src_type <= GGML_TYPE_Q8_1) || @@ -1593,11 +1618,35 @@ class ggml_webgpu_shader_lib { break; default: { - // quantized types std::string type_upper = src0_name; std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper); - defines.push_back(std::string("SRC0_TYPE=") + src0_name); + switch (context.src0->type) + { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ4_NL: + { + // Quantized types using u32 buffers for portability. + defines.push_back("SRC0_TYPE=u32"); + defines.push_back("U32_DEQUANT_HELPERS"); + break; + } + default: + { + defines.push_back(std::string("SRC0_TYPE=") + src0_name); + } + } + defines.push_back("BYTE_HELPERS"); defines.push_back(type_upper + "_T"); defines.push_back(type_upper); diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index edfc6579171..3b894a9b9cc 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -97,6 +97,14 @@ static inline void compute_2d_workgroups(uint32_t total_wg, uint32_t max_per_dim /* End Constants */ +static inline wgpu::CallbackMode ggml_webgpu_callback_mode() { +#ifdef __EMSCRIPTEN__ + return wgpu::CallbackMode::AllowProcessEvents; +#else + return wgpu::CallbackMode::AllowSpontaneous; +#endif +} + // This is a "fake" base pointer, since WebGPU buffers do not have pointers to // their locations. static void * const webgpu_ptr_base = (void *) (uintptr_t) 0x1000; // NOLINT @@ -474,7 +482,7 @@ static void ggml_backend_webgpu_wait_queue(webgpu_global_context & ctx) { const wgpu::WaitStatus wait_status = ctx->instance.WaitAny( ctx->queue.OnSubmittedWorkDone( - wgpu::CallbackMode::AllowSpontaneous, + ggml_webgpu_callback_mode(), [&callback_status, &callback_message](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) { callback_status = status; callback_message = std::string(message); @@ -494,7 +502,7 @@ static void ggml_backend_webgpu_map_buffer(webgpu_global_context & ctx, std::string callback_message; const wgpu::WaitStatus wait_status = ctx->instance.WaitAny( - buffer.MapAsync(mode, offset, size, wgpu::CallbackMode::AllowSpontaneous, + buffer.MapAsync(mode, offset, size, ggml_webgpu_callback_mode(), [&callback_status, &callback_message](wgpu::MapAsyncStatus status, wgpu::StringView message) { callback_status = status; callback_message = std::string(message); @@ -526,7 +534,11 @@ static void ggml_backend_webgpu_debug(webgpu_global_context & ctx) { encoder.CopyBufferToBuffer(ctx->debug_dev_buf, 0, ctx->debug_host_buf, 0, ctx->debug_host_buf.GetSize()); wgpu::CommandBuffer commands = encoder.Finish(); ctx->queue.Submit(1, &commands); - ggml_backend_webgpu_map_buffer(ctx, ctx->debug_host_buf, wgpu::MapMode::Read, 0, ctx->debug_host_buf.GetSize()); + if (!ggml_backend_webgpu_map_buffer(ctx, ctx->debug_host_buf, wgpu::MapMode::Read, 0, + ctx->debug_host_buf.GetSize())) { + GGML_LOG_ERROR("ggml_webgpu: Debug buffer map failed\n"); + return; + } const float * debug_data = (const float *) ctx->debug_host_buf.GetConstMappedRange(); std::cout << "debug[0]: " << debug_data[0] << "\n"; ctx->debug_host_buf.Unmap(); @@ -542,7 +554,7 @@ static void ggml_backend_webgpu_collect_profile_futures(webgpu_global_context & auto ts_bufs = command.timestamp_query_bufs; wgpu::Future f = ts_bufs.host_buf.MapAsync( - wgpu::MapMode::Read, 0, ts_bufs.host_buf.GetSize(), wgpu::CallbackMode::AllowSpontaneous, + wgpu::MapMode::Read, 0, ts_bufs.host_buf.GetSize(), ggml_webgpu_callback_mode(), [ctx, ts_bufs, label](wgpu::MapAsyncStatus status, wgpu::StringView message) { if (status != wgpu::MapAsyncStatus::Success) { GGML_LOG_ERROR("ggml_webgpu: Failed to map timestamp buffer: %s\n", std::string(message).c_str()); @@ -3420,7 +3432,7 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { ctx->webgpu_global_ctx->instance.WaitAny( ctx->webgpu_global_ctx->instance.RequestAdapter( - &options, wgpu::CallbackMode::AllowSpontaneous, + &options, ggml_webgpu_callback_mode(), [&ctx](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message) { if (status != wgpu::RequestAdapterStatus::Success) { GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message); @@ -3491,8 +3503,8 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { dev_desc.requiredFeatures = required_features.data(); dev_desc.requiredFeatureCount = required_features.size(); dev_desc.SetDeviceLostCallback( - wgpu::CallbackMode::AllowSpontaneous, - [](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) { + ggml_webgpu_callback_mode(), + [ctx](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) { if (reason == wgpu::DeviceLostReason::Destroyed) { return; } @@ -3525,7 +3537,7 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { ctx->webgpu_global_ctx->instance.WaitAny( ctx->webgpu_global_ctx->adapter.RequestDevice( - &dev_desc, wgpu::CallbackMode::AllowSpontaneous, + &dev_desc, ggml_webgpu_callback_mode(), [ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) { if (status != wgpu::RequestDeviceStatus::Success) { GGML_LOG_ERROR("ggml_webgpu: Failed to get a device: %s\n", std::string(message).c_str()); @@ -4046,6 +4058,13 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() { ctx.name = GGML_WEBGPU_NAME; ctx.device_count = 0; + // Keep one Dawn/WebGPU instance alive for the lifetime of the static backend + // registry. Recreating it on repeated registry lookups can invalidate + // adapter/device references that are still held by the backend/device layer. + if (ctx.webgpu_global_ctx != nullptr && ctx.webgpu_global_ctx->instance != nullptr) { + return ® + } + wgpu::InstanceDescriptor instance_descriptor{}; std::vector instance_features = { wgpu::InstanceFeatureName::TimedWaitAny }; instance_descriptor.requiredFeatures = instance_features.data(); @@ -4063,11 +4082,11 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() { ctx.webgpu_global_ctx = webgpu_global_context(new webgpu_global_context_struct()); ctx.webgpu_global_ctx->instance = std::move(inst); + // Probe for adapter support wgpu::Adapter adapter; if (ctx.webgpu_global_ctx->instance != nullptr) { wgpu::RequestAdapterOptions options = {}; - // probe for adapter support ctx.webgpu_global_ctx->instance.WaitAny( ctx.webgpu_global_ctx->instance.RequestAdapter( &options, wgpu::CallbackMode::AllowSpontaneous, diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl index feb0bca3f84..0d3501c34a2 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl @@ -9,36 +9,44 @@ fn get_byte_i32(value: u32, index: u32) -> i32 { #endif #ifdef U32_DEQUANT_HELPERS -fn load_src0_u16_at(byte_offset: u32) -> u32 { - let word = src0[byte_offset / 4u]; - let shift = (byte_offset & 2u) * 8u; - return (word >> shift) & 0xFFFFu; +fn load_u16_at( + buf: ptr, read_write>, + byte_offset: u32) -> u32 { + let word = buf[byte_offset / 4]; + let shift = (byte_offset & 0x2) * 8; + return (word >> shift) & 0xFFFF; } -fn load_src0_u32_at(byte_offset: u32) -> u32 { - let word_idx = byte_offset / 4u; - let shift = (byte_offset & 3u) * 8u; - let lo = src0[word_idx]; - if (shift == 0u) { - return lo; - } - let hi = src0[word_idx + 1u]; - return (lo >> shift) | (hi << (32u - shift)); +fn load_u32_at( + buf: ptr, read_write>, + byte_offset: u32) -> u32 { + let word_idx = byte_offset / 4; + let shift = (byte_offset & 0x3) * 8; + let lo = buf[word_idx]; + let hi = buf[word_idx + 1]; + let shifted = (lo >> shift) | (hi << (32 - shift)); + return select(shifted, lo, shift == 0); } -fn load_src0_f16_at(byte_offset: u32) -> f16 { - let packed = unpack2x16float(load_src0_u16_at(byte_offset)); +fn load_f16_at( + buf: ptr, read_write>, + byte_offset: u32) -> f16 { + let packed = unpack2x16float(load_u16_at(buf, byte_offset)); return f16(packed[0]); } -#endif -#ifdef Q4_0_T -struct q4_0 { - d: f16, - qs: array -}; +fn load_f16_as_f32_at( + buf: ptr, read_write>, + byte_offset: u32) -> f32 { + let word = buf[byte_offset / 4]; + let shift = (byte_offset & 0x2) * 8; + let d_bits = (word >> shift) & 0xFFFF; + return unpack2x16float(d_bits)[0]; +} #endif + + #ifdef Q4_1_T struct q4_1 { d: f16, @@ -47,13 +55,6 @@ struct q4_1 { }; #endif -#ifdef Q5_0_T -struct q5_0 { - d: f16, - qh: array, - qs: array -}; -#endif #ifdef Q5_1_T struct q5_1 { @@ -64,12 +65,6 @@ struct q5_1 { }; #endif -#ifdef Q8_0_T -struct q8_0 { - d: f16, - qs: array -}; -#endif #ifdef Q8_1_T struct q8_1 { @@ -88,14 +83,6 @@ struct q2_K { }; #endif -#ifdef Q3_K_T -struct q3_K { - hmask: array, - qs: array, - scales: array, - d: f16 -}; -#endif #if defined(Q4_K_SCALE_MIN) || defined(Q5_K_SCALE_MIN) fn get_scale_min(is: u32, scales: array) -> vec2 { @@ -132,64 +119,6 @@ struct q5_K { }; #endif -#ifdef Q6_K_T -struct q6_K { - ql: array, - qh: array, - scales: array, - d: f16 -}; -#endif - -#ifdef IQ2_XXS_T -struct iq2_xxs { - d: f16, - qs: array -}; -#endif - -#ifdef IQ2_XS_T -struct iq2_xs { - d: f16, - qs: array, - scales: array -}; -#endif - -#ifdef IQ2_S_T -struct iq2_s { - d: f16, - qs: array, - qh: array, - scales: array -}; -#endif - -#ifdef IQ3_XXS_T -struct iq3_xxs { - d: f16, - qs: array -}; -#endif - -#ifdef IQ3_S_T -struct iq3_s { - d: f16, - qs: array, - qh: array, - signs: array, - scales: array -}; -#endif - -#ifdef IQ1_S_T -struct iq1_s { - d: f16, - qs: array, - qh: array -}; -#endif - #ifdef IQ1_M_T struct iq1_m { qs: array, @@ -198,17 +127,9 @@ struct iq1_m { }; #endif -#ifdef IQ4_NL_T -struct iq4_nl { - d: f16, - qs: array, -}; -#endif - #ifdef IQ4_XS_T struct iq4_xs { - d: f16, - scales_h: f16, + d_scales_h: u32, scales_l: u32, qs: array }; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl index d9eb6a3567e..3c8b84c9ac3 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl @@ -27,17 +27,18 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef Q4_0 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let block_q4_0 = src[src_base + offset]; - let d = f32(block_q4_0.d); - for (var j: u32 = 0; j < 4; j++) { - let q_packed = bitcast(vec2(block_q4_0.qs[2 * j], block_q4_0.qs[2 * j + 1])); + let block_byte_base = (src_base + offset) * 18; // Block stride: 18 bytes + let d = load_f16_as_f32_at(&src, block_byte_base); + for (var j: u32 = 0u; j < 4; j++) { + let q_byte_offset = block_byte_base + 2 + j * 4; + let q_packed = load_u32_at(&src, q_byte_offset); for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte(q_packed, k); - let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0f) * d; - let q_lo = (f32(q_byte & 0xF) - 8.0f) * d; + let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * d; + let q_lo = (f32(q_byte & 0xFu) - 8.0) * d; let dst_offset = dst_base + offset * 32 + j * 4 + k; dst[dst_offset] = q_lo; - dst[dst_offset + 16] = q_hi; + dst[dst_offset + 16u] = q_hi; } } } @@ -64,17 +65,22 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef Q5_0 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let block_q5_0 = src[src_base + offset]; - let d = f32(block_q5_0.d); - let qh_packed = bitcast(vec2(block_q5_0.qh[0], block_q5_0.qh[1])); + let block_byte_base = (src_base + offset) * 22; // Block stride: 22 bytes + let d = load_f16_as_f32_at(&src, block_byte_base); + let qh_packed = load_u32_at(&src, block_byte_base + 2); for (var j: u32 = 0; j < 4; j++) { - let q_packed = bitcast(vec2(block_q5_0.qs[2 * j], block_q5_0.qs[2 * j + 1])); + let q_byte_offset = block_byte_base + 6 + j * 4; + let q_packed = load_u32_at(&src, q_byte_offset); + for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte(q_packed, k); + let qh_hi = (qh_packed >> (j * 4 + k + 12)) & 0x10; let q_hi = (f32(((q_byte >> 4) & 0xF) | qh_hi) - 16.0) * d; + let qh_lo = ((qh_packed >> (j * 4 + k)) << 4) & 0x10; let q_lo = (f32((q_byte & 0xF) | qh_lo) - 16.0) * d; + let dst_offset = dst_base + offset * 32 + j * 4 + k; dst[dst_offset] = q_lo; dst[dst_offset + 16] = q_hi; @@ -106,14 +112,15 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef Q8_0 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let block_q8_0 = src[src_base + offset]; - let d = f32(block_q8_0.d); - for (var j: u32 = 0; j < 8; j++) { - let q_packed = bitcast(vec2(block_q8_0.qs[2 * j], block_q8_0.qs[2 * j + 1])); - for (var k: u32 = 0; k < 4; k++) { + let block_byte_base = (src_base + offset) * 34; // Block stride: 34 bytes + let d = load_f16_as_f32_at(&src, block_byte_base); + for (var j: u32 = 0u; j < 8u; j++) { + let q_byte_offset = block_byte_base + 2u + j * 4u; + let q_packed = load_u32_at(&src, q_byte_offset); + for (var k: u32 = 0u; k < 4u; k++) { let q_byte = get_byte_i32(q_packed, k); let q_val = f32(q_byte) * d; - let dst_offset = dst_base + offset * 32 + j * 4 + k; + let dst_offset = dst_base + offset * 32u + j * 4u + k; dst[dst_offset] = q_val; } } @@ -152,36 +159,42 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef Q3_K fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let block = src[src_base + offset]; - let d = f32(block.d); + let block_byte_base = (src_base + offset) * 110; // Block stride: 110 bytes - // extract 6-bit scales, which consist of 4-bits from first 8 bytes of scale, - // and 2-bits from the last 4 bytes + // Bytes 108-109: f16 scale 'd' + let d = load_f16_as_f32_at(&src, block_byte_base + 108); + + // Bytes 96-107: 12 bytes of scales (3 u32s) let kmask1: u32 = 0x03030303; let kmask2: u32 = 0x0f0f0f0f; + var scale_vals: array; - for (var i: u32 = 0; i < 4; i++) { - scale_vals[i] = bitcast(vec2(block.scales[2 * i], block.scales[2 * i + 1])); - } + scale_vals[0] = load_u32_at(&src, block_byte_base + 96); + scale_vals[1] = load_u32_at(&src, block_byte_base + 100); + scale_vals[2] = load_u32_at(&src, block_byte_base + 104); + var tmp: u32 = scale_vals[2]; scale_vals[2] = ((scale_vals[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); scale_vals[3] = ((scale_vals[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4); scale_vals[0] = (scale_vals[0] & kmask2) | ((tmp & kmask1) << 4); scale_vals[1] = (scale_vals[1] & kmask2) | (((tmp >> 2) & kmask1) << 4); - // convert arrays of f16 -> u32 + // Bytes 0-31: 32 bytes of hmask (8 u32s) var hmask_vals: array; for (var i: u32 = 0; i < 8; i++) { - hmask_vals[i] = bitcast(vec2(block.hmask[2 * i], block.hmask[2 * i + 1])); + hmask_vals[i] = load_u32_at(&src, block_byte_base + i * 4); } + + // Bytes 32-95: 64 bytes of qs (16 u32s) var qs_vals: array; - for (var i: u32 = 0; i < 16; i++) { - qs_vals[i] = bitcast(vec2(block.qs[2 * i], block.qs[2 * i + 1])); + for (var i: u32 = 0u; i < 16; i++) { + qs_vals[i] = load_u32_at(&src, block_byte_base + 32 + i * 4); } var dst_i = dst_base + offset * 256; var is: u32 = 0; var m: u32 = 1; + // 2 halves of the block (128 elements each) for (var q_b_idx: u32 = 0; q_b_idx < 64; q_b_idx += 32) { // 4 groups (each group has 2 blocks of 16 elements) @@ -191,11 +204,13 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let sc = get_byte(scale_vals[is / 4], is % 4); is++; let dl = d * (f32(sc) - 32.0); - for (var l: u32 = 0u; l < 16u; l++) { + + for (var l: u32 = 0; l < 16; l++) { let q_idx = q_b_idx + k + l; let hm_idx = k + l; let q_byte = get_byte(qs_vals[q_idx / 4], q_idx % 4); let hmask_byte = get_byte(hmask_vals[hm_idx / 4], hm_idx % 4); + let hm = select(4.0, 0.0, (hmask_byte & m) != 0); let qs_val = (q_byte >> shift) & 3; dst[dst_i] = (f32(qs_val) - hm) * dl; @@ -268,21 +283,27 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef Q6_K // 16 blocks of 16 elements each fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let block = src[src_base + offset]; - let d = f32(block.d); + let block_byte_base = (src_base + offset) * 210; // Block stride: 210 bytes - // convert arrays of f16 -> u32 + // Bytes 208-209: f16 scale 'd' + let d = load_f16_as_f32_at(&src, block_byte_base + 208); + + // Bytes 0-127: 128 bytes of ql (32 u32s) var ql_vals: array; for (var i: u32 = 0; i < 32; i++) { - ql_vals[i] = bitcast(vec2(block.ql[2 * i], block.ql[2 * i + 1])); + ql_vals[i] = load_u32_at(&src, block_byte_base + i * 4); } + + // Bytes 128-191: 64 bytes of qh (16 u32s) var qh_vals: array; - for (var i: u32 = 0; i < 16; i++) { - qh_vals[i] = bitcast(vec2(block.qh[2 * i], block.qh[2 * i + 1])); + for (var i: u32 = 0; i < 16u; i++) { + qh_vals[i] = load_u32_at(&src, block_byte_base + 128 + i * 4u); } + + // Bytes 192-207: 16 bytes of scales (4 u32s) var scale_vals: array; for (var i: u32 = 0; i < 4; i++) { - scale_vals[i] = bitcast(vec2(block.scales[2 * i], block.scales[2 * i + 1])); + scale_vals[i] = load_u32_at(&src, block_byte_base + 192 + i * 4); } var dst_i = dst_base + offset * 256; @@ -323,12 +344,14 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef IQ2_XXS fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let block = src[src_base + offset]; - let d = f32(block.d); + let block_byte_base = (src_base + offset) * 66; // Block stride: 66 bytes + let d = load_f16_as_f32_at(&src, block_byte_base); var dst_i = dst_base + offset * 256; for (var ib: u32 = 0; ib < 32; ib += 4) { - let aux0 = bitcast(vec2(block.qs[ib], block.qs[ib + 1])); - let aux1 = bitcast(vec2(block.qs[ib + 2], block.qs[ib + 3])); + let aux0_offset = block_byte_base + 2 + ib * 2; + let aux1_offset = block_byte_base + 2 + (ib + 2) * 2; + let aux0 = load_u32_at(&src, aux0_offset); + let aux1 = load_u32_at(&src, aux1_offset); let db = d * (0.5 + f32(aux1 >> 28)) * 0.25; for (var l: u32 = 0; l < 4; l++) { let ig = get_byte(aux0, l) * 8; @@ -345,15 +368,19 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { } #endif + + #ifdef IQ2_XS fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let block = src[src_base + offset]; - let d = f32(block.d); + let block_byte_base = (src_base + offset) * 74; // Block stride: 74 bytes + let d = load_f16_as_f32_at(&src, block_byte_base); var dst_i = dst_base + offset * 256; + var scale_vals = array( - bitcast(vec2(block.scales[0], block.scales[1])), - bitcast(vec2(block.scales[2], block.scales[3])) + load_u32_at(&src, block_byte_base + 66), + load_u32_at(&src, block_byte_base + 70) ); + for (var ib: u32 = 0; ib < 32; ib += 4) { let s = get_byte(scale_vals[ib / 16], (ib % 16) / 4); let db = array( @@ -361,7 +388,8 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { d * (0.5 + f32(s >> 4)) * 0.25 ); for (var l: u32 = 0; l < 4; l++) { - let qs_val = bitcast(vec2(block.qs[ib + l], 0.0)); + let qs_offset = block_byte_base + 2 + (ib + l) * 2; + let qs_val = load_u32_at(&src, qs_offset) & 0xFFFF; let ig = (qs_val & 511) * 8; let is = qs_val >> 9; let signs = get_byte(ksigns_iq2xs[is / 4], is % 4); @@ -379,21 +407,23 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef IQ2_S fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let block = src[src_base + offset]; - let d = f32(block.d); + let block_byte_base = (src_base + offset) * 82; // Block stride: 82 bytes + let d = load_f16_as_f32_at(&src, block_byte_base); var dst_i = dst_base + offset * 256; + var qs_vals : array; for (var i: u32 = 0; i < 16; i++) { - qs_vals[i] = bitcast(vec2(block.qs[i * 2], block.qs[i * 2 + 1])); + qs_vals[i] = load_u32_at(&src, block_byte_base + 2 + i * 4); } - var qh_vals = array( - bitcast(vec2(block.qh[0], block.qh[1])), - bitcast(vec2(block.qh[2], block.qh[3])) - ); - var scale_vals = array( - bitcast(vec2(block.scales[0], block.scales[1])), - bitcast(vec2(block.scales[2], block.scales[3])) - ); + + var qh_vals: array; + qh_vals[0] = load_u32_at(&src, block_byte_base + 66); + qh_vals[1] = load_u32_at(&src, block_byte_base + 70); + + var scale_vals: array; + scale_vals[0] = load_u32_at(&src, block_byte_base + 74); + scale_vals[1] = load_u32_at(&src, block_byte_base + 78); + for (var ib: u32 = 0; ib < 8; ib ++) { let s = get_byte(scale_vals[ib / 4], ib % 4); let db = array( @@ -419,16 +449,17 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef IQ3_XXS fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let block = src[src_base + offset]; - let d = f32(block.d); + let block_byte_base = (src_base + offset) * 98; // Block stride: 98 bytes + let d = load_f16_as_f32_at(&src, block_byte_base); var dst_i = dst_base + offset * 256; for (var ib: u32 = 0; ib < 16; ib += 2) { - let sc_sign = bitcast(vec2(block.qs[ib + 32], block.qs[ib + 33])); + let sc_sign_offset = block_byte_base + 2 + (ib + 32) * 2; + let sc_sign = load_u32_at(&src, sc_sign_offset); let db = d * (0.5 + f32(sc_sign >> 28)) * 0.5; for (var l: u32 = 0; l < 4; l++) { let is = (sc_sign >> (7 * l)) & 127; let signs = get_byte(ksigns_iq2xs[is / 4], is % 4); - let ig_val = bitcast(vec2(block.qs[ib * 2 + l], 0.0)); + let ig_val = load_u32_at(&src, block_byte_base + 2 + (ib * 2 + l) * 2) & 0xFFFF; let ig1 = get_byte(ig_val, 0); let ig2 = get_byte(ig_val, 1); for (var j: u32 = 0; j < 4; j++) { @@ -448,18 +479,22 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef IQ3_S fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let block = src[src_base + offset]; - let d = f32(block.d); + let block_byte_base = (src_base + offset) * 110; // Block stride: 110 bytes + let d = load_f16_as_f32_at(&src, block_byte_base); var dst_i = dst_base + offset * 256; + var qh_vals = array( - bitcast(vec2(block.qh[0], block.qh[1])), - bitcast(vec2(block.qh[2], block.qh[3])) + load_u32_at(&src, block_byte_base + 66), + load_u32_at(&src, block_byte_base + 70) ); + var sign_vals: array; for (var i: u32 = 0; i < 8; i++) { - sign_vals[i] = bitcast(vec2(block.signs[i * 2], block.signs[i * 2 + 1])); + sign_vals[i] = load_u32_at(&src, block_byte_base + 74 + i * 4); } - var scale_vals = bitcast(vec2(block.scales[0], block.scales[1])); + + var scale_vals = load_u32_at(&src, block_byte_base + 106); + for (var ib: u32 = 0; ib < 4; ib++) { let s = get_byte(scale_vals, ib); let db = array( @@ -472,7 +507,7 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let sign_w = sign_vals[ib * 2 + k]; for (var l: u32 = 0; l < 4; l++) { let signs = get_byte(sign_w, l); - let ig_val = bitcast(vec2(block.qs[ib * 8 + k * 4 + l], 0.0)); + let ig_val = load_u32_at(&src, block_byte_base + 2 + (ib * 8 + k * 4 + l) * 2) & 0xFFFF; let ig1 = get_byte(ig_val, 0) | ((qh_byte << ((8 - (2 * l)))) & 256); let ig2 = get_byte(ig_val, 1) | ((qh_byte << ((7 - (2 * l)))) & 256); for (var j: u32 = 0; j < 4; j++) { @@ -493,14 +528,14 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef IQ1_S fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let block = src[src_base + offset]; - let d = f32(block.d); + let block_byte_base = (src_base + offset) * 50; // Block stride: 50 bytes + let d = load_f16_as_f32_at(&src, block_byte_base); var dst_i = dst_base + offset * 256; for (var ib: u32 = 0; ib < 8; ib++) { - let qh = bitcast(vec2(block.qh[ib], 0.0)); - let dl = d * (2 * f32((qh >> 12) & 7) + 1); + let qh = load_u32_at(&src, block_byte_base + 34 + ib * 2) & 0xFFFF; + let dl = d * (2.0 * f32((qh >> 12) & 7) + 1.0); let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000) != 0); - let qs_w = bitcast(vec2(block.qs[ib * 2], block.qs[ib * 2 + 1])); + let qs_w = load_u32_at(&src, block_byte_base + 2 + ib * 4); for (var l: u32 = 0; l < 4; l++) { let ig = (get_byte(qs_w, l) | (((qh >> (3 * l)) & 7) << 8)) * 8; for (var j: u32 = 0; j < 8; j++) { @@ -560,12 +595,12 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef IQ4_NL fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let block = src[src_base + offset]; - let d = f32(block.d); + let block_byte_base = (src_base + offset) * 18; // Block stride: 18 bytes + let d = load_f16_as_f32_at(&src, block_byte_base); var dst_i = dst_base + offset * 32; var qs: array; for (var i: u32 = 0; i < 4; i++) { - qs[i] = bitcast(vec2(block.qs[i * 2], block.qs[i * 2 + 1])); + qs[i] = load_u32_at(&src, block_byte_base + 2 + i * 4); } for (var j: u32 = 0; j < 16; j++) { let qsb = get_byte(qs[j / 4], j % 4); @@ -579,8 +614,8 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef IQ4_XS fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block = src[src_base + offset]; - let d = f32(block.d); - let scales_h = bitcast(vec2(block.scales_h, 0.0)); + let d = unpack2x16float(block.d_scales_h)[0]; + let scales_h = block.d_scales_h >> 16; var dst_i = dst_base + offset * 256; for (var ib: u32 = 0; ib < 8; ib++) { let ls = ((get_byte(block.scales_l, ib / 2) >> (4 * (ib % 2))) & 0xF) | (((scales_h >> (2 * ib)) & 3) << 4); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl index 5b9f5b36224..fdabaf09b2e 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl @@ -20,11 +20,12 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef Q4_0 fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_q4_0 = src0[src0_idx_base + offset]; - let d = f32(block_q4_0.d); + let block_byte_base = (src0_idx_base + offset) * 18; // Block stride: 18 bytes + let d = load_f16_as_f32_at(&src0, block_byte_base); var sum: f32 = 0.0; for (var j: u32 = 0; j < 4; j++) { - let q_packed = bitcast(vec2(block_q4_0.qs[2 * j], block_q4_0.qs[2 * j + 1])); + let q_byte_offset = block_byte_base + 2 + j * 4; + let q_packed = load_u32_at(&src0, q_byte_offset); for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte(q_packed, k); let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0f) * d; @@ -61,12 +62,13 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef Q5_0 fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_q5_0 = src0[src0_idx_base + offset]; - let d = f32(block_q5_0.d); + let block_byte_base = (src0_idx_base + offset) * 22; // Block stride: 22 bytes + let d = load_f16_as_f32_at(&src0, block_byte_base); var sum: f32 = 0.0; - let qh_packed = bitcast(vec2(block_q5_0.qh[0], block_q5_0.qh[1])); + let qh_packed = load_u32_at(&src0, block_byte_base + 2); for (var j: u32 = 0; j < 4; j++) { - let q_packed = bitcast(vec2(block_q5_0.qs[2 * j], block_q5_0.qs[2 * j + 1])); + let q_byte_offset = block_byte_base + 6 + j * 4; + let q_packed = load_u32_at(&src0, q_byte_offset); for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte(q_packed, k); let qh_hi = (qh_packed >> (j * 4 + k + 12)) & 0x10; @@ -107,12 +109,13 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef Q8_0 fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_q8_0 = src0[src0_idx_base + offset]; - let d = f32(block_q8_0.d); + let block_byte_base = (src0_idx_base + offset) * 34; // Block stride: 34 bytes + let d = load_f16_as_f32_at(&src0, block_byte_base); var sum: f32 = 0.0; for (var j: u32 = 0; j < 8; j++) { - let q_packed = bitcast(vec2(block_q8_0.qs[2 * j], block_q8_0.qs[2 * j + 1])); - for (var k: u32 = 0; k < 4; k++) { + let q_byte_offset = block_byte_base + 2 + j * 4; + let q_packed = load_u32_at(&src0, q_byte_offset); + for (var k: u32 = 0u; k < 4u; k++) { let q_byte = get_byte_i32(q_packed, k); let q_val = f32(q_byte) * d; let src1_offset = src1_idx_base + offset * 32 + j * 4 + k; @@ -178,31 +181,37 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef Q3_K // 16 blocks of 16 elements each fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - let d = f32(block.d); + let block_byte_base = (src0_idx_base + offset) * 110; // Block stride: 110 bytes + + // Bytes 108-109: f16 scale 'd' + let d = load_f16_as_f32_at(&src0, block_byte_base + 108); // extract 6-bit scales, which consist of 4-bits from first 8 bytes of scale, // and 2-bits from the last 4 bytes + // Bytes 96-107: 12 bytes of scales (3 u32s) let kmask1: u32 = 0x03030303; let kmask2: u32 = 0x0f0f0f0f; var scale_vals: array; - for (var i: u32 = 0; i < 4; i++) { - scale_vals[i] = bitcast(vec2(block.scales[2 * i], block.scales[2 * i + 1])); - } + scale_vals[0] = load_u32_at(&src0, block_byte_base + 96); + scale_vals[1] = load_u32_at(&src0, block_byte_base + 100); + scale_vals[2] = load_u32_at(&src0, block_byte_base + 104); + var tmp: u32 = scale_vals[2]; scale_vals[2] = ((scale_vals[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); scale_vals[3] = ((scale_vals[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4); scale_vals[0] = (scale_vals[0] & kmask2) | ((tmp & kmask1) << 4); scale_vals[1] = (scale_vals[1] & kmask2) | (((tmp >> 2) & kmask1) << 4); - // convert arrays of f16 -> u32 + // Bytes 0-31: 32 bytes of hmask (8 u32s) var hmask_vals: array; for (var i: u32 = 0; i < 8; i++) { - hmask_vals[i] = bitcast(vec2(block.hmask[2 * i], block.hmask[2 * i + 1])); + hmask_vals[i] = load_u32_at(&src0, block_byte_base + i * 4); } + + // Bytes 32-95: 64 bytes of qs (16 u32s) var qs_vals: array; - for (var i: u32 = 0; i < 16; i++) { - qs_vals[i] = bitcast(vec2(block.qs[2 * i], block.qs[2 * i + 1])); + for (var i: u32 = 0u; i < 16; i++) { + qs_vals[i] = load_u32_at(&src0, block_byte_base + 32 + i * 4); } var sum = 0.0; @@ -301,21 +310,27 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef Q6_K // 16 blocks of 16 elements each fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - let d = f32(block.d); + let block_byte_base = (src0_idx_base + offset) * 210; // Block stride: 210 bytes - // convert arrays of f16 -> u32 + // Bytes 208-209: f16 scale 'd' + let d = load_f16_as_f32_at(&src0, block_byte_base + 208); + + // Bytes 0-127: 128 bytes of ql (32 u32s) var ql_vals: array; for (var i: u32 = 0; i < 32; i++) { - ql_vals[i] = bitcast(vec2(block.ql[2 * i], block.ql[2 * i + 1])); + ql_vals[i] = load_u32_at(&src0, block_byte_base + i * 4); } + + // Bytes 128-191: 64 bytes of qh (16 u32s) var qh_vals: array; for (var i: u32 = 0; i < 16; i++) { - qh_vals[i] = bitcast(vec2(block.qh[2 * i], block.qh[2 * i + 1])); + qh_vals[i] = load_u32_at(&src0, block_byte_base + 128 + i * 4); } + + // Bytes 192-207: 16 bytes of scales (4 u32s) var scale_vals: array; for (var i: u32 = 0; i < 4; i++) { - scale_vals[i] = bitcast(vec2(block.scales[2 * i], block.scales[2 * i + 1])); + scale_vals[i] = load_u32_at(&src0, block_byte_base + 192 + i * 4); } var sum = 0.0; @@ -358,13 +373,15 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef IQ2_XXS fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - let d = f32(block.d); + let block_byte_base = (src0_idx_base + offset) * 66; // Block stride: 66 bytes + let d = load_f16_as_f32_at(&src0, block_byte_base); var src1_i = src1_idx_base + offset * 256; var sum = 0.0; for (var ib: u32 = 0; ib < 32; ib += 4) { - let aux0 = bitcast(vec2(block.qs[ib], block.qs[ib + 1])); - let aux1 = bitcast(vec2(block.qs[ib + 2], block.qs[ib + 3])); + let aux0_offset = block_byte_base + 2 + ib * 2; + let aux1_offset = block_byte_base + 2 + (ib + 2) * 2; + let aux0 = load_u32_at(&src0, aux0_offset); + let aux1 = load_u32_at(&src0, aux1_offset); let db = d * (0.5 + f32(aux1 >> 28)) * 0.25; for (var l: u32 = 0; l < 4; l++) { let ig = get_byte(aux0, l) * 8; @@ -384,13 +401,15 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef IQ2_XS fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - let d = f32(block.d); + let block_byte_base = (src0_idx_base + offset) * 74; // Block stride: 74 bytes + let d = load_f16_as_f32_at(&src0, block_byte_base); var src1_i = src1_idx_base + offset * 256; + var scale_vals = array( - bitcast(vec2(block.scales[0], block.scales[1])), - bitcast(vec2(block.scales[2], block.scales[3])) + load_u32_at(&src0, block_byte_base + 66), + load_u32_at(&src0, block_byte_base + 70) ); + var sum = 0.0; for (var ib: u32 = 0; ib < 32; ib += 4) { let s = get_byte(scale_vals[ib / 16], (ib % 16) / 4); @@ -399,7 +418,8 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { d * (0.5 + f32(s >> 4)) * 0.25 ); for (var l: u32 = 0; l < 4; l++) { - let qs_val = bitcast(vec2(block.qs[ib + l], 0.0)); + let qs_offset = block_byte_base + 2 + (ib + l) * 2; + let qs_val = load_u32_at(&src0, qs_offset) & 0xFFFF; let ig = (qs_val & 511) * 8; let is = qs_val >> 9; let signs = get_byte(ksigns_iq2xs[is / 4], is % 4); @@ -418,21 +438,23 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef IQ2_S fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - let d = f32(block.d); + let block_byte_base = (src0_idx_base + offset) * 82; // Block stride: 82 bytes + let d = load_f16_as_f32_at(&src0, block_byte_base); var src1_i = src1_idx_base + offset * 256; + var qs_vals : array; for (var i: u32 = 0; i < 16; i++) { - qs_vals[i] = bitcast(vec2(block.qs[i * 2], block.qs[i * 2 + 1])); + qs_vals[i] = load_u32_at(&src0, block_byte_base + 2 + i * 4); } - var qh_vals = array( - bitcast(vec2(block.qh[0], block.qh[1])), - bitcast(vec2(block.qh[2], block.qh[3])) - ); - var scale_vals = array( - bitcast(vec2(block.scales[0], block.scales[1])), - bitcast(vec2(block.scales[2], block.scales[3])) - ); + + var qh_vals: array; + qh_vals[0] = load_u32_at(&src0, block_byte_base + 66); + qh_vals[1] = load_u32_at(&src0, block_byte_base + 70); + + var scale_vals: array; + scale_vals[0] = load_u32_at(&src0, block_byte_base + 74); + scale_vals[1] = load_u32_at(&src0, block_byte_base + 78); + var sum = 0.0; for (var ib: u32 = 0; ib < 8; ib ++) { let s = get_byte(scale_vals[ib / 4], ib % 4); @@ -460,17 +482,18 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef IQ3_XXS fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - let d = f32(block.d); + let block_byte_base = (src0_idx_base + offset) * 98; // Block stride: 98 bytes + let d = load_f16_as_f32_at(&src0, block_byte_base); var src1_i = src1_idx_base + offset * 256; var sum = 0.0; for (var ib: u32 = 0; ib < 16; ib += 2) { - let sc_sign = bitcast(vec2(block.qs[ib + 32], block.qs[ib + 33])); + let sc_sign_offset = block_byte_base + 2 + (ib + 32) * 2; + let sc_sign = load_u32_at(&src0, sc_sign_offset); let db = d * (0.5 + f32(sc_sign >> 28)) * 0.5; for (var l: u32 = 0; l < 4; l++) { let is = (sc_sign >> (7 * l)) & 127; let signs = get_byte(ksigns_iq2xs[is / 4], is % 4); - let ig_val = bitcast(vec2(block.qs[ib * 2 + l], 0.0)); + let ig_val = load_u32_at(&src0, block_byte_base + 2 + (ib * 2 + l) * 2) & 0xFFFF; let ig1 = get_byte(ig_val, 0); let ig2 = get_byte(ig_val, 1); for (var j: u32 = 0; j < 4; j++) { @@ -491,18 +514,22 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef IQ3_S fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - let d = f32(block.d); + let block_byte_base = (src0_idx_base + offset) * 110; // Block stride: 110 bytes + let d = load_f16_as_f32_at(&src0, block_byte_base); var src1_i = src1_idx_base + offset * 256; + var qh_vals = array( - bitcast(vec2(block.qh[0], block.qh[1])), - bitcast(vec2(block.qh[2], block.qh[3])) + load_u32_at(&src0, block_byte_base + 66), + load_u32_at(&src0, block_byte_base + 70) ); + var sign_vals: array; for (var i: u32 = 0; i < 8; i++) { - sign_vals[i] = bitcast(vec2(block.signs[i * 2], block.signs[i * 2 + 1])); + sign_vals[i] = load_u32_at(&src0, block_byte_base + 74 + i * 4); } - var scale_vals = bitcast(vec2(block.scales[0], block.scales[1])); + + var scale_vals = load_u32_at(&src0, block_byte_base + 106); + var sum = 0.0; for (var ib: u32 = 0; ib < 4; ib++) { let s = get_byte(scale_vals, ib); @@ -516,7 +543,7 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let sign_w = sign_vals[ib * 2 + k]; for (var l: u32 = 0; l < 4; l++) { let signs = get_byte(sign_w, l); - let ig_val = bitcast(vec2(block.qs[ib * 8 + k * 4 + l], 0.0)); + let ig_val = load_u32_at(&src0, block_byte_base + 2 + (ib * 8 + k * 4 + l) * 2) & 0xFFFF; let ig1 = get_byte(ig_val, 0) | ((qh_byte << ((8 - (2 * l)))) & 256); let ig2 = get_byte(ig_val, 1) | ((qh_byte << ((7 - (2 * l)))) & 256); for (var j: u32 = 0; j < 4; j++) { @@ -538,15 +565,15 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef IQ1_S fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - let d = f32(block.d); + let block_byte_base = (src0_idx_base + offset) * 50; // Block stride: 50 bytes + let d = load_f16_as_f32_at(&src0, block_byte_base); var src1_i = src1_idx_base + offset * 256; var sum = 0.0; for (var ib: u32 = 0; ib < 8; ib++) { - let qh = bitcast(vec2(block.qh[ib], 0.0)); - let dl = d * (2 * f32((qh >> 12) & 7) + 1); + let qh = load_u32_at(&src0, block_byte_base + 34 + ib * 2) & 0xFFFF; + let dl = d * (2.0 * f32((qh >> 12) & 7) + 1.0); let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000) != 0); - let qs_w = bitcast(vec2(block.qs[ib * 2], block.qs[ib * 2 + 1])); + let qs_w = load_u32_at(&src0, block_byte_base + 2 + ib * 4); for (var l: u32 = 0; l < 4; l++) { let ig = (get_byte(qs_w, l) | (((qh >> (3 * l)) & 7) << 8)) * 8; for (var j: u32 = 0; j < 8; j++) { @@ -610,13 +637,13 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef IQ4_NL fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - let d = f32(block.d); + let block_byte_base = (src0_idx_base + offset) * 18; // Block stride: 18 bytes + let d = load_f16_as_f32_at(&src0, block_byte_base); var src1_i = src1_idx_base + offset * 32; var sum = 0.0; var qs: array; for (var i: u32 = 0; i < 4; i++) { - qs[i] = bitcast(vec2(block.qs[i * 2], block.qs[i * 2 + 1])); + qs[i] = load_u32_at(&src0, block_byte_base + 2 + i * 4); } for (var j: u32 = 0; j < 16; j++) { let qsb = get_byte(qs[j / 4], j % 4); @@ -631,8 +658,8 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef IQ4_XS fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block = src0[src0_idx_base + offset]; - let d = f32(block.d); - let scales_h = bitcast(vec2(block.scales_h, 0.0)); + let d = unpack2x16float(block.d_scales_h)[0]; + let scales_h = block.d_scales_h >> 16; var src1_i = src1_idx_base + offset * 256; var sum = 0.0; for (var ib: u32 = 0; ib < 8; ib++) { diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl index ea91c13468f..374137ff8e8 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl @@ -84,11 +84,11 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { let src0_idx = batch_offset + global_m * params.stride_01 + global_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_src0_f16_at(block_byte_base); + let d = load_f16_at(&src0, block_byte_base); for (var j = 0u; j < F16_PER_THREAD; j += 2) { let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); - let q_packed = load_src0_u32_at(q_byte_offset); + let q_packed = load_u32_at(&src0, q_byte_offset); for (var k = 0u; k < 4u; k++) { let q_byte = get_byte(q_packed, k); let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d; @@ -125,12 +125,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { let src0_idx = batch_offset + global_m * params.stride_01 + global_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_src0_f16_at(block_byte_base); - let m = load_src0_f16_at(block_byte_base + 2u); + let d = load_f16_at(&src0, block_byte_base); + let m = load_f16_at(&src0, block_byte_base + 2u); for (var j = 0u; j < F16_PER_THREAD; j += 2) { let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j); - let q_packed = load_src0_u32_at(q_byte_offset); + let q_packed = load_u32_at(&src0, q_byte_offset); for (var k = 0u; k < 4u; k++) { let q_byte = get_byte(q_packed, k); let q_lo = f16(q_byte & 0xF) * d + m; @@ -171,12 +171,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let src0_idx = batch_offset + global_m * params.stride_01 + global_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_src0_f16_at(block_byte_base); - let qh_packed = load_src0_u32_at(block_byte_base + 2u); + let d = load_f16_at(&src0, block_byte_base); + let qh_packed = load_u32_at(&src0, block_byte_base + 2u); for (var j = 0u; j < 2; j++) { let q_byte_offset = block_byte_base + 6u + 2u * (block_offset + j * 2u); - let q_packed = load_src0_u32_at(q_byte_offset); + let q_packed = load_u32_at(&src0, q_byte_offset); let j_adjusted = j + (block_offset / 2u); @@ -225,14 +225,14 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let src0_idx = batch_offset + global_m * params.stride_01 + global_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_src0_f16_at(block_byte_base); - let m = load_src0_f16_at(block_byte_base + 2u); - let qh_packed = load_src0_u32_at(block_byte_base + 4u); + let d = load_f16_at(&src0, block_byte_base); + let m = load_f16_at(&src0, block_byte_base + 2u); + let qh_packed = load_u32_at(&src0, block_byte_base + 4u); for (var j = 0u; j < 2; j++) { let q_byte_offset = block_byte_base + 8u + 2u * (block_offset + j * 2u); - let q_packed = load_src0_u32_at(q_byte_offset); + let q_packed = load_u32_at(&src0, q_byte_offset); let j_adjusted = j + (block_offset / 2u); @@ -277,11 +277,11 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { let src0_idx = batch_offset + global_m * params.stride_01 + global_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_src0_f16_at(block_byte_base); + let d = load_f16_at(&src0, block_byte_base); for (var j = 0u; j < F16_PER_THREAD; j+=2) { let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); - let q_packed = load_src0_u32_at(q_byte_offset); + let q_packed = load_u32_at(&src0, q_byte_offset); for (var k = 0u; k < 4u; k++) { let q_byte = get_byte_i32(q_packed, k); @@ -317,12 +317,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { let src0_idx = batch_offset + global_m * params.stride_01 + global_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_src0_f16_at(block_byte_base); - let m = load_src0_f16_at(block_byte_base + 2u); + let d = load_f16_at(&src0, block_byte_base); + let m = load_f16_at(&src0, block_byte_base + 2u); for (var j = 0u; j < F16_PER_THREAD; j+=2) { let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j); - let q_packed = load_src0_u32_at(q_byte_offset); + let q_packed = load_u32_at(&src0, q_byte_offset); for (var k = 0u; k < 4u; k++) { let q_byte = get_byte_i32(q_packed, k); @@ -359,8 +359,8 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let src0_idx = batch_offset + global_m * params.stride_01 + block_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_src0_f16_at(block_byte_base + 80u); - let dmin = load_src0_f16_at(block_byte_base + 82u); + let d = load_f16_at(&src0, block_byte_base + 80u); + let dmin = load_f16_at(&src0, block_byte_base + 82u); // Decode the element at position k_in_block let block_of_32 = k_in_block / 32u; @@ -373,14 +373,14 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let is = k_in_block / 16u; - let sc_packed = load_src0_u32_at(block_byte_base + 4u * (is / 4u)); + let sc_packed = load_u32_at(&src0, block_byte_base + 4u * (is / 4u)); let sc = get_byte(sc_packed, is % 4u); let dl = d * f16(sc & 0xFu); let ml = dmin * f16(sc >> 4u); let q_idx = q_b_idx + k + l; - let q_packed = load_src0_u32_at(block_byte_base + 16u + 4u * (q_idx / 4u)); + let q_packed = load_u32_at(&src0, block_byte_base + 16u + 4u * (q_idx / 4u)); let q_byte = get_byte(q_packed, q_idx % 4u); let qs_val = (q_byte >> shift) & 3u; @@ -413,7 +413,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let src0_idx = batch_offset + global_m * params.stride_01 + block_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_src0_f16_at(block_byte_base + 108u); + let d = load_f16_at(&src0, block_byte_base + 108u); // Load and unpack scales let kmask1: u32 = 0x03030303u; @@ -421,7 +421,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 var scale_vals: array; for (var i: u32 = 0u; i < 4u; i++) { - scale_vals[i] = load_src0_u32_at(block_byte_base + 96u + 4u * i); + scale_vals[i] = load_u32_at(&src0, block_byte_base + 96u + 4u * i); } var tmp: u32 = scale_vals[2]; @@ -433,12 +433,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 // Load hmask and qs arrays var hmask_vals: array; for (var i: u32 = 0u; i < 8u; i++) { - hmask_vals[i] = load_src0_u32_at(block_byte_base + 4u * i); + hmask_vals[i] = load_u32_at(&src0, block_byte_base + 4u * i); } var qs_vals: array; for (var i: u32 = 0u; i < 16u; i++) { - qs_vals[i] = load_src0_u32_at(block_byte_base + 32u + 4u * i); + qs_vals[i] = load_u32_at(&src0, block_byte_base + 32u + 4u * i); } let half = k_in_block / 128u; // 0 or 1 @@ -499,13 +499,13 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let src0_idx = batch_offset + global_m * params.stride_01 + block_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_src0_f16_at(block_byte_base); - let dmin = load_src0_f16_at(block_byte_base + 2u); + let d = load_f16_at(&src0, block_byte_base); + let dmin = load_f16_at(&src0, block_byte_base + 2u); // Load packed scales var scale_vals: array; for (var i: u32 = 0u; i < 3u; i++) { - scale_vals[i] = load_src0_u32_at(block_byte_base + 4u + 4u * i); + scale_vals[i] = load_u32_at(&src0, block_byte_base + 4u + 4u * i); } // Map k_in_block to loop structure: @@ -541,7 +541,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let ml = dmin * f16(mn); let q_idx = q_b_idx + l; - let q_packed = load_src0_u32_at(block_byte_base + 16u + 4u * (q_idx / 4u)); + let q_packed = load_u32_at(&src0, block_byte_base + 16u + 4u * (q_idx / 4u)); let q_byte = get_byte(q_packed, q_idx % 4u); let qs_val = (q_byte >> shift) & 0xFu; @@ -575,13 +575,13 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let src0_idx = batch_offset + global_m * params.stride_01 + block_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_src0_f16_at(block_byte_base); - let dmin = load_src0_f16_at(block_byte_base + 2u); + let d = load_f16_at(&src0, block_byte_base); + let dmin = load_f16_at(&src0, block_byte_base + 2u); // Load packed scales var scale_vals: array; for (var i: u32 = 0u; i < 3u; i++) { - scale_vals[i] = load_src0_u32_at(block_byte_base + 4u + 4u * i); + scale_vals[i] = load_u32_at(&src0, block_byte_base + 4u + 4u * i); } // The original loop processes elements in groups of 64 @@ -621,11 +621,11 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let ml = dmin * f16(mn); let q_idx = q_b_idx + l; - let q_packed = load_src0_u32_at(block_byte_base + 48u + 4u * (q_idx / 4u)); + let q_packed = load_u32_at(&src0, block_byte_base + 48u + 4u * (q_idx / 4u)); let q_byte = get_byte(q_packed, q_idx % 4u); - let qh_packed = load_src0_u32_at(block_byte_base + 16u + 4u * (l / 4u)); + let qh_packed = load_u32_at(&src0, block_byte_base + 16u + 4u * (l / 4u)); let qh_byte = get_byte(qh_packed, l % 4u); @@ -673,17 +673,17 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 // Load only ql13 word needed let ql13_flat = ql_b_idx + l; - let ql13 = load_src0_u32_at(block_byte_base + ql13_flat); + let ql13 = load_u32_at(&src0, block_byte_base + ql13_flat); let ql13_b = get_byte(ql13, 0u); // Load only ql24 word needed let ql24_flat = ql_b_idx + l + 32u; - let ql24 = load_src0_u32_at(block_byte_base + ql24_flat); + let ql24 = load_u32_at(&src0, block_byte_base + ql24_flat); let ql24_b = get_byte(ql24, 0u); // Load only qh word needed let qh_flat = qh_b_idx + l; - let qh = load_src0_u32_at(block_byte_base + 128u + qh_flat); + let qh = load_u32_at(&src0, block_byte_base + 128u + qh_flat); let qh_b = get_byte(qh, 0u); let q1 = f16((ql13_b & 0xFu) | ((qh_b & 3u) << 4u)) - f16(32.0); @@ -694,10 +694,10 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 // Load only the scale word needed let is = l / 16u; let sc_idx = sc_b_idx + is + quarter * 2u; - let sc = load_src0_u32_at(block_byte_base + 192u + sc_idx); + let sc = load_u32_at(&src0, block_byte_base + 192u + sc_idx); let sc_val = get_byte_i32(sc, 0u); - let d = load_src0_f16_at(block_byte_base + 208u); + let d = load_f16_at(&src0, block_byte_base + 208u); var q_val: f16; if (quarter == 0u) { diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl index 6525f23bdfc..6f6bcaf7940 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl @@ -65,10 +65,10 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(load_src0_f16_at(block_byte_base)); + let d = f32(load_f16_at(&src0, block_byte_base)); for (var j = 0u; j < F16_PER_THREAD; j += 2) { let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); - let q_packed = load_src0_u32_at(q_byte_offset); + let q_packed = load_u32_at(&src0, q_byte_offset); for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte(q_packed, k); let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * d; @@ -98,11 +98,11 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(load_src0_f16_at(block_byte_base)); - let m = f32(load_src0_f16_at(block_byte_base + 2u)); + let d = f32(load_f16_at(&src0, block_byte_base)); + let m = f32(load_f16_at(&src0, block_byte_base + 2u)); for (var j = 0u; j < F16_PER_THREAD; j += 2) { let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j); - let q_packed = load_src0_u32_at(q_byte_offset); + let q_packed = load_u32_at(&src0, q_byte_offset); for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte(q_packed, k); let q_hi = f32((q_byte >> 4) & 0xF) * d + m; @@ -132,12 +132,12 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(load_src0_f16_at(block_byte_base)); - let qh_packed = load_src0_u32_at(block_byte_base + 2u); + let d = f32(load_f16_at(&src0, block_byte_base)); + let qh_packed = load_u32_at(&src0, block_byte_base + 2u); for (var j = 0u; j < 2; j++) { let q_byte_offset = block_byte_base + 6u + 2u * (block_offset + j * 2u); - let q_packed = load_src0_u32_at(q_byte_offset); + let q_packed = load_u32_at(&src0, q_byte_offset); let j_adjusted = j + (block_offset / 2u); @@ -176,13 +176,13 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(load_src0_f16_at(block_byte_base)); - let m = load_src0_f16_at(block_byte_base + 2u); - let qh_packed = load_src0_u32_at(block_byte_base + 4u); + let d = f32(load_f16_at(&src0, block_byte_base)); + let m = load_f16_at(&src0, block_byte_base + 2u); + let qh_packed = load_u32_at(&src0, block_byte_base + 4u); for (var j = 0u; j < 2; j++) { let q_byte_offset = block_byte_base + 8u + 2u * (block_offset + j * 2u); - let q_packed = load_src0_u32_at(q_byte_offset); + let q_packed = load_u32_at(&src0, q_byte_offset); let j_adjusted = j + (block_offset / 2u); @@ -221,11 +221,11 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(load_src0_f16_at(block_byte_base)); + let d = f32(load_f16_at(&src0, block_byte_base)); for (var j = 0u; j < F16_PER_THREAD; j += 2) { let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); - let q_packed = load_src0_u32_at(q_byte_offset); + let q_packed = load_u32_at(&src0, q_byte_offset); for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte_i32(q_packed, k); let q_val = f32(q_byte) * d; @@ -254,12 +254,12 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(load_src0_f16_at(block_byte_base)); - let m = load_src0_f16_at(block_byte_base + 2u); + let d = f32(load_f16_at(&src0, block_byte_base)); + let m = load_f16_at(&src0, block_byte_base + 2u); for (var j = 0u; j < F16_PER_THREAD; j += 2) { let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j); - let q_packed = load_src0_u32_at(q_byte_offset); + let q_packed = load_u32_at(&src0, q_byte_offset); for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte_i32(q_packed, k); let q_val = f32(q_byte) * d + f32(m); @@ -309,13 +309,13 @@ fn mul_acc(tig: u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { for (var i = ix; i < nb; i += 2u) { let bbase = (idx_base + k_block_start + i) * BLOCK_SIZE_BYTES; - let d = f32(load_src0_f16_at(bbase + 208u)); + let d = f32(load_f16_at(&src0, bbase + 208u)); - let ql1_u32 = load_src0_u32_at(bbase + q_offset_l); - let ql2_u32 = load_src0_u32_at(bbase + q_offset_l + 32u); - let qh_u32 = load_src0_u32_at(bbase + 128u + q_offset_h); - let sc_u32_0 = load_src0_u32_at(bbase + sc_base_byte); - let sc_u32_1 = load_src0_u32_at(bbase + sc_base_byte + 4u); + let ql1_u32 = load_u32_at(&src0, bbase + q_offset_l); + let ql2_u32 = load_u32_at(&src0, bbase + q_offset_l + 32u); + let qh_u32 = load_u32_at(&src0, bbase + 128u + q_offset_h); + let sc_u32_0 = load_u32_at(&src0, bbase + sc_base_byte); + let sc_u32_1 = load_u32_at(&src0, bbase + sc_base_byte + 4u); let sc0 = sbyte_of(sc_u32_0, sc_byte_pos); let sc2 = sbyte_of(sc_u32_0, sc_byte_pos + 2u); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl index 21beb9bb94d..8c334817ccd 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl @@ -107,7 +107,8 @@ fn main(@builtin(global_invocation_id) gid: vec3) { let res = src[params.offset_src + src_idx] / (1.0 + exp(-src[params.offset_src + src_idx])); #endif #ifdef EXP - let res = exp(src[params.offset_src + src_idx]); + let src_f32 = f32(src[params.offset_src + src_idx]); + let res = TYPE(exp(src_f32)); #endif #ifdef LOG let res = TYPE(log(f32(src[params.offset_src + src_idx]))); @@ -161,7 +162,8 @@ fn main(@builtin(global_invocation_id) gid: vec3) { let res = TYPE(select(log(1.0 + exp(src_f32)), src_f32, src_f32 > 20.0)); #endif #ifdef EXPM1 - let res = exp(src[params.offset_src + src_idx]) - 1.0; + let src_f32 = f32(src[params.offset_src + src_idx]); + let res = TYPE(exp(src_f32) - 1.0); #endif #ifdef FLOOR let res = floor(src[params.offset_src + src_idx]); @@ -181,7 +183,7 @@ fn main(@builtin(global_invocation_id) gid: vec3) { let res = src[params.offset_src + src_idx] * src[params.offset_src + src_idx]; #endif #ifdef SQRT - let res = sqrt(src[params.offset_src + src_idx]); + let res = TYPE(sqrt(f32(src[params.offset_src + src_idx]))); #endif #ifdef SIN let res_f32 = sin(f32(src[params.offset_src + src_idx])); From 2580cfc70360cddcc09de271c84c90c57771e30c Mon Sep 17 00:00:00 2001 From: Rithik Sharma Date: Fri, 10 Apr 2026 10:52:38 -0700 Subject: [PATCH 414/831] ggml-webgpu: support non-square subgroup matrix configs for Intel GPUs (llama/21669) --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 13 +++++-- .../ggml-webgpu/wgsl-shaders/flash_attn.wgsl | 34 +++++++++---------- 2 files changed, 27 insertions(+), 20 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 3b894a9b9cc..e979783f020 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -3461,13 +3461,15 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { GGML_ASSERT(ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ShaderF16)); #ifndef __EMSCRIPTEN__ - // Only support square f16 matrices of size 8 or 16 for now + // Accept f16 subgroup matrix configurations (square or non-square). + // NVIDIA GPUs typically report square configs (e.g. 16x16x16), + // while Intel Xe2 GPUs report non-square configs (e.g. 8x16x16). + // The shaders are already parameterized to handle any M/N/K dimensions. bool valid_subgroup_matrix_config = false; if (ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) { for (size_t i = 0; i < subgroup_matrix_configs.configCount; i++) { const wgpu::SubgroupMatrixConfig config = subgroup_matrix_configs.configs[i]; - if (config.M == config.N && config.N == config.K && (config.K == 8 || config.K == 16) && - config.componentType == wgpu::SubgroupMatrixComponentType::F16 && + if (config.componentType == wgpu::SubgroupMatrixComponentType::F16 && config.resultComponentType == wgpu::SubgroupMatrixComponentType::F16) { ctx->webgpu_global_ctx->capabilities.sg_mat_m = config.M; ctx->webgpu_global_ctx->capabilities.sg_mat_n = config.N; @@ -3805,6 +3807,11 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const if (!ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix) { break; } + // Head dimensions must be divisible by subgroup matrix dimensions + if (src0->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k != 0 || + src2->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_n != 0) { + break; + } // Head dimensions must fit in workgroup memory with minimum tile sizes size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; const bool has_mask = op->src[3] != nullptr; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl index 8b76cecba91..aa2d2e54db9 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl @@ -369,35 +369,35 @@ fn main(@builtin(workgroup_id) wg_id: vec3, #endif for (var kv_block = subgroup_id; kv_block < KV_BLOCKS; kv_block += num_subgroups) { let inter_offset = kv_block * SG_MAT_N; - var acc: subgroup_matrix_result = subgroupMatrixLoad>(&inter_shmem, inter_offset, false, KV_TILE); + var acc: subgroup_matrix_result = subgroupMatrixLoad>(&inter_shmem, inter_offset, false, KV_TILE); - var q_cur = subgroupMatrixLoad>(&q_shmem, 0u, false, HEAD_DIM_QK); + var q_cur = subgroupMatrixLoad>(&q_shmem, 0u, false, HEAD_DIM_QK); #ifdef KV_DIRECT - var k_cur = subgroupMatrixLoad>(&K, k_global_offset + 0u, true, params.stride_k1); + var k_cur = subgroupMatrixLoad>(&K, k_global_offset + 0u, true, params.stride_k1); #else - var k_cur = subgroupMatrixLoad>(&kv_shmem, k_block_offset + 0u, true, HEAD_DIM_QK); + var k_cur = subgroupMatrixLoad>(&kv_shmem, k_block_offset + 0u, true, HEAD_DIM_QK); #endif var t: u32 = 1u; for (; t + 1u < HEAD_DIM_QK / SG_MAT_K; t += 2u) { let h0 = t * SG_MAT_K; - var q0 = subgroupMatrixLoad>(&q_shmem, h0, false, HEAD_DIM_QK); + var q0 = subgroupMatrixLoad>(&q_shmem, h0, false, HEAD_DIM_QK); #ifdef KV_DIRECT - var k0 = subgroupMatrixLoad>(&K, k_global_offset + h0, true, params.stride_k1); + var k0 = subgroupMatrixLoad>(&K, k_global_offset + h0, true, params.stride_k1); #else - var k0 = subgroupMatrixLoad>(&kv_shmem, k_block_offset + h0, true, HEAD_DIM_QK); + var k0 = subgroupMatrixLoad>(&kv_shmem, k_block_offset + h0, true, HEAD_DIM_QK); #endif acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc); q_cur = q0; k_cur = k0; let h1 = (t + 1u) * SG_MAT_K; - var q1g = subgroupMatrixLoad>(&q_shmem, h1, false, HEAD_DIM_QK); + var q1g = subgroupMatrixLoad>(&q_shmem, h1, false, HEAD_DIM_QK); #ifdef KV_DIRECT - var k1g = subgroupMatrixLoad>(&K, k_global_offset + h1, true, params.stride_k1); + var k1g = subgroupMatrixLoad>(&K, k_global_offset + h1, true, params.stride_k1); #else - var k1g = subgroupMatrixLoad>(&kv_shmem, k_block_offset + h1, true, HEAD_DIM_QK); + var k1g = subgroupMatrixLoad>(&kv_shmem, k_block_offset + h1, true, HEAD_DIM_QK); #endif acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc); q_cur = q1g; @@ -407,11 +407,11 @@ fn main(@builtin(workgroup_id) wg_id: vec3, // handle odd tail if (t < HEAD_DIM_QK / SG_MAT_K) { let h = t * SG_MAT_K; - var qn = subgroupMatrixLoad>(&q_shmem, h, false, HEAD_DIM_QK); + var qn = subgroupMatrixLoad>(&q_shmem, h, false, HEAD_DIM_QK); #ifdef KV_DIRECT - var kn = subgroupMatrixLoad>(&K, k_global_offset + h, true, params.stride_k1); + var kn = subgroupMatrixLoad>(&K, k_global_offset + h, true, params.stride_k1); #else - var kn = subgroupMatrixLoad>(&kv_shmem, k_block_offset + h, true, HEAD_DIM_QK); + var kn = subgroupMatrixLoad>(&kv_shmem, k_block_offset + h, true, HEAD_DIM_QK); #endif acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc); q_cur = qn; @@ -566,7 +566,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, head_dim_block < HEAD_DIM_V; head_dim_block += num_subgroups * SG_MAT_N) { // load O submatrix from shared memory - var o_sg_mat: subgroup_matrix_result = subgroupMatrixLoad>( + var o_sg_mat: subgroup_matrix_result = subgroupMatrixLoad>( &o_shmem, head_dim_block, false, @@ -574,7 +574,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, ); for (var kv_block = 0u; kv_block < KV_BLOCKS; kv_block++) { let p_offset = kv_block * SG_MAT_N; - var p_sg_mat: subgroup_matrix_left = subgroupMatrixLoad>( + var p_sg_mat: subgroup_matrix_left = subgroupMatrixLoad>( &inter_shmem, p_offset, false, @@ -585,7 +585,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, #ifdef KV_DIRECT let v_block_row = kv_tile + kv_block * SG_MAT_N; let v_global_offset = v_head_offset + v_block_row * params.stride_v1 + head_dim_block; - var v_sg_mat: subgroup_matrix_right = subgroupMatrixLoad>( + var v_sg_mat: subgroup_matrix_right = subgroupMatrixLoad>( &V, v_global_offset, false, @@ -593,7 +593,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, ); #else let v_block_offset = kv_block * SG_MAT_N * HEAD_DIM_V; - var v_sg_mat: subgroup_matrix_right = subgroupMatrixLoad>( + var v_sg_mat: subgroup_matrix_right = subgroupMatrixLoad>( &kv_shmem, v_block_offset + head_dim_block, false, From 28ce072f59523b0a3a1752ceab7516e6e5d9a86d Mon Sep 17 00:00:00 2001 From: Max Krasnyansky Date: Fri, 10 Apr 2026 15:47:43 -0700 Subject: [PATCH 415/831] hexagon: improved Op queuing, buffer and cache management (llama/21705) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * hexagon: introduce op request batching and rewrite buffer managment The host now prepares batches of requests and dispatches them via a single dspqueue message. Buffers are mapped explicitly by NPU while processing batches. * hex-dma: disable l2 bypass since to work around new issue due to no flushes between Ops * hex-utils: add explicit l2flush and l2clear helpers * hex-opreq: use fine-grain per tensor l2 management * hex-opreq: avoid redundant invalidates for tensors we already flushed * hex-opreq: update debug messages * htp-opreq: reuse ops_context * hex-opreq: do not flush or invalidate cache lines beyond buffer boundry * hex-opreq: fix errors in log message * Revert "hex-opreq: do not flush or invalidate cache lines beyond buffer boundry" This reverts commit 8b7f0a55a750a6430ce4eb1874c7feb3d720056d. * hexagon: limit l2 flushes to 1MB which covers l2 cache * hex-opreq: limit cache flush to 4MB Looks like 4MB cont. vitual space should cover the 1MB cache. * hexagon: drop cache flush size to 2MB * hex-opreq: start reworking opreq packing * hex-opreq: introduce new way of packing opbatch where tensors are stored separately * hex-opreq: add a simple fastrpc call to force unmap all buffers * hex-l2flush: somehow 2MB does not seem robust, also cleanup step size to use line-size * hex-opreq: bump opreq batch size to 256 * hex-mm: place src1 spad at the top of vtcm for easy reuse * hex-ops: introduce internal types and disable src1 reuse for now Nothing new just formalizing the repack / qyn.quant types we've been using. * htp-opreq: use tensor pointers instead of copies * hex-opreq: introduce more robust way for tracking vtcm/spad reuse This removes the SKIP_QUANTIZE flag that became fragile with the addition of HMX and other ops. * hex-cumsum: fix error post opreq merge * hex-opreq: move request batch handling into the session Prepping everything for using dspqueue buffers and doing that inside the session is much cleaner. * hex-mm: yet another fix for src1 reuse when we're mixing hmx/hvx * hex-bufs: introduce pinned mmapings and use non-pinned ones for model buffers * hex-buf: add support for allocating shared/pinned buffer for opreqs * hex-opbatch: make opbatches configurable * hex-naming: better name for ggml_hexagon_shared_buffer * hex-naming: add session->c_name() helper * hex-opbatch: start using shm but still copy for now * hex-opbatch: use shared buffer for packing opbatch * hex-opbatch: beter naming for opbatch related classes and code * hex-opbatch: reuse batched tensors with same data/dims/strides * hex-opbatch: update logging * hex-opbatch: add support for vmem limit for op batching * hex-opbatch: update htp side to properly support dynamic mmap/unmap * hex-opbatch: add OB and OQ params for run-completion script and fix the asserts in batch processing * hex-opbatch: fixed src1 handling in act ops * hex-act: fix empty src1 handling in swiglu and friends Simplify preamble macro while at it * hex-mm: minor fix vtcm and dma handling in matmul cleaning up some left-overs from merges * hex-opbatch: allocate extra 1KB for dspqueue overhead * hexagon: fix softmax for non-aligned tensors and cleanup vtcm alloc * hex-mm: properly handle hmx_disabled flag * hex-ops: update comments * hex-ops: add debug output for get/set-rows * hex-mmap: optimize un/mapping of buffers * hex-opreq: global cache flush and invalidate beyond 128KB threshold * hex-ops: add super simple opfilter regex for debugging If an Op matches the regex hex backend will reject it. * hex-opbatch: wireup newer ops missed in merge and update main switch to detect this in future * hexagon: improved vtcm acquision to remove inter-op overhead Fully compatible with QNN-HTP coex * hex-mm: fixed hvx fallback path * hex-mm: lower the vmem threshold a bit further to ~3GB * hexagon: update debug & error logs This also fixes an issue with newer llvm merging repack and non-repack functions. We use those pointer to distinguish between buffer types. * hexagon: move ops context into main context Just a cleanup. We don't need separate contexts at this point. * hex-opbatch: cleanup naming and headers for opbatch and related descriptors * hex-fa: it's now better to enable FA during TG to reduce graph splits * hexagon: remove GGML_HEXAGON_EXPERIMENTAL env var It's no longer useful. Please use more flexible GGML_HEXAGON_OPFILTER to disable Ops if needed for debugging or validation. * hexagon: fixed editorconfig check * Update ggml/src/ggml-hexagon/ggml-hexagon.cpp Co-authored-by: Sigbjørn Skjæret --------- Co-authored-by: Trivikram Reddy Co-authored-by: Sigbjørn Skjæret --- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 1343 +++++++++--------- ggml/src/ggml-hexagon/htp/act-ops.c | 137 +- ggml/src/ggml-hexagon/htp/argsort-ops.c | 18 +- ggml/src/ggml-hexagon/htp/binary-ops.c | 46 +- ggml/src/ggml-hexagon/htp/cpy-ops.c | 10 +- ggml/src/ggml-hexagon/htp/cumsum-ops.c | 25 +- ggml/src/ggml-hexagon/htp/flash-attn-ops.c | 36 +- ggml/src/ggml-hexagon/htp/get-rows-ops.c | 74 +- ggml/src/ggml-hexagon/htp/hex-utils.h | 21 + ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c | 25 +- ggml/src/ggml-hexagon/htp/hmx-ops.h | 6 +- ggml/src/ggml-hexagon/htp/htp-ctx.h | 106 +- ggml/src/ggml-hexagon/htp/htp-msg.h | 166 --- ggml/src/ggml-hexagon/htp/htp-ops.h | 183 ++- ggml/src/ggml-hexagon/htp/htp_iface.idl | 2 + ggml/src/ggml-hexagon/htp/main.c | 1418 +++++--------------- ggml/src/ggml-hexagon/htp/matmul-ops.c | 229 +++- ggml/src/ggml-hexagon/htp/repeat-ops.c | 10 +- ggml/src/ggml-hexagon/htp/rope-ops.c | 31 +- ggml/src/ggml-hexagon/htp/set-rows-ops.c | 90 +- ggml/src/ggml-hexagon/htp/softmax-ops.c | 252 ++-- ggml/src/ggml-hexagon/htp/ssm-conv.c | 21 +- ggml/src/ggml-hexagon/htp/sum-rows-ops.c | 12 +- ggml/src/ggml-hexagon/htp/unary-ops.c | 12 +- 24 files changed, 1732 insertions(+), 2541 deletions(-) delete mode 100644 ggml/src/ggml-hexagon/htp/htp-msg.h diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index ac5baa2acaf..3d68b80048f 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -7,10 +7,14 @@ #include #include -#include #include +#include +#include #include #include +#include +#include +#include #ifdef _WIN32 # include @@ -33,7 +37,7 @@ #include "ggml-impl.h" #include "ggml-quants.h" #include "op-desc.h" -#include "htp-msg.h" +#include "htp-ops.h" #include "htp_iface.h" #include "htp-drv.h" @@ -44,12 +48,14 @@ static int opt_etm = 0; static int opt_verbose = 0; static int opt_profile = 0; static int opt_hostbuf = 1; // hostbuf ON by default -static int opt_experimental = 0; static int opt_use_hmx = 1; // when set, enable HMX; when 0, use HVX only // Enable all stages by default -static int opt_opmask = HTP_OPMASK_QUEUE | HTP_OPMASK_QUANTIZE | HTP_OPMASK_COMPUTE; -static int opt_opsync = 0; // synchronous ops +static int opt_opmask = HTP_OPMASK_QUEUE | HTP_OPMASK_COMPUTE; +static int opt_opsync = 0; // synchronous ops +static int opt_opbatch = 1024; // max number of ops in a batch +static int opt_opqueue = 16; // max number of pending batches +static std::regex* opt_opfilter = NULL; // regex of ops to not claim #define HEX_VERBOSE(...) \ if (opt_verbose) GGML_LOG_DEBUG(__VA_ARGS__) @@ -86,7 +92,7 @@ static void ggml_hexagon_dump_op_exec(const std::string &sess_name, const ggml_t op_desc desc(op); GGML_LOG_DEBUG("ggml-hex: %s execute-op %s: %s : %s : %s : %s : %s : flags 0x%x\n", sess_name.c_str(), - ggml_op_name(op->op), desc.names, desc.dims, desc.types, desc.strides, desc.buffs, req_flags); + ggml_op_desc(op), desc.names, desc.dims, desc.types, desc.strides, desc.buffs, req_flags); } static void ggml_hexagon_dump_op_supp(const std::string &sess_name, const struct ggml_tensor * op, bool supp) { @@ -94,7 +100,7 @@ static void ggml_hexagon_dump_op_supp(const std::string &sess_name, const struct op_desc desc(op); GGML_LOG_DEBUG("ggml-hex: %s supports-op %s : %s : %s : %s : %s : %s : %s\n", sess_name.c_str(), - ggml_op_name(op->op), desc.names, desc.dims, desc.types, desc.strides, desc.buffs, supp ? "yes" : "no"); + ggml_op_desc(op), desc.names, desc.dims, desc.types, desc.strides, desc.buffs, supp ? "yes" : "no"); } static void ggml_hexagon_dump_op_prof(const std::string &sess_name, const ggml_tensor * op, @@ -103,25 +109,16 @@ static void ggml_hexagon_dump_op_prof(const std::string &sess_name, const ggml_t op_desc desc(op); GGML_LOG_DEBUG("ggml-hex: %s profile-op %s: %s : %s : %s : %s : %s : op-usec %u op-cycles %u op-pkts %u (%f) call-usec %llu\n", sess_name.c_str(), - ggml_op_name(op->op), desc.names, desc.dims, desc.types, desc.strides, desc.buffs, + ggml_op_desc(op), desc.names, desc.dims, desc.types, desc.strides, desc.buffs, op_usec, op_cycles, op_pkts, (float) op_cycles / op_pkts, (unsigned long long) call_usec); } // ** backend sessions -struct ggml_hexagon_session { - ggml_hexagon_session(int dev_id, ggml_backend_dev_t dev) noexcept(false); - ~ggml_hexagon_session() noexcept(true); - - void allocate(int dev_id) noexcept(false); - void release() noexcept(true); - - void enqueue(struct htp_general_req &req, struct dspqueue_buffer *bufs, uint32_t n_bufs, bool sync = false); - void flush(); - - ggml_backend_buffer_type buffer_type = {}; - ggml_backend_buffer_type repack_buffer_type = {}; +struct ggml_hexagon_opbatch; +struct ggml_hexagon_opshm; +struct ggml_hexagon_session { std::string name; remote_handle64 handle; dspqueue_t queue; @@ -133,87 +130,28 @@ struct ggml_hexagon_session { bool valid_handle; bool valid_queue; bool valid_iface; - std::atomic op_pending; - uint32_t prof_usecs; - uint32_t prof_cycles; - uint32_t prof_pkts; -}; - -void ggml_hexagon_session::enqueue(struct htp_general_req &req, struct dspqueue_buffer *bufs, uint32_t n_bufs, bool sync) { - // Bump pending flag (cleared in the session::flush once we get the response) - this->op_pending++; // atomic inc - - int err = dspqueue_write(this->queue, - 0, // flags - the framework will autoset this - n_bufs, // number of buffers - bufs, // buffer references - sizeof(req), // Message length - (const uint8_t *) &req, // Message - DSPQUEUE_TIMEOUT // Timeout - ); - - if (err != 0) { - GGML_ABORT("ggml-hex: %s dspqueue_write failed: 0x%08x\n", this->name.c_str(), (unsigned) err); - } - - if (sync) { - flush(); - } -} - -// Flush HTP response queue i.e wait for all outstanding requests to complete -void ggml_hexagon_session::flush() { - dspqueue_t q = this->queue; - - // Repeatedly read packets from the queue until it's empty. We don't - // necessarily get a separate callback for each packet, and new packets - // may arrive while we're processing the previous one. - - while (this->op_pending) { - struct htp_general_rsp rsp; - uint32_t rsp_size; - uint32_t flags; - struct dspqueue_buffer bufs[HTP_MAX_PACKET_BUFFERS]; - uint32_t n_bufs; + std::atomic op_pending; + ggml_hexagon_opbatch *op_batch; + ggml_hexagon_opshm *op_shm; - // Read response packet from queue - int err = dspqueue_read(q, &flags, - HTP_MAX_PACKET_BUFFERS, // Maximum number of buffer references - &n_bufs, // Number of buffer references - bufs, // Buffer references - sizeof(rsp), // Max message length - &rsp_size, // Message length - (uint8_t *) &rsp, // Message - DSPQUEUE_TIMEOUT); // Timeout - - if (err == AEE_EEXPIRED) { - // TODO: might need to bail out if the HTP is stuck on something - continue; - } + ggml_backend_buffer_type buffer_type = {}; + ggml_backend_buffer_type repack_buffer_type = {}; - if (err != 0) { - GGML_ABORT("ggml-hex: dspqueue_read failed: 0x%08x\n", (unsigned) err); - } + ggml_hexagon_session(int dev_id, ggml_backend_dev_t dev) noexcept(false); + ~ggml_hexagon_session() noexcept(true); - // Basic sanity checks - if (rsp_size != sizeof(rsp)) { - GGML_ABORT("ggml-hex: dspcall : bad response (size)\n"); - } + const char* c_name() const { return name.c_str(); } - if (rsp.status != HTP_STATUS_OK) { - GGML_LOG_ERROR("ggml-hex: dspcall : dsp-rsp: %s\n", status_to_str(rsp.status)); - // TODO: handle errors - } + void allocate(int dev_id) noexcept(false); + void release() noexcept(true); - // TODO: update profiling implementation, currently only works for opt_opsync mode - this->prof_usecs = rsp.prof_usecs; - this->prof_cycles = rsp.prof_cycles; - this->prof_pkts = rsp.prof_pkts; + void enqueue_op(htp_op_code opcode, const ggml_tensor *op); + void flush(bool all = true); - this->op_pending--; // atomic dec - } -} + void flush_pending(bool all = false); + void flush_batch(); +}; // ** backend buffers @@ -227,82 +165,99 @@ struct ggml_backend_hexagon_buffer_type_context { std::string name; }; -struct ggml_backend_hexagon_buffer_context { - bool mmap_to(ggml_hexagon_session * s) { - HEX_VERBOSE("ggml-hex: %s mmaping buffer: base %p domain-id %d session-id %d size %zu fd %d repack %d\n", - s->name.c_str(), (void *) this->base, s->domain_id, s->session_id, this->size, this->fd, - (int) this->repack); +struct ggml_hexagon_shared_buffer { + ggml_hexagon_session * sess; + uint8_t * base; + size_t size; + int fd; + bool mapped; + bool pinned; - int err = fastrpc_mmap(s->domain_id, this->fd, (void *) this->base, 0, this->size, FASTRPC_MAP_FD); + void mmap(bool pinned = false) { + int err = fastrpc_mmap(sess->domain_id, this->fd, (void *) this->base, 0, this->size, FASTRPC_MAP_FD_DELAYED); if (err != 0) { - GGML_LOG_ERROR("ggml-hex: buffer mapping failed : domain_id %d size %zu fd %d error 0x%08x\n", - s->domain_id, this->size, this->fd, (unsigned) err); - return false; + GGML_LOG_ERROR("ggml-hex: %s buffer mapping failed : domain_id %d size %zu fd %d error 0x%08x\n", sess->c_name(), + sess->domain_id, this->size, this->fd, (unsigned) err); + throw std::runtime_error("ggml-hex: fastrpc_mmap failed (see log for details)"); } - return true; - } - - bool mmap() { - if (this->mapped) { - return true; - } - if (!mmap_to(this->sess)) { - return false; + if (pinned) { + err = htp_iface_mmap(sess->handle, this->fd, this->size, pinned); + if (err != 0) { + GGML_LOG_ERROR("ggml-hex: %s buffer pinning failed : domain_id %d size %zu fd %d error 0x%08x\n", sess->c_name(), + sess->domain_id, this->size, this->fd, (unsigned) err); + throw std::runtime_error("ggml-hex: htp_iface_mmap failed (see log for details)"); + } } + this->mapped = true; - return true; + this->pinned = pinned; + HEX_VERBOSE("ggml-hex: %s mapped buffer: base %p size %zu fd %d pinned %u\n", + sess->c_name(), (void *) this->base, this->size, this->fd, pinned); } - void munmap() { - if (!this->mapped) { - return; - } + void unmap() { + if (!this->mapped) return; + + htp_iface_munmap(sess->handle, this->fd); + fastrpc_munmap(sess->domain_id, this->fd, (void *) this->base, this->size); + + HEX_VERBOSE("ggml-hex: %s unmapped buffer: base %p size %zu fd %d\n", sess->c_name(), + (void *) this->base, size, this->fd); - fastrpc_munmap(this->sess->domain_id, this->fd, this->base, this->size); this->mapped = false; + this->fd = -1; } - ggml_backend_hexagon_buffer_context(ggml_hexagon_session * sess, size_t size, bool repack) { - size += 4 * 1024; // extra page for padding + void alloc(size_t size, bool pinned = false) { + if (this->base) return; - this->base = (uint8_t *) rpcmem_alloc2(RPCMEM_HEAP_ID_SYSTEM, RPCMEM_DEFAULT_FLAGS | RPCMEM_HEAP_NOREG, size); + this->base = (uint8_t *) rpcmem_alloc2(RPCMEM_HEAP_ID_SYSTEM, RPCMEM_DEFAULT_FLAGS, size); if (!this->base) { - GGML_LOG_ERROR("ggml-hex: %s failed to allocate buffer : size %zu\n", sess->name.c_str(), size); + GGML_LOG_ERROR("ggml-hex: %s failed to allocate buffer : size %zu\n", sess->c_name(), size); throw std::runtime_error("ggml-hex: rpcmem_alloc failed (see log for details)"); } this->fd = rpcmem_to_fd(this->base); if (this->fd < 0) { - GGML_LOG_ERROR("ggml-hex: %s failed to get FD for buffer %p\n", sess->name.c_str(), (void *) this->base); - rpcmem_free(this->base); - this->base = NULL; + GGML_LOG_ERROR("ggml-hex: %s failed to get FD for buffer %p\n", sess->c_name(), (void *) this->base); throw std::runtime_error("ggml-hex: rpcmem_to_fd failed (see log for details)"); } + this->size = size; + + HEX_VERBOSE("ggml-hex: %s allocated buffer: base %p size %zu fd %d pinned %d\n", sess->c_name(), + (void *) this->base, this->size, this->fd, (int) pinned); + + mmap(pinned); + } + + void free() { + if (!this->base) return; - HEX_VERBOSE("ggml-hex: %s allocated buffer: base %p size %zu fd %d repack %d\n", sess->name.c_str(), - (void *) this->base, size, this->fd, (int) repack); + unmap(); + rpcmem_free(this->base); + + HEX_VERBOSE("ggml-hex: %s freed buffer: base %p size %zu fd %d\n", sess->c_name(), + (void *) this->base, size, this->fd); + + this->base = NULL; + } + + ggml_hexagon_shared_buffer(ggml_hexagon_session * sess, size_t size, bool pinned = false) { + size += 4 * 1024; // extra page for padding this->sess = sess; - this->size = size; + this->size = 0; + this->base = nullptr; + this->fd = -1; this->mapped = false; - this->repack = repack; - } - ~ggml_backend_hexagon_buffer_context() { - munmap(); - if (this->base) { - rpcmem_free(this->base); - this->base = NULL; - } + alloc(size, pinned); } - ggml_hexagon_session * sess; // primary session - uint8_t * base; - size_t size; - int fd; - bool mapped; // mmap is done - bool repack; // repacked buffer + ~ggml_hexagon_shared_buffer() { + free(); + } }; static ggml_hexagon_session * ggml_backend_hexagon_buffer_get_sess(ggml_backend_buffer_t buffer) { @@ -310,30 +265,26 @@ static ggml_hexagon_session * ggml_backend_hexagon_buffer_get_sess(ggml_backend_ } static void ggml_backend_hexagon_buffer_free_buffer(ggml_backend_buffer_t buffer) { - auto ctx = static_cast(buffer->context); - delete ctx; + auto sbuf = static_cast(buffer->context); + delete sbuf; } static void * ggml_backend_hexagon_buffer_get_base(ggml_backend_buffer_t buffer) { - auto ctx = static_cast(buffer->context); - return ctx->base; + auto sbuf = static_cast(buffer->context); + return sbuf->base; } static enum ggml_status ggml_backend_hexagon_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { - auto ctx = static_cast(buffer->context); - auto sess = ctx->sess; + auto sbuf = static_cast(buffer->context); + auto sess = sbuf->sess; - HEX_VERBOSE("ggml-hex: %s init-tensor %s : base %p data %p nbytes %zu usage %d repack %d\n", sess->name.c_str(), - tensor->name, (void *) ctx->base, tensor->data, ggml_nbytes(tensor), (int) buffer->usage, - (int) ctx->repack); + HEX_VERBOSE("ggml-hex: %s init-tensor %s : base %p data %p nbytes %zu usage %d\n", sess->c_name(), + tensor->name, (void *) sbuf->base, tensor->data, ggml_nbytes(tensor), (int) buffer->usage); if (tensor->view_src != NULL && tensor->view_offs == 0) { - ; // nothing to do for the view - } else { - if (!ctx->mapped) { - ctx->mmap(); - } + return GGML_STATUS_SUCCESS; // nothing to do for the view } + return GGML_STATUS_SUCCESS; } @@ -1387,11 +1338,10 @@ static void ggml_backend_hexagon_buffer_set_tensor(ggml_backend_buffer_t buffer, const void * data, size_t offset, size_t size) { - auto ctx = (ggml_backend_hexagon_buffer_context *) buffer->context; - auto sess = ctx->sess; + auto sbuf = (ggml_hexagon_shared_buffer *) buffer->context; + auto sess = sbuf->sess; - HEX_VERBOSE("ggml-hex: %s set-tensor %s : data %p offset %zu size %zu\n", sess->name.c_str(), tensor->name, data, - offset, size); + HEX_VERBOSE("ggml-hex: %s set-tensor %s : data %p offset %zu size %zu\n", sess->c_name(), tensor->name, data, offset, size); switch (tensor->type) { case GGML_TYPE_Q4_0: @@ -1430,11 +1380,10 @@ static void ggml_backend_hexagon_buffer_get_tensor(ggml_backend_buffer_t buffer, void * data, size_t offset, size_t size) { - auto ctx = (ggml_backend_hexagon_buffer_context *) buffer->context; - auto sess = ctx->sess; + auto sbuf = (ggml_hexagon_shared_buffer *) buffer->context; + auto sess = sbuf->sess; - HEX_VERBOSE("ggml-hex: %s get-tensor %s : data %p offset %zu size %zu\n", sess->name.c_str(), tensor->name, data, - offset, size); + HEX_VERBOSE("ggml-hex: %s get-tensor %s : data %p offset %zu size %zu\n", sess->c_name(), tensor->name, data, offset, size); switch (tensor->type) { case GGML_TYPE_Q4_0: @@ -1478,10 +1427,10 @@ static bool ggml_backend_hexagon_buffer_cpy_tensor(ggml_backend_buffer_t bu } static void ggml_backend_hexagon_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { - auto ctx = (ggml_backend_hexagon_buffer_context *) buffer->context; - auto sess = ctx->sess; - HEX_VERBOSE("ggml-hex: %s clear-buff base %p size %zu\n", sess->name.c_str(), (void *) ctx->base, ctx->size); - memset(ctx->base, value, ctx->size); + auto sbuf = (ggml_hexagon_shared_buffer *) buffer->context; + auto sess = sbuf->sess; + HEX_VERBOSE("ggml-hex: %s clear-buff base %p size %zu\n", sess->c_name(), (void *) sbuf->base, sbuf->size); + memset(sbuf->base, value, sbuf->size); } static ggml_backend_buffer_i ggml_backend_hexagon_buffer_interface = { @@ -1508,10 +1457,10 @@ static ggml_backend_buffer_t ggml_backend_hexagon_buffer_type_alloc_buffer( ggml_backend_buffer_type_t buffer_type, size_t size) { auto sess = static_cast(buffer_type->context)->sess; try { - ggml_backend_hexagon_buffer_context * ctx = new ggml_backend_hexagon_buffer_context(sess, size, false /*repack*/); - return ggml_backend_buffer_init(buffer_type, ggml_backend_hexagon_buffer_interface, ctx, size); + ggml_hexagon_shared_buffer * sbuf = new ggml_hexagon_shared_buffer(sess, size); + return ggml_backend_buffer_init(buffer_type, ggml_backend_hexagon_buffer_interface, sbuf, size); } catch (const std::exception & exc) { - GGML_LOG_ERROR("ggml-hex: %s failed to allocate buffer context: %s\n", sess->name.c_str(), exc.what()); + GGML_LOG_ERROR("ggml-hex: %s failed to allocate buffer context (host): %s\n", sess->c_name(), exc.what()); return nullptr; } } @@ -1520,10 +1469,10 @@ static ggml_backend_buffer_t ggml_backend_hexagon_repack_buffer_type_alloc_buffe ggml_backend_buffer_type_t buffer_type, size_t size) { auto sess = static_cast(buffer_type->context)->sess; try { - ggml_backend_hexagon_buffer_context * ctx = new ggml_backend_hexagon_buffer_context(sess, size, true /*repack*/); - return ggml_backend_buffer_init(buffer_type, ggml_backend_hexagon_buffer_interface, ctx, size); + ggml_hexagon_shared_buffer * sbuf = new ggml_hexagon_shared_buffer(sess, size); + return ggml_backend_buffer_init(buffer_type, ggml_backend_hexagon_buffer_interface, sbuf, size); } catch (const std::exception & exc) { - GGML_LOG_ERROR("ggml-hex: %s failed to allocate buffer context: %s\n", sess->name.c_str(), exc.what()); + GGML_LOG_ERROR("ggml-hex: %s failed to allocate buffer context (repack): %s\n", sess->c_name(), exc.what()); return nullptr; } } @@ -1538,7 +1487,7 @@ static size_t ggml_backend_hexagon_buffer_type_get_alloc_size(ggml_backend_buffe } static size_t ggml_backend_hexagon_buffer_type_get_max_size(ggml_backend_buffer_type_t buffer_type) { - return 1 * 1024 * 1024 * 1024; // 1GB per buffer + return 1UL * 1024 * 1024 * 1024; // 1GB per buffer GGML_UNUSED(buffer_type); } @@ -1570,6 +1519,373 @@ static ggml_backend_buffer_type_i ggml_backend_hexagon_repack_buffer_type_interf /* .is_host = */ ggml_backend_hexagon_repack_buffer_type_is_host, }; +// Backend session implementation + +struct ggml_hexagon_opshm { + ggml_hexagon_shared_buffer *sbuf; + + std::vector block_mask; + size_t block_size; + + uint8_t * base() const { return this->sbuf->base; } + int fd() const { return this->sbuf->fd; } + size_t n_blocks() const { return this->block_mask.size(); } + + ggml_hexagon_opshm(ggml_hexagon_session *sess, size_t max_batch, size_t max_pending) { + size_t n_bufs = HTP_OP_MAX_BUFS; + size_t n_ops = max_batch; + size_t n_tensors = n_ops + n_ops * HTP_OP_MAX_INPUTS; + + block_mask.resize(max_pending, true); + + block_size = sizeof(htp_buf_desc) * n_bufs + + sizeof(htp_tensor) * n_tensors + + sizeof(htp_op_desc) * n_ops; + + sbuf = new ggml_hexagon_shared_buffer(sess, block_size * block_mask.size(), true /* pinned */); + + if (opt_verbose) { + GGML_LOG_INFO("ggml-hex: %s allocated shared buf %zu : block-size %zu max-batch %zu max-pending %zu\n", + sess->c_name(), (size_t) sbuf->size, block_size, max_batch, max_pending); + } + } + + ~ggml_hexagon_opshm() { + delete sbuf; + } + + uint8_t * allocate() { + auto it = std::find(block_mask.begin(), block_mask.end(), true); + if (it == block_mask.end()) + return nullptr; + + unsigned int i = std::distance(block_mask.begin(), it); + uint8_t* addr = sbuf->base + (i * block_size); + block_mask[i] = false; + + HEX_VERBOSE("ggml-hex: %s allocated op shm #%u %p\n", sbuf->sess->c_name(), i, (void*) addr); + return addr; + } + + void release(uint8_t * addr) { + int i = (addr - sbuf->base) / block_size; + block_mask[i] = true; + HEX_VERBOSE("ggml-hex: %s released op shm #%u %p\n", sbuf->sess->c_name(), i, (void*) addr); + } +}; + +struct ggml_hexagon_opbatch { + const char* name; + + std::vector buffers; + std::vector tensors; + std::vector ops; + + std::unordered_map b_map; // buffer fd to index + std::unordered_map t_map; // tensor ptr to index + std::unordered_multimap d_map; // tensor data to index + + unsigned int n_bufs; // num buffers in the batch + unsigned int n_tens; // num tensors ... + unsigned int n_ops; // num ops ... + size_t b_vmem; // sum of all buffer sizes + + unsigned int n_bufs_max; + unsigned int n_tens_max; + unsigned int n_ops_max; + size_t b_vmem_max; + + void reset() { + n_bufs = 0; + n_tens = 0; + n_ops = 0; + b_vmem = 0; + + b_map.clear(); + t_map.clear(); + d_map.clear(); + } + + ggml_hexagon_opbatch(ggml_hexagon_session *sess, size_t max_batch) { + name = sess->c_name(); + + n_bufs_max = HTP_OP_MAX_BUFS; + n_ops_max = max_batch; + n_tens_max = n_ops_max + n_ops_max * HTP_OP_MAX_INPUTS; + + b_vmem_max = HTP_OP_MAX_VMEM; + + buffers.resize(n_bufs_max); + tensors.resize(n_tens_max); + ops.resize(n_ops_max); + + b_map.reserve(n_bufs_max); + t_map.reserve(n_tens_max); + d_map.reserve(n_tens_max); + + reset(); + } + + bool empty() const { return n_ops == 0; } + + // add buffer and return its index + int add_buffer(ggml_hexagon_shared_buffer * sbuf) { + // Lookup by fd + auto it = b_map.find(sbuf->fd); + if (it != b_map.end()) { return it->second; } + + // Add new buffer to the batch + int bi = n_bufs++; + GGML_ASSERT(n_bufs < HTP_OP_MAX_BUFS); + + b_map.insert({sbuf->fd, bi}); + + htp_buf_desc &b = buffers[bi]; + b.base = (uint64_t) sbuf->base; + b.fd = sbuf->fd; + b.size = sbuf->size; + + b_vmem += b.size; + + HEX_VERBOSE("ggml-hex: add-buffer #%u : fd %d base %p size %zu : vmem %zu\n", bi, b.fd, (void*) sbuf->base, (size_t) b.size, b_vmem); + + return bi; + } + + bool same_shape(const htp_tensor * h, const ggml_tensor * t) const { + return (h->ne[0] == t->ne[0]) && (h->ne[1] == t->ne[1]) && (h->ne[2] == t->ne[2]) && (h->ne[3] == t->ne[3]) && + (h->nb[0] == t->nb[0]) && (h->nb[1] == t->nb[1]) && (h->nb[2] == t->nb[2]) && (h->nb[3] == t->nb[3]); + } + + // add tensor and return its index + int add_tensor(const ggml_tensor * t) { + auto sbuf = static_cast(t->buffer->context); + + // First lookup by tensor data + auto range = d_map.equal_range(t->data); + for (auto it = range.first; it != range.second; ++it) { + htp_tensor * h = &tensors[it->second]; + if (same_shape(h, t)) { return it->second; } + } + + // Lookup by tensor ptr + auto it = t_map.find(t); + if (it != t_map.end()) { return it->second; } + + // Add new tensor to the batch + int ti = n_tens++; + GGML_ASSERT(n_tens <= n_tens_max); + + t_map.insert({t, ti}); + d_map.insert({t->data, ti}); + + uint64_t t_offset = (uint8_t *) t->data - sbuf->base; + size_t t_size = ggml_nbytes(t); + + htp_tensor &h = tensors[ti]; + h.bi = add_buffer(sbuf); + h.data = t_offset; + h.size = t_size; + h.type = t->type; + h.ne[0] = t->ne[0]; h.ne[1] = t->ne[1]; h.ne[2] = t->ne[2]; h.ne[3] = t->ne[3]; + h.nb[0] = t->nb[0]; h.nb[1] = t->nb[1]; h.nb[2] = t->nb[2]; h.nb[3] = t->nb[3]; + + h.flags = 0; + if (ggml_backend_buffer_get_usage(t->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE) { + h.flags |= HTP_TENSOR_COMPUTE; + } + + HEX_VERBOSE("ggml-hex: add-tensor #%u %s : bi %d data %p offset %zu size %zu flags 0x%x : %zu:%zu:%zu:%zu\n", + ti, t->name, h.bi, (void*) t->data, (size_t) t_offset, t_size, h.flags, + (size_t) t->ne[0], (size_t) t->ne[1], (size_t) t->ne[2], (size_t) t->ne[3]); + + return ti; + } + + bool fit_op(const struct ggml_tensor *t) const { + if (n_ops >= n_ops_max ) return false; + + // check how much extras we will need + size_t extra_bufs = 0; + size_t extra_vmem = 0; + size_t extra_tens = 0; + + auto fit_tensor = [&](const ggml_tensor *t) { + if (!t_map.count(t)) { + extra_tens++; + + auto sbuf = static_cast(t->buffer->context); + if (!b_map.count(sbuf->fd)) { + extra_vmem += sbuf->size; + extra_bufs += 1; + } + } + }; + + for (unsigned int i=0; i < HTP_OP_MAX_INPUTS && t->src[i]; i++) { + fit_tensor(t->src[i]); + } + fit_tensor(t); + + if ((extra_bufs + n_bufs) > n_bufs_max) return false; + if ((extra_tens + n_tens) > n_tens_max) return false; + if ((extra_vmem + b_vmem) > b_vmem_max) return false; + + return true; + } + + // assumes that fit_op() was called first and returned true + void add_op(htp_op_code opcode, const struct ggml_tensor * t) { + // Add new op + htp_op_desc &o = ops[n_ops++]; + GGML_ASSERT(n_ops <= n_ops_max); + + memcpy(&o.params, &t->op_params, sizeof(t->op_params)); + o.opcode = opcode; + o.flags = 0; + + if (!(opt_opmask & HTP_OPMASK_COMPUTE)) { + o.flags |= HTP_OPFLAGS_SKIP_COMPUTE; + } + + ggml_hexagon_dump_op_exec(name, t, o.flags); + + for (unsigned int i=0; i < HTP_OP_MAX_INPUTS; i++) { + o.src[i] = t->src[i] ? add_tensor(t->src[i]) : 0xffff; + } + o.dst = add_tensor(t); + } + + size_t flush(uint8_t * mem_addr, size_t mem_size) { + static_assert(sizeof(htp_buf_desc) % 8 == 0, "sizeof(htp_buf_desc) must be multiple of 8"); + static_assert(sizeof(htp_tensor) % 8 == 0, "sizeof(htp_tensor) must be multiple of 8"); + static_assert(sizeof(htp_op_desc) % 8 == 0, "sizeof(htp_op_desc) must be multiple of 8"); + + const size_t b_size = sizeof(htp_buf_desc) * n_bufs; + const size_t t_size = sizeof(htp_tensor) * n_tens; + const size_t o_size = sizeof(htp_op_desc) * n_ops; + + const size_t m_size = b_size + t_size + o_size; + GGML_ASSERT(m_size <= mem_size); + + uint8_t * b_ptr = (uint8_t *) mem_addr; + uint8_t * t_ptr = (uint8_t *) b_ptr + b_size; + uint8_t * o_ptr = (uint8_t *) t_ptr + t_size; + + memcpy(b_ptr, (void *) buffers.data(), b_size); + memcpy(t_ptr, (void *) tensors.data(), t_size); + memcpy(o_ptr, (void *) ops.data(), o_size); + + HEX_VERBOSE("ggml-hex: %s flush-opbatch : n-bufs %u n-tensors %u n-ops %u vmem %zu : b-size %zu t-size %zu o-size %zu\n", + name, n_bufs, n_tens, n_ops, b_vmem, b_size, t_size, o_size); + + if (opt_verbose > 1) { + htp_buf_desc *b = (htp_buf_desc*) b_ptr; + for (unsigned int i=0; i < n_bufs; i++) { + GGML_LOG_DEBUG("ggml-hex: %s htp-buf #%u : fd %d base %p size %zu\n", name, i, + b[i].fd, (void *) b[i].base, (size_t) b[i].size); + } + htp_tensor *t = (htp_tensor*) t_ptr; + for (unsigned int i=0; i < n_tens; i++) { + GGML_LOG_DEBUG("ggml-hex: %s htp-tensor #%u : bi %u offset %u size %u : %zu:%zu:%zu:%zu\n", + name, i, t[i].bi, t[i].data, t[i].size, + (size_t) t[i].ne[0], (size_t) t[i].ne[1], (size_t) t[i].ne[2], (size_t) t[i].ne[3]); + } + } + + reset(); + + return m_size; + } +}; + +// Flush HTP response queue i.e wait for all outstanding requests to complete +void ggml_hexagon_session::flush_pending(bool all) { + while (this->op_pending) { + struct htp_opbatch_rsp rsp; + uint32_t rsp_size; + uint32_t flags; + + struct dspqueue_buffer dbuf; + uint32_t n_dbufs; + + // Read response packet from queue + int err = dspqueue_read(this->queue, &flags, 1, &n_dbufs, &dbuf, sizeof(rsp), &rsp_size, (uint8_t *) &rsp, DSPQUEUE_TIMEOUT); + if (err == AEE_EEXPIRED) { + continue; + } + + if (err != 0) { + GGML_ABORT("ggml-hex: dspqueue_read failed: 0x%08x\n", (unsigned) err); + } + + // Basic sanity checks + if (rsp_size != sizeof(rsp) || n_dbufs != 1) { + GGML_ABORT("ggml-hex: %s dspcall : bad response : size %u dspbufs %u\n", this->c_name(), rsp_size, n_dbufs); + } + + op_shm->release((uint8_t*) dbuf.ptr); + + if (rsp.status != HTP_STATUS_OK) { + GGML_LOG_ERROR("ggml-hex: %s dspcall : dsp-rsp: %s\n", this->c_name(), status_to_str(rsp.status)); + // TODO: handle errors + } + + // FIXME: profile will be per opreq + // this->prof_usecs = rsp.prof_usecs; + // this->prof_cycles = rsp.prof_cycles; + // this->prof_pkts = rsp.prof_pkts; + + this->op_pending--; // atomic dec + + if (!all) break; + } +} + +void ggml_hexagon_session::flush_batch() { + if (op_batch->empty()) { return; } + + htp_opbatch_req req; + req.n_bufs = op_batch->n_bufs; + req.n_tensors = op_batch->n_tens; + req.n_ops = op_batch->n_ops; + + dspqueue_buffer dbuf; + dbuf.fd = op_shm->fd(); + dbuf.flags = DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT; + dbuf.ptr = op_shm->allocate(); + if (!dbuf.ptr) { + flush_pending(false); + dbuf.ptr = op_shm->allocate(); + } + + dbuf.offset = (uint8_t*) dbuf.ptr - (uint8_t*) op_shm->base(); + dbuf.size = op_batch->flush((uint8_t*) dbuf.ptr, op_shm->block_size); + + // Bump pending flag (cleared in the session::flush once we get the response) + this->op_pending++; // atomic inc + + HEX_VERBOSE("ggml-hex: %s: queue-opbatch : %p size %u\n", this->c_name(), dbuf.ptr, dbuf.size); + + int err = dspqueue_write(this->queue, 0, 1, &dbuf, sizeof(req), (const uint8_t*) &req, DSPQUEUE_TIMEOUT); + if (err != 0) { + GGML_ABORT("ggml-hex: %s dspqueue_write failed: 0x%08x\n", this->c_name(), (unsigned) err); + } +} + +void ggml_hexagon_session::enqueue_op(htp_op_code opcode, const ggml_tensor *op) { + if (!op_batch->fit_op(op)) { + flush_batch(); + } + op_batch->add_op(opcode, op); +} + +// Flush HTP response queue i.e wait for all outstanding requests to complete +void ggml_hexagon_session::flush(bool all) { + flush_batch(); + flush_pending(all); +} + void ggml_hexagon_session::allocate(int dev_id) noexcept(false) { this->valid_session = false; this->valid_handle = false; @@ -1582,9 +1898,6 @@ void ggml_hexagon_session::allocate(int dev_id) noexcept(false) { this->name = std::string("HTP") + std::to_string(dev_id); this->op_pending = 0; - this->prof_usecs = 0; - this->prof_cycles = 0; - this->prof_pkts = 0; GGML_LOG_INFO("ggml-hex: allocating new session: %s\n", this->name.c_str()); @@ -1676,11 +1989,14 @@ void ggml_hexagon_session::allocate(int dev_id) noexcept(false) { } } + const size_t req_q_size = (sizeof(htp_opbatch_req) * opt_opqueue * 2) + 1024; + const size_t rsp_q_size = (sizeof(htp_opbatch_rsp) * opt_opqueue * 2) + 1024; + // Now let's setup the DSP queue err = dspqueue_create(this->domain_id, 0, // Flags - 128 * 1024, // Request queue size (in bytes) - 64 * 1024, // Response queue size (in bytes) + req_q_size, // Request queue size (in bytes) + rsp_q_size, // Response queue size (in bytes) nullptr, // Read packet callback (we handle reads explicitly) nullptr, // Error callback (we handle errors during reads) (void *) this, // Callback context @@ -1715,6 +2031,10 @@ void ggml_hexagon_session::allocate(int dev_id) noexcept(false) { throw std::runtime_error("ggml-hex: iface start failed (see log for details)"); } this->valid_iface = true; + + // Allocate buffers and state for op batching + this->op_batch = new ggml_hexagon_opbatch(this, opt_opbatch); + this->op_shm = new ggml_hexagon_opshm(this, opt_opbatch, opt_opqueue); } void ggml_hexagon_session::release() noexcept(true) { @@ -1722,6 +2042,9 @@ void ggml_hexagon_session::release() noexcept(true) { int err; + delete this->op_batch; + delete this->op_shm; + // Stop the DSP-side service and close the queue if (this->valid_iface) { err = htp_iface_stop(this->handle); @@ -1753,6 +2076,9 @@ ggml_hexagon_session::ggml_hexagon_session(int dev_id, ggml_backend_dev_t dev) n buffer_type.device = dev; repack_buffer_type.device = dev; + op_batch = nullptr; + op_shm = nullptr; + try { allocate(dev_id); @@ -1815,9 +2141,13 @@ static bool ggml_hexagon_supported_flash_attn_ext(const struct ggml_hexagon_sess return false; } - return opt_experimental; -} + if (dst->ne[2] != 1 || dst->ne[3] != 1) { + // FA during prompt still needs work + return false; + } + return true; +} static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * sess, const struct ggml_tensor * dst) { const struct ggml_tensor * src0 = dst->src[0]; @@ -2082,6 +2412,23 @@ static bool ggml_hexagon_supported_softmax(const struct ggml_hexagon_session * s } } + // Reject non-HVX-aligned sizes when ne[0] > HVX_F32_LANES + // The HVX softmax implementation has issues with tail handling for larger non-aligned sizes + // Small sizes (ne[0] <= 32) work correctly with tail-only processing + const int64_t ne0 = src0->ne[0]; + if (ne0 > 32 && (ne0 & (32 - 1)) != 0) { + return false; + } + + // HVX vector size constraints for softmax + #define SOFTMAX_MAX_ROW_SIZE 131072 // 128K elements max for numerical precision + + // Reject very large row sizes to avoid numerical precision issues + // Softmax accumulation over many elements can lead to precision loss + if (ne0 > SOFTMAX_MAX_ROW_SIZE) { + return false; + } + return true; } @@ -2249,571 +2596,85 @@ static bool ggml_hexagon_supported_cumsum(const struct ggml_hexagon_session * se return true; } -enum dspqbuf_type { - DSPQBUF_TYPE_DSP_WRITE_CPU_READ = 0, - DSPQBUF_TYPE_CPU_WRITE_DSP_READ, - DSPQBUF_TYPE_CONSTANT, -}; - -static void dspqbuf_dump(dspqueue_buffer * d, const struct ggml_tensor * t, dspqbuf_type type) { - if (opt_verbose < 2) return; - - auto buf = static_cast(t->buffer->context); - auto sess = buf->sess; - - GGML_LOG_DEBUG("ggml-hex: %s dspqbuf : %s base-addr %p base-size %zu data %p offset %u size %u\n", sess->name.c_str(), - t->name, (void *) buf->base, buf->size, (void *) d->ptr, (unsigned int) d->offset, - (unsigned int) d->size); -} - -// Init hexagon tensor from GGML tensor and Hexagon buffer -static void htp_req_tensor_init(htp_tensor * h, const ggml_tensor * t) { - h->data = 0; // updated by the receiver - h->type = t->type; - h->ne[0] = t->ne[0]; - h->ne[1] = t->ne[1]; - h->ne[2] = t->ne[2]; - h->ne[3] = t->ne[3]; - h->nb[0] = t->nb[0]; - h->nb[1] = t->nb[1]; - h->nb[2] = t->nb[2]; - h->nb[3] = t->nb[3]; -} - -static size_t htp_req_buff_init(htp_tensor *h, dspqueue_buffer * d, const ggml_tensor * t, dspqbuf_type type) { - if (!t) { - return 0; - } - - auto buf = static_cast(t->buffer->context); - - memset(d, 0, sizeof(*d)); - d->fd = buf->fd; - d->ptr = t->data; - d->offset = (uint8_t *) t->data - buf->base; - d->size = ggml_nbytes(t); - - if (!d->size) { - // Some requests contain srcs where ggml_nbytes() returns 0 but the rest of the op is non-empty - d->size = 64; - } - - switch (type) { - case DSPQBUF_TYPE_DSP_WRITE_CPU_READ: - // Flush CPU - d->flags = DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER; - break; - case DSPQBUF_TYPE_CPU_WRITE_DSP_READ: - // Flush CPU, Invalidate DSP - d->flags = DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT; - break; - default: - // Constant buffer, no cache maintenance - d->flags = 0; - break; - } - - htp_req_tensor_init(h, t); - - dspqbuf_dump(d, t, type); - - return 1; -} - -typedef size_t (*htp_req_init_func_t)(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * op); - -template -static inline void ggml_hexagon_dispatch_op(ggml_hexagon_session *sess, const struct ggml_tensor * op, uint32_t flags) { - uint64_t t = ggml_time_us(); - - // Construct HTP request - htp_general_req req; - memset(&req, 0, sizeof(req)); - - req.flags = flags; - if (!(opt_opmask & HTP_OPMASK_QUANTIZE)) { - req.flags |= HTP_OPFLAGS_SKIP_QUANTIZE; - } - if (!(opt_opmask & HTP_OPMASK_COMPUTE)) { - req.flags |= HTP_OPFLAGS_SKIP_COMPUTE; - } - - ggml_hexagon_dump_op_exec(sess->name, op, req.flags); - - if ((opt_opmask & HTP_OPMASK_QUEUE)) { - dspqueue_buffer bufs[HTP_MAX_PACKET_BUFFERS]; - size_t n_bufs = _init_req_func(&req, bufs, op); - sess->enqueue(req, bufs, n_bufs, opt_opsync); - } - - t = ggml_time_us() - t; - - ggml_hexagon_dump_op_prof(sess->name, op, sess->prof_usecs, sess->prof_cycles, sess->prof_pkts, t); -} - -template -static inline size_t init_binary_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { - switch (t->op) { - case GGML_OP_MUL_MAT: - req->op = HTP_OP_MUL_MAT; - break; - case GGML_OP_MUL: - req->op = HTP_OP_MUL; - break; - case GGML_OP_ADD: - req->op = HTP_OP_ADD; - break; - case GGML_OP_SUB: - req->op = HTP_OP_SUB; - break; - case GGML_OP_DIV: - req->op = HTP_OP_DIV; - break; - default: - GGML_ABORT("ggml-hex: binary : unsupported op: %d\n", t->op); - break; - } - - // src0: Weights (mulmat) or First Operand (binary op). - // If constant (e.g. weights), no cache management is needed. - // src1: Input Activations (mulmat) or Second Operand (binary op). - - size_t n_bufs = 0; - n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], _is_src0_constant ? DSPQBUF_TYPE_CONSTANT : DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); - - return n_bufs; -} - -static inline size_t init_cpy_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { - req->op = HTP_OP_CPY; - - size_t n_bufs = 0; - n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); - - return n_bufs; -} - -static inline size_t init_cont_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { - // CONT is just a contiguous copy — reuse CPY op - req->op = HTP_OP_CPY; - - size_t n_bufs = 0; - n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); - - return n_bufs; -} - -static inline size_t init_repeat_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { - req->op = HTP_OP_REPEAT; - - size_t n_bufs = 0; - n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); - - return n_bufs; -} - -static inline size_t init_cumsum_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { - req->op = HTP_OP_CUMSUM; - - size_t n_bufs = 0; - n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); - - return n_bufs; -} - -static inline size_t init_get_rows_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { - req->op = HTP_OP_GET_ROWS; - - size_t n_bufs = 0; - n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); - - return n_bufs; -} - -static inline size_t init_argsort_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { - req->op = HTP_OP_ARGSORT; - memcpy(&req->op_params, &t->op_params, sizeof(t->op_params)); - - size_t n_bufs = 0; - n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); - - return n_bufs; -} - -template -static inline size_t init_binary_id_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { - switch (t->op) { - case GGML_OP_MUL_MAT_ID: - req->op = HTP_OP_MUL_MAT_ID; - break; - case GGML_OP_ADD_ID: - req->op = HTP_OP_ADD_ID; - break; - default: - GGML_ABORT("ggml-hex: unsupported op: %d\n", t->op); - } - - // src0: Weights (mulmat) or Input Activations (other op). - // If constant, no cache management is needed. - // src1: Input Activations (mulmat) or Second Operand (binary op). - // src2: Expert IDs (mulmat) or Activated Experts (other op). - - size_t n_bufs = 0; - n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], _is_src0_constant ? DSPQBUF_TYPE_CONSTANT : DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->src2, &bufs[n_bufs], t->src[2], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); - - return n_bufs; +static const char * ggml_backend_hexagon_name(ggml_backend_t backend) { + auto sess = static_cast(backend->context); + return sess->c_name(); } -static inline size_t init_set_rows_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { - req->op = HTP_OP_SET_ROWS; - - size_t n_bufs = 0; - n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); - - return n_bufs; +static void ggml_backend_hexagon_free(ggml_backend_t backend) { + // we just need to delete the backend here + // the sessions are allocated & freed as part of the registry + delete backend; } -static inline size_t init_unary_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { - memcpy(&req->op_params, &t->op_params, sizeof(t->op_params)); - - bool supported = false; - +static htp_op_code op_remap_to_htp(const ggml_tensor * t) { switch (t->op) { - case GGML_OP_RMS_NORM: - req->op = HTP_OP_RMS_NORM; - supported = true; - break; - - case GGML_OP_SCALE: - req->op = HTP_OP_SCALE; - supported = true; - break; - - case GGML_OP_SQR: - req->op = HTP_OP_SQR; - supported = true; - break; - - case GGML_OP_SQRT: - req->op = HTP_OP_SQRT; - supported = true; - break; + case GGML_OP_FLASH_ATTN_EXT: return HTP_OP_FLASH_ATTN_EXT; + case GGML_OP_MUL_MAT: return HTP_OP_MUL_MAT; + case GGML_OP_MUL_MAT_ID: return HTP_OP_MUL_MAT_ID; + case GGML_OP_MUL: return HTP_OP_MUL; + case GGML_OP_ADD: return HTP_OP_ADD; + case GGML_OP_ADD_ID: return HTP_OP_ADD_ID; + case GGML_OP_SUB: return HTP_OP_SUB; + case GGML_OP_DIV: return HTP_OP_DIV; + case GGML_OP_CPY: return HTP_OP_CPY; + case GGML_OP_CONT: return HTP_OP_CPY; + case GGML_OP_GET_ROWS: return HTP_OP_GET_ROWS; + case GGML_OP_SET_ROWS: return HTP_OP_SET_ROWS; + case GGML_OP_SUM_ROWS: return HTP_OP_SUM_ROWS; + case GGML_OP_ARGSORT: return HTP_OP_ARGSORT; + case GGML_OP_RMS_NORM: return HTP_OP_RMS_NORM; + case GGML_OP_SCALE: return HTP_OP_SCALE; + case GGML_OP_SQR: return HTP_OP_SQR; + case GGML_OP_SQRT: return HTP_OP_SQRT; + case GGML_OP_SOFT_MAX: return HTP_OP_SOFTMAX; + case GGML_OP_SSM_CONV: return HTP_OP_SSM_CONV; + case GGML_OP_ROPE: return HTP_OP_ROPE; + case GGML_OP_REPEAT: return HTP_OP_REPEAT; + case GGML_OP_CUMSUM: return HTP_OP_CUMSUM; case GGML_OP_UNARY: switch (ggml_get_unary_op(t)) { - case GGML_UNARY_OP_SILU: - req->op = HTP_OP_UNARY_SILU; - supported = true; - break; - case GGML_UNARY_OP_GELU: - req->op = HTP_OP_UNARY_GELU; - supported = true; - break; - case GGML_UNARY_OP_SIGMOID: - req->op = HTP_OP_UNARY_SIGMOID; - supported = true; - break; - case GGML_UNARY_OP_NEG: - req->op = HTP_OP_UNARY_NEG; - supported = true; - break; - case GGML_UNARY_OP_EXP: - req->op = HTP_OP_UNARY_EXP; - supported = true; - break; - case GGML_UNARY_OP_SOFTPLUS: - req->op = HTP_OP_UNARY_SOFTPLUS; - supported = true; - break; + case GGML_UNARY_OP_SILU: return HTP_OP_UNARY_SILU; + case GGML_UNARY_OP_GELU: return HTP_OP_UNARY_GELU; + case GGML_UNARY_OP_SIGMOID: return HTP_OP_UNARY_SIGMOID; + case GGML_UNARY_OP_NEG: return HTP_OP_UNARY_NEG; + case GGML_UNARY_OP_EXP: return HTP_OP_UNARY_EXP; + case GGML_UNARY_OP_SOFTPLUS: return HTP_OP_UNARY_SOFTPLUS; default: break; } break; case GGML_OP_GLU: - if (ggml_get_glu_op(t) == GGML_GLU_OP_SWIGLU) { - req->op = HTP_OP_GLU_SWIGLU; - supported = true; - } else if (ggml_get_glu_op(t) == GGML_GLU_OP_SWIGLU_OAI) { - req->op = HTP_OP_GLU_SWIGLU_OAI; - supported = true; - } else if (ggml_get_glu_op(t) == GGML_GLU_OP_GEGLU) { - req->op = HTP_OP_GLU_GEGLU; - supported = true; + switch (ggml_get_glu_op(t)) { + case GGML_GLU_OP_SWIGLU: return HTP_OP_GLU_SWIGLU; + case GGML_GLU_OP_SWIGLU_OAI: return HTP_OP_GLU_SWIGLU_OAI; + case GGML_GLU_OP_GEGLU: return HTP_OP_GLU_GEGLU; + default: break; } break; - case GGML_OP_SOFT_MAX: - req->op = HTP_OP_SOFTMAX; - supported = true; - break; - default: - break; - } - - if (!supported) { - GGML_ABORT("ggml-hex: unary : unsupported op: %d\n", t->op); + GGML_ABORT("\nggml-hex: graph-compute %s is not supported\n", ggml_op_desc(t)); } - - size_t n_bufs = 0; - n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); - - return n_bufs; -} - -static inline size_t init_sum_rows_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { - memcpy(&req->op_params, &t->op_params, sizeof(t->op_params)); - req->op = HTP_OP_SUM_ROWS; - - size_t n_bufs = 0; - n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); - - return n_bufs; -} - -static inline size_t init_rope_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { - memcpy(&req->op_params, &t->op_params, sizeof(t->op_params)); - req->op = HTP_OP_ROPE; - - size_t n_bufs = 0; - n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->src2, &bufs[n_bufs], t->src[2], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); - - return n_bufs; + return HTP_OP_INVALID; } -static inline size_t init_flash_attn_ext_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { - memcpy(&req->op_params, &t->op_params, sizeof(t->op_params)); - req->op = HTP_OP_FLASH_ATTN_EXT; - - size_t n_bufs = 0; - n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->src2, &bufs[n_bufs], t->src[2], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->src3, &bufs[n_bufs], t->src[3], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->src4, &bufs[n_bufs], t->src[4], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); - - return n_bufs; -} - -static inline size_t init_ssm_conv_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { - req->op = HTP_OP_SSM_CONV; - - size_t n_bufs = 0; - n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CONSTANT); - n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); - - return n_bufs; -} - -static const char * ggml_backend_hexagon_name(ggml_backend_t backend) { - auto sess = static_cast(backend->context); - return sess->name.c_str(); -} - -static void ggml_backend_hexagon_free(ggml_backend_t backend) { - // we just need to delete the backend here - // the sessions are allocated & freed as part of the registry - delete backend; -} - -// Map weight type to its activation quantization family. -// Types in the same family produce identical Q8 formats in VTCM and can -// safely share quantized activation data via SKIP_QUANTIZE. -// When adding a new quantized type, assign it the correct family here. -static inline int act_quant_family(enum ggml_type wtype) { - switch (wtype) { - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q8_0: - case GGML_TYPE_IQ4_NL: - case GGML_TYPE_MXFP4: - return 1; // Q8x4x2 - default: - return 0; // unknown / not quantized - } -} - -static inline bool op_reuse_src1(const ggml_tensor * op1, const ggml_tensor * op0) { - return (op0 && op0->src[1] == op1->src[1] && - act_quant_family(op0->src[0]->type) == act_quant_family(op1->src[0]->type) && - act_quant_family(op0->src[0]->type) != 0); -} - -static inline bool is_compute_op(ggml_tensor *node) +static inline bool op_is_compute(ggml_tensor *node) { return !ggml_op_is_empty(node->op) && !ggml_is_empty(node) && (node->flags & GGML_TENSOR_FLAG_COMPUTE); } -// scan the graph and figure out last compute op index -static inline int last_compute_op(ggml_cgraph * graph) { - int last = 0; - for (int i = 0; i < graph->n_nodes; ++i) { - if (is_compute_op(graph->nodes[i])) { - last = i; - } - } - - return last; -} - static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, ggml_cgraph * graph) { auto sess = static_cast(backend->context); - HEX_VERBOSE("ggml-hex: %s graph-compute n_nodes %d\n", sess->name.c_str(), graph->n_nodes); - - const int last = last_compute_op(graph); - - const struct ggml_tensor * prev_op = nullptr; // prev executed op + HEX_VERBOSE("ggml-hex: %s graph-compute n_nodes %d\n", sess->c_name(), graph->n_nodes); for (int i = 0; i < graph->n_nodes; ++i) { - ggml_tensor * node = graph->nodes[i]; - - if (!is_compute_op(node)) { - continue; - } - - uint32_t flags = 0; - - // skip quantizer if src1 is reused - if (op_reuse_src1(node, prev_op)) { - flags |= HTP_OPFLAGS_SKIP_QUANTIZE; - } - - prev_op = node; - - // ask for early notification for the last Op - if (i == last) { - flags |= HTP_OPFLAGS_EARLY_WAKEUP; - } - - switch (node->op) { - case GGML_OP_MUL_MAT: - if (ggml_is_quantized(node->src[0]->type)) { - ggml_hexagon_dispatch_op>(sess, node, flags); - } else { - ggml_hexagon_dispatch_op>(sess, node, flags); - } - break; - case GGML_OP_MUL_MAT_ID: - if (ggml_is_quantized(node->src[0]->type)) { - ggml_hexagon_dispatch_op>(sess, node, flags); - } else { - ggml_hexagon_dispatch_op>(sess, node, flags); - } - break; - case GGML_OP_MUL: - case GGML_OP_ADD: - case GGML_OP_SUB: - case GGML_OP_DIV: - ggml_hexagon_dispatch_op>(sess, node, flags); - break; - case GGML_OP_ADD_ID: - ggml_hexagon_dispatch_op>(sess, node, flags); - break; - case GGML_OP_RMS_NORM: - case GGML_OP_SCALE: - ggml_hexagon_dispatch_op(sess, node, flags); - break; - case GGML_OP_SQR: - case GGML_OP_SQRT: - ggml_hexagon_dispatch_op(sess, node, flags); - break; - case GGML_OP_SUM_ROWS: - ggml_hexagon_dispatch_op(sess, node, flags); - break; - case GGML_OP_UNARY: - switch (ggml_get_unary_op(node)) { - case GGML_UNARY_OP_NEG: - case GGML_UNARY_OP_EXP: - case GGML_UNARY_OP_SIGMOID: - case GGML_UNARY_OP_SOFTPLUS: - case GGML_UNARY_OP_SILU: - case GGML_UNARY_OP_GELU: - ggml_hexagon_dispatch_op(sess, node, flags); - break; - default: - break; - } - break; - case GGML_OP_GLU: - switch (ggml_get_glu_op(node)) { - case GGML_GLU_OP_SWIGLU: - case GGML_GLU_OP_SWIGLU_OAI: - case GGML_GLU_OP_GEGLU: - ggml_hexagon_dispatch_op(sess, node, flags); - break; - default: - break; - } - break; - case GGML_OP_SOFT_MAX: - ggml_hexagon_dispatch_op(sess, node, flags); - break; - - case GGML_OP_ROPE: - ggml_hexagon_dispatch_op(sess, node, flags); - break; - - case GGML_OP_FLASH_ATTN_EXT: - ggml_hexagon_dispatch_op(sess, node, flags); - break; - - case GGML_OP_SET_ROWS: - ggml_hexagon_dispatch_op(sess, node, flags); - break; - - case GGML_OP_GET_ROWS: - ggml_hexagon_dispatch_op(sess, node, flags); - break; - - case GGML_OP_CPY: - ggml_hexagon_dispatch_op(sess, node, flags); - break; - - case GGML_OP_CONT: - ggml_hexagon_dispatch_op(sess, node, flags); - break; - - case GGML_OP_REPEAT: - ggml_hexagon_dispatch_op(sess, node, flags); - break; - - case GGML_OP_ARGSORT: - ggml_hexagon_dispatch_op(sess, node, flags); - break; - - case GGML_OP_SSM_CONV: - ggml_hexagon_dispatch_op(sess, node, flags); - break; - - case GGML_OP_CUMSUM: - ggml_hexagon_dispatch_op(sess, node, flags); - break; - - default: - GGML_ABORT("\nggml-hex: graph-compute %s is not supported\n", ggml_op_desc(node)); + ggml_tensor * n = graph->nodes[i]; + if (op_is_compute(n)) { + sess->enqueue_op(op_remap_to_htp(n), n); } } @@ -2826,7 +2687,7 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg static void ggml_backend_hexagon_synchronize(ggml_backend_t backend) { auto sess = static_cast(backend->context); - HEX_VERBOSE("ggml-hex: %s synchronize\n", sess->name.c_str()); + HEX_VERBOSE("ggml-hex: %s synchronize\n", sess->c_name()); // Wait until all pending ops complete sess->flush(); @@ -3045,7 +2906,7 @@ static ggml_backend_t ggml_backend_hexagon_device_init(ggml_backend_dev_t dev, c static const char * ggml_backend_hexagon_device_get_name(ggml_backend_dev_t dev) { auto sess = static_cast(dev->context); - return sess->name.c_str(); + return sess->c_name(); GGML_UNUSED(dev); } @@ -3056,8 +2917,7 @@ static const char * ggml_backend_hexagon_device_get_description(ggml_backend_dev } static void ggml_backend_hexagon_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { - // ~2GB per session for now - *free = 2ULL * 1024 * 1024 * 1024; + *free = 0; *total = *free; GGML_UNUSED(dev); @@ -3172,6 +3032,11 @@ static bool ggml_hexagon_supported_repeat(const struct ggml_hexagon_session * se static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) { auto sess = static_cast(dev->context); + // reject ops that match the filter + if (opt_opfilter && std::regex_match(ggml_op_desc(op), *opt_opfilter)) { + return false; + } + // all srcs & dsts must be mapped to the same session if (!ggml_hexagon_supported_buffers(sess, op)) { ggml_hexagon_dump_op_supp(sess->name, op, false); @@ -3188,6 +3053,13 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons supp = true; break; + case GGML_OP_MUL: + case GGML_OP_ADD: + case GGML_OP_SUB: + case GGML_OP_DIV: + supp = ggml_hexagon_supported_binary(sess, op); + break; + case GGML_OP_MUL_MAT: supp = ggml_hexagon_supported_mul_mat(sess, op); break; @@ -3196,13 +3068,6 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons supp = ggml_hexagon_supported_mul_mat_id(sess, op); break; - case GGML_OP_MUL: - case GGML_OP_ADD: - case GGML_OP_SUB: - case GGML_OP_DIV: - supp = ggml_hexagon_supported_binary(sess, op); - break; - case GGML_OP_ADD_ID: supp = ggml_hexagon_supported_add_id(sess, op); break; @@ -3241,6 +3106,7 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons break; } break; + case GGML_OP_GLU: switch (ggml_get_glu_op(op)) { case GGML_GLU_OP_SWIGLU: @@ -3252,6 +3118,7 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons break; } break; + case GGML_OP_ROPE: supp = ggml_hexagon_supported_rope(sess, op); break; @@ -3438,11 +3305,13 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) { static_assert((unsigned int) HTP_TYPE_IQ4_NL == (unsigned int) GGML_TYPE_IQ4_NL, "please update hexagon_type to match ggml_type"); - const char * str_experimental = getenv("GGML_HEXAGON_EXPERIMENTAL"); const char * str_verbose = getenv("GGML_HEXAGON_VERBOSE"); const char * str_hostbuf = getenv("GGML_HEXAGON_HOSTBUF"); const char * str_opmask = getenv("GGML_HEXAGON_OPMASK"); const char * str_opsync = getenv("GGML_HEXAGON_OPSYNC"); + const char * str_opbatch = getenv("GGML_HEXAGON_OPBATCH"); + const char * str_opqueue = getenv("GGML_HEXAGON_OPQUEUE"); + const char * str_opfilter= getenv("GGML_HEXAGON_OPFILTER"); const char * str_profile = getenv("GGML_HEXAGON_PROFILE"); const char * str_etm = getenv("GGML_HEXAGON_ETM"); const char * str_nhvx = getenv("GGML_HEXAGON_NHVX"); @@ -3450,16 +3319,21 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) { const char * str_ndev = getenv("GGML_HEXAGON_NDEV"); const char * str_arch = getenv("GGML_HEXAGON_ARCH"); - opt_experimental = str_experimental ? atoi(str_experimental) : 0; + auto RE_ICASE = std::regex_constants::icase; + + opt_opfilter = str_opfilter ? new std::regex(str_opfilter, RE_ICASE) : NULL; opt_verbose = str_verbose ? atoi(str_verbose) : 0; opt_hostbuf = str_hostbuf ? atoi(str_hostbuf) : opt_hostbuf; - opt_opmask = str_opmask ? strtoul(str_opmask, NULL, 0) : opt_opmask; - opt_opsync = str_opsync ? atoi(str_opsync) : 0; + opt_opmask = str_opmask ? strtoul(str_opmask, NULL, 0) : opt_opmask; + opt_opsync = str_opsync ? atoi(str_opsync) : opt_opsync; + opt_opbatch = str_opbatch ? strtoul(str_opbatch, NULL, 0) : opt_opbatch; + opt_opqueue = str_opqueue ? strtoul(str_opqueue, NULL, 0) : opt_opqueue; opt_profile = str_profile ? atoi(str_profile) : 0; opt_etm = str_etm ? atoi(str_etm) : 0; opt_nhvx = str_nhvx ? strtoul(str_nhvx, NULL, 0) : opt_nhvx; opt_use_hmx = str_use_hmx ? atoi(str_use_hmx) : opt_use_hmx; opt_ndev = str_ndev ? strtoul(str_ndev, NULL, 0) : opt_ndev; + opt_hostbuf = str_hostbuf ? atoi(str_hostbuf) : opt_hostbuf; if (opt_ndev > GGML_HEXAGON_MAX_SESSIONS) { opt_ndev = GGML_HEXAGON_MAX_SESSIONS; @@ -3472,12 +3346,7 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) { opt_arch = strtoul(str_arch, NULL, 0); } - opt_hostbuf = str_hostbuf ? atoi(str_hostbuf) : 1; - reg->context = new ggml_hexagon_registry(reg); - - HEX_VERBOSE("ggml-hex: size-of-general-req %zu size-of-general-rsp %zu\n", sizeof(struct htp_general_req), - sizeof(struct htp_general_rsp)); } static const struct ggml_backend_reg_i ggml_backend_hexagon_reg_i = { diff --git a/ggml/src/ggml-hexagon/htp/act-ops.c b/ggml/src/ggml-hexagon/htp/act-ops.c index d8b924981e0..6416d2dfbc3 100644 --- a/ggml/src/ggml-hexagon/htp/act-ops.c +++ b/ggml/src/ggml-hexagon/htp/act-ops.c @@ -14,59 +14,42 @@ #define GGML_COMMON_DECL_C #include "ggml-common.h" #include "htp-ctx.h" -#include "htp-msg.h" +#include "htp-ops.h" #include "htp-ops.h" -#define htp_act_preamble3 \ - const uint32_t ne00 = src0->ne[0]; \ - const uint32_t ne01 = src0->ne[1]; \ - const uint32_t ne02 = src0->ne[2]; \ - const uint32_t ne03 = src0->ne[3]; \ - \ - const uint32_t ne10 = src1->ne[0]; \ - const uint32_t ne11 = src1->ne[1]; \ - const uint32_t ne12 = src1->ne[2]; \ - const uint32_t ne13 = src1->ne[3]; \ - \ - const uint32_t ne0 = dst->ne[0]; \ - const uint32_t ne1 = dst->ne[1]; \ - const uint32_t ne2 = dst->ne[2]; \ - const uint32_t ne3 = dst->ne[3]; \ - \ - const uint32_t nb00 = src0->nb[0]; \ - const uint32_t nb01 = src0->nb[1]; \ - const uint32_t nb02 = src0->nb[2]; \ - const uint32_t nb03 = src0->nb[3]; \ - \ - const uint32_t nb10 = src1->nb[0]; \ - const uint32_t nb11 = src1->nb[1]; \ - const uint32_t nb12 = src1->nb[2]; \ - const uint32_t nb13 = src1->nb[3]; \ - \ - const uint32_t nb0 = dst->nb[0]; \ - const uint32_t nb1 = dst->nb[1]; \ - const uint32_t nb2 = dst->nb[2]; \ - const uint32_t nb3 = dst->nb[3]; - -#define htp_act_preamble2 \ - const uint32_t ne00 = src0->ne[0]; \ - const uint32_t ne01 = src0->ne[1]; \ - const uint32_t ne02 = src0->ne[2]; \ - const uint32_t ne03 = src0->ne[3]; \ - \ - const uint32_t ne0 = dst->ne[0]; \ - const uint32_t ne1 = dst->ne[1]; \ - const uint32_t ne2 = dst->ne[2]; \ - const uint32_t ne3 = dst->ne[3]; \ - \ - const uint32_t nb00 = src0->nb[0]; \ - const uint32_t nb01 = src0->nb[1]; \ - const uint32_t nb02 = src0->nb[2]; \ - const uint32_t nb03 = src0->nb[3]; \ - \ - const uint32_t nb0 = dst->nb[0]; \ - const uint32_t nb1 = dst->nb[1]; \ - const uint32_t nb2 = dst->nb[2]; \ +#define htp_act_preamble \ + const struct htp_tensor * src0 = actx->octx->src[0]; \ + const struct htp_tensor * src1 = actx->octx->src[1]; \ + const struct htp_tensor * dst = actx->octx->dst; \ + \ + const uint32_t ne00 = src0->ne[0]; \ + const uint32_t ne01 = src0->ne[1]; \ + const uint32_t ne02 = src0->ne[2]; \ + const uint32_t ne03 = src0->ne[3]; \ + \ + const uint32_t nb00 = src0->nb[0]; \ + const uint32_t nb01 = src0->nb[1]; \ + const uint32_t nb02 = src0->nb[2]; \ + const uint32_t nb03 = src0->nb[3]; \ + \ + const uint32_t ne10 = src1 ? src1->ne[0] : 0; \ + const uint32_t ne11 = src1 ? src1->ne[1] : 0; \ + const uint32_t ne12 = src1 ? src1->ne[2] : 0; \ + const uint32_t ne13 = src1 ? src1->ne[3] : 0; \ + \ + const uint32_t nb10 = src1 ? src1->nb[0] : 0; \ + const uint32_t nb11 = src1 ? src1->nb[1] : 0; \ + const uint32_t nb12 = src1 ? src1->nb[2] : 0; \ + const uint32_t nb13 = src1 ? src1->nb[3] : 0; \ + \ + const uint32_t ne0 = dst->ne[0]; \ + const uint32_t ne1 = dst->ne[1]; \ + const uint32_t ne2 = dst->ne[2]; \ + const uint32_t ne3 = dst->ne[3]; \ + \ + const uint32_t nb0 = dst->nb[0]; \ + const uint32_t nb1 = dst->nb[1]; \ + const uint32_t nb2 = dst->nb[2]; \ const uint32_t nb3 = dst->nb[3]; struct htp_act_context { @@ -97,10 +80,7 @@ struct htp_act_context { static void glu_swiglu_f32_per_thread(unsigned int nth, unsigned int ith, void * data) { struct htp_act_context * actx = (struct htp_act_context *) data; - const struct htp_tensor * src0 = &actx->octx->src0; - const struct htp_tensor * src1 = &actx->octx->src1; - const struct htp_tensor * dst = &actx->octx->dst; - htp_act_preamble3; + htp_act_preamble; size_t src0_row_size = actx->src0_row_size; size_t src1_row_size = actx->src1_row_size; @@ -207,10 +187,7 @@ static void glu_swiglu_f32_per_thread(unsigned int nth, unsigned int ith, void * static void glu_swiglu_oai_f32_per_thread(unsigned int nth, unsigned int ith, void * data) { struct htp_act_context * actx = (struct htp_act_context *) data; - const struct htp_tensor * src0 = &actx->octx->src0; - const struct htp_tensor * src1 = &actx->octx->src1; - const struct htp_tensor * dst = &actx->octx->dst; - htp_act_preamble3; + htp_act_preamble; uint64_t t1, t2; t1 = HAP_perf_get_qtimer_count(); @@ -332,9 +309,7 @@ static void glu_swiglu_oai_f32_per_thread(unsigned int nth, unsigned int ith, vo static void unary_gelu_f32_per_thread(unsigned int nth, unsigned int ith, void * data) { struct htp_act_context * actx = (struct htp_act_context *) data; - const struct htp_tensor * src0 = &actx->octx->src0; - const struct htp_tensor * dst = &actx->octx->dst; - htp_act_preamble2; + htp_act_preamble; uint64_t t1, t2; t1 = HAP_perf_get_qtimer_count(); @@ -433,9 +408,7 @@ static void unary_gelu_f32_per_thread(unsigned int nth, unsigned int ith, void * static void unary_silu_f32_per_thread(unsigned int nth, unsigned int ith, void * data) { struct htp_act_context * actx = (struct htp_act_context *) data; - const struct htp_tensor * src0 = &actx->octx->src0; - const struct htp_tensor * dst = &actx->octx->dst; - htp_act_preamble2; + htp_act_preamble; uint64_t t1, t2; t1 = HAP_perf_get_qtimer_count(); @@ -533,10 +506,7 @@ static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; static void glu_geglu_f32_per_thread(unsigned int nth, unsigned int ith, void * data) { struct htp_act_context * actx = (struct htp_act_context *) data; - const struct htp_tensor * src0 = &actx->octx->src0; - const struct htp_tensor * src1 = &actx->octx->src1; - const struct htp_tensor * dst = &actx->octx->dst; - htp_act_preamble3; + htp_act_preamble; size_t src0_row_size = actx->src0_row_size; size_t src1_row_size = actx->src1_row_size; @@ -652,9 +622,9 @@ static void glu_geglu_f32_per_thread(unsigned int nth, unsigned int ith, void * } static int execute_op_activations_f32(struct htp_ops_context * octx) { - const struct htp_tensor * src0 = &octx->src0; - const struct htp_tensor * src1 = &octx->src1; - struct htp_tensor * dst = &octx->dst; + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * src1 = octx->src[1]; + const struct htp_tensor * dst = octx->dst; if (((src0->ne[0] * SIZEOF_FP32) != src0->nb[1]) || ((dst->ne[0] * SIZEOF_FP32) != dst->nb[1])) { FARF(ERROR, "Non-contiguous tensors are not supported at this time \n"); @@ -697,25 +667,20 @@ static int execute_op_activations_f32(struct htp_ops_context * octx) { const uint32_t n_threads = MIN(octx->n_threads, src0_nrows); size_t src0_row_size = src0->nb[1]; - size_t src1_row_size = src1->nb[1]; // zero bytes if src1 is not used + size_t src1_row_size = src1 ? src1->nb[1] : src0->nb[1]; size_t dst_row_size = dst->nb[1]; - const bool src1_valid = src1->ne[0]; - if (!src1_valid) { - src1_row_size = src0_row_size; - } - const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN); const size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN); const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN); + // VTCM scratchpads for all tensors // N rows per thread, padded to HVX vector size - size_t spad_size_per_row = (src0_row_size_aligned + src1_row_size_aligned) + dst_row_size_aligned; size_t vtcm_row_per_thread = (octx->ctx->vtcm_size)/ (n_threads* spad_size_per_row); // Make sure the reserved vtcm size is sufficient - if(vtcm_row_per_thread ==0){ + if (vtcm_row_per_thread == 0) { FARF(ERROR, "act-%s : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n", op_type, octx->ctx->vtcm_size, spad_size_per_row * n_threads); return HTP_STATUS_VTCM_TOO_SMALL; @@ -733,7 +698,11 @@ static int execute_op_activations_f32(struct htp_ops_context * octx) { octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; - if (src1->ne[0]) { + octx->src0_spad.src = NULL; + octx->src1_spad.src = NULL; + octx->dst_spad.src = NULL; + + if (src1) { FARF(HIGH, "%s: %ux%ux%ux%u x %ux%ux%ux%u -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], octx->src0_spad.size, octx->src1_spad.size, @@ -773,9 +742,9 @@ static int execute_op_activations_f32(struct htp_ops_context * octx) { // Pointers and GLU logic const uint8_t * data_src0 = (const uint8_t *) src0->data; - const uint8_t * data_src1 = (const uint8_t *) src1->data; + const uint8_t * data_src1 = src1 ? (const uint8_t *) src1->data : NULL; - if (!src1_valid && (octx->op == HTP_OP_GLU_SWIGLU || octx->op == HTP_OP_GLU_SWIGLU_OAI || octx->op == HTP_OP_GLU_GEGLU)) { + if (!src1 && (octx->op == HTP_OP_GLU_SWIGLU || octx->op == HTP_OP_GLU_SWIGLU_OAI || octx->op == HTP_OP_GLU_GEGLU)) { const int32_t swapped = octx->op_params[1]; data_src1 = data_src0; actx.src1_row_size = actx.src0_row_size; @@ -799,7 +768,7 @@ static int execute_op_activations_f32(struct htp_ops_context * octx) { int op_activations(struct htp_ops_context * octx) { int err = HTP_STATUS_OK; - switch (octx->src0.type) { + switch (octx->src[0]->type) { case HTP_TYPE_F32: err = execute_op_activations_f32(octx); break; diff --git a/ggml/src/ggml-hexagon/htp/argsort-ops.c b/ggml/src/ggml-hexagon/htp/argsort-ops.c index 3ec26a4c1ac..bdd0623615d 100644 --- a/ggml/src/ggml-hexagon/htp/argsort-ops.c +++ b/ggml/src/ggml-hexagon/htp/argsort-ops.c @@ -12,7 +12,7 @@ #include "hex-dma.h" #include "htp-ctx.h" -#include "htp-msg.h" +#include "htp-ops.h" #include "htp-ops.h" #ifndef MIN @@ -175,8 +175,8 @@ static void htp_argsort_f32(unsigned int n, unsigned int i, void * data) { struct htp_ops_context * octx = actx->octx; // Unpack context - const struct htp_tensor * src0 = &octx->src0; - const struct htp_tensor * dst = &octx->dst; + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * dst = octx->dst; // Scratchpad memory uint8_t * spad = octx->src0_spad.data + octx->src0_spad.size_per_thread * i; @@ -249,16 +249,16 @@ static void htp_argsort_f32(unsigned int n, unsigned int i, void * data) { int op_argsort(struct htp_ops_context * octx) { // Check supported types - if (octx->src0.type != HTP_TYPE_F32) { + if (octx->src[0]->type != HTP_TYPE_F32) { return HTP_STATUS_NO_SUPPORT; } - const uint32_t total_rows = octx->src0.ne[1] * octx->src0.ne[2] * octx->src0.ne[3]; + const uint32_t total_rows = octx->src[0]->ne[1] * octx->src[0]->ne[2] * octx->src[0]->ne[3]; const uint32_t n_threads = MIN(total_rows, octx->n_threads); // Allocate scratchpad // We need 1 row of float + 1 row of int32 per thread. - uint32_t ne00 = octx->src0.ne[0]; + uint32_t ne00 = octx->src[0]->ne[0]; size_t values_size = hex_round_up(ne00 * sizeof(float), 128); size_t indices_size = hex_round_up(ne00 * sizeof(int32_t), 128); size_t spad_per_thread = values_size + indices_size; @@ -278,9 +278,9 @@ int op_argsort(struct htp_ops_context * octx) { octx->src0_spad.size_per_thread = spad_per_thread; FARF(HIGH, "argsort: %ux%ux%ux%u -> %ux%ux%ux%u (0x%x, 0x%x)", - octx->src0.ne[0], octx->src0.ne[1], octx->src0.ne[2], octx->src0.ne[3], - octx->dst.ne[0], octx->dst.ne[1], octx->dst.ne[2], octx->dst.ne[3], - octx->src0.data, octx->dst.data); + octx->src[0]->ne[0], octx->src[0]->ne[1], octx->src[0]->ne[2], octx->src[0]->ne[3], + octx->dst->ne[0], octx->dst->ne[1], octx->dst->ne[2], octx->dst->ne[3], + octx->src[0]->data, octx->dst->data); struct htp_argsort_context actx; actx.octx = octx; diff --git a/ggml/src/ggml-hexagon/htp/binary-ops.c b/ggml/src/ggml-hexagon/htp/binary-ops.c index 1b0f97493bc..52013ad0fec 100644 --- a/ggml/src/ggml-hexagon/htp/binary-ops.c +++ b/ggml/src/ggml-hexagon/htp/binary-ops.c @@ -14,7 +14,7 @@ #define GGML_COMMON_DECL_C #include "ggml-common.h" #include "htp-ctx.h" -#include "htp-msg.h" +#include "htp-ops.h" #include "htp-ops.h" #ifndef MIN @@ -43,10 +43,10 @@ struct htp_binary_context { bool split_at_ne02; }; -#define htp_binary_preamble \ - const struct htp_tensor * src0 = &octx->src0; \ - const struct htp_tensor * src1 = &octx->src1; \ - struct htp_tensor * dst = &octx->dst; \ +#define htp_binary_preamble \ + const struct htp_tensor * src0 = octx->src[0]; \ + const struct htp_tensor * src1 = octx->src[1]; \ + const struct htp_tensor * dst = octx->dst; \ \ const uint32_t ne00 = src0->ne[0]; \ const uint32_t ne01 = src0->ne[1]; \ @@ -181,7 +181,7 @@ static void binary_job_scalar(unsigned int nth, unsigned int ith, void * data) { struct htp_ops_context * octx = bctx->octx; htp_binary_preamble; - const uint32_t src0_type = octx->src0.type; + const uint32_t src0_type = octx->src[0]->type; const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16); const uint32_t total_rows = ne01 * ne02 * ne03; const uint32_t start_row = bctx->nrows_per_thread * ith; @@ -274,7 +274,7 @@ static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, voi struct htp_ops_context * octx = bctx->octx; htp_binary_preamble; - const uint32_t src0_type = octx->src0.type; + const uint32_t src0_type = octx->src[0]->type; const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16); const uint32_t total_rows = ne01 * ne02 * ne03; const uint32_t start_row = bctx->nrows_per_thread * ith; @@ -374,7 +374,7 @@ static void binary_job_vector_row_broadcast(unsigned int nth, unsigned int ith, struct htp_ops_context * octx = bctx->octx; htp_binary_preamble; - const uint32_t src0_type = octx->src0.type; + const uint32_t src0_type = octx->src[0]->type; const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16); const uint32_t total_rows = ne01 * ne02 * ne03; const uint32_t start_row = bctx->nrows_per_thread * ith; @@ -455,7 +455,7 @@ static void binary_job_vector_complex(unsigned int nth, unsigned int ith, void * struct htp_ops_context * octx = bctx->octx; htp_binary_preamble; - const uint32_t src0_type = octx->src0.type; + const uint32_t src0_type = octx->src[0]->type; const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16); const uint32_t total_rows = ne01 * ne02 * ne03; const uint32_t start_row = bctx->nrows_per_thread * ith; @@ -540,7 +540,7 @@ static void binary_job_element_repeat(unsigned int nth, unsigned int ith, void * struct htp_ops_context * octx = bctx->octx; htp_binary_preamble; - const uint32_t src0_type = octx->src0.type; + const uint32_t src0_type = octx->src[0]->type; const uint32_t elem_size_bytes = (src0_type == HTP_TYPE_F32) ? sizeof(float) : sizeof(_Float16); const uint32_t row_size_bytes = ne00 * elem_size_bytes;; const uint32_t total_rows = ne01 * ne02 * ne03; @@ -629,10 +629,10 @@ static void binary_job_add_id(unsigned int nth, unsigned int ith, void * data) { struct htp_binary_context * bctx = (struct htp_binary_context *) data; struct htp_ops_context * octx = bctx->octx; - const struct htp_tensor * src0 = &octx->src0; - const struct htp_tensor * src1 = &octx->src1; - const struct htp_tensor * src2 = &octx->src2; - struct htp_tensor * dst = &octx->dst; + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * src1 = octx->src[1]; + const struct htp_tensor * src2 = octx->src[2]; + const struct htp_tensor * dst = octx->dst; const uint32_t ne00 = src0->ne[0]; const uint32_t ne01 = src0->ne[1]; @@ -723,15 +723,15 @@ static void binary_job_add_id(unsigned int nth, unsigned int ith, void * data) { } static int execute_op_binary(struct htp_ops_context * octx) { - const struct htp_tensor * src0 = &octx->src0; - const struct htp_tensor * src1 = &octx->src1; - struct htp_tensor * dst = &octx->dst; + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * src1 = octx->src[1]; + const struct htp_tensor * dst = octx->dst; const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3]; const uint32_t n_threads = MIN(octx->n_threads, src0_nrows); // Use packed row sizes for VTCM allocation - const uint32_t src0_type = octx->src0.type; + const uint32_t src0_type = octx->src[0]->type; const size_t elem_size = (src0_type == HTP_TYPE_F32) ? sizeof(float) : sizeof(_Float16); const size_t src0_row_size = src0->ne[0] * elem_size; const size_t src1_row_size = src1->ne[0] * elem_size; @@ -799,9 +799,9 @@ static int execute_op_binary(struct htp_ops_context * octx) { return HTP_STATUS_VTCM_TOO_SMALL; } - octx->src0_spad.data = octx->ctx->vtcm_base; - octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; - octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; + octx->src0_spad.data = octx->ctx->vtcm_base; octx->src0_spad.src = NULL; + octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; octx->src1_spad.src = NULL; + octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; octx->dst_spad.src = NULL; if ((octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { return HTP_STATUS_OK; @@ -857,12 +857,12 @@ static int execute_op_binary(struct htp_ops_context * octx) { int op_binary(struct htp_ops_context * octx) { // Does not support permutations of src1 - const struct htp_tensor * src1 = &octx->src1; + const struct htp_tensor * src1 = octx->src[1]; if (src1->nb[1] < src1->nb[0]) { return HTP_STATUS_NO_SUPPORT; } - const uint32_t src0_type = octx->src0.type; + const uint32_t src0_type = octx->src[0]->type; if ((src0_type == HTP_TYPE_F32) || (src0_type == HTP_TYPE_F16)) { return execute_op_binary(octx); } diff --git a/ggml/src/ggml-hexagon/htp/cpy-ops.c b/ggml/src/ggml-hexagon/htp/cpy-ops.c index a40d866b9c3..e5b9d350fd7 100644 --- a/ggml/src/ggml-hexagon/htp/cpy-ops.c +++ b/ggml/src/ggml-hexagon/htp/cpy-ops.c @@ -11,7 +11,7 @@ #define GGML_COMMON_DECL_C #include "ggml-common.h" #include "htp-ctx.h" -#include "htp-msg.h" +#include "htp-ops.h" #include "htp-ops.h" #include "hvx-utils.h" @@ -32,10 +32,10 @@ struct htp_copy_context { void (*copy)(struct htp_copy_context * ct, struct htp_ops_context * octx, int nth, int ith); }; -#define cpy_preamble \ - struct htp_tensor *src0 = &octx->src0; \ - struct htp_tensor *dst = &octx->dst; \ - \ +#define cpy_preamble \ + const struct htp_tensor *src0 = octx->src[0]; \ + const struct htp_tensor *dst = octx->dst; \ + \ const uint32_t ne00 = src0->ne[0]; \ const uint32_t ne01 = src0->ne[1]; \ const uint32_t ne02 = src0->ne[2]; \ diff --git a/ggml/src/ggml-hexagon/htp/cumsum-ops.c b/ggml/src/ggml-hexagon/htp/cumsum-ops.c index ce51555a7fd..2ced1971236 100644 --- a/ggml/src/ggml-hexagon/htp/cumsum-ops.c +++ b/ggml/src/ggml-hexagon/htp/cumsum-ops.c @@ -13,9 +13,9 @@ #include "hvx-utils.h" #include "hex-dma.h" -#define htp_cumsum_tensors_preamble \ - struct htp_tensor * restrict src0 = &octx->src0; \ - struct htp_tensor * restrict dst = &octx->dst; \ +#define htp_cumsum_tensors_preamble \ + const struct htp_tensor * restrict src0 = octx->src[0]; \ + const struct htp_tensor * restrict dst = octx->dst; \ \ const uint32_t ne00 = src0->ne[0]; \ const uint32_t ne01 = src0->ne[1]; \ @@ -206,8 +206,8 @@ static void cumsum_thread_f32(unsigned int nth, unsigned int ith, void * data) { } int op_cumsum_f32(struct htp_ops_context * octx) { - const struct htp_tensor * src0 = &octx->src0; - const struct htp_tensor * dst = &octx->dst; + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * dst = octx->dst; if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { return HTP_STATUS_OK; @@ -226,10 +226,12 @@ int op_cumsum_f32(struct htp_ops_context * octx) { octx->src0_spad.size_per_thread = src_row_size_aligned * 2; octx->dst_spad.size_per_thread = dst_row_size_aligned * 2; - octx->src0_spad.size = n_threads * octx->src0_spad.size_per_thread; - octx->dst_spad.size = n_threads * octx->dst_spad.size_per_thread; - octx->src0_spad.data = octx->ctx->vtcm_base; - octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size; + + octx->src0_spad.size = n_threads * octx->src0_spad.size_per_thread; + octx->dst_spad.size = n_threads * octx->dst_spad.size_per_thread; + + octx->src0_spad.data = octx->ctx->vtcm_base; octx->src0_spad.src = NULL; + octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size; octx->dst_spad.src = NULL; struct htp_cumsum_context cctx = { .octx = octx, @@ -251,8 +253,9 @@ int op_cumsum_f32(struct htp_ops_context * octx) { } int op_cumsum(struct htp_ops_context * octx) { - int err = HTP_STATUS_OK; - struct htp_tensor * dst = &octx->dst; + const struct htp_tensor * dst = octx->dst; + + int err = HTP_STATUS_OK; switch (dst->type) { case HTP_TYPE_F32: diff --git a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c index 0c9bc785620..d296a322589 100644 --- a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +++ b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c @@ -15,7 +15,7 @@ #define GGML_COMMON_DECL_C #include "ggml-common.h" #include "htp-ctx.h" -#include "htp-msg.h" +#include "htp-ops.h" #include "htp-ops.h" // Must be multiple of 32 @@ -278,12 +278,12 @@ static inline void hvx_scale_vec_f32_aa(uint8_t * restrict dst, const uint8_t * static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void * data) { struct htp_fa_context * factx = (struct htp_fa_context *) data; const struct htp_ops_context * octx = factx->octx; - const struct htp_tensor * q = &octx->src0; - const struct htp_tensor * k = &octx->src1; - const struct htp_tensor * v = &octx->src2; - const struct htp_tensor * mask = (octx->src3.data) ? &octx->src3 : NULL; - const struct htp_tensor * sinks = (octx->src4.data) ? &octx->src4 : NULL; - const struct htp_tensor * dst = &octx->dst; + const struct htp_tensor * q = octx->src[0]; + const struct htp_tensor * k = octx->src[1]; + const struct htp_tensor * v = octx->src[2]; + const struct htp_tensor * mask = octx->src[3]; + const struct htp_tensor * sinks = octx->src[4]; + const struct htp_tensor * dst = octx->dst; const uint32_t neq0 = q->ne[0]; const uint32_t neq1 = q->ne[1]; @@ -610,11 +610,11 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void * } int op_flash_attn_ext(struct htp_ops_context * octx) { - const struct htp_tensor * q = &octx->src0; - const struct htp_tensor * k = &octx->src1; - const struct htp_tensor * v = &octx->src2; - const struct htp_tensor * mask = (octx->src3.data) ? &octx->src3 : NULL; - const struct htp_tensor * dst = &octx->dst; + const struct htp_tensor * q = octx->src[0]; + const struct htp_tensor * k = octx->src[1]; + const struct htp_tensor * v = octx->src[2]; + const struct htp_tensor * mask = octx->src[3]; + const struct htp_tensor * dst = octx->dst; // Check support if ((q->type != HTP_TYPE_F16 && q->type != HTP_TYPE_F32) || k->type != HTP_TYPE_F16 || v->type != HTP_TYPE_F16) { @@ -701,13 +701,11 @@ int op_flash_attn_ext(struct htp_ops_context * octx) { return HTP_STATUS_VTCM_TOO_SMALL; } - octx->src0_spad.data = octx->ctx->vtcm_base; - octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; - octx->src2_spad.data = octx->src1_spad.data + octx->src1_spad.size; - octx->src3_spad.data = octx->src2_spad.data + octx->src2_spad.size; - octx->dst_spad.data = octx->src3_spad.data + octx->src3_spad.size; - - // FARF(ERROR, "fa: qrows-per-thread %u", factx.qrows_per_thread); + octx->src0_spad.data = octx->ctx->vtcm_base; octx->src0_spad.src = NULL; + octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; octx->src1_spad.src = NULL; + octx->src2_spad.data = octx->src1_spad.data + octx->src1_spad.size; octx->src2_spad.src = NULL; + octx->src3_spad.data = octx->src2_spad.data + octx->src2_spad.size; octx->src3_spad.src = NULL; + octx->dst_spad.data = octx->src3_spad.data + octx->src3_spad.size; octx->dst_spad.src = NULL; if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { worker_pool_run_func(octx->ctx->worker_pool, flash_attn_ext_f16_thread, &factx, octx->n_threads); diff --git a/ggml/src/ggml-hexagon/htp/get-rows-ops.c b/ggml/src/ggml-hexagon/htp/get-rows-ops.c index 047d2850aaa..5a1dc933860 100644 --- a/ggml/src/ggml-hexagon/htp/get-rows-ops.c +++ b/ggml/src/ggml-hexagon/htp/get-rows-ops.c @@ -11,7 +11,7 @@ #define GGML_COMMON_DECL_C #include "ggml-common.h" #include "htp-ctx.h" -#include "htp-msg.h" +#include "htp-ops.h" #include "htp-ops.h" #include "hvx-utils.h" @@ -23,27 +23,33 @@ struct get_rows_context { }; #define get_rows_preamble \ - const uint32_t ne00 = octx->src0.ne[0]; \ - const uint32_t ne01 = octx->src0.ne[1]; \ - const uint32_t ne02 = octx->src0.ne[2]; \ - const uint32_t ne03 = octx->src0.ne[3]; \ - \ - const uint32_t ne10 = octx->src1.ne[0]; \ - const uint32_t ne11 = octx->src1.ne[1]; \ - const uint32_t ne12 = octx->src1.ne[2]; \ - \ - const uint32_t nb01 = octx->src0.nb[1]; \ - const uint32_t nb02 = octx->src0.nb[2]; \ - const uint32_t nb03 = octx->src0.nb[3]; \ - \ - const uint32_t nb10 = octx->src1.nb[0]; \ - const uint32_t nb11 = octx->src1.nb[1]; \ - const uint32_t nb12 = octx->src1.nb[2]; \ - \ - const uint32_t nb1 = octx->dst.nb[1]; \ - const uint32_t nb2 = octx->dst.nb[2]; \ - const uint32_t nb3 = octx->dst.nb[3]; \ - \ + const uint32_t ne00 = octx->src[0]->ne[0]; \ + const uint32_t ne01 = octx->src[0]->ne[1]; \ + const uint32_t ne02 = octx->src[0]->ne[2]; \ + const uint32_t ne03 = octx->src[0]->ne[3]; \ + \ + const uint32_t ne10 = octx->src[1]->ne[0]; \ + const uint32_t ne11 = octx->src[1]->ne[1]; \ + const uint32_t ne12 = octx->src[1]->ne[2]; \ + const uint32_t ne13 = octx->src[1]->ne[3]; \ + \ + const uint32_t ne0 = octx->dst->ne[0]; \ + const uint32_t ne1 = octx->dst->ne[1]; \ + const uint32_t ne2 = octx->dst->ne[2]; \ + const uint32_t ne3 = octx->dst->ne[3]; \ + \ + const uint32_t nb01 = octx->src[0]->nb[1]; \ + const uint32_t nb02 = octx->src[0]->nb[2]; \ + const uint32_t nb03 = octx->src[0]->nb[3]; \ + \ + const uint32_t nb10 = octx->src[1]->nb[0]; \ + const uint32_t nb11 = octx->src[1]->nb[1]; \ + const uint32_t nb12 = octx->src[1]->nb[2]; \ + \ + const uint32_t nb1 = octx->dst->nb[1]; \ + const uint32_t nb2 = octx->dst->nb[2]; \ + const uint32_t nb3 = octx->dst->nb[3]; \ + \ const uint32_t nr = ne10 * ne11 * ne12; static void get_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *data) { @@ -51,12 +57,14 @@ static void get_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *da struct htp_ops_context * octx = grctx->octx; get_rows_preamble; + uint64_t qt = HAP_perf_get_qtimer_count(); + // parallelize by src1 elements (which correspond to dst rows) const uint32_t dr = grctx->src1_nrows_per_thread; const uint32_t ir0 = dr * ith; const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr; - const bool is_i32 = (octx->src1.type == HTP_TYPE_I32); + const bool is_i32 = (octx->src[1]->type == HTP_TYPE_I32); for (uint32_t i = ir0; i < ir1; ++i) { const uint32_t i12 = fastdiv(i, &grctx->get_rows_div_ne10_ne11); @@ -64,7 +72,7 @@ static void get_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *da const uint32_t i11 = fastdiv(rem, &grctx->get_rows_div_ne10); const uint32_t i10 = rem - i11 * ne10; - const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12; + const uintptr_t src1_addr = octx->src[1]->data + i10*nb10 + i11*nb11 + i12*nb12; uint32_t i01 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr; @@ -73,10 +81,14 @@ static void get_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *da continue; } - const uintptr_t src0_ptr = octx->src0.data + i01*nb01 + i11*nb02 + i12*nb03; - const uintptr_t dst_ptr = octx->dst.data + i10*nb1 + i11*nb2 + i12*nb3; + const uintptr_t src0_ptr = octx->src[0]->data + i01*nb01 + i11*nb02 + i12*nb03; + const uintptr_t dst_ptr = octx->dst->data + i10*nb1 + i11*nb2 + i12*nb3; hvx_copy_f32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, ne00); } + + qt = HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - qt); + FARF(HIGH, "get-rows-f32-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, + ne00, ne01, ne02, ne03, ir0, ir1, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, (unsigned) qt); } int op_get_rows(struct htp_ops_context * octx) { @@ -84,15 +96,15 @@ int op_get_rows(struct htp_ops_context * octx) { const uint32_t n_threads = MIN(nr, octx->n_threads); - if (octx->src0.type != HTP_TYPE_F32) { + if (octx->src[0]->type != HTP_TYPE_F32) { return HTP_STATUS_NO_SUPPORT; } - if (octx->dst.type != HTP_TYPE_F32) { + if (octx->dst->type != HTP_TYPE_F32) { return HTP_STATUS_NO_SUPPORT; } - if (octx->src1.type != HTP_TYPE_I32 && octx->src1.type != HTP_TYPE_I64) { + if (octx->src[1]->type != HTP_TYPE_I32 && octx->src[1]->type != HTP_TYPE_I64) { return HTP_STATUS_NO_SUPPORT; } @@ -102,8 +114,8 @@ int op_get_rows(struct htp_ops_context * octx) { struct get_rows_context grctx; grctx.octx = octx; - grctx.get_rows_div_ne10 = init_fastdiv_values(octx->src1.ne[0]); - grctx.get_rows_div_ne10_ne11 = init_fastdiv_values(octx->src1.ne[0] * octx->src1.ne[1]); + grctx.get_rows_div_ne10 = init_fastdiv_values(octx->src[1]->ne[0]); + grctx.get_rows_div_ne10_ne11 = init_fastdiv_values(octx->src[1]->ne[0] * octx->src[1]->ne[1]); grctx.src1_nrows_per_thread = (nr + n_threads - 1) / n_threads; diff --git a/ggml/src/ggml-hexagon/htp/hex-utils.h b/ggml/src/ggml-hexagon/htp/hex-utils.h index 8ed1456bc54..fe0b661e309 100644 --- a/ggml/src/ggml-hexagon/htp/hex-utils.h +++ b/ggml/src/ggml-hexagon/htp/hex-utils.h @@ -3,8 +3,10 @@ #include #include +#include #include "hexagon_types.h" +#include "hexagon_protos.h" #include "hex-fastdiv.h" #include "hex-dump.h" @@ -68,4 +70,23 @@ static inline void hex_l2fetch(const void * p, uint32_t width, uint32_t stride, Q6_l2fetch_AP((void *) p, control); } +#define HEX_L2_LINE_SIZE 64 +#define HEX_L2_FLUSH_SIZE (128 * 1024) + +static inline void hex_l2flush(void * addr, size_t size) +{ + if (size > HEX_L2_FLUSH_SIZE) { + qurt_mem_cache_clean((qurt_addr_t) 0, 0, QURT_MEM_CACHE_FLUSH_INVALIDATE_ALL, QURT_MEM_DCACHE); + } else { + const uint32_t s = (uint32_t) addr; + const uint32_t e = s + size; + for (uint32_t i = s; i < e; i += HEX_L2_LINE_SIZE * 4) { + Q6_dccleaninva_A((void *) i + HEX_L2_LINE_SIZE * 0); + Q6_dccleaninva_A((void *) i + HEX_L2_LINE_SIZE * 1); + Q6_dccleaninva_A((void *) i + HEX_L2_LINE_SIZE * 2); + Q6_dccleaninva_A((void *) i + HEX_L2_LINE_SIZE * 3); + } + } +} + #endif /* HEX_UTILS_H */ diff --git a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c index 4ff2b36de96..ec191c14981 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c @@ -20,7 +20,7 @@ #include "hvx-dump.h" #include "worker-pool.h" #include "htp-ctx.h" -#include "htp-msg.h" +#include "htp-ops.h" #include "hmx-utils.h" #include "hmx-ops.h" @@ -821,7 +821,7 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu // and each q_head is computed individually to avoid tile-major packing // issues. m_chunk_n_rows is always a multiple of 32 (from // hmx_compute_chunks), so per-head tile arrays don't overlap. - const size_t vtcm_budget = ctx->vtcm_scratch_size; + const size_t vtcm_budget = ctx->vtcm_size; const size_t vec_dot_size = params->k * sizeof(__fp16); // When the activation has a large stride (e.g. permuted Q tensor with @@ -998,7 +998,7 @@ int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, float *restrict dst, co } // --- Dynamic VTCM layout --- - const size_t vtcm_budget = ctx->vtcm_scratch_size; + const size_t vtcm_budget = ctx->vtcm_size; const size_t vec_dot_size = k * sizeof(__fp16); // DMA-based activation gather for strided tensors (see batched path comment). @@ -1182,7 +1182,7 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds FARF(MEDIUM, "hmx_matmul_qk: STANDARD path m=%d k=%d n=%d type=%d", m, k, n, weight_type); // --- Dynamic VTCM layout --- - const size_t vtcm_budget = ctx->vtcm_scratch_size; + const size_t vtcm_budget = ctx->vtcm_size; const size_t vec_dot_size = k * sizeof(__fp16); const bool use_pipeline = (m >= 128) && (k <= n); @@ -1273,9 +1273,6 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds void *buf_curr = vtcm_scratch0; void *buf_next = vtcm_scratch1; - // issue async DDR data transfer for the first weight chunk - // NOTE: use 2D DMA (n_cols rows x row_stride bytes) instead of 1D - // because UDMA roiwidth is 16-bit and total size can exceed 65535. { const size_t n_cols_first = hex_smin(n, n_chunk_n_cols); dma_queue_push(ctx->dma[0], dma_make_ptr(buf_curr, permuted_weight), row_stride, row_stride, row_stride, n_cols_first); @@ -1533,20 +1530,15 @@ void transfer_activation_chunk_threaded(struct htp_context *ctx, __fp16 *dst, co worker_pool_run_func(ctx->worker_pool, transfer_activation_chunk_worker_fn, &state, ctx->n_threads); } -int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict out, const float *restrict x, const uint8_t *restrict w, int m, - int k, int n, int weight_type) { - // Runtime check -- k >= 16384 exceeds 2D DMA limit - if (k >= 16384) { - FARF(HIGH, "%s: k=%d exceeds 2D DMA limit", __func__, k); - return -1; - } +int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict out, const float *restrict x, const uint8_t *restrict w, + int m, int k, int n, int weight_type) { // assume k % 32 == 0 && n % 32 == 0 const size_t row_stride = get_x4x2_row_stride(weight_type, k); if (row_stride == 0) { return -1; } - const size_t vtcm_budget = ctx->vtcm_scratch_size; + const size_t vtcm_budget = ctx->vtcm_size; const size_t M_BLOCK_SIZE = 512; const size_t N_BLOCK_SIZE = 512; @@ -1576,8 +1568,7 @@ int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); assert((size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base) <= vtcm_budget); - FARF(MEDIUM, "%s: m=%d k=%d n=%d wtype=%d vtcm=%zu/%zu", - __func__, m, k, n, weight_type, + FARF(MEDIUM, "%s: m=%d k=%d n=%d wtype=%d vtcm=%zu/%zu", __func__, m, k, n, weight_type, (size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget); // initialize eye tile (32x32 identity matrix) diff --git a/ggml/src/ggml-hexagon/htp/hmx-ops.h b/ggml/src/ggml-hexagon/htp/hmx-ops.h index b36c8d129ba..fb95d36f5a9 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-ops.h +++ b/ggml/src/ggml-hexagon/htp/hmx-ops.h @@ -7,16 +7,12 @@ #include #include -#ifndef restrict -# define restrict __restrict -#endif +#include "htp-ops.h" #ifdef __cplusplus extern "C" { #endif -struct htp_context; // forward declaration - typedef struct { float *dst; const float *activation; diff --git a/ggml/src/ggml-hexagon/htp/htp-ctx.h b/ggml/src/ggml-hexagon/htp/htp-ctx.h index 6f1917fa2cb..4c36a6ea0c2 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ctx.h +++ b/ggml/src/ggml-hexagon/htp/htp-ctx.h @@ -2,6 +2,7 @@ #define HTP_CTX_H #include "hex-dma.h" +#include "htp-ops.h" #include "worker-pool.h" #include @@ -10,38 +11,85 @@ #include #define HTP_MAX_NTHREADS 10 +#define HTP_MAX_MMAPS 16 + +// Memory mapping +struct htp_mmap { + uint64_t size; + uint64_t base; + uint32_t fd; + uint32_t pinned; +}; + +// Scratchpad state +struct htp_spad { + const struct htp_tensor * src; // original src of the data (for reuse) + uint8_t * data; // pointer to an area in vtcm + uint32_t stride; // stride used inside this spad + uint32_t size; // total size + uint32_t size_per_thread; // size per thread +}; + +// Context while processing an Op +// TODO: fold this into the main context +struct htp_ops_context { + struct htp_context * ctx; + + enum htp_op_code op; // FIXME: rename to opcode + int32_t op_params[HTP_OP_MAX_PARAMS]; + + const struct htp_tensor * src[HTP_OP_MAX_INPUTS]; + const struct htp_tensor * dst; + + // TODO convert these to an array + struct htp_spad src0_spad; + struct htp_spad src1_spad; + struct htp_spad src2_spad; + struct htp_spad src3_spad; + struct htp_spad dst_spad; + + uint32_t n_threads; + uint32_t flags; +}; // Main context for htp DSP backend struct htp_context { - dspqueue_t queue; - dma_queue * dma[HTP_MAX_NTHREADS]; - worker_pool_context_t worker_pool; - uint32_t n_threads; - - int thread_id; - int thread_prio; - - uint8_t * vtcm_base; - size_t vtcm_size; - uint32_t vtcm_rctx; - - atomic_bool vtcm_valid; - atomic_bool vtcm_inuse; - atomic_bool vtcm_needs_release; - - uint32_t opmask; - - // Cached src1 spad position from the last quantize pass. - // When SKIP_QUANTIZE is set the Q8 activation data is already in VTCM - // at this address; the matmul must read from here instead of recomputing - // the offset (which depends on the current op's src0 size). - uint8_t * prev_src1_spad; - - // HMX acceleration fields (v73+, enabled by compile-time HTP_HAS_HMX) -#ifdef HTP_HAS_HMX - int hmx_enabled; // Runtime flag: HMX initialisation succeeded - size_t vtcm_scratch_size; // Usable dynamic scratch (vtcm_size minus tail reservation) -#endif + dspqueue_t queue; + dma_queue * dma[HTP_MAX_NTHREADS]; + struct htp_mmap mmap[HTP_MAX_MMAPS]; + worker_pool_context_t worker_pool; + uint32_t n_threads; + + int thread_id; + int thread_prio; + + int hmx_enabled; + + uint8_t * vtcm_base; + size_t vtcm_size; + uint32_t vtcm_rctx; + atomic_bool vtcm_valid; + atomic_bool vtcm_needs_release; + + struct htp_ops_context octx; }; +int op_matmul(struct htp_ops_context * octx); +int op_matmul_id(struct htp_ops_context * octx); +int op_binary(struct htp_ops_context * octx); +int op_unary(struct htp_ops_context * octx); +int op_sum_rows(struct htp_ops_context * octx); +int op_activations(struct htp_ops_context * octx); +int op_softmax(struct htp_ops_context * octx); +int op_add_id(struct htp_ops_context * octx); +int op_rope(struct htp_ops_context * octx); +int op_flash_attn_ext(struct htp_ops_context * octx); +int op_set_rows(struct htp_ops_context * octx); +int op_get_rows(struct htp_ops_context * octx); +int op_cpy(struct htp_ops_context * octx); +int op_repeat(struct htp_ops_context * octx); +int op_argsort(struct htp_ops_context * octx); +int op_ssm_conv(struct htp_ops_context * octx); +int op_cumsum(struct htp_ops_context * octx); + #endif /* HTP_CTX_H */ diff --git a/ggml/src/ggml-hexagon/htp/htp-msg.h b/ggml/src/ggml-hexagon/htp/htp-msg.h deleted file mode 100644 index df0ea7ccbd6..00000000000 --- a/ggml/src/ggml-hexagon/htp/htp-msg.h +++ /dev/null @@ -1,166 +0,0 @@ -#ifndef HTP_MSG_H -#define HTP_MSG_H - -#include - -// ggml-common.h must be included prio to this header - -// Mask to enable various stages of the Ops. -// Used for debugging and profiling. -enum { - HTP_OPMASK_QUEUE = (1 << 0), // Enable Queueing (ie calls into the DSP) - HTP_OPMASK_QUANTIZE = (1 << 1), // Enable Quantize - HTP_OPMASK_COMPUTE = (1 << 2), // Enable Compute -}; - -// Op flags -enum { - HTP_OPFLAGS_SKIP_QUANTIZE = (1 << 0), // Skip dynamic quantization (reuse quantized tensors) - HTP_OPFLAGS_SKIP_COMPUTE = (1 << 1), // Skip actual computation (used for profiling) - HTP_OPFLAGS_EARLY_WAKEUP = (1 << 2) // Send early wakeup notification -}; - -enum htp_status { - HTP_STATUS_OK = 1, - HTP_STATUS_INTERNAL_ERR = 2, - HTP_STATUS_NO_SUPPORT = 3, - HTP_STATUS_INVAL_PARAMS = 4, - HTP_STATUS_VTCM_TOO_SMALL = 5, -}; - -// The values must match the ggml_type. -// Duplicated here because we can't include full ggml.h in the htp build. -// We have some static_asserts in the cpp code to ensure things are in sync. -enum htp_data_type { - HTP_TYPE_F32 = 0, - HTP_TYPE_F16 = 1, - HTP_TYPE_Q4_0 = 2, - HTP_TYPE_Q8_0 = 8, - HTP_TYPE_IQ4_NL = 20, - HTP_TYPE_I32 = 26, - HTP_TYPE_I64 = 27, - HTP_TYPE_MXFP4 = 39, - HTP_TYPE_COUNT -}; - -// Do not reorder first 4 (used as an index) -enum htp_op { - HTP_OP_MUL = 0, - HTP_OP_ADD = 1, - HTP_OP_SUB = 2, - HTP_OP_DIV = 3, - HTP_OP_MUL_MAT, - HTP_OP_MUL_MAT_ID, - HTP_OP_RMS_NORM, - HTP_OP_UNARY_SILU, - HTP_OP_UNARY_GELU, - HTP_OP_UNARY_SIGMOID, - HTP_OP_UNARY_EXP, - HTP_OP_UNARY_NEG, - HTP_OP_UNARY_SOFTPLUS, - HTP_OP_GLU_SWIGLU, - HTP_OP_GLU_SWIGLU_OAI, - HTP_OP_GLU_GEGLU, - HTP_OP_SOFTMAX, - HTP_OP_ADD_ID, - HTP_OP_ROPE, - HTP_OP_FLASH_ATTN_EXT, - HTP_OP_SET_ROWS, - HTP_OP_GET_ROWS, - HTP_OP_SCALE, - HTP_OP_CPY, - HTP_OP_ARGSORT, - HTP_OP_SQR, - HTP_OP_SQRT, - HTP_OP_SUM_ROWS, - HTP_OP_SSM_CONV, - HTP_OP_REPEAT, - HTP_OP_CUMSUM, - INVALID -}; - -static inline size_t htp_t_block_size(uint32_t t) { - switch (t) { - case HTP_TYPE_F32: - return 1; - case HTP_TYPE_F16: - return 1; - case HTP_TYPE_Q4_0: - return QK4_0; - case HTP_TYPE_Q8_0: - return QK8_0; - case HTP_TYPE_IQ4_NL: - return QK4_NL; - case HTP_TYPE_MXFP4: - return QK_MXFP4; - default: - assert(0 && "unsupported HTP data type"); - } - return 0; -} - -static inline size_t htp_type_nbytes(uint32_t t) { - switch (t) { - case HTP_TYPE_F32: - return 4; - case HTP_TYPE_F16: - return 2; - case HTP_TYPE_Q4_0: - return sizeof(block_q4_0); - case HTP_TYPE_Q8_0: - return sizeof(block_q8_0); - case HTP_TYPE_IQ4_NL: - return sizeof(block_iq4_nl); - case HTP_TYPE_MXFP4: - return sizeof(block_mxfp4); - default: - assert(0 && "unsupported HTP data type"); - } - return 0; -} - -// Internal types -#define QK_Q4_0x4x2 256 // 4x Q4_0 blocks packed with next 4x Q4_0 blocks (size in bytes 128) -#define QK_Q8_0x4x2 256 // 4x Q8_0 blocks concat with next 4x Q8_0 blocks -#define QK_MXFP4x4x2 256 // 4x MXFP4 blocks concat with next 4x MXFP4 blocks - -#define HTP_MAX_DIMS 4 - -struct htp_tensor { - uint32_t data; // Buffer offset in the messages, and data pointer on the NSP - uint32_t type; // Data type - uint32_t ne[HTP_MAX_DIMS]; // Number of elements - uint32_t nb[HTP_MAX_DIMS]; // Stride in bytes (see ggml.h ggml_tensor) -}; - -#define HTP_MAX_OP_PARAMS 64 - -struct htp_general_req { - uint32_t op; // GGML/HTP Op - int32_t op_params[HTP_MAX_OP_PARAMS / sizeof(int32_t)]; - // Params for the op, e.g. epsilon of RMS norm - uint32_t flags; // Request flags - - struct htp_tensor src0; // Input0 tensor - struct htp_tensor src1; // Input1 tensor - struct htp_tensor src2; // Input2 tensor - struct htp_tensor src3; // Input3 tensor - struct htp_tensor src4; // Input4 tensor - struct htp_tensor dst; // Output tensor - - // should be multiple of 64 bytes (cacheline) -}; - -struct htp_general_rsp { - uint32_t op; // GGML/HTP Op - uint32_t status; // HTP_STATUS_... - uint32_t prof_usecs; // Number of usec per request - uint32_t prof_cycles; // Number of cycles per request - uint32_t prof_pkts; // Number of instruction packets per request - uint8_t unused[44]; // Pad to 64 bytes -}; - -#define HTP_MAX_MESSAGE_SIZE sizeof(struct htp_general_req) -#define HTP_MAX_PACKET_BUFFERS 8 - -#endif /* HTP_MSG_H */ diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index d35decaac20..44a6ab4f737 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -1,65 +1,154 @@ #ifndef HTP_OPS_H #define HTP_OPS_H -#include "htp-ctx.h" -#include "htp-msg.h" -#include "worker-pool.h" - #include -#include -#include +// ggml-common.h must be included prio to this header + +enum htp_status { + HTP_STATUS_OK = 1, + HTP_STATUS_INTERNAL_ERR = 2, + HTP_STATUS_NO_SUPPORT = 3, + HTP_STATUS_INVAL_PARAMS = 4, + HTP_STATUS_VTCM_TOO_SMALL = 5, +}; + +// First set of values must match the ggml_type. +// Duplicated here because we can't include full ggml.h in the htp build. +// We have some static_asserts in the cpp code to ensure things are in sync. +enum htp_data_type { + HTP_TYPE_F32 = 0, + HTP_TYPE_F16 = 1, + HTP_TYPE_Q4_0 = 2, + HTP_TYPE_Q8_0 = 8, + HTP_TYPE_IQ4_NL = 20, + HTP_TYPE_I32 = 26, + HTP_TYPE_I64 = 27, + HTP_TYPE_MXFP4 = 39, + + // types used internally for repack, dyn.quant, etc + HTP_TYPE_Q4_0x4x2 = 200, + HTP_TYPE_Q8_0x4x2, + HTP_TYPE_MXFP4x4x2, + + HTP_TYPE_INVALID +}; + +// Constats for internal types +#define QK_Q4_0x4x2 256 // 4x Q4_0 blocks packed with next 4x Q4_0 blocks (size in bytes 128) +#define QK_Q8_0x4x2 256 // 4x Q8_0 blocks concat with next 4x Q8_0 blocks +#define QK_MXFP4x4x2 256 // 4x MXFP4 blocks concat with next 4x MXFP4 blocks + + +// Mask to enable various stages of the Ops. +// Used for debugging and profiling. +enum htp_op_mask { + HTP_OPMASK_QUEUE = (1 << 0), // Enable Queueing (ie calls into the DSP) + HTP_OPMASK_COMPUTE = (1 << 1), // Enable Compute +}; + +// Do not reorder first 4 (used as an index) +enum htp_op_code { + HTP_OP_MUL = 0, + HTP_OP_ADD = 1, + HTP_OP_SUB = 2, + HTP_OP_DIV = 3, + HTP_OP_MUL_MAT, + HTP_OP_MUL_MAT_ID, + HTP_OP_RMS_NORM, + HTP_OP_UNARY_SILU, + HTP_OP_UNARY_GELU, + HTP_OP_UNARY_SIGMOID, + HTP_OP_UNARY_EXP, + HTP_OP_UNARY_NEG, + HTP_OP_UNARY_SOFTPLUS, + HTP_OP_GLU_SWIGLU, + HTP_OP_GLU_SWIGLU_OAI, + HTP_OP_GLU_GEGLU, + HTP_OP_SOFTMAX, + HTP_OP_ADD_ID, + HTP_OP_ROPE, + HTP_OP_FLASH_ATTN_EXT, + HTP_OP_SET_ROWS, + HTP_OP_GET_ROWS, + HTP_OP_SCALE, + HTP_OP_CPY, + HTP_OP_ARGSORT, + HTP_OP_SQR, + HTP_OP_SQRT, + HTP_OP_SUM_ROWS, + HTP_OP_SSM_CONV, + HTP_OP_REPEAT, + HTP_OP_CUMSUM, + + HTP_OP_INVALID +}; + +#define HTP_OP_MAX_DIMS 4 // aka GGML_MAX_DIMS +#define HTP_OP_MAX_INPUTS 6 // aka GGML_MAX_SRCS +#define HTP_OP_MAX_PARAMS 16 // aka GGML_MAX_OP_PARAMS -// ggml-common.h must be included prior to this header +#define HTP_OP_MAX_BUFS 8 +#define HTP_OP_MAX_REQS 256 +#define HTP_OP_MAX_TENSORS (HTP_OP_MAX_REQS * HTP_OP_MAX_INPUTS + HTP_OP_MAX_REQS) +#define HTP_OP_MAX_VMEM (3221225472u) -struct htp_spad { - uint8_t * data; - size_t stride; - size_t size; - size_t size_per_thread; +enum htp_tensor_flags { + HTP_TENSOR_COMPUTE = (1U << 0), // Tensor buffer temporal compute data (not weights) + HTP_TENSOR_FLUSHED = (1U << 1) // Tensor buffer has been flushed (set by the NPU) }; -struct htp_ops_context { - struct htp_context * ctx; +// Tensor descriptor +struct htp_tensor { + uint32_t data; // Buffer offset in the messages, and data pointer on the NPU + uint32_t size; // Data size in bytes + uint32_t flags; // Buffer / tensor flags + uint16_t type; // Data type + uint16_t bi; // Buffer index + uint32_t ne[HTP_OP_MAX_DIMS]; // Number of elements + uint32_t nb[HTP_OP_MAX_DIMS]; // Stride in bytes (see ggml.h ggml_tensor) +}; - enum htp_op op; - int32_t op_params[HTP_MAX_OP_PARAMS / sizeof(int32_t)]; +// Buffer descriptor +struct htp_buf_desc { + uint64_t base; // base address + uint64_t size; // total size + uint32_t flags; // buffer flags (unused) + uint32_t fd; // file descriptor +}; - struct htp_tensor src0; - struct htp_tensor src1; - struct htp_tensor src2; - struct htp_tensor src3; - struct htp_tensor src4; - struct htp_tensor dst; +enum htp_op_flags { + HTP_OPFLAGS_SKIP_COMPUTE = (1U << 0), // Skip actual computation (used for profiling) +}; - struct htp_spad src0_spad; - struct htp_spad src1_spad; - struct htp_spad src2_spad; - struct htp_spad src3_spad; - struct htp_spad dst_spad; +// Op descriptor +struct htp_op_desc { + uint32_t opcode; // GGML/HTP Op + uint32_t flags; // Op flags + int32_t params[HTP_OP_MAX_PARAMS]; // Params for the op, e.g. epsilon of RMS norm + uint16_t src[HTP_OP_MAX_INPUTS]; // Input tensors indices + uint16_t dst; // Output tensor index - worker_pool_context_t * wpool; // worker pool - uint32_t n_threads; // num threads + // the rest is filled in-place by the NPU + uint32_t prof_usecs; // Number of usec per request + uint32_t prof_cycles; // Number of cycles per request + uint32_t prof_pkts; // Number of instruction packets per request + uint32_t unused; +}; - uint32_t flags; +struct htp_opbatch_req { + uint32_t n_bufs; // Number of buffers + uint32_t n_tensors; // Number of tensors + uint32_t n_ops; // Number of ops + uint32_t flags; // unused + // struct htp_buf_desc bufs[]; -- dspqueue buf 0 + // struct htp_tensor tensors[]; -- dspqueue buf 0 + // struct htp_op_desc ops[]; -- dspqueue buf 0 }; -int op_matmul(struct htp_ops_context * octx); -int op_matmul_id(struct htp_ops_context * octx); -int op_binary(struct htp_ops_context * octx); -int op_unary(struct htp_ops_context * octx); -int op_sum_rows(struct htp_ops_context * octx); -int op_activations(struct htp_ops_context * octx); -int op_softmax(struct htp_ops_context * octx); -int op_add_id(struct htp_ops_context * octx); -int op_rope(struct htp_ops_context * octx); -int op_flash_attn_ext(struct htp_ops_context * octx); -int op_set_rows(struct htp_ops_context * octx); -int op_get_rows(struct htp_ops_context * octx); -int op_cpy(struct htp_ops_context * octx); -int op_repeat(struct htp_ops_context * octx); -int op_argsort(struct htp_ops_context * octx); -int op_ssm_conv(struct htp_ops_context * octx); -int op_cumsum(struct htp_ops_context * octx); +struct htp_opbatch_rsp { + uint32_t status; // HTP_STATUS_... + // struct htp_op_req ops[]; -- dspqueue buf 0 +}; #endif /* HTP_OPS_H */ diff --git a/ggml/src/ggml-hexagon/htp/htp_iface.idl b/ggml/src/ggml-hexagon/htp/htp_iface.idl index 2dc716cb441..3eb5d5a6912 100644 --- a/ggml/src/ggml-hexagon/htp/htp_iface.idl +++ b/ggml/src/ggml-hexagon/htp/htp_iface.idl @@ -9,6 +9,8 @@ interface htp_iface : remote_handle64 { AEEResult start(in uint32 sess_id, in uint64 dsp_queue_id, in uint32 n_hvx, in uint32 use_hmx); AEEResult stop(); + AEEResult mmap(in uint32 fd, in uint32 size, in uint32 pinned); + AEEResult munmap(in uint32 fd); AEEResult enable_etm(); AEEResult disable_etm(); }; diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index 6f37bf9d4b8..8b347039428 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -1,5 +1,7 @@ #pragma clang diagnostic ignored "-Wgnu-zero-variadic-macro-arguments" #pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" #include #include @@ -12,6 +14,7 @@ #include #include #include +#include #include #include @@ -21,14 +24,10 @@ #define GGML_COMMON_DECL_C #include "ggml-common.h" #include "htp-ctx.h" -#include "htp-msg.h" +#include "htp-ops.h" #include "htp-ops.h" #include "worker-pool.h" -#ifdef HTP_HAS_HMX -#include "hmx-ops.h" -#endif // HTP_HAS_HMX - AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) { struct htp_context * ctx; int err = 0; @@ -38,7 +37,7 @@ AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) { return AEE_ENOMEMORY; } - // Use the context structure as a handle + // Use the context structure as the handle *handle = (remote_handle64) ctx; // Enable FARF logs @@ -115,6 +114,16 @@ AEEResult htp_iface_close(remote_handle64 handle) { return AEE_EITEMBUSY; } + // release the mmaps (if any) + for (uint32_t i=0; immap[i].size) { + HAP_munmap2((void *) ctx->mmap[i].base, ctx->mmap[i].size); + ctx->mmap[i].size = 0; + ctx->mmap[i].base = NULL; + ctx->mmap[i].fd = -1; + } + } + free(ctx); return AEE_SUCCESS; } @@ -143,66 +152,93 @@ AEEResult htp_iface_disable_etm(remote_handle64 handle) { return err; } -static int vtcm_acquire(struct htp_context * ctx) { - int err; - if (!ctx->vtcm_valid) { - // Temporarily bump thread priority to make sure it's higher than other sessions. - // This way the resource manager will notify the other thread to release VTCM. - // Note that we need to reaquire VTCM at normal priority for this to work next time. - qurt_thread_set_priority(qurt_thread_get_id(), ctx->thread_prio - 10); - err = HAP_compute_res_acquire_cached(ctx->vtcm_rctx, 1000000); - if (err != 0) { - FARF(ERROR, "Failed to acquire VTCM: 0x%08x", (unsigned)err); - abort(); - } - HAP_compute_res_release_cached(ctx->vtcm_rctx); - qurt_thread_set_priority(qurt_thread_get_id(), ctx->thread_prio); +AEEResult htp_iface_mmap(remote_handle64 handle, int fd, uint32_t size, uint32_t pinned) { + struct htp_context * ctx = (struct htp_context *) handle; + if (!ctx) { + return AEE_EBADPARM; + } - err = HAP_compute_res_acquire_cached(ctx->vtcm_rctx, 1000000); - if (err != 0) { - FARF(ERROR, "Failed to acquire VTCM: 0x%08x", (unsigned)err); - abort(); + // See if we already have this mapping + for (uint32_t i=0; immap[i]; + if (m->fd == fd) { + m->pinned = pinned; + return AEE_SUCCESS; } - ctx->vtcm_valid = true; } - ctx->vtcm_inuse = true; + // Add new mapping + for (uint32_t i=0; immap[i]; + if (!m->size) { + FARF(HIGH, "mmap : fd %u size %u pinned %u", fd, size, pinned); + void *va = HAP_mmap2(NULL, size, HAP_PROT_READ | HAP_PROT_WRITE, 0, fd, 0); + if (va == (void*)-1) { + FARF(ERROR, "mmap failed : va %p fd %u size %u", va, fd, (uint32_t) size); + return AEE_EFAILED; + } + m->base = (uint64_t) va; + m->fd = fd; + m->size = size; + m->pinned = pinned; - return 0; + return AEE_SUCCESS; + } + } + + return AEE_ENOMEMORY; } -static int vtcm_release(struct htp_context * ctx) { - ctx->vtcm_inuse = false; +AEEResult htp_iface_munmap(remote_handle64 handle, int fd) { + struct htp_context * ctx = (struct htp_context *) handle; + if (!ctx) { + return AEE_EBADPARM; + } - if (ctx->vtcm_valid && ctx->vtcm_needs_release) { - ctx->vtcm_valid = false; - ctx->vtcm_needs_release = false; - HAP_compute_res_release_cached(ctx->vtcm_rctx); + for (uint32_t i=0; immap[i]; + if (fd < 0 || m->fd == fd) { + FARF(HIGH, "unmmap : base %p fd %u size %u", (void*) m->base, m->fd, (uint32_t) m->size); + HAP_munmap2((void *) m->base, m->size); + m->size = 0; + m->base = NULL; + m->fd = -1; + m->pinned = 0; + } } - return 0; + return AEE_SUCCESS; } -static int vtcm_release_callback(unsigned int rctx, void * state) { - struct htp_context * ctx = (struct htp_context *) state; - - if (!ctx || ctx->vtcm_rctx != rctx) { - return AEE_EBADPARM; - } +static void vtcm_acquire(struct htp_context * ctx) { + if (!ctx->vtcm_valid) { + int err = HAP_compute_res_acquire_cached(ctx->vtcm_rctx, 1000000u); + if (err != 0) { + FARF(ERROR, "ggml-hex: failed to acquire VTCM: 0x%08x", (unsigned)err); + abort(); + } - // If VTCM is not inuse (not processing Ops) release it right here - // otherwise we'll release it once we're done with the current Op. + ctx->vtcm_needs_release = false; + ctx->vtcm_valid = true; - if (ctx->vtcm_inuse) { - ctx->vtcm_needs_release = true; - return 0; + // Drop the priority to make sure we get the release callback from other GGML-HTP and QNN-HTP sessions + HAP_compute_res_update_priority(ctx->vtcm_rctx, ctx->thread_prio + 10); } +} - ctx->vtcm_valid = false; - HAP_compute_res_release_cached(ctx->vtcm_rctx); +static void vtcm_release(struct htp_context * ctx) { + if (ctx->vtcm_valid) { + ctx->vtcm_valid = false; + ctx->vtcm_needs_release = false; + HAP_compute_res_release_cached(ctx->vtcm_rctx); + } +} +static int vtcm_release_callback(unsigned int rctx, void * state) { + struct htp_context * ctx = (struct htp_context *) state; + ctx->vtcm_needs_release = true; return 0; } @@ -236,7 +272,6 @@ static int vtcm_alloc(struct htp_context * ctx) { ctx->vtcm_size = vtcm_size; ctx->vtcm_rctx = rctx; ctx->vtcm_valid = false; - ctx->vtcm_inuse = false; ctx->vtcm_needs_release = false; return 0; @@ -288,18 +323,8 @@ AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_que } #ifdef HTP_HAS_HMX - if (use_hmx) { - ctx->vtcm_scratch_size = ctx->vtcm_size; - ctx->hmx_enabled = 1; - - FARF(HIGH, "HMX enabled: vtcm-scratch %zu", ctx->vtcm_scratch_size); - } else { - // HMX disabled: skip HMX initialisation so the - // dispatch loop falls through to the HVX compute paths. - ctx->hmx_enabled = 0; - ctx->vtcm_scratch_size = ctx->vtcm_size; - FARF(HIGH, "HMX disabled (use_hmx=0): vtcm-scratch %zu", ctx->vtcm_scratch_size); - } + ctx->hmx_enabled = use_hmx; + FARF(HIGH, "HMX %s (use_hmx=%d)", ctx->hmx_enabled ? "enabled" : "disabled", use_hmx); #endif qurt_sysenv_max_hthreads_t hw_threads; @@ -362,13 +387,11 @@ AEEResult htp_iface_stop(remote_handle64 handle) { for (int i = 0; i < ctx->n_threads; i++) { dma_queue_delete(ctx->dma[i]); } + #ifdef HTP_HAS_HMX - if (ctx->hmx_enabled) { - ctx->hmx_enabled = 0; - } + ctx->hmx_enabled = 0; #endif - vtcm_free(ctx); return AEE_SUCCESS; @@ -397,1129 +420,320 @@ static inline void profile_stop(struct profile_data * d) { d->pkts = hex_get_pktcnt() - d->pkts; } -static int send_htp_rsp(struct htp_context * c, - uint32_t op, - uint32_t status, - struct dspqueue_buffer * bufs, - size_t n_bufs, - struct profile_data * prof) { - // Prep response struct (zero-init to clear cmp/unused union) - struct htp_general_rsp rsp; - memset(&rsp, 0, sizeof(rsp)); - rsp.op = op; - rsp.status = status; - rsp.prof_usecs = prof->usecs; - rsp.prof_cycles = prof->cycles; - rsp.prof_pkts = prof->pkts; - - int err = dspqueue_write(c->queue, - 0, // Flags - n_bufs, - bufs, // Buffer references - sizeof(rsp), - (const uint8_t *) &rsp, // Message - DSPQUEUE_TIMEOUT_NONE); +static int execute_op(struct htp_ops_context * octx) { + switch (octx->op) { + case HTP_OP_MUL_MAT: + return op_matmul(octx); - if (err != 0) { - FARF(ERROR, "dspqueue_write failed: 0x%08x", (unsigned) err); - } + case HTP_OP_MUL_MAT_ID: + return op_matmul_id(octx); - return err; -} + case HTP_OP_MUL: + case HTP_OP_ADD: + case HTP_OP_SUB: + case HTP_OP_DIV: + case HTP_OP_ADD_ID: + return op_binary(octx); -static void proc_matmul_req(struct htp_context * ctx, - struct htp_general_req * req, - struct dspqueue_buffer * bufs, - size_t n_bufs) { - struct dspqueue_buffer rsp_bufs[1]; - - // We had written to the output buffer, we'd also need to flush it - rsp_bufs[0].fd = bufs[2].fd; - rsp_bufs[0].ptr = bufs[2].ptr; - rsp_bufs[0].size = bufs[2].size; - rsp_bufs[0].offset = bufs[2].offset; - rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU - - // Setup Op context - struct htp_ops_context octx = { 0 }; - octx.ctx = ctx; - octx.src0 = req->src0; - octx.src1 = req->src1; - octx.dst = req->dst; - octx.flags = req->flags; - octx.op = req->op; - - // Update data pointers - octx.src0.data = (uint32_t) bufs[0].ptr; - octx.src1.data = (uint32_t) bufs[1].ptr; - octx.dst.data = (uint32_t) bufs[2].ptr; - octx.n_threads = ctx->n_threads; - - struct profile_data prof; - profile_start(&prof); - - uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; - if (vtcm_acquire(ctx) == AEE_SUCCESS) { - rsp_status = op_matmul(&octx); - vtcm_release(ctx); - } + case HTP_OP_RMS_NORM: + case HTP_OP_SCALE: + case HTP_OP_SQR: + case HTP_OP_SQRT: + case HTP_OP_UNARY_SOFTPLUS: + case HTP_OP_UNARY_SIGMOID: + case HTP_OP_UNARY_NEG: + case HTP_OP_UNARY_EXP: + return op_unary(octx); - profile_stop(&prof); - send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); -} + case HTP_OP_UNARY_SILU: + case HTP_OP_UNARY_GELU: + case HTP_OP_GLU_SWIGLU: + case HTP_OP_GLU_SWIGLU_OAI: + case HTP_OP_GLU_GEGLU: + return op_activations(octx); -static void proc_argsort_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { - struct dspqueue_buffer rsp_bufs[1]; - - // We had written to the output buffer, we'd also need to flush it - rsp_bufs[0].fd = bufs[1].fd; - rsp_bufs[0].ptr = bufs[1].ptr; - rsp_bufs[0].offset = bufs[1].offset; - rsp_bufs[0].size = bufs[1].size; - rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU - - // Setup Op context - struct htp_ops_context octx = { 0 }; - octx.ctx = ctx; - octx.src0 = req->src0; - octx.dst = req->dst; - octx.flags = req->flags; - octx.op = req->op; - - memcpy(octx.op_params, req->op_params, sizeof(octx.op_params)); - - // Update data pointers - octx.src0.data = (uint32_t) bufs[0].ptr; - octx.dst.data = (uint32_t) bufs[1].ptr; - octx.n_threads = ctx->n_threads; - - struct profile_data prof; - profile_start(&prof); - - uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; - if (vtcm_acquire(ctx) == AEE_SUCCESS) { - rsp_status = op_argsort(&octx); - vtcm_release(ctx); - } + case HTP_OP_SOFTMAX: + return op_softmax(octx); - profile_stop(&prof); - send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); -} + case HTP_OP_ROPE: + return op_rope(octx); -static void proc_cpy_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { - struct dspqueue_buffer rsp_bufs[1]; - - // We had written to the output buffer, we'd also need to flush it - rsp_bufs[0].fd = bufs[1].fd; - rsp_bufs[0].ptr = bufs[1].ptr; - rsp_bufs[0].offset = bufs[1].offset; - rsp_bufs[0].size = bufs[1].size; - rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU - - // Setup Op context - struct htp_ops_context octx = { 0 }; - octx.ctx = ctx; - octx.src0 = req->src0; - octx.dst = req->dst; - octx.flags = req->flags; - octx.op = req->op; - - // Update data pointers - octx.src0.data = (uint32_t) bufs[0].ptr; - octx.dst.data = (uint32_t) bufs[1].ptr; - octx.n_threads = ctx->n_threads; - - struct profile_data prof; - profile_start(&prof); - - uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; - if (vtcm_acquire(ctx) == AEE_SUCCESS) { - rsp_status = op_cpy(&octx); - vtcm_release(ctx); - } + case HTP_OP_FLASH_ATTN_EXT: + return op_flash_attn_ext(octx); - profile_stop(&prof); - send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); -} + case HTP_OP_SET_ROWS: + return op_set_rows(octx); -static void proc_repeat_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { - struct dspqueue_buffer rsp_bufs[1]; - - // We had written to the output buffer, we'd also need to flush it - rsp_bufs[0].fd = bufs[1].fd; - rsp_bufs[0].ptr = bufs[1].ptr; - rsp_bufs[0].offset = bufs[1].offset; - rsp_bufs[0].size = bufs[1].size; - rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU - - // Setup Op context - struct htp_ops_context octx = { 0 }; - octx.ctx = ctx; - octx.src0 = req->src0; - octx.dst = req->dst; - octx.flags = req->flags; - octx.op = req->op; - - // Update data pointers - octx.src0.data = (uint32_t) bufs[0].ptr; - octx.dst.data = (uint32_t) bufs[1].ptr; - octx.n_threads = ctx->n_threads; - - struct profile_data prof; - profile_start(&prof); - - uint32_t rsp_status = op_repeat(&octx); - - profile_stop(&prof); - send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); -} + case HTP_OP_GET_ROWS: + return op_get_rows(octx); -static void proc_get_rows_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { - struct dspqueue_buffer rsp_bufs[1]; - - // We had written to the output buffer, we'd also need to flush it - rsp_bufs[0].fd = bufs[2].fd; - rsp_bufs[0].ptr = bufs[2].ptr; - rsp_bufs[0].offset = bufs[2].offset; - rsp_bufs[0].size = bufs[2].size; - rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU - - // Setup Op context - struct htp_ops_context octx = { 0 }; - octx.ctx = ctx; - octx.src0 = req->src0; - octx.src1 = req->src1; - octx.dst = req->dst; - octx.flags = req->flags; - octx.op = req->op; - - // Update data pointers - octx.src0.data = (uint32_t) bufs[0].ptr; - octx.src1.data = (uint32_t) bufs[1].ptr; - octx.dst.data = (uint32_t) bufs[2].ptr; - octx.n_threads = ctx->n_threads; - - struct profile_data prof; - profile_start(&prof); - - uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; - if (vtcm_acquire(ctx) == AEE_SUCCESS) { - rsp_status = op_get_rows(&octx); - vtcm_release(ctx); - } + case HTP_OP_SUM_ROWS: + return op_sum_rows(octx); - profile_stop(&prof); - send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); -} + case HTP_OP_CPY: + return op_cpy(octx); -static void proc_matmul_id_req(struct htp_context * ctx, - struct htp_general_req * req, - struct dspqueue_buffer * bufs, - size_t n_bufs) { - struct dspqueue_buffer rsp_bufs[1]; - - // We had written to the output buffer, we'd also need to flush it - rsp_bufs[0].fd = bufs[3].fd; - rsp_bufs[0].ptr = bufs[3].ptr; - rsp_bufs[0].size = bufs[3].size; - rsp_bufs[0].offset = bufs[3].offset; - rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU - - // Setup Op context - struct htp_ops_context octx = { 0 }; - octx.ctx = ctx; - octx.src0 = req->src0; - octx.src1 = req->src1; - octx.src2 = req->src2; - octx.dst = req->dst; - octx.flags = req->flags; - octx.op = req->op; - - // Update data pointers - octx.src0.data = (uint32_t) bufs[0].ptr; - octx.src1.data = (uint32_t) bufs[1].ptr; - octx.src2.data = (uint32_t) bufs[2].ptr; - octx.dst.data = (uint32_t) bufs[3].ptr; - octx.n_threads = ctx->n_threads; - - struct profile_data prof; - profile_start(&prof); - - uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; - if (vtcm_acquire(ctx) == AEE_SUCCESS) { - rsp_status = op_matmul_id(&octx); - vtcm_release(ctx); - } + case HTP_OP_REPEAT: + return op_repeat(octx); - profile_stop(&prof); - send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); -} + case HTP_OP_ARGSORT: + return op_argsort(octx); -static void proc_binary_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { - struct dspqueue_buffer rsp_bufs[1]; - - // We had written to the output buffer, we'd also need to flush it - rsp_bufs[0].fd = bufs[2].fd; - rsp_bufs[0].ptr = bufs[2].ptr; - rsp_bufs[0].offset = bufs[2].offset; - rsp_bufs[0].size = bufs[2].size; - rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU - - // Setup Op context - struct htp_ops_context octx = { 0 }; - octx.ctx = ctx; - octx.src0 = req->src0; - octx.src1 = req->src1; - octx.dst = req->dst; - octx.flags = req->flags; - octx.op = req->op; - - // Update data pointers - octx.src0.data = (uint32_t) bufs[0].ptr; - octx.src1.data = (uint32_t) bufs[1].ptr; - octx.dst.data = (uint32_t) bufs[2].ptr; - octx.n_threads = ctx->n_threads; - - struct profile_data prof; - profile_start(&prof); - - uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; - if (vtcm_acquire(ctx) == AEE_SUCCESS) { - rsp_status = op_binary(&octx); - vtcm_release(ctx); - } + case HTP_OP_SSM_CONV: + return op_ssm_conv(octx); - profile_stop(&prof); - send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); -} + case HTP_OP_CUMSUM: + return op_cumsum(octx); -static void proc_add_id_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { - struct dspqueue_buffer rsp_bufs[1]; - - // We had written to the output buffer, we'd also need to flush it - rsp_bufs[0].fd = bufs[3].fd; - rsp_bufs[0].ptr = bufs[3].ptr; - rsp_bufs[0].offset = bufs[3].offset; - rsp_bufs[0].size = bufs[3].size; - rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU - - // Setup Op context - struct htp_ops_context octx = { 0 }; - octx.ctx = ctx; - octx.src0 = req->src0; - octx.src1 = req->src1; - octx.src2 = req->src2; - octx.dst = req->dst; - octx.flags = req->flags; - octx.op = req->op; - - // Update data pointers - octx.src0.data = (uint32_t) bufs[0].ptr; - octx.src1.data = (uint32_t) bufs[1].ptr; - octx.src2.data = (uint32_t) bufs[2].ptr; - octx.dst.data = (uint32_t) bufs[3].ptr; - octx.n_threads = ctx->n_threads; - - struct profile_data prof; - profile_start(&prof); - - uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; - if (vtcm_acquire(ctx) == AEE_SUCCESS) { - rsp_status = op_binary(&octx); - vtcm_release(ctx); - } + case HTP_OP_INVALID: + break; - profile_stop(&prof); - send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); -} - -static void proc_unary_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { - struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS]; - - // We had written to the output buffer, we'd also need to flush it - rsp_bufs[0].fd = bufs[1].fd; - rsp_bufs[0].ptr = bufs[1].ptr; - rsp_bufs[0].offset = bufs[1].offset; - rsp_bufs[0].size = bufs[1].size; - rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU - - // Setup Op context - struct htp_ops_context octx = { 0 }; - octx.ctx = ctx; - octx.src0 = req->src0; - octx.dst = req->dst; - octx.flags = req->flags; - octx.op = req->op; - - memcpy(octx.op_params, req->op_params, sizeof(octx.op_params)); - - // Update data pointers - octx.src0.data = (uint32_t) bufs[0].ptr; - octx.dst.data = (uint32_t) bufs[1].ptr; - octx.n_threads = ctx->n_threads; - - struct profile_data prof; - profile_start(&prof); - - uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; - if (vtcm_acquire(ctx) == AEE_SUCCESS) { - rsp_status = op_unary(&octx); - vtcm_release(ctx); + // No default to catch missing cases } - profile_stop(&prof); - send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); + FARF(ERROR, "Unknown Op %u", octx->op); + return -1; } -static void proc_sum_rows_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { - struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS]; - - // We had written to the output buffer, we'd also need to flush it - rsp_bufs[0].fd = bufs[1].fd; - rsp_bufs[0].ptr = bufs[1].ptr; - rsp_bufs[0].offset = bufs[1].offset; - rsp_bufs[0].size = bufs[1].size; - rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU - - // Setup Op context - struct htp_ops_context octx = { 0 }; - octx.ctx = ctx; - octx.src0 = req->src0; - octx.dst = req->dst; - octx.flags = req->flags; - octx.op = req->op; - - memcpy(octx.op_params, req->op_params, sizeof(octx.op_params)); - - // Update data pointers - octx.src0.data = (uint32_t) bufs[0].ptr; - octx.dst.data = (uint32_t) bufs[1].ptr; - octx.n_threads = ctx->n_threads; - - struct profile_data prof; - profile_start(&prof); - - uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; - if (vtcm_acquire(ctx) == AEE_SUCCESS) { - rsp_status = op_sum_rows(&octx); - vtcm_release(ctx); - } +static inline bool reuse_buf(struct htp_context *ctx, uint32_t *m_reuse, struct htp_buf_desc *b) { + b->base = NULL; - profile_stop(&prof); - send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); -} - -static void proc_ssm_conv_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { - struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS]; - - // We've written to the output buffer, we'd also need to flush it - rsp_bufs[0].fd = bufs[2].fd; - rsp_bufs[0].ptr = bufs[2].ptr; - rsp_bufs[0].offset = bufs[2].offset; - rsp_bufs[0].size = bufs[2].size; - rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU - - // Setup OP context - struct htp_ops_context octx = { 0 }; - octx.ctx = ctx; - octx.src0 = req->src0; - octx.src1 = req->src1; - octx.dst = req->dst; - octx.flags = req->flags; - octx.op = req->op; - - memcpy(octx.op_params, req->op_params, sizeof(octx.op_params)); - - // Update data pointers - octx.src0.data = (uint32_t) bufs[0].ptr; - octx.src1.data = (uint32_t) bufs[1].ptr; - octx.dst.data = (uint32_t) bufs[2].ptr; - octx.n_threads = ctx->n_threads; - - struct profile_data prof; - profile_start(&prof); - - uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; - if (vtcm_acquire(ctx) == AEE_SUCCESS) { - rsp_status = op_ssm_conv(&octx); - vtcm_release(ctx); + for (uint32_t i=0; immap + i; + if (m->size && m->fd == b->fd) { + b->base = m->base; + *m_reuse |= (1 << i); + return true; + } } - profile_stop(&prof); - send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); + return false; } -static void proc_cumsum_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { - struct dspqueue_buffer rsp_bufs[1]; - - // We've written to the output buffer, we'd also need to flush it - rsp_bufs[0].fd = bufs[1].fd; - rsp_bufs[0].ptr = bufs[1].ptr; - rsp_bufs[0].offset = bufs[1].offset; - rsp_bufs[0].size = bufs[1].size; - rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU - - // Setup Op context - struct htp_ops_context octx = { 0 }; - octx.ctx = ctx; - octx.src0 = req->src0; - octx.dst = req->dst; - octx.flags = req->flags; - octx.op = req->op; - octx.src0.data = (uint32_t) bufs[0].ptr; - octx.dst.data = (uint32_t) bufs[1].ptr; - octx.n_threads = ctx->n_threads; - - struct profile_data prof; - profile_start(&prof); - - uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; - if (vtcm_acquire(ctx) == AEE_SUCCESS) { - rsp_status = op_cumsum(&octx); - vtcm_release(ctx); +static inline void drop_mmap(struct htp_context *ctx, struct htp_mmap *m) { + if (m->size && !m->pinned) { + FARF(HIGH, "unmap : fd %u base %p size %u pinned %u", m->fd, (void*) m->base, (uint32_t) m->size, m->pinned); + HAP_munmap2((void *) m->base, m->size); + m->size = 0; + m->base = 0; + m->fd = -1; } - - profile_stop(&prof); - send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); } -static void proc_activations_req(struct htp_context * ctx, - struct htp_general_req * req, - struct dspqueue_buffer * bufs, - uint32_t n_bufs) { - struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS]; - - int write_idx = (n_bufs == 3) ? 2 : 1; - - // We had written to the output buffer, we'd also need to flush it - rsp_bufs[0].fd = bufs[write_idx].fd; - rsp_bufs[0].ptr = bufs[write_idx].ptr; - rsp_bufs[0].offset = bufs[write_idx].offset; - rsp_bufs[0].size = bufs[write_idx].size; - rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU - - // Setup Op context - struct htp_ops_context octx = { 0 }; - octx.ctx = ctx; - octx.src0 = req->src0; - if (3 == n_bufs) { - octx.src1 = req->src1; - } - octx.dst = req->dst; - octx.flags = req->flags; - octx.op = req->op; - - memcpy(octx.op_params, req->op_params, sizeof(octx.op_params)); - - // Update data pointers - octx.src0.data = (uint32_t) bufs[0].ptr; - if (3 == n_bufs) { - octx.src1.data = (uint32_t) bufs[1].ptr; - octx.dst.data = (uint32_t) bufs[2].ptr; - } else { - octx.dst.data = (uint32_t) bufs[1].ptr; - } - octx.n_threads = ctx->n_threads; +static inline void mmap_buf(struct htp_context *ctx, struct htp_buf_desc *b) { + if (b->base) return; // already mapped - struct profile_data prof; - profile_start(&prof); + // find unused mapping + for (uint32_t i=0; i < HTP_MAX_MMAPS; i++) { + struct htp_mmap *m = &ctx->mmap[i]; + if (!m->size) { + void *va = HAP_mmap2(NULL, b->size, HAP_PROT_READ | HAP_PROT_WRITE, 0, b->fd, 0); + if (va == (void*)-1) { + FARF(ERROR, "mmap failed : va %p fd %u size %u", va, b->fd, (uint32_t) b->size); + abort(); // can't do much else at this point + } - uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; - if (vtcm_acquire(ctx) == AEE_SUCCESS) { - if (octx.op == HTP_OP_SOFTMAX) { - rsp_status = op_softmax(&octx); - } else { - rsp_status = op_activations(&octx); + m->base = b->base = (uint64_t) va; + m->fd = b->fd; + m->size = b->size; + m->pinned = 0; + + FARF(HIGH, "mmap : fd %u base %p size %u pinned %u", m->fd, (void*) m->base, (uint32_t) m->size, m->pinned); + return; } - vtcm_release(ctx); } - - profile_stop(&prof); - send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); } -static void proc_rope_req(struct htp_context * ctx, - struct htp_general_req * req, - struct dspqueue_buffer * bufs, - uint32_t n_bufs) { - struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS]; - - int write_idx = n_bufs - 1; - - // We had written to the output buffer, we'd also need to flush it - rsp_bufs[0].fd = bufs[write_idx].fd; - rsp_bufs[0].ptr = bufs[write_idx].ptr; - rsp_bufs[0].offset = bufs[write_idx].offset; - rsp_bufs[0].size = bufs[write_idx].size; - rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU - - // Setup Op context - struct htp_ops_context octx = { 0 }; - octx.ctx = ctx; - octx.src0 = req->src0; - octx.src1 = req->src1; - if (4 == n_bufs) { - octx.src2 = req->src2; - } - octx.dst = req->dst; - octx.flags = req->flags; - octx.op = req->op; - - memcpy(octx.op_params, req->op_params, sizeof(octx.op_params)); - - // Update data pointers - octx.src0.data = (uint32_t) bufs[0].ptr; - octx.src1.data = (uint32_t) bufs[1].ptr; - if (4 == n_bufs) { - octx.src2.data = (uint32_t) bufs[2].ptr; - octx.dst.data = (uint32_t) bufs[3].ptr; - } else { - octx.dst.data = (uint32_t) bufs[2].ptr; - } - octx.n_threads = ctx->n_threads; +static void prep_op_bufs(struct htp_context *ctx, struct htp_buf_desc *bufs, uint32_t n_bufs) { + uint32_t m_reuse = 0; // mmap reuse mask (index from ctx->mmap array) + uint32_t b_reuse = 0; // buf reuse count - struct profile_data prof; - profile_start(&prof); + size_t m_vmem = 0; // mapped vmem + size_t e_vmem = 0; // extra vmem - uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; - if (vtcm_acquire(ctx) == AEE_SUCCESS) { - rsp_status = op_rope(&octx); - vtcm_release(ctx); + // See what we can reuse + for (uint32_t i=0; i < n_bufs; i++) { + struct htp_buf_desc *b = bufs + i; + if (reuse_buf(ctx, &m_reuse, b)) { b_reuse++; } else { e_vmem += b->size; } + FARF(HIGH, "prep-buf #%u : pass0 fd %u base %p size %u flags 0x%x", i, b->fd, (void*) b->base, (uint32_t) b->size, b->flags); } - profile_stop(&prof); - send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); -} + if (b_reuse == n_bufs) return; // all bufs reuse existing mappings -static void proc_set_rows_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { - struct dspqueue_buffer rsp_bufs[1]; - - // We had written to the output buffer, we'd also need to flush it - rsp_bufs[0].fd = bufs[2].fd; - rsp_bufs[0].ptr = bufs[2].ptr; - rsp_bufs[0].offset = bufs[2].offset; - rsp_bufs[0].size = bufs[2].size; - rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU - - // Setup Op context - struct htp_ops_context octx = { 0 }; - octx.ctx = ctx; - octx.src0 = req->src0; - octx.src1 = req->src1; - octx.dst = req->dst; - octx.flags = req->flags; - octx.op = req->op; - - // Update data pointers - octx.src0.data = (uint32_t) bufs[0].ptr; - octx.src1.data = (uint32_t) bufs[1].ptr; - octx.dst.data = (uint32_t) bufs[2].ptr; - octx.n_threads = ctx->n_threads; - - struct profile_data prof; - profile_start(&prof); - - uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; - if (vtcm_acquire(ctx) == AEE_SUCCESS) { - rsp_status = op_set_rows(&octx); - vtcm_release(ctx); - } + // See how much vmem we have mmaped right now + for (uint32_t i=0; immap[i].size; } - profile_stop(&prof); - send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); -} - -static void proc_flash_attn_ext_req(struct htp_context * ctx, - struct htp_general_req * req, - struct dspqueue_buffer * bufs, - uint32_t n_bufs) { - // Setup Op context - struct htp_ops_context octx; - memset(&octx, 0, sizeof(octx)); - - octx.ctx = ctx; - octx.n_threads = ctx->n_threads; - - octx.src0 = req->src0; - octx.src1 = req->src1; - octx.src2 = req->src2; - octx.src3 = req->src3; - octx.src4 = req->src4; - octx.dst = req->dst; - octx.flags = req->flags; - octx.op = req->op; - - memcpy(octx.op_params, req->op_params, sizeof(octx.op_params)); - - // Update data pointers - octx.src0.data = (uint32_t) bufs[0].ptr; - octx.src1.data = (uint32_t) bufs[1].ptr; - octx.src2.data = (uint32_t) bufs[2].ptr; - - int last_buf = 3; - - if (octx.src3.ne[0]) { - octx.src3.data = (uint32_t) bufs[last_buf++].ptr; // mask is valid - } + FARF(HIGH, "prep-bufs : pass1 mmap-vmem %zu extra-vmem %zu n-bufs %u b-reuse %u", m_vmem, e_vmem, n_bufs, b_reuse); - if (octx.src4.ne[0]) { - octx.src4.data = (uint32_t) bufs[last_buf++].ptr; // sinks is valid + if ((m_vmem + e_vmem) > HTP_OP_MAX_VMEM) { + // Drop unused mappings + for (uint32_t i=0; i < HTP_MAX_MMAPS; i++) { + bool used = m_reuse & (1<mmap + i); } + } } - octx.dst.data = (uint32_t) bufs[last_buf].ptr; - - struct profile_data prof; - profile_start(&prof); - - uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; - if (vtcm_acquire(ctx) == AEE_SUCCESS) { - rsp_status = op_flash_attn_ext(&octx); - vtcm_release(ctx); + // Create missing mappings + for (uint32_t i=0; i < n_bufs; i++) { + struct htp_buf_desc *b = bufs + i; + mmap_buf(ctx, b); + FARF(HIGH, "prep-buf #%u : pass1 fd %u base %p size %u flags 0x%x", i, b->fd, (void*) b->base, (uint32_t) b->size, b->flags); } +} - profile_stop(&prof); +static void prep_tensor(struct htp_context *ctx, struct htp_buf_desc *bufs, uint32_t idx, struct htp_tensor *t) { + uint32_t offset = t->data; + uint32_t size = t->size; + uint32_t bi = t->bi; - struct dspqueue_buffer rsp_buf = bufs[last_buf]; - rsp_buf.flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU + t->data = bufs[bi].base + offset; // update data to the actual pointer - send_htp_rsp(ctx, req->op, rsp_status, &bufs[last_buf], 1, &prof); + FARF(HIGH, "prep-tensor #%u: bi %u offset %u size %u data %p : %u:%u:%u:%u", idx, t->bi, offset, t->size, (void*) t->data, + t->ne[0], t->ne[1], t->ne[3], t->ne[3]); } -#ifdef HTP_HAS_HMX -// --------------------------------------------------------------------------- -// HMX operation wrappers — self-contained, bypass htp_ops_context / htp_spad. -// VTCM, DMA and thread dispatch are managed inside the HMX kernels. -// --------------------------------------------------------------------------- - -static void proc_hmx_matmul_req(struct htp_context * ctx, - struct htp_general_req * req, - struct dspqueue_buffer * bufs, - size_t n_bufs) { - // HMX weight tile requires N to be 32-aligned. - if (req->src0.ne[1] % 32 != 0) { - proc_matmul_req(ctx, req, bufs, n_bufs); - return; +static void prep_tensors(struct htp_context *ctx, struct htp_buf_desc *bufs, struct htp_tensor *tens, uint32_t n_tens) { + for (uint32_t i=0; i < n_tens; i++) { + prep_tensor(ctx, bufs, i, tens + i); } +} - const bool is_batched = (req->src0.ne[2] * req->src0.ne[3] > 1 || - req->src1.ne[2] * req->src1.ne[3] > 1); +static void proc_op_req(struct htp_ops_context * octx, struct htp_tensor *tens, uint32_t idx, struct htp_op_desc * op) { + memcpy(octx->op_params, op->params, sizeof(octx->op_params)); + octx->flags = op->flags; + octx->op = op->opcode; - // Quantised HMX kernels only handle flat 2D matmul (host already rejects - // batched quantised, but guard here too). F16 batched matmul is handled - // by the dedicated wrapper in hmx-matmul-ops.c. - if (is_batched && - req->src0.type != HTP_TYPE_F16) { - proc_matmul_req(ctx, req, bufs, n_bufs); - return; - } + FARF(HIGH, "proc-op #%u: opcode %u flags 0x%x", idx, octx->op, octx->flags); - // HMX assumes contiguous row-major layout. Fall back for permuted - // tensors where strides are non-monotonic (e.g. transposed KV cache). - if (req->src0.nb[0] > req->src0.nb[1] || - req->src1.nb[0] > req->src1.nb[1]) { - proc_matmul_req(ctx, req, bufs, n_bufs); - return; - } + // Prep input tensors + for (uint32_t i=0; isrc[i] == 0xffff ? NULL : tens + op->src[i]; - // M alignment: when M > 32 but not 32-aligned, we split into - // HMX (first m_hmx = M & ~31 rows) + HVX (remaining m_tail rows). - // When M <= 32 and not 32-aligned, fall back entirely to HVX. - const int m_total = (int) req->src1.ne[1]; - const int m_tail = m_total % 32; - const int m_hmx = m_total - m_tail; + octx->src[i] = src; + if (!src) continue; - if (m_hmx == 0) { - proc_matmul_req(ctx, req, bufs, n_bufs); - return; - } - - // HMX supports F16, Q4_0, Q8_0, IQ4_NL, MXFP4 weights. - // Other types fall back to HVX. - { - uint32_t wtype = req->src0.type; - if (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_Q4_0 && wtype != HTP_TYPE_Q8_0 && wtype != HTP_TYPE_IQ4_NL && - wtype != HTP_TYPE_MXFP4) { - proc_matmul_req(ctx, req, bufs, n_bufs); - return; - } - // Quantised HMX path requires K aligned to 256 (x4x2 super-block). - // F16 HMX path requires K aligned to 32 (tile width). - if (wtype != HTP_TYPE_F16 && req->src0.ne[0] % 256 != 0) { - proc_matmul_req(ctx, req, bufs, n_bufs); - return; - } - if (wtype == HTP_TYPE_F16 && req->src0.ne[0] % 32 != 0) { - proc_matmul_req(ctx, req, bufs, n_bufs); - return; + if (!(src->flags & HTP_TENSOR_FLUSHED) && (src->flags & HTP_TENSOR_COMPUTE)) { + // flush compute buffers on input + hex_l2flush((void *) src->data, src->size); } + + FARF(HIGH, "prep-src #%u: data %p size %u : %u:%u:%u:%u", op->src[i], (void*) src->data, src->size, + src->ne[0], src->ne[1], src->ne[3], src->ne[3]); } - (void) n_bufs; - - struct dspqueue_buffer rsp_bufs[1]; - rsp_bufs[0].fd = bufs[2].fd; - rsp_bufs[0].ptr = bufs[2].ptr; - rsp_bufs[0].size = bufs[2].size; - rsp_bufs[0].offset = bufs[2].offset; - rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); - - // src0 = weights, src1 = activation, dst = output - void * wgt = (void *) bufs[0].ptr; - float * act = (float *) bufs[1].ptr; - float * dst = (float *) bufs[2].ptr; - - int k = (int) req->src0.ne[0]; // inner dimension - int n = (int) req->src0.ne[1]; // weight columns - - - struct profile_data prof; - profile_start(&prof); - - uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; - - // --- Phase 1: HMX on the first m_hmx (32-aligned) rows --- - if (vtcm_acquire(ctx) == AEE_SUCCESS) { - int ret = -1; - - const int ne02 = (int) req->src0.ne[2]; - const int ne03 = (int) req->src0.ne[3]; - const int ne12 = (int) req->src1.ne[2]; - const int ne13 = (int) req->src1.ne[3]; - // Row strides in elements. For compact tensors these equal k; for - // permuted attention views they can be larger, so pass the real stride. - const int act_stride = (int)(req->src1.nb[1] / sizeof(float)); - const int weight_stride = (int)(req->src0.nb[1] / sizeof(__fp16)); - - switch (req->src0.type) { - case HTP_TYPE_F16: - if (is_batched) { - hmx_matmul_w16a32_batched_params_t batch_params = { - .dst = dst, - .activation = act, - .permuted_weight = (const __fp16 *) wgt, - .m = m_hmx, - .k = k, - .n = n, - .act_stride = act_stride, - .weight_stride = weight_stride, - .dst_stride = (int)(req->dst.nb[1] / sizeof(float)), - .ne02 = ne02, - .ne03 = ne03, - .ne12 = ne12, - .ne13 = ne13, - .src0_nb2 = req->src0.nb[2], - .src0_nb3 = req->src0.nb[3], - .src1_nb2 = req->src1.nb[2], - .src1_nb3 = req->src1.nb[3], - .dst_nb2 = req->dst.nb[2], - .dst_nb3 = req->dst.nb[3], - }; - ret = hmx_mat_mul_permuted_w16a32_batched(ctx, &batch_params); - } else { - ret = hmx_mat_mul_permuted_w16a32(ctx, dst, act, - (const __fp16 *) wgt, - m_hmx, k, n, - act_stride, - weight_stride); - } - break; - default: - ret = hmx_mat_mul_permuted_qk_0_d16a32(ctx, dst, act, - (const uint8_t *) wgt, - m_hmx, k, n, (int) req->src0.type); - break; - } + // Prep output tensor + struct htp_tensor *dst = tens + op->dst; - if (ret == 0) { - rsp_status = HTP_STATUS_OK; - } else { - FARF(HIGH, "HMX matmul failed (ret=%d), falling back to HVX", ret); - vtcm_release(ctx); - req->flags &= ~HTP_OPFLAGS_SKIP_QUANTIZE; - proc_matmul_req(ctx, req, bufs, n_bufs); - return; - } - vtcm_release(ctx); - } + octx->dst = dst; - // --- Phase 2: HVX on the remaining m_tail rows --- - if (m_tail > 0 && rsp_status == HTP_STATUS_OK) { - struct htp_ops_context octx = { 0 }; - octx.ctx = ctx; - octx.src0 = req->src0; // weights: unchanged - octx.src1 = req->src1; - octx.src1.ne[1] = m_tail; // only tail rows - octx.dst = req->dst; - octx.dst.ne[1] = m_tail; // only tail rows - // Always re-quantize tail src1: HMX Phase 1 overwrites VTCM, - // so any previously cached quantized data (SKIP_QUANTIZE pipeline) - // is invalid. - octx.flags = req->flags & ~HTP_OPFLAGS_SKIP_QUANTIZE; - octx.op = req->op; - octx.n_threads = ctx->n_threads; - - // Offset activation and dst pointers past the HMX-processed rows. - // Use nb[1] (row stride in bytes) to compute the byte offset. - octx.src0.data = (uint32_t) bufs[0].ptr; - octx.src1.data = (uint32_t)((uint8_t *) bufs[1].ptr + (size_t) m_hmx * req->src1.nb[1]); - octx.dst.data = (uint32_t)((uint8_t *) bufs[2].ptr + (size_t) m_hmx * req->dst.nb[1]); - - FARF(HIGH, "proc_hmx_matmul: HVX tail m_tail=%d act=%p dst=%p", - m_tail, (void *)(uintptr_t) octx.src1.data, (void *)(uintptr_t) octx.dst.data); - - if (vtcm_acquire(ctx) == AEE_SUCCESS) { - uint32_t hvx_ret = op_matmul(&octx); - vtcm_release(ctx); - if (hvx_ret != HTP_STATUS_OK) { - FARF(ERROR, "HVX tail matmul failed (ret=%u)", hvx_ret); - rsp_status = HTP_STATUS_INTERNAL_ERR; - } - } else { - rsp_status = HTP_STATUS_INTERNAL_ERR; - } - } + FARF(HIGH, "prep-dst #%u: data %p size %u : %u:%u:%u:%u", op->dst, (void*) dst->data, dst->size, + dst->ne[0], dst->ne[1], dst->ne[3], dst->ne[3]); + + (void) execute_op(octx); - profile_stop(&prof); + // flush buffers on output + hex_l2flush((void *) dst->data, dst->size); + dst->flags |= HTP_TENSOR_FLUSHED; - send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); + FARF(HIGH, "post-dst #%u: data %p size %u : %u:%u:%u:%u", op->dst, (void*) dst->data, dst->size, + dst->ne[0], dst->ne[1], dst->ne[3], dst->ne[3]); } -#endif // HTP_HAS_HMX +#define DSPQUEUE_POLL_TIMEOUT_USEC 100 +#define DSPQUEUE_POLL_COUNT 100 static void htp_packet_callback(dspqueue_t queue, int error, void * context) { struct htp_context * ctx = (struct htp_context *) context; - // Repeatedly read packets from the queue until it's empty. We don't - // necessarily get a separate callback for each packet, and new packets - // may arrive while we're processing the previous one. This ensures we - // keep the DSP busy as much as possible and avoid waiting for the CPU. + int err; + + uint32_t poll_count = DSPQUEUE_POLL_COUNT; - while (1) { - struct htp_general_req req; - uint32_t req_size; + vtcm_acquire(ctx); - struct dspqueue_buffer bufs[HTP_MAX_PACKET_BUFFERS]; - uint32_t n_bufs; - uint32_t flags; + while (!ctx->vtcm_needs_release) { + struct htp_opbatch_req req; + uint32_t r_size = sizeof(req); - // Read packet from queue - int err = dspqueue_read_noblock(queue, &flags, - HTP_MAX_PACKET_BUFFERS, // Maximum number of buffer references - &n_bufs, // Number of buffer references - bufs, // Buffer references - sizeof(req), // Max message length - &req_size, // Message length - (uint8_t *) &req); // Message + struct dspqueue_buffer dbuf; + uint32_t n_dbufs = 1; + uint32_t flags = 0; + err = dspqueue_read_noblock(queue, &flags, n_dbufs, &n_dbufs, &dbuf, r_size, &r_size, (uint8_t *) &req); if (err == AEE_EWOULDBLOCK) { - // Consumed all packets available for now - return; + if (--poll_count) { + qurt_sleep(DSPQUEUE_POLL_TIMEOUT_USEC); + continue; + } + break; } if (err != 0) { FARF(ERROR, "dspqueue_read_noblock failed: 0x%08x", (unsigned) err); - return; + break; } - if (req_size != sizeof(req)) { - FARF(ERROR, "Invalid request size"); + if (r_size < sizeof(req) || n_dbufs != 1) { + FARF(ERROR, "invalid request : size %u n-dbufs %u", r_size, n_dbufs); continue; } - if (req.flags & HTP_OPFLAGS_EARLY_WAKEUP) { - // Host wants early notification - dspqueue_write_early_wakeup_noblock(ctx->queue, 10, 0); + const uint32_t n_bufs = req.n_bufs; + const uint32_t n_tens = req.n_tensors; + const uint32_t n_ops = req.n_ops; + + const uint32_t b_size = sizeof(struct htp_buf_desc) * n_bufs; + const uint32_t t_size = sizeof(struct htp_tensor) * n_tens; + const uint32_t o_size = sizeof(struct htp_op_desc) * n_ops; + + if (dbuf.size < b_size + t_size + o_size) { + FARF(ERROR, "invalid opbatch memory block size %u", dbuf.size); + break; } - // Process packet based on its message type - switch (req.op) { - case HTP_OP_MUL_MAT: - if (n_bufs != 3) { - FARF(ERROR, "Bad matmul-req buffer list"); - continue; - } -#ifdef HTP_HAS_HMX - if (ctx->hmx_enabled) { - proc_hmx_matmul_req(ctx, &req, bufs, n_bufs); - } else -#endif - { - proc_matmul_req(ctx, &req, bufs, n_bufs); - } - break; - - case HTP_OP_MUL_MAT_ID: - if (n_bufs != 4) { - FARF(ERROR, "Bad matmul-id-req buffer list"); - continue; - } - proc_matmul_id_req(ctx, &req, bufs, n_bufs); - break; - - case HTP_OP_MUL: - case HTP_OP_ADD: - case HTP_OP_SUB: - case HTP_OP_DIV: - if (n_bufs != 3) { - FARF(ERROR, "Bad binary-req buffer list"); - continue; - } - proc_binary_req(ctx, &req, bufs); - break; - - case HTP_OP_RMS_NORM: - case HTP_OP_SCALE: - if (n_bufs != 2) { - FARF(ERROR, "Bad unary-req buffer list"); - continue; - } - - proc_unary_req(ctx, &req, bufs); - break; - - case HTP_OP_SQR: - case HTP_OP_SQRT: - case HTP_OP_UNARY_NEG: - case HTP_OP_UNARY_EXP: - case HTP_OP_UNARY_SIGMOID: - case HTP_OP_UNARY_SOFTPLUS: - if (n_bufs != 2) { - FARF(ERROR, "Bad unary-req buffer list"); - continue; - } - - proc_unary_req(ctx, &req, bufs); - break; - - case HTP_OP_SUM_ROWS: - if (n_bufs != 2) { - FARF(ERROR, "Bad unary-req buffer list"); - continue; - } - - proc_sum_rows_req(ctx, &req, bufs); - break; - - case HTP_OP_UNARY_SILU: - case HTP_OP_UNARY_GELU: - if (n_bufs != 2) { - FARF(ERROR, "Bad act-req buffer list"); - continue; - } - proc_activations_req(ctx, &req, bufs, n_bufs); - break; - - case HTP_OP_GLU_SWIGLU: - case HTP_OP_GLU_SWIGLU_OAI: - case HTP_OP_SOFTMAX: - case HTP_OP_GLU_GEGLU: - if ((n_bufs != 2) && (n_bufs != 3)) { - FARF(ERROR, "Bad act-req buffer list"); - continue; - } - proc_activations_req(ctx, &req, bufs, n_bufs); - break; - - case HTP_OP_ADD_ID: - if (n_bufs != 4) { - FARF(ERROR, "Bad add-id-req buffer list"); - continue; - } - proc_add_id_req(ctx, &req, bufs); - break; - - case HTP_OP_ROPE: - if ((n_bufs != 3) && (n_bufs != 4)) { - FARF(ERROR, "Bad rope-req buffer list"); - continue; - } - proc_rope_req(ctx, &req, bufs, n_bufs); - break; - - case HTP_OP_FLASH_ATTN_EXT: - if (!(n_bufs >= 4 && n_bufs <= 6)) { - FARF(ERROR, "Bad flash-attn-ext-req buffer list"); - continue; - } - proc_flash_attn_ext_req(ctx, &req, bufs, n_bufs); - break; - - case HTP_OP_SET_ROWS: - if (n_bufs != 3) { - FARF(ERROR, "Bad set-rows-req buffer list"); - continue; - } - proc_set_rows_req(ctx, &req, bufs); - break; - - case HTP_OP_GET_ROWS: - if (n_bufs != 3) { - FARF(ERROR, "Bad get-rows-req buffer list"); - continue; - } - proc_get_rows_req(ctx, &req, bufs); - break; - - case HTP_OP_CPY: - if (n_bufs != 2) { - FARF(ERROR, "Bad cpy-req buffer list"); - continue; - } - proc_cpy_req(ctx, &req, bufs); - break; - - case HTP_OP_REPEAT: - if (n_bufs != 2) { - FARF(ERROR, "Bad repeat-req buffer list"); - continue; - } - proc_repeat_req(ctx, &req, bufs); - break; - - case HTP_OP_ARGSORT: - if (n_bufs != 2) { - FARF(ERROR, "Bad argsort-req buffer list"); - continue; - } - proc_argsort_req(ctx, &req, bufs); - break; - - case HTP_OP_SSM_CONV: - if (n_bufs != 3) { - FARF(ERROR, "Bad ssm-conv-req buffer list"); - continue; - } - proc_ssm_conv_req(ctx, &req, bufs); - break; - - case HTP_OP_CUMSUM: - if (n_bufs != 2) { - FARF(ERROR, "Bad cumsum-req buffer list"); - continue; - } - proc_cumsum_req(ctx, &req, bufs); - break; - - default: - FARF(ERROR, "Unknown Op %u", req.op); - break; + // Reset poll count for valid requests + poll_count = DSPQUEUE_POLL_COUNT; + + uint8_t * m_ptr = dbuf.ptr; + struct htp_buf_desc* bufs = (struct htp_buf_desc*) m_ptr; m_ptr += b_size; + struct htp_tensor* tens = (struct htp_tensor*) m_ptr; m_ptr += t_size; + struct htp_op_desc* ops = (struct htp_op_desc*) m_ptr; + + FARF(HIGH, "processing opbatch: n-bufs %u n-tensors %u n-ops %u : m-size %u b-size %u t-size %u o-size %u", + n_bufs, n_tens, n_ops, dbuf.size, b_size, t_size, o_size); + + prep_op_bufs(ctx, bufs, n_bufs); + prep_tensors(ctx, bufs, tens, n_tens); + + struct htp_ops_context *octx = &ctx->octx; + memset(octx, 0, sizeof(*octx)); + octx->n_threads = ctx->n_threads; + octx->ctx = ctx; + + for (uint32_t i=0; i < n_ops; i++) { + struct profile_data prof; + profile_start(&prof); + + proc_op_req(octx, tens, i, &ops[i]); + + profile_stop(&prof); + ops[i].prof_usecs = prof.usecs; + ops[i].prof_cycles = prof.cycles; + ops[i].prof_pkts = prof.pkts; + } + + // dspqueue_write_early_wakeup_noblock(ctx->queue, 10, 0); + + struct htp_opbatch_rsp rsp; + rsp.status = HTP_STATUS_OK; // FIXME + + dbuf.flags = DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT; + err = dspqueue_write(queue, 0, 1, &dbuf, sizeof(rsp), (const uint8_t *) &rsp, DSPQUEUE_TIMEOUT_NONE); + if (err != 0) { + FARF(ERROR, "dspqueue_write failed: 0x%08x", (unsigned) err); + break; } } + + vtcm_release(ctx); } diff --git a/ggml/src/ggml-hexagon/htp/matmul-ops.c b/ggml/src/ggml-hexagon/htp/matmul-ops.c index 24b7bad6876..bac06693d81 100644 --- a/ggml/src/ggml-hexagon/htp/matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/matmul-ops.c @@ -16,8 +16,9 @@ #define GGML_COMMON_DECL_C #include "ggml-common.h" #include "htp-ctx.h" -#include "htp-msg.h" #include "htp-ops.h" +#include "htp-ops.h" +#include "hmx-ops.h" #define MM_SPAD_SRC0_NROWS 16 #define MM_SPAD_SRC1_NROWS 16 @@ -1897,11 +1898,11 @@ static void vec_dot_f16_f32_uu_1x1(const int n, float * restrict s, const void * hvx_vec_store_u(&s[0], 4, rsum); } -#define htp_matmul_tensors_preamble \ - struct htp_tensor * restrict src0 = &octx->src0; \ - struct htp_tensor * restrict src1 = &octx->src1; \ - struct htp_tensor * restrict src2 = &octx->src2; \ - struct htp_tensor * restrict dst = &octx->dst; \ +#define htp_matmul_tensors_preamble \ + const struct htp_tensor * restrict src0 = octx->src[0]; \ + const struct htp_tensor * restrict src1 = octx->src[1]; \ + const struct htp_tensor * restrict src2 = octx->src[2]; \ + const struct htp_tensor * restrict dst = octx->dst; \ struct htp_spad * restrict src0_spad = &octx->src0_spad; \ struct htp_spad * restrict src1_spad = &octx->src1_spad; \ struct htp_spad * restrict dst_spad = &octx->dst_spad; \ @@ -2223,8 +2224,8 @@ struct mmid_row_mapping { static void matmul_id(unsigned int nth, unsigned int ith, void * data) { htp_matmul_preamble; - struct htp_tensor * restrict ids = &octx->src2; - struct htp_spad * restrict src2_spad = &octx->src2_spad; + const struct htp_tensor * restrict ids = octx->src[2]; + struct htp_spad * restrict src2_spad = &octx->src2_spad; uint64_t t1, t2; t1 = HAP_perf_get_qtimer_count(); @@ -2342,8 +2343,8 @@ static void matmul_id(unsigned int nth, unsigned int ith, void * data) { static void matvec_id(unsigned int nth, unsigned int ith, void * data) { htp_matmul_preamble; - struct htp_tensor * restrict ids = &octx->src2; - struct htp_spad * restrict src2_spad = &octx->src2_spad; + const struct htp_tensor * restrict ids = octx->src[2]; + struct htp_spad * restrict src2_spad = &octx->src2_spad; uint64_t t1, t2; t1 = HAP_perf_get_qtimer_count(); @@ -2612,7 +2613,7 @@ static void quantize_f32_q8x4x2(unsigned int nth, unsigned int ith, void * data) struct htp_matmul_context * mmctx = data; struct htp_ops_context * octx = mmctx->octx; - const struct htp_tensor * src = &octx->src1; + const struct htp_tensor * src = octx->src[1]; uint8_t * restrict dst = octx->src1_spad.data; struct htp_spad * spad = &octx->src0_spad; uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread; @@ -2659,7 +2660,7 @@ static void quantize_f32_f16(unsigned int nth, unsigned int ith, void * data) { struct htp_matmul_context * mmctx = data; struct htp_ops_context * octx = mmctx->octx; - const struct htp_tensor * src = &octx->src1; + const struct htp_tensor * src = octx->src[1]; uint8_t * restrict dst = octx->src1_spad.data; uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread; uint32_t dst_stride = octx->src1_spad.stride; @@ -2701,7 +2702,7 @@ static void quantize_f16_f16(unsigned int nth, unsigned int ith, void * data) { struct htp_matmul_context * mmctx = data; struct htp_ops_context * octx = mmctx->octx; - const struct htp_tensor * src = &octx->src1; + const struct htp_tensor * src = octx->src[1]; uint8_t * restrict dst = octx->src1_spad.data; uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread; uint32_t dst_stride = octx->src1_spad.stride; @@ -2800,7 +2801,7 @@ static void htp_mminit_spad(struct htp_ops_context * octx, octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; } -int op_matmul(struct htp_ops_context * octx) { +static int op_matmul_hvx(struct htp_ops_context * octx) { htp_matmul_tensors_preamble; struct htp_matmul_context mmctx_struct = {0}; @@ -2824,7 +2825,7 @@ int op_matmul(struct htp_ops_context * octx) { worker_callback_t quant_job_func; worker_callback_t matmul_job_func = src1_nrows > 1 ? matmul_2d : matvec_2d; - bool need_quant = !(octx->flags & HTP_OPFLAGS_SKIP_QUANTIZE); + bool need_quant = true; if (src0->type == HTP_TYPE_F16) { // Try optimized f16-f16 path first (src1 in VTCM) @@ -2838,7 +2839,7 @@ int op_matmul(struct htp_ops_context * octx) { // Default matmul implementation does not support multi-batch src0 (N-vs-N broadcasting). // It only supports 1-vs-N broadcasting (src0 is 2D) or standard 2D matmul. const bool is_batched = (ne02 > 1) || (ne03 > 1); - const bool is_permuted = htp_is_permuted(&octx->src0) || htp_is_permuted(&octx->src1); + const bool is_permuted = htp_is_permuted(octx->src[0]) || htp_is_permuted(octx->src[1]); if (!is_batched && !is_permuted && f16_total_size <= octx->ctx->vtcm_size) { // Optimized path @@ -2915,32 +2916,170 @@ int op_matmul(struct htp_ops_context * octx) { return HTP_STATUS_VTCM_TOO_SMALL; } - octx->src0_spad.data = octx->ctx->vtcm_base; - octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; - octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; + // Place src1 spad first. We use it for dyn.quant and may reuse between ops + octx->src1_spad.data = octx->ctx->vtcm_base; + octx->src0_spad.data = octx->src1_spad.data + octx->src1_spad.size; + octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size; + + octx->src1_spad.src = (src1 == octx->src1_spad.src) ? src1 : NULL; + octx->src0_spad.src = NULL; + octx->dst_spad.src = NULL; octx->src0_spad.stride = src0_row_size_padded; octx->src1_spad.stride = src1_row_size; - if (need_quant) { + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) + return HTP_STATUS_OK; + + if (need_quant && !octx->src1_spad.src) { const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads); mmctx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs; worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, mmctx, n_quant_jobs); - // Cache where src1 was written so subsequent SKIP_QUANTIZE ops can find it - octx->ctx->prev_src1_spad = octx->src1_spad.data; + octx->src1_spad.src = src1; + } + + const uint32_t n_matmul_jobs = octx->n_threads; + worker_pool_run_func(octx->ctx->worker_pool, matmul_job_func, mmctx, n_matmul_jobs); + + return HTP_STATUS_OK; +} + +int op_matmul(struct htp_ops_context * octx) { + htp_matmul_tensors_preamble; + +#ifndef HTP_HAS_HMX + return op_matmul_hvx(octx); +#else + if (!octx->ctx->hmx_enabled) { + return op_matmul_hvx(octx); + } + + // HMX weight tile requires N to be 32-aligned. + if (src0->ne[1] % 32 != 0) { + return op_matmul_hvx(octx); + } + + // HMX supports F16, Q4_0, Q8_0, IQ4_NL, MXFP4 weights. + // Other types fall back to HVX. + uint32_t wtype = src0->type; + if (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_Q4_0 && wtype != HTP_TYPE_Q8_0 && wtype != HTP_TYPE_IQ4_NL && wtype != HTP_TYPE_MXFP4) { + return op_matmul_hvx(octx); + } + + // Quantised HMX path requires K aligned to 256 (x4x2 super-block). + // F16 HMX path requires K aligned to 32 (tile width). + if (wtype != HTP_TYPE_F16 && src0->ne[0] % 256 != 0) { + return op_matmul_hvx(octx); + } + + if (wtype == HTP_TYPE_F16 && src0->ne[0] % 32 != 0) { + return op_matmul_hvx(octx); + } + + const bool is_batched = (src0->ne[2] * src0->ne[3] > 1 || src1->ne[2] * src1->ne[3] > 1); + + // Quantised HMX kernels only handle flat 2D matmul (host already rejects + // batched quantised, but guard here too). F16 batched matmul is handled + // by the dedicated wrapper in hmx-matmul-ops.c. + if (is_batched && src0->type != HTP_TYPE_F16) { + return op_matmul_hvx(octx); + } + + // HMX assumes contiguous row-major layout. Fall back for permuted + // tensors where strides are non-monotonic (e.g. transposed KV cache). + if (src0->nb[0] > src0->nb[1] || src1->nb[0] > src1->nb[1]) { + return op_matmul_hvx(octx); + } + + // M alignment: when M > 32 but not 32-aligned, we split into + // HMX (first m_hmx = M & ~31 rows) + HVX (remaining m_tail rows). + // When M <= 32 and not 32-aligned, fall back entirely to HVX. + const int m_total = (int) src1->ne[1]; + const int m_tail = m_total % 32; + const int m_hmx = m_total - m_tail; + + if (m_hmx == 0) { + return op_matmul_hvx(octx); + } + + // Always re-quantize src1 since HMX kernel overwrites vtcm/spad, + // so any previously cached quantized data is invalid. + octx->src1_spad.src = NULL; + + int k = (int) src0->ne[0]; // inner dimension + int n = (int) src0->ne[1]; // weight columns + + // --- Phase 1: HMX on the first m_hmx (32-aligned) rows --- + int ret = -1; + + // Row strides in elements. For compact tensors these equal k; for + // permuted attention views they can be larger, so pass the real stride. + const int act_stride = (int)(src1->nb[1] / sizeof(float)); + const int wgt_stride = (int)(src0->nb[1] / sizeof(__fp16)); + + if (src0->type == HTP_TYPE_F16) { + if (is_batched) { + hmx_matmul_w16a32_batched_params_t batch_params = { + .dst = (float *) dst->data, + .activation = (float *) src1->data, + .permuted_weight = (const __fp16 *) src0->data, + .m = m_hmx, + .k = k, + .n = n, + .act_stride = act_stride, + .weight_stride = wgt_stride, + .dst_stride = (int) (dst->nb[1] / sizeof(float)), + .ne02 = ne02, + .ne03 = ne03, + .ne12 = ne12, + .ne13 = ne13, + .src0_nb2 = src0->nb[2], + .src0_nb3 = src0->nb[3], + .src1_nb2 = src1->nb[2], + .src1_nb3 = src1->nb[3], + .dst_nb2 = dst->nb[2], + .dst_nb3 = dst->nb[3], + }; + ret = hmx_mat_mul_permuted_w16a32_batched(octx->ctx, &batch_params); + } else { + ret = hmx_mat_mul_permuted_w16a32(octx->ctx, + (float*) dst->data, (float*) src1->data, (const __fp16 *) src0->data, + m_hmx, k, n, act_stride, wgt_stride); + } } else { - // SKIP_QUANTIZE: Q8 data lives at the address written by the previous - // quantize pass. The current op may have a different src0 size (e.g. - // IQ4_NL vs MXFP4), so src1_spad.data computed above could be wrong. - octx->src1_spad.data = octx->ctx->prev_src1_spad; + ret = hmx_mat_mul_permuted_qk_0_d16a32(octx->ctx, + (float*) dst->data, (float*) src1->data, (const uint8_t *) src0->data, + m_hmx, k, n, (int) src0->type); } - if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { - const uint32_t n_matmul_jobs = octx->n_threads; - worker_pool_run_func(octx->ctx->worker_pool, matmul_job_func, mmctx, n_matmul_jobs); + if (ret != 0) { + FARF(HIGH, "HMX matmul failed (ret=%d), falling back to HVX", ret); + return op_matmul(octx); } - return HTP_STATUS_OK; + // --- Phase 2: HVX on the remaining m_tail rows --- + if (m_tail > 0) { + // copy of src1 and dst + struct htp_tensor src1_tail = *src1; + struct htp_tensor dst_tail = *dst; + + src1_tail.ne[1] = m_tail; // only tail rows + dst_tail.ne[1] = m_tail; // only tail rows + + // Offset activation and dst pointers past the HMX-processed rows. + // Use nb[1] (row stride in bytes) to compute the byte offset. + src1_tail.data += (uint32_t) m_hmx * src1->nb[1]; + dst_tail.data += (uint32_t) m_hmx * dst->nb[1]; + + octx->src[1] = &src1_tail; + octx->dst = &dst_tail; + + FARF(HIGH, "hmx-matmul: HVX tail m_tail %d src1 %p dst %p", m_tail, (void *) src1_tail.data, (void *) dst_tail.data); + return op_matmul_hvx(octx); + } + + return 0; +#endif // HTP_HAS_HMX } int op_matmul_id(struct htp_ops_context * octx) { @@ -2950,7 +3089,7 @@ int op_matmul_id(struct htp_ops_context * octx) { struct htp_matmul_context * mmctx = &mmctx_struct; mmctx->octx = octx; - struct htp_tensor * restrict ids = &octx->src2; + const struct htp_tensor * restrict ids = octx->src[2]; const size_t src0_row_size = nb01; const size_t dst_row_size = nb1; @@ -3003,11 +3142,17 @@ int op_matmul_id(struct htp_ops_context * octx) { return HTP_STATUS_VTCM_TOO_SMALL; } - octx->src0_spad.data = octx->ctx->vtcm_base; - octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; - octx->src2_spad.data = octx->src1_spad.data + octx->src1_spad.size; + // Place src1 spad first. We use it for dyn.quant and may reuse in subseq ops. + octx->src1_spad.data = octx->ctx->vtcm_base; + octx->src0_spad.data = octx->src1_spad.data + octx->src1_spad.size; + octx->src2_spad.data = octx->src0_spad.data + octx->src0_spad.size; octx->dst_spad.data = octx->src2_spad.data + octx->src2_spad.size; + octx->src1_spad.src = (src1 == octx->src1_spad.src) ? src1 : NULL; + octx->src0_spad.src = NULL; + octx->src2_spad.src = NULL; + octx->dst_spad.src = NULL; + octx->src0_spad.stride = src0_row_size_padded; octx->src1_spad.stride = src1_row_size; @@ -3031,20 +3176,18 @@ int op_matmul_id(struct htp_ops_context * octx) { } } - // Setup worker pool callbacks - if (!(octx->flags & HTP_OPFLAGS_SKIP_QUANTIZE)) { + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) + return HTP_STATUS_OK; + + if (octx->src1_spad.src != src1) { const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads); mmctx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs; worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, mmctx, n_quant_jobs); - octx->ctx->prev_src1_spad = octx->src1_spad.data; - } else { - octx->src1_spad.data = octx->ctx->prev_src1_spad; + octx->src1_spad.src = src1; } - if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { - const uint32_t n_matmul_jobs = octx->n_threads; - worker_pool_run_func(octx->ctx->worker_pool, matmul_id_job_func, mmctx, n_matmul_jobs); - } + const uint32_t n_matmul_jobs = octx->n_threads; + worker_pool_run_func(octx->ctx->worker_pool, matmul_id_job_func, mmctx, n_matmul_jobs); return HTP_STATUS_OK; } diff --git a/ggml/src/ggml-hexagon/htp/repeat-ops.c b/ggml/src/ggml-hexagon/htp/repeat-ops.c index 5db06c920e2..a6f2f0ed5f3 100644 --- a/ggml/src/ggml-hexagon/htp/repeat-ops.c +++ b/ggml/src/ggml-hexagon/htp/repeat-ops.c @@ -12,7 +12,7 @@ #define GGML_COMMON_DECL_C #include "ggml-common.h" #include "htp-ctx.h" -#include "htp-msg.h" +#include "htp-ops.h" #include "htp-ops.h" struct htp_repeat_context { @@ -32,8 +32,8 @@ struct htp_repeat_context { static void repeat_job_per_thread(unsigned int nth, unsigned int ith, void * data) { const struct htp_repeat_context * rctx = (const struct htp_repeat_context *) data; struct htp_ops_context * octx = rctx->octx; - const struct htp_tensor * src = &octx->src0; - const struct htp_tensor * dst = &octx->dst; + const struct htp_tensor * src = octx->src[0]; + const struct htp_tensor * dst = octx->dst; const uint32_t ne00 = src->ne[0]; const uint32_t ne01 = src->ne[1]; @@ -98,8 +98,8 @@ static void repeat_job_per_thread(unsigned int nth, unsigned int ith, void * dat } int op_repeat(struct htp_ops_context * octx) { - const struct htp_tensor * src0 = &octx->src0; - struct htp_tensor * dst = &octx->dst; + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * dst = octx->dst; // Validate that dst dims are multiples of src dims if (dst->ne[0] % src0->ne[0] != 0 || diff --git a/ggml/src/ggml-hexagon/htp/rope-ops.c b/ggml/src/ggml-hexagon/htp/rope-ops.c index ecedadb0fea..1d8b0796bc9 100644 --- a/ggml/src/ggml-hexagon/htp/rope-ops.c +++ b/ggml/src/ggml-hexagon/htp/rope-ops.c @@ -15,7 +15,7 @@ #define GGML_COMMON_DECL_C #include "ggml-common.h" #include "htp-ctx.h" -#include "htp-msg.h" +#include "htp-ops.h" #include "htp-ops.h" // Redefined the types GGML_ROPE_TYPE_NORMAL & GGML_ROPE_TYPE_NEOX as we can't include ggml.h @@ -253,10 +253,10 @@ static void rope_job_f32(unsigned int nth, unsigned int ith, void * data) { struct htp_rope_context * rctx = (struct htp_rope_context *) data; struct htp_ops_context * octx = rctx->octx; - const struct htp_tensor * src0 = &octx->src0; - const struct htp_tensor * src1 = &octx->src1; - const struct htp_tensor * src2 = &octx->src2; - struct htp_tensor * dst = &octx->dst; + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * src1 = octx->src[1]; + const struct htp_tensor * src2 = octx->src[2]; + const struct htp_tensor * dst = octx->dst; htp_rope_preamble; @@ -284,7 +284,7 @@ static void rope_job_f32(unsigned int nth, unsigned int ith, void * data) { dma_queue * dma_queue = octx->ctx->dma[ith]; const int32_t * pos = (const int32_t *) src1->data; - const float * freq_factors = src2->data ? (const float *) src2->data : NULL; + const float * freq_factors = src2 ? (const float *) src2->data : NULL; uint32_t ir = 0; uint32_t prev_i2 = (uint32_t) -1; @@ -384,10 +384,10 @@ static void rope_job_f32(unsigned int nth, unsigned int ith, void * data) { static int execute_op_rope_f32(struct htp_ops_context * octx) { int err = HTP_STATUS_OK; - const struct htp_tensor * src0 = &octx->src0; - const struct htp_tensor * src1 = &octx->src1; - const struct htp_tensor * src2 = &octx->src2; - struct htp_tensor * dst = &octx->dst; + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * src1 = octx->src[1]; + const struct htp_tensor * src2 = octx->src[2]; + const struct htp_tensor * dst = octx->dst; const char * op_type = "rope-f32"; @@ -424,19 +424,16 @@ static int execute_op_rope_f32(struct htp_ops_context * octx) { return HTP_STATUS_VTCM_TOO_SMALL; } - // Assign sizes octx->src0_spad.size_per_thread = src0_spad_per_thread; octx->dst_spad.size_per_thread = dst_spad_per_thread; octx->src0_spad.size = n_threads * src0_spad_per_thread; octx->dst_spad.size = n_threads * dst_spad_per_thread; octx->src1_spad.size = 0; - // Assign pointers - octx->src0_spad.data = octx->ctx->vtcm_base; - octx->src1_spad.data = NULL; - octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size; + octx->src0_spad.data = octx->ctx->vtcm_base; octx->src0_spad.src = NULL; + octx->src1_spad.data = NULL; octx->src1_spad.src = NULL; + octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size; octx->dst_spad.src = NULL; - // Fill context struct htp_rope_context rctx; memset(&rctx, 0, sizeof(struct htp_rope_context)); @@ -483,7 +480,7 @@ static int execute_op_rope_f32(struct htp_ops_context * octx) { int op_rope(struct htp_ops_context * octx) { int err = HTP_STATUS_OK; - switch (octx->src0.type) { + switch (octx->src[0]->type) { case HTP_TYPE_F32: err = execute_op_rope_f32(octx); break; diff --git a/ggml/src/ggml-hexagon/htp/set-rows-ops.c b/ggml/src/ggml-hexagon/htp/set-rows-ops.c index 4b6967749f8..0def7b408bf 100644 --- a/ggml/src/ggml-hexagon/htp/set-rows-ops.c +++ b/ggml/src/ggml-hexagon/htp/set-rows-ops.c @@ -14,33 +14,37 @@ #define GGML_COMMON_DECL_C #include "ggml-common.h" #include "htp-ctx.h" -#include "htp-msg.h" +#include "htp-ops.h" #include "htp-ops.h" -#define set_rows_preamble \ - const uint32_t ne00 = octx->src0.ne[0]; \ - const uint32_t ne01 = octx->src0.ne[1]; \ - const uint32_t ne02 = octx->src0.ne[2]; \ - const uint32_t ne03 = octx->src0.ne[3]; \ - \ - const uint32_t ne10 = octx->src1.ne[0]; \ - const uint32_t ne11 = octx->src1.ne[1]; \ - const uint32_t ne12 = octx->src1.ne[2]; \ - \ - const uint32_t nb01 = octx->src0.nb[1]; \ - const uint32_t nb02 = octx->src0.nb[2]; \ - const uint32_t nb03 = octx->src0.nb[3]; \ - \ - const uint32_t nb10 = octx->src1.nb[0]; \ - const uint32_t nb11 = octx->src1.nb[1]; \ - const uint32_t nb12 = octx->src1.nb[2]; \ - \ - const uint32_t nb1 = octx->dst.nb[1]; \ - const uint32_t nb2 = octx->dst.nb[2]; \ - const uint32_t nb3 = octx->dst.nb[3]; \ - \ - const uint32_t ne1 = octx->dst.ne[1]; \ - \ +#define set_rows_preamble \ + const uint32_t ne00 = octx->src[0]->ne[0]; \ + const uint32_t ne01 = octx->src[0]->ne[1]; \ + const uint32_t ne02 = octx->src[0]->ne[2]; \ + const uint32_t ne03 = octx->src[0]->ne[3]; \ + \ + const uint32_t ne10 = octx->src[1]->ne[0]; \ + const uint32_t ne11 = octx->src[1]->ne[1]; \ + const uint32_t ne12 = octx->src[1]->ne[2]; \ + const uint32_t ne13 = octx->src[1]->ne[3]; \ + \ + const uint32_t nb01 = octx->src[0]->nb[1]; \ + const uint32_t nb02 = octx->src[0]->nb[2]; \ + const uint32_t nb03 = octx->src[0]->nb[3]; \ + \ + const uint32_t nb10 = octx->src[1]->nb[0]; \ + const uint32_t nb11 = octx->src[1]->nb[1]; \ + const uint32_t nb12 = octx->src[1]->nb[2]; \ + \ + const uint32_t nb1 = octx->dst->nb[1]; \ + const uint32_t nb2 = octx->dst->nb[2]; \ + const uint32_t nb3 = octx->dst->nb[3]; \ + \ + const uint32_t ne0 = octx->dst->ne[0]; \ + const uint32_t ne1 = octx->dst->ne[1]; \ + const uint32_t ne2 = octx->dst->ne[2]; \ + const uint32_t ne3 = octx->dst->ne[3]; \ + \ const uint32_t nr = ne01; struct htp_set_rows_context { @@ -56,12 +60,14 @@ static void set_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *da set_rows_preamble; + uint64_t qt = HAP_perf_get_qtimer_count(); + // parallelize by rows of src0 const uint32_t dr = srctx->src0_nrows_per_thread; const uint32_t ir0 = dr * ith; const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr; - const bool is_i32 = (octx->src1.type == HTP_TYPE_I32); + const bool is_i32 = (octx->src[1]->type == HTP_TYPE_I32); for (uint32_t i03 = 0; i03 < ne03; ++i03) { for (uint32_t i02 = 0; i02 < ne02; ++i02) { @@ -70,7 +76,7 @@ static void set_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *da const uint32_t i11 = fastmodulo(i02, ne11, &srctx->div_ne11); const uint32_t i10 = i; - const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12; + const uintptr_t src1_addr = octx->src[1]->data + i10*nb10 + i11*nb11 + i12*nb12; uint32_t i1 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr; if (i1 >= ne1) { @@ -78,14 +84,18 @@ static void set_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *da continue; } - const uintptr_t src0_ptr = octx->src0.data + i*nb01 + i02*nb02 + i03*nb03; - const uintptr_t dst_ptr = octx->dst.data + i1*nb1 + i02*nb2 + i03*nb3; + const uintptr_t src0_ptr = octx->src[0]->data + i*nb01 + i02*nb02 + i03*nb03; + const uintptr_t dst_ptr = octx->dst->data + i1*nb1 + i02*nb2 + i03*nb3; // copy row hvx_copy_f32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, ne00); } } } + + qt = HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - qt); + FARF(HIGH, "set-rows-f32-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, + ne00, ne01, ne02, ne03, ir0, ir1, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, (unsigned) qt); } static void set_rows_thread_f16_f32(unsigned int nth, unsigned int ith, void *data) { @@ -94,12 +104,14 @@ static void set_rows_thread_f16_f32(unsigned int nth, unsigned int ith, void *da set_rows_preamble; + uint64_t qt = HAP_perf_get_qtimer_count(); + // parallelize by rows of src0 const uint32_t dr = srctx->src0_nrows_per_thread; const uint32_t ir0 = dr * ith; const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr; - const bool is_i32 = (octx->src1.type == HTP_TYPE_I32); + const bool is_i32 = (octx->src[1]->type == HTP_TYPE_I32); for (uint32_t i03 = 0; i03 < ne03; ++i03) { for (uint32_t i02 = 0; i02 < ne02; ++i02) { @@ -108,7 +120,7 @@ static void set_rows_thread_f16_f32(unsigned int nth, unsigned int ith, void *da const uint32_t i11 = fastmodulo(i02, ne11, &srctx->div_ne11); const uint32_t i10 = i; - const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12; + const uintptr_t src1_addr = octx->src[1]->data + i10*nb10 + i11*nb11 + i12*nb12; uint32_t i1 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr; if (i1 >= ne1) { @@ -116,13 +128,17 @@ static void set_rows_thread_f16_f32(unsigned int nth, unsigned int ith, void *da continue; } - const uint8_t* src0_ptr = (const uint8_t *) octx->src0.data + i*nb01 + i02*nb02 + i03*nb03; - uint8_t* dst_ptr = (uint8_t *) octx->dst.data + i1*nb1 + i02*nb2 + i03*nb3; + const uint8_t* src0_ptr = (const uint8_t *) octx->src[0]->data + i*nb01 + i02*nb02 + i03*nb03; + uint8_t* dst_ptr = (uint8_t *) octx->dst->data + i1*nb1 + i02*nb2 + i03*nb3; hvx_copy_f16_f32_uu(dst_ptr, src0_ptr, ne00); } } } + + qt = HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - qt); + FARF(HIGH, "set-rows-f16-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, + ne00, ne01, ne02, ne03, ir0, ir1, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, (unsigned) qt); } int op_set_rows(struct htp_ops_context * octx) { @@ -130,15 +146,15 @@ int op_set_rows(struct htp_ops_context * octx) { const uint32_t n_threads = MIN(nr, octx->n_threads); - if (octx->src0.type != HTP_TYPE_F32) { + if (octx->src[0]->type != HTP_TYPE_F32) { return HTP_STATUS_NO_SUPPORT; } - if (octx->dst.type != HTP_TYPE_F32 && octx->dst.type != HTP_TYPE_F16) { + if (octx->dst->type != HTP_TYPE_F32 && octx->dst->type != HTP_TYPE_F16) { return HTP_STATUS_NO_SUPPORT; } - if (octx->src1.type != HTP_TYPE_I32 && octx->src1.type != HTP_TYPE_I64) { + if (octx->src[1]->type != HTP_TYPE_I32 && octx->src[1]->type != HTP_TYPE_I64) { return HTP_STATUS_NO_SUPPORT; } @@ -153,7 +169,7 @@ int op_set_rows(struct htp_ops_context * octx) { srctx.src0_nrows_per_thread = (nr + n_threads - 1) / n_threads; - switch(octx->dst.type) { + switch(octx->dst->type) { case HTP_TYPE_F32: worker_pool_run_func(octx->ctx->worker_pool, set_rows_thread_f32_f32, &srctx, n_threads); break; diff --git a/ggml/src/ggml-hexagon/htp/softmax-ops.c b/ggml/src/ggml-hexagon/htp/softmax-ops.c index d6356b9506f..d78bcc0eb24 100644 --- a/ggml/src/ggml-hexagon/htp/softmax-ops.c +++ b/ggml/src/ggml-hexagon/htp/softmax-ops.c @@ -15,68 +15,89 @@ #define GGML_COMMON_DECL_C #include "ggml-common.h" #include "htp-ctx.h" -#include "htp-msg.h" +#include "htp-ops.h" #include "htp-ops.h" -#define htp_softmax_preamble3 \ - const uint32_t ne00 = src0->ne[0]; \ - const uint32_t ne01 = src0->ne[1]; \ - const uint32_t ne02 = src0->ne[2]; \ - const uint32_t ne03 = src0->ne[3]; \ - \ - const uint32_t nb00 = src0->nb[0]; \ - const uint32_t nb01 = src0->nb[1]; \ - const uint32_t nb02 = src0->nb[2]; \ - const uint32_t nb03 = src0->nb[3]; \ - \ - const uint32_t ne10 = (src1->ne[0]) ? src1->ne[0] : 1; \ - const uint32_t ne11 = (src1->ne[0]) ? src1->ne[1] : 1; \ - const uint32_t ne12 = (src1->ne[0]) ? src1->ne[2] : 1; \ - const uint32_t ne13 = (src1->ne[0]) ? src1->ne[3] : 1; \ - \ - const uint32_t nb10 = (src1->ne[0]) ? src1->nb[0] : 1; \ - const uint32_t nb11 = (src1->ne[0]) ? src1->nb[1] : 1; \ - const uint32_t nb12 = (src1->ne[0]) ? src1->nb[2] : 1; \ - const uint32_t nb13 = (src1->ne[0]) ? src1->nb[3] : 1; \ - \ - const uint32_t ne0 = dst->ne[0]; \ - const uint32_t ne1 = dst->ne[1]; \ - const uint32_t ne2 = dst->ne[2]; \ - const uint32_t ne3 = dst->ne[3]; \ - \ - const uint32_t nb0 = dst->nb[0]; \ - const uint32_t nb1 = dst->nb[1]; \ - const uint32_t nb2 = dst->nb[2]; \ +#define htp_softmax_preamble3 \ + const uint32_t ne00 = src0->ne[0]; \ + const uint32_t ne01 = src0->ne[1]; \ + const uint32_t ne02 = src0->ne[2]; \ + const uint32_t ne03 = src0->ne[3]; \ + \ + const uint32_t nb00 = src0->nb[0]; \ + const uint32_t nb01 = src0->nb[1]; \ + const uint32_t nb02 = src0->nb[2]; \ + const uint32_t nb03 = src0->nb[3]; \ + \ + const uint32_t ne10 = src1 ? src1->ne[0] : 1; \ + const uint32_t ne11 = src1 ? src1->ne[1] : 1; \ + const uint32_t ne12 = src1 ? src1->ne[2] : 1; \ + const uint32_t ne13 = src1 ? src1->ne[3] : 1; \ + \ + const uint32_t nb10 = src1 ? src1->nb[0] : 1; \ + const uint32_t nb11 = src1 ? src1->nb[1] : 1; \ + const uint32_t nb12 = src1 ? src1->nb[2] : 1; \ + const uint32_t nb13 = src1 ? src1->nb[3] : 1; \ + \ + const uint32_t ne0 = dst->ne[0]; \ + const uint32_t ne1 = dst->ne[1]; \ + const uint32_t ne2 = dst->ne[2]; \ + const uint32_t ne3 = dst->ne[3]; \ + \ + const uint32_t nb0 = dst->nb[0]; \ + const uint32_t nb1 = dst->nb[1]; \ + const uint32_t nb2 = dst->nb[2]; \ const uint32_t nb3 = dst->nb[3]; struct htp_softmax_context { + struct htp_ops_context * octx; + bool use_f16; bool use_src1; + uint32_t n_head; uint32_t n_head_log2; - float scale; - float max_bias; - float m0; - float m1; + float scale; + float max_bias; + float m0; + float m1; - uint32_t src0_nrows_per_thread; struct fastdiv_values fastdiv_ne01; struct fastdiv_values fastdiv_ne02; struct fastdiv_values fastdiv_ne12; // For mask broadcasting struct fastdiv_values fastdiv_ne13; // For mask broadcasting - size_t spad_stride; - struct htp_ops_context * octx; + uint32_t src0_nrows_per_thread; }; +static void apply_mask(float * restrict wp0, + const float * restrict mp_f32, + const __fp16 * restrict mp_f16, + uint32_t ne00, + float slope, + bool use_f16) { + if (!mp_f32) { + return; + } + if (use_f16) { + for (uint32_t i = 0; i < ne00; ++i) { + wp0[i] += slope * (float) mp_f16[i]; + } + } else { + for (uint32_t i = 0; i < ne00; ++i) { + wp0[i] += slope * mp_f32[i]; + } + } +} + static void init_softmax_ctx(struct htp_softmax_context * smctx, struct htp_ops_context * octx) { - const struct htp_tensor * src0 = &octx->src0; - const struct htp_tensor * src1 = &octx->src1; + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * src1 = octx->src[1]; memset(smctx, 0, sizeof(struct htp_softmax_context)); - memcpy(&smctx->scale, (float *) octx->op_params, sizeof(float)); + memcpy(&smctx->scale, (float *) octx->op_params, sizeof(float)); memcpy(&smctx->max_bias, (float *) octx->op_params + 1, sizeof(float)); smctx->n_head = src0->ne[2]; @@ -85,8 +106,8 @@ static void init_softmax_ctx(struct htp_softmax_context * smctx, struct htp_ops_ smctx->m0 = powf(2.0f, -(smctx->max_bias) / smctx->n_head_log2); smctx->m1 = powf(2.0f, -(smctx->max_bias / 2.0f) / smctx->n_head_log2); - smctx->use_src1 = (src1->ne[0] != 0); - smctx->use_f16 = (src1->ne[0] != 0) && (src1->type == HTP_TYPE_F16); + smctx->use_src1 = (src1 != 0); + smctx->use_f16 = (src1 != 0) && (src1->type == HTP_TYPE_F16); smctx->octx = octx; @@ -97,8 +118,8 @@ static void init_softmax_ctx(struct htp_softmax_context * smctx, struct htp_ops_ if (ne01 > 0) smctx->fastdiv_ne01 = init_fastdiv_values(ne01); if (ne02 > 0) smctx->fastdiv_ne02 = init_fastdiv_values(ne02); - const uint32_t ne12 = (src1->ne[0]) ? src1->ne[2] : 1; - const uint32_t ne13 = (src1->ne[0]) ? src1->ne[3] : 1; + const uint32_t ne12 = src1 ? src1->ne[2] : 1; + const uint32_t ne13 = src1 ? src1->ne[3] : 1; if (ne12 > 0) smctx->fastdiv_ne12 = init_fastdiv_values(ne12); if (ne13 > 0) smctx->fastdiv_ne13 = init_fastdiv_values(ne13); @@ -139,10 +160,7 @@ static void hvx_fast_softmax_prep_f32(const uint8_t * restrict src, } } -static void hvx_fast_softmax_f32(const uint8_t * restrict src, - uint8_t * restrict dst, - uint8_t * restrict pad, - const int num_elems) { +static void hvx_fast_softmax_f32(const uint8_t * restrict src, uint8_t * restrict dst, uint8_t * restrict pad, const int num_elems) { const HVX_Vector * restrict v_src = (HVX_Vector *) src; HVX_Vector * restrict v_pad = (HVX_Vector *) pad; HVX_Vector * restrict v_dst = (HVX_Vector *) dst; @@ -188,27 +206,20 @@ static void hvx_fast_softmax_f32(const uint8_t * restrict src, } } -static float hvx_softmax_f32(const uint8_t * restrict src, - uint8_t * restrict dst, - uint8_t * restrict spad, - const int num_elems, - const float max) { +static float hvx_softmax_f32(const uint8_t * restrict src, uint8_t * restrict dst, uint8_t * restrict spad, const int num_elems, const float max) { hvx_sub_scalar_f32(spad, src, max, num_elems); hvx_exp_f32(dst, spad, num_elems, false); - - float sum = hvx_reduce_sum_f32(dst, num_elems); - - return sum; + return hvx_reduce_sum_f32(dst, num_elems); } static void softmax_job_f32(unsigned int nth, unsigned int ith, void * data) { struct htp_softmax_context * smctx = (struct htp_softmax_context *) data; struct htp_ops_context * octx = smctx->octx; - const struct htp_tensor * src0 = &octx->src0; - const struct htp_tensor * src1 = &octx->src1; - struct htp_tensor * dst = &octx->dst; + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * src1 = octx->src[1]; + const struct htp_tensor * dst = octx->dst; htp_softmax_preamble3; @@ -223,22 +234,26 @@ static void softmax_job_f32(unsigned int nth, unsigned int ith, void * data) { return; } - uint64_t t1, t2; - t1 = HAP_perf_get_qtimer_count(); + uint64_t qt = HAP_perf_get_qtimer_count(); int is_aligned = 1; int opt_path = 0; + if (!hex_is_aligned((void *) src0->data, VLEN) || !hex_is_aligned((void *) dst->data, VLEN)) { is_aligned = 0; FARF(HIGH, "softmax-f32: unaligned addresses in elementwise op, possibly slower execution\n"); } + + // Only use the fast path when aligned AND row size is multiple of VLEN (128 bytes) + // The fast path (hvx_fast_softmax_f32) doesn't handle tail elements + // The non-opt path uses hvx_softmax_f32 which properly handles all sizes via its helper functions if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) { opt_path = 1; } - uint8_t * src0_spad_data = octx->src0_spad.data + (ith * smctx->spad_stride); - uint8_t * src1_spad_data = octx->src1_spad.data + (ith * smctx->spad_stride); - uint8_t * dst_spad_data = octx->dst_spad.data + (ith * smctx->spad_stride); + uint8_t * src0_spad_data = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread); + uint8_t * src1_spad_data = octx->src1_spad.data + (ith * octx->src1_spad.size_per_thread); + uint8_t * dst_spad_data = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread); float * wp0 = (float *) src0_spad_data; float * wp1 = (float *) src1_spad_data; @@ -278,47 +293,29 @@ static void softmax_job_f32(unsigned int nth, unsigned int ith, void * data) { // ALiBi if (i2 != prev_i2) { const uint32_t h = i2; // head - - slope = (smctx->max_bias > 0.0f) ? - h < smctx->n_head_log2 ? - powf(smctx->m0, h + 1) : - powf(smctx->m1, 2 * (h - smctx->n_head_log2) + 1) : - 1.0f; + slope = (smctx->max_bias > 0.0f) ? h < smctx->n_head_log2 ? powf(smctx->m0, h + 1) : powf(smctx->m1, 2 * (h - smctx->n_head_log2) + 1) : 1.0f; prev_i2 = i2; } - float * sp = (float *) ((char *) octx->src0.data + i1 * nb01 + i2 * nb02 + i3 * nb03); - float * dp = (float *) ((char *) octx->dst.data + i1 * nb1 + i2 * nb2 + i3 * nb3); + float * sp = (float *) ((char *) src0->data + i1 * nb01 + i2 * nb02 + i3 * nb03); + float * dp = (float *) ((char *) dst->data + i1 * nb1 + i2 * nb2 + i3 * nb3); // broadcast the mask across rows - __fp16 * mp_f16 = (smctx->use_src1) ? - (__fp16 *) ((char *) octx->src1.data + i11 * nb11 + i12 * nb12 + i13 * nb13) : - NULL; - float * mp_f32 = (smctx->use_src1) ? - (float *) ((char *) octx->src1.data + i11 * nb11 + i12 * nb12 + i13 * nb13) : - NULL; + __fp16 * mp_f16 = (smctx->use_src1) ? (__fp16 *) ((char *) src1->data + i11 * nb11 + i12 * nb12 + i13 * nb13) : NULL; + float * mp_f32 = (smctx->use_src1) ? (float *) ((char *) src1->data + i11 * nb11 + i12 * nb12 + i13 * nb13) : NULL; if ((1 == opt_path) && (mp_f32) && !(smctx->use_f16)) { - hvx_fast_softmax_prep_f32((const uint8_t *) sp, (uint8_t *) wp0, ne00, smctx->scale, - (const uint8_t *) mp_f32, slope); - } else { + hvx_fast_softmax_prep_f32((const uint8_t *) sp, (uint8_t *) wp0, ne00, smctx->scale, (const uint8_t *) mp_f32, slope); + hvx_fast_softmax_f32((const uint8_t *) wp0, (uint8_t *) dp, (uint8_t *) wp1, ne00); + } else if (1 == opt_path) { hvx_scale_f32((uint8_t *) wp0, (const uint8_t *) sp, ne00, smctx->scale); - if (mp_f32) { - if (smctx->use_f16) { - for (int i = 0; i < ne00; ++i) { - wp0[i] += slope * (float) mp_f16[i]; - } - } else { - for (int i = 0; i < ne00; ++i) { - wp0[i] += slope * mp_f32[i]; - } - } - } - } - - if (1 == opt_path) { + apply_mask(wp0, mp_f32, mp_f16, ne00, slope, smctx->use_f16); hvx_fast_softmax_f32((const uint8_t *) wp0, (uint8_t *) dp, (uint8_t *) wp1, ne00); } else { + // Non-optimized path: uses HVX helper functions that properly handle all tensor sizes + // including non-multiples of 32 (the HVX vector lane count for f32) + hvx_scale_f32((uint8_t *) wp0, (const uint8_t *) sp, ne00, smctx->scale); + apply_mask(wp0, mp_f32, mp_f16, ne00, slope, smctx->use_f16); float max = hvx_reduce_max_f32((const uint8_t *) wp0, ne00); float sum = hvx_softmax_f32((const uint8_t *) wp0, (uint8_t *) wp2, (uint8_t *) wp1, ne00, max); sum = sum > 0.0 ? (1.0 / sum) : 1; @@ -326,54 +323,47 @@ static void softmax_job_f32(unsigned int nth, unsigned int ith, void * data) { } } - t2 = HAP_perf_get_qtimer_count(); - - FARF(HIGH, "softmax-f32 %d/%d/%d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, - smctx->use_f16, opt_path, ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13, - ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); + qt = HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - qt); + FARF(HIGH, "softmax-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u : opt %u f16 %u usec %u\n", ith, nth, + ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13, + ne0, ne1, ne2, ne3, opt_path, smctx->use_f16, (unsigned) qt); } static int execute_op_softmax_f32(struct htp_ops_context * octx) { int err = HTP_STATUS_OK; - const struct htp_tensor * src0 = &octx->src0; - const struct htp_tensor * src1 = &octx->src1; - struct htp_tensor * dst = &octx->dst; + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * src1 = octx->src[1]; + const struct htp_tensor * dst = octx->dst; struct htp_softmax_context smctx; const char * op_type = "softmax-f32"; - switch (octx->op) { - case HTP_OP_SOFTMAX: - init_softmax_ctx(&smctx, octx); - break; - - default: - FARF(ERROR, "Unsupported Op %u\n", octx->op); - return HTP_STATUS_NO_SUPPORT; - } + init_softmax_ctx(&smctx, octx); const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3]; const uint32_t n_threads = MIN(octx->n_threads, src0_nrows); + smctx.src0_nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads; + const size_t src0_row_size = src0->nb[1]; const size_t src1_row_size = src0_row_size; const size_t dst_row_size = dst->nb[1]; // VTCM scratchpads for all tensors - // N rows per thread, padded to HVX vector size - octx->dst_spad.size = hex_round_up(dst_row_size, 128) * n_threads; - octx->src0_spad.size = hex_round_up(src0_row_size, 128) * n_threads; - octx->src1_spad.size = hex_round_up(src1_row_size, 128) * n_threads; + // 4 rows per thread, padded to HVX vector size + octx->src0_spad.size_per_thread = hex_round_up(4 * src0_row_size, 128); + octx->src1_spad.size_per_thread = hex_round_up(4 * src1_row_size, 128); + octx->dst_spad.size_per_thread = hex_round_up(4 * dst_row_size, 128); - // Use stride for calculating offset - smctx.spad_stride = hex_round_up(src0_row_size, 128); + octx->src0_spad.size = octx->src0_spad.size_per_thread * n_threads; + octx->src1_spad.size = octx->src1_spad.size_per_thread * n_threads; + octx->dst_spad.size = octx->dst_spad.size_per_thread * n_threads; size_t spad_size = octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size; - if (src1->ne[0]) { - FARF(HIGH, - "%s: %ux%ux%ux%u x %ux%ux%ux%u -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", + if (src1) { + FARF(HIGH, "%s: %ux%ux%ux%u x %ux%ux%ux%u -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size); @@ -385,19 +375,17 @@ static int execute_op_softmax_f32(struct htp_ops_context * octx) { // Make sure the reserved vtcm size is sufficient if (octx->ctx->vtcm_size < spad_size) { - FARF(ERROR, "%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size, - spad_size); + FARF(ERROR, "%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size, spad_size); return HTP_STATUS_VTCM_TOO_SMALL; } - octx->src0_spad.data = octx->ctx->vtcm_base; - octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; - octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; + octx->src0_spad.data = octx->ctx->vtcm_base; octx->src0_spad.src = NULL; + octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; octx->src1_spad.src = NULL; + octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; octx->dst_spad.src = NULL; - if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { - smctx.src0_nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads; - worker_pool_run_func(octx->ctx->worker_pool, softmax_job_f32, &smctx, n_threads); - } + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) return err; + + worker_pool_run_func(octx->ctx->worker_pool, softmax_job_f32, &smctx, n_threads); return err; } @@ -405,7 +393,7 @@ static int execute_op_softmax_f32(struct htp_ops_context * octx) { int op_softmax(struct htp_ops_context * octx) { int err = HTP_STATUS_OK; - switch (octx->src0.type) { + switch (octx->src[0]->type) { case HTP_TYPE_F32: err = execute_op_softmax_f32(octx); break; diff --git a/ggml/src/ggml-hexagon/htp/ssm-conv.c b/ggml/src/ggml-hexagon/htp/ssm-conv.c index 6b035810d57..a28fd03e978 100644 --- a/ggml/src/ggml-hexagon/htp/ssm-conv.c +++ b/ggml/src/ggml-hexagon/htp/ssm-conv.c @@ -16,14 +16,14 @@ #include "ggml-common.h" #include "htp-ctx.h" #include "hex-dma.h" -#include "htp-msg.h" +#include "htp-ops.h" #include "htp-ops.h" #include "hvx-utils.h" -#define htp_ssm_conv_tensors_preamble \ - struct htp_tensor * restrict src0 = &octx->src0; \ - struct htp_tensor * restrict src1 = &octx->src1; \ - struct htp_tensor * restrict dst = &octx->dst; \ +#define htp_ssm_conv_tensors_preamble \ + const struct htp_tensor * restrict src0 = octx->src[0]; \ + const struct htp_tensor * restrict src1 = octx->src[1]; \ + const struct htp_tensor * restrict dst = octx->dst; \ struct htp_spad * restrict src0_spad = &octx->src0_spad; \ struct htp_spad * restrict src1_spad = &octx->src1_spad; \ struct htp_spad * restrict dst_spad = &octx->dst_spad; \ @@ -289,9 +289,9 @@ int op_ssm_conv_f32(struct htp_ops_context * octx) { // Compute gather scratchpad size for src0 and src1 const size_t gather_spad_size = n_threads * VLEN * 2; - octx->src0_spad.data = octx->ctx->vtcm_base + gather_spad_size; - octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; - octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; + octx->src0_spad.data = octx->ctx->vtcm_base + gather_spad_size; octx->src0_spad.src = NULL; + octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; octx->src1_spad.src = NULL; + octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; octx->dst_spad.src = NULL; FARF(HIGH, "ssm_conv-f32: gather-spad:%zu spad-per-thread:(%u:%u:%u) spad-sizes:(%u:%u:%u) spad-data:(%p:%p:%p)\n", gather_spad_size, octx->src0_spad.size_per_thread, octx->src1_spad.size_per_thread, @@ -323,8 +323,9 @@ int op_ssm_conv_f32(struct htp_ops_context * octx) { } int op_ssm_conv(struct htp_ops_context * octx) { - int err = HTP_STATUS_OK; - struct htp_tensor * dst = &octx->dst; + const struct htp_tensor * dst = octx->dst; + + int err = HTP_STATUS_OK; switch (dst->type) { case HTP_TYPE_F32: diff --git a/ggml/src/ggml-hexagon/htp/sum-rows-ops.c b/ggml/src/ggml-hexagon/htp/sum-rows-ops.c index 352650b689b..874c41ab2ac 100644 --- a/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +++ b/ggml/src/ggml-hexagon/htp/sum-rows-ops.c @@ -14,13 +14,13 @@ #define GGML_COMMON_DECL_C #include "ggml-common.h" #include "htp-ctx.h" -#include "htp-msg.h" +#include "htp-ops.h" #include "htp-ops.h" -#define sum_rows_preamble \ - struct htp_tensor *src0 = &octx->src0;\ - struct htp_tensor *dst = &octx->dst; \ - \ +#define sum_rows_preamble \ + const struct htp_tensor *src0 = octx->src[0]; \ + const struct htp_tensor *dst = octx->dst; \ + \ const uint32_t ne00 = src0->ne[0]; \ const uint32_t ne01 = src0->ne[1]; \ const uint32_t ne02 = src0->ne[2]; \ @@ -94,7 +94,7 @@ static void sum_rows_thread_f32(unsigned int nth, unsigned int ith, void *data) int op_sum_rows(struct htp_ops_context * octx) { sum_rows_preamble; - if (octx->src0.type != HTP_TYPE_F32) { + if (octx->src[0]->type != HTP_TYPE_F32) { return HTP_STATUS_NO_SUPPORT; } diff --git a/ggml/src/ggml-hexagon/htp/unary-ops.c b/ggml/src/ggml-hexagon/htp/unary-ops.c index 13d28317d5c..03eccfd55e3 100644 --- a/ggml/src/ggml-hexagon/htp/unary-ops.c +++ b/ggml/src/ggml-hexagon/htp/unary-ops.c @@ -16,7 +16,7 @@ #define GGML_COMMON_DECL_C #include "ggml-common.h" #include "htp-ctx.h" -#include "htp-msg.h" +#include "htp-ops.h" #include "htp-ops.h" struct htp_unary_context { @@ -267,8 +267,8 @@ static void softplus_f32(const float * restrict src, static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * data) { const struct htp_unary_context * uctx = (const struct htp_unary_context *) data; struct htp_ops_context * octx = uctx->octx; - const struct htp_tensor * src = &octx->src0; - const struct htp_tensor * dst = &octx->dst; + const struct htp_tensor * src = octx->src[0]; + const struct htp_tensor * dst = octx->dst; htp_unary_preamble; @@ -387,8 +387,8 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * static int execute_op_unary_f32(struct htp_ops_context * octx) { int err = HTP_STATUS_OK; - const struct htp_tensor * src0 = &octx->src0; - struct htp_tensor * dst = &octx->dst; + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * dst = octx->dst; const char * op_type = NULL; @@ -490,7 +490,7 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) { int op_unary(struct htp_ops_context * octx) { int err = HTP_STATUS_OK; - switch (octx->src0.type) { + switch (octx->src[0]->type) { case HTP_TYPE_F32: err = execute_op_unary_f32(octx); break; From 3af7c879bc3317337fb46f5b00ca1702243b8a56 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Sat, 11 Apr 2026 10:30:30 +0800 Subject: [PATCH 416/831] CUDA: also store node->src ne/nb for graph equality (llama/21736) --- ggml/src/ggml-cuda/common.cuh | 4 +++- ggml/src/ggml-cuda/ggml-cuda.cu | 12 +++++++----- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 56a67f1edc8..8a4246223b5 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -1185,7 +1185,9 @@ struct ggml_cuda_graph { bool warmup_complete = false; struct node_properties { ggml_tensor node; - void * node_src_data_ptrs[GGML_MAX_SRC]; + void * node_src_data_ptrs[GGML_MAX_SRC]; + int64_t node_src_ne[GGML_MAX_SRC][GGML_MAX_DIMS]; + size_t node_src_nb[GGML_MAX_SRC][GGML_MAX_DIMS]; }; std::vector node_props; diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 8613d20b9f9..3113de017f0 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3070,16 +3070,18 @@ static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx ggml_cuda_graph::node_properties prop = {}; memcpy(&prop.node, cgraph->nodes[i], sizeof(ggml_tensor)); - // if the backend scheduler is making copies of CPU tensors, the src pointers can be the same but with different data, see: - // https://github.com/ggml-org/llama.cpp/pull/21472#discussion_r3052235188 for (int j = 0; j < GGML_MAX_SRC; ++j) { - prop.node_src_data_ptrs[j] = cgraph->nodes[i]->src[j] ? cgraph->nodes[i]->src[j]->data : nullptr; + if (cgraph->nodes[i]->src[j]) { + prop.node_src_data_ptrs[j] = cgraph->nodes[i]->src[j]->data; + memcpy(prop.node_src_ne[j], cgraph->nodes[i]->src[j]->ne, sizeof(prop.node_src_ne[j])); + memcpy(prop.node_src_nb[j], cgraph->nodes[i]->src[j]->nb, sizeof(prop.node_src_nb[j])); + } } - if (!res && memcmp(&graph->node_props[i], &prop, sizeof(prop)) != 0) { + if (res || memcmp(&graph->node_props[i], &prop, sizeof(prop)) != 0) { + graph->node_props[i] = prop; res = true; } - graph->node_props[i] = prop; } return res; From 34381b01c44c2fc0c40a00e2086fd6318bd7f570 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Sat, 11 Apr 2026 08:45:00 +0200 Subject: [PATCH 417/831] ggml : fix a few instances of missing GGML_TYPE_Q1_0 cases (llama/21716) --- ggml/src/ggml-cpu/ops.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 0b5d6c6df88..a9bc21da6f0 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -664,6 +664,7 @@ void ggml_compute_forward_add( { ggml_compute_forward_add_non_quantized(params, dst); } break; + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -1113,6 +1114,7 @@ void ggml_compute_forward_add1( GGML_ABORT("fatal error"); } } break; + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -1242,6 +1244,7 @@ void ggml_compute_forward_acc( } break; case GGML_TYPE_F16: case GGML_TYPE_BF16: + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -4331,6 +4334,7 @@ void ggml_compute_forward_out_prod( const ggml_tensor * src0 = dst->src[0]; switch (src0->type) { + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -4606,6 +4610,7 @@ void ggml_compute_forward_set( } break; case GGML_TYPE_F16: case GGML_TYPE_BF16: + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: From e0c8e505e995a3936998ca36feacaf0cf1950133 Mon Sep 17 00:00:00 2001 From: shaofeiqi Date: Sat, 11 Apr 2026 01:46:19 -0700 Subject: [PATCH 418/831] opencl: add basic support for q5_k (llama/21593) * opencl: add general q5_k mv * opencl: add flattened Q5_K mv and general Q5_K mm * opencl: fix Q5_K unit tests --- ggml/src/ggml-opencl/CMakeLists.txt | 3 + ggml/src/ggml-opencl/ggml-opencl.cpp | 384 +++++++++++++++++- ggml/src/ggml-opencl/kernels/cvt.cl | 76 ++++ .../kernels/mul_mm_q5_k_f32_l4_lm.cl | 192 +++++++++ .../ggml-opencl/kernels/mul_mv_q5_k_f32.cl | 187 +++++++++ .../kernels/mul_mv_q5_k_f32_flat.cl | 203 +++++++++ 6 files changed, 1043 insertions(+), 2 deletions(-) create mode 100644 ggml/src/ggml-opencl/kernels/mul_mm_q5_k_f32_l4_lm.cl create mode 100644 ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32.cl create mode 100644 ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32_flat.cl diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index 540942b195d..112c2afe821 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -90,6 +90,8 @@ set(GGML_OPENCL_KERNELS mul_mv_q4_1_f32_flat mul_mv_q4_k_f32 mul_mv_q4_k_f32_flat + mul_mv_q5_k_f32 + mul_mv_q5_k_f32_flat mul_mv_q6_k_f32 mul_mv_q6_k_f32_flat mul_mv_q8_0_f32 @@ -109,6 +111,7 @@ set(GGML_OPENCL_KERNELS mul_mm_q4_1_f32_l4_lm mul_mm_q8_0_f32_l4_lm mul_mm_q4_k_f32_l4_lm + mul_mm_q5_k_f32_l4_lm mul_mm_q6_k_f32_l4_lm mul_mm_q8_0_f32_8x4 gemv_noshuffle_q4_1_f32 diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index f1a28a7f4cd..a581402300a 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -541,12 +541,15 @@ struct ggml_backend_opencl_context { cl_kernel kernel_convert_block_q4_K_noshuffle; cl_kernel kernel_restore_block_q4_K_noshuffle; cl_kernel kernel_convert_block_q4_K, kernel_restore_block_q4_K; + cl_kernel kernel_convert_block_q5_K, kernel_restore_block_q5_K; cl_kernel kernel_convert_block_q6_K, kernel_restore_block_q6_K; cl_kernel kernel_mul_mat_q4_0_f32_1d_8x_flat, kernel_mul_mat_q4_0_f32_1d_16x_flat; cl_kernel kernel_mul_mv_q4_1_f32; cl_kernel kernel_mul_mv_q4_1_f32_flat; cl_kernel kernel_mul_mv_q4_K_f32; cl_kernel kernel_mul_mv_q4_K_f32_flat; + cl_kernel kernel_mul_mv_q5_K_f32; + cl_kernel kernel_mul_mv_q5_K_f32_flat; cl_kernel kernel_mul_mv_q6_K_f32; cl_kernel kernel_mul_mv_q6_K_f32_flat; cl_kernel kernel_mul_mv_mxfp4_f32, kernel_mul_mv_mxfp4_f32_flat; @@ -587,6 +590,7 @@ struct ggml_backend_opencl_context { cl_kernel kernel_mul_mm_q4_1_f32_l4_lm; cl_kernel kernel_mul_mm_q8_0_f32_l4_lm; cl_kernel kernel_mul_mm_q4_k_f32_l4_lm; + cl_kernel kernel_mul_mm_q5_k_f32_l4_lm; cl_kernel kernel_mul_mm_q6_k_f32_l4_lm; std::vector profiling_info; @@ -938,6 +942,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve CL_CHECK((backend_ctx->kernel_restore_block_q4_K = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_K", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q4_K_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_K_noshuffle", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_q4_K_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_K_noshuffle", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q5_K = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q5_K", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q5_K = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q5_K", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q6_K = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q6_K", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_q6_K = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q6_K", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q6_K_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q6_K_noshuffle", &err), err)); @@ -1249,6 +1255,39 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // mul_mv_q5_k_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_q5_k_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_q5_k_f32.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mv_q5_K_f32 = clCreateKernel(prog, "kernel_mul_mv_q5_K_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // mul_mv_q5_k_f32_flat + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_q5_k_f32_flat.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_q5_k_f32_flat.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mv_q5_K_f32_flat = clCreateKernel(prog, "kernel_mul_mv_q5_K_f32_flat", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + } + // mul_mv_q6_k_f32 { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -1556,6 +1595,23 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // mul_mm_q5_k_f32_l4_lm + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mm_q5_k_f32_l4_lm.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mm_q5_k_f32_l4_lm.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mm_q5_k_f32_l4_lm = clCreateKernel(prog, "kernel_mul_mm_q5_k_f32_l4_lm", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + // mul_mm_f16_f32_kq_kqv { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -3530,6 +3586,58 @@ struct ggml_tensor_extra_cl_q4_K { } }; +struct ggml_tensor_extra_cl_q5_K { + // Lower 4 bits of quantized weights. + cl_mem q = nullptr; + // Upper 1 bit of quantized weights. + cl_mem qh = nullptr; + // Scales for each block. + cl_mem s = nullptr; + // Scales for each super block. + cl_mem d = nullptr; + // Min for each super block. + cl_mem dm = nullptr; + + size_t size_q = 0; + size_t size_qh = 0; + size_t size_s = 0; + size_t size_d = 0; + size_t size_dm = 0; + + ~ggml_tensor_extra_cl_q5_K() { + reset(); + } + + void reset() { + if (q != nullptr) { + CL_CHECK(clReleaseMemObject(q)); + q = nullptr; + } + if (qh != nullptr) { + CL_CHECK(clReleaseMemObject(qh)); + qh = nullptr; + } + if (s != nullptr) { + CL_CHECK(clReleaseMemObject(s)); + s = nullptr; + } + if (d != nullptr) { + CL_CHECK(clReleaseMemObject(d)); + d = nullptr; + } + if (dm != nullptr) { + CL_CHECK(clReleaseMemObject(dm)); + dm = nullptr; + } + + size_q = 0; + size_qh = 0; + size_s = 0; + size_d = 0; + size_dm = 0; + } +}; + struct ggml_tensor_extra_cl_q6_K { // Lower 4 bits of quantized weights. cl_mem ql = nullptr; @@ -3945,6 +4053,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te } else if (op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q4_1 || op->src[0]->type == GGML_TYPE_MXFP4 || op->src[0]->type == GGML_TYPE_Q4_K || + op->src[0]->type == GGML_TYPE_Q5_K || op->src[0]->type == GGML_TYPE_Q6_K) { return op->src[1]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]); } else if (op->src[0]->type == GGML_TYPE_Q8_0) { @@ -4153,6 +4262,12 @@ struct ggml_backend_opencl_buffer_context { for (ggml_tensor_extra_cl_q6_K * e : temp_tensor_extras_q6_K_in_use) { delete e; } + for (ggml_tensor_extra_cl_q5_K * e : temp_tensor_extras_q5_K) { + delete e; + } + for (ggml_tensor_extra_cl_q5_K * e : temp_tensor_extras_q5_K_in_use) { + delete e; + } } ggml_tensor_extra_cl * ggml_opencl_alloc_temp_tensor_extra() { @@ -4245,6 +4360,21 @@ struct ggml_backend_opencl_buffer_context { return extra; } + ggml_tensor_extra_cl_q5_K * ggml_opencl_alloc_temp_tensor_extra_q5_K() { + ggml_tensor_extra_cl_q5_K * extra; + if (temp_tensor_extras_q5_K.empty()) { + extra = new ggml_tensor_extra_cl_q5_K(); + } else { + extra = temp_tensor_extras_q5_K.back(); + temp_tensor_extras_q5_K.pop_back(); + } + + temp_tensor_extras_q5_K_in_use.push_back(extra); + + extra->reset(); + return extra; + } + ggml_tensor_extra_cl_q6_K * ggml_opencl_alloc_temp_tensor_extra_q6_K() { ggml_tensor_extra_cl_q6_K * extra; if (temp_tensor_extras_q6_K.empty()) { @@ -4291,6 +4421,11 @@ struct ggml_backend_opencl_buffer_context { } temp_tensor_extras_q4_K_in_use.clear(); + for (ggml_tensor_extra_cl_q5_K * e : temp_tensor_extras_q5_K_in_use) { + temp_tensor_extras_q5_K.push_back(e); + } + temp_tensor_extras_q5_K_in_use.clear(); + for (ggml_tensor_extra_cl_q6_K * e : temp_tensor_extras_q6_K_in_use) { temp_tensor_extras_q6_K.push_back(e); } @@ -4314,6 +4449,8 @@ struct ggml_backend_opencl_buffer_context { std::vector temp_tensor_extras_q8_0_in_use; std::vector temp_tensor_extras_q4_K; std::vector temp_tensor_extras_q4_K_in_use; + std::vector temp_tensor_extras_q5_K; + std::vector temp_tensor_extras_q5_K_in_use; std::vector temp_tensor_extras_q6_K; std::vector temp_tensor_extras_q6_K_in_use; @@ -5152,6 +5289,97 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, #endif // GGML_OPENCL_USE_ADRENO_KERNELS return; } + if (tensor->type == GGML_TYPE_Q5_K) { + ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra; + GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized"); + + // Allocate the new extra and create aliases from the original. + ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; + ggml_tensor_extra_cl_q5_K * extra = ctx->ggml_opencl_alloc_temp_tensor_extra_q5_K(); + + size_t size_q = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/2; + size_t size_qh = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/8; + size_t size_s = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*(3*ggml_blck_size(tensor->type)/64); + size_t size_d = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t); + size_t size_dm = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t); + GGML_ASSERT(size_q + size_qh + size_s + size_d + size_dm == ggml_nbytes(tensor) && + "Incorrect tensor size"); + + cl_int err; + cl_mem data_device; + CL_CHECK((data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, ggml_nbytes(tensor), NULL, &err), err)); + CL_CHECK(clEnqueueWriteBuffer(queue, data_device, CL_TRUE, 0, ggml_nbytes(tensor), data, 0, NULL, NULL)); + + cl_buffer_region region; + + // Create subbuffer for d. + region.origin = align_to(extra_orig->offset + tensor->view_offs + offset, backend_ctx->alignment); + region.size = size_d; + extra->d = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + auto previous_origin = region.origin; + + // Create subbuffer for dm. + region.origin = align_to(previous_origin + size_d, backend_ctx->alignment); + region.size = size_dm; + extra->dm = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + previous_origin = region.origin; + + // Create subbuffer for s. + region.origin = align_to(previous_origin + size_dm, backend_ctx->alignment); + region.size = size_s; + extra->s = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + previous_origin = region.origin; + + // Create subbuffer for q (lower 4 bits) + region.origin = align_to(previous_origin + size_s, backend_ctx->alignment); + region.size = size_q; + extra->q = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + previous_origin = region.origin; + + // Create subbuffer for qh (upper 1 bit) + region.origin = align_to(previous_origin + size_q, backend_ctx->alignment); + region.size = size_qh; + CL_CHECK((extra->qh = clCreateSubBuffer(extra_orig->data_device, CL_MEM_READ_WRITE, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + CL_CHECK(err); + + cl_kernel kernel = backend_ctx->kernel_convert_block_q5_K; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->qh)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->s)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extra->dm)); + + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + + extra->size_q = size_q; + extra->size_qh = size_qh; + extra->size_s = size_s; + extra->size_d = size_d; + extra->size_dm = size_dm; + + tensor->extra = extra; + return; + } if (tensor->type == GGML_TYPE_Q6_K) { ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra; GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized"); @@ -5658,6 +5886,35 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, CL_CHECK(clReleaseMemObject(data_device)); return; } + if (tensor->type == GGML_TYPE_Q5_K) { + ggml_tensor_extra_cl_q5_K * extra = (ggml_tensor_extra_cl_q5_K *)tensor->extra; + + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + + cl_kernel kernel = backend_ctx->kernel_restore_block_q5_K; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->s)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra->dm)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &data_device)); + + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {1, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clEnqueueReadBuffer( + queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); + return; + } if (tensor->type == GGML_TYPE_Q6_K) { ggml_tensor_extra_cl_q6_K * extra = (ggml_tensor_extra_cl_q6_K *)tensor->extra; @@ -10221,6 +10478,7 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co ggml_tensor_extra_cl_mxfp4 * extra0_mxfp4 = (ggml_tensor_extra_cl_mxfp4 *)src0->extra; ggml_tensor_extra_cl_q8_0 * extra0_q8_0 = (ggml_tensor_extra_cl_q8_0 *)src0->extra; ggml_tensor_extra_cl_q4_K * extra0_q4_K = (ggml_tensor_extra_cl_q4_K *)src0->extra; + ggml_tensor_extra_cl_q5_K * extra0_q5_K = (ggml_tensor_extra_cl_q5_K *)src0->extra; ggml_tensor_extra_cl_q6_K * extra0_q6_K = (ggml_tensor_extra_cl_q6_K *)src0->extra; #endif @@ -10925,6 +11183,51 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); return; } + case GGML_TYPE_Q5_K: { + if (ne11 < 32) { + break; + } + if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) { + break; + } + + kernel = backend_ctx->kernel_mul_mm_q5_k_f32_l4_lm; + nth0 = 128; // calculated as (BM*BN)/(TM*TN) + + int batch_stride_a = ne00*ne01; + int batch_stride_b = ne10*ne11; + int batch_stride_d = ne0*ne1; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q5_K->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q5_K->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q5_K->s)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q5_K->d)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra0_q5_K->dm)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne10)); // stride_a + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne10)); // stride_b + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne01)); // stride_d + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &batch_stride_a)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &batch_stride_b)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(int), &batch_stride_d)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 21, sizeof(int), &r3)); + + // 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed. + size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13}; + size_t local_work_size[] = {(size_t)nth0, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + return; + } case GGML_TYPE_Q6_K: { if (ne11 < 32) { break; @@ -11442,7 +11745,81 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co #endif // GGML_OPENCL_SOA_Q break; } - case GGML_TYPE_Q5_K: + case GGML_TYPE_Q5_K: { +#ifdef GGML_OPENCL_SOA_Q + kernel = backend_ctx->kernel_mul_mv_q5_K_f32_flat; + + if (backend_ctx->gpu_family == INTEL) { + nth0 = 16; + nth1 = 1; + ndst = 4; + } else if (backend_ctx->gpu_family == ADRENO) { + nth0 = 64; + nth1 = 2; + ndst = 16; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q5_K->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q5_K->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q5_K->s)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q5_K->d)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra0_q5_K->dm)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb13)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 21, sizeof(int), &r3)); +#else + kernel = backend_ctx->kernel_mul_mv_q5_K_f32; + + if (backend_ctx->gpu_family == INTEL) { + nth0 = 16; + nth1 = 1; + ndst = 4; + } else if (backend_ctx->gpu_family == ADRENO) { + nth0 = 64; + nth1 = 1; + ndst = 4; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(int), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb13)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &r3)); +#endif // GGML_OPENCL_SOA_Q + break; + } case GGML_TYPE_Q6_K: #ifdef GGML_OPENCL_SOA_Q kernel = backend_ctx->kernel_mul_mv_q6_K_f32_flat; @@ -11610,7 +11987,10 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co } else if (src0t == GGML_TYPE_Q3_K) { GGML_ASSERT(false && "not implemented"); } else if (src0t == GGML_TYPE_Q5_K) { - GGML_ASSERT(false && "not implemented"); + size_t global_work_size[] = {(size_t)(ne01+ndst*nth1-1)/(ndst*nth1)*nth0, (size_t)ne11*nth1, (size_t)ne12*ne13}; + size_t local_work_size[] = {(size_t)nth0, (size_t)nth1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); } else if (src0t == GGML_TYPE_Q6_K) { size_t global_work_size[] = {(size_t)(ne01+ndst*nth1-1)/(ndst*nth1)*nth0, (size_t)ne11*nth1, (size_t)ne12*ne13}; size_t local_work_size[] = {(size_t)nth0, (size_t)nth1, 1}; diff --git a/ggml/src/ggml-opencl/kernels/cvt.cl b/ggml/src/ggml-opencl/kernels/cvt.cl index 81fe17fa10f..1bd83d29b3d 100644 --- a/ggml/src/ggml-opencl/kernels/cvt.cl +++ b/ggml/src/ggml-opencl/kernels/cvt.cl @@ -66,6 +66,17 @@ struct block_q4_K { uchar q[QK_K / 2]; // nibbles / quants }; +//------------------------------------------------------------------------------ +// block_q5_k +//------------------------------------------------------------------------------ +struct block_q5_K { + half d; // delta + half dm; // min + uchar s[K_SCALE_SIZE]; + uchar qh[QK_K / 8]; + uchar qs[QK_K / 2]; // nibbles / quants +}; + //------------------------------------------------------------------------------ // block_q6_K //------------------------------------------------------------------------------ @@ -546,6 +557,71 @@ kernel void kernel_restore_block_q4_K_noshuffle( } } +//------------------------------------------------------------------------------ +// kernel_convert_block_q5_K +// Convert the block_q5_K format to 5 separate arrays (AOS -> SOA). +// Each thread processes a super block. +//------------------------------------------------------------------------------ +kernel void kernel_convert_block_q5_K( + global struct block_q5_K * src0, + global uchar * dst_q, + global uchar * dst_qh, + global uchar * dst_s, + global half * dst_d, + global half * dst_dm +) { + global struct block_q5_K * b = (global struct block_q5_K *) src0 + get_global_id(0); + global uchar * q = (global uchar *) dst_q + QK_K/2*get_global_id(0); + global uchar * qh = (global uchar *) dst_qh + QK_K/8*get_global_id(0); + global uchar * s = (global uchar *) dst_s + K_SCALE_SIZE*get_global_id(0); + global half * d = (global half *) dst_d + get_global_id(0); + global half * dm = (global half *) dst_dm + get_global_id(0); + + *d = b->d; + *dm = b->dm; + + for (int i = 0; i < QK_K/2; ++i) { + q[i] = b->qs[i]; + } + for (int i = 0; i < QK_K/8; ++i) { + qh[i] = b->qh[i]; + } + for (int i = 0; i < K_SCALE_SIZE; ++i) { + s[i] = b->s[i]; + } +} + +// Restore block_q5_K from flattened arrays. +// Each thread processes a super block. +kernel void kernel_restore_block_q5_K( + global uchar * src_q, + global uchar * src_qh, + global uchar * src_s, + global half * src_d, + global half * src_dm, + global struct block_q5_K * dst +) { + global struct block_q5_K * b = (global struct block_q5_K *) dst + get_global_id(0); + global uchar * q = (global uchar *) src_q + QK_K/2*get_global_id(0); + global uchar * qh = (global uchar *) src_qh + QK_K/8*get_global_id(0); + global uchar * s = (global uchar *) src_s + K_SCALE_SIZE*get_global_id(0); + global half * d = (global half *) src_d + get_global_id(0); + global half * dm = (global half *) src_dm + get_global_id(0); + + b->d = *d; + b->dm = *dm; + + for (int i = 0; i < QK_K/2; ++i) { + b->qs[i] = q[i]; + } + for (int i = 0; i < QK_K/8; ++i) { + b->qh[i] = qh[i]; + } + for (int i = 0; i < K_SCALE_SIZE; ++i) { + b->s[i] = s[i]; + } +} + //------------------------------------------------------------------------------ // kernel_convert_block_q6_K // Convert the block_q6_K format to 3 separate arrays (AOS -> SOA). diff --git a/ggml/src/ggml-opencl/kernels/mul_mm_q5_k_f32_l4_lm.cl b/ggml/src/ggml-opencl/kernels/mul_mm_q5_k_f32_l4_lm.cl new file mode 100644 index 00000000000..8e191f57e83 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mm_q5_k_f32_l4_lm.cl @@ -0,0 +1,192 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#define LOAD_VEC_A 4 +#define LOAD_VEC_B 4 + +#define BM 64 +#define BN 64 +#define BK 32 +#define TM 4 +#define TN 8 + +kernel void kernel_mul_mm_q5_k_f32_l4_lm( + global uchar4 * src0_q, + global uchar * src0_qh, + global uchar * src0_s, + global half * src0_d, + global half * src0_dm, + global float4 * src1, + ulong offset1, + global float * dst, + ulong offsetd, + + int ne00, + int ne01, + int ne02, + int ne11, + int ne12, + + int stride_a, + int stride_b, + int stride_d, + + int batch_stride_a, + int batch_stride_b, + int batch_stride_d, + + int r2, + int r3 +) { + src1 = (global float4*)((global char*)src1 + offset1); + dst = (global float *)((global char*)dst + offsetd); + + local float buf_a[BM * BK]; + local float buf_b[BN * BK]; + + const int batch_idx = get_global_id(2); + + const int i13 = batch_idx / ne12; + const int i12 = batch_idx % ne12; + + const int i03 = i13 / r3; + const int i02 = i12 / r2; + + const int batch_idx_a = i03 * ne02 + i02; + + const int ir = get_group_id(0); + const int ic = get_group_id(1); + + const int tid = get_local_id(0); + const int th_r = tid % (BM / TM); + const int th_c = tid / (BM / TM); + + const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A); + const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A); + const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B); + const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B); + + const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK; + const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK; + + int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A; + int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B; + + float sums[TM * TN]; + float cache_a[TM]; + float cache_b[TN]; + + for (int i = 0; i < TM * TN; i++) { + sums[i] = 0.0f; + } + + for (int block = 0; block < ne00; block += BK) { + for (int l = 0; l < BM; l += loadstride_a) { + if (ir*BM + loadc_a + l < ne01) { + int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a; + int ib = idx / 64; + int iqs = (idx % 64) * 2; + + int n = iqs / 32; + int b = (iqs % 32) / 16; + int is = 2 * n + b; + int qsi = n * 32 + (iqs % 16) * 2; + + global uchar * scales = src0_s + ib * 12; + + int scidx0 = (is < 4) ? is : (is + 4); + int scidx1 = (is < 4) ? is : (is - 4); + int scidxmask1 = (is < 4) ? 0x30 : 0xC0; + int scidxshift1 = (is < 4) ? 0 : 2; + int mbidx0 = is + 4; + int mbidx1 = (is < 4) ? is + 4 : is; + int mbidxmask0 = (is < 4) ? 0xF : 0xF0; + int mbidxshift0 = (is < 4) ? 0 : 4; + int mbidxmask1 = (is < 4) ? 0x30 : 0xC0; + int mbidxshift1 = (is < 4) ? 0 : 2; + + uchar sc = (scales[scidx0] & 0xF) | ((scales[scidx1] & scidxmask1) >> scidxshift1); + uchar mbyte = ((scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((scales[mbidx1] & mbidxmask1) >> mbidxshift1); + + float d = (float)src0_d[ib] * (float)sc; + float m = -(float)src0_dm[ib] * (float)mbyte; + + int qh_base = (iqs % 16) * 2; + int bit_pos = 2*n + b; + uchar h0 = (src0_qh[ib*32 + qh_base + 0] >> bit_pos) & 1; + uchar h1 = (src0_qh[ib*32 + qh_base + 1] >> bit_pos) & 1; + uchar h2 = (src0_qh[ib*32 + qh_base + 2] >> bit_pos) & 1; + uchar h3 = (src0_qh[ib*32 + qh_base + 3] >> bit_pos) & 1; + + global uchar4 * qs = src0_q + ib*32 + (qsi >> 2); + uchar4 q = *qs; + float4 v1 = (convert_float4((uchar4)( + ((q.s0 >> (b * 4))&0x0F) | (h0 << 4), + ((q.s1 >> (b * 4))&0x0F) | (h1 << 4), + ((q.s2 >> (b * 4))&0x0F) | (h2 << 4), + ((q.s3 >> (b * 4))&0x0F) | (h3 << 4) + )))*d + m; + + buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = v1.s0; + buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = v1.s1; + buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = v1.s2; + buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = v1.s3; + } else { + buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = 0.0f; + } + } + + for (int l = 0; l < BN; l += loadstride_b) { + if (ic*BN + loadc_b + l < ne11) { + int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b; + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3; + } else { + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f; + } + } + + barrier(CLK_LOCAL_MEM_FENCE); + + pos_a += BK / LOAD_VEC_A; + pos_b += BK / LOAD_VEC_B; + + for (int i = 0; i < BK; i++) { + for (int j = 0; j < TM; j++) { + cache_a[j] = buf_a[(i) * BM + th_r * TM + j]; + } + + for (int j = 0; j < TN; j++) { + cache_b[j] = buf_b[(i) * BN + th_c * TN + j]; + } + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + const int sums_idx = cc*TM + cr; + sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]); + } + } + } + barrier(CLK_LOCAL_MEM_FENCE); + } + + const int dr = ir * BM + th_r * TM; + const int dc = ic * BN + th_c * TN; + + const int offsets = batch_idx * batch_stride_d; + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + if (dr + cr < ne01 && dc + cc < ne11) { + dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr]; + } + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32.cl new file mode 100644 index 00000000000..b2058abc1b6 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32.cl @@ -0,0 +1,187 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define QK_K 256 +#define K_SCALE_SIZE 12 + +typedef struct { + half d; // super-block scale for quantized scales + half dmin; // super-block scale for quantized mins + uchar scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits + uchar qh[QK_K/8]; // quants, high bit (1 bit per value, packed 8 per byte) + uchar qs[QK_K/2]; // quants, low 4 bits (2 values per byte) +} block_q5_K; + +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 4 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 16 +#elif defined(ADRENO_GPU) +#define N_DST 4 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif + +#define BLOCK_STRIDE (N_SIMDWIDTH/8) + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mv_q5_K_f32( + global char * src0, + int offset0, + global char * src1, + int offset1, + global char * dst, + int offsetd, + int ne00, + int ne01, + ulong nb01, + ulong nb02, + ulong nb03, + int ne12, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + int ne1, + int r2, + int r3 +) { + src0 = src0 + offset0; + src1 = src1 + offset1; + dst = dst + offsetd; + + ushort kmask1 = 0x3f3f; + ushort kmask2 = 0x0f0f; + ushort kmask3 = 0xc0c0; + + int ix = get_sub_group_local_id()/8; // super block index + int it = get_sub_group_local_id()%8; // block index (inside super block) + int iq = it/4; // 0 or 1 - first or second half of the super block + int ir = it%4; // 0...3 - block index in the half super block + + int nb = ne00/QK_K; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + int offset_src0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + int offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + global block_q5_K * x = (global block_q5_K *) (src0 + offset_src0); + global float * y = (global float *) (src1 + offset_src1); + + float yl[16]; + float yh[16]; + float sumf[N_DST] = {0.f}; + float all_sum; + + global float * y4 = y + ix * QK_K + 64 * iq + 8 * ir; + + uchar u1_lo = (uchar)(1 << (2*iq)); + uchar u2_lo = (uchar)(2 << (2*iq)); + uchar u1_hi = (uchar)(1 << (2*iq + 4)); + uchar u2_hi = (uchar)(2 << (2*iq + 4)); + + ushort sc16[4]; + uchar * sc8 = (uchar *)sc16; + + for (int ib = ix; ib < nb; ib += BLOCK_STRIDE) { + float4 sumy = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; ++i) { + yl[i+0] = y4[i+0]; + sumy.s0 += yl[i+0]; + + yl[i+8] = y4[i+32]; + sumy.s1 += yl[i+8]; + + yh[i+0] = y4[i+128]; + sumy.s2 += yh[i+0]; + + yh[i+8] = y4[i+160]; + sumy.s3 += yh[i+8]; + } + + global ushort * sc = (global ushort *)x[ib].scales + iq; + global ushort * q1 = (global ushort *)x[ib].qs + 16 * iq + 4 * ir; + global uchar * qh = x[ib].qh + 8 * ir; + global half * dh = &x[ib].d; + + for (int row = 0; row < N_DST; row++) { + sc16[0] = sc[0] & kmask1; + sc16[1] = sc[2] & kmask1; + sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2); + sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2); + + global ushort * q2 = q1 + 32; + + float4 acc1 = {0.f, 0.f, 0.f, 0.f}; + float4 acc2 = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; i += 2) { + acc1.s0 += yl[i+0] * ((q1[i/2] & 0x000F) + (qh[i+0] & u1_lo ? 16.f : 0.f)); + acc1.s1 += yl[i+1] * ((q1[i/2] & 0x0F00) + (qh[i+1] & u1_lo ? 16.f*256.f : 0.f)); + acc1.s2 += yl[i+8] * ((q1[i/2] & 0x00F0) + (qh[i+0] & u2_lo ? 16.f*16.f : 0.f)); + acc1.s3 += yl[i+9] * ((q1[i/2] & 0xF000) + (qh[i+1] & u2_lo ? 16.f*4096.f: 0.f)); + acc2.s0 += yh[i+0] * ((q2[i/2] & 0x000F) + (qh[i+0] & u1_hi ? 16.f : 0.f)); + acc2.s1 += yh[i+1] * ((q2[i/2] & 0x0F00) + (qh[i+1] & u1_hi ? 16.f*256.f : 0.f)); + acc2.s2 += yh[i+8] * ((q2[i/2] & 0x00F0) + (qh[i+0] & u2_hi ? 16.f*16.f : 0.f)); + acc2.s3 += yh[i+9] * ((q2[i/2] & 0xF000) + (qh[i+1] & u2_hi ? 16.f*4096.f: 0.f)); + } + + float dall = dh[0]; + float dmin = dh[1]; + sumf[row] += dall * ((acc1.s0 + 1.f/256.f * acc1.s1) * sc8[0] + + (acc1.s2 + 1.f/256.f * acc1.s3) * sc8[1] * 1.f/16.f + + (acc2.s0 + 1.f/256.f * acc2.s1) * sc8[4] + + (acc2.s2 + 1.f/256.f * acc2.s3) * sc8[5] * 1.f/16.f) - + dmin * (sumy.s0 * sc8[2] + sumy.s1 * sc8[3] + sumy.s2 * sc8[6] + sumy.s3 * sc8[7]); + + q1 += nb01/2; + sc += nb01/2; + dh += nb01/2; + qh += nb01; + } + + y4 += BLOCK_STRIDE * QK_K; + } + + global float * dst_f32 = (global float *) dst + im*ne0*ne1 + r1*ne0; + + for (int row = 0; row < N_DST; ++row) { + all_sum = sub_group_reduce_add(sumf[row]); + if (first_row + row < ne01) { + if (get_sub_group_local_id() == 0) { + dst_f32[first_row + row] = all_sum; + } + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32_flat.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32_flat.cl new file mode 100644 index 00000000000..e353a72be70 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32_flat.cl @@ -0,0 +1,203 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +//------------------------------------------------------------------------------ +// block_q5_K +//------------------------------------------------------------------------------ +#define QK_K 256 +#define BLOCK_Q5K_SIZE 176 +#define K_SCALE_SIZE 12 + +typedef struct { + half d; // super-block scale for quantized scales + half dmin; // super-block scale for quantized mins + uchar scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits + uchar qh[QK_K/8]; // quants, high bit (1 bit per value, packed 8 per byte) + uchar qs[QK_K/2]; // quants, low 4 bits (2 values per byte) +} block_q5_K; + +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 4 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 16 +#elif defined(ADRENO_GPU) +#define N_DST 16 +#define N_SIMDGROUP 2 +#define N_SIMDWIDTH 64 +#endif + +#undef BLOCK_STRIDE +// number of (super) blocks each subgroup processes +// each thread in a subgroup processes a block (32 weights) +#define BLOCK_STRIDE (N_SIMDWIDTH/8) + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mv_q5_K_f32_flat( + global uchar * src0_q, + global uchar * src0_qh, + global uchar * src0_s, + global half * src0_d, + global half * src0_dm, + global char * src1, + int offset1, + global char * dst, + int offsetd, + int ne00, + int ne01, + ulong nb01, + ulong nb02, + ulong nb03, + int ne12, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + int ne1, + int r2, + int r3 +) { + src1 = src1 + offset1; + dst = dst + offsetd; + + ushort kmask1 = 0x3f3f; + ushort kmask2 = 0x0f0f; + ushort kmask3 = 0xc0c0; + + int ix = get_sub_group_local_id()/8; + int it = get_sub_group_local_id()%8; + int iq = it/4; + int ir = it%4; + + int nb = ne00/QK_K; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + int offset_src0 = (first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03)/BLOCK_Q5K_SIZE; + uint blk = nb01 / BLOCK_Q5K_SIZE; + global uchar * blk_q = (global uchar *)src0_q + offset_src0*(QK_K/2); + global uchar * blk_qh = (global uchar *)src0_qh + offset_src0*(QK_K/8); + global uchar * blk_s = (global uchar *)src0_s + offset_src0*K_SCALE_SIZE; + global half * blk_d = (global half *)src0_d + offset_src0; + global half * blk_dm = (global half *)src0_dm + offset_src0; + + int offset_src1 = r1*nb11 + (i12)*nb12 + (i13)*nb13; + global float * y = (global float *)(src1 + offset_src1); + + float yl[16]; + float yh[16]; + float sumf[N_DST] = {0.f}; + float all_sum; + + global float * y4 = y + ix * QK_K + 64 * iq + 8 * ir; + + uchar u1_lo = (uchar)(1 << (2*iq)); + uchar u2_lo = (uchar)(2 << (2*iq)); + uchar u1_hi = (uchar)(1 << (2*iq + 4)); + uchar u2_hi = (uchar)(2 << (2*iq + 4)); + + ushort sc16[4]; + uchar * sc8 = (uchar *)sc16; + + for (int ib = ix; ib < nb; ib += BLOCK_STRIDE) { + float4 sumy = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; ++i) { + yl[i+0] = y4[i+0]; + sumy.s0 += yl[i+0]; + + yl[i+8] = y4[i+32]; + sumy.s1 += yl[i+8]; + + yh[i+0] = y4[i+128]; + sumy.s2 += yh[i+0]; + + yh[i+8] = y4[i+160]; + sumy.s3 += yh[i+8]; + } + + global ushort * q1 = (global ushort *)(blk_q + ib * (QK_K/2)) + (16 * iq + 4 * ir); + global uchar * qh = (global uchar *)(blk_qh + ib * (QK_K/8)) + 8 * ir; + global ushort * sc = (global ushort *)(blk_s + ib * K_SCALE_SIZE) + iq; + global half * d = blk_d + ib; + global half * dm = blk_dm + ib; + + for (int row = 0; row < N_DST; row++) { + sc16[0] = sc[0] & kmask1; + sc16[1] = sc[2] & kmask1; + sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2); + sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2); + + global ushort * q2 = q1 + 32; + + float4 acc1 = {0.f, 0.f, 0.f, 0.f}; + float4 acc2 = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; i += 2) { + acc1.s0 += yl[i+0] * ((q1[i/2] & 0x000F) + (qh[i+0] & u1_lo ? 16.f : 0.f)); + acc1.s1 += yl[i+1] * ((q1[i/2] & 0x0F00) + (qh[i+1] & u1_lo ? 16.f*256.f : 0.f)); + acc1.s2 += yl[i+8] * ((q1[i/2] & 0x00F0) + (qh[i+0] & u2_lo ? 16.f*16.f : 0.f)); + acc1.s3 += yl[i+9] * ((q1[i/2] & 0xF000) + (qh[i+1] & u2_lo ? 16.f*4096.f: 0.f)); + acc2.s0 += yh[i+0] * ((q2[i/2] & 0x000F) + (qh[i+0] & u1_hi ? 16.f : 0.f)); + acc2.s1 += yh[i+1] * ((q2[i/2] & 0x0F00) + (qh[i+1] & u1_hi ? 16.f*256.f : 0.f)); + acc2.s2 += yh[i+8] * ((q2[i/2] & 0x00F0) + (qh[i+0] & u2_hi ? 16.f*16.f : 0.f)); + acc2.s3 += yh[i+9] * ((q2[i/2] & 0xF000) + (qh[i+1] & u2_hi ? 16.f*4096.f: 0.f)); + } + + float dall = *d; + float dmin = *dm; + sumf[row] += dall * ((acc1.s0 + 1.f/256.f * acc1.s1) * sc8[0] + + (acc1.s2 + 1.f/256.f * acc1.s3) * sc8[1] * 1.f/16.f + + (acc2.s0 + 1.f/256.f * acc2.s1) * sc8[4] + + (acc2.s2 + 1.f/256.f * acc2.s3) * sc8[5] * 1.f/16.f) - + dmin * (sumy.s0 * sc8[2] + sumy.s1 * sc8[3] + sumy.s2 * sc8[6] + sumy.s3 * sc8[7]); + + q1 += blk*64; + qh += blk*32; + sc += blk*6; + d += blk; + dm += blk; + } + + y4 += BLOCK_STRIDE * QK_K; + } + + global float * dst_f32 = (global float *) dst + im*ne0*ne1 + r1*ne0; + + for (int row = 0; row < N_DST; ++row) { + all_sum = sub_group_reduce_add(sumf[row]); + if (first_row + row < ne01) { + if (get_sub_group_local_id() == 0) { + dst_f32[first_row + row] = all_sum; + } + } + } +} From c0b46c2f8f3eac135f6f8d32dc3863e0c8898e8b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sat, 11 Apr 2026 18:52:11 +0200 Subject: [PATCH 419/831] CUDA: skip compilation of superfluous FA kernels (llama/21768) --- ggml/src/ggml-cuda/fattn.cu | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index addf93205ef..ea6607cd337 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -75,13 +75,17 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_con return; } - if (use_gqa_opt && gqa_ratio % 2 == 0) { - ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst); + if constexpr (DKQ <= 256) { + if (use_gqa_opt && gqa_ratio % 2 == 0) { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst); + return; + } + + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst); return; + } else { + GGML_ABORT("fatal error"); } - - ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst); - return; } if (use_gqa_opt && gqa_ratio > 4) { @@ -94,12 +98,16 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_con return; } - if (use_gqa_opt && gqa_ratio > 1) { - ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst); - return; - } + if constexpr (DKQ <= 256) { + if (use_gqa_opt && gqa_ratio > 1) { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst); + return; + } - ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst); + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ctx, dst); + } else { + GGML_ABORT("fatal error"); + } } static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { From b9072073128ebd7bdb98c6d328dce6e38f983109 Mon Sep 17 00:00:00 2001 From: Stephen Cox Date: Mon, 13 Apr 2026 00:15:26 +1200 Subject: [PATCH 420/831] mtmd: add Gemma 4 audio conformer encoder support (llama/21421) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * mtmd: add Gemma 4 audio conformer encoder support Add audio processing for Gemma 4 E2B/E4B via a USM-style Conformer. Architecture: - 12-layer Conformer: FFN → Self-Attention → Causal Conv1D → FFN → Norm - Subsampling Conv Projection: 2x Conv2D(stride=2) with LayerNorm - Full self-attention with sinusoidal RPE and sliding window mask (24) - Logit softcapping at 50.0, ClippableLinear clamping - Output: 1024 → 1536 → RMSNorm → multimodal embedder Mel preprocessing (dedicated mtmd_audio_preprocessor_gemma4a): - HTK mel scale, 128 bins, magnitude STFT, mel_floor=1e-3 - Standard periodic Hann window (320 samples), zero-padded to FFT size - Semicausal left-padding (frame_length/2 samples) - Frame count matched to PyTorch (unfold formula) - No pre-emphasis, no Whisper-style normalization - Mel cosine similarity vs PyTorch: 0.9998 Key fixes: - Tensor loading dedup: prevent get_tensor() from creating duplicate entries in ctx_data. Fixed with std::set guard. - ClippableLinear clamp_info loading moved after per-layer tensors. - Sliding window mask (24 positions) matching PyTorch context_size. - Skip Whisper normalization for Gemma4 mel output. Tested on E2B and E4B with CPU and Vulkan backends. Transcribes: "Glad to see things are going well and business is starting to pick up" (matching ground truth). Ref: #21325 --- ggml/src/ggml-cuda/ssm-conv.cu | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/ssm-conv.cu b/ggml/src/ggml-cuda/ssm-conv.cu index 69985cd335c..b77cdc1c137 100644 --- a/ggml/src/ggml-cuda/ssm-conv.cu +++ b/ggml/src/ggml-cuda/ssm-conv.cu @@ -134,8 +134,9 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int switch (nc) { case 3: launch_kernel(std::integral_constant{}); break; case 4: launch_kernel(std::integral_constant{}); break; + case 5: launch_kernel(std::integral_constant{}); break; case 9: launch_kernel(std::integral_constant{}); break; - default: GGML_ABORT("Only support kernel sizes 3, 4, 9 right now."); + default: GGML_ABORT("Only support kernel sizes 3, 4, 5, 9 right now."); } } From 655072cd78a989c0696efaa1e54e616b6b8c2678 Mon Sep 17 00:00:00 2001 From: Akarshan Biswas Date: Mon, 13 Apr 2026 07:14:58 +0530 Subject: [PATCH 421/831] sycl: disable Q1_0 in backend and cleanup unused variables (llama/21807) --- ggml/src/ggml-sycl/convert.cpp | 2 +- ggml/src/ggml-sycl/dequantize.hpp | 1 + ggml/src/ggml-sycl/element_wise.cpp | 2 +- ggml/src/ggml-sycl/gated_delta_net.cpp | 10 ++++------ ggml/src/ggml-sycl/ggml-sycl.cpp | 7 +++++++ ggml/src/ggml-sycl/upscale.cpp | 8 ++++---- 6 files changed, 18 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-sycl/convert.cpp b/ggml/src/ggml-sycl/convert.cpp index d7f60cbc9ea..f12419426ae 100644 --- a/ggml/src/ggml-sycl/convert.cpp +++ b/ggml/src/ggml-sycl/convert.cpp @@ -488,7 +488,7 @@ static void dequantize_row_nvfp4_sycl(const void * vx, dst_t * y, const int64_t const int nb = k / QK_NVFP4; stream->parallel_for( sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)), - [=](sycl::nd_item<3> item_ct1) { + [=](sycl::nd_item<3> /*item_ct1*/) { dequantize_block_nvfp4(vx, y, k); }); } diff --git a/ggml/src/ggml-sycl/dequantize.hpp b/ggml/src/ggml-sycl/dequantize.hpp index f992db33b2d..68c3db30613 100644 --- a/ggml/src/ggml-sycl/dequantize.hpp +++ b/ggml/src/ggml-sycl/dequantize.hpp @@ -14,6 +14,7 @@ #define GGML_SYCL_DEQUANTIZE_HPP #include "common.hpp" +#include "convert.hpp" typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v); typedef void (*dequantize_kernel_t_reorder)(const void *d, const int64_t ib, const void *qs, diff --git a/ggml/src/ggml-sycl/element_wise.cpp b/ggml/src/ggml-sycl/element_wise.cpp index ec0247528c4..249e80c826e 100644 --- a/ggml/src/ggml-sycl/element_wise.cpp +++ b/ggml/src/ggml-sycl/element_wise.cpp @@ -355,7 +355,7 @@ static void acc_f32_sycl(const float *x, const float *y, float *dst, const int num_blocks = (n_elements + SYCL_ACC_BLOCK_SIZE - 1) / SYCL_ACC_BLOCK_SIZE; stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { + [=](sycl::nd_item<3> /*item_ct1*/) { acc_f32(x, y, dst, n_elements, ne10, ne11, ne12, ne13, s1, s2, s3, offset); }); } diff --git a/ggml/src/ggml-sycl/gated_delta_net.cpp b/ggml/src/ggml-sycl/gated_delta_net.cpp index 648455c134b..ebc587524bf 100644 --- a/ggml/src/ggml-sycl/gated_delta_net.cpp +++ b/ggml/src/ggml-sycl/gated_delta_net.cpp @@ -176,14 +176,12 @@ static void launch_gated_delta_net(const float * q_d, const sycl::uint3 neqk1_magic = init_fastdiv_values(neqk1); const sycl::uint3 rq3_magic = init_fastdiv_values(rq3); - int cc = ggml_sycl_info().devices[ggml_sycl_get_device()].cc; - switch (S_v) { case 16: { constexpr int sv = 16; stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + [=](sycl::nd_item<3> /*item_ct1*/) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { gated_delta_net_sycl(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); @@ -194,7 +192,7 @@ static void launch_gated_delta_net(const float * q_d, { constexpr int sv = 32; stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + [=](sycl::nd_item<3> /*item_ct1*/) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { gated_delta_net_sycl(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); @@ -205,7 +203,7 @@ static void launch_gated_delta_net(const float * q_d, { constexpr int sv = 64; stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + [=](sycl::nd_item<3> /*item_ct1*/) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { gated_delta_net_sycl( q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); @@ -217,7 +215,7 @@ static void launch_gated_delta_net(const float * q_d, { constexpr int sv = 128; stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + [=](sycl::nd_item<3> /*item_ct1*/) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { gated_delta_net_sycl( q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 989c91a6abb..ea79d2538c1 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -4727,12 +4727,19 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g struct ggml_tensor * a = op->src[0]; struct ggml_tensor * b = op->src[1]; + // disable Q1_0 until implementation + if (a->type == GGML_TYPE_Q1_0 || b->type == GGML_TYPE_Q1_0) { + return false; + } + if (a->ne[3] != b->ne[3]) { return false; } ggml_type src0_type = op->src[0]->type; + + // TODO: The configuration below needs more work to be supported with oneDNN if (ggml_is_permuted(a) && !ggml_is_contiguous(a) && a->ne[2] > 1 && a->ne[3] > 1 && src0_type == GGML_TYPE_F16) { diff --git a/ggml/src/ggml-sycl/upscale.cpp b/ggml/src/ggml-sycl/upscale.cpp index 18c743de447..e42cb419d83 100644 --- a/ggml/src/ggml-sycl/upscale.cpp +++ b/ggml/src/ggml-sycl/upscale.cpp @@ -272,7 +272,7 @@ static void upscale_f32_sycl(const float * x, sycl::nd_range<3>( sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { + [=](sycl::nd_item<3> /*item_ct1*/) { upscale_f32(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3); }); } @@ -304,7 +304,7 @@ static void upscale_f32_bilinear_sycl(const float * x, sycl::nd_range<3>( sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { + [=](sycl::nd_item<3> /*item_ct1*/) { upscale_f32_bilinear_antialias( x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset); @@ -314,7 +314,7 @@ static void upscale_f32_bilinear_sycl(const float * x, sycl::nd_range<3>( sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { + [=](sycl::nd_item<3> /*item_ct1*/) { upscale_f32_bilinear( x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset); @@ -349,7 +349,7 @@ static void upscale_f32_bicubic_sycl(const float * x, sycl::nd_range<3>( sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { + [=](sycl::nd_item<3> /*item_ct1*/) { upscale_f32_bicubic( x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset); From 36b7bb3d9576e670037228bf58c768cdeb5ed450 Mon Sep 17 00:00:00 2001 From: Masashi Yoshimura Date: Mon, 13 Apr 2026 12:13:04 +0900 Subject: [PATCH 422/831] Remove extra conditional check on debug mode. (llama/21798) --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index e979783f020..634201bc64d 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -534,11 +534,7 @@ static void ggml_backend_webgpu_debug(webgpu_global_context & ctx) { encoder.CopyBufferToBuffer(ctx->debug_dev_buf, 0, ctx->debug_host_buf, 0, ctx->debug_host_buf.GetSize()); wgpu::CommandBuffer commands = encoder.Finish(); ctx->queue.Submit(1, &commands); - if (!ggml_backend_webgpu_map_buffer(ctx, ctx->debug_host_buf, wgpu::MapMode::Read, 0, - ctx->debug_host_buf.GetSize())) { - GGML_LOG_ERROR("ggml_webgpu: Debug buffer map failed\n"); - return; - } + ggml_backend_webgpu_map_buffer(ctx, ctx->debug_host_buf, wgpu::MapMode::Read, 0, ctx->debug_host_buf.GetSize()); const float * debug_data = (const float *) ctx->debug_host_buf.GetConstMappedRange(); std::cout << "debug[0]: " << debug_data[0] << "\n"; ctx->debug_host_buf.Unmap(); From d9ed371c2c50bdaef8134a05694b826e1cb7f7c6 Mon Sep 17 00:00:00 2001 From: Oliver Simons Date: Mon, 13 Apr 2026 11:14:06 +0200 Subject: [PATCH 423/831] CUDA: Limit DeviceSegmentedSort to immediate mode (llama/21718) * CUDA: Limit DeviceSegmentedSort to immediate mode DeviceSegmentedSort is currently not capturable in a cuda graph. Hence, we have to go for the slower DeviceSegmentedRadixSort in that case. Perf numbers on RTX Pro 6000 Blackwell Max-Q: DeviceSegmentedRadixSort in graph mode (i.e. CUDA Graphs) ARGSORT(type=f32,ne=[2048,512,1,1],order=1): 12291 runs - 105.94 us/run - 8192 kB/run - 73.75 GB/s ARGSORT(type=f32,ne=[4096,512,1,1],order=1): 10245 runs - 115.08 us/run - 16384 kB/run - 135.77 GB/s ARGSORT(type=f32,ne=[8192,512,1,1],order=1): 5125 runs - 221.22 us/run - 32768 kB/run - 141.26 GB/s ARGSORT(type=f32,ne=[16384,512,1,1],order=1): 2565 runs - 430.98 us/run - 65536 kB/run - 145.02 GB/s ARGSORT(type=f32,ne=[32768,512,1,1],order=1): 1028 runs - 1185.83 us/run - 131072 kB/run - 105.41 GB/s ARGSORT(type=f32,ne=[65536,512,1,1],order=1): 387 runs - 2748.62 us/run - 262144 kB/run - 90.95 GB/s DeviceSegmentedSort in immediate mode ARGSORT(type=f32,ne=[2048,512,1,1],order=1): 16388 runs - 71.17 us/run - 8192 kB/run - 109.78 GB/s ARGSORT(type=f32,ne=[4096,512,1,1],order=1): 12294 runs - 81.38 us/run - 16384 kB/run - 192.00 GB/s ARGSORT(type=f32,ne=[8192,512,1,1],order=1): 5125 runs - 240.81 us/run - 32768 kB/run - 129.77 GB/s ARGSORT(type=f32,ne=[16384,512,1,1],order=1): 2565 runs - 406.60 us/run - 65536 kB/run - 153.71 GB/s ARGSORT(type=f32,ne=[32768,512,1,1],order=1): 1285 runs - 873.23 us/run - 131072 kB/run - 143.15 GB/s ARGSORT(type=f32,ne=[65536,512,1,1],order=1): 516 runs - 2288.46 us/run - 262144 kB/run - 109.24 GB/s * Add test case for dispatch to DeviceSegmentedRadixSort We currently lack a way to force graph mode in CUDA, patch callback to invoke ggml_backend_compare_graph_backend twice to enforce each test to run in graph mode --- ggml/src/ggml-cuda/argsort.cu | 79 +++++++++++++++++++++++++---------- 1 file changed, 56 insertions(+), 23 deletions(-) diff --git a/ggml/src/ggml-cuda/argsort.cu b/ggml/src/ggml-cuda/argsort.cu index ed4e5de70f5..0f3f017b534 100644 --- a/ggml/src/ggml-cuda/argsort.cu +++ b/ggml/src/ggml-cuda/argsort.cu @@ -58,26 +58,48 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool, size_t temp_storage_bytes = 0; + bool is_capturing = false; +#ifdef USE_CUDA_GRAPH + // Currently (confirmed for CCCL <= 3.2) DeviceSegmentedSort does not support stream capture, while DeviceSegmentedRadixSort does. + // See https://github.com/NVIDIA/cccl/issues/5661#issuecomment-3229037149 + // TODO: constrain this to the CCCL versions that have this issue once it's resolved in a future CCCL release. + cudaStreamCaptureStatus capture_status; + CUDA_CHECK(cudaStreamIsCapturing(stream, &capture_status)); + is_capturing = (capture_status != cudaStreamCaptureStatusNone); +#endif // USE_CUDA_GRAPH + if (order == GGML_SORT_ORDER_ASC) { if (nrows == 1) { CUDA_CHECK(DeviceRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) - temp_indices, dst, // values (indices) - ncols, 0, sizeof(float) * 8, stream)); + temp_indices, dst, // values (indices) + ncols, 0, sizeof(float) * 8, stream)); + } else if (is_capturing) { + CUDA_CHECK(DeviceSegmentedRadixSort::SortPairs( + nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) + temp_indices, dst, // values (indices) + ncols * nrows, nrows, // num items, num segments + offset_iterator, offset_iterator + 1, 0, sizeof(float) * 8, stream)); } else { - CUDA_CHECK(DeviceSegmentedSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) - temp_indices, dst, // values (indices) - ncols * nrows, nrows, // num items, num segments - offset_iterator, offset_iterator + 1, stream)); + CUDA_CHECK(DeviceSegmentedSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, + temp_keys, // keys (in-place) + temp_indices, dst, // values (indices) + ncols * nrows, nrows, // num items, num segments + offset_iterator, offset_iterator + 1, stream)); } } else { if (nrows == 1) { - CUDA_CHECK(DeviceRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) - temp_indices, dst, // values (indices) - ncols, 0, sizeof(float) * 8, stream)); + CUDA_CHECK(DeviceRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, + temp_keys, // keys (in-place) + temp_indices, dst, // values (indices) + ncols, 0, sizeof(float) * 8, stream)); + } else if (is_capturing) { + CUDA_CHECK(DeviceSegmentedRadixSort::SortPairsDescending( + nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst, ncols * nrows, nrows, + offset_iterator, offset_iterator + 1, 0, sizeof(float) * 8, stream)); } else { - CUDA_CHECK(DeviceSegmentedSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices, - dst, ncols * nrows, nrows, offset_iterator, offset_iterator + 1, - stream)); + CUDA_CHECK(DeviceSegmentedSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, + temp_indices, dst, ncols * nrows, nrows, + offset_iterator, offset_iterator + 1, stream)); } } @@ -86,22 +108,33 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool, if (order == GGML_SORT_ORDER_ASC) { if (nrows == 1) { - CUDA_CHECK(DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) - temp_indices, dst, // values (indices) - ncols, 0, sizeof(float) * 8, stream)); + CUDA_CHECK(DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, + temp_keys, // keys (in-place) + temp_indices, dst, // values (indices) + ncols, 0, sizeof(float) * 8, stream)); + } else if (is_capturing) { + CUDA_CHECK(DeviceSegmentedRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, + temp_indices, dst, ncols * nrows, nrows, offset_iterator, + offset_iterator + 1, 0, sizeof(float) * 8, stream)); } else { - CUDA_CHECK(DeviceSegmentedSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst, - ncols * nrows, nrows, offset_iterator, offset_iterator + 1, stream)); + CUDA_CHECK(DeviceSegmentedSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, + temp_indices, dst, ncols * nrows, nrows, offset_iterator, + offset_iterator + 1, stream)); } } else { if (nrows == 1) { - CUDA_CHECK(DeviceRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) - temp_indices, dst, // values (indices) - ncols, 0, sizeof(float) * 8, stream)); + CUDA_CHECK(DeviceRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, + temp_keys, // keys (in-place) + temp_indices, dst, // values (indices) + ncols, 0, sizeof(float) * 8, stream)); + } else if (is_capturing) { + CUDA_CHECK(DeviceSegmentedRadixSort::SortPairsDescending( + d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst, ncols * nrows, nrows, + offset_iterator, offset_iterator + 1, 0, sizeof(float) * 8, stream)); } else { - CUDA_CHECK(DeviceSegmentedSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, - temp_indices, dst, ncols * nrows, nrows, offset_iterator, - offset_iterator + 1, stream)); + CUDA_CHECK(DeviceSegmentedSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, + temp_keys, temp_indices, dst, ncols * nrows, nrows, + offset_iterator, offset_iterator + 1, stream)); } } } From 0f99a47177a887b3771c876fcd5b8de4711c9fe3 Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Mon, 13 Apr 2026 14:21:31 +0200 Subject: [PATCH 424/831] vulkan: Flash Attention DP4A shader for quantized KV cache (llama/20797) * use integer dot product for quantized KV flash attention * small improvements * fix SHMEM_STAGING indexing * add missing KV type quants * fixes * add supported quants to FA tests * readd fast paths for <8bit quants * fix mmq gate and shmem checks --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 97 +++++++-- .../vulkan-shaders/flash_attn.comp | 184 +++++++++++++++++- .../vulkan-shaders/flash_attn_base.glsl | 5 + .../vulkan-shaders/flash_attn_mmq_funcs.glsl | 149 ++++++++++++++ .../vulkan-shaders/mul_mmq_shmem_types.glsl | 6 + .../src/ggml-vulkan/vulkan-shaders/types.glsl | 1 + .../vulkan-shaders/vulkan-shaders-gen.cpp | 15 +- 7 files changed, 430 insertions(+), 27 deletions(-) create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 977aff62d81..1bee3e187cf 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -2858,11 +2858,10 @@ struct vk_fa_tuning_params { } }; -static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc); +static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type kv_type); static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc); static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) { - GGML_UNUSED(kv_type); vk_fa_tuning_params result{}; result.path = FA_SCALAR; @@ -2914,7 +2913,7 @@ static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device, result.shmem_staging = (device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 && hsv < 256) ? 1 : 0; - if (!reduce_block_rows && !ggml_vk_flash_attn_scalar_shmem_support(device, result, hsk, hsv, f32acc)) { + if (!reduce_block_rows && !ggml_vk_flash_attn_scalar_shmem_support(device, result, hsk, hsv, f32acc, kv_type)) { result.block_rows /= 2; } @@ -3445,21 +3444,47 @@ static void ggml_vk_load_shaders(vk_device& device) { if (device->fp16) { CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, ) CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, ) - CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, ) - CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, ) - CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, ) - CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, ) - CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, ) - CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, ) + +#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + if (device->integer_dot_product && device->subgroup_clustered) { + CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _int8) + CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _int8) + CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _int8) + CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _int8) + CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _int8) + CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _int8) + } else +#endif + { + CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, ) + CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, ) + CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, ) + CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, ) + CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, ) + CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, ) + } } else { CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, _fp32) CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, _fp32) - CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32) - CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32) - CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _fp32) - CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _fp32) - CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _fp32) - CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _fp32) + +#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + if (device->integer_dot_product && device->subgroup_clustered) { + CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32_int8) + CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32_int8) + CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _fp32_int8) + CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _fp32_int8) + CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _fp32_int8) + CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _fp32_int8) + } else +#endif + { + CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32) + CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32) + CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _fp32) + CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _fp32) + CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _fp32) + CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _fp32) + } } #if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) if (device->coopmat1_fa_support) { @@ -8780,7 +8805,7 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx } } -static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc) { +static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type kv_type) { GGML_UNUSED(f32acc); // Needs to be kept up to date on shader changes const uint32_t wg_size = params.workgroup_size; @@ -8789,21 +8814,51 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con const uint32_t float_type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float); + const bool mmq = device->integer_dot_product && device->subgroup_clustered && + (kv_type == GGML_TYPE_Q4_0 || kv_type == GGML_TYPE_Q4_1 || + kv_type == GGML_TYPE_Q5_0 || kv_type == GGML_TYPE_Q5_1 || + kv_type == GGML_TYPE_Q8_0 || kv_type == GGML_TYPE_IQ4_NL); + // tmpsh is overestimated slightly const uint32_t tmpsh = wg_size * sizeof(float); const uint32_t tmpshv4 = wg_size * 4 * float_type_size; const uint32_t masksh = Bc * (Br + 1) * float_type_size; - const uint32_t Qf = Br * (hsk / 4 + 1) * 4 * float_type_size; + uint32_t Qf, kvsh, kblocksh_size; + if (mmq) { + // block_b_cache: int32_t qs[8] + FLOAT_TYPEV2 ds + const uint32_t block_b_size = 8 * sizeof(int32_t) + 2 * float_type_size; + Qf = Br * (hsk / 32) * block_b_size; + + // kvsh uses D = HSV (K goes through kblocksh instead) + kvsh = params.shmem_staging ? Bc * (hsv / 4 + 1) * 4 * float_type_size : 4 * float_type_size; + + // block_a_cache size depends on quant type + uint32_t block_a_size; + switch (kv_type) { + case GGML_TYPE_Q4_0: block_a_size = 4 * sizeof(uint32_t) + float_type_size; break; + case GGML_TYPE_Q4_1: block_a_size = 4 * sizeof(uint32_t) + 2 * float_type_size; break; + case GGML_TYPE_Q5_0: block_a_size = 4 * sizeof(uint32_t) + sizeof(uint32_t) + float_type_size; break; + case GGML_TYPE_Q5_1: block_a_size = 4 * sizeof(uint32_t) + sizeof(uint32_t) + 2 * float_type_size; break; + case GGML_TYPE_Q8_0: + case GGML_TYPE_IQ4_NL: block_a_size = 8 * sizeof(int32_t) + float_type_size; break; + default: block_a_size = 0; break; + } + kblocksh_size = params.shmem_staging ? Bc * (hsk / 32) * block_a_size : block_a_size; + } else { + Qf = Br * (hsk / 4 + 1) * 4 * float_type_size; + + const uint32_t D = std::max(hsk, hsv); + kvsh = params.shmem_staging ? Bc * (D / 4 + 1) * 4 * float_type_size : 4 * float_type_size; - const uint32_t D = std::max(hsk, hsv); - const uint32_t kvsh = params.shmem_staging ? Bc * (D / 4 + 1) * 4 * float_type_size : 4 * float_type_size; + kblocksh_size = 0; + } - const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf + kvsh; + const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf + kvsh + kblocksh_size; const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize; - VK_LOG_DEBUG("ggml_vk_flash_attn_scalar_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", total_size=" << total_size << ", supported=" << supported); + VK_LOG_DEBUG("ggml_vk_flash_attn_scalar_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", mmq=" << mmq << ", total_size=" << total_size << ", supported=" << supported); return supported; } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index 11b7dce8578..6e6bdabc92e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -10,6 +10,13 @@ #extension GL_EXT_shader_subgroup_extended_types_float16 : require #endif +#ifdef MMQ +#extension GL_EXT_integer_dot_product : require +#extension GL_KHR_shader_subgroup_clustered : require + +#include "mul_mmq_shmem_types.glsl" +#endif + #extension GL_KHR_shader_subgroup_shuffle : enable #extension GL_KHR_shader_subgroup_vote : enable @@ -41,15 +48,34 @@ shared FLOAT_TYPEV4 tmpshv4[tmpsh_size]; const uint32_t masksh_stride = Br + 1; shared FLOAT_TYPE masksh[Bc * masksh_stride]; +#ifndef MMQ const uint32_t qf_stride = HSK / 4 + 1; shared FLOAT_TYPEV4 Qf[Br * qf_stride]; +#else +const uint32_t qf_stride = HSK / 32; +shared block_b_cache Qf[Br * qf_stride]; +#endif + +#ifndef MMQ const uint32_t D = HSK > HSV ? HSK : HSV; +#else +const uint32_t D = HSV; +#endif const uint32_t kvsh_stride = D / 4 + 1; shared FLOAT_TYPEV4 kvsh[SHMEM_STAGING != 0 ? Bc * kvsh_stride : 1]; +#ifdef MMQ + +shared block_a_cache kblocksh[SHMEM_STAGING != 0 ? Bc * qf_stride : 1]; +#endif + shared vec4 occupancy_limiter[LIMIT_OCCUPANCY_SHMEM > 0 ? LIMIT_OCCUPANCY_SHMEM : 1]; +#ifdef MMQ +#include "flash_attn_mmq_funcs.glsl" +#endif + void main() { #ifdef NEEDS_INIT_IQ_SHMEM init_iq_shmem(gl_WorkGroupSize); @@ -82,10 +108,39 @@ void main() { [[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) { uint32_t d = (idx + tid) % (HSK / 4); uint32_t r = (idx + tid) / (HSK / 4); - if (r < Br && d < HSK / 4 && - i * Br + r < N) { + const bool is_in_bounds = r < Br && d < HSK / 4 && i * Br + r < N; +#ifndef MMQ + if (is_in_bounds) { Qf[r * qf_stride + d] = FLOAT_TYPEV4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale); } +#else + const uint buf_ib = r * qf_stride + d / 8; + const uint buf_iqs = d % 8; + + FLOAT_TYPEV4 vals = is_in_bounds ? FLOAT_TYPEV4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale) : FLOAT_TYPEV4(0.0f); + const FLOAT_TYPEV4 abs_vals = abs(vals); + + const FLOAT_TYPE thread_max = max(max(abs_vals.x, abs_vals.y), max(abs_vals.z, abs_vals.w)); + const FLOAT_TYPE amax = subgroupClusteredMax(thread_max, 8); + const FLOAT_TYPE qd = amax / FLOAT_TYPE(127.0); + const FLOAT_TYPE qd_inv = qd != FLOAT_TYPE(0.0) ? FLOAT_TYPE(1.0) / qd : FLOAT_TYPE(0.0); + vals = round(vals * qd_inv); + + Qf[buf_ib].qs[buf_iqs] = pack32(i8vec4(vals)); + +#if defined(DATA_A_Q8_0) || defined(DATA_A_IQ4_NL) + if (buf_iqs == 0) { + Qf[buf_ib].ds = FLOAT_TYPEV2(qd, 0.0); + } +#else // Q4_0, Q4_1, Q5_0, Q5_1 + const FLOAT_TYPE thread_sum = vals.x + vals.y + vals.z + vals.w; + const FLOAT_TYPE sum = subgroupClusteredAdd(thread_sum, 8); + + if (buf_iqs == 0) { + Qf[buf_ib].ds = FLOAT_TYPEV2(qd, sum * qd); + } +#endif +#endif } barrier(); @@ -195,6 +250,7 @@ void main() { if (SHMEM_STAGING != 0) { barrier(); +#ifndef MMQ [[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) { uint32_t d = (idx + tid) % (HSK / 4); uint32_t c = (idx + tid) / (HSK / 4); @@ -214,9 +270,29 @@ void main() { kvsh[c * kvsh_stride + d] = K_Tf; } } +#else // MMQ + const uint ints_per_block = 8 / QUANT_R_MMQ; + const uint quant_iters = Bc * HSK / 32 * ints_per_block; + [[unroll]] for (uint32_t idx = 0; idx < quant_iters; idx += gl_WorkGroupSize.x) { + const uint32_t iqs = (idx + tid) % ints_per_block; + const uint32_t ib = (idx + tid) / ints_per_block; + const uint32_t c = ib / (HSK / 32); + const uint32_t block = ib % (HSK / 32); + if (idx + gl_WorkGroupSize.x <= quant_iters || c < Bc) { + const uint buf_ib = c * qf_stride + block; + if (!KV_bounds_check || j * Bc + c < KV) { + const uint global_ib = (j * Bc + c) * k_stride + block; + k_block_to_shmem(buf_ib, global_ib, iqs, k_offset); + } else { + k_block_to_shmem_zero(buf_ib, iqs); + } + } + } +#endif // MMQ barrier(); } +#ifndef MMQ // More d iterations means Q register caching becomes relevant // Few iterations means the additional registers needed are worse than the speed-up from caching if (HSK_per_thread / 4 > 4) { @@ -275,6 +351,110 @@ void main() { } } } +#else // MMQ + const uint hsk4 = HSK_per_thread / 4; + const uint d_per_step = (hsk4 % 8 == 0) ? 8 : + (hsk4 % 4 == 0) ? 4 : + (hsk4 % 2 == 0) ? 2 : 1; + + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) { + continue; + } + + [[unroll]] for (uint32_t d_block = 0; d_block < HSK_per_thread / 4; d_block += d_per_step) { + int32_t k_quants[d_per_step]; + ACC_TYPEV2 k_dm; + + if (SHMEM_STAGING != 0) { + const uint k_block_idx = (d_tid * (HSK_per_thread / 4) + d_block) / 8; + const uint buf_ib = (c * cols_per_iter + col_tid) * qf_stride + k_block_idx; +#if QUANT_AUXF == 1 + k_dm = ACC_TYPEV2(kblocksh[buf_ib].dm, 0.0); +#else + k_dm = ACC_TYPEV2(kblocksh[buf_ib].dm); +#endif + +#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1) || defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1) + if (d_per_step == 8) { + [[unroll]] for (uint32_t d = 0; d < 4; d++) { + uint vui = kblocksh[buf_ib].qs[d]; + k_quants[d ] = int32_t( vui & 0x0F0F0F0F); + k_quants[d + 4] = int32_t((vui >> 4) & 0x0F0F0F0F); +#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1) + uint qh_lo = (kblocksh[buf_ib].qh >> (d * 4)) & 0xF; + uint qh_hi = (kblocksh[buf_ib].qh >> (d * 4 + 16)) & 0xF; + k_quants[d ] |= int32_t((qh_lo * 0x02040810u) & 0x10101010u); + k_quants[d + 4] |= int32_t((qh_hi * 0x02040810u) & 0x10101010u); +#endif + } + } else +#endif + { + [[unroll]] for (uint32_t d = 0; d < d_per_step; d++) { + k_quants[d] = get_k_qs_shmem(buf_ib, (d_tid * (HSK_per_thread / 4) + d_block) % 8 + d); + } + } + } else { + const uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d_tid * (HSK_per_thread / 4) + d_block); + const uint ib = coord / BLOCK_SIZE; + const uint iqs = (coord % BLOCK_SIZE); + +#if QUANT_AUXF == 1 + k_dm = ACC_TYPEV2(get_k_d(ib, k_offset), 0.0); +#else + k_dm = ACC_TYPEV2(get_k_dm(ib, k_offset)); +#endif +#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1) || defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1) + if (d_per_step == 8) { +#if defined(DATA_A_Q5_0) + uint qh = pack32(u16vec2(k_packed.k_data_packed16[k_offset + ib].qh[0], + k_packed.k_data_packed16[k_offset + ib].qh[1])); +#elif defined(DATA_A_Q5_1) + uint qh = k_packed.k_data_packed16[k_offset + ib].qh; +#endif + [[unroll]] for (uint32_t d = 0; d < 4; d++) { +#if defined(A_TYPE_PACKED32) + uint vui = k_packed32.k_data_packed32[k_offset + ib].qs[d]; +#else + uint vui = pack32(u16vec2(k_packed.k_data_packed16[k_offset + ib].qs[iqs / 2 + d * 2 + 0], + k_packed.k_data_packed16[k_offset + ib].qs[iqs / 2 + d * 2 + 1])); +#endif + k_quants[d ] = int32_t( vui & 0x0F0F0F0F); + k_quants[d + 4] = int32_t((vui >> 4) & 0x0F0F0F0F); +#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1) + uint qh_lo = (qh >> (d * 4)) & 0xF; + uint qh_hi = (qh >> (d * 4 + 16)) & 0xF; + k_quants[d ] |= int32_t((qh_lo * 0x02040810u) & 0x10101010u); + k_quants[d + 4] |= int32_t((qh_hi * 0x02040810u) & 0x10101010u); +#endif + } + } else +#endif + { + [[unroll]] for (uint32_t d = 0; d < d_per_step; d++) { + k_quants[d] = get_k_qs(ib, iqs + d * 4, k_offset); + } + } + } + + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + const uint qib = tile_row(r) * qf_stride + (d_tid * (HSK_per_thread / 4) + d_block) / 8; + const uint qiqs = (d_tid * (HSK_per_thread / 4) + d_block) % 8; + + int32_t acc = 0; + [[unroll]] for (uint32_t d = 0; d < d_per_step; d++) { + acc += dotPacked4x8EXT(Qf[qib].qs[qiqs + d], k_quants[d]); + } + + Sf[r][c] += ACC_TYPE(acc) * ACC_TYPE(Qf[qib].ds.x) * k_dm.x; + if ((d_tid * (HSK_per_thread / 4) + d_block) % 8 == 0) { + Sf[r][c] += k_dot_correction(qib, k_dm); + } + } + } + } +#endif // MMQ [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { // Compute sum across the D_split diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl index b30dee86871..6f349246915 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl @@ -89,6 +89,11 @@ layout (binding = 1) readonly buffer K_PACKED16 {A_TYPE_PACKED16 k_data_packed16 layout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16[];} v_packed; #endif +#if defined(A_TYPE_PACKED32) +layout (binding = 1) readonly buffer K_PACKED32 {A_TYPE_PACKED32 k_data_packed32[];} k_packed32; +layout (binding = 2) readonly buffer V_PACKED32 {A_TYPE_PACKED32 v_data_packed32[];} v_packed32; +#endif + #ifndef BLOCK_SIZE #define BLOCK_SIZE 1 #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl new file mode 100644 index 00000000000..e14e62d546a --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl @@ -0,0 +1,149 @@ +#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1) +int32_t get_k_qs(uint ib, uint iqs, uint a_offset) { +#ifdef DATA_A_Q4_0 + uint vui = pack32(u16vec2(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0], + k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1])); +#else + uint vui = k_packed32.k_data_packed32[a_offset + ib].qs[(iqs & 0xF) / 4]; +#endif + + uint shift = (iqs & 0x10) >> 2; + vui >>= shift; + + return int32_t(vui & 0x0F0F0F0F); +} +#endif + +#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1) +int32_t get_k_qs(uint ib, uint iqs, uint a_offset) { +#ifdef DATA_A_Q5_0 + uint vui = pack32(u16vec2(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0], + k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1])); + uint qh = pack32(u16vec2(k_packed.k_data_packed16[a_offset + ib].qh[0], + k_packed.k_data_packed16[a_offset + ib].qh[1])); +#else + uint vui = k_packed32.k_data_packed32[a_offset + ib].qs[(iqs & 0xF) / 4]; + uint qh = k_packed.k_data_packed16[a_offset + ib].qh; +#endif + + uint shift = (iqs & 0x10) >> 2; + vui >>= shift; + + uint qh_bits = (qh >> iqs) & 0xF; + return int32_t(vui & 0x0F0F0F0F) | int32_t((qh_bits * 0x02040810u) & 0x10101010u); +} +#endif + +#if defined(DATA_A_Q8_0) +int32_t get_k_qs(uint ib, uint iqs, uint a_offset) { + return pack32(i16vec2(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2], k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2 + 1])); +} +#endif + +#if defined(DATA_A_IQ4_NL) +int32_t get_k_qs(uint ib, uint iqs, uint a_offset) { + uint vui = pack32(u16vec2(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0], + k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1])); + uint shift = (iqs & 0x10) >> 2; + vui >>= shift; + + u8vec4 idx = unpack8(vui & 0x0F0F0F0F); + return pack32(i8vec4(kvalues_iq4nl_const[idx.x], + kvalues_iq4nl_const[idx.y], + kvalues_iq4nl_const[idx.z], + kvalues_iq4nl_const[idx.w])); +} +#endif + +#if QUANT_AUXF == 1 +FLOAT_TYPE get_k_d(uint ib, uint a_offset) { + return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d); +} +#else +FLOAT_TYPEV2 get_k_dm(uint ib, uint a_offset) { + return FLOAT_TYPEV2(k_packed32.k_data_packed32[a_offset + ib].dm); +} +#endif + +void k_block_to_shmem(const uint buf_ib, const uint global_ib, const uint iqs, const uint a_offset) { +#if defined(DATA_A_Q4_0) + kblocksh[buf_ib].qs[iqs] = pack32(u16vec2(k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2], + k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2 + 1])); +#elif defined(DATA_A_Q4_1) + kblocksh[buf_ib].qs[iqs] = k_packed32.k_data_packed32[a_offset + global_ib].qs[iqs]; +#elif defined(DATA_A_Q5_0) + kblocksh[buf_ib].qs[iqs] = pack32(u16vec2(k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2], + k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2 + 1])); + if (iqs == 0) { + kblocksh[buf_ib].qh = pack32(u16vec2(k_packed.k_data_packed16[a_offset + global_ib].qh[0], + k_packed.k_data_packed16[a_offset + global_ib].qh[1])); + } +#elif defined(DATA_A_Q5_1) + kblocksh[buf_ib].qs[iqs] = k_packed32.k_data_packed32[a_offset + global_ib].qs[iqs]; + if (iqs == 0) { + kblocksh[buf_ib].qh = k_packed.k_data_packed16[a_offset + global_ib].qh; + } +#elif defined(DATA_A_Q8_0) + kblocksh[buf_ib].qs[iqs] = pack32(i16vec2(k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2], + k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2 + 1])); +#elif defined(DATA_A_IQ4_NL) + const uint qs = pack32(u16vec2(k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2], + k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2 + 1])); + const u8vec4 i_a0 = unpack8( qs & 0x0F0F0F0F); + const u8vec4 i_a1 = unpack8((qs >> 4) & 0x0F0F0F0F); + kblocksh[buf_ib].qs[iqs ] = pack32(i8vec4(kvalues_iq4nl_const[i_a0.x], kvalues_iq4nl_const[i_a0.y], + kvalues_iq4nl_const[i_a0.z], kvalues_iq4nl_const[i_a0.w])); + kblocksh[buf_ib].qs[iqs + 4] = pack32(i8vec4(kvalues_iq4nl_const[i_a1.x], kvalues_iq4nl_const[i_a1.y], + kvalues_iq4nl_const[i_a1.z], kvalues_iq4nl_const[i_a1.w])); +#endif + + if (iqs == 0) { +#if QUANT_AUXF == 1 + kblocksh[buf_ib].dm = FLOAT_TYPE(k_packed.k_data_packed16[a_offset + global_ib].d); +#else + kblocksh[buf_ib].dm = FLOAT_TYPEV2(k_packed32.k_data_packed32[a_offset + global_ib].dm); +#endif + } +} + +int32_t get_k_qs_shmem(const uint buf_ib, const uint pos) { +#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1) + uint sub = pos % 4; + uint shift = ((pos % 8) >= 4) ? 4 : 0; + return int32_t((kblocksh[buf_ib].qs[sub] >> shift) & 0x0F0F0F0F); +#elif defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1) + uint sub = pos % 4; + uint shift = ((pos % 8) >= 4) ? 4 : 0; + int32_t result = int32_t((kblocksh[buf_ib].qs[sub] >> shift) & 0x0F0F0F0F); + uint qh_bits = (kblocksh[buf_ib].qh >> (pos * 4)) & 0xF; + return result | int32_t((qh_bits * 0x02040810u) & 0x10101010u); +#elif defined(DATA_A_Q8_0) || defined(DATA_A_IQ4_NL) + return kblocksh[buf_ib].qs[pos]; +#endif +} + +ACC_TYPE k_dot_correction(const uint qib, const ACC_TYPEV2 k_dm) { +#if defined(DATA_A_Q4_0) + return -ACC_TYPE(8.0) * ACC_TYPE(Qf[qib].ds.y) * k_dm.x; +#elif defined(DATA_A_Q5_0) + return -ACC_TYPE(16.0) * ACC_TYPE(Qf[qib].ds.y) * k_dm.x; +#elif defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1) + return ACC_TYPE(Qf[qib].ds.y) * k_dm.y; +#else + return ACC_TYPE(0.0); +#endif +} + +void k_block_to_shmem_zero(const uint buf_ib, const uint iqs) { + kblocksh[buf_ib].qs[iqs] = 0; +#if defined(DATA_A_IQ4_NL) + kblocksh[buf_ib].qs[iqs + 4] = 0; +#endif + if (iqs == 0) { +#if QUANT_AUXF == 1 + kblocksh[buf_ib].dm = FLOAT_TYPE(0.0f); +#else + kblocksh[buf_ib].dm = FLOAT_TYPEV2(0.0f); +#endif + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl index c700f6e3f25..10552d013a2 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl @@ -32,6 +32,12 @@ struct block_a_cache { int32_t qs[32/4]; FLOAT_TYPE dm; }; +#elif defined(DATA_A_IQ4_NL) +#define QUANT_R_MMQ 2 +struct block_a_cache { + int32_t qs[8]; + FLOAT_TYPE dm; +}; #elif defined(DATA_A_MXFP4) #define QUANT_R_MMQ 2 struct block_a_cache { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl index 4239070af5e..1fb592fb84b 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl @@ -1692,6 +1692,7 @@ struct block_iq4_nl_packed16 #if defined(DATA_A_IQ4_NL) #define QUANT_K QUANT_K_IQ4_NL #define QUANT_R QUANT_R_IQ4_NL +#define QUANT_AUXF 1 #define A_TYPE block_iq4_nl #define A_TYPE_PACKED16 block_iq4_nl_packed16 #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 77a55ea812b..607eef7d0d6 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -406,8 +406,8 @@ std::map merge_maps(const std::map> compiles; -void string_to_spv(std::string name, const std::string& source, const std::map& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) { - name = name + (f16acc ? "_f16acc" : "") + (coopmat ? "_cm1" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32")); +void string_to_spv(std::string name, const std::string& source, const std::map& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false, const std::string& suffix = "") { + name = name + (f16acc ? "_f16acc" : "") + (coopmat ? "_cm1" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32")) + suffix; std::string out_path = join_paths(output_dir, name + ".spv"); if (input_filepath == "") { @@ -625,15 +625,16 @@ void process_shaders() { for (const bool& fp16 : {false, true}) { std::map base_dict; if (fp16) { - base_dict = {{"FLOAT_TYPE", "float16_t"}, {"FLOAT_TYPEV4", "f16vec4"}, {"FLOAT16", "1"}, {"FLOAT_TYPE_MAX", "float16_t(65504.0)"}}; + base_dict = {{"FLOAT_TYPE", "float16_t"}, {"FLOAT_TYPEV2", "f16vec2"}, {"FLOAT_TYPEV4", "f16vec4"}, {"FLOAT16", "1"}, {"FLOAT_TYPE_MAX", "float16_t(65504.0)"}}; } else { - base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV4", "vec4"}}; + base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}, {"FLOAT_TYPEV4", "vec4"}}; } // flash attention for (const bool& f16acc : {false, true}) { std::map fa_base_dict = base_dict; fa_base_dict["ACC_TYPE"] = fp16 && f16acc ? "float16_t" : "float"; + fa_base_dict["ACC_TYPEV2"] = fp16 && f16acc ? "f16vec2" : "vec2"; fa_base_dict["ACC_TYPEV4"] = fp16 && f16acc ? "f16vec4" : "vec4"; if (fp16 && f16acc) { fa_base_dict["ACC_TYPE_MAX"] = "float16_t(65504.0)"; @@ -672,6 +673,12 @@ void process_shaders() { std::string data_a_key = "DATA_A_" + to_uppercase(tname); string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), fp16, false, false, f16acc); +#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + if (tname != "f32") { + string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", + merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }, {"MMQ", "1"}}), fp16, false, false, f16acc, "_int8"); + } +#endif } } } From cdeaa341742c4d558d7020079ef0e282803511a6 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Tue, 14 Apr 2026 11:34:23 +0200 Subject: [PATCH 425/831] vulkan: Support GGML_TYPE_NVFP4 (llama/21455) This adds nvfp4 support for get_rows, dequant, and mul_mat(_id). For mul_mat, it does not add support for the dp4/q8_1 path, it's all via fp16/fp32. --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 28 +++++++++++ .../vulkan-shaders/copy_from_quant.comp | 2 +- .../vulkan-shaders/dequant_funcs.glsl | 25 ++++++++++ .../vulkan-shaders/dequant_funcs_cm2.glsl | 20 ++++++++ .../vulkan-shaders/dequant_nvfp4.comp | 32 +++++++++++++ .../vulkan-shaders/mul_mm_funcs.glsl | 17 +++++++ .../src/ggml-vulkan/vulkan-shaders/types.glsl | 47 ++++++++++++++++++- .../vulkan-shaders/vulkan-shaders-gen.cpp | 3 +- 8 files changed, 171 insertions(+), 3 deletions(-) create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 1bee3e187cf..b353d041421 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -3079,6 +3079,10 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec case GGML_TYPE_MXFP4: lut_size = 4*16; break; + case GGML_TYPE_NVFP4: + // Same kvalues budget as MXFP4 plus ue4m3_fp32_lut[128] (types.glsl, DATA_A_NVFP4). + lut_size = 4*16 + 128u * (uint32_t)sizeof(float); + break; default: break; } @@ -3558,6 +3562,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_MXFP4], matmul_mxfp4_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_NVFP4], matmul_nvfp4_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) GGML_ASSERT(device->subgroup_ballot); @@ -3588,6 +3593,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5) CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5) CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_NVFP4], matmul_id_subgroup_nvfp4_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5) #undef CREATE_MM #undef CREATE_MM2 } else @@ -3651,6 +3657,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4], matmul_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_NVFP4], matmul_nvfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); } else { CREATE_MM(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q1_0].f32acc, matmul_q1_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); @@ -3674,6 +3681,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4].f32acc, matmul_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_NVFP4].f32acc, matmul_nvfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); } GGML_ASSERT(device->subgroup_ballot); @@ -3708,6 +3716,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); + CREATE_MM2(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_NVFP4], matmul_id_subgroup_nvfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); #undef CREATE_MM2 #undef CREATE_MM } else @@ -3773,6 +3782,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4], matmul_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_NVFP4], matmul_nvfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) if (device->integer_dot_product) { @@ -3819,6 +3829,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_NVFP4], matmul_id_subgroup_nvfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) if (device->integer_dot_product) { @@ -3864,6 +3875,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM2(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_NVFP4], matmul_id_nvfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) if (device->integer_dot_product) { @@ -3939,6 +3951,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4].f32acc, matmul_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_NVFP4].f32acc, matmul_nvfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) if (device->integer_dot_product) { @@ -3983,6 +3996,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_subgroup_iq4_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_subgroup_iq4_nl_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_subgroup_mxfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_NVFP4].f32acc, matmul_id_subgroup_nvfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); } else { CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); @@ -4010,6 +4024,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_mxfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_NVFP4].f32acc, matmul_id_nvfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); } } // reusing CREATE_MM from the fp32 path @@ -4108,6 +4123,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f32_f32", arr_dmmv_iq4_xs_f32_f32_len[reduc16], arr_dmmv_iq4_xs_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32", arr_dmmv_iq4_nl_f32_f32_len[reduc16], arr_dmmv_iq4_nl_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f32_f32", arr_dmmv_mxfp4_f32_f32_len[reduc16], arr_dmmv_mxfp4_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_NVFP4][i], "mul_mat_vec_nvfp4_f32_f32", arr_dmmv_nvfp4_f32_f32_len[reduc16], arr_dmmv_nvfp4_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32", arr_dmmv_f32_f16_f32_len[reduc], arr_dmmv_f32_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {wg_size_subgroup, 1, i+1}, 1, false, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32", arr_dmmv_f16_f16_f32_len[reduc], arr_dmmv_f16_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size); @@ -4133,6 +4149,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f16_f32", arr_dmmv_iq4_xs_f16_f32_len[reduc16], arr_dmmv_iq4_xs_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32", arr_dmmv_iq4_nl_f16_f32_len[reduc16], arr_dmmv_iq4_nl_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f16_f32", arr_dmmv_mxfp4_f16_f32_len[reduc16], arr_dmmv_mxfp4_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_NVFP4][i], "mul_mat_vec_nvfp4_f16_f32", arr_dmmv_nvfp4_f16_f32_len[reduc16], arr_dmmv_nvfp4_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) if (device->integer_dot_product) { @@ -4184,6 +4201,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ4_XS], "mul_mat_vec_id_iq4_xs_f32", arr_dmmv_id_iq4_xs_f32_f32_len[reduc16], arr_dmmv_id_iq4_xs_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", arr_dmmv_id_iq4_nl_f32_f32_len[reduc16], arr_dmmv_id_iq4_nl_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_MXFP4], "mul_mat_vec_id_mxfp4_f32", arr_dmmv_id_mxfp4_f32_f32_len[reduc16], arr_dmmv_id_mxfp4_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_NVFP4], "mul_mat_vec_id_nvfp4_f32", arr_dmmv_id_nvfp4_f32_f32_len[reduc16], arr_dmmv_id_nvfp4_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16); #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) if (device->integer_dot_product) { @@ -4239,6 +4257,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_XS], "dequant_iq4_xs", dequant_iq4_xs_len, dequant_iq4_xs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_NL], "dequant_iq4_nl", dequant_iq4_nl_len, dequant_iq4_nl_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_MXFP4], "dequant_mxfp4", dequant_mxfp4_len, dequant_mxfp4_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_NVFP4], "dequant_nvfp4", dequant_nvfp4_len, dequant_nvfp4_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); // get_rows ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F32 ], "get_rows_f32", get_rows_f32_len, get_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); @@ -4265,6 +4284,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs", get_rows_iq4_xs_len, get_rows_iq4_xs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl", get_rows_iq4_nl_len, get_rows_iq4_nl_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_MXFP4], "get_rows_mxfp4", get_rows_mxfp4_len, get_rows_mxfp4_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_NVFP4], "get_rows_nvfp4", get_rows_nvfp4_len, get_rows_nvfp4_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_I32], "get_rows_i32", get_rows_i32_len, get_rows_i32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32", get_rows_f32_f32_len, get_rows_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); @@ -4291,6 +4311,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs_f32", get_rows_iq4_xs_f32_len, get_rows_iq4_xs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_MXFP4], "get_rows_mxfp4_f32", get_rows_mxfp4_f32_len, get_rows_mxfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_NVFP4], "get_rows_nvfp4_f32", get_rows_nvfp4_f32_len, get_rows_nvfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, sizeof(vk_op_flash_attn_split_k_reduce_push_constants), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true); @@ -6089,6 +6110,7 @@ static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: case GGML_TYPE_MXFP4: + case GGML_TYPE_NVFP4: break; default: return nullptr; @@ -6161,6 +6183,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: case GGML_TYPE_MXFP4: + case GGML_TYPE_NVFP4: break; default: return nullptr; @@ -6227,6 +6250,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: case GGML_TYPE_MXFP4: + case GGML_TYPE_NVFP4: break; default: return nullptr; @@ -6318,6 +6342,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: case GGML_TYPE_MXFP4: + case GGML_TYPE_NVFP4: break; default: return nullptr; @@ -6387,6 +6412,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: case GGML_TYPE_MXFP4: + case GGML_TYPE_NVFP4: break; default: return nullptr; @@ -15373,6 +15399,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: case GGML_TYPE_MXFP4: + case GGML_TYPE_NVFP4: break; default: return false; @@ -15488,6 +15515,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: case GGML_TYPE_MXFP4: + case GGML_TYPE_NVFP4: case GGML_TYPE_I32: return true; default: diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp b/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp index 06df5095258..6a692147478 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp @@ -4,7 +4,7 @@ #include "generic_unary_head.glsl" #include "dequant_funcs.glsl" -#if defined(DATA_A_IQ4_NL) || defined(DATA_A_MXFP4) +#if defined(DATA_A_IQ4_NL) || defined(DATA_A_MXFP4) || defined(DATA_A_NVFP4) // 16 invocations needed for init_iq_shmem layout(local_size_x = 16, local_size_y = 1, local_size_z = 1) in; #else diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl index ede1275cfc2..88d07d2dfd5 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl @@ -450,6 +450,25 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) { } #endif +#if defined(DATA_A_NVFP4) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint sub = iqs >> 4; + const float d = ue4m3_to_fp32(data_a[a_offset + ib].d[sub]); + const uint j = iqs & 7; + const uint shift = (iqs & 8) >> 1; // 0 or 4 + const uint vui0 = uint(data_a[a_offset + ib].qs[sub * 8u + j]); + const uint vui1 = uint(data_a[a_offset + ib].qs[sub * 8u + j + 1]); + const uint qs0 = (vui0 >> shift) & 0xF; + const uint qs1 = (vui1 >> shift) & 0xF; + return vec2(float(kvalues_mxfp4[qs0]), float(kvalues_mxfp4[qs1])) * d * 0.5; +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const vec2 v0 = dequantize(ib, iqs, a_offset); + const vec2 v1 = dequantize(ib, iqs + 2u, a_offset); + return vec4(v0.x, v0.y, v1.x, v1.y); +} +#endif + #if defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16) vec2 get_dm(uint ib, uint a_offset) { return vec2(0, 0); @@ -484,6 +503,12 @@ vec2 get_dm(uint ib, uint a_offset) { } #endif +#if defined(DATA_A_NVFP4) +vec2 get_dm(uint ib, uint a_offset) { + return vec2(1.0, 0.0); +} +#endif + #if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1) vec2 get_dm(uint ib, uint a_offset) { const vec2 dm = vec2(data_a_packed32[a_offset + ib].dm); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl index 03035f28120..c582aba87dc 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl @@ -697,6 +697,24 @@ float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords } #endif +#if defined(DATA_A_NVFP4) +layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufNVFP4 { + block_nvfp4 block; +}; + +float16_t dequantFuncNVFP4(const in decodeBufNVFP4 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const uint idx = coordInBlock[1]; + const uint sub = (idx & 0x30) >> 4; + const uint iqs = ((idx & 0x30) >> 1) + (idx & 0x7); + const uint shift = (idx & 0x8) >> 1; + const float d = ue4m3_to_fp32(bl.block.d[sub]); + uint qs = uint(bl.block.qs[iqs]); + qs = (qs >> shift) & 0xF; + return float16_t(kvalues_mxfp4[qs] * d * 0.5); +} +#endif + #if defined(DATA_A_Q1_0) #define dequantFuncA dequantFuncQ1_0 #elif defined(DATA_A_Q4_0) @@ -743,6 +761,8 @@ float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords #define dequantFuncA dequantFuncIQ4_NL #elif defined(DATA_A_MXFP4) #define dequantFuncA dequantFuncMXFP4 +#elif defined(DATA_A_NVFP4) +#define dequantFuncA dequantFuncNVFP4 #elif defined(DATA_A_F32) #define dequantFuncA dequantFuncF32 #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp new file mode 100644 index 00000000000..689089160b7 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp @@ -0,0 +1,32 @@ +#version 450 + +#include "dequant_head.glsl" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_nvfp4 data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64; + + init_iq_shmem(gl_WorkGroupSize); + + const uint tid = gl_LocalInvocationID.x % 64; + const uint sub = tid / 16; + const uint ir = tid % 16; + const uint ib = 16 * i + ir; + if (ib >= p.nel / 64) { + return; + } + + const uint q_idx = 8 * sub; + const uint b_idx = 1024 * i + 64 * ir + 16 * sub; + + const float d = ue4m3_to_fp32(data_a[ib].d[sub]); + + [[unroll]] for (uint l = 0; l < 8; ++l) { + data_b[b_idx + l + 0] = D_TYPE(d * 0.5 * float(kvalues_mxfp4[data_a[ib].qs[q_idx + l] & 0xF])); + data_b[b_idx + l + 8] = D_TYPE(d * 0.5 * float(kvalues_mxfp4[data_a[ib].qs[q_idx + l] >> 4])); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl index 219bd608035..6e4a29d2fdd 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl @@ -501,6 +501,23 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin kvalues_mxfp4[vui2 & 0xF] * d); buf_a[buf_idx + 8] = FLOAT_TYPEV2(kvalues_mxfp4[vui >> 4] * d, kvalues_mxfp4[vui2 >> 4] * d); +#elif defined(DATA_A_NVFP4) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + // lo and hi nibbles are 8 elements apart, which doesn't quite line up with + // how the thread mapping and buf_idx calculation works for other types. + const uint buf_idx = col * SHMEM_STRIDE + (row & 3) + (row & ~3) * 2; + + const uint ib = idx / 16u; + const uint sub = (idx & 0xC) >> 2; + const uint iqs = (idx & 0xF) * 2; + const float d = ue4m3_to_fp32(data_a[ib].d[sub]) * 0.5; + const uint vui = uint(data_a[ib].qs[iqs]); + const uint vui2 = uint(data_a[ib].qs[iqs+1]); + + buf_a[buf_idx ] = FLOAT_TYPEV2(kvalues_mxfp4[vui & 0xF] * d, + kvalues_mxfp4[vui2 & 0xF] * d); + buf_a[buf_idx + 4] = FLOAT_TYPEV2(kvalues_mxfp4[vui >> 4] * d, + kvalues_mxfp4[vui2 >> 4] * d); #endif } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl index 1fb592fb84b..4bcd97756fd 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl @@ -1713,6 +1713,22 @@ struct block_mxfp4 #define A_TYPE block_mxfp4 #endif +#define QUANT_K_NVFP4 64 +#define QUANT_R_NVFP4 1 + +struct block_nvfp4 +{ + uint8_t d[QUANT_K_NVFP4 / 16]; + uint8_t qs[QUANT_K_NVFP4 / 2]; +}; + +#if defined(DATA_A_NVFP4) +#define QUANT_K QUANT_K_NVFP4 +#define QUANT_R QUANT_R_NVFP4 +#define QUANT_AUXF 1 +#define A_TYPE block_nvfp4 +#endif + #if defined(DATA_A_IQ4_NL) || defined(DATA_A_IQ4_XS) const int8_t kvalues_iq4nl_const[16] = { int8_t(-127), int8_t(-104), int8_t(-83), int8_t(-65), int8_t(-49), int8_t(-35), int8_t(-22), int8_t(-10), @@ -1732,7 +1748,7 @@ void init_iq_shmem(uvec3 wgsize) } #endif -#if defined(DATA_A_MXFP4) +#if defined(DATA_A_MXFP4) || defined(DATA_A_NVFP4) const int8_t kvalues_mxfp4_const[16] = { int8_t(0), int8_t(1), int8_t(2), int8_t(3), int8_t(4), int8_t(6), int8_t(8), int8_t(12), int8_t(0), int8_t(-1), int8_t(-2), int8_t(-3), int8_t(-4), int8_t(-6), int8_t(-8), int8_t(-12), @@ -1740,6 +1756,24 @@ const int8_t kvalues_mxfp4_const[16] = { shared int8_t kvalues_mxfp4[16]; +#if defined(DATA_A_NVFP4) +// UE4M3 scale in NVFP4 blocks use only 7 bits; sign (bit 7) is always zero. +shared float ue4m3_fp32_lut[128]; + +float ue4m3_to_fp32_build(uint u) { + if (u == 0u || u == 127u) { + return 0.0; + } + const uint exp = (u >> 3) & 15u; + const uint man = u & 7u; + if (exp == 0u) { + return float(man) * (1.0 / 512.0); + } + const uint bits = (exp + 120u) << 23 | (man << 20); + return uintBitsToFloat(bits); +} +#endif + #define NEEDS_INIT_IQ_SHMEM void init_iq_shmem(uvec3 wgsize) { @@ -1747,6 +1781,11 @@ void init_iq_shmem(uvec3 wgsize) for (uint i = gl_LocalInvocationIndex.x; i < kvalues_mxfp4.length(); i += wgsize.x) { kvalues_mxfp4[i] = kvalues_mxfp4_const[i]; } +#if defined(DATA_A_NVFP4) + for (uint i = gl_LocalInvocationIndex.x; i < 128u; i += wgsize.x) { + ue4m3_fp32_lut[i] = ue4m3_to_fp32_build(i); + } +#endif barrier(); } #endif @@ -1783,6 +1822,12 @@ float e8m0_to_fp32(uint8_t x) { return uintBitsToFloat(bits); } +#if defined(DATA_A_NVFP4) +float ue4m3_to_fp32(uint8_t x) { + return ue4m3_fp32_lut[uint(x)]; +} +#endif + #if BDA #extension GL_EXT_buffer_reference : enable diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 607eef7d0d6..b232927658b 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -66,6 +66,7 @@ const std::vector type_names = { "iq4_xs", "iq4_nl", "mxfp4", + "nvfp4", "bf16", }; @@ -556,7 +557,7 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c std::string load_vec_quant = "2"; if ((tname == "q1_0") || (tname == "q4_0") || (tname == "q4_1") || (tname == "q5_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s")) load_vec_quant = "8"; - else if ((tname == "q5_0") || (tname == "q8_0") || (tname == "q2_k") || (tname == "q4_k") || (tname == "q5_k") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_xs") || (tname == "iq4_nl") || (tname == "mxfp4")) + else if ((tname == "q5_0") || (tname == "q8_0") || (tname == "q2_k") || (tname == "q4_k") || (tname == "q5_k") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_xs") || (tname == "iq4_nl") || (tname == "mxfp4") || (tname == "nvfp4")) load_vec_quant = "4"; if (tname == "bf16") { From b732f4d9b5429c72f7e50ed7001588a4aa847380 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Tue, 14 Apr 2026 03:46:41 -0700 Subject: [PATCH 426/831] ggml-webgpu: Update register tiling matmul to use f32 accumulation (llama/21644) * Update register tiling matmul to use f32 accumulation * fix profiling code * Fix register tiling matmul for chrome, i'm blaming dawn * Update batch tuning value for iOS * compile fix * Fix use of new load function --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 51 +++++++------------ .../wgsl-shaders/mul_mat_decls.tmpl | 35 +++++-------- .../wgsl-shaders/mul_mat_reg_tile.wgsl | 12 ++--- .../wgsl-shaders/mul_mat_subgroup_matrix.wgsl | 3 ++ 4 files changed, 40 insertions(+), 61 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 634201bc64d..8d0e109365f 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -79,7 +79,7 @@ static inline void compute_2d_workgroups(uint32_t total_wg, uint32_t max_per_dim /* Constants */ -#define WEBGPU_DEFAULT_COMMAND_SUBMIT_BATCH_SIZE 32u +#define WEBGPU_DEFAULT_COMMAND_SUBMIT_BATCH_SIZE 64u #define WEBGPU_NUM_PARAM_SLOT_SAFETY_MARGIN 10u #define WEBGPU_RUNTIME_WAIT_TIMEOUT_MS 30000u #define WEBGPU_RUNTIME_WAIT_TIMEOUT_NS (WEBGPU_RUNTIME_WAIT_TIMEOUT_MS * 1e6) @@ -97,14 +97,6 @@ static inline void compute_2d_workgroups(uint32_t total_wg, uint32_t max_per_dim /* End Constants */ -static inline wgpu::CallbackMode ggml_webgpu_callback_mode() { -#ifdef __EMSCRIPTEN__ - return wgpu::CallbackMode::AllowProcessEvents; -#else - return wgpu::CallbackMode::AllowSpontaneous; -#endif -} - // This is a "fake" base pointer, since WebGPU buffers do not have pointers to // their locations. static void * const webgpu_ptr_base = (void *) (uintptr_t) 0x1000; // NOLINT @@ -445,34 +437,25 @@ static void ggml_backend_webgpu_check_wait_status(wgpu::WaitStatus wait_status, } #ifdef __EMSCRIPTEN__ -// iOS browsers seem to have very strict limits on the number of in-flight GPU commands, so we need to throttle to avoid failures. EM_JS(int, ggml_webgpu_is_ios_browser, (), { const ua = navigator.userAgent; return (ua.includes('iPhone') || ua.includes('iPad')) ? 1 : 0; }); #endif -static uint32_t ggml_backend_webgpu_get_max_inflight_batches(const wgpu::AdapterInfo & info) { +// TODO: these next two functions may want tuning across different platforms and workloads, +static uint32_t ggml_backend_webgpu_get_max_inflight_batches() { #ifdef __EMSCRIPTEN__ + // iOS has very strict limits on the number of in-flight GPU commands, + // so we need to throttle to avoid failures. if (ggml_webgpu_is_ios_browser()) { return 1; } -#else - GGML_UNUSED(info); #endif - return UINT32_MAX; } -static uint32_t ggml_backend_webgpu_get_command_submit_batch_size(const wgpu::AdapterInfo & info) { -#ifdef __EMSCRIPTEN__ - if (ggml_webgpu_is_ios_browser()) { - return 16; - } -#else - GGML_UNUSED(info); -#endif - +static uint32_t ggml_backend_webgpu_get_command_submit_batch_size() { return WEBGPU_DEFAULT_COMMAND_SUBMIT_BATCH_SIZE; } @@ -482,7 +465,7 @@ static void ggml_backend_webgpu_wait_queue(webgpu_global_context & ctx) { const wgpu::WaitStatus wait_status = ctx->instance.WaitAny( ctx->queue.OnSubmittedWorkDone( - ggml_webgpu_callback_mode(), + wgpu::CallbackMode::AllowSpontaneous, [&callback_status, &callback_message](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) { callback_status = status; callback_message = std::string(message); @@ -502,7 +485,7 @@ static void ggml_backend_webgpu_map_buffer(webgpu_global_context & ctx, std::string callback_message; const wgpu::WaitStatus wait_status = ctx->instance.WaitAny( - buffer.MapAsync(mode, offset, size, ggml_webgpu_callback_mode(), + buffer.MapAsync(mode, offset, size, wgpu::CallbackMode::AllowSpontaneous, [&callback_status, &callback_message](wgpu::MapAsyncStatus status, wgpu::StringView message) { callback_status = status; callback_message = std::string(message); @@ -542,15 +525,15 @@ static void ggml_backend_webgpu_debug(webgpu_global_context & ctx) { #endif #ifdef GGML_WEBGPU_GPU_PROFILE -static void ggml_backend_webgpu_collect_profile_futures(webgpu_global_context & ctx, - const std::vector & commands, - std::vector & futures) { +static void ggml_backend_webgpu_collect_profile_futures(webgpu_global_context & ctx, + const std::vector & commands, + std::vector & futures) { for (const auto & command : commands) { auto label = command.pipeline_name; auto ts_bufs = command.timestamp_query_bufs; wgpu::Future f = ts_bufs.host_buf.MapAsync( - wgpu::MapMode::Read, 0, ts_bufs.host_buf.GetSize(), ggml_webgpu_callback_mode(), + wgpu::MapMode::Read, 0, ts_bufs.host_buf.GetSize(), wgpu::CallbackMode::AllowSpontaneous, [ctx, ts_bufs, label](wgpu::MapAsyncStatus status, wgpu::StringView message) { if (status != wgpu::MapAsyncStatus::Success) { GGML_LOG_ERROR("ggml_webgpu: Failed to map timestamp buffer: %s\n", std::string(message).c_str()); @@ -3428,7 +3411,7 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { ctx->webgpu_global_ctx->instance.WaitAny( ctx->webgpu_global_ctx->instance.RequestAdapter( - &options, ggml_webgpu_callback_mode(), + &options, wgpu::CallbackMode::AllowSpontaneous, [&ctx](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message) { if (status != wgpu::RequestAdapterStatus::Success) { GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message); @@ -3449,8 +3432,8 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { } #endif ctx->webgpu_global_ctx->adapter.GetInfo(&info); - ctx->webgpu_global_ctx->command_submit_batch_size = ggml_backend_webgpu_get_command_submit_batch_size(info); - ctx->webgpu_global_ctx->max_inflight_batches = ggml_backend_webgpu_get_max_inflight_batches(info); + ctx->webgpu_global_ctx->command_submit_batch_size = ggml_backend_webgpu_get_command_submit_batch_size(); + ctx->webgpu_global_ctx->max_inflight_batches = ggml_backend_webgpu_get_max_inflight_batches(); wgpu::SupportedFeatures features; ctx->webgpu_global_ctx->adapter.GetFeatures(&features); // we require f16 support @@ -3501,7 +3484,7 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { dev_desc.requiredFeatures = required_features.data(); dev_desc.requiredFeatureCount = required_features.size(); dev_desc.SetDeviceLostCallback( - ggml_webgpu_callback_mode(), + wgpu::CallbackMode::AllowSpontaneous, [ctx](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) { if (reason == wgpu::DeviceLostReason::Destroyed) { return; @@ -3535,7 +3518,7 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { ctx->webgpu_global_ctx->instance.WaitAny( ctx->webgpu_global_ctx->adapter.RequestDevice( - &dev_desc, ggml_webgpu_callback_mode(), + &dev_desc, wgpu::CallbackMode::AllowSpontaneous, [ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) { if (status != wgpu::RequestDeviceStatus::Success) { GGML_LOG_ERROR("ggml_webgpu: Failed to get a device: %s\n", std::string(message).c_str()); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl index 374137ff8e8..56a76a6e6c4 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl @@ -502,12 +502,6 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let d = load_f16_at(&src0, block_byte_base); let dmin = load_f16_at(&src0, block_byte_base + 2u); - // Load packed scales - var scale_vals: array; - for (var i: u32 = 0u; i < 3u; i++) { - scale_vals[i] = load_u32_at(&src0, block_byte_base + 4u + 4u * i); - } - // Map k_in_block to loop structure: // Outer loop over 64-element groups (alternating q_b_idx) // Inner loop over 2 shifts per group @@ -523,15 +517,17 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 var sc: u32; var mn: u32; + let scale_base = block_byte_base + 4u; + if (is < 4u) { - let sc_byte = get_byte(scale_vals[is / 4u], is % 4u); - let min_byte = get_byte(scale_vals[(is + 4u) / 4u], is % 4u); + let sc_byte = get_byte(load_u32_at(&src0, scale_base), is % 4u); + let min_byte = get_byte(load_u32_at(&src0, scale_base + 4), is % 4u); sc = sc_byte & 63u; mn = min_byte & 63u; } else { - let sc_min_lo = get_byte(scale_vals[(is + 4u) / 4u], (is + 4u) % 4u); - let sc_hi = get_byte(scale_vals[(is - 4u) / 4u], (is - 4u) % 4u); - let min_hi = get_byte(scale_vals[is / 4u], is % 4u); + let sc_min_lo = get_byte(load_u32_at(&src0, scale_base + 8), (is + 4u) % 4u); + let sc_hi = get_byte(load_u32_at(&src0, scale_base), (is - 4u) % 4u); + let min_hi = get_byte(load_u32_at(&src0, scale_base + 4), is % 4u); sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u); mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u); @@ -578,11 +574,6 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let d = load_f16_at(&src0, block_byte_base); let dmin = load_f16_at(&src0, block_byte_base + 2u); - // Load packed scales - var scale_vals: array; - for (var i: u32 = 0u; i < 3u; i++) { - scale_vals[i] = load_u32_at(&src0, block_byte_base + 4u + 4u * i); - } // The original loop processes elements in groups of 64 // Each group of 64: q_b_idx cycles through [0,32,64,96], shift cycles [0,4] @@ -603,15 +594,17 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 var sc: u32; var mn: u32; + let scale_base = block_byte_base + 4u; + if (is < 4u) { - let sc_byte = get_byte(scale_vals[is / 4u], is % 4u); - let min_byte = get_byte(scale_vals[(is + 4u) / 4u], is % 4u); + let sc_byte = get_byte(load_u32_at(&src0, scale_base), is % 4u); + let min_byte = get_byte(load_u32_at(&src0, scale_base + 4), is % 4u); sc = sc_byte & 63u; mn = min_byte & 63u; } else { - let sc_min_lo = get_byte(scale_vals[(is + 4u) / 4u], (is + 4u) % 4u); - let sc_hi = get_byte(scale_vals[(is - 4u) / 4u], (is - 4u) % 4u); - let min_hi = get_byte(scale_vals[is / 4u], is % 4u); + let sc_min_lo = get_byte(load_u32_at(&src0, scale_base + 8), (is + 4u) % 4u); + let sc_hi = get_byte(load_u32_at(&src0, scale_base), (is - 4u) % 4u); + let min_hi = get_byte(load_u32_at(&src0, scale_base + 4), is % 4u); sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u); mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl index b1da421a691..ee37e6d249c 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl @@ -4,14 +4,14 @@ enable f16; #include "mul_mat_decls.tmpl" #ifdef VEC -fn store_val(acc: array, TILE_M>, tn: u32, tm: u32) -> vec4 { - return vec4(f32(acc[tm][tn]), f32(acc[tm + 1][tn]), f32(acc[tm + 2][tn]), f32(acc[tm + 3][tn])); +fn store_val(acc: array, TILE_M>, tn: u32, tm: u32) -> vec4 { + return vec4(acc[tm][tn], acc[tm + 1][tn], acc[tm + 2][tn], acc[tm + 3][tn]); } #endif #ifdef SCALAR -fn store_val(acc: array, TILE_M>, tn: u32, tm: u32) -> f32 { - return f32(acc[tm][tn]); +fn store_val(acc: array, TILE_M>, tn: u32, tm: u32) -> f32 { + return acc[tm][tn]; } #endif @@ -98,7 +98,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let offset_m = wg_m * WORKGROUP_SIZE_M * TILE_M; let offset_n = wg_n * WORKGROUP_SIZE_N * TILE_N; - var acc: array, TILE_M>; + var acc: array, TILE_M>; for (var k_outer = 0u; k_outer < params.k; k_outer += TILE_K) { @@ -122,7 +122,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let src1_idx = src1_n * TILE_K + k_inner; let src1_val = shmem[TILE_SRC0_SHMEM + src1_idx]; for (var tm = 0u; tm < TILE_M; tm++) { - acc[tm][tn] += src0_tile[tm] * src1_val; + acc[tm][tn] += f32(src0_tile[tm]) * f32(src1_val); } } } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl index 9f9ef279f29..4151ce430b0 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl @@ -6,6 +6,9 @@ enable chromium_experimental_subgroup_matrix; #include "common_decls.tmpl" #include "mul_mat_decls.tmpl" +// TODO: this shader path does not work with some models like qwen2.5 on Metal devices, f16 accumulation causes NaNs. +// See https://github.com/ggml-org/llama.cpp/issues/21602 + #ifdef VEC fn store_dst(shmem_idx: u32, dst_idx: u32) { dst[dst_idx] = vec4( From bfdcd4a92c0302905f8c6010642e0e87685d53b1 Mon Sep 17 00:00:00 2001 From: texasich <101962694+texasich@users.noreply.github.com> Date: Tue, 14 Apr 2026 05:47:56 -0500 Subject: [PATCH 427/831] cmake: fix CMP0194 warning on Windows with MSVC (llama/21630) * cmake: fix CMP0194 warning on Windows with MSVC Set CMP0194 policy to NEW before project() call in ggml/CMakeLists.txt to suppress the "MSVC is not an assembler for language ASM" warning introduced in CMake 4.1. The ggml project enables ASM globally for Metal (macOS) and KleidiAI (ARM) backends. On Windows/MSVC, no assembler sources are used, but CMake 4.1+ warns because cl.exe is not a valid ASM compiler. This follows the same pattern used in ggml-vulkan (CMP0114, CMP0147). Closes ggml-org/llama.cpp#20311 * cmake: apply cisc's formatting suggestion --------- Co-authored-by: texasich --- ggml/CMakeLists.txt | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 6bf15723b3c..8454eecde6e 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -1,4 +1,11 @@ cmake_minimum_required(VERSION 3.14...3.28) # for add_link_options and implicit target directories. + +# ref: https://cmake.org/cmake/help/latest/policy/CMP0194.html +# MSVC is not a valid assembler for the ASM language. +# Set to NEW to avoid a warning on CMake 4.1+ with MSVC. +if (POLICY CMP0194) + cmake_policy(SET CMP0194 NEW) +endif() project("ggml" C CXX ASM) ### GGML Version From 80f7be74bb45e575f1cf2ab35e1ba8553358694a Mon Sep 17 00:00:00 2001 From: Richard Davison Date: Tue, 14 Apr 2026 13:23:45 +0200 Subject: [PATCH 428/831] ggml : fix ARM NEON nvfp4 dot product on non-dotprod targets (llama/21559) --- ggml/src/ggml-cpu/arch/arm/quants.c | 40 ++++++++++++++++++++++++----- ggml/src/ggml-cpu/ggml-cpu-impl.h | 10 ++++++++ 2 files changed, 43 insertions(+), 7 deletions(-) diff --git a/ggml/src/ggml-cpu/arch/arm/quants.c b/ggml/src/ggml-cpu/arch/arm/quants.c index e09db59cf22..64d811fafe7 100644 --- a/ggml/src/ggml-cpu/arch/arm/quants.c +++ b/ggml/src/ggml-cpu/arch/arm/quants.c @@ -783,6 +783,7 @@ void ggml_vec_dot_nvfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo const int8x16_t q4_lo_1 = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits_1, m4b)); const int8x16_t q4_hi_1 = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits_1, 4)); +#if defined(__ARM_FEATURE_DOTPROD) const int8x16_t q8_0a = vld1q_s8(y[2*ib].qs); const int8x16_t q8_0b = vld1q_s8(y[2*ib].qs + 16); const int8x16_t q8_lo_0 = vcombine_s8(vget_low_s8(q8_0a), vget_low_s8(q8_0b)); @@ -794,15 +795,40 @@ void ggml_vec_dot_nvfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo const int8x16_t q8_hi_1 = vcombine_s8(vget_high_s8(q8_1a), vget_high_s8(q8_1b)); const int32x4_t p0 = vaddq_s32( - ggml_vdotq_s32(vdupq_n_s32(0), q4_lo_0, q8_lo_0), - ggml_vdotq_s32(vdupq_n_s32(0), q4_hi_0, q8_hi_0)); + vdotq_s32(vdupq_n_s32(0), q4_lo_0, q8_lo_0), + vdotq_s32(vdupq_n_s32(0), q4_hi_0, q8_hi_0)); const int32x4_t p1 = vaddq_s32( - ggml_vdotq_s32(vdupq_n_s32(0), q4_lo_1, q8_lo_1), - ggml_vdotq_s32(vdupq_n_s32(0), q4_hi_1, q8_hi_1)); + vdotq_s32(vdupq_n_s32(0), q4_lo_1, q8_lo_1), + vdotq_s32(vdupq_n_s32(0), q4_hi_1, q8_hi_1)); - const int32x4_t sums = vpaddq_s32(p0, p1); + const int32x4_t sumi = vpaddq_s32(p0, p1); +#else + const int8x8_t q4_0_lo = vget_low_s8(q4_lo_0); + const int8x8_t q4_0_hi = vget_low_s8(q4_hi_0); + const int8x8_t q4_1_lo = vget_high_s8(q4_lo_0); + const int8x8_t q4_1_hi = vget_high_s8(q4_hi_0); + const int8x8_t q4_2_lo = vget_low_s8(q4_lo_1); + const int8x8_t q4_2_hi = vget_low_s8(q4_hi_1); + const int8x8_t q4_3_lo = vget_high_s8(q4_lo_1); + const int8x8_t q4_3_hi = vget_high_s8(q4_hi_1); + + const int8x8_t q8_0_lo = vld1_s8(y[2*ib].qs); + const int8x8_t q8_0_hi = vld1_s8(y[2*ib].qs + 8); + const int8x8_t q8_1_lo = vld1_s8(y[2*ib].qs + 16); + const int8x8_t q8_1_hi = vld1_s8(y[2*ib].qs + 24); + const int8x8_t q8_2_lo = vld1_s8(y[2*ib+1].qs); + const int8x8_t q8_2_hi = vld1_s8(y[2*ib+1].qs + 8); + const int8x8_t q8_3_lo = vld1_s8(y[2*ib+1].qs + 16); + const int8x8_t q8_3_hi = vld1_s8(y[2*ib+1].qs + 24); + + const int32x4_t sumi = (int32x4_t){ + vaddvq_s32(ggml_nvfp4_dot8(q4_0_lo, q8_0_lo, q4_0_hi, q8_0_hi)), + vaddvq_s32(ggml_nvfp4_dot8(q4_1_lo, q8_1_lo, q4_1_hi, q8_1_hi)), + vaddvq_s32(ggml_nvfp4_dot8(q4_2_lo, q8_2_lo, q4_2_hi, q8_2_hi)), + vaddvq_s32(ggml_nvfp4_dot8(q4_3_lo, q8_3_lo, q4_3_hi, q8_3_hi)), + }; +#endif - // Decode 4 UE4M3 scales to f32 and multiply with q8 scales const float dy0 = GGML_CPU_FP16_TO_FP32(y[2*ib].d); const float dy1 = GGML_CPU_FP16_TO_FP32(y[2*ib+1].d); const float32x4_t nvsc = { @@ -813,7 +839,7 @@ void ggml_vec_dot_nvfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo }; const float32x4_t scales = vmulq_f32(nvsc, (float32x4_t){dy0, dy0, dy1, dy1}); - acc = vfmaq_f32(acc, vcvtq_f32_s32(sums), scales); + acc = vfmaq_f32(acc, vcvtq_f32_s32(sumi), scales); } sumf = vaddvq_f32(acc); #else diff --git a/ggml/src/ggml-cpu/ggml-cpu-impl.h b/ggml/src/ggml-cpu/ggml-cpu-impl.h index 88a9c9ec057..5d1ca5ffcc3 100644 --- a/ggml/src/ggml-cpu/ggml-cpu-impl.h +++ b/ggml/src/ggml-cpu/ggml-cpu-impl.h @@ -306,6 +306,7 @@ inline static uint8x16_t ggml_vqtbl1q_u8(uint8x16_t a, uint8x16_t b) { #if !defined(__ARM_FEATURE_DOTPROD) +// NOTE: this fallback produces the same total sum as native vdotq_s32 but with different per-lane grouping — do not use when individual lane values matter. inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b) { const int16x8_t p0 = vmull_s8(vget_low_s8 (a), vget_low_s8 (b)); const int16x8_t p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b)); @@ -319,6 +320,15 @@ inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b) #endif // !defined(__ARM_FEATURE_DOTPROD) +static inline int32x4_t ggml_nvfp4_dot8(const int8x8_t q4_lo, const int8x8_t q8_lo, + const int8x8_t q4_hi, const int8x8_t q8_hi) { + const int16x8_t p_lo = vmull_s8(q4_lo, q8_lo); + const int16x8_t p_hi = vmull_s8(q4_hi, q8_hi); + const int32x4_t sum_lo = vpaddlq_s16(p_lo); + const int32x4_t sum_hi = vpaddlq_s16(p_hi); + return vaddq_s32(sum_lo, sum_hi); +} + #endif // defined(__ARM_NEON) #ifdef __wasm_simd128__ From 691b1d0826e9a1eceb955b527591aa23c287ebb0 Mon Sep 17 00:00:00 2001 From: Seyoung Jeong Date: Tue, 14 Apr 2026 21:43:59 +0900 Subject: [PATCH 429/831] metal : add XIELU unary op (llama/20802) --- ggml/src/ggml-metal/ggml-metal-device.cpp | 1 + ggml/src/ggml-metal/ggml-metal-device.m | 1 + ggml/src/ggml-metal/ggml-metal-impl.h | 1 + ggml/src/ggml-metal/ggml-metal-ops.cpp | 7 +++++++ ggml/src/ggml-metal/ggml-metal.metal | 9 +++++++++ 5 files changed, 19 insertions(+) diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index e8548b053e8..8e0836c0beb 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -250,6 +250,7 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary(ggml_metal case GGML_UNARY_OP_CEIL: op_num = OP_UNARY_NUM_CEIL; break; case GGML_UNARY_OP_ROUND: op_num = OP_UNARY_NUM_ROUND; break; case GGML_UNARY_OP_TRUNC: op_num = OP_UNARY_NUM_TRUNC; break; + case GGML_UNARY_OP_XIELU: op_num = OP_UNARY_NUM_XIELU; break; default: GGML_ABORT("fatal error"); } break; default: GGML_ABORT("fatal error"); diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 40cacb46520..4c192da650f 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -1043,6 +1043,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_UNARY_OP_CEIL: case GGML_UNARY_OP_ROUND: case GGML_UNARY_OP_TRUNC: + case GGML_UNARY_OP_XIELU: return ggml_is_contiguous_rows(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16); default: return false; diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 62b028f4a4a..e7433f2a658 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -127,6 +127,7 @@ #define OP_UNARY_NUM_CEIL 118 #define OP_UNARY_NUM_ROUND 119 #define OP_UNARY_NUM_TRUNC 120 +#define OP_UNARY_NUM_XIELU 121 #define OP_SUM_ROWS_NUM_SUM_ROWS 10 #define OP_SUM_ROWS_NUM_MEAN 11 diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 846225d9077..5b426be103f 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -787,6 +787,13 @@ int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) { args.max = ggml_get_op_params_f32(op, 1); } + if (op->op == GGML_OP_UNARY && ggml_get_unary_op(op) == GGML_UNARY_OP_XIELU) { + args.slope = ggml_get_op_params_f32(op, 1); // alpha_n + args.scale = ggml_get_op_params_f32(op, 2); // alpha_p + args.bias = ggml_get_op_params_f32(op, 3); // beta + args.val = ggml_get_op_params_f32(op, 4); // eps + } + auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op); if (pipeline.c4) { diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index f67c5cd8a1d..445a4deca83 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1177,6 +1177,15 @@ kernel void kernel_unary_impl( if (FC_OP == OP_UNARY_NUM_TRUNC) { dst_ptr[i0] = (T) trunc(x); } + + if (FC_OP == OP_UNARY_NUM_XIELU) { + const TC xi = x; + const TC gate = TC(xi > TC(0.0f)); + const TC clamped = fmin(xi, TC(args.val)); + const TC y_pos = TC(args.scale) * xi * xi + TC(args.bias) * xi; + const TC y_neg = (exp(clamped) - TC(1.0f) - xi) * TC(args.slope) + TC(args.bias) * xi; + dst_ptr[i0] = (T) (gate * y_pos + (TC(1.0f) - gate) * y_neg); + } } #undef FC_OP From 7024f7e5c12e7b0c42f5edddf69ed3210caf497a Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 14 Apr 2026 15:58:09 +0300 Subject: [PATCH 430/831] ci : re-enable mac workflows (llama/21894) * ci : re-enable mac workflows * vulkan : fix compile warning --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 8d0e109365f..aa3fe06d5a9 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -3485,7 +3485,7 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { dev_desc.requiredFeatureCount = required_features.size(); dev_desc.SetDeviceLostCallback( wgpu::CallbackMode::AllowSpontaneous, - [ctx](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) { + [](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) { if (reason == wgpu::DeviceLostReason::Destroyed) { return; } From 45365fa1116f13586a89b9b6ed67e956e5f7399b Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Tue, 14 Apr 2026 15:17:45 +0200 Subject: [PATCH 431/831] vulkan: Programmatically add RoundingModeRTE to all shaders when the device supports it (llama/21572) * vulkan: Programmatically add RoundingModeRTE to all shaders when the device supports it * use FetchContent to get SPIRV-Headers * Fetch spirv-headers unconditionally * remove fetchcontent, rely on installed headers * fix ubuntu job * Update docs/build.md --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 201 ++++++++++-------- .../vulkan-shaders/copy_to_quant.comp | 1 - ggml/src/ggml-vulkan/vulkan-shaders/diag.comp | 1 - ggml/src/ggml-vulkan/vulkan-shaders/exp.comp | 1 - .../vulkan-shaders/generic_binary_head.glsl | 1 - .../ggml-vulkan/vulkan-shaders/glu_head.glsl | 1 - .../ggml-vulkan/vulkan-shaders/im2col.comp | 1 - .../ggml-vulkan/vulkan-shaders/im2col_3d.comp | 1 - ggml/src/ggml-vulkan/vulkan-shaders/log.comp | 1 - .../ggml-vulkan/vulkan-shaders/multi_add.comp | 1 - .../ggml-vulkan/vulkan-shaders/rope_head.glsl | 1 - .../vulkan-shaders/rope_params.glsl | 2 - ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl | 5 - ggml/src/ggml-vulkan/vulkan-shaders/tri.comp | 1 - .../vulkan-shaders/vulkan-shaders-gen.cpp | 91 +++----- 15 files changed, 138 insertions(+), 172 deletions(-) delete mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index b353d041421..b2a54bd85d0 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -20,6 +20,13 @@ DispatchLoaderDynamic & ggml_vk_default_dispatcher(); #define VULKAN_HPP_DEFAULT_DISPATCHER ggml_vk_default_dispatcher() #include +// SPIRV-Headers: LunarG Windows SDK uses Include/spirv-headers/spirv.hpp (not spirv/unified1/). MinGW/MSYS2 and +// Linux packages use Khronos layout spirv/unified1/spirv.hpp. See docs/build.md#vulkan. +#if defined(_WIN32) && !defined(__MINGW32__) +#include +#else +#include +#endif #include #include @@ -2131,6 +2138,66 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin GGML_ASSERT(wg_denoms[0] > 0 && wg_denoms[1] > 0 && wg_denoms[2] > 0); // NOLINT vk::ShaderModuleCreateInfo shader_module_create_info({}, spv_size, reinterpret_cast(spv_data)); + + // Patch SPIR-V to enable RTE rounding for FP16, avoiding the need for + // separate shader variants compiled with -DRTE16. + std::vector spv; + if (device->float_controls_rte_fp16) { + const uint32_t* spv_words = reinterpret_cast(spv_data); + size_t word_count = spv_size / sizeof(uint32_t); + spv.assign(spv_words, spv_words + word_count); + + // Find insertion points respecting SPIR-V layout order: + // Header(5) -> OpCapability -> OpExtension -> ... -> OpEntryPoint -> OpExecutionMode -> ... + size_t pos = 5; // skip header + size_t cap_insert_pos = pos; + size_t ext_insert_pos = pos; + size_t exec_insert_pos = pos; + uint32_t entry_point_id = 0; + + while (pos < spv.size()) { + uint32_t opcode = spv[pos] & spv::OpCodeMask; + uint32_t len = spv[pos] >> spv::WordCountShift; + if (len == 0) break; + + if (opcode == spv::OpCapability) { + cap_insert_pos = pos + len; + ext_insert_pos = pos + len; + } else if (opcode == spv::OpExtension) { + ext_insert_pos = pos + len; + } else if (opcode == spv::OpEntryPoint) { + entry_point_id = spv[pos + 2]; + exec_insert_pos = pos + len; + } else if (opcode == spv::OpExecutionMode || opcode == spv::OpExecutionModeId) { + exec_insert_pos = pos + len; + } else if (entry_point_id != 0) { + break; + } + + pos += len; + } + + // Insert from latest position first so earlier indices stay valid. + + // OpExecutionMode %entrypoint RoundingModeRTE 16 + uint32_t exec_mode[] = { (4u << spv::WordCountShift) | spv::OpExecutionMode, entry_point_id, spv::ExecutionModeRoundingModeRTE, 16 }; + spv.insert(spv.begin() + exec_insert_pos, std::begin(exec_mode), std::end(exec_mode)); + + // OpExtension "SPV_KHR_float_controls" + const char ext_str[] = "SPV_KHR_float_controls"; + size_t ext_str_words = CEIL_DIV(sizeof(ext_str), sizeof(uint32_t)); + std::vector extension(1 + ext_str_words, 0); + extension[0] = (uint32_t)((1 + ext_str_words) << spv::WordCountShift) | spv::OpExtension; + memcpy(&extension[1], ext_str, sizeof(ext_str)); + spv.insert(spv.begin() + ext_insert_pos, extension.begin(), extension.end()); + + // OpCapability RoundingModeRTE + uint32_t capability[] = { (2u << spv::WordCountShift) | spv::OpCapability, spv::CapabilityRoundingModeRTE }; + spv.insert(spv.begin() + cap_insert_pos, std::begin(capability), std::end(capability)); + + shader_module_create_info = vk::ShaderModuleCreateInfo({}, spv.size() * sizeof(uint32_t), spv.data()); + } + pipeline->shader_module = device->device.createShaderModule(shader_module_create_info); vk::PushConstantRange pcr( @@ -4344,10 +4411,9 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_rms_norm_partials_f32, "rms_norm_partials_f32", rms_norm_partials_f32_len, rms_norm_partials_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_partials_f32, "rms_norm_mul_partials_f32", rms_norm_partials_f32_len, rms_norm_partials_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1, true); - if (device->float_controls_rte_fp16 && - sizeof(vk_op_rms_norm_mul_rope_push_constants) <= device->properties.limits.maxPushConstantsSize) { + if (sizeof(vk_op_rms_norm_mul_rope_push_constants) <= device->properties.limits.maxPushConstantsSize) { ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_rope_f32_f32, "rms_norm_mul_rope_f32_f32", rms_norm_mul_rope_f32_f32_len, rms_norm_mul_rope_f32_f32_data, "main", 7, sizeof(vk_op_rms_norm_mul_rope_push_constants), {1, 1, 1}, {0, 1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_rope_f32_f16, "rms_norm_mul_rope_f32_f16", rms_norm_mul_rope_f32_f16_rte_len, rms_norm_mul_rope_f32_f16_rte_data, "main", 7, sizeof(vk_op_rms_norm_mul_rope_push_constants), {1, 1, 1}, {0, 1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_rope_f32_f16, "rms_norm_mul_rope_f32_f16", rms_norm_mul_rope_f32_f16_len, rms_norm_mul_rope_f32_f16_data, "main", 7, sizeof(vk_op_rms_norm_mul_rope_push_constants), {1, 1, 1}, {0, 1}, 1, true); } ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); @@ -4372,43 +4438,28 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_cpy_transpose_32, "cpy_transpose_32", cpy_transpose_32_len, cpy_transpose_32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_transpose_16, "cpy_transpose_16", cpy_transpose_16_len, cpy_transpose_16_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1); - if (device->float_controls_rte_fp16) { - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q1_0], "cpy_f32_q1_0", cpy_f32_q1_0_rte_len, cpy_f32_q1_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_rte_len, cpy_f32_q4_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_rte_len, cpy_f32_q5_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_rte_len, cpy_f32_q5_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_rte_len, cpy_f32_q8_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_rte_len, cpy_f32_iq4_nl_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); - } else { - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q1_0], "cpy_f32_q1_0", cpy_f32_q1_0_len, cpy_f32_q1_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); - } - -#define SET_ROWS(itype, rte) \ - ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_F32], "set_rows_f32" #itype, set_rows_f32 ## itype ## rte ## _len, set_rows_f32 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ - ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_F16], "set_rows_f16" #itype, set_rows_f16 ## itype ## rte ## _len, set_rows_f16 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ - ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_BF16], "set_rows_bf16" #itype, set_rows_bf16 ## itype ## rte ## _len, set_rows_bf16 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ - ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q1_0], "set_rows_q1_0" #itype, set_rows_q1_0 ## itype ## rte ## _len, set_rows_q1_0 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ - ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q4_0], "set_rows_q4_0" #itype, set_rows_q4_0 ## itype ## rte ## _len, set_rows_q4_0 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ - ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q4_1], "set_rows_q4_1" #itype, set_rows_q4_1 ## itype ## rte ## _len, set_rows_q4_1 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ - ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q5_0], "set_rows_q5_0" #itype, set_rows_q5_0 ## itype ## rte ## _len, set_rows_q5_0 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ - ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q5_1], "set_rows_q5_1" #itype, set_rows_q5_1 ## itype ## rte ## _len, set_rows_q5_1 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ - ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q8_0], "set_rows_q8_0" #itype, set_rows_q8_0 ## itype ## rte ## _len, set_rows_q8_0 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ - ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_IQ4_NL], "set_rows_iq4_nl" #itype, set_rows_iq4_nl ## itype ## rte ## _len, set_rows_iq4_nl ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); - - if (device->float_controls_rte_fp16) { - SET_ROWS(_i32, _rte) - SET_ROWS(_i64, _rte) - } else { - SET_ROWS(_i32, ) - SET_ROWS(_i64, ) - } + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q1_0], "cpy_f32_q1_0", cpy_f32_q1_0_len, cpy_f32_q1_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); + +#define SET_ROWS(itype) \ + ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_F32], "set_rows_f32" #itype, set_rows_f32 ## itype ## _len, set_rows_f32 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_F16], "set_rows_f16" #itype, set_rows_f16 ## itype ## _len, set_rows_f16 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_BF16], "set_rows_bf16" #itype, set_rows_bf16 ## itype ## _len, set_rows_bf16 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q1_0], "set_rows_q1_0" #itype, set_rows_q1_0 ## itype ## _len, set_rows_q1_0 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q4_0], "set_rows_q4_0" #itype, set_rows_q4_0 ## itype ## _len, set_rows_q4_0 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q4_1], "set_rows_q4_1" #itype, set_rows_q4_1 ## itype ## _len, set_rows_q4_1 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q5_0], "set_rows_q5_0" #itype, set_rows_q5_0 ## itype ## _len, set_rows_q5_0 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q5_1], "set_rows_q5_1" #itype, set_rows_q5_1 ## itype ## _len, set_rows_q5_1 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q8_0], "set_rows_q8_0" #itype, set_rows_q8_0 ## itype ## _len, set_rows_q8_0 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_IQ4_NL], "set_rows_iq4_nl" #itype, set_rows_iq4_nl ## itype ## _len, set_rows_iq4_nl ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); + + SET_ROWS(_i32) + SET_ROWS(_i64) #undef SET_ROWS @@ -4428,11 +4479,10 @@ static void ggml_vk_load_shaders(vk_device& device) { return s; }; - bool rte = device->float_controls_rte_fp16; #define CREATE_BINARY(name, namemod, spec, bindings) \ for (int s0 : {0,1}) for (int s1 : {0,1}) for (int d : {0,1}) \ ggml_vk_create_pipeline2(device, device->pipeline_ ## name ## namemod[s0][s1][d], \ - #name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d][rte], name ## _data[s0][s1][d][rte], \ + #name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d], name ## _data[s0][s1][d], \ "main", (bindings), sizeof(vk_op_binary_push_constants), {512, 1, 1}, spec, 1); CREATE_BINARY(add, , {0}, 4) @@ -4475,13 +4525,8 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_sin_f32, "sin_f32", sin_f32_len, sin_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cos_f32, "cos_f32", cos_f32_len, cos_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); - if (device->float_controls_rte_fp16) { - ggml_vk_create_pipeline(device, device->pipeline_log[0], "log_f32_rte", log_f32_rte_len, log_f32_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_log[1], "log_f16_rte", log_f16_rte_len, log_f16_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); - } else { - ggml_vk_create_pipeline(device, device->pipeline_log[0], "log_f32", log_f32_len, log_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_log[1], "log_f16", log_f16_len, log_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); - } + ggml_vk_create_pipeline(device, device->pipeline_log[0], "log_f32", log_f32_len, log_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_log[1], "log_f16", log_f16_len, log_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_tri[0], "tri_f32", tri_f32_len, tri_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_tri[1], "tri_f16", tri_f16_len, tri_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); @@ -4522,19 +4567,9 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_UNARY(floor) CREATE_UNARY(trunc) CREATE_UNARY(sgn) + CREATE_UNARY(exp) #undef CREATE_UNARY -#define CREATE_UNARY_RTE(name) \ - if (device->float_controls_rte_fp16) { \ - ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32_rte", name ## _f32_rte_len, name ## _f32_rte_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \ - ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16_rte", name ## _f16_rte_len, name ## _f16_rte_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \ - } else { \ - ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \ - ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \ - } - CREATE_UNARY_RTE(exp) -#undef CREATE_UNARY_RTE - ggml_vk_create_pipeline(device, device->pipeline_add1_f16_f16, "add1_f16_f16", add1_f16_f16_len, add1_f16_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_add1_f16_f32, "add1_f16_f32", add1_f16_f32_len, add1_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_add1_f32_f32, "add1_f32_f32", add1_f32_f32_len, add1_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); @@ -4544,13 +4579,8 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_fill_f32, "fill_f32", fill_f32_len, fill_f32_data, "main", 1, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); #define CREATE_GLU(name) \ - if (device->float_controls_rte_fp16) { \ - ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32_rte", name ## _f32_rte_len, name ## _f32_rte_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \ - ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16_rte", name ## _f16_rte_len, name ## _f16_rte_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \ - } else { \ - ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \ - ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \ - } + ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); CREATE_GLU(geglu) CREATE_GLU(reglu) @@ -4583,25 +4613,14 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32, "rope_multi_f32", rope_multi_f32_len, rope_multi_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f32, "rope_vision_f32", rope_vision_f32_len, rope_vision_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); - if (device->float_controls_rte_fp16) { - ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_rte_len, rope_norm_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_rte_len, rope_neox_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_rte_len, rope_multi_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_rte_len, rope_vision_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); - - ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32_f16, "rope_norm_f32_f16", rope_norm_f32_f16_rte_len, rope_norm_f32_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32_f16, "rope_neox_f32_f16", rope_neox_f32_f16_rte_len, rope_neox_f32_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32_f16, "rope_multi_f32_f16", rope_multi_f32_f16_rte_len, rope_multi_f32_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); - } else { - ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_len, rope_multi_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_len, rope_vision_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_len, rope_multi_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_len, rope_vision_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32_f16, "rope_norm_f32_f16", rope_norm_f32_f16_len, rope_norm_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32_f16, "rope_neox_f32_f16", rope_neox_f32_f16_len, rope_neox_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32_f16, "rope_multi_f32_f16", rope_multi_f32_f16_len, rope_multi_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); - } + ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32_f16, "rope_norm_f32_f16", rope_norm_f32_f16_len, rope_norm_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32_f16, "rope_neox_f32_f16", rope_neox_f32_f16_len, rope_neox_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32_f16, "rope_multi_f32_f16", rope_multi_f32_f16_len, rope_multi_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); for (uint32_t i = 0; i < num_argsort_pipelines; ++i) { uint32_t BLOCK_SIZE = 1u << std::min(i, device->max_workgroup_size_log2); @@ -4663,13 +4682,8 @@ static void ggml_vk_load_shaders(vk_device& device) { #define IM2COL(bda) \ ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32 ## bda ## _len, im2col_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \ ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32, "im2col_3d_f32", im2col_3d_f32 ## bda ## _len, im2col_3d_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \ - if (device->float_controls_rte_fp16) { \ - ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte ## bda ## _len, im2col_f32_f16_rte ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \ - ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16_rte ## bda ## _len, im2col_3d_f32_f16_rte ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \ - } else { \ - ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16 ## bda ## _len, im2col_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \ - ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16 ## bda ## _len, im2col_3d_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \ - } + ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16 ## bda ## _len, im2col_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16 ## bda ## _len, im2col_3d_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); if (device->shader_int64 && device->buffer_device_address) { IM2COL(_bda) } else { @@ -14343,8 +14357,7 @@ static bool ggml_vk_can_fuse_rms_norm_mul_rope(ggml_backend_vk_context * ctx, co } // conditions for pipeline creation - if (!(ctx->device->float_controls_rte_fp16 && - sizeof(vk_op_rms_norm_mul_rope_push_constants) <= ctx->device->properties.limits.maxPushConstantsSize)) { + if (sizeof(vk_op_rms_norm_mul_rope_push_constants) > ctx->device->properties.limits.maxPushConstantsSize) { return false; } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp b/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp index 4ffa45485c9..710c15296da 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp @@ -1,6 +1,5 @@ #version 450 -#include "rte.glsl" #include "types.glsl" #if defined(SET_ROWS) && QUANT_K == 1 diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp b/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp index cd3f42f4911..79761324f55 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp @@ -1,6 +1,5 @@ #version 450 -#include "rte.glsl" #include "types.glsl" #include "generic_unary_head.glsl" diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp b/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp index b69d4ddb096..c7cf5ec68f7 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp @@ -1,6 +1,5 @@ #version 450 -#include "rte.glsl" #include "generic_head.glsl" #include "types.glsl" diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl index ba7909c4d38..dc657f3c708 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl @@ -1,7 +1,6 @@ #extension GL_EXT_shader_16bit_storage : require #extension GL_EXT_control_flow_attributes : require -#include "rte.glsl" #include "utils.glsl" #if RMS_NORM_ROPE_FUSION #include "rope_params.glsl" diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl index 95298922d83..d8fdd8f7b5e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl @@ -1,6 +1,5 @@ #extension GL_EXT_shader_16bit_storage : require -#include "rte.glsl" layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp b/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp index db14f5a3cf3..674f91e5ed2 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp @@ -3,7 +3,6 @@ #extension GL_EXT_shader_16bit_storage : require #extension GL_EXT_control_flow_attributes : require -#include "rte.glsl" #include "types.glsl" layout (push_constant) uniform parameter diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp b/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp index 4bf8b4ca046..93f61fd8543 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp @@ -4,7 +4,6 @@ #extension GL_EXT_control_flow_attributes : require #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require -#include "rte.glsl" #include "types.glsl" layout (push_constant) uniform parameter diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/log.comp b/ggml/src/ggml-vulkan/vulkan-shaders/log.comp index ff2812d3d75..3cda6a63c45 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/log.comp @@ -1,6 +1,5 @@ #version 450 -#include "rte.glsl" #include "types.glsl" #include "generic_unary_head.glsl" diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp b/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp index 10cf5202a4a..26d194e9e8d 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp @@ -8,7 +8,6 @@ #extension GL_KHR_shader_subgroup_basic : enable #endif -#include "rte.glsl" #include "types.glsl" #include "utils.glsl" diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl index d9b4d4c03f3..51a127bcd87 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl @@ -2,7 +2,6 @@ #extension GL_EXT_shader_16bit_storage : require -#include "rte.glsl" #include "rope_params.glsl" layout(local_size_x = 1, local_size_y = 256, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl index ec6ceaca9bd..2e2a7e14c66 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl @@ -1,8 +1,6 @@ #if !defined(GGML_ROPE_PARAMS) #define GGML_ROPE_PARAMS -#include "rte.glsl" - struct rope_params { uint rope_mode; uint nrows; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl deleted file mode 100644 index ad51c1e80b8..00000000000 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +++ /dev/null @@ -1,5 +0,0 @@ - -#if RTE16 -#extension GL_EXT_spirv_intrinsics : enable -spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits -#endif // RTE16 diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp b/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp index e18d0ffa307..f9b78f96072 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp @@ -1,6 +1,5 @@ #version 450 -#include "rte.glsl" #include "types.glsl" #include "generic_unary_head.glsl" diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index b232927658b..54b9b327333 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -745,7 +745,7 @@ void process_shaders() { string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("rms_norm_partials_f32", "rms_norm_partials.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("rms_norm_mul_rope_f32_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"ROPE_D_TYPE", "float"}, {"RMS_NORM_ROPE_FUSION", "1"}})); - string_to_spv("rms_norm_mul_rope_f32_f16_rte", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}, {"RMS_NORM_ROPE_FUSION", "1"}, {"RTE16", "1"}})); + string_to_spv("rms_norm_mul_rope_f32_f16", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}, {"RMS_NORM_ROPE_FUSION", "1"}})); string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("l2_norm_f32", "l2_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); @@ -769,15 +769,12 @@ void process_shaders() { for (std::string t : {"q1_0", "q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) { string_to_spv("cpy_f32_" + t, "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); - string_to_spv("cpy_f32_" + t + "_rte", "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}}); string_to_spv("cpy_" + t + "_f32", "copy_from_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); } for (std::string t : {"f32", "f16", "bf16", "q1_0", "q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) { - string_to_spv("set_rows_" + t + "_i32", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uint"}, {"B_SIZE", "32"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); - string_to_spv("set_rows_" + t + "_i32_rte", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uint"}, {"B_SIZE", "32"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}}); - string_to_spv("set_rows_" + t + "_i64", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uvec2"}, {"B_SIZE", "64"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); - string_to_spv("set_rows_" + t + "_i64_rte", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uvec2"}, {"B_SIZE", "64"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}}); + string_to_spv("set_rows_" + t + "_i32", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uint"}, {"B_SIZE", "32"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + string_to_spv("set_rows_" + t + "_i64", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uvec2"}, {"B_SIZE", "64"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); } auto get_type_str = [](bool f16) { @@ -794,12 +791,10 @@ void process_shaders() { for (auto src0_f16 : {false, true}) { for (auto src1_f16 : {false, true}) { for (auto dst_f16 : {false, true}) { - for (auto rte : {false, true}) { auto source = op == "add_rms" ? std::string("add") : op; - auto name = op + get_suffix(src0_f16, src1_f16, dst_f16) + (rte ? "_rte" : ""); + auto name = op + get_suffix(src0_f16, src1_f16, dst_f16); auto add_rms = op == "add_rms" ? "1" : "0"; - string_to_spv(name.c_str(), source + ".comp", {{"A_TYPE", get_type_str(src0_f16)}, {"B_TYPE", get_type_str(src1_f16)}, {"D_TYPE", get_type_str(dst_f16)}, {"FLOAT_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}, {"ADD_RMS" , add_rms}}); - } + string_to_spv(name.c_str(), source + ".comp", {{"A_TYPE", get_type_str(src0_f16)}, {"B_TYPE", get_type_str(src1_f16)}, {"D_TYPE", get_type_str(dst_f16)}, {"FLOAT_TYPE", "float"}, {"ADD_RMS" , add_rms}}); } } } @@ -847,14 +842,11 @@ void process_shaders() { string_to_spv("upscale_f32", "upscale.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); - for (auto rte : {false, true}) { - std::string suffix = rte ? "_rte" : ""; - string_to_spv("exp_f16" + suffix, "exp.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}}); - string_to_spv("exp_f32" + suffix, "exp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"} , {"RTE16", rte ? "1" : "0"}}); + string_to_spv("exp_f16", "exp.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("exp_f32", "exp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); - string_to_spv("log_f16" + suffix, "log.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}}); - string_to_spv("log_f32" + suffix, "log.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}}); - } + string_to_spv("log_f16", "log.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("log_f32", "log.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("gelu_f16", "gelu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("gelu_erf_f16", "gelu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); @@ -908,21 +900,18 @@ void process_shaders() { string_to_spv("trunc_f16", "trunc.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); string_to_spv("trunc_f32", "trunc.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); - for (auto rte : {false, true}) { - std::string suffix = rte ? "_rte" : ""; - string_to_spv("geglu_f16" + suffix, "geglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}}); - string_to_spv("geglu_f32" + suffix, "geglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}}); - string_to_spv("reglu_f16" + suffix, "reglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}}); - string_to_spv("reglu_f32" + suffix, "reglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}}); - string_to_spv("swiglu_f16" + suffix, "swiglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}}); - string_to_spv("swiglu_f32" + suffix, "swiglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}}); - string_to_spv("swiglu_oai_f16" + suffix, "swiglu_oai.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}}); - string_to_spv("swiglu_oai_f32" + suffix, "swiglu_oai.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}}); - string_to_spv("geglu_erf_f16" + suffix, "geglu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}}); - string_to_spv("geglu_erf_f32" + suffix, "geglu_erf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}}); - string_to_spv("geglu_quick_f16" + suffix,"geglu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}}); - string_to_spv("geglu_quick_f32" + suffix,"geglu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}}); - } + string_to_spv("geglu_f16", "geglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("geglu_f32", "geglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("reglu_f16", "reglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("reglu_f32", "reglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("swiglu_f16", "swiglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("swiglu_f32", "swiglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("swiglu_oai_f16", "swiglu_oai.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("swiglu_oai_f32", "swiglu_oai.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("geglu_erf_f16", "geglu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("geglu_erf_f32", "geglu_erf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("geglu_quick_f16","geglu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("geglu_quick_f32","geglu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); @@ -942,25 +931,18 @@ void process_shaders() { string_to_spv("rope_norm_f32", "rope_norm.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}}); string_to_spv("rope_norm_f16", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}}); - string_to_spv("rope_norm_f16_rte", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}}); string_to_spv("rope_norm_f32_f16", "rope_norm.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}}); - string_to_spv("rope_norm_f32_f16_rte", "rope_norm.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}}); string_to_spv("rope_neox_f32", "rope_neox.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}}); string_to_spv("rope_neox_f16", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}}); - string_to_spv("rope_neox_f16_rte", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}}); string_to_spv("rope_neox_f32_f16", "rope_neox.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}}); - string_to_spv("rope_neox_f32_f16_rte", "rope_neox.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}}); string_to_spv("rope_multi_f32", "rope_multi.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}}); string_to_spv("rope_multi_f16", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}}); - string_to_spv("rope_multi_f16_rte", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}}); string_to_spv("rope_multi_f32_f16", "rope_multi.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}}); - string_to_spv("rope_multi_f32_f16_rte", "rope_multi.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}}); string_to_spv("rope_vision_f32", "rope_vision.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}}); string_to_spv("rope_vision_f16", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}}); - string_to_spv("rope_vision_f16_rte", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}}); string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}}); string_to_spv("argsort_large_f32", "argsort_large.comp", {{"A_TYPE", "float"}}); @@ -983,7 +965,6 @@ void process_shaders() { std::string bda_def = bda ? "1" : "0"; string_to_spv("im2col" + dim_str + "_f32" + bda_str, "im2col" + dim_str + ".comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"D_SIZE", "4"}, {"BDA", bda_def}})); string_to_spv("im2col" + dim_str + "_f32_f16" + bda_str, "im2col" + dim_str + ".comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"D_SIZE", "2"}, {"BDA", bda_def}})); - string_to_spv("im2col" + dim_str + "_f32_f16_rte" + bda_str, "im2col" + dim_str + ".comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"D_SIZE", "2"}, {"RTE16", "1"}, {"BDA", bda_def}})); } } @@ -1036,8 +1017,8 @@ void process_shaders() { string_to_spv("add_id_f32", "add_id.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); - string_to_spv("multi_add_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "0"}}); - string_to_spv("multi_add_rms_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "1"}}); + string_to_spv("multi_add_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"ADD_RMS" , "0"}}); + string_to_spv("multi_add_rms_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"ADD_RMS" , "1"}}); string_to_spv("ssm_scan_f32", "ssm_scan.comp", {{"A_TYPE", "float"}}); string_to_spv("ssm_scan_subgroup_f32", "ssm_scan.comp", {{"A_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}); @@ -1090,8 +1071,8 @@ void write_output_files() { std::string suffixes[2] = {"_f32", "_f16"}; for (std::string op : {"add", "sub", "mul", "div", "add_rms"}) { - hdr << "extern const void * " << op << "_data[2][2][2][2];\n"; - hdr << "extern const uint64_t " << op << "_len[2][2][2][2];\n"; + hdr << "extern const void * " << op << "_data[2][2][2];\n"; + hdr << "extern const uint64_t " << op << "_len[2][2][2];\n"; std::string op_file = op == "add_rms" ? "add.comp" : std::string(op) + ".comp"; if (basename(input_filepath) != op_file) { @@ -1099,8 +1080,8 @@ void write_output_files() { } std::stringstream data = make_generic_stringstream(); std::stringstream len = make_generic_stringstream(); - data << "const void * " << op << "_data[2][2][2][2] = "; - len << "const uint64_t " << op << "_len[2][2][2][2] = "; + data << "const void * " << op << "_data[2][2][2] = "; + len << "const uint64_t " << op << "_len[2][2][2] = "; for (uint32_t t0 = 0; t0 < 2; ++t0) { if (t0 == 0) { data << "{"; @@ -1116,20 +1097,10 @@ void write_output_files() { data << "{"; len << "{"; } - for (uint32_t rte = 0; rte < 2; ++rte) { - if (rte == 0) { - data << "{"; - len << "{"; - } - data << op << suffixes[t0] << suffixes[t1] << suffixes[t2] << ((rte != 0) ? "_rte" : ""); - len << op << suffixes[t0] << suffixes[t1] << suffixes[t2] << ((rte != 0) ? "_rte" : ""); - data << "_data,"; - len << "_len,"; - if (rte == 1) { - data << "}, "; - len << "}, "; - } - } + data << op << suffixes[t0] << suffixes[t1] << suffixes[t2]; + len << op << suffixes[t0] << suffixes[t1] << suffixes[t2]; + data << "_data,"; + len << "_len,"; if (t2 == 1) { data << "}, "; len << "}, "; From 08e412c862ca2274aedb16f314376a81cd32b9a6 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 14 Apr 2026 17:32:29 +0300 Subject: [PATCH 432/831] metal : fix FA support logic (llama/21898) --- ggml/src/ggml-metal/ggml-metal-device.m | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 4c192da650f..effe666a691 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -1160,6 +1160,23 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te if (op->src[1]->type != op->src[2]->type) { return false; } + switch (op->src[1]->type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + break; + case GGML_TYPE_BF16: + if (!has_bfloat) { + return false; + } + break; + default: + return false; + } return has_simdgroup_mm; // TODO: over-restricted for vec-kernels case GGML_OP_SSM_CONV: case GGML_OP_SSM_SCAN: From 44d86c4921c1d8ba48e946c1885983311e6055b1 Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Tue, 14 Apr 2026 16:32:58 +0200 Subject: [PATCH 433/831] ggml : remove ggml-ext.h (llama/21869) * ggml: correct placement of ggml-ext.h * ggml : remove ggml-ext.h --------- Co-authored-by: Georgi Gerganov --- ggml/include/ggml-backend.h | 47 ++++++++++++++++++++++++++++++++++ ggml/src/ggml-alloc.c | 1 + ggml/src/ggml-backend-meta.cpp | 3 --- 3 files changed, 48 insertions(+), 3 deletions(-) diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h index 3c06aeaffb1..4a8f6d4287d 100644 --- a/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h @@ -348,6 +348,53 @@ extern "C" { // Set a callback to be called for each resulting node during graph compute GGML_API void ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data); + // + // Meta backend + // + +#define GGML_BACKEND_META_MAX_DEVICES 16 + + enum ggml_backend_meta_split_axis { + // tensor split by tensor dimensions: + GGML_BACKEND_SPLIT_AXIS_0 = 0, + GGML_BACKEND_SPLIT_AXIS_1 = 1, + GGML_BACKEND_SPLIT_AXIS_2 = 2, + GGML_BACKEND_SPLIT_AXIS_3 = 3, + + GGML_BACKEND_SPLIT_AXIS_MIRRORED = 10, // all values on all backends + GGML_BACKEND_SPLIT_AXIS_PARTIAL = 11, // each backend has a partial sum + + // for internal bookkeeping only: + GGML_BACKEND_SPLIT_AXIS_NONE = 98, + GGML_BACKEND_SPLIT_AXIS_UNKNOWN = 99, + }; + GGML_API const char * ggml_backend_meta_split_axis_name(enum ggml_backend_meta_split_axis split_axis); + + struct ggml_backend_meta_split_state { + enum ggml_backend_meta_split_axis axis; + + // for tensors with axis >= 0 && axis < GGML_MAX_DIMS: + // - each device has a slice of the tensor along the split axis + // - most tensors have n_segments == 1 and a contiguous slice of the tensor data + // - some tensors have an inhomogenenous data layout along the split axis, + // those tensors are divided into segments which are each individually split across devices + // - ne has one entry per segment and device that add up to ggml_tensor::ne for that axis, + // the outer/inner loops are over segments/devices like [seg0_dev0, seg0_dev1, seg1_dev0, seg1_dev1], + // - for example, a transformer may have a fused QKV matrix rather than 3 matrices, those would be 3 separate segments + // that each need to be split individually across devices so that each device gets a slice of Q, K, and V + int64_t ne[16*GGML_BACKEND_META_MAX_DEVICES]; + uint32_t n_segments; + }; + + // function to assign split states for statically allocated tensors, compute tensor split states will be assigned to be compatible: + typedef struct ggml_backend_meta_split_state(*ggml_backend_meta_get_split_state_t)(const struct ggml_tensor * tensor, void * userdata); + + // create a new meta device from "simple" devices, meta buffer type/buffer/backend is then derived from this: + // TODO: this looks a bit strange - a backend API creates a device. I think we should try + // express this as a backend registry functionality instead + GGML_API ggml_backend_dev_t ggml_backend_meta_device( + ggml_backend_dev_t * devs, size_t n_devs, ggml_backend_meta_get_split_state_t get_split_state, void * get_split_state_ud); + // // Utils // diff --git a/ggml/src/ggml-alloc.c b/ggml/src/ggml-alloc.c index e9b70398ffc..a4b01ccf8a1 100644 --- a/ggml/src/ggml-alloc.c +++ b/ggml/src/ggml-alloc.c @@ -2,6 +2,7 @@ #include "ggml-backend-impl.h" #include "ggml.h" #include "ggml-impl.h" + #include #include #include diff --git a/ggml/src/ggml-backend-meta.cpp b/ggml/src/ggml-backend-meta.cpp index a2ab8872c4a..0a8eea4e945 100644 --- a/ggml/src/ggml-backend-meta.cpp +++ b/ggml/src/ggml-backend-meta.cpp @@ -5,9 +5,6 @@ #include "ggml-alloc.h" #include "ggml-cpp.h" -// TODO: tmp -#include "ggml-ext.h" - #include #include #include From 24cc89e477bea0336e911316d036c24dee5258a8 Mon Sep 17 00:00:00 2001 From: Yiwei Shao <44545837+njsyw1997@users.noreply.github.com> Date: Tue, 14 Apr 2026 14:09:03 -0700 Subject: [PATCH 434/831] hexagon: optimization for HMX mat_mul (llama/21554) * hexagon: add async HMX worker Introduce hmx-worker (dedicated thread for HMX compute) to overlap HMX matmul with HVX dequant/DMA stages in the pipeline path, replacing the previous synchronous HMX calls that blocked the main thread. * hexagon: cost-based VTCM chunk search for out-stationary matmul * hexagon: fix futex race in hmx_worker_drain Store the boolean to local variable avoid atomic load twice * hex-mm: hmx optimize scatter/transpose and use HMX intrinsics * hex-vmem: drop vmem limit a touch under 3GB on v73 * hexagon: add fwd declaration of htp_context * hex-hmx: replace hmx-worker with hmx-queue that mimics dma-queue interface Simplifies the overall implemantion, reduces thread wakeup roundtrips. * hex-mm: add debug log to hmx work func called from hmx-queue * Update hmx-queue.h Co-authored-by: Max Krasnyansky --------- Co-authored-by: Kim-Chyan Gan Co-authored-by: Max Krasnyansky Co-authored-by: Max Krasnyansky --- ggml/src/ggml-hexagon/htp/CMakeLists.txt | 1 + ggml/src/ggml-hexagon/htp/hex-utils.h | 15 +- ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c | 388 +++++++++++++-------- ggml/src/ggml-hexagon/htp/hmx-queue.c | 158 +++++++++ ggml/src/ggml-hexagon/htp/hmx-queue.h | 134 +++++++ ggml/src/ggml-hexagon/htp/hmx-utils.h | 56 --- ggml/src/ggml-hexagon/htp/htp-ctx.h | 7 + ggml/src/ggml-hexagon/htp/htp-ops.h | 5 + ggml/src/ggml-hexagon/htp/hvx-base.h | 5 + ggml/src/ggml-hexagon/htp/main.c | 17 +- 10 files changed, 589 insertions(+), 197 deletions(-) create mode 100644 ggml/src/ggml-hexagon/htp/hmx-queue.c create mode 100644 ggml/src/ggml-hexagon/htp/hmx-queue.h diff --git a/ggml/src/ggml-hexagon/htp/CMakeLists.txt b/ggml/src/ggml-hexagon/htp/CMakeLists.txt index 2b60f427ada..9ca759459d4 100644 --- a/ggml/src/ggml-hexagon/htp/CMakeLists.txt +++ b/ggml/src/ggml-hexagon/htp/CMakeLists.txt @@ -47,6 +47,7 @@ list(FIND HTP_HMX_VERSIONS ${DSP_VERSION} _hmx_idx) if (_hmx_idx GREATER_EQUAL 0) target_sources(${HTP_LIB} PRIVATE + hmx-queue.c hmx-matmul-ops.c ) diff --git a/ggml/src/ggml-hexagon/htp/hex-utils.h b/ggml/src/ggml-hexagon/htp/hex-utils.h index fe0b661e309..f6713c5cf8f 100644 --- a/ggml/src/ggml-hexagon/htp/hex-utils.h +++ b/ggml/src/ggml-hexagon/htp/hex-utils.h @@ -31,6 +31,14 @@ static inline uint64_t hex_get_pktcnt() { return pktcnt; } +static inline uint32_t hex_ceil_pow2(uint32_t x) { + if (x <= 1) { return 1; } + int p = 2; + x--; + while (x >>= 1) { p <<= 1; } + return p; +} + static inline size_t hmx_ceil_div(size_t num, size_t den) { return (num + den - 1) / den; } @@ -73,8 +81,7 @@ static inline void hex_l2fetch(const void * p, uint32_t width, uint32_t stride, #define HEX_L2_LINE_SIZE 64 #define HEX_L2_FLUSH_SIZE (128 * 1024) -static inline void hex_l2flush(void * addr, size_t size) -{ +static inline void hex_l2flush(void * addr, size_t size) { if (size > HEX_L2_FLUSH_SIZE) { qurt_mem_cache_clean((qurt_addr_t) 0, 0, QURT_MEM_CACHE_FLUSH_INVALIDATE_ALL, QURT_MEM_DCACHE); } else { @@ -89,4 +96,8 @@ static inline void hex_l2flush(void * addr, size_t size) } } +static inline void hex_pause() { + asm volatile(" pause(#255)\n"); +} + #endif /* HEX_UTILS_H */ diff --git a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c index ec191c14981..485ec3f1aa9 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c @@ -16,14 +16,16 @@ #include "ggml-common.h" #include "hex-dma.h" +#include "worker-pool.h" + #include "hvx-utils.h" #include "hvx-dump.h" -#include "worker-pool.h" #include "htp-ctx.h" #include "htp-ops.h" -#include "hmx-utils.h" #include "hmx-ops.h" +#include "hmx-utils.h" +#include "hmx-queue.h" #include "hmx-profile.h" static const __fp16 q4_0_to_fp16_lut[64] __attribute__((aligned(VLEN))) = { @@ -47,7 +49,8 @@ static const __fp16 iq4_nl_to_fp16_lut[64] __attribute__((aligned(VLEN))) = { static const int32_t weight_transpose_scatter_offsets[32] __attribute__((aligned(VLEN))) = { 0*128, 1*128, 2*128, 3*128, 4*128, 5*128, 6*128, 7*128, 8*128, 9*128, 10*128, 11*128, 12*128, 13*128, 14*128, 15*128, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 16*128, 17*128, 18*128, 19*128, 20*128, 21*128, 22*128, 23*128, + 24*128, 25*128, 26*128, 27*128, 28*128, 29*128, 30*128, 31*128 }; // Scales per x4x2 logical block: 8 × sizeof(__fp16) = 16 bytes @@ -109,36 +112,45 @@ static inline bool hmx_add_overflow(size_t a, size_t b, size_t *out) { return false; } -// Search for optimal (mc, nc) chunk sizes that maximize mc * nc within VTCM budget. +// Search for optimal (mc, nc) chunk sizes within VTCM budget. +// +// VTCM model: nc * per_n_cost + mc * per_m_cost + mc * nc * per_mn_cost + overhead // -// Cost model: total = nc * per_n_cost + mc * per_m_cost + mc * nc * per_mn_cost + overhead -// per_n_cost: bytes per nc column (weight + scratch buffers) -// per_m_cost: bytes per mc row (activation) -// per_mn_cost: bytes per mc*nc element (output) -// overhead: fixed bytes (scales 256B, eye_tile 2048B, etc.) +// Minimize ceil(m/mc) * m_block_cost + ceil(n/nc) * n_block_cost. +// All matmul paths repeat weight processing per M-block and activation loading +// per N-block, so discrete block counts drive total overhead. +// Tie-break: when cost is equal, prefer larger mc * nc. +// +// Caller-provided coefficients: +// m_block_cost: penalty per extra M-block (weight redundancy, scales with n). +// n_block_cost: penalty per extra N-block (activation redundancy, scales with m). // // Algorithm: nc sweeps from n_max down by 32, analytically solving for mc_max. // Returns 0 on success, -1 if VTCM is insufficient. -static int hmx_compute_chunks( - size_t vtcm_total, size_t overhead, - size_t per_n_cost, size_t per_m_cost, size_t per_mn_cost, - int m, int n, - size_t *m_chunk_out, size_t *n_chunk_out, - size_t *total_out) -{ +static int hmx_compute_chunks(size_t vtcm_total, + size_t overhead, + size_t per_n_cost, + size_t per_m_cost, + size_t per_mn_cost, + int m, + int n, + size_t m_block_cost, + size_t n_block_cost, + size_t * m_chunk_out, + size_t * n_chunk_out, + size_t * total_out) { if (m <= 0 || n <= 0) return -1; if (vtcm_total <= overhead) return -1; if (per_n_cost == 0 || per_m_cost == 0 || per_mn_cost == 0) return -1; const size_t usable = vtcm_total - overhead; - size_t best_mn = 0, best_m = 0, best_n = 0; + + size_t best_cost = SIZE_MAX; + size_t best_mn = 0; + size_t best_m = 0, best_n = 0; const size_t n_max = hex_align_down((size_t)n, HMX_FP16_TILE_N_COLS); for (size_t nc = n_max; nc >= HMX_FP16_TILE_N_COLS; nc -= HMX_FP16_TILE_N_COLS) { - // Early exit: if nc * m_max cannot beat best, smaller nc won't either - if (nc * hex_align_down((size_t)m, HMX_FP16_TILE_N_ROWS) <= best_mn) - break; - size_t n_fixed = 0, ncmn = 0, mc_denom = 0; if (hmx_mul_overflow(nc, per_n_cost, &n_fixed)) continue; if (n_fixed >= usable) goto next_nc; @@ -152,10 +164,19 @@ static int hmx_compute_chunks( mc = hex_align_down(mc, HMX_FP16_TILE_N_ROWS); mc = hex_smin(mc, (size_t)m); - if (mc > 0 && mc * nc > best_mn) { - best_mn = mc * nc; - best_m = mc; - best_n = nc; + if (mc == 0) { + goto next_nc; + } + + size_t mblocks = ((size_t) m + mc - 1) / mc; + size_t nblocks = ((size_t) n + nc - 1) / nc; + size_t cost = mblocks * m_block_cost + nblocks * n_block_cost; + size_t mn = mc * nc; + if (cost < best_cost || (cost == best_cost && mn > best_mn)) { + best_cost = cost; + best_mn = mn; + best_m = mc; + best_n = nc; } } @@ -233,7 +254,7 @@ static inline HVX_Vector dequantize_x4x2_q4_0_group_hvx( const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); HVX_Vector v_scales = hvx_vec_splat_f16(*scale); // q4x4x2 stores two int4 values per byte. Keep only the selected nibble. - HVX_Vector v_quants = upper_nibbles ? Q6_Vub_vlsr_VubR(vq, 4) : vq; + HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles); v_quants = Q6_V_vand_VV(v_quants, mask_h4); // Shuffle before LUT v_quants = Q6_Vb_vshuff_Vb(v_quants); @@ -257,7 +278,7 @@ static inline void dequantize_x4x2_q4_0_x4groups_hvx( // Load all 128 packed bytes (4 contiguous 32-byte groups) HVX_Vector vq = hvx_vmemu(packed_128); const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); - HVX_Vector v_quants = upper_nibbles ? Q6_Vub_vlsr_VubR(vq, 4) : vq; + HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles); v_quants = Q6_V_vand_VV(v_quants, mask_h4); // Shuffle before LUT @@ -277,10 +298,8 @@ static inline void dequantize_x4x2_q4_0_x4groups_hvx( v_hi = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hi, v_sc23)); // Extract individual groups: scatter uses q_mask64 so only first 64 bytes matter - out[0] = v_lo; // group0 already in [0:63] - out[1] = Q6_V_vror_VR(v_lo, 64); // group1 rotated to [0:63] - out[2] = v_hi; // group2 already in [0:63] - out[3] = Q6_V_vror_VR(v_hi, 64); // group3 rotated to [0:63] + out[0] = v_lo; // group0 already in [0:63] + out[1] = v_hi; // group2 already in [0:63] } // Dequantize one x4x2 Q8_0 group (32 int8 quants) -> 32 FP16 in first 64 bytes. @@ -384,8 +403,9 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task( size_t row_stride, int weight_type, int start_tile, int end_tile) { - const int n_k_tiles = k_block / HMX_FP16_TILE_N_COLS; - const int qrow_size = (weight_type == HTP_TYPE_Q8_0) ? k_block : (k_block / 2); + const int n_k_tiles = (unsigned)k_block / HMX_FP16_TILE_N_COLS; + const bool is_q4 = (weight_type == HTP_TYPE_Q4_0 || weight_type == HTP_TYPE_IQ4_NL); + const int qrow_size = is_q4 ? ((unsigned)k_block / 2) : k_block; const HVX_Vector vlut_cvt = (weight_type == HTP_TYPE_IQ4_NL) ? hvx_vmem(iq4_nl_to_fp16_lut) : (weight_type == HTP_TYPE_MXFP4) ? hvx_vmem(mxfp4_to_fp16_lut) : @@ -398,47 +418,46 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task( const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); // 4 bytes = 1 column step const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64); // first 16 words (64 bytes) - for (int t = start_tile; t < end_tile; ) { - int ct = t / n_k_tiles; // column tile index - int kt = t % n_k_tiles; // K tile index + unsigned ct = (unsigned)start_tile / n_k_tiles; // column tile index + unsigned kt = (unsigned)start_tile % n_k_tiles; // K tile index + for (unsigned t = start_tile; t < end_tile; ) { + if (kt >= n_k_tiles) { kt = 0; ct++; } - // --- Batch-4 fast path for Q4_0/IQ4_NL: process 4 contiguous K-tiles with one vlut16 per row --- - if ((weight_type == HTP_TYPE_Q4_0 || weight_type == HTP_TYPE_IQ4_NL) && (kt % 4 == 0) && (t + 4 <= end_tile) && - ((t + 3) / n_k_tiles == ct)) { - int blk_idx = (kt * 32) / QK_Q4_0x4x2; - int sub_blk_base = ((kt * 32) % QK_Q4_0x4x2) / 32; // 0 or 4 - bool upper = (sub_blk_base >= 4); - int packed_off = blk_idx * (QK_Q4_0x4x2 / 2); // 128 contiguous packed bytes - int scale_off = qrow_size + blk_idx * HMX_X4X2_DBLK_SIZE - + sub_blk_base * (int)sizeof(__fp16); // 4 consecutive scales + // --- Batch-4 fast path for Q4: process 4 contiguous K-tiles with one vlut16 per row --- + if (is_q4 && (kt % 4 == 0) && (t + 4 <= end_tile) && ((t + 3) / n_k_tiles == ct)) { + unsigned blk_idx = (kt * 32) / QK_Q4_0x4x2; + unsigned sub_blk_base = ((kt * 32) % QK_Q4_0x4x2) / 32; // 0 or 4 + bool upper = (sub_blk_base >= 4); + unsigned packed_off = blk_idx * (QK_Q4_0x4x2 / 2); // 128 contiguous packed bytes + unsigned scale_off = qrow_size + blk_idx * HMX_X4X2_DBLK_SIZE + + sub_blk_base * (int)sizeof(__fp16); // 4 consecutive scales __fp16 *tile_bases[4]; - for (int g = 0; g < 4; g++) { tile_bases[g] = vtcm_dst + (t + g) * HMX_FP16_TILE_N_ELMS; } + for (unsigned g = 0; g < 4; g++) { tile_bases[g] = vtcm_dst + (t + g) * HMX_FP16_TILE_N_ELMS; } HVX_Vector v_off = v_scat_base; - for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) { - int row0 = ct * HMX_FP16_TILE_N_COLS + r; - int row1 = row0 + 1; - const uint8_t *r0 = vtcm_src + row0 * row_stride; - const uint8_t *r1 = vtcm_src + row1 * row_stride; - HVX_Vector v0[4], v1[4]; - dequantize_x4x2_q4_0_x4groups_hvx(r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt, v0); - if (row1 < n_cols) { - dequantize_x4x2_q4_0_x4groups_hvx(r1 + packed_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt, v1); - } else { - v1[0] = v1[1] = v1[2] = v1[3] = Q6_V_vzero(); - } + unsigned row_offset = ct * HMX_FP16_TILE_N_COLS * row_stride; + unsigned row1 = ct * HMX_FP16_TILE_N_COLS + 1; - for (int g = 0; g < 4; g++) { Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_bases[g], HMX_FP16_TILE_SIZE - 1, v_off, v0[g]); } + for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) { + HVX_Vector v0[2]; + const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride; + dequantize_x4x2_q4_0_x4groups_hvx(r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt, v0); + Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, v0[0]); + Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, v0[1]); v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); - for (int g = 0; g < 4; g++) { Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_bases[g], HMX_FP16_TILE_SIZE - 1, v_off, v1[g]); } + + + r0 = vtcm_src + row_offset; row_offset += row_stride; + dequantize_x4x2_q4_0_x4groups_hvx(r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt, v0); + Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, v0[0]); + Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, v0[1]); v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); } for (int g = 0; g < 4; g++) { (void) *(volatile HVX_Vector *)(tile_bases[g]); } - - t += 4; + t += 4; kt += 4; continue; } @@ -495,20 +514,19 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task( // --- Single-tile fallback --- __fp16 *tile_base = vtcm_dst + t * HMX_FP16_TILE_N_ELMS; - if (weight_type == HTP_TYPE_Q4_0 || weight_type == HTP_TYPE_IQ4_NL) { - int blk_idx = (kt * 32) / QK_Q4_0x4x2; - int sub_blk = ((kt * 32) % QK_Q4_0x4x2) / 32; - bool upper = (sub_blk >= 4); - int byte_off = blk_idx * (QK_Q4_0x4x2 / 2) + (upper ? (sub_blk - 4) : sub_blk) * 32; - int scale_off = qrow_size + blk_idx * HMX_X4X2_DBLK_SIZE + sub_blk * (int)sizeof(__fp16); + if (is_q4) { + unsigned blk_idx = (kt * 32) / QK_Q4_0x4x2; + unsigned sub_blk = ((kt * 32) % QK_Q4_0x4x2) / 32; + bool upper = (sub_blk >= 4); + unsigned byte_off = blk_idx * (QK_Q4_0x4x2 / 2) + (upper ? (sub_blk - 4) : sub_blk) * 32; + unsigned scale_off = qrow_size + blk_idx * HMX_X4X2_DBLK_SIZE + sub_blk * (int)sizeof(__fp16); HVX_Vector v_off = v_scat_base; // reset to column 0 - for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) { - int row0 = ct * HMX_FP16_TILE_N_COLS + r; - int row1 = row0 + 1; - - const uint8_t *r0 = vtcm_src + row0 * row_stride; - const uint8_t *r1 = vtcm_src + row1 * row_stride; + unsigned row_offset = ct * HMX_FP16_TILE_N_COLS * row_stride; + unsigned row1 = ct * HMX_FP16_TILE_N_COLS + 1; + for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) { + const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride; + const uint8_t *r1 = vtcm_src + row_offset; row_offset += row_stride; HVX_Vector v0 = dequantize_x4x2_q4_0_group_hvx( r0 + byte_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt); @@ -585,7 +603,7 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task( } (void) *(volatile HVX_Vector *)(tile_base); } - ++t; + ++t; ++kt; } // Drain HVX scatter write buffer: a vmem load on the same HW thread retires @@ -653,9 +671,13 @@ static void dequantize_x4x2_weight_chunk_to_fp16_tiles( // --- End x4x2 dequantizers --- // requires external HMX lock -static void core_dot_chunk_fp16(__fp16 *output, const __fp16 *activation, const __fp16 *weight, const __fp16 *scales, +static void core_dot_chunk_fp16(__fp16 *restrict output, const __fp16 *restrict activation, const __fp16 *restrict weight, const __fp16 *restrict scales, int n_row_tiles, int n_col_tiles, int n_dot_tiles) { - hmx_set_output_scales(scales); + __builtin_assume(n_row_tiles > 0); + __builtin_assume(n_col_tiles > 0); + __builtin_assume(n_dot_tiles > 0); + + Q6_bias_mxmem2_A((void *)scales); for (int r = 0; r < n_row_tiles; ++r) { for (int c = 0; c < n_col_tiles; ++c) { @@ -665,16 +687,55 @@ static void core_dot_chunk_fp16(__fp16 *output, const __fp16 *activation, const const __fp16 *col_tiles = weight + c * n_dot_tiles * HMX_FP16_TILE_N_ELMS; for (int k = 0; k < n_dot_tiles; ++k) { - int offset = k * HMX_FP16_TILE_N_ELMS; - hmx_load_tile_pair_fp16(row_tiles + offset, col_tiles + offset); + Q6_activation_hf_mxmem_RR((unsigned int)row_tiles, 2047); + Q6_weight_hf_mxmem_RR((unsigned int)col_tiles, 2047); + row_tiles += HMX_FP16_TILE_N_ELMS; + col_tiles += HMX_FP16_TILE_N_ELMS; } __fp16 *out_tile = output + (r * n_col_tiles + c) * HMX_FP16_TILE_N_ELMS; - hmx_consume_accumulator_fp16(out_tile); + Q6_mxmem_AR_after_hf(out_tile, 0); } } } +// --- Async HMX matmul job (for pipeline overlap) --- + +typedef struct { + __fp16 * output; + const __fp16 * activation; + const __fp16 * weight; + const __fp16 * scales; + uint32_t n_row_tiles; + uint32_t n_col_tiles; + uint32_t n_dot_tiles; +} hmx_matmul_job_t; + +static void hmx_matmul_worker_fn(void * data) { + hmx_matmul_job_t * job = (hmx_matmul_job_t *) data; + FARF(HIGH, "hmx-mm-job: n_row_tiles %u n_col_tiles %u n_dot_tiles %u", job->n_row_tiles, job->n_col_tiles, job->n_dot_tiles); + core_dot_chunk_fp16(job->output, job->activation, job->weight, job->scales, job->n_row_tiles, job->n_col_tiles, job->n_dot_tiles); +} + +static inline void hmx_matmul_job_init(hmx_matmul_job_t * job, + __fp16 * output, + const __fp16 * activation, + const __fp16 * weight, + const __fp16 * scales, + int n_row_tiles, + int n_col_tiles, + int n_dot_tiles) { + job->output = output; + job->activation = activation; + job->weight = weight; + job->scales = scales; + job->n_row_tiles = n_row_tiles; + job->n_col_tiles = n_col_tiles; + job->n_dot_tiles = n_dot_tiles; +} + +// --- End async HMX matmul job --- + static void transfer_output_chunk_fp16_to_fp32(float *restrict dst, const __fp16 *restrict vtcm_src, int n_rows, int n_cols, int n) { assert(n_cols % HMX_FP16_TILE_N_COLS == 0); const int n_col_tiles = n_cols / HMX_FP16_TILE_N_COLS; @@ -832,12 +893,13 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu const size_t f32_scratch_per_m = use_dma_activation ? (size_t) params->k * sizeof(float) : 0; size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0, vtcm_used = 0; + // FP16 weight: interleave and activation load have similar per-element cost. if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, - /*per_n=*/3 * vec_dot_size, - /*per_m=*/group_size * vec_dot_size + f32_scratch_per_m, - /*per_mn=*/sizeof(__fp16), - params->m, params->n, - &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) { + /*per_n=*/3 * vec_dot_size, + /*per_m=*/group_size * vec_dot_size + f32_scratch_per_m, + /*per_mn=*/sizeof(__fp16), params->m, params->n, + /*m_block_cost=*/(size_t) params->n, + /*n_block_cost=*/(size_t) params->m, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) { FARF(HIGH, "%s: grouped path does not fit VTCM, falling back to legacy batched loop", __func__); return hmx_mat_mul_permuted_w16a32_batched_legacy(ctx, params); } @@ -1006,13 +1068,15 @@ int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, float *restrict dst, co const size_t f32_scratch_per_m = use_dma_activation ? (size_t) k * sizeof(float) : 0; size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0, vtcm_used = 0; + // FP16 weight: interleave and activation load have similar per-element cost. if (hmx_compute_chunks(vtcm_budget, - /*overhead=*/ 256, - /*per_n=*/ 3 * vec_dot_size, // W + S0 + S1 - /*per_m=*/ vec_dot_size + f32_scratch_per_m, // A + optional F32 scratch - /*per_mn=*/ sizeof(__fp16), // O - m, n, - &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) { + /*overhead=*/256, + /*per_n=*/3 * vec_dot_size, // W + S0 + S1 + /*per_m=*/vec_dot_size + f32_scratch_per_m, // A + optional F32 scratch + /*per_mn=*/sizeof(__fp16), // O + m, n, + /*m_block_cost=*/(size_t) n, + /*n_block_cost=*/(size_t) m, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) { FARF(HIGH, "%s: VTCM too small (m=%d k=%d n=%d budget=%zu)", __func__, m, k, n, vtcm_budget); return -1; } @@ -1157,6 +1221,8 @@ int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, float *restrict dst, co int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict out, const float *restrict x, const uint8_t *restrict w, int m, int k, int n, int w_type); +#define FALLBACK_TO_STANDARD 1 + int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict dst, const float *restrict activation, const uint8_t *restrict permuted_weight, int m, int k, int n, int weight_type) { @@ -1169,9 +1235,12 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds // for large m, k (e.g. prefill FFN Down), use out-stationary version if (m >= 128 && k > n && n > 1024) { - FARF(MEDIUM, "hmx_matmul_qk: OUT-STATIONARY path m=%d k=%d n=%d type=%d (K_BLOCK=512, %d K-iters with fp16 intermediate)", - m, k, n, weight_type, (k + 511) / 512); - return mat_mul_qk_0_d16a32_out_stationary(ctx, dst, activation, permuted_weight, m, k, n, weight_type); + int rc = mat_mul_qk_0_d16a32_out_stationary(ctx, dst, activation, permuted_weight, m, k, n, weight_type); + if (rc != FALLBACK_TO_STANDARD) { + return rc; // 0 success, -1 error + } + FARF(MEDIUM, "hmx_matmul_qk: out-stationary fallback to standard m=%d k=%d n=%d", m, k, n); + // fall through to standard path } size_t row_stride = get_x4x2_row_stride(weight_type, k); @@ -1197,9 +1266,10 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds } size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0, vtcm_used = 0; - if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, - per_n_cost, /*per_m=*/vec_dot_size, per_mn_cost, - m, n, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) { + // Quantized weight: dequant ~1.5x more expensive per element than activation load. + if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, per_n_cost, /*per_m=*/vec_dot_size, per_mn_cost, m, n, + /*m_block_cost=*/(size_t) n * 3, + /*n_block_cost=*/(size_t) m * 2, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) { FARF(HIGH, "%s: VTCM too small (m=%d k=%d n=%d pipe=%d budget=%zu)", __func__, m, k, n, use_pipeline, vtcm_budget); return -1; @@ -1256,9 +1326,8 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds use_pipeline ? "PIPELINE" : "SEQUENTIAL", m_chunk_n_rows, n_chunk_n_cols, (size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget); - HAP_compute_res_hmx_lock(ctx->vtcm_rctx); - if (!use_pipeline) { + HAP_compute_res_hmx_lock(ctx->vtcm_rctx); for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) { // transfer activation matrix chunk into VTCM size_t n_rows = hex_smin(m - mr, m_chunk_n_rows); @@ -1318,20 +1387,22 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds TIMER_STOP(output_store); } } + HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); } else { // 4-stage pipeline: DMA load (A), dequantize (B), HMX matmul (C), store (D) - // stage B and D (dequantize and store) are expected to be on the critical path + // HMX compute (C) runs on dedicated worker thread, overlapping with HVX stages (B, D). // A --> B: vtcm_qweight, 1 buffer // B --> C: vtcm_weight0/vtcm_weight1, 2 buffers // C --> D: vtcm_output0/vtcm_output1, 2 buffers - // - // LD ||A3| | B3 || - // MM || C2 || - // ST || D1 | || + // Async timeline (C overlaps B+D): + // main+HVX: [A0][Act][B0][A1][sub C0][B1‖C0][A2][wait,sub C1][D0+B2‖C1][wait,sub C2][D1‖C2][wait][D2] + // HMX queue: [████ C0 ████████][████ C1 ████████████][████ C2 ████████] int n_chunk_cnt = hmx_ceil_div(n, n_chunk_n_cols); + hmx_matmul_job_t job_slots[2]; // persistent double-buffered job descriptors + for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) { const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows); @@ -1352,31 +1423,34 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds transfer_activation_chunk_threaded(ctx, vtcm_activation, activation_chunk, n_rows, k, k); } - // prologue: B0, A1, C0, B1 + // prologue: B0, A1, submit C0 (async), B1 (overlaps C0) { - // B0 + // B0: wait for DMA, dequant weight chunk 0 dma_queue_pop(ctx->dma[0]); dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[0], vtcm_qweight, n_cols_A0, k, row_stride, weight_type); - // A1 + // A1: issue DMA for weight chunk 1 const size_t n_cols_A1 = hex_smin(n - 1 * n_chunk_n_cols, n_chunk_n_cols); if (1 < n_chunk_cnt) { const uint8_t *qweight_chunk_A1 = permuted_weight + n_chunk_n_cols * row_stride; dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A1), row_stride, row_stride, row_stride, n_cols_A1); } - // C0 - core_dot_chunk_fp16((__fp16 *) vtcm_output_bufs[0], (__fp16 *) vtcm_activation, (__fp16 *) vtcm_weight_bufs[0], vtcm_scales, - hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS), hmx_ceil_div(n_cols_A0, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS); + // submit C0 (non-blocking — HMX worker executes in parallel) + hmx_matmul_job_init(&job_slots[0], (__fp16 *) vtcm_output_bufs[0], (__fp16 *) vtcm_activation, + (__fp16 *) vtcm_weight_bufs[0], vtcm_scales, + hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS), + hmx_ceil_div(n_cols_A0, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS); + hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_matmul_worker_fn, &job_slots[0])); - // B1 + // B1: DMA pop + dequant (runs in parallel with C0 on HMX worker) if (1 < n_chunk_cnt) { dma_queue_pop(ctx->dma[0]); dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[1], vtcm_qweight, n_cols_A1, k, row_stride, weight_type); } } - // main loop + // main loop: wait C_i → submit C_{i+1} → D_i + B_{i+2} (parallel with C_{i+1}) for (int i = 0; i < n_chunk_cnt; ++i) { const size_t nc = i * n_chunk_n_cols; const size_t nc_p1 = nc + 1 * n_chunk_n_cols; @@ -1386,36 +1460,41 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds const size_t n_cols_p1 = hex_smin(n - nc_p1, n_chunk_n_cols); const size_t n_cols_p2 = hex_smin(n - nc_p2, n_chunk_n_cols); - // issue A_{i+2} + // issue A_{i+2}: DMA push (non-blocking) if (i + 2 < n_chunk_cnt) { const uint8_t *qweight_chunk_p2 = permuted_weight + nc_p2 * row_stride; dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_p2), row_stride, row_stride, row_stride, n_cols_p2); } - // wait for HMX (C_{i}) -- C_{i} is done - - // result of B_{i+1} (input of C_{i+1}) should be ready now + // wait C_i: block until prologue/previous C completes + hmx_queue_pop(ctx->hmx_queue); - // issue C_{i+1} + // submit C_{i+1} (non-blocking, overlaps with D_i + B_{i+2} below) + // job_slots[(i+1)%2] is safe: C_i just completed, freeing slot i%2's + // counterpart — and (i+1)%2 was last used by C_{i-1} which completed + // before C_i was submitted. if (i + 1 < n_chunk_cnt) { - core_dot_chunk_fp16((__fp16 *) vtcm_output_bufs[(i + 1) % 2], (__fp16 *) vtcm_activation, (__fp16 *) vtcm_weight_bufs[(i + 1) % 2], vtcm_scales, - hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS), hmx_ceil_div(n_cols_p1, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS); + hmx_matmul_job_init(&job_slots[(i + 1) % 2], (__fp16 *) vtcm_output_bufs[(i + 1) % 2], + (__fp16 *) vtcm_activation, (__fp16 *) vtcm_weight_bufs[(i + 1) % 2], + vtcm_scales, hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS), + hmx_ceil_div(n_cols_p1, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS); + hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_matmul_worker_fn, &job_slots[(i + 1) % 2])); } - // compute D_{i} + // D_i: store output (multi-thread HVX, parallel with C_{i+1}) float *output_chunk = dst + (mr * n + nc); transfer_output_chunk_threaded(ctx, output_chunk, vtcm_output_bufs[i % 2], n_rows, n_cols, n); - // wait for DMA (A_{i+2}), compute B_{i+2} + // B_{i+2}: DMA pop + dequant (multi-thread HVX, parallel with C_{i+1}) if (i + 2 < n_chunk_cnt) { dma_queue_pop(ctx->dma[0]); dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[(i + 2) % 2], vtcm_qweight, n_cols_p2, k, row_stride, weight_type); } } } - } - HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); + hmx_queue_suspend(ctx->hmx_queue); + } TIMER_STOP(total); @@ -1434,10 +1513,13 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds } // C += AB -void core_mma_chunk_fp16(__fp16 *c, const __fp16 *a, const __fp16 *b, const __fp16 *col_scales, const __fp16 *eye_tile, +void core_mma_chunk_fp16(__fp16 *restrict c, const __fp16 *restrict a, const __fp16 *restrict b, const __fp16 *restrict col_scales, const __fp16 *restrict eye_tile, int n_row_tiles, int n_col_tiles, int n_dot_tiles, bool zero_init) { + __builtin_assume(n_row_tiles > 0); + __builtin_assume(n_col_tiles > 0); + __builtin_assume(n_dot_tiles > 0); - hmx_set_output_scales(col_scales); + Q6_bias_mxmem2_A((void *)col_scales); for (int i = 0; i < n_row_tiles; ++i) { for (int j = 0; j < n_col_tiles; ++j) { @@ -1448,15 +1530,17 @@ void core_mma_chunk_fp16(__fp16 *c, const __fp16 *a, const __fp16 *b, const __fp __fp16 *accum_tile = c + (i * n_col_tiles + j) * HMX_FP16_TILE_N_ELMS; if (!zero_init) { - hmx_load_tile_pair_fp16(accum_tile, eye_tile); + Q6_activation_hf_mxmem_RR((unsigned int)accum_tile, 2047); + Q6_weight_hf_mxmem_RR((unsigned int)eye_tile, 2047); } for (int k = 0; k < n_dot_tiles; ++k) { - int offset = k * HMX_FP16_TILE_N_ELMS; - hmx_load_tile_pair_fp16(row_tiles + offset, col_tiles + offset); + Q6_activation_hf_mxmem_RR((unsigned int)row_tiles, 2047); + Q6_weight_hf_mxmem_RR((unsigned int)col_tiles, 2047); + row_tiles += HMX_FP16_TILE_N_ELMS; + col_tiles += HMX_FP16_TILE_N_ELMS; } - - hmx_consume_accumulator_fp16(accum_tile); + Q6_mxmem_AR_after_hf(accum_tile, 0); } } } @@ -1540,12 +1624,41 @@ int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict const size_t vtcm_budget = ctx->vtcm_size; - const size_t M_BLOCK_SIZE = 512; - const size_t N_BLOCK_SIZE = 512; - const size_t K_BLOCK_SIZE = 512; + const size_t K_BLOCK_SIZE = 1024; - // Compute precise buffer sizes + // Fallback: if k doesn't need K-blocking, out-stationary has no advantage + const size_t k_iters_check = (k + K_BLOCK_SIZE - 1) / K_BLOCK_SIZE; + if (k_iters_check <= 1) { + FARF(MEDIUM, "%s: K_BLK=%zu >= k=%d, fallback to standard path", __func__, K_BLOCK_SIZE, k); + return FALLBACK_TO_STANDARD; + } + + // Dynamic M,N search via hmx_compute_chunks const size_t sub_row_stride_alloc = get_x4x2_row_stride(weight_type, K_BLOCK_SIZE); + const size_t per_m = K_BLOCK_SIZE * sizeof(float) // scratch1: M×K×4 (act DMA staging F32) + + K_BLOCK_SIZE * sizeof(__fp16); // activation: M×K×2 (F16 tiles) + const size_t per_n = sub_row_stride_alloc // scratch0: N×sub_row(K) (packed quant) + + K_BLOCK_SIZE * sizeof(__fp16); // weight: N×K×2 (F16 tiles) + const size_t per_mn = sizeof(__fp16); // output: M×N×2 (out-stationary) + // Alignment margin: hex_align_up can add up to 2047 bytes per buffer; + // scratch1 (mc×6144) is naturally 2048-aligned, remaining 4 buffers need margin + const size_t align_margin = 4 * HMX_FP16_TILE_SIZE; + const size_t overhead = HMX_FP16_TILE_SIZE + 256 + align_margin; // eye_tile + scales + alignment + + size_t M_BLOCK_SIZE, N_BLOCK_SIZE, vtcm_used; + // Cost-based search: minimize ceil(m/mc)*m_block_cost + ceil(n/nc)*n_block_cost. + // From profiling: wt_dequant per element ≈ 1.5× activation load per element. + // m_block_cost = n*3: each extra M-block re-dequants all N×K weight (expensive). + // n_block_cost = m*2: each extra N-block re-loads all M×K activation (cheaper). + const size_t m_block_cost = (size_t) n * 3; + const size_t n_block_cost = (size_t) m * 2; + if (hmx_compute_chunks(vtcm_budget, overhead, per_n, per_m, per_mn, m, n, m_block_cost, n_block_cost, &M_BLOCK_SIZE, + &N_BLOCK_SIZE, &vtcm_used) != 0) { + FARF(HIGH, "%s: VTCM too small (m=%d k=%d n=%d budget=%zu)", __func__, m, k, n, vtcm_budget); + return -1; + } + + // Compute precise buffer sizes from searched M,N and fixed K const size_t weight_size = hex_align_up(N_BLOCK_SIZE * K_BLOCK_SIZE * sizeof(__fp16), HMX_FP16_TILE_SIZE); const size_t act_size = hex_align_up(M_BLOCK_SIZE * K_BLOCK_SIZE * sizeof(__fp16), HMX_FP16_TILE_SIZE); const size_t out_size = hex_align_up(M_BLOCK_SIZE * N_BLOCK_SIZE * sizeof(__fp16), HMX_FP16_TILE_SIZE); @@ -1554,7 +1667,8 @@ int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict const size_t total_vtcm = weight_size + act_size + out_size + scratch0_sz + scratch1_sz + HMX_FP16_TILE_SIZE + 256; if (total_vtcm > vtcm_budget) { - FARF(HIGH, "%s: VTCM too small: need %zu have %zu (m=%d k=%d n=%d)", __func__, total_vtcm, vtcm_budget, m, k, n); + FARF(HIGH, "%s: VTCM overflow after search: need %zu have %zu (M=%zu N=%zu K=%zu)", __func__, total_vtcm, + vtcm_budget, M_BLOCK_SIZE, N_BLOCK_SIZE, K_BLOCK_SIZE); return -1; } @@ -1568,8 +1682,8 @@ int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); assert((size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base) <= vtcm_budget); - FARF(MEDIUM, "%s: m=%d k=%d n=%d wtype=%d vtcm=%zu/%zu", __func__, m, k, n, weight_type, - (size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget); + FARF(HIGH, "hmx-mm: m=%d k=%d n=%d wtype=%d block M=%zu N=%zu K=%zu vtcm=%zu/%zu", __func__, m, k, n, weight_type, + M_BLOCK_SIZE, N_BLOCK_SIZE, K_BLOCK_SIZE, (size_t) (vtcm_ptr - (uint8_t *) ctx->vtcm_base), vtcm_budget); // initialize eye tile (32x32 identity matrix) { diff --git a/ggml/src/ggml-hexagon/htp/hmx-queue.c b/ggml/src/ggml-hexagon/htp/hmx-queue.c new file mode 100644 index 00000000000..5b1d83a0cbf --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hmx-queue.c @@ -0,0 +1,158 @@ +#pragma clang diagnostic ignored "-Wunused-function" + +#include +#include +#include + +#include +#include + +#include + +#include "hmx-queue.h" + +#define QURT_LOWEST_PRIO (254) + +static inline void hmx_lock(struct hmx_queue *q) +{ + if (!q->hmx_locked) { + HAP_compute_res_hmx_lock(q->hap_rctx); + q->hmx_locked = true; + } +} + +static inline void hmx_unlock(struct hmx_queue *q) +{ + if (q->hmx_locked) { + HAP_compute_res_hmx_unlock(q->hap_rctx); + q->hmx_locked = false; + } +} + +static inline void hmx_queue_process(struct hmx_queue *q, bool* killed) { + unsigned int ir = atomic_load(&q->idx_read); + + while (ir != atomic_load(&q->idx_write)) { + struct hmx_queue_desc *d = &q->desc[ir]; + if (!d->done) { + FARF(HIGH, "hmx-queue-process: ir %u func %p data %p", ir, d->func, d->data); + + enum hmx_queue_signal sig = (enum hmx_queue_signal) (unsigned int) d->func; + switch (sig) { + case HMX_QUEUE_NOOP: /* noop */; break; + case HMX_QUEUE_KILL: *killed = true; break; + case HMX_QUEUE_SUSPEND: hmx_unlock(q); break; + default: + hmx_lock(q); + d->func(d->data); + break; + } + + atomic_fetch_add(&d->done, 1); + } + + ir = (ir + 1) & q->idx_mask; + atomic_store(&q->idx_read, ir); + } +} + +static void hmx_queue_thread(void * arg) { + struct hmx_queue * q = (struct hmx_queue *) arg; + + FARF(HIGH, "hmx-queue-thread: started"); + + bool killed = false; + + unsigned int poll_cnt = HMX_QUEUE_POLL_COUNT; + unsigned int prev_seqn = 0; + while (!killed) { + unsigned int seqn = atomic_load(&q->seqn); + if (seqn == prev_seqn) { + if (--poll_cnt) { hex_pause(); continue; } + FARF(HIGH, "hmx-queue-thread: sleeping"); + qurt_futex_wait(&q->seqn, prev_seqn); + continue; + } + prev_seqn = seqn; + poll_cnt = HMX_QUEUE_POLL_COUNT; + + FARF(HIGH, "hmx-queue-thread: new work"); + + hmx_queue_process(q, &killed); + } + + FARF(HIGH, "hmx-queue-thread: stopped"); +} + +struct hmx_queue * hmx_queue_create(size_t capacity, uint32_t hap_rctx) { + capacity = hex_ceil_pow2(capacity); + + struct hmx_queue * q = (struct hmx_queue *) memalign(32, sizeof(struct hmx_queue)); + if (q == NULL) { + FARF(ERROR, "%s: failed to allocate DMA queue\n", __FUNCTION__); + return NULL; + } + memset(q, 0, sizeof(struct hmx_queue)); + q->capacity = capacity; + q->idx_mask = capacity - 1; + q->hap_rctx = hap_rctx; + + q->desc = (struct hmx_queue_desc *) memalign(64, capacity * sizeof(struct hmx_queue_desc)); + if (!q->desc) { + FARF(ERROR, "hmx-queue: failed to allocate HMX queue descriptors\n"); + return NULL; + } + memset(q->desc, 0, capacity * sizeof(struct hmx_queue_desc)); + + const size_t stack_size = HMX_QUEUE_THREAD_STACK_SIZE; + q->stack = (unsigned char *) memalign(64, stack_size); + if (!q->stack) { + FARF(ERROR, "hmx-queue: thread stack allocation failed (%zu bytes)", stack_size); + return NULL; + } + memset(q->stack, 0, stack_size); + + // Match caller thread priority (same pattern as worker-pool.c). + int prio = qurt_thread_get_priority(qurt_thread_get_id()); + if (prio < 1) { + prio = 1; + } + if (prio > QURT_LOWEST_PRIO) { + prio = QURT_LOWEST_PRIO; + } + + qurt_thread_attr_t attr; + qurt_thread_attr_init(&attr); + qurt_thread_attr_set_stack_addr(&attr, q->stack); + qurt_thread_attr_set_stack_size(&attr, stack_size); + qurt_thread_attr_set_priority(&attr, prio); + qurt_thread_attr_set_name(&attr, "hmx-queue"); + + int err = qurt_thread_create(&q->thread, &attr, hmx_queue_thread, q); + if (err) { + FARF(ERROR, "hmx-worker: thread create failed (%d)", err); + return NULL; + } + + FARF(HIGH, "hmx-queue: capacity %u\n", capacity); + + return q; +} + +void hmx_queue_delete(struct hmx_queue * q) { + if (!q) { + return; + } + + // Tell the worker to exit. + hmx_queue_flush(q); + hmx_queue_signal(q, HMX_QUEUE_KILL); + hmx_queue_flush(q); + + int status; + qurt_thread_join(q->thread, &status); + + free(q->desc); + free(q->stack); + free(q); +} diff --git a/ggml/src/ggml-hexagon/htp/hmx-queue.h b/ggml/src/ggml-hexagon/htp/hmx-queue.h new file mode 100644 index 00000000000..0d48c280f52 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hmx-queue.h @@ -0,0 +1,134 @@ +#ifndef HMX_QUEUE_H +#define HMX_QUEUE_H + +#include +#include +#include + +#include +#include +#include +#include + +#include "hex-utils.h" + +#ifdef __cplusplus +extern "C" { +#endif + +#define HMX_QUEUE_THREAD_STACK_SIZE (16 * 1024) +#define HMX_QUEUE_POLL_COUNT 2000 + +typedef void (*hmx_queue_func)(void *); + +// Dummy funcs used as signals +enum hmx_queue_signal { + HMX_QUEUE_NOOP = 0, // aka NULL + HMX_QUEUE_SUSPEND, + HMX_QUEUE_KILL +}; + +struct hmx_queue_desc { + hmx_queue_func func; + void * data; + atomic_uint done; +}; + +struct hmx_queue { + struct hmx_queue_desc * desc; + atomic_uint idx_write; // updated by producer (push) + atomic_uint idx_read; // updated by consumer (process) + unsigned int idx_pop; // updated by producer (pop) + uint32_t idx_mask; + uint32_t capacity; + + atomic_uint seqn; // incremented for all pushes, used with futex + qurt_thread_t thread; + void * stack; + uint32_t hap_rctx; + bool hmx_locked; +}; + +struct hmx_queue * hmx_queue_create(size_t capacity, uint32_t hap_rctx); +void hmx_queue_delete(struct hmx_queue * q); + +static inline struct hmx_queue_desc hmx_queue_make_desc(hmx_queue_func func, void * data) { + struct hmx_queue_desc d = { func, data }; + return d; +} + +static inline bool hmx_queue_push(struct hmx_queue * q, struct hmx_queue_desc d) { + unsigned int ir = atomic_load(&q->idx_read); + unsigned int iw = q->idx_write; + + if (((iw + 1) & q->idx_mask) == ir) { + FARF(HIGH, "hmx-queue-push: queue is full\n"); + return false; + } + + atomic_store(&d.done, 0); + + FARF(HIGH, "hmx-queue-push: iw %u func %p data %p\n", iw, d.func, d.data); + + q->desc[iw] = d; + atomic_store(&q->idx_write, (iw + 1) & q->idx_mask); + // wake up our thread + atomic_fetch_add(&q->seqn, 1); + qurt_futex_wake(&q->seqn, 1); + + return true; +} + +static inline bool hmx_queue_signal(struct hmx_queue *q, enum hmx_queue_signal sig) { + return hmx_queue_push(q, hmx_queue_make_desc((hmx_queue_func) sig, NULL)); +} + +static inline bool hmx_queue_empty(struct hmx_queue * q) { + return q->idx_pop == q->idx_write; +} + +static inline uint32_t hmx_queue_depth(struct hmx_queue * q) { + return (q->idx_read - q->idx_read) & q->idx_mask; +} + +static inline uint32_t hmx_queue_capacity(struct hmx_queue * q) { + return q->capacity; +} + +static inline struct hmx_queue_desc hmx_queue_pop(struct hmx_queue * q) { + unsigned int ip = q->idx_pop; + unsigned int iw = q->idx_write; + + struct hmx_queue_desc rd = { NULL, NULL }; + if (ip == iw) { + return rd; + } + + // Wait for desc to complete + struct hmx_queue_desc * d = &q->desc[ip]; + while (!atomic_load(&d->done)) { + FARF(HIGH, "hmx-queue-pop: waiting for HMX queue : %u\n", ip); + hex_pause(); + } + + rd = *d; + q->idx_pop = (ip + 1) & q->idx_mask; + + FARF(HIGH, "hmx-queue-pop: ip %u func %p data %p\n", ip, rd.func, rd.data); + return rd; +} + +static inline void hmx_queue_flush(struct hmx_queue * q) { + while (hmx_queue_pop(q).func != NULL) ; +} + +static inline void hmx_queue_suspend(struct hmx_queue *q) { + hmx_queue_signal(q, HMX_QUEUE_SUSPEND); + hmx_queue_flush(q); +} + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif /* HMX_QUEUE_H */ diff --git a/ggml/src/ggml-hexagon/htp/hmx-utils.h b/ggml/src/ggml-hexagon/htp/hmx-utils.h index aacfbcda287..af04619cebb 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-utils.h +++ b/ggml/src/ggml-hexagon/htp/hmx-utils.h @@ -14,10 +14,6 @@ #define HMX_INLINE_ALWAYS inline __attribute__((unused, always_inline)) -static HMX_INLINE_ALWAYS void hmx_set_output_scales(const void *scales) { - asm volatile("bias = mxmem2(%0)" :: "r"(scales)); -} - // Initialise aligned 256-byte area with scale vector + zero padding. static HMX_INLINE_ALWAYS void hmx_init_column_scales(void *out_scales, HVX_Vector v_scale) { HVX_Vector *pv = (HVX_Vector *)out_scales; @@ -25,58 +21,6 @@ static HMX_INLINE_ALWAYS void hmx_init_column_scales(void *out_scales, HVX_Vecto *pv = Q6_V_vzero(); } -// Load multiple contiguous tiles with :deep streaming. -// Rt = total region size - 1; the hardware streams through [Rs, Rs + Rt]. -// IMPORTANT: the tile region [Rs, Rs + Rt] must NOT cross a VTCM 4 MB bank -// boundary, otherwise the mxmem instruction will raise a precise bus error. -// Callers must ensure their VTCM layout satisfies this constraint. -static HMX_INLINE_ALWAYS void hmx_load_tiles_fp16(const __fp16 *row_tiles, - const __fp16 *col_tiles, - size_t n_tiles) { - size_t limit = n_tiles * HMX_FP16_TILE_SIZE - 1; - asm volatile( - "{ activation.hf = mxmem(%0, %1):deep\n" - "weight.hf = mxmem(%2, %3) }\n" - :: "r"(row_tiles), "r"(limit), "r"(col_tiles), "r"(limit) - : "memory"); -} - -// Load a single activation+weight tile pair (no :deep streaming). -// Rt defines the accessible region [Rs, Rs+Rt]. Following the reference formula -// (limit = n_tiles * HMX_FP16_TILE_SIZE - 1), for a single tile Rt = 2047. -// The original code used Rt=0x7FFF (32 KB region); when dynamic VTCM allocation -// places a tile near a 4 MB bank boundary, the oversized region crosses it and -// triggers a precise bus error (0x2601). Rt=2047 confines accesses to exactly -// one 2048-byte tile while covering all 16 HVX vectors (offsets 0..2047). -static HMX_INLINE_ALWAYS void hmx_load_tile_pair_fp16(const __fp16 *act_tile, - const __fp16 *wt_tile) { - asm volatile( - "{ activation.hf = mxmem(%0, %1)\n" - "weight.hf = mxmem(%2, %3) }\n" - :: "r"(act_tile), "r"(2047), - "r"(wt_tile), "r"(2047) - : "memory"); -} - -static HMX_INLINE_ALWAYS void hmx_consume_accumulator_fp16(__fp16 *out) { - // Use the combined convert-and-store instruction (matches the reference - // Q6_mxmem_AR_after_hf intrinsic). The previous two-instruction sequence - // "cvt.hf = acc(2); mxmem = cvt" used an undocumented Rs=2 parameter. - asm volatile( - "mxmem(%0, %1):after.hf = acc\n" - :: "r"(out), "r"(0) - : "memory"); -} - -// Compute inner product of two vectors of tiles and store result. -static HMX_INLINE_ALWAYS void hmx_dot_fp16(__fp16 *out, - const __fp16 *row_tiles, - const __fp16 *col_tiles, - size_t n_tiles) { - hmx_load_tiles_fp16(row_tiles, col_tiles, n_tiles); - hmx_consume_accumulator_fp16(out); -} - // --- VTCM sequential allocator (from htp-ops-lib/include/dsp/vtcm_mgr.h) --- static inline uint8_t *vtcm_seq_alloc(uint8_t **vtcm_ptr, size_t size) { diff --git a/ggml/src/ggml-hexagon/htp/htp-ctx.h b/ggml/src/ggml-hexagon/htp/htp-ctx.h index 4c36a6ea0c2..8b5e47adef8 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ctx.h +++ b/ggml/src/ggml-hexagon/htp/htp-ctx.h @@ -2,6 +2,7 @@ #define HTP_CTX_H #include "hex-dma.h" +#include "hmx-queue.h" #include "htp-ops.h" #include "worker-pool.h" @@ -30,6 +31,8 @@ struct htp_spad { uint32_t size_per_thread; // size per thread }; +struct htp_context; + // Context while processing an Op // TODO: fold this into the main context struct htp_ops_context { @@ -72,6 +75,10 @@ struct htp_context { atomic_bool vtcm_needs_release; struct htp_ops_context octx; + +#ifdef HTP_HAS_HMX + struct hmx_queue * hmx_queue; // Async HMX queue for pipeline overlap +#endif }; int op_matmul(struct htp_ops_context * octx); diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index 44a6ab4f737..fa84b674cd2 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -91,7 +91,12 @@ enum htp_op_code { #define HTP_OP_MAX_BUFS 8 #define HTP_OP_MAX_REQS 256 #define HTP_OP_MAX_TENSORS (HTP_OP_MAX_REQS * HTP_OP_MAX_INPUTS + HTP_OP_MAX_REQS) + +#if __HVX_ARCH__ < 75 +#define HTP_OP_MAX_VMEM (3167538380u) +#else #define HTP_OP_MAX_VMEM (3221225472u) +#endif enum htp_tensor_flags { HTP_TENSOR_COMPUTE = (1U << 0), // Tensor buffer temporal compute data (not weights) diff --git a/ggml/src/ggml-hexagon/htp/hvx-base.h b/ggml/src/ggml-hexagon/htp/hvx-base.h index db05ab40d28..ed6026e762a 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-base.h +++ b/ggml/src/ggml-hexagon/htp/hvx-base.h @@ -116,9 +116,14 @@ static inline HVX_VectorPred hvx_vec_is_nan_f16(HVX_Vector v) { } static inline HVX_Vector hvx_vec_f32_to_f16_shuff(HVX_Vector v0, HVX_Vector v1) { +#if __HVX_ARCH__ >= 81 + HVX_Vector q0 = Q6_Vqf32_equals_Vsf(v0); + HVX_Vector q1 = Q6_Vqf32_equals_Vsf(v1); +#else const HVX_Vector zero = Q6_V_vzero(); HVX_Vector q0 = Q6_Vqf32_vadd_VsfVsf(v0, zero); HVX_Vector q1 = Q6_Vqf32_vadd_VsfVsf(v1, zero); +#endif return Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(q1, q0)); } diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index 8b347039428..d71c97ed292 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -18,8 +18,9 @@ #include #include -#include "hex-dma.h" #include "hex-utils.h" +#include "hex-dma.h" +#include "hmx-queue.h" #define GGML_COMMON_DECL_C #include "ggml-common.h" @@ -324,6 +325,14 @@ AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_que #ifdef HTP_HAS_HMX ctx->hmx_enabled = use_hmx; + ctx->hmx_queue = NULL; + if (use_hmx) { + ctx->hmx_queue = hmx_queue_create(16, ctx->vtcm_rctx); + if (!ctx->hmx_queue) { + FARF(ERROR, "hmx-queue-create failed"); + ctx->hmx_enabled = false; + } + } FARF(HIGH, "HMX %s (use_hmx=%d)", ctx->hmx_enabled ? "enabled" : "disabled", use_hmx); #endif @@ -389,7 +398,11 @@ AEEResult htp_iface_stop(remote_handle64 handle) { } #ifdef HTP_HAS_HMX - ctx->hmx_enabled = 0; + if (ctx->hmx_queue) { + hmx_queue_delete(ctx->hmx_queue); + ctx->hmx_queue = NULL; + } + ctx->hmx_enabled = false; #endif vtcm_free(ctx); From 86d94cd95bb043772f6153d0add5bf6a204e066d Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Wed, 15 Apr 2026 14:45:16 +0200 Subject: [PATCH 435/831] docs: more extensive RoPE documentation [no ci] (llama/21953) * more extensive ggml_rope documentation * add more docs * nits --- ggml/include/ggml.h | 56 ++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 55 insertions(+), 1 deletion(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 11d3e8a8167..703e3783136 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -1773,8 +1773,32 @@ extern "C" { int n_dims, int mode); - // custom RoPE + // RoPE operations with extended options + // a is the input tensor to apply RoPE to, shape [n_embd, n_head, n_token] + // b is an int32 vector with size n_token // c is freq factors (e.g. phi3-128k), (optional) + // mode can be GGML_ROPE_TYPE_NORMAL or NEOX; for MROPE and VISION mode, use ggml_rope_multi + // + // pseudo-code for computing theta: + // for i in [0, n_dims/2): + // theta[i] = b[i] * powf(freq_base, -2.0 * i / n_dims); + // theta[i] = theta[i] / c[i]; # if c is provided, divide theta by c + // theta[i] = rope_yarn(theta[i], ...); # note: theta = theta * freq_scale is applied here + // + // other params are used by YaRN RoPE scaling, these default values will disable YaRN: + // freq_scale = 1.0f + // ext_factor = 0.0f + // attn_factor = 1.0f + // beta_fast = 0.0f + // beta_slow = 0.0f + // + // example: + // (marking: c = cos, s = sin, 0 = unrotated) + // given a single head with size = 8 --> [00000000] + // GGML_ROPE_TYPE_NORMAL n_dims = 4 --> [cscs0000] + // GGML_ROPE_TYPE_NORMAL n_dims = 8 --> [cscscscs] + // GGML_ROPE_TYPE_NEOX n_dims = 4 --> [ccss0000] + // GGML_ROPE_TYPE_NEOX n_dims = 8 --> [ccccssss] GGML_API struct ggml_tensor * ggml_rope_ext( struct ggml_context * ctx, struct ggml_tensor * a, @@ -1790,6 +1814,36 @@ extern "C" { float beta_fast, float beta_slow); + // multi-dimensional RoPE, for Qwen-VL and similar vision models + // mode can be either VISION, MROPE, IMROPE, cannot be combined with NORMAL or NEOX + // sections specify how many dimensions to rotate in each section: + // section length is equivalent to number of cos/sin pairs, NOT the number of dims + // (i.e. sum of 4 sections are expected to be n_dims/2) + // last sections can be 0, means ignored + // all other options are identical to ggml_rope_ext + // + // important note: + // - NEOX ordering is automatically applied and cannot be disabled for MROPE and VISION + // if you need normal ordering, there are 2 methods: + // (1) split the tensor manually using ggml_view + // (2) permute the weight upon conversion + // - for VISION, n_dims must be head_size/2 + // + // example M-RoPE: + // given sections = [t=4, y=2, x=2, 0] + // given a single head with size = 18 --> [000000000000000000] + // GGML_ROPE_TYPE_MROPE n_dims = 16 --> [ttttyyxxttttyyxx00] (cos/sin are applied in NEOX ordering) + // GGML_ROPE_TYPE_IMROPE n_dims = 16 --> [ttyxttyxttyxttyx00] (interleaved M-RoPE, still NEOX ordering) + // note: the theta for each dim is computed the same way as ggml_rope_ext, no matter the section + // in other words, idx used for theta: [0123456789... until n_dims/2], not reset for each section + // + // example vision RoPE: + // given sections = [y=4, x=4, 0, 0] (last 2 sections are ignored) + // given a single head with size = 8 --> [00000000] + // GGML_ROPE_TYPE_VISION n_dims = 4 --> [yyyyxxxx] + // other values of n_dims are untested and is undefined behavior + // note: unlike MROPE, the theta for each dim is computed differently for each section + // in other words, idx used for theta: [0123] for y section, then [0123] for x section GGML_API struct ggml_tensor * ggml_rope_multi( struct ggml_context * ctx, struct ggml_tensor * a, From 182db04cb2e6ce68b5bfa17571222b179f3840ae Mon Sep 17 00:00:00 2001 From: Valeriy Dubov Date: Wed, 15 Apr 2026 16:44:02 +0300 Subject: [PATCH 436/831] rpc : add native RDMA transport for RPC backend (RoCEv2) (llama/20590) --- ggml/include/ggml-rpc.h | 6 +- ggml/src/ggml-rpc/CMakeLists.txt | 23 ++ ggml/src/ggml-rpc/ggml-rpc.cpp | 610 +++++++++++++++++++++++++++++-- 3 files changed, 601 insertions(+), 38 deletions(-) diff --git a/ggml/include/ggml-rpc.h b/ggml/include/ggml-rpc.h index 1c11495b66e..6fcf5a43393 100644 --- a/ggml/include/ggml-rpc.h +++ b/ggml/include/ggml-rpc.h @@ -6,9 +6,9 @@ extern "C" { #endif -#define RPC_PROTO_MAJOR_VERSION 3 -#define RPC_PROTO_MINOR_VERSION 6 -#define RPC_PROTO_PATCH_VERSION 1 +#define RPC_PROTO_MAJOR_VERSION 4 +#define RPC_PROTO_MINOR_VERSION 0 +#define RPC_PROTO_PATCH_VERSION 0 #ifdef __cplusplus static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT has changed - update RPC_PROTO_PATCH_VERSION"); diff --git a/ggml/src/ggml-rpc/CMakeLists.txt b/ggml/src/ggml-rpc/CMakeLists.txt index f5acb8ec2cb..8671ce5ceaf 100644 --- a/ggml/src/ggml-rpc/CMakeLists.txt +++ b/ggml/src/ggml-rpc/CMakeLists.txt @@ -7,3 +7,26 @@ ggml_add_backend_library(ggml-rpc if (WIN32) target_link_libraries(ggml-rpc PRIVATE ws2_32) endif() + +# RDMA auto-detection (Linux only, requires libibverbs) +if (NOT WIN32 AND NOT APPLE) + find_library(IBVERBS_LIB ibverbs) + if (IBVERBS_LIB) + option(GGML_RPC_RDMA "ggml: enable RDMA transport for RPC" ON) + else() + option(GGML_RPC_RDMA "ggml: enable RDMA transport for RPC" OFF) + endif() +else() + set(GGML_RPC_RDMA OFF CACHE BOOL "RDMA not available on this platform" FORCE) +endif() + +if (GGML_RPC_RDMA) + if (NOT IBVERBS_LIB) + find_library(IBVERBS_LIB ibverbs REQUIRED) + endif() + target_compile_definitions(ggml-rpc PRIVATE GGML_RPC_RDMA) + target_link_libraries(ggml-rpc PRIVATE ${IBVERBS_LIB}) + message(STATUS " RDMA transport enabled (auto-detected)") +else() + message(STATUS " RDMA transport disabled") +endif() diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index 61bfcc5a675..017ef0af360 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -3,7 +3,9 @@ #include "ggml-backend-impl.h" #include "ggml-cpp.h" +#include #include +#include #include #include #include @@ -31,6 +33,14 @@ #include #include +#ifdef GGML_RPC_RDMA +# include +# include +# ifndef _WIN32 +# include +# endif +#endif // GGML_RPC_RDMA + static const char * RPC_DEBUG = std::getenv("GGML_RPC_DEBUG"); #define LOG_DBG(...) \ @@ -49,17 +59,116 @@ typedef int sockfd_t; #endif // cross-platform socket + +#ifdef GGML_RPC_RDMA +static constexpr size_t RDMA_CHUNK = 256 * 1024; // 256 KiB per send/recv (fits default 8 MiB memlock) +static constexpr int RDMA_RX_DEPTH = 24; // pre-posted recv ring: 24 × 256 KiB = 6 MiB +static constexpr size_t RDMA_GID_SIZE = 16; // RoCE GID / IB GID is always 16 bytes +using rdma_gid_t = std::array; + +struct rdma_conn { + struct ibv_context * ctx = nullptr; + struct ibv_pd * pd = nullptr; + struct ibv_cq * scq = nullptr; // send completions + struct ibv_cq * rcq = nullptr; // recv completions + struct ibv_qp * qp = nullptr; + + void * tx_buf = nullptr; + struct ibv_mr * tx_mr = nullptr; + + void * rx_buf = nullptr; // RDMA_RX_DEPTH × RDMA_CHUNK contiguous + struct ibv_mr * rx_mr = nullptr; + int rx_head = 0; + + uint32_t max_inline = 0; + + uint8_t * rx_slot(int i) const { + return static_cast(rx_buf) + static_cast(i) * RDMA_CHUNK; + } + + bool post_rx(int i) { + struct ibv_sge sge = {}; + sge.addr = (uintptr_t)rx_slot(i); + sge.length = RDMA_CHUNK; + sge.lkey = rx_mr->lkey; + struct ibv_recv_wr wr = {}, * bad = nullptr; + wr.wr_id = (uint64_t)i; + wr.sg_list = &sge; + wr.num_sge = 1; + return ibv_post_recv(qp, &wr, &bad) == 0; + } + + ~rdma_conn() { + if (tx_mr) ibv_dereg_mr(tx_mr); + if (rx_mr) ibv_dereg_mr(rx_mr); + free(tx_buf); + free(rx_buf); + if (qp) ibv_destroy_qp(qp); + if (scq) ibv_destroy_cq(scq); + if (rcq) ibv_destroy_cq(rcq); + if (pd) ibv_dealloc_pd(pd); + if (ctx) ibv_close_device(ctx); + } +}; + +// Local RDMA parameters captured during the probe phase and later consumed +// by rdma_activate() after the remote side's caps arrive via HELLO. +struct rdma_local_info { + uint32_t qpn = 0; + uint32_t psn = 0; + uint8_t gid[RDMA_GID_SIZE] = {}; + uint8_t ib_port = 0; + int gid_idx = 0; + enum ibv_mtu path_mtu = IBV_MTU_1024; +}; +#endif // GGML_RPC_RDMA + +// conn_caps size for transport-agnostic capability exchange +static constexpr size_t RPC_CONN_CAPS_SIZE = 24; + +// conn_caps RDMA layout helper +#ifdef GGML_RPC_RDMA +struct rdma_caps { + uint32_t qpn; + uint32_t psn; + uint8_t gid[RDMA_GID_SIZE]; +}; +static_assert(sizeof(rdma_caps) == RPC_CONN_CAPS_SIZE, "rdma_caps must match conn_caps size"); +#endif // GGML_RPC_RDMA + +// Forward declarations for transport function pointers +struct socket_t; +static bool tcp_send_impl(socket_t * sock, const void * data, size_t size); +static bool tcp_recv_impl(socket_t * sock, void * data, size_t size); + struct socket_t { sockfd_t fd; + bool (*fn_send)(socket_t *, const void *, size_t) = tcp_send_impl; + bool (*fn_recv)(socket_t *, void *, size_t) = tcp_recv_impl; +#ifdef GGML_RPC_RDMA + std::unique_ptr rdma; + rdma_local_info rdma_local = {}; +#endif // GGML_RPC_RDMA socket_t(sockfd_t fd) : fd(fd) {} ~socket_t() { +#ifdef GGML_RPC_RDMA + rdma.reset(); +#endif // GGML_RPC_RDMA LOG_DBG("[%s] closing socket %d\n", __func__, this->fd); #ifdef _WIN32 - closesocket(this->fd); + if (fd != INVALID_SOCKET) closesocket(this->fd); #else - close(this->fd); + if (fd >= 0) close(this->fd); #endif } + + // Advertise local transport capabilities into conn_caps. + // May probe RDMA and store the probe on this socket for update_caps. + void get_caps(uint8_t * caps); + + // Activate transport upgrade based on remote conn_caps using the probe + // previously stored by get_caps. + void update_caps(const uint8_t * remote_caps); }; // macro for nicer error messages on server crash @@ -115,10 +224,16 @@ static_assert(RPC_CMD_HELLO == 14, "RPC_CMD_HELLO must be always 14"); // Try RPC_CMD_SET_TENSOR_HASH first when data size is larger than this threshold const size_t HASH_THRESHOLD = 10 * 1024 * 1024; +struct rpc_msg_hello_req { + uint8_t conn_caps[RPC_CONN_CAPS_SIZE]; +}; + struct rpc_msg_hello_rsp { uint8_t major; uint8_t minor; uint8_t patch; + uint8_t padding; + uint8_t conn_caps[RPC_CONN_CAPS_SIZE]; }; struct rpc_msg_device_count_rsp { @@ -414,27 +529,414 @@ static bool recv_data(sockfd_t sockfd, void * data, size_t size) { return true; } -static bool send_msg(sockfd_t sockfd, const void * msg, size_t msg_size) { - if (!send_data(sockfd, &msg_size, sizeof(msg_size))) { +// TCP transport implementations (for function-pointer dispatch) + +static bool tcp_send_impl(socket_t * sock, const void * data, size_t size) { + return send_data(sock->fd, data, size); +} + +static bool tcp_recv_impl(socket_t * sock, void * data, size_t size) { + return recv_data(sock->fd, data, size); +} + +// RDMA transport (performance-optimized, auto-negotiated) + +#ifdef GGML_RPC_RDMA + +static bool rdma_send_impl(socket_t * sock, const void * data, size_t size); +static bool rdma_recv_impl(socket_t * sock, void * data, size_t size); + +static inline bool tcp_peer_closed(int fd) { + if (fd < 0) return false; +#ifndef _WIN32 + struct pollfd pfd = { fd, POLLIN | POLLRDHUP, 0 }; + int r = poll(&pfd, 1, 0); + return r > 0 && (pfd.revents & (POLLHUP | POLLERR | POLLRDHUP)); +#else + return false; +#endif +} + +static inline bool rdma_poll(struct ibv_cq * cq, struct ibv_wc * wc, int tcp_fd) { + for (uint64_t s = 0; ; s++) { + int n = ibv_poll_cq(cq, 1, wc); + if (n > 0) { + if (wc->status != IBV_WC_SUCCESS) { + GGML_LOG_ERROR("RDMA CQ wc error: status=%d (%s) vendor_err=0x%x\n", + wc->status, ibv_wc_status_str(wc->status), wc->vendor_err); + } + return wc->status == IBV_WC_SUCCESS; + } + if (n < 0) return false; + if ((s & 0xFFFFF) == 0 && s > 0) { + if (tcp_peer_closed(tcp_fd)) { + return false; + } + } + } +} + +static bool rdma_send(rdma_conn * c, const void * data, size_t size, int tcp_fd) { + const uint8_t * src = (const uint8_t *)data; + size_t rem = size; + while (rem > 0) { + size_t chunk = std::min(rem, RDMA_CHUNK); + + struct ibv_sge sge = {}; + struct ibv_send_wr wr = {}, * bad = nullptr; + wr.opcode = IBV_WR_SEND; + wr.sg_list = &sge; + wr.num_sge = 1; + + if (chunk <= c->max_inline) { + sge.addr = (uintptr_t)src; + sge.length = chunk; + wr.send_flags = IBV_SEND_SIGNALED | IBV_SEND_INLINE; + } else { + memcpy(c->tx_buf, src, chunk); + sge.addr = (uintptr_t)c->tx_buf; + sge.length = chunk; + sge.lkey = c->tx_mr->lkey; + wr.send_flags = IBV_SEND_SIGNALED; + } + + if (ibv_post_send(c->qp, &wr, &bad) != 0) return false; + struct ibv_wc wc; + if (!rdma_poll(c->scq, &wc, tcp_fd)) return false; + + src += chunk; + rem -= chunk; + } + return true; +} + + +static bool rdma_recv(rdma_conn * c, void * data, size_t size, int tcp_fd) { + uint8_t * dst = (uint8_t *)data; + size_t rem = size; + while (rem > 0) { + struct ibv_wc wc; + if (!rdma_poll(c->rcq, &wc, tcp_fd)) return false; + + int slot = (int)wc.wr_id; + size_t got = wc.byte_len; + memcpy(dst, c->rx_slot(slot), got); + + if (!c->post_rx(slot)) return false; + + dst += got; + rem -= got; + } + return true; +} + +static bool rdma_send_impl(socket_t * sock, const void * data, size_t size) { + return rdma_send(sock->rdma.get(), data, size, sock->fd); +} + +static bool rdma_recv_impl(socket_t * sock, void * data, size_t size) { + return rdma_recv(sock->rdma.get(), data, size, sock->fd); +} + +// Build a RoCE GID-shaped 16-byte target from a TCP socket's local address. +// Used to match the socket's local IP against the kernel's GID table so that +// a single memcmp handles IPv4, IPv4-mapped IPv6, and native IPv6 uniformly: +// AF_INET -> ::ffff:a.b.c.d (bytes 10-11 = 0xff, last 4 = IPv4) +// AF_INET6 (IPv4-mapped) -> ::ffff:a.b.c.d (already in GID shape) +// AF_INET6 (native v6) -> the 16-byte IPv6 address as-is +// Returns std::nullopt on unsupported family or getsockname failure. +static std::optional rdma_build_target_gid(sockfd_t tcp_fd) { + sockaddr_storage addr = {}; + socklen_t addr_len = sizeof(addr); + if (getsockname(tcp_fd, reinterpret_cast(&addr), &addr_len) != 0) { + return std::nullopt; + } + rdma_gid_t target = {}; + if (addr.ss_family == AF_INET) { + const auto * a = reinterpret_cast(&addr); + target[10] = 0xff; + target[11] = 0xff; + memcpy(&target[12], &a->sin_addr, 4); + return target; + } + if (addr.ss_family == AF_INET6) { + const auto * a = reinterpret_cast(&addr); + memcpy(target.data(), &a->sin6_addr, RDMA_GID_SIZE); + return target; + } + return std::nullopt; +} + +static rdma_conn * rdma_probe(sockfd_t tcp_fd, rdma_local_info * out) { + const char * dev_env = std::getenv("GGML_RDMA_DEV"); + const char * gid_env = std::getenv("GGML_RDMA_GID"); + + auto target_gid = rdma_build_target_gid(tcp_fd); + if (!target_gid) { + return nullptr; + } + + const uint8_t ib_port = 1; + int num_devs = 0; + ibv_device ** devs = ibv_get_device_list(&num_devs); + if (!devs || num_devs == 0) return nullptr; + + ibv_context * ibctx = nullptr; + const char * matched_dev = nullptr; + int gid_idx = gid_env ? atoi(gid_env) : -1; + int gid_version = IBV_GID_TYPE_IB; // 0 = unknown/IB + + for (int d = 0; d < num_devs; d++) { + const char * dn = ibv_get_device_name(devs[d]); + if (dev_env && strcmp(dev_env, dn) != 0) continue; + + ibv_context * ctx = ibv_open_device(devs[d]); + if (!ctx) continue; + + ibv_port_attr pa; + if (ibv_query_port(ctx, ib_port, &pa) != 0) { ibv_close_device(ctx); continue; } + + int found_gid = gid_idx; + int found_version = IBV_GID_TYPE_IB; + if (found_gid < 0) { + // Find a GID on this port whose bytes equal the local TCP address + // (IPv4 or IPv6). Prefer RoCE v2 (UDP/IP, L3-routable) over v1 + // (raw Ethernet, same-L2 only) so silent hangs on L3-routed paths + // are avoided. ibv_query_gid_ex returns gid+type in one call. + int v2_idx = -1; + int v1_idx = -1; + for (int i = 0; i < pa.gid_tbl_len; i++) { + ibv_gid_entry entry = {}; + if (ibv_query_gid_ex(ctx, ib_port, i, &entry, 0) != 0) continue; + if (memcmp(entry.gid.raw, target_gid->data(), RDMA_GID_SIZE) != 0) continue; + if (entry.gid_type == IBV_GID_TYPE_ROCE_V2 && v2_idx < 0) { + v2_idx = i; + } else if (entry.gid_type == IBV_GID_TYPE_ROCE_V1 && v1_idx < 0) { + v1_idx = i; + } + } + if (v2_idx >= 0) { + found_gid = v2_idx; + found_version = IBV_GID_TYPE_ROCE_V2; + } else if (v1_idx >= 0) { + found_gid = v1_idx; + found_version = IBV_GID_TYPE_ROCE_V1; + } + } else { + // Explicit GID index from GGML_RDMA_GID — fetch its type for logging. + ibv_gid_entry entry = {}; + if (ibv_query_gid_ex(ctx, ib_port, found_gid, &entry, 0) == 0) { + found_version = entry.gid_type; + } + } + if (found_gid >= 0) { + ibctx = ctx; + gid_idx = found_gid; + gid_version = found_version; + matched_dev = dn; + out->path_mtu = pa.active_mtu; + break; + } + ibv_close_device(ctx); + } + ibv_free_device_list(devs); + if (!ibctx) return nullptr; + + out->ib_port = ib_port; + out->gid_idx = gid_idx; + + // unique_ptr owns ibctx and every subsequent resource via ~rdma_conn(), + // so each failure path is a plain `return nullptr;`. + auto c = std::make_unique(); + c->ctx = ibctx; + + c->pd = ibv_alloc_pd(ibctx); + if (!c->pd) return nullptr; + + c->scq = ibv_create_cq(ibctx, 16, nullptr, nullptr, 0); + c->rcq = ibv_create_cq(ibctx, RDMA_RX_DEPTH + 4, nullptr, nullptr, 0); + if (!c->scq || !c->rcq) return nullptr; + + ibv_qp_init_attr qia = {}; + qia.send_cq = c->scq; + qia.recv_cq = c->rcq; + qia.qp_type = IBV_QPT_RC; + qia.cap.max_send_wr = 4; + qia.cap.max_recv_wr = RDMA_RX_DEPTH + 4; + qia.cap.max_send_sge = 1; + qia.cap.max_recv_sge = 1; + qia.cap.max_inline_data = 256; + + c->qp = ibv_create_qp(c->pd, &qia); + if (!c->qp) return nullptr; + c->max_inline = qia.cap.max_inline_data; + + c->tx_buf = aligned_alloc(4096, RDMA_CHUNK); + c->rx_buf = aligned_alloc(4096, static_cast(RDMA_RX_DEPTH) * RDMA_CHUNK); + if (!c->tx_buf || !c->rx_buf) return nullptr; + + c->tx_mr = ibv_reg_mr(c->pd, c->tx_buf, RDMA_CHUNK, IBV_ACCESS_LOCAL_WRITE); + c->rx_mr = ibv_reg_mr(c->pd, c->rx_buf, static_cast(RDMA_RX_DEPTH) * RDMA_CHUNK, + IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE); + if (!c->tx_mr || !c->rx_mr) return nullptr; + + ibv_gid local_gid; + if (ibv_query_gid(ibctx, ib_port, gid_idx, &local_gid) != 0) return nullptr; + + out->qpn = c->qp->qp_num; + out->psn = c->qp->qp_num & 0xffffff; + memcpy(out->gid, &local_gid, RDMA_GID_SIZE); + + const char * ver_str = ""; + if (gid_version == IBV_GID_TYPE_ROCE_V2) { + ver_str = " RoCEv2"; + } else if (gid_version == IBV_GID_TYPE_ROCE_V1) { + ver_str = " RoCEv1"; + } + GGML_LOG_INFO("RDMA probed: dev=%s gid=%d%s qpn=%u inline=%u\n", + matched_dev, gid_idx, ver_str, out->qpn, c->max_inline); + return c.release(); +} + +// Phase 2: Given remote QPN/PSN/GID, transition QP: RESET->INIT->pre-post->RTR->RTS. +// On success, the connection is live and ready for rdma_send/rdma_recv. +static bool rdma_activate(rdma_conn * c, const rdma_local_info * local, + uint32_t remote_qpn, uint32_t remote_psn, const uint8_t * remote_gid) { + // RESET -> INIT + { + struct ibv_qp_attr a = {}; + a.qp_state = IBV_QPS_INIT; + a.port_num = local->ib_port; + a.pkey_index = 0; + a.qp_access_flags = IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ | IBV_ACCESS_LOCAL_WRITE; + if (ibv_modify_qp(c->qp, &a, + IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS) != 0) { + return false; + } + } + + for (int i = 0; i < RDMA_RX_DEPTH; i++) { + if (!c->post_rx(i)) return false; + } + + // INIT -> RTR + { + struct ibv_qp_attr a = {}; + a.qp_state = IBV_QPS_RTR; + a.path_mtu = local->path_mtu; + a.dest_qp_num = remote_qpn; + a.rq_psn = remote_psn; + a.max_dest_rd_atomic = 1; + a.min_rnr_timer = 1; + a.ah_attr.is_global = 1; + memcpy(&a.ah_attr.grh.dgid, remote_gid, RDMA_GID_SIZE); + a.ah_attr.grh.hop_limit = 1; + a.ah_attr.grh.sgid_index = local->gid_idx; + a.ah_attr.dlid = 0; + a.ah_attr.port_num = local->ib_port; + if (ibv_modify_qp(c->qp, &a, + IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | IBV_QP_DEST_QPN | + IBV_QP_RQ_PSN | IBV_QP_MAX_DEST_RD_ATOMIC | IBV_QP_MIN_RNR_TIMER) != 0) { + return false; + } + } + + // RTR -> RTS + { + struct ibv_qp_attr a = {}; + a.qp_state = IBV_QPS_RTS; + a.timeout = 14; + a.retry_cnt = 7; + a.rnr_retry = 7; + a.sq_psn = local->psn; + a.max_rd_atomic = 1; + if (ibv_modify_qp(c->qp, &a, + IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT | IBV_QP_RNR_RETRY | + IBV_QP_SQ_PSN | IBV_QP_MAX_QP_RD_ATOMIC) != 0) { + return false; + } + } + + GGML_LOG_INFO("RDMA activated: qpn=%u->%u mtu=%d rx_depth=%d\n", + local->qpn, remote_qpn, 128 << local->path_mtu, RDMA_RX_DEPTH); + return true; +} + +#endif // GGML_RPC_RDMA + +// --------------------------------------------------------------------------- +// socket_t transport capability methods +// --------------------------------------------------------------------------- + +void socket_t::get_caps(uint8_t * caps) { + memset(caps, 0, RPC_CONN_CAPS_SIZE); +#ifdef GGML_RPC_RDMA + rdma_local = {}; + rdma.reset(rdma_probe(fd, &rdma_local)); + if (rdma) { + rdma_caps rc = {}; + rc.qpn = rdma_local.qpn; + rc.psn = rdma_local.psn; + memcpy(rc.gid, rdma_local.gid, RDMA_GID_SIZE); + memcpy(caps, &rc, sizeof(rc)); + } +#endif // GGML_RPC_RDMA +} + +void socket_t::update_caps(const uint8_t * remote_caps) { +#ifdef GGML_RPC_RDMA + if (!rdma) { + return; + } + rdma_caps rc = {}; + memcpy(&rc, remote_caps, sizeof(rc)); + if (rc.qpn == 0) { + rdma.reset(); + return; + } + if (rdma_activate(rdma.get(), &rdma_local, rc.qpn, rc.psn, rc.gid)) { + fn_send = rdma_send_impl; + fn_recv = rdma_recv_impl; + } else { + GGML_LOG_ERROR("RDMA activate failed, staying on TCP\n"); + rdma.reset(); + } +#else + (void)remote_caps; +#endif // GGML_RPC_RDMA +} + +// unified transport dispatch (via function pointers) + +static bool send_data(socket_t * sock, const void * data, size_t size) { + return sock->fn_send(sock, data, size); +} + +static bool recv_data(socket_t * sock, void * data, size_t size) { + return sock->fn_recv(sock, data, size); +} + +static bool send_msg(socket_t * sock, const void * msg, size_t msg_size) { + if (!send_data(sock, &msg_size, sizeof(msg_size))) { return false; } - return send_data(sockfd, msg, msg_size); + return send_data(sock, msg, msg_size); } -static bool recv_msg(sockfd_t sockfd, void * msg, size_t msg_size) { +static bool recv_msg(socket_t * sock, void * msg, size_t msg_size) { uint64_t size; - if (!recv_data(sockfd, &size, sizeof(size))) { + if (!recv_data(sock, &size, sizeof(size))) { return false; } if (size != msg_size) { return false; } - return recv_data(sockfd, msg, msg_size); + return recv_data(sock, msg, msg_size); } -static bool recv_msg(sockfd_t sockfd, std::vector & input) { +static bool recv_msg(socket_t * sock, std::vector & input) { uint64_t size; - if (!recv_data(sockfd, &size, sizeof(size))) { + if (!recv_data(sock, &size, sizeof(size))) { return false; } try { @@ -443,7 +945,7 @@ static bool recv_msg(sockfd_t sockfd, std::vector & input) { GGML_LOG_ERROR("Failed to allocate input buffer of size %" PRIu64 "\n", size); return false; } - return recv_data(sockfd, input.data(), size); + return recv_data(sock, input.data(), size); } static bool parse_endpoint(const std::string & endpoint, std::string & host, int & port) { @@ -452,7 +954,11 @@ static bool parse_endpoint(const std::string & endpoint, std::string & host, int return false; } host = endpoint.substr(0, pos); - port = std::stoi(endpoint.substr(pos + 1)); + try { + port = std::stoi(endpoint.substr(pos + 1)); + } catch (...) { + return false; + } return true; } @@ -460,13 +966,13 @@ static bool parse_endpoint(const std::string & endpoint, std::string & host, int // No response static bool send_rpc_cmd(const std::shared_ptr & sock, enum rpc_cmd cmd, const void * input, size_t input_size) { uint8_t cmd_byte = cmd; - if (!send_data(sock->fd, &cmd_byte, sizeof(cmd_byte))) { + if (!send_data(sock.get(), &cmd_byte, sizeof(cmd_byte))) { return false; } - if (!send_data(sock->fd, &input_size, sizeof(input_size))) { + if (!send_data(sock.get(), &input_size, sizeof(input_size))) { return false; } - if (!send_data(sock->fd, input, input_size)) { + if (!send_data(sock.get(), input, input_size)) { return false; } return true; @@ -478,16 +984,14 @@ static bool send_rpc_cmd(const std::shared_ptr & sock, enum rpc_cmd cm if (!send_rpc_cmd(sock, cmd, input, input_size)) { return false; } - // TODO: currently the output_size is always known, do we need support for commands with variable output size? - // even if we do, we can skip sending output_size from the server for commands with known output size uint64_t out_size; - if (!recv_data(sock->fd, &out_size, sizeof(out_size))) { + if (!recv_data(sock.get(), &out_size, sizeof(out_size))) { return false; } if (out_size != output_size) { return false; } - if (!recv_data(sock->fd, output, output_size)) { + if (!recv_data(sock.get(), output, output_size)) { return false; } return true; @@ -495,17 +999,25 @@ static bool send_rpc_cmd(const std::shared_ptr & sock, enum rpc_cmd cm // RPC client-side implementation -static bool check_server_version(const std::shared_ptr & sock) { - rpc_msg_hello_rsp response; - bool status = send_rpc_cmd(sock, RPC_CMD_HELLO, nullptr, 0, &response, sizeof(response)); +// Performs HELLO handshake with transport auto-negotiation. +// Advertises local capabilities via conn_caps; if the server responds with +// matching capabilities, the socket is upgraded transparently. +static bool negotiate_hello(const std::shared_ptr & sock) { + rpc_msg_hello_req request = {}; + rpc_msg_hello_rsp response = {}; + + sock->get_caps(request.conn_caps); + + bool status = send_rpc_cmd(sock, RPC_CMD_HELLO, &request, sizeof(request), &response, sizeof(response)); RPC_STATUS_ASSERT(status); + if (response.major != RPC_PROTO_MAJOR_VERSION || response.minor > RPC_PROTO_MINOR_VERSION) { - GGML_LOG_ERROR("RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch); + GGML_LOG_ERROR("RPC server version mismatch: %d.%d.%d\n", + response.major, response.minor, response.patch); return false; } - if (response.minor != RPC_PROTO_MINOR_VERSION || response.patch != RPC_PROTO_PATCH_VERSION) { - GGML_LOG_INFO("WARNING: RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch); - } + + sock->update_caps(response.conn_caps); return true; } @@ -527,6 +1039,7 @@ static std::shared_ptr get_socket(const std::string & endpoint) { GGML_LOG_ERROR("Failed to parse endpoint: %s\n", endpoint.c_str()); return nullptr; } + #ifdef _WIN32 if (!initialized) { WSADATA wsaData; @@ -543,10 +1056,10 @@ static std::shared_ptr get_socket(const std::string & endpoint) { if (sock == nullptr) { return nullptr; } - if (!check_server_version(sock)) { + if (!negotiate_hello(sock)) { return nullptr; } - LOG_DBG("[%s] connected to %s, sockfd=%d\n", __func__, endpoint.c_str(), sock->fd); + LOG_DBG("[%s] connected to %s\n", __func__, endpoint.c_str()); sockets[endpoint] = sock; return sock; } @@ -1597,25 +2110,46 @@ rpc_server::~rpc_server() { } static void rpc_serve_client(const std::vector & backends, const char * cache_dir, - sockfd_t sockfd) { + socket_t * sockfd) { rpc_server server(backends, cache_dir); uint8_t cmd; if (!recv_data(sockfd, &cmd, 1)) { return; } - // the first command sent by the client must be HELLO if (cmd != RPC_CMD_HELLO) { GGML_LOG_ERROR("Expected HELLO command, update client\n"); return; } - if (!recv_msg(sockfd, nullptr, 0)) { + + // Read input_size and validate protocol version + uint64_t hello_input_size; + if (!recv_data(sockfd, &hello_input_size, sizeof(hello_input_size))) { return; } - rpc_msg_hello_rsp response; - server.hello(response); - if (!send_msg(sockfd, &response, sizeof(response))) { + + if (hello_input_size != sizeof(rpc_msg_hello_req)) { + GGML_LOG_ERROR("HELLO request size mismatch (%zu vs %zu) — client needs upgrade to protocol v%d.x\n", + (size_t)hello_input_size, sizeof(rpc_msg_hello_req), RPC_PROTO_MAJOR_VERSION); + return; + } + + rpc_msg_hello_req req = {}; + if (!recv_data(sockfd, &req, sizeof(req))) { return; } + + rpc_msg_hello_rsp rsp = {}; + server.hello(rsp); + + // Advertise server transport capabilities based on client's caps + sockfd->get_caps(rsp.conn_caps); + + if (!send_msg(sockfd, &rsp, sizeof(rsp))) { + return; + } + + // Activate transport upgrade using client's caps + sockfd->update_caps(req.conn_caps); while (true) { if (!recv_data(sockfd, &cmd, 1)) { break; @@ -1884,6 +2418,12 @@ void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir if (!parse_endpoint(endpoint, host, port)) { return; } + +#ifdef GGML_RPC_RDMA + printf(" transport : TCP (RDMA auto-negotiate enabled)\n"); +#else + printf(" transport : TCP\n"); +#endif // GGML_RPC_RDMA #ifdef _WIN32 { WSADATA wsaData; @@ -1907,7 +2447,7 @@ void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir } printf("Accepted client connection\n"); fflush(stdout); - rpc_serve_client(backends, cache_dir, client_socket->fd); + rpc_serve_client(backends, cache_dir, client_socket.get()); printf("Client connection closed\n"); fflush(stdout); } From 7e57b20d533b2854738e005db9c4c8aa510d67bc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Wed, 15 Apr 2026 15:58:40 +0200 Subject: [PATCH 437/831] CUDA: manage NCCL communicators in context (llama/21891) * CUDA: manage NCCL communicators in context * add check that all backends are CUDA * remove unused vector, limit init to > 1 GPUs * fix warnings * fix cuda device, cache allreduce --- ggml/include/ggml-backend.h | 7 +- ggml/src/ggml-backend-meta.cpp | 37 +++++++--- ggml/src/ggml-cuda/common.cuh | 4 -- ggml/src/ggml-cuda/ggml-cuda.cu | 118 +++++++++++++++++++++++--------- 4 files changed, 119 insertions(+), 47 deletions(-) diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h index 4a8f6d4287d..d0c7e5a1be0 100644 --- a/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h @@ -202,8 +202,11 @@ extern "C" { // Common functions that may be obtained using ggml_backend_reg_get_proc_address - // AllReduce operation for tensor parallelism (meta backend) - typedef bool (*ggml_backend_allreduce_tensor_t)(ggml_backend_t * backends, struct ggml_tensor ** tensors, size_t n_backends); + // Context management and operations for faster communication between backends, used for tensor parallelism (meta backend) + typedef void * (*ggml_backend_comm_init_t)(ggml_backend_t * backends, size_t n_backends); + typedef void (*ggml_backend_comm_free_t)(void * comm_ctx); + typedef bool (*ggml_backend_comm_allreduce_tensor_t)(void * comm_ctx, struct ggml_tensor ** tensors); + // Split buffer type for tensor parallelism (old) typedef ggml_backend_buffer_type_t (*ggml_backend_split_buffer_type_t)(int main_device, const float * tensor_split); // Set the number of threads for the backend diff --git a/ggml/src/ggml-backend-meta.cpp b/ggml/src/ggml-backend-meta.cpp index 0a8eea4e945..1ee3eeb4d96 100644 --- a/ggml/src/ggml-backend-meta.cpp +++ b/ggml/src/ggml-backend-meta.cpp @@ -1419,22 +1419,48 @@ struct ggml_backend_meta_context { size_t max_tmp_size = 0; size_t max_subgraphs = 0; + void * comm_ctx = nullptr; + ggml_backend_comm_allreduce_tensor_t comm_allreduce = nullptr; + ggml_backend_meta_context(ggml_backend_dev_t meta_dev, const char * params) { const size_t n_devs = ggml_backend_meta_dev_n_devs(meta_dev); name = "Meta("; + std::vector simple_backends; backend_configs.reserve(n_devs); + simple_backends.reserve(n_devs); for (size_t i = 0; i < n_devs; i++) { ggml_backend_dev_t simple_dev = ggml_backend_meta_dev_simple_dev(meta_dev, i); if (i > 0) { name += ","; } name += ggml_backend_dev_name(simple_dev); - backend_configs.emplace_back(ggml_backend_dev_init(simple_dev, params)); + simple_backends.push_back(ggml_backend_dev_init(simple_dev, params)); + backend_configs.emplace_back(simple_backends.back()); } name += ")"; + + if (n_devs > 1) { + ggml_backend_comm_init_t comm_init = (ggml_backend_comm_init_t) ggml_backend_reg_get_proc_address( + ggml_backend_dev_backend_reg(ggml_backend_get_device(simple_backends[0])), "ggml_backend_comm_init"); + if (comm_init != nullptr) { + comm_ctx = comm_init(simple_backends.data(), simple_backends.size()); + } + } + if (comm_ctx != nullptr) { + comm_allreduce = (ggml_backend_comm_allreduce_tensor_t) + ggml_backend_reg_get_proc_address(ggml_backend_dev_backend_reg( + ggml_backend_get_device(simple_backends[0])), "ggml_backend_comm_allreduce_tensor"); + GGML_ASSERT(comm_allreduce != nullptr); + } } ~ggml_backend_meta_context() { + if (comm_ctx != nullptr) { + ggml_backend_comm_free_t comm_free = (ggml_backend_comm_free_t) ggml_backend_reg_get_proc_address( + ggml_backend_dev_backend_reg(ggml_backend_get_device(backend_configs[0].backend)), "ggml_backend_comm_free"); + GGML_ASSERT(comm_free != nullptr); + comm_free(comm_ctx); + } for (auto & bc : backend_configs) { ggml_backend_free(bc.backend); } @@ -1845,20 +1871,15 @@ static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend, if (n_backends > 1 && i < n_subgraphs - 1) { bool backend_allreduce_success = false; - ggml_backend_allreduce_tensor_t allreduce_tensor = (ggml_backend_allreduce_tensor_t) ggml_backend_reg_get_proc_address( - ggml_backend_dev_backend_reg(ggml_backend_get_device(backend_ctx->backend_configs[0].backend)), "ggml_backend_allreduce_tensor"); - if (allreduce_tensor) { - std::vector backends; - backends.reserve(n_backends); + if (backend_ctx->comm_ctx) { std::vector nodes; nodes.reserve(n_backends); for (size_t j = 0; j < n_backends; j++) { auto & bcj = backend_ctx->backend_configs[j]; - backends.push_back(bcj.backend); ggml_cgraph * cgraph_ij = bcj.cgraphs[i].cgraph_main; nodes.push_back(cgraph_ij->nodes[cgraph_ij->n_nodes-1]); } - backend_allreduce_success = allreduce_tensor(backends.data(), nodes.data(), n_backends); + backend_allreduce_success = backend_ctx->comm_allreduce(backend_ctx->comm_ctx, nodes.data()); } if (!backend_allreduce_success) { diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 8a4246223b5..2e5eaff9bf4 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -1092,10 +1092,6 @@ struct ggml_cuda_device_info { cuda_device_info devices[GGML_CUDA_MAX_DEVICES] = {}; std::array default_tensor_split = {}; - -#ifdef GGML_USE_NCCL - ncclComm_t comms[GGML_CUDA_MAX_DEVICES]; -#endif // GGML_USE_NCCL }; const ggml_cuda_device_info & ggml_cuda_info(); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 3113de017f0..5d81befec32 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -338,14 +338,6 @@ static ggml_cuda_device_info ggml_cuda_init() { } } -#ifdef GGML_USE_NCCL - int dev_ids[GGML_CUDA_MAX_DEVICES]; - for (int id = 0; id < info.device_count; ++id) { - dev_ids[id] = id; - } - NCCL_CHECK(ncclCommInitAll(info.comms, info.device_count, dev_ids)); -#endif // GGML_USE_NCCL - return info; } @@ -1125,7 +1117,69 @@ static const ggml_backend_buffer_type_i ggml_backend_cuda_split_buffer_type_inte /* .is_host = */ ggml_backend_cuda_split_buffer_type_is_host, }; -bool ggml_backend_cuda_allreduce_tensor(ggml_backend_t * backends, struct ggml_tensor ** tensors, size_t n_backends) { +#ifdef GGML_USE_NCCL +struct ggml_backend_cuda_comm_context { + std::vector backends; + std::vector comms; + + ~ggml_backend_cuda_comm_context() { + for (ncclComm_t comm : comms) { + NCCL_CHECK(ncclCommDestroy(comm)); + } + } +}; +#endif // GGML_USE_NCCL + +static void ggml_backend_cuda_comm_free(void * comm_ctx_v) { +#ifdef GGML_USE_NCCL + if (comm_ctx_v == nullptr) { + return; + } + ggml_backend_cuda_comm_context * comm_ctx = (ggml_backend_cuda_comm_context *) comm_ctx_v; + delete comm_ctx; +#else + GGML_UNUSED(comm_ctx_v); +#endif // GGML_USE_NCCL +} + +static void * ggml_backend_cuda_comm_init(ggml_backend_t * backends, size_t n_backends) { +#ifdef GGML_USE_NCCL + for (size_t i = 0; i < n_backends; i++) { + if (!ggml_backend_is_cuda(backends[i])) { + return nullptr; + } + } + ggml_backend_cuda_comm_context * ret = new ggml_backend_cuda_comm_context; + std::vector dev_ids; + ret->backends.reserve(n_backends); + dev_ids.reserve(n_backends); + for (size_t i = 0; i < n_backends; i++) { + ret->backends.push_back(backends[i]); + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backends[i]->context; + dev_ids.push_back(cuda_ctx->device); + } + + ret->comms.resize(n_backends); + NCCL_CHECK(ncclCommInitAll(ret->comms.data(), n_backends, dev_ids.data())); + return ret; +#else + // If NCCL is installed it is used by default for optimal performance. + // However, NVIDIA does not distribute NCCL with CUDA so users may be unwittingly missing this package. + // RCCL is disabled by default, users are explicitly opting in. + // Therefore print no warning for RCCL. +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + static bool warning_printed = false; + if (!warning_printed) { + GGML_LOG_WARN("%s: NVIDIA Collective Communications Library (NCCL) is unavailable, multi GPU performance will be suboptimal\n", __func__); + warning_printed = true; + } +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + GGML_UNUSED_VARS(backends, n_backends); + return nullptr; +#endif // GGML_USE_NCCL +} + +static bool ggml_backend_cuda_comm_allreduce_tensor(void * comm_ctx_v, struct ggml_tensor ** tensors) { #ifdef GGML_USE_NCCL const int64_t ne = ggml_nelements(tensors[0]); // FIXME the input of llm_graph_context::build_in_out_ids can produce a tensor with 0 elements if n_outputs == 0 @@ -1133,21 +1187,24 @@ bool ggml_backend_cuda_allreduce_tensor(ggml_backend_t * backends, struct ggml_t if (ne == 0) { return true; } + + GGML_ASSERT(comm_ctx_v != nullptr); + ggml_backend_cuda_comm_context * comm_ctx = (ggml_backend_cuda_comm_context *) comm_ctx_v; + const size_t n_backends = comm_ctx->backends.size(); + for (size_t i = 0; i < n_backends; ++i) { GGML_ASSERT(tensors[i] != nullptr); GGML_ASSERT(ggml_nelements(tensors[i]) == ne); GGML_ASSERT(ggml_is_contiguously_allocated(tensors[i])); } - const ggml_cuda_device_info info = ggml_cuda_info(); - // For small tensors, simply reduce them as FP32. // The following heuristic for how "small" a tensor should be is based on RTX 4090s connected via 16x PCIe 4.0. if ((n_backends <= 2 && ne < 32768) || (n_backends == 3 && ne < 131072) || (n_backends >= 4 && ne < 262144)) { NCCL_CHECK(ncclGroupStart()); for (size_t i = 0; i < n_backends; ++i) { - ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backends[i]->context; - NCCL_CHECK(ncclAllReduce(tensors[i]->data, tensors[i]->data, ne, ncclFloat, ncclSum, info.comms[cuda_ctx->device], cuda_ctx->stream())); + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) comm_ctx->backends[i]->context; + NCCL_CHECK(ncclAllReduce(tensors[i]->data, tensors[i]->data, ne, ncclFloat, ncclSum, comm_ctx->comms[i], cuda_ctx->stream())); } NCCL_CHECK(ncclGroupEnd()); @@ -1160,44 +1217,33 @@ bool ggml_backend_cuda_allreduce_tensor(ggml_backend_t * backends, struct ggml_t ggml_cuda_pool_alloc tmp[GGML_CUDA_MAX_DEVICES]; for (size_t i = 0; i < n_backends; ++i) { - ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backends[i]->context; + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) comm_ctx->backends[i]->context; tmp[i].pool = &cuda_ctx->pool(); tmp[i].alloc(ne); - ggml_cuda_set_device(i); + ggml_cuda_set_device(cuda_ctx->device); to_bf16(tensors[i]->data, tmp[i].get(), ne, cuda_ctx->stream()); CUDA_CHECK(cudaGetLastError()); } NCCL_CHECK(ncclGroupStart()); for (size_t i = 0; i < n_backends; ++i) { - ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backends[i]->context; - NCCL_CHECK(ncclAllReduce(tmp[i].get(), tmp[i].get(), ne, ncclBfloat16, ncclSum, info.comms[cuda_ctx->device], cuda_ctx->stream())); + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) comm_ctx->backends[i]->context; + NCCL_CHECK(ncclAllReduce(tmp[i].get(), tmp[i].get(), ne, ncclBfloat16, ncclSum, comm_ctx->comms[i], cuda_ctx->stream())); } NCCL_CHECK(ncclGroupEnd()); for (size_t i = 0; i < n_backends; ++i) { - ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backends[i]->context; + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) comm_ctx->backends[i]->context; - ggml_cuda_set_device(i); + ggml_cuda_set_device(cuda_ctx->device); to_fp32(tmp[i].get(), (float *) tensors[i]->data, ne, cuda_ctx->stream()); CUDA_CHECK(cudaGetLastError()); } return true; #else - // If NCCL is installed it is used by default for optimal performance. - // However, NVIDIA does not distribute NCCL with CUDA so users may be unwittingly missing this package. - // RCCL is disabled by default, users are explicitly opting in. - // Therefore print no warning for RCCL. -#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) - static bool warning_printed = false; - if (!warning_printed) { - GGML_LOG_WARN("%s: NVIDIA Collective Communications Library (NCCL) is unavailable, multi GPU performance will be suboptimal\n", __func__); - warning_printed = true; - } -#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) - GGML_UNUSED_VARS(backends, tensors, n_backends); + GGML_UNUSED_VARS(comm_ctx_v, tensors); return false; #endif // GGML_USE_NCCL } @@ -5220,8 +5266,14 @@ static ggml_backend_feature * ggml_backend_cuda_get_features(ggml_backend_reg_t static void * ggml_backend_cuda_reg_get_proc_address(ggml_backend_reg_t reg, const char * name) { GGML_UNUSED(reg); - if (strcmp(name, "ggml_backend_allreduce_tensor") == 0) { - return (void *)ggml_backend_cuda_allreduce_tensor; + if (strcmp(name, "ggml_backend_comm_init") == 0) { + return (void *)ggml_backend_cuda_comm_init; + } + if (strcmp(name, "ggml_backend_comm_free") == 0) { + return (void *)ggml_backend_cuda_comm_free; + } + if (strcmp(name, "ggml_backend_comm_allreduce_tensor") == 0) { + return (void *)ggml_backend_cuda_comm_allreduce_tensor; } if (strcmp(name, "ggml_backend_split_buffer_type") == 0) { return (void *)ggml_backend_cuda_split_buffer_type; From 9638e29657e7c547212284a2b473335c31a86a05 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Wed, 15 Apr 2026 16:01:46 +0200 Subject: [PATCH 438/831] CUDA: require explicit opt-in for P2P access (llama/21910) --- ggml/src/ggml-cuda/ggml-cuda.cu | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 5d81befec32..c17db3875ad 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -324,16 +324,18 @@ static ggml_cuda_device_info ggml_cuda_init() { // configure logging to stdout // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, nullptr)); - for (int id = 0; id < info.device_count; ++id) { - ggml_cuda_set_device(id); - for (int id_other = 0; id_other < info.device_count; ++id_other) { - if (id == id_other) { - continue; - } - int can_access_peer; - CUDA_CHECK(cudaDeviceCanAccessPeer(&can_access_peer, id, id_other)); - if (can_access_peer) { - CUDA_CHECK(cudaDeviceEnablePeerAccess(id_other, 0)); + if (getenv("GGML_CUDA_P2P") != nullptr) { + for (int id = 0; id < info.device_count; ++id) { + ggml_cuda_set_device(id); + for (int id_other = 0; id_other < info.device_count; ++id_other) { + if (id == id_other) { + continue; + } + int can_access_peer; + CUDA_CHECK(cudaDeviceCanAccessPeer(&can_access_peer, id, id_other)); + if (can_access_peer) { + CUDA_CHECK(cudaDeviceEnablePeerAccess(id_other, 0)); + } } } } From 2a785c596944da4cc67d15c8600f606b5021e7bb Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Wed, 15 Apr 2026 09:14:40 -0700 Subject: [PATCH 439/831] ggml-webgpu: Fix dequantization helpers to not pass in pointers (llama/21872) * Fix dequantization helpers to not pass in pointers * Increase XIELU precision --- .../wgsl-shaders/common_decls.tmpl | 73 +++++++++----- .../ggml-webgpu/wgsl-shaders/get_rows.wgsl | 90 +++++++++--------- .../src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl | 90 +++++++++--------- .../wgsl-shaders/mul_mat_decls.tmpl | 94 +++++++++---------- .../ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl | 2 + .../wgsl-shaders/mul_mat_reg_tile.wgsl | 2 + .../wgsl-shaders/mul_mat_subgroup_matrix.wgsl | 3 +- .../ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl | 48 +++++----- ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl | 13 +-- 9 files changed, 223 insertions(+), 192 deletions(-) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl index 0d3501c34a2..62fe72ee3b1 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl @@ -9,42 +9,65 @@ fn get_byte_i32(value: u32, index: u32) -> i32 { #endif #ifdef U32_DEQUANT_HELPERS -fn load_u16_at( - buf: ptr, read_write>, - byte_offset: u32) -> u32 { - let word = buf[byte_offset / 4]; - let shift = (byte_offset & 0x2) * 8; - return (word >> shift) & 0xFFFF; +#ifdef DECLARE_BYTE_LOADERS_SRC +fn load_u16_at_src(byte_offset: u32) -> u32 { + let word = src[byte_offset / 4u]; + let shift = (byte_offset & 0x2u) * 8u; + return (word >> shift) & 0xFFFFu; } -fn load_u32_at( - buf: ptr, read_write>, - byte_offset: u32) -> u32 { - let word_idx = byte_offset / 4; - let shift = (byte_offset & 0x3) * 8; - let lo = buf[word_idx]; - let hi = buf[word_idx + 1]; - let shifted = (lo >> shift) | (hi << (32 - shift)); - return select(shifted, lo, shift == 0); +fn load_u32_at_src(byte_offset: u32) -> u32 { + let word_idx = byte_offset / 4u; + let shift = (byte_offset & 0x3u) * 8u; + let lo = src[word_idx]; + let hi = src[word_idx + 1u]; + let shifted = (lo >> shift) | (hi << (32u - shift)); + return select(shifted, lo, shift == 0u); } -fn load_f16_at( - buf: ptr, read_write>, - byte_offset: u32) -> f16 { - let packed = unpack2x16float(load_u16_at(buf, byte_offset)); +fn load_f16_at_src(byte_offset: u32) -> f16 { + let packed = unpack2x16float(load_u16_at_src(byte_offset)); return f16(packed[0]); } -fn load_f16_as_f32_at( - buf: ptr, read_write>, - byte_offset: u32) -> f32 { - let word = buf[byte_offset / 4]; - let shift = (byte_offset & 0x2) * 8; - let d_bits = (word >> shift) & 0xFFFF; +fn load_f16_as_f32_at_src(byte_offset: u32) -> f32 { + let word = src[byte_offset / 4u]; + let shift = (byte_offset & 0x2u) * 8u; + let d_bits = (word >> shift) & 0xFFFFu; return unpack2x16float(d_bits)[0]; } #endif +#ifdef DECLARE_BYTE_LOADERS_SRC0 +fn load_u16_at_src0(byte_offset: u32) -> u32 { + let word = src0[byte_offset / 4u]; + let shift = (byte_offset & 0x2u) * 8u; + return (word >> shift) & 0xFFFFu; +} + +fn load_u32_at_src0(byte_offset: u32) -> u32 { + let word_idx = byte_offset / 4u; + let shift = (byte_offset & 0x3u) * 8u; + let lo = src0[word_idx]; + let hi = src0[word_idx + 1u]; + let shifted = (lo >> shift) | (hi << (32u - shift)); + return select(shifted, lo, shift == 0u); +} + +fn load_f16_at_src0(byte_offset: u32) -> f16 { + let packed = unpack2x16float(load_u16_at_src0(byte_offset)); + return f16(packed[0]); +} + +fn load_f16_as_f32_at_src0(byte_offset: u32) -> f32 { + let word = src0[byte_offset / 4u]; + let shift = (byte_offset & 0x2u) * 8u; + let d_bits = (word >> shift) & 0xFFFFu; + return unpack2x16float(d_bits)[0]; +} +#endif +#endif + #ifdef Q4_1_T diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl index 3c8b84c9ac3..1415798fa6b 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl @@ -1,6 +1,8 @@ enable f16; +#define DECLARE_BYTE_LOADERS_SRC #include "common_decls.tmpl" + #ifdef F32_VEC fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { dst[(dst_base / 4) + offset] = src[(src_base / 4) + offset]; @@ -28,10 +30,10 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef Q4_0 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block_byte_base = (src_base + offset) * 18; // Block stride: 18 bytes - let d = load_f16_as_f32_at(&src, block_byte_base); + let d = load_f16_as_f32_at_src(block_byte_base); for (var j: u32 = 0u; j < 4; j++) { let q_byte_offset = block_byte_base + 2 + j * 4; - let q_packed = load_u32_at(&src, q_byte_offset); + let q_packed = load_u32_at_src(q_byte_offset); for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte(q_packed, k); let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * d; @@ -66,11 +68,11 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef Q5_0 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block_byte_base = (src_base + offset) * 22; // Block stride: 22 bytes - let d = load_f16_as_f32_at(&src, block_byte_base); - let qh_packed = load_u32_at(&src, block_byte_base + 2); + let d = load_f16_as_f32_at_src(block_byte_base); + let qh_packed = load_u32_at_src(block_byte_base + 2); for (var j: u32 = 0; j < 4; j++) { let q_byte_offset = block_byte_base + 6 + j * 4; - let q_packed = load_u32_at(&src, q_byte_offset); + let q_packed = load_u32_at_src(q_byte_offset); for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte(q_packed, k); @@ -113,10 +115,10 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef Q8_0 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block_byte_base = (src_base + offset) * 34; // Block stride: 34 bytes - let d = load_f16_as_f32_at(&src, block_byte_base); + let d = load_f16_as_f32_at_src(block_byte_base); for (var j: u32 = 0u; j < 8u; j++) { let q_byte_offset = block_byte_base + 2u + j * 4u; - let q_packed = load_u32_at(&src, q_byte_offset); + let q_packed = load_u32_at_src(q_byte_offset); for (var k: u32 = 0u; k < 4u; k++) { let q_byte = get_byte_i32(q_packed, k); let q_val = f32(q_byte) * d; @@ -162,16 +164,16 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block_byte_base = (src_base + offset) * 110; // Block stride: 110 bytes // Bytes 108-109: f16 scale 'd' - let d = load_f16_as_f32_at(&src, block_byte_base + 108); + let d = load_f16_as_f32_at_src(block_byte_base + 108); // Bytes 96-107: 12 bytes of scales (3 u32s) let kmask1: u32 = 0x03030303; let kmask2: u32 = 0x0f0f0f0f; var scale_vals: array; - scale_vals[0] = load_u32_at(&src, block_byte_base + 96); - scale_vals[1] = load_u32_at(&src, block_byte_base + 100); - scale_vals[2] = load_u32_at(&src, block_byte_base + 104); + scale_vals[0] = load_u32_at_src(block_byte_base + 96); + scale_vals[1] = load_u32_at_src(block_byte_base + 100); + scale_vals[2] = load_u32_at_src(block_byte_base + 104); var tmp: u32 = scale_vals[2]; scale_vals[2] = ((scale_vals[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); @@ -182,13 +184,13 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { // Bytes 0-31: 32 bytes of hmask (8 u32s) var hmask_vals: array; for (var i: u32 = 0; i < 8; i++) { - hmask_vals[i] = load_u32_at(&src, block_byte_base + i * 4); + hmask_vals[i] = load_u32_at_src(block_byte_base + i * 4); } // Bytes 32-95: 64 bytes of qs (16 u32s) var qs_vals: array; for (var i: u32 = 0u; i < 16; i++) { - qs_vals[i] = load_u32_at(&src, block_byte_base + 32 + i * 4); + qs_vals[i] = load_u32_at_src(block_byte_base + 32 + i * 4); } var dst_i = dst_base + offset * 256; @@ -286,24 +288,24 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block_byte_base = (src_base + offset) * 210; // Block stride: 210 bytes // Bytes 208-209: f16 scale 'd' - let d = load_f16_as_f32_at(&src, block_byte_base + 208); + let d = load_f16_as_f32_at_src(block_byte_base + 208); // Bytes 0-127: 128 bytes of ql (32 u32s) var ql_vals: array; for (var i: u32 = 0; i < 32; i++) { - ql_vals[i] = load_u32_at(&src, block_byte_base + i * 4); + ql_vals[i] = load_u32_at_src(block_byte_base + i * 4); } // Bytes 128-191: 64 bytes of qh (16 u32s) var qh_vals: array; for (var i: u32 = 0; i < 16u; i++) { - qh_vals[i] = load_u32_at(&src, block_byte_base + 128 + i * 4u); + qh_vals[i] = load_u32_at_src(block_byte_base + 128 + i * 4u); } // Bytes 192-207: 16 bytes of scales (4 u32s) var scale_vals: array; for (var i: u32 = 0; i < 4; i++) { - scale_vals[i] = load_u32_at(&src, block_byte_base + 192 + i * 4); + scale_vals[i] = load_u32_at_src(block_byte_base + 192 + i * 4); } var dst_i = dst_base + offset * 256; @@ -345,13 +347,13 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef IQ2_XXS fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block_byte_base = (src_base + offset) * 66; // Block stride: 66 bytes - let d = load_f16_as_f32_at(&src, block_byte_base); + let d = load_f16_as_f32_at_src(block_byte_base); var dst_i = dst_base + offset * 256; for (var ib: u32 = 0; ib < 32; ib += 4) { let aux0_offset = block_byte_base + 2 + ib * 2; let aux1_offset = block_byte_base + 2 + (ib + 2) * 2; - let aux0 = load_u32_at(&src, aux0_offset); - let aux1 = load_u32_at(&src, aux1_offset); + let aux0 = load_u32_at_src(aux0_offset); + let aux1 = load_u32_at_src(aux1_offset); let db = d * (0.5 + f32(aux1 >> 28)) * 0.25; for (var l: u32 = 0; l < 4; l++) { let ig = get_byte(aux0, l) * 8; @@ -373,12 +375,12 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef IQ2_XS fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block_byte_base = (src_base + offset) * 74; // Block stride: 74 bytes - let d = load_f16_as_f32_at(&src, block_byte_base); + let d = load_f16_as_f32_at_src(block_byte_base); var dst_i = dst_base + offset * 256; var scale_vals = array( - load_u32_at(&src, block_byte_base + 66), - load_u32_at(&src, block_byte_base + 70) + load_u32_at_src(block_byte_base + 66), + load_u32_at_src(block_byte_base + 70) ); for (var ib: u32 = 0; ib < 32; ib += 4) { @@ -389,7 +391,7 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { ); for (var l: u32 = 0; l < 4; l++) { let qs_offset = block_byte_base + 2 + (ib + l) * 2; - let qs_val = load_u32_at(&src, qs_offset) & 0xFFFF; + let qs_val = load_u32_at_src(qs_offset) & 0xFFFF; let ig = (qs_val & 511) * 8; let is = qs_val >> 9; let signs = get_byte(ksigns_iq2xs[is / 4], is % 4); @@ -408,21 +410,21 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef IQ2_S fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block_byte_base = (src_base + offset) * 82; // Block stride: 82 bytes - let d = load_f16_as_f32_at(&src, block_byte_base); + let d = load_f16_as_f32_at_src(block_byte_base); var dst_i = dst_base + offset * 256; var qs_vals : array; for (var i: u32 = 0; i < 16; i++) { - qs_vals[i] = load_u32_at(&src, block_byte_base + 2 + i * 4); + qs_vals[i] = load_u32_at_src(block_byte_base + 2 + i * 4); } var qh_vals: array; - qh_vals[0] = load_u32_at(&src, block_byte_base + 66); - qh_vals[1] = load_u32_at(&src, block_byte_base + 70); + qh_vals[0] = load_u32_at_src(block_byte_base + 66); + qh_vals[1] = load_u32_at_src(block_byte_base + 70); var scale_vals: array; - scale_vals[0] = load_u32_at(&src, block_byte_base + 74); - scale_vals[1] = load_u32_at(&src, block_byte_base + 78); + scale_vals[0] = load_u32_at_src(block_byte_base + 74); + scale_vals[1] = load_u32_at_src(block_byte_base + 78); for (var ib: u32 = 0; ib < 8; ib ++) { let s = get_byte(scale_vals[ib / 4], ib % 4); @@ -450,16 +452,16 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef IQ3_XXS fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block_byte_base = (src_base + offset) * 98; // Block stride: 98 bytes - let d = load_f16_as_f32_at(&src, block_byte_base); + let d = load_f16_as_f32_at_src(block_byte_base); var dst_i = dst_base + offset * 256; for (var ib: u32 = 0; ib < 16; ib += 2) { let sc_sign_offset = block_byte_base + 2 + (ib + 32) * 2; - let sc_sign = load_u32_at(&src, sc_sign_offset); + let sc_sign = load_u32_at_src(sc_sign_offset); let db = d * (0.5 + f32(sc_sign >> 28)) * 0.5; for (var l: u32 = 0; l < 4; l++) { let is = (sc_sign >> (7 * l)) & 127; let signs = get_byte(ksigns_iq2xs[is / 4], is % 4); - let ig_val = load_u32_at(&src, block_byte_base + 2 + (ib * 2 + l) * 2) & 0xFFFF; + let ig_val = load_u32_at_src(block_byte_base + 2 + (ib * 2 + l) * 2) & 0xFFFF; let ig1 = get_byte(ig_val, 0); let ig2 = get_byte(ig_val, 1); for (var j: u32 = 0; j < 4; j++) { @@ -480,20 +482,20 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef IQ3_S fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block_byte_base = (src_base + offset) * 110; // Block stride: 110 bytes - let d = load_f16_as_f32_at(&src, block_byte_base); + let d = load_f16_as_f32_at_src(block_byte_base); var dst_i = dst_base + offset * 256; var qh_vals = array( - load_u32_at(&src, block_byte_base + 66), - load_u32_at(&src, block_byte_base + 70) + load_u32_at_src(block_byte_base + 66), + load_u32_at_src(block_byte_base + 70) ); var sign_vals: array; for (var i: u32 = 0; i < 8; i++) { - sign_vals[i] = load_u32_at(&src, block_byte_base + 74 + i * 4); + sign_vals[i] = load_u32_at_src(block_byte_base + 74 + i * 4); } - var scale_vals = load_u32_at(&src, block_byte_base + 106); + var scale_vals = load_u32_at_src(block_byte_base + 106); for (var ib: u32 = 0; ib < 4; ib++) { let s = get_byte(scale_vals, ib); @@ -507,7 +509,7 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let sign_w = sign_vals[ib * 2 + k]; for (var l: u32 = 0; l < 4; l++) { let signs = get_byte(sign_w, l); - let ig_val = load_u32_at(&src, block_byte_base + 2 + (ib * 8 + k * 4 + l) * 2) & 0xFFFF; + let ig_val = load_u32_at_src(block_byte_base + 2 + (ib * 8 + k * 4 + l) * 2) & 0xFFFF; let ig1 = get_byte(ig_val, 0) | ((qh_byte << ((8 - (2 * l)))) & 256); let ig2 = get_byte(ig_val, 1) | ((qh_byte << ((7 - (2 * l)))) & 256); for (var j: u32 = 0; j < 4; j++) { @@ -529,13 +531,13 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef IQ1_S fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block_byte_base = (src_base + offset) * 50; // Block stride: 50 bytes - let d = load_f16_as_f32_at(&src, block_byte_base); + let d = load_f16_as_f32_at_src(block_byte_base); var dst_i = dst_base + offset * 256; for (var ib: u32 = 0; ib < 8; ib++) { - let qh = load_u32_at(&src, block_byte_base + 34 + ib * 2) & 0xFFFF; + let qh = load_u32_at_src(block_byte_base + 34 + ib * 2) & 0xFFFF; let dl = d * (2.0 * f32((qh >> 12) & 7) + 1.0); let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000) != 0); - let qs_w = load_u32_at(&src, block_byte_base + 2 + ib * 4); + let qs_w = load_u32_at_src(block_byte_base + 2 + ib * 4); for (var l: u32 = 0; l < 4; l++) { let ig = (get_byte(qs_w, l) | (((qh >> (3 * l)) & 7) << 8)) * 8; for (var j: u32 = 0; j < 8; j++) { @@ -596,11 +598,11 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { #ifdef IQ4_NL fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block_byte_base = (src_base + offset) * 18; // Block stride: 18 bytes - let d = load_f16_as_f32_at(&src, block_byte_base); + let d = load_f16_as_f32_at_src(block_byte_base); var dst_i = dst_base + offset * 32; var qs: array; for (var i: u32 = 0; i < 4; i++) { - qs[i] = load_u32_at(&src, block_byte_base + 2 + i * 4); + qs[i] = load_u32_at_src(block_byte_base + 2 + i * 4); } for (var j: u32 = 0; j < 16; j++) { let qsb = get_byte(qs[j / 4], j % 4); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl index fdabaf09b2e..fcbefdeb802 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl @@ -1,7 +1,9 @@ enable f16; +#define DECLARE_BYTE_LOADERS_SRC0 #include "common_decls.tmpl" + #ifdef FLOAT const BLOCK_SIZE = 1u; @@ -21,11 +23,11 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef Q4_0 fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block_byte_base = (src0_idx_base + offset) * 18; // Block stride: 18 bytes - let d = load_f16_as_f32_at(&src0, block_byte_base); + let d = load_f16_as_f32_at_src0(block_byte_base); var sum: f32 = 0.0; for (var j: u32 = 0; j < 4; j++) { let q_byte_offset = block_byte_base + 2 + j * 4; - let q_packed = load_u32_at(&src0, q_byte_offset); + let q_packed = load_u32_at_src0(q_byte_offset); for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte(q_packed, k); let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0f) * d; @@ -63,12 +65,12 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef Q5_0 fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block_byte_base = (src0_idx_base + offset) * 22; // Block stride: 22 bytes - let d = load_f16_as_f32_at(&src0, block_byte_base); + let d = load_f16_as_f32_at_src0(block_byte_base); var sum: f32 = 0.0; - let qh_packed = load_u32_at(&src0, block_byte_base + 2); + let qh_packed = load_u32_at_src0(block_byte_base + 2); for (var j: u32 = 0; j < 4; j++) { let q_byte_offset = block_byte_base + 6 + j * 4; - let q_packed = load_u32_at(&src0, q_byte_offset); + let q_packed = load_u32_at_src0(q_byte_offset); for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte(q_packed, k); let qh_hi = (qh_packed >> (j * 4 + k + 12)) & 0x10; @@ -110,11 +112,11 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef Q8_0 fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block_byte_base = (src0_idx_base + offset) * 34; // Block stride: 34 bytes - let d = load_f16_as_f32_at(&src0, block_byte_base); + let d = load_f16_as_f32_at_src0(block_byte_base); var sum: f32 = 0.0; for (var j: u32 = 0; j < 8; j++) { let q_byte_offset = block_byte_base + 2 + j * 4; - let q_packed = load_u32_at(&src0, q_byte_offset); + let q_packed = load_u32_at_src0(q_byte_offset); for (var k: u32 = 0u; k < 4u; k++) { let q_byte = get_byte_i32(q_packed, k); let q_val = f32(q_byte) * d; @@ -184,7 +186,7 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block_byte_base = (src0_idx_base + offset) * 110; // Block stride: 110 bytes // Bytes 108-109: f16 scale 'd' - let d = load_f16_as_f32_at(&src0, block_byte_base + 108); + let d = load_f16_as_f32_at_src0(block_byte_base + 108); // extract 6-bit scales, which consist of 4-bits from first 8 bytes of scale, // and 2-bits from the last 4 bytes @@ -192,9 +194,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let kmask1: u32 = 0x03030303; let kmask2: u32 = 0x0f0f0f0f; var scale_vals: array; - scale_vals[0] = load_u32_at(&src0, block_byte_base + 96); - scale_vals[1] = load_u32_at(&src0, block_byte_base + 100); - scale_vals[2] = load_u32_at(&src0, block_byte_base + 104); + scale_vals[0] = load_u32_at_src0(block_byte_base + 96); + scale_vals[1] = load_u32_at_src0(block_byte_base + 100); + scale_vals[2] = load_u32_at_src0(block_byte_base + 104); var tmp: u32 = scale_vals[2]; scale_vals[2] = ((scale_vals[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); @@ -205,13 +207,13 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { // Bytes 0-31: 32 bytes of hmask (8 u32s) var hmask_vals: array; for (var i: u32 = 0; i < 8; i++) { - hmask_vals[i] = load_u32_at(&src0, block_byte_base + i * 4); + hmask_vals[i] = load_u32_at_src0(block_byte_base + i * 4); } // Bytes 32-95: 64 bytes of qs (16 u32s) var qs_vals: array; for (var i: u32 = 0u; i < 16; i++) { - qs_vals[i] = load_u32_at(&src0, block_byte_base + 32 + i * 4); + qs_vals[i] = load_u32_at_src0(block_byte_base + 32 + i * 4); } var sum = 0.0; @@ -313,24 +315,24 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block_byte_base = (src0_idx_base + offset) * 210; // Block stride: 210 bytes // Bytes 208-209: f16 scale 'd' - let d = load_f16_as_f32_at(&src0, block_byte_base + 208); + let d = load_f16_as_f32_at_src0(block_byte_base + 208); // Bytes 0-127: 128 bytes of ql (32 u32s) var ql_vals: array; for (var i: u32 = 0; i < 32; i++) { - ql_vals[i] = load_u32_at(&src0, block_byte_base + i * 4); + ql_vals[i] = load_u32_at_src0(block_byte_base + i * 4); } // Bytes 128-191: 64 bytes of qh (16 u32s) var qh_vals: array; for (var i: u32 = 0; i < 16; i++) { - qh_vals[i] = load_u32_at(&src0, block_byte_base + 128 + i * 4); + qh_vals[i] = load_u32_at_src0(block_byte_base + 128 + i * 4); } // Bytes 192-207: 16 bytes of scales (4 u32s) var scale_vals: array; for (var i: u32 = 0; i < 4; i++) { - scale_vals[i] = load_u32_at(&src0, block_byte_base + 192 + i * 4); + scale_vals[i] = load_u32_at_src0(block_byte_base + 192 + i * 4); } var sum = 0.0; @@ -374,14 +376,14 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef IQ2_XXS fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block_byte_base = (src0_idx_base + offset) * 66; // Block stride: 66 bytes - let d = load_f16_as_f32_at(&src0, block_byte_base); + let d = load_f16_as_f32_at_src0(block_byte_base); var src1_i = src1_idx_base + offset * 256; var sum = 0.0; for (var ib: u32 = 0; ib < 32; ib += 4) { let aux0_offset = block_byte_base + 2 + ib * 2; let aux1_offset = block_byte_base + 2 + (ib + 2) * 2; - let aux0 = load_u32_at(&src0, aux0_offset); - let aux1 = load_u32_at(&src0, aux1_offset); + let aux0 = load_u32_at_src0(aux0_offset); + let aux1 = load_u32_at_src0(aux1_offset); let db = d * (0.5 + f32(aux1 >> 28)) * 0.25; for (var l: u32 = 0; l < 4; l++) { let ig = get_byte(aux0, l) * 8; @@ -402,12 +404,12 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef IQ2_XS fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block_byte_base = (src0_idx_base + offset) * 74; // Block stride: 74 bytes - let d = load_f16_as_f32_at(&src0, block_byte_base); + let d = load_f16_as_f32_at_src0(block_byte_base); var src1_i = src1_idx_base + offset * 256; var scale_vals = array( - load_u32_at(&src0, block_byte_base + 66), - load_u32_at(&src0, block_byte_base + 70) + load_u32_at_src0(block_byte_base + 66), + load_u32_at_src0(block_byte_base + 70) ); var sum = 0.0; @@ -419,7 +421,7 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { ); for (var l: u32 = 0; l < 4; l++) { let qs_offset = block_byte_base + 2 + (ib + l) * 2; - let qs_val = load_u32_at(&src0, qs_offset) & 0xFFFF; + let qs_val = load_u32_at_src0(qs_offset) & 0xFFFF; let ig = (qs_val & 511) * 8; let is = qs_val >> 9; let signs = get_byte(ksigns_iq2xs[is / 4], is % 4); @@ -439,21 +441,21 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef IQ2_S fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block_byte_base = (src0_idx_base + offset) * 82; // Block stride: 82 bytes - let d = load_f16_as_f32_at(&src0, block_byte_base); + let d = load_f16_as_f32_at_src0(block_byte_base); var src1_i = src1_idx_base + offset * 256; var qs_vals : array; for (var i: u32 = 0; i < 16; i++) { - qs_vals[i] = load_u32_at(&src0, block_byte_base + 2 + i * 4); + qs_vals[i] = load_u32_at_src0(block_byte_base + 2 + i * 4); } var qh_vals: array; - qh_vals[0] = load_u32_at(&src0, block_byte_base + 66); - qh_vals[1] = load_u32_at(&src0, block_byte_base + 70); + qh_vals[0] = load_u32_at_src0(block_byte_base + 66); + qh_vals[1] = load_u32_at_src0(block_byte_base + 70); var scale_vals: array; - scale_vals[0] = load_u32_at(&src0, block_byte_base + 74); - scale_vals[1] = load_u32_at(&src0, block_byte_base + 78); + scale_vals[0] = load_u32_at_src0(block_byte_base + 74); + scale_vals[1] = load_u32_at_src0(block_byte_base + 78); var sum = 0.0; for (var ib: u32 = 0; ib < 8; ib ++) { @@ -483,17 +485,17 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef IQ3_XXS fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block_byte_base = (src0_idx_base + offset) * 98; // Block stride: 98 bytes - let d = load_f16_as_f32_at(&src0, block_byte_base); + let d = load_f16_as_f32_at_src0(block_byte_base); var src1_i = src1_idx_base + offset * 256; var sum = 0.0; for (var ib: u32 = 0; ib < 16; ib += 2) { let sc_sign_offset = block_byte_base + 2 + (ib + 32) * 2; - let sc_sign = load_u32_at(&src0, sc_sign_offset); + let sc_sign = load_u32_at_src0(sc_sign_offset); let db = d * (0.5 + f32(sc_sign >> 28)) * 0.5; for (var l: u32 = 0; l < 4; l++) { let is = (sc_sign >> (7 * l)) & 127; let signs = get_byte(ksigns_iq2xs[is / 4], is % 4); - let ig_val = load_u32_at(&src0, block_byte_base + 2 + (ib * 2 + l) * 2) & 0xFFFF; + let ig_val = load_u32_at_src0(block_byte_base + 2 + (ib * 2 + l) * 2) & 0xFFFF; let ig1 = get_byte(ig_val, 0); let ig2 = get_byte(ig_val, 1); for (var j: u32 = 0; j < 4; j++) { @@ -515,20 +517,20 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef IQ3_S fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block_byte_base = (src0_idx_base + offset) * 110; // Block stride: 110 bytes - let d = load_f16_as_f32_at(&src0, block_byte_base); + let d = load_f16_as_f32_at_src0(block_byte_base); var src1_i = src1_idx_base + offset * 256; var qh_vals = array( - load_u32_at(&src0, block_byte_base + 66), - load_u32_at(&src0, block_byte_base + 70) + load_u32_at_src0(block_byte_base + 66), + load_u32_at_src0(block_byte_base + 70) ); var sign_vals: array; for (var i: u32 = 0; i < 8; i++) { - sign_vals[i] = load_u32_at(&src0, block_byte_base + 74 + i * 4); + sign_vals[i] = load_u32_at_src0(block_byte_base + 74 + i * 4); } - var scale_vals = load_u32_at(&src0, block_byte_base + 106); + var scale_vals = load_u32_at_src0(block_byte_base + 106); var sum = 0.0; for (var ib: u32 = 0; ib < 4; ib++) { @@ -543,7 +545,7 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let sign_w = sign_vals[ib * 2 + k]; for (var l: u32 = 0; l < 4; l++) { let signs = get_byte(sign_w, l); - let ig_val = load_u32_at(&src0, block_byte_base + 2 + (ib * 8 + k * 4 + l) * 2) & 0xFFFF; + let ig_val = load_u32_at_src0(block_byte_base + 2 + (ib * 8 + k * 4 + l) * 2) & 0xFFFF; let ig1 = get_byte(ig_val, 0) | ((qh_byte << ((8 - (2 * l)))) & 256); let ig2 = get_byte(ig_val, 1) | ((qh_byte << ((7 - (2 * l)))) & 256); for (var j: u32 = 0; j < 4; j++) { @@ -566,14 +568,14 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef IQ1_S fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block_byte_base = (src0_idx_base + offset) * 50; // Block stride: 50 bytes - let d = load_f16_as_f32_at(&src0, block_byte_base); + let d = load_f16_as_f32_at_src0(block_byte_base); var src1_i = src1_idx_base + offset * 256; var sum = 0.0; for (var ib: u32 = 0; ib < 8; ib++) { - let qh = load_u32_at(&src0, block_byte_base + 34 + ib * 2) & 0xFFFF; + let qh = load_u32_at_src0(block_byte_base + 34 + ib * 2) & 0xFFFF; let dl = d * (2.0 * f32((qh >> 12) & 7) + 1.0); let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000) != 0); - let qs_w = load_u32_at(&src0, block_byte_base + 2 + ib * 4); + let qs_w = load_u32_at_src0(block_byte_base + 2 + ib * 4); for (var l: u32 = 0; l < 4; l++) { let ig = (get_byte(qs_w, l) | (((qh >> (3 * l)) & 7) << 8)) * 8; for (var j: u32 = 0; j < 8; j++) { @@ -638,12 +640,12 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { #ifdef IQ4_NL fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block_byte_base = (src0_idx_base + offset) * 18; // Block stride: 18 bytes - let d = load_f16_as_f32_at(&src0, block_byte_base); + let d = load_f16_as_f32_at_src0(block_byte_base); var src1_i = src1_idx_base + offset * 32; var sum = 0.0; var qs: array; for (var i: u32 = 0; i < 4; i++) { - qs[i] = load_u32_at(&src0, block_byte_base + 2 + i * 4); + qs[i] = load_u32_at_src0(block_byte_base + 2 + i * 4); } for (var j: u32 = 0; j < 16; j++) { let qsb = get_byte(qs[j / 4], j % 4); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl index 56a76a6e6c4..5a323818260 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl @@ -84,11 +84,11 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { let src0_idx = batch_offset + global_m * params.stride_01 + global_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_f16_at(&src0, block_byte_base); + let d = load_f16_at_src0(block_byte_base); for (var j = 0u; j < F16_PER_THREAD; j += 2) { let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); - let q_packed = load_u32_at(&src0, q_byte_offset); + let q_packed = load_u32_at_src0(q_byte_offset); for (var k = 0u; k < 4u; k++) { let q_byte = get_byte(q_packed, k); let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d; @@ -125,12 +125,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { let src0_idx = batch_offset + global_m * params.stride_01 + global_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_f16_at(&src0, block_byte_base); - let m = load_f16_at(&src0, block_byte_base + 2u); + let d = load_f16_at_src0(block_byte_base); + let m = load_f16_at_src0(block_byte_base + 2u); for (var j = 0u; j < F16_PER_THREAD; j += 2) { let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j); - let q_packed = load_u32_at(&src0, q_byte_offset); + let q_packed = load_u32_at_src0(q_byte_offset); for (var k = 0u; k < 4u; k++) { let q_byte = get_byte(q_packed, k); let q_lo = f16(q_byte & 0xF) * d + m; @@ -171,12 +171,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let src0_idx = batch_offset + global_m * params.stride_01 + global_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_f16_at(&src0, block_byte_base); - let qh_packed = load_u32_at(&src0, block_byte_base + 2u); + let d = load_f16_at_src0(block_byte_base); + let qh_packed = load_u32_at_src0(block_byte_base + 2u); for (var j = 0u; j < 2; j++) { let q_byte_offset = block_byte_base + 6u + 2u * (block_offset + j * 2u); - let q_packed = load_u32_at(&src0, q_byte_offset); + let q_packed = load_u32_at_src0(q_byte_offset); let j_adjusted = j + (block_offset / 2u); @@ -225,14 +225,14 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let src0_idx = batch_offset + global_m * params.stride_01 + global_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_f16_at(&src0, block_byte_base); - let m = load_f16_at(&src0, block_byte_base + 2u); - let qh_packed = load_u32_at(&src0, block_byte_base + 4u); + let d = load_f16_at_src0(block_byte_base); + let m = load_f16_at_src0(block_byte_base + 2u); + let qh_packed = load_u32_at_src0(block_byte_base + 4u); for (var j = 0u; j < 2; j++) { let q_byte_offset = block_byte_base + 8u + 2u * (block_offset + j * 2u); - let q_packed = load_u32_at(&src0, q_byte_offset); + let q_packed = load_u32_at_src0(q_byte_offset); let j_adjusted = j + (block_offset / 2u); @@ -277,11 +277,11 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { let src0_idx = batch_offset + global_m * params.stride_01 + global_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_f16_at(&src0, block_byte_base); + let d = load_f16_at_src0(block_byte_base); for (var j = 0u; j < F16_PER_THREAD; j+=2) { let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); - let q_packed = load_u32_at(&src0, q_byte_offset); + let q_packed = load_u32_at_src0(q_byte_offset); for (var k = 0u; k < 4u; k++) { let q_byte = get_byte_i32(q_packed, k); @@ -317,12 +317,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { let src0_idx = batch_offset + global_m * params.stride_01 + global_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_f16_at(&src0, block_byte_base); - let m = load_f16_at(&src0, block_byte_base + 2u); + let d = load_f16_at_src0(block_byte_base); + let m = load_f16_at_src0(block_byte_base + 2u); for (var j = 0u; j < F16_PER_THREAD; j+=2) { let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j); - let q_packed = load_u32_at(&src0, q_byte_offset); + let q_packed = load_u32_at_src0(q_byte_offset); for (var k = 0u; k < 4u; k++) { let q_byte = get_byte_i32(q_packed, k); @@ -359,8 +359,8 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let src0_idx = batch_offset + global_m * params.stride_01 + block_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_f16_at(&src0, block_byte_base + 80u); - let dmin = load_f16_at(&src0, block_byte_base + 82u); + let d = load_f16_at_src0(block_byte_base + 80u); + let dmin = load_f16_at_src0(block_byte_base + 82u); // Decode the element at position k_in_block let block_of_32 = k_in_block / 32u; @@ -373,14 +373,14 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let is = k_in_block / 16u; - let sc_packed = load_u32_at(&src0, block_byte_base + 4u * (is / 4u)); + let sc_packed = load_u32_at_src0(block_byte_base + 4u * (is / 4u)); let sc = get_byte(sc_packed, is % 4u); let dl = d * f16(sc & 0xFu); let ml = dmin * f16(sc >> 4u); let q_idx = q_b_idx + k + l; - let q_packed = load_u32_at(&src0, block_byte_base + 16u + 4u * (q_idx / 4u)); + let q_packed = load_u32_at_src0(block_byte_base + 16u + 4u * (q_idx / 4u)); let q_byte = get_byte(q_packed, q_idx % 4u); let qs_val = (q_byte >> shift) & 3u; @@ -413,7 +413,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let src0_idx = batch_offset + global_m * params.stride_01 + block_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_f16_at(&src0, block_byte_base + 108u); + let d = load_f16_at_src0(block_byte_base + 108u); // Load and unpack scales let kmask1: u32 = 0x03030303u; @@ -421,7 +421,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 var scale_vals: array; for (var i: u32 = 0u; i < 4u; i++) { - scale_vals[i] = load_u32_at(&src0, block_byte_base + 96u + 4u * i); + scale_vals[i] = load_u32_at_src0(block_byte_base + 96u + 4u * i); } var tmp: u32 = scale_vals[2]; @@ -433,12 +433,12 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 // Load hmask and qs arrays var hmask_vals: array; for (var i: u32 = 0u; i < 8u; i++) { - hmask_vals[i] = load_u32_at(&src0, block_byte_base + 4u * i); + hmask_vals[i] = load_u32_at_src0(block_byte_base + 4u * i); } var qs_vals: array; for (var i: u32 = 0u; i < 16u; i++) { - qs_vals[i] = load_u32_at(&src0, block_byte_base + 32u + 4u * i); + qs_vals[i] = load_u32_at_src0(block_byte_base + 32u + 4u * i); } let half = k_in_block / 128u; // 0 or 1 @@ -499,8 +499,8 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let src0_idx = batch_offset + global_m * params.stride_01 + block_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_f16_at(&src0, block_byte_base); - let dmin = load_f16_at(&src0, block_byte_base + 2u); + let d = load_f16_at_src0(block_byte_base); + let dmin = load_f16_at_src0(block_byte_base + 2u); // Map k_in_block to loop structure: // Outer loop over 64-element groups (alternating q_b_idx) @@ -520,14 +520,14 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let scale_base = block_byte_base + 4u; if (is < 4u) { - let sc_byte = get_byte(load_u32_at(&src0, scale_base), is % 4u); - let min_byte = get_byte(load_u32_at(&src0, scale_base + 4), is % 4u); + let sc_byte = get_byte(load_u32_at_src0(scale_base), is % 4u); + let min_byte = get_byte(load_u32_at_src0(scale_base + 4), is % 4u); sc = sc_byte & 63u; mn = min_byte & 63u; } else { - let sc_min_lo = get_byte(load_u32_at(&src0, scale_base + 8), (is + 4u) % 4u); - let sc_hi = get_byte(load_u32_at(&src0, scale_base), (is - 4u) % 4u); - let min_hi = get_byte(load_u32_at(&src0, scale_base + 4), is % 4u); + let sc_min_lo = get_byte(load_u32_at_src0(scale_base + 8), (is + 4u) % 4u); + let sc_hi = get_byte(load_u32_at_src0(scale_base), (is - 4u) % 4u); + let min_hi = get_byte(load_u32_at_src0(scale_base + 4), is % 4u); sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u); mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u); @@ -537,7 +537,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let ml = dmin * f16(mn); let q_idx = q_b_idx + l; - let q_packed = load_u32_at(&src0, block_byte_base + 16u + 4u * (q_idx / 4u)); + let q_packed = load_u32_at_src0(block_byte_base + 16u + 4u * (q_idx / 4u)); let q_byte = get_byte(q_packed, q_idx % 4u); let qs_val = (q_byte >> shift) & 0xFu; @@ -571,8 +571,8 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let src0_idx = batch_offset + global_m * params.stride_01 + block_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_f16_at(&src0, block_byte_base); - let dmin = load_f16_at(&src0, block_byte_base + 2u); + let d = load_f16_at_src0(block_byte_base); + let dmin = load_f16_at_src0(block_byte_base + 2u); // The original loop processes elements in groups of 64 @@ -597,14 +597,14 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let scale_base = block_byte_base + 4u; if (is < 4u) { - let sc_byte = get_byte(load_u32_at(&src0, scale_base), is % 4u); - let min_byte = get_byte(load_u32_at(&src0, scale_base + 4), is % 4u); + let sc_byte = get_byte(load_u32_at_src0(scale_base), is % 4u); + let min_byte = get_byte(load_u32_at_src0(scale_base + 4), is % 4u); sc = sc_byte & 63u; mn = min_byte & 63u; } else { - let sc_min_lo = get_byte(load_u32_at(&src0, scale_base + 8), (is + 4u) % 4u); - let sc_hi = get_byte(load_u32_at(&src0, scale_base), (is - 4u) % 4u); - let min_hi = get_byte(load_u32_at(&src0, scale_base + 4), is % 4u); + let sc_min_lo = get_byte(load_u32_at_src0(scale_base + 8), (is + 4u) % 4u); + let sc_hi = get_byte(load_u32_at_src0(scale_base), (is - 4u) % 4u); + let min_hi = get_byte(load_u32_at_src0(scale_base + 4), is % 4u); sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u); mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u); @@ -614,11 +614,11 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let ml = dmin * f16(mn); let q_idx = q_b_idx + l; - let q_packed = load_u32_at(&src0, block_byte_base + 48u + 4u * (q_idx / 4u)); + let q_packed = load_u32_at_src0(block_byte_base + 48u + 4u * (q_idx / 4u)); let q_byte = get_byte(q_packed, q_idx % 4u); - let qh_packed = load_u32_at(&src0, block_byte_base + 16u + 4u * (l / 4u)); + let qh_packed = load_u32_at_src0(block_byte_base + 16u + 4u * (l / 4u)); let qh_byte = get_byte(qh_packed, l % 4u); @@ -666,17 +666,17 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 // Load only ql13 word needed let ql13_flat = ql_b_idx + l; - let ql13 = load_u32_at(&src0, block_byte_base + ql13_flat); + let ql13 = load_u32_at_src0(block_byte_base + ql13_flat); let ql13_b = get_byte(ql13, 0u); // Load only ql24 word needed let ql24_flat = ql_b_idx + l + 32u; - let ql24 = load_u32_at(&src0, block_byte_base + ql24_flat); + let ql24 = load_u32_at_src0(block_byte_base + ql24_flat); let ql24_b = get_byte(ql24, 0u); // Load only qh word needed let qh_flat = qh_b_idx + l; - let qh = load_u32_at(&src0, block_byte_base + 128u + qh_flat); + let qh = load_u32_at_src0(block_byte_base + 128u + qh_flat); let qh_b = get_byte(qh, 0u); let q1 = f16((ql13_b & 0xFu) | ((qh_b & 3u) << 4u)) - f16(32.0); @@ -687,10 +687,10 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 // Load only the scale word needed let is = l / 16u; let sc_idx = sc_b_idx + is + quarter * 2u; - let sc = load_u32_at(&src0, block_byte_base + 192u + sc_idx); + let sc = load_u32_at_src0(block_byte_base + 192u + sc_idx); let sc_val = get_byte_i32(sc, 0u); - let d = load_f16_at(&src0, block_byte_base + 208u); + let d = load_f16_at_src0(block_byte_base + 208u); var q_val: f16; if (quarter == 0u) { diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl index 5f763a6400a..91039ff2546 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl @@ -1,6 +1,8 @@ enable f16; +#define DECLARE_BYTE_LOADERS_SRC0 #include "common_decls.tmpl" + #include "mul_mat_decls.tmpl" #ifdef VEC diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl index ee37e6d249c..98bbdeb83ba 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl @@ -1,6 +1,8 @@ enable f16; +#define DECLARE_BYTE_LOADERS_SRC0 #include "common_decls.tmpl" + #include "mul_mat_decls.tmpl" #ifdef VEC diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl index 4151ce430b0..d86a72ce6e0 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl @@ -3,7 +3,9 @@ enable f16; enable subgroups; enable chromium_experimental_subgroup_matrix; +#define DECLARE_BYTE_LOADERS_SRC0 #include "common_decls.tmpl" + #include "mul_mat_decls.tmpl" // TODO: this shader path does not work with some models like qwen2.5 on Metal devices, f16 accumulation causes NaNs. @@ -196,4 +198,3 @@ fn main(@builtin(workgroup_id) wg_id: vec3, } } } - diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl index 6f6bcaf7940..9f7b3e32eca 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl @@ -1,7 +1,9 @@ enable f16; +#define DECLARE_BYTE_LOADERS_SRC0 #include "common_decls.tmpl" + #ifdef VEC #define VEC_SIZE 4 @@ -65,10 +67,10 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(load_f16_at(&src0, block_byte_base)); + let d = f32(load_f16_at_src0(block_byte_base)); for (var j = 0u; j < F16_PER_THREAD; j += 2) { let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); - let q_packed = load_u32_at(&src0, q_byte_offset); + let q_packed = load_u32_at_src0(q_byte_offset); for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte(q_packed, k); let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * d; @@ -98,11 +100,11 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(load_f16_at(&src0, block_byte_base)); - let m = f32(load_f16_at(&src0, block_byte_base + 2u)); + let d = f32(load_f16_at_src0(block_byte_base)); + let m = f32(load_f16_at_src0(block_byte_base + 2u)); for (var j = 0u; j < F16_PER_THREAD; j += 2) { let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j); - let q_packed = load_u32_at(&src0, q_byte_offset); + let q_packed = load_u32_at_src0(q_byte_offset); for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte(q_packed, k); let q_hi = f32((q_byte >> 4) & 0xF) * d + m; @@ -132,12 +134,12 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(load_f16_at(&src0, block_byte_base)); - let qh_packed = load_u32_at(&src0, block_byte_base + 2u); + let d = f32(load_f16_at_src0(block_byte_base)); + let qh_packed = load_u32_at_src0(block_byte_base + 2u); for (var j = 0u; j < 2; j++) { let q_byte_offset = block_byte_base + 6u + 2u * (block_offset + j * 2u); - let q_packed = load_u32_at(&src0, q_byte_offset); + let q_packed = load_u32_at_src0(q_byte_offset); let j_adjusted = j + (block_offset / 2u); @@ -176,13 +178,13 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(load_f16_at(&src0, block_byte_base)); - let m = load_f16_at(&src0, block_byte_base + 2u); - let qh_packed = load_u32_at(&src0, block_byte_base + 4u); + let d = f32(load_f16_at_src0(block_byte_base)); + let m = load_f16_at_src0(block_byte_base + 2u); + let qh_packed = load_u32_at_src0(block_byte_base + 4u); for (var j = 0u; j < 2; j++) { let q_byte_offset = block_byte_base + 8u + 2u * (block_offset + j * 2u); - let q_packed = load_u32_at(&src0, q_byte_offset); + let q_packed = load_u32_at_src0(q_byte_offset); let j_adjusted = j + (block_offset / 2u); @@ -221,11 +223,11 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(load_f16_at(&src0, block_byte_base)); + let d = f32(load_f16_at_src0(block_byte_base)); for (var j = 0u; j < F16_PER_THREAD; j += 2) { let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); - let q_packed = load_u32_at(&src0, q_byte_offset); + let q_packed = load_u32_at_src0(q_byte_offset); for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte_i32(q_packed, k); let q_val = f32(q_byte) * d; @@ -254,12 +256,12 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(load_f16_at(&src0, block_byte_base)); - let m = load_f16_at(&src0, block_byte_base + 2u); + let d = f32(load_f16_at_src0(block_byte_base)); + let m = load_f16_at_src0(block_byte_base + 2u); for (var j = 0u; j < F16_PER_THREAD; j += 2) { let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j); - let q_packed = load_u32_at(&src0, q_byte_offset); + let q_packed = load_u32_at_src0(q_byte_offset); for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte_i32(q_packed, k); let q_val = f32(q_byte) * d + f32(m); @@ -309,13 +311,13 @@ fn mul_acc(tig: u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { for (var i = ix; i < nb; i += 2u) { let bbase = (idx_base + k_block_start + i) * BLOCK_SIZE_BYTES; - let d = f32(load_f16_at(&src0, bbase + 208u)); + let d = f32(load_f16_at_src0(bbase + 208u)); - let ql1_u32 = load_u32_at(&src0, bbase + q_offset_l); - let ql2_u32 = load_u32_at(&src0, bbase + q_offset_l + 32u); - let qh_u32 = load_u32_at(&src0, bbase + 128u + q_offset_h); - let sc_u32_0 = load_u32_at(&src0, bbase + sc_base_byte); - let sc_u32_1 = load_u32_at(&src0, bbase + sc_base_byte + 4u); + let ql1_u32 = load_u32_at_src0(bbase + q_offset_l); + let ql2_u32 = load_u32_at_src0(bbase + q_offset_l + 32u); + let qh_u32 = load_u32_at_src0(bbase + 128u + q_offset_h); + let sc_u32_0 = load_u32_at_src0(bbase + sc_base_byte); + let sc_u32_1 = load_u32_at_src0(bbase + sc_base_byte + 4u); let sc0 = sbyte_of(sc_u32_0, sc_byte_pos); let sc2 = sbyte_of(sc_u32_0, sc_byte_pos + 2u); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl index 8c334817ccd..b8f1bca1284 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl @@ -147,15 +147,12 @@ fn main(@builtin(global_invocation_id) gid: vec3) { -9.010913, 9.010913))); #endif #ifdef XIELU + let val = f32(src[params.offset_src + src_idx]); let res = - select(((exp(min(src[params.offset_src + src_idx], TYPE(params.eps))) - 1.0) - - src[params.offset_src + src_idx]) * - TYPE(params.alpha_n) + - TYPE(params.beta) * src[params.offset_src + src_idx], - TYPE(params.alpha_p) * src[params.offset_src + src_idx] * - src[params.offset_src + src_idx] + - TYPE(params.beta) * src[params.offset_src + src_idx], - src[params.offset_src + src_idx] > 0.0); + TYPE(select( + ((exp(min(val, params.eps)) - 1.0) - val) * params.alpha_n + params.beta * val, + params.alpha_p * val * val + params.beta * val, + val > 0.0)); #endif #ifdef SOFTPLUS let src_f32 = f32(src[params.offset_src + src_idx]); From c6d1fbf31f3f8c611772e3a6bb3d3b35ac5f01eb Mon Sep 17 00:00:00 2001 From: Pasha Khosravi Date: Wed, 15 Apr 2026 09:38:38 -0700 Subject: [PATCH 440/831] cuda: Q1_0 initial backend (llama/21629) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [cuda] initial Q1_0 backend * remove unused code, fix AMD MMA guard * attempt to support dp4a * Apply suggestions from code review Co-authored-by: Johannes Gäßler --------- Co-authored-by: Johannes Gäßler --- ggml/src/ggml-cuda/common.cuh | 7 ++ ggml/src/ggml-cuda/convert.cu | 10 ++ ggml/src/ggml-cuda/dequantize.cuh | 22 +++++ ggml/src/ggml-cuda/getrows.cu | 4 + ggml/src/ggml-cuda/ggml-cuda.cu | 2 + ggml/src/ggml-cuda/mmq.cu | 4 + ggml/src/ggml-cuda/mmq.cuh | 93 +++++++++++++++++++ ggml/src/ggml-cuda/mmvq.cu | 8 ++ .../template-instances/generate_cu_files.py | 1 + .../template-instances/mmq-instance-q1_0.cu | 5 + ggml/src/ggml-cuda/vecdotq.cuh | 48 ++++++++++ 11 files changed, 204 insertions(+) create mode 100644 ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 2e5eaff9bf4..ad30ecd8fa5 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -924,6 +924,13 @@ struct ggml_cuda_type_traits { static constexpr int qr = 1; }; +template<> +struct ggml_cuda_type_traits { + static constexpr int qk = QK1_0; + static constexpr int qr = QR1_0; + static constexpr int qi = QI1_0; +}; + template<> struct ggml_cuda_type_traits { static constexpr int qk = QK4_0; diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index 79ccfe568a2..61630a35a29 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -711,6 +711,8 @@ to_bf16_cuda_t ggml_get_to_bf16_cuda(ggml_type type) { to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { switch (type) { + case GGML_TYPE_Q1_0: + return dequantize_block_cont_cuda; case GGML_TYPE_Q4_0: return dequantize_row_q4_0_cuda; case GGML_TYPE_Q4_1: @@ -767,6 +769,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { switch (type) { + case GGML_TYPE_Q1_0: + return dequantize_block_cont_cuda; case GGML_TYPE_Q4_0: return dequantize_row_q4_0_cuda; case GGML_TYPE_Q4_1: @@ -822,6 +826,8 @@ to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type) { switch (type) { case GGML_TYPE_F32: return convert_unary_cuda; + case GGML_TYPE_Q1_0: + return dequantize_block_cuda; case GGML_TYPE_Q4_0: return dequantize_block_cuda; case GGML_TYPE_Q4_1: @@ -843,6 +849,8 @@ to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type) { switch (type) { case GGML_TYPE_F32: return convert_unary_cuda; + case GGML_TYPE_Q1_0: + return dequantize_block_cuda; case GGML_TYPE_Q4_0: return dequantize_block_cuda; case GGML_TYPE_Q4_1: @@ -864,6 +872,8 @@ to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type) { switch (type) { case GGML_TYPE_F16: return convert_unary_cuda; + case GGML_TYPE_Q1_0: + return dequantize_block_cuda; case GGML_TYPE_Q4_0: return dequantize_block_cuda; case GGML_TYPE_Q4_1: diff --git a/ggml/src/ggml-cuda/dequantize.cuh b/ggml/src/ggml-cuda/dequantize.cuh index e060fb29fdc..9ae1342fc0e 100644 --- a/ggml/src/ggml-cuda/dequantize.cuh +++ b/ggml/src/ggml-cuda/dequantize.cuh @@ -1,5 +1,27 @@ #include "common.cuh" +static __device__ __forceinline__ void dequantize_q1_0(const void * vx, const int64_t ib, const int iqs, float2 & v){ + const block_q1_0 * x = (const block_q1_0 *) vx; + + const float d = x[ib].d; + + const int bit_index_0 = iqs; + const int bit_index_1 = iqs + 1; + + const int byte_index_0 = bit_index_0 / 8; + const int bit_offset_0 = bit_index_0 % 8; + + const int byte_index_1 = bit_index_1 / 8; + const int bit_offset_1 = bit_index_1 % 8; + + // Extract bits: 1 = +d, 0 = -d (branchless) + const int bit_0 = (x[ib].qs[byte_index_0] >> bit_offset_0) & 1; + const int bit_1 = (x[ib].qs[byte_index_1] >> bit_offset_1) & 1; + + v.x = (2*bit_0 - 1) * d; + v.y = (2*bit_1 - 1) * d; +} + static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int64_t ib, const int iqs, float2 & v){ const block_q4_0 * x = (const block_q4_0 *) vx; diff --git a/ggml/src/ggml-cuda/getrows.cu b/ggml/src/ggml-cuda/getrows.cu index 2fab33243dd..e99cba63d34 100644 --- a/ggml/src/ggml-cuda/getrows.cu +++ b/ggml/src/ggml-cuda/getrows.cu @@ -179,6 +179,10 @@ static void ggml_cuda_get_rows_switch_src0_type( get_rows_cuda_float((const nv_bfloat16 *) src0_d, src1_d, dst_d, ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); break; + case GGML_TYPE_Q1_0: + get_rows_cuda_q(src0_d, src1_d, dst_d, + ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); + break; case GGML_TYPE_Q4_0: get_rows_cuda_q(src0_d, src1_d, dst_d, ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index c17db3875ad..790f53cead7 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -4831,6 +4831,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g switch (a->type) { case GGML_TYPE_F32: case GGML_TYPE_F16: + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -4868,6 +4869,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_TYPE_F32: case GGML_TYPE_BF16: case GGML_TYPE_I32: + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index 27b4145ac9a..3f01ff5bfb0 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -5,6 +5,9 @@ static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) { switch (args.type_x) { + case GGML_TYPE_Q1_0: + mul_mat_q_case(ctx, args, stream); + break; case GGML_TYPE_Q4_0: mul_mat_q_case(ctx, args, stream); break; @@ -270,6 +273,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t bool mmq_supported; switch (type) { + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 18911141472..28b662df925 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -57,6 +57,8 @@ static_assert(sizeof(block_fp4_mmq) == sizeof(block_q8_1_mmq), "Unexpected b static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) { switch (type_x) { + case GGML_TYPE_Q1_0: + return MMQ_Q8_1_DS_LAYOUT_D4; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: return MMQ_Q8_1_DS_LAYOUT_DS4; @@ -185,6 +187,7 @@ static constexpr __device__ int get_mmq_y_device() { static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) { switch (type) { + case GGML_TYPE_Q1_0: return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_Q4_0: return MMQ_DP4A_TXS_Q4_0; case GGML_TYPE_Q4_1: return MMQ_DP4A_TXS_Q4_1; case GGML_TYPE_Q5_0: return MMQ_DP4A_TXS_Q8_0; @@ -229,6 +232,7 @@ static_assert(MMQ_MMA_TILE_X_K_NVFP4 % 8 == 4, "Wrong padding."); static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { switch (type) { + case GGML_TYPE_Q1_0: return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_Q4_0: return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_Q4_1: return MMQ_MMA_TILE_X_K_Q8_1; case GGML_TYPE_Q5_0: return MMQ_MMA_TILE_X_K_Q8_0; @@ -302,6 +306,87 @@ static constexpr __device__ int mmq_get_nwarps_device() { // ------------------------------------------------------------ +template static __device__ __forceinline__ void load_tiles_q1_0( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + 2*MMQ_TILE_NE_K); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + + constexpr int blocks_per_iter = MMQ_ITER_K / QK1_0; + constexpr int threads_per_row = blocks_per_iter * QI1_0; + constexpr int nrows = warp_size / threads_per_row; + constexpr int scale_entries_per_block = QK1_0 / QK8_1; + constexpr int scale_entries_per_row = blocks_per_iter * scale_entries_per_block; + + const int txi = threadIdx.x % threads_per_row; + const int kbx = txi / QI1_0; + const int kqsx = txi % QI1_0; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_q1_0 * bxi = (const block_q1_0 *) x + kbx0 + i*stride + kbx; + const int qs_offset = 4*kqsx; + const int qs0 = bxi->qs[qs_offset + 0] | (bxi->qs[qs_offset + 1] << 8) | + (bxi->qs[qs_offset + 2] << 16) | (bxi->qs[qs_offset + 3] << 24); + + int unpacked_bytes[8]; +#pragma unroll + for (int j = 0; j < 8; ++j) { + const int shift = j * 4; + const int bits4 = (qs0 >> shift) & 0x0F; + const int b0 = (bits4 & 0x01) ? 1 : -1; + const int b1 = (bits4 & 0x02) ? 1 : -1; + const int b2 = (bits4 & 0x04) ? 1 : -1; + const int b3 = (bits4 & 0x08) ? 1 : -1; + unpacked_bytes[j] = (b0 & 0xFF) | ((b1 & 0xFF) << 8) | ((b2 & 0xFF) << 16) | ((b3 & 0xFF) << 24); + } + + const int dst_offset = kbx*(scale_entries_per_block*QI8_0) + kqsx*QI8_0; +#pragma unroll + for (int j = 0; j < 8; ++j) { +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + dst_offset + j] = unpacked_bytes[j]; +#else + x_qs[i*(2*MMQ_TILE_NE_K + 1) + dst_offset + j] = unpacked_bytes[j]; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + } + } + + const int ksx = threadIdx.x % scale_entries_per_row; + const int scale_block = ksx / scale_entries_per_block; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + threadIdx.y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q1_0 * bxi = (const block_q1_0 *) x + kbx0 + i*stride + scale_block; + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + ksx] = bxi->d; +#else + x_df[i*(2*MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2) + ksx] = bxi->d; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + } +} + template static __device__ __forceinline__ void load_tiles_q4_0( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { constexpr int nwarps = mmq_get_nwarps_device(); @@ -3290,6 +3375,14 @@ static __device__ __forceinline__ void mmq_write_back_mma( template struct mmq_type_traits; +template +struct mmq_type_traits { + static constexpr int vdr = VDR_Q1_0_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q1_0; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; +}; + template struct mmq_type_traits { static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ; diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 07b10167bc4..8f55cace1a1 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -9,6 +9,7 @@ typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) { switch (type) { + case GGML_TYPE_Q1_0: return vec_dot_q1_0_q8_1; case GGML_TYPE_Q4_0: return vec_dot_q4_0_q8_1; case GGML_TYPE_Q4_1: return vec_dot_q4_1_q8_1; case GGML_TYPE_Q5_0: return vec_dot_q5_0_q8_1; @@ -36,6 +37,7 @@ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) static constexpr __host__ __device__ int get_vdr_mmvq(ggml_type type) { switch (type) { + case GGML_TYPE_Q1_0: return VDR_Q1_0_Q8_1_MMVQ; case GGML_TYPE_Q4_0: return VDR_Q4_0_Q8_1_MMVQ; case GGML_TYPE_Q4_1: return VDR_Q4_1_Q8_1_MMVQ; case GGML_TYPE_Q5_0: return VDR_Q5_0_Q8_1_MMVQ; @@ -886,6 +888,12 @@ static void mul_mat_vec_q_switch_type( const int nsamples_x, const int nsamples_dst, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, const int ids_stride, cudaStream_t stream) { switch (type_x) { + case GGML_TYPE_Q1_0: + mul_mat_vec_q_switch_ncols_dst + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); + break; case GGML_TYPE_Q4_0: mul_mat_vec_q_switch_ncols_dst (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, diff --git a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py index 40d51f93fa4..841059c15b5 100755 --- a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +++ b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py @@ -32,6 +32,7 @@ SOURCE_FATTN_MMA_CASE = "DECL_FATTN_MMA_F16_CASE({head_size_kq}, {head_size_v}, {ncols1}, {ncols2});\n" TYPES_MMQ = [ + "GGML_TYPE_Q1_0", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", "GGML_TYPE_Q2_K", "GGML_TYPE_Q3_K", "GGML_TYPE_Q4_K", "GGML_TYPE_Q5_K", "GGML_TYPE_Q6_K", "GGML_TYPE_IQ2_XXS", "GGML_TYPE_IQ2_XS", "GGML_TYPE_IQ2_S", "GGML_TYPE_IQ3_XXS", "GGML_TYPE_IQ3_S", diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu new file mode 100644 index 00000000000..f0686b0d0d8 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE(GGML_TYPE_Q1_0); diff --git a/ggml/src/ggml-cuda/vecdotq.cuh b/ggml/src/ggml-cuda/vecdotq.cuh index 40b2b41e7e8..d1741cc8d7b 100644 --- a/ggml/src/ggml-cuda/vecdotq.cuh +++ b/ggml/src/ggml-cuda/vecdotq.cuh @@ -106,6 +106,9 @@ static __device__ __forceinline__ uint32_t unpack_ksigns(const uint8_t v) { // VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called // MMVQ = mul_mat_vec_q, MMQ = mul_mat_q +#define VDR_Q1_0_Q8_1_MMVQ 1 // Process one 32-element chunk at a time for parallelism +#define VDR_Q1_0_Q8_1_MMQ 4 // Q1_0 has 128 bits (4 ints) per block + #define VDR_Q4_0_Q8_1_MMVQ 2 #define VDR_Q4_0_Q8_1_MMQ 4 @@ -669,6 +672,51 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq( return d6 * sumf_d; } +static __device__ __forceinline__ float vec_dot_q1_0_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { + + const block_q1_0 * bq1_0 = (const block_q1_0 *) vbq + kbx; + + // Q1_0: 128 elements with ONE scale + // Q8_1: 32 elements per block with individual scales + // iqs selects which of the 4 chunks of 32 elements to process (0-3) + + const float d1 = bq1_0->d; + + // Process only the chunk specified by iqs + const block_q8_1 * bq8_1_chunk = bq8_1 + iqs; + + // Load 32 bits (4 bytes) for this chunk from Q1_0 + const int offset = iqs * 4; + const int v = bq1_0->qs[offset + 0] | (bq1_0->qs[offset + 1] << 8) | + (bq1_0->qs[offset + 2] << 16) | (bq1_0->qs[offset + 3] << 24); + + // Unpack 32 bits into 32 signed values (-1 or +1) + int vi_bytes[8]; +#pragma unroll + for (int j = 0; j < 8; ++j) { + const int shift = j * 4; + const int bits4 = (v >> shift) & 0x0F; + const int b0 = (bits4 & 0x01) ? 1 : -1; + const int b1 = (bits4 & 0x02) ? 1 : -1; + const int b2 = (bits4 & 0x04) ? 1 : -1; + const int b3 = (bits4 & 0x08) ? 1 : -1; + vi_bytes[j] = (b0 & 0xFF) | ((b1 & 0xFF) << 8) | ((b2 & 0xFF) << 16) | ((b3 & 0xFF) << 24); + } + + // Compute dot product for this 32-element chunk + int sumi = 0; +#pragma unroll + for (int j = 0; j < 8; ++j) { + const int u = get_int_b4(bq8_1_chunk->qs, j); + sumi = ggml_cuda_dp4a(vi_bytes[j], u, sumi); + } + + // Apply Q1_0's single scale and this chunk's Q8_1 scale + const float d8 = __low2float(bq8_1_chunk->ds); + return d1 * d8 * sumi; +} + static __device__ __forceinline__ float vec_dot_q4_0_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { From 7fe6b8e171d23fe12847dbf42309d46144ea6407 Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Wed, 15 Apr 2026 19:04:51 +0200 Subject: [PATCH 441/831] vulkan: optimize im2col (llama/21713) * vulkan: improve im2col memory write layout * cap workgroups * minimal device tuning * use vendor_id instead of subgroup size --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 13 ++- .../ggml-vulkan/vulkan-shaders/im2col.comp | 96 +++++++------------ 2 files changed, 46 insertions(+), 63 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index b2a54bd85d0..702a249d754 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -1394,7 +1394,7 @@ struct vk_op_im2col_push_constants { uint32_t IW; uint32_t IH; uint32_t OW; uint32_t OH; uint32_t KW; uint32_t KH; - uint32_t pelements; + uint32_t OH_batch; uint32_t CHW; int32_t s0; int32_t s1; int32_t p0; int32_t p1; @@ -10064,7 +10064,13 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co const uint32_t batch = src1->ne[is_2D ? 3 : 2]; - elements = { OW * KW * KH, OH, batch * IC }; + const uint32_t CHW = IC * KH * KW; + // Cap X workgroups to limit concurrent IC channel reads. + // The shader loops over X to cover the full CHW dimension. + // AMD prefers a lower limit + const uint32_t min_cap = ctx->device->vendor_id == VK_VENDOR_ID_AMD ? 512u : 4096u; + const uint32_t x_elements = std::min(CHW, std::max(min_cap, OW * KH * KW)); + elements = { x_elements, OW, OH * batch }; elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]); elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]); } break; @@ -11727,7 +11733,6 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co const uint32_t offset_delta = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32 const uint32_t batch_offset = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32 - const uint32_t pelements = OW * KW * KH; const uint32_t batch = src1->ne[is_2D ? 3 : 2]; const ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; @@ -11739,7 +11744,7 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co dst_addr, batch_offset, offset_delta, IC, IW, IH, OW, OH, KW, KH, - pelements, + OH * batch, IC * KH * KW, s0, s1, p0, p1, d0, d1, batch * IC }); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp b/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp index 674f91e5ed2..ba4c2103f0c 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp @@ -13,7 +13,7 @@ layout (push_constant) uniform parameter uint IW; uint IH; uint OW; uint OH; uint KW; uint KH; - uint pelements; + uint OH_batch; uint CHW; int s0; int s1; int p0; int p1; @@ -34,82 +34,60 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; layout (buffer_reference) buffer D_ptr {D_TYPE d;}; #endif -void im2col(const uint y, const uint z) { - const uint gidx = gl_GlobalInvocationID.x; +void im2col(const uint ow, const uint z_idx) { + const uint oh = z_idx % p.OH; + const uint batch_idx = z_idx / p.OH; - const uint oh = y; - const uint batch = z / p.IC; - const uint ic = z % p.IC; + const uint gidx = gl_LocalInvocationID.x; + const uint src_batch = batch_idx * p.batch_offset; + const BDA_OFFSET_T dst_row = ((BDA_OFFSET_T(batch_idx) * p.OH + oh) * p.OW + ow) * p.CHW; - const uint src_base = ic * p.offset_delta + batch * p.batch_offset; - const BDA_OFFSET_T dst_base = ((BDA_OFFSET_T(batch) * p.OH + oh) * p.OW) * p.CHW + BDA_OFFSET_T(ic) * (p.KW * p.KH); - const int oh_s1 = int(oh) * p.s1; - const uint ksize = p.OW * p.KH; + const uint KHKW = p.KH * p.KW; - const uint base_linear_idx = gidx * NUM_ITER; + uint wg_x = gl_WorkGroupID.x; + do { + const uint wg_offset = wg_x * 512; - uint current_kx = base_linear_idx / ksize; - const uint rem = base_linear_idx - (current_kx * ksize); - uint current_ky = rem / p.OW; - uint current_ix = rem % p.OW; + [[unroll]] for (uint i = 0; i < NUM_ITER; ++i) { + const uint chw_idx = wg_offset + gidx + i * BLOCK_SIZE; - A_TYPE values[NUM_ITER]; - BDA_OFFSET_T offset_dst[NUM_ITER]; - [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) { - values[idx] = A_TYPE(0); - } - - [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) { - - const uint linear_idx = base_linear_idx + idx; - - if (linear_idx >= p.pelements) { - continue; - } - - const uint iiw = current_ix * p.s0 + current_kx * p.d0 - p.p0; - const uint iih = oh_s1 + current_ky * p.d1 - p.p1; - - offset_dst[idx] = dst_base + BDA_OFFSET_T(current_ix) * p.CHW + current_ky * p.KW + current_kx; - - if ((iih < p.IH) && (iiw < p.IW)) { - values[idx] = data_a[src_base + iih * p.IW + iiw]; - } - - if (++current_ix == p.OW) { - current_ix = 0; - if (++current_ky == p.KH) { - current_ky = 0; - current_kx++; + if (chw_idx >= p.CHW) { + return; } - } - } - [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) { + const uint ic = chw_idx / KHKW; + const uint rem = chw_idx - ic * KHKW; + const uint ky = rem / p.KW; + const uint kx = rem - ky * p.KW; - const uint linear_idx = base_linear_idx + idx; + const uint iiw = ow * p.s0 + kx * p.d0 - p.p0; + const uint iih = oh * p.s1 + ky * p.d1 - p.p1; - if (linear_idx >= p.pelements) { - continue; - } + A_TYPE val = A_TYPE(0); + if (iih < p.IH && iiw < p.IW) { + val = data_a[src_batch + ic * p.offset_delta + iih * p.IW + iiw]; + } #if BDA - D_ptr dst_addr = D_ptr(p.dst_addr + D_SIZE * offset_dst[idx]); - dst_addr.d = D_TYPE(values[idx]); + D_ptr out_ptr = D_ptr(p.dst_addr + D_SIZE * (dst_row + chw_idx)); + out_ptr.d = D_TYPE(val); #else - data_d[offset_dst[idx]] = D_TYPE(values[idx]); + data_d[dst_row + chw_idx] = D_TYPE(val); #endif - } + } + + wg_x += gl_NumWorkGroups.x; + } while (wg_x * 512 < p.CHW); } void main() { - uint y = gl_GlobalInvocationID.y; - while (y < p.OH) { + uint ow = gl_GlobalInvocationID.y; + while (ow < p.OW) { uint z = gl_GlobalInvocationID.z; - while (z < p.batch_IC) { - im2col(y, z); + while (z < p.OH_batch) { + im2col(ow, z); z += gl_NumWorkGroups.z; } - y += gl_NumWorkGroups.y; + ow += gl_NumWorkGroups.y; } } From f62bb133207f47e9975dfb511b119a304f23622d Mon Sep 17 00:00:00 2001 From: Katostrofik Date: Thu, 16 Apr 2026 01:34:05 -0400 Subject: [PATCH 442/831] Fix Q8_0 reorder: garbage on 2nd prompt + crash on full VRAM (llama/21638) * [SYCL] Fix Q8_0 reorder: add missing dequantize path for GEMM The Q8_0 reorder optimization (#21527) was missing a reorder-aware dequantizer for the GEMM code path used during prompt processing. After token generation reordered Q8_0 weights (via DMMV/MMVQ), the next prompt processing pass would read them with the standard dequantizer, producing garbage output. Add dequantize_block_q8_0_reorder() and wire it into both ggml_get_to_fp16_sycl() and ggml_get_to_fp32_sycl(), matching the pattern already used by Q4_0, Q4_K, and Q6_K. Fixes #21589 AI (Claude) was used to assist with root cause investigation and writing the kernel code. All code was human-reviewed and tested on real hardware. * SYCL: fix reorder crash when device memory is full The reorder optimization allocates a temporary buffer the full size of the weight tensor on the device. When VRAM is nearly full (large models on a single GPU), this allocation fails and the subsequent memcpy crashes on a NULL pointer. Fix: try device allocation first, fall back to host memory if device memory is full. The reorder kernel still works correctly reading from host memory over PCIe. This is slower for the one-time reorder (~21 t/s vs ~38 t/s on Intel Arc Pro B70), but the optimization is preserved for all subsequent inference. If both device and host allocation fail, skip the reorder and fall back to the unoptimized kernel path. Also fixes a bug where opt_for_reorder() marked tensors as reordered even when the reorder was skipped due to allocation failure. This caused DMMV/MMVQ kernels to read the original AoS data as if it were SoA, producing garbage output or NaN results. Tested on Intel Arc Pro B70 (32GB) with Q8_0, Q4_K_M models. Coding was AI-assisted (Claude), reviewed and tested on hardware by a human. Fixes #20478 * SYCL: add RAII temp buffer class + macro guard for host fallback Replace sycl_ext_malloc_with_fallback/sycl_ext_free_fallback free functions with sycl_reorder_temp_buffer RAII class. The host_fallback bool is now a private member, and cleanup happens automatically at scope exit. Add GGML_SYCL_HOST_MEM_FALLBACK cmake option (default ON) to guard the host memory fallback code path. Device access to host memory requires Linux kernel 6.8+ (Ubuntu 26.04+); users on older kernels can set -DGGML_SYCL_HOST_MEM_FALLBACK=OFF to disable it. Addresses arthw's review on PR #21638. Co-Authored-By: Claude Opus 4.6 (1M context) * SYCL: document GGML_SYCL_HOST_MEM_FALLBACK build option in SYCL.md Co-Authored-By: Claude Opus 4.6 (1M context) * SYCL: add reorder-aware DMMV dequantizers for Q4_K and Q6_K Q4_K and Q6_K had reorder support for MMVQ and GEMM paths but not DMMV. When the DMMV path encountered reordered data it would abort. Add DMMV kernels that read from the SOA reorder layout for both types. Same math as the non-reorder versions, different memory access pattern. Co-Authored-By: Claude Opus 4.6 (1M context) --------- Co-authored-by: Claude Opus 4.6 (1M context) --- ggml/CMakeLists.txt | 1 + ggml/src/ggml-sycl/CMakeLists.txt | 5 + ggml/src/ggml-sycl/convert.cpp | 33 ++- ggml/src/ggml-sycl/dequantize.hpp | 28 +++ ggml/src/ggml-sycl/dmmv.cpp | 321 +++++++++++++++++++++++++++++- ggml/src/ggml-sycl/ggml-sycl.cpp | 106 +++++++--- 6 files changed, 465 insertions(+), 29 deletions(-) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 8454eecde6e..6b65ecd6e5c 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -254,6 +254,7 @@ option(GGML_RPC "ggml: use RPC" option(GGML_SYCL "ggml: use SYCL" OFF) option(GGML_SYCL_F16 "ggml: use 16 bit floats for sycl calculations" OFF) option(GGML_SYCL_GRAPH "ggml: enable graphs in the SYCL backend" ON) +option(GGML_SYCL_HOST_MEM_FALLBACK "ggml: allow host memory fallback in SYCL reorder (requires kernel 6.8+)" ON) option(GGML_SYCL_DNN "ggml: enable oneDNN in the SYCL backend" ON) set (GGML_SYCL_TARGET "INTEL" CACHE STRING "ggml: sycl target device") diff --git a/ggml/src/ggml-sycl/CMakeLists.txt b/ggml/src/ggml-sycl/CMakeLists.txt index 7b07b227874..8e589fa238d 100644 --- a/ggml/src/ggml-sycl/CMakeLists.txt +++ b/ggml/src/ggml-sycl/CMakeLists.txt @@ -154,6 +154,11 @@ if (GGML_SYCL_GRAPH) target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_GRAPH) endif() +if (GGML_SYCL_HOST_MEM_FALLBACK) + message(STATUS "find GGML_SYCL_HOST_MEM_FALLBACK") + target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_HOST_MEM_FALLBACK) +endif() + if (GGML_SYCL_DEVICE_ARCH) target_compile_options(ggml-sycl PRIVATE -Xsycl-target-backend --offload-arch=${GGML_SYCL_DEVICE_ARCH}) target_link_options(ggml-sycl PRIVATE -Xsycl-target-backend --offload-arch=${GGML_SYCL_DEVICE_ARCH}) diff --git a/ggml/src/ggml-sycl/convert.cpp b/ggml/src/ggml-sycl/convert.cpp index f12419426ae..f3c521b45f6 100644 --- a/ggml/src/ggml-sycl/convert.cpp +++ b/ggml/src/ggml-sycl/convert.cpp @@ -151,6 +151,25 @@ static void dequantize_row_q4_0_sycl_reorder(const void *vx, dst_t *y, const int } +template +static void dequantize_row_q8_0_sycl_reorder(const void *vx, dst_t *y, const int64_t k, + dpct::queue_ptr stream) { + + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + int constexpr WARP_K = WARP_SIZE * QK8_0; + const int n_warp = (k + WARP_K - 1) / WARP_K; + GGML_ASSERT(k % QK8_0 == 0); + stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, n_warp) * + sycl::range<3>(1, 1, WARP_SIZE), + sycl::range<3>(1, 1, WARP_SIZE)), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]]{ + dequantize_block_q8_0_reorder(vx, y, k, item_ct1); + }); + +} + template static void dequantize_row_q4_1_sycl(const void *vx, dst_t *y, const int64_t k, dpct::queue_ptr stream) { @@ -614,7 +633,12 @@ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) { case GGML_TYPE_Q5_1: return dequantize_block_sycl; case GGML_TYPE_Q8_0: - return dequantize_block_sycl; + if (dst->src[0]->extra && + ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { + return dequantize_row_q8_0_sycl_reorder; + } else { + return dequantize_block_sycl; + } case GGML_TYPE_Q2_K: return dequantize_row_q2_K_sycl; case GGML_TYPE_Q3_K: @@ -683,7 +707,12 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) { case GGML_TYPE_Q5_1: return dequantize_block_sycl; case GGML_TYPE_Q8_0: - return dequantize_block_sycl; + if (dst->src[0]->extra && + ((ggml_tensor_extra_gpu*)dst->src[0]->extra)->optimized_feature.reorder) { + return dequantize_row_q8_0_sycl_reorder; + } else { + return dequantize_block_sycl; + } case GGML_TYPE_Q2_K: return dequantize_row_q2_K_sycl; case GGML_TYPE_Q3_K: diff --git a/ggml/src/ggml-sycl/dequantize.hpp b/ggml/src/ggml-sycl/dequantize.hpp index 68c3db30613..19fa88680d6 100644 --- a/ggml/src/ggml-sycl/dequantize.hpp +++ b/ggml/src/ggml-sycl/dequantize.hpp @@ -239,6 +239,34 @@ static void dequantize_block_q4_0_reorder(const void * __restrict__ vx, dst_t * } +// Dequantize Q8_0 from reorder layout: [all qs (k bytes)][all d values] +// Each thread handles one block of QK8_0 elements. +template +static void dequantize_block_q8_0_reorder(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t k, + const sycl::nd_item<3> &item_ct1) { + + const int64_t i = item_ct1.get_group(2); + const int64_t tid = item_ct1.get_local_id(2); + const int lane_ib = i * WARP_SIZE + tid; + + if (lane_ib >= k / QK8_0) { + return; + } + + dst_t * y_ptr = yy + lane_ib * QK8_0; + + auto qs = (const int8_t*)vx + lane_ib * QK8_0; + auto s_ptr = (const sycl::half*)((const uint8_t*)vx + k) + lane_ib; + + const float d = float(*s_ptr); + +#pragma unroll + for (int l = 0; l < QK8_0; ++l) { + y_ptr[l] = d * qs[l]; + } + +} + template static void dequantize_block_q4_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t nb32, const sycl::nd_item<3> &item_ct1) { diff --git a/ggml/src/ggml-sycl/dmmv.cpp b/ggml/src/ggml-sycl/dmmv.cpp index 1c8b6f3771f..5577bf73b28 100644 --- a/ggml/src/ggml-sycl/dmmv.cpp +++ b/ggml/src/ggml-sycl/dmmv.cpp @@ -615,6 +615,162 @@ static void dequantize_mul_mat_vec_q4_k(const void *__restrict__ vx, } } +static void dequantize_mul_mat_vec_q4_k_reorder(const void *__restrict__ vx, + const float *__restrict__ yy, + float *__restrict__ dst, + const int ncols, int nrows, + const sycl::nd_item<3> &item_ct1) { + + const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + + item_ct1.get_local_id(1); + if (row > nrows) return; + const int num_blocks_per_row = ncols / QK_K; + const int ib0 = row*num_blocks_per_row; + + // SOA base pointers for the reordered layout: + // [qs: nb * QK_K/2] [scales: nb * K_SCALE_SIZE] [dm: nb * sizeof(half2)] + const int nb = nrows * num_blocks_per_row; + const uint8_t * qs_base = (const uint8_t *)vx; + const uint8_t * scales_base = qs_base + (size_t)nb * (QK_K / 2); + const sycl::half2 * dm_base = (const sycl::half2 *)(scales_base + (size_t)nb * K_SCALE_SIZE); + +#if QK_K == 256 + const uint16_t kmask1 = 0x3f3f; + const uint16_t kmask2 = 0x0f0f; + const uint16_t kmask3 = 0xc0c0; + + const int tid = + item_ct1.get_local_id(2) / K_QUANTS_PER_ITERATION; // 0...31 or 0...16 + const int ix = + item_ct1.get_local_id(2) % K_QUANTS_PER_ITERATION; // 0 or 0,1 + + const int step = 8/K_QUANTS_PER_ITERATION; // 8 or 4 + + const int il = tid/step; // 0...3 + const int ir = tid - step*il; // 0...7 or 0...3 + const int n = 2 * K_QUANTS_PER_ITERATION; // 2 or 4 + + const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 + const int in = il%2; + + const int l0 = n*(2*ir + in); + const int q_offset = 32*im + l0; + const int y_offset = 64*im + l0; + + uint16_t aux[4]; + const uint8_t * sc = (const uint8_t *)aux; + +#if K_QUANTS_PER_ITERATION == 2 + uint32_t q32[4]; + const uint8_t * q4 = (const uint8_t *)q32; +#else + uint16_t q16[4]; + const uint8_t * q4 = (const uint8_t *)q16; +#endif + + float tmp = 0; // partial sum for thread in warp + + for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { + const int bi = ib0 + i; + + const float * y1 = yy + i*QK_K + y_offset; + const float * y2 = y1 + 128; + + const sycl::half2 dm_val = dm_base[bi]; + const float dall = dm_val[0]; + const float dmin = dm_val[1]; + + const uint16_t * a = (const uint16_t *)(scales_base + bi * K_SCALE_SIZE); + aux[0] = a[im+0] & kmask1; + aux[1] = a[im+2] & kmask1; + aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2); + aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2); + +#if K_QUANTS_PER_ITERATION == 2 + const uint32_t * q1 = (const uint32_t *)(qs_base + bi * (QK_K / 2) + q_offset); + const uint32_t * q2 = q1 + 16; + + q32[0] = q1[0] & 0x0f0f0f0f; + q32[1] = q1[0] & 0xf0f0f0f0; + q32[2] = q2[0] & 0x0f0f0f0f; + q32[3] = q2[0] & 0xf0f0f0f0; + + sycl::float4 s = {0.f, 0.f, 0.f, 0.f}; + float smin = 0; + for (int l = 0; l < 4; ++l) { + s.x() += y1[l] * q4[l + 0]; s.y() += y1[l + 32] * q4[l + 4]; + s.z() += y2[l] * q4[l + 8]; s.w() += y2[l + 32] * q4[l + 12]; + smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7]; + } + tmp += dall * (s.x() * sc[0] + s.y() * sc[1] * 1.f / 16.f + + s.z() * sc[4] + s.w() * sc[5] * 1.f / 16.f) - + dmin * smin; +#else + const uint16_t * q1 = (const uint16_t *)(qs_base + bi * (QK_K / 2) + q_offset); + const uint16_t * q2 = q1 + 32; + + q16[0] = q1[0] & 0x0f0f; + q16[1] = q1[0] & 0xf0f0; + q16[2] = q2[0] & 0x0f0f; + q16[3] = q2[0] & 0xf0f0; + + float4 s = {0.f, 0.f, 0.f, 0.f}; + float smin = 0; + for (int l = 0; l < 2; ++l) { + s.x += y1[l] * q4[l+0]; s.y += y1[l+32] * q4[l+2]; + s.z += y2[l] * q4[l+4]; s.w += y2[l+32] * q4[l+6]; + smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7]; + } + tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin; +#endif + + } +#else + const int tid = item_ct1.get_local_id(2)/(2*K_QUANTS_PER_ITERATION); // 0...15 + const int ix = item_ct1.get_local_id(2)%(2*K_QUANTS_PER_ITERATION); + + const int step = tid * K_QUANTS_PER_ITERATION; + + uint16_t aux16[2]; + const uint8_t * s = (const uint8_t *)aux16; + + float tmp = 0; + + for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { + const int bi = ib0 + i; + + const uint8_t * q = qs_base + bi * (QK_K / 2) + step; + const float * y = yy + i*QK_K + step; + const uint16_t * a = (const uint16_t *)(scales_base + bi * K_SCALE_SIZE); + aux16[0] = a[0] & 0x0f0f; + aux16[1] = (a[0] >> 4) & 0x0f0f; + const sycl::half2 dm_val = dm_base[bi]; + const float d = (float)dm_val[0]; + const float m = (float)dm_val[1]; + float sum = 0.f; + for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) { + sum += y[j+ 0] * (d * s[0] * (q[j+ 0] & 0xF) - m * s[2]) + + y[j+16] * (d * s[0] * (q[j+16] & 0xF) - m * s[2]) + + y[j+32] * (d * s[1] * (q[j+ 0] >> 4) - m * s[3]) + + y[j+48] * (d * s[1] * (q[j+16] >> 4) - m * s[3]); + } + tmp += sum; + } + +#endif + + // sum up partial sums and write back result +#pragma unroll + for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) { + tmp += + dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); + } + + if (tid == 0) { + dst[row] = tmp; + } +} + /* DPCT1110:7: The total declared local variable size in device function dequantize_mul_mat_vec_q5_k exceeds 128 bytes and may cause high register @@ -864,6 +1020,129 @@ static void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const floa } } +static void dequantize_mul_mat_vec_q6_k_reorder(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows, + const sycl::nd_item<3> &item_ct1) { + + static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION"); + + const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + + item_ct1.get_local_id(1); + if (row > nrows) return; + + const int num_blocks_per_row = ncols / QK_K; + const int ib0 = row*num_blocks_per_row; + + // SOA base pointers for the reordered layout: + // [ql: nb * QK_K/2] [qh: nb * QK_K/4] [scales: nb * QK_K/16] [d: nb * sizeof(half)] + const int nb = nrows * num_blocks_per_row; + const uint8_t * ql_base = (const uint8_t *)vx; + const uint8_t * qh_base = ql_base + (size_t)nb * (QK_K / 2); + const int8_t * scales_base = (const int8_t *)(qh_base + (size_t)nb * (QK_K / 4)); + const sycl::half * d_base = (const sycl::half *)((const uint8_t *)scales_base + (size_t)nb * (QK_K / 16)); + +#if QK_K == 256 + + const int tid = + item_ct1.get_local_id(2) / K_QUANTS_PER_ITERATION; // 0...31 or 0...16 + const int ix = + item_ct1.get_local_id(2) % K_QUANTS_PER_ITERATION; // 0 or 0, 1 + + const int step = 16/K_QUANTS_PER_ITERATION; // 16 or 8 + + const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... + const int in = tid - step*im; // 0...15 or 0...7 + +#if K_QUANTS_PER_ITERATION == 1 + const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15 + const int is = 0; +#else + const int l0 = 4 * in; // 0, 4, 8, ..., 28 + const int is = in / 4; +#endif + const int ql_offset = 64*im + l0; + const int qh_offset = 32*im + l0; + const int s_offset = 8*im + is; + const int y_offset = 128*im + l0; + + float tmp = 0; // partial sum for thread in warp + + for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { + const int bi = ib0 + i; + + const float * y = yy + i * QK_K + y_offset; + const uint8_t * ql = ql_base + bi * (QK_K / 2) + ql_offset; + const uint8_t * qh = qh_base + bi * (QK_K / 4) + qh_offset; + const int8_t * s = scales_base + bi * (QK_K / 16) + s_offset; + + const float d = d_base[bi]; + +#if K_QUANTS_PER_ITERATION == 1 + float sum = y[ 0] * s[0] * d * ((int8_t)((ql[ 0] & 0xF) | ((qh[ 0] & 0x03) << 4)) - 32) + + y[16] * s[1] * d * ((int8_t)((ql[16] & 0xF) | ((qh[16] & 0x03) << 4)) - 32) + + y[32] * s[2] * d * ((int8_t)((ql[32] & 0xF) | ((qh[ 0] & 0x0c) << 2)) - 32) + + y[48] * s[3] * d * ((int8_t)((ql[48] & 0xF) | ((qh[16] & 0x0c) << 2)) - 32) + + y[64] * s[4] * d * ((int8_t)((ql[ 0] >> 4) | ((qh[ 0] & 0x30) >> 0)) - 32) + + y[80] * s[5] * d * ((int8_t)((ql[16] >> 4) | ((qh[16] & 0x30) >> 0)) - 32) + + y[96] * s[6] * d * ((int8_t)((ql[32] >> 4) | ((qh[ 0] & 0xc0) >> 2)) - 32) + +y[112] * s[7] * d * ((int8_t)((ql[48] >> 4) | ((qh[16] & 0xc0) >> 2)) - 32); + tmp += sum; +#else + float sum = 0; + for (int l = 0; l < 4; ++l) { + sum += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32) + + y[l+32] * s[2] * d * ((int8_t)((ql[l+32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32) + + y[l+64] * s[4] * d * ((int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32) + + y[l+96] * s[6] * d * ((int8_t)((ql[l+32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32); + } + tmp += sum; +#endif + + } + +#else + + const int tid = item_ct1.get_local_id(2)/(2*K_QUANTS_PER_ITERATION); // 0...7 + const int ix = item_ct1.get_local_id(2)%(2*K_QUANTS_PER_ITERATION); // 0...3 + + const int step = tid * K_QUANTS_PER_ITERATION; + + float tmp = 0; // partial sum for thread in warp + + for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { + const int bi = ib0 + i; + + const float * y = yy + i * QK_K + step; + const uint8_t * ql = ql_base + bi * (QK_K / 2) + step; + const uint8_t * qh = qh_base + bi * (QK_K / 4) + step; + const int8_t * s = scales_base + bi * (QK_K / 16); + + const float d = d_base[bi]; + + float sum = 0; + for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) { + sum += y[j+ 0] * s[0] * d * ((int8_t)((ql[j+ 0] & 0xF) | ((qh[j] & 0x03) << 4)) - 32) + + y[j+16] * s[1] * d * ((int8_t)((ql[j+16] & 0xF) | ((qh[j] & 0x0c) << 2)) - 32) + + y[j+32] * s[2] * d * ((int8_t)((ql[j+ 0] >> 4) | ((qh[j] & 0x30) >> 0)) - 32) + + y[j+48] * s[3] * d * ((int8_t)((ql[j+16] >> 4) | ((qh[j] & 0xc0) >> 2)) - 32); + } + tmp += sum; + + } + +#endif + + // sum up partial sums and write back result +#pragma unroll + for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) { + tmp += + dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); + } + + if (tid == 0) { + dst[row] = tmp; + } +} + static void dequantize_mul_mat_vec_q4_0_sycl_reorder(const void *vx, const dfloat *y, float *dst, const int ncols, const int nrows, @@ -1167,6 +1446,38 @@ static void dequantize_mul_mat_vec_q6_K_sycl(const void *vx, const float *y, }); } +static void dequantize_mul_mat_vec_q4_K_sycl_reorder(const void *vx, const float *y, + float *dst, const int ncols, + const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int ny = 2 / K_QUANTS_PER_ITERATION; + const int block_num_y = (nrows + ny - 1) / ny; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE); + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] { + dequantize_mul_mat_vec_q4_k_reorder(vx, y, dst, ncols, nrows, item_ct1); + }); +} + +static void dequantize_mul_mat_vec_q6_K_sycl_reorder(const void *vx, const float *y, + float *dst, const int ncols, + const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int ny = 2 / K_QUANTS_PER_ITERATION; + const int block_num_y = (nrows + ny - 1) / ny; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE); + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] { + dequantize_mul_mat_vec_q6_k_reorder(vx, y, dst, ncols, nrows, item_ct1); + }); +} + void ggml_sycl_op_dequantize_mul_mat_vec( ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, @@ -1235,8 +1546,7 @@ void ggml_sycl_op_dequantize_mul_mat_vec( case GGML_TYPE_Q4_K: if ((ggml_tensor_extra_gpu *) dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { - // reorder is currently not supported for dmmv - GGML_ABORT("Unimplemented dequantize case case for q4_k reorder"); + dequantize_mul_mat_vec_q4_K_sycl_reorder(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); } else { dequantize_mul_mat_vec_q4_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); } @@ -1245,7 +1555,12 @@ void ggml_sycl_op_dequantize_mul_mat_vec( dequantize_mul_mat_vec_q5_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); break; case GGML_TYPE_Q6_K: - dequantize_mul_mat_vec_q6_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); + if ((ggml_tensor_extra_gpu *) dst->src[0]->extra && + ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { + dequantize_mul_mat_vec_q6_K_sycl_reorder(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); + } else { + dequantize_mul_mat_vec_q6_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); + } break; case GGML_TYPE_F16: convert_mul_mat_vec_f16_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index ea79d2538c1..c02a41ad862 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -3348,9 +3348,55 @@ static inline void sycl_ext_free(dpct::queue_ptr stream, void * ptr) { sycl::free(ptr, *stream); } -static void reorder_qw_q4_0(uint8_t * data_device, const int ncols, const int nrows, size_t size, size_t offset, +// RAII wrapper for temporary reorder buffers with optional host memory fallback. +// When device allocation fails and GGML_SYCL_HOST_MEM_FALLBACK is enabled, +// falls back to host memory so the reorder kernel can still run (over PCIe). +// Device access to host memory requires Linux kernel 6.8+ (Ubuntu 26.04+). +struct sycl_reorder_temp_buffer { + void * ptr = nullptr; + dpct::queue_ptr stream; + + sycl_reorder_temp_buffer(dpct::queue_ptr stream, size_t size) : stream(stream) { + ptr = sycl_ext_malloc_device(stream, size); +#ifdef GGML_SYCL_HOST_MEM_FALLBACK + if (!ptr) { + ptr = sycl::malloc_host(size, *stream); + if (ptr) { + host_fallback = true; + GGML_LOG_WARN("%s: device alloc of %zu bytes failed, using host memory fallback\n", __func__, size); + } + } +#endif + } + + ~sycl_reorder_temp_buffer() { + if (!ptr) { + return; + } + if (host_fallback) { + sycl::free(ptr, *stream); + } else { + sycl_ext_free(stream, ptr); + } + } + + explicit operator bool() const { return ptr != nullptr; } + + sycl_reorder_temp_buffer(const sycl_reorder_temp_buffer &) = delete; + sycl_reorder_temp_buffer & operator=(const sycl_reorder_temp_buffer &) = delete; + +private: + bool host_fallback = false; +}; + +static bool reorder_qw_q4_0(uint8_t * data_device, const int ncols, const int nrows, size_t size, size_t offset, dpct::queue_ptr stream) { - uint8_t * tmp_buf = static_cast(sycl_ext_malloc_device(stream, size)); + sycl_reorder_temp_buffer tmp(stream, size); + if (!tmp) { + GGML_LOG_WARN("%s: failed to allocate %zu bytes for reorder temp buffer, skipping reorder\n", __func__, size); + return false; + } + uint8_t * tmp_buf = static_cast(tmp.ptr); sycl::event copy_event; SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size))); @@ -3379,12 +3425,17 @@ static void reorder_qw_q4_0(uint8_t * data_device, const int ncols, const int nr if (!g_ggml_sycl_use_async_mem_op) { reorder_event.wait_and_throw(); } - sycl_ext_free(stream, tmp_buf); + return true; } -static void reorder_qw_q8_0(uint8_t * data_device, const int ncols, const int nrows, size_t size, size_t offset, +static bool reorder_qw_q8_0(uint8_t * data_device, const int ncols, const int nrows, size_t size, size_t offset, dpct::queue_ptr stream) { - uint8_t * tmp_buf = static_cast(sycl_ext_malloc_device(stream, size)); + sycl_reorder_temp_buffer tmp(stream, size); + if (!tmp) { + GGML_LOG_WARN("%s: failed to allocate %zu bytes for reorder temp buffer, skipping reorder\n", __func__, size); + return false; + } + uint8_t * tmp_buf = static_cast(tmp.ptr); sycl::event copy_event; SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size))); @@ -3413,16 +3464,21 @@ static void reorder_qw_q8_0(uint8_t * data_device, const int ncols, const int nr if (!g_ggml_sycl_use_async_mem_op) { reorder_event.wait_and_throw(); } - sycl_ext_free(stream, tmp_buf); + return true; } -static void reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) { +static bool reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) { GGML_ASSERT(size % sizeof(block_q4_K) == 0); GGML_ASSERT(offset % sizeof(block_q4_K) == 0); const int nblocks = size / sizeof(block_q4_K); - uint8_t * tmp_buf = static_cast(sycl_ext_malloc_device(stream, size)); + sycl_reorder_temp_buffer tmp(stream, size); + if (!tmp) { + GGML_LOG_WARN("%s: failed to allocate %zu bytes for reorder temp buffer, skipping reorder\n", __func__, size); + return false; + } + uint8_t * tmp_buf = static_cast(tmp.ptr); sycl::event copy_event; SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size))); @@ -3451,16 +3507,21 @@ static void reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, d if (!g_ggml_sycl_use_async_mem_op) { reorder_event.wait_and_throw(); } - sycl_ext_free(stream, tmp_buf); + return true; } -static void reorder_qw_q6_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) { +static bool reorder_qw_q6_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) { GGML_ASSERT(size % sizeof(block_q6_K) == 0); GGML_ASSERT(offset % sizeof(block_q6_K) == 0); const int nblocks = size / sizeof(block_q6_K); - uint8_t * tmp_buf = static_cast(sycl_ext_malloc_device(stream, size)); + sycl_reorder_temp_buffer tmp(stream, size); + if (!tmp) { + GGML_LOG_WARN("%s: failed to allocate %zu bytes for reorder temp buffer, skipping reorder\n", __func__, size); + return false; + } + uint8_t * tmp_buf = static_cast(tmp.ptr); sycl::event copy_event; SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size))); @@ -3499,10 +3560,10 @@ static void reorder_qw_q6_k(uint8_t * data_device, size_t size, size_t offset, d if (!g_ggml_sycl_use_async_mem_op) { reorder_event.wait_and_throw(); } - sycl_ext_free(stream, tmp_buf); + return true; } -static void reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) { +static bool reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) { uint8_t * data_device = (uint8_t *) src0->data; size_t ncols = src0->ne[0]; size_t nrows = src0->ne[1]; @@ -3510,20 +3571,16 @@ static void reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) { switch (src0->type) { case GGML_TYPE_Q4_0: - reorder_qw_q4_0(data_device, ncols, nrows, size, 0, stream); - break; + return reorder_qw_q4_0(data_device, ncols, nrows, size, 0, stream); case GGML_TYPE_Q8_0: - reorder_qw_q8_0(data_device, ncols, nrows, size, 0, stream); - break; + return reorder_qw_q8_0(data_device, ncols, nrows, size, 0, stream); case GGML_TYPE_Q4_K: - reorder_qw_q4_k(data_device, size, 0, stream); - break; + return reorder_qw_q4_k(data_device, size, 0, stream); case GGML_TYPE_Q6_K: - reorder_qw_q6_k(data_device, size, 0, stream); - break; + return reorder_qw_q6_k(data_device, size, 0, stream); default: GGML_ABORT("reorder_qw() called with unsupported type"); - break; + return false; } } @@ -3563,8 +3620,9 @@ static void opt_for_reorder(ggml_backend_sycl_context * ctx, const ggml_tensor * break; } - reorder_qw(src0, ctx->stream()); - extra->optimized_feature.reorder = true; // Used to decode/dequan in next steps and avoid re-reordering + if (reorder_qw(src0, ctx->stream())) { + extra->optimized_feature.reorder = true; // Used to decode/dequan in next steps and avoid re-reordering + } } From 092330b474ed34f80ed854ae7b64034a94a6f79a Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Thu, 16 Apr 2026 01:12:19 -0700 Subject: [PATCH 443/831] ggml-webgpu: compute pass batching and removing profiling overhead (llama/21873) * Update register tiling matmul to use f32 accumulation * fix profiling code * Fix register tiling matmul for chrome, i'm blaming dawn * Update batch tuning value for iOS * compile fix * Fix use of new load function * Move to a single query set for GPU profiling * Move to batching compute passes when not profiling * Refactor build_multi * remove iOS throttling now that we're batching compute passes --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 799 ++++++++++++--------------- 1 file changed, 348 insertions(+), 451 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index aa3fe06d5a9..01637e2ddab 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -73,8 +73,8 @@ static inline void compute_2d_workgroups(uint32_t total_wg, uint32_t max_per_dim #endif // GGML_WEBGPU_CPU_PROFILE #ifdef GGML_WEBGPU_GPU_PROFILE -# define WEBGPU_NUM_TIMESTAMP_QUERY_BUFS 32 -# define WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES 16 // e.g. enough for two timestamps +# define WEBGPU_MAX_PROFILE_QUERY_COUNT 4096u +# define WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES (WEBGPU_MAX_PROFILE_QUERY_COUNT * sizeof(uint64_t)) #endif /* Constants */ @@ -159,78 +159,20 @@ struct webgpu_param_arena { ~webgpu_param_arena() { this->cleanup(); } }; -#ifdef GGML_WEBGPU_GPU_PROFILE -struct webgpu_gpu_profile_bufs { - wgpu::Buffer host_buf; - wgpu::Buffer dev_buf; - wgpu::QuerySet query_set; -}; - -// Holds a pool of parameter buffers for WebGPU operations -struct webgpu_gpu_profile_buf_pool { - std::vector free; - - std::mutex mutex; - - std::condition_variable cv; - - void init(wgpu::Device device, - int num_bufs, - size_t buf_size, - wgpu::BufferUsage dev_buf_usage, - wgpu::BufferUsage host_buf_usage) { - for (int i = 0; i < num_bufs; i++) { - wgpu::Buffer host_buf; - wgpu::Buffer dev_buf; - ggml_webgpu_create_buffer(device, host_buf, buf_size, host_buf_usage, "ggml_webgpu_host_profile_buf"); - ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_profile_buf"); - // Create a query set for 2 timestamps - wgpu::QuerySetDescriptor ts_query_set_desc = {}; - - ts_query_set_desc.type = wgpu::QueryType::Timestamp; - ts_query_set_desc.count = 2; - wgpu::QuerySet ts_query_set = device.CreateQuerySet(&ts_query_set_desc); - - free.push_back({ host_buf, dev_buf, ts_query_set }); - } - } - - webgpu_gpu_profile_bufs alloc_bufs() { - std::unique_lock lock(mutex); - cv.wait(lock, [this] { return !free.empty(); }); - webgpu_gpu_profile_bufs bufs = free.back(); - free.pop_back(); - return bufs; - } - - void free_bufs(std::vector bufs) { - std::lock_guard lock(mutex); - free.insert(free.end(), bufs.begin(), bufs.end()); - cv.notify_all(); - } - - void cleanup() { - std::lock_guard lock(mutex); - for (auto & bufs : free) { - bufs.host_buf.Destroy(); - bufs.dev_buf.Destroy(); - bufs.query_set.Destroy(); - } - free.clear(); - } - - ~webgpu_gpu_profile_buf_pool() { this->cleanup(); } -}; -#endif - struct webgpu_encoded_op { uint32_t num_kernels = 0; #ifdef GGML_WEBGPU_GPU_PROFILE - webgpu_gpu_profile_bufs timestamp_query_bufs; - std::string pipeline_name; + std::vector pipeline_names; #endif }; +struct webgpu_dispatch_desc { + webgpu_pipeline pipeline; + std::vector params; + std::vector bind_group_entries; + std::pair workgroups = { 1, 1 }; +}; + struct webgpu_capabilities { wgpu::Limits limits; bool supports_subgroup_matrix = false; @@ -256,7 +198,7 @@ struct webgpu_global_context_struct { webgpu_capabilities capabilities; // Shared buffer to move data from device to host wgpu::Buffer get_tensor_staging_buf; - // Global mutex for pipeline and staging buffer, will be refactored to exclude pipeline caches. + // Global mutex for get_tensor std::recursive_mutex mutex; wgpu::Buffer memset_params_buf; @@ -272,8 +214,6 @@ struct webgpu_global_context_struct { #ifdef GGML_WEBGPU_GPU_PROFILE // Profiling: per-shader GPU time in ms std::unordered_map shader_gpu_time_ms; - // Profiling: pool of timestamp query buffers (one per operation) - webgpu_gpu_profile_buf_pool timestamp_query_buf_pool; #endif #ifdef GGML_WEBGPU_DEBUG @@ -312,11 +252,45 @@ struct webgpu_context_struct { std::unique_ptr shader_lib; - webgpu_param_arena param_arena; - wgpu::Buffer set_rows_dev_error_buf; - wgpu::Buffer set_rows_host_error_buf; + webgpu_param_arena param_arena; + wgpu::Buffer set_rows_dev_error_buf; + wgpu::Buffer set_rows_host_error_buf; + wgpu::CommandEncoder active_command_encoder; + wgpu::ComputePassEncoder active_compute_pass; size_t memset_bytes_per_thread; + +#ifdef GGML_WEBGPU_GPU_PROFILE + wgpu::Buffer profile_timestamp_dev_buf; + wgpu::Buffer profile_timestamp_host_buf; + wgpu::QuerySet profile_timestamp_query_set; + uint32_t profile_timestamp_query_count = 0; +#endif + + ~webgpu_context_struct() { +#ifdef GGML_WEBGPU_GPU_PROFILE + if (this->profile_timestamp_host_buf) { + this->profile_timestamp_host_buf.Destroy(); + this->profile_timestamp_host_buf = nullptr; + } + if (this->profile_timestamp_dev_buf) { + this->profile_timestamp_dev_buf.Destroy(); + this->profile_timestamp_dev_buf = nullptr; + } + if (this->profile_timestamp_query_set) { + this->profile_timestamp_query_set.Destroy(); + this->profile_timestamp_query_set = nullptr; + } +#endif + if (this->set_rows_host_error_buf) { + this->set_rows_host_error_buf.Destroy(); + this->set_rows_host_error_buf = nullptr; + } + if (this->set_rows_dev_error_buf) { + this->set_rows_dev_error_buf.Destroy(); + this->set_rows_dev_error_buf = nullptr; + } + } }; typedef std::shared_ptr webgpu_context; @@ -399,24 +373,6 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device, /** WebGPU Actions */ -#ifdef GGML_WEBGPU_GPU_PROFILE -static void ggml_backend_webgpu_wait_profile_futures(webgpu_global_context & ctx, - std::vector & futures) { - if (futures.empty()) { - return; - } - - constexpr size_t max_futures_per_wait = 64; - - while (!futures.empty()) { - ctx->instance.WaitAny(std::min(max_futures_per_wait, futures.size()), futures.data(), UINT64_MAX); - futures.erase(std::remove_if(futures.begin(), futures.end(), - [](const wgpu::FutureWaitInfo & info) { return info.completed; }), - futures.end()); - } -} -#endif - template static void ggml_backend_webgpu_check_wait_status(wgpu::WaitStatus wait_status, T callback_status, @@ -436,22 +392,8 @@ static void ggml_backend_webgpu_check_wait_status(wgpu::WaitStatus wait_status, } } -#ifdef __EMSCRIPTEN__ -EM_JS(int, ggml_webgpu_is_ios_browser, (), { - const ua = navigator.userAgent; - return (ua.includes('iPhone') || ua.includes('iPad')) ? 1 : 0; -}); -#endif - // TODO: these next two functions may want tuning across different platforms and workloads, static uint32_t ggml_backend_webgpu_get_max_inflight_batches() { -#ifdef __EMSCRIPTEN__ - // iOS has very strict limits on the number of in-flight GPU commands, - // so we need to throttle to avoid failures. - if (ggml_webgpu_is_ios_browser()) { - return 1; - } -#endif return UINT32_MAX; } @@ -524,118 +466,77 @@ static void ggml_backend_webgpu_debug(webgpu_global_context & ctx) { } #endif -#ifdef GGML_WEBGPU_GPU_PROFILE -static void ggml_backend_webgpu_collect_profile_futures(webgpu_global_context & ctx, - const std::vector & commands, - std::vector & futures) { - for (const auto & command : commands) { - auto label = command.pipeline_name; - auto ts_bufs = command.timestamp_query_bufs; - - wgpu::Future f = ts_bufs.host_buf.MapAsync( - wgpu::MapMode::Read, 0, ts_bufs.host_buf.GetSize(), wgpu::CallbackMode::AllowSpontaneous, - [ctx, ts_bufs, label](wgpu::MapAsyncStatus status, wgpu::StringView message) { - if (status != wgpu::MapAsyncStatus::Success) { - GGML_LOG_ERROR("ggml_webgpu: Failed to map timestamp buffer: %s\n", std::string(message).c_str()); - } else { - const uint64_t * ts_data = (const uint64_t *) ts_bufs.host_buf.GetConstMappedRange(); - // WebGPU timestamps are in ns; convert to ms - double elapsed_ms = double(ts_data[1] - ts_data[0]) * 1e-6; - ctx->shader_gpu_time_ms[label] += elapsed_ms; - } - // We can't unmap in here due to WebGPU reentrancy limitations. - ctx->timestamp_query_buf_pool.free_bufs({ ts_bufs }); - }); - futures.push_back({ f }); - } -} -#endif - -static webgpu_encoded_op ggml_backend_webgpu_build_multi( - webgpu_global_context & ctx, - webgpu_param_arena & param_arena, - wgpu::CommandEncoder & encoder, - const std::vector & pipelines, - const std::vector> & params_list, - const std::vector> & bind_group_entries_list, - const std::vector> & workgroups_list) { - GGML_ASSERT(pipelines.size() == params_list.size()); - GGML_ASSERT(pipelines.size() == bind_group_entries_list.size()); - GGML_ASSERT(pipelines.size() == workgroups_list.size()); - +static webgpu_encoded_op ggml_backend_webgpu_build_multi(webgpu_context & ctx, + const std::vector & dispatches) { webgpu_encoded_op result = {}; std::vector bind_groups; std::vector param_offsets; - result.num_kernels = pipelines.size(); + result.num_kernels = dispatches.size(); - for (size_t i = 0; i < pipelines.size(); i++) { - const size_t param_size = params_list[i].size() * sizeof(uint32_t); - const size_t param_offset = param_arena.alloc_slot(param_size); + for (size_t i = 0; i < dispatches.size(); i++) { + const webgpu_dispatch_desc & dispatch = dispatches[i]; + const size_t param_size = dispatch.params.size() * sizeof(uint32_t); + const size_t param_offset = ctx->param_arena.alloc_slot(param_size); - std::vector entries = bind_group_entries_list[i]; + std::vector entries = dispatch.bind_group_entries; uint32_t params_binding_num = entries.size(); entries.push_back({ .binding = params_binding_num, - .buffer = param_arena.buffer, + .buffer = ctx->param_arena.buffer, .offset = param_offset, - .size = param_arena.slot_size }); + .size = ctx->param_arena.slot_size }); wgpu::BindGroupDescriptor bind_group_desc; - bind_group_desc.layout = pipelines[i].pipeline.GetBindGroupLayout(0); + bind_group_desc.layout = dispatch.pipeline.pipeline.GetBindGroupLayout(0); bind_group_desc.entryCount = entries.size(); bind_group_desc.entries = entries.data(); - bind_group_desc.label = pipelines[i].name.c_str(); - bind_groups.push_back(ctx->device.CreateBindGroup(&bind_group_desc)); + bind_group_desc.label = dispatch.pipeline.name.c_str(); + bind_groups.push_back(ctx->global_ctx->device.CreateBindGroup(&bind_group_desc)); param_offsets.push_back(param_offset); } for (size_t i = 0; i < param_offsets.size(); i++) { - ctx->queue.WriteBuffer(param_arena.buffer, param_offsets[i], params_list[i].data(), - params_list[i].size() * sizeof(uint32_t)); + ctx->global_ctx->queue.WriteBuffer(ctx->param_arena.buffer, param_offsets[i], dispatches[i].params.data(), + dispatches[i].params.size() * sizeof(uint32_t)); } + #ifdef GGML_WEBGPU_GPU_PROFILE - webgpu_gpu_profile_bufs ts_bufs = ctx->timestamp_query_buf_pool.alloc_bufs(); - if (ts_bufs.host_buf.GetMapState() == wgpu::BufferMapState::Mapped) { - ts_bufs.host_buf.Unmap(); + for (size_t i = 0; i < dispatches.size(); i++) { + GGML_ASSERT(ctx->profile_timestamp_query_count + 2 <= WEBGPU_MAX_PROFILE_QUERY_COUNT); + const uint32_t query_begin = ctx->profile_timestamp_query_count++; + const uint32_t query_end = ctx->profile_timestamp_query_count++; + wgpu::PassTimestampWrites ts_writes = { .querySet = ctx->profile_timestamp_query_set, + .beginningOfPassWriteIndex = query_begin, + .endOfPassWriteIndex = query_end }; + wgpu::ComputePassDescriptor pass_desc = { .timestampWrites = &ts_writes }; + wgpu::ComputePassEncoder pass = ctx->active_command_encoder.BeginComputePass(&pass_desc); + + pass.SetPipeline(dispatches[i].pipeline.pipeline); + pass.SetBindGroup(0, bind_groups[i]); + pass.DispatchWorkgroups(dispatches[i].workgroups.first, dispatches[i].workgroups.second, 1); + pass.End(); + result.pipeline_names.push_back(dispatches[i].pipeline.name); } - - wgpu::PassTimestampWrites ts_writes = { .querySet = ts_bufs.query_set, - .beginningOfPassWriteIndex = 0, - .endOfPassWriteIndex = 1 }; - wgpu::ComputePassDescriptor pass_desc = { .timestampWrites = &ts_writes }; - wgpu::ComputePassEncoder pass = encoder.BeginComputePass(&pass_desc); #else - wgpu::ComputePassEncoder pass = encoder.BeginComputePass(); -#endif - for (size_t i = 0; i < pipelines.size(); i++) { - pass.SetPipeline(pipelines[i].pipeline); - pass.SetBindGroup(0, bind_groups[i]); - pass.DispatchWorkgroups(workgroups_list[i].first, workgroups_list[i].second, 1); + for (size_t i = 0; i < dispatches.size(); i++) { + ctx->active_compute_pass.SetPipeline(dispatches[i].pipeline.pipeline); + ctx->active_compute_pass.SetBindGroup(0, bind_groups[i]); + ctx->active_compute_pass.DispatchWorkgroups(dispatches[i].workgroups.first, dispatches[i].workgroups.second, 1); } - pass.End(); - -#ifdef GGML_WEBGPU_GPU_PROFILE - encoder.ResolveQuerySet(ts_bufs.query_set, 0, 2, ts_bufs.dev_buf, 0); - encoder.CopyBufferToBuffer(ts_bufs.dev_buf, 0, ts_bufs.host_buf, 0, ts_bufs.host_buf.GetSize()); - result.timestamp_query_bufs = ts_bufs; - result.pipeline_name = pipelines.front().name; #endif + return result; } -static webgpu_encoded_op ggml_backend_webgpu_build(webgpu_global_context & ctx, - webgpu_param_arena & param_arena, - wgpu::CommandEncoder & encoder, +static webgpu_encoded_op ggml_backend_webgpu_build(webgpu_context & ctx, webgpu_pipeline & pipeline, std::vector params, std::vector bind_group_entries, uint32_t wg_x, uint32_t wg_y = 1) { - return ggml_backend_webgpu_build_multi(ctx, param_arena, encoder, - { - pipeline - }, - { std::move(params) }, { std::move(bind_group_entries) }, - { { wg_x, wg_y } }); + return ggml_backend_webgpu_build_multi( + ctx, { + { pipeline, std::move(params), std::move(bind_group_entries), { wg_x, wg_y } }, + }); } static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx, @@ -784,10 +685,7 @@ static binary_overlap_flags ggml_webgpu_detect_binary_overlap(ggml_tensor * src0 return flags; } -static webgpu_encoded_op ggml_webgpu_cpy(webgpu_context & ctx, - wgpu::CommandEncoder & encoder, - ggml_tensor * src, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { ggml_webgpu_shader_lib_context shader_lib_ctx = { .src0 = src, .dst = dst, @@ -825,14 +723,13 @@ static webgpu_encoded_op ggml_webgpu_cpy(webgpu_context & ctx, }; uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } -static webgpu_encoded_op ggml_webgpu_set(webgpu_context & ctx, - wgpu::CommandEncoder & encoder, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_set(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { const bool inplace = ggml_webgpu_tensor_equal(src0, dst); ggml_webgpu_shader_lib_context shader_lib_ctx = { @@ -891,13 +788,10 @@ static webgpu_encoded_op ggml_webgpu_set(webgpu_context & ctx, .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } -static webgpu_encoded_op ggml_webgpu_pad(webgpu_context & ctx, - wgpu::CommandEncoder & encoder, - ggml_tensor * src, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { ggml_webgpu_shader_lib_context shader_lib_ctx = { .src0 = src, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup }; @@ -949,14 +843,13 @@ static webgpu_encoded_op ggml_webgpu_pad(webgpu_context & ctx, }; uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } -static webgpu_encoded_op ggml_webgpu_solve_tri(webgpu_context & ctx, - wgpu::CommandEncoder & encoder, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_solve_tri(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { ggml_webgpu_shader_lib_context shader_lib_ctx = { .src0 = src0, .src1 = src1, @@ -1011,14 +904,13 @@ static webgpu_encoded_op ggml_webgpu_solve_tri(webgpu_context & ctx, const uint32_t wg_x = CEIL_DIV((uint32_t) src1->ne[0], decisions->wg_size); const uint32_t wg_y = (uint32_t) (dst->ne[2] * dst->ne[3]); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x, wg_y); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); } -static webgpu_encoded_op ggml_webgpu_ssm_conv(webgpu_context & ctx, - wgpu::CommandEncoder & encoder, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_ssm_conv(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { ggml_webgpu_shader_lib_context shader_lib_ctx = { .src0 = src0, .src1 = src1, @@ -1068,18 +960,17 @@ static webgpu_encoded_op ggml_webgpu_ssm_conv(webgpu_context & ctx, const uint32_t wg_x = CEIL_DIV((uint32_t) src0->ne[1], decisions->block_size); const uint32_t wg_y = token_tiles * (uint32_t) dst->ne[2]; - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x, wg_y); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); } -static webgpu_encoded_op ggml_webgpu_gated_delta_net(webgpu_context & ctx, - wgpu::CommandEncoder & encoder, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * src2, - ggml_tensor * src3, - ggml_tensor * src4, - ggml_tensor * src5, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_gated_delta_net(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * src2, + ggml_tensor * src3, + ggml_tensor * src4, + ggml_tensor * src5, + ggml_tensor * dst) { ggml_webgpu_shader_lib_context shader_lib_ctx = { .src0 = src0, .src1 = src1, @@ -1154,14 +1045,13 @@ static webgpu_encoded_op ggml_webgpu_gated_delta_net(webgpu_context & ctx, .size = ggml_webgpu_tensor_binding_size(ctx, dst) } }; - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, h, n_seqs); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, h, n_seqs); } -static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, - wgpu::CommandEncoder & encoder, - ggml_tensor * src, - ggml_tensor * idx, - ggml_tensor * dst) { +static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, + ggml_tensor * src, + ggml_tensor * idx, + ggml_tensor * dst) { // For set rows specifically, we need to check if src and idx are empty // tensors. if (ggml_is_empty(src) || ggml_is_empty(idx)) { @@ -1224,7 +1114,7 @@ static std::optional ggml_webgpu_set_rows(webgpu_context & threads = src->ne[0] * src->ne[1] * src->ne[2] * src->ne[3]; } uint32_t wg_x = CEIL_DIV(threads, decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x, 1); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, 1); } // Workgroup size is a common constant @@ -1235,11 +1125,10 @@ static std::vector ggml_webgpu_wg_size_entry(uint32_t wg_si return constants; } -static webgpu_encoded_op ggml_webgpu_get_rows(webgpu_context & ctx, - wgpu::CommandEncoder & encoder, - ggml_tensor * src, - ggml_tensor * idx, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_get_rows(webgpu_context & ctx, + ggml_tensor * src, + ggml_tensor * idx, + ggml_tensor * dst) { const bool float_parallel = src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16 || src->type == GGML_TYPE_I32; ggml_webgpu_shader_lib_context shader_lib_ctx = { @@ -1291,14 +1180,13 @@ static webgpu_encoded_op ggml_webgpu_get_rows(webgpu_context & ctx, uint32_t total_threads = float_parallel ? blocks_per_row * total_rows : total_rows; uint32_t wg_x = CEIL_DIV(total_threads, decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } -static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, - wgpu::CommandEncoder & encoder, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { // Determine if this is a mat-vec operation bool is_vec = (dst->ne[1] == 1); @@ -1437,15 +1325,14 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y); } - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x, wg_y); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); } -static webgpu_encoded_op ggml_webgpu_mul_mat_id(webgpu_context & ctx, - wgpu::CommandEncoder & encoder, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * src2, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_mul_mat_id(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * src2, + ggml_tensor * dst) { ggml_webgpu_shader_lib_context shader_lib_ctx = { .src0 = src0, .src1 = src1, @@ -1457,10 +1344,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat_id(webgpu_context & ctx, // Get or create pipeline webgpu_pipeline gather_pipeline, main_pipeline; - std::vector pipelines; - std::vector> params_list; - std::vector> entries_list; - std::vector> workgroups_list; + std::vector dispatches; gather_pipeline = ctx->shader_lib->get_mul_mat_id_gather_pipeline(shader_lib_ctx); main_pipeline = ctx->shader_lib->get_mul_mat_id_pipeline(shader_lib_ctx); @@ -1520,10 +1404,9 @@ static webgpu_encoded_op ggml_webgpu_mul_mat_id(webgpu_context & ctx, const uint32_t gather_wg_x = std::min(gather_total_wg, max_wg_per_dim); const uint32_t gather_wg_y = CEIL_DIV(gather_total_wg, gather_wg_x); - pipelines.push_back(gather_pipeline); - params_list.push_back(std::move(gather_params)); - entries_list.push_back(std::move(gather_entries)); - workgroups_list.push_back({ gather_wg_x, gather_wg_y }); + dispatches.push_back({ + gather_pipeline, std::move(gather_params), std::move(gather_entries), { gather_wg_x, gather_wg_y } + }); // params for mul_mat_id.wgsl std::vector main_params = { @@ -1588,24 +1471,21 @@ static webgpu_encoded_op ggml_webgpu_mul_mat_id(webgpu_context & ctx, compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y); - pipelines.push_back(main_pipeline); - params_list.push_back(std::move(main_params)); - entries_list.push_back(std::move(main_entries)); - workgroups_list.push_back({ wg_x, wg_y }); + dispatches.push_back({ + main_pipeline, std::move(main_params), std::move(main_entries), { wg_x, wg_y } + }); - return ggml_backend_webgpu_build_multi(ctx->global_ctx, ctx->param_arena, encoder, pipelines, params_list, - entries_list, workgroups_list); + return ggml_backend_webgpu_build_multi(ctx, dispatches); } #ifndef __EMSCRIPTEN__ -static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, - wgpu::CommandEncoder & encoder, - ggml_tensor * Q, - ggml_tensor * K, - ggml_tensor * V, - ggml_tensor * mask, - ggml_tensor * sinks, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, + ggml_tensor * Q, + ggml_tensor * K, + ggml_tensor * V, + ggml_tensor * mask, + ggml_tensor * sinks, + ggml_tensor * dst) { float scale = *(float *) dst->op_params; float max_bias; memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); @@ -1897,40 +1777,33 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, const uint64_t split_wg_total = (uint64_t) wg_x * nwg; GGML_ASSERT(split_wg_total <= UINT32_MAX); - std::vector pipelines; - std::vector> params_list; - std::vector> entries_list; - std::vector> workgroups_list; + std::vector dispatches; if (use_blk) { - pipelines.push_back(blk_pipeline); - params_list.push_back(std::move(blk_params)); - entries_list.push_back(std::move(blk_entries)); - workgroups_list.push_back({ blk_nblk0, blk_nblk1 * blk_batch_count }); + dispatches.push_back({ + blk_pipeline, + std::move(blk_params), + std::move(blk_entries), + { blk_nblk0, blk_nblk1 * blk_batch_count } + }); } - pipelines.push_back(pipeline); - params_list.push_back(std::move(split_params)); - entries_list.push_back(std::move(split_entries)); - workgroups_list.push_back({ (uint32_t) split_wg_total, 1u }); + dispatches.push_back({ + pipeline, std::move(split_params), std::move(split_entries), { (uint32_t) split_wg_total, 1u } + }); if (use_vec_reduce) { - pipelines.push_back(reduce_pipeline); - params_list.push_back(std::move(reduce_params)); - entries_list.push_back(std::move(reduce_entries)); - workgroups_list.push_back({ (uint32_t) nrows, 1u }); + dispatches.push_back({ + reduce_pipeline, std::move(reduce_params), std::move(reduce_entries), { (uint32_t) nrows, 1u } + }); } - return ggml_backend_webgpu_build_multi(ctx->global_ctx, ctx->param_arena, encoder, pipelines, params_list, - entries_list, workgroups_list); + return ggml_backend_webgpu_build_multi(ctx, dispatches); } - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } #endif // __EMSCRIPTEN__ -static webgpu_encoded_op ggml_webgpu_unary_op(webgpu_context & ctx, - wgpu::CommandEncoder & encoder, - ggml_tensor * src, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { bool is_unary = dst->op == GGML_OP_UNARY; bool inplace = ggml_webgpu_tensor_equal(src, dst) || (dst->op == GGML_OP_FILL); @@ -2005,14 +1878,13 @@ static webgpu_encoded_op ggml_webgpu_unary_op(webgpu_context & ctx, } uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } -static webgpu_encoded_op ggml_webgpu_binary_op(webgpu_context & ctx, - wgpu::CommandEncoder & encoder, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_binary_op(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { binary_overlap_flags flags = ggml_webgpu_detect_binary_overlap(src0, src1, dst); ggml_webgpu_shader_lib_context shader_lib_ctx = { @@ -2108,14 +1980,13 @@ static webgpu_encoded_op ggml_webgpu_binary_op(webgpu_context & ctx, } uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } -static webgpu_encoded_op ggml_webgpu_concat(webgpu_context & ctx, - wgpu::CommandEncoder & encoder, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_concat(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { uint32_t ne = (uint32_t) ggml_nelements(dst); uint32_t dim = (uint32_t) dst->op_params[0]; @@ -2165,13 +2036,10 @@ static webgpu_encoded_op ggml_webgpu_concat(webgpu_context & ctx, webgpu_pipeline pipeline = ctx->shader_lib->get_concat_pipeline(shader_lib_ctx); auto * decisions = static_cast(pipeline.context.get()); uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } -static webgpu_encoded_op ggml_webgpu_repeat(webgpu_context & ctx, - wgpu::CommandEncoder & encoder, - ggml_tensor * src0, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_repeat(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * dst) { uint32_t ne = (uint32_t) ggml_nelements(dst); std::vector params = { ne, @@ -2210,13 +2078,10 @@ static webgpu_encoded_op ggml_webgpu_repeat(webgpu_context & ctx, webgpu_pipeline pipeline = ctx->shader_lib->get_repeat_pipeline(shader_lib_ctx); auto * decisions = static_cast(pipeline.context.get()); uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } -static webgpu_encoded_op ggml_webgpu_row_norm(webgpu_context & ctx, - wgpu::CommandEncoder & encoder, - ggml_tensor * src, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_row_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { bool inplace = ggml_webgpu_tensor_equal(src, dst); std::vector params = { @@ -2256,16 +2121,14 @@ static webgpu_encoded_op ggml_webgpu_row_norm(webgpu_context & ctx, }; webgpu_pipeline pipeline = ctx->shader_lib->get_row_norm_pipeline(shader_lib_ctx); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, - ggml_nrows(src)); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, ggml_nrows(src)); } -static webgpu_encoded_op ggml_webgpu_rope(webgpu_context & ctx, - wgpu::CommandEncoder & encoder, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * src2, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_rope(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * src2, + ggml_tensor * dst) { ggml_webgpu_shader_lib_context shader_lib_ctx = { .src0 = src0, .src1 = src1, @@ -2362,14 +2225,13 @@ static webgpu_encoded_op ggml_webgpu_rope(webgpu_context & ctx, } uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } -static webgpu_encoded_op ggml_webgpu_glu(webgpu_context & ctx, - wgpu::CommandEncoder & encoder, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_glu(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { ggml_webgpu_shader_lib_context shader_lib_ctx = { .src0 = src0, .src1 = src1, @@ -2428,13 +2290,10 @@ static webgpu_encoded_op ggml_webgpu_glu(webgpu_context & ctx, .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } -static webgpu_encoded_op ggml_webgpu_scale(webgpu_context & ctx, - wgpu::CommandEncoder & encoder, - ggml_tensor * src, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { bool inplace = ggml_webgpu_tensor_equal(src, dst); ggml_webgpu_shader_lib_context shader_lib_ctx = { @@ -2482,15 +2341,14 @@ static webgpu_encoded_op ggml_webgpu_scale(webgpu_context & ctx, } uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } -static webgpu_encoded_op ggml_webgpu_soft_max(webgpu_context & ctx, - wgpu::CommandEncoder & encoder, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * src2, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_soft_max(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * src2, + ggml_tensor * dst) { ggml_webgpu_shader_lib_context shader_lib_ctx = { .src0 = src0, .src1 = src1, @@ -2566,14 +2424,10 @@ static webgpu_encoded_op ggml_webgpu_soft_max(webgpu_context & ctx, .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); } - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, - ggml_nrows(dst)); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, ggml_nrows(dst)); } -static webgpu_encoded_op ggml_webgpu_argmax(webgpu_context & ctx, - wgpu::CommandEncoder & encoder, - ggml_tensor * src, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_argmax(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), (uint32_t) src->ne[0] }; @@ -2595,13 +2449,10 @@ static webgpu_encoded_op ggml_webgpu_argmax(webgpu_context & ctx, webgpu_pipeline pipeline = ctx->shader_lib->get_argmax_pipeline(shader_lib_ctx); uint32_t wg_x = ggml_nelements(dst); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } -static webgpu_encoded_op ggml_webgpu_argsort(webgpu_context & ctx, - wgpu::CommandEncoder & encoder, - ggml_tensor * src, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { bool is_top_k = dst->op == GGML_OP_TOP_K; ggml_webgpu_shader_lib_context shader_lib_ctx = { @@ -2659,10 +2510,7 @@ static webgpu_encoded_op ggml_webgpu_argsort(webgpu_context & ctx, const uint32_t stride_idx2 = out_ne0 * (uint32_t) dst->ne[1]; const uint32_t stride_idx3 = stride_idx2 * (uint32_t) dst->ne[2]; - std::vector pipelines; - std::vector> params_list; - std::vector> entries_list; - std::vector> workgroups_list; + std::vector dispatches; const uint32_t init_offset = start_in_tmp ? offset_tmp : offset_dst; const size_t init_align_offset = start_in_tmp ? tmp_offset : ggml_webgpu_tensor_align_offset(ctx, dst); @@ -2686,14 +2534,12 @@ static webgpu_encoded_op ggml_webgpu_argsort(webgpu_context & ctx, { .binding = 1, .buffer = ggml_webgpu_tensor_buf(dst), .offset = init_align_offset, .size = init_binding_size } }; - pipelines.push_back(argsort_pipeline); - params_list.push_back(std::move(init_params)); - entries_list.push_back(std::move(init_entries)); - workgroups_list.push_back({ wg_x_init, wg_y_init }); + dispatches.push_back({ + argsort_pipeline, std::move(init_params), std::move(init_entries), { wg_x_init, wg_y_init } + }); if (merge_passes == 0) { - return ggml_backend_webgpu_build_multi(ctx->global_ctx, ctx->param_arena, encoder, pipelines, params_list, - entries_list, workgroups_list); + return ggml_backend_webgpu_build_multi(ctx, dispatches); } bool in_is_tmp = start_in_tmp; @@ -2745,23 +2591,18 @@ static webgpu_encoded_op ggml_webgpu_argsort(webgpu_context & ctx, const uint32_t total_wg_merge = nm * nrows; const uint32_t wg_x_merge = std::min(total_wg_merge, max_wg); const uint32_t wg_y_merge = CEIL_DIV(total_wg_merge, wg_x_merge); - workgroups_list.push_back({ wg_x_merge, wg_y_merge }); - pipelines.push_back(argsort_merge_pipeline); - params_list.push_back(std::move(merge_params)); - entries_list.push_back(std::move(merge_entries)); + dispatches.push_back({ + argsort_merge_pipeline, std::move(merge_params), std::move(merge_entries), { wg_x_merge, wg_y_merge } + }); len <<= 1; in_is_tmp = !in_is_tmp; } - return ggml_backend_webgpu_build_multi(ctx->global_ctx, ctx->param_arena, encoder, pipelines, params_list, - entries_list, workgroups_list); + return ggml_backend_webgpu_build_multi(ctx, dispatches); } -static webgpu_encoded_op ggml_webgpu_cumsum(webgpu_context & ctx, - wgpu::CommandEncoder & encoder, - ggml_tensor * src, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_cumsum(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), (uint32_t) src->ne[0] }; @@ -2786,13 +2627,10 @@ static webgpu_encoded_op ggml_webgpu_cumsum(webgpu_context & ctx, webgpu_pipeline pipeline = ctx->shader_lib->get_cumsum_pipeline(shader_lib_ctx); uint32_t wg_x = ggml_nrows(dst); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } -static webgpu_encoded_op ggml_webgpu_sum_rows(webgpu_context & ctx, - wgpu::CommandEncoder & encoder, - ggml_tensor * src, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_sum_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { bool total_sum = dst->op == GGML_OP_SUM; std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), @@ -2821,13 +2659,11 @@ static webgpu_encoded_op ggml_webgpu_sum_rows(webgpu_context & ctx, webgpu_pipeline pipeline = ctx->shader_lib->get_sum_rows_pipeline(shader_lib_ctx); uint32_t wg_x = total_sum ? 1 : ggml_nrows(dst); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_arena, encoder, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } // Returns the encoded command, or std::nullopt if the operation is a no-op -static std::optional ggml_webgpu_encode_node(webgpu_context ctx, - wgpu::CommandEncoder & encoder, - ggml_tensor * node) { +static std::optional ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) { if (ggml_is_empty(node)) { return std::nullopt; } @@ -2850,20 +2686,20 @@ static std::optional ggml_webgpu_encode_node(webgpu_context return std::nullopt; case GGML_OP_CPY: case GGML_OP_CONT: - return ggml_webgpu_cpy(ctx, encoder, src0, node); + return ggml_webgpu_cpy(ctx, src0, node); case GGML_OP_SET: - return ggml_webgpu_set(ctx, encoder, src0, src1, node); + return ggml_webgpu_set(ctx, src0, src1, node); case GGML_OP_SET_ROWS: - return ggml_webgpu_set_rows(ctx, encoder, src0, src1, node); + return ggml_webgpu_set_rows(ctx, src0, src1, node); case GGML_OP_GET_ROWS: - return ggml_webgpu_get_rows(ctx, encoder, src0, src1, node); + return ggml_webgpu_get_rows(ctx, src0, src1, node); case GGML_OP_MUL_MAT: - return ggml_webgpu_mul_mat(ctx, encoder, src0, src1, node); + return ggml_webgpu_mul_mat(ctx, src0, src1, node); case GGML_OP_MUL_MAT_ID: - return ggml_webgpu_mul_mat_id(ctx, encoder, src0, src1, src2, node); + return ggml_webgpu_mul_mat_id(ctx, src0, src1, src2, node); case GGML_OP_FLASH_ATTN_EXT: #ifndef __EMSCRIPTEN__ - return ggml_webgpu_flash_attn(ctx, encoder, src0, src1, src2, node->src[3], node->src[4], node); + return ggml_webgpu_flash_attn(ctx, src0, src1, src2, node->src[3], node->src[4], node); #else return std::nullopt; #endif @@ -2871,22 +2707,22 @@ static std::optional ggml_webgpu_encode_node(webgpu_context case GGML_OP_SUB: case GGML_OP_MUL: case GGML_OP_DIV: - return ggml_webgpu_binary_op(ctx, encoder, src0, src1, node); + return ggml_webgpu_binary_op(ctx, src0, src1, node); case GGML_OP_CONCAT: - return ggml_webgpu_concat(ctx, encoder, src0, src1, node); + return ggml_webgpu_concat(ctx, src0, src1, node); case GGML_OP_REPEAT: - return ggml_webgpu_repeat(ctx, encoder, src0, node); + return ggml_webgpu_repeat(ctx, src0, node); case GGML_OP_RMS_NORM: case GGML_OP_L2_NORM: - return ggml_webgpu_row_norm(ctx, encoder, src0, node); + return ggml_webgpu_row_norm(ctx, src0, node); case GGML_OP_ROPE: - return ggml_webgpu_rope(ctx, encoder, src0, src1, src2, node); + return ggml_webgpu_rope(ctx, src0, src1, src2, node); case GGML_OP_GLU: - return ggml_webgpu_glu(ctx, encoder, src0, src1, node); + return ggml_webgpu_glu(ctx, src0, src1, node); case GGML_OP_SCALE: - return ggml_webgpu_scale(ctx, encoder, src0, node); + return ggml_webgpu_scale(ctx, src0, node); case GGML_OP_SOFT_MAX: - return ggml_webgpu_soft_max(ctx, encoder, src0, src1, src2, node); + return ggml_webgpu_soft_max(ctx, src0, src1, src2, node); case GGML_OP_UNARY: case GGML_OP_CLAMP: case GGML_OP_FILL: @@ -2897,32 +2733,80 @@ static std::optional ggml_webgpu_encode_node(webgpu_context case GGML_OP_COS: case GGML_OP_DIAG: case GGML_OP_TRI: - return ggml_webgpu_unary_op(ctx, encoder, src0, node); + return ggml_webgpu_unary_op(ctx, src0, node); case GGML_OP_SOLVE_TRI: - return ggml_webgpu_solve_tri(ctx, encoder, src0, src1, node); + return ggml_webgpu_solve_tri(ctx, src0, src1, node); case GGML_OP_SSM_CONV: - return ggml_webgpu_ssm_conv(ctx, encoder, src0, src1, node); + return ggml_webgpu_ssm_conv(ctx, src0, src1, node); case GGML_OP_GATED_DELTA_NET: - return ggml_webgpu_gated_delta_net(ctx, encoder, src0, src1, src2, node->src[3], node->src[4], node->src[5], - node); + return ggml_webgpu_gated_delta_net(ctx, src0, src1, src2, node->src[3], node->src[4], node->src[5], node); case GGML_OP_PAD: - return ggml_webgpu_pad(ctx, encoder, src0, node); + return ggml_webgpu_pad(ctx, src0, node); case GGML_OP_ARGMAX: - return ggml_webgpu_argmax(ctx, encoder, src0, node); + return ggml_webgpu_argmax(ctx, src0, node); case GGML_OP_ARGSORT: case GGML_OP_TOP_K: // we reuse the same argsort implementation for top_k - return ggml_webgpu_argsort(ctx, encoder, src0, node); + return ggml_webgpu_argsort(ctx, src0, node); case GGML_OP_CUMSUM: - return ggml_webgpu_cumsum(ctx, encoder, src0, node); + return ggml_webgpu_cumsum(ctx, src0, node); case GGML_OP_SUM: case GGML_OP_SUM_ROWS: - return ggml_webgpu_sum_rows(ctx, encoder, src0, node); + return ggml_webgpu_sum_rows(ctx, src0, node); default: return std::nullopt; } } +#ifdef GGML_WEBGPU_GPU_PROFILE +static void ggml_backend_webgpu_collect_profile_results(webgpu_context & ctx, + const std::vector & pipeline_names, + uint32_t & num_inflight_batches) { + if (pipeline_names.empty()) { + return; + } + + wgpu::CommandEncoder encoder = ctx->global_ctx->device.CreateCommandEncoder(); + encoder.ResolveQuerySet(ctx->profile_timestamp_query_set, 0, ctx->profile_timestamp_query_count, + ctx->profile_timestamp_dev_buf, 0); + encoder.CopyBufferToBuffer(ctx->profile_timestamp_dev_buf, 0, ctx->profile_timestamp_host_buf, 0, + ctx->profile_timestamp_query_count * sizeof(uint64_t)); + + wgpu::CommandBuffer profile_commands = encoder.Finish(); + ggml_backend_webgpu_submit_commands(ctx, profile_commands, num_inflight_batches); + + const size_t mapped_size = ctx->profile_timestamp_query_count * sizeof(uint64_t); + GGML_ASSERT(ctx->profile_timestamp_query_count == 2 * pipeline_names.size()); + + ggml_backend_webgpu_map_buffer(ctx->global_ctx, ctx->profile_timestamp_host_buf, wgpu::MapMode::Read, 0, + mapped_size); + const uint64_t * ts_data = (const uint64_t *) ctx->profile_timestamp_host_buf.GetConstMappedRange(0, mapped_size); + + for (size_t i = 0; i < pipeline_names.size(); ++i) { + // WebGPU timestamps are in ns; convert to ms. + const double elapsed_ms = double(ts_data[2 * i + 1] - ts_data[2 * i]) * 1e-6; + ctx->global_ctx->shader_gpu_time_ms[pipeline_names[i]] += elapsed_ms; + } + + ctx->profile_timestamp_host_buf.Unmap(); +} +#endif + +static void ggml_backend_webgpu_check_set_rows(webgpu_context & ctx, uint32_t & num_inflight_batches) { + wgpu::CommandEncoder encoder = ctx->global_ctx->device.CreateCommandEncoder(); + encoder.CopyBufferToBuffer(ctx->set_rows_dev_error_buf, 0, ctx->set_rows_host_error_buf, 0, + ctx->set_rows_host_error_buf.GetSize()); + wgpu::CommandBuffer commands = encoder.Finish(); + ggml_backend_webgpu_submit_commands(ctx, commands, num_inflight_batches); + ggml_backend_webgpu_map_buffer(ctx->global_ctx, ctx->set_rows_host_error_buf, wgpu::MapMode::Read, 0, + ctx->set_rows_host_error_buf.GetSize()); + const uint32_t * error_data = (const uint32_t *) ctx->set_rows_host_error_buf.GetConstMappedRange(); + if (*error_data) { + GGML_ABORT("ggml_webgpu: SET_ROWS index > 2^32, unsupported."); + } + ctx->set_rows_host_error_buf.Unmap(); +} + static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) { WEBGPU_LOG_DEBUG("ggml_backend_webgpu_graph_compute(" << cgraph->n_nodes << " nodes)"); @@ -2932,69 +2816,77 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str WEBGPU_CPU_PROFILE_TOTAL_START(graph_compute); std::vector commands; + + uint32_t num_batched_kernels = 0; + uint32_t num_inflight_batches = 0; + bool contains_set_rows = false; + bool batch_compute_passes = true; + #ifdef GGML_WEBGPU_GPU_PROFILE - std::vector profile_futures; + ctx->profile_timestamp_query_count = 0; + batch_compute_passes = false; + std::vector profile_pipeline_names; #endif - uint32_t num_batched_kernels = 0; - uint32_t num_inflight_batches = 0; - bool contains_set_rows = false; - wgpu::CommandEncoder batch_encoder = ctx->global_ctx->device.CreateCommandEncoder(); + + ctx->active_command_encoder = ctx->global_ctx->device.CreateCommandEncoder(); + if (batch_compute_passes) { + ctx->active_compute_pass = ctx->active_command_encoder.BeginComputePass(); + } for (int i = 0; i < cgraph->n_nodes; i++) { if (cgraph->nodes[i]->op == GGML_OP_SET_ROWS) { contains_set_rows = true; } - if (auto cmd = ggml_webgpu_encode_node(ctx, batch_encoder, cgraph->nodes[i])) { + if (auto cmd = ggml_webgpu_encode_node(ctx, cgraph->nodes[i])) { commands.push_back(*cmd); num_batched_kernels += cmd.value().num_kernels; +#ifdef GGML_WEBGPU_GPU_PROFILE + profile_pipeline_names.insert(profile_pipeline_names.end(), cmd->pipeline_names.begin(), + cmd->pipeline_names.end()); +#endif } if (num_batched_kernels >= ctx->global_ctx->command_submit_batch_size) { + if (ctx->active_compute_pass) { + ctx->active_compute_pass.End(); + } num_batched_kernels = 0; - wgpu::CommandBuffer batch_commands = batch_encoder.Finish(); + wgpu::CommandBuffer batch_commands = ctx->active_command_encoder.Finish(); ggml_backend_webgpu_submit_commands(ctx, batch_commands, num_inflight_batches); -#ifdef GGML_WEBGPU_GPU_PROFILE - ggml_backend_webgpu_collect_profile_futures(ctx->global_ctx, commands, profile_futures); -#endif + + // reset state for next batch + ctx->active_command_encoder = ctx->global_ctx->device.CreateCommandEncoder(); + if (batch_compute_passes) { + ctx->active_compute_pass = ctx->active_command_encoder.BeginComputePass(); + } ctx->param_arena.reset(); commands.clear(); - batch_encoder = ctx->global_ctx->device.CreateCommandEncoder(); } } - if (!commands.empty()) { - wgpu::CommandBuffer batch_commands = batch_encoder.Finish(); + + if (ctx->active_compute_pass) { + ctx->active_compute_pass.End(); + ctx->active_compute_pass = nullptr; + } + + if (num_batched_kernels > 0) { + wgpu::CommandBuffer batch_commands = ctx->active_command_encoder.Finish(); ggml_backend_webgpu_submit_commands(ctx, batch_commands, num_inflight_batches); -#ifdef GGML_WEBGPU_GPU_PROFILE - ggml_backend_webgpu_collect_profile_futures(ctx->global_ctx, commands, profile_futures); -#endif ctx->param_arena.reset(); commands.clear(); } + ctx->active_command_encoder = nullptr; + +#ifdef GGML_WEBGPU_GPU_PROFILE + ggml_backend_webgpu_collect_profile_results(ctx, profile_pipeline_names, num_inflight_batches); +#endif - // If there are SET_ROWS operations in this graph, copy the error buffers to the host for checking. if (contains_set_rows) { - wgpu::CommandEncoder encoder = ctx->global_ctx->device.CreateCommandEncoder(); - encoder.CopyBufferToBuffer(ctx->set_rows_dev_error_buf, 0, ctx->set_rows_host_error_buf, 0, - ctx->set_rows_host_error_buf.GetSize()); - wgpu::CommandBuffer set_rows_commands = encoder.Finish(); - ggml_backend_webgpu_submit_commands(ctx, set_rows_commands, num_inflight_batches); + ggml_backend_webgpu_check_set_rows(ctx, num_inflight_batches); } ggml_backend_webgpu_wait_queue(ctx->global_ctx); - if (contains_set_rows) { - ggml_backend_webgpu_map_buffer(ctx->global_ctx, ctx->set_rows_host_error_buf, wgpu::MapMode::Read, 0, - ctx->set_rows_host_error_buf.GetSize()); - const uint32_t * error_data = (const uint32_t *) ctx->set_rows_host_error_buf.GetConstMappedRange(); - if (*error_data) { - GGML_ABORT("ggml_webgpu: SET_ROWS index > 2^32, unsupported."); - } - ctx->set_rows_host_error_buf.Unmap(); - } - -#ifdef GGML_WEBGPU_GPU_PROFILE - ggml_backend_webgpu_wait_profile_futures(ctx->global_ctx, profile_futures); -#endif WEBGPU_CPU_PROFILE_TOTAL_END(graph_compute, ctx->global_ctx); return GGML_STATUS_SUCCESS; } @@ -3535,14 +3427,6 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { "memset_params_buf"); ctx->webgpu_global_ctx->queue = ctx->webgpu_global_ctx->device.GetQueue(); -#ifdef GGML_WEBGPU_GPU_PROFILE - // Initialize buffer pool for timestamp queries, used for profiling - ctx->webgpu_global_ctx->timestamp_query_buf_pool.init( - ctx->webgpu_global_ctx->device, WEBGPU_NUM_TIMESTAMP_QUERY_BUFS, WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES, - wgpu::BufferUsage::QueryResolve | wgpu::BufferUsage::CopySrc, - wgpu::BufferUsage::MapRead | wgpu::BufferUsage::CopyDst); -#endif - GGML_LOG_INFO( "ggml_webgpu: adapter_info: vendor_id: %u | vendor: %s | architecture: %s | device_id: %u | name: %s | " "device_desc: %s\n", @@ -3567,6 +3451,19 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) { WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES, wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "set_rows_host_error_buf"); +#ifdef GGML_WEBGPU_GPU_PROFILE + ggml_webgpu_create_buffer( + webgpu_ctx->global_ctx->device, webgpu_ctx->profile_timestamp_dev_buf, WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES, + wgpu::BufferUsage::QueryResolve | wgpu::BufferUsage::CopySrc, "profile_timestamp_dev_buf"); + ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->profile_timestamp_host_buf, + WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES, + wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "profile_timestamp_host_buf"); + wgpu::QuerySetDescriptor query_set_desc = {}; + query_set_desc.type = wgpu::QueryType::Timestamp; + query_set_desc.count = WEBGPU_MAX_PROFILE_QUERY_COUNT; + webgpu_ctx->profile_timestamp_query_set = webgpu_ctx->global_ctx->device.CreateQuerySet(&query_set_desc); +#endif + #ifdef GGML_WEBGPU_DEBUG // Initialize debug buffers ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->global_ctx->debug_host_buf, From 07c181b57f6d59027ea6fe3931967993e4f870a6 Mon Sep 17 00:00:00 2001 From: rehan-10xengineer Date: Thu, 16 Apr 2026 13:14:26 +0500 Subject: [PATCH 444/831] ggml : implemented simd_gemm kernel for riscv vector extension (llama/20627) Co-authored-by: Rehan Qasim --- ggml/src/ggml-cpu/simd-gemm.h | 90 +++++++++++++++++++++++++++++++++++ 1 file changed, 90 insertions(+) diff --git a/ggml/src/ggml-cpu/simd-gemm.h b/ggml/src/ggml-cpu/simd-gemm.h index 78d663e593e..4119d04f895 100644 --- a/ggml/src/ggml-cpu/simd-gemm.h +++ b/ggml/src/ggml-cpu/simd-gemm.h @@ -109,6 +109,96 @@ static void simd_gemm( C += N; } } +#elif defined(GGML_SIMD) && defined(__riscv_v_intrinsic) +// RM accumulators + 1 B vector = RM + 1 <= 8 => RM <= 7 +// Microkernel: C[RM x vl] += A[RM x K] * B[K x N] +template +static inline void rvv_simd_gemm_ukernel( + float * GGML_RESTRICT C, + const float * GGML_RESTRICT A, + const float * GGML_RESTRICT B, + int K, int N, size_t vl) +{ + static_assert(RM >= 1 && RM <= 7, "RM must be 1..7 for LMUL=4"); + + vfloat32m4_t acc_0 = __riscv_vle32_v_f32m4(C + 0 * N, vl); + vfloat32m4_t acc_1, acc_2, acc_3, acc_4, acc_5, acc_6; + if constexpr (RM > 1) acc_1 = __riscv_vle32_v_f32m4(C + 1 * N, vl); + if constexpr (RM > 2) acc_2 = __riscv_vle32_v_f32m4(C + 2 * N, vl); + if constexpr (RM > 3) acc_3 = __riscv_vle32_v_f32m4(C + 3 * N, vl); + if constexpr (RM > 4) acc_4 = __riscv_vle32_v_f32m4(C + 4 * N, vl); + if constexpr (RM > 5) acc_5 = __riscv_vle32_v_f32m4(C + 5 * N, vl); + if constexpr (RM > 6) acc_6 = __riscv_vle32_v_f32m4(C + 6 * N, vl); + + for (int kk = 0; kk < K; kk++) { + vfloat32m4_t b_0 = __riscv_vle32_v_f32m4(B + kk * N, vl); + + acc_0 = __riscv_vfmacc_vf_f32m4(acc_0, A[0 * K + kk], b_0, vl); + if constexpr (RM > 1) acc_1 = __riscv_vfmacc_vf_f32m4(acc_1, A[1 * K + kk], b_0, vl); + if constexpr (RM > 2) acc_2 = __riscv_vfmacc_vf_f32m4(acc_2, A[2 * K + kk], b_0, vl); + if constexpr (RM > 3) acc_3 = __riscv_vfmacc_vf_f32m4(acc_3, A[3 * K + kk], b_0, vl); + if constexpr (RM > 4) acc_4 = __riscv_vfmacc_vf_f32m4(acc_4, A[4 * K + kk], b_0, vl); + if constexpr (RM > 5) acc_5 = __riscv_vfmacc_vf_f32m4(acc_5, A[5 * K + kk], b_0, vl); + if constexpr (RM > 6) acc_6 = __riscv_vfmacc_vf_f32m4(acc_6, A[6 * K + kk], b_0, vl); + } + + __riscv_vse32_v_f32m4(C + 0 * N, acc_0, vl); + if constexpr (RM > 1) __riscv_vse32_v_f32m4(C + 1 * N, acc_1, vl); + if constexpr (RM > 2) __riscv_vse32_v_f32m4(C + 2 * N, acc_2, vl); + if constexpr (RM > 3) __riscv_vse32_v_f32m4(C + 3 * N, acc_3, vl); + if constexpr (RM > 4) __riscv_vse32_v_f32m4(C + 4 * N, acc_4, vl); + if constexpr (RM > 5) __riscv_vse32_v_f32m4(C + 5 * N, acc_5, vl); + if constexpr (RM > 6) __riscv_vse32_v_f32m4(C + 6 * N, acc_6, vl); +} + +template +static inline void rvv_simd_gemm_dispatch_tail( + float * GGML_RESTRICT C, + const float * GGML_RESTRICT A, + const float * GGML_RESTRICT B, + int K, int N, int KN, int remaining_rows) +{ + if constexpr (RM > 0) { + if (remaining_rows == RM) { + int64_t jj = 0; + for (; jj + KN <= N; jj += KN) { + rvv_simd_gemm_ukernel(C + jj, A, B + jj, K, N, KN); + } + if (jj < N) { + rvv_simd_gemm_ukernel(C + jj, A, B + jj, K, N, N - jj); + } + } else { + rvv_simd_gemm_dispatch_tail(C, A, B, K, N, KN, remaining_rows); + } + } +} + +static constexpr int GEMM_RM = 7; + +// C[M x N] += A[M x K] * B[K x N] +static void simd_gemm( + float * GGML_RESTRICT C, + const float * GGML_RESTRICT A, + const float * GGML_RESTRICT B, + int M, int K, int N) +{ + const int KN = (int)__riscv_vlenb(); + int64_t ii = 0; + for (; ii + GEMM_RM <= M; ii += GEMM_RM) { + int64_t jj = 0; + for (; jj + KN <= N; jj += KN) { + rvv_simd_gemm_ukernel(C + jj, A, B + jj, K, N, KN); + } + if (jj < N) { + rvv_simd_gemm_ukernel(C + jj, A, B + jj, K, N, N - jj); + } + A += GEMM_RM * K; + C += GEMM_RM * N; + } + + int remaining_rows = M - ii; + rvv_simd_gemm_dispatch_tail(C, A, B, K, N, KN, remaining_rows); +} #if defined(__GNUC__) && !defined(__clang__) #pragma GCC diagnostic pop From 94d6d0b743206b10a3074ea805c385b18fcd1498 Mon Sep 17 00:00:00 2001 From: rehan-10xengineer Date: Thu, 16 Apr 2026 13:15:15 +0500 Subject: [PATCH 445/831] ggml-cpu: add 128-bit RVV implementation for Quantization Vector Dot (llama/20633) * ggml-cpu: add 128-bit impls for i-quants, ternary quants * ggml-cpu: add 128-bit impls for iq2_xs, iq3_s, iq3_xxs, tq2_0 Co-authored-by: Rehan Qasim * ggml-cpu: refactor; add rvv checks --------- Co-authored-by: taimur-10x Co-authored-by: Rehan Qasim --- ggml/src/ggml-cpu/arch/riscv/quants.c | 972 ++++++++++++++++++++++++-- 1 file changed, 902 insertions(+), 70 deletions(-) diff --git a/ggml/src/ggml-cpu/arch/riscv/quants.c b/ggml/src/ggml-cpu/arch/riscv/quants.c index d7e9ba46348..d3278d6489f 100644 --- a/ggml/src/ggml-cpu/arch/riscv/quants.c +++ b/ggml/src/ggml-cpu/arch/riscv/quants.c @@ -15,6 +15,12 @@ #include // for qsort #include // for GGML_ASSERT +#ifdef _MSC_VER +#define NOINLINE __declspec(noinline) +#else +#define NOINLINE __attribute__((__noinline__)) +#endif + #define GROUP_MAX_EPS 1e-15f #define GROUP_MAX_EPS_IQ3_XXS 1e-8f #define GROUP_MAX_EPS_IQ2_S 1e-8f @@ -117,7 +123,7 @@ void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, in assert(k % QK_K == 0); size_t nb = k / QK_K; -#if defined(__riscv_v_intrinsic) +#if defined __riscv_v_intrinsic block_q8_K * y_blocks = (block_q8_K *)y; const size_t vlmax_f32m8 = __riscv_vsetvlmax_e32m8(); @@ -2053,7 +2059,119 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi } #if defined __riscv_v_intrinsic -static void ggml_vec_dot_iq1_s_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +static NOINLINE void ggml_vec_dot_iq1_s_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq1_s * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + // Load qh once for the entire superblock. + vuint16m1_t qh = __riscv_vle16_v_u16m1(x[i].qh, 8); + + // Calculate ls. + vuint16m1_t temp = __riscv_vsrl_vx_u16m1(qh, 12, 8); + temp = __riscv_vand_vx_u16m1(temp, 7, 8); + vint32m2_t ls = __riscv_vreinterpret_v_u32m2_i32m2(__riscv_vwmulu_vx_u32m2(temp, 2, 8)); + ls = __riscv_vadd_vx_i32m2(ls, 1, 8); + + // Calculate delta. + vbool16_t mask = __riscv_vmseq_vx_u16m1_b16(__riscv_vand_vx_u16m1(qh, 0x8000, 8), 0, 8); + vint32m2_t delta_neg = __riscv_vmv_v_x_i32m2(-1, 8); + vint32m2_t delta_pos = __riscv_vmv_v_x_i32m2(1, 8); + vint32m2_t delta = __riscv_vmerge_vvm_i32m2(delta_neg, delta_pos, mask, 8); + + // Load qs. + vuint8m2_t qs = __riscv_vle8_v_u8m2(x[i].qs, 32); + + // Prepare the indices. + const uint64_t shift = 0x0009000600030000; + vuint16m4_t qh_shift = __riscv_vreinterpret_v_u64m4_u16m4(__riscv_vmv_v_x_u64m4(shift, 8)); + vuint16m4_t qh_gather_index = __riscv_vreinterpret_v_i16m4_u16m4( + __riscv_vdiv_vx_i16m4(__riscv_vreinterpret_v_u16m4_i16m4(__riscv_vid_v_u16m4(32)), 4, 32)); + vuint16m4_t qh_ext = __riscv_vlmul_ext_v_u16m2_u16m4(__riscv_vlmul_ext_v_u16m1_u16m2(qh)); + vuint16m4_t qh_index = __riscv_vrgather_vv_u16m4(qh_ext, qh_gather_index, 32); + qh_index = __riscv_vsrl_vv_u16m4(qh_index, qh_shift, 32); + qh_index = __riscv_vand_vx_u16m4(qh_index, 7, 32); + qh_index = __riscv_vsll_vx_u16m4(qh_index, 8, 32); + qh_index = __riscv_vor_vv_u16m4(qh_index, __riscv_vzext_vf2_u16m4(qs, 32), 32); + vuint16m4_t index = __riscv_vsll_vx_u16m4(qh_index, 3, 32); + + // Final lsums. + int32_t lsums_s[8]; + vint32m1_t one_scalar = __riscv_vmv_v_x_i32m1(0, 1); + + // Sub-blocks 1-2 + { + vuint16m1_t grid_index0 = __riscv_vget_v_u16m4_u16m1(index, 0); + vint8m4_t grid0 = __riscv_vreinterpret_v_i64m4_i8m4(__riscv_vluxei16_v_i64m4((const int64_t*)iq1s_grid, grid_index0, 8)); + vint8m4_t q80 = __riscv_vle8_v_i8m4(&y[i].qs[0], 64); + vint16m8_t lsum0 = __riscv_vwmul_vv_i16m8(grid0, q80, 128); + lsums_s[0] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1(__riscv_vget_v_i16m8_i16m4(lsum0, 0), one_scalar, 32)); + lsums_s[1] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1(__riscv_vget_v_i16m8_i16m4(lsum0, 1), one_scalar, 32)); + } + __asm__ __volatile__("" ::: "memory"); + // Sub-blocks 3-4 + { + vuint16m1_t grid_index0 = __riscv_vget_v_u16m4_u16m1(index, 1); + vint8m4_t grid0 = __riscv_vreinterpret_v_i64m4_i8m4(__riscv_vluxei16_v_i64m4((const int64_t*)iq1s_grid, grid_index0, 8)); + vint8m4_t q80 = __riscv_vle8_v_i8m4(&y[i].qs[64], 64); + vint16m8_t lsum0 = __riscv_vwmul_vv_i16m8(grid0, q80, 128); + lsums_s[2] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1(__riscv_vget_v_i16m8_i16m4(lsum0, 0), one_scalar, 32)); + lsums_s[3] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1(__riscv_vget_v_i16m8_i16m4(lsum0, 1), one_scalar, 32)); + } + __asm__ __volatile__("" ::: "memory"); + // Sub-blocks 5-6 + { + vuint16m1_t grid_index0 = __riscv_vget_v_u16m4_u16m1(index, 2); + vint8m4_t grid0 = __riscv_vreinterpret_v_i64m4_i8m4(__riscv_vluxei16_v_i64m4((const int64_t*)iq1s_grid, grid_index0, 8)); + vint8m4_t q80 = __riscv_vle8_v_i8m4(&y[i].qs[128], 64); + vint16m8_t lsum0 = __riscv_vwmul_vv_i16m8(grid0, q80, 128); + lsums_s[4] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1(__riscv_vget_v_i16m8_i16m4(lsum0, 0), one_scalar, 32)); + lsums_s[5] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1(__riscv_vget_v_i16m8_i16m4(lsum0, 1), one_scalar, 32)); + } + __asm__ __volatile__("" ::: "memory"); + // Sub-blocks 7-8 + { + vuint16m1_t grid_index0 = __riscv_vget_v_u16m4_u16m1(index, 3); + vint8m4_t grid0 = __riscv_vreinterpret_v_i64m4_i8m4(__riscv_vluxei16_v_i64m4((const int64_t*)iq1s_grid, grid_index0, 8)); + vint8m4_t q80 = __riscv_vle8_v_i8m4(&y[i].qs[192], 64); + vint16m8_t lsum0 = __riscv_vwmul_vv_i16m8(grid0, q80, 128); + lsums_s[6] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1(__riscv_vget_v_i16m8_i16m4(lsum0, 0), one_scalar, 32)); + lsums_s[7] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1(__riscv_vget_v_i16m8_i16m4(lsum0, 1), one_scalar, 32)); + } + __asm__ __volatile__("" ::: "memory"); + vint32m2_t lsums = __riscv_vle32_v_i32m2(&lsums_s[0], 8); + + // Calculate the bsums. + vint16m2_t bsums_0 = __riscv_vle16_v_i16m2(y[i].bsums, 16); + const vuint32m2_t bsums_i32 = __riscv_vreinterpret_v_u16m2_u32m2(__riscv_vreinterpret_v_i16m2_u16m2(bsums_0)); + const vint16m1_t bsums_i32_0 = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vnsrl_wx_u16m1(bsums_i32, 0, 8)); + const vint16m1_t bsums_i32_1 = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vnsrl_wx_u16m1(bsums_i32, 16, 8)); + const vint32m2_t bsums = __riscv_vwadd_vv_i32m2(bsums_i32_0, bsums_i32_1, 8); + + // Accumulation. + vint32m2_t sumi_v = __riscv_vmul_vv_i32m2(ls, lsums, 8); + vint32m2_t sumi1_v = __riscv_vmul_vv_i32m2(__riscv_vmul_vv_i32m2(ls, delta, 8), bsums, 8); + + // Update sumf. + int sumi = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m2_i32m1(sumi_v, __riscv_vmv_v_x_i32m1(0.0f, 1), 8)); + int sumi1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m2_i32m1(sumi1_v, __riscv_vmv_v_x_i32m1(0.0f, 1), 8)); + sumf += GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d * (sumi + IQ1S_DELTA * sumi1); + } + + *s = sumf; +} + +static NOINLINE void ggml_vec_dot_iq1_s_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); UNUSED(nrc); @@ -2153,6 +2271,9 @@ static void ggml_vec_dot_iq1_s_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t void ggml_vec_dot_iq1_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { #if defined __riscv_v_intrinsic switch (__riscv_vlenb() * 8) { + case 128: + ggml_vec_dot_iq1_s_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); + break; case 256: ggml_vec_dot_iq1_s_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); break; @@ -2166,7 +2287,174 @@ void ggml_vec_dot_iq1_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo } #if defined __riscv_v_intrinsic -static void ggml_vec_dot_iq1_m_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +static NOINLINE void ggml_vec_dot_iq1_m_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq1_m * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + iq1m_scale_t scale; + float sumf = 0.0f; + for (int i = 0; i < nb; ++i) { + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint8_t * qh = x[i].qh; + const uint16_t * sc = (const uint16_t *)x[i].scales; + + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); + + // Accumulators. + vint32m4_t acc1 = __riscv_vmv_v_x_i32m4(0, 16); + vint32m4_t acc2 = __riscv_vmv_v_x_i32m4(0, 16); + + // We process 8 16-element sub-blocks together. + #pragma GCC unroll 1 + for (int ib = 0; ib < QK_K/128; ib++) { + // Load qh for 8 sub-blocks. + const vuint8mf2_t qh_8 = __riscv_vle8_v_u8mf2(qh, 8); + const vuint16m1_t qh_16_lo = __riscv_vzext_vf2_u16m1(qh_8, 8); + const vuint16m1_t qh_16_hi = __riscv_vsll_vx_u16m1(qh_16_lo, 8, 8); + const vuint16m2_t qhb = __riscv_vzext_vf2_u16m2( + __riscv_vreinterpret_v_u16m1_u8m1(__riscv_vor_vv_u16m1(qh_16_lo, qh_16_hi, 8)), 16); + qh += 8; + + // Prepare grid indices. + const vuint16m2_t qsb = __riscv_vzext_vf2_u16m2(__riscv_vle8_v_u8m1(&qs[0], 16), 16); + const vuint16m2_t shift = __riscv_vreinterpret_v_u32m2_u16m2(__riscv_vmv_v_x_u32m2(0x00040008, 8)); + vuint16m2_t index = __riscv_vor_vv_u16m2(qsb, __riscv_vand_vx_u16m2(__riscv_vsll_vv_u16m2(qhb, shift, 16), 0x700, 16), 16); + index = __riscv_vsll_vx_u16m2(index, 3, 16); + qs += 16; + + // Prepare the deltas. + const vbool8_t mask = __riscv_vmsgtu_vx_u16m2_b8( + __riscv_vand_vv_u16m2(qhb, __riscv_vreinterpret_v_u32m2_u16m2(__riscv_vmv_v_x_u32m2(0x00800008, 8)), 16), 0, 16); + const vint64m8_t delta_pos = __riscv_vmv_v_x_i64m8(0x0101010101010101, 16); + const vint8m8_t delta = __riscv_vreinterpret_v_i64m8_i8m8( + __riscv_vmerge_vxm_i64m8(delta_pos, 0xffffffffffffffff, mask, 16)); + + // Sub-blocks 0-3 + { + // Load the grid. + const vint8m4_t iq1b = __riscv_vreinterpret_v_i64m4_i8m4(__riscv_vreinterpret_v_u64m4_i64m4( + __riscv_vluxei16_v_u64m4(iq1s_grid, __riscv_vget_v_u16m2_u16m1(index, 0), 8))); + + // Calculate the lsums. + // + // Sub-block 0, 1 + { + // Load q8 for each sub-block. + const vint8m2_t q8b = __riscv_vle8_v_i8m2(q8, 32); + q8 += 32; + + // Calculate the lsums. + const vint16m4_t lsum1 = __riscv_vwmul_vv_i16m4(__riscv_vget_v_i8m4_i8m2(iq1b, 0), q8b, 32); + const vint16m4_t lsum2 = __riscv_vwmul_vv_i16m4(__riscv_vget_v_i8m8_i8m2(delta, 0), q8b, 32); + + // Prepare the scales. + const int16_t ls_0 = 2*((sc[0] >> 0) & 0x7) + 1; + const int16_t ls_1 = 2*((sc[0] >> 3) & 0x7) + 1; + + // Accumulate in acc0 and acc1 for each sub-block. + acc1 = __riscv_vwmacc_vx_i32m4(acc1, ls_0, __riscv_vget_v_i16m4_i16m2(lsum1, 0), 16); + acc1 = __riscv_vwmacc_vx_i32m4(acc1, ls_1, __riscv_vget_v_i16m4_i16m2(lsum1, 1), 16); + acc2 = __riscv_vwmacc_vx_i32m4(acc2, ls_0, __riscv_vget_v_i16m4_i16m2(lsum2, 0), 16); + acc2 = __riscv_vwmacc_vx_i32m4(acc2, ls_1, __riscv_vget_v_i16m4_i16m2(lsum2, 1), 16); + } + __asm__ __volatile__("" ::: "memory"); + // Sub-block 2, 3 + { + // Load q8 for each sub-block. + const vint8m2_t q8b = __riscv_vle8_v_i8m2(q8, 32); + q8 += 32; + + // Calculate the lsums. + const vint16m4_t lsum1 = __riscv_vwmul_vv_i16m4(__riscv_vget_v_i8m4_i8m2(iq1b, 1), q8b, 32); + const vint16m4_t lsum2 = __riscv_vwmul_vv_i16m4(__riscv_vget_v_i8m8_i8m2(delta, 1), q8b, 32); + + // Prepare the scales. + const int16_t ls_0 = 2*((sc[0] >> 6) & 0x7) + 1; + const int16_t ls_1 = 2*((sc[0] >> 9) & 0x7) + 1; + + // Accumulate in acc0 and acc1 for each sub-block. + acc1 = __riscv_vwmacc_vx_i32m4(acc1, ls_0, __riscv_vget_v_i16m4_i16m2(lsum1, 0), 16); + acc1 = __riscv_vwmacc_vx_i32m4(acc1, ls_1, __riscv_vget_v_i16m4_i16m2(lsum1, 1), 16); + acc2 = __riscv_vwmacc_vx_i32m4(acc2, ls_0, __riscv_vget_v_i16m4_i16m2(lsum2, 0), 16); + acc2 = __riscv_vwmacc_vx_i32m4(acc2, ls_1, __riscv_vget_v_i16m4_i16m2(lsum2, 1), 16); + } + sc += 1; + } + __asm__ __volatile__("" ::: "memory"); + // Sub-blocks 4-7 + { + // Load the grid. + const vint8m4_t iq1b = __riscv_vreinterpret_v_i64m4_i8m4(__riscv_vreinterpret_v_u64m4_i64m4( + __riscv_vluxei16_v_u64m4(iq1s_grid, __riscv_vget_v_u16m2_u16m1(index, 1), 8))); + + // Calculate the lsums. + // + // Sub-block 4, 5 + { + // Load q8 for each sub-block. + const vint8m2_t q8b = __riscv_vle8_v_i8m2(q8, 32); + q8 += 32; + + // Calculate the lsums. + const vint16m4_t lsum1 = __riscv_vwmul_vv_i16m4(__riscv_vget_v_i8m4_i8m2(iq1b, 0), q8b, 32); + const vint16m4_t lsum2 = __riscv_vwmul_vv_i16m4(__riscv_vget_v_i8m8_i8m2(delta, 2), q8b, 32); + + // Prepare the scales. + const int16_t ls_0 = 2*((sc[0] >> 0) & 0x7) + 1; + const int16_t ls_1 = 2*((sc[0] >> 3) & 0x7) + 1; + + // Accumulate in acc0 and acc1 for each sub-block. + acc1 = __riscv_vwmacc_vx_i32m4(acc1, ls_0, __riscv_vget_v_i16m4_i16m2(lsum1, 0), 16); + acc1 = __riscv_vwmacc_vx_i32m4(acc1, ls_1, __riscv_vget_v_i16m4_i16m2(lsum1, 1), 16); + acc2 = __riscv_vwmacc_vx_i32m4(acc2, ls_0, __riscv_vget_v_i16m4_i16m2(lsum2, 0), 16); + acc2 = __riscv_vwmacc_vx_i32m4(acc2, ls_1, __riscv_vget_v_i16m4_i16m2(lsum2, 1), 16); + } + __asm__ __volatile__("" ::: "memory"); + // Sub-block 6, 7 + { + // Load q8 for each sub-block. + const vint8m2_t q8b = __riscv_vle8_v_i8m2(q8, 32); + q8 += 32; + + // Calculate the lsums. + const vint16m4_t lsum1 = __riscv_vwmul_vv_i16m4(__riscv_vget_v_i8m4_i8m2(iq1b, 1), q8b, 32); + const vint16m4_t lsum2 = __riscv_vwmul_vv_i16m4(__riscv_vget_v_i8m8_i8m2(delta, 3), q8b, 32); + + // Prepare the scales. + const int16_t ls_0 = 2*((sc[0] >> 6) & 0x7) + 1; + const int16_t ls_1 = 2*((sc[0] >> 9) & 0x7) + 1; + + // Accumulate in acc0 and acc1 for each sub-block. + acc1 = __riscv_vwmacc_vx_i32m4(acc1, ls_0, __riscv_vget_v_i16m4_i16m2(lsum1, 0), 16); + acc1 = __riscv_vwmacc_vx_i32m4(acc1, ls_1, __riscv_vget_v_i16m4_i16m2(lsum1, 1), 16); + acc2 = __riscv_vwmacc_vx_i32m4(acc2, ls_0, __riscv_vget_v_i16m4_i16m2(lsum2, 0), 16); + acc2 = __riscv_vwmacc_vx_i32m4(acc2, ls_1, __riscv_vget_v_i16m4_i16m2(lsum2, 1), 16); + } + sc += 1; + } + } + + // Reduce and accumulate in `sumf`. + vint32m1_t one = __riscv_vmv_v_x_i32m1(0, 1); + int sumi1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m4_i32m1(acc1, one, 16)); + int sumi2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m4_i32m1(acc2, one, 16)); + sumf += y[i].d * GGML_CPU_FP16_TO_FP32(scale.f16) * (sumi1 + IQ1M_DELTA * sumi2); + } + + *s = sumf; +} + +static NOINLINE void ggml_vec_dot_iq1_m_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); UNUSED(nrc); @@ -2193,9 +2481,10 @@ static void ggml_vec_dot_iq1_m_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t vint32m2_t acc1 = __riscv_vmv_v_x_i32m2(0, 16); vint32m2_t acc2 = __riscv_vmv_v_x_i32m2(0, 16); - // We process 4 sub-blocks together. + // We process 8 16-element sub-blocks together. + #pragma GCC unroll 1 for (int ib = 0; ib < QK_K/128; ib++) { - // Load qh for 4 sub-blocks. + // Load qh for 8 sub-blocks. const vuint8mf4_t qh_8 = __riscv_vle8_v_u8mf4(qh, 8); const vuint16mf2_t qh_16_lo = __riscv_vzext_vf2_u16mf2(qh_8, 8); const vuint16mf2_t qh_16_hi = __riscv_vsll_vx_u16mf2(qh_16_lo, 8, 8); @@ -2203,6 +2492,8 @@ static void ggml_vec_dot_iq1_m_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t __riscv_vreinterpret_v_u16mf2_u8mf2(__riscv_vor_vv_u16mf2(qh_16_lo, qh_16_hi, 8)), 16); qh += 8; + __asm__ __volatile__("" ::: "memory"); + // Prepare grid indices. const vuint16m1_t qsb = __riscv_vzext_vf2_u16m1(__riscv_vle8_v_u8mf2(&qs[0], 16), 16); const vuint16m1_t shift = __riscv_vreinterpret_v_u32m1_u16m1(__riscv_vmv_v_x_u32m1(0x00040008, 8)); @@ -2210,6 +2501,8 @@ static void ggml_vec_dot_iq1_m_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t index = __riscv_vsll_vx_u16m1(index, 3, 16); qs += 16; + __asm__ __volatile__("" ::: "memory"); + // Load the grid. const vint8m4_t iq1b = __riscv_vreinterpret_v_i64m4_i8m4(__riscv_vreinterpret_v_u64m4_i64m4( __riscv_vluxei16_v_u64m4(iq1s_grid, index, 16))); @@ -2218,9 +2511,8 @@ static void ggml_vec_dot_iq1_m_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t const vbool16_t mask = __riscv_vmsgtu_vx_u16m1_b16( __riscv_vand_vv_u16m1(qhb, __riscv_vreinterpret_v_u32m1_u16m1(__riscv_vmv_v_x_u32m1(0x00800008, 8)), 16), 0, 16); const vint64m4_t delta_pos = __riscv_vmv_v_x_i64m4(0x0101010101010101, 16); - const vint64m4_t delta_neg = __riscv_vmv_v_x_i64m4(0xffffffffffffffff, 16); const vint8m4_t delta = __riscv_vreinterpret_v_i64m4_i8m4( - __riscv_vmerge_vvm_i64m4(delta_pos, delta_neg, mask, 16)); + __riscv_vmerge_vxm_i64m4(delta_pos, 0xffffffffffffffff, mask, 16)); // Load q8 for sub-blocks. const vint8m4_t q8b = __riscv_vle8_v_i8m4(q8, 128); @@ -2261,6 +2553,8 @@ static void ggml_vec_dot_iq1_m_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_3_1, __riscv_vget_v_i16m8_i16m1(lsum1, 7), 16); acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_3_0, __riscv_vget_v_i16m8_i16m1(lsum2, 6), 16); acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_3_1, __riscv_vget_v_i16m8_i16m1(lsum2, 7), 16); + + __asm__ __volatile__("" ::: "memory"); } // Reduce and accumulate in `sumf`. @@ -2277,6 +2571,9 @@ static void ggml_vec_dot_iq1_m_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t void ggml_vec_dot_iq1_m_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { #if defined __riscv_v_intrinsic switch (__riscv_vlenb() * 8) { + case 128: + ggml_vec_dot_iq1_m_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); + break; case 256: ggml_vec_dot_iq1_m_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); break; @@ -2300,8 +2597,7 @@ static const uint8_t sign_bit_masks_arr[64] = { 1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128 }; - -static void ggml_vec_dot_iq2_s_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +static NOINLINE void ggml_vec_dot_iq2_s_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); @@ -2392,7 +2688,7 @@ static void ggml_vec_dot_iq2_s_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t *s = 0.125f * sumf; } -static void ggml_vec_dot_iq2_s_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +static NOINLINE void ggml_vec_dot_iq2_s_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); @@ -2513,7 +2809,7 @@ void ggml_vec_dot_iq2_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo #endif } -#if defined(__riscv_v_intrinsic) +#if defined __riscv_v_intrinsic static const int8_t keven_signs_q2xs[1024] = { 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, -1, @@ -2549,7 +2845,84 @@ static const int8_t keven_signs_q2xs[1024] = { 1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1, }; -static void ggml_vec_dot_iq2_xs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +static NOINLINE void ggml_vec_dot_iq2_xs_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq2_xs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + const uint64_t * grid64 = (const uint64_t *)iq2xs_grid; + + float sumf = 0.0f; +#pragma GCC unroll 1 + for (int i = 0; i < nb; ++i) { + const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + const uint16_t * GGML_RESTRICT qs = x[i].qs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + const uint8_t * GGML_RESTRICT scales = x[i].scales; + + int32_t sum_int = 0; + + // Loop over 4 subblocks of 64 elements + for (int ib64 = 0; ib64 < QK_K / 64; ++ib64) { + + // Load indices. + vuint16m1_t v_qs = __riscv_vle16_v_u16m1(qs, 8); + qs += 8; + + // Prepare offsets + vuint16m1_t vidx_grid = __riscv_vsll_vx_u16m1(__riscv_vand_vx_u16m1(v_qs, 511, 8), 3, 8); + vuint16m1_t vidx_sign = __riscv_vsll_vx_u16m1(__riscv_vsrl_vx_u16m1(v_qs, 9, 8), 3, 8); + + // load values and signs from the lookup tables + vuint64m4_t vq2_64 = __riscv_vluxei16_v_u64m4(grid64, vidx_grid, 8); + vuint64m4_t vs2_64 = __riscv_vluxei16_v_u64m4(signs64, vidx_sign, 8); + vint8m4_t q2u = __riscv_vreinterpret_v_u8m4_i8m4(__riscv_vreinterpret_v_u64m4_u8m4(vq2_64)); + vint8m4_t q2s = __riscv_vreinterpret_v_u8m4_i8m4(__riscv_vreinterpret_v_u64m4_u8m4(vs2_64)); + vint8m4_t q2_final = __riscv_vmul_vv_i8m4(q2u, q2s, 64); + asm volatile("" ::: "memory"); + vint8m4_t q8v = __riscv_vle8_v_i8m4(q8, 64); + q8 += 64; + + vint16m8_t prod = __riscv_vwmul_vv_i16m8(q2_final, q8v, 64); + asm volatile("" ::: "memory"); + vint32m1_t zero_vec = __riscv_vmv_v_x_i32m1(0, 1); + + int32_t sum0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1( + __riscv_vget_v_i16m8_i16m2(prod, 0), zero_vec, 16)); + + int32_t sum1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1( + __riscv_vget_v_i16m8_i16m2(prod, 1), zero_vec, 16)); + + int32_t sum2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1( + __riscv_vget_v_i16m8_i16m2(prod, 2), zero_vec, 16)); + + int32_t sum3 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1( + __riscv_vget_v_i16m8_i16m2(prod, 3), zero_vec, 16)); + + const uint8_t scale_byte_1 = scales[0]; + const uint8_t scale_byte_2 = scales[1]; + scales += 2; + + sum_int += sum0 * ((scale_byte_1 & 0x0F) * 2 + 1); + sum_int += sum1 * ((scale_byte_1 >> 4) * 2 + 1); + sum_int += sum2 * ((scale_byte_2 & 0x0F) * 2 + 1); + sum_int += sum3 * ((scale_byte_2 >> 4) * 2 + 1); + } + + sumf += d * sum_int; + } + *s = 0.125f * sumf; +} + +static NOINLINE void ggml_vec_dot_iq2_xs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); UNUSED(nrc); @@ -2628,6 +3001,9 @@ static void ggml_vec_dot_iq2_xs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_ void ggml_vec_dot_iq2_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { #if defined __riscv_v_intrinsic switch (__riscv_vlenb() * 8) { + case 128: + ggml_vec_dot_iq2_xs_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); + break; case 256: ggml_vec_dot_iq2_xs_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); break; @@ -2641,7 +3017,7 @@ void ggml_vec_dot_iq2_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v } #if defined __riscv_v_intrinsic -static void ggml_vec_dot_iq2_xxs_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +static NOINLINE void ggml_vec_dot_iq2_xxs_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); UNUSED(nrc); @@ -2732,7 +3108,7 @@ static void ggml_vec_dot_iq2_xxs_q8_K_vl128(int n, float * GGML_RESTRICT s, size *s = 0.125f * sumf; } -static void ggml_vec_dot_iq2_xxs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +static NOINLINE void ggml_vec_dot_iq2_xxs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); UNUSED(nrc); @@ -2833,7 +3209,7 @@ void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const case 128: ggml_vec_dot_iq2_xxs_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); break; - default: + default: // 256 and above ggml_vec_dot_iq2_xxs_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); break; } @@ -2843,7 +3219,102 @@ void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const } #if defined __riscv_v_intrinsic -static void ggml_vec_dot_iq3_s_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +static NOINLINE void ggml_vec_dot_iq3_s_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); + const block_iq3_s * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + const uint32_t * grid32 = (const uint32_t *)iq3s_grid; + + vuint8mf2_t v_id_8 = __riscv_vid_v_u8mf2(8); + vuint8m2_t v_id_32 = __riscv_vid_v_u8m2(32); + + // Keeping these in a tight scope to hint they're only needed for the mask computation. + vuint8m2_t v_sign_gather_indices, v_sign_masks; + { + vuint8m2_t v_shifts = __riscv_vand_vx_u8m2(v_id_32, 7, 32); + vuint8m2_t v_one_32 = __riscv_vmv_v_x_u8m2(1, 32); + v_sign_gather_indices = __riscv_vsrl_vx_u8m2(v_id_32, 3, 32); + v_sign_masks = __riscv_vsll_vv_u8m2(v_one_32, v_shifts, 32); + } + + float sumf = 0.0f; + + for (int i = 0; i < nb; ++i) { + const float d = GGML_CPU_FP16_TO_FP32(x[i].d); + const float combined_scale = d * y[i].d; + + const uint8_t * GGML_RESTRICT qs = x[i].qs; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const uint8_t * GGML_RESTRICT scales = x[i].scales; + const uint8_t * GGML_RESTRICT signs = x[i].signs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + float sum_block = 0.0f; + + for (int ib = 0; ib < 8; ++ib) { + + // Grid lookup + vuint8m2_t v_grid_u8; + { + vuint8mf2_t v_qs_u8 = __riscv_vle8_v_u8mf2(qs, 8); + qs += 8; + + uint8_t qh_val = *qh++; + vuint8mf2_t v_qh_val = __riscv_vmv_v_x_u8mf2(qh_val, 8); + v_qh_val = __riscv_vsrl_vv_u8mf2(v_qh_val, v_id_8, 8); + v_qh_val = __riscv_vand_vx_u8mf2(v_qh_val, 1, 8); + + vuint16m1_t v_qs_u16 = __riscv_vwcvtu_x_x_v_u16m1(v_qs_u8, 8); + v_qs_u16 = __riscv_vsll_vx_u16m1(v_qs_u16, 2, 8); + + vuint16m1_t v_qh_u16 = __riscv_vwcvtu_x_x_v_u16m1(v_qh_val, 8); + v_qh_u16 = __riscv_vsll_vx_u16m1(v_qh_u16, 10, 8); + + vuint16m1_t v_grid_offsets = __riscv_vor_vv_u16m1(v_qs_u16, v_qh_u16, 8); + + vuint32m2_t v_grid_packed = __riscv_vluxei16_v_u32m2(grid32, v_grid_offsets, 8); + v_grid_u8 = __riscv_vreinterpret_v_u32m2_u8m2(v_grid_packed); + } + __asm__ volatile ("" ::: "memory"); + + //Sign application and dot product + int32_t s_val; + { + vuint8mf4_t v_signs_raw = __riscv_vle8_v_u8mf4(signs, 4); + signs += 4; + + vuint8m2_t v_signs_source = __riscv_vlmul_ext_v_u8mf4_u8m2(v_signs_raw); + vuint8m2_t v_signs_bcast = __riscv_vrgather_vv_u8m2(v_signs_source, v_sign_gather_indices, 32); + vuint8m2_t v_sign_bits = __riscv_vand_vv_u8m2(v_signs_bcast, v_sign_masks, 32); + vbool4_t m_negative = __riscv_vmsne_vx_u8m2_b4(v_sign_bits, 0, 32); + + vint8m2_t v_q8 = __riscv_vle8_v_i8m2(q8, 32); + q8 += 32; + + vint8m2_t v_q8_signed = __riscv_vrsub_vx_i8m2_mu(m_negative, v_q8, v_q8, 0, 32); + vint16m4_t v_dot = __riscv_vwmulsu_vv_i16m4(v_q8_signed, v_grid_u8, 32); + + vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, 1); + s_val = __riscv_vmv_x_s_i32m1_i32( + __riscv_vwredsum_vs_i16m4_i32m1(v_dot, v_zero, 32)); + } + __asm__ volatile ("" ::: "memory"); + { + uint8_t sc_byte = scales[ib >> 1]; + int sc_val = (ib & 1) ? (sc_byte >> 4) : (sc_byte & 0xF); + sc_val = sc_val * 2 + 1; + sum_block += (float)(s_val * sc_val); + } + } + sumf += sum_block * combined_scale; + } + *s = sumf; +} + +static NOINLINE void ggml_vec_dot_iq3_s_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); UNUSED(nrc); UNUSED(bx); @@ -2942,6 +3413,9 @@ static void ggml_vec_dot_iq3_s_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t void ggml_vec_dot_iq3_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { #if defined __riscv_v_intrinsic switch (__riscv_vlenb() * 8) { + case 128: + ggml_vec_dot_iq3_s_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); + break; case 256: ggml_vec_dot_iq3_s_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); break; @@ -2955,7 +3429,100 @@ void ggml_vec_dot_iq3_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo } #if defined __riscv_v_intrinsic -static void ggml_vec_dot_iq3_xxs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +static NOINLINE void ggml_vec_dot_iq3_xxs_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); + + const block_iq3_xxs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + const int nb = n / QK_K; + + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + const uint32_t * grid32 = (const uint32_t *)iq3xxs_grid; + + // constants for unpacking logic + const uint32_t shifts_val[8] = {0, 7, 14, 21, 0, 7, 14, 21}; + vuint32m2_t v_shifts = __riscv_vle32_v_u32m2(shifts_val, 8); + + const uint32_t gather_idx_val[8] = {0, 0, 0, 0, 1, 1, 1, 1}; + vuint32m2_t v_gather_idx = __riscv_vle32_v_u32m2(gather_idx_val, 8); + + uint32_t aux32[2]; + float sumf = 0.0f; + + for (int i = 0; i < nb; ++i) { + const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * GGML_RESTRICT q3_indices = x[i].qs; + const uint8_t * GGML_RESTRICT metadata = x[i].qs + QK_K/4; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + float block_sum = 0.0f; + + // Process 64 weights per loop + for (int ib = 0; ib < QK_K / 64; ++ib) { + + // load of metadata via memcpy + memcpy(aux32, metadata, 2 * sizeof(uint32_t)); + metadata += 2 * sizeof(uint32_t); + + vuint8m1_t v_q3_idx_u8 = __riscv_vle8_v_u8m1(q3_indices, 16); + q3_indices += 16; + + vuint16m2_t v_q3_idx_u16 = __riscv_vwmulu_vx_u16m2(v_q3_idx_u8, 4, 16); + + vuint32m4_t v_q3_magnitudes_u32 = __riscv_vluxei16_v_u32m4(grid32, v_q3_idx_u16, 16); + + vint8m4_t v_q3_magnitudes = __riscv_vreinterpret_v_u8m4_i8m4( + __riscv_vreinterpret_v_u32m4_u8m4(v_q3_magnitudes_u32)); + + vuint32m2_t v_aux = __riscv_vle32_v_u32m2(aux32, 2); + + vuint32m2_t v_aux_expanded = __riscv_vrgather_vv_u32m2(v_aux, v_gather_idx, 8); + + vuint32m2_t v_s_vals_raw = __riscv_vand_vx_u32m2( + __riscv_vsrl_vv_u32m2(v_aux_expanded, v_shifts, 8), 127, 8); + + vuint16m1_t sign_indices_byte_offset = __riscv_vsll_vx_u16m1( + __riscv_vncvt_x_x_w_u16m1(v_s_vals_raw, 8), 3, 8); + + vuint64m4_t v_s_vals_u64 = __riscv_vluxei16_v_u64m4(signs64, sign_indices_byte_offset, 8); + + vint8m4_t v_s_vals = __riscv_vreinterpret_v_u8m4_i8m4( + __riscv_vreinterpret_v_u64m4_u8m4(v_s_vals_u64)); + + vint8m4_t v_q3_signed = __riscv_vmul_vv_i8m4(v_q3_magnitudes, v_s_vals, 64); + asm volatile("" ::: "memory"); + vint8m4_t v_q8 = __riscv_vle8_v_i8m4(q8, 64); + q8 += 64; + + vint16m8_t v_dot = __riscv_vwmul_vv_i16m8(v_q8, v_q3_signed, 64); + + asm volatile("" ::: "memory"); + + vint16m4_t v_dot_1 = __riscv_vget_v_i16m8_i16m4(v_dot, 0); + vint16m4_t v_dot_2 = __riscv_vget_v_i16m8_i16m4(v_dot, 1); + + vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, 1); + + vint32m1_t v_sum_1 = __riscv_vwredsum_vs_i16m4_i32m1(v_dot_1, v_zero, 32); + vint32m1_t v_sum_2 = __riscv_vwredsum_vs_i16m4_i32m1(v_dot_2, v_zero, 32); + + int32_t sum1_i = __riscv_vmv_x_s_i32m1_i32(v_sum_1); + int32_t sum2_i = __riscv_vmv_x_s_i32m1_i32(v_sum_2); + + const float scale1_f = (float)(2 * (aux32[0] >> 28) + 1); + const float scale2_f = (float)(2 * (aux32[1] >> 28) + 1); + + block_sum += sum1_i * scale1_f + sum2_i * scale2_f; + } + + sumf += d * block_sum; + } + *s = 0.25f * sumf; +} + +static NOINLINE void ggml_vec_dot_iq3_xxs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); UNUSED(nrc); @@ -3052,6 +3619,9 @@ static void ggml_vec_dot_iq3_xxs_q8_K_vl256(int n, float * GGML_RESTRICT s, size void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { #if defined __riscv_v_intrinsic switch (__riscv_vlenb() * 8) { + case 128: + ggml_vec_dot_iq3_xxs_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); + break; case 256: ggml_vec_dot_iq3_xxs_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); break; @@ -3065,7 +3635,7 @@ void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const } #if defined __riscv_v_intrinsic -static void ggml_vec_dot_iq4_nl_q8_0_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +static NOINLINE void ggml_vec_dot_iq4_nl_q8_0_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(nrc == 1); UNUSED(nrc); UNUSED(bx); @@ -3095,12 +3665,14 @@ static void ggml_vec_dot_iq4_nl_q8_0_vl128(int n, float * GGML_RESTRICT s, size_ vint8m2_t q8b2 = __riscv_vle8_v_i8m2(y[ib + 1].qs, 32); // Unpack the weight blocks. - vuint8m2_t iq4bits1; - iq4bits1 = __riscv_vset_v_u8m1_u8m2(iq4bits1, 0, __riscv_vand_vx_u8m1(iq4_packed1, 0xf, 16)); - iq4bits1 = __riscv_vset_v_u8m1_u8m2(iq4bits1, 1, __riscv_vsrl_vx_u8m1(iq4_packed1, 4, 16)); - vuint8m2_t iq4bits2; - iq4bits2 = __riscv_vset_v_u8m1_u8m2(iq4bits2, 0, __riscv_vand_vx_u8m1(iq4_packed2, 0xf, 16)); - iq4bits2 = __riscv_vset_v_u8m1_u8m2(iq4bits2, 1, __riscv_vsrl_vx_u8m1(iq4_packed2, 4, 16)); + vuint8m2_t iq4bits1 = __riscv_vcreate_v_u8m1_u8m2( + __riscv_vand_vx_u8m1(iq4_packed1, 0xf, 16), + __riscv_vsrl_vx_u8m1(iq4_packed1, 4, 16) + ); + vuint8m2_t iq4bits2 = __riscv_vcreate_v_u8m1_u8m2( + __riscv_vand_vx_u8m1(iq4_packed2, 0xf, 16), + __riscv_vsrl_vx_u8m1(iq4_packed2, 4, 16) + ); // Gather values from the lookup table. vint8m2_t iq4b1 = __riscv_vrgather_vv_i8m2(values, iq4bits1, 32); @@ -3118,7 +3690,7 @@ static void ggml_vec_dot_iq4_nl_q8_0_vl128(int n, float * GGML_RESTRICT s, size_ *s = sumf; } -static void ggml_vec_dot_iq4_nl_q8_0_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +static NOINLINE void ggml_vec_dot_iq4_nl_q8_0_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(nrc == 1); UNUSED(nrc); UNUSED(bx); @@ -3182,7 +3754,7 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v case 128: ggml_vec_dot_iq4_nl_q8_0_vl128(n, s, bs, vx, bx, vy, by, nrc); break; - default: + default: // 256 and above ggml_vec_dot_iq4_nl_q8_0_vl256(n, s, bs, vx, bx, vy, by, nrc); break; } @@ -3192,7 +3764,73 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v } #if defined __riscv_v_intrinsic -static void ggml_vec_dot_iq4_xs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +static NOINLINE void ggml_vec_dot_iq4_xs_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + assert(n % QK_K == 0); + + const block_iq4_xs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + const vint8m4_t values = __riscv_vle8_v_i8m4(kvalues_iq4nl, 16); + float sumf = 0; + + for (int ibl = 0; ibl < nb; ++ibl) { + const int8_t * q8 = y[ibl].qs; + const uint8_t * iq4 = x[ibl].qs; + uint16_t h = x[ibl].scales_h; + + // We process 2 sub-blocks together. + int sumi1 = 0, sumi2 = 0; + #pragma GCC unroll 1 + for (int ib = 0; ib < QK_K / 64; ++ib) { + // Load the packed weights. + const vuint8m2_t iq4_packed = __riscv_vle8_v_u8m2(iq4, 32); + iq4 += 32; + + // Unpack the weight blocks. + const vuint8m2_t iq4bits_lo = __riscv_vand_vx_u8m2(iq4_packed, 0xf, 32); + const vuint8m2_t iq4bits_hi = __riscv_vsrl_vx_u8m2(iq4_packed, 4, 32); + const vuint8m4_t iq4bits = __riscv_vcreate_v_u8m2_u8m4(iq4bits_lo, iq4bits_hi); + const vuint8m4_t iq4bits_reorder = __riscv_vcreate_v_u8m1_u8m4( + __riscv_vmv_v_v_u8m1(__riscv_vget_v_u8m4_u8m1(iq4bits, 0), 16), + __riscv_vmv_v_v_u8m1(__riscv_vget_v_u8m4_u8m1(iq4bits, 2), 16), + __riscv_vmv_v_v_u8m1(__riscv_vget_v_u8m4_u8m1(iq4bits, 1), 16), + __riscv_vmv_v_v_u8m1(__riscv_vget_v_u8m4_u8m1(iq4bits, 3), 16) + ); + const vint8m4_t iq4b = __riscv_vrgather_vv_i8m4(values, iq4bits_reorder, 64); + + // Multiply with activations. + const vint8m4_t q8b = __riscv_vle8_v_i8m4(q8, 64); + q8 += 64; + const vint16m8_t prod = __riscv_vwmul_vv_i16m8(iq4b, q8b, 64); + + // Reduce separately. + const int acc0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1(__riscv_vget_v_i16m8_i16m4(prod, 0), __riscv_vmv_v_x_i32m1(0, 1), 32)); + const int acc1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1(__riscv_vget_v_i16m8_i16m4(prod, 1), __riscv_vmv_v_x_i32m1(0, 1), 32)); + + const int ls1 = ((x[ibl].scales_l[ib] & 0xf) | ((h << 4) & 0x30)) - 32; + const int ls2 = ((x[ibl].scales_l[ib] >> 4) | ((h << 2) & 0x30)) - 32; + h >>= 4; + + sumi1 += acc0 * ls1; + sumi2 += acc1 * ls2; + + __asm__ __volatile__("" ::: "memory"); + } + + sumf += GGML_CPU_FP16_TO_FP32(x[ibl].d) * y[ibl].d * (sumi1 + sumi2); + } + + *s = sumf; +} + +static NOINLINE void ggml_vec_dot_iq4_xs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(nrc == 1); UNUSED(nrc); UNUSED(bx); @@ -3207,16 +3845,15 @@ static void ggml_vec_dot_iq4_xs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_ const vint8m4_t values = __riscv_vle8_v_i8m4(kvalues_iq4nl, 16); float sumf = 0; - int acc[4]; // Indices for re-ordering IQ4 data. - uint64_t index[16] = { + uint16_t index[16] = { 0, 1, 8, 9, 2, 3, 10, 11, 4, 5,12, 13, 6, 7, 14, 15, }; - vuint64m4_t i_vec = __riscv_vle64_v_u64m4(index, 16); + vuint16m1_t i_vec = __riscv_vle16_v_u16m1(index, 16); for (int ibl = 0; ibl < nb; ++ibl) { const int8_t * q8 = y[ibl].qs; @@ -3225,30 +3862,33 @@ static void ggml_vec_dot_iq4_xs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_ int sumi1 = 0, sumi2 = 0, sumi3 = 0, sumi4 = 0; + #pragma GCC unroll 1 for (int ib = 0; ib < QK_K / 128; ++ib) { // Weights and activations. vuint8m2_t iq4_packed = __riscv_vle8_v_u8m2(iq4, 64); - vint8m4_t q8b = __riscv_vle8_v_i8m4(q8, 128); iq4 += 64; - q8 += 128; // Unpack the weight blocks. vuint8m2_t iq4bits_lo = __riscv_vand_vx_u8m2(iq4_packed, 0xf, 64); vuint8m2_t iq4bits_hi = __riscv_vsrl_vx_u8m2(iq4_packed, 4, 64); - vuint8m4_t iq4bits; - iq4bits = __riscv_vset_v_u8m2_u8m4(iq4bits, 0, iq4bits_lo); - iq4bits = __riscv_vset_v_u8m2_u8m4(iq4bits, 1, iq4bits_hi); - vuint8m4_t iq4bits_reorder = __riscv_vreinterpret_v_u64m4_u8m4(__riscv_vrgather_vv_u64m4(__riscv_vreinterpret_v_u8m4_u64m4(iq4bits), i_vec, 16)); + vuint8m4_t iq4bits = __riscv_vcreate_v_u8m2_u8m4(iq4bits_lo, iq4bits_hi); + vuint8m4_t iq4bits_reorder = __riscv_vreinterpret_v_u64m4_u8m4(__riscv_vrgatherei16_vv_u64m4(__riscv_vreinterpret_v_u8m4_u64m4(iq4bits), i_vec, 16)); vint8m4_t iq4b = __riscv_vrgather_vv_i8m4(values, iq4bits_reorder, 128); + __asm__ __volatile__("" ::: "memory"); + // Multiply with activations. + vint8m4_t q8b = __riscv_vle8_v_i8m4(q8, 128); vint16m8_t prod = __riscv_vwmul_vv_i16m8(iq4b, q8b, 128); + q8 += 128; + + __asm__ __volatile__("" ::: "memory"); // Reduce separately. - __riscv_vse32_v_i32m1(&acc[0],__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(prod, 0), __riscv_vmv_v_x_i32m1(0, 1), 32), 1); - __riscv_vse32_v_i32m1(&acc[1],__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(prod, 1), __riscv_vmv_v_x_i32m1(0, 1), 32), 1); - __riscv_vse32_v_i32m1(&acc[2],__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(prod, 2), __riscv_vmv_v_x_i32m1(0, 1), 32), 1); - __riscv_vse32_v_i32m1(&acc[3],__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(prod, 3), __riscv_vmv_v_x_i32m1(0, 1), 32), 1); + int acc0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(prod, 0), __riscv_vmv_v_x_i32m1(0, 1), 32)); + int acc1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(prod, 1), __riscv_vmv_v_x_i32m1(0, 1), 32)); + int acc2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(prod, 2), __riscv_vmv_v_x_i32m1(0, 1), 32)); + int acc3 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(prod, 3), __riscv_vmv_v_x_i32m1(0, 1), 32)); int ls1 = ((x[ibl].scales_l[ib * 2 + 0] & 0xf) | ((h << 4) & 0x30)) - 32; int ls2 = ((x[ibl].scales_l[ib * 2 + 0] >> 4) | ((h << 2) & 0x30)) - 32; @@ -3256,10 +3896,12 @@ static void ggml_vec_dot_iq4_xs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_ int ls4 = ((x[ibl].scales_l[ib * 2 + 1] >> 4) | ((h >> 2) & 0x30)) - 32; h >>= 8; - sumi1 += acc[0] * ls1; - sumi2 += acc[1] * ls2; - sumi3 += acc[2] * ls3; - sumi4 += acc[3] * ls4; + sumi1 += acc0 * ls1; + sumi2 += acc1 * ls2; + sumi3 += acc2 * ls3; + sumi4 += acc3 * ls4; + + __asm__ __volatile__("" ::: "memory"); } sumf += GGML_CPU_FP16_TO_FP32(x[ibl].d) * y[ibl].d * (sumi1 + sumi2 + sumi3 + sumi4); @@ -3272,6 +3914,9 @@ static void ggml_vec_dot_iq4_xs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { #if defined __riscv_v_intrinsic switch (__riscv_vlenb() * 8) { + case 128: + ggml_vec_dot_iq4_xs_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); + break; case 256: ggml_vec_dot_iq4_xs_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); break; @@ -3285,7 +3930,7 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v } #if defined __riscv_v_intrinsic -static void ggml_vec_dot_tq1_0_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +static NOINLINE void ggml_vec_dot_tq1_0_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(nrc == 1); UNUSED(nrc); UNUSED(bx); @@ -3301,8 +3946,107 @@ static void ggml_vec_dot_tq1_0_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t uint8_t pow[16] = {1, 1, 1, 1, 3, 3, 3, 3, 9, 9, 9, 9, 27, 27, 27, 27}; for (int i = 0; i < nb; i++) { + const uint8_t * GGML_RESTRICT tq = x[i].qs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + // First loop. - vint32m4_t suml1; + vint16m4_t suml1; + { + const int vl = 32; + const vuint8m2_t tqb = __riscv_vle8_v_u8m2(tq, vl); + tq += 32; + + { + const vuint16m4_t tq0 = __riscv_vsrl_vx_u16m4(__riscv_vwmulu_vx_u16m4(tqb, 3, vl), 8, vl); + const vint16m4_t q80 = __riscv_vwcvt_x_x_v_i16m4(__riscv_vle8_v_i8m2(q8, vl), vl); + suml1 = __riscv_vmul_vv_i16m4(__riscv_vreinterpret_v_u16m4_i16m4(__riscv_vsub_vx_u16m4(tq0, 1, vl)), q80, vl); + q8 += 32; + } + + uint8_t pow3 = 3; + #pragma GCC unroll 1 + for (int t = 0; t < 4; t++) { + const vuint16m4_t tqn = __riscv_vsrl_vx_u16m4(__riscv_vwmulu_vx_u16m4(__riscv_vmul_vx_u8m2(tqb, pow3, vl), 3, vl), 8, vl); + const vint16m4_t q8n = __riscv_vwcvt_x_x_v_i16m4(__riscv_vle8_v_i8m2(q8, vl), vl); + suml1 = __riscv_vmacc_vv_i16m4(suml1, __riscv_vreinterpret_v_u16m4_i16m4(__riscv_vsub_vx_u16m4(tqn, 1, vl)), q8n, vl); + pow3 *= 3; + q8 += 32; + } + } + + // Second loop. + vint16m2_t suml2; + { + const int vl = 16; + const vuint8m1_t tqb = __riscv_vle8_v_u8m1(tq, vl); + + { + const vuint16m2_t tq0 = __riscv_vsrl_vx_u16m2(__riscv_vwmulu_vx_u16m2(tqb, 3, vl), 8, vl); + const vint16m2_t q80 = __riscv_vwcvt_x_x_v_i16m2(__riscv_vle8_v_i8m1(q8, vl), vl); + suml2 = __riscv_vmul_vv_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vsub_vx_u16m2(tq0, 1, vl)), q80, vl); + q8 += 16; + } + + uint8_t pow3 = 3; + #pragma GCC unroll 1 + for (int t = 0; t < 4; t++) { + const vuint16m2_t tqn = __riscv_vsrl_vx_u16m2(__riscv_vwmulu_vx_u16m2(__riscv_vmul_vx_u8m1(tqb, pow3, vl), 3, vl), 8, vl); + const vint16m2_t q8n = __riscv_vwcvt_x_x_v_i16m2(__riscv_vle8_v_i8m1(q8, vl), vl); + suml2 = __riscv_vmacc_vv_i16m2(suml2, __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vsub_vx_u16m2(tqn, 1, vl)), q8n, vl); + pow3 *= 3; + q8 += 16; + } + } + + // Third loop. + vint16m2_t suml3; + { + const int vl = 16; + + uint32_t qh; + memcpy(&qh, &x[i].qh[0], 4); + // Prevent fusion with vmv. + __asm__ __volatile__("" : "+r"(qh)); + const vuint8m1_t tqb = __riscv_vreinterpret_v_u32m1_u8m1(__riscv_vmv_v_x_u32m1(qh, vl / 4)); + + const vuint8m1_t p = __riscv_vle8_v_u8m1(pow, vl); + + const vuint16m2_t tq0 = __riscv_vsrl_vx_u16m2(__riscv_vwmulu_vx_u16m2(__riscv_vmul_vv_u8m1(tqb, p, vl), 3, vl), 8, vl); + + const vint16m2_t q80 = __riscv_vwcvt_x_x_v_i16m2(__riscv_vle8_v_i8m1(q8, vl), vl); + + suml3 = __riscv_vmul_vv_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vsub_vx_u16m2(tq0, 1, vl)), q80, vl); + } + + vint16m2_t sumb = __riscv_vadd_vv_i16m2(__riscv_vget_v_i16m4_i16m2(suml1, 0), __riscv_vget_v_i16m4_i16m2(suml1, 1), 16); + sumb = __riscv_vadd_vv_i16m2(sumb, suml2, 16); + sumb = __riscv_vadd_vv_i16m2(sumb, suml3, 16); + + vint32m1_t sum = __riscv_vwredsum_vs_i16m2_i32m1(sumb, __riscv_vmv_v_x_i32m1(0, 1), 16); + sumf += __riscv_vmv_x_s_i32m1_i32(sum) * y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); + } + + *s = sumf; +} + +static NOINLINE void ggml_vec_dot_tq1_0_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_tq1_0 * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + float sumf = 0.0f; + uint8_t pow[16] = {1, 1, 1, 1, 3, 3, 3, 3, 9, 9, 9, 9, 27, 27, 27, 27}; + + for (int i = 0; i < nb; i++) { + // First loop. + vint16m2_t suml1; { const int vl = 32; vuint8m1_t tq = __riscv_vle8_v_u8m1(x[i].qs, vl); @@ -3325,13 +4069,13 @@ static void ggml_vec_dot_tq1_0_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t vint16m2_t sum3 = __riscv_vmul_vv_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vsub_vx_u16m2(tq3, 1, vl)), q83, vl); vint16m2_t sum4 = __riscv_vmul_vv_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vsub_vx_u16m2(tq4, 1, vl)), q84, vl); - vint32m4_t sumi0 = __riscv_vwadd_vv_i32m4(sum0, sum1, vl); - vint32m4_t sumi1 = __riscv_vwadd_vv_i32m4(sum2, sum3, vl); - suml1 = __riscv_vadd_vv_i32m4(__riscv_vwcvt_x_x_v_i32m4(sum4, vl), __riscv_vadd_vv_i32m4(sumi0, sumi1, vl), vl); + vint16m2_t sumi0 = __riscv_vadd_vv_i16m2(sum0, sum1, vl); + vint16m2_t sumi1 = __riscv_vadd_vv_i16m2(sum2, sum3, vl); + suml1 = __riscv_vadd_vv_i16m2(sum4, __riscv_vadd_vv_i16m2(sumi0, sumi1, vl), vl); } // Second loop. - vint32m2_t suml2; + vint16m1_t suml2; { const int vl = 16; vuint8mf2_t tq = __riscv_vle8_v_u8mf2(x[i].qs + 32, vl); @@ -3354,13 +4098,13 @@ static void ggml_vec_dot_tq1_0_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t vint16m1_t sum3 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq3, 1, vl)), q83, vl); vint16m1_t sum4 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq4, 1, vl)), q84, vl); - vint32m2_t sumi0 = __riscv_vwadd_vv_i32m2(sum0, sum1, vl); - vint32m2_t sumi1 = __riscv_vwadd_vv_i32m2(sum2, sum3, vl); - suml2 = __riscv_vadd_vv_i32m2(__riscv_vwcvt_x_x_v_i32m2(sum4, vl), __riscv_vadd_vv_i32m2(sumi0, sumi1, vl), vl); + vint16m1_t sumi0 = __riscv_vadd_vv_i16m1(sum0, sum1, vl); + vint16m1_t sumi1 = __riscv_vadd_vv_i16m1(sum2, sum3, vl); + suml2 = __riscv_vadd_vv_i16m1(sum4, __riscv_vadd_vv_i16m1(sumi0, sumi1, vl), vl); } // Third loop. - vint32m2_t suml3; + vint16m1_t suml3; { const int vl = 16; @@ -3376,15 +4120,13 @@ static void ggml_vec_dot_tq1_0_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t vint16m1_t q80 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 240, vl), vl); - vint16m1_t sum0 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq0, 1, vl)), q80, vl); - suml3 = __riscv_vwcvt_x_x_v_i32m2(sum0, vl); + suml3 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq0, 1, vl)), q80, vl); } - vint32m2_t sumb = __riscv_vadd_vv_i32m2(__riscv_vget_v_i32m4_i32m2(suml1, 0), __riscv_vget_v_i32m4_i32m2(suml1, 1), 16); - sumb = __riscv_vadd_vv_i32m2(sumb, suml2, 16); - sumb = __riscv_vadd_vv_i32m2(sumb, suml3, 16); + vint16m1_t sumb = __riscv_vadd_vv_i16m1(__riscv_vget_v_i16m2_i16m1(suml1, 0), __riscv_vget_v_i16m2_i16m1(suml1, 1), 16); + sumb = __riscv_vadd_vv_i16m1(sumb, __riscv_vadd_vv_i16m1(suml2, suml3, 16), 16); - vint32m1_t sum = __riscv_vredsum_vs_i32m2_i32m1(sumb, __riscv_vmv_v_x_i32m1(0, 1), 16); + vint32m1_t sum = __riscv_vwredsum_vs_i16m1_i32m1(sumb, __riscv_vmv_v_x_i32m1(0, 1), 16); sumf += __riscv_vmv_x_s_i32m1_i32(sum) * y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); } @@ -3395,6 +4137,9 @@ static void ggml_vec_dot_tq1_0_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t void ggml_vec_dot_tq1_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { #if defined __riscv_v_intrinsic switch (__riscv_vlenb() * 8) { + case 128: + ggml_vec_dot_tq1_0_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); + break; case 256: ggml_vec_dot_tq1_0_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); break; @@ -3408,7 +4153,89 @@ void ggml_vec_dot_tq1_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo } #if defined __riscv_v_intrinsic -static void ggml_vec_dot_tq2_0_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +static NOINLINE void ggml_vec_dot_tq2_0_q8_K_vl128(const int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_tq2_0 * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + float sumf = 0.0f; + for (int i = 0; i < nb; ++i) { + int32_t sumi = 0; + + for (size_t j = 0; j < sizeof(x[0].qs); j += 32) { + const int8_t * py0 = &y[i].qs[j * 4 + 0 * 32]; + const int8_t * py1 = &y[i].qs[j * 4 + 1 * 32]; + const int8_t * py2 = &y[i].qs[j * 4 + 2 * 32]; + const int8_t * py3 = &y[i].qs[j * 4 + 3 * 32]; + const uint8_t* px = &x[i].qs[j]; + + size_t vl = __riscv_vsetvl_e16m4(32); + vint16m4_t vacc16 = __riscv_vmv_v_x_i16m4(0, vl); + + // Load Raw Packed elements + vl = __riscv_vsetvl_e8m2(32); + vuint8m2_t vx_u8 = __riscv_vle8_v_u8m2(px, vl); + + // Process bits 1:0 + { + // Unpack + vuint8m2_t t0 = __riscv_vand_vx_u8m2(vx_u8, 0x03, vl); + vint8m2_t vq = __riscv_vsub_vx_i8m2(__riscv_vreinterpret_v_u8m2_i8m2(t0), 1, vl); + vint8m2_t vy = __riscv_vle8_v_i8m2(py0, vl); + // Accumulate + vacc16 = __riscv_vwmacc_vv_i16m4(vacc16, vq, vy, vl); + } + __asm__ volatile("" ::: "memory"); + // Process bits 3:2 + { + vuint8m2_t t1 = __riscv_vsrl_vx_u8m2(vx_u8, 2, vl); + t1 = __riscv_vand_vx_u8m2(t1, 0x03, vl); + vint8m2_t vq = __riscv_vsub_vx_i8m2(__riscv_vreinterpret_v_u8m2_i8m2(t1), 1, vl); + + vint8m2_t vy = __riscv_vle8_v_i8m2(py1, vl); + vacc16 = __riscv_vwmacc_vv_i16m4(vacc16, vq, vy, vl); + } + __asm__ volatile("" ::: "memory"); + // Process bits 5:4 + { + vuint8m2_t t2 = __riscv_vsrl_vx_u8m2(vx_u8, 4, vl); + t2 = __riscv_vand_vx_u8m2(t2, 0x03, vl); + vint8m2_t vq = __riscv_vsub_vx_i8m2(__riscv_vreinterpret_v_u8m2_i8m2(t2), 1, vl); + + vint8m2_t vy = __riscv_vle8_v_i8m2(py2, vl); + vacc16 = __riscv_vwmacc_vv_i16m4(vacc16, vq, vy, vl); + } + __asm__ volatile("" ::: "memory"); + // Process bits 7:6 + { + vuint8m2_t t3 = __riscv_vsrl_vx_u8m2(vx_u8, 6, vl); + vint8m2_t vq = __riscv_vsub_vx_i8m2(__riscv_vreinterpret_v_u8m2_i8m2(t3), 1, vl); + + vint8m2_t vy = __riscv_vle8_v_i8m2(py3, vl); + vacc16 = __riscv_vwmacc_vv_i16m4(vacc16, vq, vy, vl); + } + __asm__ volatile("" ::: "memory"); + vl = __riscv_vsetvl_e16m4(32); + vint32m1_t vzero32 = __riscv_vmv_v_x_i32m1(0, 1); + vint32m1_t vred32 = __riscv_vwredsum_vs_i16m4_i32m1(vacc16, vzero32, vl); + sumi += __riscv_vmv_x_s_i32m1_i32(vred32); + } + + const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); + sumf += (float)sumi * d; + } + + *s = sumf; +} + +static NOINLINE void ggml_vec_dot_tq2_0_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); UNUSED(nrc); @@ -3483,6 +4310,9 @@ static void ggml_vec_dot_tq2_0_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t void ggml_vec_dot_tq2_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { #if defined __riscv_v_intrinsic switch (__riscv_vlenb() * 8) { + case 128: + ggml_vec_dot_tq2_0_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); + break; case 256: ggml_vec_dot_tq2_0_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); break; @@ -3496,7 +4326,7 @@ void ggml_vec_dot_tq2_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo } #if defined __riscv_v_intrinsic -static void ggml_vec_dot_mxfp4_q8_0_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +static NOINLINE void ggml_vec_dot_mxfp4_q8_0_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(nrc == 1); UNUSED(nrc); UNUSED(bx); @@ -3526,12 +4356,14 @@ static void ggml_vec_dot_mxfp4_q8_0_vl128(int n, float * GGML_RESTRICT s, size_t vint8m2_t q8b2 = __riscv_vle8_v_i8m2(y[ib + 1].qs, 32); // Unpack the weight blocks. - vuint8m2_t mxbits1; - mxbits1 = __riscv_vset_v_u8m1_u8m2(mxbits1, 0, __riscv_vand_vx_u8m1(mx_packed1, 0xf, 16)); - mxbits1 = __riscv_vset_v_u8m1_u8m2(mxbits1, 1, __riscv_vsrl_vx_u8m1(mx_packed1, 4, 16)); - vuint8m2_t mxbits2; - mxbits2 = __riscv_vset_v_u8m1_u8m2(mxbits2, 0, __riscv_vand_vx_u8m1(mx_packed2, 0xf, 16)); - mxbits2 = __riscv_vset_v_u8m1_u8m2(mxbits2, 1, __riscv_vsrl_vx_u8m1(mx_packed2, 4, 16)); + vuint8m2_t mxbits1 = __riscv_vcreate_v_u8m1_u8m2( + __riscv_vand_vx_u8m1(mx_packed1, 0xf, 16), + __riscv_vsrl_vx_u8m1(mx_packed1, 4, 16) + ); + vuint8m2_t mxbits2 = __riscv_vcreate_v_u8m1_u8m2( + __riscv_vand_vx_u8m1(mx_packed2, 0xf, 16), + __riscv_vsrl_vx_u8m1(mx_packed2, 4, 16) + ); // Gather values from the lookup table. vint8m2_t mxb1 = __riscv_vrgather_vv_i8m2(values, mxbits1, 32); @@ -3549,7 +4381,7 @@ static void ggml_vec_dot_mxfp4_q8_0_vl128(int n, float * GGML_RESTRICT s, size_t *s = sumf; } -static void ggml_vec_dot_mxfp4_q8_0_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +static NOINLINE void ggml_vec_dot_mxfp4_q8_0_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(nrc == 1); UNUSED(nrc); UNUSED(bx); @@ -3613,7 +4445,7 @@ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo case 128: ggml_vec_dot_mxfp4_q8_0_vl128(n, s, bs, vx, bx, vy, by, nrc); break; - default: + default: // 256 and above ggml_vec_dot_mxfp4_q8_0_vl256(n, s, bs, vx, bx, vy, by, nrc); break; } From 655c0750f5a027817fa3038f17232dd3bf717480 Mon Sep 17 00:00:00 2001 From: Kusha Gharahi <3326002+kushagharahi@users.noreply.github.com> Date: Thu, 16 Apr 2026 03:54:37 -0500 Subject: [PATCH 446/831] metal: Implement ROLL op (llama/21946) * nix: support unified apple-sdk * Impl roll op for Metal * Revert "nix: support unified apple-sdk" This reverts commit abfa473360471532c547de8b202c780507924d4b. * update ops.md * update op docs --- ggml/src/ggml-metal/ggml-metal-device.cpp | 17 +++++++ ggml/src/ggml-metal/ggml-metal-device.h | 1 + ggml/src/ggml-metal/ggml-metal-device.m | 1 + ggml/src/ggml-metal/ggml-metal-impl.h | 23 +++++++++ ggml/src/ggml-metal/ggml-metal-ops.cpp | 57 +++++++++++++++++++++++ ggml/src/ggml-metal/ggml-metal-ops.h | 1 + ggml/src/ggml-metal/ggml-metal.metal | 34 ++++++++++++++ 7 files changed, 134 insertions(+) diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 8e0836c0beb..07d016d2227 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -1819,6 +1819,23 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_upscale(ggml_met return res; } +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_roll(ggml_metal_library_t lib, const ggml_tensor * op) { + assert(op->op == GGML_OP_ROLL); + + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_roll_%s", ggml_type_name(op->src[0]->type)); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + } + + return res; +} + ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_PAD); diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index de43f819312..b423501358e 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -152,6 +152,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_3d struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_upscale (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_roll (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_arange (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_adamw (ggml_metal_library_t lib, const struct ggml_tensor * op); diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index effe666a691..27cb1683518 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -1138,6 +1138,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_ARGSORT: case GGML_OP_TOP_K: case GGML_OP_ARANGE: + case GGML_OP_ROLL: return true; case GGML_OP_FLASH_ATTN_EXT: // for new head sizes, add checks here diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index e7433f2a658..379a8b33a14 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -1017,6 +1017,29 @@ typedef struct { int32_t p1; } ggml_metal_kargs_pad_reflect_1d; +typedef struct { + int64_t ne00; + int64_t ne01; + int64_t ne02; + int64_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int64_t ne0; + int64_t ne1; + int64_t ne2; + int64_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; + int32_t s0; + int32_t s1; + int32_t s2; + int32_t s3; +} ggml_metal_kargs_roll; + typedef struct { uint64_t nb1; int dim; diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 5b426be103f..e173527909a 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -410,6 +410,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { { n_fuse = ggml_metal_op_pad_reflect_1d(ctx, idx); } break; + case GGML_OP_ROLL: + { + n_fuse = ggml_metal_op_roll(ctx, idx); + } break; case GGML_OP_ARANGE: { n_fuse = ggml_metal_op_arange(ctx, idx); @@ -3945,6 +3949,59 @@ int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) { return 1; } +int ggml_metal_op_roll(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + const int32_t s0 = ggml_get_op_params_i32(op, 0); + const int32_t s1 = ggml_get_op_params_i32(op, 1); + const int32_t s2 = ggml_get_op_params_i32(op, 2); + const int32_t s3 = ggml_get_op_params_i32(op, 3); + + ggml_metal_kargs_roll args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.s0 =*/ s0, + /*.s1 =*/ s1, + /*.s2 =*/ s2, + /*.s3 =*/ s3 + }; + + auto pipeline = ggml_metal_library_get_pipeline_roll(lib, op); + + const int nth = std::min(1024, ne0); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + + ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1); + + return 1; +} + int ggml_metal_op_pad(ggml_metal_op_t ctx, int idx) { ggml_tensor * op = ctx->node(idx); diff --git a/ggml/src/ggml-metal/ggml-metal-ops.h b/ggml/src/ggml-metal/ggml-metal-ops.h index 50e3c5c77a1..36c61071b4f 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.h +++ b/ggml/src/ggml-metal/ggml-metal-ops.h @@ -81,6 +81,7 @@ int ggml_metal_op_conv_transpose_2d (ggml_metal_op_t ctx, int idx); int ggml_metal_op_upscale (ggml_metal_op_t ctx, int idx); int ggml_metal_op_pad (ggml_metal_op_t ctx, int idx); int ggml_metal_op_pad_reflect_1d (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_roll (ggml_metal_op_t ctx, int idx); int ggml_metal_op_arange (ggml_metal_op_t ctx, int idx); int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx); int ggml_metal_op_argmax (ggml_metal_op_t ctx, int idx); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 445a4deca83..9f38c9d2968 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -5247,6 +5247,40 @@ kernel void kernel_upscale_bicubic_f32( } } +kernel void kernel_roll_f32( + constant ggml_metal_kargs_roll & args, + device const char * src0, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + + const int64_t i3 = tgpig.z; + const int64_t i2 = tgpig.y; + const int64_t i1 = tgpig.x; + + device const float * src0_ptr = (device const float *) src0; + device float * dst_ptr = (device float *) dst; + + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + // apply shifts and wrap around + int64_t i00 = i0 - args.s0; + int64_t i01 = i1 - args.s1; + int64_t i02 = i2 - args.s2; + int64_t i03 = i3 - args.s3; + + if (i00 < 0) { i00 += args.ne00; } else if (i00 >= args.ne00) { i00 -= args.ne00; } + if (i01 < 0) { i01 += args.ne01; } else if (i01 >= args.ne01) { i01 -= args.ne01; } + if (i02 < 0) { i02 += args.ne02; } else if (i02 >= args.ne02) { i02 -= args.ne02; } + if (i03 < 0) { i03 += args.ne03; } else if (i03 >= args.ne03) { i03 -= args.ne03; } + + int64_t src_idx = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00 + i00; + int64_t dst_idx = i3 *args.ne2 *args.ne1 *args.ne0 + i2 *args.ne1 *args.ne0 + i1 *args.ne0 + i0; + + dst_ptr[dst_idx] = src0_ptr[src_idx]; + } +} + kernel void kernel_pad_f32( constant ggml_metal_kargs_pad & args, device const char * src0, From 820438ae2c60aa9c5fd9a20edb085b414139440d Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Thu, 16 Apr 2026 17:21:28 +0800 Subject: [PATCH 447/831] ggml: add graph_reused (llama/21764) * ggml: add graph_reused * use versioning instead of reuse flag * increment version with atomic * use top bits for split numbering * add assert * move counter to ggml.c * set uid in split_graph only * fix windows * address further review comments * get next_uid rather than doing bit manipulation * rename + add comment about uid --- ggml/src/ggml-backend.cpp | 7 +++++++ ggml/src/ggml-cuda/common.cuh | 1 + ggml/src/ggml-cuda/ggml-cuda.cu | 9 +++++++++ ggml/src/ggml-impl.h | 6 ++++++ ggml/src/ggml.c | 12 ++++++++++++ 5 files changed, 35 insertions(+) diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index 1a555bf2a4d..d9f8aaec52f 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -1030,6 +1030,8 @@ void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgra GGML_ABORT("%s: failed to initialize context\n", __func__); } + graph->uid = ggml_graph_next_uid(); + // pass 1: assign backends to ops with pre-allocated inputs for (int i = 0; i < graph->n_leafs; i++) { struct ggml_tensor * leaf = graph->leafs[i]; @@ -1477,6 +1479,11 @@ void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgra assert(graph_copy->size > graph_copy->n_leafs); graph_copy->leafs[graph_copy->n_leafs++] = leaf; } + + // set ids for all splits + for (int i = 0; i < sched->n_splits; ++i) { + sched->splits[i].graph.uid = ggml_graph_next_uid(); + } } static bool ggml_backend_sched_alloc_splits(ggml_backend_sched_t sched) { diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index ad30ecd8fa5..66ed02d2923 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -1186,6 +1186,7 @@ struct ggml_cuda_graph { std::vector nodes; bool disable_due_to_gpu_arch = false; bool warmup_complete = false; + uint64_t uid = 0; struct node_properties { ggml_tensor node; void * node_src_data_ptrs[GGML_MAX_SRC]; diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 790f53cead7..de579d2ed50 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3108,6 +3108,15 @@ static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx const void * graph_key = ggml_cuda_graph_get_key(cgraph); ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key); + if (cgraph->uid != 0 && + cgraph->uid == graph->uid) { + GGML_LOG_DEBUG("CUDA Graph id %zu reused\n", cgraph->uid); + GGML_ASSERT((int)graph->node_props.size() == cgraph->n_nodes); + return false; + } + + graph->uid = cgraph->uid; + // Check if the graph size has changed if ((int)graph->node_props.size() != cgraph->n_nodes) { res = true; diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h index 0639db362e7..62b76abbcec 100644 --- a/ggml/src/ggml-impl.h +++ b/ggml/src/ggml-impl.h @@ -30,6 +30,8 @@ extern "C" { void ggml_print_backtrace(void); +uint64_t ggml_graph_next_uid(void); + #ifndef MIN # define MIN(a, b) ((a) < (b) ? (a) : (b)) #endif @@ -338,6 +340,10 @@ struct ggml_cgraph { struct ggml_hash_set visited_hash_set; enum ggml_cgraph_eval_order order; + + // an optional identifier that can be utilized to recognize same graphs if two non-zero values match + // a value of 0 means it is not set and should be ignored + uint64_t uid; }; // returns a slice of cgraph with nodes [i0, i1) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 0142498d967..eda041f4518 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -53,6 +53,16 @@ #define UNUSED GGML_UNUSED +uint64_t ggml_graph_next_uid(void) { +#ifdef _MSC_VER + static volatile long long counter = 1; + return (uint64_t) _InterlockedIncrement64(&counter) - 1; +#else + static uint64_t counter = 1; + return __atomic_fetch_add(&counter, 1, __ATOMIC_RELAXED); +#endif +} + // Needed for ggml_fp32_to_bf16_row() #if defined(__AVX512BF16__) #if defined(_MSC_VER) @@ -7098,6 +7108,7 @@ struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t siz /*.use_counts =*/ use_counts_ptr, /*.hash_table =*/ { hash_size, hash_used, hash_keys_ptr }, /*.order =*/ GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT, + /*.uid =*/ 0, }; ggml_hash_set_reset(&cgraph->visited_hash_set); @@ -7125,6 +7136,7 @@ struct ggml_cgraph ggml_graph_view(struct ggml_cgraph * cgraph0, int i0, int i1) /*.use_counts =*/ cgraph0->use_counts, /*.visited_hash_set =*/ cgraph0->visited_hash_set, /*.order =*/ cgraph0->order, + /*.uid =*/ 0 }; return cgraph; From 57a48a485084a7daa2b61b760968dc384a54b354 Mon Sep 17 00:00:00 2001 From: shaofeiqi Date: Thu, 16 Apr 2026 12:08:33 -0700 Subject: [PATCH 448/831] opencl: add q5_K gemm and gemv kernels for Adreno (llama/21595) --- ggml/src/ggml-opencl/CMakeLists.txt | 2 + ggml/src/ggml-opencl/ggml-opencl.cpp | 326 ++++++++++++++++++ ggml/src/ggml-opencl/kernels/cvt.cl | 94 ++++- .../kernels/gemm_noshuffle_q5_k_f32.cl | 176 ++++++++++ .../kernels/gemv_noshuffle_q5_k_f32.cl | 326 ++++++++++++++++++ 5 files changed, 922 insertions(+), 2 deletions(-) create mode 100644 ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_k_f32.cl create mode 100644 ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_k_f32.cl diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index 112c2afe821..772fc537494 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -121,6 +121,8 @@ set(GGML_OPENCL_KERNELS gemm_noshuffle_q4_k_f32 gemv_noshuffle_q6_k_f32 gemm_noshuffle_q6_k_f32 + gemv_noshuffle_q5_k_f32 + gemm_noshuffle_q5_k_f32 mul neg norm diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index a581402300a..b27fbb13a3a 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -542,6 +542,8 @@ struct ggml_backend_opencl_context { cl_kernel kernel_restore_block_q4_K_noshuffle; cl_kernel kernel_convert_block_q4_K, kernel_restore_block_q4_K; cl_kernel kernel_convert_block_q5_K, kernel_restore_block_q5_K; + cl_kernel kernel_convert_block_q5_K_noshuffle; + cl_kernel kernel_restore_block_q5_K_noshuffle; cl_kernel kernel_convert_block_q6_K, kernel_restore_block_q6_K; cl_kernel kernel_mul_mat_q4_0_f32_1d_8x_flat, kernel_mul_mat_q4_0_f32_1d_16x_flat; cl_kernel kernel_mul_mv_q4_1_f32; @@ -730,6 +732,8 @@ struct ggml_backend_opencl_context { cl_kernel kernel_gemm_noshuffle_q4_k_f32; cl_kernel kernel_gemv_noshuffle_q6_K_f32; cl_kernel kernel_gemm_noshuffle_q6_K_f32; + cl_kernel kernel_gemv_noshuffle_q5_k_f32; + cl_kernel kernel_gemm_noshuffle_q5_k_f32; #endif // GGML_OPENCL_USE_ADRENO_KERNELS void free() { @@ -944,6 +948,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve CL_CHECK((backend_ctx->kernel_restore_block_q4_K_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_K_noshuffle", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q5_K = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q5_K", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_q5_K = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q5_K", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q5_K_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q5_K_noshuffle", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q5_K_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q5_K_noshuffle", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q6_K = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q6_K", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_q6_K = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q6_K", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q6_K_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q6_K_noshuffle", &err), err)); @@ -2794,6 +2800,45 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve CL_CHECK((backend_ctx->kernel_gemm_noshuffle_q6_K_f32 = clCreateKernel(prog, "kernel_gemm_noshuffle_q6_K_f32", &err), err)); GGML_LOG_CONT("."); } + + // gemv_noshuffle_q5_k_f32 + { + std::string CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable "; + if (backend_ctx->has_vector_subgroup_broadcast) { + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAST "; + } + +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemv_noshuffle_q5_k_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("gemv_noshuffle_q5_k_f32.cl"); +#endif + + cl_program prog = build_program_from_source( + backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_gemv_compile_opts); + + CL_CHECK((backend_ctx->kernel_gemv_noshuffle_q5_k_f32 = clCreateKernel(prog, "kernel_gemv_noshuffle_q5_k_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // gemm_noshuffle_q5_k_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemm_noshuffle_q5_k_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("gemm_noshuffle_q5_k_f32.cl"); +#endif + cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_gemm_noshuffle_q5_k_f32 = clCreateKernel(prog, "kernel_gemm_noshuffle_q5_k_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } #endif // GGML_OPENCL_USE_ADRENO_KERNELS GGML_LOG_CONT("\n"); } @@ -5354,7 +5399,17 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, CL_CHECK((extra->qh = clCreateSubBuffer(extra_orig->data_device, CL_MEM_READ_WRITE, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); CL_CHECK(err); + #ifdef GGML_OPENCL_USE_ADRENO_KERNELS + cl_kernel kernel = backend_ctx->kernel_convert_block_q5_K; + if (use_adreno_kernels(backend_ctx, tensor)) { + kernel = backend_ctx->kernel_convert_block_q5_K_noshuffle; + } + #else cl_kernel kernel = backend_ctx->kernel_convert_block_q5_K; + #endif + + cl_uchar mask_0F = 0x0F; + cl_uchar mask_F0 = 0xF0; CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q)); @@ -5362,6 +5417,8 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->s)); CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra->d)); CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extra->dm)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_uchar), &mask_0F)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_uchar), &mask_F0)); size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; size_t local_work_size[] = {64, 1, 1}; @@ -5378,6 +5435,21 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, extra->size_dm = size_dm; tensor->extra = extra; +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_kernels(backend_ctx, tensor)) { + + int M = tensor->ne[1]; + int K = tensor->ne[0]; + + GGML_ASSERT(K % 32 == 0); + + // Transpose q, d, dm as ushort, qh as uchar + transpose_2d_as_16b(backend_ctx, extra->q, extra->q, size_q, K/4, M); + transpose_2d_as_8b (backend_ctx, extra->qh, extra->qh, size_qh, K/8, M); + transpose_2d_as_16b(backend_ctx, extra->d, extra->d, size_d, K/256, M); + transpose_2d_as_16b(backend_ctx, extra->dm, extra->dm, size_dm, K/256, M); + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS return; } if (tensor->type == GGML_TYPE_Q6_K) { @@ -5894,6 +5966,57 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, ggml_nbytes(tensor), NULL, &err); CL_CHECK(err); + cl_uchar mask_0F = 0x0F; + cl_uchar mask_F0 = 0xF0; + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_kernels(backend_ctx, tensor)) { + int M = tensor->ne[1]; + int K = tensor->ne[0]; + + size_t size_q = extra->size_q; + size_t size_qh = extra->size_qh; + size_t size_d = extra->size_d; + size_t size_dm = extra->size_dm; + + static ggml_cl_buffer buf_trans_q; + static ggml_cl_buffer buf_trans_qh; + static ggml_cl_buffer buf_trans_d; + static ggml_cl_buffer buf_trans_dm; + + buf_trans_q.allocate(backend_ctx->context, size_q); + buf_trans_qh.allocate(backend_ctx->context, size_qh); + buf_trans_d.allocate(backend_ctx->context, size_d); + buf_trans_dm.allocate(backend_ctx->context, size_dm); + + // Reverse transpose q, qh, d, dm + transpose_2d_as_16b(backend_ctx, extra->q, buf_trans_q.buffer, size_q, M, K/4); + transpose_2d_as_8b (backend_ctx, extra->qh, buf_trans_qh.buffer, size_qh, M, K/8); + transpose_2d_as_16b(backend_ctx, extra->d, buf_trans_d.buffer, size_d, M, K/256); + transpose_2d_as_16b(backend_ctx, extra->dm, buf_trans_dm.buffer, size_dm, M, K/256); + + cl_kernel kernel = backend_ctx->kernel_restore_block_q5_K_noshuffle; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &buf_trans_q.buffer)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &buf_trans_qh.buffer)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->s)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &buf_trans_d.buffer)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &buf_trans_dm.buffer)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_uchar), &mask_0F)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_uchar), &mask_F0)); + + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {1, 1, 1}; + + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, NULL)); + CL_CHECK(clEnqueueReadBuffer(queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); + return; + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + cl_kernel kernel = backend_ctx->kernel_restore_block_q5_K; CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q)); CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->qh)); @@ -5901,6 +6024,8 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->d)); CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra->dm)); CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_uchar), &mask_0F)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_uchar), &mask_F0)); size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; size_t local_work_size[] = {1, 1, 1}; @@ -10451,6 +10576,201 @@ static void ggml_cl_mul_mat_q6_K_f32_adreno(ggml_backend_t backend, const ggml_t #endif } +static void ggml_cl_mul_mat_q5_K_f32_adreno(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + ggml_tensor_extra_cl_q5_K * extra0_q5_k = (ggml_tensor_extra_cl_q5_K *)src0->extra; + + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne1 = dst->ne[1]; + + GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0); + + cl_context context = backend_ctx->context; + cl_kernel kernel; + + cl_int err; + cl_image_format img_fmt; + cl_image_desc img_desc; + cl_buffer_region region; + + int M = ne01; + int N = ne1; + int K = ne00; + + cl_uchar mask_d6 = 0x3F; + cl_uchar mask_d4 = 0x0F; + cl_uchar mask_hi2 = 0xC0; + + if (ne1 == 1) { + cl_mem q_img = nullptr; + cl_mem qh_img = nullptr; + cl_mem b_sub_buf = nullptr; + cl_mem b_img = nullptr; + + // image for q (CL_R, CL_UNSIGNED_INT32): width = M*K/2/4 + img_fmt = {CL_R, CL_UNSIGNED_INT32}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = M * K / 2 / 4; + img_desc.buffer = extra0_q5_k->q; + CL_CHECK((q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + // image for qh (CL_R, CL_HALF_FLOAT): width = M*K/16 + img_fmt = {CL_R, CL_HALF_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = M * K / 16; + img_desc.buffer = extra0_q5_k->qh; + CL_CHECK((qh_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + // subbuffer for activations + region.origin = offset1; + region.size = K * N * sizeof(float); + CL_CHECK((b_sub_buf = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for activations (CL_RGBA, CL_FLOAT): width = K*N/4 + img_fmt = {CL_RGBA, CL_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * N / 4; + img_desc.buffer = b_sub_buf; + CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + kernel = backend_ctx->kernel_gemv_noshuffle_q5_k_f32; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &q_img)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &qh_img)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q5_k->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q5_k->dm)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra0_q5_k->s)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &b_img)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_uchar), &mask_d6)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_uchar), &mask_d4)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_uchar), &mask_hi2)); + + size_t local_work_size[3] = {64, 4, 1}; + size_t global_work_size[3] = {(size_t)CEIL_DIV(ne01/2, 64)*64, 4, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + + CL_CHECK(clReleaseMemObject(q_img)); + CL_CHECK(clReleaseMemObject(qh_img)); + CL_CHECK(clReleaseMemObject(b_sub_buf)); + CL_CHECK(clReleaseMemObject(b_img)); + } else { + cl_mem b_sub_buf = nullptr; + cl_mem b_sub_buf_trans = nullptr; + cl_mem b_img = nullptr; + cl_mem b_img_trans = nullptr; + + // subbuffer for activations + region.origin = offset1; + region.size = K * N * sizeof(float); + CL_CHECK((b_sub_buf = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for activations + img_fmt = {CL_RGBA, CL_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * N / 4; + img_desc.buffer = b_sub_buf; + CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + // pad N to multiple of 8 + int extra_elements = N % 8; + int padding = 0; + if (extra_elements > 0) { + padding = 8 - extra_elements; + } + + // subbuffer for transposed activations + region.origin = 0; + region.size = K * (N + padding) * sizeof(float) / 2; + backend_ctx->prealloc_act_trans.allocate(context, region.size); + CL_CHECK((b_sub_buf_trans = clCreateSubBuffer(backend_ctx->prealloc_act_trans.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for transposed activations + img_fmt = {CL_RGBA, CL_HALF_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * (N + padding) / 4; + img_desc.buffer = b_sub_buf_trans; + CL_CHECK((b_img_trans = clCreateImage(context, 0, &img_fmt, &img_desc, NULL, &err), err)); + + // transpose activations + int height_B = N / 4; + if (height_B == 0) height_B = 1; + int width_B = K / 4; + int padded_height_B = (N + padding) / 4; + + kernel = backend_ctx->kernel_transpose_32_16; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &b_img)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &b_img_trans)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_B)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_B)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &padded_height_B)); + + size_t local_work_size_t[2] = {1, 16}; + size_t global_work_size_t[2] = {(size_t)width_B, (size_t)padded_height_B}; + backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size_t, local_work_size_t, dst); + + // gemm + kernel = backend_ctx->kernel_gemm_noshuffle_q5_k_f32; + int padded_N = N + padding; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q5_k->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q5_k->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q5_k->s)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q5_k->d)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra0_q5_k->dm)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &b_img_trans)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_int), &padded_N)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_uchar), &mask_d6)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_uchar), &mask_d4)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_uchar), &mask_hi2)); + + size_t global_work_size[3] = {(size_t)CEIL_DIV(ne1, 8), (size_t)CEIL_DIV(ne01, 4), 1}; + size_t local_work_size[3] = {1, 128, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + + CL_CHECK(clReleaseMemObject(b_sub_buf)); + CL_CHECK(clReleaseMemObject(b_sub_buf_trans)); + CL_CHECK(clReleaseMemObject(b_img)); + CL_CHECK(clReleaseMemObject(b_img_trans)); + } +#else + GGML_UNUSED(backend); + GGML_UNUSED(src0); + GGML_UNUSED(src1); + GGML_UNUSED(dst); +#endif +} + static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0); GGML_ASSERT(src0->extra); @@ -10600,6 +10920,12 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co return; } + // q5_K x fp32 + if (src0t == GGML_TYPE_Q5_K && src1t == GGML_TYPE_F32) { + ggml_cl_mul_mat_q5_K_f32_adreno(backend, src0, src1, dst); + return; + } + // q4_0 x fp32 if(src0t == GGML_TYPE_Q4_0 && src1t == GGML_TYPE_F32) { // TODO: remove duplicate definitions of image description + format -- move to top diff --git a/ggml/src/ggml-opencl/kernels/cvt.cl b/ggml/src/ggml-opencl/kernels/cvt.cl index 1bd83d29b3d..39af32d282b 100644 --- a/ggml/src/ggml-opencl/kernels/cvt.cl +++ b/ggml/src/ggml-opencl/kernels/cvt.cl @@ -568,7 +568,9 @@ kernel void kernel_convert_block_q5_K( global uchar * dst_qh, global uchar * dst_s, global half * dst_d, - global half * dst_dm + global half * dst_dm, + uchar mask_0F, + uchar mask_F0 ) { global struct block_q5_K * b = (global struct block_q5_K *) src0 + get_global_id(0); global uchar * q = (global uchar *) dst_q + QK_K/2*get_global_id(0); @@ -599,7 +601,9 @@ kernel void kernel_restore_block_q5_K( global uchar * src_s, global half * src_d, global half * src_dm, - global struct block_q5_K * dst + global struct block_q5_K * dst, + uchar mask_0F, + uchar mask_F0 ) { global struct block_q5_K * b = (global struct block_q5_K *) dst + get_global_id(0); global uchar * q = (global uchar *) src_q + QK_K/2*get_global_id(0); @@ -622,6 +626,92 @@ kernel void kernel_restore_block_q5_K( } } +kernel void kernel_convert_block_q5_K_noshuffle( + global struct block_q5_K * src0, + global uchar * dst_q, + global uchar * dst_qh, + global uchar * dst_s, + global half * dst_d, + global half * dst_dm, + uchar mask_0F, + uchar mask_F0 +) { + global struct block_q5_K * b = (global struct block_q5_K *) src0 + get_global_id(0); + global uchar * q = (global uchar *) dst_q + QK_K/2 * get_global_id(0); + global uchar * qh = (global uchar *) dst_qh + QK_K/8 * get_global_id(0); + global uchar * s = (global uchar *) dst_s + K_SCALE_SIZE * get_global_id(0); + global half * d = (global half *) dst_d + get_global_id(0); + global half * dm = (global half *) dst_dm + get_global_id(0); + + *d = b->d; + *dm = b->dm; + + for (int i = 0; i < QK_K / 64; ++i) { + for (int j = 0; j < 16; ++j) { + uchar x0 = b->qs[i*32 + 2*j]; + uchar x1 = b->qs[i*32 + 2*j + 1]; + q[i*32 + j] = convert_uchar(x0 & mask_0F) | convert_uchar((x1 & mask_0F) << 4); + q[i*32 + j + 16] = convert_uchar((x0 & mask_F0) >> 4) | convert_uchar(x1 & mask_F0); + } + } + + for (int l = 0; l < QK_K/8; ++l) { + uchar x0 = 0; + for (int i = 0; i < 8; ++i) { + x0 |= ((b->qh[(l%4)*8+i] >> (l/4)) & 0x01) << i; + } + qh[l] = x0; + } + + for (int i = 0; i < K_SCALE_SIZE; ++i) { + s[i] = b->s[i]; + } +} + +kernel void kernel_restore_block_q5_K_noshuffle( + global uchar * src_q, + global uchar * src_qh, + global uchar * src_s, + global half * src_d, + global half * src_dm, + global struct block_q5_K * dst, + uchar mask_0F, + uchar mask_F0 +) { + global struct block_q5_K * b = (global struct block_q5_K *) dst + get_global_id(0); + global uchar * q = (global uchar *) src_q + QK_K/2 * get_global_id(0); + global uchar * qh = (global uchar *) src_qh + QK_K/8 * get_global_id(0); + global uchar * s = (global uchar *) src_s + K_SCALE_SIZE * get_global_id(0); + global half * d = (global half *) src_d + get_global_id(0); + global half * dm = (global half *) src_dm + get_global_id(0); + + b->d = *d; + b->dm = *dm; + + for (int i = 0; i < QK_K / 64; ++i) { + for (int j = 0; j < 16; ++j) { + uchar lo = q[i*32 + j]; + uchar hi = q[i*32 + j + 16]; + b->qs[i*32 + 2*j] = convert_uchar((lo & mask_0F) | ((hi & mask_0F) << 4)); + b->qs[i*32 + 2*j + 1] = convert_uchar(((lo & mask_F0) >> 4) | (hi & mask_F0)); + } + } + + for (int g = 0; g < 4; ++g) { + for (int i = 0; i < 8; ++i) { + uchar x0 = 0; + for (int k = 0; k < 8; ++k) { + x0 |= ((qh[4*k+g] >> i) & 0x01) << k; + } + b->qh[g*8+i] = x0; + } + } + + for (int i = 0; i < K_SCALE_SIZE; ++i) { + b->s[i] = s[i]; + } +} + //------------------------------------------------------------------------------ // kernel_convert_block_q6_K // Convert the block_q6_K format to 3 separate arrays (AOS -> SOA). diff --git a/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_k_f32.cl b/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_k_f32.cl new file mode 100644 index 00000000000..058c0f7edc6 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_k_f32.cl @@ -0,0 +1,176 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_qcom_reqd_sub_group_size +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif +#define QK_K 256 +#define K_SCALE_SIZE 12 + +inline void get_scale_min_k4( + int j, + global const uchar * q, + uchar * d, + uchar * m, + uchar mask_d6, + uchar mask_d4, + uchar mask_hi2 +) { + if (j < 4) { + *d = q[j] & mask_d6; + *m = q[j+4] & mask_d6; + } else { + *d = (q[j+4] & mask_d4) | ((q[j-4] & mask_hi2) >> 2); + *m = ((q[j+4] >> 4) & mask_d4) | ((q[j] & mask_hi2) >> 2); + } +} + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_128 +#endif +kernel void kernel_gemm_noshuffle_q5_k_f32( + global const ushort * src0_q, + global const uchar * src0_qh, + global const uchar * src0_s, + global const half * src0_d, + global const half * src0_dm, + read_only image1d_buffer_t src1, + global float * dst, + ulong offsetd, + int m, + int n, + int k, + int n_no_padding, + uchar mask_d6, + uchar mask_d4, + uchar mask_hi2 +) { + dst = (global float *)((global char *)dst + offsetd); + int n_4 = n >> 2; + int gy = get_global_id(0); + int gx = get_global_id(1); + int gx_2 = gx << 2; + + half8 c0 = 0, c1 = 0, c2 = 0, c3 = 0; + half8 B; + half4 dequantized_weights; + + int num_blocks_K = k / QK_K; + + global const ushort * weight_ptr = src0_q + gx_2; + global const uchar * qh_ptr = src0_qh + gx_2; + global const half * d_ptr = src0_d + gx_2; + global const half * dm_ptr = src0_dm + gx_2; + + for (int i = 0; i < k; i += 32) { + int sb_idx = i / QK_K; + int sub_idx = (i / 32) % 8; + + half4 d = vload4(0, d_ptr + sb_idx * m); + half4 dm = vload4(0, dm_ptr + sb_idx * m); + + global const uchar * sc0 = src0_s + (gx_2+0) * num_blocks_K * K_SCALE_SIZE + sb_idx * K_SCALE_SIZE; + global const uchar * sc1 = src0_s + (gx_2+1) * num_blocks_K * K_SCALE_SIZE + sb_idx * K_SCALE_SIZE; + global const uchar * sc2 = src0_s + (gx_2+2) * num_blocks_K * K_SCALE_SIZE + sb_idx * K_SCALE_SIZE; + global const uchar * sc3 = src0_s + (gx_2+3) * num_blocks_K * K_SCALE_SIZE + sb_idx * K_SCALE_SIZE; + + uchar sv0, mn0, sv1, mn1, sv2, mn2, sv3, mn3; + get_scale_min_k4(sub_idx, sc0, &sv0, &mn0, mask_d6, mask_d4, mask_hi2); + get_scale_min_k4(sub_idx, sc1, &sv1, &mn1, mask_d6, mask_d4, mask_hi2); + get_scale_min_k4(sub_idx, sc2, &sv2, &mn2, mask_d6, mask_d4, mask_hi2); + get_scale_min_k4(sub_idx, sc3, &sv3, &mn3, mask_d6, mask_d4, mask_hi2); + + half4 scale = convert_half4(convert_float4(d) * convert_float4((uchar4)(sv0, sv1, sv2, sv3))); + half4 mval = convert_half4(convert_float4(dm) * convert_float4((uchar4)(mn0, mn1, mn2, mn3))); + + for (int l = 0; l < 32; l += 4) { + int ki = i + l; + ushort4 bits4 = vload4(0, weight_ptr + (ki/4) * m); + uchar4 qh_bits = vload4(0, qh_ptr + (ki/8) * m); + int qh_shift = ki % 8; + + // j=0 + B.s0123 = read_imageh(src1, gy*2 + (ki+0) * n_4); + B.s4567 = read_imageh(src1, gy*2+1 + (ki+0) * n_4); + dequantized_weights.s0 = ((bits4.s0 & 0x000F) | (((qh_bits.s0 >> (qh_shift+0)) & 1) << 4)) * scale.s0 - mval.s0; + dequantized_weights.s1 = ((bits4.s1 & 0x000F) | (((qh_bits.s1 >> (qh_shift+0)) & 1) << 4)) * scale.s1 - mval.s1; + dequantized_weights.s2 = ((bits4.s2 & 0x000F) | (((qh_bits.s2 >> (qh_shift+0)) & 1) << 4)) * scale.s2 - mval.s2; + dequantized_weights.s3 = ((bits4.s3 & 0x000F) | (((qh_bits.s3 >> (qh_shift+0)) & 1) << 4)) * scale.s3 - mval.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=1 + B.s0123 = read_imageh(src1, gy*2 + (ki+1) * n_4); + B.s4567 = read_imageh(src1, gy*2+1 + (ki+1) * n_4); + dequantized_weights.s0 = (((bits4.s0 & 0x00F0) >> 4) | (((qh_bits.s0 >> (qh_shift+1)) & 1) << 4)) * scale.s0 - mval.s0; + dequantized_weights.s1 = (((bits4.s1 & 0x00F0) >> 4) | (((qh_bits.s1 >> (qh_shift+1)) & 1) << 4)) * scale.s1 - mval.s1; + dequantized_weights.s2 = (((bits4.s2 & 0x00F0) >> 4) | (((qh_bits.s2 >> (qh_shift+1)) & 1) << 4)) * scale.s2 - mval.s2; + dequantized_weights.s3 = (((bits4.s3 & 0x00F0) >> 4) | (((qh_bits.s3 >> (qh_shift+1)) & 1) << 4)) * scale.s3 - mval.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=2 + B.s0123 = read_imageh(src1, gy*2 + (ki+2) * n_4); + B.s4567 = read_imageh(src1, gy*2+1 + (ki+2) * n_4); + dequantized_weights.s0 = (((bits4.s0 & 0x0F00) >> 8) | (((qh_bits.s0 >> (qh_shift+2)) & 1) << 4)) * scale.s0 - mval.s0; + dequantized_weights.s1 = (((bits4.s1 & 0x0F00) >> 8) | (((qh_bits.s1 >> (qh_shift+2)) & 1) << 4)) * scale.s1 - mval.s1; + dequantized_weights.s2 = (((bits4.s2 & 0x0F00) >> 8) | (((qh_bits.s2 >> (qh_shift+2)) & 1) << 4)) * scale.s2 - mval.s2; + dequantized_weights.s3 = (((bits4.s3 & 0x0F00) >> 8) | (((qh_bits.s3 >> (qh_shift+2)) & 1) << 4)) * scale.s3 - mval.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=3 + B.s0123 = read_imageh(src1, gy*2 + (ki+3) * n_4); + B.s4567 = read_imageh(src1, gy*2+1 + (ki+3) * n_4); + dequantized_weights.s0 = (((bits4.s0 & 0xF000) >> 12) | (((qh_bits.s0 >> (qh_shift+3)) & 1) << 4)) * scale.s0 - mval.s0; + dequantized_weights.s1 = (((bits4.s1 & 0xF000) >> 12) | (((qh_bits.s1 >> (qh_shift+3)) & 1) << 4)) * scale.s1 - mval.s1; + dequantized_weights.s2 = (((bits4.s2 & 0xF000) >> 12) | (((qh_bits.s2 >> (qh_shift+3)) & 1) << 4)) * scale.s2 - mval.s2; + dequantized_weights.s3 = (((bits4.s3 & 0xF000) >> 12) | (((qh_bits.s3 >> (qh_shift+3)) & 1) << 4)) * scale.s3 - mval.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + } + } + + int idx = (gy<<3)*m + (gx<<2); + + if (idx+3 < m*n_no_padding) { + vstore4((float4)(c0.s0, c1.s0, c2.s0, c3.s0), 0, dst + idx); + idx += m; + } + if (idx+3 < m*n_no_padding) { + vstore4((float4)(c0.s1, c1.s1, c2.s1, c3.s1), 0, dst + idx); + idx += m; + } + if (idx+3 < m*n_no_padding) { + vstore4((float4)(c0.s2, c1.s2, c2.s2, c3.s2), 0, dst + idx); + idx += m; + } + if (idx+3 < m*n_no_padding) { + vstore4((float4)(c0.s3, c1.s3, c2.s3, c3.s3), 0, dst + idx); + idx += m; + } + if (idx+3 < m*n_no_padding) { + vstore4((float4)(c0.s4, c1.s4, c2.s4, c3.s4), 0, dst + idx); + idx += m; + } + if (idx+3 < m*n_no_padding) { + vstore4((float4)(c0.s5, c1.s5, c2.s5, c3.s5), 0, dst + idx); + idx += m; + } + if (idx+3 < m*n_no_padding) { + vstore4((float4)(c0.s6, c1.s6, c2.s6, c3.s6), 0, dst + idx); + idx += m; + } + if (idx+3 < m*n_no_padding) { + vstore4((float4)(c0.s7, c1.s7, c2.s7, c3.s7), 0, dst + idx); + } +} diff --git a/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_k_f32.cl b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_k_f32.cl new file mode 100644 index 00000000000..c40db166638 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_k_f32.cl @@ -0,0 +1,326 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable + +#ifdef cl_qcom_reqd_sub_group_size +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#endif + +#define QK_K 256 +#define NSUBGROUPS 4 +#define SUBGROUP_SIZE 64 + +inline void get_scale_min_k4( + int j, + global const uchar * q, + uchar * d, + uchar * m, + uchar mask_d6, + uchar mask_d4, + uchar mask_hi2 +) { + if (j < 4) { + *d = q[j] & mask_d6; + *m = q[j+4] & mask_d6; + } else { + *d = (q[j+4] & mask_d4) | ((q[j-4] & mask_hi2) >> 2); + *m = ((q[j+4] >> 4) & mask_d4) | ((q[j] & mask_hi2) >> 2); + } +} + +#define dequantizeBlockAccum_ns_sgbroadcast_1_hi(total_sums, bits4, bits1, scale, minv, y) \ + float shared_y; \ + shared_y = sub_group_broadcast(y.s0, 0); \ + total_sums.s0 += (((bits4.s0 & 0x000F) | ((bits1.s0 & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0x000F) | ((bits1.s1 & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 0); \ + total_sums.s0 += ((((bits4.s0 & 0x00F0) >> 4) | (((bits1.s0 >> 1) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s1 & 0x00F0) >> 4) | (((bits1.s1 >> 1) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 0); \ + total_sums.s0 += ((((bits4.s0 & 0x0F00) >> 8) | (((bits1.s0 >> 2) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s1 & 0x0F00) >> 8) | (((bits1.s1 >> 2) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 0); \ + total_sums.s0 += ((((bits4.s0 & 0xF000) >> 12) | (((bits1.s0 >> 3) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s1 & 0xF000) >> 12) | (((bits1.s1 >> 3) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 0); \ + total_sums.s0 += (((bits4.s2 & 0x000F) | (((bits1.s0 >> 4) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0x000F) | (((bits1.s1 >> 4) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 0); \ + total_sums.s0 += ((((bits4.s2 & 0x00F0) >> 4) | (((bits1.s0 >> 5) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s3 & 0x00F0) >> 4) | (((bits1.s1 >> 5) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 0); \ + total_sums.s0 += ((((bits4.s2 & 0x0F00) >> 8) | (((bits1.s0 >> 6) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s3 & 0x0F00) >> 8) | (((bits1.s1 >> 6) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 0); \ + total_sums.s0 += ((((bits4.s2 & 0xF000) >> 12) | (((bits1.s0 >> 7) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s3 & 0xF000) >> 12) | (((bits1.s1 >> 7) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s0, 1); \ + total_sums.s0 += (((bits4.s4 & 0x000F) | ((bits1.s2 & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0x000F) | ((bits1.s3 & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 1); \ + total_sums.s0 += ((((bits4.s4 & 0x00F0) >> 4) | (((bits1.s2 >> 1) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s5 & 0x00F0) >> 4) | (((bits1.s3 >> 1) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 1); \ + total_sums.s0 += ((((bits4.s4 & 0x0F00) >> 8) | (((bits1.s2 >> 2) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s5 & 0x0F00) >> 8) | (((bits1.s3 >> 2) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 1); \ + total_sums.s0 += ((((bits4.s4 & 0xF000) >> 12) | (((bits1.s2 >> 3) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s5 & 0xF000) >> 12) | (((bits1.s3 >> 3) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 1); \ + total_sums.s0 += (((bits4.s6 & 0x000F) | (((bits1.s2 >> 4) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0x000F) | (((bits1.s3 >> 4) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 1); \ + total_sums.s0 += ((((bits4.s6 & 0x00F0) >> 4) | (((bits1.s2 >> 5) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s7 & 0x00F0) >> 4) | (((bits1.s3 >> 5) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 1); \ + total_sums.s0 += ((((bits4.s6 & 0x0F00) >> 8) | (((bits1.s2 >> 6) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s7 & 0x0F00) >> 8) | (((bits1.s3 >> 6) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 1); \ + total_sums.s0 += ((((bits4.s6 & 0xF000) >> 12) | (((bits1.s2 >> 7) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s7 & 0xF000) >> 12) | (((bits1.s3 >> 7) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + + +#define dequantizeBlockAccum_ns_sgbroadcast_1_lo(total_sums, bits4, bits1, scale, minv, y) \ + shared_y = sub_group_broadcast(y.s0, 2); \ + total_sums.s0 += (((bits4.s0 & 0x000F) | ((bits1.s4 & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0x000F) | ((bits1.s5 & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 2); \ + total_sums.s0 += ((((bits4.s0 & 0x00F0) >> 4) | (((bits1.s4 >> 1) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s1 & 0x00F0) >> 4) | (((bits1.s5 >> 1) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 2); \ + total_sums.s0 += ((((bits4.s0 & 0x0F00) >> 8) | (((bits1.s4 >> 2) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s1 & 0x0F00) >> 8) | (((bits1.s5 >> 2) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 2); \ + total_sums.s0 += ((((bits4.s0 & 0xF000) >> 12) | (((bits1.s4 >> 3) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s1 & 0xF000) >> 12) | (((bits1.s5 >> 3) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 2); \ + total_sums.s0 += (((bits4.s2 & 0x000F) | (((bits1.s4 >> 4) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0x000F) | (((bits1.s5 >> 4) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 2); \ + total_sums.s0 += ((((bits4.s2 & 0x00F0) >> 4) | (((bits1.s4 >> 5) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s3 & 0x00F0) >> 4) | (((bits1.s5 >> 5) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 2); \ + total_sums.s0 += ((((bits4.s2 & 0x0F00) >> 8) | (((bits1.s4 >> 6) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s3 & 0x0F00) >> 8) | (((bits1.s5 >> 6) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 2); \ + total_sums.s0 += ((((bits4.s2 & 0xF000) >> 12) | (((bits1.s4 >> 7) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s3 & 0xF000) >> 12) | (((bits1.s5 >> 7) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s0, 3); \ + total_sums.s0 += (((bits4.s4 & 0x000F) | ((bits1.s6 & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0x000F) | ((bits1.s7 & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 3); \ + total_sums.s0 += ((((bits4.s4 & 0x00F0) >> 4) | (((bits1.s6 >> 1) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s5 & 0x00F0) >> 4) | (((bits1.s7 >> 1) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 3); \ + total_sums.s0 += ((((bits4.s4 & 0x0F00) >> 8) | (((bits1.s6 >> 2) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s5 & 0x0F00) >> 8) | (((bits1.s7 >> 2) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 3); \ + total_sums.s0 += ((((bits4.s4 & 0xF000) >> 12) | (((bits1.s6 >> 3) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s5 & 0xF000) >> 12) | (((bits1.s7 >> 3) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 3); \ + total_sums.s0 += (((bits4.s6 & 0x000F) | (((bits1.s6 >> 4) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0x000F) | (((bits1.s7 >> 4) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 3); \ + total_sums.s0 += ((((bits4.s6 & 0x00F0) >> 4) | (((bits1.s6 >> 5) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s7 & 0x00F0) >> 4) | (((bits1.s7 >> 5) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 3); \ + total_sums.s0 += ((((bits4.s6 & 0x0F00) >> 8) | (((bits1.s6 >> 6) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s7 & 0x0F00) >> 8) | (((bits1.s7 >> 6) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 3); \ + total_sums.s0 += ((((bits4.s6 & 0xF000) >> 12) | (((bits1.s6 >> 7) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s7 & 0xF000) >> 12) | (((bits1.s7 >> 7) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + + +#define dequantizeBlockAccum_ns_sgbroadcast_8_hi(total_sums, bits4, bits1, scale, minv, y) \ + float8 shared_y; \ + shared_y = sub_group_broadcast(y, 0); \ + total_sums.s0 += (((bits4.s0 & 0x000F) | ((bits1.s0 & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s0; \ + total_sums.s0 += ((((bits4.s0 & 0x00F0) >> 4) | (((bits1.s0 >> 1) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s1; \ + total_sums.s0 += ((((bits4.s0 & 0x0F00) >> 8) | (((bits1.s0 >> 2) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s2; \ + total_sums.s0 += ((((bits4.s0 & 0xF000) >> 12) | (((bits1.s0 >> 3) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s3; \ + total_sums.s0 += (((bits4.s2 & 0x000F) | (((bits1.s0 >> 4) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s4; \ + total_sums.s0 += ((((bits4.s2 & 0x00F0) >> 4) | (((bits1.s0 >> 5) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s5; \ + total_sums.s0 += ((((bits4.s2 & 0x0F00) >> 8) | (((bits1.s0 >> 6) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s6; \ + total_sums.s0 += ((((bits4.s2 & 0xF000) >> 12) | (((bits1.s0 >> 7) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s7; \ + total_sums.s1 += (((bits4.s1 & 0x000F) | ((bits1.s1 & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s0; \ + total_sums.s1 += ((((bits4.s1 & 0x00F0) >> 4) | (((bits1.s1 >> 1) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s1; \ + total_sums.s1 += ((((bits4.s1 & 0x0F00) >> 8) | (((bits1.s1 >> 2) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s2; \ + total_sums.s1 += ((((bits4.s1 & 0xF000) >> 12) | (((bits1.s1 >> 3) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s3; \ + total_sums.s1 += (((bits4.s3 & 0x000F) | (((bits1.s1 >> 4) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s4; \ + total_sums.s1 += ((((bits4.s3 & 0x00F0) >> 4) | (((bits1.s1 >> 5) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s5; \ + total_sums.s1 += ((((bits4.s3 & 0x0F00) >> 8) | (((bits1.s1 >> 6) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s6; \ + total_sums.s1 += ((((bits4.s3 & 0xF000) >> 12) | (((bits1.s1 >> 7) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s7; \ + shared_y = sub_group_broadcast(y, 1); \ + total_sums.s0 += (((bits4.s4 & 0x000F) | ((bits1.s2 & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s0; \ + total_sums.s0 += ((((bits4.s4 & 0x00F0) >> 4) | (((bits1.s2 >> 1) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s1; \ + total_sums.s0 += ((((bits4.s4 & 0x0F00) >> 8) | (((bits1.s2 >> 2) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s2; \ + total_sums.s0 += ((((bits4.s4 & 0xF000) >> 12) | (((bits1.s2 >> 3) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s3; \ + total_sums.s0 += (((bits4.s6 & 0x000F) | (((bits1.s2 >> 4) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s4; \ + total_sums.s0 += ((((bits4.s6 & 0x00F0) >> 4) | (((bits1.s2 >> 5) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s5; \ + total_sums.s0 += ((((bits4.s6 & 0x0F00) >> 8) | (((bits1.s2 >> 6) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s6; \ + total_sums.s0 += ((((bits4.s6 & 0xF000) >> 12) | (((bits1.s2 >> 7) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s7; \ + total_sums.s1 += (((bits4.s5 & 0x000F) | ((bits1.s3 & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s0; \ + total_sums.s1 += ((((bits4.s5 & 0x00F0) >> 4) | (((bits1.s3 >> 1) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s1; \ + total_sums.s1 += ((((bits4.s5 & 0x0F00) >> 8) | (((bits1.s3 >> 2) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s2; \ + total_sums.s1 += ((((bits4.s5 & 0xF000) >> 12) | (((bits1.s3 >> 3) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s3; \ + total_sums.s1 += (((bits4.s7 & 0x000F) | (((bits1.s3 >> 4) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s4; \ + total_sums.s1 += ((((bits4.s7 & 0x00F0) >> 4) | (((bits1.s3 >> 5) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s5; \ + total_sums.s1 += ((((bits4.s7 & 0x0F00) >> 8) | (((bits1.s3 >> 6) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s6; \ + total_sums.s1 += ((((bits4.s7 & 0xF000) >> 12) | (((bits1.s3 >> 7) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s7; \ + + +#define dequantizeBlockAccum_ns_sgbroadcast_8_lo(total_sums, bits4, bits1, scale, minv, y) \ + shared_y = sub_group_broadcast(y, 2); \ + total_sums.s0 += (((bits4.s0 & 0x000F) | ((bits1.s4 & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s0; \ + total_sums.s0 += ((((bits4.s0 & 0x00F0) >> 4) | (((bits1.s4 >> 1) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s1; \ + total_sums.s0 += ((((bits4.s0 & 0x0F00) >> 8) | (((bits1.s4 >> 2) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s2; \ + total_sums.s0 += ((((bits4.s0 & 0xF000) >> 12) | (((bits1.s4 >> 3) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s3; \ + total_sums.s0 += (((bits4.s2 & 0x000F) | (((bits1.s4 >> 4) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s4; \ + total_sums.s0 += ((((bits4.s2 & 0x00F0) >> 4) | (((bits1.s4 >> 5) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s5; \ + total_sums.s0 += ((((bits4.s2 & 0x0F00) >> 8) | (((bits1.s4 >> 6) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s6; \ + total_sums.s0 += ((((bits4.s2 & 0xF000) >> 12) | (((bits1.s4 >> 7) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s7; \ + total_sums.s1 += (((bits4.s1 & 0x000F) | ((bits1.s5 & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s0; \ + total_sums.s1 += ((((bits4.s1 & 0x00F0) >> 4) | (((bits1.s5 >> 1) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s1; \ + total_sums.s1 += ((((bits4.s1 & 0x0F00) >> 8) | (((bits1.s5 >> 2) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s2; \ + total_sums.s1 += ((((bits4.s1 & 0xF000) >> 12) | (((bits1.s5 >> 3) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s3; \ + total_sums.s1 += (((bits4.s3 & 0x000F) | (((bits1.s5 >> 4) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s4; \ + total_sums.s1 += ((((bits4.s3 & 0x00F0) >> 4) | (((bits1.s5 >> 5) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s5; \ + total_sums.s1 += ((((bits4.s3 & 0x0F00) >> 8) | (((bits1.s5 >> 6) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s6; \ + total_sums.s1 += ((((bits4.s3 & 0xF000) >> 12) | (((bits1.s5 >> 7) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s7; \ + shared_y = sub_group_broadcast(y, 3); \ + total_sums.s0 += (((bits4.s4 & 0x000F) | ((bits1.s6 & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s0; \ + total_sums.s0 += ((((bits4.s4 & 0x00F0) >> 4) | (((bits1.s6 >> 1) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s1; \ + total_sums.s0 += ((((bits4.s4 & 0x0F00) >> 8) | (((bits1.s6 >> 2) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s2; \ + total_sums.s0 += ((((bits4.s4 & 0xF000) >> 12) | (((bits1.s6 >> 3) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s3; \ + total_sums.s0 += (((bits4.s6 & 0x000F) | (((bits1.s6 >> 4) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s4; \ + total_sums.s0 += ((((bits4.s6 & 0x00F0) >> 4) | (((bits1.s6 >> 5) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s5; \ + total_sums.s0 += ((((bits4.s6 & 0x0F00) >> 8) | (((bits1.s6 >> 6) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s6; \ + total_sums.s0 += ((((bits4.s6 & 0xF000) >> 12) | (((bits1.s6 >> 7) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s7; \ + total_sums.s1 += (((bits4.s5 & 0x000F) | ((bits1.s7 & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s0; \ + total_sums.s1 += ((((bits4.s5 & 0x00F0) >> 4) | (((bits1.s7 >> 1) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s1; \ + total_sums.s1 += ((((bits4.s5 & 0x0F00) >> 8) | (((bits1.s7 >> 2) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s2; \ + total_sums.s1 += ((((bits4.s5 & 0xF000) >> 12) | (((bits1.s7 >> 3) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s3; \ + total_sums.s1 += (((bits4.s7 & 0x000F) | (((bits1.s7 >> 4) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s4; \ + total_sums.s1 += ((((bits4.s7 & 0x00F0) >> 4) | (((bits1.s7 >> 5) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s5; \ + total_sums.s1 += ((((bits4.s7 & 0x0F00) >> 8) | (((bits1.s7 >> 6) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s6; \ + total_sums.s1 += ((((bits4.s7 & 0xF000) >> 12) | (((bits1.s7 >> 7) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s7; \ + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_gemv_noshuffle_q5_k_f32( + read_only image1d_buffer_t src0_q, + read_only image1d_buffer_t src0_qh, + global half2 * src0_d, + global half2 * src0_m, + global uchar * src0_s, + read_only image1d_buffer_t src1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + uchar mask_d6, + uchar mask_d4, + uchar mask_hi2) +{ + uint groupId = get_local_id(1); + uint gid = get_global_id(0); + ushort slid = get_sub_group_local_id(); + + uint K = ne00; + uint M = ne01; + + uint LINE_STRIDE_A = M / 2; + uint BLOCK_STRIDE_A = NSUBGROUPS * M; + + uint LINE_STRIDE_A_QH = M / 2; + uint BLOCK_STRIDE_A_QH = NSUBGROUPS * M / 2; + uint scales_per_row = (K / QK_K) * 12; + + private uint4 regA; + private ushort4 regH; + private half2 regS; + private half2 regM; + private float8 regB; + + private float2 totalSum = (float2)(0.0f); + + for (uint k = groupId; k < (K / 32); k += NSUBGROUPS) { + uint sb = k / 8; + uint j = k % 8; + + half2 d = src0_d[gid + sb * LINE_STRIDE_A]; + half2 dm = src0_m[gid + sb * LINE_STRIDE_A]; + + global const uchar * sc0 = src0_s + 2 * gid * scales_per_row + sb * 12; + global const uchar * sc1 = src0_s + (2 * gid + 1) * scales_per_row + sb * 12; + + uchar sv0, mn0, sv1, mn1; + get_scale_min_k4(j, sc0, &sv0, &mn0, mask_d6, mask_d4, mask_hi2); + get_scale_min_k4(j, sc1, &sv1, &mn1, mask_d6, mask_d4, mask_hi2); + + regS = convert_half2(convert_float2(d) * convert_float2((uchar2)(sv0, sv1))); + regM = convert_half2(convert_float2(dm) * convert_float2((uchar2)(mn0, mn1))); + + if (slid < 4) { + regB.s0123 = read_imagef(src1, (slid * 2 + k * 8)); + regB.s4567 = read_imagef(src1, (1 + slid * 2 + k * 8)); + } + + regH.s0 = as_ushort(read_imageh(src0_qh, (gid + k * BLOCK_STRIDE_A_QH + LINE_STRIDE_A_QH * 0)).x); + regH.s1 = as_ushort(read_imageh(src0_qh, (gid + k * BLOCK_STRIDE_A_QH + LINE_STRIDE_A_QH * 1)).x); + regH.s2 = as_ushort(read_imageh(src0_qh, (gid + k * BLOCK_STRIDE_A_QH + LINE_STRIDE_A_QH * 2)).x); + regH.s3 = as_ushort(read_imageh(src0_qh, (gid + k * BLOCK_STRIDE_A_QH + LINE_STRIDE_A_QH * 3)).x); + + regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 0)).x; + regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 1)).x; + regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 2)).x; + regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 3)).x; +#ifdef VECTOR_SUB_GROUP_BROADCAST + dequantizeBlockAccum_ns_sgbroadcast_8_hi(totalSum, as_ushort8(regA), as_uchar8(regH), regS, regM, regB); +#else + dequantizeBlockAccum_ns_sgbroadcast_1_hi(totalSum, as_ushort8(regA), as_uchar8(regH), regS, regM, regB); +#endif // VECTOR_SUB_GROUP_BROADCAST + + regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 4)).x; + regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 5)).x; + regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x; + regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x; +#ifdef VECTOR_SUB_GROUP_BROADCAST + dequantizeBlockAccum_ns_sgbroadcast_8_lo(totalSum, as_ushort8(regA), as_uchar8(regH), regS, regM, regB); +#else + dequantizeBlockAccum_ns_sgbroadcast_1_lo(totalSum, as_ushort8(regA), as_uchar8(regH), regS, regM, regB); +#endif // VECTOR_SUB_GROUP_BROADCAST + } + + // reduction in local memory, assumes #wave=4 + local float2 reduceLM[SUBGROUP_SIZE * 3]; + if (groupId == 1) { + reduceLM[SUBGROUP_SIZE * 0 + slid] = totalSum; + } + if (groupId == 2) { + reduceLM[SUBGROUP_SIZE * 1 + slid] = totalSum; + } + if (groupId == 3) { + reduceLM[SUBGROUP_SIZE * 2 + slid] = totalSum; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + if (groupId == 0) { + totalSum += reduceLM[SUBGROUP_SIZE * 0 + slid]; + } + if (groupId == 0) { + totalSum += reduceLM[SUBGROUP_SIZE * 1 + slid]; + } + if (groupId == 0) { + totalSum += reduceLM[SUBGROUP_SIZE * 2 + slid]; + } + + // 2 outputs per fiber in wave 0 + if (groupId == 0) { + dst = (global float*)((global char*)dst + offsetd); + vstore2(totalSum, 0, &(dst[gid * 2])); + } +} From b25d5d050b53463ab415e4a6c7039c43bedee571 Mon Sep 17 00:00:00 2001 From: nullname Date: Fri, 17 Apr 2026 04:48:34 +0800 Subject: [PATCH 449/831] hexagon: optimize HMX matmul operations (llama/21071) * optimize hmx_mat_mul functions by calculating row and column tiles upfront * refactor core_dot_chunk_fp16 to use size_t for tile counts and improve readability * wip * set scale outside of loop * wip * refactor core_mma_chunk_fp16 and mat_mul_qk_0_d16a32 to use size_t for tile counts * wip * wip * refactor transfer_output_chunk_fp16_to_fp32 to use size_t for dimensions * refactor core_dot_chunk_fp16 to use size_t for tile row stride calculation * wip * refactor hmx_mat_mul functions to use hvx_vec_splat_f16 for column scales initialization * refactor hmx_mat_mul_permuted_w16a32_batched to streamline scale setting and locking * refactor core_dot_chunk_fp16 to improve tile stride calculations for output * refactor hmx_mat_mul functions to use Q6_V_vsplat_R for column scales initialization * fix compiling error * wip * optimize row and column tile indexing in core_mma_chunk_fp16 function * wip * Revert "wip" This reverts commit cde679eff79c4a28dd2d89d32f710015e09592b6. * Add size limit check for HAP_mmap in htp_iface_mmap and drop_mmap functions * wip --- ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c | 96 +++++++++++----------- ggml/src/ggml-hexagon/htp/htp-ops.h | 2 + ggml/src/ggml-hexagon/htp/main.c | 31 ++++++- 3 files changed, 80 insertions(+), 49 deletions(-) diff --git a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c index 485ec3f1aa9..dbca8220fab 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c @@ -648,9 +648,9 @@ static void dequantize_x4x2_weight_chunk_to_fp16_tiles( assert(n_cols % HMX_FP16_TILE_N_COLS == 0); assert(k_block % HMX_FP16_TILE_N_COLS == 0); - int n_col_tiles = n_cols / HMX_FP16_TILE_N_COLS; - int n_k_tiles = k_block / HMX_FP16_TILE_N_COLS; - int n_tot_tiles = n_col_tiles * n_k_tiles; + size_t n_col_tiles = n_cols / HMX_FP16_TILE_N_COLS; + size_t n_k_tiles = k_block / HMX_FP16_TILE_N_COLS; + size_t n_tot_tiles = n_col_tiles * n_k_tiles; size_t n_tiles_per_task = hmx_ceil_div(n_tot_tiles, ctx->n_threads); @@ -678,9 +678,8 @@ static void core_dot_chunk_fp16(__fp16 *restrict output, const __fp16 *restrict __builtin_assume(n_dot_tiles > 0); Q6_bias_mxmem2_A((void *)scales); - for (int r = 0; r < n_row_tiles; ++r) { - for (int c = 0; c < n_col_tiles; ++c) { + for (size_t c = 0; c < n_col_tiles; ++c) { Q6_mxclracc_hf(); const __fp16 *row_tiles = activation + r * n_dot_tiles * HMX_FP16_TILE_N_ELMS; @@ -738,25 +737,25 @@ static inline void hmx_matmul_job_init(hmx_matmul_job_t * job, static void transfer_output_chunk_fp16_to_fp32(float *restrict dst, const __fp16 *restrict vtcm_src, int n_rows, int n_cols, int n) { assert(n_cols % HMX_FP16_TILE_N_COLS == 0); - const int n_col_tiles = n_cols / HMX_FP16_TILE_N_COLS; + const size_t tile_row_stride = (n_cols / HMX_FP16_TILE_N_COLS) * HMX_FP16_TILE_N_ELMS; const HVX_Vector one = hvx_vec_splat_f16(1.0); - for (int r = 0; r < n_rows; r += 2) { - int r0 = r / HMX_FP16_TILE_N_ROWS; - int r1 = r % HMX_FP16_TILE_N_ROWS; + for (size_t r = 0; r < n_rows; r += 2) { + const size_t r0 = r / HMX_FP16_TILE_N_ROWS; + const size_t r1 = (r % HMX_FP16_TILE_N_ROWS) / 2; // index of the row pair within the tile + const __fp16 *row_base = vtcm_src + r0 * tile_row_stride; + float *output_row_base = dst + r * n; // global memory row base for row r (and r+1) #pragma unroll(4) - for (int c = 0; c < n_cols; c += HMX_FP16_TILE_N_COLS) { - int c0 = c / HMX_FP16_TILE_N_COLS; - - const __fp16 *tile = vtcm_src + (r0 * n_col_tiles + c0) * HMX_FP16_TILE_N_ELMS; - - HVX_Vector v = ((const HVX_Vector *) tile)[r1 / 2]; + for (size_t c = 0; c < n_cols; c += HMX_FP16_TILE_N_COLS) { + const size_t c0 = c / HMX_FP16_TILE_N_COLS; + const __fp16 *tile = row_base + c0 * HMX_FP16_TILE_N_ELMS; + HVX_Vector v = ((const HVX_Vector *) tile)[r1]; HVX_VectorPair vp = Q6_Wqf32_vmpy_VhfVhf(v, one); - volatile HVX_Vector *pv_out0 = (volatile HVX_Vector *) (dst + (r * n + c + 0)); - volatile HVX_Vector *pv_out1 = (volatile HVX_Vector *) (dst + (r * n + c + n)); // next row in global memory + volatile HVX_Vector *pv_out0 = (volatile HVX_Vector *) (output_row_base + c + 0); + volatile HVX_Vector *pv_out1 = (volatile HVX_Vector *) (output_row_base + c + n); // next row in global memory *pv_out0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(vp)); if (r + 1 < n_rows) { @@ -794,7 +793,7 @@ static void transfer_output_chunk_threaded(struct htp_context *ctx, float *dst, assert(n_cols % HMX_FP16_TILE_N_COLS == 0); size_t n_tot_chunks = n_rows; - size_t n_chunks_per_task = 32; // must be multiple of HMX_FP16_TILE_N_ROWS (32) + size_t n_chunks_per_task = HMX_FP16_TILE_N_ROWS; // must be multiple of HMX_FP16_TILE_N_ROWS (32) output_transfer_task_state_t state; state.n_tasks = (n_tot_chunks + n_chunks_per_task - 1) / n_chunks_per_task; @@ -926,7 +925,7 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu return hmx_mat_mul_permuted_w16a32_batched_legacy(ctx, params); } - hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // fp16: 1.0 + hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16 FARF(MEDIUM, "%s: grouped path m=%d k=%d n=%d group=%d streams=%d mc=%zu nc=%zu vtcm=%zu/%zu", __func__, params->m, params->k, params->n, group_size, params->ne13, @@ -944,12 +943,15 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu const size_t fp16_row_bytes = (size_t) params->k * sizeof(__fp16); const size_t weight_row_bytes = (size_t) params->weight_stride * sizeof(__fp16); + HAP_compute_res_hmx_lock(ctx->vtcm_rctx); + for (int b3 = 0; b3 < params->ne13; ++b3) { for (int b2_base = 0; b2_base < params->ne12; b2_base += group_size) { const __fp16 *weight_group = hmx_matmul_weight_batch_ptr(params, b2_base, b3); for (size_t mr = 0; mr < (size_t) params->m; mr += m_chunk_n_rows) { const size_t n_rows = hex_smin((size_t) params->m - mr, m_chunk_n_rows); + const size_t n_row_tiles = hmx_ceil_div((int) n_rows, HMX_FP16_TILE_N_ROWS); // Pre-load activations for all heads in the group (once per m_chunk). // When the source is strided (permuted Q), use 2D DMA to gather @@ -987,10 +989,9 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_first); } - HAP_compute_res_hmx_lock(ctx->vtcm_rctx); - for (size_t nc = 0; nc < (size_t) params->n; nc += n_chunk_n_cols) { const size_t n_cols = hex_smin((size_t) params->n - nc, n_chunk_n_cols); + const size_t n_col_tiles = hmx_ceil_div((int) n_cols, HMX_FP16_TILE_N_COLS); TIMER_START(weight_load); { @@ -1014,11 +1015,9 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu for (int g = 0; g < group_size; ++g) { TIMER_START(hmx_core); { - const __fp16 *vtcm_act_g = vtcm_activation + (size_t) g * act_head_stride; - const int n_row_tiles = hmx_ceil_div((int) n_rows, HMX_FP16_TILE_N_ROWS); - const int n_col_tiles = hmx_ceil_div((int) n_cols, HMX_FP16_TILE_N_COLS); - core_dot_chunk_fp16(vtcm_output, vtcm_act_g, vtcm_weight, vtcm_scales, - n_row_tiles, n_col_tiles, params->k / 32); + const __fp16 * vtcm_act_g = vtcm_activation + (size_t) g * act_head_stride; + core_dot_chunk_fp16(vtcm_output, vtcm_act_g, vtcm_weight, vtcm_scales, n_row_tiles, n_col_tiles, + params->k / 32); } TIMER_STOP(hmx_core); @@ -1030,12 +1029,12 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu TIMER_STOP(output_store); } } - - HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); } } } + HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); + TIMER_STOP(total); #if defined(ENABLE_PROFILE_TIMERS) @@ -1103,7 +1102,7 @@ int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, float *restrict dst, co return -1; } - hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // fp16: 1.0 + hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16 FARF(MEDIUM, "%s: m=%d k=%d n=%d mc=%zu nc=%zu vtcm=%zu/%zu", __func__, m, k, n, m_chunk_n_rows, n_chunk_n_cols, @@ -1121,7 +1120,8 @@ int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, float *restrict dst, co for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) { // transfer activation matrix chunk into VTCM - size_t n_rows = hex_smin(m - mr, m_chunk_n_rows); + const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows); + const size_t n_row_tiles = hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS); TIMER_START(activation_load); { @@ -1159,7 +1159,8 @@ int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, float *restrict dst, co } for (size_t nc = 0; nc < n; nc += n_chunk_n_cols) { - size_t n_cols = hex_smin(n - nc, n_chunk_n_cols); + const size_t n_cols = hex_smin(n - nc, n_chunk_n_cols); + const size_t n_col_tiles = hmx_ceil_div(n_cols, HMX_FP16_TILE_N_COLS); TIMER_START(weight_load); { @@ -1184,8 +1185,6 @@ int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, float *restrict dst, co TIMER_START(hmx_core); { - const int n_row_tiles = hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS); - const int n_col_tiles = hmx_ceil_div(n_cols, HMX_FP16_TILE_N_COLS); core_dot_chunk_fp16(vtcm_output, vtcm_activation, vtcm_weight, vtcm_scales, n_row_tiles, n_col_tiles, k / 32); } TIMER_STOP(hmx_core); @@ -1307,7 +1306,7 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds return -1; } - hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // fp16: 1.0 + hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16 FARF(MEDIUM, "%s: m=%d k=%d n=%d wtype=%d pipe=%d mc=%zu nc=%zu vtcm=%zu/%zu", __func__, m, k, n, weight_type, use_pipeline, @@ -1330,7 +1329,8 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds HAP_compute_res_hmx_lock(ctx->vtcm_rctx); for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) { // transfer activation matrix chunk into VTCM - size_t n_rows = hex_smin(m - mr, m_chunk_n_rows); + const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows); + const size_t n_row_tiles = hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS); TIMER_START(activation_load); { @@ -1348,7 +1348,8 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds } for (size_t nc = 0; nc < n; nc += n_chunk_n_cols) { - size_t n_cols = hex_smin(n - nc, n_chunk_n_cols); + const size_t n_cols = hex_smin(n - nc, n_chunk_n_cols); + const size_t n_col_tiles = hmx_ceil_div(n_cols, HMX_FP16_TILE_N_COLS); TIMER_START(weight_load); { @@ -1373,8 +1374,6 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds TIMER_START(hmx_core); { - const int n_row_tiles = hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS); - const int n_col_tiles = hmx_ceil_div(n_cols, HMX_FP16_TILE_N_COLS); core_dot_chunk_fp16(vtcm_output, vtcm_activation, vtcm_weight, vtcm_scales, n_row_tiles, n_col_tiles, k / 32); } TIMER_STOP(hmx_core); @@ -1521,14 +1520,16 @@ void core_mma_chunk_fp16(__fp16 *restrict c, const __fp16 *restrict a, const __f Q6_bias_mxmem2_A((void *)col_scales); - for (int i = 0; i < n_row_tiles; ++i) { - for (int j = 0; j < n_col_tiles; ++j) { + const size_t dot_tile_stride = n_dot_tiles * HMX_FP16_TILE_N_ELMS; + for (size_t i = 0; i < n_row_tiles; ++i) { + const __fp16 *row_base = a + i * dot_tile_stride; + __fp16 *res_base = c + i * n_col_tiles * HMX_FP16_TILE_N_ELMS; + for (size_t j = 0; j < n_col_tiles; ++j) { Q6_mxclracc_hf(); - const __fp16 *row_tiles = a + i * n_dot_tiles * HMX_FP16_TILE_N_ELMS; - const __fp16 *col_tiles = b + j * n_dot_tiles * HMX_FP16_TILE_N_ELMS; - - __fp16 *accum_tile = c + (i * n_col_tiles + j) * HMX_FP16_TILE_N_ELMS; + const __fp16 *col_tiles = b + j * dot_tile_stride; + const __fp16 *row_tiles = row_base; + __fp16 *accum_tile = res_base + j * HMX_FP16_TILE_N_ELMS; if (!zero_init) { Q6_activation_hf_mxmem_RR((unsigned int)accum_tile, 2047); Q6_weight_hf_mxmem_RR((unsigned int)eye_tile, 2047); @@ -1697,7 +1698,7 @@ int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict v = Q6_V_vror_VR(v, VLEN - 8); } } - hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // fp16: 1.0 + hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16 TIMER_DEFINE(fetch); TIMER_DEFINE(act_load); @@ -1715,7 +1716,7 @@ int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict const int n_col_tiles = hmx_ceil_div(n_blk_sz, HMX_FP16_TILE_N_COLS); for (size_t kk = 0; kk < k; kk += K_BLOCK_SIZE) { - size_t k_blk_sz = hex_smin(k - kk, K_BLOCK_SIZE); + const size_t k_blk_sz = hex_smin(k - kk, K_BLOCK_SIZE); TIMER_START(fetch); // fetch activation block into VTCM @@ -1731,13 +1732,13 @@ int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict } // fetch weight block into VTCM (x4x2 sub-block: quants + scales) + const size_t sub_row_stride = get_x4x2_row_stride(weight_type, k_blk_sz); { qweight_fetch_task_state_t s; const int blk_start = kk / QK_Q4_0x4x2; const int nb_sub = (k_blk_sz + QK_Q4_0x4x2 - 1) / QK_Q4_0x4x2; const int full_qrow = (weight_type == HTP_TYPE_Q8_0) ? k : (k / 2); - const size_t sub_row_stride = get_x4x2_row_stride(weight_type, k_blk_sz); const int scale_blk_size = (weight_type == HTP_TYPE_MXFP4) ? HMX_X4X2_MXFP4_EBLK_SIZE : HMX_X4X2_DBLK_SIZE; @@ -1777,7 +1778,6 @@ int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict dma_queue_pop(ctx->dma[0]); // vtcm_scratch0 is used to store the qweight chunk // worker_pool_run_func already returned, so fetch is done - const size_t sub_row_stride = get_x4x2_row_stride(weight_type, k_blk_sz); dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight, vtcm_scratch0, n_blk_sz, k_blk_sz, sub_row_stride, weight_type); } diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index fa84b674cd2..79b5ecd2270 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -98,6 +98,8 @@ enum htp_op_code { #define HTP_OP_MAX_VMEM (3221225472u) #endif +#define HTP_MMAP_MAX_VMEM (2147483648u) + enum htp_tensor_flags { HTP_TENSOR_COMPUTE = (1U << 0), // Tensor buffer temporal compute data (not weights) HTP_TENSOR_FLUSHED = (1U << 1) // Tensor buffer has been flushed (set by the NPU) diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index d71c97ed292..5091623a653 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -118,7 +118,11 @@ AEEResult htp_iface_close(remote_handle64 handle) { // release the mmaps (if any) for (uint32_t i=0; immap[i].size) { +#if __HVX_ARCH__ > 73 HAP_munmap2((void *) ctx->mmap[i].base, ctx->mmap[i].size); +#else + HAP_munmap((void *) ctx->mmap[i].base, ctx->mmap[i].size); +#endif ctx->mmap[i].size = 0; ctx->mmap[i].base = NULL; ctx->mmap[i].fd = -1; @@ -173,8 +177,16 @@ AEEResult htp_iface_mmap(remote_handle64 handle, int fd, uint32_t size, uint32_t struct htp_mmap *m = &ctx->mmap[i]; if (!m->size) { FARF(HIGH, "mmap : fd %u size %u pinned %u", fd, size, pinned); - +#if __HVX_ARCH__ > 73 void *va = HAP_mmap2(NULL, size, HAP_PROT_READ | HAP_PROT_WRITE, 0, fd, 0); +#else + if (size > HTP_MMAP_MAX_VMEM) { // HAP_mmap has a size limit of 2GB + FARF(ERROR, "mmap failed : size %u exceeds 2GB limit for HAP_mmap", (uint32_t) size); + abort(); // can't do much else at this point + } + + void *va = HAP_mmap(NULL, size, HAP_PROT_READ | HAP_PROT_WRITE, 0, fd, 0); +#endif if (va == (void*)-1) { FARF(ERROR, "mmap failed : va %p fd %u size %u", va, fd, (uint32_t) size); return AEE_EFAILED; @@ -202,7 +214,11 @@ AEEResult htp_iface_munmap(remote_handle64 handle, int fd) { struct htp_mmap *m = &ctx->mmap[i]; if (fd < 0 || m->fd == fd) { FARF(HIGH, "unmmap : base %p fd %u size %u", (void*) m->base, m->fd, (uint32_t) m->size); +#if __HVX_ARCH__ > 73 HAP_munmap2((void *) m->base, m->size); +#else + HAP_munmap((void *) m->base, m->size); +#endif m->size = 0; m->base = NULL; m->fd = -1; @@ -526,7 +542,11 @@ static inline bool reuse_buf(struct htp_context *ctx, uint32_t *m_reuse, struct static inline void drop_mmap(struct htp_context *ctx, struct htp_mmap *m) { if (m->size && !m->pinned) { FARF(HIGH, "unmap : fd %u base %p size %u pinned %u", m->fd, (void*) m->base, (uint32_t) m->size, m->pinned); +#if __HVX_ARCH__ > 73 HAP_munmap2((void *) m->base, m->size); +#else + HAP_munmap((void *) m->base, m->size); +#endif m->size = 0; m->base = 0; m->fd = -1; @@ -540,7 +560,16 @@ static inline void mmap_buf(struct htp_context *ctx, struct htp_buf_desc *b) { for (uint32_t i=0; i < HTP_MAX_MMAPS; i++) { struct htp_mmap *m = &ctx->mmap[i]; if (!m->size) { +#if __HVX_ARCH__ > 73 void *va = HAP_mmap2(NULL, b->size, HAP_PROT_READ | HAP_PROT_WRITE, 0, b->fd, 0); +#else + if (b->size > HTP_MMAP_MAX_VMEM) { // HAP_mmap has a size limit of 2GB + FARF(ERROR, "mmap failed : size %u exceeds 2GB limit for HAP_mmap", (uint32_t) b->size); + abort(); // can't do much else at this point + } + + void *va = HAP_mmap(NULL, b->size, HAP_PROT_READ | HAP_PROT_WRITE, 0, b->fd, 0); +#endif if (va == (void*)-1) { FARF(ERROR, "mmap failed : va %p fd %u size %u", va, b->fd, (uint32_t) b->size); abort(); // can't do much else at this point From 77c0630ce64e63a63f16e30d9982608b5f6474fa Mon Sep 17 00:00:00 2001 From: lhez Date: Thu, 16 Apr 2026 22:28:33 -0700 Subject: [PATCH 450/831] opencl: refactor q8_0 set_tensor and mul_mat host side dispatch for Adreno (llama/21938) * opencl: refactor q8_0 gemm/gemv Adreno dispatch * opencl: refactor q8_0 set_tensor * opencl: fix whitespace --- ggml/src/ggml-opencl/ggml-opencl.cpp | 361 ++++++++------------------- 1 file changed, 99 insertions(+), 262 deletions(-) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index b27fbb13a3a..8bc7ae65a6d 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -5116,115 +5116,8 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, GGML_ASSERT(tensor->ne[2] == 1); GGML_ASSERT(tensor->ne[3] == 1); - // Transpose weights - size_t q_size_bytes = K * M / 4 * sizeof(float); - cl_buffer_region region; - region.origin = 0; - region.size = q_size_bytes; - cl_mem qT_d = clCreateSubBuffer( - backend_ctx->prealloc_quant_trans.buffer, - 0, - CL_BUFFER_CREATE_TYPE_REGION, - ®ion, - &err); - CL_CHECK(err); - - cl_mem q_d_image1D; - cl_mem qT_d_image1D; - - cl_image_format img_fmt_1d; - cl_image_desc img_desc_1d; - - img_fmt_1d = { CL_RGBA, CL_FLOAT }; - memset(&img_desc_1d, 0, sizeof(img_desc_1d)); - img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; - img_desc_1d.image_width = M * K / 4 / 4; - img_desc_1d.buffer = extra->q; - q_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err); - CL_CHECK(err); - - img_fmt_1d = { CL_RGBA, CL_FLOAT }; - memset(&img_desc_1d, 0, sizeof(img_desc_1d)); - img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; - img_desc_1d.image_width = M * K / 4 / 4; - img_desc_1d.buffer = qT_d; - qT_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err); - CL_CHECK(err); - - int height_q = M / 4; - int width_q = K / 4 / 4; - kernel = backend_ctx->kernel_transpose_32; - - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &q_d_image1D)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &qT_d_image1D)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_q)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_q)); - - size_t local_size_q[3] = {4, 16, 1}; - size_t global_size_q[3] = {static_cast(width_q), static_cast(height_q), 1}; - CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_size_q, local_size_q, 0, NULL, &evt)); - CL_CHECK(clWaitForEvents(1, &evt)); - - // Transpose scales - size_t d_size_bytes = M * (K / 32) * 2; - region.origin = 0; - region.size = d_size_bytes; - cl_mem dT_d = clCreateSubBuffer( - backend_ctx->prealloc_scales_trans.buffer, - 0, - CL_BUFFER_CREATE_TYPE_REGION, - ®ion, - &err); - CL_CHECK(err); - - cl_mem d_d_image1D; - cl_mem dT_d_image1D; - - memset(&img_desc_1d, 0, sizeof(img_desc_1d)); - img_fmt_1d = { CL_R, CL_HALF_FLOAT }; - img_desc_1d.image_width = M * K / 32; - img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; - img_desc_1d.buffer = extra->d; - d_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err); - CL_CHECK(err); - - img_fmt_1d = { CL_RGBA, CL_HALF_FLOAT }; - memset(&img_desc_1d, 0, sizeof(img_desc_1d)); - img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; - img_desc_1d.image_width = M * K / 32 / 4; - img_desc_1d.buffer = dT_d; - dT_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err); - CL_CHECK(err); - - int height_s = M / 4; - int width_s = K / 32; - - kernel = backend_ctx->kernel_transpose_16_4x1; - - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &d_d_image1D)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &dT_d_image1D)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_s)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_s)); - - size_t local_size_s[3] = {4, 16, 1}; - size_t global_size_s[3] = {static_cast(width_s), static_cast(height_s), 1}; - CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_size_s, local_size_s, 0, NULL, &evt)); - CL_CHECK(clWaitForEvents(1, &evt)); - - // copy transposed buffer contents to original buffers - CL_CHECK(clEnqueueCopyBuffer(queue, qT_d, extra->q, 0, 0, q_size_bytes, 0, NULL, &evt)); - CL_CHECK(clWaitForEvents(1, &evt)); - - CL_CHECK(clEnqueueCopyBuffer(queue, dT_d, extra->d, 0, 0, d_size_bytes, 0, NULL, &evt)); - CL_CHECK(clWaitForEvents(1, &evt)); - - CL_CHECK(clReleaseMemObject(qT_d)); - CL_CHECK(clReleaseMemObject(dT_d)); - - CL_CHECK(clReleaseMemObject(q_d_image1D)); - CL_CHECK(clReleaseMemObject(d_d_image1D)); - CL_CHECK(clReleaseMemObject(qT_d_image1D)); - CL_CHECK(clReleaseMemObject(dT_d_image1D)); + transpose_2d_as_32b(backend_ctx, extra->q, extra->q, size_q, K/4, M); + transpose_2d_as_16b(backend_ctx, extra->d, extra->d, size_d, K/32, M); } // end transpose #endif // GGML_OPENCL_USE_ADRENO_KERNELS @@ -9956,19 +9849,18 @@ static void ggml_cl_mul_mat_q8_0_f32_adreno(ggml_backend_t backend, const ggml_t GGML_ASSERT(dst); GGML_ASSERT(dst->extra); - const enum ggml_type src0t = src0->type; - const enum ggml_type src1t = src1->type; - - GGML_ASSERT(src0t == GGML_TYPE_Q8_0); - GGML_ASSERT(src1t == GGML_TYPE_F32); + GGML_ASSERT(src0->type == GGML_TYPE_Q8_0); + GGML_ASSERT(src1->type == GGML_TYPE_F32); ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; - ggml_tensor_extra_cl_q8_0 * extra0_q8_0 = (ggml_tensor_extra_cl_q8_0 *)src0->extra; + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + GGML_ASSERT(src1->view_offs == 0); GGML_ASSERT(dst->view_offs == 0); @@ -9989,148 +9881,112 @@ static void ggml_cl_mul_mat_q8_0_f32_adreno(ggml_backend_t backend, const ggml_t cl_context context = backend_ctx->context; cl_kernel kernel; - // init CL objects - cl_int status; - cl_image_format img_fmt_1d; - cl_image_desc img_desc_1d; + cl_int err; + cl_image_format img_fmt; + cl_image_desc img_desc; cl_buffer_region region; - cl_mem A_image1d; - cl_mem B_image1d; - cl_mem B_sub_buffer; - cl_mem S_image1d; - // for B transpose - cl_mem B_image1d_trans = nullptr; - cl_mem B_d = nullptr; - - cl_mem D_image1d; - cl_mem D_sub_buffer; int M = ne01; int N = ne1; int K = ne00; - // create an image for A - img_fmt_1d = { CL_R, CL_FLOAT}; - memset(&img_desc_1d, 0, sizeof(img_desc_1d)); - img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; - img_desc_1d.image_width = M * K / 4; // Divide by 4 for char -> float - img_desc_1d.buffer = extra0_q8_0->q; - A_image1d = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt_1d, &img_desc_1d, NULL, &status); - CL_CHECK(status); - - // create an image for Scale - img_fmt_1d = { CL_R, CL_HALF_FLOAT}; - memset(&img_desc_1d, 0, sizeof(img_desc_1d)); - img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; - img_desc_1d.image_width = M * K / 32; // Block size is 32 - img_desc_1d.buffer = extra0_q8_0->d; - S_image1d = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt_1d, &img_desc_1d, NULL, &status); - CL_CHECK(status); - - // create a sub_buffer for B - region.origin = (extra1->offset); // + src1->view_offs); - region.size = K * N * sizeof(float); - B_sub_buffer = clCreateSubBuffer((extra1->data_device), 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); - CL_CHECK(status); - - // create an image for B from sub_buffer: RGBA (OCL) - img_fmt_1d = {CL_RGBA, CL_FLOAT}; - memset(&img_desc_1d, 0, sizeof(img_desc_1d)); - img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; - img_desc_1d.image_width = K * N / 4; - img_desc_1d.buffer = B_sub_buffer; - B_image1d = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt_1d, &img_desc_1d, NULL, &status); - CL_CHECK(status); + if (ne1 == 1) { + cl_mem q_img = nullptr; + cl_mem b_sub_buf = nullptr; + cl_mem b_img = nullptr; - // Create subbuffer and image1d_buffer for dst - region.origin = (extrad->offset); // + dst->view_offs; - region.size = M * N * sizeof(float); - D_sub_buffer = clCreateSubBuffer((extrad->data_device), 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); - CL_CHECK(status); + // image for q + img_fmt = { CL_R, CL_UNSIGNED_INT32}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = M * K / 4; + img_desc.buffer = extra0_q8_0->q; + CL_CHECK((q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); - img_fmt_1d = {CL_R, CL_FLOAT}; - memset(&img_desc_1d, 0, sizeof(img_desc_1d)); - img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; - img_desc_1d.image_width = M * N; - img_desc_1d.buffer = D_sub_buffer; - D_image1d = clCreateImage(context, CL_MEM_WRITE_ONLY, &img_fmt_1d, &img_desc_1d, NULL, &status); - CL_CHECK(status); + // create a sub_buffer for B + region.origin = offset1; + region.size = K * N * sizeof(float); + CL_CHECK((b_sub_buf = clCreateSubBuffer((extra1->data_device), 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); - size_t local_work_size[3] = {1, 1, 1}; - size_t global_work_size[3] = {1, 1, 1}; + // image for activations + img_fmt = {CL_RGBA, CL_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * N / 4; + img_desc.buffer = b_sub_buf; + CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); - if (N == 1) { kernel = backend_ctx->CL_mul_mat_vec_q8_0_f32; int r2 = 1; int r3 = 1; - cl_uint k_arg = 0; - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_mem), &A_image1d)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_mem), &extra0_q8_0->d)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_mem), &B_image1d)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_ulong), &extra1->offset)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_mem), &extrad->data_device)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_ulong), &extrad->offset)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne00)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne01)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne02)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne10)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne12)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne0)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne1)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &r2)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &r3)); + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &q_img)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q8_0->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &b_img)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &extra1->offset)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &extrad->offset)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r3)); size_t wavesize = backend_ctx->adreno_wave_size; - local_work_size[0] = wavesize; - local_work_size[1] = 4; // reduce factor - local_work_size[2] = 1; + size_t local_work_size[] = { wavesize, 4, 1 }; + size_t global_work_size[] = { CEIL_DIV(M, wavesize)*wavesize, 4, 1 }; - global_work_size[0] = ((M + wavesize - 1) / wavesize) * wavesize; - global_work_size[1] = 4; // reduce factor - global_work_size[2] = 1; + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + + CL_CHECK(clReleaseMemObject(q_img)); + CL_CHECK(clReleaseMemObject(b_img)); + CL_CHECK(clReleaseMemObject(b_sub_buf)); } else { - cl_ulong offsetd = extrad->offset + dst->view_offs; - int padding; + cl_mem b_sub_buf = nullptr; + cl_mem b_sub_buf_trans = nullptr; + cl_mem b_img = nullptr; + cl_mem b_img_trans = nullptr; - //how many extra elements beyond multiple of 8 - int extra_elements = N % 8; + // subbuffer for activations + region.origin = offset1; + region.size = K * N * sizeof(float); + CL_CHECK((b_sub_buf = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for activations + img_fmt = {CL_RGBA, CL_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * N / 4; + img_desc.buffer = b_sub_buf; + CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); - //how much padding to add - padding = 0; + // pad N to multiple of 8 + int extra_elements = N % 8; + int padding = 0; if (extra_elements > 0){ padding = 8 - extra_elements; } - // Specify the starting offset (in bytes) + // subbuffer for transposed activations region.origin = 0; - // Specify the size of the sub-buffer (divide by 2 for FP16) region.size = K * (N + padding) * sizeof(float)/2; backend_ctx->prealloc_act_trans.allocate(context, region.size); - B_d = clCreateSubBuffer( - backend_ctx->prealloc_act_trans.buffer, - 0, - CL_BUFFER_CREATE_TYPE_REGION, - ®ion, - &status); - CL_CHECK(status); + CL_CHECK((b_sub_buf_trans = clCreateSubBuffer(backend_ctx->prealloc_act_trans.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); - cl_image_format image_format_B_d_output = { CL_RGBA, CL_HALF_FLOAT }; //(CL_HALF_FLOAT for FP16) - cl_image_desc image_desc_B_d_output = { - CL_MEM_OBJECT_IMAGE1D_BUFFER, - static_cast(K * (N + padding)/4), - 0, 0, 0, 0, 0, 0, 0, { B_d } - }; - B_image1d_trans = clCreateImage( - context, - 0, - &image_format_B_d_output, - &image_desc_B_d_output, - NULL, - &status); - CL_CHECK(status); + // image for transposed activations + img_fmt = {CL_RGBA, CL_HALF_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * (N + padding) / 4; + img_desc.buffer = b_sub_buf_trans; + CL_CHECK((b_img_trans = clCreateImage(context, 0, &img_fmt, &img_desc, NULL, &err), err)); + // transpose activations int height_B = N/4; if (height_B == 0) { height_B = 1; @@ -10139,58 +9995,39 @@ static void ggml_cl_mul_mat_q8_0_f32_adreno(ggml_backend_t backend, const ggml_t int padded_height_B = (N + padding)/4; kernel = backend_ctx->kernel_transpose_32_16; - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &B_image1d)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &B_image1d_trans)); + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &b_img)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &b_img_trans)); CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_B)); CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_B)); CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &padded_height_B)); - size_t local_size_t[2] = { 1, 16 }; - size_t global_size_t[2] = { - static_cast(width_B), - static_cast(padded_height_B) - }; - - backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_size_t, local_size_t, dst); + size_t local_work_size_t[2] = { 1, 16 }; + size_t global_work_size_t[2] = { (size_t)width_B, (size_t)padded_height_B }; + backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size_t, local_work_size_t, dst); + // gemm kernel = backend_ctx->kernel_mul_mm_q8_0_f32_8x4; - - int N_with_padding = N + padding; + int padded_N = N + padding; CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q8_0->q)); CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q8_0->d)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &B_image1d_trans)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &b_img_trans)); CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extrad->data_device)); CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &K)); CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &M)); - CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &N_with_padding)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &padded_N)); CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &N)); CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &offsetd)); - global_work_size[0] = (size_t)(N + 7) / 8; - global_work_size[1] = (size_t)(M + 3) / 4; - global_work_size[2] = 1; - - local_work_size[0] = 2; - local_work_size[1] = 128; - local_work_size[2] = 1; - } + size_t global_work_size[] = { (size_t)CEIL_DIV(N, 8), (size_t)CEIL_DIV(M, 4), 1 }; + size_t local_work_size[] = { 2, 128, 1 }; - // enqueue kernel with profiling - backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); - // deallocate sub buffers and images - CL_CHECK(clReleaseMemObject(A_image1d)); - CL_CHECK(clReleaseMemObject(B_sub_buffer)); - CL_CHECK(clReleaseMemObject(B_image1d)); - CL_CHECK(clReleaseMemObject(S_image1d)); - CL_CHECK(clReleaseMemObject(D_sub_buffer)); - CL_CHECK(clReleaseMemObject(D_image1d)); - if (B_image1d_trans) { - CL_CHECK(clReleaseMemObject(B_image1d_trans)); - } - if (B_d) { - CL_CHECK(clReleaseMemObject(B_d)); + CL_CHECK(clReleaseMemObject(b_img_trans)); + CL_CHECK(clReleaseMemObject(b_sub_buf_trans)); + CL_CHECK(clReleaseMemObject(b_img)); + CL_CHECK(clReleaseMemObject(b_sub_buf)); } #else GGML_UNUSED(backend); From 918e0ad20954beaf4a57675749e0a54f12f4233b Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Fri, 17 Apr 2026 23:24:21 +0800 Subject: [PATCH 451/831] CUDA: use LRU based eviction for cuda graphs (llama/21611) * CUDA: use a ring-buffer for cuda graphs * bump limit to 128 * use LRU eviction * better naming * do periodic clean-up --- ggml/src/ggml-cuda/common.cuh | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 66ed02d2923..ddf50baf495 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -1187,6 +1187,7 @@ struct ggml_cuda_graph { bool disable_due_to_gpu_arch = false; bool warmup_complete = false; uint64_t uid = 0; + int64_t last_used_time = 0; struct node_properties { ggml_tensor node; void * node_src_data_ptrs[GGML_MAX_SRC]; @@ -1368,12 +1369,28 @@ struct ggml_backend_cuda_context { // when the computation is split across CPU/GPU (e.g., with --n-cpu-moe) std::unordered_map> cuda_graphs; + int64_t last_graph_eviction_sweep = 0; + ggml_cuda_graph * cuda_graph(const void * first_node_ptr) { + const int64_t time_now = ggml_time_us(); + + // sweep every 5s, evicting cuda graphs unused for >=10s + if (time_now - last_graph_eviction_sweep >= 5'000'000) { + last_graph_eviction_sweep = time_now; + for (auto it = cuda_graphs.begin(); it != cuda_graphs.end(); ) { + if (time_now - it->second->last_used_time >= 10'000'000) { + it = cuda_graphs.erase(it); + } else { + ++it; + } + } + } + auto it = cuda_graphs.find(first_node_ptr); if (it == cuda_graphs.end()) { - cuda_graphs[first_node_ptr] = std::make_unique(); - return cuda_graphs[first_node_ptr].get(); + it = cuda_graphs.emplace(first_node_ptr, std::make_unique()).first; } + it->second->last_used_time = time_now; return it->second.get(); } From cbbe935765b0a4d5f301ff8e8f4636f3b6e94c98 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Fri, 17 Apr 2026 09:17:11 -0700 Subject: [PATCH 452/831] ggml-webgpu: fix compiler warnings and refactor FlashAttention encoding (llama/21052) * Update workflows to remove dependence on llvmpipe * Try setting Dawn_DIR * remove c++20 initializers * Move to proper guid * Try avoiding segfaults on vulkan backend process exit * Remove compiler warnings on parameter casting * Fix soft_max and update reg_tile accumulation to f32 for better precision * Refactor flash_attn a bit * remove c++20 initializers and format * Increase div precision for NVIDIA * revert div precision and comment out ggml-ci node for now * Formatting * Try debugging on a failing CI node * Revert "Try debugging on a failing CI node" This reverts commit 1971e33cba919915e12bcfd5828abfbd54ca942e. --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 585 ++++--- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 1498 +++++++---------- .../wgsl-shaders/flash_attn_vec_blk.wgsl | 12 +- 3 files changed, 918 insertions(+), 1177 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 3de6258c74d..7d9a4403fab 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -390,12 +390,11 @@ struct ggml_webgpu_flash_attn_pipeline_key { bool has_mask; bool has_sinks; bool uses_logit_softcap; - bool use_vec; bool operator==(const ggml_webgpu_flash_attn_pipeline_key & other) const { return kv_type == other.kv_type && head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v && kv_direct == other.kv_direct && has_mask == other.has_mask && has_sinks == other.has_sinks && - uses_logit_softcap == other.uses_logit_softcap && use_vec == other.use_vec; + uses_logit_softcap == other.uses_logit_softcap; } }; @@ -409,47 +408,37 @@ struct ggml_webgpu_flash_attn_pipeline_key_hash { ggml_webgpu_hash_combine(seed, key.has_mask); ggml_webgpu_hash_combine(seed, key.has_sinks); ggml_webgpu_hash_combine(seed, key.uses_logit_softcap); - ggml_webgpu_hash_combine(seed, key.use_vec); return seed; } }; -struct ggml_webgpu_flash_attn_shader_lib_context { - ggml_webgpu_flash_attn_pipeline_key key; - uint32_t sg_mat_m; - uint32_t sg_mat_n; - uint32_t sg_mat_k; - size_t wg_mem_limit_bytes; - uint32_t max_subgroup_size; +struct ggml_webgpu_flash_attn_decisions { + uint32_t q_tile = 0; + uint32_t kv_tile = 0; + uint32_t wg_size = 0; }; -struct ggml_webgpu_flash_attn_shader_decisions { - uint32_t q_tile = 0; +struct ggml_webgpu_flash_attn_vec_decisions { uint32_t kv_tile = 0; uint32_t wg_size = 0; }; -inline uint32_t ggml_webgpu_flash_attn_pick_vec_ne(const ggml_webgpu_flash_attn_pipeline_key & key) { - // Keep conservative defaults unless this is the f16 vec-split shape family. - if (key.kv_type != GGML_TYPE_F16 || key.head_dim_qk != key.head_dim_v) { - return 1u; - } - - // Head-dim specializations used by the tuned vec f16 path. - switch (key.head_dim_qk) { - case 64: - return 2u; - case 96: - return 4u; - case 128: - return 1u; - case 192: - return 2u; - case 576: - return 2u; - default: - return 1u; - } +inline ggml_webgpu_flash_attn_pipeline_key ggml_webgpu_flash_attn_make_pipeline_key( + const ggml_webgpu_shader_lib_context & context) { + const bool has_mask = context.src3 != nullptr; + const bool has_sinks = context.src4 != nullptr; + const bool kv_direct = (context.src1->type == GGML_TYPE_F16) && (context.src0->ne[0] % context.sg_mat_k == 0) && + (context.src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0); + + ggml_webgpu_flash_attn_pipeline_key key = {}; + key.kv_type = context.src1->type; + key.head_dim_qk = (uint32_t) context.src0->ne[0]; + key.head_dim_v = (uint32_t) context.src2->ne[0]; + key.kv_direct = kv_direct; + key.has_mask = has_mask; + key.has_sinks = has_sinks; + key.uses_logit_softcap = ggml_get_op_params_f32(context.dst, 2) != 0.0f; + return key; } struct ggml_webgpu_flash_attn_vec_reduce_pipeline_key { @@ -471,79 +460,20 @@ inline bool operator==(const ggml_webgpu_flash_attn_vec_reduce_pipeline_key & lh return lhs.head_dim_v == rhs.head_dim_v && lhs.wg_size == rhs.wg_size; } -struct ggml_webgpu_flash_attn_vec_reduce_shader_lib_context { - ggml_webgpu_flash_attn_vec_reduce_pipeline_key key; - uint32_t max_wg_size; -}; - -inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_vec_reduce_shader( - pre_wgsl::Preprocessor & preprocessor, - const char * shader_src, - const ggml_webgpu_flash_attn_vec_reduce_shader_lib_context & context) { - std::vector defines; - std::string variant = "flash_attn_vec_reduce"; - - defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(context.key.head_dim_v)); - variant += std::string("_hsv") + std::to_string(context.key.head_dim_v); - - defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); - variant += std::string("_wg") + std::to_string(context.max_wg_size); - - ggml_webgpu_processed_shader result; - result.wgsl = preprocessor.preprocess(shader_src, defines); - result.variant = variant; - return result; -} - struct ggml_webgpu_flash_attn_blk_pipeline_key { - uint32_t q_tile; uint32_t kv_tile; - bool operator==(const ggml_webgpu_flash_attn_blk_pipeline_key & other) const { - return q_tile == other.q_tile && kv_tile == other.kv_tile; - } + bool operator==(const ggml_webgpu_flash_attn_blk_pipeline_key & other) const { return kv_tile == other.kv_tile; } }; struct ggml_webgpu_flash_attn_blk_pipeline_key_hash { size_t operator()(const ggml_webgpu_flash_attn_blk_pipeline_key & key) const { size_t seed = 0; - ggml_webgpu_hash_combine(seed, key.q_tile); ggml_webgpu_hash_combine(seed, key.kv_tile); return seed; } }; -struct ggml_webgpu_flash_attn_blk_shader_lib_context { - ggml_webgpu_flash_attn_blk_pipeline_key key; - uint32_t max_wg_size; -}; - -inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_blk_shader( - pre_wgsl::Preprocessor & preprocessor, - const char * shader_src, - const ggml_webgpu_flash_attn_blk_shader_lib_context & context) { - std::vector defines; - std::string variant = "flash_attn_vec_blk"; - - defines.push_back(std::string("Q_TILE=") + std::to_string(context.key.q_tile)); - variant += std::string("_qt") + std::to_string(context.key.q_tile); - - defines.push_back(std::string("KV_TILE=") + std::to_string(context.key.kv_tile)); - variant += std::string("_kvt") + std::to_string(context.key.kv_tile); - - uint32_t wg_size = 1; - while ((wg_size << 1) <= context.max_wg_size) { - wg_size <<= 1; - } - defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); - variant += std::string("_wg") + std::to_string(wg_size); - - ggml_webgpu_processed_shader result; - result.wgsl = preprocessor.preprocess(shader_src, defines); - result.variant = variant; - return result; -} - // This is exposed because it's necessary in supports_op inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile, uint32_t kv_tile, @@ -568,6 +498,41 @@ inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile, return f16_elems * GGML_WEBGPU_F16_SIZE_BYTES + f32_elems * GGML_WEBGPU_F32_SIZE_BYTES; } +inline uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_shader_lib_context & context, + const ggml_webgpu_flash_attn_pipeline_key & key) { + const size_t limit_bytes = context.wg_mem_limit_bytes; + const size_t q_tile = context.sg_mat_m; + const size_t base_q_bytes = (key.head_dim_qk + key.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES + + 2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES; + size_t bytes_per_kv = 0; + if (!key.kv_direct) { + bytes_per_kv += std::max(key.head_dim_qk, key.head_dim_v); + } + if (key.has_mask) { + bytes_per_kv += q_tile; + } + bytes_per_kv += q_tile; + bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES; + const uint32_t max_kv_tile = (limit_bytes - base_q_bytes) / bytes_per_kv; + return (max_kv_tile / context.sg_mat_n) * context.sg_mat_n; +} + +inline uint32_t ggml_webgpu_flash_attn_vec_get_kv_tile(const ggml_webgpu_shader_lib_context & context) { + const ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context); + const uint32_t min_kv_tile = ggml_webgpu_flash_attn_max_kv_tile(context, key); + uint32_t kv_tile = std::max(context.sg_mat_n, std::min(32u, min_kv_tile)); + kv_tile = (kv_tile / context.sg_mat_n) * context.sg_mat_n; + + if (key.kv_direct) { + kv_tile = std::min(kv_tile, GGML_WEBGPU_KV_SEQ_PAD); + while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) { + kv_tile -= context.sg_mat_n; + } + } + + return kv_tile; +} + /** Matrix Multiplication **/ struct ggml_webgpu_legacy_mul_mat_pipeline_key { @@ -802,6 +767,8 @@ class ggml_webgpu_shader_lib { repeat_pipelines; // type std::unordered_map flash_attn_pipelines; + std::unordered_map + flash_attn_vec_pipelines; std::unordered_map @@ -849,10 +816,9 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_row_norm_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_row_norm_pipeline_key key = { - .op = context.dst->op, - .inplace = context.inplace, - }; + ggml_webgpu_row_norm_pipeline_key key = {}; + key.op = context.dst->op; + key.inplace = context.inplace; auto it = row_norm_pipelines.find(key); if (it != row_norm_pipelines.end()) { @@ -908,9 +874,10 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_set_rows_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_set_rows_pipeline_key key = { .dst_type = context.dst->type, - .vec4 = context.src0->ne[0] % 4 == 0, - .i64_idx = context.src1->type == GGML_TYPE_I64 }; + ggml_webgpu_set_rows_pipeline_key key = {}; + key.dst_type = context.dst->type; + key.vec4 = context.src0->ne[0] % 4 == 0; + key.i64_idx = context.src1->type == GGML_TYPE_I64; auto it = set_rows_pipelines.find(key); if (it != set_rows_pipelines.end()) { @@ -955,7 +922,9 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_set_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_set_pipeline_key key = { .type = context.dst->type, .inplace = context.inplace }; + ggml_webgpu_set_pipeline_key key = {}; + key.type = context.dst->type; + key.inplace = context.inplace; auto it = set_pipelines.find(key); if (it != set_pipelines.end()) { @@ -1062,10 +1031,9 @@ class ggml_webgpu_shader_lib { webgpu_pipeline get_get_rows_pipeline(const ggml_webgpu_shader_lib_context & context) { const bool vectorized = context.src0->type == GGML_TYPE_F32 && context.dst->ne[0] % 4 == 0; - ggml_webgpu_get_rows_pipeline_key key = { - .src_type = context.src0->type, - .vectorized = (int) vectorized, - }; + ggml_webgpu_get_rows_pipeline_key key = {}; + key.src_type = context.src0->type; + key.vectorized = (int) vectorized; auto it = get_rows_pipelines.find(key); if (it != get_rows_pipelines.end()) { @@ -1115,8 +1083,7 @@ class ggml_webgpu_shader_lib { std::string type_upper = type_str; std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper); - switch (key.src_type) - { + switch (key.src_type) { case GGML_TYPE_Q4_0: case GGML_TYPE_Q5_0: case GGML_TYPE_Q8_0: @@ -1136,9 +1103,9 @@ class ggml_webgpu_shader_lib { break; } default: - { - defines.push_back(std::string("SRC_TYPE=") + type_str); - } + { + defines.push_back(std::string("SRC_TYPE=") + type_str); + } } defines.push_back("BYTE_HELPERS"); @@ -1181,7 +1148,8 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_scale_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_scale_pipeline_key key = { .inplace = context.inplace }; + ggml_webgpu_scale_pipeline_key key = {}; + key.inplace = context.inplace; auto it = scale_pipelines.find(key); if (it != scale_pipelines.end()) { @@ -1208,11 +1176,10 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_solve_tri_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_solve_tri_pipeline_key key = { - .type = context.dst->type, - .n = (int) context.src0->ne[0], - .k = (int) context.src1->ne[0], - }; + ggml_webgpu_solve_tri_pipeline_key key = {}; + key.type = context.dst->type; + key.n = (int) context.src0->ne[0]; + key.k = (int) context.src1->ne[0]; auto it = solve_tri_pipelines.find(key); if (it != solve_tri_pipelines.end()) { @@ -1250,10 +1217,9 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_ssm_conv_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_ssm_conv_pipeline_key key = { - .type = context.dst->type, - .vectorized = context.src1->ne[0] == 4, - }; + ggml_webgpu_ssm_conv_pipeline_key key = {}; + key.type = context.dst->type; + key.vectorized = context.src1->ne[0] == 4; auto it = ssm_conv_pipelines.find(key); if (it != ssm_conv_pipelines.end()) { @@ -1293,11 +1259,10 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_gated_delta_net_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_gated_delta_net_pipeline_key key = { - .type = context.dst->type, - .s_v = (int) context.src2->ne[0], - .kda = context.src3->ne[0] == context.src2->ne[0], - }; + ggml_webgpu_gated_delta_net_pipeline_key key = {}; + key.type = context.dst->type; + key.s_v = (int) context.src2->ne[0]; + key.kda = context.src3->ne[0] == context.src2->ne[0]; auto it = gated_delta_net_pipelines.find(key); if (it != gated_delta_net_pipelines.end()) { @@ -1330,7 +1295,8 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_pad_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_pad_pipeline_key key = { .circular = ggml_get_op_params_i32(context.dst, 8) != 0 }; + ggml_webgpu_pad_pipeline_key key = {}; + key.circular = ggml_get_op_params_i32(context.dst, 8) != 0; auto it = pad_pipelines.find(key); if (it != pad_pipelines.end()) { @@ -1357,15 +1323,13 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_mul_mat_vec_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_mul_mat_vec_pipeline_key key = { - .src0_type = context.src0->type, - .src1_type = context.src1->type, - // Quantized mat-vec path currently runs scalar; only allow vectorization when both inputs are float - .vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 && - (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? - 1 : - 0, - }; + ggml_webgpu_mul_mat_vec_pipeline_key key = {}; + key.src0_type = context.src0->type; + key.src1_type = context.src1->type; + key.vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 && + (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? + 1 : + 0; auto it = mul_mat_vec_pipelines.find(key); if (it != mul_mat_vec_pipelines.end()) { @@ -1451,15 +1415,14 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_mul_mat_fast_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_mul_mat_pipeline_key key = { - .src0_type = context.src0->type, - .src1_type = context.src1->type, - .vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 && context.dst->ne[1] % 4 == 0 && - (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? - 1 : - 0, - .use_subgroup_matrix = context.supports_subgroup_matrix - }; + ggml_webgpu_mul_mat_pipeline_key key = {}; + key.src0_type = context.src0->type; + key.src1_type = context.src1->type; + key.vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 && context.dst->ne[1] % 4 == 0 && + (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? + 1 : + 0; + key.use_subgroup_matrix = context.supports_subgroup_matrix; auto it = mul_mat_fast_pipelines.find(key); if (it != mul_mat_fast_pipelines.end()) { @@ -1578,8 +1541,9 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_mul_mat_legacy_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_legacy_mul_mat_pipeline_key key = { .src0_type = context.src0->type, - .src1_type = context.src1->type }; + ggml_webgpu_legacy_mul_mat_pipeline_key key = {}; + key.src0_type = context.src0->type; + key.src1_type = context.src1->type; auto it = mul_mat_legacy_pipelines.find(key); if (it != mul_mat_legacy_pipelines.end()) { @@ -1621,8 +1585,7 @@ class ggml_webgpu_shader_lib { std::string type_upper = src0_name; std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper); - switch (context.src0->type) - { + switch (context.src0->type) { case GGML_TYPE_Q4_0: case GGML_TYPE_Q5_0: case GGML_TYPE_Q8_0: @@ -1642,9 +1605,9 @@ class ggml_webgpu_shader_lib { break; } default: - { - defines.push_back(std::string("SRC0_TYPE=") + src0_name); - } + { + defines.push_back(std::string("SRC0_TYPE=") + src0_name); + } } defines.push_back("BYTE_HELPERS"); @@ -1689,10 +1652,9 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_mul_mat_id_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_mul_mat_id_pipeline_key key = { - .src0_type = context.src0->type, - .src1_type = context.src1->type, - }; + ggml_webgpu_mul_mat_id_pipeline_key key = {}; + key.src0_type = context.src0->type; + key.src1_type = context.src1->type; auto it = mul_mat_id_pipelines.find(key); if (it != mul_mat_id_pipelines.end()) { @@ -1782,13 +1744,12 @@ class ggml_webgpu_shader_lib { webgpu_pipeline get_unary_pipeline(const ggml_webgpu_shader_lib_context & context) { const bool is_unary = context.dst->op == GGML_OP_UNARY; const int op = is_unary ? (int) ggml_get_unary_op(context.dst) : context.dst->op; - ggml_webgpu_unary_pipeline_key key = { - .type = context.dst->type, - .op = op, - .is_unary = is_unary, - .inplace = context.inplace, - .ttype = (ggml_tri_type) ggml_get_op_params_i32(context.dst, 0), - }; + ggml_webgpu_unary_pipeline_key key = {}; + key.type = context.dst->type; + key.op = op; + key.is_unary = is_unary; + key.inplace = context.inplace; + key.ttype = (ggml_tri_type) ggml_get_op_params_i32(context.dst, 0); auto it = unary_pipelines.find(key); if (it != unary_pipelines.end()) { @@ -1853,13 +1814,12 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_binary_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_binary_pipeline_key key = { - .type = context.dst->type, - .op = context.dst->op, - .inplace = context.inplace, - .overlap = context.overlap, - .src_overlap = context.src_overlap, - }; + ggml_webgpu_binary_pipeline_key key = {}; + key.type = context.dst->type; + key.op = context.dst->op; + key.inplace = context.inplace; + key.overlap = context.overlap; + key.src_overlap = context.src_overlap; auto it = binary_pipelines.find(key); if (it != binary_pipelines.end()) { @@ -1908,9 +1868,8 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_concat_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_concat_pipeline_key key = { - .type = context.dst->type, - }; + ggml_webgpu_concat_pipeline_key key = {}; + key.type = context.dst->type; auto it = concat_pipelines.find(key); if (it != concat_pipelines.end()) { @@ -1945,9 +1904,8 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_repeat_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_repeat_pipeline_key key = { - .type = context.dst->type, - }; + ggml_webgpu_repeat_pipeline_key key = {}; + key.type = context.dst->type; auto it = repeat_pipelines.find(key); if (it != repeat_pipelines.end()) { @@ -1985,16 +1943,16 @@ class ggml_webgpu_shader_lib { return repeat_pipelines[key]; } - webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_flash_attn_shader_lib_context & context) { - auto it = flash_attn_pipelines.find(context.key); + webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_shader_lib_context & context) { + const ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context); + auto it = flash_attn_pipelines.find(key); if (it != flash_attn_pipelines.end()) { return it->second; } - std::vector defines; std::string variant = "flash_attn"; - switch (context.key.kv_type) { + switch (key.kv_type) { case GGML_TYPE_F32: defines.push_back("KV_F32"); break; @@ -2010,111 +1968,206 @@ class ggml_webgpu_shader_lib { default: GGML_ABORT("Unsupported KV type for flash attention shader"); } - variant += std::string("_") + ggml_type_name(context.key.kv_type); + variant += std::string("_") + ggml_type_name(key.kv_type); - if (context.key.has_mask) { + if (key.has_mask) { defines.push_back("MASK"); variant += "_mask"; } - if (context.key.has_sinks) { + if (key.has_sinks) { defines.push_back("SINKS"); variant += "_sinks"; } - if (context.key.uses_logit_softcap) { + if (key.uses_logit_softcap) { defines.push_back("LOGIT_SOFTCAP"); variant += "_lgsc"; } - if (context.key.kv_direct) { + if (key.kv_direct) { defines.push_back("KV_DIRECT"); variant += "_kvdirect"; } - if (context.key.has_mask && context.key.use_vec) { - defines.push_back("BLK"); - variant += "_blk"; - } - defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(context.key.head_dim_qk)); - variant += std::string("_hsqk") + std::to_string(context.key.head_dim_qk); + defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(key.head_dim_qk)); + variant += std::string("_hsqk") + std::to_string(key.head_dim_qk); - defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(context.key.head_dim_v)); - variant += std::string("_hsv") + std::to_string(context.key.head_dim_v); + defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v)); + variant += std::string("_hsv") + std::to_string(key.head_dim_v); defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m)); defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n)); defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k)); - uint32_t q_tile = context.sg_mat_m; - uint32_t kv_tile = std::min(ggml_webgpu_flash_attn_max_kv_tile(context), - context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES); - if (context.key.use_vec) { - q_tile = 1; - kv_tile = std::max(context.sg_mat_n, std::min(32u, ggml_webgpu_flash_attn_max_kv_tile(context))); - kv_tile = (kv_tile / context.sg_mat_n) * context.sg_mat_n; - const uint32_t vec_ne = ggml_webgpu_flash_attn_pick_vec_ne(context.key); - defines.push_back(std::string("VEC_NE=") + std::to_string(vec_ne) + "u"); - } - if (context.key.kv_direct) { - GGML_ASSERT(kv_tile <= GGML_WEBGPU_KV_SEQ_PAD); + auto decisions = std::make_shared(); + decisions->q_tile = context.sg_mat_m; + + const uint32_t min_kv_tile = ggml_webgpu_flash_attn_max_kv_tile(context, key); + uint32_t kv_tile = std::min(min_kv_tile, context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES); + + if (key.kv_direct) { + kv_tile = std::min(kv_tile, GGML_WEBGPU_KV_SEQ_PAD); while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) { kv_tile -= context.sg_mat_n; } } - defines.push_back(std::string("Q_TILE=") + std::to_string(q_tile)); - defines.push_back(std::string("KV_TILE=") + std::to_string(kv_tile)); + decisions->kv_tile = kv_tile; + decisions->wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE); - uint32_t wg_size = 0; - if (context.key.use_vec) { - wg_size = std::max(1u, std::min(32u, context.max_subgroup_size)); - } else { - wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE); + defines.push_back(std::string("Q_TILE=") + std::to_string(decisions->q_tile)); + defines.push_back(std::string("KV_TILE=") + std::to_string(decisions->kv_tile)); + defines.push_back(std::string("WG_SIZE=") + std::to_string(decisions->wg_size)); + + webgpu_pipeline pipeline = + ggml_webgpu_create_pipeline(device, preprocessor.preprocess(wgsl_flash_attn, defines), variant); + pipeline.context = decisions; + flash_attn_pipelines[key] = pipeline; + return flash_attn_pipelines[key]; + } + + webgpu_pipeline get_flash_attn_vec_pipeline(const ggml_webgpu_shader_lib_context & context) { + const ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context); + auto it = flash_attn_vec_pipelines.find(key); + if (it != flash_attn_vec_pipelines.end()) { + return it->second; } - defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); - const char * shader_src = context.key.use_vec ? wgsl_flash_attn_vec_split : wgsl_flash_attn; + std::vector defines; + std::string variant = "flash_attn_vec"; + + switch (key.kv_type) { + case GGML_TYPE_F32: + defines.push_back("KV_F32"); + break; + case GGML_TYPE_F16: + defines.push_back("KV_F16"); + break; + case GGML_TYPE_Q4_0: + defines.push_back("KV_Q4_0"); + break; + case GGML_TYPE_Q8_0: + defines.push_back("KV_Q8_0"); + break; + default: + GGML_ABORT("Unsupported KV type for flash attention shader"); + } + variant += std::string("_") + ggml_type_name(key.kv_type); + + if (key.has_mask) { + defines.push_back("MASK"); + defines.push_back("BLK"); + variant += "_mask_blk"; + } + if (key.has_sinks) { + defines.push_back("SINKS"); + variant += "_sinks"; + } + if (key.uses_logit_softcap) { + defines.push_back("LOGIT_SOFTCAP"); + variant += "_lgsc"; + } + if (key.kv_direct) { + defines.push_back("KV_DIRECT"); + variant += "_kvdirect"; + } + + defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(key.head_dim_qk)); + variant += std::string("_hsqk") + std::to_string(key.head_dim_qk); + + defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v)); + variant += std::string("_hsv") + std::to_string(key.head_dim_v); + + defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m)); + defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n)); + defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k)); + defines.push_back("Q_TILE=1"); + + auto decisions = std::make_shared(); + decisions->kv_tile = ggml_webgpu_flash_attn_vec_get_kv_tile(context); + decisions->wg_size = std::max(1u, std::min(32u, context.max_subgroup_size)); + uint32_t vec_ne = 1u; + + // Keep conservative defaults unless this is the f16 vec-split shape family. + if (key.kv_type == GGML_TYPE_F16 && key.head_dim_qk == key.head_dim_v) { + switch (key.head_dim_qk) { + case 64: + case 192: + case 576: + vec_ne = 2u; + break; + case 96: + vec_ne = 4u; + break; + default: + break; + } + } + + defines.push_back(std::string("KV_TILE=") + std::to_string(decisions->kv_tile)); + defines.push_back(std::string("WG_SIZE=") + std::to_string(decisions->wg_size)); + defines.push_back(std::string("VEC_NE=") + std::to_string(vec_ne) + "u"); + webgpu_pipeline pipeline = - ggml_webgpu_create_pipeline(device, preprocessor.preprocess(shader_src, defines), variant); - auto decisions = std::make_shared(); - decisions->q_tile = q_tile; - decisions->kv_tile = kv_tile; - decisions->wg_size = wg_size; - pipeline.context = decisions; - flash_attn_pipelines[context.key] = pipeline; - return flash_attn_pipelines[context.key]; - } - - webgpu_pipeline get_flash_attn_blk_pipeline(const ggml_webgpu_flash_attn_blk_shader_lib_context & context) { - auto it = flash_attn_blk_pipelines.find(context.key); + ggml_webgpu_create_pipeline(device, preprocessor.preprocess(wgsl_flash_attn_vec_split, defines), variant); + pipeline.context = decisions; + flash_attn_vec_pipelines[key] = pipeline; + return flash_attn_vec_pipelines[key]; + } + + webgpu_pipeline get_flash_attn_blk_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_flash_attn_blk_pipeline_key key = {}; + key.kv_tile = ggml_webgpu_flash_attn_vec_get_kv_tile(context); + auto it = flash_attn_blk_pipelines.find(key); if (it != flash_attn_blk_pipelines.end()) { return it->second; } - ggml_webgpu_processed_shader processed = - ggml_webgpu_preprocess_flash_attn_blk_shader(preprocessor, wgsl_flash_attn_vec_blk, context); - webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed.wgsl, processed.variant); - flash_attn_blk_pipelines[context.key] = pipeline; - return flash_attn_blk_pipelines[context.key]; + std::vector defines; + std::string variant = "flash_attn_vec_blk"; + + defines.push_back(std::string("KV_TILE=") + std::to_string(key.kv_tile)); + variant += std::string("_kvt") + std::to_string(key.kv_tile); + + uint32_t wg_size = 1; + while ((wg_size << 1) <= context.max_wg_size) { + wg_size <<= 1; + } + defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); + variant += std::string("_wg") + std::to_string(wg_size); + + webgpu_pipeline pipeline = + ggml_webgpu_create_pipeline(device, preprocessor.preprocess(wgsl_flash_attn_vec_blk, defines), variant); + flash_attn_blk_pipelines[key] = pipeline; + return flash_attn_blk_pipelines[key]; } - webgpu_pipeline get_flash_attn_vec_reduce_pipeline( - const ggml_webgpu_flash_attn_vec_reduce_shader_lib_context & context) { - auto it = flash_attn_vec_reduce_pipelines.find(context.key); + webgpu_pipeline get_flash_attn_vec_reduce_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_flash_attn_vec_reduce_pipeline_key key = {}; + key.head_dim_v = (uint32_t) context.src2->ne[0]; + key.wg_size = context.max_wg_size; + auto it = flash_attn_vec_reduce_pipelines.find(key); if (it != flash_attn_vec_reduce_pipelines.end()) { return it->second; } - ggml_webgpu_processed_shader processed = - ggml_webgpu_preprocess_flash_attn_vec_reduce_shader(preprocessor, wgsl_flash_attn_vec_reduce, context); - webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed.wgsl, processed.variant); - flash_attn_vec_reduce_pipelines[context.key] = pipeline; - return flash_attn_vec_reduce_pipelines[context.key]; + std::vector defines; + std::string variant = "flash_attn_vec_reduce"; + + defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v)); + variant += std::string("_hsv") + std::to_string(key.head_dim_v); + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + variant += std::string("_wg") + std::to_string(context.max_wg_size); + + webgpu_pipeline pipeline = + ggml_webgpu_create_pipeline(device, preprocessor.preprocess(wgsl_flash_attn_vec_reduce, defines), variant); + flash_attn_vec_reduce_pipelines[key] = pipeline; + return flash_attn_vec_reduce_pipelines[key]; } webgpu_pipeline get_cpy_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_cpy_pipeline_key key = { - .src_type = context.src0->type, - .dst_type = context.dst->type, - }; + ggml_webgpu_cpy_pipeline_key key = {}; + key.src_type = context.src0->type; + key.dst_type = context.dst->type; auto it = cpy_pipelines.find(key); if (it != cpy_pipelines.end()) { @@ -2166,11 +2219,10 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_glu_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_glu_pipeline_key key = { - .glu_op = ggml_get_glu_op(context.dst), - .type = context.dst->type, - .split = (context.src1 != nullptr), - }; + ggml_webgpu_glu_pipeline_key key = {}; + key.glu_op = ggml_get_glu_op(context.dst); + key.type = context.dst->type; + key.split = (context.src1 != nullptr); auto it = glu_pipelines.find(key); if (it != glu_pipelines.end()) { @@ -2239,11 +2291,10 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_rope_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_rope_pipeline_key key = { - .type = context.dst->type, - .inplace = context.inplace, - .has_ff = (context.src2 != nullptr), - }; + ggml_webgpu_rope_pipeline_key key = {}; + key.type = context.dst->type; + key.inplace = context.inplace; + key.has_ff = (context.src2 != nullptr); auto it = rope_pipelines.find(key); if (it != rope_pipelines.end()) { @@ -2288,12 +2339,11 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_soft_max_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_soft_max_pipeline_key key = { - .mask_type = context.src1 ? context.src1->type : GGML_TYPE_F32, - .has_mask = (context.src1 != nullptr), - .has_sink = (context.src2 != nullptr), - .inplace = context.inplace, - }; + ggml_webgpu_soft_max_pipeline_key key = {}; + key.mask_type = context.src1 ? context.src1->type : GGML_TYPE_F32; + key.has_mask = (context.src1 != nullptr); + key.has_sink = (context.src2 != nullptr); + key.inplace = context.inplace; auto it = soft_max_pipelines.find(key); if (it != soft_max_pipelines.end()) { @@ -2359,25 +2409,6 @@ class ggml_webgpu_shader_lib { pipeline_desc.layout = nullptr; // nullptr means auto layout return { device.CreateComputePipeline(&pipeline_desc), label }; } - - static uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_flash_attn_shader_lib_context & context) { - const size_t limit_bytes = context.wg_mem_limit_bytes; - const size_t q_tile = context.sg_mat_m; - const size_t base_q_bytes = - (context.key.head_dim_qk + context.key.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES + - 2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES; - size_t bytes_per_kv = 0; - if (!context.key.kv_direct) { - bytes_per_kv += std::max(context.key.head_dim_qk, context.key.head_dim_v); - } - if (context.key.has_mask) { - bytes_per_kv += q_tile; - } - bytes_per_kv += q_tile; - bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES; - const uint32_t max_kv_tile = (limit_bytes - base_q_bytes) / bytes_per_kv; - return (max_kv_tile / context.sg_mat_n) * context.sg_mat_n; - } }; #endif // GGML_WEBGPU_SHADER_LIB_HPP diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 01637e2ddab..e7bda817a28 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -41,6 +41,12 @@ static inline void compute_2d_workgroups(uint32_t total_wg, uint32_t max_per_dim wg_x = CEIL_DIV(total_wg, wg_y); } +static inline uint32_t ggml_webgpu_u32_from_f32(float value) { + uint32_t bits; + memcpy(&bits, &value, sizeof(bits)); + return bits; +} + #ifdef GGML_WEBGPU_DEBUG # define WEBGPU_LOG_DEBUG(msg) std::cout << msg << std::endl # define WEBGPU_DEBUG_BUF_ELEMS 512 @@ -369,6 +375,96 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device, buffer = device.CreateBuffer(&buffer_desc); } +static size_t ggml_webgpu_tensor_offset(const ggml_tensor * tensor) { + return webgpu_tensor_offset(tensor) + tensor->view_offs; +} + +static wgpu::Buffer ggml_webgpu_tensor_buf(const ggml_tensor * tensor) { + ggml_backend_webgpu_buffer_context * ctx = (ggml_backend_webgpu_buffer_context *) tensor->buffer->context; + return ctx->buffer; +} + +static size_t ggml_webgpu_tensor_misalignment(webgpu_context & ctx, const ggml_tensor * t) { + size_t offset = ggml_webgpu_tensor_offset(t); + return offset & (ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment - 1); +} + +static bool ggml_webgpu_flash_attn_use_vec(webgpu_global_context & global_ctx, + const ggml_tensor * Q, + const ggml_tensor * K, + const ggml_tensor * V) { + const size_t alignment = global_ctx->capabilities.limits.minStorageBufferOffsetAlignment; + const uint32_t k_offset_elems = + (uint32_t) ((ggml_webgpu_tensor_offset(K) & (alignment - 1)) / ggml_type_size(K->type)); + const uint32_t v_offset_elems = + (uint32_t) ((ggml_webgpu_tensor_offset(V) & (alignment - 1)) / ggml_type_size(V->type)); + const bool f16_vec4_aligned = (k_offset_elems % 4u == 0u) && (v_offset_elems % 4u == 0u); + const bool kv_vec_type_supported = + K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0; + + return (Q->ne[1] < 20) && (Q->ne[0] % 32 == 0) && (V->ne[0] % 4 == 0) && kv_vec_type_supported && + (K->type != GGML_TYPE_F16 || f16_vec4_aligned) && (V->type == K->type); +} + +static size_t ggml_webgpu_tensor_align_offset(webgpu_context & ctx, const ggml_tensor * t) { + size_t offset = ggml_webgpu_tensor_offset(t); + return offset & ~(ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment - 1); +} + +static size_t ggml_webgpu_tensor_binding_size(webgpu_context & ctx, ggml_tensor * t) { + return ROUNDUP_POW2(ggml_nbytes(t) + ggml_webgpu_tensor_misalignment(ctx, t), WEBGPU_STORAGE_BUF_BINDING_MULT); +} + +// Used to determine if two tensors are the same for in-place operations +static bool ggml_webgpu_tensor_equal(ggml_tensor * a, ggml_tensor * b) { + return (ggml_webgpu_tensor_buf(a).Get() == ggml_webgpu_tensor_buf(b).Get()) && + (ggml_webgpu_tensor_offset(a) == ggml_webgpu_tensor_offset(b)); +} + +// Used to determine if two tensors share the same buffer and their byte ranges overlap, +static bool ggml_webgpu_tensor_overlap(ggml_tensor * a, ggml_tensor * b) { + return (ggml_webgpu_tensor_buf(a).Get() == ggml_webgpu_tensor_buf(b).Get()) && + ggml_webgpu_tensor_offset(a) < (ggml_webgpu_tensor_offset(b) + ggml_nbytes(b)) && + ggml_webgpu_tensor_offset(b) < (ggml_webgpu_tensor_offset(a) + ggml_nbytes(a)); +} + +struct binary_overlap_flags { + bool inplace; // src0 == dst + bool overlap; // src1 == dst + bool src_overlap; +}; + +static binary_overlap_flags ggml_webgpu_detect_binary_overlap(ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { + binary_overlap_flags flags = {}; + flags.inplace = ggml_webgpu_tensor_equal(src0, dst); + flags.overlap = ggml_webgpu_tensor_overlap(src1, dst); + flags.src_overlap = ggml_webgpu_tensor_overlap(src0, src1); + + return flags; +} + +static wgpu::BindGroupEntry ggml_webgpu_make_bind_group_entry(uint32_t binding, + wgpu::Buffer buffer, + uint64_t offset, + uint64_t size) { + wgpu::BindGroupEntry entry = {}; + entry.binding = binding; + entry.buffer = std::move(buffer); + entry.offset = offset; + entry.size = size; + return entry; +} + +static wgpu::BindGroupEntry ggml_webgpu_make_tensor_bind_group_entry(webgpu_context & ctx, + uint32_t binding, + ggml_tensor * tensor) { + return ggml_webgpu_make_bind_group_entry(binding, ggml_webgpu_tensor_buf(tensor), + ggml_webgpu_tensor_align_offset(ctx, tensor), + ggml_webgpu_tensor_binding_size(ctx, tensor)); +} + /** End WebGPU object initializations */ /** WebGPU Actions */ @@ -480,10 +576,8 @@ static webgpu_encoded_op ggml_backend_webgpu_build_multi(webgpu_context & std::vector entries = dispatch.bind_group_entries; uint32_t params_binding_num = entries.size(); - entries.push_back({ .binding = params_binding_num, - .buffer = ctx->param_arena.buffer, - .offset = param_offset, - .size = ctx->param_arena.slot_size }); + entries.push_back(ggml_webgpu_make_bind_group_entry(params_binding_num, ctx->param_arena.buffer, param_offset, + ctx->param_arena.slot_size)); wgpu::BindGroupDescriptor bind_group_desc; bind_group_desc.layout = dispatch.pipeline.pipeline.GetBindGroupLayout(0); @@ -502,13 +596,17 @@ static webgpu_encoded_op ggml_backend_webgpu_build_multi(webgpu_context & #ifdef GGML_WEBGPU_GPU_PROFILE for (size_t i = 0; i < dispatches.size(); i++) { GGML_ASSERT(ctx->profile_timestamp_query_count + 2 <= WEBGPU_MAX_PROFILE_QUERY_COUNT); - const uint32_t query_begin = ctx->profile_timestamp_query_count++; - const uint32_t query_end = ctx->profile_timestamp_query_count++; - wgpu::PassTimestampWrites ts_writes = { .querySet = ctx->profile_timestamp_query_set, - .beginningOfPassWriteIndex = query_begin, - .endOfPassWriteIndex = query_end }; - wgpu::ComputePassDescriptor pass_desc = { .timestampWrites = &ts_writes }; - wgpu::ComputePassEncoder pass = ctx->active_command_encoder.BeginComputePass(&pass_desc); + const uint32_t query_begin = ctx->profile_timestamp_query_count++; + const uint32_t query_end = ctx->profile_timestamp_query_count++; + + wgpu::PassTimestampWrites ts_writes = {}; + ts_writes.querySet = ctx->profile_timestamp_query_set; + ts_writes.beginningOfPassWriteIndex = query_begin; + ts_writes.endOfPassWriteIndex = query_end; + wgpu::ComputePassDescriptor pass_desc = {}; + pass_desc.timestampWrites = &ts_writes; + + wgpu::ComputePassEncoder pass = ctx->active_command_encoder.BeginComputePass(&pass_desc); pass.SetPipeline(dispatches[i].pipeline.pipeline); pass.SetBindGroup(0, bind_groups[i]); @@ -544,17 +642,19 @@ static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx, uint32_t value, size_t offset, size_t size) { - std::vector params = { (uint32_t) offset, (uint32_t) size, value }; - std::vector entries = { - { .binding = 0, .buffer = buf, .offset = 0, .size = buf.GetSize() } - }; - size_t bytes_per_wg = WEBGPU_MAX_WG_SIZE * ctx->capabilities.memset_bytes_per_thread; - uint32_t wg_x = CEIL_DIV(size + 3, bytes_per_wg); + std::vector params = { (uint32_t) offset, (uint32_t) size, value }; + std::vector entries = { ggml_webgpu_make_bind_group_entry(0, buf, 0, buf.GetSize()) }; + size_t bytes_per_wg = WEBGPU_MAX_WG_SIZE * ctx->capabilities.memset_bytes_per_thread; + uint32_t wg_x = CEIL_DIV(size + 3, bytes_per_wg); ctx->queue.WriteBuffer(ctx->memset_params_buf, 0, params.data(), params.size() * sizeof(uint32_t)); - entries.push_back( - { .binding = 1, .buffer = ctx->memset_params_buf, .offset = 0, .size = WEBGPU_PARAMS_BUF_SIZE_BYTES }); + wgpu::BindGroupEntry params_entry = {}; + params_entry.binding = 1; + params_entry.buffer = ctx->memset_params_buf; + params_entry.offset = 0; + params_entry.size = WEBGPU_PARAMS_BUF_SIZE_BYTES; + entries.push_back(params_entry); wgpu::BindGroupDescriptor bind_group_desc; bind_group_desc.layout = ctx->memset_pipeline.pipeline.GetBindGroupLayout(0); @@ -632,65 +732,11 @@ static void ggml_backend_webgpu_free(ggml_backend_t backend) { delete backend; } -static size_t ggml_webgpu_tensor_offset(const ggml_tensor * tensor) { - return webgpu_tensor_offset(tensor) + tensor->view_offs; -} - -static wgpu::Buffer ggml_webgpu_tensor_buf(const ggml_tensor * tensor) { - ggml_backend_webgpu_buffer_context * ctx = (ggml_backend_webgpu_buffer_context *) tensor->buffer->context; - return ctx->buffer; -} - -static size_t ggml_webgpu_tensor_misalignment(webgpu_context & ctx, const ggml_tensor * t) { - size_t offset = ggml_webgpu_tensor_offset(t); - return offset & (ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment - 1); -} - -static size_t ggml_webgpu_tensor_align_offset(webgpu_context & ctx, const ggml_tensor * t) { - size_t offset = ggml_webgpu_tensor_offset(t); - return offset & ~(ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment - 1); -} - -static size_t ggml_webgpu_tensor_binding_size(webgpu_context & ctx, ggml_tensor * t) { - return ROUNDUP_POW2(ggml_nbytes(t) + ggml_webgpu_tensor_misalignment(ctx, t), WEBGPU_STORAGE_BUF_BINDING_MULT); -} - -// Used to determine if two tensors are the same for in-place operations -static bool ggml_webgpu_tensor_equal(ggml_tensor * a, ggml_tensor * b) { - return (ggml_webgpu_tensor_buf(a).Get() == ggml_webgpu_tensor_buf(b).Get()) && - (ggml_webgpu_tensor_offset(a) == ggml_webgpu_tensor_offset(b)); -} - -// Used to determine if two tensors share the same buffer and their byte ranges overlap, -static bool ggml_webgpu_tensor_overlap(ggml_tensor * a, ggml_tensor * b) { - return (ggml_webgpu_tensor_buf(a).Get() == ggml_webgpu_tensor_buf(b).Get()) && - ggml_webgpu_tensor_offset(a) < (ggml_webgpu_tensor_offset(b) + ggml_nbytes(b)) && - ggml_webgpu_tensor_offset(b) < (ggml_webgpu_tensor_offset(a) + ggml_nbytes(a)); -} - -struct binary_overlap_flags { - bool inplace; // src0 == dst - bool overlap; // src1 == dst - bool src_overlap; -}; - -static binary_overlap_flags ggml_webgpu_detect_binary_overlap(ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * dst) { - binary_overlap_flags flags = {}; - flags.inplace = ggml_webgpu_tensor_equal(src0, dst); - flags.overlap = ggml_webgpu_tensor_overlap(src1, dst); - flags.src_overlap = ggml_webgpu_tensor_overlap(src0, src1); - - return flags; -} - static webgpu_encoded_op ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; webgpu_pipeline pipeline = ctx->shader_lib->get_cpy_pipeline(shader_lib_ctx); @@ -712,14 +758,8 @@ static webgpu_encoded_op ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src }; std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src), - .offset = ggml_webgpu_tensor_align_offset(ctx, src), - .size = ggml_webgpu_tensor_binding_size(ctx, src) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) } + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst), }; uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); @@ -732,13 +772,12 @@ static webgpu_encoded_op ggml_webgpu_set(webgpu_context & ctx, ggml_tensor * dst) { const bool inplace = ggml_webgpu_tensor_equal(src0, dst); - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src0, - .src1 = src1, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - .inplace = inplace, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.inplace = inplace; webgpu_pipeline pipeline = ctx->shader_lib->get_set_pipeline(shader_lib_ctx); @@ -772,29 +811,21 @@ static webgpu_encoded_op ggml_webgpu_set(webgpu_context & ctx, std::vector entries; uint32_t binding_index = 0; if (!inplace) { - entries.push_back({ .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = ggml_webgpu_tensor_align_offset(ctx, src0), - .size = ggml_webgpu_tensor_binding_size(ctx, src0) }); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0)); binding_index++; } - entries.push_back({ .binding = binding_index, - .buffer = ggml_webgpu_tensor_buf(src1), - .offset = ggml_webgpu_tensor_align_offset(ctx, src1), - .size = ggml_webgpu_tensor_binding_size(ctx, src1) }); - entries.push_back({ .binding = binding_index + 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index, src1)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index + 1, dst)); uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } static webgpu_encoded_op ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; webgpu_pipeline pipeline = ctx->shader_lib->get_pad_pipeline(shader_lib_ctx); @@ -832,14 +863,8 @@ static webgpu_encoded_op ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src }; std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src), - .offset = ggml_webgpu_tensor_align_offset(ctx, src), - .size = ggml_webgpu_tensor_binding_size(ctx, src) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) } + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst), }; uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); @@ -850,13 +875,12 @@ static webgpu_encoded_op ggml_webgpu_solve_tri(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) { - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src0, - .src1 = src1, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; webgpu_pipeline pipeline = ctx->shader_lib->get_solve_tri_pipeline(shader_lib_ctx); @@ -888,18 +912,9 @@ static webgpu_encoded_op ggml_webgpu_solve_tri(webgpu_context & ctx, }; std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = ggml_webgpu_tensor_align_offset(ctx, src0), - .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(src1), - .offset = ggml_webgpu_tensor_align_offset(ctx, src1), - .size = ggml_webgpu_tensor_binding_size(ctx, src1) }, - { .binding = 2, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) } + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst), }; const uint32_t wg_x = CEIL_DIV((uint32_t) src1->ne[0], decisions->wg_size); @@ -911,12 +926,11 @@ static webgpu_encoded_op ggml_webgpu_ssm_conv(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) { - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src0, - .src1 = src1, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; webgpu_pipeline pipeline = ctx->shader_lib->get_ssm_conv_pipeline(shader_lib_ctx); auto * decisions = static_cast(pipeline.context.get()); @@ -944,18 +958,9 @@ static webgpu_encoded_op ggml_webgpu_ssm_conv(webgpu_context & ctx, }; std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = ggml_webgpu_tensor_align_offset(ctx, src0), - .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(src1), - .offset = ggml_webgpu_tensor_align_offset(ctx, src1), - .size = ggml_webgpu_tensor_binding_size(ctx, src1) }, - { .binding = 2, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) } + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst), }; const uint32_t wg_x = CEIL_DIV((uint32_t) src0->ne[1], decisions->block_size); @@ -971,15 +976,14 @@ static webgpu_encoded_op ggml_webgpu_gated_delta_net(webgpu_context & ctx, ggml_tensor * src4, ggml_tensor * src5, ggml_tensor * dst) { - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src0, - .src1 = src1, - .src2 = src2, - .src3 = src3, - .src4 = src4, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.src2 = src2; + shader_lib_ctx.src3 = src3; + shader_lib_ctx.src4 = src4; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; webgpu_pipeline pipeline = ctx->shader_lib->get_gated_delta_net_pipeline(shader_lib_ctx); @@ -1015,34 +1019,10 @@ static webgpu_encoded_op ggml_webgpu_gated_delta_net(webgpu_context & ctx, }; std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = ggml_webgpu_tensor_align_offset(ctx, src0), - .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(src1), - .offset = ggml_webgpu_tensor_align_offset(ctx, src1), - .size = ggml_webgpu_tensor_binding_size(ctx, src1) }, - { .binding = 2, - .buffer = ggml_webgpu_tensor_buf(src2), - .offset = ggml_webgpu_tensor_align_offset(ctx, src2), - .size = ggml_webgpu_tensor_binding_size(ctx, src2) }, - { .binding = 3, - .buffer = ggml_webgpu_tensor_buf(src3), - .offset = ggml_webgpu_tensor_align_offset(ctx, src3), - .size = ggml_webgpu_tensor_binding_size(ctx, src3) }, - { .binding = 4, - .buffer = ggml_webgpu_tensor_buf(src4), - .offset = ggml_webgpu_tensor_align_offset(ctx, src4), - .size = ggml_webgpu_tensor_binding_size(ctx, src4) }, - { .binding = 5, - .buffer = ggml_webgpu_tensor_buf(src5), - .offset = ggml_webgpu_tensor_align_offset(ctx, src5), - .size = ggml_webgpu_tensor_binding_size(ctx, src5) }, - { .binding = 6, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) } + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, src2), ggml_webgpu_make_tensor_bind_group_entry(ctx, 3, src3), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 4, src4), ggml_webgpu_make_tensor_bind_group_entry(ctx, 5, src5), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 6, dst), }; return ggml_backend_webgpu_build(ctx, pipeline, params, entries, h, n_seqs); @@ -1058,12 +1038,11 @@ static std::optional ggml_webgpu_set_rows(webgpu_context & ct return std::nullopt; } - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src, - .src1 = idx, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.src1 = idx; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; webgpu_pipeline pipeline = ctx->shader_lib->get_set_rows_pipeline(shader_lib_ctx); @@ -1086,25 +1065,14 @@ static std::optional ggml_webgpu_set_rows(webgpu_context & ct }; std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src), - .offset = ggml_webgpu_tensor_align_offset(ctx, src), - .size = ggml_webgpu_tensor_binding_size(ctx, src) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(idx), - .offset = ggml_webgpu_tensor_align_offset(ctx, idx), - .size = ggml_webgpu_tensor_binding_size(ctx, idx) }, - { .binding = 2, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) } + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, idx), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst), }; if (decisions->i64_idx) { - entries.push_back({ .binding = 3, - .buffer = ctx->set_rows_dev_error_buf, - .offset = 0, - .size = ctx->set_rows_dev_error_buf.GetSize() }); + entries.push_back(ggml_webgpu_make_bind_group_entry(3, ctx->set_rows_dev_error_buf, 0, + ctx->set_rows_dev_error_buf.GetSize())); } uint32_t threads; @@ -1131,12 +1099,11 @@ static webgpu_encoded_op ggml_webgpu_get_rows(webgpu_context & ctx, ggml_tensor * dst) { const bool float_parallel = src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16 || src->type == GGML_TYPE_I32; - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src, - .src1 = nullptr, - .dst = dst, - .max_wg_size = WEBGPU_MAX_WG_SIZE, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.src1 = nullptr; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = WEBGPU_MAX_WG_SIZE; webgpu_pipeline pipeline = ctx->shader_lib->get_get_rows_pipeline(shader_lib_ctx); auto * decisions = static_cast(pipeline.context.get()); @@ -1160,20 +1127,9 @@ static webgpu_encoded_op ggml_webgpu_get_rows(webgpu_context & ctx, (uint32_t) (idx->ne[1]), (uint32_t) (idx->ne[2]) }; - std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src), - .offset = ggml_webgpu_tensor_align_offset(ctx, src), - .size = ggml_webgpu_tensor_binding_size(ctx, src) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(idx), - .offset = ggml_webgpu_tensor_align_offset(ctx, idx), - .size = ggml_webgpu_tensor_binding_size(ctx, idx) }, - { .binding = 2, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) } - }; + std::vector entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, idx), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst) }; uint32_t blocks_per_row = (uint32_t) (dst->ne[0] / (src->type == GGML_TYPE_F32 && dst->ne[0] % 4 == 0 ? 4 : 1)); uint32_t total_rows = (uint32_t) (dst->ne[1] * dst->ne[2] * dst->ne[3]); @@ -1225,17 +1181,16 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, break; } - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src0, - .src1 = src1, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - .supports_subgroup_matrix = ctx->global_ctx->capabilities.supports_subgroup_matrix, - .sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m, - .sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n, - .sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k, - .max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.supports_subgroup_matrix = ctx->global_ctx->capabilities.supports_subgroup_matrix; + shader_lib_ctx.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m; + shader_lib_ctx.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n; + shader_lib_ctx.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k; + shader_lib_ctx.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size; // Get or create pipeline webgpu_pipeline pipeline; @@ -1270,18 +1225,9 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, // Build bind group entries std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = ggml_webgpu_tensor_align_offset(ctx, src0), - .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(src1), - .offset = ggml_webgpu_tensor_align_offset(ctx, src1), - .size = ggml_webgpu_tensor_binding_size(ctx, src1) }, - { .binding = 2, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }, + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst), }; // Calculate workgroup dimensions @@ -1333,13 +1279,12 @@ static webgpu_encoded_op ggml_webgpu_mul_mat_id(webgpu_context & ctx, ggml_tensor * src1, ggml_tensor * src2, ggml_tensor * dst) { - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src0, - .src1 = src1, - .src2 = src2, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.src2 = src2; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; // Get or create pipeline webgpu_pipeline gather_pipeline, main_pipeline; @@ -1380,22 +1325,14 @@ static webgpu_encoded_op ggml_webgpu_mul_mat_id(webgpu_context & ctx, // bind group entries for mul_mat_id_gather.wgsl std::vector gather_entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src2), - .offset = ggml_webgpu_tensor_align_offset(ctx, src2), - .size = ggml_webgpu_tensor_binding_size(ctx, src2) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = gathered_expert_used_align_offset, - .size = gathered_binding_size }, - { .binding = 2, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = gathered_tokens_align_offset, - .size = gathered_binding_size }, - { .binding = 3, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = gathered_count_ids_align_offset, - .size = gathered_count_ids_binding_size }, + ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(src2), ggml_webgpu_tensor_align_offset(ctx, src2), + ggml_webgpu_tensor_binding_size(ctx, src2)), + ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(dst), gathered_expert_used_align_offset, + gathered_binding_size), + ggml_webgpu_make_bind_group_entry(2, ggml_webgpu_tensor_buf(dst), gathered_tokens_align_offset, + gathered_binding_size), + ggml_webgpu_make_bind_group_entry(3, ggml_webgpu_tensor_buf(dst), gathered_count_ids_align_offset, + gathered_count_ids_binding_size), }; const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; @@ -1427,30 +1364,18 @@ static webgpu_encoded_op ggml_webgpu_mul_mat_id(webgpu_context & ctx, // bind group entries for mul_mat_id.wgsl std::vector main_entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = ggml_webgpu_tensor_align_offset(ctx, src0), - .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(src1), - .offset = ggml_webgpu_tensor_align_offset(ctx, src1), - .size = ggml_webgpu_tensor_binding_size(ctx, src1) }, - { .binding = 2, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }, - { .binding = 3, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = gathered_expert_used_align_offset, - .size = gathered_binding_size }, - { .binding = 4, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = gathered_tokens_align_offset, - .size = gathered_binding_size }, - { .binding = 5, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = gathered_count_ids_align_offset, - .size = gathered_count_ids_binding_size }, + ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(src0), ggml_webgpu_tensor_align_offset(ctx, src0), + ggml_webgpu_tensor_binding_size(ctx, src0)), + ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(src1), ggml_webgpu_tensor_align_offset(ctx, src1), + ggml_webgpu_tensor_binding_size(ctx, src1)), + ggml_webgpu_make_bind_group_entry(2, ggml_webgpu_tensor_buf(dst), ggml_webgpu_tensor_align_offset(ctx, dst), + ggml_webgpu_tensor_binding_size(ctx, dst)), + ggml_webgpu_make_bind_group_entry(3, ggml_webgpu_tensor_buf(dst), gathered_expert_used_align_offset, + gathered_binding_size), + ggml_webgpu_make_bind_group_entry(4, ggml_webgpu_tensor_buf(dst), gathered_tokens_align_offset, + gathered_binding_size), + ggml_webgpu_make_bind_group_entry(5, ggml_webgpu_tensor_buf(dst), gathered_count_ids_align_offset, + gathered_count_ids_binding_size), }; // Calculate workgroup dimensions @@ -1486,11 +1411,9 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, ggml_tensor * mask, ggml_tensor * sinks, ggml_tensor * dst) { - float scale = *(float *) dst->op_params; - float max_bias; - memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); - float logit_softcap; - memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float)); + float scale = ggml_get_op_params_f32(dst, 0); + float max_bias = ggml_get_op_params_f32(dst, 1); + float logit_softcap = ggml_get_op_params_f32(dst, 2); if (logit_softcap != 0.0f) { scale /= logit_softcap; } @@ -1522,86 +1445,53 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, (uint32_t) (V->nb[3] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 3 has_mask ? (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)) : 0, // stride of mask dim 3 (uint32_t) (Q->ne[2] / K->ne[2]), // repeat factor for K/V in dim 2 (MHA/MQA/GQA) - *(uint32_t *) &scale, // scale (possibly adjusted for logit softcap) - *(uint32_t *) &max_bias, - *(uint32_t *) &logit_softcap, - *(uint32_t *) &n_head_log2, - *(uint32_t *) &m0, - *(uint32_t *) &m1 + ggml_webgpu_u32_from_f32(scale), // scale (possibly adjusted for logit softcap) + ggml_webgpu_u32_from_f32(max_bias), + ggml_webgpu_u32_from_f32(logit_softcap), + ggml_webgpu_u32_from_f32(n_head_log2), + ggml_webgpu_u32_from_f32(m0), + ggml_webgpu_u32_from_f32(m1) }; std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(Q), - .offset = ggml_webgpu_tensor_align_offset(ctx, Q), - .size = ggml_webgpu_tensor_binding_size(ctx, Q) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(K), - .offset = ggml_webgpu_tensor_align_offset(ctx, K), - .size = ggml_webgpu_tensor_binding_size(ctx, K) }, - { .binding = 2, - .buffer = ggml_webgpu_tensor_buf(V), - .offset = ggml_webgpu_tensor_align_offset(ctx, V), - .size = ggml_webgpu_tensor_binding_size(ctx, V) } + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, Q), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, K), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, V), }; uint32_t binding_index = 3; if (has_mask) { - entries.push_back({ .binding = binding_index++, - .buffer = ggml_webgpu_tensor_buf(mask), - .offset = ggml_webgpu_tensor_align_offset(ctx, mask), - .size = ggml_webgpu_tensor_binding_size(ctx, mask) }); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index++, mask)); } if (has_sinks) { - entries.push_back({ .binding = binding_index++, - .buffer = ggml_webgpu_tensor_buf(sinks), - .offset = ggml_webgpu_tensor_align_offset(ctx, sinks), - .size = ggml_webgpu_tensor_binding_size(ctx, sinks) }); - } - entries.push_back({ .binding = binding_index++, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); - - const uint32_t k_offset_elems = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type)); - const uint32_t v_offset_elems = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type)); - const bool f16_vec4_aligned = (k_offset_elems % 4u == 0u) && (v_offset_elems % 4u == 0u); - - const bool kv_direct = (K->type == GGML_TYPE_F16) && f16_vec4_aligned && - (Q->ne[0] % ctx->global_ctx->capabilities.sg_mat_k == 0) && - (K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0); - - const bool kv_vec_type_supported = - K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0; - const bool use_vec = (Q->ne[1] < 20) && (Q->ne[0] % 32 == 0) && (V->ne[0] % 4 == 0) && kv_vec_type_supported && - (K->type != GGML_TYPE_F16 || f16_vec4_aligned) && (V->type == K->type); - const uint32_t vec_nwg_cap = std::max(1u, std::min(32u, ctx->global_ctx->capabilities.max_subgroup_size)); - const bool use_blk = use_vec && has_mask; - - ggml_webgpu_flash_attn_pipeline_key key = { - .kv_type = K->type, - .head_dim_qk = (uint32_t) Q->ne[0], - .head_dim_v = (uint32_t) V->ne[0], - .kv_direct = kv_direct, - .has_mask = static_cast(has_mask), - .has_sinks = static_cast(has_sinks), - .uses_logit_softcap = logit_softcap != 0.0f, - .use_vec = use_vec, - }; - - ggml_webgpu_flash_attn_shader_lib_context shader_lib_ctx = { - .key = key, - .sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m, - .sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n, - .sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k, - .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize, - .max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size, - }; - webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_pipeline(shader_lib_ctx); - - auto * decisions = static_cast(pipeline.context.get()); - - uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions->q_tile); - uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index++, sinks)); + } + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index++, dst)); + + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = Q; + shader_lib_ctx.src1 = K; + shader_lib_ctx.src2 = V; + shader_lib_ctx.src3 = mask; + shader_lib_ctx.src4 = sinks; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; + shader_lib_ctx.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m; + shader_lib_ctx.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n; + shader_lib_ctx.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k; + shader_lib_ctx.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size; + const bool use_vec = ggml_webgpu_flash_attn_use_vec(ctx->global_ctx, Q, K, V); + webgpu_pipeline pipeline = use_vec ? ctx->shader_lib->get_flash_attn_vec_pipeline(shader_lib_ctx) : + ctx->shader_lib->get_flash_attn_pipeline(shader_lib_ctx); + + if (!use_vec) { + auto * decisions = static_cast(pipeline.context.get()); + uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions->q_tile); + uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); + } + + auto * decisions = static_cast(pipeline.context.get()); wgpu::Buffer blk_buf = {}; uint64_t blk_size_bytes = 0; @@ -1609,197 +1499,162 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, uint32_t blk_nblk1 = 0; uint32_t blk_batch_count = 0; - if (use_vec) { - uint32_t nwg = 1u; - const uint64_t kv_span = (uint64_t) std::max(1u, decisions->kv_tile); - while ((2u * nwg * kv_span) < (uint64_t) K->ne[1] && nwg < vec_nwg_cap) { - nwg <<= 1; - } - nwg = std::min(nwg, vec_nwg_cap); - GGML_ASSERT(nwg <= ctx->global_ctx->capabilities.max_subgroup_size); - const uint64_t nrows = (uint64_t) Q->ne[1] * Q->ne[2] * Q->ne[3]; - const bool use_vec_reduce = nwg > 1u; - GGML_ASSERT(nrows <= UINT32_MAX); - - uint64_t tmp_stats_base = 0; - uint64_t tmp_size_bytes = 0; - wgpu::Buffer tmp_buf = {}; - uint64_t tmp_bind_offset = 0; - uint64_t tmp_bind_size = 0; - const size_t align_bytes = ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment; - const size_t dst_offset = ggml_webgpu_tensor_offset(dst); - size_t scratch_offset = ROUNDUP_POW2(dst_offset + ggml_nbytes(dst), align_bytes); - - if (use_vec_reduce) { - const uint64_t tmp_data_elems = nrows * (uint64_t) V->ne[0] * nwg; - const uint64_t tmp_stats_elems = nrows * 2u * nwg; - tmp_stats_base = tmp_data_elems; - tmp_size_bytes = - ROUNDUP_POW2((tmp_data_elems + tmp_stats_elems) * sizeof(float), WEBGPU_STORAGE_BUF_BINDING_MULT); - GGML_ASSERT(tmp_stats_base <= UINT32_MAX); - tmp_buf = ggml_webgpu_tensor_buf(dst); - tmp_bind_offset = scratch_offset; - tmp_bind_size = tmp_size_bytes; - scratch_offset = ROUNDUP_POW2(scratch_offset + tmp_size_bytes, align_bytes); - } else { - // nwg==1 writes final dst directly in vec-split; keep tmp binding valid without extra allocation. - tmp_buf = ggml_webgpu_tensor_buf(dst); - tmp_bind_offset = ggml_webgpu_tensor_align_offset(ctx, dst); - tmp_bind_size = ggml_webgpu_tensor_binding_size(ctx, dst); - } - - webgpu_pipeline blk_pipeline; - std::vector blk_params; - std::vector blk_entries; - if (use_blk) { - GGML_ASSERT(has_mask); - - blk_nblk0 = CEIL_DIV((uint32_t) K->ne[1], decisions->kv_tile); - blk_nblk1 = CEIL_DIV((uint32_t) Q->ne[1], decisions->q_tile); - blk_buf = ggml_webgpu_tensor_buf(dst); - const uint32_t stride_mask3 = (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)); - blk_batch_count = stride_mask3 > 0 ? (uint32_t) Q->ne[3] : 1u; - const uint64_t blk_elems = (uint64_t) blk_nblk0 * blk_nblk1 * blk_batch_count; - blk_size_bytes = ROUNDUP_POW2(blk_elems * sizeof(uint32_t), WEBGPU_STORAGE_BUF_BINDING_MULT); - ggml_webgpu_flash_attn_blk_shader_lib_context blk_shader_ctx = { - .key = - { - .q_tile = decisions->q_tile, - .kv_tile = decisions->kv_tile, - }, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - }; - blk_pipeline = ctx->shader_lib->get_flash_attn_blk_pipeline(blk_shader_ctx); - - blk_params = { - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)), // offset_mask - (uint32_t) Q->ne[1], // seq_len_q - (uint32_t) K->ne[1], // seq_len_kv - stride_mask3, // stride_mask3 - blk_nblk0, // nblk0 - blk_nblk1, // nblk1 - }; - blk_entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(mask), - .offset = ggml_webgpu_tensor_align_offset(ctx, mask), - .size = ggml_webgpu_tensor_binding_size(ctx, mask) }, - { .binding = 1, .buffer = blk_buf, .offset = scratch_offset, .size = blk_size_bytes }, - }; - scratch_offset = ROUNDUP_POW2(scratch_offset + blk_size_bytes, align_bytes); - } + const uint32_t vec_nwg_cap = std::max(1u, std::min(32u, ctx->global_ctx->capabilities.max_subgroup_size)); + uint32_t nwg = 1u; + const uint64_t kv_span = (uint64_t) std::max(1u, decisions->kv_tile); + while ((2u * nwg * kv_span) < (uint64_t) K->ne[1] && nwg < vec_nwg_cap) { + nwg <<= 1; + } + nwg = std::min(nwg, vec_nwg_cap); + const uint64_t nrows = (uint64_t) Q->ne[1] * Q->ne[2] * Q->ne[3]; + const bool use_vec_reduce = nwg > 1u; + GGML_ASSERT(nrows <= UINT32_MAX); + + uint64_t tmp_stats_base = 0; + uint64_t tmp_size_bytes = 0; + wgpu::Buffer tmp_buf = {}; + uint64_t tmp_bind_offset = 0; + uint64_t tmp_bind_size = 0; + const size_t align_bytes = ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment; + const size_t dst_offset = ggml_webgpu_tensor_offset(dst); + size_t scratch_offset = ROUNDUP_POW2(dst_offset + ggml_nbytes(dst), align_bytes); + + if (use_vec_reduce) { + const uint64_t tmp_data_elems = nrows * (uint64_t) V->ne[0] * nwg; + const uint64_t tmp_stats_elems = nrows * 2u * nwg; + tmp_stats_base = tmp_data_elems; + tmp_size_bytes = + ROUNDUP_POW2((tmp_data_elems + tmp_stats_elems) * sizeof(float), WEBGPU_STORAGE_BUF_BINDING_MULT); + GGML_ASSERT(tmp_stats_base <= UINT32_MAX); + tmp_buf = ggml_webgpu_tensor_buf(dst); + tmp_bind_offset = scratch_offset; + tmp_bind_size = tmp_size_bytes; + scratch_offset = ROUNDUP_POW2(scratch_offset + tmp_size_bytes, align_bytes); + } else { + // nwg==1 writes final dst directly in vec-split; keep tmp binding valid without extra allocation. + tmp_buf = ggml_webgpu_tensor_buf(dst); + tmp_bind_offset = ggml_webgpu_tensor_align_offset(ctx, dst); + tmp_bind_size = ggml_webgpu_tensor_binding_size(ctx, dst); + } - std::vector split_params = params; - if (use_blk) { - split_params.push_back(0u); // blk_base - split_params.push_back(blk_nblk0); // blk_nblk0 - split_params.push_back(blk_nblk1); // blk_nblk1 - } - split_params.push_back(0u); // tmp_data_base - split_params.push_back((uint32_t) tmp_stats_base); // tmp_stats_base - split_params.push_back(nwg); // nwg - - std::vector split_entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(Q), - .offset = ggml_webgpu_tensor_align_offset(ctx, Q), - .size = ggml_webgpu_tensor_binding_size(ctx, Q) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(K), - .offset = ggml_webgpu_tensor_align_offset(ctx, K), - .size = ggml_webgpu_tensor_binding_size(ctx, K) }, - { .binding = 2, - .buffer = ggml_webgpu_tensor_buf(V), - .offset = ggml_webgpu_tensor_align_offset(ctx, V), - .size = ggml_webgpu_tensor_binding_size(ctx, V) }, + webgpu_pipeline blk_pipeline; + std::vector blk_params; + std::vector blk_entries; + if (has_mask) { + blk_nblk0 = CEIL_DIV((uint32_t) K->ne[1], decisions->kv_tile); + blk_nblk1 = (uint32_t) Q->ne[1]; + blk_buf = ggml_webgpu_tensor_buf(dst); + const uint32_t stride_mask3 = (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)); + blk_batch_count = stride_mask3 > 0 ? (uint32_t) Q->ne[3] : 1u; + const uint64_t blk_elems = (uint64_t) blk_nblk0 * blk_nblk1 * blk_batch_count; + blk_size_bytes = ROUNDUP_POW2(blk_elems * sizeof(uint32_t), WEBGPU_STORAGE_BUF_BINDING_MULT); + const ggml_webgpu_shader_lib_context blk_shader_ctx = shader_lib_ctx; + blk_pipeline = ctx->shader_lib->get_flash_attn_blk_pipeline(blk_shader_ctx); + + blk_params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)), // offset_mask + (uint32_t) Q->ne[1], // seq_len_q + (uint32_t) K->ne[1], // seq_len_kv + stride_mask3, // stride_mask3 + blk_nblk0, // nblk0 + blk_nblk1, // nblk1 }; - uint32_t split_binding_index = 3; - if (has_mask) { - split_entries.push_back({ .binding = split_binding_index++, - .buffer = ggml_webgpu_tensor_buf(mask), - .offset = ggml_webgpu_tensor_align_offset(ctx, mask), - .size = ggml_webgpu_tensor_binding_size(ctx, mask) }); - } - if (has_sinks) { - split_entries.push_back({ .binding = split_binding_index++, - .buffer = ggml_webgpu_tensor_buf(sinks), - .offset = ggml_webgpu_tensor_align_offset(ctx, sinks), - .size = ggml_webgpu_tensor_binding_size(ctx, sinks) }); - } - if (use_blk) { - split_entries.push_back({ .binding = split_binding_index++, - .buffer = blk_buf, - .offset = blk_entries[1].offset, - .size = blk_size_bytes }); - } + blk_entries = { + ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(mask), + ggml_webgpu_tensor_align_offset(ctx, mask), + ggml_webgpu_tensor_binding_size(ctx, mask)), + ggml_webgpu_make_bind_group_entry(1, blk_buf, scratch_offset, blk_size_bytes), + }; + scratch_offset = ROUNDUP_POW2(scratch_offset + blk_size_bytes, align_bytes); + } + + std::vector split_params = params; + if (has_mask) { + split_params.push_back(0u); // blk_base + split_params.push_back(blk_nblk0); // blk_nblk0 + split_params.push_back(blk_nblk1); // blk_nblk1 + } + split_params.push_back(0u); // tmp_data_base + split_params.push_back((uint32_t) tmp_stats_base); // tmp_stats_base + split_params.push_back(nwg); // nwg + + std::vector split_entries = { + ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(Q), ggml_webgpu_tensor_align_offset(ctx, Q), + ggml_webgpu_tensor_binding_size(ctx, Q)), + ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(K), ggml_webgpu_tensor_align_offset(ctx, K), + ggml_webgpu_tensor_binding_size(ctx, K)), + ggml_webgpu_make_bind_group_entry(2, ggml_webgpu_tensor_buf(V), ggml_webgpu_tensor_align_offset(ctx, V), + ggml_webgpu_tensor_binding_size(ctx, V)), + }; + uint32_t split_binding_index = 3; + if (has_mask) { + split_entries.push_back(ggml_webgpu_make_bind_group_entry(split_binding_index++, ggml_webgpu_tensor_buf(mask), + ggml_webgpu_tensor_align_offset(ctx, mask), + ggml_webgpu_tensor_binding_size(ctx, mask))); + } + if (has_sinks) { + split_entries.push_back(ggml_webgpu_make_bind_group_entry(split_binding_index++, ggml_webgpu_tensor_buf(sinks), + ggml_webgpu_tensor_align_offset(ctx, sinks), + ggml_webgpu_tensor_binding_size(ctx, sinks))); + } + if (has_mask) { split_entries.push_back( - { .binding = split_binding_index++, .buffer = tmp_buf, .offset = tmp_bind_offset, .size = tmp_bind_size }); - split_entries.push_back({ .binding = split_binding_index++, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); - - webgpu_pipeline reduce_pipeline; - std::vector reduce_params; - std::vector reduce_entries; - if (use_vec_reduce) { - const uint32_t reduce_wg_size = std::max( - 32u, - std::min(nwg * 32u, ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup)); - ggml_webgpu_flash_attn_vec_reduce_shader_lib_context reduce_shader_ctx = { - .key = - { - .head_dim_v = (uint32_t) V->ne[0], - .wg_size = reduce_wg_size, - }, - .max_wg_size = reduce_wg_size, - }; - reduce_pipeline = ctx->shader_lib->get_flash_attn_vec_reduce_pipeline(reduce_shader_ctx); - - reduce_params = { - (uint32_t) nrows, // nrows - (uint32_t) Q->ne[1], // seq_len_q - (uint32_t) Q->ne[2], // n_heads - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), // offset_dst - nwg, // nwg - 0u, // tmp_data_base - (uint32_t) tmp_stats_base, // tmp_stats_base - }; - - reduce_entries = { - { .binding = 0, .buffer = tmp_buf, .offset = tmp_bind_offset, .size = tmp_size_bytes }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }, - }; - } + ggml_webgpu_make_bind_group_entry(split_binding_index++, blk_buf, blk_entries[1].offset, blk_size_bytes)); + } + split_entries.push_back( + ggml_webgpu_make_bind_group_entry(split_binding_index++, tmp_buf, tmp_bind_offset, tmp_bind_size)); + split_entries.push_back(ggml_webgpu_make_bind_group_entry(split_binding_index++, ggml_webgpu_tensor_buf(dst), + ggml_webgpu_tensor_align_offset(ctx, dst), + ggml_webgpu_tensor_binding_size(ctx, dst))); + + webgpu_pipeline reduce_pipeline; + std::vector reduce_params; + std::vector reduce_entries; + if (use_vec_reduce) { + const uint32_t reduce_wg_size = std::max( + 32u, std::min(nwg * 32u, ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup)); + ggml_webgpu_shader_lib_context reduce_shader_ctx = shader_lib_ctx; + reduce_shader_ctx.max_wg_size = reduce_wg_size; + reduce_pipeline = ctx->shader_lib->get_flash_attn_vec_reduce_pipeline(reduce_shader_ctx); + + reduce_params = { + (uint32_t) nrows, // nrows + (uint32_t) Q->ne[1], // seq_len_q + (uint32_t) Q->ne[2], // n_heads + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), // offset_dst + nwg, // nwg + 0u, // tmp_data_base + (uint32_t) tmp_stats_base, // tmp_stats_base + }; - const uint64_t split_wg_total = (uint64_t) wg_x * nwg; - GGML_ASSERT(split_wg_total <= UINT32_MAX); - std::vector dispatches; - - if (use_blk) { - dispatches.push_back({ - blk_pipeline, - std::move(blk_params), - std::move(blk_entries), - { blk_nblk0, blk_nblk1 * blk_batch_count } - }); - } + reduce_entries = { + ggml_webgpu_make_bind_group_entry(0, tmp_buf, tmp_bind_offset, tmp_size_bytes), + ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(dst), ggml_webgpu_tensor_align_offset(ctx, dst), + ggml_webgpu_tensor_binding_size(ctx, dst)), + }; + } + + uint32_t wg_x = Q->ne[1] * Q->ne[2] * Q->ne[3]; + const uint64_t split_wg_total = (uint64_t) wg_x * nwg; + GGML_ASSERT(split_wg_total <= UINT32_MAX); + + std::vector dispatches; + + if (has_mask) { dispatches.push_back({ - pipeline, std::move(split_params), std::move(split_entries), { (uint32_t) split_wg_total, 1u } + blk_pipeline, std::move(blk_params), std::move(blk_entries), { blk_nblk0, blk_nblk1 * blk_batch_count } + }); + } + dispatches.push_back({ + pipeline, std::move(split_params), std::move(split_entries), { (uint32_t) split_wg_total, 1u } + }); + if (use_vec_reduce) { + dispatches.push_back({ + reduce_pipeline, std::move(reduce_params), std::move(reduce_entries), { (uint32_t) nrows, 1u } }); - if (use_vec_reduce) { - dispatches.push_back({ - reduce_pipeline, std::move(reduce_params), std::move(reduce_entries), { (uint32_t) nrows, 1u } - }); - } - - return ggml_backend_webgpu_build_multi(ctx, dispatches); } - return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); + return ggml_backend_webgpu_build_multi(ctx, dispatches); } #endif // __EMSCRIPTEN__ @@ -1807,13 +1662,12 @@ static webgpu_encoded_op ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor bool is_unary = dst->op == GGML_OP_UNARY; bool inplace = ggml_webgpu_tensor_equal(src, dst) || (dst->op == GGML_OP_FILL); - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src, - .src1 = nullptr, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - .inplace = inplace, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.src1 = nullptr; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.inplace = inplace; webgpu_pipeline pipeline = ctx->shader_lib->get_unary_pipeline(shader_lib_ctx); @@ -1844,10 +1698,10 @@ static webgpu_encoded_op ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor float alpha_p = ggml_get_op_params_f32(dst, 2); float beta = ggml_get_op_params_f32(dst, 3); float eps = ggml_get_op_params_f32(dst, 4); - params.push_back(*reinterpret_cast(&alpha_n)); - params.push_back(*reinterpret_cast(&alpha_p)); - params.push_back(*reinterpret_cast(&beta)); - params.push_back(*reinterpret_cast(&eps)); + params.push_back(ggml_webgpu_u32_from_f32(alpha_n)); + params.push_back(ggml_webgpu_u32_from_f32(alpha_p)); + params.push_back(ggml_webgpu_u32_from_f32(beta)); + params.push_back(ggml_webgpu_u32_from_f32(eps)); break; } default: @@ -1856,25 +1710,19 @@ static webgpu_encoded_op ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor } else if (dst->op == GGML_OP_CLAMP) { float clamp_min = ggml_get_op_params_f32(dst, 0); float clamp_max = ggml_get_op_params_f32(dst, 1); - params.push_back(*reinterpret_cast(&clamp_min)); - params.push_back(*reinterpret_cast(&clamp_max)); + params.push_back(ggml_webgpu_u32_from_f32(clamp_min)); + params.push_back(ggml_webgpu_u32_from_f32(clamp_max)); } else if (dst->op == GGML_OP_FILL) { float fill_val = ggml_get_op_params_f32(dst, 0); - params.push_back(*reinterpret_cast(&fill_val)); + params.push_back(ggml_webgpu_u32_from_f32(fill_val)); effective_src = dst; // fill simply fills dst } std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(effective_src), - .offset = ggml_webgpu_tensor_align_offset(ctx, effective_src), - .size = ggml_webgpu_tensor_binding_size(ctx, effective_src) }, + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, effective_src), }; if (!inplace) { - entries.push_back({ .binding = 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst)); } uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); @@ -1887,15 +1735,14 @@ static webgpu_encoded_op ggml_webgpu_binary_op(webgpu_context & ctx, ggml_tensor * dst) { binary_overlap_flags flags = ggml_webgpu_detect_binary_overlap(src0, src1, dst); - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src0, - .src1 = src1, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - .inplace = flags.inplace, - .overlap = flags.overlap, - .src_overlap = flags.src_overlap, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.inplace = flags.inplace; + shader_lib_ctx.overlap = flags.overlap; + shader_lib_ctx.src_overlap = flags.src_overlap; webgpu_pipeline pipeline = ctx->shader_lib->get_binary_pipeline(shader_lib_ctx); @@ -1944,38 +1791,18 @@ static webgpu_encoded_op ggml_webgpu_binary_op(webgpu_context & ctx, size_t merged_offset = std::min(src0_webgpu_tensor_align_offset, src1_webgpu_tensor_align_offset); size_t merged_end = std::max(src0_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, src0), src1_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, src1)); - entries.push_back({ - .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = merged_offset, - .size = merged_end - merged_offset, - }); - entries.push_back({ - .binding = 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst), - }); + entries.push_back(ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(src0), merged_offset, + merged_end - merged_offset)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst)); } else { - entries.push_back({ - .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = src0_webgpu_tensor_align_offset, - .size = ggml_webgpu_tensor_binding_size(ctx, src0), - }); - entries.push_back({ - .binding = 1, - .buffer = ggml_webgpu_tensor_buf(src1), - .offset = src1_webgpu_tensor_align_offset, - .size = ggml_webgpu_tensor_binding_size(ctx, src1), - }); + entries.push_back(ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(src0), + src0_webgpu_tensor_align_offset, + ggml_webgpu_tensor_binding_size(ctx, src0))); + entries.push_back(ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(src1), + src1_webgpu_tensor_align_offset, + ggml_webgpu_tensor_binding_size(ctx, src1))); if (!flags.inplace && !flags.overlap) { - entries.push_back({ - .binding = 2, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst), - }); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst)); } } @@ -2012,26 +1839,16 @@ static webgpu_encoded_op ggml_webgpu_concat(webgpu_context & ctx, }; std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = ggml_webgpu_tensor_align_offset(ctx, src0), - .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(src1), - .offset = ggml_webgpu_tensor_align_offset(ctx, src1), - .size = ggml_webgpu_tensor_binding_size(ctx, src1) }, - { .binding = 2, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) } + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst), }; - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src0, - .src1 = src1, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; webgpu_pipeline pipeline = ctx->shader_lib->get_concat_pipeline(shader_lib_ctx); auto * decisions = static_cast(pipeline.context.get()); @@ -2059,21 +1876,14 @@ static webgpu_encoded_op ggml_webgpu_repeat(webgpu_context & ctx, ggml_tensor * (uint32_t) (dst->ne[2]) }; std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = ggml_webgpu_tensor_align_offset(ctx, src0), - .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) } + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst), }; - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src0, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; webgpu_pipeline pipeline = ctx->shader_lib->get_repeat_pipeline(shader_lib_ctx); auto * decisions = static_cast(pipeline.context.get()); @@ -2097,28 +1907,19 @@ static webgpu_encoded_op ggml_webgpu_row_norm(webgpu_context & ctx, ggml_tensor (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) src->ne[3], - *(uint32_t *) dst->op_params // epsilon, treated as f32 in the shader + ggml_webgpu_u32_from_f32(ggml_get_op_params_f32(dst, 0)) // epsilon, treated as f32 in the shader }; - std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src), - .offset = ggml_webgpu_tensor_align_offset(ctx, src), - .size = ggml_webgpu_tensor_binding_size(ctx, src) } - }; + std::vector entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src) }; if (!inplace) { - entries.push_back({ .binding = 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst)); } - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - .inplace = inplace, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.inplace = inplace; webgpu_pipeline pipeline = ctx->shader_lib->get_row_norm_pipeline(shader_lib_ctx); return ggml_backend_webgpu_build(ctx, pipeline, params, entries, ggml_nrows(src)); @@ -2129,14 +1930,13 @@ static webgpu_encoded_op ggml_webgpu_rope(webgpu_context & ctx, ggml_tensor * src1, ggml_tensor * src2, ggml_tensor * dst) { - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src0, - .src1 = src1, - .src2 = src2, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - .inplace = ggml_webgpu_tensor_equal(src0, dst), - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.src2 = src2; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.inplace = ggml_webgpu_tensor_equal(src0, dst); webgpu_pipeline pipeline = ctx->shader_lib->get_rope_pipeline(shader_lib_ctx); @@ -2187,41 +1987,27 @@ static webgpu_encoded_op ggml_webgpu_rope(webgpu_context & ctx, (uint32_t) src0->ne[2], (uint32_t) n_dims, (uint32_t) mode, - *(uint32_t *) &theta_scale, - *(uint32_t *) &attn_factor, - *(uint32_t *) &freq_scale, - *(uint32_t *) &ext_factor, - *(uint32_t *) &corr_dims[0], - *(uint32_t *) &corr_dims[1], + ggml_webgpu_u32_from_f32(theta_scale), + ggml_webgpu_u32_from_f32(attn_factor), + ggml_webgpu_u32_from_f32(freq_scale), + ggml_webgpu_u32_from_f32(ext_factor), + ggml_webgpu_u32_from_f32(corr_dims[0]), + ggml_webgpu_u32_from_f32(corr_dims[1]), (uint32_t) sections[0], (uint32_t) sections[1], (uint32_t) sections[2], (uint32_t) sections[3] }; - std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = ggml_webgpu_tensor_align_offset(ctx, src0), - .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(src1), - .offset = ggml_webgpu_tensor_align_offset(ctx, src1), - .size = ggml_webgpu_tensor_binding_size(ctx, src1) } - }; - uint32_t dst_binding = 2; + std::vector entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1) }; + uint32_t dst_binding = 2; if (has_freq_factor) { dst_binding = 3; - entries.push_back({ .binding = 2, - .buffer = ggml_webgpu_tensor_buf(src2), - .offset = ggml_webgpu_tensor_align_offset(ctx, src2), - .size = ggml_webgpu_tensor_binding_size(ctx, src2) }); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, src2)); } if (!inplace) { - entries.push_back({ .binding = dst_binding, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, dst_binding, dst)); } uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size); @@ -2232,12 +2018,11 @@ static webgpu_encoded_op ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) { - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src0, - .src1 = src1, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; webgpu_pipeline pipeline = ctx->shader_lib->get_glu_pipeline(shader_lib_ctx); @@ -2265,29 +2050,20 @@ static webgpu_encoded_op ggml_webgpu_glu(webgpu_context & ctx, (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], - (uint32_t) ((int32_t *) dst->op_params)[1], // swapped - *(uint32_t *) &dst->op_params[2], // alpha, for swiglu_oai - *(uint32_t *) &dst->op_params[3], // limit, for swiglu_oai + (uint32_t) ((int32_t *) dst->op_params)[1], // swapped + ggml_webgpu_u32_from_f32(ggml_get_op_params_f32(dst, 2)), // alpha, for swiglu_oai + ggml_webgpu_u32_from_f32(ggml_get_op_params_f32(dst, 3)), // limit, for swiglu_oai }; std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = ggml_webgpu_tensor_align_offset(ctx, src0), - .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), }; uint32_t dst_binding = 1; if (split) { dst_binding = 2; - entries.push_back({ .binding = 1, - .buffer = ggml_webgpu_tensor_buf(src1), - .offset = ggml_webgpu_tensor_align_offset(ctx, src1), - .size = ggml_webgpu_tensor_binding_size(ctx, src1) }); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1)); } - entries.push_back({ .binding = dst_binding, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, dst_binding, dst)); uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size); return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); @@ -2296,13 +2072,12 @@ static webgpu_encoded_op ggml_webgpu_glu(webgpu_context & ctx, static webgpu_encoded_op ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { bool inplace = ggml_webgpu_tensor_equal(src, dst); - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src, - .src1 = nullptr, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - .inplace = inplace, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.src1 = nullptr; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.inplace = inplace; webgpu_pipeline pipeline = ctx->shader_lib->get_scale_pipeline(shader_lib_ctx); auto * decisions = static_cast(pipeline.context.get()); @@ -2321,23 +2096,15 @@ static webgpu_encoded_op ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * s (uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], - *(uint32_t *) dst->op_params, // scale - *(uint32_t *) &dst->op_params[1] // bias + ggml_webgpu_u32_from_f32(ggml_get_op_params_f32(dst, 0)), // scale + ggml_webgpu_u32_from_f32(ggml_get_op_params_f32(dst, 1)) // bias }; // bindgroups unchanged - std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src), - .offset = ggml_webgpu_tensor_align_offset(ctx, src), - .size = ggml_webgpu_tensor_binding_size(ctx, src) } - }; + std::vector entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src) }; if (!inplace) { - entries.push_back({ .binding = 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst)); } uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size); @@ -2349,25 +2116,23 @@ static webgpu_encoded_op ggml_webgpu_soft_max(webgpu_context & ctx, ggml_tensor * src1, ggml_tensor * src2, ggml_tensor * dst) { - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src0, - .src1 = src1, - .src2 = src2, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - .inplace = ggml_webgpu_tensor_equal(src0, dst), - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.src2 = src2; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.inplace = ggml_webgpu_tensor_equal(src0, dst); webgpu_pipeline pipeline = ctx->shader_lib->get_soft_max_pipeline(shader_lib_ctx); - const int inplace = ggml_webgpu_tensor_equal(src0, dst); - const int has_mask = (src1 != nullptr); - const int has_sink = (src2 != nullptr); - float max_bias; - memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); - float n_head_log2 = float(1u << (uint32_t) floor(log2(src0->ne[2]))); - float m0 = powf(2.0f, -(max_bias) / n_head_log2); - float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + const int inplace = ggml_webgpu_tensor_equal(src0, dst); + const int has_mask = (src1 != nullptr); + const int has_sink = (src2 != nullptr); + float max_bias = ggml_get_op_params_f32(dst, 1); + float n_head_log2 = float(1u << (uint32_t) floor(log2(src0->ne[2]))); + float m0 = powf(2.0f, -(max_bias) / n_head_log2); + float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), @@ -2389,39 +2154,29 @@ static webgpu_encoded_op ggml_webgpu_soft_max(webgpu_context & ctx, (uint32_t) src0->ne[2], has_mask ? (uint32_t) src1->ne[2] : 0, has_mask ? (uint32_t) src1->ne[3] : 0, - *(uint32_t *) dst->op_params, // scale - *(uint32_t *) &max_bias, - *(uint32_t *) &n_head_log2, - *(uint32_t *) &m0, - *(uint32_t *) &m1 + ggml_webgpu_u32_from_f32(ggml_get_op_params_f32(dst, 0)), // scale + ggml_webgpu_u32_from_f32(max_bias), + ggml_webgpu_u32_from_f32(n_head_log2), + ggml_webgpu_u32_from_f32(m0), + ggml_webgpu_u32_from_f32(m1) }; - std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = ggml_webgpu_tensor_align_offset(ctx, src0), - .size = ggml_webgpu_tensor_binding_size(ctx, src0) } - }; - uint32_t binding_num = 1; + std::vector entries = { ggml_webgpu_make_bind_group_entry( + 0, ggml_webgpu_tensor_buf(src0), ggml_webgpu_tensor_align_offset(ctx, src0), + ggml_webgpu_tensor_binding_size(ctx, src0)) }; + uint32_t binding_num = 1; if (has_mask) { - entries.push_back({ .binding = binding_num, - .buffer = ggml_webgpu_tensor_buf(src1), - .offset = ggml_webgpu_tensor_align_offset(ctx, src1), - .size = ggml_webgpu_tensor_binding_size(ctx, src1) }); + entries.push_back(ggml_webgpu_make_bind_group_entry(binding_num, ggml_webgpu_tensor_buf(src1), + ggml_webgpu_tensor_align_offset(ctx, src1), + ggml_webgpu_tensor_binding_size(ctx, src1))); binding_num++; } if (has_sink) { - entries.push_back({ .binding = binding_num, - .buffer = ggml_webgpu_tensor_buf(src2), - .offset = ggml_webgpu_tensor_align_offset(ctx, src2), - .size = ggml_webgpu_tensor_binding_size(ctx, src2) }); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_num, src2)); binding_num++; } if (!inplace) { - entries.push_back({ .binding = binding_num, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_num, dst)); } return ggml_backend_webgpu_build(ctx, pipeline, params, entries, ggml_nrows(dst)); @@ -2432,20 +2187,13 @@ static webgpu_encoded_op ggml_webgpu_argmax(webgpu_context & ctx, ggml_tensor * (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), (uint32_t) src->ne[0] }; - std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src), - .offset = ggml_webgpu_tensor_align_offset(ctx, src), - .size = ggml_webgpu_tensor_binding_size(ctx, src) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) } - }; + std::vector entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst) }; - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; webgpu_pipeline pipeline = ctx->shader_lib->get_argmax_pipeline(shader_lib_ctx); uint32_t wg_x = ggml_nelements(dst); @@ -2455,13 +2203,12 @@ static webgpu_encoded_op ggml_webgpu_argmax(webgpu_context & ctx, ggml_tensor * static webgpu_encoded_op ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { bool is_top_k = dst->op == GGML_OP_TOP_K; - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src, - .src1 = nullptr, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.src1 = nullptr; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; webgpu_pipeline argsort_pipeline = ctx->shader_lib->get_argsort_pipeline(shader_lib_ctx); auto * argsort_decisions = static_cast(argsort_pipeline.context.get()); @@ -2527,11 +2274,8 @@ static webgpu_encoded_op ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * const uint32_t wg_x_init = std::min(total_wg_init, max_wg); const uint32_t wg_y_init = CEIL_DIV(total_wg_init, wg_x_init); std::vector init_entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src), - .offset = ggml_webgpu_tensor_align_offset(ctx, src), - .size = ggml_webgpu_tensor_binding_size(ctx, src) }, - { .binding = 1, .buffer = ggml_webgpu_tensor_buf(dst), .offset = init_align_offset, .size = init_binding_size } + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src), + ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(dst), init_align_offset, init_binding_size) }; dispatches.push_back({ @@ -2580,12 +2324,9 @@ static webgpu_encoded_op ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * nrows }; std::vector merge_entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src), - .offset = ggml_webgpu_tensor_align_offset(ctx, src), - .size = ggml_webgpu_tensor_binding_size(ctx, src) }, - { .binding = 1, .buffer = ggml_webgpu_tensor_buf(dst), .offset = align_in, .size = size_in }, - { .binding = 2, .buffer = ggml_webgpu_tensor_buf(dst), .offset = align_out, .size = size_out } + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src), + ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(dst), align_in, size_in), + ggml_webgpu_make_bind_group_entry(2, ggml_webgpu_tensor_buf(dst), align_out, size_out) }; const uint32_t total_wg_merge = nm * nrows; @@ -2607,23 +2348,14 @@ static webgpu_encoded_op ggml_webgpu_cumsum(webgpu_context & ctx, ggml_tensor * (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), (uint32_t) src->ne[0] }; - std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src), - .offset = ggml_webgpu_tensor_align_offset(ctx, src), - .size = ggml_webgpu_tensor_binding_size(ctx, src) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) } - }; + std::vector entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst) }; - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src, - .src1 = nullptr, - .dst = dst, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.src1 = nullptr; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; webgpu_pipeline pipeline = ctx->shader_lib->get_cumsum_pipeline(shader_lib_ctx); uint32_t wg_x = ggml_nrows(dst); @@ -2641,20 +2373,13 @@ static webgpu_encoded_op ggml_webgpu_sum_rows(webgpu_context & ctx, ggml_tensor total_sum ? 1 : (uint32_t) src->ne[1], total_sum ? 1 : (uint32_t) src->ne[2] }; - std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src), - .offset = ggml_webgpu_tensor_align_offset(ctx, src), - .size = ggml_webgpu_tensor_binding_size(ctx, src) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) } - }; + std::vector entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst) }; - ggml_webgpu_shader_lib_context shader_lib_ctx = { - .src0 = src, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup - }; + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; webgpu_pipeline pipeline = ctx->shader_lib->get_sum_rows_pipeline(shader_lib_ctx); @@ -3133,40 +2858,24 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer const ggml_tensor * mask = tensor->src[3]; const ggml_tensor * sinks = tensor->src[4]; if (Q && K && V) { - GGML_UNUSED(sinks); - const bool kv_direct = (K->type == GGML_TYPE_F16) && - (Q->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k == 0) && - (K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0); - const bool kv_vec_type_supported = - K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0; - const bool use_vec = (Q->ne[1] < 20) && (Q->ne[0] % 32 == 0) && (V->ne[0] % 4 == 0) && - kv_vec_type_supported && (V->type == K->type); - if (use_vec) { - const uint32_t sg_mat_m = ctx->webgpu_global_ctx->capabilities.sg_mat_m; - const uint32_t sg_mat_n = ctx->webgpu_global_ctx->capabilities.sg_mat_n; - const size_t limit_bytes = - ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; - const size_t q_tile = sg_mat_m; - const size_t base_q_bytes = (Q->ne[0] + V->ne[0]) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES + - 2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES; - size_t bytes_per_kv = 0; - if (!kv_direct) { - bytes_per_kv += std::max(Q->ne[0], V->ne[0]); - } - if (mask != nullptr) { - bytes_per_kv += q_tile; - } - bytes_per_kv += q_tile; - bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES; - uint32_t kv_tile = ((limit_bytes - base_q_bytes) / bytes_per_kv / sg_mat_n) * sg_mat_n; - kv_tile = std::max(sg_mat_n, std::min(32u, kv_tile)); - kv_tile = (kv_tile / sg_mat_n) * sg_mat_n; - if (kv_direct) { - GGML_ASSERT(kv_tile <= GGML_WEBGPU_KV_SEQ_PAD); - while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) { - kv_tile -= sg_mat_n; - } - } + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = const_cast(Q); + shader_lib_ctx.src1 = const_cast(K); + shader_lib_ctx.src2 = const_cast(V); + shader_lib_ctx.src3 = const_cast(mask); + shader_lib_ctx.src4 = const_cast(sinks); + shader_lib_ctx.dst = const_cast(tensor); + shader_lib_ctx.max_wg_size = + ctx->webgpu_global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.wg_mem_limit_bytes = + ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; + shader_lib_ctx.sg_mat_m = ctx->webgpu_global_ctx->capabilities.sg_mat_m; + shader_lib_ctx.sg_mat_n = ctx->webgpu_global_ctx->capabilities.sg_mat_n; + shader_lib_ctx.sg_mat_k = ctx->webgpu_global_ctx->capabilities.sg_mat_k; + shader_lib_ctx.max_subgroup_size = ctx->webgpu_global_ctx->capabilities.max_subgroup_size; + + if (ggml_webgpu_flash_attn_use_vec(ctx->webgpu_global_ctx, Q, K, V)) { + const uint32_t kv_tile = ggml_webgpu_flash_attn_vec_get_kv_tile(shader_lib_ctx); const uint32_t vec_nwg_cap = std::max( 1u, std::min(32u, ctx->webgpu_global_ctx->capabilities.max_subgroup_size)); @@ -3271,8 +2980,9 @@ static void ggml_backend_webgpu_device_get_props(ggml_backend_dev_t dev, struct } static ggml_guid_t ggml_backend_webgpu_guid(void) { - static const char * guid_str = "__ggml_webgpu :)"; - return reinterpret_cast((void *) guid_str); + static ggml_guid guid = { 0x67, 0xc7, 0xa4, 0xb1, 0x78, 0x74, 0x4f, 0x51, + 0x9d, 0x65, 0x44, 0x6d, 0xe4, 0x1b, 0x82, 0x9a }; + return &guid; } static void ggml_webgpu_init_memset_pipeline(webgpu_global_context & ctx) { @@ -3931,20 +3641,23 @@ static const struct ggml_backend_reg_i ggml_backend_webgpu_reg_i = { ggml_backend_reg_t ggml_backend_webgpu_reg() { WEBGPU_LOG_DEBUG("ggml_backend_webgpu_reg()"); - static ggml_backend_webgpu_reg_context ctx; - static ggml_backend_reg reg = { + // Intentionally leak the global registry context to avoid crashing inside + // Dawn/Vulkan static teardown during process exit. + static ggml_backend_webgpu_reg_context * ctx = new ggml_backend_webgpu_reg_context(); + + static ggml_backend_reg reg = { /* .api_version = */ GGML_BACKEND_API_VERSION, /* .iface = */ ggml_backend_webgpu_reg_i, - /* .context = */ &ctx, + /* .context = */ ctx, }; - ctx.name = GGML_WEBGPU_NAME; - ctx.device_count = 0; + ctx->name = GGML_WEBGPU_NAME; + ctx->device_count = 0; // Keep one Dawn/WebGPU instance alive for the lifetime of the static backend // registry. Recreating it on repeated registry lookups can invalidate // adapter/device references that are still held by the backend/device layer. - if (ctx.webgpu_global_ctx != nullptr && ctx.webgpu_global_ctx->instance != nullptr) { + if (ctx->webgpu_global_ctx != nullptr && ctx->webgpu_global_ctx->instance != nullptr) { return ® } @@ -3961,17 +3674,18 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() { instance_descriptor.nextInChain = &instanceTogglesDesc; #endif - wgpu::Instance inst = wgpu::CreateInstance(&instance_descriptor); - ctx.webgpu_global_ctx = webgpu_global_context(new webgpu_global_context_struct()); - ctx.webgpu_global_ctx->instance = std::move(inst); + wgpu::Instance inst = wgpu::CreateInstance(&instance_descriptor); + ctx->webgpu_global_ctx = webgpu_global_context(new webgpu_global_context_struct()); + ctx->webgpu_global_ctx->instance = std::move(inst); // Probe for adapter support wgpu::Adapter adapter; - if (ctx.webgpu_global_ctx->instance != nullptr) { + if (ctx->webgpu_global_ctx->instance != nullptr) { wgpu::RequestAdapterOptions options = {}; - ctx.webgpu_global_ctx->instance.WaitAny( - ctx.webgpu_global_ctx->instance.RequestAdapter( + // probe for adapter support + ctx->webgpu_global_ctx->instance.WaitAny( + ctx->webgpu_global_ctx->instance.RequestAdapter( &options, wgpu::CallbackMode::AllowSpontaneous, [&adapter](wgpu::RequestAdapterStatus status, wgpu::Adapter _adapter, const char * message) { if (status != wgpu::RequestAdapterStatus::Success) { @@ -3984,7 +3698,7 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() { } if (adapter != nullptr) { - ctx.device_count = 1; + ctx->device_count = 1; } return ® diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl index 82d072be73a..61107c6a985 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl @@ -1,7 +1,6 @@ diagnostic(off, subgroup_uniformity); enable f16; -#define Q_TILE 1 #define KV_TILE 32 #define WG_SIZE 32 @@ -11,7 +10,7 @@ struct Params { seq_len_kv: u32, stride_mask3: u32, // Number of KV blocks and Q blocks per batch. - // nblk0 = ceil(seq_len_kv / KV_TILE), nblk1 = ceil(seq_len_q / Q_TILE). + // nblk0 = ceil(seq_len_kv / KV_TILE), nblk1 = seq_len_q. nblk0: u32, nblk1: u32, }; @@ -40,7 +39,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, return; } - let q_start = q_blk * Q_TILE; + let q_start = q_blk; let k_start = kv_blk * KV_TILE; let mask_batch = select(0u, batch_idx, params.stride_mask3 > 0u); @@ -54,11 +53,8 @@ fn main(@builtin(workgroup_id) wg_id: vec3, var local_max = -MASK_MAX; var local_any = 0u; - for (var q_rel = 0u; q_rel < Q_TILE; q_rel += 1u) { - let q_row = q_start + q_rel; - if (q_row >= params.seq_len_q) { - continue; - } + let q_row = q_start; + if (q_row < params.seq_len_q) { let row_base = mask_batch_base + q_row * params.seq_len_kv; for (var k_rel = local_id.x; k_rel < KV_TILE; k_rel += WG_SIZE) { let k_col = k_start + k_rel; From a899e4bdcbda94e099c7a2ac40ff26b490419cca Mon Sep 17 00:00:00 2001 From: SamareshSingh <97642706+ssam18@users.noreply.github.com> Date: Sat, 18 Apr 2026 03:04:51 -0500 Subject: [PATCH 453/831] ggml-backend-meta: add multi-segment read support in get_tensor (llama/22063) --- ggml/src/ggml-backend-meta.cpp | 40 +++++++++++++++++++++++++++++++++- 1 file changed, 39 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-backend-meta.cpp b/ggml/src/ggml-backend-meta.cpp index 1ee3eeb4d96..24f6bc0639d 100644 --- a/ggml/src/ggml-backend-meta.cpp +++ b/ggml/src/ggml-backend-meta.cpp @@ -1270,7 +1270,45 @@ static void ggml_backend_meta_buffer_get_tensor(ggml_backend_buffer_t buffer, co GGML_ASSERT(ggml_is_contiguous(tensor)); const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ false); - GGML_ASSERT(split_state.n_segments == 1); + + if (split_state.n_segments != 1) { + GGML_ASSERT(split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS); + GGML_ASSERT(offset == 0); + GGML_ASSERT(size == ggml_nbytes(tensor)); + GGML_ASSERT(tensor->ne[3] == 1); + size_t offset_data = 0; + std::vector simple_offsets(n_bufs, 0); + if (split_state.axis == GGML_BACKEND_SPLIT_AXIS_0) { + GGML_ASSERT(tensor->ne[2] == 1); + const int64_t blck_size = ggml_blck_size(tensor->type); + for (size_t s = 0; s < split_state.n_segments; s++) { + for (size_t j = 0; j < n_bufs; j++) { + const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); + GGML_ASSERT(split_state.ne[s*n_bufs + j] % blck_size == 0); + const size_t nbytes = split_state.ne[s*n_bufs + j]/blck_size * tensor->nb[0]; + ggml_backend_tensor_get_2d(simple_tensor, (char *) data + offset_data, simple_offsets[j], nbytes, + tensor->ne[1], simple_tensor->nb[1], tensor->nb[1]); + offset_data += nbytes; + simple_offsets[j] += nbytes; + } + } + GGML_ASSERT(offset_data*tensor->ne[1] == size); + return; + } + GGML_ASSERT(split_state.axis == GGML_BACKEND_SPLIT_AXIS_1); + for (size_t s = 0; s < split_state.n_segments; s++) { + for (size_t j = 0; j < n_bufs; j++) { + const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); + const size_t nbytes = split_state.ne[s*n_bufs + j] * tensor->nb[1]; + ggml_backend_tensor_get_2d(simple_tensor, (char *) data + offset_data, simple_offsets[j], nbytes, + tensor->ne[2], simple_tensor->nb[2], tensor->nb[2]); + offset_data += nbytes; + simple_offsets[j] += nbytes; + } + } + GGML_ASSERT(offset_data*tensor->ne[2] == size); + return; + } switch (split_state.axis) { case GGML_BACKEND_SPLIT_AXIS_0: From 32789b9e07afc115eec3be81a76a34453e90ae67 Mon Sep 17 00:00:00 2001 From: Radoslav Gerganov Date: Sun, 19 Apr 2026 10:21:53 +0300 Subject: [PATCH 454/831] rpc : refactor the RPC transport (llama/21998) * rpc : refactor the RPC transport Move all transport related code into a separate file and use the socket_t interface to hide all transport implementation details. * fix win32 * better socket_t construction --- ggml/src/ggml-rpc/CMakeLists.txt | 1 + ggml/src/ggml-rpc/ggml-rpc.cpp | 806 +++---------------------------- ggml/src/ggml-rpc/transport.cpp | 683 ++++++++++++++++++++++++++ ggml/src/ggml-rpc/transport.h | 34 ++ 4 files changed, 782 insertions(+), 742 deletions(-) create mode 100644 ggml/src/ggml-rpc/transport.cpp create mode 100644 ggml/src/ggml-rpc/transport.h diff --git a/ggml/src/ggml-rpc/CMakeLists.txt b/ggml/src/ggml-rpc/CMakeLists.txt index 8671ce5ceaf..40e11fead63 100644 --- a/ggml/src/ggml-rpc/CMakeLists.txt +++ b/ggml/src/ggml-rpc/CMakeLists.txt @@ -2,6 +2,7 @@ message(STATUS "Using RPC backend") ggml_add_backend_library(ggml-rpc ggml-rpc.cpp + transport.cpp ) if (WIN32) diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index 017ef0af360..2ded7397868 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -2,6 +2,7 @@ #include "ggml-impl.h" #include "ggml-backend-impl.h" #include "ggml-cpp.h" +#include "transport.h" #include #include @@ -12,35 +13,11 @@ #include #include #include -#ifdef _WIN32 -# define WIN32_LEAN_AND_MEAN -# ifndef NOMINMAX -# define NOMINMAX -# endif -# include -# include -#else -# include -# include -# include -# include -# include -# include -# include -#endif #include #include #include #include -#ifdef GGML_RPC_RDMA -# include -# include -# ifndef _WIN32 -# include -# endif -#endif // GGML_RPC_RDMA - static const char * RPC_DEBUG = std::getenv("GGML_RPC_DEBUG"); #define LOG_DBG(...) \ @@ -49,128 +26,6 @@ static const char * RPC_DEBUG = std::getenv("GGML_RPC_DEBUG"); namespace fs = std::filesystem; -static constexpr size_t MAX_CHUNK_SIZE = 1024ull * 1024ull * 1024ull; // 1 GiB - -#ifdef _WIN32 -typedef SOCKET sockfd_t; -using ssize_t = __int64; -#else -typedef int sockfd_t; -#endif - -// cross-platform socket - -#ifdef GGML_RPC_RDMA -static constexpr size_t RDMA_CHUNK = 256 * 1024; // 256 KiB per send/recv (fits default 8 MiB memlock) -static constexpr int RDMA_RX_DEPTH = 24; // pre-posted recv ring: 24 × 256 KiB = 6 MiB -static constexpr size_t RDMA_GID_SIZE = 16; // RoCE GID / IB GID is always 16 bytes -using rdma_gid_t = std::array; - -struct rdma_conn { - struct ibv_context * ctx = nullptr; - struct ibv_pd * pd = nullptr; - struct ibv_cq * scq = nullptr; // send completions - struct ibv_cq * rcq = nullptr; // recv completions - struct ibv_qp * qp = nullptr; - - void * tx_buf = nullptr; - struct ibv_mr * tx_mr = nullptr; - - void * rx_buf = nullptr; // RDMA_RX_DEPTH × RDMA_CHUNK contiguous - struct ibv_mr * rx_mr = nullptr; - int rx_head = 0; - - uint32_t max_inline = 0; - - uint8_t * rx_slot(int i) const { - return static_cast(rx_buf) + static_cast(i) * RDMA_CHUNK; - } - - bool post_rx(int i) { - struct ibv_sge sge = {}; - sge.addr = (uintptr_t)rx_slot(i); - sge.length = RDMA_CHUNK; - sge.lkey = rx_mr->lkey; - struct ibv_recv_wr wr = {}, * bad = nullptr; - wr.wr_id = (uint64_t)i; - wr.sg_list = &sge; - wr.num_sge = 1; - return ibv_post_recv(qp, &wr, &bad) == 0; - } - - ~rdma_conn() { - if (tx_mr) ibv_dereg_mr(tx_mr); - if (rx_mr) ibv_dereg_mr(rx_mr); - free(tx_buf); - free(rx_buf); - if (qp) ibv_destroy_qp(qp); - if (scq) ibv_destroy_cq(scq); - if (rcq) ibv_destroy_cq(rcq); - if (pd) ibv_dealloc_pd(pd); - if (ctx) ibv_close_device(ctx); - } -}; - -// Local RDMA parameters captured during the probe phase and later consumed -// by rdma_activate() after the remote side's caps arrive via HELLO. -struct rdma_local_info { - uint32_t qpn = 0; - uint32_t psn = 0; - uint8_t gid[RDMA_GID_SIZE] = {}; - uint8_t ib_port = 0; - int gid_idx = 0; - enum ibv_mtu path_mtu = IBV_MTU_1024; -}; -#endif // GGML_RPC_RDMA - -// conn_caps size for transport-agnostic capability exchange -static constexpr size_t RPC_CONN_CAPS_SIZE = 24; - -// conn_caps RDMA layout helper -#ifdef GGML_RPC_RDMA -struct rdma_caps { - uint32_t qpn; - uint32_t psn; - uint8_t gid[RDMA_GID_SIZE]; -}; -static_assert(sizeof(rdma_caps) == RPC_CONN_CAPS_SIZE, "rdma_caps must match conn_caps size"); -#endif // GGML_RPC_RDMA - -// Forward declarations for transport function pointers -struct socket_t; -static bool tcp_send_impl(socket_t * sock, const void * data, size_t size); -static bool tcp_recv_impl(socket_t * sock, void * data, size_t size); - -struct socket_t { - sockfd_t fd; - bool (*fn_send)(socket_t *, const void *, size_t) = tcp_send_impl; - bool (*fn_recv)(socket_t *, void *, size_t) = tcp_recv_impl; -#ifdef GGML_RPC_RDMA - std::unique_ptr rdma; - rdma_local_info rdma_local = {}; -#endif // GGML_RPC_RDMA - socket_t(sockfd_t fd) : fd(fd) {} - ~socket_t() { -#ifdef GGML_RPC_RDMA - rdma.reset(); -#endif // GGML_RPC_RDMA - LOG_DBG("[%s] closing socket %d\n", __func__, this->fd); -#ifdef _WIN32 - if (fd != INVALID_SOCKET) closesocket(this->fd); -#else - if (fd >= 0) close(this->fd); -#endif - } - - // Advertise local transport capabilities into conn_caps. - // May probe RDMA and store the probe on this socket for update_caps. - void get_caps(uint8_t * caps); - - // Activate transport upgrade based on remote conn_caps using the probe - // previously stored by get_caps. - void update_caps(const uint8_t * remote_caps); -}; - // macro for nicer error messages on server crash #define RPC_STATUS_ASSERT(x) if (!(x)) GGML_ABORT("Remote RPC server crashed or returned malformed response") @@ -403,540 +258,27 @@ static uint64_t fnv_hash(const uint8_t * data, size_t len) { return hash; } -static std::shared_ptr make_socket(sockfd_t fd) { -#ifdef _WIN32 - if (fd == INVALID_SOCKET) { - return nullptr; - } -#else - if (fd < 0) { - return nullptr; - } -#endif - return std::make_shared(fd); -} - -static bool set_no_delay(sockfd_t sockfd) { - int flag = 1; - // set TCP_NODELAY to disable Nagle's algorithm - int ret = setsockopt(sockfd, IPPROTO_TCP, TCP_NODELAY, (char *)&flag, sizeof(int)); - return ret == 0; -} - -static bool set_reuse_addr(sockfd_t sockfd) { - int flag = 1; - int ret = setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, (char *)&flag, sizeof(int)); - return ret == 0; -} - -static std::shared_ptr socket_connect(const char * host, int port) { - struct sockaddr_in addr; - auto sockfd = socket(AF_INET, SOCK_STREAM, 0); - auto sock_ptr = make_socket(sockfd); - if (sock_ptr == nullptr) { - return nullptr; - } - if (!set_no_delay(sockfd)) { - GGML_LOG_ERROR("Failed to set TCP_NODELAY\n"); - return nullptr; - } - addr.sin_family = AF_INET; - addr.sin_port = htons(port); - struct hostent * server = gethostbyname(host); - if (server == NULL) { - GGML_LOG_ERROR("Cannot resolve host '%s'\n", host); - return nullptr; - } - memcpy(&addr.sin_addr.s_addr, server->h_addr, server->h_length); - if (connect(sock_ptr->fd, (struct sockaddr *)&addr, sizeof(addr)) < 0) { - return nullptr; - } - return sock_ptr; -} - -static std::shared_ptr socket_accept(sockfd_t srv_sockfd) { - auto client_socket_fd = accept(srv_sockfd, NULL, NULL); - auto client_socket = make_socket(client_socket_fd); - if (client_socket == nullptr) { - return nullptr; - } - if (!set_no_delay(client_socket_fd)) { - GGML_LOG_ERROR("Failed to set TCP_NODELAY\n"); - return nullptr; - } - return client_socket; -} - -static std::shared_ptr create_server_socket(const char * host, int port) { - auto sockfd = socket(AF_INET, SOCK_STREAM, 0); - auto sock = make_socket(sockfd); - if (sock == nullptr) { - return nullptr; - } - if (!set_reuse_addr(sockfd)) { - GGML_LOG_ERROR("Failed to set SO_REUSEADDR\n"); - return nullptr; - } - if (inet_addr(host) == INADDR_NONE) { - GGML_LOG_ERROR("Invalid host address: %s\n", host); - return nullptr; - } - struct sockaddr_in serv_addr; - serv_addr.sin_family = AF_INET; - serv_addr.sin_addr.s_addr = inet_addr(host); - serv_addr.sin_port = htons(port); - - if (bind(sockfd, (struct sockaddr *) &serv_addr, sizeof(serv_addr)) < 0) { - return nullptr; - } - if (listen(sockfd, 1) < 0) { - return nullptr; - } - return sock; -} - -static bool send_data(sockfd_t sockfd, const void * data, size_t size) { - size_t bytes_sent = 0; - while (bytes_sent < size) { - size_t size_to_send = std::min(size - bytes_sent, MAX_CHUNK_SIZE); - ssize_t n = send(sockfd, (const char *)data + bytes_sent, size_to_send, 0); - if (n < 0) { - GGML_LOG_ERROR("send failed (bytes_sent=%zu, size_to_send=%zu)\n", - bytes_sent, size_to_send); - return false; - } - bytes_sent += (size_t)n; - } - return true; -} - -static bool recv_data(sockfd_t sockfd, void * data, size_t size) { - size_t bytes_recv = 0; - while (bytes_recv < size) { - size_t size_to_recv = std::min(size - bytes_recv, MAX_CHUNK_SIZE); - ssize_t n = recv(sockfd, (char *)data + bytes_recv, size_to_recv, 0); - if (n < 0) { - GGML_LOG_ERROR("recv failed (bytes_recv=%zu, size_to_recv=%zu)\n", - bytes_recv, size_to_recv); - return false; - } - if (n == 0) { - LOG_DBG("recv returned 0 (peer closed?)\n"); - return false; - } - bytes_recv += (size_t)n; - } - return true; -} - -// TCP transport implementations (for function-pointer dispatch) - -static bool tcp_send_impl(socket_t * sock, const void * data, size_t size) { - return send_data(sock->fd, data, size); -} - -static bool tcp_recv_impl(socket_t * sock, void * data, size_t size) { - return recv_data(sock->fd, data, size); -} - -// RDMA transport (performance-optimized, auto-negotiated) - -#ifdef GGML_RPC_RDMA - -static bool rdma_send_impl(socket_t * sock, const void * data, size_t size); -static bool rdma_recv_impl(socket_t * sock, void * data, size_t size); - -static inline bool tcp_peer_closed(int fd) { - if (fd < 0) return false; -#ifndef _WIN32 - struct pollfd pfd = { fd, POLLIN | POLLRDHUP, 0 }; - int r = poll(&pfd, 1, 0); - return r > 0 && (pfd.revents & (POLLHUP | POLLERR | POLLRDHUP)); -#else - return false; -#endif -} - -static inline bool rdma_poll(struct ibv_cq * cq, struct ibv_wc * wc, int tcp_fd) { - for (uint64_t s = 0; ; s++) { - int n = ibv_poll_cq(cq, 1, wc); - if (n > 0) { - if (wc->status != IBV_WC_SUCCESS) { - GGML_LOG_ERROR("RDMA CQ wc error: status=%d (%s) vendor_err=0x%x\n", - wc->status, ibv_wc_status_str(wc->status), wc->vendor_err); - } - return wc->status == IBV_WC_SUCCESS; - } - if (n < 0) return false; - if ((s & 0xFFFFF) == 0 && s > 0) { - if (tcp_peer_closed(tcp_fd)) { - return false; - } - } - } -} - -static bool rdma_send(rdma_conn * c, const void * data, size_t size, int tcp_fd) { - const uint8_t * src = (const uint8_t *)data; - size_t rem = size; - while (rem > 0) { - size_t chunk = std::min(rem, RDMA_CHUNK); - - struct ibv_sge sge = {}; - struct ibv_send_wr wr = {}, * bad = nullptr; - wr.opcode = IBV_WR_SEND; - wr.sg_list = &sge; - wr.num_sge = 1; - - if (chunk <= c->max_inline) { - sge.addr = (uintptr_t)src; - sge.length = chunk; - wr.send_flags = IBV_SEND_SIGNALED | IBV_SEND_INLINE; - } else { - memcpy(c->tx_buf, src, chunk); - sge.addr = (uintptr_t)c->tx_buf; - sge.length = chunk; - sge.lkey = c->tx_mr->lkey; - wr.send_flags = IBV_SEND_SIGNALED; - } - - if (ibv_post_send(c->qp, &wr, &bad) != 0) return false; - struct ibv_wc wc; - if (!rdma_poll(c->scq, &wc, tcp_fd)) return false; - - src += chunk; - rem -= chunk; - } - return true; -} - - -static bool rdma_recv(rdma_conn * c, void * data, size_t size, int tcp_fd) { - uint8_t * dst = (uint8_t *)data; - size_t rem = size; - while (rem > 0) { - struct ibv_wc wc; - if (!rdma_poll(c->rcq, &wc, tcp_fd)) return false; - - int slot = (int)wc.wr_id; - size_t got = wc.byte_len; - memcpy(dst, c->rx_slot(slot), got); - - if (!c->post_rx(slot)) return false; - - dst += got; - rem -= got; - } - return true; -} - -static bool rdma_send_impl(socket_t * sock, const void * data, size_t size) { - return rdma_send(sock->rdma.get(), data, size, sock->fd); -} - -static bool rdma_recv_impl(socket_t * sock, void * data, size_t size) { - return rdma_recv(sock->rdma.get(), data, size, sock->fd); -} - -// Build a RoCE GID-shaped 16-byte target from a TCP socket's local address. -// Used to match the socket's local IP against the kernel's GID table so that -// a single memcmp handles IPv4, IPv4-mapped IPv6, and native IPv6 uniformly: -// AF_INET -> ::ffff:a.b.c.d (bytes 10-11 = 0xff, last 4 = IPv4) -// AF_INET6 (IPv4-mapped) -> ::ffff:a.b.c.d (already in GID shape) -// AF_INET6 (native v6) -> the 16-byte IPv6 address as-is -// Returns std::nullopt on unsupported family or getsockname failure. -static std::optional rdma_build_target_gid(sockfd_t tcp_fd) { - sockaddr_storage addr = {}; - socklen_t addr_len = sizeof(addr); - if (getsockname(tcp_fd, reinterpret_cast(&addr), &addr_len) != 0) { - return std::nullopt; - } - rdma_gid_t target = {}; - if (addr.ss_family == AF_INET) { - const auto * a = reinterpret_cast(&addr); - target[10] = 0xff; - target[11] = 0xff; - memcpy(&target[12], &a->sin_addr, 4); - return target; - } - if (addr.ss_family == AF_INET6) { - const auto * a = reinterpret_cast(&addr); - memcpy(target.data(), &a->sin6_addr, RDMA_GID_SIZE); - return target; - } - return std::nullopt; -} - -static rdma_conn * rdma_probe(sockfd_t tcp_fd, rdma_local_info * out) { - const char * dev_env = std::getenv("GGML_RDMA_DEV"); - const char * gid_env = std::getenv("GGML_RDMA_GID"); - - auto target_gid = rdma_build_target_gid(tcp_fd); - if (!target_gid) { - return nullptr; - } - - const uint8_t ib_port = 1; - int num_devs = 0; - ibv_device ** devs = ibv_get_device_list(&num_devs); - if (!devs || num_devs == 0) return nullptr; - - ibv_context * ibctx = nullptr; - const char * matched_dev = nullptr; - int gid_idx = gid_env ? atoi(gid_env) : -1; - int gid_version = IBV_GID_TYPE_IB; // 0 = unknown/IB - - for (int d = 0; d < num_devs; d++) { - const char * dn = ibv_get_device_name(devs[d]); - if (dev_env && strcmp(dev_env, dn) != 0) continue; - - ibv_context * ctx = ibv_open_device(devs[d]); - if (!ctx) continue; - - ibv_port_attr pa; - if (ibv_query_port(ctx, ib_port, &pa) != 0) { ibv_close_device(ctx); continue; } - - int found_gid = gid_idx; - int found_version = IBV_GID_TYPE_IB; - if (found_gid < 0) { - // Find a GID on this port whose bytes equal the local TCP address - // (IPv4 or IPv6). Prefer RoCE v2 (UDP/IP, L3-routable) over v1 - // (raw Ethernet, same-L2 only) so silent hangs on L3-routed paths - // are avoided. ibv_query_gid_ex returns gid+type in one call. - int v2_idx = -1; - int v1_idx = -1; - for (int i = 0; i < pa.gid_tbl_len; i++) { - ibv_gid_entry entry = {}; - if (ibv_query_gid_ex(ctx, ib_port, i, &entry, 0) != 0) continue; - if (memcmp(entry.gid.raw, target_gid->data(), RDMA_GID_SIZE) != 0) continue; - if (entry.gid_type == IBV_GID_TYPE_ROCE_V2 && v2_idx < 0) { - v2_idx = i; - } else if (entry.gid_type == IBV_GID_TYPE_ROCE_V1 && v1_idx < 0) { - v1_idx = i; - } - } - if (v2_idx >= 0) { - found_gid = v2_idx; - found_version = IBV_GID_TYPE_ROCE_V2; - } else if (v1_idx >= 0) { - found_gid = v1_idx; - found_version = IBV_GID_TYPE_ROCE_V1; - } - } else { - // Explicit GID index from GGML_RDMA_GID — fetch its type for logging. - ibv_gid_entry entry = {}; - if (ibv_query_gid_ex(ctx, ib_port, found_gid, &entry, 0) == 0) { - found_version = entry.gid_type; - } - } - if (found_gid >= 0) { - ibctx = ctx; - gid_idx = found_gid; - gid_version = found_version; - matched_dev = dn; - out->path_mtu = pa.active_mtu; - break; - } - ibv_close_device(ctx); - } - ibv_free_device_list(devs); - if (!ibctx) return nullptr; - - out->ib_port = ib_port; - out->gid_idx = gid_idx; - - // unique_ptr owns ibctx and every subsequent resource via ~rdma_conn(), - // so each failure path is a plain `return nullptr;`. - auto c = std::make_unique(); - c->ctx = ibctx; - - c->pd = ibv_alloc_pd(ibctx); - if (!c->pd) return nullptr; - - c->scq = ibv_create_cq(ibctx, 16, nullptr, nullptr, 0); - c->rcq = ibv_create_cq(ibctx, RDMA_RX_DEPTH + 4, nullptr, nullptr, 0); - if (!c->scq || !c->rcq) return nullptr; - - ibv_qp_init_attr qia = {}; - qia.send_cq = c->scq; - qia.recv_cq = c->rcq; - qia.qp_type = IBV_QPT_RC; - qia.cap.max_send_wr = 4; - qia.cap.max_recv_wr = RDMA_RX_DEPTH + 4; - qia.cap.max_send_sge = 1; - qia.cap.max_recv_sge = 1; - qia.cap.max_inline_data = 256; - - c->qp = ibv_create_qp(c->pd, &qia); - if (!c->qp) return nullptr; - c->max_inline = qia.cap.max_inline_data; - - c->tx_buf = aligned_alloc(4096, RDMA_CHUNK); - c->rx_buf = aligned_alloc(4096, static_cast(RDMA_RX_DEPTH) * RDMA_CHUNK); - if (!c->tx_buf || !c->rx_buf) return nullptr; - - c->tx_mr = ibv_reg_mr(c->pd, c->tx_buf, RDMA_CHUNK, IBV_ACCESS_LOCAL_WRITE); - c->rx_mr = ibv_reg_mr(c->pd, c->rx_buf, static_cast(RDMA_RX_DEPTH) * RDMA_CHUNK, - IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE); - if (!c->tx_mr || !c->rx_mr) return nullptr; - - ibv_gid local_gid; - if (ibv_query_gid(ibctx, ib_port, gid_idx, &local_gid) != 0) return nullptr; - - out->qpn = c->qp->qp_num; - out->psn = c->qp->qp_num & 0xffffff; - memcpy(out->gid, &local_gid, RDMA_GID_SIZE); - - const char * ver_str = ""; - if (gid_version == IBV_GID_TYPE_ROCE_V2) { - ver_str = " RoCEv2"; - } else if (gid_version == IBV_GID_TYPE_ROCE_V1) { - ver_str = " RoCEv1"; - } - GGML_LOG_INFO("RDMA probed: dev=%s gid=%d%s qpn=%u inline=%u\n", - matched_dev, gid_idx, ver_str, out->qpn, c->max_inline); - return c.release(); -} - -// Phase 2: Given remote QPN/PSN/GID, transition QP: RESET->INIT->pre-post->RTR->RTS. -// On success, the connection is live and ready for rdma_send/rdma_recv. -static bool rdma_activate(rdma_conn * c, const rdma_local_info * local, - uint32_t remote_qpn, uint32_t remote_psn, const uint8_t * remote_gid) { - // RESET -> INIT - { - struct ibv_qp_attr a = {}; - a.qp_state = IBV_QPS_INIT; - a.port_num = local->ib_port; - a.pkey_index = 0; - a.qp_access_flags = IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ | IBV_ACCESS_LOCAL_WRITE; - if (ibv_modify_qp(c->qp, &a, - IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS) != 0) { - return false; - } - } - - for (int i = 0; i < RDMA_RX_DEPTH; i++) { - if (!c->post_rx(i)) return false; - } - - // INIT -> RTR - { - struct ibv_qp_attr a = {}; - a.qp_state = IBV_QPS_RTR; - a.path_mtu = local->path_mtu; - a.dest_qp_num = remote_qpn; - a.rq_psn = remote_psn; - a.max_dest_rd_atomic = 1; - a.min_rnr_timer = 1; - a.ah_attr.is_global = 1; - memcpy(&a.ah_attr.grh.dgid, remote_gid, RDMA_GID_SIZE); - a.ah_attr.grh.hop_limit = 1; - a.ah_attr.grh.sgid_index = local->gid_idx; - a.ah_attr.dlid = 0; - a.ah_attr.port_num = local->ib_port; - if (ibv_modify_qp(c->qp, &a, - IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | IBV_QP_DEST_QPN | - IBV_QP_RQ_PSN | IBV_QP_MAX_DEST_RD_ATOMIC | IBV_QP_MIN_RNR_TIMER) != 0) { - return false; - } - } - - // RTR -> RTS - { - struct ibv_qp_attr a = {}; - a.qp_state = IBV_QPS_RTS; - a.timeout = 14; - a.retry_cnt = 7; - a.rnr_retry = 7; - a.sq_psn = local->psn; - a.max_rd_atomic = 1; - if (ibv_modify_qp(c->qp, &a, - IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT | IBV_QP_RNR_RETRY | - IBV_QP_SQ_PSN | IBV_QP_MAX_QP_RD_ATOMIC) != 0) { - return false; - } - } - - GGML_LOG_INFO("RDMA activated: qpn=%u->%u mtu=%d rx_depth=%d\n", - local->qpn, remote_qpn, 128 << local->path_mtu, RDMA_RX_DEPTH); - return true; -} - -#endif // GGML_RPC_RDMA - -// --------------------------------------------------------------------------- -// socket_t transport capability methods -// --------------------------------------------------------------------------- - -void socket_t::get_caps(uint8_t * caps) { - memset(caps, 0, RPC_CONN_CAPS_SIZE); -#ifdef GGML_RPC_RDMA - rdma_local = {}; - rdma.reset(rdma_probe(fd, &rdma_local)); - if (rdma) { - rdma_caps rc = {}; - rc.qpn = rdma_local.qpn; - rc.psn = rdma_local.psn; - memcpy(rc.gid, rdma_local.gid, RDMA_GID_SIZE); - memcpy(caps, &rc, sizeof(rc)); - } -#endif // GGML_RPC_RDMA -} - -void socket_t::update_caps(const uint8_t * remote_caps) { -#ifdef GGML_RPC_RDMA - if (!rdma) { - return; - } - rdma_caps rc = {}; - memcpy(&rc, remote_caps, sizeof(rc)); - if (rc.qpn == 0) { - rdma.reset(); - return; - } - if (rdma_activate(rdma.get(), &rdma_local, rc.qpn, rc.psn, rc.gid)) { - fn_send = rdma_send_impl; - fn_recv = rdma_recv_impl; - } else { - GGML_LOG_ERROR("RDMA activate failed, staying on TCP\n"); - rdma.reset(); - } -#else - (void)remote_caps; -#endif // GGML_RPC_RDMA -} - -// unified transport dispatch (via function pointers) - -static bool send_data(socket_t * sock, const void * data, size_t size) { - return sock->fn_send(sock, data, size); -} - -static bool recv_data(socket_t * sock, void * data, size_t size) { - return sock->fn_recv(sock, data, size); -} - -static bool send_msg(socket_t * sock, const void * msg, size_t msg_size) { - if (!send_data(sock, &msg_size, sizeof(msg_size))) { +static bool send_msg(socket_ptr sock, const void * msg, size_t msg_size) { + if (!sock->send_data(&msg_size, sizeof(msg_size))) { return false; } - return send_data(sock, msg, msg_size); + return sock->send_data(msg, msg_size); } -static bool recv_msg(socket_t * sock, void * msg, size_t msg_size) { +static bool recv_msg(socket_ptr sock, void * msg, size_t msg_size) { uint64_t size; - if (!recv_data(sock, &size, sizeof(size))) { + if (!sock->recv_data(&size, sizeof(size))) { return false; } if (size != msg_size) { return false; } - return recv_data(sock, msg, msg_size); + return sock->recv_data(msg, msg_size); } -static bool recv_msg(socket_t * sock, std::vector & input) { +static bool recv_msg(socket_ptr sock, std::vector & input) { uint64_t size; - if (!recv_data(sock, &size, sizeof(size))) { + if (!sock->recv_data(&size, sizeof(size))) { return false; } try { @@ -945,7 +287,7 @@ static bool recv_msg(socket_t * sock, std::vector & input) { GGML_LOG_ERROR("Failed to allocate input buffer of size %" PRIu64 "\n", size); return false; } - return recv_data(sock, input.data(), size); + return sock->recv_data(input.data(), size); } static bool parse_endpoint(const std::string & endpoint, std::string & host, int & port) { @@ -964,15 +306,15 @@ static bool parse_endpoint(const std::string & endpoint, std::string & host, int // RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) | // No response -static bool send_rpc_cmd(const std::shared_ptr & sock, enum rpc_cmd cmd, const void * input, size_t input_size) { +static bool send_rpc_cmd(socket_ptr sock, enum rpc_cmd cmd, const void * input, size_t input_size) { uint8_t cmd_byte = cmd; - if (!send_data(sock.get(), &cmd_byte, sizeof(cmd_byte))) { + if (!sock->send_data(&cmd_byte, sizeof(cmd_byte))) { return false; } - if (!send_data(sock.get(), &input_size, sizeof(input_size))) { + if (!sock->send_data(&input_size, sizeof(input_size))) { return false; } - if (!send_data(sock.get(), input, input_size)) { + if (!sock->send_data(input, input_size)) { return false; } return true; @@ -980,18 +322,18 @@ static bool send_rpc_cmd(const std::shared_ptr & sock, enum rpc_cmd cm // RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) | // RPC response: | response_size (8 bytes) | response_data (response_size bytes) | -static bool send_rpc_cmd(const std::shared_ptr & sock, enum rpc_cmd cmd, const void * input, size_t input_size, void * output, size_t output_size) { +static bool send_rpc_cmd(socket_ptr sock, enum rpc_cmd cmd, const void * input, size_t input_size, void * output, size_t output_size) { if (!send_rpc_cmd(sock, cmd, input, input_size)) { return false; } uint64_t out_size; - if (!recv_data(sock.get(), &out_size, sizeof(out_size))) { + if (!sock->recv_data(&out_size, sizeof(out_size))) { return false; } if (out_size != output_size) { return false; } - if (!recv_data(sock.get(), output, output_size)) { + if (!sock->recv_data(output, output_size)) { return false; } return true; @@ -1025,7 +367,6 @@ static std::shared_ptr get_socket(const std::string & endpoint) { static std::mutex mutex; std::lock_guard lock(mutex); static std::unordered_map> sockets; - static bool initialized = false; auto it = sockets.find(endpoint); if (it != sockets.end()) { @@ -1040,19 +381,10 @@ static std::shared_ptr get_socket(const std::string & endpoint) { return nullptr; } -#ifdef _WIN32 - if (!initialized) { - WSADATA wsaData; - int res = WSAStartup(MAKEWORD(2, 2), &wsaData); - if (res != 0) { - return nullptr; - } - initialized = true; + if (!rpc_transport_init()) { + return nullptr; } -#else - GGML_UNUSED(initialized); -#endif - auto sock = socket_connect(host.c_str(), port); + auto sock = socket_t::connect(host.c_str(), port); if (sock == nullptr) { return nullptr; } @@ -2110,10 +1442,10 @@ rpc_server::~rpc_server() { } static void rpc_serve_client(const std::vector & backends, const char * cache_dir, - socket_t * sockfd) { + socket_ptr sock) { rpc_server server(backends, cache_dir); uint8_t cmd; - if (!recv_data(sockfd, &cmd, 1)) { + if (!sock->recv_data(&cmd, 1)) { return; } if (cmd != RPC_CMD_HELLO) { @@ -2123,7 +1455,7 @@ static void rpc_serve_client(const std::vector & backends, const // Read input_size and validate protocol version uint64_t hello_input_size; - if (!recv_data(sockfd, &hello_input_size, sizeof(hello_input_size))) { + if (!sock->recv_data(&hello_input_size, sizeof(hello_input_size))) { return; } @@ -2134,24 +1466,22 @@ static void rpc_serve_client(const std::vector & backends, const } rpc_msg_hello_req req = {}; - if (!recv_data(sockfd, &req, sizeof(req))) { + if (!sock->recv_data(&req, sizeof(req))) { return; } rpc_msg_hello_rsp rsp = {}; server.hello(rsp); - // Advertise server transport capabilities based on client's caps - sockfd->get_caps(rsp.conn_caps); - - if (!send_msg(sockfd, &rsp, sizeof(rsp))) { + sock->get_caps(rsp.conn_caps); + if (!send_msg(sock, &rsp, sizeof(rsp))) { return; } // Activate transport upgrade using client's caps - sockfd->update_caps(req.conn_caps); + sock->update_caps(req.conn_caps); while (true) { - if (!recv_data(sockfd, &cmd, 1)) { + if (!sock->recv_data(&cmd, 1)) { break; } if (cmd >= RPC_CMD_COUNT) { @@ -2165,115 +1495,115 @@ static void rpc_serve_client(const std::vector & backends, const return; } case RPC_CMD_DEVICE_COUNT: { - if (!recv_msg(sockfd, nullptr, 0)) { + if (!recv_msg(sock, nullptr, 0)) { return; } rpc_msg_device_count_rsp response; response.device_count = backends.size(); - if (!send_msg(sockfd, &response, sizeof(response))) { + if (!send_msg(sock, &response, sizeof(response))) { return; } break; } case RPC_CMD_ALLOC_BUFFER: { rpc_msg_alloc_buffer_req request; - if (!recv_msg(sockfd, &request, sizeof(request))) { + if (!recv_msg(sock, &request, sizeof(request))) { return; } rpc_msg_alloc_buffer_rsp response; if (!server.alloc_buffer(request, response)) { return; } - if (!send_msg(sockfd, &response, sizeof(response))) { + if (!send_msg(sock, &response, sizeof(response))) { return; } break; } case RPC_CMD_GET_ALLOC_SIZE: { rpc_msg_get_alloc_size_req request; - if (!recv_msg(sockfd, &request, sizeof(request))) { + if (!recv_msg(sock, &request, sizeof(request))) { return; } rpc_msg_get_alloc_size_rsp response; if (!server.get_alloc_size(request, response)) { return; } - if (!send_msg(sockfd, &response, sizeof(response))) { + if (!send_msg(sock, &response, sizeof(response))) { return; } break; } case RPC_CMD_GET_ALIGNMENT: { rpc_msg_get_alignment_req request; - if (!recv_msg(sockfd, &request, sizeof(request))) { + if (!recv_msg(sock, &request, sizeof(request))) { return; } rpc_msg_get_alignment_rsp response; if (!server.get_alignment(request, response)) { return; } - if (!send_msg(sockfd, &response, sizeof(response))) { + if (!send_msg(sock, &response, sizeof(response))) { return; } break; } case RPC_CMD_GET_MAX_SIZE: { rpc_msg_get_max_size_req request; - if (!recv_msg(sockfd, &request, sizeof(request))) { + if (!recv_msg(sock, &request, sizeof(request))) { return; } rpc_msg_get_max_size_rsp response; if (!server.get_max_size(request, response)) { return; } - if (!send_msg(sockfd, &response, sizeof(response))) { + if (!send_msg(sock, &response, sizeof(response))) { return; } break; } case RPC_CMD_BUFFER_GET_BASE: { rpc_msg_buffer_get_base_req request; - if (!recv_msg(sockfd, &request, sizeof(request))) { + if (!recv_msg(sock, &request, sizeof(request))) { return; } rpc_msg_buffer_get_base_rsp response; if (!server.buffer_get_base(request, response)) { return; } - if (!send_msg(sockfd, &response, sizeof(response))) { + if (!send_msg(sock, &response, sizeof(response))) { return; } break; } case RPC_CMD_FREE_BUFFER: { rpc_msg_free_buffer_req request; - if (!recv_msg(sockfd, &request, sizeof(request))) { + if (!recv_msg(sock, &request, sizeof(request))) { return; } if (!server.free_buffer(request)) { return; } - if (!send_msg(sockfd, nullptr, 0)) { + if (!send_msg(sock, nullptr, 0)) { return; } break; } case RPC_CMD_BUFFER_CLEAR: { rpc_msg_buffer_clear_req request; - if (!recv_msg(sockfd, &request, sizeof(request))) { + if (!recv_msg(sock, &request, sizeof(request))) { return; } if (!server.buffer_clear(request)) { return; } - if (!send_msg(sockfd, nullptr, 0)) { + if (!send_msg(sock, nullptr, 0)) { return; } break; } case RPC_CMD_SET_TENSOR: { std::vector input; - if (!recv_msg(sockfd, input)) { + if (!recv_msg(sock, input)) { return; } if (!server.set_tensor(input)) { @@ -2283,62 +1613,62 @@ static void rpc_serve_client(const std::vector & backends, const } case RPC_CMD_SET_TENSOR_HASH: { rpc_msg_set_tensor_hash_req request; - if (!recv_msg(sockfd, &request, sizeof(request))) { + if (!recv_msg(sock, &request, sizeof(request))) { return; } rpc_msg_set_tensor_hash_rsp response; if (!server.set_tensor_hash(request, response)) { return; } - if (!send_msg(sockfd, &response, sizeof(response))) { + if (!send_msg(sock, &response, sizeof(response))) { return; } break; } case RPC_CMD_INIT_TENSOR: { rpc_msg_init_tensor_req request; - if (!recv_msg(sockfd, &request,sizeof(request))) { + if (!recv_msg(sock, &request,sizeof(request))) { return; } if (!server.init_tensor(request)) { return; } - if (!send_msg(sockfd, nullptr, 0)) { + if (!send_msg(sock, nullptr, 0)) { return; } break; } case RPC_CMD_GET_TENSOR: { rpc_msg_get_tensor_req request; - if (!recv_msg(sockfd, &request, sizeof(request))) { + if (!recv_msg(sock, &request, sizeof(request))) { return; } std::vector response; if (!server.get_tensor(request, response)) { return; } - if (!send_msg(sockfd, response.data(), response.size())) { + if (!send_msg(sock, response.data(), response.size())) { return; } break; } case RPC_CMD_COPY_TENSOR: { rpc_msg_copy_tensor_req request; - if (!recv_msg(sockfd, &request, sizeof(request))) { + if (!recv_msg(sock, &request, sizeof(request))) { return; } rpc_msg_copy_tensor_rsp response; if (!server.copy_tensor(request, response)) { return; } - if (!send_msg(sockfd, &response, sizeof(response))) { + if (!send_msg(sock, &response, sizeof(response))) { return; } break; } case RPC_CMD_GRAPH_COMPUTE: { std::vector input; - if (!recv_msg(sockfd, input)) { + if (!recv_msg(sock, input)) { return; } if (!server.graph_compute(input)) { @@ -2348,7 +1678,7 @@ static void rpc_serve_client(const std::vector & backends, const } case RPC_CMD_GRAPH_RECOMPUTE: { rpc_msg_graph_recompute_req request; - if (!recv_msg(sockfd, &request, sizeof(request))) { + if (!recv_msg(sock, &request, sizeof(request))) { return; } if (!server.graph_recompute(request)) { @@ -2358,14 +1688,14 @@ static void rpc_serve_client(const std::vector & backends, const } case RPC_CMD_GET_DEVICE_MEMORY: { rpc_msg_get_device_memory_req request; - if (!recv_msg(sockfd, &request, sizeof(request))) { + if (!recv_msg(sock, &request, sizeof(request))) { return; } rpc_msg_get_device_memory_rsp response; if (!server.get_device_memory(request, response)) { return; } - if (!send_msg(sockfd, &response, sizeof(response))) { + if (!send_msg(sock, &response, sizeof(response))) { return; } break; @@ -2424,36 +1754,28 @@ void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir #else printf(" transport : TCP\n"); #endif // GGML_RPC_RDMA -#ifdef _WIN32 - { - WSADATA wsaData; - int res = WSAStartup(MAKEWORD(2, 2), &wsaData); - if (res != 0) { - fprintf(stderr, "WSAStartup failed: %d\n", res); - return; - } + if (!rpc_transport_init()) { + fprintf(stderr, "Failed to initialize RPC transport\n"); + return; } -#endif - auto server_socket = create_server_socket(host.c_str(), port); + auto server_socket = socket_t::create_server(host.c_str(), port); if (server_socket == nullptr) { fprintf(stderr, "Failed to create server socket\n"); return; } while (true) { - auto client_socket = socket_accept(server_socket->fd); + auto client_socket = server_socket->accept(); if (client_socket == nullptr) { fprintf(stderr, "Failed to accept client connection\n"); return; } printf("Accepted client connection\n"); fflush(stdout); - rpc_serve_client(backends, cache_dir, client_socket.get()); + rpc_serve_client(backends, cache_dir, client_socket); printf("Client connection closed\n"); fflush(stdout); } -#ifdef _WIN32 - WSACleanup(); -#endif + rpc_transport_shutdown(); for (auto backend : backends) { ggml_backend_free(backend); } diff --git a/ggml/src/ggml-rpc/transport.cpp b/ggml/src/ggml-rpc/transport.cpp new file mode 100644 index 00000000000..a728152421f --- /dev/null +++ b/ggml/src/ggml-rpc/transport.cpp @@ -0,0 +1,683 @@ +#include "transport.h" +#include "ggml-impl.h" + +#ifdef _WIN32 +# define WIN32_LEAN_AND_MEAN +# ifndef NOMINMAX +# define NOMINMAX +# endif +# include +# include +#else +# include +# include +# include +# include +# include +# include +# include +#endif +#include +#include +#include + +#ifdef GGML_RPC_RDMA +# include +# include +# ifndef _WIN32 +# include +# endif +#endif // GGML_RPC_RDMA + +#ifdef _WIN32 +typedef SOCKET sockfd_t; +using ssize_t = __int64; +#else +typedef int sockfd_t; +#endif + +static const char * RPC_DEBUG = std::getenv("GGML_RPC_DEBUG"); + +#define LOG_DBG(...) \ + do { if (RPC_DEBUG) GGML_LOG_DEBUG(__VA_ARGS__); } while (0) + +#ifdef GGML_RPC_RDMA +static constexpr size_t RDMA_CHUNK = 256 * 1024; // 256 KiB per send/recv (fits default 8 MiB memlock) +static constexpr int RDMA_RX_DEPTH = 24; // pre-posted recv ring: 24 × 256 KiB = 6 MiB +static constexpr size_t RDMA_GID_SIZE = 16; // RoCE GID / IB GID is always 16 bytes +using rdma_gid_t = std::array; + +struct rdma_conn { + struct ibv_context * ctx = nullptr; + struct ibv_pd * pd = nullptr; + struct ibv_cq * scq = nullptr; // send completions + struct ibv_cq * rcq = nullptr; // recv completions + struct ibv_qp * qp = nullptr; + + void * tx_buf = nullptr; + struct ibv_mr * tx_mr = nullptr; + + void * rx_buf = nullptr; // RDMA_RX_DEPTH × RDMA_CHUNK contiguous + struct ibv_mr * rx_mr = nullptr; + int rx_head = 0; + + uint32_t max_inline = 0; + + uint8_t * rx_slot(int i) const { + return static_cast(rx_buf) + static_cast(i) * RDMA_CHUNK; + } + + bool post_rx(int i) { + struct ibv_sge sge = {}; + sge.addr = (uintptr_t)rx_slot(i); + sge.length = RDMA_CHUNK; + sge.lkey = rx_mr->lkey; + struct ibv_recv_wr wr = {}, * bad = nullptr; + wr.wr_id = (uint64_t)i; + wr.sg_list = &sge; + wr.num_sge = 1; + return ibv_post_recv(qp, &wr, &bad) == 0; + } + + ~rdma_conn() { + if (tx_mr) ibv_dereg_mr(tx_mr); + if (rx_mr) ibv_dereg_mr(rx_mr); + free(tx_buf); + free(rx_buf); + if (qp) ibv_destroy_qp(qp); + if (scq) ibv_destroy_cq(scq); + if (rcq) ibv_destroy_cq(rcq); + if (pd) ibv_dealloc_pd(pd); + if (ctx) ibv_close_device(ctx); + } +}; + +// Local RDMA parameters captured during the probe phase and later consumed +// by rdma_activate() after the remote side's caps arrive via HELLO. +struct rdma_local_info { + uint32_t qpn = 0; + uint32_t psn = 0; + uint8_t gid[RDMA_GID_SIZE] = {}; + uint8_t ib_port = 0; + int gid_idx = 0; + enum ibv_mtu path_mtu = IBV_MTU_1024; +}; + +struct rdma_caps { + uint32_t qpn; + uint32_t psn; + uint8_t gid[RDMA_GID_SIZE]; +}; + +static_assert(sizeof(rdma_caps) == RPC_CONN_CAPS_SIZE, "rdma_caps must match conn_caps size"); + +#endif // GGML_RPC_RDMA + +struct socket_t::impl { + impl(sockfd_t fd) : use_rdma(false), fd(fd) {} + ~impl(); + bool send_data(const void * data, size_t size); + bool recv_data(void * data, size_t size); + void get_caps(uint8_t * local_caps); + void update_caps(const uint8_t * remote_caps); + +#ifdef GGML_RPC_RDMA + bool tcp_peer_closed(); + std::optional rdma_build_target_gid(); + bool rdma_probe(); + bool rdma_activate(uint32_t remote_qpn, uint32_t remote_psn, const uint8_t * remote_gid); + bool rdma_poll(struct ibv_cq * cq, struct ibv_wc * wc); + bool rdma_send(const void * data, size_t size); + bool rdma_recv(void * data, size_t size); + + std::unique_ptr rdma; + rdma_local_info rdma_local = {}; +#endif // GGML_RPC_RDMA + bool use_rdma; + sockfd_t fd; +}; + +socket_t::impl::~impl() { +#ifdef GGML_RPC_RDMA + rdma.reset(); +#endif // GGML_RPC_RDMA + LOG_DBG("[%s] closing socket %d\n", __func__, this->fd); +#ifdef _WIN32 + if (fd != INVALID_SOCKET) closesocket(this->fd); +#else + if (fd >= 0) close(this->fd); +#endif +} + +#ifdef GGML_RPC_RDMA + +bool socket_t::impl::tcp_peer_closed() { + if (fd < 0) return false; +#ifndef _WIN32 + struct pollfd pfd = { fd, POLLIN | POLLRDHUP, 0 }; + int r = poll(&pfd, 1, 0); + return r > 0 && (pfd.revents & (POLLHUP | POLLERR | POLLRDHUP)); +#else + return false; +#endif +} + +// Build a RoCE GID-shaped 16-byte target from a TCP socket's local address. +// Used to match the socket's local IP against the kernel's GID table so that +// a single memcmp handles IPv4, IPv4-mapped IPv6, and native IPv6 uniformly: +// AF_INET -> ::ffff:a.b.c.d (bytes 10-11 = 0xff, last 4 = IPv4) +// AF_INET6 (IPv4-mapped) -> ::ffff:a.b.c.d (already in GID shape) +// AF_INET6 (native v6) -> the 16-byte IPv6 address as-is +// Returns std::nullopt on unsupported family or getsockname failure. +std::optional socket_t::impl::rdma_build_target_gid() { + sockaddr_storage addr = {}; + socklen_t addr_len = sizeof(addr); + if (getsockname(fd, reinterpret_cast(&addr), &addr_len) != 0) { + return std::nullopt; + } + rdma_gid_t target = {}; + if (addr.ss_family == AF_INET) { + const auto * a = reinterpret_cast(&addr); + target[10] = 0xff; + target[11] = 0xff; + memcpy(&target[12], &a->sin_addr, 4); + return target; + } + if (addr.ss_family == AF_INET6) { + const auto * a = reinterpret_cast(&addr); + memcpy(target.data(), &a->sin6_addr, RDMA_GID_SIZE); + return target; + } + return std::nullopt; +} + +bool socket_t::impl::rdma_probe() { + const char * dev_env = std::getenv("GGML_RDMA_DEV"); + const char * gid_env = std::getenv("GGML_RDMA_GID"); + + auto target_gid = rdma_build_target_gid(); + if (!target_gid) { + return false; + } + + const uint8_t ib_port = 1; + int num_devs = 0; + ibv_device ** devs = ibv_get_device_list(&num_devs); + if (!devs || num_devs == 0) return false; + + ibv_context * ibctx = nullptr; + const char * matched_dev = nullptr; + int gid_idx = gid_env ? atoi(gid_env) : -1; + int gid_version = IBV_GID_TYPE_IB; // 0 = unknown/IB + + for (int d = 0; d < num_devs; d++) { + const char * dn = ibv_get_device_name(devs[d]); + if (dev_env && strcmp(dev_env, dn) != 0) continue; + + ibv_context * ctx = ibv_open_device(devs[d]); + if (!ctx) continue; + + ibv_port_attr pa; + if (ibv_query_port(ctx, ib_port, &pa) != 0) { ibv_close_device(ctx); continue; } + + int found_gid = gid_idx; + int found_version = IBV_GID_TYPE_IB; + if (found_gid < 0) { + // Find a GID on this port whose bytes equal the local TCP address + // (IPv4 or IPv6). Prefer RoCE v2 (UDP/IP, L3-routable) over v1 + // (raw Ethernet, same-L2 only) so silent hangs on L3-routed paths + // are avoided. ibv_query_gid_ex returns gid+type in one call. + int v2_idx = -1; + int v1_idx = -1; + for (int i = 0; i < pa.gid_tbl_len; i++) { + ibv_gid_entry entry = {}; + if (ibv_query_gid_ex(ctx, ib_port, i, &entry, 0) != 0) continue; + if (memcmp(entry.gid.raw, target_gid->data(), RDMA_GID_SIZE) != 0) continue; + if (entry.gid_type == IBV_GID_TYPE_ROCE_V2 && v2_idx < 0) { + v2_idx = i; + } else if (entry.gid_type == IBV_GID_TYPE_ROCE_V1 && v1_idx < 0) { + v1_idx = i; + } + } + if (v2_idx >= 0) { + found_gid = v2_idx; + found_version = IBV_GID_TYPE_ROCE_V2; + } else if (v1_idx >= 0) { + found_gid = v1_idx; + found_version = IBV_GID_TYPE_ROCE_V1; + } + } else { + // Explicit GID index from GGML_RDMA_GID — fetch its type for logging. + ibv_gid_entry entry = {}; + if (ibv_query_gid_ex(ctx, ib_port, found_gid, &entry, 0) == 0) { + found_version = entry.gid_type; + } + } + if (found_gid >= 0) { + ibctx = ctx; + gid_idx = found_gid; + gid_version = found_version; + matched_dev = dn; + rdma_local.path_mtu = pa.active_mtu; + break; + } + ibv_close_device(ctx); + } + ibv_free_device_list(devs); + if (!ibctx) return false; + + rdma_local.ib_port = ib_port; + rdma_local.gid_idx = gid_idx; + + rdma = std::make_unique(); + rdma->ctx = ibctx; + + rdma->pd = ibv_alloc_pd(ibctx); + if (!rdma->pd) return false; + + rdma->scq = ibv_create_cq(ibctx, 16, nullptr, nullptr, 0); + rdma->rcq = ibv_create_cq(ibctx, RDMA_RX_DEPTH + 4, nullptr, nullptr, 0); + if (!rdma->scq || !rdma->rcq) return false; + + ibv_qp_init_attr qia = {}; + qia.send_cq = rdma->scq; + qia.recv_cq = rdma->rcq; + qia.qp_type = IBV_QPT_RC; + qia.cap.max_send_wr = 4; + qia.cap.max_recv_wr = RDMA_RX_DEPTH + 4; + qia.cap.max_send_sge = 1; + qia.cap.max_recv_sge = 1; + qia.cap.max_inline_data = 256; + + rdma->qp = ibv_create_qp(rdma->pd, &qia); + if (!rdma->qp) return false; + rdma->max_inline = qia.cap.max_inline_data; + + rdma->tx_buf = aligned_alloc(4096, RDMA_CHUNK); + rdma->rx_buf = aligned_alloc(4096, static_cast(RDMA_RX_DEPTH) * RDMA_CHUNK); + if (!rdma->tx_buf || !rdma->rx_buf) return false; + + rdma->tx_mr = ibv_reg_mr(rdma->pd, rdma->tx_buf, RDMA_CHUNK, IBV_ACCESS_LOCAL_WRITE); + rdma->rx_mr = ibv_reg_mr(rdma->pd, rdma->rx_buf, static_cast(RDMA_RX_DEPTH) * RDMA_CHUNK, + IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE); + if (!rdma->tx_mr || !rdma->rx_mr) return false; + + ibv_gid local_gid; + if (ibv_query_gid(ibctx, ib_port, gid_idx, &local_gid) != 0) return false; + + rdma_local.qpn = rdma->qp->qp_num; + rdma_local.psn = rdma->qp->qp_num & 0xffffff; + memcpy(&rdma_local.gid, &local_gid, RDMA_GID_SIZE); + + const char * ver_str = ""; + if (gid_version == IBV_GID_TYPE_ROCE_V2) { + ver_str = " RoCEv2"; + } else if (gid_version == IBV_GID_TYPE_ROCE_V1) { + ver_str = " RoCEv1"; + } + GGML_LOG_INFO("RDMA probed: dev=%s gid=%d%s qpn=%u inline=%u\n", + matched_dev, gid_idx, ver_str, rdma_local.qpn, rdma->max_inline); + return true; +} + +// Phase 2: Given remote QPN/PSN/GID, transition QP: RESET->INIT->pre-post->RTR->RTS. +// On success, the connection is live and ready for rdma_send/rdma_recv. +bool socket_t::impl::rdma_activate(uint32_t remote_qpn, uint32_t remote_psn, const uint8_t * remote_gid) { + // RESET -> INIT + { + struct ibv_qp_attr a = {}; + a.qp_state = IBV_QPS_INIT; + a.port_num = rdma_local.ib_port; + a.pkey_index = 0; + a.qp_access_flags = IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ | IBV_ACCESS_LOCAL_WRITE; + if (ibv_modify_qp(rdma->qp, &a, + IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS) != 0) { + return false; + } + } + + for (int i = 0; i < RDMA_RX_DEPTH; i++) { + if (!rdma->post_rx(i)) return false; + } + + // INIT -> RTR + { + struct ibv_qp_attr a = {}; + a.qp_state = IBV_QPS_RTR; + a.path_mtu = rdma_local.path_mtu; + a.dest_qp_num = remote_qpn; + a.rq_psn = remote_psn; + a.max_dest_rd_atomic = 1; + a.min_rnr_timer = 1; + a.ah_attr.is_global = 1; + memcpy(&a.ah_attr.grh.dgid, remote_gid, RDMA_GID_SIZE); + a.ah_attr.grh.hop_limit = 1; + a.ah_attr.grh.sgid_index = rdma_local.gid_idx; + a.ah_attr.dlid = 0; + a.ah_attr.port_num = rdma_local.ib_port; + if (ibv_modify_qp(rdma->qp, &a, + IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | IBV_QP_DEST_QPN | + IBV_QP_RQ_PSN | IBV_QP_MAX_DEST_RD_ATOMIC | IBV_QP_MIN_RNR_TIMER) != 0) { + return false; + } + } + + // RTR -> RTS + { + struct ibv_qp_attr a = {}; + a.qp_state = IBV_QPS_RTS; + a.timeout = 14; + a.retry_cnt = 7; + a.rnr_retry = 7; + a.sq_psn = rdma_local.psn; + a.max_rd_atomic = 1; + if (ibv_modify_qp(rdma->qp, &a, + IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT | IBV_QP_RNR_RETRY | + IBV_QP_SQ_PSN | IBV_QP_MAX_QP_RD_ATOMIC) != 0) { + return false; + } + } + + GGML_LOG_INFO("RDMA activated: qpn=%u->%u mtu=%d rx_depth=%d\n", + rdma_local.qpn, remote_qpn, 128 << rdma_local.path_mtu, RDMA_RX_DEPTH); + return true; +} + +bool socket_t::impl::rdma_poll(struct ibv_cq * cq, struct ibv_wc * wc) { + for (uint64_t s = 0; ; s++) { + int n = ibv_poll_cq(cq, 1, wc); + if (n > 0) { + if (wc->status != IBV_WC_SUCCESS) { + GGML_LOG_ERROR("RDMA CQ wc error: status=%d (%s) vendor_err=0x%x\n", + wc->status, ibv_wc_status_str(wc->status), wc->vendor_err); + } + return wc->status == IBV_WC_SUCCESS; + } + if (n < 0) return false; + if ((s & 0xFFFFF) == 0 && s > 0) { + if (tcp_peer_closed()) { + return false; + } + } + } +} + +bool socket_t::impl::rdma_send(const void * data, size_t size) { + rdma_conn * c = rdma.get(); + const uint8_t * src = (const uint8_t *)data; + size_t rem = size; + while (rem > 0) { + size_t chunk = std::min(rem, RDMA_CHUNK); + + struct ibv_sge sge = {}; + struct ibv_send_wr wr = {}, * bad = nullptr; + wr.opcode = IBV_WR_SEND; + wr.sg_list = &sge; + wr.num_sge = 1; + + if (chunk <= c->max_inline) { + sge.addr = (uintptr_t)src; + sge.length = chunk; + wr.send_flags = IBV_SEND_SIGNALED | IBV_SEND_INLINE; + } else { + memcpy(c->tx_buf, src, chunk); + sge.addr = (uintptr_t)c->tx_buf; + sge.length = chunk; + sge.lkey = c->tx_mr->lkey; + wr.send_flags = IBV_SEND_SIGNALED; + } + + if (ibv_post_send(c->qp, &wr, &bad) != 0) return false; + struct ibv_wc wc; + if (!rdma_poll(c->scq, &wc)) return false; + + src += chunk; + rem -= chunk; + } + return true; +} + +bool socket_t::impl::rdma_recv(void * data, size_t size) { + rdma_conn * c = rdma.get(); + uint8_t * dst = (uint8_t *)data; + size_t rem = size; + while (rem > 0) { + struct ibv_wc wc; + if (!rdma_poll(c->rcq, &wc)) return false; + + int slot = (int)wc.wr_id; + size_t got = wc.byte_len; + memcpy(dst, c->rx_slot(slot), got); + + if (!c->post_rx(slot)) return false; + + dst += got; + rem -= got; + } + return true; +} + +#endif // GGML_RPC_RDMA + +bool socket_t::impl::send_data(const void * data, size_t size) { +#ifdef GGML_RPC_RDMA + if (use_rdma) { + return rdma_send(data, size); + } +#endif + size_t bytes_sent = 0; + while (bytes_sent < size) { + size_t size_to_send = std::min(size - bytes_sent, MAX_CHUNK_SIZE); + ssize_t n = send(fd, (const char *)data + bytes_sent, size_to_send, 0); + if (n < 0) { + GGML_LOG_ERROR("send failed (bytes_sent=%zu, size_to_send=%zu)\n", + bytes_sent, size_to_send); + return false; + } + bytes_sent += (size_t)n; + } + return true; +} + +bool socket_t::impl::recv_data(void * data, size_t size) { +#ifdef GGML_RPC_RDMA + if (use_rdma) { + return rdma_recv(data, size); + } +#endif + size_t bytes_recv = 0; + while (bytes_recv < size) { + size_t size_to_recv = std::min(size - bytes_recv, MAX_CHUNK_SIZE); + ssize_t n = recv(fd, (char *)data + bytes_recv, size_to_recv, 0); + if (n < 0) { + GGML_LOG_ERROR("recv failed (bytes_recv=%zu, size_to_recv=%zu)\n", + bytes_recv, size_to_recv); + return false; + } + if (n == 0) { + LOG_DBG("recv returned 0 (peer closed?)\n"); + return false; + } + bytes_recv += (size_t)n; + } + return true; +} + +void socket_t::impl::get_caps(uint8_t * local_caps) { + memset(local_caps, 0, RPC_CONN_CAPS_SIZE); +#ifdef GGML_RPC_RDMA + rdma_local = {}; + if (rdma_probe()) { + rdma_caps rc = {}; + rc.qpn = rdma_local.qpn; + rc.psn = rdma_local.psn; + memcpy(rc.gid, rdma_local.gid, RDMA_GID_SIZE); + memcpy(local_caps, &rc, sizeof(rc)); + } else { + rdma.reset(); + } +#endif // GGML_RPC_RDMA +} + +void socket_t::impl::update_caps(const uint8_t * remote_caps) { +#ifdef GGML_RPC_RDMA + if (!rdma) { + return; + } + rdma_caps rc = {}; + memcpy(&rc, remote_caps, sizeof(rc)); + if (rc.qpn == 0) { + rdma.reset(); + return; + } + if (rdma_activate(rc.qpn, rc.psn, rc.gid)) { + use_rdma = true; + } else { + GGML_LOG_ERROR("RDMA activate failed, staying on TCP\n"); + rdma.reset(); + } +#else + (void)remote_caps; +#endif // GGML_RPC_RDMA +} + + +///////////////////////////////////////////////////////////////////////////// + +socket_t::socket_t(std::unique_ptr p) : pimpl(std::move(p)) {} + +socket_t::~socket_t() = default; + +bool socket_t::send_data(const void * data, size_t size) { + return pimpl->send_data(data, size); +} + +bool socket_t::recv_data(void * data, size_t size) { + return pimpl->recv_data(data, size); +} + +void socket_t::get_caps(uint8_t * local_caps) { + return pimpl->get_caps(local_caps); +} + +void socket_t::update_caps(const uint8_t * remote_caps) { + return pimpl->update_caps(remote_caps); +} + +static bool is_valid_fd(sockfd_t sockfd) { +#ifdef _WIN32 + return sockfd != INVALID_SOCKET; +#else + return sockfd >= 0; +#endif +} + +static bool set_no_delay(sockfd_t sockfd) { + int flag = 1; + // set TCP_NODELAY to disable Nagle's algorithm + int ret = setsockopt(sockfd, IPPROTO_TCP, TCP_NODELAY, (char *)&flag, sizeof(int)); + return ret == 0; +} + +static bool set_reuse_addr(sockfd_t sockfd) { + int flag = 1; + int ret = setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, (char *)&flag, sizeof(int)); + return ret == 0; +} + +socket_ptr socket_t::accept() { + auto client_socket_fd = ::accept(pimpl->fd, NULL, NULL); + if (!is_valid_fd(client_socket_fd)) { + return nullptr; + } + if (!set_no_delay(client_socket_fd)) { + GGML_LOG_ERROR("Failed to set TCP_NODELAY\n"); + return nullptr; + } + return socket_ptr(new socket_t(std::make_unique(client_socket_fd))); +} + +socket_ptr socket_t::create_server(const char * host, int port) { + auto sockfd = socket(AF_INET, SOCK_STREAM, 0); + if (!is_valid_fd(sockfd)) { + return nullptr; + } + if (!set_reuse_addr(sockfd)) { + GGML_LOG_ERROR("Failed to set SO_REUSEADDR\n"); + return nullptr; + } + if (inet_addr(host) == INADDR_NONE) { + GGML_LOG_ERROR("Invalid host address: %s\n", host); + return nullptr; + } + struct sockaddr_in serv_addr; + serv_addr.sin_family = AF_INET; + serv_addr.sin_addr.s_addr = inet_addr(host); + serv_addr.sin_port = htons(port); + + if (bind(sockfd, (struct sockaddr *) &serv_addr, sizeof(serv_addr)) < 0) { + return nullptr; + } + if (listen(sockfd, 1) < 0) { + return nullptr; + } + return socket_ptr(new socket_t(std::make_unique(sockfd))); +} + +socket_ptr socket_t::connect(const char * host, int port) { + auto sockfd = socket(AF_INET, SOCK_STREAM, 0); + if (!is_valid_fd(sockfd)) { + return nullptr; + } + if (!set_no_delay(sockfd)) { + GGML_LOG_ERROR("Failed to set TCP_NODELAY\n"); + return nullptr; + } + struct sockaddr_in addr; + addr.sin_family = AF_INET; + addr.sin_port = htons(port); + struct hostent * server = gethostbyname(host); + if (server == NULL) { + GGML_LOG_ERROR("Cannot resolve host '%s'\n", host); + return nullptr; + } + memcpy(&addr.sin_addr.s_addr, server->h_addr, server->h_length); + if (::connect(sockfd, (struct sockaddr *)&addr, sizeof(addr)) < 0) { + return nullptr; + } + return socket_ptr(new socket_t(std::make_unique(sockfd))); +} + +#ifdef _WIN32 +static std::mutex g_rpc_transport_mu; +static bool g_rpc_transport_wsa_started = false; +#endif + +bool rpc_transport_init() { +#ifdef _WIN32 + std::lock_guard lock(g_rpc_transport_mu); + if (g_rpc_transport_wsa_started) { + return true; + } + WSADATA wsaData; + int res = WSAStartup(MAKEWORD(2, 2), &wsaData); + if (res != 0) { + return false; + } + g_rpc_transport_wsa_started = true; + return true; +#else + return true; +#endif +} + +void rpc_transport_shutdown() { +#ifdef _WIN32 + std::lock_guard lock(g_rpc_transport_mu); + if (!g_rpc_transport_wsa_started) { + return; + } + WSACleanup(); + g_rpc_transport_wsa_started = false; +#endif +} diff --git a/ggml/src/ggml-rpc/transport.h b/ggml/src/ggml-rpc/transport.h new file mode 100644 index 00000000000..73b85cc530a --- /dev/null +++ b/ggml/src/ggml-rpc/transport.h @@ -0,0 +1,34 @@ +#pragma once + +#include +#include +#include + +struct socket_t; +typedef std::shared_ptr socket_ptr; + +static constexpr size_t MAX_CHUNK_SIZE = 1024ull * 1024ull * 1024ull; // 1 GiB +static constexpr size_t RPC_CONN_CAPS_SIZE = 24; + +struct socket_t { + ~socket_t(); + + bool send_data(const void * data, size_t size); + bool recv_data(void * data, size_t size); + + socket_ptr accept(); + + void get_caps(uint8_t * local_caps); + void update_caps(const uint8_t * remote_caps); + + static socket_ptr create_server(const char * host, int port); + static socket_ptr connect(const char * host, int port); + +private: + struct impl; + explicit socket_t(std::unique_ptr p); + std::unique_ptr pimpl; +}; + +bool rpc_transport_init(); +void rpc_transport_shutdown(); From 171f037fbaef10c7901018a2be91e85764d581c2 Mon Sep 17 00:00:00 2001 From: texasich <101962694+texasich@users.noreply.github.com> Date: Sun, 19 Apr 2026 02:25:05 -0500 Subject: [PATCH 455/831] cmake: remove CMP0194 policy to restore MSVC builds (llama/21934) #21630 added the CMP0194 NEW policy to silence a CMake warning, but on Windows runners it caused CMake to prefer the MinGW toolchain for ASM and broke MSVC builds. Reverting only that policy block restores the previous working behavior. The CMake 4.1+ warning comes back, but that is cosmetic and does not break any platform. Reported-by: oobabooga Refs: #21630 Co-authored-by: texasich --- ggml/CMakeLists.txt | 6 ------ 1 file changed, 6 deletions(-) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 6b65ecd6e5c..a0eb9204eab 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -1,11 +1,5 @@ cmake_minimum_required(VERSION 3.14...3.28) # for add_link_options and implicit target directories. -# ref: https://cmake.org/cmake/help/latest/policy/CMP0194.html -# MSVC is not a valid assembler for the ASM language. -# Set to NEW to avoid a warning on CMake 4.1+ with MSVC. -if (POLICY CMP0194) - cmake_policy(SET CMP0194 NEW) -endif() project("ggml" C CXX ASM) ### GGML Version From 671fd1527a4aeb1b186d54302d42a8f5451feb82 Mon Sep 17 00:00:00 2001 From: Gaurav Garg Date: Sun, 19 Apr 2026 15:18:35 +0530 Subject: [PATCH 456/831] ggml : reduce CPU overhead in meta backend (llama/22041) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * cache subgraph splits when cgraph is unchanged Skip per-call subgraph construction in ggml_backend_meta_graph_compute when the same ggml_cgraph is used consecutively. Assign uid to every sub-graph so that CUDA's fast uid check path hits too. * Address review comments * Keep the scope as is * Rename last_uid and last_n_subgraphs field. Remove last_max_tmp_size field. Refactor code. * Address review comments * Update ggml/src/ggml-backend-meta.cpp Co-authored-by: Johannes Gäßler * Update ggml/src/ggml-backend-meta.cpp Co-authored-by: Johannes Gäßler --------- Co-authored-by: Johannes Gäßler --- ggml/src/ggml-backend-meta.cpp | 307 +++++++++++++++++---------------- 1 file changed, 160 insertions(+), 147 deletions(-) diff --git a/ggml/src/ggml-backend-meta.cpp b/ggml/src/ggml-backend-meta.cpp index 24f6bc0639d..39651adc1c1 100644 --- a/ggml/src/ggml-backend-meta.cpp +++ b/ggml/src/ggml-backend-meta.cpp @@ -1456,6 +1456,8 @@ struct ggml_backend_meta_context { int max_nnodes = 0; size_t max_tmp_size = 0; size_t max_subgraphs = 0; + size_t n_subgraphs = 0; + uint64_t uid = 0; void * comm_ctx = nullptr; ggml_backend_comm_allreduce_tensor_t comm_allreduce = nullptr; @@ -1616,6 +1618,9 @@ static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend, const size_t n_backends = ggml_backend_meta_n_backends(backend); ggml_backend_meta_context * backend_ctx = (ggml_backend_meta_context *) backend->context; + // If the previous cgraph had a defined UID it can be used to skip rebuilding the subgraphs per simple backend. + const bool needs_rebuild = (cgraph->uid == 0) || (cgraph->uid != backend_ctx->uid); + bool max_nnodes_raised = false; if (cgraph->n_nodes > backend_ctx->max_nnodes) { for (size_t j = 0; j < n_backends; j++) { @@ -1625,173 +1630,181 @@ static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend, } backend_ctx->max_nnodes = cgraph->n_nodes; max_nnodes_raised = true; + assert(needs_rebuild); } - for (size_t j = 0; j < n_backends; j++) { - auto & bcj = backend_ctx->backend_configs[j]; - - for (int i = 0; i < cgraph->n_nodes; i++) { - ggml_tensor * node = cgraph->nodes[i]; - if (node->view_src != nullptr && node->view_src->op == GGML_OP_NONE && ggml_backend_buffer_is_host(node->view_src->buffer)) { - // FIXME s_copy_main is on the CPU and its view seems to be incorrectly added to the graph nodes. - // For regular usage this doesn't matter since it's a noop but trying to call ggml_backend_meta_buffer_simple_tensor results in a crash. - bcj.nodes[i] = node; - continue; + + if (needs_rebuild) { + size_t n_subgraphs = 0; + size_t max_tmp_size = 0; + + for (size_t j = 0; j < n_backends; j++) { + auto & bcj = backend_ctx->backend_configs[j]; + + for (int i = 0; i < cgraph->n_nodes; i++) { + ggml_tensor * node = cgraph->nodes[i]; + if (node->view_src != nullptr && node->view_src->op == GGML_OP_NONE && ggml_backend_buffer_is_host(node->view_src->buffer)) { + // FIXME s_copy_main is on the CPU and its view seems to be incorrectly added to the graph nodes. + // For regular usage this doesn't matter since it's a noop but trying to call ggml_backend_meta_buffer_simple_tensor results in a crash. + bcj.nodes[i] = node; + continue; + } + bcj.nodes[i] = ggml_backend_meta_buffer_simple_tensor(node, j); + GGML_ASSERT(bcj.nodes[i]); } - bcj.nodes[i] = ggml_backend_meta_buffer_simple_tensor(node, j); - GGML_ASSERT(bcj.nodes[i]); } - } - size_t n_subgraphs = 0; - size_t max_tmp_size = 0; - { - // For MoE models it may make sense to delay the AllReduce in order to reduce I/O: - auto get_i_delayed = [&](const int i) -> int { - int id = i; // i_delayed - int idr = i; // i_delayed return, last safe return value - - ggml_tensor * node = cgraph->nodes[id]; - int32_t n_used = ggml_node_get_use_count(cgraph, id); - if (id + 1 >= cgraph->n_nodes) { - return idr; - } - { - ggml_tensor * next = cgraph->nodes[id+1]; - if (next->op == GGML_OP_ADD_ID && next->src[0] == node && - ggml_backend_meta_get_split_state(next->src[1], false).axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL && - ggml_backend_meta_get_split_state(next->src[2], false).axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { - node = next; + { + // For MoE models it may make sense to delay the AllReduce in order to reduce I/O: + auto get_i_delayed = [&](const int i) -> int { + int id = i; // i_delayed + int idr = i; // i_delayed return, last safe return value + + ggml_tensor * node = cgraph->nodes[id]; + int32_t n_used = ggml_node_get_use_count(cgraph, id); + if (id + 1 >= cgraph->n_nodes) { + return idr; + } + { + ggml_tensor * next = cgraph->nodes[id+1]; + if (next->op == GGML_OP_ADD_ID && next->src[0] == node && + ggml_backend_meta_get_split_state(next->src[1], false).axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL && + ggml_backend_meta_get_split_state(next->src[2], false).axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { + node = next; + id++; + idr = id; + n_used = ggml_node_get_use_count(cgraph, id); + } + } + if (id + 1 >= cgraph->n_nodes) { + return idr; + } + { + ggml_tensor * next = cgraph->nodes[id+1]; + if (next->op == GGML_OP_MUL && next->src[0] == node && + ggml_backend_meta_get_split_state(next->src[1], false).axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { + node = next; + id++; + idr = id; + n_used = ggml_node_get_use_count(cgraph, id); + } + } + + if (n_used != node->ne[1] || id + 2*n_used-1 >= cgraph->n_nodes) { + return idr; + } + for (int32_t k = 0; k < n_used; k++) { + ggml_tensor * next = cgraph->nodes[id+1]; + if (next->op != GGML_OP_VIEW || next->view_src != node || next->view_offs != k*node->nb[1] || + next->ne[0] != node->ne[0] || next->ne[1] != node->ne[2] || next->nb[1] != node->nb[2] || + ggml_node_get_use_count(cgraph, id+1) != 1) { + return idr; + } id++; - idr = id; - n_used = ggml_node_get_use_count(cgraph, id); } - } - if (id + 1 >= cgraph->n_nodes) { - return idr; - } - { - ggml_tensor * next = cgraph->nodes[id+1]; - if (next->op == GGML_OP_MUL && next->src[0] == node && - ggml_backend_meta_get_split_state(next->src[1], false).axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { - node = next; + { + ggml_tensor * next = cgraph->nodes[id+1]; + if (next->op != GGML_OP_ADD || next->src[0] != cgraph->nodes[id - (n_used-1)] || + next->src[1] != cgraph->nodes[id - (n_used-2)] || ggml_node_get_use_count(cgraph, id+1) != 1) { + return idr; + } id++; - idr = id; - n_used = ggml_node_get_use_count(cgraph, id); } - } - - if (n_used != node->ne[1] || id + 2*n_used-1 >= cgraph->n_nodes) { + for (int32_t k = 0; k < n_used - 2; k++) { + ggml_tensor * next = cgraph->nodes[id+1]; + if (next->op != GGML_OP_ADD || next->src[0] != cgraph->nodes[id] || + next->src[1] != cgraph->nodes[id - (n_used-2)] || ggml_node_get_use_count(cgraph, id+1) != 1) { + return idr; + } + id++; + } + idr = id; return idr; - } - for (int32_t k = 0; k < n_used; k++) { - ggml_tensor * next = cgraph->nodes[id+1]; - if (next->op != GGML_OP_VIEW || next->view_src != node || next->view_offs != k*node->nb[1] || - next->ne[0] != node->ne[0] || next->ne[1] != node->ne[2] || next->nb[1] != node->nb[2] || - ggml_node_get_use_count(cgraph, id+1) != 1) { - return idr; + }; + + int i_start = 0; + for (int i = 0; i < cgraph->n_nodes; i++) { + ggml_tensor * node = cgraph->nodes[i]; + if (node->view_src != nullptr && node->view_src->op == GGML_OP_NONE && ggml_backend_buffer_is_host(node->view_src->buffer)) { + continue; } - id++; - } - { - ggml_tensor * next = cgraph->nodes[id+1]; - if (next->op != GGML_OP_ADD || next->src[0] != cgraph->nodes[id - (n_used-1)] || - next->src[1] != cgraph->nodes[id - (n_used-2)] || ggml_node_get_use_count(cgraph, id+1) != 1) { - return idr; + const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(node, /*assume_sync =*/ false); + if (split_state.axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL) { + max_tmp_size = std::max(max_tmp_size, ggml_nbytes(node)); } - id++; - } - for (int32_t k = 0; k < n_used - 2; k++) { - ggml_tensor * next = cgraph->nodes[id+1]; - if (next->op != GGML_OP_ADD || next->src[0] != cgraph->nodes[id] || - next->src[1] != cgraph->nodes[id - (n_used-2)] || ggml_node_get_use_count(cgraph, id+1) != 1) { - return idr; + const bool new_subgraph = i + 1 == cgraph->n_nodes || split_state.axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL; + if (!new_subgraph) { + continue; } - id++; - } - idr = id; - return idr; - }; - - int i_start = 0; - for (int i = 0; i < cgraph->n_nodes; i++) { - ggml_tensor * node = cgraph->nodes[i]; - if (node->view_src != nullptr && node->view_src->op == GGML_OP_NONE && ggml_backend_buffer_is_host(node->view_src->buffer)) { - continue; - } - const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(node, /*assume_sync =*/ false); - if (split_state.axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL) { - max_tmp_size = std::max(max_tmp_size, ggml_nbytes(node)); - } - const bool new_subgraph = i + 1 == cgraph->n_nodes || split_state.axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL; - if (!new_subgraph) { - continue; + + i = get_i_delayed(i); + + for (size_t j = 0; j < n_backends; j++) { + auto & bcj = backend_ctx->backend_configs[j]; + bcj.cgraphs[n_subgraphs].offset = i_start; + } + n_subgraphs++; + i_start = i + 1; } + GGML_ASSERT(i_start == cgraph->n_nodes); + } - i = get_i_delayed(i); + backend_ctx->uid = cgraph->uid; + backend_ctx->n_subgraphs = n_subgraphs; + if (max_tmp_size > backend_ctx->max_tmp_size) { for (size_t j = 0; j < n_backends; j++) { auto & bcj = backend_ctx->backend_configs[j]; - bcj.cgraphs[n_subgraphs].offset = i_start; + bcj.buf.reset(ggml_backend_alloc_buffer(bcj.backend, max_tmp_size)); + } + backend_ctx->max_tmp_size = max_tmp_size; + } + + if (max_nnodes_raised || n_subgraphs > backend_ctx->max_subgraphs) { + backend_ctx->max_subgraphs = std::max(backend_ctx->max_subgraphs, n_subgraphs); + const size_t n_reduce_steps = backend_ctx->n_reduce_steps(); + const size_t n_nodes_per_device = 2 * n_reduce_steps; // tmp + ADD per step + const size_t n_cgraphs_per_device = n_reduce_steps; // 1 ADD graph per step + const size_t mem_per_device_graphs_main = backend_ctx->max_subgraphs*ggml_graph_overhead_custom(backend_ctx->max_nnodes, cgraph->grads); + const size_t mem_per_device_graphs_aux = n_cgraphs_per_device*backend_ctx->max_subgraphs*ggml_graph_overhead_custom(1, cgraph->grads); + const size_t mem_per_device_nodes_aux = n_nodes_per_device*backend_ctx->max_subgraphs*ggml_tensor_overhead(); + ggml_init_params params = { + /*.mem_size =*/ n_backends * (mem_per_device_graphs_main + mem_per_device_graphs_aux + mem_per_device_nodes_aux), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + backend_ctx->ctx.reset(ggml_init(params)); + for (size_t j = 0; j < n_backends; j++) { + auto & bcj = backend_ctx->backend_configs[j]; + for (size_t i = 0; i < n_subgraphs; i++) { + bcj.cgraphs[i].cgraph_main = ggml_new_graph_custom(backend_ctx->ctx.get(), cgraph->n_nodes, /*grads =*/ false); + } + } + backend_ctx->cgraphs_aux.resize(n_backends*n_cgraphs_per_device*backend_ctx->max_subgraphs); + for (size_t k = 0; k < backend_ctx->cgraphs_aux.size(); k++) { + backend_ctx->cgraphs_aux[k] = ggml_new_graph_custom(backend_ctx->ctx.get(), 1, cgraph->grads); + } + backend_ctx->nodes_aux.resize(n_backends*n_nodes_per_device*backend_ctx->max_subgraphs); + for (size_t k = 0; k < backend_ctx->nodes_aux.size(); k++) { + backend_ctx->nodes_aux[k] = ggml_new_tensor_1d(backend_ctx->ctx.get(), GGML_TYPE_F32, 1); } - n_subgraphs++; - i_start = i + 1; } - GGML_ASSERT(i_start == cgraph->n_nodes); - } - if (max_tmp_size > backend_ctx->max_tmp_size) { for (size_t j = 0; j < n_backends; j++) { auto & bcj = backend_ctx->backend_configs[j]; - bcj.buf.reset(ggml_backend_alloc_buffer(bcj.backend, max_tmp_size)); - } - backend_ctx->max_tmp_size = max_tmp_size; - } - - - if (max_nnodes_raised || n_subgraphs > backend_ctx->max_subgraphs) { - backend_ctx->max_subgraphs = std::max(backend_ctx->max_subgraphs, n_subgraphs); - const size_t n_reduce_steps = backend_ctx->n_reduce_steps(); - const size_t n_nodes_per_device = 2 * n_reduce_steps; // tmp + ADD per step - const size_t n_cgraphs_per_device = n_reduce_steps; // 1 ADD graph per step - const size_t mem_per_device_graphs_main = backend_ctx->max_subgraphs*ggml_graph_overhead_custom(backend_ctx->max_nnodes, cgraph->grads); - const size_t mem_per_device_graphs_aux = n_cgraphs_per_device*backend_ctx->max_subgraphs*ggml_graph_overhead_custom(1, cgraph->grads); - const size_t mem_per_device_nodes_aux = n_nodes_per_device*backend_ctx->max_subgraphs*ggml_tensor_overhead(); - ggml_init_params params = { - /*.mem_size =*/ n_backends * (mem_per_device_graphs_main + mem_per_device_graphs_aux + mem_per_device_nodes_aux), - /*.mem_buffer =*/ nullptr, - /*.no_alloc =*/ true, - }; - backend_ctx->ctx.reset(ggml_init(params)); - for (size_t j = 0; j < n_backends; j++) { - auto & bcj = backend_ctx->backend_configs[j]; - for (size_t i = 0; i < n_subgraphs; i++) { - bcj.cgraphs[i].cgraph_main = ggml_new_graph_custom(backend_ctx->ctx.get(), cgraph->n_nodes, /*grads =*/ false); - } - } - backend_ctx->cgraphs_aux.resize(n_backends*n_cgraphs_per_device*backend_ctx->max_subgraphs); - for (size_t k = 0; k < backend_ctx->cgraphs_aux.size(); k++) { - backend_ctx->cgraphs_aux[k] = ggml_new_graph_custom(backend_ctx->ctx.get(), 1, cgraph->grads); - } - backend_ctx->nodes_aux.resize(n_backends*n_nodes_per_device*backend_ctx->max_subgraphs); - for (size_t k = 0; k < backend_ctx->nodes_aux.size(); k++) { - backend_ctx->nodes_aux[k] = ggml_new_tensor_1d(backend_ctx->ctx.get(), GGML_TYPE_F32, 1); - } - } - - for (size_t j = 0; j < n_backends; j++) { - auto & bcj = backend_ctx->backend_configs[j]; - for (size_t i_graph = 0; i_graph < n_subgraphs; i_graph++) { - ggml_cgraph * cgraph_ij = bcj.cgraphs[i_graph].cgraph_main; - const size_t i_node_start = bcj.cgraphs[i_graph].offset; - const size_t i_node_stop = i_graph + 1 < n_subgraphs ? bcj.cgraphs[i_graph + 1].offset : cgraph->n_nodes; - cgraph_ij->n_nodes = i_node_stop - i_node_start; - ggml_hash_set_reset(&cgraph_ij->visited_hash_set); - for (size_t i_node = i_node_start; i_node < i_node_stop; i_node++) { - ggml_tensor * node_ij = bcj.nodes[i_node]; - cgraph_ij->nodes[i_node - i_node_start] = node_ij; - const size_t hash_pos_orig = ggml_hash_find(&cgraph->visited_hash_set, cgraph->nodes[i_node]); - const size_t hash_pos_ij = ggml_hash_insert(&cgraph_ij->visited_hash_set, node_ij); - cgraph_ij->use_counts[hash_pos_ij] = cgraph->use_counts[hash_pos_orig]; + for (size_t i_graph = 0; i_graph < n_subgraphs; i_graph++) { + ggml_cgraph * cgraph_ij = bcj.cgraphs[i_graph].cgraph_main; + const size_t i_node_start = bcj.cgraphs[i_graph].offset; + const size_t i_node_stop = i_graph + 1 < n_subgraphs ? bcj.cgraphs[i_graph + 1].offset : cgraph->n_nodes; + cgraph_ij->n_nodes = i_node_stop - i_node_start; + ggml_hash_set_reset(&cgraph_ij->visited_hash_set); + for (size_t i_node = i_node_start; i_node < i_node_stop; i_node++) { + ggml_tensor * node_ij = bcj.nodes[i_node]; + cgraph_ij->nodes[i_node - i_node_start] = node_ij; + const size_t hash_pos_orig = ggml_hash_find(&cgraph->visited_hash_set, cgraph->nodes[i_node]); + const size_t hash_pos_ij = ggml_hash_insert(&cgraph_ij->visited_hash_set, node_ij); + cgraph_ij->use_counts[hash_pos_ij] = cgraph->use_counts[hash_pos_orig]; + } + cgraph_ij->uid = ggml_graph_next_uid(); } } } @@ -1898,7 +1911,7 @@ static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend, }; - for (size_t i = 0; i < n_subgraphs; i++) { + for (size_t i = 0; i < backend_ctx->n_subgraphs; i++) { for (size_t j = 0; j < n_backends; j++) { auto & bcj = backend_ctx->backend_configs[j]; const ggml_status status = ggml_backend_graph_compute_async(bcj.backend, bcj.cgraphs[i].cgraph_main); @@ -1907,7 +1920,7 @@ static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend, } } - if (n_backends > 1 && i < n_subgraphs - 1) { + if (n_backends > 1 && i < backend_ctx->n_subgraphs - 1) { bool backend_allreduce_success = false; if (backend_ctx->comm_ctx) { std::vector nodes; From 945746b40c2fe7983b02f00fb0c653476334a7bc Mon Sep 17 00:00:00 2001 From: uvos Date: Sun, 19 Apr 2026 12:59:44 +0200 Subject: [PATCH 457/831] HIP: Remove unesscary NCCL_CHECK (llama/21914) --- ggml/src/ggml-cuda/vendors/hip.h | 1 - 1 file changed, 1 deletion(-) diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h index 898fec31e36..52c38908e06 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h @@ -33,7 +33,6 @@ #define CU_MEM_LOCATION_TYPE_DEVICE hipMemLocationTypeDevice #define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite #define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }} -#define NCCL_CHECK(fn) {ncclResult_t err = fn; if(err != ncclSuccess) { GGML_ABORT("RCCL Failure RCCL returned: %i\n", err); }} #define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width) #define __shfl_up_sync(mask, var, laneMask, width) __shfl_up(var, laneMask, width) #define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width) From b8f57c9c50e389bd4e21e3ebc0c9db4506bf2e2a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sun, 19 Apr 2026 18:26:59 +0200 Subject: [PATCH 458/831] CUDA: refactor mma data loading for AMD (llama/22051) * CUDA: refactor mma data loading for AMD * fix CDNA MMQ occupancy * fix CDNA3 mma * fix RDNA3 compile --- ggml/src/ggml-cuda/common.cuh | 4 - ggml/src/ggml-cuda/fattn-mma-f16.cuh | 57 ++----- ggml/src/ggml-cuda/mma.cuh | 245 +++++++++------------------ ggml/src/ggml-cuda/mmq.cuh | 201 ++-------------------- 4 files changed, 112 insertions(+), 395 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index ddf50baf495..3aec1742ee1 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -269,10 +269,6 @@ static const char * cu_get_error_str(CUresult err) { #define FLASH_ATTN_AVAILABLE #endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ < 220) -#if defined(TURING_MMA_AVAILABLE) -#define LDMATRIX_TRANS_AVAILABLE -#endif // defined(TURING_MMA_AVAILABLE) - static bool fp16_available(const int cc) { return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL || (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_PH1); diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index b613ae61fb8..e185449d491 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -305,12 +305,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV, const int i_sup) { constexpr int warp_size = ggml_cuda_get_physical_warp_size(); // K/V data is loaded with decreasing granularity for D for better memory bandwidth. - // The minimum granularity with cp.async is 16 bytes, with synchronous data loading it's 4 bytes. + // The minimum granularity is 16 bytes. + constexpr int h2_per_chunk = 16/sizeof(half2); + const int chunks_per_row = D2 / h2_per_chunk; if constexpr (use_cp_async) { + static_assert(warp_size == 32, "bad warp_size"); static_assert(!oob_check, "OOB check not compatible with cp_async"); constexpr int preload = 64; - constexpr int h2_per_chunk = 16/sizeof(half2); - const int chunks_per_row = D2 / h2_per_chunk; const unsigned int tile_KV_32 = ggml_cuda_cvta_generic_to_shared(tile_KV); @@ -348,11 +349,11 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( // 6: max 1*16= 16 bytes, 8 half ggml_cuda_unroll<6>{}(load); } else { - // TODO use ggml_cuda_memcpy_1 + const half2 zero[4] = {{0.0f, 0.0f}, {0.0f, 0.0f}, {0.0f, 0.0f}, {0.0f, 0.0f}}; auto load = [&] __device__ (const int n) { - const int stride_k = warp_size >> n; - const int k0_start = stride_k == warp_size ? 0 : D2 - D2 % (2*stride_k); - const int k0_stop = D2 - D2 % (1*stride_k); + const int stride_k = 32 >> n; + const int k0_start = stride_k == 32 ? 0 : chunks_per_row - chunks_per_row % (2*stride_k); + const int k0_stop = chunks_per_row - chunks_per_row % (1*stride_k); const int stride_i = warp_size / stride_k; if (k0_start == k0_stop) { @@ -371,15 +372,18 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k); - tile_KV[i*stride_tile + k] = !oob_check || i < i_sup ? KV[i*stride_KV + k] : make_half2(0.0f, 0.0f); + ggml_cuda_memcpy_1<16>(tile_KV + i*stride_tile + k*4, + !oob_check || i < i_sup ? KV + i*stride_KV + k*h2_per_chunk : zero); } } }; - // 1: max 32* 4=128 bytes, 64 half - // 2: max 16* 4= 64 bytes, 32 half - // 3: max 8* 4= 32 bytes, 16 half - // 4: max 4* 4= 16 bytes, 8 half - ggml_cuda_unroll<4>{}(load); + // 1: max 32*16=512 bytes, 256 half + // 2: max 16*16=256 bytes, 128 half + // 3: max 8*16=128 bytes, 64 half + // 4: max 4*16= 64 bytes, 32 half + // 5: max 2*16= 32 bytes, 16 half + // 6: max 1*16= 16 bytes, 8 half + ggml_cuda_unroll<6>{}(load); } } @@ -862,11 +866,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } -#if defined(AMD_WMMA_AVAILABLE) && !defined(LDMATRIX_TRANS_AVAILABLE) - T_A_VKQ A_identity; - make_identity_mat(A_identity); -#endif // defined(AMD_WMMA_AVAILABLE) && !defined(LDMATRIX_TRANS_AVAILABLE) - // Calculate VKQ tile, need to use logical rather than physical elements for i0 due to transposition of V: #pragma unroll for (int i0_start = 0; i0_start < DV; i0_start += 2*nbatch_V2) { @@ -897,29 +896,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( const int k0 = k00 + (threadIdx.y % np)*T_A_VKQ::J; T_A_VKQ A; // Transposed in SRAM but not in registers, gets transposed on load. -#if defined(LDMATRIX_TRANS_AVAILABLE) load_ldmatrix_trans(A, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V); -#elif defined(AMD_MFMA_AVAILABLE) - // MFMA A register layout: A_mat[i=lane%16][k=4*(lane/16)+reg]. - // Normal load gives A_mat[seq][dv] but we need A_mat[dv][seq] = V^T. - // Load with transposed addressing: 4 strided half loads. - { - const half2 * xs0 = tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2; - const half * xs0_h = (const half *) xs0; - const int stride_h = stride_tile_V * 2; // stride in half units - half * A_h = (half *) A.x; -#pragma unroll - for (int l = 0; l < 4; ++l) { - A_h[l] = xs0_h[(4*(threadIdx.x / 16) + l) * stride_h + threadIdx.x % 16]; - } - } -#else - // TODO: Try to transpose tile_V when loading gmem to smem. - // Use mma to transpose T_A_VKQ for RDNA. - T_A_VKQ A_trans; - load_ldmatrix(A_trans, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V); - mma(A, A_trans, A_identity); -#endif // defined(LDMATRIX_TRANS_AVAILABLE) if constexpr (T_B_KQ::I == 8) { mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]); } else { diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index c91dd2d9ad6..b0f674635f1 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -86,17 +86,12 @@ namespace ggml_cuda_mma { // - (I_MAJOR, I_MAJOR_MIRRORED) -> I_MAJOR // - (I_MAJOR, J_MAJOR_MIRRORED) -> I_MAJOR - static constexpr bool is_i_major(const data_layout dl) { - return dl == DATA_LAYOUT_I_MAJOR || - dl == DATA_LAYOUT_I_MAJOR_MIRRORED; - } - static constexpr __device__ data_layout get_input_data_layout() { -#if defined(RDNA3) || __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#if defined(RDNA3) || defined(VOLTA_MMA_AVAILABLE) return DATA_LAYOUT_I_MAJOR_MIRRORED; #else return DATA_LAYOUT_I_MAJOR; -#endif // defined(RDNA3) || __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#endif // defined(RDNA3) || defined(VOLTA_MMA_AVAILABLE) } template @@ -113,7 +108,6 @@ namespace ggml_cuda_mma { T x[ne] = {0}; static constexpr __device__ bool supported() { - if (I == 64 && J == 2) return true; if (I == 16 && J == 8) return true; if (I == 32 && J == 4) return true; if (I == 16 && J == 16) return true; @@ -122,7 +116,7 @@ namespace ggml_cuda_mma { } static __device__ __forceinline__ int get_i(const int l) { - if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8> + if constexpr (I == 16 && J == 4) { return threadIdx.x % 16; } else if constexpr (I == 16 && J == 8) { return threadIdx.x % 16; @@ -139,8 +133,8 @@ namespace ggml_cuda_mma { } static __device__ __forceinline__ int get_j(const int l) { - if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8> - return (2 * ((threadIdx.x / 16) % 2) + l); + if constexpr (I == 16 && J == 4) { + return threadIdx.x / 16; } else if constexpr (I == 16 && J == 8) { return 2 * (threadIdx.x / 16) + l; } else if constexpr (I == 32 && J == 4) { @@ -154,7 +148,7 @@ namespace ggml_cuda_mma { return -1; } } -#elif __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#elif defined(VOLTA_MMA_AVAILABLE) static constexpr int ne = I * J / 32; T x[ne] = {0}; @@ -283,7 +277,7 @@ namespace ggml_cuda_mma { static constexpr int J = J_; static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR; -#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#if defined(VOLTA_MMA_AVAILABLE) static constexpr int ne = I * J / WARP_SIZE; half2 x[ne] = {{0.0f, 0.0f}}; @@ -407,7 +401,7 @@ namespace ggml_cuda_mma { return -1; } } -#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#endif // defined(VOLTA_MMA_AVAILABLE) }; template @@ -701,57 +695,12 @@ namespace ggml_cuda_mma { } #endif // defined(TURING_MMA_AVAILABLE) - static __device__ __forceinline__ void make_identity_mat(tile<16, 8, half2> & t) { -#if defined(RDNA4) - const int row = t.get_i(0); - const int left_right = t.get_j(0) / 4; - const int up_down = row / 8; - const int idx = row % 8; - reinterpret_cast(t.x)[idx] = left_right == up_down ? 1.0f : 0.0f; -#else - GGML_UNUSED_VARS(t); - NO_DEVICE_CODE; -#endif // defined(RDNA4) - } - template static __device__ __forceinline__ void load_generic(tile & t, const T * __restrict__ xs0, const int stride) { -#if defined(AMD_MFMA_AVAILABLE) - if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8> -#pragma unroll - for (int l = 0; l < t.ne; ++l) { - t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)]; - } - } else { - ggml_cuda_memcpy_1(t.x, xs0 + t.get_i(0) * stride + t.get_j(0)); - } -#elif defined(AMD_WMMA_AVAILABLE) - // All wmma layout has contiguous data when i-major. - if constexpr (is_i_major(dl)) { - // the data must be aligned to 16 bytes when bigger than ggml_cuda_get_max_cpy_bytes() - constexpr int aligned_copy_bytes = ggml_cuda_get_max_cpy_bytes(); - if constexpr (sizeof(t.x) > aligned_copy_bytes) { - static_assert(sizeof(t.x) % aligned_copy_bytes == 0, "bad type size"); - constexpr int aligned_copy_count = sizeof(t.x)/aligned_copy_bytes; -#pragma unroll - for (int i = 0; i < aligned_copy_count; ++i) { - ggml_cuda_memcpy_1(t.x + t.ne/aligned_copy_count*i, xs0 + t.get_i(0) * stride + t.get_j(t.ne/aligned_copy_count*i)); - } - } else { - ggml_cuda_memcpy_1(t.x, xs0 + t.get_i(0) * stride + t.get_j(0)); - } - } else { -#pragma unroll - for (int l = 0; l < t.ne; ++l) { - t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)]; - } - } -#else #pragma unroll for (int l = 0; l < t.ne; ++l) { t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)]; } -#endif // defined(AMD_MFMA_AVAILABLE) } template @@ -764,26 +713,37 @@ namespace ggml_cuda_mma { : "=r"(xi[0]), "=r"(xi[1]) : "l"(xs)); #else - load_generic(t, xs0, stride); + GGML_UNUSED_VARS(t, xs0, stride); + NO_DEVICE_CODE; #endif // TURING_MMA_AVAILABLE } - template + template static __device__ __forceinline__ void load_ldmatrix( - tile<16, 4, T> & t, const T * __restrict__ xs0, const int stride) { + tile<16, 4, T, dl> & t, const T * __restrict__ xs0, const int stride) { #ifdef TURING_MMA_AVAILABLE int * xi = (int *) t.x; const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride; asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];" : "=r"(xi[0]), "=r"(xi[1]) : "l"(xs)); +#elif defined(AMD_WMMA_AVAILABLE) +#ifdef RDNA3 + static_assert(dl == DATA_LAYOUT_I_MAJOR_MIRRORED, "bad data layout"); + static_assert(sizeof(t.x) == 16, "bad ne"); + ggml_cuda_memcpy_1<8>(t.x + 0, xs0 + t.get_i(0)*stride + 0); + ggml_cuda_memcpy_1<8>(t.x + 2, xs0 + t.get_i(0)*stride + 2); +#else + static_assert(dl == DATA_LAYOUT_I_MAJOR, "bad data layout"); + static_assert(sizeof(t.x) == 8, "bad ne"); + ggml_cuda_memcpy_1<8>(t.x, xs0 + t.get_i(0)*stride + t.get_j(0)); +#endif // RDNA3 +#elif defined(AMD_MFMA_AVAILABLE) + static_assert(sizeof(t.x) == 4, "bad ne"); + ggml_cuda_memcpy_1<4>(t.x, xs0 + t.get_i(0)*stride + t.get_j(0)); #else -#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA GGML_UNUSED_VARS(t, xs0, stride); NO_DEVICE_CODE; -#else - load_generic(t, xs0, stride); -#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA #endif // TURING_MMA_AVAILABLE } @@ -796,19 +756,26 @@ namespace ggml_cuda_mma { asm volatile("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];" : "=r"(xi[0]), "=r"(xi[1]), "=r"(xi[2]), "=r"(xi[3]) : "l"(xs)); -#else -#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA -#if 1 - // TODO: more generic handling - static_assert(sizeof(T) == 4, "bad type size"); +#elif defined(VOLTA_MMA_AVAILABLE) ggml_cuda_memcpy_1<4*sizeof(T)>(t.x + 0, xs0 + t.get_i(0)*stride + 0); ggml_cuda_memcpy_1<4*sizeof(T)>(t.x + 4, xs0 + t.get_i(4)*stride + 4); +#elif defined(AMD_WMMA_AVAILABLE) +#ifdef RDNA3 + static_assert(dl == DATA_LAYOUT_I_MAJOR_MIRRORED, "bad data layout"); + static_assert(sizeof(t.x) == 32, "bad ne"); + ggml_cuda_memcpy_1<16>(t.x + 0, xs0 + t.get_i(0)*stride + 0); + ggml_cuda_memcpy_1<16>(t.x + 4, xs0 + t.get_i(0)*stride + 4); #else - load_generic(t, xs0, stride); -#endif // 1 + static_assert(dl == DATA_LAYOUT_I_MAJOR, "bad data layout"); + static_assert(sizeof(t.x) == 16, "bad ne"); + ggml_cuda_memcpy_1<16>(t.x, xs0 + t.get_i(0)*stride + t.get_j(0)); +#endif // RDNA3 +#elif defined(AMD_MFMA_AVAILABLE) + static_assert(sizeof(t.x) == 8, "bad ne"); + ggml_cuda_memcpy_1<8>(t.x, xs0 + t.get_i(0)*stride + t.get_j(0)); #else - load_generic(t, xs0, stride); -#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA + GGML_UNUSED_VARS(t, xs0, stride); + NO_DEVICE_CODE; #endif // TURING_MMA_AVAILABLE } @@ -827,23 +794,30 @@ namespace ggml_cuda_mma { static __device__ __forceinline__ void load_ldmatrix( tile<32, 4, half2> & t, const half2 * __restrict__ xs0, const int stride) { -#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#if defined(VOLTA_MMA_AVAILABLE) ggml_cuda_memcpy_1<4*sizeof(half2)>(t.x, xs0 + t.get_i(0)*stride); #else GGML_UNUSED_VARS(t, xs0, stride); NO_DEVICE_CODE; -#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#endif // defined(VOLTA_MMA_AVAILABLE) } template static __device__ __forceinline__ void load_ldmatrix_trans( tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) { #ifdef TURING_MMA_AVAILABLE - int * xi = (int * ) t.x; + int * xi = (int *) t.x; const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2); asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.b16 {%0, %1, %2, %3}, [%4];" : "=r"(xi[0]), "=r"(xi[2]), "=r"(xi[1]), "=r"(xi[3]) : "l"(xs)); +#elif defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + half * xh = (half *) t.x; +#pragma unroll + for (int l = 0; l < t.ne; ++l) { + xh[2*l + 0] = ((const half *) xs0)[(2*t.get_j(l) + 0)*(2*stride) + t.get_i(l)]; + xh[2*l + 1] = ((const half *) xs0)[(2*t.get_j(l) + 1)*(2*stride) + t.get_i(l)]; + } #else GGML_UNUSED_VARS(t, xs0, stride); NO_DEVICE_CODE; @@ -1218,73 +1192,27 @@ namespace ggml_cuda_mma { using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int; int32x4_t * acc = (int32x4_t *) D.x; #if defined(CDNA4) || defined(CDNA3) - acc[0] = __builtin_amdgcn_mfma_i32_16x16x32_i8(((int64_t *) A.x)[0], - ((int64_t *) B.x)[0], - acc[0], - 0, 0, 0); + acc[0] = __builtin_amdgcn_mfma_i32_16x16x32_i8(((int64_t *) A.x)[0], ((int64_t *) B.x)[0], acc[0], 0, 0, 0); #elif defined(CDNA2) || defined(CDNA1) - acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[0], - B.x[0], - acc[0], - 0, 0, 0); - acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[1], - B.x[1], - acc[0], - 0, 0, 0); + acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[0], B.x[0], acc[0], 0, 0, 0); + acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[1], B.x[1], acc[0], 0, 0, 0); #endif // defined(CDNA4) || defined(CDNA3) - #elif defined(AMD_WMMA_AVAILABLE) - using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int; int32x8_t * acc = (int32x8_t *) D.x; - #if defined(RDNA4) using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int; int32x2_t * a_vec = (int32x2_t *) A.x; int32x2_t * b_vec = (int32x2_t *) B.x; - - acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12( - true, - a_vec[0], - true, - b_vec[0], - acc[0], - true - ); - - acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12( - true, - a_vec[1], - true, - b_vec[1], - acc[0], - true - ); - + acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(true, a_vec[0], true, b_vec[0], acc[0], true); + acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(true, a_vec[1], true, b_vec[1], acc[0], true); #elif defined(RDNA3) using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int; int32x4_t * a_vec = (int32x4_t *) A.x; int32x4_t * b_vec = (int32x4_t *) B.x; - - acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32( - true, - a_vec[0], - true, - b_vec[0], - acc[0], - true - ); - - acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32( - true, - a_vec[1], - true, - b_vec[1], - acc[0], - true - ); + acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(true, a_vec[0], true, b_vec[0], acc[0], true); + acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(true, a_vec[1], true, b_vec[1], acc[0], true); #endif // RDNA4 - #else GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; @@ -1297,19 +1225,10 @@ namespace ggml_cuda_mma { using int32x16_t = __attribute__((__vector_size__(16 * sizeof(int)))) int; int32x16_t * acc = (int32x16_t *) D.x; #if defined(CDNA4) || defined(CDNA3) - acc[0] = __builtin_amdgcn_mfma_i32_32x32x16_i8(((int64_t *) A.x)[0], - ((int64_t *) B.x)[0], - acc[0], - 0, 0, 0); + acc[0] = __builtin_amdgcn_mfma_i32_32x32x16_i8(((int64_t *) A.x)[0], ((int64_t *) B.x)[0], acc[0], 0, 0, 0); #elif defined(CDNA2) || defined(CDNA1) - acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[0], - B.x[0], - acc[0], - 0, 0, 0); - acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[1], - B.x[1], - acc[0], - 0, 0, 0); + acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[0], B.x[0], acc[0], 0, 0, 0); + acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[1], B.x[1], acc[0], 0, 0, 0); #endif // defined(CDNA4) || defined(CDNA3) #else @@ -1329,7 +1248,7 @@ namespace ggml_cuda_mma { static __device__ __forceinline__ void mma( tile<32, 8, float> & D, const tile<32, 4, half2> & A, const tile<8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> & B) { -#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#if defined(VOLTA_MMA_AVAILABLE) const int * Axi = (const int *) A.x; const int * Bxi = (const int *) B.x; int * Dxi = (int *) D.x; @@ -1344,12 +1263,12 @@ namespace ggml_cuda_mma { #else GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; -#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA +#endif // defined(VOLTA_MMA_AVAILABLE) } static __device__ __forceinline__ void mma( tile<32, 4, half2> & D, const tile<32, 4, half2> & A, const tile<8, 4, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> & B) { -#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#if defined(VOLTA_MMA_AVAILABLE) const int * Axi = (const int *) A.x; const int * Bxi = (const int *) B.x; int * Dxi = (int *) D.x; @@ -1364,41 +1283,35 @@ namespace ggml_cuda_mma { #else GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; -#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA +#endif // defined(VOLTA_MMA_AVAILABLE) } template static __device__ __forceinline__ void mma( tile<16, 16, int, dl_d> & D, const tile<16, 4, int, dl_ab> & A, const tile<16, 4, int, dl_ab> & B) { -#if defined(AMD_WMMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) + using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int; + int32x4_t * acc = (int32x4_t *) D.x; +#if defined(CDNA4) || defined(CDNA3) + const int64_t xA = uint32_t(A.x[0]); + const int64_t xB = uint32_t(B.x[0]); + acc[0] = __builtin_amdgcn_mfma_i32_16x16x32_i8(xA, xB, acc[0], 0, 0, 0); +#elif defined(CDNA2) || defined(CDNA1) + acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[0], B.x[0], acc[0], 0, 0, 0); +#endif // defined(CDNA4) || defined(CDNA3) +#elif defined(AMD_WMMA_AVAILABLE) using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int; int32x8_t * acc = (int32x8_t *) D.x; #if defined(RDNA4) using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int; int32x2_t * a_vec = (int32x2_t *) A.x; int32x2_t * b_vec = (int32x2_t *) B.x; - - acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12( - true, - a_vec[0], - true, - b_vec[0], - acc[0], - false - ); + acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(true, a_vec[0], true, b_vec[0], acc[0], false); #elif defined(RDNA3) using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int; int32x4_t * a_vec = (int32x4_t *) A.x; int32x4_t * b_vec = (int32x4_t *) B.x; - - acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32( - true, - a_vec[0], - true, - b_vec[0], - acc[0], - false - ); + acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(true, a_vec[0], true, b_vec[0], acc[0], false); #endif // RDNA4 #else GGML_UNUSED(D); diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 28b662df925..b1a319de9be 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -104,7 +104,7 @@ struct tile_x_sizes { }; static int get_mmq_x_max_host(const int cc) { - return (amd_mfma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc)) ? 128 : + return (turing_mma_available(cc) || amd_wmma_available(cc)) ? 128 : GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ? #ifdef GGML_CUDA_FORCE_MMQ 128 : 64; @@ -114,9 +114,9 @@ static int get_mmq_x_max_host(const int cc) { } static constexpr __device__ int get_mmq_x_max_device() { -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) +#if defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) return 128; -#else // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#else // defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) #if defined(GGML_USE_HIP) return 64; @@ -1054,13 +1054,13 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( tile_A A[ntx]; #pragma unroll for (int n = 0; n < ntx; ++n) { - load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0); + load_ldmatrix(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0); } #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { tile_B B; - load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + load_ldmatrix(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); float dB; const int j = j0 + tile_C::get_j(0); @@ -1295,13 +1295,13 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( tile_A A[ntx]; #pragma unroll for (int n = 0; n < ntx; ++n) { - load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1); + load_ldmatrix(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1); } #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { tile_B B; - load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + load_ldmatrix(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); const int j = j0 + tile_C::get_j(0); const float2 dsB = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]); @@ -1435,57 +1435,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a( template static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { -#if defined(AMD_MFMA_AVAILABLE) - constexpr data_layout input_layout = get_input_data_layout(); - typedef tile<16, 8, int, input_layout> tile_A; - typedef tile<16, 8, int, input_layout> tile_B; - typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C; - typedef tile<64, 2, int, input_layout> tile_load; - - constexpr int granularity = mmq_get_granularity_device(mmq_x); - constexpr int rows_per_warp = granularity; - constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. - - y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); - - const int * x_qs = (const int *) x; - const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2; - const int * y_qs = (const int *) y + 4; - const float * y_df = (const float *) y; - - const int i0 = (threadIdx.y / ntx) * rows_per_warp; - - for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) { - const int k0 = k00 + k01; - - tile_A A[ntx]; -#pragma unroll - for (int n = 0; n < ntx; ++n) { - load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K); - } - -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { - tile_B B[1]; - load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); - - const int j = j0 + tile_C::get_j(0); - const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1] / 2; - -#pragma unroll - for (int n = 0; n < ntx; ++n) { - tile_C C; - mma(C, A[n], B[0]); - -#pragma unroll - for (int l = 0; l < tile_C::ne; ++l) { - const int i = i0 + n*tile_C::I + tile_C::get_i(l); - sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4] * dB; - } - } - } - } -#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles +#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr data_layout input_layout = get_input_data_layout(); typedef tile<16, 4, int, input_layout> tile_A; typedef tile<16, 4, int, input_layout> tile_B; @@ -1510,13 +1460,13 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( tile_A A[ntx]; #pragma unroll for (int n = 0; n < ntx; ++n) { - load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K); + load_ldmatrix(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K); } #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { tile_B B; - load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + load_ldmatrix(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); const int j = j0 + tile_C::get_j(0); const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1]; @@ -1742,74 +1692,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a( template static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { -#if defined(AMD_MFMA_AVAILABLE) - constexpr data_layout input_layout = get_input_data_layout(); - typedef tile<16, 8, int, input_layout> tile_A; - typedef tile<16, 8, int, input_layout> tile_B; - typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C; - typedef tile<64, 2, int, input_layout> tile_load; - - constexpr int granularity = mmq_get_granularity_device(mmq_x); - constexpr int rows_per_warp = granularity; - constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. - - y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); - - const int * x_qs = (const int *) x; - const half2 * x_dm = (const half2 *) x_qs + MMQ_TILE_NE_K*2; - const int * y_qs = (const int *) y + 4; - const half2 * y_ds = (const half2 *) y; - - const int i0 = (threadIdx.y / ntx) * rows_per_warp; - - for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) { - const int k0 = k00 + k01; - - tile_A A[ntx]; -#pragma unroll - for (int n = 0; n < ntx; ++n) { - load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K); - } - -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { - tile_B B[1]; - load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); - - const int j = j0 + tile_C::get_j(0); - const float dB = (k01 < MMQ_TILE_NE_K/2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K]).x/2 : __half22float2(y_ds[j*MMQ_TILE_Y_K]).y/2; - const float sB = (k01 >= MMQ_TILE_NE_K * 3/4) ? 0 - : (((k01/4)%2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).y - : __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).x); - - tile_C Cm; - if (k01 >= MMQ_TILE_NE_K * 3/4) { - tile_A A1; - A1.x[0] = 0x01010101; - A1.x[1] = 0x01010101; - mma(Cm, A1, B[0]); - } - -#pragma unroll - for (int n = 0; n < ntx; ++n) { - tile_C Cd; - mma(Cd, A[n], B[0]); - -#pragma unroll - for (int l = 0; l < tile_C::ne; ++l) { - const int i = i0 + n*tile_C::I + tile_C::get_i(l); - const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/4]); - float tmp = Cd.x[l]*dm.x; - if (k01 >= MMQ_TILE_NE_K * 3/4) { - tmp -= Cm.x[l]*dm.y; - } - sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*dB; - sum[(j0/tile_C::J + n)*tile_C::ne + l] -= dm.y*sB; - } - } - } - } -#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles +#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr data_layout input_layout = get_input_data_layout(); typedef tile<16, 4, int, input_layout> tile_A; typedef tile<16, 4, int, input_layout> tile_B; @@ -1834,13 +1717,13 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( tile_A A[ntx]; #pragma unroll for (int n = 0; n < ntx; ++n) { - load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K); + load_ldmatrix(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K); } #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { tile_B B; - load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + load_ldmatrix(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); const int j = j0 + tile_C::get_j(0); const float dB = (k01 < MMQ_TILE_NE_K/2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K]).x : __half22float2(y_ds[j*MMQ_TILE_Y_K]).y; @@ -2573,59 +2456,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a( template static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { -#if defined(AMD_MFMA_AVAILABLE) - constexpr data_layout input_layout = get_input_data_layout(); - typedef tile<16, 8, int, input_layout> tile_A; - typedef tile<16, 8, int, input_layout> tile_B; - typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C; - typedef tile<64, 2, int, input_layout> tile_load; - - constexpr int granularity = mmq_get_granularity_device(mmq_x); - constexpr int rows_per_warp = granularity; - constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. - - y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); - - const int * x_qs = (const int *) x; - const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2; - const int * x_sc = (const int *) x_df + MMQ_TILE_NE_K/QI6_K; - const int * y_qs = (const int *) y + 4; - const float * y_df = (const float *) y; - - const int i0 = (threadIdx.y / ntx) * rows_per_warp; - - for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) { - const int k0 = k00 + k01; - - tile_A A[ntx]; -#pragma unroll - for (int n = 0; n < ntx; ++n) { - load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K); - } - -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { - tile_B B[1]; - load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); - - const int j = j0 + tile_C::get_j(0); - const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1] / 2; - -#pragma unroll - for (int n = 0; n < ntx; ++n) { - tile_C C; - mma(C, A[n], B[0]); - -#pragma unroll - for (int l = 0; l < tile_C::ne; ++l) { - const int i = i0 + n*tile_C::I + tile_C::get_i(l); - const int8_t * sc = (const int8_t *) (x_sc + i*MMQ_MMA_TILE_X_K_Q6_K + k00/16); - sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * sc[k01/4] * x_df[i*MMQ_MMA_TILE_X_K_Q6_K] * dB; - } - } - } - } -#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles +#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr data_layout input_layout = get_input_data_layout(); typedef tile<16, 4, int, input_layout> tile_A; typedef tile<16, 4, int, input_layout> tile_B; @@ -2651,13 +2482,13 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( tile_A A[ntx]; #pragma unroll for (int n = 0; n < ntx; ++n) { - load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K); + load_ldmatrix(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K); } #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { tile_B B; - load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + load_ldmatrix(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); const int j = j0 + tile_C::get_j(0); const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1]; From 931cf2f3a81af3f347cc01bed686d374c851a2e4 Mon Sep 17 00:00:00 2001 From: Katostrofik Date: Mon, 20 Apr 2026 01:39:45 -0400 Subject: [PATCH 459/831] Fix reorder MMVQ assert on unaligned vocab sizes (llama/22035) * [SYCL] Fix reorder MMVQ assert on unaligned vocab sizes The reorder mul_mat_vec_q dispatchers for Q4_0, Q8_0, Q4_K, and Q6_K asserted that block_num_y was a multiple of 16 subgroups. Models with a vocab size not divisible by 16 (for example HY-MT at 120818) aborted on model load when the output projection tripped the assert. I replaced the assert with padding: block_num_y now rounds up to a whole number of subgroup-sized workgroups. The kernel already has the row bounds check (`if (row >= nrows) return;`) so the extra padded threads early-exit cleanly. Row values are uniform across a subgroup so the collective reduce stays safe. For aligned vocab sizes the padded block_num_y equals the old value, so the kernel launch is identical and there is no regression. Thanks to @arthw for flagging the relationship to #21527. Fixes #22020. AI assisted coding, tested on Intel B70 hardware. * sycl: use WARP_SIZE for num_subgroups in reorder MMVQ launches Replaces the hardcoded 16 with WARP_SIZE in the four reorder_mul_mat_vec launch helpers (Q4_0, Q8_0, Q4_K, Q6_K). Compile-time no-op on the Intel target where WARP_SIZE is 16, but makes the relationship to subgroup size explicit. Per review by @NeoZhangJianyu on #22035. Assisted by Claude. --- ggml/src/ggml-sycl/mmvq.cpp | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-sycl/mmvq.cpp b/ggml/src/ggml-sycl/mmvq.cpp index af22b98dddb..3a4577ecbbc 100644 --- a/ggml/src/ggml-sycl/mmvq.cpp +++ b/ggml/src/ggml-sycl/mmvq.cpp @@ -537,9 +537,9 @@ static void mul_mat_vec_q_iq4_xs_q8_1(const void *__restrict__ vx, static void reorder_mul_mat_vec_q4_0_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, dpct::queue_ptr stream) { GGML_ASSERT(ncols % QK4_0 == 0); - const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y); - constexpr size_t num_subgroups = 16; - GGML_ASSERT(block_num_y % num_subgroups == 0); + // Round up to a whole number of subgroup-sized workgroups; out-of-range rows are skipped inside the kernel. + constexpr size_t num_subgroups = WARP_SIZE; + const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y * (int) num_subgroups) * (int) num_subgroups; const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, (block_num_y * WARP_SIZE)); const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE); @@ -682,9 +682,9 @@ static void mul_mat_vec_q5_1_q8_1_sycl(const void *vx, const void *vy, static void reorder_mul_mat_vec_q8_0_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, dpct::queue_ptr stream) { GGML_ASSERT(ncols % QK8_0 == 0); - const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y); - constexpr size_t num_subgroups = 16; - GGML_ASSERT(block_num_y % num_subgroups == 0); + // Round up to a whole number of subgroup-sized workgroups; out-of-range rows are skipped inside the kernel. + constexpr size_t num_subgroups = WARP_SIZE; + const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y * (int) num_subgroups) * (int) num_subgroups; const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, (block_num_y * WARP_SIZE)); const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE); @@ -798,9 +798,9 @@ static void reorder_mul_mat_vec_q4_k_q8_1_sycl(const void * vx, const void * vy, const int nrows, dpct::queue_ptr stream) { GGML_ASSERT(ncols % QK_K == 0); - const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y); - constexpr size_t num_subgroups = 16; - GGML_ASSERT(block_num_y % num_subgroups == 0); + // Round up to a whole number of subgroup-sized workgroups; out-of-range rows are skipped inside the kernel. + constexpr size_t num_subgroups = WARP_SIZE; + const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y * (int) num_subgroups) * (int) num_subgroups; const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE); const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE); @@ -842,9 +842,9 @@ static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy, static void reorder_mul_mat_vec_q6_k_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, dpct::queue_ptr stream) { GGML_ASSERT(ncols % QK_K == 0); - const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y); - constexpr size_t num_subgroups = 16; - GGML_ASSERT(block_num_y % num_subgroups == 0); + // Round up to a whole number of subgroup-sized workgroups; out-of-range rows are skipped inside the kernel. + constexpr size_t num_subgroups = WARP_SIZE; + const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y * (int) num_subgroups) * (int) num_subgroups; const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE); const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE); From 5f21fdcbb9400721732c161e4b04174e5d8625db Mon Sep 17 00:00:00 2001 From: neha-ha <137219201+neha-ha@users.noreply.github.com> Date: Mon, 20 Apr 2026 07:37:17 -0700 Subject: [PATCH 460/831] ggml-webgpu: updated matrix-vector multiplication (llama/21738) * merged properly, but slow q3_k and q5_k with u32 indexing * Start on new mat-vec * New format float paths working * Working q4_0 * Work on remaining legacy q-types * port k-quants to new matvec * remove old shader * Remove old constants, format * remove accidental file --------- Co-authored-by: Neha Abbas Co-authored-by: Reese Levine --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 34 +- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 28 +- .../wgsl-shaders/common_decls.tmpl | 7 + .../ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl | 1102 +++++++++++------ 4 files changed, 788 insertions(+), 383 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 7d9a4403fab..9d88f98050e 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -44,18 +44,9 @@ // Matrix-vector multiplication parameters #define WEBGPU_MUL_MAT_VEC_WG_SIZE 256 -// Must be multiple of 4 to work with vectorized paths, and must divide -// mul_mat_vec wg size -#define WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG 64 -#define WEBGPU_MUL_MAT_VEC_FLOAT_TILE_K 256 - -#define WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG 64 -#define WEBGPU_MUL_MAT_VEC_LEGACY_Q_TILE_K 256 - -// Requires 32 threads per output (wg_size/outputs_per_wg == 32) -#define WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG 8 -// Requires at least two (and multiple of 2) k-quant blocks per tile -#define WEBGPU_MUL_MAT_VEC_K_Q_TILE_K 512 +#define WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG 4 +#define WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG 4 +#define WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG 4 // default size for legacy matrix multiplication #define WEBGPU_MUL_MAT_WG_SIZE 256 @@ -78,6 +69,7 @@ struct ggml_webgpu_shader_lib_context { bool inplace = false; bool overlap = false; bool src_overlap = false; + bool supports_subgroups = false; bool supports_subgroup_matrix = false; uint32_t sg_mat_m = 0; uint32_t sg_mat_n = 0; @@ -575,7 +567,6 @@ struct ggml_webgpu_mul_mat_vec_pipeline_key_hash { struct ggml_webgpu_mul_mat_vec_shader_decisions { uint32_t wg_size; - uint32_t tile_k; uint32_t outputs_per_wg; uint32_t vec_size; }; @@ -1326,7 +1317,7 @@ class ggml_webgpu_shader_lib { ggml_webgpu_mul_mat_vec_pipeline_key key = {}; key.src0_type = context.src0->type; key.src1_type = context.src1->type; - key.vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 && + key.vectorized = (context.src0->ne[0] % 4 == 0 && (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? 1 : 0; @@ -1337,7 +1328,8 @@ class ggml_webgpu_shader_lib { } std::vector defines; - std::string variant = "mul_mat_vec"; + std::string variant = "mul_mat_vec"; + const char * shader_src = wgsl_mul_mat_vec; // src0 type (matrix row) switch (context.src0->type) { @@ -1386,25 +1378,25 @@ class ggml_webgpu_shader_lib { defines.push_back(key.vectorized ? "VEC" : "SCALAR"); uint32_t wg_size = WEBGPU_MUL_MAT_VEC_WG_SIZE; - uint32_t tile_k = WEBGPU_MUL_MAT_VEC_FLOAT_TILE_K; uint32_t outputs_per_wg = WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG; if (key.src0_type >= GGML_TYPE_Q2_K) { - tile_k = WEBGPU_MUL_MAT_VEC_K_Q_TILE_K; outputs_per_wg = WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG; } else if (key.src0_type >= GGML_TYPE_Q4_0) { - tile_k = WEBGPU_MUL_MAT_VEC_LEGACY_Q_TILE_K; outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG; } defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); - defines.push_back(std::string("TILE_K=") + std::to_string(tile_k)); defines.push_back(std::string("OUTPUTS_PER_WG=") + std::to_string(outputs_per_wg)); + defines.push_back(context.supports_subgroups ? "USE_SUBGROUP_REDUCTION" : "USE_WORKGROUP_REDUCTION"); + variant += context.supports_subgroups ? "_sg_reduce" : "_wg_reduce"; + if (key.vectorized) { + variant += "_vectorized"; + } - auto processed = preprocessor.preprocess(wgsl_mul_mat_vec, defines); + auto processed = preprocessor.preprocess(shader_src, defines); auto decisions = std::make_shared(); decisions->wg_size = wg_size; - decisions->tile_k = tile_k; decisions->outputs_per_wg = outputs_per_wg; decisions->vec_size = key.vectorized ? 4 : 1; diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index e7bda817a28..aa20a745e0a 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -181,6 +181,7 @@ struct webgpu_dispatch_desc { struct webgpu_capabilities { wgpu::Limits limits; + bool supports_subgroups = false; bool supports_subgroup_matrix = false; uint32_t sg_mat_m = 0; @@ -1164,14 +1165,11 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: case GGML_TYPE_Q6_K: - use_fast = true; - break; - case GGML_TYPE_Q2_K: - case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: - // we don't have fast mat-vec for these types, but we do have (semi) fast mat-mat - use_fast = !is_vec; + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q2_K: + use_fast = true; break; default: break; @@ -1182,10 +1180,12 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, } ggml_webgpu_shader_lib_context shader_lib_ctx = {}; - shader_lib_ctx.src0 = src0; - shader_lib_ctx.src1 = src1; - shader_lib_ctx.dst = dst; + + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.dst = dst; shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.supports_subgroups = ctx->global_ctx->capabilities.supports_subgroups; shader_lib_ctx.supports_subgroup_matrix = ctx->global_ctx->capabilities.supports_subgroup_matrix; shader_lib_ctx.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m; shader_lib_ctx.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n; @@ -1287,7 +1287,8 @@ static webgpu_encoded_op ggml_webgpu_mul_mat_id(webgpu_context & ctx, shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; // Get or create pipeline - webgpu_pipeline gather_pipeline, main_pipeline; + webgpu_pipeline gather_pipeline; + webgpu_pipeline main_pipeline; std::vector dispatches; @@ -3040,6 +3041,8 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { ctx->webgpu_global_ctx->adapter.GetFeatures(&features); // we require f16 support GGML_ASSERT(ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ShaderF16)); + ctx->webgpu_global_ctx->capabilities.supports_subgroups = + ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::Subgroups); #ifndef __EMSCRIPTEN__ // Accept f16 subgroup matrix configurations (square or non-square). @@ -3072,11 +3075,14 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { #ifndef __EMSCRIPTEN__ required_features.push_back(wgpu::FeatureName::ImplicitDeviceSynchronization); if (ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix) { - required_features.push_back(wgpu::FeatureName::Subgroups); required_features.push_back(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix); } #endif + if (ctx->webgpu_global_ctx->capabilities.supports_subgroups) { + required_features.push_back(wgpu::FeatureName::Subgroups); + } + #ifdef GGML_WEBGPU_GPU_PROFILE required_features.push_back(wgpu::FeatureName::TimestampQuery); #endif diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl index 62fe72ee3b1..14c045b0ba6 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl @@ -45,6 +45,13 @@ fn load_u16_at_src0(byte_offset: u32) -> u32 { return (word >> shift) & 0xFFFFu; } +// Always reads the 4-byte-aligned word containing byte_offset. +// Caller extracts the 16-bit half it needs via & 0xFFFFu or >> 16u. +// this is used in k-quants for better performance +fn load_u32_at_src0_aligned(byte_offset: u32) -> u32 { + return src0[(byte_offset & ~3u) / 4u]; +} + fn load_u32_at_src0(byte_offset: u32) -> u32 { let word_idx = byte_offset / 4u; let shift = (byte_offset & 0x3u) * 8u; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl index 9f7b3e32eca..97c9f6d7a09 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl @@ -1,465 +1,865 @@ +#ifdef USE_SUBGROUP_REDUCTION +enable subgroups; +#endif enable f16; #define DECLARE_BYTE_LOADERS_SRC0 #include "common_decls.tmpl" +#ifdef U32_DEQUANT_HELPERS +#define SRC0_TYPE u32 -#ifdef VEC +fn byte_of(v: u32, b: u32) -> u32 { + return (v >> (b * 8u)) & 0xFFu; +} + +fn sbyte_of(v: u32, b: u32) -> i32 { + let raw = i32((v >> (b * 8u)) & 0xFFu); + return select(raw, raw - 256, raw >= 128); +} +#endif -#define VEC_SIZE 4 -#define DST_TYPE vec4 +#ifdef VEC +#define VEC_SIZE 4u #define SRC0_TYPE vec4 #define SRC1_TYPE vec4 fn inner_dot(src0_val: SRC0_TYPE, src1_val: SRC1_TYPE) -> f32 { return f32(dot(SRC1_TYPE(src0_val), src1_val)); } - -fn store_val(group_base: u32) -> vec4 { - return vec4(partial_sums[group_base], - partial_sums[group_base + THREADS_PER_OUTPUT], - partial_sums[group_base + THREADS_PER_OUTPUT * 2], - partial_sums[group_base + THREADS_PER_OUTPUT * 3]); -} #endif #ifdef SCALAR - -#define VEC_SIZE 1 -#define DST_TYPE f32 +#define VEC_SIZE 1u #define SRC0_TYPE SRC0_INNER_TYPE #define SRC1_TYPE SRC1_INNER_TYPE fn inner_dot(src0_val: SRC0_TYPE, src1_val: SRC1_TYPE) -> f32 { return f32(src0_val) * f32(src1_val); } +#endif + +struct MulMatParams { + offset_src0: u32, + offset_src1: u32, + offset_dst: u32, + m: u32, + n: u32, + k: u32, + stride_01: u32, + stride_11: u32, + stride_02: u32, + stride_12: u32, + stride_03: u32, + stride_13: u32, + bs02: u32, + bs03: u32, + broadcast2: u32, + broadcast3: u32 +}; -fn store_val(group_base: u32) -> f32 { - return partial_sums[group_base]; +@group(0) @binding(0) var src0: array; +@group(0) @binding(1) var src1: array; +@group(0) @binding(2) var dst: array; + +@group(0) @binding(3) var params: MulMatParams; + +// Flattened as [row][thread] to keep each row's reduction contiguous in memory. +var partial_sums: array; + +fn partial_index(row: u32, thread: u32) -> u32 { + return row * WG_SIZE + thread; } + +@compute @workgroup_size(WG_SIZE) +fn main( + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) wg_id: vec3, + @builtin(num_workgroups) num_wg: vec3 +#ifdef USE_SUBGROUP_REDUCTION + , @builtin(subgroup_id) subgroup_id: u32, + @builtin(subgroup_invocation_id) subgroup_invocation_id: u32, + @builtin(num_subgroups) num_subgroups: u32, + @builtin(subgroup_size) subgroup_size: u32 #endif +) { + let thread_id = local_id.x; + + let total_batches = params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3; + let wg_linear = wg_id.y * num_wg.x + wg_id.x; + let output_groups = (params.m + OUTPUTS_PER_WG - 1u) / OUTPUTS_PER_WG; + let batch_idx = wg_linear / output_groups; + if (batch_idx >= total_batches) { + return; + } + + let row_base = (wg_linear % output_groups) * OUTPUTS_PER_WG; + + let dst2_stride = params.m * params.n; + let dst2_idx = batch_idx % (params.bs02 * params.broadcast2); + let dst3_stride = dst2_stride * params.bs02 * params.broadcast2; + let dst3_idx = batch_idx / (params.bs02 * params.broadcast2); + let src03_idx = dst3_idx / params.broadcast3; + let src13_idx = dst3_idx; + let src02_idx = dst2_idx / params.broadcast2; + let src12_idx = dst2_idx; + + let src0_batch_offset = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02; + let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12; + let dst_idx_base = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row_base; + + var acc: array; #ifdef MUL_ACC_FLOAT -fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { - var local_sum = 0.0; - for (var i = tig * VEC_SIZE; i < tile_size; i += THREADS_PER_OUTPUT * VEC_SIZE) { - let a = src0[(idx_base + k_outer + i) / VEC_SIZE]; - let b = shared_vector[i / VEC_SIZE]; - local_sum += inner_dot(a, b); + let k_vec = params.k / VEC_SIZE; + let src1_idx_base_vec = src1_idx_base / VEC_SIZE; + + // Each thread walks K, loads from the vector, and updates + // a small block of output rows held in registers. + for (var k = thread_id; k < k_vec; k += WG_SIZE) { + let x = src1[src1_idx_base_vec + k]; + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let src0_idx = (src0_batch_offset + output_row * params.stride_01) / VEC_SIZE + k; + acc[row] += inner_dot(src0[src0_idx], x); + } + } } - return local_sum; -} #endif #ifdef MUL_ACC_Q4_0 +#define BLOCK_SIZE 32 +#define BLOCK_SIZE_BYTES 18 +#define THREADS_PER_BLOCK 4 +#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) + + let num_blocks = params.k / BLOCK_SIZE; + let thread_within_block = thread_id % 4; + for (var block = thread_id/THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE/THREADS_PER_BLOCK) { + let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; + var x_block: array; + for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 4] = f32(src1[x_base + i + 16]); + } -const BLOCK_SIZE = 32; -const BLOCK_SIZE_BYTES = 18u; -const NQ = 16u; // number of weights per thread -const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 -const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; - -fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { - var local_sum = 0.0; - for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) { - let blck_idx = i / BLOCK_SIZE; - let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; - let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; - // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] - let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(load_f16_at_src0(block_byte_base)); - for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); - let q_packed = load_u32_at_src0(q_byte_offset); - for (var k: u32 = 0; k < 4; k++) { - let q_byte = get_byte(q_packed, k); - let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * d; - let q_lo = (f32(q_byte & 0xF) - 8.0) * d; - local_sum += q_lo * shared_vector[shmem_idx + j * 2 + k]; - local_sum += q_hi * shared_vector[shmem_idx + j * 2 + k + 16]; + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + var row_sum = 0.0; + + let q_packed = load_u32_at_src0(block_byte_base + 2u + 4u * thread_within_block); + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_byte = get_byte(q_packed, byte_idx); + let q_lo = (f32(q_byte & 0xFu) - 8.0) * d; + let q_hi = (f32((q_byte >> 4u) & 0xFu) - 8.0) * d; + row_sum += q_lo * x_block[byte_idx]; + row_sum += q_hi * x_block[byte_idx + 4u]; + } + acc[row] += row_sum; } } } - return local_sum; -} #endif #ifdef MUL_ACC_Q4_1 +#define BLOCK_SIZE 32 +#define BLOCK_SIZE_BYTES 20 +#define THREADS_PER_BLOCK 4 +#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) + + let num_blocks = params.k / BLOCK_SIZE; + let thread_within_block = thread_id % THREADS_PER_BLOCK; + for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { + let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; + var x_block: array; + for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 4] = f32(src1[x_base + i + 16]); + } -const BLOCK_SIZE = 32; -const BLOCK_SIZE_BYTES = 20u; -const NQ = 16u; // number of weights per thread -const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 -const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; - -fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { - var local_sum = 0.0; - for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) { - let blck_idx = i / BLOCK_SIZE; - let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; - let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; - // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] - let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(load_f16_at_src0(block_byte_base)); - let m = f32(load_f16_at_src0(block_byte_base + 2u)); - for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j); - let q_packed = load_u32_at_src0(q_byte_offset); - for (var k: u32 = 0; k < 4; k++) { - let q_byte = get_byte(q_packed, k); - let q_hi = f32((q_byte >> 4) & 0xF) * d + m; - let q_lo = f32(q_byte & 0xF) * d + m; - local_sum += q_lo * shared_vector[shmem_idx + j * 2 + k]; - local_sum += q_hi * shared_vector[shmem_idx + j * 2 + k + 16]; + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + let m = f32(load_f16_at_src0(block_byte_base + 2u)); + var row_sum = 0.0; + + let q_packed = load_u32_at_src0(block_byte_base + 4u + 4u * thread_within_block); + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_byte = get_byte(q_packed, byte_idx); + let q_lo = f32(q_byte & 0xFu) * d + m; + let q_hi = f32((q_byte >> 4u) & 0xFu) * d + m; + row_sum += q_lo * x_block[byte_idx]; + row_sum += q_hi * x_block[byte_idx + 4u]; + } + acc[row] += row_sum; } } } - return local_sum; -} #endif #ifdef MUL_ACC_Q5_0 +#define BLOCK_SIZE 32 +#define BLOCK_SIZE_BYTES 22 +#define THREADS_PER_BLOCK 4 +#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) + + let num_blocks = params.k / BLOCK_SIZE; + let thread_within_block = thread_id % THREADS_PER_BLOCK; + for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { + let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; + var x_block: array; + for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 4] = f32(src1[x_base + i + 16]); + } -const BLOCK_SIZE = 32; -const BLOCK_SIZE_BYTES = 22u; -const NQ = 16u; // number of weights per thread -const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 -const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; - -fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { - var local_sum = 0.0; - for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) { - let blck_idx = i / BLOCK_SIZE; - let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; - let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; - // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] - let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(load_f16_at_src0(block_byte_base)); - let qh_packed = load_u32_at_src0(block_byte_base + 2u); - - for (var j = 0u; j < 2; j++) { - let q_byte_offset = block_byte_base + 6u + 2u * (block_offset + j * 2u); - let q_packed = load_u32_at_src0(q_byte_offset); - - let j_adjusted = j + (block_offset / 2u); - - for (var k: u32 = 0; k < 4; k++) { - let q_byte = get_byte(q_packed, k); - - let qh_hi = (qh_packed >> (j_adjusted * 4 + k + 12)) & 0x10; - let q_hi = (f32(((q_byte >> 4) & 0xF) | qh_hi) - 16.0) * d; - let qh_lo = ((qh_packed >> (j_adjusted * 4 + k)) << 4) & 0x10; - let q_lo = (f32((q_byte & 0xF) | qh_lo) - 16.0) * d; - - local_sum += q_lo * shared_vector[shmem_idx + j * 4 + k]; - local_sum += q_hi * shared_vector[shmem_idx + j * 4 + k + 16]; + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + let qh_packed = load_u32_at_src0(block_byte_base + 2u); + let q_packed = load_u32_at_src0(block_byte_base + 6u + 4u * thread_within_block); + let qh_shift = thread_within_block * 4u; + var row_sum = 0.0; + + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_byte = get_byte(q_packed, byte_idx); + let qh_lo = ((qh_packed >> (qh_shift + byte_idx)) << 4u) & 0x10u; + let qh_hi = (qh_packed >> (qh_shift + byte_idx + 12u)) & 0x10u; + let q_lo = (f32((q_byte & 0xFu) | qh_lo) - 16.0) * d; + let q_hi = (f32(((q_byte >> 4u) & 0xFu) | qh_hi) - 16.0) * d; + row_sum += q_lo * x_block[byte_idx]; + row_sum += q_hi * x_block[byte_idx + 4u]; + } + acc[row] += row_sum; } - } } - return local_sum; -} #endif - #ifdef MUL_ACC_Q5_1 +#define BLOCK_SIZE 32 +#define BLOCK_SIZE_BYTES 24 +#define THREADS_PER_BLOCK 4 +#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) + + let num_blocks = params.k / BLOCK_SIZE; + let thread_within_block = thread_id % THREADS_PER_BLOCK; + for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { + let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; + var x_block: array; + for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 4] = f32(src1[x_base + i + 16]); + } -const BLOCK_SIZE = 32; -const BLOCK_SIZE_BYTES = 24u; -const NQ = 16u; // number of weights per thread -const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 -const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; - -fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { - var local_sum = 0.0; - for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) { - let blck_idx = i / BLOCK_SIZE; - let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; - let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; - // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] - let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(load_f16_at_src0(block_byte_base)); - let m = load_f16_at_src0(block_byte_base + 2u); - let qh_packed = load_u32_at_src0(block_byte_base + 4u); - - for (var j = 0u; j < 2; j++) { - let q_byte_offset = block_byte_base + 8u + 2u * (block_offset + j * 2u); - let q_packed = load_u32_at_src0(q_byte_offset); - - let j_adjusted = j + (block_offset / 2u); - - for (var k: u32 = 0; k < 4; k++) { - let q_byte = get_byte(q_packed, k); - - let qh_hi = (qh_packed >> (j_adjusted * 4 + k + 12)) & 0x10; - let q_hi = f32(((q_byte >> 4) & 0xF) | qh_hi) * d + f32(m); - let qh_lo = ((qh_packed >> (j_adjusted * 4 + k)) << 4) & 0x10; - let q_lo = f32((q_byte & 0xF) | qh_lo) * d + f32(m); - - local_sum += q_lo * shared_vector[shmem_idx + j * 4 + k]; - local_sum += q_hi * shared_vector[shmem_idx + j * 4 + k + 16]; + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + let m = f32(load_f16_at_src0(block_byte_base + 2u)); + let qh_packed = load_u32_at_src0(block_byte_base + 4u); + let q_packed = load_u32_at_src0(block_byte_base + 8u + 4u * thread_within_block); + let qh_shift = thread_within_block * 4u; + var row_sum = 0.0; + + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_byte = get_byte(q_packed, byte_idx); + let qh_lo = ((qh_packed >> (qh_shift + byte_idx)) << 4u) & 0x10u; + let qh_hi = (qh_packed >> (qh_shift + byte_idx + 12u)) & 0x10u; + let q_lo = f32((q_byte & 0xFu) | qh_lo) * d + m; + let q_hi = f32(((q_byte >> 4u) & 0xFu) | qh_hi) * d + m; + row_sum += q_lo * x_block[byte_idx]; + row_sum += q_hi * x_block[byte_idx + 4u]; + } + acc[row] += row_sum; } - } } - return local_sum; -} #endif - #ifdef MUL_ACC_Q8_0 +#define BLOCK_SIZE 32 +#define BLOCK_SIZE_BYTES 34 +#define THREADS_PER_BLOCK 4 +#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) + + let num_blocks = params.k / BLOCK_SIZE; + let thread_within_block = thread_id % THREADS_PER_BLOCK; + for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { + let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * ELEMS_PER_THREAD; + var x_block: array; + for (var i = 0u; i < ELEMS_PER_THREAD; i++) { + x_block[i] = f32(src1[x_base + i]); + } -const BLOCK_SIZE = 32; -const BLOCK_SIZE_BYTES = 34u; -const NQ = 16u; // number of weights per thread -const WEIGHTS_PER_F16 = 2u; -const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; - -fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { - var local_sum = 0.0; - for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) { - let blck_idx = i / BLOCK_SIZE; - let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; - let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; - // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] - let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(load_f16_at_src0(block_byte_base)); - - for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); - let q_packed = load_u32_at_src0(q_byte_offset); - for (var k: u32 = 0; k < 4; k++) { - let q_byte = get_byte_i32(q_packed, k); - let q_val = f32(q_byte) * d; - local_sum += q_val * shared_vector[shmem_idx + j * 2 + k]; + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + var row_sum = 0.0; + + for (var packed_idx = 0u; packed_idx < ELEMS_PER_THREAD / 4u; packed_idx++) { + let q_packed = load_u32_at_src0(block_byte_base + 2u + 4u * (thread_within_block * 2u + packed_idx)); + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_val = f32(get_byte_i32(q_packed, byte_idx)) * d; + row_sum += q_val * x_block[packed_idx * 4u + byte_idx]; + } + } + acc[row] += row_sum; } } } - return local_sum; -} #endif - #ifdef MUL_ACC_Q8_1 +#define BLOCK_SIZE 32 +#define BLOCK_SIZE_BYTES 36 +#define THREADS_PER_BLOCK 4 +#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) + + let num_blocks = params.k / BLOCK_SIZE; + let thread_within_block = thread_id % THREADS_PER_BLOCK; + for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { + let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * ELEMS_PER_THREAD; + var x_block: array; + for (var i = 0u; i < ELEMS_PER_THREAD; i++) { + x_block[i] = f32(src1[x_base + i]); + } -const BLOCK_SIZE = 32; -const BLOCK_SIZE_BYTES = 36u; -const NQ = 16u; // number of weights per thread -const WEIGHTS_PER_F16 = 2u; -const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; - -fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { - var local_sum = 0.0; - for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) { - let blck_idx = i / BLOCK_SIZE; - let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; - let block_byte_base = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * BLOCK_SIZE_BYTES; - // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] - let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(load_f16_at_src0(block_byte_base)); - let m = load_f16_at_src0(block_byte_base + 2u); - - for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j); - let q_packed = load_u32_at_src0(q_byte_offset); - for (var k: u32 = 0; k < 4; k++) { - let q_byte = get_byte_i32(q_packed, k); - let q_val = f32(q_byte) * d + f32(m); - local_sum += q_val * shared_vector[shmem_idx + j * 2 + k]; + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + let m = f32(load_f16_at_src0(block_byte_base + 2u)); + var row_sum = 0.0; + + for (var packed_idx = 0u; packed_idx < ELEMS_PER_THREAD / 4u; packed_idx++) { + let q_packed = load_u32_at_src0(block_byte_base + 4u + 4u * (thread_within_block * 2u + packed_idx)); + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_val = f32(get_byte_i32(q_packed, byte_idx)) * d + m; + row_sum += q_val * x_block[packed_idx * 4u + byte_idx]; + } + } + acc[row] += row_sum; } } } - return local_sum; -} #endif -#ifdef MUL_ACC_Q6_K - -const BLOCK_SIZE = 256u; -const BLOCK_SIZE_BYTES = 210u; - -fn byte_of(v: u32, b: u32) -> u32 { - return (v >> (b * 8u)) & 0xFFu; -} +#ifdef MUL_ACC_Q2_K +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 84 +#define THREADS_PER_BLOCK 16 + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let lane = tid / 2u; + let phase = tid % 2u; + let iq = lane / 4u; + let ir = lane % 4u; + let is = ir / 2u; + + let y_offset = 128u * iq + 8u * ir + 4u * phase; + let sc0_byte = 8u * iq + is; + let sc2_byte = 8u * iq + is + 2u; + let sc4_byte = 8u * iq + is + 4u; + let sc6_byte = 8u * iq + is + 6u; + let qs_byte = 16u + (16u * iq + 4u * ir) * 2u + 4u * phase; + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; + for (var i = 0u; i < 4u; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 4u] = f32(src1[x_base + 32u + i]); + x_block[i + 8u] = f32(src1[x_base + 64u + i]); + x_block[i + 12u] = f32(src1[x_base + 96u + i]); + } -fn sbyte_of(v: u32, b: u32) -> i32 { - let raw = i32((v >> (b * 8u)) & 0xFFu); - return select(raw, raw - 256, raw >= 128); -} + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + + let dall = f32(load_f16_at_src0(block_byte_base + 80u)); + let dmin = f32(load_f16_at_src0(block_byte_base + 82u)) * (1.0 / 16.0); + + let sc0 = byte_of(load_u32_at_src0_aligned(block_byte_base + sc0_byte), sc0_byte & 3u); + let sc2 = byte_of(load_u32_at_src0_aligned(block_byte_base + sc2_byte), sc2_byte & 3u); + let sc4 = byte_of(load_u32_at_src0_aligned(block_byte_base + sc4_byte), sc4_byte & 3u); + let sc6 = byte_of(load_u32_at_src0_aligned(block_byte_base + sc6_byte), sc6_byte & 3u); + + let q_u32 = load_u32_at_src0_aligned(block_byte_base + qs_byte); + let qs0 = q_u32 & 0xFFFFu; + let qs1 = q_u32 >> 16u; + + var sumy = vec4(0.0, 0.0, 0.0, 0.0); + var acc1 = vec4(0.0, 0.0, 0.0, 0.0); + var acc2 = vec4(0.0, 0.0, 0.0, 0.0); + + sumy[0] = x_block[0] + x_block[1] + x_block[2] + x_block[3]; + sumy[1] = x_block[4] + x_block[5] + x_block[6] + x_block[7]; + sumy[2] = x_block[8] + x_block[9] + x_block[10] + x_block[11]; + sumy[3] = x_block[12] + x_block[13] + x_block[14] + x_block[15]; + + acc1[0] = x_block[0] * f32(qs0 & 0x0003u) + x_block[2] * f32(qs1 & 0x0003u); + acc2[0] = x_block[1] * f32(qs0 & 0x0300u) + x_block[3] * f32(qs1 & 0x0300u); + acc1[1] = x_block[4] * f32(qs0 & 0x000Cu) + x_block[6] * f32(qs1 & 0x000Cu); + acc2[1] = x_block[5] * f32(qs0 & 0x0C00u) + x_block[7] * f32(qs1 & 0x0C00u); + acc1[2] = x_block[8] * f32(qs0 & 0x0030u) + x_block[10] * f32(qs1 & 0x0030u); + acc2[2] = x_block[9] * f32(qs0 & 0x3000u) + x_block[11] * f32(qs1 & 0x3000u); + acc1[3] = x_block[12] * f32(qs0 & 0x00C0u) + x_block[14] * f32(qs1 & 0x00C0u); + acc2[3] = x_block[13] * f32(qs0 & 0xC000u) + x_block[15] * f32(qs1 & 0xC000u); + + acc[row] += dall * ((acc1[0] + (1.0/256.0) * acc2[0]) * f32(sc0 & 0xFu) + + (acc1[1] + (1.0/256.0) * acc2[1]) * f32(sc2 & 0xFu) / 4.0 + + (acc1[2] + (1.0/256.0) * acc2[2]) * f32(sc4 & 0xFu) / 16.0 + + (acc1[3] + (1.0/256.0) * acc2[3]) * f32(sc6 & 0xFu) / 64.0) + - dmin * (sumy[0] * f32(sc0 & 0xF0u) + sumy[1] * f32(sc2 & 0xF0u) + + sumy[2] * f32(sc4 & 0xF0u) + sumy[3] * f32(sc6 & 0xF0u)); + } + } + } +#endif -fn mul_acc(tig: u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { - let tid = tig / 2u; - let ix = tig % 2u; - let ip = tid / 8u; - let il = tid % 8u; - let l0 = 4u * il; - let is = 8u * ip + l0 / 16u; - let y_offset = 128u * ip + l0; - let q_offset_l = 64u * ip + l0; - let q_offset_h = 32u * ip + l0; +#ifdef MUL_ACC_Q3_K +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 110 +#define THREADS_PER_BLOCK 16 - let nb = tile_size / BLOCK_SIZE; - let k_block_start = k_outer / BLOCK_SIZE; + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; - // Aligned scale byte position (is can be odd) - let sc_base_byte = 192u + (is & ~3u); - let sc_byte_pos = is & 3u; + let lane = tid / 2u; + let phase = tid % 2u; + let ip = lane / 4u; + let il = 2u * ((lane % 4u) / 2u); + let ir = lane % 2u; + let l0 = 8u * ir; - var local_sum = 0.0; + let q_byte = 32u + 32u * ip + l0 + 16u * phase; + let h_byte = l0 + 16u * phase; + let y_offset = 128u * ip + 32u * il + l0 + 16u * phase; - for (var i = ix; i < nb; i += 2u) { - let bbase = (idx_base + k_block_start + i) * BLOCK_SIZE_BYTES; + let s_shift1 = 4u * ip; + let s_shift2 = s_shift1 + il; - let d = f32(load_f16_at_src0(bbase + 208u)); + let v1 = select(64.0, 4.0, il == 0u); + let v2 = 4.0 * v1; + let shift = 2u * il; - let ql1_u32 = load_u32_at_src0(bbase + q_offset_l); - let ql2_u32 = load_u32_at_src0(bbase + q_offset_l + 32u); - let qh_u32 = load_u32_at_src0(bbase + 128u + q_offset_h); - let sc_u32_0 = load_u32_at_src0(bbase + sc_base_byte); - let sc_u32_1 = load_u32_at_src0(bbase + sc_base_byte + 4u); + var qm0: u32; var qm1: u32; var qm2: u32; var qm3: u32; + if (il == 0u) { + qm0 = 0x0003u; qm1 = 0x0300u; qm2 = 0x000Cu; qm3 = 0x0C00u; + } else { + qm0 = 0x0030u; qm1 = 0x3000u; qm2 = 0x00C0u; qm3 = 0xC000u; + } - let sc0 = sbyte_of(sc_u32_0, sc_byte_pos); - let sc2 = sbyte_of(sc_u32_0, sc_byte_pos + 2u); - let sc4 = sbyte_of(sc_u32_1, sc_byte_pos); - let sc6 = sbyte_of(sc_u32_1, sc_byte_pos + 2u); + let mm_idx = 2u * ip + il / 2u; + var hm0: u32; var hm1: u32; var hm2: u32; var hm3: u32; + switch (mm_idx) { + case 0u: { hm0=0x0001u; hm1=0x0100u; hm2=0x0002u; hm3=0x0200u; } + case 1u: { hm0=0x0004u; hm1=0x0400u; hm2=0x0008u; hm3=0x0800u; } + case 2u: { hm0=0x0010u; hm1=0x1000u; hm2=0x0020u; hm3=0x2000u; } + default: { hm0=0x0040u; hm1=0x4000u; hm2=0x0080u; hm3=0x8000u; } + } - var sums = vec4(0.0, 0.0, 0.0, 0.0); + let num_blocks = params.k / BLOCK_SIZE; - for (var l = 0u; l < 4u; l++) { - let y_base = i * BLOCK_SIZE + y_offset + l; - let yl0 = f32(shared_vector[y_base]); - let yl1 = f32(shared_vector[y_base + 32u]); - let yl2 = f32(shared_vector[y_base + 64u]); - let yl3 = f32(shared_vector[y_base + 96u]); - - let q1b = byte_of(ql1_u32, l); - let q2b = byte_of(ql2_u32, l); - let qhb = byte_of(qh_u32, l); - - let dq0 = f32(i32((q1b & 0x0Fu) | ((qhb & 0x03u) << 4u)) - 32); - let dq1 = f32(i32((q2b & 0x0Fu) | ((qhb & 0x0Cu) << 2u)) - 32); - let dq2 = f32(i32((q1b >> 4u) | ((qhb & 0x30u) )) - 32); - let dq3 = f32(i32((q2b >> 4u) | ((qhb & 0xC0u) >> 2u)) - 32); - - sums[0] += yl0 * dq0; - sums[1] += yl1 * dq1; - sums[2] += yl2 * dq2; - sums[3] += yl3 * dq3; + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; + for (var i = 0u; i < 8u; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 8u] = f32(src1[x_base + 32u + i]); } - local_sum += d * (sums[0] * f32(sc0) + sums[1] * f32(sc2) + - sums[2] * f32(sc4) + sums[3] * f32(sc6)); + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + + let d = f32(load_f16_at_src0(block_byte_base + 108u)); + let a_base = 96u; + let a_il0 = load_u16_at_src0(block_byte_base + a_base + il * 2u); + let a_il1 = load_u16_at_src0(block_byte_base + a_base + (il + 1u) * 2u); + let a_4 = load_u16_at_src0(block_byte_base + a_base + 8u); + let a_5 = load_u16_at_src0(block_byte_base + a_base + 10u); + + var scales32 = a_4 | (a_5 << 16u); + let aux32 = ((scales32 >> s_shift2) << 4u) & 0x30303030u; + scales32 = a_il0 | (a_il1 << 16u); + scales32 = ((scales32 >> s_shift1) & 0x0F0F0F0Fu) | aux32; + + let scale0 = f32(i32(byte_of(scales32, phase + 0u)) - 32); + let scale1 = f32(i32(byte_of(scales32, phase + 2u)) - 32); + + let q_u32_0 = load_u32_at_src0(block_byte_base + q_byte + 0u); + let q_u32_1 = load_u32_at_src0(block_byte_base + q_byte + 4u); + let h_u32_0 = load_u32_at_src0(block_byte_base + h_byte + 0u); + let h_u32_1 = load_u32_at_src0(block_byte_base + h_byte + 4u); + + var s1 = 0.0; var s2 = 0.0; var s3 = 0.0; + var s4 = 0.0; var s5 = 0.0; var s6 = 0.0; + + for (var l = 0u; l < 8u; l += 2u) { + let q_u32 = select(q_u32_0, q_u32_1, l >= 4u); + let qs = select(q_u32 & 0xFFFFu, q_u32 >> 16u, (l & 2u) != 0u); + let h_u32 = select(h_u32_0, h_u32_1, l >= 4u); + let hv = select(h_u32 & 0xFFFFu, h_u32 >> 16u, (l & 2u) != 0u); + + s1 += x_block[l + 0u] * f32(qs & qm0); + s2 += x_block[l + 1u] * f32(qs & qm1); + s3 += select(0.0, x_block[l + 0u], (hv & hm0) == 0u) + + select(0.0, x_block[l + 1u], (hv & hm1) == 0u); + s4 += x_block[l + 8u] * f32(qs & qm2); + s5 += x_block[l + 9u] * f32(qs & qm3); + s6 += select(0.0, x_block[l + 8u], (hv & hm2) == 0u) + + select(0.0, x_block[l + 9u], (hv & hm3) == 0u); + } + + let d1 = d * (s1 + (1.0/256.0) * s2 - s3 * v1); + let d2 = d * (s4 + (1.0/256.0) * s5 - s6 * v2); + acc[row] += (d1 * scale0 + 0.25 * d2 * scale1) / f32(1u << shift); + } + } } - - return local_sum; -} #endif -struct MulMatParams { - offset_src0: u32, - offset_src1: u32, - offset_dst: u32, - m: u32, - n: u32, - k: u32, - stride_01: u32, - stride_11: u32, - stride_02: u32, - stride_12: u32, - stride_03: u32, - stride_13: u32, - bs02: u32, - bs03: u32, - broadcast2: u32, - broadcast3: u32 -}; - -// SRC0_TYPE and SRC1_TYPE are defined in mul_mat_decls, which is included -@group(0) @binding(0) var src0: array; // M rows, K columns -@group(0) @binding(1) var src1: array; // K rows, N columns (transposed) -@group(0) @binding(2) var dst: array; // M rows, N columns (transposed) - -@group(0) @binding(3) var params: MulMatParams; - -const THREADS_PER_OUTPUT = WG_SIZE / OUTPUTS_PER_WG; +#ifdef MUL_ACC_Q4_K +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 144 +#define THREADS_PER_BLOCK 16 + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let il = tid / 4u; + let ir = tid % 4u; + let im = il / 2u; + let in = il % 2u; + let l0 = 4u * (2u * ir + in); + + let y_offset = 64u * im + l0; + let q_offset = 32u * im + l0; + let sc0_byte = 4u + im * 2u; + let sc2_byte = 4u + (im + 2u) * 2u; + let sc4_byte = 4u + (im + 4u) * 2u; + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; + for (var i = 0u; i < 4u; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 4u] = f32(src1[x_base + 32u + i]); + x_block[i + 8u] = f32(src1[x_base + 128u + i]); + x_block[i + 12u] = f32(src1[x_base + 160u + i]); + } -// Shared memory for collaborative loading and reduction -var shared_vector: array; // Cache vector tile -var partial_sums: array; // For reduction + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + + let d = f32(load_f16_at_src0(block_byte_base + 0u)); + let dmin = f32(load_f16_at_src0(block_byte_base + 2u)); + + let sc0_u32 = load_u32_at_src0_aligned(block_byte_base + sc0_byte); + let sc0 = select(sc0_u32 & 0xFFFFu, sc0_u32 >> 16u, (sc0_byte & 2u) != 0u); + let sc2_u32 = load_u32_at_src0_aligned(block_byte_base + sc2_byte); + let sc2 = select(sc2_u32 & 0xFFFFu, sc2_u32 >> 16u, (sc2_byte & 2u) != 0u); + let sc4_u32 = load_u32_at_src0_aligned(block_byte_base + sc4_byte); + let sc4 = select(sc4_u32 & 0xFFFFu, sc4_u32 >> 16u, (sc4_byte & 2u) != 0u); + + let sc16_0 = sc0 & 0x3F3Fu; + let sc16_1 = sc2 & 0x3F3Fu; + let sc16_2 = (sc4 & 0x0F0Fu) | ((sc0 & 0xC0C0u) >> 2u); + let sc16_3 = ((sc4 >> 4u) & 0x0F0Fu) | ((sc2 & 0xC0C0u) >> 2u); + + let scale0 = f32(sc16_0 & 0xFFu); + let scale1 = f32((sc16_0 >> 8u) & 0xFFu); + let min0 = f32(sc16_1 & 0xFFu); + let min1 = f32((sc16_1 >> 8u) & 0xFFu); + let scale2 = f32(sc16_2 & 0xFFu); + let scale3 = f32((sc16_2 >> 8u) & 0xFFu); + let min2 = f32(sc16_3 & 0xFFu); + let min3 = f32((sc16_3 >> 8u) & 0xFFu); + + let q1_u32 = load_u32_at_src0_aligned(block_byte_base + 16u + q_offset); + let q2_u32 = load_u32_at_src0_aligned(block_byte_base + 80u + q_offset); + + var dot = vec4(0.0, 0.0, 0.0, 0.0); + var sumx = vec4(0.0, 0.0, 0.0, 0.0); + for (var i = 0u; i < 4u; i++) { + let q1b = byte_of(q1_u32, i); + let q2b = byte_of(q2_u32, i); + dot[0] += x_block[i] * f32(q1b & 0x0Fu); + dot[1] += x_block[i + 4u] * f32(q1b >> 4u); + dot[2] += x_block[i + 8u] * f32(q2b & 0x0Fu); + dot[3] += x_block[i + 12u] * f32(q2b >> 4u); + sumx[0] += x_block[i]; + sumx[1] += x_block[i + 4u]; + sumx[2] += x_block[i + 8u]; + sumx[3] += x_block[i + 12u]; + } + + acc[row] += d * (dot[0] * scale0 + dot[1] * scale1 + dot[2] * scale2 + dot[3] * scale3) + - dmin * (sumx[0] * min0 + sumx[1] * min1 + sumx[2] * min2 + sumx[3] * min3); + } + } + } +#endif -@compute @workgroup_size(WG_SIZE) -fn main( - @builtin(local_invocation_id) local_id: vec3, - @builtin(workgroup_id) wg_id: vec3, - @builtin(num_workgroups) num_wg: vec3) { - let thread_id = local_id.x; +#ifdef MUL_ACC_Q5_K +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 176 +#define THREADS_PER_BLOCK 16 + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let il = tid / 4u; + let ir = tid % 4u; + let im = il / 2u; + let in = il % 2u; + let l0 = 4u * (2u * ir + in); + + let y_offset = 64u * im + l0; + let q_offset = 48u + 32u * im + l0; + let qh_offset = 16u + 8u * ir + 4u * in; + let sc0_byte = 4u + im * 2u; + let sc2_byte = 4u + (im + 2u) * 2u; + let sc4_byte = 4u + (im + 4u) * 2u; + + let hm1 = 1u << (2u * im); + let hm2 = hm1 << 1u; + let hm3 = hm1 << 4u; + let hm4 = hm2 << 4u; + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; + for (var i = 0u; i < 4u; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 4u] = f32(src1[x_base + 32u + i]); + x_block[i + 8u] = f32(src1[x_base + 128u + i]); + x_block[i + 12u] = f32(src1[x_base + 160u + i]); + } - // Handle batch dimensions - let total_batches = params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3; - let wg_linear = wg_id.y * num_wg.x + wg_id.x; - let output_groups = (params.m + OUTPUTS_PER_WG - 1u) / OUTPUTS_PER_WG; - let batch_idx = wg_linear / output_groups; - if (batch_idx >= total_batches) { - return; + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + + let d = f32(load_f16_at_src0(block_byte_base + 0u)); + let dmin = f32(load_f16_at_src0(block_byte_base + 2u)); + + let sc0_u32 = load_u32_at_src0_aligned(block_byte_base + sc0_byte); + let sc0 = select(sc0_u32 & 0xFFFFu, sc0_u32 >> 16u, (sc0_byte & 2u) != 0u); + let sc2_u32 = load_u32_at_src0_aligned(block_byte_base + sc2_byte); + let sc2 = select(sc2_u32 & 0xFFFFu, sc2_u32 >> 16u, (sc2_byte & 2u) != 0u); + let sc4_u32 = load_u32_at_src0_aligned(block_byte_base + sc4_byte); + let sc4 = select(sc4_u32 & 0xFFFFu, sc4_u32 >> 16u, (sc4_byte & 2u) != 0u); + + let sc16_0 = sc0 & 0x3F3Fu; + let sc16_1 = sc2 & 0x3F3Fu; + let sc16_2 = (sc4 & 0x0F0Fu) | ((sc0 & 0xC0C0u) >> 2u); + let sc16_3 = ((sc4 >> 4u) & 0x0F0Fu) | ((sc2 & 0xC0C0u) >> 2u); + + let f0 = f32(sc16_0 & 0xFFu); + let f1 = f32((sc16_0 >> 8u) & 0xFFu); + let m0 = f32(sc16_1 & 0xFFu); + let m1 = f32((sc16_1 >> 8u) & 0xFFu); + let f4 = f32(sc16_2 & 0xFFu); + let f5 = f32((sc16_2 >> 8u) & 0xFFu); + let m4 = f32(sc16_3 & 0xFFu); + let m5 = f32((sc16_3 >> 8u) & 0xFFu); + + let q1_u32 = load_u32_at_src0_aligned(block_byte_base + q_offset); + let q2_u32 = load_u32_at_src0_aligned(block_byte_base + q_offset + 64u); + let qh_u32 = load_u32_at_src0_aligned(block_byte_base + qh_offset); + + var vals = vec4(0.0, 0.0, 0.0, 0.0); + var sumy = vec4(0.0, 0.0, 0.0, 0.0); + for (var i = 0u; i < 4u; i++) { + let q1b = byte_of(q1_u32, i); + let q2b = byte_of(q2_u32, i); + let qhb = byte_of(qh_u32, i); + + let yl0 = x_block[i]; + let yl8 = x_block[i + 4u]; + let yh0 = x_block[i + 8u]; + let yh8 = x_block[i + 12u]; + + sumy[0] += yl0; + sumy[1] += yl8; + sumy[2] += yh0; + sumy[3] += yh8; + + let q0 = f32((q1b & 0x0Fu) | select(0u, 0x10u, (qhb & hm1) != 0u)); + let q1 = f32((q1b >> 4u) | select(0u, 0x10u, (qhb & hm2) != 0u)); + let q2 = f32((q2b & 0x0Fu) | select(0u, 0x10u, (qhb & hm3) != 0u)); + let q3 = f32((q2b >> 4u) | select(0u, 0x10u, (qhb & hm4) != 0u)); + + vals[0] += yl0 * q0; + vals[1] += yl8 * q1; + vals[2] += yh0 * q2; + vals[3] += yh8 * q3; + } + + acc[row] += d * (f0 * vals[0] + f1 * vals[1] + f4 * vals[2] + f5 * vals[3]) + - dmin * (sumy[0] * m0 + sumy[1] * m1 + + sumy[2] * m4 + sumy[3] * m5); + } + } } +#endif - // Which of the outputs does this thread belong to? - let thread_group = thread_id / THREADS_PER_OUTPUT; - let thread_in_group = thread_id % THREADS_PER_OUTPUT; +#ifdef MUL_ACC_Q6_K +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 210 +#define THREADS_PER_BLOCK 16 - // Each workgroup computes OUTPUTS_PER_WG consecutive outputs - let output_row = (wg_linear % output_groups) * OUTPUTS_PER_WG + thread_group; + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; - let dst2_stride = params.m * params.n; - let dst2_idx = batch_idx % (params.bs02 * params.broadcast2); - let dst3_stride = dst2_stride * params.bs02 * params.broadcast2; - let dst3_idx = batch_idx / (params.bs02 * params.broadcast2); - let src03_idx = dst3_idx / params.broadcast3; - let src13_idx = dst3_idx; - let src02_idx = dst2_idx / params.broadcast2; - let src12_idx = dst2_idx; + let ip = tid / 8u; + let il = tid % 8u; + let l0 = 4u * il; + let is = 8u * ip + l0 / 16u; - let src0_idx_base = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02 + output_row * params.stride_01; - let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12; - let dst_idx = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + output_row; + let y_offset = 128u * ip + l0; + let q_offset_l = 64u * ip + l0; + let q_offset_h = 32u * ip + l0; - var local_sum = 0.0; + let num_blocks = params.k / BLOCK_SIZE; + let sc_base_byte = 192u + (is & ~3u); + let sc_byte_pos = is & 3u; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; + for (var l = 0u; l < 4u; l++) { + x_block[l] = f32(src1[x_base + l]); + x_block[l + 4u] = f32(src1[x_base + 32u + l]); + x_block[l + 8u] = f32(src1[x_base + 64u + l]); + x_block[l + 12u] = f32(src1[x_base + 96u + l]); + } - // Each thread processes multiple K elements and accumulates - for (var k_tile = 0u; k_tile < params.k; k_tile += TILE_K) { - let tile_size = min(TILE_K, params.k - k_tile); + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + + let d = f32(load_f16_at_src0(block_byte_base + 208u)); + let ql1_u32 = load_u32_at_src0(block_byte_base + q_offset_l); + let ql2_u32 = load_u32_at_src0(block_byte_base + q_offset_l + 32u); + let qh_u32 = load_u32_at_src0(block_byte_base + 128u + q_offset_h); + let sc_u32_0 = load_u32_at_src0(block_byte_base + sc_base_byte); + let sc_u32_1 = load_u32_at_src0(block_byte_base + sc_base_byte + 4u); + + let sc0 = sbyte_of(sc_u32_0, sc_byte_pos); + let sc2 = sbyte_of(sc_u32_0, sc_byte_pos + 2u); + let sc4 = sbyte_of(sc_u32_1, sc_byte_pos); + let sc6 = sbyte_of(sc_u32_1, sc_byte_pos + 2u); + + var sums = vec4(0.0, 0.0, 0.0, 0.0); + + for (var l = 0u; l < 4u; l++) { + let q1b = byte_of(ql1_u32, l); + let q2b = byte_of(ql2_u32, l); + let qhb = byte_of(qh_u32, l); + + let dq0 = f32(i32((q1b & 0x0Fu) | ((qhb & 0x03u) << 4u)) - 32); + let dq1 = f32(i32((q2b & 0x0Fu) | ((qhb & 0x0Cu) << 2u)) - 32); + let dq2 = f32(i32((q1b >> 4u) | (qhb & 0x30u)) - 32); + let dq3 = f32(i32((q2b >> 4u) | ((qhb & 0xC0u) >> 2u)) - 32); + + sums[0] += x_block[l] * dq0; + sums[1] += x_block[l + 4u] * dq1; + sums[2] += x_block[l + 8u] * dq2; + sums[3] += x_block[l + 12u] * dq3; + } + + acc[row] += d * (sums[0] * f32(sc0) + sums[1] * f32(sc2) + + sums[2] * f32(sc4) + sums[3] * f32(sc6)); + } + } + } +#endif - // Cooperatively load vector tile into shared memory (all threads) - for (var i = thread_id * VEC_SIZE; i < tile_size; i += WG_SIZE * VEC_SIZE) { - shared_vector[i / VEC_SIZE] = src1[(src1_idx_base + k_tile + i) / VEC_SIZE]; +#ifdef USE_SUBGROUP_REDUCTION + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let subgroup_total = subgroupAdd(acc[row]); + if (subgroup_invocation_id == 0u) { + partial_sums[partial_index(row, subgroup_id)] = subgroup_total; } + } - workgroupBarrier(); + workgroupBarrier(); - if (output_row < params.m) { - local_sum += mul_acc(thread_in_group, tile_size, src0_idx_base, k_tile); + for (var row = subgroup_id; (row < OUTPUTS_PER_WG) && (row_base + row < params.m); row += num_subgroups) { + let output_row = row_base + row; + var row_acc = 0.0f; + for (var k = subgroup_invocation_id; k < num_subgroups; k += subgroup_size) { + row_acc += partial_sums[partial_index(row, k)]; } + let row_total = subgroupAdd(row_acc); + if (subgroup_invocation_id == 0) { + dst[dst_idx_base + row] = row_total; + } + } +#endif - workgroupBarrier(); +#ifdef USE_WORKGROUP_REDUCTION + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + partial_sums[partial_index(row, thread_id)] = acc[row]; } - // Store partial sums and reduce within each partition - partial_sums[thread_id] = local_sum; workgroupBarrier(); - let group_base = thread_group * THREADS_PER_OUTPUT; - let thread_base = group_base + thread_in_group; - var offset: u32 = THREADS_PER_OUTPUT / 2; - while (offset > 0) { - if (thread_in_group < offset) { - partial_sums[thread_base] += partial_sums[thread_base + offset]; + + var stride = WG_SIZE / 2u; + + while (stride > 0) { + if (thread_id < stride) { + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + partial_sums[partial_index(row, thread_id)] += partial_sums[partial_index(row, thread_id + stride)]; + } } - offset = offset / 2; + workgroupBarrier(); + stride = stride / 2; } - // Store back to global memory - if (output_row < params.m && thread_group % VEC_SIZE == 0 && thread_in_group == 0) { - dst[dst_idx / VEC_SIZE] = store_val(group_base); + if (thread_id < OUTPUTS_PER_WG) { + let output_row = row_base + thread_id; + if (output_row < params.m) { + dst[dst_idx_base + thread_id] = partial_sums[partial_index(thread_id, 0)]; + } } +#endif } From 2b9fb0be770f188e9d6b506403e3f3606f8a66dc Mon Sep 17 00:00:00 2001 From: pl752 Date: Mon, 20 Apr 2026 21:02:54 +0500 Subject: [PATCH 461/831] ggml-cpu: Optimized x86 and generic cpu q1_0 dot (follow up) (llama/21636) * Implemented optimized q1_0 dot for x86 and generic * Removed redundant helper definition * Removed two redundant instructions from AVX q1_0 dot * Fixed inconsistency with fp16 conversion for generic q1_0 dot and deduplicated generic fallback * Style cleanup around AVX q1_0 dot * Replaced explicitly unrolled blocks with inner for loop for q1_0 * Replaced scalar ARM q1_0 impl with new generic one --- ggml/src/ggml-cpu/arch-fallback.h | 1 - ggml/src/ggml-cpu/arch/arm/quants.c | 30 +----- ggml/src/ggml-cpu/arch/x86/quants.c | 158 ++++++++++++++++++++++++++++ ggml/src/ggml-cpu/quants.c | 26 +++-- 4 files changed, 179 insertions(+), 36 deletions(-) diff --git a/ggml/src/ggml-cpu/arch-fallback.h b/ggml/src/ggml-cpu/arch-fallback.h index c589a213e9d..595ded09f03 100644 --- a/ggml/src/ggml-cpu/arch-fallback.h +++ b/ggml/src/ggml-cpu/arch-fallback.h @@ -83,7 +83,6 @@ #elif defined(__x86_64__) || defined(__i386__) || defined(_M_IX86) || defined(_M_X64) // quants.c #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 -#define ggml_vec_dot_q1_0_q8_0_generic ggml_vec_dot_q1_0_q8_0 // repack.cpp #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 #define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 diff --git a/ggml/src/ggml-cpu/arch/arm/quants.c b/ggml/src/ggml-cpu/arch/arm/quants.c index 64d811fafe7..fe621332970 100644 --- a/ggml/src/ggml-cpu/arch/arm/quants.c +++ b/ggml/src/ggml-cpu/arch/arm/quants.c @@ -151,8 +151,6 @@ void ggml_vec_dot_q1_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi const block_q1_0 * GGML_RESTRICT x = vx; const block_q8_0 * GGML_RESTRICT y = vy; - float sumf = 0.0f; - #if defined(__ARM_NEON) float32x4_t sumv = vdupq_n_f32(0.0f); @@ -212,31 +210,13 @@ void ggml_vec_dot_q1_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi } } - sumf = vaddvq_f32(sumv); + *s = vaddvq_f32(sumv); #else - // Scalar fallback - for (int i = 0; i < nb; i++) { - const float d0 = GGML_FP16_TO_FP32(x[i].d); - - // Process 4 Q8_0 blocks - for (int k = 0; k < 4; k++) { - const float d1 = GGML_FP16_TO_FP32(y[i*4 + k].d); - - int sumi = 0; - for (int j = 0; j < QK8_0; j++) { - const int bit_index = k * QK8_0 + j; - const int byte_index = bit_index / 8; - const int bit_offset = bit_index % 8; - - const int xi = ((x[i].qs[byte_index] >> bit_offset) & 1) ? 1 : -1; - sumi += xi * y[i*4 + k].qs[j]; - } - sumf += d0 * d1 * sumi; - } - } + UNUSED(nb); + UNUSED(x); + UNUSED(y); + ggml_vec_dot_q1_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); #endif - - *s = sumf; } diff --git a/ggml/src/ggml-cpu/arch/x86/quants.c b/ggml/src/ggml-cpu/arch/x86/quants.c index 74d699f633d..0a3e071e57c 100644 --- a/ggml/src/ggml-cpu/arch/x86/quants.c +++ b/ggml/src/ggml-cpu/arch/x86/quants.c @@ -274,6 +274,18 @@ static inline __m256 quad_mx_delta_float(const uint8_t x0, const float y0, const } #endif #elif defined(__SSSE3__) +static inline __m128i bytes_from_bits_16(const uint8_t * x) { + uint16_t x16; + memcpy(&x16, x, sizeof(uint16_t)); + + const __m128i shuf_mask = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000); + __m128i bytes = _mm_shuffle_epi8(_mm_set1_epi16((short) x16), shuf_mask); + const __m128i bit_mask = _mm_set_epi64x(0x7fbfdfeff7fbfdfe, 0x7fbfdfeff7fbfdfe); + bytes = _mm_or_si128(bytes, bit_mask); + + return _mm_cmpeq_epi8(bytes, _mm_set1_epi64x(-1)); +} + // horizontally add 4x4 floats static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128 c, const __m128 d) { __m128 res_0 =_mm_hadd_ps(a, b); @@ -540,6 +552,152 @@ static inline __m128i get_scale_shuffle(int i) { } #endif +void ggml_vec_dot_q1_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + const int qk = QK1_0; + const int nb = n / qk; + + assert(n % qk == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q1_0 * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + +#if defined(__AVX2__) + const __m256i ones_8 = _mm256_set1_epi8(1); + const __m256i ones_16 = _mm256_set1_epi16(1); + const __m256i byte_shuf = _mm256_setr_epi8( + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, + 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3); + const __m256i bit_masks = _mm256_setr_epi8( + 1, 2, 4, 8, 16, 32, 64, (char) -128, 1, 2, 4, 8, 16, 32, 64, (char) -128, + 1, 2, 4, 8, 16, 32, 64, (char) -128, 1, 2, 4, 8, 16, 32, 64, (char) -128); + const __m256i zero = _mm256_setzero_si256(); + __m256 acc = _mm256_setzero_ps(); + + for (int ib = 0; ib < nb; ++ib) { + const float d0 = GGML_CPU_FP16_TO_FP32(x[ib].d); + const uint32_t * GGML_RESTRICT qs32 = (const uint32_t *) x[ib].qs; + const block_q8_0 * GGML_RESTRICT y_ptr = &y[ib * 4]; + + __m256 acc_block; + { + const __m256i qy = _mm256_loadu_si256((const __m256i *) y_ptr[0].qs); + const __m256i sm = _mm256_cmpeq_epi8( + _mm256_and_si256(_mm256_shuffle_epi8(_mm256_set1_epi32((int) qs32[0]), byte_shuf), bit_masks), zero); + const __m256i sy = _mm256_sub_epi8(_mm256_xor_si256(qy, sm), sm); + const __m256i s32 = _mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy), ones_16); + acc_block = _mm256_mul_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y_ptr[0].d)), _mm256_cvtepi32_ps(s32)); + } + for (int K = 1; K < 4; ++K) { + const __m256i qy = _mm256_loadu_si256((const __m256i *) y_ptr[K].qs); + const __m256i sm = _mm256_cmpeq_epi8( + _mm256_and_si256(_mm256_shuffle_epi8(_mm256_set1_epi32((int) qs32[K]), byte_shuf), bit_masks), zero); + const __m256i sy = _mm256_sub_epi8(_mm256_xor_si256(qy, sm), sm); + const __m256i s32 = _mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy), ones_16); + acc_block = _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y_ptr[K].d)), _mm256_cvtepi32_ps(s32), acc_block); + } + acc = _mm256_fmadd_ps(_mm256_set1_ps(d0), acc_block, acc); + } + + *s = hsum_float_8(acc); +#elif defined(__AVX__) + const __m128i ones_8 = _mm_set1_epi8(1); + const __m128i ones_16 = _mm_set1_epi16(1); + const __m128i zero = _mm_setzero_si128(); + __m256 acc = _mm256_setzero_ps(); + + for (int ib = 0; ib < nb; ++ib) { + const float d0 = GGML_CPU_FP16_TO_FP32(x[ib].d); + const block_q8_0 * GGML_RESTRICT y_ptr = &y[ib * 4]; + __m256 acc_block; + { + const __m256i bit_mask = bytes_from_bits_32(&x[ib].qs[0]); + const __m128i bit_mask_0 = _mm256_castsi256_si128(bit_mask); + const __m128i bit_mask_1 = _mm256_extractf128_si256(bit_mask, 1); + const __m128i qy_0 = _mm_loadu_si128((const __m128i *) &y_ptr[0].qs[0]); + const __m128i qy_1 = _mm_loadu_si128((const __m128i *) &y_ptr[0].qs[16]); + const __m128i sign_mask_0 = _mm_cmpeq_epi8(bit_mask_0, zero); + const __m128i sign_mask_1 = _mm_cmpeq_epi8(bit_mask_1, zero); + const __m128i sy_0 = _mm_sub_epi8(_mm_xor_si128(qy_0, sign_mask_0), sign_mask_0); + const __m128i sy_1 = _mm_sub_epi8(_mm_xor_si128(qy_1, sign_mask_1), sign_mask_1); + const __m128i sum16_0 = _mm_maddubs_epi16(ones_8, sy_0); + const __m128i sum16_1 = _mm_maddubs_epi16(ones_8, sy_1); + const __m128i sum32_0 = _mm_madd_epi16(sum16_0, ones_16); + const __m128i sum32_1 = _mm_madd_epi16(sum16_1, ones_16); + const __m256 q = _mm256_cvtepi32_ps(MM256_SET_M128I(sum32_1, sum32_0)); + acc_block = _mm256_mul_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y_ptr[0].d)), q); + } + for(int K = 1; K < 4; ++K) { + const __m256i bit_mask = bytes_from_bits_32(&x[ib].qs[(K) * 4]); + const __m128i bit_mask_0 = _mm256_castsi256_si128(bit_mask); + const __m128i bit_mask_1 = _mm256_extractf128_si256(bit_mask, 1); + const __m128i qy_0 = _mm_loadu_si128((const __m128i *) &y_ptr[(K)].qs[0]); + const __m128i qy_1 = _mm_loadu_si128((const __m128i *) &y_ptr[(K)].qs[16]); + const __m128i sign_mask_0 = _mm_cmpeq_epi8(bit_mask_0, zero); + const __m128i sign_mask_1 = _mm_cmpeq_epi8(bit_mask_1, zero); + const __m128i sy_0 = _mm_sub_epi8(_mm_xor_si128(qy_0, sign_mask_0), sign_mask_0); + const __m128i sy_1 = _mm_sub_epi8(_mm_xor_si128(qy_1, sign_mask_1), sign_mask_1); + const __m128i sum16_0 = _mm_maddubs_epi16(ones_8, sy_0); + const __m128i sum16_1 = _mm_maddubs_epi16(ones_8, sy_1); + const __m128i sum32_0 = _mm_madd_epi16(sum16_0, ones_16); + const __m128i sum32_1 = _mm_madd_epi16(sum16_1, ones_16); + const __m256 q = _mm256_cvtepi32_ps(MM256_SET_M128I(sum32_1, sum32_0)); + acc_block = _mm256_add_ps(acc_block, _mm256_mul_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y_ptr[(K)].d)), q)); + } +#undef Q1_AVX_BLOCK + + acc = _mm256_add_ps(acc, _mm256_mul_ps(_mm256_set1_ps(d0), acc_block)); + } + + *s = hsum_float_8(acc); +#elif defined(__SSSE3__) + const __m128i ones_8 = _mm_set1_epi8(1); + const __m128i ones_16 = _mm_set1_epi16(1); + const __m128i zero = _mm_setzero_si128(); + __m128 acc_0 = _mm_setzero_ps(); + __m128 acc_1 = _mm_setzero_ps(); + __m128 acc_2 = _mm_setzero_ps(); + __m128 acc_3 = _mm_setzero_ps(); + + for (int ib = 0; ib < nb; ++ib) { + const __m128 d0 = _mm_set1_ps(GGML_CPU_FP16_TO_FP32(x[ib].d)); + const block_q8_0 * GGML_RESTRICT y_ptr = &y[ib * 4]; + +#define Q1_SSSE3_BLOCK(QS_OFF, Y_IDX, ACC) \ + { \ + const __m128i bit_mask_0 = bytes_from_bits_16(&x[ib].qs[(QS_OFF) + 0]); \ + const __m128i bit_mask_1 = bytes_from_bits_16(&x[ib].qs[(QS_OFF) + 2]); \ + const __m128i qy_0 = _mm_loadu_si128((const __m128i *) &y_ptr[(Y_IDX)].qs[0]); \ + const __m128i qy_1 = _mm_loadu_si128((const __m128i *) &y_ptr[(Y_IDX)].qs[16]); \ + const __m128i sign_mask_0 = _mm_cmpeq_epi8(bit_mask_0, zero); \ + const __m128i sign_mask_1 = _mm_cmpeq_epi8(bit_mask_1, zero); \ + const __m128i sy_0 = _mm_sub_epi8(_mm_xor_si128(qy_0, sign_mask_0), sign_mask_0); \ + const __m128i sy_1 = _mm_sub_epi8(_mm_xor_si128(qy_1, sign_mask_1), sign_mask_1); \ + const __m128i sum_0 = _mm_madd_epi16(_mm_maddubs_epi16(ones_8, sy_0), ones_16); \ + const __m128i sum_1 = _mm_madd_epi16(_mm_maddubs_epi16(ones_8, sy_1), ones_16); \ + const __m128 q = _mm_cvtepi32_ps(_mm_add_epi32(sum_0, sum_1)); \ + (ACC) = _mm_add_ps((ACC), _mm_mul_ps(_mm_mul_ps(d0, _mm_set1_ps(GGML_CPU_FP16_TO_FP32(y_ptr[(Y_IDX)].d))), q)); \ + } + Q1_SSSE3_BLOCK(0, 0, acc_0) + Q1_SSSE3_BLOCK(4, 1, acc_1) + Q1_SSSE3_BLOCK(8, 2, acc_2) + Q1_SSSE3_BLOCK(12, 3, acc_3) +#undef Q1_SSSE3_BLOCK + } + + *s = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3); +#else + UNUSED(nb); + UNUSED(x); + UNUSED(y); + ggml_vec_dot_q1_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { const int qk = QK8_0; const int nb = n / qk; diff --git a/ggml/src/ggml-cpu/quants.c b/ggml/src/ggml-cpu/quants.c index f66127c2290..e5f9a4083f9 100644 --- a/ggml/src/ggml-cpu/quants.c +++ b/ggml/src/ggml-cpu/quants.c @@ -137,22 +137,28 @@ void ggml_vec_dot_q1_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, c float sumf = 0.0; for (int i = 0; i < nb; i++) { - const float d0 = GGML_FP16_TO_FP32(x[i].d); + const float d0 = GGML_CPU_FP16_TO_FP32(x[i].d); float sumi = 0.0f; for (int k = 0; k < 4; k++) { - const float d1 = GGML_FP16_TO_FP32(y[i*4 + k].d); - + const block_q8_0 * GGML_RESTRICT yb = &y[i * 4 + k]; + const float d1 = GGML_CPU_FP16_TO_FP32(yb->d); int sumi_block = 0; - for (int j = 0; j < QK8_0; j++) { - const int bit_index = k * QK8_0 + j; - const int byte_index = bit_index / 8; - const int bit_offset = bit_index % 8; - - const int xi = ((x[i].qs[byte_index] >> bit_offset) & 1) ? 1 : -1; - sumi_block += xi * y[i*4 + k].qs[j]; + const uint8_t * GGML_RESTRICT bits = &x[i].qs[k * 4]; + const int8_t * GGML_RESTRICT qy = yb->qs; + + for (int b = 0; b < 4; ++b, qy += 8) { + const unsigned mask = bits[b]; + sumi_block += ((mask & 0x01) ? qy[0] : -qy[0]) + + ((mask & 0x02) ? qy[1] : -qy[1]) + + ((mask & 0x04) ? qy[2] : -qy[2]) + + ((mask & 0x08) ? qy[3] : -qy[3]) + + ((mask & 0x10) ? qy[4] : -qy[4]) + + ((mask & 0x20) ? qy[5] : -qy[5]) + + ((mask & 0x40) ? qy[6] : -qy[6]) + + ((mask & 0x80) ? qy[7] : -qy[7]); } sumi += d1 * sumi_block; From 6429023e5f48c37b03e4903bf2bab8ef875b244f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Mon, 20 Apr 2026 18:09:39 +0200 Subject: [PATCH 462/831] TP: fix 0-sized tensor slices, AllReduce fallback (llama/21808) * TP: fix 0-sized tensor slices, AllReduce fallback * fix layer structure <-> GPU count aliasing * add missing std::fill * fix CUDA device set, max ggml ctx size --- ggml/src/ggml-backend-meta.cpp | 218 +++++++++++++++++++++----------- ggml/src/ggml-cuda/ggml-cuda.cu | 13 +- 2 files changed, 154 insertions(+), 77 deletions(-) diff --git a/ggml/src/ggml-backend-meta.cpp b/ggml/src/ggml-backend-meta.cpp index 39651adc1c1..4bf90c6a98b 100644 --- a/ggml/src/ggml-backend-meta.cpp +++ b/ggml/src/ggml-backend-meta.cpp @@ -1133,7 +1133,7 @@ static enum ggml_status ggml_backend_meta_buffer_init_tensor(ggml_backend_buffer if (t_ij->view_src != nullptr && ggml_backend_buffer_is_meta(t_ij->view_src->buffer)) { t_ij->view_src = ggml_backend_meta_buffer_simple_tensor(tensor->view_src, j); if (t_ij->view_offs > 0 && split_dim >= 0 && split_dim < GGML_MAX_DIMS) { - GGML_ASSERT(ne[split_dim] != 0 && tensor->ne[split_dim] != 0); + GGML_ASSERT(tensor->ne[split_dim] != 0); const int split_dim_view_src = ggml_backend_meta_get_split_state(tensor->view_src, /*assume_sync =*/ true).axis; GGML_ASSERT(split_dim_view_src >= 0 && split_dim_view_src < GGML_MAX_DIMS); @@ -1170,6 +1170,28 @@ static enum ggml_status ggml_backend_meta_buffer_init_tensor(ggml_backend_buffer simple_tensors.push_back(t_ij); } + + // If one of the sources has a zero-sized slice, disable the computation: + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (tensor->src[i] == nullptr || !ggml_backend_buffer_is_meta(tensor->src[i]->buffer)) { + continue; + } + + const ggml_backend_meta_split_state split_state_src = ggml_backend_meta_get_split_state(tensor->src[i], /*assume_sync =*/ true); + if (split_state_src.axis < 0 || split_state_src.axis >= GGML_MAX_DIMS) { + continue; + } + for (size_t j = 0; j < n_simple_bufs; j++) { + int64_t ne_sum = 0; + for (size_t s = 0; s < split_state_src.n_segments; s++) { + ne_sum += split_state_src.ne[s*n_simple_bufs + j]; + } + if (ne_sum == 0) { + simple_tensors[j]->flags &= ~GGML_TENSOR_FLAG_COMPUTE; + } + } + } + buf_ctx->simple_tensors[tensor] = simple_tensors; return GGML_STATUS_SUCCESS; @@ -1442,17 +1464,20 @@ struct ggml_backend_meta_context { struct backend_config { ggml_backend_t backend; - std::vector cgraphs; - std::vector nodes; - ggml_backend_buffer_ptr buf; + std::vector cgraphs; + std::vector nodes; + std::vector bufs; - backend_config(ggml_backend_t backend) : backend(backend) {} + backend_config(ggml_backend_t backend, const size_t n_reduce_steps) : backend(backend) { + bufs.resize(n_reduce_steps); + } }; std::string name; std::vector backend_configs; ggml_context_ptr ctx; std::vector cgraphs_aux; std::vector nodes_aux; + size_t n_reduce_steps; int max_nnodes = 0; size_t max_tmp_size = 0; size_t max_subgraphs = 0; @@ -1464,6 +1489,7 @@ struct ggml_backend_meta_context { ggml_backend_meta_context(ggml_backend_dev_t meta_dev, const char * params) { const size_t n_devs = ggml_backend_meta_dev_n_devs(meta_dev); + n_reduce_steps = std::ceil(std::log2(n_devs)); name = "Meta("; std::vector simple_backends; backend_configs.reserve(n_devs); @@ -1475,7 +1501,7 @@ struct ggml_backend_meta_context { } name += ggml_backend_dev_name(simple_dev); simple_backends.push_back(ggml_backend_dev_init(simple_dev, params)); - backend_configs.emplace_back(simple_backends.back()); + backend_configs.emplace_back(simple_backends.back(), n_reduce_steps); } name += ")"; @@ -1505,10 +1531,6 @@ struct ggml_backend_meta_context { ggml_backend_free(bc.backend); } } - - size_t n_reduce_steps() const { - return std::ceil(std::log2(backend_configs.size())); - } }; static const char * ggml_backend_meta_get_name(ggml_backend_t backend) { @@ -1754,16 +1776,17 @@ static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend, if (max_tmp_size > backend_ctx->max_tmp_size) { for (size_t j = 0; j < n_backends; j++) { auto & bcj = backend_ctx->backend_configs[j]; - bcj.buf.reset(ggml_backend_alloc_buffer(bcj.backend, max_tmp_size)); + for (size_t i = 0; i < backend_ctx->n_reduce_steps; i++) { + bcj.bufs[i].reset(ggml_backend_alloc_buffer(bcj.backend, max_tmp_size)); + } } backend_ctx->max_tmp_size = max_tmp_size; } if (max_nnodes_raised || n_subgraphs > backend_ctx->max_subgraphs) { backend_ctx->max_subgraphs = std::max(backend_ctx->max_subgraphs, n_subgraphs); - const size_t n_reduce_steps = backend_ctx->n_reduce_steps(); - const size_t n_nodes_per_device = 2 * n_reduce_steps; // tmp + ADD per step - const size_t n_cgraphs_per_device = n_reduce_steps; // 1 ADD graph per step + const size_t n_nodes_per_device = 3 * backend_ctx->n_reduce_steps; // tmp + ADD (+zeroing) graph per step and device + const size_t n_cgraphs_per_device = 2 * backend_ctx->n_reduce_steps; // ADD ( + zeroing) graph per step and device const size_t mem_per_device_graphs_main = backend_ctx->max_subgraphs*ggml_graph_overhead_custom(backend_ctx->max_nnodes, cgraph->grads); const size_t mem_per_device_graphs_aux = n_cgraphs_per_device*backend_ctx->max_subgraphs*ggml_graph_overhead_custom(1, cgraph->grads); const size_t mem_per_device_nodes_aux = n_nodes_per_device*backend_ctx->max_subgraphs*ggml_tensor_overhead(); @@ -1812,11 +1835,6 @@ static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend, size_t iga = 0; // i graph aux size_t ina = 0; // i node aux - // FIXME usage_counts - auto get_cgraph_aux = [&]() -> ggml_cgraph * { - ggml_cgraph * ret = backend_ctx->cgraphs_aux[iga++]; - return ret; - }; auto get_node_aux = [&](ggml_tensor * t) -> ggml_tensor * { ggml_tensor * ret = backend_ctx->nodes_aux[ina++]; memset(ret, 0, sizeof(ggml_tensor)); @@ -1828,75 +1846,110 @@ static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend, } return ret; }; + auto set_tmp_data = [&](ggml_tensor * tensor, const size_t j, const size_t i_buf) { + auto & bcj = backend_ctx->backend_configs[j]; + ggml_backend_buffer_ptr & buf_ptr = bcj.bufs[i_buf]; + if (!buf_ptr || ggml_backend_buffer_get_size(buf_ptr.get()) < backend_ctx->max_tmp_size) { + buf_ptr.reset(ggml_backend_alloc_buffer(bcj.backend, backend_ctx->max_tmp_size)); + } + tensor->buffer = buf_ptr.get(); + tensor->data = ggml_backend_buffer_get_base(buf_ptr.get()); + }; + // FIXME usage_counts + auto get_cgraph_aux = [&]() -> ggml_cgraph * { + ggml_cgraph * ret = backend_ctx->cgraphs_aux[iga++]; + return ret; + }; // Preferentially use backend-specific allreduce_tensor_async (e.g. NCCL for CUDA), use a generic fallback if unavailable: auto allreduce_fallback = [&](size_t i) -> ggml_status { std::vector step_cgraphs(n_backends, nullptr); - for (size_t offset_j = 1; offset_j < n_backends; offset_j *= 2) { + // Zero out nodes that were disabled due to having a zero-sized slice: + for (size_t j = 0; j < n_backends; j++) { + auto & bcj = backend_ctx->backend_configs[j]; + ggml_tensor * node = bcj.cgraphs[i].cgraph_main->nodes[bcj.cgraphs[i].cgraph_main->n_nodes - 1]; + if (node->flags & GGML_TENSOR_FLAG_COMPUTE) { + continue; + } + ggml_tensor * node_zero = get_node_aux(node); + node_zero->op = GGML_OP_SCALE; // FIXME 0.0f * NaN == NaN + node_zero->src[0] = node; + ggml_set_op_params_f32(node_zero, 0, 0.0f); + node_zero->data = node->data; + node_zero->flags |= GGML_TENSOR_FLAG_COMPUTE; + + step_cgraphs[j] = get_cgraph_aux(); + step_cgraphs[j]->nodes[0] = node_zero; + step_cgraphs[j]->n_nodes = 1; + const ggml_status status = ggml_backend_graph_compute_async(bcj.backend, step_cgraphs[j]); + if (status != GGML_STATUS_SUCCESS) { + return status; + } + } + std::fill(step_cgraphs.begin(), step_cgraphs.end(), nullptr); + + auto push_data = [&](const size_t j_src, const size_t j_dst, const size_t i_buf) { + assert(step_cgraphs[j_dst] == nullptr); + auto & bcj_src = backend_ctx->backend_configs[j_src]; + auto & bcj_dst = backend_ctx->backend_configs[j_dst]; + + ggml_tensor * node_src = bcj_src.cgraphs[i].cgraph_main->nodes[bcj_src.cgraphs[i].cgraph_main->n_nodes - 1]; + ggml_tensor * node_dst = bcj_dst.cgraphs[i].cgraph_main->nodes[bcj_dst.cgraphs[i].cgraph_main->n_nodes - 1]; + GGML_ASSERT(ggml_is_contiguous(node_src)); + GGML_ASSERT(ggml_is_contiguous(node_dst)); + + ggml_tensor * node_tmp = get_node_aux(node_dst); + set_tmp_data(node_tmp, j_dst, i_buf); + + ggml_backend_tensor_copy_async(bcj_src.backend, bcj_dst.backend, node_src, node_tmp); + + ggml_tensor * node_red = get_node_aux(node_dst); + node_red->view_src = node_dst->view_src == nullptr ? node_dst : node_dst->view_src; + node_red->view_offs = node_dst->view_offs; + node_red->op = GGML_OP_ADD; + node_red->src[0] = node_dst; + node_red->src[1] = node_tmp; + node_red->flags |= GGML_TENSOR_FLAG_COMPUTE; + ggml_backend_view_init(node_red); + + ggml_cgraph * cgraph_aux = get_cgraph_aux(); + cgraph_aux->nodes[0] = node_red; + cgraph_aux->n_nodes = 1; + step_cgraphs[j_dst] = cgraph_aux; + }; + + size_t offset_j = n_backends/2; + while ((offset_j & (offset_j - 1)) != 0) { + offset_j--; + } + const size_t offset_j_max = offset_j; + size_t i_buf = 0; + + // If n_backends is not a power of 2, fold in the excess prior to butterfly reduction: + for (size_t j_src = 2*offset_j_max; j_src < n_backends; j_src++) { + const size_t j_dst = j_src - 2*offset_j_max; + push_data(j_src, j_dst, i_buf); + const ggml_status status = ggml_backend_graph_compute_async(backend_ctx->backend_configs[j_dst].backend, step_cgraphs[j_dst]); + if (status != GGML_STATUS_SUCCESS) { + return status; + } + i_buf = 1; + } + + // Butterfly reduction: + for (; offset_j >= 1; offset_j /= 2) { std::fill(step_cgraphs.begin(), step_cgraphs.end(), nullptr); - for (size_t j = 0; j < n_backends; j++) { + for (size_t j = 0; j < 2*offset_j_max; j++) { const size_t j_other = j ^ offset_j; - if (j_other > j) { + if (j_other >= n_backends) { continue; } - - auto & bcj1 = backend_ctx->backend_configs[j]; - auto & bcj2 = backend_ctx->backend_configs[j_other]; - - ggml_tensor * node1 = bcj1.cgraphs[i].cgraph_main->nodes[bcj1.cgraphs[i].cgraph_main->n_nodes - 1]; - ggml_tensor * node2 = bcj2.cgraphs[i].cgraph_main->nodes[bcj2.cgraphs[i].cgraph_main->n_nodes - 1]; - GGML_ASSERT(ggml_is_contiguous(node1)); - GGML_ASSERT(ggml_is_contiguous(node2)); - - // Tmp tensors to receive P2P copies - ggml_tensor * node_tmp_1 = get_node_aux(node1); - node_tmp_1->buffer = bcj1.buf.get(); - node_tmp_1->data = ggml_backend_buffer_get_base(bcj1.buf.get()); - - ggml_tensor * node_tmp_2 = get_node_aux(node2); - node_tmp_2->buffer = bcj2.buf.get(); - node_tmp_2->data = ggml_backend_buffer_get_base(bcj2.buf.get()); - - // 2 P2P copies: exchange full buffers - ggml_backend_tensor_copy_async(bcj1.backend, bcj2.backend, node1, node_tmp_2); - ggml_backend_tensor_copy_async(bcj2.backend, bcj1.backend, node2, node_tmp_1); - - // Local ADD: node1 += tmp1 (in-place via view) - ggml_tensor * node_red_1 = get_node_aux(node1); - node_red_1->view_src = node1->view_src == nullptr ? node1 : node1->view_src; - node_red_1->view_offs = node1->view_offs; - node_red_1->op = GGML_OP_ADD; - node_red_1->src[0] = node1; - node_red_1->src[1] = node_tmp_1; - node_red_1->flags |= GGML_TENSOR_FLAG_COMPUTE; - ggml_backend_view_init(node_red_1); - - // Local ADD: node2 += tmp2 (in-place via view) - ggml_tensor * node_red_2 = get_node_aux(node2); - node_red_2->view_src = node2->view_src == nullptr ? node2 : node2->view_src; - node_red_2->view_offs = node2->view_offs; - node_red_2->op = GGML_OP_ADD; - node_red_2->src[0] = node2; - node_red_2->src[1] = node_tmp_2; - node_red_2->flags |= GGML_TENSOR_FLAG_COMPUTE; - ggml_backend_view_init(node_red_2); - - // Build 1-node cgraphs for the ADD ops - ggml_cgraph * cgraph_aux_1 = get_cgraph_aux(); - cgraph_aux_1->nodes[0] = node_red_1; - cgraph_aux_1->n_nodes = 1; - step_cgraphs[j] = cgraph_aux_1; - - ggml_cgraph * cgraph_aux_2 = get_cgraph_aux(); - cgraph_aux_2->nodes[0] = node_red_2; - cgraph_aux_2->n_nodes = 1; - step_cgraphs[j_other] = cgraph_aux_2; + push_data(j, j_other, i_buf); } - // Execute local ADDs for this step - for (size_t j = 0; j < n_backends; j++) { + for (size_t j = 0; j < 2*offset_j_max; j++) { if (step_cgraphs[j] == nullptr) { continue; } @@ -1906,7 +1959,20 @@ static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend, return status; } } + i_buf++; } + assert(i_buf == backend_ctx->n_reduce_steps); + + // If n_backends is not a power of 2, copy back the reduced tensors to the excess: + for (size_t j = 2*offset_j_max; j < n_backends; j++) { + auto & bcj_src = backend_ctx->backend_configs[j - 2*offset_j_max]; + auto & bcj_dst = backend_ctx->backend_configs[j]; + + ggml_tensor * node_src = bcj_src.cgraphs[i].cgraph_main->nodes[bcj_src.cgraphs[i].cgraph_main->n_nodes - 1]; + ggml_tensor * node_dst = bcj_dst.cgraphs[i].cgraph_main->nodes[bcj_dst.cgraphs[i].cgraph_main->n_nodes - 1]; + ggml_backend_tensor_copy_async(bcj_src.backend, bcj_dst.backend, node_src, node_dst); + } + return GGML_STATUS_SUCCESS; }; diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index de579d2ed50..ecd12b80dfe 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -1203,6 +1203,13 @@ static bool ggml_backend_cuda_comm_allreduce_tensor(void * comm_ctx_v, struct gg // For small tensors, simply reduce them as FP32. // The following heuristic for how "small" a tensor should be is based on RTX 4090s connected via 16x PCIe 4.0. if ((n_backends <= 2 && ne < 32768) || (n_backends == 3 && ne < 131072) || (n_backends >= 4 && ne < 262144)) { + for (size_t i = 0; i < n_backends; ++i) { + if ((tensors[i]->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) comm_ctx->backends[i]->context; + ggml_cuda_set_device(cuda_ctx->device); + CUDA_CHECK(cudaMemsetAsync(tensors[i]->data, 0, ggml_nbytes(tensors[i]), cuda_ctx->stream())); + } + } NCCL_CHECK(ncclGroupStart()); for (size_t i = 0; i < n_backends; ++i) { ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) comm_ctx->backends[i]->context; @@ -1224,7 +1231,11 @@ static bool ggml_backend_cuda_comm_allreduce_tensor(void * comm_ctx_v, struct gg tmp[i].alloc(ne); ggml_cuda_set_device(cuda_ctx->device); - to_bf16(tensors[i]->data, tmp[i].get(), ne, cuda_ctx->stream()); + if (tensors[i]->flags & GGML_TENSOR_FLAG_COMPUTE) { + to_bf16(tensors[i]->data, tmp[i].get(), ne, cuda_ctx->stream()); + } else { + CUDA_CHECK(cudaMemsetAsync(tmp[i].get(), 0, ne * sizeof(nv_bfloat16), cuda_ctx->stream())); + } CUDA_CHECK(cudaGetLastError()); } From 239c5c86c30d36249e3479459914c4eb24958f19 Mon Sep 17 00:00:00 2001 From: Gaurav Garg Date: Mon, 20 Apr 2026 21:55:39 +0530 Subject: [PATCH 463/831] Tensor-parallel: Fix delayed AllReduce on Gemma-4 MoE (llama/22129) * Fix delayed AllReduce on Gemma-4 MoE Skip forward past nodes that don't consume the current one, and allow a chain of MULs. * Check for all sources before skipping nodes * Address review comments --- ggml/src/ggml-backend-meta.cpp | 42 ++++++++++++++++++++++++++++++---- 1 file changed, 38 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-backend-meta.cpp b/ggml/src/ggml-backend-meta.cpp index 4bf90c6a98b..6d22f3421b1 100644 --- a/ggml/src/ggml-backend-meta.cpp +++ b/ggml/src/ggml-backend-meta.cpp @@ -1683,6 +1683,36 @@ static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend, ggml_tensor * node = cgraph->nodes[id]; int32_t n_used = ggml_node_get_use_count(cgraph, id); + + // Skip MIRRORED nodes that don't consume node + auto skip_unrelated = [&]() { + while (id + 1 < cgraph->n_nodes) { + ggml_tensor * next = cgraph->nodes[id+1]; + if (ggml_backend_meta_get_split_state(next, false).axis != GGML_BACKEND_SPLIT_AXIS_MIRRORED) { + break; + } + bool safe = true; + for (int s = 0; s < GGML_MAX_SRC; s++) { + if (next->src[s] == nullptr) { + continue; + } + if (next->src[s] == node) { + safe = false; + break; + } + if (ggml_backend_meta_get_split_state(next->src[s], false).axis != GGML_BACKEND_SPLIT_AXIS_MIRRORED) { + safe = false; + break; + } + } + if (!safe) { + break; + } + id++; + } + }; + + skip_unrelated(); if (id + 1 >= cgraph->n_nodes) { return idr; } @@ -1697,10 +1727,12 @@ static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend, n_used = ggml_node_get_use_count(cgraph, id); } } - if (id + 1 >= cgraph->n_nodes) { - return idr; - } - { + // Chain of MULs with MIRRORED src[1] + while (true) { + skip_unrelated(); + if (id + 1 >= cgraph->n_nodes) { + return idr; + } ggml_tensor * next = cgraph->nodes[id+1]; if (next->op == GGML_OP_MUL && next->src[0] == node && ggml_backend_meta_get_split_state(next->src[1], false).axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { @@ -1708,6 +1740,8 @@ static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend, id++; idr = id; n_used = ggml_node_get_use_count(cgraph, id); + } else { + break; } } From b13deaabaec1d52fc228195e56f96f1d83b7d2c0 Mon Sep 17 00:00:00 2001 From: leonardHONG <2695316095@qq.com> Date: Tue, 21 Apr 2026 05:30:38 +0800 Subject: [PATCH 464/831] ggml-cuda: flush legacy pool on OOM and retry (llama/22155) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * ggml-cuda: flush legacy pool on OOM and retry Signed-off-by: 梁厚宏 <2695316095@qq.com> * Address review comments: add explicit sync, update destructor, clean up MUSA macros Signed-off-by: 梁厚宏 <2695316095@qq.com> --------- Signed-off-by: 梁厚宏 <2695316095@qq.com> --- ggml/src/ggml-cuda/ggml-cuda.cu | 23 +++++++++++++++++++++-- ggml/src/ggml-cuda/vendors/hip.h | 1 + ggml/src/ggml-cuda/vendors/musa.h | 1 + 3 files changed, 23 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index ecd12b80dfe..185956317e0 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -368,15 +368,21 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool { } ~ggml_cuda_pool_leg() { + clear_pool(); + GGML_ASSERT(pool_size == 0); + } + + void clear_pool() { ggml_cuda_set_device(device); for (int i = 0; i < MAX_BUFFERS; ++i) { ggml_cuda_buffer & b = buffer_pool[i]; if (b.ptr != nullptr) { CUDA_CHECK(cudaFree(b.ptr)); pool_size -= b.size; + b.ptr = nullptr; + b.size = 0; } } - GGML_ASSERT(pool_size == 0); } void * alloc(size_t size, size_t * actual_size) override { @@ -421,7 +427,20 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool { size_t look_ahead_size = (size_t) (1.05 * size); look_ahead_size = 256 * ((look_ahead_size + 255)/256); ggml_cuda_set_device(device); - CUDA_CHECK(ggml_cuda_device_malloc(&ptr, look_ahead_size, device)); + cudaError_t err = ggml_cuda_device_malloc(&ptr, look_ahead_size, device); + if (err == cudaErrorMemoryAllocation) { + (void)cudaGetLastError(); + const size_t cached_bytes = pool_size; + GGML_LOG_DEBUG(GGML_CUDA_NAME " pool[%d]: alloc of %.2f MiB failed, flushing %.2f MiB of cached buffers and retrying\n", + device, look_ahead_size/1024.0/1024.0, cached_bytes/1024.0/1024.0); + CUDA_CHECK(cudaDeviceSynchronize()); + clear_pool(); + err = ggml_cuda_device_malloc(&ptr, look_ahead_size, device); + if (err == cudaSuccess) { + GGML_LOG_DEBUG(GGML_CUDA_NAME " pool[%d]: retry succeeded\n", device); + } + } + CUDA_CHECK(err); *actual_size = look_ahead_size; pool_size += look_ahead_size; #ifdef DEBUG_CUDA_MALLOC diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h index 52c38908e06..78ca364d38f 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h @@ -58,6 +58,7 @@ #define cudaDeviceProp hipDeviceProp_t #define cudaDeviceSynchronize hipDeviceSynchronize #define cudaError_t hipError_t +#define cudaErrorMemoryAllocation hipErrorOutOfMemory #define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled #define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled #define cudaEventCreateWithFlags hipEventCreateWithFlags diff --git a/ggml/src/ggml-cuda/vendors/musa.h b/ggml/src/ggml-cuda/vendors/musa.h index 1abb8acfd4b..8aa056e9174 100644 --- a/ggml/src/ggml-cuda/vendors/musa.h +++ b/ggml/src/ggml-cuda/vendors/musa.h @@ -42,6 +42,7 @@ #define cudaDeviceProp musaDeviceProp #define cudaDeviceSynchronize musaDeviceSynchronize #define cudaError_t musaError_t +#define cudaErrorMemoryAllocation musaErrorMemoryAllocation #define cudaErrorPeerAccessAlreadyEnabled musaErrorPeerAccessAlreadyEnabled #define cudaErrorPeerAccessNotEnabled musaErrorPeerAccessNotEnabled #define cudaEventCreateWithFlags musaEventCreateWithFlags From e7cffdbd0bc1ea97a605ab361907ef771b993bea Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 21 Apr 2026 11:02:56 +0300 Subject: [PATCH 465/831] ggml : bump version to 0.10.0 (ggml/1463) --- ggml/CMakeLists.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index a0eb9204eab..2effd587b41 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -4,8 +4,8 @@ project("ggml" C CXX ASM) ### GGML Version set(GGML_VERSION_MAJOR 0) -set(GGML_VERSION_MINOR 9) -set(GGML_VERSION_PATCH 11) +set(GGML_VERSION_MINOR 10) +set(GGML_VERSION_PATCH 0) set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}") list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/") From 85bbc822092a88d4feb0b2f8ddad0bb2de04488e Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Tue, 21 Apr 2026 11:01:56 +0200 Subject: [PATCH 466/831] vulkan: Support F16 OP_FILL (llama/22177) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 8 +++++++- .../src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp | 1 + 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 702a249d754..d4acee8b1df 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -792,6 +792,7 @@ struct vk_device_struct { vk_pipeline pipeline_arange_f32; vk_pipeline pipeline_fill_f32; + vk_pipeline pipeline_fill_f16; vk_pipeline pipeline_geglu[2]; vk_pipeline pipeline_reglu[2]; @@ -4577,6 +4578,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_arange_f32, "arange_f32", arange_f32_len, arange_f32_data, "main", 1, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_fill_f32, "fill_f32", fill_f32_len, fill_f32_data, "main", 1, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_fill_f16, "fill_f16", fill_f16_len, fill_f16_data, "main", 1, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); #define CREATE_GLU(name) \ ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \ @@ -9844,6 +9846,9 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const if (dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_fill_f32; } + if (dst->type == GGML_TYPE_F16) { + return ctx->device->pipeline_fill_f16; + } return nullptr; default: return nullptr; @@ -15713,8 +15718,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm || (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F32) || (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F16); case GGML_OP_ARANGE: - case GGML_OP_FILL: return op->type == GGML_TYPE_F32; + case GGML_OP_FILL: + return op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16; case GGML_OP_SCALE: return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_PAD: diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 54b9b327333..ff836615330 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -889,6 +889,7 @@ void process_shaders() { string_to_spv("add1_f32_f32", "add1.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); string_to_spv("arange_f32", "arange.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); string_to_spv("fill_f32", "fill.comp", {{"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + string_to_spv("fill_f16", "fill.comp", {{"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}}); string_to_spv("step_f16", "step.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); string_to_spv("step_f32", "step.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("round_f16", "round.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); From 150cef5a5f5c5272444eafbf090083476c8b1ccf Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 21 Apr 2026 17:24:55 +0300 Subject: [PATCH 467/831] metal : workaround macOS GPU interactivity watchdog (llama/22216) --- ggml/src/ggml-metal/ggml-metal.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ggml/src/ggml-metal/ggml-metal.cpp b/ggml/src/ggml-metal/ggml-metal.cpp index 4dbf8e6fea9..6a836e45908 100644 --- a/ggml/src/ggml-metal/ggml-metal.cpp +++ b/ggml/src/ggml-metal/ggml-metal.cpp @@ -918,6 +918,10 @@ ggml_backend_reg_t ggml_backend_metal_reg(void) { static std::vector devs; if (!initialized) { + // workaround macOS limitation (kIOGPUCommandBufferCallbackErrorImpactingInteractivity) until proper fix becomes possible + // ref: https://github.com/ggml-org/llama.cpp/issues/20141#issuecomment-4272947703 + setenv("AGX_RELAX_CDM_CTXSTORE_TIMEOUT", "1", true); + static ggml_backend_metal_reg_ptr reg_ctx(ggml_backend_metal_reg_init()); for (int i = 0; i < g_devices; ++i) { From 3a73f9cf0b3dc2a221000fd865545c783d3e978d Mon Sep 17 00:00:00 2001 From: Zijun Yu Date: Tue, 21 Apr 2026 23:58:34 +0800 Subject: [PATCH 468/831] openvino: driver setup, CI split, thread safety, and NPU optimizations (llama/21944) * Thread safety per request only * Fix ROPE yarn case * Fix sticky stateful config * Use i4/i8 directly for symmetric quant * Use weightless caching * Add WeightlessCacheAttribute to reduce NPU memory usage * Gelu tanh support (llama/125) * Imrope support (llama/126) * fix(openvino): explicit ov::Tensor frees in ggml_backend_openvino_free * add GPU,NPU support in OV Dockerfile * add build-openvino.yml ci * Fix sticky stateful config * add concurrency to ov-gpu ci runs. Move OV CI to build-openvino.yml * fix thread-safety of shared runtime context * rope type abstraction for frontend translations * fix editorconfig --------- Co-authored-by: Mustafa Cavus Co-authored-by: Dan Hoffman Co-authored-by: Ravi Panchumarthy --- ggml/src/ggml-openvino/ggml-decoder.cpp | 20 +- .../src/ggml-openvino/ggml-openvino-extra.cpp | 29 +- ggml/src/ggml-openvino/ggml-openvino.cpp | 42 +- ggml/src/ggml-openvino/ggml-quants.cpp | 456 ++++++++++-------- ggml/src/ggml-openvino/openvino/op/rope.cpp | 40 +- .../ggml-openvino/openvino/op/unary_gelu.cpp | 25 + ggml/src/ggml-openvino/openvino/op_table.cpp | 1 + ggml/src/ggml-openvino/openvino/op_table.h | 1 + .../openvino/pass/eliminate_zp.cpp | 123 ----- .../openvino/pass/eliminate_zp.h | 17 - .../rt_info/weightless_caching_attributes.hpp | 41 ++ .../openvino/translate_session.cpp | 30 +- ggml/src/ggml-openvino/openvino/utils.cpp | 103 ++-- ggml/src/ggml-openvino/openvino/utils.h | 1 + ggml/src/ggml-openvino/utils.cpp | 145 ++++-- ggml/src/ggml-openvino/utils.h | 26 +- 16 files changed, 646 insertions(+), 454 deletions(-) create mode 100644 ggml/src/ggml-openvino/openvino/op/unary_gelu.cpp delete mode 100644 ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp delete mode 100644 ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h create mode 100644 ggml/src/ggml-openvino/openvino/rt_info/weightless_caching_attributes.hpp diff --git a/ggml/src/ggml-openvino/ggml-decoder.cpp b/ggml/src/ggml-openvino/ggml-decoder.cpp index 0938d2273e9..5095e799849 100644 --- a/ggml/src/ggml-openvino/ggml-decoder.cpp +++ b/ggml/src/ggml-openvino/ggml-decoder.cpp @@ -19,7 +19,6 @@ #include #include #include -#include #include #include #include @@ -207,8 +206,22 @@ int GgmlOvDecoder::compute_op_case(const ggml_tensor * node) const { break; } case GGML_OP_ROPE: { + const int mode = node->op_params[2]; + switch (mode) { + case GGML_ROPE_TYPE_NEOX: { + op_case = 0x00010000; + break; + } + case GGML_ROPE_TYPE_IMROPE: { + op_case = 0x00020000; + break; + } + default: + op_case = 0x00000000; + break; + } if (node->src[0]->op == GGML_OP_VIEW) { - op_case = 2; + op_case = (op_case | 0x00000002); } break; } @@ -573,9 +586,6 @@ std::map GgmlOvDecoder::get_kv_param_res_names() const } std::map> GgmlOvDecoder::create_weight_nodes(ggml_cgraph * cgraph, bool naive) { - static std::mutex weights_mutex; - std::lock_guard lock(weights_mutex); - std::map> model_weights; auto * nodes = cgraph->nodes; auto n_nodes = cgraph->n_nodes; diff --git a/ggml/src/ggml-openvino/ggml-openvino-extra.cpp b/ggml/src/ggml-openvino/ggml-openvino-extra.cpp index cc3cb4583cd..4140136aca2 100644 --- a/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +++ b/ggml/src/ggml-openvino/ggml-openvino-extra.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #include ov::Core & ov_singleton_core() { @@ -42,11 +43,13 @@ void ggml_openvino_device_config::init() { {"NPUW_DQ", "YES" }, {"NPUW_DQ_FULL", "NO" }, }; - if (cache_dir) { + if (cache_dir && strlen(cache_dir) > 0) { compile_config["NPUW_CACHE_DIR"] = cache_dir; + compile_config.insert(ov::cache_mode(ov::CacheMode::OPTIMIZE_SIZE)); } - } else if (cache_dir) { - ov_singleton_core().set_property(ov::cache_dir(cache_dir)); + } else if (cache_dir && strlen(cache_dir) > 0) { + compile_config.insert(ov::cache_dir(cache_dir)); + compile_config.insert(ov::cache_mode(ov::CacheMode::OPTIMIZE_SIZE)); } // Initialize remote context with queue sharing for GPU @@ -259,10 +262,12 @@ ggml_openvino_extracted_layout ggml_openvino_get_extracted_layout(const ggml_ten layout.weights_size = layout.is_u4 ? (n_elements / 2) : n_elements; int64_t n_blocks = n_elements / layout.weights_per_block; layout.scales_size = n_blocks * sizeof(uint16_t); - // For symmetric quantization, we only need one zp value (not one per block) - // Zero points are stored in U4 or U8 format matching the weight type - size_t n_zp_elements = layout.is_symmetric ? 1 : n_blocks; - layout.zp_size = layout.is_u4 ? ((n_zp_elements + 1) / 2) : n_zp_elements; + // For symmetric quantization, no zp needed (weights stored as signed) + if (layout.is_symmetric) { + layout.zp_size = 0; + } else { + layout.zp_size = layout.is_u4 ? ((n_blocks + 1) / 2) : n_blocks; + } layout.weights_offset = 0; layout.scales_offset = ((layout.weights_size + alignment - 1) / alignment) * alignment; @@ -313,10 +318,12 @@ ggml_openvino_extracted_layout ggml_openvino_get_extracted_layout(const ggml_ten // Scales: F16 per block int64_t n_blocks = n_elements / layout.weights_per_block; layout.scales_size = n_blocks * sizeof(uint16_t); // F16 = 2 bytes - // Zero points: U4 or U8 matching weight type - // For symmetric quantization, we only need one zp value (not one per block) - size_t n_zp_elements = layout.is_symmetric ? 1 : n_blocks; - layout.zp_size = layout.is_u4 ? ((n_zp_elements + 1) / 2) : n_zp_elements; + // For symmetric quantization, no zp needed (weights stored as signed) + if (layout.is_symmetric) { + layout.zp_size = 0; + } else { + layout.zp_size = layout.is_u4 ? ((n_blocks + 1) / 2) : n_blocks; + } // Layout in buffer: [weights | scales | zp] with alignment layout.weights_offset = 0; diff --git a/ggml/src/ggml-openvino/ggml-openvino.cpp b/ggml/src/ggml-openvino/ggml-openvino.cpp index 0c8d3508e87..4f3ebf2536b 100644 --- a/ggml/src/ggml-openvino/ggml-openvino.cpp +++ b/ggml/src/ggml-openvino/ggml-openvino.cpp @@ -145,13 +145,18 @@ static void * ggml_backend_openvino_buffer_get_base(ggml_backend_buffer_t buffer return ctx->data; } +static bool is_stateful_enabled() { + static const auto * stateful = getenv("GGML_OPENVINO_STATEFUL_EXECUTION"); + return stateful && *stateful != '\0' && strcmp(stateful, "0") != 0; +} + static enum ggml_status ggml_backend_openvino_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { // GGML_LOG_DEBUG("%s: buffer usage=%d, tensor name=%s\n", __func__, buffer->usage, tensor->name); ggml_backend_openvino_buffer_context * ctx = (ggml_backend_openvino_buffer_context *) buffer->context; // Put kvcache on device memory for GPU (NPU memory is too small even for kvcache) if (strncmp(tensor->name, "cache_", 6) == 0 && !ctx->is_remote && ggml_openvino_get_device_name() == "GPU" && - !getenv("GGML_OPENVINO_STATEFUL_EXECUTION")) { + !is_stateful_enabled()) { GGML_ASSERT(ctx->tensor_extras.empty()); auto device = ctx->device; auto size = ctx->size; @@ -600,6 +605,14 @@ bool ggml_backend_buft_is_openvino_host(ggml_backend_buffer_type_t buft) { static void ggml_backend_openvino_free(ggml_backend_t backend) { ggml_backend_openvino_context * ctx = (ggml_backend_openvino_context *) backend->context; + + if (ctx->runtime_context) { + auto r_ctx = std::static_pointer_cast(ctx->runtime_context); + if (--r_ctx->backend_count == 0) { + r_ctx->clear_caches(); + } + } + delete ctx; delete backend; } @@ -644,7 +657,12 @@ static ggml_guid_t ggml_backend_openvino_guid(void) { } static std::shared_ptr get_ov_runtime_context_ptr() { - static std::shared_ptr r_ctx = std::make_shared(); + static std::shared_ptr r_ctx = [] { + auto ctx = std::make_shared(); + ctx->device = ggml_openvino_get_device_name(); + ctx->stateful = is_stateful_enabled() && !ggml_openvino_is_npu(); + return ctx; + }(); return r_ctx; } @@ -669,8 +687,7 @@ GGML_BACKEND_API ggml_backend_t ggml_backend_openvino_init(int device) { } std::shared_ptr r_ctx = std::static_pointer_cast(ctx->runtime_context); - r_ctx->device = ggml_openvino_get_device_name(); - r_ctx->stateful = getenv("GGML_OPENVINO_STATEFUL_EXECUTION") && !ggml_openvino_is_npu(); + r_ctx->backend_count++; ggml_backend_t openvino_backend = new ggml_backend{ /* .guid = */ ggml_backend_openvino_guid(), @@ -883,7 +900,7 @@ static bool is_op_unsupported_case(const ggml_tensor * op) { const int32_t * op_params = op->op_params; const int n_dims = op_params[1]; const int mode = op_params[2]; - if (mode != GGML_ROPE_TYPE_NORMAL && mode != GGML_ROPE_TYPE_NEOX) { + if (mode != GGML_ROPE_TYPE_NORMAL && mode != GGML_ROPE_TYPE_NEOX && mode != GGML_ROPE_TYPE_IMROPE) { // GGML_LOG_WARN("OpenVINO backend does not support ROPE with mode %d\n", mode); return true; } @@ -896,14 +913,6 @@ static bool is_op_unsupported_case(const ggml_tensor * op) { // GGML_LOG_WARN("OpenVINO backend does not support ROPE with type %s\n", ggml_type_name(op->type)); return true; } - float freq_scale; - float ext_factor; - memcpy(&freq_scale, op_params + 6, sizeof(float)); - memcpy(&ext_factor, op_params + 7, sizeof(float)); - if (ext_factor != 0.0f) { - // GGML_LOG_WARN("OpenVINO backend does not support ROPE with ext_factor %f != 0.0f\n", ext_factor); - return true; - } if (op->src[0]->op == GGML_OP_VIEW) { if (op->src[0]->view_src->ne[1] != op->src[0]->ne[2]) { // GGML_LOG_WARN( @@ -913,6 +922,12 @@ static bool is_op_unsupported_case(const ggml_tensor * op) { return true; } } + if (mode == GGML_ROPE_TYPE_IMROPE && + (op->src[2] != 0 || ((const float *) op_params)[6] != 1 || ((const float *) op_params)[7] != 0 || + ((const float *) op_params)[8] != 1)) { + // GGML_LOG_WARN("OpenVINO backend does not support IMROPE with freq_factors, freq_scale, ext_factor, and attn_factor\n"); + return true; + } break; } default: @@ -942,6 +957,7 @@ static bool ggml_backend_openvino_device_supports_op(ggml_backend_dev_t dev, con // GGML_OP_SOFT_MAX, GGML_OP_SET_ROWS, GGML_OP_FLASH_ATTN_EXT, GGML_OP_CPY}; static const std::set supported_unary_ops{ + GGML_UNARY_OP_GELU, GGML_UNARY_OP_SILU, }; static const std::set supported_glu_ops{ diff --git a/ggml/src/ggml-openvino/ggml-quants.cpp b/ggml/src/ggml-openvino/ggml-quants.cpp index dbf38646ddd..57d66df4f01 100644 --- a/ggml/src/ggml-openvino/ggml-quants.cpp +++ b/ggml/src/ggml-openvino/ggml-quants.cpp @@ -46,6 +46,7 @@ void unpack_32_4(const uint8_t * data, uint8_t * dst) { // Extracts (weight, scales, zp) from Q4_0 tensors. // Data layout is: |16 bit scale|32 x 4bit weights|. +// When zp_arr is empty (symmetric), weights are stored as signed i4 (value - 8). void extract_q4_0_data(const ggml_tensor * tensor, ov::Tensor & weights_arr, ov::Tensor & scales_arr, @@ -55,28 +56,32 @@ void extract_q4_0_data(const ggml_tensor * tensor, auto * data = static_cast(tensor->data); auto * weights = static_cast(weights_arr.data()); auto * scales = scales_arr.data::value_type>(); - auto * zp = static_cast(zp_arr.data()); - - bool is_scalar_zp = (zp_arr.get_size() == 1); // Symmetric quantization - // For Q4_0, zero point is always 8 - if (is_scalar_zp) { - zp[0] = 8 | (8 << 4); // Pack two 4-bit values - } + bool is_symmetric = (weights_arr.get_element_type() == ov::element::i4); // Signed i4 path - ov::parallel_for(scales_arr.get_size(), [&](size_t i) { - scales[i] = ov::float16::from_bits(*((uint16_t *) (data + i * bytes_per_block))); - // For asymmetric quantization, compute per-block zero points - if (!is_scalar_zp) { + if (!is_symmetric) { + auto * zp = static_cast(zp_arr.data()); + ov::parallel_for(scales_arr.get_size(), [&](size_t i) { + scales[i] = ov::float16::from_bits(*((uint16_t *) (data + i * bytes_per_block))); // Pack two 4-bit zero points per byte if (i % 2 == 0) { zp[i / 2] = 8; // Lower nibble } else { zp[i / 2] |= (8 << 4); // Upper nibble } - } - unpack_32_4(data + i * bytes_per_block + 2, weights + i * 16); - }); + unpack_32_4(data + i * bytes_per_block + 2, weights + i * 16); + }); + } else { + // Symmetric: unpack as u4 then convert to i4 by subtracting 8 (XOR each nibble) + ov::parallel_for(scales_arr.get_size(), [&](size_t i) { + scales[i] = ov::float16::from_bits(*((uint16_t *) (data + i * bytes_per_block))); + unpack_32_4(data + i * bytes_per_block + 2, weights + i * 16); + // Convert u4 to i4: subtract 8 from each nibble. XOR 0x88 flips each nibble by 8. + for (int j = 0; j < 16; ++j) { + weights[i * 16 + j] ^= 0x88; + } + }); + } } // Extracts (weight, scales, zp) from Q4_1 tensors. @@ -123,6 +128,7 @@ void extract_q4_1_data(const ggml_tensor * tensor, // Extracts (weight, scales, zp) from Q8_0 tensors. // Data layout is: |16 bit scale|32 x 8bit weights|. +// When zp_arr is empty (symmetric), weights are stored as signed i8 directly. void extract_q8_0_data(const ggml_tensor * tensor, ov::Tensor & weights_arr, ov::Tensor & scales_arr, @@ -133,29 +139,30 @@ void extract_q8_0_data(const ggml_tensor * tensor, auto * data = static_cast(tensor->data); auto * weights = static_cast(weights_arr.data()); auto * scales = scales_arr.data::value_type>(); - auto * zp = static_cast(zp_arr.data()); - - bool is_scalar_zp = (zp_arr.get_size() == 1); // Symmetric quantization - // For Q8_0, zero point is always 128 - if (is_scalar_zp) { - zp[0] = 128; - } + bool is_symmetric = (weights_arr.get_element_type() == ov::element::i8); // Signed i8 path - ov::parallel_for(scales_arr.get_size(), [&](size_t i) { - uint8_t * block_data = data + i * bytes_per_block; - scales[i] = ov::float16::from_bits(*(uint16_t *) block_data); - // For asymmetric quantization, store per-block zero points - if (!is_scalar_zp) { + if (!is_symmetric) { + auto * zp = static_cast(zp_arr.data()); + ov::parallel_for(scales_arr.get_size(), [&](size_t i) { + uint8_t * block_data = data + i * bytes_per_block; + scales[i] = ov::float16::from_bits(*(uint16_t *) block_data); zp[i] = 128; - } - for (size_t j = 0; j < weights_per_block; ++j) { - uint8_t x = block_data[j + 2]; // j+2 to skip the scale bytes. - // Original data is in int8_t, so we add a bias of -128 and invert the first bit. - x ^= 1 << 7; - weights[i * weights_per_block + j] = x; - } - }); + for (size_t j = 0; j < weights_per_block; ++j) { + uint8_t x = block_data[j + 2]; + x ^= 1 << 7; // Convert int8 to uint8 by flipping sign bit + weights[i * weights_per_block + j] = x; + } + }); + } else { + // Symmetric: store original int8 values directly (no unsigned bias) + ov::parallel_for(scales_arr.get_size(), [&](size_t i) { + uint8_t * block_data = data + i * bytes_per_block; + scales[i] = ov::float16::from_bits(*(uint16_t *) block_data); + // Copy int8 weights as-is (the tensor element type is i8) + memcpy(weights + i * weights_per_block, block_data + 2, weights_per_block); + }); + } } void unpack_256_4(const uint8_t * data, uint8_t * dst) { @@ -256,44 +263,62 @@ void extract_q6_k_data(const ggml_tensor * tensor, auto * data = static_cast(tensor->data); auto * weights = static_cast(weights_arr.data()); auto * scales = scales_arr.data::value_type>(); - auto * zp = static_cast(zp_arr.data()); - - bool is_scalar_zp = (zp_arr.get_size() == 1); // Symmetric quantization - - // For Q6_K, zero point is always 32 - if (is_scalar_zp) { - zp[0] = 32; - } - - ov::parallel_for(n_super_block, [&](size_t i) { - uint8_t * block_data = data + i * bytes_per_block; - float scale_factor = - static_cast(ov::float16::from_bits(*((uint16_t *) block_data + 104))); // (128+64+16)/2 + bool is_symmetric = (weights_arr.get_element_type() == ov::element::i8); // Signed i8 path - for (size_t j = 0; j < 16; j++) { - scales[j + i * 16] = - ov::float16(scale_factor * static_cast(*((int8_t *) (block_data + 128 + 64 + j)))); - // For asymmetric quantization, store per-block zero points - if (!is_scalar_zp) { + if (!is_symmetric) { + auto * zp = static_cast(zp_arr.data()); + ov::parallel_for(n_super_block, [&](size_t i) { + uint8_t * block_data = data + i * bytes_per_block; + float scale_factor = static_cast(ov::float16::from_bits(*((uint16_t *) block_data + 104))); + for (size_t j = 0; j < 16; j++) { + scales[j + i * 16] = + ov::float16(scale_factor * static_cast(*((int8_t *) (block_data + 128 + 64 + j)))); zp[j + i * 16] = 32; } - } - - uint8_t * ql = block_data; - uint8_t * qh = block_data + 128; - - for (int64_t j = 0; j < 32; ++j) { - weights[i * 256 + j] = (ql[j] & 0xF) | (((qh[j] >> 0) & 3) << 4); - weights[i * 256 + j + 32] = (ql[32 + j] & 0xF) | (((qh[j] >> 2) & 3) << 4); - weights[i * 256 + j + 64] = (ql[j] >> 4) | (((qh[j] >> 4) & 3) << 4); - weights[i * 256 + j + 96] = (ql[32 + j] >> 4) | (((qh[j] >> 6) & 3) << 4); - weights[i * 256 + j + 128] = (ql[64 + j] & 0xF) | (((qh[32 + j] >> 0) & 3) << 4); - weights[i * 256 + j + 160] = (ql[96 + j] & 0xF) | (((qh[32 + j] >> 2) & 3) << 4); - weights[i * 256 + j + 192] = (ql[64 + j] >> 4) | (((qh[32 + j] >> 4) & 3) << 4); - weights[i * 256 + j + 224] = (ql[96 + j] >> 4) | (((qh[32 + j] >> 6) & 3) << 4); - } - }); + uint8_t * ql = block_data; + uint8_t * qh = block_data + 128; + for (int64_t j = 0; j < 32; ++j) { + weights[i * 256 + j] = (ql[j] & 0xF) | (((qh[j] >> 0) & 3) << 4); + weights[i * 256 + j + 32] = (ql[32 + j] & 0xF) | (((qh[j] >> 2) & 3) << 4); + weights[i * 256 + j + 64] = (ql[j] >> 4) | (((qh[j] >> 4) & 3) << 4); + weights[i * 256 + j + 96] = (ql[32 + j] >> 4) | (((qh[j] >> 6) & 3) << 4); + weights[i * 256 + j + 128] = (ql[64 + j] & 0xF) | (((qh[32 + j] >> 0) & 3) << 4); + weights[i * 256 + j + 160] = (ql[96 + j] & 0xF) | (((qh[32 + j] >> 2) & 3) << 4); + weights[i * 256 + j + 192] = (ql[64 + j] >> 4) | (((qh[32 + j] >> 4) & 3) << 4); + weights[i * 256 + j + 224] = (ql[96 + j] >> 4) | (((qh[32 + j] >> 6) & 3) << 4); + } + }); + } else { + // Symmetric: subtract 32 from each weight to store as signed i8 + ov::parallel_for(n_super_block, [&](size_t i) { + uint8_t * block_data = data + i * bytes_per_block; + float scale_factor = static_cast(ov::float16::from_bits(*((uint16_t *) block_data + 104))); + for (size_t j = 0; j < 16; j++) { + scales[j + i * 16] = + ov::float16(scale_factor * static_cast(*((int8_t *) (block_data + 128 + 64 + j)))); + } + uint8_t * ql = block_data; + uint8_t * qh = block_data + 128; + auto * signed_weights = reinterpret_cast(weights); + for (int64_t j = 0; j < 32; ++j) { + signed_weights[i * 256 + j] = static_cast((ql[j] & 0xF) | (((qh[j] >> 0) & 3) << 4)) - 32; + signed_weights[i * 256 + j + 32] = + static_cast((ql[32 + j] & 0xF) | (((qh[j] >> 2) & 3) << 4)) - 32; + signed_weights[i * 256 + j + 64] = static_cast((ql[j] >> 4) | (((qh[j] >> 4) & 3) << 4)) - 32; + signed_weights[i * 256 + j + 96] = + static_cast((ql[32 + j] >> 4) | (((qh[j] >> 6) & 3) << 4)) - 32; + signed_weights[i * 256 + j + 128] = + static_cast((ql[64 + j] & 0xF) | (((qh[32 + j] >> 0) & 3) << 4)) - 32; + signed_weights[i * 256 + j + 160] = + static_cast((ql[96 + j] & 0xF) | (((qh[32 + j] >> 2) & 3) << 4)) - 32; + signed_weights[i * 256 + j + 192] = + static_cast((ql[64 + j] >> 4) | (((qh[32 + j] >> 4) & 3) << 4)) - 32; + signed_weights[i * 256 + j + 224] = + static_cast((ql[96 + j] >> 4) | (((qh[32 + j] >> 6) & 3) << 4)) - 32; + } + }); + } } static inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t * d, uint8_t * m) { @@ -389,11 +414,10 @@ ov::Output make_int8_weights(ov::Tensor & weight, size_t group_size, bool use_bias) { ov::Shape orig_shape = weight.get_shape(); + bool is_signed = (weight.get_element_type() == ov::element::i8); // Symmetric: signed weights, no ZP // Expand dimensions for scales and zp/bias auto scale_shape = scales.get_shape(); - auto zp_shape = zp.get_shape(); - bool is_scalar_zp = zp_shape.empty(); // Symmetric quantization ov::Shape packed_shape = {orig_shape[0], orig_shape[1] / group_size, group_size}; @@ -403,37 +427,48 @@ ov::Output make_int8_weights(ov::Tensor & weight, } else { scale_shape.push_back(1); scales.set_shape(scale_shape); - // For symmetric quantization, zp remains scalar (don't resize) - if (!is_scalar_zp) { + if (!is_signed && zp.get_size() > 0) { + auto zp_shape = zp.get_shape(); zp_shape.push_back(1); zp.set_shape(zp_shape); } } - // Create graph nodes - auto weights_node = std::make_shared(ov::element::u8, packed_shape, - static_cast(weight.data()), nullptr); - weights_node->get_rt_info()["__gguf_tensor_holder"] = weight; auto scales_f16 = std::make_shared(scales); - auto weights_f16 = std::make_shared(weights_node, ov::element::f16); ov::Output result; - if (use_bias && !is_scalar_zp) { - // Bias path: w * s + b (zp tensor holds f16 bias values) - auto bias_f16 = std::make_shared(zp); - auto w_s = std::make_shared(weights_f16, scales_f16, ov::op::AutoBroadcastType::NUMPY); - result = std::make_shared(w_s, bias_f16, ov::op::AutoBroadcastType::NUMPY); + if (is_signed) { + // Signed path: q * s (no zero point subtraction needed) + auto weights_node = std::make_shared(ov::element::i8, packed_shape, + static_cast(weight.data()), nullptr); + weights_node->get_rt_info()["__gguf_tensor_holder"] = weight; + auto weights_f16 = std::make_shared(weights_node, ov::element::f16); + result = std::make_shared(weights_f16, scales_f16, ov::op::AutoBroadcastType::NUMPY); } else { - // Zero point path: (w - zp) * s - auto zero_point = std::make_shared(zp); - float zp_value; - if (ov::op::util::get_single_value(zero_point, zp_value)) { - zero_point = ov::op::v0::Constant::create(zero_point->get_element_type(), {}, {zp_value}); + // Unsigned path + auto weights_node = std::make_shared(ov::element::u8, packed_shape, + static_cast(weight.data()), nullptr); + weights_node->get_rt_info()["__gguf_tensor_holder"] = weight; + auto weights_f16 = std::make_shared(weights_node, ov::element::f16); + + if (use_bias && zp.get_size() > 0) { + // Bias path: w * s + b (zp tensor holds f16 bias values) + auto bias_f16 = std::make_shared(zp); + auto w_s = + std::make_shared(weights_f16, scales_f16, ov::op::AutoBroadcastType::NUMPY); + result = std::make_shared(w_s, bias_f16, ov::op::AutoBroadcastType::NUMPY); + } else { + // Zero point path: (w - zp) * s + auto zero_point = std::make_shared(zp); + float zp_value; + if (ov::op::util::get_single_value(zero_point, zp_value)) { + zero_point = ov::op::v0::Constant::create(zero_point->get_element_type(), {}, {zp_value}); + } + auto zero_point_f16 = std::make_shared(zero_point, ov::element::f16); + auto w_zp = + std::make_shared(weights_f16, zero_point_f16, ov::op::AutoBroadcastType::NUMPY); + result = std::make_shared(w_zp, scales_f16, ov::op::AutoBroadcastType::NUMPY); } - auto zero_point_f16 = std::make_shared(zero_point, ov::element::f16); - auto w_zp = - std::make_shared(weights_f16, zero_point_f16, ov::op::AutoBroadcastType::NUMPY); - result = std::make_shared(w_zp, scales_f16, ov::op::AutoBroadcastType::NUMPY); } if (packed_shape.size() != 2) { @@ -452,11 +487,10 @@ ov::Output make_int4_weights(ov::Tensor & weight, size_t group_size, bool use_bias) { ov::Shape orig_weight_shape = weight.get_shape(); + bool is_signed = (weight.get_element_type() == ov::element::i4); // Symmetric: signed weights, no ZP // Expand dimensions for scales and zp/bias ov::Shape scale_shape = scales.get_shape(); - auto zp_shape = zp.get_shape(); - bool is_scalar_zp = zp_shape.empty(); // Symmetric quantization // Create INT4 weight tensor ov::Shape packed_shape = {orig_weight_shape[0], orig_weight_shape[1] / group_size, group_size}; @@ -467,36 +501,48 @@ ov::Output make_int4_weights(ov::Tensor & weight, } else { scale_shape.push_back(1); scales.set_shape(scale_shape); - // For symmetric quantization, zp remains scalar (don't resize) - if (!is_scalar_zp) { + if (!is_signed && zp.get_size() > 0) { + auto zp_shape = zp.get_shape(); zp_shape.push_back(1); zp.set_shape(zp_shape); } } - auto weights_node = std::make_shared(ov::element::u4, packed_shape, - static_cast(weight.data()), nullptr); - weights_node->get_rt_info()["__gguf_tensor_holder"] = weight; - auto weights_f16 = std::make_shared(weights_node, ov::element::f16); auto scales_f16 = std::make_shared(scales); ov::Output result; - if (use_bias && !is_scalar_zp) { - // Bias path: w * s + b (zp tensor holds f16 bias values) - auto bias_f16 = std::make_shared(zp); - auto w_s = std::make_shared(weights_f16, scales_f16, ov::op::AutoBroadcastType::NUMPY); - result = std::make_shared(w_s, bias_f16, ov::op::AutoBroadcastType::NUMPY); + if (is_signed) { + // Signed path: q * s (no zero point subtraction needed) + auto weights_node = std::make_shared(ov::element::i4, packed_shape, + static_cast(weight.data()), nullptr); + weights_node->get_rt_info()["__gguf_tensor_holder"] = weight; + auto weights_f16 = std::make_shared(weights_node, ov::element::f16); + result = std::make_shared(weights_f16, scales_f16, ov::op::AutoBroadcastType::NUMPY); } else { - // Zero point path: (w - zp) * s - auto zero_points_node = std::make_shared(zp); - float zp_value; - if (ov::op::util::get_single_value(zero_points_node, zp_value)) { - zero_points_node = ov::op::v0::Constant::create(zero_points_node->get_element_type(), {}, {zp_value}); + // Unsigned path + auto weights_node = std::make_shared(ov::element::u4, packed_shape, + static_cast(weight.data()), nullptr); + weights_node->get_rt_info()["__gguf_tensor_holder"] = weight; + auto weights_f16 = std::make_shared(weights_node, ov::element::f16); + + if (use_bias && zp.get_size() > 0) { + // Bias path: w * s + b (zp tensor holds f16 bias values) + auto bias_f16 = std::make_shared(zp); + auto w_s = + std::make_shared(weights_f16, scales_f16, ov::op::AutoBroadcastType::NUMPY); + result = std::make_shared(w_s, bias_f16, ov::op::AutoBroadcastType::NUMPY); + } else { + // Zero point path: (w - zp) * s + auto zero_points_node = std::make_shared(zp); + float zp_value; + if (ov::op::util::get_single_value(zero_points_node, zp_value)) { + zero_points_node = ov::op::v0::Constant::create(zero_points_node->get_element_type(), {}, {zp_value}); + } + auto zero_points_f16 = std::make_shared(zero_points_node, ov::element::f16); + auto w_zp = + std::make_shared(weights_f16, zero_points_f16, ov::op::AutoBroadcastType::NUMPY); + result = std::make_shared(w_zp, scales_f16, ov::op::AutoBroadcastType::NUMPY); } - auto zero_points_f16 = std::make_shared(zero_points_node, ov::element::f16); - auto w_zp = - std::make_shared(weights_f16, zero_points_f16, ov::op::AutoBroadcastType::NUMPY); - result = std::make_shared(w_zp, scales_f16, ov::op::AutoBroadcastType::NUMPY); } if (packed_shape.size() != 2) { @@ -699,24 +745,32 @@ OvWeight process_weight_tensor(const ggml_tensor * tensor, const void * data, vo // Quantized path (normal extraction or quantized requant) // Create weight/scale/zp tensors - shared between both paths - ov::element::Type weight_type = layout.is_u4 ? ov::element::u4 : ov::element::u8; + // For symmetric quantization, use signed types (i4/i8) and no ZP tensor + ov::element::Type weight_type = layout.is_symmetric ? (layout.is_u4 ? ov::element::i4 : ov::element::i8) : + (layout.is_u4 ? ov::element::u4 : ov::element::u8); ov::Shape scale_shape = {node_shape[0], node_shape[1] / layout.weights_per_block}; - ov::Shape zp_shape = layout.is_symmetric ? ov::Shape{} : scale_shape; if (output_base_ptr) { uint8_t * buf_base = static_cast(output_base_ptr); result.weights = ov::Tensor(weight_type, node_shape, buf_base + layout.weights_offset); result.scales = ov::Tensor(ov::element::f16, scale_shape, buf_base + layout.scales_offset); - result.zp = ov::Tensor(weight_type, zp_shape, buf_base + layout.zp_offset); + if (!layout.is_symmetric) { + ov::element::Type zp_type = layout.is_u4 ? ov::element::u4 : ov::element::u8; + result.zp = ov::Tensor(zp_type, scale_shape, buf_base + layout.zp_offset); + } + // else: result.zp remains default-constructed (empty) for symmetric } else { result.weights = ov::Tensor(weight_type, node_shape); result.scales = ov::Tensor(ov::element::f16, scale_shape); - if (use_bias && !layout.is_symmetric) { - // bias only has effect for asymmetric quant - result.zp = ov::Tensor(ov::element::f16, zp_shape); - } else { - result.zp = ov::Tensor(weight_type, zp_shape); + if (!layout.is_symmetric) { + if (use_bias) { + result.zp = ov::Tensor(ov::element::f16, scale_shape); + } else { + ov::element::Type zp_type = layout.is_u4 ? ov::element::u4 : ov::element::u8; + result.zp = ov::Tensor(zp_type, scale_shape); + } } + // else: result.zp remains default-constructed (empty) for symmetric } if (layout.is_requant && layout.requant_type.has_value()) { @@ -741,59 +795,75 @@ void quantize_q4_0(const float * x, auto * weights = static_cast(weights_arr.data()); auto * scales = scales_arr.data::value_type>(); - auto * zp = static_cast(zp_arr.data()); - bool is_scalar_zp = (zp_arr.get_size() == 1); // Symmetric quantization - - // For Q4_0, zero point is always 8 - if (is_scalar_zp) { - zp[0] = 8 | (8 << 4); // Pack two 4-bit values - } + bool is_symmetric = (weights_arr.get_element_type() == ov::element::i4); // Signed i4 path - for (int i = 0; i < nb; i++) { - float amax = 0.0f; // absolute max - float max = 0.0f; - - for (int j = 0; j < qk; j++) { - const float v = x[i * qk + j]; - if (amax < fabsf(v)) { - amax = fabsf(v); - max = v; + if (!is_symmetric) { + auto * zp = static_cast(zp_arr.data()); + for (int i = 0; i < nb; i++) { + float amax = 0.0f; + float max = 0.0f; + for (int j = 0; j < qk; j++) { + const float v = x[i * qk + j]; + if (amax < fabsf(v)) { + amax = fabsf(v); + max = v; + } } - } - - const float d = max / -8; - - if (d == 0) { - scales[i] = ov::float16(1.0f); - // zp is already set to 8 for symmetric, or set per-block for asymmetric - if (!is_scalar_zp) { + const float d = max / -8; + if (d == 0) { + scales[i] = ov::float16(1.0f); if (i % 2 == 0) { zp[i / 2] = 8; } else { zp[i / 2] |= (8 << 4); } + memset(weights + i * qk / 2, 8 | (8 << 4), qk / 2); + continue; } - memset(weights + i * qk / 2, 8 | (8 << 4), qk / 2); - continue; - } - - const float id = 1.0f / d; - scales[i] = ov::float16(d); - // For asymmetric quantization, store per-block zero points - if (!is_scalar_zp) { + const float id = 1.0f / d; + scales[i] = ov::float16(d); if (i % 2 == 0) { zp[i / 2] = 8; } else { zp[i / 2] |= (8 << 4); } + for (int j = 0; j < qk / 2; ++j) { + const float x0 = x[i * qk + 2 * j] * id; + const float x1 = x[i * qk + 2 * j + 1] * id; + const uint8_t xi0 = MIN(15, (int8_t) (x0 + 8.5f)); + const uint8_t xi1 = MIN(15, (int8_t) (x1 + 8.5f)); + weights[i * qk / 2 + j] = xi0 | (xi1 << 4); + } } - - for (int j = 0; j < qk / 2; ++j) { - const float x0 = x[i * qk + 2 * j] * id; - const float x1 = x[i * qk + 2 * j + 1] * id; - const uint8_t xi0 = MIN(15, (int8_t) (x0 + 8.5f)); - const uint8_t xi1 = MIN(15, (int8_t) (x1 + 8.5f)); - weights[i * qk / 2 + j] = xi0 | (xi1 << 4); + } else { + // Symmetric: produce signed i4 values in [-8, 7] + for (int i = 0; i < nb; i++) { + float amax = 0.0f; + float max = 0.0f; + for (int j = 0; j < qk; j++) { + const float v = x[i * qk + j]; + if (amax < fabsf(v)) { + amax = fabsf(v); + max = v; + } + } + const float d = max / -8; + if (d == 0) { + scales[i] = ov::float16(1.0f); + // i4 value 0 packed: 0x00 + memset(weights + i * qk / 2, 0, qk / 2); + continue; + } + const float id = 1.0f / d; + scales[i] = ov::float16(d); + for (int j = 0; j < qk / 2; ++j) { + const float x0 = x[i * qk + 2 * j] * id; + const float x1 = x[i * qk + 2 * j + 1] * id; + // Signed i4: range [-8, 7]. Quantize as round(x*id), then pack as 4-bit two's complement. + int8_t si0 = (int8_t) std::max(-8, std::min(7, (int) roundf(x0))); + int8_t si1 = (int8_t) std::max(-8, std::min(7, (int) roundf(x1))); + weights[i * qk / 2 + j] = (si0 & 0x0F) | ((si1 & 0x0F) << 4); + } } } } @@ -809,36 +879,42 @@ void quantize_q8_0(const float * x, auto * weights = static_cast(weights_arr.data()); auto * scales = scales_arr.data::value_type>(); - auto * zp = static_cast(zp_arr.data()); - bool is_scalar_zp = (zp_arr.get_size() == 1); // Symmetric quantization - - // For Q8_0, zero point is always 128 - if (is_scalar_zp) { - zp[0] = 128; - } - - for (int i = 0; i < nb; i++) { - float amax = 0.0f; // absolute max + bool is_symmetric = (weights_arr.get_element_type() == ov::element::i8); // Signed i8 path - for (int j = 0; j < qk; j++) { - const float v = x[i * qk + j]; - if (amax < fabsf(v)) { - amax = fabsf(v); + if (!is_symmetric) { + auto * zp = static_cast(zp_arr.data()); + for (int i = 0; i < nb; i++) { + float amax = 0.0f; + for (int j = 0; j < qk; j++) { + const float v = x[i * qk + j]; + amax = std::max(amax, fabsf(v)); } - } - - const float d = amax / 127.0f; - const float id = d ? 1.0f / d : 0.0f; - scales[i] = ov::float16(d); - // For asymmetric quantization, store per-block zero points - if (!is_scalar_zp) { + const float d = amax / 127.0f; + const float id = d ? 1.0f / d : 0.0f; + scales[i] = ov::float16(d); zp[i] = 128; + for (int j = 0; j < qk; ++j) { + const float x0 = x[i * qk + j] * id; + const int8_t xi0 = roundf(x0); + weights[i * qk + j] = (uint8_t) (xi0 + 128); + } } - - for (int j = 0; j < qk; ++j) { - const float x0 = x[i * qk + j] * id; - const int8_t xi0 = roundf(x0); - weights[i * qk + j] = (uint8_t) (xi0 + 128); + } else { + // Symmetric: store signed int8 values directly + auto * signed_weights = reinterpret_cast(weights); + for (int i = 0; i < nb; i++) { + float amax = 0.0f; + for (int j = 0; j < qk; j++) { + const float v = x[i * qk + j]; + amax = std::max(amax, fabsf(v)); + } + const float d = amax / 127.0f; + const float id = d ? 1.0f / d : 0.0f; + scales[i] = ov::float16(d); + for (int j = 0; j < qk; ++j) { + const float x0 = x[i * qk + j] * id; + signed_weights[i * qk + j] = (int8_t) roundf(x0); + } } } } @@ -861,12 +937,8 @@ void quantize_q8_1(const float * x, for (int j = 0; j < qk; j++) { const float v = x[i * qk + j]; - if (v < min) { - min = v; - } - if (v > max) { - max = v; - } + min = std::min(v, min); + max = std::max(v, max); } const float d = (max - min) / ((1 << 8) - 1); diff --git a/ggml/src/ggml-openvino/openvino/op/rope.cpp b/ggml/src/ggml-openvino/openvino/op/rope.cpp index 26dc2d24f82..a8db9b38930 100644 --- a/ggml/src/ggml-openvino/openvino/op/rope.cpp +++ b/ggml/src/ggml-openvino/openvino/op/rope.cpp @@ -9,12 +9,17 @@ #include #include #include +#include +#include +#include #include #include #include +#include #include #include #include +#include #include #include @@ -33,6 +38,12 @@ OutputVector translate_rope(const NodeContext & context) { auto data_node = context.get_input(0).get_node_shared_ptr(); auto output_shape = context.get_output_shape().to_shape(); int32_t * op_params = context.get_output_op_params(); + const int mode = (op_case & 0xFFFF0000) >> 16; + op_case = (op_case & 0x0000FFFF); + + constexpr int TYPE_NORMAL = 0; + constexpr int TYPE_NEOX = 1; + constexpr int TYPE_IMROPE = 2; Output cos_theta_node; Output sin_theta_node; @@ -45,7 +56,7 @@ OutputVector translate_rope(const NodeContext & context) { if (context.get_input_size() == 3) { rope_freqs_weight = context.get_input(2).get_node_shared_ptr(); } - auto sin_cos = make_sin_cos(op_params, inp_pos, rope_freqs_weight); + auto sin_cos = make_sin_cos(op_params, inp_pos, rope_freqs_weight, mode == TYPE_IMROPE); sin_theta_node = sin_cos.first; cos_theta_node = sin_cos.second; } @@ -65,11 +76,7 @@ OutputVector translate_rope(const NodeContext & context) { } } - const int mode = op_params[2]; - constexpr int ROPE_TYPE_NORMAL = 0; - constexpr int ROPE_TYPE_NEOX = 2; - - if (mode == ROPE_TYPE_NORMAL) { + if (mode == TYPE_NORMAL) { auto neg_one = ov::op::v0::Constant::create(ov::element::i64, {1}, {-1}); auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0}); auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1}); @@ -97,7 +104,7 @@ OutputVector translate_rope(const NodeContext & context) { auto data_shape = ov::op::v0::Constant::create( ov::element::i64, {4}, std::vector{1, -1, (int64_t) output_shape[2], (int64_t) output_shape[3]}); res = std::make_shared(stack, data_shape, false); - } else if (mode == ROPE_TYPE_NEOX) { + } else if (mode == TYPE_NEOX) { auto data_split = std::make_shared( data_node, ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {-1}), 2); Output slice_data_node_0 = data_split->outputs()[0]; @@ -112,6 +119,25 @@ OutputVector translate_rope(const NodeContext & context) { std::make_shared(slice_data_node_1, cos_theta_node)); res = std::make_shared(ov::OutputVector{first_half_node, second_half_node}, -1); + } else if (mode == TYPE_IMROPE) { + int64_t n_dims = data_node->get_shape()[3]; + auto cos_sin_shape = std::make_shared(ov::element::i64, ov::Shape{4}, std::vector{1,-1,1,(n_dims >> 1)}); + auto cos_reshaped = std::make_shared(cos_theta_node, cos_sin_shape, true); + auto sin_reshaped = std::make_shared(sin_theta_node, cos_sin_shape, true); + + auto split_axis = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {3}); + auto split_a = std::make_shared(data_node, split_axis, 2); + auto x0 = split_a->output(0); + auto x1 = split_a->output(1); + auto mul_a = std::make_shared(x0, cos_reshaped); + auto mul_b = std::make_shared(x1, sin_reshaped); + auto sub = std::make_shared(mul_a, mul_b); + + auto mul_c = std::make_shared(x0, sin_reshaped); + auto mul_d = std::make_shared(x1, cos_reshaped); + auto add = std::make_shared(mul_c, mul_d); + + res = std::make_shared(ov::OutputVector{sub, add}, 3); } return rename_outputs_with_suffix({res}, context.get_name()); diff --git a/ggml/src/ggml-openvino/openvino/op/unary_gelu.cpp b/ggml/src/ggml-openvino/openvino/op/unary_gelu.cpp new file mode 100644 index 00000000000..d1e9efc33a5 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op/unary_gelu.cpp @@ -0,0 +1,25 @@ +#include "../node_context.h" +#include "../op_table.h" +#include "../utils.h" + +#include +#include + +namespace ov { +namespace frontend { +namespace ggml { +namespace op { + +OutputVector translate_unary_gelu(const NodeContext & context) { + num_inputs_check(context, 1, 1); + + auto input = context.get_input(0); + auto res = std::make_shared(input); + + return rename_outputs_with_suffix({res}, context.get_name()); +} + +} // namespace op +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op_table.cpp b/ggml/src/ggml-openvino/openvino/op_table.cpp index beadafe8103..1385539279c 100644 --- a/ggml/src/ggml-openvino/openvino/op_table.cpp +++ b/ggml/src/ggml-openvino/openvino/op_table.cpp @@ -31,6 +31,7 @@ std::unordered_map get_supported_ops() { {"GGML_OP_SOFT_MAX", op::translate_soft_max }, {"GGML_OP_SUB", op::translate_1to1_match_2_inputs}, {"GGML_OP_TRANSPOSE", op::translate_transpose }, + {"GGML_UNARY_OP_GELU", op::translate_unary_gelu }, {"GGML_UNARY_OP_SILU", op::translate_unary_silu }, {"GGML_OP_VIEW", op::translate_view }, {"GGML_GLU_OP_SWIGLU", op::translate_glu_swiglu }, diff --git a/ggml/src/ggml-openvino/openvino/op_table.h b/ggml/src/ggml-openvino/openvino/op_table.h index 37f763117aa..f546796d2ee 100644 --- a/ggml/src/ggml-openvino/openvino/op_table.h +++ b/ggml/src/ggml-openvino/openvino/op_table.h @@ -21,6 +21,7 @@ GGML_OP_CONVERTER(translate_rms_norm); GGML_OP_CONVERTER(translate_rope); GGML_OP_CONVERTER(translate_scale); GGML_OP_CONVERTER(translate_unary_silu); +GGML_OP_CONVERTER(translate_unary_gelu); GGML_OP_CONVERTER(translate_soft_max); GGML_OP_CONVERTER(translate_transpose); GGML_OP_CONVERTER(translate_view); diff --git a/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp b/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp deleted file mode 100644 index ed2a3ab6d1b..00000000000 --- a/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +++ /dev/null @@ -1,123 +0,0 @@ -#include "eliminate_zp.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace ov { -namespace frontend { -namespace ggml { -namespace pass { - -EliminateZeroPoints::EliminateZeroPoints() { - // Find pattern: - // (Multiply Any(scale) - // (Subtract (Convert Constant(data))) - // (Convert Constant(zero_point))) - // where zero_point is a scalar - // If data is u4 and zp value is 8 (q4_0), Replace the Subtract with an i4 Constant whose value is data - zp_val - // If data is u8 and zp value is 128 (q8_0) or 32 (q6_k), Replace the Subtract with an i8 Constant - - auto m_data_constant = ov::pass::pattern::wrap_type(); - auto m_data_convert = ov::pass::pattern::wrap_type({m_data_constant}); - - auto m_zp_constant = ov::pass::pattern::wrap_type(); - auto m_zp_convert = ov::pass::pattern::wrap_type({m_zp_constant}); - - auto m_subtract = ov::pass::pattern::wrap_type({m_data_convert, m_zp_convert}); - auto m_scale = ov::pass::pattern::any_input(); - auto m_multiply = ov::pass::pattern::wrap_type({m_scale, m_subtract}); - - const auto callback = [=](ov::pass::pattern::Matcher & m) { - const auto & pattern_map = m.get_pattern_value_map(); - - auto multiply_node = - std::dynamic_pointer_cast(pattern_map.at(m_multiply).get_node_shared_ptr()); - auto subtract_node = - std::dynamic_pointer_cast(pattern_map.at(m_subtract).get_node_shared_ptr()); - auto data_constant = - std::dynamic_pointer_cast(pattern_map.at(m_data_constant).get_node_shared_ptr()); - auto zp_constant = - std::dynamic_pointer_cast(pattern_map.at(m_zp_constant).get_node_shared_ptr()); - - if (!multiply_node || !subtract_node || !data_constant || !zp_constant) { - return false; - } - - if (ov::shape_size(zp_constant->get_shape()) != 1) { - return false; - } - - auto data_type = data_constant->get_element_type(); - auto zp_data = zp_constant->cast_vector(); - - if (zp_data.empty()) { - return false; - } - - int zp_value = zp_data[0]; - - bool should_eliminate = false; - ov::element::Type target_type; - - if (data_type == ov::element::u4 && zp_value == 8) { - should_eliminate = true; - target_type = ov::element::i4; - } else if (data_type == ov::element::u8 && (zp_value == 128 || zp_value == 32)) { - should_eliminate = true; - target_type = ov::element::i8; - } - - if (!should_eliminate) { - return false; - } - - auto data_shape = data_constant->get_shape(); - size_t total_elements = ov::shape_size(data_shape); - - std::shared_ptr new_constant; - - // TODO improve performance - if (data_type == ov::element::u4) { - auto data_values = data_constant->cast_vector(); - std::vector adjusted_values(total_elements); - - ov::parallel_for(total_elements, [&](size_t i) { - adjusted_values[i] = static_cast(static_cast(data_values[i]) - 8); - }); - - new_constant = std::make_shared(target_type, data_shape, adjusted_values); - } else if (data_type == ov::element::u8) { - auto data_values = data_constant->cast_vector(); - std::vector adjusted_values(total_elements); - - ov::parallel_for(total_elements, [&, zp_value](size_t i) { - adjusted_values[i] = static_cast(static_cast(data_values[i]) - zp_value); - }); - - new_constant = std::make_shared(target_type, data_shape, adjusted_values); - } - - auto new_convert = - std::make_shared(new_constant, subtract_node->get_output_element_type(0)); - ov::replace_node(subtract_node, new_convert); - - return true; - }; - - register_matcher( - std::make_shared(m_multiply, "ov::frontend::ggml::pass::EliminateZeroPoints"), - callback); -} - -} // namespace pass -} // namespace ggml -} // namespace frontend -} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h b/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h deleted file mode 100644 index edd3cd718d9..00000000000 --- a/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +++ /dev/null @@ -1,17 +0,0 @@ -#include "openvino/pass/matcher_pass.hpp" - -namespace ov { -namespace frontend { -namespace ggml { -namespace pass { - -class EliminateZeroPoints : public ov::pass::MatcherPass { -public: - OPENVINO_MATCHER_PASS_RTTI("ov::frontend::ggml::pass::EliminateZeroPoints") - EliminateZeroPoints(); -}; - -} // namespace pass -} // namespace ggml -} // namespace frontend -} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/rt_info/weightless_caching_attributes.hpp b/ggml/src/ggml-openvino/openvino/rt_info/weightless_caching_attributes.hpp new file mode 100644 index 00000000000..f051891c481 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/rt_info/weightless_caching_attributes.hpp @@ -0,0 +1,41 @@ +// Copyright (C) 2018-2026 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include + +namespace ov { + +/** + * @brief Holds weightless caching attributes of a single constant. + * + * WeightlessCacheAttribute class represents runtime info attribute that holds + * the values of original size of the constant in bytes and the binary offset of the + * constant's data in the weights file used by the weightless caching mechanism. It's + * not copyable in case the data was changed (the original node was replaced by a new + * one produced during the tranformation pipeline) - in that case weightless caching + * can't be used for that constant. + */ +class OPENVINO_API WeightlessCacheAttribute : public RuntimeAttribute { +public: + OPENVINO_RTTI("WeightlessCacheAttribute", "0", RuntimeAttribute) + + WeightlessCacheAttribute() = delete; + + WeightlessCacheAttribute(size_t original_size, size_t bin_offset, ov::element::Type original_dtype) + : original_size(original_size), + bin_offset(bin_offset), + original_dtype(original_dtype) {} + + bool is_copyable() const override; + + size_t original_size; + size_t bin_offset; + ov::element::Type original_dtype; +}; + +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/translate_session.cpp b/ggml/src/ggml-openvino/openvino/translate_session.cpp index 23a1dea2496..0f68a1f5062 100644 --- a/ggml/src/ggml-openvino/openvino/translate_session.cpp +++ b/ggml/src/ggml-openvino/openvino/translate_session.cpp @@ -3,15 +3,16 @@ #include "ggml-openvino/openvino/node_context.h" #include "ggml-openvino/openvino/utils.h" #include "input_model.h" -#include "pass/eliminate_zp.h" #include "pass/mark_decompression_convert_constant_folding.h" #include "pass/squeeze_matmul.h" +#include "rt_info/weightless_caching_attributes.hpp" #include #include #include #include #include +#include #include #include #include @@ -33,7 +34,6 @@ #include #include #include -#include namespace ov { namespace frontend { @@ -240,6 +240,31 @@ std::shared_ptr TranslateSession::translate_graph(const frontend::InputMo resulting_model = std::make_shared(results, used_params); apply_transformations(resulting_model); + + // Set WeightlessCacheAttribute on large constants to avoid unnecessary memory copies + // in the NPUW plugin. Without this attribute, NPUW's LazyTensor constructor + // (lazy_tensor.cpp, op::Const::Const) will memcpy every constant "in case export + // occurs", doubling memory usage per compile_model call. + // + // The bin_offset field serves as a unique key (not a real file offset) — this is + // the same convention the GPU plugin uses for non-IR models (see + // Plugin::set_weightless_cache_attributes in intel_gpu/src/plugin/plugin.cpp). + // Each constant must have a distinct bin_offset, otherwise GPU's weightless cache + // import will map multiple constants to the same data. + // + // Small constants (< 16 elements) are excluded since they may be introduced by + // optimization patterns and the overhead is negligible. + size_t offset = 0; + for (auto & node : resulting_model->get_ordered_ops()) { + if (auto cnst = ov::as_type_ptr(node); + cnst && cnst->get_byte_size() / cnst->get_element_type().size() >= 16) { + auto & rt_info = cnst->get_rt_info(); + if (rt_info.find(ov::WeightlessCacheAttribute::get_type_info_static()) == rt_info.end()) { + rt_info[ov::WeightlessCacheAttribute::get_type_info_static()] = + ov::WeightlessCacheAttribute(cnst->get_byte_size(), offset++, cnst->get_element_type()); + } + } + } return resulting_model; } @@ -257,7 +282,6 @@ std::shared_ptr TranslateSession::apply_transformations(std::shared_ptris_static()) { - manager.register_pass(); manager.register_pass(); } manager.run_passes(model); diff --git a/ggml/src/ggml-openvino/openvino/utils.cpp b/ggml/src/ggml-openvino/openvino/utils.cpp index 65356a51b51..0baaf88e17a 100644 --- a/ggml/src/ggml-openvino/openvino/utils.cpp +++ b/ggml/src/ggml-openvino/openvino/utils.cpp @@ -2,6 +2,7 @@ #include "ggml-impl.h" +#include #include #include #include @@ -13,6 +14,7 @@ #include #include #include +#include #include #include #include @@ -87,8 +89,11 @@ ov::Output rope_yarn_ramp_mix(int n_dims, const float corr_dims[2], fl auto ramp_y = std::make_shared(std::make_shared(dim_ids, corr_low), denom); auto ramp_clamped = std::make_shared(ramp_y, 0.0f, 1.0f); + // rope_yarn_ramp returns (1 - clamp(y)), so invert before scaling + auto one = ov::op::v0::Constant::create(ov::element::f32, Shape{1, 1, 1, 1}, {1.0f}); + auto ramp_inverted = std::make_shared(one, ramp_clamped); auto ext_factor_node = ov::op::v0::Constant::create(ov::element::f32, Shape{}, {ext_factor}); - auto ramp_mix = std::make_shared(ramp_clamped, ext_factor_node); + auto ramp_mix = std::make_shared(ramp_inverted, ext_factor_node); return ramp_mix; } @@ -115,6 +120,7 @@ void ggml_rope_yarn_corr_dims(int n_dims, std::pair, ov::Output> make_sin_cos(int32_t * rope_params, std::shared_ptr inp_pos, std::shared_ptr rope_freqs_weight, + bool imrope, bool stateful) { if (stateful) { inp_pos = std::make_shared(inp_pos, ov::op::v0::Constant::create(ov::element::i64, {1}, {0})); @@ -122,6 +128,13 @@ std::pair, ov::Output> make_sin_cos(int32_t * rope_params auto pos_perm = std::make_shared(ov::element::i64, ov::Shape{3}, std::vector{2, 1, 0}); inp_pos = std::make_shared(inp_pos, pos_perm); + } else if (imrope) { + inp_pos = std::make_shared(inp_pos, ov::element::f32); + auto pos_shape = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{5}, {0, 0, 0, 4, -1}); + inp_pos = std::make_shared(inp_pos, pos_shape, true); + auto pos_transpose_shape = + std::make_shared(ov::element::i64, ov::Shape{5}, std::vector{0, 1, 2, 4, 3}); + inp_pos = std::make_shared(inp_pos, pos_transpose_shape); } else { inp_pos = std::make_shared(inp_pos, ov::element::f32); auto pos_perm = @@ -136,6 +149,7 @@ std::pair, ov::Output> make_sin_cos(int32_t * rope_params float beta_fast; float beta_slow; const int n_dims = rope_params[1]; + const size_t n_dims_half = n_dims >> 1; const int n_ctx_orig = rope_params[4]; memcpy(&freq_base, rope_params + 5, sizeof(float)); memcpy(&freq_scale, rope_params + 6, sizeof(float)); @@ -146,57 +160,74 @@ std::pair, ov::Output> make_sin_cos(int32_t * rope_params const float theta_scale = powf(freq_base, -2.0f / n_dims); - float corr_dims[2]; - ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); - - std::vector factor(n_dims / 2); - factor[0] = 1.0f; - for (size_t i = 1; i < factor.size(); i++) { - factor[i] = theta_scale * factor[i - 1]; - } + std::vector factor(n_dims_half); Output freq_factors; - if (stateful) { - freq_factors = - std::make_shared(ov::element::f32, ov::Shape{1, 1, factor.size()}, factor); - } else { - freq_factors = - std::make_shared(ov::element::f32, ov::Shape{1, 1, 1, factor.size()}, factor); - } - if (rope_freqs_weight) { - freq_factors = std::make_shared(freq_factors, rope_freqs_weight); - } - - auto theta_extrap = std::make_shared(freq_factors, inp_pos); - auto theta_interp = std::make_shared( - theta_extrap, ov::op::v0::Constant::create(ov::element::f32, {1}, {freq_scale})); Output theta; float mscale = attn_factor; - if (ext_factor == 0.0f) { - theta = theta_interp; + if (imrope) { + std::vector gather_indices(n_dims_half); + for (size_t j = 0; j < n_dims_half; j++) { + gather_indices[j] = j % 3; + factor[j] = std::pow(theta_scale, j); + } + auto gather_indices_const = + std::make_shared(ov::element::i64, ov::Shape{n_dims_half}, gather_indices); + auto gather_axis = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{}, {4}); + inp_pos = std::make_shared(inp_pos, gather_indices_const, gather_axis); + auto factor_const = std::make_shared(ov::element::f32, ov::Shape{n_dims_half}, factor); + theta = std::make_shared(inp_pos, factor_const); } else { - auto ramp_mix = rope_yarn_ramp_mix(n_dims, corr_dims, ext_factor); - Output one; + float corr_dims[2]; + ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); + factor[0] = 1.0f; + for (size_t i = 1; i < factor.size(); i++) { + factor[i] = theta_scale * factor[i - 1]; + } if (stateful) { - one = ov::op::v0::Constant::create(ov::element::f32, Shape{1, 1, 1}, {1.0f}); + freq_factors = + std::make_shared(ov::element::f32, ov::Shape{1, 1, factor.size()}, factor); } else { - one = ov::op::v0::Constant::create(ov::element::f32, Shape{1, 1, 1, 1}, {1.0f}); + freq_factors = + std::make_shared(ov::element::f32, ov::Shape{1, 1, 1, factor.size()}, factor); + } + if (rope_freqs_weight) { + freq_factors = std::make_shared(freq_factors, rope_freqs_weight); } - auto one_minus_ramp = std::make_shared(one, ramp_mix); - theta = std::make_shared(std::make_shared(theta_interp, one_minus_ramp), - std::make_shared(theta_extrap, ramp_mix)); - mscale *= (1.0f + 0.1f * std::log(1.0f / freq_scale)); + auto theta_extrap = std::make_shared(freq_factors, inp_pos); + auto theta_interp = std::make_shared( + theta_extrap, ov::op::v0::Constant::create(ov::element::f32, {1}, {freq_scale})); + + if (ext_factor == 0.0f) { + theta = theta_interp; + } else { + auto ramp_mix = rope_yarn_ramp_mix(n_dims, corr_dims, ext_factor); + Output one; + if (stateful) { + one = ov::op::v0::Constant::create(ov::element::f32, Shape{1, 1, 1}, {1.0f}); + } else { + one = ov::op::v0::Constant::create(ov::element::f32, Shape{1, 1, 1, 1}, {1.0f}); + } + auto one_minus_ramp = std::make_shared(one, ramp_mix); + + theta = std::make_shared(std::make_shared(theta_interp, one_minus_ramp), + std::make_shared(theta_extrap, ramp_mix)); + mscale *= (1.0f + 0.1f * std::log(1.0f / freq_scale)); + } } Output cos_theta = std::make_shared(theta); Output sin_theta = std::make_shared(theta); - auto mscale_node = ov::op::v0::Constant::create(ov::element::f32, Shape{}, {mscale}); + if (!imrope) { + auto mscale_node = ov::op::v0::Constant::create(ov::element::f32, Shape{}, {mscale}); + + cos_theta = std::make_shared(cos_theta, mscale_node); + sin_theta = std::make_shared(sin_theta, mscale_node); + } - cos_theta = std::make_shared(cos_theta, mscale_node); - sin_theta = std::make_shared(sin_theta, mscale_node); return std::make_pair(sin_theta, cos_theta); } diff --git a/ggml/src/ggml-openvino/openvino/utils.h b/ggml/src/ggml-openvino/openvino/utils.h index 88dcad4c906..767dd4c53ea 100644 --- a/ggml/src/ggml-openvino/openvino/utils.h +++ b/ggml/src/ggml-openvino/openvino/utils.h @@ -67,6 +67,7 @@ OutputVector rename_outputs_with_suffix(const OutputVector& outputs, const std:: std::pair, ov::Output> make_sin_cos(int32_t* rope_params, std::shared_ptr inp_pos, std::shared_ptr rope_freqs_weight = nullptr, + bool imrope = false, bool stateful = false); ov::Output process_view_input(const NodeContext& context, int input_index, int slice_len = 0); diff --git a/ggml/src/ggml-openvino/utils.cpp b/ggml/src/ggml-openvino/utils.cpp index 1b553a0de00..998ef7c9eb4 100644 --- a/ggml/src/ggml-openvino/utils.cpp +++ b/ggml/src/ggml-openvino/utils.cpp @@ -81,8 +81,8 @@ ov::Tensor create_ov_output_tensor(std::shared_ptr ggml_decoder, enum ggml_status ov_graph_compute_dynamic(ggml_cgraph * cgraph, std::shared_ptr r_ctx) { auto & core = ov_singleton_core(); const auto & config = ggml_openvino_get_compile_config(); - auto device = r_ctx->device; - bool stateful = r_ctx->stateful; + const auto & device = r_ctx->device; + const auto & stateful = r_ctx->stateful; static auto is_static = false; if (is_naive(cgraph)) { @@ -106,14 +106,26 @@ enum ggml_status ov_graph_compute_dynamic(ggml_cgraph * cgraph, std::shared_ptr< int64_t infer_end_time; { - std::lock_guard lock(r_ctx->ov_compute_mutex); + std::shared_ptr entry; + ModelParams old_m_params; - auto it = r_ctx->decoder_cache.find(key); + { + std::lock_guard map_lock(r_ctx->ctx_mutex); + auto it = r_ctx->decoder_cache.find(key); + cache_hit = it != r_ctx->decoder_cache.end(); + if (cache_hit) { + entry = it->second; + } else { + auto mutex = std::make_shared(); + entry = std::make_shared(mutex); + r_ctx->decoder_cache[key] = entry; + } + } + + std::lock_guard lock(*(entry->mutex)); - cache_hit = it != r_ctx->decoder_cache.end(); - ModelParams old_m_params; if (cache_hit) { - ggml_decoder = it->second; + ggml_decoder = entry->ptr; old_m_params = ggml_decoder->get_model_params(); cache_hit = old_m_params.can_reuse_dynamically(m_params); } @@ -126,7 +138,10 @@ enum ggml_status ov_graph_compute_dynamic(ggml_cgraph * cgraph, std::shared_ptr< ggml_decoder->update_io(cgraph); } ggml_decoder->add_extra_inputs(); - infer_request = r_ctx->infer_request_cache.at(key); + { + std::lock_guard map_lock(r_ctx->ctx_mutex); + infer_request = r_ctx->infer_request_cache.at(key); + } if (stateful) { const auto * inp_pos = get_inp_pos_tensor(cgraph); @@ -170,7 +185,10 @@ enum ggml_status ov_graph_compute_dynamic(ggml_cgraph * cgraph, std::shared_ptr< conversion_end_time = decoder_end_time; compile_end_time = decoder_end_time; } else { - r_ctx->infer_request_cache.erase(key); + { + std::lock_guard map_lock(r_ctx->ctx_mutex); + r_ctx->infer_request_cache.erase(key); + } std::shared_ptr model; auto model_weights = GgmlOvDecoder::create_weight_nodes(cgraph); @@ -199,8 +217,7 @@ enum ggml_status ov_graph_compute_dynamic(ggml_cgraph * cgraph, std::shared_ptr< } compile_end_time = ggml_time_us(); infer_request = std::make_shared(compiled_model.create_infer_request()); - r_ctx->infer_request_cache[key] = infer_request; - r_ctx->decoder_cache[key] = ggml_decoder; + entry->ptr = ggml_decoder; std::vector ov_input_names; std::vector ov_output_names; @@ -210,8 +227,13 @@ enum ggml_status ov_graph_compute_dynamic(ggml_cgraph * cgraph, std::shared_ptr< for (const auto & ov_output : model->get_results()) { ov_output_names.push_back(ov_output->get_friendly_name()); } - r_ctx->ov_input_names_cache[key] = std::move(ov_input_names); - r_ctx->ov_output_names_cache[key] = std::move(ov_output_names); + + { + std::lock_guard map_lock(r_ctx->ctx_mutex); + r_ctx->infer_request_cache[key] = infer_request; + r_ctx->ov_input_names_cache[key] = std::move(ov_input_names); + r_ctx->ov_output_names_cache[key] = std::move(ov_output_names); + } if (stateful) { const auto * inp_pos = get_inp_pos_tensor(cgraph); @@ -224,8 +246,13 @@ enum ggml_status ov_graph_compute_dynamic(ggml_cgraph * cgraph, std::shared_ptr< } } - auto ov_input_names = r_ctx->ov_input_names_cache[key]; - auto ov_output_names = r_ctx->ov_output_names_cache[key]; + std::vector ov_input_names; + std::vector ov_output_names; + { + std::lock_guard map_lock(r_ctx->ctx_mutex); + ov_input_names = r_ctx->ov_input_names_cache[key]; + ov_output_names = r_ctx->ov_output_names_cache[key]; + } for (size_t i = 0; i < ov_input_names.size(); i++) { auto param_name = ov_input_names[i]; @@ -306,12 +333,26 @@ enum ggml_status ov_graph_compute_static(ggml_cgraph * cgraph, std::shared_ptrdecoder_cache.find(key); - - cache_hit = it != r_ctx->decoder_cache.end(); + std::shared_ptr entry; ModelParams old_m_params; + + { + std::lock_guard map_lock(r_ctx->ctx_mutex); + auto it = r_ctx->decoder_cache.find(key); + cache_hit = it != r_ctx->decoder_cache.end(); + if (cache_hit) { + entry = it->second; + } else { + auto mutex = std::make_shared(); + entry = std::make_shared(mutex); + r_ctx->decoder_cache[key] = entry; + } + } + + std::lock_guard lock(*(entry->mutex)); + if (cache_hit) { - ggml_decoder = it->second; + ggml_decoder = entry->ptr; old_m_params = ggml_decoder->get_model_params(); cache_hit = old_m_params.can_reuse_statically(m_params); } @@ -325,14 +366,21 @@ enum ggml_status ov_graph_compute_static(ggml_cgraph * cgraph, std::shared_ptrupdate_io(cgraph); } ggml_decoder->add_extra_inputs(); - infer_request = is_prefill ? r_ctx->infer_request_cache_prefill.at(key) : r_ctx->infer_request_cache.at(key); + { + std::lock_guard map_lock(r_ctx->ctx_mutex); + infer_request = + is_prefill ? r_ctx->infer_request_cache_prefill.at(key) : r_ctx->infer_request_cache.at(key); + } decoder_end_time = ggml_time_us(); conversion_end_time = decoder_end_time; compile_end_time = decoder_end_time; } else { - r_ctx->infer_request_cache.erase(key); - r_ctx->infer_request_cache_prefill.erase(key); + { + std::lock_guard map_lock(r_ctx->ctx_mutex); + r_ctx->infer_request_cache.erase(key); + r_ctx->infer_request_cache_prefill.erase(key); + } std::shared_ptr model; auto model_weights = GgmlOvDecoder::create_weight_nodes(cgraph); @@ -372,16 +420,14 @@ enum ggml_status ov_graph_compute_static(ggml_cgraph * cgraph, std::shared_ptrinfer_request_cache_prefill[key] = - std::make_shared(compiled_model_prefill.create_infer_request()); - r_ctx->infer_request_cache[key] = - std::make_shared(compiled_model_decode.create_infer_request()); + auto infer_request_prefill = std::make_shared(compiled_model_prefill.create_infer_request()); + auto infer_request_decode = std::make_shared(compiled_model_decode.create_infer_request()); compile_end_time = ggml_time_us(); model = is_prefill ? model_prefill : model_decode; ggml_decoder = is_prefill ? ggml_decoder_prefill : ggml_decoder_decode; - infer_request = is_prefill ? r_ctx->infer_request_cache_prefill[key] : r_ctx->infer_request_cache[key]; - r_ctx->decoder_cache[key] = ggml_decoder; + infer_request = is_prefill ? infer_request_prefill : infer_request_decode; + entry->ptr = ggml_decoder; std::vector ov_input_names; std::vector ov_output_names; @@ -391,18 +437,29 @@ enum ggml_status ov_graph_compute_static(ggml_cgraph * cgraph, std::shared_ptrget_results()) { ov_output_names.push_back(ov_output->get_friendly_name()); } - r_ctx->ov_input_names_cache[key] = std::move(ov_input_names); - r_ctx->ov_output_names_cache[key] = std::move(ov_output_names); + + { + std::lock_guard map_lock(r_ctx->ctx_mutex); + r_ctx->infer_request_cache_prefill[key] = infer_request_prefill; + r_ctx->infer_request_cache[key] = infer_request_decode; + r_ctx->ov_input_names_cache[key] = std::move(ov_input_names); + r_ctx->ov_output_names_cache[key] = std::move(ov_output_names); + } } - auto ov_input_names = r_ctx->ov_input_names_cache[key]; - auto ov_output_names = r_ctx->ov_output_names_cache[key]; + std::vector ov_input_names_local; + std::vector ov_output_names_local; + { + std::lock_guard map_lock(r_ctx->ctx_mutex); + ov_input_names_local = r_ctx->ov_input_names_cache[key]; + ov_output_names_local = r_ctx->ov_output_names_cache[key]; + } if (is_prefill) { auto inp_len = inp_pos->ne[0]; for (int chunk_index = 0; chunk_index * prefill_chunk_size < inp_len; chunk_index++) { - for (size_t i = 0; i < ov_input_names.size(); i++) { - auto param_name = ov_input_names[i]; + for (size_t i = 0; i < ov_input_names_local.size(); i++) { + auto param_name = ov_input_names_local[i]; auto input_tensor = get_ov_input_tensor_static_prefill(ggml_decoder, param_name, chunk_index); infer_request->set_input_tensor(i, input_tensor); @@ -412,8 +469,8 @@ enum ggml_status ov_graph_compute_static(ggml_cgraph * cgraph, std::shared_ptrget_model_outputs().at(ov_output_names[i]); + for (size_t i = 0; i < ov_output_names_local.size(); i++) { + auto * ggml_tensor = ggml_decoder->get_model_outputs().at(ov_output_names_local[i]); auto output_tensor = create_ov_output_tensor(ggml_decoder, infer_request, i, ggml_tensor); infer_request->set_output_tensor(i, output_tensor); } @@ -421,16 +478,16 @@ enum ggml_status ov_graph_compute_static(ggml_cgraph * cgraph, std::shared_ptrinfer(); if (getenv("GGML_OPENVINO_DEBUG_OUTPUT")) { - for (size_t i = 0; i < ov_output_names.size(); i++) { + for (size_t i = 0; i < ov_output_names_local.size(); i++) { const auto output_tensor = infer_request->get_output_tensor(i); - print_output_tensor_info(ov_output_names[i], output_tensor, output_tensor.data()); + print_output_tensor_info(ov_output_names_local[i], output_tensor, output_tensor.data()); } } } infer_end_time = ggml_time_us(); } else { - for (size_t i = 0; i < ov_input_names.size(); i++) { - auto param_name = ov_input_names[i]; + for (size_t i = 0; i < ov_input_names_local.size(); i++) { + auto param_name = ov_input_names_local[i]; auto input_tensor = get_ov_input_tensor_static_decode(ggml_decoder, param_name); infer_request->set_input_tensor(i, input_tensor); @@ -440,8 +497,8 @@ enum ggml_status ov_graph_compute_static(ggml_cgraph * cgraph, std::shared_ptrget_model_outputs().at(ov_output_names[i]); + for (size_t i = 0; i < ov_output_names_local.size(); i++) { + auto * ggml_tensor = ggml_decoder->get_model_outputs().at(ov_output_names_local[i]); auto output_tensor = create_ov_output_tensor(ggml_decoder, infer_request, i, ggml_tensor); infer_request->set_output_tensor(i, output_tensor); } @@ -450,9 +507,9 @@ enum ggml_status ov_graph_compute_static(ggml_cgraph * cgraph, std::shared_ptrget_output_tensor(i); - print_output_tensor_info(ov_output_names[i], output_tensor, output_tensor.data()); + print_output_tensor_info(ov_output_names_local[i], output_tensor, output_tensor.data()); } } } diff --git a/ggml/src/ggml-openvino/utils.h b/ggml/src/ggml-openvino/utils.h index 656573d1389..2c72e33c352 100644 --- a/ggml/src/ggml-openvino/utils.h +++ b/ggml/src/ggml-openvino/utils.h @@ -3,12 +3,15 @@ #include "ggml-impl.h" #include +#include #include #include +#include #include #include #include #include +#include #include struct graph_key { @@ -40,11 +43,17 @@ struct graph_key_hash { } }; +struct decoder_runtime_ctx { + decoder_runtime_ctx(std::shared_ptr mutex) : mutex(std::move(mutex)) {} + std::shared_ptr mutex; + std::shared_ptr ptr; +}; + struct ov_runtime_context { - std::mutex ov_compute_mutex; + mutable std::mutex ctx_mutex; std::string device; bool stateful; - std::unordered_map, graph_key_hash> decoder_cache; + std::unordered_map, graph_key_hash> decoder_cache; std::unordered_map, graph_key_hash> infer_request_cache; std::unordered_map, graph_key_hash> infer_request_cache_prefill; std::unordered_map, graph_key_hash> ov_input_names_cache; @@ -53,11 +62,22 @@ struct ov_runtime_context { // Simultanous stateful inference request support to be added. size_t stateful_kv_size; std::map kv_state_input_name_map; + std::atomic backend_count; ov_runtime_context() : device("CPU"), stateful(false), - stateful_kv_size(0) {} + stateful_kv_size(0), + backend_count(0) {} + + void clear_caches() { + std::lock_guard lock(ctx_mutex); + decoder_cache.clear(); + infer_request_cache.clear(); + infer_request_cache_prefill.clear(); + ov_input_names_cache.clear(); + ov_output_names_cache.clear(); + } }; enum ggml_status ov_graph_compute(struct ggml_cgraph * cgraph, ggml_backend_t backend); From e2014d6959fd6194d434daf1ea199715b427beba Mon Sep 17 00:00:00 2001 From: Mengsheng Wu Date: Wed, 22 Apr 2026 04:53:44 +0800 Subject: [PATCH 469/831] hexagon: fix missing v79 entry in libggml-htp.inf (llama/22194) --- ggml/src/ggml-hexagon/libggml-htp.inf | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ggml/src/ggml-hexagon/libggml-htp.inf b/ggml/src/ggml-hexagon/libggml-htp.inf index 656d2d9ab26..360d8b1228e 100644 --- a/ggml/src/ggml-hexagon/libggml-htp.inf +++ b/ggml/src/ggml-hexagon/libggml-htp.inf @@ -18,6 +18,7 @@ libggml-htp-v68.so = 1 libggml-htp-v69.so = 1 libggml-htp-v73.so = 1 libggml-htp-v75.so = 1 +libggml-htp-v79.so = 1 libggml-htp-v81.so = 1 [ControlFlags] @@ -31,6 +32,7 @@ libggml-htp-v68.so,,,0x10 ;COPYFLG_NO_OVERWRITE libggml-htp-v69.so,,,0x10 ;COPYFLG_NO_OVERWRITE libggml-htp-v73.so,,,0x10 ;COPYFLG_NO_OVERWRITE libggml-htp-v75.so,,,0x10 ;COPYFLG_NO_OVERWRITE +libggml-htp-v79.so,,,0x10 ;COPYFLG_NO_OVERWRITE libggml-htp-v81.so,,,0x10 ;COPYFLG_NO_OVERWRITE [Strings] From 84a6b5c03903504ccfd7bfad321d8c6dc9fbd708 Mon Sep 17 00:00:00 2001 From: Shreya Jain Date: Tue, 21 Apr 2026 14:16:04 -0700 Subject: [PATCH 470/831] Hexagon: DAIG op (llama/22195) * hexagon: Add DIAG op * hexagon: add HVX support and DMA double buffering * hexagon: fix fatal error * hexagon: remove as many pragma(s) as possible --- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 28 +++ ggml/src/ggml-hexagon/htp/CMakeLists.txt | 1 + ggml/src/ggml-hexagon/htp/diag-ops.c | 216 +++++++++++++++++++++++ ggml/src/ggml-hexagon/htp/htp-ctx.h | 1 + ggml/src/ggml-hexagon/htp/htp-ops.h | 1 + ggml/src/ggml-hexagon/htp/main.c | 3 + 6 files changed, 250 insertions(+) create mode 100644 ggml/src/ggml-hexagon/htp/diag-ops.c diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 3d68b80048f..5e206c5e9de 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -2596,6 +2596,29 @@ static bool ggml_hexagon_supported_cumsum(const struct ggml_hexagon_session * se return true; } +static bool ggml_hexagon_supported_diag(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + const struct ggml_tensor * src0 = op->src[0]; + const struct ggml_tensor * dst = op; + + // diag only supports F32 currently + if (src0->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) { + return false; + } + + // Input must have ne[1] == 1 (vector input) + if (src0->ne[1] != 1) { + return false; + } + + // Output must be square in first two dimensions + if (dst->ne[0] != dst->ne[1] || dst->ne[0] != src0->ne[0]) { + return false; + } + + GGML_UNUSED(sess); + return true; +} + static const char * ggml_backend_hexagon_name(ggml_backend_t backend) { auto sess = static_cast(backend->context); return sess->c_name(); @@ -2632,6 +2655,7 @@ static htp_op_code op_remap_to_htp(const ggml_tensor * t) { case GGML_OP_ROPE: return HTP_OP_ROPE; case GGML_OP_REPEAT: return HTP_OP_REPEAT; case GGML_OP_CUMSUM: return HTP_OP_CUMSUM; + case GGML_OP_DIAG: return HTP_OP_DIAG; case GGML_OP_UNARY: switch (ggml_get_unary_op(t)) { @@ -3159,6 +3183,10 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons supp = ggml_hexagon_supported_cumsum(sess, op); break; + case GGML_OP_DIAG: + supp = ggml_hexagon_supported_diag(sess, op); + break; + default: break; } diff --git a/ggml/src/ggml-hexagon/htp/CMakeLists.txt b/ggml/src/ggml-hexagon/htp/CMakeLists.txt index 9ca759459d4..82c10b57bbf 100644 --- a/ggml/src/ggml-hexagon/htp/CMakeLists.txt +++ b/ggml/src/ggml-hexagon/htp/CMakeLists.txt @@ -34,6 +34,7 @@ add_library(${HTP_LIB} SHARED argsort-ops.c ssm-conv.c cumsum-ops.c + diag-ops.c ) target_compile_definitions(${HTP_LIB} PRIVATE diff --git a/ggml/src/ggml-hexagon/htp/diag-ops.c b/ggml/src/ggml-hexagon/htp/diag-ops.c new file mode 100644 index 00000000000..9b3194d9084 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/diag-ops.c @@ -0,0 +1,216 @@ +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#include +#include + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" +#include "htp-ops.h" +#include "hvx-types.h" +#include "hex-utils.h" +#include "hvx-copy.h" +#include "hex-dma.h" + +#define htp_diag_tensors_preamble \ + const struct htp_tensor * restrict src0 = octx->src[0]; \ + const struct htp_tensor * restrict dst = octx->dst; \ + \ + const uint32_t ne02 = src0->ne[2]; \ + \ + const uint32_t ne0 = dst->ne[0]; \ + const uint32_t ne1 = dst->ne[1]; \ + \ + const uint32_t nb02 = src0->nb[2]; \ + const uint32_t nb03 = src0->nb[3]; \ + \ + const uint32_t nb1 = dst->nb[1]; \ + const uint32_t nb2 = dst->nb[2]; \ + const uint32_t nb3 = dst->nb[3]; + +struct htp_diag_context { + struct htp_ops_context * octx; + size_t src_batch_size; + size_t dst_row_size; + size_t src_batch_size_aligned; + size_t dst_row_size_aligned; + uint32_t batches_per_thread; + uint32_t total_batches; +}; + +#define htp_diag_preamble \ + struct htp_diag_context * dctx = (struct htp_diag_context *) data; \ + struct htp_ops_context * octx = dctx->octx; \ + htp_diag_tensors_preamble; + +static inline void hvx_diag_row_f32(const float * restrict src, float * restrict dst, + uint32_t row_idx, uint32_t n) { + hvx_splat_f32_a((uint8_t *) dst, 0.0f, n); + dst[row_idx] = src[row_idx]; +} + +// --------------------------------------------------------------------------- +// Per thread worker: DMA src fetch, compute in VTCM, DMA dst writeback +// --------------------------------------------------------------------------- + +static void diag_thread_f32_dma(unsigned int nth, unsigned int ith, void * data) { + htp_diag_preamble; + dma_queue * dma_queue = octx->ctx->dma[ith]; + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + const uint32_t ib0 = dctx->batches_per_thread * ith; + const uint32_t ib1 = MIN(ib0 + dctx->batches_per_thread, dctx->total_batches); + + if (ib0 >= ib1) { + return; + } + + const size_t src_batch_size = dctx->src_batch_size; + const size_t dst_row_size = dctx->dst_row_size; + const size_t src_batch_size_aligned = dctx->src_batch_size_aligned; + const size_t dst_row_size_aligned = dctx->dst_row_size_aligned; + + const uint8_t * src_data = (const uint8_t *) src0->data; + uint8_t * dst_data = (uint8_t *) dst->data; + + // 1 src buffer + 1 dst row buffer per thread in VTCM + uint8_t * src_spad = octx->src0_spad.data + (ith * src_batch_size_aligned); + uint8_t * dst_spad = octx->dst_spad.data + (ith * dst_row_size_aligned); + + for (uint32_t ib = ib0; ib < ib1; ib++) { + const uint32_t i3 = ib / ne02; + const uint32_t i2 = ib % ne02; + + const uint8_t * src_batch = src_data + i3 * nb03 + i2 * nb02; + + // Fetch source vector into VTCM + dma_queue_push_ddr_to_vtcm(dma_queue, + dma_make_ptr(src_spad, src_batch), + src_batch_size_aligned, src_batch_size, 1); + dma_queue_flush(dma_queue); + + const float * src_spad_f32 = (const float *) src_spad; + float * dst_spad_f32 = (float *) dst_spad; + + for (uint32_t i1 = 0; i1 < ne1; i1++) { + // Compute row in VTCM + hvx_diag_row_f32(src_spad_f32, dst_spad_f32, i1, ne0); + + // Write completed row back to DDR + uint8_t * dst_row = dst_data + i3 * nb3 + i2 * nb2 + i1 * nb1; + dma_queue_push_vtcm_to_ddr(dma_queue, + dma_make_ptr(dst_row, dst_spad), + dst_row_size, dst_row_size_aligned, 1); + dma_queue_flush(dma_queue); + } + } + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "diag-f32-dma %d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", + ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ib0, ib1, + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +// --------------------------------------------------------------------------- +// Per thread worker: Direct HVX (no DMA) +// --------------------------------------------------------------------------- + +static void diag_thread_f32(unsigned int nth, unsigned int ith, void * data) { + htp_diag_preamble; + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + const uint8_t * src_data = (const uint8_t *) src0->data; + uint8_t * dst_data = (uint8_t *) dst->data; + + const uint32_t ib0 = dctx->batches_per_thread * ith; + const uint32_t ib1 = MIN(ib0 + dctx->batches_per_thread, dctx->total_batches); + + for (uint32_t ib = ib0; ib < ib1; ib++) { + const uint32_t i3 = ib / ne02; + const uint32_t i2 = ib % ne02; + + const float * restrict src_batch = (const float *)(src_data + i3 * nb03 + i2 * nb02); + + for (uint32_t i1 = 0; i1 < ne1; i1++) { + float * restrict dst_row = (float *)(dst_data + i3 * nb3 + i2 * nb2 + i1 * nb1); + hvx_diag_row_f32(src_batch, dst_row, i1, ne0); + } + } + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "diag-f32 %d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", + ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ib0, ib1, + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +int op_diag_f32(struct htp_ops_context * octx) { + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * dst = octx->dst; + + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { + return HTP_STATUS_OK; + } + + const uint32_t total_batches = src0->ne[2] * src0->ne[3]; + const uint32_t n_threads = MIN(octx->n_threads, total_batches); + + const size_t src_batch_size = src0->ne[0] * sizeof(float); + const size_t dst_row_size = dst->ne[0] * sizeof(float); + const size_t src_batch_size_aligned = hex_round_up(src_batch_size, VLEN); + const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN); + + // 1 src buffer + 1 dst row buffer per thread + const size_t spad_per_thread = src_batch_size_aligned + dst_row_size_aligned; + + octx->src0_spad.size_per_thread = src_batch_size_aligned; + octx->dst_spad.size_per_thread = dst_row_size_aligned; + + octx->src0_spad.size = n_threads * octx->src0_spad.size_per_thread; + octx->dst_spad.size = n_threads * octx->dst_spad.size_per_thread; + + octx->src0_spad.data = octx->ctx->vtcm_base; octx->src0_spad.src = NULL; + octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size; octx->dst_spad.src = NULL; + + struct htp_diag_context dctx = { + .octx = octx, + .src_batch_size = src_batch_size, + .dst_row_size = dst_row_size, + .src_batch_size_aligned = src_batch_size_aligned, + .dst_row_size_aligned = dst_row_size_aligned, + .batches_per_thread = (total_batches + n_threads - 1) / n_threads, + .total_batches = total_batches, + }; + + if (octx->ctx->vtcm_size < spad_per_thread * n_threads) { + worker_pool_run_func(octx->ctx->worker_pool, diag_thread_f32, &dctx, n_threads); + } else { + worker_pool_run_func(octx->ctx->worker_pool, diag_thread_f32_dma, &dctx, n_threads); + } + + return HTP_STATUS_OK; +} + +int op_diag(struct htp_ops_context * octx) { + const struct htp_tensor * dst = octx->dst; + + int err = HTP_STATUS_OK; + + switch (dst->type) { + case HTP_TYPE_F32: + err = op_diag_f32(octx); + break; + default: + err = HTP_STATUS_NO_SUPPORT; + break; + } + + return err; +} diff --git a/ggml/src/ggml-hexagon/htp/htp-ctx.h b/ggml/src/ggml-hexagon/htp/htp-ctx.h index 8b5e47adef8..038941af0f2 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ctx.h +++ b/ggml/src/ggml-hexagon/htp/htp-ctx.h @@ -98,5 +98,6 @@ int op_repeat(struct htp_ops_context * octx); int op_argsort(struct htp_ops_context * octx); int op_ssm_conv(struct htp_ops_context * octx); int op_cumsum(struct htp_ops_context * octx); +int op_diag(struct htp_ops_context * octx); #endif /* HTP_CTX_H */ diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index 79b5ecd2270..002dd1c12d2 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -80,6 +80,7 @@ enum htp_op_code { HTP_OP_SSM_CONV, HTP_OP_REPEAT, HTP_OP_CUMSUM, + HTP_OP_DIAG, HTP_OP_INVALID }; diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index 5091623a653..d633145c909 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -514,6 +514,9 @@ static int execute_op(struct htp_ops_context * octx) { case HTP_OP_CUMSUM: return op_cumsum(octx); + case HTP_OP_DIAG: + return op_diag(octx); + case HTP_OP_INVALID: break; From 2e5eb6e9512a51129827698576307c7d4f5148d4 Mon Sep 17 00:00:00 2001 From: Masashi Yoshimura Date: Wed, 22 Apr 2026 08:05:21 +0900 Subject: [PATCH 471/831] ggml-webgpu: reset CPU/GPU profiling time when freeing context (llama/22050) * Reset the CPU/GPU profiling time when freeing context. * move GPU profiling time from global context to webgpu_context. --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index aa20a745e0a..a2923145230 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -211,6 +211,7 @@ struct webgpu_global_context_struct { wgpu::Buffer memset_params_buf; webgpu_pipeline memset_pipeline; + // TODO: We should rework the CPU profiling time handling to make it more useful. ref: https://github.com/ggml-org/llama.cpp/pull/22050 #ifdef GGML_WEBGPU_CPU_PROFILE // Profiling: labeled CPU time in ms (total) std::unordered_map cpu_time_ms; @@ -218,11 +219,6 @@ struct webgpu_global_context_struct { std::unordered_map cpu_detail_ms; #endif -#ifdef GGML_WEBGPU_GPU_PROFILE - // Profiling: per-shader GPU time in ms - std::unordered_map shader_gpu_time_ms; -#endif - #ifdef GGML_WEBGPU_DEBUG wgpu::Buffer debug_host_buf; wgpu::Buffer debug_dev_buf; @@ -268,10 +264,12 @@ struct webgpu_context_struct { size_t memset_bytes_per_thread; #ifdef GGML_WEBGPU_GPU_PROFILE - wgpu::Buffer profile_timestamp_dev_buf; - wgpu::Buffer profile_timestamp_host_buf; - wgpu::QuerySet profile_timestamp_query_set; - uint32_t profile_timestamp_query_count = 0; + // Profiling: per-shader GPU time in ms + std::unordered_map shader_gpu_time_ms; + wgpu::Buffer profile_timestamp_dev_buf; + wgpu::Buffer profile_timestamp_host_buf; + wgpu::QuerySet profile_timestamp_query_set; + uint32_t profile_timestamp_query_count = 0; #endif ~webgpu_context_struct() { @@ -713,12 +711,12 @@ static void ggml_backend_webgpu_free(ggml_backend_t backend) { #ifdef GGML_WEBGPU_GPU_PROFILE std::cout << "\n[ggml_webgpu gpu profiling summary]\n"; double total_gpu = 0.0; - for (const auto & kv : ctx->webgpu_ctx->global_ctx->shader_gpu_time_ms) { + for (const auto & kv : ctx->webgpu_ctx->shader_gpu_time_ms) { total_gpu += kv.second; } std::cout << "ggml_webgpu: total gpu time (all shaders): " << total_gpu << " ms\n"; std::cout << "\nggml_webgpu: gpu breakdown:\n"; - for (const auto & kv : ctx->webgpu_ctx->global_ctx->shader_gpu_time_ms) { + for (const auto & kv : ctx->webgpu_ctx->shader_gpu_time_ms) { double pct = (total_gpu > 0.0) ? (kv.second / total_gpu * 100.0) : 0.0; std::cout << "ggml_webgpu: " << kv.first << ": " << kv.second << " ms (" << std::fixed << std::setprecision(2) << pct << "%)\n"; @@ -2511,7 +2509,7 @@ static void ggml_backend_webgpu_collect_profile_results(webgpu_context & for (size_t i = 0; i < pipeline_names.size(); ++i) { // WebGPU timestamps are in ns; convert to ms. const double elapsed_ms = double(ts_data[2 * i + 1] - ts_data[2 * i]) * 1e-6; - ctx->global_ctx->shader_gpu_time_ms[pipeline_names[i]] += elapsed_ms; + ctx->shader_gpu_time_ms[pipeline_names[i]] += elapsed_ms; } ctx->profile_timestamp_host_buf.Unmap(); From d6a417408c5a764ff484a0210d5d99a55af9d8c9 Mon Sep 17 00:00:00 2001 From: Aparna M P Date: Wed, 22 Apr 2026 04:54:20 +0530 Subject: [PATCH 472/831] hexagon: add support for FILL op (llama/22198) Co-authored-by: Max Krasnyansky --- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 16 +++ ggml/src/ggml-hexagon/htp/CMakeLists.txt | 1 + ggml/src/ggml-hexagon/htp/fill-ops.c | 123 +++++++++++++++++++++++ ggml/src/ggml-hexagon/htp/htp-ctx.h | 1 + ggml/src/ggml-hexagon/htp/htp-ops.h | 1 + ggml/src/ggml-hexagon/htp/main.c | 3 + 6 files changed, 145 insertions(+) create mode 100644 ggml/src/ggml-hexagon/htp/fill-ops.c diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 5e206c5e9de..cdd9fcf5928 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -2655,6 +2655,7 @@ static htp_op_code op_remap_to_htp(const ggml_tensor * t) { case GGML_OP_ROPE: return HTP_OP_ROPE; case GGML_OP_REPEAT: return HTP_OP_REPEAT; case GGML_OP_CUMSUM: return HTP_OP_CUMSUM; + case GGML_OP_FILL: return HTP_OP_FILL; case GGML_OP_DIAG: return HTP_OP_DIAG; case GGML_OP_UNARY: @@ -3053,6 +3054,17 @@ static bool ggml_hexagon_supported_repeat(const struct ggml_hexagon_session * se return true; } +static bool ggml_hexagon_supported_fill(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + const struct ggml_tensor * dst = op; + + if (dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16) { + return false; + } + + GGML_UNUSED(sess); + return true; +} + static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) { auto sess = static_cast(dev->context); @@ -3183,6 +3195,10 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons supp = ggml_hexagon_supported_cumsum(sess, op); break; + case GGML_OP_FILL: + supp = ggml_hexagon_supported_fill(sess, op); + break; + case GGML_OP_DIAG: supp = ggml_hexagon_supported_diag(sess, op); break; diff --git a/ggml/src/ggml-hexagon/htp/CMakeLists.txt b/ggml/src/ggml-hexagon/htp/CMakeLists.txt index 82c10b57bbf..b1ae60a9c43 100644 --- a/ggml/src/ggml-hexagon/htp/CMakeLists.txt +++ b/ggml/src/ggml-hexagon/htp/CMakeLists.txt @@ -34,6 +34,7 @@ add_library(${HTP_LIB} SHARED argsort-ops.c ssm-conv.c cumsum-ops.c + fill-ops.c diag-ops.c ) diff --git a/ggml/src/ggml-hexagon/htp/fill-ops.c b/ggml/src/ggml-hexagon/htp/fill-ops.c new file mode 100644 index 00000000000..3ccfbe74ee4 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/fill-ops.c @@ -0,0 +1,123 @@ +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#include +#include + +#include + +#include "hvx-copy.h" +#include "hvx-utils.h" + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" +#include "htp-ops.h" + +// ggml op_params layout for FILL: +// op_params[0] (as float) - the scalar fill value + +#define fill_preamble \ + const struct htp_tensor * dst = octx->dst; \ + \ + const uint32_t ne0 = dst->ne[0]; \ + const uint32_t ne1 = dst->ne[1]; \ + const uint32_t ne2 = dst->ne[2]; \ + const uint32_t ne3 = dst->ne[3]; \ + \ + const uint32_t nb1 = dst->nb[1]; \ + const uint32_t nb2 = dst->nb[2]; \ + const uint32_t nb3 = dst->nb[3]; \ + \ + const uint32_t nr = ne1 * ne2 * ne3; + +struct htp_fill_context { + struct htp_ops_context * octx; + uint32_t nrows_per_thread; + uint32_t total_rows; // ne1 * ne2 * ne3 + bool opt_path; + HVX_Vector splat_vec; + uint32_t elem_size; +}; + +static void fill_thread(unsigned int nth, unsigned int ith, void * data) { + const struct htp_fill_context * fctx = (const struct htp_fill_context *) data; + struct htp_ops_context * octx = fctx->octx; + fill_preamble; + + // Parallelise over the flat row index spanning ne1*ne2*ne3 + const uint32_t ir0 = fctx->nrows_per_thread * ith; + const uint32_t ir1 = MIN(ir0 + fctx->nrows_per_thread, fctx->total_rows); + + uint64_t t1 = HAP_perf_get_qtimer_count(); + + if (fctx->opt_path) { + // Opt path: tensor is fully contiguous, treat as flat array + const uint32_t elem_start = ir0 * ne0; + const uint32_t elem_end = ir1 * ne0; + uint8_t * dst_ptr = (uint8_t *) dst->data + elem_start * fctx->elem_size; + hvx_splat_u(dst_ptr, fctx->splat_vec, elem_end - elem_start, fctx->elem_size); + } else { + // Non-contiguous path: must respect strides + for (uint32_t ir = ir0; ir < ir1; ++ir) { + const uint32_t i1 = ir % ne1; + const uint32_t i2 = (ir / ne1) % ne2; + const uint32_t i3 = ir / (ne1 * ne2); + uint8_t * dst_ptr = (uint8_t *) dst->data + i1*nb1 + i2*nb2 + i3*nb3; + hvx_splat_u(dst_ptr, fctx->splat_vec, ne0, fctx->elem_size); + } + } + + uint64_t t2 = HAP_perf_get_qtimer_count(); + FARF(HIGH, "fill %u/%u: rows %u:%u usec %u\n", + ith, nth, ir0, ir1, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +int op_fill(struct htp_ops_context * octx) { + fill_preamble; + + if (dst->type != HTP_TYPE_F32 && dst->type != HTP_TYPE_F16) { + return HTP_STATUS_NO_SUPPORT; + } + + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { + return HTP_STATUS_OK; + } + + // nr = ne1*ne2*ne3 (flat row count across all outer dims); parallelise over it. + const uint32_t n_threads = MIN(nr, octx->n_threads); + + // Optimize if fully contiguous: skip stride arithmetic, treat as flat array + const bool opt_path = (nb2 == nb1 * ne1) && (nb3 == nb2 * ne2); + + FARF(HIGH, "fill: (%ux%ux%ux%u) type=%u opt=%d\n", + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], dst->type, (int) opt_path); + + float val_f32 = 0.f; + memcpy(&val_f32, &octx->op_params[0], sizeof(float)); + + struct htp_fill_context fctx = { + .octx = octx, + .nrows_per_thread = (nr + n_threads - 1) / n_threads, + .total_rows = nr, + .opt_path = opt_path, + }; + + switch (dst->type) { + case HTP_TYPE_F32: + fctx.splat_vec = hvx_vec_splat_f32(val_f32); + fctx.elem_size = sizeof(float); + break; + case HTP_TYPE_F16: + fctx.splat_vec = hvx_vec_splat_f16((_Float16) val_f32); + fctx.elem_size = sizeof(_Float16); + break; + default: + return HTP_STATUS_NO_SUPPORT; + } + + worker_pool_run_func(octx->ctx->worker_pool, fill_thread, &fctx, n_threads); + + return HTP_STATUS_OK; +} diff --git a/ggml/src/ggml-hexagon/htp/htp-ctx.h b/ggml/src/ggml-hexagon/htp/htp-ctx.h index 038941af0f2..78455e6b071 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ctx.h +++ b/ggml/src/ggml-hexagon/htp/htp-ctx.h @@ -98,6 +98,7 @@ int op_repeat(struct htp_ops_context * octx); int op_argsort(struct htp_ops_context * octx); int op_ssm_conv(struct htp_ops_context * octx); int op_cumsum(struct htp_ops_context * octx); +int op_fill(struct htp_ops_context * octx); int op_diag(struct htp_ops_context * octx); #endif /* HTP_CTX_H */ diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index 002dd1c12d2..62d6ec02241 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -80,6 +80,7 @@ enum htp_op_code { HTP_OP_SSM_CONV, HTP_OP_REPEAT, HTP_OP_CUMSUM, + HTP_OP_FILL, HTP_OP_DIAG, HTP_OP_INVALID diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index d633145c909..9185c9ffe15 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -514,6 +514,9 @@ static int execute_op(struct htp_ops_context * octx) { case HTP_OP_CUMSUM: return op_cumsum(octx); + case HTP_OP_FILL: + return op_fill(octx); + case HTP_OP_DIAG: return op_diag(octx); From 447be522e91bc83679fa714eb40f3e994c2aaa73 Mon Sep 17 00:00:00 2001 From: Chen Yuan Date: Tue, 21 Apr 2026 23:18:57 -0400 Subject: [PATCH 473/831] ggml-webgpu(shader): support conv2d kernels. (llama/21964) * ggml(webgpu): fix the busy-polls in Emscripten in the waitAny after #20618, and remove the busy webgpu log * Merge with upstream * Fix GET_ROWS packed integer NaN when using f16 as memory buffer in shader quants * Update Unary wgsl EXP and EXPM1 for f16 stability * Fix GET_ROWS IQ4_XS strcut for NaN f16 canonicalization * Fix numerical percision for unary sqrt when working with f16 * Fix NaN canonicalization for packed integers using f16 * Update err threshold for binary div ops when using f16 * backend: Keep one Dawn/WebGPU instance alive for the lifetime of the static backend * clean: uncomment existing code logs * clean: clean the unncessary debug info * Refactor and generalize dequant helpers * Remove deprecated quant structs * Refactor shader defines to reduce repetition * Remove error override for F16 type * fix: fix the accidential removal of the proper initialization of ctx * clean: clean legacy and format code * fix: did not modify tests ops * shader(conv2d): add conv2d shader kernels and pass f32 and f16 tests * shader(conv2d): fix the out of bounds memory access in the weight indexing * shader(conv2d): clean unused variables and optimize the computation * merge: use the new entries function * clean: address the formatting issues * clean: address the warning issues * clear: clean the shader editorconfig-checker issues * clear: clean the shader editorconfig-checker with utf-8 --------- Co-authored-by: Jeremy J. Hartmann --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 63 +++++++ ggml/src/ggml-webgpu/ggml-webgpu.cpp | 89 ++++++++++ ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl | 165 ++++++++++++++++++ 3 files changed, 317 insertions(+) create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 9d88f98050e..f84dfee9d39 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -240,6 +240,27 @@ struct ggml_webgpu_ssm_conv_pipeline_key { } }; +/** CONV 2D */ +struct ggml_webgpu_conv2d_pipeline_key { + ggml_type weight_type; + ggml_type input_type; + ggml_type output_type; + + bool operator==(const ggml_webgpu_conv2d_pipeline_key & other) const { + return weight_type == other.weight_type && input_type == other.input_type && output_type == other.output_type; + } +}; + +struct ggml_webgpu_conv2d_pipeline_key_hash { + size_t operator()(const ggml_webgpu_conv2d_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.weight_type); + ggml_webgpu_hash_combine(seed, key.input_type); + ggml_webgpu_hash_combine(seed, key.output_type); + return seed; + } +}; + /** Gated Delta Net **/ struct ggml_webgpu_gated_delta_net_pipeline_key { int type; @@ -789,6 +810,8 @@ class ggml_webgpu_shader_lib { rope_pipelines; std::unordered_map soft_max_pipelines; + std::unordered_map + conv2d_pipelines; public: ggml_webgpu_shader_lib(wgpu::Device device) { this->device = device; } @@ -2382,6 +2405,46 @@ class ggml_webgpu_shader_lib { return soft_max_pipelines[key]; } + webgpu_pipeline get_conv2d_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_conv2d_pipeline_key key = {}; + key.weight_type = context.src0->type; + key.input_type = context.src1->type; + key.output_type = context.dst->type; + + auto it = conv2d_pipelines.find(key); + if (it != conv2d_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "conv_2d"; + + auto push_type_defines = [&](const char * prefix, ggml_type type) { + std::string s_prefix = prefix; + if (type == GGML_TYPE_F32) { + defines.push_back(s_prefix + "_F32"); + } else if (type == GGML_TYPE_F16) { + defines.push_back(s_prefix + "_F16"); + } else { + GGML_ABORT("Unsupported type for CONV_2D shader"); + } + }; + + push_type_defines("WEIGHT", key.weight_type); + push_type_defines("INPUT", key.input_type); + push_type_defines("OUTPUT", key.output_type); + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_conv2d, defines); + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + conv2d_pipelines[key] = pipeline; + return conv2d_pipelines[key]; + } + private: static webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & device, std::string shader_code, diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index a2923145230..551586751c0 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -8,6 +8,7 @@ #include "ggml-backend-impl.h" #include "ggml-impl.h" #include "ggml-webgpu-shader-lib.hpp" +#include "ggml.h" #ifdef __EMSCRIPTEN__ # include @@ -921,6 +922,87 @@ static webgpu_encoded_op ggml_webgpu_solve_tri(webgpu_context & ctx, return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); } +static webgpu_encoded_op ggml_webgpu_conv_2d(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { + const int32_t s0 = ggml_get_op_params_i32(dst, 0); + const int32_t s1 = ggml_get_op_params_i32(dst, 1); + const int32_t p0 = ggml_get_op_params_i32(dst, 2); + const int32_t p1 = ggml_get_op_params_i32(dst, 3); + const int32_t d0 = ggml_get_op_params_i32(dst, 4); + const int32_t d1 = ggml_get_op_params_i32(dst, 5); + + std::vector params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + + (uint32_t) (src0->nb[0] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), + + (uint32_t) (src1->nb[0] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)), + + (uint32_t) (dst->nb[0] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), + + (uint32_t) src0->ne[0], + (uint32_t) src0->ne[1], + (uint32_t) src0->ne[2], + + (uint32_t) src1->ne[0], + (uint32_t) src1->ne[1], + + (uint32_t) dst->ne[0], + (uint32_t) dst->ne[1], + (uint32_t) dst->ne[2], + (uint32_t) dst->ne[3], + + (uint32_t) s0, + (uint32_t) s1, + (uint32_t) p0, + (uint32_t) p1, + (uint32_t) d0, + (uint32_t) d1, + }; + + std::vector entries = { + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst), + }; + + uint32_t max_wg_size = + std::min((uint32_t) WEBGPU_MAX_WG_SIZE, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupSizeX); + uint32_t wg_size = + std::min((uint32_t) ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, max_wg_size); + + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = wg_size; + + webgpu_pipeline pipeline = ctx->shader_lib->get_conv2d_pipeline(shader_lib_ctx); + + auto * decisions = static_cast(pipeline.context.get()); + + uint32_t n_out = ggml_nelements(dst); + uint32_t total_wg = CEIL_DIV(n_out, decisions->wg_size); + uint32_t max_wg = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; + uint32_t wg_x = std::min(total_wg, max_wg); + uint32_t wg_y = CEIL_DIV(total_wg, wg_x); + + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); +} + static webgpu_encoded_op ggml_webgpu_ssm_conv(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, @@ -2477,6 +2559,8 @@ static std::optional ggml_webgpu_encode_node(webgpu_context c case GGML_OP_SUM: case GGML_OP_SUM_ROWS: return ggml_webgpu_sum_rows(ctx, src0, node); + case GGML_OP_CONV_2D: + return ggml_webgpu_conv_2d(ctx, src0, src1, node); default: return std::nullopt; } @@ -3495,6 +3579,11 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const case GGML_OP_SOLVE_TRI: supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32; break; + case GGML_OP_CONV_2D: + supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && + (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16) && + (src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); + break; case GGML_OP_SSM_CONV: supports_op = op->type == GGML_TYPE_F32; break; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl new file mode 100644 index 00000000000..9eb131dc221 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl @@ -0,0 +1,165 @@ +#include "common_decls.tmpl" +enable f16; + +@group(0) @binding(0) +#if defined(WEIGHT_F32) +var weights: array; +#elif defined(WEIGHT_F16) +var weights: array; +#endif + +@group(0) @binding(1) +#if defined(INPUT_F32) +var input: array; +#elif defined(INPUT_F16) +var input: array; +#endif + +@group(0) @binding(2) +#if defined(OUTPUT_F32) +var output: array; +#elif defined(OUTPUT_F16) +var output: array; +#endif + +struct Params { + offset_w: u32, + offset_i: u32, + offset_o: u32, + + // element strides + sw0: u32, sw1: u32, sw2: u32, sw3: u32, + si0: u32, si1: u32, si2: u32, si3: u32, + so0: u32, so1: u32, so2: u32, so3: u32, + + // kernel dimensions + KW: u32, KH: u32, IC: u32, + // input dimensions + IW: u32, IH: u32, + // output dimensions + OW: u32, OH: u32, OC_out: u32, N_out: u32, + + // stride + s0: u32, s1: u32, + // padding + p0: u32, p1: u32, + // dilation + d0: u32, d1: u32, +}; + +@group(0) @binding(3) +var params: Params; + +fn load_weight(idx: u32) -> f32 { + #if defined(WEIGHT_F32) + return weights[idx]; + #elif defined(WEIGHT_F16) + return f32(weights[idx]); + #endif +} + +fn load_input(idx: u32) -> f32 { + #if defined(INPUT_F32) + return input[idx]; + #elif defined(INPUT_F16) + return f32(input[idx]); + #endif +} + +fn store_output(idx: u32, val: f32) { + #if defined(OUTPUT_F32) + output[idx] = val; + #elif defined(OUTPUT_F16) + output[idx] = f16(val); + #endif +} + +fn ceil_div_u32(x: u32, y: u32) -> u32 { + return (x + y - 1) / y; +} + +// returns the first valid kernel index k such that base + k * step >= 0 +fn first_valid_k(base: i32, step: u32) -> u32 { + if (base >= 0) { + return 0; + } + + return ceil_div_u32(u32(-base), step); +} + +// returns the first invalid kernel index k such that base + k * step >= limit so valid k are in [0, end_valid_k) +fn end_valid_k(base: i32, step: u32, limit: u32, k_max: u32) -> u32 { + let remaining = i32(limit) - base; + if (remaining <= 0) { + return 0; + } + + return min(k_max, ceil_div_u32(u32(remaining), step)); +} + +@compute @workgroup_size(WG_SIZE) +fn main( + @builtin(global_invocation_id) gid: vec3, + @builtin(num_workgroups) num_wg: vec3 +) { + + let threads_per_group = u32(WG_SIZE); + let i_out = gid.x + (num_wg.x * threads_per_group) * gid.y; + let n_out = params.OW * params.OH * params.OC_out * params.N_out; + + var sum: f32 = 0.0; + if (i_out >= n_out) { + return; + } + + // Kernel layout: [KW, KH, IC, ..] + // Input layout: [IW, IH, .., ..] + // Output layout: [OW, OH, OC, N] + + var i = i_out; + let n = i / (params.OC_out * params.OH * params.OW); + i = i % (params.OC_out * params.OH * params.OW); + let oc = i / (params.OH * params.OW); + i = i % (params.OH * params.OW); + let oh = i / params.OW; + let ow = i % params.OW; + + let ow_base = i32(ow * params.s0) - i32(params.p0); + let oh_base = i32(oh * params.s1) - i32(params.p1); + + // clip the valid kernel window once + let kw_begin = first_valid_k(ow_base, params.d0); + let kw_end = end_valid_k(ow_base, params.d0, params.IW, params.KW); + let kh_begin = first_valid_k(oh_base, params.d1); + let kh_end = end_valid_k(oh_base, params.d1, params.IH, params.KH); + + // entire receptive field is out of bounds + if (kw_begin >= kw_end || kh_begin >= kh_end) { + let out_idx = params.offset_o + ow * params.so0 + oh * params.so1 + oc * params.so2 + n * params.so3; + store_output(out_idx, 0.0); + return; + } + + let weight_oc_base = params.offset_w + oc * params.sw3; + let input_n_base = params.offset_i + n * params.si3; + + for (var ic: u32 = 0; ic < params.IC; ic += 1) { + let w_base_ic = ic * params.sw2 + weight_oc_base; + let in_base = ic * params.si2 + input_n_base; + + for (var kh: u32 = kh_begin; kh < kh_end; kh += 1) { + let ih = u32(oh_base + i32(kh * params.d1)); + let w_row_base = w_base_ic + kh * params.sw1; + let in_row_base = in_base + ih * params.si1; + for (var kw: u32 = kw_begin; kw < kw_end; kw += 1) { + let iw = u32(ow_base + i32(kw * params.d0)); + let w_idx = w_row_base + kw * params.sw0; + let in_idx = in_row_base + iw * params.si0; + sum += load_weight(w_idx) * load_input(in_idx); + } + } + } + + let out_idx = params.offset_o + ow * params.so0 + oh * params.so1 + oc * params.so2 + n * params.so3; + store_output(out_idx, sum); +} From c5bb7c0078d94cbf6f85caa5a7bc19cf310d846f Mon Sep 17 00:00:00 2001 From: Akarshan Biswas Date: Wed, 22 Apr 2026 18:02:56 +0530 Subject: [PATCH 474/831] sycl: Improve mul_mat_id memory efficiency and add BF16 fast path (llama/22119) * sycl: size mul_mat_id staging buffers by routed rows Previously src1_contiguous/dst_contiguous in ggml_sycl_mul_mat_id were sized to ggml_nelements(src1/dst), which over-allocates when ne12 > 1 and can fail with UR_RESULT_ERROR_OUT_OF_HOST_MEMORY on Level Zero for MoE models (notably with --cpu-moe). Size them by the actual number of routed rows (ids->ne[1] * n_ids) instead. * sycl: add bf16 mul_mat fast path via DNNL When src0 is BF16 (commonly the case for lm_head / output.weight), the existing f16 path is skipped because bf16 isn't covered, and the f32 fallback dequantizes the entire src0 slab to f32 in a single pool alloc (row_diff*ne00 floats). For large-vocab models this can reach several GB and fail with UR_RESULT_ERROR_OUT_OF_HOST_MEMORY on Level Zero. Add a bf16xbf16 -> f32 DNNL matmul fast path that uses the bf16 storage in place and only materializes a small src1 bf16 conversion buffer. bf16 matmul accumulates in f32, so it's correct even when the op requests GGML_PREC_F32 (as lm_head does). - gemm.hpp: map bfloat16 to dnnl::memory::data_type::bf16. - convert.{hpp,cpp}: expose ggml_get_to_bf16_sycl for f32/f16/bf16 -> bf16. - ggml-sycl.cpp: take the bf16 path early in ggml_sycl_op_mul_mat_sycl when DNNL and GGML_SYCL_HAS_BF16 are both available. --- ggml/src/ggml-sycl/common.hpp | 7 +++++++ ggml/src/ggml-sycl/convert.cpp | 23 ++++++++++++++++------- ggml/src/ggml-sycl/convert.hpp | 9 +++++++++ ggml/src/ggml-sycl/gemm.hpp | 3 +++ ggml/src/ggml-sycl/ggml-sycl.cpp | 30 ++++++++++++++++++++++++++++-- ggml/src/ggml-sycl/set_rows.cpp | 8 +++++++- 6 files changed, 70 insertions(+), 10 deletions(-) diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp index fd84c917853..0101b27640a 100644 --- a/ggml/src/ggml-sycl/common.hpp +++ b/ggml/src/ggml-sycl/common.hpp @@ -28,6 +28,13 @@ namespace syclexp = sycl::ext::oneapi::experimental; +#if defined(__INTEL_LLVM_COMPILER) && __has_include() + #include + #ifndef GGML_SYCL_HAS_BF16 + #define GGML_SYCL_HAS_BF16 + #endif +#endif + #if GGML_SYCL_DNNL #include "dnnl.hpp" #include "dnnl_sycl.hpp" diff --git a/ggml/src/ggml-sycl/convert.cpp b/ggml/src/ggml-sycl/convert.cpp index f3c521b45f6..67b9c06f3e4 100644 --- a/ggml/src/ggml-sycl/convert.cpp +++ b/ggml/src/ggml-sycl/convert.cpp @@ -2,13 +2,6 @@ #include "dequantize.hpp" #include "presets.hpp" -#if defined(__INTEL_LLVM_COMPILER) - #if __has_include() - #include - #define GGML_SYCL_HAS_BF16 - #endif -#endif - template static void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, const sycl::nd_item<3> &item_ct1) { @@ -767,6 +760,22 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) { } +#ifdef GGML_SYCL_HAS_BF16 +to_bf16_sycl_t ggml_get_to_bf16_sycl(ggml_type type, ggml_tensor * /*dst*/) { + switch (type) { + case GGML_TYPE_F32: + return convert_unary_sycl; + case GGML_TYPE_F16: + return convert_unary_sycl; + case GGML_TYPE_BF16: + return convert_unary_sycl; + default: + GGML_ABORT("fatal error: unsupport data type=%s\n", ggml_type_name(type)); + return nullptr; + } +} +#endif + to_fp16_nc_sycl_t ggml_get_to_fp16_nc_sycl(ggml_type type) { switch (type) { case GGML_TYPE_F32: diff --git a/ggml/src/ggml-sycl/convert.hpp b/ggml/src/ggml-sycl/convert.hpp index 6e621f2154d..8de79d10ff6 100644 --- a/ggml/src/ggml-sycl/convert.hpp +++ b/ggml/src/ggml-sycl/convert.hpp @@ -23,6 +23,11 @@ typedef to_t_sycl_t to_fp16_sycl_t; to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst); to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor * dst); +#ifdef GGML_SYCL_HAS_BF16 +typedef to_t_sycl_t to_bf16_sycl_t; +to_bf16_sycl_t ggml_get_to_bf16_sycl(ggml_type type, ggml_tensor * dst); +#endif + // Nc = Non-contiguous template using to_t_nc_sycl_t = void (*)(const void * x, T * y, int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne03, @@ -35,15 +40,19 @@ template inline dst_t ggml_sycl_cast(src_t x) { if constexpr (std::is_same_v) { return x; +#ifdef GGML_SYCL_HAS_BF16 } else if constexpr (std::is_same_v) { return sycl::ext::oneapi::bfloat16(float(x)); } else if constexpr (std::is_same_v) { return static_cast(x); +#endif } else if constexpr (std::is_same_v && std::is_same_v) { return x.template convert(); +#ifdef GGML_SYCL_HAS_BF16 } else if constexpr (std::is_same_v && std::is_same_v>) { return {x.x, x.y}; +#endif } else if constexpr(std::is_same_v) { return int32_t(x); } else { diff --git a/ggml/src/ggml-sycl/gemm.hpp b/ggml/src/ggml-sycl/gemm.hpp index dcf6c7aeeb4..c202da110be 100644 --- a/ggml/src/ggml-sycl/gemm.hpp +++ b/ggml/src/ggml-sycl/gemm.hpp @@ -29,6 +29,9 @@ class DnnlGemmWrapper { static constexpr dt to_dt() { if constexpr (std::is_same_v) return dt::f32; else if constexpr (std::is_same_v) return dt::f16; +#ifdef GGML_SYCL_HAS_BF16 + else if constexpr (std::is_same_v) return dt::bf16; +#endif else static_assert(0); } diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index c02a41ad862..3829da87903 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -2176,6 +2176,31 @@ inline void ggml_sycl_op_mul_mat_sycl( #else bool use_fp16 = false; #endif + +#if GGML_SYCL_DNNL && defined(GGML_SYCL_HAS_BF16) + // Fast path for bf16 src0 + if (src0->type == GGML_TYPE_BF16 && !g_ggml_sycl_disable_dnn && ggml_is_contiguous(src0) && + row_diff == src0->ne[1]) { + using bf16_t = sycl::ext::oneapi::bfloat16; + ggml_sycl_pool_alloc src1_as_bf16(ctx.pool(), src1_ncols*ne10); + if (src1->type != GGML_TYPE_BF16) { + const to_bf16_sycl_t to_bf16_sycl = ggml_get_to_bf16_sycl(src1->type, dst); + GGML_ASSERT(to_bf16_sycl != nullptr); + to_bf16_sycl(src1_ddf_i, src1_as_bf16.get(), src1_ncols*ne10, stream); + } else { + stream->memcpy(src1_as_bf16.get(), src1_ddf_i, src1_ncols*ne10*sizeof(bf16_t)); + } + DnnlGemmWrapper::row_gemm(ctx, row_diff, src1_ncols, ne10, + src0_dd_i, DnnlGemmWrapper::to_dt(), + src1_as_bf16.get(), DnnlGemmWrapper::to_dt(), + dst_dd_i, DnnlGemmWrapper::to_dt(), stream); + GGML_UNUSED(dst); + GGML_UNUSED(src1_ddq_i); + GGML_UNUSED(src1_padded_row_size); + return; + } +#endif + if ((src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && use_fp16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) { ggml_sycl_pool_alloc src0_as_f16(ctx.pool()); @@ -3848,8 +3873,9 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, } } } else { - ggml_sycl_pool_alloc src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1)); - ggml_sycl_pool_alloc dst_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst)); + const int64_t n_routed_rows = ids->ne[1] * n_ids; + ggml_sycl_pool_alloc src1_contiguous(ctx.pool(), sizeof(float)*n_routed_rows*ne10); + ggml_sycl_pool_alloc dst_contiguous(ctx.pool(), sizeof(float)*n_routed_rows*ne0); src1_row.data = src1_contiguous.get(); dst_row.data = dst_contiguous.get(); diff --git a/ggml/src/ggml-sycl/set_rows.cpp b/ggml/src/ggml-sycl/set_rows.cpp index a641c100913..8fb41943525 100644 --- a/ggml/src/ggml-sycl/set_rows.cpp +++ b/ggml/src/ggml-sycl/set_rows.cpp @@ -4,7 +4,11 @@ namespace utils { template static constexpr bool is_arithmetic_v() { - return std::is_arithmetic_v || std::is_same_v || std::is_same_v; + return std::is_arithmetic_v || std::is_same_v +#ifdef GGML_SYCL_HAS_BF16 + || std::is_same_v +#endif + ; } } @@ -181,6 +185,7 @@ static void set_rows_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor * s stream ); break; +#ifdef GGML_SYCL_HAS_BF16 case GGML_TYPE_BF16: set_rows_sycl( src0_d, src1_d, (char *)dst->data, @@ -193,6 +198,7 @@ static void set_rows_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor * s stream ); break; +#endif case GGML_TYPE_Q8_0: set_rows_sycl_q(src0_d, src1_d, (block_q8_0 *)dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb1, nb2, nb3, stream); break; From 0fbe4c4ca7a4c42c217a8f2ada04677c33a20b2c Mon Sep 17 00:00:00 2001 From: Masashi Yoshimura Date: Thu, 23 Apr 2026 02:51:40 +0900 Subject: [PATCH 475/831] ggml-webgpu: Add fused RMS_NORM + MUL (llama/21983) * fused rms_norm_mul + mul * Add GGML_WEBGPU_DISABLE_FUSION for being able to disable kernel fusion. * Decouple num_fused_ops from webgpu_context; misc cleanup * Fix eps handling and remove disable_fusion. * Fix not to use c++20 initializers. --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 71 +++++++- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 157 ++++++++++++++++-- .../wgsl-shaders/rms_norm_mul.wgsl | 139 ++++++++++++++++ 3 files changed, 349 insertions(+), 18 deletions(-) create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index f84dfee9d39..6593a9fe16b 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -194,6 +194,26 @@ struct ggml_webgpu_row_norm_pipeline_key_hash { } }; +/** RMS_NORM + MUL **/ + +struct ggml_webgpu_rms_norm_mul_pipeline_key { + bool inplace; + bool src_overlap; + + bool operator==(const ggml_webgpu_rms_norm_mul_pipeline_key & other) const { + return inplace == other.inplace && src_overlap == other.src_overlap; + } +}; + +struct ggml_webgpu_rms_norm_mul_pipeline_key_hash { + size_t operator()(const ggml_webgpu_rms_norm_mul_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.inplace); + ggml_webgpu_hash_combine(seed, key.src_overlap); + return seed; + } +}; + /** Pad **/ struct ggml_webgpu_pad_pipeline_key { bool circular; @@ -517,7 +537,7 @@ inline uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_shader_lib_ const size_t q_tile = context.sg_mat_m; const size_t base_q_bytes = (key.head_dim_qk + key.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES + 2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES; - size_t bytes_per_kv = 0; + size_t bytes_per_kv = 0; if (!key.kv_direct) { bytes_per_kv += std::max(key.head_dim_qk, key.head_dim_v); } @@ -755,16 +775,17 @@ class ggml_webgpu_shader_lib { std::unordered_map cumsum_pipelines; // key is fixed, no variants yet std::unordered_map row_norm_pipelines; // op/inplace + std::unordered_map - get_rows_pipelines; // src_type, vectorized + get_rows_pipelines; // src_type, vectorized std::unordered_map - unary_pipelines; // type/op/inplace + unary_pipelines; // type/op/inplace std::unordered_map - scale_pipelines; // inplace + scale_pipelines; // inplace std::unordered_map - solve_tri_pipelines; // type + solve_tri_pipelines; // type std::unordered_map - ssm_conv_pipelines; // type/vectorized + ssm_conv_pipelines; // type/vectorized std::unordered_map @@ -813,6 +834,11 @@ class ggml_webgpu_shader_lib { std::unordered_map conv2d_pipelines; + std::unordered_map + rms_norm_mul_pipelines; + public: ggml_webgpu_shader_lib(wgpu::Device device) { this->device = device; } @@ -1828,6 +1854,39 @@ class ggml_webgpu_shader_lib { return unary_pipelines[key]; } + webgpu_pipeline get_rms_norm_mul_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_rms_norm_mul_pipeline_key key = {}; + key.inplace = context.inplace; + key.src_overlap = context.src_overlap; + + auto it = rms_norm_mul_pipelines.find(key); + if (it != rms_norm_mul_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string op_name = "RMS_NORM_MUL"; + std::string variant = op_name; + + if (key.inplace) { + defines.push_back("INPLACE"); + variant += "_inplace"; + } else if (key.src_overlap) { + defines.push_back("SRC_OVERLAP"); + variant += "_src_overlap"; + } + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_rms_norm_mul, defines); + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + rms_norm_mul_pipelines[key] = pipeline; + return rms_norm_mul_pipelines[key]; + } + webgpu_pipeline get_binary_pipeline(const ggml_webgpu_shader_lib_context & context) { ggml_webgpu_binary_pipeline_key key = {}; key.type = context.dst->type; diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 551586751c0..5d3169904c5 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -1972,6 +1972,94 @@ static webgpu_encoded_op ggml_webgpu_repeat(webgpu_context & ctx, ggml_tensor * return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } +static std::optional ggml_webgpu_rms_norm_mul(webgpu_context & ctx, + ggml_tensor * rn_src, + ggml_tensor * rn_dst, + ggml_tensor * mul_src0, + ggml_tensor * mul_src1, + ggml_tensor * dst) { + ggml_tensor * mul_src; + + if (ggml_webgpu_tensor_equal(rn_dst, mul_src0)) { + mul_src = mul_src1; + } else if (ggml_webgpu_tensor_equal(rn_dst, mul_src1)) { + mul_src = mul_src0; + } else { + GGML_ABORT("rms_norm must be equal to the one of mul_src0 and mul_src1"); + } + + bool inplace = (ggml_webgpu_tensor_equal(rn_dst, mul_src0) && ggml_webgpu_tensor_equal(mul_src1, dst)) || + (ggml_webgpu_tensor_equal(rn_dst, mul_src1) && ggml_webgpu_tensor_equal(mul_src0, dst)); + bool src_overlap = ggml_webgpu_tensor_overlap(rn_src, mul_src); + + uint32_t offset_merged_rn_src = 0; + uint32_t offset_merged_mul_src = 0; + size_t rn_src_webgpu_tensor_align_offset = ggml_webgpu_tensor_align_offset(ctx, rn_src); + size_t mul_src_webgpu_tensor_align_offset = ggml_webgpu_tensor_align_offset(ctx, mul_src); + + if (src_overlap) { + size_t min_offset = std::min(rn_src_webgpu_tensor_align_offset, mul_src_webgpu_tensor_align_offset); + offset_merged_rn_src = + (uint32_t) ((rn_src_webgpu_tensor_align_offset - min_offset) / ggml_type_size(rn_src->type)); + offset_merged_mul_src = + (uint32_t) ((mul_src_webgpu_tensor_align_offset - min_offset) / ggml_type_size(mul_src->type)); + } + + std::vector params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, rn_src) / ggml_type_size(rn_src->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mul_src) / ggml_type_size(mul_src->type)), + offset_merged_rn_src, + offset_merged_mul_src, + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + (uint32_t) (rn_src->nb[1] / ggml_type_size(rn_src->type)), + (uint32_t) (rn_src->nb[2] / ggml_type_size(rn_src->type)), + (uint32_t) (rn_src->nb[3] / ggml_type_size(rn_src->type)), + (uint32_t) (mul_src->nb[1] / ggml_type_size(mul_src->type)), + (uint32_t) (mul_src->nb[2] / ggml_type_size(mul_src->type)), + (uint32_t) (mul_src->nb[3] / ggml_type_size(mul_src->type)), + (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), + (uint32_t) mul_src->ne[0], + (uint32_t) mul_src->ne[1], + (uint32_t) mul_src->ne[2], + (uint32_t) mul_src->ne[3], + (uint32_t) dst->ne[0], + (uint32_t) dst->ne[1], + (uint32_t) dst->ne[2], + (uint32_t) dst->ne[3], + ggml_webgpu_u32_from_f32(ggml_get_op_params_f32(rn_dst, 0)) // epsilon, treated as f32 in the shader + }; + + std::vector entries; + + if (inplace) { + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, rn_src)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, mul_src)); + } else if (src_overlap) { + size_t merged_offset = std::min(rn_src_webgpu_tensor_align_offset, mul_src_webgpu_tensor_align_offset); + size_t merged_end = + std::max(rn_src_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, rn_src), + mul_src_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, mul_src)); + entries.push_back(ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(rn_src), merged_offset, + merged_end - merged_offset)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst)); + } else { + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, rn_src)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, mul_src)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst)); + } + + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.inplace = inplace; + shader_lib_ctx.src_overlap = src_overlap; + + webgpu_pipeline pipeline = ctx->shader_lib->get_rms_norm_mul_pipeline(shader_lib_ctx); + + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, ggml_nrows(dst)); +} + static webgpu_encoded_op ggml_webgpu_row_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { bool inplace = ggml_webgpu_tensor_equal(src, dst); @@ -2468,15 +2556,48 @@ static webgpu_encoded_op ggml_webgpu_sum_rows(webgpu_context & ctx, ggml_tensor return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } +static bool ggml_webgpu_can_fuse_rms_norm_mul(const struct ggml_cgraph * cgraph, int node_idx) { + if (!ggml_can_fuse(cgraph, node_idx, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { + return false; + } + + // additional constraints specific to this fusion + const ggml_tensor * rms_norm = cgraph->nodes[node_idx]; + const ggml_tensor * mul = cgraph->nodes[node_idx + 1]; + + GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(rms_norm->type == GGML_TYPE_F32); + // rms_norm only supports f32 + if (mul->src[0]->type != GGML_TYPE_F32 || mul->src[1]->type != GGML_TYPE_F32 || mul->type != GGML_TYPE_F32) { + return false; + } + // if rms_norm is the B operand, then we don't handle broadcast + if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm)) { + return false; + } + // rms_norm shader assumes contiguous rows + if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) { + return false; + } + + return true; +} + // Returns the encoded command, or std::nullopt if the operation is a no-op -static std::optional ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) { +static std::optional ggml_webgpu_encode(webgpu_context ctx, + ggml_cgraph * cgraph, + int node_idx, + int & num_encoded_ops) { + ggml_tensor ** nodes = cgraph->nodes; + ggml_tensor * node = nodes[node_idx]; + if (ggml_is_empty(node)) { return std::nullopt; } if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { return std::nullopt; } - WEBGPU_LOG_DEBUG("ggml_webgpu_encode_node(" << node << ", " << ggml_op_name(node->op) << ")"); + WEBGPU_LOG_DEBUG("ggml_webgpu_encode(" << node << ", " << ggml_op_name(node->op) << ")"); ggml_tensor * src0 = node->src[0]; ggml_tensor * src1 = node->src[1]; @@ -2519,6 +2640,13 @@ static std::optional ggml_webgpu_encode_node(webgpu_context c case GGML_OP_REPEAT: return ggml_webgpu_repeat(ctx, src0, node); case GGML_OP_RMS_NORM: + if (ggml_webgpu_can_fuse_rms_norm_mul(cgraph, node_idx)) { + num_encoded_ops = 2; + ggml_tensor * mul_node = nodes[node_idx + 1]; + return ggml_webgpu_rms_norm_mul(ctx, src0, node, mul_node->src[0], mul_node->src[1], mul_node); + } else { + return ggml_webgpu_row_norm(ctx, src0, node); + } case GGML_OP_L2_NORM: return ggml_webgpu_row_norm(ctx, src0, node); case GGML_OP_ROPE: @@ -2629,6 +2757,8 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str uint32_t num_inflight_batches = 0; bool contains_set_rows = false; bool batch_compute_passes = true; + int num_encoded_ops = 1; + int node_idx = 0; #ifdef GGML_WEBGPU_GPU_PROFILE ctx->profile_timestamp_query_count = 0; @@ -2641,11 +2771,11 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str ctx->active_compute_pass = ctx->active_command_encoder.BeginComputePass(); } - for (int i = 0; i < cgraph->n_nodes; i++) { - if (cgraph->nodes[i]->op == GGML_OP_SET_ROWS) { + while (node_idx < cgraph->n_nodes) { + if (cgraph->nodes[node_idx]->op == GGML_OP_SET_ROWS) { contains_set_rows = true; } - if (auto cmd = ggml_webgpu_encode_node(ctx, cgraph->nodes[i])) { + if (auto cmd = ggml_webgpu_encode(ctx, cgraph, node_idx, num_encoded_ops)) { commands.push_back(*cmd); num_batched_kernels += cmd.value().num_kernels; #ifdef GGML_WEBGPU_GPU_PROFILE @@ -2670,6 +2800,9 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str ctx->param_arena.reset(); commands.clear(); } + + node_idx += num_encoded_ops; + num_encoded_ops = 1; } if (ctx->active_compute_pass) { @@ -3237,7 +3370,7 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) { ggml_backend_webgpu_device_context * dev_ctx = (ggml_backend_webgpu_device_context *) dev->context; webgpu_context webgpu_ctx = std::make_shared(); webgpu_ctx->global_ctx = dev_ctx->webgpu_global_ctx; - webgpu_ctx->shader_lib = std::make_unique(dev_ctx->webgpu_global_ctx->device); + webgpu_ctx->shader_lib = std::make_unique(dev_ctx->webgpu_global_ctx->device); webgpu_ctx->param_arena.init( webgpu_ctx->global_ctx->device, WEBGPU_PARAMS_BUF_SIZE_BYTES, webgpu_ctx->global_ctx->command_submit_batch_size + WEBGPU_NUM_PARAM_SLOT_SAFETY_MARGIN, @@ -3487,12 +3620,12 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const break; } // Head dimensions must fit in workgroup memory with minimum tile sizes - size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; - const bool has_mask = op->src[3] != nullptr; - const bool kv_direct = src1->type == GGML_TYPE_F16 && - (src0->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k) == 0 && - (src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD) == 0; - const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes( + size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; + const bool has_mask = op->src[3] != nullptr; + const bool kv_direct = src1->type == GGML_TYPE_F16 && + (src0->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k) == 0 && + (src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD) == 0; + const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes( ctx->webgpu_global_ctx->capabilities.sg_mat_m, ctx->webgpu_global_ctx->capabilities.sg_mat_n, (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask, kv_direct); if (min_bytes > limit_bytes) { diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl new file mode 100644 index 00000000000..71f063b51aa --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl @@ -0,0 +1,139 @@ +#ifdef INPLACE + +@group(0) @binding(0) +var rn_src: array; + +@group(0) @binding(1) +var mul_src: array; + +@group(0) @binding(2) +var params: Params; + +fn update(rn_src_offset: u32, dst_offset: u32, scale: f32, mul_src_offset: u32) { + mul_src[dst_offset] = scale * rn_src[rn_src_offset] * mul_src[mul_src_offset]; +} + +#elif SRC_OVERLAP + +@group(0) @binding(0) +var merged_src: array; + +@group(0) @binding(1) +var dst: array; + +@group(0) @binding(2) +var params: Params; + +fn update(rn_src_offset: u32, dst_offset: u32, scale: f32, mul_src_offset: u32) { + dst[dst_offset] = scale * merged_src[rn_src_offset] * merged_src[mul_src_offset]; +} + +#else + +@group(0) @binding(0) +var rn_src: array; + +@group(0) @binding(1) +var mul_src: array; + +@group(0) @binding(2) +var dst: array; + +@group(0) @binding(3) +var params: Params; + +fn update(rn_src_offset: u32, dst_offset: u32, scale: f32, mul_src_offset: u32) { + dst[dst_offset] = scale * rn_src[rn_src_offset] * mul_src[mul_src_offset]; +} + +#endif + +struct Params { + offset_rn_src: u32, + offset_mul_src: u32, + offset_merged_rn_src: u32, + offset_merged_mul_src: u32, + offset_dst: u32, + + stride_rn_src1: u32, + stride_rn_src2: u32, + stride_rn_src3: u32, + + stride_mul_src1: u32, + stride_mul_src2: u32, + stride_mul_src3: u32, + + stride_dst1: u32, + stride_dst2: u32, + stride_dst3: u32, + + mul_src_ne0: u32, + mul_src_ne1: u32, + mul_src_ne2: u32, + mul_src_ne3: u32, + + ne0: u32, + ne1: u32, + ne2: u32, + ne3: u32, + + eps: f32 +}; + +var scratch: array; + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(workgroup_id) wid: vec3, + @builtin(local_invocation_id) lid: vec3) { + + // one thread per row + var i = wid.x; + let i3 = i / (params.ne2 * params.ne1); + i = i % (params.ne2 * params.ne1); + let i2 = i / params.ne1; + let i1 = i % params.ne1; + let i_rn_src_row = params.offset_rn_src + params.offset_merged_rn_src + i3 * params.stride_rn_src3 + i2 * params.stride_rn_src2 + i1 * params.stride_rn_src1; + let i_mul_src_row = params.offset_mul_src + params.offset_merged_mul_src + (i3 % params.mul_src_ne3) * params.stride_mul_src3 + (i2 % params.mul_src_ne2) * params.stride_mul_src2 + (i1 % params.mul_src_ne1) * params.stride_mul_src1; + let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1; + + let elems = (params.ne0 + WG_SIZE - 1) / WG_SIZE; + + var sum = 0.0f; + var col = lid.x; + for (var j: u32 = 0; j < elems; j++) { + if (col >= params.ne0) { + break; + } +#ifdef SRC_OVERLAP + sum += pow(merged_src[i_rn_src_row + col], 2.0); +#else + sum += pow(rn_src[i_rn_src_row + col], 2.0); +#endif + col += WG_SIZE; + } + + scratch[lid.x] = sum; + + workgroupBarrier(); + + var offset: u32 = WG_SIZE / 2; + while (offset > 0) { + if (lid.x < offset) { + scratch[lid.x] += scratch[lid.x + offset]; + } + offset = offset / 2; + workgroupBarrier(); + } + sum = scratch[0]; + + let scale = 1.0/sqrt(sum/f32(params.ne0) + params.eps); + + col = lid.x; + for (var j: u32 = 0; j < elems; j++) { + if (col >= params.ne0) { + break; + } + update(i_rn_src_row + col, i_dst_row + col, scale, i_mul_src_row + col % params.mul_src_ne0); + col += WG_SIZE; + } +} From d2a26dc8e26edc72f0ba1b9d9f727d34625c9c7b Mon Sep 17 00:00:00 2001 From: Nikhil Jain Date: Wed, 22 Apr 2026 10:52:01 -0700 Subject: [PATCH 476/831] Implement async tensor api and event api (llama/22099) * Only run webgpu CI on my fork * Implement set_tensor_async * Implement synchronize api * Implement event creation and deletion API * Cleanup * Cleanup * Comment out jobs for local CI run * Add webgpu only workflow * Delete .github/workflows/build-webgpu.yml * Cleanup * Cleanup * Update API with function handlers * Run clang-format * Replace one-shot buffer with a direct queue.WriteBuffer using the buffer context --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 99 ++++++++++++++++++++++++++-- 1 file changed, 92 insertions(+), 7 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 5d3169904c5..44e3bf82216 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -2832,22 +2832,107 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str return GGML_STATUS_SUCCESS; } +struct ggml_backend_webgpu_event_context { + webgpu_global_context global_ctx; + wgpu::Future future; + bool recorded = false; +}; + +static ggml_backend_event_t ggml_backend_webgpu_device_event_new(ggml_backend_dev_t device) { + ggml_backend_webgpu_device_context * dev_ctx = (ggml_backend_webgpu_device_context *) device->context; + + auto * event_ctx = new ggml_backend_webgpu_event_context(); + event_ctx->global_ctx = dev_ctx->webgpu_global_ctx; + + auto * event = new ggml_backend_event; + event->device = device; + event->context = event_ctx; + return event; +} + +static void ggml_backend_webgpu_device_event_free(ggml_backend_dev_t dev, ggml_backend_event_t event) { + GGML_UNUSED(dev); + delete static_cast(event->context); + delete event; +} + +static void ggml_backend_webgpu_device_event_synchronize(ggml_backend_dev_t dev, ggml_backend_event_t event) { + GGML_UNUSED(dev); + ggml_backend_webgpu_event_context * event_ctx = (ggml_backend_webgpu_event_context *) event->context; + if (!event_ctx->recorded) { + return; + } + wgpu::WaitStatus status = + event_ctx->global_ctx->instance.WaitAny(event_ctx->future, WEBGPU_RUNTIME_WAIT_TIMEOUT_NS); + if (status == wgpu::WaitStatus::TimedOut) { + GGML_ABORT("ggml_webgpu: event_synchronize timed out after %u ms\n", WEBGPU_RUNTIME_WAIT_TIMEOUT_MS); + } + event_ctx->recorded = false; +} + +static void ggml_backend_webgpu_event_record(ggml_backend_t backend, ggml_backend_event_t event) { + ggml_backend_webgpu_context * backend_ctx = (ggml_backend_webgpu_context *) backend->context; + ggml_backend_webgpu_event_context * event_ctx = (ggml_backend_webgpu_event_context *) event->context; + + event_ctx->future = backend_ctx->webgpu_ctx->global_ctx->queue.OnSubmittedWorkDone( + wgpu::CallbackMode::AllowSpontaneous, [](wgpu::QueueWorkDoneStatus, wgpu::StringView) {}); + event_ctx->recorded = true; +} + +static void ggml_backend_webgpu_event_wait(ggml_backend_t backend, ggml_backend_event_t event) { + GGML_UNUSED(backend); + ggml_backend_webgpu_device_event_synchronize(nullptr, event); +} + +static void ggml_backend_webgpu_set_tensor_async(ggml_backend_t backend, + ggml_tensor * tensor, + const void * data, + size_t offset, + size_t size) { + GGML_UNUSED(backend); + auto * buf_ctx = (ggml_backend_webgpu_buffer_context *) tensor->buffer->context; + size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset; + + // Write aligned portion + buf_ctx->global_ctx->queue.WriteBuffer(buf_ctx->buffer, total_offset, data, (size / 4) * 4); + + if (size % 4 != 0) { + // If size is not a multiple of 4, we need to memset the remaining bytes + size_t remaining_size = size % 4; + + // pack the remaining bytes into a uint32_t + uint32_t val32 = 0; + + for (size_t i = 0; i < remaining_size; i++) { + ((uint8_t *) &val32)[i] = ((const uint8_t *) data)[size - remaining_size + i]; + } + // memset the remaining bytes + ggml_backend_webgpu_buffer_memset(buf_ctx->global_ctx, buf_ctx->buffer, val32, + total_offset + (size - remaining_size), remaining_size); + } +} + +static void ggml_backend_webgpu_synchronize(ggml_backend_t backend) { + ggml_backend_webgpu_context * backend_ctx = (ggml_backend_webgpu_context *) backend->context; + ggml_backend_webgpu_wait_queue(backend_ctx->webgpu_ctx->global_ctx); +} + static ggml_backend_i ggml_backend_webgpu_i = { /* .get_name = */ ggml_backend_webgpu_name, /* .free = */ ggml_backend_webgpu_free, - /* .set_tensor_async = */ NULL, + /* .set_tensor_async = */ ggml_backend_webgpu_set_tensor_async, /* .get_tensor_async = */ NULL, /* .get_tensor_2d_async = */ NULL, /* .set_tensor_2d_async = */ NULL, /* .cpy_tensor_async = */ NULL, - /* .synchronize = */ NULL, + /* .synchronize = */ ggml_backend_webgpu_synchronize, /* .graph_plan_create = */ NULL, /* .graph_plan_free = */ NULL, /* .graph_plan_update = */ NULL, /* .graph_plan_compute = */ NULL, /* .graph_compute = */ ggml_backend_webgpu_graph_compute, - /* .event_record = */ NULL, - /* .event_wait = */ NULL, + /* .event_record = */ ggml_backend_webgpu_event_record, + /* .event_wait = */ ggml_backend_webgpu_event_wait, /* .graph_optimize = */ NULL, }; @@ -3810,9 +3895,9 @@ static struct ggml_backend_device_i ggml_backend_webgpu_device_i = { /* .supports_op = */ ggml_backend_webgpu_device_supports_op, /* .supports_buft = */ ggml_backend_webgpu_device_supports_buft, /* .offload_op = */ NULL, - /* .event_new = */ NULL, - /* .event_free = */ NULL, - /* .event_synchronize = */ NULL, + /* .event_new = */ ggml_backend_webgpu_device_event_new, + /* .event_free = */ ggml_backend_webgpu_device_event_free, + /* .event_synchronize = */ ggml_backend_webgpu_device_event_synchronize, }; /* End GGML Backend Device Interface */ From 393fdffe20e5bdd9c0803220c75d53eeca90664f Mon Sep 17 00:00:00 2001 From: uvos Date: Thu, 23 Apr 2026 02:34:31 +0200 Subject: [PATCH 477/831] HIP: flip GGML_HIP_GRAPHS to default on (llama/22254) In #11362 hip graph was disabled by default as, at the time, its performance impact was negative. Due to improvements in rocm and our usage and construction of graphs this is no longer true, so lets change the default. --- ggml/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 2effd587b41..b9f7deb150d 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -213,7 +213,7 @@ set (GGML_CUDA_COMPRESSION_MODE "size" CACHE STRING set_property(CACHE GGML_CUDA_COMPRESSION_MODE PROPERTY STRINGS "none;speed;balance;size") option(GGML_HIP "ggml: use HIP" OFF) -option(GGML_HIP_GRAPHS "ggml: use HIP graph, experimental, slow" OFF) +option(GGML_HIP_GRAPHS "ggml: use HIP graph" ON) option(GGML_HIP_RCCL "ggml: use ROCm Collective Comm. Library" OFF) option(GGML_HIP_NO_VMM "ggml: do not try to use HIP VMM" ON) option(GGML_HIP_ROCWMMA_FATTN "ggml: enable rocWMMA for FlashAttention" OFF) From b6b547885cd431e457d58bad58eb5e9ba972919b Mon Sep 17 00:00:00 2001 From: Anav Prasad Date: Thu, 23 Apr 2026 02:28:56 +0000 Subject: [PATCH 478/831] CUDA: fuse relu + sqr (llama/22249) --- ggml/src/ggml-cuda/ggml-cuda.cu | 30 ++++++++++++++++++++++++++++++ ggml/src/ggml-cuda/unary.cu | 23 +++++++++++++++++++++++ ggml/src/ggml-cuda/unary.cuh | 2 ++ 3 files changed, 55 insertions(+) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 185956317e0..1c2c3b4ac69 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3592,6 +3592,30 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, return true; } + if (ops.size() == 2 && ops.begin()[0] == GGML_OP_UNARY && ops.begin()[1] == GGML_OP_SQR + && unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_RELU) { + const ggml_tensor * unary = cgraph->nodes[node_idx]; + const ggml_tensor * sqr = cgraph->nodes[node_idx+1]; + + if (ggml_get_unary_op(unary) != GGML_UNARY_OP_RELU) { + return false; + } + + if (unary->type != GGML_TYPE_F32 && unary->type != GGML_TYPE_F16) { + return false; + } + + if (unary->type != sqr->type) { + return false; + } + + if (!ggml_is_contiguous(unary->src[0])) { + return false; + } + + return true; + } + if (ops.size() == 3 && ops.begin()[0] == GGML_OP_SCALE && ops.begin()[1] == GGML_OP_UNARY && ops.begin()[2] == GGML_OP_SCALE && unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_TANH) { const ggml_tensor *scale = cgraph->nodes[node_idx]; @@ -4100,6 +4124,12 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud continue; } + if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_SQR }, { GGML_UNARY_OP_RELU })) { + ggml_cuda_op_relu_sqr(*cuda_ctx, node, cgraph->nodes[i+1]); + i++; + continue; + } + if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SCALE, GGML_OP_UNARY, GGML_OP_SCALE }, { GGML_UNARY_OP_TANH })) { i += 2; ggml_cuda_op_softcap(*cuda_ctx, cgraph->nodes[i], node); diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index 4ad30fa1f35..2aeba26f414 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -65,6 +65,11 @@ static __device__ __forceinline__ float op_sqr(float x) { return x * x; } +static __device__ __forceinline__ float op_relu_sqr(float x) { + const float r = fmaxf(x, 0.0f); + return r * r; +} + static __device__ __forceinline__ float op_sqrt(float x) { return sqrtf(x); } @@ -615,3 +620,21 @@ void ggml_cuda_op_unary_mul(ggml_backend_cuda_context & ctx, ggml_tensor * unary GGML_ABORT("Unsupported unary op for fused unary+mul"); } } + +/* fused relu + sqr */ + +void ggml_cuda_op_relu_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * relu_node, ggml_tensor * sqr_node) { + const ggml_tensor * src = relu_node->src[0]; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(ggml_is_contiguous(src)); + GGML_ASSERT(src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16); + GGML_ASSERT(src->type == sqr_node->type); + + const int k = ggml_nelements(src); + if (src->type == GGML_TYPE_F16) { + unary_cuda((const half *)src->data, (half *)sqr_node->data, k, stream); + } else { + unary_cuda((const float *)src->data, (float *)sqr_node->data, k, stream); + } +} diff --git a/ggml/src/ggml-cuda/unary.cuh b/ggml/src/ggml-cuda/unary.cuh index f1dd2183a6c..81ed873ecc3 100644 --- a/ggml/src/ggml-cuda/unary.cuh +++ b/ggml/src/ggml-cuda/unary.cuh @@ -91,6 +91,8 @@ void ggml_cuda_op_xielu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_unary_mul(ggml_backend_cuda_context & ctx, ggml_tensor * unary_node, ggml_tensor * mul_node); +void ggml_cuda_op_relu_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * relu_node, ggml_tensor * sqr_node); + __device__ __forceinline__ float ggml_cuda_op_silu_single(float x) { return x / (1.0f + expf(-x)); } From df528c4f71cec95e0d024a512f4d36928e4e61e6 Mon Sep 17 00:00:00 2001 From: Chen Yuan Date: Wed, 22 Apr 2026 23:17:41 -0400 Subject: [PATCH 479/831] ggml-webgpu: add support for im2col (llama/22259) * shader(im2col): implement the im2col shader * shader(im2col): clean the formatting issues * shader(im2col): clean the editorconfig checker warning * fix(shader): address the workgroup issues of im2col and conv2d --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 59 ++++++++ ggml/src/ggml-webgpu/ggml-webgpu.cpp | 127 +++++++++++++++--- ggml/src/ggml-webgpu/wgsl-shaders/im2col.wgsl | 101 ++++++++++++++ 3 files changed, 268 insertions(+), 19 deletions(-) create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/im2col.wgsl diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 6593a9fe16b..efc5b8c97a7 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -281,6 +281,25 @@ struct ggml_webgpu_conv2d_pipeline_key_hash { } }; +/** Im2Col **/ +struct ggml_webgpu_im2col_pipeline_key { + ggml_type input_type; + ggml_type output_type; + + bool operator==(const ggml_webgpu_im2col_pipeline_key & other) const { + return input_type == other.input_type && output_type == other.output_type; + } +}; + +struct ggml_webgpu_im2col_pipeline_key_hash { + size_t operator()(const ggml_webgpu_im2col_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.input_type); + ggml_webgpu_hash_combine(seed, key.output_type); + return seed; + } +}; + /** Gated Delta Net **/ struct ggml_webgpu_gated_delta_net_pipeline_key { int type; @@ -833,6 +852,8 @@ class ggml_webgpu_shader_lib { soft_max_pipelines; std::unordered_map conv2d_pipelines; + std::unordered_map + im2col_pipelines; std::unordered_maptype; + key.output_type = context.dst->type; + + auto it = im2col_pipelines.find(key); + if (it != im2col_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "im2col"; + + auto push_type_defines = [&](const char * prefix, ggml_type type) { + std::string s_prefix = prefix; + if (type == GGML_TYPE_F32) { + defines.push_back(s_prefix + "_F32"); + } else if (type == GGML_TYPE_F16) { + defines.push_back(s_prefix + "_F16"); + } else { + GGML_ABORT("Unsupported type for IM2COL shader"); + } + }; + + push_type_defines("INPUT", key.input_type); + push_type_defines("OUTPUT", key.output_type); + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_im2col, defines); + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + im2col_pipelines[key] = pipeline; + return im2col_pipelines[key]; + } + private: static webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & device, std::string shader_code, diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 44e3bf82216..bcca2bd4627 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -979,25 +979,108 @@ static webgpu_encoded_op ggml_webgpu_conv_2d(webgpu_context & ctx, ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst), }; - uint32_t max_wg_size = - std::min((uint32_t) WEBGPU_MAX_WG_SIZE, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupSizeX); - uint32_t wg_size = - std::min((uint32_t) ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, max_wg_size); - ggml_webgpu_shader_lib_context shader_lib_ctx = {}; shader_lib_ctx.src0 = src0; shader_lib_ctx.src1 = src1; shader_lib_ctx.dst = dst; - shader_lib_ctx.max_wg_size = wg_size; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; webgpu_pipeline pipeline = ctx->shader_lib->get_conv2d_pipeline(shader_lib_ctx); auto * decisions = static_cast(pipeline.context.get()); - uint32_t n_out = ggml_nelements(dst); - uint32_t total_wg = CEIL_DIV(n_out, decisions->wg_size); - uint32_t max_wg = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; - uint32_t wg_x = std::min(total_wg, max_wg); + uint32_t total_wg = CEIL_DIV((uint32_t) ggml_nelements(dst), decisions->wg_size); + uint32_t wg_x = std::min(ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, total_wg); + uint32_t wg_y = CEIL_DIV(total_wg, wg_x); + + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); +} + +static webgpu_encoded_op ggml_webgpu_im2col(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { + const int32_t s0 = ggml_get_op_params_i32(dst, 0); + const int32_t s1 = ggml_get_op_params_i32(dst, 1); + const int32_t p0 = ggml_get_op_params_i32(dst, 2); + const int32_t p1 = ggml_get_op_params_i32(dst, 3); + const int32_t d0 = ggml_get_op_params_i32(dst, 4); + const int32_t d1 = ggml_get_op_params_i32(dst, 5); + const bool is_2D = ggml_get_op_params_i32(dst, 6) == 1; + + const uint32_t KW = src0->ne[0]; + const uint32_t KH = is_2D ? src0->ne[1] : 1; + const uint32_t IC = is_2D ? src0->ne[2] : src0->ne[1]; + + const uint32_t IW = src1->ne[0]; + const uint32_t IH = is_2D ? src1->ne[1] : 1; + const uint32_t N = is_2D ? src1->ne[3] : src1->ne[2]; + + const uint32_t OW = dst->ne[1]; + const uint32_t OH = is_2D ? dst->ne[2] : 1; + + const uint32_t si0 = (uint32_t) (src1->nb[0] / ggml_type_size(src1->type)); + const uint32_t si1 = is_2D ? (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)) : 0; + const uint32_t si2 = is_2D ? (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)) : + (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)); + const uint32_t si3 = is_2D ? (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)) : + (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)); + + const uint32_t so0 = (uint32_t) (dst->nb[0] / ggml_type_size(dst->type)); + const uint32_t so1 = (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)); + const uint32_t so2 = is_2D ? (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)) : 0; + const uint32_t so3 = is_2D ? (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)) : + (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)); + + std::vector params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + + si0, + si1, + si2, + si3, + so0, + so1, + so2, + so3, + + KW, + KH, + IC, + + IW, + IH, + N, + + OW, + OH, + + (uint32_t) s0, + (uint32_t) s1, + (uint32_t) p0, + (uint32_t) p1, + (uint32_t) d0, + (uint32_t) d1, + }; + + std::vector entries = { + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src1), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst), + }; + + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + + webgpu_pipeline pipeline = ctx->shader_lib->get_im2col_pipeline(shader_lib_ctx); + + auto * decisions = static_cast(pipeline.context.get()); + + uint32_t total_wg = CEIL_DIV((uint32_t) ggml_nelements(dst), decisions->wg_size); + uint32_t wg_x = std::min(ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, total_wg); uint32_t wg_y = CEIL_DIV(total_wg, wg_x); return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); @@ -1988,8 +2071,8 @@ static std::optional ggml_webgpu_rms_norm_mul(webgpu_context GGML_ABORT("rms_norm must be equal to the one of mul_src0 and mul_src1"); } - bool inplace = (ggml_webgpu_tensor_equal(rn_dst, mul_src0) && ggml_webgpu_tensor_equal(mul_src1, dst)) || - (ggml_webgpu_tensor_equal(rn_dst, mul_src1) && ggml_webgpu_tensor_equal(mul_src0, dst)); + bool inplace = (ggml_webgpu_tensor_equal(rn_dst, mul_src0) && ggml_webgpu_tensor_equal(mul_src1, dst)) || + (ggml_webgpu_tensor_equal(rn_dst, mul_src1) && ggml_webgpu_tensor_equal(mul_src0, dst)); bool src_overlap = ggml_webgpu_tensor_overlap(rn_src, mul_src); uint32_t offset_merged_rn_src = 0; @@ -2689,6 +2772,8 @@ static std::optional ggml_webgpu_encode(webgpu_context ctx, return ggml_webgpu_sum_rows(ctx, src0, node); case GGML_OP_CONV_2D: return ggml_webgpu_conv_2d(ctx, src0, src1, node); + case GGML_OP_IM2COL: + return ggml_webgpu_im2col(ctx, src0, src1, node); default: return std::nullopt; } @@ -3455,7 +3540,7 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) { ggml_backend_webgpu_device_context * dev_ctx = (ggml_backend_webgpu_device_context *) dev->context; webgpu_context webgpu_ctx = std::make_shared(); webgpu_ctx->global_ctx = dev_ctx->webgpu_global_ctx; - webgpu_ctx->shader_lib = std::make_unique(dev_ctx->webgpu_global_ctx->device); + webgpu_ctx->shader_lib = std::make_unique(dev_ctx->webgpu_global_ctx->device); webgpu_ctx->param_arena.init( webgpu_ctx->global_ctx->device, WEBGPU_PARAMS_BUF_SIZE_BYTES, webgpu_ctx->global_ctx->command_submit_batch_size + WEBGPU_NUM_PARAM_SLOT_SAFETY_MARGIN, @@ -3705,12 +3790,12 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const break; } // Head dimensions must fit in workgroup memory with minimum tile sizes - size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; - const bool has_mask = op->src[3] != nullptr; - const bool kv_direct = src1->type == GGML_TYPE_F16 && - (src0->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k) == 0 && - (src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD) == 0; - const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes( + size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; + const bool has_mask = op->src[3] != nullptr; + const bool kv_direct = src1->type == GGML_TYPE_F16 && + (src0->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k) == 0 && + (src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD) == 0; + const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes( ctx->webgpu_global_ctx->capabilities.sg_mat_m, ctx->webgpu_global_ctx->capabilities.sg_mat_n, (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask, kv_direct); if (min_bytes > limit_bytes) { @@ -3802,6 +3887,10 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16) && (src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); break; + case GGML_OP_IM2COL: + supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && + (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); + break; case GGML_OP_SSM_CONV: supports_op = op->type == GGML_TYPE_F32; break; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/im2col.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/im2col.wgsl new file mode 100644 index 00000000000..386ebab879f --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/im2col.wgsl @@ -0,0 +1,101 @@ +#include "common_decls.tmpl" +enable f16; + +@group(0) @binding(0) +#if defined(INPUT_F32) +var input: array; +#elif defined(INPUT_F16) +var input: array; +#endif + +@group(0) @binding(1) +#if defined(OUTPUT_F32) +var output: array; +#elif defined(OUTPUT_F16) +var output: array; +#endif + +struct Params { + offset_i: u32, + offset_o: u32, + + // element strides + si0: u32, si1: u32, si2: u32, si3: u32, + so0: u32, so1: u32, so2: u32, so3: u32, + + KW: u32, KH: u32, IC: u32, + IW: u32, IH: u32, N: u32, + OW: u32, OH: u32, + + // stride + s0: u32, s1: u32, + // padding + p0: u32, p1: u32, + // dilation + d0: u32, d1: u32, +} + +@group(0) @binding(2) +var params: Params; + +fn load_input(idx: u32) -> f32 { + #if defined(INPUT_F32) + return input[idx]; + #elif defined(INPUT_F16) + return f32(input[idx]); + #endif +} + +fn store_output(idx: u32, val: f32) { + #if defined(OUTPUT_F32) + output[idx] = val; + #elif defined(OUTPUT_F16) + output[idx] = f16(val); + #endif +} + +@compute @workgroup_size(WG_SIZE) +fn main( + @builtin(global_invocation_id) gid: vec3, + @builtin(num_workgroups) num_wg: vec3 +) { + + let threads_per_group = u32(WG_SIZE); + let i_out = gid.x + (num_wg.x * threads_per_group) * gid.y; + let K = params.KW * params.KH * params.IC; + let M = params.OW * params.OH; + let total = K * M * params.N; + + if (i_out >= total) { + return; + } + + // decode (k, m, n) + var i = i_out; + let n = i / (K * M); + i = i % (K * M); + let m = i / K; + let k = i % K; + + // decode (oh, ow) + let oh = m / params.OW; + let ow = m % params.OW; + + // decode (kw, kh, ic) + let kw = k % params.KW; + let tmp = k / params.KW; + let kh = tmp % params.KH; + let ic = tmp / params.KH; + + let iw_i32 = i32(ow * params.s0 + kw * params.d0) - i32(params.p0); + let ih_i32 = i32(oh * params.s1 + kh * params.d1) - i32(params.p1); + + if (iw_i32 >= 0 && iw_i32 < i32(params.IW) && ih_i32 >= 0 && ih_i32 < i32(params.IH)) { + let iw = u32(iw_i32); + let ih = u32(ih_i32); + let in_idx = params.offset_i + iw * params.si0 + ih * params.si1 + ic * params.si2 + n * params.si3; + store_output(params.offset_o + k * params.so0 + ow * params.so1 + oh * params.so2 + n * params.so3, load_input(in_idx)); + } else { + store_output(params.offset_o + k * params.so0 + ow * params.so1 + oh * params.so2 + n * params.so3, 0.0); + } +} From b938c5026c42d5a1c3a11665fe714cc1768d0823 Mon Sep 17 00:00:00 2001 From: abotsis Date: Wed, 22 Apr 2026 23:18:56 -0600 Subject: [PATCH 480/831] sycl : fused MoE mul_mat_vec_q for TG (llama/21920) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * sycl : fused MoE mul_mat_vec_q for TG Create an MMVQ kernel so ggml_sycl_mul_mat_id can consolidate n_experts_used matmuls in a single kernel launch. The kernel also reads expert IDs directly, removing a per-call host sync. This is similar to the CUDA backend's ggml_cuda_mul_mat_vec_q* paths. All types supported in the current MMVQ are supported here as well: Q2_K, Q3_K, Q4_K, Q5_K, Q6_K, Q4_0, Q4_1, Q5_0, Q5_1, Q8_0 It will fall back to the existing per-expert path when src0 has been rewritten by opt_for_reorder(), and for any shape the fused path doesn't handle. test-backend-ops passes for supported type/shape combos. Benchmark: Qwen3-Next-35B-A3B Q4_K_M on Intel Arc B70 (SYCL0), baseline 707c0b7a6, 16k context, -fa 0. build/bin/llama-bench -hf unsloth/Qwen3.5-35B-A3B-GGUF:Q4_K_M \ -p 1024 -n 128 -d 16384 -ngl 99 -fa 0 -ub 2048 -r 2 -dev SYCL0 Before (3 runs on 707c0b7a6): | test | run 1 | run 2 | run 3 | | --------------- | ----------------:| ----------------:| ----------------:| | pp1024 @ d16384 | 533.26 ± 4.87 | 535.20 ± 2.78 | 524.27 ± 3.10 | | tg128 @ d16384 | 33.47 ± 0.02 | 33.31 ± 0.02 | 33.17 ± 0.05 | After (3 runs on 707c0b7a6 + this patch): | test | run 1 | run 2 | run 3 | | --------------- | ----------------:| ----------------:| ----------------:| | pp1024 @ d16384 | 534.06 ± 0.97 | 531.95 ± 0.02 | 520.94 ± 20.10 | | tg128 @ d16384 | 45.85 ± 0.21 | 45.95 ± 0.45 | 46.22 ± 0.12 | disclosure: Claude wrote it, but I reviewed and understand the implementation (albeit my C is a little rusty). * sycl: also support nvfp4 and mxfp4 expert types * sycl: terser comments/nested dispatch in response to review * sycl: more comment cleanup in mmvq.cpp/hpp --------- Co-authored-by: Debian --- ggml/src/ggml-sycl/ggml-sycl.cpp | 51 +++++++++++ ggml/src/ggml-sycl/mmvq.cpp | 151 +++++++++++++++++++++++++++++++ ggml/src/ggml-sycl/mmvq.hpp | 16 ++++ 3 files changed, 218 insertions(+) diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 3829da87903..36923160d72 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -3808,6 +3808,51 @@ __dpct_inline__ static void k_copy_dst_from_contiguous( } } +// Fused MoE TG fast path. Returns false to fall back to the per-expert loop below. +static bool ggml_sycl_mul_mat_id_mmvq_fused( + ggml_backend_sycl_context & ctx, const ggml_tensor * src0, + const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) +{ + const int64_t ne10 = src1->ne[0]; + const int64_t ne11 = src1->ne[1]; + const int64_t ne12 = src1->ne[2]; + if (ne12 != 1) return false; + if (src1->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) return false; + if (ne10 != src0->ne[0] || ne10 % QK8_1 != 0) return false; + if (!ggml_is_contiguous(src1)) return false; + + // Reorder layout not supported; fall back. + const ggml_tensor_extra_gpu * src0_extra = + static_cast(src0->extra); + if (src0_extra && src0_extra->optimized_feature.reorder) return false; + + const int64_t n_ids_per_group = ids->ne[0]; + if (ids->ne[1] != 1) return false; + if (ne11 != 1 && ne11 != n_ids_per_group) return false; + + const queue_ptr stream = ctx.stream(); + const int src1_padded_cols = GGML_PAD((int) ne10, MATRIX_ROW_PADDING); + const int n_experts_used = (int) n_ids_per_group; + const int nrows = (int) src0->ne[1]; + + ggml_sycl_pool_alloc src1_q8_alloc(ctx.pool(), + (size_t) ne11 * src1_padded_cols * sizeof(block_q8_1) / QK8_1); + char * src1_ddq = src1_q8_alloc.get(); + quantize_row_q8_1_sycl( + (const float *) src1->data, src1_ddq, (int) ne10, (int) ne11, + src1_padded_cols, stream); + + const size_t bytes_per_qrow = (size_t) src1_padded_cols * sizeof(block_q8_1) / QK8_1; + const size_t src1_row_stride = (ne11 == 1) ? 0 : bytes_per_qrow; + + return ggml_sycl_mul_mat_vec_q_id( + src0->type, src0->data, src1_ddq, (const int32_t *) ids->data, + (float *) dst->data, (int) ne10, nrows, n_experts_used, + /*expert_weight_stride=*/ src0->nb[2], + /*dst_row_stride=*/ dst->nb[1], + src1_row_stride, stream); +} + static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, ggml_tensor *dst) try { scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/3); @@ -3823,6 +3868,12 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, const int64_t n_as = ne02; const int64_t n_ids = ids->ne[0]; + if (ne12 == 1) { + if (ggml_sycl_mul_mat_id_mmvq_fused(ctx, src0, src1, ids, dst)) { + return; + } + } + std::vector ids_host(ggml_nbytes(ids)); const char * ids_dev = (const char *) ids->data; diff --git a/ggml/src/ggml-sycl/mmvq.cpp b/ggml/src/ggml-sycl/mmvq.cpp index 3a4577ecbbc..8fa2198f35a 100644 --- a/ggml/src/ggml-sycl/mmvq.cpp +++ b/ggml/src/ggml-sycl/mmvq.cpp @@ -1199,3 +1199,154 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens GGML_UNUSED(src1_ddf_i); GGML_UNUSED(ctx); } + +// src1_row_stride: 0 for shared src1 (gate/up proj), else per-expert stride (down proj). +template +static void mul_mat_vec_q_moe( + const void * __restrict__ vx_base, const void * __restrict__ vy_base, + float * __restrict__ dst_base, const int32_t * __restrict__ ids_dev, + const int ncols, const int nrows, + const size_t expert_weight_stride, const size_t dst_row_stride, + const size_t src1_row_stride, + const sycl::nd_item<3> & item_ct1) { + + const int expert_idx = item_ct1.get_group(1); + const int i02 = ids_dev[expert_idx]; + + const char * vx = (const char *) vx_base + (size_t) i02 * expert_weight_stride; + const char * vy = (const char *) vy_base + (size_t) expert_idx * src1_row_stride; + float * dst = (float *) ((char *) dst_base + (size_t) expert_idx * dst_row_stride); + + const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1); + + if (row >= nrows) { + return; + } + + const int blocks_per_row = ncols / qk; + constexpr int blocks_per_warp = (vdr * WARP_SIZE + qi - 1) / qi; + + float tmp = 0.0f; + + const block_q_t * x = (const block_q_t *) vx; + const block_q8_1 * y = (const block_q8_1 *) vy; + + for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row; i += blocks_per_warp) { + const int ibx = row * blocks_per_row + i; + const int iby = i * (qk / QK8_1); + + for (size_t elem = 0; elem < qi / vdr; elem += WARP_SIZE) { + const int iqs = elem + vdr * (item_ct1.get_local_id(2) % (qi / vdr)); + tmp += vec_dot_q_sycl(&x[ibx], &y[iby], iqs); + } + } + +#pragma unroll + for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { + tmp += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); + } + + if (item_ct1.get_local_id(2) == 0) { + dst[row] = tmp; + } +} + +template +static void launch_mul_mat_vec_q_moe( + const void * vx_base, const void * vy, const int32_t * ids_dev, + float * dst_base, const int ncols, const int nrows, const int n_experts_used, + const size_t expert_weight_stride, const size_t dst_row_stride, + const size_t src1_row_stride, + dpct::queue_ptr stream) { + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, (unsigned) n_experts_used, (unsigned) block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_moe( + vx_base, vy, dst_base, ids_dev, ncols, nrows, + expert_weight_stride, dst_row_stride, src1_row_stride, item); + }); + }); +} + +bool ggml_sycl_mul_mat_vec_q_id( + enum ggml_type src0_type, + const void * vx_base, + const void * vy, + const int32_t * ids_dev, + float * dst_base, + int ncols, + int nrows, + int n_experts_used, + size_t expert_weight_stride, + size_t dst_row_stride, + size_t src1_row_stride, + dpct::queue_ptr stream) { + switch (src0_type) { + case GGML_TYPE_Q4_0: + launch_mul_mat_vec_q_moe( + vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used, + expert_weight_stride, dst_row_stride, src1_row_stride, stream); + return true; + case GGML_TYPE_Q4_1: + launch_mul_mat_vec_q_moe( + vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used, + expert_weight_stride, dst_row_stride, src1_row_stride, stream); + return true; + case GGML_TYPE_Q5_0: + launch_mul_mat_vec_q_moe( + vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used, + expert_weight_stride, dst_row_stride, src1_row_stride, stream); + return true; + case GGML_TYPE_Q5_1: + launch_mul_mat_vec_q_moe( + vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used, + expert_weight_stride, dst_row_stride, src1_row_stride, stream); + return true; + case GGML_TYPE_Q8_0: + launch_mul_mat_vec_q_moe( + vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used, + expert_weight_stride, dst_row_stride, src1_row_stride, stream); + return true; + case GGML_TYPE_Q2_K: + launch_mul_mat_vec_q_moe( + vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used, + expert_weight_stride, dst_row_stride, src1_row_stride, stream); + return true; + case GGML_TYPE_Q3_K: + launch_mul_mat_vec_q_moe( + vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used, + expert_weight_stride, dst_row_stride, src1_row_stride, stream); + return true; + case GGML_TYPE_Q4_K: + launch_mul_mat_vec_q_moe( + vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used, + expert_weight_stride, dst_row_stride, src1_row_stride, stream); + return true; + case GGML_TYPE_Q5_K: + launch_mul_mat_vec_q_moe( + vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used, + expert_weight_stride, dst_row_stride, src1_row_stride, stream); + return true; + case GGML_TYPE_Q6_K: + launch_mul_mat_vec_q_moe( + vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used, + expert_weight_stride, dst_row_stride, src1_row_stride, stream); + return true; + case GGML_TYPE_MXFP4: + launch_mul_mat_vec_q_moe( + vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used, + expert_weight_stride, dst_row_stride, src1_row_stride, stream); + return true; + case GGML_TYPE_NVFP4: + launch_mul_mat_vec_q_moe( + vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used, + expert_weight_stride, dst_row_stride, src1_row_stride, stream); + return true; + default: + return false; + } +} diff --git a/ggml/src/ggml-sycl/mmvq.hpp b/ggml/src/ggml-sycl/mmvq.hpp index 049b43d4535..d674dc1d61e 100644 --- a/ggml/src/ggml-sycl/mmvq.hpp +++ b/ggml/src/ggml-sycl/mmvq.hpp @@ -24,4 +24,20 @@ void ggml_sycl_op_mul_mat_vec_q( const int64_t src1_ncols, const int64_t src1_padded_row_size, const dpct::queue_ptr &stream); +// Requires standard (non-reorder) block layout for src0. +// Returns false if src0_type isn't handled; caller should fall back. +bool ggml_sycl_mul_mat_vec_q_id( + enum ggml_type src0_type, + const void * vx_base, // start of stacked expert weights + const void * vy, // pre-quantized src1 (Q8_1) + const int32_t * ids_dev, // device-side int32, length n_experts_used + float * dst_base, + int ncols, + int nrows, + int n_experts_used, + size_t expert_weight_stride, // bytes between experts in vx_base + size_t dst_row_stride, // bytes between dst rows + size_t src1_row_stride, // 0 = shared src1, else per-expert stride in bytes + dpct::queue_ptr stream); + #endif // GGML_SYCL_MMVQ_HPP From 1aba06173778618cc3d56ce6201eb703935c5e99 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 23 Apr 2026 08:22:08 +0300 Subject: [PATCH 481/831] ggml-base: use MATH_LIBRARY variable instead of hardcoded 'm' (llama/22239) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixes #22237 — the find_library(MATH_LIBRARY m) result was being discarded and the target linked against the literal 'm' string. This prevents users from overriding the math library (e.g. for AMD AOCL) via CMake variables. Now the discovered MATH_LIBRARY is used directly. --- ggml/src/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index 48fbe208d90..52754e1b9d6 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -473,7 +473,7 @@ target_link_libraries(ggml-base PRIVATE Threads::Threads) find_library(MATH_LIBRARY m) if (MATH_LIBRARY) if (NOT WIN32 OR NOT DEFINED ENV{ONEAPI_ROOT}) - target_link_libraries(ggml-base PRIVATE m) + target_link_libraries(ggml-base PRIVATE ${MATH_LIBRARY}) endif() endif() From 682ee993057efd295a81140bae4067d6346e1f92 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 23 Apr 2026 08:22:49 +0300 Subject: [PATCH 482/831] metal : fix event synchronization (llama/22260) --- ggml/src/ggml-metal/ggml-metal-device.m | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 27cb1683518..f17f7e2e0ce 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -931,13 +931,13 @@ void ggml_metal_device_rsets_keep_alive(ggml_metal_device_t dev) { } struct ggml_metal_event { - void * obj; // id + void * obj; // id atomic_int value; }; void ggml_metal_event_encode_signal(ggml_metal_event_t ev, ggml_metal_cmd_buf_t cmd_buf_raw) { - id event = (id)ev->obj; + id event = (id)ev->obj; id cmd_buf = (id) cmd_buf_raw; @@ -945,7 +945,7 @@ void ggml_metal_event_encode_signal(ggml_metal_event_t ev, ggml_metal_cmd_buf_t } void ggml_metal_event_encode_wait(ggml_metal_event_t ev, ggml_metal_cmd_buf_t cmd_buf_raw) { - id event = (id)ev->obj; + id event = (id)ev->obj; id cmd_buf = (id) cmd_buf_raw; @@ -953,7 +953,7 @@ void ggml_metal_event_encode_wait(ggml_metal_event_t ev, ggml_metal_cmd_buf_t cm } ggml_metal_event_t ggml_metal_device_event_init(ggml_metal_device_t dev) { - id event = [dev->mtl_device newEvent]; + id event = [dev->mtl_device newSharedEvent]; ggml_metal_event_t ev = calloc(1, sizeof(struct ggml_metal_event)); @@ -964,7 +964,7 @@ ggml_metal_event_t ggml_metal_device_event_init(ggml_metal_device_t dev) { } void ggml_metal_device_event_free(ggml_metal_device_t dev, ggml_metal_event_t ev) { - id event = ev->obj; + id event = ev->obj; [event release]; free(ev); @@ -973,14 +973,13 @@ void ggml_metal_device_event_free(ggml_metal_device_t dev, ggml_metal_event_t ev } void ggml_metal_device_event_synchronize(ggml_metal_device_t dev, ggml_metal_event_t ev) { - @autoreleasepool { - id event = ev->obj; - - id cmd_buf = [dev->mtl_queue commandBuffer]; - [cmd_buf encodeWaitForEvent:event value:atomic_load_explicit(&ev->value, memory_order_relaxed)]; - [cmd_buf commit]; - [cmd_buf waitUntilCompleted]; + id event = ev->obj; + const bool res = [event waitUntilSignaledValue:atomic_load_explicit(&ev->value, memory_order_relaxed) timeoutMS:60000]; + if (!res) { + GGML_ABORT("%s: failed to wait for event\n", __func__); } + + GGML_UNUSED(dev); } void ggml_metal_device_get_memory(ggml_metal_device_t dev, size_t * free, size_t * total) { From 71b1ab37841177903e7e97420489ede09bd231b1 Mon Sep 17 00:00:00 2001 From: Max Krasnyansky Date: Thu, 23 Apr 2026 14:17:21 -0700 Subject: [PATCH 483/831] hexagon: add support for basic and extended Op profiling (llama/22269) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * hexagon: restore HTP_OPMASK_QUEUE * hexagon: honor OPMASK_SKIP_COMPUTE in hmx-matmul * hex-prof: restore op profiling * hex-prof: enable PMU * hexagon: simplify and improve op-queuing with full profiling support Add separate profile descriptors. * hexagon: remove opsync and rename opmask into opstage opsync is no longer needed since the profiler is fully async now. opmask name was confusing and opstage is more accurate. * hexagon: refactor opbatch queue handling * hexagon: add iface hooks for enabling profiler from the host Also move all the PMU setup stuff out of the hex-utils since it's not inteded for normal use. * hexagon: make profiler mode configurable On older devices getting PMU counters is expensive so it's now optional. * hexagon: add support for setting profiler pmu events from env * hexagon: simplify profiler output (no need to print buffs, etc) * hexagon: simplify pmu counter formating * hexagon: add a simple profile post-proc tool * hex-prof: add support for reading logs from stdin * hexagon: document GGML_HEXAGON_PROFILE * hex-prof: update default width for dims field * hex-prof: fix linter warnings and errors * Update ggml/src/ggml-hexagon/htp/htp-ops.h Co-authored-by: Sigbjørn Skjæret * Update scripts/snapdragon/ggml-hexagon-profile.py Co-authored-by: Sigbjørn Skjæret --------- Co-authored-by: Trivikram Reddy Co-authored-by: Sigbjørn Skjæret --- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 402 +++++++++++++++--------- ggml/src/ggml-hexagon/htp/hex-utils.h | 28 ++ ggml/src/ggml-hexagon/htp/htp-ctx.h | 5 +- ggml/src/ggml-hexagon/htp/htp-ops.h | 36 ++- ggml/src/ggml-hexagon/htp/htp_iface.idl | 8 +- ggml/src/ggml-hexagon/htp/main.c | 172 +++++++--- ggml/src/ggml-hexagon/htp/matmul-ops.c | 4 + 7 files changed, 442 insertions(+), 213 deletions(-) diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index cdd9fcf5928..955903418b6 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -12,9 +12,12 @@ #include #include #include +#include +#include #include #include #include +#include #ifdef _WIN32 # include @@ -41,18 +44,26 @@ #include "htp_iface.h" #include "htp-drv.h" +using intvec = std::vector; +using uintvec = std::vector; +using u32vec = std::vector; + static size_t opt_ndev = 1; static size_t opt_nhvx = 0; // use all static int opt_arch = 0; // autodetect static int opt_etm = 0; static int opt_verbose = 0; -static int opt_profile = 0; +static int opt_profile = 0; // profiling mode (0-disabled, 1-basic, 2-pmu) static int opt_hostbuf = 1; // hostbuf ON by default static int opt_use_hmx = 1; // when set, enable HMX; when 0, use HVX only +// Default PMU events, if profiling with PMU (mode=2) is enabled +// See https://docs.qualcomm.com/doc/80-N2040-60/topic/pmu-events.html +// https://docs.qualcomm.com/doc/80-N2040-61/topic/hvx-pmu-events.html +static u32vec opt_pmu_evt { 0x3, 0x111, 0x100, 0x105, 0x240, 0x256, 0x7D, 0x8C }; + // Enable all stages by default -static int opt_opmask = HTP_OPMASK_QUEUE | HTP_OPMASK_COMPUTE; -static int opt_opsync = 0; // synchronous ops +static int opt_opstage = HTP_OPSTAGE_QUEUE | HTP_OPSTAGE_COMPUTE; static int opt_opbatch = 1024; // max number of ops in a batch static int opt_opqueue = 16; // max number of pending batches static std::regex* opt_opfilter = NULL; // regex of ops to not claim @@ -104,19 +115,26 @@ static void ggml_hexagon_dump_op_supp(const std::string &sess_name, const struct } static void ggml_hexagon_dump_op_prof(const std::string &sess_name, const ggml_tensor * op, - uint32_t op_usec, uint32_t op_cycles, uint32_t op_pkts, uint64_t call_usec) { + uint32_t op_usec, uint32_t op_cycles, const uint32_t pmu[]) { if (!opt_profile) return; op_desc desc(op); - GGML_LOG_DEBUG("ggml-hex: %s profile-op %s: %s : %s : %s : %s : %s : op-usec %u op-cycles %u op-pkts %u (%f) call-usec %llu\n", sess_name.c_str(), - ggml_op_desc(op), desc.names, desc.dims, desc.types, desc.strides, desc.buffs, - op_usec, op_cycles, op_pkts, (float) op_cycles / op_pkts, (unsigned long long) call_usec); + + char pmu_str[256] = ""; + if (opt_profile > 1) { + static_assert(HTP_PROF_PMU_NCNT == 8, "current implementation assumes 8 PMU counters"); + sprintf(pmu_str, " pmu [%u,%u,%u,%u,%u,%u,%u,%u]", + pmu[0], pmu[1], pmu[2], pmu[3], pmu[4], pmu[5], pmu[6], pmu[7]); + } + + GGML_LOG_DEBUG("ggml-hex: %s profile-op %s: %s : %s : %s : %s : usec %u cycles %u%s\n", sess_name.c_str(), + ggml_op_desc(op), desc.names, desc.dims, desc.types, desc.strides, op_usec, op_cycles, pmu_str); } // ** backend sessions struct ggml_hexagon_opbatch; -struct ggml_hexagon_opshm; +struct ggml_hexagon_opqueue; struct ggml_hexagon_session { std::string name; @@ -132,8 +150,8 @@ struct ggml_hexagon_session { bool valid_iface; std::atomic op_pending; - ggml_hexagon_opbatch *op_batch; - ggml_hexagon_opshm *op_shm; + ggml_hexagon_opbatch* op_batch; + ggml_hexagon_opqueue* op_queue; ggml_backend_buffer_type buffer_type = {}; ggml_backend_buffer_type repack_buffer_type = {}; @@ -1521,65 +1539,14 @@ static ggml_backend_buffer_type_i ggml_backend_hexagon_repack_buffer_type_interf // Backend session implementation -struct ggml_hexagon_opshm { - ggml_hexagon_shared_buffer *sbuf; - - std::vector block_mask; - size_t block_size; - - uint8_t * base() const { return this->sbuf->base; } - int fd() const { return this->sbuf->fd; } - size_t n_blocks() const { return this->block_mask.size(); } - - ggml_hexagon_opshm(ggml_hexagon_session *sess, size_t max_batch, size_t max_pending) { - size_t n_bufs = HTP_OP_MAX_BUFS; - size_t n_ops = max_batch; - size_t n_tensors = n_ops + n_ops * HTP_OP_MAX_INPUTS; - - block_mask.resize(max_pending, true); - - block_size = sizeof(htp_buf_desc) * n_bufs + - sizeof(htp_tensor) * n_tensors + - sizeof(htp_op_desc) * n_ops; - - sbuf = new ggml_hexagon_shared_buffer(sess, block_size * block_mask.size(), true /* pinned */); - - if (opt_verbose) { - GGML_LOG_INFO("ggml-hex: %s allocated shared buf %zu : block-size %zu max-batch %zu max-pending %zu\n", - sess->c_name(), (size_t) sbuf->size, block_size, max_batch, max_pending); - } - } - - ~ggml_hexagon_opshm() { - delete sbuf; - } - - uint8_t * allocate() { - auto it = std::find(block_mask.begin(), block_mask.end(), true); - if (it == block_mask.end()) - return nullptr; - - unsigned int i = std::distance(block_mask.begin(), it); - uint8_t* addr = sbuf->base + (i * block_size); - block_mask[i] = false; - - HEX_VERBOSE("ggml-hex: %s allocated op shm #%u %p\n", sbuf->sess->c_name(), i, (void*) addr); - return addr; - } - - void release(uint8_t * addr) { - int i = (addr - sbuf->base) / block_size; - block_mask[i] = true; - HEX_VERBOSE("ggml-hex: %s released op shm #%u %p\n", sbuf->sess->c_name(), i, (void*) addr); - } -}; - struct ggml_hexagon_opbatch { - const char* name; + ggml_hexagon_session* sess; - std::vector buffers; - std::vector tensors; - std::vector ops; + std::vector ops; // pointers to original ops + + std::vector h_bufs; // htp buffer descriptors + std::vector h_tens; // htp tensor descriptors + std::vector h_ops; // htp op descriptors std::unordered_map b_map; // buffer fd to index std::unordered_map t_map; // tensor ptr to index @@ -1606,19 +1573,21 @@ struct ggml_hexagon_opbatch { d_map.clear(); } - ggml_hexagon_opbatch(ggml_hexagon_session *sess, size_t max_batch) { - name = sess->c_name(); + ggml_hexagon_opbatch(ggml_hexagon_session *sess, size_t batch_size) { + this->sess = sess; n_bufs_max = HTP_OP_MAX_BUFS; - n_ops_max = max_batch; + n_ops_max = batch_size; n_tens_max = n_ops_max + n_ops_max * HTP_OP_MAX_INPUTS; b_vmem_max = HTP_OP_MAX_VMEM; - buffers.resize(n_bufs_max); - tensors.resize(n_tens_max); ops.resize(n_ops_max); + h_bufs.resize(n_bufs_max); + h_tens.resize(n_tens_max); + h_ops.resize(n_ops_max); + b_map.reserve(n_bufs_max); t_map.reserve(n_tens_max); d_map.reserve(n_tens_max); @@ -1640,7 +1609,7 @@ struct ggml_hexagon_opbatch { b_map.insert({sbuf->fd, bi}); - htp_buf_desc &b = buffers[bi]; + htp_buf_desc &b = h_bufs[bi]; b.base = (uint64_t) sbuf->base; b.fd = sbuf->fd; b.size = sbuf->size; @@ -1664,7 +1633,7 @@ struct ggml_hexagon_opbatch { // First lookup by tensor data auto range = d_map.equal_range(t->data); for (auto it = range.first; it != range.second; ++it) { - htp_tensor * h = &tensors[it->second]; + htp_tensor * h = &h_tens[it->second]; if (same_shape(h, t)) { return it->second; } } @@ -1682,7 +1651,7 @@ struct ggml_hexagon_opbatch { uint64_t t_offset = (uint8_t *) t->data - sbuf->base; size_t t_size = ggml_nbytes(t); - htp_tensor &h = tensors[ti]; + htp_tensor &h = h_tens[ti]; h.bi = add_buffer(sbuf); h.data = t_offset; h.size = t_size; @@ -1737,65 +1706,170 @@ struct ggml_hexagon_opbatch { // assumes that fit_op() was called first and returned true void add_op(htp_op_code opcode, const struct ggml_tensor * t) { // Add new op - htp_op_desc &o = ops[n_ops++]; + + unsigned int n = n_ops++; GGML_ASSERT(n_ops <= n_ops_max); + ops[n] = t; + + htp_op_desc &o = h_ops[n]; memcpy(&o.params, &t->op_params, sizeof(t->op_params)); o.opcode = opcode; o.flags = 0; - if (!(opt_opmask & HTP_OPMASK_COMPUTE)) { + if (!(opt_opstage & HTP_OPSTAGE_COMPUTE)) { o.flags |= HTP_OPFLAGS_SKIP_COMPUTE; } - ggml_hexagon_dump_op_exec(name, t, o.flags); + ggml_hexagon_dump_op_exec(sess->c_name(), t, o.flags); for (unsigned int i=0; i < HTP_OP_MAX_INPUTS; i++) { o.src[i] = t->src[i] ? add_tensor(t->src[i]) : 0xffff; } o.dst = add_tensor(t); } +}; + +struct ggml_hexagon_opqueue { + // Shared buffer for storing batches + ggml_hexagon_shared_buffer *shm_buf; + size_t shm_blk_size; + + using opvec = std::vector; + + std::queue done; // completed batch ids + std::vector op_cache; // per batch op cache + std::vector start_usec; // per batch start time + + ggml_hexagon_opqueue(ggml_hexagon_session *sess, size_t batch_size, size_t depth) { + size_t n_bufs = HTP_OP_MAX_BUFS; + size_t n_ops = batch_size; + size_t n_tensors = n_ops + n_ops * HTP_OP_MAX_INPUTS; + + shm_blk_size = sizeof(htp_buf_desc) * n_bufs + + sizeof(htp_tensor) * n_tensors + + sizeof(htp_op_desc) * n_ops + + sizeof(htp_prof_desc) * n_ops; + + shm_buf = new ggml_hexagon_shared_buffer(sess, shm_blk_size * depth, true /* pinned */); + + op_cache.resize(depth); + start_usec.resize(depth, 0); + + // init done queue + for (unsigned int i = 0; i < depth; i++) { done.push(i); } + + if (opt_verbose) { + GGML_LOG_INFO("ggml-hex: %s allocated op-queue : batch-size %zu depth %zu shm-size %zu shm-block-size %zu\n", + sess->c_name(), batch_size, depth, shm_buf->size, shm_blk_size); + } + } - size_t flush(uint8_t * mem_addr, size_t mem_size) { - static_assert(sizeof(htp_buf_desc) % 8 == 0, "sizeof(htp_buf_desc) must be multiple of 8"); - static_assert(sizeof(htp_tensor) % 8 == 0, "sizeof(htp_tensor) must be multiple of 8"); - static_assert(sizeof(htp_op_desc) % 8 == 0, "sizeof(htp_op_desc) must be multiple of 8"); + ~ggml_hexagon_opqueue() { + delete shm_buf; + } - const size_t b_size = sizeof(htp_buf_desc) * n_bufs; - const size_t t_size = sizeof(htp_tensor) * n_tens; - const size_t o_size = sizeof(htp_op_desc) * n_ops; + // push new batch + bool push(htp_opbatch_req& req, dspqueue_buffer& dbuf, ggml_hexagon_opbatch* op_batch) { + static_assert(sizeof(htp_opbatch_req) % 8 == 0, "sizeof(htp_opbatch_req) must be multiple of 8"); + static_assert(sizeof(htp_opbatch_rsp) % 8 == 0, "sizeof(htp_opbatch_rsp) must be multiple of 8"); + static_assert(sizeof(htp_buf_desc) % 8 == 0, "sizeof(htp_buf_desc) must be multiple of 8"); + static_assert(sizeof(htp_tensor) % 8 == 0, "sizeof(htp_tensor) must be multiple of 8"); + static_assert(sizeof(htp_op_desc) % 8 == 0, "sizeof(htp_op_desc) must be multiple of 8"); + static_assert(sizeof(htp_prof_desc) % 8 == 0, "sizeof(htp_prof_desc) must be multiple of 8"); - const size_t m_size = b_size + t_size + o_size; - GGML_ASSERT(m_size <= mem_size); + if (done.empty()) { return false; } - uint8_t * b_ptr = (uint8_t *) mem_addr; - uint8_t * t_ptr = (uint8_t *) b_ptr + b_size; - uint8_t * o_ptr = (uint8_t *) t_ptr + t_size; + req.id = done.front(); done.pop(); // batch id + req.n_bufs = op_batch->n_bufs; + req.n_tensors = op_batch->n_tens; + req.n_ops = op_batch->n_ops; - memcpy(b_ptr, (void *) buffers.data(), b_size); - memcpy(t_ptr, (void *) tensors.data(), t_size); - memcpy(o_ptr, (void *) ops.data(), o_size); + op_cache[req.id] = op_batch->ops; + start_usec[req.id] = ggml_time_us(); - HEX_VERBOSE("ggml-hex: %s flush-opbatch : n-bufs %u n-tensors %u n-ops %u vmem %zu : b-size %zu t-size %zu o-size %zu\n", - name, n_bufs, n_tens, n_ops, b_vmem, b_size, t_size, o_size); + const size_t b_size = sizeof(htp_buf_desc) * req.n_bufs; + const size_t t_size = sizeof(htp_tensor) * req.n_tensors; + const size_t o_size = sizeof(htp_op_desc) * req.n_ops; + const size_t p_size = sizeof(htp_prof_desc) * req.n_ops; + + dbuf.ptr = shm_buf->base + (req.id * shm_blk_size); + dbuf.fd = shm_buf->fd; + dbuf.flags = DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT; + dbuf.offset = (uint8_t*) dbuf.ptr - (uint8_t*) shm_buf->base; + dbuf.size = b_size + t_size + o_size + p_size; + + GGML_ASSERT(dbuf.size <= shm_blk_size); + + uint8_t * m_ptr = (uint8_t*) dbuf.ptr; + uint8_t * b_ptr = m_ptr; m_ptr += b_size; + uint8_t * t_ptr = m_ptr; m_ptr += t_size; + uint8_t * o_ptr = m_ptr; + + memcpy(b_ptr, (void *) op_batch->h_bufs.data(), b_size); + memcpy(t_ptr, (void *) op_batch->h_tens.data(), t_size); + memcpy(o_ptr, (void *) op_batch->h_ops.data(), o_size); + + HEX_VERBOSE("ggml-hex: %s op-queue push batch #%u : n-bufs %u n-tensors %u n-ops %u vmem %zu : b-size %zu t-size %zu o-size %zu m-size %zu\n", + shm_buf->sess->c_name(), req.id, req.n_bufs, req.n_tensors, req.n_ops, op_batch->b_vmem, + b_size, t_size, o_size, (size_t) dbuf.size); + + op_batch->reset(); if (opt_verbose > 1) { htp_buf_desc *b = (htp_buf_desc*) b_ptr; - for (unsigned int i=0; i < n_bufs; i++) { - GGML_LOG_DEBUG("ggml-hex: %s htp-buf #%u : fd %d base %p size %zu\n", name, i, + for (unsigned int i=0; i < req.n_bufs; i++) { + GGML_LOG_DEBUG("ggml-hex: %s htp-buf #%u : fd %d base %p size %zu\n", shm_buf->sess->c_name(), i, b[i].fd, (void *) b[i].base, (size_t) b[i].size); } htp_tensor *t = (htp_tensor*) t_ptr; - for (unsigned int i=0; i < n_tens; i++) { + for (unsigned int i=0; i < req.n_tensors; i++) { GGML_LOG_DEBUG("ggml-hex: %s htp-tensor #%u : bi %u offset %u size %u : %zu:%zu:%zu:%zu\n", - name, i, t[i].bi, t[i].data, t[i].size, + shm_buf->sess->c_name(), i, t[i].bi, t[i].data, t[i].size, (size_t) t[i].ne[0], (size_t) t[i].ne[1], (size_t) t[i].ne[2], (size_t) t[i].ne[3]); } } - reset(); + return true; + } + + void pop(htp_opbatch_rsp rsp, dspqueue_buffer dbuf) { + GGML_ASSERT(rsp.id < op_cache.size()); + + done.push(rsp.id); + + const size_t b_size = sizeof(htp_buf_desc) * rsp.n_bufs; + const size_t t_size = sizeof(htp_tensor) * rsp.n_tensors; + const size_t o_size = sizeof(htp_op_desc) * rsp.n_ops; + const size_t p_size = sizeof(htp_prof_desc) * rsp.n_ops; - return m_size; + const size_t m_size = b_size + t_size + o_size + p_size; + GGML_ASSERT(m_size <= shm_blk_size); + + HEX_VERBOSE("ggml-hex: %s op-queue pop batch #%u : n-bufs %u n-tensors %u n-ops %u : m-size %zu b-size %zu t-size %zu o-size %zu\n", + shm_buf->sess->c_name(), rsp.id, rsp.n_bufs, rsp.n_tensors, rsp.n_ops, + (size_t) dbuf.size, b_size, t_size, o_size); + + uint8_t * m_ptr = (uint8_t*) dbuf.ptr; + uint8_t * p_ptr = m_ptr + (b_size + t_size + o_size); + + if (opt_profile && rsp.n_ops > 0) { + auto & ops = op_cache[rsp.id]; + + uint64_t batch_usec = ggml_time_us() - start_usec[rsp.id]; + uint32_t htp_usec = 0; + + GGML_ASSERT(rsp.n_ops <= ops.size()); + + const htp_prof_desc * pd = (const htp_prof_desc *) p_ptr; + for (uint32_t i = 0; i < rsp.n_ops; i++) { + htp_usec += pd[i].usecs; + ggml_hexagon_dump_op_prof(shm_buf->sess->name, ops[i], pd[i].usecs, pd[i].cycles, pd[i].pmu); + } + + GGML_LOG_DEBUG("ggml-hex: %s profile-batch n-ops %u batch-dur-usec %lld htp-ops-usec %u\n", + shm_buf->sess->c_name(), rsp.n_ops, (long long) batch_usec, htp_usec); + } } }; @@ -1824,17 +1898,12 @@ void ggml_hexagon_session::flush_pending(bool all) { GGML_ABORT("ggml-hex: %s dspcall : bad response : size %u dspbufs %u\n", this->c_name(), rsp_size, n_dbufs); } - op_shm->release((uint8_t*) dbuf.ptr); - if (rsp.status != HTP_STATUS_OK) { GGML_LOG_ERROR("ggml-hex: %s dspcall : dsp-rsp: %s\n", this->c_name(), status_to_str(rsp.status)); // TODO: handle errors } - // FIXME: profile will be per opreq - // this->prof_usecs = rsp.prof_usecs; - // this->prof_cycles = rsp.prof_cycles; - // this->prof_pkts = rsp.prof_pkts; + op_queue->pop(rsp, dbuf); this->op_pending--; // atomic dec @@ -1845,28 +1914,17 @@ void ggml_hexagon_session::flush_pending(bool all) { void ggml_hexagon_session::flush_batch() { if (op_batch->empty()) { return; } - htp_opbatch_req req; - req.n_bufs = op_batch->n_bufs; - req.n_tensors = op_batch->n_tens; - req.n_ops = op_batch->n_ops; + htp_opbatch_req req {}; + dspqueue_buffer dbuf{}; - dspqueue_buffer dbuf; - dbuf.fd = op_shm->fd(); - dbuf.flags = DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT; - dbuf.ptr = op_shm->allocate(); - if (!dbuf.ptr) { + if (!op_queue->push(req, dbuf, op_batch)) { flush_pending(false); - dbuf.ptr = op_shm->allocate(); + op_queue->push(req, dbuf, op_batch); } - dbuf.offset = (uint8_t*) dbuf.ptr - (uint8_t*) op_shm->base(); - dbuf.size = op_batch->flush((uint8_t*) dbuf.ptr, op_shm->block_size); - // Bump pending flag (cleared in the session::flush once we get the response) this->op_pending++; // atomic inc - HEX_VERBOSE("ggml-hex: %s: queue-opbatch : %p size %u\n", this->c_name(), dbuf.ptr, dbuf.size); - int err = dspqueue_write(this->queue, 0, 1, &dbuf, sizeof(req), (const uint8_t*) &req, DSPQUEUE_TIMEOUT); if (err != 0) { GGML_ABORT("ggml-hex: %s dspqueue_write failed: 0x%08x\n", this->c_name(), (unsigned) err); @@ -2016,25 +2074,33 @@ void ggml_hexagon_session::allocate(int dev_id) noexcept(false) { } if (opt_etm) { - err = htp_iface_enable_etm(this->handle); + err = htp_iface_etm(this->handle, 1); if (err != 0) { GGML_LOG_ERROR("ggml-hex: failed to enable ETM tracing: 0x%08x\n", (unsigned) err); } } - // Start the DSP-side service. We need to pass the queue ID to the - // DSP in a FastRPC call; the DSP side will import the queue and start - // listening for packets in a callback. + if (opt_profile) { + htp_iface_pmu_conf pmu_conf{}; + std::copy(opt_pmu_evt.begin(), opt_pmu_evt.end(), pmu_conf.events); + + err = htp_iface_profiler(this->handle, opt_profile, &pmu_conf); + if (err != 0) { + GGML_LOG_ERROR("ggml-hex: failed to enable profiling: 0x%08x\n", (unsigned) err); + } + } + + // Allocate buffers and state for op batching + this->op_batch = new ggml_hexagon_opbatch(this, opt_opbatch); + this->op_queue = new ggml_hexagon_opqueue(this, opt_opbatch, opt_opqueue); + + // Start processing op batch requests err = htp_iface_start(this->handle, dev_id, this->queue_id, opt_nhvx, opt_use_hmx); if (err != 0) { GGML_LOG_ERROR("ggml-hex: failed to start session: 0x%08x\n", (unsigned) err); throw std::runtime_error("ggml-hex: iface start failed (see log for details)"); } this->valid_iface = true; - - // Allocate buffers and state for op batching - this->op_batch = new ggml_hexagon_opbatch(this, opt_opbatch); - this->op_shm = new ggml_hexagon_opshm(this, opt_opbatch, opt_opqueue); } void ggml_hexagon_session::release() noexcept(true) { @@ -2043,7 +2109,7 @@ void ggml_hexagon_session::release() noexcept(true) { int err; delete this->op_batch; - delete this->op_shm; + delete this->op_queue; // Stop the DSP-side service and close the queue if (this->valid_iface) { @@ -2054,12 +2120,20 @@ void ggml_hexagon_session::release() noexcept(true) { } if (opt_etm) { - err = htp_iface_disable_etm(this->handle); + err = htp_iface_etm(this->handle, 0); if (err != 0) { GGML_LOG_ERROR("ggml-hex: warn : failed to disable ETM tracing: 0x%08x\n", (unsigned) err); } } + if (opt_profile) { + htp_iface_pmu_conf pmu_conf{}; + err = htp_iface_profiler(this->handle, 0, &pmu_conf); + if (err != 0) { + GGML_LOG_ERROR("ggml-hex: warn : failed to disable profiling: 0x%08x\n", (unsigned) err); + } + } + if (this->valid_queue) { err = dspqueue_close(queue); if (err != 0) { @@ -2077,7 +2151,7 @@ ggml_hexagon_session::ggml_hexagon_session(int dev_id, ggml_backend_dev_t dev) n repack_buffer_type.device = dev; op_batch = nullptr; - op_shm = nullptr; + op_queue = nullptr; try { allocate(dev_id); @@ -2698,7 +2772,7 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg for (int i = 0; i < graph->n_nodes; ++i) { ggml_tensor * n = graph->nodes[i]; - if (op_is_compute(n)) { + if (op_is_compute(n) && (opt_opstage & HTP_OPSTAGE_QUEUE)) { sess->enqueue_op(op_remap_to_htp(n), n); } } @@ -3338,6 +3412,26 @@ static void * ggml_backend_hexagon_get_proc_address(ggml_backend_reg_t reg, cons return NULL; } +template std::vector str_to_vec(const char* str) { + std::stringstream ss(str); + std::vector v; + std::string t; + + while (std::getline(ss, t, ',')) { + v.push_back(std::stoul(t, nullptr, 0)); + } + + return v; +} + +template std::string vec_to_str(std::vector v) { + std::stringstream ss; + ss << std::setbase(BASE) << std::showbase; + for (auto i : v) { ss << i << ','; } + auto str = ss.str(); str.pop_back(); // drop last comma + return str; +} + static void ggml_hexagon_init(ggml_backend_reg * reg) { // Basic sanity checks to make sure definitions match static_assert((unsigned int) HTP_TYPE_Q4_0 == (unsigned int) GGML_TYPE_Q4_0, @@ -3351,8 +3445,7 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) { const char * str_verbose = getenv("GGML_HEXAGON_VERBOSE"); const char * str_hostbuf = getenv("GGML_HEXAGON_HOSTBUF"); - const char * str_opmask = getenv("GGML_HEXAGON_OPMASK"); - const char * str_opsync = getenv("GGML_HEXAGON_OPSYNC"); + const char * str_opstage = getenv("GGML_HEXAGON_OPSTAGE"); const char * str_opbatch = getenv("GGML_HEXAGON_OPBATCH"); const char * str_opqueue = getenv("GGML_HEXAGON_OPQUEUE"); const char * str_opfilter= getenv("GGML_HEXAGON_OPFILTER"); @@ -3365,19 +3458,30 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) { auto RE_ICASE = std::regex_constants::icase; - opt_opfilter = str_opfilter ? new std::regex(str_opfilter, RE_ICASE) : NULL; - opt_verbose = str_verbose ? atoi(str_verbose) : 0; - opt_hostbuf = str_hostbuf ? atoi(str_hostbuf) : opt_hostbuf; - opt_opmask = str_opmask ? strtoul(str_opmask, NULL, 0) : opt_opmask; - opt_opsync = str_opsync ? atoi(str_opsync) : opt_opsync; - opt_opbatch = str_opbatch ? strtoul(str_opbatch, NULL, 0) : opt_opbatch; - opt_opqueue = str_opqueue ? strtoul(str_opqueue, NULL, 0) : opt_opqueue; - opt_profile = str_profile ? atoi(str_profile) : 0; - opt_etm = str_etm ? atoi(str_etm) : 0; - opt_nhvx = str_nhvx ? strtoul(str_nhvx, NULL, 0) : opt_nhvx; - opt_use_hmx = str_use_hmx ? atoi(str_use_hmx) : opt_use_hmx; - opt_ndev = str_ndev ? strtoul(str_ndev, NULL, 0) : opt_ndev; - opt_hostbuf = str_hostbuf ? atoi(str_hostbuf) : opt_hostbuf; + opt_opfilter = str_opfilter ? new std::regex(str_opfilter, RE_ICASE) : NULL; + opt_verbose = str_verbose ? atoi(str_verbose) : 0; + opt_hostbuf = str_hostbuf ? atoi(str_hostbuf) : opt_hostbuf; + opt_opstage = str_opstage ? strtoul(str_opstage, NULL, 0) : opt_opstage; + opt_opbatch = str_opbatch ? strtoul(str_opbatch, NULL, 0) : opt_opbatch; + opt_opqueue = str_opqueue ? strtoul(str_opqueue, NULL, 0) : opt_opqueue; + opt_etm = str_etm ? atoi(str_etm) : 0; + opt_nhvx = str_nhvx ? strtoul(str_nhvx, NULL, 0) : opt_nhvx; + opt_use_hmx = str_use_hmx ? atoi(str_use_hmx) : opt_use_hmx; + opt_ndev = str_ndev ? strtoul(str_ndev, NULL, 0) : opt_ndev; + opt_hostbuf = str_hostbuf ? atoi(str_hostbuf) : opt_hostbuf; + + if (str_profile) { + opt_pmu_evt = [&]() -> std::vector { + auto v = str_to_vec(str_profile); + switch (v.size()) { + case 1: opt_profile = v[0]; return opt_pmu_evt; // mode with default pmu events + case 8: opt_profile = 2; return v; // mode with custom pmu events + default: opt_profile = 0; return {}; // garbage input + }}(); + if (opt_profile == 1) opt_pmu_evt = {}; + GGML_LOG_INFO("ggml-hex: Profiling mode %u : pmu-evt [ %s ]\n", opt_profile, + vec_to_str(opt_pmu_evt).c_str()); + } if (opt_ndev > GGML_HEXAGON_MAX_SESSIONS) { opt_ndev = GGML_HEXAGON_MAX_SESSIONS; diff --git a/ggml/src/ggml-hexagon/htp/hex-utils.h b/ggml/src/ggml-hexagon/htp/hex-utils.h index f6713c5cf8f..329249e11da 100644 --- a/ggml/src/ggml-hexagon/htp/hex-utils.h +++ b/ggml/src/ggml-hexagon/htp/hex-utils.h @@ -4,6 +4,7 @@ #include #include #include +#include #include "hexagon_types.h" #include "hexagon_protos.h" @@ -100,4 +101,31 @@ static inline void hex_pause() { asm volatile(" pause(#255)\n"); } +#ifndef HEX_NUM_PMU_COUNTERS +#define HEX_NUM_PMU_COUNTERS 8 +#endif + +static inline void hex_get_pmu(uint32_t counters[]) { +#if __HVX_ARCH__ >= 79 + asm volatile("%0 = upmucnt0" : "=r"(counters[0])); + asm volatile("%0 = upmucnt1" : "=r"(counters[1])); + asm volatile("%0 = upmucnt2" : "=r"(counters[2])); + asm volatile("%0 = upmucnt3" : "=r"(counters[3])); + asm volatile("%0 = upmucnt4" : "=r"(counters[4])); + asm volatile("%0 = upmucnt5" : "=r"(counters[5])); + asm volatile("%0 = upmucnt6" : "=r"(counters[6])); + asm volatile("%0 = upmucnt7" : "=r"(counters[7])); +#else + counters[0] = qurt_pmu_get(QURT_PMUCNT0); + counters[1] = qurt_pmu_get(QURT_PMUCNT1); + counters[2] = qurt_pmu_get(QURT_PMUCNT2); + counters[3] = qurt_pmu_get(QURT_PMUCNT3); + counters[4] = qurt_pmu_get(QURT_PMUCNT4); + counters[5] = qurt_pmu_get(QURT_PMUCNT5); + counters[6] = qurt_pmu_get(QURT_PMUCNT6); + counters[7] = qurt_pmu_get(QURT_PMUCNT7); + // qurt_pmu_get_pmucnt(counters); +#endif +} + #endif /* HEX_UTILS_H */ diff --git a/ggml/src/ggml-hexagon/htp/htp-ctx.h b/ggml/src/ggml-hexagon/htp/htp-ctx.h index 78455e6b071..f8c89211aed 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ctx.h +++ b/ggml/src/ggml-hexagon/htp/htp-ctx.h @@ -10,6 +10,7 @@ #include #include #include +#include #define HTP_MAX_NTHREADS 10 #define HTP_MAX_MMAPS 16 @@ -66,7 +67,9 @@ struct htp_context { int thread_id; int thread_prio; - int hmx_enabled; + bool hmx_enabled; + bool etm; + uint32_t profiler; uint8_t * vtcm_base; size_t vtcm_size; diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index 62d6ec02241..56d7b398d10 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -42,9 +42,9 @@ enum htp_data_type { // Mask to enable various stages of the Ops. // Used for debugging and profiling. -enum htp_op_mask { - HTP_OPMASK_QUEUE = (1 << 0), // Enable Queueing (ie calls into the DSP) - HTP_OPMASK_COMPUTE = (1 << 1), // Enable Compute +enum htp_op_stage { + HTP_OPSTAGE_QUEUE = (1 << 0), // Enable Queueing (ie calls into NPU) + HTP_OPSTAGE_COMPUTE = (1 << 1), // Enable Compute }; // Do not reorder first 4 (used as an index) @@ -137,27 +137,45 @@ struct htp_op_desc { int32_t params[HTP_OP_MAX_PARAMS]; // Params for the op, e.g. epsilon of RMS norm uint16_t src[HTP_OP_MAX_INPUTS]; // Input tensors indices uint16_t dst; // Output tensor index +}; + +enum htp_profiler_mode { + HTP_PROF_DISABLED = 0, + HTP_PROF_BASIC = 1, + HTP_PROF_PMU = 2, +}; + +#define HTP_PROF_PMU_NCNT 8 - // the rest is filled in-place by the NPU - uint32_t prof_usecs; // Number of usec per request - uint32_t prof_cycles; // Number of cycles per request - uint32_t prof_pkts; // Number of instruction packets per request - uint32_t unused; +// Profile descriptor +struct htp_prof_desc { + uint32_t opcode; // GGML/HTP Op + uint32_t usecs; // Number of usec + uint32_t cycles; // Number of cycles + uint32_t pad; // Unused + uint32_t pmu[HTP_PROF_PMU_NCNT]; // PMU counters }; struct htp_opbatch_req { + uint32_t id; // Batch id uint32_t n_bufs; // Number of buffers uint32_t n_tensors; // Number of tensors uint32_t n_ops; // Number of ops uint32_t flags; // unused + uint32_t pad; // unused // struct htp_buf_desc bufs[]; -- dspqueue buf 0 // struct htp_tensor tensors[]; -- dspqueue buf 0 // struct htp_op_desc ops[]; -- dspqueue buf 0 }; struct htp_opbatch_rsp { + uint32_t id; // Batch id uint32_t status; // HTP_STATUS_... - // struct htp_op_req ops[]; -- dspqueue buf 0 + uint32_t n_bufs; // Number of buffers + uint32_t n_tensors; // Number of tensors + uint32_t n_ops; // Number of op profile descriptors + uint32_t pad; // unused + // struct htp_prof_desc profs[]; -- dspqueue buf 0 }; #endif /* HTP_OPS_H */ diff --git a/ggml/src/ggml-hexagon/htp/htp_iface.idl b/ggml/src/ggml-hexagon/htp/htp_iface.idl index 3eb5d5a6912..dbcafd1d856 100644 --- a/ggml/src/ggml-hexagon/htp/htp_iface.idl +++ b/ggml/src/ggml-hexagon/htp/htp_iface.idl @@ -6,13 +6,17 @@ #include "AEEStdDef.idl" #include "remote.idl" +struct htp_iface_pmu_conf { + uint32 events[8]; +}; + interface htp_iface : remote_handle64 { AEEResult start(in uint32 sess_id, in uint64 dsp_queue_id, in uint32 n_hvx, in uint32 use_hmx); AEEResult stop(); AEEResult mmap(in uint32 fd, in uint32 size, in uint32 pinned); AEEResult munmap(in uint32 fd); - AEEResult enable_etm(); - AEEResult disable_etm(); + AEEResult profiler(in uint32 mode, in htp_iface_pmu_conf pmu); + AEEResult etm(in uint32 enable); }; #endif /* HTP_IDL */ diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index 9185c9ffe15..088434a63e9 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -27,6 +27,7 @@ #include "htp-ctx.h" #include "htp-ops.h" #include "htp-ops.h" +#include "htp_iface.h" #include "worker-pool.h" AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) { @@ -103,6 +104,54 @@ AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) { return AEE_SUCCESS; } +AEEResult htp_iface_etm(remote_handle64 handle, uint32_t enable) { + int err = enable ? HAP_user_etm_enable() : HAP_user_etm_disable(); + if (err) { + if (err == AEE_EVERSIONNOTSUPPORT) { + FARF(ERROR, "API HAP_user_etm_enable/disable is not supported\n"); + } else { + FARF(ERROR, "Error executing HAP_user_etm_enable/disable with error code : 0x%x\n", err); + } + } + return err; +} + +AEEResult htp_iface_profiler(remote_handle64 handle, uint32_t mode, const htp_iface_pmu_conf* pmu_conf) { + struct htp_context * ctx = (struct htp_context *) handle; + if (!ctx) { + return AEE_EBADPARM; + } + + if (mode == HTP_PROF_PMU) { + const uint32_t* events = pmu_conf->events; + + // Pack 4 event IDs (low 8 bits) into each 32-bit config register + uint32_t evtcfg = 0, evtcfg1 = 0, cfg = 0, i = 0; + for (; i < HEX_NUM_PMU_COUNTERS/2; i++) { + evtcfg |= ((events[i + 0] & 0xFF) << (i * 8)); + evtcfg1 |= ((events[i + 4] & 0xFF) << (i * 8)); + } + + // For events >255 pack high 2 bits of all 8 event IDs into cfg register + // 2 bits per counter: bits [1:0] for counter 0, [3:2] for counter 1, etc. + for (i = 0; i < HEX_NUM_PMU_COUNTERS; i++) { + cfg |= (((events[i] >> 8) & 3) << (i * 2)); + } + + FARF(ALWAYS, "Configuring PMU registers: evtcfg = 0x%x, evtcfg1 = 0x%x, pmucfg = 0x%x", evtcfg, evtcfg1, cfg); + + // Configure PMU registers + qurt_pmu_set(QURT_PMUCFG, cfg); + qurt_pmu_set(QURT_PMUEVTCFG, evtcfg); + qurt_pmu_set(QURT_PMUEVTCFG1, evtcfg1); + qurt_pmu_enable(1); + } + + ctx->profiler = mode; + + return AEE_SUCCESS; +} + AEEResult htp_iface_close(remote_handle64 handle) { struct htp_context * ctx = (struct htp_context *) handle; @@ -129,35 +178,19 @@ AEEResult htp_iface_close(remote_handle64 handle) { } } - free(ctx); - return AEE_SUCCESS; -} - -AEEResult htp_iface_enable_etm(remote_handle64 handle) { - int err = HAP_user_etm_enable(); - if (err) { - if (err == AEE_EVERSIONNOTSUPPORT) { - FARF(ERROR, "API HAP_user_etm_enable is not supported\n"); - } else { - FARF(ERROR, "Error executing HAP_user_etm_enable with error code : 0x%x\n", err); - } + if (ctx->profiler) { + qurt_pmu_enable(1); } - return err; -} -AEEResult htp_iface_disable_etm(remote_handle64 handle) { - int err = HAP_user_etm_disable(); - if (err) { - if (err == AEE_EVERSIONNOTSUPPORT) { - FARF(ERROR, "API HAP_user_etm_disable is not supported\n"); - } else { - FARF(ERROR, "Error executing HAP_user_etm_disable with error code : 0x%x\n", err); - } + if (ctx->etm) { + HAP_user_etm_disable(); } - return err; + + free(ctx); + return AEE_SUCCESS; } -AEEResult htp_iface_mmap(remote_handle64 handle, int fd, uint32_t size, uint32_t pinned) { +AEEResult htp_iface_mmap(remote_handle64 handle, uint32 fd, uint32 size, uint32 pinned) { struct htp_context * ctx = (struct htp_context *) handle; if (!ctx) { return AEE_EBADPARM; @@ -204,7 +237,7 @@ AEEResult htp_iface_mmap(remote_handle64 handle, int fd, uint32_t size, uint32_t return AEE_ENOMEMORY; } -AEEResult htp_iface_munmap(remote_handle64 handle, int fd) { +AEEResult htp_iface_munmap(remote_handle64 handle, uint32 fd) { struct htp_context * ctx = (struct htp_context *) handle; if (!ctx) { return AEE_EBADPARM; @@ -434,19 +467,39 @@ static void htp_error_callback(dspqueue_t queue, int error, void * context) { struct profile_data { uint64_t usecs; uint64_t cycles; - uint64_t pkts; + uint32_t pmu_counters[HEX_NUM_PMU_COUNTERS]; }; -static inline void profile_start(struct profile_data * d) { - d->usecs = HAP_perf_get_qtimer_count(); - d->cycles = hex_get_cycles(); - d->pkts = hex_get_pktcnt(); +static inline void profile_start(uint32_t mode, struct profile_data * d) { + switch (mode) { + case HTP_PROF_PMU: + hex_get_pmu(d->pmu_counters); + // fallthrough + case HTP_PROF_BASIC: + d->usecs = HAP_perf_get_qtimer_count(); + d->cycles = hex_get_cycles(); + break; + default: + break; + } } -static inline void profile_stop(struct profile_data * d) { - d->usecs = HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - d->usecs); - d->cycles = hex_get_cycles() - d->cycles; - d->pkts = hex_get_pktcnt() - d->pkts; +static inline void profile_stop(uint32_t mode, struct profile_data * d) { + uint32_t pmu_counters[HEX_NUM_PMU_COUNTERS]; + switch (mode) { + case HTP_PROF_PMU: + hex_get_pmu(pmu_counters); + for (int i = 0; i < HEX_NUM_PMU_COUNTERS; i++) { + d->pmu_counters[i] = pmu_counters[i] - d->pmu_counters[i]; + } + // fallthrough + case HTP_PROF_BASIC: + d->usecs = HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - d->usecs); + d->cycles = hex_get_cycles() - d->cycles; + break; + default: + break; + } } static int execute_op(struct htp_ops_context * octx) { @@ -726,29 +779,32 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) { continue; } + // Reset poll count for valid requests + poll_count = DSPQUEUE_POLL_COUNT; + const uint32_t n_bufs = req.n_bufs; const uint32_t n_tens = req.n_tensors; const uint32_t n_ops = req.n_ops; - const uint32_t b_size = sizeof(struct htp_buf_desc) * n_bufs; - const uint32_t t_size = sizeof(struct htp_tensor) * n_tens; - const uint32_t o_size = sizeof(struct htp_op_desc) * n_ops; + const uint32_t b_size = sizeof(struct htp_buf_desc) * n_bufs; + const uint32_t t_size = sizeof(struct htp_tensor) * n_tens; + const uint32_t o_size = sizeof(struct htp_op_desc) * n_ops; + const uint32_t p_size = sizeof(struct htp_prof_desc) * n_ops; - if (dbuf.size < b_size + t_size + o_size) { + if (dbuf.size < b_size + t_size + o_size + p_size) { FARF(ERROR, "invalid opbatch memory block size %u", dbuf.size); break; } - // Reset poll count for valid requests - poll_count = DSPQUEUE_POLL_COUNT; + FARF(HIGH, "processing opbatch #%u: n-bufs %u n-tensors %u n-ops %u : m-size %u b-size %u t-size %u o-size %u", req.id, + n_bufs, n_tens, n_ops, dbuf.size, b_size, t_size, o_size); + // Setup descriptor pointers uint8_t * m_ptr = dbuf.ptr; - struct htp_buf_desc* bufs = (struct htp_buf_desc*) m_ptr; m_ptr += b_size; - struct htp_tensor* tens = (struct htp_tensor*) m_ptr; m_ptr += t_size; - struct htp_op_desc* ops = (struct htp_op_desc*) m_ptr; - - FARF(HIGH, "processing opbatch: n-bufs %u n-tensors %u n-ops %u : m-size %u b-size %u t-size %u o-size %u", - n_bufs, n_tens, n_ops, dbuf.size, b_size, t_size, o_size); + struct htp_buf_desc* bufs = (struct htp_buf_desc*) m_ptr; m_ptr += b_size; + struct htp_tensor* tens = (struct htp_tensor*) m_ptr; m_ptr += t_size; + struct htp_op_desc* ops = (struct htp_op_desc*) m_ptr; m_ptr += o_size; + struct htp_prof_desc* pds = (struct htp_prof_desc*) m_ptr; prep_op_bufs(ctx, bufs, n_bufs); prep_tensors(ctx, bufs, tens, n_tens); @@ -760,22 +816,34 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) { for (uint32_t i=0; i < n_ops; i++) { struct profile_data prof; - profile_start(&prof); + + profile_start(ctx->profiler, &prof); proc_op_req(octx, tens, i, &ops[i]); - profile_stop(&prof); - ops[i].prof_usecs = prof.usecs; - ops[i].prof_cycles = prof.cycles; - ops[i].prof_pkts = prof.pkts; + profile_stop(ctx->profiler, &prof); + + if (ctx->profiler) { + pds[i].opcode = ops[i].opcode; + pds[i].usecs = prof.usecs; + pds[i].cycles = prof.cycles; + for (int j = 0; j < HEX_NUM_PMU_COUNTERS; j++) { + pds[i].pmu[j] = prof.pmu_counters[j]; + } + } } // dspqueue_write_early_wakeup_noblock(ctx->queue, 10, 0); struct htp_opbatch_rsp rsp; - rsp.status = HTP_STATUS_OK; // FIXME + rsp.id = req.id; + rsp.status = HTP_STATUS_OK; + rsp.n_bufs = n_bufs; + rsp.n_tensors = n_tens; + rsp.n_ops = n_ops; dbuf.flags = DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT; + err = dspqueue_write(queue, 0, 1, &dbuf, sizeof(rsp), (const uint8_t *) &rsp, DSPQUEUE_TIMEOUT_NONE); if (err != 0) { FARF(ERROR, "dspqueue_write failed: 0x%08x", (unsigned) err); diff --git a/ggml/src/ggml-hexagon/htp/matmul-ops.c b/ggml/src/ggml-hexagon/htp/matmul-ops.c index bac06693d81..a0c265132c8 100644 --- a/ggml/src/ggml-hexagon/htp/matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/matmul-ops.c @@ -3017,6 +3017,10 @@ int op_matmul(struct htp_ops_context * octx) { const int act_stride = (int)(src1->nb[1] / sizeof(float)); const int wgt_stride = (int)(src0->nb[1] / sizeof(__fp16)); + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { + return HTP_STATUS_OK; + } + if (src0->type == HTP_TYPE_F16) { if (is_batched) { hmx_matmul_w16a32_batched_params_t batch_params = { From 641998f558afb6dae907e86ef0a44995b8a00592 Mon Sep 17 00:00:00 2001 From: Chen Yuan Date: Thu, 23 Apr 2026 19:32:59 -0400 Subject: [PATCH 484/831] fix(shader): handle the buffer aliasing for rms fuse (llama/22266) --- ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp | 14 ++++++++++---- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 6 ++++-- .../ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl | 17 ++++++++++++++++- 3 files changed, 30 insertions(+), 7 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index efc5b8c97a7..449eae808e4 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -197,11 +197,12 @@ struct ggml_webgpu_row_norm_pipeline_key_hash { /** RMS_NORM + MUL **/ struct ggml_webgpu_rms_norm_mul_pipeline_key { - bool inplace; - bool src_overlap; + bool inplace; // rn_src == dst + bool overlap; // mul_src == dst + bool src_overlap; // rn_src == mul_src bool operator==(const ggml_webgpu_rms_norm_mul_pipeline_key & other) const { - return inplace == other.inplace && src_overlap == other.src_overlap; + return inplace == other.inplace && overlap == other.overlap && src_overlap == other.src_overlap; } }; @@ -209,6 +210,7 @@ struct ggml_webgpu_rms_norm_mul_pipeline_key_hash { size_t operator()(const ggml_webgpu_rms_norm_mul_pipeline_key & key) const { size_t seed = 0; ggml_webgpu_hash_combine(seed, key.inplace); + ggml_webgpu_hash_combine(seed, key.overlap); ggml_webgpu_hash_combine(seed, key.src_overlap); return seed; } @@ -556,7 +558,7 @@ inline uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_shader_lib_ const size_t q_tile = context.sg_mat_m; const size_t base_q_bytes = (key.head_dim_qk + key.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES + 2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES; - size_t bytes_per_kv = 0; + size_t bytes_per_kv = 0; if (!key.kv_direct) { bytes_per_kv += std::max(key.head_dim_qk, key.head_dim_v); } @@ -1878,6 +1880,7 @@ class ggml_webgpu_shader_lib { webgpu_pipeline get_rms_norm_mul_pipeline(const ggml_webgpu_shader_lib_context & context) { ggml_webgpu_rms_norm_mul_pipeline_key key = {}; key.inplace = context.inplace; + key.overlap = context.overlap; key.src_overlap = context.src_overlap; auto it = rms_norm_mul_pipelines.find(key); @@ -1892,6 +1895,9 @@ class ggml_webgpu_shader_lib { if (key.inplace) { defines.push_back("INPLACE"); variant += "_inplace"; + } else if (key.overlap) { + defines.push_back("OVERLAP"); + variant += "_overlap"; } else if (key.src_overlap) { defines.push_back("SRC_OVERLAP"); variant += "_src_overlap"; diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index bcca2bd4627..acc486cfdda 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -2071,8 +2071,9 @@ static std::optional ggml_webgpu_rms_norm_mul(webgpu_context GGML_ABORT("rms_norm must be equal to the one of mul_src0 and mul_src1"); } - bool inplace = (ggml_webgpu_tensor_equal(rn_dst, mul_src0) && ggml_webgpu_tensor_equal(mul_src1, dst)) || + bool overlap = (ggml_webgpu_tensor_equal(rn_dst, mul_src0) && ggml_webgpu_tensor_equal(mul_src1, dst)) || (ggml_webgpu_tensor_equal(rn_dst, mul_src1) && ggml_webgpu_tensor_equal(mul_src0, dst)); + bool inplace = ggml_webgpu_tensor_equal(rn_src, dst); bool src_overlap = ggml_webgpu_tensor_overlap(rn_src, mul_src); uint32_t offset_merged_rn_src = 0; @@ -2116,7 +2117,7 @@ static std::optional ggml_webgpu_rms_norm_mul(webgpu_context std::vector entries; - if (inplace) { + if (inplace || overlap) { entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, rn_src)); entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, mul_src)); } else if (src_overlap) { @@ -2136,6 +2137,7 @@ static std::optional ggml_webgpu_rms_norm_mul(webgpu_context ggml_webgpu_shader_lib_context shader_lib_ctx = {}; shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; shader_lib_ctx.inplace = inplace; + shader_lib_ctx.overlap = overlap; shader_lib_ctx.src_overlap = src_overlap; webgpu_pipeline pipeline = ctx->shader_lib->get_rms_norm_mul_pipeline(shader_lib_ctx); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl index 71f063b51aa..74aaa2753ae 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl @@ -1,4 +1,4 @@ -#ifdef INPLACE +#ifdef OVERLAP @group(0) @binding(0) var rn_src: array; @@ -13,6 +13,21 @@ fn update(rn_src_offset: u32, dst_offset: u32, scale: f32, mul_src_offset: u32) mul_src[dst_offset] = scale * rn_src[rn_src_offset] * mul_src[mul_src_offset]; } +#elif INPLACE + +@group(0) @binding(0) +var rn_src: array; + +@group(0) @binding(1) +var mul_src: array; + +@group(0) @binding(2) +var params: Params; + +fn update(rn_src_offset: u32, dst_offset: u32, scale: f32, mul_src_offset: u32) { + rn_src[dst_offset] = scale * rn_src[rn_src_offset] * mul_src[mul_src_offset]; +} + #elif SRC_OVERLAP @group(0) @binding(0) From 23921d5a695262bdf9bdb34300f179aa97ae7a1e Mon Sep 17 00:00:00 2001 From: Mengsheng Wu Date: Fri, 24 Apr 2026 09:39:13 +0800 Subject: [PATCH 485/831] hexagon: add SOLVE_TRI op (llama/21974) * hexagon: add SOLVE_TRI op * ggml: fix TODO description for solve_tri * hexagon: rm unused variable/function warnings * hexagon: chunk vs batch processingfor better thread utilization * hexagon: vectorize partial f32 loads * hexagon: move HVX f32 add/sub/mul wrappers to hvx-base.h --------- Co-authored-by: Todor Boinovski --- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 39 +++- ggml/src/ggml-hexagon/htp/CMakeLists.txt | 1 + ggml/src/ggml-hexagon/htp/htp-ctx.h | 1 + ggml/src/ggml-hexagon/htp/htp-ops.h | 2 +- ggml/src/ggml-hexagon/htp/hvx-base.h | 24 ++ ggml/src/ggml-hexagon/htp/main.c | 3 + ggml/src/ggml-hexagon/htp/solve-tri-ops.c | 267 ++++++++++++++++++++++ 7 files changed, 335 insertions(+), 2 deletions(-) create mode 100644 ggml/src/ggml-hexagon/htp/solve-tri-ops.c diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 955903418b6..0d9b5e289bb 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -2693,6 +2693,39 @@ static bool ggml_hexagon_supported_diag(const struct ggml_hexagon_session * sess return true; } +static bool ggml_hexagon_supported_solve_tri(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + const struct ggml_tensor * src0 = op->src[0]; // A + const struct ggml_tensor * src1 = op->src[1]; // B + const struct ggml_tensor * dst = op; // X + + if (!src0 || !src1) { + return false; + } + + if (src0->type != GGML_TYPE_F32 || src1->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) { + return false; + } + + if (src0->ne[0] != src0->ne[1]) { + return false; + } + + if (src0->ne[1] != src1->ne[1]) { + return false; + } + + if (src0->ne[2] != src1->ne[2] || src0->ne[3] != src1->ne[3]) { + return false; + } + + if (dst->ne[0] != src1->ne[0] || dst->ne[1] != src1->ne[1] || dst->ne[2] != src1->ne[2] || dst->ne[3] != src1->ne[3]) { + return false; + } + + GGML_UNUSED(sess); + return true; +} + static const char * ggml_backend_hexagon_name(ggml_backend_t backend) { auto sess = static_cast(backend->context); return sess->c_name(); @@ -2731,7 +2764,7 @@ static htp_op_code op_remap_to_htp(const ggml_tensor * t) { case GGML_OP_CUMSUM: return HTP_OP_CUMSUM; case GGML_OP_FILL: return HTP_OP_FILL; case GGML_OP_DIAG: return HTP_OP_DIAG; - + case GGML_OP_SOLVE_TRI: return HTP_OP_SOLVE_TRI; case GGML_OP_UNARY: switch (ggml_get_unary_op(t)) { case GGML_UNARY_OP_SILU: return HTP_OP_UNARY_SILU; @@ -3277,6 +3310,10 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons supp = ggml_hexagon_supported_diag(sess, op); break; + case GGML_OP_SOLVE_TRI: + supp = ggml_hexagon_supported_solve_tri(sess, op); + break; + default: break; } diff --git a/ggml/src/ggml-hexagon/htp/CMakeLists.txt b/ggml/src/ggml-hexagon/htp/CMakeLists.txt index b1ae60a9c43..8bd528478ba 100644 --- a/ggml/src/ggml-hexagon/htp/CMakeLists.txt +++ b/ggml/src/ggml-hexagon/htp/CMakeLists.txt @@ -36,6 +36,7 @@ add_library(${HTP_LIB} SHARED cumsum-ops.c fill-ops.c diag-ops.c + solve-tri-ops.c ) target_compile_definitions(${HTP_LIB} PRIVATE diff --git a/ggml/src/ggml-hexagon/htp/htp-ctx.h b/ggml/src/ggml-hexagon/htp/htp-ctx.h index f8c89211aed..d704fedee9d 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ctx.h +++ b/ggml/src/ggml-hexagon/htp/htp-ctx.h @@ -103,5 +103,6 @@ int op_ssm_conv(struct htp_ops_context * octx); int op_cumsum(struct htp_ops_context * octx); int op_fill(struct htp_ops_context * octx); int op_diag(struct htp_ops_context * octx); +int op_solve_tri(struct htp_ops_context * octx); #endif /* HTP_CTX_H */ diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index 56d7b398d10..4397245c5b8 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -82,7 +82,7 @@ enum htp_op_code { HTP_OP_CUMSUM, HTP_OP_FILL, HTP_OP_DIAG, - + HTP_OP_SOLVE_TRI, HTP_OP_INVALID }; diff --git a/ggml/src/ggml-hexagon/htp/hvx-base.h b/ggml/src/ggml-hexagon/htp/hvx-base.h index ed6026e762a..d0926dedd28 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-base.h +++ b/ggml/src/ggml-hexagon/htp/hvx-base.h @@ -256,6 +256,18 @@ static inline HVX_Vector hvx_vec_mul_f16_f16(HVX_Vector a, HVX_Vector b) return Q6_Vhf_equals_Wqf32(Q6_Wqf32_vmpy_VhfVhf(a, b)); } +static inline HVX_Vector hvx_vec_add_f32_f32(HVX_Vector a, HVX_Vector b) { + return Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(a, b)); +} + +static inline HVX_Vector hvx_vec_sub_f32_f32(HVX_Vector a, HVX_Vector b) { + return Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(a, b)); +} + +static inline HVX_Vector hvx_vec_mul_f32_f32(HVX_Vector a, HVX_Vector b) { + return Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b)); +} + #else static inline HVX_Vector hvx_vec_add_f16_f16(HVX_Vector a, HVX_Vector b) @@ -273,6 +285,18 @@ static inline HVX_Vector hvx_vec_mul_f16_f16(HVX_Vector a, HVX_Vector b) return Q6_Vhf_vmpy_VhfVhf(a, b); } +static inline HVX_Vector hvx_vec_add_f32_f32(HVX_Vector a, HVX_Vector b) { + return Q6_Vsf_vadd_VsfVsf(a, b); +} + +static inline HVX_Vector hvx_vec_sub_f32_f32(HVX_Vector a, HVX_Vector b) { + return Q6_Vsf_vsub_VsfVsf(a, b); +} + +static inline HVX_Vector hvx_vec_mul_f32_f32(HVX_Vector a, HVX_Vector b) { + return Q6_Vsf_vmpy_VsfVsf(a, b); +} + #endif // __HVX_ARCH__ < 79 #endif /* HVX_BASE_H */ diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index 088434a63e9..db277a25e5a 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -573,6 +573,9 @@ static int execute_op(struct htp_ops_context * octx) { case HTP_OP_DIAG: return op_diag(octx); + case HTP_OP_SOLVE_TRI: + return op_solve_tri(octx); + case HTP_OP_INVALID: break; diff --git a/ggml/src/ggml-hexagon/htp/solve-tri-ops.c b/ggml/src/ggml-hexagon/htp/solve-tri-ops.c new file mode 100644 index 00000000000..ae8e1a50495 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/solve-tri-ops.c @@ -0,0 +1,267 @@ +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#include +#include +#include + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" +#include "htp-ops.h" +#include "hvx-types.h" +#include "hvx-utils.h" + +struct htp_solve_tri_context { + struct htp_ops_context * octx; + uint32_t jobs_per_thread; + uint32_t total_jobs; + uint32_t k_chunks; + uint32_t col_block; +}; + +static inline void solve_tri_row_scalar(const float * A_row, + const float * B_row, + float * X, + uint32_t row, + uint32_t k, + uint32_t col0, + uint32_t coln, + float inv_diag) { + for (uint32_t col = col0; col < col0 + coln; ++col) { + float sum = 0.0f; + for (uint32_t t = 0; t < row; ++t) { + sum += A_row[t] * X[t * k + col]; + } + X[row * k + col] = (B_row[col] - sum) * inv_diag; + } +} + +static inline HVX_Vector hvx_load_partial_f32(const float * src, uint32_t n) { + HVX_Vector v = *((const HVX_UVector *) src); + HVX_VectorPred mask = Q6_Q_vsetq2_R(n * sizeof(float)); + return Q6_V_vmux_QVV(mask, v, Q6_V_vzero()); +} + +static inline void solve_tri_row_hvx(const float * A_row, + const float * B_row, + float * X, + uint32_t row, + uint32_t k, + uint32_t col0, + uint32_t coln, + float inv_diag) { + const bool full = (coln == VLEN_FP32); + + HVX_Vector sum_v = Q6_V_vzero(); + for (uint32_t t = 0; t < row; ++t) { + const float a = A_row[t]; + const float * x_row_col = X + t * k + col0; + + HVX_Vector x_v = full ? *((const HVX_UVector *) x_row_col) : hvx_load_partial_f32(x_row_col, coln); + HVX_Vector a_v = hvx_vec_splat_f32(a); + sum_v = hvx_vec_add_f32_f32(sum_v, hvx_vec_mul_f32_f32(x_v, a_v)); + } + + const float * b_row_col = B_row + col0; + float * x_out_col = X + row * k + col0; + + HVX_Vector b_v = full ? *((const HVX_UVector *) b_row_col) : hvx_load_partial_f32(b_row_col, coln); + HVX_Vector inv_diag_v = hvx_vec_splat_f32(inv_diag); + + HVX_Vector out_v = hvx_vec_mul_f32_f32(hvx_vec_sub_f32_f32(b_v, sum_v), inv_diag_v); + hvx_vec_store_u((void *) x_out_col, coln * sizeof(float), out_v); +} + +// Batch-level thread: each job is one full batch. +static void solve_tri_batch_thread_f32(unsigned int nth, unsigned int ith, void * data) { + struct htp_solve_tri_context * sctx = (struct htp_solve_tri_context *) data; + struct htp_ops_context * octx = sctx->octx; + + const struct htp_tensor * src0 = octx->src[0]; // A + const struct htp_tensor * src1 = octx->src[1]; // B + const struct htp_tensor * dst = octx->dst; // X + + const uint32_t n = src0->ne[0]; + const uint32_t k = src1->ne[0]; + + const uint32_t ne02 = src0->ne[2]; + + const uint32_t col_block = VLEN_FP32; + const uint32_t k_full = (k / col_block) * col_block; + + const uint32_t start_batch = sctx->jobs_per_thread * ith; + const uint32_t end_batch = MIN(start_batch + sctx->jobs_per_thread, sctx->total_jobs); + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + for (uint32_t batch = start_batch; batch < end_batch; ++batch) { + const uint32_t i03 = batch / ne02; + const uint32_t i02 = batch - i03 * ne02; + + const float * A_batch = + (const float *) ((const uint8_t *) (uintptr_t) src0->data + i02 * src0->nb[2] + i03 * src0->nb[3]); + const float * B_batch = + (const float *) ((const uint8_t *) (uintptr_t) src1->data + i02 * src1->nb[2] + i03 * src1->nb[3]); + float * X_batch = (float *) ((uint8_t *) (uintptr_t) dst->data + i02 * dst->nb[2] + i03 * dst->nb[3]); + + for (uint32_t row = 0; row < n; ++row) { + const float diag = A_batch[row * n + row]; + const float inv_diag = 1.0f / diag; + const float * A_row = A_batch + row * n; + const float * B_row = B_batch + row * k; + + uint32_t col0 = 0; + for (; col0 < k_full; col0 += col_block) { + solve_tri_row_hvx(A_row, B_row, X_batch, row, k, col0, col_block, inv_diag); + } + + if (col0 < k) { + const uint32_t coln = k - col0; + if (coln >= 8) { + solve_tri_row_hvx(A_row, B_row, X_batch, row, k, col0, coln, inv_diag); + } else { + solve_tri_row_scalar(A_row, B_row, X_batch, row, k, col0, coln, inv_diag); + } + } + } + } + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "solve-tri-batch %d/%d: A=(%ux%u) B=(%ux%u) batch %u:%u usec %u\n", + ith, nth, n, n, k, n, start_batch, end_batch, + (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +// Chunk-level thread: each job is one (batch, col_chunk) pair. +static void solve_tri_chunk_thread_f32(unsigned int nth, unsigned int ith, void * data) { + struct htp_solve_tri_context * sctx = (struct htp_solve_tri_context *) data; + struct htp_ops_context * octx = sctx->octx; + + const struct htp_tensor * src0 = octx->src[0]; // A + const struct htp_tensor * src1 = octx->src[1]; // B + const struct htp_tensor * dst = octx->dst; // X + + const uint32_t n = src0->ne[0]; + const uint32_t k = src1->ne[0]; + + const uint32_t ne02 = src0->ne[2]; + + const uint32_t start_job = sctx->jobs_per_thread * ith; + const uint32_t end_job = MIN(start_job + sctx->jobs_per_thread, sctx->total_jobs); + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + for (uint32_t job = start_job; job < end_job; ++job) { + const uint32_t batch = job / sctx->k_chunks; + const uint32_t chunk = job - batch * sctx->k_chunks; + + const uint32_t i03 = batch / ne02; + const uint32_t i02 = batch - i03 * ne02; + + const uint32_t col0 = chunk * sctx->col_block; + const uint32_t coln = MIN(sctx->col_block, k - col0); + + const float * A_batch = + (const float *) ((const uint8_t *) (uintptr_t) src0->data + i02 * src0->nb[2] + i03 * src0->nb[3]); + const float * B_batch = + (const float *) ((const uint8_t *) (uintptr_t) src1->data + i02 * src1->nb[2] + i03 * src1->nb[3]); + float * X_batch = (float *) ((uint8_t *) (uintptr_t) dst->data + i02 * dst->nb[2] + i03 * dst->nb[3]); + + const bool use_hvx = (coln >= 8); + + for (uint32_t row = 0; row < n; ++row) { + const float diag = A_batch[row * n + row]; + const float inv_diag = 1.0f / diag; + + const float * A_row = A_batch + row * n; + const float * B_row = B_batch + row * k; + + if (use_hvx) { + solve_tri_row_hvx(A_row, B_row, X_batch, row, k, col0, coln, inv_diag); + } else { + solve_tri_row_scalar(A_row, B_row, X_batch, row, k, col0, coln, inv_diag); + } + } + } + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "solve-tri-chunk %d/%d: A=(%ux%u) B=(%ux%u) job %u:%u usec %u\n", + ith, nth, n, n, k, n, start_job, end_job, + (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +int op_solve_tri(struct htp_ops_context * octx) { + const struct htp_tensor * src0 = octx->src[0]; // A + const struct htp_tensor * src1 = octx->src[1]; // B + const struct htp_tensor * dst = octx->dst; // X + + if (src0->type != HTP_TYPE_F32 || src1->type != HTP_TYPE_F32 || dst->type != HTP_TYPE_F32) { + return HTP_STATUS_NO_SUPPORT; + } + + // left=true, lower=true, uni=false only + if (src0->ne[0] != src0->ne[1]) { + return HTP_STATUS_INVAL_PARAMS; + } + if (src0->ne[1] != src1->ne[1]) { + return HTP_STATUS_INVAL_PARAMS; + } + if (src0->ne[2] != src1->ne[2] || src0->ne[3] != src1->ne[3]) { + return HTP_STATUS_INVAL_PARAMS; + } + if (dst->ne[0] != src1->ne[0] || dst->ne[1] != src1->ne[1] || dst->ne[2] != src1->ne[2] || + dst->ne[3] != src1->ne[3]) { + return HTP_STATUS_INVAL_PARAMS; + } + + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { + return HTP_STATUS_OK; + } + + const uint32_t k = src1->ne[0]; + + const uint32_t col_block = VLEN_FP32; + const uint32_t k_chunks = (k + col_block - 1) / col_block; + const uint32_t total_batches = src0->ne[2] * src0->ne[3]; + const bool batched = total_batches >= (uint32_t) octx->n_threads; + + FARF(HIGH, "solve-tri: (%ux%ux%ux%u) x (%ux%ux%ux%u) -> (%ux%ux%ux%u) : batched %d\n", + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], batched); + + if (batched) { + // Batch-level parallelism + const uint32_t n_threads = MIN((uint32_t) octx->n_threads, total_batches); + + struct htp_solve_tri_context sctx = { + .octx = octx, + .jobs_per_thread = (total_batches + n_threads - 1) / n_threads, + .total_jobs = total_batches, + .k_chunks = k_chunks, + .col_block = col_block, + }; + + worker_pool_run_func(octx->ctx->worker_pool, solve_tri_batch_thread_f32, &sctx, n_threads); + } else { + // Chunk-level parallelism + const uint32_t total_jobs = total_batches * k_chunks; + const uint32_t n_threads = MIN((uint32_t) octx->n_threads, MAX(total_jobs, 1)); + + struct htp_solve_tri_context sctx = { + .octx = octx, + .jobs_per_thread = (total_jobs + n_threads - 1) / n_threads, + .total_jobs = total_jobs, + .k_chunks = k_chunks, + .col_block = col_block, + }; + + worker_pool_run_func(octx->ctx->worker_pool, solve_tri_chunk_thread_f32, &sctx, n_threads); + } + + return HTP_STATUS_OK; +} From dfb8b68799f3aa4781c11007ecfc82fc146728eb Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 24 Apr 2026 11:02:00 +0300 Subject: [PATCH 486/831] ggml : minor coding style (llama/22308) --- ggml/src/ggml.c | 40 ++++++++++++++++++++-------------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index eda041f4518..54d3eae3e4d 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -7656,7 +7656,7 @@ size_t ggml_quantize_chunk( int64_t nrows, int64_t n_per_row, const float * imatrix) { - const int64_t n = (int64_t) nrows * n_per_row; + const int64_t n = nrows * n_per_row; if (ggml_quantize_requires_imatrix(type)) { GGML_ASSERT(imatrix != NULL); @@ -7673,21 +7673,21 @@ size_t ggml_quantize_chunk( size_t result = 0; switch (type) { - case GGML_TYPE_Q1_0: result = quantize_q1_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_Q4_0: result = quantize_q4_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_Q4_1: result = quantize_q4_1(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_Q5_0: result = quantize_q5_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_Q5_1: result = quantize_q5_1(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_Q8_0: result = quantize_q8_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_MXFP4: result = quantize_mxfp4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_NVFP4: result = quantize_nvfp4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_Q2_K: result = quantize_q2_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_Q3_K: result = quantize_q3_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_Q4_K: result = quantize_q4_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_Q5_K: result = quantize_q5_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_Q6_K: result = quantize_q6_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_TQ1_0: result = quantize_tq1_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_TQ2_0: result = quantize_tq2_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q1_0: result = quantize_q1_0 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q4_0: result = quantize_q4_0 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q4_1: result = quantize_q4_1 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q5_0: result = quantize_q5_0 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q5_1: result = quantize_q5_1 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q8_0: result = quantize_q8_0 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_MXFP4: result = quantize_mxfp4 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_NVFP4: result = quantize_nvfp4 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q2_K: result = quantize_q2_K (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q3_K: result = quantize_q3_K (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q4_K: result = quantize_q4_K (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q5_K: result = quantize_q5_K (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q6_K: result = quantize_q6_K (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_TQ1_0: result = quantize_tq1_0 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_TQ2_0: result = quantize_tq2_0 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ2_XXS: result = quantize_iq2_xxs(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ2_XS: result = quantize_iq2_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ3_XXS: result = quantize_iq3_xxs(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; @@ -7752,9 +7752,9 @@ struct ggml_threadpool_params ggml_threadpool_params_default(int n_threads) { } bool ggml_threadpool_params_match(const struct ggml_threadpool_params * p0, const struct ggml_threadpool_params * p1) { - if (p0->n_threads != p1->n_threads ) return false; - if (p0->prio != p1->prio ) return false; - if (p0->poll != p1->poll ) return false; - if (p0->strict_cpu != p1->strict_cpu ) return false; + if (p0->n_threads != p1->n_threads ) return false; + if (p0->prio != p1->prio ) return false; + if (p0->poll != p1->poll ) return false; + if (p0->strict_cpu != p1->strict_cpu ) return false; return memcmp(p0->cpumask, p1->cpumask, GGML_MAX_N_THREADS) == 0; } From 07d6db39e5f659a048e07e28f19e4439ebe1625e Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 24 Apr 2026 13:56:03 +0300 Subject: [PATCH 487/831] metal : print GPU description (llama/22318) --- ggml/src/ggml-metal/ggml-metal-device.m | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index f17f7e2e0ce..27b78c5e6d7 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -814,7 +814,7 @@ ggml_metal_device_t ggml_metal_device_init(int device) { } // print MTL GPU family: - GGML_LOG_INFO("%s: GPU name: %s\n", __func__, dev->props.name); + GGML_LOG_INFO("%s: GPU name: %s (%s)\n", __func__, dev->props.name, dev->props.desc); // determine max supported GPU family // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf From 6576c4da90f5a8b1662697a2b73442276657677c Mon Sep 17 00:00:00 2001 From: Mengsheng Wu Date: Sat, 25 Apr 2026 00:21:33 +0800 Subject: [PATCH 488/831] hexagon: use DIRID 13 in libggml-htp.inf for modern InfVerif (llama/22306) --- ggml/src/ggml-hexagon/libggml-htp.inf | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-hexagon/libggml-htp.inf b/ggml/src/ggml-hexagon/libggml-htp.inf index 360d8b1228e..39cefcdda38 100644 --- a/ggml/src/ggml-hexagon/libggml-htp.inf +++ b/ggml/src/ggml-hexagon/libggml-htp.inf @@ -8,7 +8,7 @@ CatalogFile = libggml-htp.cat PnpLockDown = 1 [DestinationDirs] -Drivers_Dir = 6 +Drivers_Dir = 13 [SourceDisksNames] 1 = %DiskId% From 35d679a4f8f51833e6d25b0f748632ba888d3d7b Mon Sep 17 00:00:00 2001 From: Zheyuan Chen Date: Fri, 24 Apr 2026 10:39:09 -0700 Subject: [PATCH 489/831] ggml-webgpu: enable FLASH_ATTN_EXT on browser without subgroup matrix (llama/22199) * ggml-webgpu: add tile flash attention fallback * ggml-webgpu: add new fields and discard usage of mnk for tile version * ggml-webgpu: modify the vec path to discard the mnk parameter * ggml-webgpu: enable flash attention vec and tile version for broswer * ggml-webgpu: stagging KV for flash attention tile version * formatting * turn on subgroup uniformity check * remove Q_TILE as it is always 1 for vec path * make row_max and exp_sum to local register * make different bindings with same underlying buffer to have the same usage flags * move path selection into the shader library and have the host consume a single flash-attn decision object. * turn off skip_validation and address buffer overlapping when nwg==1 * formatting * merge binding when kv overlap --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 326 +++++++++-------- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 193 ++++++---- .../ggml-webgpu/wgsl-shaders/flash_attn.wgsl | 29 ++ .../wgsl-shaders/flash_attn_tile.wgsl | 330 ++++++++++++++++++ .../wgsl-shaders/flash_attn_vec_blk.wgsl | 2 +- .../wgsl-shaders/flash_attn_vec_split.wgsl | 321 ++++++++--------- 6 files changed, 809 insertions(+), 392 deletions(-) create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 449eae808e4..e492c2123a4 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -436,19 +436,27 @@ struct ggml_webgpu_unary_pipeline_key_hash { /** FlashAttention */ +enum ggml_webgpu_flash_attn_path : uint32_t { + GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX = 0u, + GGML_WEBGPU_FLASH_ATTN_PATH_TILE = 1u, + GGML_WEBGPU_FLASH_ATTN_PATH_VEC = 2u, +}; + struct ggml_webgpu_flash_attn_pipeline_key { ggml_type kv_type; uint32_t head_dim_qk; uint32_t head_dim_v; bool kv_direct; + bool kv_overlap; bool has_mask; bool has_sinks; bool uses_logit_softcap; + uint32_t path; bool operator==(const ggml_webgpu_flash_attn_pipeline_key & other) const { return kv_type == other.kv_type && head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v && - kv_direct == other.kv_direct && has_mask == other.has_mask && has_sinks == other.has_sinks && - uses_logit_softcap == other.uses_logit_softcap; + kv_direct == other.kv_direct && kv_overlap == other.kv_overlap && has_mask == other.has_mask && + has_sinks == other.has_sinks && uses_logit_softcap == other.uses_logit_softcap && path == other.path; } }; @@ -459,39 +467,70 @@ struct ggml_webgpu_flash_attn_pipeline_key_hash { ggml_webgpu_hash_combine(seed, key.head_dim_qk); ggml_webgpu_hash_combine(seed, key.head_dim_v); ggml_webgpu_hash_combine(seed, key.kv_direct); + ggml_webgpu_hash_combine(seed, key.kv_overlap); ggml_webgpu_hash_combine(seed, key.has_mask); ggml_webgpu_hash_combine(seed, key.has_sinks); ggml_webgpu_hash_combine(seed, key.uses_logit_softcap); + ggml_webgpu_hash_combine(seed, key.path); return seed; } }; struct ggml_webgpu_flash_attn_decisions { - uint32_t q_tile = 0; - uint32_t kv_tile = 0; - uint32_t wg_size = 0; + uint32_t path = GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX; + uint32_t q_tile = 0; + uint32_t kv_tile = 0; + uint32_t wg_size = 0; + bool kv_direct = false; }; -struct ggml_webgpu_flash_attn_vec_decisions { - uint32_t kv_tile = 0; - uint32_t wg_size = 0; -}; +inline constexpr uint32_t GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH = 4u; +inline constexpr uint32_t GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE = 4u; + +inline uint32_t ggml_webgpu_flash_attn_pick_vec_ne(const ggml_webgpu_flash_attn_pipeline_key & key) { + if (key.path != GGML_WEBGPU_FLASH_ATTN_PATH_VEC || key.kv_type != GGML_TYPE_F16 || + key.head_dim_qk != key.head_dim_v) { + return 1u; + } + + switch (key.head_dim_qk) { + case 64: + case 192: + case 576: + return 2u; + case 96: + return 4u; + default: + return 1u; + } +} inline ggml_webgpu_flash_attn_pipeline_key ggml_webgpu_flash_attn_make_pipeline_key( - const ggml_webgpu_shader_lib_context & context) { + const ggml_webgpu_shader_lib_context & context, + uint32_t path) { const bool has_mask = context.src3 != nullptr; const bool has_sinks = context.src4 != nullptr; - const bool kv_direct = (context.src1->type == GGML_TYPE_F16) && (context.src0->ne[0] % context.sg_mat_k == 0) && - (context.src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0); + bool kv_direct = false; + if (path != GGML_WEBGPU_FLASH_ATTN_PATH_TILE) { + uint32_t kv_direct_align = GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH; + if (path == GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX) { + kv_direct_align = context.sg_mat_k; + } + kv_direct = (context.src1->type == GGML_TYPE_F16) && + (context.src0->ne[0] % std::max(1u, kv_direct_align) == 0) && + (context.src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0); + } ggml_webgpu_flash_attn_pipeline_key key = {}; key.kv_type = context.src1->type; key.head_dim_qk = (uint32_t) context.src0->ne[0]; key.head_dim_v = (uint32_t) context.src2->ne[0]; key.kv_direct = kv_direct; + key.kv_overlap = context.src_overlap; key.has_mask = has_mask; key.has_sinks = has_sinks; key.uses_logit_softcap = ggml_get_op_params_f32(context.dst, 2) != 0.0f; + key.path = path; return key; } @@ -554,8 +593,16 @@ inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile, inline uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_shader_lib_context & context, const ggml_webgpu_flash_attn_pipeline_key & key) { - const size_t limit_bytes = context.wg_mem_limit_bytes; - const size_t q_tile = context.sg_mat_m; + const size_t limit_bytes = context.wg_mem_limit_bytes; + uint32_t q_tile = context.sg_mat_m; + uint32_t kv_granularity = context.sg_mat_n; + if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE) { + q_tile = GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE; + kv_granularity = std::max(1u, context.max_subgroup_size); + } else if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { + q_tile = 1u; + kv_granularity = 8u; + } const size_t base_q_bytes = (key.head_dim_qk + key.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES + 2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES; size_t bytes_per_kv = 0; @@ -568,23 +615,90 @@ inline uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_shader_lib_ bytes_per_kv += q_tile; bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES; const uint32_t max_kv_tile = (limit_bytes - base_q_bytes) / bytes_per_kv; - return (max_kv_tile / context.sg_mat_n) * context.sg_mat_n; + return (max_kv_tile / kv_granularity) * kv_granularity; } -inline uint32_t ggml_webgpu_flash_attn_vec_get_kv_tile(const ggml_webgpu_shader_lib_context & context) { - const ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context); - const uint32_t min_kv_tile = ggml_webgpu_flash_attn_max_kv_tile(context, key); - uint32_t kv_tile = std::max(context.sg_mat_n, std::min(32u, min_kv_tile)); - kv_tile = (kv_tile / context.sg_mat_n) * context.sg_mat_n; +inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions( + const ggml_webgpu_shader_lib_context & context, + size_t storage_offset_alignment) { + ggml_webgpu_flash_attn_decisions decisions = {}; + const size_t alignment = std::max(1u, storage_offset_alignment); + const auto * K = context.src1; + const auto * V = context.src2; + GGML_ASSERT(K != nullptr); + GGML_ASSERT(V != nullptr); + + const auto flash_attn_tensor_offset = [](const ggml_tensor * tensor) -> size_t { + constexpr uintptr_t ptr_base_addr = 0x1000u; + const ggml_tensor * base = tensor->view_src != nullptr ? tensor->view_src : tensor; + return reinterpret_cast(base->data) - ptr_base_addr + tensor->view_offs; + }; + + const uint32_t k_offset_elems = + (uint32_t) ((flash_attn_tensor_offset(K) & (alignment - 1)) / ggml_type_size(K->type)); + const uint32_t v_offset_elems = + (uint32_t) ((flash_attn_tensor_offset(V) & (alignment - 1)) / ggml_type_size(V->type)); + const bool f16_vec4_aligned = (k_offset_elems % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0u) && + (v_offset_elems % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0u); + const bool kv_vec_type_supported = + K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0; + const bool use_vec = context.supports_subgroups && (context.src0->ne[1] < 20) && (context.src0->ne[0] % 32 == 0) && + (context.src2->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && + kv_vec_type_supported && (K->type != GGML_TYPE_F16 || f16_vec4_aligned) && + (context.src2->type == K->type); + const bool use_tile = context.supports_subgroups && !context.supports_subgroup_matrix && K->type == GGML_TYPE_F16 && + V->type == GGML_TYPE_F16 && f16_vec4_aligned && + (context.src0->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && + (context.src2->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && !use_vec; + + decisions.path = use_vec ? GGML_WEBGPU_FLASH_ATTN_PATH_VEC : + use_tile ? GGML_WEBGPU_FLASH_ATTN_PATH_TILE : + GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX; + + const ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context, decisions.path); + decisions.kv_direct = key.kv_direct; + + if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { + const uint32_t min_kv_tile = ggml_webgpu_flash_attn_max_kv_tile(context, key); + decisions.q_tile = 1u; + decisions.kv_tile = std::max(8u, std::min(32u, min_kv_tile)); + decisions.kv_tile = (decisions.kv_tile / 8u) * 8u; + decisions.wg_size = std::max(1u, std::min(32u, context.max_subgroup_size)); + if (decisions.kv_direct) { + decisions.kv_tile = std::min(decisions.kv_tile, GGML_WEBGPU_KV_SEQ_PAD); + while (GGML_WEBGPU_KV_SEQ_PAD % decisions.kv_tile != 0) { + decisions.kv_tile -= 8u; + } + } + return decisions; + } + + decisions.q_tile = + decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE : context.sg_mat_m; + decisions.kv_tile = decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? + std::min(64u, ggml_webgpu_flash_attn_max_kv_tile(context, key)) : + std::min(ggml_webgpu_flash_attn_max_kv_tile(context, key), + context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES); + decisions.wg_size = decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? + GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE : + std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE); - if (key.kv_direct) { - kv_tile = std::min(kv_tile, GGML_WEBGPU_KV_SEQ_PAD); - while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) { - kv_tile -= context.sg_mat_n; + if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE) { + const uint32_t tile_kv_granularity = std::max(1u, context.max_subgroup_size); + decisions.kv_tile = + std::max(tile_kv_granularity, (decisions.kv_tile / tile_kv_granularity) * tile_kv_granularity); + } + + if (decisions.kv_direct) { + GGML_ASSERT(decisions.kv_tile <= GGML_WEBGPU_KV_SEQ_PAD); + while (GGML_WEBGPU_KV_SEQ_PAD % decisions.kv_tile != 0) { + decisions.kv_tile -= decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? + std::max(1u, context.max_subgroup_size) : + context.sg_mat_n; } } - return kv_tile; + return decisions; } /** Matrix Multiplication **/ @@ -821,8 +935,6 @@ class ggml_webgpu_shader_lib { repeat_pipelines; // type std::unordered_map flash_attn_pipelines; - std::unordered_map - flash_attn_vec_pipelines; std::unordered_map @@ -2044,14 +2156,19 @@ class ggml_webgpu_shader_lib { return repeat_pipelines[key]; } - webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_shader_lib_context & context) { - const ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context); - auto it = flash_attn_pipelines.find(key); + webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_shader_lib_context & context, + size_t storage_offset_alignment) { + const ggml_webgpu_flash_attn_decisions decisions = + ggml_webgpu_flash_attn_get_decisions(context, storage_offset_alignment); + ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context, decisions.path); + auto it = flash_attn_pipelines.find(key); if (it != flash_attn_pipelines.end()) { return it->second; } std::vector defines; - std::string variant = "flash_attn"; + std::string variant = decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC ? "flash_attn_vec" : + decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? "flash_attn_tile" : + "flash_attn"; switch (key.kv_type) { case GGML_TYPE_F32: @@ -2073,7 +2190,12 @@ class ggml_webgpu_shader_lib { if (key.has_mask) { defines.push_back("MASK"); - variant += "_mask"; + if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { + defines.push_back("BLK"); + variant += "_mask_blk"; + } else { + variant += "_mask"; + } } if (key.has_sinks) { defines.push_back("SINKS"); @@ -2087,6 +2209,10 @@ class ggml_webgpu_shader_lib { defines.push_back("KV_DIRECT"); variant += "_kvdirect"; } + if (key.kv_overlap) { + defines.push_back("KV_OVERLAP"); + variant += "_kv_overlap"; + } defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(key.head_dim_qk)); variant += std::string("_hsqk") + std::to_string(key.head_dim_qk); @@ -2094,129 +2220,37 @@ class ggml_webgpu_shader_lib { defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v)); variant += std::string("_hsv") + std::to_string(key.head_dim_v); - defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m)); - defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n)); - defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k)); - - auto decisions = std::make_shared(); - decisions->q_tile = context.sg_mat_m; - - const uint32_t min_kv_tile = ggml_webgpu_flash_attn_max_kv_tile(context, key); - uint32_t kv_tile = std::min(min_kv_tile, context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES); - - if (key.kv_direct) { - kv_tile = std::min(kv_tile, GGML_WEBGPU_KV_SEQ_PAD); - while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) { - kv_tile -= context.sg_mat_n; - } + const char * shader_src = wgsl_flash_attn; + if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { + defines.push_back("KV_GRANULARITY=8"); + defines.push_back(std::string("VEC_NE=") + std::to_string(ggml_webgpu_flash_attn_pick_vec_ne(key)) + "u"); + shader_src = wgsl_flash_attn_vec_split; + } else if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE) { + shader_src = wgsl_flash_attn_tile; + defines.push_back("MAX_SUBGROUP_SIZE=" + std::to_string(context.max_subgroup_size)); + defines.push_back("KV_STAGE_STRIDE=" + std::to_string(std::max(key.head_dim_qk, key.head_dim_v))); + variant += "_tile"; + } else { + defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m)); + defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n)); + defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k)); } - decisions->kv_tile = kv_tile; - decisions->wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE); - - defines.push_back(std::string("Q_TILE=") + std::to_string(decisions->q_tile)); - defines.push_back(std::string("KV_TILE=") + std::to_string(decisions->kv_tile)); - defines.push_back(std::string("WG_SIZE=") + std::to_string(decisions->wg_size)); + auto pipeline_decisions = std::make_shared(decisions); + defines.push_back(std::string("Q_TILE=") + std::to_string(decisions.q_tile)); + defines.push_back(std::string("KV_TILE=") + std::to_string(decisions.kv_tile)); + defines.push_back(std::string("WG_SIZE=") + std::to_string(decisions.wg_size)); webgpu_pipeline pipeline = - ggml_webgpu_create_pipeline(device, preprocessor.preprocess(wgsl_flash_attn, defines), variant); - pipeline.context = decisions; + ggml_webgpu_create_pipeline(device, preprocessor.preprocess(shader_src, defines), variant); + pipeline.context = pipeline_decisions; flash_attn_pipelines[key] = pipeline; return flash_attn_pipelines[key]; } - webgpu_pipeline get_flash_attn_vec_pipeline(const ggml_webgpu_shader_lib_context & context) { - const ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context); - auto it = flash_attn_vec_pipelines.find(key); - if (it != flash_attn_vec_pipelines.end()) { - return it->second; - } - - std::vector defines; - std::string variant = "flash_attn_vec"; - - switch (key.kv_type) { - case GGML_TYPE_F32: - defines.push_back("KV_F32"); - break; - case GGML_TYPE_F16: - defines.push_back("KV_F16"); - break; - case GGML_TYPE_Q4_0: - defines.push_back("KV_Q4_0"); - break; - case GGML_TYPE_Q8_0: - defines.push_back("KV_Q8_0"); - break; - default: - GGML_ABORT("Unsupported KV type for flash attention shader"); - } - variant += std::string("_") + ggml_type_name(key.kv_type); - - if (key.has_mask) { - defines.push_back("MASK"); - defines.push_back("BLK"); - variant += "_mask_blk"; - } - if (key.has_sinks) { - defines.push_back("SINKS"); - variant += "_sinks"; - } - if (key.uses_logit_softcap) { - defines.push_back("LOGIT_SOFTCAP"); - variant += "_lgsc"; - } - if (key.kv_direct) { - defines.push_back("KV_DIRECT"); - variant += "_kvdirect"; - } - - defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(key.head_dim_qk)); - variant += std::string("_hsqk") + std::to_string(key.head_dim_qk); - - defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v)); - variant += std::string("_hsv") + std::to_string(key.head_dim_v); - - defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m)); - defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n)); - defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k)); - defines.push_back("Q_TILE=1"); - - auto decisions = std::make_shared(); - decisions->kv_tile = ggml_webgpu_flash_attn_vec_get_kv_tile(context); - decisions->wg_size = std::max(1u, std::min(32u, context.max_subgroup_size)); - uint32_t vec_ne = 1u; - - // Keep conservative defaults unless this is the f16 vec-split shape family. - if (key.kv_type == GGML_TYPE_F16 && key.head_dim_qk == key.head_dim_v) { - switch (key.head_dim_qk) { - case 64: - case 192: - case 576: - vec_ne = 2u; - break; - case 96: - vec_ne = 4u; - break; - default: - break; - } - } - - defines.push_back(std::string("KV_TILE=") + std::to_string(decisions->kv_tile)); - defines.push_back(std::string("WG_SIZE=") + std::to_string(decisions->wg_size)); - defines.push_back(std::string("VEC_NE=") + std::to_string(vec_ne) + "u"); - - webgpu_pipeline pipeline = - ggml_webgpu_create_pipeline(device, preprocessor.preprocess(wgsl_flash_attn_vec_split, defines), variant); - pipeline.context = decisions; - flash_attn_vec_pipelines[key] = pipeline; - return flash_attn_vec_pipelines[key]; - } - - webgpu_pipeline get_flash_attn_blk_pipeline(const ggml_webgpu_shader_lib_context & context) { + webgpu_pipeline get_flash_attn_blk_pipeline(const ggml_webgpu_shader_lib_context & context, uint32_t kv_tile) { ggml_webgpu_flash_attn_blk_pipeline_key key = {}; - key.kv_tile = ggml_webgpu_flash_attn_vec_get_kv_tile(context); + key.kv_tile = kv_tile; auto it = flash_attn_blk_pipelines.find(key); if (it != flash_attn_blk_pipelines.end()) { return it->second; diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index acc486cfdda..7ed6fdd1625 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -389,23 +389,6 @@ static size_t ggml_webgpu_tensor_misalignment(webgpu_context & ctx, const ggml_t return offset & (ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment - 1); } -static bool ggml_webgpu_flash_attn_use_vec(webgpu_global_context & global_ctx, - const ggml_tensor * Q, - const ggml_tensor * K, - const ggml_tensor * V) { - const size_t alignment = global_ctx->capabilities.limits.minStorageBufferOffsetAlignment; - const uint32_t k_offset_elems = - (uint32_t) ((ggml_webgpu_tensor_offset(K) & (alignment - 1)) / ggml_type_size(K->type)); - const uint32_t v_offset_elems = - (uint32_t) ((ggml_webgpu_tensor_offset(V) & (alignment - 1)) / ggml_type_size(V->type)); - const bool f16_vec4_aligned = (k_offset_elems % 4u == 0u) && (v_offset_elems % 4u == 0u); - const bool kv_vec_type_supported = - K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0; - - return (Q->ne[1] < 20) && (Q->ne[0] % 32 == 0) && (V->ne[0] % 4 == 0) && kv_vec_type_supported && - (K->type != GGML_TYPE_F16 || f16_vec4_aligned) && (V->type == K->type); -} - static size_t ggml_webgpu_tensor_align_offset(webgpu_context & ctx, const ggml_tensor * t) { size_t offset = ggml_webgpu_tensor_offset(t); return offset & ~(ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment - 1); @@ -1567,7 +1550,6 @@ static webgpu_encoded_op ggml_webgpu_mul_mat_id(webgpu_context & ctx, return ggml_backend_webgpu_build_multi(ctx, dispatches); } -#ifndef __EMSCRIPTEN__ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, ggml_tensor * Q, ggml_tensor * K, @@ -1585,13 +1567,29 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, float m0 = powf(2.0f, -(max_bias) / n_head_log2); float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - const int has_mask = (mask != nullptr); - const int has_sinks = (sinks != nullptr); + const int has_mask = (mask != nullptr); + const int has_sinks = (sinks != nullptr); + const bool kv_overlap = ggml_webgpu_tensor_overlap(K, V) && K->type == V->type; + + uint32_t offset_k = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type)); + uint32_t offset_v = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type)); + size_t kv_bind_offset = 0; + size_t kv_bind_size = 0; + if (kv_overlap) { + const size_t k_bind_offset = ggml_webgpu_tensor_align_offset(ctx, K); + const size_t v_bind_offset = ggml_webgpu_tensor_align_offset(ctx, V); + const size_t k_bind_end = k_bind_offset + ggml_webgpu_tensor_binding_size(ctx, K); + const size_t v_bind_end = v_bind_offset + ggml_webgpu_tensor_binding_size(ctx, V); + kv_bind_offset = std::min(k_bind_offset, v_bind_offset); + kv_bind_size = std::max(k_bind_end, v_bind_end) - kv_bind_offset; + offset_k = (uint32_t) ((ggml_webgpu_tensor_offset(K) - kv_bind_offset) / ggml_type_size(K->type)); + offset_v = (uint32_t) ((ggml_webgpu_tensor_offset(V) - kv_bind_offset) / ggml_type_size(V->type)); + } std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, Q) / ggml_type_size(Q->type)), - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type)), - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type)), + offset_k, + offset_v, has_mask ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)) : 0, has_sinks ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, sinks) / ggml_type_size(sinks->type)) : 0, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), @@ -1619,10 +1617,15 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, }; std::vector entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, Q), - ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, K), - ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, V), }; - uint32_t binding_index = 3; + if (kv_overlap) { + entries.push_back( + ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(K), kv_bind_offset, kv_bind_size)); + } else { + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, K)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, V)); + } + uint32_t binding_index = kv_overlap ? 2u : 3u; if (has_mask) { entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index++, mask)); } @@ -1638,25 +1641,25 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, shader_lib_ctx.src3 = mask; shader_lib_ctx.src4 = sinks; shader_lib_ctx.dst = dst; + shader_lib_ctx.src_overlap = kv_overlap; + shader_lib_ctx.supports_subgroups = ctx->global_ctx->capabilities.supports_subgroups; + shader_lib_ctx.supports_subgroup_matrix = ctx->global_ctx->capabilities.supports_subgroup_matrix; shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; shader_lib_ctx.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; shader_lib_ctx.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m; shader_lib_ctx.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n; shader_lib_ctx.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k; shader_lib_ctx.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size; - const bool use_vec = ggml_webgpu_flash_attn_use_vec(ctx->global_ctx, Q, K, V); - webgpu_pipeline pipeline = use_vec ? ctx->shader_lib->get_flash_attn_vec_pipeline(shader_lib_ctx) : - ctx->shader_lib->get_flash_attn_pipeline(shader_lib_ctx); + webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_pipeline( + shader_lib_ctx, ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment); + auto * decisions = static_cast(pipeline.context.get()); - if (!use_vec) { - auto * decisions = static_cast(pipeline.context.get()); + if (decisions->path != GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions->q_tile); uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } - auto * decisions = static_cast(pipeline.context.get()); - wgpu::Buffer blk_buf = {}; uint64_t blk_size_bytes = 0; uint32_t blk_nblk0 = 0; @@ -1695,10 +1698,12 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, tmp_bind_size = tmp_size_bytes; scratch_offset = ROUNDUP_POW2(scratch_offset + tmp_size_bytes, align_bytes); } else { - // nwg==1 writes final dst directly in vec-split; keep tmp binding valid without extra allocation. + // nwg==1 writes final dst directly in vec-split; bind tmp to a tiny non-overlapping scratch region. + tmp_size_bytes = WEBGPU_STORAGE_BUF_BINDING_MULT; tmp_buf = ggml_webgpu_tensor_buf(dst); - tmp_bind_offset = ggml_webgpu_tensor_align_offset(ctx, dst); - tmp_bind_size = ggml_webgpu_tensor_binding_size(ctx, dst); + tmp_bind_offset = scratch_offset; + tmp_bind_size = tmp_size_bytes; + scratch_offset = ROUNDUP_POW2(scratch_offset + tmp_size_bytes, align_bytes); } webgpu_pipeline blk_pipeline; @@ -1713,7 +1718,7 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, const uint64_t blk_elems = (uint64_t) blk_nblk0 * blk_nblk1 * blk_batch_count; blk_size_bytes = ROUNDUP_POW2(blk_elems * sizeof(uint32_t), WEBGPU_STORAGE_BUF_BINDING_MULT); const ggml_webgpu_shader_lib_context blk_shader_ctx = shader_lib_ctx; - blk_pipeline = ctx->shader_lib->get_flash_attn_blk_pipeline(blk_shader_ctx); + blk_pipeline = ctx->shader_lib->get_flash_attn_blk_pipeline(blk_shader_ctx, decisions->kv_tile); blk_params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)), // offset_mask @@ -1745,12 +1750,19 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, std::vector split_entries = { ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(Q), ggml_webgpu_tensor_align_offset(ctx, Q), ggml_webgpu_tensor_binding_size(ctx, Q)), - ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(K), ggml_webgpu_tensor_align_offset(ctx, K), - ggml_webgpu_tensor_binding_size(ctx, K)), - ggml_webgpu_make_bind_group_entry(2, ggml_webgpu_tensor_buf(V), ggml_webgpu_tensor_align_offset(ctx, V), - ggml_webgpu_tensor_binding_size(ctx, V)), }; - uint32_t split_binding_index = 3; + if (kv_overlap) { + split_entries.push_back( + ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(K), kv_bind_offset, kv_bind_size)); + } else { + split_entries.push_back(ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(K), + ggml_webgpu_tensor_align_offset(ctx, K), + ggml_webgpu_tensor_binding_size(ctx, K))); + split_entries.push_back(ggml_webgpu_make_bind_group_entry(2, ggml_webgpu_tensor_buf(V), + ggml_webgpu_tensor_align_offset(ctx, V), + ggml_webgpu_tensor_binding_size(ctx, V))); + } + uint32_t split_binding_index = kv_overlap ? 2u : 3u; if (has_mask) { split_entries.push_back(ggml_webgpu_make_bind_group_entry(split_binding_index++, ggml_webgpu_tensor_buf(mask), ggml_webgpu_tensor_align_offset(ctx, mask), @@ -1820,7 +1832,6 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, return ggml_backend_webgpu_build_multi(ctx, dispatches); } -#endif // __EMSCRIPTEN__ static webgpu_encoded_op ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { bool is_unary = dst->op == GGML_OP_UNARY; @@ -2710,11 +2721,7 @@ static std::optional ggml_webgpu_encode(webgpu_context ctx, case GGML_OP_MUL_MAT_ID: return ggml_webgpu_mul_mat_id(ctx, src0, src1, src2, node); case GGML_OP_FLASH_ATTN_EXT: -#ifndef __EMSCRIPTEN__ return ggml_webgpu_flash_attn(ctx, src0, src1, src2, node->src[3], node->src[4], node); -#else - return std::nullopt; -#endif case GGML_OP_ADD: case GGML_OP_SUB: case GGML_OP_MUL: @@ -3257,13 +3264,19 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer ctx->webgpu_global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; shader_lib_ctx.wg_mem_limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; + shader_lib_ctx.supports_subgroups = ctx->webgpu_global_ctx->capabilities.supports_subgroups; + shader_lib_ctx.supports_subgroup_matrix = + ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix; shader_lib_ctx.sg_mat_m = ctx->webgpu_global_ctx->capabilities.sg_mat_m; shader_lib_ctx.sg_mat_n = ctx->webgpu_global_ctx->capabilities.sg_mat_n; shader_lib_ctx.sg_mat_k = ctx->webgpu_global_ctx->capabilities.sg_mat_k; shader_lib_ctx.max_subgroup_size = ctx->webgpu_global_ctx->capabilities.max_subgroup_size; - if (ggml_webgpu_flash_attn_use_vec(ctx->webgpu_global_ctx, Q, K, V)) { - const uint32_t kv_tile = ggml_webgpu_flash_attn_vec_get_kv_tile(shader_lib_ctx); + const ggml_webgpu_flash_attn_decisions decisions = ggml_webgpu_flash_attn_get_decisions( + shader_lib_ctx, ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment); + + if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { + const uint32_t kv_tile = decisions.kv_tile; const uint32_t vec_nwg_cap = std::max( 1u, std::min(32u, ctx->webgpu_global_ctx->capabilities.max_subgroup_size)); @@ -3283,6 +3296,8 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer const size_t tmp_size_bytes = ROUNDUP_POW2( (tmp_data_elems + tmp_stats_elems) * sizeof(float), WEBGPU_STORAGE_BUF_BINDING_MULT); res += tmp_size_bytes + align; + } else { + res += WEBGPU_STORAGE_BUF_BINDING_MULT + align; } if (mask != nullptr) { const uint32_t blk_nblk0 = CEIL_DIV((uint32_t) K->ne[1], kv_tile); @@ -3431,12 +3446,12 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { ctx->webgpu_global_ctx->capabilities.supports_subgroups = ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::Subgroups); + bool valid_subgroup_matrix_config = false; #ifndef __EMSCRIPTEN__ // Accept f16 subgroup matrix configurations (square or non-square). // NVIDIA GPUs typically report square configs (e.g. 16x16x16), // while Intel Xe2 GPUs report non-square configs (e.g. 8x16x16). // The shaders are already parameterized to handle any M/N/K dimensions. - bool valid_subgroup_matrix_config = false; if (ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) { for (size_t i = 0; i < subgroup_matrix_configs.configCount; i++) { const wgpu::SubgroupMatrixConfig config = subgroup_matrix_configs.configs[i]; @@ -3450,8 +3465,8 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { } } } - ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix = valid_subgroup_matrix_config; #endif + ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix = valid_subgroup_matrix_config; // For subgroup matrix code to be the most efficient, we would like the subgroup size to be consistent and accurate. // Unfortunately, that is not possible, so we use the maximum subgroup size reported by the adapter. @@ -3499,12 +3514,12 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { // Enable Dawn-specific toggles to increase native performance // TODO: Maybe WebGPU needs a "fast" mode where you can request compilers skip adding checks like these, // only for native performance? - const char * const deviceEnabledToggles[] = { "skip_validation", "disable_robustness", "disable_workgroup_init", - "disable_polyfills_on_integer_div_and_mod" }; - const char * const deviceDisabledToggles[] = { "timestamp_quantization" }; + const char * const deviceEnabledToggles[] = { "disable_robustness", "disable_workgroup_init", + "disable_polyfills_on_integer_div_and_mod" }; + const char * const deviceDisabledToggles[] = { "timestamp_quantization" }; wgpu::DawnTogglesDescriptor deviceTogglesDesc; deviceTogglesDesc.enabledToggles = deviceEnabledToggles; - deviceTogglesDesc.enabledToggleCount = 4; + deviceTogglesDesc.enabledToggleCount = 3; deviceTogglesDesc.disabledToggles = deviceDisabledToggles; deviceTogglesDesc.disabledToggleCount = 1; @@ -3782,33 +3797,63 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const break; case GGML_OP_FLASH_ATTN_EXT: { -#ifndef __EMSCRIPTEN__ - if (!ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix) { + supports_op = src0->type == GGML_TYPE_F32 && + (src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16 || + src1->type == GGML_TYPE_Q4_0 || src1->type == GGML_TYPE_Q8_0) && + src2->type == src1->type && op->type == GGML_TYPE_F32; + if (!supports_op) { break; } - // Head dimensions must be divisible by subgroup matrix dimensions - if (src0->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k != 0 || - src2->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_n != 0) { + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.src2 = src2; + shader_lib_ctx.src3 = op->src[3]; + shader_lib_ctx.src4 = op->src[4]; + shader_lib_ctx.dst = const_cast(op); + shader_lib_ctx.supports_subgroups = ctx->webgpu_global_ctx->capabilities.supports_subgroups; + shader_lib_ctx.supports_subgroup_matrix = ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix; + shader_lib_ctx.wg_mem_limit_bytes = + ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; + shader_lib_ctx.sg_mat_m = ctx->webgpu_global_ctx->capabilities.sg_mat_m; + shader_lib_ctx.sg_mat_n = ctx->webgpu_global_ctx->capabilities.sg_mat_n; + shader_lib_ctx.sg_mat_k = ctx->webgpu_global_ctx->capabilities.sg_mat_k; + shader_lib_ctx.max_subgroup_size = ctx->webgpu_global_ctx->capabilities.max_subgroup_size; + + const ggml_webgpu_flash_attn_decisions decisions = ggml_webgpu_flash_attn_get_decisions( + shader_lib_ctx, ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment); + const size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; + const bool has_mask = op->src[3] != nullptr; + if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { + const size_t min_bytes = + ggml_webgpu_flash_attn_wg_mem_bytes(decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0], + (uint32_t) src2->ne[0], has_mask, decisions.kv_direct); + if (min_bytes > limit_bytes) { + supports_op = false; + } break; } - // Head dimensions must fit in workgroup memory with minimum tile sizes - size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; - const bool has_mask = op->src[3] != nullptr; - const bool kv_direct = src1->type == GGML_TYPE_F16 && - (src0->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k) == 0 && - (src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD) == 0; - const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes( - ctx->webgpu_global_ctx->capabilities.sg_mat_m, ctx->webgpu_global_ctx->capabilities.sg_mat_n, - (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask, kv_direct); - if (min_bytes > limit_bytes) { + + if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE) { + const size_t min_bytes = + ggml_webgpu_flash_attn_wg_mem_bytes(decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0], + (uint32_t) src2->ne[0], has_mask, decisions.kv_direct); + if (min_bytes > limit_bytes) { + supports_op = false; + } break; } - supports_op = src0->type == GGML_TYPE_F32 && - (src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16 || - src1->type == GGML_TYPE_Q4_0 || src1->type == GGML_TYPE_Q8_0) && - src2->type == src1->type && op->type == GGML_TYPE_F32; -#endif + if (!ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix) { + supports_op = false; + break; + } + const size_t min_bytes = + ggml_webgpu_flash_attn_wg_mem_bytes(decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0], + (uint32_t) src2->ne[0], has_mask, decisions.kv_direct); + if (min_bytes > limit_bytes) { + supports_op = false; + } break; } case GGML_OP_RMS_NORM: diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl index aa2d2e54db9..6d5d69fb8de 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl @@ -138,26 +138,55 @@ struct Params { }; @group(0) @binding(0) var Q: array; +#ifdef KV_OVERLAP +@group(0) @binding(1) var K: array; +#define V K +#else @group(0) @binding(1) var K: array; @group(0) @binding(2) var V: array; +#endif #if defined(MASK) && defined(SINKS) +#ifdef KV_OVERLAP +@group(0) @binding(2) var mask: array; +@group(0) @binding(3) var sinks: array; +#define DST_BINDING 4 +#define PARAMS_BINDING 5 +#else @group(0) @binding(3) var mask: array; @group(0) @binding(4) var sinks: array; #define DST_BINDING 5 #define PARAMS_BINDING 6 +#endif #elif defined(MASK) +#ifdef KV_OVERLAP +@group(0) @binding(2) var mask: array; +#define DST_BINDING 3 +#define PARAMS_BINDING 4 +#else @group(0) @binding(3) var mask: array; #define DST_BINDING 4 #define PARAMS_BINDING 5 +#endif #elif defined(SINKS) +#ifdef KV_OVERLAP +@group(0) @binding(2) var sinks: array; +#define DST_BINDING 3 +#define PARAMS_BINDING 4 +#else @group(0) @binding(3) var sinks: array; #define DST_BINDING 4 #define PARAMS_BINDING 5 +#endif +#else +#ifdef KV_OVERLAP +#define DST_BINDING 2 +#define PARAMS_BINDING 3 #else #define DST_BINDING 3 #define PARAMS_BINDING 4 #endif +#endif @group(0) @binding(DST_BINDING) var dst: array>; @group(0) @binding(PARAMS_BINDING) var params: Params; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl new file mode 100644 index 00000000000..37ea23b80c8 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl @@ -0,0 +1,330 @@ +enable f16; +enable subgroups; + +#define HEAD_DIM_QK 64 +#define HEAD_DIM_V 64 +#define KV_STAGE_STRIDE 64 +#define Q_TILE 4 +#define KV_TILE 64 +#define WG_SIZE 128 + +struct Params { + offset_q: u32, + offset_k: u32, + offset_v: u32, + offset_mask: u32, + offset_sinks: u32, + offset_dst: u32, + + n_heads: u32, + seq_len_q: u32, + seq_len_kv: u32, + + stride_q1: u32, + stride_q2: u32, + stride_q3: u32, + stride_k1: u32, + stride_k2: u32, + stride_k3: u32, + stride_v1: u32, + stride_v2: u32, + stride_v3: u32, + stride_mask3: u32, + + q_per_kv: u32, + + scale: f32, + max_bias: f32, + logit_softcap: f32, + n_head_log2: f32, + m0: f32, + m1: f32, +}; + +@group(0) @binding(0) var Q: array; +#ifdef KV_OVERLAP +@group(0) @binding(1) var K: array>; +#define V K +#else +@group(0) @binding(1) var K: array>; +@group(0) @binding(2) var V: array>; +#endif + +#if defined(MASK) && defined(SINKS) +#ifdef KV_OVERLAP +@group(0) @binding(2) var mask: array; +@group(0) @binding(3) var sinks: array; +#define DST_BINDING 4 +#define PARAMS_BINDING 5 +#else +@group(0) @binding(3) var mask: array; +@group(0) @binding(4) var sinks: array; +#define DST_BINDING 5 +#define PARAMS_BINDING 6 +#endif +#elif defined(MASK) +#ifdef KV_OVERLAP +@group(0) @binding(2) var mask: array; +#define DST_BINDING 3 +#define PARAMS_BINDING 4 +#else +@group(0) @binding(3) var mask: array; +#define DST_BINDING 4 +#define PARAMS_BINDING 5 +#endif +#elif defined(SINKS) +#ifdef KV_OVERLAP +@group(0) @binding(2) var sinks: array; +#define DST_BINDING 3 +#define PARAMS_BINDING 4 +#else +@group(0) @binding(3) var sinks: array; +#define DST_BINDING 4 +#define PARAMS_BINDING 5 +#endif +#else +#ifdef KV_OVERLAP +#define DST_BINDING 2 +#define PARAMS_BINDING 3 +#else +#define DST_BINDING 3 +#define PARAMS_BINDING 4 +#endif +#endif + +@group(0) @binding(DST_BINDING) var dst: array>; +@group(0) @binding(PARAMS_BINDING) var params: Params; + +const FLOAT_MIN: f32 = -1.0e9; +const Q_CHUNKS: u32 = HEAD_DIM_QK / 4u; +const V_CHUNKS: u32 = HEAD_DIM_V / 4u; +const SCORE_REGS_PER_LANE: u32 = (KV_TILE + MAX_SUBGROUP_SIZE - 1u) / MAX_SUBGROUP_SIZE; +const OUT_REGS_PER_LANE: u32 = (V_CHUNKS + MAX_SUBGROUP_SIZE - 1u) / MAX_SUBGROUP_SIZE; + +var q_shmem: array; +var kv_shmem: array; +var p_shmem: array; + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(workgroup_id) wg_id: vec3, + @builtin(local_invocation_id) local_id: vec3, + @builtin(subgroup_id) subgroup_id: u32, + @builtin(subgroup_size) subgroup_size: u32, + @builtin(num_subgroups) num_subgroups: u32, + @builtin(subgroup_invocation_id) sg_inv_id: u32) { + if (subgroup_size == 0u || num_subgroups < Q_TILE) { + return; + } + + let wg_per_head = (params.seq_len_q + Q_TILE - 1u) / Q_TILE; + let wg_per_batch = wg_per_head * params.n_heads; + + let dst2_stride = HEAD_DIM_V * params.n_heads; + let dst3_stride = dst2_stride * params.seq_len_q; + + let batch_idx = wg_id.x / wg_per_batch; + let q_batch_offset = params.offset_q + batch_idx * params.stride_q3; + let k_batch_offset = params.offset_k + batch_idx * params.stride_k3; + let v_batch_offset = params.offset_v + batch_idx * params.stride_v3; + let dst_batch_offset = params.offset_dst + batch_idx * dst3_stride; + let wg_in_batch = wg_id.x % wg_per_batch; + + let head_idx = wg_in_batch / wg_per_head; + let q_head_offset = q_batch_offset + head_idx * params.stride_q2; + let k_head_idx = head_idx / params.q_per_kv; + let v_head_offset = v_batch_offset + k_head_idx * params.stride_v2; + let k_head_offset = k_batch_offset + k_head_idx * params.stride_k2; + + let wg_in_head = wg_in_batch % wg_per_head; + let q_row_start = wg_in_head * Q_TILE; + let global_q_row = q_row_start + subgroup_id; + let row_active = subgroup_id < Q_TILE && global_q_row < params.seq_len_q; + +#ifdef MASK + let mask_global_offset = params.offset_mask + batch_idx * params.stride_mask3 + q_row_start * params.seq_len_kv; +#endif + + let dst_global_offset = dst_batch_offset + q_row_start * dst2_stride + head_idx * HEAD_DIM_V; + + let head = f32(head_idx); + let slope = select(1.0, + select(pow(params.m1, 2.0 * (head - params.n_head_log2) + 1.0), + pow(params.m0, head + 1.0), + head < params.n_head_log2), + params.max_bias > 0.0); + + for (var elem_idx = local_id.x; elem_idx < Q_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE) { + let q_tile_row = elem_idx / HEAD_DIM_QK; + let q_col = elem_idx % HEAD_DIM_QK; + let head_q_row = q_row_start + q_tile_row; + let global_q_row_offset = q_head_offset + head_q_row * params.stride_q1; + q_shmem[elem_idx] = f16(select( + 0.0, + Q[global_q_row_offset + q_col] * params.scale, + head_q_row < params.seq_len_q)); + } + + workgroupBarrier(); + + var row_max = FLOAT_MIN; + var exp_sum = 0.0; + var out_regs: array, OUT_REGS_PER_LANE>; + for (var reg_idx = 0u; reg_idx < OUT_REGS_PER_LANE; reg_idx += 1u) { + out_regs[reg_idx] = vec4(0.0); + } + + let q_base = subgroup_id * HEAD_DIM_QK; + let subgroup_p_offset = subgroup_id * KV_TILE; + + for (var kv_tile = 0u; kv_tile < params.seq_len_kv; kv_tile += KV_TILE) { + let kv_count = min(KV_TILE, params.seq_len_kv - kv_tile); + let score_slots = min(SCORE_REGS_PER_LANE, (kv_count + subgroup_size - 1u) / subgroup_size); + let out_slots = min(OUT_REGS_PER_LANE, (V_CHUNKS + subgroup_size - 1u) / subgroup_size); + var local_scores: array; + for (var slot = 0u; slot < SCORE_REGS_PER_LANE; slot += 1u) { + local_scores[slot] = FLOAT_MIN; + } + + for (var vec_idx_local = local_id.x; vec_idx_local < kv_count * Q_CHUNKS; vec_idx_local += WG_SIZE) { + let kv_local = vec_idx_local / Q_CHUNKS; + let chunk = vec_idx_local % Q_CHUNKS; + let global_k_row = kv_tile + kv_local; + let k_vec_index = (k_head_offset + global_k_row * params.stride_k1 + chunk * 4u) >> 2u; + let k4 = K[k_vec_index]; + let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u; + kv_shmem[kv_off + 0u] = k4.x; + kv_shmem[kv_off + 1u] = k4.y; + kv_shmem[kv_off + 2u] = k4.z; + kv_shmem[kv_off + 3u] = k4.w; + } + + workgroupBarrier(); + + var local_max = FLOAT_MIN; + if (row_active) { + for (var slot = 0u; slot < score_slots; slot += 1u) { + let kv_local = sg_inv_id + slot * subgroup_size; + if (kv_local >= kv_count) { + continue; + } + + let global_k_row = kv_tile + kv_local; + var dot_val = 0.0; + for (var chunk = 0u; chunk < Q_CHUNKS; chunk += 1u) { + let q_off = q_base + chunk * 4u; + let qv = vec4( + f32(q_shmem[q_off + 0u]), + f32(q_shmem[q_off + 1u]), + f32(q_shmem[q_off + 2u]), + f32(q_shmem[q_off + 3u])); + let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u; + let kv = vec4( + f32(kv_shmem[kv_off + 0u]), + f32(kv_shmem[kv_off + 1u]), + f32(kv_shmem[kv_off + 2u]), + f32(kv_shmem[kv_off + 3u])); + dot_val += dot(qv, kv); + } +#ifdef LOGIT_SOFTCAP + dot_val = params.logit_softcap * tanh(dot_val); +#endif +#ifdef MASK + let mask_idx = mask_global_offset + subgroup_id * params.seq_len_kv + global_k_row; + dot_val += slope * f32(mask[mask_idx]); +#endif + local_scores[slot] = dot_val; + local_max = max(local_max, dot_val); + } + } + + let tile_max = subgroupMax(local_max); + let new_max = max(row_max, tile_max); + let cur_exp = exp(row_max - new_max); + exp_sum *= cur_exp; + for (var reg_idx = 0u; reg_idx < OUT_REGS_PER_LANE; reg_idx += 1u) { + out_regs[reg_idx] *= cur_exp; + } + + var local_sum = 0.0; + for (var slot = 0u; slot < score_slots; slot += 1u) { + let kv_local = sg_inv_id + slot * subgroup_size; + if (row_active && kv_local < kv_count) { + let p = exp(local_scores[slot] - new_max); + p_shmem[subgroup_p_offset + kv_local] = p; + local_sum += p; + } + } + + workgroupBarrier(); + + for (var vec_idx_local = local_id.x; vec_idx_local < kv_count * V_CHUNKS; vec_idx_local += WG_SIZE) { + let kv_local = vec_idx_local / V_CHUNKS; + let chunk = vec_idx_local % V_CHUNKS; + let global_v_row = kv_tile + kv_local; + let v_vec_index = (v_head_offset + global_v_row * params.stride_v1 + chunk * 4u) >> 2u; + let v4 = V[v_vec_index]; + let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u; + kv_shmem[kv_off + 0u] = v4.x; + kv_shmem[kv_off + 1u] = v4.y; + kv_shmem[kv_off + 2u] = v4.z; + kv_shmem[kv_off + 3u] = v4.w; + } + + workgroupBarrier(); + + let tile_sum = subgroupAdd(local_sum); + exp_sum += tile_sum; + row_max = new_max; + + if (row_active) { + for (var reg_idx = 0u; reg_idx < out_slots; reg_idx += 1u) { + let chunk = sg_inv_id + reg_idx * subgroup_size; + if (chunk >= V_CHUNKS) { + continue; + } + + var acc = out_regs[reg_idx]; + for (var kv_local = 0u; kv_local < kv_count; kv_local += 1u) { + let p = p_shmem[subgroup_p_offset + kv_local]; + let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u; + let v4 = vec4( + f32(kv_shmem[kv_off + 0u]), + f32(kv_shmem[kv_off + 1u]), + f32(kv_shmem[kv_off + 2u]), + f32(kv_shmem[kv_off + 3u])); + acc += p * v4; + } + out_regs[reg_idx] = acc; + } + } + + workgroupBarrier(); + } + +#ifdef SINKS + if (row_active) { + let sink_score = sinks[params.offset_sinks + head_idx]; + let sink_max = max(row_max, sink_score); + let sink_scale = exp(row_max - sink_max); + for (var reg_idx = 0u; reg_idx < OUT_REGS_PER_LANE; reg_idx += 1u) { + out_regs[reg_idx] *= sink_scale; + } + exp_sum = exp_sum * sink_scale + exp(sink_score - sink_max); + row_max = sink_max; + } +#endif + + if (row_active) { + let inv_exp_sum = select(0.0, 1.0 / exp_sum, exp_sum != 0.0); + let row_base = dst_global_offset + subgroup_id * dst2_stride; + let out_slots = min(OUT_REGS_PER_LANE, (V_CHUNKS + subgroup_size - 1u) / subgroup_size); + for (var reg_idx = 0u; reg_idx < out_slots; reg_idx += 1u) { + let chunk = sg_inv_id + reg_idx * subgroup_size; + if (chunk >= V_CHUNKS) { + continue; + } + let dst_vec_index = (row_base + chunk * 4u) >> 2u; + dst[dst_vec_index] = out_regs[reg_idx] * inv_exp_sum; + } + } +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl index 61107c6a985..b4f7c16c35d 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl @@ -15,7 +15,7 @@ struct Params { nblk1: u32, }; -@group(0) @binding(0) var mask: array; +@group(0) @binding(0) var mask: array; @group(0) @binding(1) var blk: array; @group(0) @binding(2) var params: Params; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl index a52575871ae..b1e234784a8 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl @@ -1,8 +1,6 @@ -diagnostic(off, chromium.subgroup_matrix_uniformity); diagnostic(off, subgroup_uniformity); enable f16; enable subgroups; -enable chromium_experimental_subgroup_matrix; #ifdef KV_F32 #define KV_TYPE f32 @@ -13,19 +11,14 @@ enable chromium_experimental_subgroup_matrix; #define HEAD_DIM_QK 64 #define HEAD_DIM_V 64 - -#define SG_MAT_M 8 -#define SG_MAT_N 8 -#define SG_MAT_K 8 - -#define Q_TILE SG_MAT_M +#define KV_GRANULARITY 8 #define KV_TILE 16 #define WG_SIZE 64 #ifndef VEC_NE #define VEC_NE 4u #endif -#define KV_BLOCKS (KV_TILE / SG_MAT_N) +#define KV_BLOCKS (KV_TILE / KV_GRANULARITY) #define BLOCK_SIZE 32 #define BLOCKS_K ((HEAD_DIM_QK + BLOCK_SIZE - 1) / BLOCK_SIZE) @@ -97,6 +90,14 @@ struct Params { }; @group(0) @binding(0) var Q: array; +#ifdef KV_OVERLAP +#if defined(KV_Q4_0) || defined(KV_Q8_0) +@group(0) @binding(1) var K: array; +#else +@group(0) @binding(1) var K: array>; +#endif +#define V K +#else #if defined(KV_Q4_0) || defined(KV_Q8_0) @group(0) @binding(1) var K: array; #else @@ -107,7 +108,22 @@ struct Params { #else @group(0) @binding(2) var V: array>; #endif +#endif #if defined(MASK) && defined(SINKS) +#ifdef KV_OVERLAP +@group(0) @binding(2) var mask: array; +@group(0) @binding(3) var sinks: array; +#ifdef BLK +#define BLK_BINDING 4 +#define TMP_BINDING 5 +#define DST_BINDING 6 +#define PARAMS_BINDING 7 +#else +#define TMP_BINDING 4 +#define DST_BINDING 5 +#define PARAMS_BINDING 6 +#endif +#else @group(0) @binding(3) var mask: array; @group(0) @binding(4) var sinks: array; #ifdef BLK @@ -120,7 +136,21 @@ struct Params { #define DST_BINDING 6 #define PARAMS_BINDING 7 #endif +#endif #elif defined(MASK) +#ifdef KV_OVERLAP +@group(0) @binding(2) var mask: array; +#ifdef BLK +#define BLK_BINDING 3 +#define TMP_BINDING 4 +#define DST_BINDING 5 +#define PARAMS_BINDING 6 +#else +#define TMP_BINDING 3 +#define DST_BINDING 4 +#define PARAMS_BINDING 5 +#endif +#else @group(0) @binding(3) var mask: array; #ifdef BLK #define BLK_BINDING 4 @@ -132,16 +162,30 @@ struct Params { #define DST_BINDING 5 #define PARAMS_BINDING 6 #endif +#endif #elif defined(SINKS) +#ifdef KV_OVERLAP +@group(0) @binding(2) var sinks: array; +#define TMP_BINDING 3 +#define DST_BINDING 4 +#define PARAMS_BINDING 5 +#else @group(0) @binding(3) var sinks: array; #define TMP_BINDING 4 #define DST_BINDING 5 #define PARAMS_BINDING 6 +#endif +#else +#ifdef KV_OVERLAP +#define TMP_BINDING 2 +#define DST_BINDING 3 +#define PARAMS_BINDING 4 #else #define TMP_BINDING 3 #define DST_BINDING 4 #define PARAMS_BINDING 5 #endif +#endif #ifdef BLK @group(0) @binding(BLK_BINDING) var blk: array; @@ -153,7 +197,7 @@ struct Params { // Just a very small float value. const FLOAT_MIN: f32 = -1.0e9; -var q_shmem: array; +var q_shmem: array; #ifndef KV_DIRECT const kv_shmem_size = KV_TILE * max(HEAD_DIM_QK, HEAD_DIM_V); @@ -161,31 +205,27 @@ const kv_shmem_size = KV_TILE * max(HEAD_DIM_QK, HEAD_DIM_V); var kv_shmem: array; #endif -var o_shmem: array; +var o_shmem: array; #ifdef MASK // storage for mask values -var mask_shmem: array; +var mask_shmem: array; #endif // note that we reuse the same storage for both since we only need one at a time -var inter_shmem: array; +var inter_shmem: array; // Storage for row max and exp sum during online softmax -var row_max_shmem: array; -var exp_sum_shmem: array; -var blk_state_wg: u32; - -fn calc_softmax_term(kv_idx: u32, q_tile_row: u32, slope: f32, has_bias: bool, apply_mask: bool) -> f32 { +fn calc_softmax_term(kv_idx: u32, slope: f32, has_bias: bool, apply_mask: bool) -> f32 { var v = select(FLOAT_MIN, - f32(inter_shmem[kv_idx + q_tile_row * KV_TILE]) * params.scale, + f32(inter_shmem[kv_idx]) * params.scale, kv_idx < KV_TILE); #ifdef LOGIT_SOFTCAP v = params.logit_softcap * tanh(v); #endif #ifdef MASK if (apply_mask) { - var mask_val = select(0.0,f32(mask_shmem[q_tile_row * KV_TILE + kv_idx]), kv_idx < KV_TILE); + var mask_val = select(0.0, f32(mask_shmem[kv_idx]), kv_idx < KV_TILE); v += select(mask_val, slope * mask_val, has_bias); } #endif @@ -199,19 +239,17 @@ fn main(@builtin(workgroup_id) wg_id: vec3, @builtin(subgroup_size) subgroup_size: u32, @builtin(num_subgroups) num_subgroups: u32, @builtin(subgroup_invocation_id) sg_inv_id: u32) { + // Vec path processes exactly one query row per workgroup, so subgroup 0 can + // keep the running softmax state in private storage. + var row_max = FLOAT_MIN; + var exp_sum = 0.0; - // initialize row max for online softmax - for (var i = local_id.x; i < Q_TILE; i += WG_SIZE) { - row_max_shmem[i] = FLOAT_MIN; - exp_sum_shmem[i] = 0.0; - } - - for (var i = local_id.x; i < Q_TILE * HEAD_DIM_V; i += WG_SIZE) { + for (var i = local_id.x; i < HEAD_DIM_V; i += WG_SIZE) { o_shmem[i] = 0.0; } // workgroups per head/batch - let wg_per_head = (params.seq_len_q + Q_TILE - 1u) / Q_TILE; + let wg_per_head = params.seq_len_q; let wg_per_batch = wg_per_head * params.n_heads; let dst2_stride = HEAD_DIM_V * params.n_heads; @@ -235,9 +273,9 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let k_head_offset = k_batch_offset + k_head_idx * params.stride_k2; let v_head_offset = v_batch_offset + v_head_idx * params.stride_v2; - // starting Q row for this workgroup + // Vec path handles one Q row per workgroup. let wg_in_head = wg_in_batch % wg_per_head; - let q_row_start = wg_in_head * Q_TILE; + let q_row_start = wg_in_head; #ifdef MASK // mask offset @@ -248,21 +286,18 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let has_bias = params.max_bias > 0.0; let slope = select(1.0, select(pow(params.m1, 2.0 * (head - params.n_head_log2) + 1.0), pow(params.m0, head + 1.0), head < params.n_head_log2), has_bias); - // load q tile into shared memory - for (var elem_idx = local_id.x; elem_idx < Q_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE) { - let q_row = elem_idx / HEAD_DIM_QK; - let q_col = elem_idx % HEAD_DIM_QK; - let head_q_row = q_row_start + q_row; - let global_q_row_offset = q_head_offset + head_q_row * params.stride_q1; + // load the single Q row into shared memory + for (var elem_idx = local_id.x; elem_idx < HEAD_DIM_QK; elem_idx += WG_SIZE) { + let global_q_row_offset = q_head_offset + q_row_start * params.stride_q1; q_shmem[elem_idx] = f16(select( 0.0, - Q[global_q_row_offset + q_col], - head_q_row < params.seq_len_q && q_col < HEAD_DIM_QK)); + Q[global_q_row_offset + elem_idx], + q_row_start < params.seq_len_q)); } for (var kv_tile = iwg * KV_TILE; kv_tile < params.seq_len_kv; kv_tile += KV_TILE * params.nwg) { #ifdef BLK - let q_blk = q_row_start / Q_TILE; + let q_blk = q_row_start; let kv_blk = kv_tile / KV_TILE; let blk_batch = select(0u, batch_idx, params.stride_mask3 > 0u); let blk_idx = params.blk_base + (blk_batch * params.blk_nblk1 + q_blk) * params.blk_nblk0 + kv_blk; @@ -270,13 +305,9 @@ fn main(@builtin(workgroup_id) wg_id: vec3, #else let blk_state_local = 1u; #endif - if (local_id.x == 0u) { - blk_state_wg = blk_state_local; - } - workgroupBarrier(); - let blk_state = blk_state_wg; + let blk_state = blk_state_local; let skip_tile = blk_state == 0u; - for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) { + for (var elem_idx = local_id.x; elem_idx < KV_TILE; elem_idx += WG_SIZE) { inter_shmem[elem_idx] = f16(0.0); } @@ -360,20 +391,14 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let num_of_threads = subgroup_size / VEC_NE; let tx = sg_inv_id % num_of_threads; let ty = sg_inv_id / num_of_threads; - for (var q_tile_row = subgroup_id; q_tile_row < Q_TILE; q_tile_row += num_subgroups) { - let global_q_row = q_row_start + q_tile_row; - if (global_q_row >= params.seq_len_q) { - continue; - } - let local_q_row_offset = q_tile_row * HEAD_DIM_QK; - + if (subgroup_id == 0u && q_row_start < params.seq_len_q) { for (var kv_base : u32 = 0u; kv_base < KV_TILE; kv_base += VEC_NE) { let kv_idx = kv_base + ty; var partial_sum: f32 = 0.0; let kv_valid = kv_idx < KV_TILE && (kv_tile + kv_idx) < params.seq_len_kv; if (kv_valid) { for (var i = tx; i < (HEAD_DIM_QK / 4u); i += num_of_threads) { - let q_off = local_q_row_offset + i * 4u; + let q_off = i * 4u; let qv = vec4( f32(q_shmem[q_off + 0u]), @@ -410,8 +435,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let sum_bcast = subgroupShuffle(sum, num_of_threads * ty); if (tx == 0u && kv_valid) { - let dst_idx = q_tile_row * KV_TILE + kv_idx; - inter_shmem[dst_idx] = f16(sum_bcast); + inter_shmem[kv_idx] = f16(sum_bcast); } } } @@ -422,13 +446,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let apply_mask = !skip_tile && (blk_state != 2u); if (apply_mask) { // load mask tile into shared memory for this KV block - for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) { - let mask_row = elem_idx / KV_TILE; - let mask_col = elem_idx % KV_TILE; - let global_q_row = q_row_start + mask_row; - let global_k_col = kv_tile + mask_col; - let mask_in_bounds = global_q_row < params.seq_len_q && global_k_col < params.seq_len_kv; - let mask_idx = mask_global_offset + mask_row * params.seq_len_kv + global_k_col; + for (var elem_idx = local_id.x; elem_idx < KV_TILE; elem_idx += WG_SIZE) { + let global_k_col = kv_tile + elem_idx; + let mask_in_bounds = q_row_start < params.seq_len_q && global_k_col < params.seq_len_kv; + let mask_idx = mask_global_offset + global_k_col; mask_shmem[elem_idx] = select(0.0, mask[mask_idx], mask_in_bounds); } } @@ -439,50 +460,40 @@ fn main(@builtin(workgroup_id) wg_id: vec3, workgroupBarrier(); // online softmax - if (!skip_tile) { - for (var q_tile_row = subgroup_id; q_tile_row < Q_TILE; q_tile_row += num_subgroups) { - let global_q_row = q_row_start + q_tile_row; - if (global_q_row >= params.seq_len_q) { - break; - } - - var prev_max = row_max_shmem[q_tile_row]; - var final_max = prev_max; - // pass 1: compute final max across the full KV tile in chunks - for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) { - let kv_idx = kv_offset + sg_inv_id; - let kv_valid = kv_tile + kv_idx < params.seq_len_kv && kv_idx < KV_TILE; - let softmax_term = select(FLOAT_MIN, - calc_softmax_term(kv_idx, q_tile_row, slope, has_bias, apply_mask), - kv_valid); - final_max = subgroupMax(max(final_max, softmax_term)); - } + if (!skip_tile && subgroup_id == 0u && q_row_start < params.seq_len_q) { + var prev_max = row_max; + var final_max = prev_max; + // pass 1: compute final max across the full KV tile in chunks + for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) { + let kv_idx = kv_offset + sg_inv_id; + let kv_valid = kv_tile + kv_idx < params.seq_len_kv && kv_idx < KV_TILE; + let softmax_term = select(FLOAT_MIN, + calc_softmax_term(kv_idx, slope, has_bias, apply_mask), + kv_valid); + final_max = subgroupMax(max(final_max, softmax_term)); + } - var total_exp_term: f32 = 0.0; - // pass 2: compute exp sum and write P using final_max - for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) { - let kv_idx = kv_offset + sg_inv_id; - let softmax_term = calc_softmax_term(kv_idx, q_tile_row, slope, has_bias, apply_mask); - let cur_p = select(0.0, - exp(softmax_term - final_max), - kv_tile + kv_idx < params.seq_len_kv && kv_idx < KV_TILE); - total_exp_term += subgroupAdd(cur_p); - if (kv_idx < KV_TILE) { - inter_shmem[kv_idx + q_tile_row * KV_TILE] = f16(cur_p); - } + var total_exp_term: f32 = 0.0; + // pass 2: compute exp sum and write P using final_max + for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) { + let kv_idx = kv_offset + sg_inv_id; + let softmax_term = calc_softmax_term(kv_idx, slope, has_bias, apply_mask); + let cur_p = select(0.0, + exp(softmax_term - final_max), + kv_tile + kv_idx < params.seq_len_kv && kv_idx < KV_TILE); + total_exp_term += subgroupAdd(cur_p); + if (kv_idx < KV_TILE) { + inter_shmem[kv_idx] = f16(cur_p); } + } - let cur_exp = exp(prev_max - final_max); + let cur_exp = exp(prev_max - final_max); - if (sg_inv_id == 0) { - row_max_shmem[q_tile_row] = final_max; - exp_sum_shmem[q_tile_row] = exp_sum_shmem[q_tile_row] * cur_exp + total_exp_term; - } + row_max = final_max; + exp_sum = exp_sum * cur_exp + total_exp_term; - for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) { - let idx = q_tile_row * HEAD_DIM_V + elem_idx; - o_shmem[idx] = f16(f32(o_shmem[idx]) * cur_exp); - } + for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) { + o_shmem[elem_idx] = f16(f32(o_shmem[elem_idx]) * cur_exp); } } @@ -562,15 +573,13 @@ fn main(@builtin(workgroup_id) wg_id: vec3, workgroupBarrier(); if (!skip_tile) { - // we have P (Q_TILE x KV_TILE) in inter_shmem and V (KV_TILE x head_dim_v) in kv_shmem + // we have P (KV_TILE) in inter_shmem and V (KV_TILE x head_dim_v) in kv_shmem // we want to compute O += P * V across the full KV tile let ne_threads : u32 = VEC_NE; let nl_threads = max(1u, subgroup_size / ne_threads); let tx_pv = sg_inv_id % nl_threads; let ty_pv = sg_inv_id / nl_threads; - for (var q_tile_row = subgroup_id; - q_tile_row < Q_TILE; - q_tile_row += num_subgroups) { + if (subgroup_id == 0u && q_row_start < params.seq_len_q) { for (var vec_col = tx_pv; vec_col < (HEAD_DIM_V / 4u); vec_col += nl_threads) { var lo = vec4(0.0, 0.0, 0.0, 0.0); for (var cc = 0u; cc < KV_TILE / ne_threads; cc += 1u) { @@ -580,7 +589,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, continue; } - let p = f32(inter_shmem[kv_idx + q_tile_row * KV_TILE]); + let p = f32(inter_shmem[kv_idx]); #ifdef KV_DIRECT let v_idx = v_head_offset + v_row * params.stride_v1 + vec_col * 4u; let v4 = vec4(V[v_idx >> 2u]); @@ -621,11 +630,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3, if (ty_pv == 0u) { let elem_base = vec_col * 4u; - let o_base_idx = q_tile_row * HEAD_DIM_V + elem_base; - o_shmem[o_base_idx + 0u] = f16(f32(o_shmem[o_base_idx + 0u]) + lo_x); - o_shmem[o_base_idx + 1u] = f16(f32(o_shmem[o_base_idx + 1u]) + lo_y); - o_shmem[o_base_idx + 2u] = f16(f32(o_shmem[o_base_idx + 2u]) + lo_z); - o_shmem[o_base_idx + 3u] = f16(f32(o_shmem[o_base_idx + 3u]) + lo_w); + o_shmem[elem_base + 0u] = f16(f32(o_shmem[elem_base + 0u]) + lo_x); + o_shmem[elem_base + 1u] = f16(f32(o_shmem[elem_base + 1u]) + lo_y); + o_shmem[elem_base + 2u] = f16(f32(o_shmem[elem_base + 2u]) + lo_z); + o_shmem[elem_base + 3u] = f16(f32(o_shmem[elem_base + 3u]) + lo_w); } } } @@ -637,70 +645,46 @@ fn main(@builtin(workgroup_id) wg_id: vec3, #ifdef SINKS // Sinks are global terms and must be applied exactly once across split workgroups. - if (iwg == 0u) { - for (var q_tile_row = subgroup_id; - q_tile_row < Q_TILE; - q_tile_row += num_subgroups) { - let global_q_row = q_row_start + q_tile_row; - if (global_q_row >= params.seq_len_q) { - break; - } - - var prev_max = row_max_shmem[q_tile_row]; - - // for non-sink threads, exp(FLOAT_MIN) effectively zeroes out their contribution to the sum - let sink_val = select(FLOAT_MIN, sinks[params.offset_sinks + head_idx], sg_inv_id == 0); - let new_max = subgroupMax(max(prev_max, sink_val)); - let max_exp = exp(prev_max - new_max); - let sink_exp = exp(sink_val - new_max); - - let sink_exp_sum = subgroupAdd(sink_exp); - - if (sg_inv_id == 0) { - row_max_shmem[q_tile_row] = new_max; - exp_sum_shmem[q_tile_row] = exp_sum_shmem[q_tile_row] * max_exp + sink_exp_sum; - } - - for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) { - let idx = q_tile_row * HEAD_DIM_V + elem_idx; - o_shmem[idx] = f16(f32(o_shmem[idx]) * max_exp); - } + if (iwg == 0u && subgroup_id == 0u && q_row_start < params.seq_len_q) { + var prev_max = row_max; + + // for non-sink threads, exp(FLOAT_MIN) effectively zeroes out their contribution to the sum + let sink_val = select(FLOAT_MIN, sinks[params.offset_sinks + head_idx], sg_inv_id == 0u); + let new_max = subgroupMax(max(prev_max, sink_val)); + let max_exp = exp(prev_max - new_max); + let sink_exp = exp(sink_val - new_max); + + let sink_exp_sum = subgroupAdd(sink_exp); + + row_max = new_max; + exp_sum = exp_sum * max_exp + sink_exp_sum; + + for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) { + o_shmem[elem_idx] = f16(f32(o_shmem[elem_idx]) * max_exp); } - workgroupBarrier(); } + workgroupBarrier(); #endif let rows_per_batch = params.n_heads * params.seq_len_q; - for (var q_tile_row = subgroup_id; - q_tile_row < Q_TILE; - q_tile_row += num_subgroups) { - - let global_q_row = q_row_start + q_tile_row; - if (global_q_row >= params.seq_len_q) { break; } - + if (subgroup_id == 0u && q_row_start < params.seq_len_q) { if (params.nwg == 1u) { - let exp_sum = exp_sum_shmem[q_tile_row]; let scale = select(0.0, 1.0 / exp_sum, exp_sum != 0.0); - let row_base: u32 = - params.offset_dst + batch_idx * dst3_stride + global_q_row * dst2_stride + head_idx * HEAD_DIM_V; + let row_base: u32 = params.offset_dst + batch_idx * dst3_stride + q_row_start * dst2_stride + + head_idx * HEAD_DIM_V; for (var elem_base = sg_inv_id * 4u; elem_base < HEAD_DIM_V; elem_base += subgroup_size * 4u) { - let i0 = q_tile_row * HEAD_DIM_V + (elem_base + 0u); - let i1 = q_tile_row * HEAD_DIM_V + (elem_base + 1u); - let i2 = q_tile_row * HEAD_DIM_V + (elem_base + 2u); - let i3 = q_tile_row * HEAD_DIM_V + (elem_base + 3u); - let v = vec4( - f32(o_shmem[i0]) * scale, - f32(o_shmem[i1]) * scale, - f32(o_shmem[i2]) * scale, - f32(o_shmem[i3]) * scale + f32(o_shmem[elem_base + 0u]) * scale, + f32(o_shmem[elem_base + 1u]) * scale, + f32(o_shmem[elem_base + 2u]) * scale, + f32(o_shmem[elem_base + 3u]) * scale ); let dst_vec_index: u32 = (row_base + elem_base) >> 2u; dst[dst_vec_index] = v; } } else { - let rid = batch_idx * rows_per_batch + head_idx * params.seq_len_q + global_q_row; + let rid = batch_idx * rows_per_batch + head_idx * params.seq_len_q + q_row_start; let tmp_row_data_base = params.tmp_data_base + rid * (HEAD_DIM_V * params.nwg) + iwg * HEAD_DIM_V; let tmp_row_stats_base = params.tmp_stats_base + rid * (2u * params.nwg) + 2u * iwg; @@ -708,21 +692,16 @@ fn main(@builtin(workgroup_id) wg_id: vec3, elem_base < HEAD_DIM_V; elem_base += subgroup_size * 4u) { - let i0 = q_tile_row * HEAD_DIM_V + (elem_base + 0u); - let i1 = q_tile_row * HEAD_DIM_V + (elem_base + 1u); - let i2 = q_tile_row * HEAD_DIM_V + (elem_base + 2u); - let i3 = q_tile_row * HEAD_DIM_V + (elem_base + 3u); - let tbase = tmp_row_data_base + elem_base; - tmp[tbase + 0u] = f32(o_shmem[i0]); - tmp[tbase + 1u] = f32(o_shmem[i1]); - tmp[tbase + 2u] = f32(o_shmem[i2]); - tmp[tbase + 3u] = f32(o_shmem[i3]); + tmp[tbase + 0u] = f32(o_shmem[elem_base + 0u]); + tmp[tbase + 1u] = f32(o_shmem[elem_base + 1u]); + tmp[tbase + 2u] = f32(o_shmem[elem_base + 2u]); + tmp[tbase + 3u] = f32(o_shmem[elem_base + 3u]); } if (sg_inv_id == 0u) { - tmp[tmp_row_stats_base + 0u] = exp_sum_shmem[q_tile_row]; - tmp[tmp_row_stats_base + 1u] = row_max_shmem[q_tile_row]; + tmp[tmp_row_stats_base + 0u] = exp_sum; + tmp[tmp_row_stats_base + 1u] = row_max; } } } From c546b0b1bc9a0b6dde7b330986800bf8183eda14 Mon Sep 17 00:00:00 2001 From: Trivikram Reddy <127072883+trivikram-reddy1@users.noreply.github.com> Date: Fri, 24 Apr 2026 15:55:17 -0500 Subject: [PATCH 490/831] Hexagon: Bump HMX Frequency to Max Corner (llama/22334) * hexagon: bump HMX freq to max corner * hex-mm: fix error in log msg --- ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c | 2 +- ggml/src/ggml-hexagon/htp/main.c | 18 ++++++++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c index dbca8220fab..05e3c6c2b0f 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c @@ -1683,7 +1683,7 @@ int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); assert((size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base) <= vtcm_budget); - FARF(HIGH, "hmx-mm: m=%d k=%d n=%d wtype=%d block M=%zu N=%zu K=%zu vtcm=%zu/%zu", __func__, m, k, n, weight_type, + FARF(HIGH, "hmx-mm: m=%d k=%d n=%d wtype=%d block M=%zu N=%zu K=%zu vtcm=%zu/%zu", m, k, n, weight_type, M_BLOCK_SIZE, N_BLOCK_SIZE, K_BLOCK_SIZE, (size_t) (vtcm_ptr - (uint8_t *) ctx->vtcm_base), vtcm_budget); // initialize eye tile (32x32 identity matrix) diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index db277a25e5a..62942f6384c 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -101,6 +101,24 @@ AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) { } } + { + // Set HMX clock + HAP_power_request_t request; + memset(&request, 0, sizeof(HAP_power_request_t)); + request.type = HAP_power_set_HMX_v2; + request.hmx_v2.set_clock = TRUE; + request.hmx_v2.target_corner = HAP_DCVS_EXP_VCORNER_MAX; + request.hmx_v2.min_corner = HAP_DCVS_EXP_VCORNER_MAX; + request.hmx_v2.max_corner = HAP_DCVS_EXP_VCORNER_MAX; + request.hmx_v2.perf_mode = HAP_CLK_PERF_HIGH; + FARF(ALWAYS, "Setting HMX clock\n"); + err = HAP_power_set((void *) &ctx, &request); + if (err != AEE_SUCCESS) { + FARF(ERROR, "Error setting HMX clock."); + return err; + } + } + return AEE_SUCCESS; } From c235b05d8a0044b771cc06d975128415810cc002 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Fri, 24 Apr 2026 23:18:15 -0700 Subject: [PATCH 491/831] ggml-webgpu: support for SSM_SCAN and disable set_rows error checking (llama/22327) * Implement ssm_scan * Remove blocking in graph_compute and check for set rows * Fix bindings * Update op support --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 72 ++++++++ ggml/src/ggml-webgpu/ggml-webgpu.cpp | 90 +++++++++- .../ggml-webgpu/wgsl-shaders/ssm_scan.wgsl | 168 ++++++++++++++++++ 3 files changed, 328 insertions(+), 2 deletions(-) create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index e492c2123a4..16ebc32cbc7 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -98,6 +98,29 @@ struct ggml_webgpu_ssm_conv_shader_decisions { uint32_t tokens_per_wg; }; +struct ggml_webgpu_ssm_scan_pipeline_key { + int type; + int d_state; + + bool operator==(const ggml_webgpu_ssm_scan_pipeline_key & other) const { + return type == other.type && d_state == other.d_state; + } +}; + +struct ggml_webgpu_ssm_scan_pipeline_key_hash { + size_t operator()(const ggml_webgpu_ssm_scan_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.type); + ggml_webgpu_hash_combine(seed, key.d_state); + return seed; + } +}; + +struct ggml_webgpu_ssm_scan_shader_decisions { + uint32_t wg_size; + uint32_t tokens_per_tile; +}; + /** Argsort **/ struct ggml_webgpu_argsort_shader_lib_context { @@ -921,6 +944,8 @@ class ggml_webgpu_shader_lib { solve_tri_pipelines; // type std::unordered_map ssm_conv_pipelines; // type/vectorized + std::unordered_map + ssm_scan_pipelines; // type/d_state std::unordered_map @@ -1433,6 +1458,53 @@ class ggml_webgpu_shader_lib { return ssm_conv_pipelines[key]; } + webgpu_pipeline get_ssm_scan_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_ssm_scan_pipeline_key key = {}; + key.type = context.dst->type; + key.d_state = (int) context.src0->ne[0]; + + auto it = ssm_scan_pipelines.find(key); + if (it != ssm_scan_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "ssm_scan"; + + switch (key.type) { + case GGML_TYPE_F32: + variant += "_f32"; + break; + default: + GGML_ABORT("Unsupported type for ssm_scan shader"); + } + + const uint32_t wg_size = (uint32_t) key.d_state; + + constexpr uint32_t tokens_per_tile = 4u; + + defines.push_back("WG_SIZE=" + std::to_string(wg_size) + "u"); + defines.push_back("TOKENS_PER_TILE=" + std::to_string(tokens_per_tile) + "u"); + + if (context.supports_subgroups) { + defines.push_back("USE_SUBGROUP_REDUCTION"); + variant += "_sg_reduce"; + } else { + variant += "_wg_reduce"; + } + + variant += "_d" + std::to_string(key.d_state); + + auto processed = preprocessor.preprocess(wgsl_ssm_scan, defines); + auto decisions = std::make_shared(); + decisions->wg_size = wg_size; + decisions->tokens_per_tile = tokens_per_tile; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + ssm_scan_pipelines[key] = pipeline; + return ssm_scan_pipelines[key]; + } + webgpu_pipeline get_gated_delta_net_pipeline(const ggml_webgpu_shader_lib_context & context) { ggml_webgpu_gated_delta_net_pipeline_key key = {}; key.type = context.dst->type; diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 7ed6fdd1625..bcec20c1a11 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -1115,6 +1115,80 @@ static webgpu_encoded_op ggml_webgpu_ssm_conv(webgpu_context & ctx, return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); } +static webgpu_encoded_op ggml_webgpu_ssm_scan(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * src2, + ggml_tensor * src3, + ggml_tensor * src4, + ggml_tensor * src5, + ggml_tensor * src6, + ggml_tensor * dst) { + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.supports_subgroups = ctx->global_ctx->capabilities.supports_subgroups; + + webgpu_pipeline pipeline = ctx->shader_lib->get_ssm_scan_pipeline(shader_lib_ctx); + + std::vector params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src3) / ggml_type_size(src3->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src4) / ggml_type_size(src4->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src5) / ggml_type_size(src5->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src6) / ggml_type_size(src6->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + + (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), + + (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)), + + (uint32_t) (src2->nb[1] / ggml_type_size(src2->type)), + (uint32_t) (src2->nb[2] / ggml_type_size(src2->type)), + + (uint32_t) src3->ne[0], + (uint32_t) (src3->nb[1] / ggml_type_size(src3->type)), + + (uint32_t) (src4->nb[1] / ggml_type_size(src4->type)), + (uint32_t) (src4->nb[2] / ggml_type_size(src4->type)), + (uint32_t) (src4->nb[3] / ggml_type_size(src4->type)), + + (uint32_t) (src5->nb[1] / ggml_type_size(src5->type)), + (uint32_t) (src5->nb[2] / ggml_type_size(src5->type)), + (uint32_t) (src5->nb[3] / ggml_type_size(src5->type)), + + (uint32_t) src0->ne[0], + (uint32_t) src0->ne[1], + (uint32_t) src0->ne[2], + (uint32_t) src4->ne[1], + (uint32_t) src1->ne[2], + (uint32_t) src1->ne[3], + (uint32_t) ggml_nelements(src1), + }; + + std::vector entries = { + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, src2), ggml_webgpu_make_tensor_bind_group_entry(ctx, 3, src3), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 4, src4), ggml_webgpu_make_tensor_bind_group_entry(ctx, 5, src5), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 6, src6), ggml_webgpu_make_tensor_bind_group_entry(ctx, 7, dst), + }; + + const uint32_t total_wg = (uint32_t) (src0->ne[1] * src0->ne[2] * src1->ne[3]); + const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; + uint32_t wg_x; + uint32_t wg_y; + compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y); + + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); +} + static webgpu_encoded_op ggml_webgpu_gated_delta_net(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, @@ -2764,6 +2838,9 @@ static std::optional ggml_webgpu_encode(webgpu_context ctx, return ggml_webgpu_solve_tri(ctx, src0, src1, node); case GGML_OP_SSM_CONV: return ggml_webgpu_ssm_conv(ctx, src0, src1, node); + case GGML_OP_SSM_SCAN: + return ggml_webgpu_ssm_scan(ctx, src0, src1, src2, node->src[3], node->src[4], node->src[5], node->src[6], + node); case GGML_OP_GATED_DELTA_NET: return ggml_webgpu_gated_delta_net(ctx, src0, src1, src2, node->src[3], node->src[4], node->src[5], node); case GGML_OP_PAD: @@ -2822,7 +2899,10 @@ static void ggml_backend_webgpu_collect_profile_results(webgpu_context & } #endif +// Don't bother checking set_rows index overflow for now, since practically the WebGPU doesn't need to support +// models that would require it right now. static void ggml_backend_webgpu_check_set_rows(webgpu_context & ctx, uint32_t & num_inflight_batches) { +#ifdef GGML_WEBGPU_CHECK_SET_ROWS wgpu::CommandEncoder encoder = ctx->global_ctx->device.CreateCommandEncoder(); encoder.CopyBufferToBuffer(ctx->set_rows_dev_error_buf, 0, ctx->set_rows_host_error_buf, 0, ctx->set_rows_host_error_buf.GetSize()); @@ -2835,6 +2915,10 @@ static void ggml_backend_webgpu_check_set_rows(webgpu_context & ctx, uint32_t & GGML_ABORT("ggml_webgpu: SET_ROWS index > 2^32, unsupported."); } ctx->set_rows_host_error_buf.Unmap(); +#else + GGML_UNUSED(ctx); + GGML_UNUSED(num_inflight_batches); +#endif } static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) { @@ -2920,8 +3004,6 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str ggml_backend_webgpu_check_set_rows(ctx, num_inflight_batches); } - ggml_backend_webgpu_wait_queue(ctx->global_ctx); - WEBGPU_CPU_PROFILE_TOTAL_END(graph_compute, ctx->global_ctx); return GGML_STATUS_SUCCESS; } @@ -3941,6 +4023,10 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const case GGML_OP_SSM_CONV: supports_op = op->type == GGML_TYPE_F32; break; + case GGML_OP_SSM_SCAN: + supports_op = op->type == GGML_TYPE_F32 && + src0->ne[0] <= ctx->webgpu_global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + break; case GGML_OP_GATED_DELTA_NET: { const uint32_t s_v = (uint32_t) src2->ne[0]; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl new file mode 100644 index 00000000000..64324738591 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl @@ -0,0 +1,168 @@ +#ifdef USE_SUBGROUP_REDUCTION +enable subgroups; +#endif + +struct Params { + offset_s: u32, + offset_x: u32, + offset_dt: u32, + offset_A: u32, + offset_B: u32, + offset_C: u32, + offset_ids: u32, + offset_dst: u32, + + stride_s1: u32, + stride_s2: u32, + stride_s3: u32, + + stride_x1: u32, + stride_x2: u32, + stride_x3: u32, + + stride_dt1: u32, + stride_dt2: u32, + + a_ne0: u32, + stride_A1: u32, + + stride_B1: u32, + stride_B2: u32, + stride_B3: u32, + + stride_C1: u32, + stride_C2: u32, + stride_C3: u32, + + d_state: u32, + d_inner: u32, + n_head: u32, + n_group: u32, + n_seq_tokens: u32, + n_seqs: u32, + + y_elems: u32, +}; + +@group(0) @binding(0) var s_in: array; +@group(0) @binding(1) var x: array; +@group(0) @binding(2) var dt: array; +@group(0) @binding(3) var A: array; +@group(0) @binding(4) var B: array; +@group(0) @binding(5) var C: array; +@group(0) @binding(6) var ids: array; +@group(0) @binding(7) var dst: array; +@group(0) @binding(8) var params: Params; + +var shared_x_dt: array; +var shared_dtsp: array; +var shared_reduce: array; + +fn reduce_base(token_in_tile: u32) -> u32 { + return token_in_tile * WG_SIZE; +} + +@compute @workgroup_size(WG_SIZE) +fn main( + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) wg_id: vec3, + @builtin(num_workgroups) num_wg: vec3 +#ifdef USE_SUBGROUP_REDUCTION + , @builtin(subgroup_id) subgroup_id: u32, + @builtin(subgroup_invocation_id) subgroup_invocation_id: u32, + @builtin(num_subgroups) num_subgroups: u32 +#endif +) { + let tid = local_id.x; + let wg_linear = wg_id.y * num_wg.x + wg_id.x; + + let i1 = wg_linear % params.d_inner; + let head_seq = wg_linear / params.d_inner; + let ir = head_seq % params.n_head; + let i3 = head_seq / params.n_head; + + let state_slot = u32(ids[params.offset_ids + i3]); + let g = ir / (params.n_head / params.n_group); + + let s_idx = params.offset_s + tid + i1 * params.stride_s1 + ir * params.stride_s2 + state_slot * params.stride_s3; + var s_prev = s_in[s_idx]; + + let A0 = A[params.offset_A + (tid % params.a_ne0) + ir * params.stride_A1]; + + for (var token_base = 0u; token_base < params.n_seq_tokens; token_base += TOKENS_PER_TILE) { + if (tid < TOKENS_PER_TILE) { + let token = token_base + tid; + if (token < params.n_seq_tokens) { + let x_idx = params.offset_x + i1 + ir * params.stride_x1 + token * params.stride_x2 + i3 * params.stride_x3; + let dt_idx = params.offset_dt + ir + token * params.stride_dt1 + i3 * params.stride_dt2; + let dt0 = dt[dt_idx]; + let dtsp = select(log(1.0 + exp(dt0)), dt0, dt0 > 20.0); + shared_dtsp[tid] = dtsp; + shared_x_dt[tid] = x[x_idx] * dtsp; + } + } + + workgroupBarrier(); + + for (var token_in_tile = 0u; token_in_tile < TOKENS_PER_TILE; token_in_tile++) { + let token = token_base + token_in_tile; + if (token >= params.n_seq_tokens) { + break; + } + + let x_dt = shared_x_dt[token_in_tile]; + let dA = exp(shared_dtsp[token_in_tile] * A0); + let reduce_idx = reduce_base(token_in_tile) + tid; + + let b_idx = params.offset_B + tid + g * params.stride_B1 + token * params.stride_B2 + i3 * params.stride_B3; + let c_idx = params.offset_C + tid + g * params.stride_C1 + token * params.stride_C2 + i3 * params.stride_C3; + let s = s_prev * dA + B[b_idx] * x_dt; + s_prev = s; + +#ifdef USE_SUBGROUP_REDUCTION + let subgroup_partial = subgroupAdd(s * C[c_idx]); + if (subgroup_invocation_id == 0u) { + shared_reduce[reduce_idx - tid + subgroup_id] = subgroup_partial; + } +#else + shared_reduce[reduce_idx] = s * C[c_idx]; +#endif + + workgroupBarrier(); + +#ifdef USE_SUBGROUP_REDUCTION + if (tid == 0u) { + var sum = 0.0; + for (var sg = 0u; sg < num_subgroups; sg++) { + sum += shared_reduce[reduce_base(token_in_tile) + sg]; + } + let y_idx = + params.offset_dst + i1 + ir * params.d_inner + token * (params.n_head * params.d_inner) + + i3 * (params.n_seq_tokens * params.n_head * params.d_inner); + dst[y_idx] = sum; + } +#else + for (var stride = WG_SIZE / 2u; stride > 0u; stride >>= 1u) { + if (tid < stride) { + shared_reduce[reduce_idx] += shared_reduce[reduce_idx + stride]; + } + workgroupBarrier(); + } + + if (tid == 0u) { + let y_idx = + params.offset_dst + i1 + ir * params.d_inner + token * (params.n_head * params.d_inner) + + i3 * (params.n_seq_tokens * params.n_head * params.d_inner); + dst[y_idx] = shared_reduce[reduce_base(token_in_tile)]; + } +#endif + + workgroupBarrier(); + } + } + + let state_idx = + params.offset_dst + params.y_elems + tid + i1 * params.d_state + ir * (params.d_state * params.d_inner) + + i3 * (params.d_state * params.d_inner * params.n_head); + dst[state_idx] = s_prev; +} From 6296fd5a904edbd9785a9e8e06d38564e3c70b49 Mon Sep 17 00:00:00 2001 From: Neo Zhang Date: Sat, 25 Apr 2026 14:20:14 +0800 Subject: [PATCH 492/831] Optimize Q4_0 mul_mat for Arc770, add scripts (llama/22291) * opt arc770 for Q4_0 * add for Q4_0 * update the script * add help script for windows * update guide * fix format issue * convert from dos to unix for format issue * fix missed -sm parameter --- ggml/src/ggml-sycl/common.hpp | 2 +- ggml/src/ggml-sycl/ggml-sycl.cpp | 10 ++++- ggml/src/ggml-sycl/sycl_hw.cpp | 72 +++++++++++++++++++++++++++----- ggml/src/ggml-sycl/sycl_hw.hpp | 24 ++++++++--- 4 files changed, 90 insertions(+), 18 deletions(-) diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp index 0101b27640a..5abf2290651 100644 --- a/ggml/src/ggml-sycl/common.hpp +++ b/ggml/src/ggml-sycl/common.hpp @@ -224,7 +224,7 @@ struct sycl_device_info { // cudaOccupancyMaxActiveBlocksPerMultiprocessor bool vmm; // virtual memory support size_t total_vram; - //sycl_hw_info hw_info; \\ device id and aarch, currently not used + sycl_hw_info hw_info; optimize_feature opt_feature; }; diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 36923160d72..1eead625e76 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -104,6 +104,7 @@ static ggml_sycl_device_info ggml_sycl_init() { info.max_work_group_sizes[i] = prop.get_max_work_group_size(); info.devices[i].max_wg_per_cu = info.max_work_group_sizes[i] / prop.get_max_compute_units(); + info.devices[i].hw_info = get_device_hw_info(&device); } @@ -3703,9 +3704,16 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor // Dispatch becomes obscure with the reorder, MMVQ when the reorder optimization // is enabled takes precedence over DMMV, the current if-else implementation // requires disabling DMMV if both conditions are met + if (!g_ggml_sycl_prioritize_dmmv && ((should_reorder_tensor(ctx, dst) && ggml_sycl_supports_reorder_mmvq(src0->type)))) { - use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q; + // Arc770 get benefit with Q4_0 by skipping it. + if (!(ggml_sycl_info().devices[ctx.device].hw_info.arch == + gpu_arch::intel_gpu_acm_g10 && + src0->type == GGML_TYPE_Q4_0)) { + use_dequantize_mul_mat_vec = + use_dequantize_mul_mat_vec && !use_mul_mat_vec_q; + } } if (!split && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) { diff --git a/ggml/src/ggml-sycl/sycl_hw.cpp b/ggml/src/ggml-sycl/sycl_hw.cpp index 7041140034b..03b0c37a3cd 100644 --- a/ggml/src/ggml-sycl/sycl_hw.cpp +++ b/ggml/src/ggml-sycl/sycl_hw.cpp @@ -1,15 +1,67 @@ #include "sycl_hw.hpp" -// TODO: currently not used -/* -sycl_hw_info get_device_hw_info(sycl::device *device_ptr) { - sycl_hw_info res; - int32_t id = device_ptr->get_info(); - res.device_id = id; +using namespace std; - syclex::architecture arch = device_ptr->get_info(); - res.arch = arch; +/*defined in +* /opt/intel/oneapi/compiler/latest/include/sycl/ext/oneapi/experimental/device_architecture.def +*/ +static map> arch2name = { + {gpu_arch::intel_gpu_bdw, {"intel_gpu_bdw", GPU_FAMILY_IGPU_NON_XE}}, + {gpu_arch::intel_gpu_skl, {"intel_gpu_skl", GPU_FAMILY_IGPU_NON_XE}}, + {gpu_arch::intel_gpu_kbl, {"intel_gpu_kbl", GPU_FAMILY_IGPU_NON_XE}}, + {gpu_arch::intel_gpu_cfl, {"intel_gpu_cfl", GPU_FAMILY_IGPU_NON_XE}}, + {gpu_arch::intel_gpu_apl, {"intel_gpu_apl", GPU_FAMILY_IGPU_NON_XE}}, + {gpu_arch::intel_gpu_glk, {"intel_gpu_glk", GPU_FAMILY_IGPU_NON_XE}}, + {gpu_arch::intel_gpu_whl, {"intel_gpu_whl", GPU_FAMILY_IGPU_NON_XE}}, + {gpu_arch::intel_gpu_aml, {"intel_gpu_aml", GPU_FAMILY_IGPU_NON_XE}}, + {gpu_arch::intel_gpu_cml, {"intel_gpu_cml", GPU_FAMILY_IGPU_NON_XE}}, + {gpu_arch::intel_gpu_icllp, {"intel_gpu_icllp", GPU_FAMILY_IGPU_NON_XE}}, + {gpu_arch::intel_gpu_ehl, {"intel_gpu_ehl", GPU_FAMILY_IGPU_NON_XE}}, + {gpu_arch::intel_gpu_tgllp, {"intel_gpu_tgllp", GPU_FAMILY_IGPU_NON_XE}}, + {gpu_arch::intel_gpu_rkl, {"intel_gpu_rkl", GPU_FAMILY_IGPU_NON_XE}}, + {gpu_arch::intel_gpu_adl_s, {"intel_gpu_adl_s", GPU_FAMILY_IGPU_NON_XE}}, + {gpu_arch::intel_gpu_adl_p, {"intel_gpu_adl_p", GPU_FAMILY_IGPU_NON_XE}}, + {gpu_arch::intel_gpu_adl_n, {"intel_gpu_adl_n", GPU_FAMILY_IGPU_NON_XE}}, + {gpu_arch::intel_gpu_dg1, {"intel_gpu_dg1", GPU_FAMILY_DGPU_CLIENT_GAME}}, + {gpu_arch::intel_gpu_acm_g10, {"intel_gpu_acm_g10", GPU_FAMILY_DGPU_CLIENT_GAME}}, + {gpu_arch::intel_gpu_acm_g11, {"intel_gpu_acm_g11", GPU_FAMILY_DGPU_CLIENT_GAME}}, + {gpu_arch::intel_gpu_acm_g12, {"intel_gpu_acm_g12", GPU_FAMILY_DGPU_CLIENT_GAME}}, + {gpu_arch::intel_gpu_pvc, {"intel_gpu_pvc", GPU_FAMILY_DGPU_CLOUD}}, + {gpu_arch::intel_gpu_pvc_vg, {"intel_gpu_pvc_vg", GPU_FAMILY_DGPU_CLOUD}}, + {gpu_arch::intel_gpu_mtl_u, {"intel_gpu_mtl_u", GPU_FAMILY_IGPU_XE}}, + {gpu_arch::intel_gpu_mtl_h, {"intel_gpu_mtl_h", GPU_FAMILY_IGPU_XE}}, + {gpu_arch::intel_gpu_arl_h, {"intel_gpu_arl_h", GPU_FAMILY_IGPU_XE}}, + {gpu_arch::intel_gpu_bmg_g21, {"intel_gpu_bmg_g21", GPU_FAMILY_DGPU_CLIENT_GAME}}, + {gpu_arch::intel_gpu_bmg_g31, {"intel_gpu_bmg_g31", GPU_FAMILY_DGPU_CLIENT_GAME}}, + {gpu_arch::intel_gpu_lnl_m, {"intel_gpu_lnl_m", GPU_FAMILY_IGPU_XE}}, + {gpu_arch::intel_gpu_ptl_h, {"intel_gpu_ptl_h", GPU_FAMILY_IGPU_XE}}, + {gpu_arch::intel_gpu_ptl_u, {"intel_gpu_ptl_u", GPU_FAMILY_IGPU_XE}}, + {gpu_arch::intel_gpu_wcl, {"intel_gpu_wcl", GPU_FAMILY_IGPU_XE}} +}; + + +sycl_hw_info get_device_hw_info(sycl::device* device_ptr) { + sycl_hw_info res; + int32_t id = + device_ptr->get_info(); + res.device_id = id; + + res.name = device_ptr->get_info(); - return res; + syclex::architecture arch = + device_ptr->get_info(); + res.arch = arch; + + map>::iterator it = + arch2name.find(res.arch); + if (it != arch2name.end()) { + res.arch_name = it->second.first; + res.gpu_family = it->second.second; + } else { + res.arch_name = "unknown"; + res.gpu_family = GPU_FAMILY_UKNOWN; + } + + return res; } -*/ diff --git a/ggml/src/ggml-sycl/sycl_hw.hpp b/ggml/src/ggml-sycl/sycl_hw.hpp index 36b140bf037..a5d20462572 100644 --- a/ggml/src/ggml-sycl/sycl_hw.hpp +++ b/ggml/src/ggml-sycl/sycl_hw.hpp @@ -9,18 +9,30 @@ #include namespace syclex = sycl::ext::oneapi::experimental; +using gpu_arch = sycl::ext::oneapi::experimental::architecture; + +// It's used to mark the GPU computing capacity +// The value must flow the order of performance. +enum sycl_intel_gpu_family { + GPU_FAMILY_UKNOWN = -1, + // iGPU without Xe core, before Meteor Lake iGPU(Xe) + GPU_FAMILY_IGPU_NON_XE = 0, + // iGPU with Xe core, Meteor Lake iGPU or newer. + GPU_FAMILY_IGPU_XE = 1, + // dGPU for gaming in client/data center (DG1/FLex 140 or newer). + GPU_FAMILY_DGPU_CLIENT_GAME = 2, + // dGPU for AI in cloud, PVC or newer. + GPU_FAMILY_DGPU_CLOUD = 3 +}; -// TODO: currently not used -/* struct sycl_hw_info { syclex::architecture arch; + const char* arch_name; int32_t device_id; + std::string name; + sycl_intel_gpu_family gpu_family; }; -bool is_in_vector(std::vector &vec, int item); - sycl_hw_info get_device_hw_info(sycl::device *device_ptr); -*/ - #endif // SYCL_HW_HPP From 21da84303e9cea074f16850e9a2573f68b75b48f Mon Sep 17 00:00:00 2001 From: Developer-Ecosystem-Engineering <65677710+Developer-Ecosystem-Engineering@users.noreply.github.com> Date: Sat, 25 Apr 2026 05:14:28 -0700 Subject: [PATCH 493/831] metal : optimize Metal Tensor API usage for GGML_OP_MUL_MAT (llama/20962) * Optimize Metal Tensor API usage for matmul2d Separates the Metal Tensor API (matmul2d) path in kernel_mul_mm into its own standalone kernel, gated by GGML_METAL_HAS_TENSOR. The legacy simdgroup_matrix kernel is preserved under #else. Previously both paths were interleaved via #ifdef blocks within a single kernel, forcing the tensor path to share the legacy kernel's data layout and threadgroup memory scheme. Splitting the kernel enabled memory and dispatch optimizations that weren't possible when the two paths shared code structure. * cont : cleanup * cont : cleanup * cont : cleanup --------- Co-authored-by: Georgi Gerganov --- ggml/src/ggml-metal/ggml-metal-device.cpp | 26 ++- ggml/src/ggml-metal/ggml-metal-device.h | 2 + ggml/src/ggml-metal/ggml-metal-device.m | 17 +- ggml/src/ggml-metal/ggml-metal-impl.h | 13 ++ ggml/src/ggml-metal/ggml-metal-ops.cpp | 7 +- ggml/src/ggml-metal/ggml-metal.metal | 233 +++++++++++++--------- 6 files changed, 189 insertions(+), 109 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 07d016d2227..d211bf79f14 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -677,7 +677,15 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm(ggml_meta const ggml_type tsrc1 = op->src[1]->type; const bool bc_inp = op->src[0]->ne[0] % 32 != 0; - const bool bc_out = op->ne[0] % 64 != 0 || op->ne[1] % 32 != 0; + + constexpr int NRA = SZ_SIMDGROUP * N_MM_BLOCK_Y * N_MM_SIMD_GROUP_Y; + constexpr int NRB = SZ_SIMDGROUP * N_MM_BLOCK_X * N_MM_SIMD_GROUP_X; + + const bool has_tensor = ggml_metal_device_get_props(ggml_metal_library_get_device(lib))->has_tensor; + + const bool bc_out = has_tensor + ? (op->ne[0] % NRA != 0 || op->ne[1] % NRB != 0) + : (op->ne[0] % 64 != 0 || op->ne[1] % 32 != 0); snprintf(base, 256, "kernel_mul_mm_%s_%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1)); snprintf(name, 256, "%s_bci=%d_bco=%d", base, bc_inp, bc_out); @@ -694,8 +702,20 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm(ggml_meta ggml_metal_cv_free(cv); } - // when the output size is not multiple of 64x32, we need extra smem to prevent out-of-bounds writes - res.smem = bc_out ? 8192 : 4096 + 2048; + if (has_tensor) { + res.nr0 = NRA; + res.nr1 = NRB; + + const size_t smem_a = NRA * N_MM_NK_TOTAL * sizeof(ggml_fp16_t); + res.smem = smem_a; + } else { + res.nr0 = 64; + res.nr1 = 32; + + res.smem = bc_out ? 8192 : (4096 + 2048); + } + + res.nsg = N_MM_SIMD_GROUP_X * N_MM_SIMD_GROUP_Y; return res; } diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index b423501358e..a6c1dab5515 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -102,6 +102,8 @@ ggml_metal_library_t ggml_metal_library_init_from_source(ggml_metal_device_t dev void ggml_metal_library_free(ggml_metal_library_t lib); +ggml_metal_device_t ggml_metal_library_get_device(ggml_metal_library_t lib); + struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline (ggml_metal_library_t lib, const char * name); struct ggml_metal_pipeline_with_params ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv); diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 27b78c5e6d7..fe90aafe7bc 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -95,8 +95,8 @@ int ggml_metal_pipeline_max_theads_per_threadgroup(struct ggml_metal_pipeline_wi struct ggml_metal_library { id obj; - id device; + ggml_metal_device_t dev; ggml_metal_pipelines_t pipelines; // cache of compiled pipelines NSLock * lock; @@ -251,7 +251,7 @@ ggml_metal_library_t ggml_metal_library_init(ggml_metal_device_t dev) { ggml_metal_library_t res = calloc(1, sizeof(struct ggml_metal_library)); res->obj = library; - res->device = device; + res->dev = dev; res->pipelines = ggml_metal_pipelines_init(); res->lock = [NSLock new]; @@ -318,7 +318,7 @@ ggml_metal_library_t ggml_metal_library_init_from_source(ggml_metal_device_t dev } res->obj = library; - res->device = device; + res->dev = dev; res->pipelines = ggml_metal_pipelines_init(); res->lock = [NSLock new]; @@ -341,6 +341,10 @@ void ggml_metal_library_free(ggml_metal_library_t lib) { free(lib); } +ggml_metal_device_t ggml_metal_library_get_device(ggml_metal_library_t lib) { + return lib->dev; +} + struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline(ggml_metal_library_t lib, const char * name) { [lib->lock lock]; @@ -405,7 +409,8 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_compile_pipeline(ggml_ return res; } - id obj = [lib->device newComputePipelineStateWithFunction:mtl_function error:&error]; + id device = ggml_metal_device_get_obj(lib->dev); + id obj = [device newComputePipelineStateWithFunction:mtl_function error:&error]; [mtl_function release]; @@ -699,7 +704,7 @@ ggml_metal_device_t ggml_metal_device_init(int device) { " auto sB = tB.slice(0, 0); \n" " mm.run(sB, sA, cT); \n" " \n" - " auto tC = tensor, tensor_inline>(C, dextents(4, 4)); \n" + " auto tC = tensor, tensor_inline>(C, dextents(16, 16)); \n" " \n" " cT.store(tC); \n" "}"; @@ -749,7 +754,7 @@ ggml_metal_device_t ggml_metal_device_init(int device) { " auto sB = tB.slice(0, 0); \n" " mm.run(sB, sA, cT); \n" " \n" - " auto tC = tensor, tensor_inline>(C, dextents(4, 4)); \n" + " auto tC = tensor, tensor_inline>(C, dextents(16, 16)); \n" " \n" " cT.store(tC); \n" "}"; diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 379a8b33a14..ff74cafb5b7 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -1,6 +1,19 @@ #ifndef GGML_METAL_IMPL #define GGML_METAL_IMPL +// kernel parameters for mat-mat threadgroups +// +// TODO: become function constants + +#define SZ_SIMDGROUP 16 +#define N_MM_NK 2 +#define N_MM_NK_TOTAL (SZ_SIMDGROUP * N_MM_NK) + +#define N_MM_BLOCK_X 4 +#define N_MM_BLOCK_Y 2 +#define N_MM_SIMD_GROUP_X 2 +#define N_MM_SIMD_GROUP_Y 2 + // kernel parameters for mat-vec threadgroups // // N_R0: number of src0 rows to process per simdgroup diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index e173527909a..5fa162c875c 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -2195,7 +2195,12 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) { const size_t smem = pipeline.smem; ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); - ggml_metal_encoder_dispatch_threadgroups(enc, ((ne11 + 31)/32), ((ne01 + 63)/64), ne12*ne13, 128, 1, 1); + + const int nr0 = pipeline.nr0; + const int nr1 = pipeline.nr1; + const int nsg = pipeline.nsg; + + ggml_metal_encoder_dispatch_threadgroups(enc, ((ne11 + nr1 - 1) / nr1), ((ne01 + nr0 - 1) / nr0), ne12 * ne13, 32, nsg, 1); } else { auto pipeline = ggml_metal_library_get_pipeline_mul_mv(lib, op); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 9f38c9d2968..c372eaedeae 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -9306,7 +9306,137 @@ constant bool FC_mul_mm_bc_inp [[function_constant(FC_MUL_MM + 0)]]; constant bool FC_mul_mm_bc_out [[function_constant(FC_MUL_MM + 1)]]; // each block_q contains 16*nl weights -template +#ifdef GGML_METAL_HAS_TENSOR +template< + typename SA, typename SA_4x4, typename SA_8x8, + typename SB, typename SB_2x4, typename SB_8x8, + typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread SA_4x4 &), + typename T0, typename T0_4x4, typename T1, typename T1_2x4> +kernel void kernel_mul_mm( + constant ggml_metal_kargs_mul_mm & args, + device const char * srcA, + device const char * srcB, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig [[threadgroup_position_in_grid]], + ushort tiitg [[thread_index_in_threadgroup]], + ushort sgitg [[simdgroup_index_in_threadgroup]]) { + (void) sgitg; + + // Matrix dimensions: A(M,K) x B(K,N) -> C(M,N) + const int K = args.ne00; + const int M = args.ne0; + const int N = args.ne1; + + // Batch dimension handling + const int im = tgpig.z; + const int i12 = im % args.ne12; + const int i13 = im / args.ne12; + + // Batch offsets for srcA and srcB + const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + + // Tile dimensions + constexpr int NRB = SZ_SIMDGROUP * N_MM_BLOCK_X * N_MM_SIMD_GROUP_X; + constexpr int NRA = SZ_SIMDGROUP * N_MM_BLOCK_Y * N_MM_SIMD_GROUP_Y; + + // Tile offsets in output matrix + const int ra = tgpig.y * NRA; + const int rb = tgpig.x * NRB; + + // Threadgroup memory for dequantized A tile only + threadgroup SA * sa = (threadgroup SA *)(shmem); + + // Work-item count for A loading + constexpr int A_WORK_ITEMS = NRA * N_MM_NK; + constexpr int NUM_THREADS = N_SIMDWIDTH * N_MM_SIMD_GROUP_X * N_MM_SIMD_GROUP_Y; + + // tA wraps threadgroup memory + auto tA = tensor(sa, dextents(N_MM_NK_TOTAL, NRA)); + + // tB wraps device memory directly + device T1 * ptrB = (device T1 *)(srcB + args.nb12*i12 + args.nb13*i13); + const int strideB = args.nb11 / sizeof(T1); + auto tB = tensor(ptrB, dextents(K, N), array({1, strideB})); + + // Configure matmul operation + mpp::tensor_ops::matmul2d< + mpp::tensor_ops::matmul2d_descriptor( + NRB, NRA, N_MM_NK_TOTAL, false, true, true, + mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate), + execution_simdgroups> mm; + + auto cT = mm.get_destination_cooperative_tensor(); + + // Accumulate partial results over K dimension + for (int loop_k = 0; loop_k < K; loop_k += N_MM_NK_TOTAL) { + // === PHASE 1: Dequantization of A into threadgroup memory === + for (int work = tiitg; work < A_WORK_ITEMS; work += NUM_THREADS) { + const int row = work / N_MM_NK; + const int k_chunk = work % N_MM_NK; + const int k_pos = loop_k + k_chunk * 16; + const short k_base = k_chunk * 16; + + // Bounds check: skip device read if row is out of matrix bounds + if (ra + row < M) { + if (is_same::value && FC_mul_mm_bc_inp) { + // Element-wise reads when K is not aligned (nb01 not aligned for half4x4/float4x4). + // MSL spec Table 2.5: half4x4 requires 8-byte alignment. When K is odd, + // nb01 = K*2 is not 8-byte aligned, so odd-row pointers are misaligned. + // Mirrors the legacy kernel's existing guard. + device const T0 * row_ptr = (device const T0 *)(srcA + args.nb01 * (ra + row) + offset0); + + FOR_UNROLL (short i = 0; i < 16; i++) { + sa[row * N_MM_NK_TOTAL + (k_base + i)] = (k_pos + i < K) ? (SA) row_ptr[k_pos + i] : (SA)0; + } + } else { + const int block_idx = k_pos / (16 * nl); + const short il = (k_pos / 16) % nl; + + device const block_q * row_ptr = (device const block_q *)(srcA + args.nb01 * (ra + row) + offset0); + + SA_4x4 temp_a; + dequantize_func(row_ptr + block_idx, il, temp_a); + + FOR_UNROLL (short i = 0; i < 16; i++) { + // Zero-pad A for K positions beyond valid range (handles partial K iterations) + sa[row * N_MM_NK_TOTAL + (k_base + i)] = (k_pos + i < K) ? temp_a[i/4][i%4] : (SA)0; + } + } + } else { + // Zero-pad rows beyond matrix bounds + FOR_UNROLL (short i = 0; i < 16; i++) { + sa[row * N_MM_NK_TOTAL + (k_base + i)] = (SA)0; + } + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // === PHASE 2: Tensor matmul === + auto mA = tA.slice(0, 0); + auto mB = tB.slice(loop_k, rb); + + mm.run(mB, mA, cT); + + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + // Store result tile to output matrix (with batch offset) + // cT.store handles bounds checking via tD's extents (M, N) + device float * dstBatch = (device float *)dst + im * N * M; + + auto tD = tensor(dstBatch, dextents(M, N), array({1, M})); + cT.store(tD.slice(ra, rb)); +} + +#else + +template< + typename S0, typename S0_4x4, typename S0_8x8, + typename S1, typename S1_2x4, typename S1_8x8, + typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread S0_4x4 &), + typename T0, typename T0_4x4, typename T1, typename T1_2x4> kernel void kernel_mul_mm( constant ggml_metal_kargs_mul_mm & args, device const char * src0, @@ -9320,10 +9450,6 @@ kernel void kernel_mul_mm( threadgroup S0 * sa = (threadgroup S0 *)(shmem); threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096); -#ifdef GGML_METAL_HAS_TENSOR - threadgroup float * sc = (threadgroup float *)(shmem); -#endif - constexpr int NR0 = 64; constexpr int NR1 = 32; @@ -9363,7 +9489,6 @@ kernel void kernel_mul_mm( + args.nb11*(r1 + lr1) + args.nb10*iy); -#ifndef GGML_METAL_HAS_TENSOR S0_8x8 ma[4]; S1_8x8 mb[2]; @@ -9372,19 +9497,8 @@ kernel void kernel_mul_mm( for (short i = 0; i < 8; i++){ mc[i] = make_filled_simdgroup_matrix(0.f); } -#else - auto tA = tensor, tensor_inline>(sa, dextents(NK, NR0)); - auto tB = tensor, tensor_inline>(sb, dextents(NR1, NK )); - - mpp::tensor_ops::matmul2d< - mpp::tensor_ops::matmul2d_descriptor(NR1, NR0, NK, false, true, false, mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate), - execution_simdgroups<4>> mm; - - auto cT = mm.get_destination_cooperative_tensor(); -#endif for (int loop_k = 0; loop_k < args.ne00; loop_k += NK) { -#ifndef GGML_METAL_HAS_TENSOR // load data and store to threadgroup memory if (is_same::value && FC_mul_mm_bc_inp) { threadgroup_barrier(mem_flags::mem_threadgroup); @@ -9454,66 +9568,6 @@ kernel void kernel_mul_mm( *(threadgroup S1_2x4 *)(sb + 64*ib + 8*ly) = (S1_2x4)(*((device T1_2x4 *) y)); } -#else - // load data and store to threadgroup memory - if (is_same::value && FC_mul_mm_bc_inp) { - threadgroup_barrier(mem_flags::mem_threadgroup); - - // no need for dequantization - for (short i = 0; i < 16; i++) { - const short sx = 2*il0 + i/8; - const short sy = (tiitg/NL0)/8; - - const short lx = i%8; - const short ly = (tiitg/NL0)%8; - //const short lx = (tiitg/NL0)%8; - //const short ly = i%8; - - *(sa + NK*(8*sy + ly) + 8*sx + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0; - } - } else { - S0_4x4 temp_a; - dequantize_func(x, il, temp_a); - - threadgroup_barrier(mem_flags::mem_threadgroup); - - FOR_UNROLL (short i = 0; i < 16; i++) { - const short sx = 2*il0 + i/8; - const short sy = (tiitg/NL0)/8; - - const short lx = i%8; - const short ly = (tiitg/NL0)%8; - //const short lx = (tiitg/NL0)%8; - //const short ly = i%8; - - *(sa + NK*(8*sy + ly) + 8*sx + lx) = temp_a[i/4][i%4]; - } - } - - if (FC_mul_mm_bc_inp) { - for (short i = 0; i < 8; ++i) { - const short sx = (tiitg%NL1); - const short sy = (tiitg/NL1)/8; - - const short lx = i; - const short ly = (tiitg/NL1)%8; - //const short lx = (tiitg/NL1)%8; - //const short ly = i; - - *(sb + NK*(8*sy + ly) + 8*sx + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0; - } - } else { - const short sx = (tiitg%NL1); - const short sy = (tiitg/NL1)/8; - - //const short lx = i; - const short ly = (tiitg/NL1)%8; - //const short lx = (tiitg/NL1)%8; - //const short ly = i; - - *(threadgroup S1_2x4 *)(sb + NK*(8*sy + ly) + 8*sx) = (S1_2x4)(*((device T1_2x4 *) y)); - } -#endif il = (il + 2 < nl) ? il + 2 : il % 2; x = (il < 2) ? x + (2 + nl - 1)/nl : x; @@ -9522,7 +9576,6 @@ kernel void kernel_mul_mm( threadgroup_barrier(mem_flags::mem_threadgroup); -#ifndef GGML_METAL_HAS_TENSOR // load matrices from threadgroup memory and conduct outer products threadgroup const S0 * lsma = (sa + 4*64*(sgitg%2)); threadgroup const S1 * lsmb = (sb + 2*64*(sgitg/2)); @@ -9549,24 +9602,10 @@ kernel void kernel_mul_mm( lsma += 8*64; lsmb += 4*64; } -#else - auto sA = tA.slice(0, 0); - auto sB = tB.slice(0, 0); - - mm.run(sB, sA, cT); -#endif } if (!FC_mul_mm_bc_out || (r0 + NR0 <= args.ne0 && r1 + NR1 <= args.ne1)) { // if no bounds checks on the output are needed, we can directly write to device memory -#ifdef GGML_METAL_HAS_TENSOR - device float * C = (device float *) dst + - r0 + \ - r1 * args.ne0 + im*args.ne1*args.ne0; - - auto tC = tensor, tensor_inline>(C, dextents(args.ne0, NR1)); - cT.store(tC); -#else device float * C = (device float *) dst + (r0 + 32*(sgitg & 1)) + \ (r1 + 16*(sgitg >> 1)) * args.ne0 + im*args.ne1*args.ne0; @@ -9574,21 +9613,15 @@ kernel void kernel_mul_mm( for (short i = 0; i < 8; i++) { simdgroup_store(mc[i], C + 8*(i%4) + 8*args.ne0*(i/4), args.ne0, 0, false); } -#endif } else { // block is smaller than 64x32, we should avoid writing data outside of the matrix threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup float * temp_str = ((threadgroup float *) shmem) + 32*(sgitg&1) + (16*(sgitg >> 1))*NR0; -#ifdef GGML_METAL_HAS_TENSOR - auto tC = tensor, tensor_inline>(sc, dextents(NR0, NR1)); - cT.store(tC); -#else for (short i = 0; i < 8; i++) { simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*NR0*(i/4), NR0, 0, false); } -#endif threadgroup_barrier(mem_flags::mem_threadgroup); @@ -9614,6 +9647,8 @@ kernel void kernel_mul_mm( } } +#endif // GGML_METAL_HAS_TENSOR + template // n_expert_used kernel void kernel_mul_mm_id_map0( constant ggml_metal_kargs_mul_mm_id_map0 & args, @@ -9789,7 +9824,7 @@ kernel void kernel_mul_mm_id( const short ib = 8*sx + sy; - *(sa + 64*ib + 8*ly + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0; + *(sa + 64*ib + 8*ly + lx) = loop_k + 16*il + i < args.ne00 ? (S0) *((device T0 *) x + i) : (S0) 0; } } else { S0_4x4 temp_a; From da738a74f56248a3488bf9f54dfd2da67abe1196 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sat, 25 Apr 2026 14:15:03 +0200 Subject: [PATCH 494/831] CUDA: reduce MMQ stream-k overhead (llama/22298) * CUDA: reduce MMQ stream-k overhead * use 32 bit integers for kbc --- ggml/src/ggml-cuda/mmq.cuh | 277 ++++++++++++++++++------------------- 1 file changed, 138 insertions(+), 139 deletions(-) diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index b1a319de9be..91a1b737a82 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -3478,10 +3478,10 @@ template static __global__ void mul_mat_q( const char * __restrict__ x, const int * __restrict__ y, const int32_t * __restrict__ ids_dst, const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, float * __restrict__ tmp_fixup, - const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_row_x, const int ncols_y, const int stride_col_dst, - const int channel_ratio, const int nchannels_y, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, - const int sample_ratio, const int nsamples_y, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, - const int ncols_max) { + const uint3 blocks_per_ne00, const int nrows_x, const int ncols_dst, const int stride_row_x, const int ncols_y, const int stride_col_dst, + const uint3 channel_ratio, const uint3 nchannels_y, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, + const uint3 sample_ratio, const uint3 nsamples_y, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, + const uint3 ntx) { // Skip unused template specializations for faster compilation: if (mmq_x > get_mmq_x_max_device() || mmq_x % mmq_get_granularity_device(mmq_x) != 0) { @@ -3495,8 +3495,7 @@ static __global__ void mul_mat_q( constexpr int qk = ggml_cuda_type_traits::qk; constexpr int mmq_y = get_mmq_y_device(); - const int ntx = (ncols_max + mmq_x - 1) / mmq_x; // Number of tiles x - const int nty = (nrows_x + mmq_y - 1) / mmq_y; // Number of tiles y + const uint32_t nty = (nrows_x + mmq_y - 1) / mmq_y; // Number of tiles y // Initialize the ids for writing back data with just the index. // For regular matrix multiplications this is never changed. @@ -3517,8 +3516,9 @@ static __global__ void mul_mat_q( // On non-CDNA AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead: #if (defined(GGML_USE_HIP) && !defined(CDNA)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA { - const int wt = blockIdx.z / nchannels_y; - const int zt = blockIdx.z - wt*nchannels_y; + const uint2 tmp2 = fast_div_modulo(blockIdx.z, nchannels_y); + const int wt = tmp2.x; + const int zt = tmp2.y; const int jt = blockIdx.y; const int it = blockIdx.x; @@ -3561,40 +3561,40 @@ static __global__ void mul_mat_q( const int tile_x_max_i = nrows_x - it*mmq_y - 1; const int tile_y_max_j = col_diff - jt*mmq_x - 1; - const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x; + const int offset_x = fastdiv(wt, sample_ratio)*stride_sample_x + fastdiv(zt, channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x; constexpr bool fixup = false; mul_mat_q_process_tile (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst, - tile_x_max_i, tile_y_max_j, 0, ncols_x/qk); + tile_x_max_i, tile_y_max_j, 0, blocks_per_ne00.z); return; } #endif // (defined(GGML_USE_HIP) && !defined(CDNA4) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA - constexpr int ITER_K = get_iter_k(type); - - const int64_t blocks_per_ne00 = ncols_x / qk; - constexpr int blocks_per_iter = ITER_K / qk; + constexpr int ITER_K = get_iter_k(type); + constexpr int blocks_per_iter = ITER_K / qk; // kbc == k block continuous, current index in continuous ijk space. - int64_t kbc = (int64_t) blockIdx.x *nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x; - int64_t kbc_stop = (int64_t)(blockIdx.x + 1)*nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x; + int kbc = int64_t(blockIdx.x) *(nsamples_y.z*nchannels_y.z*ntx.z*nty*blocks_per_ne00.z) / gridDim.x; + int kbc_stop = int64_t(blockIdx.x + 1)*(nsamples_y.z*nchannels_y.z*ntx.z*nty*blocks_per_ne00.z) / gridDim.x; - kbc -= (kbc % blocks_per_ne00) % blocks_per_iter; - kbc_stop -= (kbc_stop % blocks_per_ne00) % blocks_per_iter; + kbc -= fastmodulo(kbc, blocks_per_ne00) % blocks_per_iter; + kbc_stop -= fastmodulo(kbc_stop, blocks_per_ne00) % blocks_per_iter; // kb0 == k index when doing the matrix multiplication for an output tile. - int kb0_start = kbc % blocks_per_ne00; - int kb0_stop = min(blocks_per_ne00, kb0_start + kbc_stop - kbc); - while (kbc < kbc_stop && kb0_stop == blocks_per_ne00) { - int tmp = kbc; - const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00); - tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00); - const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00); - tmp -= wt * (nchannels_y*ntx*blocks_per_ne00); - const int zt = tmp / (ntx*blocks_per_ne00); - tmp -= zt * (ntx*blocks_per_ne00); - const int jt = tmp / blocks_per_ne00; + int kb0_start = fastmodulo(kbc, blocks_per_ne00); + int kb0_stop = min(blocks_per_ne00.z, uint32_t(kb0_start + kbc_stop - kbc)); + while (kbc < kbc_stop && kb0_stop == int(blocks_per_ne00.z)) { + int tmp = fastdiv(kbc, blocks_per_ne00); + uint2 tmp2 = fast_div_modulo(tmp, ntx); + const int jt = tmp2.y; + tmp = tmp2.x; + tmp2 = fast_div_modulo(tmp, nchannels_y); + const int zt = tmp2.y; + tmp = tmp2.x; + tmp2 = fast_div_modulo(tmp, nsamples_y); + const int wt = tmp2.y; + const int it = tmp2.x; // Defaults for regular matrix multiplication: int col_low = 0; @@ -3612,11 +3612,11 @@ static __global__ void mul_mat_q( offset_dst = 0; if (jt*mmq_x >= col_diff) { - kbc += blocks_per_ne00; - kbc -= kbc % blocks_per_ne00; + kbc += blocks_per_ne00.z; + kbc -= fastmodulo(kbc, blocks_per_ne00); kb0_start = 0; - kb0_stop = min(blocks_per_ne00, kbc_stop - kbc); + kb0_stop = min(blocks_per_ne00.z, uint32_t(kbc_stop - kbc)); continue; } @@ -3641,32 +3641,34 @@ static __global__ void mul_mat_q( const int tile_x_max_i = nrows_x - it*mmq_y - 1; const int tile_y_max_j = col_diff - jt*mmq_x - 1; - const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x; + const int offset_x = fastdiv(wt, sample_ratio)*stride_sample_x + fastdiv(zt, channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x; constexpr bool fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer. mul_mat_q_process_tile (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst, tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop); - kbc += blocks_per_ne00; - kbc -= kbc % blocks_per_ne00; + kbc += blocks_per_ne00.z; + kbc -= fastmodulo(kbc, blocks_per_ne00); kb0_start = 0; - kb0_stop = min(blocks_per_ne00, kbc_stop - kbc); + kb0_stop = min(blocks_per_ne00.z, uint32_t(kbc_stop - kbc)); } if (kbc >= kbc_stop) { return; } - int tmp = kbc; - const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00); - tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00); - const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00); - tmp -= wt * (nchannels_y*ntx*blocks_per_ne00); - const int zt = tmp / (ntx*blocks_per_ne00); - tmp -= zt * (ntx*blocks_per_ne00); - const int jt = tmp / blocks_per_ne00; + int tmp = fastdiv(kbc, blocks_per_ne00); + uint2 tmp2 = fast_div_modulo(tmp, ntx); + const int jt = tmp2.y; + tmp = tmp2.x; + tmp2 = fast_div_modulo(tmp, nchannels_y); + const int zt = tmp2.y; + tmp = tmp2.x; + tmp2 = fast_div_modulo(tmp, nsamples_y); + const int wt = tmp2.y; + const int it = tmp2.x; // Defaults for regular matrix multiplication: int col_low = 0; @@ -3708,7 +3710,7 @@ static __global__ void mul_mat_q( const int tile_x_max_i = nrows_x - it*mmq_y - 1; const int tile_y_max_j = col_diff - jt*mmq_x - 1; - const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x; + const int offset_x = fastdiv(wt, sample_ratio)*stride_sample_x + fastdiv(zt, channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x; constexpr bool fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks. mul_mat_q_process_tile @@ -3717,46 +3719,37 @@ static __global__ void mul_mat_q( } template -static __global__ void mul_mat_q_stream_k_fixup(const int32_t * ids_dst, - const int32_t * expert_bounds, - float * __restrict__ dst, - const float * __restrict__ tmp_last_tile, - const int ncols_x, - const int nrows_x, - const int ncols_dst, - const size_t stride_col_dst, - const int nchannels_y, - const size_t stride_channel_dst, - const int nsamples_y, - const size_t stride_sample_dst, - const int ncols_max) { - constexpr int mmq_y = get_mmq_y_device(); - constexpr int qk = ggml_cuda_type_traits::qk; - constexpr int ITER_K = get_iter_k(type); - - constexpr int blocks_per_iter = ITER_K / qk; - const int64_t blocks_per_ne00 = ncols_x / qk; +__launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device()/2, 1) +static __global__ void mul_mat_q_stream_k_fixup( + const int32_t * __restrict__ ids_dst, const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, + float * __restrict__ tmp_last_tile, const uint3 blocks_per_ne00, const int nrows_x, const int ncols_dst, + const int stride_col_dst, const uint3 nchannels_y, const int stride_channel_dst, const uint3 nsamples_y, + const int stride_sample_dst, const uint3 ntx) { + constexpr int mmq_y = get_mmq_y_device(); + constexpr int qk = ggml_cuda_type_traits::qk; + constexpr int ITER_K = get_iter_k(type); + constexpr int blocks_per_iter = ITER_K / qk; - constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int nwarps = mmq_get_nwarps_device()/2; constexpr int warp_size = ggml_cuda_get_physical_warp_size(); - float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f}; + float sum[mmq_x / nwarps] = {0.0f}; + const int i = blockIdx.y*warp_size + threadIdx.x; - const int ntx = (ncols_max + mmq_x - 1) / mmq_x; - const int nty = (nrows_x + mmq_y - 1) / mmq_y; + const int nty = (nrows_x + mmq_y - 1) / mmq_y; const int bidx0 = blockIdx.x; // kbc == k block continuous, current index in continuous ijk space. - int64_t kbc0 = (int64_t) bidx0 *nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x; - int64_t kbc0_stop = (int64_t)(bidx0 + 1)*nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x; + int kbc0 = int64_t(blockIdx.x) *(nsamples_y.z*nchannels_y.z*ntx.z*nty*blocks_per_ne00.z) / gridDim.x; + int kbc0_stop = int64_t(blockIdx.x + 1)*(nsamples_y.z*nchannels_y.z*ntx.z*nty*blocks_per_ne00.z) / gridDim.x; - kbc0 -= (kbc0 % blocks_per_ne00) % blocks_per_iter; - kbc0_stop -= (kbc0_stop % blocks_per_ne00) % blocks_per_iter; + kbc0 -= fastmodulo(kbc0, blocks_per_ne00) % blocks_per_iter; + kbc0_stop -= fastmodulo(kbc0_stop, blocks_per_ne00) % blocks_per_iter; const bool did_not_have_any_data = kbc0 == kbc0_stop; - const bool wrote_beginning_of_tile = kbc0 % blocks_per_ne00 == 0; - const bool did_not_write_last = kbc0/blocks_per_ne00 == kbc0_stop/blocks_per_ne00 && kbc0_stop % blocks_per_ne00 != 0; + const bool wrote_beginning_of_tile = fastmodulo(kbc0, blocks_per_ne00) == 0; + const bool did_not_write_last = fastdiv(kbc0, blocks_per_ne00) == fastdiv(kbc0_stop, blocks_per_ne00) && fastmodulo(kbc0_stop, blocks_per_ne00) != 0; if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) { return; } @@ -3765,11 +3758,11 @@ static __global__ void mul_mat_q_stream_k_fixup(const int32_t * ids_dst, // Iterate over previous blocks and sum up partial sums written to fixup buffer. // All CUDA blocks that get here must have a previous block that needs a fixup. - int64_t bidx = bidx0 - 1; - int64_t kbc_stop = kbc0; + int bidx = bidx0 - 1; + int kbc_stop = kbc0; while(true) { - int64_t kbc = bidx*nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x; - kbc -= (kbc % blocks_per_ne00) % blocks_per_iter; + int kbc = int64_t(bidx)*(nsamples_y.z*nchannels_y.z*ntx.z*nty*blocks_per_ne00.z) / gridDim.x; + kbc -= fastmodulo(kbc, blocks_per_ne00) % blocks_per_iter; if (kbc == kbc_stop) { // Did not have any data. bidx--; @@ -3779,20 +3772,16 @@ static __global__ void mul_mat_q_stream_k_fixup(const int32_t * ids_dst, any_fixup = true; + #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { const int j = j0 + threadIdx.y; -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { - const int i = i0 + threadIdx.x; - - sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size] += tmp_last_tile[bidx*(mmq_x*mmq_y) + j*mmq_y + i]; - } + sum[j0/nwarps] += tmp_last_tile[bidx*(mmq_x*mmq_y) + j*mmq_y + i]; } // If this block started in a previous tile we are done and don't need to combine additional partial results. - if (kbc % blocks_per_ne00 == 0 || kbc/blocks_per_ne00 < kbc0/blocks_per_ne00) { + if (fastmodulo(kbc, blocks_per_ne00) == 0 || fastdiv(kbc, blocks_per_ne00) < fastdiv(kbc0, blocks_per_ne00)) { break; } bidx--; @@ -3803,14 +3792,16 @@ static __global__ void mul_mat_q_stream_k_fixup(const int32_t * ids_dst, return; } - int tmp = kbc0; - const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00); - tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00); - const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00); - tmp -= wt * (nchannels_y*ntx*blocks_per_ne00); - const int zt = tmp / (ntx*blocks_per_ne00); - tmp -= zt * (ntx*blocks_per_ne00); - const int jt = tmp / blocks_per_ne00; + int tmp = fastdiv(kbc0, blocks_per_ne00); + uint2 tmp2 = fast_div_modulo(tmp, ntx); + const int jt = tmp2.y; + tmp = tmp2.x; + tmp2 = fast_div_modulo(tmp, nchannels_y); + const int zt = tmp2.y; + tmp = tmp2.x; + tmp2 = fast_div_modulo(tmp, nsamples_y); + const int wt = tmp2.y; + const int it = tmp2.x; if (!ids_dst) { const int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst + it*mmq_y; @@ -3818,6 +3809,9 @@ static __global__ void mul_mat_q_stream_k_fixup(const int32_t * ids_dst, const int i_max = nrows_x - it*mmq_y - 1; const int j_max = ncols_dst - jt*mmq_x - 1; + if (need_check && i > i_max) { + return; + } #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { @@ -3827,16 +3821,7 @@ static __global__ void mul_mat_q_stream_k_fixup(const int32_t * ids_dst, return; } -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { - const int i = i0 + threadIdx.x; - - if (need_check && i > i_max) { - continue; - } - - dst[j*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size]; - } + dst[j*stride_col_dst + i] += sum[j0/nwarps]; } return; } @@ -3856,6 +3841,9 @@ static __global__ void mul_mat_q_stream_k_fixup(const int32_t * ids_dst, const int i_max = nrows_x - it*mmq_y - 1; const int j_max = col_diff - jt*mmq_x - 1; + if (need_check && i > i_max) { + return; + } #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { @@ -3865,16 +3853,7 @@ static __global__ void mul_mat_q_stream_k_fixup(const int32_t * ids_dst, return; } -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { - const int i = i0 + threadIdx.x; - - if (need_check && i > i_max) { - continue; - } - - dst[ids_dst_shared[j]*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size]; - } + dst[ids_dst_shared[j]*stride_col_dst + i] += sum[j0/nwarps]; } } @@ -3922,29 +3901,44 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a const int channel_ratio = args.nchannels_y / args.nchannels_x; const int sample_ratio = args.nsamples_y / args.nsamples_x; + const uint3 blocks_per_ne00_fd = init_fastdiv_values(args.ncols_x / ggml_cuda_type_traits::qk); + const uint3 ntx_fd = init_fastdiv_values(ntx); + const uint3 nchannels_y_fd = init_fastdiv_values(args.nchannels_y); + const uint3 nsamples_y_fd = init_fastdiv_values(args.nsamples_y); + const uint3 channel_ratio_fd = init_fastdiv_values(channel_ratio); + const uint3 sample_ratio_fd = init_fastdiv_values(sample_ratio); + if (!args.use_stream_k) { if (args.nrows_x % mmq_y == 0) { constexpr bool need_check = false; mul_mat_q<<>> (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr, - args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, - channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, - sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, - args.ncols_max); + blocks_per_ne00_fd, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, + channel_ratio_fd, nchannels_y_fd, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, + sample_ratio_fd, nsamples_y_fd, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, + ntx_fd); } else { constexpr bool need_check = true; mul_mat_q<<>> (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr, - args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, - channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, - sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, - args.ncols_max); + blocks_per_ne00_fd, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, + channel_ratio_fd, nchannels_y_fd, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, + sample_ratio_fd, nsamples_y_fd, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, + ntx_fd); } return; } - const dim3 block_nums_stream_k(nsm, 1, 1); - const bool fixup_needed = ntx*nty*ntzw % nsm != 0; + // For the stream-k kernel it is possible to run it with tiling by setting the number of CUDA blocks equal to the number of tiles. + // This is worthwhile if the efficiency of tiling is high and skipping the fixup kernel is more important. + const int ntiles_dst = ntx * nty * ntzw; + const int tiles_nwaves = (ntiles_dst + nsm - 1) / nsm; + const int tiles_efficiency_percent = 100 * ntiles_dst / (nsm*tiles_nwaves); + const dim3 block_nums_stream_k(GGML_CUDA_CC_IS_NVIDIA(cc) && tiles_efficiency_percent >= 90 ? ntiles_dst : nsm, 1, 1); + + GGML_ASSERT(ntiles_dst * blocks_per_ne00_fd.z < (1 << 30)); // Assert that variable kbc will not overflow. + + const bool fixup_needed = ntiles_dst % block_nums_stream_k.x != 0; ggml_cuda_pool & pool = ctx.pool(id); ggml_cuda_pool_alloc tmp_fixup(pool); @@ -3952,40 +3946,45 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a tmp_fixup.alloc(block_nums_stream_k.x * mmq_x*mmq_y); } + const dim3 block_nums_fixup(block_nums_stream_k.x, mmq_y/warp_size, 1); + const dim3 block_dims_fixup(block_dims.x, block_dims.y/2, block_dims.z); + if (args.nrows_x % mmq_y == 0) { constexpr bool need_check = false; mul_mat_q<<>> (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, - args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, - channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, - sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, - args.ncols_max); + blocks_per_ne00_fd, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, + channel_ratio_fd, nchannels_y_fd, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, + sample_ratio_fd, nsamples_y_fd, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, + ntx_fd); if (!fixup_needed) { return; } - mul_mat_q_stream_k_fixup<<>> - (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst, - args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst, - args.ncols_max); + CUDA_CHECK(cudaGetLastError()); + mul_mat_q_stream_k_fixup<<>> + (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, blocks_per_ne00_fd, args.nrows_x, args.ncols_dst, + args.nrows_dst, nchannels_y_fd, args.stride_channel_dst, nsamples_y_fd, args.stride_sample_dst, + ntx_fd); } else { constexpr bool need_check = true; mul_mat_q<<>> (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, - args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, - channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, - sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, - args.ncols_max); + blocks_per_ne00_fd, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, + channel_ratio_fd, nchannels_y_fd, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, + sample_ratio_fd, nsamples_y_fd, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, + ntx_fd); if (!fixup_needed) { return; } - mul_mat_q_stream_k_fixup<<>> - (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst, - args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst, - args.ncols_max); + CUDA_CHECK(cudaGetLastError()); + mul_mat_q_stream_k_fixup<<>> + (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, blocks_per_ne00_fd, args.nrows_x, args.ncols_dst, + args.nrows_dst, nchannels_y_fd, args.stride_channel_dst, nsamples_y_fd, args.stride_sample_dst, + ntx_fd); } } From 1be2adf7b3df28f450a60822ad3952316aaa6644 Mon Sep 17 00:00:00 2001 From: Trivikram Reddy <127072883+trivikram-reddy1@users.noreply.github.com> Date: Sat, 25 Apr 2026 19:58:26 -0500 Subject: [PATCH 495/831] hexagon: guard HMX clock request for v75+ platforms (llama/22377) --- ggml/src/ggml-hexagon/htp/main.c | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index 62942f6384c..f58347304be 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -101,6 +101,7 @@ AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) { } } +#if __HVX_ARCH__ >= 75 { // Set HMX clock HAP_power_request_t request; @@ -118,6 +119,7 @@ AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) { return err; } } +#endif return AEE_SUCCESS; } From 93a3f376421cd1439e2f25e2b8687bb5685e4e15 Mon Sep 17 00:00:00 2001 From: lhez Date: Sat, 25 Apr 2026 21:21:58 -0700 Subject: [PATCH 496/831] opencl: add iq4_nl support (llama/22272) * opencl: add general support for iq4_nl * opencl: add iq4_nl gemm/gemv for adreno * opencl: pack 2 lut entries into a uint --- ggml/src/ggml-opencl/CMakeLists.txt | 5 + ggml/src/ggml-opencl/ggml-opencl.cpp | 594 ++++++++++++++++++ ggml/src/ggml-opencl/kernels/cvt.cl | 107 ++++ .../kernels/gemm_noshuffle_iq4_nl_f32.cl | 150 +++++ .../kernels/gemv_noshuffle_iq4_nl_f32.cl | 302 +++++++++ .../kernels/mul_mm_iq4_nl_f32_l4_lm.cl | 171 +++++ .../ggml-opencl/kernels/mul_mv_iq4_nl_f32.cl | 164 +++++ .../kernels/mul_mv_iq4_nl_f32_flat.cl | 202 ++++++ 8 files changed, 1695 insertions(+) create mode 100644 ggml/src/ggml-opencl/kernels/gemm_noshuffle_iq4_nl_f32.cl create mode 100644 ggml/src/ggml-opencl/kernels/gemv_noshuffle_iq4_nl_f32.cl create mode 100644 ggml/src/ggml-opencl/kernels/mul_mm_iq4_nl_f32_l4_lm.cl create mode 100644 ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32.cl create mode 100644 ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32_flat.cl diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index 772fc537494..5ed83eeb48a 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -96,6 +96,8 @@ set(GGML_OPENCL_KERNELS mul_mv_q6_k_f32_flat mul_mv_q8_0_f32 mul_mv_q8_0_f32_flat + mul_mv_iq4_nl_f32 + mul_mv_iq4_nl_f32_flat mul_mv_mxfp4_f32 mul_mv_mxfp4_f32_flat mul_mv_id_q4_0_f32_8x_flat @@ -110,12 +112,15 @@ set(GGML_OPENCL_KERNELS mul_mm_q4_0_f32_l4_lm mul_mm_q4_1_f32_l4_lm mul_mm_q8_0_f32_l4_lm + mul_mm_iq4_nl_f32_l4_lm mul_mm_q4_k_f32_l4_lm mul_mm_q5_k_f32_l4_lm mul_mm_q6_k_f32_l4_lm mul_mm_q8_0_f32_8x4 gemv_noshuffle_q4_1_f32 gemm_noshuffle_q4_1_f32 + gemv_noshuffle_iq4_nl_f32 + gemm_noshuffle_iq4_nl_f32 gemv_noshuffle_general_q8_0_f32 gemv_noshuffle_q4_k_f32 gemm_noshuffle_q4_k_f32 diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 8bc7ae65a6d..4d31591a4a6 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -545,6 +545,9 @@ struct ggml_backend_opencl_context { cl_kernel kernel_convert_block_q5_K_noshuffle; cl_kernel kernel_restore_block_q5_K_noshuffle; cl_kernel kernel_convert_block_q6_K, kernel_restore_block_q6_K; + cl_kernel kernel_convert_block_iq4_nl, kernel_restore_block_iq4_nl; + cl_kernel kernel_convert_block_iq4_nl_noshuffle; + cl_kernel kernel_restore_block_iq4_nl_noshuffle; cl_kernel kernel_mul_mat_q4_0_f32_1d_8x_flat, kernel_mul_mat_q4_0_f32_1d_16x_flat; cl_kernel kernel_mul_mv_q4_1_f32; cl_kernel kernel_mul_mv_q4_1_f32_flat; @@ -556,6 +559,8 @@ struct ggml_backend_opencl_context { cl_kernel kernel_mul_mv_q6_K_f32_flat; cl_kernel kernel_mul_mv_mxfp4_f32, kernel_mul_mv_mxfp4_f32_flat; cl_kernel kernel_mul_mv_q8_0_f32, kernel_mul_mv_q8_0_f32_flat; + cl_kernel kernel_mul_mv_iq4_nl_f32; + cl_kernel kernel_mul_mv_iq4_nl_f32_flat; cl_kernel kernel_solve_tri_f32; cl_kernel kernel_im2col_f32, kernel_im2col_f16; cl_kernel kernel_argsort_f32_i32; @@ -594,6 +599,7 @@ struct ggml_backend_opencl_context { cl_kernel kernel_mul_mm_q4_k_f32_l4_lm; cl_kernel kernel_mul_mm_q5_k_f32_l4_lm; cl_kernel kernel_mul_mm_q6_k_f32_l4_lm; + cl_kernel kernel_mul_mm_iq4_nl_f32_l4_lm; std::vector profiling_info; @@ -734,6 +740,8 @@ struct ggml_backend_opencl_context { cl_kernel kernel_gemm_noshuffle_q6_K_f32; cl_kernel kernel_gemv_noshuffle_q5_k_f32; cl_kernel kernel_gemm_noshuffle_q5_k_f32; + cl_kernel kernel_gemv_noshuffle_iq4_nl_f32; + cl_kernel kernel_gemm_noshuffle_iq4_nl_f32; #endif // GGML_OPENCL_USE_ADRENO_KERNELS void free() { @@ -954,6 +962,10 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve CL_CHECK((backend_ctx->kernel_restore_block_q6_K = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q6_K", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q6_K_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q6_K_noshuffle", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_q6_K_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q6_K_noshuffle", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_iq4_nl = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_iq4_nl", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_iq4_nl = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_iq4_nl", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_iq4_nl_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_iq4_nl_noshuffle", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_iq4_nl_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_iq4_nl_noshuffle", &err), err)); GGML_LOG_CONT("."); } @@ -1359,6 +1371,40 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // mul_mv_iq4_nl_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_iq4_nl_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_iq4_nl_f32.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mv_iq4_nl_f32 = clCreateKernel(prog, "kernel_mul_mv_iq4_nl_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // mul_mv_iq4_nl_f32_flat + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_iq4_nl_f32_flat.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_iq4_nl_f32_flat.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mv_iq4_nl_f32_flat = clCreateKernel(prog, "kernel_mul_mv_iq4_nl_f32_flat", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + // mul_mv_mxfp4_f32 { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -1567,6 +1613,23 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // mul_mm_iq4_nl_f32_l4_lm + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mm_iq4_nl_f32_l4_lm.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mm_iq4_nl_f32_l4_lm.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mm_iq4_nl_f32_l4_lm = clCreateKernel(prog, "kernel_mul_mm_iq4_nl_f32_l4_lm", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + // mul_mm_q4_k_f32_l4_lm { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -2647,6 +2710,45 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // gemm_noshuffle_iq4_nl_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemm_noshuffle_iq4_nl_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("gemm_noshuffle_iq4_nl_f32.cl"); +#endif + cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_gemm_noshuffle_iq4_nl_f32 = clCreateKernel(prog, "kernel_gemm_noshuffle_iq4_nl_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // gemv_noshuffle_iq4_nl_f32 + { + std::string CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable "; + if (backend_ctx->has_vector_subgroup_broadcast) { + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAST "; + } + +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemv_noshuffle_iq4_nl_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("gemv_noshuffle_iq4_nl_f32.cl"); +#endif + + cl_program prog = build_program_from_source( + backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_gemv_compile_opts); + + CL_CHECK((backend_ctx->kernel_gemv_noshuffle_iq4_nl_f32 = clCreateKernel(prog, "kernel_gemv_noshuffle_iq4_nl_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + // mul_mm_q8_0_f32_8x4 { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -3597,6 +3699,30 @@ struct ggml_tensor_extra_cl_q8_0 { } }; +struct ggml_tensor_extra_cl_iq4_nl { + cl_mem q = nullptr; + cl_mem q_img = nullptr; + + cl_mem d = nullptr; + cl_mem d_img = nullptr; + + size_t size_q = 0; + size_t size_d = 0; + + ~ggml_tensor_extra_cl_iq4_nl() { + reset(); + } + + void reset() { + if (q != nullptr) { CL_CHECK(clReleaseMemObject(q)); q = nullptr; } + if (d != nullptr) { CL_CHECK(clReleaseMemObject(d)); d = nullptr; } + q_img = nullptr; + d_img = nullptr; + size_q = 0; + size_d = 0; + } +}; + struct ggml_tensor_extra_cl_q4_K { // Quantized values cl_mem q = nullptr; @@ -4097,6 +4223,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te return op->src[1]->type == GGML_TYPE_F32; } else if (op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q4_1 || op->src[0]->type == GGML_TYPE_MXFP4 || + op->src[0]->type == GGML_TYPE_IQ4_NL || op->src[0]->type == GGML_TYPE_Q4_K || op->src[0]->type == GGML_TYPE_Q5_K || op->src[0]->type == GGML_TYPE_Q6_K) { @@ -4295,6 +4422,12 @@ struct ggml_backend_opencl_buffer_context { for (ggml_tensor_extra_cl_q8_0 * e : temp_tensor_extras_q8_0_in_use) { delete e; } + for (ggml_tensor_extra_cl_iq4_nl * e : temp_tensor_extras_iq4_nl) { + delete e; + } + for (ggml_tensor_extra_cl_iq4_nl * e : temp_tensor_extras_iq4_nl_in_use) { + delete e; + } for (ggml_tensor_extra_cl_q4_K * e : temp_tensor_extras_q4_K) { delete e; } @@ -4390,6 +4523,21 @@ struct ggml_backend_opencl_buffer_context { return extra; } + ggml_tensor_extra_cl_iq4_nl * ggml_opencl_alloc_temp_tensor_extra_iq4_nl() { + ggml_tensor_extra_cl_iq4_nl * extra; + if (temp_tensor_extras_iq4_nl.empty()) { + extra = new ggml_tensor_extra_cl_iq4_nl(); + } else { + extra = temp_tensor_extras_iq4_nl.back(); + temp_tensor_extras_iq4_nl.pop_back(); + } + + temp_tensor_extras_iq4_nl_in_use.push_back(extra); + + extra->reset(); + return extra; + } + ggml_tensor_extra_cl_q4_K * ggml_opencl_alloc_temp_tensor_extra_q4_K() { ggml_tensor_extra_cl_q4_K * extra; if (temp_tensor_extras_q4_K.empty()) { @@ -4461,6 +4609,11 @@ struct ggml_backend_opencl_buffer_context { } temp_tensor_extras_q8_0_in_use.clear(); + for (ggml_tensor_extra_cl_iq4_nl * e : temp_tensor_extras_iq4_nl_in_use) { + temp_tensor_extras_iq4_nl.push_back(e); + } + temp_tensor_extras_iq4_nl_in_use.clear(); + for (ggml_tensor_extra_cl_q4_K * e : temp_tensor_extras_q4_K_in_use) { temp_tensor_extras_q4_K.push_back(e); } @@ -4492,6 +4645,8 @@ struct ggml_backend_opencl_buffer_context { std::vector temp_tensor_extras_mxfp4_in_use; std::vector temp_tensor_extras_q8_0; std::vector temp_tensor_extras_q8_0_in_use; + std::vector temp_tensor_extras_iq4_nl; + std::vector temp_tensor_extras_iq4_nl_in_use; std::vector temp_tensor_extras_q4_K; std::vector temp_tensor_extras_q4_K_in_use; std::vector temp_tensor_extras_q5_K; @@ -5123,6 +5278,87 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, return; } + if (tensor->type == GGML_TYPE_IQ4_NL) { + ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra; + GGML_ASSERT(extra_orig && "Tensors in OpenCL backend should have been allocated and initialized"); + + ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; + ggml_tensor_extra_cl_iq4_nl * extra = ctx->ggml_opencl_alloc_temp_tensor_extra_iq4_nl(); + + size_t size_d = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t); + size_t size_q = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*(ggml_blck_size(tensor->type)/2); + GGML_ASSERT(size_d + size_q == ggml_nbytes(tensor) && "Incorrect tensor size"); + + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + CL_CHECK(clEnqueueWriteBuffer( + queue, data_device, CL_TRUE, 0, + ggml_nbytes(tensor), data, 0, NULL, NULL)); + + cl_buffer_region region; + + // Create subbuffer for scales. + region.origin = align_to(extra_orig->offset + tensor->view_offs + offset, backend_ctx->alignment); + region.size = size_d; + extra->d = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + auto previous_origin = region.origin; + + // Create subbuffer for quants. + region.origin = align_to(previous_origin + size_d, backend_ctx->alignment); + region.size = size_q; + extra->q = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + + #ifdef GGML_OPENCL_USE_ADRENO_KERNELS + cl_kernel kernel = backend_ctx->kernel_convert_block_iq4_nl; + if (use_adreno_kernels(backend_ctx, tensor)) { + kernel = backend_ctx->kernel_convert_block_iq4_nl_noshuffle; + } + #else + cl_kernel kernel = backend_ctx->kernel_convert_block_iq4_nl; + #endif + cl_ulong n_blk = ggml_nelements(tensor)/ggml_blck_size(tensor->type); + cl_uchar mask_0F = 0x0F; + cl_uchar mask_F0 = 0xF0; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_uchar), &mask_0F)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_uchar), &mask_F0)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &n_blk)); + + size_t global_work_size[] = {(size_t)CEIL_DIV(n_blk, 64)*64, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + + tensor->extra = extra; + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_kernels(backend_ctx, tensor)) { + int M = tensor->ne[1]; + int K = tensor->ne[0]; + GGML_ASSERT(K % 32 == 0); + + // Transpose q as ushort + transpose_2d_as_16b(backend_ctx, extra->q, extra->q, size_q, K/4, M); + // Transpose d as ushort + transpose_2d_as_16b(backend_ctx, extra->d, extra->d, size_d, K/32, M); + } +#endif + return; + } if (tensor->type == GGML_TYPE_Q4_K) { ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra; GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized"); @@ -5775,6 +6011,78 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, CL_CHECK(clReleaseMemObject(data_device)); return; } + if (tensor->type == GGML_TYPE_IQ4_NL) { + ggml_tensor_extra_cl_iq4_nl * extra = (ggml_tensor_extra_cl_iq4_nl *)tensor->extra; + + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_kernels(backend_ctx, tensor)) { + static ggml_cl_buffer buf_trans_q; + static ggml_cl_buffer buf_trans_d; + static ggml_cl_buffer buf_unpacked; + + cl_int M = tensor->ne[1]; + cl_int K = tensor->ne[0]; + GGML_ASSERT(K % 32 == 0); + + size_t size_q = (ggml_nelements(tensor)/ggml_blck_size(tensor->type))*(ggml_blck_size(tensor->type)/2); + size_t size_d = (ggml_nelements(tensor)/ggml_blck_size(tensor->type))*sizeof(ggml_fp16_t); + GGML_ASSERT(size_d + size_q == ggml_nbytes(tensor) && "Incorrect tensor size"); + + buf_trans_q.allocate(backend_ctx->context, size_q); + buf_trans_d.allocate(backend_ctx->context, size_d); + buf_unpacked.allocate(backend_ctx->context, ggml_nbytes(tensor)); + + // transpose q, d back + transpose_2d_as_16b(backend_ctx, extra->q, buf_trans_q.buffer, size_q, M, K/4); + transpose_2d_as_16b(backend_ctx, extra->d, buf_trans_d.buffer, size_d, M, K/32); + + cl_uchar mask_0F = 0x0F; + cl_uchar mask_F0 = 0xF0; + + cl_kernel kernel = backend_ctx->kernel_restore_block_iq4_nl_noshuffle; + cl_ulong n_blk = ggml_nelements(tensor)/ggml_blck_size(tensor->type); + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &buf_trans_q.buffer)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &buf_trans_d.buffer)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &buf_unpacked.buffer)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_uchar), &mask_0F)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_uchar), &mask_F0)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &n_blk)); + + size_t global_work_size[] = {(size_t)n_blk, 1, 1}; + size_t local_work_size[] = {1, 1, 1}; + + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); + CL_CHECK(clEnqueueReadBuffer(queue, buf_unpacked.buffer, CL_TRUE, offset, size, data, 0, NULL, NULL)); + return; + } +#endif + cl_kernel kernel = backend_ctx->kernel_restore_block_iq4_nl; + cl_ulong n_blk = ggml_nelements(tensor)/ggml_blck_size(tensor->type); + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &n_blk)); + + size_t global_work_size[] = {(size_t)n_blk, 1, 1}; + size_t local_work_size[] = {1, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clEnqueueReadBuffer( + queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); + return; + } if (tensor->type == GGML_TYPE_Q4_K) { ggml_tensor_extra_cl_q4_K * extra = (ggml_tensor_extra_cl_q4_K *)tensor->extra; @@ -9840,6 +10148,178 @@ static void ggml_cl_mul_mat_q4_1_f32_adreno(ggml_backend_t backend, const ggml_t #endif } +static void ggml_cl_mul_mat_iq4_nl_f32_adreno(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + ggml_tensor_extra_cl_iq4_nl * extra0_iq4_nl = (ggml_tensor_extra_cl_iq4_nl *)src0->extra; + + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + + const int ne1 = dst->ne[1]; + + GGML_ASSERT(ne00 % 32 == 0); + + cl_context context = backend_ctx->context; + cl_kernel kernel; + + cl_int err; + cl_image_format img_fmt; + cl_image_desc img_desc; + cl_buffer_region region; + + int M = ne01; + int N = ne1; + int K = ne00; + + if (ne1 == 1) { + cl_mem q_img = nullptr; + cl_mem b_sub_buf = nullptr; + cl_mem b_img = nullptr; + + // image for q + img_fmt = { CL_R, CL_UNSIGNED_INT32}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = M * K / 2 / 4; + img_desc.buffer = extra0_iq4_nl->q; + CL_CHECK((q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + // subbuffer for activations + region.origin = offset1; + region.size = K * N * sizeof(float); + CL_CHECK((b_sub_buf = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for activations + img_fmt = {CL_RGBA, CL_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * N / 4; + img_desc.buffer = b_sub_buf; + CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + kernel = backend_ctx->kernel_gemv_noshuffle_iq4_nl_f32; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &q_img)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_iq4_nl->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &b_img)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_int), &ne01)); + + size_t local_work_size[3] = {64, 4, 1}; + size_t global_work_size[3] = {(size_t)CEIL_DIV(ne01/2, 64)*64, 4, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + + CL_CHECK(clReleaseMemObject(q_img)); + CL_CHECK(clReleaseMemObject(b_sub_buf)); + CL_CHECK(clReleaseMemObject(b_img)); + } else { + cl_mem b_sub_buf = nullptr; + cl_mem b_sub_buf_trans = nullptr; + cl_mem b_img = nullptr; + cl_mem b_img_trans = nullptr; + + // subbuffer for activations + region.origin = offset1; + region.size = K * N * sizeof(float); + CL_CHECK((b_sub_buf = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for activations + img_fmt = {CL_RGBA, CL_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * N / 4; + img_desc.buffer = b_sub_buf; + CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + // pad N to multiple of 8 + int extra_elements = N % 8; + int padding = 0; + if (extra_elements > 0){ + padding = 8 - extra_elements; + } + + // subbuffer for transposed activations + region.origin = 0; + region.size = K * (N + padding) * sizeof(float)/2; + backend_ctx->prealloc_act_trans.allocate(context, region.size); + CL_CHECK((b_sub_buf_trans = clCreateSubBuffer(backend_ctx->prealloc_act_trans.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for transposed activations + img_fmt = {CL_RGBA, CL_HALF_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * (N + padding) / 4; + img_desc.buffer = b_sub_buf_trans; + CL_CHECK((b_img_trans = clCreateImage(context, 0, &img_fmt, &img_desc, NULL, &err), err)); + + // transpose activations + int height_B = N/4; + if (height_B == 0) { + height_B = 1; + } + int width_B = K/4; + int padded_height_B = (N + padding)/4; + + kernel = backend_ctx->kernel_transpose_32_16; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &b_img)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &b_img_trans)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_B)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_B)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &padded_height_B)); + + size_t local_work_size_t[2] = { 1, 16 }; + size_t global_work_size_t[2] = { (size_t)width_B, (size_t)padded_height_B }; + backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size_t, local_work_size_t, dst); + + // gemm + kernel = backend_ctx->kernel_gemm_noshuffle_iq4_nl_f32; + int padded_N = N + padding; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_iq4_nl->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_iq4_nl->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &b_img_trans)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_int), &padded_N)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_int), &ne1)); + + size_t global_work_size[3] = {(size_t)CEIL_DIV(ne1, 8), (size_t)CEIL_DIV(ne01, 4), 1}; + size_t local_work_size[3] = {1, 128, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + + CL_CHECK(clReleaseMemObject(b_sub_buf)); + CL_CHECK(clReleaseMemObject(b_sub_buf_trans)); + CL_CHECK(clReleaseMemObject(b_img)); + CL_CHECK(clReleaseMemObject(b_img_trans)); + } +#else + GGML_UNUSED(backend); + GGML_UNUSED(src0); + GGML_UNUSED(src1); + GGML_UNUSED(dst); +#endif +} + static void ggml_cl_mul_mat_q8_0_f32_adreno(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { #ifdef GGML_OPENCL_USE_ADRENO_KERNELS GGML_ASSERT(src0); @@ -10634,6 +11114,7 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co ggml_tensor_extra_cl_q4_1 * extra0_q4_1 = (ggml_tensor_extra_cl_q4_1 *)src0->extra; ggml_tensor_extra_cl_mxfp4 * extra0_mxfp4 = (ggml_tensor_extra_cl_mxfp4 *)src0->extra; ggml_tensor_extra_cl_q8_0 * extra0_q8_0 = (ggml_tensor_extra_cl_q8_0 *)src0->extra; + ggml_tensor_extra_cl_iq4_nl * extra0_iq4_nl = (ggml_tensor_extra_cl_iq4_nl *)src0->extra; ggml_tensor_extra_cl_q4_K * extra0_q4_K = (ggml_tensor_extra_cl_q4_K *)src0->extra; ggml_tensor_extra_cl_q5_K * extra0_q5_K = (ggml_tensor_extra_cl_q5_K *)src0->extra; ggml_tensor_extra_cl_q6_K * extra0_q6_K = (ggml_tensor_extra_cl_q6_K *)src0->extra; @@ -10738,6 +11219,12 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co return; } + // iq4_nl x fp32 + if (src0t == GGML_TYPE_IQ4_NL && src1t == GGML_TYPE_F32) { + ggml_cl_mul_mat_iq4_nl_f32_adreno(backend, src0, src1, dst); + return; + } + // q8_0 x fp32 if (src0t == GGML_TYPE_Q8_0 && src1t == GGML_TYPE_F32 && enable_adreno_trans_weight(backend_ctx, src0)) { @@ -11302,6 +11789,48 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); return; } + case GGML_TYPE_IQ4_NL: { + if (ne11 < 32) { + break; + } + if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) { + break; + } + + kernel = backend_ctx->kernel_mul_mm_iq4_nl_f32_l4_lm; + nth0 = 128; // calculated as (BM*BN)/(TM*TN) + + int batch_stride_a = ne00*ne01; + int batch_stride_b = ne10*ne11; + int batch_stride_d = ne0*ne1; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_iq4_nl->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_iq4_nl->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne10)); // stride_a + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne10)); // stride_b + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne01)); // stride_d + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &batch_stride_a)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &batch_stride_b)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &batch_stride_d)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &r3)); + + // 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed. + size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13}; + size_t local_work_size[] = {(size_t)nth0, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + return; + } case GGML_TYPE_Q4_K: { if (ne11 < 32) { break; @@ -11829,6 +12358,70 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne1)); CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &r2)); CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &r3)); +#endif // GGML_OPENCL_SOA_Q + break; + } + case GGML_TYPE_IQ4_NL: { +#ifdef GGML_OPENCL_SOA_Q + kernel = backend_ctx->kernel_mul_mv_iq4_nl_f32_flat; + + if (backend_ctx->gpu_family == INTEL) { + nth0 = 16; + nth1 = 1; + ndst = 8; + } else if (backend_ctx->gpu_family == ADRENO) { + nth0 = 64; + nth1 = 1; + ndst = 8; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_iq4_nl->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_iq4_nl->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r3)); +#else + kernel = backend_ctx->kernel_mul_mv_iq4_nl_f32; + + if (backend_ctx->gpu_family == INTEL) { + nth0 = 16; + nth1 = 1; + ndst = 4; + } else if (backend_ctx->gpu_family == ADRENO) { + nth0 = 64; + nth1 = 1; + ndst = 4; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r3)); #endif // GGML_OPENCL_SOA_Q break; } @@ -12131,6 +12724,7 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_MXFP4 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 || + src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_Q2_K) { // Each SIMD group produces N_DST values in the result. Assuming each // workgroup has N_SIMDGROUP SIMD groups, then each workgroup will diff --git a/ggml/src/ggml-opencl/kernels/cvt.cl b/ggml/src/ggml-opencl/kernels/cvt.cl index 39af32d282b..f3937d8304c 100644 --- a/ggml/src/ggml-opencl/kernels/cvt.cl +++ b/ggml/src/ggml-opencl/kernels/cvt.cl @@ -87,6 +87,17 @@ struct block_q6_K { half d; // super-block scale }; +//------------------------------------------------------------------------------ +// block_iq4_nl +//------------------------------------------------------------------------------ +#define QK4_NL 32 + +struct block_iq4_nl +{ + half d; + uint8_t qs[QK4_NL / 2]; +}; + //------------------------------------------------------------------------------ // kernel_convert_block_q4_0 // Convert the block_q4_0 format to 2 separate arrays (AOS -> SOA). @@ -895,3 +906,99 @@ kernel void kernel_restore_block_q6_K_noshuffle( b->scales[i] = s[i]; } } + +//------------------------------------------------------------------------------ +// kernel_convert_block_iq4_nl +// Convert the block_iq4_nl format to 2 separate arrays (AOS -> SOA). +//------------------------------------------------------------------------------ +kernel void kernel_convert_block_iq4_nl( + global struct block_iq4_nl * src0, + global uchar * dst_q, + global half * dst_d, + uchar mask_0F, + uchar mask_F0, + ulong n_blk +) { + if (get_global_id(0) >= n_blk) { + return; + } + global struct block_iq4_nl * b = (global struct block_iq4_nl *) src0 + get_global_id(0); + global uchar * q = (global uchar *) dst_q + QK4_NL/2*get_global_id(0); + global half * d = (global half *) dst_d + get_global_id(0); + + *d = b->d; + + for (int i = 0; i < QK4_NL/2; ++i) { + q[i] = b->qs[i]; + } +} + +kernel void kernel_restore_block_iq4_nl( + global uchar * src_q, + global half * src_d, + global struct block_iq4_nl * dst, + ulong n_blk +) { + if (get_global_id(0) >= n_blk) { + return; + } + global struct block_iq4_nl * b = (global struct block_iq4_nl *) dst + get_global_id(0); + global uchar * q = (global uchar *) src_q + QK4_NL/2*get_global_id(0); + global half * d = (global half *) src_d + get_global_id(0); + + b->d = *d; + + for (int i = 0; i < QK4_NL/2; ++i) { + b->qs[i] = q[i]; + } +} + +kernel void kernel_convert_block_iq4_nl_noshuffle( + global struct block_iq4_nl * src0, + global uchar * dst_q, + global half * dst_d, + uchar mask_0F, + uchar mask_F0, + ulong n_blk +) { + if (get_global_id(0) >= n_blk) { + return; + } + global struct block_iq4_nl * b = (global struct block_iq4_nl *) src0 + get_global_id(0); + global uchar * q = (global uchar *) dst_q + QK4_NL/2*get_global_id(0); + global half * d = (global half *) dst_d + get_global_id(0); + + *d = b->d; + for (int i = 0; i < QK4_NL/4; ++i) { + uchar x0 = b->qs[2*i + 0]; + uchar x1 = b->qs[2*i + 1]; + + q[i + 0 ] = convert_uchar(x0 & mask_0F) | convert_uchar((x1 & mask_0F) << 4); + q[i + QK4_NL/4] = convert_uchar((x0 & mask_F0) >> 4) | convert_uchar(x1 & mask_F0); + } +} + +kernel void kernel_restore_block_iq4_nl_noshuffle( + global uchar * src_q, + global half * src_d, + global struct block_iq4_nl * dst, + uchar mask_0F, + uchar mask_F0, + ulong n_blk +) { + if (get_global_id(0) >= n_blk) { + return; + } + global struct block_iq4_nl * b = (global struct block_iq4_nl *) dst + get_global_id(0); + global uchar * q = (global uchar *) src_q + QK4_NL/2*get_global_id(0); + global half * d = (global half *) src_d + get_global_id(0); + + b->d = *d; + for (int i = 0; i < QK4_NL/4; ++i) { + uchar x0 = q[i + 0 ]; + uchar x1 = q[i + QK4_NL/4]; + + b->qs[2*i + 0] = convert_uchar((x0 & mask_0F) | ((x1 & mask_0F) << 4)); + b->qs[2*i + 1] = convert_uchar(((x0 & mask_F0) >> 4) | (x1 & mask_F0)); + } +} diff --git a/ggml/src/ggml-opencl/kernels/gemm_noshuffle_iq4_nl_f32.cl b/ggml/src/ggml-opencl/kernels/gemm_noshuffle_iq4_nl_f32.cl new file mode 100644 index 00000000000..6869d822862 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemm_noshuffle_iq4_nl_f32.cl @@ -0,0 +1,150 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable + +#ifdef cl_qcom_reqd_sub_group_size +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +constant half kvalues_iq4nl[16] = { + (half)-127.f, (half)-104.f, (half)-83.f, (half)-65.f, + (half) -49.f, (half) -35.f, (half)-22.f, (half)-10.f, + (half) 1.f, (half) 13.f, (half) 25.f, (half) 38.f, + (half) 53.f, (half) 69.f, (half) 89.f, (half)113.f +}; + +// Packed LUT: 2 FP16 values per uint, 8 unique constant loads instead of 16 +constant uint iq4nl_packed[8] = { + 0xD680D7F0u, // idx 0,1: -127, -104 + 0xD410D530u, // idx 2,3: -83, -65 + 0xD060D220u, // idx 4,5: -49, -35 + 0xC900CD80u, // idx 6,7: -22, -10 + 0x4A803C00u, // idx 8,9: 1, 13 + 0x50C04E40u, // idx 10,11: 25, 38 + 0x545052A0u, // idx 12,13: 53, 69 + 0x57105590u // idx 14,15: 89, 113 +}; + +// Packed dequant: 1 uint constant load (8-way divergence) + shift + as_half +#define IQ4_NL_DEQUANT(nibble) as_half((ushort)(iq4nl_packed[(nibble) >> 1] >> (((nibble) & 1u) << 4))) + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_128 +#endif + +kernel void kernel_gemm_noshuffle_iq4_nl_f32( + global const ushort * src0_q, + global const half * src0_d, + read_only image1d_buffer_t src1, + global float * dst, + ulong offsetd, + int m, + int n, + int k, + int n_no_padding +) { + dst = (global float *)((global char *)dst + offsetd); + + int m_4 = m >> 2; + int n_4 = n >> 2; + + int gy = get_global_id(0); + int gx = get_global_id(1); + int gx_2 = gx << 2; + + half8 c0 = 0, c1 = 0, c2 = 0, c3 = 0; + half8 B; + half4 dequantized_weights; + + global const ushort * weight_ptr = src0_q + gx_2; + global const half * scale_ptr = src0_d + gx_2; + + for (int i = 0; i < k; i += 4) { + B.s0123 = read_imageh(src1, gy*2 + (i)*(n_4)); + B.s4567 = read_imageh(src1, gy*2 + (i)*(n_4)+1); + + ushort4 bits4 = vload4(0, weight_ptr + (i/4)*(m)); + + half4 scale = vload4(0, scale_ptr + (i/32)*(m)); + + // j=0 + dequantized_weights.s0 = IQ4_NL_DEQUANT(bits4.s0 & 0x000Fu) * scale.s0; + dequantized_weights.s1 = IQ4_NL_DEQUANT(bits4.s1 & 0x000Fu) * scale.s1; + dequantized_weights.s2 = IQ4_NL_DEQUANT(bits4.s2 & 0x000Fu) * scale.s2; + dequantized_weights.s3 = IQ4_NL_DEQUANT(bits4.s3 & 0x000Fu) * scale.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=1 + B.s0123 = read_imageh(src1, gy*2 + (i+1)*(n_4)); + B.s4567 = read_imageh(src1, gy*2 + (i+1)*(n_4)+1); + dequantized_weights.s0 = IQ4_NL_DEQUANT((bits4.s0 >> 4) & 0x000Fu) * scale.s0; + dequantized_weights.s1 = IQ4_NL_DEQUANT((bits4.s1 >> 4) & 0x000Fu) * scale.s1; + dequantized_weights.s2 = IQ4_NL_DEQUANT((bits4.s2 >> 4) & 0x000Fu) * scale.s2; + dequantized_weights.s3 = IQ4_NL_DEQUANT((bits4.s3 >> 4) & 0x000Fu) * scale.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=2 + B.s0123 = read_imageh(src1, gy*2 + (i+2)*(n_4)); + B.s4567 = read_imageh(src1, gy*2 + (i+2)*(n_4)+1); + dequantized_weights.s0 = IQ4_NL_DEQUANT((bits4.s0 >> 8) & 0x000Fu) * scale.s0; + dequantized_weights.s1 = IQ4_NL_DEQUANT((bits4.s1 >> 8) & 0x000Fu) * scale.s1; + dequantized_weights.s2 = IQ4_NL_DEQUANT((bits4.s2 >> 8) & 0x000Fu) * scale.s2; + dequantized_weights.s3 = IQ4_NL_DEQUANT((bits4.s3 >> 8) & 0x000Fu) * scale.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=3 + B.s0123 = read_imageh(src1, gy*2 + (i+3)*(n_4)); + B.s4567 = read_imageh(src1, gy*2 + (i+3)*(n_4)+1); + dequantized_weights.s0 = IQ4_NL_DEQUANT((bits4.s0 >> 12) & 0x000Fu) * scale.s0; + dequantized_weights.s1 = IQ4_NL_DEQUANT((bits4.s1 >> 12) & 0x000Fu) * scale.s1; + dequantized_weights.s2 = IQ4_NL_DEQUANT((bits4.s2 >> 12) & 0x000Fu) * scale.s2; + dequantized_weights.s3 = IQ4_NL_DEQUANT((bits4.s3 >> 12) & 0x000Fu) * scale.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + } + + int idx = (gy<<3)*m + (gx<<2); + + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s0, c1.s0, c2.s0, c3.s0), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s1, c1.s1, c2.s1, c3.s1), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s2, c1.s2, c2.s2, c3.s2), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s3, c1.s3, c2.s3, c3.s3), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s4, c1.s4, c2.s4, c3.s4), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s5, c1.s5, c2.s5, c3.s5), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s6, c1.s6, c2.s6, c3.s6), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s7, c1.s7, c2.s7, c3.s7), 0, dst + idx); + } +} diff --git a/ggml/src/ggml-opencl/kernels/gemv_noshuffle_iq4_nl_f32.cl b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_iq4_nl_f32.cl new file mode 100644 index 00000000000..9386bf25a6f --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_iq4_nl_f32.cl @@ -0,0 +1,302 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable + +#ifdef cl_qcom_reqd_sub_group_size +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#endif + +#define QK4_NL 32 +#define NSUBGROUPS 4 +#define SUBGROUP_SIZE 64 + +constant half kvalues_iq4nl[16] = { + (half)-127.f, (half)-104.f, (half)-83.f, (half)-65.f, + (half) -49.f, (half) -35.f, (half)-22.f, (half)-10.f, + (half) 1.f, (half) 13.f, (half) 25.f, (half) 38.f, + (half) 53.f, (half) 69.f, (half) 89.f, (half)113.f +}; + +// Packed LUT: 2 FP16 values per uint, 8 unique constant loads instead of 16 +constant uint iq4nl_packed[8] = { + 0xD680D7F0u, // idx 0,1: -127, -104 + 0xD410D530u, // idx 2,3: -83, -65 + 0xD060D220u, // idx 4,5: -49, -35 + 0xC900CD80u, // idx 6,7: -22, -10 + 0x4A803C00u, // idx 8,9: 1, 13 + 0x50C04E40u, // idx 10,11: 25, 38 + 0x545052A0u, // idx 12,13: 53, 69 + 0x57105590u // idx 14,15: 89, 113 +}; + +// Packed dequant: 1 uint constant load (8-way divergence) + shift + as_half +#define IQ4_NL_DEQUANT(nibble) as_half((ushort)(iq4nl_packed[(nibble) >> 1] >> (((nibble) & 1u) << 4))) + +#define dequantizeBlockAccum_ns_sgbroadcast_1_hi(total_sums, bits4, scale, y) \ + float shared_y; \ + shared_y = sub_group_broadcast(y.s0, 0); \ + total_sums.s0 += IQ4_NL_DEQUANT((bits4.s0 & 0x000F)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT((bits4.s1 & 0x000F)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 0); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s0 & 0x00F0) >> 4)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s1 & 0x00F0) >> 4)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 0); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s0 & 0x0F00) >> 8)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s1 & 0x0F00) >> 8)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 0); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s0 & 0xF000) >> 12)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s1 & 0xF000) >> 12)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 0); \ + total_sums.s0 += IQ4_NL_DEQUANT((bits4.s2 & 0x000F)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT((bits4.s3 & 0x000F)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 0); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s2 & 0x00F0) >> 4)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s3 & 0x00F0) >> 4)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 0); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s2 & 0x0F00) >> 8)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s3 & 0x0F00) >> 8)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 0); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s2 & 0xF000) >> 12)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s3 & 0xF000) >> 12)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s0, 1); \ + total_sums.s0 += IQ4_NL_DEQUANT((bits4.s4 & 0x000F)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT((bits4.s5 & 0x000F)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 1); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s4 & 0x00F0) >> 4)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s5 & 0x00F0) >> 4)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 1); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s4 & 0x0F00) >> 8)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s5 & 0x0F00) >> 8)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 1); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s4 & 0xF000) >> 12)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s5 & 0xF000) >> 12)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 1); \ + total_sums.s0 += IQ4_NL_DEQUANT((bits4.s6 & 0x000F)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT((bits4.s7 & 0x000F)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 1); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s6 & 0x00F0) >> 4)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s7 & 0x00F0) >> 4)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 1); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s6 & 0x0F00) >> 8)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s7 & 0x0F00) >> 8)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 1); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s6 & 0xF000) >> 12)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s7 & 0xF000) >> 12)) * scale.s1 * shared_y; \ + + +#define dequantizeBlockAccum_ns_sgbroadcast_1_lo(total_sums, bits4, scale, y) \ + shared_y = sub_group_broadcast(y.s0, 2); \ + total_sums.s0 += IQ4_NL_DEQUANT((bits4.s0 & 0x000F)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT((bits4.s1 & 0x000F)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 2); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s0 & 0x00F0) >> 4)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s1 & 0x00F0) >> 4)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 2); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s0 & 0x0F00) >> 8)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s1 & 0x0F00) >> 8)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 2); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s0 & 0xF000) >> 12)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s1 & 0xF000) >> 12)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 2); \ + total_sums.s0 += IQ4_NL_DEQUANT((bits4.s2 & 0x000F)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT((bits4.s3 & 0x000F)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 2); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s2 & 0x00F0) >> 4)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s3 & 0x00F0) >> 4)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 2); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s2 & 0x0F00) >> 8)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s3 & 0x0F00) >> 8)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 2); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s2 & 0xF000) >> 12)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s3 & 0xF000) >> 12)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s0, 3); \ + total_sums.s0 += IQ4_NL_DEQUANT((bits4.s4 & 0x000F)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT((bits4.s5 & 0x000F)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 3); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s4 & 0x00F0) >> 4)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s5 & 0x00F0) >> 4)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 3); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s4 & 0x0F00) >> 8)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s5 & 0x0F00) >> 8)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 3); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s4 & 0xF000) >> 12)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s5 & 0xF000) >> 12)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 3); \ + total_sums.s0 += IQ4_NL_DEQUANT((bits4.s6 & 0x000F)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT((bits4.s7 & 0x000F)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 3); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s6 & 0x00F0) >> 4)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s7 & 0x00F0) >> 4)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 3); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s6 & 0x0F00) >> 8)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s7 & 0x0F00) >> 8)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 3); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s6 & 0xF000) >> 12)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s7 & 0xF000) >> 12)) * scale.s1 * shared_y; \ + + +#define dequantizeBlockAccum_ns_sgbroadcast_8_hi(total_sums, bits4, scale, y) \ + float8 shared_y; \ + shared_y = sub_group_broadcast(y, 0); \ + total_sums.s0 += IQ4_NL_DEQUANT((bits4.s0 & 0x000F)) * scale.s0 * shared_y.s0; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s0 & 0x00F0) >> 4)) * scale.s0 * shared_y.s1; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s0 & 0x0F00) >> 8)) * scale.s0 * shared_y.s2; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s0 & 0xF000) >> 12)) * scale.s0 * shared_y.s3; \ + total_sums.s0 += IQ4_NL_DEQUANT((bits4.s2 & 0x000F)) * scale.s0 * shared_y.s4; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s2 & 0x00F0) >> 4)) * scale.s0 * shared_y.s5; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s2 & 0x0F00) >> 8)) * scale.s0 * shared_y.s6; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s2 & 0xF000) >> 12)) * scale.s0 * shared_y.s7; \ + total_sums.s1 += IQ4_NL_DEQUANT((bits4.s1 & 0x000F)) * scale.s1 * shared_y.s0; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s1 & 0x00F0) >> 4)) * scale.s1 * shared_y.s1; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s1 & 0x0F00) >> 8)) * scale.s1 * shared_y.s2; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s1 & 0xF000) >> 12)) * scale.s1 * shared_y.s3; \ + total_sums.s1 += IQ4_NL_DEQUANT((bits4.s3 & 0x000F)) * scale.s1 * shared_y.s4; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s3 & 0x00F0) >> 4)) * scale.s1 * shared_y.s5; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s3 & 0x0F00) >> 8)) * scale.s1 * shared_y.s6; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s3 & 0xF000) >> 12)) * scale.s1 * shared_y.s7; \ + shared_y = sub_group_broadcast(y, 1); \ + total_sums.s0 += IQ4_NL_DEQUANT((bits4.s4 & 0x000F)) * scale.s0 * shared_y.s0; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s4 & 0x00F0) >> 4)) * scale.s0 * shared_y.s1; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s4 & 0x0F00) >> 8)) * scale.s0 * shared_y.s2; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s4 & 0xF000) >> 12)) * scale.s0 * shared_y.s3; \ + total_sums.s0 += IQ4_NL_DEQUANT((bits4.s6 & 0x000F)) * scale.s0 * shared_y.s4; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s6 & 0x00F0) >> 4)) * scale.s0 * shared_y.s5; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s6 & 0x0F00) >> 8)) * scale.s0 * shared_y.s6; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s6 & 0xF000) >> 12)) * scale.s0 * shared_y.s7; \ + total_sums.s1 += IQ4_NL_DEQUANT((bits4.s5 & 0x000F)) * scale.s1 * shared_y.s0; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s5 & 0x00F0) >> 4)) * scale.s1 * shared_y.s1; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s5 & 0x0F00) >> 8)) * scale.s1 * shared_y.s2; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s5 & 0xF000) >> 12)) * scale.s1 * shared_y.s3; \ + total_sums.s1 += IQ4_NL_DEQUANT((bits4.s7 & 0x000F)) * scale.s1 * shared_y.s4; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s7 & 0x00F0) >> 4)) * scale.s1 * shared_y.s5; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s7 & 0x0F00) >> 8)) * scale.s1 * shared_y.s6; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s7 & 0xF000) >> 12)) * scale.s1 * shared_y.s7; \ + + +#define dequantizeBlockAccum_ns_sgbroadcast_8_lo(total_sums, bits4, scale, y) \ + shared_y = sub_group_broadcast(y, 2); \ + total_sums.s0 += IQ4_NL_DEQUANT((bits4.s0 & 0x000F)) * scale.s0 * shared_y.s0; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s0 & 0x00F0) >> 4)) * scale.s0 * shared_y.s1; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s0 & 0x0F00) >> 8)) * scale.s0 * shared_y.s2; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s0 & 0xF000) >> 12)) * scale.s0 * shared_y.s3; \ + total_sums.s0 += IQ4_NL_DEQUANT((bits4.s2 & 0x000F)) * scale.s0 * shared_y.s4; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s2 & 0x00F0) >> 4)) * scale.s0 * shared_y.s5; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s2 & 0x0F00) >> 8)) * scale.s0 * shared_y.s6; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s2 & 0xF000) >> 12)) * scale.s0 * shared_y.s7; \ + total_sums.s1 += IQ4_NL_DEQUANT((bits4.s1 & 0x000F)) * scale.s1 * shared_y.s0; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s1 & 0x00F0) >> 4)) * scale.s1 * shared_y.s1; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s1 & 0x0F00) >> 8)) * scale.s1 * shared_y.s2; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s1 & 0xF000) >> 12)) * scale.s1 * shared_y.s3; \ + total_sums.s1 += IQ4_NL_DEQUANT((bits4.s3 & 0x000F)) * scale.s1 * shared_y.s4; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s3 & 0x00F0) >> 4)) * scale.s1 * shared_y.s5; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s3 & 0x0F00) >> 8)) * scale.s1 * shared_y.s6; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s3 & 0xF000) >> 12)) * scale.s1 * shared_y.s7; \ + shared_y = sub_group_broadcast(y, 3); \ + total_sums.s0 += IQ4_NL_DEQUANT((bits4.s4 & 0x000F)) * scale.s0 * shared_y.s0; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s4 & 0x00F0) >> 4)) * scale.s0 * shared_y.s1; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s4 & 0x0F00) >> 8)) * scale.s0 * shared_y.s2; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s4 & 0xF000) >> 12)) * scale.s0 * shared_y.s3; \ + total_sums.s0 += IQ4_NL_DEQUANT((bits4.s6 & 0x000F)) * scale.s0 * shared_y.s4; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s6 & 0x00F0) >> 4)) * scale.s0 * shared_y.s5; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s6 & 0x0F00) >> 8)) * scale.s0 * shared_y.s6; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s6 & 0xF000) >> 12)) * scale.s0 * shared_y.s7; \ + total_sums.s1 += IQ4_NL_DEQUANT((bits4.s5 & 0x000F)) * scale.s1 * shared_y.s0; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s5 & 0x00F0) >> 4)) * scale.s1 * shared_y.s1; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s5 & 0x0F00) >> 8)) * scale.s1 * shared_y.s2; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s5 & 0xF000) >> 12)) * scale.s1 * shared_y.s3; \ + total_sums.s1 += IQ4_NL_DEQUANT((bits4.s7 & 0x000F)) * scale.s1 * shared_y.s4; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s7 & 0x00F0) >> 4)) * scale.s1 * shared_y.s5; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s7 & 0x0F00) >> 8)) * scale.s1 * shared_y.s6; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s7 & 0xF000) >> 12)) * scale.s1 * shared_y.s7; \ + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_gemv_noshuffle_iq4_nl_f32( + read_only image1d_buffer_t src0_q, + global half2 * src0_d, + read_only image1d_buffer_t src1, + global float * dst, + ulong offsetd, + int ne00, + int ne01) +{ + uint groupId = get_local_id(1); + uint gid = get_global_id(0); + ushort slid = get_sub_group_local_id(); + + uint K = ne00; + uint M = ne01; + + uint LINE_STRIDE_A = M / 2; + uint BLOCK_STRIDE_A = NSUBGROUPS * M; + + private uint4 regA; + private half2 regS; + private float8 regB; + + private float2 totalSum = (float2)(0.0f); + + // loop along K in block granularity, skip 4 blocks every iter + for (uint k = groupId; k < (K / QK4_NL); k += NSUBGROUPS) { + regS = src0_d[gid + k * LINE_STRIDE_A]; // each fiber loads scale of two rows + // first 4 fibers in each wave load 8 B values to its private scope + if (slid < 4) { + regB.s0123 = read_imagef(src1, (slid * 2 + k * 8)); + regB.s4567 = read_imagef(src1, (1 + slid * 2 + k * 8)); + } + + // load half weights for two blocks in consecutive rows + regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 0)).x; + regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 1)).x; + regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 2)).x; + regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 3)).x; +#ifdef VECTOR_SUB_GROUP_BROADCAST + dequantizeBlockAccum_ns_sgbroadcast_8_hi(totalSum, as_ushort8(regA), regS, regB); +#else + dequantizeBlockAccum_ns_sgbroadcast_1_hi(totalSum, as_ushort8(regA), regS, regB); +#endif // VECTOR_SUB_GROUP_BROADCAST + + regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 4)).x; + regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 5)).x; + regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x; + regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x; +#ifdef VECTOR_SUB_GROUP_BROADCAST + dequantizeBlockAccum_ns_sgbroadcast_8_lo(totalSum, as_ushort8(regA), regS, regB); +#else + dequantizeBlockAccum_ns_sgbroadcast_1_lo(totalSum, as_ushort8(regA), regS, regB); +#endif // VECTOR_SUB_GROUP_BROADCAST + } + + // reduction in local memory, assumes #wave=4 + local float2 reduceLM[SUBGROUP_SIZE * 3]; + if (groupId == 1) { + reduceLM[SUBGROUP_SIZE * 0 + slid] = totalSum; + } + if (groupId == 2) { + reduceLM[SUBGROUP_SIZE * 1 + slid] = totalSum; + } + if (groupId == 3) { + reduceLM[SUBGROUP_SIZE * 2 + slid] = totalSum; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + if (groupId == 0) { + totalSum += reduceLM[SUBGROUP_SIZE * 0 + slid]; + } + if (groupId == 0) { + totalSum += reduceLM[SUBGROUP_SIZE * 1 + slid]; + } + if (groupId == 0) { + totalSum += reduceLM[SUBGROUP_SIZE * 2 + slid]; + } + + // 2 outputs per fiber in wave 0 + if (groupId == 0) { + dst = (global float*)((global char*)dst + offsetd); + vstore2(totalSum, 0, &(dst[gid * 2])); + } + +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mm_iq4_nl_f32_l4_lm.cl b/ggml/src/ggml-opencl/kernels/mul_mm_iq4_nl_f32_l4_lm.cl new file mode 100644 index 00000000000..11ff7f8d9dc --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mm_iq4_nl_f32_l4_lm.cl @@ -0,0 +1,171 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#define LOAD_VEC_A 8 +#define LOAD_VEC_B 4 + +#define BM 64 +#define BN 64 +#define BK 32 +#define TM 4 +#define TN 8 + +constant float kvalues_iq4nl[16] = { + -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, + 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f +}; + +kernel void kernel_mul_mm_iq4_nl_f32_l4_lm( + global uchar4 * src0_q, + global half * src0_d, + global float4 * src1, + ulong offset1, + global float * dst, + ulong offsetd, + + int ne00, + int ne01, + int ne02, + int ne11, + int ne12, + + int stride_a, + int stride_b, + int stride_d, + + int batch_stride_a, + int batch_stride_b, + int batch_stride_d, + + int r2, + int r3 +) { + src1 = (global float4*)((global char*)src1 + offset1); + dst = (global float *)((global char*)dst + offsetd); + + local float buf_a[BM * BK]; + local float buf_b[BN * BK]; + + const int batch_idx = get_global_id(2); + + const int i13 = batch_idx / ne12; + const int i12 = batch_idx % ne12; + + const int i03 = i13 / r3; + const int i02 = i12 / r2; + + const int batch_idx_a = i03 * ne02 + i02; + + const int ir = get_group_id(0); + const int ic = get_group_id(1); + + const int tid = get_local_id(0); + const int th_r = tid % (BM / TM); + const int th_c = tid / (BM / TM); + + const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A); + const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A); + const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B); + const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B); + + const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK; + const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK; + + int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A; + int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B; + + float sums[TM * TN]; + float cache_a[TM]; + float cache_b[TN]; + + for (int i = 0; i < TM * TN; i++) { + sums[i] = 0.0f; + } + + for (int block = 0; block < ne00; block += BK) { + for (int l = 0; l < BM; l += loadstride_a) { + if (ir*BM + loadc_a + l < ne01) { + int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a; + int ib = idx / 4; + int iqs = idx % 4; + + float d = (float)src0_d[ib]; + global uchar4 * qs = src0_q + ib*4 + iqs; + uchar4 q = *qs; + // IQ4_NL: use lookup table instead of linear (nibble - 8) + float4 v1 = (float4)(kvalues_iq4nl[(q.s0 )&0x0F], kvalues_iq4nl[(q.s1 )&0x0F], + kvalues_iq4nl[(q.s2 )&0x0F], kvalues_iq4nl[(q.s3 )&0x0F])*d; + float4 v2 = (float4)(kvalues_iq4nl[(q.s0>>4)&0x0F], kvalues_iq4nl[(q.s1>>4)&0x0F], + kvalues_iq4nl[(q.s2>>4)&0x0F], kvalues_iq4nl[(q.s3>>4)&0x0F])*d; + + buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = v1.s0; + buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = v1.s1; + buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = v1.s2; + buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = v1.s3; + buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = v2.s0; + buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = v2.s1; + buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = v2.s2; + buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = v2.s3; + } else { + buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = 0.0f; + } + } + + for (int l = 0; l < BN; l += loadstride_b) { + if (ic*BN + loadc_b + l < ne11) { + int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b; + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3; + } else { + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f; + } + } + + barrier(CLK_LOCAL_MEM_FENCE); + + pos_a += BK / LOAD_VEC_A; + pos_b += BK / LOAD_VEC_B; + + for (int i = 0; i < BK; i++) { + for (int j = 0; j < TM; j++) { + cache_a[j] = buf_a[(i) * BM + th_r * TM + j]; + } + + for (int j = 0; j < TN; j++) { + cache_b[j] = buf_b[(i) * BN + th_c * TN + j]; + } + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + const int sums_idx = cc*TM + cr; + sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]); + } + } + } + barrier(CLK_LOCAL_MEM_FENCE); + } + + const int dr = ir * BM + th_r * TM; + const int dc = ic * BN + th_c * TN; + + const int offsets = batch_idx * batch_stride_d; + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + if (dr + cr < ne01 && dc + cc < ne11) { + dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr]; + } + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32.cl b/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32.cl new file mode 100644 index 00000000000..a6a325cd729 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32.cl @@ -0,0 +1,164 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define QK4_NL 32 + +typedef char int8_t; +typedef uchar uint8_t; +typedef short int16_t; +typedef ushort uint16_t; +typedef int int32_t; +typedef uint uint32_t; + +constant float kvalues_iq4nl[16] = { + -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, + 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f +}; + +//------------------------------------------------------------------------------ +// block_iq4_nl +//------------------------------------------------------------------------------ +struct block_iq4_nl +{ + half d; + uint8_t qs[QK4_NL / 2]; +}; + +//------------------------------------------------------------------------------ +// mul_vec_q_n_f32 +//------------------------------------------------------------------------------ +// Compute inner product between half a block of iq4_nl and 16 floats (yl). +// il indicates where the quants begin (0 or 8). +inline float block_iq4_nl_dot_y( + global struct block_iq4_nl * qb_curr, + private float * yl, + int il +) { + float d = qb_curr->d; + float acc = 0.f; + global uchar * qs = qb_curr->qs + il; + for (int i = 0; i < 8; ++i) { + acc += yl[i] * kvalues_iq4nl[qs[i] & 0x0F]; + acc += yl[i+8] * kvalues_iq4nl[qs[i] >> 4]; + } + return d * acc; +} + +#ifdef INTEL_GPU +#define N_DST 4 // each subgroup group works on 4 rows +#define N_SUBGROUP 1 // number of subgroups in a thread group +#define N_SUBGROUP_SIZE 16 // assuming subgroup size is 16 +#elif defined (ADRENO_GPU) +#define N_DST 4 +#define N_SUBGROUP 1 +#define N_SUBGROUP_SIZE 64 +#endif + +inline void mul_vec_q_n_f32( + global void * src0, + global float * src1, + global float * dst, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + + const ulong nb = ne00/QK4_NL; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + int first_row = (r0 * N_SUBGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + global struct block_iq4_nl * x = (global struct block_iq4_nl *) src0 + offset0; + global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[16]; // src1 vector cache + float sumf[N_DST]={0.f}; + + int ix = get_sub_group_local_id()/2; + int il = 8*(get_sub_group_local_id()%2); + + global float * yb = y + ix * QK4_NL + il; + + // each thread in a SIMD group deals with half a block. + for (int ib = ix; ib < nb; ib += N_SUBGROUP_SIZE/2) { + for (int i = 0; i < 8; ++i) { + yl[i] = yb[i]; + yl[i+8] = yb[i+16]; + } + + for (int row = 0; row < N_DST; row++) { + sumf[row] += block_iq4_nl_dot_y(x+ib+row*nb, yl, il); + } + + yb += QK4_NL * (N_SUBGROUP_SIZE/2); + } + + float tot[N_DST] = { + sub_group_reduce_add(sumf[0]), sub_group_reduce_add(sumf[1]), + sub_group_reduce_add(sumf[2]), sub_group_reduce_add(sumf[3])}; + for (int row = 0; row < N_DST; ++row) { + if (get_sub_group_local_id() == 0 && first_row + row < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot[row]; + } + } +} + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mv_iq4_nl_f32( + global void * src0, + ulong offset0, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + mul_vec_q_n_f32(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32_flat.cl b/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32_flat.cl new file mode 100644 index 00000000000..8c5b3f52e42 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32_flat.cl @@ -0,0 +1,202 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define QK4_NL 32 + +typedef char int8_t; +typedef uchar uint8_t; +typedef short int16_t; +typedef ushort uint16_t; +typedef int int32_t; +typedef uint uint32_t; + +constant float kvalues_iq4nl[16] = { + -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, + 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f +}; + +//------------------------------------------------------------------------------ +// block_iq4_nl +//------------------------------------------------------------------------------ +struct block_iq4_nl +{ + half d; + uint8_t qs[QK4_NL / 2]; +}; + +// Compute dot product between half a block of iq4_nl quants and activations. +// x points to the quant bytes, dh points to the scale. +// yl has 16 activation values: [0..7] for low nibbles, [8..15] for high nibbles. +// il indicates offset into the quant bytes (0 or 8). +inline float block_iq4_nl_dot_y_flat( + global uchar * x, + global half * dh, + private float * yl, + int il +) { + float d = *dh; + global uchar * qs = x + il; + float acc = 0.f; + for (int i = 0; i < 8; ++i) { + acc += yl[i] * kvalues_iq4nl[qs[i] & 0x0F]; + acc += yl[i+8] * kvalues_iq4nl[qs[i] >> 4]; + } + return d * acc; +} + +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 8 // each subgroup works on 8 rows +#define N_SUBGROUP 1 // number of subgroups in a thread group +#define N_SUBGROUP_SIZE 16 // assuming subgroup size is 16 +#elif defined (ADRENO_GPU) +#define N_DST 8 +#define N_SUBGROUP 1 +#define N_SUBGROUP_SIZE 64 +#endif + +inline void mul_vec_q_n_f32_8x_flat( + global uchar * src0_q, + global half * src0_d, + global float * src1, + global float * dst, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + const ulong nb = ne00/QK4_NL; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + int first_row = (r0 * N_SUBGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + // The number of scales is the same as the number of blocks. + ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + // Each block contains QK4_NL/2 uchars, hence offset for qs is as follows. + ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_NL/2; + + global uchar * x = (global uchar *) src0_q + offset0_q; + global half * d = (global half *) src0_d + offset0_d; + global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[16]; + float8 sumf = 0.f; + + int ix = get_sub_group_local_id()/2; + int il = 8*(get_sub_group_local_id()%2); + + global float * yb = y + ix*QK4_NL + il; + + for (int ib = ix; ib < nb; ib += N_SUBGROUP_SIZE/2) { + for (int i = 0; i < 8; ++i) { + yl[i] = yb[i]; + yl[i+8] = yb[i+16]; + } + + sumf.s0 += block_iq4_nl_dot_y_flat(x + ib*QK4_NL/2 + 0*nb*QK4_NL/2, d + ib + 0*nb, yl, il); + sumf.s1 += block_iq4_nl_dot_y_flat(x + ib*QK4_NL/2 + 1*nb*QK4_NL/2, d + ib + 1*nb, yl, il); + sumf.s2 += block_iq4_nl_dot_y_flat(x + ib*QK4_NL/2 + 2*nb*QK4_NL/2, d + ib + 2*nb, yl, il); + sumf.s3 += block_iq4_nl_dot_y_flat(x + ib*QK4_NL/2 + 3*nb*QK4_NL/2, d + ib + 3*nb, yl, il); + + sumf.s4 += block_iq4_nl_dot_y_flat(x + ib*QK4_NL/2 + 4*nb*QK4_NL/2, d + ib + 4*nb, yl, il); + sumf.s5 += block_iq4_nl_dot_y_flat(x + ib*QK4_NL/2 + 5*nb*QK4_NL/2, d + ib + 5*nb, yl, il); + sumf.s6 += block_iq4_nl_dot_y_flat(x + ib*QK4_NL/2 + 6*nb*QK4_NL/2, d + ib + 6*nb, yl, il); + sumf.s7 += block_iq4_nl_dot_y_flat(x + ib*QK4_NL/2 + 7*nb*QK4_NL/2, d + ib + 7*nb, yl, il); + + yb += QK4_NL * (N_SUBGROUP_SIZE/2); + } + + float8 tot = (float8)( + sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), + sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3), + sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5), + sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7) + ); + + if (get_sub_group_local_id() == 0) { + if (first_row + 0 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; + } + if (first_row + 1 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; + } + if (first_row + 2 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; + } + if (first_row + 3 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; + } + + if (first_row + 4 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4; + } + if (first_row + 5 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5; + } + if (first_row + 6 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6; + } + if (first_row + 7 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7; + } + } +} + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mv_iq4_nl_f32_flat( + global uchar * src0_q, + global half * src0_d, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + mul_vec_q_n_f32_8x_flat(src0_q, src0_d, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); +} From 4e11277a198de0f0ccc9a6fbfd6e943a7602b546 Mon Sep 17 00:00:00 2001 From: Eve <139727413+netrunnereve@users.noreply.github.com> Date: Sun, 26 Apr 2026 06:27:50 +0000 Subject: [PATCH 497/831] ggml-cpu: optimize avx2 q6_k (llama/22345) --- ggml/src/ggml-cpu/arch/x86/quants.c | 46 ++++++++++++----------------- 1 file changed, 19 insertions(+), 27 deletions(-) diff --git a/ggml/src/ggml-cpu/arch/x86/quants.c b/ggml/src/ggml-cpu/arch/x86/quants.c index 0a3e071e57c..94b19b82bbc 100644 --- a/ggml/src/ggml-cpu/arch/x86/quants.c +++ b/ggml/src/ggml-cpu/arch/x86/quants.c @@ -2300,9 +2300,8 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi #if defined __AVX2__ - const __m256i m4 = _mm256_set1_epi8(0xF); - const __m256i m2 = _mm256_set1_epi8(3); - const __m256i m32s = _mm256_set1_epi8(32); + const __m256i m3 = _mm256_set1_epi8(3); + const __m256i m15 = _mm256_set1_epi8(15); __m256 acc = _mm256_setzero_ps(); @@ -2314,53 +2313,45 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi const uint8_t * GGML_RESTRICT qh = x[i].qh; const int8_t * GGML_RESTRICT q8 = y[i].qs; + const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums); const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales); + const __m256i scales_16 = _mm256_cvtepi8_epi16(scales); + const __m256i q8sclsub = _mm256_slli_epi32(_mm256_madd_epi16(q8sums, scales_16), 5); __m256i sumi = _mm256_setzero_si256(); int is = 0; for (int j = 0; j < QK_K/128; ++j) { - - const __m128i scale_0 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 0)); - const __m128i scale_1 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 1)); - const __m128i scale_2 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 2)); - const __m128i scale_3 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 3)); - is += 4; - const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4); q4 += 32; const __m256i q4bits2 = _mm256_loadu_si256((const __m256i*)q4); q4 += 32; const __m256i q4bitsH = _mm256_loadu_si256((const __m256i*)qh); qh += 32; - const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(q4bitsH, m2), 4); - const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 2), m2), 4); - const __m256i q4h_2 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 4), m2), 4); - const __m256i q4h_3 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 6), m2), 4); + const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(q4bitsH, m3), 4); + const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(q4bitsH, _mm256_set1_epi8(12)), 2); + const __m256i q4h_2 = _mm256_and_si256(q4bitsH, _mm256_set1_epi8(48)); + const __m256i q4h_3 = _mm256_srli_epi16(_mm256_and_si256(q4bitsH, _mm256_set1_epi8(-64)), 2); - const __m256i q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m4), q4h_0); - const __m256i q4_1 = _mm256_or_si256(_mm256_and_si256(q4bits2, m4), q4h_1); - const __m256i q4_2 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m4), q4h_2); - const __m256i q4_3 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits2, 4), m4), q4h_3); + const __m256i q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m15), q4h_0); + const __m256i q4_1 = _mm256_or_si256(_mm256_and_si256(q4bits2, m15), q4h_1); + const __m256i q4_2 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m15), q4h_2); + const __m256i q4_3 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits2, 4), m15), q4h_3); const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - __m256i q8s_0 = _mm256_maddubs_epi16(m32s, q8_0); - __m256i q8s_1 = _mm256_maddubs_epi16(m32s, q8_1); - __m256i q8s_2 = _mm256_maddubs_epi16(m32s, q8_2); - __m256i q8s_3 = _mm256_maddubs_epi16(m32s, q8_3); - __m256i p16_0 = _mm256_maddubs_epi16(q4_0, q8_0); __m256i p16_1 = _mm256_maddubs_epi16(q4_1, q8_1); __m256i p16_2 = _mm256_maddubs_epi16(q4_2, q8_2); __m256i p16_3 = _mm256_maddubs_epi16(q4_3, q8_3); - p16_0 = _mm256_sub_epi16(p16_0, q8s_0); - p16_1 = _mm256_sub_epi16(p16_1, q8s_1); - p16_2 = _mm256_sub_epi16(p16_2, q8s_2); - p16_3 = _mm256_sub_epi16(p16_3, q8s_3); + const __m128i scale_0 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 0)); + const __m128i scale_1 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 1)); + const __m128i scale_2 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 2)); + const __m128i scale_3 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 3)); + is += 4; p16_0 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_0), p16_0); p16_1 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_1), p16_1); @@ -2372,6 +2363,7 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi } + sumi = _mm256_sub_epi32(sumi, q8sclsub); acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc); } From 2f3df42cddca762047c2884342b683549420be71 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Sun, 26 Apr 2026 08:28:14 +0200 Subject: [PATCH 498/831] ggml-cpu : re-enable fast gelu_quick_f16 (llama/22339) --- ggml/src/ggml-cpu/vec.h | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/ggml/src/ggml-cpu/vec.h b/ggml/src/ggml-cpu/vec.h index a0375a28de0..bcd68da9aa9 100644 --- a/ggml/src/ggml-cpu/vec.h +++ b/ggml/src/ggml-cpu/vec.h @@ -1036,12 +1036,12 @@ inline static float ggml_gelu_quick_f32(float x) { return x*(1.0f/(1.0f+expf(GELU_QUICK_COEF*x))); } -//inline static void ggml_vec_gelu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { -// const uint16_t * i16 = (const uint16_t *) x; -// for (int i = 0; i < n; ++i) { -// y[i] = ggml_table_gelu_quick_f16[i16[i]]; -// } -//} +inline static void ggml_vec_gelu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { + const uint16_t * i16 = (const uint16_t *) x; + for (int i = 0; i < n; ++i) { + y[i] = ggml_table_gelu_quick_f16[i16[i]]; + } +} #ifdef GGML_GELU_QUICK_FP16 inline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float * x) { @@ -1060,13 +1060,6 @@ inline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float * } #endif -inline static void ggml_vec_gelu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { - for (int i = 0; i < n; ++i) { - float v = GGML_CPU_FP16_TO_FP32(x[i]); - y[i] = GGML_CPU_FP32_TO_FP16(v*(1.0f/(1.0f+expf(GELU_QUICK_COEF*v)))); - } -} - // Sigmoid Linear Unit (SiLU) function inline static float ggml_silu_f32(float x) { return x/(1.0f + expf(-x)); From 9bf6c3c8602b976b79139d908fd63ffe048749ba Mon Sep 17 00:00:00 2001 From: Oliver Simons Date: Sun, 26 Apr 2026 09:21:45 +0200 Subject: [PATCH 499/831] CUDA: better coalesce data-access for contiguous concat (llama/22330) Also, distribute all elements across CTAs evenly instead of launching one CTA per dim --- ggml/src/ggml-cuda/concat.cu | 141 +++++++++++++++-------------------- 1 file changed, 62 insertions(+), 79 deletions(-) diff --git a/ggml/src/ggml-cuda/concat.cu b/ggml/src/ggml-cuda/concat.cu index e9ffd274b99..102f944f924 100644 --- a/ggml/src/ggml-cuda/concat.cu +++ b/ggml/src/ggml-cuda/concat.cu @@ -1,96 +1,79 @@ #include "concat.cuh" // contiguous kernels -static __global__ void concat_f32_dim0(const float * x, const float * y, float * dst, const int ne0, const int ne00) { - int nidx = threadIdx.x + blockIdx.x * blockDim.x; - if (nidx >= ne0) { - return; - } - - int offset_dst = - nidx + - blockIdx.y * ne0 + - blockIdx.z * ne0 * gridDim.y; - - if (nidx < ne00) { // src0 - int offset_src = - nidx + - blockIdx.y * ne00 + - blockIdx.z * ne00 * gridDim.y; - dst[offset_dst] = x[offset_src]; - } else { - int offset_src = - (nidx - ne00) + - blockIdx.y * (ne0 - ne00) + - blockIdx.z * (ne0 - ne00) * gridDim.y; - dst[offset_dst] = y[offset_src]; - } -} - -static __global__ void concat_f32_dim1(const float * x, const float * y, float * dst, const int ne0, const int ne01) { - int nidx = threadIdx.x + blockIdx.x * blockDim.x; - if (nidx >= ne0) { - return; - } +template +static __global__ void __launch_bounds__(CUDA_CONCAT_BLOCK_SIZE) concat_f32_cont(const float * x, + const float * y, + float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne0, + int64_t ne1, + int64_t ne2) { + static_assert(dim >= 0 && dim <= 2, "dim must be in [0, 2]"); + + const int64_t n = ne0 * ne1 * ne2; + + for (int64_t i = (int64_t) blockIdx.x * blockDim.x + threadIdx.x; i < n; i += (int64_t) blockDim.x * gridDim.x) { + if constexpr (dim == 0) { + const int64_t row = i / ne0; + const int64_t i0 = i - row * ne0; + + if (i0 < ne00) { + dst[i] = x[row * ne00 + i0]; + } else { + dst[i] = y[row * (ne0 - ne00) + (i0 - ne00)]; + } + } else if constexpr (dim == 1) { + const int64_t dst_plane = ne0 * ne1; + const int64_t src0_plane = ne0 * ne01; + const int64_t src1_plane = dst_plane - src0_plane; + const int64_t i2 = i / dst_plane; + const int64_t i01 = i - i2 * dst_plane; + + if (i01 < src0_plane) { + dst[i] = x[i2 * src0_plane + i01]; + } else { + dst[i] = y[i2 * src1_plane + (i01 - src0_plane)]; + } + } else { + const int64_t src0_size = ne0 * ne1 * ne02; - int offset_dst = - nidx + - blockIdx.y * ne0 + - blockIdx.z * ne0 * gridDim.y; - - if (blockIdx.y < (unsigned)ne01) { // src0 - int offset_src = - nidx + - blockIdx.y * ne0 + - blockIdx.z * ne0 * ne01; - dst[offset_dst] = x[offset_src]; - } else { - int offset_src = - nidx + - (blockIdx.y - ne01) * ne0 + - blockIdx.z * ne0 * (gridDim.y - ne01); - dst[offset_dst] = y[offset_src]; + if (i < src0_size) { + dst[i] = x[i]; + } else { + dst[i] = y[i - src0_size]; + } + } } } -static __global__ void concat_f32_dim2(const float * x, const float * y, float * dst, const int ne0, const int ne02) { - int nidx = threadIdx.x + blockIdx.x * blockDim.x; - if (nidx >= ne0) { - return; - } - - int offset_dst = - nidx + - blockIdx.y * ne0 + - blockIdx.z * ne0 * gridDim.y; - - if (blockIdx.z < (unsigned)ne02) { // src0 - int offset_src = - nidx + - blockIdx.y * ne0 + - blockIdx.z * ne0 * gridDim.y; - dst[offset_dst] = x[offset_src]; - } else { - int offset_src = - nidx + - blockIdx.y * ne0 + - (blockIdx.z - ne02) * ne0 * gridDim.y; - dst[offset_dst] = y[offset_src]; - } -} +static void concat_f32_cuda(const float * x, + const float * y, + float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne0, + int64_t ne1, + int64_t ne2, + int dim, + cudaStream_t stream) { + const int64_t n = ne0 * ne1 * ne2; + const int num_blocks = (n + CUDA_CONCAT_BLOCK_SIZE - 1) / CUDA_CONCAT_BLOCK_SIZE; -static void concat_f32_cuda(const float * x, const float * y, float * dst, int ne00, int ne01, int ne02, int ne0, int ne1, int ne2, int dim, cudaStream_t stream) { - int num_blocks = (ne0 + CUDA_CONCAT_BLOCK_SIZE - 1) / CUDA_CONCAT_BLOCK_SIZE; - dim3 gridDim(num_blocks, ne1, ne2); if (dim == 0) { - concat_f32_dim0<<>>(x, y, dst, ne0, ne00); + concat_f32_cont<0> + <<>>(x, y, dst, ne00, ne01, ne02, ne0, ne1, ne2); return; } if (dim == 1) { - concat_f32_dim1<<>>(x, y, dst, ne0, ne01); + concat_f32_cont<1> + <<>>(x, y, dst, ne00, ne01, ne02, ne0, ne1, ne2); return; } - concat_f32_dim2<<>>(x, y, dst, ne0, ne02); + concat_f32_cont<2><<>>(x, y, dst, ne00, ne01, ne02, ne0, ne1, ne2); } // non-contiguous kernel (slow) From 7296b9c7faec4df1e683d0ef652c3ed4c79ac6ff Mon Sep 17 00:00:00 2001 From: Gaurav Garg Date: Sun, 26 Apr 2026 17:04:40 +0530 Subject: [PATCH 500/831] Fix recurrent state serialization for partial reads and writes (llama/22362) The previous code worked only for full tensor reads and writes and was hitting `GGML_ASSERT(size == ggml_nbytes(tensor)); ` assert when tested with llama-server. --- ggml/src/ggml-backend-meta.cpp | 66 +++++++++++++++++++++++++--------- 1 file changed, 50 insertions(+), 16 deletions(-) diff --git a/ggml/src/ggml-backend-meta.cpp b/ggml/src/ggml-backend-meta.cpp index 6d22f3421b1..41a61775bd6 100644 --- a/ggml/src/ggml-backend-meta.cpp +++ b/ggml/src/ggml-backend-meta.cpp @@ -1205,40 +1205,57 @@ static void ggml_backend_meta_buffer_set_tensor(ggml_backend_buffer_t buffer, gg if (split_state.n_segments != 1) { GGML_ASSERT(split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS); - GGML_ASSERT(offset == 0); - GGML_ASSERT(size == ggml_nbytes(tensor)); GGML_ASSERT(tensor->ne[3] == 1); + size_t offset_data = 0; std::vector simple_offsets(n_bufs, 0); if (split_state.axis == GGML_BACKEND_SPLIT_AXIS_0) { GGML_ASSERT(tensor->ne[2] == 1); + + const size_t row_stride = tensor->nb[1]; + GGML_ASSERT(offset % row_stride == 0); + GGML_ASSERT(size % row_stride == 0); + const int64_t r_start = offset / row_stride; + const int64_t r_count = size / row_stride; + GGML_ASSERT(r_start + r_count <= tensor->ne[1]); + const int64_t blck_size = ggml_blck_size(tensor->type); for (size_t s = 0; s < split_state.n_segments; s++) { for (size_t j = 0; j < n_bufs; j++) { ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); GGML_ASSERT(split_state.ne[s*n_bufs + j] % blck_size == 0); const size_t nbytes = split_state.ne[s*n_bufs + j]/blck_size * tensor->nb[0]; - ggml_backend_tensor_set_2d(simple_tensor, (const char *) data + offset_data, simple_offsets[j], nbytes, - tensor->ne[1], simple_tensor->nb[1], tensor->nb[1]); + ggml_backend_tensor_set_2d(simple_tensor, (const char *) data + offset_data, + simple_offsets[j] + r_start * simple_tensor->nb[1], nbytes, + r_count, simple_tensor->nb[1], tensor->nb[1]); offset_data += nbytes; simple_offsets[j] += nbytes; } } - GGML_ASSERT(offset_data*tensor->ne[1] == size); + GGML_ASSERT(offset_data*r_count == size); return; } GGML_ASSERT(split_state.axis == GGML_BACKEND_SPLIT_AXIS_1); + + const size_t row_stride = tensor->nb[2]; + GGML_ASSERT(offset % row_stride == 0); + GGML_ASSERT(size % row_stride == 0); + const int64_t r_start = offset / row_stride; + const int64_t r_count = size / row_stride; + GGML_ASSERT(r_start + r_count <= tensor->ne[2]); + for (size_t s = 0; s < split_state.n_segments; s++) { for (size_t j = 0; j < n_bufs; j++) { ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); const size_t nbytes = split_state.ne[s*n_bufs + j] * tensor->nb[1]; - ggml_backend_tensor_set_2d(simple_tensor, (const char *) data + offset_data, simple_offsets[j], nbytes, - tensor->ne[2], simple_tensor->nb[2], tensor->nb[2]); + ggml_backend_tensor_set_2d(simple_tensor, (const char *) data + offset_data, + simple_offsets[j] + r_start * simple_tensor->nb[2], nbytes, + r_count, simple_tensor->nb[2], tensor->nb[2]); offset_data += nbytes; simple_offsets[j] += nbytes; } } - GGML_ASSERT(offset_data*tensor->ne[2] == size); + GGML_ASSERT(offset_data*r_count == size); return; } @@ -1295,40 +1312,57 @@ static void ggml_backend_meta_buffer_get_tensor(ggml_backend_buffer_t buffer, co if (split_state.n_segments != 1) { GGML_ASSERT(split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS); - GGML_ASSERT(offset == 0); - GGML_ASSERT(size == ggml_nbytes(tensor)); GGML_ASSERT(tensor->ne[3] == 1); + size_t offset_data = 0; std::vector simple_offsets(n_bufs, 0); if (split_state.axis == GGML_BACKEND_SPLIT_AXIS_0) { GGML_ASSERT(tensor->ne[2] == 1); + + const size_t row_stride = tensor->nb[1]; + GGML_ASSERT(offset % row_stride == 0); + GGML_ASSERT(size % row_stride == 0); + const int64_t r_start = offset / row_stride; + const int64_t r_count = size / row_stride; + GGML_ASSERT(r_start + r_count <= tensor->ne[1]); + const int64_t blck_size = ggml_blck_size(tensor->type); for (size_t s = 0; s < split_state.n_segments; s++) { for (size_t j = 0; j < n_bufs; j++) { const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); GGML_ASSERT(split_state.ne[s*n_bufs + j] % blck_size == 0); const size_t nbytes = split_state.ne[s*n_bufs + j]/blck_size * tensor->nb[0]; - ggml_backend_tensor_get_2d(simple_tensor, (char *) data + offset_data, simple_offsets[j], nbytes, - tensor->ne[1], simple_tensor->nb[1], tensor->nb[1]); + ggml_backend_tensor_get_2d(simple_tensor, (char *) data + offset_data, + simple_offsets[j] + r_start * simple_tensor->nb[1], nbytes, + r_count, simple_tensor->nb[1], tensor->nb[1]); offset_data += nbytes; simple_offsets[j] += nbytes; } } - GGML_ASSERT(offset_data*tensor->ne[1] == size); + GGML_ASSERT(offset_data*r_count == size); return; } GGML_ASSERT(split_state.axis == GGML_BACKEND_SPLIT_AXIS_1); + + const size_t row_stride = tensor->nb[2]; + GGML_ASSERT(offset % row_stride == 0); + GGML_ASSERT(size % row_stride == 0); + const int64_t r_start = offset / row_stride; + const int64_t r_count = size / row_stride; + GGML_ASSERT(r_start + r_count <= tensor->ne[2]); + for (size_t s = 0; s < split_state.n_segments; s++) { for (size_t j = 0; j < n_bufs; j++) { const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); const size_t nbytes = split_state.ne[s*n_bufs + j] * tensor->nb[1]; - ggml_backend_tensor_get_2d(simple_tensor, (char *) data + offset_data, simple_offsets[j], nbytes, - tensor->ne[2], simple_tensor->nb[2], tensor->nb[2]); + ggml_backend_tensor_get_2d(simple_tensor, (char *) data + offset_data, + simple_offsets[j] + r_start * simple_tensor->nb[2], nbytes, + r_count, simple_tensor->nb[2], tensor->nb[2]); offset_data += nbytes; simple_offsets[j] += nbytes; } } - GGML_ASSERT(offset_data*tensor->ne[2] == size); + GGML_ASSERT(offset_data*r_count == size); return; } From 1478450e61487ae2cd44916d902ccd626539de47 Mon Sep 17 00:00:00 2001 From: Rithik Sharma Date: Sun, 26 Apr 2026 09:26:28 -0700 Subject: [PATCH 501/831] add performance-portable tuning for register-tile and subgroup matmul (llama/22241) --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 36 ++++++++++++++----- 1 file changed, 28 insertions(+), 8 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 16ebc32cbc7..503171ee14f 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -26,20 +26,23 @@ // Matrix multiplication parameters // Register tiling parameters -#define WEBGPU_MUL_MAT_TILE_M 8 -#define WEBGPU_MUL_MAT_TILE_N 8 +#define WEBGPU_MUL_MAT_TILE_M 4 +#define WEBGPU_MUL_MAT_TILE_N 4 #define WEBGPU_MUL_MAT_WG_SIZE_M 8 #define WEBGPU_MUL_MAT_WG_SIZE_N 8 -#define WEBGPU_MUL_MAT_TILE_K 32 +#define WEBGPU_MUL_MAT_REG_TILE_K_FLOAT 8 +#define WEBGPU_MUL_MAT_REG_TILE_K_QUANT 32 // Subgroup matrix parameters // The number of subgroups in the M dimension #define WEBGPU_MUL_MAT_SUBGROUP_M 2 // The number of subgroups in the N dimension -#define WEBGPU_MUL_MAT_SUBGROUP_N 2 +#define WEBGPU_MUL_MAT_SUBGROUP_N 4 // The number of subgroup matrices each subgroup accumulates over #define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M 4 #define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2 +#define WEBGPU_MUL_MAT_SUBGROUP_TILE_K_FLOAT 32 +#define WEBGPU_MUL_MAT_SUBGROUP_TILE_K_QUANT 32 // Matrix-vector multiplication parameters #define WEBGPU_MUL_MAT_VEC_WG_SIZE 256 @@ -1734,13 +1737,24 @@ class ggml_webgpu_shader_lib { // VEC/SCALAR controls defines.push_back(key.vectorized ? "VEC" : "SCALAR"); + const bool is_quant = ggml_is_quantized(context.src0->type); + + uint32_t tile_k; + if (key.use_subgroup_matrix) { + tile_k = is_quant ? WEBGPU_MUL_MAT_SUBGROUP_TILE_K_QUANT + : WEBGPU_MUL_MAT_SUBGROUP_TILE_K_FLOAT; + } else { + tile_k = is_quant ? WEBGPU_MUL_MAT_REG_TILE_K_QUANT + : WEBGPU_MUL_MAT_REG_TILE_K_FLOAT; + } + // Tiles defines.push_back("TILE_M=" + std::to_string(WEBGPU_MUL_MAT_TILE_M) + "u"); defines.push_back("TILE_N=" + std::to_string(WEBGPU_MUL_MAT_TILE_N) + "u"); - defines.push_back("TILE_K=" + std::to_string(WEBGPU_MUL_MAT_TILE_K) + "u"); // Subgroup matrix specifics if (key.use_subgroup_matrix) { + defines.push_back("TILE_K=" + std::to_string(tile_k) + "u"); defines.push_back("MAX_SUBGROUP_SIZE=" + std::to_string(context.max_subgroup_size) + "u"); defines.push_back("SUBGROUP_M=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_M) + "u"); defines.push_back("SUBGROUP_N=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_N) + "u"); @@ -1760,12 +1774,13 @@ class ggml_webgpu_shader_lib { if (!key.use_subgroup_matrix) { defines.push_back("WORKGROUP_SIZE_M=" + std::to_string(WEBGPU_MUL_MAT_WG_SIZE_M) + "u"); defines.push_back("WORKGROUP_SIZE_N=" + std::to_string(WEBGPU_MUL_MAT_WG_SIZE_N) + "u"); + defines.push_back("TILE_K=" + std::to_string(tile_k) + "u"); } auto processed = preprocessor.preprocess(shader_src, defines); auto decisions = std::make_shared(); - decisions->tile_k = WEBGPU_MUL_MAT_TILE_K; + decisions->tile_k = tile_k; decisions->tile_m = WEBGPU_MUL_MAT_TILE_M; decisions->tile_n = WEBGPU_MUL_MAT_TILE_N; decisions->use_subgroup_matrix = key.use_subgroup_matrix; @@ -1962,10 +1977,15 @@ class ggml_webgpu_shader_lib { defines.push_back("SCALAR"); + // mul_mat_id is register-tile only. + const uint32_t tile_k = ggml_is_quantized(context.src0->type) + ? WEBGPU_MUL_MAT_REG_TILE_K_QUANT + : WEBGPU_MUL_MAT_REG_TILE_K_FLOAT; + // Tiles defines.push_back("TILE_M=" + std::to_string(WEBGPU_MUL_MAT_TILE_M) + "u"); defines.push_back("TILE_N=" + std::to_string(WEBGPU_MUL_MAT_TILE_N) + "u"); - defines.push_back("TILE_K=" + std::to_string(WEBGPU_MUL_MAT_TILE_K) + "u"); + defines.push_back("TILE_K=" + std::to_string(tile_k) + "u"); defines.push_back("WORKGROUP_SIZE_M=" + std::to_string(WEBGPU_MUL_MAT_WG_SIZE_M) + "u"); defines.push_back("WORKGROUP_SIZE_N=" + std::to_string(WEBGPU_MUL_MAT_WG_SIZE_N) + "u"); @@ -1976,7 +1996,7 @@ class ggml_webgpu_shader_lib { auto processed = preprocessor.preprocess(wgsl_mul_mat_id, defines); auto decisions = std::make_shared(); - decisions->tile_k = WEBGPU_MUL_MAT_TILE_K; + decisions->tile_k = tile_k; decisions->tile_m = WEBGPU_MUL_MAT_TILE_M; decisions->tile_n = WEBGPU_MUL_MAT_TILE_N; decisions->wg_size_m = WEBGPU_MUL_MAT_WG_SIZE_M; From f5c3ce17d563b7a86561062c1cd82ad7b1ebdd24 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrien=20Gallou=C3=ABt?= Date: Mon, 27 Apr 2026 08:30:55 +0200 Subject: [PATCH 502/831] ggml : use 64 bytes aligned tile buffers (llama/21058) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit | Model | Test | t/s OLD | t/s NEW | Speedup | |:---------------------------------|:-------|----------:|----------:|----------:| | qwen35 0.8B BF16 | pp512 | 584.59 | 595.41 | 1.02 | | qwen35 0.8B BF16 | tg128 | 52.23 | 52.82 | 1.01 | | qwen35 0.8B IQ2_M - 2.7 bpw | pp512 | 260.64 | 261.70 | 1.00 | | qwen35 0.8B IQ2_M - 2.7 bpw | tg128 | 81.17 | 80.89 | 1.00 | | qwen35 0.8B IQ2_XXS - 2.0625 bpw | pp512 | 302.36 | 302.56 | 1.00 | | qwen35 0.8B IQ2_XXS - 2.0625 bpw | tg128 | 84.93 | 85.12 | 1.00 | | qwen35 0.8B IQ3_XXS - 3.0625 bpw | pp512 | 263.22 | 260.01 | 0.99 | | qwen35 0.8B IQ3_XXS - 3.0625 bpw | tg128 | 80.29 | 78.94 | 0.98 | | qwen35 0.8B IQ4_NL - 4.5 bpw | pp512 | 728.65 | 742.09 | 1.02 | | qwen35 0.8B IQ4_NL - 4.5 bpw | tg128 | 82.39 | 84.46 | 1.03 | | qwen35 0.8B IQ4_XS - 4.25 bpw | pp512 | 681.33 | 677.06 | 0.99 | | qwen35 0.8B IQ4_XS - 4.25 bpw | tg128 | 80.18 | 79.28 | 0.99 | | qwen35 0.8B Q2_K_M | pp512 | 413.28 | 415.94 | 1.01 | | qwen35 0.8B Q2_K_M | tg128 | 81.90 | 82.78 | 1.01 | | qwen35 0.8B Q3_K_M | pp512 | 493.17 | 495.08 | 1.00 | | qwen35 0.8B Q3_K_M | tg128 | 82.75 | 83.23 | 1.01 | | qwen35 0.8B Q3_K_S | pp512 | 429.35 | 427.64 | 1.00 | | qwen35 0.8B Q3_K_S | tg128 | 86.69 | 87.02 | 1.00 | | qwen35 0.8B Q4_0 | pp512 | 783.46 | 782.32 | 1.00 | | qwen35 0.8B Q4_0 | tg128 | 88.23 | 87.90 | 1.00 | | qwen35 0.8B Q4_1 | pp512 | 741.71 | 729.76 | 0.98 | | qwen35 0.8B Q4_1 | tg128 | 85.44 | 86.01 | 1.01 | | qwen35 0.8B Q4_K_M | pp512 | 676.24 | 681.31 | 1.01 | | qwen35 0.8B Q4_K_M | tg128 | 76.59 | 77.06 | 1.01 | | qwen35 0.8B Q4_K_S | pp512 | 683.12 | 688.81 | 1.01 | | qwen35 0.8B Q4_K_S | tg128 | 80.50 | 81.19 | 1.01 | | qwen35 0.8B Q5_K_M | pp512 | 635.33 | 642.11 | 1.01 | | qwen35 0.8B Q5_K_M | tg128 | 72.07 | 72.49 | 1.01 | | qwen35 0.8B Q5_K_S | pp512 | 660.95 | 658.18 | 1.00 | | qwen35 0.8B Q5_K_S | tg128 | 72.19 | 72.95 | 1.01 | | qwen35 0.8B Q6_K | pp512 | 647.97 | 638.84 | 0.99 | | qwen35 0.8B Q6_K | tg128 | 72.83 | 72.49 | 1.00 | | qwen35 0.8B Q8_0 | pp512 | 805.01 | 785.49 | 0.98 | | qwen35 0.8B Q8_0 | tg128 | 70.10 | 70.13 | 1.00 | Signed-off-by: Adrien Gallouët --- ggml/src/ggml-cpu/amx/mmq.cpp | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/ggml/src/ggml-cpu/amx/mmq.cpp b/ggml/src/ggml-cpu/amx/mmq.cpp index 93a6d397f79..d9383a04be8 100644 --- a/ggml/src/ggml-cpu/amx/mmq.cpp +++ b/ggml/src/ggml-cpu/amx/mmq.cpp @@ -2005,12 +2005,12 @@ void tinygemm_kernel_amx(int M, int N, int KB, const void * RESTRICT _A, const v const int lda = KB * sizeof(TA); //const int ldb = KB * sizeof(TB); - static thread_local packed_B_t Tile0[TILE_N * TILE_K]; - static thread_local packed_B_t Tile1[TILE_N * TILE_K]; - static thread_local int8_t Tile23[TILE_M * TILE_K]; + alignas(64) static thread_local packed_B_t Tile0[TILE_N * TILE_K]; + alignas(64) static thread_local packed_B_t Tile1[TILE_N * TILE_K]; + alignas(64) static thread_local int8_t Tile23[TILE_M * TILE_K]; - static thread_local int32_t TileC0[TILE_M * TILE_N * 4]; - static thread_local int32_t TileC1[TILE_M * TILE_N * 4]; + alignas(64) static thread_local int32_t TileC0[TILE_M * TILE_N * 4]; + alignas(64) static thread_local int32_t TileC1[TILE_M * TILE_N * 4]; // double buffering C to interleave avx512 and amx int32_t * C_cur = TileC0; @@ -2187,21 +2187,21 @@ void tinygemm_kernel_amx(int M, int N, int KB, const void * RESTRICT _A, const v const int m1 = std::max(M - TILE_M, 0); //const int lda = KB * sizeof(TA); - static thread_local int8_t Tile0[TILE_N * TILE_K]; - static thread_local int8_t Tile1[TILE_N * TILE_K]; - static thread_local int8_t Tile23[TILE_M * TILE_K]; + alignas(64) static thread_local int8_t Tile0[TILE_N * TILE_K]; + alignas(64) static thread_local int8_t Tile1[TILE_N * TILE_K]; + alignas(64) static thread_local int8_t Tile23[TILE_M * TILE_K]; // mat mul result for each group - static thread_local int32_t Tile4[TILE_M * TILE_N]; - static thread_local int32_t Tile5[TILE_M * TILE_N]; - static thread_local int32_t Tile6[TILE_M * TILE_N]; - static thread_local int32_t Tile7[TILE_M * TILE_N]; + alignas(64) static thread_local int32_t Tile4[TILE_M * TILE_N]; + alignas(64) static thread_local int32_t Tile5[TILE_M * TILE_N]; + alignas(64) static thread_local int32_t Tile6[TILE_M * TILE_N]; + alignas(64) static thread_local int32_t Tile7[TILE_M * TILE_N]; // sum of each QK_K block, contains 8 groups, int32 - static thread_local int32_t Sumi4[TILE_M * TILE_N]; - static thread_local int32_t Sumi5[TILE_M * TILE_N]; - static thread_local int32_t Sumi6[TILE_M * TILE_N]; - static thread_local int32_t Sumi7[TILE_M * TILE_N]; + alignas(64) static thread_local int32_t Sumi4[TILE_M * TILE_N]; + alignas(64) static thread_local int32_t Sumi5[TILE_M * TILE_N]; + alignas(64) static thread_local int32_t Sumi6[TILE_M * TILE_N]; + alignas(64) static thread_local int32_t Sumi7[TILE_M * TILE_N]; const int k_group_size = std::is_same::value ? 16 : 32; for (int i = 0; i < KB; ++i) { From c9ba41397cb81e13f98d896a5a63fd5e9a1ea8dc Mon Sep 17 00:00:00 2001 From: unraido <127105806+unraido@users.noreply.github.com> Date: Mon, 27 Apr 2026 23:25:09 +0900 Subject: [PATCH 503/831] fix: rpc-server cache may not work in Windows environments (llama/22394) * fix: create directory and log cache file name. * Remove GGML_LOG_INFO conditional compilation. --------- Co-authored-by: kotaro --- ggml/src/ggml-rpc/ggml-rpc.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index 2ded7397868..505bec73d37 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -1101,7 +1101,7 @@ bool rpc_server::set_tensor(const std::vector & input) { fs::path cache_file = fs::path(cache_dir) / hash_str; std::ofstream ofs(cache_file, std::ios::binary); ofs.write((const char *)data, size); - GGML_LOG_INFO("[%s] saved to '%s'\n", __func__, cache_file.c_str()); + GGML_LOG_INFO("[%s] saved to '%s'\n", __func__, cache_file.string().c_str()); } ggml_backend_tensor_set(tensor, data, offset, size); return true; From f675a8c9264682c720ae0d3b7badb06227065cc3 Mon Sep 17 00:00:00 2001 From: Rithik Sharma Date: Mon, 27 Apr 2026 08:25:45 -0700 Subject: [PATCH 504/831] add fast mat-vec kernels for i-quants (llama/22344) --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 18 + ggml/src/ggml-webgpu/ggml-webgpu.cpp | 11 + .../ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl | 514 ++++++++++++++++++ 3 files changed, 543 insertions(+) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 503171ee14f..08ea2906ada 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -1615,6 +1615,24 @@ class ggml_webgpu_shader_lib { defines.push_back("MUL_ACC_" + type_upper); defines.push_back("U32_DEQUANT_HELPERS"); defines.push_back("SRC0_INNER_TYPE=u32"); + switch (context.src0->type) { + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ4_XS: + defines.push_back(type_upper + "_GRID"); + break; + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ3_XXS: + defines.push_back(type_upper + "_GRID"); + defines.push_back(type_upper + "_TABLES"); + break; + default: + break; + } break; } } diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index bcec20c1a11..d6d7dbdaf3c 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -1391,6 +1391,17 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, case GGML_TYPE_Q2_K: use_fast = true; break; + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ4_XS: + use_fast = is_vec; + break; default: break; } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl index 97c9f6d7a09..c2eafee6c75 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl @@ -812,6 +812,520 @@ fn main( } #endif +#ifdef MUL_ACC_IQ1_S +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 50 +#define THREADS_PER_BLOCK 16 + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let sub_blk = tid / 2u; + let half = tid % 2u; + let slot0 = half * 2u; + let y_offset = sub_blk * 32u + slot0 * 8u; + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; + for (var i = 0u; i < 16u; i++) { + x_block[i] = f32(src1[x_base + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + + let d = f32(load_f16_at_src0(block_byte_base)); + let qh = load_u32_at_src0(block_byte_base + 34u + sub_blk * 2u) & 0xFFFFu; + let dl = d * f32(2u * ((qh >> 12u) & 7u) + 1u); + let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000u) != 0u); + let qs_w = load_u32_at_src0(block_byte_base + 2u + sub_blk * 4u); + + var row_sum = 0.0; + for (var ll = 0u; ll < 2u; ll++) { + let l = slot0 + ll; + let qs_byte = get_byte(qs_w, l); + let ig = (qs_byte | (((qh >> (3u * l)) & 7u) << 8u)) * 8u; + let gw = iq1_grid[ig / 16u]; + let bit_base = (ig % 16u) * 2u; + for (var j = 0u; j < 8u; j++) { + let g = (gw >> (bit_base + j * 2u)) & 3u; + let gs = select(f32(g), f32(g) - 4.0, (g & 2u) != 0u); + row_sum += dl * (gs + delta) * x_block[ll * 8u + j]; + } + } + acc[row] += row_sum; + } + } + } +#endif + +#ifdef MUL_ACC_IQ1_M +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 56 +#define THREADS_PER_BLOCK 16 + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let sub_blk = tid / 2u; + let half = tid % 2u; + let slot0 = half * 2u; + let y_offset = sub_blk * 32u + slot0 * 8u; + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; + for (var i = 0u; i < 16u; i++) { + x_block[i] = f32(src1[x_base + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + + let sc_lo = load_u32_at_src0(block_byte_base + 48u); + let sc_hi = load_u32_at_src0(block_byte_base + 52u); + let sc0 = sc_lo & 0xFFFFu; + let sc1 = (sc_lo >> 16u) & 0xFFFFu; + let sc2 = sc_hi & 0xFFFFu; + let sc3 = (sc_hi >> 16u) & 0xFFFFu; + let d_bits = (sc0 >> 12u) | ((sc1 >> 8u) & 0xF0u) | ((sc2 >> 4u) & 0xF00u) | (sc3 & 0xF000u); + let d = f32(bitcast>(d_bits)[0]); + + let sc_u16 = select(select(sc2, sc3, sub_blk >= 6u), + select(sc0, sc1, sub_blk >= 2u), + sub_blk < 4u); + + let qs_w = load_u32_at_src0(block_byte_base + sub_blk * 4u); + let qh = load_u32_at_src0(block_byte_base + 32u + sub_blk * 2u) & 0xFFFFu; + let qh_lo = qh & 0xFFu; + let qh_hi = (qh >> 8u) & 0xFFu; + + var row_sum = 0.0; + for (var ll = 0u; ll < 2u; ll++) { + let l = slot0 + ll; + let bit_off = 6u * (sub_blk % 2u) + 3u * (l / 2u); + let sub_scale = (sc_u16 >> bit_off) & 0x7u; + let dl = d * f32(2u * sub_scale + 1u); + let qh_byte = select(qh_lo, qh_hi, l >= 2u); + let ll2 = l % 2u; + let grid_idx = get_byte(qs_w, l) | (((qh_byte >> (4u * ll2)) & 7u) << 8u); + let delta = select(IQ1_DELTA, -IQ1_DELTA, ((qh_byte >> (3u + 4u * ll2)) & 1u) != 0u); + let ig = grid_idx * 8u; + let gw = iq1_grid[ig / 16u]; + let bit_base = (ig % 16u) * 2u; + for (var j = 0u; j < 8u; j++) { + let g = (gw >> (bit_base + j * 2u)) & 3u; + let gs = select(f32(g), f32(g) - 4.0, (g & 2u) != 0u); + row_sum += dl * (gs + delta) * x_block[ll * 8u + j]; + } + } + acc[row] += row_sum; + } + } + } +#endif + +#ifdef MUL_ACC_IQ2_XXS +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 66 +#define THREADS_PER_BLOCK 16 + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let sub_blk = tid / 2u; + let half = tid % 2u; + let slot0 = half * 2u; + let y_offset = sub_blk * 32u + slot0 * 8u; + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; + for (var i = 0u; i < 16u; i++) { + x_block[i] = f32(src1[x_base + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + let aux_lo = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u); + let aux_hi = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u + 4u); + let ls = aux_hi >> 28u; + let db = d * (0.5 + f32(ls)) * 0.25; + + var row_sum = 0.0; + for (var ll = 0u; ll < 2u; ll++) { + let l = slot0 + ll; + let grid_idx = (aux_lo >> (8u * l)) & 0xFFu; + let signs_idx = (aux_hi >> (7u * l)) & 0x7Fu; + let signs = (ksigns_iq2xs[signs_idx / 4u] >> ((signs_idx % 4u) * 8u)) & 0xFFu; + let gw_lo = iq2xxs_grid[grid_idx * 2u]; + let gw_hi = iq2xxs_grid[grid_idx * 2u + 1u]; + for (var j = 0u; j < 8u; j++) { + let gw = select(gw_hi, gw_lo, j < 4u); + let b = f32((gw >> ((j & 3u) * 8u)) & 0xFFu); + let s = select(1.0, -1.0, ((signs >> j) & 1u) != 0u); + row_sum += db * b * s * x_block[ll * 8u + j]; + } + } + acc[row] += row_sum; + } + } + } +#endif + +#ifdef MUL_ACC_IQ2_XS +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 74 +#define THREADS_PER_BLOCK 16 + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let sub_blk = tid / 2u; + let half = tid % 2u; + let slot0 = half * 2u; + let y_offset = sub_blk * 32u + slot0 * 8u; + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; + for (var i = 0u; i < 16u; i++) { + x_block[i] = f32(src1[x_base + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + let qs_lo = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u); + let qs_hi = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u + 4u); + let scales_word = load_u32_at_src0(block_byte_base + 66u + (sub_blk / 4u) * 4u); + let scales_byte = get_byte(scales_word, sub_blk % 4u); + + var row_sum = 0.0; + for (var ll = 0u; ll < 2u; ll++) { + let l = slot0 + ll; + let qs_word = select(qs_hi, qs_lo, l < 2u); + let half2 = (l % 2u) * 16u; + let qs_val = (qs_word >> half2) & 0xFFFFu; + let grid_idx = qs_val & 0x1FFu; + let signs_idx = (qs_val >> 9u) & 0x7Fu; + let sub_scale = (scales_byte >> (4u * (l / 2u))) & 0xFu; + let db = d * (0.5 + f32(sub_scale)) * 0.25; + let signs = (ksigns_iq2xs[signs_idx / 4u] >> ((signs_idx % 4u) * 8u)) & 0xFFu; + let gw_lo = iq2xs_grid[grid_idx * 2u]; + let gw_hi = iq2xs_grid[grid_idx * 2u + 1u]; + for (var j = 0u; j < 8u; j++) { + let gw = select(gw_hi, gw_lo, j < 4u); + let b = f32((gw >> ((j & 3u) * 8u)) & 0xFFu); + let s = select(1.0, -1.0, ((signs >> j) & 1u) != 0u); + row_sum += db * b * s * x_block[ll * 8u + j]; + } + } + acc[row] += row_sum; + } + } + } +#endif + +#ifdef MUL_ACC_IQ2_S +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 82 +#define THREADS_PER_BLOCK 16 + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let sub_blk = tid / 2u; + let half = tid % 2u; + let slot0 = half * 2u; + let y_offset = sub_blk * 32u + slot0 * 8u; + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; + for (var i = 0u; i < 16u; i++) { + x_block[i] = f32(src1[x_base + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + let qs_w = load_u32_at_src0(block_byte_base + 2u + sub_blk * 4u); + let sg_w = load_u32_at_src0(block_byte_base + 34u + sub_blk * 4u); + let qh_word = load_u32_at_src0(block_byte_base + 66u + (sub_blk / 4u) * 4u); + let qh_byte = get_byte(qh_word, sub_blk % 4u); + let sc_word = load_u32_at_src0(block_byte_base + 74u + (sub_blk / 4u) * 4u); + let scales_byte = get_byte(sc_word, sub_blk % 4u); + + var row_sum = 0.0; + for (var ll = 0u; ll < 2u; ll++) { + let l = slot0 + ll; + let qs_byte = get_byte(qs_w, l); + let sign_byte = get_byte(sg_w, l); + let grid_idx = qs_byte | (((qh_byte >> (2u * l)) & 3u) << 8u); + let sub_scale = (scales_byte >> (4u * (l / 2u))) & 0xFu; + let db = d * (0.5 + f32(sub_scale)) * 0.25; + let gw_lo = iq2s_grid[grid_idx * 2u]; + let gw_hi = iq2s_grid[grid_idx * 2u + 1u]; + for (var j = 0u; j < 8u; j++) { + let gw = select(gw_hi, gw_lo, j < 4u); + let b = f32((gw >> ((j & 3u) * 8u)) & 0xFFu); + let s = select(1.0, -1.0, ((sign_byte >> j) & 1u) != 0u); + row_sum += db * b * s * x_block[ll * 8u + j]; + } + } + acc[row] += row_sum; + } + } + } +#endif + +#ifdef MUL_ACC_IQ3_XXS +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 98 +#define THREADS_PER_BLOCK 16 + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let sub_blk = tid / 2u; + let half = tid % 2u; + let slot0 = half * 2u; + let y_offset = sub_blk * 32u + slot0 * 8u; + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; + for (var i = 0u; i < 16u; i++) { + x_block[i] = f32(src1[x_base + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + let qs_lo = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u); + let qs_hi = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u + 4u); + let aux = load_u32_at_src0(block_byte_base + 66u + sub_blk * 4u); + let ls = aux >> 28u; + let db = d * (0.5 + f32(ls)) * 0.5; + + var row_sum = 0.0; + for (var ll = 0u; ll < 2u; ll++) { + let l = slot0 + ll; + let qs_word = select(qs_hi, qs_lo, l < 2u); + let byte_pos = (l % 2u) * 2u; + let grid_idx_0 = (qs_word >> (byte_pos * 8u)) & 0xFFu; + let grid_idx_1 = (qs_word >> ((byte_pos + 1u) * 8u)) & 0xFFu; + let signs_idx = (aux >> (7u * l)) & 0x7Fu; + let signs = (ksigns_iq2xs[signs_idx / 4u] >> ((signs_idx % 4u) * 8u)) & 0xFFu; + let grid1 = iq3xxs_grid[grid_idx_0]; + let grid2 = iq3xxs_grid[grid_idx_1]; + for (var j = 0u; j < 4u; j++) { + let b1 = f32((grid1 >> (j * 8u)) & 0xFFu); + let b2 = f32((grid2 >> (j * 8u)) & 0xFFu); + let s1 = select(1.0, -1.0, ((signs >> j) & 1u) != 0u); + let s2 = select(1.0, -1.0, ((signs >> (j + 4u)) & 1u) != 0u); + row_sum += db * b1 * s1 * x_block[ll * 8u + j]; + row_sum += db * b2 * s2 * x_block[ll * 8u + j + 4u]; + } + } + acc[row] += row_sum; + } + } + } +#endif + +#ifdef MUL_ACC_IQ3_S +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 110 +#define THREADS_PER_BLOCK 16 + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let sub_blk = tid / 2u; + let half = tid % 2u; + let slot0 = half * 2u; + let y_offset = sub_blk * 32u + slot0 * 8u; + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; + for (var i = 0u; i < 16u; i++) { + x_block[i] = f32(src1[x_base + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + let qs_lo = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u); + let qs_hi = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u + 4u); + let qh_word = load_u32_at_src0(block_byte_base + 66u + (sub_blk / 4u) * 4u); + let qh_byte = get_byte(qh_word, sub_blk % 4u); + let sg_w = load_u32_at_src0(block_byte_base + 74u + sub_blk * 4u); + let sc_word = load_u32_at_src0(block_byte_base + 106u); + let scales_byte = get_byte(sc_word, sub_blk / 2u); + let sub_scale = (scales_byte >> (4u * (sub_blk % 2u))) & 0xFu; + let db = d * (1.0 + 2.0 * f32(sub_scale)); + + var row_sum = 0.0; + for (var ll = 0u; ll < 2u; ll++) { + let l = slot0 + ll; + let qs_word = select(qs_hi, qs_lo, l < 2u); + let byte_pos = (l % 2u) * 2u; + let qs0 = (qs_word >> (byte_pos * 8u)) & 0xFFu; + let qs1 = (qs_word >> ((byte_pos + 1u) * 8u)) & 0xFFu; + let grid_idx_1 = qs0 | (((qh_byte >> (2u * l)) & 1u) << 8u); + let grid_idx_2 = qs1 | (((qh_byte >> (2u * l + 1u)) & 1u) << 8u); + let sign_byte = get_byte(sg_w, l); + let grid1 = iq3s_grid[grid_idx_1]; + let grid2 = iq3s_grid[grid_idx_2]; + for (var j = 0u; j < 4u; j++) { + let b1 = f32((grid1 >> (j * 8u)) & 0xFFu); + let b2 = f32((grid2 >> (j * 8u)) & 0xFFu); + let s1 = select(1.0, -1.0, ((sign_byte >> j) & 1u) != 0u); + let s2 = select(1.0, -1.0, ((sign_byte >> (j + 4u)) & 1u) != 0u); + row_sum += db * b1 * s1 * x_block[ll * 8u + j]; + row_sum += db * b2 * s2 * x_block[ll * 8u + j + 4u]; + } + } + acc[row] += row_sum; + } + } + } +#endif + +#ifdef MUL_ACC_IQ4_NL +#define BLOCK_SIZE 32 +#define BLOCK_SIZE_BYTES 18 +#define THREADS_PER_BLOCK 4 +#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) + + let num_blocks = params.k / BLOCK_SIZE; + let thread_within_block = thread_id % THREADS_PER_BLOCK; + for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { + let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4u; + var x_block: array; + for (var i = 0u; i < ELEMS_PER_THREAD / 2u; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 4u] = f32(src1[x_base + i + 16u]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + var row_sum = 0.0; + + let q_packed = load_u32_at_src0(block_byte_base + 2u + 4u * thread_within_block); + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_byte = get_byte(q_packed, byte_idx); + let q_lo = f32(kvalues_iq4nl[q_byte & 0xFu]) * d; + let q_hi = f32(kvalues_iq4nl[(q_byte >> 4u) & 0xFu]) * d; + row_sum += q_lo * x_block[byte_idx]; + row_sum += q_hi * x_block[byte_idx + 4u]; + } + acc[row] += row_sum; + } + } + } +#endif + +#ifdef MUL_ACC_IQ4_XS +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 136 +#define THREADS_PER_BLOCK 16 + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let sub_blk = tid / 2u; + let half = tid % 2u; + let y_offset = sub_blk * 32u + half * 16u; + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; + for (var i = 0u; i < 16u; i++) { + x_block[i] = f32(src1[x_base + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + let scales_h = load_u16_at_src0(block_byte_base + 2u); + let scales_l_word = load_u32_at_src0(block_byte_base + 4u); + let sl_byte = get_byte(scales_l_word, sub_blk / 2u); + let sl = (sl_byte >> (4u * (sub_blk % 2u))) & 0xFu; + let sh_bits = (scales_h >> (2u * sub_blk)) & 3u; + let ls = i32(sl | (sh_bits << 4u)); + let dl = d * f32(ls - 32); + + let qs_byte_off = 8u + sub_blk * 16u; + let q_w0 = load_u32_at_src0(block_byte_base + qs_byte_off); + let q_w1 = load_u32_at_src0(block_byte_base + qs_byte_off + 4u); + let q_w2 = load_u32_at_src0(block_byte_base + qs_byte_off + 8u); + let q_w3 = load_u32_at_src0(block_byte_base + qs_byte_off + 12u); + + var row_sum = 0.0; + for (var i = 0u; i < 16u; i++) { + let q_word = select( + select(q_w0, q_w1, i >= 4u), + select(q_w2, q_w3, i >= 12u), + i >= 8u); + let q_byte = get_byte(q_word, i % 4u); + let nib = select(q_byte & 0xFu, (q_byte >> 4u) & 0xFu, half == 1u); + row_sum += f32(kvalues_iq4nl[nib]) * dl * x_block[i]; + } + acc[row] += row_sum; + } + } + } +#endif + #ifdef USE_SUBGROUP_REDUCTION for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let subgroup_total = subgroupAdd(acc[row]); From 9c233f11f09c0ea3d7d8df0056c3c312ef9248f3 Mon Sep 17 00:00:00 2001 From: Rithik Sharma Date: Mon, 27 Apr 2026 15:50:59 -0700 Subject: [PATCH 505/831] ggml-webgpu: add Q1_0 support (llama/22374) * add fast matmul matvec q1_0 kernel * ggml-webgpu: drop redundant zero-fills in Q1_0 shmem init --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 9 +++-- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 4 +++ .../ggml-webgpu/wgsl-shaders/get_rows.wgsl | 18 ++++++++++ .../wgsl-shaders/mul_mat_decls.tmpl | 33 +++++++++++++++++++ .../ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl | 32 ++++++++++++++++++ 5 files changed, 94 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 08ea2906ada..fb2c9527f3c 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -1287,6 +1287,7 @@ class ggml_webgpu_shader_lib { std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper); switch (key.src_type) { + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q5_0: case GGML_TYPE_Q8_0: @@ -1323,7 +1324,9 @@ class ggml_webgpu_shader_lib { defines.push_back("DST_TYPE=f32"); - if ((key.src_type >= GGML_TYPE_Q4_0 && key.src_type <= GGML_TYPE_Q8_1) || + if (key.src_type == GGML_TYPE_Q1_0) { + defines.push_back("BLOCK_SIZE=128u"); + } else if ((key.src_type >= GGML_TYPE_Q4_0 && key.src_type <= GGML_TYPE_Q8_1) || key.src_type == GGML_TYPE_IQ4_NL) { defines.push_back("BLOCK_SIZE=32u"); } else if (key.src_type >= GGML_TYPE_Q2_K) { @@ -1657,7 +1660,9 @@ class ggml_webgpu_shader_lib { uint32_t wg_size = WEBGPU_MUL_MAT_VEC_WG_SIZE; uint32_t outputs_per_wg = WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG; - if (key.src0_type >= GGML_TYPE_Q2_K) { + if (key.src0_type == GGML_TYPE_Q1_0) { + outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG; + } else if (key.src0_type >= GGML_TYPE_Q2_K) { outputs_per_wg = WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG; } else if (key.src0_type >= GGML_TYPE_Q4_0) { outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG; diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index d6d7dbdaf3c..6d861c0c781 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -1389,6 +1389,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, case GGML_TYPE_Q5_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q2_K: + case GGML_TYPE_Q1_0: use_fast = true; break; case GGML_TYPE_IQ1_S: @@ -3736,6 +3737,7 @@ static bool ggml_backend_webgpu_device_supports_buft(ggml_backend_dev_t dev, ggm static bool ggml_webgpu_supported_qtype(ggml_type type) { switch (type) { + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -3830,6 +3832,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const switch (src0->type) { case GGML_TYPE_F32: case GGML_TYPE_F16: + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -3868,6 +3871,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const switch (src0->type) { case GGML_TYPE_F32: case GGML_TYPE_F16: + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl index 1415798fa6b..5710cd35469 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl @@ -27,6 +27,24 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { } #endif +#ifdef Q1_0 +fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { + let block_byte_base = (src_base + offset) * 18; + let d = load_f16_as_f32_at_src(block_byte_base); + for (var j: u32 = 0u; j < 4u; j++) { + let q_packed = load_u32_at_src(block_byte_base + 2u + j * 4u); + let dst_base128 = dst_base + offset * 128u + j * 32u; + for (var k: u32 = 0; k < 4u; k++) { + let q_byte = get_byte(q_packed, k); + for (var bit: u32 = 0; bit < 8u; bit++) { + let w = select(-d, d, ((q_byte >> bit) & 1u) != 0u); + dst[dst_base128 + k * 8u + bit] = w; + } + } + } +} +#endif + #ifdef Q4_0 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block_byte_base = (src_base + offset) * 18; // Block stride: 18 bytes diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl index 5a323818260..15b22c4f731 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl @@ -61,6 +61,39 @@ fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u3 #endif // INIT_SRC1_SHMEM_FLOAT #endif +#ifdef INIT_SRC0_SHMEM_Q1_0 +const BLOCK_SIZE = 128u; +const BLOCK_SIZE_BYTES = 18u; +const NQ = 8u; // 8 weights (1 byte of qs) per thread per iteration + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) { + let tile_m = i / TILE_K; + let tile_k_start = i % TILE_K; + let global_m = offset_m + tile_m; + let global_k_start = k_outer + tile_k_start; + + if (global_m >= params.m) { + break; + } + + let block_k = global_k_start / BLOCK_SIZE; + let byte_in_block = (global_k_start % BLOCK_SIZE) / 8u; + let src0_idx = batch_offset + global_m * params.stride_01 + block_k; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + let d = load_f16_at_src0(block_byte_base); + let q_byte = load_u32_at_src0(block_byte_base + 2u + byte_in_block) & 0xFFu; + + for (var bit = 0u; bit < NQ; bit++) { + let global_k = global_k_start + bit; + if (global_k < params.k) { + shmem[i + bit] = select(-d, d, ((q_byte >> bit) & 1u) != 0u); + } + } + } +} +#endif // INIT_SRC0_SHMEM_Q1_0 + #ifdef INIT_SRC0_SHMEM_Q4_0 const BLOCK_SIZE = 32u; const BLOCK_SIZE_BYTES = 18u; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl index c2eafee6c75..a8000439bfb 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl @@ -128,6 +128,38 @@ fn main( } #endif +#ifdef MUL_ACC_Q1_0 +#define BLOCK_SIZE 128 +#define BLOCK_SIZE_BYTES 18 +#define THREADS_PER_BLOCK 16 +#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) + + let num_blocks = params.k / BLOCK_SIZE; + let thread_within_block = thread_id % THREADS_PER_BLOCK; + for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { + let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * ELEMS_PER_THREAD; + var x_block: array; + for (var i = 0u; i < ELEMS_PER_THREAD; i++) { + x_block[i] = f32(src1[x_base + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + let q_byte = load_u32_at_src0(block_byte_base + 2u + thread_within_block) & 0xFFu; + var row_sum = 0.0; + for (var bit = 0u; bit < 8u; bit++) { + let w = select(-d, d, ((q_byte >> bit) & 1u) != 0u); + row_sum += w * x_block[bit]; + } + acc[row] += row_sum; + } + } + } +#endif + #ifdef MUL_ACC_Q4_0 #define BLOCK_SIZE 32 #define BLOCK_SIZE_BYTES 18 From 70e4c0aec058a27f7abf0df1dd7a9660ba3bd4a0 Mon Sep 17 00:00:00 2001 From: hipudding Date: Tue, 28 Apr 2026 14:27:22 +0800 Subject: [PATCH 506/831] CANN: add new ops, optimize existing ops (llama/21204) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit New operators: - GGML_OP_SET: implement via aclnnInplaceCopy on target region - GGML_OP_CUMSUM: implement via aclnnCumsum - GGML_OP_FILL: implement via aclnnInplaceFillScalar - GGML_OP_DIAG: implement via aclnnInplaceCopy on diagonal strides - GGML_OP_TRI (lower/lower_diag/upper_diag/upper): implement via aclnnTril(-1/0) and aclnnTriu(0/1) with appropriate diagonal offsets - GGML_OP_SOLVE_TRI: implement via aclnnTriangularSolve - GGML_UNARY_OP_SOFTPLUS: implement via aclnnSoftplus Optimizations: - GLU (SwiGLU/GeGLU/GeGLU_ERF/GeGLU_QUICK): fuse with aclnnSwiGlu / aclnnGeGluV3 when applicable; fallback conditions now checked inside each function rather than at the call site - CROSS_ENTROPY_LOSS: replace 5-kernel sequence (LogSoftmax→Mul→ ReduceSum×2→Muls) with single aclnnSoftmaxCrossEntropyWithLogits call - L2_NORM: fix in-place ClampMin on norm result (was clamping wrong tensor); add eps clamping before division to avoid divide-by-zero - PAD_REFLECT_1D: eliminate per-ne[3] loop; assert contiguity and call ReflectionPad1d once on the full 4-D view; remove redundant nb copies - GET_ROWS: replace IndexSelect with GatherV2 per batch slice; refactor helper into gather_batched lambda with batch loop inlined - SET_ROWS: replace IndexCopy with InplaceIndexCopy per batch slice; refactor helper into scatter_batched lambda with batch loop inlined - OUT_PROD: replace O(ne[3]*ne[2]*ne[1]) Ger+InplaceAdd loop with per-slice Matmul loop (src0 @ src1^T); handles strided-broadcast batch dims where ne02/ne03 may differ from ne2/ne3 - backend memset_tensor: implement via aclrtMemset (was NULL) Bug fixes: - COUNT_EQUAL: use non-inplace EqTensor into a same-type temporary buffer instead of InplaceEqTensor, avoiding corruption of src0 - ACL graph cache (USE_ACL_GRAPH): restore node_type and src_type[] fields in ggml_graph_node_properties; has_matching_properties() was missing type checks, causing F16 and BF16 tensors (same nb[0]=2) to incorrectly share cached graphs and produce wrong results (ERR≈679) - graph cache op_params matching: compare full GGML_MAX_OP_PARAMS bytes so that ops differing only in parameters are not incorrectly replayed from cache --- ggml/src/ggml-cann/aclnn_ops.cpp | 768 +++++++++++++++++++++---------- ggml/src/ggml-cann/aclnn_ops.h | 56 +++ ggml/src/ggml-cann/ggml-cann.cpp | 66 ++- 3 files changed, 628 insertions(+), 262 deletions(-) diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index a950475fc3b..2dc0f40917d 100644 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -25,6 +25,7 @@ #include "ggml-impl.h" #include "ggml.h" + #include #include #include @@ -45,7 +46,9 @@ #include #include #include +#include #include +#include #include #include #include @@ -62,6 +65,7 @@ #include #include #include +#include #include #include #include @@ -69,11 +73,15 @@ #include #include #include +#include #include #include #include #include +#include #include +#include +#include #include #include #include @@ -151,6 +159,107 @@ void ggml_cann_op_unary_gated(std::functionsrc[1] != nullptr || swapped != 0) { + ggml_cann_op_unary_gated(silu_fn, ctx, dst); + return; + } + + // aclnnSwiGlu requires the split dim (src->ne[0]) to be even; fall back otherwise. + if (dst->src[0]->ne[0] % 2 != 0) { + ggml_cann_op_unary_gated(silu_fn, ctx, dst); + return; + } + + ggml_tensor * src0 = dst->src[0]; + size_t elem_size = ggml_element_size(src0); + + // src0 GGML: [2*ne0, ne1, ne2, ne3] → 3D view [2*ne0, ne1, ne2*ne3] + // CANN reversed: [ne2*ne3, ne1, 2*ne0], split along CANN dim 2 (last). + int64_t ne0_x2 = src0->ne[0]; + int64_t ne1 = src0->ne[1]; + int64_t ne23 = src0->ne[2] * src0->ne[3]; + int64_t src3d_ne[] = { ne0_x2, ne1, ne23 }; + size_t src3d_nb[] = { (size_t)src0->nb[0], (size_t)src0->nb[1], (size_t)src0->nb[2] }; + acl_tensor_ptr acl_src = ggml_cann_create_tensor(src0->data, ggml_cann_type_mapping(src0->type), + elem_size, src3d_ne, src3d_nb, 3); + + // dst GGML: [ne0, ne1, ne2, ne3] → 3D view [ne0, ne1, ne2*ne3] + int64_t ne0 = dst->ne[0]; + int64_t dst3d_ne[] = { ne0, ne1, ne23 }; + size_t dst3d_nb[] = { (size_t)dst->nb[0], (size_t)dst->nb[1], (size_t)dst->nb[2] }; + acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst->data, ggml_cann_type_mapping(dst->type), + elem_size, dst3d_ne, dst3d_nb, 3); + + // CANN tensor [ne23, ne1, 2*ne0]: split along CANN dim 2 (last) = 2*ne0. + GGML_CANN_CALL_ACLNN_OP(ctx, SwiGlu, acl_src.get(), (int64_t)2, acl_dst.get()); +} + +// Fused GeGLU using aclnnGeGluV3: splits input along ne[0] (CANN last dim), +// activates the LEFT half with GELU, multiplies by right half. +// approximate: 0=tanh, 1=none(erf). activateLeft=true matches GGML convention. +// outGelu is a required-but-discard output buffer. +// +// Falls back to the generic two-kernel path when src[1] != nullptr (two +// independent halves) or swapped != 0 (reversed activation order), as +// aclnnGeGluV3 only handles the single interleaved tensor in standard order. +void ggml_cann_geglu(ggml_backend_cann_context & ctx, ggml_tensor * dst, int64_t approximate) { + auto gelu_fn = [](ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst) { + GGML_CANN_CALL_ACLNN_OP(ctx, Gelu, acl_src, acl_dst); + }; + + const int32_t swapped = ggml_get_op_params_i32(dst, 1); + if (dst->src[1] != nullptr || swapped != 0) { + ggml_cann_op_unary_gated(gelu_fn, ctx, dst); + return; + } + + // aclnnGeGluV3 requires the split dim (src->ne[0]) to be even; fall back otherwise. + if (dst->src[0]->ne[0] % 2 != 0) { + ggml_cann_op_unary_gated(gelu_fn, ctx, dst); + return; + } + + ggml_tensor * src0 = dst->src[0]; + acl_tensor_ptr acl_src = ggml_cann_create_tensor(src0); + acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst); + + // Allocate a temporary buffer for the required outGelu output (same shape as dst). + // Build contiguous strides since the pool allocation is a fresh buffer. + size_t elem_size = ggml_element_size(dst); + int64_t ne[GGML_MAX_DIMS] = { dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3] }; + size_t nb[GGML_MAX_DIMS]; + nb[0] = elem_size; + for (int i = 1; i < GGML_MAX_DIMS; i++) { + nb[i] = nb[i - 1] * ne[i - 1]; + } + size_t gelu_out_size = nb[GGML_MAX_DIMS - 1] * ne[GGML_MAX_DIMS - 1]; + ggml_cann_pool_alloc gelu_out_alloc(ctx.pool(), gelu_out_size); + + acl_tensor_ptr acl_gelu_out = ggml_cann_create_tensor( + gelu_out_alloc.get(), ggml_cann_type_mapping(dst->type), elem_size, ne, nb, GGML_MAX_DIMS); + // V3 adds activateLeft param; true → Gelu(left)*right, matching GGML convention. + // GGML dim 0 → CANN last dim (index GGML_MAX_DIMS-1 = 3 for 4D tensor). + GGML_CANN_CALL_ACLNN_OP(ctx, GeGluV3, acl_src.get(), (int64_t)(GGML_MAX_DIMS - 1), approximate, true, + acl_dst.get(), acl_gelu_out.get()); +} + /** * @brief Repeats elements of a tensor along each dimension according to the * specified repeat array. @@ -445,28 +554,33 @@ void ggml_cann_l2_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst) { ggml_cann_pool_alloc temp_buffer_allocator(ctx.pool(), n_bytes); void * buffer = temp_buffer_allocator.get(); - int64_t div_ne[] = { 1, src->ne[1], src->ne[2], src->ne[3] }; - size_t div_nb[GGML_MAX_DIMS]; - div_nb[0] = sizeof(float); + int64_t norm_ne[] = { 1, src->ne[1], src->ne[2], src->ne[3] }; + size_t norm_nb[GGML_MAX_DIMS]; + norm_nb[0] = sizeof(float); for (int i = 1; i < GGML_MAX_DIMS; ++i) { - div_nb[i] = div_nb[i - 1] * div_ne[i - 1]; + norm_nb[i] = norm_nb[i - 1] * norm_ne[i - 1]; } - acl_tensor_ptr acl_div = ggml_cann_create_tensor(buffer, ACL_FLOAT, type_size, div_ne, div_nb, GGML_MAX_DIMS); + acl_tensor_ptr acl_norm = ggml_cann_create_tensor(buffer, ACL_FLOAT, sizeof(float), norm_ne, norm_nb, GGML_MAX_DIMS); std::vector norm_dims = { 3 }; acl_int_array_ptr dims_array = ggml_cann_create_int_array(norm_dims.data(), norm_dims.size()); float p_value = 2.0f; acl_scalar_ptr p_scalar = ggml_cann_create_scalar(&p_value, aclDataType::ACL_FLOAT); - GGML_CANN_CALL_ACLNN_OP(ctx, Norm, acl_src.get(), p_scalar.get(), dims_array.get(), true, acl_div.get()); + GGML_CANN_CALL_ACLNN_OP(ctx, Norm, acl_src.get(), p_scalar.get(), dims_array.get(), true, acl_norm.get()); + + ggml_cann_pool_alloc clamp_buffer_allocator(ctx.pool()); + acl_tensor_ptr acl_clamped; - // Clamp norm to at least eps: scale = 1/fmaxf(norm, eps) - acl_scalar_ptr acl_min = ggml_cann_create_scalar(&eps, aclDataType::ACL_FLOAT); - float flt_max = FLT_MAX; - acl_scalar_ptr acl_max = ggml_cann_create_scalar(&flt_max, aclDataType::ACL_FLOAT); - GGML_CANN_CALL_ACLNN_OP(ctx, Clamp, acl_div.get(), acl_min.get(), acl_max.get(), acl_div.get()); + if (eps > 0.0f) { + void * clamp_buf = clamp_buffer_allocator.alloc(n_bytes); + acl_clamped = ggml_cann_create_tensor(clamp_buf, ACL_FLOAT, sizeof(float), norm_ne, norm_nb, GGML_MAX_DIMS); + acl_scalar_ptr eps_scalar = ggml_cann_create_scalar(&eps, aclDataType::ACL_FLOAT); + GGML_CANN_CALL_ACLNN_OP(ctx, ClampMin, acl_norm.get(), eps_scalar.get(), acl_clamped.get()); + } - GGML_CANN_CALL_ACLNN_OP(ctx, Div, acl_src.get(), acl_div.get(), acl_dst.get()); + aclTensor * acl_div_input = acl_clamped ? acl_clamped.get() : acl_norm.get(); + GGML_CANN_CALL_ACLNN_OP(ctx, Div, acl_src.get(), acl_div_input, acl_dst.get()); } void ggml_cann_cross_entropy_loss(ggml_backend_cann_context & ctx, ggml_tensor * dst) { @@ -482,56 +596,30 @@ void ggml_cann_cross_entropy_loss(ggml_backend_cann_context & ctx, ggml_tensor * logits_nb[1] = logits_nb[0] * logits_ne[0]; acl_tensor_ptr acl_logits = ggml_cann_create_tensor(src0->data, ACL_FLOAT, sizeof(float), logits_ne, logits_nb, 2); - size_t log_softmax_type_size = sizeof(float); - int64_t log_softmax_n_bytes = nr * nc * log_softmax_type_size; - ggml_cann_pool_alloc log_softmax_allocator(ctx.pool(), log_softmax_n_bytes); - void * log_softmax_buffer = log_softmax_allocator.get(); - - int64_t log_softmax_ne[] = { nc, nr }; - size_t log_softmax_nb[2]; - log_softmax_nb[0] = log_softmax_type_size; - log_softmax_nb[1] = log_softmax_nb[0] * log_softmax_ne[0]; - acl_tensor_ptr acl_log_softmax = ggml_cann_create_tensor(log_softmax_buffer, ACL_FLOAT, log_softmax_type_size, - log_softmax_ne, log_softmax_nb, 2); - - GGML_CANN_CALL_ACLNN_OP(ctx, LogSoftmax, acl_logits.get(), 1, acl_log_softmax.get()); - int64_t labels_ne[] = { nc, nr }; size_t labels_nb[2]; labels_nb[0] = ggml_type_size(src1->type); labels_nb[1] = labels_nb[0] * labels_ne[0]; acl_tensor_ptr acl_labels = ggml_cann_create_tensor(src1->data, ACL_FLOAT, sizeof(float), labels_ne, labels_nb, 2); - size_t mul_type_size = sizeof(float); - int64_t mul_n_bytes = nr * nc * mul_type_size; - ggml_cann_pool_alloc mul_allocator(ctx.pool(), mul_n_bytes); - void * mul_buffer = mul_allocator.get(); - - int64_t mul_ne[] = { nc, nr }; - size_t mul_nb[2]; - mul_nb[0] = mul_type_size; - mul_nb[1] = mul_nb[0] * mul_ne[0]; - acl_tensor_ptr acl_mul_result = ggml_cann_create_tensor(mul_buffer, ACL_FLOAT, mul_type_size, mul_ne, mul_nb, 2); - - GGML_CANN_CALL_ACLNN_OP(ctx, Mul, acl_log_softmax.get(), acl_labels.get(), acl_mul_result.get()); + size_t loss_per_sample_type_size = sizeof(float); + int64_t loss_per_sample_n_bytes = nr * loss_per_sample_type_size; + ggml_cann_pool_alloc loss_per_sample_allocator(ctx.pool(), loss_per_sample_n_bytes); + void * loss_per_sample_buffer = loss_per_sample_allocator.get(); - size_t sum_per_sample_type_size = sizeof(float); - int64_t sum_per_sample_n_bytes = nr * sum_per_sample_type_size; - ggml_cann_pool_alloc sum_per_sample_allocator(ctx.pool(), sum_per_sample_n_bytes); - void * sum_per_sample_buffer = sum_per_sample_allocator.get(); + int64_t loss_per_sample_ne[] = { nr }; + size_t loss_per_sample_nb[1]; + loss_per_sample_nb[0] = loss_per_sample_type_size; + acl_tensor_ptr acl_loss_per_sample = ggml_cann_create_tensor( + loss_per_sample_buffer, ACL_FLOAT, loss_per_sample_type_size, loss_per_sample_ne, loss_per_sample_nb, 1); - int64_t sum_per_sample_ne[] = { nr }; - size_t sum_per_sample_nb[1]; - sum_per_sample_nb[0] = sum_per_sample_type_size; - acl_tensor_ptr acl_sum_per_sample = ggml_cann_create_tensor( - sum_per_sample_buffer, ACL_FLOAT, sum_per_sample_type_size, sum_per_sample_ne, sum_per_sample_nb, 1); + size_t backprop_n_bytes = nr * nc * sizeof(float); + ggml_cann_pool_alloc backprop_allocator(ctx.pool(), backprop_n_bytes); + void * backprop_buffer = backprop_allocator.get(); + acl_tensor_ptr acl_backprop = ggml_cann_create_tensor(backprop_buffer, ACL_FLOAT, sizeof(float), logits_ne, logits_nb, 2); - std::vector sum_dims = { 1 }; - acl_int_array_ptr dims_array = ggml_cann_create_int_array(sum_dims.data(), sum_dims.size()); - bool keep_dims = false; - - GGML_CANN_CALL_ACLNN_OP(ctx, ReduceSum, acl_mul_result.get(), dims_array.get(), keep_dims, ACL_FLOAT, - acl_sum_per_sample.get()); + GGML_CANN_CALL_ACLNN_OP(ctx, SoftmaxCrossEntropyWithLogits, acl_logits.get(), acl_labels.get(), + acl_loss_per_sample.get(), acl_backprop.get()); size_t total_sum_type_size = sizeof(float); int64_t total_sum_n_bytes = 1 * total_sum_type_size; @@ -547,11 +635,12 @@ void ggml_cann_cross_entropy_loss(ggml_backend_cann_context & ctx, ggml_tensor * std::vector total_sum_dims = { 0 }; acl_int_array_ptr total_sum_dims_array = ggml_cann_create_int_array(total_sum_dims.data(), total_sum_dims.size()); + bool keep_dims = false; - GGML_CANN_CALL_ACLNN_OP(ctx, ReduceSum, acl_sum_per_sample.get(), total_sum_dims_array.get(), keep_dims, ACL_FLOAT, + GGML_CANN_CALL_ACLNN_OP(ctx, ReduceSum, acl_loss_per_sample.get(), total_sum_dims_array.get(), keep_dims, ACL_FLOAT, acl_total_sum.get()); - float value = -1.0f / static_cast(nr); + float value = 1.0f / static_cast(nr); acl_scalar_ptr scale_factor = ggml_cann_create_scalar(&value, aclDataType::ACL_FLOAT); acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst->data, ACL_FLOAT, sizeof(float), total_sum_ne, total_sum_nb, 1); @@ -589,6 +678,33 @@ void ggml_cann_group_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst) { acl_mean_out.get(), acl_rstd_out.get()); } +void ggml_cann_set(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + ggml_tensor * src0 = dst->src[0]; + ggml_tensor * src1 = dst->src[1]; + + size_t nb1 = ((int32_t *) dst->op_params)[0]; + size_t nb2 = ((int32_t *) dst->op_params)[1]; + size_t nb3 = ((int32_t *) dst->op_params)[2]; + size_t offset = ((int32_t *) dst->op_params)[3]; + bool inplace = (bool) ((int32_t *) dst->op_params)[4]; + + size_t param_nb[] = { ggml_element_size(src0), nb1, nb2, nb3 }; + + // Create a view of dst at the target offset with src1's dimensions + acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst, src1->ne, param_nb, GGML_MAX_DIMS, ACL_FORMAT_ND, offset); + acl_tensor_ptr acl_src1 = ggml_cann_create_tensor(src1); + + if (!inplace) { + // First copy src0 to dst entirely + size_t cpy_size = ggml_nbytes(dst); + ACL_CHECK( + aclrtMemcpyAsync(dst->data, cpy_size, src0->data, cpy_size, ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream())); + } + + // Copy src1 into the target region of dst + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceCopy, acl_dst.get(), acl_src1.get()); +} + void ggml_cann_acc(ggml_backend_cann_context & ctx, ggml_tensor * dst) { ggml_tensor * src0 = dst->src[0]; ggml_tensor * src1 = dst->src[1]; @@ -652,6 +768,113 @@ void ggml_cann_sum(ggml_backend_cann_context & ctx, ggml_tensor * dst) { aclnn_reduce_sum(ctx, dst, reduce_dims, 4); } +void ggml_cann_cumsum(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + ggml_tensor * src = dst->src[0]; + acl_tensor_ptr acl_src = ggml_cann_create_tensor(src); + acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst); + // GGML cumsum operates along dim 0 (innermost / ne[0]). + // ggml_cann_create_tensor reverses dimensions to [ne3,ne2,ne1,ne0], + // so GGML dim 0 maps to CANN dim 3 (the last dim of the 4-D tensor). + GGML_CANN_CALL_ACLNN_OP(ctx, Cumsum, acl_src.get(), (int64_t)3, + ggml_cann_type_mapping(dst->type), acl_dst.get()); +} + +void ggml_cann_solve_tri(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + ggml_tensor * src0 = dst->src[0]; // A: [N, N, B2, B3] lower triangular + ggml_tensor * src1 = dst->src[1]; // B: [K, N, B2, B3] + + acl_tensor_ptr acl_a = ggml_cann_create_tensor(src0); + acl_tensor_ptr acl_b = ggml_cann_create_tensor(src1); + acl_tensor_ptr acl_x = ggml_cann_create_tensor(dst); + + // mOut: triangular copy of A (required output), same shape as A. + const size_t a_bytes = ggml_nbytes(src0); + ggml_cann_pool_alloc m_alloc(ctx.pool(), a_bytes); + acl_tensor_ptr acl_m = ggml_cann_create_tensor( + m_alloc.get(), ggml_cann_type_mapping(src0->type), + ggml_type_size(src0->type), src0->ne, src0->nb, GGML_MAX_DIMS); + + // Solve AX = B: upper=false (lower tri), transpose=false, unitriangular=false. + GGML_CANN_CALL_ACLNN_OP(ctx, TriangularSolve, + acl_b.get(), acl_a.get(), false, false, false, + acl_x.get(), acl_m.get()); +} + +void ggml_cann_diag(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + ggml_tensor * src = dst->src[0]; + + GGML_ASSERT(src->ne[1] == 1); + + const int64_t N = src->ne[0]; + const int64_t n_batch = src->ne[2] * src->ne[3]; + const size_t nb_f32 = sizeof(float); + + // Fill dst with zeros. + acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst); + { + float zero = 0.0f; + acl_scalar_ptr acl_zero = ggml_cann_create_scalar(&zero, ACL_FLOAT); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceFillScalar, acl_dst.get(), acl_zero.get()); + } + + // Copy src vector onto the diagonal of dst via strided views. + // src viewed as [N, n_batch], contiguous strides. + int64_t ne_vec[2] = { N, n_batch }; + size_t nb_src_vec[2] = { nb_f32, N * nb_f32 }; + // dst diagonal view: stride (N+1)*4 steps along the diagonal. + size_t nb_dst_diag[2] = { (N + 1) * nb_f32, N * N * nb_f32 }; + + acl_tensor_ptr acl_src_vec = ggml_cann_create_tensor(src->data, ACL_FLOAT, nb_f32, ne_vec, nb_src_vec, 2); + acl_tensor_ptr acl_dst_diag = ggml_cann_create_tensor(dst->data, ACL_FLOAT, nb_f32, ne_vec, nb_dst_diag, 2); + + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceCopy, acl_dst_diag.get(), acl_src_vec.get()); +} + +void ggml_cann_fill(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + float c = ggml_get_op_params_f32(dst, 0); + + acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst); + acl_scalar_ptr acl_c = ggml_cann_create_scalar(&c, ACL_FLOAT); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceFillScalar, acl_dst.get(), acl_c.get()); +} + +void ggml_cann_tri(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + ggml_tensor * src = dst->src[0]; + + const int64_t S = src->ne[0]; + const int64_t n_batch = src->ne[2] * src->ne[3]; + const size_t nb_f32 = sizeof(float); + + int64_t ne3d[3] = { S, S, n_batch }; + size_t nb3d[3] = { nb_f32, S * nb_f32, S * S * nb_f32 }; + + const ggml_tri_type ttype = (ggml_tri_type) ggml_get_op_params_i32(dst, 0); + + acl_tensor_ptr acl_src = ggml_cann_create_tensor(src->data, ACL_FLOAT, nb_f32, ne3d, nb3d, 3); + acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst->data, ACL_FLOAT, nb_f32, ne3d, nb3d, 3); + + switch (ttype) { + case GGML_TRI_TYPE_LOWER: + // Tril(-1): preserve row > col (strict lower), zero upper + diagonal. + GGML_CANN_CALL_ACLNN_OP(ctx, Tril, acl_src.get(), (int64_t)-1, acl_dst.get()); + break; + case GGML_TRI_TYPE_UPPER_DIAG: + // Triu(0): preserve row <= col (upper + diagonal), zero strict lower. + GGML_CANN_CALL_ACLNN_OP(ctx, Triu, acl_src.get(), (int64_t)0, acl_dst.get()); + break; + case GGML_TRI_TYPE_UPPER: + // Triu(1): preserve row < col (strict upper), zero lower + diagonal. + GGML_CANN_CALL_ACLNN_OP(ctx, Triu, acl_src.get(), (int64_t)1, acl_dst.get()); + break; + case GGML_TRI_TYPE_LOWER_DIAG: + // Tril(0): preserve row >= col (lower + diagonal), zero strict upper. + GGML_CANN_CALL_ACLNN_OP(ctx, Tril, acl_src.get(), (int64_t)0, acl_dst.get()); + break; + default: + GGML_ABORT("unsupported tri type"); + } +} + void ggml_cann_upsample_nearest2d(ggml_backend_cann_context & ctx, ggml_tensor * dst) { ggml_tensor * src = dst->src[0]; acl_tensor_ptr acl_src = ggml_cann_create_tensor(src, nullptr, nullptr, 0, ACL_FORMAT_NCHW); @@ -1695,152 +1918,90 @@ void ggml_cann_softmax(ggml_backend_cann_context & ctx, ggml_tensor * dst) { aclnn_softmax(ctx, softmax_tensor.get(), 3, acl_dst.get()); } -/** - * @brief Performs index select operation on a 4D tensor using the CANN backend. - * - * This function applies the `IndexSelect` operation along a specific dimension - * of the source tensor (`src_buffer`) using the indices from the index tensor (`index`). - * It iterates over the last two dimensions of the source tensor, creates the corresponding - * CANN tensors for the source, index, and output slices, and executes the `IndexSelect` - * operation for each slice. - * - * @param ctx The context for CANN backend operations. - * @param src_buffer The source buffer containing the 4D input tensor data. - * @param src_ne The dimensions of the source tensor. - * @param src_nb The strides (byte offsets) of the source tensor. - * @param dst_buffer The destination buffer where the output tensor data will be written. - * @param dst_ne The dimensions of the destination tensor. - * @param dst_nb The strides (byte offsets) of the destination tensor. - * @param index The index tensor specifying the indices to select from the source tensor. - * @param type The data type of the source and destination tensors. - */ -static void aclnn_index_select_4d(ggml_backend_cann_context & ctx, - void * src_buffer, - int64_t * src_ne, - size_t * src_nb, - void * dst_buffer, - int64_t * dst_ne, - size_t * dst_nb, - ggml_tensor * index, - ggml_type type) { - for (int64_t i = 0; i < src_ne[3]; i++) { - for (int64_t j = 0; j < src_ne[2]; j++) { - // src - acl_tensor_ptr acl_src_tensor = - ggml_cann_create_tensor((char *) src_buffer + i * src_nb[3] + j * src_nb[2], - ggml_cann_type_mapping(type), ggml_type_size(type), src_ne, src_nb, 2); - - // index - acl_tensor_ptr acl_index = ggml_cann_create_tensor( - (char *) index->data + (i % index->ne[2]) * index->nb[2] + (j % index->ne[1]) * index->nb[1], - ggml_cann_type_mapping(index->type), ggml_element_size(index), index->ne, index->nb, 1); - - // out - acl_tensor_ptr acl_out = - ggml_cann_create_tensor((char *) dst_buffer + i * dst_nb[3] + j * dst_nb[2], - ggml_cann_type_mapping(type), ggml_type_size(type), dst_ne, dst_nb, 2); - GGML_CANN_CALL_ACLNN_OP(ctx, IndexSelect, acl_src_tensor.get(), 0, acl_index.get(), acl_out.get()); - } - } -} - -/** - * @brief Performs inplace index copy operation on a 4D tensor using the CANN backend. - * - * This function applies the `IndexCopy` operation along a specific dimension of the - * destination tensor (`dst_buffer`) by copying elements from the source tensor (`src_buffer`) - * to positions specified by the index tensor (`index`). - * It iterates over the last two dimensions of the tensors, creates the corresponding - * CANN tensors for source, index, and destination slices, and performs the index copy - * operation for each slice. - * - * @param ctx The context for CANN backend operations. - * @param src_buffer The source buffer containing the 4D input tensor data to be copied. - * @param src_ne The dimensions of the source tensor. - * @param src_nb The strides (byte offsets) of the source tensor. - * @param dst_buffer The destination buffer where values will be copied to. - * @param dst_ne The dimensions of the destination tensor. - * @param dst_nb The strides (byte offsets) of the destination tensor. - * @param index The index tensor specifying target positions in the destination tensor. - * @param type The data type of the source and destination tensors. - */ -static void aclnn_index_copy_4d(ggml_backend_cann_context & ctx, - void * src_buffer, - int64_t * src_ne, - size_t * src_nb, - void * dst_buffer, - int64_t * dst_ne, - size_t * dst_nb, - ggml_tensor * index, - ggml_type type) { - for (int64_t i = 0; i < src_ne[3]; i++) { - for (int64_t j = 0; j < src_ne[2]; j++) { - // src - acl_tensor_ptr acl_src_tensor = - ggml_cann_create_tensor((char *) src_buffer + i * src_nb[3] + j * src_nb[2], - ggml_cann_type_mapping(type), ggml_type_size(type), src_ne, src_nb, 2); - - // index - acl_tensor_ptr acl_index = ggml_cann_create_tensor( - (char *) index->data + (i % index->ne[2]) * index->nb[2] + (j % index->ne[1]) * index->nb[1], - ggml_cann_type_mapping(index->type), ggml_element_size(index), index->ne, index->nb, 1); - - // out - acl_tensor_ptr acl_out = - ggml_cann_create_tensor((char *) dst_buffer + i * dst_nb[3] + j * dst_nb[2], - ggml_cann_type_mapping(type), ggml_type_size(type), dst_ne, dst_nb, 2); - GGML_CANN_CALL_ACLNN_OP(ctx, InplaceIndexCopy, acl_out.get(), 0, acl_index.get(), acl_src_tensor.get()); - } - } -} void ggml_cann_get_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst) { - ggml_tensor * src0 = dst->src[0]; // src + ggml_tensor * src0 = dst->src[0]; // weight ggml_tensor * src1 = dst->src[1]; // index GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_BF16); + // n_idx: number of row indices per (i2, i3) batch slice. + // ggml guarantees: src0->ne[2] == src1->ne[1], src0->ne[3] == src1->ne[2], src1->ne[3] == 1. + const int64_t n_idx = src1->ne[0]; + + // Gather all (i2, i3) batch slices from src into dst. + // ggml_cann_create_tensor reverses dims, so ACL sees [ne1, ne0]. + // GatherV2 with dim=0 gathers along ACL dim-0 == ggml ne[1] (the vocabulary / row axis). + // nb: the 4 strides of the source buffer (nb[0..1] for the 2D slice shape, + // nb[2..3] for computing per-batch-slice base pointer offsets). + auto gather_batched = [&](void * src_base, aclDataType acl_type, size_t type_size, + const size_t * nb) { + int64_t src_ne[2] = { src0->ne[0], src0->ne[1] }; + size_t src_nb_2d[2] = { nb[0], nb[1] }; + int64_t dst_ne[2] = { src0->ne[0], n_idx }; + size_t dst_nb_2d[2] = { dst->nb[0], dst->nb[1] }; + int64_t idx_ne[1] = { n_idx }; + size_t idx_nb[1] = { (size_t)ggml_element_size(src1) }; + + for (int64_t i3 = 0; i3 < src0->ne[3]; i3++) { + for (int64_t i2 = 0; i2 < src0->ne[2]; i2++) { + acl_tensor_ptr acl_src = ggml_cann_create_tensor( + (char *)src_base + i3 * nb[3] + i2 * nb[2], + acl_type, type_size, src_ne, src_nb_2d, 2); + acl_tensor_ptr acl_idx = ggml_cann_create_tensor( + (char *)src1->data + i3 * src1->nb[2] + i2 * src1->nb[1], + ggml_cann_type_mapping(src1->type), (size_t)ggml_element_size(src1), + idx_ne, idx_nb, 1); + acl_tensor_ptr acl_dst = ggml_cann_create_tensor( + (char *)dst->data + i3 * dst->nb[3] + i2 * dst->nb[2], + acl_type, type_size, dst_ne, dst_nb_2d, 2); + GGML_CANN_CALL_ACLNN_OP(ctx, GatherV2, acl_src.get(), 0, acl_idx.get(), acl_dst.get()); + } + } + }; + switch (src0->type) { case GGML_TYPE_BF16: case GGML_TYPE_F16: case GGML_TYPE_F32: if (src0->type == dst->type) { - aclnn_index_select_4d(ctx, src0->data, src0->ne, src0->nb, dst->data, dst->ne, dst->nb, src1, - dst->type); + gather_batched(src0->data, + ggml_cann_type_mapping(src0->type), ggml_type_size(src0->type), + src0->nb); } else { - acl_tensor_ptr acl_src0 = ggml_cann_create_tensor(src0); - ggml_cann_pool_alloc src_buffer_allocator(ctx.pool(), ggml_nelements(src0) * ggml_element_size(dst)); - void * src_trans_buffer = src_buffer_allocator.get(); - size_t src_trans_nb[GGML_MAX_DIMS]; - src_trans_nb[0] = dst->nb[0]; + // Cast src0 to dst type, then gather. + ggml_cann_pool_alloc src_cast_allocator(ctx.pool(), + ggml_nelements(src0) * ggml_element_size(dst)); + size_t src_cast_nb[GGML_MAX_DIMS]; + src_cast_nb[0] = ggml_type_size(dst->type); for (int i = 1; i < GGML_MAX_DIMS; i++) { - src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1]; + src_cast_nb[i] = src_cast_nb[i - 1] * src0->ne[i - 1]; } - acl_tensor_ptr src_trans_tensor = - ggml_cann_create_tensor(src_trans_buffer, ggml_cann_type_mapping(dst->type), - ggml_type_size(dst->type), src0->ne, src_trans_nb, GGML_MAX_DIMS); - aclnn_cast(ctx, acl_src0.get(), src_trans_tensor.get(), ggml_cann_type_mapping(dst->type)); - aclnn_index_select_4d(ctx, src_trans_buffer, src0->ne, src_trans_nb, dst->data, dst->ne, dst->nb, src1, - dst->type); + acl_tensor_ptr acl_src0 = ggml_cann_create_tensor(src0); + acl_tensor_ptr acl_src_cast = ggml_cann_create_tensor( + src_cast_allocator.get(), ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type), + src0->ne, src_cast_nb, GGML_MAX_DIMS); + aclnn_cast(ctx, acl_src0.get(), acl_src_cast.get(), ggml_cann_type_mapping(dst->type)); + + gather_batched(src_cast_allocator.get(), + ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type), + src_cast_nb); } break; case GGML_TYPE_Q8_0: { - // add 1 dim for bcast mul. + // Dequantize Q8_0 to dst type, then gather. size_t weight_nb[GGML_MAX_DIMS + 1], scale_nb[GGML_MAX_DIMS + 1], dequant_nb[GGML_MAX_DIMS + 1]; int64_t weight_ne[GGML_MAX_DIMS + 1], scale_ne[GGML_MAX_DIMS + 1], *dequant_ne; - int64_t scale_offset = 0; - // [3,4,5,64] -> [3,4,5,2,32] - weight_ne[0] = QK8_0; - weight_ne[1] = src0->ne[0] / QK8_0; - weight_nb[0] = sizeof(int8_t); - weight_nb[1] = weight_nb[0] * weight_ne[0]; + weight_ne[0] = QK8_0; + weight_ne[1] = src0->ne[0] / QK8_0; + weight_nb[0] = sizeof(int8_t); + weight_nb[1] = weight_nb[0] * weight_ne[0]; for (int i = 2; i < GGML_MAX_DIMS + 1; i++) { weight_ne[i] = src0->ne[i - 1]; weight_nb[i] = weight_nb[i - 1] * weight_ne[i - 1]; } - // [3,4,5,64] -> [3,4,5,2,1] scale_ne[0] = 1; scale_ne[1] = src0->ne[0] / QK8_0; scale_nb[0] = sizeof(uint16_t); @@ -1849,31 +2010,33 @@ void ggml_cann_get_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst) { scale_ne[i] = src0->ne[i - 1]; scale_nb[i] = scale_nb[i - 1] * scale_ne[i - 1]; } - // [3,4,5,64] -> [3,4,5,2,32] dequant_ne = weight_ne; dequant_nb[0] = ggml_type_size(dst->type); for (int i = 1; i < GGML_MAX_DIMS + 1; i++) { dequant_nb[i] = dequant_nb[i - 1] * dequant_ne[i - 1]; } - scale_offset = ggml_nelements(src0) * sizeof(int8_t); - ggml_cann_pool_alloc dequant_buffer_allocator(ctx.pool(), - ggml_nelements(src0) * ggml_type_size(dst->type)); - acl_tensor_ptr acl_weight_tensor = ggml_cann_create_tensor(src0->data, ACL_INT8, sizeof(int8_t), - weight_ne, weight_nb, GGML_MAX_DIMS + 1); - acl_tensor_ptr acl_scale_tensor = - ggml_cann_create_tensor(src0->data, ACL_FLOAT16, sizeof(uint16_t), scale_ne, scale_nb, - GGML_MAX_DIMS + 1, ACL_FORMAT_ND, scale_offset); - acl_tensor_ptr dequant_tensor = - ggml_cann_create_tensor(dequant_buffer_allocator.get(), ggml_cann_type_mapping(dst->type), - ggml_type_size(dst->type), dequant_ne, dequant_nb, GGML_MAX_DIMS + 1); - aclnn_mul(ctx, acl_weight_tensor.get(), acl_scale_tensor.get(), dequant_tensor.get()); - dequant_nb[0] = ggml_type_size(dst->type); + const int64_t scale_offset = ggml_nelements(src0) * sizeof(int8_t); + ggml_cann_pool_alloc dequant_allocator(ctx.pool(), + ggml_nelements(src0) * ggml_type_size(dst->type)); + acl_tensor_ptr acl_weight = ggml_cann_create_tensor(src0->data, ACL_INT8, sizeof(int8_t), + weight_ne, weight_nb, GGML_MAX_DIMS + 1); + acl_tensor_ptr acl_scale = ggml_cann_create_tensor( + src0->data, ACL_FLOAT16, sizeof(uint16_t), scale_ne, scale_nb, + GGML_MAX_DIMS + 1, ACL_FORMAT_ND, scale_offset); + acl_tensor_ptr acl_dequant = ggml_cann_create_tensor( + dequant_allocator.get(), ggml_cann_type_mapping(dst->type), + ggml_type_size(dst->type), dequant_ne, dequant_nb, GGML_MAX_DIMS + 1); + aclnn_mul(ctx, acl_weight.get(), acl_scale.get(), acl_dequant.get()); + + // Reinterpret dequant buffer as 4D [src0->ne] with contiguous strides. dequant_ne = src0->ne; + dequant_nb[0] = ggml_type_size(dst->type); for (int i = 1; i < GGML_MAX_DIMS; i++) { dequant_nb[i] = dequant_nb[i - 1] * src0->ne[i - 1]; } - aclnn_index_select_4d(ctx, dequant_buffer_allocator.get(), dequant_ne, dequant_nb, dst->data, dst->ne, - dst->nb, src1, dst->type); + gather_batched(dequant_allocator.get(), + ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type), + dequant_nb); break; } default: @@ -1883,31 +2046,70 @@ void ggml_cann_get_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst) { } void ggml_cann_set_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst) { - ggml_tensor * src0 = dst->src[0]; // src - ggml_tensor * src1 = dst->src[1]; // index + ggml_tensor * src0 = dst->src[0]; // source values + ggml_tensor * src1 = dst->src[1]; // row indices + + // n_idx: number of source rows to scatter per batch slice. + // ggml guarantees: src0->ne[1] == src1->ne[0]. + const int64_t n_idx = src1->ne[0]; + + // Copy n_idx rows of src [ne0, n_idx] into dst [ne0, ne1] at positions given by a 1D index. + // ggml_cann_create_tensor reverses dims, so ACL sees [ne1, ne0] for dst. + // InplaceIndexCopy with dim=0 copies along ACL dim-0 == ggml ne[1] (the row axis). + // src_nb: the 4 strides of the source buffer (nb[0..1] for the 2D slice shape, + // nb[2..3] for computing per-batch-slice base pointer offsets). + auto scatter_batched = [&](void * src_base, aclDataType acl_type, size_t type_size, + const size_t * src_nb) { + int64_t d_ne[2] = { dst->ne[0], dst->ne[1] }; + size_t d_nb[2] = { dst->nb[0], dst->nb[1] }; + int64_t s_ne[2] = { dst->ne[0], n_idx }; + size_t s_nb_2d[2] = { src_nb[0], src_nb[1] }; + int64_t i_ne[1] = { n_idx }; + size_t i_nb[1] = { (size_t)ggml_element_size(src1) }; + + for (int64_t i3 = 0; i3 < dst->ne[3]; i3++) { + for (int64_t i2 = 0; i2 < dst->ne[2]; i2++) { + acl_tensor_ptr acl_dst = ggml_cann_create_tensor( + (char *)dst->data + i3 * dst->nb[3] + i2 * dst->nb[2], + acl_type, type_size, d_ne, d_nb, 2); + acl_tensor_ptr acl_idx = ggml_cann_create_tensor( + (char *)src1->data + (i3 % src1->ne[2]) * src1->nb[2] + (i2 % src1->ne[1]) * src1->nb[1], + ggml_cann_type_mapping(src1->type), (size_t)ggml_element_size(src1), + i_ne, i_nb, 1); + acl_tensor_ptr acl_src = ggml_cann_create_tensor( + (char *)src_base + i3 * src_nb[3] + i2 * src_nb[2], + acl_type, type_size, s_ne, s_nb_2d, 2); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceIndexCopy, acl_dst.get(), 0, acl_idx.get(), acl_src.get()); + } + } + }; switch (dst->type) { case GGML_TYPE_F32: - { - aclnn_index_copy_4d(ctx, src0->data, src0->ne, src0->nb, dst->data, dst->ne, dst->nb, src1, dst->type); - break; - } + scatter_batched(src0->data, + ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type), + src0->nb); + break; case GGML_TYPE_F16: case GGML_TYPE_BF16: { - acl_tensor_ptr acl_src0 = ggml_cann_create_tensor(src0); - ggml_cann_pool_alloc src_buffer_allocator(ctx.pool(), ggml_nelements(src0) * sizeof(uint16_t)); - void * src_trans_buffer = src_buffer_allocator.get(); - size_t src_trans_nb[GGML_MAX_DIMS]; - src_trans_nb[0] = sizeof(uint16_t); + // Cast src0 (F32) to dst type first. + ggml_cann_pool_alloc src_cast_allocator(ctx.pool(), + ggml_nelements(src0) * ggml_type_size(dst->type)); + size_t src_cast_nb[GGML_MAX_DIMS]; + src_cast_nb[0] = ggml_type_size(dst->type); for (int i = 1; i < GGML_MAX_DIMS; i++) { - src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1]; + src_cast_nb[i] = src_cast_nb[i - 1] * src0->ne[i - 1]; } - acl_tensor_ptr src_trans_tensor = ggml_cann_create_tensor( - src_trans_buffer, ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type), src0->ne, src_trans_nb, GGML_MAX_DIMS); - aclnn_cast(ctx, acl_src0.get(), src_trans_tensor.get(), ggml_cann_type_mapping(dst->type)); - aclnn_index_copy_4d(ctx, src_trans_buffer, src0->ne, src_trans_nb, dst->data, dst->ne, dst->nb, src1, - dst->type); + acl_tensor_ptr acl_src0 = ggml_cann_create_tensor(src0); + acl_tensor_ptr acl_src_cast = ggml_cann_create_tensor( + src_cast_allocator.get(), ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type), + src0->ne, src_cast_nb, GGML_MAX_DIMS); + aclnn_cast(ctx, acl_src0.get(), acl_src_cast.get(), ggml_cann_type_mapping(dst->type)); + + scatter_batched(src_cast_allocator.get(), + ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type), + src_cast_nb); break; } default: @@ -3268,29 +3470,50 @@ void ggml_cann_pad_reflect_1d(ggml_backend_cann_context & ctx, ggml_tensor * dst int64_t paddingsArray[2] = { opts[0], opts[1] }; acl_int_array_ptr paddings = ggml_cann_create_int_array(paddingsArray, 2); - for (int64_t i = 0; i < src0->ne[3]; i++) { - acl_tensor_ptr acl_src = - ggml_cann_create_tensor((char *) src0->data + i * src0->ne[3], ggml_cann_type_mapping(src0->type), - ggml_element_size(src0), src0->ne, src0->nb, 3); + // Collapsing ne[2]*ne[3] into a single batch dimension requires that dim3 + // is contiguous with respect to dim2 in both src and dst. + GGML_ASSERT(src0->nb[3] == src0->nb[2] * src0->ne[2]); + GGML_ASSERT(dst->nb[3] == dst->nb[2] * dst->ne[2]); - acl_tensor_ptr acl_dst = - ggml_cann_create_tensor((char *) dst->data + i * src0->ne[3], ggml_cann_type_mapping(dst->type), - ggml_element_size(dst), dst->ne, dst->nb, 3); + int64_t src_ne_3d[3] = { src0->ne[0], src0->ne[1], src0->ne[2] * src0->ne[3] }; + int64_t dst_ne_3d[3] = { dst->ne[0], dst->ne[1], dst->ne[2] * dst->ne[3] }; - GGML_CANN_CALL_ACLNN_OP(ctx, ReflectionPad1d, acl_src.get(), paddings.get(), acl_dst.get()); - } + acl_tensor_ptr acl_src = ggml_cann_create_tensor(src0->data, ggml_cann_type_mapping(src0->type), + ggml_element_size(src0), src_ne_3d, src0->nb, 3); + + acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst->data, ggml_cann_type_mapping(dst->type), + ggml_element_size(dst), dst_ne_3d, dst->nb, 3); + + GGML_CANN_CALL_ACLNN_OP(ctx, ReflectionPad1d, acl_src.get(), paddings.get(), acl_dst.get()); } void ggml_cann_count_equal(ggml_backend_cann_context & ctx, ggml_tensor * dst) { ggml_tensor * src0 = dst->src[0]; ggml_tensor * src1 = dst->src[1]; + // Write element-wise equality (0 or 1) into a temporary buffer to avoid + // modifying src0 in-place. Use the same type as src0 so ReduceSum can + // consume it directly without a type cast. + ggml_cann_pool_alloc eq_alloc(ctx.pool(), ggml_nelements(src0) * ggml_element_size(src0)); + size_t eq_nb[GGML_MAX_DIMS]; + eq_nb[0] = ggml_element_size(src0); + for (int i = 1; i < GGML_MAX_DIMS; i++) { + eq_nb[i] = eq_nb[i - 1] * src0->ne[i - 1]; + } + acl_tensor_ptr acl_eq = ggml_cann_create_tensor( + eq_alloc.get(), ggml_cann_type_mapping(src0->type), ggml_element_size(src0), + src0->ne, eq_nb, GGML_MAX_DIMS); + acl_tensor_ptr acl_self = ggml_cann_create_tensor(src0); acl_tensor_ptr acl_other = ggml_cann_create_tensor(src1); - - GGML_CANN_CALL_ACLNN_OP(ctx, InplaceEqTensor, acl_self.get(), acl_other.get()); - - ggml_cann_sum(ctx, dst); + GGML_CANN_CALL_ACLNN_OP(ctx, EqTensor, acl_self.get(), acl_other.get(), acl_eq.get()); + + // Sum the 0/1 values into dst. + acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst); + int64_t dims[4] = { 0, 1, 2, 3 }; + acl_int_array_ptr dims_arr = ggml_cann_create_int_array(dims, 4); + GGML_CANN_CALL_ACLNN_OP(ctx, ReduceSum, acl_eq.get(), dims_arr.get(), true, + ggml_cann_type_mapping(dst->type), acl_dst.get()); } void ggml_cann_step(ggml_backend_cann_context & ctx, ggml_tensor * dst) { @@ -3306,6 +3529,27 @@ void ggml_cann_step(ggml_backend_cann_context & ctx, ggml_tensor * dst) { GGML_CANN_CALL_ACLNN_OP(ctx, GtScalar, acl_src.get(), alpha.get(), acl_dst.get()); } +void ggml_cann_softplus(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + ggml_tensor * src0 = dst->src[0]; + + acl_tensor_ptr acl_src = ggml_cann_create_tensor(src0); + acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst); + + float beta_val = 1.0f; + float threshold_val = 20.0f; + acl_scalar_ptr beta = ggml_cann_create_scalar(&beta_val, ACL_FLOAT); + acl_scalar_ptr threshold = ggml_cann_create_scalar(&threshold_val, ACL_FLOAT); + + GGML_CANN_CALL_ACLNN_OP(ctx, Softplus, acl_src.get(), beta.get(), threshold.get(), acl_dst.get()); +} + +void ggml_cann_geglu_quick(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + auto gelu_quick_fn = [](ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst) { + GGML_CANN_CALL_ACLNN_OP(ctx, GeluV2, acl_src, 0, acl_dst); + }; + ggml_cann_op_unary_gated(gelu_quick_fn, ctx, dst); +} + /** * @brief Performs expert-specific matrix multiplication (MoE) with * floating-point precision using the CANN backend. @@ -3892,46 +4136,65 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context & ctx, ggml_tensor * dst } static void ggml_cann_out_prod_fp(ggml_backend_cann_context & ctx, ggml_tensor * dst) { - ggml_tensor * src0 = dst->src[0]; // weight - ggml_tensor * src1 = dst->src[1]; // input + ggml_tensor * src0 = dst->src[0]; // weight [ne00=m, ne01=K, ne02, ne03] + ggml_tensor * src1 = dst->src[1]; // input [ne10=n, ne11=K, ne12, ne13] GGML_TENSOR_BINARY_OP_LOCALS - acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst); - GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, acl_dst.get()); + // dst[i,j] = sum_k src0[i,k] * src1[j,k] i.e. dst = src0 @ src1^T. + // + // ggml_cann_create_tensor reverses dimension order, so ACL sees: + // acl_src0 slice: ggml[m,K] -> ACL[K,m] + // acl_src1 slice: ggml[n,K] -> ACL[K,n] + // acl_dst slice: ggml[m,n] -> ACL[n,m] + // + // Build a transposed view of src1 by swapping ne[0]/ne[1]: + // src1_t: ggml[K,n] (swapped strides) -> ACL[n,K] + // + // Matmul(src1_t [n,K], src0 [K,m]) = [n,m] = acl_dst ✓ + // + // The outer batch loop is kept because src0 may have fewer batch slices than + // dst (ne02 <= ne2, ne03 <= ne3): this is a strided-broadcast not supported + // by standard CANN Matmul broadcasting. + + const aclDataType src0_acl_type = ggml_cann_type_mapping(src0->type); + const aclDataType src1_acl_type = ggml_cann_type_mapping(src1->type); + const aclDataType dst_acl_type = ggml_cann_type_mapping(dst->type); + const size_t src0_type_sz = ggml_type_size(src0->type); + const size_t src1_type_sz = ggml_type_size(src1->type); + const size_t dst_type_sz = ggml_type_size(dst->type); const int64_t dps2 = ne2 / ne02; const int64_t dps3 = ne3 / ne03; + for (int64_t i3 = 0; i3 < ne3; i3++) { for (int64_t i2 = 0; i2 < ne2; i2++) { const int64_t i02 = i2 / dps2; const int64_t i03 = i3 / dps3; - const int64_t i12 = i2; - const int64_t i13 = i3; - acl_tensor_ptr accumulator = - ggml_cann_create_tensor((char *) dst->data + i2 * nb2 + i3 * nb3, ggml_cann_type_mapping(dst->type), - ggml_type_size(dst->type), dst->ne, dst->nb, 2); - - // The outer product needs to be accumulated in this dimension. - for (int64_t i1 = 0; i1 < ne11; i1++) { - acl_tensor_ptr acl_input = ggml_cann_create_tensor( - (char *) src1->data + i1 * nb11 + i12 * nb12 + i13 * nb13, ggml_cann_type_mapping(src0->type), - ggml_type_size(src0->type), src1->ne, src1->nb, 1); - - acl_tensor_ptr acl_weight = ggml_cann_create_tensor( - (char *) src0->data + i1 * nb01 + i02 * nb02 + i03 * nb03, ggml_cann_type_mapping(src0->type), - ggml_type_size(src0->type), src0->ne, src0->nb, 1); - - ggml_cann_pool_alloc output_allocator(ctx.pool()); - void * output_buffer = output_allocator.alloc(ggml_nbytes(dst)); - acl_tensor_ptr acl_out = ggml_cann_create_tensor(output_buffer, ggml_cann_type_mapping(dst->type), - ggml_type_size(dst->type), dst->ne, dst->nb, 2); - - GGML_CANN_CALL_ACLNN_OP(ctx, Ger, acl_input.get(), acl_weight.get(), acl_out.get()); - float alpha_value = 1.0f; - aclScalar * alpha = aclCreateScalar(&alpha_value, ACL_FLOAT); - GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdd, accumulator.get(), acl_out.get(), alpha); - } + // src0 2D slice at [i02, i03]: ggml [m, K] -> ACL [K, m] + int64_t src0_ne[2] = { ne00, ne01 }; + size_t src0_nb[2] = { nb00, nb01 }; + acl_tensor_ptr acl_src0_s = ggml_cann_create_tensor( + (char *) src0->data + i02 * nb02 + i03 * nb03, + src0_acl_type, src0_type_sz, src0_ne, src0_nb, 2); + + // src1 transposed 2D slice at [i2, i3]: swap ne/nb -> ggml[K,n] -> ACL[n,K] + int64_t src1_t_ne[2] = { ne11, ne10 }; + size_t src1_t_nb[2] = { nb11, nb10 }; + acl_tensor_ptr acl_src1_t = ggml_cann_create_tensor( + (char *) src1->data + i2 * nb12 + i3 * nb13, + src1_acl_type, src1_type_sz, src1_t_ne, src1_t_nb, 2); + + // dst 2D slice at [i2, i3]: ggml [m, n] -> ACL [n, m] + int64_t dst_ne[2] = { ne0, ne1 }; + size_t dst_nb[2] = { nb0, nb1 }; + acl_tensor_ptr acl_dst_s = ggml_cann_create_tensor( + (char *) dst->data + i2 * nb2 + i3 * nb3, + dst_acl_type, dst_type_sz, dst_ne, dst_nb, 2); + + // Matmul(src1_t [n,K], src0 [K,m]) = [n,m] = acl_dst_s ✓ + GGML_CANN_CALL_ACLNN_OP(ctx, Matmul, + acl_src1_t.get(), acl_src0_s.get(), acl_dst_s.get(), (int8_t) 1); } } } @@ -4170,3 +4433,4 @@ void ggml_cann_gated_linear_attn(ggml_backend_cann_context & ctx, ggml_tensor * } } } + diff --git a/ggml/src/ggml-cann/aclnn_ops.h b/ggml/src/ggml-cann/aclnn_ops.h index 7f5ba4d3302..cdbf9260f85 100644 --- a/ggml/src/ggml-cann/aclnn_ops.h +++ b/ggml/src/ggml-cann/aclnn_ops.h @@ -32,6 +32,9 @@ #include #include #include +#include +#include +#include #include #include #include @@ -47,6 +50,9 @@ #include #include #include +#include +#include +#include #include #include #include @@ -69,6 +75,9 @@ */ void ggml_cann_repeat(ggml_backend_cann_context & ctx, ggml_tensor * dst); +void ggml_cann_swiglu(ggml_backend_cann_context & ctx, ggml_tensor * dst); +void ggml_cann_geglu(ggml_backend_cann_context & ctx, ggml_tensor * dst, int64_t approximate); + /** * @brief Applies the Leaky ReLU activation function to a tensor using the CANN * backend. @@ -325,6 +334,48 @@ void ggml_cann_sum_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst); void ggml_cann_sum(ggml_backend_cann_context & ctx, ggml_tensor * dst); +/** + * @brief Computes the cumulative sum of a ggml tensor along dim 0 using the + * CANN backend. + * + * @param ctx The CANN context used for operations. + * @param dst The destination tensor. dst->op is `GGML_OP_CUMSUM`. + */ +void ggml_cann_cumsum(ggml_backend_cann_context & ctx, ggml_tensor * dst); + +/** + * @brief Computes a triangular mask (tril/triu) of a square ggml tensor + * using the CANN backend. + * + * @param ctx The CANN context used for operations. + * @param dst The destination tensor. dst->op is `GGML_OP_TRI`. + */ +void ggml_cann_tri(ggml_backend_cann_context & ctx, ggml_tensor * dst); + +/** + * @brief Solves a triangular linear system AX=B using the CANN backend. + * + * @param ctx The CANN context used for operations. + * @param dst The destination tensor. dst->op is `GGML_OP_SOLVE_TRI`. + */ +void ggml_cann_solve_tri(ggml_backend_cann_context & ctx, ggml_tensor * dst); + +/** + * @brief Creates a diagonal matrix from a vector using the CANN backend. + * + * @param ctx The CANN context used for operations. + * @param dst The destination tensor. dst->op is `GGML_OP_DIAG`. + */ +void ggml_cann_diag(ggml_backend_cann_context & ctx, ggml_tensor * dst); + +/** + * @brief Fills a tensor with a constant scalar value using the CANN backend. + * + * @param ctx The CANN context used for operations. + * @param dst The destination tensor. dst->op is `GGML_OP_FILL`. + */ +void ggml_cann_fill(ggml_backend_cann_context & ctx, ggml_tensor * dst); + /** * @brief Upsamples a ggml tensor using nearest neighbor interpolation using * the CANN backend. @@ -461,6 +512,9 @@ void ggml_cann_timestep_embedding(ggml_backend_cann_context & ctx, ggml_tensor * // @see ggml_cann_dup. void ggml_cann_cpy(ggml_backend_cann_context & ctx, ggml_tensor * dst); +// @see ggml_cann_acc, but copies src1 into dst instead of adding. +void ggml_cann_set(ggml_backend_cann_context & ctx, ggml_tensor * dst); + /** * @brief Computes the softmax activation with optional masking. * @@ -813,6 +867,8 @@ void ggml_cann_count_equal(ggml_backend_cann_context & ctx, ggml_tensor * dst); * dst->op is expected to be `GGML_OP_STEP`. */ void ggml_cann_step(ggml_backend_cann_context & ctx, ggml_tensor * dst); +void ggml_cann_softplus(ggml_backend_cann_context & ctx, ggml_tensor * dst); +void ggml_cann_geglu_quick(ggml_backend_cann_context & ctx, ggml_tensor * dst); /** * @brief Performs the Flash Attention extended operator using the CANN backend. diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index 5fc484b342b..3618ba7f6f6 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -1428,6 +1428,22 @@ static bool ggml_backend_cann_buffer_cpy_tensor(ggml_backend_buffer_t buffer, return false; } +/** + * @brief Set a region of a tensor's device memory to a specified value. + * + * @param buffer The CANN buffer containing the tensor. + * @param tensor Pointer to the tensor whose memory will be set. + * @param value The value to which each byte in the region will be set. + * @param offset Byte offset within the tensor's data to start setting. + * @param size Number of bytes to set. + */ +static void ggml_backend_cann_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { + ggml_backend_cann_buffer_context * ctx = (ggml_backend_cann_buffer_context *) buffer->context; + + ggml_cann_set_device(ctx->device); + ACL_CHECK(aclrtMemset((char *) tensor->data + offset, size, value, size)); +} + /** * @brief Clear a CANN buffer by setting all its memory to a specified value. * @@ -1454,7 +1470,7 @@ static const ggml_backend_buffer_i ggml_backend_cann_buffer_interface = { /* .free_buffer = */ ggml_backend_cann_buffer_free_buffer, /* .get_base = */ ggml_backend_cann_buffer_get_base, /* .init_tensor = */ ggml_backend_cann_buffer_init_tensor, - /* .memset_tensor = */ NULL, + /* .memset_tensor = */ ggml_backend_cann_buffer_memset_tensor, /* .set_tensor = */ ggml_backend_cann_buffer_set_tensor, /* .get_tensor = */ ggml_backend_cann_buffer_get_tensor, /* .set_tensor_2d = */ NULL, @@ -1835,6 +1851,9 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context & ctx, struct gg case GGML_UNARY_OP_STEP: ggml_cann_step(ctx, dst); break; + case GGML_UNARY_OP_SOFTPLUS: + ggml_cann_softplus(ctx, dst); + break; default: return false; } @@ -1845,20 +1864,16 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context & ctx, struct gg GGML_CANN_CALL_OP_UNARY_GATED(Relu); break; case GGML_GLU_OP_GEGLU: + ggml_cann_geglu(ctx, dst, 0); // approximate=0 → tanh + break; case GGML_GLU_OP_GEGLU_ERF: - // aclnnGelu internally uses the erf-based approximation. - GGML_CANN_CALL_OP_UNARY_GATED(Gelu); + ggml_cann_geglu(ctx, dst, 1); // approximate=1 → erf break; case GGML_GLU_OP_SWIGLU: - GGML_CANN_CALL_OP_UNARY_GATED(Silu); + ggml_cann_swiglu(ctx, dst); break; case GGML_GLU_OP_GEGLU_QUICK: - { - auto lambda = [](ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst) { - GGML_CANN_CALL_ACLNN_OP(ctx, GeluV2, acl_src, 0, acl_dst); - }; - ggml_cann_op_unary_gated(lambda, ctx, dst); - } + ggml_cann_geglu_quick(ctx, dst); break; default: return false; @@ -1920,6 +1935,9 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context & ctx, struct gg case GGML_OP_CPY: ggml_cann_cpy(ctx, dst); break; + case GGML_OP_SET: + ggml_cann_set(ctx, dst); + break; case GGML_OP_CONT: ggml_cann_dup(ctx, dst); break; @@ -1989,6 +2007,21 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context & ctx, struct gg case GGML_OP_SSM_CONV: ggml_cann_ssm_conv(ctx, dst); break; + case GGML_OP_CUMSUM: + ggml_cann_cumsum(ctx, dst); + break; + case GGML_OP_TRI: + ggml_cann_tri(ctx, dst); + break; + case GGML_OP_FILL: + ggml_cann_fill(ctx, dst); + break; + case GGML_OP_DIAG: + ggml_cann_diag(ctx, dst); + break; + case GGML_OP_SOLVE_TRI: + ggml_cann_solve_tri(ctx, dst); + break; default: return false; } @@ -2324,6 +2357,7 @@ static enum ggml_status ggml_backend_cann_graph_compute(ggml_backend_t backend, if (use_cann_graph) { // If no matching graph is found, the graph needs to be recaptured. graph_capture_required = !cann_ctx->graph_lru_cache.find_and_move_to_front(cgraph); + if (graph_capture_required) { // If no matching graph is found, add a new ACL graph. ggml_cann_graph * new_graph = ggml_cann_graph::create_from_cgraph(cgraph); @@ -2382,6 +2416,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten case GGML_UNARY_OP_SGN: case GGML_UNARY_OP_STEP: case GGML_UNARY_OP_GELU_ERF: + case GGML_UNARY_OP_SOFTPLUS: return true; default: return false; @@ -2572,6 +2607,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten case GGML_OP_SUM_ROWS: case GGML_OP_ARGSORT: case GGML_OP_ACC: + case GGML_OP_SET: case GGML_OP_GROUP_NORM: return true; case GGML_OP_PAD: @@ -2649,6 +2685,16 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten } case GGML_OP_SSM_CONV: return true; + case GGML_OP_CUMSUM: + return op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_TRI: + return op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_FILL: + return op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_DIAG: + return op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_SOLVE_TRI: + return op->src[0]->type == GGML_TYPE_F32; default: return false; } From ca624d86abdbb9f332850227fc02d4b2f6d4f10e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrien=20Gallou=C3=ABt?= Date: Tue, 28 Apr 2026 08:56:02 +0200 Subject: [PATCH 507/831] ggml : revert to -lm linking instead of find_library (llama/22355) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * ggml : revert to -lm linking instead of find_library `find_library(MATH_LIBRARY m)` was introduced recently, but it breaks CUDA compilation with GGML_STATIC. I could not find any valid use case where we would prefer `find_library` over the standard `-lm` approach. This commit is also meant to start a discussion if there is a valid reason to keep `find_library(MATH_LIBRARY m)`, we should clarify what problem it was solving and find an alternative fix that does not break CUDA with GGML_STATIC. Signed-off-by: Adrien Gallouët * ggml : use MATH_LIBRARY only if defined Signed-off-by: Adrien Gallouët * ggml : fix initial broken condition Signed-off-by: Adrien Gallouët * ggml : always respect MATH_LIBRARY when defined Signed-off-by: Adrien Gallouët --------- Signed-off-by: Adrien Gallouët --- ggml/src/CMakeLists.txt | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index 52754e1b9d6..3e48860bfc8 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -470,11 +470,10 @@ endforeach() target_link_libraries(ggml-base PRIVATE Threads::Threads) -find_library(MATH_LIBRARY m) -if (MATH_LIBRARY) - if (NOT WIN32 OR NOT DEFINED ENV{ONEAPI_ROOT}) - target_link_libraries(ggml-base PRIVATE ${MATH_LIBRARY}) - endif() +if (DEFINED MATH_LIBRARY) + target_link_libraries(ggml-base PRIVATE ${MATH_LIBRARY}) +elseif (NOT WIN32 AND NOT DEFINED ENV{ONEAPI_ROOT}) + target_link_libraries(ggml-base PRIVATE m) endif() if (CMAKE_SYSTEM_NAME MATCHES "Android") From 6fceff2eb4b248e57f69b4ed6d1cf82a471ad493 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrien=20Gallou=C3=ABt?= Date: Tue, 28 Apr 2026 09:02:32 +0200 Subject: [PATCH 508/831] ggml : skip already registered backends and devices (llama/22296) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Adrien Gallouët --- ggml/src/ggml-backend-reg.cpp | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp index 0587109212e..8165ae2c8bb 100644 --- a/ggml/src/ggml-backend-reg.cpp +++ b/ggml/src/ggml-backend-reg.cpp @@ -181,6 +181,12 @@ struct ggml_backend_registry { return; } + for (auto & entry : backends) { + if (entry.reg == reg) { + return; + } + } + #ifndef NDEBUG GGML_LOG_DEBUG("%s: registered backend %s (%zu devices)\n", __func__, ggml_backend_reg_name(reg), ggml_backend_reg_dev_count(reg)); @@ -192,6 +198,12 @@ struct ggml_backend_registry { } void register_device(ggml_backend_dev_t device) { + for (auto & dev : devices) { + if (dev == device) { + return; + } + } + #ifndef NDEBUG GGML_LOG_DEBUG("%s: registered device %s (%s)\n", __func__, ggml_backend_dev_name(device), ggml_backend_dev_description(device)); #endif From 0fa31f9bb612e55b92b6877d729b947c9e6db4e0 Mon Sep 17 00:00:00 2001 From: Emil Askerov <56842174+EmilAskerov@users.noreply.github.com> Date: Tue, 28 Apr 2026 13:19:06 +0300 Subject: [PATCH 509/831] ggml: improve SPIR-V headers detection with __has_include (llama/21918) * ggml: improve SPIR-V headers detection with __has_include while preserving original _WIN32 logic * Address review comments: fix fallback logic and add FreeBSD support * Remove spirv_cross fallback as per review * Remove redundant __has_include check --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index d4acee8b1df..6256639ab97 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -20,12 +20,19 @@ DispatchLoaderDynamic & ggml_vk_default_dispatcher(); #define VULKAN_HPP_DEFAULT_DISPATCHER ggml_vk_default_dispatcher() #include -// SPIRV-Headers: LunarG Windows SDK uses Include/spirv-headers/spirv.hpp (not spirv/unified1/). MinGW/MSYS2 and -// Linux packages use Khronos layout spirv/unified1/spirv.hpp. See docs/build.md#vulkan. -#if defined(_WIN32) && !defined(__MINGW32__) -#include + +// SPIR-V Headers: different SDK installations expose different include paths. +// LunarG Vulkan SDK on Windows typically provides . +// Linux packages, MSYS2 and MinGW often use the Khronos layout . +#if __has_include() +# include +#elif __has_include() +# include +#elif __has_include() +# include #else -#include + // Fallback to let the compiler throw a standard "file not found" error +# include #endif #include From 35fa508360cf2baee08c5eeb7b78c01bc79af000 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Tue, 28 Apr 2026 12:28:12 +0200 Subject: [PATCH 510/831] vulkan: add barrier after writetimestamp (llama/21865) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 6256639ab97..69c24bb5877 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -13014,6 +13014,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr if (vk_perf_logger_enabled && vk_perf_logger_concurrent) { ctx->query_node_idx[ctx->query_idx] = node_idx; compute_ctx->s->buffer->buf.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++); + ggml_vk_sync_buffers(ctx, compute_ctx); } } // Add all fused nodes to the unsynchronized lists. @@ -14503,6 +14504,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg compute_ctx = ggml_vk_get_compute_ctx(ctx); ctx->query_idx = 0; compute_ctx->s->buffer->buf.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++); + ggml_vk_sync_buffers(ctx, compute_ctx); } ctx->prealloc_y_last_pipeline_used = nullptr; @@ -14739,6 +14741,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg ctx->query_nodes[ctx->query_idx] = cgraph->nodes[i]; ctx->query_fusion_names[ctx->query_idx] = fusion_string; compute_ctx->s->buffer->buf.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++); + ggml_vk_sync_buffers(ctx, compute_ctx); } else { // track a fusion string and number of fused ops for the current node_idx ctx->query_fusion_names[i] = fusion_string; From 4ea5b6febcbab8c100da752d8b826afbcfec1382 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Tue, 28 Apr 2026 07:27:17 -0700 Subject: [PATCH 511/831] ggml-webgpu: fix buffer aliasing for ssm_scan and refactor aliasing logic (llama/22456) * Refactor buffer aliasing to be part of shader lib decisions * cleanup * formatting --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 159 ++++++--- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 335 +++++++++--------- ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl | 6 +- .../wgsl-shaders/rms_norm_mul.wgsl | 6 +- .../ggml-webgpu/wgsl-shaders/ssm_scan.wgsl | 25 ++ 5 files changed, 301 insertions(+), 230 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index fb2c9527f3c..34cbf3694b1 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -26,21 +26,21 @@ // Matrix multiplication parameters // Register tiling parameters -#define WEBGPU_MUL_MAT_TILE_M 4 -#define WEBGPU_MUL_MAT_TILE_N 4 -#define WEBGPU_MUL_MAT_WG_SIZE_M 8 -#define WEBGPU_MUL_MAT_WG_SIZE_N 8 +#define WEBGPU_MUL_MAT_TILE_M 4 +#define WEBGPU_MUL_MAT_TILE_N 4 +#define WEBGPU_MUL_MAT_WG_SIZE_M 8 +#define WEBGPU_MUL_MAT_WG_SIZE_N 8 #define WEBGPU_MUL_MAT_REG_TILE_K_FLOAT 8 #define WEBGPU_MUL_MAT_REG_TILE_K_QUANT 32 // Subgroup matrix parameters // The number of subgroups in the M dimension -#define WEBGPU_MUL_MAT_SUBGROUP_M 2 +#define WEBGPU_MUL_MAT_SUBGROUP_M 2 // The number of subgroups in the N dimension -#define WEBGPU_MUL_MAT_SUBGROUP_N 4 +#define WEBGPU_MUL_MAT_SUBGROUP_N 4 // The number of subgroup matrices each subgroup accumulates over -#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M 4 -#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2 +#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M 4 +#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2 #define WEBGPU_MUL_MAT_SUBGROUP_TILE_K_FLOAT 32 #define WEBGPU_MUL_MAT_SUBGROUP_TILE_K_QUANT 32 @@ -59,19 +59,32 @@ template inline void ggml_webgpu_hash_combine(size_t & seed, const seed ^= std::hash{}(value) + 0x9e3779b9 + (seed << 6) + (seed >> 2); } +// Calculates base address of a tensor ignoring the fake base pointer +inline uintptr_t ggml_webgpu_tensor_addr(const ggml_tensor * tensor) { + const ggml_tensor * base_tensor = tensor->view_src ? tensor->view_src : tensor; + return (uintptr_t) base_tensor->data + tensor->view_offs; +} + +inline bool ggml_webgpu_tensor_equal(const ggml_tensor * a, const ggml_tensor * b) { + return a->buffer == b->buffer && ggml_webgpu_tensor_addr(a) == ggml_webgpu_tensor_addr(b); +} + +inline bool ggml_webgpu_tensor_overlap(const ggml_tensor * a, const ggml_tensor * b) { + return a->buffer == b->buffer && ggml_webgpu_tensor_addr(a) < ggml_webgpu_tensor_addr(b) + ggml_nbytes(b) && + ggml_webgpu_tensor_addr(b) < ggml_webgpu_tensor_addr(a) + ggml_nbytes(a); +} + struct ggml_webgpu_shader_lib_context { ggml_tensor * src0; ggml_tensor * src1; ggml_tensor * src2; ggml_tensor * src3; ggml_tensor * src4; + ggml_tensor * src5; ggml_tensor * dst; uint32_t max_wg_size; size_t wg_mem_limit_bytes = 0; - bool inplace = false; - bool overlap = false; - bool src_overlap = false; bool supports_subgroups = false; bool supports_subgroup_matrix = false; uint32_t sg_mat_m = 0; @@ -88,6 +101,14 @@ struct webgpu_pipeline { struct ggml_webgpu_generic_shader_decisions { uint32_t wg_size = 0; + bool inplace = false; +}; + +struct ggml_webgpu_binary_shader_decisions { + uint32_t wg_size = 0; + bool inplace = false; + bool overlap = false; + bool src_overlap = false; }; struct ggml_webgpu_processed_shader { @@ -102,11 +123,12 @@ struct ggml_webgpu_ssm_conv_shader_decisions { }; struct ggml_webgpu_ssm_scan_pipeline_key { - int type; - int d_state; + int type; + int d_state; + bool xbc_overlap; bool operator==(const ggml_webgpu_ssm_scan_pipeline_key & other) const { - return type == other.type && d_state == other.d_state; + return type == other.type && d_state == other.d_state && xbc_overlap == other.xbc_overlap; } }; @@ -115,6 +137,7 @@ struct ggml_webgpu_ssm_scan_pipeline_key_hash { size_t seed = 0; ggml_webgpu_hash_combine(seed, key.type); ggml_webgpu_hash_combine(seed, key.d_state); + ggml_webgpu_hash_combine(seed, key.xbc_overlap); return seed; } }; @@ -122,6 +145,7 @@ struct ggml_webgpu_ssm_scan_pipeline_key_hash { struct ggml_webgpu_ssm_scan_shader_decisions { uint32_t wg_size; uint32_t tokens_per_tile; + bool xbc_overlap = false; }; /** Argsort **/ @@ -242,6 +266,13 @@ struct ggml_webgpu_rms_norm_mul_pipeline_key_hash { } }; +struct ggml_webgpu_rms_norm_mul_shader_decisions { + uint32_t wg_size = 0; + bool inplace = false; + bool overlap = false; + bool src_overlap = false; +}; + /** Pad **/ struct ggml_webgpu_pad_pipeline_key { bool circular; @@ -503,11 +534,12 @@ struct ggml_webgpu_flash_attn_pipeline_key_hash { }; struct ggml_webgpu_flash_attn_decisions { - uint32_t path = GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX; - uint32_t q_tile = 0; - uint32_t kv_tile = 0; - uint32_t wg_size = 0; - bool kv_direct = false; + uint32_t path = GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX; + uint32_t q_tile = 0; + uint32_t kv_tile = 0; + uint32_t wg_size = 0; + bool kv_direct = false; + bool kv_overlap = false; }; inline constexpr uint32_t GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH = 4u; @@ -552,7 +584,7 @@ inline ggml_webgpu_flash_attn_pipeline_key ggml_webgpu_flash_attn_make_pipeline_ key.head_dim_qk = (uint32_t) context.src0->ne[0]; key.head_dim_v = (uint32_t) context.src2->ne[0]; key.kv_direct = kv_direct; - key.kv_overlap = context.src_overlap; + key.kv_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src2); key.has_mask = has_mask; key.has_sinks = has_sinks; key.uses_logit_softcap = ggml_get_op_params_f32(context.dst, 2) != 0.0f; @@ -1021,7 +1053,7 @@ class ggml_webgpu_shader_lib { webgpu_pipeline get_row_norm_pipeline(const ggml_webgpu_shader_lib_context & context) { ggml_webgpu_row_norm_pipeline_key key = {}; key.op = context.dst->op; - key.inplace = context.inplace; + key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst); auto it = row_norm_pipelines.find(key); if (it != row_norm_pipelines.end()) { @@ -1051,8 +1083,12 @@ class ggml_webgpu_shader_lib { const uint32_t row_norm_wg_size = 128u; uint32_t wg_size = std::min(context.max_wg_size, row_norm_wg_size); defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); - auto processed = preprocessor.preprocess(wgsl_row_norm, defines); - row_norm_pipelines[key] = ggml_webgpu_create_pipeline(device, processed, variant); + auto processed = preprocessor.preprocess(wgsl_row_norm, defines); + auto decisions = std::make_shared(); + decisions->wg_size = wg_size; + decisions->inplace = key.inplace; + row_norm_pipelines[key] = ggml_webgpu_create_pipeline(device, processed, variant); + row_norm_pipelines[key].context = decisions; return row_norm_pipelines[key]; } @@ -1127,7 +1163,7 @@ class ggml_webgpu_shader_lib { webgpu_pipeline get_set_pipeline(const ggml_webgpu_shader_lib_context & context) { ggml_webgpu_set_pipeline_key key = {}; key.type = context.dst->type; - key.inplace = context.inplace; + key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst); auto it = set_pipelines.find(key); if (it != set_pipelines.end()) { @@ -1160,6 +1196,7 @@ class ggml_webgpu_shader_lib { auto processed = preprocessor.preprocess(wgsl_set, defines); auto decisions = std::make_shared(); decisions->wg_size = context.max_wg_size; + decisions->inplace = key.inplace; webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); pipeline.context = decisions; set_pipelines[key] = pipeline; @@ -1355,7 +1392,7 @@ class ggml_webgpu_shader_lib { webgpu_pipeline get_scale_pipeline(const ggml_webgpu_shader_lib_context & context) { ggml_webgpu_scale_pipeline_key key = {}; - key.inplace = context.inplace; + key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst); auto it = scale_pipelines.find(key); if (it != scale_pipelines.end()) { @@ -1375,6 +1412,7 @@ class ggml_webgpu_shader_lib { auto processed = preprocessor.preprocess(wgsl_scale, defines); auto decisions = std::make_shared(); decisions->wg_size = context.max_wg_size; + decisions->inplace = key.inplace; webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); pipeline.context = decisions; scale_pipelines[key] = pipeline; @@ -1468,6 +1506,8 @@ class ggml_webgpu_shader_lib { ggml_webgpu_ssm_scan_pipeline_key key = {}; key.type = context.dst->type; key.d_state = (int) context.src0->ne[0]; + key.xbc_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src4) && + ggml_webgpu_tensor_overlap(context.src1, context.src5); auto it = ssm_scan_pipelines.find(key); if (it != ssm_scan_pipelines.end()) { @@ -1499,12 +1539,17 @@ class ggml_webgpu_shader_lib { variant += "_wg_reduce"; } + if (key.xbc_overlap) { + defines.push_back("XBC_OVERLAP"); + } + variant += "_d" + std::to_string(key.d_state); auto processed = preprocessor.preprocess(wgsl_ssm_scan, defines); auto decisions = std::make_shared(); decisions->wg_size = wg_size; decisions->tokens_per_tile = tokens_per_tile; + decisions->xbc_overlap = key.xbc_overlap; webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); pipeline.context = decisions; ssm_scan_pipelines[key] = pipeline; @@ -1764,11 +1809,9 @@ class ggml_webgpu_shader_lib { uint32_t tile_k; if (key.use_subgroup_matrix) { - tile_k = is_quant ? WEBGPU_MUL_MAT_SUBGROUP_TILE_K_QUANT - : WEBGPU_MUL_MAT_SUBGROUP_TILE_K_FLOAT; + tile_k = is_quant ? WEBGPU_MUL_MAT_SUBGROUP_TILE_K_QUANT : WEBGPU_MUL_MAT_SUBGROUP_TILE_K_FLOAT; } else { - tile_k = is_quant ? WEBGPU_MUL_MAT_REG_TILE_K_QUANT - : WEBGPU_MUL_MAT_REG_TILE_K_FLOAT; + tile_k = is_quant ? WEBGPU_MUL_MAT_REG_TILE_K_QUANT : WEBGPU_MUL_MAT_REG_TILE_K_FLOAT; } // Tiles @@ -2001,9 +2044,8 @@ class ggml_webgpu_shader_lib { defines.push_back("SCALAR"); // mul_mat_id is register-tile only. - const uint32_t tile_k = ggml_is_quantized(context.src0->type) - ? WEBGPU_MUL_MAT_REG_TILE_K_QUANT - : WEBGPU_MUL_MAT_REG_TILE_K_FLOAT; + const uint32_t tile_k = + ggml_is_quantized(context.src0->type) ? WEBGPU_MUL_MAT_REG_TILE_K_QUANT : WEBGPU_MUL_MAT_REG_TILE_K_FLOAT; // Tiles defines.push_back("TILE_M=" + std::to_string(WEBGPU_MUL_MAT_TILE_M) + "u"); @@ -2039,8 +2081,8 @@ class ggml_webgpu_shader_lib { key.type = context.dst->type; key.op = op; key.is_unary = is_unary; - key.inplace = context.inplace; - key.ttype = (ggml_tri_type) ggml_get_op_params_i32(context.dst, 0); + key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst) || context.dst->op == GGML_OP_FILL; + key.ttype = (ggml_tri_type) ggml_get_op_params_i32(context.dst, 0); auto it = unary_pipelines.find(key); if (it != unary_pipelines.end()) { @@ -2098,6 +2140,7 @@ class ggml_webgpu_shader_lib { auto processed = preprocessor.preprocess(wgsl_unary, defines); auto decisions = std::make_shared(); decisions->wg_size = context.max_wg_size; + decisions->inplace = key.inplace; webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); pipeline.context = decisions; unary_pipelines[key] = pipeline; @@ -2106,9 +2149,9 @@ class ggml_webgpu_shader_lib { webgpu_pipeline get_rms_norm_mul_pipeline(const ggml_webgpu_shader_lib_context & context) { ggml_webgpu_rms_norm_mul_pipeline_key key = {}; - key.inplace = context.inplace; - key.overlap = context.overlap; - key.src_overlap = context.src_overlap; + key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst); + key.overlap = ggml_webgpu_tensor_equal(context.src1, context.dst); + key.src_overlap = ggml_webgpu_tensor_overlap(context.src0, context.src1); auto it = rms_norm_mul_pipelines.find(key); if (it != rms_norm_mul_pipelines.end()) { @@ -2132,12 +2175,15 @@ class ggml_webgpu_shader_lib { defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); - auto processed = preprocessor.preprocess(wgsl_rms_norm_mul, defines); - auto decisions = std::make_shared(); - decisions->wg_size = context.max_wg_size; - webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); - pipeline.context = decisions; - rms_norm_mul_pipelines[key] = pipeline; + auto processed = preprocessor.preprocess(wgsl_rms_norm_mul, defines); + auto pipeline_decisions = std::make_shared(); + pipeline_decisions->wg_size = context.max_wg_size; + pipeline_decisions->inplace = key.inplace; + pipeline_decisions->overlap = key.overlap; + pipeline_decisions->src_overlap = key.src_overlap; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = pipeline_decisions; + rms_norm_mul_pipelines[key] = pipeline; return rms_norm_mul_pipelines[key]; } @@ -2145,9 +2191,9 @@ class ggml_webgpu_shader_lib { ggml_webgpu_binary_pipeline_key key = {}; key.type = context.dst->type; key.op = context.dst->op; - key.inplace = context.inplace; - key.overlap = context.overlap; - key.src_overlap = context.src_overlap; + key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst); + key.overlap = ggml_webgpu_tensor_equal(context.src1, context.dst); + key.src_overlap = ggml_webgpu_tensor_overlap(context.src0, context.src1); auto it = binary_pipelines.find(key); if (it != binary_pipelines.end()) { @@ -2186,11 +2232,15 @@ class ggml_webgpu_shader_lib { defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); - auto processed = preprocessor.preprocess(wgsl_binary, defines); - auto decisions = std::make_shared(); - decisions->wg_size = context.max_wg_size; + auto processed = preprocessor.preprocess(wgsl_binary, defines); + auto pipeline_decisions = std::make_shared(); + pipeline_decisions->wg_size = context.max_wg_size; + pipeline_decisions->inplace = key.inplace; + pipeline_decisions->overlap = key.overlap; + pipeline_decisions->src_overlap = key.src_overlap; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); - pipeline.context = decisions; + pipeline.context = pipeline_decisions; binary_pipelines[key] = pipeline; return binary_pipelines[key]; } @@ -2351,7 +2401,8 @@ class ggml_webgpu_shader_lib { defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k)); } - auto pipeline_decisions = std::make_shared(decisions); + auto pipeline_decisions = std::make_shared(decisions); + pipeline_decisions->kv_overlap = key.kv_overlap; defines.push_back(std::string("Q_TILE=") + std::to_string(decisions.q_tile)); defines.push_back(std::string("KV_TILE=") + std::to_string(decisions.kv_tile)); defines.push_back(std::string("WG_SIZE=") + std::to_string(decisions.wg_size)); @@ -2543,7 +2594,7 @@ class ggml_webgpu_shader_lib { webgpu_pipeline get_rope_pipeline(const ggml_webgpu_shader_lib_context & context) { ggml_webgpu_rope_pipeline_key key = {}; key.type = context.dst->type; - key.inplace = context.inplace; + key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst); key.has_ff = (context.src2 != nullptr); auto it = rope_pipelines.find(key); @@ -2582,6 +2633,7 @@ class ggml_webgpu_shader_lib { auto processed = preprocessor.preprocess(wgsl_rope, defines); auto decisions = std::make_shared(); decisions->wg_size = context.max_wg_size; + decisions->inplace = key.inplace; webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); pipeline.context = decisions; rope_pipelines[key] = pipeline; @@ -2593,7 +2645,7 @@ class ggml_webgpu_shader_lib { key.mask_type = context.src1 ? context.src1->type : GGML_TYPE_F32; key.has_mask = (context.src1 != nullptr); key.has_sink = (context.src2 != nullptr); - key.inplace = context.inplace; + key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst); auto it = soft_max_pipelines.find(key); if (it != soft_max_pipelines.end()) { @@ -2634,6 +2686,7 @@ class ggml_webgpu_shader_lib { auto processed = preprocessor.preprocess(wgsl_soft_max, defines); auto decisions = std::make_shared(); decisions->wg_size = context.max_wg_size; + decisions->inplace = key.inplace; webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); pipeline.context = decisions; soft_max_pipelines[key] = pipeline; diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 6d861c0c781..762d9f8d1b4 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -108,12 +108,9 @@ static inline uint32_t ggml_webgpu_u32_from_f32(float value) { // their locations. static void * const webgpu_ptr_base = (void *) (uintptr_t) 0x1000; // NOLINT -// Always returns the base offset of a tensor, regardless of views. -static uint64_t webgpu_tensor_offset(const ggml_tensor * tensor) { - if (tensor->view_src) { - return (uint8_t *) tensor->view_src->data - (uint8_t *) webgpu_ptr_base; - } - return (uint8_t *) tensor->data - (uint8_t *) webgpu_ptr_base; +static size_t ggml_webgpu_tensor_offset(const ggml_tensor * tensor) { + const ggml_tensor * base_tensor = tensor->view_src ? tensor->view_src : tensor; + return (size_t) ((uintptr_t) base_tensor->data - (uintptr_t) webgpu_ptr_base) + tensor->view_offs; } /* Struct definitions */ @@ -375,10 +372,6 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device, buffer = device.CreateBuffer(&buffer_desc); } -static size_t ggml_webgpu_tensor_offset(const ggml_tensor * tensor) { - return webgpu_tensor_offset(tensor) + tensor->view_offs; -} - static wgpu::Buffer ggml_webgpu_tensor_buf(const ggml_tensor * tensor) { ggml_backend_webgpu_buffer_context * ctx = (ggml_backend_webgpu_buffer_context *) tensor->buffer->context; return ctx->buffer; @@ -398,34 +391,31 @@ static size_t ggml_webgpu_tensor_binding_size(webgpu_context & ctx, ggml_tensor return ROUNDUP_POW2(ggml_nbytes(t) + ggml_webgpu_tensor_misalignment(ctx, t), WEBGPU_STORAGE_BUF_BINDING_MULT); } -// Used to determine if two tensors are the same for in-place operations -static bool ggml_webgpu_tensor_equal(ggml_tensor * a, ggml_tensor * b) { - return (ggml_webgpu_tensor_buf(a).Get() == ggml_webgpu_tensor_buf(b).Get()) && - (ggml_webgpu_tensor_offset(a) == ggml_webgpu_tensor_offset(b)); -} +struct ggml_webgpu_merged_binding_range { + size_t offset; + size_t size; +}; -// Used to determine if two tensors share the same buffer and their byte ranges overlap, -static bool ggml_webgpu_tensor_overlap(ggml_tensor * a, ggml_tensor * b) { - return (ggml_webgpu_tensor_buf(a).Get() == ggml_webgpu_tensor_buf(b).Get()) && - ggml_webgpu_tensor_offset(a) < (ggml_webgpu_tensor_offset(b) + ggml_nbytes(b)) && - ggml_webgpu_tensor_offset(b) < (ggml_webgpu_tensor_offset(a) + ggml_nbytes(a)); -} +static ggml_webgpu_merged_binding_range ggml_webgpu_tensor_merged_binding_range( + webgpu_context & ctx, + std::initializer_list tensors) { + size_t merged_offset = SIZE_MAX; + size_t merged_end = 0; -struct binary_overlap_flags { - bool inplace; // src0 == dst - bool overlap; // src1 == dst - bool src_overlap; -}; + for (ggml_tensor * tensor : tensors) { + const size_t bind_offset = ggml_webgpu_tensor_align_offset(ctx, tensor); + const size_t bind_end = bind_offset + ggml_webgpu_tensor_binding_size(ctx, tensor); -static binary_overlap_flags ggml_webgpu_detect_binary_overlap(ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * dst) { - binary_overlap_flags flags = {}; - flags.inplace = ggml_webgpu_tensor_equal(src0, dst); - flags.overlap = ggml_webgpu_tensor_overlap(src1, dst); - flags.src_overlap = ggml_webgpu_tensor_overlap(src0, src1); + merged_offset = std::min(merged_offset, bind_offset); + merged_end = std::max(merged_end, bind_end); + } + + return { merged_offset, merged_end - merged_offset }; +} - return flags; +static uint32_t ggml_webgpu_tensor_merged_element_offset(const ggml_tensor * tensor, + const ggml_webgpu_merged_binding_range & merged_range) { + return (uint32_t) ((ggml_webgpu_tensor_offset(tensor) - merged_range.offset) / ggml_type_size(tensor->type)); } static wgpu::BindGroupEntry ggml_webgpu_make_bind_group_entry(uint32_t binding, @@ -753,18 +743,16 @@ static webgpu_encoded_op ggml_webgpu_set(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) { - const bool inplace = ggml_webgpu_tensor_equal(src0, dst); - ggml_webgpu_shader_lib_context shader_lib_ctx = {}; shader_lib_ctx.src0 = src0; shader_lib_ctx.src1 = src1; shader_lib_ctx.dst = dst; shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; - shader_lib_ctx.inplace = inplace; webgpu_pipeline pipeline = ctx->shader_lib->get_set_pipeline(shader_lib_ctx); - auto * decisions = static_cast(pipeline.context.get()); + auto * decisions = static_cast(pipeline.context.get()); + const bool inplace = decisions->inplace; const uint32_t ne = inplace ? (uint32_t) ggml_nelements(src1) : (uint32_t) ggml_nelements(dst); const uint32_t dst_type_size = (uint32_t) ggml_type_size(dst->type); @@ -1126,19 +1114,39 @@ static webgpu_encoded_op ggml_webgpu_ssm_scan(webgpu_context & ctx, ggml_tensor * dst) { ggml_webgpu_shader_lib_context shader_lib_ctx = {}; shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.src4 = src4; + shader_lib_ctx.src5 = src5; shader_lib_ctx.dst = dst; shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; shader_lib_ctx.supports_subgroups = ctx->global_ctx->capabilities.supports_subgroups; - webgpu_pipeline pipeline = ctx->shader_lib->get_ssm_scan_pipeline(shader_lib_ctx); + webgpu_pipeline pipeline = ctx->shader_lib->get_ssm_scan_pipeline(shader_lib_ctx); + auto * decisions = static_cast(pipeline.context.get()); + const bool xbc_overlap = decisions->xbc_overlap; + + uint32_t offset_x = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)); + uint32_t offset_B = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src4) / ggml_type_size(src4->type)); + uint32_t offset_C = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src5) / ggml_type_size(src5->type)); + size_t xbc_bind_offset = 0; + size_t xbc_bind_size = 0; + if (xbc_overlap) { + const ggml_webgpu_merged_binding_range merged_range = + ggml_webgpu_tensor_merged_binding_range(ctx, { src1, src4, src5 }); + xbc_bind_offset = merged_range.offset; + xbc_bind_size = merged_range.size; + offset_x = ggml_webgpu_tensor_merged_element_offset(src1, merged_range); + offset_B = ggml_webgpu_tensor_merged_element_offset(src4, merged_range); + offset_C = ggml_webgpu_tensor_merged_element_offset(src5, merged_range); + } std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), + offset_x, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src3) / ggml_type_size(src3->type)), - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src4) / ggml_type_size(src4->type)), - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src5) / ggml_type_size(src5->type)), + offset_B, + offset_C, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src6) / ggml_type_size(src6->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), @@ -1174,11 +1182,24 @@ static webgpu_encoded_op ggml_webgpu_ssm_scan(webgpu_context & ctx, }; std::vector entries = { - ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1), - ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, src2), ggml_webgpu_make_tensor_bind_group_entry(ctx, 3, src3), - ggml_webgpu_make_tensor_bind_group_entry(ctx, 4, src4), ggml_webgpu_make_tensor_bind_group_entry(ctx, 5, src5), - ggml_webgpu_make_tensor_bind_group_entry(ctx, 6, src6), ggml_webgpu_make_tensor_bind_group_entry(ctx, 7, dst), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), }; + if (xbc_overlap) { + entries.push_back( + ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(src1), xbc_bind_offset, xbc_bind_size)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, src2)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 3, src3)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 4, src6)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 5, dst)); + } else { + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, src2)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 3, src3)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 4, src4)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 5, src5)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 6, src6)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 7, dst)); + } const uint32_t total_wg = (uint32_t) (src0->ne[1] * src0->ne[2] * src1->ne[3]); const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; @@ -1653,23 +1674,38 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, float m0 = powf(2.0f, -(max_bias) / n_head_log2); float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = Q; + shader_lib_ctx.src1 = K; + shader_lib_ctx.src2 = V; + shader_lib_ctx.src3 = mask; + shader_lib_ctx.src4 = sinks; + shader_lib_ctx.dst = dst; + shader_lib_ctx.supports_subgroups = ctx->global_ctx->capabilities.supports_subgroups; + shader_lib_ctx.supports_subgroup_matrix = ctx->global_ctx->capabilities.supports_subgroup_matrix; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; + shader_lib_ctx.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m; + shader_lib_ctx.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n; + shader_lib_ctx.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k; + shader_lib_ctx.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size; + webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_pipeline( + shader_lib_ctx, ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment); + auto * decisions = static_cast(pipeline.context.get()); const int has_mask = (mask != nullptr); const int has_sinks = (sinks != nullptr); - const bool kv_overlap = ggml_webgpu_tensor_overlap(K, V) && K->type == V->type; + const bool kv_overlap = decisions->kv_overlap; uint32_t offset_k = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type)); uint32_t offset_v = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type)); size_t kv_bind_offset = 0; size_t kv_bind_size = 0; if (kv_overlap) { - const size_t k_bind_offset = ggml_webgpu_tensor_align_offset(ctx, K); - const size_t v_bind_offset = ggml_webgpu_tensor_align_offset(ctx, V); - const size_t k_bind_end = k_bind_offset + ggml_webgpu_tensor_binding_size(ctx, K); - const size_t v_bind_end = v_bind_offset + ggml_webgpu_tensor_binding_size(ctx, V); - kv_bind_offset = std::min(k_bind_offset, v_bind_offset); - kv_bind_size = std::max(k_bind_end, v_bind_end) - kv_bind_offset; - offset_k = (uint32_t) ((ggml_webgpu_tensor_offset(K) - kv_bind_offset) / ggml_type_size(K->type)); - offset_v = (uint32_t) ((ggml_webgpu_tensor_offset(V) - kv_bind_offset) / ggml_type_size(V->type)); + const ggml_webgpu_merged_binding_range merged_range = ggml_webgpu_tensor_merged_binding_range(ctx, { K, V }); + kv_bind_offset = merged_range.offset; + kv_bind_size = merged_range.size; + offset_k = ggml_webgpu_tensor_merged_element_offset(K, merged_range); + offset_v = ggml_webgpu_tensor_merged_element_offset(V, merged_range); } std::vector params = { @@ -1720,26 +1756,6 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, } entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index++, dst)); - ggml_webgpu_shader_lib_context shader_lib_ctx = {}; - shader_lib_ctx.src0 = Q; - shader_lib_ctx.src1 = K; - shader_lib_ctx.src2 = V; - shader_lib_ctx.src3 = mask; - shader_lib_ctx.src4 = sinks; - shader_lib_ctx.dst = dst; - shader_lib_ctx.src_overlap = kv_overlap; - shader_lib_ctx.supports_subgroups = ctx->global_ctx->capabilities.supports_subgroups; - shader_lib_ctx.supports_subgroup_matrix = ctx->global_ctx->capabilities.supports_subgroup_matrix; - shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; - shader_lib_ctx.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; - shader_lib_ctx.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m; - shader_lib_ctx.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n; - shader_lib_ctx.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k; - shader_lib_ctx.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size; - webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_pipeline( - shader_lib_ctx, ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment); - auto * decisions = static_cast(pipeline.context.get()); - if (decisions->path != GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions->q_tile); uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches @@ -1921,18 +1937,17 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, static webgpu_encoded_op ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { bool is_unary = dst->op == GGML_OP_UNARY; - bool inplace = ggml_webgpu_tensor_equal(src, dst) || (dst->op == GGML_OP_FILL); ggml_webgpu_shader_lib_context shader_lib_ctx = {}; shader_lib_ctx.src0 = src; shader_lib_ctx.src1 = nullptr; shader_lib_ctx.dst = dst; shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; - shader_lib_ctx.inplace = inplace; webgpu_pipeline pipeline = ctx->shader_lib->get_unary_pipeline(shader_lib_ctx); - auto * decisions = static_cast(pipeline.context.get()); + auto * decisions = static_cast(pipeline.context.get()); + const bool inplace = decisions->inplace; uint32_t ne = (uint32_t) ggml_nelements(dst); @@ -1994,41 +2009,38 @@ static webgpu_encoded_op ggml_webgpu_binary_op(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) { - binary_overlap_flags flags = ggml_webgpu_detect_binary_overlap(src0, src1, dst); - ggml_webgpu_shader_lib_context shader_lib_ctx = {}; shader_lib_ctx.src0 = src0; shader_lib_ctx.src1 = src1; shader_lib_ctx.dst = dst; shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; - shader_lib_ctx.inplace = flags.inplace; - shader_lib_ctx.overlap = flags.overlap; - shader_lib_ctx.src_overlap = flags.src_overlap; - - webgpu_pipeline pipeline = ctx->shader_lib->get_binary_pipeline(shader_lib_ctx); - auto * decisions = static_cast(pipeline.context.get()); + webgpu_pipeline pipeline = ctx->shader_lib->get_binary_pipeline(shader_lib_ctx); + auto * decisions = static_cast(pipeline.context.get()); uint32_t ne = (uint32_t) ggml_nelements(dst); size_t src0_webgpu_tensor_align_offset = ggml_webgpu_tensor_align_offset(ctx, src0); size_t src1_webgpu_tensor_align_offset = ggml_webgpu_tensor_align_offset(ctx, src1); - uint32_t offset_merged_src0 = 0; - uint32_t offset_merged_src1 = 0; - if (flags.src_overlap) { - size_t min_off = std::min(src0_webgpu_tensor_align_offset, src1_webgpu_tensor_align_offset); - offset_merged_src0 = (uint32_t) ((src0_webgpu_tensor_align_offset - min_off) / ggml_type_size(src0->type)); - offset_merged_src1 = (uint32_t) ((src1_webgpu_tensor_align_offset - min_off) / ggml_type_size(src0->type)); + uint32_t offset_src0 = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)); + uint32_t offset_src1 = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)); + size_t merged_offset = 0; + size_t merged_size = 0; + if (decisions->src_overlap) { + const ggml_webgpu_merged_binding_range merged_range = + ggml_webgpu_tensor_merged_binding_range(ctx, { src0, src1 }); + merged_offset = merged_range.offset; + merged_size = merged_range.size; + offset_src0 = ggml_webgpu_tensor_merged_element_offset(src0, merged_range); + offset_src1 = ggml_webgpu_tensor_merged_element_offset(src1, merged_range); } std::vector params = { ne, - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), + offset_src0, + offset_src1, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), - offset_merged_src0, - offset_merged_src1, (uint32_t) (src0->nb[0] / ggml_type_size(src0->type)), (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), @@ -2048,12 +2060,9 @@ static webgpu_encoded_op ggml_webgpu_binary_op(webgpu_context & ctx, std::vector entries; - if (flags.src_overlap) { - size_t merged_offset = std::min(src0_webgpu_tensor_align_offset, src1_webgpu_tensor_align_offset); - size_t merged_end = std::max(src0_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, src0), - src1_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, src1)); - entries.push_back(ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(src0), merged_offset, - merged_end - merged_offset)); + if (decisions->src_overlap) { + entries.push_back( + ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(src0), merged_offset, merged_size)); entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst)); } else { entries.push_back(ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(src0), @@ -2062,7 +2071,7 @@ static webgpu_encoded_op ggml_webgpu_binary_op(webgpu_context & ctx, entries.push_back(ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(src1), src1_webgpu_tensor_align_offset, ggml_webgpu_tensor_binding_size(ctx, src1))); - if (!flags.inplace && !flags.overlap) { + if (!decisions->inplace && !decisions->overlap) { entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst)); } } @@ -2168,29 +2177,15 @@ static std::optional ggml_webgpu_rms_norm_mul(webgpu_context GGML_ABORT("rms_norm must be equal to the one of mul_src0 and mul_src1"); } - bool overlap = (ggml_webgpu_tensor_equal(rn_dst, mul_src0) && ggml_webgpu_tensor_equal(mul_src1, dst)) || - (ggml_webgpu_tensor_equal(rn_dst, mul_src1) && ggml_webgpu_tensor_equal(mul_src0, dst)); - bool inplace = ggml_webgpu_tensor_equal(rn_src, dst); - bool src_overlap = ggml_webgpu_tensor_overlap(rn_src, mul_src); - - uint32_t offset_merged_rn_src = 0; - uint32_t offset_merged_mul_src = 0; - size_t rn_src_webgpu_tensor_align_offset = ggml_webgpu_tensor_align_offset(ctx, rn_src); - size_t mul_src_webgpu_tensor_align_offset = ggml_webgpu_tensor_align_offset(ctx, mul_src); - - if (src_overlap) { - size_t min_offset = std::min(rn_src_webgpu_tensor_align_offset, mul_src_webgpu_tensor_align_offset); - offset_merged_rn_src = - (uint32_t) ((rn_src_webgpu_tensor_align_offset - min_offset) / ggml_type_size(rn_src->type)); - offset_merged_mul_src = - (uint32_t) ((mul_src_webgpu_tensor_align_offset - min_offset) / ggml_type_size(mul_src->type)); - } + uint32_t offset_rn_src = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, rn_src) / ggml_type_size(rn_src->type)); + uint32_t offset_mul_src = + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mul_src) / ggml_type_size(mul_src->type)); + size_t merged_offset = 0; + size_t merged_size = 0; std::vector params = { - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, rn_src) / ggml_type_size(rn_src->type)), - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mul_src) / ggml_type_size(mul_src->type)), - offset_merged_rn_src, - offset_merged_mul_src, + offset_rn_src, + offset_mul_src, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), (uint32_t) (rn_src->nb[1] / ggml_type_size(rn_src->type)), (uint32_t) (rn_src->nb[2] / ggml_type_size(rn_src->type)), @@ -2214,16 +2209,32 @@ static std::optional ggml_webgpu_rms_norm_mul(webgpu_context std::vector entries; - if (inplace || overlap) { + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = rn_src; + shader_lib_ctx.src1 = mul_src; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + + webgpu_pipeline pipeline = ctx->shader_lib->get_rms_norm_mul_pipeline(shader_lib_ctx); + auto * decisions = static_cast(pipeline.context.get()); + + if (decisions->src_overlap) { + const ggml_webgpu_merged_binding_range merged_range = + ggml_webgpu_tensor_merged_binding_range(ctx, { rn_src, mul_src }); + merged_offset = merged_range.offset; + merged_size = merged_range.size; + offset_rn_src = ggml_webgpu_tensor_merged_element_offset(rn_src, merged_range); + offset_mul_src = ggml_webgpu_tensor_merged_element_offset(mul_src, merged_range); + params[0] = offset_rn_src; + params[1] = offset_mul_src; + } + + if (decisions->inplace || decisions->overlap) { entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, rn_src)); entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, mul_src)); - } else if (src_overlap) { - size_t merged_offset = std::min(rn_src_webgpu_tensor_align_offset, mul_src_webgpu_tensor_align_offset); - size_t merged_end = - std::max(rn_src_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, rn_src), - mul_src_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, mul_src)); - entries.push_back(ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(rn_src), merged_offset, - merged_end - merged_offset)); + } else if (decisions->src_overlap) { + entries.push_back( + ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(rn_src), merged_offset, merged_size)); entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst)); } else { entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, rn_src)); @@ -2231,20 +2242,10 @@ static std::optional ggml_webgpu_rms_norm_mul(webgpu_context entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst)); } - ggml_webgpu_shader_lib_context shader_lib_ctx = {}; - shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; - shader_lib_ctx.inplace = inplace; - shader_lib_ctx.overlap = overlap; - shader_lib_ctx.src_overlap = src_overlap; - - webgpu_pipeline pipeline = ctx->shader_lib->get_rms_norm_mul_pipeline(shader_lib_ctx); - return ggml_backend_webgpu_build(ctx, pipeline, params, entries, ggml_nrows(dst)); } static webgpu_encoded_op ggml_webgpu_row_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { - bool inplace = ggml_webgpu_tensor_equal(src, dst); - std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), @@ -2261,18 +2262,18 @@ static webgpu_encoded_op ggml_webgpu_row_norm(webgpu_context & ctx, ggml_tensor ggml_webgpu_u32_from_f32(ggml_get_op_params_f32(dst, 0)) // epsilon, treated as f32 in the shader }; - std::vector entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src) }; - if (!inplace) { - entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst)); - } - ggml_webgpu_shader_lib_context shader_lib_ctx = {}; shader_lib_ctx.src0 = src; shader_lib_ctx.dst = dst; shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; - shader_lib_ctx.inplace = inplace; - webgpu_pipeline pipeline = ctx->shader_lib->get_row_norm_pipeline(shader_lib_ctx); + webgpu_pipeline pipeline = ctx->shader_lib->get_row_norm_pipeline(shader_lib_ctx); + auto * decisions = static_cast(pipeline.context.get()); + + std::vector entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src) }; + if (!decisions->inplace) { + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst)); + } return ggml_backend_webgpu_build(ctx, pipeline, params, entries, ggml_nrows(src)); } @@ -2287,14 +2288,13 @@ static webgpu_encoded_op ggml_webgpu_rope(webgpu_context & ctx, shader_lib_ctx.src2 = src2; shader_lib_ctx.dst = dst; shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; - shader_lib_ctx.inplace = ggml_webgpu_tensor_equal(src0, dst); webgpu_pipeline pipeline = ctx->shader_lib->get_rope_pipeline(shader_lib_ctx); auto * decisions = static_cast(pipeline.context.get()); - const int inplace = ggml_webgpu_tensor_equal(src0, dst); - const int has_freq_factor = (src2 != nullptr); + const bool inplace = decisions->inplace; + const int has_freq_factor = (src2 != nullptr); const int n_dims = ((int32_t *) dst->op_params)[1]; const int mode = ((int32_t *) dst->op_params)[2]; @@ -2421,14 +2421,11 @@ static webgpu_encoded_op ggml_webgpu_glu(webgpu_context & ctx, } static webgpu_encoded_op ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { - bool inplace = ggml_webgpu_tensor_equal(src, dst); - ggml_webgpu_shader_lib_context shader_lib_ctx = {}; shader_lib_ctx.src0 = src; shader_lib_ctx.src1 = nullptr; shader_lib_ctx.dst = dst; shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; - shader_lib_ctx.inplace = inplace; webgpu_pipeline pipeline = ctx->shader_lib->get_scale_pipeline(shader_lib_ctx); auto * decisions = static_cast(pipeline.context.get()); @@ -2454,7 +2451,7 @@ static webgpu_encoded_op ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * s // bindgroups unchanged std::vector entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src) }; - if (!inplace) { + if (!decisions->inplace) { entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst)); } @@ -2473,17 +2470,17 @@ static webgpu_encoded_op ggml_webgpu_soft_max(webgpu_context & ctx, shader_lib_ctx.src2 = src2; shader_lib_ctx.dst = dst; shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; - shader_lib_ctx.inplace = ggml_webgpu_tensor_equal(src0, dst); - webgpu_pipeline pipeline = ctx->shader_lib->get_soft_max_pipeline(shader_lib_ctx); + webgpu_pipeline pipeline = ctx->shader_lib->get_soft_max_pipeline(shader_lib_ctx); + auto * decisions = static_cast(pipeline.context.get()); - const int inplace = ggml_webgpu_tensor_equal(src0, dst); - const int has_mask = (src1 != nullptr); - const int has_sink = (src2 != nullptr); - float max_bias = ggml_get_op_params_f32(dst, 1); - float n_head_log2 = float(1u << (uint32_t) floor(log2(src0->ne[2]))); - float m0 = powf(2.0f, -(max_bias) / n_head_log2); - float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + const bool inplace = decisions->inplace; + const int has_mask = (src1 != nullptr); + const int has_sink = (src2 != nullptr); + float max_bias = ggml_get_op_params_f32(dst, 1); + float n_head_log2 = float(1u << (uint32_t) floor(log2(src0->ne[2]))); + float m0 = powf(2.0f, -(max_bias) / n_head_log2); + float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), @@ -3079,7 +3076,7 @@ static void ggml_backend_webgpu_set_tensor_async(ggml_backend_t backend, size_t size) { GGML_UNUSED(backend); auto * buf_ctx = (ggml_backend_webgpu_buffer_context *) tensor->buffer->context; - size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset; + size_t total_offset = ggml_webgpu_tensor_offset(tensor) + offset; // Write aligned portion buf_ctx->global_ctx->queue.WriteBuffer(buf_ctx->buffer, total_offset, data, (size / 4) * 4); @@ -3161,7 +3158,7 @@ static void ggml_backend_webgpu_buffer_memset_tensor(ggml_backend_buffer_t buffe WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor(" << buf_ctx->label << ", " << tensor << ", " << value << ", " << offset << ", " << size << ")"); - size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset; + size_t total_offset = ggml_webgpu_tensor_offset(tensor) + offset; // This is a trick to set all bytes of a u32 to the same 1 byte value. uint32_t val32 = (uint32_t) value * 0x01010101; @@ -3180,7 +3177,7 @@ static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer, WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_set_tensor(" << buf_ctx->label << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")"); - size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset; + size_t total_offset = ggml_webgpu_tensor_offset(tensor) + offset; buf_ctx->global_ctx->queue.WriteBuffer(buf_ctx->buffer, total_offset, data, (size / 4) * 4); @@ -3212,7 +3209,7 @@ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer, << ", " << offset << ", " << size << ")"); wgpu::Device device = buf_ctx->global_ctx->device; - size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset; + size_t total_offset = ggml_webgpu_tensor_offset(tensor) + offset; size_t final_size = size; if (size % 4 != 0) { diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl index a748dc1b86c..605de7aa7be 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl @@ -7,8 +7,6 @@ struct Params { offset_src0: u32, offset_src1: u32, offset_dst: u32, - offset_merged_src0: u32, - offset_merged_src1: u32, stride_src0_0: u32, stride_src0_1: u32, @@ -134,8 +132,8 @@ fn update(dst_i: u32, src0_i: u32, src1_i: u32) { @compute @workgroup_size(WG_SIZE) fn main(@builtin(global_invocation_id) gid: vec3) { if (gid.x < params.ne) { - let src0_i = params.offset_src0 + params.offset_merged_src0 + src0_index(gid.x); - let src1_i = params.offset_src1 + params.offset_merged_src1 + src1_index(gid.x); + let src0_i = params.offset_src0 + src0_index(gid.x); + let src1_i = params.offset_src1 + src1_index(gid.x); update(params.offset_dst + gid.x, src0_i, src1_i); } } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl index 74aaa2753ae..fd20a4e54c9 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl @@ -66,8 +66,6 @@ fn update(rn_src_offset: u32, dst_offset: u32, scale: f32, mul_src_offset: u32) struct Params { offset_rn_src: u32, offset_mul_src: u32, - offset_merged_rn_src: u32, - offset_merged_mul_src: u32, offset_dst: u32, stride_rn_src1: u32, @@ -107,8 +105,8 @@ fn main(@builtin(workgroup_id) wid: vec3, i = i % (params.ne2 * params.ne1); let i2 = i / params.ne1; let i1 = i % params.ne1; - let i_rn_src_row = params.offset_rn_src + params.offset_merged_rn_src + i3 * params.stride_rn_src3 + i2 * params.stride_rn_src2 + i1 * params.stride_rn_src1; - let i_mul_src_row = params.offset_mul_src + params.offset_merged_mul_src + (i3 % params.mul_src_ne3) * params.stride_mul_src3 + (i2 % params.mul_src_ne2) * params.stride_mul_src2 + (i1 % params.mul_src_ne1) * params.stride_mul_src1; + let i_rn_src_row = params.offset_rn_src + i3 * params.stride_rn_src3 + i2 * params.stride_rn_src2 + i1 * params.stride_rn_src1; + let i_mul_src_row = params.offset_mul_src + (i3 % params.mul_src_ne3) * params.stride_mul_src3 + (i2 % params.mul_src_ne2) * params.stride_mul_src2 + (i1 % params.mul_src_ne1) * params.stride_mul_src1; let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1; let elems = (params.ne0 + WG_SIZE - 1) / WG_SIZE; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl index 64324738591..05761dec353 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl @@ -45,6 +45,14 @@ struct Params { }; @group(0) @binding(0) var s_in: array; +#ifdef XBC_OVERLAP +@group(0) @binding(1) var x_B_C_merged: array; +@group(0) @binding(2) var dt: array; +@group(0) @binding(3) var A: array; +@group(0) @binding(4) var ids: array; +@group(0) @binding(5) var dst: array; +@group(0) @binding(6) var params: Params; +#else @group(0) @binding(1) var x: array; @group(0) @binding(2) var dt: array; @group(0) @binding(3) var A: array; @@ -53,6 +61,7 @@ struct Params { @group(0) @binding(6) var ids: array; @group(0) @binding(7) var dst: array; @group(0) @binding(8) var params: Params; +#endif var shared_x_dt: array; var shared_dtsp: array; @@ -98,7 +107,11 @@ fn main( let dt0 = dt[dt_idx]; let dtsp = select(log(1.0 + exp(dt0)), dt0, dt0 > 20.0); shared_dtsp[tid] = dtsp; +#ifdef XBC_OVERLAP + shared_x_dt[tid] = x_B_C_merged[x_idx] * dtsp; +#else shared_x_dt[tid] = x[x_idx] * dtsp; +#endif } } @@ -116,16 +129,28 @@ fn main( let b_idx = params.offset_B + tid + g * params.stride_B1 + token * params.stride_B2 + i3 * params.stride_B3; let c_idx = params.offset_C + tid + g * params.stride_C1 + token * params.stride_C2 + i3 * params.stride_C3; +#ifdef XBC_OVERLAP + let s = s_prev * dA + x_B_C_merged[b_idx] * x_dt; +#else let s = s_prev * dA + B[b_idx] * x_dt; +#endif s_prev = s; #ifdef USE_SUBGROUP_REDUCTION +#ifdef XBC_OVERLAP + let subgroup_partial = subgroupAdd(s * x_B_C_merged[c_idx]); +#else let subgroup_partial = subgroupAdd(s * C[c_idx]); +#endif if (subgroup_invocation_id == 0u) { shared_reduce[reduce_idx - tid + subgroup_id] = subgroup_partial; } +#else +#ifdef XBC_OVERLAP + shared_reduce[reduce_idx] = s * x_B_C_merged[c_idx]; #else shared_reduce[reduce_idx] = s * C[c_idx]; +#endif #endif workgroupBarrier(); From e69c109aac3f7ca1643a50027603902c123a3849 Mon Sep 17 00:00:00 2001 From: Matt Corallo <649246+TheBlueMatt@users.noreply.github.com> Date: Tue, 28 Apr 2026 15:31:04 +0000 Subject: [PATCH 512/831] vulkan: Coalesce Q4_K/Q5_K scale loads (llama/21751) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Some SPIR-V compilers (notably mesa) don't handle the current vulkan Q4_K/Q5_K scale load pattern in mul_mat particularly well. While reading three `u8`s from the 12-byte scale array should (at least on some hardware) result in loading the full 12 bytes in a single LOAD followed by whatever extraction is needed, at least the ANV Intel driver really can't practically perform this optimization. `mesa`'s unsigned upper bound logic doesn't handle tracking bounds through ternary, resulting in the `(is < 4) ? ... : is - 4` having an infinite upper bound (as it cannot prove `is - 4` doesn't underflow). While this could still be rectified if mesa looked at the array bounds, it currently doesn't and `glslc` currently emits SPIR-V that doesn't allow for this optimization anyway (though maybe it will at some point, see https://github.com/KhronosGroup/glslang/issues/4206). In mul_mat_vecq we took a different approach to loading the same fields. We read the first two bytes we needed from `scale` then took a branch before deciding whether we needed to read a third byte. In mesa this did, indeed, lead to a top-level branch with conditional loads. As such these loads ended up not being coalesced either (at least in the ANV driver) resulting in additional instructions in our hot loop. Instead, here, we go ahead and force loading the full 12 bytes and extract the bits we need from the packed-u32s instead. In mul_mat there's a few less ternaries and only one extra shift, so even on drivers that did optimize the previous loads properly the only material change should be pulling a few extra bytes into registers (which on most hardware won't cost anything anyway, though ironically on Intel it theoretically could). In mul_mat_vecq this requires a bit of extra math and may read bytes from the u32 that weren't needed, but it seems likely avoiding the branch is a win on most platforms. On Intel Xe2/mesa 26.0.4 with the optimizations from https://gitlab.freedesktop.org/mesa/mesa/-/work_items/15162, for shader matmul_id_subgroup_q4_k_f32_f16acc_aligned_l: * Instruction Count: 2753 -> 2688 * SEND Count: 269 -> 261 * Cycle Count: 273976 -> 266138 * Max live registers: 248 -> 246 * Non SSA regs after NIR: 381 -> 382 for shader matmul_id_subgroup_q5_k_f32_f16acc_aligned_l: * Instruction Count: 2767 -> 2702 * SEND Count: 271 -> 263 * Cycle Count: 274140 -> 268144 * Max live registers: 248 -> 246 * Non SSA regs after NIR: 381 -> 382 for shader mul_mat_vec_id_q4_k_q8_1_f32: * Instruction Count: 1930 -> 1646 * SEND Count: 116 -> 71 * Cycle Count: 1348306 -> 843350 * Max live registers: 78 -> 84 * Non SSA regs after NIR: 300 -> 135 for shader mul_mat_vec_id_q5_k_q8_1_f32: * Instruction Count: 2207 -> 1922 * SEND Count: 131 -> 86 * Cycle Count: 1392012 -> 1037836 * Max live registers: 90 -> 90 * Non SSA regs after NIR: 300 -> 135 for shader mul_mat_vec_q4_k_q8_1_f32: * Instruction Count: 2029 -> 1749 * SEND Count: 111 -> 66 * Cycle Count: 1347278 -> 840118 * Max live registers: 74 -> 80 * Non SSA regs after NIR: 299 -> 134 for shader mul_mat_vec_q5_k_q8_1_f32: * Instruction Count: 2307 -> 2022 * SEND Count: 126 -> 81 * Cycle Count: 1379820 -> 954042 * Max live registers: 86 -> 86 * Non SSA regs after NIR: 299 -> 134 On one Arc Pro B60, unsloth/Qwen3.5-35B-A3B-GGUF:UD-Q4_K_XL: * pp512: 907.34 ± 9.28 -> 941.94 ± 10.53 (+4%) * pp2048: 897.95 ± 1.82 -> 931.55 ± 1.79 (+4%) * tg128: 49.49 ± 0.02 -> 49.86 ± 0.05 (+ <1%) On one Arc Pro B60, unsloth/Qwen3.5-27B-GGUF:Q4_K_S: * pp512: 324.13 ± 10.52 -> 354.33 ± 6.81 (+9%) * pp2048: 329.80 ± 0.25 -> 357.10 ± 0.06 (+8%) * tg128: 17.11 ± 0.01 -> 18.11 ± 0.01 (+6%) On four Arc Pro B60s, unsloth/Qwen3.5-122B-A10B-GGUF:Q5_K_S with -sm layer (note that -sm tensor improvements will naturally be less): * pp512: 264.55 ± 2.81 -> 280.45 ± 3.94 (+6%) * pp2048: 319.32 ± 2.72 -> 335.70 ± 3.48 (+5%) * tg128: 26.39 ± 0.01 -> 26.67 ± 0.01 (+1%) --- .../vulkan-shaders/mul_mat_vecq_funcs.glsl | 23 +++++--- .../vulkan-shaders/mul_mm_funcs.glsl | 54 ++++++++++--------- 2 files changed, 44 insertions(+), 33 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl index e99108dc50c..bc580aeeb83 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl @@ -296,13 +296,22 @@ vec2 get_dm_scale(uint ib, uint iqs) { const uint ib_k = ib / 8; const uint iqs_k = (ib % 8) * 8 + iqs; const uint is = iqs_k / 8; - u8vec2 scale_dm; - if (is < 4) { - scale_dm = u8vec2(data_a[ib_k].scales[is] & 0x3F, data_a[ib_k].scales[is + 4] & 0x3F); - } else { - scale_dm = u8vec2((data_a[ib_k].scales[is+4] & 0xF) | ((data_a[ib_k].scales[is-4] & 0xC0) >> 2), - (data_a[ib_k].scales[is+4] >> 4) | ((data_a[ib_k].scales[is ] & 0xC0) >> 2)); - } + + const uvec3 scales = uvec3(data_a_packed32[ib_k].scales[0], + data_a_packed32[ib_k].scales[1], + data_a_packed32[ib_k].scales[2]); + const uint scalesoffs = (is & 3) * 8; + + const uint scidx0 = (is < 4) ? 0 : 2; + const uint scidxshift0 = scalesoffs; + const uint scidxshift1 = (is < 4) ? scalesoffs : scalesoffs + 2; + const uint mbidx0 = (is < 4) ? 1 : 2; + const uint mbidxshift0 = (is < 4) ? scalesoffs : scalesoffs + 4; + const uint mbidxshift1 = (is < 4) ? scalesoffs : scalesoffs + 2; + + const uint8_t sc = uint8_t(((scales[scidx0] >> scidxshift0) & 0xF) | ((scales[0] >> scidxshift1) & 0x30)); + const uint8_t mbyte = uint8_t(((scales[mbidx0] >> mbidxshift0) & 0xF) | ((scales[1] >> mbidxshift1) & 0x30)); + u8vec2 scale_dm = u8vec2(sc, mbyte); return FLOAT_TYPEV2(data_a_packed32[ib_k].dm) * FLOAT_TYPEV2(scale_dm); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl index 6e4a29d2fdd..73595168984 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl @@ -201,19 +201,20 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const vec2 loadd = vec2(data_a[ib].dm); - const uint scidx0 = (is < 4) ? is : (is + 4); - const uint scidx1 = (is < 4) ? is : (is - 4); - const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0; - const uint scidxshift1 = (is < 4) ? 0 : 2; - const uint mbidx0 = is + 4; - const uint mbidx1 = (is < 4) ? is + 4 : is; - const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0; - const uint mbidxshift0 = (is < 4) ? 0 : 4; - const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0; - const uint mbidxshift1 = (is < 4) ? 0 : 2; - - const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1)); - const uint8_t mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1)); + const uvec3 scales = uvec3(data_a_packed32[ib].scales[0], + data_a_packed32[ib].scales[1], + data_a_packed32[ib].scales[2]); + const uint scalesoffs = (is & 3) * 8; + + const uint scidx0 = (is < 4) ? 0 : 2; + const uint scidxshift0 = scalesoffs; + const uint scidxshift1 = (is < 4) ? scalesoffs : scalesoffs + 2; + const uint mbidx0 = (is < 4) ? 1 : 2; + const uint mbidxshift0 = (is < 4) ? scalesoffs : scalesoffs + 4; + const uint mbidxshift1 = (is < 4) ? scalesoffs : scalesoffs + 2; + + const uint8_t sc = uint8_t(((scales[scidx0] >> scidxshift0) & 0xF) | ((scales[0] >> scidxshift1) & 0x30)); + const uint8_t mbyte = uint8_t(((scales[mbidx0] >> mbidxshift0) & 0xF) | ((scales[1] >> mbidxshift1) & 0x30)); const float d = loadd.x * sc; const float m = -loadd.y * mbyte; @@ -237,19 +238,20 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const vec2 loadd = vec2(data_a[ib].dm); - const uint scidx0 = (is < 4) ? is : (is + 4); - const uint scidx1 = (is < 4) ? is : (is - 4); - const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0; - const uint scidxshift1 = (is < 4) ? 0 : 2; - const uint mbidx0 = is + 4; - const uint mbidx1 = (is < 4) ? is + 4 : is; - const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0; - const uint mbidxshift0 = (is < 4) ? 0 : 4; - const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0; - const uint mbidxshift1 = (is < 4) ? 0 : 2; - - const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1)); - const uint8_t mbyte = uint8_t(((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1)); + const uvec3 scales = uvec3(data_a_packed32[ib].scales[0], + data_a_packed32[ib].scales[1], + data_a_packed32[ib].scales[2]); + const uint scalesoffs = (is & 3) * 8; + + const uint scidx0 = (is < 4) ? 0 : 2; + const uint scidxshift0 = scalesoffs; + const uint scidxshift1 = (is < 4) ? scalesoffs : scalesoffs + 2; + const uint mbidx0 = (is < 4) ? 1 : 2; + const uint mbidxshift0 = (is < 4) ? scalesoffs : scalesoffs + 4; + const uint mbidxshift1 = (is < 4) ? scalesoffs : scalesoffs + 2; + + const uint8_t sc = uint8_t(((scales[scidx0] >> scidxshift0) & 0xF) | ((scales[0] >> scidxshift1) & 0x30)); + const uint8_t mbyte = uint8_t(((scales[mbidx0] >> mbidxshift0) & 0xF) | ((scales[1] >> mbidxshift1) & 0x30)); const float d = loadd.x * sc; const float m = -loadd.y * mbyte; From b553e17071862cd10feed563f8afb027cc713e18 Mon Sep 17 00:00:00 2001 From: lnigam Date: Wed, 29 Apr 2026 01:07:35 +0530 Subject: [PATCH 513/831] =?UTF-8?q?ggml-cuda:=20add=20flash-attn=20support?= =?UTF-8?q?=20for=20DKQ=3D320/DV=3D256=20with=20ncols2=3D32=20(=E2=80=A6?= =?UTF-8?q?=20(#22286)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * ggml-cuda: add flash-attn support for DKQ=320/DV=256 with ncols2=32 (GQA=32) Adds MMA-f16 and tile kernel configs, dispatch logic, template instances, and tile .cu file for Mistral Small 4 (head sizes 320/256), restricting to ncols2=32 to support GQA ratio 32 only. * Adding check to return BEST_FATTN_KERNEL_NONE in case GQA!=32 * Apply suggestions from code review Address review comments Co-authored-by: Johannes Gäßler * Address review comments and making kernel config default to DQK=512, DV=512 instead of DQK=256,DV=256 * Fixed bug with sinks=1, with ncols=32, there are two warp-groups created but sinks index is same(0,...,15) for both the groups hence with sinks=1, output is not matching with CPU output. Added sink_base which will be base index for each warp_group (threadIdx.y / np) * Apply suggestions from code review Co-authored-by: Johannes Gäßler * Update ggml/src/ggml-cuda/template-instances/generate_cu_files.py Co-authored-by: Johannes Gäßler --------- Co-authored-by: Johannes Gäßler --- ggml/src/ggml-cuda/fattn-mma-f16.cuh | 15 +++++++- ggml/src/ggml-cuda/fattn-tile.cu | 4 ++ ggml/src/ggml-cuda/fattn-tile.cuh | 37 ++++++++++++++----- ggml/src/ggml-cuda/fattn.cu | 24 ++++++++++++ ...ttn-mma-f16-instance-ncols1_1-ncols2_32.cu | 1 + ...ttn-mma-f16-instance-ncols1_2-ncols2_32.cu | 1 + .../fattn-tile-instance-dkq320-dv256.cu | 5 +++ .../template-instances/generate_cu_files.py | 15 +++++--- 8 files changed, 86 insertions(+), 16 deletions(-) create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq320-dv256.cu diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index e185449d491..3f01e858de7 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -66,6 +66,9 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 2, true); GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 32, 128, 128, 128, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 32, 128, 2, 32, 128, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 64, 256, 1, 32, 128, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 8, 64, 4, 32, 256, 256, 128, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 256, 256, 128, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 128, 1, false); @@ -85,6 +88,9 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true); GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 32, 128, 2, 32, 128, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 64, 256, 1, 32, 128, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 8, 64, 4, 32, 96, 64, 128, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 96, 64, 128, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 128, 1, false); @@ -118,6 +124,9 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true); GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 32, 128, 2, 64, 160, 128, 64, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 64, 128, 2, 64, 160, 128, 64, 2, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 128, 128, 128, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 128, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 256, 1, 32, 128, 128, 128, 1, false); @@ -1217,7 +1226,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( float KQ_max_scale[cols_per_thread]; #pragma unroll for (int col = 0; col < cols_per_thread; ++col) { - const int jc = cols_per_warp == 8 ? T_C_KQ::get_j(col) : T_C_KQ::get_i(2*col); + const int jc = (threadIdx.y/np)*cols_per_warp + (cols_per_warp == 8 ? T_C_KQ::get_j(col) : T_C_KQ::get_i(2*col)); const float sink = sinks_f[jc % ncols2]; const float KQ_max_new = fmaxf(KQ_max[col], sink); @@ -1825,6 +1834,10 @@ extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16); extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16); extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16); +// Mistral Small 4 (DKQ=320, DV=256), GQA=32-only build: +extern DECL_FATTN_MMA_F16_CASE(320, 256, 1, 32); +extern DECL_FATTN_MMA_F16_CASE(320, 256, 2, 32); + // For GLM 4.7 Flash extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4); extern DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4); diff --git a/ggml/src/ggml-cuda/fattn-tile.cu b/ggml/src/ggml-cuda/fattn-tile.cu index 25b16e83cac..d60634cc0e9 100644 --- a/ggml/src/ggml-cuda/fattn-tile.cu +++ b/ggml/src/ggml-cuda/fattn-tile.cu @@ -38,6 +38,10 @@ void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor GGML_ASSERT(V->ne[0] == K->ne[0]); ggml_cuda_flash_attn_ext_tile_case<256, 256>(ctx, dst); } break; + case 320: { + GGML_ASSERT(V->ne[0] == 256); + ggml_cuda_flash_attn_ext_tile_case<320, 256>(ctx, dst); + } break; case 512: { GGML_ASSERT(V->ne[0] == K->ne[0]); ggml_cuda_flash_attn_ext_tile_case<512, 512>(ctx, dst); diff --git a/ggml/src/ggml-cuda/fattn-tile.cuh b/ggml/src/ggml-cuda/fattn-tile.cuh index 26721cc4c7d..928b856f9d2 100644 --- a/ggml/src/ggml-cuda/fattn-tile.cuh +++ b/ggml/src/ggml-cuda/fattn-tile.cuh @@ -68,6 +68,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 32, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 64, 64) @@ -128,6 +130,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 32, 256, 2, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 32, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 32, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 32, 64) @@ -195,6 +199,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 32, 512, 1, 128, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 64, 64) @@ -264,6 +270,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5, 32, 256) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3, 64, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 32, 256, 2, 128, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 4, 64, 64) @@ -1144,14 +1152,16 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_cuda_context & ctx, ggm } } - if (Q->ne[1] > 8/ncols2) { - constexpr int cols_per_block = 16; - const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; - const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); - fattn_kernel_t fattn_kernel = flash_attn_tile; - launch_fattn - (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size); - return; + if constexpr (ncols2 <= 16) { + if (Q->ne[1] > 8/ncols2) { + constexpr int cols_per_block = 16; + const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; + const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); + fattn_kernel_t fattn_kernel = flash_attn_tile; + launch_fattn + (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size); + return; + } } if constexpr (ncols2 <= 8) { @@ -1210,6 +1220,14 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm const int gqa_limit = nvidia && gqa_ratio <= 4 && DV <= 256 ? 16 : INT_MAX; const bool use_gqa_opt = mask && max_bias == 0.0f && Q->ne[1] <= gqa_limit && K->ne[1] % FATTN_KQ_STRIDE == 0; + if constexpr (DKQ == 320) { // Mistral Small 4 + if (use_gqa_opt && gqa_ratio % 32 == 0) { + launch_fattn_tile_switch_ncols1(ctx, dst); + return; + } + GGML_ABORT("flash-attn tile (320/256): expected GQA ratio multiple of 32"); + } + if constexpr (DKQ == 576) { if (use_gqa_opt && gqa_ratio % 16 == 0) { launch_fattn_tile_switch_ncols1(ctx, dst); @@ -1221,7 +1239,7 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm } } - if constexpr (DKQ <= 512) { + if constexpr (DKQ <= 512 && DKQ != 320) { if (use_gqa_opt && gqa_ratio % 8 == 0) { launch_fattn_tile_switch_ncols1(ctx, dst); return; @@ -1275,5 +1293,6 @@ extern DECL_FATTN_TILE_CASE( 96, 96); extern DECL_FATTN_TILE_CASE(112, 112); extern DECL_FATTN_TILE_CASE(128, 128); extern DECL_FATTN_TILE_CASE(256, 256); +extern DECL_FATTN_TILE_CASE(320, 256); extern DECL_FATTN_TILE_CASE(512, 512); extern DECL_FATTN_TILE_CASE(576, 512); diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index ea6607cd337..8256591b21d 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -143,6 +143,22 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg GGML_ASSERT(V->ne[0] == 256); ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<256, 256>(ctx, dst); break; + case 320: + // For Mistral Small 4, go straight to the ncols1 switch (ncols2=32-only build). + GGML_ASSERT(V->ne[0] == 256); + { + float max_bias = 0.0f; + memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); + + const bool use_gqa_opt = mask && max_bias == 0.0f; + GGML_ASSERT(use_gqa_opt); + GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); + const int gqa_ratio = Q->ne[2] / K->ne[2]; + GGML_ASSERT(gqa_ratio % 32 == 0); + + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<320, 256, 32>(ctx, dst); + } + break; case 512: GGML_ASSERT(V->ne[0] == 512); ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<512, 512>(ctx, dst); @@ -352,6 +368,14 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const return BEST_FATTN_KERNEL_NONE; } break; + case 320: + if (V->ne[0] != 256 || !gqa_opt_applies) { + return BEST_FATTN_KERNEL_NONE; + } + if (gqa_ratio % 32 != 0) { + return BEST_FATTN_KERNEL_NONE; + } + break; case 512: if (V->ne[0] != K->ne[0]) { return BEST_FATTN_KERNEL_NONE; diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu index 1f554d81e5e..8fc3b17976e 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu @@ -2,4 +2,5 @@ #include "../fattn-mma-f16.cuh" +DECL_FATTN_MMA_F16_CASE(320, 256, 1, 32); DECL_FATTN_MMA_F16_CASE(576, 512, 1, 32); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu index 264751d65ec..abd2b21ce04 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu @@ -2,4 +2,5 @@ #include "../fattn-mma-f16.cuh" +DECL_FATTN_MMA_F16_CASE(320, 256, 2, 32); DECL_FATTN_MMA_F16_CASE(576, 512, 2, 32); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq320-dv256.cu b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq320-dv256.cu new file mode 100644 index 00000000000..c91f508079d --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq320-dv256.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.cuh" + +DECL_FATTN_TILE_CASE(320, 256); diff --git a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py index 841059c15b5..5e9a1cb2eb3 100755 --- a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +++ b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py @@ -3,7 +3,7 @@ from glob import glob import os -HEAD_SIZES_KQ = [40, 64, 72, 80, 96, 112, 128, 256, 512, 576] +HEAD_SIZES_KQ = [40, 64, 72, 80, 96, 112, 128, 256, 320, 512, 576] TYPES_KV = ["GGML_TYPE_F16", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", "GGML_TYPE_BF16"] @@ -62,7 +62,7 @@ def get_short_name(long_quant_name): os.remove(filename) for head_size_kq in HEAD_SIZES_KQ: - head_size_v = head_size_kq if head_size_kq != 576 else 512 + head_size_v = 256 if head_size_kq == 320 else (head_size_kq if head_size_kq != 576 else 512) with open(f"fattn-tile-instance-dkq{head_size_kq}-dv{head_size_v}.cu", "w") as f: f.write(SOURCE_FATTN_TILE.format(head_size_kq=head_size_kq, head_size_v=head_size_v)) @@ -84,13 +84,16 @@ def get_short_name(long_quant_name): continue if head_size_kq == 72: continue - if head_size_kq == 512 and ncols2 not in (4, 8): + # Skip compilation of unused ncols2 values for niche head sizes: + if head_size_kq == 320 and ncols2 != 32: # Mistral Small 4 continue - if head_size_kq != 576 and ncols2 in (16, 32): + if head_size_kq == 512 and ncols2 not in (4, 8): # Gemma 4 continue - if head_size_kq == 576 and ncols2 not in (4, 16, 32): + if head_size_kq == 576 and ncols2 not in (4, 16, 32): # Deepseek, GLM 4.7 Flash continue - head_size_v = head_size_kq if head_size_kq != 576 else 512 + if head_size_kq not in (320, 576) and ncols2 in (16, 32): + continue + head_size_v = 256 if head_size_kq == 320 else (head_size_kq if head_size_kq != 576 else 512) f.write(SOURCE_FATTN_MMA_CASE.format(ncols1=ncols1, ncols2=ncols2, head_size_kq=head_size_kq, head_size_v=head_size_v)) for type in TYPES_MMQ: From c200b588f88301ab77f8f368355ed718ecb18ce7 Mon Sep 17 00:00:00 2001 From: Michael Wand Date: Tue, 28 Apr 2026 15:47:42 -0700 Subject: [PATCH 514/831] ggml-cuda: Repost of 21896: Blackwell native NVFP4 support (llama/22196) --- ggml/src/ggml-cuda/common.cuh | 12 ++ ggml/src/ggml-cuda/mma.cuh | 34 +++-- ggml/src/ggml-cuda/mmq.cu | 21 ++- ggml/src/ggml-cuda/mmq.cuh | 230 ++++++++++++++++++++------------ ggml/src/ggml-cuda/mmvq.cu | 3 + ggml/src/ggml-cuda/quantize.cu | 148 ++++++++++++++++---- ggml/src/ggml-cuda/quantize.cuh | 2 +- 7 files changed, 319 insertions(+), 131 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 3aec1742ee1..10817505d9f 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -830,6 +830,18 @@ static __device__ __forceinline__ float ggml_cuda_ue4m3_to_fp32(uint8_t x) { #endif // defined(GGML_USE_HIP) && defined(CDNA3) && defined(FP8_AVAILABLE) && HIP_VERSION >= 60200000 } +static __device__ __forceinline__ uint8_t ggml_cuda_fp32_to_ue4m3(float x) { +#if defined(BLACKWELL_MMA_AVAILABLE) // This is used for NVFP4 subblock scale quantizations only + if (!(x > 0.0f)) { + return 0; + } + const __nv_fp8_e4m3 xf(x); + return xf.__x; +#else + NO_DEVICE_CODE; // Used only for NVFP4 Scales for Activations, only for Blackwell +#endif // defined(BLACKWELL_MMA_AVAILABLE) +} + __device__ __forceinline__ uint8_t ggml_cuda_float_to_fp4_e2m1(float x, float e) { const uint8_t sign_bit = (x < 0.0f) << 3; float ax = fabsf(x) * e; diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index b0f674635f1..79bb2934c5f 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -1015,25 +1015,35 @@ namespace ggml_cuda_mma { #endif // AMD_MFMA_AVAILABLE } - static __device__ __forceinline__ void mma_block_scaled(tile<16, 8, float> & D, - const tile<16, 8, int> & A, - const tile<8, 8, int> & B, - uint32_t a_scale, - uint32_t b_scale) { + template + static __device__ __forceinline__ void mma_block_scaled_fp4(tile<16, 8, float> & D, + const tile<16, 8, int> & A, + const tile<8, 8, int> & B, + uint32_t a_scale, + uint32_t b_scale) { #ifdef BLACKWELL_MMA_AVAILABLE const int * Axi = (const int *) A.x; const int * Bxi = (const int *) B.x; float * Dxi = (float *) D.x; - asm volatile( - "mma.sync.aligned.kind::mxf4.block_scale.scale_vec::2X.m16n8k64.row.col.f32.e2m1.e2m1.f32.ue8m0 " - "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3}, " - "%10, {0, 0}, %11, {0, 0};" - : "+f"(Dxi[0]), "+f"(Dxi[1]), "+f"(Dxi[2]), "+f"(Dxi[3]) - : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]), "r"(a_scale), "r"(b_scale)); + if constexpr (type == GGML_TYPE_MXFP4) { + asm volatile( + "mma.sync.aligned.kind::mxf4.block_scale.scale_vec::2X.m16n8k64.row.col.f32.e2m1.e2m1.f32.ue8m0 " + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3}, " + "%10, {0, 0}, %11, {0, 0};" + : "+f"(Dxi[0]), "+f"(Dxi[1]), "+f"(Dxi[2]), "+f"(Dxi[3]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]), "r"(a_scale), "r"(b_scale)); + } else { + asm volatile( + "mma.sync.aligned.kind::mxf4nvf4.block_scale.scale_vec::4X.m16n8k64.row.col.f32.e2m1.e2m1.f32.ue4m3 " + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3}, " + "%10, {0, 0}, %11, {0, 0};" + : "+f"(Dxi[0]), "+f"(Dxi[1]), "+f"(Dxi[2]), "+f"(Dxi[3]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]), "r"(a_scale), "r"(b_scale)); + } #else GGML_UNUSED_VARS(D, A, B, a_scale, b_scale); -#endif // BLACKWELL_MMA_AVAILABLE +#endif // BLACKWELL_MMA_AVAILABLE } static __device__ __forceinline__ void mma( diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index 3f01ff5bfb0..e1add5e0331 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -122,7 +122,7 @@ void ggml_cuda_mul_mat_q( || GGML_CUDA_CC_IS_CDNA(cc); // TODO: tighter pool buffer size vs q8 path - const bool use_native_mxfp4 = blackwell_mma_available(cc) && src0->type == GGML_TYPE_MXFP4; + const bool use_native_fp4 = blackwell_mma_available(cc) && (src0->type == GGML_TYPE_MXFP4 || src0->type == GGML_TYPE_NVFP4); if (!ids) { const size_t nbytes_src1_q8_1 = ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1 + @@ -133,9 +133,9 @@ void ggml_cuda_mul_mat_q( const int64_t s11 = src1->nb[1] / ts_src1; const int64_t s12 = src1->nb[2] / ts_src1; const int64_t s13 = src1->nb[3] / ts_src1; - if (use_native_mxfp4) { + if (use_native_fp4) { static_assert(sizeof(block_fp4_mmq) == 4 * sizeof(block_q8_1)); - quantize_mmq_mxfp4_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded, + quantize_mmq_fp4_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded, ne11, ne12, ne13, stream); } else { @@ -146,10 +146,8 @@ void ggml_cuda_mul_mat_q( } // Stride depends on quantization format - const int64_t s12 = use_native_mxfp4 ? - ne11 * ne10_padded * sizeof(block_fp4_mmq) / - (8 * QK_MXFP4 * sizeof(int)) // block_fp4_mmq holds 256 values (8 blocks of 32) - : + const int64_t s12 = use_native_fp4 ? + ne11 * ne10_padded * sizeof(block_fp4_mmq) / (QK_K * sizeof(int)) : // block_fp4_mmq holds 256 values ne11 * ne10_padded * sizeof(block_q8_1) / (QK8_1 * sizeof(int)); const int64_t s13 = ne12*s12; @@ -198,8 +196,8 @@ void ggml_cuda_mul_mat_q( const int64_t s12 = src1->nb[2] / ts_src1; const int64_t s13 = src1->nb[3] / ts_src1; - if (use_native_mxfp4) { - quantize_mmq_mxfp4_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13, + if (use_native_fp4) { + quantize_mmq_fp4_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream); } else { quantize_mmq_q8_1_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13, @@ -208,8 +206,9 @@ void ggml_cuda_mul_mat_q( CUDA_CHECK(cudaGetLastError()); } - const int64_t s12 = use_native_mxfp4 ? ne11 * ne10_padded * sizeof(block_fp4_mmq) / (8 * QK_MXFP4 * sizeof(int)) : - ne11 * ne10_padded * sizeof(block_q8_1) / (QK8_1 * sizeof(int)); + static_assert(QK_K == 8 * QK_MXFP4, "QK_K needs to be 8 * QK_MXFP4"); + const int64_t s12 = use_native_fp4 ? ne11 * ne10_padded * sizeof(block_fp4_mmq) / (QK_K * sizeof(int)) : + ne11 * ne10_padded * sizeof(block_q8_1) / (QK8_1 * sizeof(int)); const int64_t s13 = ne12*s12; // Note that ne02 is used instead of ne12 because the number of y channels determines the z dimension of the CUDA grid. diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 91a1b737a82..edf546d8f1e 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -10,9 +10,9 @@ using namespace ggml_cuda_mma; #define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available. -#define MMQ_ITER_K 256 -#define MMQ_ITER_K_MXFP4_FP4 512 -#define MMQ_NWARPS 8 +#define MMQ_ITER_K 256 +#define MMQ_ITER_K_FP4 512 +#define MMQ_NWARPS 8 typedef void (*load_tiles_mmq_t)(const char * __restrict__ x, int * x_tile, const int kbx0, const int i_max, const int stride); typedef void (*vec_dot_mmq_t)(const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00); @@ -46,9 +46,12 @@ struct block_q8_1_mmq { int8_t qs[4*QK8_1]; // 128 values quantized to 8 bit each }; +// this struct is used for fp4 data types (currently only used for Blackwell) +// mxfp4 has block size 32, each int32 of d4 contains 2 e8m0 scales in the lower 16 bits +// nvfp4 has block size 16, each int32 of d4 contains 4 ue4m3 scales struct block_fp4_mmq { - uint32_t d4[4]; // 8 E8M0 scales (1 per 32 values), 2 packed per uint32: d4[0]={s0,s1}, d4[1]={s2,s3}, etc. - int8_t qs[4 * 32]; // 256 FP4 values packed as 4-bit pairs (2 per byte), 8 blocks of 32 values + uint32_t d4[4]; + int8_t qs[4 * 32]; // 256 FP4 values packed as 4-bit pairs (2 per byte) }; static_assert(sizeof(block_q8_1_mmq) == 4*QK8_1 + 4*sizeof(half2), "Unexpected block_q8_1_mmq size"); @@ -143,10 +146,11 @@ static int get_mmq_y_host(const int cc) { static constexpr __device__ int get_iter_k([[maybe_unused]] const ggml_type type) { #if defined(BLACKWELL_MMA_AVAILABLE) - return type == GGML_TYPE_MXFP4 ? MMQ_ITER_K_MXFP4_FP4 : MMQ_ITER_K; -#else - return MMQ_ITER_K; +if (type == GGML_TYPE_NVFP4 || type == GGML_TYPE_MXFP4) { + return MMQ_ITER_K_FP4; +} #endif // defined(BLACKWELL_MMA_AVAILABLE) + return MMQ_ITER_K; } static constexpr __device__ int get_mmq_y_device() { @@ -213,8 +217,8 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml } #define MMQ_MMA_TILE_X_K_Q8_0 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4) -#define MMQ_MMA_TILE_X_K_FP4 (2*MMQ_TILE_NE_K + 8 + 4) // MXFP4 -#define MMQ_MMA_TILE_X_K_NVFP4 (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4) // NVFP4 +#define MMQ_MMA_TILE_X_K_FP4 (2*MMQ_TILE_NE_K + 8 + 4) // MXFP4 and NVFP4 Blackwell +#define MMQ_MMA_TILE_X_K_NVFP4 (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4) // NVFP4 Generic #define MMQ_MMA_TILE_X_K_Q8_1 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4) #define MMQ_MMA_TILE_X_K_Q2_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K + 4) #define MMQ_MMA_TILE_X_K_Q3_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4) @@ -240,7 +244,11 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { case GGML_TYPE_Q8_0: return MMQ_MMA_TILE_X_K_Q8_0; // tile sizes are the same for Q8_1 and FP4 for blackwell case GGML_TYPE_MXFP4: return MMQ_MMA_TILE_X_K_Q8_1; +#if defined(BLACKWELL_MMA_AVAILABLE) + case GGML_TYPE_NVFP4: return MMQ_MMA_TILE_X_K_FP4; +#else case GGML_TYPE_NVFP4: return MMQ_MMA_TILE_X_K_NVFP4; +#endif // defined(BLACKWELL_MMA_AVAILABLE) case GGML_TYPE_Q2_K: return MMQ_MMA_TILE_X_K_Q2_K; case GGML_TYPE_Q3_K: return MMQ_MMA_TILE_X_K_Q3_K; case GGML_TYPE_Q4_K: return MMQ_MMA_TILE_X_K_Q8_1; @@ -934,6 +942,128 @@ static __device__ __forceinline__ void load_tiles_mxfp4_fp4(const char * __restr } } +#ifdef BLACKWELL_MMA_AVAILABLE +template +static __device__ __forceinline__ void load_tiles_nvfp4_nvfp4(const char * __restrict__ x, + int * __restrict__ x_tile, + const int kbx0, + const int i_max, + const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + constexpr int iter_k = get_iter_k(GGML_TYPE_NVFP4); + constexpr int threads_per_row = iter_k / QK_NVFP4; // each thread processes 1 block + constexpr int rows_per_warp = warp_size / threads_per_row; + + uint32_t * x_u32 = (uint32_t *) x_tile; + + const int txi = threadIdx.x; + const int kbx = txi % threads_per_row; + const int row_in_warp = txi / threads_per_row; + + const block_nvfp4 * bxi_base = (const block_nvfp4 *) x + kbx0 + kbx; + uint32_t * x_u32_scale = x_u32 + 64 + kbx; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += rows_per_warp * nwarps) { + int i = i0 + threadIdx.y * rows_per_warp + row_in_warp; + + if constexpr (need_check) { + i = min(i, i_max); + } + + const block_nvfp4 * bxi = bxi_base + i * stride; + const int row_base = i * MMQ_MMA_TILE_X_K_FP4; + const int q_base = row_base + 8 * kbx; + + const uint32_t * src_qs = reinterpret_cast(bxi->qs); + +#pragma unroll + for (int sub = 0; sub < QK_NVFP4 / QK_NVFP4_SUB; ++sub) { + x_u32[q_base + 2 * sub + 0] = src_qs[2 * sub + 0]; + x_u32[q_base + 2 * sub + 1] = src_qs[2 * sub + 1]; + } + + x_u32_scale[row_base] = get_int_b4(bxi->d, 0); + } +} + +// Shared MMA kernel for MXFP4 and NVFP4 on Blackwell. +// Both quantizations encode values as e2m1 (FP4) and produce one uint32 scale per +// m16n8k64 MMA call; only the PTX kind (scale_vec::2X ue8m0 vs scale_vec::4X ue4m3) +// and the per-type stride constant differ. +template +static __device__ __forceinline__ void vec_dot_fp4_fp4_mma(const int * __restrict__ x, + const int * __restrict__ y, + float * __restrict__ sum, + const int k00) { + static_assert(type == GGML_TYPE_MXFP4 || type == GGML_TYPE_NVFP4, + "vec_dot_fp4_fp4_mma: type must be MXFP4 or NVFP4"); + + typedef tile<16, 8, int> tile_A; + typedef tile<8, 8, int> tile_B; + typedef tile<16, 8, float> tile_C; + + constexpr int stride = MMQ_MMA_TILE_X_K_FP4; + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = 2 * granularity; + constexpr int ntx = rows_per_warp / tile_C::I; + constexpr int nfrags = MMQ_TILE_NE_K / tile_A::J; + + y += (threadIdx.y % ntx) * (tile_C::J * MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const uint32_t * x_sc = (const uint32_t *) (x_qs + 2 * MMQ_TILE_NE_K); + const int * y_qs = (const int *) y + 4; + const uint32_t * y_sc = (const uint32_t *) y; + + // 2 threads per quad supply the packed scale register to the block_scale MMA, + // see https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-block-scaling + const int tidx_A = threadIdx.x / 4 + (threadIdx.x % 2) * 8; + const int tidx_B = threadIdx.x / 4; + const int i0 = (threadIdx.y / ntx) * rows_per_warp; + + tile_A A[ntx][nfrags]; + uint32_t scaleA[ntx][nfrags]; + +#pragma unroll + for (int n = 0; n < ntx; ++n) { +#pragma unroll + for (int frag = 0; frag < nfrags; ++frag) { + const int k0 = k00 + frag * tile_A::J; + load_ldmatrix(A[n][frag], x_qs + (i0 + n * tile_A::I) * stride + k0, stride); + scaleA[n][frag] = x_sc[(i0 + n * tile_A::I + tidx_A) * stride + k0 / tile_A::J]; + } + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx * tile_C::J) { + tile_B B[nfrags]; + uint32_t scaleB[nfrags]; + +#pragma unroll + for (int frag = 0; frag < nfrags; ++frag) { + const int k0 = frag * tile_B::J; + load_generic(B[frag], y_qs + j0 * MMQ_TILE_Y_K + k0, MMQ_TILE_Y_K); + scaleB[frag] = y_sc[(j0 + tidx_B) * MMQ_TILE_Y_K + frag]; + } + +#pragma unroll + for (int n = 0; n < ntx; ++n) { +#pragma unroll + for (int frag = 0; frag < nfrags; ++frag) { + tile_C C = {}; + mma_block_scaled_fp4(C, A[n][frag], B[frag], scaleA[n][frag], scaleB[frag]); +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + sum[(j0 / tile_C::J + n) * tile_C::ne + l] += C.x[l]; + } + } + } + } +} +#endif // BLACKWELL_MMA_AVAILABLE + template static __device__ __forceinline__ void load_tiles_nvfp4(const char * __restrict__ x, @@ -1163,77 +1293,6 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( #endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } -template -static __device__ __forceinline__ void vec_dot_mxfp4_mxfp4_mma(const int * __restrict__ x, - const int * __restrict__ y, - float * __restrict__ sum, - const int k00) { - typedef tile<16, 8, int> tile_A; - typedef tile<8, 8, int> tile_B; - typedef tile<16, 8, float> tile_C; // Output is float for native scaled MMA - - constexpr int granularity = mmq_get_granularity_device(mmq_x); - constexpr int rows_per_warp = 2 * granularity; - constexpr int ntx = rows_per_warp / tile_C::I; // Number of x minitiles per warp. - - y += (threadIdx.y % ntx) * (tile_C::J * MMQ_TILE_Y_FP4_K); - - // Match layout from load_tiles_mxfp4_fp4 - const int * x_qs = (const int *) x; - const uint32_t * x_sc = (const uint32_t *) (x_qs + 2 * MMQ_TILE_NE_K); - const int * y_qs = (const int *) y + 4; - const uint32_t * y_sc = (const uint32_t *) y; - - // tile_A has a length of 64 logical values vs. 32 values in block_mxfp4 - tile_A A[ntx][MMQ_TILE_NE_K / (2 * QI_MXFP4)]; - uint32_t scaleA[ntx][MMQ_TILE_NE_K / (2 * QI_MXFP4)]; - - // Block scale - // Each thread has to point to a 4 byte scale value - // https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-block-scaling - - const int i0 = (threadIdx.y / ntx) * rows_per_warp; - -#pragma unroll - for (int n = 0; n < ntx; ++n) { -#pragma unroll - for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 2 * QI_MXFP4) { - const int k0 = k00 + k01; - - load_ldmatrix(A[n][k01 / (2 * QI_MXFP4)], x_qs + (i0 + n * tile_A::I) * MMQ_MMA_TILE_X_K_FP4 + k0, - MMQ_MMA_TILE_X_K_FP4); - - // based on block-scaling document, 2 threads in each quad need to supply to the scale value - const int tidx = threadIdx.x / 4 + (threadIdx.x % 2) * 8; - scaleA[n][k01 / (2 * QI_MXFP4)] = - *(x_sc + (i0 + n * tile_A::I + tidx) * MMQ_MMA_TILE_X_K_FP4 + k0 / (2 * QI_MXFP4)); - } - } - -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += ntx * tile_C::J) { -#pragma unroll - for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 2 * QI_MXFP4) { - tile_B B; - uint32_t scaleB; // 2xN scales - - load_generic(B, y_qs + j0 * MMQ_TILE_Y_FP4_K + k01, MMQ_TILE_Y_FP4_K); - - scaleB = y_sc[(j0 + threadIdx.x / 4) * MMQ_TILE_Y_FP4_K + k01 / (2 * QI_MXFP4)]; - -#pragma unroll - for (int n = 0; n < ntx; ++n) { - tile_C C; - - mma_block_scaled(C, A[n][k01 / (2 * QI_MXFP4)], B, scaleA[n][k01 / (2 * QI_MXFP4)], scaleB); -#pragma unroll - for (int l = 0; l < tile_C::ne; ++l) { - sum[(j0 / tile_C::J + n) * tile_C::ne + l] += C.x[l]; - } - } - } - } -} template static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a( @@ -3259,7 +3318,7 @@ struct mmq_type_traits { static constexpr int vdr = VDR_MXFP4_Q8_1_MMQ; #ifdef BLACKWELL_MMA_AVAILABLE static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4_fp4; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_mxfp4_mxfp4_mma; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_fp4_fp4_mma; #else static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; @@ -3270,8 +3329,13 @@ struct mmq_type_traits { template struct mmq_type_traits { static constexpr int vdr = VDR_NVFP4_Q8_1_MMQ; +#ifdef BLACKWELL_MMA_AVAILABLE + static constexpr load_tiles_mmq_t load_tiles = load_tiles_nvfp4_nvfp4; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_fp4_fp4_mma; +#else static constexpr load_tiles_mmq_t load_tiles = load_tiles_nvfp4; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; +#endif // BLACKWELL_MMA_AVAILABLE static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a; }; @@ -3406,7 +3470,7 @@ static __device__ __forceinline__ void mul_mat_q_process_tile( #if defined(BLACKWELL_MMA_AVAILABLE) // FP4 tile stores 8 blocks - constexpr int ne_block = (type == GGML_TYPE_MXFP4) ? 8 * QK_MXFP4 : 4 * QK8_1; + constexpr int ne_block = (type == GGML_TYPE_MXFP4 || type == GGML_TYPE_NVFP4) ? QK_K : 4 * QK8_1; #else constexpr int ne_block = 4 * QK8_1; #endif // defined(BLACKWELL_MMA_AVAILABLE) diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 8f55cace1a1..da48f313a38 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -115,6 +115,7 @@ static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_pascal_older(gg case GGML_TYPE_IQ4_NL: return 6; case GGML_TYPE_IQ4_XS: return 5; case GGML_TYPE_MXFP4: return 4; + case GGML_TYPE_NVFP4: return 4; case GGML_TYPE_Q2_K: return 4; case GGML_TYPE_Q3_K: return 4; case GGML_TYPE_Q4_0: return 6; @@ -135,6 +136,7 @@ static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_turing_plus(ggm case GGML_TYPE_IQ3_S: return 6; case GGML_TYPE_IQ3_XXS: return 7; case GGML_TYPE_MXFP4: return 7; + case GGML_TYPE_NVFP4: return 8; case GGML_TYPE_Q2_K: return 7; case GGML_TYPE_Q3_K: return 5; default: return MMVQ_MAX_BATCH_SIZE; @@ -221,6 +223,7 @@ static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_rdna4(ggml_type case GGML_TYPE_IQ4_NL: return 7; case GGML_TYPE_IQ4_XS: return 5; case GGML_TYPE_MXFP4: return 5; + case GGML_TYPE_NVFP4: return 5; case GGML_TYPE_Q3_K: return 4; case GGML_TYPE_Q4_0: return 7; case GGML_TYPE_Q4_1: return 7; diff --git a/ggml/src/ggml-cuda/quantize.cu b/ggml/src/ggml-cuda/quantize.cu index 4300ffc148c..52f664719ae 100644 --- a/ggml/src/ggml-cuda/quantize.cu +++ b/ggml/src/ggml-cuda/quantize.cu @@ -70,6 +70,102 @@ __device__ __forceinline__ uint8_t compute_e8m0_scale(float amax) { return static_cast(biased); } + +static __global__ void quantize_mmq_nvfp4( + const float * __restrict__ x, const int32_t * __restrict__ ids, void * __restrict__ vy, + const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03, + const int64_t ne0, const int64_t ne1, const int64_t ne2) { +#if defined(BLACKWELL_MMA_AVAILABLE) + + const int64_t i0_base = ((int64_t) blockDim.x * blockIdx.y + threadIdx.x) * QK_NVFP4_SUB; + if (i0_base >= ne0) { + return; + } + + const int64_t i1 = blockIdx.x; + const int64_t i2 = blockIdx.z % ne2; + const int64_t i3 = blockIdx.z / ne2; + const int64_t i01 = ids ? ids[i1] : i1; + const int64_t k_block = i0_base / QK_K; + const int64_t blocks_per_col = (ne0 + QK_K - 1) / QK_K; + if (k_block >= blocks_per_col) { + return; + } + + const int64_t ib = blockIdx.z * ((int64_t) blocks_per_col * ne1) + k_block * ne1 + blockIdx.x; + block_fp4_mmq * y = (block_fp4_mmq *) vy; + block_fp4_mmq * yb = y + ib; + + const int sub = (i0_base % QK_K) / QK_NVFP4_SUB; + + float vals_raw[QK_NVFP4_SUB]; + float amax_raw = 0.0f; + const int64_t base_idx = i3 * s03 + i2 * s02 + i01 * s01; +#pragma unroll + for (int k = 0; k < QK_NVFP4_SUB; k++) { + const int64_t i00 = i0_base + k; + if (i00 < ne00) { + const float v = x[base_idx + i00]; + vals_raw[k] = v; + amax_raw = fmaxf(amax_raw, fabsf(v)); + } else { + vals_raw[k] = 0.0f; + } + } + + static constexpr int test_offsets[5] = { 0, -1, 1, -2, 2}; + const int first_fp8_code = (int) ggml_cuda_fp32_to_ue4m3(amax_raw / 6.0f); + + float best_err = FLT_MAX; + uint8_t fp8_code = 0; + float subblock_scale = 0.0f; + +#pragma unroll // Check +/- 2 to find best code to reduce NVFP4 activation loss. Negligible overhead on Blackwell. + for (int i = 0; i < 5; i++) { + const int test_code = first_fp8_code + test_offsets[i]; + if (test_code < 0 || test_code > 0x7e) { + continue; + } + const uint8_t code = (uint8_t) test_code; + const float test_scale = ggml_cuda_ue4m3_to_fp32(code); + const float test_inv_scale = test_scale > 0.0f ? 0.5f / test_scale : 0.0f; + float cur_err = 0.0f; +#pragma unroll + for (int k = 0; k < QK_NVFP4_SUB; ++k) { + const float v = vals_raw[k]; + const uint8_t q = ggml_cuda_float_to_fp4_e2m1(v, test_inv_scale); + const float err_diff = fabsf(v) - fabsf(kvalues_mxfp4[q & 0x7]) * test_scale; + cur_err = fmaf(err_diff, err_diff, cur_err); + } + + if (cur_err < best_err) { + best_err = cur_err; + fp8_code = test_code; + subblock_scale = test_scale; + } + } + + const float inv_scale = subblock_scale > 0.0f ? 0.5f / subblock_scale : 0.0f; + uint32_t q0 = 0; + uint32_t q1 = 0; +#pragma unroll // this is faster than the previous __nv_fp4x4_e2m1 + for (int k = 0; k < QK_NVFP4_SUB / 4; ++k) { + q0 |= (uint32_t) ggml_cuda_float_to_fp4_e2m1(vals_raw[k + 0], inv_scale) << (8 * k); + q0 |= (uint32_t) ggml_cuda_float_to_fp4_e2m1(vals_raw[k + 8], inv_scale) << (8 * k + 4); + q1 |= (uint32_t) ggml_cuda_float_to_fp4_e2m1(vals_raw[k + 4], inv_scale) << (8 * k); + q1 |= (uint32_t) ggml_cuda_float_to_fp4_e2m1(vals_raw[k + 12], inv_scale) << (8 * k + 4); + } + + uint32_t * yqs = reinterpret_cast(yb->qs); + yqs[2 * sub + 0] = q0; + yqs[2 * sub + 1] = q1; + reinterpret_cast(yb->d4)[sub] = fp8_code; +#else + NO_DEVICE_CODE; // This is for Blackwell NVFP4 activations only. +#endif // defined(BLACKWELL_MMA_AVAILABLE) + +} + // quantize values in the format mxfp4 is stored which is interleaved nibbles // i.e. a block a0-a31 is represented as a0a16,a1a17 ...a15a31 static __global__ void quantize_mmq_mxfp4(const float * __restrict__ x, @@ -316,28 +412,32 @@ void quantize_mmq_q8_1_cuda( } } -void quantize_mmq_mxfp4_cuda(const float * x, - const int32_t * ids, - void * vy, - [[maybe_unused]] const ggml_type type_src0, - const int64_t ne00, - const int64_t s01, - const int64_t s02, - const int64_t s03, - const int64_t ne0, - const int64_t ne1, - const int64_t ne2, - const int64_t ne3, - cudaStream_t stream) { - GGML_ASSERT(ne0 % (2 * QK_MXFP4) == 0); - - constexpr int nwarps = 8; - constexpr int vals_per_warp = 2 * QK_MXFP4; - constexpr int vals_per_block = nwarps * vals_per_warp; - - const int64_t block_num_y = (ne0 + vals_per_block - 1) / vals_per_block; - const dim3 num_blocks(ne1, block_num_y, ne2 * ne3); - const dim3 block_size(WARP_SIZE, nwarps, 1); - - quantize_mmq_mxfp4<<>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2); +void quantize_mmq_fp4_cuda( + const float * x, const int32_t * ids, void * vy, const ggml_type type_src0, + const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03, + const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) { + GGML_ASSERT(type_src0 == GGML_TYPE_MXFP4 || type_src0 == GGML_TYPE_NVFP4); + GGML_ASSERT(ne0 > 0); + + if (type_src0 == GGML_TYPE_NVFP4) { + GGML_ASSERT(ne00 % QK_NVFP4 == 0); + constexpr int nvfp4_block_size = 128; + const int64_t block_num_y = (ne0 + QK_NVFP4_SUB * nvfp4_block_size - 1) / (QK_NVFP4_SUB * nvfp4_block_size); + const dim3 block_size(nvfp4_block_size, 1, 1); + const dim3 num_blocks(ne1, block_num_y, ne2 * ne3); + quantize_mmq_nvfp4<<>>( + x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2); + } else { + GGML_ASSERT(ne0 % (2 * QK_MXFP4) == 0); + + constexpr int nwarps = 8; + constexpr int vals_per_warp = 2 * QK_MXFP4; + constexpr int vals_per_block = nwarps * vals_per_warp; + + const int64_t block_num_y = (ne0 + vals_per_block - 1) / vals_per_block; + const dim3 num_blocks(ne1, block_num_y, ne2 * ne3); + const dim3 block_size(WARP_SIZE, nwarps, 1); + + quantize_mmq_mxfp4<<>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2); + } } diff --git a/ggml/src/ggml-cuda/quantize.cuh b/ggml/src/ggml-cuda/quantize.cuh index 6a91df63578..768a3ae6de6 100644 --- a/ggml/src/ggml-cuda/quantize.cuh +++ b/ggml/src/ggml-cuda/quantize.cuh @@ -26,7 +26,7 @@ void quantize_mmq_q8_1_cuda( ggml_type type_src0, int64_t ne00, int64_t s01, int64_t s02, int64_t s03, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, cudaStream_t stream); -void quantize_mmq_mxfp4_cuda(const float * x, +void quantize_mmq_fp4_cuda(const float * x, const int32_t * ids, void * vy, ggml_type type_src0, From 53011393746fcdc9423af536fba0be02a1d66363 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Wed, 29 Apr 2026 08:55:07 +0200 Subject: [PATCH 515/831] TP: fix delayed AllReduce + zero-sized slices (llama/22489) --- ggml/src/ggml-backend-meta.cpp | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-backend-meta.cpp b/ggml/src/ggml-backend-meta.cpp index 41a61775bd6..fbc02d6458a 100644 --- a/ggml/src/ggml-backend-meta.cpp +++ b/ggml/src/ggml-backend-meta.cpp @@ -1826,7 +1826,24 @@ static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend, continue; } - i = get_i_delayed(i); + const int i_delayed = get_i_delayed(i); + + // If we can delay the AllReduce we need to consider the interaction with zero-sized tensor slices. + // A backend with such a slice would normally have valid data after participating in the AllReduce with a node that has + // its compute flag disabled and thus gets its data zeroed out. + // If the AllReduce is delayed then the nodes until that point also need to have their compute flag disabled. + if (i_delayed > i) { + for (size_t j = 0; j < n_backends; j++) { + auto & bcj = backend_ctx->backend_configs[j]; + if ((bcj.nodes[i]->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + for (int ii = i + 1; ii <= i_delayed; ii++) { + bcj.nodes[ii]->flags &= ~GGML_TENSOR_FLAG_COMPUTE; + } + } + } + } + + i = i_delayed; for (size_t j = 0; j < n_backends; j++) { auto & bcj = backend_ctx->backend_configs[j]; From 3076725eb074338c0b0fa0bb50bfc00d4aec6497 Mon Sep 17 00:00:00 2001 From: hrushitfujitsu Date: Wed, 29 Apr 2026 13:27:37 +0530 Subject: [PATCH 516/831] ggml : add sve tuned code for gemm_q8_0_4x8_q8_0() kernel (llama/21916) * Added sve tuned code for gemm_q8_0_4x8_q8_0() kernel * Change arrays to static const in repack.cpp --------- Co-authored-by: Vithulep --- ggml/src/ggml-cpu/arch/arm/repack.cpp | 65 +++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/ggml/src/ggml-cpu/arch/arm/repack.cpp b/ggml/src/ggml-cpu/arch/arm/repack.cpp index 80ff5ce549b..a7534443091 100644 --- a/ggml/src/ggml-cpu/arch/arm/repack.cpp +++ b/ggml/src/ggml-cpu/arch/arm/repack.cpp @@ -5023,6 +5023,71 @@ void ggml_gemm_q8_0_4x8_q8_0(int n, UNUSED(ncols_interleaved); UNUSED(blocklen); +#if defined(__aarch64__) && defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8) + if (svcntb() * 8 == 256) { + const block_q8_0x4 * b_ptr_base = (const block_q8_0x4 *) vx; + + static const uint32_t idx_arr[8] = {0, 1, 4, 5, 2, 3, 6, 7}; + svuint32_t idx = svld1(svptrue_b32(), idx_arr); + static const uint32_t idx_arr1[8] = {0, 1, 2, 3, 1, 2, 3, 0}; + svuint32_t idx_sc1 = svld1(svptrue_b32(), idx_arr1); + static const uint32_t idx_arr2[8] = {0, 1, 2, 3, 0, 1, 2, 3}; + svuint32_t idx_sc2 = svld1(svptrue_b32(), idx_arr2); + + for (int y = 0; y < nr; y += 4) { + const block_q8_0x4 * a_ptr_base = (const block_q8_0x4 *) vy + (y / 4) * nb; + + for (int x = 0; x < nc; x += ncols_interleaved) { + const block_q8_0x4 * b_ptr = b_ptr_base + (x / 4) * nb; + const block_q8_0x4 * a_ptr = a_ptr_base; + + svfloat32_t acc_f32_01 = svdup_f32(0); + svfloat32_t acc_f32_23 = svdup_f32(0); + + for (int b = 0; b < nb; b++) { + + svint32_t acc_01 = svdup_s32(0); + svint32_t acc_23 = svdup_s32(0); + + // Process 4 chunks of 8 positions each + for (int chunk = 0; chunk < 4; chunk++) { + svint8_t s_a01 = svld1rq_s8(svptrue_b8(), a_ptr->qs + chunk * 32); + svint8_t s_a23 = svld1rq_s8(svptrue_b8(), a_ptr->qs + chunk * 32 + 16); + svint8_t s_b0123 = svld1_s8(svptrue_b8(), b_ptr->qs + chunk * 32); + + acc_01 = svmmla_s32(acc_01, s_a01, s_b0123); + acc_23 = svmmla_s32(acc_23, s_a23, s_b0123); + } + + // Reorder outputs from 2×2 tiles to row-major + // acc[01] = [r0c0, r0c1, r1c0, r1c1, r0c2, r0c3, r1c2, r1c3] + // acc[23] = [r2c0, r2c1, r3c0, r3c1, r2c2, r2c3, r3c2, r3c3] + + svint32_t row01 = svtbl_s32(acc_01, idx); + svint32_t row23 = svtbl_s32(acc_23, idx); + + svfloat16_t temp1 = svld1_f16(svptrue_pat_b16(SV_VL4), (const __fp16 *) a_ptr->d); + svfloat16_t temp2 = svld1_f16(svptrue_pat_b16(SV_VL4), (const __fp16 *) b_ptr->d); + svfloat32_t sv_a_d = svtbl_f32(svcvt_f32_f16_x(svptrue_b32(), svzip1_f16(temp1, temp1)), idx_sc1); + svfloat32_t sv_b_d = svtbl_f32(svcvt_f32_f16_x(svptrue_b32(), svzip1_f16(temp2, temp2)), idx_sc2); + + acc_f32_01 = svmla_f32_x(svptrue_b32(), acc_f32_01, svcvt_f32_s32_x(svptrue_b32(), row01), svmul_lane_f32(sv_b_d, sv_a_d, 0)); + acc_f32_23 = svmla_f32_x(svptrue_b32(), acc_f32_23, svcvt_f32_s32_x(svptrue_b32(), row23), svmul_lane_f32(sv_b_d, sv_a_d, 2)); + a_ptr++; + b_ptr++; + } + + svbool_t pg4 = svptrue_pat_b32(SV_VL4); + svst1_f32(pg4, s + (y+0) * bs + x, acc_f32_01); + svst1_f32(pg4, s + (y+1) * bs + x, svext_f32(acc_f32_01, acc_f32_01, 4)); + svst1_f32(pg4, s + (y+2) * bs + x, acc_f32_23); + svst1_f32(pg4, s + (y+3) * bs + x, svext_f32(acc_f32_23, acc_f32_23, 4)); + } + } + return; + } +#endif // SVE compile-time end + #if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) const block_q8_0x4 * b_ptr_base = (const block_q8_0x4 *) vx; From fa20229eeb54ee219fe9f67782bbae799d953f2b Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Wed, 29 Apr 2026 00:59:00 -0700 Subject: [PATCH 517/831] ggml-webgpu: Fix bug in FlashAttention support check (llama/22492) * Fix flashattention support check for devices that don't support subgroups * set path to none if kv_tile doesn't fit --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 44 ++++++++++++------- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 4 ++ 2 files changed, 31 insertions(+), 17 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 34cbf3694b1..b7771ac230e 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -494,9 +494,10 @@ struct ggml_webgpu_unary_pipeline_key_hash { /** FlashAttention */ enum ggml_webgpu_flash_attn_path : uint32_t { - GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX = 0u, - GGML_WEBGPU_FLASH_ATTN_PATH_TILE = 1u, - GGML_WEBGPU_FLASH_ATTN_PATH_VEC = 2u, + GGML_WEBGPU_FLASH_ATTN_PATH_NONE = 0u, + GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX = 1u, + GGML_WEBGPU_FLASH_ATTN_PATH_TILE = 2u, + GGML_WEBGPU_FLASH_ATTN_PATH_VEC = 3u, }; struct ggml_webgpu_flash_attn_pipeline_key { @@ -534,7 +535,7 @@ struct ggml_webgpu_flash_attn_pipeline_key_hash { }; struct ggml_webgpu_flash_attn_decisions { - uint32_t path = GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX; + uint32_t path = GGML_WEBGPU_FLASH_ATTN_PATH_NONE; uint32_t q_tile = 0; uint32_t kv_tile = 0; uint32_t wg_size = 0; @@ -709,19 +710,29 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions( (context.src0->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && (context.src2->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && !use_vec; - decisions.path = use_vec ? GGML_WEBGPU_FLASH_ATTN_PATH_VEC : - use_tile ? GGML_WEBGPU_FLASH_ATTN_PATH_TILE : - GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX; + decisions.path = use_vec ? GGML_WEBGPU_FLASH_ATTN_PATH_VEC : + use_tile ? GGML_WEBGPU_FLASH_ATTN_PATH_TILE : + context.supports_subgroup_matrix ? GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX : + GGML_WEBGPU_FLASH_ATTN_PATH_NONE; + + if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_NONE) { + return decisions; + } const ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context, decisions.path); decisions.kv_direct = key.kv_direct; + const uint32_t max_kv_tile = ggml_webgpu_flash_attn_max_kv_tile(context, key); + // invalidate if even the smallest kv_tile doesn't fit in shared memory + if (max_kv_tile == 0) { + decisions.path = GGML_WEBGPU_FLASH_ATTN_PATH_NONE; + return decisions; + } if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { - const uint32_t min_kv_tile = ggml_webgpu_flash_attn_max_kv_tile(context, key); - decisions.q_tile = 1u; - decisions.kv_tile = std::max(8u, std::min(32u, min_kv_tile)); - decisions.kv_tile = (decisions.kv_tile / 8u) * 8u; - decisions.wg_size = std::max(1u, std::min(32u, context.max_subgroup_size)); + decisions.q_tile = 1u; + decisions.kv_tile = std::max(8u, std::min(32u, max_kv_tile)); + decisions.kv_tile = (decisions.kv_tile / 8u) * 8u; + decisions.wg_size = std::max(1u, std::min(32u, context.max_subgroup_size)); if (decisions.kv_direct) { decisions.kv_tile = std::min(decisions.kv_tile, GGML_WEBGPU_KV_SEQ_PAD); while (GGML_WEBGPU_KV_SEQ_PAD % decisions.kv_tile != 0) { @@ -734,9 +745,8 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions( decisions.q_tile = decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE : context.sg_mat_m; decisions.kv_tile = decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? - std::min(64u, ggml_webgpu_flash_attn_max_kv_tile(context, key)) : - std::min(ggml_webgpu_flash_attn_max_kv_tile(context, key), - context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES); + std::min(64u, max_kv_tile) : + std::min(max_kv_tile, context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES); decisions.wg_size = decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE : std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE); @@ -755,7 +765,6 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions( context.sg_mat_n; } } - return decisions; } @@ -1364,7 +1373,7 @@ class ggml_webgpu_shader_lib { if (key.src_type == GGML_TYPE_Q1_0) { defines.push_back("BLOCK_SIZE=128u"); } else if ((key.src_type >= GGML_TYPE_Q4_0 && key.src_type <= GGML_TYPE_Q8_1) || - key.src_type == GGML_TYPE_IQ4_NL) { + key.src_type == GGML_TYPE_IQ4_NL) { defines.push_back("BLOCK_SIZE=32u"); } else if (key.src_type >= GGML_TYPE_Q2_K) { defines.push_back("BLOCK_SIZE=256u"); @@ -2325,6 +2334,7 @@ class ggml_webgpu_shader_lib { size_t storage_offset_alignment) { const ggml_webgpu_flash_attn_decisions decisions = ggml_webgpu_flash_attn_get_decisions(context, storage_offset_alignment); + GGML_ASSERT(decisions.path != GGML_WEBGPU_FLASH_ATTN_PATH_NONE); ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context, decisions.path); auto it = flash_attn_pipelines.find(key); if (it != flash_attn_pipelines.end()) { diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 762d9f8d1b4..f7fd73ae144 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -3918,6 +3918,10 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const shader_lib_ctx, ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment); const size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; const bool has_mask = op->src[3] != nullptr; + if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_NONE) { + supports_op = false; + break; + } if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0], From 6119537e9aa65c4dc117c395c2d5acac07eb6b21 Mon Sep 17 00:00:00 2001 From: qiurui144 <39214303+qiurui144@users.noreply.github.com> Date: Wed, 29 Apr 2026 15:59:21 +0800 Subject: [PATCH 518/831] ggml-cpu: cmake: append xsmtvdotii march for SpacemiT IME (llama/22317) * ggml-cpu: cmake: append xsmtvdotii march for SpacemiT IME When GGML_CPU_RISCV64_SPACEMIT=ON is set, ime1_kernels.cpp contains inline asm for the vmadot family which requires the xsmtvdotii custom extension.(problem can see in some blogs and make sure in K3 platform) The current CMakeLists does not include xsmtvdotii, so any toolchain that honours the explicit -march (tested with SpacemiT GCC 15.2) fails at the assembler stage: Error: unrecognized opcode `vmadot v16,v14,v0', extension `xsmtvdotii' required Append _xsmtvdotii to MARCH_STR when GGML_CPU_RISCV64_SPACEMIT is enabled so the IME path can actually build with a capable toolchain. No effect on builds that leave GGML_CPU_RISCV64_SPACEMIT off. toolchain from https://www.spacemit.com/community/resources-download/Tools * Update ggml/src/ggml-cpu/CMakeLists.txt Co-authored-by: alex-spacemit --------- Co-authored-by: alex-spacemit --- ggml/src/ggml-cpu/CMakeLists.txt | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index beebc4760d2..c1c225f0197 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -485,6 +485,13 @@ function(ggml_add_cpu_backend_variant_impl tag_name) if (GGML_RV_ZIHINTPAUSE) string(APPEND MARCH_STR "_zihintpause") endif() + if (GGML_CPU_RISCV64_SPACEMIT) + # `xsmtvdotii' is only required for GCC >= 15. + if (CMAKE_C_COMPILER_ID STREQUAL "GNU" AND + CMAKE_C_COMPILER_VERSION VERSION_GREATER_EQUAL 15) + string(APPEND MARCH_STR "_xsmtvdotii") + endif() + endif() list(APPEND ARCH_FLAGS "-march=${MARCH_STR}" -mabi=lp64d) else() From 44e7803661cf16d648b8c0a5b250aea1167d99c1 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Wed, 29 Apr 2026 16:19:33 +0800 Subject: [PATCH 519/831] ggml-cuda: refactor fusion code (llama/22468) * ggml-cuda: refactor fusion code * apply formatting + make env variable truthy --- ggml/src/ggml-cuda/ggml-cuda.cu | 703 ++++++++++++++++---------------- 1 file changed, 355 insertions(+), 348 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 1c2c3b4ac69..fd8dd91714c 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3640,6 +3640,357 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, return false; } +// try and fuse nodes and return the number of nodes to skip +static int ggml_cuda_try_fuse(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, int i) { + + static bool disable_fusion = getenv("GGML_CUDA_DISABLE_FUSION") != nullptr && std::atoi(getenv("GGML_CUDA_DISABLE_FUSION")); + if (disable_fusion) { + return 0; + } + + ggml_tensor * node = cgraph->nodes[i]; + + //topk-moe + if (cgraph->nodes[i]->op == GGML_OP_UNARY || cgraph->nodes[i]->op == GGML_OP_SOFT_MAX || + cgraph->nodes[i]->op == GGML_OP_ARGSORT) { + ggml_cuda_topk_moe_args args; + const bool can_fuse = ggml_cuda_topk_moe_fusion(cgraph, i, args); + std::vector ops; + + if (can_fuse) { + const ggml_tensor * logits = node->src[0]; + ggml_tensor * weights = nullptr; + ggml_tensor * ids = nullptr; + const ggml_tensor * bias = nullptr; + const ggml_tensor * clamp = nullptr; + const ggml_tensor * scale = nullptr; + + if (!args.delayed_softmax) { + ggml_op gating_op = args.sigmoid ? GGML_OP_UNARY : GGML_OP_SOFT_MAX; + int out_nodes[2]; // nodes which can't be elided + + if (args.prob_bias) { + bias = cgraph->nodes[i + 2]->src[1]; + ops.insert(ops.end(), { gating_op, GGML_OP_RESHAPE, GGML_OP_ADD, GGML_OP_ARGSORT, GGML_OP_VIEW, + GGML_OP_GET_ROWS }); + out_nodes[0] = i + 4; + ids = cgraph->nodes[i + 4]; + } else { + ops.insert(ops.end(), + { gating_op, GGML_OP_RESHAPE, GGML_OP_ARGSORT, GGML_OP_VIEW, GGML_OP_GET_ROWS }); + out_nodes[0] = i + 3; + ids = cgraph->nodes[i + 3]; + } + + if (args.norm) { + ops.insert(ops.end(), + { GGML_OP_RESHAPE, GGML_OP_SUM_ROWS, GGML_OP_CLAMP, GGML_OP_DIV, GGML_OP_RESHAPE }); + clamp = cgraph->nodes[i + ops.size() - 3]; + } + if (args.scale) { + ops.insert(ops.end(), { GGML_OP_SCALE }); + scale = cgraph->nodes[i + ops.size() - 1]; + } + + weights = cgraph->nodes[i + ops.size() - 1]; + out_nodes[1] = i + ops.size() - 1; + + if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) && + ggml_cuda_should_use_topk_moe(node, logits, weights, ids) && + ggml_cuda_check_fusion_memory_ranges(cgraph, i, ops.size(), out_nodes, 2, /*is_topk_moe=*/true)) { + ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args); + return ops.size() - 1; + } + } else if (!args.norm && !args.prob_bias) { + //special case gpt-oss, no norm, no bias. + ops.insert(ops.end(), { GGML_OP_ARGSORT, GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE, + GGML_OP_SOFT_MAX, GGML_OP_RESHAPE }); + weights = cgraph->nodes[i + 5]; + ids = cgraph->nodes[i + 1]; + const ggml_tensor * softmax = cgraph->nodes[i + 4]; + + int out_nodes[2] = { i + 1, i + 5 }; + if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) && + ggml_cuda_should_use_topk_moe(softmax, logits, weights, ids) && + ggml_cuda_check_fusion_memory_ranges(cgraph, i, ops.size(), out_nodes, 2, /*is_topk_moe=*/true)) { + ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args); + return ops.size() - 1; + } + } + } + } + + //RoPE + view + set-rows + if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, {})) { + ggml_tensor * rope = cgraph->nodes[i]; + ggml_tensor * set_rows = cgraph->nodes[i + 2]; + + ggml_cuda_op_rope_fused(*cuda_ctx, rope, set_rows); + return 2; + } + + // multi-(add or mul) + if (node->op == GGML_OP_ADD || node->op == GGML_OP_MUL) { + int n_fuse = 0; + ggml_op ops[8]; + std::fill(ops, ops + 8, node->op); + + for (; n_fuse <= 6; ++n_fuse) { + if (!ggml_can_fuse(cgraph, i + n_fuse, ops + n_fuse, 2)) { + break; + } + if (cgraph->nodes[i + n_fuse] != cgraph->nodes[i + n_fuse + 1]->src[0]) { + break; + } + if (!ggml_are_same_layout(cgraph->nodes[i + n_fuse]->src[1], cgraph->nodes[i + n_fuse + 1]->src[1])) { + break; + } + } + + n_fuse++; + + if (n_fuse > 1) { + ggml_tensor fused_node; + memcpy(&fused_node, node, sizeof(ggml_tensor)); + for (int j = 0; j < n_fuse - 1; ++j) { + fused_node.src[j + 2] = cgraph->nodes[i + j + 1]->src[1]; + } + fused_node.data = cgraph->nodes[i + n_fuse - 1]->data; + if (node->op == GGML_OP_ADD) { + ggml_cuda_op_fused_add(*cuda_ctx, &fused_node, n_fuse); + } else { + ggml_cuda_op_fused_mul(*cuda_ctx, &fused_node, n_fuse); + } + return n_fuse - 1; + } + } + + bool fused_mul_mat_vec = false; + int fused_node_count = 0; + + // gate + glu + up + for (ggml_op op : { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID }) { + const ggml_op bias_op = op == GGML_OP_MUL_MAT ? GGML_OP_ADD : GGML_OP_ADD_ID; + + if (ggml_cuda_can_fuse(cgraph, i, { op, bias_op, op, bias_op, GGML_OP_GLU }, {})) { + ggml_tensor * glu = cgraph->nodes[i + 4]; + ggml_tensor * gate_bias_n = glu->src[0]; + ggml_tensor * up_bias_n = glu->src[1]; + + //we don't assume the order for {gate, up}. Instead infer it from the bias tensor + ggml_tensor * gate_n = nullptr; + ggml_tensor * up_n = nullptr; + + if (gate_bias_n->src[0] == cgraph->nodes[i] || gate_bias_n->src[1] == cgraph->nodes[i]) { + gate_n = cgraph->nodes[i]; + up_n = cgraph->nodes[i + 2]; + } else if (gate_bias_n->src[0] == cgraph->nodes[i + 2] || gate_bias_n->src[1] == cgraph->nodes[i + 2]) { + gate_n = cgraph->nodes[i + 2]; + up_n = cgraph->nodes[i]; + } else { + continue; + } + + auto get_bias_tensor = [](const ggml_tensor * bias_node, const ggml_tensor * mul_node, ggml_op op_bias) { + if (op_bias == GGML_OP_ADD) { + if (bias_node->src[0] == mul_node) { + return bias_node->src[1]; + } + if (bias_node->src[1] == mul_node) { + return bias_node->src[0]; + } + return (ggml_tensor *) nullptr; + } + GGML_ASSERT(op_bias == GGML_OP_ADD_ID); + GGML_ASSERT(bias_node->src[0] == mul_node); + return bias_node->src[1]; + }; + + ggml_tensor * up_bias_tensor = get_bias_tensor(up_bias_n, up_n, bias_op); + ggml_tensor * gate_bias_tensor = get_bias_tensor(gate_bias_n, gate_n, bias_op); + + if (!up_bias_tensor || !gate_bias_tensor) { + continue; + } + + // we don't support repeating adds + if (bias_op == GGML_OP_ADD && (!ggml_are_same_shape(gate_bias_n->src[0], gate_bias_n->src[1]) || + !ggml_are_same_shape(up_bias_n->src[0], up_bias_n->src[1]))) { + continue; + } + + const ggml_tensor * src0 = up_n->src[0]; + const ggml_tensor * src1 = up_n->src[1]; + const ggml_tensor * ids = up_n->src[2]; + + if (ggml_cuda_should_fuse_mul_mat_vec_f(up_n)) { + ggml_cuda_mm_fusion_args_host fusion_data{}; + fusion_data.gate = gate_n->src[0]; + fusion_data.x_bias = up_bias_tensor; + fusion_data.gate_bias = gate_bias_tensor; + fusion_data.glu_op = ggml_get_glu_op(glu); + + ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, glu, &fusion_data); + fused_mul_mat_vec = true; + fused_node_count = 5; + break; + } + + if (ggml_cuda_should_fuse_mul_mat_vec_q(up_n)) { + ggml_cuda_mm_fusion_args_host fusion_data{}; + fusion_data.gate = gate_n->src[0]; + fusion_data.x_bias = up_bias_tensor; + fusion_data.gate_bias = gate_bias_tensor; + fusion_data.glu_op = ggml_get_glu_op(glu); + + ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, glu, &fusion_data); + fused_mul_mat_vec = true; + fused_node_count = 5; + break; + } + } else if (ggml_cuda_can_fuse(cgraph, i, { op, op, GGML_OP_GLU }, {})) { + ggml_tensor * glu = cgraph->nodes[i + 2]; + ggml_tensor * gate = glu->src[0]; + ggml_tensor * up = glu->src[1]; + + bool ok = (gate == cgraph->nodes[i] && up == cgraph->nodes[i + 1]) || + (gate == cgraph->nodes[i + 1] && up == cgraph->nodes[i]); + + if (!ok) { + continue; + } + + const ggml_tensor * src0 = up->src[0]; + const ggml_tensor * src1 = up->src[1]; + const ggml_tensor * ids = up->src[2]; + + if (ggml_cuda_should_fuse_mul_mat_vec_f(up)) { + ggml_cuda_mm_fusion_args_host fusion_data{}; + fusion_data.gate = gate->src[0]; + fusion_data.glu_op = ggml_get_glu_op(glu); + + ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, glu, &fusion_data); + fused_mul_mat_vec = true; + fused_node_count = 3; + break; + } + + if (ggml_cuda_should_fuse_mul_mat_vec_q(up)) { + ggml_cuda_mm_fusion_args_host fusion_data{}; + fusion_data.gate = gate->src[0]; + fusion_data.glu_op = ggml_get_glu_op(glu); + + ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, glu, &fusion_data); + fused_mul_mat_vec = true; + fused_node_count = 3; + break; + } + } + } + + if (fused_mul_mat_vec) { + return fused_node_count - 1; + } + + fused_mul_mat_vec = false; + fused_node_count = 0; + + // gate + add + glu + up + add + for (ggml_op op : { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID }) { + const ggml_op bias_op = op == GGML_OP_MUL_MAT ? GGML_OP_ADD : GGML_OP_ADD_ID; + + if (!ggml_can_fuse(cgraph, i, { op, bias_op })) { + continue; + } + + ggml_tensor * mm_node = cgraph->nodes[i]; + ggml_tensor * bias_node = cgraph->nodes[i + 1]; + + ggml_tensor * bias_tensor = nullptr; + if (bias_op == GGML_OP_ADD) { + if (bias_node->src[0] == mm_node) { + bias_tensor = bias_node->src[1]; + } else if (bias_node->src[1] == mm_node) { + bias_tensor = bias_node->src[0]; + } else { + continue; + } + } else { + if (bias_node->src[0] != mm_node) { + continue; + } + bias_tensor = bias_node->src[1]; + } + + const ggml_tensor * src0 = mm_node->src[0]; + const ggml_tensor * src1 = mm_node->src[1]; + const ggml_tensor * ids = mm_node->src[2]; + + if (bias_op == GGML_OP_ADD_ID && bias_node->src[2] != ids) { + continue; + } + + if (bias_op == GGML_OP_ADD && !ggml_are_same_shape(bias_node->src[0], bias_node->src[1])) { + continue; + } + + ggml_cuda_mm_fusion_args_host fusion_data{}; + fusion_data.x_bias = bias_tensor; + + if (ggml_cuda_should_fuse_mul_mat_vec_f(mm_node)) { + ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, bias_node, &fusion_data); + fused_mul_mat_vec = true; + fused_node_count = 2; + break; + } + + if (ggml_cuda_should_fuse_mul_mat_vec_q(mm_node)) { + ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, bias_node, &fusion_data); + fused_mul_mat_vec = true; + fused_node_count = 2; + break; + } + } + + if (fused_mul_mat_vec) { + return fused_node_count - 1; + } + + if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ADD }, {})) { + ggml_cuda_op_rms_norm_fused_add(*cuda_ctx, node, cgraph->nodes[i + 1], cgraph->nodes[i + 2]); + return 2; + } + + if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL }, {})) { + ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i + 1]); + return 1; + } + + if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SSM_CONV, GGML_OP_UNARY }, { GGML_UNARY_OP_SILU })) { + ggml_cuda_op_ssm_conv(*cuda_ctx, node, cgraph->nodes[i + 1]); + return 1; + } + + if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_MUL }, { GGML_UNARY_OP_SILU }) || + ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_MUL }, { GGML_UNARY_OP_SIGMOID }) || + ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_MUL }, { GGML_UNARY_OP_SOFTPLUS })) { + ggml_cuda_op_unary_mul(*cuda_ctx, node, cgraph->nodes[i + 1]); + return 1; + } + + if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_SQR }, { GGML_UNARY_OP_RELU })) { + ggml_cuda_op_relu_sqr(*cuda_ctx, node, cgraph->nodes[i + 1]); + return 1; + } + + if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SCALE, GGML_OP_UNARY, GGML_OP_SCALE }, { GGML_UNARY_OP_TANH })) { + ggml_cuda_op_softcap(*cuda_ctx, cgraph->nodes[i + 2], node); + return 2; + } + + return 0; +} + static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, const bool use_cuda_graph, const bool cuda_graph_update_required, const void * graph_key) { bool graph_evaluated_or_captured = false; @@ -3786,355 +4137,11 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud continue; } - // start of fusion operations - static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr); - if (!disable_fusion) { - ggml_cuda_topk_moe_args args; - - if (cgraph->nodes[i]->op == GGML_OP_UNARY || cgraph->nodes[i]->op == GGML_OP_SOFT_MAX || - cgraph->nodes[i]->op == GGML_OP_ARGSORT) { - const bool can_fuse = ggml_cuda_topk_moe_fusion(cgraph, i, args); - - std::vector ops; - - if (can_fuse) { - const ggml_tensor * logits = node->src[0]; - ggml_tensor * weights = nullptr; - ggml_tensor * ids = nullptr; - const ggml_tensor * bias = nullptr; - const ggml_tensor * clamp = nullptr; - const ggml_tensor * scale = nullptr; - - if (!args.delayed_softmax) { - ggml_op gating_op = args.sigmoid ? GGML_OP_UNARY : GGML_OP_SOFT_MAX; - int out_nodes[2]; // nodes which can't be elided - - if (args.prob_bias) { - bias = cgraph->nodes[i + 2]->src[1]; - ops.insert(ops.end(), { gating_op, GGML_OP_RESHAPE, GGML_OP_ADD, GGML_OP_ARGSORT, - GGML_OP_VIEW, GGML_OP_GET_ROWS }); - out_nodes[0] = i + 4; - ids = cgraph->nodes[i + 4]; - } else { - ops.insert(ops.end(), { gating_op, GGML_OP_RESHAPE, GGML_OP_ARGSORT, GGML_OP_VIEW, - GGML_OP_GET_ROWS }); - out_nodes[0] = i + 3; - ids = cgraph->nodes[i + 3]; - } - - if (args.norm) { - ops.insert(ops.end(), { GGML_OP_RESHAPE, GGML_OP_SUM_ROWS, GGML_OP_CLAMP, - GGML_OP_DIV, GGML_OP_RESHAPE }); - clamp = cgraph->nodes[i + ops.size() - 3]; - } - if (args.scale) { - ops.insert(ops.end(), { GGML_OP_SCALE }); - scale = cgraph->nodes[i + ops.size() - 1]; - } - - weights = cgraph->nodes[i + ops.size() - 1]; - out_nodes[1] = i + ops.size() - 1; - - if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) && - ggml_cuda_should_use_topk_moe(node, logits, weights, ids) && - ggml_cuda_check_fusion_memory_ranges(cgraph, i, ops.size(), out_nodes, 2, /*is_topk_moe=*/ true)) { - ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args); - i += ops.size() - 1; - continue; - } - } else if (!args.norm && !args.prob_bias) { - //special case gpt-oss, no norm, no bias. - ops.insert(ops.end(), { GGML_OP_ARGSORT, GGML_OP_VIEW, GGML_OP_GET_ROWS, - GGML_OP_RESHAPE, GGML_OP_SOFT_MAX, GGML_OP_RESHAPE }); - weights = cgraph->nodes[i + 5]; - ids = cgraph->nodes[i + 1]; - const ggml_tensor * softmax = cgraph->nodes[i + 4]; - - int out_nodes[2] = { i + 1, i + 5 }; - if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) && - ggml_cuda_should_use_topk_moe(softmax, logits, weights, ids) && - ggml_cuda_check_fusion_memory_ranges(cgraph, i, ops.size(), out_nodes, 2, /*is_topk_moe=*/ true)) { - ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args); - i += ops.size() - 1; - continue; - } - } - } - } - - if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, {})) { - ggml_tensor * rope = cgraph->nodes[i]; - ggml_tensor * set_rows = cgraph->nodes[i + 2]; - - ggml_cuda_op_rope_fused(*cuda_ctx, rope, set_rows); - i += 2; - continue; - } - - if (node->op == GGML_OP_ADD || node->op == GGML_OP_MUL) { - int n_fuse = 0; - ggml_op ops[8]; - std::fill(ops, ops + 8, node->op); - - for (; n_fuse <= 6; ++n_fuse){ - if (!ggml_can_fuse(cgraph, i + n_fuse, ops + n_fuse, 2)) { - break; - } - if (cgraph->nodes[i + n_fuse] != cgraph->nodes[i + n_fuse + 1]->src[0]) { - break; - } - if (!ggml_are_same_layout(cgraph->nodes[i + n_fuse]->src[1], cgraph->nodes[i + n_fuse + 1]->src[1])) { - break; - } - } - - n_fuse++; + int nodes_to_skip = ggml_cuda_try_fuse(cuda_ctx, cgraph, i); - if (n_fuse > 1) { - ggml_tensor fused_node; - memcpy(&fused_node, node, sizeof(ggml_tensor)); - for (int j = 0; j < n_fuse - 1; ++j) { - fused_node.src[j + 2] = cgraph->nodes[i + j + 1]->src[1]; - } - fused_node.data = cgraph->nodes[i + n_fuse - 1]->data; - if (node->op == GGML_OP_ADD) { - ggml_cuda_op_fused_add(*cuda_ctx, &fused_node, n_fuse); - } else { - ggml_cuda_op_fused_mul(*cuda_ctx, &fused_node, n_fuse); - } - i += n_fuse - 1; - - continue; - } - } - - bool fused_mul_mat_vec = false; - int fused_node_count = 0; - - for (ggml_op op : { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID }) { - const ggml_op bias_op = op == GGML_OP_MUL_MAT ? GGML_OP_ADD : GGML_OP_ADD_ID; - - if (ggml_cuda_can_fuse(cgraph, i, { op, bias_op, op, bias_op, GGML_OP_GLU }, {})) { - ggml_tensor * glu = cgraph->nodes[i + 4]; - ggml_tensor * gate_bias_n = glu->src[0]; - ggml_tensor * up_bias_n = glu->src[1]; - - //we don't assume the order for {gate, up}. Instead infer it from the bias tensor - ggml_tensor * gate_n = nullptr; - ggml_tensor * up_n = nullptr; - - if (gate_bias_n->src[0] == cgraph->nodes[i] || gate_bias_n->src[1] == cgraph->nodes[i]) { - gate_n = cgraph->nodes[i]; - up_n = cgraph->nodes[i + 2]; - } else if (gate_bias_n->src[0] == cgraph->nodes[i + 2] || gate_bias_n->src[1] == cgraph->nodes[i + 2]) { - gate_n = cgraph->nodes[i + 2]; - up_n = cgraph->nodes[i]; - } else { - continue; - } - - auto get_bias_tensor = [](const ggml_tensor * bias_node, const ggml_tensor * mul_node, ggml_op op_bias) { - if (op_bias == GGML_OP_ADD) { - if (bias_node->src[0] == mul_node) { - return bias_node->src[1]; - } - if (bias_node->src[1] == mul_node) { - return bias_node->src[0]; - } - return (ggml_tensor *) nullptr; - } - GGML_ASSERT(op_bias == GGML_OP_ADD_ID); - GGML_ASSERT(bias_node->src[0] == mul_node); - return bias_node->src[1]; - }; - - ggml_tensor * up_bias_tensor = get_bias_tensor(up_bias_n, up_n, bias_op); - ggml_tensor * gate_bias_tensor = get_bias_tensor(gate_bias_n, gate_n, bias_op); - - if (!up_bias_tensor || !gate_bias_tensor) { - continue; - } - - // we don't support repeating adds - if (bias_op == GGML_OP_ADD && - (!ggml_are_same_shape(gate_bias_n->src[0], gate_bias_n->src[1]) || - !ggml_are_same_shape(up_bias_n->src[0], up_bias_n->src[1]))) { - continue; - } - - const ggml_tensor * src0 = up_n->src[0]; - const ggml_tensor * src1 = up_n->src[1]; - const ggml_tensor * ids = up_n->src[2]; - - if (ggml_cuda_should_fuse_mul_mat_vec_f(up_n)) { - ggml_cuda_mm_fusion_args_host fusion_data{}; - fusion_data.gate = gate_n->src[0]; - fusion_data.x_bias = up_bias_tensor; - fusion_data.gate_bias = gate_bias_tensor; - fusion_data.glu_op = ggml_get_glu_op(glu); - - ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, glu, &fusion_data); - fused_mul_mat_vec = true; - fused_node_count = 5; - break; - } - - if (ggml_cuda_should_fuse_mul_mat_vec_q(up_n)) { - ggml_cuda_mm_fusion_args_host fusion_data{}; - fusion_data.gate = gate_n->src[0]; - fusion_data.x_bias = up_bias_tensor; - fusion_data.gate_bias = gate_bias_tensor; - fusion_data.glu_op = ggml_get_glu_op(glu); - - ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, glu, &fusion_data); - fused_mul_mat_vec = true; - fused_node_count = 5; - break; - } - } else if (ggml_cuda_can_fuse(cgraph, i, { op, op, GGML_OP_GLU }, {})) { - ggml_tensor * glu = cgraph->nodes[i + 2]; - ggml_tensor * gate = glu->src[0]; - ggml_tensor * up = glu->src[1]; - - bool ok = (gate == cgraph->nodes[i] && up == cgraph->nodes[i + 1]) - || (gate == cgraph->nodes[i + 1] && up == cgraph->nodes[i]); - - if (!ok) continue; - - const ggml_tensor * src0 = up->src[0]; - const ggml_tensor * src1 = up->src[1]; - const ggml_tensor * ids = up->src[2]; - - if (ggml_cuda_should_fuse_mul_mat_vec_f(up)) { - ggml_cuda_mm_fusion_args_host fusion_data{}; - fusion_data.gate = gate->src[0]; - fusion_data.glu_op = ggml_get_glu_op(glu); - - ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, glu, &fusion_data); - fused_mul_mat_vec = true; - fused_node_count = 3; - break; - } - - if (ggml_cuda_should_fuse_mul_mat_vec_q(up)) { - ggml_cuda_mm_fusion_args_host fusion_data{}; - fusion_data.gate = gate->src[0]; - fusion_data.glu_op = ggml_get_glu_op(glu); - - ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, glu, &fusion_data); - fused_mul_mat_vec = true; - fused_node_count = 3; - break; - } - } - } - - if (fused_mul_mat_vec) { - i += fused_node_count - 1; - continue; - } - - fused_mul_mat_vec = false; - fused_node_count = 0; - - for (ggml_op op : { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID }) { - const ggml_op bias_op = op == GGML_OP_MUL_MAT ? GGML_OP_ADD : GGML_OP_ADD_ID; - - if (!ggml_can_fuse(cgraph, i, { op, bias_op })) { - continue; - } - - ggml_tensor * mm_node = cgraph->nodes[i]; - ggml_tensor * bias_node = cgraph->nodes[i + 1]; - - ggml_tensor * bias_tensor = nullptr; - if (bias_op == GGML_OP_ADD) { - if (bias_node->src[0] == mm_node) { - bias_tensor = bias_node->src[1]; - } else if (bias_node->src[1] == mm_node) { - bias_tensor = bias_node->src[0]; - } else { - continue; - } - } else { - if (bias_node->src[0] != mm_node) { - continue; - } - bias_tensor = bias_node->src[1]; - } - - const ggml_tensor * src0 = mm_node->src[0]; - const ggml_tensor * src1 = mm_node->src[1]; - const ggml_tensor * ids = mm_node->src[2]; - - if (bias_op == GGML_OP_ADD_ID && bias_node->src[2] != ids) { - continue; - } - - if (bias_op == GGML_OP_ADD && !ggml_are_same_shape(bias_node->src[0], bias_node->src[1])) { - continue; - } - - ggml_cuda_mm_fusion_args_host fusion_data{}; - fusion_data.x_bias = bias_tensor; - - if (ggml_cuda_should_fuse_mul_mat_vec_f(mm_node)) { - ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, bias_node, &fusion_data); - fused_mul_mat_vec = true; - fused_node_count = 2; - break; - } - - if (ggml_cuda_should_fuse_mul_mat_vec_q(mm_node)) { - ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, bias_node, &fusion_data); - fused_mul_mat_vec = true; - fused_node_count = 2; - break; - } - } - - if (fused_mul_mat_vec) { - i += fused_node_count - 1; - continue; - } - - if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ADD}, {})) { - ggml_cuda_op_rms_norm_fused_add(*cuda_ctx, node, cgraph->nodes[i+1], cgraph->nodes[i+2]); - i += 2; - continue; - } - - if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL}, {})) { - ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i+1]); - i++; - continue; - } - - if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SSM_CONV, GGML_OP_UNARY }, { GGML_UNARY_OP_SILU })) { - ggml_cuda_op_ssm_conv(*cuda_ctx, node, cgraph->nodes[i+1]); - i++; - continue; - } - - if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_MUL }, { GGML_UNARY_OP_SILU }) || - ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_MUL }, { GGML_UNARY_OP_SIGMOID }) || - ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_MUL }, { GGML_UNARY_OP_SOFTPLUS })) { - ggml_cuda_op_unary_mul(*cuda_ctx, node, cgraph->nodes[i+1]); - i++; - continue; - } - - if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_SQR }, { GGML_UNARY_OP_RELU })) { - ggml_cuda_op_relu_sqr(*cuda_ctx, node, cgraph->nodes[i+1]); - i++; - continue; - } - - if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SCALE, GGML_OP_UNARY, GGML_OP_SCALE }, { GGML_UNARY_OP_TANH })) { - i += 2; - ggml_cuda_op_softcap(*cuda_ctx, cgraph->nodes[i], node); - continue; - } + if (nodes_to_skip != 0) { + i += nodes_to_skip; + continue; } #ifndef NDEBUG assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device)); From ad670182d95023221b71a0852adf245e7b73cd1c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 29 Apr 2026 16:41:45 +0300 Subject: [PATCH 520/831] ggml : bump version to 0.10.1 (ggml/1469) --- ggml/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index b9f7deb150d..f7b6f1f334f 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -5,7 +5,7 @@ project("ggml" C CXX ASM) ### GGML Version set(GGML_VERSION_MAJOR 0) set(GGML_VERSION_MINOR 10) -set(GGML_VERSION_PATCH 0) +set(GGML_VERSION_PATCH 1) set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}") list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/") From 320c048724d0c6e393540ff6ac51eec23afea04c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 30 Apr 2026 21:44:28 +0300 Subject: [PATCH 521/831] sync : ggml --- scripts/sync-ggml.last | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index 58863dc6bbb..236ae95a80f 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -404fcb9d7c96989569e68c9e7881ee3465a05c50 +387fa29fbbf3149f06a631c7850b6c35c24b0232 From c59a7736051d497d9370db54f01c46845e6bb8ad Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 1 May 2026 11:53:27 +0300 Subject: [PATCH 522/831] examples : update to Q1_0 --- examples/common-ggml.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/common-ggml.cpp b/examples/common-ggml.cpp index 6f02a2504c5..3f2eded86f7 100644 --- a/examples/common-ggml.cpp +++ b/examples/common-ggml.cpp @@ -74,6 +74,7 @@ bool ggml_common_quantize_0( case GGML_FTYPE_MOSTLY_BF16: case GGML_FTYPE_MOSTLY_MXFP4: case GGML_FTYPE_MOSTLY_NVFP4: + case GGML_FTYPE_MOSTLY_Q1_0: { fprintf(stderr, "%s: invalid model type %d\n", __func__, ftype); return false; @@ -215,6 +216,7 @@ bool ggml_common_quantize_0( case GGML_TYPE_TQ2_0: case GGML_TYPE_MXFP4: case GGML_TYPE_NVFP4: + case GGML_TYPE_Q1_0: case GGML_TYPE_COUNT: { fprintf(stderr, "%s: unsupported quantization type %d (%s)\n", __func__, ttype, ggml_type_name((ggml_type) ttype)); From 9f2cec1840b38f510b0098fc767a622c40b8a433 Mon Sep 17 00:00:00 2001 From: shalinib-ibm Date: Wed, 29 Apr 2026 16:02:40 +0530 Subject: [PATCH 523/831] ggml-cpu : disable tiled matmul on AIX to fix page boundary segfault (llama/22293) * ggml-cpu : disable tiled matmul on AIX to fix page boundary segfault vec_xst operations in the tiled path crash on AIX when writing near 4KB page boundaries due to strict memory protection. Fall back to mnpack implementation on AIX for stable execution. Signed-off-by: Shalini Salomi Bodapati * Update ggml/src/ggml-cpu/llamafile/sgemm.cpp Co-authored-by: Aaron Teo * Update sgemm.cpp * Update sgemm.cpp --------- Signed-off-by: Shalini Salomi Bodapati Co-authored-by: Aaron Teo --- ggml/src/ggml-cpu/llamafile/sgemm.cpp | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-cpu/llamafile/sgemm.cpp b/ggml/src/ggml-cpu/llamafile/sgemm.cpp index 34e320e2f50..e13828e3be6 100644 --- a/ggml/src/ggml-cpu/llamafile/sgemm.cpp +++ b/ggml/src/ggml-cpu/llamafile/sgemm.cpp @@ -2321,6 +2321,9 @@ class tinyBLAS_Q0_PPC { } void matmul(int64_t m, int64_t n) { + #if defined(_AIX) || defined(__BIG_ENDIAN__) + mnpack(0, m, 0, n); + #else const int64_t mc = 64; const int64_t kc = 64; int64_t nc = 64; @@ -2334,7 +2337,6 @@ class tinyBLAS_Q0_PPC { } else { n_aligned = (n / 64) * 64; } - if (n_aligned > 0) { if (n_aligned % 64 == 0) nc = 64; else if (n_aligned == n) nc = n; @@ -2352,6 +2354,7 @@ class tinyBLAS_Q0_PPC { } else { mnpack(0, m, 0, n); } + #endif } private: @@ -3191,12 +3194,16 @@ class tinyBLAS_PPC { } void matmul(int64_t m, int64_t n) { + #if defined(_AIX) || defined(__BIG_ENDIAN__) + mnpack(0, m, 0, n); + #else int64_t mc = 256; int64_t nc = 256; int64_t kc = 256; if (m % mc == 0 && n % nc == 0 && k % kc == 0) { matmul_tiled(m, n, mc, nc, kc); } else { mnpack(0, m, 0, n); } + #endif } private: From aec8e69c2f1f78ea3872361b1483ba99ebf74468 Mon Sep 17 00:00:00 2001 From: Anav Prasad Date: Wed, 29 Apr 2026 11:39:56 -0700 Subject: [PATCH 524/831] CUDA: fuse SSM_CONV + ADD(bias) + SILU (llama/22478) --- ggml/src/ggml-cuda/ggml-cuda.cu | 35 ++++++++++++++++++++++++++++++++- ggml/src/ggml-cuda/ssm-conv.cu | 34 ++++++++++++++++++++++++++------ ggml/src/ggml-cuda/ssm-conv.cuh | 2 +- 3 files changed, 63 insertions(+), 8 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index fd8dd91714c..0e6f74685d6 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3556,6 +3556,9 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, && unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_SILU) { const ggml_tensor * ssm_conv = cgraph->nodes[node_idx]; const ggml_tensor * silu = cgraph->nodes[node_idx+1]; + if (ggml_get_unary_op(silu) != unary_ops.begin()[0]) { + return false; + } if (ssm_conv->type != GGML_TYPE_F32 || silu->type != GGML_TYPE_F32) { return false; @@ -3564,6 +3567,31 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, return true; } + if (ops.size() == 3 && ops.begin()[0] == GGML_OP_SSM_CONV && ops.begin()[1] == GGML_OP_ADD + && ops.begin()[2] == GGML_OP_UNARY && unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_SILU) { + const ggml_tensor * ssm_conv = cgraph->nodes[node_idx]; + const ggml_tensor * add = cgraph->nodes[node_idx+1]; + const ggml_tensor * silu = cgraph->nodes[node_idx+2]; + if (ggml_get_unary_op(silu) != unary_ops.begin()[0]) { + return false; + } + + if (ssm_conv->type != GGML_TYPE_F32 || add->type != GGML_TYPE_F32 || silu->type != GGML_TYPE_F32) { + return false; + } + + // ADD must consume ssm_conv's output and broadcast a 1-D channel-wise bias. + const ggml_tensor * bias = (add->src[0] == ssm_conv) ? add->src[1] : add->src[0]; + if (bias->type != GGML_TYPE_F32 || !ggml_is_contiguous(bias)) { + return false; + } + if (ggml_nelements(bias) != ssm_conv->ne[0] || bias->ne[0] != ssm_conv->ne[0]) { + return false; + } + + return true; + } + if (ops.size() == 2 && ops.begin()[0] == GGML_OP_UNARY && ops.begin()[1] == GGML_OP_MUL && unary_ops.size() == 1 && (unary_ops.begin()[0] == GGML_UNARY_OP_SILU || unary_ops.begin()[0] == GGML_UNARY_OP_SIGMOID || unary_ops.begin()[0] == GGML_UNARY_OP_SOFTPLUS)) { const ggml_tensor * unary = cgraph->nodes[node_idx]; @@ -3966,8 +3994,13 @@ static int ggml_cuda_try_fuse(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph return 1; } + if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SSM_CONV, GGML_OP_ADD, GGML_OP_UNARY }, { GGML_UNARY_OP_SILU })) { + ggml_cuda_op_ssm_conv(*cuda_ctx, node, cgraph->nodes[i + 1], cgraph->nodes[i + 2]); + return 2; + } + if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SSM_CONV, GGML_OP_UNARY }, { GGML_UNARY_OP_SILU })) { - ggml_cuda_op_ssm_conv(*cuda_ctx, node, cgraph->nodes[i + 1]); + ggml_cuda_op_ssm_conv(*cuda_ctx, node, /*bias_add_node=*/ nullptr, cgraph->nodes[i + 1]); return 1; } diff --git a/ggml/src/ggml-cuda/ssm-conv.cu b/ggml/src/ggml-cuda/ssm-conv.cu index b77cdc1c137..4841389fbc8 100644 --- a/ggml/src/ggml-cuda/ssm-conv.cu +++ b/ggml/src/ggml-cuda/ssm-conv.cu @@ -3,6 +3,7 @@ template static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float * __restrict__ src1, + const float * __restrict__ bias, const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1, float * __restrict__ dst, const int dst_nb0, const int dst_nb1, const int dst_nb2, const int64_t n_t) { @@ -27,6 +28,8 @@ static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float w[j] = w_block[tid * stride_w + j]; } + float b = bias != nullptr ? bias[bidy * split_d_inner + tid] : 0.0f; + for (int64_t i = 0; i < n_t; i++) { float sumf = 0.0f; @@ -42,12 +45,14 @@ static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float for (size_t j = 0; j < d_conv; j++) { sumf += x[(i + j) % d_conv] * w[j]; } + sumf += b; y_block[i * stride_y + tid] = apply_silu ? ggml_cuda_op_silu_single(sumf) : sumf; } } template static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0, const float * __restrict__ src1, + const float * __restrict__ bias, const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1, float * __restrict__ dst, const int dst_nb0, const int dst_nb1, const int dst_nb2, const int64_t n_t) { @@ -97,6 +102,8 @@ static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0, w[j] = w_block[tid * stride_w + j]; } + float b = bias != nullptr ? bias[bidy * split_d_inner + tid] : 0.0f; + // Compute from shared memory for (int64_t i = 0; i < local_n_t; i++) { float sumf = 0.0f; @@ -104,12 +111,13 @@ static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0, for (size_t j = 0; j < d_conv; j++) { sumf += smem[tid * n_cols + i + j] * w[j]; } + sumf += b; y_block[i * stride_y + tid] = apply_silu ? ggml_cuda_op_silu_single(sumf) : sumf; } } template -static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int src0_nb0, const int src0_nb1, +static void ssm_conv_f32_cuda(const float * src0, const float * src1, const float * bias, const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1, float * dst, const int dst_nb0, const int dst_nb1, const int dst_nb2, const int64_t nc, const int64_t nr, const int64_t n_t, const int64_t n_s, cudaStream_t stream) { @@ -120,14 +128,14 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int constexpr int kNC = decltype(NC)::value; if (n_t <= 32) { const dim3 blocks(n_s, (nr + threads - 1) / threads, 1); - ssm_conv_f32<<>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, + ssm_conv_f32<<>>(src0, src1, bias, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t); } else { const int64_t split_n_t = 32; dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t); const size_t smem_size = threads * (kNC - 1 + split_n_t) * sizeof(float); ssm_conv_long_token_f32<<>>( - src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t); + src0, src1, bias, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t); } }; @@ -140,11 +148,18 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int } } -void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * silu_dst) { +void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * bias_add_node, ggml_tensor * silu_dst) { const struct ggml_tensor * src0 = dst->src[0]; // conv_x const struct ggml_tensor * src1 = dst->src[1]; // conv1d.weight + const bool fuse_bias = bias_add_node != nullptr; const bool fuse_silu = silu_dst != nullptr; + // bias always comes with silu. + GGML_ASSERT(!fuse_bias || fuse_silu); + + // The bias (when fused) is the non-conv operand of the ADD node. + const struct ggml_tensor * bias = fuse_bias ? (bias_add_node->src[0] == dst ? bias_add_node->src[1] : bias_add_node->src[0]) : nullptr; + // When fusing, write to silu_dst (the node downstream references). const struct ggml_tensor * out = fuse_silu ? silu_dst : dst; @@ -160,16 +175,23 @@ void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst, g const float * src0_d = (const float *) src0->data; const float * src1_d = (const float *) src1->data; + const float * bias_d = fuse_bias ? (const float *) bias->data : nullptr; float * dst_d = (float *) out->data; cudaStream_t stream = ctx.stream(); GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(out->type == GGML_TYPE_F32); + if (fuse_bias) { + GGML_ASSERT(bias->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_is_contiguous(bias)); + GGML_ASSERT(ggml_nelements(bias) == nr); + } + if (fuse_silu) { - ssm_conv_f32_cuda(src0_d, src1_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, out->nb[0], out->nb[1], + ssm_conv_f32_cuda(src0_d, src1_d, bias_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, out->nb[0], out->nb[1], out->nb[2], nc, nr, n_t, n_s, stream); } else { - ssm_conv_f32_cuda(src0_d, src1_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, out->nb[0], out->nb[1], + ssm_conv_f32_cuda(src0_d, src1_d, bias_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, out->nb[0], out->nb[1], out->nb[2], nc, nr, n_t, n_s, stream); } } diff --git a/ggml/src/ggml-cuda/ssm-conv.cuh b/ggml/src/ggml-cuda/ssm-conv.cuh index f96a1cd2484..8514ca84920 100644 --- a/ggml/src/ggml-cuda/ssm-conv.cuh +++ b/ggml/src/ggml-cuda/ssm-conv.cuh @@ -1,3 +1,3 @@ #include "common.cuh" -void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * silu_dst = nullptr); +void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * bias_add_node = nullptr, ggml_tensor * silu_dst = nullptr); From 66392cf1a2624fb2688d10c835bdc178f669460b Mon Sep 17 00:00:00 2001 From: Max Krasnyansky Date: Wed, 29 Apr 2026 11:51:21 -0700 Subject: [PATCH 525/831] hexagon: make vmem and buffer-size configurable (llama/22487) * hexagon: allow host to set max vmem size We use a sane default but it's helpful to allow for an override if needed. * hexagon: add support for measuring vmem space and move pinned mmaping management to host * hexagon: update vmem checks to use uint64 * hexagon: bump op buffers to 16 (matches max mmaps) * hexagon: bump default vmem to 3.2GB * hexagon: add support for autodetecting vmem space and some logging cleanup in that area * hexagon: fix whitespace warnings * Update scripts/snapdragon/adb/run-cli.sh Co-authored-by: Pascal * hex-adb: fix run-completion script --------- Co-authored-by: Pascal --- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 238 ++++++++++++++---------- ggml/src/ggml-hexagon/htp/htp-ctx.h | 4 +- ggml/src/ggml-hexagon/htp/htp-ops.h | 8 +- ggml/src/ggml-hexagon/htp/htp_iface.idl | 4 +- ggml/src/ggml-hexagon/htp/main.c | 27 ++- 5 files changed, 162 insertions(+), 119 deletions(-) diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 0d9b5e289bb..9345da62168 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -48,14 +48,16 @@ using intvec = std::vector; using uintvec = std::vector; using u32vec = std::vector; -static size_t opt_ndev = 1; -static size_t opt_nhvx = 0; // use all -static int opt_arch = 0; // autodetect -static int opt_etm = 0; -static int opt_verbose = 0; -static int opt_profile = 0; // profiling mode (0-disabled, 1-basic, 2-pmu) -static int opt_hostbuf = 1; // hostbuf ON by default -static int opt_use_hmx = 1; // when set, enable HMX; when 0, use HVX only +static int opt_arch = 0; // autodetect +static size_t opt_ndev = 1; +static size_t opt_nhvx = 0; // use all +static int opt_use_hmx = 1; // when set, enable HMX; when 0, use HVX only +static size_t opt_vmem = HTP_OP_MAX_VMEM_DEFAULT; // max available va space for buffer mappings +static size_t opt_mbuf = 1ul * 1024 * 1024 * 1024; // max buffer size +static int opt_etm = 0; +static int opt_verbose = 0; +static int opt_profile = 0; // profiling mode (0-disabled, 1-basic, 2-pmu) +static int opt_hostbuf = 1; // hostbuf ON by default // Default PMU events, if profiling with PMU (mode=2) is enabled // See https://docs.qualcomm.com/doc/80-N2040-60/topic/pmu-events.html @@ -66,6 +68,7 @@ static u32vec opt_pmu_evt { 0x3, 0x111, 0x100, 0x105, 0x240, 0x256, 0x7D, 0x8C } static int opt_opstage = HTP_OPSTAGE_QUEUE | HTP_OPSTAGE_COMPUTE; static int opt_opbatch = 1024; // max number of ops in a batch static int opt_opqueue = 16; // max number of pending batches + static std::regex* opt_opfilter = NULL; // regex of ops to not claim #define HEX_VERBOSE(...) \ @@ -110,7 +113,7 @@ static void ggml_hexagon_dump_op_supp(const std::string &sess_name, const struct if (!opt_verbose) return; op_desc desc(op); - GGML_LOG_DEBUG("ggml-hex: %s supports-op %s : %s : %s : %s : %s : %s : %s\n", sess_name.c_str(), + GGML_LOG_DEBUG("ggml-hex: %s supports-op %s: %s : %s : %s : %s : %s : %s\n", sess_name.c_str(), ggml_op_desc(op), desc.names, desc.dims, desc.types, desc.strides, desc.buffs, supp ? "yes" : "no"); } @@ -118,8 +121,6 @@ static void ggml_hexagon_dump_op_prof(const std::string &sess_name, const ggml_t uint32_t op_usec, uint32_t op_cycles, const uint32_t pmu[]) { if (!opt_profile) return; - op_desc desc(op); - char pmu_str[256] = ""; if (opt_profile > 1) { static_assert(HTP_PROF_PMU_NCNT == 8, "current implementation assumes 8 PMU counters"); @@ -127,6 +128,7 @@ static void ggml_hexagon_dump_op_prof(const std::string &sess_name, const ggml_t pmu[0], pmu[1], pmu[2], pmu[3], pmu[4], pmu[5], pmu[6], pmu[7]); } + op_desc desc(op); GGML_LOG_DEBUG("ggml-hex: %s profile-op %s: %s : %s : %s : %s : usec %u cycles %u%s\n", sess_name.c_str(), ggml_op_desc(op), desc.names, desc.dims, desc.types, desc.strides, op_usec, op_cycles, pmu_str); } @@ -191,33 +193,30 @@ struct ggml_hexagon_shared_buffer { bool mapped; bool pinned; - void mmap(bool pinned = false) { - int err = fastrpc_mmap(sess->domain_id, this->fd, (void *) this->base, 0, this->size, FASTRPC_MAP_FD_DELAYED); + void mmap() { + fastrpc_map_flags flags = this->pinned ? FASTRPC_MAP_FD : FASTRPC_MAP_FD_DELAYED; + + int err = fastrpc_mmap(sess->domain_id, this->fd, (void *) this->base, 0, this->size, flags); if (err != 0) { GGML_LOG_ERROR("ggml-hex: %s buffer mapping failed : domain_id %d size %zu fd %d error 0x%08x\n", sess->c_name(), sess->domain_id, this->size, this->fd, (unsigned) err); throw std::runtime_error("ggml-hex: fastrpc_mmap failed (see log for details)"); } - if (pinned) { - err = htp_iface_mmap(sess->handle, this->fd, this->size, pinned); - if (err != 0) { - GGML_LOG_ERROR("ggml-hex: %s buffer pinning failed : domain_id %d size %zu fd %d error 0x%08x\n", sess->c_name(), - sess->domain_id, this->size, this->fd, (unsigned) err); - throw std::runtime_error("ggml-hex: htp_iface_mmap failed (see log for details)"); - } - } - - this->mapped = true; - this->pinned = pinned; HEX_VERBOSE("ggml-hex: %s mapped buffer: base %p size %zu fd %d pinned %u\n", sess->c_name(), (void *) this->base, this->size, this->fd, pinned); + + this->mapped = true; } void unmap() { if (!this->mapped) return; - htp_iface_munmap(sess->handle, this->fd); + if (!this->pinned) { + // HTP might still hold a reference, tell it drop it + htp_iface_munmap(sess->handle, this->fd); + } + fastrpc_munmap(sess->domain_id, this->fd, (void *) this->base, this->size); HEX_VERBOSE("ggml-hex: %s unmapped buffer: base %p size %zu fd %d\n", sess->c_name(), @@ -227,7 +226,7 @@ struct ggml_hexagon_shared_buffer { this->fd = -1; } - void alloc(size_t size, bool pinned = false) { + void alloc(size_t size) { if (this->base) return; this->base = (uint8_t *) rpcmem_alloc2(RPCMEM_HEAP_ID_SYSTEM, RPCMEM_DEFAULT_FLAGS, size); @@ -245,8 +244,7 @@ struct ggml_hexagon_shared_buffer { HEX_VERBOSE("ggml-hex: %s allocated buffer: base %p size %zu fd %d pinned %d\n", sess->c_name(), (void *) this->base, this->size, this->fd, (int) pinned); - - mmap(pinned); + mmap(); } void free() { @@ -262,15 +260,14 @@ struct ggml_hexagon_shared_buffer { } ggml_hexagon_shared_buffer(ggml_hexagon_session * sess, size_t size, bool pinned = false) { - size += 4 * 1024; // extra page for padding - this->sess = sess; this->size = 0; this->base = nullptr; this->fd = -1; this->mapped = false; + this->pinned = pinned; - alloc(size, pinned); + alloc(size); } ~ggml_hexagon_shared_buffer() { @@ -1475,6 +1472,7 @@ static ggml_backend_buffer_t ggml_backend_hexagon_buffer_type_alloc_buffer( ggml_backend_buffer_type_t buffer_type, size_t size) { auto sess = static_cast(buffer_type->context)->sess; try { + size += 4 * 1024; // guard page ggml_hexagon_shared_buffer * sbuf = new ggml_hexagon_shared_buffer(sess, size); return ggml_backend_buffer_init(buffer_type, ggml_backend_hexagon_buffer_interface, sbuf, size); } catch (const std::exception & exc) { @@ -1487,6 +1485,7 @@ static ggml_backend_buffer_t ggml_backend_hexagon_repack_buffer_type_alloc_buffe ggml_backend_buffer_type_t buffer_type, size_t size) { auto sess = static_cast(buffer_type->context)->sess; try { + size += 4 * 1024; // guard page ggml_hexagon_shared_buffer * sbuf = new ggml_hexagon_shared_buffer(sess, size); return ggml_backend_buffer_init(buffer_type, ggml_backend_hexagon_buffer_interface, sbuf, size); } catch (const std::exception & exc) { @@ -1505,7 +1504,7 @@ static size_t ggml_backend_hexagon_buffer_type_get_alloc_size(ggml_backend_buffe } static size_t ggml_backend_hexagon_buffer_type_get_max_size(ggml_backend_buffer_type_t buffer_type) { - return 1UL * 1024 * 1024 * 1024; // 1GB per buffer + return opt_mbuf; // typically 1GB per buffer GGML_UNUSED(buffer_type); } @@ -1573,14 +1572,14 @@ struct ggml_hexagon_opbatch { d_map.clear(); } - ggml_hexagon_opbatch(ggml_hexagon_session *sess, size_t batch_size) { + ggml_hexagon_opbatch(ggml_hexagon_session *sess, size_t batch_size, size_t max_vmem) { this->sess = sess; n_bufs_max = HTP_OP_MAX_BUFS; n_ops_max = batch_size; n_tens_max = n_ops_max + n_ops_max * HTP_OP_MAX_INPUTS; - b_vmem_max = HTP_OP_MAX_VMEM; + b_vmem_max = max_vmem; ops.resize(n_ops_max); @@ -1592,6 +1591,9 @@ struct ggml_hexagon_opbatch { t_map.reserve(n_tens_max); d_map.reserve(n_tens_max); + GGML_LOG_INFO("ggml-hex: %s op batching: n-bufs %u n-tensors %u n-ops %u vmem %zu\n", + sess->c_name(), n_bufs_max, n_tens_max, n_ops_max, b_vmem_max); + reset(); } @@ -1925,6 +1927,8 @@ void ggml_hexagon_session::flush_batch() { // Bump pending flag (cleared in the session::flush once we get the response) this->op_pending++; // atomic inc + HEX_VERBOSE("ggml-hex: %s queue-opbatch: %p size %u\n", this->c_name(), dbuf.ptr, dbuf.size); + int err = dspqueue_write(this->queue, 0, 1, &dbuf, sizeof(req), (const uint8_t*) &req, DSPQUEUE_TIMEOUT); if (err != 0) { GGML_ABORT("ggml-hex: %s dspqueue_write failed: 0x%08x\n", this->c_name(), (unsigned) err); @@ -1944,6 +1948,35 @@ void ggml_hexagon_session::flush(bool all) { flush_pending(all); } +static size_t ggml_hexagon_measure_max_vmem(ggml_hexagon_session *sess) { + // Allocate a bunch pinned buffers till failure. + // This is kind of expensive but handy for figuring out exactly how much we can mmap on a specific device. + // Typically we're going to allocate all/most of these buffers anyway for the model weights. + + std::vector sbufs; + + const size_t MiB = 1024 * 1024; + const size_t GiB = MiB * 1024; + + size_t vmem = 0; + size_t step = 256u * MiB; + + try { + sbufs.push_back(new ggml_hexagon_shared_buffer(sess, GiB, true)); vmem += GiB; + sbufs.push_back(new ggml_hexagon_shared_buffer(sess, GiB, true)); vmem += GiB; + sbufs.push_back(new ggml_hexagon_shared_buffer(sess, GiB, true)); vmem += GiB; + + while (1) { + sbufs.push_back(new ggml_hexagon_shared_buffer(sess, step, true)); + vmem += step; + } + } catch (...) { } + + for (auto b : sbufs) { delete b; } + + return vmem - step; // backoff to account for overhead from internal mappings +} + void ggml_hexagon_session::allocate(int dev_id) noexcept(false) { this->valid_session = false; this->valid_handle = false; @@ -1957,7 +1990,7 @@ void ggml_hexagon_session::allocate(int dev_id) noexcept(false) { this->op_pending = 0; - GGML_LOG_INFO("ggml-hex: allocating new session: %s\n", this->name.c_str()); + GGML_LOG_DEBUG("ggml-hex: %s allocating new session\n", this->name.c_str()); domain * my_domain = get_domain(this->domain_id); if (my_domain == NULL) { @@ -2033,9 +2066,6 @@ void ggml_hexagon_session::allocate(int dev_id) noexcept(false) { this->valid_handle = true; - GGML_LOG_INFO("ggml-hex: new session: %s : session-id %d domain-id %d uri %s handle 0x%lx\n", this->name.c_str(), - this->session_id, this->domain_id, session_uri, (unsigned long) this->handle); - // Enable FastRPC QoS mode { struct remote_rpc_control_latency l; @@ -2047,6 +2077,9 @@ void ggml_hexagon_session::allocate(int dev_id) noexcept(false) { } } + GGML_LOG_INFO("ggml-hex: %s new session : session-id %d domain-id %d uri %s handle 0x%lx\n", this->c_name(), + this->session_id, this->domain_id, session_uri, (unsigned long) this->handle); + const size_t req_q_size = (sizeof(htp_opbatch_req) * opt_opqueue * 2) + 1024; const size_t rsp_q_size = (sizeof(htp_opbatch_rsp) * opt_opqueue * 2) + 1024; @@ -2091,13 +2124,19 @@ void ggml_hexagon_session::allocate(int dev_id) noexcept(false) { } // Allocate buffers and state for op batching - this->op_batch = new ggml_hexagon_opbatch(this, opt_opbatch); this->op_queue = new ggml_hexagon_opqueue(this, opt_opbatch, opt_opqueue); - // Start processing op batch requests - err = htp_iface_start(this->handle, dev_id, this->queue_id, opt_nhvx, opt_use_hmx); + if (!opt_vmem) { + opt_vmem = ggml_hexagon_measure_max_vmem(this); + GGML_LOG_INFO("ggml-hex: %s measured max vmem %zu\n", this->c_name(), opt_vmem); + } + + this->op_batch = new ggml_hexagon_opbatch(this, opt_opbatch, opt_vmem); + + // Start dspqueue/opbatch processing + err = htp_iface_start(this->handle, dev_id, this->queue_id, opt_nhvx, opt_use_hmx, opt_vmem); if (err != 0) { - GGML_LOG_ERROR("ggml-hex: failed to start session: 0x%08x\n", (unsigned) err); + GGML_LOG_ERROR("ggml-hex: %s failed to start session: 0x%08x\n", this->c_name(), (unsigned) err); throw std::runtime_error("ggml-hex: iface start failed (see log for details)"); } this->valid_iface = true; @@ -2108,17 +2147,17 @@ void ggml_hexagon_session::release() noexcept(true) { int err; - delete this->op_batch; - delete this->op_queue; - - // Stop the DSP-side service and close the queue if (this->valid_iface) { + // Stop dspqueue/opbatch processing err = htp_iface_stop(this->handle); if (err != 0) { GGML_ABORT("ggml-hex: htp_iface_stop failed: 0x%08x\n", (unsigned) err); } } + delete this->op_batch; + delete this->op_queue; + if (opt_etm) { err = htp_iface_etm(this->handle, 0); if (err != 0) { @@ -3380,21 +3419,6 @@ struct ggml_hexagon_registry { ggml_hexagon_registry::ggml_hexagon_registry(ggml_backend_reg_t reg) { GGML_LOG_INFO("ggml-hex: Hexagon backend (experimental) : allocating new registry : ndev %zu\n", opt_ndev); - if (!opt_arch) { - int err = get_hex_arch_ver(CDSP_DOMAIN_ID, &opt_arch); - if (err != 0) { - GGML_LOG_ERROR("ggml-hex: failed to query HTP version (err %d) defaulting to v73\n", err); - opt_arch = 73; - } - } - -#if defined(__ANDROID__) - if (opt_arch < 75) { - opt_ndev = 1; - GGML_LOG_WARN("ggml-hex: forcing ndev to 1 for SoCs archs lower than v75.\n"); - } -#endif - GGML_LOG_INFO("ggml-hex: Hexagon Arch version v%d\n", opt_arch); // Create devices / sessions @@ -3480,32 +3504,67 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) { static_assert((unsigned int) HTP_TYPE_IQ4_NL == (unsigned int) GGML_TYPE_IQ4_NL, "please update hexagon_type to match ggml_type"); - const char * str_verbose = getenv("GGML_HEXAGON_VERBOSE"); - const char * str_hostbuf = getenv("GGML_HEXAGON_HOSTBUF"); - const char * str_opstage = getenv("GGML_HEXAGON_OPSTAGE"); - const char * str_opbatch = getenv("GGML_HEXAGON_OPBATCH"); - const char * str_opqueue = getenv("GGML_HEXAGON_OPQUEUE"); - const char * str_opfilter= getenv("GGML_HEXAGON_OPFILTER"); - const char * str_profile = getenv("GGML_HEXAGON_PROFILE"); - const char * str_etm = getenv("GGML_HEXAGON_ETM"); - const char * str_nhvx = getenv("GGML_HEXAGON_NHVX"); - const char * str_use_hmx = getenv("GGML_HEXAGON_USE_HMX"); - const char * str_ndev = getenv("GGML_HEXAGON_NDEV"); - const char * str_arch = getenv("GGML_HEXAGON_ARCH"); + const char * str_verbose = getenv("GGML_HEXAGON_VERBOSE"); + const char * str_hostbuf = getenv("GGML_HEXAGON_HOSTBUF"); + const char * str_opstage = getenv("GGML_HEXAGON_OPSTAGE"); + const char * str_opbatch = getenv("GGML_HEXAGON_OPBATCH"); + const char * str_opqueue = getenv("GGML_HEXAGON_OPQUEUE"); + const char * str_opfilter = getenv("GGML_HEXAGON_OPFILTER"); + const char * str_profile = getenv("GGML_HEXAGON_PROFILE"); + const char * str_etm = getenv("GGML_HEXAGON_ETM"); + const char * str_nhvx = getenv("GGML_HEXAGON_NHVX"); + const char * str_use_hmx = getenv("GGML_HEXAGON_USE_HMX"); + const char * str_ndev = getenv("GGML_HEXAGON_NDEV"); + const char * str_arch = getenv("GGML_HEXAGON_ARCH"); + const char * str_vmem = getenv("GGML_HEXAGON_VMEM"); + const char * str_mbuf = getenv("GGML_HEXAGON_MBUF"); + + // Init Arch first since it affects other defaults + if (!str_arch) { + int err = get_hex_arch_ver(CDSP_DOMAIN_ID, &opt_arch); + if (err != 0) { + GGML_LOG_ERROR("ggml-hex: failed to query HTP version (err %d) defaulting to v73\n", err); + opt_arch = 73; + } + } else { + if (str_arch[0] == 'v' || str_arch[0] == 'V') { + str_arch++; + } + opt_arch = strtoul(str_arch, NULL, 0); + } + + size_t MiB = 1024 * 1024; + + // Update vmem default + opt_vmem = opt_arch >= 75 ? HTP_OP_MAX_VMEM_DEFAULT : 3000 * MiB; auto RE_ICASE = std::regex_constants::icase; - opt_opfilter = str_opfilter ? new std::regex(str_opfilter, RE_ICASE) : NULL; - opt_verbose = str_verbose ? atoi(str_verbose) : 0; - opt_hostbuf = str_hostbuf ? atoi(str_hostbuf) : opt_hostbuf; - opt_opstage = str_opstage ? strtoul(str_opstage, NULL, 0) : opt_opstage; - opt_opbatch = str_opbatch ? strtoul(str_opbatch, NULL, 0) : opt_opbatch; - opt_opqueue = str_opqueue ? strtoul(str_opqueue, NULL, 0) : opt_opqueue; - opt_etm = str_etm ? atoi(str_etm) : 0; - opt_nhvx = str_nhvx ? strtoul(str_nhvx, NULL, 0) : opt_nhvx; - opt_use_hmx = str_use_hmx ? atoi(str_use_hmx) : opt_use_hmx; - opt_ndev = str_ndev ? strtoul(str_ndev, NULL, 0) : opt_ndev; - opt_hostbuf = str_hostbuf ? atoi(str_hostbuf) : opt_hostbuf; + opt_opfilter = str_opfilter ? new std::regex(str_opfilter, RE_ICASE) : NULL; + opt_verbose = str_verbose ? atoi(str_verbose) : 0; + opt_hostbuf = str_hostbuf ? atoi(str_hostbuf) : opt_hostbuf; + opt_opstage = str_opstage ? strtoul(str_opstage, NULL, 0) : opt_opstage; + opt_opbatch = str_opbatch ? strtoul(str_opbatch, NULL, 0) : opt_opbatch; + opt_opqueue = str_opqueue ? strtoul(str_opqueue, NULL, 0) : opt_opqueue; + opt_profile = str_profile ? atoi(str_profile) : 0; + opt_etm = str_etm ? atoi(str_etm) : 0; + opt_nhvx = str_nhvx ? strtoul(str_nhvx, NULL, 0) : opt_nhvx; + opt_use_hmx = str_use_hmx ? atoi(str_use_hmx) : opt_use_hmx; + opt_ndev = str_ndev ? strtoul(str_ndev, NULL, 0) : opt_ndev; + opt_hostbuf = str_hostbuf ? atoi(str_hostbuf) : opt_hostbuf; + opt_mbuf = str_mbuf ? strtoul(str_mbuf, NULL, 0) * MiB : opt_mbuf; + opt_vmem = str_vmem ? strtoul(str_vmem, NULL, 0) * MiB : opt_vmem; + + if (opt_ndev > GGML_HEXAGON_MAX_SESSIONS) { + opt_ndev = GGML_HEXAGON_MAX_SESSIONS; + } + +#if defined(__ANDROID__) + if (opt_arch < 75) { + opt_ndev = 1; + GGML_LOG_WARN("ggml-hex: forcing ndev to 1 for SoCs archs lower than v75.\n"); + } +#endif if (str_profile) { opt_pmu_evt = [&]() -> std::vector { @@ -3520,17 +3579,6 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) { vec_to_str(opt_pmu_evt).c_str()); } - if (opt_ndev > GGML_HEXAGON_MAX_SESSIONS) { - opt_ndev = GGML_HEXAGON_MAX_SESSIONS; - } - - if (str_arch) { - if (str_arch[0] == 'v') { - str_arch++; - } - opt_arch = strtoul(str_arch, NULL, 0); - } - reg->context = new ggml_hexagon_registry(reg); } diff --git a/ggml/src/ggml-hexagon/htp/htp-ctx.h b/ggml/src/ggml-hexagon/htp/htp-ctx.h index d704fedee9d..e9c563ca887 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ctx.h +++ b/ggml/src/ggml-hexagon/htp/htp-ctx.h @@ -20,7 +20,7 @@ struct htp_mmap { uint64_t size; uint64_t base; uint32_t fd; - uint32_t pinned; + uint32_t reserved; }; // Scratchpad state @@ -77,6 +77,8 @@ struct htp_context { atomic_bool vtcm_valid; atomic_bool vtcm_needs_release; + uint64_t max_vmem; + struct htp_ops_context octx; #ifdef HTP_HAS_HMX diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index 4397245c5b8..66a3150c1a0 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -90,15 +90,11 @@ enum htp_op_code { #define HTP_OP_MAX_INPUTS 6 // aka GGML_MAX_SRCS #define HTP_OP_MAX_PARAMS 16 // aka GGML_MAX_OP_PARAMS -#define HTP_OP_MAX_BUFS 8 +#define HTP_OP_MAX_BUFS 16 #define HTP_OP_MAX_REQS 256 #define HTP_OP_MAX_TENSORS (HTP_OP_MAX_REQS * HTP_OP_MAX_INPUTS + HTP_OP_MAX_REQS) -#if __HVX_ARCH__ < 75 -#define HTP_OP_MAX_VMEM (3167538380u) -#else -#define HTP_OP_MAX_VMEM (3221225472u) -#endif +#define HTP_OP_MAX_VMEM_DEFAULT (3355443200u) #define HTP_MMAP_MAX_VMEM (2147483648u) diff --git a/ggml/src/ggml-hexagon/htp/htp_iface.idl b/ggml/src/ggml-hexagon/htp/htp_iface.idl index dbcafd1d856..d696a5fba0c 100644 --- a/ggml/src/ggml-hexagon/htp/htp_iface.idl +++ b/ggml/src/ggml-hexagon/htp/htp_iface.idl @@ -11,9 +11,9 @@ struct htp_iface_pmu_conf { }; interface htp_iface : remote_handle64 { - AEEResult start(in uint32 sess_id, in uint64 dsp_queue_id, in uint32 n_hvx, in uint32 use_hmx); + AEEResult start(in uint32 sess_id, in uint64 dsp_queue_id, in uint32 n_hvx, in uint32 use_hmx, in uint64 max_vmem); AEEResult stop(); - AEEResult mmap(in uint32 fd, in uint32 size, in uint32 pinned); + AEEResult mmap(in uint32 fd, in uint32 size); AEEResult munmap(in uint32 fd); AEEResult profiler(in uint32 mode, in htp_iface_pmu_conf pmu); AEEResult etm(in uint32 enable); diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index f58347304be..49c1a15b344 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -210,7 +210,7 @@ AEEResult htp_iface_close(remote_handle64 handle) { return AEE_SUCCESS; } -AEEResult htp_iface_mmap(remote_handle64 handle, uint32 fd, uint32 size, uint32 pinned) { +AEEResult htp_iface_mmap(remote_handle64 handle, uint32_t fd, uint32_t size) { struct htp_context * ctx = (struct htp_context *) handle; if (!ctx) { return AEE_EBADPARM; @@ -220,7 +220,6 @@ AEEResult htp_iface_mmap(remote_handle64 handle, uint32 fd, uint32 size, uint32 for (uint32_t i=0; immap[i]; if (m->fd == fd) { - m->pinned = pinned; return AEE_SUCCESS; } } @@ -229,7 +228,7 @@ AEEResult htp_iface_mmap(remote_handle64 handle, uint32 fd, uint32 size, uint32 for (uint32_t i=0; immap[i]; if (!m->size) { - FARF(HIGH, "mmap : fd %u size %u pinned %u", fd, size, pinned); + FARF(HIGH, "mmap : fd %u size %u", fd, size); #if __HVX_ARCH__ > 73 void *va = HAP_mmap2(NULL, size, HAP_PROT_READ | HAP_PROT_WRITE, 0, fd, 0); #else @@ -248,7 +247,6 @@ AEEResult htp_iface_mmap(remote_handle64 handle, uint32 fd, uint32 size, uint32 m->base = (uint64_t) va; m->fd = fd; m->size = size; - m->pinned = pinned; return AEE_SUCCESS; } @@ -275,7 +273,6 @@ AEEResult htp_iface_munmap(remote_handle64 handle, uint32 fd) { m->size = 0; m->base = NULL; m->fd = -1; - m->pinned = 0; } } @@ -358,7 +355,7 @@ static void vtcm_free(struct htp_context * ctx) { static void htp_packet_callback(dspqueue_t queue, int error, void * context); static void htp_error_callback(dspqueue_t queue, int error, void * context); -AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_queue_id, uint32 n_hvx, uint32 use_hmx) { +AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_queue_id, uint32 n_hvx, uint32 use_hmx, uint64_t max_vmem) { struct htp_context * ctx = (struct htp_context *) handle; if (!ctx) { @@ -376,12 +373,12 @@ AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_que htp_error_callback, // Error callback; no errors expected on the DSP (void *) ctx, // Callback context &ctx->queue); - if (err) { FARF(ERROR, "Queue import failed with 0x%08x", (unsigned) err); return err; } + ctx->max_vmem = max_vmem; ctx->thread_id = qurt_thread_get_id(); ctx->thread_prio = qurt_thread_get_priority(ctx->thread_id); @@ -622,8 +619,8 @@ static inline bool reuse_buf(struct htp_context *ctx, uint32_t *m_reuse, struct } static inline void drop_mmap(struct htp_context *ctx, struct htp_mmap *m) { - if (m->size && !m->pinned) { - FARF(HIGH, "unmap : fd %u base %p size %u pinned %u", m->fd, (void*) m->base, (uint32_t) m->size, m->pinned); + if (m->size) { + FARF(HIGH, "unmap : fd %u base %p size %u", m->fd, (void*) m->base, (uint32_t) m->size); #if __HVX_ARCH__ > 73 HAP_munmap2((void *) m->base, m->size); #else @@ -660,9 +657,8 @@ static inline void mmap_buf(struct htp_context *ctx, struct htp_buf_desc *b) { m->base = b->base = (uint64_t) va; m->fd = b->fd; m->size = b->size; - m->pinned = 0; - FARF(HIGH, "mmap : fd %u base %p size %u pinned %u", m->fd, (void*) m->base, (uint32_t) m->size, m->pinned); + FARF(HIGH, "mmap : fd %u base %p size %u", m->fd, (void*) m->base, (uint32_t) m->size); return; } } @@ -672,8 +668,8 @@ static void prep_op_bufs(struct htp_context *ctx, struct htp_buf_desc *bufs, uin uint32_t m_reuse = 0; // mmap reuse mask (index from ctx->mmap array) uint32_t b_reuse = 0; // buf reuse count - size_t m_vmem = 0; // mapped vmem - size_t e_vmem = 0; // extra vmem + uint64_t m_vmem = 0; // mapped vmem + uint64_t e_vmem = 0; // extra vmem // See what we can reuse for (uint32_t i=0; i < n_bufs; i++) { @@ -687,9 +683,10 @@ static void prep_op_bufs(struct htp_context *ctx, struct htp_buf_desc *bufs, uin // See how much vmem we have mmaped right now for (uint32_t i=0; immap[i].size; } - FARF(HIGH, "prep-bufs : pass1 mmap-vmem %zu extra-vmem %zu n-bufs %u b-reuse %u", m_vmem, e_vmem, n_bufs, b_reuse); + FARF(HIGH, "prep-bufs : pass1 mmap-vmem %zu extra-vmem %zu max-vmem %zu : n-bufs %u b-reuse %u", + (size_t) m_vmem, (size_t) e_vmem, (size_t) ctx->max_vmem, n_bufs, b_reuse); - if ((m_vmem + e_vmem) > HTP_OP_MAX_VMEM) { + if ((m_vmem + e_vmem) > ctx->max_vmem) { // Drop unused mappings for (uint32_t i=0; i < HTP_MAX_MMAPS; i++) { bool used = m_reuse & (1< Date: Wed, 29 Apr 2026 22:58:32 -0700 Subject: [PATCH 526/831] add fast matmul iquants (llama/22504) --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 19 + ggml/src/ggml-webgpu/ggml-webgpu.cpp | 2 +- .../wgsl-shaders/mul_mat_decls.tmpl | 423 ++++++++++++++++++ 3 files changed, 443 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index b7771ac230e..5239164cd00 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -1806,6 +1806,25 @@ class ggml_webgpu_shader_lib { defines.push_back("U32_DEQUANT_HELPERS"); defines.push_back("SRC0_INNER_TYPE=u32"); + switch (context.src0->type) { + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ4_XS: + defines.push_back(type_upper + "_GRID"); + break; + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + defines.push_back(type_upper + "_GRID"); + defines.push_back(type_upper + "_TABLES"); + break; + default: + break; + } + variant += std::string("_") + src0_name; break; } diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index f7fd73ae144..5e55a2a1e1b 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -1422,7 +1422,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: - use_fast = is_vec; + use_fast = true; break; default: break; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl index 15b22c4f731..51cf08f196f 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl @@ -740,3 +740,426 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 } } #endif // INIT_SRC0_SHMEM_Q6_K + +#ifdef INIT_SRC0_SHMEM_IQ4_NL +const BLOCK_SIZE = 32u; +const BLOCK_SIZE_BYTES = 18u; + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { + let tile_m = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + let global_m = offset_m + tile_m; + let global_k = k_outer + tile_k; + + if (global_m >= params.m || global_k >= params.k) { + shmem[elem_idx] = f16(0.0); + continue; + } + + let block_k = global_k / BLOCK_SIZE; + let k_in_block = global_k % BLOCK_SIZE; + + let src0_idx = batch_offset + global_m * params.stride_01 + block_k; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + let d = load_f16_at_src0(block_byte_base); + + let pos = k_in_block % 16u; + let nib_shift = (k_in_block / 16u) * 4u; + let q_packed = load_u32_at_src0(block_byte_base + 2u + (pos / 4u) * 4u); + let nib = (get_byte(q_packed, pos % 4u) >> nib_shift) & 0xFu; + + shmem[elem_idx] = d * f16(kvalues_iq4nl[nib]); + } +} +#endif // INIT_SRC0_SHMEM_IQ4_NL + +#ifdef INIT_SRC0_SHMEM_IQ4_XS +const BLOCK_SIZE = 256u; +const BLOCK_SIZE_BYTES = 136u; + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { + let tile_m = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + let global_m = offset_m + tile_m; + let global_k = k_outer + tile_k; + + if (global_m >= params.m || global_k >= params.k) { + shmem[elem_idx] = f16(0.0); + continue; + } + + let block_k = global_k / BLOCK_SIZE; + let k_in_block = global_k % BLOCK_SIZE; + + let src0_idx = batch_offset + global_m * params.stride_01 + block_k; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + + let d_scales_h = load_u32_at_src0(block_byte_base); + let d = bitcast>(d_scales_h).x; + let scales_h = d_scales_h >> 16u; + + let ib = k_in_block / 32u; + let pos = k_in_block % 32u; + + let scales_l_word = load_u32_at_src0(block_byte_base + 4u); + let ls_lo = (get_byte(scales_l_word, ib / 2u) >> ((ib & 1u) * 4u)) & 0xFu; + let ls_hi = ((scales_h >> (2u * ib)) & 3u) << 4u; + let dl = d * f16(i32(ls_lo | ls_hi) - 32); + + let iqs = ib * 16u + (pos % 16u); + let nib_shift = (pos / 16u) * 4u; + let q_packed = load_u32_at_src0(block_byte_base + 8u + (iqs / 4u) * 4u); + let nib = (get_byte(q_packed, iqs % 4u) >> nib_shift) & 0xFu; + + shmem[elem_idx] = dl * f16(kvalues_iq4nl[nib]); + } +} +#endif // INIT_SRC0_SHMEM_IQ4_XS + +#ifdef INIT_SRC0_SHMEM_IQ1_S +const BLOCK_SIZE = 256u; +const BLOCK_SIZE_BYTES = 50u; + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { + let tile_m = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + let global_m = offset_m + tile_m; + let global_k = k_outer + tile_k; + + if (global_m >= params.m || global_k >= params.k) { + shmem[elem_idx] = f16(0.0); + continue; + } + + let block_k = global_k / BLOCK_SIZE; + let k_in_block = global_k % BLOCK_SIZE; + + let src0_idx = batch_offset + global_m * params.stride_01 + block_k; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + let d = load_f16_as_f32_at_src0(block_byte_base); + + let ib = k_in_block / 32u; + let pos = k_in_block % 32u; + let l = pos / 8u; + let j = pos % 8u; + + let qh = load_u32_at_src0(block_byte_base + 34u + ib * 2u) & 0xFFFFu; + let dl = d * (2.0 * f32((qh >> 12u) & 7u) + 1.0); + let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000u) != 0u); + + let qs_w = load_u32_at_src0(block_byte_base + 2u + ib * 4u); + let ig = (get_byte(qs_w, l) | (((qh >> (3u * l)) & 7u) << 8u)) * 8u; + + let gw = iq1_grid[(ig + j) / 16u]; + let g = (gw >> (((ig + j) % 16u) * 2u)) & 3u; + let gs = bitcast(g << 30u) >> 30u; + + shmem[elem_idx] = f16(dl * (f32(gs) + delta)); + } +} +#endif // INIT_SRC0_SHMEM_IQ1_S + +#ifdef INIT_SRC0_SHMEM_IQ1_M +const BLOCK_SIZE = 256u; +const BLOCK_SIZE_BYTES = 56u; + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { + let tile_m = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + let global_m = offset_m + tile_m; + let global_k = k_outer + tile_k; + + if (global_m >= params.m || global_k >= params.k) { + shmem[elem_idx] = f16(0.0); + continue; + } + + let block_k = global_k / BLOCK_SIZE; + let k_in_block = global_k % BLOCK_SIZE; + + let src0_idx = batch_offset + global_m * params.stride_01 + block_k; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + + let scales0 = load_u32_at_src0(block_byte_base + 48u); + let scales1 = load_u32_at_src0(block_byte_base + 52u); + let scale_packed = ((scales0 >> 12u) & 0xFu) | + ((scales0 >> 24u) & 0x00F0u) | + ((scales1 >> 4u) & 0x0F00u) | + ((scales1 >> 16u) & 0xF000u); + let d = f32(bitcast>(scale_packed).x); + + let ib = k_in_block / 32u; + let pos = k_in_block % 32u; + let l = pos / 8u; + let j = pos % 8u; + + let scales = select(scales0, scales1, ib >= 4u); + let sw = (scales >> (16u * ((ib / 2u) % 2u))) & 0xFFFFu; + let s_pair = (sw >> (6u * (ib % 2u) + 3u * (l / 2u))) & 0x7u; + let dl = d * f32(2u * s_pair + 1u); + + let qh_word = load_u32_at_src0(block_byte_base + 32u + (ib / 2u) * 4u); + let qh = qh_word >> (16u * (ib % 2u)); + let qh_nib = (qh >> (4u * l)) & 0xFu; + + let qs_w = load_u32_at_src0(block_byte_base + ib * 4u); + let idx = get_byte(qs_w, l) | ((qh_nib & 7u) << 8u); + let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh_nib & 0x8u) != 0u); + + let ig = idx * 8u; + let gw = iq1_grid[(ig + j) / 16u]; + let g = (gw >> (((ig + j) % 16u) * 2u)) & 3u; + let gs = bitcast(g << 30u) >> 30u; + + shmem[elem_idx] = f16(dl * (f32(gs) + delta)); + } +} +#endif // INIT_SRC0_SHMEM_IQ1_M + +#ifdef INIT_SRC0_SHMEM_IQ2_XXS +const BLOCK_SIZE = 256u; +const BLOCK_SIZE_BYTES = 66u; + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { + let tile_m = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + let global_m = offset_m + tile_m; + let global_k = k_outer + tile_k; + + if (global_m >= params.m || global_k >= params.k) { + shmem[elem_idx] = f16(0.0); + continue; + } + + let block_k = global_k / BLOCK_SIZE; + let k_in_block = global_k % BLOCK_SIZE; + + let src0_idx = batch_offset + global_m * params.stride_01 + block_k; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + let d = load_f16_as_f32_at_src0(block_byte_base); + + let entry_idx = k_in_block / 8u; + let j = k_in_block % 8u; + + let ib = entry_idx & ~3u; + let l = entry_idx & 3u; + + let aux0 = load_u32_at_src0(block_byte_base + 2u + ib * 2u); + let aux1 = load_u32_at_src0(block_byte_base + 2u + (ib + 2u) * 2u); + let db = d * (0.5 + f32(aux1 >> 28u)) * 0.25; + + let ig = get_byte(aux0, l) * 8u; + let is = (aux1 >> (7u * l)) & 127u; + let signs = get_byte(ksigns_iq2xs[is / 4u], is % 4u); + + let g = get_byte(iq2xxs_grid[(ig + j) / 4u], (ig + j) % 4u); + let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4u], j % 4u) & signs) != 0u); + + shmem[elem_idx] = f16(db * f32(g) * m); + } +} +#endif // INIT_SRC0_SHMEM_IQ2_XXS + +#ifdef INIT_SRC0_SHMEM_IQ2_XS +const BLOCK_SIZE = 256u; +const BLOCK_SIZE_BYTES = 74u; + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { + let tile_m = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + let global_m = offset_m + tile_m; + let global_k = k_outer + tile_k; + + if (global_m >= params.m || global_k >= params.k) { + shmem[elem_idx] = f16(0.0); + continue; + } + + let block_k = global_k / BLOCK_SIZE; + let k_in_block = global_k % BLOCK_SIZE; + + let src0_idx = batch_offset + global_m * params.stride_01 + block_k; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + let d = load_f16_as_f32_at_src0(block_byte_base); + + let entry_idx = k_in_block / 8u; + let j = k_in_block % 8u; + + let ib = entry_idx & ~3u; + let l = entry_idx & 3u; + + let scales_word = load_u32_at_src0(block_byte_base + 66u + (ib / 16u) * 4u); + let s = get_byte(scales_word, (ib % 16u) / 4u); + let s_nib = select(s & 0xFu, (s >> 4u) & 0xFu, (l / 2u) != 0u); + let dl = d * (0.5 + f32(s_nib)) * 0.25; + + let qs_word = load_u32_at_src0(block_byte_base + 2u + (ib + l) * 2u); + let qs_val = qs_word & 0xFFFFu; + let ig = (qs_val & 511u) * 8u; + let is = qs_val >> 9u; + let signs = get_byte(ksigns_iq2xs[is / 4u], is % 4u); + + let g = get_byte(iq2xs_grid[(ig + j) / 4u], (ig + j) % 4u); + let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4u], j % 4u) & signs) != 0u); + + shmem[elem_idx] = f16(dl * f32(g) * m); + } +} +#endif // INIT_SRC0_SHMEM_IQ2_XS + +#ifdef INIT_SRC0_SHMEM_IQ2_S +const BLOCK_SIZE = 256u; +const BLOCK_SIZE_BYTES = 82u; + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { + let tile_m = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + let global_m = offset_m + tile_m; + let global_k = k_outer + tile_k; + + if (global_m >= params.m || global_k >= params.k) { + shmem[elem_idx] = f16(0.0); + continue; + } + + let block_k = global_k / BLOCK_SIZE; + let k_in_block = global_k % BLOCK_SIZE; + + let src0_idx = batch_offset + global_m * params.stride_01 + block_k; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + let d = load_f16_as_f32_at_src0(block_byte_base); + + let ib = k_in_block / 32u; + let l = (k_in_block % 32u) / 8u; + let j = k_in_block % 8u; + + let scales_word = load_u32_at_src0(block_byte_base + 74u + (ib / 4u) * 4u); + let s = get_byte(scales_word, ib % 4u); + let s_nib = select(s & 0xFu, (s >> 4u) & 0xFu, (l / 2u) != 0u); + let dl = d * (0.5 + f32(s_nib)) * 0.25; + + let qs_word = load_u32_at_src0(block_byte_base + 2u + ib * 4u); + let qh_word = load_u32_at_src0(block_byte_base + 66u + (ib / 4u) * 4u); + let qh_b = (get_byte(qh_word, ib % 4u) << (8u - 2u * l)) & 0x300u; + let ig = (get_byte(qs_word, l) | qh_b) * 8u; + + let signs_word = load_u32_at_src0(block_byte_base + 34u + ib * 4u); + let signs = get_byte(signs_word, l); + + let g = get_byte(iq2s_grid[(ig + j) / 4u], (ig + j) % 4u); + let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4u], j % 4u) & signs) != 0u); + + shmem[elem_idx] = f16(dl * f32(g) * m); + } +} +#endif // INIT_SRC0_SHMEM_IQ2_S + +#ifdef INIT_SRC0_SHMEM_IQ3_XXS +const BLOCK_SIZE = 256u; +const BLOCK_SIZE_BYTES = 98u; + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { + let tile_m = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + let global_m = offset_m + tile_m; + let global_k = k_outer + tile_k; + + if (global_m >= params.m || global_k >= params.k) { + shmem[elem_idx] = f16(0.0); + continue; + } + + let block_k = global_k / BLOCK_SIZE; + let k_in_block = global_k % BLOCK_SIZE; + + let src0_idx = batch_offset + global_m * params.stride_01 + block_k; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + let d = load_f16_as_f32_at_src0(block_byte_base); + + let ib_pair = k_in_block / 32u; + let in_pair = k_in_block % 32u; + let l = in_pair / 8u; + let in_l = in_pair % 8u; + let k2 = in_l / 4u; + let j = in_l % 4u; + + let ib = ib_pair * 2u; + let sc_sign_off = block_byte_base + 2u + (ib + 32u) * 2u; + let sc_sign = load_u32_at_src0(sc_sign_off); + let db = d * (0.5 + f32(sc_sign >> 28u)) * 0.5; + let is = (sc_sign >> (7u * l)) & 127u; + let signs = get_byte(ksigns_iq2xs[is / 4u], is % 4u); + + let ig_word = load_u32_at_src0(block_byte_base + 2u + (ib * 2u + l) * 2u) & 0xFFFFu; + let ig_byte = get_byte(ig_word, k2); + let g = get_byte(iq3xxs_grid[ig_byte], j); + let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[k2], j) & signs) != 0u); + + shmem[elem_idx] = f16(db * f32(g) * m); + } +} +#endif // INIT_SRC0_SHMEM_IQ3_XXS + +#ifdef INIT_SRC0_SHMEM_IQ3_S +const BLOCK_SIZE = 256u; +const BLOCK_SIZE_BYTES = 110u; + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { + let tile_m = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + let global_m = offset_m + tile_m; + let global_k = k_outer + tile_k; + + if (global_m >= params.m || global_k >= params.k) { + shmem[elem_idx] = f16(0.0); + continue; + } + + let block_k = global_k / BLOCK_SIZE; + let k_in_block = global_k % BLOCK_SIZE; + + let src0_idx = batch_offset + global_m * params.stride_01 + block_k; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + let d = load_f16_as_f32_at_src0(block_byte_base); + + let ib = k_in_block / 64u; + let rest = k_in_block % 64u; + let k = rest / 32u; + let in_k = rest % 32u; + let l = in_k / 8u; + let in_l = in_k % 8u; + let k2 = in_l / 4u; + let j = in_l % 4u; + + let scales_word = load_u32_at_src0(block_byte_base + 106u); + let s = get_byte(scales_word, ib); + let s_nib = select(s & 0xFu, (s >> 4u) & 0xFu, k != 0u); + let dl = d * (1.0 + 2.0 * f32(s_nib)); + + let qh_word = load_u32_at_src0(block_byte_base + 66u + (ib / 2u) * 4u); + let qh_byte = get_byte(qh_word, (ib % 2u) * 2u + k); + + let ig_word = load_u32_at_src0(block_byte_base + 2u + (ib * 8u + k * 4u + l) * 2u) & 0xFFFFu; + let ig_lo = get_byte(ig_word, 0u) | ((qh_byte << (8u - 2u * l)) & 256u); + let ig_hi = get_byte(ig_word, 1u) | ((qh_byte << (7u - 2u * l)) & 256u); + let ig = select(ig_lo, ig_hi, k2 != 0u); + + let signs_word = load_u32_at_src0(block_byte_base + 74u + (ib * 2u + k) * 4u); + let signs = get_byte(signs_word, l); + + let g = get_byte(iq3s_grid[ig], j); + let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[k2], j) & signs) != 0u); + + shmem[elem_idx] = f16(dl * f32(g) * m); + } +} +#endif // INIT_SRC0_SHMEM_IQ3_S From 582d2562a41f89388e5040253100780b3934c7c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Thu, 30 Apr 2026 13:04:50 +0200 Subject: [PATCH 527/831] CUDA: fix tile FA kernel on Pascal (llama/22541) --- ggml/src/ggml-cuda/fattn-tile.cuh | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-tile.cuh b/ggml/src/ggml-cuda/fattn-tile.cuh index 928b856f9d2..585f2c22853 100644 --- a/ggml/src/ggml-cuda/fattn-tile.cuh +++ b/ggml/src/ggml-cuda/fattn-tile.cuh @@ -68,7 +68,7 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64) - GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 32, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 16, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64) @@ -130,7 +130,7 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64) - GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 32, 256, 2, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 16, 256, 2, 32, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 32, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 32, 64) @@ -1124,7 +1124,7 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_cuda_context & ctx, ggm constexpr size_t nbytes_shared = 0; #ifdef GGML_USE_HIP - if constexpr (DV <= 128) { + if constexpr (DKQ <= 128) { if (Q->ne[1] > 32/ncols2) { constexpr int cols_per_block = 64; const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; @@ -1138,7 +1138,7 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_cuda_context & ctx, ggm #endif // GGML_USE_HIP #ifndef GGML_USE_HIP - if constexpr (DV <= 256) + if constexpr (DKQ <= 256) #endif // GGML_USE_HIP { if (Q->ne[1] > 16/ncols2) { @@ -1220,11 +1220,22 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm const int gqa_limit = nvidia && gqa_ratio <= 4 && DV <= 256 ? 16 : INT_MAX; const bool use_gqa_opt = mask && max_bias == 0.0f && Q->ne[1] <= gqa_limit && K->ne[1] % FATTN_KQ_STRIDE == 0; - if constexpr (DKQ == 320) { // Mistral Small 4 + if constexpr (DKQ == 320) { + // This branch is only used for Mistral Small 4 which has a GQA ratio of 32. + // On AMD, simply use that GQA ratio with 32 columns / block since we always have enough SRAM. + // On NVIDIA however, the tile kernel is only used for GPUs that can't use the mma kernel (Pascal and older). + // Therefore, use a GQA ratio of 16 with 16 columns / block to stay below 48 kiB of SRAM / block. +#ifdef GGML_USE_HIP if (use_gqa_opt && gqa_ratio % 32 == 0) { launch_fattn_tile_switch_ncols1(ctx, dst); return; } +#else + if (use_gqa_opt && gqa_ratio % 16 == 0) { + launch_fattn_tile_switch_ncols1(ctx, dst); + return; + } +#endif // GGML_USE_HIP GGML_ABORT("flash-attn tile (320/256): expected GQA ratio multiple of 32"); } From 0c7c3ba570cb0b6f03da762d53ba211022cfb89a Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Thu, 30 Apr 2026 17:37:13 +0200 Subject: [PATCH 528/831] vulkan: add get/set tensor 2d functions (llama/22514) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * vulkan: add get/set_tensor_2d functions * fix backend interface comments * Update ggml/src/ggml-metal/ggml-metal.cpp Co-authored-by: Sigbjørn Skjæret --- ggml/src/ggml-backend-meta.cpp | 2 +- ggml/src/ggml-blas/ggml-blas.cpp | 4 +- ggml/src/ggml-cann/ggml-cann.cpp | 2 +- ggml/src/ggml-cpu/ggml-cpu.cpp | 2 +- ggml/src/ggml-cuda/ggml-cuda.cu | 4 +- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 2 +- ggml/src/ggml-metal/ggml-metal.cpp | 6 +- ggml/src/ggml-opencl/ggml-opencl.cpp | 4 +- ggml/src/ggml-rpc/ggml-rpc.cpp | 4 +- ggml/src/ggml-sycl/ggml-sycl.cpp | 2 +- ggml/src/ggml-virtgpu/ggml-backend.cpp | 2 +- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 213 +++++++++++++++++++------ ggml/src/ggml-webgpu/ggml-webgpu.cpp | 2 +- ggml/src/ggml-zdnn/ggml-zdnn.cpp | 2 +- ggml/src/ggml-zendnn/ggml-zendnn.cpp | 2 +- 15 files changed, 181 insertions(+), 72 deletions(-) diff --git a/ggml/src/ggml-backend-meta.cpp b/ggml/src/ggml-backend-meta.cpp index fbc02d6458a..c0ffd9a048b 100644 --- a/ggml/src/ggml-backend-meta.cpp +++ b/ggml/src/ggml-backend-meta.cpp @@ -2100,8 +2100,8 @@ static const ggml_backend_i ggml_backend_meta_i = { /* .free = */ ggml_backend_meta_free, /* .set_tensor_async = */ ggml_backend_meta_set_tensor_async, /* .get_tensor_async = */ ggml_backend_meta_get_tensor_async, - /* .get_tensor_2d_async = */ nullptr, /* .set_tensor_2d_async = */ nullptr, + /* .get_tensor_2d_async = */ nullptr, /* .cpy_tensor_async = */ nullptr, /* .synchronize = */ ggml_backend_meta_synchronize, /* .graph_plan_create = */ nullptr, diff --git a/ggml/src/ggml-blas/ggml-blas.cpp b/ggml/src/ggml-blas/ggml-blas.cpp index 05245b69807..b4c735267e0 100644 --- a/ggml/src/ggml-blas/ggml-blas.cpp +++ b/ggml/src/ggml-blas/ggml-blas.cpp @@ -262,9 +262,9 @@ static struct ggml_backend_i blas_backend_i = { /* .get_name = */ ggml_backend_blas_get_name, /* .free = */ ggml_backend_blas_free, /* .set_tensor_async = */ NULL, - /* .get_tensor_2d_async = */ NULL, - /* .set_tensor_2d_async = */ NULL, /* .get_tensor_async = */ NULL, + /* .set_tensor_2d_async = */ NULL, + /* .get_tensor_2d_async = */ NULL, /* .cpy_tensor_async = */ NULL, /* .synchronize = */ NULL, /* .graph_plan_create = */ NULL, diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index 3618ba7f6f6..5f51ea3bb3c 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -2746,8 +2746,8 @@ static const ggml_backend_i ggml_backend_cann_interface = { /* .free = */ ggml_backend_cann_free, /* .set_tensor_async = */ ggml_backend_cann_set_tensor_async, /* .get_tensor_async = */ ggml_backend_cann_get_tensor_async, - /* .get_tensor_2d_async = */ NULL, /* .set_tensor_2d_async = */ NULL, + /* .get_tensor_2d_async = */ NULL, /* .cpy_tensor_async = */ ggml_backend_cann_cpy_tensor_async, /* .synchronize = */ ggml_backend_cann_synchronize, /* .graph_plan_create = */ NULL, diff --git a/ggml/src/ggml-cpu/ggml-cpu.cpp b/ggml/src/ggml-cpu/ggml-cpu.cpp index 49f840be207..128883b41ce 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.cpp +++ b/ggml/src/ggml-cpu/ggml-cpu.cpp @@ -195,8 +195,8 @@ static const struct ggml_backend_i ggml_backend_cpu_i = { /* .free = */ ggml_backend_cpu_free, /* .set_tensor_async = */ NULL, /* .get_tensor_async = */ NULL, - /* .get_tensor_2d_async = */ NULL, /* .set_tensor_2d_async = */ NULL, + /* .get_tensor_2d_async = */ NULL, /* .cpy_tensor_async = */ NULL, /* .synchronize = */ NULL, /* .graph_plan_create = */ ggml_backend_cpu_graph_plan_create, diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 0e6f74685d6..fbe0fa06242 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -4588,8 +4588,8 @@ static const ggml_backend_i ggml_backend_cuda_interface = { /* .free = */ ggml_backend_cuda_free, /* .set_tensor_async = */ ggml_backend_cuda_set_tensor_async, /* .get_tensor_async = */ ggml_backend_cuda_get_tensor_async, - /* .get_tensor_2d_async = */ ggml_backend_cuda_set_tensor_2d_async, - /* .set_tensor_2d_async = */ ggml_backend_cuda_get_tensor_2d_async, + /* .set_tensor_2d_async = */ ggml_backend_cuda_set_tensor_2d_async, + /* .get_tensor_2d_async = */ ggml_backend_cuda_get_tensor_2d_async, /* .cpy_tensor_async = */ ggml_backend_cuda_cpy_tensor_async, /* .synchronize = */ ggml_backend_cuda_synchronize, /* .graph_plan_create = */ NULL, diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 9345da62168..17ac083f4ea 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -3036,8 +3036,8 @@ static struct ggml_backend_i hexagon_backend_i = { /* .free = */ ggml_backend_hexagon_free, /* .set_tensor_async = */ NULL, /* .get_tensor_async = */ NULL, - /* .get_tensor_2d_async = */ NULL, /* .set_tensor_2d_async = */ NULL, + /* .get_tensor_2d_async = */ NULL, /* .cpy_tensor_async = */ NULL, /* .synchronize = */ ggml_backend_hexagon_synchronize, /* .graph_plan_create = */ NULL, diff --git a/ggml/src/ggml-metal/ggml-metal.cpp b/ggml/src/ggml-metal/ggml-metal.cpp index 6a836e45908..cc329d67594 100644 --- a/ggml/src/ggml-metal/ggml-metal.cpp +++ b/ggml/src/ggml-metal/ggml-metal.cpp @@ -166,8 +166,8 @@ static ggml_backend_buffer_i ggml_backend_metal_buffer_private_i = { /* .memset_tensor = */ ggml_backend_metal_buffer_private_memset_tensor, /* .set_tensor = */ ggml_backend_metal_buffer_private_set_tensor, /* .get_tensor = */ ggml_backend_metal_buffer_private_get_tensor, - /* .get_tensor_2d_async = */ NULL, - /* .set_tensor_2d_async = */ NULL, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, /* .cpy_tensor = */ ggml_backend_metal_buffer_private_cpy_tensor, /* .clear = */ ggml_backend_metal_buffer_private_clear, /* .reset = */ NULL, @@ -567,8 +567,8 @@ static ggml_backend_i ggml_backend_metal_i = { /* .free = */ ggml_backend_metal_free, /* .set_tensor_async = */ ggml_backend_metal_set_tensor_async, /* .get_tensor_async = */ ggml_backend_metal_get_tensor_async, - /* .get_tensor_2d_async = */ NULL, /* .set_tensor_2d_async = */ NULL, + /* .get_tensor_2d_async = */ NULL, /* .cpy_tensor_async = */ ggml_backend_metal_cpy_tensor_async, // only needed for multi-GPU setups /* .synchronize = */ ggml_backend_metal_synchronize, /* .graph_plan_create = */ NULL, diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 4d31591a4a6..11f72a5198a 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -4343,9 +4343,9 @@ static ggml_backend_i ggml_backend_opencl_i = { /* .free = */ ggml_backend_opencl_free, /* .set_tensor_async = */ NULL, /* ggml_backend_opencl_set_tensor_async */ /* .get_tensor_async = */ NULL, /* ggml_backend_opencl_get_tensor_async */ - /* .cpy_tensor_async = */ NULL, /* ggml_backend_opencl_cpy_tensor_async */ - /* .get_tensor_2d_async = */ NULL, /* .set_tensor_2d_async = */ NULL, + /* .get_tensor_2d_async = */ NULL, + /* .cpy_tensor_async = */ NULL, /* ggml_backend_opencl_cpy_tensor_async */ /* .synchronize = */ ggml_backend_opencl_synchronize, /* .graph_plan_create = */ NULL, /* .graph_plan_free = */ NULL, diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index 505bec73d37..7176d2feef9 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -740,9 +740,9 @@ static ggml_backend_i ggml_backend_rpc_interface = { /* .free = */ ggml_backend_rpc_free, /* .set_tensor_async = */ NULL, /* .get_tensor_async = */ NULL, - /* .cpy_tensor_async = */ NULL, - /* .get_tensor_2d_async = */ NULL, /* .set_tensor_2d_async = */ NULL, + /* .get_tensor_2d_async = */ NULL, + /* .cpy_tensor_async = */ NULL, /* .synchronize = */ ggml_backend_rpc_synchronize, /* .graph_plan_create = */ NULL, /* .graph_plan_free = */ NULL, diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 1eead625e76..f06147eeeb8 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -4700,8 +4700,8 @@ static ggml_backend_i ggml_backend_sycl_interface = { /* .free = */ ggml_backend_sycl_free, /* .set_tensor_async = */ ggml_backend_sycl_set_tensor_async, /* .get_tensor_async = */ ggml_backend_sycl_get_tensor_async, - /* .get_tensor_2d_async = */ NULL, /* .set_tensor_2d_async = */ NULL, + /* .get_tensor_2d_async = */ NULL, /* .cpy_tensor_async = */ NULL, // ggml_backend_sycl_cpy_tensor_async, // // TODO: update for the new // interface diff --git a/ggml/src/ggml-virtgpu/ggml-backend.cpp b/ggml/src/ggml-virtgpu/ggml-backend.cpp index 2b978556228..12756c9282f 100644 --- a/ggml/src/ggml-virtgpu/ggml-backend.cpp +++ b/ggml/src/ggml-virtgpu/ggml-backend.cpp @@ -34,8 +34,8 @@ static ggml_backend_i ggml_backend_remoting_interface = { /* .free = */ ggml_backend_remoting_free, /* .set_tensor_async = */ NULL, // ggml_backend_remoting_set_tensor_async, /* .get_tensor_async = */ NULL, // ggml_backend_remoting_get_tensor_async, - /* .get_tensor_2d_async = */ NULL, /* .set_tensor_2d_async = */ NULL, + /* .get_tensor_2d_async = */ NULL, /* .cpy_tensor_async = */ NULL, // ggml_backend_remoting_cpy_tensor_async, /* .synchronize = */ NULL, // ggml_backend_remoting_synchronize, /* .graph_plan_create = */ NULL, diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 69c24bb5877..10b73317943 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -6845,7 +6845,7 @@ static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_cont } } -static bool ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height, bool sync_staging = false) { +static bool ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t dpitch, size_t width, size_t height, bool sync_staging = false) { VK_LOG_DEBUG("ggml_vk_buffer_write_2d_async(" << width << ", " << height << ")"); // Check if src is pinned memory vk_buffer buf = nullptr; @@ -6855,7 +6855,7 @@ static bool ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, siz if (buf != nullptr) { // Memory is pinned, use as staging buffer std::vector slices(1); - if (width == spitch) { + if (width == spitch && width == dpitch) { // Only do single write if stride is equal slices[0].srcOffset = buf_offset; slices[0].dstOffset = offset; @@ -6864,7 +6864,7 @@ static bool ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, siz slices.resize(height); for (size_t i = 0; i < height; i++) { slices[i].srcOffset = buf_offset + i * spitch; - slices[i].dstOffset = offset + i * width; + slices[i].dstOffset = offset + i * dpitch; slices[i].size = width; } } @@ -6881,21 +6881,30 @@ static bool ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, siz } // Staging buffer required - const size_t copy_size = width*height; - ggml_vk_ensure_sync_staging_buffer(dst->device, copy_size); + const size_t staging_size = width * height; + ggml_vk_ensure_sync_staging_buffer(dst->device, staging_size); vk_buffer& staging_buffer = dst->device->sync_staging; - VkBufferCopy buf_copy = { - 0, - offset, - copy_size}; + std::vector slices(1); + if (width == dpitch) { + slices[0].srcOffset = 0; + slices[0].dstOffset = offset; + slices[0].size = staging_size; + } else { + slices.resize(height); + for (size_t i = 0; i < height; i++) { + slices[i].srcOffset = i * width; + slices[i].dstOffset = offset + i * dpitch; + slices[i].size = width; + } + } ggml_vk_sync_buffers(nullptr, subctx); - vkCmdCopyBuffer(subctx->s->buffer->buf, (VkBuffer)staging_buffer->buffer, (VkBuffer)dst->buffer, 1, &buf_copy); + subctx->s->buffer->buf.copyBuffer((VkBuffer)staging_buffer->buffer, (VkBuffer)dst->buffer, slices); if (width == spitch) { - deferred_memcpy((uint8_t *)staging_buffer->ptr, src, width * height, &subctx->in_memcpys); + deferred_memcpy((uint8_t *)staging_buffer->ptr, src, staging_size, &subctx->in_memcpys); } else { for (size_t i = 0; i < height; i++) { deferred_memcpy((uint8_t *)staging_buffer->ptr + i * width, (const uint8_t *) src + i * spitch, width, &subctx->in_memcpys); @@ -6906,24 +6915,24 @@ static bool ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, siz static bool ggml_vk_buffer_write_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t size, bool sync_staging = false) { VK_LOG_DEBUG("ggml_vk_buffer_write_async(" << size << ")"); - return ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, size, size, 1, sync_staging); + return ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, size, size, size, 1, sync_staging); } -static void ggml_vk_buffer_write_2d(vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height) { +static void ggml_vk_buffer_write_2d(vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t dpitch, size_t width, size_t height) { VK_LOG_DEBUG("ggml_vk_buffer_write_2d(" << width << ", " << height << ")"); // Buffer is already mapped if(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) { GGML_ASSERT(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostCoherent); for (size_t i = 0; i < height; i++) { - memcpy((uint8_t *)dst->ptr + offset + i * width, (const uint8_t *) src + i * spitch, width); + memcpy((uint8_t *)dst->ptr + offset + i * dpitch, (const uint8_t *) src + i * spitch, width); } } else { std::lock_guard guard(dst->device->mutex); vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue.cmd_pool); ggml_vk_ctx_begin(dst->device, subctx); - bool ret = ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, spitch, width, height, true); + bool ret = ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, spitch, dpitch, width, height, true); GGML_ASSERT(ret); ggml_vk_ctx_end(subctx); @@ -6944,7 +6953,7 @@ static void ggml_vk_buffer_write_2d(vk_buffer& dst, size_t offset, const void * static void ggml_vk_buffer_write(vk_buffer& dst, size_t offset, const void * src, size_t size) { VK_LOG_DEBUG("ggml_vk_buffer_write(" << size << ")"); - ggml_vk_buffer_write_2d(dst, offset, src, 0, size, 1); + ggml_vk_buffer_write_2d(dst, offset, src, size, size, size, 1); } static bool ggml_vk_buffer_read_2d_async(vk_context subctx, vk_buffer& src, size_t offset, void * dst, size_t spitch, size_t dpitch, size_t width, size_t height, bool sync_staging = false) { @@ -6990,15 +6999,35 @@ static bool ggml_vk_buffer_read_2d_async(vk_context subctx, vk_buffer& src, size } // Fall back to staging buffer - const size_t copy_size = dpitch * height; - ggml_vk_ensure_sync_staging_buffer(src->device, copy_size); + const size_t staging_size = width * height; + ggml_vk_ensure_sync_staging_buffer(src->device, staging_size); vk_buffer& staging_buffer = src->device->sync_staging; + std::vector staging_slices(1); + if (width == spitch) { + staging_slices[0].srcOffset = offset; + staging_slices[0].dstOffset = 0; + staging_slices[0].size = staging_size; + } else { + staging_slices.resize(height); + for (size_t i = 0; i < height; i++) { + staging_slices[i].srcOffset = offset + i * spitch; + staging_slices[i].dstOffset = i * width; + staging_slices[i].size = width; + } + } + ggml_vk_sync_buffers(nullptr, subctx); - subctx->s->buffer->buf.copyBuffer(src->buffer, staging_buffer->buffer, slices); + subctx->s->buffer->buf.copyBuffer(src->buffer, staging_buffer->buffer, staging_slices); - deferred_memcpy(dst, staging_buffer->ptr, copy_size, &subctx->out_memcpys); + if (width == dpitch) { + deferred_memcpy(dst, staging_buffer->ptr, staging_size, &subctx->out_memcpys); + } else { + for (size_t i = 0; i < height; i++) { + deferred_memcpy((uint8_t *) dst + i * dpitch, (const uint8_t *) staging_buffer->ptr + i * width, width, &subctx->out_memcpys); + } + } return true; } @@ -7006,8 +7035,8 @@ static bool ggml_vk_buffer_read_async(vk_context subctx, vk_buffer& src, size_t return ggml_vk_buffer_read_2d_async(subctx, src, offset, dst, size, size, size, 1, sync_staging); } -static void ggml_vk_buffer_read(vk_buffer& src, size_t offset, void * dst, size_t size) { - VK_LOG_DEBUG("ggml_vk_buffer_read(" << src->buffer << ", " << offset << ", " << size << ")"); +static void ggml_vk_buffer_read_2d(vk_buffer& src, size_t offset, void * dst, size_t spitch, size_t dpitch, size_t width, size_t height) { + VK_LOG_DEBUG("ggml_vk_buffer_read_2d(" << src->buffer << ", " << offset << ", " << width << ", " << height << ")"); // If the device is not an UMA device the memory is host-accessible through rebar. While writing // through PCIe is sufficient fast reading back data from PCIe is slower than going through @@ -7015,18 +7044,20 @@ static void ggml_vk_buffer_read(vk_buffer& src, size_t offset, void * dst, size_ if(src->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible && src->device->uma) { GGML_ASSERT(src->memory_property_flags & vk::MemoryPropertyFlagBits::eHostCoherent); - memcpy(dst, (uint8_t *) src->ptr + offset, size); + for (size_t i = 0; i < height; i++) { + memcpy((uint8_t *) dst + i * dpitch, (const uint8_t *) src->ptr + offset + i * spitch, width); + } } else { std::lock_guard guard(src->device->mutex); vk_context subctx = ggml_vk_create_temporary_context(src->device->transfer_queue.cmd_pool); ggml_vk_ctx_begin(src->device, subctx); - bool ret = ggml_vk_buffer_read_async(subctx, src, offset, dst, size, true); + bool ret = ggml_vk_buffer_read_2d_async(subctx, src, offset, dst, spitch, dpitch, width, height, true); GGML_ASSERT(ret); ggml_vk_ctx_end(subctx); ggml_vk_submit(subctx, src->device->fence); - VK_CHECK(src->device->device.waitForFences({ src->device->fence }, true, UINT64_MAX), "vk_buffer_read waitForFences"); + VK_CHECK(src->device->device.waitForFences({ src->device->fence }, true, UINT64_MAX), "vk_buffer_read_2d waitForFences"); src->device->device.resetFences({ src->device->fence }); ggml_vk_queue_command_pools_cleanup(src->device); @@ -7036,6 +7067,11 @@ static void ggml_vk_buffer_read(vk_buffer& src, size_t offset, void * dst, size_ } } +static void ggml_vk_buffer_read(vk_buffer& src, size_t offset, void * dst, size_t size) { + VK_LOG_DEBUG("ggml_vk_buffer_read(" << src->buffer << ", " << offset << ", " << size << ")"); + ggml_vk_buffer_read_2d(src, offset, dst, size, size, size, 1); +} + static void ggml_vk_buffer_copy_async(vk_context& ctx, vk_buffer& dst, size_t dst_offset, vk_buffer& src, size_t src_offset, size_t size) { VK_LOG_DEBUG("ggml_vk_buffer_copy_async(" << size << ")"); // Make sure both buffers are on same device @@ -7067,7 +7103,7 @@ static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& sr // Copy to src staging buffer ggml_vk_buffer_copy(src->device->sync_staging, 0, src, src_offset, size); // Copy to dst buffer - ggml_vk_buffer_write_2d(dst, dst_offset, src->device->sync_staging->ptr, 0, size, 1); + ggml_vk_buffer_write(dst, dst_offset, src->device->sync_staging->ptr, size); } } @@ -13615,6 +13651,20 @@ static void ggml_backend_vk_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml ggml_vk_buffer_write(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size); } +static void ggml_backend_vk_buffer_set_tensor_2d(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, + size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data) { + VK_LOG_DEBUG("ggml_backend_vk_buffer_set_tensor_2d(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ", " << + n_copies << ", " << stride_tensor << ", " << stride_data << ")"); + ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context; + vk_buffer buf = buf_ctx->dev_buffer; + + if (size == 0) { + return; + } + + ggml_vk_buffer_write_2d(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, stride_data, stride_tensor, size, n_copies); +} + static void ggml_backend_vk_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { VK_LOG_DEBUG("ggml_backend_vk_buffer_get_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")"); ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context; @@ -13628,6 +13678,21 @@ static void ggml_backend_vk_buffer_get_tensor(ggml_backend_buffer_t buffer, cons ggml_vk_buffer_read(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size); } +static void ggml_backend_vk_buffer_get_tensor_2d(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, + size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data) { + VK_LOG_DEBUG("ggml_backend_vk_buffer_get_tensor_2d(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ", " << + n_copies << ", " << stride_tensor << ", " << stride_data << ")"); + ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context; + + if (size == 0) { + return; + } + + vk_buffer buf = buf_ctx->dev_buffer; + + ggml_vk_buffer_read_2d(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, stride_tensor, stride_data, size, n_copies); +} + static bool ggml_backend_vk_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) { if (ggml_nbytes(src) == 0) { return true; @@ -13662,8 +13727,8 @@ static ggml_backend_buffer_i ggml_backend_vk_buffer_interface = { /* .memset_tensor = */ ggml_backend_vk_buffer_memset_tensor, /* .set_tensor = */ ggml_backend_vk_buffer_set_tensor, /* .get_tensor = */ ggml_backend_vk_buffer_get_tensor, - /* .set_tensor_2d = */ NULL, - /* .get_tensor_2d = */ NULL, + /* .set_tensor_2d = */ ggml_backend_vk_buffer_set_tensor_2d, + /* .get_tensor_2d = */ ggml_backend_vk_buffer_get_tensor_2d, /* .cpy_tensor = */ ggml_backend_vk_buffer_cpy_tensor, /* .clear = */ ggml_backend_vk_buffer_clear, /* .reset = */ NULL, @@ -13819,8 +13884,9 @@ static ggml_backend_buffer_type_t ggml_backend_vk_get_default_buffer_type(ggml_b return &ctx->device->buffer_type; } -static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { - VK_LOG_DEBUG("ggml_backend_vk_set_tensor_async(" << size << ")"); +static void ggml_backend_vk_set_tensor_2d_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, + size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data) { + VK_LOG_DEBUG("ggml_backend_vk_set_tensor_2d_async(" << size << ", " << n_copies << ")"); ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; GGML_ASSERT((tensor->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || tensor->buffer->buft == ggml_backend_vk_host_buffer_type()) && "unsupported buffer type"); @@ -13834,7 +13900,6 @@ static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, ggml_tensor if (ctx->device->async_use_transfer_queue) { if (ctx->transfer_ctx.expired()) { - // Initialize new transfer context cpy_ctx = ggml_vk_create_context(ctx, ctx->transfer_cmd_pool); ctx->transfer_ctx = cpy_ctx; ggml_vk_ctx_begin(ctx->device, cpy_ctx); @@ -13849,25 +13914,48 @@ static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, ggml_tensor auto dst_offset = vk_tensor_offset(tensor) + tensor->view_offs + offset; - bool ret = ggml_vk_buffer_write_async(cpy_ctx, buf, dst_offset, data, size); + bool ret = ggml_vk_buffer_write_2d_async(cpy_ctx, buf, dst_offset, data, stride_data, stride_tensor, size, n_copies); if (!ret) { - ggml_vk_ensure_sync_staging_buffer(ctx, size); + const size_t staging_size = size * n_copies; + ggml_vk_ensure_sync_staging_buffer(ctx, staging_size); ggml_vk_sync_buffers(nullptr, cpy_ctx); - vk::BufferCopy buffer_cpy; - buffer_cpy.srcOffset = 0; - buffer_cpy.dstOffset = dst_offset; - buffer_cpy.size = size; + std::vector slices(1); + if (size == stride_tensor) { + slices[0].srcOffset = 0; + slices[0].dstOffset = dst_offset; + slices[0].size = staging_size; + } else { + slices.resize(n_copies); + for (size_t i = 0; i < n_copies; i++) { + slices[i].srcOffset = i * size; + slices[i].dstOffset = dst_offset + i * stride_tensor; + slices[i].size = size; + } + } - cpy_ctx->s->buffer->buf.copyBuffer(ctx->sync_staging->buffer, buf->buffer, { buffer_cpy }); - deferred_memcpy(ctx->sync_staging->ptr, data, size, &cpy_ctx->in_memcpys); + cpy_ctx->s->buffer->buf.copyBuffer(ctx->sync_staging->buffer, buf->buffer, slices); + + if (size == stride_data) { + deferred_memcpy(ctx->sync_staging->ptr, data, staging_size, &cpy_ctx->in_memcpys); + } else { + for (size_t i = 0; i < n_copies; i++) { + deferred_memcpy((uint8_t *)ctx->sync_staging->ptr + i * size, (const uint8_t *)data + i * stride_data, size, &cpy_ctx->in_memcpys); + } + } ggml_vk_synchronize(ctx); } } -static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { - VK_LOG_DEBUG("ggml_backend_vk_get_tensor_async(" << size << ")"); +static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + VK_LOG_DEBUG("ggml_backend_vk_set_tensor_async(" << size << ")"); + ggml_backend_vk_set_tensor_2d_async(backend, tensor, data, offset, size, 1, size, size); +} + +static void ggml_backend_vk_get_tensor_2d_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, + size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data) { + VK_LOG_DEBUG("ggml_backend_vk_get_tensor_2d_async(" << size << ", " << n_copies << ")"); ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; GGML_ASSERT((tensor->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || tensor->buffer->buft == ggml_backend_vk_host_buffer_type()) && "unsupported buffer type"); @@ -13882,24 +13970,45 @@ static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_ vk_buffer buf = buf_ctx->dev_buffer; auto src_offset = vk_tensor_offset(tensor) + tensor->view_offs + offset; - bool ret = ggml_vk_buffer_read_async(compute_ctx, buf, src_offset, data, size); + bool ret = ggml_vk_buffer_read_2d_async(compute_ctx, buf, src_offset, data, stride_tensor, stride_data, size, n_copies); - // If that failed, copy synchronously through a staging buffer if (!ret) { - ggml_vk_ensure_sync_staging_buffer(ctx, size); + const size_t staging_size = size * n_copies; + ggml_vk_ensure_sync_staging_buffer(ctx, staging_size); ggml_vk_sync_buffers(nullptr, compute_ctx); - vk::BufferCopy buffer_cpy; - buffer_cpy.srcOffset = src_offset; - buffer_cpy.dstOffset = 0; - buffer_cpy.size = size; + std::vector slices(1); + if (size == stride_tensor) { + slices[0].srcOffset = src_offset; + slices[0].dstOffset = 0; + slices[0].size = staging_size; + } else { + slices.resize(n_copies); + for (size_t i = 0; i < n_copies; i++) { + slices[i].srcOffset = src_offset + i * stride_tensor; + slices[i].dstOffset = i * size; + slices[i].size = size; + } + } + + compute_ctx->s->buffer->buf.copyBuffer(buf->buffer, ctx->sync_staging->buffer, slices); - compute_ctx->s->buffer->buf.copyBuffer(buf->buffer, ctx->sync_staging->buffer, { buffer_cpy }); - deferred_memcpy(data, ctx->sync_staging->ptr, size, &compute_ctx->out_memcpys); + if (size == stride_data) { + deferred_memcpy(data, ctx->sync_staging->ptr, staging_size, &compute_ctx->out_memcpys); + } else { + for (size_t i = 0; i < n_copies; i++) { + deferred_memcpy((uint8_t *)data + i * stride_data, (const uint8_t *)ctx->sync_staging->ptr + i * size, size, &compute_ctx->out_memcpys); + } + } ggml_vk_synchronize(ctx); } } +static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { + VK_LOG_DEBUG("ggml_backend_vk_get_tensor_async(" << size << ")"); + ggml_backend_vk_get_tensor_2d_async(backend, tensor, data, offset, size, 1, size, size); +} + static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const ggml_tensor * src, ggml_tensor * dst) { VK_LOG_DEBUG("ggml_backend_vk_cpy_tensor_async(" << src << " -> " << dst << ", size=" << ggml_nbytes(src) << ")"); ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend_dst->context; @@ -15123,8 +15232,8 @@ static ggml_backend_i ggml_backend_vk_interface = { /* .free = */ ggml_backend_vk_free, /* .set_tensor_async = */ ggml_backend_vk_set_tensor_async, /* .get_tensor_async = */ ggml_backend_vk_get_tensor_async, - /* .get_tensor_2d_async = */ NULL, - /* .set_tensor_2d_async = */ NULL, + /* .set_tensor_2d_async = */ ggml_backend_vk_set_tensor_2d_async, + /* .get_tensor_2d_async = */ ggml_backend_vk_get_tensor_2d_async, /* .cpy_tensor_async = */ ggml_backend_vk_cpy_tensor_async, /* .synchronize = */ ggml_backend_vk_synchronize, /* .graph_plan_create = */ NULL, diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 5e55a2a1e1b..a1dccfc0f5a 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -3107,8 +3107,8 @@ static ggml_backend_i ggml_backend_webgpu_i = { /* .free = */ ggml_backend_webgpu_free, /* .set_tensor_async = */ ggml_backend_webgpu_set_tensor_async, /* .get_tensor_async = */ NULL, - /* .get_tensor_2d_async = */ NULL, /* .set_tensor_2d_async = */ NULL, + /* .get_tensor_2d_async = */ NULL, /* .cpy_tensor_async = */ NULL, /* .synchronize = */ ggml_backend_webgpu_synchronize, /* .graph_plan_create = */ NULL, diff --git a/ggml/src/ggml-zdnn/ggml-zdnn.cpp b/ggml/src/ggml-zdnn/ggml-zdnn.cpp index e6b6fc24fd7..639b818d128 100644 --- a/ggml/src/ggml-zdnn/ggml-zdnn.cpp +++ b/ggml/src/ggml-zdnn/ggml-zdnn.cpp @@ -423,8 +423,8 @@ static ggml_backend_i ggml_backend_zdnn_i = { /* .free = */ ggml_backend_zdnn_free, /* .set_tensor_async = */ NULL, /* .get_tensor_async = */ NULL, - /* .get_tensor_2d_async = */ NULL, /* .set_tensor_2d_async = */ NULL, + /* .get_tensor_2d_async = */ NULL, /* .cpy_tensor_async = */ NULL, /* .synchronize = */ NULL, /* .graph_plan_create = */ NULL, diff --git a/ggml/src/ggml-zendnn/ggml-zendnn.cpp b/ggml/src/ggml-zendnn/ggml-zendnn.cpp index fc1df4dbef4..2b82c7c1dbb 100644 --- a/ggml/src/ggml-zendnn/ggml-zendnn.cpp +++ b/ggml/src/ggml-zendnn/ggml-zendnn.cpp @@ -407,8 +407,8 @@ static struct ggml_backend_i ggml_backend_zendnn_i = { /* .free = */ ggml_backend_zendnn_free, /* .set_tensor_async = */ NULL, /* .get_tensor_async = */ NULL, - /* .get_tensor_2d_async = */ NULL, /* .set_tensor_2d_async = */ NULL, + /* .get_tensor_2d_async = */ NULL, /* .cpy_tensor_async = */ NULL, /* .synchronize = */ NULL, /* .graph_plan_create = */ NULL, From b34a9f3d83e8443835ad42778885d3b5ec8b825a Mon Sep 17 00:00:00 2001 From: Masashi Yoshimura Date: Fri, 1 May 2026 06:19:10 +0900 Subject: [PATCH 529/831] ggml-webgpu: Improve performance of mat-vec and mat-mat for MUL_MAT_ID (llama/22464) * Add mat-vec fast path of MUL_MAT_ID. * Add shared accumulation vec logic and the other types supports. * Add i-quant mat-mat for MUL_MAT_ID and fix some parts * Remove n_experts from shader_lib_context. --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 173 +- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 73 +- .../wgsl-shaders/mul_mat_id_vec.wgsl | 154 ++ .../ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl | 1284 +-------------- .../wgsl-shaders/mul_mat_vec_acc.tmpl | 1391 +++++++++++++++++ 5 files changed, 1780 insertions(+), 1295 deletions(-) create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 5239164cd00..0f66275c6a3 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -664,7 +664,7 @@ inline uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_shader_lib_ } const size_t base_q_bytes = (key.head_dim_qk + key.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES + 2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES; - size_t bytes_per_kv = 0; + size_t bytes_per_kv = 0; if (!key.kv_direct) { bytes_per_kv += std::max(key.head_dim_qk, key.head_dim_v); } @@ -701,10 +701,10 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions( (v_offset_elems % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0u); const bool kv_vec_type_supported = K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0; - const bool use_vec = context.supports_subgroups && (context.src0->ne[1] < 20) && (context.src0->ne[0] % 32 == 0) && - (context.src2->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && - kv_vec_type_supported && (K->type != GGML_TYPE_F16 || f16_vec4_aligned) && - (context.src2->type == K->type); + const bool use_vec = context.supports_subgroups && (context.src0->ne[1] < 20) && (context.src0->ne[0] % 32 == 0) && + (context.src2->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && + kv_vec_type_supported && (K->type != GGML_TYPE_F16 || f16_vec4_aligned) && + (context.src2->type == K->type); const bool use_tile = context.supports_subgroups && !context.supports_subgroup_matrix && K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16 && f16_vec4_aligned && (context.src0->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && @@ -862,9 +862,12 @@ struct ggml_webgpu_mul_mat_shader_decisions { struct ggml_webgpu_mul_mat_id_pipeline_key { ggml_type src0_type; ggml_type src1_type; + uint32_t n_experts; + int vectorized; bool operator==(const ggml_webgpu_mul_mat_id_pipeline_key & other) const { - return src0_type == other.src0_type && src1_type == other.src1_type; + return src0_type == other.src0_type && src1_type == other.src1_type && n_experts == other.n_experts && + vectorized == other.vectorized; } }; @@ -873,6 +876,8 @@ struct ggml_webgpu_mul_mat_id_pipeline_key_hash { size_t seed = 0; ggml_webgpu_hash_combine(seed, key.src0_type); ggml_webgpu_hash_combine(seed, key.src1_type); + ggml_webgpu_hash_combine(seed, key.n_experts); + ggml_webgpu_hash_combine(seed, key.vectorized); return seed; } }; @@ -1023,6 +1028,8 @@ class ggml_webgpu_shader_lib { std::unordered_map mul_mat_id_gather_pipelines; // key is fixed std::unordered_map mul_mat_id_pipelines; // src0_type/src1_type + std::unordered_map + mul_mat_id_vec_pipelines; // src0_type/src1_type std::unordered_map set_rows_pipelines; @@ -1516,7 +1523,7 @@ class ggml_webgpu_shader_lib { key.type = context.dst->type; key.d_state = (int) context.src0->ne[0]; key.xbc_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src4) && - ggml_webgpu_tensor_overlap(context.src1, context.src5); + ggml_webgpu_tensor_overlap(context.src1, context.src5); auto it = ssm_scan_pipelines.find(key); if (it != ssm_scan_pipelines.end()) { @@ -1633,10 +1640,10 @@ class ggml_webgpu_shader_lib { ggml_webgpu_mul_mat_vec_pipeline_key key = {}; key.src0_type = context.src0->type; key.src1_type = context.src1->type; - key.vectorized = (context.src0->ne[0] % 4 == 0 && + key.vectorized = (context.src0->ne[0] % 4 == 0 && (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? - 1 : - 0; + 1 : + 0; auto it = mul_mat_vec_pipelines.find(key); if (it != mul_mat_vec_pipelines.end()) { @@ -2012,6 +2019,11 @@ class ggml_webgpu_shader_lib { ggml_webgpu_mul_mat_id_pipeline_key key = {}; key.src0_type = context.src0->type; key.src1_type = context.src1->type; + key.n_experts = context.src0->ne[2]; + key.vectorized = (context.src0->ne[0] % 4 == 0 && context.src0->ne[1] % 4 == 0 && + (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? + 1 : + 0; auto it = mul_mat_id_pipelines.find(key); if (it != mul_mat_id_pipelines.end()) { @@ -2041,14 +2053,12 @@ class ggml_webgpu_shader_lib { switch (context.src0->type) { case GGML_TYPE_F32: defines.push_back("SRC0_INNER_TYPE=f32"); - defines.push_back("FLOAT"); defines.push_back("INIT_SRC0_SHMEM_FLOAT"); defines.push_back("INIT_SRC1_SHMEM_FLOAT"); variant += "_f32"; break; case GGML_TYPE_F16: defines.push_back("SRC0_INNER_TYPE=f16"); - defines.push_back("FLOAT"); defines.push_back("INIT_SRC0_SHMEM_FLOAT"); defines.push_back("INIT_SRC1_SHMEM_FLOAT"); variant += "_f16"; @@ -2064,12 +2074,32 @@ class ggml_webgpu_shader_lib { defines.push_back("U32_DEQUANT_HELPERS"); defines.push_back("SRC0_INNER_TYPE=u32"); + switch (context.src0->type) { + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ4_XS: + defines.push_back(type_upper + "_GRID"); + break; + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + defines.push_back(type_upper + "_GRID"); + defines.push_back(type_upper + "_TABLES"); + break; + default: + break; + } + variant += std::string("_") + src0_name; break; } } - defines.push_back("SCALAR"); + // VEC/SCALAR controls + defines.push_back(key.vectorized ? "VEC" : "SCALAR"); // mul_mat_id is register-tile only. const uint32_t tile_k = @@ -2102,6 +2132,123 @@ class ggml_webgpu_shader_lib { return mul_mat_id_pipelines[key]; } + webgpu_pipeline get_mul_mat_id_vec_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_mul_mat_id_pipeline_key key = {}; + key.src0_type = context.src0->type; + key.src1_type = context.src1->type; + key.n_experts = context.src0->ne[2]; + key.vectorized = (context.src0->ne[0] % 4 == 0 && + (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? + 1 : + 0; + + auto it = mul_mat_id_vec_pipelines.find(key); + if (it != mul_mat_id_vec_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "mul_mat_id_vec"; + const char * shader_src = wgsl_mul_mat_id_vec; + + // src1 type + switch (context.src1->type) { + case GGML_TYPE_F32: + defines.push_back("SRC1_INNER_TYPE=f32"); + break; + case GGML_TYPE_F16: + defines.push_back("SRC1_INNER_TYPE=f16"); + break; + default: + GGML_ABORT("Unsupported src1 type for mul_mat fast shader"); + } + + // src0 type + switch (context.src0->type) { + case GGML_TYPE_F32: + defines.push_back("SRC0_INNER_TYPE=f32"); + defines.push_back("MUL_ACC_FLOAT"); + variant += "_f32"; + break; + case GGML_TYPE_F16: + defines.push_back("SRC0_INNER_TYPE=f16"); + defines.push_back("MUL_ACC_FLOAT"); + variant += "_f16"; + break; + default: + { + // Quantized types: use helpers but accumulate in f16 + const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type); + std::string src0_name = src0_traits->type_name; + std::string type_upper = src0_name; + variant += "_" + src0_name; + std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper); + + defines.push_back("BYTE_HELPERS"); + defines.push_back("MUL_ACC_" + type_upper); + defines.push_back("U32_DEQUANT_HELPERS"); + defines.push_back("SRC0_INNER_TYPE=u32"); + switch (context.src0->type) { + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ4_XS: + defines.push_back(type_upper + "_GRID"); + break; + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ3_XXS: + defines.push_back(type_upper + "_GRID"); + defines.push_back(type_upper + "_TABLES"); + break; + default: + break; + } + break; + } + } + + // VEC/SCALAR controls + defines.push_back(key.vectorized ? "VEC" : "SCALAR"); + + uint32_t wg_size = WEBGPU_MUL_MAT_VEC_WG_SIZE; + uint32_t outputs_per_wg = WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG; + + if (key.src0_type == GGML_TYPE_Q1_0) { + outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG; + } else if (key.src0_type >= GGML_TYPE_Q2_K) { + outputs_per_wg = WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG; + } else if (key.src0_type >= GGML_TYPE_Q4_0) { + outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG; + } + + // variant suffix for src1 type + variant += std::string("_") + (context.src1->type == GGML_TYPE_F32 ? "f32" : "f16"); + + defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); + defines.push_back(std::string("OUTPUTS_PER_WG=") + std::to_string(outputs_per_wg)); + defines.push_back(context.supports_subgroups ? "USE_SUBGROUP_REDUCTION" : "USE_WORKGROUP_REDUCTION"); + variant += context.supports_subgroups ? "_sg_reduce" : "_wg_reduce"; + if (key.vectorized) { + variant += "_vectorized"; + } + + defines.push_back(std::string("N_EXPERTS=") + std::to_string(key.n_experts)); + + auto processed = preprocessor.preprocess(shader_src, defines); + + auto decisions = std::make_shared(); + decisions->wg_size = wg_size; + decisions->outputs_per_wg = outputs_per_wg; + + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + mul_mat_id_vec_pipelines[key] = pipeline; + return mul_mat_id_vec_pipelines[key]; + } + webgpu_pipeline get_unary_pipeline(const ggml_webgpu_shader_lib_context & context) { const bool is_unary = context.dst->op == GGML_OP_UNARY; const int op = is_unary ? (int) ggml_get_unary_op(context.dst) : context.dst->op; diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index a1dccfc0f5a..f102c7a818b 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -1404,7 +1404,6 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: - case GGML_TYPE_Q8_1: case GGML_TYPE_Q6_K: case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: @@ -1527,11 +1526,74 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); } +static webgpu_encoded_op ggml_webgpu_mul_mat_id_vec(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * src2, + ggml_tensor * dst) { + const uint32_t param_n_expert = (uint32_t) src0->ne[2]; + const uint32_t param_n_expert_used = (uint32_t) dst->ne[1]; + + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.src2 = src2; + shader_lib_ctx.dst = dst; + shader_lib_ctx.supports_subgroups = ctx->global_ctx->capabilities.supports_subgroups; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + + webgpu_pipeline pipeline = ctx->shader_lib->get_mul_mat_id_vec_pipeline(shader_lib_ctx); + + std::vector params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + (uint32_t) src0->ne[0], + (uint32_t) src0->ne[1], + param_n_expert, + param_n_expert_used, + (uint32_t) src1->ne[1], + (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), + (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), + (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), + (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), + }; + + std::vector entries = { + ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(src0), ggml_webgpu_tensor_align_offset(ctx, src0), + ggml_webgpu_tensor_binding_size(ctx, src0)), + ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(src1), ggml_webgpu_tensor_align_offset(ctx, src1), + ggml_webgpu_tensor_binding_size(ctx, src1)), + ggml_webgpu_make_bind_group_entry(2, ggml_webgpu_tensor_buf(src2), ggml_webgpu_tensor_align_offset(ctx, src2), + ggml_webgpu_tensor_binding_size(ctx, src2)), + ggml_webgpu_make_bind_group_entry(3, ggml_webgpu_tensor_buf(dst), ggml_webgpu_tensor_align_offset(ctx, dst), + ggml_webgpu_tensor_binding_size(ctx, dst)), + }; + + uint32_t wg_x = 1; + uint32_t wg_y = 1; + + auto * decisions = static_cast(pipeline.context.get()); + + const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; + uint32_t output_groups = CEIL_DIV(dst->ne[0], decisions->outputs_per_wg); + uint32_t total_wg = output_groups * param_n_expert_used; + compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y); + + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); +} + static webgpu_encoded_op ggml_webgpu_mul_mat_id(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * src2, ggml_tensor * dst) { + // we can use mat-vec fast path + if (dst->ne[2] == 1) { + return ggml_webgpu_mul_mat_id_vec(ctx, src0, src1, src2, dst); + } + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; shader_lib_ctx.src0 = src0; shader_lib_ctx.src1 = src1; @@ -3879,6 +3941,15 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ4_XS: supports_op = true; break; default: diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl new file mode 100644 index 00000000000..6ff9bcf2df0 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl @@ -0,0 +1,154 @@ +#ifdef USE_SUBGROUP_REDUCTION +enable subgroups; +#endif +enable f16; + +#define DECLARE_BYTE_LOADERS_SRC0 +#include "common_decls.tmpl" + +#include "mul_mat_vec_acc.tmpl" + +struct MulMatIdVecParams { + offset_src0: u32, + offset_src1: u32, + offset_ids: u32, + offset_dst: u32, + + k: u32, + m: u32, + n_expert: u32, + n_expert_used: u32, + b_ne1: u32, + + stride_01: u32, + stride_11: u32, + stride_02: u32, + stride_12: u32, +}; + +@group(0) @binding(0) var src0: array; // [cols, rows, n_expert] +@group(0) @binding(1) var src1: array; // [cols, b_ne1, n_tokens(1)] +@group(0) @binding(2) var ids: array; // [n_experd_used, n_tokens(1)] +@group(0) @binding(3) var dst: array; // [rows, n_expert_used, n_tokens(1)] + +// "mul_mat_vec_acc.tmpl" requires params.k, params.m, params.stride_01 +@group(0) @binding(4) var params: MulMatIdVecParams; + +// Flattened as [row][thread] to keep each row's reduction contiguous in memory. +var partial_sums: array; + +fn partial_index(row: u32, thread: u32) -> u32 { + return row * WG_SIZE + thread; +} + +var gathered_count_ids: array; +var gathered_expert_used: array; + +@compute @workgroup_size(WG_SIZE) +fn main( + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) wg_id: vec3, + @builtin(num_workgroups) num_wg: vec3 +#ifdef USE_SUBGROUP_REDUCTION + , @builtin(subgroup_id) subgroup_id: u32, + @builtin(subgroup_invocation_id) subgroup_invocation_id: u32, + @builtin(num_subgroups) num_subgroups: u32, + @builtin(subgroup_size) subgroup_size: u32 +#endif +) { + + let thread_id = local_id.x; + + for (var i = thread_id;i < params.n_expert;i += WG_SIZE) { + gathered_count_ids[i] = 0; + } + + workgroupBarrier(); + + // gather the selected experts for the target token. + for (var col = thread_id;col < params.n_expert_used;col += WG_SIZE) { + let expert = ids[params.offset_ids + col]; + gathered_count_ids[expert] = 1; + gathered_expert_used[expert] = col; + } + + workgroupBarrier(); + + let output_groups:u32 = (params.m + OUTPUTS_PER_WG - 1u) / OUTPUTS_PER_WG; + let wg_linear = wg_id.y * num_wg.x + wg_id.x; + + var own_expert:u32 = 0; + var wg_in_batch:u32 = 0; + var wg_sum:u32 = 0; + + for (var i = 0u;i < params.n_expert;i += 1) { + let wg_vec_count = gathered_count_ids[i]; // 1 or 0 + let wg_per_matrix = output_groups * wg_vec_count; + if (wg_sum <= wg_linear && wg_linear < wg_sum + wg_per_matrix) { + own_expert = i; + wg_in_batch = wg_linear - wg_sum; + break; + } + wg_sum += wg_per_matrix; + } + + let row_base = (wg_linear % output_groups) * OUTPUTS_PER_WG; + let dst1_stride = params.m; + + let src0_batch_offset = params.offset_src0 + own_expert * params.stride_02; + let src1_idx_base = params.offset_src1 + (gathered_expert_used[own_expert] % params.b_ne1) * params.stride_11; + let dst_idx_base = params.offset_dst + gathered_expert_used[own_expert] * dst1_stride + row_base; + + let acc = accumulate_vec_dot(thread_id, row_base, src0_batch_offset, src1_idx_base); + +#ifdef USE_SUBGROUP_REDUCTION + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let subgroup_total = subgroupAdd(acc[row]); + if (subgroup_invocation_id == 0u) { + partial_sums[partial_index(row, subgroup_id)] = subgroup_total; + } + } + + workgroupBarrier(); + + for (var row = subgroup_id; (row < OUTPUTS_PER_WG) && (row_base + row < params.m); row += num_subgroups) { + let output_row = row_base + row; + var row_acc = 0.0f; + for (var k = subgroup_invocation_id; k < num_subgroups; k += subgroup_size) { + row_acc += partial_sums[partial_index(row, k)]; + } + let row_total = subgroupAdd(row_acc); + if (subgroup_invocation_id == 0) { + dst[dst_idx_base + row] = row_total; + } + } +#endif + +#ifdef USE_WORKGROUP_REDUCTION + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + partial_sums[partial_index(row, thread_id)] = acc[row]; + } + + workgroupBarrier(); + + var stride:u32 = WG_SIZE / 2u; + + while (stride > 0) { + if (thread_id < stride) { + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + partial_sums[partial_index(row, thread_id)] += partial_sums[partial_index(row, thread_id + stride)]; + } + } + + workgroupBarrier(); + stride = stride / 2; + } + + if (thread_id < OUTPUTS_PER_WG) { + let output_row = row_base + thread_id; + if (output_row < params.m) { + dst[dst_idx_base + thread_id] = partial_sums[partial_index(thread_id, 0)]; + } + } +#endif +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl index a8000439bfb..a194cf40468 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl @@ -6,38 +6,7 @@ enable f16; #define DECLARE_BYTE_LOADERS_SRC0 #include "common_decls.tmpl" -#ifdef U32_DEQUANT_HELPERS -#define SRC0_TYPE u32 - -fn byte_of(v: u32, b: u32) -> u32 { - return (v >> (b * 8u)) & 0xFFu; -} - -fn sbyte_of(v: u32, b: u32) -> i32 { - let raw = i32((v >> (b * 8u)) & 0xFFu); - return select(raw, raw - 256, raw >= 128); -} -#endif - -#ifdef VEC -#define VEC_SIZE 4u -#define SRC0_TYPE vec4 -#define SRC1_TYPE vec4 - -fn inner_dot(src0_val: SRC0_TYPE, src1_val: SRC1_TYPE) -> f32 { - return f32(dot(SRC1_TYPE(src0_val), src1_val)); -} -#endif - -#ifdef SCALAR -#define VEC_SIZE 1u -#define SRC0_TYPE SRC0_INNER_TYPE -#define SRC1_TYPE SRC1_INNER_TYPE - -fn inner_dot(src0_val: SRC0_TYPE, src1_val: SRC1_TYPE) -> f32 { - return f32(src0_val) * f32(src1_val); -} -#endif +#include "mul_mat_vec_acc.tmpl" struct MulMatParams { offset_src0: u32, @@ -62,6 +31,7 @@ struct MulMatParams { @group(0) @binding(1) var src1: array; @group(0) @binding(2) var dst: array; +// "mul_mat_vec_acc.tmpl" requires params.k, params.m, params.stride_01 @group(0) @binding(3) var params: MulMatParams; // Flattened as [row][thread] to keep each row's reduction contiguous in memory. @@ -108,1255 +78,7 @@ fn main( let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12; let dst_idx_base = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row_base; - var acc: array; - -#ifdef MUL_ACC_FLOAT - let k_vec = params.k / VEC_SIZE; - let src1_idx_base_vec = src1_idx_base / VEC_SIZE; - - // Each thread walks K, loads from the vector, and updates - // a small block of output rows held in registers. - for (var k = thread_id; k < k_vec; k += WG_SIZE) { - let x = src1[src1_idx_base_vec + k]; - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let output_row = row_base + row; - if (output_row < params.m) { - let src0_idx = (src0_batch_offset + output_row * params.stride_01) / VEC_SIZE + k; - acc[row] += inner_dot(src0[src0_idx], x); - } - } - } -#endif - -#ifdef MUL_ACC_Q1_0 -#define BLOCK_SIZE 128 -#define BLOCK_SIZE_BYTES 18 -#define THREADS_PER_BLOCK 16 -#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) - - let num_blocks = params.k / BLOCK_SIZE; - let thread_within_block = thread_id % THREADS_PER_BLOCK; - for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { - let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * ELEMS_PER_THREAD; - var x_block: array; - for (var i = 0u; i < ELEMS_PER_THREAD; i++) { - x_block[i] = f32(src1[x_base + i]); - } - - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let output_row = row_base + row; - if (output_row < params.m) { - let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - let d = f32(load_f16_at_src0(block_byte_base)); - let q_byte = load_u32_at_src0(block_byte_base + 2u + thread_within_block) & 0xFFu; - var row_sum = 0.0; - for (var bit = 0u; bit < 8u; bit++) { - let w = select(-d, d, ((q_byte >> bit) & 1u) != 0u); - row_sum += w * x_block[bit]; - } - acc[row] += row_sum; - } - } - } -#endif - -#ifdef MUL_ACC_Q4_0 -#define BLOCK_SIZE 32 -#define BLOCK_SIZE_BYTES 18 -#define THREADS_PER_BLOCK 4 -#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) - - let num_blocks = params.k / BLOCK_SIZE; - let thread_within_block = thread_id % 4; - for (var block = thread_id/THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE/THREADS_PER_BLOCK) { - let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; - var x_block: array; - for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 4] = f32(src1[x_base + i + 16]); - } - - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let output_row = row_base + row; - if (output_row < params.m) { - let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - let d = f32(load_f16_at_src0(block_byte_base)); - var row_sum = 0.0; - - let q_packed = load_u32_at_src0(block_byte_base + 2u + 4u * thread_within_block); - for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { - let q_byte = get_byte(q_packed, byte_idx); - let q_lo = (f32(q_byte & 0xFu) - 8.0) * d; - let q_hi = (f32((q_byte >> 4u) & 0xFu) - 8.0) * d; - row_sum += q_lo * x_block[byte_idx]; - row_sum += q_hi * x_block[byte_idx + 4u]; - } - acc[row] += row_sum; - } - } - } -#endif - -#ifdef MUL_ACC_Q4_1 -#define BLOCK_SIZE 32 -#define BLOCK_SIZE_BYTES 20 -#define THREADS_PER_BLOCK 4 -#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) - - let num_blocks = params.k / BLOCK_SIZE; - let thread_within_block = thread_id % THREADS_PER_BLOCK; - for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { - let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; - var x_block: array; - for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 4] = f32(src1[x_base + i + 16]); - } - - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let output_row = row_base + row; - if (output_row < params.m) { - let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - let d = f32(load_f16_at_src0(block_byte_base)); - let m = f32(load_f16_at_src0(block_byte_base + 2u)); - var row_sum = 0.0; - - let q_packed = load_u32_at_src0(block_byte_base + 4u + 4u * thread_within_block); - for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { - let q_byte = get_byte(q_packed, byte_idx); - let q_lo = f32(q_byte & 0xFu) * d + m; - let q_hi = f32((q_byte >> 4u) & 0xFu) * d + m; - row_sum += q_lo * x_block[byte_idx]; - row_sum += q_hi * x_block[byte_idx + 4u]; - } - acc[row] += row_sum; - } - } - } -#endif - -#ifdef MUL_ACC_Q5_0 -#define BLOCK_SIZE 32 -#define BLOCK_SIZE_BYTES 22 -#define THREADS_PER_BLOCK 4 -#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) - - let num_blocks = params.k / BLOCK_SIZE; - let thread_within_block = thread_id % THREADS_PER_BLOCK; - for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { - let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; - var x_block: array; - for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 4] = f32(src1[x_base + i + 16]); - } - - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let output_row = row_base + row; - if (output_row < params.m) { - let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - let d = f32(load_f16_at_src0(block_byte_base)); - let qh_packed = load_u32_at_src0(block_byte_base + 2u); - let q_packed = load_u32_at_src0(block_byte_base + 6u + 4u * thread_within_block); - let qh_shift = thread_within_block * 4u; - var row_sum = 0.0; - - for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { - let q_byte = get_byte(q_packed, byte_idx); - let qh_lo = ((qh_packed >> (qh_shift + byte_idx)) << 4u) & 0x10u; - let qh_hi = (qh_packed >> (qh_shift + byte_idx + 12u)) & 0x10u; - let q_lo = (f32((q_byte & 0xFu) | qh_lo) - 16.0) * d; - let q_hi = (f32(((q_byte >> 4u) & 0xFu) | qh_hi) - 16.0) * d; - row_sum += q_lo * x_block[byte_idx]; - row_sum += q_hi * x_block[byte_idx + 4u]; - } - acc[row] += row_sum; - } - } - } -#endif - -#ifdef MUL_ACC_Q5_1 -#define BLOCK_SIZE 32 -#define BLOCK_SIZE_BYTES 24 -#define THREADS_PER_BLOCK 4 -#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) - - let num_blocks = params.k / BLOCK_SIZE; - let thread_within_block = thread_id % THREADS_PER_BLOCK; - for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { - let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; - var x_block: array; - for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 4] = f32(src1[x_base + i + 16]); - } - - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let output_row = row_base + row; - if (output_row < params.m) { - let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - let d = f32(load_f16_at_src0(block_byte_base)); - let m = f32(load_f16_at_src0(block_byte_base + 2u)); - let qh_packed = load_u32_at_src0(block_byte_base + 4u); - let q_packed = load_u32_at_src0(block_byte_base + 8u + 4u * thread_within_block); - let qh_shift = thread_within_block * 4u; - var row_sum = 0.0; - - for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { - let q_byte = get_byte(q_packed, byte_idx); - let qh_lo = ((qh_packed >> (qh_shift + byte_idx)) << 4u) & 0x10u; - let qh_hi = (qh_packed >> (qh_shift + byte_idx + 12u)) & 0x10u; - let q_lo = f32((q_byte & 0xFu) | qh_lo) * d + m; - let q_hi = f32(((q_byte >> 4u) & 0xFu) | qh_hi) * d + m; - row_sum += q_lo * x_block[byte_idx]; - row_sum += q_hi * x_block[byte_idx + 4u]; - } - acc[row] += row_sum; - } - } - } -#endif - -#ifdef MUL_ACC_Q8_0 -#define BLOCK_SIZE 32 -#define BLOCK_SIZE_BYTES 34 -#define THREADS_PER_BLOCK 4 -#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) - - let num_blocks = params.k / BLOCK_SIZE; - let thread_within_block = thread_id % THREADS_PER_BLOCK; - for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { - let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * ELEMS_PER_THREAD; - var x_block: array; - for (var i = 0u; i < ELEMS_PER_THREAD; i++) { - x_block[i] = f32(src1[x_base + i]); - } - - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let output_row = row_base + row; - if (output_row < params.m) { - let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - let d = f32(load_f16_at_src0(block_byte_base)); - var row_sum = 0.0; - - for (var packed_idx = 0u; packed_idx < ELEMS_PER_THREAD / 4u; packed_idx++) { - let q_packed = load_u32_at_src0(block_byte_base + 2u + 4u * (thread_within_block * 2u + packed_idx)); - for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { - let q_val = f32(get_byte_i32(q_packed, byte_idx)) * d; - row_sum += q_val * x_block[packed_idx * 4u + byte_idx]; - } - } - acc[row] += row_sum; - } - } - } -#endif - -#ifdef MUL_ACC_Q8_1 -#define BLOCK_SIZE 32 -#define BLOCK_SIZE_BYTES 36 -#define THREADS_PER_BLOCK 4 -#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) - - let num_blocks = params.k / BLOCK_SIZE; - let thread_within_block = thread_id % THREADS_PER_BLOCK; - for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { - let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * ELEMS_PER_THREAD; - var x_block: array; - for (var i = 0u; i < ELEMS_PER_THREAD; i++) { - x_block[i] = f32(src1[x_base + i]); - } - - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let output_row = row_base + row; - if (output_row < params.m) { - let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - let d = f32(load_f16_at_src0(block_byte_base)); - let m = f32(load_f16_at_src0(block_byte_base + 2u)); - var row_sum = 0.0; - - for (var packed_idx = 0u; packed_idx < ELEMS_PER_THREAD / 4u; packed_idx++) { - let q_packed = load_u32_at_src0(block_byte_base + 4u + 4u * (thread_within_block * 2u + packed_idx)); - for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { - let q_val = f32(get_byte_i32(q_packed, byte_idx)) * d + m; - row_sum += q_val * x_block[packed_idx * 4u + byte_idx]; - } - } - acc[row] += row_sum; - } - } - } -#endif - -#ifdef MUL_ACC_Q2_K -#define BLOCK_SIZE 256 -#define BLOCK_SIZE_BYTES 84 -#define THREADS_PER_BLOCK 16 - - let tid = thread_id % THREADS_PER_BLOCK; - let block_group = thread_id / THREADS_PER_BLOCK; - let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; - - let lane = tid / 2u; - let phase = tid % 2u; - let iq = lane / 4u; - let ir = lane % 4u; - let is = ir / 2u; - - let y_offset = 128u * iq + 8u * ir + 4u * phase; - let sc0_byte = 8u * iq + is; - let sc2_byte = 8u * iq + is + 2u; - let sc4_byte = 8u * iq + is + 4u; - let sc6_byte = 8u * iq + is + 6u; - let qs_byte = 16u + (16u * iq + 4u * ir) * 2u + 4u * phase; - - let num_blocks = params.k / BLOCK_SIZE; - - for (var block = block_group; block < num_blocks; block += num_block_groups) { - let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 4u; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 4u] = f32(src1[x_base + 32u + i]); - x_block[i + 8u] = f32(src1[x_base + 64u + i]); - x_block[i + 12u] = f32(src1[x_base + 96u + i]); - } - - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let output_row = row_base + row; - if (output_row < params.m) { - let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - - let dall = f32(load_f16_at_src0(block_byte_base + 80u)); - let dmin = f32(load_f16_at_src0(block_byte_base + 82u)) * (1.0 / 16.0); - - let sc0 = byte_of(load_u32_at_src0_aligned(block_byte_base + sc0_byte), sc0_byte & 3u); - let sc2 = byte_of(load_u32_at_src0_aligned(block_byte_base + sc2_byte), sc2_byte & 3u); - let sc4 = byte_of(load_u32_at_src0_aligned(block_byte_base + sc4_byte), sc4_byte & 3u); - let sc6 = byte_of(load_u32_at_src0_aligned(block_byte_base + sc6_byte), sc6_byte & 3u); - - let q_u32 = load_u32_at_src0_aligned(block_byte_base + qs_byte); - let qs0 = q_u32 & 0xFFFFu; - let qs1 = q_u32 >> 16u; - - var sumy = vec4(0.0, 0.0, 0.0, 0.0); - var acc1 = vec4(0.0, 0.0, 0.0, 0.0); - var acc2 = vec4(0.0, 0.0, 0.0, 0.0); - - sumy[0] = x_block[0] + x_block[1] + x_block[2] + x_block[3]; - sumy[1] = x_block[4] + x_block[5] + x_block[6] + x_block[7]; - sumy[2] = x_block[8] + x_block[9] + x_block[10] + x_block[11]; - sumy[3] = x_block[12] + x_block[13] + x_block[14] + x_block[15]; - - acc1[0] = x_block[0] * f32(qs0 & 0x0003u) + x_block[2] * f32(qs1 & 0x0003u); - acc2[0] = x_block[1] * f32(qs0 & 0x0300u) + x_block[3] * f32(qs1 & 0x0300u); - acc1[1] = x_block[4] * f32(qs0 & 0x000Cu) + x_block[6] * f32(qs1 & 0x000Cu); - acc2[1] = x_block[5] * f32(qs0 & 0x0C00u) + x_block[7] * f32(qs1 & 0x0C00u); - acc1[2] = x_block[8] * f32(qs0 & 0x0030u) + x_block[10] * f32(qs1 & 0x0030u); - acc2[2] = x_block[9] * f32(qs0 & 0x3000u) + x_block[11] * f32(qs1 & 0x3000u); - acc1[3] = x_block[12] * f32(qs0 & 0x00C0u) + x_block[14] * f32(qs1 & 0x00C0u); - acc2[3] = x_block[13] * f32(qs0 & 0xC000u) + x_block[15] * f32(qs1 & 0xC000u); - - acc[row] += dall * ((acc1[0] + (1.0/256.0) * acc2[0]) * f32(sc0 & 0xFu) + - (acc1[1] + (1.0/256.0) * acc2[1]) * f32(sc2 & 0xFu) / 4.0 + - (acc1[2] + (1.0/256.0) * acc2[2]) * f32(sc4 & 0xFu) / 16.0 + - (acc1[3] + (1.0/256.0) * acc2[3]) * f32(sc6 & 0xFu) / 64.0) - - dmin * (sumy[0] * f32(sc0 & 0xF0u) + sumy[1] * f32(sc2 & 0xF0u) + - sumy[2] * f32(sc4 & 0xF0u) + sumy[3] * f32(sc6 & 0xF0u)); - } - } - } -#endif - - -#ifdef MUL_ACC_Q3_K -#define BLOCK_SIZE 256 -#define BLOCK_SIZE_BYTES 110 -#define THREADS_PER_BLOCK 16 - - let tid = thread_id % THREADS_PER_BLOCK; - let block_group = thread_id / THREADS_PER_BLOCK; - let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; - - let lane = tid / 2u; - let phase = tid % 2u; - let ip = lane / 4u; - let il = 2u * ((lane % 4u) / 2u); - let ir = lane % 2u; - let l0 = 8u * ir; - - let q_byte = 32u + 32u * ip + l0 + 16u * phase; - let h_byte = l0 + 16u * phase; - let y_offset = 128u * ip + 32u * il + l0 + 16u * phase; - - let s_shift1 = 4u * ip; - let s_shift2 = s_shift1 + il; - - let v1 = select(64.0, 4.0, il == 0u); - let v2 = 4.0 * v1; - let shift = 2u * il; - - var qm0: u32; var qm1: u32; var qm2: u32; var qm3: u32; - if (il == 0u) { - qm0 = 0x0003u; qm1 = 0x0300u; qm2 = 0x000Cu; qm3 = 0x0C00u; - } else { - qm0 = 0x0030u; qm1 = 0x3000u; qm2 = 0x00C0u; qm3 = 0xC000u; - } - - let mm_idx = 2u * ip + il / 2u; - var hm0: u32; var hm1: u32; var hm2: u32; var hm3: u32; - switch (mm_idx) { - case 0u: { hm0=0x0001u; hm1=0x0100u; hm2=0x0002u; hm3=0x0200u; } - case 1u: { hm0=0x0004u; hm1=0x0400u; hm2=0x0008u; hm3=0x0800u; } - case 2u: { hm0=0x0010u; hm1=0x1000u; hm2=0x0020u; hm3=0x2000u; } - default: { hm0=0x0040u; hm1=0x4000u; hm2=0x0080u; hm3=0x8000u; } - } - - let num_blocks = params.k / BLOCK_SIZE; - - for (var block = block_group; block < num_blocks; block += num_block_groups) { - let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 8u; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 8u] = f32(src1[x_base + 32u + i]); - } - - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let output_row = row_base + row; - if (output_row < params.m) { - let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - - let d = f32(load_f16_at_src0(block_byte_base + 108u)); - let a_base = 96u; - let a_il0 = load_u16_at_src0(block_byte_base + a_base + il * 2u); - let a_il1 = load_u16_at_src0(block_byte_base + a_base + (il + 1u) * 2u); - let a_4 = load_u16_at_src0(block_byte_base + a_base + 8u); - let a_5 = load_u16_at_src0(block_byte_base + a_base + 10u); - - var scales32 = a_4 | (a_5 << 16u); - let aux32 = ((scales32 >> s_shift2) << 4u) & 0x30303030u; - scales32 = a_il0 | (a_il1 << 16u); - scales32 = ((scales32 >> s_shift1) & 0x0F0F0F0Fu) | aux32; - - let scale0 = f32(i32(byte_of(scales32, phase + 0u)) - 32); - let scale1 = f32(i32(byte_of(scales32, phase + 2u)) - 32); - - let q_u32_0 = load_u32_at_src0(block_byte_base + q_byte + 0u); - let q_u32_1 = load_u32_at_src0(block_byte_base + q_byte + 4u); - let h_u32_0 = load_u32_at_src0(block_byte_base + h_byte + 0u); - let h_u32_1 = load_u32_at_src0(block_byte_base + h_byte + 4u); - - var s1 = 0.0; var s2 = 0.0; var s3 = 0.0; - var s4 = 0.0; var s5 = 0.0; var s6 = 0.0; - - for (var l = 0u; l < 8u; l += 2u) { - let q_u32 = select(q_u32_0, q_u32_1, l >= 4u); - let qs = select(q_u32 & 0xFFFFu, q_u32 >> 16u, (l & 2u) != 0u); - let h_u32 = select(h_u32_0, h_u32_1, l >= 4u); - let hv = select(h_u32 & 0xFFFFu, h_u32 >> 16u, (l & 2u) != 0u); - - s1 += x_block[l + 0u] * f32(qs & qm0); - s2 += x_block[l + 1u] * f32(qs & qm1); - s3 += select(0.0, x_block[l + 0u], (hv & hm0) == 0u) + - select(0.0, x_block[l + 1u], (hv & hm1) == 0u); - s4 += x_block[l + 8u] * f32(qs & qm2); - s5 += x_block[l + 9u] * f32(qs & qm3); - s6 += select(0.0, x_block[l + 8u], (hv & hm2) == 0u) + - select(0.0, x_block[l + 9u], (hv & hm3) == 0u); - } - - let d1 = d * (s1 + (1.0/256.0) * s2 - s3 * v1); - let d2 = d * (s4 + (1.0/256.0) * s5 - s6 * v2); - acc[row] += (d1 * scale0 + 0.25 * d2 * scale1) / f32(1u << shift); - } - } - } -#endif - -#ifdef MUL_ACC_Q4_K -#define BLOCK_SIZE 256 -#define BLOCK_SIZE_BYTES 144 -#define THREADS_PER_BLOCK 16 - - let tid = thread_id % THREADS_PER_BLOCK; - let block_group = thread_id / THREADS_PER_BLOCK; - let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; - - let il = tid / 4u; - let ir = tid % 4u; - let im = il / 2u; - let in = il % 2u; - let l0 = 4u * (2u * ir + in); - - let y_offset = 64u * im + l0; - let q_offset = 32u * im + l0; - let sc0_byte = 4u + im * 2u; - let sc2_byte = 4u + (im + 2u) * 2u; - let sc4_byte = 4u + (im + 4u) * 2u; - - let num_blocks = params.k / BLOCK_SIZE; - - for (var block = block_group; block < num_blocks; block += num_block_groups) { - let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 4u; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 4u] = f32(src1[x_base + 32u + i]); - x_block[i + 8u] = f32(src1[x_base + 128u + i]); - x_block[i + 12u] = f32(src1[x_base + 160u + i]); - } - - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let output_row = row_base + row; - if (output_row < params.m) { - let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - - let d = f32(load_f16_at_src0(block_byte_base + 0u)); - let dmin = f32(load_f16_at_src0(block_byte_base + 2u)); - - let sc0_u32 = load_u32_at_src0_aligned(block_byte_base + sc0_byte); - let sc0 = select(sc0_u32 & 0xFFFFu, sc0_u32 >> 16u, (sc0_byte & 2u) != 0u); - let sc2_u32 = load_u32_at_src0_aligned(block_byte_base + sc2_byte); - let sc2 = select(sc2_u32 & 0xFFFFu, sc2_u32 >> 16u, (sc2_byte & 2u) != 0u); - let sc4_u32 = load_u32_at_src0_aligned(block_byte_base + sc4_byte); - let sc4 = select(sc4_u32 & 0xFFFFu, sc4_u32 >> 16u, (sc4_byte & 2u) != 0u); - - let sc16_0 = sc0 & 0x3F3Fu; - let sc16_1 = sc2 & 0x3F3Fu; - let sc16_2 = (sc4 & 0x0F0Fu) | ((sc0 & 0xC0C0u) >> 2u); - let sc16_3 = ((sc4 >> 4u) & 0x0F0Fu) | ((sc2 & 0xC0C0u) >> 2u); - - let scale0 = f32(sc16_0 & 0xFFu); - let scale1 = f32((sc16_0 >> 8u) & 0xFFu); - let min0 = f32(sc16_1 & 0xFFu); - let min1 = f32((sc16_1 >> 8u) & 0xFFu); - let scale2 = f32(sc16_2 & 0xFFu); - let scale3 = f32((sc16_2 >> 8u) & 0xFFu); - let min2 = f32(sc16_3 & 0xFFu); - let min3 = f32((sc16_3 >> 8u) & 0xFFu); - - let q1_u32 = load_u32_at_src0_aligned(block_byte_base + 16u + q_offset); - let q2_u32 = load_u32_at_src0_aligned(block_byte_base + 80u + q_offset); - - var dot = vec4(0.0, 0.0, 0.0, 0.0); - var sumx = vec4(0.0, 0.0, 0.0, 0.0); - for (var i = 0u; i < 4u; i++) { - let q1b = byte_of(q1_u32, i); - let q2b = byte_of(q2_u32, i); - dot[0] += x_block[i] * f32(q1b & 0x0Fu); - dot[1] += x_block[i + 4u] * f32(q1b >> 4u); - dot[2] += x_block[i + 8u] * f32(q2b & 0x0Fu); - dot[3] += x_block[i + 12u] * f32(q2b >> 4u); - sumx[0] += x_block[i]; - sumx[1] += x_block[i + 4u]; - sumx[2] += x_block[i + 8u]; - sumx[3] += x_block[i + 12u]; - } - - acc[row] += d * (dot[0] * scale0 + dot[1] * scale1 + dot[2] * scale2 + dot[3] * scale3) - - dmin * (sumx[0] * min0 + sumx[1] * min1 + sumx[2] * min2 + sumx[3] * min3); - } - } - } -#endif - -#ifdef MUL_ACC_Q5_K -#define BLOCK_SIZE 256 -#define BLOCK_SIZE_BYTES 176 -#define THREADS_PER_BLOCK 16 - - let tid = thread_id % THREADS_PER_BLOCK; - let block_group = thread_id / THREADS_PER_BLOCK; - let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; - - let il = tid / 4u; - let ir = tid % 4u; - let im = il / 2u; - let in = il % 2u; - let l0 = 4u * (2u * ir + in); - - let y_offset = 64u * im + l0; - let q_offset = 48u + 32u * im + l0; - let qh_offset = 16u + 8u * ir + 4u * in; - let sc0_byte = 4u + im * 2u; - let sc2_byte = 4u + (im + 2u) * 2u; - let sc4_byte = 4u + (im + 4u) * 2u; - - let hm1 = 1u << (2u * im); - let hm2 = hm1 << 1u; - let hm3 = hm1 << 4u; - let hm4 = hm2 << 4u; - - let num_blocks = params.k / BLOCK_SIZE; - - for (var block = block_group; block < num_blocks; block += num_block_groups) { - let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 4u; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 4u] = f32(src1[x_base + 32u + i]); - x_block[i + 8u] = f32(src1[x_base + 128u + i]); - x_block[i + 12u] = f32(src1[x_base + 160u + i]); - } - - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let output_row = row_base + row; - if (output_row < params.m) { - let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - - let d = f32(load_f16_at_src0(block_byte_base + 0u)); - let dmin = f32(load_f16_at_src0(block_byte_base + 2u)); - - let sc0_u32 = load_u32_at_src0_aligned(block_byte_base + sc0_byte); - let sc0 = select(sc0_u32 & 0xFFFFu, sc0_u32 >> 16u, (sc0_byte & 2u) != 0u); - let sc2_u32 = load_u32_at_src0_aligned(block_byte_base + sc2_byte); - let sc2 = select(sc2_u32 & 0xFFFFu, sc2_u32 >> 16u, (sc2_byte & 2u) != 0u); - let sc4_u32 = load_u32_at_src0_aligned(block_byte_base + sc4_byte); - let sc4 = select(sc4_u32 & 0xFFFFu, sc4_u32 >> 16u, (sc4_byte & 2u) != 0u); - - let sc16_0 = sc0 & 0x3F3Fu; - let sc16_1 = sc2 & 0x3F3Fu; - let sc16_2 = (sc4 & 0x0F0Fu) | ((sc0 & 0xC0C0u) >> 2u); - let sc16_3 = ((sc4 >> 4u) & 0x0F0Fu) | ((sc2 & 0xC0C0u) >> 2u); - - let f0 = f32(sc16_0 & 0xFFu); - let f1 = f32((sc16_0 >> 8u) & 0xFFu); - let m0 = f32(sc16_1 & 0xFFu); - let m1 = f32((sc16_1 >> 8u) & 0xFFu); - let f4 = f32(sc16_2 & 0xFFu); - let f5 = f32((sc16_2 >> 8u) & 0xFFu); - let m4 = f32(sc16_3 & 0xFFu); - let m5 = f32((sc16_3 >> 8u) & 0xFFu); - - let q1_u32 = load_u32_at_src0_aligned(block_byte_base + q_offset); - let q2_u32 = load_u32_at_src0_aligned(block_byte_base + q_offset + 64u); - let qh_u32 = load_u32_at_src0_aligned(block_byte_base + qh_offset); - - var vals = vec4(0.0, 0.0, 0.0, 0.0); - var sumy = vec4(0.0, 0.0, 0.0, 0.0); - for (var i = 0u; i < 4u; i++) { - let q1b = byte_of(q1_u32, i); - let q2b = byte_of(q2_u32, i); - let qhb = byte_of(qh_u32, i); - - let yl0 = x_block[i]; - let yl8 = x_block[i + 4u]; - let yh0 = x_block[i + 8u]; - let yh8 = x_block[i + 12u]; - - sumy[0] += yl0; - sumy[1] += yl8; - sumy[2] += yh0; - sumy[3] += yh8; - - let q0 = f32((q1b & 0x0Fu) | select(0u, 0x10u, (qhb & hm1) != 0u)); - let q1 = f32((q1b >> 4u) | select(0u, 0x10u, (qhb & hm2) != 0u)); - let q2 = f32((q2b & 0x0Fu) | select(0u, 0x10u, (qhb & hm3) != 0u)); - let q3 = f32((q2b >> 4u) | select(0u, 0x10u, (qhb & hm4) != 0u)); - - vals[0] += yl0 * q0; - vals[1] += yl8 * q1; - vals[2] += yh0 * q2; - vals[3] += yh8 * q3; - } - - acc[row] += d * (f0 * vals[0] + f1 * vals[1] + f4 * vals[2] + f5 * vals[3]) - - dmin * (sumy[0] * m0 + sumy[1] * m1 + - sumy[2] * m4 + sumy[3] * m5); - } - } - } -#endif - -#ifdef MUL_ACC_Q6_K -#define BLOCK_SIZE 256 -#define BLOCK_SIZE_BYTES 210 -#define THREADS_PER_BLOCK 16 - - let tid = thread_id % THREADS_PER_BLOCK; - let block_group = thread_id / THREADS_PER_BLOCK; - let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; - - let ip = tid / 8u; - let il = tid % 8u; - let l0 = 4u * il; - let is = 8u * ip + l0 / 16u; - - let y_offset = 128u * ip + l0; - let q_offset_l = 64u * ip + l0; - let q_offset_h = 32u * ip + l0; - - let num_blocks = params.k / BLOCK_SIZE; - let sc_base_byte = 192u + (is & ~3u); - let sc_byte_pos = is & 3u; - - for (var block = block_group; block < num_blocks; block += num_block_groups) { - let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var l = 0u; l < 4u; l++) { - x_block[l] = f32(src1[x_base + l]); - x_block[l + 4u] = f32(src1[x_base + 32u + l]); - x_block[l + 8u] = f32(src1[x_base + 64u + l]); - x_block[l + 12u] = f32(src1[x_base + 96u + l]); - } - - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let output_row = row_base + row; - if (output_row < params.m) { - let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - - let d = f32(load_f16_at_src0(block_byte_base + 208u)); - let ql1_u32 = load_u32_at_src0(block_byte_base + q_offset_l); - let ql2_u32 = load_u32_at_src0(block_byte_base + q_offset_l + 32u); - let qh_u32 = load_u32_at_src0(block_byte_base + 128u + q_offset_h); - let sc_u32_0 = load_u32_at_src0(block_byte_base + sc_base_byte); - let sc_u32_1 = load_u32_at_src0(block_byte_base + sc_base_byte + 4u); - - let sc0 = sbyte_of(sc_u32_0, sc_byte_pos); - let sc2 = sbyte_of(sc_u32_0, sc_byte_pos + 2u); - let sc4 = sbyte_of(sc_u32_1, sc_byte_pos); - let sc6 = sbyte_of(sc_u32_1, sc_byte_pos + 2u); - - var sums = vec4(0.0, 0.0, 0.0, 0.0); - - for (var l = 0u; l < 4u; l++) { - let q1b = byte_of(ql1_u32, l); - let q2b = byte_of(ql2_u32, l); - let qhb = byte_of(qh_u32, l); - - let dq0 = f32(i32((q1b & 0x0Fu) | ((qhb & 0x03u) << 4u)) - 32); - let dq1 = f32(i32((q2b & 0x0Fu) | ((qhb & 0x0Cu) << 2u)) - 32); - let dq2 = f32(i32((q1b >> 4u) | (qhb & 0x30u)) - 32); - let dq3 = f32(i32((q2b >> 4u) | ((qhb & 0xC0u) >> 2u)) - 32); - - sums[0] += x_block[l] * dq0; - sums[1] += x_block[l + 4u] * dq1; - sums[2] += x_block[l + 8u] * dq2; - sums[3] += x_block[l + 12u] * dq3; - } - - acc[row] += d * (sums[0] * f32(sc0) + sums[1] * f32(sc2) + - sums[2] * f32(sc4) + sums[3] * f32(sc6)); - } - } - } -#endif - -#ifdef MUL_ACC_IQ1_S -#define BLOCK_SIZE 256 -#define BLOCK_SIZE_BYTES 50 -#define THREADS_PER_BLOCK 16 - - let tid = thread_id % THREADS_PER_BLOCK; - let block_group = thread_id / THREADS_PER_BLOCK; - let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; - - let sub_blk = tid / 2u; - let half = tid % 2u; - let slot0 = half * 2u; - let y_offset = sub_blk * 32u + slot0 * 8u; - - let num_blocks = params.k / BLOCK_SIZE; - - for (var block = block_group; block < num_blocks; block += num_block_groups) { - let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 16u; i++) { - x_block[i] = f32(src1[x_base + i]); - } - - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let output_row = row_base + row; - if (output_row < params.m) { - let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - - let d = f32(load_f16_at_src0(block_byte_base)); - let qh = load_u32_at_src0(block_byte_base + 34u + sub_blk * 2u) & 0xFFFFu; - let dl = d * f32(2u * ((qh >> 12u) & 7u) + 1u); - let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000u) != 0u); - let qs_w = load_u32_at_src0(block_byte_base + 2u + sub_blk * 4u); - - var row_sum = 0.0; - for (var ll = 0u; ll < 2u; ll++) { - let l = slot0 + ll; - let qs_byte = get_byte(qs_w, l); - let ig = (qs_byte | (((qh >> (3u * l)) & 7u) << 8u)) * 8u; - let gw = iq1_grid[ig / 16u]; - let bit_base = (ig % 16u) * 2u; - for (var j = 0u; j < 8u; j++) { - let g = (gw >> (bit_base + j * 2u)) & 3u; - let gs = select(f32(g), f32(g) - 4.0, (g & 2u) != 0u); - row_sum += dl * (gs + delta) * x_block[ll * 8u + j]; - } - } - acc[row] += row_sum; - } - } - } -#endif - -#ifdef MUL_ACC_IQ1_M -#define BLOCK_SIZE 256 -#define BLOCK_SIZE_BYTES 56 -#define THREADS_PER_BLOCK 16 - - let tid = thread_id % THREADS_PER_BLOCK; - let block_group = thread_id / THREADS_PER_BLOCK; - let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; - - let sub_blk = tid / 2u; - let half = tid % 2u; - let slot0 = half * 2u; - let y_offset = sub_blk * 32u + slot0 * 8u; - - let num_blocks = params.k / BLOCK_SIZE; - - for (var block = block_group; block < num_blocks; block += num_block_groups) { - let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 16u; i++) { - x_block[i] = f32(src1[x_base + i]); - } - - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let output_row = row_base + row; - if (output_row < params.m) { - let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - - let sc_lo = load_u32_at_src0(block_byte_base + 48u); - let sc_hi = load_u32_at_src0(block_byte_base + 52u); - let sc0 = sc_lo & 0xFFFFu; - let sc1 = (sc_lo >> 16u) & 0xFFFFu; - let sc2 = sc_hi & 0xFFFFu; - let sc3 = (sc_hi >> 16u) & 0xFFFFu; - let d_bits = (sc0 >> 12u) | ((sc1 >> 8u) & 0xF0u) | ((sc2 >> 4u) & 0xF00u) | (sc3 & 0xF000u); - let d = f32(bitcast>(d_bits)[0]); - - let sc_u16 = select(select(sc2, sc3, sub_blk >= 6u), - select(sc0, sc1, sub_blk >= 2u), - sub_blk < 4u); - - let qs_w = load_u32_at_src0(block_byte_base + sub_blk * 4u); - let qh = load_u32_at_src0(block_byte_base + 32u + sub_blk * 2u) & 0xFFFFu; - let qh_lo = qh & 0xFFu; - let qh_hi = (qh >> 8u) & 0xFFu; - - var row_sum = 0.0; - for (var ll = 0u; ll < 2u; ll++) { - let l = slot0 + ll; - let bit_off = 6u * (sub_blk % 2u) + 3u * (l / 2u); - let sub_scale = (sc_u16 >> bit_off) & 0x7u; - let dl = d * f32(2u * sub_scale + 1u); - let qh_byte = select(qh_lo, qh_hi, l >= 2u); - let ll2 = l % 2u; - let grid_idx = get_byte(qs_w, l) | (((qh_byte >> (4u * ll2)) & 7u) << 8u); - let delta = select(IQ1_DELTA, -IQ1_DELTA, ((qh_byte >> (3u + 4u * ll2)) & 1u) != 0u); - let ig = grid_idx * 8u; - let gw = iq1_grid[ig / 16u]; - let bit_base = (ig % 16u) * 2u; - for (var j = 0u; j < 8u; j++) { - let g = (gw >> (bit_base + j * 2u)) & 3u; - let gs = select(f32(g), f32(g) - 4.0, (g & 2u) != 0u); - row_sum += dl * (gs + delta) * x_block[ll * 8u + j]; - } - } - acc[row] += row_sum; - } - } - } -#endif - -#ifdef MUL_ACC_IQ2_XXS -#define BLOCK_SIZE 256 -#define BLOCK_SIZE_BYTES 66 -#define THREADS_PER_BLOCK 16 - - let tid = thread_id % THREADS_PER_BLOCK; - let block_group = thread_id / THREADS_PER_BLOCK; - let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; - - let sub_blk = tid / 2u; - let half = tid % 2u; - let slot0 = half * 2u; - let y_offset = sub_blk * 32u + slot0 * 8u; - - let num_blocks = params.k / BLOCK_SIZE; - - for (var block = block_group; block < num_blocks; block += num_block_groups) { - let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 16u; i++) { - x_block[i] = f32(src1[x_base + i]); - } - - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let output_row = row_base + row; - if (output_row < params.m) { - let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - let d = f32(load_f16_at_src0(block_byte_base)); - let aux_lo = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u); - let aux_hi = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u + 4u); - let ls = aux_hi >> 28u; - let db = d * (0.5 + f32(ls)) * 0.25; - - var row_sum = 0.0; - for (var ll = 0u; ll < 2u; ll++) { - let l = slot0 + ll; - let grid_idx = (aux_lo >> (8u * l)) & 0xFFu; - let signs_idx = (aux_hi >> (7u * l)) & 0x7Fu; - let signs = (ksigns_iq2xs[signs_idx / 4u] >> ((signs_idx % 4u) * 8u)) & 0xFFu; - let gw_lo = iq2xxs_grid[grid_idx * 2u]; - let gw_hi = iq2xxs_grid[grid_idx * 2u + 1u]; - for (var j = 0u; j < 8u; j++) { - let gw = select(gw_hi, gw_lo, j < 4u); - let b = f32((gw >> ((j & 3u) * 8u)) & 0xFFu); - let s = select(1.0, -1.0, ((signs >> j) & 1u) != 0u); - row_sum += db * b * s * x_block[ll * 8u + j]; - } - } - acc[row] += row_sum; - } - } - } -#endif - -#ifdef MUL_ACC_IQ2_XS -#define BLOCK_SIZE 256 -#define BLOCK_SIZE_BYTES 74 -#define THREADS_PER_BLOCK 16 - - let tid = thread_id % THREADS_PER_BLOCK; - let block_group = thread_id / THREADS_PER_BLOCK; - let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; - - let sub_blk = tid / 2u; - let half = tid % 2u; - let slot0 = half * 2u; - let y_offset = sub_blk * 32u + slot0 * 8u; - - let num_blocks = params.k / BLOCK_SIZE; - - for (var block = block_group; block < num_blocks; block += num_block_groups) { - let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 16u; i++) { - x_block[i] = f32(src1[x_base + i]); - } - - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let output_row = row_base + row; - if (output_row < params.m) { - let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - let d = f32(load_f16_at_src0(block_byte_base)); - let qs_lo = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u); - let qs_hi = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u + 4u); - let scales_word = load_u32_at_src0(block_byte_base + 66u + (sub_blk / 4u) * 4u); - let scales_byte = get_byte(scales_word, sub_blk % 4u); - - var row_sum = 0.0; - for (var ll = 0u; ll < 2u; ll++) { - let l = slot0 + ll; - let qs_word = select(qs_hi, qs_lo, l < 2u); - let half2 = (l % 2u) * 16u; - let qs_val = (qs_word >> half2) & 0xFFFFu; - let grid_idx = qs_val & 0x1FFu; - let signs_idx = (qs_val >> 9u) & 0x7Fu; - let sub_scale = (scales_byte >> (4u * (l / 2u))) & 0xFu; - let db = d * (0.5 + f32(sub_scale)) * 0.25; - let signs = (ksigns_iq2xs[signs_idx / 4u] >> ((signs_idx % 4u) * 8u)) & 0xFFu; - let gw_lo = iq2xs_grid[grid_idx * 2u]; - let gw_hi = iq2xs_grid[grid_idx * 2u + 1u]; - for (var j = 0u; j < 8u; j++) { - let gw = select(gw_hi, gw_lo, j < 4u); - let b = f32((gw >> ((j & 3u) * 8u)) & 0xFFu); - let s = select(1.0, -1.0, ((signs >> j) & 1u) != 0u); - row_sum += db * b * s * x_block[ll * 8u + j]; - } - } - acc[row] += row_sum; - } - } - } -#endif - -#ifdef MUL_ACC_IQ2_S -#define BLOCK_SIZE 256 -#define BLOCK_SIZE_BYTES 82 -#define THREADS_PER_BLOCK 16 - - let tid = thread_id % THREADS_PER_BLOCK; - let block_group = thread_id / THREADS_PER_BLOCK; - let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; - - let sub_blk = tid / 2u; - let half = tid % 2u; - let slot0 = half * 2u; - let y_offset = sub_blk * 32u + slot0 * 8u; - - let num_blocks = params.k / BLOCK_SIZE; - - for (var block = block_group; block < num_blocks; block += num_block_groups) { - let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 16u; i++) { - x_block[i] = f32(src1[x_base + i]); - } - - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let output_row = row_base + row; - if (output_row < params.m) { - let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - let d = f32(load_f16_at_src0(block_byte_base)); - let qs_w = load_u32_at_src0(block_byte_base + 2u + sub_blk * 4u); - let sg_w = load_u32_at_src0(block_byte_base + 34u + sub_blk * 4u); - let qh_word = load_u32_at_src0(block_byte_base + 66u + (sub_blk / 4u) * 4u); - let qh_byte = get_byte(qh_word, sub_blk % 4u); - let sc_word = load_u32_at_src0(block_byte_base + 74u + (sub_blk / 4u) * 4u); - let scales_byte = get_byte(sc_word, sub_blk % 4u); - - var row_sum = 0.0; - for (var ll = 0u; ll < 2u; ll++) { - let l = slot0 + ll; - let qs_byte = get_byte(qs_w, l); - let sign_byte = get_byte(sg_w, l); - let grid_idx = qs_byte | (((qh_byte >> (2u * l)) & 3u) << 8u); - let sub_scale = (scales_byte >> (4u * (l / 2u))) & 0xFu; - let db = d * (0.5 + f32(sub_scale)) * 0.25; - let gw_lo = iq2s_grid[grid_idx * 2u]; - let gw_hi = iq2s_grid[grid_idx * 2u + 1u]; - for (var j = 0u; j < 8u; j++) { - let gw = select(gw_hi, gw_lo, j < 4u); - let b = f32((gw >> ((j & 3u) * 8u)) & 0xFFu); - let s = select(1.0, -1.0, ((sign_byte >> j) & 1u) != 0u); - row_sum += db * b * s * x_block[ll * 8u + j]; - } - } - acc[row] += row_sum; - } - } - } -#endif - -#ifdef MUL_ACC_IQ3_XXS -#define BLOCK_SIZE 256 -#define BLOCK_SIZE_BYTES 98 -#define THREADS_PER_BLOCK 16 - - let tid = thread_id % THREADS_PER_BLOCK; - let block_group = thread_id / THREADS_PER_BLOCK; - let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; - - let sub_blk = tid / 2u; - let half = tid % 2u; - let slot0 = half * 2u; - let y_offset = sub_blk * 32u + slot0 * 8u; - - let num_blocks = params.k / BLOCK_SIZE; - - for (var block = block_group; block < num_blocks; block += num_block_groups) { - let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 16u; i++) { - x_block[i] = f32(src1[x_base + i]); - } - - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let output_row = row_base + row; - if (output_row < params.m) { - let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - let d = f32(load_f16_at_src0(block_byte_base)); - let qs_lo = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u); - let qs_hi = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u + 4u); - let aux = load_u32_at_src0(block_byte_base + 66u + sub_blk * 4u); - let ls = aux >> 28u; - let db = d * (0.5 + f32(ls)) * 0.5; - - var row_sum = 0.0; - for (var ll = 0u; ll < 2u; ll++) { - let l = slot0 + ll; - let qs_word = select(qs_hi, qs_lo, l < 2u); - let byte_pos = (l % 2u) * 2u; - let grid_idx_0 = (qs_word >> (byte_pos * 8u)) & 0xFFu; - let grid_idx_1 = (qs_word >> ((byte_pos + 1u) * 8u)) & 0xFFu; - let signs_idx = (aux >> (7u * l)) & 0x7Fu; - let signs = (ksigns_iq2xs[signs_idx / 4u] >> ((signs_idx % 4u) * 8u)) & 0xFFu; - let grid1 = iq3xxs_grid[grid_idx_0]; - let grid2 = iq3xxs_grid[grid_idx_1]; - for (var j = 0u; j < 4u; j++) { - let b1 = f32((grid1 >> (j * 8u)) & 0xFFu); - let b2 = f32((grid2 >> (j * 8u)) & 0xFFu); - let s1 = select(1.0, -1.0, ((signs >> j) & 1u) != 0u); - let s2 = select(1.0, -1.0, ((signs >> (j + 4u)) & 1u) != 0u); - row_sum += db * b1 * s1 * x_block[ll * 8u + j]; - row_sum += db * b2 * s2 * x_block[ll * 8u + j + 4u]; - } - } - acc[row] += row_sum; - } - } - } -#endif - -#ifdef MUL_ACC_IQ3_S -#define BLOCK_SIZE 256 -#define BLOCK_SIZE_BYTES 110 -#define THREADS_PER_BLOCK 16 - - let tid = thread_id % THREADS_PER_BLOCK; - let block_group = thread_id / THREADS_PER_BLOCK; - let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; - - let sub_blk = tid / 2u; - let half = tid % 2u; - let slot0 = half * 2u; - let y_offset = sub_blk * 32u + slot0 * 8u; - - let num_blocks = params.k / BLOCK_SIZE; - - for (var block = block_group; block < num_blocks; block += num_block_groups) { - let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 16u; i++) { - x_block[i] = f32(src1[x_base + i]); - } - - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let output_row = row_base + row; - if (output_row < params.m) { - let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - let d = f32(load_f16_at_src0(block_byte_base)); - let qs_lo = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u); - let qs_hi = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u + 4u); - let qh_word = load_u32_at_src0(block_byte_base + 66u + (sub_blk / 4u) * 4u); - let qh_byte = get_byte(qh_word, sub_blk % 4u); - let sg_w = load_u32_at_src0(block_byte_base + 74u + sub_blk * 4u); - let sc_word = load_u32_at_src0(block_byte_base + 106u); - let scales_byte = get_byte(sc_word, sub_blk / 2u); - let sub_scale = (scales_byte >> (4u * (sub_blk % 2u))) & 0xFu; - let db = d * (1.0 + 2.0 * f32(sub_scale)); - - var row_sum = 0.0; - for (var ll = 0u; ll < 2u; ll++) { - let l = slot0 + ll; - let qs_word = select(qs_hi, qs_lo, l < 2u); - let byte_pos = (l % 2u) * 2u; - let qs0 = (qs_word >> (byte_pos * 8u)) & 0xFFu; - let qs1 = (qs_word >> ((byte_pos + 1u) * 8u)) & 0xFFu; - let grid_idx_1 = qs0 | (((qh_byte >> (2u * l)) & 1u) << 8u); - let grid_idx_2 = qs1 | (((qh_byte >> (2u * l + 1u)) & 1u) << 8u); - let sign_byte = get_byte(sg_w, l); - let grid1 = iq3s_grid[grid_idx_1]; - let grid2 = iq3s_grid[grid_idx_2]; - for (var j = 0u; j < 4u; j++) { - let b1 = f32((grid1 >> (j * 8u)) & 0xFFu); - let b2 = f32((grid2 >> (j * 8u)) & 0xFFu); - let s1 = select(1.0, -1.0, ((sign_byte >> j) & 1u) != 0u); - let s2 = select(1.0, -1.0, ((sign_byte >> (j + 4u)) & 1u) != 0u); - row_sum += db * b1 * s1 * x_block[ll * 8u + j]; - row_sum += db * b2 * s2 * x_block[ll * 8u + j + 4u]; - } - } - acc[row] += row_sum; - } - } - } -#endif - -#ifdef MUL_ACC_IQ4_NL -#define BLOCK_SIZE 32 -#define BLOCK_SIZE_BYTES 18 -#define THREADS_PER_BLOCK 4 -#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) - - let num_blocks = params.k / BLOCK_SIZE; - let thread_within_block = thread_id % THREADS_PER_BLOCK; - for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { - let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4u; - var x_block: array; - for (var i = 0u; i < ELEMS_PER_THREAD / 2u; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 4u] = f32(src1[x_base + i + 16u]); - } - - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let output_row = row_base + row; - if (output_row < params.m) { - let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - let d = f32(load_f16_at_src0(block_byte_base)); - var row_sum = 0.0; - - let q_packed = load_u32_at_src0(block_byte_base + 2u + 4u * thread_within_block); - for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { - let q_byte = get_byte(q_packed, byte_idx); - let q_lo = f32(kvalues_iq4nl[q_byte & 0xFu]) * d; - let q_hi = f32(kvalues_iq4nl[(q_byte >> 4u) & 0xFu]) * d; - row_sum += q_lo * x_block[byte_idx]; - row_sum += q_hi * x_block[byte_idx + 4u]; - } - acc[row] += row_sum; - } - } - } -#endif - -#ifdef MUL_ACC_IQ4_XS -#define BLOCK_SIZE 256 -#define BLOCK_SIZE_BYTES 136 -#define THREADS_PER_BLOCK 16 - - let tid = thread_id % THREADS_PER_BLOCK; - let block_group = thread_id / THREADS_PER_BLOCK; - let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; - - let sub_blk = tid / 2u; - let half = tid % 2u; - let y_offset = sub_blk * 32u + half * 16u; - - let num_blocks = params.k / BLOCK_SIZE; - - for (var block = block_group; block < num_blocks; block += num_block_groups) { - let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 16u; i++) { - x_block[i] = f32(src1[x_base + i]); - } - - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let output_row = row_base + row; - if (output_row < params.m) { - let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - let d = f32(load_f16_at_src0(block_byte_base)); - let scales_h = load_u16_at_src0(block_byte_base + 2u); - let scales_l_word = load_u32_at_src0(block_byte_base + 4u); - let sl_byte = get_byte(scales_l_word, sub_blk / 2u); - let sl = (sl_byte >> (4u * (sub_blk % 2u))) & 0xFu; - let sh_bits = (scales_h >> (2u * sub_blk)) & 3u; - let ls = i32(sl | (sh_bits << 4u)); - let dl = d * f32(ls - 32); - - let qs_byte_off = 8u + sub_blk * 16u; - let q_w0 = load_u32_at_src0(block_byte_base + qs_byte_off); - let q_w1 = load_u32_at_src0(block_byte_base + qs_byte_off + 4u); - let q_w2 = load_u32_at_src0(block_byte_base + qs_byte_off + 8u); - let q_w3 = load_u32_at_src0(block_byte_base + qs_byte_off + 12u); - - var row_sum = 0.0; - for (var i = 0u; i < 16u; i++) { - let q_word = select( - select(q_w0, q_w1, i >= 4u), - select(q_w2, q_w3, i >= 12u), - i >= 8u); - let q_byte = get_byte(q_word, i % 4u); - let nib = select(q_byte & 0xFu, (q_byte >> 4u) & 0xFu, half == 1u); - row_sum += f32(kvalues_iq4nl[nib]) * dl * x_block[i]; - } - acc[row] += row_sum; - } - } - } -#endif + let acc = accumulate_vec_dot(thread_id, row_base, src0_batch_offset, src1_idx_base); #ifdef USE_SUBGROUP_REDUCTION for (var row = 0u; row < OUTPUTS_PER_WG; row++) { diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl new file mode 100644 index 00000000000..1f59bd14863 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl @@ -0,0 +1,1391 @@ +#ifdef U32_DEQUANT_HELPERS +#define SRC0_TYPE u32 + +fn byte_of(v: u32, b: u32) -> u32 { + return (v >> (b * 8u)) & 0xFFu; +} + +fn sbyte_of(v: u32, b: u32) -> i32 { + let raw = i32((v >> (b * 8u)) & 0xFFu); + return select(raw, raw - 256, raw >= 128); +} +#endif + +#ifdef VEC +#define VEC_SIZE 4u +#define SRC0_TYPE vec4 +#define SRC1_TYPE vec4 + +fn inner_dot(src0_val: SRC0_TYPE, src1_val: SRC1_TYPE) -> f32 { + return f32(dot(SRC1_TYPE(src0_val), src1_val)); +} +#endif + +#ifdef SCALAR +#define VEC_SIZE 1u +#define SRC0_TYPE SRC0_INNER_TYPE +#define SRC1_TYPE SRC1_INNER_TYPE + +fn inner_dot(src0_val: SRC0_TYPE, src1_val: SRC1_TYPE) -> f32 { + return f32(src0_val) * f32(src1_val); +} +#endif + +#ifdef MUL_ACC_FLOAT +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { + var acc: array; + + let k_vec = params.k / VEC_SIZE; + let src1_idx_base_vec = src1_idx_base / VEC_SIZE; + + // Each thread walks K, loads from the vector, and updates + // a small block of output rows held in registers. + for (var k = thread_id; k < k_vec; k += WG_SIZE) { + let x = src1[src1_idx_base_vec + k]; + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let src0_idx = (src0_batch_offset + output_row * params.stride_01) / VEC_SIZE + k; + acc[row] += inner_dot(src0[src0_idx], x); + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_Q1_0 +#define BLOCK_SIZE 128 +#define BLOCK_SIZE_BYTES 18 +#define THREADS_PER_BLOCK 16 +#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { + var acc: array; + + let num_blocks = params.k / BLOCK_SIZE; + let thread_within_block = thread_id % THREADS_PER_BLOCK; + for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { + let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * ELEMS_PER_THREAD; + var x_block: array; + for (var i = 0u; i < ELEMS_PER_THREAD; i++) { + x_block[i] = f32(src1[x_base + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + let q_byte = load_u32_at_src0(block_byte_base + 2u + thread_within_block) & 0xFFu; + var row_sum = 0.0; + for (var bit = 0u; bit < 8u; bit++) { + let w = select(-d, d, ((q_byte >> bit) & 1u) != 0u); + row_sum += w * x_block[bit]; + } + acc[row] += row_sum; + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_Q4_0 +#define BLOCK_SIZE 32 +#define BLOCK_SIZE_BYTES 18 +#define THREADS_PER_BLOCK 4 +#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { + var acc: array; + + let num_blocks = params.k / BLOCK_SIZE; + let thread_within_block = thread_id % 4; + for (var block = thread_id/THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE/THREADS_PER_BLOCK) { + let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; + var x_block: array; + for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 4] = f32(src1[x_base + i + 16]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + var row_sum = 0.0; + + let q_packed = load_u32_at_src0(block_byte_base + 2u + 4u * thread_within_block); + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_byte = get_byte(q_packed, byte_idx); + let q_lo = (f32(q_byte & 0xFu) - 8.0) * d; + let q_hi = (f32((q_byte >> 4u) & 0xFu) - 8.0) * d; + row_sum += q_lo * x_block[byte_idx]; + row_sum += q_hi * x_block[byte_idx + 4u]; + } + acc[row] += row_sum; + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_Q4_1 +#define BLOCK_SIZE 32 +#define BLOCK_SIZE_BYTES 20 +#define THREADS_PER_BLOCK 4 +#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { + var acc: array; + + let num_blocks = params.k / BLOCK_SIZE; + let thread_within_block = thread_id % THREADS_PER_BLOCK; + for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { + let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; + var x_block: array; + for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 4] = f32(src1[x_base + i + 16]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + let m = f32(load_f16_at_src0(block_byte_base + 2u)); + var row_sum = 0.0; + + let q_packed = load_u32_at_src0(block_byte_base + 4u + 4u * thread_within_block); + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_byte = get_byte(q_packed, byte_idx); + let q_lo = f32(q_byte & 0xFu) * d + m; + let q_hi = f32((q_byte >> 4u) & 0xFu) * d + m; + row_sum += q_lo * x_block[byte_idx]; + row_sum += q_hi * x_block[byte_idx + 4u]; + } + acc[row] += row_sum; + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_Q5_0 +#define BLOCK_SIZE 32 +#define BLOCK_SIZE_BYTES 22 +#define THREADS_PER_BLOCK 4 +#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { + var acc: array; + + let num_blocks = params.k / BLOCK_SIZE; + let thread_within_block = thread_id % THREADS_PER_BLOCK; + for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { + let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; + var x_block: array; + for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 4] = f32(src1[x_base + i + 16]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + let qh_packed = load_u32_at_src0(block_byte_base + 2u); + let q_packed = load_u32_at_src0(block_byte_base + 6u + 4u * thread_within_block); + let qh_shift = thread_within_block * 4u; + var row_sum = 0.0; + + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_byte = get_byte(q_packed, byte_idx); + let qh_lo = ((qh_packed >> (qh_shift + byte_idx)) << 4u) & 0x10u; + let qh_hi = (qh_packed >> (qh_shift + byte_idx + 12u)) & 0x10u; + let q_lo = (f32((q_byte & 0xFu) | qh_lo) - 16.0) * d; + let q_hi = (f32(((q_byte >> 4u) & 0xFu) | qh_hi) - 16.0) * d; + row_sum += q_lo * x_block[byte_idx]; + row_sum += q_hi * x_block[byte_idx + 4u]; + } + acc[row] += row_sum; + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_Q5_1 +#define BLOCK_SIZE 32 +#define BLOCK_SIZE_BYTES 24 +#define THREADS_PER_BLOCK 4 +#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { + var acc: array; + + let num_blocks = params.k / BLOCK_SIZE; + let thread_within_block = thread_id % THREADS_PER_BLOCK; + for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { + let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; + var x_block: array; + for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 4] = f32(src1[x_base + i + 16]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + let m = f32(load_f16_at_src0(block_byte_base + 2u)); + let qh_packed = load_u32_at_src0(block_byte_base + 4u); + let q_packed = load_u32_at_src0(block_byte_base + 8u + 4u * thread_within_block); + let qh_shift = thread_within_block * 4u; + var row_sum = 0.0; + + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_byte = get_byte(q_packed, byte_idx); + let qh_lo = ((qh_packed >> (qh_shift + byte_idx)) << 4u) & 0x10u; + let qh_hi = (qh_packed >> (qh_shift + byte_idx + 12u)) & 0x10u; + let q_lo = f32((q_byte & 0xFu) | qh_lo) * d + m; + let q_hi = f32(((q_byte >> 4u) & 0xFu) | qh_hi) * d + m; + row_sum += q_lo * x_block[byte_idx]; + row_sum += q_hi * x_block[byte_idx + 4u]; + } + acc[row] += row_sum; + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_Q8_0 +#define BLOCK_SIZE 32 +#define BLOCK_SIZE_BYTES 34 +#define THREADS_PER_BLOCK 4 +#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { + var acc: array; + + let num_blocks = params.k / BLOCK_SIZE; + let thread_within_block = thread_id % THREADS_PER_BLOCK; + for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { + let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * ELEMS_PER_THREAD; + var x_block: array; + for (var i = 0u; i < ELEMS_PER_THREAD; i++) { + x_block[i] = f32(src1[x_base + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + var row_sum = 0.0; + + for (var packed_idx = 0u; packed_idx < ELEMS_PER_THREAD / 4u; packed_idx++) { + let q_packed = load_u32_at_src0(block_byte_base + 2u + 4u * (thread_within_block * 2u + packed_idx)); + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_val = f32(get_byte_i32(q_packed, byte_idx)) * d; + row_sum += q_val * x_block[packed_idx * 4u + byte_idx]; + } + } + acc[row] += row_sum; + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_Q8_1 +#define BLOCK_SIZE 32 +#define BLOCK_SIZE_BYTES 36 +#define THREADS_PER_BLOCK 4 +#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { + var acc: array; + + let num_blocks = params.k / BLOCK_SIZE; + let thread_within_block = thread_id % THREADS_PER_BLOCK; + for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { + let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * ELEMS_PER_THREAD; + var x_block: array; + for (var i = 0u; i < ELEMS_PER_THREAD; i++) { + x_block[i] = f32(src1[x_base + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + let m = f32(load_f16_at_src0(block_byte_base + 2u)); + var row_sum = 0.0; + + for (var packed_idx = 0u; packed_idx < ELEMS_PER_THREAD / 4u; packed_idx++) { + let q_packed = load_u32_at_src0(block_byte_base + 4u + 4u * (thread_within_block * 2u + packed_idx)); + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_val = f32(get_byte_i32(q_packed, byte_idx)) * d + m; + row_sum += q_val * x_block[packed_idx * 4u + byte_idx]; + } + } + acc[row] += row_sum; + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_Q2_K +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 84 +#define THREADS_PER_BLOCK 16 +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { + var acc: array; + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let lane = tid / 2u; + let phase = tid % 2u; + let iq = lane / 4u; + let ir = lane % 4u; + let is = ir / 2u; + + let y_offset = 128u * iq + 8u * ir + 4u * phase; + let sc0_byte = 8u * iq + is; + let sc2_byte = 8u * iq + is + 2u; + let sc4_byte = 8u * iq + is + 4u; + let sc6_byte = 8u * iq + is + 6u; + let qs_byte = 16u + (16u * iq + 4u * ir) * 2u + 4u * phase; + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; + for (var i = 0u; i < 4u; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 4u] = f32(src1[x_base + 32u + i]); + x_block[i + 8u] = f32(src1[x_base + 64u + i]); + x_block[i + 12u] = f32(src1[x_base + 96u + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + + let dall = f32(load_f16_at_src0(block_byte_base + 80u)); + let dmin = f32(load_f16_at_src0(block_byte_base + 82u)) * (1.0 / 16.0); + + let sc0 = byte_of(load_u32_at_src0_aligned(block_byte_base + sc0_byte), sc0_byte & 3u); + let sc2 = byte_of(load_u32_at_src0_aligned(block_byte_base + sc2_byte), sc2_byte & 3u); + let sc4 = byte_of(load_u32_at_src0_aligned(block_byte_base + sc4_byte), sc4_byte & 3u); + let sc6 = byte_of(load_u32_at_src0_aligned(block_byte_base + sc6_byte), sc6_byte & 3u); + + let q_u32 = load_u32_at_src0_aligned(block_byte_base + qs_byte); + let qs0 = q_u32 & 0xFFFFu; + let qs1 = q_u32 >> 16u; + + var sumy = vec4(0.0, 0.0, 0.0, 0.0); + var acc1 = vec4(0.0, 0.0, 0.0, 0.0); + var acc2 = vec4(0.0, 0.0, 0.0, 0.0); + + sumy[0] = x_block[0] + x_block[1] + x_block[2] + x_block[3]; + sumy[1] = x_block[4] + x_block[5] + x_block[6] + x_block[7]; + sumy[2] = x_block[8] + x_block[9] + x_block[10] + x_block[11]; + sumy[3] = x_block[12] + x_block[13] + x_block[14] + x_block[15]; + + acc1[0] = x_block[0] * f32(qs0 & 0x0003u) + x_block[2] * f32(qs1 & 0x0003u); + acc2[0] = x_block[1] * f32(qs0 & 0x0300u) + x_block[3] * f32(qs1 & 0x0300u); + acc1[1] = x_block[4] * f32(qs0 & 0x000Cu) + x_block[6] * f32(qs1 & 0x000Cu); + acc2[1] = x_block[5] * f32(qs0 & 0x0C00u) + x_block[7] * f32(qs1 & 0x0C00u); + acc1[2] = x_block[8] * f32(qs0 & 0x0030u) + x_block[10] * f32(qs1 & 0x0030u); + acc2[2] = x_block[9] * f32(qs0 & 0x3000u) + x_block[11] * f32(qs1 & 0x3000u); + acc1[3] = x_block[12] * f32(qs0 & 0x00C0u) + x_block[14] * f32(qs1 & 0x00C0u); + acc2[3] = x_block[13] * f32(qs0 & 0xC000u) + x_block[15] * f32(qs1 & 0xC000u); + + acc[row] += dall * ((acc1[0] + (1.0/256.0) * acc2[0]) * f32(sc0 & 0xFu) + + (acc1[1] + (1.0/256.0) * acc2[1]) * f32(sc2 & 0xFu) / 4.0 + + (acc1[2] + (1.0/256.0) * acc2[2]) * f32(sc4 & 0xFu) / 16.0 + + (acc1[3] + (1.0/256.0) * acc2[3]) * f32(sc6 & 0xFu) / 64.0) + - dmin * (sumy[0] * f32(sc0 & 0xF0u) + sumy[1] * f32(sc2 & 0xF0u) + + sumy[2] * f32(sc4 & 0xF0u) + sumy[3] * f32(sc6 & 0xF0u)); + } + } + } + + return acc; +} +#endif + + +#ifdef MUL_ACC_Q3_K +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 110 +#define THREADS_PER_BLOCK 16 +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { + var acc: array; + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let lane = tid / 2u; + let phase = tid % 2u; + let ip = lane / 4u; + let il = 2u * ((lane % 4u) / 2u); + let ir = lane % 2u; + let l0 = 8u * ir; + + let q_byte = 32u + 32u * ip + l0 + 16u * phase; + let h_byte = l0 + 16u * phase; + let y_offset = 128u * ip + 32u * il + l0 + 16u * phase; + + let s_shift1 = 4u * ip; + let s_shift2 = s_shift1 + il; + + let v1 = select(64.0, 4.0, il == 0u); + let v2 = 4.0 * v1; + let shift = 2u * il; + + var qm0: u32; var qm1: u32; var qm2: u32; var qm3: u32; + if (il == 0u) { + qm0 = 0x0003u; qm1 = 0x0300u; qm2 = 0x000Cu; qm3 = 0x0C00u; + } else { + qm0 = 0x0030u; qm1 = 0x3000u; qm2 = 0x00C0u; qm3 = 0xC000u; + } + + let mm_idx = 2u * ip + il / 2u; + var hm0: u32; var hm1: u32; var hm2: u32; var hm3: u32; + switch (mm_idx) { + case 0u: { hm0=0x0001u; hm1=0x0100u; hm2=0x0002u; hm3=0x0200u; } + case 1u: { hm0=0x0004u; hm1=0x0400u; hm2=0x0008u; hm3=0x0800u; } + case 2u: { hm0=0x0010u; hm1=0x1000u; hm2=0x0020u; hm3=0x2000u; } + default: { hm0=0x0040u; hm1=0x4000u; hm2=0x0080u; hm3=0x8000u; } + } + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; + for (var i = 0u; i < 8u; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 8u] = f32(src1[x_base + 32u + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + + let d = f32(load_f16_at_src0(block_byte_base + 108u)); + let a_base = 96u; + let a_il0 = load_u16_at_src0(block_byte_base + a_base + il * 2u); + let a_il1 = load_u16_at_src0(block_byte_base + a_base + (il + 1u) * 2u); + let a_4 = load_u16_at_src0(block_byte_base + a_base + 8u); + let a_5 = load_u16_at_src0(block_byte_base + a_base + 10u); + + var scales32 = a_4 | (a_5 << 16u); + let aux32 = ((scales32 >> s_shift2) << 4u) & 0x30303030u; + scales32 = a_il0 | (a_il1 << 16u); + scales32 = ((scales32 >> s_shift1) & 0x0F0F0F0Fu) | aux32; + + let scale0 = f32(i32(byte_of(scales32, phase + 0u)) - 32); + let scale1 = f32(i32(byte_of(scales32, phase + 2u)) - 32); + + let q_u32_0 = load_u32_at_src0(block_byte_base + q_byte + 0u); + let q_u32_1 = load_u32_at_src0(block_byte_base + q_byte + 4u); + let h_u32_0 = load_u32_at_src0(block_byte_base + h_byte + 0u); + let h_u32_1 = load_u32_at_src0(block_byte_base + h_byte + 4u); + + var s1 = 0.0; var s2 = 0.0; var s3 = 0.0; + var s4 = 0.0; var s5 = 0.0; var s6 = 0.0; + + for (var l = 0u; l < 8u; l += 2u) { + let q_u32 = select(q_u32_0, q_u32_1, l >= 4u); + let qs = select(q_u32 & 0xFFFFu, q_u32 >> 16u, (l & 2u) != 0u); + let h_u32 = select(h_u32_0, h_u32_1, l >= 4u); + let hv = select(h_u32 & 0xFFFFu, h_u32 >> 16u, (l & 2u) != 0u); + + s1 += x_block[l + 0u] * f32(qs & qm0); + s2 += x_block[l + 1u] * f32(qs & qm1); + s3 += select(0.0, x_block[l + 0u], (hv & hm0) == 0u) + + select(0.0, x_block[l + 1u], (hv & hm1) == 0u); + s4 += x_block[l + 8u] * f32(qs & qm2); + s5 += x_block[l + 9u] * f32(qs & qm3); + s6 += select(0.0, x_block[l + 8u], (hv & hm2) == 0u) + + select(0.0, x_block[l + 9u], (hv & hm3) == 0u); + } + + let d1 = d * (s1 + (1.0/256.0) * s2 - s3 * v1); + let d2 = d * (s4 + (1.0/256.0) * s5 - s6 * v2); + acc[row] += (d1 * scale0 + 0.25 * d2 * scale1) / f32(1u << shift); + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_Q4_K +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 144 +#define THREADS_PER_BLOCK 16 +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { + var acc: array; + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let il = tid / 4u; + let ir = tid % 4u; + let im = il / 2u; + let in = il % 2u; + let l0 = 4u * (2u * ir + in); + + let y_offset = 64u * im + l0; + let q_offset = 32u * im + l0; + let sc0_byte = 4u + im * 2u; + let sc2_byte = 4u + (im + 2u) * 2u; + let sc4_byte = 4u + (im + 4u) * 2u; + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; + for (var i = 0u; i < 4u; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 4u] = f32(src1[x_base + 32u + i]); + x_block[i + 8u] = f32(src1[x_base + 128u + i]); + x_block[i + 12u] = f32(src1[x_base + 160u + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + + let d = f32(load_f16_at_src0(block_byte_base + 0u)); + let dmin = f32(load_f16_at_src0(block_byte_base + 2u)); + + let sc0_u32 = load_u32_at_src0_aligned(block_byte_base + sc0_byte); + let sc0 = select(sc0_u32 & 0xFFFFu, sc0_u32 >> 16u, (sc0_byte & 2u) != 0u); + let sc2_u32 = load_u32_at_src0_aligned(block_byte_base + sc2_byte); + let sc2 = select(sc2_u32 & 0xFFFFu, sc2_u32 >> 16u, (sc2_byte & 2u) != 0u); + let sc4_u32 = load_u32_at_src0_aligned(block_byte_base + sc4_byte); + let sc4 = select(sc4_u32 & 0xFFFFu, sc4_u32 >> 16u, (sc4_byte & 2u) != 0u); + + let sc16_0 = sc0 & 0x3F3Fu; + let sc16_1 = sc2 & 0x3F3Fu; + let sc16_2 = (sc4 & 0x0F0Fu) | ((sc0 & 0xC0C0u) >> 2u); + let sc16_3 = ((sc4 >> 4u) & 0x0F0Fu) | ((sc2 & 0xC0C0u) >> 2u); + + let scale0 = f32(sc16_0 & 0xFFu); + let scale1 = f32((sc16_0 >> 8u) & 0xFFu); + let min0 = f32(sc16_1 & 0xFFu); + let min1 = f32((sc16_1 >> 8u) & 0xFFu); + let scale2 = f32(sc16_2 & 0xFFu); + let scale3 = f32((sc16_2 >> 8u) & 0xFFu); + let min2 = f32(sc16_3 & 0xFFu); + let min3 = f32((sc16_3 >> 8u) & 0xFFu); + + let q1_u32 = load_u32_at_src0_aligned(block_byte_base + 16u + q_offset); + let q2_u32 = load_u32_at_src0_aligned(block_byte_base + 80u + q_offset); + + var dot = vec4(0.0, 0.0, 0.0, 0.0); + var sumx = vec4(0.0, 0.0, 0.0, 0.0); + for (var i = 0u; i < 4u; i++) { + let q1b = byte_of(q1_u32, i); + let q2b = byte_of(q2_u32, i); + dot[0] += x_block[i] * f32(q1b & 0x0Fu); + dot[1] += x_block[i + 4u] * f32(q1b >> 4u); + dot[2] += x_block[i + 8u] * f32(q2b & 0x0Fu); + dot[3] += x_block[i + 12u] * f32(q2b >> 4u); + sumx[0] += x_block[i]; + sumx[1] += x_block[i + 4u]; + sumx[2] += x_block[i + 8u]; + sumx[3] += x_block[i + 12u]; + } + + acc[row] += d * (dot[0] * scale0 + dot[1] * scale1 + dot[2] * scale2 + dot[3] * scale3) + - dmin * (sumx[0] * min0 + sumx[1] * min1 + sumx[2] * min2 + sumx[3] * min3); + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_Q5_K +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 176 +#define THREADS_PER_BLOCK 16 +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { + var acc: array; + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let il = tid / 4u; + let ir = tid % 4u; + let im = il / 2u; + let in = il % 2u; + let l0 = 4u * (2u * ir + in); + + let y_offset = 64u * im + l0; + let q_offset = 48u + 32u * im + l0; + let qh_offset = 16u + 8u * ir + 4u * in; + let sc0_byte = 4u + im * 2u; + let sc2_byte = 4u + (im + 2u) * 2u; + let sc4_byte = 4u + (im + 4u) * 2u; + + let hm1 = 1u << (2u * im); + let hm2 = hm1 << 1u; + let hm3 = hm1 << 4u; + let hm4 = hm2 << 4u; + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; + for (var i = 0u; i < 4u; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 4u] = f32(src1[x_base + 32u + i]); + x_block[i + 8u] = f32(src1[x_base + 128u + i]); + x_block[i + 12u] = f32(src1[x_base + 160u + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + + let d = f32(load_f16_at_src0(block_byte_base + 0u)); + let dmin = f32(load_f16_at_src0(block_byte_base + 2u)); + + let sc0_u32 = load_u32_at_src0_aligned(block_byte_base + sc0_byte); + let sc0 = select(sc0_u32 & 0xFFFFu, sc0_u32 >> 16u, (sc0_byte & 2u) != 0u); + let sc2_u32 = load_u32_at_src0_aligned(block_byte_base + sc2_byte); + let sc2 = select(sc2_u32 & 0xFFFFu, sc2_u32 >> 16u, (sc2_byte & 2u) != 0u); + let sc4_u32 = load_u32_at_src0_aligned(block_byte_base + sc4_byte); + let sc4 = select(sc4_u32 & 0xFFFFu, sc4_u32 >> 16u, (sc4_byte & 2u) != 0u); + + let sc16_0 = sc0 & 0x3F3Fu; + let sc16_1 = sc2 & 0x3F3Fu; + let sc16_2 = (sc4 & 0x0F0Fu) | ((sc0 & 0xC0C0u) >> 2u); + let sc16_3 = ((sc4 >> 4u) & 0x0F0Fu) | ((sc2 & 0xC0C0u) >> 2u); + + let f0 = f32(sc16_0 & 0xFFu); + let f1 = f32((sc16_0 >> 8u) & 0xFFu); + let m0 = f32(sc16_1 & 0xFFu); + let m1 = f32((sc16_1 >> 8u) & 0xFFu); + let f4 = f32(sc16_2 & 0xFFu); + let f5 = f32((sc16_2 >> 8u) & 0xFFu); + let m4 = f32(sc16_3 & 0xFFu); + let m5 = f32((sc16_3 >> 8u) & 0xFFu); + + let q1_u32 = load_u32_at_src0_aligned(block_byte_base + q_offset); + let q2_u32 = load_u32_at_src0_aligned(block_byte_base + q_offset + 64u); + let qh_u32 = load_u32_at_src0_aligned(block_byte_base + qh_offset); + + var vals = vec4(0.0, 0.0, 0.0, 0.0); + var sumy = vec4(0.0, 0.0, 0.0, 0.0); + for (var i = 0u; i < 4u; i++) { + let q1b = byte_of(q1_u32, i); + let q2b = byte_of(q2_u32, i); + let qhb = byte_of(qh_u32, i); + + let yl0 = x_block[i]; + let yl8 = x_block[i + 4u]; + let yh0 = x_block[i + 8u]; + let yh8 = x_block[i + 12u]; + + sumy[0] += yl0; + sumy[1] += yl8; + sumy[2] += yh0; + sumy[3] += yh8; + + let q0 = f32((q1b & 0x0Fu) | select(0u, 0x10u, (qhb & hm1) != 0u)); + let q1 = f32((q1b >> 4u) | select(0u, 0x10u, (qhb & hm2) != 0u)); + let q2 = f32((q2b & 0x0Fu) | select(0u, 0x10u, (qhb & hm3) != 0u)); + let q3 = f32((q2b >> 4u) | select(0u, 0x10u, (qhb & hm4) != 0u)); + + vals[0] += yl0 * q0; + vals[1] += yl8 * q1; + vals[2] += yh0 * q2; + vals[3] += yh8 * q3; + } + + acc[row] += d * (f0 * vals[0] + f1 * vals[1] + f4 * vals[2] + f5 * vals[3]) + - dmin * (sumy[0] * m0 + sumy[1] * m1 + + sumy[2] * m4 + sumy[3] * m5); + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_Q6_K +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 210 +#define THREADS_PER_BLOCK 16 +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { + var acc: array; + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let ip = tid / 8u; + let il = tid % 8u; + let l0 = 4u * il; + let is = 8u * ip + l0 / 16u; + + let y_offset = 128u * ip + l0; + let q_offset_l = 64u * ip + l0; + let q_offset_h = 32u * ip + l0; + + let num_blocks = params.k / BLOCK_SIZE; + let sc_base_byte = 192u + (is & ~3u); + let sc_byte_pos = is & 3u; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; + for (var l = 0u; l < 4u; l++) { + x_block[l] = f32(src1[x_base + l]); + x_block[l + 4u] = f32(src1[x_base + 32u + l]); + x_block[l + 8u] = f32(src1[x_base + 64u + l]); + x_block[l + 12u] = f32(src1[x_base + 96u + l]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + + let d = f32(load_f16_at_src0(block_byte_base + 208u)); + let ql1_u32 = load_u32_at_src0(block_byte_base + q_offset_l); + let ql2_u32 = load_u32_at_src0(block_byte_base + q_offset_l + 32u); + let qh_u32 = load_u32_at_src0(block_byte_base + 128u + q_offset_h); + let sc_u32_0 = load_u32_at_src0(block_byte_base + sc_base_byte); + let sc_u32_1 = load_u32_at_src0(block_byte_base + sc_base_byte + 4u); + + let sc0 = sbyte_of(sc_u32_0, sc_byte_pos); + let sc2 = sbyte_of(sc_u32_0, sc_byte_pos + 2u); + let sc4 = sbyte_of(sc_u32_1, sc_byte_pos); + let sc6 = sbyte_of(sc_u32_1, sc_byte_pos + 2u); + + var sums = vec4(0.0, 0.0, 0.0, 0.0); + + for (var l = 0u; l < 4u; l++) { + let q1b = byte_of(ql1_u32, l); + let q2b = byte_of(ql2_u32, l); + let qhb = byte_of(qh_u32, l); + + let dq0 = f32(i32((q1b & 0x0Fu) | ((qhb & 0x03u) << 4u)) - 32); + let dq1 = f32(i32((q2b & 0x0Fu) | ((qhb & 0x0Cu) << 2u)) - 32); + let dq2 = f32(i32((q1b >> 4u) | (qhb & 0x30u)) - 32); + let dq3 = f32(i32((q2b >> 4u) | ((qhb & 0xC0u) >> 2u)) - 32); + + sums[0] += x_block[l] * dq0; + sums[1] += x_block[l + 4u] * dq1; + sums[2] += x_block[l + 8u] * dq2; + sums[3] += x_block[l + 12u] * dq3; + } + + acc[row] += d * (sums[0] * f32(sc0) + sums[1] * f32(sc2) + + sums[2] * f32(sc4) + sums[3] * f32(sc6)); + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_IQ1_S +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 50 +#define THREADS_PER_BLOCK 16 +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { + var acc: array; + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let sub_blk = tid / 2u; + let half = tid % 2u; + let slot0 = half * 2u; + let y_offset = sub_blk * 32u + slot0 * 8u; + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; + for (var i = 0u; i < 16u; i++) { + x_block[i] = f32(src1[x_base + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + + let d = f32(load_f16_at_src0(block_byte_base)); + let qh = load_u32_at_src0(block_byte_base + 34u + sub_blk * 2u) & 0xFFFFu; + let dl = d * f32(2u * ((qh >> 12u) & 7u) + 1u); + let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000u) != 0u); + let qs_w = load_u32_at_src0(block_byte_base + 2u + sub_blk * 4u); + + var row_sum = 0.0; + for (var ll = 0u; ll < 2u; ll++) { + let l = slot0 + ll; + let qs_byte = get_byte(qs_w, l); + let ig = (qs_byte | (((qh >> (3u * l)) & 7u) << 8u)) * 8u; + let gw = iq1_grid[ig / 16u]; + let bit_base = (ig % 16u) * 2u; + for (var j = 0u; j < 8u; j++) { + let g = (gw >> (bit_base + j * 2u)) & 3u; + let gs = select(f32(g), f32(g) - 4.0, (g & 2u) != 0u); + row_sum += dl * (gs + delta) * x_block[ll * 8u + j]; + } + } + acc[row] += row_sum; + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_IQ1_M +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 56 +#define THREADS_PER_BLOCK 16 +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { + var acc: array; + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let sub_blk = tid / 2u; + let half = tid % 2u; + let slot0 = half * 2u; + let y_offset = sub_blk * 32u + slot0 * 8u; + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; + for (var i = 0u; i < 16u; i++) { + x_block[i] = f32(src1[x_base + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + + let sc_lo = load_u32_at_src0(block_byte_base + 48u); + let sc_hi = load_u32_at_src0(block_byte_base + 52u); + let sc0 = sc_lo & 0xFFFFu; + let sc1 = (sc_lo >> 16u) & 0xFFFFu; + let sc2 = sc_hi & 0xFFFFu; + let sc3 = (sc_hi >> 16u) & 0xFFFFu; + let d_bits = (sc0 >> 12u) | ((sc1 >> 8u) & 0xF0u) | ((sc2 >> 4u) & 0xF00u) | (sc3 & 0xF000u); + let d = f32(bitcast>(d_bits)[0]); + + let sc_u16 = select(select(sc2, sc3, sub_blk >= 6u), + select(sc0, sc1, sub_blk >= 2u), + sub_blk < 4u); + + let qs_w = load_u32_at_src0(block_byte_base + sub_blk * 4u); + let qh = load_u32_at_src0(block_byte_base + 32u + sub_blk * 2u) & 0xFFFFu; + let qh_lo = qh & 0xFFu; + let qh_hi = (qh >> 8u) & 0xFFu; + + var row_sum = 0.0; + for (var ll = 0u; ll < 2u; ll++) { + let l = slot0 + ll; + let bit_off = 6u * (sub_blk % 2u) + 3u * (l / 2u); + let sub_scale = (sc_u16 >> bit_off) & 0x7u; + let dl = d * f32(2u * sub_scale + 1u); + let qh_byte = select(qh_lo, qh_hi, l >= 2u); + let ll2 = l % 2u; + let grid_idx = get_byte(qs_w, l) | (((qh_byte >> (4u * ll2)) & 7u) << 8u); + let delta = select(IQ1_DELTA, -IQ1_DELTA, ((qh_byte >> (3u + 4u * ll2)) & 1u) != 0u); + let ig = grid_idx * 8u; + let gw = iq1_grid[ig / 16u]; + let bit_base = (ig % 16u) * 2u; + for (var j = 0u; j < 8u; j++) { + let g = (gw >> (bit_base + j * 2u)) & 3u; + let gs = select(f32(g), f32(g) - 4.0, (g & 2u) != 0u); + row_sum += dl * (gs + delta) * x_block[ll * 8u + j]; + } + } + acc[row] += row_sum; + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_IQ2_XXS +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 66 +#define THREADS_PER_BLOCK 16 +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { + var acc: array; + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let sub_blk = tid / 2u; + let half = tid % 2u; + let slot0 = half * 2u; + let y_offset = sub_blk * 32u + slot0 * 8u; + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; + for (var i = 0u; i < 16u; i++) { + x_block[i] = f32(src1[x_base + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + let aux_lo = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u); + let aux_hi = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u + 4u); + let ls = aux_hi >> 28u; + let db = d * (0.5 + f32(ls)) * 0.25; + + var row_sum = 0.0; + for (var ll = 0u; ll < 2u; ll++) { + let l = slot0 + ll; + let grid_idx = (aux_lo >> (8u * l)) & 0xFFu; + let signs_idx = (aux_hi >> (7u * l)) & 0x7Fu; + let signs = (ksigns_iq2xs[signs_idx / 4u] >> ((signs_idx % 4u) * 8u)) & 0xFFu; + let gw_lo = iq2xxs_grid[grid_idx * 2u]; + let gw_hi = iq2xxs_grid[grid_idx * 2u + 1u]; + for (var j = 0u; j < 8u; j++) { + let gw = select(gw_hi, gw_lo, j < 4u); + let b = f32((gw >> ((j & 3u) * 8u)) & 0xFFu); + let s = select(1.0, -1.0, ((signs >> j) & 1u) != 0u); + row_sum += db * b * s * x_block[ll * 8u + j]; + } + } + acc[row] += row_sum; + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_IQ2_XS +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 74 +#define THREADS_PER_BLOCK 16 +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { + var acc: array; + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let sub_blk = tid / 2u; + let half = tid % 2u; + let slot0 = half * 2u; + let y_offset = sub_blk * 32u + slot0 * 8u; + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; + for (var i = 0u; i < 16u; i++) { + x_block[i] = f32(src1[x_base + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + let qs_lo = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u); + let qs_hi = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u + 4u); + let scales_word = load_u32_at_src0(block_byte_base + 66u + (sub_blk / 4u) * 4u); + let scales_byte = get_byte(scales_word, sub_blk % 4u); + + var row_sum = 0.0; + for (var ll = 0u; ll < 2u; ll++) { + let l = slot0 + ll; + let qs_word = select(qs_hi, qs_lo, l < 2u); + let half2 = (l % 2u) * 16u; + let qs_val = (qs_word >> half2) & 0xFFFFu; + let grid_idx = qs_val & 0x1FFu; + let signs_idx = (qs_val >> 9u) & 0x7Fu; + let sub_scale = (scales_byte >> (4u * (l / 2u))) & 0xFu; + let db = d * (0.5 + f32(sub_scale)) * 0.25; + let signs = (ksigns_iq2xs[signs_idx / 4u] >> ((signs_idx % 4u) * 8u)) & 0xFFu; + let gw_lo = iq2xs_grid[grid_idx * 2u]; + let gw_hi = iq2xs_grid[grid_idx * 2u + 1u]; + for (var j = 0u; j < 8u; j++) { + let gw = select(gw_hi, gw_lo, j < 4u); + let b = f32((gw >> ((j & 3u) * 8u)) & 0xFFu); + let s = select(1.0, -1.0, ((signs >> j) & 1u) != 0u); + row_sum += db * b * s * x_block[ll * 8u + j]; + } + } + acc[row] += row_sum; + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_IQ2_S +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 82 +#define THREADS_PER_BLOCK 16 +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { + var acc: array; + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let sub_blk = tid / 2u; + let half = tid % 2u; + let slot0 = half * 2u; + let y_offset = sub_blk * 32u + slot0 * 8u; + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; + for (var i = 0u; i < 16u; i++) { + x_block[i] = f32(src1[x_base + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + let qs_w = load_u32_at_src0(block_byte_base + 2u + sub_blk * 4u); + let sg_w = load_u32_at_src0(block_byte_base + 34u + sub_blk * 4u); + let qh_word = load_u32_at_src0(block_byte_base + 66u + (sub_blk / 4u) * 4u); + let qh_byte = get_byte(qh_word, sub_blk % 4u); + let sc_word = load_u32_at_src0(block_byte_base + 74u + (sub_blk / 4u) * 4u); + let scales_byte = get_byte(sc_word, sub_blk % 4u); + + var row_sum = 0.0; + for (var ll = 0u; ll < 2u; ll++) { + let l = slot0 + ll; + let qs_byte = get_byte(qs_w, l); + let sign_byte = get_byte(sg_w, l); + let grid_idx = qs_byte | (((qh_byte >> (2u * l)) & 3u) << 8u); + let sub_scale = (scales_byte >> (4u * (l / 2u))) & 0xFu; + let db = d * (0.5 + f32(sub_scale)) * 0.25; + let gw_lo = iq2s_grid[grid_idx * 2u]; + let gw_hi = iq2s_grid[grid_idx * 2u + 1u]; + for (var j = 0u; j < 8u; j++) { + let gw = select(gw_hi, gw_lo, j < 4u); + let b = f32((gw >> ((j & 3u) * 8u)) & 0xFFu); + let s = select(1.0, -1.0, ((sign_byte >> j) & 1u) != 0u); + row_sum += db * b * s * x_block[ll * 8u + j]; + } + } + acc[row] += row_sum; + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_IQ3_XXS +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 98 +#define THREADS_PER_BLOCK 16 +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { + var acc: array; + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let sub_blk = tid / 2u; + let half = tid % 2u; + let slot0 = half * 2u; + let y_offset = sub_blk * 32u + slot0 * 8u; + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; + for (var i = 0u; i < 16u; i++) { + x_block[i] = f32(src1[x_base + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + let qs_lo = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u); + let qs_hi = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u + 4u); + let aux = load_u32_at_src0(block_byte_base + 66u + sub_blk * 4u); + let ls = aux >> 28u; + let db = d * (0.5 + f32(ls)) * 0.5; + + var row_sum = 0.0; + for (var ll = 0u; ll < 2u; ll++) { + let l = slot0 + ll; + let qs_word = select(qs_hi, qs_lo, l < 2u); + let byte_pos = (l % 2u) * 2u; + let grid_idx_0 = (qs_word >> (byte_pos * 8u)) & 0xFFu; + let grid_idx_1 = (qs_word >> ((byte_pos + 1u) * 8u)) & 0xFFu; + let signs_idx = (aux >> (7u * l)) & 0x7Fu; + let signs = (ksigns_iq2xs[signs_idx / 4u] >> ((signs_idx % 4u) * 8u)) & 0xFFu; + let grid1 = iq3xxs_grid[grid_idx_0]; + let grid2 = iq3xxs_grid[grid_idx_1]; + for (var j = 0u; j < 4u; j++) { + let b1 = f32((grid1 >> (j * 8u)) & 0xFFu); + let b2 = f32((grid2 >> (j * 8u)) & 0xFFu); + let s1 = select(1.0, -1.0, ((signs >> j) & 1u) != 0u); + let s2 = select(1.0, -1.0, ((signs >> (j + 4u)) & 1u) != 0u); + row_sum += db * b1 * s1 * x_block[ll * 8u + j]; + row_sum += db * b2 * s2 * x_block[ll * 8u + j + 4u]; + } + } + acc[row] += row_sum; + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_IQ3_S +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 110 +#define THREADS_PER_BLOCK 16 +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { + var acc: array; + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let sub_blk = tid / 2u; + let half = tid % 2u; + let slot0 = half * 2u; + let y_offset = sub_blk * 32u + slot0 * 8u; + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; + for (var i = 0u; i < 16u; i++) { + x_block[i] = f32(src1[x_base + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + let qs_lo = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u); + let qs_hi = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u + 4u); + let qh_word = load_u32_at_src0(block_byte_base + 66u + (sub_blk / 4u) * 4u); + let qh_byte = get_byte(qh_word, sub_blk % 4u); + let sg_w = load_u32_at_src0(block_byte_base + 74u + sub_blk * 4u); + let sc_word = load_u32_at_src0(block_byte_base + 106u); + let scales_byte = get_byte(sc_word, sub_blk / 2u); + let sub_scale = (scales_byte >> (4u * (sub_blk % 2u))) & 0xFu; + let db = d * (1.0 + 2.0 * f32(sub_scale)); + + var row_sum = 0.0; + for (var ll = 0u; ll < 2u; ll++) { + let l = slot0 + ll; + let qs_word = select(qs_hi, qs_lo, l < 2u); + let byte_pos = (l % 2u) * 2u; + let qs0 = (qs_word >> (byte_pos * 8u)) & 0xFFu; + let qs1 = (qs_word >> ((byte_pos + 1u) * 8u)) & 0xFFu; + let grid_idx_1 = qs0 | (((qh_byte >> (2u * l)) & 1u) << 8u); + let grid_idx_2 = qs1 | (((qh_byte >> (2u * l + 1u)) & 1u) << 8u); + let sign_byte = get_byte(sg_w, l); + let grid1 = iq3s_grid[grid_idx_1]; + let grid2 = iq3s_grid[grid_idx_2]; + for (var j = 0u; j < 4u; j++) { + let b1 = f32((grid1 >> (j * 8u)) & 0xFFu); + let b2 = f32((grid2 >> (j * 8u)) & 0xFFu); + let s1 = select(1.0, -1.0, ((sign_byte >> j) & 1u) != 0u); + let s2 = select(1.0, -1.0, ((sign_byte >> (j + 4u)) & 1u) != 0u); + row_sum += db * b1 * s1 * x_block[ll * 8u + j]; + row_sum += db * b2 * s2 * x_block[ll * 8u + j + 4u]; + } + } + acc[row] += row_sum; + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_IQ4_NL +#define BLOCK_SIZE 32 +#define BLOCK_SIZE_BYTES 18 +#define THREADS_PER_BLOCK 4 +#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { + var acc: array; + + let num_blocks = params.k / BLOCK_SIZE; + let thread_within_block = thread_id % THREADS_PER_BLOCK; + for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { + let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4u; + var x_block: array; + for (var i = 0u; i < ELEMS_PER_THREAD / 2u; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 4u] = f32(src1[x_base + i + 16u]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + var row_sum = 0.0; + + let q_packed = load_u32_at_src0(block_byte_base + 2u + 4u * thread_within_block); + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_byte = get_byte(q_packed, byte_idx); + let q_lo = f32(kvalues_iq4nl[q_byte & 0xFu]) * d; + let q_hi = f32(kvalues_iq4nl[(q_byte >> 4u) & 0xFu]) * d; + row_sum += q_lo * x_block[byte_idx]; + row_sum += q_hi * x_block[byte_idx + 4u]; + } + acc[row] += row_sum; + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_IQ4_XS +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 136 +#define THREADS_PER_BLOCK 16 +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { + var acc: array; + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let sub_blk = tid / 2u; + let half = tid % 2u; + let y_offset = sub_blk * 32u + half * 16u; + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array; + for (var i = 0u; i < 16u; i++) { + x_block[i] = f32(src1[x_base + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + let scales_h = load_u16_at_src0(block_byte_base + 2u); + let scales_l_word = load_u32_at_src0(block_byte_base + 4u); + let sl_byte = get_byte(scales_l_word, sub_blk / 2u); + let sl = (sl_byte >> (4u * (sub_blk % 2u))) & 0xFu; + let sh_bits = (scales_h >> (2u * sub_blk)) & 3u; + let ls = i32(sl | (sh_bits << 4u)); + let dl = d * f32(ls - 32); + + let qs_byte_off = 8u + sub_blk * 16u; + let q_w0 = load_u32_at_src0(block_byte_base + qs_byte_off); + let q_w1 = load_u32_at_src0(block_byte_base + qs_byte_off + 4u); + let q_w2 = load_u32_at_src0(block_byte_base + qs_byte_off + 8u); + let q_w3 = load_u32_at_src0(block_byte_base + qs_byte_off + 12u); + + var row_sum = 0.0; + for (var i = 0u; i < 16u; i++) { + let q_word = select( + select(q_w0, q_w1, i >= 4u), + select(q_w2, q_w3, i >= 12u), + i >= 8u); + let q_byte = get_byte(q_word, i % 4u); + let nib = select(q_byte & 0xFu, (q_byte >> 4u) & 0xFu, half == 1u); + row_sum += f32(kvalues_iq4nl[nib]) * dl * x_block[i]; + } + acc[row] += row_sum; + } + } + } + + return acc; +} +#endif From ccd04522f96ff68cdba1312cca8e7472a4a8bb13 Mon Sep 17 00:00:00 2001 From: Chen Yuan Date: Fri, 1 May 2026 01:22:18 -0400 Subject: [PATCH 530/831] ggml-webgpu: add the upscale shader (llama/22419) * shader(upscale): add the upscale shader with nearest, bilinear and bicubic implementations * shader(upscale): use macro --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 94 +++++++ ggml/src/ggml-webgpu/ggml-webgpu.cpp | 49 ++++ .../src/ggml-webgpu/wgsl-shaders/upscale.wgsl | 240 ++++++++++++++++++ 3 files changed, 383 insertions(+) create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/upscale.wgsl diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 0f66275c6a3..651c9cbcdf6 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -1,6 +1,7 @@ #ifndef GGML_WEBGPU_SHADER_LIB_HPP #define GGML_WEBGPU_SHADER_LIB_HPP +#include "ggml-impl.h" #include "ggml-wgsl-shaders.hpp" #include "ggml.h" #include "pre_wgsl.hpp" @@ -405,6 +406,31 @@ struct ggml_webgpu_scale_pipeline_key_hash { } }; +/** Upscale **/ + +struct ggml_webgpu_upscale_pipeline_key { + ggml_type input_type; + ggml_type output_type; + uint32_t base_mode; + bool antialias; + + bool operator==(const ggml_webgpu_upscale_pipeline_key & other) const { + return input_type == other.input_type && output_type == other.output_type && base_mode == other.base_mode && + antialias == other.antialias; + } +}; + +struct ggml_webgpu_upscale_pipeline_key_hash { + size_t operator()(const ggml_webgpu_upscale_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.input_type); + ggml_webgpu_hash_combine(seed, key.output_type); + ggml_webgpu_hash_combine(seed, key.base_mode); + ggml_webgpu_hash_combine(seed, key.antialias); + return seed; + } +}; + /** Concat **/ struct ggml_webgpu_concat_pipeline_key { @@ -1049,6 +1075,8 @@ class ggml_webgpu_shader_lib { webgpu_pipeline, ggml_webgpu_rms_norm_mul_pipeline_key_hash> rms_norm_mul_pipelines; + std::unordered_map + upscale_pipelines; public: ggml_webgpu_shader_lib(wgpu::Device device) { this->device = device; } @@ -2947,6 +2975,72 @@ class ggml_webgpu_shader_lib { return im2col_pipelines[key]; } + webgpu_pipeline get_upscale_pipeline(const ggml_webgpu_shader_lib_context & context) { + const uint32_t mode_flags = (uint32_t) ggml_get_op_params_i32(context.dst, 0); + const uint32_t base_mode = mode_flags & 0xFFu; + const bool antialias = (mode_flags & GGML_SCALE_FLAG_ANTIALIAS) != 0u; + + ggml_webgpu_upscale_pipeline_key key = {}; + key.input_type = context.src0->type; + key.output_type = context.dst->type; + key.base_mode = base_mode; + key.antialias = antialias; + + auto it = upscale_pipelines.find(key); + if (it != upscale_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "upscale"; + + if (key.input_type == GGML_TYPE_F16) { + defines.push_back("SRC_F16"); + variant += "_src_f16"; + } else { + variant += "_src_f32"; + } + + if (key.output_type == GGML_TYPE_F16) { + defines.push_back("DST_F16"); + variant += "_dst_f16"; + } else { + variant += "_dst_f32"; + } + + switch (base_mode) { + case GGML_SCALE_MODE_NEAREST: + defines.push_back("NEAREST"); + variant += "_nearest"; + break; + case GGML_SCALE_MODE_BILINEAR: + defines.push_back("BILINEAR"); + variant += "_bilinear"; + break; + case GGML_SCALE_MODE_BICUBIC: + defines.push_back("BICUBIC"); + variant += "_bicubic"; + break; + default: + GGML_ABORT("Unsupported upscale mode"); + } + + if (antialias) { + defines.push_back("ANTIALIAS"); + variant += "_aa"; + } + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_upscale, defines); + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + upscale_pipelines[key] = pipeline; + return upscale_pipelines[key]; + } + private: static webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & device, std::string shader_code, diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index f102c7a818b..cab0aead198 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -2824,6 +2824,49 @@ static bool ggml_webgpu_can_fuse_rms_norm_mul(const struct ggml_cgraph * cgraph, return true; } +static webgpu_encoded_op ggml_webgpu_upscale(webgpu_context ctx, ggml_tensor * src, ggml_tensor * dst) { + const uint32_t mode_flags = (uint32_t) ggml_get_op_params_i32(dst, 0); + std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + + (uint32_t) (src->nb[0] / ggml_type_size(src->type)), + (uint32_t) (src->nb[1] / ggml_type_size(src->type)), + (uint32_t) (src->nb[2] / ggml_type_size(src->type)), + (uint32_t) (src->nb[3] / ggml_type_size(src->type)), + + (uint32_t) (dst->nb[0] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), + + (uint32_t) src->ne[0], + (uint32_t) src->ne[1], + (uint32_t) src->ne[2], + (uint32_t) src->ne[3], + + (uint32_t) dst->ne[0], + (uint32_t) dst->ne[1], + (uint32_t) dst->ne[2], + (uint32_t) dst->ne[3], + + mode_flags }; + + std::vector entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst) }; + + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + + webgpu_pipeline pipeline = ctx->shader_lib->get_upscale_pipeline(shader_lib_ctx); + auto * decisions = static_cast(pipeline.context.get()); + uint32_t total_wg = CEIL_DIV((uint32_t) ggml_nelements(dst), decisions->wg_size); + uint32_t wg_x = std::min(ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, total_wg); + uint32_t wg_y = CEIL_DIV(total_wg, wg_x); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); +} + // Returns the encoded command, or std::nullopt if the operation is a no-op static std::optional ggml_webgpu_encode(webgpu_context ctx, ggml_cgraph * cgraph, @@ -2931,6 +2974,8 @@ static std::optional ggml_webgpu_encode(webgpu_context ctx, return ggml_webgpu_conv_2d(ctx, src0, src1, node); case GGML_OP_IM2COL: return ggml_webgpu_im2col(ctx, src0, src1, node); + case GGML_OP_UPSCALE: + return ggml_webgpu_upscale(ctx, src0, node); default: return std::nullopt; } @@ -4163,6 +4208,10 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const case GGML_OP_SUM_ROWS: supports_op = op->type == GGML_TYPE_F32 && src0->type == op->type && ggml_is_contiguous_rows(src0); break; + case GGML_OP_UPSCALE: + supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && + (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); + break; default: break; } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/upscale.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/upscale.wgsl new file mode 100644 index 00000000000..e9ef8822644 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/upscale.wgsl @@ -0,0 +1,240 @@ +#if defined(SRC_F16) || defined(DST_F16) +enable f16; +#endif + +#ifdef SRC_F16 +#define SRC_TYPE f16 +#else +#define SRC_TYPE f32 +#endif + +#ifdef DST_F16 +#define DST_TYPE f16 +#else +#define DST_TYPE f32 +#endif + +@group(0) @binding(0) +var input: array; + +@group(0) @binding(1) +var output: array; + +struct Params { + offset_i: u32, + offset_o: u32, + + // element strides + si0: u32, si1: u32, si2: u32, si3: u32, + so0: u32, so1: u32, so2: u32, so3: u32, + + src_w: u32, + src_h: u32, + src_z: u32, + src_n: u32, + + dst_w: u32, + dst_h: u32, + dst_z: u32, + dst_n: u32, + + mode_flags: u32, +}; + +@group(0) @binding(2) +var params: Params; + +const GGML_SCALE_FLAG_ALIGN_CORNERS: u32 = 1u << 8u; + +fn get_clamped_input(x: i32, y: i32, z: u32, n: u32) -> f32 { + let cx = u32(clamp(x, 0, i32(params.src_w) - 1)); + let cy = u32(clamp(y, 0, i32(params.src_h) - 1)); + let i = params.offset_i + cx * params.si0 + cy * params.si1 + z * params.si2 + n * params.si3; + return f32(input[i]); +} + +fn cubic_weight(t: f32, a: f32) -> f32 { + let at = abs(t); + if (at <= 1.0) { + return (a + 2.0) * at * at * at - (a + 3.0) * at * at + 1.0; + } else if (at <= 2.0) { + return a * at * at * at - 5.0 * a * at * at + 8.0 * a * at - 4.0 * a; + } else { + return 0.0; + } +} + +@compute @workgroup_size(WG_SIZE) +fn main( + @builtin(global_invocation_id) gid: vec3, + @builtin(num_workgroups) num_wg: vec3 +) { + + let i_out = gid.x + (num_wg.x * u32(WG_SIZE)) * gid.y; + let total = params.dst_w * params.dst_h * params.dst_z * params.dst_n; + + if (i_out >= total) { + return; + } + + // decode (x, y, z, n) + var i = i_out; + let x_dst = i % params.dst_w; + i = i / params.dst_w; + let y_dst = i % params.dst_h; + i = i / params.dst_h; + let z_dst = i % params.dst_z; + let n_dst = i / params.dst_z; + + // scale factors + var sf0 = f32(params.dst_w) / f32(params.src_w); + var sf1 = f32(params.dst_h) / f32(params.src_h); + var sf2 = f32(params.dst_z) / f32(params.src_z); + var sf3 = f32(params.dst_n) / f32(params.src_n); + + let align_corners = (params.mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) != 0; + + // pixel_offset: 0.5 for half-pixel-center (default), 0.0 for align_corners + var pixel_offset = 0.5; + if (align_corners) { + pixel_offset = 0.0; + if (params.dst_w > 1 && params.src_w > 1) { + sf0 = f32(params.dst_w - 1) / f32(params.src_w - 1); + } + if (params.dst_h > 1 && params.src_h > 1) { + sf1 = f32(params.dst_h - 1) / f32(params.src_h - 1); + } + } + + let z_src = min(params.src_z - 1, u32(floor(f32(z_dst) / sf2))); + let n_src = min(params.src_n - 1, u32(floor(f32(n_dst) / sf3))); + + var result = 0.0; + +#if defined(NEAREST) + + let x_src = min(params.src_w - 1, u32(floor(f32(x_dst) / sf0))); + let y_src = min(params.src_h - 1, u32(floor(f32(y_dst) / sf1))); + + result = get_clamped_input(i32(x_src), i32(y_src), z_src, n_src); + +#elif defined(BILINEAR) + +#if defined(ANTIALIAS) + + // Antialiased bilinear: triangle filter over a variable support region. + let support0 = max(1.0f / sf0, 1.0f); + let support1 = max(1.0f / sf1, 1.0f); + let invscale0 = 1.0 / support0; + let invscale1 = 1.0 / support1; + + let fx = (f32(x_dst) + pixel_offset) / sf0; + let fy = (f32(y_dst) + pixel_offset) / sf1; + + let x_min = max(i32(fx - support0 + pixel_offset), 0); + let y_min = max(i32(fy - support1 + pixel_offset), 0); + let x_max = min(i32(fx + support0 + pixel_offset), i32(params.src_w)); + let y_max = min(i32(fy + support1 + pixel_offset), i32(params.src_h)); + + var weighted_sum = 0.0; + var total_weight = 0.0; + + for (var x = x_min; x < x_max; x += 1) { + let wx = max(1.0 - abs(f32(x) - fx + pixel_offset) * invscale0, 0.0); + for (var y = y_min; y < y_max; y += 1) { + let wy = max(1.0 - abs(f32(y) - fy + pixel_offset) * invscale1, 0.0); + let w = wx * wy; + if (w > 0.0) { + weighted_sum += get_clamped_input(x, y, z_src, n_src) * w; + total_weight += w; + } + } + } + + if (total_weight > 0.0) { + result = weighted_sum / total_weight; + } + +#else + + let fx = (f32(x_dst) + pixel_offset) / sf0 - pixel_offset; + let fy = (f32(y_dst) + pixel_offset) / sf1 - pixel_offset; + let x0 = i32(floor(fx)); + let y0 = i32(floor(fy)); + let dx = clamp(fx - f32(x0), 0.0, 1.0); + let dy = clamp(fy - f32(y0), 0.0, 1.0); + let a = get_clamped_input(x0, y0, z_src, n_src); + let b = get_clamped_input(x0 + 1, y0, z_src, n_src); + let c = get_clamped_input(x0, y0 + 1, z_src, n_src); + let d = get_clamped_input(x0 + 1, y0 + 1, z_src, n_src); + + let wa = (1.0 - dx) * (1.0 - dy); + let wb = dx * (1.0 - dy); + let wc = (1.0 - dx) * dy; + let wd = dx * dy; + + result = a * wa + b * wb + c * wc + d * wd; + +#endif + +#elif defined(BICUBIC) + + // bicubic convolution with alpha = -0.75 (PyTorch default) + let alpha = -0.75; + let fx = (f32(x_dst) + pixel_offset) / sf0 - pixel_offset; + let fy = (f32(y_dst) + pixel_offset) / sf1 - pixel_offset; + + let x0 = i32(floor(fx)); + let y0 = i32(floor(fy)); + let dx = fx - f32(x0); + let dy = fy - f32(y0); + + // horizontal weights for offsets -1, 0, 1, 2 + let wx0 = cubic_weight(dx + 1.0, alpha); + let wx1 = cubic_weight(dx, alpha); + let wx2 = cubic_weight(1.0 - dx, alpha); + let wx3 = cubic_weight(2.0 - dx, alpha); + + // vertical weights for offsets -1, 0, 1, 2 + let wy0 = cubic_weight(dy + 1.0, alpha); + let wy1 = cubic_weight(dy, alpha); + let wy2 = cubic_weight(1.0 - dy, alpha); + let wy3 = cubic_weight(2.0 - dy, alpha); + + // intermediate horizontal interpolation for 4x4 grid of pixels + // x0-1, x0, x0+1, x0+2, y0-1 + let p0 = get_clamped_input(x0 - 1, y0 - 1, z_src, n_src); + let p1 = get_clamped_input(x0, y0 - 1, z_src, n_src); + let p2 = get_clamped_input(x0 + 1, y0 - 1, z_src, n_src); + let p3 = get_clamped_input(x0 + 2, y0 - 1, z_src, n_src); + let row0 = p0 * wx0 + p1 * wx1 + p2 * wx2 + p3 * wx3; + + // x0-1, x0, x0+1, x0+2, y0 + let q0 = get_clamped_input(x0 - 1, y0, z_src, n_src); + let q1 = get_clamped_input(x0, y0, z_src, n_src); + let q2 = get_clamped_input(x0 + 1, y0, z_src, n_src); + let q3 = get_clamped_input(x0 + 2, y0, z_src, n_src); + let row1 = q0 * wx0 + q1 * wx1 + q2 * wx2 + q3 * wx3; + + // x0-1, x0, x0+1, x0+2, y0+1 + let r0 = get_clamped_input(x0 - 1, y0 + 1, z_src, n_src); + let r1 = get_clamped_input(x0, y0 + 1, z_src, n_src); + let r2 = get_clamped_input(x0 + 1, y0 + 1, z_src, n_src); + let r3 = get_clamped_input(x0 + 2, y0 + 1, z_src, n_src); + let row2 = r0 * wx0 + r1 * wx1 + r2 * wx2 + r3 * wx3; + + // x0-1, x0, x0+1, x0+2, y0+2 + let s0 = get_clamped_input(x0 - 1, y0 + 2, z_src, n_src); + let s1 = get_clamped_input(x0, y0 + 2, z_src, n_src); + let s2 = get_clamped_input(x0 + 1, y0 + 2, z_src, n_src); + let s3 = get_clamped_input(x0 + 2, y0 + 2, z_src, n_src); + let row3 = s0 * wx0 + s1 * wx1 + s2 * wx2 + s3 * wx3; + + // final vertical interpolation + result = row0 * wy0 + row1 * wy1 + row2 * wy2 + row3 * wy3; + +#endif + + let dst_idx = params.offset_o + x_dst * params.so0 + y_dst * params.so1 + z_dst * params.so2 + n_dst * params.so3; + output[dst_idx] = DST_TYPE(result); +} From e10025351cc17bf52b94bacb0cd705deec947f8d Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 1 May 2026 13:08:32 +0300 Subject: [PATCH 531/831] sync : ggml --- scripts/sync-ggml.last | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index 236ae95a80f..a03455e74c8 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -387fa29fbbf3149f06a631c7850b6c35c24b0232 +b70770970e84c30a007b3859a453768b3ece2d3d From 35cb6841299888d20ad320966ce2176c403ada7d Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 1 May 2026 18:53:30 +0300 Subject: [PATCH 532/831] ggml : try fix win32 build (#0) --- ggml/src/ggml.c | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 54d3eae3e4d..81343eeb14c 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -55,8 +55,13 @@ uint64_t ggml_graph_next_uid(void) { #ifdef _MSC_VER +#if defined(_WIN32) + static volatile LONG counter = 1; + return (uint64_t) InterlockedIncrement(&counter) - 1; +#else static volatile long long counter = 1; return (uint64_t) _InterlockedIncrement64(&counter) - 1; +#endif #else static uint64_t counter = 1; return __atomic_fetch_add(&counter, 1, __ATOMIC_RELAXED); From 95053f68e4c2b638b3b33c200cfd9f4dd96976b7 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Fri, 1 May 2026 15:28:32 +0200 Subject: [PATCH 533/831] vulkan: Support asymmetric FA in coopmat2 path (llama/21753) * vulkan: Support asymmetric FA in coopmat2 path There has been some recent interest/experimentation with mixed quantization types for FA. I had originally designed the cm2 FA shader with this in mind (because I didn't realize it wasn't supported at the time!), this change adds the missing pieces and enables it. Also support Q1_0 since people have been trying that out (seems crazy, but who knows). We should be able to do similar things in the coopmat1/scalar path, but there's another change open against the scalar path and I don't want to conflict. * reorder cases --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 155 +++++++++++------- .../vulkan-shaders/flash_attn_base.glsl | 6 + .../vulkan-shaders/flash_attn_cm2.comp | 94 ++++++++--- .../vulkan-shaders/vulkan-shaders-gen.cpp | 17 +- 4 files changed, 185 insertions(+), 87 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 10b73317943..c2f1883328f 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -440,10 +440,12 @@ struct vk_fa_pipeline_state { bool f32acc; uint32_t flags; uint32_t limit_occupancy_shmem; + ggml_type k_type; + ggml_type v_type; bool operator<(const vk_fa_pipeline_state &b) const { - return std::tie(HSK, HSV, Br, Bc, D_split, row_split, shmem_staging, path, workgroup_size, subgroup_size, aligned, f32acc, flags, limit_occupancy_shmem) < - std::tie(b.HSK, b.HSV, b.Br, b.Bc, b.D_split, b.row_split, b.shmem_staging, b.path, b.workgroup_size, b.subgroup_size, b.aligned, b.f32acc, b.flags, b.limit_occupancy_shmem); + return std::tie(HSK, HSV, Br, Bc, D_split, row_split, shmem_staging, path, workgroup_size, subgroup_size, aligned, f32acc, flags, limit_occupancy_shmem, k_type, v_type) < + std::tie(b.HSK, b.HSV, b.Br, b.Bc, b.D_split, b.row_split, b.shmem_staging, b.path, b.workgroup_size, b.subgroup_size, b.aligned, b.f32acc, b.flags, b.limit_occupancy_shmem, b.k_type, b.v_type); } }; @@ -3041,7 +3043,7 @@ static vk_fa_tuning_params get_fa_tuning_params_coopmat1(const vk_device& device return result; } -static vk_fa_tuning_params get_fa_tuning_params_coopmat2(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) { +static vk_fa_tuning_params get_fa_tuning_params_coopmat2(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type k_type, ggml_type v_type, bool f32acc) { GGML_UNUSED(n_kv); GGML_UNUSED(f32acc); @@ -3055,7 +3057,7 @@ static vk_fa_tuning_params get_fa_tuning_params_coopmat2(const vk_device& device if (small_rows) { result.block_rows = 32; result.block_cols = 32; - } else if (ggml_is_quantized(kv_type) || hsk >= 256 || hsv >= 256) { + } else if (ggml_is_quantized(k_type) || ggml_is_quantized(v_type) || hsk >= 256 || hsv >= 256) { result.block_rows = (hsk >= 512 || hsv >= 512) ? 32 : 64; result.block_cols = 32; } else { @@ -3069,7 +3071,13 @@ static vk_fa_tuning_params get_fa_tuning_params_coopmat2(const vk_device& device return result; } -static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) { +static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type k_type, ggml_type v_type, bool f32acc) { + // Mixed K/V is only implemented on the coopmat2 (flash_attn_cm2) path; never use scalar/cm1. + if (k_type != v_type) { + GGML_ASSERT(device->coopmat2); + return get_fa_tuning_params_coopmat2(device, hsk, hsv, n_rows, n_kv, k_type, v_type, f32acc); + } + FaCodePath path = device->coopmat2 ? FA_COOPMAT2 : device->coopmat1_fa_support ? FA_COOPMAT1 : FA_SCALAR; @@ -3081,7 +3089,7 @@ static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_ if (path == FA_COOPMAT1) { bool shape_ok = (f32acc && device->coopmat_support_16x16x16_f32acc) || (!f32acc && device->coopmat_support_16x16x16_f16acc); - const vk_fa_tuning_params params = get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc); + const vk_fa_tuning_params params = get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, k_type, f32acc); bool shmem_ok = ggml_vk_flash_attn_coopmat_shmem_support(device, params, hsk, hsv, f32acc); if (!shape_ok || !shmem_ok) { @@ -3094,20 +3102,25 @@ static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_ path = FA_SCALAR; } + // Q1_0 K/V is only implemented on coopmat2 (flash_attn_cm2); there is no scalar FA shader for it. + if ((k_type == GGML_TYPE_Q1_0 || v_type == GGML_TYPE_Q1_0) && device->coopmat2) { + path = FA_COOPMAT2; + } + switch (path) { case FA_SCALAR: - return get_fa_tuning_params_scalar(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc); + return get_fa_tuning_params_scalar(device, hsk, hsv, n_rows, n_kv, k_type, f32acc); case FA_COOPMAT1: - return get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc); + return get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, k_type, f32acc); case FA_COOPMAT2: - return get_fa_tuning_params_coopmat2(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc); + return get_fa_tuning_params_coopmat2(device, hsk, hsv, n_rows, n_kv, k_type, v_type, f32acc); default: throw std::runtime_error("unsupported FaCodePath"); } } static vk_fa_pipeline_state get_fa_pipeline_state(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool aligned, bool f32acc, - bool use_mask, bool use_mask_opt, bool use_logit_softcap) { + bool use_mask, bool use_mask_opt, bool use_logit_softcap, ggml_type k_type, ggml_type v_type) { const bool old_amd_windows = device->vendor_id == VK_VENDOR_ID_AMD && device->driver_id == vk::DriverId::eAmdProprietary && (device->architecture == AMD_GCN || device->architecture == AMD_RDNA1 || device->architecture == AMD_RDNA2); @@ -3118,12 +3131,32 @@ static vk_fa_pipeline_state get_fa_pipeline_state(const vk_device& device, const const uint32_t subgroup_size = params.disable_subgroups ? 0 : params.subgroup_size; - return vk_fa_pipeline_state{hsk, hsv, params.block_rows, params.block_cols, params.d_split, params.row_split, params.shmem_staging, params.path, params.workgroup_size, subgroup_size, aligned, f32acc, flags, params.limit_occupancy_shmem}; + return vk_fa_pipeline_state{hsk, hsv, params.block_rows, params.block_cols, params.d_split, params.row_split, params.shmem_staging, params.path, params.workgroup_size, subgroup_size, aligned, f32acc, flags, params.limit_occupancy_shmem, k_type, v_type}; } static std::vector get_fa_spec_constants(const vk_fa_pipeline_state& state) { - return {state.workgroup_size, state.Br, state.Bc, state.HSK, state.HSV, !state.aligned, state.D_split, - state.row_split, state.subgroup_size, state.shmem_staging ? 1u : 0u, state.flags, state.limit_occupancy_shmem}; + const auto fa_block_bytes = [](ggml_type t) -> uint32_t { + // decodeBufF32 uses a block of vec4s for a better memory access pattern. + return t == GGML_TYPE_F32 ? 16u : (uint32_t) ggml_type_size(t); + }; + return { + /* 0 WorkGroupSize */ state.workgroup_size, + /* 1 Br */ state.Br, + /* 2 Bc */ state.Bc, + /* 3 HSK */ state.HSK, + /* 4 HSV */ state.HSV, + /* 5 Clamp */ static_cast(!state.aligned), + /* 6 D_split */ state.D_split, + /* 7 row_split */ state.row_split, + /* 8 SubGroupSize */ state.subgroup_size, + /* 9 SHMEM_STAGING */ state.shmem_staging ? 1u : 0u, + /*10 Flags */ state.flags, + /*11 LIMIT_OCCUPANCY_SHMEM */ state.limit_occupancy_shmem, + /*12 FaTypeK */ static_cast(state.k_type), + /*13 FaTypeV */ static_cast(state.v_type), + /*14 FaBlockBytesK */ fa_block_bytes(state.k_type), + /*15 FaBlockBytesV */ fa_block_bytes(state.v_type), + }; } static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector& warptile, bool mul_mat_id, ggml_type src0_type) { @@ -3578,16 +3611,35 @@ static void ggml_vk_load_shaders(vk_device& device) { } #endif #if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) +#define CREATE_FA_CM2_MIXED() \ + for (int fa_k_ty = 0; fa_k_ty < (int)GGML_TYPE_COUNT; ++fa_k_ty) { \ + for (auto &fa : device->pipeline_flash_attn_f32_f16[fa_k_ty]) { \ + FaCodePath path = fa.first.path; \ + uint32_t Br = fa.first.Br; \ + uint32_t Bc = fa.first.Bc; \ + bool aligned = fa.first.aligned; \ + bool f32acc = fa.first.f32acc; \ + if (path == FA_COOPMAT2) { \ + if (aligned) { \ + if (f32acc) { \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_mixed_aligned_f32acc_cm2", flash_attn_f32_f16_mixed_cm2_len, flash_attn_f32_f16_mixed_cm2_data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), Bc, true, false, 0); \ + } else { \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_mixed_aligned_f16acc_cm2", flash_attn_f32_f16_mixed_f16acc_cm2_len, flash_attn_f32_f16_mixed_f16acc_cm2_data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), Bc, true, false, 0); \ + } \ + } else { \ + if (f32acc) { \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_mixed_f32acc_cm2", flash_attn_f32_f16_mixed_cm2_len, flash_attn_f32_f16_mixed_cm2_data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), 1, true, false, 0); \ + } else { \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_mixed_f16acc_cm2", flash_attn_f32_f16_mixed_f16acc_cm2_len, flash_attn_f32_f16_mixed_f16acc_cm2_data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), 1, true, false, 0); \ + } \ + } \ + } \ + } \ + } if (device->coopmat2) { - CREATE_FA(GGML_TYPE_F32, f32, FA_COOPMAT2, _cm2) - CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT2, _cm2) - CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT2, _cm2) - CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_COOPMAT2, _cm2) - CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_COOPMAT2, _cm2) - CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_COOPMAT2, _cm2) - CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT2, _cm2) - CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_COOPMAT2, _cm2) + CREATE_FA_CM2_MIXED(); } +#undef CREATE_FA_CM2_MIXED #endif #undef CREATE_FA @@ -9042,8 +9094,6 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx assert(dst->type == GGML_TYPE_F32); assert(q->type == GGML_TYPE_F32); - assert(k->type == v->type); - uint32_t gqa_ratio = 1; uint32_t qk_ratio = neq2 / nek2; uint32_t workgroups_x = (uint32_t)neq1; @@ -9054,7 +9104,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx // For scalar/coopmat1 FA, we can use the "large" size to accommodate qga. // For coopmat2 FA, we always use the small size (which is still pretty large for gqa). - vk_fa_tuning_params tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, 512, KV, k->type, f32acc); + vk_fa_tuning_params tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, 512, KV, k->type, v->type, f32acc); const uint32_t max_gqa = std::min(tuning_params.block_rows, 32u); if (N <= 8 && qk_ratio > 1 && qk_ratio <= max_gqa && @@ -9067,7 +9117,11 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx workgroups_y /= gqa_ratio; } - tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, N, KV, k->type, f32acc); + tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, N, KV, k->type, v->type, f32acc); + + if (tuning_params.path != FA_COOPMAT2) { + GGML_ASSERT(k->type == v->type); + } const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type)); uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type)); @@ -9106,7 +9160,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx // Only use mask opt when the mask is fairly large. This hasn't been tuned extensively. bool use_mask_opt = mask && nem1 >= 32 && nem0 * nem1 > 32768 && nem0 >= tuning_params.block_cols * 16; vk_fa_pipeline_state fa_pipeline_state = get_fa_pipeline_state(ctx->device, tuning_params, HSK, HSV, aligned, f32acc, - mask != nullptr, use_mask_opt, logit_softcap != 0); + mask != nullptr, use_mask_opt, logit_softcap != 0, k->type, v->type); vk_pipeline pipeline = nullptr; @@ -15590,38 +15644,27 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) { return false; } - // It's straightforward to support different K/V dequant, but would - // significantly increase the number of pipelines - if (op->src[1]->type != op->src[2]->type) { + // mismatching K/V type is currently supported for coopmat2 only. + if (op->src[1]->type != op->src[2]->type && !coopmat2) { return false; } - switch (op->src[1]->type) { - case GGML_TYPE_F16: - case GGML_TYPE_F32: - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q8_0: - case GGML_TYPE_Q4_1: - case GGML_TYPE_Q5_0: - case GGML_TYPE_Q5_1: - case GGML_TYPE_IQ4_NL: - // supported in scalar and coopmat2 paths - break; - // K dequants currently disabled because D dimension is rounded up to 256 and runs inefficiently - //case GGML_TYPE_Q2_K: - //case GGML_TYPE_Q3_K: - //case GGML_TYPE_Q4_K: - //case GGML_TYPE_Q5_K: - //case GGML_TYPE_Q6_K: - //case GGML_TYPE_IQ1_S: - //case GGML_TYPE_IQ1_M: - //case GGML_TYPE_IQ2_XXS: - //case GGML_TYPE_IQ2_XS: - //case GGML_TYPE_IQ2_S: - //case GGML_TYPE_IQ3_XXS: - //case GGML_TYPE_IQ3_S: - //case GGML_TYPE_IQ4_XS: - - default: + auto fa_kv_ok = [coopmat2](ggml_type t) { + switch (t) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_0: + return true; + case GGML_TYPE_Q1_0: + return coopmat2; + default: + return false; + } + }; + if (!fa_kv_ok(op->src[1]->type) || !fa_kv_ok(op->src[2]->type)) { return false; } if (!coopmat2 && !(device->subgroup_shuffle && device->subgroup_vote)) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl index 6f349246915..efed3a73e22 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl @@ -13,6 +13,12 @@ layout (constant_id = 8) const uint32_t SubGroupSize = 32; layout (constant_id = 9) const uint32_t SHMEM_STAGING = 0; layout (constant_id = 10) const uint32_t Flags = 0; layout (constant_id = 11) const uint32_t LIMIT_OCCUPANCY_SHMEM = 0; +// ggml_type enumerant for K/V +layout (constant_id = 12) const uint32_t FaTypeK = 0; +layout (constant_id = 13) const uint32_t FaTypeV = 0; +// sizeof(decode buffer): quants -> ggml block size; F32 -> 16 (decodeBufF32 vec4). +layout (constant_id = 14) const uint32_t FaBlockBytesK = 2; +layout (constant_id = 15) const uint32_t FaBlockBytesV = 2; const bool USE_MASK_OPT = (Flags & 1) != 0; const bool MASK_ENABLE = (Flags & 2) != 0; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index 0ea181342ce..8a7bbaeb92c 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -17,8 +17,57 @@ #extension GL_EXT_null_initializer : enable #include "types.glsl" -#include "dequant_funcs_cm2.glsl" #include "flash_attn_base.glsl" +#include "dequant_funcs_cm2.glsl" + +// buffer_reference stride = sizeof(struct) = FaBlockBytesK/V. +layout(buffer_reference, std430, buffer_reference_align = 1) buffer decodeBufFA_K { + uint8_t raw[FaBlockBytesK]; +}; +layout(buffer_reference, std430, buffer_reference_align = 1) buffer decodeBufFA_V { + uint8_t raw[FaBlockBytesV]; +}; + +uint fa_block_elems(uint ty) { + switch (ty) { + case 0u: return 4u; // GGML_TYPE_F32: vec4 block (matches decodeBufF32 / dequantFuncF32) + case 1u: return 1u; // GGML_TYPE_F16 + case 2u: return uint(QUANT_K_Q4_0); + case 3u: return uint(QUANT_K_Q4_1); + case 6u: return uint(QUANT_K_Q5_0); + case 7u: return uint(QUANT_K_Q5_1); + case 8u: return uint(QUANT_K_Q8_0); + case 41u: return uint(QUANT_K_Q1_0); + default: + return 1u; + } +} + +float16_t faDecodeK(const decodeBufFA_K bl_in, const uint blockCoords[2], const uint coordInBlock[2]) { + switch (FaTypeK) { + case 0u: return dequantFuncF32(decodeBufF32(bl_in), blockCoords, coordInBlock); + case 2u: return dequantFuncQ4_0(decodeBufQ4_0(bl_in), blockCoords, coordInBlock); + case 3u: return dequantFuncQ4_1(decodeBufQ4_1(bl_in), blockCoords, coordInBlock); + case 6u: return dequantFuncQ5_0(decodeBufQ5_0(bl_in), blockCoords, coordInBlock); + case 7u: return dequantFuncQ5_1(decodeBufQ5_1(bl_in), blockCoords, coordInBlock); + case 8u: return dequantFuncQ8_0(decodeBufQ8_0(bl_in), blockCoords, coordInBlock); + case 41u: return dequantFuncQ1_0(decodeBufQ1_0(bl_in), blockCoords, coordInBlock); + default: return float16_t(0); + } +} + +float16_t faDecodeV(const decodeBufFA_V bl_in, const uint blockCoords[2], const uint coordInBlock[2]) { + switch (FaTypeV) { + case 0u: return dequantFuncF32(decodeBufF32(bl_in), blockCoords, coordInBlock); + case 2u: return dequantFuncQ4_0(decodeBufQ4_0(bl_in), blockCoords, coordInBlock); + case 3u: return dequantFuncQ4_1(decodeBufQ4_1(bl_in), blockCoords, coordInBlock); + case 6u: return dequantFuncQ5_0(decodeBufQ5_0(bl_in), blockCoords, coordInBlock); + case 7u: return dequantFuncQ5_1(decodeBufQ5_1(bl_in), blockCoords, coordInBlock); + case 8u: return dequantFuncQ8_0(decodeBufQ8_0(bl_in), blockCoords, coordInBlock); + case 41u: return dequantFuncQ1_0(decodeBufQ1_0(bl_in), blockCoords, coordInBlock); + default: return float16_t(0); + } +} layout (binding = 0) readonly buffer Q {uint8_t data_q[];}; layout (binding = 1) readonly buffer K {uint8_t data_k[];}; @@ -55,12 +104,6 @@ ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE ele return max(elem0, elem1); } -#if BLOCK_SIZE > 1 -#define DECODEFUNC , DEQUANTFUNC -#else -#define DECODEFUNC -#endif - // Store the output when doing grouped query attention. // Rows index by Q's dimension 2, and the first N rows are valid. D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) @@ -95,10 +138,6 @@ ACC_TYPE perElemOpNonGqaSplitKStoreCol0(const in uint32_t r, const in uint32_t c } void main() { -#ifdef NEEDS_INIT_IQ_SHMEM - init_iq_shmem(gl_WorkGroupSize); -#endif - init_indices(); tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutQ = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); @@ -107,10 +146,10 @@ void main() { tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0); -#if BLOCK_SIZE > 1 - tensorLayoutK = setTensorLayoutBlockSizeNV(tensorLayoutK, 1, BLOCK_SIZE); - tensorLayoutV = setTensorLayoutBlockSizeNV(tensorLayoutV, 1, BLOCK_SIZE); -#endif + const uint bs_k = fa_block_elems(FaTypeK); + const uint bs_v = fa_block_elems(FaTypeV); + tensorLayoutK = setTensorLayoutBlockSizeNV(tensorLayoutK, 1, bs_k); + tensorLayoutV = setTensorLayoutBlockSizeNV(tensorLayoutV, 1, bs_v); tensorLayoutQ = setTensorLayoutDimensionNV(tensorLayoutQ, N, HSK); tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, HSK); @@ -120,10 +159,12 @@ void main() { if (Clamp != gl_CooperativeMatrixClampModeConstantNV) { q_stride &= ~7; -#if BLOCK_SIZE == 1 - k_stride &= ~7; - v_stride &= ~7; -#endif + if (bs_k == 1u) { + k_stride &= ~7; + } + if (bs_v == 1u) { + v_stride &= ~7; + } m_stride &= ~7; } tensorLayoutQ = setTensorLayoutStrideNV(tensorLayoutQ, q_stride, 1); @@ -230,7 +271,13 @@ void main() { coopmat K_T; uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13; - coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose DECODEFUNC); + // F16: bs_k==1 (direct load). F32: bs_k==4 (vec4 / dequantFuncF32). Q4/Q8 family: bs_k==32. Q1_0: bs_k==128. + const bool k_use_decode = (bs_k > 1u); + if (k_use_decode) { + coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose, faDecodeK); + } else { + coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose); + } S = coopMatMulAdd(Qf16, K_T, S); if (LOGIT_SOFTCAP) { @@ -291,7 +338,12 @@ void main() { coopmat V; uint32_t v_offset = iv2*p.nb22 + iv3*p.nb23; - coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad) DECODEFUNC); + const bool v_use_decode = (bs_v > 1u); + if (v_use_decode) { + coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad), faDecodeV); + } else { + coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad)); + } L = eM*L + rowsum; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index ff836615330..6f2a929c40c 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -641,20 +641,17 @@ void process_shaders() { fa_base_dict["ACC_TYPE_MAX"] = "float16_t(65504.0)"; } + if (fp16) { +#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + string_to_spv("flash_attn_f32_f16_mixed", "flash_attn_cm2.comp", + merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), fp16, false, true, f16acc); +#endif + } + for (const auto& tname : type_names) { if (tname == "bf16") continue; if (fp16) { -#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) - if (tname == "f16") { - string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", - merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), fp16, false, true, f16acc); - } else { - std::string data_a_key = "DATA_A_" + to_uppercase(tname); - string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", - merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), fp16, false, true, f16acc); - } -#endif #if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) if (tname == "f16") { string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp", From 9623c1203b91da1467c1f3692c713a1f09dfa8c4 Mon Sep 17 00:00:00 2001 From: Masashi Yoshimura Date: Fri, 1 May 2026 23:55:01 +0900 Subject: [PATCH 534/831] ggml-webgpu: Fix vectorized handling in mul-mat and mul-mat-id (llama/22578) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Fix vectorized condition of mul-mat-fast pipeline and add vectorized variant to mul-mat-id * Apply suggestion from @CISC Co-authored-by: Sigbjørn Skjæret --------- Co-authored-by: Sigbjørn Skjæret --- ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 651c9cbcdf6..cff93b8d170 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -1779,12 +1779,12 @@ class ggml_webgpu_shader_lib { webgpu_pipeline get_mul_mat_fast_pipeline(const ggml_webgpu_shader_lib_context & context) { ggml_webgpu_mul_mat_pipeline_key key = {}; - key.src0_type = context.src0->type; - key.src1_type = context.src1->type; - key.vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 && context.dst->ne[1] % 4 == 0 && - (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? - 1 : - 0; + key.src0_type = context.src0->type; + key.src1_type = context.src1->type; + key.vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 && + (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? + 1 : + 0; key.use_subgroup_matrix = context.supports_subgroup_matrix; auto it = mul_mat_fast_pipelines.find(key); @@ -2143,6 +2143,9 @@ class ggml_webgpu_shader_lib { // variant suffix for src1 type variant += std::string("_") + (context.src1->type == GGML_TYPE_F32 ? "f32" : "f16"); + if (key.vectorized) { + variant += "_vectorized"; + } auto processed = preprocessor.preprocess(wgsl_mul_mat_id, defines); From f2ce24fa5c946d7cbc42f52947428c2eba299393 Mon Sep 17 00:00:00 2001 From: Aparna M P Date: Fri, 1 May 2026 22:39:23 +0530 Subject: [PATCH 535/831] hexagon: enable non-contiguous row tensor support for unary ops (llama/22574) --- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 4 +- ggml/src/ggml-hexagon/htp/hvx-exp.h | 4 +- ggml/src/ggml-hexagon/htp/unary-ops.c | 110 ++++++++++++++++++------- 3 files changed, 85 insertions(+), 33 deletions(-) diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 17ac083f4ea..6bb073102c0 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -2421,8 +2421,8 @@ static bool ggml_hexagon_supported_unary(const struct ggml_hexagon_session * ses return false; } - // TODO: add support for non-contigiuos tensors - if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(dst)) { + // TODO: add support for non-contiguous elements within a row + if (!ggml_is_contiguous_rows(src0) || !ggml_is_contiguous_rows(dst)) { return false; } diff --git a/ggml/src/ggml-hexagon/htp/hvx-exp.h b/ggml/src/ggml-hexagon/htp/hvx-exp.h index 84e4836dc92..e71ec4909a6 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-exp.h +++ b/ggml/src/ggml-hexagon/htp/hvx-exp.h @@ -17,7 +17,7 @@ #define EXP_LOGN2 (0x3F317218) // ln(2) = 0.6931471805 #define EXP_LOG2E (0x3FB8AA3B) // log2(e) = 1/ln(2) = 1.4426950408 #define EXP_ONE (0x3f800000) // 1.0 -#define EXP_RANGE_R (0x42B16666) // 88.7 +#define EXP_RANGE_R (0x42B17218) // ln(FLT_MAX) approx = 88.7228 #define EXP_RANGE_L (0xC2B00000) // -88.0 (approx log(FLT_MIN)) static inline HVX_Vector hvx_vec_exp_f32(HVX_Vector in_vec) { @@ -163,7 +163,7 @@ static inline void hvx_exp_f32(uint8_t * restrict dst, const uint8_t * restrict HVX_Vector vec_out = Q6_V_vzero(); static const float kInf = INFINITY; - static const float kMaxExp = 88.7f; + static const float kMaxExp = 88.7228f; const HVX_Vector max_exp = hvx_vec_splat_f32(kMaxExp); const HVX_Vector inf = hvx_vec_splat_f32(kInf); diff --git a/ggml/src/ggml-hexagon/htp/unary-ops.c b/ggml/src/ggml-hexagon/htp/unary-ops.c index 03eccfd55e3..819cdc49bd9 100644 --- a/ggml/src/ggml-hexagon/htp/unary-ops.c +++ b/ggml/src/ggml-hexagon/htp/unary-ops.c @@ -26,8 +26,8 @@ struct htp_unary_context { const uint8_t * data_src0; uint8_t * data_dst; - size_t src0_row_size; - size_t dst_row_size; + size_t src0_data_row_size; // actual data bytes per row + size_t dst_data_row_size; // actual data bytes per row size_t src0_row_size_aligned; size_t dst_row_size_aligned; @@ -41,6 +41,40 @@ struct htp_unary_context { uint32_t nc; }; +// Convert flat row index to DDR byte offset using the tensor's actual strides. +// ir = i1 + ne1*(i2 + ne2*i3) => offset = i1*nb1 + i2*nb2 + i3*nb3 +static inline size_t unary_row_offset(uint32_t ir, + uint32_t ne1, uint32_t ne2, + size_t nb1, size_t nb2, size_t nb3) { + const uint32_t i1 = ir % ne1; + const uint32_t i2 = (ir / ne1) % ne2; + const uint32_t i3 = ir / (ne1 * ne2); + return i1 * nb1 + i2 * nb2 + i3 * nb3; +} +// Safe DMA block size from row `ir`: clamp to the tighter dim-1 slice +// boundary of src and dst so the nb1 stride stays valid for all rows. +static inline uint32_t unary_block_size(uint32_t ir, + uint32_t end_row, + uint32_t block, + bool src_contig, + bool dst_contig, + uint32_t src_ne1, + uint32_t dst_ne1) { + uint32_t limit = MIN(block, end_row - ir); + + if (!src_contig) { + const uint32_t src_slice_end = (ir / src_ne1 + 1) * src_ne1; + limit = MIN(limit, src_slice_end - ir); + } + + if (!dst_contig) { + const uint32_t dst_slice_end = (ir / dst_ne1 + 1) * dst_ne1; + limit = MIN(limit, dst_slice_end - ir); + } + + return limit; +} + #define htp_unary_preamble \ const uint32_t ne00 = src->ne[0]; \ const uint32_t ne01 = src->ne[1]; \ @@ -276,8 +310,8 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * int32_t * op_params = octx->op_params; uint32_t src0_nrows_per_thread = uctx->src0_nrows_per_thread; - const size_t src0_row_size = uctx->src0_row_size; - const size_t dst_row_size = uctx->dst_row_size; + const size_t src0_data_row_size = uctx->src0_data_row_size; + const size_t dst_data_row_size = uctx->dst_data_row_size; const size_t src0_row_size_aligned = uctx->src0_row_size_aligned; const size_t dst_row_size_aligned = uctx->dst_row_size_aligned; @@ -303,7 +337,16 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * size_t src0_spad_half_size = uctx->src0_spad_half_size; size_t dst_spad_half_size = uctx->dst_spad_half_size; - const int BLOCK = uctx->block; + // Non-contiguous tensors have gaps at dim-2/3 boundaries that a single-stride + // 2D DMA descriptor cannot span. Clamp BLOCK to ne1 (one dim-1 slice) so every + // transfer stays within a nb1-uniform region. Skipped for contiguous tensors. + const bool src0_contig = (nb02 == (size_t)ne01 * nb01) && + (nb03 == (size_t)ne02 * nb02); + const bool dst_contig = (nb2 == (size_t)ne1 * nb1) && + (nb3 == (size_t)ne2 * nb2); + const uint32_t src0_max_block = src0_contig ? uctx->block : MIN((uint32_t)uctx->block, ne01); + const uint32_t dst_max_block = dst_contig ? uctx->block : MIN((uint32_t)uctx->block, ne1); + const uint32_t BLOCK = MIN(src0_max_block, dst_max_block); if (BLOCK == 0) { FARF(ERROR, "unary-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n", octx->src0_spad.size_per_thread, src0_row_size_aligned); @@ -312,21 +355,23 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * dma_queue * dma_queue = octx->ctx->dma[ith]; - for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) { - const uint32_t block_size = MIN(BLOCK, src0_end_row - ir); + for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; spad_idx++) { + const uint32_t block_size = unary_block_size(ir, src0_end_row, BLOCK, src0_contig, dst_contig, ne01, ne1); // Dummy DMA transation for sequencing (interleaving dst,src,dst,...) - dma_queue_push_vtcm_to_ddr(dma_queue, + dma_queue_push(dma_queue, dma_make_ptr(data_dst, dst_spad_data + (spad_idx * dst_spad_half_size)), - dst_row_size, dst_row_size_aligned, 0); + nb1, dst_row_size_aligned, dst_data_row_size, 0); - dma_queue_push_ddr_to_vtcm(dma_queue, - dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src + (ir * src0_row_size)), - src0_row_size_aligned, src0_row_size, block_size); + const size_t src0_off = unary_row_offset(ir, ne01, ne02, nb01, nb02, nb03); + dma_queue_push(dma_queue, + dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src + src0_off), + src0_row_size_aligned, nb01, src0_data_row_size, block_size); + ir += block_size; } - for (uint32_t ir = src0_start_row; ir < src0_end_row; ir += BLOCK) { - const uint32_t block_size = MIN(BLOCK, src0_end_row - ir); + for (uint32_t ir = src0_start_row; ir < src0_end_row; ) { + const uint32_t block_size = unary_block_size(ir, src0_end_row, BLOCK, src0_contig, dst_contig, ne01, ne1); float * dst_spad = (float *) dma_queue_pop(dma_queue).src; float * src0_spad = (float *) dma_queue_pop(dma_queue).dst; @@ -361,18 +406,25 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * break; } - dma_queue_push_vtcm_to_ddr(dma_queue, - dma_make_ptr(data_dst + (ir * dst_row_size), dst_spad), - dst_row_size, dst_row_size_aligned, block_size); + const size_t dst_off = unary_row_offset(ir, ne1, ne2, nb1, nb2, nb3); + dma_queue_push(dma_queue, + dma_make_ptr(data_dst + dst_off, dst_spad), + nb1, dst_row_size_aligned, dst_data_row_size, block_size); // prefetch N+2 loop iteration if any - const uint32_t pref_block = (ir + BLOCK * 2); - if (pref_block < src0_end_row) { - const uint32_t pref_block_size = MIN(BLOCK, src0_end_row - pref_block); - dma_queue_push_ddr_to_vtcm(dma_queue, - dma_make_ptr(src0_spad, data_src + (pref_block * src0_row_size)), - src0_row_size_aligned, src0_row_size, pref_block_size); + const uint32_t next_ir = ir + block_size; + if (next_ir < src0_end_row) { + const uint32_t next_block_size = unary_block_size(next_ir, src0_end_row, BLOCK, src0_contig, dst_contig, ne01, ne1); + const uint32_t pref_ir = next_ir + next_block_size; + if (pref_ir < src0_end_row) { + const uint32_t pref_block_size = unary_block_size(pref_ir, src0_end_row, BLOCK, src0_contig, dst_contig, ne01, ne1); + const size_t src0_pref_off = unary_row_offset(pref_ir, ne01, ne02, nb01, nb02, nb03); + dma_queue_push(dma_queue, + dma_make_ptr(src0_spad, data_src + src0_pref_off), + src0_row_size_aligned, nb01, src0_data_row_size, pref_block_size); + } } + ir += block_size; } dma_queue_flush(dma_queue); @@ -426,11 +478,11 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) { const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3]; const uint32_t n_threads = MIN(octx->n_threads, src0_nrows); - const size_t src0_row_size = src0->nb[1]; - const size_t dst_row_size = dst->nb[1]; + const size_t src0_data_row_size = src0->ne[0] * sizeof(float); + const size_t dst_data_row_size = dst->ne[0] * sizeof(float); - const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN); - const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN); + const size_t src0_row_size_aligned = hex_round_up(src0_data_row_size, VLEN); + const size_t dst_row_size_aligned = hex_round_up(dst_data_row_size, VLEN); // VTCM scratchpads for all tensors // N rows per thread, padded to HVX vector size @@ -468,8 +520,8 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) { .data_src0 = (const uint8_t *)src0->data, .data_dst = (uint8_t *)dst->data, - .src0_row_size = src0_row_size, - .dst_row_size = dst_row_size, + .src0_data_row_size = src0_data_row_size, + .dst_data_row_size = dst_data_row_size, .src0_row_size_aligned = src0_row_size_aligned, .dst_row_size_aligned = dst_row_size_aligned, From 4861a3eeb5cb86df2de29c38c488e44d8dc9f6ca Mon Sep 17 00:00:00 2001 From: Yiwei Shao <44545837+njsyw1997@users.noreply.github.com> Date: Fri, 1 May 2026 20:29:13 -0700 Subject: [PATCH 536/831] hexagon: hmx flash attention (llama/22347) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * hmx: extract shared interleave headers and unify matmul batched * hmx: add HMX-accelerated flash attention for prefill * hmx: replace asm wrappers with Q6_ intrinsics in hmx-utils.h Switches three single-instruction helpers from inline asm to the matching Q6_ intrinsics, matching the style established by aizip f8737609a and used by the upstream PR #21554 hmx-matmul-ops.c rewrite: hmx_set_output_scales asm "bias=mxmem2" -> Q6_bias_mxmem2_A hmx_load_tile_pair_fp16 asm packet -> Q6_activation_hf_mxmem_RR + Q6_weight_hf_mxmem_RR hmx_consume_accumulator_fp16 asm "mxmem=acc" -> Q6_mxmem_AR_after_hf hmx_load_tiles_fp16 stays on inline asm: it uses ":deep" activation streaming, and the mixed Q6_activation_hf_mxmem_RR_deep + non-deep Q6_weight_hf_mxmem_RR pair fails the HMX backend constraint check ("activate weight pair (1) exceeds limit (1)"). The asm bundle keeps both halves in one VLIW packet and avoids the diagnostic. Functionally equivalent — same instructions emitted; the Q6_ intrinsics just give the compiler more visibility for scheduling. * hmx: drop the duplicate interleave_fp16_weight_chunk_to_tiles * hmx: apply upstream optimization to hmx-flash-attn-ops.c apply restrict, __builtin_assume, and pointer accumulation to the three HMX workers (qk_dot, o_update, o_norm) and the matching inline HMX loops in op_hmx_flash_attn_ext. * hmx: unify interleave helper * hmx: multi-thread Q load / O store and enable prefill FA dispatch Extract inline Q-load and O-store loops into worker_pool-parallel helpers (fa_phase_q_load, fa_phase_o_store) so HVX threads split the F32↔F16 conversion work across row ranges. Also relax the softmax threading gate from n_row_vec_cnt >= n_threads to >= 2, which was unnecessarily forcing single-thread fallback when n_rows_g < 512. On the dispatch side, remove the ne[2] != 1 guard that blocked multi-head (prefill) FA from reaching the HTP backend — GQA is already handled internally by both the HMX and HVX flash-attention paths. * hmx: relax matmul pipeline gate to cover k > n shapes (e.g. FFN_down) * hmx: optimize FA softmax mask phase (no-ALiBi fast path + GQA dedup) * hmx: Add an asm memory clobber at the phase boundary to prevent reorder bug * [experimental]: fp16 softmax (EXP2_HF) to accelerate fa Bake log2(e) into qk_scale and use hvx_exp2_hf directly for P and m_diff (base-2 consistent, matches htp-ops-lib). ~22 ALU ops for 64 lanes vs ~44 for the F32 round-trip path. * hmx flash-attn: refine cost model coefficients based on profiling data * hmx flash-attn: replace asm clobber with targeted volatile reads on vtcm_d_tiles * hmx flash-attn: fix prefill correctness (dst indexing, softmax reduce, V stride) * hmx flash-attn: fix p_tiles dual-tile OOB race; enable MT + pipeline * hmx flash-attn: preserve additive mask bias in no-ALiBi fast path The no-ALiBi fast path (max_bias==0) was skipping mask add entirely on the assumption that mask values are only {0, -inf}. This is wrong when the mask carries additive positional bias — those terms were silently dropped. Keep the slope-mul skip (slope≡1.0) but add mask back so the bias survives; vmux still clamps below -16 to -inf. Also add HMX FA coverage to test-backend-ops: prefill shapes (nb=64, nb=32) × {mask on/off} × {ALiBi on/off} × {softcap on/off}, F16 KV, hs ∈ {64, 128}. * hmx: fix softcap+EXP2_HF interaction, tighten matmul pipeline gate, add FA tests - flash-attn: when EXP2_HF is on AND logit_softcap is active, fold log2(e) into the post-tanh multiplier (v_cap) instead of pre-baking it into qk_scale. Pre-baking shifted the tanh knee from x≈c to x≈c/log2(e) and produced numerically wrong softcapped outputs whenever both knobs were enabled. - flash-attn softmax (fa_softmax_thread): replace the union+memcpy scalar extract pattern with HVX vmux-based per-row accumulators on rowmax/rowsum. Add hvx_vec_get_f16 helper in hvx-base.h. Functional parity, less scalar code, clearer hf/qf16 lane-format contract. - matmul (hmx_mat_mul_permuted_qk_0_d16a32): pick pipeline vs sequential layout based on whether the chunker actually yields >=2 n-chunks, instead of the static (m>=128 && n>=256) gate. Avoids paying for output double-buffer + worker dispatch when there is no HMX/HVX overlap to gain (e.g. shapes that collapse to one n-chunk). - tests: add HMX flash-attention coverage over the {mask, ALiBi (max_bias), logit_softcap} cross-product for the prefill path — head_dim 64/128, GQA 4×4, kv=512/nb=64 plus a kv=113/nb=32 non-aligned case. * [Help Wanted]: refactor D matrix computation into separate function for clarity and maintainability * format code * hexagon: looks like -O3 is causing issues with the large code base, switch to -O2 and -flto instead * hexagon: use hex_ prefix for swap_ptr * hexagon: move vtcm_seq_alloc into vtcm-utils.h More vtcm allocator updates are coming so it makes sense to start the separate hdr for it. * hmx-utils: add hmx_prefix for layout converters * hmx-mm: move main hmx_mm functions to the end, remove unused fwd decls, etc * hmx-mm: remove unused qweight_fetch_task_state_t and minor alignment fixes * hmx-fa: minor alignment fixes * hmx-fa: move hmx_flash_atten into hmx-ops.h * hmx-fa: remove redundant workpool pointer in the hmx_fa_ctx, plus minor alignment updates * hmx-fa: minor alignment and simplifications * hexagon: move FA_EXP_F16 option to hostside CMake file * hmx-fa: use hvx_vec_splat_f16 instead of fp16_to_bits * hmx-fa: add hvx_splat_u16/u8 and use that in the fa instead custom hvx_fill * hmx-fa: some more alignment updates in the core fa function * hmx-fa: keep slopes in vtcm in fp16 Saves malloc/free and removes the need for float -> fp16 downcast on every use. * hexagon: consistent noinline usage (after static) * hex-hmx: consistent use FARF_HIGH to enable debug output * hmx-utils: no need for always_inline attr * hex-hmx: consistent noinline usage (static noinline ...) * hex-hmx: simplify init_col_scales * hexagon: fix editorconfig errors * hmx-mm: minor alignment fixes --------- Co-authored-by: Max Krasnyansky --- ggml/src/ggml-hexagon/CMakeLists.txt | 3 +- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 3 +- ggml/src/ggml-hexagon/htp/CMakeLists.txt | 7 + .../ggml-hexagon/htp/cmake-toolchain.cmake | 10 +- ggml/src/ggml-hexagon/htp/flash-attn-ops.c | 14 +- ggml/src/ggml-hexagon/htp/hex-utils.h | 6 + .../src/ggml-hexagon/htp/hmx-flash-attn-ops.c | 1840 +++++++++++++++++ ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c | 1435 +++++++------ ggml/src/ggml-hexagon/htp/hmx-ops.h | 3 + ggml/src/ggml-hexagon/htp/hmx-utils.h | 192 +- ggml/src/ggml-hexagon/htp/hvx-base.h | 6 + ggml/src/ggml-hexagon/htp/hvx-copy.h | 37 +- ggml/src/ggml-hexagon/htp/vtcm-utils.h | 16 + 13 files changed, 2798 insertions(+), 774 deletions(-) create mode 100644 ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c create mode 100644 ggml/src/ggml-hexagon/htp/vtcm-utils.h diff --git a/ggml/src/ggml-hexagon/CMakeLists.txt b/ggml/src/ggml-hexagon/CMakeLists.txt index f3a583543c6..b82bae0c103 100644 --- a/ggml/src/ggml-hexagon/CMakeLists.txt +++ b/ggml/src/ggml-hexagon/CMakeLists.txt @@ -22,7 +22,8 @@ message(STATUS "hexagon: using ${HEXAGON_SDK_ROOT} and ${HEXAGON_TOOLS_ROOT} for include(${HEXAGON_SDK_ROOT}/build/cmake/hexagon_fun.cmake) include(ExternalProject) -option(GGML_HEXAGON_HTP_DEBUG "ggml-hexagon: enable HTP debug output" OFF) +option(GGML_HEXAGON_HTP_DEBUG "ggml-hexagon: enable HTP debug output" OFF) +option(GGML_HEXAGON_FA_EXP2_HF "ggml-hexagon: use FP16 exp2 polynomial in FA softmax instead of F32 exp round-trip" OFF) set(GGML_HEXAGON_HTP_CERT "$ENV{HEXAGON_HTP_CERT}" CACHE PATH "ggml-hexagon: enable HTP library signing using certificate") set(GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE 128 CACHE STRING "ggml-hexagon: quantize group size (32, 64, or 128)") diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 6bb073102c0..df4ed101464 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -2254,8 +2254,7 @@ static bool ggml_hexagon_supported_flash_attn_ext(const struct ggml_hexagon_sess return false; } - if (dst->ne[2] != 1 || dst->ne[3] != 1) { - // FA during prompt still needs work + if (dst->ne[3] != 1) { return false; } diff --git a/ggml/src/ggml-hexagon/htp/CMakeLists.txt b/ggml/src/ggml-hexagon/htp/CMakeLists.txt index 8bd528478ba..7c9e4cda5f1 100644 --- a/ggml/src/ggml-hexagon/htp/CMakeLists.txt +++ b/ggml/src/ggml-hexagon/htp/CMakeLists.txt @@ -44,6 +44,11 @@ target_compile_definitions(${HTP_LIB} PRIVATE $,FARF_HIGH=1,> FP32_QUANTIZE_GROUP_SIZE=${GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE}) +if (GGML_HEXAGON_FA_EXP2_HF) + message(STATUS "ggml-htp: HMX_FA_USE_EXP2_HF=1 (use FP16 exp2 polynomial in FA softmax)") + target_compile_definitions(${HTP_LIB} PRIVATE HMX_FA_USE_EXP2_HF=1) +endif() + # HMX acceleration: available on v73+ architectures set(HTP_HMX_VERSIONS v73 v75 v79 v81) list(FIND HTP_HMX_VERSIONS ${DSP_VERSION} _hmx_idx) @@ -52,11 +57,13 @@ if (_hmx_idx GREATER_EQUAL 0) target_sources(${HTP_LIB} PRIVATE hmx-queue.c hmx-matmul-ops.c + hmx-flash-attn-ops.c ) # -mhmx enables HMX instruction set (needed by files that include hmx-utils.h) set_source_files_properties( hmx-matmul-ops.c + hmx-flash-attn-ops.c PROPERTIES COMPILE_OPTIONS "-mhmx" ) diff --git a/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake b/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake index 7fa236e328f..ed5c198468c 100644 --- a/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +++ b/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake @@ -138,15 +138,15 @@ set(CMAKE_SHARED_LIBRARY_SONAME_C_FLAG "-Wl,-soname,") set(CMAKE_SHARED_LIBRARY_SONAME_CXX_FLAG "-Wl,-soname,") #Compiler Options -set(COMMON_FLAGS "-mcpu=hexagon${V_ARCH} -m${V_ARCH} -mhvx=${V_ARCH} -fvectorize -Wall -Werror -fno-zero-initialized-in-bss -G0 -fdata-sections -fpic ${XQF_ARGS}") +set(COMMON_FLAGS "-mcpu=hexagon${V_ARCH} -m${V_ARCH} -mhvx=${V_ARCH} -fvectorize -flto -Wall -Werror -fno-zero-initialized-in-bss -G0 -fdata-sections -fpic ${XQF_ARGS}") set(CMAKE_CXX_FLAGS_DEBUG "${COMMON_FLAGS} -O0 -D_DEBUG -g") -set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS} -O3 -g") -set(CMAKE_CXX_FLAGS_RELEASE "${COMMON_FLAGS} -O3") +set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS} -O2 -g") +set(CMAKE_CXX_FLAGS_RELEASE "${COMMON_FLAGS} -O2") set(CMAKE_C_FLAGS_DEBUG "${COMMON_FLAGS} -O0 -D_DEBUG -g") -set(CMAKE_C_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS} -O3 -g") -set(CMAKE_C_FLAGS_RELEASE "${COMMON_FLAGS} -O3") +set(CMAKE_C_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS} -O2 -g") +set(CMAKE_C_FLAGS_RELEASE "${COMMON_FLAGS} -O2") set(CMAKE_ASM_FLAGS_DEBUG "${COMMON_FLAGS} ${CMAKE_CXX_FLAGS_DEBUG}") set(CMAKE_ASM_FLAGS_RELEASE "${COMMON_FLAGS} ${CMAKE_CXX_FLAGS_RELEASE}") diff --git a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c index d296a322589..d95df6ac9d5 100644 --- a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +++ b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c @@ -17,13 +17,14 @@ #include "htp-ctx.h" #include "htp-ops.h" #include "htp-ops.h" +#include "hmx-ops.h" // Must be multiple of 32 #define FLASH_ATTN_BLOCK_SIZE (32 * 2) // This is a bit of a hack because the compiler is strugling to properly inline // the default hvx_vec_f32_to_f16 with output into the local array. -static void __attribute__((noinline)) hvx_vec_f32_to_f16_a(void *ptr, HVX_Vector v0, HVX_Vector v1) +static __attribute__((noinline)) void hvx_vec_f32_to_f16_a(void *ptr, HVX_Vector v0, HVX_Vector v1) { *(HVX_Vector *) ptr = hvx_vec_f32_to_f16(v0, v1); } @@ -621,6 +622,17 @@ int op_flash_attn_ext(struct htp_ops_context * octx) { return HTP_STATUS_NO_SUPPORT; } +#ifdef HTP_HAS_HMX + // HMX path: prefill (neq1 >= 32), head_dim multiple of 32, F16 KV + if (k->type == HTP_TYPE_F16 && v->type == HTP_TYPE_F16 && k->ne[0] % 32 == 0 && q->ne[1] >= 32) { + int ret = hmx_flash_attn_ext(octx); + if (ret == HTP_STATUS_OK) { + return ret; + } + // VTCM too small or other failure -> fall through to HVX path + } +#endif + struct htp_fa_context factx; factx.octx = octx; diff --git a/ggml/src/ggml-hexagon/htp/hex-utils.h b/ggml/src/ggml-hexagon/htp/hex-utils.h index 329249e11da..6239ceff4b4 100644 --- a/ggml/src/ggml-hexagon/htp/hex-utils.h +++ b/ggml/src/ggml-hexagon/htp/hex-utils.h @@ -74,6 +74,12 @@ static inline size_t hex_smax(size_t a, size_t b) { return a > b ? a : b; } +static inline void hex_swap_ptr(void ** p1, void ** p2) { + void * t = *p1; + *p1 = *p2; + *p2 = t; +} + static inline void hex_l2fetch(const void * p, uint32_t width, uint32_t stride, uint32_t height) { const uint64_t control = Q6_P_combine_RR(stride, Q6_R_combine_RlRl(width, height)); Q6_l2fetch_AP((void *) p, control); diff --git a/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c b/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c new file mode 100644 index 00000000000..8a6d7c14edf --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c @@ -0,0 +1,1840 @@ +// HMX-accelerated Flash Attention for prefill (neq1 >= 32). +// Ported from htp-ops-lib/src/dsp/ops/flash_attn.c, adapted to the htp/ codebase. + +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#include +#include +#include +#include +#include +#include +#include +#include + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "hex-dma.h" +#include "hmx-profile.h" +#include "hmx-queue.h" +#include "hmx-utils.h" +#include "htp-ctx.h" +#include "htp-ops.h" +#include "hvx-dump.h" +#include "hvx-reduce.h" +#include "hvx-utils.h" +#include "vtcm-utils.h" +#include "worker-pool.h" + +// ============================================================================ +// Constants +// ============================================================================ + +// Tile constants from hmx-utils.h +// HMX_FP16_TILE_N_ROWS = 32 +// HMX_FP16_TILE_N_COLS = 32 +// HMX_FP16_TILE_N_ELMS = 1024 +// HMX_FP16_TILE_SIZE = 2048 + +// ============================================================================ +// Dynamic block size computation (GQA-aware) +// ============================================================================ + +// Exact VTCM usage for a given (gqa_factor, DK, DV, Br, Bc) configuration. +// g_br = hex_align_up(gqa_factor * Br, 32) replaces Br for all Q/O/S/P/D dimensions. +// Layout: Q + O_ping + O_pong + K_dma*2 + V_dma*2 + K_tile + V_tile + S + P + D + vectors + scales +// Mask is DMA'd into a VTCM buffer (Br rows per KV block) to avoid DDR reads in softmax. +static size_t hmx_fa_compute_vtcm_usage(size_t gqa_factor, size_t DK, size_t DV, size_t Br, size_t Bc, size_t n_threads) { + const size_t g_br = hex_align_up(gqa_factor * Br, HMX_FP16_TILE_N_ROWS); + const size_t q_tile_size = hex_align_up(g_br * DK * sizeof(__fp16), 4096); // Q: [g_br, DK] + const size_t o_tile_size = hex_align_up(g_br * DV * sizeof(__fp16), 4096); // O: [g_br, DV] x2 ping-pong + const size_t k_dma_size = hex_align_up(Bc * DK * sizeof(__fp16), 4096); // K DMA: [Bc, DK] x2 double-buf + const size_t v_dma_size = hex_align_up(Bc * DV * sizeof(__fp16), 4096); // V DMA: [Bc, DV] x2 double-buf + const size_t k_tile_size = hex_align_up(Bc * DK * sizeof(__fp16), 4096); // K tiles: [Bc, DK] interleaved + const size_t v_tile_size = hex_align_up(Bc * DV * sizeof(__fp16), 4096); // V tiles: [Bc, DV] interleaved + const size_t s_tile_size = hex_align_up(g_br * Bc * sizeof(__fp16), 4096); // S/P:[g_br, Bc] + const size_t d_tile_size = hex_align_up(g_br * g_br * sizeof(__fp16), 4096); // D: [g_br, g_br] + const size_t col_vec_size = hex_align_up(g_br * sizeof(__fp16), 256); // m, l, etc. + const size_t row_vec_size = hex_align_up(Bc * sizeof(__fp16), 256); + const size_t m_line_size = hex_align_up(Bc * sizeof(__fp16), 128); + const size_t m_buf_size = hex_align_up(Br * m_line_size, 4096); + const size_t slopes_size = hex_align_up(g_br * sizeof(__fp16), 128); + + return q_tile_size * 1 // Q tiles + + o_tile_size * 2 // O ping-pong + + k_dma_size * 2 // K DMA x2 + + v_dma_size * 2 // V DMA x2 + + k_tile_size * 1 // K tiles + + v_tile_size * 1 // V tiles + + s_tile_size * 2 // S + P + + d_tile_size * 1 // D (diagonal matrix) + + col_vec_size * 4 // m_vec, l_vec, s_rowmax, p_rowsum + + row_vec_size * 2 * n_threads // per-thread softmax row scratch + + m_buf_size * 1 // mask VTCM buffer [Br rows] + + slopes_size // Slopes + + 256 * 2; // HMX scales (id + qk) +} + +// ============================================================================ +// FP16 exp2 polynomial (ported from htp-ops-lib/include/dsp/hvx_math.h) +// ============================================================================ +// 5th-order Horner polynomial for exp2(x) in qf16/hf16 domain. Input must be +// ≤ 0 (safe softmax invariant — overflow handling omitted). ~18 ALU ops per +// 64 fp16 lanes, fully parallel across HVX threads (no scatter/gather engine). +// Replaces the F32 round-trip (qf16→f32→exp→f32→f16, ~44 ops for 2×32 lanes). +static inline HVX_Vector hvx_exp2_hf(HVX_Vector x_v) { + const HVX_Vector zero_v = Q6_V_vzero(); + const HVX_Vector half_hf_v = Q6_Vh_vsplat_R(0x3800); // fp16 0.5 + + // k = round_toward_neg_inf(x); f = (float)k; frac = x - f + HVX_Vector x_minus_half = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vsub_VhfVhf(x_v, half_hf_v)); + HVX_Vector k_v = Q6_Vh_equals_Vhf(x_minus_half); // truncate to int16 + HVX_Vector f_v = Q6_Vhf_equals_Vh(k_v); // back to fp16 + + HVX_Vector x_qf16 = Q6_Vqf16_vsub_VhfVhf(x_v, f_v); // fractional part in qf16 + + // Horner: y = ((((E5*x + E4)*x + E3)*x + E2)*x + E1)*x + E0 + HVX_Vector y = Q6_Vqf16_vmpy_Vqf16Vqf16(Q6_Vh_vsplat_R(0x5082), x_qf16); // E5*x + y = Q6_Vqf16_vadd_Vqf16Vhf(y, Q6_Vh_vsplat_R(0x157d)); // + E4 + y = Q6_Vqf16_vmpy_Vqf16Vqf16(y, x_qf16); + y = Q6_Vqf16_vadd_Vqf16Vhf(y, Q6_Vh_vsplat_R(0x20ed)); // + E3 + y = Q6_Vqf16_vmpy_Vqf16Vqf16(y, x_qf16); + y = Q6_Vqf16_vadd_Vqf16Vhf(y, Q6_Vh_vsplat_R(0x2b1b)); // + E2 + y = Q6_Vqf16_vmpy_Vqf16Vqf16(y, x_qf16); + y = Q6_Vqf16_vadd_Vqf16Vhf(y, Q6_Vh_vsplat_R(0x33b0)); // + E1 + y = Q6_Vqf16_vmpy_Vqf16Vqf16(y, x_qf16); + y = Q6_Vqf16_vadd_Vqf16Vhf(y, Q6_Vh_vsplat_R(0x398c)); // + E0 + y = Q6_Vqf16_vmpy_Vqf16Vqf16(y, x_qf16); // y = y * x + y = Q6_Vqf16_vadd_Vqf16Vhf(y, Q6_Vh_vsplat_R(0x3c00)); // + 1.0 + + // Combine polynomial (mantissa) with integer part (exponent): result = y * 2^k + y = Q6_Vhf_equals_Vqf16(y); + HVX_Vector y_exp = Q6_Vuh_vlsr_VuhR(Q6_Vh_vasl_VhR(y, 1), 11); + y_exp = Q6_Vh_vadd_VhVh(k_v, y_exp); + HVX_VectorPred q_underflow = Q6_Q_vcmp_gt_VhVh(zero_v, y_exp); + y = Q6_Vh_vaslacc_VhVhR(y, k_v, 10); + return Q6_V_vmux_QVV(q_underflow, zero_v, y); +} + +#define FA_MIN_KV_BLOCKS 3 + +// Cost-based (Br, Bc) search for flash attention with pipeline constraint. +// +// VTCM model (same as before): +// overhead + g_br * per_gbr + g_br² * per_gbr2 + Bc * per_bc + g_br * Bc * per_gbr_bc +// +// Cost model (minimization objective): +// Q * (c_q_fixed + K * c_iter_fixed), where Q = ceil(qo/Br), K = ceil(kv/Bc) +static int hmx_fa_find_chunk_size(size_t * Br_out, + size_t * Bc_out, + size_t gqa_factor, + size_t DK, + size_t DV, + size_t qo_len, + size_t kv_len, + size_t vtcm_budget, + size_t n_threads) { + const size_t T = HMX_FP16_TILE_N_ROWS; // 32 + const size_t br_unit = hmx_ceil_div(T, gqa_factor); + // Bc must be a multiple of 64 so that n_tiles_per_bc is even. The softmax + // P-tile write uses a dual-tile pattern (vshuff + two stores 16 slots apart) + // that would race across r0 blocks if the last dual-tile is half-occupied. + // See .cursor/todos/hmx-flash-attn-bc-search-space.md for the perf trade-off. + const size_t bc_unit = HMX_FP16_TILE_N_COLS * 2; // 64 + const size_t fp16 = sizeof(__fp16); + + // Approximate per-unit VTCM costs (without per-buffer alignment padding). + const size_t per_gbr = (DK + 2 * DV) * fp16 + 4 * fp16; // Q + O×2 + 4 col vectors + const size_t per_gbr2 = fp16; // D diagonal matrix + const size_t per_bc = + 3 * (DK + DV) * fp16 + 2 * n_threads * fp16; // K_dma×2 + V_dma×2 + K_tile + V_tile + row bufs + const size_t per_gbr_bc = 2 * fp16; // S + P + + const size_t overhead = 256 * 2 + 13 * 4096; + + if (vtcm_budget <= overhead) { + return -1; + } + const size_t usable = vtcm_budget - overhead; + + // Br_max: largest Br aligned to br_unit that does not exceed qo_len. + const size_t Br_max = qo_len >= br_unit ? hex_align_down(qo_len, br_unit) : br_unit; + + // Pipeline constraint: cap Bc so n_kv_blocks >= FA_MIN_KV_BLOCKS. + // Only relax when kv_len is too short to form enough blocks. + const bool can_pipeline = (kv_len >= FA_MIN_KV_BLOCKS * bc_unit && n_threads >= 2); + const size_t Bc_limit = can_pipeline ? hex_align_down(kv_len / FA_MIN_KV_BLOCKS, bc_unit) : + (kv_len >= bc_unit ? hex_align_down(kv_len, bc_unit) : bc_unit); + // Cost coefficients calibrated from profiling + const size_t c_q_fixed = 1400; // per-Q-block: q_load + epilogue o_update + o_norm + o_store + const size_t c_iter_fixed = 200; // per-KV-iter: HMX queue push/pop + DMA pop + barriers + + size_t best_cost = SIZE_MAX, best_mn = 0; + size_t best_Br = 0, best_Bc = 0; + + for (size_t Br = Br_max; Br >= br_unit; Br -= br_unit) { + const size_t g_br = hex_align_up(gqa_factor * Br, T); + + // g_br-dependent VTCM cost: g_br * per_gbr + g_br² * per_gbr2 + const size_t gbr_cost = g_br * per_gbr + g_br * g_br * per_gbr2; + if (gbr_cost >= usable) { + if (Br == br_unit) { + break; + } + continue; + } + + // Analytically solve for max Bc: + // remain >= Bc * (per_bc + g_br * per_gbr_bc + Br * fp16_mask) + // The Br * fp16 term accounts for the VTCM mask buffer [Br × Bc]. + const size_t remain = usable - gbr_cost; + const size_t bc_denom = per_bc + g_br * per_gbr_bc + Br * fp16; + size_t Bc = hex_smin(hex_align_down(remain / bc_denom, bc_unit), Bc_limit); + if (Bc < bc_unit) { + if (Br == br_unit) { + break; + } + continue; + } + + // Exact VTCM verification (alignment padding may push over budget) + while (Bc >= bc_unit && hmx_fa_compute_vtcm_usage(gqa_factor, DK, DV, Br, Bc, n_threads) > vtcm_budget) { + Bc -= bc_unit; + } + if (Bc < bc_unit) { + if (Br == br_unit) { + break; + } + continue; + } + + const size_t q_blocks = (qo_len + Br - 1) / Br; + const size_t kv_blocks = (kv_len + Bc - 1) / Bc; + const size_t cost = q_blocks * (c_q_fixed + kv_blocks * c_iter_fixed); + const size_t mn = Br * Bc; + + if (cost < best_cost || (cost == best_cost && mn > best_mn)) { + best_cost = cost; + best_mn = mn; + best_Br = Br; + best_Bc = Bc; + } + + if (Br == br_unit) { + break; + } + } + + if (best_Br == 0) { + return -1; + } + + *Br_out = best_Br; + *Bc_out = best_Bc; + return 0; +} + +// ============================================================================ +// Tile interleave / extract helpers +// ============================================================================ + +// transpose scatter offsets moved to hmx-utils.h as hmx_transpose_scatter_offsets + +// Scatter offsets for diagonal tile: entry[2i] = i*136, entry[2i+1] = i*136+6 +// 136 = 4 * 32 + 8 = byte offset to diagonal in a 32x32 fp16 interleaved tile +static const int16_t d_tile_scatter_offsets[64] __attribute__((aligned(128))) = { + 0 * 136, 0 * 136 + 6, + 1 * 136, 1 * 136 + 6, + 2 * 136, 2 * 136 + 6, + 3 * 136, 3 * 136 + 6, + 4 * 136, 4 * 136 + 6, + 5 * 136, 5 * 136 + 6, + 6 * 136, 6 * 136 + 6, + 7 * 136, 7 * 136 + 6, + 8 * 136, 8 * 136 + 6, + 9 * 136, 9 * 136 + 6, + 10 * 136, 10 * 136 + 6, + 11 * 136, 11 * 136 + 6, + 12 * 136, 12 * 136 + 6, + 13 * 136, 13 * 136 + 6, + 14 * 136, 14 * 136 + 6, + 15 * 136, 15 * 136 + 6, + 0, 0, + 0, 0, + 0, 0, + 0, 0, + 0, 0, + 0, 0, + 0, 0, + 0, 0, + 0, 0, + 0, 0, + 0, 0, + 0, 0, + 0, 0, + 0, 0, + 0, 0, + 0, 0, +}; + +// hmx_interleave_rows_to_tiles and hmx_interleave_cols_to_tiles are in hmx-utils.h + +// ============================================================================ +// HMX Flash Attention context (GQA-merged) +// ============================================================================ + +struct hmx_fa_context { + const struct htp_ops_context * octx; + bool use_pipeline; // true when n_kv_blocks >= FA_MIN_KV_BLOCKS && n_threads >= 2 + uint32_t n_threads; + + // Op parameters + float scale; + float max_bias; + float logit_softcap; + uint32_t n_head_log2; + float m0, m1; + + // Dimensions + uint32_t DK, DV; + uint32_t n_kv; // kv_len + uint32_t n_kv_heads; // number of KV heads + uint32_t n_heads; // number of Q heads + uint32_t G; // GQA factor = n_heads / n_kv_heads + uint32_t n_kv_blocks; + uint32_t neq1; // Q token count + + // Types + bool is_q_fp32; + bool is_dst_fp32; + + // Dynamic block sizes + uint32_t Br; // Q tokens per block (before GQA expansion) + uint32_t Bc; + uint32_t g_br; // hex_align_up(G * Br, 32) - actual tile row dim + + // VTCM buffers (allocated by vtcm_seq_alloc) + __fp16 * vtcm_q_tiles; // Q tile format [g_br, D] + __fp16 * vtcm_o_tiles[2]; // O ping-pong [g_br, D] + __fp16 * vtcm_k_fp16[2]; // K DMA double-buffer [Bc, D] + __fp16 * vtcm_v_fp16[2]; // V DMA double-buffer [Bc, D] + __fp16 * vtcm_k_tiles; // K tiles (transposed) + __fp16 * vtcm_v_tiles; // V tiles (column-major) + __fp16 * vtcm_s_tiles; // S = QK^T [g_br, Bc] + __fp16 * vtcm_p_tiles; // P = softmax(S) [g_br, Bc] + __fp16 * vtcm_d_tiles; // Diagonal rescale [g_br, g_br] + HVX_Vector * vtcm_m_vec; // Row max [g_br] + HVX_Vector * vtcm_l_vec; // Row sum [g_br] + HVX_Vector * vtcm_s_rowmax; // Softmax intermediate [g_br] + HVX_Vector * vtcm_p_rowsum; // Softmax intermediate [g_br] + HVX_Vector * vtcm_row_bufs; // Per-thread softmax row scratch [n_threads][2][Bc/64] + uint8_t * vtcm_hmx_scales_id; // HMX output scales (identity) + uint8_t * vtcm_hmx_scales_qk; // HMX output scales (qk_scale) + __fp16 * vtcm_mask_buf; // VTCM mask buffer [Br × m_line], DMA'd per KV block + __fp16 * vtcm_slopes; // ALiBi slopes [g_br] + size_t row_buf_stride; // HVX vectors per row buffer (Bc/64) + size_t mask_buf_row_stride; // elements (__fp16) per row in mask buffer + bool mask_broadcast; // true when mask->ne[2] == 1 (head-independent, single 2D DMA) +}; + +// ============================================================================ +// Multi-thread K interleave phase +// ============================================================================ + +typedef struct { + struct hmx_fa_context * factx; + int kv_rows; + size_t src_stride; + size_t buf_idx; +} fa_k_int_args_t; + +static void fa_k_interleave_thread(unsigned int n, unsigned int i, void * data) { + fa_k_int_args_t * args = (fa_k_int_args_t *) data; + struct hmx_fa_context * factx = args->factx; + + const int total_rows = args->kv_rows; + const int rows_per_t = hex_align_up(hmx_ceil_div(total_rows, n), 2); // ensure even (row pairs) + const int start = i * rows_per_t; + const int end = hex_smin(start + rows_per_t, total_rows); + + if (start >= total_rows) { + return; + } + + hmx_interleave_rows_to_tiles(factx->vtcm_k_tiles, factx->vtcm_k_fp16[args->buf_idx], total_rows, (int) factx->DK, + (int) args->src_stride, start, end); +} + +static void fa_phase_k_interleave(struct hmx_fa_context * factx, int kv_rows, size_t src_stride, size_t buf_idx) { + worker_pool_context_t wp = factx->octx->ctx->worker_pool; + fa_k_int_args_t args = { factx, kv_rows, src_stride, buf_idx }; + if (factx->n_threads > 1 && kv_rows >= (int) (factx->n_threads * 2)) { + worker_pool_run_func(wp, fa_k_interleave_thread, &args, factx->n_threads); + } else { + fa_k_interleave_thread(1, 0, &args); + } +} + +// ============================================================================ +// Multi-thread V interleave phase +// ============================================================================ + +typedef struct { + struct hmx_fa_context * factx; + int kv_rows; + size_t src_stride; + size_t buf_idx; + size_t n_col_tiles; +} fa_v_int_args_t; + +static void fa_v_interleave_thread(unsigned int n, unsigned int i, void * data) { + fa_v_int_args_t * args = (fa_v_int_args_t *) data; + struct hmx_fa_context * factx = args->factx; + + const int total_rows = args->kv_rows; + const int rows_per_t = hex_align_up(hmx_ceil_div(total_rows, n), 2); + const int start = i * rows_per_t; + const int end = hex_smin(start + rows_per_t, total_rows); + + if (start >= total_rows) { + return; + } + + hmx_interleave_cols_to_tiles(factx->vtcm_v_tiles, factx->vtcm_v_fp16[args->buf_idx], total_rows, (int) factx->DV, + (int) args->src_stride, (int) args->n_col_tiles, start, end); +} + +static void fa_phase_v_interleave(struct hmx_fa_context * factx, + int kv_rows, + size_t src_stride, + size_t buf_idx, + size_t n_col_tiles) { + worker_pool_context_t wp = factx->octx->ctx->worker_pool; + fa_v_int_args_t args = { factx, kv_rows, src_stride, buf_idx, n_col_tiles }; + if (factx->n_threads > 1 && kv_rows >= (int) (factx->n_threads * 2)) { + worker_pool_run_func(wp, fa_v_interleave_thread, &args, factx->n_threads); + } else { + fa_v_interleave_thread(1, 0, &args); + } +} + +// ============================================================================ +// Multi-thread Q load phase: read Q[G × neq1, DK] from DDR, convert F32→F16 +// (or deal F16 pairs), and write interleaved into vtcm_q_tiles. +// Each thread owns a disjoint range of row pairs; writes target distinct tile +// slots (r0 selects tile row, r1 selects intra-tile slot), so there is no +// write conflict. Padding fill (when n_rows_g < g_br) is done single-threaded +// by the caller before dispatching. +// ============================================================================ + +typedef struct { + struct hmx_fa_context * factx; + const struct htp_tensor * q; + uint32_t q_start; + uint32_t kv_head; + uint32_t ib3; + size_t n_rows_g; +} fa_q_load_args_t; + +static void fa_q_load_thread(unsigned int n, unsigned int i, void * data) { + fa_q_load_args_t * args = (fa_q_load_args_t *) data; + struct hmx_fa_context * factx = args->factx; + + const size_t n_rows_g = args->n_rows_g; + const size_t G = factx->G; + const size_t DK = factx->DK; + + // Partition row pairs across threads. Keep each thread's start even so r/r+1 + // are always in the same thread's range. + const size_t rows_per_t = hex_align_up(hmx_ceil_div(n_rows_g, n), 2); + const size_t start = (size_t) i * rows_per_t; + const size_t end = hex_smin(start + rows_per_t, n_rows_g); + + if (start >= n_rows_g) { + return; + } + + const struct htp_tensor * q = args->q; + const uint32_t q_start = args->q_start; + const uint32_t kv_head = args->kv_head; + const uint32_t ib3 = args->ib3; + + for (size_t r = start; r < end; r += 2) { + const bool next_row_valid = (r + 1) < n_rows_g; + + const size_t q_idx0 = (r + 0) / G; + const size_t h_idx0 = (r + 0) % G; + const size_t q_idx1 = (r + 1) / G; + const size_t h_idx1 = (r + 1) % G; + + const uint8_t * q_ptr0 = (const uint8_t *) q->data + (q_start + q_idx0) * q->nb[1] + + (kv_head * G + h_idx0) * q->nb[2] + ib3 * q->nb[3]; + const uint8_t * q_ptr1 = next_row_valid ? ((const uint8_t *) q->data + (q_start + q_idx1) * q->nb[1] + + (kv_head * G + h_idx1) * q->nb[2] + ib3 * q->nb[3]) : + NULL; + + size_t r0 = r / HMX_FP16_TILE_N_ROWS; + size_t r1 = r % HMX_FP16_TILE_N_ROWS; + __fp16 * out_base = factx->vtcm_q_tiles + r0 * HMX_FP16_TILE_N_ROWS * DK; + + if (factx->is_q_fp32) { + const HVX_Vector * pv_in0 = (const HVX_Vector *) q_ptr0; + const HVX_Vector * pv_in1 = q_ptr1 ? (const HVX_Vector *) q_ptr1 : NULL; + + for (uint32_t d = 0; d < DK / 32; ++d) { + HVX_Vector v0 = pv_in0[d]; + HVX_Vector v1 = pv_in1 ? pv_in1[d] : Q6_V_vzero(); + HVX_Vector v_hf = hvx_vec_f32_to_f16_shuff(v0, v1); + + HVX_Vector * out_tile = (HVX_Vector *) (out_base + d * HMX_FP16_TILE_N_ELMS); + out_tile[r1 / 2] = v_hf; + } + } else { + const HVX_Vector * pv_in0 = (const HVX_Vector *) q_ptr0; + const HVX_Vector * pv_in1 = q_ptr1 ? (const HVX_Vector *) q_ptr1 : NULL; + + for (uint32_t d = 0; d < DK / 64; ++d) { + HVX_Vector v0 = pv_in0[d]; + HVX_Vector v1 = pv_in1 ? pv_in1[d] : Q6_V_vzero(); + HVX_VectorPair vp = Q6_W_vshuff_VVR(v1, v0, -2); + + __fp16 * out_dual_tile = out_base + d * HMX_FP16_TILE_N_ELMS * 2; + HVX_Vector * pv_out0 = ((HVX_Vector *) out_dual_tile) + r1 / 2; + HVX_Vector * pv_out1 = pv_out0 + 16; + + *pv_out0 = Q6_V_lo_W(vp); + *pv_out1 = Q6_V_hi_W(vp); + } + } + } +} + +static void fa_phase_q_load(struct hmx_fa_context * factx, + const struct htp_tensor * q, + uint32_t q_start, + uint32_t kv_head, + uint32_t ib3, + size_t n_rows_g) { + worker_pool_context_t wp = factx->octx->ctx->worker_pool; + fa_q_load_args_t args = { factx, q, q_start, kv_head, ib3, n_rows_g }; + // Require >= 2 row pairs per thread so partitioning is worthwhile. + if (factx->n_threads > 1 && n_rows_g >= (size_t) (factx->n_threads * 2)) { + worker_pool_run_func(wp, fa_q_load_thread, &args, factx->n_threads); + } else { + fa_q_load_thread(1, 0, &args); + } +} + +// ============================================================================ +// Multi-thread O store phase: read O tiles from VTCM, convert F16->F32 (or +// deal F16 pairs), and write to strided DDR dst tensor. Each thread owns a +// disjoint row range; writes target distinct dst rows (different q_idx/h_idx +// pairs produced by r/G and r%G), so there is no write conflict. +// ============================================================================ + +typedef struct { + struct hmx_fa_context * factx; + const struct htp_tensor * dst; + const __fp16 * o_tile_src; + uint32_t q_start; + uint32_t kv_head; + uint32_t ib3; + size_t n_rows_g; +} fa_o_store_args_t; + +static void fa_o_store_thread(unsigned int n, unsigned int i, void * data) { + fa_o_store_args_t * args = (fa_o_store_args_t *) data; + struct hmx_fa_context * factx = args->factx; + + const size_t n_rows_g = args->n_rows_g; + const size_t G = factx->G; + const size_t DV = factx->DV; + + const size_t rows_per_t = hmx_ceil_div(n_rows_g, n); + const size_t start = (size_t) i * rows_per_t; + const size_t end = hex_smin(start + rows_per_t, n_rows_g); + + if (start >= n_rows_g) { + return; + } + + const struct htp_tensor * dst = args->dst; + const __fp16 * o_tile_src = args->o_tile_src; + const uint32_t q_start = args->q_start; + const uint32_t kv_head = args->kv_head; + const uint32_t ib3 = args->ib3; + + for (size_t r = start; r < end; ++r) { + const size_t q_idx = r / G; + const size_t h_idx = r % G; + + // FIX(dst-indexing): ggml_flash_attn_ext() creates dst as permute(0,2,1,3) -> + // [DV, n_heads, n_tokens, n_seq], so head stride is nb[1] and token stride is nb[2]. + uint8_t * dst_row = (uint8_t *) dst->data + (kv_head * G + h_idx) * dst->nb[1] + + (q_start + q_idx) * dst->nb[2] + ib3 * dst->nb[3]; + + size_t r0 = r / HMX_FP16_TILE_N_ROWS; + size_t r1 = r % HMX_FP16_TILE_N_ROWS; + const __fp16 * tile_row_base = o_tile_src + r0 * HMX_FP16_TILE_N_ROWS * DV; + + if (factx->is_dst_fp32) { + float * out = (float *) dst_row; + for (uint32_t d = 0; d < DV / 32; ++d) { + const HVX_Vector * in_tile = (const HVX_Vector *) (tile_row_base + d * HMX_FP16_TILE_N_ELMS); + HVX_VectorPair vp = hvx_vec_f16_to_f32_shuff(in_tile[r1 / 2]); + if (r1 % 2 == 0) { + *(HVX_UVector *) (out + d * 32) = Q6_V_lo_W(vp); + } else { + *(HVX_UVector *) (out + d * 32) = Q6_V_hi_W(vp); + } + } + } else { + __fp16 * out = (__fp16 *) dst_row; + for (uint32_t d = 0; d < DV / 64; ++d) { + const __fp16 * in_dual_tile = tile_row_base + d * HMX_FP16_TILE_N_ELMS * 2; + const HVX_Vector * pv_in0 = ((const HVX_Vector *) in_dual_tile) + r1 / 2; + const HVX_Vector * pv_in1 = pv_in0 + 16; + HVX_VectorPair vp = Q6_W_vdeal_VVR(*pv_in1, *pv_in0, -2); + if (r1 % 2 == 0) { + *(HVX_UVector *) (out + d * 64) = Q6_V_lo_W(vp); + } else { + *(HVX_UVector *) (out + d * 64) = Q6_V_hi_W(vp); + } + } + } + } +} + +static void fa_phase_o_store(struct hmx_fa_context * factx, + const struct htp_tensor * dst, + const __fp16 * o_tile_src, + uint32_t q_start, + uint32_t kv_head, + uint32_t ib3, + size_t n_rows_g) { + worker_pool_context_t wp = factx->octx->ctx->worker_pool; + fa_o_store_args_t args = { factx, dst, o_tile_src, q_start, kv_head, ib3, n_rows_g }; + if (factx->n_threads > 1 && n_rows_g >= (size_t) (factx->n_threads * 2)) { + worker_pool_run_func(wp, fa_o_store_thread, &args, factx->n_threads); + } else { + fa_o_store_thread(1, 0, &args); + } +} + +// ============================================================================ +// Multi-thread softmax phase + serial m/l update + build_D +// ============================================================================ + +typedef struct { + struct hmx_fa_context * factx; + size_t kv_rows; + size_t n_rows_g; + size_t n_col_tiles; + size_t n_tiles_per_bc; + size_t n_row_tiles; + size_t n_row_tiles_g_br; + uint32_t Bc; + uint32_t G; + uint32_t kv_head; + uint32_t kv_start; + uint32_t q_start; + uint32_t ib3; + bool has_alibi; // true when max_bias != 0 (need slope * mask + add) + + // ALiBi per-head slopes (indexed by GQA-merged row: slope[r] for r in [0, n_rows_g)) + // slope[r] = 1.0 when max_bias == 0 (no ALiBi) + // Pointer into hmx_fa_context.vtcm_slopes (sized to g_br) + __fp16 * slopes; + + // Mask info (preloaded before softmax) + const struct htp_tensor * mask; + const __fp16 * mask_vtcm; // VTCM mask buffer base (NULL = DDR fallback) + size_t mask_vtcm_row_stride; // elements (__fp16) per row in VTCM mask buffer +} fa_softmax_args_t; + +static void fa_softmax_thread(unsigned int n, unsigned int i, void * data) { + fa_softmax_args_t * args = (fa_softmax_args_t *) data; + struct hmx_fa_context * factx = args->factx; + + const size_t n_rows_g = args->n_rows_g; + const size_t kv_rows = args->kv_rows; + const size_t Bc = args->Bc; + const size_t G = args->G; + const size_t n_tiles_per_bc = args->n_tiles_per_bc; + const size_t n_row_vec_cnt = hmx_ceil_div(n_rows_g, 64); + + // Partition r_vec_idx across threads + const size_t vecs_per_t = hmx_ceil_div(n_row_vec_cnt, n); + const size_t vec_start = i * vecs_per_t; + const size_t vec_end = hex_smin(vec_start + vecs_per_t, n_row_vec_cnt); + + if (vec_start >= n_row_vec_cnt) { + return; + } + + // Per-thread row scratch: thread i uses bufs at offset i * 2 * stride + const size_t row_buf_stride = factx->row_buf_stride; + HVX_Vector * my_row_buf0 = factx->vtcm_row_bufs + i * 2 * row_buf_stride; + HVX_Vector * my_row_buf1 = my_row_buf0 + row_buf_stride; + + const HVX_Vector v_neg_inf = Q6_Vh_vsplat_R(0xfbff); + + // Per-row accumulators: each fp16 lane in a 64-lane vector holds one row's scalar. + // CONTRACT: lane bits must be IEEE fp16 (hf), never qf16 — qf16 uses a different + // bit layout, so a later hf-domain read would silently produce wrong values. + // Convert first via Q6_Vhf_equals_Vqf16(). For reference: vtcm_m_vec/vtcm_s_rowmax + // are hf; vtcm_l_vec is qf16 — don't mix them up. + + for (size_t r_vec_idx = vec_start; r_vec_idx < vec_end; ++r_vec_idx) { + HVX_Vector rowmax_acc_v = v_neg_inf; + HVX_Vector rowsum_acc_v = Q6_V_vzero(); + HVX_Vector m_prev_v = factx->vtcm_m_vec[r_vec_idx]; + + for (int r_vec_off = 0; r_vec_off < 64; r_vec_off += 2) { + int r = r_vec_idx * 64 + r_vec_off; + if (r >= (int) hex_align_up(n_rows_g, 2)) { + break; + } + + int r0 = r / HMX_FP16_TILE_N_ROWS; + int r1 = r % HMX_FP16_TILE_N_ROWS; + + const __fp16 * s_ld_base = factx->vtcm_s_tiles + r0 * HMX_FP16_TILE_N_ROWS * Bc; + __fp16 * p_st_base = factx->vtcm_p_tiles + r0 * HMX_FP16_TILE_N_ROWS * Bc; + + // Decode 2 rows from S tiles into per-thread row buffers + HVX_Vector * pv_row_buf0 = my_row_buf0; + HVX_Vector * pv_row_buf1 = my_row_buf1; + for (size_t c = 0; c < kv_rows; c += 64) { + const __fp16 * in_dual_tile = s_ld_base + (c / 64) * HMX_FP16_TILE_N_ELMS * 2; + const HVX_Vector * pv_s_in0 = ((const HVX_Vector *) in_dual_tile) + r1 / 2; + const HVX_Vector * pv_s_in1 = pv_s_in0 + 16; + + HVX_VectorPair vp_s_dual_row = Q6_W_vdeal_VVR(*pv_s_in1, *pv_s_in0, -2); + *pv_row_buf0++ = Q6_V_lo_W(vp_s_dual_row); + *pv_row_buf1++ = Q6_V_hi_W(vp_s_dual_row); + } + + // Apply softcap if enabled (in F32 precision) + if (factx->logit_softcap != 0.0f) { + // When EXP2_HF is on, fold log2(e) into v_cap so the output lands in + // log2(e)-scaled space for the downstream exp2. log2(e) is kept OUT + // of qk_scale in this configuration (see scale setup) so tanh sees + // the physical QK/(√d·c) argument. + float cap = factx->logit_softcap; +#ifdef HMX_FA_USE_EXP2_HF + cap *= 1.44269504f; // log2(e) +#endif + const HVX_Vector v_cap = hvx_vec_splat_f32(cap); + for (size_t c = 0; c < kv_rows; c += 64) { + size_t ci = c / 64; + + HVX_VectorPair r0_f32 = hvx_vec_f16_to_f32(my_row_buf0[ci]); + HVX_Vector t0_lo = hvx_vec_tanh_f32(Q6_V_lo_W(r0_f32)); + HVX_Vector t0_hi = hvx_vec_tanh_f32(Q6_V_hi_W(r0_f32)); + t0_lo = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(t0_lo, v_cap)); + t0_hi = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(t0_hi, v_cap)); + my_row_buf0[ci] = hvx_vec_f32_to_f16(t0_lo, t0_hi); + + HVX_VectorPair r1_f32 = hvx_vec_f16_to_f32(my_row_buf1[ci]); + HVX_Vector t1_lo = hvx_vec_tanh_f32(Q6_V_lo_W(r1_f32)); + HVX_Vector t1_hi = hvx_vec_tanh_f32(Q6_V_hi_W(r1_f32)); + t1_lo = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(t1_lo, v_cap)); + t1_hi = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(t1_hi, v_cap)); + my_row_buf1[ci] = hvx_vec_f32_to_f16(t1_lo, t1_hi); + } + } + + // Apply mask & compute rowmax(S) + // + // Optimizations over baseline: + // A. No-ALiBi fast path: when max_bias==0 (slope≡1.0), skip the + // slope multiplication — still add mask (additive bias) but + // avoid the mul_f16_f16. Saves 2 ops/dual-row vs ALiBi path. + // B. GQA mask row dedup: G consecutive Q rows share one mask row + // (qi = r / G). Reuse mask vector when qi is unchanged between + // row0 and row1 (saves ~75% of VTCM loads for G=4). + + // ALiBi slopes — only needed when has_alibi (scheme A) + HVX_Vector v_slope0, v_slope1; + if (args->has_alibi) { + v_slope0 = hvx_vec_splat_f16(args->slopes[r + 0]); + v_slope1 = (r + 1 < (int) n_rows_g) ? hvx_vec_splat_f16(args->slopes[r + 1]) : Q6_V_vzero(); + } + + const HVX_Vector v_threshold = Q6_Vh_vsplat_R(0xcc00); // fp16 -16.0 (hoisted outside for-c) + + HVX_Vector v_s_rowmax0 = v_neg_inf; + HVX_Vector v_s_rowmax1 = v_neg_inf; + for (size_t c = 0; c < kv_rows; c += 64) { + size_t ci = c / 64; + const size_t ne = hex_smin(kv_rows - c, 64); + HVX_VectorPred q_tail_keep = Q6_Q_vsetq2_R(ne * sizeof(__fp16)); + + if (args->mask) { + HVX_Vector v_mask0, v_mask1; + + if (args->mask_vtcm) { + // Read mask from VTCM buffer (DMA'd per KV block). + // GQA dedup (scheme B): skip load when qi unchanged. + const size_t qi0 = (r + 0) / G; + v_mask0 = *(const HVX_UVector *) (args->mask_vtcm + qi0 * args->mask_vtcm_row_stride + c); + v_mask1 = v_neg_inf; + if (r + 1 < (int) n_rows_g) { + const size_t qi1 = (r + 1) / G; + if (qi1 == qi0) { + v_mask1 = v_mask0; // scheme B: reuse — same mask row + } else { + v_mask1 = *(const HVX_UVector *) (args->mask_vtcm + qi1 * args->mask_vtcm_row_stride + c); + } + } + } else { + // Fallback: read mask directly from DDR (when mask->ne[2] > 1). + const struct htp_tensor * mask = args->mask; + const size_t q_idx0 = args->q_start + ((r + 0) / G); + const size_t h_idx0 = args->kv_head * G + (r + 0) % G; + const uint32_t im2_0 = h_idx0 % mask->ne[2]; + const uint32_t im3_0 = args->ib3 % mask->ne[3]; + + const __fp16 * m0_ptr = (const __fp16 *) ((const uint8_t *) mask->data + q_idx0 * mask->nb[1] + + im2_0 * mask->nb[2] + im3_0 * mask->nb[3]) + args->kv_start + c; + v_mask0 = *(const HVX_UVector *) m0_ptr; + v_mask1 = v_neg_inf; + + if (r + 1 < (int) n_rows_g) { + const size_t q_idx1 = args->q_start + ((r + 1) / G); + if (q_idx1 == q_idx0) { + // scheme B: same mask row in DDR path + v_mask1 = v_mask0; + } else { + const size_t h_idx1 = args->kv_head * G + (r + 1) % G; + const uint32_t im2_1 = h_idx1 % mask->ne[2]; + const uint32_t im3_1 = args->ib3 % mask->ne[3]; + const __fp16 * m1_ptr = (const __fp16 *) ((const uint8_t *) mask->data + q_idx1 * mask->nb[1] + + im2_1 * mask->nb[2] + im3_1 * mask->nb[3]) + args->kv_start + c; + v_mask1 = *(const HVX_UVector *) m1_ptr; + } + } + } + + // Threshold: mask values below -16.0 are treated as -inf (causal mask). + HVX_VectorPred q_keep0 = Q6_Q_and_QQ(Q6_Q_vcmp_gt_VhfVhf(v_mask0, v_threshold), q_tail_keep); + HVX_VectorPred q_keep1 = Q6_Q_and_QQ(Q6_Q_vcmp_gt_VhfVhf(v_mask1, v_threshold), q_tail_keep); + + if (args->has_alibi) { + // ALiBi path: S += slope * mask (full mul + add) + HVX_Vector v_sm0 = hvx_vec_mul_f16_f16(v_mask0, v_slope0); + HVX_Vector v_sm1 = hvx_vec_mul_f16_f16(v_mask1, v_slope1); + my_row_buf0[ci] = Q6_V_vmux_QVV(q_keep0, hvx_vec_add_f16_f16(my_row_buf0[ci], v_sm0), v_neg_inf); + my_row_buf1[ci] = Q6_V_vmux_QVV(q_keep1, hvx_vec_add_f16_f16(my_row_buf1[ci], v_sm1), v_neg_inf); + } else { + // No-ALiBi fast path (scheme A): slope≡1.0, skip the mul + // but still add mask (additive positional bias). vmux + // clamps mask < -16 to -inf as a numerical safeguard. + my_row_buf0[ci] = Q6_V_vmux_QVV(q_keep0, hvx_vec_add_f16_f16(my_row_buf0[ci], v_mask0), v_neg_inf); + my_row_buf1[ci] = Q6_V_vmux_QVV(q_keep1, hvx_vec_add_f16_f16(my_row_buf1[ci], v_mask1), v_neg_inf); + } + } else { + if (ne < 64) { + my_row_buf0[ci] = Q6_V_vmux_QVV(q_tail_keep, my_row_buf0[ci], v_neg_inf); + my_row_buf1[ci] = Q6_V_vmux_QVV(q_tail_keep, my_row_buf1[ci], v_neg_inf); + } + } + + v_s_rowmax0 = Q6_Vhf_vmax_VhfVhf(v_s_rowmax0, my_row_buf0[ci]); + v_s_rowmax1 = Q6_Vhf_vmax_VhfVhf(v_s_rowmax1, my_row_buf1[ci]); + } + + v_s_rowmax0 = hvx_vec_reduce_max_f16(v_s_rowmax0); + v_s_rowmax1 = hvx_vec_reduce_max_f16(v_s_rowmax1); + + // Splat m_prev[r], m_prev[r+1] from the per-row accumulator. + // vror brings the target lane to lane 0, then extract + re-splat. + HVX_Vector v_m_prev0 = hvx_vec_splat_f16(hvx_vec_get_f16(Q6_V_vror_VR(m_prev_v, r_vec_off * 2))); + HVX_Vector v_m_prev1 = hvx_vec_splat_f16(hvx_vec_get_f16(Q6_V_vror_VR(m_prev_v, (r_vec_off + 1) * 2))); + + // HVX max — both operands are splats, so result is splat of m_new. + HVX_Vector v_dup_m0 = Q6_Vhf_vmax_VhfVhf(v_m_prev0, v_s_rowmax0); + HVX_Vector v_dup_m1 = Q6_Vhf_vmax_VhfVhf(v_m_prev1, v_s_rowmax1); + + // Insert row r, r+1 rowmax into rowmax_acc_v via 2-byte-wide vmux. + // Byte ranges: lane0 = [r_vec_off*2 .. r_vec_off*2+1], lane1 shifted by 2. + // vsetq2 handles the n=128 corner case when r_vec_off reaches 62. + { + HVX_VectorPred p_start = Q6_Q_vsetq_R(r_vec_off * 2); + HVX_VectorPred p_mid = Q6_Q_vsetq_R((r_vec_off + 1) * 2); + HVX_VectorPred p_end = Q6_Q_vsetq2_R((r_vec_off + 2) * 2); + HVX_VectorPred p_lane0 = Q6_Q_and_QQn(p_mid, p_start); + HVX_VectorPred p_lane1 = Q6_Q_and_QQn(p_end, p_mid); + rowmax_acc_v = Q6_V_vmux_QVV(p_lane0, v_dup_m0, rowmax_acc_v); + rowmax_acc_v = Q6_V_vmux_QVV(p_lane1, v_dup_m1, rowmax_acc_v); + } + + // Compute P = exp(S - m_new), using HVX exp + const HVX_Vector v_zero = Q6_V_vzero(); + HVX_Vector v_p_rowsum0 = v_zero; + HVX_Vector v_p_rowsum1 = v_zero; + +#ifdef HMX_FA_USE_EXP2_HF + // FP16 exp2 polynomial path (matches htp-ops-lib flash_attn.c): + // P = exp2(S - m_new) + for (size_t c = 0; c < kv_rows; c += 64) { + size_t ci = c / 64; + HVX_Vector v_s_minus_m0 = Q6_Vqf16_vsub_VhfVhf(my_row_buf0[ci], v_dup_m0); + HVX_Vector v_s_minus_m1 = Q6_Vqf16_vsub_VhfVhf(my_row_buf1[ci], v_dup_m1); + + HVX_Vector v_p_row0_hf = hvx_exp2_hf(Q6_Vhf_equals_Vqf16(v_s_minus_m0)); + HVX_Vector v_p_row1_hf = hvx_exp2_hf(Q6_Vhf_equals_Vqf16(v_s_minus_m1)); +#else + // F32 exp path: qf16 → f32 → exp → f32 → f16. Higher precision, + for (size_t c = 0; c < kv_rows; c += 64) { + size_t ci = c / 64; + HVX_Vector v_s_minus_m0 = Q6_Vqf16_vsub_VhfVhf(my_row_buf0[ci], v_dup_m0); + HVX_Vector v_s_minus_m1 = Q6_Vqf16_vsub_VhfVhf(my_row_buf1[ci], v_dup_m1); + + HVX_VectorPair vp0 = hvx_vec_f16_to_f32_shuff(Q6_Vhf_equals_Vqf16(v_s_minus_m0)); + HVX_Vector p0_lo = hvx_vec_exp_f32(Q6_V_lo_W(vp0)); + HVX_Vector p0_hi = hvx_vec_exp_f32(Q6_V_hi_W(vp0)); + HVX_Vector v_p_row0_hf = hvx_vec_f32_to_f16_shuff(p0_lo, p0_hi); + + HVX_VectorPair vp1 = hvx_vec_f16_to_f32_shuff(Q6_Vhf_equals_Vqf16(v_s_minus_m1)); + HVX_Vector p1_lo = hvx_vec_exp_f32(Q6_V_lo_W(vp1)); + HVX_Vector p1_hi = hvx_vec_exp_f32(Q6_V_hi_W(vp1)); + HVX_Vector v_p_row1_hf = hvx_vec_f32_to_f16_shuff(p1_lo, p1_hi); +#endif + // Write P to tile format. Dual-tile pattern assumes Bc is a + // multiple of 64 (enforced by bc_unit=64 in hmx_fa_find_chunk_size), + // so both tile halves are always in the current r0 block. + __fp16 * out_dual_tile = p_st_base + (c / 64) * HMX_FP16_TILE_N_ELMS * 2; + HVX_Vector * pv_p_out0 = ((HVX_Vector *) out_dual_tile) + r1 / 2; + HVX_Vector * pv_p_out1 = pv_p_out0 + 16; + + HVX_VectorPair vp_p_dual = Q6_W_vshuff_VVR(v_p_row1_hf, v_p_row0_hf, -2); + *pv_p_out0 = Q6_V_lo_W(vp_p_dual); + *pv_p_out1 = Q6_V_hi_W(vp_p_dual); + + HVX_VectorPair vp_p0 = hvx_vec_f16_to_f32_shuff(v_p_row0_hf); + HVX_VectorPair vp_p1 = hvx_vec_f16_to_f32_shuff(v_p_row1_hf); + + v_p_rowsum0 = Q6_Vqf32_vadd_Vqf32Vqf32(v_p_rowsum0, Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(vp_p0), Q6_V_hi_W(vp_p0))); + v_p_rowsum1 = Q6_Vqf32_vadd_Vqf32Vqf32(v_p_rowsum1, Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(vp_p1), Q6_V_hi_W(vp_p1))); + } + + HVX_Vector rowsum0_sf = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(v_p_rowsum0)); + HVX_Vector rowsum1_sf = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(v_p_rowsum1)); + { + // Both inputs are f32 splats, so the f32->f16 output is an fp16 splat. + HVX_Vector rv0_v = hvx_vec_f32_to_f16(rowsum0_sf, rowsum0_sf); + HVX_Vector rv1_v = hvx_vec_f32_to_f16(rowsum1_sf, rowsum1_sf); + + HVX_VectorPred p_start = Q6_Q_vsetq_R(r_vec_off * 2); + HVX_VectorPred p_mid = Q6_Q_vsetq_R((r_vec_off + 1) * 2); + HVX_VectorPred p_end = Q6_Q_vsetq2_R((r_vec_off + 2) * 2); + HVX_VectorPred p_lane0 = Q6_Q_and_QQn(p_mid, p_start); + HVX_VectorPred p_lane1 = Q6_Q_and_QQn(p_end, p_mid); + rowsum_acc_v = Q6_V_vmux_QVV(p_lane0, rv0_v, rowsum_acc_v); + rowsum_acc_v = Q6_V_vmux_QVV(p_lane1, rv1_v, rowsum_acc_v); + } + } + + factx->vtcm_s_rowmax[r_vec_idx] = rowmax_acc_v; + factx->vtcm_p_rowsum[r_vec_idx] = rowsum_acc_v; + } +} + +// Serial m/l update + build_D. Must run after softmax barrier (s_rowmax written by all threads). +// +// noinline: function boundary acts as a hard compiler barrier so the (size_t)addr scatter +// intrinsics inside cannot be hoisted past the call site. Mirrors the structural protection +// matmul gets for free via worker_pool function-pointer dispatch. Without this, the compiler +// can reorder the scatter past the subsequent hmx_queue_push and the HMX-queue worker thread +// reads stale VTCM (PPL → ~vocab-size). +static __attribute__((noinline)) void fa_ml_update_and_build_d(struct hmx_fa_context * factx, + size_t n_rows_g, + size_t n_row_tiles, + size_t n_row_tiles_g_br) { + // Reuse s_rowmax buffer for exp(m_diff) — safe because softmax is fully complete + HVX_Vector * const mvec_exp_m_diff = factx->vtcm_s_rowmax; + + const size_t n_row_vec_cnt = hmx_ceil_div(n_rows_g, 64); + for (size_t i = 0; i < n_row_vec_cnt; ++i) { + HVX_Vector v_m_prev = factx->vtcm_m_vec[i]; + HVX_Vector v_m_curr = Q6_Vhf_vmax_VhfVhf(v_m_prev, factx->vtcm_s_rowmax[i]); + HVX_Vector v_m_diff = Q6_Vqf16_vsub_VhfVhf(v_m_prev, v_m_curr); + +#ifdef HMX_FA_USE_EXP2_HF + // Base-2 path: must match P = exp2(S - m_new) in fa_softmax_thread. + HVX_Vector v_exp_m_diff = hvx_exp2_hf(Q6_Vhf_equals_Vqf16(v_m_diff)); +#else + HVX_VectorPair vp_diff = hvx_vec_f16_to_f32_shuff(Q6_Vhf_equals_Vqf16(v_m_diff)); + HVX_Vector exp_lo = hvx_vec_exp_f32(Q6_V_lo_W(vp_diff)); + HVX_Vector exp_hi = hvx_vec_exp_f32(Q6_V_hi_W(vp_diff)); + HVX_Vector v_exp_m_diff = hvx_vec_f32_to_f16_shuff(exp_lo, exp_hi); +#endif + + HVX_Vector v_l_curr = Q6_Vqf16_vmpy_Vqf16Vhf(factx->vtcm_l_vec[i], v_exp_m_diff); + v_l_curr = Q6_Vqf16_vadd_Vqf16Vhf(v_l_curr, factx->vtcm_p_rowsum[i]); + + factx->vtcm_m_vec[i] = v_m_curr; + factx->vtcm_l_vec[i] = v_l_curr; + mvec_exp_m_diff[i] = v_exp_m_diff; + } + + // Build diagonal tile D = diag(exp(m_diff)) + const HVX_Vector v_offsets = *(const HVX_Vector *) d_tile_scatter_offsets; + const HVX_VectorPred q_32_mask = Q6_Q_vsetq_R(32 * sizeof(__fp16)); + for (size_t i = 0; i < n_row_tiles; ++i) { + const HVX_Vector v_content = Q6_V_vror_VR(mvec_exp_m_diff[i / 2], (i % 2) * 64); + __fp16 * out_base = factx->vtcm_d_tiles + i * (n_row_tiles_g_br + 1) * HMX_FP16_TILE_N_ELMS; + Q6_vscatter_QRMVhV(q_32_mask, (size_t) out_base, HMX_FP16_TILE_SIZE - 1, v_offsets, v_content); + // Compiler barrier — Q6_vscatter takes (size_t)addr; without this the + // compiler may not recognize the volatile read below as aliasing and + // could reorder it before the scatter, defeating the HW drain. + __asm__ __volatile__("" ::: "memory"); + // Per-tile drain: scatter regions are disjoint (stride > tile size), + // so a single drain at tile 0 does NOT retire later tiles' entries. + (void) *(volatile HVX_Vector *) out_base; + } +} + +// Build D = diag(1/l) tile for the final O = D @ O normalization. +// +// noinline: same rationale as fa_ml_update_and_build_d — keeps Q6_vscatter from +// being hoisted past the subsequent hmx_queue_push at the o_norm call site. +static __attribute__((noinline)) void fa_build_d_diag_inv_l(struct hmx_fa_context * factx, + size_t n_row_tiles, + size_t n_row_tiles_g_br) { + const HVX_Vector v_offsets = *(const HVX_Vector *) d_tile_scatter_offsets; + const HVX_VectorPred q_32_mask = Q6_Q_vsetq_R(32 * sizeof(__fp16)); + const HVX_Vector one = hvx_vec_splat_f32(1.0f); + + HVX_Vector v_content = Q6_V_vzero(); + for (size_t i = 0; i < n_row_tiles; ++i) { + if ((i % 2) == 0) { + HVX_Vector v_l_hf = Q6_Vhf_equals_Vqf16(factx->vtcm_l_vec[i / 2]); + HVX_VectorPair vp_l = hvx_vec_f16_to_f32_shuff(v_l_hf); + HVX_Vector inv_lo = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(one, hvx_vec_inverse_f32(Q6_V_lo_W(vp_l)))); + HVX_Vector inv_hi = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(one, hvx_vec_inverse_f32(Q6_V_hi_W(vp_l)))); + v_content = hvx_vec_f32_to_f16_shuff(inv_lo, inv_hi); + } else { + v_content = Q6_V_vror_VR(v_content, 64); + } + + __fp16 * out_base = factx->vtcm_d_tiles + i * (n_row_tiles_g_br + 1) * HMX_FP16_TILE_N_ELMS; + Q6_vscatter_QRMVhV(q_32_mask, (size_t) out_base, HMX_FP16_TILE_SIZE - 1, v_offsets, v_content); + // Compiler barrier — see fa_ml_update_and_build_d for rationale. + __asm__ __volatile__("" ::: "memory"); + (void) *(volatile HVX_Vector *) out_base; + } +} + +// Combined: multi-thread softmax -> barrier -> serial m/l update + build_D +static void fa_phase_softmax_and_build_d(struct hmx_fa_context * factx, + fa_softmax_args_t * sargs, + size_t n_row_tiles, + size_t n_row_tiles_g_br) { + worker_pool_context_t wp = factx->octx->ctx->worker_pool; + const size_t n_row_vec_cnt = hmx_ceil_div(sargs->n_rows_g, 64); + + if (factx->n_threads > 1 && n_row_vec_cnt >= 2) { + uint32_t n_use = (uint32_t) hex_smin((size_t) factx->n_threads, n_row_vec_cnt); + worker_pool_run_func(wp, fa_softmax_thread, sargs, n_use); + } else { + fa_softmax_thread(1, 0, sargs); + } + // barrier implicit in worker_pool_run_func return + + fa_ml_update_and_build_d(factx, sargs->n_rows_g, n_row_tiles, n_row_tiles_g_br); +} + +// ============================================================================ +// HMX job structs and worker functions +// ============================================================================ + +typedef struct { + const __fp16 * q_tiles; + const __fp16 * k_tiles; + __fp16 * s_tiles; + size_t n_row_tiles; + size_t n_col_tiles; + size_t n_dot_tiles; // DK / 32 + size_t n_tiles_per_bc; + uint8_t * hmx_scales; +} hmx_fa_qk_job_t; + +static void hmx_fa_qk_dot_worker(void * data) { + hmx_fa_qk_job_t * job = (hmx_fa_qk_job_t *) data; + const size_t n_row_tiles = job->n_row_tiles; + const size_t n_col_tiles = job->n_col_tiles; + const size_t n_dot_tiles = job->n_dot_tiles; + const size_t n_tiles_per_bc = job->n_tiles_per_bc; + const __fp16 * restrict q_tiles = job->q_tiles; + const __fp16 * restrict k_tiles = job->k_tiles; + __fp16 * restrict s_tiles = job->s_tiles; + __builtin_assume(n_row_tiles > 0); + __builtin_assume(n_col_tiles > 0); + __builtin_assume(n_dot_tiles > 0); + + Q6_bias_mxmem2_A((void *) job->hmx_scales); + for (size_t r = 0; r < n_row_tiles; ++r) { + for (size_t c = 0; c < n_col_tiles; ++c) { + const __fp16 * row_tiles = q_tiles + r * HMX_FP16_TILE_N_ROWS * n_dot_tiles * HMX_FP16_TILE_N_COLS; + const __fp16 * col_tiles = k_tiles + c * HMX_FP16_TILE_N_COLS * n_dot_tiles * HMX_FP16_TILE_N_COLS; + __fp16 * out_tile = s_tiles + (r * n_tiles_per_bc + c) * HMX_FP16_TILE_N_ELMS; + + for (size_t k = 0; k < n_dot_tiles; ++k) { + Q6_activation_hf_mxmem_RR((unsigned int) row_tiles, 2047); + Q6_weight_hf_mxmem_RR((unsigned int) col_tiles, 2047); + row_tiles += HMX_FP16_TILE_N_ELMS; + col_tiles += HMX_FP16_TILE_N_ELMS; + } + Q6_mxmem_AR_after_hf(out_tile, 0); + } + } +} + +typedef struct { + __fp16 * o_curr; + const __fp16 * o_prev; + const __fp16 * p_tiles; + const __fp16 * v_tiles; + const __fp16 * d_tiles; + uint8_t * hmx_scales; + size_t n_row_tiles; + size_t n_col_tiles; + size_t n_row_tiles_g_br; + size_t n_tiles_per_bc; + size_t DV; +} hmx_fa_o_update_job_t; + +static void hmx_fa_o_update_worker(void * data) { + hmx_fa_o_update_job_t * job = (hmx_fa_o_update_job_t *) data; + const size_t n_row_tiles = job->n_row_tiles; + const size_t n_col_tiles = job->n_col_tiles; + const size_t n_row_tiles_g_br = job->n_row_tiles_g_br; + const size_t n_tiles_per_bc = job->n_tiles_per_bc; + const size_t DV_tiles = job->DV / 32; + const __fp16 * restrict d_tiles = job->d_tiles; + const __fp16 * restrict p_tiles = job->p_tiles; + const __fp16 * restrict v_tiles = job->v_tiles; + const __fp16 * restrict o_prev = job->o_prev; + __fp16 * restrict o_curr = job->o_curr; + __builtin_assume(n_row_tiles > 0); + __builtin_assume(n_col_tiles > 0); + __builtin_assume(DV_tiles > 0); + + Q6_bias_mxmem2_A((void *) job->hmx_scales); + for (size_t r = 0; r < n_row_tiles; ++r) { + for (size_t c = 0; c < DV_tiles; ++c) { + // D[r,r] @ O_prev[r,c] — only the diagonal tile + const __fp16 * d_diag = d_tiles + r * (n_row_tiles_g_br + 1) * HMX_FP16_TILE_N_ELMS; + const __fp16 * o_rc = o_prev + (c * n_row_tiles_g_br + r) * HMX_FP16_TILE_N_ELMS; + Q6_activation_hf_mxmem_RR((unsigned int) d_diag, 2047); + Q6_weight_hf_mxmem_RR((unsigned int) o_rc, 2047); + + // P @ V (accumulate on same accumulator) + const __fp16 * p_tile_in = p_tiles + (r * n_tiles_per_bc) * HMX_FP16_TILE_N_ELMS; + const __fp16 * v_tile_in = v_tiles + (c * n_tiles_per_bc) * HMX_FP16_TILE_N_ELMS; + for (size_t k = 0; k < n_col_tiles; ++k) { + Q6_activation_hf_mxmem_RR((unsigned int) p_tile_in, 2047); + Q6_weight_hf_mxmem_RR((unsigned int) v_tile_in, 2047); + p_tile_in += HMX_FP16_TILE_N_ELMS; + v_tile_in += HMX_FP16_TILE_N_ELMS; + } + + __fp16 * o_tile_out = o_curr + (c * n_row_tiles_g_br + r) * HMX_FP16_TILE_N_ELMS; + Q6_mxmem_AR_after_hf(o_tile_out, 0); + } + } +} + +typedef struct { + __fp16 * o_curr; // output (row-major tile layout) + const __fp16 * o_prev; // input (column-major tile layout) + const __fp16 * d_tiles; // diag(1/l) tiles + uint8_t * hmx_scales; + size_t n_row_tiles; + size_t n_row_tiles_g_br; + size_t DV; +} hmx_fa_o_norm_job_t; + +static void hmx_fa_o_norm_worker(void * data) { + hmx_fa_o_norm_job_t * job = (hmx_fa_o_norm_job_t *) data; + const size_t n_row_tiles = job->n_row_tiles; + const size_t n_row_tiles_g_br = job->n_row_tiles_g_br; + const size_t DV_tiles = job->DV / 32; + const __fp16 * restrict d_tiles = job->d_tiles; + const __fp16 * restrict o_prev = job->o_prev; + __fp16 * restrict o_curr = job->o_curr; + __builtin_assume(n_row_tiles > 0); + __builtin_assume(DV_tiles > 0); + + Q6_bias_mxmem2_A((void *) job->hmx_scales); + for (size_t r = 0; r < n_row_tiles; ++r) { + for (size_t c = 0; c < DV_tiles; ++c) { + const __fp16 * d_diag = d_tiles + r * (n_row_tiles_g_br + 1) * HMX_FP16_TILE_N_ELMS; + const __fp16 * o_rc = o_prev + (c * n_row_tiles_g_br + r) * HMX_FP16_TILE_N_ELMS; + __fp16 * o_out = o_curr + (r * DV_tiles + c) * HMX_FP16_TILE_N_ELMS; + + Q6_activation_hf_mxmem_RR((unsigned int) d_diag, 2047); + Q6_weight_hf_mxmem_RR((unsigned int) o_rc, 2047); + Q6_mxmem_AR_after_hf(o_out, 0); + } + } +} + +// Populate per-GQA-row ALiBi slopes for a given KV head. +// Row r in the GQA-merged block maps to Q head h = kv_head * G + r % G. +// slope(h) = m0^(h+1) when h < n_head_log2, else m1^(2*(h-n_head_log2)+1). +// When max_bias == 0, all slopes are 1.0 (no ALiBi). +static __attribute__((noinline)) void fa_compute_slopes(fa_softmax_args_t * sargs, + const struct hmx_fa_context * factx, + uint32_t kv_head, + size_t n_rows_g) { + if (factx->max_bias == 0.0f) { + for (size_t r = 0; r < n_rows_g; ++r) { + sargs->slopes[r] = 1.0f; + } + return; + } + + const uint32_t G = factx->G; + const uint32_t n_head_log2 = factx->n_head_log2; + const float m0 = factx->m0; + const float m1 = factx->m1; + + for (size_t r = 0; r < n_rows_g; ++r) { + const uint32_t h = kv_head * G + r % G; + sargs->slopes[r] = (h < n_head_log2) ? powf(m0, h + 1) : powf(m1, 2 * (h - n_head_log2) + 1); + } +} + +// ============================================================================ +// Core HMX flash attention algorithm (GQA-merged) +// ============================================================================ + +int hmx_flash_attn_ext(struct htp_ops_context * octx) { + const struct htp_tensor * q = octx->src[0]; + const struct htp_tensor * k = octx->src[1]; + const struct htp_tensor * v = octx->src[2]; + const struct htp_tensor * mask = (octx->src[3] && octx->src[3]->data) ? octx->src[3] : NULL; + const struct htp_tensor * dst = octx->dst; + + struct htp_context * const ctx = octx->ctx; + + if (!ctx->hmx_enabled) { + return HTP_STATUS_NO_SUPPORT; + } + + // Dimensions + const uint32_t neq0 = q->ne[0]; // head_dim (DK) + const uint32_t neq1 = q->ne[1]; // n_tokens + const uint32_t neq2 = q->ne[2]; // n_heads + const uint32_t neq3 = q->ne[3]; // n_seqs + + const uint32_t nek0 = k->ne[0]; // head_dim + const uint32_t nek1 = k->ne[1]; // kv_len + + const uint32_t nev0 = v->ne[0]; // head_dim (DV) + + const uint32_t DK = neq0; + const uint32_t DV = nev0; + + // HMX requires head_dim to be multiple of 32 + if (DK % 32 != 0 || DV % 32 != 0) { + return HTP_STATUS_NO_SUPPORT; + } + if (neq1 < 32) { + return HTP_STATUS_NO_SUPPORT; + } + + // GQA factor + const uint32_t n_kv_heads = k->ne[2]; + const uint32_t G = neq2 / n_kv_heads; + + // Thread count for multi-thread HVX phases + const uint32_t n_threads = octx->n_threads; + + // Compute dynamic block sizes (GQA-aware, accounting for per-thread row bufs) + size_t Br, Bc; + const size_t vtcm_budget = ctx->vtcm_size; + if (hmx_fa_find_chunk_size(&Br, &Bc, G, DK, DV, neq1, nek1, vtcm_budget, n_threads) != 0) { + return HTP_STATUS_VTCM_TOO_SMALL; + } + + const size_t g_br = hex_align_up(G * Br, HMX_FP16_TILE_N_ROWS); + + const uint32_t n_kv_blocks = (nek1 + Bc - 1) / Bc; + const bool use_pipeline = (n_kv_blocks >= FA_MIN_KV_BLOCKS && n_threads >= 2); + + FARF(HIGH, "hmx-fa: neq1=%u nek1=%u DK=%u DV=%u G=%u Br=%zu Bc=%zu g_br=%zu n_kv_blocks=%u pipeline=%d vtcm=%zu", + neq1, nek1, DK, DV, G, Br, Bc, g_br, n_kv_blocks, use_pipeline, vtcm_budget); + + // ======== Build context ======== + struct hmx_fa_context factx; + memset(&factx, 0, sizeof(factx)); + factx.octx = octx; + factx.n_threads = octx->ctx->n_threads; + factx.DK = DK; + factx.DV = DV; + factx.n_kv = nek1; + factx.n_kv_heads = n_kv_heads; + factx.n_heads = neq2; + factx.G = G; + factx.neq1 = neq1; + factx.Br = (uint32_t) Br; + factx.Bc = (uint32_t) Bc; + factx.g_br = (uint32_t) g_br; + factx.n_kv_blocks = n_kv_blocks; + factx.is_q_fp32 = (q->type == HTP_TYPE_F32); + factx.is_dst_fp32 = (dst->type == HTP_TYPE_F32); + factx.use_pipeline = use_pipeline; + factx.mask_broadcast = (mask != NULL && mask->ne[2] == 1); + + // Extract op parameters (mutable during softcap adjustment, then stored as const in factx) + float scale = 1.0f, max_bias = 0.0f, logit_softcap = 0.0f; + memcpy(&scale, (float *) octx->op_params + 0, sizeof(float)); + memcpy(&max_bias, (float *) octx->op_params + 1, sizeof(float)); + memcpy(&logit_softcap, (float *) octx->op_params + 2, sizeof(float)); + + if (logit_softcap != 0.0f) { + scale /= logit_softcap; + } + +#ifdef HMX_FA_USE_EXP2_HF + // Pre-bake log2(e) into qk_scale so HMX-produced S tiles are in log2(e)-scaled + // space. Then exp2(S - m) in the softmax equals base-e exp((S - m) / log2(e)), + // preserving ggml's base-e softmax semantics. Matches htp-ops-lib flash_attn.c. + // + // When softcap is active we cannot pre-bake log2(e) here — it would land inside + // the tanh argument and shift the softcap knee from x≈c to x≈c/log2(e), giving + // numerically wrong softcapped values. Instead fold log2(e) into the post-tanh + // multiplier (see softcap block: v_cap absorbs log2(e)). + if (logit_softcap == 0.0f) { + scale *= 1.44269504f; // log2(e) + } +#endif + + factx.scale = scale; + factx.max_bias = max_bias; + factx.logit_softcap = logit_softcap; + + factx.n_head_log2 = 1u << (uint32_t) floor(log2(neq2)); + factx.m0 = powf(2.0f, -(max_bias) / factx.n_head_log2); + factx.m1 = powf(2.0f, -(max_bias / 2.0f) / factx.n_head_log2); + + // ======== VTCM allocation (GQA-aware) ======== + const size_t q_tile_bytes = hex_align_up(g_br * DK * sizeof(__fp16), 4096); + const size_t o_tile_bytes = hex_align_up(g_br * DV * sizeof(__fp16), 4096); + const size_t k_dma_bytes = hex_align_up(Bc * DK * sizeof(__fp16), 4096); + const size_t v_dma_bytes = hex_align_up(Bc * DV * sizeof(__fp16), 4096); + const size_t k_tile_bytes = hex_align_up(Bc * DK * sizeof(__fp16), 4096); + const size_t v_tile_bytes = hex_align_up(Bc * DV * sizeof(__fp16), 4096); + const size_t s_tile_bytes = hex_align_up(g_br * Bc * sizeof(__fp16), 4096); + const size_t d_tile_bytes = hex_align_up(g_br * g_br * sizeof(__fp16), 4096); + const size_t col_vec_bytes = hex_align_up(g_br * sizeof(__fp16), 256); + const size_t row_vec_bytes = hex_align_up(Bc * sizeof(__fp16), 256); + const size_t m_line_bytes = hex_align_up(Bc * sizeof(__fp16), 128); + const size_t m_buf_bytes = hex_align_up(Br * m_line_bytes, 4096); + const size_t slopes_bytes = hex_align_up(g_br * sizeof(__fp16), 128); + + uint8_t * vtcm_cur = ctx->vtcm_base; + + factx.vtcm_q_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, q_tile_bytes); + factx.vtcm_o_tiles[0] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, o_tile_bytes); + factx.vtcm_o_tiles[1] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, o_tile_bytes); + factx.vtcm_k_fp16[0] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, k_dma_bytes); + factx.vtcm_k_fp16[1] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, k_dma_bytes); + factx.vtcm_v_fp16[0] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, v_dma_bytes); + factx.vtcm_v_fp16[1] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, v_dma_bytes); + factx.vtcm_k_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, k_tile_bytes); + factx.vtcm_v_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, v_tile_bytes); + factx.vtcm_s_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, s_tile_bytes); + factx.vtcm_p_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, s_tile_bytes); + factx.vtcm_d_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, d_tile_bytes); + factx.vtcm_m_vec = (HVX_Vector *) vtcm_seq_alloc(&vtcm_cur, col_vec_bytes); + factx.vtcm_l_vec = (HVX_Vector *) vtcm_seq_alloc(&vtcm_cur, col_vec_bytes); + factx.vtcm_s_rowmax = (HVX_Vector *) vtcm_seq_alloc(&vtcm_cur, col_vec_bytes); + factx.vtcm_p_rowsum = (HVX_Vector *) vtcm_seq_alloc(&vtcm_cur, col_vec_bytes); + factx.vtcm_row_bufs = (HVX_Vector *) vtcm_seq_alloc(&vtcm_cur, row_vec_bytes * 2 * n_threads); + factx.row_buf_stride = row_vec_bytes / sizeof(HVX_Vector); + factx.vtcm_hmx_scales_id = vtcm_seq_alloc(&vtcm_cur, 256); + factx.vtcm_hmx_scales_qk = vtcm_seq_alloc(&vtcm_cur, 256); + factx.vtcm_mask_buf = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, m_buf_bytes); + factx.mask_buf_row_stride = m_line_bytes / sizeof(__fp16); + factx.vtcm_slopes = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, slopes_bytes); + + if ((size_t) (vtcm_cur - ctx->vtcm_base) > ctx->vtcm_size) { + return HTP_STATUS_VTCM_TOO_SMALL; + } + + // ======== Initialize HMX output scales ======== + // Identity scale (1.0) for O updates and normalization + hmx_init_column_scales(factx.vtcm_hmx_scales_id, Q6_V_vsplat_R(0x3c00)); // 1.0 + + // QK scale embedded in HMX output + hmx_init_column_scales(factx.vtcm_hmx_scales_qk, hvx_vec_splat_f16(factx.scale)); + + // ======== Skip compute if profiling ======== + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { + return HTP_STATUS_OK; + } + + // Profiling timers + TIMER_DEFINE(total); + TIMER_DEFINE(q_load); + TIMER_DEFINE(kv_dma); + TIMER_DEFINE(k_interleave); + TIMER_DEFINE(v_interleave); + TIMER_DEFINE(qk_dot); + TIMER_DEFINE(softmax); + TIMER_DEFINE(o_update); + TIMER_DEFINE(o_norm); + TIMER_DEFINE(o_store); + + TIMER_START(total); + + // ======== DMA setup ======== + dma_queue * const dma = ctx->dma[0]; + + // Padded row sizes for DMA + const size_t size_k_row = nek0 * sizeof(__fp16); + const size_t size_v_row = nev0 * sizeof(__fp16); + const size_t size_k_row_padded = hex_round_up(nek0 * sizeof(__fp16), 128); + const size_t size_v_row_padded = hex_round_up(nev0 * sizeof(__fp16), 128); + + const size_t n_row_tiles_g_br = g_br / HMX_FP16_TILE_N_ROWS; + const size_t n_tiles_per_bc = Bc / HMX_FP16_TILE_N_COLS; + + // Q/O element size for Q load and O store + const size_t qo_element_size = factx.is_q_fp32 ? sizeof(float) : sizeof(__fp16); + + // ======== HMX lock strategy ======== + // Pipeline: queue thread auto-acquires HMX lock on first push; released by suspend. + // Fallback: main thread holds the lock (original behavior). + if (!factx.use_pipeline) { + HAP_compute_res_hmx_lock(ctx->vtcm_rctx); + } + + // ======== Reusable job descriptors for pipeline ======== + hmx_fa_qk_job_t qk_job; + hmx_fa_o_update_job_t ou_job; + hmx_fa_o_norm_job_t on_job; + + // ======== Main loop: per batch, per KV head, per Q block ======== + for (uint32_t ib3 = 0; ib3 < neq3; ++ib3) { + for (uint32_t kv_head = 0; kv_head < n_kv_heads; ++kv_head) { + const uint32_t ik2 = kv_head; + const uint32_t ik3 = ib3 / (neq3 / k->ne[3]); + const uint32_t iv2 = kv_head; + const uint32_t iv3 = ib3 / (neq3 / v->ne[3]); + + for (uint32_t q_start = 0; q_start < neq1; q_start += Br) { + const uint32_t n_q_rows = hex_smin(Br, neq1 - q_start); + const size_t n_rows_g = n_q_rows * G; + const size_t g_br_actual = hex_align_up(n_rows_g, HMX_FP16_TILE_N_ROWS); + const size_t n_row_tiles = g_br_actual / HMX_FP16_TILE_N_ROWS; + + // ---- Load Q block [g_br, D] -> tiles, interleaving G heads ---- + TIMER_START(q_load); + if (n_rows_g < g_br) { + hvx_splat_u8_a(factx.vtcm_q_tiles, 0, q_tile_bytes); + } + fa_phase_q_load(&factx, q, q_start, kv_head, ib3, n_rows_g); + TIMER_STOP(q_load); + + // ---- Initialize per-block state ---- + hvx_splat_u8_a(factx.vtcm_l_vec, 0, col_vec_bytes); + hvx_splat_u8_a(factx.vtcm_d_tiles, 0, d_tile_bytes); + hvx_splat_u16_a(factx.vtcm_m_vec, 0xfbff, col_vec_bytes/2); + + __fp16 * o_tile_prev = factx.vtcm_o_tiles[0]; + __fp16 * o_tile_curr = factx.vtcm_o_tiles[1]; + hvx_splat_u8_a(o_tile_prev, 0, o_tile_bytes); + + // ---- KV block loop with DMA double-buffering ---- + size_t buf_idx = 0; + + // Prefetch first KV block + if (factx.n_kv_blocks > 0) { + const uint32_t kv_rows0 = hex_smin(Bc, nek1); + + const uint8_t * k_src = (const uint8_t *) k->data + ik2 * k->nb[2] + ik3 * k->nb[3]; + dma_queue_push(dma, dma_make_ptr(factx.vtcm_k_fp16[0], k_src), size_k_row_padded, k->nb[1], + size_k_row, kv_rows0); + + const uint8_t * v_src = (const uint8_t *) v->data + iv2 * v->nb[2] + iv3 * v->nb[3]; + dma_queue_push(dma, dma_make_ptr(factx.vtcm_v_fp16[0], v_src), size_v_row_padded, v->nb[1], + size_v_row, kv_rows0); + } + + // Mask DMA: single 2D transfer of n_q_rows unique mask rows into VTCM buffer. + // Only when mask is head-broadcast (ne[2]==1); otherwise softmax reads DDR directly. + #define MASK_DMA_PUSH(kv_start_val, kv_rows_val, has_mask_dma_var) \ + do { \ + has_mask_dma_var = false; \ + if (mask && factx.mask_broadcast) { \ + const uint32_t _im3 = ib3 % mask->ne[3]; \ + const uint8_t * _ms = (const uint8_t *) mask->data + q_start * mask->nb[1] + _im3 * mask->nb[3] + \ + (kv_start_val) * sizeof(__fp16); \ + dma_queue_push(dma, dma_make_ptr(factx.vtcm_mask_buf, _ms), m_line_bytes, mask->nb[1], \ + (kv_rows_val) * sizeof(__fp16), n_q_rows); \ + has_mask_dma_var = true; \ + } \ + } while (0) + + #define MASK_DMA_POP(has_mask_dma_var) \ + do { \ + if (has_mask_dma_var) { \ + dma_queue_pop(dma); \ + } \ + } while (0) + + #define DMA_PREFETCH_KV(blk_val) \ + do { \ + if ((blk_val) < factx.n_kv_blocks) { \ + const uint32_t _ns = (blk_val) * Bc; \ + const uint32_t _nr = hex_smin(Bc, nek1 - _ns); \ + size_t _nb = 1 - buf_idx; \ + const uint8_t * _ks = (const uint8_t *) k->data + _ns * k->nb[1] + ik2 * k->nb[2] + ik3 * k->nb[3]; \ + dma_queue_push(dma, dma_make_ptr(factx.vtcm_k_fp16[_nb], _ks), size_k_row_padded, k->nb[1], size_k_row, _nr); \ + const uint8_t * _vs = (const uint8_t *) v->data + _ns * v->nb[1] + iv2 * v->nb[2] + iv3 * v->nb[3]; \ + dma_queue_push(dma, dma_make_ptr(factx.vtcm_v_fp16[_nb], _vs), size_v_row_padded, v->nb[1], size_v_row, _nr); \ + } \ + } while (0) + + const size_t k_src_stride = size_k_row_padded / sizeof(__fp16); + const size_t v_src_stride = size_v_row_padded / sizeof(__fp16); + + if (factx.use_pipeline) { + // ================================================================== + // Pipeline path: HVX phases ‖ HMX queue worker + // ================================================================== + struct hmx_queue * hmx_q = ctx->hmx_queue; + + for (uint32_t kv_blk = 0; kv_blk < factx.n_kv_blocks; ++kv_blk) { + const uint32_t kv_start = kv_blk * Bc; + const uint32_t kv_rows = hex_smin(Bc, nek1 - kv_start); + const size_t n_col_tiles = hmx_ceil_div(kv_rows, HMX_FP16_TILE_N_COLS); + + // Wait for current KV DMA + TIMER_START(kv_dma); + dma_queue_pop(dma); // K + dma_queue_pop(dma); // V + TIMER_STOP(kv_dma); + + // Push mask DMA for this block (single 2D DMA when broadcast) + bool has_mask_dma = false; + MASK_DMA_PUSH(kv_start, kv_rows, has_mask_dma); + + // ---- Phase 1: K_int(blk) ‖ O_update(blk-1) ---- + if (kv_blk > 0) { + // Submit O_update for previous block (HMX worker) + ou_job.o_curr = o_tile_curr; + ou_job.o_prev = o_tile_prev; + ou_job.p_tiles = factx.vtcm_p_tiles; + ou_job.v_tiles = factx.vtcm_v_tiles; + ou_job.d_tiles = factx.vtcm_d_tiles; + ou_job.hmx_scales = factx.vtcm_hmx_scales_id; + ou_job.n_row_tiles = n_row_tiles; + ou_job.n_col_tiles = hmx_ceil_div(hex_smin(Bc, nek1 - (kv_blk - 1) * Bc), HMX_FP16_TILE_N_COLS); + ou_job.n_row_tiles_g_br = n_row_tiles_g_br; + ou_job.n_tiles_per_bc = n_tiles_per_bc; + ou_job.DV = DV; + hmx_queue_push(hmx_q, hmx_queue_make_desc(hmx_fa_o_update_worker, &ou_job)); + } + + TIMER_START(k_interleave); + fa_phase_k_interleave(&factx, kv_rows, k_src_stride, buf_idx); + TIMER_STOP(k_interleave); + + if (kv_blk > 0) { + hmx_queue_pop(hmx_q); + hex_swap_ptr((void **) &o_tile_curr, (void **) &o_tile_prev); + } + + // ---- Phase 2: qk_dot(blk) on HMX ‖ V_int(blk) + DMA prefetch on HVX ---- + qk_job.q_tiles = factx.vtcm_q_tiles; + qk_job.k_tiles = factx.vtcm_k_tiles; + qk_job.s_tiles = factx.vtcm_s_tiles; + qk_job.n_row_tiles = n_row_tiles; + qk_job.n_col_tiles = n_col_tiles; + qk_job.n_dot_tiles = DK / 32; + qk_job.n_tiles_per_bc = n_tiles_per_bc; + qk_job.hmx_scales = factx.vtcm_hmx_scales_qk; + TIMER_START(qk_dot); + hmx_queue_push(hmx_q, hmx_queue_make_desc(hmx_fa_qk_dot_worker, &qk_job)); + + // DMA push next block (non-blocking, before worker_pool) + DMA_PREFETCH_KV(kv_blk + 1); + + TIMER_START(v_interleave); + fa_phase_v_interleave(&factx, kv_rows, v_src_stride, buf_idx, n_tiles_per_bc); + TIMER_STOP(v_interleave); + + hmx_queue_pop(hmx_q); + TIMER_STOP(qk_dot); + + // ---- Phase 3: softmax(blk) + build_D(blk) | HMX idle ---- + // Pop mask DMA before softmax (ensures VTCM buffer is ready) + MASK_DMA_POP(has_mask_dma); + + fa_softmax_args_t sargs; + memset(&sargs, 0, sizeof(sargs)); + sargs.factx = &factx; + sargs.kv_rows = kv_rows; + sargs.n_rows_g = n_rows_g; + sargs.n_col_tiles = n_col_tiles; + sargs.n_tiles_per_bc = n_tiles_per_bc; + sargs.n_row_tiles = n_row_tiles; + sargs.n_row_tiles_g_br = n_row_tiles_g_br; + sargs.Bc = Bc; + sargs.G = G; + sargs.kv_head = kv_head; + sargs.kv_start = kv_start; + sargs.q_start = q_start; + sargs.ib3 = ib3; + sargs.has_alibi = (factx.max_bias != 0.0f); + sargs.mask = mask; + sargs.mask_vtcm = has_mask_dma ? (const __fp16 *) factx.vtcm_mask_buf : NULL; + sargs.mask_vtcm_row_stride = factx.mask_buf_row_stride; + sargs.slopes = factx.vtcm_slopes; + fa_compute_slopes(&sargs, &factx, kv_head, n_rows_g); + + TIMER_START(softmax); + fa_phase_softmax_and_build_d(&factx, &sargs, n_row_tiles, n_row_tiles_g_br); + TIMER_STOP(softmax); + + buf_idx = 1 - buf_idx; + } // end KV block loop (pipeline) + + // Epilogue: O_update for last block + if (factx.n_kv_blocks > 0) { + const uint32_t last_blk = factx.n_kv_blocks - 1; + const size_t last_cols = hmx_ceil_div(hex_smin(Bc, nek1 - last_blk * Bc), HMX_FP16_TILE_N_COLS); + ou_job.o_curr = o_tile_curr; + ou_job.o_prev = o_tile_prev; + ou_job.p_tiles = factx.vtcm_p_tiles; + ou_job.v_tiles = factx.vtcm_v_tiles; + ou_job.d_tiles = factx.vtcm_d_tiles; + ou_job.hmx_scales = factx.vtcm_hmx_scales_id; + ou_job.n_row_tiles = n_row_tiles; + ou_job.n_col_tiles = last_cols; + ou_job.n_row_tiles_g_br = n_row_tiles_g_br; + ou_job.n_tiles_per_bc = n_tiles_per_bc; + ou_job.DV = DV; + + TIMER_START(o_update); + hmx_queue_push(hmx_q, hmx_queue_make_desc(hmx_fa_o_update_worker, &ou_job)); + hmx_queue_pop(hmx_q); + TIMER_STOP(o_update); + + hex_swap_ptr((void **) &o_tile_curr, (void **) &o_tile_prev); + } + + } else { + // ================================================================== + // Fallback path: sequential with multi-thread HVX phases + // Main thread holds HMX lock, runs HMX inline. + // ================================================================== + + for (uint32_t kv_blk = 0; kv_blk < factx.n_kv_blocks; ++kv_blk) { + const uint32_t kv_start = kv_blk * Bc; + const uint32_t kv_rows = hex_smin(Bc, nek1 - kv_start); + const size_t n_col_tiles = hmx_ceil_div(kv_rows, HMX_FP16_TILE_N_COLS); + + TIMER_START(kv_dma); + dma_queue_pop(dma); // K + dma_queue_pop(dma); // V + TIMER_STOP(kv_dma); + + bool has_mask_dma = false; + MASK_DMA_PUSH(kv_start, kv_rows, has_mask_dma); + DMA_PREFETCH_KV(kv_blk + 1); + + // K interleave (multi-thread HVX) + TIMER_START(k_interleave); + fa_phase_k_interleave(&factx, kv_rows, k_src_stride, buf_idx); + TIMER_STOP(k_interleave); + + // QK dot (inline HMX on main thread) + TIMER_START(qk_dot); + { + const size_t n_dot_tiles = (size_t) (DK / 32); + const __fp16 * restrict q_base = factx.vtcm_q_tiles; + const __fp16 * restrict k_base = factx.vtcm_k_tiles; + __fp16 * restrict s_base = factx.vtcm_s_tiles; + __builtin_assume(n_row_tiles > 0); + __builtin_assume(n_col_tiles > 0); + __builtin_assume(n_dot_tiles > 0); + + Q6_bias_mxmem2_A((void *) factx.vtcm_hmx_scales_qk); + for (size_t r = 0; r < n_row_tiles; ++r) { + for (size_t c = 0; c < n_col_tiles; ++c) { + const __fp16 * row_tiles = q_base + r * HMX_FP16_TILE_N_ROWS * DK; + const __fp16 * col_tiles = k_base + c * HMX_FP16_TILE_N_COLS * DK; + __fp16 * out_tile = s_base + (r * n_tiles_per_bc + c) * HMX_FP16_TILE_N_ELMS; + for (size_t k = 0; k < n_dot_tiles; ++k) { + Q6_activation_hf_mxmem_RR((unsigned int) row_tiles, 2047); + Q6_weight_hf_mxmem_RR((unsigned int) col_tiles, 2047); + row_tiles += HMX_FP16_TILE_N_ELMS; + col_tiles += HMX_FP16_TILE_N_ELMS; + } + Q6_mxmem_AR_after_hf(out_tile, 0); + } + } + } + TIMER_STOP(qk_dot); + + // Pop mask DMA + MASK_DMA_POP(has_mask_dma); + + // Softmax + build_D (multi-thread HVX + serial m/l update) + fa_softmax_args_t sargs; + memset(&sargs, 0, sizeof(sargs)); + sargs.factx = &factx; + sargs.kv_rows = kv_rows; + sargs.n_rows_g = n_rows_g; + sargs.n_col_tiles = n_col_tiles; + sargs.n_tiles_per_bc = n_tiles_per_bc; + sargs.n_row_tiles = n_row_tiles; + sargs.n_row_tiles_g_br = n_row_tiles_g_br; + sargs.Bc = Bc; + sargs.G = G; + sargs.kv_head = kv_head; + sargs.kv_start = kv_start; + sargs.q_start = q_start; + sargs.ib3 = ib3; + sargs.has_alibi = (factx.max_bias != 0.0f); + sargs.mask = mask; + sargs.mask_vtcm = has_mask_dma ? (const __fp16 *) factx.vtcm_mask_buf : NULL; + sargs.mask_vtcm_row_stride = factx.mask_buf_row_stride; + sargs.slopes = factx.vtcm_slopes; + fa_compute_slopes(&sargs, &factx, kv_head, n_rows_g); + + TIMER_START(softmax); + fa_phase_softmax_and_build_d(&factx, &sargs, n_row_tiles, n_row_tiles_g_br); + TIMER_STOP(softmax); + + // V interleave (multi-thread HVX) + TIMER_START(v_interleave); + // FIX(v-stride): use n_tiles_per_bc (block-invariant) as V tile layout + // stride to match o_update's v_tile access. Using per-block n_col_tiles + // misplaces DV_tile 1..3 in the last partial KV block. + fa_phase_v_interleave(&factx, kv_rows, v_src_stride, buf_idx, n_tiles_per_bc); + TIMER_STOP(v_interleave); + + // O update (inline HMX on main thread) + TIMER_START(o_update); + { + const size_t DV_tiles = (size_t) (DV / 32); + const __fp16 * restrict d_base = factx.vtcm_d_tiles; + const __fp16 * restrict p_base = factx.vtcm_p_tiles; + const __fp16 * restrict v_base = factx.vtcm_v_tiles; + const __fp16 * restrict op_base = o_tile_prev; + __fp16 * restrict oc_base = o_tile_curr; + __builtin_assume(n_row_tiles > 0); + __builtin_assume(n_col_tiles > 0); + __builtin_assume(DV_tiles > 0); + + Q6_bias_mxmem2_A((void *) factx.vtcm_hmx_scales_id); + for (size_t r = 0; r < n_row_tiles; ++r) { + for (size_t c = 0; c < DV_tiles; ++c) { + const __fp16 * d_diag = d_base + r * (n_row_tiles_g_br + 1) * HMX_FP16_TILE_N_ELMS; + const __fp16 * o_rc = op_base + (c * n_row_tiles_g_br + r) * HMX_FP16_TILE_N_ELMS; + Q6_activation_hf_mxmem_RR((unsigned int) d_diag, 2047); + Q6_weight_hf_mxmem_RR((unsigned int) o_rc, 2047); + + const __fp16 * p_tile_in = p_base + (r * n_tiles_per_bc) * HMX_FP16_TILE_N_ELMS; + const __fp16 * v_tile_in = v_base + (c * n_tiles_per_bc) * HMX_FP16_TILE_N_ELMS; + for (size_t k = 0; k < n_col_tiles; ++k) { + Q6_activation_hf_mxmem_RR((unsigned int) p_tile_in, 2047); + Q6_weight_hf_mxmem_RR((unsigned int) v_tile_in, 2047); + p_tile_in += HMX_FP16_TILE_N_ELMS; + v_tile_in += HMX_FP16_TILE_N_ELMS; + } + + __fp16 * o_tile_out = oc_base + (c * n_row_tiles_g_br + r) * HMX_FP16_TILE_N_ELMS; + Q6_mxmem_AR_after_hf(o_tile_out, 0); + } + } + hex_swap_ptr((void **) &o_tile_curr, (void **) &o_tile_prev); + } + TIMER_STOP(o_update); + + buf_idx = 1 - buf_idx; + } // end KV block loop (fallback) + } + + // ---- Final normalization: O = diag(1/l) @ O ---- + TIMER_START(o_norm); + { + fa_build_d_diag_inv_l(&factx, n_row_tiles, n_row_tiles_g_br); + + // HMX: O_final = diag(1/l) @ O_prev + if (factx.use_pipeline) { + on_job.o_curr = o_tile_curr; + on_job.o_prev = o_tile_prev; + on_job.d_tiles = factx.vtcm_d_tiles; + on_job.hmx_scales = factx.vtcm_hmx_scales_id; + on_job.n_row_tiles = n_row_tiles; + on_job.n_row_tiles_g_br = n_row_tiles_g_br; + on_job.DV = DV; + hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_fa_o_norm_worker, &on_job)); + hmx_queue_pop(ctx->hmx_queue); + } else { + const size_t DV_tiles = (size_t) (DV / 32); + const __fp16 * restrict d_base = factx.vtcm_d_tiles; + const __fp16 * restrict op_base = o_tile_prev; + __fp16 * restrict oc_base = o_tile_curr; + __builtin_assume(n_row_tiles > 0); + __builtin_assume(DV_tiles > 0); + + Q6_bias_mxmem2_A((void *) factx.vtcm_hmx_scales_id); + for (size_t r = 0; r < n_row_tiles; ++r) { + for (size_t c = 0; c < DV_tiles; ++c) { + const __fp16 * d_diag = d_base + r * (n_row_tiles_g_br + 1) * HMX_FP16_TILE_N_ELMS; + const __fp16 * o_rc = op_base + (c * n_row_tiles_g_br + r) * HMX_FP16_TILE_N_ELMS; + __fp16 * o_out = oc_base + (r * DV_tiles + c) * HMX_FP16_TILE_N_ELMS; + + Q6_activation_hf_mxmem_RR((unsigned int) d_diag, 2047); + Q6_weight_hf_mxmem_RR((unsigned int) o_rc, 2047); + Q6_mxmem_AR_after_hf(o_out, 0); + } + } + } + } + TIMER_STOP(o_norm); + + // ---- Store O block ---- + TIMER_START(o_store); + fa_phase_o_store(&factx, dst, o_tile_curr, q_start, kv_head, ib3, n_rows_g); + TIMER_STOP(o_store); + +#undef MASK_DMA_PUSH +#undef MASK_DMA_POP +#undef DMA_PREFETCH_KV + + } // end Q block loop + } // end KV head loop + } // end batch loop + + if (factx.use_pipeline) { + hmx_queue_suspend(ctx->hmx_queue); + } else { + HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); + } + + TIMER_STOP(total); + +#if defined(ENABLE_PROFILE_TIMERS) + FARF(HIGH, "hmx-fa: %lld us, q_load=%lld kv_dma=%lld k_interleave=%lld v_interleave=%lld", TIMER_US(total), + TIMER_US(q_load), TIMER_US(kv_dma), TIMER_US(k_interleave), TIMER_US(v_interleave)); + FARF(HIGH, " qk_dot=%lld softmax=%lld o_update=%lld o_norm=%lld o_store=%lld", TIMER_US(qk_dot), TIMER_US(softmax), + TIMER_US(o_update), TIMER_US(o_norm), TIMER_US(o_store)); +#endif + + return HTP_STATUS_OK; +} diff --git a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c index 05e3c6c2b0f..2666a78a96a 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c @@ -28,6 +28,8 @@ #include "hmx-queue.h" #include "hmx-profile.h" +#include "vtcm-utils.h" + static const __fp16 q4_0_to_fp16_lut[64] __attribute__((aligned(VLEN))) = { -8, 0, -7, 0, -6, 0, -5, 0, -4, 0, -3, 0, -2, 0, -1, 0, 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7, 0, }; @@ -43,40 +45,11 @@ static const __fp16 iq4_nl_to_fp16_lut[64] __attribute__((aligned(VLEN))) = { 1, 0, 13, 0, 25, 0, 38, 0, 53, 0, 69, 0, 89, 0, 113, 0, }; -// vscatter offsets for fused dequant+transpose: write K-values directly to [K][N] tile. -// word[i] = i*128 maps K-row-pair i to byte offset i*128 in the tile. -// Column offset (n*4) is added at runtime. Only entries 0..15 are used (masked by predicate). -static const int32_t weight_transpose_scatter_offsets[32] __attribute__((aligned(VLEN))) = { - 0*128, 1*128, 2*128, 3*128, 4*128, 5*128, 6*128, 7*128, - 8*128, 9*128, 10*128, 11*128, 12*128, 13*128, 14*128, 15*128, - 16*128, 17*128, 18*128, 19*128, 20*128, 21*128, 22*128, 23*128, - 24*128, 25*128, 26*128, 27*128, 28*128, 29*128, 30*128, 31*128 -}; - // Scales per x4x2 logical block: 8 × sizeof(__fp16) = 16 bytes #define HMX_X4X2_SCALES_PER_BLK 8 #define HMX_X4X2_DBLK_SIZE 16 // 8 * 2 bytes (fp16 scales for Q4_0/Q8_0/IQ4_NL) #define HMX_X4X2_MXFP4_EBLK_SIZE 8 // 8 * 1 byte (E8M0 scales for MXFP4) -static inline void swap_ptr(void **p1, void **p2) { - void *t = *p1; - *p1 = *p2; - *p2 = t; -} - -typedef struct { - uint8_t *dst; - const uint8_t *src; - dma_queue *dma; - size_t n_rows; - size_t src_stride; // DDR row stride (full row_stride) - size_t dst_stride; // VTCM sub-block row stride - size_t quant_off; // quant byte offset in each DDR row - size_t quant_width; // quant bytes to copy per row - size_t scale_off; // scale byte offset in each DDR row - size_t scale_width; // scale bytes to copy per row -} qweight_fetch_task_state_t; - // Compute the byte stride of one row in x4x2 format. // Numerically equals ggml_row_size(type, k) when k is 256-aligned, because // x4x2 packing has the same density as block_q4_0 / block_q8_0. @@ -202,46 +175,6 @@ static int hmx_compute_chunks(size_t vtcm_total, return 0; } -// forward declaration – defined after transfer_activation_chunk_fp32_to_fp16 -void transfer_activation_chunk_threaded(struct htp_context *ctx, __fp16 *dst, const float *src, int n_rows, int k_block, int k_stride); - -// Scatter row-major FP16 weight (already in VTCM scratch) directly into transposed [K][N] tiles. -// vtcm_src: [n_cols][k] row-major fp16 in VTCM scratch buffer -// vtcm_dst: [n_col_tiles][n_k_tiles][HMX_FP16_TILE_N_ELMS] tile-major interleaved fp16 -static void interleave_fp16_weight_chunk_to_tiles(__fp16 *restrict vtcm_dst, - const __fp16 *restrict vtcm_src, - int n_cols, int k) { - assert(n_cols % HMX_FP16_TILE_N_COLS == 0); - assert(k % HMX_FP16_TILE_N_COLS == 0); - - const int n_k_tiles = k / HMX_FP16_TILE_N_COLS; - const HVX_Vector v_scat_base = hvx_vmem(weight_transpose_scatter_offsets); - const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); - const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64); - - for (int r = 0; r < n_cols; r += 2) { - int ct = r / HMX_FP16_TILE_N_ROWS; // N-dimension tile index - int local_r = r % HMX_FP16_TILE_N_ROWS; // intra-tile row index - const bool next_row_valid = (r + 1) < n_cols; - - // Offset vectors for N-columns local_r and local_r+1, reused across K-tiles. - HVX_Vector v_off0 = Q6_Vw_vadd_VwVw(v_scat_base, Q6_V_vsplat_R(local_r * 4)); - HVX_Vector v_off1 = Q6_Vw_vadd_VwVw(v_off0, v_scat_step); - - for (int c = 0; c < k; c += HMX_FP16_TILE_N_COLS) { - int kt = c / HMX_FP16_TILE_N_COLS; - int tile_idx = ct * n_k_tiles + kt; - __fp16 *tile_base = vtcm_dst + tile_idx * HMX_FP16_TILE_N_ELMS; - - HVX_Vector v0 = hvx_vmemu(vtcm_src + r * k + c); - HVX_Vector v1 = next_row_valid ? hvx_vmemu(vtcm_src + (r + 1) * k + c) : Q6_V_vzero(); - - Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off0, v0); - Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off1, v1); - } - } -} - // --- x4x2 format dequantizers --- // Dequantize one x4x2 Q4_0 group (32 elements from 32 packed bytes) -> 32 FP16 in first 64 bytes. @@ -303,8 +236,7 @@ static inline void dequantize_x4x2_q4_0_x4groups_hvx( } // Dequantize one x4x2 Q8_0 group (32 int8 quants) -> 32 FP16 in first 64 bytes. -static inline HVX_Vector dequantize_x4x2_q8_0_group_hvx( - const int8_t *quants_32, const __fp16 *scale) { +static inline HVX_Vector dequantize_x4x2_q8_0_group_hvx(const int8_t *quants_32, const __fp16 *scale) { HVX_Vector vq = hvx_vmemu(quants_32); HVX_Vector v_scales = hvx_vec_splat_f16(*scale); HVX_Vector v0 = Q6_V_lo_W(Q6_Wh_vunpack_Vb(vq)); @@ -414,8 +346,8 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task( // vscatter setup: write dequantized K-values directly to transposed [K][N] tile positions. // Each int32 element holds a K-row-pair (2 adjacent fp16 values). word[i] at offset i*128 // maps to K-rows 2i and 2i+1. Column offset (n*4) added per row. - const HVX_Vector v_scat_base = hvx_vmem(weight_transpose_scatter_offsets); - const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); // 4 bytes = 1 column step + const HVX_Vector v_scat_base = hvx_vmem(hmx_transpose_scatter_offsets); + const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); // 4 bytes = 1 column step const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64); // first 16 words (64 bytes) unsigned ct = (unsigned)start_tile / n_k_tiles; // column tile index @@ -658,12 +590,12 @@ static void dequantize_x4x2_weight_chunk_to_fp16_tiles( state.n_tasks = (n_tot_tiles + n_tiles_per_task - 1) / n_tiles_per_task; state.n_tot_tiles = n_tot_tiles; state.n_tiles_per_task = n_tiles_per_task; - state.dst = vtcm_dst; - state.src = (const uint8_t *)vtcm_src; - state.n_cols = n_cols; - state.k_block = k_block; - state.row_stride = row_stride; - state.weight_type = weight_type; + state.dst = vtcm_dst; + state.src = (const uint8_t *)vtcm_src; + state.n_cols = n_cols; + state.k_block = k_block; + state.row_stride = row_stride; + state.weight_type = weight_type; worker_pool_run_func(ctx->worker_pool, dequantize_x4x2_worker_loop, &state, ctx->n_threads); } @@ -733,7 +665,7 @@ static inline void hmx_matmul_job_init(hmx_matmul_job_t * job, job->n_dot_tiles = n_dot_tiles; } -// --- End async HMX matmul job --- +// output : fp16 -> f32p static void transfer_output_chunk_fp16_to_fp32(float *restrict dst, const __fp16 *restrict vtcm_src, int n_rows, int n_cols, int n) { assert(n_cols % HMX_FP16_TILE_N_COLS == 0); @@ -807,295 +739,397 @@ static void transfer_output_chunk_threaded(struct htp_context *ctx, float *dst, worker_pool_run_func(ctx->worker_pool, transfer_output_chunk_worker_fn, &state, ctx->n_threads); } -static inline int hmx_matmul_batch_r2(const hmx_matmul_w16a32_batched_params_t *params) { - return params->ne02 > 0 ? params->ne12 / params->ne02 : 1; -} +// activations : fp32 -> fp16 -static inline int hmx_matmul_batch_r3(const hmx_matmul_w16a32_batched_params_t *params) { - return params->ne03 > 0 ? params->ne13 / params->ne03 : 1; -} +static void transfer_activation_chunk_fp32_to_fp16(__fp16 *restrict vtcm_dst, const float *restrict src, int n_rows, int k_block, int k_stride) { + for (int r = 0; r < n_rows; r += 2) { + int r0 = r / HMX_FP16_TILE_N_ROWS; // tile row index + int r1 = r % HMX_FP16_TILE_N_ROWS; // intra-tile row idx -static inline const __fp16 *hmx_matmul_weight_batch_ptr(const hmx_matmul_w16a32_batched_params_t *params, - int dst_b2, int dst_b3) { - const int r2 = hmx_matmul_batch_r2(params); - const int r3 = hmx_matmul_batch_r3(params); - return (const __fp16 *) ((const uint8_t *) params->permuted_weight + - (size_t) (dst_b2 / r2) * params->src0_nb2 + - (size_t) (dst_b3 / r3) * params->src0_nb3); -} + const bool next_row_valid = (r + 1) < n_rows; -static inline const float *hmx_matmul_activation_batch_ptr(const hmx_matmul_w16a32_batched_params_t *params, - int dst_b2, int dst_b3) { - return (const float *) ((const uint8_t *) params->activation + - (size_t) dst_b2 * params->src1_nb2 + - (size_t) dst_b3 * params->src1_nb3); -} + const HVX_Vector *pv_in0 = (const HVX_Vector *) (src + (r + 0) * k_stride); + const HVX_Vector *pv_in1 = (const HVX_Vector *) (src + (r + 1) * k_stride); + for (int c = 0; c < k_block; c += 32) { + HVX_Vector v0 = *pv_in0++; + HVX_Vector v1 = next_row_valid ? *pv_in1++ : Q6_V_vzero(); -static inline float *hmx_matmul_dst_batch_ptr(const hmx_matmul_w16a32_batched_params_t *params, - int dst_b2, int dst_b3) { - return (float *) ((uint8_t *) params->dst + - (size_t) dst_b2 * params->dst_nb2 + - (size_t) dst_b3 * params->dst_nb3); -} + HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); -static int hmx_mat_mul_permuted_w16a32_batched_legacy(struct htp_context *ctx, - const hmx_matmul_w16a32_batched_params_t *params) { - int ret = 0; - for (int b3 = 0; b3 < params->ne13 && ret == 0; ++b3) { - for (int b2 = 0; b2 < params->ne12 && ret == 0; ++b2) { - ret = hmx_mat_mul_permuted_w16a32(ctx, - hmx_matmul_dst_batch_ptr(params, b2, b3), - hmx_matmul_activation_batch_ptr(params, b2, b3), - hmx_matmul_weight_batch_ptr(params, b2, b3), - params->m, params->k, params->n, - params->act_stride, params->weight_stride); + // compute output position + int c0 = c / HMX_FP16_TILE_N_COLS; // tile column index + int tile_idx = r0 * (k_block / HMX_FP16_TILE_N_COLS) + c0; + + HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HMX_FP16_TILE_N_ELMS); + tile[r1 / 2] = v_out; } } - return ret; } -int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmul_w16a32_batched_params_t *params) { - if (!ctx || !params || !params->dst || !params->activation || !params->permuted_weight) { return -1; } - if (!params->m || !params->k || !params->n) { return -1; } - if (params->act_stride < params->k || params->weight_stride < params->k || params->dst_stride < params->n) { return -1; } - if (params->ne02 <= 0 || params->ne03 <= 0 || params->ne12 <= 0 || params->ne13 <= 0) { return -1; } - if (params->ne12 % params->ne02 != 0 || params->ne13 % params->ne03 != 0) { return -1; } - if (params->k % 32 != 0 || params->n % 32 != 0) { return -1; } +typedef struct { + __fp16 *dst; + const float *src; + int n_tasks; + int n_tot_chunks; + int n_chunks_per_task; + int k_block; + int k_stride; +} activation_transfer_task_state_t; - if (!hex_is_aligned(params->dst, VLEN) || - !hex_is_aligned(params->activation, VLEN) || - !hex_is_aligned(params->permuted_weight, VLEN)) { - return -1; - } +static void transfer_activation_chunk_worker_fn(unsigned int n, unsigned int i, void *data) { + activation_transfer_task_state_t *st = (activation_transfer_task_state_t *) data; - const int group_size = hmx_matmul_batch_r2(params); + for (unsigned int task_id = i; task_id < (unsigned int)st->n_tasks; task_id += n) { + // one chunk: one row + int chunk_idx = task_id * st->n_chunks_per_task; + size_t chunk_size = hex_smin(st->n_tot_chunks - chunk_idx, st->n_chunks_per_task); - if (group_size <= 1) { - FARF(MEDIUM, "%s: no dim2 GQA reuse (group=%d), using legacy batched loop", __func__, group_size); - return hmx_mat_mul_permuted_w16a32_batched_legacy(ctx, params); + __fp16 *dst = st->dst + chunk_idx * st->k_block; + const float *src = st->src + chunk_idx * st->k_stride; + transfer_activation_chunk_fp32_to_fp16(dst, src, chunk_size, st->k_block, st->k_stride); } +} - // Grouped path: reuse interleaved weight across all q_heads sharing a - // kv_head. Each q_head gets its own activation buffer in VTCM (so - // activation is loaded once per m_chunk and reused across all n_chunks), - // and each q_head is computed individually to avoid tile-major packing - // issues. m_chunk_n_rows is always a multiple of 32 (from - // hmx_compute_chunks), so per-head tile arrays don't overlap. - const size_t vtcm_budget = ctx->vtcm_size; - const size_t vec_dot_size = params->k * sizeof(__fp16); +static void transfer_activation_chunk_threaded(struct htp_context *ctx, __fp16 *dst, const float *src, int n_rows, int k_block, int k_stride) { + assert(k_block % HMX_FP16_TILE_N_COLS == 0 && k_stride % HMX_FP16_TILE_N_COLS == 0); + assert(VLEN == 32 * sizeof(float)); - // When the activation has a large stride (e.g. permuted Q tensor with - // act_stride >> k), HVX vector loads from strided DDR thrash L2 cache. - // Allocate an F32 scratch buffer in VTCM and use 2D DMA to gather - // strided rows into a contiguous block before the F32->F16 conversion. - const bool use_dma_activation = (params->act_stride > params->k); - const size_t f32_scratch_per_m = use_dma_activation ? (size_t) params->k * sizeof(float) : 0; + size_t n_tot_chunks = n_rows; + size_t n_chunks_per_task = 32; // must be multiple of 32 to ensure correct destination address - size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0, vtcm_used = 0; - // FP16 weight: interleave and activation load have similar per-element cost. - if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, - /*per_n=*/3 * vec_dot_size, - /*per_m=*/group_size * vec_dot_size + f32_scratch_per_m, - /*per_mn=*/sizeof(__fp16), params->m, params->n, - /*m_block_cost=*/(size_t) params->n, - /*n_block_cost=*/(size_t) params->m, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) { - FARF(HIGH, "%s: grouped path does not fit VTCM, falling back to legacy batched loop", __func__); - return hmx_mat_mul_permuted_w16a32_batched_legacy(ctx, params); - } + activation_transfer_task_state_t state; + state.n_tasks = (n_tot_chunks + n_chunks_per_task - 1) / n_chunks_per_task; + state.n_tot_chunks = n_tot_chunks; + state.n_chunks_per_task = n_chunks_per_task; + state.dst = dst; + state.src = src; + state.k_block = k_block; + state.k_stride = k_stride; - const size_t act_head_stride = m_chunk_n_rows * (size_t) params->k; // fp16 elements between heads - const size_t weight_area_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); - const size_t activation_area_size = hex_align_up(group_size * m_chunk_n_rows * vec_dot_size, HMX_FP16_TILE_SIZE); - const size_t output_area_size = hex_align_up(m_chunk_n_rows * n_chunk_n_cols * sizeof(__fp16), HMX_FP16_TILE_SIZE); - const size_t scratch_area_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); - const size_t f32_scratch_size = use_dma_activation - ? hex_align_up(m_chunk_n_rows * (size_t) params->k * sizeof(float), HMX_FP16_TILE_SIZE) : 0; + worker_pool_run_func(ctx->worker_pool, transfer_activation_chunk_worker_fn, &state, ctx->n_threads); +} - uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base; - __fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_area_size); - __fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, activation_area_size); - __fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, output_area_size); - void *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch_area_size); - void *vtcm_scratch1 = vtcm_seq_alloc(&vtcm_ptr, scratch_area_size); - __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); - float *vtcm_f32_act = use_dma_activation ? (float *) vtcm_seq_alloc(&vtcm_ptr, f32_scratch_size) : NULL; +// - if ((size_t) (vtcm_ptr - (uint8_t *) ctx->vtcm_base) > vtcm_budget) { - FARF(HIGH, "%s: grouped layout overflowed VTCM, falling back to legacy batched loop", __func__); - return hmx_mat_mul_permuted_w16a32_batched_legacy(ctx, params); - } +#define FALLBACK_TO_STANDARD 1 - hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16 +// C += AB +static void core_mma_chunk_fp16(__fp16 *restrict c, const __fp16 *restrict a, const __fp16 *restrict b, + const __fp16 *restrict col_scales, const __fp16 *restrict eye_tile, + int n_row_tiles, int n_col_tiles, int n_dot_tiles, bool zero_init) { + __builtin_assume(n_row_tiles > 0); + __builtin_assume(n_col_tiles > 0); + __builtin_assume(n_dot_tiles > 0); - FARF(MEDIUM, "%s: grouped path m=%d k=%d n=%d group=%d streams=%d mc=%zu nc=%zu vtcm=%zu/%zu", - __func__, params->m, params->k, params->n, group_size, params->ne13, - m_chunk_n_rows, n_chunk_n_cols, - (size_t) (vtcm_ptr - (uint8_t *) ctx->vtcm_base), vtcm_budget); + Q6_bias_mxmem2_A((void *)col_scales); - TIMER_DEFINE(activation_load); - TIMER_DEFINE(weight_load); - TIMER_DEFINE(hmx_core); - TIMER_DEFINE(output_store); - TIMER_DEFINE(total); + const size_t dot_tile_stride = n_dot_tiles * HMX_FP16_TILE_N_ELMS; + for (size_t i = 0; i < n_row_tiles; ++i) { + const __fp16 *row_base = a + i * dot_tile_stride; + __fp16 *res_base = c + i * n_col_tiles * HMX_FP16_TILE_N_ELMS; + for (size_t j = 0; j < n_col_tiles; ++j) { + Q6_mxclracc_hf(); - TIMER_START(total); + const __fp16 *col_tiles = b + j * dot_tile_stride; + const __fp16 *row_tiles = row_base; + __fp16 *accum_tile = res_base + j * HMX_FP16_TILE_N_ELMS; + if (!zero_init) { + Q6_activation_hf_mxmem_RR((unsigned int)accum_tile, 2047); + Q6_weight_hf_mxmem_RR((unsigned int)eye_tile, 2047); + } - const size_t fp16_row_bytes = (size_t) params->k * sizeof(__fp16); - const size_t weight_row_bytes = (size_t) params->weight_stride * sizeof(__fp16); + for (int k = 0; k < n_dot_tiles; ++k) { + Q6_activation_hf_mxmem_RR((unsigned int)row_tiles, 2047); + Q6_weight_hf_mxmem_RR((unsigned int)col_tiles, 2047); + row_tiles += HMX_FP16_TILE_N_ELMS; + col_tiles += HMX_FP16_TILE_N_ELMS; + } + Q6_mxmem_AR_after_hf(accum_tile, 0); + } + } +} - HAP_compute_res_hmx_lock(ctx->vtcm_rctx); +static __attribute__((noinline)) int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, + float *restrict out, const float *restrict x, const uint8_t *restrict w, + int m, int k, int n, int weight_type) { + // assume k % 32 == 0 && n % 32 == 0 + const size_t row_stride = get_x4x2_row_stride(weight_type, k); + if (row_stride == 0) { + return -1; + } - for (int b3 = 0; b3 < params->ne13; ++b3) { - for (int b2_base = 0; b2_base < params->ne12; b2_base += group_size) { - const __fp16 *weight_group = hmx_matmul_weight_batch_ptr(params, b2_base, b3); + const size_t vtcm_budget = ctx->vtcm_size; - for (size_t mr = 0; mr < (size_t) params->m; mr += m_chunk_n_rows) { - const size_t n_rows = hex_smin((size_t) params->m - mr, m_chunk_n_rows); - const size_t n_row_tiles = hmx_ceil_div((int) n_rows, HMX_FP16_TILE_N_ROWS); + const size_t K_BLOCK_SIZE = 1024; - // Pre-load activations for all heads in the group (once per m_chunk). - // When the source is strided (permuted Q), use 2D DMA to gather - // contiguous rows into a VTCM scratch buffer first, then HVX - // converts from the contiguous VTCM buffer. This avoids L2 cache - // thrashing from HVX loads at large strides. - TIMER_START(activation_load); - for (int g = 0; g < group_size; ++g) { - const float *activation_chunk = hmx_matmul_activation_batch_ptr(params, b2_base + g, b3) + mr * params->act_stride; - __fp16 *vtcm_act_g = vtcm_activation + (size_t) g * act_head_stride; - if (use_dma_activation) { - const size_t row_bytes = (size_t) params->k * sizeof(float); - const size_t stride_bytes = (size_t) params->act_stride * sizeof(float); - dma_queue_push(ctx->dma[0], - dma_make_ptr(vtcm_f32_act, activation_chunk), - row_bytes, stride_bytes, row_bytes, n_rows); - dma_queue_pop(ctx->dma[0]); - transfer_activation_chunk_threaded(ctx, vtcm_act_g, - vtcm_f32_act, (int) n_rows, - params->k, params->k); - } else { - transfer_activation_chunk_threaded(ctx, vtcm_act_g, - activation_chunk, (int) n_rows, - params->k, params->act_stride); - } - } - TIMER_STOP(activation_load); + // Fallback: if k doesn't need K-blocking, out-stationary has no advantage + const size_t k_iters_check = (k + K_BLOCK_SIZE - 1) / K_BLOCK_SIZE; + if (k_iters_check <= 1) { + FARF(HIGH, "%s: K_BLK=%zu >= k=%d, fallback to standard path", __func__, K_BLOCK_SIZE, k); + return FALLBACK_TO_STANDARD; + } - void *buf_curr = vtcm_scratch0; - void *buf_next = vtcm_scratch1; + // Dynamic M,N search via hmx_compute_chunks + const size_t sub_row_stride_alloc = get_x4x2_row_stride(weight_type, K_BLOCK_SIZE); + const size_t per_m = K_BLOCK_SIZE * sizeof(float) // scratch1: M×K×4 (act DMA staging F32) + + K_BLOCK_SIZE * sizeof(__fp16); // activation: M×K×2 (F16 tiles) + const size_t per_n = sub_row_stride_alloc // scratch0: N×sub_row(K) (packed quant) + + K_BLOCK_SIZE * sizeof(__fp16); // weight: N×K×2 (F16 tiles) + const size_t per_mn = sizeof(__fp16); // output: M×N×2 (out-stationary) - { - const size_t n_cols_first = hex_smin((size_t) params->n, n_chunk_n_cols); - dma_queue_push(ctx->dma[0], dma_make_ptr(buf_curr, weight_group), - fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_first); - } + // Alignment margin: hex_align_up can add up to 2047 bytes per buffer; + // scratch1 (mc×6144) is naturally 2048-aligned, remaining 4 buffers need margin + const size_t align_margin = 4 * HMX_FP16_TILE_SIZE; + const size_t overhead = HMX_FP16_TILE_SIZE + 256 + align_margin; // eye_tile + scales + alignment - for (size_t nc = 0; nc < (size_t) params->n; nc += n_chunk_n_cols) { - const size_t n_cols = hex_smin((size_t) params->n - nc, n_chunk_n_cols); - const size_t n_col_tiles = hmx_ceil_div((int) n_cols, HMX_FP16_TILE_N_COLS); + size_t M_BLOCK_SIZE, N_BLOCK_SIZE, vtcm_used; + // Cost-based search: minimize ceil(m/mc)*m_block_cost + ceil(n/nc)*n_block_cost. + // From profiling: wt_dequant per element ≈ 1.5× activation load per element. + // m_block_cost = n*3: each extra M-block re-dequants all N×K weight (expensive). + // n_block_cost = m*2: each extra N-block re-loads all M×K activation (cheaper). + const size_t m_block_cost = (size_t) n * 3; + const size_t n_block_cost = (size_t) m * 2; + if (hmx_compute_chunks(vtcm_budget, overhead, per_n, per_m, per_mn, m, n, m_block_cost, n_block_cost, &M_BLOCK_SIZE, + &N_BLOCK_SIZE, &vtcm_used) != 0) { + FARF(HIGH, "%s: VTCM too small (m=%d k=%d n=%d budget=%zu)", __func__, m, k, n, vtcm_budget); + return -1; + } - TIMER_START(weight_load); - { - dma_queue_pop(ctx->dma[0]); + // Compute precise buffer sizes from searched M,N and fixed K + const size_t weight_size = hex_align_up(N_BLOCK_SIZE * K_BLOCK_SIZE * sizeof(__fp16), HMX_FP16_TILE_SIZE); + const size_t act_size = hex_align_up(M_BLOCK_SIZE * K_BLOCK_SIZE * sizeof(__fp16), HMX_FP16_TILE_SIZE); + const size_t out_size = hex_align_up(M_BLOCK_SIZE * N_BLOCK_SIZE * sizeof(__fp16), HMX_FP16_TILE_SIZE); + const size_t scratch0_sz = hex_align_up(N_BLOCK_SIZE * sub_row_stride_alloc, HMX_FP16_TILE_SIZE); + const size_t scratch1_sz = hex_align_up(M_BLOCK_SIZE * K_BLOCK_SIZE * sizeof(float), HMX_FP16_TILE_SIZE); - const size_t nc_next = nc + n_chunk_n_cols; - if (nc_next < (size_t) params->n) { - const size_t n_cols_next = hex_smin((size_t) params->n - nc_next, n_chunk_n_cols); - const __fp16 *next_weight_chunk = weight_group + nc_next * params->weight_stride; + const size_t total_vtcm = weight_size + act_size + out_size + scratch0_sz + scratch1_sz + HMX_FP16_TILE_SIZE + 256; + if (total_vtcm > vtcm_budget) { + FARF(HIGH, "%s: VTCM overflow after search: need %zu have %zu (M=%zu N=%zu K=%zu)", __func__, total_vtcm, + vtcm_budget, M_BLOCK_SIZE, N_BLOCK_SIZE, K_BLOCK_SIZE); + return -1; + } - dma_queue_push(ctx->dma[0], dma_make_ptr(buf_next, next_weight_chunk), - fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_next); - } + uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base; + __fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_size); + __fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, act_size); + __fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, out_size); + uint8_t *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch0_sz); + uint8_t *vtcm_scratch1 = vtcm_seq_alloc(&vtcm_ptr, scratch1_sz); + __fp16 *vtcm_eye_tile = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, HMX_FP16_TILE_SIZE); + __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); + assert((size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base) <= vtcm_budget); - interleave_fp16_weight_chunk_to_tiles(vtcm_weight, (const __fp16 *) buf_curr, n_cols, params->k); - swap_ptr(&buf_curr, &buf_next); - } - TIMER_STOP(weight_load); + FARF(HIGH, "hmx-mm: m=%d k=%d n=%d wtype=%d block M=%zu N=%zu K=%zu vtcm=%zu/%zu", m, k, n, weight_type, + M_BLOCK_SIZE, N_BLOCK_SIZE, K_BLOCK_SIZE, (size_t) (vtcm_ptr - (uint8_t *) ctx->vtcm_base), vtcm_budget); - // Reuse the interleaved weight for every q_head in this GQA group - for (int g = 0; g < group_size; ++g) { - TIMER_START(hmx_core); - { - const __fp16 * vtcm_act_g = vtcm_activation + (size_t) g * act_head_stride; - core_dot_chunk_fp16(vtcm_output, vtcm_act_g, vtcm_weight, vtcm_scales, n_row_tiles, n_col_tiles, - params->k / 32); - } - TIMER_STOP(hmx_core); + // initialize eye tile (32x32 identity matrix) + { + HVX_Vector v; + v = Q6_V_vzero(); + v = Q6_Vw_vinsert_VwR(v, 0x3c000000); + v = Q6_V_vror_VR(v, VLEN - 4); + v = Q6_Vw_vinsert_VwR(v, 0x00003c00); + for (int i = 0; i < 16; ++i) { + ((HVX_Vector *) vtcm_eye_tile)[i] = v; + v = Q6_V_vror_VR(v, VLEN - 8); + } + } + hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16 - TIMER_START(output_store); - { - float *output = hmx_matmul_dst_batch_ptr(params, b2_base + g, b3) + mr * params->dst_stride + nc; - transfer_output_chunk_threaded(ctx, output, vtcm_output, (int) n_rows, (int) n_cols, params->dst_stride); - } - TIMER_STOP(output_store); - } + TIMER_DEFINE(fetch); + TIMER_DEFINE(act_load); + TIMER_DEFINE(wt_dequant); + TIMER_DEFINE(core); + + HAP_compute_res_hmx_lock(ctx->vtcm_rctx); + + for (size_t mr = 0; mr < m; mr += M_BLOCK_SIZE) { + size_t m_blk_sz = hex_smin(m - mr, M_BLOCK_SIZE); + for (size_t nc = 0; nc < n; nc += N_BLOCK_SIZE) { + size_t n_blk_sz = hex_smin(n - nc, N_BLOCK_SIZE); + + const int n_row_tiles = hmx_ceil_div(m_blk_sz, HMX_FP16_TILE_N_ROWS); + const int n_col_tiles = hmx_ceil_div(n_blk_sz, HMX_FP16_TILE_N_COLS); + + for (size_t kk = 0; kk < k; kk += K_BLOCK_SIZE) { + const size_t k_blk_sz = hex_smin(k - kk, K_BLOCK_SIZE); + + TIMER_START(fetch); + // fetch activation block into VTCM + { + const float *activation_block = x + mr * k + kk; + + dma_queue_push(ctx->dma[0], + dma_make_ptr(vtcm_scratch1, activation_block), + k_blk_sz * sizeof(float), + k * sizeof(float), + k_blk_sz * sizeof(float), + m_blk_sz); + } + + // fetch weight block into VTCM (x4x2 sub-block: quants + scales) + const size_t sub_row_stride = get_x4x2_row_stride(weight_type, k_blk_sz); + { + const int blk_start = kk / QK_Q4_0x4x2; + const int nb_sub = (k_blk_sz + QK_Q4_0x4x2 - 1) / QK_Q4_0x4x2; + const int full_qrow = (weight_type == HTP_TYPE_Q8_0) ? k : (k / 2); + const int scale_blk_size = (weight_type == HTP_TYPE_MXFP4) ? HMX_X4X2_MXFP4_EBLK_SIZE : HMX_X4X2_DBLK_SIZE; + uint8_t *dst = vtcm_scratch0; + const uint8_t *src = w + nc * row_stride; + const size_t n_rows = n_blk_sz; + const size_t src_stride = row_stride; + const size_t dst_stride = sub_row_stride; + const size_t quant_off = (weight_type == HTP_TYPE_Q8_0) ? (blk_start * QK_Q8_0x4x2) : (blk_start * (QK_Q4_0x4x2 / 2)); + const size_t quant_width = (weight_type == HTP_TYPE_Q8_0) ? (nb_sub * QK_Q8_0x4x2) : (nb_sub * (QK_Q4_0x4x2 / 2)); + const size_t scale_off = full_qrow + blk_start * scale_blk_size; + const size_t scale_width = nb_sub * scale_blk_size; + + // 2D DMA: quants sub-range + dma_queue_push(ctx->dma[0], dma_make_ptr(dst, src + quant_off), dst_stride, src_stride, quant_width, n_rows); + // 2D DMA: scales sub-range + dma_queue_push(ctx->dma[0], dma_make_ptr(dst + quant_width, src + scale_off), dst_stride, src_stride, scale_width, n_rows); + } + TIMER_STOP(fetch); + + TIMER_START(act_load); + // load activation block + { + dma_queue_pop(ctx->dma[0]); // wait for act DNA + transfer_activation_chunk_threaded(ctx, vtcm_activation, (float *) vtcm_scratch1, m_blk_sz, k_blk_sz, k_blk_sz); + } + TIMER_STOP(act_load); + + TIMER_START(wt_dequant); + // dequantize weight block + { + dma_queue_pop(ctx->dma[0]); + dma_queue_pop(ctx->dma[0]); + // vtcm_scratch0 is used to store the qweight chunk + // worker_pool_run_func already returned, so fetch is done + dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight, vtcm_scratch0, + n_blk_sz, k_blk_sz, sub_row_stride, weight_type); + } + TIMER_STOP(wt_dequant); + + // core mma + TIMER_START(core); + { + core_mma_chunk_fp16(vtcm_output, vtcm_activation, vtcm_weight, vtcm_scales, vtcm_eye_tile, n_row_tiles, + n_col_tiles, k_blk_sz / HMX_FP16_TILE_N_COLS, kk == 0); } + TIMER_STOP(core); + } + + // store output block + { + float *output_block = out + (mr * n + nc); + transfer_output_chunk_threaded(ctx, output_block, vtcm_output, m_blk_sz, n_blk_sz, n); } } } HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); - TIMER_STOP(total); - #if defined(ENABLE_PROFILE_TIMERS) - FARF(HIGH, "%s: %lld us, m=%d k=%d n=%d group=%d", __func__, TIMER_US(total), - params->m, params->k, params->n, group_size); - FARF(HIGH, " activation_load: %lld us, weight_load: %lld us, hmx_core: %lld us, output_store: %lld us", - TIMER_US(activation_load), TIMER_US(weight_load), TIMER_US(hmx_core), TIMER_US(output_store)); + FARF(HIGH, "fetch: %lld us, act_load: %lld us, wt_dequant: %lld us, core: %lld us", + TIMER_US(fetch), TIMER_US(act_load), TIMER_US(wt_dequant), TIMER_US(core)); #endif - - return 0; + return 0; } -int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, float *restrict dst, const float *restrict activation, - const __fp16 *restrict permuted_weight, int m, int k, int n, - int act_stride, int weight_stride) { +int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict dst, const float *restrict activation, + const uint8_t *restrict permuted_weight, int m, int k, int n, + int weight_type) { if (!dst || !activation || !permuted_weight || !m || !n || !k) { return -1; } - if (act_stride < k || weight_stride < k) { return -1; } if (k % 32 != 0 || n % 32 != 0) { return -1; } if (!hex_is_aligned(dst, VLEN) || !hex_is_aligned(activation, VLEN) || !hex_is_aligned(permuted_weight, VLEN)) { - return -1; + return -1; + } + + // for large m, k (e.g. prefill FFN Down), use out-stationary version + if (m >= 128 && k > n && n > 1024) { + int rc = mat_mul_qk_0_d16a32_out_stationary(ctx, dst, activation, permuted_weight, m, k, n, weight_type); + if (rc != FALLBACK_TO_STANDARD) { + return rc; // 0 success, -1 error + } + FARF(HIGH, "hmx_matmul_qk: out-stationary fallback to standard m=%d k=%d n=%d", m, k, n); + // fall through to standard path + } + + size_t row_stride = get_x4x2_row_stride(weight_type, k); + if (row_stride == 0) { + return -1; } + FARF(HIGH, "hmx_matmul_qk: STANDARD path m=%d k=%d n=%d type=%d", m, k, n, weight_type); + // --- Dynamic VTCM layout --- - const size_t vtcm_budget = ctx->vtcm_size; - const size_t vec_dot_size = k * sizeof(__fp16); + const size_t vtcm_budget = ctx->vtcm_size; + const size_t vec_dot_size = k * sizeof(__fp16); - // DMA-based activation gather for strided tensors (see batched path comment). - const bool use_dma_activation = (act_stride > k); - const size_t f32_scratch_per_m = use_dma_activation ? (size_t) k * sizeof(float) : 0; + // Pipeline = 4-stage DMA→dequant→HMX→store with HMX worker overlap. + // Only pays off when the chunker yields >=2 n-chunks, so the main loop can + // overlap HMX (C) with HVX (B/D); with a single n-chunk the extra VTCM for + // double-buffered output and the worker-dispatch overhead are pure loss. + // Try pipeline costs first; fall back to sequential if the layout collapses + // to one n-chunk. m >= 128 floor keeps HMX utilization reasonable. + const size_t pipe_per_n = row_stride + 2 * vec_dot_size; // Q + S0 + S1 (dequant bufs) + const size_t pipe_per_mn = 2 * sizeof(__fp16); // O x 2 (output double buffer) + const size_t seq_per_n = vec_dot_size + 2 * row_stride; // W + S0 + S1 (x4x2 DMA bufs) + const size_t seq_per_mn = sizeof(__fp16); // O x 1 size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0, vtcm_used = 0; - // FP16 weight: interleave and activation load have similar per-element cost. - if (hmx_compute_chunks(vtcm_budget, - /*overhead=*/256, - /*per_n=*/3 * vec_dot_size, // W + S0 + S1 - /*per_m=*/vec_dot_size + f32_scratch_per_m, // A + optional F32 scratch - /*per_mn=*/sizeof(__fp16), // O - m, n, - /*m_block_cost=*/(size_t) n, - /*n_block_cost=*/(size_t) m, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) { - FARF(HIGH, "%s: VTCM too small (m=%d k=%d n=%d budget=%zu)", __func__, m, k, n, vtcm_budget); - return -1; + bool use_pipeline = false; + + if (m >= 128) { + size_t mc = 0, nc = 0, used = 0; + if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, pipe_per_n, /*per_m=*/vec_dot_size, pipe_per_mn, m, n, + /*m_block_cost=*/(size_t) n * 3, + /*n_block_cost=*/(size_t) m * 2, &mc, &nc, &used) == 0 && + hmx_ceil_div((size_t) n, nc) >= 2) { + m_chunk_n_rows = mc; + n_chunk_n_cols = nc; + vtcm_used = used; + use_pipeline = true; + } } - const size_t weight_area_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); + if (!use_pipeline) { + if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, seq_per_n, /*per_m=*/vec_dot_size, seq_per_mn, m, n, + /*m_block_cost=*/(size_t) n * 3, + /*n_block_cost=*/(size_t) m * 2, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) { + FARF(HIGH, "%s: VTCM too small (m=%d k=%d n=%d budget=%zu)", __func__, m, k, n, vtcm_budget); + return -1; + } + } + + // Compute precise buffer sizes per execution path + const size_t weight_area_size = hex_align_up( + n_chunk_n_cols * (use_pipeline ? row_stride : vec_dot_size), HMX_FP16_TILE_SIZE); const size_t activation_area_size = hex_align_up(m_chunk_n_rows * vec_dot_size, HMX_FP16_TILE_SIZE); - const size_t output_area_size = hex_align_up(m_chunk_n_rows * n_chunk_n_cols * sizeof(__fp16), HMX_FP16_TILE_SIZE); - const size_t scratch_area_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); - const size_t f32_scratch_size = use_dma_activation - ? hex_align_up(m_chunk_n_rows * (size_t) k * sizeof(float), HMX_FP16_TILE_SIZE) : 0; + const size_t output_area_size = hex_align_up( + m_chunk_n_rows * n_chunk_n_cols * sizeof(__fp16), HMX_FP16_TILE_SIZE); + + size_t scratch0_size, scratch1_size, scratch2_size; + if (use_pipeline) { + scratch0_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); // dequant buf 0 + scratch1_size = scratch0_size; // dequant buf 1 + scratch2_size = output_area_size; // output buf 1 + } else { + scratch0_size = hex_align_up(n_chunk_n_cols * row_stride, HMX_FP16_TILE_SIZE); // x4x2 DMA buf 0 + scratch1_size = scratch0_size; // x4x2 DMA buf 1 + scratch2_size = 0; // unused + } - // VTCM layout: weight | activation | output | scratch0 | scratch1 | scales | [f32_scratch] uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base; __fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_area_size); __fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, activation_area_size); __fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, output_area_size); - void *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch_area_size); - void *vtcm_scratch1 = vtcm_seq_alloc(&vtcm_ptr, scratch_area_size); + void *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch0_size); + void *vtcm_scratch1 = vtcm_seq_alloc(&vtcm_ptr, scratch1_size); + void *vtcm_scratch2 = scratch2_size ? vtcm_seq_alloc(&vtcm_ptr, scratch2_size) : NULL; __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); - float *vtcm_f32_act = use_dma_activation ? (float *) vtcm_seq_alloc(&vtcm_ptr, f32_scratch_size) : NULL; if ((size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base) > vtcm_budget) { FARF(ERROR, "%s: vtcm overflow: used=%zu limit=%zu", __func__, (size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget); @@ -1104,8 +1138,9 @@ int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, float *restrict dst, co hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16 - FARF(MEDIUM, "%s: m=%d k=%d n=%d mc=%zu nc=%zu vtcm=%zu/%zu", - __func__, m, k, n, m_chunk_n_rows, n_chunk_n_cols, + FARF(HIGH, "%s: m=%d k=%d n=%d wtype=%d pipe=%d mc=%zu nc=%zu vtcm=%zu/%zu", + __func__, m, k, n, weight_type, use_pipeline, + m_chunk_n_rows, n_chunk_n_cols, (size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget); TIMER_DEFINE(activation_load); @@ -1116,214 +1151,9 @@ int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, float *restrict dst, co TIMER_DEFINE(total); TIMER_START(total); - HAP_compute_res_hmx_lock(ctx->vtcm_rctx); - - for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) { - // transfer activation matrix chunk into VTCM - const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows); - const size_t n_row_tiles = hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS); - - TIMER_START(activation_load); - { - const float *activation_chunk = activation + mr * act_stride; - if (use_dma_activation) { - const size_t row_bytes = (size_t) k * sizeof(float); - const size_t stride_bytes = (size_t) act_stride * sizeof(float); - dma_queue_push(ctx->dma[0], - dma_make_ptr(vtcm_f32_act, activation_chunk), - row_bytes, stride_bytes, row_bytes, n_rows); - dma_queue_pop(ctx->dma[0]); - transfer_activation_chunk_threaded(ctx, vtcm_activation, - vtcm_f32_act, n_rows, k, k); - } else { - transfer_activation_chunk_threaded(ctx, vtcm_activation, - activation_chunk, n_rows, k, act_stride); - } - } - TIMER_STOP(activation_load); - - const size_t fp16_row_bytes = (size_t) k * sizeof(__fp16); - const size_t weight_row_bytes = (size_t) weight_stride * sizeof(__fp16); - - void *buf_curr = vtcm_scratch0; - void *buf_next = vtcm_scratch1; - - // issue async DMA for the first weight chunk - // NOTE: use 2D DMA (n_cols rows x fp16_row_bytes) to avoid 16-bit roiwidth overflow. - // The source rows can be strided (e.g. KV-cache K after ggml_permute). - { - const size_t n_cols_first = hex_smin(n, n_chunk_n_cols); - - dma_queue_push(ctx->dma[0], dma_make_ptr(buf_curr, permuted_weight), - fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_first); - } - - for (size_t nc = 0; nc < n; nc += n_chunk_n_cols) { - const size_t n_cols = hex_smin(n - nc, n_chunk_n_cols); - const size_t n_col_tiles = hmx_ceil_div(n_cols, HMX_FP16_TILE_N_COLS); - - TIMER_START(weight_load); - { - dma_queue_pop(ctx->dma[0]); // wait until current weight chunk is ready - - // issue async DMA for the next weight chunk (double buffering) - const size_t nc_next = nc + n_chunk_n_cols; - if (nc_next < n) { - const size_t n_cols_next = hex_smin(n - nc_next, n_chunk_n_cols); - const __fp16 *next_weight_chunk = permuted_weight + nc_next * weight_stride; - - dma_queue_push(ctx->dma[0], dma_make_ptr(buf_next, next_weight_chunk), - fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_next); - } - - // interleave row-major fp16 from scratch into tile-major in vtcm_weight - interleave_fp16_weight_chunk_to_tiles(vtcm_weight, (const __fp16 *)buf_curr, n_cols, k); - - swap_ptr(&buf_curr, &buf_next); - } - TIMER_STOP(weight_load); - - TIMER_START(hmx_core); - { - core_dot_chunk_fp16(vtcm_output, vtcm_activation, vtcm_weight, vtcm_scales, n_row_tiles, n_col_tiles, k / 32); - } - TIMER_STOP(hmx_core); - - TIMER_START(output_store); - { - float *output = dst + (mr * n + nc); - transfer_output_chunk_threaded(ctx, output, vtcm_output, n_rows, n_cols, n); - } - TIMER_STOP(output_store); - } - - } - - HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); - - TIMER_STOP(total); - -#if defined(ENABLE_PROFILE_TIMERS) - FARF(HIGH, "%s: %lld us, m=%d k=%d n=%d", __func__, TIMER_US(total), m, k, n); - FARF(HIGH, " activation_load: %lld us, weight_load: %lld us, hmx_core: %lld us, output_store: %lld us", - TIMER_US(activation_load), TIMER_US(weight_load), TIMER_US(hmx_core), TIMER_US(output_store)); - { - size_t weight_size = (size_t)k * n * sizeof(__fp16); - float bandwidth = 1e-3f * weight_size / (float)TIMER_US(weight_load); - FARF(HIGH, " weight load bandwidth: %.2f GB/s", bandwidth); - } -#endif - - return 0; -} - -int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict out, const float *restrict x, const uint8_t *restrict w, int m, - int k, int n, int w_type); - -#define FALLBACK_TO_STANDARD 1 - -int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict dst, const float *restrict activation, - const uint8_t *restrict permuted_weight, int m, int k, int n, - int weight_type) { - if (!dst || !activation || !permuted_weight || !m || !n || !k) { return -1; } - if (k % 32 != 0 || n % 32 != 0) { return -1; } - - if (!hex_is_aligned(dst, VLEN) || !hex_is_aligned(activation, VLEN) || !hex_is_aligned(permuted_weight, VLEN)) { - return -1; - } - - // for large m, k (e.g. prefill FFN Down), use out-stationary version - if (m >= 128 && k > n && n > 1024) { - int rc = mat_mul_qk_0_d16a32_out_stationary(ctx, dst, activation, permuted_weight, m, k, n, weight_type); - if (rc != FALLBACK_TO_STANDARD) { - return rc; // 0 success, -1 error - } - FARF(MEDIUM, "hmx_matmul_qk: out-stationary fallback to standard m=%d k=%d n=%d", m, k, n); - // fall through to standard path - } - - size_t row_stride = get_x4x2_row_stride(weight_type, k); - if (row_stride == 0) { - return -1; - } - - FARF(MEDIUM, "hmx_matmul_qk: STANDARD path m=%d k=%d n=%d type=%d", m, k, n, weight_type); - - // --- Dynamic VTCM layout --- - const size_t vtcm_budget = ctx->vtcm_size; - const size_t vec_dot_size = k * sizeof(__fp16); - const bool use_pipeline = (m >= 128) && (k <= n); - - // Select cost parameters based on execution path - size_t per_n_cost, per_mn_cost; - if (use_pipeline) { - per_n_cost = row_stride + 2 * vec_dot_size; // Q + S0 + S1 (dequant bufs) - per_mn_cost = 2 * sizeof(__fp16); // O x 2 (output double buffer) - } else { - per_n_cost = vec_dot_size + 2 * row_stride; // W + S0 + S1 (x4x2 DMA bufs) - per_mn_cost = sizeof(__fp16); // O x 1 - } - - size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0, vtcm_used = 0; - // Quantized weight: dequant ~1.5x more expensive per element than activation load. - if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, per_n_cost, /*per_m=*/vec_dot_size, per_mn_cost, m, n, - /*m_block_cost=*/(size_t) n * 3, - /*n_block_cost=*/(size_t) m * 2, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) { - FARF(HIGH, "%s: VTCM too small (m=%d k=%d n=%d pipe=%d budget=%zu)", - __func__, m, k, n, use_pipeline, vtcm_budget); - return -1; - } - - // Compute precise buffer sizes per execution path - const size_t weight_area_size = hex_align_up( - n_chunk_n_cols * (use_pipeline ? row_stride : vec_dot_size), HMX_FP16_TILE_SIZE); - const size_t activation_area_size = hex_align_up(m_chunk_n_rows * vec_dot_size, HMX_FP16_TILE_SIZE); - const size_t output_area_size = hex_align_up( - m_chunk_n_rows * n_chunk_n_cols * sizeof(__fp16), HMX_FP16_TILE_SIZE); - - size_t scratch0_size, scratch1_size, scratch2_size; - if (use_pipeline) { - scratch0_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); // dequant buf 0 - scratch1_size = scratch0_size; // dequant buf 1 - scratch2_size = output_area_size; // output buf 1 - } else { - scratch0_size = hex_align_up(n_chunk_n_cols * row_stride, HMX_FP16_TILE_SIZE); // x4x2 DMA buf 0 - scratch1_size = scratch0_size; // x4x2 DMA buf 1 - scratch2_size = 0; // unused - } - - uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base; - __fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_area_size); - __fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, activation_area_size); - __fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, output_area_size); - void *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch0_size); - void *vtcm_scratch1 = vtcm_seq_alloc(&vtcm_ptr, scratch1_size); - void *vtcm_scratch2 = scratch2_size ? vtcm_seq_alloc(&vtcm_ptr, scratch2_size) : NULL; - __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); - if ((size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base) > vtcm_budget) { - FARF(ERROR, "%s: vtcm overflow: used=%zu limit=%zu", __func__, - (size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget); - return -1; - } - - hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16 - - FARF(MEDIUM, "%s: m=%d k=%d n=%d wtype=%d pipe=%d mc=%zu nc=%zu vtcm=%zu/%zu", - __func__, m, k, n, weight_type, use_pipeline, - m_chunk_n_rows, n_chunk_n_cols, - (size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget); - - TIMER_DEFINE(activation_load); - TIMER_DEFINE(weight_load); - TIMER_DEFINE(hmx_core); - TIMER_DEFINE(output_store); - - TIMER_DEFINE(total); - TIMER_START(total); - - FARF(MEDIUM, "hmx_matmul_qk: %s mc=%zu nc=%zu vtcm=%zu/%zu", - use_pipeline ? "PIPELINE" : "SEQUENTIAL", m_chunk_n_rows, n_chunk_n_cols, - (size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget); + FARF(HIGH, "hmx_matmul_qk: %s mc=%zu nc=%zu vtcm=%zu/%zu", + use_pipeline ? "PIPELINE" : "SEQUENTIAL", m_chunk_n_rows, n_chunk_n_cols, + (size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget); if (!use_pipeline) { HAP_compute_res_hmx_lock(ctx->vtcm_rctx); @@ -1368,7 +1198,7 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds // HMX computes C = A x B, where A=[M,K] activation, B=[K,N] weight. dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight, buf_curr, n_cols, k, row_stride, weight_type); - swap_ptr(&buf_curr, &buf_next); + hex_swap_ptr(&buf_curr, &buf_next); } TIMER_STOP(weight_load); @@ -1511,300 +1341,417 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds return 0; } -// C += AB -void core_mma_chunk_fp16(__fp16 *restrict c, const __fp16 *restrict a, const __fp16 *restrict b, const __fp16 *restrict col_scales, const __fp16 *restrict eye_tile, - int n_row_tiles, int n_col_tiles, int n_dot_tiles, bool zero_init) { - __builtin_assume(n_row_tiles > 0); - __builtin_assume(n_col_tiles > 0); - __builtin_assume(n_dot_tiles > 0); +// - Q6_bias_mxmem2_A((void *)col_scales); +static inline int hmx_matmul_batch_r2(const hmx_matmul_w16a32_batched_params_t *params) { + return params->ne02 > 0 ? params->ne12 / params->ne02 : 1; +} - const size_t dot_tile_stride = n_dot_tiles * HMX_FP16_TILE_N_ELMS; - for (size_t i = 0; i < n_row_tiles; ++i) { - const __fp16 *row_base = a + i * dot_tile_stride; - __fp16 *res_base = c + i * n_col_tiles * HMX_FP16_TILE_N_ELMS; - for (size_t j = 0; j < n_col_tiles; ++j) { - Q6_mxclracc_hf(); +static inline int hmx_matmul_batch_r3(const hmx_matmul_w16a32_batched_params_t *params) { + return params->ne03 > 0 ? params->ne13 / params->ne03 : 1; +} - const __fp16 *col_tiles = b + j * dot_tile_stride; - const __fp16 *row_tiles = row_base; - __fp16 *accum_tile = res_base + j * HMX_FP16_TILE_N_ELMS; - if (!zero_init) { - Q6_activation_hf_mxmem_RR((unsigned int)accum_tile, 2047); - Q6_weight_hf_mxmem_RR((unsigned int)eye_tile, 2047); - } +static inline const __fp16 *hmx_matmul_weight_batch_ptr(const hmx_matmul_w16a32_batched_params_t *params, + int dst_b2, int dst_b3) { + const int r2 = hmx_matmul_batch_r2(params); + const int r3 = hmx_matmul_batch_r3(params); + return (const __fp16 *) ((const uint8_t *) params->permuted_weight + + (size_t) (dst_b2 / r2) * params->src0_nb2 + + (size_t) (dst_b3 / r3) * params->src0_nb3); +} - for (int k = 0; k < n_dot_tiles; ++k) { - Q6_activation_hf_mxmem_RR((unsigned int)row_tiles, 2047); - Q6_weight_hf_mxmem_RR((unsigned int)col_tiles, 2047); - row_tiles += HMX_FP16_TILE_N_ELMS; - col_tiles += HMX_FP16_TILE_N_ELMS; - } - Q6_mxmem_AR_after_hf(accum_tile, 0); - } - } +static inline const float *hmx_matmul_activation_batch_ptr(const hmx_matmul_w16a32_batched_params_t *params, + int dst_b2, int dst_b3) { + return (const float *) ((const uint8_t *) params->activation + + (size_t) dst_b2 * params->src1_nb2 + + (size_t) dst_b3 * params->src1_nb3); } -static void transfer_activation_chunk_fp32_to_fp16(__fp16 *restrict vtcm_dst, const float *restrict src, int n_rows, - int k_block, int k_stride) { - for (int r = 0; r < n_rows; r += 2) { - int r0 = r / HMX_FP16_TILE_N_ROWS; // tile row index - int r1 = r % HMX_FP16_TILE_N_ROWS; // intra-tile row idx +static inline float *hmx_matmul_dst_batch_ptr(const hmx_matmul_w16a32_batched_params_t *params, + int dst_b2, int dst_b3) { + return (float *) ((uint8_t *) params->dst + + (size_t) dst_b2 * params->dst_nb2 + + (size_t) dst_b3 * params->dst_nb3); +} - const bool next_row_valid = (r + 1) < n_rows; +static int hmx_mat_mul_permuted_w16a32_batched_legacy(struct htp_context *ctx, + const hmx_matmul_w16a32_batched_params_t *params) { + int ret = 0; + for (int b3 = 0; b3 < params->ne13 && ret == 0; ++b3) { + for (int b2 = 0; b2 < params->ne12 && ret == 0; ++b2) { + ret = hmx_mat_mul_permuted_w16a32(ctx, + hmx_matmul_dst_batch_ptr(params, b2, b3), + hmx_matmul_activation_batch_ptr(params, b2, b3), + hmx_matmul_weight_batch_ptr(params, b2, b3), + params->m, params->k, params->n, + params->act_stride, params->weight_stride); + } + } + return ret; +} - const HVX_Vector *pv_in0 = (const HVX_Vector *) (src + (r + 0) * k_stride); - const HVX_Vector *pv_in1 = (const HVX_Vector *) (src + (r + 1) * k_stride); - for (int c = 0; c < k_block; c += 32) { - HVX_Vector v0 = *pv_in0++; - HVX_Vector v1 = next_row_valid ? *pv_in1++ : Q6_V_vzero(); +int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmul_w16a32_batched_params_t *params) { + if (!ctx || !params || !params->dst || !params->activation || !params->permuted_weight) { return -1; } + if (!params->m || !params->k || !params->n) { return -1; } + if (params->act_stride < params->k || params->weight_stride < params->k || params->dst_stride < params->n) { return -1; } + if (params->ne02 <= 0 || params->ne03 <= 0 || params->ne12 <= 0 || params->ne13 <= 0) { return -1; } + if (params->ne12 % params->ne02 != 0 || params->ne13 % params->ne03 != 0) { return -1; } + if (params->k % 32 != 0 || params->n % 32 != 0) { return -1; } - HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); + if (!hex_is_aligned(params->dst, VLEN) || + !hex_is_aligned(params->activation, VLEN) || + !hex_is_aligned(params->permuted_weight, VLEN)) { + return -1; + } - // compute output position - int c0 = c / HMX_FP16_TILE_N_COLS; // tile column index - int tile_idx = r0 * (k_block / HMX_FP16_TILE_N_COLS) + c0; + const int group_size = hmx_matmul_batch_r2(params); - HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HMX_FP16_TILE_N_ELMS); - tile[r1 / 2] = v_out; - } + if (group_size <= 1) { + FARF(HIGH, "%s: no dim2 GQA reuse (group=%d), using legacy batched loop", __func__, group_size); + return hmx_mat_mul_permuted_w16a32_batched_legacy(ctx, params); } -} -typedef struct { - __fp16 *dst; - const float *src; - int n_tasks; - int n_tot_chunks; - int n_chunks_per_task; - int k_block; - int k_stride; -} activation_transfer_task_state_t; + // Grouped path: reuse interleaved weight across all q_heads sharing a + // kv_head. Each q_head gets its own activation buffer in VTCM (so + // activation is loaded once per m_chunk and reused across all n_chunks), + // and each q_head is computed individually to avoid tile-major packing + // issues. m_chunk_n_rows is always a multiple of 32 (from + // hmx_compute_chunks), so per-head tile arrays don't overlap. + const size_t vtcm_budget = ctx->vtcm_size; + const size_t vec_dot_size = params->k * sizeof(__fp16); -static void transfer_activation_chunk_worker_fn(unsigned int n, unsigned int i, void *data) { - activation_transfer_task_state_t *st = (activation_transfer_task_state_t *) data; + // When the activation has a large stride (e.g. permuted Q tensor with + // act_stride >> k), HVX vector loads from strided DDR thrash L2 cache. + // Allocate an F32 scratch buffer in VTCM and use 2D DMA to gather + // strided rows into a contiguous block before the F32->F16 conversion. + const bool use_dma_activation = (params->act_stride > params->k); + const size_t f32_scratch_per_m = use_dma_activation ? (size_t) params->k * sizeof(float) : 0; - for (unsigned int task_id = i; task_id < (unsigned int)st->n_tasks; task_id += n) { - // one chunk: one row - int chunk_idx = task_id * st->n_chunks_per_task; - size_t chunk_size = hex_smin(st->n_tot_chunks - chunk_idx, st->n_chunks_per_task); + size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0, vtcm_used = 0; + // FP16 weight: interleave and activation load have similar per-element cost. + if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, + /*per_n=*/3 * vec_dot_size, + /*per_m=*/group_size * vec_dot_size + f32_scratch_per_m, + /*per_mn=*/sizeof(__fp16), params->m, params->n, + /*m_block_cost=*/(size_t) params->n, + /*n_block_cost=*/(size_t) params->m, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) { + FARF(HIGH, "%s: grouped path does not fit VTCM, falling back to legacy batched loop", __func__); + return hmx_mat_mul_permuted_w16a32_batched_legacy(ctx, params); + } - __fp16 *dst = st->dst + chunk_idx * st->k_block; - const float *src = st->src + chunk_idx * st->k_stride; - transfer_activation_chunk_fp32_to_fp16(dst, src, chunk_size, st->k_block, st->k_stride); + const size_t act_head_stride = m_chunk_n_rows * (size_t) params->k; // fp16 elements between heads + const size_t weight_area_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); + const size_t activation_area_size = hex_align_up(group_size * m_chunk_n_rows * vec_dot_size, HMX_FP16_TILE_SIZE); + const size_t output_area_size = hex_align_up(m_chunk_n_rows * n_chunk_n_cols * sizeof(__fp16), HMX_FP16_TILE_SIZE); + const size_t scratch_area_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); + const size_t f32_scratch_size = use_dma_activation + ? hex_align_up(m_chunk_n_rows * (size_t) params->k * sizeof(float), HMX_FP16_TILE_SIZE) : 0; + + uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base; + __fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_area_size); + __fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, activation_area_size); + __fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, output_area_size); + void *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch_area_size); + void *vtcm_scratch1 = vtcm_seq_alloc(&vtcm_ptr, scratch_area_size); + __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); + float *vtcm_f32_act = use_dma_activation ? (float *) vtcm_seq_alloc(&vtcm_ptr, f32_scratch_size) : NULL; + + if ((size_t) (vtcm_ptr - (uint8_t *) ctx->vtcm_base) > vtcm_budget) { + FARF(HIGH, "%s: grouped layout overflowed VTCM, falling back to legacy batched loop", __func__); + return hmx_mat_mul_permuted_w16a32_batched_legacy(ctx, params); } -} -void transfer_activation_chunk_threaded(struct htp_context *ctx, __fp16 *dst, const float *src, int n_rows, int k_block, int k_stride) { - assert(k_block % HMX_FP16_TILE_N_COLS == 0 && k_stride % HMX_FP16_TILE_N_COLS == 0); - assert(VLEN == 32 * sizeof(float)); + hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16 - size_t n_tot_chunks = n_rows; - size_t n_chunks_per_task = 32; // must be multiple of 32 to ensure correct destination address + FARF(HIGH, "%s: grouped path m=%d k=%d n=%d group=%d streams=%d mc=%zu nc=%zu vtcm=%zu/%zu", + __func__, params->m, params->k, params->n, group_size, params->ne13, + m_chunk_n_rows, n_chunk_n_cols, + (size_t) (vtcm_ptr - (uint8_t *) ctx->vtcm_base), vtcm_budget); - activation_transfer_task_state_t state; - state.n_tasks = (n_tot_chunks + n_chunks_per_task - 1) / n_chunks_per_task; - state.n_tot_chunks = n_tot_chunks; - state.n_chunks_per_task = n_chunks_per_task; - state.dst = dst; - state.src = src; - state.k_block = k_block; - state.k_stride = k_stride; + TIMER_DEFINE(activation_load); + TIMER_DEFINE(weight_load); + TIMER_DEFINE(hmx_core); + TIMER_DEFINE(output_store); + TIMER_DEFINE(total); - worker_pool_run_func(ctx->worker_pool, transfer_activation_chunk_worker_fn, &state, ctx->n_threads); -} + TIMER_START(total); -int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, float *restrict out, const float *restrict x, const uint8_t *restrict w, - int m, int k, int n, int weight_type) { - // assume k % 32 == 0 && n % 32 == 0 - const size_t row_stride = get_x4x2_row_stride(weight_type, k); - if (row_stride == 0) { - return -1; + const size_t fp16_row_bytes = (size_t) params->k * sizeof(__fp16); + const size_t weight_row_bytes = (size_t) params->weight_stride * sizeof(__fp16); + + HAP_compute_res_hmx_lock(ctx->vtcm_rctx); + + for (int b3 = 0; b3 < params->ne13; ++b3) { + for (int b2_base = 0; b2_base < params->ne12; b2_base += group_size) { + const __fp16 *weight_group = hmx_matmul_weight_batch_ptr(params, b2_base, b3); + + for (size_t mr = 0; mr < (size_t) params->m; mr += m_chunk_n_rows) { + const size_t n_rows = hex_smin((size_t) params->m - mr, m_chunk_n_rows); + const size_t n_row_tiles = hmx_ceil_div((int) n_rows, HMX_FP16_TILE_N_ROWS); + + // Pre-load activations for all heads in the group (once per m_chunk). + // When the source is strided (permuted Q), use 2D DMA to gather + // contiguous rows into a VTCM scratch buffer first, then HVX + // converts from the contiguous VTCM buffer. This avoids L2 cache + // thrashing from HVX loads at large strides. + TIMER_START(activation_load); + for (int g = 0; g < group_size; ++g) { + const float *activation_chunk = hmx_matmul_activation_batch_ptr(params, b2_base + g, b3) + mr * params->act_stride; + __fp16 *vtcm_act_g = vtcm_activation + (size_t) g * act_head_stride; + if (use_dma_activation) { + const size_t row_bytes = (size_t) params->k * sizeof(float); + const size_t stride_bytes = (size_t) params->act_stride * sizeof(float); + dma_queue_push(ctx->dma[0], + dma_make_ptr(vtcm_f32_act, activation_chunk), + row_bytes, stride_bytes, row_bytes, n_rows); + dma_queue_pop(ctx->dma[0]); + transfer_activation_chunk_threaded(ctx, vtcm_act_g, + vtcm_f32_act, (int) n_rows, + params->k, params->k); + } else { + transfer_activation_chunk_threaded(ctx, vtcm_act_g, + activation_chunk, (int) n_rows, + params->k, params->act_stride); + } + } + TIMER_STOP(activation_load); + + void *buf_curr = vtcm_scratch0; + void *buf_next = vtcm_scratch1; + + { + const size_t n_cols_first = hex_smin((size_t) params->n, n_chunk_n_cols); + dma_queue_push(ctx->dma[0], dma_make_ptr(buf_curr, weight_group), + fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_first); + } + + for (size_t nc = 0; nc < (size_t) params->n; nc += n_chunk_n_cols) { + const size_t n_cols = hex_smin((size_t) params->n - nc, n_chunk_n_cols); + const size_t n_col_tiles = hmx_ceil_div((int) n_cols, HMX_FP16_TILE_N_COLS); + + TIMER_START(weight_load); + { + dma_queue_pop(ctx->dma[0]); + + const size_t nc_next = nc + n_chunk_n_cols; + if (nc_next < (size_t) params->n) { + const size_t n_cols_next = hex_smin((size_t) params->n - nc_next, n_chunk_n_cols); + const __fp16 *next_weight_chunk = weight_group + nc_next * params->weight_stride; + + dma_queue_push(ctx->dma[0], dma_make_ptr(buf_next, next_weight_chunk), + fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_next); + } + + hmx_interleave_rows_to_tiles(vtcm_weight, (const __fp16 *) buf_curr, n_cols, params->k, params->k, + 0, n_cols); + hex_swap_ptr(&buf_curr, &buf_next); + } + TIMER_STOP(weight_load); + + // Reuse the interleaved weight for every q_head in this GQA group + for (int g = 0; g < group_size; ++g) { + TIMER_START(hmx_core); + { + const __fp16 * vtcm_act_g = vtcm_activation + (size_t) g * act_head_stride; + core_dot_chunk_fp16(vtcm_output, vtcm_act_g, vtcm_weight, vtcm_scales, n_row_tiles, n_col_tiles, + params->k / 32); + } + TIMER_STOP(hmx_core); + + TIMER_START(output_store); + { + float *output = hmx_matmul_dst_batch_ptr(params, b2_base + g, b3) + mr * params->dst_stride + nc; + transfer_output_chunk_threaded(ctx, output, vtcm_output, (int) n_rows, (int) n_cols, params->dst_stride); + } + TIMER_STOP(output_store); + } + } + } + } } - const size_t vtcm_budget = ctx->vtcm_size; + HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); - const size_t K_BLOCK_SIZE = 1024; + TIMER_STOP(total); - // Fallback: if k doesn't need K-blocking, out-stationary has no advantage - const size_t k_iters_check = (k + K_BLOCK_SIZE - 1) / K_BLOCK_SIZE; - if (k_iters_check <= 1) { - FARF(MEDIUM, "%s: K_BLK=%zu >= k=%d, fallback to standard path", __func__, K_BLOCK_SIZE, k); - return FALLBACK_TO_STANDARD; +#if defined(ENABLE_PROFILE_TIMERS) + FARF(HIGH, "%s: %lld us, m=%d k=%d n=%d group=%d", __func__, TIMER_US(total), + params->m, params->k, params->n, group_size); + FARF(HIGH, " activation_load: %lld us, weight_load: %lld us, hmx_core: %lld us, output_store: %lld us", + TIMER_US(activation_load), TIMER_US(weight_load), TIMER_US(hmx_core), TIMER_US(output_store)); +#endif + + return 0; +} + +// + +int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, float *restrict dst, const float *restrict activation, + const __fp16 *restrict permuted_weight, int m, int k, int n, + int act_stride, int weight_stride) { + if (!dst || !activation || !permuted_weight || !m || !n || !k) { return -1; } + if (act_stride < k || weight_stride < k) { return -1; } + if (k % 32 != 0 || n % 32 != 0) { return -1; } + + if (!hex_is_aligned(dst, VLEN) || !hex_is_aligned(activation, VLEN) || !hex_is_aligned(permuted_weight, VLEN)) { + return -1; } - // Dynamic M,N search via hmx_compute_chunks - const size_t sub_row_stride_alloc = get_x4x2_row_stride(weight_type, K_BLOCK_SIZE); - const size_t per_m = K_BLOCK_SIZE * sizeof(float) // scratch1: M×K×4 (act DMA staging F32) - + K_BLOCK_SIZE * sizeof(__fp16); // activation: M×K×2 (F16 tiles) - const size_t per_n = sub_row_stride_alloc // scratch0: N×sub_row(K) (packed quant) - + K_BLOCK_SIZE * sizeof(__fp16); // weight: N×K×2 (F16 tiles) - const size_t per_mn = sizeof(__fp16); // output: M×N×2 (out-stationary) - // Alignment margin: hex_align_up can add up to 2047 bytes per buffer; - // scratch1 (mc×6144) is naturally 2048-aligned, remaining 4 buffers need margin - const size_t align_margin = 4 * HMX_FP16_TILE_SIZE; - const size_t overhead = HMX_FP16_TILE_SIZE + 256 + align_margin; // eye_tile + scales + alignment + // --- Dynamic VTCM layout --- + const size_t vtcm_budget = ctx->vtcm_size; + const size_t vec_dot_size = k * sizeof(__fp16); - size_t M_BLOCK_SIZE, N_BLOCK_SIZE, vtcm_used; - // Cost-based search: minimize ceil(m/mc)*m_block_cost + ceil(n/nc)*n_block_cost. - // From profiling: wt_dequant per element ≈ 1.5× activation load per element. - // m_block_cost = n*3: each extra M-block re-dequants all N×K weight (expensive). - // n_block_cost = m*2: each extra N-block re-loads all M×K activation (cheaper). - const size_t m_block_cost = (size_t) n * 3; - const size_t n_block_cost = (size_t) m * 2; - if (hmx_compute_chunks(vtcm_budget, overhead, per_n, per_m, per_mn, m, n, m_block_cost, n_block_cost, &M_BLOCK_SIZE, - &N_BLOCK_SIZE, &vtcm_used) != 0) { + // DMA-based activation gather for strided tensors (see batched path comment). + const bool use_dma_activation = (act_stride > k); + const size_t f32_scratch_per_m = use_dma_activation ? (size_t) k * sizeof(float) : 0; + + size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0, vtcm_used = 0; + // FP16 weight: interleave and activation load have similar per-element cost. + if (hmx_compute_chunks(vtcm_budget, + /*overhead=*/256, + /*per_n=*/3 * vec_dot_size, // W + S0 + S1 + /*per_m=*/vec_dot_size + f32_scratch_per_m, // A + optional F32 scratch + /*per_mn=*/sizeof(__fp16), // O + m, n, + /*m_block_cost=*/(size_t) n, + /*n_block_cost=*/(size_t) m, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) { FARF(HIGH, "%s: VTCM too small (m=%d k=%d n=%d budget=%zu)", __func__, m, k, n, vtcm_budget); return -1; } - // Compute precise buffer sizes from searched M,N and fixed K - const size_t weight_size = hex_align_up(N_BLOCK_SIZE * K_BLOCK_SIZE * sizeof(__fp16), HMX_FP16_TILE_SIZE); - const size_t act_size = hex_align_up(M_BLOCK_SIZE * K_BLOCK_SIZE * sizeof(__fp16), HMX_FP16_TILE_SIZE); - const size_t out_size = hex_align_up(M_BLOCK_SIZE * N_BLOCK_SIZE * sizeof(__fp16), HMX_FP16_TILE_SIZE); - const size_t scratch0_sz = hex_align_up(N_BLOCK_SIZE * sub_row_stride_alloc, HMX_FP16_TILE_SIZE); - const size_t scratch1_sz = hex_align_up(M_BLOCK_SIZE * K_BLOCK_SIZE * sizeof(float), HMX_FP16_TILE_SIZE); + const size_t weight_area_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); + const size_t activation_area_size = hex_align_up(m_chunk_n_rows * vec_dot_size, HMX_FP16_TILE_SIZE); + const size_t output_area_size = hex_align_up(m_chunk_n_rows * n_chunk_n_cols * sizeof(__fp16), HMX_FP16_TILE_SIZE); + const size_t scratch_area_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); + const size_t f32_scratch_size = use_dma_activation + ? hex_align_up(m_chunk_n_rows * (size_t) k * sizeof(float), HMX_FP16_TILE_SIZE) : 0; - const size_t total_vtcm = weight_size + act_size + out_size + scratch0_sz + scratch1_sz + HMX_FP16_TILE_SIZE + 256; - if (total_vtcm > vtcm_budget) { - FARF(HIGH, "%s: VTCM overflow after search: need %zu have %zu (M=%zu N=%zu K=%zu)", __func__, total_vtcm, - vtcm_budget, M_BLOCK_SIZE, N_BLOCK_SIZE, K_BLOCK_SIZE); + // VTCM layout: weight | activation | output | scratch0 | scratch1 | scales | [f32_scratch] + uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base; + __fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_area_size); + __fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, activation_area_size); + __fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, output_area_size); + void *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch_area_size); + void *vtcm_scratch1 = vtcm_seq_alloc(&vtcm_ptr, scratch_area_size); + __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); + float *vtcm_f32_act = use_dma_activation ? (float *) vtcm_seq_alloc(&vtcm_ptr, f32_scratch_size) : NULL; + if ((size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base) > vtcm_budget) { + FARF(ERROR, "%s: vtcm overflow: used=%zu limit=%zu", __func__, + (size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget); return -1; } - uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base; - __fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_size); - __fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, act_size); - __fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, out_size); - uint8_t *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch0_sz); - uint8_t *vtcm_scratch1 = vtcm_seq_alloc(&vtcm_ptr, scratch1_sz); - __fp16 *vtcm_eye_tile = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, HMX_FP16_TILE_SIZE); - __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); - assert((size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base) <= vtcm_budget); + hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16 - FARF(HIGH, "hmx-mm: m=%d k=%d n=%d wtype=%d block M=%zu N=%zu K=%zu vtcm=%zu/%zu", m, k, n, weight_type, - M_BLOCK_SIZE, N_BLOCK_SIZE, K_BLOCK_SIZE, (size_t) (vtcm_ptr - (uint8_t *) ctx->vtcm_base), vtcm_budget); + FARF(HIGH, "%s: m=%d k=%d n=%d mc=%zu nc=%zu vtcm=%zu/%zu", + __func__, m, k, n, m_chunk_n_rows, n_chunk_n_cols, + (size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget); - // initialize eye tile (32x32 identity matrix) - { - HVX_Vector v; - v = Q6_V_vzero(); - v = Q6_Vw_vinsert_VwR(v, 0x3c000000); - v = Q6_V_vror_VR(v, VLEN - 4); - v = Q6_Vw_vinsert_VwR(v, 0x00003c00); - for (int i = 0; i < 16; ++i) { - ((HVX_Vector *) vtcm_eye_tile)[i] = v; - v = Q6_V_vror_VR(v, VLEN - 8); - } - } - hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16 + TIMER_DEFINE(activation_load); + TIMER_DEFINE(weight_load); + TIMER_DEFINE(hmx_core); + TIMER_DEFINE(output_store); - TIMER_DEFINE(fetch); - TIMER_DEFINE(act_load); - TIMER_DEFINE(wt_dequant); - TIMER_DEFINE(core); + TIMER_DEFINE(total); + TIMER_START(total); HAP_compute_res_hmx_lock(ctx->vtcm_rctx); - for (size_t mr = 0; mr < m; mr += M_BLOCK_SIZE) { - size_t m_blk_sz = hex_smin(m - mr, M_BLOCK_SIZE); - for (size_t nc = 0; nc < n; nc += N_BLOCK_SIZE) { - size_t n_blk_sz = hex_smin(n - nc, N_BLOCK_SIZE); + for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) { + // transfer activation matrix chunk into VTCM + const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows); + const size_t n_row_tiles = hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS); - const int n_row_tiles = hmx_ceil_div(m_blk_sz, HMX_FP16_TILE_N_ROWS); - const int n_col_tiles = hmx_ceil_div(n_blk_sz, HMX_FP16_TILE_N_COLS); + TIMER_START(activation_load); + { + const float *activation_chunk = activation + mr * act_stride; + if (use_dma_activation) { + const size_t row_bytes = (size_t) k * sizeof(float); + const size_t stride_bytes = (size_t) act_stride * sizeof(float); + dma_queue_push(ctx->dma[0], + dma_make_ptr(vtcm_f32_act, activation_chunk), + row_bytes, stride_bytes, row_bytes, n_rows); + dma_queue_pop(ctx->dma[0]); + transfer_activation_chunk_threaded(ctx, vtcm_activation, + vtcm_f32_act, n_rows, k, k); + } else { + transfer_activation_chunk_threaded(ctx, vtcm_activation, + activation_chunk, n_rows, k, act_stride); + } + } + TIMER_STOP(activation_load); - for (size_t kk = 0; kk < k; kk += K_BLOCK_SIZE) { - const size_t k_blk_sz = hex_smin(k - kk, K_BLOCK_SIZE); + const size_t fp16_row_bytes = (size_t) k * sizeof(__fp16); + const size_t weight_row_bytes = (size_t) weight_stride * sizeof(__fp16); - TIMER_START(fetch); - // fetch activation block into VTCM - { - const float *activation_block = x + mr * k + kk; + void *buf_curr = vtcm_scratch0; + void *buf_next = vtcm_scratch1; - dma_queue_push(ctx->dma[0], - dma_make_ptr(vtcm_scratch1, activation_block), - k_blk_sz * sizeof(float), - k * sizeof(float), - k_blk_sz * sizeof(float), - m_blk_sz); - } + // issue async DMA for the first weight chunk + // NOTE: use 2D DMA (n_cols rows x fp16_row_bytes) to avoid 16-bit roiwidth overflow. + // The source rows can be strided (e.g. KV-cache K after ggml_permute). + { + const size_t n_cols_first = hex_smin(n, n_chunk_n_cols); - // fetch weight block into VTCM (x4x2 sub-block: quants + scales) - const size_t sub_row_stride = get_x4x2_row_stride(weight_type, k_blk_sz); - { - qweight_fetch_task_state_t s; - - const int blk_start = kk / QK_Q4_0x4x2; - const int nb_sub = (k_blk_sz + QK_Q4_0x4x2 - 1) / QK_Q4_0x4x2; - const int full_qrow = (weight_type == HTP_TYPE_Q8_0) ? k : (k / 2); - const int scale_blk_size = - (weight_type == HTP_TYPE_MXFP4) ? HMX_X4X2_MXFP4_EBLK_SIZE : HMX_X4X2_DBLK_SIZE; - - s.dst = vtcm_scratch0; - s.src = w + nc * row_stride; - s.n_rows = n_blk_sz; - s.src_stride = row_stride; - s.dst_stride = sub_row_stride; - s.quant_off = - (weight_type == HTP_TYPE_Q8_0) ? (blk_start * QK_Q8_0x4x2) : (blk_start * (QK_Q4_0x4x2 / 2)); - s.quant_width = - (weight_type == HTP_TYPE_Q8_0) ? (nb_sub * QK_Q8_0x4x2) : (nb_sub * (QK_Q4_0x4x2 / 2)); - s.scale_off = full_qrow + blk_start * scale_blk_size; - s.scale_width = nb_sub * scale_blk_size; + dma_queue_push(ctx->dma[0], dma_make_ptr(buf_curr, permuted_weight), + fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_first); + } - // 2D DMA: quants sub-range - dma_queue_push(ctx->dma[0], dma_make_ptr(s.dst, s.src + s.quant_off), - s.dst_stride, s.src_stride, s.quant_width, s.n_rows); - // 2D DMA: scales sub-range - dma_queue_push(ctx->dma[0], dma_make_ptr(s.dst + s.quant_width, s.src + s.scale_off), - s.dst_stride, s.src_stride, s.scale_width, s.n_rows); - } - TIMER_STOP(fetch); + for (size_t nc = 0; nc < n; nc += n_chunk_n_cols) { + const size_t n_cols = hex_smin(n - nc, n_chunk_n_cols); + const size_t n_col_tiles = hmx_ceil_div(n_cols, HMX_FP16_TILE_N_COLS); - TIMER_START(act_load); - // load activation block - { - dma_queue_pop(ctx->dma[0]); // wait for act DNA - transfer_activation_chunk_threaded(ctx, vtcm_activation, (float *) vtcm_scratch1, m_blk_sz, k_blk_sz, k_blk_sz); - } - TIMER_STOP(act_load); + TIMER_START(weight_load); + { + dma_queue_pop(ctx->dma[0]); // wait until current weight chunk is ready - TIMER_START(wt_dequant); - // dequantize weight block - { - dma_queue_pop(ctx->dma[0]); - dma_queue_pop(ctx->dma[0]); - // vtcm_scratch0 is used to store the qweight chunk - // worker_pool_run_func already returned, so fetch is done - dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight, vtcm_scratch0, - n_blk_sz, k_blk_sz, sub_row_stride, weight_type); - } - TIMER_STOP(wt_dequant); + // issue async DMA for the next weight chunk (double buffering) + const size_t nc_next = nc + n_chunk_n_cols; + if (nc_next < n) { + const size_t n_cols_next = hex_smin(n - nc_next, n_chunk_n_cols); + const __fp16 *next_weight_chunk = permuted_weight + nc_next * weight_stride; - // core mma - TIMER_START(core); - { - core_mma_chunk_fp16(vtcm_output, vtcm_activation, vtcm_weight, vtcm_scales, vtcm_eye_tile, n_row_tiles, - n_col_tiles, k_blk_sz / HMX_FP16_TILE_N_COLS, kk == 0); + dma_queue_push(ctx->dma[0], dma_make_ptr(buf_next, next_weight_chunk), + fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_next); } - TIMER_STOP(core); + + // interleave row-major fp16 from scratch into tile-major in vtcm_weight + hmx_interleave_rows_to_tiles(vtcm_weight, (const __fp16 *) buf_curr, n_cols, k, k, 0, n_cols); + + hex_swap_ptr(&buf_curr, &buf_next); } + TIMER_STOP(weight_load); - // store output block + TIMER_START(hmx_core); { - float *output_block = out + (mr * n + nc); - transfer_output_chunk_threaded(ctx, output_block, vtcm_output, m_blk_sz, n_blk_sz, n); + core_dot_chunk_fp16(vtcm_output, vtcm_activation, vtcm_weight, vtcm_scales, n_row_tiles, n_col_tiles, k / 32); + } + TIMER_STOP(hmx_core); + + TIMER_START(output_store); + { + float *output = dst + (mr * n + nc); + transfer_output_chunk_threaded(ctx, output, vtcm_output, n_rows, n_cols, n); } + TIMER_STOP(output_store); } + } HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); + TIMER_STOP(total); + #if defined(ENABLE_PROFILE_TIMERS) - FARF(HIGH, "fetch: %lld us, act_load: %lld us, wt_dequant: %lld us, core: %lld us", - TIMER_US(fetch), TIMER_US(act_load), TIMER_US(wt_dequant), TIMER_US(core)); + FARF(HIGH, "%s: %lld us, m=%d k=%d n=%d", __func__, TIMER_US(total), m, k, n); + FARF(HIGH, " activation_load: %lld us, weight_load: %lld us, hmx_core: %lld us, output_store: %lld us", + TIMER_US(activation_load), TIMER_US(weight_load), TIMER_US(hmx_core), TIMER_US(output_store)); + { + size_t weight_size = (size_t)k * n * sizeof(__fp16); + float bandwidth = 1e-3f * weight_size / (float)TIMER_US(weight_load); + FARF(HIGH, " weight load bandwidth: %.2f GB/s", bandwidth); + } #endif + return 0; } diff --git a/ggml/src/ggml-hexagon/htp/hmx-ops.h b/ggml/src/ggml-hexagon/htp/hmx-ops.h index fb95d36f5a9..1c78ffadd1c 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-ops.h +++ b/ggml/src/ggml-hexagon/htp/hmx-ops.h @@ -61,6 +61,9 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, int m, int k, int n, int weight_type); +// HMX flash attention +int hmx_flash_attn_ext(struct htp_ops_context * octx); + #ifdef __cplusplus } #endif diff --git a/ggml/src/ggml-hexagon/htp/hmx-utils.h b/ggml/src/ggml-hexagon/htp/hmx-utils.h index af04619cebb..68f174d6937 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-utils.h +++ b/ggml/src/ggml-hexagon/htp/hmx-utils.h @@ -4,6 +4,9 @@ #ifndef HMX_UTILS_H #define HMX_UTILS_H +#include "hvx-base.h" + +#include #include #include @@ -12,21 +15,188 @@ #define HMX_FP16_TILE_N_ELMS 1024 #define HMX_FP16_TILE_SIZE 2048 -#define HMX_INLINE_ALWAYS inline __attribute__((unused, always_inline)) - // Initialise aligned 256-byte area with scale vector + zero padding. -static HMX_INLINE_ALWAYS void hmx_init_column_scales(void *out_scales, HVX_Vector v_scale) { - HVX_Vector *pv = (HVX_Vector *)out_scales; - *pv++ = v_scale; - *pv = Q6_V_vzero(); +static inline void hmx_init_column_scales(void *out_scales, HVX_Vector v_scale) { + volatile HVX_Vector *pv = (HVX_Vector *) out_scales; + pv[0] = v_scale; + pv[1] = Q6_V_vzero(); +} + +// --- Shared scatter offsets and interleave helper --- + +// vscatter offsets for fused dequant+transpose: write K-values directly to [K][N] tile. +// word[i] = i*128 maps K-row-pair i to byte offset i*128. +// Column offset (n*4) is added at runtime. Entries 0..15 cover one tile (region 2047); +// entries 16..31 cover the next adjacent tile (region 4095) — pick region size at the +// call site to scatter into one tile (masked) or two contiguous tiles (unmasked). +static const int32_t hmx_transpose_scatter_offsets[32] __attribute__((aligned(VLEN))) = { + 0 * 128, 1 * 128, 2 * 128, 3 * 128, 4 * 128, 5 * 128, 6 * 128, 7 * 128, 8 * 128, 9 * 128, 10 * 128, + 11 * 128, 12 * 128, 13 * 128, 14 * 128, 15 * 128, 16 * 128, 17 * 128, 18 * 128, 19 * 128, 20 * 128, 21 * 128, + 22 * 128, 23 * 128, 24 * 128, 25 * 128, 26 * 128, 27 * 128, 28 * 128, 29 * 128, 30 * 128, 31 * 128, +}; + +// Scatter row-major FP16 data (in VTCM scratch) into transposed [K][N] tiles. +// vtcm_src: [n_cols][src_stride] row-major fp16 (only first k elements per row are used) +// vtcm_dst: [n_col_tiles][n_k_tiles][HMX_FP16_TILE_N_ELMS] tile-major interleaved fp16 +// Processes rows [start_row, end_row) for multi-thread slicing. +// Full range: start_row=0, end_row=n_cols. +static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst, + const __fp16 * restrict vtcm_src, + int n_cols, + int k, + int src_stride, + int start_row, + int end_row) { + assert(k % HMX_FP16_TILE_N_COLS == 0); + + const int n_k_tiles = k / HMX_FP16_TILE_N_COLS; + const HVX_Vector v_scat_base = hvx_vmem(hmx_transpose_scatter_offsets); + const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); + const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64); + // Each hvx_vmemu load brings 64 fp16 = 128 bytes covering 2 adjacent K-tiles. + // When n_k_tiles is even, scatter into 2 K-tiles per call (region 4095, no mask) + // using the upper half of hmx_transpose_scatter_offsets. Tail one K-tile (when + // n_k_tiles is odd) falls back to single-tile masked scatter. + const bool pair_scatter = (n_k_tiles & 1) == 0; + const size_t pair_region = (size_t) (2 * HMX_FP16_TILE_SIZE - 1); + const size_t single_region = (size_t) (HMX_FP16_TILE_SIZE - 1); + __builtin_assume(k > 0); + __builtin_assume(end_row > start_row); + + if (pair_scatter) { + // Step c by 64 fp16 (two K-tiles per scatter), advance dst by 2 tiles per iter. + const int c_step = 2 * HMX_FP16_TILE_N_COLS; + const size_t c_byte_step = (size_t) c_step * sizeof(__fp16); + const size_t dst_step = 2 * (size_t) HMX_FP16_TILE_N_ELMS; + const int n_c_iters = k / c_step; + + for (int r = start_row; r < end_row; r += 2) { + const int ct = r / HMX_FP16_TILE_N_ROWS; + const int local_r = r % HMX_FP16_TILE_N_ROWS; + const bool next_row_valid = (r + 1) < end_row && (r + 1) < n_cols; + const HVX_Vector v_off0 = Q6_Vw_vadd_VwVw(v_scat_base, Q6_V_vsplat_R(local_r * 4)); + const HVX_Vector v_off1 = Q6_Vw_vadd_VwVw(v_off0, v_scat_step); + + __fp16 * tile_base = vtcm_dst + (size_t) ct * n_k_tiles * HMX_FP16_TILE_N_ELMS; + const uint8_t * p0 = (const uint8_t *) (vtcm_src + r * src_stride); + const uint8_t * p1 = next_row_valid ? (const uint8_t *) (vtcm_src + (r + 1) * src_stride) : NULL; + + if (p1) { + for (int i = 0; i < n_c_iters; ++i) { + HVX_Vector v0 = hvx_vmemu(p0); + p0 += c_byte_step; + HVX_Vector v1 = hvx_vmemu(p1); + p1 += c_byte_step; + Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off0, v0); + Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off1, v1); + tile_base += dst_step; + } + } else { + const HVX_Vector vzero = Q6_V_vzero(); + for (int i = 0; i < n_c_iters; ++i) { + HVX_Vector v0 = hvx_vmemu(p0); + p0 += c_byte_step; + Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off0, v0); + Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off1, vzero); + tile_base += dst_step; + } + } + } + } else { + // Fallback: scatter one K-tile per call (region 2047, masked). + const int c_step = HMX_FP16_TILE_N_COLS; + const size_t c_byte_step = (size_t) c_step * sizeof(__fp16); + const size_t dst_step = (size_t) HMX_FP16_TILE_N_ELMS; + const int n_c_iters = k / c_step; + + for (int r = start_row; r < end_row; r += 2) { + const int ct = r / HMX_FP16_TILE_N_ROWS; + const int local_r = r % HMX_FP16_TILE_N_ROWS; + const bool next_row_valid = (r + 1) < end_row && (r + 1) < n_cols; + const HVX_Vector v_off0 = Q6_Vw_vadd_VwVw(v_scat_base, Q6_V_vsplat_R(local_r * 4)); + const HVX_Vector v_off1 = Q6_Vw_vadd_VwVw(v_off0, v_scat_step); + + __fp16 * tile_base = vtcm_dst + (size_t) ct * n_k_tiles * HMX_FP16_TILE_N_ELMS; + const uint8_t * p0 = (const uint8_t *) (vtcm_src + r * src_stride); + const uint8_t * p1 = next_row_valid ? (const uint8_t *) (vtcm_src + (r + 1) * src_stride) : NULL; + + if (p1) { + for (int i = 0; i < n_c_iters; ++i) { + HVX_Vector v0 = hvx_vmemu(p0); + p0 += c_byte_step; + HVX_Vector v1 = hvx_vmemu(p1); + p1 += c_byte_step; + Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off0, v0); + Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off1, v1); + tile_base += dst_step; + } + } else { + const HVX_Vector vzero = Q6_V_vzero(); + for (int i = 0; i < n_c_iters; ++i) { + HVX_Vector v0 = hvx_vmemu(p0); + p0 += c_byte_step; + Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off0, v0); + Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off1, vzero); + tile_base += dst_step; + } + } + } + } } -// --- VTCM sequential allocator (from htp-ops-lib/include/dsp/vtcm_mgr.h) --- +// Interleave row-major FP16 data into column-major tile format. +// Input: [n_rows, head_dim] row-major. Output: tile[dim_tile][row_tile]. +// Processes rows [start_row, end_row) for multi-thread slicing. +// Full range: start_row=0, end_row=n_rows. +static inline void hmx_interleave_cols_to_tiles(__fp16 * restrict tiles_out, + const __fp16 * restrict src, + int n_rows, + int head_dim, + int src_stride, + int n_row_tiles, + int start_row, + int end_row) { + __builtin_assume(head_dim > 0); + const size_t tile_stride_elms = (size_t) n_row_tiles * HMX_FP16_TILE_N_ELMS; + + for (int r = start_row; r < end_row; r += 2) { + const bool next_row_valid = (r + 1) < end_row && (r + 1) < n_rows; + + const HVX_Vector * pv_in0 = (const HVX_Vector *) (src + r * src_stride); + const HVX_Vector * pv_in1 = next_row_valid ? (const HVX_Vector *) (src + (r + 1) * src_stride) : NULL; + + // Row-pair invariants hoisted out of the c loop. + const int r0 = r / HMX_FP16_TILE_N_ROWS; + const int r1_half = (r % HMX_FP16_TILE_N_ROWS) / 2; + + // tb0 starts at tile (c0=0, r0); tb1 at the adjacent dim-tile (c0=1, r0). + // Each c step (+= 64) advances both by 2 dim-tiles worth of fp16. + __fp16 * tb0 = tiles_out + (size_t) r0 * HMX_FP16_TILE_N_ELMS; + __fp16 * tb1 = tb0 + tile_stride_elms; + const size_t tb_step = 2 * tile_stride_elms; -static inline uint8_t *vtcm_seq_alloc(uint8_t **vtcm_ptr, size_t size) { - uint8_t *p = *vtcm_ptr; - *vtcm_ptr += size; - return p; + if (pv_in1) { + for (int c = 0; c < head_dim; c += 64) { + HVX_Vector v0 = *pv_in0++; + HVX_Vector v1 = *pv_in1++; + HVX_VectorPair vp = Q6_W_vshuff_VVR(v1, v0, -2); + ((HVX_Vector *) tb0)[r1_half] = Q6_V_lo_W(vp); + ((HVX_Vector *) tb1)[r1_half] = Q6_V_hi_W(vp); + tb0 += tb_step; + tb1 += tb_step; + } + } else { + const HVX_Vector vzero = Q6_V_vzero(); + for (int c = 0; c < head_dim; c += 64) { + HVX_Vector v0 = *pv_in0++; + HVX_VectorPair vp = Q6_W_vshuff_VVR(vzero, v0, -2); + ((HVX_Vector *) tb0)[r1_half] = Q6_V_lo_W(vp); + ((HVX_Vector *) tb1)[r1_half] = Q6_V_hi_W(vp); + tb0 += tb_step; + tb1 += tb_step; + } + } + } } #endif // HMX_UTILS_H diff --git a/ggml/src/ggml-hexagon/htp/hvx-base.h b/ggml/src/ggml-hexagon/htp/hvx-base.h index d0926dedd28..f6cb02951d0 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-base.h +++ b/ggml/src/ggml-hexagon/htp/hvx-base.h @@ -77,6 +77,12 @@ static inline int32_t hvx_vec_get_i32(HVX_Vector v) { return x; } +static inline _Float16 hvx_vec_get_f16(HVX_Vector v) { + _Float16 __attribute__((aligned(128))) x; + hvx_vec_store_a(&x, 2, v); + return x; +} + static inline HVX_Vector hvx_vec_abs_f16(HVX_Vector v) { // abs by clearing the fp16 sign bit HVX_Vector mask = Q6_Vh_vsplat_R(0x7fff); diff --git a/ggml/src/ggml-hexagon/htp/hvx-copy.h b/ggml/src/ggml-hexagon/htp/hvx-copy.h index 851482e01b2..a3e33c3b3af 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-copy.h +++ b/ggml/src/ggml-hexagon/htp/hvx-copy.h @@ -7,7 +7,8 @@ #include "hvx-base.h" -#define hvx_splat_loop_body(dst_type, vec_store) \ +#define hvx_splat_pragma(x) _Pragma(#x) +#define hvx_splat_loop_body(dst_type, vec_store, unroll_cnt) \ do { \ dst_type * restrict vdst = (dst_type *) dst; \ \ @@ -16,7 +17,7 @@ \ uint32_t i = 0; \ \ - _Pragma("unroll(4)") \ + hvx_splat_pragma(unroll(unroll_cnt)) \ for (; i < nvec; i++) { \ vdst[i] = src; \ } \ @@ -25,31 +26,47 @@ } \ } while(0) -static inline void hvx_splat_a(uint8_t * restrict dst, HVX_Vector src, uint32_t n, uint32_t elem_size) { +static inline void hvx_splat_a(void * restrict dst, HVX_Vector src, uint32_t n, uint32_t elem_size) { assert((unsigned long) dst % 128 == 0); - hvx_splat_loop_body(HVX_Vector, hvx_vec_store_a); + hvx_splat_loop_body(HVX_Vector, hvx_vec_store_a, 4); } -static inline void hvx_splat_u(uint8_t * restrict dst, HVX_Vector src, uint32_t n, uint32_t elem_size) { - hvx_splat_loop_body(HVX_UVector, hvx_vec_store_u); +static inline void hvx_splat_u(void * restrict dst, HVX_Vector src, uint32_t n, uint32_t elem_size) { + hvx_splat_loop_body(HVX_UVector, hvx_vec_store_u, 4); } -static inline void hvx_splat_f32_a(uint8_t * restrict dst, float v, uint32_t n) { +static inline void hvx_splat_f32_a(void * restrict dst, float v, uint32_t n) { hvx_splat_a(dst, hvx_vec_splat_f32(v), n, sizeof(float)); } -static inline void hvx_splat_f32_u(uint8_t * restrict dst, float v, uint32_t n) { +static inline void hvx_splat_f32_u(void * restrict dst, float v, uint32_t n) { hvx_splat_u(dst, hvx_vec_splat_f32(v), n, sizeof(float)); } -static inline void hvx_splat_f16_a(uint8_t * restrict dst, _Float16 v, uint32_t n) { +static inline void hvx_splat_f16_a(void * restrict dst, _Float16 v, uint32_t n) { hvx_splat_u(dst, hvx_vec_splat_f16(v), n, sizeof(__fp16)); } -static inline void hvx_splat_f16_u(uint8_t * restrict dst, _Float16 v, uint32_t n) { +static inline void hvx_splat_f16_u(void * restrict dst, _Float16 v, uint32_t n) { hvx_splat_u(dst, hvx_vec_splat_f16(v), n, sizeof(__fp16)); } +static inline void hvx_splat_u16_a(void * restrict dst, uint16_t v, uint32_t n) { + hvx_splat_a(dst, Q6_Vh_vsplat_R(v), n, sizeof(uint16_t)); +} + +static inline void hvx_splat_u16_u(void * restrict dst, uint16_t v, uint32_t n) { + hvx_splat_u(dst, Q6_Vh_vsplat_R(v), n, sizeof(uint16_t)); +} + +static inline void hvx_splat_u8_a(void * restrict dst, uint8_t v, uint32_t n) { + hvx_splat_a(dst, Q6_Vb_vsplat_R(v), n, 1); +} + +static inline void hvx_splat_u8_u(void * restrict dst, uint8_t v, uint32_t n) { + hvx_splat_u(dst, Q6_Vb_vsplat_R(v), n, 1); +} + #define hvx_copy_loop_body(dst_type, src_type, vec_store) \ do { \ dst_type * restrict vdst = (dst_type *) dst; \ diff --git a/ggml/src/ggml-hexagon/htp/vtcm-utils.h b/ggml/src/ggml-hexagon/htp/vtcm-utils.h new file mode 100644 index 00000000000..b129fb74e31 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/vtcm-utils.h @@ -0,0 +1,16 @@ +#ifndef VTCM_UTILS_H +#define VTCM_UTILS_H + +#include "hex-utils.h" + +#include +#include +#include + +static inline uint8_t *vtcm_seq_alloc(uint8_t **vtcm_ptr, size_t size) { + uint8_t *p = *vtcm_ptr; + *vtcm_ptr += size; + return p; +} + +#endif // VTCM_UTILS_H From 28f8534532f5be51fa6bc0a27c30e0dbecc9769f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 2 May 2026 08:45:46 +0300 Subject: [PATCH 537/831] ggml : bump version to 0.10.2 (ggml/1474) --- ggml/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index f7b6f1f334f..c97f681988b 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -5,7 +5,7 @@ project("ggml" C CXX ASM) ### GGML Version set(GGML_VERSION_MAJOR 0) set(GGML_VERSION_MINOR 10) -set(GGML_VERSION_PATCH 1) +set(GGML_VERSION_PATCH 2) set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}") list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/") From a5a8496d31ef1690ff2addc65b555916c3cf8895 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 2 May 2026 08:49:06 +0300 Subject: [PATCH 538/831] ggml : remove obsoloete wgsl templates (ggml/0) --- .../ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl | 107 ------ .../ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl | 323 ---------------- .../ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl | 295 --------------- .../wgsl-shaders/soft_max.tmpl.wgsl | 345 ------------------ 4 files changed, 1070 deletions(-) delete mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl delete mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl delete mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl delete mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl deleted file mode 100644 index b5e93b812fd..00000000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +++ /dev/null @@ -1,107 +0,0 @@ -#define(VARIANTS) - -[ - { - "REPLS": { - "SRC_TYPE": "f32", - "DST_TYPE": "f32" - } - }, - { - "REPLS": { - "SRC_TYPE": "f32", - "DST_TYPE": "i32" - } - }, - { - "REPLS": { - "SRC_TYPE": "f32", - "DST_TYPE": "f16" - } - }, - { - "REPLS": { - "SRC_TYPE": "f16", - "DST_TYPE": "f16" - } - }, - { - "REPLS": { - "SRC_TYPE": "f16", - "DST_TYPE": "f32" - } - } -] - -#end(VARIANTS) - -#define(SHADER) -enable f16; - -@group(0) @binding(0) -var src: array<{{SRC_TYPE}}>; - -@group(0) @binding(1) -var dst: array<{{DST_TYPE}}>; - -struct Params { - ne: u32, // total number of elements - offset_src: u32, // in elements - offset_dst: u32, // in elements - - // Strides (in elements) — may be permuted - stride_src0: u32, - stride_src1: u32, - stride_src2: u32, - stride_src3: u32, - - stride_dst0: u32, - stride_dst1: u32, - stride_dst2: u32, - stride_dst3: u32, - - // Logical shapes - src_ne0: u32, - src_ne1: u32, - src_ne2: u32, - - dst_ne0: u32, - dst_ne1: u32, - dst_ne2: u32 -}; - -@group(0) @binding(2) -var params: Params; - -override wg_size: u32; -@compute @workgroup_size(wg_size) -fn main(@builtin(global_invocation_id) gid: vec3) { - if (gid.x >= params.ne) { - return; - } - - var i = gid.x; - let i3 = i / (params.src_ne2 * params.src_ne1 * params.src_ne0); - i = i % (params.src_ne2 * params.src_ne1 * params.src_ne0); - let i2 = i / (params.src_ne1 * params.src_ne0); - i = i % (params.src_ne1 * params.src_ne0); - let i1 = i / params.src_ne0; - let i0 = i % params.src_ne0; - - var j = gid.x; - let j3 = j / (params.dst_ne2 * params.dst_ne1 * params.dst_ne0); - j = j % (params.dst_ne2 * params.dst_ne1 * params.dst_ne0); - let j2 = j / (params.dst_ne1 * params.dst_ne0); - j = j % (params.dst_ne1 * params.dst_ne0); - let j1 = j / params.dst_ne0; - let j0 = j % params.dst_ne0; - - let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 + - i2 * params.stride_src2 + i3 * params.stride_src3; - - let dst_idx = j0 * params.stride_dst0 + j1 * params.stride_dst1 + - j2 * params.stride_dst2 + j3 * params.stride_dst3; - - dst[params.offset_dst + dst_idx] = {{DST_TYPE}}((src[params.offset_src + src_idx])); -} -#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl deleted file mode 100644 index 03fcd548689..00000000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +++ /dev/null @@ -1,323 +0,0 @@ -#define(VARIANTS) - -[ - { - "SHADER_NAME": "reglu_f32", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["NO_SPLIT", "REGLU"] - }, - { - "SHADER_NAME": "reglu_f32_split", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["SPLIT", "REGLU"] - }, - { - "SHADER_NAME": "reglu_f16", - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["NO_SPLIT", "REGLU"] - }, - { - "SHADER_NAME": "reglu_f16_split", - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["SPLIT", "REGLU"] - }, - { - "SHADER_NAME": "geglu_f32", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["NO_SPLIT", "GEGLU"] - }, - { - "SHADER_NAME": "geglu_f32_split", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["SPLIT", "GEGLU"] - }, - { - "SHADER_NAME": "geglu_f16", - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["NO_SPLIT", "GEGLU"] - }, - { - "SHADER_NAME": "geglu_f16_split", - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["SPLIT", "GEGLU"] - }, - { - "SHADER_NAME": "swiglu_f32", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["NO_SPLIT", "SWIGLU"] - }, - { - "SHADER_NAME": "swiglu_f32_split", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["SPLIT", "SWIGLU"] - }, - { - "SHADER_NAME": "swiglu_f16", - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["NO_SPLIT", "SWIGLU"] - }, - { - "SHADER_NAME": "swiglu_f16_split", - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["SPLIT", "SWIGLU"] - }, - { - "SHADER_NAME": "swiglu_oai_f32", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["NO_SPLIT", "SWIGLU_OAI"] - }, - { - "SHADER_NAME": "swiglu_oai_f32_split", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["SPLIT", "SWIGLU_OAI"] - }, - { - "SHADER_NAME": "geglu_erf_f32", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["NO_SPLIT", "GEGLU_ERF"] - }, - { - "SHADER_NAME": "geglu_erf_f32_split", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["SPLIT", "GEGLU_ERF"] - }, - { - "SHADER_NAME": "geglu_erf_f16", - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["NO_SPLIT", "GEGLU_ERF"] - }, - { - "SHADER_NAME": "geglu_erf_f16_split", - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["SPLIT", "GEGLU_ERF"] - }, - { - "SHADER_NAME": "geglu_quick_f32", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["NO_SPLIT", "GEGLU_QUICK"] - }, - { - "SHADER_NAME": "geglu_quick_f32_split", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["SPLIT", "GEGLU_QUICK"] - }, - { - "SHADER_NAME": "geglu_quick_f16", - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["NO_SPLIT", "GEGLU_QUICK"] - }, - { - "SHADER_NAME": "geglu_quick_f16_split", - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["SPLIT", "GEGLU_QUICK"] - }, -] - -#end(VARIANTS) - -#define(DECLS) - -#decl(REGLU) -fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} { - return max(a, 0) * b; -} -#enddecl(REGLU) - -#decl(GEGLU) -const SQRT_2_OVER_PI: {{TYPE}} = 0.79788456080286535587989211986876; -const GELU_COEF_A: {{TYPE}} = 0.044715; - -fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} { - let val = SQRT_2_OVER_PI * a * (1.0 + GELU_COEF_A * a * a); - return 0.5 * a * (2.0 - 2.0 / (exp(2 * val) + 1)) * b; -} -#enddecl(GEGLU) - -#decl(SWIGLU) -fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} { - return a / (1.0 + exp(-a)) * b; -} -#enddecl(SWIGLU) - -#decl(SWIGLU_OAI) -fn op(a: f32, b: f32) -> f32 { - let xi = min(a, params.limit); - let gi = max(min(b, params.limit), -params.limit); - var out_glu = xi / (1.0 + exp(-xi * params.alpha)); - out_glu = out_glu * (1.0 + gi); - return out_glu; -} -#enddecl(SWIGLU_OAI) - -#decl(GEGLU_ERF) -const p_erf: {{TYPE}} = 0.3275911; -const a1_erf: {{TYPE}} = 0.254829592; -const a2_erf: {{TYPE}} = -0.284496736; -const a3_erf: {{TYPE}} = 1.421413741; -const a4_erf: {{TYPE}} = -1.453152027; -const a5_erf: {{TYPE}} = 1.061405429; -const SQRT_2_INV: {{TYPE}} = 0.7071067811865476; - -fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} { - let a_div_sqr2 = a * SQRT_2_INV; - let sign_x = sign(a_div_sqr2); - let x = abs(a_div_sqr2); - let t = 1.0 / (1.0 + p_erf * x); - let y = 1.0 - (((((a5_erf * t + a4_erf) * t + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x)); - let erf_approx = sign_x * y; - return 0.5 * a * (1.0 + erf_approx) * b; -} -#enddecl(GEGLU_ERF) - -#decl(GEGLU_QUICK) -const GELU_QUICK_COEF: {{TYPE}} = -1.702; - -fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} { - return a * (1.0 / (1.0 + exp(GELU_QUICK_COEF * a))) * b; -} -#enddecl(GEGLU_QUICK) - -#decl(NO_SPLIT) -@group(0) @binding(1) -var dst: array<{{TYPE}}>; - -@group(0) @binding(2) -var params: Params; - -fn a_value(base: u32) -> {{TYPE}} { - let offset: u32 = select(0, params.ne0, params.swapped != 0); - return src0[base + offset]; -} - -fn b_value(base: u32) -> {{TYPE}} { - let offset: u32 = select(params.ne0, 0, params.swapped != 0); - return src0[base + offset]; -} -#enddecl(NO_SPLIT) - -#decl(SPLIT) -@group(0) @binding(1) -var src1: array<{{TYPE}}>; - -@group(0) @binding(2) -var dst: array<{{TYPE}}>; - -@group(0) @binding(3) -var params: Params; - -fn a_value(base: u32) -> {{TYPE}} { - return src0[base]; -} - -fn b_value(base: u32) -> {{TYPE}} { - return src1[base]; -} -#enddecl(SPLIT) - -#end(DECLS) - -#define(SHADER) - -enable f16; - -struct Params { - offset_src0: u32, - offset_src1: u32, - offset_dst: u32, - - // Strides (in elements) - stride_src01: u32, - stride_src02: u32, - stride_src03: u32, - - stride_src11: u32, - stride_src12: u32, - stride_src13: u32, - - stride_dst1: u32, - stride_dst2: u32, - stride_dst3: u32, - - // shape of dst - ne: u32, - ne0: u32, - ne1: u32, - ne2: u32, - - swapped: u32, - alpha: f32, - limit: f32, -} - -@group(0) @binding(0) -var src0: array<{{TYPE}}>; - -DECLS - -override wg_size: u32; -@compute @workgroup_size(wg_size) -fn main(@builtin(global_invocation_id) gid: vec3) { - if (gid.x >= params.ne) { - return; - } - - var i = gid.x; - let i3 = i / (params.ne2 * params.ne1 * params.ne0); - i = i % (params.ne2 * params.ne1 * params.ne0); - let i2 = i / (params.ne1 * params.ne0); - i = i % (params.ne1 * params.ne0); - let i1 = i / params.ne0; - let i0 = i % params.ne0; - - let i_a = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01 + i0; - let i_b = params.offset_src1 + i3 * params.stride_src13 + i2 * params.stride_src12 + i1 * params.stride_src11 + i0; - let i_dst = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1 + i0; - - dst[i_dst] = op(a_value(i_a), b_value(i_b)); -} - -#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl deleted file mode 100644 index 84dc8dbff61..00000000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +++ /dev/null @@ -1,295 +0,0 @@ -#define(VARIANTS) - -[ - { - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["NO_FF_BINDINGS", "NO_FF_FUNC", "ROTATE"] - }, - { - "SHADER_SUFFIX": "f32_inplace", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["NO_FF_BINDINGS_INPLACE", "NO_FF_FUNC", "ROTATE_INPLACE"] - }, - { - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["NO_FF_BINDINGS", "NO_FF_FUNC", "ROTATE"] - }, - { - "SHADER_SUFFIX": "f16_inplace", - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["NO_FF_BINDINGS_INPLACE", "NO_FF_FUNC", "ROTATE_INPLACE"] - }, - { - "SHADER_SUFFIX": "f32_ff", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["FF_BINDINGS", "FF_FUNC", "ROTATE"] - }, - { - "SHADER_SUFFIX": "f32_ff_inplace", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["FF_BINDINGS_INPLACE", "FF_FUNC", "ROTATE_INPLACE"] - }, - { - "SHADER_SUFFIX": "f16_ff", - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["FF_BINDINGS", "FF_FUNC", "ROTATE"] - }, - { - "SHADER_SUFFIX": "f16_ff_inplace", - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["FF_BINDINGS_INPLACE", "FF_FUNC", "ROTATE_INPLACE"] - } -] - -#end(VARIANTS) - -#define(DECLS) - -#decl(ROTATE) -fn rotate(i_dst0: u32, i_dst1: u32, out0: f32, out1: f32) { - dst[i_dst0] = {{TYPE}}(out0); - dst[i_dst1] = {{TYPE}}(out1); -} -#enddecl(ROTATE) - -#decl(ROTATE_INPLACE) -fn rotate(i_dst0: u32, i_dst1: u32, out0: f32, out1: f32) { - src0[i_dst0] = {{TYPE}}(out0); - src0[i_dst1] = {{TYPE}}(out1); -} -#enddecl(ROTATE_INPLACE) - -#decl(NO_FF_FUNC) -fn freq_factor(i: u32) -> f32 { - return 1.0f; -} -#enddecl(NO_FF_FUNC) - -#decl(FF_FUNC) -fn freq_factor(i: u32) -> f32 { - return src2[params.offset_src2 + i/2]; -} -#enddecl(FF_FUNC) - -#decl(NO_FF_BINDINGS) - -@group(0) @binding(2) -var dst: array<{{TYPE}}>; - -@group(0) @binding(3) -var params: Params; - -#enddecl(NO_FF_BINDINGS) - -#decl(NO_FF_BINDINGS_INPLACE) - -@group(0) @binding(2) -var params: Params; - -#enddecl(NO_FF_BINDINGS_INPLACE) - -#decl(FF_BINDINGS) - -@group(0) @binding(2) -var src2: array; - -@group(0) @binding(3) -var dst: array<{{TYPE}}>; - -@group(0) @binding(4) -var params: Params; - -#enddecl(FF_BINDINGS) - -#decl(FF_BINDINGS_INPLACE) - -@group(0) @binding(2) -var src2: array; - -@group(0) @binding(3) -var params: Params; - -#enddecl(FF_BINDINGS_INPLACE) - -#end(DECLS) - -#define(SHADER) - -enable f16; - -struct Params { - offset_src0: u32, - offset_src1: u32, - offset_src2: u32, - offset_dst: u32, - - // Strides (in elements) - stride_src01: u32, - stride_src02: u32, - stride_src03: u32, - - stride_dst1: u32, - stride_dst2: u32, - stride_dst3: u32, - - n_threads: u32, - ne0: u32, - ne1: u32, - ne2: u32, - - n_dims: u32, - mode: u32, - theta_scale: f32, - attn_factor: f32, - freq_scale: f32, - ext_factor: f32, - corr_dim0: f32, - corr_dim1: f32, - sections0: u32, - sections1: u32, - sections2: u32, - sections3: u32 -}; - -@group(0) @binding(0) -var src0: array<{{TYPE}}>; - -@group(0) @binding(1) -var src1: array; - -DECLS - -fn rope_yarn_ramp(low: f32, high: f32, i: u32) -> f32 { - let y = (f32(i / 2) - low) / max(0.001f, high - low); - return 1.0f - min(1.0f, max(0.0f, y)); -} - -// returns vector of (cos_theta, sin_theta) -// TODO: check performance of instantiating once on the CPU and passed as buffer, since it's repeated per-row -fn rope_yarn(theta_extrap: f32, i: u32) -> vec2 { - var mscale = params.attn_factor; - var theta = params.freq_scale * theta_extrap; - if (params.ext_factor != 0.0f) { - let ramp_mix = rope_yarn_ramp(params.corr_dim0, params.corr_dim1, i) * params.ext_factor; - theta = theta * (1 - ramp_mix) + theta_extrap * ramp_mix; - mscale *= 1.0f + 0.1f * log(1.0f / params.freq_scale); - } - return vec2(cos(theta) * mscale, sin(theta) * mscale); -} - -fn pair_base(i0: u32, div_2: bool) -> u32 { - if (div_2) { - return i0 / 2; - } else { - return i0; - } -} - -fn pair_offset(is_neox: bool, is_mrope: bool, is_vision: bool) -> u32 { - if (is_vision) { - return params.n_dims; - } else if (is_neox || is_mrope) { - return params.n_dims / 2; - } else { - return 1; - } -} - -override wg_size: u32; -@compute @workgroup_size(wg_size) -fn main(@builtin(global_invocation_id) gid: vec3) { - // two elements per thread - if (gid.x >= params.n_threads) { - return; - } - - let is_neox = bool(params.mode & 2); - let is_mrope = bool(params.mode & 8); - let is_imrope = params.mode == 40; - let is_vision = params.mode == 24; - - var i = gid.x * 2; // start index for this thread - let i3 = i / (params.ne2 * params.ne1 * params.ne0); - i = i % (params.ne2 * params.ne1 * params.ne0); - let i2 = i / (params.ne1 * params.ne0); - i = i % (params.ne1 * params.ne0); - let i1 = i / params.ne0; - let i0 = i % params.ne0; - - let i_src_row = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01; - let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1; - - if (i0 >= params.n_dims && !is_vision) { - let i_src = i_src_row + i0; - let i_dst = i_dst_row + i0; - rotate(i_dst, i_dst + 1, f32(src0[i_src]), f32(src0[i_src + 1])); - return; - } - - var theta_base_mult: u32 = 0; - var theta_scale_pwr: u32 = i0 / 2; - if (is_mrope) { - let sect_dims = params.sections0 + params.sections1 + params.sections2 + params.sections3; - let sec_w = params.sections1 + params.sections0; - let sec_e = params.sections2 + sec_w; - let sector = (i0 / 2) % sect_dims; - if (is_imrope) { - if (sector % 3 == 1 && sector < 3 * params.sections1) { - theta_base_mult = 1; - } else if (sector % 3 == 2 && sector < 3 * params.sections2) { - theta_base_mult = 2; - } else if (sector % 3 == 0 && sector < 3 * params.sections0) { - theta_base_mult = 0; - } else { - theta_base_mult = 3; - } - } else { - if (sector >= params.sections0 && sector < sec_w) { - theta_base_mult = 1; - if (is_vision) { - theta_scale_pwr = sector - params.sections0; - } - } else if (sector >= sec_w && sector < sec_e) { - theta_base_mult = 2; - if (is_vision) { - theta_scale_pwr = sector - sec_w; - } - } else if (sector >= sec_e) { - if (is_vision) { - theta_scale_pwr = sector - sec_e; - theta_scale_pwr = (i0 / 2) % sec_e; - } - theta_base_mult = 3; - } else if (is_vision) { - theta_scale_pwr = sector; - } - } - } - let theta_base = f32(src1[params.offset_src1 + i2 + params.ne2 * theta_base_mult]) * pow(params.theta_scale, f32(theta_scale_pwr)); - let thetas = rope_yarn(theta_base/freq_factor(i0), i0); - - let i_src = i_src_row + pair_base(i0, is_neox || is_mrope || is_vision); - let i_dst = i_dst_row + pair_base(i0, is_neox || is_mrope || is_vision); - - let x0 = f32(src0[i_src]); - let x1 = f32(src0[i_src + pair_offset(is_neox, is_mrope, is_vision)]); - rotate(i_dst, i_dst + pair_offset(is_neox, is_mrope, is_vision), x0 * thetas.x - x1 * thetas.y, x0 * thetas.y + x1 * thetas.x); -} - -#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl deleted file mode 100644 index c74dc4cc923..00000000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +++ /dev/null @@ -1,345 +0,0 @@ -#define(VARIANTS) -[ - { - "SHADER_NAME": "soft_max_f32", - "DECLS": ["BASE_BINDINGS", "NOT_INPLACE", "NO_MASK", "NO_SINK"] - }, - { - "SHADER_NAME": "soft_max_f32_inplace", - "DECLS": ["BASE_BINDINGS_INPLACE", "INPLACE", "NO_MASK", "NO_SINK"] - }, - { - "SHADER_NAME": "soft_max_f32_sink", - "DECLS": ["SINK_BINDINGS", "NOT_INPLACE", "NO_MASK", "SINK"] - }, - { - "SHADER_NAME": "soft_max_f32_sink_inplace", - "DECLS": ["SINK_BINDINGS_INPLACE", "INPLACE", "NO_MASK", "SINK"] - }, - { - "SHADER_NAME": "soft_max_f32_mask_f32", - "REPLS": { - "MASK_TYPE" : "f32", - }, - "DECLS": ["MASK_BINDINGS", "NOT_INPLACE", "MASK", "NO_SINK"] - }, - { - "SHADER_NAME": "soft_max_f32_mask_f32_inplace", - "REPLS": { - "MASK_TYPE" : "f32", - }, - "DECLS": ["MASK_BINDINGS_INPLACE", "INPLACE", "MASK", "NO_SINK"] - }, - { - "SHADER_NAME": "soft_max_f32_mask_f16", - "REPLS": { - "MASK_TYPE" : "f16", - }, - "DECLS": ["MASK_BINDINGS", "NOT_INPLACE", "MASK", "NO_SINK"] - }, - { - "SHADER_NAME": "soft_max_f32_mask_f16_inplace", - "REPLS": { - "MASK_TYPE" : "f16", - }, - "DECLS": ["MASK_BINDINGS_INPLACE", "INPLACE", "MASK", "NO_SINK"] - }, - { - "SHADER_NAME": "soft_max_f32_mask_f32_sink", - "REPLS": { - "MASK_TYPE" : "f32", - }, - "DECLS": ["MASK_SINK_BINDINGS", "NOT_INPLACE", "MASK", "SINK"] - }, - { - "SHADER_NAME": "soft_max_f32_mask_f32_sink_inplace", - "REPLS": { - "MASK_TYPE" : "f32", - }, - "DECLS": ["MASK_SINK_BINDINGS_INPLACE", "INPLACE", "MASK", "SINK"] - }, - { - "SHADER_NAME": "soft_max_f32_mask_f16_sink", - "REPLS": { - "MASK_TYPE" : "f16", - }, - "DECLS": ["MASK_SINK_BINDINGS", "NOT_INPLACE", "MASK", "SINK"] - }, - { - "SHADER_NAME": "soft_max_f32_mask_f16_sink_inplace", - "REPLS": { - "MASK_TYPE" : "f16", - }, - "DECLS": ["MASK_SINK_BINDINGS_INPLACE", "INPLACE", "MASK", "SINK"] - } -] -#end(VARIANTS) - -#define(DECLS) - -#decl(BASE_BINDINGS) -@group(0) @binding(1) -var dst: array; - -@group(0) @binding(2) -var params: Params; -#enddecl(BASE_BINDINGS) - -#decl(BASE_BINDINGS_INPLACE) -@group(0) @binding(1) -var params: Params; -#enddecl(BASE_BINDINGS_INPLACE) - -#decl(SINK_BINDINGS) -@group(0) @binding(1) -var sinks: array; - -@group(0) @binding(2) -var dst: array; - -@group(0) @binding(3) -var params: Params; -#enddecl(SINK_BINDINGS) - -#decl(SINK_BINDINGS_INPLACE) -@group(0) @binding(1) -var sinks: array; - -@group(0) @binding(2) -var params: Params; -#enddecl(SINK_BINDINGS_INPLACE) - -#decl(MASK_BINDINGS) -@group(0) @binding(1) -var mask: array<{{MASK_TYPE}}>; - -@group(0) @binding(2) -var dst: array; - -@group(0) @binding(3) -var params: Params; -#enddecl(MASK_BINDINGS) - -#decl(MASK_BINDINGS_INPLACE) -@group(0) @binding(1) -var mask: array<{{MASK_TYPE}}>; - -@group(0) @binding(2) -var params: Params; -#enddecl(MASK_BINDINGS_INPLACE) - -#decl(MASK_SINK_BINDINGS) -@group(0) @binding(1) -var mask: array<{{MASK_TYPE}}>; - -@group(0) @binding(2) -var sinks: array; - -@group(0) @binding(3) -var dst: array; - -@group(0) @binding(4) -var params: Params; -#enddecl(MASK_SINK_BINDINGS) - -#decl(MASK_SINK_BINDINGS_INPLACE) -@group(0) @binding(1) -var mask: array<{{MASK_TYPE}}>; - -@group(0) @binding(2) -var sinks: array; - -@group(0) @binding(3) -var params: Params; -#enddecl(MASK_SINK_BINDINGS_INPLACE) - -#decl(NOT_INPLACE) -fn inter_value(i: u32) -> f32 { - return dst[i]; -} - -fn update(i: u32, val: f32) { - dst[i] = val; -} -#enddecl(NOT_INPLACE) - -#decl(INPLACE) -fn inter_value(i: u32) -> f32 { - return src[i]; -} - -fn update(i: u32, val: f32) { - src[i] = val; -} -#enddecl(INPLACE) - -#decl(NO_MASK) -fn mask_val(i: u32) -> f32 { - return 0.0; -} -#enddecl(NO_MASK) - -#decl(MASK) -fn mask_val(i: u32) -> f32 { - return f32(mask[i]); -} -#enddecl(MASK) - -#decl(NO_SINK) -fn lower_max_bound(i2: u32) -> f32 { - return -1e30; -} - -fn add_sinks(val: f32, i2: u32, max_val: f32) -> f32 { - return val; -} -#enddecl(NO_SINK) - -#decl(SINK) -fn lower_max_bound(i2: u32) -> f32 { - return sinks[params.offset_sinks + i2]; -} - -fn add_sinks(val: f32, i2: u32, max_val: f32) -> f32 { - return val + exp(sinks[params.offset_sinks + i2] - max_val); -} -#enddecl(SINK) - -#end(DECLS) - -#define(SHADER) -enable f16; - -struct Params { - offset_src0: u32, - offset_src1: u32, - offset_sinks: u32, - offset_dst: u32, - - // Strides (in elements) - stride_src01: u32, - stride_src02: u32, - stride_src03: u32, - - stride_src11: u32, - stride_src12: u32, - stride_src13: u32, - - stride_dst1: u32, - stride_dst2: u32, - stride_dst3: u32, - - // shape of src0/dst - ne: u32, - ne0: u32, - ne1: u32, - ne2: u32, - - // shape of src1 - ne12: u32, - ne13: u32, - - scale: f32, - max_bias: f32, - n_head_log2: f32, - m0: f32, - m1: f32, -}; - -@group(0) @binding(0) -var src: array; - -DECLS - -const CACHE_SIZE: u32 = 16; - -override wg_size: u32; -var scratch: array; - -@compute @workgroup_size(wg_size) -fn main(@builtin(workgroup_id) wid: vec3, - @builtin(local_invocation_id) lid: vec3) { - - var i = wid.x; - let i3 = i / (params.ne2 * params.ne1); - i = i % (params.ne2 * params.ne1); - let i2 = i / params.ne1; - let i1 = i % params.ne1; - let i_src0_row = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01; - let i_src1_row = params.offset_src1 + (i3 % params.ne13) * params.stride_src13 + (i2 % params.ne12) * params.stride_src12 + i1 * params.stride_src11; - let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1; - let elems = (params.ne0 + wg_size - 1) / wg_size; - - let head = f32(i2); - let slope = select(1, select(pow(params.m1, 2 * (head - params.n_head_log2) + 1), pow(params.m0, head + 1), head < params.n_head_log2), params.max_bias > 0); - - var cache: array; - - var max_val = lower_max_bound(i2); - var col = lid.x; - for (var j: u32 = 0; j < elems; j++) { - if (col >= params.ne0) { - break; - } - let val = src[i_src0_row + col] * params.scale + slope * mask_val(i_src1_row + col); - max_val = max(max_val, val); - if (col < CACHE_SIZE) { - cache[col] = val; - } - col += wg_size; - } - - scratch[lid.x] = max_val; - workgroupBarrier(); - var offset = wg_size / 2; - while (offset > 0) { - if (lid.x < offset) { - scratch[lid.x] = max(scratch[lid.x], scratch[lid.x + offset]); - } - offset = offset / 2; - workgroupBarrier(); - } - let row_max = scratch[0]; - workgroupBarrier(); - - var sum = 0.0f; - col = lid.x; - for (var j: u32 = 0; j < elems; j++) { - if (col >= params.ne0) { - break; - } - let val = select(src[i_src0_row + col] * params.scale + slope * mask_val(i_src1_row + col), - cache[col], col < CACHE_SIZE); - let ex = exp(val - row_max); - sum += ex; - if (col < CACHE_SIZE) { - cache[col] = ex; - } else { - update(i_dst_row + col, ex); - } - col += wg_size; - } - - scratch[lid.x] = sum; - workgroupBarrier(); - offset = wg_size / 2; - while (offset > 0) { - if (lid.x < offset) { - scratch[lid.x] += scratch[lid.x + offset]; - } - offset = offset / 2; - workgroupBarrier(); - } - let row_sum = add_sinks(scratch[0], i2, row_max); - - let sum_recip = 1.0 / row_sum; - col = lid.x; - for (var j: u32 = 0; j < elems; j++) { - if (col >= params.ne0) { - break; - } - update(i_dst_row + col, select(inter_value(i_dst_row + col), cache[col], col < CACHE_SIZE) * sum_recip); - col += wg_size; - } -} -#end(SHADER) From bbdaa21aa7d301675f5cf7fd87f8c0b8c272dd29 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 2 May 2026 08:51:39 +0300 Subject: [PATCH 539/831] ggml : remove obsolete rms_norm.wgsl (ggml/0) --- .../ggml-webgpu/wgsl-shaders/rms_norm.wgsl | 123 ------------------ 1 file changed, 123 deletions(-) delete mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl deleted file mode 100644 index 712b921f1ab..00000000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +++ /dev/null @@ -1,123 +0,0 @@ -#define(VARIANTS) - -[ - { - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_SUFFIX": "inplace", - "DECLS": ["INPLACE"] - }, -] - -#end(VARIANTS) - -#define(DECLS) - -#decl(NOT_INPLACE) - -fn update(src_offset: u32, dst_offset: u32, scale: f32) { - dst[dst_offset] = scale * src[src_offset]; -} - -@group(0) @binding(1) -var dst: array; - -@group(0) @binding(2) -var params: Params; - -#enddecl(NOT_INPLACE) - -#decl(INPLACE) - -fn update(src_offset: u32, dst_offset: u32, scale: f32) { - src[dst_offset] = scale * src[src_offset]; -} - -@group(0) @binding(1) -var params: Params; - -#enddecl(INPLACE) - -#end(DECLS) - -#define(SHADER) - -struct Params { - offset_src: u32, // in elements - offset_dst: u32, // in elements - - // Strides (in elements) - stride_src1: u32, - stride_src2: u32, - stride_src3: u32, - - stride_dst1: u32, - stride_dst2: u32, - stride_dst3: u32, - - // Shape of src/dst - ne0: u32, - ne1: u32, - ne2: u32, - ne3: u32, - - eps: f32 -}; - -@group(0) @binding(0) -var src: array; - -DECLS - -override wg_size: u32; -var scratch: array; - -@compute @workgroup_size(wg_size) -fn main(@builtin(workgroup_id) wid: vec3, - @builtin(local_invocation_id) lid: vec3) { - - // one thread per row - var i = wid.x; - let i3 = i / (params.ne2 * params.ne1); - i = i % (params.ne2 * params.ne1); - let i2 = i / params.ne1; - let i1 = i % params.ne1; - let i_src_row = params.offset_src + i3 * params.stride_src3 + i2 * params.stride_src2 + i1 * params.stride_src1; - let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1; - - let elems = (params.ne0 + wg_size - 1) / wg_size; - - var sum = 0.0f; - var col = lid.x; - for (var j: u32 = 0; j < elems; j++) { - if (col >= params.ne0) { - break; - } - sum += pow(src[i_src_row + col], 2.0); - col += wg_size; - } - - scratch[lid.x] = sum; - workgroupBarrier(); - var offset = wg_size / 2; - while (offset > 0) { - if (lid.x < offset) { - scratch[lid.x] += scratch[lid.x + offset]; - } - offset = offset / 2; - workgroupBarrier(); - } - sum = scratch[0]; - - let scale = 1.0/sqrt(sum/f32(params.ne0) + params.eps); - col = lid.x; - for (var j: u32 = 0; j < elems; j++) { - if (col >= params.ne0) { - break; - } - update(i_src_row + col, i_dst_row + col, scale); - col += wg_size; - } -} -#end(SHADER) From 8384aa8086714d6177f24eb5c409b39949efd2ce Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 2 May 2026 08:53:58 +0300 Subject: [PATCH 540/831] sync : ggml --- scripts/sync-ggml.last | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index a03455e74c8..812e721a8c5 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -b70770970e84c30a007b3859a453768b3ece2d3d +19eac6f0edaf285506eb6228d31bb9caeda9aba1 From 18162bcf6120551cfd447c81a09a98c6ed3db675 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 2 May 2026 08:54:20 +0300 Subject: [PATCH 541/831] cmake : add FindNCCL.cmake (ggml/0) --- ggml/cmake/FindNCCL.cmake | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) create mode 100644 ggml/cmake/FindNCCL.cmake diff --git a/ggml/cmake/FindNCCL.cmake b/ggml/cmake/FindNCCL.cmake new file mode 100644 index 00000000000..67511e2d56a --- /dev/null +++ b/ggml/cmake/FindNCCL.cmake @@ -0,0 +1,36 @@ +# cmake/FindNCCL.cmake + +# NVIDIA does not distribute CMake files with NCCl, therefore use this file to find it instead. + +find_path(NCCL_INCLUDE_DIR + NAMES nccl.h + HINTS ${NCCL_ROOT} $ENV{NCCL_ROOT} $ENV{CUDA_HOME} /usr/local/cuda + PATH_SUFFIXES include +) + +find_library(NCCL_LIBRARY + NAMES nccl + HINTS ${NCCL_ROOT} $ENV{NCCL_ROOT} $ENV{CUDA_HOME} /usr/local/cuda + PATH_SUFFIXES lib lib64 +) + +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args(NCCL + DEFAULT_MSG + NCCL_LIBRARY NCCL_INCLUDE_DIR +) + +if(NCCL_FOUND) + set(NCCL_LIBRARIES ${NCCL_LIBRARY}) + set(NCCL_INCLUDE_DIRS ${NCCL_INCLUDE_DIR}) + + if(NOT TARGET NCCL::NCCL) + add_library(NCCL::NCCL UNKNOWN IMPORTED) + set_target_properties(NCCL::NCCL PROPERTIES + IMPORTED_LOCATION "${NCCL_LIBRARY}" + INTERFACE_INCLUDE_DIRECTORIES "${NCCL_INCLUDE_DIR}" + ) + endif() +endif() + +mark_as_advanced(NCCL_INCLUDE_DIR NCCL_LIBRARY) From 4bf733672b2871d4153158af4f621a6dd9104f4a Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 2 May 2026 09:01:24 +0300 Subject: [PATCH 542/831] talk-llama : sync llama.cpp --- examples/talk-llama/llama-adapter.cpp | 17 +- examples/talk-llama/llama-adapter.h | 4 +- examples/talk-llama/llama-arch.cpp | 2099 +---------------- examples/talk-llama/llama-arch.h | 20 +- examples/talk-llama/llama-batch.h | 2 +- examples/talk-llama/llama-chat.cpp | 51 +- examples/talk-llama/llama-chat.h | 5 +- examples/talk-llama/llama-context.cpp | 224 +- examples/talk-llama/llama-context.h | 14 +- examples/talk-llama/llama-ext.h | 84 +- examples/talk-llama/llama-grammar.cpp | 74 +- examples/talk-llama/llama-graph.cpp | 259 +- examples/talk-llama/llama-graph.h | 33 + examples/talk-llama/llama-hparams.h | 4 + examples/talk-llama/llama-impl.cpp | 2 +- examples/talk-llama/llama-kv-cache.cpp | 223 +- examples/talk-llama/llama-kv-cache.h | 33 +- .../talk-llama/llama-memory-hybrid-iswa.cpp | 6 +- examples/talk-llama/llama-memory-hybrid.cpp | 6 +- .../talk-llama/llama-memory-recurrent.cpp | 12 +- examples/talk-llama/llama-mmap.cpp | 35 +- examples/talk-llama/llama-mmap.h | 1 + examples/talk-llama/llama-model-loader.cpp | 46 +- examples/talk-llama/llama-model-loader.h | 1 + examples/talk-llama/llama-model-saver.cpp | 127 +- examples/talk-llama/llama-model-saver.h | 4 + examples/talk-llama/llama-model.cpp | 1489 +++++++----- examples/talk-llama/llama-model.h | 63 +- examples/talk-llama/llama-quant.cpp | 220 +- examples/talk-llama/llama-vocab.cpp | 148 +- examples/talk-llama/llama-vocab.h | 1 + examples/talk-llama/llama.cpp | 906 ++----- examples/talk-llama/llama.h | 82 +- examples/talk-llama/models/afmoe.cpp | 19 +- examples/talk-llama/models/apertus.cpp | 18 +- examples/talk-llama/models/arcee.cpp | 29 +- examples/talk-llama/models/arctic.cpp | 16 +- examples/talk-llama/models/baichuan.cpp | 17 +- examples/talk-llama/models/bailingmoe.cpp | 28 +- examples/talk-llama/models/bailingmoe2.cpp | 14 +- examples/talk-llama/models/bert.cpp | 38 +- examples/talk-llama/models/bitnet.cpp | 38 +- examples/talk-llama/models/bloom.cpp | 22 +- examples/talk-llama/models/chameleon.cpp | 28 +- examples/talk-llama/models/chatglm.cpp | 45 +- examples/talk-llama/models/codeshell.cpp | 14 +- examples/talk-llama/models/cogvlm.cpp | 10 +- examples/talk-llama/models/cohere2-iswa.cpp | 28 +- examples/talk-llama/models/command-r.cpp | 25 +- examples/talk-llama/models/dbrx.cpp | 18 +- examples/talk-llama/models/deci.cpp | 27 +- examples/talk-llama/models/deepseek.cpp | 25 +- examples/talk-llama/models/deepseek2.cpp | 40 +- examples/talk-llama/models/dots1.cpp | 16 +- examples/talk-llama/models/dream.cpp | 22 +- examples/talk-llama/models/ernie4-5-moe.cpp | 25 +- examples/talk-llama/models/ernie4-5.cpp | 25 +- examples/talk-llama/models/eurobert.cpp | 16 +- examples/talk-llama/models/exaone-moe.cpp | 16 +- examples/talk-llama/models/exaone.cpp | 27 +- examples/talk-llama/models/exaone4.cpp | 17 +- examples/talk-llama/models/falcon-h1.cpp | 17 +- examples/talk-llama/models/falcon.cpp | 12 +- .../talk-llama/models/gemma-embedding.cpp | 18 +- examples/talk-llama/models/gemma.cpp | 17 +- examples/talk-llama/models/gemma2-iswa.cpp | 16 +- examples/talk-llama/models/gemma3.cpp | 18 +- examples/talk-llama/models/gemma3n-iswa.cpp | 98 +- examples/talk-llama/models/gemma4-iswa.cpp | 322 +++ examples/talk-llama/models/glm4-moe.cpp | 25 +- examples/talk-llama/models/glm4.cpp | 41 +- examples/talk-llama/models/gpt2.cpp | 18 +- examples/talk-llama/models/gptneox.cpp | 15 +- examples/talk-llama/models/granite-hybrid.cpp | 28 +- examples/talk-llama/models/granite.cpp | 29 +- examples/talk-llama/models/grok.cpp | 25 +- examples/talk-llama/models/grovemoe.cpp | 16 +- examples/talk-llama/models/hunyuan-dense.cpp | 66 +- examples/talk-llama/models/hunyuan-moe.cpp | 25 +- examples/talk-llama/models/internlm2.cpp | 25 +- examples/talk-llama/models/jais.cpp | 28 +- examples/talk-llama/models/jais2.cpp | 23 +- examples/talk-llama/models/jamba.cpp | 19 +- examples/talk-llama/models/kimi-linear.cpp | 5 +- examples/talk-llama/models/lfm2.cpp | 17 +- examples/talk-llama/models/llada-moe.cpp | 16 +- examples/talk-llama/models/llada.cpp | 15 +- examples/talk-llama/models/llama.cpp | 28 +- .../models/{llama-iswa.cpp => llama4.cpp} | 41 +- examples/talk-llama/models/maincoder.cpp | 16 +- examples/talk-llama/models/mamba-base.cpp | 8 +- examples/talk-llama/models/mimo2-iswa.cpp | 2 +- examples/talk-llama/models/minicpm3.cpp | 2 +- examples/talk-llama/models/minimax-m2.cpp | 2 +- examples/talk-llama/models/mistral3.cpp | 25 +- examples/talk-llama/models/models.h | 36 +- examples/talk-llama/models/modern-bert.cpp | 17 +- examples/talk-llama/models/mpt.cpp | 26 +- examples/talk-llama/models/nemotron-h.cpp | 47 +- examples/talk-llama/models/nemotron.cpp | 25 +- examples/talk-llama/models/neo-bert.cpp | 16 +- examples/talk-llama/models/olmo.cpp | 25 +- examples/talk-llama/models/olmo2.cpp | 2 +- examples/talk-llama/models/olmoe.cpp | 2 +- .../talk-llama/models/openai-moe-iswa.cpp | 25 +- examples/talk-llama/models/openelm.cpp | 2 +- examples/talk-llama/models/orion.cpp | 28 +- examples/talk-llama/models/paddleocr.cpp | 25 +- examples/talk-llama/models/pangu-embedded.cpp | 20 +- examples/talk-llama/models/phi2.cpp | 29 +- examples/talk-llama/models/phi3.cpp | 26 +- examples/talk-llama/models/plamo.cpp | 16 +- examples/talk-llama/models/plamo2.cpp | 3 +- examples/talk-llama/models/plamo3.cpp | 4 +- examples/talk-llama/models/plm.cpp | 2 +- examples/talk-llama/models/qwen.cpp | 14 +- examples/talk-llama/models/qwen2.cpp | 28 +- examples/talk-llama/models/qwen2moe.cpp | 25 +- examples/talk-llama/models/qwen2vl.cpp | 19 +- examples/talk-llama/models/qwen3.cpp | 19 +- examples/talk-llama/models/qwen35.cpp | 12 +- examples/talk-llama/models/qwen35moe.cpp | 12 +- examples/talk-llama/models/qwen3moe.cpp | 19 +- examples/talk-llama/models/qwen3next.cpp | 25 +- examples/talk-llama/models/qwen3vl-moe.cpp | 16 +- examples/talk-llama/models/qwen3vl.cpp | 16 +- examples/talk-llama/models/refact.cpp | 16 +- examples/talk-llama/models/rnd1.cpp | 16 +- examples/talk-llama/models/rwkv6.cpp | 2 +- examples/talk-llama/models/rwkv7.cpp | 2 +- examples/talk-llama/models/seed-oss.cpp | 25 +- examples/talk-llama/models/smallthinker.cpp | 17 +- examples/talk-llama/models/smollm3.cpp | 25 +- examples/talk-llama/models/stablelm.cpp | 28 +- examples/talk-llama/models/starcoder.cpp | 18 +- examples/talk-llama/models/starcoder2.cpp | 25 +- examples/talk-llama/models/step35-iswa.cpp | 6 +- examples/talk-llama/models/t5-enc.cpp | 96 - .../talk-llama/models/{t5-dec.cpp => t5.cpp} | 116 +- examples/talk-llama/models/t5encoder.cpp | 3 + .../talk-llama/models/wavtokenizer-dec.cpp | 2 +- examples/talk-llama/models/xverse.cpp | 16 +- examples/talk-llama/unicode.cpp | 178 +- examples/talk-llama/unicode.h | 2 +- 144 files changed, 3675 insertions(+), 5535 deletions(-) create mode 100644 examples/talk-llama/models/gemma4-iswa.cpp rename examples/talk-llama/models/{llama-iswa.cpp => llama4.cpp} (81%) delete mode 100644 examples/talk-llama/models/t5-enc.cpp rename examples/talk-llama/models/{t5-dec.cpp => t5.cpp} (64%) create mode 100644 examples/talk-llama/models/t5encoder.cpp diff --git a/examples/talk-llama/llama-adapter.cpp b/examples/talk-llama/llama-adapter.cpp index d6a5800e63a..4a1aaa955a8 100644 --- a/examples/talk-llama/llama-adapter.cpp +++ b/examples/talk-llama/llama-adapter.cpp @@ -294,7 +294,7 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_ } // get extra buffer types of the CPU - // TODO: a more general solution for non-CPU extra buft should be imlpemented in the future + // TODO: a more general solution for non-CPU extra buft should be implemented in the future // ref: https://github.com/ggml-org/llama.cpp/pull/12593#pullrequestreview-2718659948 std::vector buft_extra; { @@ -418,7 +418,7 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_ } llama_adapter_lora * llama_adapter_lora_init(llama_model * model, const char * path_lora) { - llama_adapter_lora * adapter = new llama_adapter_lora(); + llama_adapter_lora * adapter = new llama_adapter_lora(model); try { llama_adapter_lora_init_impl(*model, path_lora, *adapter); @@ -471,8 +471,17 @@ int32_t llama_adapter_meta_val_str_by_index(const llama_adapter_lora * adapter, return snprintf(buf, buf_size, "%s", it->second.c_str()); } -void llama_adapter_lora_free(llama_adapter_lora *) { - // deprecated: adapters are freed by llama_model's destructor +void llama_adapter_lora_free(llama_adapter_lora * adapter) { + if (adapter == nullptr) { + return; + } + + if (adapter->model != nullptr) { + adapter->model->loras.erase(adapter); + adapter->model = nullptr; + } + + delete adapter; } uint64_t llama_adapter_get_alora_n_invocation_tokens(const struct llama_adapter_lora * adapter) { diff --git a/examples/talk-llama/llama-adapter.h b/examples/talk-llama/llama-adapter.h index aa3ab63ad75..f0b1e50f816 100644 --- a/examples/talk-llama/llama-adapter.h +++ b/examples/talk-llama/llama-adapter.h @@ -61,6 +61,8 @@ struct llama_adapter_lora_weight { }; struct llama_adapter_lora { + llama_model * model = nullptr; + // map tensor name to lora_a_b std::unordered_map ab_map; @@ -75,7 +77,7 @@ struct llama_adapter_lora { // activated lora (aLoRA) std::vector alora_invocation_tokens; - llama_adapter_lora() = default; + explicit llama_adapter_lora(llama_model * model) : model(model) {} ~llama_adapter_lora() = default; llama_adapter_lora_weight * get_weight(ggml_tensor * w); diff --git a/examples/talk-llama/llama-arch.cpp b/examples/talk-llama/llama-arch.cpp index 799d16167ba..633a66fc665 100644 --- a/examples/talk-llama/llama-arch.cpp +++ b/examples/talk-llama/llama-arch.cpp @@ -56,6 +56,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_GEMMA2, "gemma2" }, { LLM_ARCH_GEMMA3, "gemma3" }, { LLM_ARCH_GEMMA3N, "gemma3n" }, + { LLM_ARCH_GEMMA4, "gemma4" }, { LLM_ARCH_GEMMA_EMBEDDING, "gemma-embedding" }, { LLM_ARCH_STARCODER2, "starcoder2" }, { LLM_ARCH_MAMBA, "mamba" }, @@ -73,6 +74,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_ARCTIC, "arctic" }, { LLM_ARCH_DEEPSEEK, "deepseek" }, { LLM_ARCH_DEEPSEEK2, "deepseek2" }, + { LLM_ARCH_DEEPSEEK2OCR, "deepseek2-ocr" }, { LLM_ARCH_CHATGLM, "chatglm" }, { LLM_ARCH_GLM4, "glm4" }, { LLM_ARCH_GLM4_MOE, "glm4moe" }, @@ -107,6 +109,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_ERNIE4_5_MOE, "ernie4_5-moe" }, { LLM_ARCH_HUNYUAN_MOE, "hunyuan-moe" }, { LLM_ARCH_HUNYUAN_DENSE, "hunyuan-dense" }, + { LLM_ARCH_HUNYUAN_VL, "hunyuan_vl" }, { LLM_ARCH_SMOLLM3, "smollm3" }, { LLM_ARCH_OPENAI_MOE, "gpt-oss" }, { LLM_ARCH_LFM2, "lfm2" }, @@ -123,6 +126,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_RND1, "rnd1" }, { LLM_ARCH_PANGU_EMBED, "pangu-embedded" }, { LLM_ARCH_MISTRAL3, "mistral3" }, + { LLM_ARCH_MISTRAL4, "mistral4" }, { LLM_ARCH_PADDLEOCR, "paddleocr" }, { LLM_ARCH_MIMO2, "mimo2" }, { LLM_ARCH_STEP35, "step35" }, @@ -163,6 +167,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_CONTEXT_LENGTH, "%s.context_length" }, { LLM_KV_EMBEDDING_LENGTH, "%s.embedding_length" }, { LLM_KV_EMBEDDING_LENGTH_OUT, "%s.embedding_length_out" }, + { LLM_KV_EMBEDDING_LENGTH_PER_LAYER, "%s.embedding_length_per_layer_input" }, { LLM_KV_FEATURES_LENGTH, "%s.features_length" }, { LLM_KV_BLOCK_COUNT, "%s.block_count" }, { LLM_KV_LEADING_DENSE_BLOCK_COUNT, "%s.leading_dense_block_count" }, @@ -236,6 +241,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ATTENTION_INDEXER_HEAD_COUNT, "%s.attention.indexer.head_count" }, { LLM_KV_ATTENTION_INDEXER_KEY_LENGTH, "%s.attention.indexer.key_length" }, { LLM_KV_ATTENTION_INDEXER_TOP_K, "%s.attention.indexer.top_k" }, + { LLM_KV_ATTENTION_SHARED_KV_LAYERS, "%s.attention.shared_kv_layers" }, { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, { LLM_KV_ROPE_DIMENSION_COUNT_SWA, "%s.rope.dimension_count_swa" }, @@ -245,6 +251,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" }, { LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" }, { LLM_KV_ROPE_SCALING_FACTOR, "%s.rope.scaling.factor" }, + { LLM_KV_ROPE_SCALING_ALPHA, "%s.rope.scaling.alpha" }, { LLM_KV_ROPE_SCALING_ATTN_FACTOR, "%s.rope.scaling.attn_factor" }, { LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" }, { LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" }, @@ -362,6 +369,9 @@ static const std::map LLM_TENSOR_NAMES = { { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, { LLM_TENSOR_ATTN_GATE, "blk.%d.attn_gate" }, { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, + { LLM_TENSOR_FFN_POST_NORM_1, "blk.%d.post_ffw_norm_1" }, + { LLM_TENSOR_FFN_POST_NORM_2, "blk.%d.post_ffw_norm_2" }, + { LLM_TENSOR_FFN_PRE_NORM_2, "blk.%d.pre_ffw_norm_2" }, { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, @@ -371,6 +381,7 @@ static const std::map LLM_TENSOR_NAMES = { { LLM_TENSOR_ATTN_NORM_2, "blk.%d.attn_norm_2" }, { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, { LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" }, + { LLM_TENSOR_LAYER_OUT_SCALE, "blk.%d.layer_output_scale" }, { LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" }, { LLM_TENSOR_POS_EMBD, "position_embd" }, { LLM_TENSOR_FFN_ACT, "blk.%d.ffn.act" }, @@ -538,2016 +549,6 @@ static const std::map LLM_TENSOR_NAMES = { { LLM_TENSOR_INDEXER_ATTN_Q_B, "blk.%d.indexer.attn_q_b" }, }; -static std::set llm_get_tensor_names(llm_arch arch) { - switch (arch) { - case LLM_ARCH_CLIP: - return {}; - case LLM_ARCH_LLAMA: - case LLM_ARCH_DECI: - case LLM_ARCH_MISTRAL3: - case LLM_ARCH_LLAMA_EMBED: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FREQS, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_ROT_EMBD, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_GATE_EXP, - LLM_TENSOR_FFN_DOWN_EXP, - LLM_TENSOR_FFN_UP_EXP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - }; - case LLM_ARCH_ARCEE: - case LLM_ARCH_STARCODER2: - case LLM_ARCH_NEMOTRON: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FREQS, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_ROT_EMBD, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_AFMOE: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_POST_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_GATE, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_POST_NORM, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_GATE_SHEXP, - LLM_TENSOR_FFN_UP_SHEXP, - LLM_TENSOR_FFN_DOWN_SHEXP, - LLM_TENSOR_FFN_EXP_PROBS_B, - }; - case LLM_ARCH_LLAMA4: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FREQS, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_ROT_EMBD, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_GATE_EXP, - LLM_TENSOR_FFN_DOWN_EXP, - LLM_TENSOR_FFN_UP_EXP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_GATE_SHEXP, - LLM_TENSOR_FFN_DOWN_SHEXP, - LLM_TENSOR_FFN_UP_SHEXP, - }; - case LLM_ARCH_BAICHUAN: - case LLM_ARCH_ORION: - case LLM_ARCH_XVERSE: - case LLM_ARCH_EXAONE: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FREQS, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_ROT_EMBD, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_FALCON: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_NORM_2, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_GROK: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FREQS, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_ROT_EMBD, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_GATE_EXP, - LLM_TENSOR_FFN_DOWN_EXP, - LLM_TENSOR_FFN_UP_EXP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_POST_NORM, - LLM_TENSOR_LAYER_OUT_NORM, - LLM_TENSOR_ATTN_OUT_NORM, - }; - case LLM_ARCH_GPT2: - case LLM_ARCH_STARCODER: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_POS_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_DOWN, - }; - case LLM_ARCH_GPTNEOX: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_MPT: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_ACT, - LLM_TENSOR_POS_EMBD, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K_NORM, - }; - case LLM_ARCH_REFACT: - case LLM_ARCH_QWEN2: - case LLM_ARCH_QWEN2VL: - case LLM_ARCH_INTERNLM2: - case LLM_ARCH_GRANITE: - case LLM_ARCH_ERNIE4_5: - case LLM_ARCH_PADDLEOCR: - case LLM_ARCH_SMOLLM3: - case LLM_ARCH_DREAM: - case LLM_ARCH_LLADA: - case LLM_ARCH_PANGU_EMBED: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_BERT: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_TOKEN_EMBD_NORM, - LLM_TENSOR_TOKEN_TYPES, - LLM_TENSOR_POS_EMBD, - LLM_TENSOR_ATTN_OUT_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_LAYER_OUT_NORM, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_CLS, - LLM_TENSOR_CLS_OUT, - }; - case LLM_ARCH_NOMIC_BERT: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_TOKEN_EMBD_NORM, - LLM_TENSOR_TOKEN_TYPES, - LLM_TENSOR_ATTN_OUT_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_LAYER_OUT_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_NOMIC_BERT_MOE: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_TOKEN_EMBD_NORM, - LLM_TENSOR_TOKEN_TYPES, - LLM_TENSOR_ATTN_OUT_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_LAYER_OUT_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - }; - case LLM_ARCH_NEO_BERT: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_ENC_OUTPUT_NORM, - LLM_TENSOR_CLS, - LLM_TENSOR_CLS_OUT, - }; - case LLM_ARCH_EUROBERT: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_DOWN, - }; - case LLM_ARCH_MODERN_BERT: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_TOKEN_EMBD_NORM, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_CLS, - LLM_TENSOR_CLS_OUT, - LLM_TENSOR_CLS_NORM, - }; - case LLM_ARCH_JINA_BERT_V2: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_TOKEN_EMBD_NORM, - LLM_TENSOR_TOKEN_TYPES, - LLM_TENSOR_ATTN_NORM_2, - LLM_TENSOR_ATTN_OUT_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_LAYER_OUT_NORM, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_CLS, - }; - case LLM_ARCH_JINA_BERT_V3: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_TOKEN_EMBD_NORM, - LLM_TENSOR_TOKEN_TYPES, - LLM_TENSOR_ATTN_OUT_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_LAYER_OUT_NORM, - }; - case LLM_ARCH_BLOOM: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_TOKEN_EMBD_NORM, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_DOWN, - }; - case LLM_ARCH_STABLELM: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FREQS, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K_NORM, - }; - case LLM_ARCH_QWEN: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FREQS, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_QWEN2MOE: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_GATE_INP_SHEXP, - LLM_TENSOR_FFN_GATE_SHEXP, - LLM_TENSOR_FFN_DOWN_SHEXP, - LLM_TENSOR_FFN_UP_SHEXP, - }; - case LLM_ARCH_QWEN3: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_CLS_OUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_QWEN3MOE: - case LLM_ARCH_QWEN3VLMOE: - case LLM_ARCH_OLMOE: - case LLM_ARCH_LLADA_MOE: - case LLM_ARCH_RND1: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - }; - case LLM_ARCH_QWEN3NEXT: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_POST_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_GATE, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_GATE_UP_EXPS, - LLM_TENSOR_FFN_GATE_INP_SHEXP, - LLM_TENSOR_FFN_GATE_SHEXP, - LLM_TENSOR_FFN_DOWN_SHEXP, - LLM_TENSOR_FFN_UP_SHEXP, - LLM_TENSOR_SSM_A_NOSCAN, - LLM_TENSOR_SSM_CONV1D, - LLM_TENSOR_SSM_DT, - LLM_TENSOR_SSM_BETA_ALPHA, - LLM_TENSOR_SSM_IN, - LLM_TENSOR_SSM_NORM, - LLM_TENSOR_SSM_OUT, - }; - case LLM_ARCH_QWEN35: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_POST_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_GATE, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_SSM_A_NOSCAN, - LLM_TENSOR_SSM_CONV1D, - LLM_TENSOR_SSM_DT, - LLM_TENSOR_SSM_BETA, - LLM_TENSOR_SSM_ALPHA, - LLM_TENSOR_SSM_NORM, - LLM_TENSOR_SSM_OUT, - }; - case LLM_ARCH_QWEN35MOE: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_POST_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_GATE, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_GATE_UP_EXPS, - LLM_TENSOR_FFN_GATE_INP_SHEXP, - LLM_TENSOR_FFN_GATE_SHEXP, - LLM_TENSOR_FFN_DOWN_SHEXP, - LLM_TENSOR_FFN_UP_SHEXP, - LLM_TENSOR_SSM_A_NOSCAN, - LLM_TENSOR_SSM_CONV1D, - LLM_TENSOR_SSM_DT, - LLM_TENSOR_SSM_BETA, - LLM_TENSOR_SSM_ALPHA, - LLM_TENSOR_SSM_NORM, - LLM_TENSOR_SSM_OUT, - }; - case LLM_ARCH_QWEN3VL: - case LLM_ARCH_CHAMELEON: - case LLM_ARCH_HUNYUAN_DENSE: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_CLS_OUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_PHI2: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_PHI3: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FACTORS_LONG, - LLM_TENSOR_ROPE_FACTORS_SHORT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_PHIMOE: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FACTORS_LONG, - LLM_TENSOR_ROPE_FACTORS_SHORT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - }; - case LLM_ARCH_PLAMO: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FREQS, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_ROT_EMBD, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_PLAMO2: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FREQS, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_ROT_EMBD, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_SSM_IN, - LLM_TENSOR_SSM_CONV1D, - LLM_TENSOR_SSM_X, - LLM_TENSOR_SSM_DT, - LLM_TENSOR_SSM_A, - LLM_TENSOR_SSM_D, - LLM_TENSOR_SSM_OUT, - LLM_TENSOR_SSM_DT_NORM, - LLM_TENSOR_SSM_B_NORM, - LLM_TENSOR_SSM_C_NORM, - LLM_TENSOR_ATTN_POST_NORM, - LLM_TENSOR_FFN_POST_NORM, - }; - case LLM_ARCH_PLAMO3: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_POST_NORM, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_POST_NORM, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_CODESHELL: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FREQS, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_ROT_EMBD, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_MINICPM: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FREQS, - LLM_TENSOR_ROPE_FACTORS_LONG, - LLM_TENSOR_ROPE_FACTORS_SHORT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_ROT_EMBD, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_GATE_EXP, - LLM_TENSOR_FFN_DOWN_EXP, - LLM_TENSOR_FFN_UP_EXP, - }; - case LLM_ARCH_MINICPM3: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FACTORS_LONG, - LLM_TENSOR_ROPE_FACTORS_SHORT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q_A_NORM, - LLM_TENSOR_ATTN_KV_A_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_Q_A, - LLM_TENSOR_ATTN_Q_B, - LLM_TENSOR_ATTN_KV_A_MQA, - LLM_TENSOR_ATTN_KV_B, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_DOWN, - }; - case LLM_ARCH_GEMMA: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_GEMMA2: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_POST_NORM, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_POST_NORM, - }; - case LLM_ARCH_GEMMA3: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_POST_NORM, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_POST_NORM, - }; - case LLM_ARCH_GEMMA3N: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_POST_NORM, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_POST_NORM, - LLM_TENSOR_PER_LAYER_TOKEN_EMBD, - LLM_TENSOR_PER_LAYER_MODEL_PROJ, - LLM_TENSOR_PER_LAYER_PROJ_NORM, - LLM_TENSOR_ALTUP_UNEMBD_PROJ, - LLM_TENSOR_ALTUP_PROJ, - LLM_TENSOR_PER_LAYER_INP_GATE, - LLM_TENSOR_PER_LAYER_PROJ, - LLM_TENSOR_PER_LAYER_POST_NORM, - LLM_TENSOR_ALTUP_CORRECT_COEF, - LLM_TENSOR_ALTUP_CORRECT_SCALE, - LLM_TENSOR_ALTUP_PREDICT_COEF, - LLM_TENSOR_ALTUP_ROUTER, - LLM_TENSOR_ALTUP_ROUTER_NORM, - LLM_TENSOR_LAUREL_L, - LLM_TENSOR_LAUREL_R, - LLM_TENSOR_LAUREL_POST_NORM, - }; - case LLM_ARCH_GEMMA_EMBEDDING: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_DENSE_2_OUT, - LLM_TENSOR_DENSE_3_OUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_POST_NORM, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_POST_NORM, - }; - case LLM_ARCH_MAMBA: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_SSM_IN, - LLM_TENSOR_SSM_CONV1D, - LLM_TENSOR_SSM_X, - LLM_TENSOR_SSM_DT, - LLM_TENSOR_SSM_A, - LLM_TENSOR_SSM_D, - LLM_TENSOR_SSM_OUT, - }; - case LLM_ARCH_MAMBA2: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_SSM_IN, - LLM_TENSOR_SSM_CONV1D, - LLM_TENSOR_SSM_DT, - LLM_TENSOR_SSM_A, - LLM_TENSOR_SSM_D, - LLM_TENSOR_SSM_NORM, - LLM_TENSOR_SSM_OUT, - }; - case LLM_ARCH_JAMBA: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_SSM_IN, - LLM_TENSOR_SSM_CONV1D, - LLM_TENSOR_SSM_X, - LLM_TENSOR_SSM_DT, - LLM_TENSOR_SSM_DT_NORM, - LLM_TENSOR_SSM_A, - LLM_TENSOR_SSM_B_NORM, - LLM_TENSOR_SSM_C_NORM, - LLM_TENSOR_SSM_D, - LLM_TENSOR_SSM_OUT, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - }; - case LLM_ARCH_FALCON_H1: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_SSM_IN, - LLM_TENSOR_SSM_CONV1D, - LLM_TENSOR_SSM_DT, - LLM_TENSOR_SSM_A, - LLM_TENSOR_SSM_D, - LLM_TENSOR_SSM_NORM, - LLM_TENSOR_SSM_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_COMMAND_R: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K_NORM, - }; - case LLM_ARCH_COHERE2: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_DBRX: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_OUT_NORM, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - }; - case LLM_ARCH_OLMO: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_OLMO2: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_POST_NORM, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_FFN_POST_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_OPENELM: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_ARCTIC: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_NORM_EXPS, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - }; - case LLM_ARCH_DEEPSEEK: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FREQS, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_ROT_EMBD, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_GATE_INP_SHEXP, - LLM_TENSOR_FFN_GATE_SHEXP, - LLM_TENSOR_FFN_DOWN_SHEXP, - LLM_TENSOR_FFN_UP_SHEXP, - }; - case LLM_ARCH_DEEPSEEK2: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q_A_NORM, - LLM_TENSOR_ATTN_KV_A_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_Q_A, - LLM_TENSOR_ATTN_Q_B, - LLM_TENSOR_ATTN_KV_A_MQA, - LLM_TENSOR_ATTN_KV_B, - LLM_TENSOR_ATTN_K_B, - LLM_TENSOR_ATTN_V_B, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_GATE_UP_EXPS, - LLM_TENSOR_FFN_GATE_INP_SHEXP, - LLM_TENSOR_FFN_GATE_SHEXP, - LLM_TENSOR_FFN_DOWN_SHEXP, - LLM_TENSOR_FFN_UP_SHEXP, - LLM_TENSOR_FFN_EXP_PROBS_B, - }; - case LLM_ARCH_PLM: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_KV_A_MQA, - LLM_TENSOR_ATTN_KV_A_NORM, - LLM_TENSOR_ATTN_KV_B, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_CHATGLM: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_ROPE_FREQS, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_DOWN, - }; - case LLM_ARCH_GLM4: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_ROPE_FREQS, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_ATTN_POST_NORM, - LLM_TENSOR_FFN_POST_NORM, - LLM_TENSOR_NEXTN_EH_PROJ, - LLM_TENSOR_NEXTN_EMBED_TOKENS, - LLM_TENSOR_NEXTN_ENORM, - LLM_TENSOR_NEXTN_HNORM, - LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, - LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, - }; - case LLM_ARCH_GLM4_MOE: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_POST_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_GATE_SHEXP, - LLM_TENSOR_FFN_DOWN_SHEXP, - LLM_TENSOR_FFN_UP_SHEXP, - LLM_TENSOR_FFN_EXP_PROBS_B, - LLM_TENSOR_NEXTN_EH_PROJ, - LLM_TENSOR_NEXTN_EMBED_TOKENS, - LLM_TENSOR_NEXTN_ENORM, - LLM_TENSOR_NEXTN_HNORM, - LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, - LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, - }; - case LLM_ARCH_GLM_DSA: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q_A_NORM, - LLM_TENSOR_ATTN_KV_A_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_Q_A, - LLM_TENSOR_ATTN_Q_B, - LLM_TENSOR_ATTN_KV_A_MQA, - LLM_TENSOR_ATTN_KV_B, - LLM_TENSOR_ATTN_K_B, - LLM_TENSOR_ATTN_V_B, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_GATE_INP_SHEXP, - LLM_TENSOR_FFN_GATE_SHEXP, - LLM_TENSOR_FFN_DOWN_SHEXP, - LLM_TENSOR_FFN_UP_SHEXP, - LLM_TENSOR_FFN_EXP_PROBS_B, - LLM_TENSOR_INDEXER_K_NORM, - LLM_TENSOR_INDEXER_PROJ, - LLM_TENSOR_INDEXER_ATTN_K, - LLM_TENSOR_INDEXER_ATTN_Q_B, - LLM_TENSOR_NEXTN_EH_PROJ, - LLM_TENSOR_NEXTN_EMBED_TOKENS, - LLM_TENSOR_NEXTN_ENORM, - LLM_TENSOR_NEXTN_HNORM, - LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, - LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, - }; - case LLM_ARCH_BITNET: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_SUB_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_SUB_NORM, - }; - case LLM_ARCH_T5: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_DEC_OUTPUT_NORM, - LLM_TENSOR_DEC_ATTN_NORM, - LLM_TENSOR_DEC_ATTN_Q, - LLM_TENSOR_DEC_ATTN_K, - LLM_TENSOR_DEC_ATTN_V, - LLM_TENSOR_DEC_ATTN_OUT, - LLM_TENSOR_DEC_ATTN_REL_B, - LLM_TENSOR_DEC_CROSS_ATTN_NORM, - LLM_TENSOR_DEC_CROSS_ATTN_Q, - LLM_TENSOR_DEC_CROSS_ATTN_K, - LLM_TENSOR_DEC_CROSS_ATTN_V, - LLM_TENSOR_DEC_CROSS_ATTN_OUT, - LLM_TENSOR_DEC_CROSS_ATTN_REL_B, - LLM_TENSOR_DEC_FFN_NORM, - LLM_TENSOR_DEC_FFN_GATE, - LLM_TENSOR_DEC_FFN_DOWN, - LLM_TENSOR_DEC_FFN_UP, - LLM_TENSOR_ENC_OUTPUT_NORM, - LLM_TENSOR_ENC_ATTN_NORM, - LLM_TENSOR_ENC_ATTN_Q, - LLM_TENSOR_ENC_ATTN_K, - LLM_TENSOR_ENC_ATTN_V, - LLM_TENSOR_ENC_ATTN_OUT, - LLM_TENSOR_ENC_ATTN_REL_B, - LLM_TENSOR_ENC_FFN_NORM, - LLM_TENSOR_ENC_FFN_GATE, - LLM_TENSOR_ENC_FFN_DOWN, - LLM_TENSOR_ENC_FFN_UP, - }; - case LLM_ARCH_T5ENCODER: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ENC_OUTPUT_NORM, - LLM_TENSOR_ENC_ATTN_NORM, - LLM_TENSOR_ENC_ATTN_Q, - LLM_TENSOR_ENC_ATTN_K, - LLM_TENSOR_ENC_ATTN_V, - LLM_TENSOR_ENC_ATTN_OUT, - LLM_TENSOR_ENC_ATTN_REL_B, - LLM_TENSOR_ENC_FFN_NORM, - LLM_TENSOR_ENC_FFN_GATE, - LLM_TENSOR_ENC_FFN_DOWN, - LLM_TENSOR_ENC_FFN_UP, - }; - case LLM_ARCH_JAIS: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - }; - case LLM_ARCH_JAIS2: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_DOWN, - }; - case LLM_ARCH_NEMOTRON_H: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_SSM_IN, - LLM_TENSOR_SSM_CONV1D, - LLM_TENSOR_SSM_DT, - LLM_TENSOR_SSM_A, - LLM_TENSOR_SSM_D, - LLM_TENSOR_SSM_NORM, - LLM_TENSOR_SSM_OUT, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_NEMOTRON_H_MOE: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - // mamba(2) ssm layers - LLM_TENSOR_SSM_IN, - LLM_TENSOR_SSM_CONV1D, - LLM_TENSOR_SSM_DT, - LLM_TENSOR_SSM_A, - LLM_TENSOR_SSM_D, - LLM_TENSOR_SSM_NORM, - LLM_TENSOR_SSM_OUT, - // attention layers - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - // dense FFN - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - // MoE FFN (for MoE layers) - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_EXP_PROBS_B, - LLM_TENSOR_FFN_LATENT_DOWN, - LLM_TENSOR_FFN_LATENT_UP, - // MoE shared expert layer - LLM_TENSOR_FFN_DOWN_SHEXP, - LLM_TENSOR_FFN_UP_SHEXP, - }; - case LLM_ARCH_EXAONE4: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FREQS, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_POST_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_POST_NORM, - }; - case LLM_ARCH_EXAONE_MOE: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FREQS, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_GATE_SHEXP, - LLM_TENSOR_FFN_UP_SHEXP, - LLM_TENSOR_FFN_DOWN_SHEXP, - LLM_TENSOR_FFN_EXP_PROBS_B, - LLM_TENSOR_NEXTN_EH_PROJ, - LLM_TENSOR_NEXTN_EMBED_TOKENS, - LLM_TENSOR_NEXTN_ENORM, - LLM_TENSOR_NEXTN_HNORM, - LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, - LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, - }; - case LLM_ARCH_RWKV6: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_TOKEN_EMBD_NORM, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_NORM_2, - LLM_TENSOR_TIME_MIX_W1, - LLM_TENSOR_TIME_MIX_W2, - LLM_TENSOR_TIME_MIX_LERP_X, - LLM_TENSOR_TIME_MIX_LERP_W, - LLM_TENSOR_TIME_MIX_LERP_K, - LLM_TENSOR_TIME_MIX_LERP_V, - LLM_TENSOR_TIME_MIX_LERP_R, - LLM_TENSOR_TIME_MIX_LERP_G, - LLM_TENSOR_TIME_MIX_LERP_FUSED, - LLM_TENSOR_TIME_MIX_FIRST, - LLM_TENSOR_TIME_MIX_DECAY, - LLM_TENSOR_TIME_MIX_DECAY_W1, - LLM_TENSOR_TIME_MIX_DECAY_W2, - LLM_TENSOR_TIME_MIX_KEY, - LLM_TENSOR_TIME_MIX_VALUE, - LLM_TENSOR_TIME_MIX_RECEPTANCE, - LLM_TENSOR_TIME_MIX_GATE, - LLM_TENSOR_TIME_MIX_LN, - LLM_TENSOR_TIME_MIX_OUTPUT, - LLM_TENSOR_CHANNEL_MIX_LERP_K, - LLM_TENSOR_CHANNEL_MIX_LERP_R, - LLM_TENSOR_CHANNEL_MIX_KEY, - LLM_TENSOR_CHANNEL_MIX_VALUE, - LLM_TENSOR_CHANNEL_MIX_RECEPTANCE, - }; - case LLM_ARCH_RWKV6QWEN2: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_TIME_MIX_W1, - LLM_TENSOR_TIME_MIX_W2, - LLM_TENSOR_TIME_MIX_LERP_X, - LLM_TENSOR_TIME_MIX_LERP_FUSED, - LLM_TENSOR_TIME_MIX_FIRST, - LLM_TENSOR_TIME_MIX_DECAY, - LLM_TENSOR_TIME_MIX_DECAY_W1, - LLM_TENSOR_TIME_MIX_DECAY_W2, - LLM_TENSOR_TIME_MIX_KEY, - LLM_TENSOR_TIME_MIX_VALUE, - LLM_TENSOR_TIME_MIX_RECEPTANCE, - LLM_TENSOR_TIME_MIX_GATE, - LLM_TENSOR_TIME_MIX_OUTPUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_RWKV7: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_TOKEN_EMBD_NORM, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_NORM_2, - LLM_TENSOR_TIME_MIX_W0, - LLM_TENSOR_TIME_MIX_W1, - LLM_TENSOR_TIME_MIX_W2, - LLM_TENSOR_TIME_MIX_A0, - LLM_TENSOR_TIME_MIX_A1, - LLM_TENSOR_TIME_MIX_A2, - LLM_TENSOR_TIME_MIX_V0, - LLM_TENSOR_TIME_MIX_V1, - LLM_TENSOR_TIME_MIX_V2, - LLM_TENSOR_TIME_MIX_G1, - LLM_TENSOR_TIME_MIX_G2, - LLM_TENSOR_TIME_MIX_K_K, - LLM_TENSOR_TIME_MIX_K_A, - LLM_TENSOR_TIME_MIX_R_K, - LLM_TENSOR_TIME_MIX_LERP_FUSED, - LLM_TENSOR_TIME_MIX_KEY, - LLM_TENSOR_TIME_MIX_VALUE, - LLM_TENSOR_TIME_MIX_RECEPTANCE, - LLM_TENSOR_TIME_MIX_LN, - LLM_TENSOR_TIME_MIX_OUTPUT, - LLM_TENSOR_CHANNEL_MIX_LERP_K, - LLM_TENSOR_CHANNEL_MIX_KEY, - LLM_TENSOR_CHANNEL_MIX_VALUE, - }; - case LLM_ARCH_ARWKV7: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_TOKEN_EMBD_NORM, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_TIME_MIX_W0, - LLM_TENSOR_TIME_MIX_W1, - LLM_TENSOR_TIME_MIX_W2, - LLM_TENSOR_TIME_MIX_A0, - LLM_TENSOR_TIME_MIX_A1, - LLM_TENSOR_TIME_MIX_A2, - LLM_TENSOR_TIME_MIX_V0, - LLM_TENSOR_TIME_MIX_V1, - LLM_TENSOR_TIME_MIX_V2, - LLM_TENSOR_TIME_MIX_G1, - LLM_TENSOR_TIME_MIX_G2, - LLM_TENSOR_TIME_MIX_K_K, - LLM_TENSOR_TIME_MIX_K_A, - LLM_TENSOR_TIME_MIX_R_K, - LLM_TENSOR_TIME_MIX_LERP_FUSED, - LLM_TENSOR_TIME_MIX_KEY, - LLM_TENSOR_TIME_MIX_VALUE, - LLM_TENSOR_TIME_MIX_RECEPTANCE, - LLM_TENSOR_TIME_MIX_LN, - LLM_TENSOR_TIME_MIX_OUTPUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_GRANITE_MOE: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_GATE_SHEXP, - LLM_TENSOR_FFN_DOWN_SHEXP, - LLM_TENSOR_FFN_UP_SHEXP, - }; - case LLM_ARCH_GRANITE_HYBRID: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_SSM_IN, - LLM_TENSOR_SSM_CONV1D, - LLM_TENSOR_SSM_DT, - LLM_TENSOR_SSM_A, - LLM_TENSOR_SSM_D, - LLM_TENSOR_SSM_NORM, - LLM_TENSOR_SSM_OUT, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_GATE_SHEXP, - LLM_TENSOR_FFN_DOWN_SHEXP, - LLM_TENSOR_FFN_UP_SHEXP, - }; - case LLM_ARCH_WAVTOKENIZER_DEC: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_TOKEN_EMBD_NORM, - LLM_TENSOR_CONV1D, - LLM_TENSOR_CONVNEXT_DW, - LLM_TENSOR_CONVNEXT_NORM, - LLM_TENSOR_CONVNEXT_PW1, - LLM_TENSOR_CONVNEXT_PW2, - LLM_TENSOR_CONVNEXT_GAMMA, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_POS_NET_CONV1, - LLM_TENSOR_POS_NET_CONV2, - LLM_TENSOR_POS_NET_NORM, - LLM_TENSOR_POS_NET_NORM1, - LLM_TENSOR_POS_NET_NORM2, - LLM_TENSOR_POS_NET_ATTN_NORM, - LLM_TENSOR_POS_NET_ATTN_Q, - LLM_TENSOR_POS_NET_ATTN_K, - LLM_TENSOR_POS_NET_ATTN_V, - LLM_TENSOR_POS_NET_ATTN_OUT, - }; - case LLM_ARCH_BAILINGMOE: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FREQS, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_GATE_INP_SHEXP, - LLM_TENSOR_FFN_GATE_SHEXP, - LLM_TENSOR_FFN_DOWN_SHEXP, - LLM_TENSOR_FFN_UP_SHEXP, - }; - case LLM_ARCH_BAILINGMOE2: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_EXP_PROBS_B, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_GATE_SHEXP, - LLM_TENSOR_FFN_DOWN_SHEXP, - LLM_TENSOR_FFN_UP_SHEXP, - LLM_TENSOR_NEXTN_EH_PROJ, - LLM_TENSOR_NEXTN_EMBED_TOKENS, - LLM_TENSOR_NEXTN_ENORM, - LLM_TENSOR_NEXTN_HNORM, - LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, - LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, - LLM_TENSOR_LAYER_OUT_NORM, - }; - case LLM_ARCH_DOTS1: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_GATE_INP_SHEXP, - LLM_TENSOR_FFN_GATE_SHEXP, - LLM_TENSOR_FFN_DOWN_SHEXP, - LLM_TENSOR_FFN_UP_SHEXP, - LLM_TENSOR_FFN_EXP_PROBS_B, - }; - case LLM_ARCH_ERNIE4_5_MOE: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_SHEXP, - LLM_TENSOR_FFN_DOWN_SHEXP, - LLM_TENSOR_FFN_UP_SHEXP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_EXP_PROBS_B, - }; - case LLM_ARCH_HUNYUAN_MOE: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE_SHEXP, - LLM_TENSOR_FFN_DOWN_SHEXP, - LLM_TENSOR_FFN_UP_SHEXP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - }; - case LLM_ARCH_OPENAI_MOE: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_POST_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_SINKS, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - }; - case LLM_ARCH_LFM2: - return { - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_SHORTCONV_CONV, - LLM_TENSOR_SHORTCONV_INPROJ, - LLM_TENSOR_SHORTCONV_OUTPROJ, - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM_LFM2, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_DENSE_2_OUT, - }; - case LLM_ARCH_LFM2MOE: - return { - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_SHORTCONV_CONV, - LLM_TENSOR_SHORTCONV_INPROJ, - LLM_TENSOR_SHORTCONV_OUTPROJ, - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM_LFM2, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_EXP_PROBS_B, - }; - case LLM_ARCH_SMALLTHINKER: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - }; - case LLM_ARCH_APERTUS: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FREQS, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_SEED_OSS: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_POST_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_GROVEMOE: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_GATE_CHEXPS, - LLM_TENSOR_FFN_DOWN_CHEXPS, - LLM_TENSOR_FFN_UP_CHEXPS, - }; - case LLM_ARCH_MINIMAX_M2: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_EXP_PROBS_B, - }; - case LLM_ARCH_COGVLM: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_VISEXP_ATTN_QKV, - LLM_TENSOR_VISEXP_ATTN_OUT, - LLM_TENSOR_VISEXP_FFN_GATE, - LLM_TENSOR_VISEXP_FFN_DOWN, - LLM_TENSOR_VISEXP_FFN_UP, - }; - case LLM_ARCH_MIMO2: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_SINKS, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_EXP_PROBS_B, - }; - case LLM_ARCH_STEP35: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FREQS, - LLM_TENSOR_ROPE_FACTORS_LONG, - LLM_TENSOR_ROPE_FACTORS_SHORT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_GATE, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_GATE_SHEXP, - LLM_TENSOR_FFN_UP_SHEXP, - LLM_TENSOR_FFN_DOWN_SHEXP, - LLM_TENSOR_FFN_EXP_PROBS_B, - }; - case LLM_ARCH_GPTJ: - case LLM_ARCH_UNKNOWN: - return { - LLM_TENSOR_TOKEN_EMBD, - }; - case LLM_ARCH_MAINCODER: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_KIMI_LINEAR: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FREQS, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - // Dense FFN (layer 0 only) - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - // MoE FFN (layers 1+) - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_EXP_PROBS_B, - // Shared experts - LLM_TENSOR_FFN_GATE_SHEXP, - LLM_TENSOR_FFN_DOWN_SHEXP, - LLM_TENSOR_FFN_UP_SHEXP, - // KDA (using SSM_ enum prefix, keeping GGUF names for backward compat) - LLM_TENSOR_SSM_CONV1D_Q, - LLM_TENSOR_SSM_CONV1D_K, - LLM_TENSOR_SSM_CONV1D_V, - LLM_TENSOR_SSM_F_A, - LLM_TENSOR_SSM_F_B, - LLM_TENSOR_SSM_BETA, - LLM_TENSOR_SSM_A, - LLM_TENSOR_SSM_G_A, - LLM_TENSOR_SSM_G_B, - LLM_TENSOR_SSM_DT, - LLM_TENSOR_SSM_NORM, - // MLA - LLM_TENSOR_ATTN_Q_A, - LLM_TENSOR_ATTN_Q_B, - LLM_TENSOR_ATTN_Q_A_NORM, - LLM_TENSOR_ATTN_KV_A_MQA, - LLM_TENSOR_ATTN_KV_B, - LLM_TENSOR_ATTN_K_B, - LLM_TENSOR_ATTN_V_B, - LLM_TENSOR_ATTN_KV_A_NORM, - }; - default: - GGML_ABORT("unknown architecture for tensor mapping"); - } -} - // declare information about the model weight tensors: // - the layer in which the tensor is going to be used. this is needed in order to assign the correct buffer type for the weight // - the operator which is going to use the weight. this is needed to determine if the respective backend supports the operator @@ -2559,20 +560,20 @@ static std::set llm_get_tensor_names(llm_arch arch) { // example: https://github.com/ggml-org/llama.cpp/pull/17548 // static const std::map LLM_TENSOR_INFOS = { - {LLM_TENSOR_TOKEN_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}}, - {LLM_TENSOR_POS_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}}, - {LLM_TENSOR_TOKEN_TYPES, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}}, - {LLM_TENSOR_TOKEN_EMBD_NORM, {LLM_TENSOR_LAYER_INPUT, GGML_OP_MUL}}, - {LLM_TENSOR_OUTPUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_CLS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_CLS_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_CLS_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, - {LLM_TENSOR_DENSE_2_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output - {LLM_TENSOR_DENSE_3_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output - {LLM_TENSOR_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, - {LLM_TENSOR_OUTPUT_NORM_LFM2, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, - {LLM_TENSOR_DEC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, - {LLM_TENSOR_ENC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, + {LLM_TENSOR_TOKEN_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}}, + {LLM_TENSOR_POS_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}}, + {LLM_TENSOR_TOKEN_TYPES, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}}, + {LLM_TENSOR_TOKEN_EMBD_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, // do the norms on the first layer (not the input layer) + {LLM_TENSOR_OUTPUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_CLS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_CLS_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_CLS_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, + {LLM_TENSOR_DENSE_2_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output + {LLM_TENSOR_DENSE_3_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output + {LLM_TENSOR_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, + {LLM_TENSOR_OUTPUT_NORM_LFM2, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, + {LLM_TENSOR_DEC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, + {LLM_TENSOR_ENC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, {LLM_TENSOR_ROPE_FREQS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ROPE}}, {LLM_TENSOR_ROPE_FACTORS_LONG, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ROPE}}, {LLM_TENSOR_ROPE_FACTORS_SHORT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ROPE}}, @@ -2680,11 +681,15 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_ATTN_OUT_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_ATTN_POST_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_FFN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_FFN_PRE_NORM_2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_FFN_POST_NORM_1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_FFN_POST_NORM_2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_FFN_POST_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_FFN_NORM_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_ATTN_Q_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_ATTN_K_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_LAYER_OUT_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_LAYER_OUT_SCALE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_ATTN_Q_A_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_ATTN_KV_A_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_ATTN_SUB_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, @@ -2705,9 +710,9 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_FFN_UP_CHEXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, {LLM_TENSOR_FFN_EXP_PROBS_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, // altup / laurel (gemma 3n) - {LLM_TENSOR_PER_LAYER_TOKEN_EMBD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}}, - {LLM_TENSOR_PER_LAYER_MODEL_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_PER_LAYER_PROJ_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, + {LLM_TENSOR_PER_LAYER_TOKEN_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}}, + {LLM_TENSOR_PER_LAYER_MODEL_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_PER_LAYER_PROJ_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_ALTUP_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, {LLM_TENSOR_ALTUP_UNEMBD_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, {LLM_TENSOR_PER_LAYER_INP_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, @@ -2723,7 +728,7 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_LAUREL_POST_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, // this tensor is loaded for T5, but never used {LLM_TENSOR_DEC_CROSS_ATTN_REL_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_NONE}}, - {LLM_TENSOR_CONV1D, {LLM_TENSOR_LAYER_INPUT, GGML_OP_IM2COL}}, + {LLM_TENSOR_CONV1D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_IM2COL}}, {LLM_TENSOR_POS_NET_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_POS_NET_NORM1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_POS_NET_NORM2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, @@ -2778,18 +783,13 @@ std::string LLM_KV::operator()(llm_kv kv) const { } LLM_TN_IMPL::LLM_TN_IMPL(llm_arch arch, llm_tensor tensor, const char * suffix, int bid, int xid) - : arch(arch), tensor(tensor), suffix(suffix), bid(bid), xid(xid), - model_tensors(llm_get_tensor_names(arch)) {} + : arch(arch), tensor(tensor), suffix(suffix), bid(bid), xid(xid) {} std::string LLM_TN_IMPL::str() const { if (LLM_TENSOR_NAMES.find(tensor) == LLM_TENSOR_NAMES.end()) { GGML_ABORT("unknown tensor name for tensor id %d", static_cast(tensor)); } - if (model_tensors.find(tensor) == model_tensors.end()) { - return LLM_TENSOR_NAMES.at(tensor); - } - std::string name = ::format(LLM_TENSOR_NAMES.at(tensor), bid, xid); if (suffix != nullptr) { name += "."; @@ -2875,3 +875,34 @@ bool llm_arch_is_diffusion(const llm_arch & arch) { return false; } } + +bool llm_arch_supports_sm_tensor(const llm_arch & arch) { + switch (arch) { + case LLM_ARCH_GROK: + case LLM_ARCH_MPT: + case LLM_ARCH_PLAMO2: + case LLM_ARCH_MINICPM3: + case LLM_ARCH_GEMMA3N: + case LLM_ARCH_MAMBA: + case LLM_ARCH_MAMBA2: + case LLM_ARCH_JAMBA: + case LLM_ARCH_FALCON_H1: + case LLM_ARCH_OLMO2: + case LLM_ARCH_OLMOE: + case LLM_ARCH_DEEPSEEK2: + case LLM_ARCH_GLM_DSA: + case LLM_ARCH_BITNET: + case LLM_ARCH_T5: + case LLM_ARCH_NEMOTRON_H: + case LLM_ARCH_NEMOTRON_H_MOE: + case LLM_ARCH_GRANITE_HYBRID: + case LLM_ARCH_LFM2: + case LLM_ARCH_LFM2MOE: + case LLM_ARCH_MINIMAX_M2: + case LLM_ARCH_MISTRAL4: + case LLM_ARCH_KIMI_LINEAR: + return false; + default: + return true; + } +} diff --git a/examples/talk-llama/llama-arch.h b/examples/talk-llama/llama-arch.h index b1b1dcf1883..8f335f5c7b3 100644 --- a/examples/talk-llama/llama-arch.h +++ b/examples/talk-llama/llama-arch.h @@ -60,6 +60,7 @@ enum llm_arch { LLM_ARCH_GEMMA2, LLM_ARCH_GEMMA3, LLM_ARCH_GEMMA3N, + LLM_ARCH_GEMMA4, LLM_ARCH_GEMMA_EMBEDDING, LLM_ARCH_STARCODER2, LLM_ARCH_MAMBA, @@ -77,6 +78,7 @@ enum llm_arch { LLM_ARCH_ARCTIC, LLM_ARCH_DEEPSEEK, LLM_ARCH_DEEPSEEK2, + LLM_ARCH_DEEPSEEK2OCR, LLM_ARCH_CHATGLM, LLM_ARCH_GLM4, LLM_ARCH_GLM4_MOE, @@ -111,6 +113,7 @@ enum llm_arch { LLM_ARCH_ERNIE4_5_MOE, LLM_ARCH_HUNYUAN_MOE, LLM_ARCH_HUNYUAN_DENSE, + LLM_ARCH_HUNYUAN_VL, LLM_ARCH_SMOLLM3, LLM_ARCH_OPENAI_MOE, LLM_ARCH_LFM2, @@ -127,6 +130,7 @@ enum llm_arch { LLM_ARCH_RND1, LLM_ARCH_PANGU_EMBED, LLM_ARCH_MISTRAL3, + LLM_ARCH_MISTRAL4, LLM_ARCH_PADDLEOCR, LLM_ARCH_MIMO2, LLM_ARCH_STEP35, @@ -167,6 +171,7 @@ enum llm_kv { LLM_KV_CONTEXT_LENGTH, LLM_KV_EMBEDDING_LENGTH, LLM_KV_EMBEDDING_LENGTH_OUT, + LLM_KV_EMBEDDING_LENGTH_PER_LAYER, LLM_KV_FEATURES_LENGTH, LLM_KV_BLOCK_COUNT, LLM_KV_LEADING_DENSE_BLOCK_COUNT, @@ -240,6 +245,7 @@ enum llm_kv { LLM_KV_ATTENTION_INDEXER_HEAD_COUNT, LLM_KV_ATTENTION_INDEXER_KEY_LENGTH, LLM_KV_ATTENTION_INDEXER_TOP_K, + LLM_KV_ATTENTION_SHARED_KV_LAYERS, LLM_KV_ROPE_DIMENSION_COUNT, LLM_KV_ROPE_DIMENSION_COUNT_SWA, @@ -249,6 +255,7 @@ enum llm_kv { LLM_KV_ROPE_SCALE_LINEAR, LLM_KV_ROPE_SCALING_TYPE, LLM_KV_ROPE_SCALING_FACTOR, + LLM_KV_ROPE_SCALING_ALPHA, LLM_KV_ROPE_SCALING_ATTN_FACTOR, LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, LLM_KV_ROPE_SCALING_FINETUNED, @@ -367,6 +374,9 @@ enum llm_tensor { LLM_TENSOR_FFN_GATE_INP_SHEXP, LLM_TENSOR_FFN_NORM, LLM_TENSOR_FFN_POST_NORM, + LLM_TENSOR_FFN_POST_NORM_1, + LLM_TENSOR_FFN_POST_NORM_2, + LLM_TENSOR_FFN_PRE_NORM_2, LLM_TENSOR_FFN_GATE, LLM_TENSOR_FFN_DOWN, LLM_TENSOR_FFN_UP, @@ -391,6 +401,7 @@ enum llm_tensor { LLM_TENSOR_ATTN_Q_NORM, LLM_TENSOR_ATTN_K_NORM, LLM_TENSOR_LAYER_OUT_NORM, + LLM_TENSOR_LAYER_OUT_SCALE, LLM_TENSOR_POST_ATTN_NORM, LLM_TENSOR_POST_MLP_NORM, LLM_TENSOR_PER_LAYER_TOKEN_EMBD, // gemma3n @@ -576,8 +587,6 @@ struct LLM_TN_IMPL { const int bid; const int xid; - const std::set model_tensors; - LLM_TN_IMPL(llm_arch arch, llm_tensor tensor, const char * suffix, int bid, int xid); std::string str() const; @@ -623,6 +632,7 @@ llm_arch llm_arch_from_string(const std::string & name); const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor); -bool llm_arch_is_recurrent(const llm_arch & arch); -bool llm_arch_is_hybrid (const llm_arch & arch); -bool llm_arch_is_diffusion(const llm_arch & arch); +bool llm_arch_is_recurrent (const llm_arch & arch); +bool llm_arch_is_hybrid (const llm_arch & arch); +bool llm_arch_is_diffusion (const llm_arch & arch); +bool llm_arch_supports_sm_tensor(const llm_arch & arch); diff --git a/examples/talk-llama/llama-batch.h b/examples/talk-llama/llama-batch.h index 8e6fac0efab..f77520e86c3 100644 --- a/examples/talk-llama/llama-batch.h +++ b/examples/talk-llama/llama-batch.h @@ -18,7 +18,7 @@ struct llama_ubatch { } // typical for M-RoPE cases: - // 0 - sequantial position of the tokens/embeddings in the sequence + // 0 - sequential position of the tokens/embeddings in the sequence // 1 - y position in the image // 2 - x position in the image // 3 - other diff --git a/examples/talk-llama/llama-chat.cpp b/examples/talk-llama/llama-chat.cpp index c415a998f33..6554a89b28a 100644 --- a/examples/talk-llama/llama-chat.cpp +++ b/examples/talk-llama/llama-chat.cpp @@ -49,6 +49,7 @@ static const std::map LLM_CHAT_TEMPLATES = { { "deepseek", LLM_CHAT_TEMPLATE_DEEPSEEK }, { "deepseek2", LLM_CHAT_TEMPLATE_DEEPSEEK_2 }, { "deepseek3", LLM_CHAT_TEMPLATE_DEEPSEEK_3 }, + { "deepseek-ocr", LLM_CHAT_TEMPLATE_DEEPSEEK_OCR }, { "command-r", LLM_CHAT_TEMPLATE_COMMAND_R }, { "llama3", LLM_CHAT_TEMPLATE_LLAMA_3 }, { "chatglm3", LLM_CHAT_TEMPLATE_CHATGLM_3 }, @@ -59,7 +60,8 @@ static const std::map LLM_CHAT_TEMPLATES = { { "exaone4", LLM_CHAT_TEMPLATE_EXAONE_4 }, { "exaone-moe", LLM_CHAT_TEMPLATE_EXAONE_MOE }, { "rwkv-world", LLM_CHAT_TEMPLATE_RWKV_WORLD }, - { "granite", LLM_CHAT_TEMPLATE_GRANITE }, + { "granite", LLM_CHAT_TEMPLATE_GRANITE_3_X }, + { "granite-4.0", LLM_CHAT_TEMPLATE_GRANITE_4_0 }, { "gigachat", LLM_CHAT_TEMPLATE_GIGACHAT }, { "megrez", LLM_CHAT_TEMPLATE_MEGREZ }, { "yandex", LLM_CHAT_TEMPLATE_YANDEX }, @@ -71,6 +73,7 @@ static const std::map LLM_CHAT_TEMPLATES = { { "hunyuan-moe", LLM_CHAT_TEMPLATE_HUNYUAN_MOE }, { "gpt-oss", LLM_CHAT_TEMPLATE_OPENAI_MOE }, { "hunyuan-dense", LLM_CHAT_TEMPLATE_HUNYUAN_DENSE }, + { "hunyuan-ocr", LLM_CHAT_TEMPLATE_HUNYUAN_OCR }, { "kimi-k2", LLM_CHAT_TEMPLATE_KIMI_K2 }, { "seed_oss", LLM_CHAT_TEMPLATE_SEED_OSS }, { "grok-2", LLM_CHAT_TEMPLATE_GROK_2 }, @@ -190,7 +193,10 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) { } else if (tmpl_contains("rwkv-world") || tmpl_contains("{{- 'User: ' + message['content']|trim + '\\n\\n' -}}")) { return LLM_CHAT_TEMPLATE_RWKV_WORLD; } else if (tmpl_contains("<|start_of_role|>")) { - return LLM_CHAT_TEMPLATE_GRANITE; + if (tmpl_contains("") || tmpl_contains("")) { + return LLM_CHAT_TEMPLATE_GRANITE_4_0; + } + return LLM_CHAT_TEMPLATE_GRANITE_3_X; } else if (tmpl_contains("message['role'] + additional_special_tokens[0] + message['content'] + additional_special_tokens[1]")) { return LLM_CHAT_TEMPLATE_GIGACHAT; } else if (tmpl_contains("<|role_start|>")) { @@ -211,6 +217,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) { return LLM_CHAT_TEMPLATE_HUNYUAN_MOE; } else if (tmpl_contains("<|start|>") && tmpl_contains("<|channel|>")) { return LLM_CHAT_TEMPLATE_OPENAI_MOE; + } else if (tmpl_contains("<|hy_Assistant|>") && tmpl_contains("<|hy_begin▁of▁sentence|>")) { + return LLM_CHAT_TEMPLATE_HUNYUAN_OCR; } else if (tmpl_contains("<|hy_Assistant|>") && tmpl_contains("<|hy_place▁holder▁no▁3|>")) { return LLM_CHAT_TEMPLATE_HUNYUAN_DENSE; } else if (tmpl_contains("<|im_assistant|>assistant<|im_middle|>")) { @@ -548,6 +556,11 @@ int32_t llm_chat_apply_template( if (add_ass) { ss << LU8("<|Assistant|>"); } + } else if (tmpl == LLM_CHAT_TEMPLATE_DEEPSEEK_OCR) { + for (auto message : chat) { + // no template + ss << message->content; + } } else if (tmpl == LLM_CHAT_TEMPLATE_EXAONE_3) { // ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/discussions/8#66bae61b1893d14ee8ed85bb // EXAONE-3.0-7.8B-Instruct @@ -611,8 +624,8 @@ int32_t llm_chat_apply_template( ss << "Assistant: " << trim(chat[i]->content) << "\n\n"; } } - } else if (tmpl == LLM_CHAT_TEMPLATE_GRANITE) { - // IBM Granite template + } else if (tmpl == LLM_CHAT_TEMPLATE_GRANITE_3_X) { + // IBM Granite 3.x template for (const auto & message : chat) { std::string role(message->role); ss << "<|start_of_role|>" << role << "<|end_of_role|>"; @@ -624,6 +637,20 @@ int32_t llm_chat_apply_template( if (add_ass) { ss << "<|start_of_role|>assistant<|end_of_role|>"; } + } else if (tmpl == LLM_CHAT_TEMPLATE_GRANITE_4_0) { + // IBM Granite 4.0 template + for (const auto & message : chat) { + std::string role(message->role); + if (role == "assistant_tool_call") { + ss << "<|start_of_role|>assistant<|end_of_role|><|tool_call|>"; + } else { + ss << "<|start_of_role|>" << role << "<|end_of_role|>"; + } + ss << message->content << "<|end_of_text|>\n"; + } + if (add_ass) { + ss << "<|start_of_role|>assistant<|end_of_role|>"; + } } else if (tmpl == LLM_CHAT_TEMPLATE_GIGACHAT) { // GigaChat template bool has_system = !chat.empty() && std::string(chat[0]->role) == "system"; @@ -798,6 +825,22 @@ int32_t llm_chat_apply_template( ss << "<|hy_User|>" << chat[i]->content << "<|hy_Assistant|>"; } } + } else if (tmpl == LLM_CHAT_TEMPLATE_HUNYUAN_OCR) { + // tencent/HunyuanOCR + ss << "<|hy_begin▁of▁sentence|>"; + for (size_t i = 0; i < chat.size(); i++) { + std::string role(chat[i]->role); + if (i == 0 && role == "system") { + ss << chat[i]->content << "<|hy_place▁holder▁no▁3|>"; + continue; + } + + if (role == "user") { + ss << chat[i]->content << "<|hy_User|>"; + } else if (role == "assistant") { + ss << chat[i]->content << "<|hy_Assistant|>"; + } + } } else if (tmpl == LLM_CHAT_TEMPLATE_KIMI_K2) { // moonshotai/Kimi-K2-Instruct for (auto message : chat) { diff --git a/examples/talk-llama/llama-chat.h b/examples/talk-llama/llama-chat.h index 9ed1db128ec..13f936a946c 100644 --- a/examples/talk-llama/llama-chat.h +++ b/examples/talk-llama/llama-chat.h @@ -28,6 +28,7 @@ enum llm_chat_template { LLM_CHAT_TEMPLATE_DEEPSEEK, LLM_CHAT_TEMPLATE_DEEPSEEK_2, LLM_CHAT_TEMPLATE_DEEPSEEK_3, + LLM_CHAT_TEMPLATE_DEEPSEEK_OCR, LLM_CHAT_TEMPLATE_COMMAND_R, LLM_CHAT_TEMPLATE_LLAMA_3, LLM_CHAT_TEMPLATE_CHATGLM_3, @@ -38,7 +39,8 @@ enum llm_chat_template { LLM_CHAT_TEMPLATE_EXAONE_4, LLM_CHAT_TEMPLATE_EXAONE_MOE, LLM_CHAT_TEMPLATE_RWKV_WORLD, - LLM_CHAT_TEMPLATE_GRANITE, + LLM_CHAT_TEMPLATE_GRANITE_3_X, + LLM_CHAT_TEMPLATE_GRANITE_4_0, LLM_CHAT_TEMPLATE_GIGACHAT, LLM_CHAT_TEMPLATE_MEGREZ, LLM_CHAT_TEMPLATE_YANDEX, @@ -51,6 +53,7 @@ enum llm_chat_template { LLM_CHAT_TEMPLATE_HUNYUAN_MOE, LLM_CHAT_TEMPLATE_OPENAI_MOE, LLM_CHAT_TEMPLATE_HUNYUAN_DENSE, + LLM_CHAT_TEMPLATE_HUNYUAN_OCR, LLM_CHAT_TEMPLATE_KIMI_K2, LLM_CHAT_TEMPLATE_SEED_OSS, LLM_CHAT_TEMPLATE_GROK_2, diff --git a/examples/talk-llama/llama-context.cpp b/examples/talk-llama/llama-context.cpp index 1f7a52d7895..8126249e143 100644 --- a/examples/talk-llama/llama-context.cpp +++ b/examples/talk-llama/llama-context.cpp @@ -1,5 +1,6 @@ #include "llama-context.h" +#include "ggml.h" #include "llama-arch.h" #include "llama-impl.h" #include "llama-batch.h" @@ -8,6 +9,7 @@ #include "llama-mmap.h" #include "llama-model.h" #include "llama-ext.h" +#include "llama.h" #include #include @@ -217,10 +219,10 @@ llama_context::llama_context( if (!hparams.vocab_only) { // GPU backends - for (auto * dev : model.devices) { - ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr); + for (const auto & dev : model.devices) { + ggml_backend_t backend = ggml_backend_dev_init(dev.dev, nullptr); if (backend == nullptr) { - throw std::runtime_error(format("failed to initialize %s backend", ggml_backend_dev_name(dev))); + throw std::runtime_error(format("failed to initialize %s backend", ggml_backend_dev_name(dev.dev))); } backends.emplace_back(backend); } @@ -295,8 +297,8 @@ llama_context::llama_context( if (backend_type == GGML_BACKEND_DEVICE_TYPE_CPU && !model.devices.empty()) { // use the host buffer of the first device CPU for faster transfer of the intermediate state - auto * dev = model.devices[0]; - auto * host_buft = ggml_backend_dev_host_buffer_type(dev); + const auto & dev = model.devices[0]; + auto * host_buft = ggml_backend_dev_host_buffer_type(dev.dev); if (host_buft) { buft = host_buft; } @@ -342,14 +344,6 @@ llama_context::llama_context( if (cparams.pipeline_parallel) { LLAMA_LOG_INFO("%s: pipeline parallelism enabled\n", __func__); - - if (!graph_reuse_disable) { - // TODO: figure out a way to make graph reuse work with pipeline parallelism - // ref: https://github.com/ggml-org/llama.cpp/pull/20463 - LLAMA_LOG_WARN("%s: graph reuse is currently not compatible with pipeline parallelism - disabling\n", __func__); - - graph_reuse_disable = true; - } } sched_reserve(); @@ -594,7 +588,7 @@ void llama_context::sched_reserve() { // reserve again with pp graph to avoid ggml-alloc reallocations during inference { - // TODO: not sure if the following graph would be worster case for multi-stream KV caches: + // TODO: not sure if the following graph would be worst case for multi-stream KV caches: // // auto * gf = graph_reserve(n_tokens, 1, n_tokens, mctx.get()); // @@ -1028,9 +1022,11 @@ void llama_context::set_abort_callback(bool (*abort_callback)(void * data), void for (auto & backend : backends) { auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend.get())); - auto * set_abort_callback_fn = (ggml_backend_set_abort_callback_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_abort_callback"); - if (set_abort_callback_fn) { - set_abort_callback_fn(backend.get(), this->abort_callback, this->abort_callback_data); + if (reg) { + auto * set_abort_callback_fn = (ggml_backend_set_abort_callback_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_abort_callback"); + if (set_abort_callback_fn) { + set_abort_callback_fn(backend.get(), this->abort_callback, this->abort_callback_data); + } } } } @@ -1165,9 +1161,11 @@ bool llama_context::set_adapter_cvec( int32_t il_end) { LLAMA_LOG_DEBUG("%s: il_start = %d, il_end = %d\n", __func__, il_start, il_end); - // TODO: should we reserve? + bool res = cvec->apply(model, data, len, n_embd, il_start, il_end); - return cvec->apply(model, data, len, n_embd, il_start, il_end); + sched_need_reserve = true; + + return res; } llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) { @@ -1187,6 +1185,13 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll if (!graph_reuse_disable && res->can_reuse(gparams)) { //LLAMA_LOG_DEBUG("%s: reusing previous graph\n", __func__); + // with pipeline parallelism, the previous graph_compute_async may still be running + // on the GPU. we must synchronize before set_inputs to avoid overwriting input tensors + // that the previous compute is still reading. + if (cparams.pipeline_parallel) { + ggml_backend_sched_synchronize(sched.get()); + } + n_reused++; } else { res->reset(); @@ -1345,8 +1350,11 @@ int llama_context::encode(const llama_batch & batch_inp) { const llama_seq_id seq_id = ubatch.seq_id_unq[s]; const int32_t seq_idx = ubatch.seq_idx[seq_id]; - embd_seq_out[seq_id].resize(n_embd); - ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float)); + // use n_embd_out (not n_embd_inp) - the pooled embedding has the model's + // output dimension, which differs from input dimension for deepstack models (e.g. qwen3vl) + const uint32_t n_embd_out = hparams.n_embd_out(); + embd_seq_out[seq_id].resize(n_embd_out); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd_out*seq_idx)*sizeof(float), n_embd_out*sizeof(float)); } } break; case LLAMA_POOLING_TYPE_RANK: @@ -1767,12 +1775,16 @@ int llama_context::decode(const llama_batch & batch_inp) { // extract sequence embeddings (cleared before processing each batch) auto & embd_seq_out = embd_seq; + // use n_embd_out (not n_embd_inp) - the pooled embedding has the model's + // output dimension, which differs from input dimension for deepstack models (e.g. qwen3vl) + const uint32_t n_embd_out = hparams.n_embd_out(); + for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) { const llama_seq_id seq_id = ubatch.seq_id_unq[s]; const int32_t seq_idx = ubatch.seq_idx[seq_id]; - embd_seq_out[seq_id].resize(n_embd); - ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float)); + embd_seq_out[seq_id].resize(n_embd_out); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd_out*seq_idx)*sizeof(float), n_embd_out*sizeof(float)); } } break; case LLAMA_POOLING_TYPE_RANK: @@ -1944,6 +1956,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { LLAMA_LOG_ERROR("%s: failed to allocate output buffer of size %.2f MiB\n", __func__, new_size / (1024.0 * 1024.0)); return 0; } + ggml_backend_buffer_clear(buf_output.get(), 0); } float * output_base = (float *) ggml_backend_buffer_get_base(buf_output.get()); @@ -2623,7 +2636,7 @@ void llama_context::perf_reset() { n_reused = 0; } -std::map llama_context::memory_breakdown() const { +llama_memory_breakdown llama_context::memory_breakdown() const { std::map ret; for (const auto & [buft, size] : model.memory_breakdown()) { ret[buft].model += size; @@ -2933,7 +2946,22 @@ llama_context * llama_init_from_model( params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED; } - if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_k)) { + if (model->split_mode() == LLAMA_SPLIT_MODE_TENSOR) { + if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO) { + LLAMA_LOG_INFO("%s: enabling flash_attn since it is required for SPLIT_MODE_TENSOR\n", __func__); + params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_ENABLED; + } + if (params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_ENABLED) { + LLAMA_LOG_ERROR("%s: SPLIT_MODE_TENSOR requires flash_attn to be enabled\n", __func__); + return nullptr; + } + if (ggml_is_quantized(params.type_k) || ggml_is_quantized(params.type_v)) { + LLAMA_LOG_ERROR("%s: simultaneous use of SPLIT_MODE_TENSOR and KV cache quantization not implemented\n", __func__); + return nullptr; + } + } + + if (params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED && ggml_is_quantized(params.type_k)) { const uint32_t blck_size = ggml_blck_size(params.type_k); for (uint32_t il = 0; il < model->hparams.n_layer; ++il) { if (model->hparams.n_embd_head_k(il) % blck_size != 0) { @@ -2944,7 +2972,7 @@ llama_context * llama_init_from_model( } } - if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_v)) { + if (params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED && ggml_is_quantized(params.type_v)) { const uint32_t blck_size = ggml_blck_size(params.type_v); for (uint32_t il = 0; il < model->hparams.n_layer; ++il) { if (model->hparams.n_embd_head_v(il) % blck_size != 0) { @@ -3465,142 +3493,6 @@ void llama_perf_context_reset(llama_context * ctx) { ctx->perf_reset(); } -void llama_memory_breakdown_print(const struct llama_context * ctx) { - const std::vector & devices = ctx->get_model().devices; - - std::map memory_breakdown = ctx->memory_breakdown(); - - std::vector> table_data; - table_data.reserve(devices.size()); - const std::string template_header = "%s: | %s | %s %s %s %s %s %s %s |\n"; - const std::string template_gpu = "%s: | %s | %s = %s + (%s = %s + %s + %s) + %s |\n"; - const std::string template_other = "%s: | %s | %s %s %s = %s + %s + %s %s |\n"; - - table_data.push_back({template_header, "memory breakdown [MiB]", "total", "free", "self", "model", "context", "compute", "unaccounted"}); - - constexpr size_t MiB = 1024 * 1024; - const std::vector desc_prefixes_strip = {"NVIDIA ", "GeForce ", "Tesla ", "AMD ", "Radeon ", "Instinct "}; - - // track seen buffer types to avoid double counting: - std::set seen_buffer_types; - - // accumulative memory breakdown for each device and for host: - std::vector mb_dev(devices.size()); - llama_memory_breakdown_data mb_host; - - for (const auto & buft_mb : memory_breakdown) { - ggml_backend_buffer_type_t buft = buft_mb.first; - const llama_memory_breakdown_data & mb = buft_mb.second; - if (ggml_backend_buft_is_host(buft)) { - mb_host.model += mb.model; - mb_host.context += mb.context; - mb_host.compute += mb.compute; - seen_buffer_types.insert(buft); - continue; - } - ggml_backend_dev_t dev = ggml_backend_buft_get_device(buft); - if (dev) { - int i_dev = -1; - for (size_t i = 0; i < devices.size(); i++) { - if (devices[i] == dev) { - i_dev = i; - break; - } - } - if (i_dev != -1) { - mb_dev[i_dev].model += mb.model; - mb_dev[i_dev].context += mb.context; - mb_dev[i_dev].compute += mb.compute; - seen_buffer_types.insert(buft); - continue; - } - } - } - - // print memory breakdown for each device: - for (size_t i = 0; i < devices.size(); i++) { - ggml_backend_dev_t dev = devices[i]; - llama_memory_breakdown_data mb = mb_dev[i]; - - const std::string name = ggml_backend_dev_name(dev); - std::string desc = ggml_backend_dev_description(dev); - for (const std::string & prefix : desc_prefixes_strip) { - if (desc.length() >= prefix.length() && desc.substr(0, prefix.length()) == prefix) { - desc = desc.substr(prefix.length()); - } - } - - size_t free, total; - ggml_backend_dev_memory(dev, &free, &total); - - const size_t self = mb.model + mb.context + mb.compute; - const size_t unaccounted = total - self - free; - - table_data.push_back({ - template_gpu, - " - " + name + " (" + desc + ")", - std::to_string(total / MiB), - std::to_string(free / MiB), - std::to_string(self / MiB), - std::to_string(mb.model / MiB), - std::to_string(mb.context / MiB), - std::to_string(mb.compute / MiB), - std::to_string(unaccounted / MiB)}); - } - - // print memory breakdown for host: - { - const size_t self = mb_host.model + mb_host.context + mb_host.compute; - table_data.push_back({ - template_other, - " - Host", - "", // total - "", // free - std::to_string(self / MiB), - std::to_string(mb_host.model / MiB), - std::to_string(mb_host.context / MiB), - std::to_string(mb_host.compute / MiB), - ""}); // unaccounted - } - - // print memory breakdown for all remaining buffer types: - for (const auto & buft_mb : memory_breakdown) { - ggml_backend_buffer_type_t buft = buft_mb.first; - const llama_memory_breakdown_data & mb = buft_mb.second; - if (seen_buffer_types.count(buft) == 1) { - continue; - } - const std::string name = ggml_backend_buft_name(buft); - const size_t self = mb.model + mb.context + mb.compute; - table_data.push_back({ - template_other, - " - " + name, - "", // total - "", // free - std::to_string(self / MiB), - std::to_string(mb.model / MiB), - std::to_string(mb.context / MiB), - std::to_string(mb.compute / MiB), - ""}); // unaccounted - seen_buffer_types.insert(buft); - } - - for (size_t j = 1; j < table_data[0].size(); j++) { - size_t max_len = 0; - for (const auto & td : table_data) { - max_len = std::max(max_len, td[j].length()); - } - for (auto & td : table_data) { - td[j].insert(j == 1 ? td[j].length() : 0, max_len - td[j].length(), ' '); - } - } - for (const auto & td : table_data) { - LLAMA_LOG_INFO(td[0].c_str(), - __func__, td[1].c_str(), td[2].c_str(), td[3].c_str(), td[4].c_str(), td[5].c_str(), - td[6].c_str(), td[7].c_str(), td[8].c_str()); - } -} - // // training // @@ -3631,3 +3523,11 @@ void llama_opt_epoch( callback_train, callback_eval); } + +// +// ext +// + +llama_memory_breakdown llama_get_memory_breakdown(const struct llama_context * ctx) { + return ctx->memory_breakdown(); +} diff --git a/examples/talk-llama/llama-context.h b/examples/talk-llama/llama-context.h index e0d0085c1c3..53c705eaffc 100644 --- a/examples/talk-llama/llama-context.h +++ b/examples/talk-llama/llama-context.h @@ -1,6 +1,7 @@ #pragma once #include "llama.h" +#include "llama-ext.h" #include "llama-cparams.h" #include "llama-graph.h" #include "llama-adapter.h" @@ -22,17 +23,6 @@ class llama_io_write_i; struct llama_memory_i; struct llama_memory_context_i; -// "memory" as in physical memory for a buffer type, in bytes -struct llama_memory_breakdown_data { - size_t model = 0; // memory allocated for the model - size_t context = 0; // memory allocated for the context - size_t compute = 0; // memory allocated for temporary compute buffers - - size_t total() const { - return model + context + compute; - } -}; - struct llama_context { // init scheduler and compute buffers, reserve worst-case graphs llama_context( @@ -172,7 +162,7 @@ struct llama_context { llama_perf_context_data perf_get_data() const; void perf_reset(); - std::map memory_breakdown() const; + llama_memory_breakdown memory_breakdown() const; // // training diff --git a/examples/talk-llama/llama-ext.h b/examples/talk-llama/llama-ext.h index 13ced783b42..8ce29d217cb 100644 --- a/examples/talk-llama/llama-ext.h +++ b/examples/talk-llama/llama-ext.h @@ -1,8 +1,12 @@ #pragma once -#include "llama-context.h" -#include "ggml.h" -#include "stdint.h" +// this is a staging header for new llama.cpp API +// breaking changes and C++ are allowed. everything here should be considered WIP + +#include "llama.h" + +#include +#include // Reserve a new compute graph. It is valid until the next call to llama_graph_reserve. LLAMA_API struct ggml_cgraph * llama_graph_reserve( @@ -10,3 +14,77 @@ LLAMA_API struct ggml_cgraph * llama_graph_reserve( uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs); + +// Get the default ggml_type for a given ftype. +LLAMA_API ggml_type llama_ftype_get_default_type(llama_ftype ftype); + +struct quantize_state_impl; + +LLAMA_API quantize_state_impl * llama_quant_init( + const llama_model * model, + const llama_model_quantize_params * params); + +LLAMA_API void llama_quant_free(quantize_state_impl * qs); + +// Descriptor for constructing a mock model for quantization testing. +struct llama_quant_model_desc { + const char * architecture; + uint32_t n_embd; + uint32_t n_ff; + uint32_t n_layer; + uint32_t n_head; + uint32_t n_head_kv; + uint32_t n_expert; + uint32_t n_embd_head_k; + uint32_t n_embd_head_v; +}; + +// Create a mock model from a metadata descriptor (for testing). +// The returned model must be freed with llama_model_free(). +LLAMA_API llama_model * llama_quant_model_from_metadata(const llama_quant_model_desc * desc); + +// Returns true if this tensor should be quantized (based on name, dims, params). +LLAMA_API bool llama_quant_tensor_allows_quantization( + const quantize_state_impl * qs, + const ggml_tensor * tensor); + +// Compute quantization type assignments for a list of tensors. +// All tensors should be quantizable (use llama_quant_tensor_allows_quantization to filter). +// result_types: caller-allocated array of n_tensors elements, filled with assigned types. +LLAMA_API void llama_quant_compute_types( + quantize_state_impl * qs, + llama_ftype ftype, + ggml_tensor ** tensors, + ggml_type * result_types, + size_t n_tensors); + +// +// device memory querying +// + +// "memory" as in physical memory for a buffer type, in bytes +struct llama_memory_breakdown_data { + size_t model = 0; // memory allocated for the model + size_t context = 0; // memory allocated for the context + size_t compute = 0; // memory allocated for temporary compute buffers + + size_t total() const { + return model + context + compute; + } +}; + +struct llama_device_memory_data { + int64_t total; + int64_t free; + llama_memory_breakdown_data mb; +}; + +// TODO: convert to C-style data structure +using llama_memory_breakdown = std::map; + +LLAMA_API int32_t llama_model_n_expert (const struct llama_model * model); +LLAMA_API int32_t llama_model_n_devices(const struct llama_model * model); + +LLAMA_API ggml_backend_dev_t llama_model_get_device(const struct llama_model * model, int i); + +LLAMA_API llama_memory_breakdown llama_get_memory_breakdown(const struct llama_context * ctx); diff --git a/examples/talk-llama/llama-grammar.cpp b/examples/talk-llama/llama-grammar.cpp index aac0d41f2b4..badcbfd0fbb 100644 --- a/examples/talk-llama/llama-grammar.cpp +++ b/examples/talk-llama/llama-grammar.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #define MAX_REPETITION_THRESHOLD 2000 @@ -454,6 +455,7 @@ const char * llama_grammar_parser::parse_sequence( bool is_nested) { size_t last_sym_start = rule.size(); const char * pos = src; + uint64_t n_prev_rules = 1; // use UINT64_MAX as the empty value because we aligned to the proper uint64_t type so -1 can't be used // (though it's technically the same as -1 now) @@ -481,6 +483,18 @@ const char * llama_grammar_parser::parse_sequence( // S' ::= S | llama_grammar_rule prev_rule(rule.begin() + last_sym_start, rule.end()); + // Calculate the total number of rules that will be generated by this repetition + uint64_t total_rules = 1; // Start with 1 for the original rule + if (!no_max && max_times > 0) { + total_rules = max_times; + } else if (min_times > 0) { + total_rules = min_times; + } + + if (n_prev_rules * total_rules >= MAX_REPETITION_THRESHOLD) { + throw std::runtime_error("number of rules that are going to be repeated multiplied by the new repetition exceeds sane defaults, please reduce the number of repetitions or rule complexity"); + } + if (min_times == 0) { rule.resize(last_sym_start); } else { @@ -508,12 +522,15 @@ const char * llama_grammar_parser::parse_sequence( if (n_opt > 0) { rule.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id}); } + n_prev_rules *= total_rules; + GGML_ASSERT(n_prev_rules >= 1); }; while (*pos) { if (*pos == '"') { // literal string pos++; last_sym_start = rule.size(); + n_prev_rules = 1; while (*pos != '"') { if (!*pos) { throw std::runtime_error("unexpected end of input"); @@ -531,6 +548,7 @@ const char * llama_grammar_parser::parse_sequence( start_type = LLAMA_GRETYPE_CHAR_NOT; } last_sym_start = rule.size(); + n_prev_rules = 1; while (*pos != ']') { if (!*pos) { throw std::runtime_error("unexpected end of input"); @@ -561,6 +579,7 @@ const char * llama_grammar_parser::parse_sequence( auto token_pair = parse_token(vocab, pos); const char * token_end = token_pair.second; last_sym_start = rule.size(); + n_prev_rules = 1; rule.push_back({type, token_pair.first}); pos = parse_space(token_end, is_nested); } else if (is_word_char(*pos)) { // rule reference @@ -568,12 +587,15 @@ const char * llama_grammar_parser::parse_sequence( uint32_t ref_rule_id = get_symbol_id(pos, name_end - pos); pos = parse_space(name_end, is_nested); last_sym_start = rule.size(); + n_prev_rules = 1; rule.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id}); } else if (*pos == '(') { // grouping // parse nested alternates into synthesized rule pos = parse_space(pos + 1, true); + uint32_t n_rules_before = symbol_ids.size(); uint32_t sub_rule_id = generate_symbol_id(rule_name); pos = parse_alternates(pos, rule_name, sub_rule_id, true); + n_prev_rules = std::max(1u, (uint32_t)symbol_ids.size() - n_rules_before); last_sym_start = rule.size(); // output reference to synthesized rule rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id}); @@ -583,6 +605,7 @@ const char * llama_grammar_parser::parse_sequence( pos = parse_space(pos + 1, is_nested); } else if (*pos == '.') { // any char last_sym_start = rule.size(); + n_prev_rules = 1; rule.push_back({LLAMA_GRETYPE_CHAR_ANY, 0}); pos = parse_space(pos + 1, is_nested); } else if (*pos == '*') { @@ -830,32 +853,54 @@ static bool llama_grammar_match_token( static void llama_grammar_advance_stack( const llama_grammar_rules & rules, const llama_grammar_stack & stack, - llama_grammar_stacks & new_stacks) { - if (stack.empty()) { - if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) { - new_stacks.emplace_back(stack); + llama_grammar_stacks & new_stacks) { + std::vector todo; + todo.push_back(stack); + + auto stack_cmp = [](const llama_grammar_stack & a, const llama_grammar_stack & b) { + return std::lexicographical_compare(a.begin(), a.end(), b.begin(), b.end(), + [](const llama_grammar_element * pa, const llama_grammar_element * pb) { + return pa < pb; // Compare pointer addresses + } + ); + }; + + std::set seen(stack_cmp); + + while (!todo.empty()) { + llama_grammar_stack curr_stack = std::move(todo.back()); + todo.pop_back(); + + if (seen.find( curr_stack) != seen.end()) { + continue; } - return; - } + seen.insert(curr_stack); - const llama_grammar_element * pos = stack.back(); + if (curr_stack.empty()) { + if (std::find(new_stacks.begin(), new_stacks.end(), curr_stack) == new_stacks.end()) { + new_stacks.emplace_back(std::move(curr_stack)); + } + continue; + } - switch (pos->type) { + const llama_grammar_element * pos = curr_stack.back(); + + switch (pos->type) { case LLAMA_GRETYPE_RULE_REF: { const size_t rule_id = static_cast(pos->value); const llama_grammar_element * subpos = rules[rule_id].data(); do { // init new stack without the top (pos) - llama_grammar_stack new_stack(stack.begin(), stack.end() - 1); + llama_grammar_stack next_stack(curr_stack.begin(), curr_stack.end() - 1); if (!llama_grammar_is_end_of_sequence(pos + 1)) { // if this rule ref is followed by another element, add that to stack - new_stack.push_back(pos + 1); + next_stack.push_back(pos + 1); } if (!llama_grammar_is_end_of_sequence(subpos)) { // if alternate is nonempty, add to stack - new_stack.push_back(subpos); + next_stack.push_back(subpos); } - llama_grammar_advance_stack(rules, new_stack, new_stacks); + todo.push_back(std::move(next_stack)); while (!llama_grammar_is_end_of_sequence(subpos)) { // scan to end of alternate def subpos++; @@ -874,9 +919,9 @@ static void llama_grammar_advance_stack( case LLAMA_GRETYPE_CHAR_ANY: case LLAMA_GRETYPE_TOKEN: case LLAMA_GRETYPE_TOKEN_NOT: - if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) { + if (std::find(new_stacks.begin(), new_stacks.end(), curr_stack) == new_stacks.end()) { // only add the stack if it's not a duplicate of one we already have - new_stacks.emplace_back(stack); + new_stacks.emplace_back(std::move(curr_stack)); } break; default: @@ -884,6 +929,7 @@ static void llama_grammar_advance_stack( // (LLAMA_GRETYPE_CHAR_ALT, LLAMA_GRETYPE_CHAR_RNG_UPPER); stack should never be left on // those GGML_ABORT("fatal error"); + } } } diff --git a/examples/talk-llama/llama-graph.cpp b/examples/talk-llama/llama-graph.cpp index 9a215bb77a0..2ff23f87cf4 100644 --- a/examples/talk-llama/llama-graph.cpp +++ b/examples/talk-llama/llama-graph.cpp @@ -1,6 +1,7 @@ #include "llama-graph.h" #include "llama-impl.h" +#include "llama-model.h" #include "llama-batch.h" #include "llama-cparams.h" @@ -19,7 +20,7 @@ // dedup helpers -static ggml_tensor * build_kq_mask( +static ggml_tensor * build_attn_inp_kq_mask( ggml_context * ctx, const llama_kv_cache_context * mctx, const llama_ubatch & ubatch, @@ -28,7 +29,11 @@ static ggml_tensor * build_kq_mask( const auto n_tokens = ubatch.n_tokens; const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq; - return ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream); + ggml_tensor * res = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream); + ggml_set_input(res); + ggml_set_name(res, "attn_inp_kq_mask"); + + return res; } static bool can_reuse_kq_mask( @@ -52,6 +57,21 @@ static bool can_reuse_kq_mask( // impl +static ggml_tensor * ggml_mul_mat_aux( + ggml_context * ctx, + ggml_tensor * cur, + ggml_tensor * rot) { + const auto n = rot->ne[0]; + + ggml_tensor * res; + + res = ggml_reshape_2d(ctx, cur, n, ggml_nelements(cur)/n); + res = ggml_mul_mat (ctx, rot, res); + res = ggml_reshape_4d(ctx, res, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]); + + return res; +} + void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) { if (ubatch->token) { const int64_t n_tokens = ubatch->n_tokens; @@ -429,6 +449,14 @@ void llm_graph_input_attn_kv::set_input(const llama_ubatch * ubatch) { mctx->set_input_v_idxs(self_v_idxs, ubatch); mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); + + if (self_k_rot) { + mctx->set_input_k_rot(self_k_rot); + } + + if (self_v_rot) { + mctx->set_input_v_rot(self_v_rot); + } } bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) { @@ -476,6 +504,22 @@ void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) { mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch); mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn); + + if (self_k_rot) { + mctx->get_base()->set_input_k_rot(self_k_rot); + } + + if (self_v_rot) { + mctx->get_base()->set_input_v_rot(self_v_rot); + } + + if (self_k_rot_swa) { + mctx->get_swa()->set_input_k_rot(self_k_rot_swa); + } + + if (self_v_rot_swa) { + mctx->get_swa()->set_input_v_rot(self_v_rot_swa); + } } bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) { @@ -532,6 +576,14 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) { mctx->get_attn()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn); + if (inp_attn->self_k_rot) { + mctx->get_attn()->set_input_k_rot(inp_attn->self_k_rot); + } + + if (inp_attn->self_v_rot) { + mctx->get_attn()->set_input_v_rot(inp_attn->self_v_rot); + } + const int64_t n_rs = mctx->get_recr()->get_n_rs(); if (inp_rs->s_copy) { @@ -630,6 +682,22 @@ void llm_graph_input_mem_hybrid_iswa::set_input(const llama_ubatch * ubatch) { attn_ctx->get_swa()->set_input_kq_mask(inp_attn->self_kq_mask_swa, ubatch, cparams.causal_attn); } + if (inp_attn->self_k_rot) { + attn_ctx->get_base()->set_input_k_rot(inp_attn->self_k_rot); + } + + if (inp_attn->self_v_rot) { + attn_ctx->get_base()->set_input_v_rot(inp_attn->self_v_rot); + } + + if (inp_attn->self_k_rot_swa) { + attn_ctx->get_swa()->set_input_k_rot(inp_attn->self_k_rot_swa); + } + + if (inp_attn->self_v_rot_swa) { + attn_ctx->get_swa()->set_input_v_rot(inp_attn->self_v_rot_swa); + } + const int64_t n_rs = mctx->get_recr()->get_n_rs(); if (inp_rs->s_copy) { @@ -992,6 +1060,84 @@ ggml_tensor * llm_graph_context::build_norm( return cur; } + +llm_graph_qkv llm_graph_context::build_qkv( + const llama_layer & layer, + ggml_tensor * cur, + int64_t n_embd_head, + int64_t n_head, + int64_t n_head_kv, + int il) const { + const int64_t n_embd_q = n_embd_head * n_head; + const int64_t n_embd_kv = n_embd_head * n_head_kv; + + ggml_tensor * Qcur, * Kcur, * Vcur; + + if (layer.wqkv) { + // fused QKV path + ggml_tensor * qkv = build_lora_mm(layer.wqkv, cur, layer.wqkv_s); + cb(qkv, "wqkv", il); + if (layer.wqkv_b) { + qkv = ggml_add(ctx0, qkv, layer.wqkv_b); + cb(qkv, "wqkv_b", il); + } + if (hparams.f_clamp_kqv > 0.0f) { + qkv = ggml_clamp(ctx0, qkv, -hparams.f_clamp_kqv, hparams.f_clamp_kqv); + cb(qkv, "wqkv_clamped", il); + } + Qcur = ggml_view_3d(ctx0, qkv, n_embd_head, n_head, n_tokens, + ggml_row_size(qkv->type, n_embd_head), qkv->nb[1], 0); + Kcur = ggml_view_3d(ctx0, qkv, n_embd_head, n_head_kv, n_tokens, + ggml_row_size(qkv->type, n_embd_head), qkv->nb[1], + ggml_row_size(qkv->type, n_embd_q)); + Vcur = ggml_view_3d(ctx0, qkv, n_embd_head, n_head_kv, n_tokens, + ggml_row_size(qkv->type, n_embd_head), qkv->nb[1], + ggml_row_size(qkv->type, n_embd_q + n_embd_kv)); + } else { + // separate Q/K/V path + Qcur = build_lora_mm(layer.wq, cur, layer.wq_s); + cb(Qcur, "Qcur", il); + if (layer.wq_b) { + Qcur = ggml_add(ctx0, Qcur, layer.wq_b); + cb(Qcur, "Qcur", il); + } + if (hparams.f_clamp_kqv > 0.0f) { + Qcur = ggml_clamp(ctx0, Qcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv); + cb(Qcur, "Qcur_clamped", il); + } + Kcur = build_lora_mm(layer.wk, cur, layer.wk_s); + cb(Kcur, "Kcur", il); + if (layer.wk_b) { + Kcur = ggml_add(ctx0, Kcur, layer.wk_b); + cb(Kcur, "Kcur", il); + } + if (hparams.f_clamp_kqv > 0.0f) { + Kcur = ggml_clamp(ctx0, Kcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv); + cb(Kcur, "Kcur_clamped", il); + } + Vcur = build_lora_mm(layer.wv, cur, layer.wv_s); + cb(Vcur, "Vcur", il); + if (layer.wv_b) { + Vcur = ggml_add(ctx0, Vcur, layer.wv_b); + cb(Vcur, "Vcur", il); + } + if (hparams.f_clamp_kqv > 0.0f) { + Vcur = ggml_clamp(ctx0, Vcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv); + cb(Vcur, "Vcur_clamped", il); + } + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + } + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + return { Qcur, Kcur, Vcur }; +} + + ggml_tensor * llm_graph_context::build_ffn( ggml_tensor * cur, ggml_tensor * up, @@ -1516,9 +1662,11 @@ ggml_tensor * llm_graph_context::build_moe_ffn( if (!weight_before_ffn) { experts = ggml_mul(ctx0, experts, weights); - cb(cur, "ffn_moe_weighted", il); + cb(experts, "ffn_moe_weighted", il); } + ggml_build_forward_expand(gf, experts); + ggml_tensor * cur_experts[LLAMA_MAX_EXPERTS] = { nullptr }; assert(n_expert_used > 0); @@ -1538,6 +1686,8 @@ ggml_tensor * llm_graph_context::build_moe_ffn( for (uint32_t i = 1; i < hparams.n_expert_used; ++i) { moe_out = ggml_add(ctx0, moe_out, cur_experts[i]); + + ggml_build_forward_expand(gf, moe_out); } if (hparams.n_expert_used == 1) { @@ -1665,7 +1815,7 @@ ggml_tensor * llm_graph_context::build_inp_attn_scale() const { ggml_tensor * llm_graph_context::build_inp_out_ids() const { // note: when all tokens are output, we could skip this optimization to spare the ggml_get_rows() calls, - // but this would make the graph topology depend on the number of output tokens, which can interere with + // but this would make the graph topology depend on the number of output tokens, which can interfere with // features that require constant topology such as pipeline parallelism // ref: https://github.com/ggml-org/llama.cpp/pull/14275#issuecomment-2987424471 //if (n_outputs < n_tokens) { @@ -1940,6 +2090,7 @@ ggml_tensor * llm_graph_context::build_attn( llm_graph_input_attn_no_cache * inp, ggml_tensor * wo, ggml_tensor * wo_b, + ggml_tensor * wo_s, ggml_tensor * q_cur, ggml_tensor * k_cur, ggml_tensor * v_cur, @@ -1973,7 +2124,7 @@ ggml_tensor * llm_graph_context::build_attn( cb(cur, "kqv_out", il); if (wo) { - cur = build_lora_mm(wo, cur); + cur = build_lora_mm(wo, cur, wo_s); } if (wo_b) { @@ -2002,13 +2153,13 @@ static std::unique_ptr build_attn_inp_kv_impl( inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch); inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch); - inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur, ubatch, cparams); - - ggml_set_input(inp->self_kq_mask); - + inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, mctx_cur, ubatch, cparams); inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; } + inp->self_k_rot = mctx_cur->build_input_k_rot(ctx0); + inp->self_v_rot = mctx_cur->build_input_v_rot(ctx0); + return inp; } @@ -2024,6 +2175,7 @@ ggml_tensor * llm_graph_context::build_attn( llm_graph_input_attn_kv * inp, ggml_tensor * wo, ggml_tensor * wo_b, + ggml_tensor * wo_s, ggml_tensor * q_cur, ggml_tensor * k_cur, ggml_tensor * v_cur, @@ -2034,6 +2186,15 @@ ggml_tensor * llm_graph_context::build_attn( int il) const { GGML_ASSERT(v_mla == nullptr); + if (inp->self_k_rot) { + q_cur = ggml_mul_mat_aux(ctx0, q_cur, inp->self_k_rot); + k_cur = ggml_mul_mat_aux(ctx0, k_cur, inp->self_k_rot); + } + + if (inp->self_v_rot) { + v_cur = ggml_mul_mat_aux(ctx0, v_cur, inp->self_v_rot); + } + // these nodes are added to the graph together so that they are not reordered // by doing so, the number of splits in the graph is reduced // expand k later to enable rope fusion which directly writes into k-v cache @@ -2061,11 +2222,20 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il); cb(cur, "kqv_out", il); + if (inp->self_v_rot) { + cur = ggml_mul_mat_aux(ctx0, cur, inp->self_v_rot); + } + if (wo) { - cur = build_lora_mm(wo, cur); if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE || arch == LLM_ARCH_JAIS2) { // GLM4, GLM4_MOE, and JAIS2 seem to have numerical issues with half-precision accumulators + cur = build_lora_mm(wo, cur); ggml_mul_mat_set_prec(cur, GGML_PREC_F32); + if (wo_s) { + cur = ggml_mul(ctx0, cur, wo_s); + } + } else { + cur = build_lora_mm(wo, cur, wo_s); } } @@ -2090,9 +2260,7 @@ static std::unique_ptr build_attn_inp_k_impl( inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch); - inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur, ubatch, cparams); - ggml_set_input(inp->self_kq_mask); - + inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, mctx_cur, ubatch, cparams); inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; } @@ -2111,6 +2279,7 @@ ggml_tensor * llm_graph_context::build_attn( llm_graph_input_attn_k * inp, ggml_tensor * wo, ggml_tensor * wo_b, + ggml_tensor * wo_s, ggml_tensor * q_cur, ggml_tensor * k_cur, ggml_tensor * v_cur, @@ -2145,10 +2314,15 @@ ggml_tensor * llm_graph_context::build_attn( cb(cur, "kqv_out", il); if (wo) { - cur = build_lora_mm(wo, cur); if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) { // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators + cur = build_lora_mm(wo, cur); ggml_mul_mat_set_prec(cur, GGML_PREC_F32); + if (wo_s) { + cur = ggml_mul(ctx0, cur, wo_s); + } + } else { + cur = build_lora_mm(wo, cur, wo_s); } } @@ -2163,6 +2337,7 @@ ggml_tensor * llm_graph_context::build_attn( llm_graph_input_attn_kv_iswa * inp, ggml_tensor * wo, ggml_tensor * wo_b, + ggml_tensor * wo_s, ggml_tensor * q_cur, ggml_tensor * k_cur, ggml_tensor * v_cur, @@ -2171,6 +2346,23 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * v_mla, float kq_scale, int il) const { + const bool is_swa = hparams.is_swa(il); + + auto * k_rot = is_swa ? inp->self_k_rot_swa : inp->self_k_rot; + auto * v_rot = is_swa ? inp->self_v_rot_swa : inp->self_v_rot; + + if (k_rot) { + q_cur = ggml_mul_mat_aux(ctx0, q_cur, k_rot); + if (k_cur) { + k_cur = ggml_mul_mat_aux(ctx0, k_cur, k_rot); + } + } + if (v_rot) { + if (v_cur) { + v_cur = ggml_mul_mat_aux(ctx0, v_cur, v_rot); + } + } + // these nodes are added to the graph together so that they are not reordered // by doing so, the number of splits in the graph is reduced ggml_build_forward_expand(gf, q_cur); @@ -2185,8 +2377,6 @@ ggml_tensor * llm_graph_context::build_attn( const auto * mctx_iswa = inp->mctx; - const bool is_swa = hparams.is_swa(il); - const auto * mctx_cur = is_swa ? mctx_iswa->get_swa() : mctx_iswa->get_base(); // optionally store to KV cache @@ -2211,8 +2401,12 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il); cb(cur, "kqv_out", il); + if (v_rot) { + cur = ggml_mul_mat_aux(ctx0, cur, v_rot); + } + if (wo) { - cur = build_lora_mm(wo, cur); + cur = build_lora_mm(wo, cur, wo_s); } if (wo_b) { @@ -2243,6 +2437,7 @@ ggml_tensor * llm_graph_context::build_attn( llm_graph_input_attn_cross * inp, ggml_tensor * wo, ggml_tensor * wo_b, + ggml_tensor * wo_s, ggml_tensor * q_cur, ggml_tensor * k_cur, ggml_tensor * v_cur, @@ -2267,7 +2462,7 @@ ggml_tensor * llm_graph_context::build_attn( cb(cur, "kqv_out", il); if (wo) { - cur = build_lora_mm(wo, cur); + cur = build_lora_mm(wo, cur, wo_s); } if (wo_b) { @@ -2293,12 +2488,8 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch); inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch); - inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur->get_base(), ubatch, cparams); - ggml_set_input(inp->self_kq_mask); - ggml_set_name(inp->self_kq_mask, "self_kq_mask"); - + inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, mctx_cur->get_base(), ubatch, cparams); inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; - ggml_set_name(inp->self_kq_mask_cnv, "self_kq_mask_cnv"); } { @@ -2307,14 +2498,16 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch); inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch); - inp->self_kq_mask_swa = build_kq_mask(ctx0, mctx_cur->get_swa(), ubatch, cparams); - ggml_set_input(inp->self_kq_mask_swa); - ggml_set_name(inp->self_kq_mask_swa, "self_kq_mask_swa"); - + inp->self_kq_mask_swa = build_attn_inp_kq_mask(ctx0, mctx_cur->get_swa(), ubatch, cparams); inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa; - ggml_set_name(inp->self_kq_mask_swa_cnv, "self_kq_mask_swa_cnv"); } + inp->self_k_rot = mctx_cur->get_base()->build_input_k_rot(ctx0); + inp->self_v_rot = mctx_cur->get_base()->build_input_v_rot(ctx0); + + inp->self_k_rot_swa = mctx_cur->get_swa()->build_input_k_rot(ctx0); + inp->self_v_rot_swa = mctx_cur->get_swa()->build_input_v_rot(ctx0); + return (llm_graph_input_attn_kv_iswa *) res->add_input(std::move(inp)); } @@ -2348,7 +2541,7 @@ ggml_tensor * llm_graph_context::build_rs( ggml_build_forward_expand(gf, ggml_cpy(ctx0, states_extra, - ggml_view_1d(ctx0, s, state_size*(n_rs - n_seqs), (rs_head + n_seqs)*state_size*ggml_element_size(s)))); + ggml_view_2d(ctx0, s, state_size, (n_rs - n_seqs), s->nb[1], (rs_head + n_seqs)*s->nb[1]))); return output_states; } @@ -2473,9 +2666,7 @@ llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa() inp_attn->self_k_idxs = attn_ctx->get_base()->build_input_k_idxs(ctx0, ubatch); inp_attn->self_v_idxs = attn_ctx->get_base()->build_input_v_idxs(ctx0, ubatch); - inp_attn->self_kq_mask = build_kq_mask(ctx0, attn_ctx->get_base(), ubatch, cparams); - ggml_set_input(inp_attn->self_kq_mask); - + inp_attn->self_kq_mask = build_attn_inp_kq_mask(ctx0, attn_ctx->get_base(), ubatch, cparams); inp_attn->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask, GGML_TYPE_F16) : inp_attn->self_kq_mask; } @@ -2483,9 +2674,7 @@ llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa() inp_attn->self_k_idxs_swa = attn_ctx->get_swa()->build_input_k_idxs(ctx0, ubatch); inp_attn->self_v_idxs_swa = attn_ctx->get_swa()->build_input_v_idxs(ctx0, ubatch); - inp_attn->self_kq_mask_swa = build_kq_mask(ctx0, attn_ctx->get_swa(), ubatch, cparams); - ggml_set_input(inp_attn->self_kq_mask_swa); - + inp_attn->self_kq_mask_swa = build_attn_inp_kq_mask(ctx0, attn_ctx->get_swa(), ubatch, cparams); inp_attn->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask_swa, GGML_TYPE_F16) : inp_attn->self_kq_mask_swa; } diff --git a/examples/talk-llama/llama-graph.h b/examples/talk-llama/llama-graph.h index 4855685ef71..5cb1756c6a9 100644 --- a/examples/talk-llama/llama-graph.h +++ b/examples/talk-llama/llama-graph.h @@ -17,6 +17,7 @@ struct ggml_context; struct ggml_tensor; struct llama_cparams; +struct llama_layer; struct llama_memory_context_i; @@ -308,6 +309,10 @@ class llm_graph_input_attn_kv : public llm_graph_input_i { ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream] ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] + // note: assumes v_rot^2 == I + ggml_tensor * self_k_rot = nullptr; + ggml_tensor * self_v_rot = nullptr; + // note: these have to be copies because in order to be able to reuse a graph, its inputs // need to carry these parameters with them. otherwise, they can point to freed // llm_graph_params from a previous batch, causing stack-use-after-return @@ -384,6 +389,12 @@ class llm_graph_input_attn_kv_iswa : public llm_graph_input_i { ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream] ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_k_rot = nullptr; + ggml_tensor * self_v_rot = nullptr; + + ggml_tensor * self_k_rot_swa = nullptr; + ggml_tensor * self_v_rot_swa = nullptr; + const llama_hparams hparams; const llama_cparams cparams; @@ -697,6 +708,12 @@ using llm_graph_result_ptr = std::unique_ptr; // used in build_rs to properly order writes and avoid unnecessary copies using llm_graph_get_rows_fn = std::function; +struct llm_graph_qkv { + ggml_tensor * q; // [n_embd_head, n_head, n_tokens] + ggml_tensor * k; // [n_embd_head, n_head_kv, n_tokens] + ggml_tensor * v; // [n_embd_head, n_head_kv, n_tokens] +}; + struct llm_graph_context { const llm_arch arch; @@ -783,6 +800,17 @@ struct llm_graph_context { llm_norm_type type, int il) const; + + // compute Q, K, V projections with optional bias and reshape + // supports both fused wqkv and separate wq/wk/wv paths + llm_graph_qkv build_qkv( + const llama_layer & layer, + ggml_tensor * cur, + int64_t n_embd_head, + int64_t n_head, + int64_t n_head_kv, + int il) const; + ggml_tensor * build_ffn( ggml_tensor * cur, ggml_tensor * up, @@ -882,6 +910,7 @@ struct llm_graph_context { llm_graph_input_attn_no_cache * inp, ggml_tensor * wo, ggml_tensor * wo_b, + ggml_tensor * wo_s, ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens] ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] @@ -897,6 +926,7 @@ struct llm_graph_context { llm_graph_input_attn_kv * inp, ggml_tensor * wo, ggml_tensor * wo_b, + ggml_tensor * wo_s, ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens] ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] @@ -912,6 +942,7 @@ struct llm_graph_context { llm_graph_input_attn_k * inp, ggml_tensor * wo, ggml_tensor * wo_b, + ggml_tensor * wo_s, ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens] ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] @@ -928,6 +959,7 @@ struct llm_graph_context { llm_graph_input_attn_kv_iswa * inp, ggml_tensor * wo, ggml_tensor * wo_b, + ggml_tensor * wo_s, ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens] ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] optional ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] optional @@ -943,6 +975,7 @@ struct llm_graph_context { llm_graph_input_attn_cross * inp, ggml_tensor * wo, ggml_tensor * wo_b, + ggml_tensor * wo_s, ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens] ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] diff --git a/examples/talk-llama/llama-hparams.h b/examples/talk-llama/llama-hparams.h index 78c0bc27d4d..ac7f9ee8650 100644 --- a/examples/talk-llama/llama-hparams.h +++ b/examples/talk-llama/llama-hparams.h @@ -116,6 +116,7 @@ struct llama_hparams { float rope_freq_base_train_swa = 10000.0f; float rope_freq_scale_train; float rope_freq_scale_train_swa = 1.0f; + float rope_scaling_alpha = 0.0f; // NTK-aware alpha for XDRoPE uint32_t n_ctx_orig_yarn; float rope_yarn_log_mul = 0.0f; @@ -209,6 +210,9 @@ struct llama_hparams { // qwen3vl deepstack uint32_t n_deepstack_layers = 0; + // gemma4 per-layer embedding + uint32_t n_embd_per_layer = 0; + // needed by encoder-decoder models (e.g. T5, FLAN-T5) // ref: https://github.com/ggml-org/llama.cpp/pull/8141 llama_token dec_start_token_id = LLAMA_TOKEN_NULL; diff --git a/examples/talk-llama/llama-impl.cpp b/examples/talk-llama/llama-impl.cpp index 4c0188ee722..b3a94b946d2 100644 --- a/examples/talk-llama/llama-impl.cpp +++ b/examples/talk-llama/llama-impl.cpp @@ -128,7 +128,7 @@ static std::string gguf_data_to_str(enum gguf_type type, const void * data, int case GGUF_TYPE_INT64: return std::to_string(((const int64_t *)data)[i]); case GGUF_TYPE_FLOAT32: return std::to_string(((const float *)data)[i]); case GGUF_TYPE_FLOAT64: return std::to_string(((const double *)data)[i]); - case GGUF_TYPE_BOOL: return ((const bool *)data)[i] ? "true" : "false"; + case GGUF_TYPE_BOOL: return ((const int8_t *)data)[i] != 0 ? "true" : "false"; default: return format("unknown type %d", type); } } diff --git a/examples/talk-llama/llama-kv-cache.cpp b/examples/talk-llama/llama-kv-cache.cpp index 01166fac9ce..09102f549c8 100644 --- a/examples/talk-llama/llama-kv-cache.cpp +++ b/examples/talk-llama/llama-kv-cache.cpp @@ -13,6 +13,65 @@ #include #include +static bool ggml_is_power_of_2(int n) { + return (n & (n - 1)) == 0; +} + +// orthonormal Walsh-Hadamard rotation matrix +// note: res^2 == I +static void ggml_gen_hadamard(ggml_tensor * tensor) { + assert(tensor->type == GGML_TYPE_F32); + + const int n = tensor->ne[0]; + + assert(ggml_is_power_of_2(n)); + assert(tensor->ne[1] == n); + assert(tensor->ne[2] == 1); + assert(tensor->ne[3] == 1); + + std::vector data_f32; + + float * data = (float *) tensor->data; + + if (tensor->type != GGML_TYPE_F32) { + data_f32.resize(n*n); + data = data_f32.data(); + } + + data[0*n + 0] = 1.0 / sqrtf(n); + + for (int s = 1; s < n; s *= 2) { + for (int i = 0; i < s; i++) { + for (int j = 0; j < s; j++) { + const float val = data[i*n + j]; + + data[(i + s)*n + (j )] = val; + data[(i )*n + (j + s)] = val; + data[(i + s)*n + (j + s)] = -val; + } + } + } + + if (tensor->type != GGML_TYPE_F32) { + ggml_quantize_chunk(tensor->type, data, tensor->data, 0, 1, n*n, nullptr); + } +} + +static ggml_tensor * ggml_mul_mat_aux( + ggml_context * ctx, + ggml_tensor * cur, + ggml_tensor * rot) { + const auto n = rot->ne[0]; + + ggml_tensor * res; + + res = ggml_reshape_2d(ctx, cur, n, ggml_nelements(cur)/n); + res = ggml_mul_mat (ctx, rot, res); + res = ggml_reshape_4d(ctx, res, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]); + + return res; +} + // // llama_kv_cache // @@ -110,6 +169,18 @@ llama_kv_cache::llama_kv_cache( continue; } + if (n_embd_head_k_all == 0) { + n_embd_head_k_all = (int32_t) hparams.n_embd_head_k(il); + } else if (n_embd_head_k_all > 0 && n_embd_head_k_all != (int32_t) hparams.n_embd_head_k(il)) { + n_embd_head_k_all = -1; + } + + if (n_embd_head_v_all == 0) { + n_embd_head_v_all = (int32_t) hparams.n_embd_head_v(il); + } else if (n_embd_head_v_all > 0 && n_embd_head_v_all != (int32_t) hparams.n_embd_head_v(il)) { + n_embd_head_v_all = -1; + } + // [TAG_V_CACHE_VARIABLE] const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); const uint32_t n_embd_v_gqa = !v_trans ? hparams.n_embd_v_gqa(il) : hparams.n_embd_v_gqa_max(); @@ -209,6 +280,48 @@ llama_kv_cache::llama_kv_cache( ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); } + const char * LLAMA_ATTN_ROT_DISABLE = getenv("LLAMA_ATTN_ROT_DISABLE"); + const bool attn_rot_disable = LLAMA_ATTN_ROT_DISABLE ? atoi(LLAMA_ATTN_ROT_DISABLE) : false; + if (attn_rot_disable) { + LLAMA_LOG_WARN("%s: attention rotation force disabled (LLAMA_ATTN_ROT_DISABLE)\n", __func__); + } + + attn_rot_k = + !attn_rot_disable && + n_embd_head_k_all > 0 && + ggml_is_quantized(type_k) && + hparams.n_embd_head_k() % 64 == 0; + + attn_rot_v = + !attn_rot_disable && + n_embd_head_v_all > 0 && + ggml_is_quantized(type_v) && + hparams.n_embd_head_v() % 64 == 0; + + LLAMA_LOG_INFO("%s: attn_rot_k = %d, n_embd_head_k_all = %d\n", __func__, attn_rot_k, n_embd_head_k_all); + LLAMA_LOG_INFO("%s: attn_rot_v = %d, n_embd_head_k_all = %d\n", __func__, attn_rot_v, n_embd_head_v_all); + + // pre-compute the haramard matrices and keep them in host memory + // TODO: in the future, we can make copies in the backend buffers to avoid host -> device transfers + if (attn_rot_k || attn_rot_v) { + for (int64_t n = 64; n <= std::max(n_embd_head_k_all, n_embd_head_v_all); n *= 2) { + attn_rot_hadamard[n] = std::vector(n*n); + + ggml_init_params params = { + /* .mem_size = */ 1*ggml_tensor_overhead(), + /* .mem_buffer = */ nullptr, + /* .no_alloc = */ true, + }; + + ggml_context_ptr ctx { ggml_init(params) }; + + ggml_tensor * tmp = ggml_new_tensor_2d(ctx.get(), GGML_TYPE_F32, n, n); + tmp->data = attn_rot_hadamard[n].data(); + + ggml_gen_hadamard(tmp); + } + } + const char * LLAMA_KV_CACHE_DEBUG = getenv("LLAMA_KV_CACHE_DEBUG"); debug = LLAMA_KV_CACHE_DEBUG ? atoi(LLAMA_KV_CACHE_DEBUG) : 0; } @@ -1004,6 +1117,14 @@ bool llama_kv_cache::get_has_shift() const { return result; } +ggml_type llama_kv_cache::type_k() const { + return layers[0].k->type; +} + +ggml_type llama_kv_cache::type_v() const { + return layers[0].v->type; +} + uint32_t llama_kv_cache::get_n_kv(const slot_info & sinfo) const { uint32_t result = 0; @@ -1189,6 +1310,47 @@ ggml_tensor * llama_kv_cache::build_input_v_idxs(ggml_context * ctx, const llama return v_idxs; } +ggml_tensor * llama_kv_cache::build_input_k_rot(ggml_context * ctx) const { + ggml_tensor * res = nullptr; + + if (attn_rot_k) { + int nrot = 64; + + // TODO: investigate if using the smallest rotation matrix is beneficial also for K (similar as for V) + // ref: https://github.com/ggml-org/llama.cpp/pull/21038#issuecomment-4141323088 + do { + nrot *= 2; + } while (n_embd_head_k_all % nrot == 0); + nrot /= 2; + + res = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nrot, nrot); + ggml_set_input(res); + ggml_set_name(res, "attn_inp_k_rot"); + } + + return res; +} + +ggml_tensor * llama_kv_cache::build_input_v_rot(ggml_context * ctx) const { + ggml_tensor * res = nullptr; + + if (attn_rot_v) { + int nrot = 64; + // using smaller rotation matrices for V seems beneficial + // ref: https://github.com/ggml-org/llama.cpp/pull/21038#issuecomment-4146397570 + //do { + // nrot *= 2; + //} while (hparams.n_embd_head_v() % nrot == 0); + //nrot /= 2; + + res = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nrot, nrot); + ggml_set_input(res); + ggml_set_name(res, "attn_inp_v_rot"); + } + + return res; +} + void llama_kv_cache::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const { const uint32_t n_tokens = ubatch->n_tokens; GGML_ASSERT(n_tokens == (int64_t) sinfo.size()*sinfo.n_stream()); @@ -1507,6 +1669,24 @@ void llama_kv_cache::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch } } +void llama_kv_cache::set_input_k_rot(ggml_tensor * dst) const { + GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); + + const auto n_rot = dst->ne[0]; + GGML_ASSERT(attn_rot_hadamard.count(dst->ne[0])); + + memcpy(dst->data, attn_rot_hadamard.at(n_rot).data(), ggml_nbytes(dst)); +} + +void llama_kv_cache::set_input_v_rot(ggml_tensor * dst) const { + GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); + + const auto n_rot = dst->ne[0]; + GGML_ASSERT(attn_rot_hadamard.count(dst->ne[0])); + + memcpy(dst->data, attn_rot_hadamard.at(n_rot).data(), ggml_nbytes(dst)); +} + size_t llama_kv_cache::total_size() const { size_t size = 0; @@ -1542,6 +1722,7 @@ ggml_tensor * llama_kv_cache::build_rope_shift( ggml_context * ctx, ggml_tensor * cur, ggml_tensor * shift, + ggml_tensor * rot, ggml_tensor * factors, float freq_base, float freq_scale, @@ -1561,17 +1742,22 @@ ggml_tensor * llama_kv_cache::build_rope_shift( // ref: https://github.com/ggml-org/llama.cpp/pull/13870 ? LLAMA_ROPE_TYPE_NEOX : hparams.rope_type; - ggml_tensor * tmp; if (ggml_is_quantized(cur->type)) { // dequantize to f32 -> RoPE -> quantize back tmp = ggml_cast(ctx, cur, GGML_TYPE_F32); + // rotate back + tmp = ggml_mul_mat_aux(ctx, tmp, rot); + tmp = ggml_rope_ext(ctx, tmp, shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow); + // rotate fwd + tmp = ggml_mul_mat_aux(ctx, tmp, rot); + tmp = ggml_cpy(ctx, tmp, cur); } else { // we rotate only the first n_rot dimensions @@ -1592,6 +1778,9 @@ class llm_graph_input_k_shift : public llm_graph_input_i { ggml_tensor * k_shift; // I32 [kv_size*n_stream] + // note: assumes k_rot^2 == I + ggml_tensor * k_rot = nullptr; + const llama_kv_cache * kv_self; }; @@ -1601,6 +1790,10 @@ void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) { if (k_shift) { kv_self->set_input_k_shift(k_shift); } + + if (k_rot) { + kv_self->set_input_k_rot(k_rot); + } } ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_context * lctx) const { @@ -1612,6 +1805,8 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, (int64_t) get_size()*n_stream); ggml_set_input(inp->k_shift); + inp->k_rot = build_input_k_rot(ctx); + const auto & cparams = lctx->get_cparams(); for (const auto & layer : layers) { @@ -1636,7 +1831,7 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co ggml_row_size(layer.k->type, n_embd_k_gqa), ggml_row_size(layer.k->type, n_embd_nope)); - ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l, il); + ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, inp->k_rot, rope_factors, freq_base_l, freq_scale_l, il); ggml_build_forward_expand(gf, cur); } @@ -2240,6 +2435,14 @@ uint32_t llama_kv_cache_context::get_n_kv() const { return n_kv; } +ggml_type llama_kv_cache_context::type_k() const { + return kv->type_k(); +} + +ggml_type llama_kv_cache_context::type_v() const { + return kv->type_v(); +} + ggml_tensor * llama_kv_cache_context::get_k(ggml_context * ctx, int32_t il) const { return kv->get_k(ctx, il, n_kv, sinfos[i_cur]); } @@ -2264,6 +2467,14 @@ ggml_tensor * llama_kv_cache_context::build_input_v_idxs(ggml_context * ctx, con return kv->build_input_v_idxs(ctx, ubatch); } +ggml_tensor * llama_kv_cache_context::build_input_k_rot(ggml_context * ctx) const { + return kv->build_input_k_rot(ctx); +} + +ggml_tensor * llama_kv_cache_context::build_input_v_rot(ggml_context * ctx) const { + return kv->build_input_v_rot(ctx); +} + void llama_kv_cache_context::set_input_k_shift(ggml_tensor * dst) const { kv->set_input_k_shift(dst); } @@ -2283,3 +2494,11 @@ void llama_kv_cache_context::set_input_kq_mask(ggml_tensor * dst, const llama_ub void llama_kv_cache_context::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const { kv->set_input_pos_bucket(dst, ubatch); } + +void llama_kv_cache_context::set_input_k_rot(ggml_tensor * dst) const { + kv->set_input_k_rot(dst); +} + +void llama_kv_cache_context::set_input_v_rot(ggml_tensor * dst) const { + kv->set_input_v_rot(dst); +} diff --git a/examples/talk-llama/llama-kv-cache.h b/examples/talk-llama/llama-kv-cache.h index 33c78c5f210..0b62dc7b232 100644 --- a/examples/talk-llama/llama-kv-cache.h +++ b/examples/talk-llama/llama-kv-cache.h @@ -152,6 +152,9 @@ class llama_kv_cache : public llama_memory_i { bool get_has_shift() const; + ggml_type type_k() const; + ggml_type type_v() const; + // // graph_build API // @@ -191,6 +194,9 @@ class llama_kv_cache : public llama_memory_i { ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const; ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const; + ggml_tensor * build_input_k_rot(ggml_context * ctx) const; + ggml_tensor * build_input_v_rot(ggml_context * ctx) const; + void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const; void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const; @@ -199,6 +205,9 @@ class llama_kv_cache : public llama_memory_i { void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const; void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const; + void set_input_k_rot(ggml_tensor * dst) const; + void set_input_v_rot(ggml_tensor * dst) const; + private: const llama_model & model; const llama_hparams & hparams; @@ -226,6 +235,18 @@ class llama_kv_cache : public llama_memory_i { // SWA const uint32_t n_swa = 0; + // env: LLAMA_ATTN_ROT_DISABLE + bool attn_rot_k = false; + bool attn_rot_v = false; + + // if all layers participating in the cache have constant head size, the value is stored here + // otherwise the value is -1 + int32_t n_embd_head_k_all = 0; + int32_t n_embd_head_v_all = 0; + + // pre-computed hadamard martrices + std::unordered_map> attn_rot_hadamard; + // env: LLAMA_KV_CACHE_DEBUG int debug = 0; @@ -262,6 +283,7 @@ class llama_kv_cache : public llama_memory_i { ggml_context * ctx, ggml_tensor * cur, ggml_tensor * shift, + ggml_tensor * rot, ggml_tensor * factors, float freq_base, float freq_scale, @@ -328,12 +350,15 @@ class llama_kv_cache_context : public llama_memory_context_i { uint32_t get_n_kv() const; + ggml_type type_k() const; + ggml_type type_v() const; + // get views of the current state of the cache ggml_tensor * get_k(ggml_context * ctx, int32_t il) const; ggml_tensor * get_v(ggml_context * ctx, int32_t il) const; // store k_cur and v_cur in the cache based on the provided head location - // note: the heads in k_cur and v_cur should be layed out contiguously in memory + // note: the heads in k_cur and v_cur should be laid out contiguously in memory // - k_cur [n_embd_head_k, n_head_k, n_tokens] // - k_idxs [n_tokens] // - v_cur [n_embd_head_v, n_head_v, n_tokens] @@ -347,6 +372,9 @@ class llama_kv_cache_context : public llama_memory_context_i { ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const; ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const; + ggml_tensor * build_input_k_rot(ggml_context * ctx) const; + ggml_tensor * build_input_v_rot(ggml_context * ctx) const; + void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const; void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const; @@ -354,6 +382,9 @@ class llama_kv_cache_context : public llama_memory_context_i { void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const; void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const; + void set_input_k_rot(ggml_tensor * dst) const; + void set_input_v_rot(ggml_tensor * dst) const; + private: llama_memory_status status; diff --git a/examples/talk-llama/llama-memory-hybrid-iswa.cpp b/examples/talk-llama/llama-memory-hybrid-iswa.cpp index 411769672af..10e6b459797 100644 --- a/examples/talk-llama/llama-memory-hybrid-iswa.cpp +++ b/examples/talk-llama/llama-memory-hybrid-iswa.cpp @@ -73,9 +73,9 @@ llama_memory_context_ptr llama_memory_hybrid_iswa::init_batch(llama_batch_allocr // if all tokens are output, split by sequence ubatch = balloc.split_seq(n_ubatch); } else { - // TODO: non-sequential equal split can be done if using unified KV cache - // for simplicity, we always use sequential equal split for now - ubatch = balloc.split_equal(n_ubatch, true); + // Use non-sequential split when KV cache is unified (needed for hellaswag/winogrande/multiple-choice) + const bool unified = (mem_attn->get_base()->get_n_stream() == 1); + ubatch = balloc.split_equal(n_ubatch, !unified); } if (ubatch.n_tokens == 0) { diff --git a/examples/talk-llama/llama-memory-hybrid.cpp b/examples/talk-llama/llama-memory-hybrid.cpp index a1b45e4a3cc..4ce1af592c1 100644 --- a/examples/talk-llama/llama-memory-hybrid.cpp +++ b/examples/talk-llama/llama-memory-hybrid.cpp @@ -73,9 +73,9 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba // if all tokens are output, split by sequence ubatch = balloc.split_seq(n_ubatch); } else { - // TODO: non-sequential equal split can be done if using unified KV cache - // for simplicity, we always use sequential equal split for now - ubatch = balloc.split_equal(n_ubatch, true); + // Use non-sequential split when KV cache is unified (needed for hellaswag/winogrande/multiple-choice) + const bool unified = (mem_attn->get_n_stream() == 1); + ubatch = balloc.split_equal(n_ubatch, !unified); } if (ubatch.n_tokens == 0) { diff --git a/examples/talk-llama/llama-memory-recurrent.cpp b/examples/talk-llama/llama-memory-recurrent.cpp index 6e8413f493d..9287fe45e96 100644 --- a/examples/talk-llama/llama-memory-recurrent.cpp +++ b/examples/talk-llama/llama-memory-recurrent.cpp @@ -1,5 +1,6 @@ #include "llama-memory-recurrent.h" +#include "ggml-backend.h" #include "llama-impl.h" #include "llama-io.h" #include "llama-batch.h" @@ -91,8 +92,8 @@ llama_memory_recurrent::llama_memory_recurrent( throw std::runtime_error("failed to create ggml context for rs cache"); } - ggml_tensor * r = ggml_new_tensor_1d(ctx, type_r, hparams.n_embd_r()*mem_size); - ggml_tensor * s = ggml_new_tensor_1d(ctx, type_s, hparams.n_embd_s()*mem_size); + ggml_tensor * r = ggml_new_tensor_2d(ctx, type_r, hparams.n_embd_r(), mem_size); + ggml_tensor * s = ggml_new_tensor_2d(ctx, type_s, hparams.n_embd_s(), mem_size); ggml_format_name(r, "cache_r_l%d", i); ggml_format_name(s, "cache_s_l%d", i); r_l[i] = r; @@ -928,11 +929,8 @@ bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell llama_seq_id seq_id; io.read_to(&seq_id, sizeof(seq_id)); - // TODO: llama_memory_recurrent should have a notion of max sequences - //if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { - if (seq_id < 0) { - //LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx)); - LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, inf)\n", __func__, seq_id); + if (seq_id < 0 || (uint32_t) seq_id >= this->n_seq_max) { + LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, this->n_seq_max); return false; } diff --git a/examples/talk-llama/llama-mmap.cpp b/examples/talk-llama/llama-mmap.cpp index c03228e9ce2..ed572da7fb5 100644 --- a/examples/talk-llama/llama-mmap.cpp +++ b/examples/talk-llama/llama-mmap.cpp @@ -40,6 +40,14 @@ #include #endif +#ifdef _WIN32 +# define llama_mmap_ftell _ftelli64 +# define llama_mmap_fseek _fseeki64 +#else +# define llama_mmap_ftell ftello +# define llama_mmap_fseek fseeko +#endif + // TODO: consider moving to llama-impl.h if needed in more places #if defined(_WIN32) static std::string llama_format_win_err(DWORD err) { @@ -86,6 +94,14 @@ struct llama_file::impl { seek(0, SEEK_SET); } + impl(FILE * file) : owns_fp(false) { + fp = file; + fp_win32 = (HANDLE) _get_osfhandle(_fileno(fp)); + seek(0, SEEK_END); + size = tell(); + seek(0, SEEK_SET); + } + size_t tell() const { LARGE_INTEGER li; li.QuadPart = 0; @@ -159,7 +175,7 @@ struct llama_file::impl { } ~impl() { - if (fp) { + if (fp && owns_fp) { std::fclose(fp); } } @@ -209,9 +225,16 @@ struct llama_file::impl { seek(0, SEEK_SET); } + impl(FILE * file) : fname("(file*)"), owns_fp(false) { + fp = file; + seek(0, SEEK_END); + size = tell(); + seek(0, SEEK_SET); + } + size_t tell() const { if (fd == -1) { - long ret = std::ftell(fp); + off_t ret = llama_mmap_ftell(fp); if (ret == -1) { throw std::runtime_error(format("ftell error: %s", strerror(errno))); } @@ -229,7 +252,7 @@ struct llama_file::impl { void seek(size_t offset, int whence) const { off_t ret = 0; if (fd == -1) { - ret = std::fseek(fp, (long) offset, whence); + ret = llama_mmap_fseek(fp, offset, whence); } else { ret = lseek(fd, offset, whence); } @@ -353,7 +376,7 @@ struct llama_file::impl { ~impl() { if (fd != -1) { close(fd); - } else { + } else if (owns_fp) { std::fclose(fp); } } @@ -369,10 +392,14 @@ struct llama_file::impl { FILE * fp{}; size_t size{}; + bool owns_fp = true; }; llama_file::llama_file(const char * fname, const char * mode, const bool use_direct_io) : pimpl(std::make_unique(fname, mode, use_direct_io)) {} + +llama_file::llama_file(FILE * file) : pimpl(std::make_unique(file)) {} + llama_file::~llama_file() = default; size_t llama_file::tell() const { return pimpl->tell(); } diff --git a/examples/talk-llama/llama-mmap.h b/examples/talk-llama/llama-mmap.h index 29ce4d24685..b7d5c61e95f 100644 --- a/examples/talk-llama/llama-mmap.h +++ b/examples/talk-llama/llama-mmap.h @@ -15,6 +15,7 @@ using llama_mlocks = std::vector>; struct llama_file { llama_file(const char * fname, const char * mode, bool use_direct_io = false); + llama_file(FILE * file); ~llama_file(); size_t tell() const; diff --git a/examples/talk-llama/llama-model-loader.cpp b/examples/talk-llama/llama-model-loader.cpp index 413f34c2268..4e65a45a50d 100644 --- a/examples/talk-llama/llama-model-loader.cpp +++ b/examples/talk-llama/llama-model-loader.cpp @@ -36,6 +36,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) { case LLAMA_FTYPE_ALL_F32: return "all F32"; case LLAMA_FTYPE_MOSTLY_F16: return "F16"; case LLAMA_FTYPE_MOSTLY_BF16: return "BF16"; + case LLAMA_FTYPE_MOSTLY_Q1_0: return "Q1_0"; case LLAMA_FTYPE_MOSTLY_Q4_0: return "Q4_0"; case LLAMA_FTYPE_MOSTLY_Q4_1: return "Q4_1"; case LLAMA_FTYPE_MOSTLY_Q5_0: return "Q5_0"; @@ -374,8 +375,9 @@ namespace GGUFMeta { } } else { if (arr_info.gt == GGUF_TYPE_BOOL) { - std::transform((const bool *)arr_info.data, (const bool *)arr_info.data + arr_info.length, result.begin(), [](bool x) { - return static_cast(x); + const int8_t * values = (const int8_t *) arr_info.data; + std::transform(values, values + arr_info.length, result.begin(), [](int8_t x) { + return static_cast(x != 0); }); } else { std::copy((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length, result.begin()); @@ -511,6 +513,7 @@ llama_model_loader::llama_model_loader( void * set_tensor_data_ud, const std::string & fname, std::vector & splits, + FILE * file, bool use_mmap, bool use_direct_io, bool check_tensors, @@ -658,6 +661,36 @@ llama_model_loader::llama_model_loader( LLAMA_LOG_INFO("%s: additional %d GGUFs metadata loaded.\n", __func__, n_split - 1); } + } else if (file != nullptr) { + struct ggml_context * ctx = NULL; + struct gguf_init_params params = { + /*.no_alloc = */ true, + /*.ctx = */ &ctx, + }; + + metadata_ptr.reset(gguf_init_from_file_ptr(file, params)); + metadata = metadata_ptr.get(); + if (metadata == nullptr) { + throw std::runtime_error(format("%s: failed to load model from file pointer", __func__)); + } + + get_key(llm_kv(LLM_KV_GENERAL_ARCHITECTURE), arch_name, false); + llm_kv = LLM_KV(llm_arch_from_string(arch_name)); + + files.emplace_back(new llama_file(file)); + contexts.emplace_back(ctx); + + // Save tensors data offset info of the main file. + for (ggml_tensor * cur = ggml_get_first_tensor(ctx); cur; cur = ggml_get_next_tensor(ctx, cur)) { + std::string tensor_name = std::string(cur->name); + // make sure there is no duplicated tensor names + if (weights_map.find(tensor_name) != weights_map.end()) { + throw std::runtime_error(format("invalid model: tensor '%s' is duplicated", ggml_get_name(cur))); + } + n_elements += ggml_nelements(cur); + n_bytes += ggml_nbytes(cur); + weights_map.emplace(tensor_name, llama_tensor_weight(files.back().get(), 0, metadata, cur)); + } } else { get_key(llm_kv(LLM_KV_GENERAL_ARCHITECTURE), arch_name, false); llm_kv = LLM_KV(llm_arch_from_string(arch_name)); @@ -669,7 +702,7 @@ llama_model_loader::llama_model_loader( fver = (enum llama_fver) gguf_get_version(metadata); LLAMA_LOG_INFO("%s: loaded meta data with %d key-value pairs and %d tensors from %s (version %s)\n", - __func__, n_kv, n_tensors, fname.c_str(), llama_file_version_name(fver)); + __func__, n_kv, n_tensors, fname.empty() ? "(file*)" : fname.c_str(), llama_file_version_name(fver)); // determine file type based on the number of tensors for each quantization and print meta data // TODO: make optional @@ -726,6 +759,7 @@ llama_model_loader::llama_model_loader( case GGML_TYPE_IQ4_XS: ftype = LLAMA_FTYPE_MOSTLY_IQ4_XS; break; case GGML_TYPE_IQ3_S: ftype = LLAMA_FTYPE_MOSTLY_IQ3_S; break; case GGML_TYPE_NVFP4: ftype = LLAMA_FTYPE_MOSTLY_NVFP4; break; + case GGML_TYPE_Q1_0: ftype = LLAMA_FTYPE_MOSTLY_Q1_0; break; default: { LLAMA_LOG_WARN("%s: unknown type %s\n", __func__, ggml_type_name(type_max)); @@ -1127,6 +1161,12 @@ struct ggml_tensor * llama_model_loader::create_tensor( if (overrides->buft == ggml_backend_cpu_buffer_type()) { // when overriding to a CPU buffer, consider the extra buffer types buft = select_weight_buft(hparams, t_meta, op, buft_list_cpu); + if (use_mmap) { + static std::once_flag once; + std::call_once(once, [] { + LLAMA_LOG_WARN("llama_model_loader: tensor overrides to CPU are used with mmap enabled - consider using --no-mmap for better performance\n"); + }); + } } else { buft = overrides->buft; } diff --git a/examples/talk-llama/llama-model-loader.h b/examples/talk-llama/llama-model-loader.h index ed5de729caf..7b3d6703c03 100644 --- a/examples/talk-llama/llama-model-loader.h +++ b/examples/talk-llama/llama-model-loader.h @@ -125,6 +125,7 @@ struct llama_model_loader { void * set_tensor_data_ud, const std::string & fname, std::vector & splits, // optional, only need if the split does not follow naming scheme + FILE * file, bool use_mmap, bool use_direct_io, bool check_tensors, diff --git a/examples/talk-llama/llama-model-saver.cpp b/examples/talk-llama/llama-model-saver.cpp index 6f6538aeccd..26864c18e97 100644 --- a/examples/talk-llama/llama-model-saver.cpp +++ b/examples/talk-llama/llama-model-saver.cpp @@ -1,7 +1,9 @@ #include "llama-model-saver.h" +#include "ggml.h" #include "gguf.h" +#include "llama-arch.h" #include "llama.h" #include "llama-hparams.h" #include "llama-model.h" @@ -10,8 +12,33 @@ #include #include +bool llama_model_saver_supports_arch(llm_arch arch) { + switch (arch) { + case LLM_ARCH_QWEN3NEXT: + case LLM_ARCH_QWEN35: + case LLM_ARCH_QWEN35MOE: + case LLM_ARCH_PLAMO3: + case LLM_ARCH_GEMMA3: + case LLM_ARCH_GEMMA3N: + case LLM_ARCH_COHERE2: + case LLM_ARCH_OLMO2: + case LLM_ARCH_BITNET: + case LLM_ARCH_T5: + case LLM_ARCH_EXAONE_MOE: + case LLM_ARCH_AFMOE: + case LLM_ARCH_APERTUS: + case LLM_ARCH_MIMO2: + case LLM_ARCH_STEP35: + return false; + default: + return true; + } +} + llama_model_saver::llama_model_saver(const struct llama_model * model) : - gguf_ctx(gguf_init_empty()), gguf_ctx_owned(true), model(model), llm_kv(model->arch) {} + gguf_ctx(gguf_init_empty()), gguf_ctx_owned(true), model(model), llm_kv(model->arch) { + GGML_ASSERT(llama_model_saver_supports_arch(model->arch)); +} llama_model_saver::llama_model_saver(enum llm_arch arch, struct gguf_context * gguf_ctx) : gguf_ctx(gguf_ctx == nullptr ? gguf_init_empty() : gguf_ctx), gguf_ctx_owned(gguf_ctx == nullptr), model(nullptr), llm_kv(arch) {} @@ -105,7 +132,10 @@ void llama_model_saver::add_tensor(const struct ggml_tensor * tensor) { return; } if (gguf_find_tensor(gguf_ctx, tensor->name) >= 0) { - GGML_ASSERT(std::string(tensor->name) == "rope_freqs.weight"); // FIXME + const std::string tensor_name = tensor->name; + GGML_ASSERT( + tensor_name == "rope_freqs.weight" || tensor_name == "rope_factors_long.weight" || + tensor_name == "rope_factors_short.weight"); // FIXME return; } gguf_add_tensor(gguf_ctx, tensor); @@ -127,6 +157,7 @@ void llama_model_saver::add_kv_from_model() { tokens[id] = token_data.text; scores[id] = token_data.score; + // FIXME should this be treated as flags? switch(token_data.attr) { case LLAMA_TOKEN_ATTR_UNKNOWN: token_types[id] = LLAMA_TOKEN_TYPE_UNKNOWN; break; case LLAMA_TOKEN_ATTR_UNUSED: token_types[id] = LLAMA_TOKEN_TYPE_UNUSED; break; @@ -134,6 +165,9 @@ void llama_model_saver::add_kv_from_model() { case LLAMA_TOKEN_ATTR_CONTROL: token_types[id] = LLAMA_TOKEN_TYPE_CONTROL; break; case LLAMA_TOKEN_ATTR_USER_DEFINED: token_types[id] = LLAMA_TOKEN_TYPE_USER_DEFINED; break; case LLAMA_TOKEN_ATTR_BYTE: token_types[id] = LLAMA_TOKEN_TYPE_BYTE; break; + // case LLAMA_TOKEN_ATTR_NORMALIZED: ??? + // case LLAMA_TOKEN_ATTR_LSTRIP: ??? + // case LLAMA_TOKEN_ATTR_RSTRIP: ??? case LLAMA_TOKEN_ATTR_UNDEFINED: default: token_types[id] = LLAMA_TOKEN_TYPE_UNDEFINED; break; } @@ -144,6 +178,19 @@ void llama_model_saver::add_kv_from_model() { add_kv(LLM_KV_GENERAL_ARCHITECTURE, model->arch_name()); // add_kv(LLM_KV_GENERAL_QUANTIZATION_VERSION, ???); // add_kv(LLM_KV_GENERAL_ALIGNMENT, ???); + // add_kv(LLM_KV_GENERAL_FILE_TYPE, ???); + // add_kv(LLM_KV_GENERAL_SAMPLING_SEQUENCE, ???); + // add_kv(LLM_KV_GENERAL_SAMPLING_TOP_K, ???); + // add_kv(LLM_KV_GENERAL_SAMPLING_TOP_P, ???); + // add_kv(LLM_KV_GENERAL_SAMPLING_MIN_P, ???); + // add_kv(LLM_KV_GENERAL_SAMPLING_XTC_PROBABILITY, ???); + // add_kv(LLM_KV_GENERAL_SAMPLING_XTC_THRESHOLD, ???); + // add_kv(LLM_KV_GENERAL_SAMPLING_TEMP, ???); + // add_kv(LLM_KV_GENERAL_SAMPLING_PENALTY_LAST_N, ???); + // add_kv(LLM_KV_GENERAL_SAMPLING_PENALTY_REPEAT, ???); + // add_kv(LLM_KV_GENERAL_SAMPLING_MIROSTAT, ???); + // add_kv(LLM_KV_GENERAL_SAMPLING_MIROSTAT_TAU, ???); + // add_kv(LLM_KV_GENERAL_SAMPLING_MIROSTAT_ETA, ???); add_kv(LLM_KV_GENERAL_NAME, model->name); // add_kv(LLM_KV_GENERAL_AUTHOR, ???); // add_kv(LLM_KV_GENERAL_VERSION, ???); @@ -163,17 +210,31 @@ void llama_model_saver::add_kv_from_model() { add_kv(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); add_kv(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, true); add_kv(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); - add_kv(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + add_kv(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp); + add_kv(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_chexp); + add_kv(LLM_KV_SWIGLU_CLAMP_EXP, hparams.swiglu_clamp_exp); + add_kv(LLM_KV_SWIGLU_CLAMP_SHEXP, hparams.swiglu_clamp_shexp); add_kv(LLM_KV_USE_PARALLEL_RESIDUAL, hparams.use_par_res); // add_kv(LLM_KV_TENSOR_DATA_LAYOUT, ???); add_kv(LLM_KV_EXPERT_COUNT, hparams.n_expert); add_kv(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used); add_kv(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + add_kv(LLM_KV_EXPERT_GROUP_COUNT, hparams.n_expert_groups); + add_kv(LLM_KV_EXPERT_GROUP_USED_COUNT, hparams.n_group_used); add_kv(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); + add_kv(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm); + add_kv(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func); + add_kv(LLM_KV_EXPERT_GROUP_SCALE, hparams.expert_group_scale); + add_kv(LLM_KV_EXPERTS_PER_GROUP, hparams.n_group_experts); + add_kv(LLM_KV_MOE_EVERY_N_LAYERS, hparams.moe_every_n_layers); + add_kv(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers); + add_kv(LLM_KV_NUM_DEEPSTACK_LAYERS, hparams.n_deepstack_layers); add_kv(LLM_KV_POOLING_TYPE, uint32_t(hparams.pooling_type)); add_kv(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); add_kv(LLM_KV_DECODER_START_TOKEN_ID, hparams.dec_start_token_id); + add_kv(LLM_KV_DECODER_BLOCK_COUNT, hparams.dec_n_layer); add_kv(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping); + add_kv(LLM_KV_ROUTER_LOGIT_SOFTCAPPING, hparams.f_router_logit_softcapping); add_kv(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping); add_kv(LLM_KV_SWIN_NORM, hparams.swin_norm); add_kv(LLM_KV_RESCALE_EVERY_N_LAYERS, hparams.rescale_every_n_layers); @@ -181,6 +242,9 @@ void llama_model_saver::add_kv_from_model() { add_kv(LLM_KV_TIME_DECAY_EXTRA_DIM, hparams.time_decay_extra_dim); add_kv(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale); add_kv(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale); + add_kv(LLM_KV_TOKEN_SHIFT_COUNT, hparams.token_shift_count); + add_kv(LLM_KV_INTERLEAVE_MOE_LAYER_STEP, hparams.n_moe_layer_step); + // add_kv(LLM_KV_FULL_ATTENTION_INTERVAL, ???); add_kv(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, true); add_kv(LLM_KV_ATTENTION_HEAD_COUNT_KV, hparams.n_head_kv_arr, true); @@ -188,22 +252,39 @@ void llama_model_saver::add_kv_from_model() { add_kv(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv); add_kv(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k_full); add_kv(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v_full); - add_kv(LLM_KV_ATTENTION_KEY_LENGTH_SWA, hparams.n_embd_head_k_swa); - add_kv(LLM_KV_ATTENTION_VALUE_LENGTH_SWA, hparams.n_embd_head_v_swa); add_kv(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); add_kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + add_kv(LLM_KV_ATTENTION_GROUPNORM_EPS, hparams.f_norm_group_eps); + add_kv(LLM_KV_ATTENTION_GROUPNORM_GROUPS, hparams.n_norm_groups); add_kv(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); add_kv(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); add_kv(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + add_kv(LLM_KV_ATTENTION_DECAY_LORA_RANK, hparams.n_lora_decay); + add_kv(LLM_KV_ATTENTION_ICLR_LORA_RANK, hparams.n_lora_iclr); + add_kv(LLM_KV_ATTENTION_VALUE_RESIDUAL_MIX_LORA_RANK, hparams.n_lora_value_res_mix); + add_kv(LLM_KV_ATTENTION_GATE_LORA_RANK, hparams.n_lora_gate); add_kv(LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, hparams.n_rel_attn_bkts); add_kv(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + // add_kv(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, ???); add_kv(LLM_KV_ATTENTION_SCALE, hparams.f_attention_scale); + add_kv(LLM_KV_ATTENTION_OUTPUT_SCALE, hparams.f_attn_out_scale); + add_kv(LLM_KV_ATTENTION_TEMPERATURE_LENGTH, hparams.attn_temp_length); + add_kv(LLM_KV_ATTENTION_TEMPERATURE_SCALE, hparams.f_attn_temp_scale); + add_kv(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla_impl); + add_kv(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla_impl); + add_kv(LLM_KV_ATTENTION_KEY_LENGTH_SWA, hparams.n_embd_head_k_swa); + add_kv(LLM_KV_ATTENTION_VALUE_LENGTH_SWA, hparams.n_embd_head_v_swa); + add_kv(LLM_KV_ATTENTION_INDEXER_HEAD_COUNT, hparams.indexer_n_head); + add_kv(LLM_KV_ATTENTION_INDEXER_KEY_LENGTH, hparams.indexer_head_size); + add_kv(LLM_KV_ATTENTION_INDEXER_TOP_K, hparams.indexer_top_k); const float rope_scaling_factor = hparams.rope_freq_scale_train == 1.0f ? 0.0f : 1.0f/hparams.rope_freq_scale_train; add_kv(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot_full); add_kv(LLM_KV_ROPE_DIMENSION_COUNT_SWA, hparams.n_rot_swa); + add_kv(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections); add_kv(LLM_KV_ROPE_FREQ_BASE, hparams.rope_freq_base_train); + add_kv(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa); // add_kv(LLM_KV_ROPE_SCALE_LINEAR, rope_scaling_factor); // old name add_kv(LLM_KV_ROPE_SCALING_TYPE, llama_rope_scaling_type_name(hparams.rope_scaling_type_train)); add_kv(LLM_KV_ROPE_SCALING_FACTOR, rope_scaling_factor); @@ -211,6 +292,10 @@ void llama_model_saver::add_kv_from_model() { add_kv(LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, hparams.n_ctx_orig_yarn); add_kv(LLM_KV_ROPE_SCALING_FINETUNED, hparams.rope_finetuned); add_kv(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul); + add_kv(LLM_KV_ROPE_SCALING_YARN_EXT_FACTOR, hparams.yarn_ext_factor); + add_kv(LLM_KV_ROPE_SCALING_YARN_ATTN_FACTOR, hparams.yarn_attn_factor); + add_kv(LLM_KV_ROPE_SCALING_YARN_BETA_FAST, hparams.yarn_beta_fast); + add_kv(LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, hparams.yarn_beta_slow); // TODO: implement split file support // add_kv(LLM_KV_SPLIT_NO, ???); @@ -221,8 +306,11 @@ void llama_model_saver::add_kv_from_model() { add_kv(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); add_kv(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); add_kv(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + add_kv(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); add_kv(LLM_KV_SSM_DT_B_C_RMS, hparams.ssm_dt_b_c_rms); + add_kv(LLM_KV_KDA_HEAD_DIM, hparams.n_embd_head_kda); + add_kv(LLM_KV_WKV_HEAD_SIZE, hparams.wkv_head_size); add_kv(LLM_KV_TOKENIZER_MODEL, vocab.get_tokenizer_model()); @@ -260,15 +348,39 @@ void llama_model_saver::add_kv_from_model() { // TODO: implement LoRA support // add_kv(LLM_KV_ADAPTER_TYPE, ???); // add_kv(LLM_KV_ADAPTER_LORA_ALPHA, ???); + // add_kv(LLM_KV_ADAPTER_LORA_TASK_NAME, ???); + // add_kv(LLM_KV_ADAPTER_LORA_PROMPT_PREFIX, ???); + // add_kv(LLM_KV_ADAPTER_ALORA_INVOCATION_TOKENS, ???); + + add_kv(LLM_KV_POSNET_EMBEDDING_LENGTH, hparams.posnet.n_embd); + add_kv(LLM_KV_POSNET_BLOCK_COUNT, hparams.posnet.n_layer); + + add_kv(LLM_KV_CONVNEXT_EMBEDDING_LENGTH, hparams.convnext.n_embd); + add_kv(LLM_KV_CONVNEXT_BLOCK_COUNT, hparams.convnext.n_layer); + + add_kv(LLM_KV_CLASSIFIER_OUTPUT_LABELS, model->classifier_labels); + + add_kv(LLM_KV_SHORTCONV_L_CACHE, hparams.n_shortconv_l_cache); + + add_kv(LLM_KV_XIELU_ALPHA_N, hparams.xielu_alpha_n); + add_kv(LLM_KV_XIELU_ALPHA_P, hparams.xielu_alpha_p); + add_kv(LLM_KV_XIELU_BETA, hparams.xielu_beta); + add_kv(LLM_KV_XIELU_EPS, hparams.xielu_eps); // deprecated // add_kv(LLM_KV_TOKENIZER_PREFIX_ID, ???); // add_kv(LLM_KV_TOKENIZER_SUFFIX_ID, ???); // add_kv(LLM_KV_TOKENIZER_MIDDLE_ID, ???); + + add_kv(LLM_KV_DENSE_2_FEAT_IN, hparams.dense_2_feat_in); + add_kv(LLM_KV_DENSE_2_FEAT_OUT, hparams.dense_2_feat_out); + add_kv(LLM_KV_DENSE_3_FEAT_IN, hparams.dense_3_feat_in); + add_kv(LLM_KV_DENSE_3_FEAT_OUT, hparams.dense_3_feat_out); } void llama_model_saver::add_tensors_from_model() { - if (std::string(model->output->name) != std::string(model->tok_embd->name)) { + if (model->output != nullptr && + std::string(model->output->name) != std::string(model->tok_embd->name)) { add_tensor(model->tok_embd); // some models use the same tensor for tok_embd and output } add_tensor(model->type_embd); @@ -297,3 +409,6 @@ void llama_model_saver::save(const std::string & path_model) { gguf_write_to_file(gguf_ctx, path_model.c_str(), false); } +void llama_model_saver::save(FILE * file) { + gguf_write_to_file_ptr(gguf_ctx, file, false); +} diff --git a/examples/talk-llama/llama-model-saver.h b/examples/talk-llama/llama-model-saver.h index 2b3541ce6c5..36a715e2b6b 100644 --- a/examples/talk-llama/llama-model-saver.h +++ b/examples/talk-llama/llama-model-saver.h @@ -6,6 +6,9 @@ #include +// FIXME temporary function for better error messages +bool llama_model_saver_supports_arch(llm_arch arch); + struct llama_model_saver { struct gguf_context * gguf_ctx = nullptr; const bool gguf_ctx_owned; @@ -37,4 +40,5 @@ struct llama_model_saver { void add_tensors_from_model(); void save(const std::string & path_model); + void save(FILE * file); }; diff --git a/examples/talk-llama/llama-model.cpp b/examples/talk-llama/llama-model.cpp index e8e1bbf1cd1..9e2a13cbd43 100644 --- a/examples/talk-llama/llama-model.cpp +++ b/examples/talk-llama/llama-model.cpp @@ -1,6 +1,8 @@ #include "llama-model.h" -#include "ggml.h" +#include "llama-arch.h" +#include "llama-ext.h" +#include "llama-hparams.h" #include "llama-impl.h" #include "llama-mmap.h" #include "llama-cparams.h" @@ -12,10 +14,11 @@ #include "llama-memory-hybrid-iswa.h" #include "llama-memory-recurrent.h" -#include "ggml-cpp.h" - #include "models/models.h" +#include "ggml.h" +#include "ggml-cpp.h" + #include #include #include @@ -24,9 +27,358 @@ #include #include #include +#include #include #include #include +#include +#include + +struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const struct ggml_tensor * tensor, void * userdata) { + const llama_meta_device_get_split_state_userdata * ud = (const llama_meta_device_get_split_state_userdata *) userdata; + const llama_hparams & hparams = ud->model->hparams; + const std::string tensor_name = tensor->name; + + const std::regex pattern_q_weight ("blk\\.\\d*\\.attn_q.weight"); + const std::regex pattern_kv_weight ("blk\\.\\d*\\.attn_(k|v).weight"); + const std::regex pattern_qkv_weight ("blk\\.\\d*\\.attn_qkv.weight"); + const std::regex pattern_q_bias ("blk\\.\\d*\\.attn_q\\.bias"); + const std::regex pattern_kv_bias ("blk\\.\\d*\\.attn_(k|v)\\.bias"); + const std::regex pattern_qkv_bias ("blk\\.\\d*\\.attn_qkv.bias"); + const std::regex pattern_qk_norm ("blk\\.\\d*\\.attn_(q|k)_norm\\.weight"); + const std::regex pattern_kv_cache ("cache_(k|v)_l\\d*"); + const std::regex pattern_attn_sinks ("blk\\.\\d*\\.attn_sinks.weight"); + const std::regex pattern_attn_out_weight ("blk\\.\\d*\\.attn_output.weight"); + const std::regex pattern_attn_out_bias ("blk\\.\\d*\\.attn_output.bias"); + const std::regex pattern_attn_gate_weight("blk\\.\\d*\\.attn_gate.weight"); + + const std::regex pattern_ssm_dt ("blk\\.\\d*\\.ssm_dt.bias"); + const std::regex pattern_ssm_a ("blk\\.\\d*\\.ssm_a"); + const std::regex pattern_ssm_alpha ("blk\\.\\d*\\.ssm_alpha.weight"); + const std::regex pattern_ssm_beta ("blk\\.\\d*\\.ssm_beta.weight"); + const std::regex pattern_ssm_beta_alpha ("blk\\.\\d*\\.ssm_ba.weight"); + const std::regex pattern_r_cache ("cache_r_l\\d*"); + const std::regex pattern_s_cache ("cache_s_l\\d*"); + const std::regex pattern_ssm_conv1d ("blk\\.\\d*\\.ssm_conv1d.weight"); + const std::regex pattern_ssm_out_weight ("blk\\.\\d*\\.ssm_out.weight"); + + const std::regex pattern_ffn_up_gate_weight("blk\\.\\d*\\.ffn_(up|gate)(_exps)?.weight"); + const std::regex pattern_ffn_up_gate_bias ("blk\\.\\d*\\.ffn_(up|gate)(_exps)?.bias"); + const std::regex pattern_ffn_gate_up_weight("blk\\.\\d*\\.ffn_gate_up(_exps)?.weight"); + const std::regex pattern_ffn_down_weight ("blk\\.\\d*\\.ffn_down(_exps)?.weight"); + const std::regex pattern_ffn_down_bias ("blk\\.\\d*\\.ffn_down.bias"); + const std::regex pattern_ffn_down_exps_bias("blk\\.\\d*\\.ffn_down_exps.bias"); + + const std::regex pattern_output_weight("output\\.weight"); + const std::regex pattern_output_bias ("output\\.bias"); + + struct tensor_config { + ggml_backend_meta_split_axis axis; + + const ggml_tensor * tensor_axis_0; + + uint32_t il; + size_t rotation; // when assigning tensor slices, rotate how the rounding is done for more even allocation + }; + + auto get_tensor_config_impl = [&]( + const ggml_backend_meta_split_axis axis, const std::string & suffix = "", const std::string & suffix_fallback = "") -> tensor_config { + // the layers in a tensor can be inhomogeneous, if the pattern is cleanly divided by the number of GPUs there can be aliasing effects, + // count only the same type of previous layers to avoid this + auto get_il_eff = [&](const size_t il){ + size_t ret = 0; + const bool il_is_recurrent = hparams.is_recurrent(il); + const bool il_is_swa = hparams.is_swa(il); + for (size_t il_prev = 0; il_prev < il; il_prev++) { + ret += hparams.is_recurrent(il_prev) == il_is_recurrent && hparams.is_swa(il_prev) == il_is_swa; + } + return ret; + }; + + uint32_t il; + std::string prefix; + size_t rotation; + if (tensor_name.substr(0, 4) == "blk.") { + const size_t length_prefix = tensor_name.find('.', 4); + GGML_ASSERT(length_prefix != std::string::npos); + prefix = tensor_name.substr(0, length_prefix + 1); + il = std::stoull(tensor_name.substr(4, length_prefix)); + rotation = get_il_eff(il) % ud->n_devices; + } else if (tensor_name.substr(0, 6) == "cache_") { + const size_t layer_index_start = tensor_name.find("_l", 6); + GGML_ASSERT(layer_index_start != std::string::npos); + il = std::stoull(tensor_name.substr(layer_index_start + 2)); + prefix = "blk." + std::to_string(il) + "."; + rotation = get_il_eff(il) % ud->n_devices; + } else { + il = 0; + rotation = hparams.n_layer % ud->n_devices; + } + const ggml_tensor * tensor_axis_0 = suffix.empty() ? tensor : ud->model->get_tensor((prefix + suffix).c_str()); + if (tensor_axis_0 == nullptr) { + GGML_ASSERT(!suffix_fallback.empty()); + tensor_axis_0 = ud->model->get_tensor((prefix + suffix_fallback).c_str()); + } + GGML_ASSERT(tensor_axis_0 != nullptr); + return {axis, tensor_axis_0, il, rotation}; + }; + + auto get_tensor_config = [&]() -> tensor_config { + // standard attention + if (std::regex_match(tensor_name, pattern_q_weight) || std::regex_match(tensor_name, pattern_kv_weight)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_1, "attn_output.weight"); + } + if (std::regex_match(tensor_name, pattern_q_bias) || std::regex_match(tensor_name, pattern_kv_bias)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_0, "attn_output.weight"); + } + if (std::regex_match(tensor_name, pattern_qkv_weight)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_1); + } + if ( std::regex_match(tensor_name, pattern_qkv_bias)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_0); + } + if (std::regex_match(tensor_name, pattern_qk_norm)) { + return get_tensor_config_impl(tensor->ne[1] == 1 ? GGML_BACKEND_SPLIT_AXIS_MIRRORED : GGML_BACKEND_SPLIT_AXIS_1, "attn_output.weight"); + } + if (std::regex_match(tensor_name, pattern_kv_cache) || std::regex_match(tensor_name, pattern_attn_sinks)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_0, "attn_output.weight"); + } + if (std::regex_match(tensor_name, pattern_attn_out_weight)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_0); + } + if (std::regex_match(tensor_name, pattern_attn_out_bias)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_MIRRORED); + } + + if (std::regex_match(tensor_name, pattern_attn_gate_weight)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_1); + } + if (std::regex_match(tensor_name, pattern_ssm_dt) || std::regex_match(tensor_name, pattern_ssm_a)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_0, "ssm_out.weight"); + } + if (std::regex_match(tensor_name, pattern_ssm_alpha) || std::regex_match(tensor_name, pattern_ssm_beta) || + std::regex_match(tensor_name, pattern_ssm_beta_alpha)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_1, "ssm_out.weight"); + } + if (std::regex_match(tensor_name, pattern_r_cache) || std::regex_match(tensor_name, pattern_s_cache)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_0, "ssm_out.weight"); + } + if (std::regex_match(tensor_name, pattern_ssm_conv1d)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_1, "ssm_out.weight"); + } + if (std::regex_match(tensor_name, pattern_ssm_out_weight)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_0); + } + + // FFN + if (std::regex_match(tensor_name, pattern_ffn_up_gate_weight)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_1, "ffn_down.weight", "ffn_down_exps.weight"); + } + if (std::regex_match(tensor_name, pattern_ffn_up_gate_bias)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_0, "ffn_down.weight", "ffn_down_exps.weight"); + } + if (std::regex_match(tensor_name, pattern_ffn_gate_up_weight)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_1, "ffn_down.weight", "ffn_down_exps.weight"); + } + if (std::regex_match(tensor_name, pattern_ffn_down_weight)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_0, "ffn_down.weight", "ffn_down_exps.weight"); + } + if (std::regex_match(tensor_name, pattern_ffn_down_bias)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_MIRRORED); + } + if (std::regex_match(tensor_name, pattern_ffn_down_exps_bias)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_PARTIAL); + } + + // output + if (std::regex_match(tensor_name, pattern_output_weight)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_1); + } + if (std::regex_match(tensor_name, pattern_output_bias)) { + const ggml_tensor * output_weight = ud->model->get_tensor("output.weight"); + GGML_ASSERT(output_weight != nullptr); + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_0); + } + + // everything else + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_MIRRORED); + }; + + auto get_split_segments = [&](int axis, uint32_t il) -> std::vector { + if (ud->model->arch == LLM_ARCH_QWEN3NEXT || ud->model->arch == LLM_ARCH_QWEN35 || ud->model->arch == LLM_ARCH_QWEN35MOE) { + const int64_t head_k_dim = hparams.ssm_d_state; + const int64_t head_v_dim = hparams.ssm_d_state; + const int64_t n_k_heads = hparams.ssm_n_group; + const int64_t n_v_heads = hparams.ssm_dt_rank; + const int64_t key_dim = head_k_dim * n_k_heads; + const int64_t value_dim = head_v_dim * n_v_heads; + + // both Qwen 3 Next and Qwen 3.5 support n_v_heads > n_k_heads but the broadcasting pattern is different: + // - Qwen 3 Next: [k0_v0, k0_v1, k1_v2, k1_v3] (this is the default split pattern) + // - Qwen 3.5: [k0_v0, k1_v1, k0_v2, k1_v3] (needs segmenting of V on the scale of K to get the correct pattern) + if (ud->model->arch == LLM_ARCH_QWEN3NEXT) { + if (std::regex_match(tensor_name, pattern_qkv_weight) || std::regex_match(tensor_name, pattern_ssm_conv1d)) { + GGML_ASSERT(tensor->ne[axis] == 2*key_dim + value_dim); + return {key_dim, key_dim, value_dim}; + } + } else { + const int64_t head_ratio = n_v_heads / n_k_heads; + if (std::regex_match(tensor_name, pattern_qkv_weight) || std::regex_match(tensor_name, pattern_ssm_conv1d)) { + GGML_ASSERT(tensor->ne[axis] == 2*key_dim + value_dim); + return std::vector(2 + head_ratio, key_dim); + } + if (std::regex_match(tensor_name, pattern_attn_gate_weight) || std::regex_match(tensor_name, pattern_ssm_out_weight)) { + return std::vector(head_ratio, key_dim); + } + if (std::regex_match(tensor_name, pattern_ssm_dt) || std::regex_match(tensor_name, pattern_ssm_a) || + std::regex_match(tensor_name, pattern_ssm_alpha) || std::regex_match(tensor_name, pattern_ssm_beta)) { + return std::vector(head_ratio, n_k_heads); + } + if (std::regex_match(tensor_name, pattern_r_cache)) { + return std::vector(2 + head_ratio, key_dim * (hparams.ssm_d_conv - 1)); + } + if (std::regex_match(tensor_name, pattern_s_cache)) { + return std::vector(head_ratio, n_k_heads * head_v_dim * head_v_dim); + } + } + + // the FFN is the same for Qwen 3 Next and Qwen 3.5: + if (std::regex_match(tensor_name, pattern_ffn_gate_up_weight)) { + const int64_t n_ff_exp = hparams.n_ff_exp; + GGML_ASSERT(tensor->ne[axis] == 2*n_ff_exp); + return {n_ff_exp, n_ff_exp}; + } + return {tensor->ne[axis]}; + } + + if (std::regex_match(tensor_name, pattern_qkv_weight) || std::regex_match(tensor_name, pattern_qkv_bias)) { + const int64_t n_embd = hparams.n_embd; + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(il); + GGML_ASSERT(hparams.n_embd_k_gqa() == n_embd_gqa); + GGML_ASSERT(tensor->ne[axis] == n_embd + 2*n_embd_gqa); + return {n_embd, n_embd_gqa, n_embd_gqa}; + } + if (std::regex_match(tensor_name, pattern_ffn_gate_up_weight)) { + const int64_t n_ff_exp = hparams.n_ff_exp; + GGML_ASSERT(tensor->ne[axis] == 2*n_ff_exp); + return {n_ff_exp, n_ff_exp}; + } + return {tensor->ne[axis]}; + }; + + auto get_split_granularity = [&](int64_t blck_size, uint32_t il, const std::vector & segments) -> std::vector { + if (hparams.is_recurrent(il)) { + // linear attention + const int64_t head_dim = hparams.ssm_d_state; + const int64_t granularity_qkv = std::lcm(blck_size, head_dim); + if (std::regex_match(tensor_name, pattern_qkv_weight) || std::regex_match(tensor_name, pattern_attn_gate_weight) || + std::regex_match(tensor_name, pattern_ssm_conv1d) || std::regex_match(tensor_name, pattern_ssm_out_weight)) { + return std::vector(segments.size(), granularity_qkv); + } + if (std::regex_match(tensor_name, pattern_ssm_dt) || std::regex_match(tensor_name, pattern_ssm_a) || + std::regex_match(tensor_name, pattern_ssm_alpha) || std::regex_match(tensor_name, pattern_ssm_beta)) { + return std::vector(segments.size(), granularity_qkv / head_dim); + } + if (std::regex_match(tensor_name, pattern_ssm_beta_alpha)) { + return std::vector(segments.size(), 2 * (granularity_qkv / head_dim)); + } + if (std::regex_match(tensor_name, pattern_r_cache)) { + return std::vector(segments.size(), granularity_qkv * (hparams.ssm_d_conv - 1)); + } + if (std::regex_match(tensor_name, pattern_s_cache)) { + return std::vector(segments.size(), granularity_qkv * head_dim); + } + } else { + // regular attention + const uint32_t n_gqa = hparams.n_gqa(il); + const uint32_t n_embd_q = n_gqa * hparams.n_embd_head_k(il); + if (std::regex_match(tensor_name, pattern_attn_sinks)) { + GGML_ASSERT(segments.size() == 1); + return {std::lcm(n_embd_q, blck_size)/n_embd_q * n_gqa}; + } + + const int64_t granularity_q = std::lcm(n_embd_q, blck_size); + if (std::regex_match(tensor_name, pattern_q_weight) || std::regex_match(tensor_name, pattern_q_bias)) { + GGML_ASSERT(segments.size() == 1); + // some models have Q gate tensors, for those cases the granularity needs to be doubled: + if (ud->model->arch == LLM_ARCH_QWEN3NEXT || ud->model->arch == LLM_ARCH_QWEN35 || ud->model->arch == LLM_ARCH_QWEN35MOE) { + return {std::lcm(2*n_embd_q, blck_size)}; + } + return {granularity_q}; + } + if (std::regex_match(tensor_name, pattern_attn_out_weight)) { + GGML_ASSERT(segments.size() == 1); + return {granularity_q}; + } + + const int64_t granularity_kv = granularity_q / n_gqa; + if (std::regex_match(tensor_name, pattern_kv_weight) || + std::regex_match(tensor_name, pattern_kv_bias) || + std::regex_match(tensor_name, pattern_kv_cache)) { + GGML_ASSERT(segments.size() == 1); + return {granularity_kv}; + } + if (std::regex_match(tensor_name, pattern_qkv_weight) || std::regex_match(tensor_name, pattern_qkv_bias)) { + GGML_ASSERT(segments.size() == 3); + return {granularity_q, granularity_kv, granularity_kv}; + } + } + + // FFN + if (std::regex_match(tensor_name, pattern_ffn_up_gate_weight) || std::regex_match(tensor_name, pattern_ffn_up_gate_bias) || + std::regex_match(tensor_name, pattern_ffn_gate_up_weight) || std::regex_match(tensor_name, pattern_ffn_down_weight)) { + GGML_ASSERT(segments.size() <= 2); + return std::vector(segments.size(), blck_size); + } + + // everything else + GGML_ASSERT(segments.size() == 1); + return {1}; + }; + + ggml_backend_meta_split_state split_state; + memset(&split_state, 0, sizeof(split_state)); + tensor_config tc = get_tensor_config(); + split_state.axis = tc.axis; + if (split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS) { + const int64_t ne_full = tensor->ne[split_state.axis]; + const int64_t blck_size = ggml_blck_size(tc.tensor_axis_0->type); + const float * tensor_split = ud->model->tensor_split(); + std::vector tensor_split_scan; + tensor_split_scan.reserve(ud->n_devices); + for (size_t j = 0; j < ud->n_devices; j++) { + tensor_split_scan.push_back(tensor_split == nullptr ? 0.0f : tensor_split[(j + tc.rotation) % ud->n_devices]); + if (j > 0) { + tensor_split_scan[j] += tensor_split_scan[j - 1]; + } + } + const std::vector segments = get_split_segments(split_state.axis, tc.il); + const std::vector granularity = get_split_granularity(blck_size, tc.il, segments); + for (size_t is = 0; is < segments.size(); is++) { + const int64_t ne_s = segments[is]; + const int64_t g_s = granularity[is]; + GGML_ASSERT(ne_full % g_s == 0); + int64_t low = 0; + size_t j = 0; + for (; j < ud->n_devices - 1; j++) { + int64_t high = tensor_split_scan.back() == 0.0f ? + ne_s * (j+1)/ud->n_devices : ne_s * tensor_split_scan[j]/tensor_split_scan.back(); + if (high % g_s != 0) { + high -= high % g_s; + } + split_state.ne[is*ud->n_devices + (j + tc.rotation) % ud->n_devices] = high - low; + low = high; + } + split_state.ne[is*ud->n_devices + (j + tc.rotation) % ud->n_devices] = ne_s - low; + } + split_state.n_segments = segments.size(); + } else { + memset(split_state.ne, 0, sizeof(split_state.ne)); + split_state.n_segments = 1; + } + return split_state; + GGML_UNUSED(userdata); +} const char * llm_type_name(llm_type type) { switch (type) { @@ -93,6 +445,7 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_26B: return "26B"; case LLM_TYPE_27B: return "27B"; case LLM_TYPE_30B: return "30B"; + case LLM_TYPE_31B: return "31B"; case LLM_TYPE_32B: return "32B"; case LLM_TYPE_34B: return "34B"; case LLM_TYPE_35B: return "35B"; @@ -127,6 +480,7 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_16B_A1B: return "16B.A1B"; case LLM_TYPE_21B_A3B: return "21B.A3B"; case LLM_TYPE_24B_A2B: return "24B.A2B"; + case LLM_TYPE_26B_A4B: return "26B.A4B"; case LLM_TYPE_30B_A3B: return "30B.A3B"; case LLM_TYPE_31B_A3_5B: return "31B.A3.5B"; case LLM_TYPE_35B_A3B: return "35B.A3B"; @@ -181,7 +535,7 @@ static llama_rope_scaling_type llama_rope_scaling_type_from_string(const std::st } // CPU: ACCEL -> GPU host -> CPU extra -> CPU -static buft_list_t make_cpu_buft_list(const std::vector & devices, bool use_extra_bufts, bool no_host) { +static buft_list_t make_cpu_buft_list(const std::vector & devices, bool use_extra_bufts, bool no_host) { buft_list_t buft_list; // add ACCEL buffer types @@ -203,10 +557,10 @@ static buft_list_t make_cpu_buft_list(const std::vector & de // a better approach would be to handle this on a weight-by-weight basis using the offload_op // function of the device to determine if it would benefit from being stored in a host buffer if (!no_host) { - for (auto * dev : devices) { - ggml_backend_buffer_type_t buft = ggml_backend_dev_host_buffer_type(dev); + for (const auto & dev : devices) { + ggml_backend_buffer_type_t buft = ggml_backend_dev_host_buffer_type(dev.dev); if (buft) { - buft_list.emplace_back(dev, buft); + buft_list.emplace_back(dev.dev, buft); break; } } @@ -273,14 +627,16 @@ static buft_list_t make_gpu_buft_list(ggml_backend_dev_t dev, llama_split_mode s // add the device extra buffer type (if any) ggml_backend_reg_t reg = ggml_backend_dev_backend_reg(dev); - auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t) - ggml_backend_reg_get_proc_address(reg, "ggml_backend_dev_get_extra_bufts"); - - if (ggml_backend_dev_get_extra_bufts_fn) { - ggml_backend_buffer_type_t * extra_bufts = ggml_backend_dev_get_extra_bufts_fn(dev); - while (extra_bufts && *extra_bufts) { - buft_list.emplace_back(dev, *extra_bufts); - ++extra_bufts; + if (reg) { + auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t) + ggml_backend_reg_get_proc_address(reg, "ggml_backend_dev_get_extra_bufts"); + + if (ggml_backend_dev_get_extra_bufts_fn) { + ggml_backend_buffer_type_t * extra_bufts = ggml_backend_dev_get_extra_bufts_fn(dev); + while (extra_bufts && *extra_bufts) { + buft_list.emplace_back(dev, *extra_bufts); + ++extra_bufts; + } } } @@ -342,6 +698,9 @@ void llama_model::load_arch(llama_model_loader & ml) { if (arch == LLM_ARCH_UNKNOWN) { throw std::runtime_error("unknown model architecture: '" + ml.get_arch_name() + "'"); } + if (!devices.empty() && devices[0].is_meta && !llm_arch_supports_sm_tensor(arch)) { + throw std::runtime_error(std::string("LLAMA_SPLIT_MODE_TENSOR not implemented for architecture '") + llm_arch_name(arch) + "'"); + } } void llama_model::load_hparams(llama_model_loader & ml) { @@ -370,12 +729,21 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_CONTEXT_LENGTH, hparams.n_ctx_train); ml.get_key(LLM_KV_EMBEDDING_LENGTH, hparams.n_embd); ml.get_key(LLM_KV_EMBEDDING_LENGTH_OUT, hparams.n_embd_out_impl, false); + ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn, false); + ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); ml.get_key(LLM_KV_BLOCK_COUNT, hparams.n_layer); ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert, false); ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, false); ml.get_key(LLM_KV_EXPERT_GROUP_COUNT, hparams.n_expert_groups, false); ml.get_key(LLM_KV_EXPERT_GROUP_USED_COUNT, hparams.n_group_used, false); + if (arch == LLM_ARCH_HUNYUAN_VL || arch == LLM_ARCH_HUNYUAN_DENSE) { + if (hparams.n_expert <= 1) { + hparams.n_expert = 0; + hparams.n_expert_used = 0; + } + } + if (arch == LLM_ARCH_WAVTOKENIZER_DEC) { ml.get_key(LLM_KV_FEATURES_LENGTH, hparams.n_embd); ml.get_key(LLM_KV_EMBEDDING_LENGTH, hparams.n_embd_out_impl); @@ -454,6 +822,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.rope_freq_scale_train = ropescale == 0.0f ? 1.0f : 1.0f/ropescale; ml.get_key(LLM_KV_ROPE_SCALING_ATTN_FACTOR, hparams.rope_attn_factor, false); + ml.get_key(LLM_KV_ROPE_SCALING_ALPHA, hparams.rope_scaling_alpha, false); // non-transformer models do not have attention heads if (hparams.n_head() > 0) { @@ -748,8 +1117,6 @@ void llama_model::load_hparams(llama_model_loader & ml) { case LLM_ARCH_BERT: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn, false); - ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); switch (hparams.n_layer) { case 3: @@ -781,8 +1148,6 @@ void llama_model::load_hparams(llama_model_loader & ml) { } ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn, false); - ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); switch (hparams.n_layer) { case 12: @@ -797,8 +1162,6 @@ void llama_model::load_hparams(llama_model_loader & ml) { case LLM_ARCH_JINA_BERT_V2: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn, false); - ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); hparams.f_max_alibi_bias = 8.0f; switch (hparams.n_layer) { @@ -810,8 +1173,6 @@ void llama_model::load_hparams(llama_model_loader & ml) { case LLM_ARCH_JINA_BERT_V3: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn, false); - ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); switch (hparams.n_layer) { case 24: @@ -823,8 +1184,6 @@ void llama_model::load_hparams(llama_model_loader & ml) { case LLM_ARCH_NOMIC_BERT_MOE: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn, false); - ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); ml.get_key(LLM_KV_MOE_EVERY_N_LAYERS, hparams.moe_every_n_layers, 0); if (hparams.n_layer == 12 && hparams.n_embd == 768) { @@ -838,8 +1197,6 @@ void llama_model::load_hparams(llama_model_loader & ml) { case LLM_ARCH_NEO_BERT: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn, false); - ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); if (hparams.n_layer == 28) { type = LLM_TYPE_250M; @@ -848,8 +1205,6 @@ void llama_model::load_hparams(llama_model_loader & ml) { case LLM_ARCH_EUROBERT: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn, false); - ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); if (hparams.n_layer == 12) { type = LLM_TYPE_SMALL; // 0.2B @@ -913,7 +1268,6 @@ void llama_model::load_hparams(llama_model_loader & ml) { // fall through case LLM_ARCH_QWEN2: { - ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { case 24: type = hparams.n_embd == 1024 ? LLM_TYPE_0_5B : LLM_TYPE_1B; break; @@ -940,8 +1294,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { } // Set non-causal attention for diffusion models hparams.causal_attn = false; - } - break; + } break; case LLM_ARCH_LLADA: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -955,8 +1308,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { } // Set non-causal attention for diffusion models hparams.causal_attn = false; - } - break; + } break; case LLM_ARCH_LLADA_MOE: { ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); @@ -995,7 +1347,6 @@ void llama_model::load_hparams(llama_model_loader & ml) { } break; case LLM_ARCH_QWEN3: { - ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { case 28: type = hparams.n_embd == 1024 ? LLM_TYPE_0_6B : LLM_TYPE_1_7B; break; @@ -1275,6 +1626,34 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_GEMMA4: + { + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.swa_layers, hparams.n_layer); + + uint32_t n_kv_shared_layers = 0; + ml.get_key(LLM_KV_ATTENTION_SHARED_KV_LAYERS, n_kv_shared_layers, false); + + hparams.n_layer_kv_from_start = hparams.n_layer - (int32_t)n_kv_shared_layers; + hparams.f_attention_scale = 1.0f; // Gemma4 uses self.scaling = 1.0 (no pre-attn scaling) + + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_EMBEDDING_LENGTH_PER_LAYER, hparams.n_embd_per_layer); + ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_SWA, hparams.n_embd_head_k_swa); + ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_SWA, hparams.n_embd_head_v_swa); + ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false); + + switch (hparams.n_layer) { + case 30: type = LLM_TYPE_26B_A4B; break; + case 35: type = LLM_TYPE_E2B; break; + case 42: type = LLM_TYPE_E4B; break; + case 60: type = LLM_TYPE_31B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_GEMMA_EMBEDDING: { hparams.swa_type = LLAMA_SWA_TYPE_SYMMETRIC; @@ -1287,7 +1666,6 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); //applied only if model converted with --sentence-transformers-dense-modules ml.get_key(LLM_KV_DENSE_2_FEAT_IN, hparams.dense_2_feat_in, false); @@ -1587,6 +1965,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { } } break; case LLM_ARCH_DEEPSEEK2: + case LLM_ARCH_MISTRAL4: { // lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B, Kanana-2-30B-A3B const bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26 || (hparams.n_layer == 48 && n_vocab == 128256)); @@ -1623,7 +2002,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { // (optional) temperature tuning - used by mistral-large ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_SCALE, hparams.f_attn_temp_scale, false); - ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_LENGTH, hparams.n_attn_temp_floor_scale, false); + ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_LENGTH, hparams.n_attn_temp_floor_scale, false); // FIXME why not use temperature_length? hparams.f_attn_temp_offset = 0.0f; @@ -1635,6 +2014,26 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_DEEPSEEK2OCR: + { + // similar to deepseek2, but without MLA + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); + + if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { + hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX; + } + + switch (hparams.n_layer) { + case 12: type = LLM_TYPE_3B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_PLM: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -1672,6 +2071,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { // NextN/MTP parameters (GLM-OCR) ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); // TODO: when MTP is implemented, this should probably be updated if needed hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; @@ -1705,6 +2105,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { // NextN/MTP parameters ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); // TODO: when MTP is implemented, this should probably be updated if needed hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; @@ -1751,6 +2152,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { // NextN/MTP parameters ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); // TODO: when MTP is implemented, this should probably be updated if needed hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; @@ -1925,6 +2327,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); switch (hparams.n_layer) { case 32: type = LLM_TYPE_30B_A3B; break; @@ -2053,7 +2456,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { switch (hparams.n_embd) { case 768: type = LLM_TYPE_350M; break; - case 1536: type = (hparams.n_embd == 2048 ? LLM_TYPE_7B_A1B : LLM_TYPE_1B); break; + case 1536: type = (hparams.n_ff() == 512 ? LLM_TYPE_7B_A1B : LLM_TYPE_1B); break; case 2048: case 2560: type = LLM_TYPE_3B; break; case 4096: type = LLM_TYPE_32B; break; default: type = LLM_TYPE_UNKNOWN; @@ -2079,7 +2482,6 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); ml.get_key(LLM_KV_ATTENTION_GROUPNORM_EPS, hparams.f_norm_group_eps); ml.get_key(LLM_KV_ATTENTION_GROUPNORM_GROUPS, hparams.n_norm_groups); - ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn, false); } break; case LLM_ARCH_BAILINGMOE: { @@ -2107,6 +2509,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func); ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); // TODO: when MTP is implemented, this should probably be updated if needed hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; @@ -2197,9 +2600,18 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_HUNYUAN_VL: case LLM_ARCH_HUNYUAN_DENSE: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, false); + + // XDRoPE / NTK-aware scaling: base = rope_theta * alpha^(dim / (dim - 2)) + if (hparams.rope_scaling_alpha > 0.0f) { + const int dim = hparams.n_embd_head_k(); + hparams.rope_freq_base_train = hparams.rope_freq_base_train + * powf(hparams.rope_scaling_alpha, (float)dim / (float)(dim - 2)); + } switch (hparams.n_embd) { case 1024: type = LLM_TYPE_0_5B; break; @@ -2588,11 +3000,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // build a list of buffer types for the CPU and GPU devices pimpl->cpu_buft_list = make_cpu_buft_list(devices, params.use_extra_bufts, params.no_host); - for (auto * dev : devices) { - buft_list_t buft_list = make_gpu_buft_list(dev, split_mode, tensor_split); + for (const auto & dev : devices) { + buft_list_t buft_list = make_gpu_buft_list(dev.dev, split_mode, tensor_split); // add CPU buffer types as a fallback buft_list.insert(buft_list.end(), pimpl->cpu_buft_list.begin(), pimpl->cpu_buft_list.end()); - pimpl->gpu_buft_list.emplace(dev, std::move(buft_list)); + pimpl->gpu_buft_list.emplace(dev.dev, std::move(buft_list)); } ggml_backend_dev_t cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); @@ -2606,7 +3018,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { if (all_zero) { // default split, by free memory for (size_t i = 0; i < n_devices(); ++i) { - ggml_backend_dev_t dev = devices[i]; + ggml_backend_dev_t dev = devices[i].dev; size_t total; size_t free; ggml_backend_dev_memory(dev, &free, &total); @@ -2642,7 +3054,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { return {cpu_dev, &pimpl->cpu_buft_list}; } const int layer_gpu = std::upper_bound(splits.begin(), splits.begin() + n_devices(), float(il - i_gpu_start)/act_gpu_layers) - splits.begin(); - auto * dev = devices.at(layer_gpu); + auto * dev = devices.at(layer_gpu).dev; LLAMA_LOG_DEBUG("load_tensors: layer %3d assigned to device %s, is_swa = %d\n", il, ggml_backend_dev_name(dev), is_swa); return {dev, &pimpl->gpu_buft_list.at(dev)}; }; @@ -2708,6 +3120,25 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", bid), {n_embd_, n_ff_, n_expert_}, flags); } }; + + // helper: try to load merged qkv first, fall back to separate q, k, v + auto create_tensor_qkv = [&](llama_layer & layer, int bid, + int64_t n_embd_, int64_t n_embd_q_, int64_t n_embd_k_, int64_t n_embd_v_, + int flags) { + const int64_t n_embd_qkv = n_embd_q_ + n_embd_k_ + n_embd_v_; + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", bid), {n_embd_, n_embd_qkv}, TENSOR_NOT_REQUIRED | TENSOR_SKIP_IF_VIRTUAL); + if (layer.wqkv) { + layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", bid), {n_embd_qkv}, TENSOR_NOT_REQUIRED | TENSOR_SKIP_IF_VIRTUAL); + } else { + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", bid), {n_embd_, n_embd_q_}, flags); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", bid), {n_embd_, n_embd_k_}, flags); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", bid), {n_embd_, n_embd_v_}, flags); + layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", bid), {n_embd_q_}, TENSOR_NOT_REQUIRED); + layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", bid), {n_embd_k_}, TENSOR_NOT_REQUIRED); + layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", bid), {n_embd_v_}, TENSOR_NOT_REQUIRED); + } + }; + switch (arch) { case LLM_ARCH_LLAMA: case LLM_ARCH_REFACT: @@ -2733,16 +3164,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); // optional bias tensors - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -2805,7 +3231,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // No bias for QKV projections as per config: include_bias=false, include_qkv_bias=false layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), { n_embd }, TENSOR_NOT_REQUIRED); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), { n_embd }, TENSOR_NOT_REQUIRED); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); @@ -2841,9 +3267,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); @@ -2882,9 +3306,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -2928,7 +3350,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) { auto & layer = layers[i]; const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(i); const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(i); - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(i); const int64_t n_ff = hparams.n_ff(i); const int64_t n_head = hparams.n_head(i); const int64_t n_head_kv = hparams.n_head_kv(i); @@ -2941,17 +3362,12 @@ bool llama_model::load_tensors(llama_model_loader & ml) { else if (n_head_kv > 0) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); } // optional bias tensors - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); if (n_ff > 0) { layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -3043,9 +3459,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0); @@ -3108,9 +3522,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -3175,10 +3587,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); - layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); + layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); @@ -3211,28 +3623,16 @@ bool llama_model::load_tensors(llama_model_loader & ml) { cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {hparams.n_cls_out}, TENSOR_NOT_REQUIRED); } - tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); - tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, 0); + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {n_embd}, 0); + tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias", 0), {n_embd}, 0); for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); - layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); - if (!layer.wqkv) { - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); - - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, 0); - - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0); - } - - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0); layer.attn_out_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd}, 0); @@ -3259,7 +3659,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { case LLM_ARCH_MODERN_BERT: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {n_embd}, 0); output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); @@ -3325,9 +3725,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -3342,31 +3740,24 @@ bool llama_model::load_tensors(llama_model_loader & ml) { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); // word_embeddings type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, 0); // token_type_embeddings - tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); // LayerNorm - tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, 0); //LayerNorm bias + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {n_embd}, 0); // LayerNorm + tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias", 0), {n_embd}, 0); // LayerNorm bias cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, 1}, TENSOR_NOT_REQUIRED); cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"), {1}, TENSOR_NOT_REQUIRED); for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; // JinaBertLayer - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); layer.attn_q_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); //output_dens - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); //output_dens + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); //output_dens layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0); //output_norm layer.attn_out_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd}, 0); @@ -3394,8 +3785,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { case LLM_ARCH_BLOOM: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); - tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, 0); + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {n_embd}, 0); + tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias", 0), {n_embd}, 0); // output output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); @@ -3414,10 +3805,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); - layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); + layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); @@ -3450,10 +3841,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); - layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); + layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); @@ -3490,16 +3881,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - // optional bias tensors, present in Stable LM 2 1.6B - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); - // optional q and k layernorms, present in StableLM 2 12B layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k, n_head}, TENSOR_NOT_REQUIRED); layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k, n_head_kv}, TENSOR_NOT_REQUIRED); @@ -3527,7 +3911,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd*3}, 0); - layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd*3}, 0); + layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd*3}, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -3557,16 +3941,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - // optional bias tensors - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); @@ -3587,16 +3964,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - // optional bias tensors - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); @@ -3645,9 +4015,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); @@ -3678,9 +4046,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); @@ -3721,22 +4087,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); - layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); - - if (layer.wqkv == nullptr) { - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); - - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, 0); - - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0); - } + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); @@ -3763,7 +4117,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, n_embd + 2 * n_embd_gqa }, TENSOR_NOT_REQUIRED); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, TENSOR_NOT_REQUIRED); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); @@ -3793,19 +4147,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), { n_embd }, 0); - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, n_embd + 2 * n_embd_gqa }, TENSOR_NOT_REQUIRED); - if (layer.wqkv == nullptr) { - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); - - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, 0); - - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0); - } + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), { n_embd }, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), { n_embd }, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), { n_embd }, 0); @@ -3832,9 +4176,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); @@ -3971,10 +4313,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); - layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); + layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); @@ -4006,11 +4348,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); - layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); @@ -4036,9 +4377,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -4062,9 +4401,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); // layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -4086,9 +4423,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -4110,9 +4445,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); @@ -4147,9 +4480,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); @@ -4175,13 +4506,14 @@ bool llama_model::load_tensors(llama_model_loader & ml) { output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); } - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - tok_embd_per_layer = create_tensor(tn(LLM_TENSOR_PER_LAYER_TOKEN_EMBD, "weight"), {n_embd_altup * n_layer, n_vocab}, 0); + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + altup_proj = create_tensor(tn(LLM_TENSOR_ALTUP_PROJ, "weight"), {n_embd, n_embd, n_altup - 1}, 0); + altup_unembd_proj = create_tensor(tn(LLM_TENSOR_ALTUP_UNEMBD_PROJ, "weight"), {n_embd, n_embd, n_altup - 1}, 0); - altup_proj = create_tensor(tn(LLM_TENSOR_ALTUP_PROJ, "weight"), {n_embd, n_embd, n_altup - 1}, 0); - altup_unembd_proj = create_tensor(tn(LLM_TENSOR_ALTUP_UNEMBD_PROJ, "weight"), {n_embd, n_embd, n_altup - 1}, 0); - per_layer_model_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_MODEL_PROJ, "weight"), {n_embd, n_embd_altup * n_layer}, 0); - per_layer_proj_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ_NORM, "weight"), {n_embd_altup}, 0); + per_layer_tok_embd = create_tensor(tn(LLM_TENSOR_PER_LAYER_TOKEN_EMBD, "weight"), {n_embd_altup * n_layer, n_vocab}, 0); + per_layer_model_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_MODEL_PROJ, "weight", 0), {n_embd, n_embd_altup * n_layer}, 0); + per_layer_proj_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ_NORM, "weight", 0), {n_embd_altup}, 0); output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); @@ -4190,9 +4522,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); @@ -4219,6 +4549,101 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.laurel_post_norm = create_tensor(tn(LLM_TENSOR_LAUREL_POST_NORM, "weight", i), {n_embd}, 0); } } break; + case LLM_ARCH_GEMMA4: + { + const uint32_t n_embd_per_layer = hparams.n_embd_per_layer; + const int64_t n_ff_exp = hparams.n_ff_exp; + + if (n_embd_head_k != n_embd_head_v) { + throw std::runtime_error("Gemma 4 requires n_embd_head_k == n_embd_head_v"); + } + if (hparams.n_embd_head_k_swa != hparams.n_embd_head_v_swa) { + throw std::runtime_error("Gemma 4 requires n_embd_head_k_swa == n_embd_head_v_swa"); + } + + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + if (n_embd_per_layer > 0) { + per_layer_tok_embd = create_tensor(tn(LLM_TENSOR_PER_LAYER_TOKEN_EMBD, "weight"), {n_embd_per_layer * n_layer, n_vocab}, 0); + per_layer_model_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_MODEL_PROJ, "weight", 0), {n_embd, n_embd_per_layer * n_layer}, 0); + per_layer_proj_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ_NORM, "weight", 0), {n_embd_per_layer}, 0); + } + + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + + int rope_freqs_flag = 0; + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + const int64_t n_head = hparams.n_head(i); + const int64_t n_embd_head = hparams.n_embd_head_k(i); + const int64_t n_embd_k = hparams.n_embd_k_gqa(i); + const int64_t n_embd_v = hparams.n_embd_v_gqa(i); + const int kv_flags = hparams.has_kv(i) ? 0 : TENSOR_NOT_REQUIRED; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + // note: use_alternative_attention (v_proj is optional, if it's not present, use k_proj) + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k}, kv_flags); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v}, TENSOR_NOT_REQUIRED); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head * n_head, n_embd}, 0); + + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head}, kv_flags); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); + + layer.out_scale = create_tensor(tn(LLM_TENSOR_LAYER_OUT_SCALE, "weight", i), {1u}, TENSOR_NOT_REQUIRED); + + if (!hparams.is_swa(i)) { + // full_attention layers use rope_freqs for proportional rope + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_embd_head/2}, rope_freqs_flag); + rope_freqs_flag = TENSOR_DUPLICATED; + } + + // handle use_double_wide_mlp + int64_t n_ff_cur = hparams.n_ff(i); + + // for expert layers, we use normal FFN as shared expert (same as python code) + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff_cur}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff_cur}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff_cur, n_embd}, 0); + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); + + // MoE router + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, TENSOR_NOT_REQUIRED); + bool has_expert = layer.ffn_gate_inp != nullptr; + + // norm + if (has_expert) { + layer.ffn_gate_inp_s = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "scale", i), {n_embd}, 0); + + layer.ffn_pre_norm_2 = create_tensor(tn(LLM_TENSOR_FFN_PRE_NORM_2, "weight", i), {n_embd}, 0); + layer.ffn_post_norm_1 = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM_1, "weight", i), {n_embd}, 0); + layer.ffn_post_norm_2 = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM_2, "weight", i), {n_embd}, 0); + + // MoE FFN + layer.ffn_gate_up_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_UP_EXPS, "weight", i), {n_embd, n_ff_exp * 2, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + + // per-expert scale will be loaded as down_exps_s at the end of the current switch case + } + + // per-layer embeddings + if (n_embd_per_layer > 0) { + layer.per_layer_inp_gate = create_tensor(tn(LLM_TENSOR_PER_LAYER_INP_GATE, "weight", i), {n_embd, n_embd_per_layer}, 0); + layer.per_layer_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ, "weight", i), {n_embd_per_layer, n_embd}, 0); + layer.per_layer_post_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_POST_NORM, "weight", i), {n_embd}, 0); + } + } + } break; case LLM_ARCH_STARCODER2: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -4239,16 +4664,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); // optional bias tensors - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, 0); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); @@ -4414,9 +4834,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } else { // Attention layers - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); } @@ -4492,14 +4910,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { const int64_t n_head_i = hparams.n_head(i); const int64_t n_embd_k_gqa_i = hparams.n_embd_k_gqa(i); const int64_t n_embd_v_gqa_i = hparams.n_embd_v_gqa(i); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head_i}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa_i}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa_i}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head_i, n_embd_k_gqa_i, n_embd_v_gqa_i, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head_i, n_embd}, 0); - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_k_gqa_i}, TENSOR_NOT_REQUIRED); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_v_gqa_i}, TENSOR_NOT_REQUIRED); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); } // feed forward (w/ optional biases) @@ -4542,9 +4955,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -4572,9 +4983,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k, n_head_kv}, 0); } - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); @@ -4597,9 +5006,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd }, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_gqa }, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_gqa }, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }, 0); layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, 0); @@ -4622,9 +5029,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); @@ -4645,9 +5050,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, 0); layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_head_kv * n_embd_head}, 0); @@ -4678,14 +5081,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_qo_dim}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_kv_dim}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_kv_dim}, 0); + create_tensor_qkv(layer, i, n_embd, n_qo_dim, n_kv_dim, n_kv_dim, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_qo_dim, n_embd}, 0); - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_qo_dim}, TENSOR_NOT_REQUIRED); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_kv_dim}, TENSOR_NOT_REQUIRED); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_kv_dim}, TENSOR_NOT_REQUIRED); layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); @@ -4709,9 +5107,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, 0); layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, 0); @@ -4778,10 +5174,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); - layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); + layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); @@ -4811,9 +5207,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -4850,9 +5244,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -4883,6 +5275,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } } break; case LLM_ARCH_DEEPSEEK2: + case LLM_ARCH_MISTRAL4: { const bool is_mla = hparams.is_mla(); @@ -4960,6 +5353,60 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); create_tensor_gate_up_exps(layer, i, n_embd, n_ff_exp, n_expert, 0); + // Shared expert branch + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + } + } + } break; + case LLM_ARCH_DEEPSEEK2OCR: + { + // similar to deepseek2, but without MLA + const int64_t n_ff_exp = hparams.n_ff_exp; + const int64_t n_expert_shared = hparams.n_expert_shared; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + // try to load output.weight, if not found, use token_embd (tied embeddings) + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + if (!output) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + // norm + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + if (i < (int) hparams.n_layer_dense_lead) { + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + } else { + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0"); + } + + // MoE branch + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + create_tensor_gate_up_exps(layer, i, n_embd, n_ff_exp, n_expert, 0); + // Shared expert branch layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, 0); @@ -5145,10 +5592,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); - layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); + layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); @@ -5187,10 +5634,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); // attention biases - all have shape n_embd (output dimension of projections) - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd}, 0); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); + layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd}, 0); + layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); @@ -5218,17 +5665,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { auto & layer = layers[i]; layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); - layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); - - if (layer.wqkv == nullptr) { - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); - } + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); @@ -5261,17 +5698,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { auto & layer = layers[i]; layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, flags); - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, flags | TENSOR_NOT_REQUIRED); - layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, flags | TENSOR_NOT_REQUIRED); - - if (layer.wqkv == nullptr) { - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, flags); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, flags); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, flags); - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, flags | TENSOR_NOT_REQUIRED); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, flags | TENSOR_NOT_REQUIRED); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, flags | TENSOR_NOT_REQUIRED); - } + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, flags); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, flags); @@ -5329,12 +5756,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, flags); // GLM-style attention with bias terms - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, flags); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, flags); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, flags); - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), { n_embd_head_k * n_head }, TENSOR_NOT_REQUIRED | flags); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), { n_embd_k_gqa }, TENSOR_NOT_REQUIRED | flags); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), { n_embd_v_gqa }, TENSOR_NOT_REQUIRED | flags); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, flags); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, flags); @@ -5514,16 +5936,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); // optional bias tensors - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); @@ -5590,14 +6007,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { const int64_t n_head_i = hparams.n_head(i); const int64_t n_embd_k_gqa_i = hparams.n_embd_k_gqa(i); const int64_t n_embd_v_gqa_i = hparams.n_embd_v_gqa(i); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head_i}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa_i}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa_i}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head_i, n_embd_k_gqa_i, n_embd_v_gqa_i, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head_i, n_embd}, 0); - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_k_gqa_i}, TENSOR_NOT_REQUIRED); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_v_gqa_i}, TENSOR_NOT_REQUIRED); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); } else { if (n_expert != 0) { const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; @@ -5645,9 +6057,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -5673,9 +6083,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); @@ -5718,9 +6126,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } auto & layer = layers[i]; - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_qo_dim}, flags); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_kv_dim}, flags); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_kv_dim}, flags); + create_tensor_qkv(layer, i, n_embd, n_qo_dim, n_kv_dim, n_kv_dim, flags); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_qo_dim, n_embd}, flags); layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0) | flags); @@ -5773,8 +6179,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); // Block 0, LN0 - tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); - tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, 0); + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {n_embd}, 0); + tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias", 0), {n_embd}, 0); // output output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); @@ -5888,8 +6294,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); // Block 0, LN0 - tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); - tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, 0); + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {n_embd}, 0); + tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias", 0), {n_embd}, 0); // output output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); @@ -6044,9 +6450,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_q_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {n_embd_head_k, n_head}, TENSOR_NOT_REQUIRED); layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {n_embd_head_k, n_head_kv}, TENSOR_NOT_REQUIRED); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -6060,8 +6464,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {hparams.n_embd, n_vocab}, 0); - conv1d = create_tensor(tn(LLM_TENSOR_CONV1D, "weight"), {7, hparams.n_embd, hparams.posnet.n_embd}, 0); - conv1d_b = create_tensor(tn(LLM_TENSOR_CONV1D, "bias"), {1, hparams.posnet.n_embd}, 0); + conv1d = create_tensor(tn(LLM_TENSOR_CONV1D, "weight", 0), {7, hparams.n_embd, hparams.posnet.n_embd}, 0); + conv1d_b = create_tensor(tn(LLM_TENSOR_CONV1D, "bias", 0), {1, hparams.posnet.n_embd}, 0); // posnet { @@ -6126,8 +6530,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { GGML_ASSERT(hparams.posnet.n_embd == hparams.convnext.n_embd); - tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {hparams.posnet.n_embd}, 0); - tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {hparams.posnet.n_embd}, 0); + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {hparams.posnet.n_embd}, 0); + tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias", 0), {hparams.posnet.n_embd}, 0); // convnext { @@ -6175,9 +6579,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_head * n_rot}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_head_kv * n_rot}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_head_kv * n_rot}, 0); + create_tensor_qkv(layer, i, n_embd, n_head * n_rot, n_head_kv * n_rot, n_head_kv * n_rot, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_rot, n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -6278,9 +6680,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_head_k * n_head, n_embd_head_k * n_head, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); @@ -6333,9 +6733,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -6370,9 +6768,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); // attention projections - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); // Q/K normalization @@ -6430,16 +6826,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); // optional bias tensors - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -6519,14 +6910,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { /*ATTENTION LAYERS*/ // attention layers (with optional bias) - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {hidden_size, n_embd_head_k * attn_num_attention_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {hidden_size, attn_num_key_value_head * n_embd_head_k}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {hidden_size, attn_num_key_value_head * n_embd_head_v}, 0); + create_tensor_qkv(layer, i, hidden_size, n_embd_head_k * attn_num_attention_head, attn_num_key_value_head * n_embd_head_k, attn_num_key_value_head * n_embd_head_v, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * attn_num_attention_head, hidden_size}, 0); - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {hidden_size}, TENSOR_NOT_REQUIRED); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {attn_num_key_value_head * n_embd_head_k}, TENSOR_NOT_REQUIRED); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {attn_num_key_value_head * n_embd_head_v}, TENSOR_NOT_REQUIRED); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {hidden_size}, TENSOR_NOT_REQUIRED); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {hidden_size}, TENSOR_NOT_REQUIRED); layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {hidden_size}, 0); @@ -6560,9 +6946,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); @@ -6580,6 +6964,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, 0); } } break; + case LLM_ARCH_HUNYUAN_VL: case LLM_ARCH_HUNYUAN_DENSE: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -6597,9 +6982,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); @@ -6631,9 +7014,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -6658,9 +7039,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_head * n_rot}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_head_kv * n_rot}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_head_kv * n_rot}, 0); + create_tensor_qkv(layer, i, n_embd, n_head * n_rot, n_head_kv * n_rot, n_head_kv * n_rot, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_rot, n_embd}, 0); layer.attn_sinks = create_tensor(tn(LLM_TENSOR_ATTN_SINKS, "weight", i), {n_head}, 0); @@ -6670,11 +7049,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); - // bias - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_head * n_rot}, 0); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_head_kv * n_rot}, 0); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_head_kv * n_rot}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); layer.ffn_gate_inp_b = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "bias", i), {n_expert}, 0); layer.ffn_gate_exps_b = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "bias", i), {n_ff_exp, n_expert}, 0); @@ -6722,9 +7097,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); GGML_ASSERT(n_embd_v_gqa == n_embd_k_gqa); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, hparams.n_embd_k_gqa(i)}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, hparams.n_embd_v_gqa(i)}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, hparams.n_embd_k_gqa(i), hparams.n_embd_v_gqa(i), 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); } else { @@ -6756,9 +7129,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_gqa }, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_gqa }, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); @@ -6795,9 +7166,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); @@ -6841,16 +7210,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); } - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_gqa }, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_gqa }, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); // optional bias tensors - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), { n_embd }, TENSOR_NOT_REQUIRED); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), { n_embd_gqa }, TENSOR_NOT_REQUIRED); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), { n_embd_gqa }, TENSOR_NOT_REQUIRED); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), { n_embd }, TENSOR_NOT_REQUIRED); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), { n_embd }, TENSOR_NOT_REQUIRED); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0); @@ -6874,9 +7238,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_gqa }, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_gqa }, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); @@ -6933,9 +7295,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // q, k, v projections // Python: q_proj, k_proj, v_proj - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k_kda * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_head_k_kda * n_head}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_head_v_kda * n_head}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k_kda * n_head, n_embd_head_k_kda * n_head, n_embd_head_v_kda * n_head, 0); // KDA specific projections // f_a_proj, f_b_proj @@ -7081,16 +7441,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); // weight tensors - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); // bias tensors - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd_head_k * n_head}, 0); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, 0); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -7147,9 +7502,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { if (!hparams.is_recurrent(i)) { // Attention layers - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head * 2 }, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); // Q/K normalization for attention layers @@ -7213,9 +7566,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { if (!hparams.is_recurrent(i)) { // Attention layers - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head * 2 }, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); // Q/K normalization for attention layers @@ -7278,9 +7629,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { if (!hparams.is_recurrent(i)) { // Attention layers - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head * 2 }, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); // Q/K normalization for attention layers @@ -7319,9 +7668,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i); uint32_t n_head = hparams.n_head(i); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_v * n_head, n_embd }, 0); layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); @@ -7380,9 +7727,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot_max/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); } - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head_l}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head_l, n_embd_k_gqa, n_embd_v_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_v * n_head_l, n_embd}, 0); // head-wise attention gate (Step35 self_attn.g_proj) @@ -7426,9 +7771,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); @@ -7501,6 +7844,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } // recurrent / linear-attention weight scales (per-tensor, shape {1}) + if (!layer.ssm_in_s && layer.ssm_in) { + layer.ssm_in_s = create_tensor(tn(LLM_TENSOR_SSM_IN, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } if (!layer.ssm_out_s && layer.ssm_out) { layer.ssm_out_s = create_tensor(tn(LLM_TENSOR_SSM_OUT, "scale", i), {1}, TENSOR_NOT_REQUIRED); } @@ -7510,11 +7856,77 @@ bool llama_model::load_tensors(llama_model_loader & ml) { if (!layer.ssm_beta_s && layer.ssm_beta) { layer.ssm_beta_s = create_tensor(tn(LLM_TENSOR_SSM_BETA, "scale", i), {1}, TENSOR_NOT_REQUIRED); } + + // input scales + if (!layer.wq_in_s && layer.wq) { + layer.wq_in_s = create_tensor(tn(LLM_TENSOR_ATTN_Q, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.wk_in_s && layer.wk) { + layer.wk_in_s = create_tensor(tn(LLM_TENSOR_ATTN_K, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.wv_in_s && layer.wv) { + layer.wv_in_s = create_tensor(tn(LLM_TENSOR_ATTN_V, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.wo_in_s && layer.wo) { + layer.wo_in_s = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.wqkv_in_s && layer.wqkv) { + layer.wqkv_in_s = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.wqkv_gate_in_s && layer.wqkv_gate) { + layer.wqkv_gate_in_s = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_gate_in_s && layer.ffn_gate) { + layer.ffn_gate_in_s = create_tensor(tn(LLM_TENSOR_FFN_GATE, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_down_in_s && layer.ffn_down) { + layer.ffn_down_in_s = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_up_in_s && layer.ffn_up) { + layer.ffn_up_in_s = create_tensor(tn(LLM_TENSOR_FFN_UP, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_gate_exps_in_s && layer.ffn_gate_exps) { + layer.ffn_gate_exps_in_s = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "input_scale", i), {n_expert}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_down_exps_in_s && layer.ffn_down_exps) { + layer.ffn_down_exps_in_s = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "input_scale", i), {n_expert}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_up_exps_in_s && layer.ffn_up_exps) { + layer.ffn_up_exps_in_s = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "input_scale", i), {n_expert}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_gate_shexp_in_s && layer.ffn_gate_shexp) { + layer.ffn_gate_shexp_in_s = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_down_shexp_in_s && layer.ffn_down_shexp) { + layer.ffn_down_shexp_in_s = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_up_shexp_in_s && layer.ffn_up_shexp) { + layer.ffn_up_shexp_in_s = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ssm_in_in_s && layer.ssm_in) { + layer.ssm_in_in_s = create_tensor(tn(LLM_TENSOR_SSM_IN, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ssm_out_in_s && layer.ssm_out) { + layer.ssm_out_in_s = create_tensor(tn(LLM_TENSOR_SSM_OUT, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ssm_alpha_in_s && layer.ssm_alpha) { + layer.ssm_alpha_in_s = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ssm_beta_in_s && layer.ssm_beta) { + layer.ssm_beta_in_s = create_tensor(tn(LLM_TENSOR_SSM_BETA, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } } } ml.done_getting_tensors(); + // populate tensors_by_name + for (auto & [_, ctx_ptr] : ml.ctx_map) { + for (auto * cur = ggml_get_first_tensor(ctx_ptr.get()); cur != NULL; cur = ggml_get_next_tensor(ctx_ptr.get(), cur)) { + tensors_by_name.emplace_back(ggml_get_name(cur), cur); + } + } + ml.init_mappings(true, use_mlock ? &pimpl->mlock_mmaps : nullptr); pimpl->mappings.reserve(ml.mappings.size()); @@ -7597,14 +8009,15 @@ bool llama_model::load_tensors(llama_model_loader & ml) { buf_map.emplace(idx, buf); } } - pimpl->ctxs_bufs.emplace_back(std::move(ctx_ptr), std::move(bufs)); - for (auto & buf : buf_map) { + for (auto & buf : bufs) { // indicate that this buffer contains weights // this is used by ggml_backend_sched to improve op scheduling: ops that use a weight are preferably scheduled to the backend that contains the weight - ggml_backend_buffer_set_usage(buf.second, GGML_BACKEND_BUFFER_USAGE_WEIGHTS); + ggml_backend_buffer_set_usage(buf.get(), GGML_BACKEND_BUFFER_USAGE_WEIGHTS); } + pimpl->ctxs_bufs.emplace_back(std::move(ctx_ptr), std::move(bufs)); + ctx_buf_maps.emplace_back(ctx, buf_map); } @@ -7632,13 +8045,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } } - // populate tensors_by_name - for (auto & [ctx, _] : pimpl->ctxs_bufs) { - for (auto * cur = ggml_get_first_tensor(ctx.get()); cur != NULL; cur = ggml_get_next_tensor(ctx.get(), cur)) { - tensors_by_name.emplace_back(ggml_get_name(cur), cur); - } - } - if (ml.no_alloc) { return true; } @@ -7683,6 +8089,10 @@ size_t llama_model::n_devices() const { return devices.size(); } +const float * llama_model::tensor_split() const { + return params.tensor_split; +} + uint32_t llama_model::n_gpu_layers() const { return params.n_gpu_layers >= 0 ? params.n_gpu_layers : hparams.n_layer + 1; } @@ -7801,114 +8211,114 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: n_cls_out = %u\n", __func__, hparams.n_cls_out); size_t i = 0; - for (auto label : classifier_labels) { + for (const auto & label : classifier_labels) { LLAMA_LOG_INFO("%s: cls_label[%2zu] = %s\n", __func__, i++, label.c_str()); } } - } - if (arch == LLM_ARCH_MAMBA || - arch == LLM_ARCH_MAMBA2 || - arch == LLM_ARCH_JAMBA || - arch == LLM_ARCH_FALCON_H1 || - arch == LLM_ARCH_PLAMO2 || - arch == LLM_ARCH_GRANITE_HYBRID || - arch == LLM_ARCH_QWEN3NEXT || - arch == LLM_ARCH_QWEN35 || - arch == LLM_ARCH_QWEN35MOE || - arch == LLM_ARCH_NEMOTRON_H || - arch == LLM_ARCH_NEMOTRON_H_MOE) { - LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv); - LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner); - LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state); - LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank); - LLAMA_LOG_INFO("%s: ssm_n_group = %u\n", __func__, hparams.ssm_n_group); - LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms); - } + if (arch == LLM_ARCH_MAMBA || + arch == LLM_ARCH_MAMBA2 || + arch == LLM_ARCH_JAMBA || + arch == LLM_ARCH_FALCON_H1 || + arch == LLM_ARCH_PLAMO2 || + arch == LLM_ARCH_GRANITE_HYBRID || + arch == LLM_ARCH_QWEN3NEXT || + arch == LLM_ARCH_QWEN35 || + arch == LLM_ARCH_QWEN35MOE || + arch == LLM_ARCH_NEMOTRON_H || + arch == LLM_ARCH_NEMOTRON_H_MOE) { + LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv); + LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner); + LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state); + LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank); + LLAMA_LOG_INFO("%s: ssm_n_group = %u\n", __func__, hparams.ssm_n_group); + LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms); + } - LLAMA_LOG_INFO("%s: model type = %s\n", __func__, type_name().c_str()); - if (pimpl->n_elements >= 1e12) { - LLAMA_LOG_INFO("%s: model params = %.2f T\n", __func__, pimpl->n_elements*1e-12); - } else if (pimpl->n_elements >= 1e9) { - LLAMA_LOG_INFO("%s: model params = %.2f B\n", __func__, pimpl->n_elements*1e-9); - } else if (pimpl->n_elements >= 1e6) { - LLAMA_LOG_INFO("%s: model params = %.2f M\n", __func__, pimpl->n_elements*1e-6); - } else { - LLAMA_LOG_INFO("%s: model params = %.2f K\n", __func__, pimpl->n_elements*1e-3); - } + LLAMA_LOG_INFO("%s: model type = %s\n", __func__, type_name().c_str()); + if (pimpl->n_elements >= 1e12) { + LLAMA_LOG_INFO("%s: model params = %.2f T\n", __func__, pimpl->n_elements*1e-12); + } else if (pimpl->n_elements >= 1e9) { + LLAMA_LOG_INFO("%s: model params = %.2f B\n", __func__, pimpl->n_elements*1e-9); + } else if (pimpl->n_elements >= 1e6) { + LLAMA_LOG_INFO("%s: model params = %.2f M\n", __func__, pimpl->n_elements*1e-6); + } else { + LLAMA_LOG_INFO("%s: model params = %.2f K\n", __func__, pimpl->n_elements*1e-3); + } - // general kv - LLAMA_LOG_INFO("%s: general.name = %s\n", __func__, name.c_str()); + // general kv + LLAMA_LOG_INFO("%s: general.name = %s\n", __func__, name.c_str()); - if (arch == LLM_ARCH_DEEPSEEK) { - LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); - LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); - } + if (arch == LLM_ARCH_DEEPSEEK) { + LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); + LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); + } - if (arch == LLM_ARCH_DEEPSEEK2 || arch == LLM_ARCH_GLM_DSA) { - LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); - LLAMA_LOG_INFO("%s: n_lora_q = %d\n", __func__, hparams.n_lora_q); - LLAMA_LOG_INFO("%s: n_lora_kv = %d\n", __func__, hparams.n_lora_kv); - LLAMA_LOG_INFO("%s: n_embd_head_k_mla = %d\n", __func__, hparams.n_embd_head_k_mla()); - LLAMA_LOG_INFO("%s: n_embd_head_v_mla = %d\n", __func__, hparams.n_embd_head_v_mla()); - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); - LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); - LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); - LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); - } + if (arch == LLM_ARCH_DEEPSEEK2 || arch == LLM_ARCH_DEEPSEEK2OCR || arch == LLM_ARCH_GLM_DSA || arch == LLM_ARCH_MISTRAL4) { + LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); + LLAMA_LOG_INFO("%s: n_lora_q = %d\n", __func__, hparams.n_lora_q); + LLAMA_LOG_INFO("%s: n_lora_kv = %d\n", __func__, hparams.n_lora_kv); + LLAMA_LOG_INFO("%s: n_embd_head_k_mla = %d\n", __func__, hparams.n_embd_head_k_mla()); + LLAMA_LOG_INFO("%s: n_embd_head_v_mla = %d\n", __func__, hparams.n_embd_head_v_mla()); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); + LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); + LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); + LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); + } - if (arch == LLM_ARCH_QWEN2MOE) { - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); - } + if (arch == LLM_ARCH_QWEN2MOE) { + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); + } - if (arch == LLM_ARCH_QWEN3MOE || arch == LLM_ARCH_OPENAI_MOE || arch == LLM_ARCH_QWEN3VLMOE || arch == LLM_ARCH_RND1) { - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - } + if (arch == LLM_ARCH_QWEN3MOE || arch == LLM_ARCH_OPENAI_MOE || arch == LLM_ARCH_QWEN3VLMOE || arch == LLM_ARCH_RND1) { + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + } - if (arch == LLM_ARCH_MINICPM || - arch == LLM_ARCH_GRANITE || - arch == LLM_ARCH_GRANITE_MOE || - arch == LLM_ARCH_GRANITE_HYBRID || - arch == LLM_ARCH_NEMOTRON_H_MOE) { - LLAMA_LOG_INFO("%s: f_embedding_scale = %f\n", __func__, hparams.f_embedding_scale); - LLAMA_LOG_INFO("%s: f_residual_scale = %f\n", __func__, hparams.f_residual_scale); - LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale); - LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); - } + if (arch == LLM_ARCH_MINICPM || + arch == LLM_ARCH_GRANITE || + arch == LLM_ARCH_GRANITE_MOE || + arch == LLM_ARCH_GRANITE_HYBRID || + arch == LLM_ARCH_NEMOTRON_H_MOE) { + LLAMA_LOG_INFO("%s: f_embedding_scale = %f\n", __func__, hparams.f_embedding_scale); + LLAMA_LOG_INFO("%s: f_residual_scale = %f\n", __func__, hparams.f_residual_scale); + LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale); + LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); + } - if (arch == LLM_ARCH_BAILINGMOE) { - LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); - LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); - LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); - } + if (arch == LLM_ARCH_BAILINGMOE) { + LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); + LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); + LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); + } - if (arch == LLM_ARCH_BAILINGMOE2) { - LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); - LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); - LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); - LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); - LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); - LLAMA_LOG_INFO("%s: nextn_predict_layers = %d\n", __func__, hparams.nextn_predict_layers); - } + if (arch == LLM_ARCH_BAILINGMOE2) { + LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); + LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); + LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); + LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); + LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); + LLAMA_LOG_INFO("%s: nextn_predict_layers = %d\n", __func__, hparams.nextn_predict_layers); + } - if (arch == LLM_ARCH_SMALLTHINKER || arch == LLM_ARCH_LFM2MOE) { - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); - } + if (arch == LLM_ARCH_SMALLTHINKER || arch == LLM_ARCH_LFM2MOE) { + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); + } - if (arch == LLM_ARCH_GROVEMOE) { - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: n_ff_chexp = %d\n", __func__, hparams.n_ff_chexp); - LLAMA_LOG_INFO("%s: n_group_experts = %d\n", __func__, hparams.n_group_experts); - LLAMA_LOG_INFO("%s: expert_group_scale = %.2f\n", __func__, hparams.expert_group_scale); + if (arch == LLM_ARCH_GROVEMOE) { + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_ff_chexp = %d\n", __func__, hparams.n_ff_chexp); + LLAMA_LOG_INFO("%s: n_group_experts = %d\n", __func__, hparams.n_group_experts); + LLAMA_LOG_INFO("%s: expert_group_scale = %.2f\n", __func__, hparams.expert_group_scale); + } } vocab.print_info(); @@ -8105,7 +8515,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, } else { llama_memory_i::layer_reuse_cb reuse = nullptr; - if (arch == LLM_ARCH_GEMMA3N) { + if (arch == LLM_ARCH_GEMMA3N || arch == LLM_ARCH_GEMMA4) { reuse = [&](int32_t il) { if (il >= (int32_t) hparams.n_layer_kv_from_start) { return (int32_t) hparams.n_layer_kv_from_start - (hparams.is_swa(il) ? 2 : 1); @@ -8168,9 +8578,9 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { case LLM_ARCH_LLAMA4: { if (hparams.swa_type == LLAMA_SWA_TYPE_NONE) { - llm = std::make_unique>(*this, params); + llm = std::make_unique>(*this, params); } else { - llm = std::make_unique(*this, params); + llm = std::make_unique>(*this, params); } } break; case LLM_ARCH_LLAMA_EMBED: @@ -8248,23 +8658,19 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { case LLM_ARCH_DREAM: { llm = std::make_unique(*this, params); - } - break; + } break; case LLM_ARCH_LLADA: { llm = std::make_unique(*this, params); - } - break; + } break; case LLM_ARCH_LLADA_MOE: { llm = std::make_unique(*this, params); - } - break; + } break; case LLM_ARCH_RND1: { llm = std::make_unique(*this, params); - } - break; + } break; case LLM_ARCH_QWEN2VL: { llm = std::make_unique(*this, params); @@ -8358,6 +8764,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_GEMMA4: + { + llm = std::make_unique(*this, params); + } break; case LLM_ARCH_GEMMA_EMBEDDING: { llm = std::make_unique(*this, params); @@ -8424,7 +8834,9 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { llm = std::make_unique(*this, params); } break; case LLM_ARCH_DEEPSEEK2: + case LLM_ARCH_DEEPSEEK2OCR: case LLM_ARCH_GLM_DSA: + case LLM_ARCH_MISTRAL4: { llm = std::make_unique(*this, params); } break; @@ -8448,11 +8860,11 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { switch (params.gtype) { case LLM_GRAPH_TYPE_ENCODER: - llm = std::make_unique(*this, params); + llm = std::make_unique>(*this, params); break; case LLM_GRAPH_TYPE_DEFAULT: case LLM_GRAPH_TYPE_DECODER: - llm = std::make_unique(*this, params); + llm = std::make_unique>(*this, params); break; default: GGML_ABORT("invalid graph type"); @@ -8460,9 +8872,8 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { } break; case LLM_ARCH_T5ENCODER: { - llm = std::make_unique(*this, params); - } - break; + llm = std::make_unique(*this, params); + } break; case LLM_ARCH_JAIS: { llm = std::make_unique(*this, params); @@ -8574,6 +8985,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_HUNYUAN_VL: case LLM_ARCH_HUNYUAN_DENSE: { llm = std::make_unique(*this, params); @@ -8823,6 +9235,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_ARCTIC: case LLM_ARCH_DEEPSEEK: case LLM_ARCH_DEEPSEEK2: + case LLM_ARCH_DEEPSEEK2OCR: case LLM_ARCH_PLM: case LLM_ARCH_CHATGLM: case LLM_ARCH_GRANITE: @@ -8836,6 +9249,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_ERNIE4_5: case LLM_ARCH_ERNIE4_5_MOE: case LLM_ARCH_MISTRAL3: + case LLM_ARCH_MISTRAL4: case LLM_ARCH_LLAMA_EMBED: case LLM_ARCH_MAINCODER: case LLM_ARCH_GLM_DSA: @@ -8874,6 +9288,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_GEMMA2: case LLM_ARCH_GEMMA3: case LLM_ARCH_GEMMA3N: + case LLM_ARCH_GEMMA4: case LLM_ARCH_GEMMA_EMBEDDING: case LLM_ARCH_STARCODER2: case LLM_ARCH_OPENELM: @@ -8920,6 +9335,9 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_GLM4_MOE: return model->hparams.use_mrope() ? LLAMA_ROPE_TYPE_MROPE : LLAMA_ROPE_TYPE_NEOX; + case LLM_ARCH_HUNYUAN_VL: + return model->hparams.use_mrope() ? LLAMA_ROPE_TYPE_MROPE : LLAMA_ROPE_TYPE_NEOX; + // all model arches should be listed explicitly here case LLM_ARCH_UNKNOWN: GGML_ABORT("unknown architecture"); @@ -9054,3 +9472,18 @@ bool llama_model_is_diffusion(const llama_model * model) { const std::vector> & llama_internal_get_tensor_map(const llama_model * model) { return model->tensors_by_name; } + +int32_t llama_model_n_expert(const struct llama_model * model) { + return model->hparams.n_expert; +} + +int32_t llama_model_n_devices(const struct llama_model * model) { + return (int32_t)model->devices.size(); +} + +ggml_backend_dev_t llama_model_get_device(const struct llama_model * model, int i) { + if (i < 0 || i >= (int)model->devices.size()) { + return nullptr; + } + return model->devices[i].dev; +} diff --git a/examples/talk-llama/llama-model.h b/examples/talk-llama/llama-model.h index 25bf892e7e2..5f101bd6374 100644 --- a/examples/talk-llama/llama-model.h +++ b/examples/talk-llama/llama-model.h @@ -84,6 +84,7 @@ enum llm_type { LLM_TYPE_26B, LLM_TYPE_27B, LLM_TYPE_30B, + LLM_TYPE_31B, LLM_TYPE_32B, LLM_TYPE_34B, LLM_TYPE_35B, @@ -118,6 +119,7 @@ enum llm_type { LLM_TYPE_16B_A1B, LLM_TYPE_21B_A3B, // Ernie MoE small LLM_TYPE_24B_A2B, // lfm2moe + LLM_TYPE_26B_A4B, // Gemma4 LLM_TYPE_30B_A3B, LLM_TYPE_31B_A3_5B, LLM_TYPE_35B_A3B, // Qwen3.5 @@ -244,6 +246,8 @@ struct llama_layer { struct ggml_tensor * wkv_b = nullptr; struct ggml_tensor * wk_b = nullptr; struct ggml_tensor * wv_b = nullptr; + struct ggml_tensor * wqkv_b = nullptr; + struct ggml_tensor * wo_b = nullptr; struct ggml_tensor * wq_cross = nullptr; struct ggml_tensor * wk_cross = nullptr; struct ggml_tensor * wv_cross = nullptr; @@ -254,13 +258,6 @@ struct llama_layer { struct ggml_tensor * wo_enc = nullptr; struct ggml_tensor * wqkv_gate = nullptr; - // attention bias - struct ggml_tensor * bq = nullptr; - struct ggml_tensor * bk = nullptr; - struct ggml_tensor * bv = nullptr; - struct ggml_tensor * bo = nullptr; - struct ggml_tensor * bqkv = nullptr; - // relative position bias struct ggml_tensor * attn_rel_b = nullptr; struct ggml_tensor * attn_rel_b_enc = nullptr; @@ -270,6 +267,9 @@ struct llama_layer { struct ggml_tensor * ffn_norm = nullptr; struct ggml_tensor * ffn_norm_b = nullptr; struct ggml_tensor * ffn_post_norm = nullptr; + struct ggml_tensor * ffn_post_norm_1 = nullptr; // gemma4 + struct ggml_tensor * ffn_post_norm_2 = nullptr; // gemma4 + struct ggml_tensor * ffn_pre_norm_2 = nullptr; // gemma4 struct ggml_tensor * layer_out_norm = nullptr; struct ggml_tensor * layer_out_norm_b = nullptr; struct ggml_tensor * ffn_norm_exps = nullptr; @@ -285,6 +285,7 @@ struct llama_layer { // ff MoE struct ggml_tensor * ffn_gate_inp = nullptr; + struct ggml_tensor * ffn_gate_inp_s = nullptr; // gemma4 struct ggml_tensor * ffn_gate_exps = nullptr; struct ggml_tensor * ffn_down_exps = nullptr; struct ggml_tensor * ffn_up_exps = nullptr; @@ -409,10 +410,32 @@ struct llama_layer { struct ggml_tensor * ffn_gate_shexp_s = nullptr; struct ggml_tensor * ffn_up_shexp_s = nullptr; struct ggml_tensor * ffn_down_shexp_s = nullptr; - struct ggml_tensor * ssm_out_s = nullptr; + struct ggml_tensor * ssm_in_s = nullptr; + struct ggml_tensor * ssm_out_s = nullptr; struct ggml_tensor * ssm_alpha_s = nullptr; struct ggml_tensor * ssm_beta_s = nullptr; + // input scales + struct ggml_tensor * wq_in_s = nullptr; + struct ggml_tensor * wk_in_s = nullptr; + struct ggml_tensor * wv_in_s = nullptr; + struct ggml_tensor * wo_in_s = nullptr; + struct ggml_tensor * wqkv_in_s = nullptr; + struct ggml_tensor * wqkv_gate_in_s = nullptr; + struct ggml_tensor * ffn_gate_in_s = nullptr; + struct ggml_tensor * ffn_up_in_s = nullptr; + struct ggml_tensor * ffn_down_in_s = nullptr; + struct ggml_tensor * ffn_gate_exps_in_s = nullptr; + struct ggml_tensor * ffn_down_exps_in_s = nullptr; + struct ggml_tensor * ffn_up_exps_in_s = nullptr; + struct ggml_tensor * ffn_gate_shexp_in_s= nullptr; + struct ggml_tensor * ffn_up_shexp_in_s = nullptr; + struct ggml_tensor * ffn_down_shexp_in_s= nullptr; + struct ggml_tensor * ssm_in_in_s = nullptr; + struct ggml_tensor * ssm_out_in_s = nullptr; + struct ggml_tensor * ssm_alpha_in_s = nullptr; + struct ggml_tensor * ssm_beta_in_s = nullptr; + // altup & laurel struct ggml_tensor * per_layer_inp_gate = nullptr; struct ggml_tensor * per_layer_proj = nullptr; @@ -461,6 +484,9 @@ struct llama_layer { struct ggml_tensor * indexer_attn_k = nullptr; struct ggml_tensor * indexer_attn_q_b = nullptr; // note: for lora a/b, not bias + // gemma4 layer output scale + struct ggml_tensor * out_scale = nullptr; + struct llama_layer_posnet posnet; struct llama_layer_convnext convnext; @@ -470,6 +496,19 @@ struct llama_layer { struct llama_layer_nextn nextn; }; +struct llama_device { + bool is_meta; + + ggml_backend_dev_t dev; +}; + +struct llama_meta_device_get_split_state_userdata { + size_t n_devices; + const struct llama_model * model; +}; + +struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const struct ggml_tensor * tensor, void * userdata); + struct llama_model { llm_type type = LLM_TYPE_UNKNOWN; llm_arch arch = LLM_ARCH_UNKNOWN; @@ -505,9 +544,9 @@ struct llama_model { struct ggml_tensor * conv1d_b = nullptr; // gemma3n altup - struct ggml_tensor * tok_embd_per_layer = nullptr; struct ggml_tensor * altup_proj = nullptr; struct ggml_tensor * altup_unembd_proj = nullptr; + struct ggml_tensor * per_layer_tok_embd = nullptr; struct ggml_tensor * per_layer_model_proj = nullptr; struct ggml_tensor * per_layer_proj_norm = nullptr; @@ -524,7 +563,7 @@ struct llama_model { std::unordered_map gguf_kv; // list of devices used in this model - std::vector devices; + std::vector devices; // for quantize-stats only std::vector> tensors_by_name; @@ -532,6 +571,9 @@ struct llama_model { // for keeping track of associated LoRA adapters std::unordered_set loras; + // statically allocated context for assigning + struct llama_meta_device_get_split_state_userdata get_split_state_ud; + int64_t t_load_us = 0; int64_t t_start_us = 0; @@ -552,6 +594,7 @@ struct llama_model { size_t size() const; // file size size_t n_tensors() const; size_t n_devices() const; + const float * tensor_split() const; uint32_t n_gpu_layers() const; llama_split_mode split_mode() const; diff --git a/examples/talk-llama/llama-quant.cpp b/examples/talk-llama/llama-quant.cpp index 8e8ce231249..2f0f70b73b6 100644 --- a/examples/talk-llama/llama-quant.cpp +++ b/examples/talk-llama/llama-quant.cpp @@ -1,11 +1,11 @@ -#include "llama.h" #include "llama-impl.h" #include "llama-model.h" #include "llama-model-loader.h" +#include "llama-ext.h" +#include #include #include -#include #include #include #include @@ -84,7 +84,6 @@ static std::string remap_imatrix(const std::string & orig_name, const std::maptensor_types) { - const auto & tensor_types = *static_cast *>(params->tensor_types); - for (const auto & [tname, qtype] : tensor_types) { - tensor_type_patterns.emplace_back(std::regex(tname), qtype); + if (params->tt_overrides) { + for (const auto * p = params->tt_overrides; p->pattern != nullptr; p++) { + tensor_type_patterns.emplace_back(std::regex(p->pattern), p->type); } } } @@ -199,6 +197,7 @@ struct quantize_state_impl { // per-tensor metadata, computed in the preliminary loop and used in the main loop struct tensor_metadata { + std::string name; ggml_type target_type; tensor_category category; std::string remapped_imatrix_name; @@ -344,7 +343,13 @@ static bool tensor_allows_quantization(const llama_model_quantize_params * param quantize &= name.find("attn_rel_b.weight") == std::string::npos; // do not quantize specific multimodal tensors - quantize &= name.find(".position_embd.") == std::string::npos; + quantize &= name.find(".position_embd") == std::string::npos; + quantize &= name.find("sam.pos_embd") == std::string::npos; + quantize &= name.find("sam.neck.") == std::string::npos; + quantize &= name.find("sam.net_") == std::string::npos; + quantize &= name.find(".rel_pos") == std::string::npos; + quantize &= name.find(".patch_embd") == std::string::npos; + quantize &= name.find(".patch_merger") == std::string::npos; return quantize; } @@ -678,9 +683,9 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, const llama_mod LLAMA_LOG_WARN("%s: %-36s - applying manual override: %s -> %s\n", __func__, tensor_name.c_str(), ggml_type_name(new_type), ggml_type_name(qtype)); new_type = qtype; - manual = true; - break; } + manual = true; + break; } } } @@ -784,7 +789,7 @@ static bool tensor_requires_imatrix(const char * tensor_name, const ggml_type ds // given a file type, get the default tensor type // -static ggml_type llama_ftype_get_default_type(llama_ftype ftype) { +ggml_type llama_ftype_get_default_type(llama_ftype ftype) { switch (ftype) { case LLAMA_FTYPE_MOSTLY_Q4_0: return GGML_TYPE_Q4_0; case LLAMA_FTYPE_MOSTLY_Q4_1: return GGML_TYPE_Q4_1; @@ -794,6 +799,7 @@ static ggml_type llama_ftype_get_default_type(llama_ftype ftype) { case LLAMA_FTYPE_MOSTLY_F16: return GGML_TYPE_F16; case LLAMA_FTYPE_MOSTLY_BF16: return GGML_TYPE_BF16; case LLAMA_FTYPE_ALL_F32: return GGML_TYPE_F32; + case LLAMA_FTYPE_MOSTLY_Q1_0: return GGML_TYPE_Q1_0; case LLAMA_FTYPE_MOSTLY_MXFP4_MOE: return GGML_TYPE_MXFP4; @@ -823,16 +829,32 @@ static ggml_type llama_ftype_get_default_type(llama_ftype ftype) { case LLAMA_FTYPE_MOSTLY_IQ3_S: case LLAMA_FTYPE_MOSTLY_IQ3_M: return GGML_TYPE_IQ3_S; - default: throw std::runtime_error(format("invalid output file type %d\n", ftype)); + default: return GGML_TYPE_COUNT; } } + +static void init_quantize_state_counters(quantize_state_impl & qs, std::vector & metadata) { + for (auto & tm : metadata) { + tensor_category cat = tensor_get_category(tm.name); + tm.category = cat; + + if (category_is_attn_v(cat)) { + ++qs.n_attention_wv; + } + + if (cat == tensor_category::OUTPUT) { + qs.has_tied_embeddings = false; + } + } + qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)qs.model.hparams.n_layer; +} + // // main quantization driver // static void llama_model_quantize_impl(const std::string & fname_inp, const std::string & fname_out, const llama_model_quantize_params * params) { - ggml_type default_type; llama_ftype ftype = params->ftype; int nthread = params->nthread; @@ -841,7 +863,10 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: nthread = std::thread::hardware_concurrency(); } - default_type = llama_ftype_get_default_type(ftype); + ggml_type default_type = llama_ftype_get_default_type(ftype); + if (default_type == GGML_TYPE_COUNT) { + throw std::runtime_error(format("invalid output file type %d\n", ftype)); + } // mmap consistently increases speed on Linux, and also increases speed on Windows with // hot cache. It may cause a slowdown on macOS, possibly related to free memory. @@ -851,15 +876,10 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: constexpr bool use_mmap = false; #endif - llama_model_kv_override * kv_overrides = nullptr; - if (params->kv_overrides) { - auto * v = (std::vector*)params->kv_overrides; - kv_overrides = v->data(); - } - + const llama_model_kv_override * kv_overrides = params->kv_overrides; std::vector splits = {}; llama_model_loader ml(/*metadata*/ nullptr, /*set_tensor_data*/ nullptr, /*set_tensor_data_ud*/ nullptr, - fname_inp, splits, use_mmap, /*use_direct_io*/ false, /*check_tensors*/ true, /*no_alloc*/ false, kv_overrides, nullptr); + fname_inp, splits, /*file*/ nullptr, use_mmap, /*use_direct_io*/ false, /*check_tensors*/ true, /*no_alloc*/ false, kv_overrides, nullptr); ml.init_mappings(false); // no prefetching llama_model model(llama_model_default_params()); @@ -873,9 +893,13 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: if (params->only_copy) { ftype = ml.ftype; } + std::unordered_map> i_data; const std::unordered_map> * imatrix_data = nullptr; if (params->imatrix) { - imatrix_data = static_cast>*>(params->imatrix); + for (const llama_model_imatrix_data * p = params->imatrix; p->name != nullptr; p++) { + i_data.emplace(p->name, std::vector(p->data, p->data + p->size)); + } + imatrix_data = & i_data; if (imatrix_data) { LLAMA_LOG_INFO("\n%s: have importance matrix data with %d entries\n", __func__, (int)imatrix_data->size()); @@ -896,7 +920,9 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: std::vector prune_list = {}; if (params->prune_layers) { - prune_list = *static_cast *>(params->prune_layers); + for (const int32_t * p = params->prune_layers; * p != -1; p++) { + prune_list.push_back(* p); + } } // copy the KV pairs from the input file @@ -910,20 +936,18 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: gguf_remove_key(ctx_out.get(), ml.llm_kv(LLM_KV_SPLIT_TENSORS_COUNT).c_str()); if (params->kv_overrides) { - const std::vector & overrides = *(const std::vector *)params->kv_overrides; - for (const auto & o : overrides) { - if (o.key[0] == 0) break; - if (o.tag == LLAMA_KV_OVERRIDE_TYPE_FLOAT) { - gguf_set_val_f32(ctx_out.get(), o.key, o.val_f64); - } else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_INT) { + for (const llama_model_kv_override * o = params->kv_overrides; o->key[0] != 0; ++o) { + if (o->tag == LLAMA_KV_OVERRIDE_TYPE_FLOAT) { + gguf_set_val_f32(ctx_out.get(), o->key, o->val_f64); + } else if (o->tag == LLAMA_KV_OVERRIDE_TYPE_INT) { // Setting type to UINT32. See https://github.com/ggml-org/llama.cpp/pull/14182 for context - gguf_set_val_u32(ctx_out.get(), o.key, (uint32_t)std::abs(o.val_i64)); - } else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_BOOL) { - gguf_set_val_bool(ctx_out.get(), o.key, o.val_bool); - } else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_STR) { - gguf_set_val_str(ctx_out.get(), o.key, o.val_str); + gguf_set_val_u32(ctx_out.get(), o->key, (uint32_t)std::abs(o->val_i64)); + } else if (o->tag == LLAMA_KV_OVERRIDE_TYPE_BOOL) { + gguf_set_val_bool(ctx_out.get(), o->key, o->val_bool); + } else if (o->tag == LLAMA_KV_OVERRIDE_TYPE_STR) { + gguf_set_val_str(ctx_out.get(), o->key, o->val_str); } else { - LLAMA_LOG_WARN("%s: unknown KV override type for key %s\n", __func__, o.key); + LLAMA_LOG_WARN("%s: unknown KV override type for key %s\n", __func__, o->key); } } } @@ -961,6 +985,15 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: }); } + // compute tensor metadata once and cache it + std::vector metadata(tensors.size()); + for (size_t i = 0; i < tensors.size(); ++i) { + metadata[i].name = ggml_get_name(tensors[i]->tensor); + } + + // initialize quantization state counters and metadata categories + init_quantize_state_counters(qs, metadata); + int idx = 0; uint16_t n_split = 1; @@ -973,25 +1006,6 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: std::vector ctx_outs(n_split); ctx_outs[0] = std::move(ctx_out); - // compute tensor metadata once and cache it - std::vector metadata(tensors.size()); - - // initialize quantization state before preliminary loop (counters for use_more_bits) - { - for (size_t i = 0; i < tensors.size(); ++i) { - const auto cat = tensor_get_category(tensors[i]->tensor->name); - if (category_is_attn_v(cat)) { - ++qs.n_attention_wv; - } - if (cat == tensor_category::OUTPUT) { - qs.has_tied_embeddings = false; - } - metadata[i].category = cat; // save and re-use the category while we're at it - } - // these also need to be set to n_layer by default - qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)qs.model.hparams.n_layer; - } - // flag for --dry-run bool will_require_imatrix = false; @@ -1002,7 +1016,6 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: for (size_t i = 0; i < tensors.size(); ++i) { const auto * it = tensors[i]; const struct ggml_tensor * tensor = it->tensor; - const std::string name = ggml_get_name(tensor); uint16_t i_split = params->keep_split ? it->idx : 0; if (!ctx_outs[i_split]) { @@ -1031,7 +1044,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: " - offending tensor: %s\n" " - target type: %s\n" "============================================================================\n\n", - name.c_str(), ggml_type_name(metadata[i].target_type)); + metadata[i].name.c_str(), ggml_type_name(metadata[i].target_type)); throw std::runtime_error("this quantization requires an imatrix!"); } } @@ -1104,7 +1117,6 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: new_ofstream(weight.idx); } - const std::string name = ggml_get_name(tensor); const size_t tensor_size = ggml_nbytes(tensor); if (!params->dry_run) { @@ -1235,9 +1247,9 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: total_size_new += new_size; // update the gguf meta data as we go - gguf_set_tensor_type(ctx_outs[cur_split].get(), name.c_str(), new_type); - GGML_ASSERT(gguf_get_tensor_size(ctx_outs[cur_split].get(), gguf_find_tensor(ctx_outs[cur_split].get(), name.c_str())) == new_size); - gguf_set_tensor_data(ctx_outs[cur_split].get(), name.c_str(), new_data); + gguf_set_tensor_type(ctx_outs[cur_split].get(), metadata[i].name.c_str(), new_type); + GGML_ASSERT(gguf_get_tensor_size(ctx_outs[cur_split].get(), gguf_find_tensor(ctx_outs[cur_split].get(), metadata[i].name.c_str())) == new_size); + gguf_set_tensor_data(ctx_outs[cur_split].get(), metadata[i].name.c_str(), new_data); // write tensor data + padding fout.write((const char *) new_data, new_size); @@ -1271,7 +1283,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: llama_model_quantize_params llama_model_quantize_default_params() { llama_model_quantize_params result = { /*.nthread =*/ 0, - /*.ftype =*/ LLAMA_FTYPE_MOSTLY_Q5_1, + /*.ftype =*/ LLAMA_FTYPE_MOSTLY_Q8_0, /*.output_tensor_type =*/ GGML_TYPE_COUNT, /*.token_embedding_type =*/ GGML_TYPE_COUNT, /*.allow_requantize =*/ false, @@ -1302,3 +1314,89 @@ uint32_t llama_model_quantize( return 0; } + +// +// Helper functions for external tools exposed in llama-ext.h +// + +quantize_state_impl * llama_quant_init( + const llama_model * model, + const llama_model_quantize_params * params) { + return new quantize_state_impl(*model, params); +} + +void llama_quant_free(quantize_state_impl * qs) { + delete qs; +} + +llama_model * llama_quant_model_from_metadata(const llama_quant_model_desc * desc) { + struct llama_model_params mparams = llama_model_default_params(); + auto * model = new llama_model(mparams); + + model->arch = llm_arch_from_string(desc->architecture); + + // infer llm_type: only LLM_TYPE_70B matters for quantization logic + if (model->arch == LLM_ARCH_LLAMA && desc->n_layer == 80 && desc->n_head != desc->n_head_kv) { + model->type = LLM_TYPE_70B; + } + + model->hparams.n_embd = desc->n_embd; + model->hparams.n_embd_head_k_full = desc->n_embd_head_k; + model->hparams.n_embd_head_v_full = desc->n_embd_head_v; + model->hparams.n_layer = desc->n_layer; + model->hparams.n_expert = desc->n_expert; + + for (uint32_t i = 0; i < desc->n_layer; i++) { + model->hparams.n_head_arr[i] = desc->n_head; + model->hparams.n_head_kv_arr[i] = desc->n_head_kv; + model->hparams.n_ff_arr[i] = desc->n_ff; + } + + return model; +} + +bool llama_quant_tensor_allows_quantization( + const quantize_state_impl * qs, + const ggml_tensor * tensor) { + return tensor_allows_quantization(qs->params, qs->model.arch, tensor); +} + +void llama_quant_compute_types( + quantize_state_impl * qs, + llama_ftype ftype, + ggml_tensor ** tensors, + ggml_type * result_types, + size_t n_tensors) { + // reset per-computation state + qs->n_attention_wv = 0; + qs->n_ffn_down = 0; + qs->n_ffn_gate = 0; + qs->n_ffn_up = 0; + qs->i_attention_wv = 0; + qs->i_ffn_down = 0; + qs->i_ffn_gate = 0; + qs->i_ffn_up = 0; + qs->n_fallback = 0; + qs->has_imatrix = false; + qs->has_tied_embeddings = true; + + // build metadata from tensor names + std::vector metadata(n_tensors); + for (size_t i = 0; i < n_tensors; i++) { + metadata[i].name = ggml_get_name(tensors[i]); + } + + // initialize counters and categories + init_quantize_state_counters(*qs, metadata); + + // use a local copy of params with the requested ftype + llama_model_quantize_params local_params = *qs->params; + local_params.ftype = ftype; + + ggml_type default_type = llama_ftype_get_default_type(ftype); + + // compute types + for (size_t i = 0; i < n_tensors; i++) { + result_types[i] = llama_tensor_get_type(*qs, &local_params, tensors[i], default_type, metadata[i]); + } +} diff --git a/examples/talk-llama/llama-vocab.cpp b/examples/talk-llama/llama-vocab.cpp index 68ba292d426..163f222ef61 100644 --- a/examples/talk-llama/llama-vocab.cpp +++ b/examples/talk-llama/llama-vocab.cpp @@ -493,6 +493,16 @@ struct llm_tokenizer_bpe : llm_tokenizer { "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?(?:\\p{L}\\p{M}*(?: \\p{L}\\p{M}*)*)+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]?|\\s*[\\r\\n]|\\s+(?!\\S)|\\s+", }; break; + case LLAMA_VOCAB_PRE_TYPE_GEMMA4: + // Gemma4 uses SPM-style BPE: spaces are replaced with ▁ by the + // normalizer, then BPE merges run on the whole text without + // word-level pre-splitting. We only need to split on newlines + // since BPE merge lookup asserts no newlines in tokens. + regex_exprs = { + "[^\\n]+|[\\n]+", + }; + byte_encode = false; // uses raw UTF-8, not GPT-2 byte encoding + break; default: // default regex for BPE tokenization pre-processing regex_exprs = { @@ -506,6 +516,7 @@ struct llm_tokenizer_bpe : llm_tokenizer { } std::vector regex_exprs; + bool byte_encode = true; // GPT-2 byte encoding; false for SPM-style BPE (raw UTF-8) }; struct llm_tokenizer_bpe_session { @@ -550,9 +561,10 @@ struct llm_tokenizer_bpe_session { void tokenize(const std::string & text, std::vector & output) { int final_prev_index = -1; - const auto word_collection = unicode_regex_split(text, tokenizer.regex_exprs); + const auto word_collection = unicode_regex_split(text, tokenizer.regex_exprs, tokenizer.byte_encode); symbols_final.clear(); + auto tok_pre = vocab.get_pre_type(); for (const auto & word : word_collection) { work_queue = llm_bigram_bpe::queue(); @@ -565,6 +577,13 @@ struct llm_tokenizer_bpe_session { if (vocab.get_ignore_merges() && vocab.text_to_token(word) != LLAMA_TOKEN_NULL) { symbols.emplace_back(llm_symbol{-1, -1, word.c_str(), word.size()}); offset = word.size(); + } else if (tok_pre == LLAMA_VOCAB_PRE_TYPE_GEMMA4 && word.find_first_not_of('\n') == std::string::npos) { + // fix for gemma 4, ref: https://github.com/ggml-org/llama.cpp/pull/21343 + auto tok = vocab.text_to_token(word); + if (tok != LLAMA_TOKEN_NULL) { + symbols.emplace_back(llm_symbol{-1, -1, word.c_str(), word.size()}); + offset = word.size(); + } } while (offset < word.size()) { @@ -640,8 +659,17 @@ struct llm_tokenizer_bpe_session { if (token == LLAMA_TOKEN_NULL) { for (auto j = str.begin(); j != str.end(); ++j) { - std::string byte_str(1, *j); - auto token_multibyte = vocab.text_to_token(byte_str); + llama_token token_multibyte = LLAMA_TOKEN_NULL; + if (tokenizer.byte_encode) { + std::string byte_str(1, *j); + token_multibyte = vocab.text_to_token(byte_str); + } else { + // For non-byte-encoded BPE (e.g. gemma-4), byte tokens use <0xXX> format + static const char * hex = "0123456789ABCDEF"; + const uint8_t ch = (uint8_t)*j; + const char buf[7] = { '<', '0', 'x', hex[ch >> 4], hex[ch & 15], '>', 0 }; + token_multibyte = vocab.text_to_token(buf); + } if (token_multibyte != LLAMA_TOKEN_NULL) { output.push_back(token_multibyte); } @@ -1863,6 +1891,42 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { special_sep_id = LLAMA_TOKEN_NULL; special_pad_id = 3; // <|plamo:pad|> special_mask_id = LLAMA_TOKEN_NULL; + } else if (tokenizer_model == "gemma4") { + type = LLAMA_VOCAB_TYPE_BPE; + + // read bpe merges and populate bpe ranks + const int merges_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_MERGES).c_str()); + if (merges_keyidx == -1) { + throw std::runtime_error("cannot find tokenizer merges in model file\n"); + } + { + const int n_merges = gguf_get_arr_n(ctx, merges_keyidx); + for (int i = 0; i < n_merges; i++) { + const std::string word = gguf_get_arr_str(ctx, merges_keyidx, i); + + std::string first; + std::string second; + + const size_t pos = word.find(' ', 1); + + if (pos != std::string::npos) { + first = word.substr(0, pos); + second = word.substr(pos + 1); + } + + bpe_ranks.emplace(std::make_pair(first, second), i); + } + } + + // default special tokens (to be read from GGUF) + special_bos_id = LLAMA_TOKEN_NULL; + special_eos_id = LLAMA_TOKEN_NULL; + special_unk_id = LLAMA_TOKEN_NULL; + special_sep_id = LLAMA_TOKEN_NULL; + special_pad_id = LLAMA_TOKEN_NULL; + special_mask_id = LLAMA_TOKEN_NULL; + + tokenizer_pre = "gemma4"; } else { throw std::runtime_error(format("unknown tokenizer: '%s'", tokenizer_model.c_str())); } @@ -1870,6 +1934,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { // for now, only BPE models have pre-tokenizers if (type == LLAMA_VOCAB_TYPE_BPE) { add_space_prefix = false; + escape_whitespaces = false; clean_spaces = true; if (tokenizer_pre.empty()) { LLAMA_LOG_WARN("%s: missing pre-tokenizer type, using: 'default'\n", __func__); @@ -1936,6 +2001,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { } else if ( tokenizer_pre == "jais-2") { pre_type = LLAMA_VOCAB_PRE_TYPE_JAIS2; + } else if ( + tokenizer_pre == "gemma4") { + pre_type = LLAMA_VOCAB_PRE_TYPE_GEMMA4; + escape_whitespaces = true; } else if ( tokenizer_pre == "jina-v1-en" || tokenizer_pre == "jina-v2-code" || @@ -1952,7 +2021,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { } else if ( tokenizer_pre == "qwen2" || tokenizer_pre == "deepseek-r1-qwen" || - tokenizer_pre == "kormo") { + tokenizer_pre == "kormo" || + tokenizer_pre == "f2llmv2") { pre_type = LLAMA_VOCAB_PRE_TYPE_QWEN2; clean_spaces = false; } else if ( @@ -2129,19 +2199,28 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { throw std::runtime_error("cannot find tokenizer vocab in model file\n"); } + const uint32_t n_tokens = gguf_get_arr_n(ctx, token_idx); + const float * scores = nullptr; const int score_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_SCORES).c_str()); if (score_idx != -1) { + const uint32_t n_scores = gguf_get_arr_n(ctx, score_idx); + if (n_scores < n_tokens) { + throw std::runtime_error("Index out of array bounds for scores (" + std::to_string(n_scores) + " < " + std::to_string(n_tokens) + ")\n"); + } scores = (const float * ) gguf_get_arr_data(ctx, score_idx); } const int * toktypes = nullptr; const int toktype_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_TOKEN_TYPE).c_str()); if (toktype_idx != -1) { + const uint32_t n_toktypes = gguf_get_arr_n(ctx, toktype_idx); + if (n_toktypes < n_tokens) { + throw std::runtime_error("Index out of array bounds for toktypes (" + std::to_string(n_toktypes) + " < " + std::to_string(n_tokens) + ")\n"); + } toktypes = (const int * ) gguf_get_arr_data(ctx, toktype_idx); } - uint32_t n_tokens = gguf_get_arr_n(ctx, token_idx); id_to_token.resize(n_tokens); for (uint32_t i = 0; i < n_tokens; i++) { @@ -2255,6 +2334,14 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { if (ml.get_key(LLM_KV_TOKENIZER_ADD_SEP, temp, false)) { add_sep = temp; } + + // workaround for Gemma 4 + // ref: https://github.com/ggml-org/llama.cpp/pull/21500 + if (pre_type == LLAMA_VOCAB_PRE_TYPE_GEMMA4 && !add_bos) { + add_bos = true; + + LLAMA_LOG_WARN("%s: override '%s' to 'true' for Gemma4\n", __func__, kv(LLM_KV_TOKENIZER_ADD_BOS).c_str()); + } } // auto-detect special tokens by text @@ -2480,6 +2567,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { || t.first == "[EOS]" // Kimi-K2 || t.first == "<|end_of_text|>" || t.first == "" // smoldocling + || t.first == "" // gemma4 + || t.first == "" // gemma4 + || t.first == "<|tool_response>" // gemma4 + || t.first == "<|end▁of▁sentence|>" // deepseek-ocr ) { special_eog_ids.insert(t.second); if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { @@ -2564,6 +2655,33 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { LLAMA_LOG_WARN("%s: special_eog_ids contains both '<|return|>' and '<|call|>', or '<|calls|>' and '<|flush|>' tokens, removing '<|end|>' token from EOG list\n", __func__); } } + + // workaround for gemma4 and paddleocr: do not include as an eog token + { + bool has_tool_response = false; + bool has_s = false; + + llama_token s_id = LLAMA_TOKEN_NULL; + + for (auto tid : special_eog_ids) { + const auto & text = id_to_token[tid].text; + if (text == "<|tool_response>") { + has_tool_response = true; + } else if (text == "") { + has_s = true; + s_id = tid; + } + } + + if (has_tool_response && has_s) { + special_eog_ids.erase(s_id); + + auto & attr = id_to_token[s_id].attr; + attr = LLAMA_TOKEN_ATTR_NORMAL; + + LLAMA_LOG_WARN("%s: special_eog_ids contains '<|tool_response>', removing '' token from EOG list\n", __func__); + } + } } // build special tokens cache @@ -2732,7 +2850,9 @@ uint8_t llama_vocab::impl::token_to_byte(llama_token id) const { return strtol(buf.c_str(), NULL, 16); } case LLAMA_VOCAB_TYPE_BPE: { - GGML_ABORT("fatal error"); + // Gemma4 uses BPE with SPM-style byte fallback tokens (<0xXX>) + auto buf = token_data.text.substr(3, 2); + return strtol(buf.c_str(), NULL, 16); } case LLAMA_VOCAB_TYPE_WPM: { GGML_ABORT("fatal error"); @@ -3021,6 +3141,10 @@ std::vector llama_vocab::impl::tokenize( if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { std::string text = fragment.raw_text.substr(fragment.offset, fragment.length); + if (escape_whitespaces) { + llama_escape_whitespace(text); + } + #ifdef PRETOKENIZERDEBUG LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", text.length(), fragment.offset, fragment.length, text.c_str()); #endif @@ -3200,9 +3324,19 @@ int32_t llama_vocab::impl::token_to_piece(llama_token token, char * buf, int32_t return _try_copy(token_text.data(), token_text.size()); } if (attr & LLAMA_TOKEN_ATTR_NORMAL) { + if (escape_whitespaces) { + // SPM-style BPE: tokens contain ▁ for spaces + std::string result = token_text; + llama_unescape_whitespace(result); + return _try_copy(result.data(), result.size()); + } std::string result = llama_decode_text(token_text); return _try_copy(result.data(), result.size()); } + if (attr & LLAMA_TOKEN_ATTR_BYTE) { + char byte = (char) token_to_byte(token); + return _try_copy((char*) &byte, 1); + } break; } case LLAMA_VOCAB_TYPE_RWKV: { @@ -3630,9 +3764,7 @@ int llama_vocab::max_token_len() const { int llama_vocab::find_bpe_rank(const std::string & token_left, const std::string & token_right) const { GGML_ASSERT(token_left.find(' ') == std::string::npos); - GGML_ASSERT(token_left.find('\n') == std::string::npos); GGML_ASSERT(token_right.find(' ') == std::string::npos); - GGML_ASSERT(token_right.find('\n') == std::string::npos); auto it = pimpl->bpe_ranks.find(std::make_pair(token_left, token_right)); if (it == pimpl->bpe_ranks.end()) { diff --git a/examples/talk-llama/llama-vocab.h b/examples/talk-llama/llama-vocab.h index be5b08012df..dd38f45d3a2 100644 --- a/examples/talk-llama/llama-vocab.h +++ b/examples/talk-llama/llama-vocab.h @@ -58,6 +58,7 @@ enum llama_vocab_pre_type { LLAMA_VOCAB_PRE_TYPE_TINY_AYA = 47, LLAMA_VOCAB_PRE_TYPE_JOYAI_LLM = 48, LLAMA_VOCAB_PRE_TYPE_JAIS2 = 49, + LLAMA_VOCAB_PRE_TYPE_GEMMA4 = 50, }; struct LLM_KV; diff --git a/examples/talk-llama/llama.cpp b/examples/talk-llama/llama.cpp index 872e659edca..e9c3028585d 100644 --- a/examples/talk-llama/llama.cpp +++ b/examples/talk-llama/llama.cpp @@ -1,6 +1,5 @@ #include "llama.h" -#include "ggml-cpp.h" #include "llama-impl.h" #include "llama-chat.h" @@ -12,6 +11,7 @@ #include "llama-model.h" #include "ggml.h" +#include "ggml-cpp.h" #include "ggml-backend.h" #include "gguf.h" @@ -24,6 +24,7 @@ #include #include #include +#include #if defined(_MSC_VER) #pragma warning(disable: 4244 4267) // possible loss of data @@ -45,722 +46,6 @@ const char * llama_flash_attn_type_name(enum llama_flash_attn_type flash_attn_ty GGML_ABORT("fatal error"); } -struct llama_device_memory_data { - int64_t total; - int64_t free; - llama_memory_breakdown_data mb; -}; - -static std::vector llama_get_device_memory_data( - const char * path_model, const llama_model_params * mparams, const llama_context_params * cparams, - std::vector & devs, uint32_t & hp_ngl, uint32_t & hp_n_ctx_train, uint32_t & hp_n_expert, - const ggml_log_level log_level) { - struct user_data_t { - struct { - ggml_log_callback callback; - void * user_data; - } original_logger; - ggml_log_level min_level; // prints below this log level go to debug log - }; - user_data_t ud; - llama_log_get(&ud.original_logger.callback, &ud.original_logger.user_data); - ud.min_level = log_level; - - llama_log_set([](ggml_log_level level, const char * text, void * user_data) { - const user_data_t * ud = (const user_data_t *) user_data; - const ggml_log_level level_eff = level >= ud->min_level ? level : GGML_LOG_LEVEL_DEBUG; - ud->original_logger.callback(level_eff, text, ud->original_logger.user_data); - }, &ud); - - llama_model_params mparams_copy = *mparams; - mparams_copy.no_alloc = true; - mparams_copy.use_mmap = false; - mparams_copy.use_mlock = false; - - llama_model * model = llama_model_load_from_file(path_model, mparams_copy); - if (model == nullptr) { - llama_log_set(ud.original_logger.callback, ud.original_logger.user_data); - throw std::runtime_error("failed to load model"); - } - - llama_context * ctx = llama_init_from_model(model, *cparams); - if (ctx == nullptr) { - llama_model_free(model); - llama_log_set(ud.original_logger.callback, ud.original_logger.user_data); - throw std::runtime_error("failed to create llama_context from model"); - } - - std::vector ret(model->devices.size()); - - std::map memory_breakdown = ctx->memory_breakdown(); - - for (const auto & [buft, mb] : memory_breakdown) { - if (ggml_backend_buft_is_host(buft)) { - continue; - } - - ggml_backend_dev_t dev = ggml_backend_buft_get_device(buft); - if (!dev) { - continue; - } - for (size_t i = 0; i < ret.size(); i++) { - if (model->devices[i] == dev) { - ret[i].mb.model += mb.model; - ret[i].mb.context += mb.context; - ret[i].mb.compute += mb.compute; - break; - } - } - } - for (size_t i = 0; i < ret.size(); i++) { - size_t free; - size_t total; - ggml_backend_dev_memory(model->devices[i], &free, &total); - - // devices can return 0 bytes for free and total memory if they do not - // have any to report. in this case, we will use the host memory as a fallback - // fixes: https://github.com/ggml-org/llama.cpp/issues/18577 - if (free == 0 && total == 0) { - ggml_backend_dev_t cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); - if (cpu_dev == nullptr) { - throw std::runtime_error(format("%s: no CPU backend found", __func__)); - } - ggml_backend_dev_memory(cpu_dev, &free, &total); - } - ret[i].free = free; - ret[i].total = total; - } - - devs = model->devices; - hp_ngl = model->hparams.n_layer; - hp_n_ctx_train = model->hparams.n_ctx_train; - hp_n_expert = model->hparams.n_expert; - - llama_memory_breakdown_print(ctx); // goes to debug log - - llama_free(ctx); - llama_model_free(model); - llama_log_set(ud.original_logger.callback, ud.original_logger.user_data); - return ret; -} - -// enum to identify part of a layer for distributing its tensors: -enum layer_fraction_t { - LAYER_FRACTION_NONE = 0, // nothing - LAYER_FRACTION_ATTN = 1, // attention - LAYER_FRACTION_UP = 2, // attention + up - LAYER_FRACTION_GATE = 3, // attention + up + gate - LAYER_FRACTION_MOE = 4, // everything but sparse MoE weights -}; -// this enum is only used in llama_params_fit_impl but needs to be defined outside of it to fix a Windows compilation issue - -class llama_params_fit_exception : public std::runtime_error { - using std::runtime_error::runtime_error; -}; - -static void llama_params_fit_impl( - const char * path_model, struct llama_model_params * mparams, struct llama_context_params * cparams, - float * tensor_split, struct llama_model_tensor_buft_override * tensor_buft_overrides, - size_t * margins_s, uint32_t n_ctx_min, enum ggml_log_level log_level) { - constexpr int64_t MiB = 1024*1024; - typedef std::vector dmds_t; - const llama_model_params default_mparams = llama_model_default_params(); - - std::vector devs; - uint32_t hp_ngl = 0; // hparams.n_gpu_layers - uint32_t hp_nct = 0; // hparams.n_ctx_train - uint32_t hp_nex = 0; // hparams.n_expert - - // step 1: get data for default parameters and check whether any changes are necessary in the first place - - LLAMA_LOG_DEBUG("%s: getting device memory data for initial parameters:\n", __func__); - const dmds_t dmds_full = llama_get_device_memory_data(path_model, mparams, cparams, devs, hp_ngl, hp_nct, hp_nex, log_level); - const size_t nd = devs.size(); // number of devices - if (nd == 0) { - LLAMA_LOG_INFO("%s: no devices with dedicated memory found\n", __func__); - return; - } - - std::vector margins; // this function uses int64_t rather than size_t for memory sizes to more conveniently handle deficits - margins.reserve(nd); - for (size_t id = 0; id < nd; id++) { - margins.push_back(margins_s[id]); - } - - std::vector dev_names; - { - dev_names.reserve(nd); - size_t max_length = 0; - for (ggml_backend_dev_t dev : devs) { - std::string name = ggml_backend_dev_name(dev); - name += " ("; - name += ggml_backend_dev_description(dev); - name += ")"; - dev_names.push_back(name); - max_length = std::max(max_length, name.length()); - } - for (std::string & dn : dev_names) { - dn.insert(dn.end(), max_length - dn.length(), ' '); - } - } - - int64_t sum_free = 0; - int64_t sum_projected_free = 0; - int64_t sum_projected_used = 0; - int64_t sum_projected_model = 0; - std::vector projected_free_per_device; - projected_free_per_device.reserve(nd); - - if (nd > 1) { - LLAMA_LOG_INFO("%s: projected memory use with initial parameters [MiB]:\n", __func__); - } - for (size_t id = 0; id < nd; id++) { - const llama_device_memory_data & dmd = dmds_full[id]; - - const int64_t projected_used = dmd.mb.total(); - const int64_t projected_free = dmd.free - projected_used; - projected_free_per_device.push_back(projected_free); - - sum_free += dmd.free; - sum_projected_used += projected_used; - sum_projected_free += projected_free; - sum_projected_model += dmd.mb.model; - - if (nd > 1) { - LLAMA_LOG_INFO("%s: - %s: %6" PRId64 " total, %6" PRId64 " used, %6" PRId64 " free vs. target of %6" PRId64 "\n", - __func__, dev_names[id].c_str(), dmd.total/MiB, projected_used/MiB, projected_free/MiB, margins[id]/MiB); - } - } - assert(sum_free >= 0 && sum_projected_used >= 0); - LLAMA_LOG_INFO("%s: projected to use %" PRId64 " MiB of device memory vs. %" PRId64 " MiB of free device memory\n", - __func__, sum_projected_used/MiB, sum_free/MiB); - if (nd == 1) { - if (projected_free_per_device[0] >= margins[0]) { - LLAMA_LOG_INFO("%s: will leave %" PRId64 " >= %" PRId64 " MiB of free device memory, no changes needed\n", - __func__, projected_free_per_device[0]/MiB, margins[0]/MiB); - return; - } - } else { - bool changes_needed = false; - for (size_t id = 0; id < nd; id++) { - if (projected_free_per_device[id] < margins[id]) { - changes_needed = true; - break; - } - } - if (!changes_needed) { - LLAMA_LOG_INFO("%s: targets for free memory can be met on all devices, no changes needed\n", __func__); - return; - } - } - - // step 2: try reducing memory use by reducing the context size - - { - int64_t global_surplus = sum_projected_free; - for (size_t id = 0; id < nd; id++) { - global_surplus -= margins[id]; - } - if (global_surplus < 0) { - if (nd == 1) { - LLAMA_LOG_INFO("%s: cannot meet free memory target of %" PRId64 " MiB, need to reduce device memory by %" PRId64 " MiB\n", - __func__, margins[0]/MiB, -global_surplus/MiB); - } else { - LLAMA_LOG_INFO( - "%s: cannot meet free memory targets on all devices, need to use %" PRId64 " MiB less in total\n", - __func__, -global_surplus/MiB); - } - if (cparams->n_ctx == 0) { - if (hp_nct > n_ctx_min) { - int64_t sum_used_target = sum_free; - for (size_t id = 0; id < nd; id++) { - sum_used_target -= margins[id]; - } - if (nd > 1) { - // for multiple devices we need to be more conservative in terms of how much context we think can fit: - // - for dense models only whole layers can be assigned to devices - // - for MoE models only whole tensors can be assigned to devices, which we estimate to be <= 1/3 of a layer - // - on average we expect a waste of 0.5 layers/tensors per device - // - use slightly more than the expected average for nd devices to be safe - const int64_t model_per_layer = sum_projected_model / std::min(uint32_t(mparams->n_gpu_layers), hp_ngl); - sum_used_target -= (nd + 1) * model_per_layer / (hp_nex == 0 ? 2 : 6); - } - - int64_t sum_projected_used_min_ctx = 0; - cparams->n_ctx = n_ctx_min; - const dmds_t dmds_min_ctx = llama_get_device_memory_data(path_model, mparams, cparams, devs, hp_ngl, hp_nct, hp_nex, log_level); - for (const auto & dmd : dmds_min_ctx) { - sum_projected_used_min_ctx += dmd.mb.total(); - } - if (sum_used_target > sum_projected_used_min_ctx) { - // linear interpolation between minimum and maximum context size: - cparams->n_ctx += (hp_nct - n_ctx_min) * (sum_used_target - sum_projected_used_min_ctx) - / (sum_projected_used - sum_projected_used_min_ctx); - cparams->n_ctx = std::max(cparams->n_ctx - cparams->n_ctx % 256, n_ctx_min); // round down context for CUDA backend - - const int64_t bytes_per_ctx = (sum_projected_used - sum_projected_used_min_ctx) / (hp_nct - n_ctx_min); - const int64_t memory_reduction = (hp_nct - cparams->n_ctx) * bytes_per_ctx; - LLAMA_LOG_INFO("%s: context size reduced from %" PRIu32 " to %" PRIu32 " -> need %" PRId64 " MiB less memory in total\n", - __func__, hp_nct, cparams->n_ctx, memory_reduction/MiB); - if (nd == 1) { - LLAMA_LOG_INFO("%s: entire model can be fit by reducing context\n", __func__); - return; - } - LLAMA_LOG_INFO("%s: entire model should be fit across devices by reducing context\n", __func__); - } else { - const int64_t memory_reduction = sum_projected_used - sum_projected_used_min_ctx; - LLAMA_LOG_INFO("%s: context size reduced from %" PRIu32 " to %" PRIu32 " -> need %" PRId64 " MiB less memory in total\n", - __func__, hp_nct, cparams->n_ctx, memory_reduction/MiB); - } - } else { - if (n_ctx_min == UINT32_MAX) { - LLAMA_LOG_INFO("%s: user has requested full context size of %" PRIu32 " -> no change\n", __func__, hp_nct); - } else { - LLAMA_LOG_INFO("%s: default model context size is %" PRIu32 " which is <= the min. context size of %" PRIu32 " -> no change\n", - __func__, hp_nct, n_ctx_min); - } - } - } else { - LLAMA_LOG_INFO("%s: context size set by user to %" PRIu32 " -> no change\n", __func__, cparams->n_ctx); - } - } - } - - if (mparams->n_gpu_layers != default_mparams.n_gpu_layers) { - throw llama_params_fit_exception("n_gpu_layers already set by user to " + std::to_string(mparams->n_gpu_layers) + ", abort"); - } - if (nd > 1) { - if (!tensor_split) { - throw llama_params_fit_exception("did not provide a buffer to write the tensor_split to, abort"); - } - if (mparams->tensor_split) { - for (size_t id = 0; id < nd; id++) { - if (mparams->tensor_split[id] != 0.0f) { - throw llama_params_fit_exception("model_params::tensor_split already set by user, abort"); - } - } - } - if (mparams->split_mode == LLAMA_SPLIT_MODE_ROW) { - throw llama_params_fit_exception("changing weight allocation for LLAMA_SPLIT_MODE_ROW not implemented, abort"); - } - } - if (!tensor_buft_overrides) { - throw llama_params_fit_exception("did not provide buffer to set tensor_buft_overrides, abort"); - } - if (mparams->tensor_buft_overrides && (mparams->tensor_buft_overrides->pattern || mparams->tensor_buft_overrides->buft)) { - throw llama_params_fit_exception("model_params::tensor_buft_overrides already set by user, abort"); - } - - // step 3: iteratively fill the back to front with "dense" layers - // - for a dense model simply fill full layers, giving each device a contiguous slice of the model - // - for a MoE model, same as dense model but with all MoE tensors in system memory - - // utility function that returns a static C string matching the tensors for a specific layer index and layer fraction: - auto get_overflow_pattern = [&](const size_t il, const layer_fraction_t lf) -> const char * { - constexpr size_t n_strings = 1000; - if (il >= n_strings) { - throw std::runtime_error("at most " + std::to_string(n_strings) + " model layers are supported"); - } - switch (lf) { - case LAYER_FRACTION_ATTN: { - static std::array patterns; - if (patterns[il].empty()) { - patterns[il] = "blk\\." + std::to_string(il) + "\\.ffn_(up|gate|down).*"; - } - return patterns[il].c_str(); - } - case LAYER_FRACTION_UP: { - static std::array patterns; - if (patterns[il].empty()) { - patterns[il] = "blk\\." + std::to_string(il) + "\\.ffn_(gate|down).*"; - } - return patterns[il].c_str(); - } - case LAYER_FRACTION_GATE: { - static std::array patterns; - if (patterns[il].empty()) { - patterns[il] = "blk\\." + std::to_string(il) + "\\.ffn_down.*"; - } - return patterns[il].c_str(); - } - case LAYER_FRACTION_MOE: { - static std::array patterns; - if (patterns[il].empty()) { - patterns[il] = "blk\\." + std::to_string(il) + "\\.ffn_(up|down|gate)_(ch|)exps"; - } - return patterns[il].c_str(); - } - default: - GGML_ABORT("fatal error"); - } - }; - - struct ngl_t { - uint32_t n_layer = 0; // number of total layers - uint32_t n_part = 0; // number of partial layers, <= n_layer - - // for the first partial layer varying parts can overflow, all further layers use LAYER_FRACTION_MOE: - layer_fraction_t overflow_type = LAYER_FRACTION_MOE; - - uint32_t n_full() const { - assert(n_layer >= n_part); - return n_layer - n_part; - } - }; - - const size_t ntbo = llama_max_tensor_buft_overrides(); - - // utility function to set n_gpu_layers and tensor_split - auto set_ngl_tensor_split_tbo = [&]( - const std::vector & ngl_per_device, - const std::vector & overflow_bufts, - llama_model_params & mparams) { - mparams.n_gpu_layers = 0; - for (size_t id = 0; id < nd; id++) { - mparams.n_gpu_layers += ngl_per_device[id].n_layer; - if (nd > 1) { - tensor_split[id] = ngl_per_device[id].n_layer; - } - } - assert(uint32_t(mparams.n_gpu_layers) <= hp_ngl + 1); - uint32_t il0 = hp_ngl + 1 - mparams.n_gpu_layers; // start index for tensor buft overrides - - mparams.tensor_split = tensor_split; - - size_t itbo = 0; - for (size_t id = 0; id < nd; id++) { - il0 += ngl_per_device[id].n_full(); - for (uint32_t il = il0; il < il0 + ngl_per_device[id].n_part; il++) { - if (itbo + 1 >= ntbo) { - tensor_buft_overrides[itbo].pattern = nullptr; - tensor_buft_overrides[itbo].buft = nullptr; - itbo++; - mparams.tensor_buft_overrides = tensor_buft_overrides; - throw llama_params_fit_exception("llama_max_tensor_buft_overrides() == " - + std::to_string(ntbo) + " is insufficient for model"); - } - tensor_buft_overrides[itbo].pattern = get_overflow_pattern(il, il == il0 ? ngl_per_device[id].overflow_type : LAYER_FRACTION_MOE); - tensor_buft_overrides[itbo].buft = il == il0 ? overflow_bufts[id] : ggml_backend_cpu_buffer_type(); - itbo++; - } - il0 += ngl_per_device[id].n_part; - } - tensor_buft_overrides[itbo].pattern = nullptr; - tensor_buft_overrides[itbo].buft = nullptr; - itbo++; - mparams.tensor_buft_overrides = tensor_buft_overrides; - }; - - // utility function that returns the memory use per device for given numbers of layers per device - auto get_memory_for_layers = [&]( - const char * func_name, - const std::vector & ngl_per_device, - const std::vector & overflow_bufts) -> std::vector { - llama_model_params mparams_copy = *mparams; - set_ngl_tensor_split_tbo(ngl_per_device, overflow_bufts, mparams_copy); - - const dmds_t dmd_nl = llama_get_device_memory_data( - path_model, &mparams_copy, cparams, devs, hp_ngl, hp_nct, hp_nex, log_level); - - LLAMA_LOG_DEBUG("%s: memory for test allocation by device:\n", func_name); - for (size_t id = 0; id < nd; id++) { - const ngl_t & n = ngl_per_device[id]; - LLAMA_LOG_DEBUG( - "%s: id=%zu, n_layer=%2" PRIu32 ", n_part=%2" PRIu32 ", overflow_type=%d, mem=%6" PRId64 " MiB\n", - func_name, id, n.n_layer, n.n_part, int(n.overflow_type), dmd_nl[id].mb.total()/MiB); - } - - std::vector ret; - ret.reserve(nd); - for (const llama_device_memory_data & dmd : dmd_nl) { - ret.push_back(dmd.mb.total()); - } - return ret; - }; - - int64_t global_surplus_cpu_moe = 0; - if (hp_nex > 0) { - const static std::string pattern_moe_all = "blk\\.\\d+\\.ffn_(up|down|gate)_(ch|)exps"; // matches all MoE tensors - ggml_backend_buffer_type_t cpu_buft = ggml_backend_cpu_buffer_type(); - tensor_buft_overrides[0] = {pattern_moe_all.c_str(), cpu_buft}; - tensor_buft_overrides[1] = {nullptr, nullptr}; - mparams->tensor_buft_overrides = tensor_buft_overrides; - - LLAMA_LOG_DEBUG("%s: getting device memory data with all MoE tensors moved to system memory:\n", __func__); - const dmds_t dmds_cpu_moe = llama_get_device_memory_data( - path_model, mparams, cparams, devs, hp_ngl, hp_nct, hp_nex, log_level); - - for (size_t id = 0; id < nd; id++) { - global_surplus_cpu_moe += dmds_cpu_moe[id].free; - global_surplus_cpu_moe -= int64_t(dmds_cpu_moe[id].mb.total()) + margins[id]; - } - - if (global_surplus_cpu_moe > 0) { - LLAMA_LOG_INFO("%s: with only dense weights in device memory there is a total surplus of %" PRId64 " MiB\n", - __func__, global_surplus_cpu_moe/MiB); - } else { - LLAMA_LOG_INFO("%s: with only dense weights in device memory there is still a total deficit of %" PRId64 " MiB\n", - __func__, -global_surplus_cpu_moe/MiB); - } - - // reset - tensor_buft_overrides[0] = {nullptr, nullptr}; - mparams->tensor_buft_overrides = tensor_buft_overrides; - } - - std::vector targets; // maximum acceptable memory use per device - targets.reserve(nd); - for (size_t id = 0; id < nd; id++) { - targets.push_back(dmds_full[id].free - margins[id]); - LLAMA_LOG_DEBUG("%s: id=%zu, target=%" PRId64 " MiB\n", __func__, id, targets[id]/MiB); - } - - std::vector overflow_bufts; // which bufts the first partial layer of a device overflows to: - overflow_bufts.reserve(nd); - for (size_t id = 0; id < nd; id++) { - overflow_bufts.push_back(ggml_backend_cpu_buffer_type()); - } - - std::vector ngl_per_device(nd); - std::vector mem = get_memory_for_layers(__func__, ngl_per_device, overflow_bufts); - - // optimize the number of layers per device using the method of false position: - // - ngl_per_device has 0 layers for each device, lower bound - // - try a "high" configuration where a device is given all unassigned layers - // - interpolate the memory use / layer between low and high linearly to get a guess where it meets our target - // - check memory use of our guess, replace either the low or high bound - // - once we only have a difference of a single layer, stop and return the lower bound that just barely still fits - // - the last device has the output layer, which cannot be a partial layer - if (hp_nex == 0) { - LLAMA_LOG_INFO("%s: filling dense layers back-to-front:\n", __func__); - } else { - LLAMA_LOG_INFO("%s: filling dense-only layers back-to-front:\n", __func__); - } - for (int id = nd - 1; id >= 0; id--) { - uint32_t n_unassigned = hp_ngl + 1; - for (size_t jd = id + 1; jd < nd; ++jd) { - assert(n_unassigned >= ngl_per_device[jd].n_layer); - n_unassigned -= ngl_per_device[jd].n_layer; - } - - std::vector ngl_per_device_high = ngl_per_device; - ngl_per_device_high[id].n_layer = n_unassigned; - if (hp_nex > 0) { - ngl_per_device_high[id].n_part = size_t(id) < nd - 1 ? ngl_per_device_high[id].n_layer : ngl_per_device_high[id].n_layer - 1; - } - if (ngl_per_device_high[id].n_layer > 0) { - std::vector mem_high = get_memory_for_layers(__func__, ngl_per_device_high, overflow_bufts); - if (mem_high[id] > targets[id]) { - assert(ngl_per_device_high[id].n_layer > ngl_per_device[id].n_layer); - uint32_t delta = ngl_per_device_high[id].n_layer - ngl_per_device[id].n_layer; - LLAMA_LOG_DEBUG("%s: start filling device %" PRIu32 ", delta=%" PRIu32 "\n", __func__, id, delta); - while (delta > 1) { - uint32_t step_size = int64_t(delta) * (targets[id] - mem[id]) / (mem_high[id] - mem[id]); - step_size = std::max(step_size, uint32_t(1)); - step_size = std::min(step_size, delta - 1); - - std::vector ngl_per_device_test = ngl_per_device; - ngl_per_device_test[id].n_layer += step_size; - if (hp_nex) { - ngl_per_device_test[id].n_part += size_t(id) == nd - 1 && ngl_per_device_test[id].n_part == 0 ? - step_size - 1 : step_size; // the first layer is the output layer which must always be full - } - const std::vector mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts); - - if (mem_test[id] <= targets[id]) { - ngl_per_device = ngl_per_device_test; - mem = mem_test; - LLAMA_LOG_DEBUG("%s: set ngl_per_device[%d].n_layer=%" PRIu32 "\n", __func__, id, ngl_per_device[id].n_layer); - } else { - ngl_per_device_high = ngl_per_device_test; - mem_high = mem_test; - LLAMA_LOG_DEBUG("%s: set ngl_per_device_high[%d].n_layer=%" PRIu32 "\n", __func__, id, ngl_per_device_high[id].n_layer); - } - delta = ngl_per_device_high[id].n_layer - ngl_per_device[id].n_layer; - } - } else { - assert(ngl_per_device_high[id].n_layer == n_unassigned); - ngl_per_device = ngl_per_device_high; - mem = mem_high; - LLAMA_LOG_DEBUG("%s: set ngl_per_device[%d].n_layer=%" PRIu32 "\n", __func__, id, ngl_per_device[id].n_layer); - } - } - - const int64_t projected_margin = dmds_full[id].free - mem[id]; - LLAMA_LOG_INFO( - "%s: - %s: %2" PRIu32 " layers, %6" PRId64 " MiB used, %6" PRId64 " MiB free\n", - __func__, dev_names[id].c_str(), ngl_per_device[id].n_layer, mem[id]/MiB, projected_margin/MiB); - } - if (hp_nex == 0 || global_surplus_cpu_moe <= 0) { - set_ngl_tensor_split_tbo(ngl_per_device, overflow_bufts, *mparams); - return; - } - - // step 4: for a MoE model where all dense tensors fit, - // convert the dense-only layers in the back to full layers in the front until all devices are full - // essentially the same procedure as for the dense-only layers except front-to-back - // also, try fitting at least part of one more layer to reduce waste for "small" GPUs with e.g. 24 GiB VRAM - - size_t id_dense_start = nd; - for (int id = nd - 1; id >= 0; id--) { - if (ngl_per_device[id].n_layer > 0) { - id_dense_start = id; - continue; - } - break; - } - assert(id_dense_start < nd); - - LLAMA_LOG_INFO("%s: converting dense-only layers to full layers and filling them front-to-back with overflow to next device/system memory:\n", __func__); - for (size_t id = 0; id <= id_dense_start && id_dense_start < nd; id++) { - std::vector ngl_per_device_high = ngl_per_device; - for (size_t jd = id_dense_start; jd < nd; jd++) { - const uint32_t n_layer_move = jd < nd - 1 ? ngl_per_device_high[jd].n_layer : ngl_per_device_high[jd].n_layer - 1; - ngl_per_device_high[id].n_layer += n_layer_move; - ngl_per_device_high[jd].n_layer -= n_layer_move; - ngl_per_device_high[jd].n_part = 0; - } - size_t id_dense_start_high = nd - 1; - std::vector mem_high = get_memory_for_layers(__func__, ngl_per_device_high, overflow_bufts); - - if (mem_high[id] > targets[id]) { - assert(ngl_per_device_high[id].n_full() >= ngl_per_device[id].n_full()); - uint32_t delta = ngl_per_device_high[id].n_full() - ngl_per_device[id].n_full(); - while (delta > 1) { - uint32_t step_size = int64_t(delta) * (targets[id] - mem[id]) / (mem_high[id] - mem[id]); - step_size = std::max(step_size, uint32_t(1)); - step_size = std::min(step_size, delta - 1); - - std::vector ngl_per_device_test = ngl_per_device; - size_t id_dense_start_test = id_dense_start; - uint32_t n_converted_test = 0; - for (;id_dense_start_test < nd; id_dense_start_test++) { - const uint32_t n_convert_jd = std::min(step_size - n_converted_test, ngl_per_device_test[id_dense_start_test].n_part); - ngl_per_device_test[id_dense_start_test].n_layer -= n_convert_jd; - ngl_per_device_test[id_dense_start_test].n_part -= n_convert_jd; - ngl_per_device_test[id].n_layer += n_convert_jd; - n_converted_test += n_convert_jd; - - if (ngl_per_device_test[id_dense_start_test].n_part > 0) { - break; - } - } - const std::vector mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts); - - if (mem_test[id] <= targets[id]) { - ngl_per_device = ngl_per_device_test; - mem = mem_test; - id_dense_start = id_dense_start_test; - LLAMA_LOG_DEBUG("%s: set ngl_per_device[%zu].(n_layer, n_part)=(%" PRIu32 ", %" PRIu32 "), id_dense_start=%zu\n", - __func__, id, ngl_per_device[id].n_layer, ngl_per_device[id].n_part, id_dense_start); - } else { - ngl_per_device_high = ngl_per_device_test; - mem_high = mem_test; - id_dense_start_high = id_dense_start_test; - LLAMA_LOG_DEBUG("%s: set ngl_per_device_high[%zu].(n_layer, n_part)=(%" PRIu32 ", %" PRIu32 "), id_dense_start_high=%zu\n", - __func__, id, ngl_per_device_high[id].n_layer, ngl_per_device_high[id].n_part, id_dense_start_high); - } - assert(ngl_per_device_high[id].n_full() >= ngl_per_device[id].n_full()); - delta = ngl_per_device_high[id].n_full() - ngl_per_device[id].n_full(); - } - } else { - ngl_per_device = ngl_per_device_high; - mem = mem_high; - id_dense_start = id_dense_start_high; - LLAMA_LOG_DEBUG("%s: set ngl_per_device[%zu].(n_layer, n_part)=(%" PRIu32 ", %" PRIu32 "), id_dense_start=%zu\n", - __func__, id, ngl_per_device[id].n_layer, ngl_per_device[id].n_part, id_dense_start); - } - - // try to fit at least part of one more layer - if (ngl_per_device[id_dense_start].n_layer > (id < nd - 1 ? 0 : 1)) { - std::vector ngl_per_device_test = ngl_per_device; - size_t id_dense_start_test = id_dense_start; - ngl_per_device_test[id_dense_start_test].n_layer--; - ngl_per_device_test[id_dense_start_test].n_part--; - ngl_per_device_test[id].n_layer++; - ngl_per_device_test[id].n_part++; - if (ngl_per_device_test[id_dense_start_test].n_part == 0) { - id_dense_start_test++; - } - ngl_per_device_test[id].overflow_type = LAYER_FRACTION_UP; - std::vector overflow_bufts_test = overflow_bufts; - if (id < nd - 1) { - overflow_bufts_test[id] = ggml_backend_dev_buffer_type(devs[id + 1]); - } - LLAMA_LOG_DEBUG("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_UP\n", __func__); - std::vector mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts_test); - if (mem_test[id] < targets[id] && (id + 1 == nd || mem_test[id + 1] < targets[id + 1])) { - ngl_per_device = ngl_per_device_test; - overflow_bufts = overflow_bufts_test; - mem = mem_test; - id_dense_start = id_dense_start_test; - LLAMA_LOG_DEBUG("%s: set ngl_per_device[%zu].(n_layer, n_part, overflow_type)=(%" PRIu32 ", %" PRIu32 ", UP), id_dense_start=%zu\n", - __func__, id, ngl_per_device[id].n_layer, ngl_per_device[id].n_part, id_dense_start); - - ngl_per_device_test[id].overflow_type = LAYER_FRACTION_GATE; - LLAMA_LOG_DEBUG("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_GATE\n", __func__); - mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts_test); - if (mem_test[id] < targets[id] && (id + 1 == nd || mem_test[id + 1] < targets[id + 1])) { - ngl_per_device = ngl_per_device_test; - overflow_bufts = overflow_bufts_test; - mem = mem_test; - id_dense_start = id_dense_start_test; - LLAMA_LOG_DEBUG("%s: set ngl_per_device[%zu].(n_layer, n_part, overflow_type)=(%" PRIu32 ", %" PRIu32 ", GATE), id_dense_start=%zu\n", - __func__, id, ngl_per_device[id].n_layer, ngl_per_device[id].n_part, id_dense_start); - } - } else { - ngl_per_device_test[id].overflow_type = LAYER_FRACTION_ATTN; - LLAMA_LOG_DEBUG("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_ATTN\n", __func__); - mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts_test); - if (mem_test[id] < targets[id] && (id + 1 == nd || mem_test[id + 1] < targets[id + 1])) { - ngl_per_device = ngl_per_device_test; - overflow_bufts = overflow_bufts_test; - mem = mem_test; - id_dense_start = id_dense_start_test; - LLAMA_LOG_DEBUG("%s: set ngl_per_device[%zu].(n_layer, n_part, overflow_type)=(%" PRIu32 ", %" PRIu32 ", ATTN), id_dense_start=%zu\n", - __func__, id, ngl_per_device[id].n_layer, ngl_per_device[id].n_part, id_dense_start); - } - } - } - - const int64_t projected_margin = dmds_full[id].free - mem[id]; - LLAMA_LOG_INFO( - "%s: - %s: %2" PRIu32 " layers (%2" PRIu32 " overflowing), %6" PRId64 " MiB used, %6" PRId64 " MiB free\n", - __func__, dev_names[id].c_str(), ngl_per_device[id].n_layer, ngl_per_device[id].n_part, mem[id]/MiB, projected_margin/MiB); - } - - // print info for devices that were not changed during the conversion from dense only to full layers: - for (size_t id = id_dense_start + 1; id < nd; id++) { - const int64_t projected_margin = dmds_full[id].free - mem[id]; - LLAMA_LOG_INFO( - "%s: - %s: %2" PRIu32 " layers (%2" PRIu32 " overflowing), %6" PRId64 " MiB used, %6" PRId64 " MiB free\n", - __func__, dev_names[id].c_str(), ngl_per_device[id].n_layer, ngl_per_device[id].n_part, mem[id]/MiB, projected_margin/MiB); - } - - set_ngl_tensor_split_tbo(ngl_per_device, overflow_bufts, *mparams); -} - -enum llama_params_fit_status llama_params_fit( - const char * path_model, struct llama_model_params * mparams, struct llama_context_params * cparams, - float * tensor_split, struct llama_model_tensor_buft_override * tensor_buft_overrides, - size_t * margins, uint32_t n_ctx_min, enum ggml_log_level log_level) { - const int64_t t0_us = llama_time_us(); - llama_params_fit_status status = LLAMA_PARAMS_FIT_STATUS_SUCCESS; - try { - llama_params_fit_impl(path_model, mparams, cparams, tensor_split, tensor_buft_overrides, margins, n_ctx_min, log_level); - LLAMA_LOG_INFO("%s: successfully fit params to free device memory\n", __func__); - } catch (const llama_params_fit_exception & e) { - LLAMA_LOG_WARN("%s: failed to fit params to free device memory: %s\n", __func__, e.what()); - status = LLAMA_PARAMS_FIT_STATUS_FAILURE; - } catch (const std::runtime_error & e) { - LLAMA_LOG_ERROR("%s: encountered an error while trying to fit params to free device memory: %s\n", __func__, e.what()); - status = LLAMA_PARAMS_FIT_STATUS_ERROR; - } - const int64_t t1_us = llama_time_us(); - LLAMA_LOG_INFO("%s: fitting params to free memory took %.2f seconds\n", __func__, (t1_us - t0_us) * 1e-6); - return status; -} - struct llama_sampler_chain_params llama_sampler_chain_default_params() { struct llama_sampler_chain_params result = { /*.no_perf =*/ true, @@ -828,7 +113,7 @@ int64_t llama_time_us(void) { // Returns 0 on success, -1 on error, and -2 on cancellation via llama_progress_callback static int llama_model_load(struct gguf_context * metadata, llama_model_set_tensor_data_t set_tensor_data, void * set_tensor_data_ud, - const std::string & fname, std::vector & splits, llama_model & model, llama_model_params & params) { + const std::string & fname, std::vector & splits, FILE * file, llama_model & model, llama_model_params & params) { // loading time will be recalculated after the first eval, so // we take page faults deferred by mmap() into consideration model.t_load_us = 0; @@ -837,7 +122,7 @@ static int llama_model_load(struct gguf_context * metadata, llama_model_set_tens model.t_start_us = tm.t_start_us; try { - llama_model_loader ml(metadata, set_tensor_data, set_tensor_data_ud, fname, splits, params.use_mmap, params.use_direct_io, + llama_model_loader ml(metadata, set_tensor_data, set_tensor_data_ud, fname, splits, file, params.use_mmap, params.use_direct_io, params.check_tensors, params.no_alloc, params.kv_overrides, params.tensor_buft_overrides); ml.print_info(); @@ -889,8 +174,24 @@ static struct llama_model * llama_model_load_from_file_impl( void * set_tensor_data_ud, const std::string & path_model, std::vector & splits, + FILE * file, struct llama_model_params params) { - GGML_ASSERT((metadata == nullptr) != path_model.empty() && "exactly one out of metadata and path_model needs to be defined"); + { + int n_sources_defined = 0; + if (metadata != nullptr) { + n_sources_defined++; + } + if (!path_model.empty()) { + n_sources_defined++; + } + if (file != nullptr) { + n_sources_defined++; + } + if (n_sources_defined != 1) { + LLAMA_LOG_ERROR("%s: exactly one out metadata, path_model, and file must be defined\n", __func__); + return nullptr; + } + } ggml_time_init(); if (!params.vocab_only && ggml_backend_reg_count() == 0) { @@ -919,58 +220,111 @@ static struct llama_model * llama_model_load_from_file_impl( // create list of devices to use with this model if (params.devices) { - for (ggml_backend_dev_t * dev = params.devices; *dev; ++dev) { - model->devices.push_back(*dev); + if (params.split_mode == LLAMA_SPLIT_MODE_TENSOR) { + size_t n_devs = 0; + while (params.devices[n_devs]) { + n_devs++; + } + if (n_devs == 0) { + LLAMA_LOG_ERROR("%s: LLAMA_SPLIT_MODE_TENSOR needs >= 1 devices\n", __func__); + return nullptr; + } + LLAMA_LOG_INFO("%s: creating a Meta device with %zu devices\n", __func__, n_devs); + for (size_t i = 0; i < n_devs; ++i) { + LLAMA_LOG_INFO("%s: - device %zu: %s\n", __func__, i, ggml_backend_dev_name(params.devices[i])); + } + model->get_split_state_ud.n_devices = n_devs; + model->get_split_state_ud.model = model; + model->devices.push_back({ + true, ggml_backend_meta_device( + params.devices, n_devs, llama_meta_device_get_split_state, &model->get_split_state_ud) + }); + } else { + for (ggml_backend_dev_t * dev = params.devices; *dev; ++dev) { + model->devices.push_back({false, *dev}); + } } } else { // default device selection // build list of available devices - std::vector gpus; - std::vector igpus; - std::vector rpc_servers; - - for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { - ggml_backend_dev_t dev = ggml_backend_dev_get(i); - switch (ggml_backend_dev_type(dev)) { - case GGML_BACKEND_DEVICE_TYPE_CPU: - case GGML_BACKEND_DEVICE_TYPE_ACCEL: - // skip CPU backends since they are handled separately - break; - - case GGML_BACKEND_DEVICE_TYPE_GPU: { - ggml_backend_reg_t reg = ggml_backend_dev_backend_reg(dev); - if (ggml_backend_reg_name(reg) == std::string("RPC")) { - rpc_servers.push_back(dev); - } else { - // check if there is already a GPU with the same device id - ggml_backend_dev_props props; - ggml_backend_dev_get_props(dev, &props); - auto it = std::find_if(gpus.begin(), gpus.end(), [&props](ggml_backend_dev_t d) { - ggml_backend_dev_props d_props; - ggml_backend_dev_get_props(d, &d_props); - if (props.device_id && d_props.device_id) { - return strcmp(props.device_id, d_props.device_id) == 0; - } - return false; - }); - - if (it != gpus.end()) { - LLAMA_LOG_INFO("%s: skipping device %s (%s) with id %s - already using device %s (%s) with the same id\n", - __func__, - ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), - props.device_id ? props.device_id : "unknown id", - ggml_backend_dev_name(*it), ggml_backend_dev_description(*it)); + std::vector gpus; + std::vector igpus; + std::vector rpc_servers; + + if (params.split_mode == LLAMA_SPLIT_MODE_TENSOR) { + std::vector devs; + devs.reserve(ggml_backend_dev_count()); + for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { + auto * dev = ggml_backend_dev_get(i); + if (ggml_backend_dev_buffer_type(dev) == ggml_backend_cpu_buffer_type()) { + LLAMA_LOG_INFO("%s: skipping %s (%s) for tensor parallelism\n", __func__, ggml_backend_dev_name(dev), ggml_backend_dev_description(dev)); + continue; + } + devs.push_back(dev); + } + if (devs.empty()) { + LLAMA_LOG_ERROR("%s: LLAMA_SPLIT_MODE_TENSOR needs >= 1 devices\n", __func__); + return nullptr; + } + + LLAMA_LOG_INFO("%s: creating a Meta device for tensor parallelism from %zu devices:\n", __func__, devs.size()); + for (size_t i = 0; i < devs.size(); ++i) { + LLAMA_LOG_INFO("%s: - device %zu: %s (%s)\n", __func__, i, ggml_backend_dev_name(devs[i]), ggml_backend_dev_description(devs[i])); + } + + GGML_ASSERT(!devs.empty()); + model->get_split_state_ud.n_devices = devs.size(); + model->get_split_state_ud.model = model; + gpus.push_back({ + true, ggml_backend_meta_device( + devs.data(), devs.size(), llama_meta_device_get_split_state, &model->get_split_state_ud) + }); + } else { + for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { + ggml_backend_dev_t dev = ggml_backend_dev_get(i); + switch (ggml_backend_dev_type(dev)) { + case GGML_BACKEND_DEVICE_TYPE_CPU: + case GGML_BACKEND_DEVICE_TYPE_ACCEL: + // skip CPU backends since they are handled separately + break; + + case GGML_BACKEND_DEVICE_TYPE_GPU: { + ggml_backend_reg_t reg = ggml_backend_dev_backend_reg(dev); + if (ggml_backend_reg_name(reg) == std::string("RPC")) { + rpc_servers.push_back({false, dev}); } else { - gpus.push_back(dev); + // check if there is already a GPU with the same device id + ggml_backend_dev_props props; + ggml_backend_dev_get_props(dev, &props); + auto it = std::find_if(gpus.begin(), gpus.end(), [&props](const llama_device & d) { + ggml_backend_dev_props d_props; + ggml_backend_dev_get_props(d.dev, &d_props); + if (props.device_id && d_props.device_id) { + return strcmp(props.device_id, d_props.device_id) == 0; + } + return false; + }); + + if (it != gpus.end()) { + LLAMA_LOG_INFO("%s: skipping device %s (%s) with id %s - already using device %s (%s) with the same id\n", + __func__, + ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), + props.device_id ? props.device_id : "unknown id", + ggml_backend_dev_name(it->dev), ggml_backend_dev_description(it->dev)); + } else { + gpus.push_back({false, dev}); + } } + break; } - break; - } - case GGML_BACKEND_DEVICE_TYPE_IGPU: - igpus.push_back(dev); - break; + case GGML_BACKEND_DEVICE_TYPE_IGPU: + igpus.push_back({false, dev}); + break; + case GGML_BACKEND_DEVICE_TYPE_META: + GGML_ABORT("fatal error"); + } } } @@ -996,22 +350,22 @@ static struct llama_model * llama_model_load_from_file_impl( llama_model_free(model); return nullptr; } - ggml_backend_dev_t main_gpu = model->devices[params.main_gpu]; + llama_device main_gpu = model->devices[params.main_gpu]; model->devices.clear(); model->devices.push_back(main_gpu); } } - for (auto * dev : model->devices) { + for (const auto & dev : model->devices) { ggml_backend_dev_props props; - ggml_backend_dev_get_props(dev, &props); + ggml_backend_dev_get_props(dev.dev, &props); LLAMA_LOG_INFO("%s: using device %s (%s) (%s) - %zu MiB free\n", __func__, - ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), + ggml_backend_dev_name(dev.dev), ggml_backend_dev_description(dev.dev), props.device_id ? props.device_id : "unknown id", props.memory_free/1024/1024); } - const int status = llama_model_load(metadata, set_tensor_data, set_tensor_data_ud, path_model, splits, *model, params); + const int status = llama_model_load(metadata, set_tensor_data, set_tensor_data_ud, path_model, splits, file, *model, params); GGML_ASSERT(status <= 0); if (status < 0) { if (status == -1) { @@ -1037,7 +391,7 @@ struct llama_model * llama_model_init_from_user( std::vector splits = {}; params.use_mmap = false; params.use_extra_bufts = false; - return llama_model_load_from_file_impl(metadata, set_tensor_data, set_tensor_data_ud, path_model, splits, params); + return llama_model_load_from_file_impl(metadata, set_tensor_data, set_tensor_data_ud, path_model, splits, /*file*/ nullptr, params); } // deprecated struct llama_model * llama_load_model_from_file( @@ -1050,7 +404,7 @@ struct llama_model * llama_model_load_from_file( const char * path_model, struct llama_model_params params) { std::vector splits = {}; - return llama_model_load_from_file_impl(nullptr, nullptr, nullptr, path_model, splits, params); + return llama_model_load_from_file_impl(nullptr, nullptr, nullptr, path_model, splits, /*file*/ nullptr, params); } struct llama_model * llama_model_load_from_splits( @@ -1066,7 +420,17 @@ struct llama_model * llama_model_load_from_splits( for (size_t i = 0; i < n_paths; ++i) { splits.push_back(paths[i]); } - return llama_model_load_from_file_impl(nullptr, nullptr, nullptr, splits.front(), splits, params); + return llama_model_load_from_file_impl(nullptr, nullptr, nullptr, splits.front(), splits, /*file*/ nullptr, params); +} + +struct llama_model * llama_model_load_from_file_ptr(FILE * file, struct llama_model_params params) { + if (!file) { + LLAMA_LOG_ERROR("%s: file is NULL\n", __func__); + return nullptr; + } + std::string path_model; + std::vector splits = {}; + return llama_model_load_from_file_impl(nullptr, nullptr, nullptr, path_model, splits, file, params); } void llama_model_save_to_file(const struct llama_model * model, const char * path_model) { diff --git a/examples/talk-llama/llama.h b/examples/talk-llama/llama.h index c6e102abe51..eb869814097 100644 --- a/examples/talk-llama/llama.h +++ b/examples/talk-llama/llama.h @@ -154,6 +154,7 @@ extern "C" { LLAMA_FTYPE_MOSTLY_TQ2_0 = 37, // except 1d tensors LLAMA_FTYPE_MOSTLY_MXFP4_MOE = 38, // except 1d tensors LLAMA_FTYPE_MOSTLY_NVFP4 = 39, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q1_0 = 40, // except 1d tensors LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file }; @@ -191,9 +192,10 @@ extern "C" { LLAMA_API const char * llama_flash_attn_type_name(enum llama_flash_attn_type flash_attn_type); enum llama_split_mode { - LLAMA_SPLIT_MODE_NONE = 0, // single GPU - LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs - LLAMA_SPLIT_MODE_ROW = 2, // split layers and KV across GPUs, use tensor parallelism if supported + LLAMA_SPLIT_MODE_NONE = 0, // single GPU + LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs + LLAMA_SPLIT_MODE_ROW = 2, // split layers and KV across GPUs, use tensor parallelism if supported + LLAMA_SPLIT_MODE_TENSOR = 3, }; // TODO: simplify (https://github.com/ggml-org/llama.cpp/pull/9294#pullrequestreview-2286561979) @@ -380,22 +382,33 @@ extern "C" { size_t n_samplers; }; + struct llama_model_tensor_override { + const char * pattern; + enum ggml_type type; + }; + + struct llama_model_imatrix_data { + const char * name; + const float * data; + size_t size; + }; + // model quantization parameters typedef struct llama_model_quantize_params { - int32_t nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency() - enum llama_ftype ftype; // quantize to this llama_ftype - enum ggml_type output_tensor_type; // output tensor type - enum ggml_type token_embedding_type; // token embeddings tensor type - bool allow_requantize; // allow quantizing non-f32/f16 tensors - bool quantize_output_tensor; // quantize output.weight - bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored - bool pure; // quantize all tensors to the default type - bool keep_split; // quantize to the same number of shards - bool dry_run; // calculate and show the final quantization size without performing quantization - void * imatrix; // pointer to importance matrix data - void * kv_overrides; // pointer to vector containing overrides - void * tensor_types; // pointer to vector containing tensor types - void * prune_layers; // pointer to vector containing layer indices to prune + int32_t nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency() + enum llama_ftype ftype; // quantize to this llama_ftype + enum ggml_type output_tensor_type; // output tensor type + enum ggml_type token_embedding_type; // token embeddings tensor type + bool allow_requantize; // allow quantizing non-f32/f16 tensors + bool quantize_output_tensor; // quantize output.weight + bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored + bool pure; // quantize all tensors to the default type + bool keep_split; // quantize to the same number of shards + bool dry_run; // calculate and show the final quantization size without performing quantization + const struct llama_model_imatrix_data * imatrix; // pointer to importance matrix data + const struct llama_model_kv_override * kv_overrides; // pointer to kv overrides + const struct llama_model_tensor_override * tt_overrides; // pointer to tensor overrides + const int32_t * prune_layers; // pointer to layer indices to prune } llama_model_quantize_params; typedef struct llama_logit_bias { @@ -465,6 +478,11 @@ extern "C" { const char * path_model, struct llama_model_params params); + // Load a model from an open FILE pointer + LLAMA_API struct llama_model * llama_model_load_from_file_ptr( + FILE * file, + struct llama_model_params params); + // Load a model from multiple splits (support custom naming scheme) // The paths must be in the correct order LLAMA_API struct llama_model * llama_model_load_from_splits( @@ -493,27 +511,6 @@ extern "C" { // Frees all allocated memory LLAMA_API void llama_free(struct llama_context * ctx); - enum llama_params_fit_status { - LLAMA_PARAMS_FIT_STATUS_SUCCESS = 0, // found allocations that are projected to fit - LLAMA_PARAMS_FIT_STATUS_FAILURE = 1, // could not find allocations that are projected to fit - LLAMA_PARAMS_FIT_STATUS_ERROR = 2, // a hard error occurred, e.g. because no model could be found at the specified path - }; - - // fits mparams and cparams to free device memory (assumes system memory is unlimited) - // - returns true if the parameters could be successfully modified to fit device memory - // - this function is NOT thread safe because it modifies the global llama logger state - // - only parameters that have the same value as in llama_default_model_params are modified - // with the exception of the context size which is modified if and only if equal to 0 - LLAMA_API enum llama_params_fit_status llama_params_fit( - const char * path_model, - struct llama_model_params * mparams, - struct llama_context_params * cparams, - float * tensor_split, // writable buffer for tensor split, needs at least llama_max_devices elements - struct llama_model_tensor_buft_override * tensor_buft_overrides, // writable buffer for overrides, needs at least llama_max_tensor_buft_overrides elements - size_t * margins, // margins of memory to leave per device in bytes - uint32_t n_ctx_min, // minimum context size to set when trying to reduce memory use - enum ggml_log_level log_level); // minimum log level to print during fitting, lower levels go to debug log - LLAMA_API int64_t llama_time_us(void); LLAMA_API size_t llama_max_devices(void); @@ -636,7 +633,6 @@ extern "C" { // Load a LoRA adapter from file // The adapter is valid as long as the associated model is not freed - // All adapters must be loaded before context creation LLAMA_API struct llama_adapter_lora * llama_adapter_lora_init( struct llama_model * model, const char * path_lora); @@ -660,9 +656,8 @@ extern "C" { LLAMA_API int32_t llama_adapter_meta_val_str_by_index(const struct llama_adapter_lora * adapter, int32_t i, char * buf, size_t buf_size); // Manually free a LoRA adapter - // NOTE: loaded adapters will be free when the associated model is deleted - LLAMA_API DEPRECATED(void llama_adapter_lora_free(struct llama_adapter_lora * adapter), - "adapters are now freed together with the associated model"); + // NOTE: loaded adapters that are not manually freed will be freed when the associated model is deleted + LLAMA_API void llama_adapter_lora_free(struct llama_adapter_lora * adapter); // Get the invocation tokens if the current lora is an alora LLAMA_API uint64_t llama_adapter_get_alora_n_invocation_tokens(const struct llama_adapter_lora * adapter); @@ -1530,9 +1525,6 @@ extern "C" { LLAMA_API void llama_perf_sampler_print(const struct llama_sampler * chain); LLAMA_API void llama_perf_sampler_reset( struct llama_sampler * chain); - // print a breakdown of per-device memory use via LLAMA_LOG: - LLAMA_API void llama_memory_breakdown_print(const struct llama_context * ctx); - // // training // diff --git a/examples/talk-llama/models/afmoe.cpp b/examples/talk-llama/models/afmoe.cpp index 9aabe25c965..2790b12111d 100644 --- a/examples/talk-llama/models/afmoe.cpp +++ b/examples/talk-llama/models/afmoe.cpp @@ -41,22 +41,13 @@ llm_build_afmoe::llm_build_afmoe(const llama_model & model, const llm_graph_para { ggml_tensor * attn_inp = cur; // save input for gate computation - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); // compute gate from input ggml_tensor * gate = build_lora_mm(model.layers[il].wqkv_gate, attn_inp); cb(gate, "attn_gate_proj", il); - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - // Q/K normalization Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); @@ -77,10 +68,8 @@ llm_build_afmoe::llm_build_afmoe(const llama_model & model, const llm_graph_para cb(Kcur, "Kcur_rope", il); } - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - cur = build_attn(inp_attn, - NULL, NULL, // wo will be applied after gating + NULL, NULL, NULL, // wo will be applied after gating Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); @@ -91,7 +80,7 @@ llm_build_afmoe::llm_build_afmoe(const llama_model & model, const llm_graph_para cb(cur, "attn_gated", il); // now apply output projection - cur = build_lora_mm(model.layers[il].wo, cur); + cur = build_lora_mm(model.layers[il].wo, cur, model.layers[il].wo_s); cb(cur, "attn_o_proj", il); } diff --git a/examples/talk-llama/models/apertus.cpp b/examples/talk-llama/models/apertus.cpp index 4d65614e466..af44cea6054 100644 --- a/examples/talk-llama/models/apertus.cpp +++ b/examples/talk-llama/models/apertus.cpp @@ -1,7 +1,5 @@ #include "models.h" - - llm_build_apertus::llm_build_apertus(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); @@ -32,25 +30,15 @@ llm_build_apertus::llm_build_apertus(const llama_model & model, const llm_graph_ ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); cb(Qcur, "Qcur_normed", il); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); cb(Kcur, "Kcur_normed", il); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); @@ -62,7 +50,7 @@ llm_build_apertus::llm_build_apertus(const llama_model & model, const llm_graph_ cb(Vcur, "Vcur_pos", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); } diff --git a/examples/talk-llama/models/arcee.cpp b/examples/talk-llama/models/arcee.cpp index 20b9ffd49eb..2e71f5d9e2a 100644 --- a/examples/talk-llama/models/arcee.cpp +++ b/examples/talk-llama/models/arcee.cpp @@ -1,6 +1,5 @@ #include "models.h" - llm_build_arcee::llm_build_arcee(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); @@ -36,30 +35,8 @@ llm_build_arcee::llm_build_arcee(const llama_model & model, const llm_graph_para ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, rope_factors, @@ -78,7 +55,7 @@ llm_build_arcee::llm_build_arcee(const llama_model & model, const llm_graph_para cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); } diff --git a/examples/talk-llama/models/arctic.cpp b/examples/talk-llama/models/arctic.cpp index b712e08cbd3..f8ca6aff6ab 100644 --- a/examples/talk-llama/models/arctic.cpp +++ b/examples/talk-llama/models/arctic.cpp @@ -30,18 +30,8 @@ llm_build_arctic::llm_build_arctic(const llama_model & model, const llm_graph_pa // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -60,7 +50,7 @@ llm_build_arctic::llm_build_arctic(const llama_model & model, const llm_graph_pa cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } diff --git a/examples/talk-llama/models/baichuan.cpp b/examples/talk-llama/models/baichuan.cpp index abd03cd0b97..2d0d05df485 100644 --- a/examples/talk-llama/models/baichuan.cpp +++ b/examples/talk-llama/models/baichuan.cpp @@ -1,6 +1,5 @@ #include "models.h" - llm_build_baichuan::llm_build_baichuan(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); @@ -29,18 +28,8 @@ llm_build_baichuan::llm_build_baichuan(const llama_model & model, const llm_grap // self-attention { - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); switch (model.type) { case LLM_TYPE_7B: @@ -67,7 +56,7 @@ llm_build_baichuan::llm_build_baichuan(const llama_model & model, const llm_grap cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } diff --git a/examples/talk-llama/models/bailingmoe.cpp b/examples/talk-llama/models/bailingmoe.cpp index 25e3369c313..67a7120d622 100644 --- a/examples/talk-llama/models/bailingmoe.cpp +++ b/examples/talk-llama/models/bailingmoe.cpp @@ -28,30 +28,8 @@ llm_build_bailingmoe::llm_build_bailingmoe(const llama_model & model, const llm_ ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_rot, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_rot, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_rot, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head_k, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, rope_factors, @@ -70,7 +48,7 @@ llm_build_bailingmoe::llm_build_bailingmoe(const llama_model & model, const llm_ cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_rot)), il); } diff --git a/examples/talk-llama/models/bailingmoe2.cpp b/examples/talk-llama/models/bailingmoe2.cpp index 42098624663..497b4babd0c 100644 --- a/examples/talk-llama/models/bailingmoe2.cpp +++ b/examples/talk-llama/models/bailingmoe2.cpp @@ -3,7 +3,6 @@ llm_build_bailingmoe2::llm_build_bailingmoe2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); @@ -29,15 +28,8 @@ llm_build_bailingmoe2::llm_build_bailingmoe2(const llama_model & model, const ll // self_attention { - cur = build_lora_mm(model.layers[il].wqkv, cur); - cb(cur, "wqkv", il); - - ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head * sizeof(float), - cur->nb[1], 0 * sizeof(float) * (n_embd)); - ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float), - cur->nb[1], 1 * sizeof(float) * (n_embd)); - ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float), - cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa)); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); cb(Qcur, "Qcur_normed", il); @@ -56,7 +48,7 @@ llm_build_bailingmoe2::llm_build_bailingmoe2(const llama_model & model, const ll cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); } diff --git a/examples/talk-llama/models/bert.cpp b/examples/talk-llama/models/bert.cpp index 87331791418..7e046cfd2a4 100644 --- a/examples/talk-llama/models/bert.cpp +++ b/examples/talk-llama/models/bert.cpp @@ -2,7 +2,6 @@ llm_build_bert::llm_build_bert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); @@ -28,8 +27,8 @@ llm_build_bert::llm_build_bert(const llama_model & model, const llm_graph_params cb(inpL, "inp_embd", -1); // embed layer norm - inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1); - cb(inpL, "inp_norm", -1); + inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, 0); + cb(inpL, "inp_norm", 0); auto * inp_attn = build_attn_inp_no_cache(); @@ -39,35 +38,8 @@ llm_build_bert::llm_build_bert(const llama_model & model, const llm_graph_params ggml_tensor * cur = inpL; { - ggml_tensor * Qcur; - ggml_tensor * Kcur; - ggml_tensor * Vcur; - - // self-attention - if (model.layers[il].wqkv) { - cur = build_lora_mm(model.layers[il].wqkv, cur); - cb(cur, "wqkv", il); - - if (model.layers[il].bqkv) { - cur = ggml_add(ctx0, cur, model.layers[il].bqkv); - cb(cur, "bqkv", il); - } - - Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head * sizeof(float), cur->nb[1], - 0 * sizeof(float) * (n_embd)); - Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float), - cur->nb[1], 1 * sizeof(float) * (n_embd)); - Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float), - cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa)); - } else { - Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, cur), model.layers[il].bq); - Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, cur), model.layers[il].bk); - Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, cur), model.layers[il].bv); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - } + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); if (model.layers[il].attn_q_norm) { Qcur = ggml_reshape_2d(ctx0, Qcur, n_embd_head * n_head, n_tokens); @@ -100,7 +72,7 @@ llm_build_bert::llm_build_bert(const llama_model & model, const llm_graph_params cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); cb(cur, "kqv_out", il); } diff --git a/examples/talk-llama/models/bitnet.cpp b/examples/talk-llama/models/bitnet.cpp index ccf5bc8e82b..71526354ca6 100644 --- a/examples/talk-llama/models/bitnet.cpp +++ b/examples/talk-llama/models/bitnet.cpp @@ -28,33 +28,8 @@ llm_build_bitnet::llm_build_bitnet(const llama_model & model, const llm_graph_pa // self-attention { - // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur, model.layers[il].wq_s); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - - // B1.K - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur, model.layers[il].wk_s); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - - // B1.V - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur, model.layers[il].wv_s); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -73,7 +48,7 @@ llm_build_bitnet::llm_build_bitnet(const llama_model & model, const llm_graph_pa cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - NULL, NULL, + NULL, NULL, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); cur = build_norm(cur, @@ -82,8 +57,8 @@ llm_build_bitnet::llm_build_bitnet(const llama_model & model, const llm_graph_pa cb(cur, "attn_sub_norm", il); cur = build_lora_mm(model.layers[il].wo, cur, model.layers[il].wo_s); - if (model.layers[il].bo) { - cur = ggml_add(ctx0, cur, model.layers[il].bo); + if (model.layers[il].wo_b) { + cur = ggml_add(ctx0, cur, model.layers[il].wo_b); } cb(cur, "attn_out", il); } @@ -121,6 +96,9 @@ llm_build_bitnet::llm_build_bitnet(const llama_model & model, const llm_graph_pa cur = ggml_add(ctx0, cur, ffn_inp); cb(cur, "l_out", il); + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + // input for next layer inpL = cur; } diff --git a/examples/talk-llama/models/bloom.cpp b/examples/talk-llama/models/bloom.cpp index b1c19bb58a2..f3b0999bf54 100644 --- a/examples/talk-llama/models/bloom.cpp +++ b/examples/talk-llama/models/bloom.cpp @@ -2,7 +2,6 @@ llm_build_bloom::llm_build_bloom(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); @@ -16,8 +15,8 @@ llm_build_bloom::llm_build_bloom(const llama_model & model, const llm_graph_para inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, - LLM_NORM, -1); - cb(inpL, "inp_norm", -1); + LLM_NORM, 0); + cb(inpL, "inp_norm", 0); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -30,22 +29,11 @@ llm_build_bloom::llm_build_bloom(const llama_model & model, const llm_graph_para // self-attention { - cur = build_lora_mm(model.layers[il].wqkv, cur); - cb(cur, "wqkv", il); - - cur = ggml_add(ctx0, cur, model.layers[il].bqkv); - cb(cur, "bqkv", il); - - ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); - ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); - - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } diff --git a/examples/talk-llama/models/chameleon.cpp b/examples/talk-llama/models/chameleon.cpp index 2f24105fa14..21deaba1a6d 100644 --- a/examples/talk-llama/models/chameleon.cpp +++ b/examples/talk-llama/models/chameleon.cpp @@ -36,22 +36,10 @@ llm_build_chameleon::llm_build_chameleon(const llama_model & model, const llm_gr // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); if (model.layers[il].attn_q_norm) { - Qcur = ggml_view_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens, - ggml_element_size(Qcur) * n_embd_head, - ggml_element_size(Qcur) * n_embd_head * n_head, - 0); - cb(Qcur, "Qcur", il); - Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, model.layers[il].attn_q_norm_b, @@ -60,12 +48,6 @@ llm_build_chameleon::llm_build_chameleon(const llama_model & model, const llm_gr } if (model.layers[il].attn_k_norm) { - Kcur = ggml_view_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens, - ggml_element_size(Kcur) * n_embd_head, - ggml_element_size(Kcur) * n_embd_head * n_head_kv, - 0); - cb(Kcur, "Kcur", il); - Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, model.layers[il].attn_k_norm_b, @@ -73,10 +55,6 @@ llm_build_chameleon::llm_build_chameleon(const llama_model & model, const llm_gr cb(Kcur, "Kcur", il); } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, @@ -94,7 +72,7 @@ llm_build_chameleon::llm_build_chameleon(const llama_model & model, const llm_gr cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, nullptr, + model.layers[il].wo, nullptr, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } diff --git a/examples/talk-llama/models/chatglm.cpp b/examples/talk-llama/models/chatglm.cpp index 5887ed22e7e..7d4a43fdca5 100644 --- a/examples/talk-llama/models/chatglm.cpp +++ b/examples/talk-llama/models/chatglm.cpp @@ -3,7 +3,6 @@ llm_build_chatglm::llm_build_chatglm(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); @@ -30,37 +29,8 @@ llm_build_chatglm::llm_build_chatglm(const llama_model & model, const llm_graph_ // self-attention { - ggml_tensor * Qcur = nullptr; - ggml_tensor * Kcur = nullptr; - ggml_tensor * Vcur = nullptr; - - if (model.layers[il].wqkv == nullptr) { - Qcur = build_lora_mm(model.layers[il].wq, cur); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - } - Kcur = build_lora_mm(model.layers[il].wk, cur); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - } - Vcur = build_lora_mm(model.layers[il].wv, cur); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - } else { - cur = build_lora_mm(model.layers[il].wqkv, cur); - cb(cur, "wqkv", il); - if (model.layers[il].bqkv) { - cur = ggml_add(ctx0, cur, model.layers[il].bqkv); - cb(cur, "bqkv", il); - } - Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); - Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); - } + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); //printf("freq_base: %f freq_scale: %f ext_factor: %f attn_factor: %f\n", freq_base, freq_scale, ext_factor, attn_factor); Qcur = ggml_rope_ext( @@ -80,7 +50,7 @@ llm_build_chatglm::llm_build_chatglm(const llama_model & model, const llm_graph_ cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -111,8 +81,13 @@ llm_build_chatglm::llm_build_chatglm(const llama_model & model, const llm_graph_ } - inpL = ggml_add(ctx0, cur, ffn_inp); - cb(inpL, "l_out", il); + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; } cur = build_norm(inpL, diff --git a/examples/talk-llama/models/codeshell.cpp b/examples/talk-llama/models/codeshell.cpp index e8e13e143f2..3ceb5835b85 100644 --- a/examples/talk-llama/models/codeshell.cpp +++ b/examples/talk-llama/models/codeshell.cpp @@ -2,7 +2,6 @@ llm_build_codeshell::llm_build_codeshell(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); GGML_ASSERT(n_embd_head == n_rot); @@ -28,15 +27,8 @@ llm_build_codeshell::llm_build_codeshell(const llama_model & model, const llm_gr // self-attention { - cur = build_lora_mm(model.layers[il].wqkv, cur); - cb(cur, "wqkv", il); - - cur = ggml_add(ctx0, cur, model.layers[il].bqkv); - cb(cur, "bqkv", il); - - ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); - ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -55,7 +47,7 @@ llm_build_codeshell::llm_build_codeshell(const llama_model & model, const llm_gr cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } diff --git a/examples/talk-llama/models/cogvlm.cpp b/examples/talk-llama/models/cogvlm.cpp index 2ef2b6e389b..be3eeeddac7 100644 --- a/examples/talk-llama/models/cogvlm.cpp +++ b/examples/talk-llama/models/cogvlm.cpp @@ -28,18 +28,20 @@ llm_build_cogvlm::llm_build_cogvlm(const llama_model & model, const llm_graph_pa for (int il = 0; il < n_layer; ++il) { // get either the text or image weight tensors - ggml_tensor *wqkv, *wo; + ggml_tensor *wqkv, *wo, *wo_s; ggml_tensor *ffn_gate, *ffn_down, *ffn_up; if (is_text) { wqkv = model.layers[il].wqkv; wo = model.layers[il].wo; + wo_s = model.layers[il].wo_s; ffn_gate = model.layers[il].ffn_gate; ffn_down = model.layers[il].ffn_down; ffn_up = model.layers[il].ffn_up; } else { wqkv = model.layers[il].visexp_attn_wqkv; wo = model.layers[il].visexp_attn_wo; + wo_s = nullptr; ffn_gate = model.layers[il].visexp_ffn_gate; ffn_down = model.layers[il].visexp_ffn_down; ffn_up = model.layers[il].visexp_ffn_up; @@ -64,7 +66,7 @@ llm_build_cogvlm::llm_build_cogvlm(const llama_model & model, const llm_graph_pa Kcur = ggml_rope(ctx0, Kcur, inp_pos, n_embd_head, rope_type); cur = build_attn(inp_attn, - wo, nullptr, + wo, nullptr, wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); @@ -86,6 +88,10 @@ llm_build_cogvlm::llm_build_cogvlm(const llama_model & model, const llm_graph_pa cur = ggml_add(ctx0, cur, ffn_inp); cb(cur, "ffn_out", il); + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer inpL = cur; } diff --git a/examples/talk-llama/models/cohere2-iswa.cpp b/examples/talk-llama/models/cohere2-iswa.cpp index 7c71a59ae7f..670b08e7d97 100644 --- a/examples/talk-llama/models/cohere2-iswa.cpp +++ b/examples/talk-llama/models/cohere2-iswa.cpp @@ -36,30 +36,8 @@ llm_build_cohere2_iswa::llm_build_cohere2_iswa(const llama_model & model, const ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); if (is_swa) { Qcur = ggml_rope_ext( @@ -80,7 +58,7 @@ llm_build_cohere2_iswa::llm_build_cohere2_iswa(const llama_model & model, const cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } diff --git a/examples/talk-llama/models/command-r.cpp b/examples/talk-llama/models/command-r.cpp index ba1230f0419..067961caa08 100644 --- a/examples/talk-llama/models/command-r.cpp +++ b/examples/talk-llama/models/command-r.cpp @@ -32,27 +32,8 @@ llm_build_command_r::llm_build_command_r(const llama_model & model, const llm_gr // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); if (model.layers[il].attn_q_norm) { Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM, il); @@ -73,7 +54,7 @@ llm_build_command_r::llm_build_command_r(const llama_model & model, const llm_gr cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/dbrx.cpp b/examples/talk-llama/models/dbrx.cpp index 73eb5cd24e7..0e882721807 100644 --- a/examples/talk-llama/models/dbrx.cpp +++ b/examples/talk-llama/models/dbrx.cpp @@ -2,7 +2,6 @@ llm_build_dbrx::llm_build_dbrx(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); GGML_ASSERT(n_embd_head == n_rot); @@ -30,19 +29,8 @@ llm_build_dbrx::llm_build_dbrx(const llama_model & model, const llm_graph_params // self-attention { - ggml_tensor * Qcur = nullptr; - ggml_tensor * Kcur = nullptr; - ggml_tensor * Vcur = nullptr; - - cur = build_lora_mm(model.layers[il].wqkv, cur); - cb(cur, "wqkv", il); - - cur = ggml_clamp(ctx0, cur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv); - cb(cur, "wqkv_clamped", il); - - Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); - Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -61,7 +49,7 @@ llm_build_dbrx::llm_build_dbrx(const llama_model & model, const llm_graph_params cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } diff --git a/examples/talk-llama/models/deci.cpp b/examples/talk-llama/models/deci.cpp index ac448bfcaa8..30272eabd69 100644 --- a/examples/talk-llama/models/deci.cpp +++ b/examples/talk-llama/models/deci.cpp @@ -1,7 +1,5 @@ #include "models.h" - - llm_build_deci::llm_build_deci(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); @@ -47,27 +45,8 @@ llm_build_deci::llm_build_deci(const llama_model & model, const llm_graph_params ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); @@ -80,7 +59,7 @@ llm_build_deci::llm_build_deci(const llama_model & model, const llm_graph_params cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/deepseek.cpp b/examples/talk-llama/models/deepseek.cpp index 3432359e03a..671b72dfead 100644 --- a/examples/talk-llama/models/deepseek.cpp +++ b/examples/talk-llama/models/deepseek.cpp @@ -35,27 +35,8 @@ llm_build_deepseek::llm_build_deepseek(const llama_model & model, const llm_grap ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); @@ -68,7 +49,7 @@ llm_build_deepseek::llm_build_deepseek(const llama_model & model, const llm_grap cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/deepseek2.cpp b/examples/talk-llama/models/deepseek2.cpp index d437fe29e71..303fc72c610 100644 --- a/examples/talk-llama/models/deepseek2.cpp +++ b/examples/talk-llama/models/deepseek2.cpp @@ -2,6 +2,9 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + // lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B + bool is_ocr = model.arch == LLM_ARCH_DEEPSEEK2OCR; + const bool is_mla = hparams.is_mla(); // note: these are the actual head sizes you get when treating as MHA or after "decompression" using wv_b for MLA @@ -54,7 +57,38 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr cb(cur, "attn_norm", il); // self_attention - { + if (is_ocr) { + const int n_embed_head = hparams.n_embd / hparams.n_head(); + const int ocr_rope_type = GGML_ROPE_TYPE_NEOX; + GGML_ASSERT(n_embed_head == n_embd_head_k && n_embed_head == n_embd_head_v); + + ggml_tensor * Qcur = NULL; + ggml_tensor * Kcur = NULL; + ggml_tensor * Vcur = NULL; + + Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + cb(Qcur, "q", il); + cb(Kcur, "k", il); + cb(Vcur, "v", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embed_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embed_head, n_head, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embed_head, n_head, n_tokens); + + GGML_ASSERT(fabs(freq_base - 10000.0) < 1e-4); + Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_embed_head, ocr_rope_type, 0, freq_base, 1, 0, 1, 0, 0); + Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_embed_head, ocr_rope_type, 0, freq_base, 1, 0, 1, 0, 0); + cb(Qcur, "q_pe", il); + cb(Kcur, "k_pe", il); + + cur = build_attn(inp_attn_kv, + model.layers[il].wo, NULL, model.layers[il].wo_s, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + cb(cur, "attn_out", il); + } + else { ggml_tensor * q = NULL; const bool is_lite = model.layers[il].wq; @@ -148,7 +182,7 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr // note: MLA with the absorption optimization converts into MQA (ie: GQA with 1 group) cur = build_attn(inp_attn_k, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, model.layers[il].wv_b, kq_scale, il); } else { ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_cmpr); @@ -185,7 +219,7 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr // note: MLA without the absorption optimization converts into MHA (ie: GQA with full n_head groups) cur = build_attn(inp_attn_kv, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); } } diff --git a/examples/talk-llama/models/dots1.cpp b/examples/talk-llama/models/dots1.cpp index 07236dd27c9..5d1750fedda 100644 --- a/examples/talk-llama/models/dots1.cpp +++ b/examples/talk-llama/models/dots1.cpp @@ -29,18 +29,8 @@ llm_build_dots1::llm_build_dots1(const llama_model & model, const llm_graph_para // self_attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); cb(Qcur, "Qcur_normed", il); @@ -59,7 +49,7 @@ llm_build_dots1::llm_build_dots1(const llama_model & model, const llm_graph_para cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/dream.cpp b/examples/talk-llama/models/dream.cpp index 4edc8530cb3..8e7d9ae64c7 100644 --- a/examples/talk-llama/models/dream.cpp +++ b/examples/talk-llama/models/dream.cpp @@ -1,7 +1,5 @@ #include "models.h" - - llm_build_dream::llm_build_dream(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { //copied from qwen2 @@ -31,22 +29,8 @@ llm_build_dream::llm_build_dream(const llama_model & model, const llm_graph_para // self-attention { - // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); @@ -59,7 +43,7 @@ llm_build_dream::llm_build_dream(const llama_model & model, const llm_graph_para cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/ernie4-5-moe.cpp b/examples/talk-llama/models/ernie4-5-moe.cpp index 63baf152c40..fc6a3e17a09 100644 --- a/examples/talk-llama/models/ernie4-5-moe.cpp +++ b/examples/talk-llama/models/ernie4-5-moe.cpp @@ -30,27 +30,8 @@ llm_build_ernie4_5_moe::llm_build_ernie4_5_moe(const llama_model & model, const // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); @@ -63,7 +44,7 @@ llm_build_ernie4_5_moe::llm_build_ernie4_5_moe(const llama_model & model, const cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); cb(cur, "attn_out", il); } diff --git a/examples/talk-llama/models/ernie4-5.cpp b/examples/talk-llama/models/ernie4-5.cpp index d548de0547b..033ba409eab 100644 --- a/examples/talk-llama/models/ernie4-5.cpp +++ b/examples/talk-llama/models/ernie4-5.cpp @@ -29,27 +29,8 @@ llm_build_ernie4_5::llm_build_ernie4_5(const llama_model & model, const llm_grap } // self-attention { - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); @@ -62,7 +43,7 @@ llm_build_ernie4_5::llm_build_ernie4_5(const llama_model & model, const llm_grap cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { diff --git a/examples/talk-llama/models/eurobert.cpp b/examples/talk-llama/models/eurobert.cpp index e8628d165d0..43fff4daf3a 100644 --- a/examples/talk-llama/models/eurobert.cpp +++ b/examples/talk-llama/models/eurobert.cpp @@ -24,17 +24,8 @@ llm_build_eurobert::llm_build_eurobert(const llama_model & model, const llm_grap LLM_NORM_RMS, il); { - ggml_tensor * Qcur; - ggml_tensor * Kcur; - ggml_tensor * Vcur; - - Qcur = build_lora_mm(model.layers[il].wq, cur); - Kcur = build_lora_mm(model.layers[il].wk, cur); - Vcur = build_lora_mm(model.layers[il].wv, cur); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -53,7 +44,7 @@ llm_build_eurobert::llm_build_eurobert(const llama_model & model, const llm_grap cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, nullptr, + model.layers[il].wo, nullptr, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); cb(cur, "kqv_out", il); } @@ -82,6 +73,7 @@ llm_build_eurobert::llm_build_eurobert(const llama_model & model, const llm_grap cur = ggml_add(ctx0, cur, ffn_inp); + // input for next layer inpL = cur; } cur = inpL; diff --git a/examples/talk-llama/models/exaone-moe.cpp b/examples/talk-llama/models/exaone-moe.cpp index ea75701c528..7b88a31d39d 100644 --- a/examples/talk-llama/models/exaone-moe.cpp +++ b/examples/talk-llama/models/exaone-moe.cpp @@ -35,18 +35,8 @@ llm_build_exaone_moe::llm_build_exaone_moe(const llama_model & model, const llm_ ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); @@ -65,7 +55,7 @@ llm_build_exaone_moe::llm_build_exaone_moe(const llama_model & model, const llm_ cb(Vcur, "Vcur", il); cur = build_attn(inp_attn_iswa, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); cb(cur, "attn_out", il); } diff --git a/examples/talk-llama/models/exaone.cpp b/examples/talk-llama/models/exaone.cpp index d4eea58e2f1..4f845bf4106 100644 --- a/examples/talk-llama/models/exaone.cpp +++ b/examples/talk-llama/models/exaone.cpp @@ -1,7 +1,5 @@ #include "models.h" - - llm_build_exaone::llm_build_exaone(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); @@ -34,27 +32,8 @@ llm_build_exaone::llm_build_exaone(const llama_model & model, const llm_graph_pa ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); @@ -67,7 +46,7 @@ llm_build_exaone::llm_build_exaone(const llama_model & model, const llm_graph_pa cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/exaone4.cpp b/examples/talk-llama/models/exaone4.cpp index 755af3b747b..34bee3b8fe9 100644 --- a/examples/talk-llama/models/exaone4.cpp +++ b/examples/talk-llama/models/exaone4.cpp @@ -1,6 +1,5 @@ #include "models.h" - template llm_build_exaone4::llm_build_exaone4(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { @@ -39,18 +38,8 @@ llm_build_exaone4::llm_build_exaone4(const llama_model & model, const llm_ { ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); @@ -69,7 +58,7 @@ llm_build_exaone4::llm_build_exaone4(const llama_model & model, const llm_ cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); cb(cur, "attn_out", il); } diff --git a/examples/talk-llama/models/falcon-h1.cpp b/examples/talk-llama/models/falcon-h1.cpp index ff842d93a41..05accf90fad 100644 --- a/examples/talk-llama/models/falcon-h1.cpp +++ b/examples/talk-llama/models/falcon-h1.cpp @@ -27,19 +27,8 @@ llm_build_falcon_h1::llm_build_falcon_h1(const llama_model & model, const llm_gr cb(cur, "attn_norm", il); // self-attention - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, hparams.rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); @@ -52,7 +41,7 @@ llm_build_falcon_h1::llm_build_falcon_h1(const llama_model & model, const llm_gr cb(Vcur, "Vcur-post-rope", il); ggml_tensor * attn_out = build_attn(inp->get_attn(), - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(attn_out, "attn_out", il); diff --git a/examples/talk-llama/models/falcon.cpp b/examples/talk-llama/models/falcon.cpp index 9fcba508878..2f65fa56e1f 100644 --- a/examples/talk-llama/models/falcon.cpp +++ b/examples/talk-llama/models/falcon.cpp @@ -1,9 +1,7 @@ #include "models.h" - llm_build_falcon::llm_build_falcon(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); GGML_ASSERT(n_embd_head == n_rot); @@ -42,12 +40,8 @@ llm_build_falcon::llm_build_falcon(const llama_model & model, const llm_graph_pa cur = attn_norm; } - cur = build_lora_mm(model.layers[il].wqkv, cur); - cb(cur, "wqkv", il); - - ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); - ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); // using mode = 2 for neox mode Qcur = ggml_rope_ext( @@ -67,7 +61,7 @@ llm_build_falcon::llm_build_falcon(const llama_model & model, const llm_graph_pa cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } diff --git a/examples/talk-llama/models/gemma-embedding.cpp b/examples/talk-llama/models/gemma-embedding.cpp index 98110d45e3b..b6de9551c52 100644 --- a/examples/talk-llama/models/gemma-embedding.cpp +++ b/examples/talk-llama/models/gemma-embedding.cpp @@ -9,7 +9,7 @@ llm_build_gemma_embedding::llm_build_gemma_embedding(const llama_model & model, inpL = build_inp_embd(model.tok_embd); - // important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings) + // important: do not normalize weights for raw embeddings input (i.e. encoded image embeddings) inpL = ggml_scale(ctx0, inpL, ubatch.token ? sqrtf(n_embd) : 1.0f); cb(inpL, "inp_scaled", -1); @@ -31,18 +31,8 @@ llm_build_gemma_embedding::llm_build_gemma_embedding(const llama_model & model, // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); cb(Qcur, "Qcur_normed", il); @@ -65,7 +55,7 @@ llm_build_gemma_embedding::llm_build_gemma_embedding(const llama_model & model, cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il); } diff --git a/examples/talk-llama/models/gemma.cpp b/examples/talk-llama/models/gemma.cpp index 1869efd389a..09d2ff8bae7 100644 --- a/examples/talk-llama/models/gemma.cpp +++ b/examples/talk-llama/models/gemma.cpp @@ -1,6 +1,5 @@ #include "models.h" - llm_build_gemma::llm_build_gemma(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); @@ -29,18 +28,8 @@ llm_build_gemma::llm_build_gemma(const llama_model & model, const llm_graph_para // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -60,7 +49,7 @@ llm_build_gemma::llm_build_gemma(const llama_model & model, const llm_graph_para cb(Qcur, "Qcur_scaled", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/gemma2-iswa.cpp b/examples/talk-llama/models/gemma2-iswa.cpp index 3927ddd297b..0ef07df8d01 100644 --- a/examples/talk-llama/models/gemma2-iswa.cpp +++ b/examples/talk-llama/models/gemma2-iswa.cpp @@ -31,18 +31,8 @@ llm_build_gemma2_iswa::llm_build_gemma2_iswa(const llama_model & model, const ll // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -61,7 +51,7 @@ llm_build_gemma2_iswa::llm_build_gemma2_iswa(const llama_model & model, const ll Qcur = ggml_scale(ctx0, Qcur, hparams.f_attention_scale); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/gemma3.cpp b/examples/talk-llama/models/gemma3.cpp index bbb4d9a81e8..0da4af21c17 100644 --- a/examples/talk-llama/models/gemma3.cpp +++ b/examples/talk-llama/models/gemma3.cpp @@ -9,7 +9,7 @@ llm_build_gemma3::llm_build_gemma3(const llama_model & model, const llm_gr inpL = build_inp_embd(model.tok_embd); - // important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings) + // important: do not normalize weights for raw embeddings input (i.e. encoded image embeddings) inpL = ggml_scale(ctx0, inpL, ubatch.token ? sqrtf(n_embd) : 1.0f); cb(inpL, "inp_scaled", -1); @@ -47,18 +47,8 @@ llm_build_gemma3::llm_build_gemma3(const llama_model & model, const llm_gr // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); cb(Qcur, "Qcur_normed", il); @@ -84,7 +74,7 @@ llm_build_gemma3::llm_build_gemma3(const llama_model & model, const llm_gr Qcur = ggml_scale(ctx0, Qcur, hparams.f_attention_scale); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/gemma3n-iswa.cpp b/examples/talk-llama/models/gemma3n-iswa.cpp index 8ce2ae39c2f..f8095417e06 100644 --- a/examples/talk-llama/models/gemma3n-iswa.cpp +++ b/examples/talk-llama/models/gemma3n-iswa.cpp @@ -1,5 +1,12 @@ #include "models.h" +// get 2D slice view from a 3D tensor, the idx corresponds to the 3rd dim +static ggml_tensor * ggml_view_2d_slice(ggml_context * ctx0, ggml_tensor * x, int idx) { + GGML_ASSERT(idx < (int) x->ne[2]); + return ggml_view_2d(ctx0, x, x->ne[0], x->ne[1], ggml_row_size(x->type, x->ne[0]), + idx * x->ne[0] * x->ne[1] * ggml_element_size(x)); +} + llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params), model(model), @@ -12,7 +19,7 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const inpL = build_inp_embd(model.tok_embd); - // important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings) + // important: do not normalize weights for raw embeddings input (i.e. encoded image embeddings) inpL = ggml_scale(ctx0, inpL, ubatch.token ? sqrtf(n_embd) : 1.0f); cb(inpL, "inp_scaled", -1); @@ -22,8 +29,11 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const // TODO: is causal == true correct? might need some changes auto * inp_attn = build_attn_inp_kv_iswa(); - // inp_per_layer shape: [n_embd_altup, n_tokens, n_layer] - ggml_tensor * inp_per_layer = project_per_layer_inputs(inpL, get_per_layer_inputs()); + ggml_tensor * inp_per_layer = build_inp_per_layer(); + ggml_build_forward_expand(gf, inp_per_layer); + + // inp_per_layer now has shape: [n_embd_altup, n_tokens, n_layer] + inp_per_layer = project_per_layer_inputs(inpL, inp_per_layer); // inpL now has only 1 altup, project it to the rest of the altups // these "added" altups will be concat to the last dim of inpL @@ -37,8 +47,7 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const inpL = ggml_concat(ctx0, inpL, altup_added, 2); // shape: [n_embd, n_tokens, n_altup] cb(inpL, "inp_stacked", -1); } - // inpL now has shape: [n_embd, n_tokens, n_altup] - // inp_per_layer now has shape: [n_embd_altup, n_tokens, n_layer] + // inpL now has shape: [n_embd, n_tokens, n_altup] for (int il = 0; il < n_layer; ++il) { // this block is made to be closely resemble Gemma3p5DecoderLayer on python code @@ -49,8 +58,8 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const ggml_tensor * predictions = altup_predict(cur, il); // [n_embd, n_tokens, n_altup] // predicted value will go through self-attention and laurel - ggml_tensor * active_prediction = view_2d_slice(predictions, i_altup_act); // [n_embd, n_tokens] - cur = active_prediction; + ggml_tensor * active_prediction = ggml_view_2d_slice(ctx0, predictions, i_altup_act); // [n_embd, n_tokens] + cur = active_prediction; cb(cur, "active_prediction", il); // norm @@ -62,19 +71,7 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const // self-attention if (hparams.has_kv(il)) { - // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, n_embd_head, n_head, n_head_kv, il); Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); @@ -94,7 +91,7 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const cb(Kcur, "Kcur_pos", il); cur = build_attn(inp_attn, model.layers[il].wo, - NULL, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, + NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, hparams.f_attention_scale, il); } else { // reuse KV cache of earlier layers @@ -110,7 +107,7 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const cb(Qcur, "Qcur_pos", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, nullptr, nullptr, nullptr, nullptr, nullptr, hparams.f_attention_scale, il); } cur = build_norm(cur, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, il); @@ -151,12 +148,13 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const ggml_tensor * first_prediction; // [n_embd, n_tokens] { - first_prediction = view_2d_slice(corrected, i_altup_act); // [n_embd, n_tokens] + first_prediction = ggml_view_2d_slice(ctx0, corrected, i_altup_act); // [n_embd, n_tokens] first_prediction = ggml_mul(ctx0, first_prediction, model.layers[il].altup_correct_scale); first_prediction = build_lora_mm(model.layers[il].per_layer_inp_gate, first_prediction); first_prediction = ggml_gelu(ctx0, first_prediction); // [n_embd_altup, n_tokens] cb(first_prediction, "first_prediction_gated", il); - ggml_tensor * inp_this_layer = view_2d_slice(inp_per_layer, il); // [n_embd_altup, n_tokens] + + ggml_tensor * inp_this_layer = ggml_view_2d_slice(ctx0, inp_per_layer, il); // [n_embd_altup, n_tokens] first_prediction = ggml_mul(ctx0, first_prediction, inp_this_layer); // [n_embd_altup, n_tokens] cb(first_prediction, "first_prediction_scaled", il); @@ -167,7 +165,7 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const } // equivalent to python code: corrected_predictions[1:] += first_prediction { - ggml_tensor * slice_first = view_2d_slice(corrected, 0); + ggml_tensor * slice_first = ggml_view_2d_slice(ctx0, corrected, 0); ggml_tensor * slice_rest = ggml_view_3d( ctx0, corrected, n_embd, n_tokens, n_altup - 1, ggml_row_size(corrected->type, n_embd), ggml_row_size(corrected->type, n_embd * n_tokens), n_embd * n_tokens * ggml_element_size(corrected)); @@ -185,7 +183,7 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const // cur now has multiple altup(s), we want to merge them back to 1 altup { - ggml_tensor * target_magnitude = calc_magnitude(view_2d_slice(cur, i_altup_act)); // [n_embd, n_tokens] + ggml_tensor * target_magnitude = calc_magnitude(ggml_view_2d_slice(ctx0, cur, i_altup_act)); // [n_embd, n_tokens] // do a view to skip the first slice (active altup) ggml_tensor * alt_slice = ggml_view_3d(ctx0, cur, n_embd, n_tokens, n_altup - 1, ggml_row_size(cur->type, n_embd), @@ -197,9 +195,9 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const cb(altup_unembd, "altup_unembd", -1); // equivalent to torch.mean(hidden_states, dim=0) - cur = view_2d_slice(cur, 0); // [n_embd, n_tokens] + cur = ggml_view_2d_slice(ctx0, cur, 0); // [n_embd, n_tokens] for (int i = 0; i < n_altup - 1; ++i) { - cur = ggml_add(ctx0, cur, view_2d_slice(altup_unembd, i)); + cur = ggml_add(ctx0, cur, ggml_view_2d_slice(ctx0, altup_unembd, i)); } cur = ggml_scale(ctx0, cur, 1.0f / float(n_altup)); // [n_embd, n_tokens] cb(cur, "unembd_merged", -1); @@ -235,39 +233,34 @@ ggml_tensor * llm_build_gemma3n_iswa::calc_magnitude(ggml_tensor * x) { return ggml_sqrt(ctx0, ggml_sum_rows(ctx0, ggml_sqr(ctx0, x))); } -// get 2D slice view from a 3D tensor, the idx corresponds to the 3rd dim -ggml_tensor * llm_build_gemma3n_iswa::view_2d_slice(ggml_tensor * x, int idx) { - GGML_ASSERT(idx < (int) x->ne[2]); - return ggml_view_2d(ctx0, x, x->ne[0], x->ne[1], ggml_row_size(x->type, x->ne[0]), - idx * x->ne[0] * x->ne[1] * ggml_element_size(x)); -} - // equivalent to get_per_layer_inputs() in python code // output shape: [n_embd_altup, n_layer, n_tokens] -ggml_tensor * llm_build_gemma3n_iswa::get_per_layer_inputs() { +ggml_tensor * llm_build_gemma3n_iswa::build_inp_per_layer() { auto inp = std::make_unique(n_embd); ggml_tensor * inp_per_layer; + float tok_embd_scale = sqrtf((float) n_embd_altup); if (ubatch.token) { inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens); ggml_set_input(inp->tokens); res->t_inp_tokens = inp->tokens; - inp_per_layer = ggml_get_rows(ctx0, model.tok_embd_per_layer, inp->tokens); + inp_per_layer = ggml_get_rows (ctx0, model.per_layer_tok_embd, inp->tokens); inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_altup, n_layer, n_tokens); - inp_per_layer = ggml_scale(ctx0, inp_per_layer, sqrtf((float) n_embd_altup)); + inp_per_layer = ggml_scale (ctx0, inp_per_layer, tok_embd_scale); cb(inp_per_layer, "inp_per_layer_selected", -1); res->add_input(std::move(inp)); } else { - // Vision embedding path: use padding token (ID=0) embedding + // Multimodal embedding path: use padding token (ID=0) embedding // TODO: verify if this is the correct behavior in transformers implementation - const int64_t embd_size = model.tok_embd_per_layer->ne[0]; // n_embd_altup * n_layer + const int64_t embd_size = model.per_layer_tok_embd->ne[0]; // n_embd_altup * n_layer // Extract and dequantize padding token embedding (row 0) - ggml_tensor * padding = ggml_view_1d(ctx0, model.tok_embd_per_layer, embd_size, 0); - inp_per_layer = ggml_cast(ctx0, padding, GGML_TYPE_F32); + ggml_tensor * padding = ggml_view_1d(ctx0, model.per_layer_tok_embd, embd_size, 0); + inp_per_layer = ggml_cast (ctx0, padding, GGML_TYPE_F32); + inp_per_layer = ggml_scale(ctx0, inp_per_layer, tok_embd_scale); // Reshape to [n_embd_altup, n_layer, 1] inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_altup, n_layer, 1); - cb(inp_per_layer, "inp_per_layer_vision", -1); + cb(inp_per_layer, "inp_per_layer_multimodal", -1); } return inp_per_layer; } @@ -275,18 +268,19 @@ ggml_tensor * llm_build_gemma3n_iswa::get_per_layer_inputs() { // equivalent to project_per_layer_inputs() in python code // this calculates the per-layer inputs, so the final tensor shape will have n_layer as the last dim // output shape: [n_embd_altup, n_tokens, n_layer] -ggml_tensor * llm_build_gemma3n_iswa::project_per_layer_inputs(ggml_tensor * inputs_embeds, ggml_tensor * inp_per_layer) { +ggml_tensor * llm_build_gemma3n_iswa::project_per_layer_inputs(ggml_tensor * inp_batch, ggml_tensor * inp_per_layer) { const float per_layer_projection_scale = 1.0f / sqrtf((float) n_embd); const float per_layer_input_scale = 1.0f / sqrtf(2.0f); - ggml_tensor * per_layer_proj = ggml_mul_mat(ctx0, model.per_layer_model_proj, inputs_embeds); - per_layer_proj = ggml_scale(ctx0, per_layer_proj, per_layer_projection_scale); - per_layer_proj = ggml_reshape_3d(ctx0, per_layer_proj, n_embd_altup, n_layer, n_tokens); - per_layer_proj = build_norm(per_layer_proj, model.per_layer_proj_norm, NULL, LLM_NORM_RMS, - -1); // [n_embd_altup, n_layer, n_tokens] + ggml_tensor * per_layer_proj; + per_layer_proj = ggml_mul_mat (ctx0, model.per_layer_model_proj, inp_batch); + per_layer_proj = ggml_scale (ctx0, per_layer_proj, per_layer_projection_scale); + per_layer_proj = ggml_reshape_3d(ctx0, per_layer_proj, n_embd_altup, n_layer, n_tokens); + + per_layer_proj = build_norm(per_layer_proj, model.per_layer_proj_norm, NULL, LLM_NORM_RMS, -1); cb(per_layer_proj, "per_layer_proj", -1); - inp_per_layer = ggml_add(ctx0, per_layer_proj, inp_per_layer); + inp_per_layer = ggml_add (ctx0, per_layer_proj, inp_per_layer); inp_per_layer = ggml_scale(ctx0, inp_per_layer, per_layer_input_scale); cb(inp_per_layer, "inp_per_layer", -1); @@ -337,7 +331,7 @@ ggml_tensor * llm_build_gemma3n_iswa::altup_compute_router_modalities(ggml_tenso // input cur shape: [n_embd, n_tokens, n_altup] // output shape: [n_embd, n_tokens, n_altup] ggml_tensor * llm_build_gemma3n_iswa::altup_predict(ggml_tensor * cur, int il) { - ggml_tensor * activated = view_2d_slice(cur, i_altup_act); // [n_embd, n_tokens] + ggml_tensor * activated = ggml_view_2d_slice(ctx0, cur, i_altup_act); // [n_embd, n_tokens] ggml_tensor * modalities = altup_compute_router_modalities(activated, il); // [n_altup, n_tokens] cb(modalities, "modalities", il); @@ -365,7 +359,7 @@ ggml_tensor * llm_build_gemma3n_iswa::altup_correct(ggml_tensor * predictions, g ggml_tensor * modalities = altup_compute_router_modalities(activated, il); // [n_altup, n_tokens] cb(modalities, "modalities", il); - ggml_tensor * active_prediction = view_2d_slice(predictions, i_altup_act); + ggml_tensor * active_prediction = ggml_view_2d_slice(ctx0, predictions, i_altup_act); ggml_tensor * innovation = ggml_sub(ctx0, activated, active_prediction); // [n_embd, n_tokens] cb(innovation, "innovation", il); diff --git a/examples/talk-llama/models/gemma4-iswa.cpp b/examples/talk-llama/models/gemma4-iswa.cpp new file mode 100644 index 00000000000..c7fb7747414 --- /dev/null +++ b/examples/talk-llama/models/gemma4-iswa.cpp @@ -0,0 +1,322 @@ +#include "models.h" + +// get 2D slice view from a 3D tensor, the idx corresponds to the 3rd dim +static ggml_tensor * ggml_view_2d_slice(ggml_context * ctx0, ggml_tensor * x, int idx) { + GGML_ASSERT(idx < (int) x->ne[2]); + return ggml_view_2d(ctx0, x, x->ne[0], x->ne[1], ggml_row_size(x->type, x->ne[0]), + idx * x->ne[0] * x->ne[1] * ggml_element_size(x)); +} + +llm_build_gemma4_iswa::llm_build_gemma4_iswa(const llama_model & model, const llm_graph_params & params) : + llm_graph_context(params), + model(model), + n_embd_per_layer(model.hparams.n_embd_per_layer) { + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings) + inpL = ggml_scale(ctx0, inpL, ubatch.token ? sqrtf(n_embd) : 1.0f); + cb(inpL, "inp_scaled", -1); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + // TODO: is causal == true correct? might need some changes + auto * inp_attn = build_attn_inp_kv_iswa(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + ggml_tensor * inp_per_layer = nullptr; + if (model.per_layer_tok_embd) { + inp_per_layer = build_inp_per_layer(); + ggml_build_forward_expand(gf, inp_per_layer); + + // inp_per_layer shape: [n_embd_per_layer, n_tokens, n_layer] + inp_per_layer = project_per_layer_inputs(inpL, inp_per_layer); + } + + for (int il = 0; il < n_layer; ++il) { + const int64_t n_embd_head = hparams.n_embd_head_k(il); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_v(il)); + + const int64_t n_head = hparams.n_head(il); + const int64_t n_head_kv = hparams.n_head_kv(il); + + const float freq_base_l = model.get_rope_freq_base(cparams, il); + const float freq_scale_l = model.get_rope_freq_scale(cparams, il); + const int n_rot_l = hparams.n_rot(il); + + // norm + cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + ggml_tensor * freq_factors = nullptr; + if (!hparams.is_swa(il)) { + // full_attention layers use rope_freqs for proportional rope + freq_factors = model.layers[il].rope_freqs; + } + + // Q projection (shared for both non-KV and KV layers) + // this is to mirror Gemma4Attention in pytorch code + ggml_tensor * Qcur; + { + Qcur = build_lora_mm(model.layers[il].wq, cur, model.layers[il].wq_s); + cb(Qcur, "Qcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + + Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, freq_factors, n_rot_l, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, + ext_factor, attn_factor, beta_fast, beta_slow); + cb(Qcur, "Qcur_pos", il); + } + + // self-attention + if (hparams.has_kv(il)) { + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur, model.layers[il].wk_s); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = model.layers[il].wv + ? build_lora_mm(model.layers[il].wv, cur, model.layers[il].wv_s) + : Kcur; // if v_proj is not present, use Kcur as Vcur + cb(Vcur, "Vcur", il); + + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, il); + Vcur = ggml_rms_norm(ctx0, Vcur, hparams.f_norm_rms_eps); + + cb(Kcur, "Kcur_normed", il); + cb(Vcur, "Vcur_normed", il); + + Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, freq_factors, n_rot_l, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, + ext_factor, attn_factor, beta_fast, beta_slow); + + cb(Kcur, "Kcur_pos", il); + + cur = build_attn(inp_attn, model.layers[il].wo, + nullptr, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, + hparams.f_attention_scale, il); + } else { + // reuse KV cache of earlier layers + cur = build_attn(inp_attn, + model.layers[il].wo, nullptr, model.layers[il].wo_s, + Qcur, nullptr, nullptr, nullptr, nullptr, nullptr, hparams.f_attention_scale, il); + } + + // TODO @ngxson : strip unused token right after the last KV layer to speed up prompt processing + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + cur = build_norm(cur, + model.layers[il].attn_post_norm, nullptr, + LLM_NORM_RMS, il); + cb(cur, "attn_post_norm", il); + + ggml_tensor * attn_out = ggml_add(ctx0, cur, inpL); + cb(attn_out, "attn_out", il); + + // feed-forward network + const bool is_moe_layer = model.layers[il].ffn_gate_inp != nullptr; + if (is_moe_layer) { + // MLP (shared exp) + ggml_tensor * cur_mlp = build_norm(attn_out, + model.layers[il].ffn_norm, nullptr, + LLM_NORM_RMS, il); + cb(cur_mlp, "ffn_norm_1", il); + + cur_mlp = build_ffn(cur_mlp, + model.layers[il].ffn_up, nullptr, model.layers[il].ffn_up_s, + model.layers[il].ffn_gate, nullptr, model.layers[il].ffn_gate_s, + model.layers[il].ffn_down, nullptr, model.layers[il].ffn_down_s, + nullptr, + LLM_FFN_GELU, LLM_FFN_PAR, il); + cur_mlp = build_norm(cur_mlp, + model.layers[il].ffn_post_norm_1, nullptr, + LLM_NORM_RMS, il); + cb(cur_mlp, "ffn_mlp", il); + + // Expert FFN + ggml_tensor * cur_moe = build_norm(attn_out, + model.layers[il].ffn_pre_norm_2, nullptr, + LLM_NORM_RMS, il); + cb(cur_moe, "ffn_norm_2", il); + + // custom MoE logits calculation (router operates on attn_out, not cur) + ggml_tensor * tmp = ggml_rms_norm(ctx0, attn_out, hparams.f_norm_rms_eps); + tmp = ggml_scale(ctx0, tmp, 1.0f / sqrtf((float) n_embd)); + tmp = ggml_mul(ctx0, tmp, model.layers[il].ffn_gate_inp_s); + ggml_tensor * logits = build_lora_mm(model.layers[il].ffn_gate_inp, tmp); // [n_expert, n_tokens] + cb(logits, "ffn_moe_logits", il); + + cur_moe = build_moe_ffn(cur_moe, + nullptr, // gate_inp + nullptr, // up_exps + nullptr, // gate_exps + model.layers[il].ffn_down_exps, + nullptr, // exp_probs_b (not used for gemma4) + n_expert, n_expert_used, + LLM_FFN_GELU, true, + 1.0f, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + il, logits, + model.layers[il].ffn_gate_up_exps, + nullptr, // up_exps_s + nullptr, // gate_exps_s + model.layers[il].ffn_down_exps_s); + cur_moe = build_norm(cur_moe, + model.layers[il].ffn_post_norm_2, nullptr, + LLM_NORM_RMS, il); + cb(cur_moe, "ffn_moe", il); + + cur = ggml_add(ctx0, cur_mlp, cur_moe); + cb(cur, "ffn_moe_combined", il); + } else { + cur = build_norm(attn_out, + model.layers[il].ffn_norm, nullptr, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, nullptr, model.layers[il].ffn_up_s, + model.layers[il].ffn_gate, nullptr, model.layers[il].ffn_gate_s, + model.layers[il].ffn_down, nullptr, model.layers[il].ffn_down_s, + nullptr, + LLM_FFN_GELU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } + cur = build_norm(cur, + model.layers[il].ffn_post_norm, nullptr, + LLM_NORM_RMS, -1); + cb(cur, "ffn_post_norm", il); + + // residual connection + cur = ggml_add(ctx0, cur, attn_out); + + // per-layer embedding + if (inp_per_layer) { + ggml_tensor * pe_in = cur; + cb(cur, "pe_in", il); + + cur = build_lora_mm(model.layers[il].per_layer_inp_gate, cur); // [n_embd_per_layer, n_tokens] + cur = ggml_gelu(ctx0, cur); + + ggml_tensor * inp_this_layer = ggml_view_2d_slice(ctx0, inp_per_layer, il); // [n_embd_per_layer, n_tokens] + + // TODO @ngxson : improve this + if (il == n_layer - 1 && inp_out_ids) { + inp_this_layer = ggml_get_rows(ctx0, inp_this_layer, inp_out_ids); + } + + cur = ggml_mul(ctx0, cur, inp_this_layer); + cur = build_lora_mm(model.layers[il].per_layer_proj, cur); // [n_embd, n_tokens] + cur = build_norm(cur, model.layers[il].per_layer_post_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "per_layer_embd_out", il); + + // residual connection + cur = ggml_add(ctx0, pe_in, cur); + } + + // layer_scalar + if (model.layers[il].out_scale) { + cur = ggml_mul(ctx0, cur, model.layers[il].out_scale); + cb(cur, "out_scaled", il); + } + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + cur = inpL; + + cur = build_norm(cur, + model.output_norm, nullptr, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + if (hparams.f_final_logit_softcapping) { + cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping); + cur = ggml_tanh(ctx0, cur); + cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping); + } + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} + +// equivalent to get_per_layer_inputs() in python code +// output shape: [n_embd_per_layer, n_layer, n_tokens] +ggml_tensor * llm_build_gemma4_iswa::build_inp_per_layer() { + auto inp = std::make_unique(n_embd); + + ggml_tensor * inp_per_layer; + float tok_embd_scale = sqrtf((float) n_embd_per_layer); + if (ubatch.token) { + inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens); + ggml_set_input(inp->tokens); + res->t_inp_tokens = inp->tokens; + + inp_per_layer = ggml_get_rows (ctx0, model.per_layer_tok_embd, inp->tokens); + inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_per_layer, n_layer, n_tokens); + inp_per_layer = ggml_scale (ctx0, inp_per_layer, tok_embd_scale); + cb(inp_per_layer, "inp_per_layer_selected", -1); + + res->add_input(std::move(inp)); + } else { + // Multimodal embedding path: use padding token (ID=0) embedding + // TODO: verify if this is the correct behavior in transformers implementation + const int64_t embd_size = model.per_layer_tok_embd->ne[0]; // n_embd_per_layer * n_layer + + // Extract and dequantize padding token embedding (row 0) + ggml_tensor * padding = ggml_view_1d(ctx0, model.per_layer_tok_embd, embd_size, 0); + inp_per_layer = ggml_cast (ctx0, padding, GGML_TYPE_F32); + inp_per_layer = ggml_scale(ctx0, inp_per_layer, tok_embd_scale); + + // Reshape to [n_embd_per_layer, n_layer, 1] + inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_per_layer, n_layer, 1); + cb(inp_per_layer, "inp_per_layer_multimodal", -1); + } + return inp_per_layer; +} + +// equivalent to project_per_layer_inputs() in python code +// this calculates the per-layer inputs, so the final tensor shape will have n_layer as the last dim +// inp_batch shape: [n_embd, n_tokens] +// inp_per_layer shape: [n_embd_per_layer, n_layer, n_tokens] (from build_inp_per_layer) +// output shape: [n_embd_per_layer, n_tokens, n_layer] +ggml_tensor * llm_build_gemma4_iswa::project_per_layer_inputs(ggml_tensor * inp_batch, ggml_tensor * inp_per_layer) { + const float per_layer_projection_scale = 1.0f / sqrtf((float) n_embd); + const float per_layer_input_scale = 1.0f / sqrtf(2.0f); + + // note: this matrix multiplication will be performed in the input layer (i.e. on the CPU) + ggml_tensor * per_layer_proj; + per_layer_proj = ggml_mul_mat (ctx0, model.per_layer_model_proj, inp_batch); + per_layer_proj = ggml_scale (ctx0, per_layer_proj, per_layer_projection_scale); + per_layer_proj = ggml_reshape_3d(ctx0, per_layer_proj, n_embd_per_layer, n_layer, n_tokens); + + per_layer_proj = build_norm(per_layer_proj, model.per_layer_proj_norm, nullptr, LLM_NORM_RMS, -1); + cb(per_layer_proj, "per_layer_proj", -1); + + inp_per_layer = ggml_add (ctx0, per_layer_proj, inp_per_layer); + inp_per_layer = ggml_scale(ctx0, inp_per_layer, per_layer_input_scale); + cb(inp_per_layer, "inp_per_layer", -1); + + // permute to shape: [n_embd_per_layer, n_tokens, n_layer] + inp_per_layer = ggml_cont(ctx0, ggml_permute(ctx0, inp_per_layer, 0, 2, 1, 3)); + return inp_per_layer; +} diff --git a/examples/talk-llama/models/glm4-moe.cpp b/examples/talk-llama/models/glm4-moe.cpp index 7938545ed8a..8d4f4a01553 100644 --- a/examples/talk-llama/models/glm4-moe.cpp +++ b/examples/talk-llama/models/glm4-moe.cpp @@ -38,27 +38,8 @@ llm_build_glm4_moe::llm_build_glm4_moe(const llama_model & model, const llm_grap // self-attention { - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - } - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - } - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - } - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); // Apply Q/K norm if available (GLM-4.5 355B variant) if (model.layers[il].attn_q_norm) { @@ -94,7 +75,7 @@ llm_build_glm4_moe::llm_build_glm4_moe(const llama_model & model, const llm_grap cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_transformer_layers - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/glm4.cpp b/examples/talk-llama/models/glm4.cpp index b6ad8febed3..f0bfda393fa 100644 --- a/examples/talk-llama/models/glm4.cpp +++ b/examples/talk-llama/models/glm4.cpp @@ -1,10 +1,7 @@ #include "models.h" - - llm_build_glm4::llm_build_glm4(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); @@ -41,40 +38,8 @@ llm_build_glm4::llm_build_glm4(const llama_model & model, const llm_graph_params // self-attention { - ggml_tensor * Qcur = nullptr; - ggml_tensor * Kcur = nullptr; - ggml_tensor * Vcur = nullptr; - - if (model.layers[il].wqkv == nullptr) { - Qcur = build_lora_mm(model.layers[il].wq, cur); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - } - Kcur = build_lora_mm(model.layers[il].wk, cur); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - } - Vcur = build_lora_mm(model.layers[il].wv, cur); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - } else { - cur = build_lora_mm(model.layers[il].wqkv, cur); - cb(cur, "wqkv", il); - if (model.layers[il].bqkv) { - cur = ggml_add(ctx0, cur, model.layers[il].bqkv); - cb(cur, "bqkv", il); - } - Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head * sizeof(float), cur->nb[1], - 0 * sizeof(float) * (n_embd)); - Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float), - cur->nb[1], 1 * sizeof(float) * (n_embd)); - Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float), - cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa)); - } + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); if (use_mrope) { Qcur = ggml_rope_multi(ctx0, Qcur, inp_pos, nullptr, @@ -100,7 +65,7 @@ llm_build_glm4::llm_build_glm4(const llama_model & model, const llm_graph_params cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); } if (il == n_transformer_layers - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/gpt2.cpp b/examples/talk-llama/models/gpt2.cpp index cb1238f2d34..f8dc53eb723 100644 --- a/examples/talk-llama/models/gpt2.cpp +++ b/examples/talk-llama/models/gpt2.cpp @@ -2,7 +2,6 @@ llm_build_gpt2::llm_build_gpt2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); @@ -34,22 +33,11 @@ llm_build_gpt2::llm_build_gpt2(const llama_model & model, const llm_graph_params // self-attention { - cur = build_lora_mm(model.layers[il].wqkv, cur); - cb(cur, "wqkv", il); - - cur = ggml_add(ctx0, cur, model.layers[il].bqkv); - cb(cur, "bqkv", il); - - ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); - ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); - - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } diff --git a/examples/talk-llama/models/gptneox.cpp b/examples/talk-llama/models/gptneox.cpp index 1c8fe6c836d..0016ddede43 100644 --- a/examples/talk-llama/models/gptneox.cpp +++ b/examples/talk-llama/models/gptneox.cpp @@ -1,9 +1,7 @@ #include "models.h" - llm_build_gptneox::llm_build_gptneox(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); @@ -28,15 +26,8 @@ llm_build_gptneox::llm_build_gptneox(const llama_model & model, const llm_graph_ // self-attention { - cur = build_lora_mm(model.layers[il].wqkv, cur); - cb(cur, "wqkv", il); - - cur = ggml_add(ctx0, cur, model.layers[il].bqkv); - cb(cur, "bqkv", il); - - ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); - ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -55,7 +46,7 @@ llm_build_gptneox::llm_build_gptneox(const llama_model & model, const llm_graph_ cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } diff --git a/examples/talk-llama/models/granite-hybrid.cpp b/examples/talk-llama/models/granite-hybrid.cpp index 9b54a38c386..e983742bef5 100644 --- a/examples/talk-llama/models/granite-hybrid.cpp +++ b/examples/talk-llama/models/granite-hybrid.cpp @@ -73,31 +73,7 @@ ggml_tensor * llm_build_granite_hybrid::build_attention_layer(ggml_tensor * const llama_model & model, const int64_t n_embd_head, const int il) { - // compute Q and K and (optionally) RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, hparams.n_head(il), n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, hparams.n_head_kv(il), n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, hparams.n_head_kv(il), n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, n_embd_head, hparams.n_head(il), hparams.n_head_kv(il), il); const bool use_rope = hparams.rope_finetuned; if (use_rope) { @@ -116,7 +92,7 @@ ggml_tensor * llm_build_granite_hybrid::build_attention_layer(ggml_tensor * const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale; cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); return cur; diff --git a/examples/talk-llama/models/granite.cpp b/examples/talk-llama/models/granite.cpp index 7a7e1664c29..6ea90285225 100644 --- a/examples/talk-llama/models/granite.cpp +++ b/examples/talk-llama/models/granite.cpp @@ -76,31 +76,8 @@ ggml_tensor * llm_build_granite::build_attention_layer( const int64_t n_embd_head, const int il) { - // compute Q and K and (optionally) RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, hparams.n_head(il), n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, hparams.n_head_kv(il), n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, hparams.n_head_kv(il), n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, hparams.n_head(il), hparams.n_head_kv(il), il); const bool use_rope = hparams.rope_finetuned; if (use_rope) { @@ -124,7 +101,7 @@ ggml_tensor * llm_build_granite::build_attention_layer( const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); return cur; diff --git a/examples/talk-llama/models/grok.cpp b/examples/talk-llama/models/grok.cpp index 580d63e36ae..b8f35afdc03 100644 --- a/examples/talk-llama/models/grok.cpp +++ b/examples/talk-llama/models/grok.cpp @@ -30,27 +30,8 @@ llm_build_grok::llm_build_grok(const llama_model & model, const llm_graph_params // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -69,7 +50,7 @@ llm_build_grok::llm_build_grok(const llama_model & model, const llm_graph_params cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/grovemoe.cpp b/examples/talk-llama/models/grovemoe.cpp index aa60d3e9388..151108a2a71 100644 --- a/examples/talk-llama/models/grovemoe.cpp +++ b/examples/talk-llama/models/grovemoe.cpp @@ -30,18 +30,8 @@ llm_build_grovemoe::llm_build_grovemoe(const llama_model & model, const llm_grap // self_attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); cb(Qcur, "Qcur_normed", il); @@ -60,7 +50,7 @@ llm_build_grovemoe::llm_build_grovemoe(const llama_model & model, const llm_grap cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); } diff --git a/examples/talk-llama/models/hunyuan-dense.cpp b/examples/talk-llama/models/hunyuan-dense.cpp index 6a51707c85b..1cd85d6d9d4 100644 --- a/examples/talk-llama/models/hunyuan-dense.cpp +++ b/examples/talk-llama/models/hunyuan-dense.cpp @@ -6,6 +6,11 @@ llm_build_hunyuan_dense::llm_build_hunyuan_dense(const llama_model & model, cons GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); GGML_ASSERT(n_embd_head == n_rot); + const bool use_mrope = hparams.use_mrope(); + + int sections[4]; + std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); + ggml_tensor * cur; ggml_tensor * inpL; @@ -34,44 +39,39 @@ llm_build_hunyuan_dense::llm_build_hunyuan_dense(const llama_model & model, cons ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); + + if (use_mrope) { + Qcur = ggml_rope_multi( + ctx0, Qcur, inp_pos, rope_factors, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_multi( + ctx0, Kcur, inp_pos, rope_factors, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + } else { + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - - Qcur = ggml_rope_ext( - ctx0, Qcur, inp_pos, rope_factors, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow - ); cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - Kcur = ggml_rope_ext( - ctx0, Kcur, inp_pos, rope_factors, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow - ); - Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, il); @@ -83,7 +83,7 @@ llm_build_hunyuan_dense::llm_build_hunyuan_dense(const llama_model & model, cons cb(Qcur, "Qcur_norm", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); } diff --git a/examples/talk-llama/models/hunyuan-moe.cpp b/examples/talk-llama/models/hunyuan-moe.cpp index 806c30b3667..ffe1664b0e1 100644 --- a/examples/talk-llama/models/hunyuan-moe.cpp +++ b/examples/talk-llama/models/hunyuan-moe.cpp @@ -35,27 +35,8 @@ llm_build_hunyuan_moe::llm_build_hunyuan_moe(const llama_model & model, const ll ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, rope_factors, @@ -84,7 +65,7 @@ llm_build_hunyuan_moe::llm_build_hunyuan_moe(const llama_model & model, const ll cb(Qcur, "Qcur_norm", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); } diff --git a/examples/talk-llama/models/internlm2.cpp b/examples/talk-llama/models/internlm2.cpp index 441d250268e..83be2ca0aee 100644 --- a/examples/talk-llama/models/internlm2.cpp +++ b/examples/talk-llama/models/internlm2.cpp @@ -30,27 +30,8 @@ llm_build_internlm2::llm_build_internlm2(const llama_model & model, const llm_gr // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -69,7 +50,7 @@ llm_build_internlm2::llm_build_internlm2(const llama_model & model, const llm_gr cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/jais.cpp b/examples/talk-llama/models/jais.cpp index 135bf288ba1..31101f3c14b 100644 --- a/examples/talk-llama/models/jais.cpp +++ b/examples/talk-llama/models/jais.cpp @@ -2,7 +2,6 @@ llm_build_jais::llm_build_jais(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); @@ -24,22 +23,11 @@ llm_build_jais::llm_build_jais(const llama_model & model, const llm_graph_params // self-attention { - cur = build_lora_mm(model.layers[il].wqkv, cur); - cb(cur, "wqkv", il); - - cur = ggml_add(ctx0, cur, model.layers[il].bqkv); - cb(cur, "bqkv", il); - - ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*cur->nb[0]*(n_embd)); - ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*cur->nb[0]*(n_embd)); - ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*cur->nb[0]*(n_embd + n_embd_gqa)); - - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/float(n_embd_head), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -66,8 +54,14 @@ llm_build_jais::llm_build_jais(const llama_model & model, const llm_graph_params LLM_FFN_SILU, LLM_FFN_PAR, il); cb(cur, "ffn_out", il); } - inpL = ggml_add(ctx0, cur, ffn_inp); - cb(inpL, "l_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; } cur = build_norm(inpL, model.output_norm, diff --git a/examples/talk-llama/models/jais2.cpp b/examples/talk-llama/models/jais2.cpp index 2cfe484eb52..507e04fa4aa 100644 --- a/examples/talk-llama/models/jais2.cpp +++ b/examples/talk-llama/models/jais2.cpp @@ -31,25 +31,8 @@ llm_build_jais2::llm_build_jais2(const llama_model & model, const llm_graph_para // Self-attention with separate Q, K, V projections { - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur_bias", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur_bias", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur_bias", il); - - // Reshape for attention - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); // Apply RoPE Qcur = ggml_rope_ext( @@ -68,7 +51,7 @@ llm_build_jais2::llm_build_jais2(const llama_model & model, const llm_graph_para cb(Kcur, "Kcur_rope", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } diff --git a/examples/talk-llama/models/jamba.cpp b/examples/talk-llama/models/jamba.cpp index c0c89de187a..f82b7795c87 100644 --- a/examples/talk-llama/models/jamba.cpp +++ b/examples/talk-llama/models/jamba.cpp @@ -24,25 +24,12 @@ llm_build_jamba::llm_build_jamba(const llama_model & model, const llm_graph_para } else { // Attention - struct ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - struct ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - struct ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); // No RoPE :) cur = build_attn(inp_hybrid->get_attn(), - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, NULL, NULL, NULL, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/kimi-linear.cpp b/examples/talk-llama/models/kimi-linear.cpp index 4d62f4e7159..58c89c417fc 100644 --- a/examples/talk-llama/models/kimi-linear.cpp +++ b/examples/talk-llama/models/kimi-linear.cpp @@ -268,7 +268,7 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll ggml_tensor * Vcur = kv_cmpr; cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn_k, layer.wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, layer.wv_b, kq_scale_mla, il); + cur = build_attn(inp_attn_k, layer.wo, NULL, layer.wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, layer.wv_b, kq_scale_mla, il); cb(cur, "mla_out", il); } else { // MLA KV cache disabled. Fall back to MHA KV cache. Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head_k_mla, n_head, n_tokens); @@ -299,7 +299,7 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll // Direct softmax attention (with MHA KV cache) // Use build_attn with inp_attn for proper mask handling - cur = build_attn(inp_attn_kv, layer.wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale_mla, il); + cur = build_attn(inp_attn_kv, layer.wo, NULL, layer.wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale_mla, il); cb(cur, "mla_out", il); } } @@ -362,6 +362,7 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll cur = build_cvec(cur, il); cb(cur, "l_out", il); + // input for next layer inpL = cur; } cur = inpL; diff --git a/examples/talk-llama/models/lfm2.cpp b/examples/talk-llama/models/lfm2.cpp index dfa322166b1..eb8ec3c803a 100644 --- a/examples/talk-llama/models/lfm2.cpp +++ b/examples/talk-llama/models/lfm2.cpp @@ -42,16 +42,8 @@ llm_build_lfm2::llm_build_lfm2(const llama_model & model, const llm_graph_ const auto n_embd_head = hparams.n_embd_head_v(); const auto n_head_kv = hparams.n_head_kv(il); - auto * q = build_lora_mm(model.layers[il].wq, cur); - cb(q, "model.layers.{}.self_attn.q_proj", il); - auto * k = build_lora_mm(model.layers[il].wk, cur); - cb(k, "model.layers.{}.self_attn.k_proj", il); - auto * v = build_lora_mm(model.layers[il].wv, cur); - cb(v, "model.layers.{}.self_attn.v_proj", il); - - q = ggml_reshape_3d(ctx0, q, n_embd_head, n_head, n_tokens); - k = ggml_reshape_3d(ctx0, k, n_embd_head, n_head_kv, n_tokens); - v = ggml_reshape_3d(ctx0, v, n_embd_head, n_head_kv, n_tokens); + auto [q, k, v] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); // qk norm q = build_norm(q, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); @@ -66,7 +58,7 @@ llm_build_lfm2::llm_build_lfm2(const llama_model & model, const llm_graph_ attn_factor, beta_fast, beta_slow); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, q, k, v, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); cb(cur, "model.layers.{}.self_attn.out_proj", il); @@ -177,6 +169,9 @@ llm_build_lfm2::llm_build_lfm2(const llama_model & model, const llm_graph_ cb(ffn_norm_out, "model.layers.{}.ffn_out", il); cur = ggml_add(ctx0, cur, ffn_out); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); } cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); diff --git a/examples/talk-llama/models/llada-moe.cpp b/examples/talk-llama/models/llada-moe.cpp index 18de88fde1f..c756d6fde5f 100644 --- a/examples/talk-llama/models/llada-moe.cpp +++ b/examples/talk-llama/models/llada-moe.cpp @@ -30,18 +30,8 @@ llm_build_llada_moe::llm_build_llada_moe(const llama_model & model, const llm_gr // self_attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); cb(Qcur, "Qcur_normed", il); @@ -66,7 +56,7 @@ llm_build_llada_moe::llm_build_llada_moe(const llama_model & model, const llm_gr cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/llada.cpp b/examples/talk-llama/models/llada.cpp index 0dac9d616ae..501df3c7eaf 100644 --- a/examples/talk-llama/models/llada.cpp +++ b/examples/talk-llama/models/llada.cpp @@ -30,17 +30,8 @@ llm_build_llada::llm_build_llada(const llama_model & model, const llm_graph_para // self-attention { // compute separate Q, K, V projections without bias, matching LLaDALlamaBlock - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); @@ -53,7 +44,7 @@ llm_build_llada::llm_build_llada(const llama_model & model, const llm_graph_para cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/llama.cpp b/examples/talk-llama/models/llama.cpp index e08ae0c0b0e..8d478dc6747 100644 --- a/examples/talk-llama/models/llama.cpp +++ b/examples/talk-llama/models/llama.cpp @@ -43,27 +43,8 @@ llm_build_llama::llm_build_llama(const llama_model & model, const llm_gra ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur, model.layers[il].wq_s); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur, model.layers[il].wk_s); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur, model.layers[il].wv_s); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, rope_factors, @@ -89,11 +70,8 @@ llm_build_llama::llm_build_llama(const llama_model & model, const llm_gra cb(Kcur, "Kcur_normed", il); } cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); - if (model.layers[il].wo_s) { - cur = ggml_mul(ctx0, cur, model.layers[il].wo_s); - } cb(cur, "attn_out", il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/llama-iswa.cpp b/examples/talk-llama/models/llama4.cpp similarity index 81% rename from examples/talk-llama/models/llama-iswa.cpp rename to examples/talk-llama/models/llama4.cpp index 67cb9a10ec5..4e4bfb43f33 100644 --- a/examples/talk-llama/models/llama-iswa.cpp +++ b/examples/talk-llama/models/llama4.cpp @@ -1,6 +1,7 @@ #include "models.h" -llm_build_llama_iswa::llm_build_llama_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +template +llm_build_llama4::llm_build_llama4(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); @@ -18,7 +19,14 @@ llm_build_llama_iswa::llm_build_llama_iswa(const llama_model & model, const llm_ ggml_tensor * inp_attn_scale = nullptr; inp_attn_scale = build_inp_attn_scale(); - auto * inp_attn = build_attn_inp_kv_iswa(); + using inp_attn_type = std::conditional_t; + inp_attn_type * inp_attn = nullptr; + + if constexpr (iswa) { + inp_attn = build_attn_inp_kv_iswa(); + } else { + inp_attn = build_attn_inp_kv(); + } const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; @@ -46,27 +54,8 @@ llm_build_llama_iswa::llm_build_llama_iswa(const llama_model & model, const llm_ ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); if (use_rope) { Qcur = ggml_rope_ext( @@ -95,7 +84,7 @@ llm_build_llama_iswa::llm_build_llama_iswa(const llama_model & model, const llm_ cb(Kcur, "Kcur_normed", il); } cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); } @@ -176,3 +165,7 @@ llm_build_llama_iswa::llm_build_llama_iswa(const llama_model & model, const llm_ ggml_build_forward_expand(gf, cur); } + +// Explicit template instantiations +template struct llm_build_llama4; +template struct llm_build_llama4; diff --git a/examples/talk-llama/models/maincoder.cpp b/examples/talk-llama/models/maincoder.cpp index a72b7790a1f..8a76931c007 100644 --- a/examples/talk-llama/models/maincoder.cpp +++ b/examples/talk-llama/models/maincoder.cpp @@ -30,18 +30,8 @@ llm_build_maincoder::llm_build_maincoder(const llama_model & model, const llm_gr // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -66,7 +56,7 @@ llm_build_maincoder::llm_build_maincoder(const llama_model & model, const llm_gr cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/mamba-base.cpp b/examples/talk-llama/models/mamba-base.cpp index 9de587db55f..c37f29c487e 100644 --- a/examples/talk-llama/models/mamba-base.cpp +++ b/examples/talk-llama/models/mamba-base.cpp @@ -42,7 +42,7 @@ ggml_tensor * llm_build_mamba_base::build_mamba_layer(llm_graph_input_rs * inp, cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs); // {n_embd, 2*d_inner} @ {n_embd, n_seq_tokens, n_seqs} => {2*d_inner, n_seq_tokens, n_seqs} - ggml_tensor * xz = build_lora_mm(layer.ssm_in, cur); + ggml_tensor * xz = build_lora_mm(layer.ssm_in, cur, layer.ssm_in_s); // split the above in two // => {d_inner, n_seq_tokens, n_seqs} ggml_tensor * x = ggml_view_3d(ctx0, xz, d_inner, xz->ne[1], xz->ne[2], xz->nb[1], xz->nb[2], 0); @@ -137,7 +137,7 @@ ggml_tensor * llm_build_mamba_base::build_mamba_layer(llm_graph_input_rs * inp, y = ggml_swiglu_split(ctx0, ggml_cont(ctx0, z), y); // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} - cur = build_lora_mm(layer.ssm_out, y); + cur = build_lora_mm(layer.ssm_out, y, layer.ssm_out_s); } // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens} @@ -184,7 +184,7 @@ ggml_tensor * llm_build_mamba_base::build_mamba2_layer(llm_graph_input_rs * inp, // d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads // {n_embd, d_in_proj} @ {n_embd, n_seq_tokens, n_seqs} => {d_in_proj, n_seq_tokens, n_seqs} - ggml_tensor * zxBCdt = build_lora_mm(model.layers[il].ssm_in, cur); + ggml_tensor * zxBCdt = build_lora_mm(model.layers[il].ssm_in, cur, model.layers[il].ssm_in_s); // split the above in three ggml_tensor * z = ggml_view_4d(ctx0, zxBCdt, head_dim, n_head, n_seq_tokens, n_seqs, head_dim * zxBCdt->nb[0], @@ -278,7 +278,7 @@ ggml_tensor * llm_build_mamba_base::build_mamba2_layer(llm_graph_input_rs * inp, y = ggml_reshape_3d(ctx0, y, d_inner, n_seq_tokens, n_seqs); // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} - cur = build_lora_mm(model.layers[il].ssm_out, y); + cur = build_lora_mm(model.layers[il].ssm_out, y, model.layers[il].ssm_out_s); } // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens} diff --git a/examples/talk-llama/models/mimo2-iswa.cpp b/examples/talk-llama/models/mimo2-iswa.cpp index 06956915ea0..52c6acfe214 100644 --- a/examples/talk-llama/models/mimo2-iswa.cpp +++ b/examples/talk-llama/models/mimo2-iswa.cpp @@ -58,7 +58,7 @@ llm_build_mimo2_iswa::llm_build_mimo2_iswa(const llama_model & model, const llm_ ggml_tensor * sinks = model.layers[il].attn_sinks; cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, sinks, nullptr, 1.0f/sqrtf(float(n_embd_head_k)), il); } diff --git a/examples/talk-llama/models/minicpm3.cpp b/examples/talk-llama/models/minicpm3.cpp index 89dd7105157..bf12ab73c74 100644 --- a/examples/talk-llama/models/minicpm3.cpp +++ b/examples/talk-llama/models/minicpm3.cpp @@ -134,7 +134,7 @@ llm_build_minicpm3::llm_build_minicpm3(const llama_model & model, const llm_grap cb(k_states, "k_states", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, q_states, k_states, v_states, nullptr, nullptr, nullptr, kq_scale, il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/minimax-m2.cpp b/examples/talk-llama/models/minimax-m2.cpp index 83d0916c08c..b809b79f2b9 100644 --- a/examples/talk-llama/models/minimax-m2.cpp +++ b/examples/talk-llama/models/minimax-m2.cpp @@ -64,7 +64,7 @@ llm_build_minimax_m2::llm_build_minimax_m2(const llama_model & model, const llm_ cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } diff --git a/examples/talk-llama/models/mistral3.cpp b/examples/talk-llama/models/mistral3.cpp index 42a5117ff02..b5ae72a2ee1 100644 --- a/examples/talk-llama/models/mistral3.cpp +++ b/examples/talk-llama/models/mistral3.cpp @@ -41,27 +41,8 @@ llm_build_mistral3::llm_build_mistral3(const llama_model & model, const llm_grap ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, rope_factors, @@ -86,7 +67,7 @@ llm_build_mistral3::llm_build_mistral3(const llama_model & model, const llm_grap } cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); } diff --git a/examples/talk-llama/models/models.h b/examples/talk-llama/models/models.h index a86b2b1ebd7..94991c55fe8 100644 --- a/examples/talk-llama/models/models.h +++ b/examples/talk-llama/models/models.h @@ -256,9 +256,11 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { llm_build_gemma3n_iswa(const llama_model & model, const llm_graph_params & params); ggml_tensor * calc_magnitude(ggml_tensor * x); - ggml_tensor * view_2d_slice(ggml_tensor * x, int idx); - ggml_tensor * get_per_layer_inputs(); - ggml_tensor * project_per_layer_inputs(ggml_tensor * inputs_embeds, ggml_tensor * inp_per_layer); + + // TODO: refactor in common "per-layer" functionality [TAG_PER_LAYER] + ggml_tensor * build_inp_per_layer(); + ggml_tensor * project_per_layer_inputs(ggml_tensor * inp_batch, ggml_tensor * inp_per_layer); + ggml_tensor * gaussian_topk(ggml_tensor * x); ggml_tensor * altup_compute_router_modalities(ggml_tensor * x, int il); ggml_tensor * altup_predict(ggml_tensor * cur, int il); @@ -266,6 +268,18 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { ggml_tensor * altup_correct(ggml_tensor * predictions, ggml_tensor * activated, int il); }; +struct llm_build_gemma4_iswa : public llm_graph_context { + const llama_model & model; + + const int64_t n_embd_per_layer; + + llm_build_gemma4_iswa(const llama_model & model, const llm_graph_params & params); + + // TODO: refactor in common "per-layer" functionality [TAG_PER_LAYER] + ggml_tensor * build_inp_per_layer(); + ggml_tensor * project_per_layer_inputs(ggml_tensor * inp_batch, ggml_tensor * inp_per_layer); +}; + struct llm_build_gemma_embedding : public llm_graph_context { llm_build_gemma_embedding(const llama_model & model, const llm_graph_params & params); }; @@ -393,8 +407,9 @@ struct llm_build_llama : public llm_graph_context { llm_build_llama(const llama_model & model, const llm_graph_params & params); }; -struct llm_build_llama_iswa : public llm_graph_context { - llm_build_llama_iswa(const llama_model & model, const llm_graph_params & params); +template +struct llm_build_llama4 : public llm_graph_context { + llm_build_llama4(const llama_model & model, const llm_graph_params & params); }; struct llm_build_maincoder : public llm_graph_context { @@ -481,7 +496,7 @@ struct llm_build_phi2 : public llm_graph_context { llm_build_phi2(const llama_model & model, const llm_graph_params & params); }; -template +template struct llm_build_phi3 : public llm_graph_context { llm_build_phi3(const llama_model & model, const llm_graph_params & params); }; @@ -687,12 +702,13 @@ struct llm_build_step35_iswa : public llm_graph_context { llm_build_step35_iswa(const llama_model & model, const llm_graph_params & params); }; -struct llm_build_t5_dec : public llm_graph_context { - llm_build_t5_dec(const llama_model & model, const llm_graph_params & params); +template +struct llm_build_t5 : public llm_graph_context { + llm_build_t5(const llama_model & model, const llm_graph_params & params); }; -struct llm_build_t5_enc : public llm_graph_context { - llm_build_t5_enc(const llama_model & model, const llm_graph_params & params); +struct llm_build_t5encoder : public llm_build_t5 { + llm_build_t5encoder(const llama_model & model, const llm_graph_params & params); }; struct llm_build_wavtokenizer_dec : public llm_graph_context { diff --git a/examples/talk-llama/models/modern-bert.cpp b/examples/talk-llama/models/modern-bert.cpp index 26020584c6d..5c6a1b5e1bc 100644 --- a/examples/talk-llama/models/modern-bert.cpp +++ b/examples/talk-llama/models/modern-bert.cpp @@ -2,7 +2,6 @@ llm_build_modern_bert::llm_build_modern_bert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); @@ -15,8 +14,8 @@ llm_build_modern_bert::llm_build_modern_bert(const llama_model & model, const ll cb(inpL, "inp_embd", -1); // embed layer norm - inpL = build_norm(inpL, model.tok_norm, nullptr, LLM_NORM, -1); - cb(inpL, "inp_norm", -1); + inpL = build_norm(inpL, model.tok_norm, nullptr, LLM_NORM, 0); + cb(inpL, "inp_norm", 0); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -37,14 +36,8 @@ llm_build_modern_bert::llm_build_modern_bert(const llama_model & model, const ll } // self attention - cur = build_lora_mm(model.layers[il].wqkv, cur); - cb(cur, "wqkv", il); - - const size_t type_size = ggml_type_size(cur->type); - - ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*type_size, cur->nb[1], 0*type_size*(n_embd)); - ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*type_size, cur->nb[1], 1*type_size*(n_embd)); - ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*type_size, cur->nb[1], 1*type_size*(n_embd + n_embd_gqa)); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); // RoPE Qcur = ggml_rope_ext( @@ -64,7 +57,7 @@ llm_build_modern_bert::llm_build_modern_bert(const llama_model & model, const ll cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, nullptr, + model.layers[il].wo, nullptr, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); cb(cur, "kqv_out", il); diff --git a/examples/talk-llama/models/mpt.cpp b/examples/talk-llama/models/mpt.cpp index ce44a805f5c..8596bbb2024 100644 --- a/examples/talk-llama/models/mpt.cpp +++ b/examples/talk-llama/models/mpt.cpp @@ -1,10 +1,7 @@ #include "models.h" - - llm_build_mpt::llm_build_mpt(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); @@ -38,25 +35,8 @@ llm_build_mpt::llm_build_mpt(const llama_model & model, const llm_graph_params & { cur = attn_norm; - cur = build_lora_mm(model.layers[il].wqkv, cur); - cb(cur, "wqkv", il); - - if (model.layers[il].bqkv) { - cur = ggml_add(ctx0, cur, model.layers[il].bqkv); - cb(cur, "bqkv", il); - } - - if (hparams.f_clamp_kqv > 0.0f) { - cur = ggml_clamp(ctx0, cur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv); - cb(cur, "wqkv_clamped", il); - } - - ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head * sizeof(float), - cur->nb[1], 0 * sizeof(float) * (n_embd)); - ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float), - cur->nb[1], 1 * sizeof(float) * (n_embd)); - ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float), - cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa)); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); // Q/K Layernorm if (model.layers[il].attn_q_norm) { @@ -76,7 +56,7 @@ llm_build_mpt::llm_build_mpt(const llama_model & model, const llm_graph_params & cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); } diff --git a/examples/talk-llama/models/nemotron-h.cpp b/examples/talk-llama/models/nemotron-h.cpp index 7af99174d16..dc07d43df58 100644 --- a/examples/talk-llama/models/nemotron-h.cpp +++ b/examples/talk-llama/models/nemotron-h.cpp @@ -65,40 +65,12 @@ ggml_tensor * llm_build_nemotron_h::build_attention_layer(ggml_tensor * const llama_model & model, int64_t n_embd_head, int il) { - // compute Q and K - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, hparams.n_head(il), n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, hparams.n_head_kv(il), n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, hparams.n_head_kv(il), n_tokens); - - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, n_embd_head, hparams.n_head(il), hparams.n_head_kv(il), il); const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale; cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); return cur; @@ -107,9 +79,9 @@ ggml_tensor * llm_build_nemotron_h::build_attention_layer(ggml_tensor * ggml_tensor * llm_build_nemotron_h::build_ffn_layer(ggml_tensor * cur, const llama_model & model, int il) { if (model.layers[il].ffn_gate_inp == nullptr) { cur = build_ffn(cur, - model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, model.layers[il].ffn_up_s, NULL, NULL, NULL, - model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, model.layers[il].ffn_down_s, NULL, LLM_FFN_RELU_SQR, LLM_FFN_PAR, il); cb(cur, "ffn_out", il); @@ -136,7 +108,10 @@ ggml_tensor * llm_build_nemotron_h::build_ffn_layer(ggml_tensor * cur, const lla hparams.expert_weights_scale, LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID, il, - router_logits); + router_logits, nullptr, + model.layers[il].ffn_up_exps_s, + nullptr, // no gate + model.layers[il].ffn_down_exps_s); cb(moe_out, "ffn_moe_out", il); if (model.layers[il].ffn_latent_up) { @@ -144,9 +119,9 @@ ggml_tensor * llm_build_nemotron_h::build_ffn_layer(ggml_tensor * cur, const lla } ggml_tensor * ffn_shexp = build_ffn(inp_emb, - model.layers[il].ffn_up_shexp, NULL, NULL, - NULL /* no gate */ , NULL, NULL, - model.layers[il].ffn_down_shexp, NULL, NULL, + model.layers[il].ffn_up_shexp, NULL, model.layers[il].ffn_up_shexp_s, + NULL /* no gate */ , NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, model.layers[il].ffn_down_shexp_s, NULL, LLM_FFN_RELU_SQR, LLM_FFN_PAR, il); cb(ffn_shexp, "ffn_shexp", il); diff --git a/examples/talk-llama/models/nemotron.cpp b/examples/talk-llama/models/nemotron.cpp index 34aa6fa5ec4..054b16fe0ef 100644 --- a/examples/talk-llama/models/nemotron.cpp +++ b/examples/talk-llama/models/nemotron.cpp @@ -31,27 +31,8 @@ llm_build_nemotron::llm_build_nemotron(const llama_model & model, const llm_grap // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -70,7 +51,7 @@ llm_build_nemotron::llm_build_nemotron(const llama_model & model, const llm_grap cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/neo-bert.cpp b/examples/talk-llama/models/neo-bert.cpp index 2fdf4a3692f..da68024a34d 100644 --- a/examples/talk-llama/models/neo-bert.cpp +++ b/examples/talk-llama/models/neo-bert.cpp @@ -2,7 +2,6 @@ llm_build_neo_bert::llm_build_neo_bert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); @@ -27,17 +26,8 @@ llm_build_neo_bert::llm_build_neo_bert(const llama_model & model, const llm_grap LLM_NORM_RMS, il); { - ggml_tensor * Qcur; - ggml_tensor * Kcur; - ggml_tensor * Vcur; - - // self-attention - cur = build_lora_mm(model.layers[il].wqkv, cur); - cb(cur, "wqkv", il); - - Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); - Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); // RoPE Qcur = ggml_rope_ext( @@ -57,7 +47,7 @@ llm_build_neo_bert::llm_build_neo_bert(const llama_model & model, const llm_grap cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, nullptr, + model.layers[il].wo, nullptr, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); cb(cur, "kqv_out", il); } diff --git a/examples/talk-llama/models/olmo.cpp b/examples/talk-llama/models/olmo.cpp index 26f4b6ee628..a9974025f07 100644 --- a/examples/talk-llama/models/olmo.cpp +++ b/examples/talk-llama/models/olmo.cpp @@ -30,27 +30,8 @@ llm_build_olmo::llm_build_olmo(const llama_model & model, const llm_graph_params // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (hparams.f_clamp_kqv > 0.0f) { - Qcur = ggml_clamp(ctx0, Qcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (hparams.f_clamp_kqv > 0.0f) { - Kcur = ggml_clamp(ctx0, Kcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (hparams.f_clamp_kqv > 0.0f) { - Vcur = ggml_clamp(ctx0, Vcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -69,7 +50,7 @@ llm_build_olmo::llm_build_olmo(const llama_model & model, const llm_graph_params cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, nullptr, + model.layers[il].wo, nullptr, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/olmo2.cpp b/examples/talk-llama/models/olmo2.cpp index 5076359e3f9..308d2a600c2 100644 --- a/examples/talk-llama/models/olmo2.cpp +++ b/examples/talk-llama/models/olmo2.cpp @@ -89,7 +89,7 @@ llm_build_olmo2::llm_build_olmo2(const llama_model & model, const llm_grap cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/olmoe.cpp b/examples/talk-llama/models/olmoe.cpp index 83a56a0b3b6..ed46a00ef90 100644 --- a/examples/talk-llama/models/olmoe.cpp +++ b/examples/talk-llama/models/olmoe.cpp @@ -68,7 +68,7 @@ llm_build_olmoe::llm_build_olmoe(const llama_model & model, const llm_graph_para cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/openai-moe-iswa.cpp b/examples/talk-llama/models/openai-moe-iswa.cpp index 403f130bc41..50992b8d506 100644 --- a/examples/talk-llama/models/openai-moe-iswa.cpp +++ b/examples/talk-llama/models/openai-moe-iswa.cpp @@ -28,27 +28,8 @@ llm_build_openai_moe_iswa::llm_build_openai_moe_iswa(const llama_model & model, // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_rot, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_rot, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_rot, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_rot, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -67,7 +48,7 @@ llm_build_openai_moe_iswa::llm_build_openai_moe_iswa(const llama_model & model, cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, model.layers[il].attn_sinks, nullptr, 1.0f/sqrtf(float(n_rot)), il); cb(cur, "attn_out", il); diff --git a/examples/talk-llama/models/openelm.cpp b/examples/talk-llama/models/openelm.cpp index 5df6fe3e3ce..514ac33517f 100644 --- a/examples/talk-llama/models/openelm.cpp +++ b/examples/talk-llama/models/openelm.cpp @@ -73,7 +73,7 @@ llm_build_openelm::llm_build_openelm(const llama_model & model, const llm_graph_ cb(Qcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/orion.cpp b/examples/talk-llama/models/orion.cpp index 48c01efe368..a5874b6dee7 100644 --- a/examples/talk-llama/models/orion.cpp +++ b/examples/talk-llama/models/orion.cpp @@ -30,30 +30,8 @@ llm_build_orion::llm_build_orion(const llama_model & model, const llm_graph_para // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - // if (model.layers[il].bq) { - // Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - // cb(Qcur, "Qcur", il); - // } - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - // if (model.layers[il].bk) { - // Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - // cb(Kcur, "Kcur", il); - // } - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - // if (model.layers[il].bv) { - // Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - // cb(Vcur, "Vcur", il); - // } - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -72,7 +50,7 @@ llm_build_orion::llm_build_orion(const llama_model & model, const llm_graph_para cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/paddleocr.cpp b/examples/talk-llama/models/paddleocr.cpp index 340455c2d5f..56cb1d94c5f 100644 --- a/examples/talk-llama/models/paddleocr.cpp +++ b/examples/talk-llama/models/paddleocr.cpp @@ -35,27 +35,8 @@ llm_build_paddleocr::llm_build_paddleocr(const llama_model & model, const llm_gr } // self-attention { - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_multi( ctx0, Qcur, inp_pos, nullptr, @@ -74,7 +55,7 @@ llm_build_paddleocr::llm_build_paddleocr(const llama_model & model, const llm_gr cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { diff --git a/examples/talk-llama/models/pangu-embedded.cpp b/examples/talk-llama/models/pangu-embedded.cpp index 1cf0938e68f..53464f21d22 100644 --- a/examples/talk-llama/models/pangu-embedded.cpp +++ b/examples/talk-llama/models/pangu-embedded.cpp @@ -1,6 +1,5 @@ #include "models.h" - llm_build_pangu_embedded::llm_build_pangu_embedded(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); @@ -31,21 +30,8 @@ llm_build_pangu_embedded::llm_build_pangu_embedded(const llama_model & model, co // self attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -63,7 +49,7 @@ llm_build_pangu_embedded::llm_build_pangu_embedded(const llama_model & model, co cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } diff --git a/examples/talk-llama/models/phi2.cpp b/examples/talk-llama/models/phi2.cpp index 32d40d71fb7..0fb3ffa2e63 100644 --- a/examples/talk-llama/models/phi2.cpp +++ b/examples/talk-llama/models/phi2.cpp @@ -1,9 +1,7 @@ #include "models.h" - llm_build_phi2::llm_build_phi2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); @@ -30,29 +28,8 @@ llm_build_phi2::llm_build_phi2(const llama_model & model, const llm_graph_params // self-attention { - ggml_tensor * Qcur = nullptr; - ggml_tensor * Kcur = nullptr; - ggml_tensor * Vcur = nullptr; - - if (model.layers[il].wqkv) { - cur = build_lora_mm(model.layers[il].wqkv, attn_norm_output); - cb(cur, "wqkv", il); - - cur = ggml_add(ctx0, cur, model.layers[il].bqkv); - cb(cur, "bqkv", il); - - Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); - Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); - } else { - Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, attn_norm_output), model.layers[il].bq); - Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, attn_norm_output), model.layers[il].bk); - Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, attn_norm_output), model.layers[il].bv); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - } + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], attn_norm_output, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, @@ -74,7 +51,7 @@ llm_build_phi2::llm_build_phi2(const llama_model & model, const llm_graph_params Qcur = ggml_scale(ctx0, Qcur, 1.0f/sqrtf(float(n_embd_head))); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/phi3.cpp b/examples/talk-llama/models/phi3.cpp index 3d11a9459c4..39af285d3c5 100644 --- a/examples/talk-llama/models/phi3.cpp +++ b/examples/talk-llama/models/phi3.cpp @@ -3,7 +3,6 @@ template llm_build_phi3::llm_build_phi3(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); @@ -39,27 +38,8 @@ llm_build_phi3::llm_build_phi3(const llama_model & model, const llm_graph_ LLM_NORM_RMS, il); cb(attn_norm_output, "attn_norm", il); - ggml_tensor * Qcur = nullptr; - ggml_tensor * Kcur = nullptr; - ggml_tensor * Vcur = nullptr; - - if (model.layers[il].wqkv) { - cur = build_lora_mm(model.layers[il].wqkv, attn_norm_output); - cb(cur, "wqkv", il); - - Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head * sizeof(float), cur->nb[1], 0 * sizeof(float) * (n_embd)); - Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float), cur->nb[1], 1 * sizeof(float) * (n_embd)); - Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float), cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa)); - } - else { - Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, attn_norm_output), model.layers[il].bq); - Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, attn_norm_output), model.layers[il].bk); - Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, attn_norm_output), model.layers[il].bv); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - } + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], attn_norm_output, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, @@ -80,7 +60,7 @@ llm_build_phi3::llm_build_phi3(const llama_model & model, const llm_graph_ cb(Qcur, "Qcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/plamo.cpp b/examples/talk-llama/models/plamo.cpp index b7a71211042..4d5c84506c2 100644 --- a/examples/talk-llama/models/plamo.cpp +++ b/examples/talk-llama/models/plamo.cpp @@ -30,18 +30,8 @@ llm_build_plamo::llm_build_plamo(const llama_model & model, const llm_graph_para // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -60,7 +50,7 @@ llm_build_plamo::llm_build_plamo(const llama_model & model, const llm_graph_para cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/plamo2.cpp b/examples/talk-llama/models/plamo2.cpp index f02acbc1869..b6142daebd9 100644 --- a/examples/talk-llama/models/plamo2.cpp +++ b/examples/talk-llama/models/plamo2.cpp @@ -71,6 +71,7 @@ llm_build_plamo2::llm_build_plamo2(const llama_model & model, const llm_graph_pa cur = ggml_add(ctx0, cur, residual); cb(cur, "ffn_residual", il); + // input for next layer inpL = cur; } @@ -140,7 +141,7 @@ ggml_tensor * llm_build_plamo2::build_plamo2_attn_layer(llm_graph_input_attn_kv ext_factor, attn_factor, beta_fast, beta_slow); cur = build_attn(inp, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, NULL, NULL, NULL, 1.0f / sqrtf(float(n_embd_head_v)), il); } diff --git a/examples/talk-llama/models/plamo3.cpp b/examples/talk-llama/models/plamo3.cpp index 32af6e04663..67844c09f24 100644 --- a/examples/talk-llama/models/plamo3.cpp +++ b/examples/talk-llama/models/plamo3.cpp @@ -73,7 +73,7 @@ llm_build_plamo3::llm_build_plamo3(const llama_model & model, const llm_gr const float attn_scale = 1.0f / sqrtf(float(head_dim_q)); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, attn_scale, il); cb(cur, "attn_out", il); @@ -109,6 +109,8 @@ llm_build_plamo3::llm_build_plamo3(const llama_model & model, const llm_gr cur = build_cvec(cur, il); cb(cur, "l_out", il); + + // input for next layer inpL = cur; } diff --git a/examples/talk-llama/models/plm.cpp b/examples/talk-llama/models/plm.cpp index bcb651ce543..abce6b34d04 100644 --- a/examples/talk-llama/models/plm.cpp +++ b/examples/talk-llama/models/plm.cpp @@ -120,7 +120,7 @@ llm_build_plm::llm_build_plm(const llama_model & model, const llm_graph_params & cb(k_states, "k_states", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, q_states, k_states, v_states, nullptr, nullptr, nullptr, kq_scale, il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/qwen.cpp b/examples/talk-llama/models/qwen.cpp index 7390f1320bf..44e75d87437 100644 --- a/examples/talk-llama/models/qwen.cpp +++ b/examples/talk-llama/models/qwen.cpp @@ -1,6 +1,5 @@ #include "models.h" - llm_build_qwen::llm_build_qwen(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); @@ -28,15 +27,8 @@ llm_build_qwen::llm_build_qwen(const llama_model & model, const llm_graph_params // self-attention { - cur = build_lora_mm(model.layers[il].wqkv, cur); - cb(cur, "wqkv", il); - - cur = ggml_add(ctx0, cur, model.layers[il].bqkv); - cb(cur, "bqkv", il); - - ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); - ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 2*sizeof(float)*(n_embd)); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); // using mode = 2 for neox mode Qcur = ggml_rope_ext( @@ -56,7 +48,7 @@ llm_build_qwen::llm_build_qwen(const llama_model & model, const llm_graph_params cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/qwen2.cpp b/examples/talk-llama/models/qwen2.cpp index 58c10622508..2892dd75087 100644 --- a/examples/talk-llama/models/qwen2.cpp +++ b/examples/talk-llama/models/qwen2.cpp @@ -30,30 +30,8 @@ llm_build_qwen2::llm_build_qwen2(const llama_model & model, const llm_graph_para // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -72,7 +50,7 @@ llm_build_qwen2::llm_build_qwen2(const llama_model & model, const llm_graph_para cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/qwen2moe.cpp b/examples/talk-llama/models/qwen2moe.cpp index 60761789dc9..5f0a6861b68 100644 --- a/examples/talk-llama/models/qwen2moe.cpp +++ b/examples/talk-llama/models/qwen2moe.cpp @@ -30,27 +30,8 @@ llm_build_qwen2moe::llm_build_qwen2moe(const llama_model & model, const llm_grap // self_attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -69,7 +50,7 @@ llm_build_qwen2moe::llm_build_qwen2moe(const llama_model & model, const llm_grap cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/qwen2vl.cpp b/examples/talk-llama/models/qwen2vl.cpp index 9004bab9db1..da7937c7667 100644 --- a/examples/talk-llama/models/qwen2vl.cpp +++ b/examples/talk-llama/models/qwen2vl.cpp @@ -33,21 +33,8 @@ llm_build_qwen2vl::llm_build_qwen2vl(const llama_model & model, const llm_graph_ // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_multi( ctx0, Qcur, inp_pos, nullptr, @@ -66,7 +53,7 @@ llm_build_qwen2vl::llm_build_qwen2vl(const llama_model & model, const llm_graph_ cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/qwen3.cpp b/examples/talk-llama/models/qwen3.cpp index 52081668477..883dd5f9a90 100644 --- a/examples/talk-llama/models/qwen3.cpp +++ b/examples/talk-llama/models/qwen3.cpp @@ -30,18 +30,8 @@ llm_build_qwen3::llm_build_qwen3(const llama_model & model, const llm_graph_para // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur, model.layers[il].wq_s); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur, model.layers[il].wk_s); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur, model.layers[il].wv_s); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); cb(Qcur, "Qcur_normed", il); @@ -66,11 +56,8 @@ llm_build_qwen3::llm_build_qwen3(const llama_model & model, const llm_graph_para cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); - if (model.layers[il].wo_s) { - cur = ggml_mul(ctx0, cur, model.layers[il].wo_s); - } } if (il == n_layer - 1 && inp_out_ids) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); diff --git a/examples/talk-llama/models/qwen35.cpp b/examples/talk-llama/models/qwen35.cpp index 3108bf331ac..87790f08e4e 100644 --- a/examples/talk-llama/models/qwen35.cpp +++ b/examples/talk-llama/models/qwen35.cpp @@ -64,6 +64,9 @@ llm_build_qwen35::llm_build_qwen35(const llama_model & model, const llm_graph_pa cur = ggml_add(ctx0, cur, ffn_residual); cb(cur, "post_ffn", il); + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + // Input for next layer inpL = cur; } @@ -176,7 +179,7 @@ ggml_tensor * llm_build_qwen35::build_layer_attn( const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale; cur = build_attn(inp, - nullptr, nullptr, + nullptr, nullptr, nullptr, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_pregate", il); @@ -222,9 +225,10 @@ ggml_tensor * llm_build_qwen35::build_layer_attn_linear( cb(beta, "beta", il); beta = ggml_sigmoid(ctx0, beta); + cb(beta, "beta_sigmoid", il); ggml_tensor * alpha = build_lora_mm(model.layers[il].ssm_alpha, cur, model.layers[il].ssm_alpha_s); - alpha = ggml_cont_3d(ctx0, alpha, num_v_heads, n_seq_tokens, n_seqs); + alpha = ggml_reshape_3d(ctx0, alpha, num_v_heads, n_seq_tokens, n_seqs); cb(alpha, "alpha", il); ggml_tensor * alpha_biased = ggml_add(ctx0, alpha, model.layers[il].ssm_dt); @@ -266,7 +270,7 @@ ggml_tensor * llm_build_qwen35::build_layer_attn_linear( cb(last_conv_states, "last_conv_states", il); ggml_tensor * state_update_target = - ggml_view_1d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels * n_seqs, + ggml_view_2d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels, n_seqs, conv_states_all->nb[1], kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all)); cb(state_update_target, "state_update_target", il); @@ -342,7 +346,7 @@ ggml_tensor * llm_build_qwen35::build_layer_attn_linear( // Update the recurrent states ggml_build_forward_expand(gf, ggml_cpy(ctx0, new_state, - ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs, + ggml_view_2d(ctx0, ssm_states_all, hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1], kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all)))); // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] diff --git a/examples/talk-llama/models/qwen35moe.cpp b/examples/talk-llama/models/qwen35moe.cpp index 165e2412e56..7dc6a23c751 100644 --- a/examples/talk-llama/models/qwen35moe.cpp +++ b/examples/talk-llama/models/qwen35moe.cpp @@ -64,6 +64,9 @@ llm_build_qwen35moe::llm_build_qwen35moe(const llama_model & model, const llm_gr cur = ggml_add(ctx0, cur, ffn_residual); cb(cur, "post_moe", il); + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + // Input for next layer inpL = cur; } @@ -176,7 +179,7 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn( const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale; cur = build_attn(inp, - nullptr, nullptr, + nullptr, nullptr, nullptr, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_pregate", il); @@ -222,9 +225,10 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear( cb(beta, "beta", il); beta = ggml_sigmoid(ctx0, beta); + cb(beta, "beta_sigmoid", il); ggml_tensor * alpha = build_lora_mm(model.layers[il].ssm_alpha, cur, model.layers[il].ssm_alpha_s); - alpha = ggml_cont_3d(ctx0, alpha, num_v_heads, n_seq_tokens, n_seqs); + alpha = ggml_reshape_3d(ctx0, alpha, num_v_heads, n_seq_tokens, n_seqs); cb(alpha, "alpha", il); ggml_tensor * alpha_biased = ggml_add(ctx0, alpha, model.layers[il].ssm_dt); @@ -266,7 +270,7 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear( cb(last_conv_states, "last_conv_states", il); ggml_tensor * state_update_target = - ggml_view_1d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels * n_seqs, + ggml_view_2d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels, n_seqs, conv_states_all->nb[1], kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all)); cb(state_update_target, "state_update_target", il); @@ -342,7 +346,7 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear( // Update the recurrent states ggml_build_forward_expand(gf, ggml_cpy(ctx0, new_state, - ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs, + ggml_view_2d(ctx0, ssm_states_all, hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1], kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all)))); // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] diff --git a/examples/talk-llama/models/qwen3moe.cpp b/examples/talk-llama/models/qwen3moe.cpp index dba46618ff2..16bedba994d 100644 --- a/examples/talk-llama/models/qwen3moe.cpp +++ b/examples/talk-llama/models/qwen3moe.cpp @@ -30,18 +30,8 @@ llm_build_qwen3moe::llm_build_qwen3moe(const llama_model & model, const llm_grap // self_attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur, model.layers[il].wq_s); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur, model.layers[il].wk_s); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur, model.layers[il].wv_s); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); cb(Qcur, "Qcur_normed", il); @@ -66,11 +56,8 @@ llm_build_qwen3moe::llm_build_qwen3moe(const llama_model & model, const llm_grap cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); - if (model.layers[il].wo_s) { - cur = ggml_mul(ctx0, cur, model.layers[il].wo_s); - } } if (il == n_layer - 1 && inp_out_ids) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); diff --git a/examples/talk-llama/models/qwen3next.cpp b/examples/talk-llama/models/qwen3next.cpp index cc479dd075c..1beda70b7cf 100644 --- a/examples/talk-llama/models/qwen3next.cpp +++ b/examples/talk-llama/models/qwen3next.cpp @@ -56,6 +56,9 @@ llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_gr cur = ggml_add(ctx0, cur, ffn_residual); cb(cur, "post_moe", il); + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + // Input for next layer inpL = cur; } @@ -154,7 +157,7 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn( const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale; cur = build_attn(inp, - nullptr, nullptr, + nullptr, nullptr, nullptr, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_pregate", il); @@ -169,7 +172,7 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn( cur = ggml_mul(ctx0, cur, gate); cb(cur, "attn_gated", il); - cur = build_lora_mm(model.layers[il].wo, cur); + cur = build_lora_mm(model.layers[il].wo, cur, model.layers[il].wo_s); cb(cur, "attn_output", il); return cur; @@ -351,7 +354,7 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( cb(last_conv_states, "last_conv_states", il); ggml_tensor * state_update_target = - ggml_view_1d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels * n_seqs, + ggml_view_2d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels, n_seqs, conv_states_all->nb[1], kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all)); cb(state_update_target, "state_update_target", il); @@ -411,19 +414,19 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( GGML_ASSERT(num_v_heads % num_k_heads == 0); int64_t repeat_factor = num_v_heads / num_k_heads; - // repeat interleave: reshape to (repeat part, 1, remaining part), do repeat, then reshape back - ggml_tensor * q_reshaped = ggml_reshape_3d(ctx0, q_conv, head_k_dim, 1, num_k_heads * n_seq_tokens * n_seqs); - ggml_tensor * k_reshaped = ggml_reshape_3d(ctx0, k_conv, head_k_dim, 1, num_k_heads * n_seq_tokens * n_seqs); + // repeat interleave: reshape to (repeat part, 1, remaining part...), do repeat, then reshape back + ggml_tensor * q_reshaped = ggml_reshape_4d(ctx0, q_conv, head_k_dim, 1, num_k_heads, n_seq_tokens * n_seqs); + ggml_tensor * k_reshaped = ggml_reshape_4d(ctx0, k_conv, head_k_dim, 1, num_k_heads, n_seq_tokens * n_seqs); // Repeat along the third dimension (the new dimension with size 1) ggml_tensor * q_repeated = - ggml_repeat_4d(ctx0, q_reshaped, head_k_dim, repeat_factor, num_k_heads * n_seq_tokens * n_seqs, 1); + ggml_repeat_4d(ctx0, q_reshaped, head_k_dim, repeat_factor, num_k_heads, n_seq_tokens * n_seqs); ggml_tensor * k_repeated = - ggml_repeat_4d(ctx0, k_reshaped, head_k_dim, repeat_factor, num_k_heads * n_seq_tokens * n_seqs, 1); + ggml_repeat_4d(ctx0, k_reshaped, head_k_dim, repeat_factor, num_k_heads, n_seq_tokens * n_seqs); // Reshape back to merge the head and repeat dimensions - // From [head_dim, num_k_heads, repeat_factor, n_seq_tokens * n_seqs] - // Back to [head_dim, num_k_heads * repeat_factor, n_seq_tokens, n_seqs] + // From [head_dim, repeat_factor, num_k_heads, n_seq_tokens * n_seqs] + // Back to [head_dim, repeat_factor * num_k_heads, n_seq_tokens, n_seqs] q_conv = ggml_reshape_4d(ctx0, q_repeated, head_k_dim, num_k_heads * repeat_factor, n_seq_tokens, n_seqs); k_conv = ggml_reshape_4d(ctx0, k_repeated, head_k_dim, num_k_heads * repeat_factor, n_seq_tokens, n_seqs); } @@ -442,7 +445,7 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( // Update the recurrent states ggml_build_forward_expand(gf, ggml_cpy(ctx0, new_state, - ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs, + ggml_view_2d(ctx0, ssm_states_all, hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1], kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all)))); // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] diff --git a/examples/talk-llama/models/qwen3vl-moe.cpp b/examples/talk-llama/models/qwen3vl-moe.cpp index 195daea66c9..29ee8278a4d 100644 --- a/examples/talk-llama/models/qwen3vl-moe.cpp +++ b/examples/talk-llama/models/qwen3vl-moe.cpp @@ -36,18 +36,8 @@ llm_build_qwen3vlmoe::llm_build_qwen3vlmoe(const llama_model & model, const llm_ // self_attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); cb(Qcur, "Qcur_normed", il); @@ -72,7 +62,7 @@ llm_build_qwen3vlmoe::llm_build_qwen3vlmoe(const llama_model & model, const llm_ cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } diff --git a/examples/talk-llama/models/qwen3vl.cpp b/examples/talk-llama/models/qwen3vl.cpp index bbd5f42ba5b..faa5f2ef3c8 100644 --- a/examples/talk-llama/models/qwen3vl.cpp +++ b/examples/talk-llama/models/qwen3vl.cpp @@ -36,18 +36,8 @@ llm_build_qwen3vl::llm_build_qwen3vl(const llama_model & model, const llm_graph_ // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); cb(Qcur, "Qcur_normed", il); @@ -72,7 +62,7 @@ llm_build_qwen3vl::llm_build_qwen3vl(const llama_model & model, const llm_graph_ cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } diff --git a/examples/talk-llama/models/refact.cpp b/examples/talk-llama/models/refact.cpp index 140700d9e2d..398eb368db0 100644 --- a/examples/talk-llama/models/refact.cpp +++ b/examples/talk-llama/models/refact.cpp @@ -24,25 +24,15 @@ llm_build_refact::llm_build_refact(const llama_model & model, const llm_graph_pa // self-attention { - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/rnd1.cpp b/examples/talk-llama/models/rnd1.cpp index c8e1f43400f..a917c19f25a 100644 --- a/examples/talk-llama/models/rnd1.cpp +++ b/examples/talk-llama/models/rnd1.cpp @@ -32,18 +32,8 @@ llm_build_rnd1::llm_build_rnd1(const llama_model & model, const llm_graph_params // self_attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); cb(Qcur, "Qcur_normed", il); @@ -68,7 +58,7 @@ llm_build_rnd1::llm_build_rnd1(const llama_model & model, const llm_graph_params cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/rwkv6.cpp b/examples/talk-llama/models/rwkv6.cpp index 15453fbf50f..032b219d6cb 100644 --- a/examples/talk-llama/models/rwkv6.cpp +++ b/examples/talk-llama/models/rwkv6.cpp @@ -8,7 +8,7 @@ llm_build_rwkv6::llm_build_rwkv6(const llama_model & model, const llm_graph_para ggml_tensor * inpL; inpL = build_inp_embd(model.tok_embd); - inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1); + inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, 0); auto * rs_inp = build_rs_inp(); diff --git a/examples/talk-llama/models/rwkv7.cpp b/examples/talk-llama/models/rwkv7.cpp index 5caf6553dfe..16ffa6901b9 100644 --- a/examples/talk-llama/models/rwkv7.cpp +++ b/examples/talk-llama/models/rwkv7.cpp @@ -9,7 +9,7 @@ llm_build_rwkv7::llm_build_rwkv7(const llama_model & model, const llm_graph_para ggml_tensor * v_first = nullptr; inpL = build_inp_embd(model.tok_embd); - inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1); + inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, 0); auto * rs_inp = build_rs_inp(); diff --git a/examples/talk-llama/models/seed-oss.cpp b/examples/talk-llama/models/seed-oss.cpp index a4d0b75d846..6db8d9781fe 100644 --- a/examples/talk-llama/models/seed-oss.cpp +++ b/examples/talk-llama/models/seed-oss.cpp @@ -32,27 +32,8 @@ llm_build_seed_oss::llm_build_seed_oss(const llama_model & model, const llm_grap // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -71,7 +52,7 @@ llm_build_seed_oss::llm_build_seed_oss(const llama_model & model, const llm_grap cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); } diff --git a/examples/talk-llama/models/smallthinker.cpp b/examples/talk-llama/models/smallthinker.cpp index e2155aacef4..55d09ec325d 100644 --- a/examples/talk-llama/models/smallthinker.cpp +++ b/examples/talk-llama/models/smallthinker.cpp @@ -45,18 +45,8 @@ llm_build_smallthinker::llm_build_smallthinker(const llama_model & model, // self_attention { // compute Q and K and RoPE them - struct ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - struct ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - struct ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); if (use_rope) { Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, @@ -69,7 +59,7 @@ llm_build_smallthinker::llm_build_smallthinker(const llama_model & model, cb(Kcur, "Kcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -101,6 +91,7 @@ llm_build_smallthinker::llm_build_smallthinker(const llama_model & model, cur = ffn_out; cur = ggml_add(ctx0, cur, ffn_inp); + cur = build_cvec(cur, il); cb(cur, "l_out", il); diff --git a/examples/talk-llama/models/smollm3.cpp b/examples/talk-llama/models/smollm3.cpp index e267fd8f32f..83636dbf546 100644 --- a/examples/talk-llama/models/smollm3.cpp +++ b/examples/talk-llama/models/smollm3.cpp @@ -34,27 +34,8 @@ llm_build_smollm3::llm_build_smollm3(const llama_model & model, const llm_graph_ // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); if (use_rope) { Qcur = ggml_rope_ext( @@ -74,7 +55,7 @@ llm_build_smollm3::llm_build_smollm3(const llama_model & model, const llm_graph_ cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); } diff --git a/examples/talk-llama/models/stablelm.cpp b/examples/talk-llama/models/stablelm.cpp index ff5aced93b3..9c19abd8835 100644 --- a/examples/talk-llama/models/stablelm.cpp +++ b/examples/talk-llama/models/stablelm.cpp @@ -30,30 +30,8 @@ llm_build_stablelm::llm_build_stablelm(const llama_model & model, const llm_grap // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); if (model.layers[il].attn_q_norm) { Qcur = build_norm(Qcur, @@ -87,7 +65,7 @@ llm_build_stablelm::llm_build_stablelm(const llama_model & model, const llm_grap cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/starcoder.cpp b/examples/talk-llama/models/starcoder.cpp index 941cee98219..cf9fe95c35b 100644 --- a/examples/talk-llama/models/starcoder.cpp +++ b/examples/talk-llama/models/starcoder.cpp @@ -2,7 +2,6 @@ llm_build_starcoder::llm_build_starcoder(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); @@ -33,22 +32,11 @@ llm_build_starcoder::llm_build_starcoder(const llama_model & model, const llm_gr // self-attention { - cur = build_lora_mm(model.layers[il].wqkv, cur); - cb(cur, "wqkv", il); - - cur = ggml_add(ctx0, cur, model.layers[il].bqkv); - cb(cur, "bqkv", il); - - ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); - ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); - - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/starcoder2.cpp b/examples/talk-llama/models/starcoder2.cpp index a5965aceb3b..b6d4d5aac1a 100644 --- a/examples/talk-llama/models/starcoder2.cpp +++ b/examples/talk-llama/models/starcoder2.cpp @@ -30,27 +30,8 @@ llm_build_starcoder2::llm_build_starcoder2(const llama_model & model, const llm_ // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -69,7 +50,7 @@ llm_build_starcoder2::llm_build_starcoder2(const llama_model & model, const llm_ cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/models/step35-iswa.cpp b/examples/talk-llama/models/step35-iswa.cpp index 176209cd93e..86aa98909e7 100644 --- a/examples/talk-llama/models/step35-iswa.cpp +++ b/examples/talk-llama/models/step35-iswa.cpp @@ -68,7 +68,7 @@ llm_build_step35_iswa::llm_build_step35_iswa(const llama_model & model, const ll const float kq_scale = 1.0f / sqrtf(float(n_embd_head_k)); ggml_tensor * attn_out = build_attn(inp_attn, - nullptr, nullptr, + nullptr, nullptr, nullptr, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(attn_out, "attn_out", il); // head-wise attention gate: sigmoid(g_proj(x)) in torch @@ -92,7 +92,7 @@ llm_build_step35_iswa::llm_build_step35_iswa(const llama_model & model, const ll } // output projection - cur = build_lora_mm(model.layers[il].wo, attn_out); + cur = build_lora_mm(model.layers[il].wo, attn_out, model.layers[il].wo_s); cb(cur, "attn_proj", il); } @@ -145,9 +145,11 @@ llm_build_step35_iswa::llm_build_step35_iswa(const llama_model & model, const ll cb(cur, "ffn_out", il); } cur = ggml_add(ctx0, cur, ffn_inp); + cur = build_cvec(cur, il); cb(cur, "l_out", il); + // input for next layer inpL = cur; } diff --git a/examples/talk-llama/models/t5-enc.cpp b/examples/talk-llama/models/t5-enc.cpp deleted file mode 100644 index 395dfb51042..00000000000 --- a/examples/talk-llama/models/t5-enc.cpp +++ /dev/null @@ -1,96 +0,0 @@ -#include "models.h" - -llm_build_t5_enc::llm_build_t5_enc(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v(); - - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); - - ggml_tensor * cur; - ggml_tensor * inpL; - - inpL = build_inp_embd(model.tok_embd); - - ggml_tensor * pos_bucket_enc = build_inp_pos_bucket_enc(); - - auto * inp_attn = build_attn_inp_no_cache(); - - ggml_tensor * inp_out_ids = build_inp_out_ids(); - - for (int il = 0; il < n_layer; ++il) { - ggml_tensor * inpSA = inpL; - - // norm - cur = build_norm(inpL, - model.layers[il].attn_norm_enc, NULL, - LLM_NORM_RMS, il); - cb(cur, "attn_norm", il); - - // self-attention - { - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq_enc, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk_enc, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv_enc, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - - ggml_tensor * attn_rel_b = model.layers[il].attn_rel_b_enc ? model.layers[il].attn_rel_b_enc : model.layers[0].attn_rel_b_enc; - ggml_tensor * kq_b = build_pos_bias(pos_bucket_enc, attn_rel_b); - - cur = build_attn(inp_attn, - model.layers[il].wo_enc, nullptr, - Qcur, Kcur, Vcur, kq_b, nullptr, nullptr, 1.0f, il); - cb(cur, "kqv_out", il); - } - if (il == n_layer - 1 && inp_out_ids) { - cur = ggml_get_rows(ctx0, cur, inp_out_ids); - inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); - } - ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); - cb(ffn_inp, "ffn_inp", il); - - // feed-forward network - { - cur = build_norm(ffn_inp, - model.layers[il].ffn_norm_enc, NULL, - LLM_NORM_RMS, il); - cb(cur, "ffn_norm", il); - - // T5 uses relu, flan-T5 uses gelu-gated - cur = build_ffn(cur, - model.layers[il].ffn_up_enc, NULL, NULL, - model.layers[il].ffn_gate_enc, NULL, NULL, - model.layers[il].ffn_down_enc, NULL, NULL, - NULL, - model.layers[il].ffn_gate_enc ? LLM_FFN_GELU : LLM_FFN_RELU, - model.layers[il].ffn_gate_enc ? LLM_FFN_PAR : LLM_FFN_SEQ, - il); - cb(cur, "ffn_out", il); - } - cur = ggml_add(ctx0, cur, ffn_inp); - cb(cur, "ffn_out", il); - - cur = build_cvec(cur, il); - cb(cur, "l_out", il); - - // input for next layer - inpL = cur; - } - cur = inpL; - cb(cur, "result_embd", -1); - - cur = build_norm(cur, - model.output_norm_enc, NULL, - LLM_NORM_RMS, -1); - - cb(cur, "result_norm", -1); - res->t_embd = cur; - - ggml_build_forward_expand(gf, cur); -} diff --git a/examples/talk-llama/models/t5-dec.cpp b/examples/talk-llama/models/t5.cpp similarity index 64% rename from examples/talk-llama/models/t5-dec.cpp rename to examples/talk-llama/models/t5.cpp index 8ca8372bd4c..9f9dfef4012 100644 --- a/examples/talk-llama/models/t5-dec.cpp +++ b/examples/talk-llama/models/t5.cpp @@ -1,6 +1,7 @@ #include "models.h" -llm_build_t5_dec::llm_build_t5_dec(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +template <> +llm_build_t5::llm_build_t5(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); //const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); @@ -34,24 +35,13 @@ llm_build_t5_dec::llm_build_t5_dec(const llama_model & model, const llm_graph_pa // self-attention { - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, n_embd_head, n_head, n_head_kv, il); ggml_tensor * attn_rel_b = model.layers[il].attn_rel_b ? model.layers[il].attn_rel_b : model.layers[0].attn_rel_b; ggml_tensor * kq_b = build_pos_bias(pos_bucket_dec, attn_rel_b); cur = build_attn(inp_attn_self, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, kq_b, nullptr, nullptr, 1.0f, il); cb(cur, "kqv_out", il); } @@ -82,7 +72,7 @@ llm_build_t5_dec::llm_build_t5_dec(const llama_model & model, const llm_graph_pa Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_outputs_enc); cur = build_attn(inp_attn_cross, - model.layers[il].wo_cross, nullptr, + model.layers[il].wo_cross, nullptr, nullptr, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il); cb(cur, "kqv_out", il); @@ -164,3 +154,99 @@ llm_build_t5_dec::llm_build_t5_dec(const llama_model & model, const llm_graph_pa ggml_build_forward_expand(gf, cur); } + +template <> +llm_build_t5::llm_build_t5(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + ggml_tensor * pos_bucket_enc = build_inp_pos_bucket_enc(); + + auto * inp_attn = build_attn_inp_no_cache(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm_enc, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq_enc, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk_enc, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv_enc, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + ggml_tensor * attn_rel_b = model.layers[il].attn_rel_b_enc ? model.layers[il].attn_rel_b_enc : model.layers[0].attn_rel_b_enc; + ggml_tensor * kq_b = build_pos_bias(pos_bucket_enc, attn_rel_b); + + cur = build_attn(inp_attn, + model.layers[il].wo_enc, nullptr, nullptr, + Qcur, Kcur, Vcur, kq_b, nullptr, nullptr, 1.0f, il); + cb(cur, "kqv_out", il); + } + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + { + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm_enc, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + // T5 uses relu, flan-T5 uses gelu-gated + cur = build_ffn(cur, + model.layers[il].ffn_up_enc, NULL, NULL, + model.layers[il].ffn_gate_enc, NULL, NULL, + model.layers[il].ffn_down_enc, NULL, NULL, + NULL, + model.layers[il].ffn_gate_enc ? LLM_FFN_GELU : LLM_FFN_RELU, + model.layers[il].ffn_gate_enc ? LLM_FFN_PAR : LLM_FFN_SEQ, + il); + cb(cur, "ffn_out", il); + } + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + cur = inpL; + cb(cur, "result_embd", -1); + + cur = build_norm(cur, + model.output_norm_enc, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + ggml_build_forward_expand(gf, cur); +} diff --git a/examples/talk-llama/models/t5encoder.cpp b/examples/talk-llama/models/t5encoder.cpp new file mode 100644 index 00000000000..5c1f9eb4030 --- /dev/null +++ b/examples/talk-llama/models/t5encoder.cpp @@ -0,0 +1,3 @@ +#include "models.h" + +llm_build_t5encoder::llm_build_t5encoder(const llama_model & model, const llm_graph_params & params) : llm_build_t5(model, params) {} diff --git a/examples/talk-llama/models/wavtokenizer-dec.cpp b/examples/talk-llama/models/wavtokenizer-dec.cpp index 537a0d41248..a7776d9cdc9 100644 --- a/examples/talk-llama/models/wavtokenizer-dec.cpp +++ b/examples/talk-llama/models/wavtokenizer-dec.cpp @@ -93,7 +93,7 @@ llm_build_wavtokenizer_dec::llm_build_wavtokenizer_dec(const llama_model & model cur = build_norm(cur, model.tok_norm, model.tok_norm_b, - LLM_NORM, -1); + LLM_NORM, 0); cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur)); diff --git a/examples/talk-llama/models/xverse.cpp b/examples/talk-llama/models/xverse.cpp index 3a8dfafcceb..53085ec80f6 100644 --- a/examples/talk-llama/models/xverse.cpp +++ b/examples/talk-llama/models/xverse.cpp @@ -28,18 +28,8 @@ llm_build_xverse::llm_build_xverse(const llama_model & model, const llm_graph_pa // self-attention { - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -58,7 +48,7 @@ llm_build_xverse::llm_build_xverse(const llama_model & model, const llm_graph_pa cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { diff --git a/examples/talk-llama/unicode.cpp b/examples/talk-llama/unicode.cpp index 122c8ca04a5..dc13e53f09f 100644 --- a/examples/talk-llama/unicode.cpp +++ b/examples/talk-llama/unicode.cpp @@ -470,6 +470,141 @@ static std::vector unicode_regex_split_custom_llama3(const std::string & return bpe_offsets; } +// Qwen2 system regex: "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" +static std::vector unicode_regex_split_custom_qwen2(const std::string & text, const std::vector & offsets) { + std::vector bpe_offsets; // store the offset of each word + bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size + + const auto cpts = unicode_cpts_from_utf8(text); + + size_t start = 0; + for (auto offset : offsets) { + const size_t offset_ini = start; + const size_t offset_end = start + offset; + assert(offset_end <= cpts.size()); + start = offset_end; + + static const uint32_t OUT_OF_RANGE = 0xFFFFFFFF; + auto _get_cpt = [&] (const size_t pos) -> uint32_t { + return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE; + }; + + auto _get_flags = [&] (const size_t pos) -> unicode_cpt_flags { + return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags_from_cpt(cpts[pos]) : unicode_cpt_flags{}; + }; + + size_t _prev_end = offset_ini; + auto _add_token = [&] (const size_t end) -> size_t { + assert(_prev_end <= end && end <= offset_end); + size_t len = end - _prev_end; + if (len > 0) { + bpe_offsets.push_back(len); + } + _prev_end = end; + //if (len > 0) { + // std::string s = ""; + // for(size_t p = end-len; p < end; p++) + // s += unicode_cpt_to_utf8(cpts[p]); + // printf(">>> '%s'\n", s.c_str()); + //} + return len; + }; + + for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) { + const uint32_t cpt = _get_cpt(pos); + const auto flags = _get_flags(pos); + + // regex: (?i:'s|'t|'re|'ve|'m|'ll|'d) // case insensitive + if (cpt == '\'' && pos+1 < offset_end) { + uint32_t cpt_next = unicode_tolower(_get_cpt(pos+1)); + if (cpt_next == 's' || cpt_next == 't' || cpt_next == 'm' || cpt_next == 'd') { + pos += _add_token(pos+2); + continue; + } + if (pos+2 < offset_end) { + uint32_t cpt_next_next = unicode_tolower(_get_cpt(pos+2)); + if ((cpt_next == 'r' && cpt_next_next == 'e') || + (cpt_next == 'v' && cpt_next_next == 'e') || + (cpt_next == 'l' && cpt_next_next == 'l')) { + pos += _add_token(pos+3); + continue; + } + } + } + + // regex: [^\r\n\p{L}\p{N}]?\p{L}+ + if (!(cpt == '\r' || cpt == '\n' || flags.is_number)) { + if (flags.is_letter || _get_flags(pos+1).is_letter) { // one or more letters + pos++; + while (_get_flags(pos).is_letter) { + pos++; + } + _add_token(pos); + continue; + } + } + + // regex: \p{N} + if (flags.is_number) { + pos++; + _add_token(pos); + continue; + } + + // regex: ?[^\s\p{L}\p{N}]+[\r\n]* + auto flags2 = (cpt == ' ' ? _get_flags(pos+1) : flags); + if (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags.as_uint()) { + pos += (cpt == ' '); + while (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags2.as_uint()) { + flags2 = _get_flags(++pos); + } + uint32_t cpt2 = _get_cpt(pos); + while (cpt2 == '\r' || cpt2 == '\n') { + cpt2 = _get_cpt(++pos); + } + _add_token(pos); + continue; + } + + size_t num_whitespaces = 0; + size_t last_end_r_or_n = 0; + while (_get_flags(pos+num_whitespaces).is_whitespace) { + uint32_t cpt2 = _get_cpt(pos+num_whitespaces); + if (cpt2 == '\r' || cpt2 == '\n') { + last_end_r_or_n = pos + num_whitespaces + 1; + } + num_whitespaces++; + } + + // regex: \s*[\r\n]+ + if (last_end_r_or_n > 0) { + pos = last_end_r_or_n; + _add_token(pos); + continue; + } + + // regex: \s+(?!\S) + if (num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != OUT_OF_RANGE) { + pos += num_whitespaces - 1; + _add_token(pos); + continue; + } + + // regex: \s+ + if (num_whitespaces > 0) { + pos += num_whitespaces; + _add_token(pos); + continue; + } + + // no matches + _add_token(++pos); + } + } + + return bpe_offsets; +} + template static std::vector unicode_regex_split_stl(const std::basic_string & text, const std::basic_string & regex, const std::vector & offsets) { using BidirIt = typename std::basic_string::const_iterator; @@ -753,6 +888,35 @@ static std::vector unicode_regex_split_custom_afmoe(const std::string & return bpe_offsets; } +// regex: [^\n]+|[\n]+ +// splits text into runs of non-newline characters and runs of newline characters +static std::vector unicode_regex_split_custom_newlines(const std::string & text, const std::vector & offsets) { + std::vector bpe_offsets; + bpe_offsets.reserve(offsets.size()); + + const auto cpts = unicode_cpts_from_utf8(text); + + size_t start = 0; + for (auto offset : offsets) { + const size_t offset_ini = start; + const size_t offset_end = start + offset; + assert(offset_end <= cpts.size()); + start = offset_end; + + size_t pos = offset_ini; + while (pos < offset_end) { + const bool is_newline = (cpts[pos] == '\n'); + const size_t run_start = pos; + while (pos < offset_end && (cpts[pos] == '\n') == is_newline) { + pos++; + } + bpe_offsets.push_back(pos - run_start); + } + } + + return bpe_offsets; +} + static std::vector unicode_regex_split_custom(const std::string & text, const std::string & regex_expr, const std::vector & offsets) { std::vector bpe_offsets; @@ -761,14 +925,18 @@ static std::vector unicode_regex_split_custom(const std::string & text, } else if ( regex_expr == "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" || regex_expr == "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+") { - bpe_offsets = unicode_regex_split_custom_llama3(text, offsets); + } else if ( + regex_expr == "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+") { + bpe_offsets = unicode_regex_split_custom_qwen2(text, offsets); } else if (regex_expr == "\\p{Han}+") { // K2's first pattern - handle all K2 patterns together bpe_offsets = unicode_regex_split_custom_kimi_k2(text, offsets); } else if (regex_expr == "\\p{AFMoE_digits}") { // AFMOE digit pattern - use custom implementation for proper splitting bpe_offsets = unicode_regex_split_custom_afmoe(text, offsets); + } else if (regex_expr == "[^\\n]+|[\\n]+") { + bpe_offsets = unicode_regex_split_custom_newlines(text, offsets); } else if (regex_expr == "\\d{1,3}(?=(?:\\d{3})*\\b)") { // tiny_aya digit grouping pattern from tokenizer.json: // {"type": "Split", "pattern": {"Regex": "\\d{1,3}(?=(?:\\d{3})*\\b)"}, "behavior": "Isolated"} @@ -912,7 +1080,7 @@ bool unicode_cpt_is_han(uint32_t cpt) { return false; } -std::vector unicode_regex_split(const std::string & text, const std::vector & regex_exprs) { +std::vector unicode_regex_split(const std::string & text, const std::vector & regex_exprs, bool byte_encode) { // unicode categories static const std::map k_ucat_enum = { { "\\p{N}", unicode_cpt_flags::NUMBER }, @@ -1099,5 +1267,9 @@ std::vector unicode_regex_split(const std::string & text, const std start += offset; } - return unicode_byte_encoding_process(bpe_words); + if (byte_encode) { + return unicode_byte_encoding_process(bpe_words); + } + + return bpe_words; } diff --git a/examples/talk-llama/unicode.h b/examples/talk-llama/unicode.h index 5bd1362ff41..600ab9216b9 100644 --- a/examples/talk-llama/unicode.h +++ b/examples/talk-llama/unicode.h @@ -108,4 +108,4 @@ uint32_t unicode_tolower(uint32_t cpt); bool unicode_cpt_is_han(uint32_t cpt); -std::vector unicode_regex_split(const std::string & text, const std::vector & regex_exprs); +std::vector unicode_regex_split(const std::string & text, const std::vector & regex_exprs, bool byte_encode = true); From c81b2dabbc45484dee2ca6658cfe39c841df5c70 Mon Sep 17 00:00:00 2001 From: KITAITI Makoto Date: Thu, 7 May 2026 13:28:18 +0900 Subject: [PATCH 543/831] ruby : transcribe without GVL, accept more MemoryViews, Windows support, fix memory size report, improve document (#3775) * Change MemoryView example using NDAV * Add note on audio attributes for #full and #full_parallel * Support more variants of MemoryView * Use IO.popen instead of Kernel.` for Windows compatibility * Use cmake's -C option instead of multiple -D options * Fix memsize calculation * Remove unused argument * Add is_interrupted field to abort callback container * Fix RBS syntax * Address document comment for RDoc * Add .document for RDoc * Add .rdoc_options * Run #full without GVL * Initialize callbacks with nil * Specify implicity Whisper::Params to distinguish from Whisper::Context::Params * Run callbacks without GVL * Call log callback with GVL * Run full_parallel without GVL * Run transcribe without GVL * Fix ruby_whisper_lock_gvl and ruby_whisper_unlock_gvl * Fix return value of encoder_begin_callback * Report GVL unlocking from transcribe * Remove unused interface * Restore overload of full_parallel * Close process * Fix struct name * Make is_without_gvl thread local * Use rb_thread_call_with_gvl instead of global variable * Retrieve instance variable in GVL * Narrow acceptable MemoryView format * Fix option cache path * Reduce files in package * Use append_cflags * Add ext/*.rb to task dependencies * Use copy instead of cp * Make TestPackage more portable * Patch for lower version Ruby * Make build scripts more portable * Add Windows support * Don't raise exceptions --- bindings/ruby/.document | 3 + bindings/ruby/.rdoc_options | 2 + bindings/ruby/README.md | 10 +- bindings/ruby/Rakefile | 4 +- bindings/ruby/ext/dependencies.rb | 14 +- bindings/ruby/ext/dependencies_for_windows.rb | 17 ++ bindings/ruby/ext/extconf.rb | 28 +- bindings/ruby/ext/options.rb | 68 ++++- bindings/ruby/ext/options_for_windows.rb | 51 ++++ bindings/ruby/ext/ruby_whisper.c | 36 ++- bindings/ruby/ext/ruby_whisper.h | 17 +- bindings/ruby/ext/ruby_whisper_context.c | 102 ++++++- bindings/ruby/ext/ruby_whisper_params.c | 261 +++++++++++++++--- bindings/ruby/ext/ruby_whisper_transcribe.cpp | 47 +++- bindings/ruby/extsources.rb | 36 ++- bindings/ruby/sig/whisper.rbs | 100 +++---- bindings/ruby/test/test_package.rb | 11 +- 17 files changed, 647 insertions(+), 160 deletions(-) create mode 100644 bindings/ruby/.document create mode 100644 bindings/ruby/.rdoc_options create mode 100644 bindings/ruby/ext/dependencies_for_windows.rb create mode 100644 bindings/ruby/ext/options_for_windows.rb diff --git a/bindings/ruby/.document b/bindings/ruby/.document new file mode 100644 index 00000000000..a8e9788fc7c --- /dev/null +++ b/bindings/ruby/.document @@ -0,0 +1,3 @@ +README.md +LICENSE +sig diff --git a/bindings/ruby/.rdoc_options b/bindings/ruby/.rdoc_options new file mode 100644 index 00000000000..cf14aa5f5b4 --- /dev/null +++ b/bindings/ruby/.rdoc_options @@ -0,0 +1,2 @@ +title: whispercpp +main_page: README.md diff --git a/bindings/ruby/README.md b/bindings/ruby/README.md index 41e7b330d58..07b81830c58 100644 --- a/bindings/ruby/README.md +++ b/bindings/ruby/README.md @@ -360,7 +360,7 @@ Whisper::Context.new("base") ### Low-level API to transcribe ### -You can also call `Whisper::Context#full` and `#full_parallel` with a Ruby array as samples. Although `#transcribe` with audio file path is recommended because it extracts PCM samples in C++ and is fast, `#full` and `#full_parallel` give you flexibility. +You can also call `Whisper::Context#full` and `#full_parallel` with a Ruby array as samples. Although `#transcribe` with audio file path is recommended because it extracts PCM samples in C++ and is fast, `#full` and `#full_parallel` give you flexibility. Unlike `#transcribe`, these methods requires 16,000 Hz, 32-bit float audio. ```ruby require "whisper" @@ -383,16 +383,16 @@ If you can prepare audio data as C array and export it as a MemoryView, whisperc ```ruby require "torchaudio" -require "arrow-numo-narray" +require "ndav/torch/tensor" require "whisper" waveform, sample_rate = TorchAudio.load("test/fixtures/jfk.wav") -# Convert Torch::Tensor to Arrow::Array via Numo::NArray -samples = waveform.squeeze.numo.to_arrow.to_arrow_array +# Convert Torch::Tensor to NDAV +samples = waveform.squeeze.to_ndav whisper = Whisper::Context.new("base") whisper - # Arrow::Array exports MemoryView + # NDAV exports MemoryView .full(Whisper::Params.new, samples) ``` diff --git a/bindings/ruby/Rakefile b/bindings/ruby/Rakefile index d9a66030de4..7b521b3bdfa 100644 --- a/bindings/ruby/Rakefile +++ b/bindings/ruby/Rakefile @@ -16,7 +16,7 @@ EXTSOURCES.each do |src| file src directory dir file dest => [src, dir] do |t| - cp t.source, t.name + copy t.source, t.name end SOURCES.include dest end @@ -34,7 +34,7 @@ LIB_NAME = "whisper".ext(RbConfig::CONFIG["DLEXT"]) SO_FILE = File.join("ext", LIB_NAME) LIB_FILE = File.join("lib", LIB_NAME) -file "ext/Makefile" => SRC + ["ext/extconf.rb"] + SOURCES do |t| +file "ext/Makefile" => SRC + SOURCES + FileList["ext/*.rb"] do |t| chdir "ext" do ruby "extconf.rb" end diff --git a/bindings/ruby/ext/dependencies.rb b/bindings/ruby/ext/dependencies.rb index 2ba4b94b62b..b2eb9beb84f 100644 --- a/bindings/ruby/ext/dependencies.rb +++ b/bindings/ruby/ext/dependencies.rb @@ -22,13 +22,17 @@ def libs else nil end - }.reverse.collect {|lib| "lib#{lib}.a"} + }.reverse.collect {|lib| "#{prefix(lib)}#{lib}.#{RbConfig::CONFIG['LIBEXT']}"} end def to_s libs.join(" ") end + def local_libs + to_s + end + private def dot_path @@ -36,9 +40,7 @@ def dot_path end def generate_dot - args = ["-S", "sources", "-B", "build", "--graphviz", dot_path, "-D", "BUILD_SHARED_LIBS=OFF"] - args << @options.to_s unless @options.to_s.empty? - system @cmake, *args, exception: true + system @cmake, "-S", "sources", "-B", "build", *@options.graphviz_cmake_args, "--graphviz", dot_path, *@options, exception: true end def parse_dot @@ -59,6 +61,10 @@ def parse_dot end end + def prefix(lib) + "lib" + end + def tsort_each_node @nodes.each_key do |node| yield node diff --git a/bindings/ruby/ext/dependencies_for_windows.rb b/bindings/ruby/ext/dependencies_for_windows.rb new file mode 100644 index 00000000000..5574107182d --- /dev/null +++ b/bindings/ruby/ext/dependencies_for_windows.rb @@ -0,0 +1,17 @@ +require_relative "dependencies" + +class DependenciesForWindows < Dependencies + def local_libs + libs.collect {|lib| %|"#{lib_path(lib)}"|}.join(" ") + end + + private + + def prefix(lib) + lib.start_with?("ggml") ? "" : "lib" + end + + def lib_path(lib) + File.join(__dir__, lib).tr("\\", "/") + end +end diff --git a/bindings/ruby/ext/extconf.rb b/bindings/ruby/ext/extconf.rb index acff501aa3b..4b09b6ebe13 100644 --- a/bindings/ruby/ext/extconf.rb +++ b/bindings/ruby/ext/extconf.rb @@ -1,15 +1,27 @@ require "mkmf" -require_relative "options" -require_relative "dependencies" + +if RUBY_PLATFORM.match? /mswin|mingw|ucrt/ + require_relative "options_for_windows" + require_relative "dependencies_for_windows" + + Opts = OptionsForWindows + Deps = DependenciesForWindows +else + require_relative "options" + require_relative "dependencies" + + Opts = Options + Deps = Dependencies +end cmake = find_executable("cmake") || abort -options = Options.new(cmake).to_s +options = Opts.new(cmake) have_library("gomp") rescue nil -libs = Dependencies.new(cmake, options).to_s +libs = Deps.new(cmake, options) -$CFLAGS << " -O3 -march=native" +append_cflags ["-O3", "-march=native"] $INCFLAGS << " -Isources/include -Isources/ggml/include -Isources/examples" -$LOCAL_LIBS << " #{libs}" +$LOCAL_LIBS << " #{libs.local_libs}" $cleanfiles << " build #{libs}" create_makefile "whisper" do |conf| @@ -17,7 +29,7 @@ $(TARGET_SO): #{libs} #{libs}: cmake-targets cmake-targets: - #{"\t"}#{cmake} -S sources -B build -D BUILD_SHARED_LIBS=OFF -D CMAKE_ARCHIVE_OUTPUT_DIRECTORY=#{__dir__} -D CMAKE_POSITION_INDEPENDENT_CODE=ON #{options} - #{"\t"}#{cmake} --build build --config Release --target common whisper + #{"\t"}"#{cmake}" -S sources -B build #{options} + #{"\t"}"#{cmake}" --build build --config Release --target common whisper EOF end diff --git a/bindings/ruby/ext/options.rb b/bindings/ruby/ext/options.rb index ede80c0656b..e723af9fd9a 100644 --- a/bindings/ruby/ext/options.rb +++ b/bindings/ruby/ext/options.rb @@ -1,26 +1,36 @@ +require "fileutils" + class Options def initialize(cmake="cmake") @cmake = cmake @options = {} configure + write_cache_file + end + + def to_a + [ + "-D", "BUILD_SHARED_LIBS=OFF", + "-D", "WHISPER_BUILD_TESTS=OFF", + "-D", "CMAKE_ARCHIVE_OUTPUT_DIRECTORY=#{__dir__}", + "-D", "CMAKE_POSITION_INDEPENDENT_CODE=ON", + "-C", cache_path + ] end def to_s - @options - .reject {|name, (type, value)| value.nil?} - .collect {|name, (type, value)| "-D #{name}=#{value == true ? "ON" : value == false ? "OFF" : value.shellescape}"} - .join(" ") + command_line(*to_a) end - def cmake_options - return @cmake_options if @cmake_options + def graphviz_cmake_args + [] + end - output = nil - Dir.chdir __dir__ do - output = `#{@cmake.shellescape} -S sources -B build -L` - end - @cmake_options = output.lines.drop_while {|line| line.chomp != "-- Cache values"}.drop(1) + private + + def cmake_options + @cmake_options ||= cmake_options_output.lines.drop_while {|line| line.chomp != "-- Cache values"}.drop(1) .filter_map {|line| option, value = line.chomp.split("=", 2) name, type = option.split(":", 2) @@ -34,7 +44,11 @@ def cmake_options }.to_h end - private + def cmake_options_output + Dir.chdir(__dir__) do + IO.popen([@cmake, "-S", "sources", "-B", "build", "-L"]) {|io| io.read} + end + end def configure cmake_options.each_pair do |name, (type, default_value)| @@ -74,12 +88,38 @@ def option_name(name) def enabled?(option) op = @options[option] - raise "Option not exist: #{option}" unless op - raise "Option not boolean: #{option}(#{op[0]})" unless op[0] == "BOOL" + return false unless op + return false unless op[0] == "BOOL" if op[1].nil? cmake_options[option][1] else op[1] end end + + def cache_path + File.join(__dir__, "sources", "Options.cmake") + end + + def write_cache_file + FileUtils.mkpath File.dirname(cache_path) + File.open cache_path, "w" do |file| + @options.reject {|name, (type, value)| value.nil?}.each do |name, (type, value)| + line = "set(CACHE{%s} TYPE %s FORCE VALUE %s)" % { + name:, + type:, + value: value == true ? "ON" : value == false ? "OFF" : escape_cmake(value) + } + file.puts line + end + end + end + + def escape_cmake(str) + str.gsub(/[\\"]/, '\\\\\&') + end + + def command_line(*args) + args.collect {|arg| %|"#{arg.to_s.gsub(/[\\"]/, '\\\\\&')}"|}.join(" ") + end end diff --git a/bindings/ruby/ext/options_for_windows.rb b/bindings/ruby/ext/options_for_windows.rb new file mode 100644 index 00000000000..7db785d8a2d --- /dev/null +++ b/bindings/ruby/ext/options_for_windows.rb @@ -0,0 +1,51 @@ +require_relative "options" + +class OptionsForWindows < Options + def to_s + command_line(*generator_args, *to_a) + end + + def graphviz_cmake_args + generator_args + end + + private + + def arm? + RbConfig::CONFIG["host_cpu"].to_s.downcase.match?(/\A(?:arm64|aarch64)\z/) + end + + def cmake_options_output + Dir.chdir(__dir__) do + IO.popen([@cmake, "-S", "sources", "-B", "build", *generator_args, "-L"]) {|io| io.read} + end + end + + def generator_args + generator = cmake_generator + ["-G", generator] if generator && !generator.empty? + end + + def cmake_generator + return @cmake_generator if defined?(@cmake_generator) + + generator = ENV["CMAKE_GENERATOR"] + abort "CMAKE_GENERATOR=#{generator} is unsupported for mingw/ucrt Ruby" if visual_studio_generator_name?(generator) + return @cmake_generator = generator unless generator.nil? || generator.empty? + + ninja = find_executable("ninja") + return @cmake_generator = "Ninja" if ninja + + make = find_executable("make") + return @cmake_generator = "MSYS Makefiles" if make + + mingw32_make = find_executable("mingw32-make") + return @cmake_generator = "MinGW Makefiles" if mingw32_make + + @cmake_generator = nil + end + + def visual_studio_generator_name?(generator) + generator && generator.start_with?("Visual Studio") + end +end diff --git a/bindings/ruby/ext/ruby_whisper.c b/bindings/ruby/ext/ruby_whisper.c index 5f1917ee805..56fceb1c894 100644 --- a/bindings/ruby/ext/ruby_whisper.c +++ b/bindings/ruby/ext/ruby_whisper.c @@ -29,6 +29,7 @@ ID id_cache; ID id_n_processors; static bool is_log_callback_finalized = false; +static bool is_ruby_log_callback_present = false; // High level API extern VALUE ruby_whisper_segment_allocate(VALUE klass); @@ -106,18 +107,43 @@ static VALUE ruby_whisper_s_finalize_log_callback(VALUE self, VALUE id) { return Qnil; } +typedef struct { + int level; + const char * buffer; +} call_log_callbacks_args; + +static void* +call_log_callbacks(void *v_args) { + VALUE log_callback = rb_iv_get(mWhisper, "log_callback"); + if (NIL_P(log_callback)) { + return NULL; + } + + call_log_callbacks_args *args = (call_log_callbacks_args *)v_args; + VALUE user_data = rb_iv_get(mWhisper, "user_data"); + rb_funcall(log_callback, id_call, 3, INT2NUM(args->level), rb_str_new2(args->buffer), user_data); + + return NULL; +} + static void ruby_whisper_log_callback(enum ggml_log_level level, const char * buffer, void * user_data) { if (is_log_callback_finalized) { return; } - VALUE log_callback = rb_iv_get(mWhisper, "log_callback"); - if (NIL_P(log_callback)) { + if (!is_ruby_log_callback_present) { return; } - VALUE udata = rb_iv_get(mWhisper, "user_data"); - rb_funcall(log_callback, id_call, 3, INT2NUM(level), rb_str_new2(buffer), udata); + call_log_callbacks_args args = { + level, + buffer, + }; + if (ruby_thread_has_gvl_p()) { + call_log_callbacks((void *)&args); + } else { + rb_thread_call_with_gvl(call_log_callbacks, (void *)&args); + } } /* @@ -140,8 +166,10 @@ static VALUE ruby_whisper_s_log_set(VALUE self, VALUE log_callback, VALUE user_d if (NIL_P(log_callback)) { whisper_log_set(NULL, NULL); + is_ruby_log_callback_present = false; } else { whisper_log_set(ruby_whisper_log_callback, NULL); + is_ruby_log_callback_present = true; } return Qnil; diff --git a/bindings/ruby/ext/ruby_whisper.h b/bindings/ruby/ext/ruby_whisper.h index 6b0b4df7214..ba4d8b6fbcc 100644 --- a/bindings/ruby/ext/ruby_whisper.h +++ b/bindings/ruby/ext/ruby_whisper.h @@ -2,10 +2,17 @@ #define RUBY_WHISPER_H #include +#include #include +#include #include #include "whisper.h" +#if RUBY_API_VERSION_MAJOR < 4 +// Exists but not declared as public API +int ruby_thread_has_gvl_p(void); +#endif + typedef struct { VALUE *context; VALUE user_data; @@ -13,6 +20,14 @@ typedef struct { VALUE callbacks; } ruby_whisper_callback_container; +typedef struct { + VALUE *context; + VALUE user_data; + VALUE callback; + VALUE callbacks; + bool is_interrupted; +} ruby_whisper_abort_callback_container; + typedef struct { struct whisper_context *context; } ruby_whisper; @@ -27,7 +42,7 @@ typedef struct { ruby_whisper_callback_container *new_segment_callback_container; ruby_whisper_callback_container *progress_callback_container; ruby_whisper_callback_container *encoder_begin_callback_container; - ruby_whisper_callback_container *abort_callback_container; + ruby_whisper_abort_callback_container *abort_callback_container; VALUE vad_params; } ruby_whisper_params; diff --git a/bindings/ruby/ext/ruby_whisper_context.c b/bindings/ruby/ext/ruby_whisper_context.c index 6e38ead6321..26058fc07e6 100644 --- a/bindings/ruby/ext/ruby_whisper_context.c +++ b/bindings/ruby/ext/ruby_whisper_context.c @@ -1,5 +1,11 @@ #include "ruby_whisper.h" +#ifdef WORDS_BIGENDIAN + #define IS_BIGENDIAN true +#else + #define IS_BIGENDIAN false +#endif + extern ID id_to_s; extern ID id___method__; extern ID id_to_enum; @@ -47,6 +53,27 @@ typedef struct full_parallel_args { int n_processors; } full_parallel_args; +typedef struct full_without_gvl_args { + struct whisper_context *context; + struct whisper_full_params *params; + float *samples; + int n_samples; + int result; +} full_without_gvl_args; + +typedef struct full_parallel_without_gvl_args { + struct whisper_context *context; + struct whisper_full_params *params; + float *samples; + int n_samples; + int n_processors; + int result; +} full_parallel_without_gvl_args; + +typedef struct full_ubf_args { + ruby_whisper_abort_callback_container *abort_callback_container; +} full_ubf_args; + static void ruby_whisper_free(ruby_whisper *rw) { @@ -74,7 +101,7 @@ static size_t ruby_whisper_memsize(const void *p) { const ruby_whisper *rw = (const ruby_whisper *)p; - size_t size = sizeof(rw); + size_t size = sizeof(*rw); if (!rw) { return 0; } @@ -304,11 +331,25 @@ VALUE ruby_whisper_model_type(VALUE self) static bool check_memory_view(rb_memory_view_t *memview) { - if (memview->format != NULL && strcmp(memview->format, "f") != 0) { - rb_warn("currently only format \"f\" is supported for MemoryView, but given: %s", memview->format); + if (!memview->format) { + rb_warn("currently format is required"); + return false; + } + + if (strcmp(memview->format, "f") == 0) { + // accept + } else if (strcmp(memview->format, "e") == 0) { + if (IS_BIGENDIAN) { + rb_warn("currently format \"e\" is only supported on little-endian environment"); + return false; + } + } else { + rb_warn("currently only format \"f\" and \"e\" on little-endian environment is supported for MemoryView, but given: %s", memview->format); return false; } - if (memview->format != NULL && memview->ndim != 1) { + + if (memview->ndim != 1 && !(memview->ndim == 2 && memview->shape[1] == 1)) { + // TODO: Accept ndim == 2 with shape [n_samples, channels] and channels > 1 by averaging the samples in different channels or just taking the first channel rb_warn("currently only 1 dimensional MemoryView is supported, but given: %zd", memview->ndim); return false; } @@ -426,6 +467,22 @@ release_samples(VALUE rb_parsed_args) return Qnil; } +static void* +full_without_gvl(void *rb_args) +{ + full_without_gvl_args *args = (full_without_gvl_args *)rb_args; + args->result = whisper_full(args->context, *args->params, args->samples, args->n_samples); + return NULL; +} + +static void +full_ubf(void *rb_args) +{ + full_ubf_args *args = (full_ubf_args *)rb_args; + + args->abort_callback_container->is_interrupted = true; +} + static VALUE full_body(VALUE rb_args) { @@ -437,9 +494,19 @@ full_body(VALUE rb_args) TypedData_Get_Struct(*args->params, ruby_whisper_params, &ruby_whisper_params_type, rwp); prepare_transcription(rwp, args->context, 1); - int result = whisper_full(rw->context, rwp->params, args->samples, args->n_samples); - return INT2NUM(result); + struct full_without_gvl_args full_without_gvl_args = { + rw->context, + &rwp->params, + args->samples, + args->n_samples, + 0, + }; + full_ubf_args full_ubf_args = { + rwp->abort_callback_container, + }; + rb_thread_call_without_gvl(full_without_gvl, (void *)&full_without_gvl_args, full_ubf, (void *)&full_ubf_args); + return INT2NUM(full_without_gvl_args.result); } /* @@ -477,6 +544,14 @@ VALUE ruby_whisper_full(int argc, VALUE *argv, VALUE self) } } +static void* +full_parallel_without_gvl(void *rb_args) +{ + full_parallel_without_gvl_args *args = (full_parallel_without_gvl_args *)rb_args; + args->result = whisper_full_parallel(args->context, *args->params, args->samples, args->n_samples, args->n_processors); + return NULL; +} + static VALUE full_parallel_body(VALUE rb_args) { @@ -488,9 +563,20 @@ full_parallel_body(VALUE rb_args) TypedData_Get_Struct(*args->params, ruby_whisper_params, &ruby_whisper_params_type, rwp); prepare_transcription(rwp, args->context, args->n_processors); - int result = whisper_full_parallel(rw->context, rwp->params, args->samples, args->n_samples, args->n_processors); - return INT2NUM(result); + struct full_parallel_without_gvl_args full_parallel_without_gvl_args = { + rw->context, + &rwp->params, + args->samples, + args->n_samples, + args->n_processors, + 0, + }; + full_ubf_args full_ubf_args = { + rwp->abort_callback_container, + }; + rb_thread_call_without_gvl(full_parallel_without_gvl, (void *)&full_parallel_without_gvl_args, full_ubf, (void *)&full_ubf_args); + return INT2NUM(full_parallel_without_gvl_args.result); } /* diff --git a/bindings/ruby/ext/ruby_whisper_params.c b/bindings/ruby/ext/ruby_whisper_params.c index 3e5dca9c1e1..2aae7c12d19 100644 --- a/bindings/ruby/ext/ruby_whisper_params.c +++ b/bindings/ruby/ext/ruby_whisper_params.c @@ -93,21 +93,66 @@ rb_whisper_callback_container_allocate() { container->context = NULL; container->user_data = Qnil; container->callback = Qnil; - container->callbacks = rb_ary_new(); + container->callbacks = Qnil; return container; } -static void new_segment_callback(struct whisper_context *ctx, struct whisper_state *state, int n_new, void *user_data) { - const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data; +static void +rb_whisper_abort_callback_container_mark(ruby_whisper_abort_callback_container *rwc) +{ + if (rwc == NULL) return; + + rb_gc_mark(rwc->user_data); + rb_gc_mark(rwc->callback); + rb_gc_mark(rwc->callbacks); +} + +static ruby_whisper_abort_callback_container* +rb_whisper_abort_callback_container_allocate() { + ruby_whisper_abort_callback_container *container; + container = ALLOC(ruby_whisper_abort_callback_container); + container->context = NULL; + container->user_data = Qnil; + container->callback = Qnil; + container->callbacks = Qnil; + container->is_interrupted = false; + return container; +} + +static bool +ruby_whisper_callback_container_is_present(const ruby_whisper_callback_container *container) { + return !NIL_P(container->callback) || !NIL_P(container->callbacks); +} + +static bool +ruby_whisper_abort_callback_container_is_present(const ruby_whisper_abort_callback_container *container) { + return !NIL_P(container->callback) || !NIL_P(container->callbacks); +} + +typedef struct { + const ruby_whisper_callback_container *container; + struct whisper_state *state; + int n_new; +} call_new_segment_callbacks_args; + +static void* +call_new_segment_callbacks(void *v_args) { + call_new_segment_callbacks_args *args = (call_new_segment_callbacks_args *)v_args; + const ruby_whisper_callback_container *container = args->container; + struct whisper_state *state = args->state; + int n_new = args->n_new; // Currently, doesn't support state because // those require to resolve GC-related problems. if (!NIL_P(container->callback)) { rb_funcall(container->callback, id_call, 4, *container->context, Qnil, INT2NUM(n_new), container->user_data); } + if (NIL_P(container->callbacks)) { + return NULL; + } const long callbacks_len = RARRAY_LEN(container->callbacks); if (0 == callbacks_len) { - return; + return NULL; } const int n_segments = whisper_full_n_segments_from_state(state); for (int i = n_new; i > 0; i--) { @@ -118,95 +163,208 @@ static void new_segment_callback(struct whisper_context *ctx, struct whisper_sta rb_funcall(cb, id_call, 1, segment); } } + + return NULL; } -static void progress_callback(struct whisper_context *ctx, struct whisper_state *state, int progress_cur, void *user_data) { +static void new_segment_callback(struct whisper_context *ctx, struct whisper_state *state, int n_new, void *user_data) { const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data; - const VALUE progress = INT2NUM(progress_cur); - // Currently, doesn't support state because + if (!ruby_whisper_callback_container_is_present(container)) { + return; + } + + call_new_segment_callbacks_args args = { + container, + state, + n_new + }; + rb_thread_call_with_gvl(call_new_segment_callbacks, (void *)&args); +} + +typedef struct { + const ruby_whisper_callback_container *container; + struct whisper_state *state; + int progress_cur; +} call_progress_callbacks_args; + +static void* +call_progress_callbacks(void *v_args) { + call_progress_callbacks_args *args = (call_progress_callbacks_args *)v_args; + const ruby_whisper_callback_container *container = args->container; + int progress_cur = args->progress_cur; + + // Currently, doesn't support state because // those require to resolve GC-related problems. - if (!NIL_P(container->callback)) { - rb_funcall(container->callback, id_call, 4, *container->context, Qnil, progress, container->user_data); + if (!NIL_P(args->container->callback)) { + rb_funcall(container->callback, id_call, 4, *container->context, Qnil, INT2NUM(progress_cur), container->user_data); + } + if (NIL_P(container->callbacks)) { + return NULL; } const long callbacks_len = RARRAY_LEN(container->callbacks); if (0 == callbacks_len) { - return; + return NULL; } for (int j = 0; j < callbacks_len; j++) { VALUE cb = rb_ary_entry(container->callbacks, j); - rb_funcall(cb, id_call, 1, progress); + rb_funcall(cb, id_call, 1, INT2NUM(progress_cur)); } + + return NULL; } -static bool encoder_begin_callback(struct whisper_context *ctx, struct whisper_state *state, void *user_data) { +static void progress_callback(struct whisper_context *ctx, struct whisper_state *state, int progress_cur, void *user_data) { const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data; - bool is_aborted = false; - VALUE result; + if (!ruby_whisper_callback_container_is_present(container)) { + return; + } + + call_progress_callbacks_args args = { + container, + state, + progress_cur + }; + rb_thread_call_with_gvl(call_progress_callbacks, (void *)&args); +} + +typedef struct { + const ruby_whisper_callback_container *container; + struct whisper_state *state; + bool is_continued; +} call_encoder_begin_callbacks_args; + +static void* +call_encoder_begin_callbacks(void *v_args) { + call_encoder_begin_callbacks_args *args = (call_encoder_begin_callbacks_args *)v_args; + const ruby_whisper_callback_container *container = args->container; + VALUE result = Qnil; // Currently, doesn't support state because // those require to resolve GC-related problems. if (!NIL_P(container->callback)) { result = rb_funcall(container->callback, id_call, 3, *container->context, Qnil, container->user_data); if (result == Qfalse) { - is_aborted = true; + args->is_continued = false; + return NULL; } } - const long callbacks_len = RARRAY_LEN(container->callbacks); - if (0 == callbacks_len) { - return !is_aborted; - } - for (int j = 0; j < callbacks_len; j++) { - VALUE cb = rb_ary_entry(container->callbacks, j); - result = rb_funcall(cb, id_call, 0); - if (result == Qfalse) { - is_aborted = true; + if (!NIL_P(container->callbacks)) { + const long callbacks_len = RARRAY_LEN(container->callbacks); + if (0 == callbacks_len) { + return NULL; + } + for (int j = 0; j < callbacks_len; j++) { + VALUE cb = rb_ary_entry(container->callbacks, j); + result = rb_funcall(cb, id_call, 0); + if (result == Qfalse) { + args->is_continued = false; + return NULL; + } } } - return !is_aborted; + + return NULL; } -static bool abort_callback(void * user_data) { +static bool encoder_begin_callback(struct whisper_context *ctx, struct whisper_state *state, void *user_data) { const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data; + if (!ruby_whisper_callback_container_is_present(container)) { + return true; + } + + call_encoder_begin_callbacks_args args = { + container, + state, + true + }; + rb_thread_call_with_gvl(call_encoder_begin_callbacks, (void *)&args); + + return args.is_continued; +} + +typedef struct { + const ruby_whisper_abort_callback_container *container; + struct whisper_state *state; + bool is_interrupted; +} call_abort_callbacks_args; + +static void* +call_abort_callbacks(void *v_args) { + call_abort_callbacks_args *args = (call_abort_callbacks_args *)v_args; + const ruby_whisper_abort_callback_container *container = args->container; + + if (container->is_interrupted) { + args->is_interrupted = true; + return NULL; + } + if (!NIL_P(container->callback)) { VALUE result = rb_funcall(container->callback, id_call, 1, container->user_data); if (!NIL_P(result) && Qfalse != result) { - return true; + args->is_interrupted = true; + return NULL; } } + if (NIL_P(container->callbacks)) { + return NULL; + } const long callbacks_len = RARRAY_LEN(container->callbacks); if (0 == callbacks_len) { - return false; + return NULL; } for (int j = 0; j < callbacks_len; j++) { VALUE cb = rb_ary_entry(container->callbacks, j); VALUE result = rb_funcall(cb, id_call, 1, container->user_data); if (!NIL_P(result) && Qfalse != result) { - return true; + args->is_interrupted = true; + return NULL; } } - return false; + + return NULL; +} + +static bool abort_callback(void * user_data) { + const ruby_whisper_abort_callback_container *container = (ruby_whisper_abort_callback_container *)user_data; + + if (container->is_interrupted) { + return true; + } + + if (!ruby_whisper_abort_callback_container_is_present(container)) { + return false; + } + + call_abort_callbacks_args args = { + container, + NULL, + false + }; + rb_thread_call_with_gvl(call_abort_callbacks, (void *)&args); + + return args.is_interrupted; } static void -check_thread_safety(ruby_whisper_params *rwp, VALUE *context, int n_processors) +check_thread_safety(ruby_whisper_params *rwp, int n_processors) { if (n_processors == 1) { return; } - if (!NIL_P(rwp->new_segment_callback_container->callback) || 0 != RARRAY_LEN(rwp->new_segment_callback_container->callbacks)) { + if (ruby_whisper_callback_container_is_present(rwp->new_segment_callback_container)) { rb_raise(rb_eRuntimeError, "new segment callback not supported on parallel transcription"); } - if (!NIL_P(rwp->progress_callback_container->callback) || 0 != RARRAY_LEN(rwp->progress_callback_container->callbacks)) { + if (ruby_whisper_callback_container_is_present(rwp->progress_callback_container)) { rb_raise(rb_eRuntimeError, "progress callback not supported on parallel transcription"); } - if (!NIL_P(rwp->encoder_begin_callback_container->callback) || 0 != RARRAY_LEN(rwp->encoder_begin_callback_container->callbacks)) { + if (ruby_whisper_callback_container_is_present(rwp->encoder_begin_callback_container)) { rb_raise(rb_eRuntimeError, "encoder begin callback not supported on parallel transcription"); } - if (!NIL_P(rwp->abort_callback_container->callback) || 0 != RARRAY_LEN(rwp->abort_callback_container->callbacks)) { + if (ruby_whisper_abort_callback_container_is_present(rwp->abort_callback_container)) { rb_raise(rb_eRuntimeError, "abort callback not supported on parallel transcription"); } @@ -217,29 +375,28 @@ check_thread_safety(ruby_whisper_params *rwp, VALUE *context, int n_processors) } static void register_callbacks(ruby_whisper_params * rwp, VALUE * context) { - if (!NIL_P(rwp->new_segment_callback_container->callback) || 0 != RARRAY_LEN(rwp->new_segment_callback_container->callbacks)) { + if (ruby_whisper_callback_container_is_present(rwp->new_segment_callback_container)) { rwp->new_segment_callback_container->context = context; rwp->params.new_segment_callback = new_segment_callback; rwp->params.new_segment_callback_user_data = rwp->new_segment_callback_container; } - if (!NIL_P(rwp->progress_callback_container->callback) || 0 != RARRAY_LEN(rwp->progress_callback_container->callbacks)) { + if (ruby_whisper_callback_container_is_present(rwp->progress_callback_container)) { rwp->progress_callback_container->context = context; rwp->params.progress_callback = progress_callback; rwp->params.progress_callback_user_data = rwp->progress_callback_container; } - if (!NIL_P(rwp->encoder_begin_callback_container->callback) || 0 != RARRAY_LEN(rwp->encoder_begin_callback_container->callbacks)) { + if (ruby_whisper_callback_container_is_present(rwp->encoder_begin_callback_container)) { rwp->encoder_begin_callback_container->context = context; rwp->params.encoder_begin_callback = encoder_begin_callback; rwp->params.encoder_begin_callback_user_data = rwp->encoder_begin_callback_container; } - if (!NIL_P(rwp->abort_callback_container->callback) || 0 != RARRAY_LEN(rwp->abort_callback_container->callbacks)) { - rwp->abort_callback_container->context = context; - rwp->params.abort_callback = abort_callback; - rwp->params.abort_callback_user_data = rwp->abort_callback_container; - } + rwp->abort_callback_container->context = context; + rwp->params.abort_callback = abort_callback; + rwp->abort_callback_container->is_interrupted = false; + rwp->params.abort_callback_user_data = rwp->abort_callback_container; } static void set_vad_params(ruby_whisper_params *rwp) @@ -255,7 +412,7 @@ static void set_vad_params(ruby_whisper_params *rwp) void prepare_transcription(ruby_whisper_params *rwp, VALUE *context, int n_processors) { - check_thread_safety(rwp, context, n_processors); + check_thread_safety(rwp, n_processors); register_callbacks(rwp, context); set_vad_params(rwp); } @@ -267,7 +424,7 @@ rb_whisper_params_mark(void *p) rb_whisper_callbcack_container_mark(rwp->new_segment_callback_container); rb_whisper_callbcack_container_mark(rwp->progress_callback_container); rb_whisper_callbcack_container_mark(rwp->encoder_begin_callback_container); - rb_whisper_callbcack_container_mark(rwp->abort_callback_container); + rb_whisper_abort_callback_container_mark(rwp->abort_callback_container); rb_gc_mark(rwp->vad_params); } @@ -338,7 +495,7 @@ ruby_whisper_params_allocate(VALUE klass) rwp->new_segment_callback_container = rb_whisper_callback_container_allocate(); rwp->progress_callback_container = rb_whisper_callback_container_allocate(); rwp->encoder_begin_callback_container = rb_whisper_callback_container_allocate(); - rwp->abort_callback_container = rb_whisper_callback_container_allocate(); + rwp->abort_callback_container = rb_whisper_abort_callback_container_allocate(); return obj; } @@ -1302,6 +1459,9 @@ ruby_whisper_params_on_new_segment(VALUE self) ruby_whisper_params *rwp; TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); const VALUE blk = rb_block_proc(); + if (NIL_P(rwp->new_segment_callback_container->callbacks)) { + rwp->new_segment_callback_container->callbacks = rb_ary_new(); + } rb_ary_push(rwp->new_segment_callback_container->callbacks, blk); return Qnil; } @@ -1322,6 +1482,9 @@ ruby_whisper_params_on_progress(VALUE self) ruby_whisper_params *rwp; TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); const VALUE blk = rb_block_proc(); + if (NIL_P(rwp->progress_callback_container->callbacks)) { + rwp->progress_callback_container->callbacks = rb_ary_new(); + } rb_ary_push(rwp->progress_callback_container->callbacks, blk); return Qnil; } @@ -1342,6 +1505,9 @@ ruby_whisper_params_on_encoder_begin(VALUE self) ruby_whisper_params *rwp; TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); const VALUE blk = rb_block_proc(); + if (NIL_P(rwp->encoder_begin_callback_container->callbacks)) { + rwp->encoder_begin_callback_container->callbacks = rb_ary_new(); + } rb_ary_push(rwp->encoder_begin_callback_container->callbacks, blk); return Qnil; } @@ -1366,6 +1532,9 @@ ruby_whisper_params_abort_on(VALUE self) ruby_whisper_params *rwp; TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); const VALUE blk = rb_block_proc(); + if (NIL_P(rwp->abort_callback_container->callbacks)) { + rwp->abort_callback_container->callbacks = rb_ary_new(); + } rb_ary_push(rwp->abort_callback_container->callbacks, blk); return Qnil; } diff --git a/bindings/ruby/ext/ruby_whisper_transcribe.cpp b/bindings/ruby/ext/ruby_whisper_transcribe.cpp index 3d00566009a..37656af1c44 100644 --- a/bindings/ruby/ext/ruby_whisper_transcribe.cpp +++ b/bindings/ruby/ext/ruby_whisper_transcribe.cpp @@ -15,8 +15,37 @@ extern ID id_call; extern ID id_to_path; extern ID transcribe_option_names[1]; -extern void -prepare_transcription(ruby_whisper_params * rwp, VALUE * self, int n_processors); +extern void prepare_transcription(ruby_whisper_params * rwp, VALUE * self, int n_processors); + +typedef struct{ + struct whisper_context *context; + struct whisper_full_params *params; + float *samples; + size_t n_samples; + int n_processors; + int result; +} transcribe_without_gvl_args; + +static void* +transcribe_without_gvl(void *rb_args) +{ + transcribe_without_gvl_args *args = (transcribe_without_gvl_args *)rb_args; + args->result = whisper_full_parallel(args->context, *args->params, args->samples, args->n_samples, args->n_processors); + + return NULL; +} + +typedef struct { + ruby_whisper_abort_callback_container *abort_callback_container; +} transcribe_ubf_args; + +static void +transcribe_ubf(void *rb_args) +{ + transcribe_ubf_args *args = (transcribe_ubf_args *)rb_args; + + args->abort_callback_container->is_interrupted = true; +} /* * transcribe a single file @@ -75,7 +104,19 @@ ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) { prepare_transcription(rwp, &self, n_processors); - if (whisper_full_parallel(rw->context, rwp->params, pcmf32.data(), pcmf32.size(), n_processors) != 0) { + transcribe_without_gvl_args args = { + rw->context, + &rwp->params, + pcmf32.data(), + pcmf32.size(), + n_processors, + 0, + }; + transcribe_ubf_args ubf_args = { + rwp->abort_callback_container, + }; + rb_thread_call_without_gvl(transcribe_without_gvl, (void *)&args, transcribe_ubf, (void *)&ubf_args); + if (args.result != 0) { fprintf(stderr, "failed to process audio\n"); return self; } diff --git a/bindings/ruby/extsources.rb b/bindings/ruby/extsources.rb index b24f1a7f13d..850ac9841b1 100644 --- a/bindings/ruby/extsources.rb +++ b/bindings/ruby/extsources.rb @@ -5,37 +5,53 @@ .devops .github ci - examples/wchess/wchess.wasm + examples/addon.node + examples/bench.wasm + examples/command + examples/command.wasm + examples/lsp + examples/main + examples/python + examples/stream + examples/stream.wasm + examples/sycl + examples/talk-llama + examples/wchess examples/whisper.android examples/whisper.android.java + examples/whisper.nvim examples/whisper.objc examples/whisper.swiftui + examples/whisper.wasm grammars models samples scripts + tests ].collect {|dir| root/dir} ignored_files = %w[ AUTHORS Makefile - README.md - README_sycl.md .gitignore .gitmodules .dockerignore - whisper.nvim - twitch.sh - yt-wsp.sh - close-issue.yml - build-xcframework.sh +] +ignored_exts = %w[ + .yml + .sh + .md + .py + .js + .nvim ] EXTSOURCES = `git ls-files -z #{root}`.split("\x0") .collect {|file| Pathname(file)} .reject {|file| - ignored_dirs.any? {|dir| file.descend.any? {|desc| desc == dir}} || + ignored_exts.include?(file.extname) || ignored_files.include?(file.basename.to_path) || - (file.descend.to_a[1] != root && file.descend.to_a[1] != Pathname("..")/"javascript") + ignored_dirs.any? {|dir| file.descend.any? {|desc| desc == dir}} || + (file.descend.to_a[1] != root && file != Pathname("..")/"javascript"/"package-tmpl.json") } .collect(&:to_path) diff --git a/bindings/ruby/sig/whisper.rbs b/bindings/ruby/sig/whisper.rbs index 3c59661975b..cbec4803820 100644 --- a/bindings/ruby/sig/whisper.rbs +++ b/bindings/ruby/sig/whisper.rbs @@ -5,10 +5,10 @@ module Whisper end type log_callback = ^(Integer level, String message, Object user_data) -> void - type new_segment_callback = ^(Whisper::Context, void, Integer n_new, Object user_data) -> void - type progress_callback = ^(Whisper::Context, void, Integer progress, Object user_data) -> void - type encoder_begin_callback = ^(Whisper::Context, void, Object user_data) -> void - type abort_callback = ^(Whisper::Context, void, Object user_data) -> boolish + type new_segment_callback = ^(Whisper::Context, untyped, Integer n_new, Object user_data) -> void + type progress_callback = ^(Whisper::Context, untyped, Integer progress, Object user_data) -> void + type encoder_begin_callback = ^(Whisper::Context, untyped, Object user_data) -> void + type abort_callback = ^(Whisper::Context, untyped, Object user_data) -> boolish VERSION: String LOG_LEVEL_NONE: Integer @@ -52,11 +52,11 @@ module Whisper # puts text # end # - # If n_processors is greater than 1, you cannot set any callbacks including + # If `n_processors` is greater than 1, you cannot set any callbacks including # new_segment_callback, progress_callback, encoder_begin_callback, abort_callback, # and log_callback set by Whisper.log_set - def transcribe: (path, Params, ?n_processors: Integer) -> self - | (path, Params, ?n_processors: Integer) { (String) -> void } -> self + def transcribe: (path, Whisper::Params, ?n_processors: Integer) -> self + | (path, Whisper::Params, ?n_processors: Integer) { (String) -> void } -> self def model_n_vocab: () -> Integer def model_n_audio_ctx: () -> Integer @@ -74,7 +74,7 @@ module Whisper # puts segment.text # end # - # Returns an Enumerator if no block given: + # Returns an `Enumerator` if no block given: # # whisper.transcribe("path/to/audio.wav", params) # enum = whisper.each_segment @@ -91,25 +91,25 @@ module Whisper # def full_lang_id: () -> Integer - # Start time of a segment indexed by +segment_index+ in centiseconds (10 times milliseconds). + # Start time of a segment indexed by `segment_index` in centiseconds (10 times milliseconds). # # full_get_segment_t0(3) # => 1668 (16680 ms) # def full_get_segment_t0: (Integer) -> Integer - # End time of a segment indexed by +segment_index+ in centiseconds (10 times milliseconds). + # End time of a segment indexed by `segment_index` in centiseconds (10 times milliseconds). # # full_get_segment_t1(3) # => 1668 (16680 ms) # def full_get_segment_t1: (Integer) -> Integer - # Whether the next segment indexed by +segment_index+ is predicated as a speaker turn. + # Whether the next segment indexed by `segment_index` is predicated as a speaker turn. # # full_get_segment_speacker_turn_next(3) # => true # def full_get_segment_speaker_turn_next: (Integer) -> (true | false) - # Text of a segment indexed by +segment_index+. + # Text of a segment indexed by `segment_index`. # # full_get_segment_text(3) # => "ask not what your country can do for you, ..." # @@ -117,27 +117,27 @@ module Whisper def full_get_segment_no_speech_prob: (Integer) -> Float - # Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text - # Not thread safe for same context + # Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text + # Not thread safe for same context # Uses the specified decoding strategy to obtain the text. # - # The second argument +samples+ must be an array of samples, respond to :length, or be a MemoryView of an array of float. It must be 32 bit float PCM audio data. + # The second argument `samples` must be an array of samples, respond to `:length`, or be a MemoryView of an array of float. It must be 32 bit float PCM audio data. # - def full: (Params, Array[Float] samples, ?Integer n_samples) -> self - | (Params, _Samples, ?Integer n_samples) -> self + def full: (Whisper::Params, Array[Float] samples, ?Integer n_samples) -> self + | (Whisper::Params, _Samples, ?Integer n_samples) -> self - # Split the input audio in chunks and process each chunk separately using whisper_full_with_state() - # Result is stored in the default state of the context - # Not thread safe if executed in parallel on the same context. - # It seems this approach can offer some speedup in some cases. + # Split the input audio in chunks and process each chunk separately using `whisper_full_with_state()` + # Result is stored in the default state of the context + # Not thread safe if executed in parallel on the same context. + # It seems this approach can offer some speedup in some cases. # However, the transcription accuracy can be worse at the beginning and end of each chunk. # - # If n_processors is greater than 1, you cannot set any callbacks including + # If `n_processors` is greater than 1, you cannot set any callbacks including # new_segment_callback, progress_callback, encoder_begin_callback, abort_callback, # and log_callback set by Whisper.log_set - def full_parallel: (Params, Array[Float], ?Integer n_samples) -> self - | (Params, _Samples, ?Integer n_samples) -> self - | (Params, _Samples, ?Integer? n_samples, Integer n_processors) -> self + def full_parallel: (Whisper::Params, Array[Float], ?Integer n_samples) -> self + | (Whisper::Params, _Samples, ?Integer n_samples) -> self + | (Whisper::Params, _Samples, ?Integer? n_samples, Integer n_processors) -> self def to_srt: () -> String def to_webvtt: () -> String @@ -217,35 +217,35 @@ module Whisper def translate: () -> (true | false) def no_context=: (boolish) -> boolish - # If true, does not use past transcription (if any) as initial prompt for the decoder. + # If `true`, does not use past transcription (if any) as initial prompt for the decoder. # def no_context: () -> (true | false) def single_segment=: (boolish) -> boolish - # If true, forces single segment output (useful for streaming). + # If `true`, forces single segment output (useful for streaming). # def single_segment: () -> (true | false) def print_special=: (boolish) -> boolish - # If true, prints special tokens (e.g. , , , etc.). + # If `true`, prints special tokens (e.g. , , , etc.). # def print_special: () -> (true | false) def print_progress=: (boolish) -> boolish - # If true, prints progress information. + # If `true`, prints progress information. # def print_progress: () -> (true | false) def print_realtime=: (boolish) -> boolish - # If true, prints results from within whisper.cpp. (avoid it, use callback instead) + # If `true`, prints results from within whisper.cpp. (avoid it, use callback instead) # def print_realtime: () -> (true | false) - # If true, prints timestamps for each text segment when printing realtime. + # If `true`, prints timestamps for each text segment when printing realtime. # def print_timestamps=: (boolish) -> boolish @@ -253,19 +253,19 @@ module Whisper def suppress_blank=: (boolish) -> boolish - # If true, suppresses blank outputs. + # If `true`, suppresses blank outputs. # def suppress_blank: () -> (true | false) def suppress_nst=: (boolish) -> boolish - # If true, suppresses non-speech-tokens. + # If `true`, suppresses non-speech-tokens. # def suppress_nst: () -> (true | false) def token_timestamps=: (boolish) -> boolish - # If true, enables token-level timestamps. + # If `true`, enables token-level timestamps. # def token_timestamps: () -> (true | false) @@ -277,16 +277,16 @@ module Whisper def split_on_word=: (boolish) -> boolish - # If true, split on word rather than on token (when used with max_len). + # If `true`, split on word rather than on token (when used with max_len). # def split_on_word: () -> (true | false) def initial_prompt=: (_ToS) -> _ToS def carry_initial_prompt=: (boolish) -> boolish - # Tokens to provide to the whisper decoder as initial prompt - # these are prepended to any existing text context from a previous call - # use whisper_tokenize() to convert text to tokens. + # Tokens to provide to the whisper decoder as initial prompt + # these are prepended to any existing text context from a previous call + # use whisper_tokenize() to convert text to tokens. # Maximum of whisper_n_text_ctx()/2 tokens are used (typically 224). # def initial_prompt: () -> (String | nil) @@ -294,7 +294,7 @@ module Whisper def diarize=: (boolish) -> boolish - # If true, enables diarization. + # If `true`, enables diarization. # def diarize: () -> (true | false) @@ -423,7 +423,7 @@ module Whisper # def on_new_segment: { (Segment) -> void } -> void - # Hook called on progress update. Yields each progress Integer between 0 and 100. + # Hook called on progress update. Yields each progress `Integer` between 0 and 100. # def on_progress: { (Integer progress) -> void } -> void @@ -431,7 +431,7 @@ module Whisper # def on_encoder_begin: { () -> void } -> void - # Call block to determine whether abort or not. Return +true+ when you want to abort. + # Call block to determine whether abort or not. Return `true` when you want to abort. # # params.abort_on do # if some_condition @@ -504,13 +504,13 @@ module Whisper # Yields each Whisper::Token: # - # whisper.each_segment.first.each_token do |token| - # p token - # end + # whisper.each_segment.first.each_token do |token| + # p token + # end # - # Returns an Enumerator if no block is given: + # Returns an `Enumerator` if no block is given: # - # whisper.each_segment.first.each_token.to_a # => [#, ...] + # whisper.each_segment.first.each_token.to_a # => [#, ...] # def each_token: { (Token) -> void } -> void | () -> Enumerator[Token] @@ -518,7 +518,7 @@ module Whisper def to_webvtt_cue: () -> String - # Possible keys: :start_time, :end_time, :text, :no_speech_prob, :speaker_turn_next + # Possible keys: `:start_time`, `:end_time`, `:text`, `:no_speech_prob`, `:speaker_turn_next` # # whisper.each_segment do |segment| # segment => {start_time:, end_time:, text:, no_speech_prob:, speaker_turn_next:} @@ -569,7 +569,7 @@ module Whisper # [EXPERIMENTAL] Token-level timestamps with DTW # - # Do not use if you haven't computed token-level timestamps with dtw. + # Do not use if you haven't computed token-level timestamps with dtw. # Roughly corresponds to the moment in audio in which the token was output. # def t_dtw: () -> Integer @@ -580,14 +580,14 @@ module Whisper # Start time of the token. # - # Token-level timestamp data. + # Token-level timestamp data. # Do not use if you haven't computed token-level timestamps. # def start_time: () -> Integer # End time of the token. # - # Token-level timestamp data. + # Token-level timestamp data. # Do not use if you haven't computed token-level timestamps. # def end_time: () -> Integer diff --git a/bindings/ruby/test/test_package.rb b/bindings/ruby/test/test_package.rb index 108f34efbeb..f99012cce83 100644 --- a/bindings/ruby/test/test_package.rb +++ b/bindings/ruby/test/test_package.rb @@ -1,12 +1,12 @@ require_relative "helper" require 'tempfile' require 'tmpdir' -require 'shellwords' +require 'open3' class TestPackage < TestBase def test_build Tempfile.create do |file| - assert system("gem", "build", "whispercpp.gemspec", "--output", file.to_path.shellescape, exception: true) + assert system("gem", "build", "whispercpp.gemspec", "--output", file.to_path, exception: true) assert file.size > 0 assert_path_exist file.to_path end @@ -20,7 +20,7 @@ def setup def test_install gemspec = Gem::Specification.load("whispercpp.gemspec") Dir.mktmpdir do |dir| - system "gem", "install", "--install-dir", dir.shellescape, "--no-document", "pkg/#{gemspec.file_name.shellescape}", exception: true + system "gem", "install", "--install-dir", dir, "--no-document", File.join("pkg", gemspec.file_name), exception: true assert_installed dir, gemspec.version end end @@ -29,13 +29,14 @@ def test_install_with_coreml omit_unless RUBY_PLATFORM.match?(/darwin/) do gemspec = Gem::Specification.load("whispercpp.gemspec") Dir.mktmpdir do |dir| - system "gem", "install", "--install-dir", dir.shellescape, "--no-document", "pkg/#{gemspec.file_name.shellescape}", "--", "--enable-whisper-coreml", exception: true + system "gem", "install", "--install-dir", dir, "--no-document", File.join("pkg", gemspec.file_name), "--", "--enable-whisper-coreml", exception: true assert_installed dir, gemspec.version libdir = File.join(dir, "gems", "#{gemspec.name}-#{gemspec.version}", "lib") assert_nothing_raised do system "ruby", "-I", libdir, "-r", "whisper", "-e", "Whisper::Context.new('tiny')", exception: true end - assert_match(/COREML = 1/, `ruby -I #{libdir.shellescape} -r whisper -e 'puts Whisper.system_info_str'`) + output, status = Open3.capture2("ruby", "-I", libdir, "-r", "whisper", "-e", "puts Whisper.system_info_str") + assert_match /COREML = 1/, output end end end From c33c5618b72bb345df029b730b36bc0e369845a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bjarke=20Viks=C3=B8e?= <164612031+bviksoe@users.noreply.github.com> Date: Sun, 10 May 2026 16:24:12 +0200 Subject: [PATCH 544/831] whisper : fix incorrect timestamps, usually near silences (#2279) * Incorrect timetstamps Fixes #2271 - Adds consecutive timestamps after end of last segment as the new starting ts - Add these timestamp to output when "print-special" enabled - Fixes fflush usage in live reporting I was not able to test this with the special "token_timestamps" option. * Skip initial timestamp --- src/whisper.cpp | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/whisper.cpp b/src/whisper.cpp index 2f356da0f06..6176d21f53c 100644 --- a/src/whisper.cpp +++ b/src/whisper.cpp @@ -7659,11 +7659,14 @@ int whisper_full_with_state( } } text = ""; - while (i < (int) tokens_cur.size() && tokens_cur[i].id > whisper_token_beg(ctx)) { + t0 = t1; + while (i + 1 < (int) tokens_cur.size() && tokens_cur[i + 1].id > whisper_token_beg(ctx)) { i++; + if (params.print_special) { + text += whisper_token_to_str(ctx, tokens_cur[i].id); + } + t0 = seek + 2 * (tokens_cur[i].tid - whisper_token_beg(ctx)); } - i--; - t0 = t1; i0 = i + 1; speaker_turn_next = false; } @@ -7680,8 +7683,8 @@ int whisper_full_with_state( printf("[%s --> %s] %s\n", to_timestamp(tt0).c_str(), to_timestamp(tt1).c_str(), text.c_str()); } else { printf("%s", text.c_str()); - fflush(stdout); } + fflush(stdout); } result_all.push_back({ tt0, tt1, text, state->no_speech_prob, {}, speaker_turn_next }); From 338cce1e58133261753243802a0e7a430118866d Mon Sep 17 00:00:00 2001 From: Andreas Lubbe Date: Tue, 12 May 2026 07:36:00 +0200 Subject: [PATCH 545/831] server: Add support for controlling token_timestamps directly (#3785) --- examples/server/server.cpp | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index f6a7a83181a..08c0988d2be 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -101,6 +101,7 @@ struct whisper_params { bool print_realtime = false; bool print_progress = false; bool no_timestamps = false; + bool token_timestamps = true; bool use_gpu = true; bool flash_attn = true; int32_t gpu_device = 0; @@ -550,6 +551,12 @@ void get_req_parameters(const Request & req, whisper_params & params) { params.no_timestamps = parse_str_to_bool(req.get_file_value("no_timestamps").content); } + if (req.has_file("token_timestamps")) + { + params.token_timestamps = parse_str_to_bool(req.get_file_value("token_timestamps").content); + } else { + params.token_timestamps = !params.no_timestamps; + } if (req.has_file("language")) { params.language = req.get_file_value("language").content; @@ -690,10 +697,10 @@ int main(int argc, char ** argv) { if (params.dtw == "large.v3") { cparams.dtw_aheads_preset = WHISPER_AHEADS_LARGE_V3; } - if (params.dtw == "large.v3.turbo") { + if (params.dtw == "large.v3.turbo") { cparams.dtw_aheads_preset = WHISPER_AHEADS_LARGE_V3_TURBO; } - + if (cparams.dtw_aheads_preset == WHISPER_AHEADS_NONE) { fprintf(stderr, "error: unknown DTW preset '%s'\n", params.dtw.c_str()); return 3; @@ -939,7 +946,7 @@ int main(int argc, char ** argv) { wparams.logprob_thold = params.logprob_thold; wparams.no_timestamps = params.no_timestamps; - wparams.token_timestamps = !params.no_timestamps; + wparams.token_timestamps = params.token_timestamps; wparams.no_context = params.no_context; wparams.suppress_nst = params.suppress_nst; @@ -1043,7 +1050,7 @@ int main(int argc, char ** argv) { res.set_content(ss.str(), "text/vtt"); } else if (params.response_format == vjson_format) { /* try to match openai/whisper's Python format */ - std::string results = output_str(ctx, params, pcmf32s); + std::string results = output_str(ctx, params, pcmf32s); json jres = json{ {"task", params.translate ? "translate" : "transcribe"}, {"language", whisper_lang_str_full(whisper_full_lang_id(ctx))}, @@ -1088,7 +1095,7 @@ int main(int argc, char ** argv) { segment["tokens"].push_back(token.id); json word = json{{"word", whisper_full_get_token_text(ctx, i, j)}}; - if (!params.no_timestamps) { + if (!params.no_timestamps && params.token_timestamps) { word["start"] = token.t0 * 0.01; word["end"] = token.t1 * 0.01; word["t_dtw"] = token.t_dtw; From f08258abd74b995bb95d8005103f72f1afd66a8a Mon Sep 17 00:00:00 2001 From: annaeina <2846698728@qq.com> Date: Wed, 13 May 2026 13:32:00 +0800 Subject: [PATCH 546/831] whisper : fix max_tokens skipping remaining audio (#3798) * whisper: fix max_tokens skipping remaining audio * add PR reference comment as suggested Co-authored-by: Georgi Gerganov * fix(ci): enable artifact overwrite --- .github/workflows/build.yml | 1 + bindings/go/pkg/whisper/context_test.go | 48 +++++++++++++++++++++++++ src/whisper.cpp | 12 +++++++ 3 files changed, 61 insertions(+) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index fb115b22abb..be3f78a3f5b 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -662,6 +662,7 @@ jobs: with: name: ggml_${{ matrix.arch }}.dll path: build/bin/${{ matrix.build }}/ggml.dll + overwrite: true - name: Upload ggml base dll uses: actions/upload-artifact@v6 diff --git a/bindings/go/pkg/whisper/context_test.go b/bindings/go/pkg/whisper/context_test.go index e98a4c2b80b..79f6a593024 100644 --- a/bindings/go/pkg/whisper/context_test.go +++ b/bindings/go/pkg/whisper/context_test.go @@ -2,6 +2,7 @@ package whisper_test import ( "os" + "strings" "testing" "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper" @@ -92,6 +93,53 @@ func TestProcess(t *testing.T) { assert.NoError(err) } +func TestProcessMaxTokensPerSegment(t *testing.T) { + assert := assert.New(t) + + if _, err := os.Stat(ModelPath); os.IsNotExist(err) { + t.Skip("Skipping test, model not found:", ModelPath) + } + + fh, err := os.Open(SamplePath) + assert.NoError(err) + defer fh.Close() + + // Decode the WAV file - load the full buffer + dec := wav.NewDecoder(fh) + buf, err := dec.FullPCMBuffer() + assert.NoError(err) + assert.Equal(uint16(1), dec.NumChans) + + data := buf.AsFloat32Buffer().Data + + model, err := whisper.New(ModelPath) + assert.NoError(err) + assert.NotNil(model) + defer model.Close() + + context, err := model.NewContext() + assert.NoError(err) + + context.SetMaxTokensPerSegment(5) + + err = context.Process(data, nil, nil, nil) + assert.NoError(err) + + var text strings.Builder + nSegments := 0 + for { + segment, err := context.NextSegment() + if err != nil { + break + } + nSegments++ + text.WriteString(segment.Text) + } + + assert.Greater(nSegments, 1) + assert.Contains(text.String(), "country") +} + func TestDetectedLanguage(t *testing.T) { assert := assert.New(t) diff --git a/src/whisper.cpp b/src/whisper.cpp index 6176d21f53c..210ca597fb4 100644 --- a/src/whisper.cpp +++ b/src/whisper.cpp @@ -6216,6 +6216,13 @@ static void whisper_process_logits( } } + // ref: https://github.com/ggml-org/whisper.cpp/pull/3798 + if (!params.no_timestamps && !params.single_segment && params.max_tokens > 0 && (int) tokens_cur.size() >= params.max_tokens) { + for (int i = 0; i < vocab.token_eot; ++i) { + logits[i] = -INFINITY; + } + } + // suppress sot and nosp tokens logits[vocab.token_sot] = -INFINITY; logits[vocab.token_nosp] = -INFINITY; @@ -7725,7 +7732,12 @@ int whisper_full_with_state( } // ref: https://github.com/ggml-org/whisper.cpp/pull/2629 + const bool max_tokens_timestamp_ending = params.max_tokens > 0 && + !params.single_segment && + tokens_cur.size() > (size_t) params.max_tokens; + const bool single_timestamp_ending = tokens_cur.size() > 1 && + !max_tokens_timestamp_ending && tokens_cur[tokens_cur.size() - 2].id < whisper_token_beg(ctx) && tokens_cur[tokens_cur.size() - 1].id > whisper_token_beg(ctx); if (single_timestamp_ending) { From a604a9b5b0ff9108191769a09843ae325c6c0d7f Mon Sep 17 00:00:00 2001 From: Andreas Lubbe Date: Wed, 13 May 2026 08:54:56 +0200 Subject: [PATCH 547/831] server: fix params leak between requests (#3784) --- examples/server/server.cpp | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 08c0988d2be..c582c448de1 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -824,7 +824,7 @@ int main(int argc, char ** argv) { } auto audio_file = req.get_file_value("file"); - // check non-required fields + whisper_params params = default_params; get_req_parameters(req, params); std::string filename{audio_file.filename}; @@ -1127,9 +1127,6 @@ int main(int argc, char ** argv) { res.set_content(jres.dump(-1, ' ', false, json::error_handler_t::replace), "application/json"); } - - // reset params to their defaults - params = default_params; }); svr->Post(sparams.request_path + "/load", [&](const Request &req, Response &res){ std::lock_guard lock(whisper_mutex); From 3e9b7d0fef3528ee2208da3cdb873a2c53d2ae2f Mon Sep 17 00:00:00 2001 From: Andreas Lubbe Date: Wed, 13 May 2026 10:37:28 +0200 Subject: [PATCH 548/831] server : fix no_speech_thold not being read (#3783) --- examples/server/server.cpp | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index c582c448de1..735255b6290 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -87,7 +87,7 @@ struct whisper_params { float logprob_thold = -1.00f; float temperature = 0.00f; float temperature_inc = 0.20f; - float no_speech_thold = 0.6f; + float no_speech_thold = 0.6f; bool debug_mode = false; bool translate = false; @@ -527,6 +527,10 @@ void get_req_parameters(const Request & req, whisper_params & params) { params.logprob_thold = std::stof(req.get_file_value("logprob_thold").content); } + if (req.has_file("no_speech_thold")) + { + params.no_speech_thold = std::stof(req.get_file_value("no_speech_thold").content); + } if (req.has_file("debug_mode")) { params.debug_mode = parse_str_to_bool(req.get_file_value("debug_mode").content); @@ -762,6 +766,7 @@ int main(int argc, char ** argv) { -F file="@<file-path>" \ -F temperature="0.0" \ -F temperature_inc="0.2" \ + -F no_speech_thold="0.6" \ -F response_format="json"
@@ -940,7 +945,7 @@ int main(int argc, char ** argv) { wparams.beam_search.beam_size = params.beam_size; wparams.temperature = params.temperature; - wparams.no_speech_thold = params.no_speech_thold; + wparams.no_speech_thold = params.no_speech_thold; wparams.temperature_inc = params.temperature_inc; wparams.entropy_thold = params.entropy_thold; wparams.logprob_thold = params.logprob_thold; From ff5704a416813610e30e54d864c5af1be41288c6 Mon Sep 17 00:00:00 2001 From: Shawn Gu Date: Fri, 1 May 2026 23:02:24 -0700 Subject: [PATCH 549/831] opencl: Adreno optimization for MoE - MxFP4 (llama/22301) * MoE Mxfp4 CLC kernel added, router reorder on GPU * Pass test-backend-ops for MoE mxfp4 Adreno CLC * remove putenv in llama-model.cpp * fix indent style and whitespace * opencl: remove unnecessary headers * opencl: do not save cl_program objects * opencl: remove unnecessary assert * fix precision issue --------- Co-authored-by: Li He --- ggml/src/ggml-opencl/CMakeLists.txt | 4 + ggml/src/ggml-opencl/ggml-opencl.cpp | 451 +++++++++++++++--- ggml/src/ggml-opencl/kernels/cvt.cl | 87 ++++ .../kernels/gemm_moe_mxfp4_f32_ns.cl | 302 ++++++++++++ .../kernels/gemv_moe_mxfp4_f32_ns.cl | 161 +++++++ ggml/src/ggml-opencl/kernels/moe_reorder_b.cl | 30 ++ .../ggml-opencl/kernels/moe_sort_by_expert.cl | 82 ++++ 7 files changed, 1040 insertions(+), 77 deletions(-) create mode 100644 ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl create mode 100644 ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl create mode 100644 ggml/src/ggml-opencl/kernels/moe_reorder_b.cl create mode 100644 ggml/src/ggml-opencl/kernels/moe_sort_by_expert.cl diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index 5ed83eeb48a..35d425a431f 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -107,6 +107,10 @@ set(GGML_OPENCL_KERNELS mul_mv_id_mxfp4_f32_flat gemm_moe_mxfp4_f32 gemv_moe_mxfp4_f32 + gemm_moe_mxfp4_f32_ns + gemv_moe_mxfp4_f32_ns + moe_reorder_b + moe_sort_by_expert mul_mm_f32_f32_l4_lm mul_mm_f16_f32_l4_lm mul_mm_q4_0_f32_l4_lm diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 11f72a5198a..74948c27e4e 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -416,6 +416,15 @@ struct ggml_backend_opencl_context { ggml_cl_buffer prealloc_src0; ggml_cl_buffer prealloc_src1; + // prealloc buffers for MoE router table preprocess + bool toggle_reorder = false; + ggml_cl_buffer prealloc_post_router; + ggml_cl_buffer prealloc_emap; + ggml_cl_buffer prealloc_hist; + ggml_cl_buffer prealloc_tile_offset; + ggml_cl_buffer prealloc_total_tiles; + ggml_cl_buffer prealloc_slot_counter; + cl_program program_add; cl_program program_add_id; cl_program program_clamp; @@ -531,6 +540,7 @@ struct ggml_backend_opencl_context { cl_kernel kernel_convert_block_q4_0, kernel_restore_block_q4_0; cl_kernel kernel_convert_block_q4_1, kernel_restore_block_q4_1; cl_kernel kernel_convert_block_mxfp4, kernel_convert_block_mxfp4_trans, kernel_restore_block_mxfp4, kernel_restore_block_mxfp4_trans; + cl_kernel kernel_convert_block_mxfp4_trans4_ns, kernel_restore_block_mxfp4_trans4_ns; cl_kernel kernel_convert_block_q8_0, kernel_restore_block_q8_0, kernel_restore_block_q8_0_trans; cl_kernel kernel_convert_block_q6_K_noshuffle, kernel_restore_block_q6_K_noshuffle; cl_kernel kernel_mul_mat_q4_0_f32_8x_flat; @@ -587,6 +597,9 @@ struct ggml_backend_opencl_context { cl_kernel kernel_ssm_conv_f32_f32, kernel_ssm_conv_f32_f32_4; cl_kernel kernel_timestep_embedding; cl_kernel kernel_gemv_moe_mxfp4_f32, kernel_gemm_moe_mxfp4_f32; + cl_kernel kernel_gemv_moe_mxfp4_f32_ns, kernel_gemm_moe_mxfp4_f32_ns; + cl_kernel kernel_moe_reorder_b; + cl_kernel kernel_moe_histogram, kernel_moe_scan, kernel_moe_fill, kernel_moe_scatter; cl_kernel kernel_mul_mv_id_q4_0_f32_8x_flat; cl_kernel kernel_mul_mv_id_q8_0_f32, kernel_mul_mv_id_q8_0_f32_flat; cl_kernel kernel_mul_mv_id_mxfp4_f32; @@ -945,6 +958,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve CL_CHECK((backend_ctx->kernel_restore_block_q4_1 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_1", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_mxfp4 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_mxfp4_trans = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4_trans", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_mxfp4_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4_trans4_ns", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_mxfp4_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_mxfp4_trans4_ns", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_mxfp4_trans = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_mxfp4_trans", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_mxfp4 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_mxfp4", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q8_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q8_0", &err), err)); @@ -2864,6 +2879,77 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // gemv_moe_mxfp4_f32_ns + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemv_moe_mxfp4_f32_ns.cl.h" + }; +#else + const std::string kernel_src = read_file("gemv_moe_mxfp4_f32_ns.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); + + CL_CHECK((backend_ctx->kernel_gemv_moe_mxfp4_f32_ns = clCreateKernel(prog, "kernel_gemv_moe_mxfp4_f32_ns", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // gemm_moe_mxfp4_f32_ns + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemm_moe_mxfp4_f32_ns.cl.h" + }; +#else + const std::string kernel_src = read_file("gemm_moe_mxfp4_f32_ns.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); + + CL_CHECK((backend_ctx->kernel_gemm_moe_mxfp4_f32_ns = clCreateKernel(prog, "kernel_gemm_moe_mxfp4_f32_ns", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // moe_reorder_b + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "moe_reorder_b.cl.h" + }; +#else + const std::string kernel_src = read_file("moe_reorder_b.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); + + CL_CHECK((backend_ctx->kernel_moe_reorder_b = clCreateKernel(prog, "kernel_moe_reorder_b", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // moe_sort_by_expert + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "moe_sort_by_expert.cl.h" + }; +#else + const std::string kernel_src = read_file("moe_sort_by_expert.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); + + CL_CHECK((backend_ctx->kernel_moe_histogram = clCreateKernel(prog, "kernel_moe_histogram", &err), err)); + CL_CHECK((backend_ctx->kernel_moe_scan = clCreateKernel(prog, "kernel_moe_scan", &err), err)); + CL_CHECK((backend_ctx->kernel_moe_fill = clCreateKernel(prog, "kernel_moe_fill", &err), err)); + CL_CHECK((backend_ctx->kernel_moe_scatter = clCreateKernel(prog, "kernel_moe_scatter", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + // gemv_noshuffle_q6_k_f32 { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -3651,13 +3737,12 @@ struct ggml_tensor_extra_cl_mxfp4 { CL_CHECK(clReleaseMemObject(e)); e = nullptr; } - if (q != nullptr) { + if (q_img != nullptr) { CL_CHECK(clReleaseMemObject(q_img)); - q = nullptr; + q_img = nullptr; } - // Currently, q_img and d_img are not used. They can be image1d_buffer_t + // Currently, e_img is not used. They can be image1d_buffer_t // that wraps around q and d to utilize image access path. - q_img = nullptr; e_img = nullptr; size_q = 0; size_e = 0; @@ -4740,7 +4825,7 @@ inline bool use_adreno_kernels(const ggml_backend_opencl_context *backend_ctx, c inline bool use_adreno_moe_kernels(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) { GGML_UNUSED(backend_ctx); int ne01 = tensor->ne[1]; - return ((strstr(tensor->name, "ffn") != NULL) || (strstr(tensor->name, "as") != NULL)) && (ne01 % 64 == 0); + return (((strstr(tensor->name, "ffn") != NULL) && (strstr(tensor->name, "exps") != NULL)) || (strstr(tensor->name, "as") != NULL)) && (ne01 % 64 == 0); } inline bool enable_adreno_trans_weight(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) { @@ -5151,8 +5236,9 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, CL_CHECK(err); #ifdef GGML_OPENCL_USE_ADRENO_KERNELS + // Adreno moe mxfp4 kernel needs special transpose and unshuffling if (use_adreno_moe_kernels(backend_ctx, tensor)) { - cl_kernel kernel = backend_ctx->kernel_convert_block_mxfp4_trans; + cl_kernel kernel = backend_ctx->kernel_convert_block_mxfp4_trans4_ns; int ne00 = tensor->ne[0]; int ne01 = tensor->ne[1]; @@ -5172,9 +5258,21 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, CL_CHECK(clReleaseMemObject(data_device)); tensor->extra = extra; + // Create image for Q + cl_image_format img_format_q = {CL_R, CL_UNSIGNED_INT32}; + cl_image_desc img_desc_q = { + CL_MEM_OBJECT_IMAGE1D_BUFFER, + static_cast(ggml_nelements(tensor) / 8), + 0, 0, 0, 0, 0, 0, 0, + { extra->q } + }; + extra->q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_format_q, &img_desc_q, NULL, &err); + tensor->extra = extra; + return; } -#endif + +#endif // GGML_OPENCL_USE_ADRENO_KERNELS cl_kernel kernel = backend_ctx->kernel_convert_block_mxfp4; CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); @@ -5912,7 +6010,7 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, #ifdef GGML_OPENCL_USE_ADRENO_KERNELS if (use_adreno_moe_kernels(backend_ctx, tensor)) { - cl_kernel kernel = backend_ctx->kernel_restore_block_mxfp4_trans; + cl_kernel kernel = backend_ctx->kernel_restore_block_mxfp4_trans4_ns; int ne00 = tensor->ne[0]; int ne01 = tensor->ne[1]; @@ -5936,7 +6034,8 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, CL_CHECK(clReleaseMemObject(data_device)); return; } -#endif + +#endif // GGML_OPENCL_USE_ADRENO_KERNELS cl_kernel kernel = backend_ctx->kernel_restore_block_mxfp4; CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q)); CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->e)); @@ -12763,6 +12862,118 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co } } +static void moe_router_reoerder(ggml_backend_t backend, const ggml_tensor * src, int ne20) { + cl_int err; + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *)src->extra; + cl_ulong offset = extra->offset + src->view_offs; + + const int ne21 = src->ne[1]; + const int nb21 = src->nb[1]; + const int ne02 = nb21 / src->nb[0]; + const int n_tile_size = 32; + const int max_post_router_tile = (ne20 * ne21 / n_tile_size) + ne02; + + cl_buffer_region region; + region.origin = offset; + region.size = nb21 * ne21; + cl_mem original_router_buf = clCreateSubBuffer(extra->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + + backend_ctx->prealloc_post_router.allocate(backend_ctx->context, sizeof(int) * max_post_router_tile * n_tile_size); + region.origin = 0; + region.size = sizeof(int) * max_post_router_tile * n_tile_size; + cl_mem post_router_buf = clCreateSubBuffer(backend_ctx->prealloc_post_router.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + + backend_ctx->prealloc_emap.allocate(backend_ctx->context, sizeof(short) * max_post_router_tile); + region.origin = 0; + region.size = sizeof(short) * max_post_router_tile; + cl_mem emap_buf = clCreateSubBuffer(backend_ctx->prealloc_emap.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + + backend_ctx->prealloc_hist.allocate(backend_ctx->context, sizeof(int) * ne02); + region.origin = 0; + region.size = sizeof(int) * ne02; + cl_mem hist_buf = clCreateSubBuffer(backend_ctx->prealloc_hist.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + + backend_ctx->prealloc_tile_offset.allocate(backend_ctx->context, sizeof(int) * ne02); + region.origin = 0; + region.size = sizeof(int) * ne02; + cl_mem tile_offset_buf = clCreateSubBuffer(backend_ctx->prealloc_tile_offset.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + + backend_ctx->prealloc_slot_counter.allocate(backend_ctx->context, sizeof(int) * ne02); + region.origin = 0; + region.size = sizeof(int) * ne02; + cl_mem slot_counter_buf = clCreateSubBuffer(backend_ctx->prealloc_slot_counter.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + + backend_ctx->prealloc_total_tiles.allocate(backend_ctx->context, sizeof(int)); + region.origin = 0; + region.size = sizeof(int); + cl_mem total_tiles_buf = clCreateSubBuffer(backend_ctx->prealloc_total_tiles.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + + // Histogram + cl_kernel kernel = backend_ctx->kernel_moe_histogram; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &original_router_buf)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &hist_buf)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &ne21)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &ne20)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne02)); + + size_t histogram_global_size[] = {(size_t)(((ne21 + 63) / 64) * 64), static_cast(ne20), 1}; + size_t histogram_local_size[] = {64, static_cast(ne20), 1}; + backend_ctx->enqueue_ndrange_kernel(kernel, 3, histogram_global_size, histogram_local_size, src); + + // Scan + kernel = backend_ctx->kernel_moe_scan; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &hist_buf)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &tile_offset_buf)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &total_tiles_buf)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &slot_counter_buf)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &n_tile_size)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne02)); + + size_t scan_global_size[] = {1}; + size_t scan_local_size[] = {1}; + backend_ctx->enqueue_ndrange_kernel(kernel, 1, scan_global_size, scan_local_size, src); + + // Fill + kernel = backend_ctx->kernel_moe_fill; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &post_router_buf)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &total_tiles_buf)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &n_tile_size)); + + size_t fill_global_size[] = {(size_t)(((max_post_router_tile + 63) / 64) * 64), n_tile_size, 1}; + size_t fill_local_size[] = {64, 1, 1}; + backend_ctx->enqueue_ndrange_kernel(kernel, 3, fill_global_size, fill_local_size, src); + + // Scatter + kernel = backend_ctx->kernel_moe_scatter; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &original_router_buf)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &post_router_buf)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &emap_buf)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &tile_offset_buf)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &slot_counter_buf)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne21)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne20)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne02)); + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, histogram_global_size, histogram_local_size, src); + + CL_CHECK(clReleaseMemObject(original_router_buf)); + CL_CHECK(clReleaseMemObject(hist_buf)); + CL_CHECK(clReleaseMemObject(tile_offset_buf)); + CL_CHECK(clReleaseMemObject(total_tiles_buf)); + CL_CHECK(clReleaseMemObject(slot_counter_buf)); + CL_CHECK(clReleaseMemObject(post_router_buf)); + CL_CHECK(clReleaseMemObject(emap_buf)); +} + static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0); GGML_ASSERT(src0->extra); @@ -12824,6 +13035,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, const int ne0 = dst->ne[0]; const int ne1 = dst->ne[1]; + const int ne2 = dst->ne[2]; const int r2 = ne12/ne02; const int r3 = ne13/ne03; @@ -12836,6 +13048,9 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, int nrows = 1; // number of row in src1 int ndst = 4; // number of values produced by each subgroup + const int n_tile_size = 32; + const int max_post_router_tile = (ne20 * ne21 / n_tile_size) + ne02; + cl_kernel kernel; // subgroup mat vec @@ -12967,11 +13182,10 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, size_t local_size[3] = {64, 2, 1}; size_t global_size[3] = {64, 2, 1}; - cl_mem src1_sub_buffer, buf_src1_image, buf_src2; - - int tile_size = 320; if (ne12 == 1) { // for gemv - kernel = backend_ctx->kernel_gemv_moe_mxfp4_f32; + kernel = backend_ctx->kernel_gemv_moe_mxfp4_f32_ns; + + cl_mem src1_sub_buffer, buf_src1_image, buf_src2; // create a sub_buffer for src2 cl_buffer_region region; @@ -12985,78 +13199,154 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, global_size[1] = 4; global_size[2] = static_cast(ne20); local_size[1] = 4; + + // create a sub_buffer for src1 + region.origin = offset1; + region.size = ne10 * ne11 * ne12 * sizeof(float); + src1_sub_buffer = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // create image for src1 + cl_image_format image_format_buf_src1 = {CL_RGBA, CL_FLOAT}; + cl_image_desc image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ne10 * ne11 * ne12 / 4), 0,0,0,0,0,0,0, {src1_sub_buffer}}; + buf_src1_image = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status); + CL_CHECK(status); + + // Set kernel args + int arg_idx = 0; + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_mxfp4->q)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_mxfp4->e)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src1_image)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne11)); + + // launch kernel + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst); + + // deallocate sub buffers and images + CL_CHECK(clReleaseMemObject(src1_sub_buffer)); + CL_CHECK(clReleaseMemObject(buf_src1_image)); + CL_CHECK(clReleaseMemObject(buf_src2)); + } else { // for gemm - kernel = backend_ctx->kernel_gemm_moe_mxfp4_f32; - - // preprocess router table - int num_tiles_per_expert = (ne01 + tile_size - 1) / tile_size; - void * host_src2_reorder = malloc(ne20 * ne21 * 4 * num_tiles_per_expert * sizeof(short)); - void * host_src2 = malloc(ne21 * nb21); - CL_CHECK(clEnqueueReadBuffer(backend_ctx->queue, extra2->data_device, CL_TRUE, offset2, ne21 * nb21, host_src2, 0, NULL, NULL)); - int total_experts = nb21 / nb20; - int out_idx = 0; - for (int i_expert = 0; i_expert < ne02; i_expert++) { - for (int i_tile = 0; i_tile < num_tiles_per_expert; i_tile++) { - for (int j = 0; j < ne21; j++) { - for (int i = 0; i < ne20; i++) { - int expert = ((int *)host_src2)[j * total_experts + i]; - if (i_expert == expert) { - ((short *)host_src2_reorder)[out_idx] = static_cast(expert); - ((short *)host_src2_reorder)[out_idx + 1] = static_cast(j * ne11 + (i % ne11)); - ((short *)host_src2_reorder)[out_idx + 2] = static_cast(j * ne20 + i); - ((short *)host_src2_reorder)[out_idx + 3] = static_cast(i_tile); - out_idx += 4; - } - } - } - } + kernel = backend_ctx->kernel_gemm_moe_mxfp4_f32_ns; + + // Reorder router if called from test-backend-ops or when new router is generated. + // Otherwise reuse the reordered result from previous mul_mat_id call. + if ((strstr(src0->name, "as") != NULL) || backend_ctx->toggle_reorder) { + moe_router_reoerder(backend, src2, ne20); + backend_ctx->toggle_reorder = false; } - buf_src2 = clCreateBuffer(backend_ctx->context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, ne20 * ne21 * 4 * num_tiles_per_expert * sizeof(short), host_src2_reorder, &status); + + cl_mem sub_buf_src1_pre, buf_src1_reordered, image_src1_reordered, sub_buf_dst, buf_dst_image; + cl_mem buf_src2, buf_src2_emap; + + cl_buffer_region region; + region.origin = 0; + region.size = sizeof(int) * max_post_router_tile * n_tile_size; + GGML_ASSERT(backend_ctx->prealloc_post_router.buffer); + buf_src2 = clCreateSubBuffer(backend_ctx->prealloc_post_router.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); CL_CHECK(status); - // set thread grid - global_size[0] = static_cast(tile_size); - global_size[2] = static_cast(ne20 * ne21 * num_tiles_per_expert); - } + region.origin = 0; + region.size = sizeof(short) * max_post_router_tile; + buf_src2_emap = clCreateSubBuffer(backend_ctx->prealloc_emap.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); - // create a sub_buffer for src1 - cl_buffer_region region; - region.origin = offset1; - region.size = ne10 * ne11 * ne12 * sizeof(float); - src1_sub_buffer = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); - CL_CHECK(status); - - // create image for src1 - cl_image_format image_format_buf_src1 = {CL_RGBA, CL_FLOAT}; - cl_image_desc image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ne10 * ne11 * ne12 / 4), 0,0,0,0,0,0,0, {src1_sub_buffer}}; - buf_src1_image = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status); - CL_CHECK(status); - - // Set kernel args - int arg_idx = 0; - CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_mxfp4->q)); - CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_mxfp4->e)); - CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src1_image)); - CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2)); - CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extrad->data_device)); - CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_ulong), &offsetd)); - CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00)); - CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01)); - if (ne12 == 1) { - CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne11)); - } else { - CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &tile_size)); - } + // Reorder activations + // create a sub_buffer for src1 + region.origin = offset1; + region.size = ne10 * ne11 * ne12 * sizeof(float); + sub_buf_src1_pre = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // Create image for reordered src1 + // Use pre-allocated placeholder + region.origin = 0; + region.size = ne00 * max_post_router_tile * n_tile_size * sizeof(float); + backend_ctx->prealloc_act_trans.allocate(backend_ctx->context, region.size); + buf_src1_reordered = clCreateSubBuffer( + backend_ctx->prealloc_act_trans.buffer, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &status); + CL_CHECK(status); + cl_image_format image_format_buf_src1; + cl_image_desc image_desc_buf_src1; + image_format_buf_src1 = {CL_RGBA, CL_FLOAT}; + image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ne00 * max_post_router_tile * n_tile_size / 4), 0,0,0,0,0,0,0, {buf_src1_reordered}}; + image_src1_reordered = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status); + CL_CHECK(status); - // launch kernel - backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst); + unsigned short map_ratio = ne20 / ne11; + GGML_ASSERT(((map_ratio == 1) || (map_ratio == ne20)) && "Map ratio not supported\n"); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 0, sizeof(cl_mem), &sub_buf_src1_pre)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 1, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 2, sizeof(cl_mem), &buf_src1_reordered)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 3, sizeof(cl_mem), &(backend_ctx->prealloc_total_tiles.buffer))); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 4, sizeof(unsigned int), &ne00)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 5, sizeof(unsigned short), &map_ratio)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 6, sizeof(unsigned int), &n_tile_size)); + + size_t reorder_b_local_size[3] = {256, 1, 1}; + size_t reorder_b_global_size[3] = {static_cast(((ne00 / 4) + 255) / 256 * 256), static_cast(max_post_router_tile * n_tile_size), 1}; + + // Dispatch reorder kernel + backend_ctx->enqueue_ndrange_kernel(backend_ctx->kernel_moe_reorder_b, 3, reorder_b_global_size, reorder_b_local_size, dst); + + // MoE kernel prepare + // Create sub buffer for dst + region.origin = offsetd; + region.size = ne0 * ne1 * ne2 * sizeof(float); + sub_buf_dst = clCreateSubBuffer( + extrad->data_device, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &status); + CL_CHECK(status); + // Create image for dst + cl_image_format image_format_buf_dst = {CL_R, CL_FLOAT}; + cl_image_desc image_desc_buf_dst = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ne0 * ne1 * ne2), 0,0,0,0,0,0,0, {sub_buf_dst}}; + buf_dst_image = clCreateImage(backend_ctx->context, CL_MEM_WRITE_ONLY, &image_format_buf_dst, &image_desc_buf_dst, NULL, &status); + CL_CHECK(status); - // deallocate sub buffers and images - CL_CHECK(clReleaseMemObject(src1_sub_buffer)); - CL_CHECK(clReleaseMemObject(buf_src1_image)); - CL_CHECK(clReleaseMemObject(buf_src2)); + // Set kernel args + int arg_idx = 0; + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_mxfp4->q_img)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_mxfp4->e)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &image_src1_reordered)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2_emap)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_dst_image)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &(backend_ctx->prealloc_total_tiles.buffer))); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01)); + + // set thread grid + global_size[1] = static_cast((ne01 + 63) / 64); + global_size[2] = static_cast(max_post_router_tile); + local_size[1] = 1; + local_size[2] = 1; + + // Dispatch kernel + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst); + + clReleaseMemObject(sub_buf_src1_pre); + clReleaseMemObject(buf_src1_reordered); + clReleaseMemObject(image_src1_reordered); + clReleaseMemObject(buf_src2); + clReleaseMemObject(buf_src2_emap); + clReleaseMemObject(sub_buf_dst); + clReleaseMemObject(buf_dst_image); + } return; - } // else fallback to generic kernel + } // fallback to generic MoE mxfp4 kernel #endif // GGML_OPENCL_USE_ADRENO_KERNELS #ifdef GGML_OPENCL_SOA_Q @@ -14002,6 +14292,13 @@ static void ggml_cl_argsort(ggml_backend_t backend, const ggml_tensor * src0, co size_t local_work_size[] = {(size_t)ne00_padded, 1, 1}; backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + const int ne21 = dst->ne[1]; + if ((strstr(src0->name, "_moe") != NULL) && (ne21 != 1)) { + backend_ctx->toggle_reorder = true; + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS } static void ggml_cl_sum_rows(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { diff --git a/ggml/src/ggml-opencl/kernels/cvt.cl b/ggml/src/ggml-opencl/kernels/cvt.cl index f3937d8304c..c1ad46f4435 100644 --- a/ggml/src/ggml-opencl/kernels/cvt.cl +++ b/ggml/src/ggml-opencl/kernels/cvt.cl @@ -371,6 +371,93 @@ kernel void kernel_restore_block_mxfp4_trans( b->e = src_e[src_blk_offset]; } +kernel void kernel_convert_block_mxfp4_trans4_ns( + global struct block_mxfp4 * src0, + __global uint * dst_q, + __global uchar * dst_e, + uint ne00, + uint ne01 +) { + uint i00 = get_global_id(1); + uint i01 = get_global_id(0); + uint i02 = get_global_id(2); + + uint ne00_blk = ne00 / QK_MXFP4; + uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; + uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; + + global struct block_mxfp4 * b = src0 + src_blk_offset; + dst_e[dst_blk_offset] = b->e; + + // extract quantization and unshuffle + ushort8 pre_block = ((global ushort8 *)(&(b->qs[0])))[0]; + + ushort8 post_block = (ushort8)(0); + + uchar * pre_block_ptr = (uchar *)(&pre_block); + uchar * post_block_ptr = (uchar *)(&post_block); + + for (int i = 0; i < QK_MXFP4 / 4; ++i) { + uchar x0 = pre_block_ptr[2*i + 0]; + uchar x1 = pre_block_ptr[2*i + 1]; + + post_block_ptr[i + 0 ] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4); + post_block_ptr[i + QK_MXFP4 / 4] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0); + } + + uint4 q_block = as_uint4(post_block); + + uint offset = i02 * ne00_blk * ne01 * 4 + i00 * ne01 * 4 + i01; + dst_q[offset] = q_block.x; + dst_q[offset + ne01] = q_block.y; + dst_q[offset + ne01 * 2] = q_block.z; + dst_q[offset + ne01 * 3] = q_block.w; +} + +kernel void kernel_restore_block_mxfp4_trans4_ns( + __global uint * src_q, + __global uchar * src_e, + __global struct block_mxfp4 * dst0, + uint ne00, + uint ne01 +) { + uint i00 = get_global_id(1); + uint i01 = get_global_id(0); + uint i02 = get_global_id(2); + + uint ne00_blk = ne00 / QK_MXFP4; + uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; + uint src_d_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; + + __global struct block_mxfp4 * b = dst0 + dst_blk_offset; + b->e = src_e[src_d_offset]; + + // collect transposed quantization parts for a block + uint src_q_offset = i02 * ne00_blk * ne01 * 4 + i00 * ne01 * 4 + i01; + uint4 q_block; + q_block.x = src_q[src_q_offset]; + q_block.y = src_q[src_q_offset + ne01]; + q_block.z = src_q[src_q_offset + ne01 * 2]; + q_block.w = src_q[src_q_offset + ne01 * 3]; + + ushort8 post_block = as_ushort8(q_block); + ushort8 pre_block = (ushort8)(0); + + uchar * pre_block_ptr = (uchar *)(&pre_block); + uchar * post_block_ptr = (uchar *)(&post_block); + + for (int i = 0; i < QK_MXFP4 / 4; ++i) { + uchar x0 = post_block_ptr[i + 0]; + uchar x1 = post_block_ptr[i + QK_MXFP4 / 4]; + + pre_block_ptr[2 * i + 0] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4); + pre_block_ptr[2 * i + 1] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0); + } + + ((__global ushort8 *)(&(b->qs[0])))[0] = pre_block; +} + + //------------------------------------------------------------------------------ // block_q8_0 //------------------------------------------------------------------------------ diff --git a/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl new file mode 100644 index 00000000000..e404f392bdd --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl @@ -0,0 +1,302 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_uniform_load: enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_constant_load: enable +#pragma OPENCL EXTENSION cl_qcom_extra_vector_types : enable + +#define TILESIZE_K 16 +#define TILESIZE_M 64 +#define TILESIZE_N 32 + + +static inline half8 mxfp4_to_fp16_packed8(ushort2 fp4x8) { + ushort2 fp16_packed_a_0, fp16_packed_b_0, bias_a, bias_b, sign_a, sign_b; + fp16_packed_a_0.lo = (fp4x8.s0 << 9) & 0x0E00; + fp16_packed_a_0.hi = (fp4x8.s0 << 5) & 0x0E00; + fp16_packed_b_0.lo = (fp4x8.s0 << 1) & 0x0E00; + fp16_packed_b_0.hi = (fp4x8.s0 >> 3) & 0x0E00; + + bias_a.lo = (fp16_packed_a_0.lo != 0) ? 0x3800 : 0x0; + bias_a.hi = (fp16_packed_a_0.hi != 0) ? 0x3800 : 0x0; + bias_b.lo = (fp16_packed_b_0.lo != 0) ? 0x3800 : 0x0; + bias_b.hi = (fp16_packed_b_0.hi != 0) ? 0x3800 : 0x0; + + fp16_packed_a_0.lo = (fp16_packed_a_0.lo != 0x0200) ? fp16_packed_a_0.lo : 0x0; + fp16_packed_a_0.hi = (fp16_packed_a_0.hi != 0x0200) ? fp16_packed_a_0.hi : 0x0; + fp16_packed_b_0.lo = (fp16_packed_b_0.lo != 0x0200) ? fp16_packed_b_0.lo : 0x0; + fp16_packed_b_0.hi = (fp16_packed_b_0.hi != 0x0200) ? fp16_packed_b_0.hi : 0x0; + + sign_a.lo = (fp4x8.s0 << 12) & 0x8000; + sign_a.hi = (fp4x8.s0 << 8) & 0x8000; + sign_b.lo = (fp4x8.s0 << 4) & 0x8000; + sign_b.hi = fp4x8.s0 & 0x8000; + + fp16_packed_a_0 = sign_a + bias_a + fp16_packed_a_0; + fp16_packed_b_0 = sign_b + bias_b + fp16_packed_b_0; + + ushort2 fp16_packed_a_1, fp16_packed_b_1; + fp16_packed_a_1.lo = (fp4x8.s1 << 9) & 0x0E00; + fp16_packed_a_1.hi = (fp4x8.s1 << 5) & 0x0E00; + fp16_packed_b_1.lo = (fp4x8.s1 << 1) & 0x0E00; + fp16_packed_b_1.hi = (fp4x8.s1 >> 3) & 0x0E00; + + bias_a.lo = (fp16_packed_a_1.lo != 0) ? 0x3800 : 0x0; + bias_a.hi = (fp16_packed_a_1.hi != 0) ? 0x3800 : 0x0; + bias_b.lo = (fp16_packed_b_1.lo != 0) ? 0x3800 : 0x0; + bias_b.hi = (fp16_packed_b_1.hi != 0) ? 0x3800 : 0x0; + + fp16_packed_a_1.lo = (fp16_packed_a_1.lo != 0x0200) ? fp16_packed_a_1.lo : 0x0; + fp16_packed_a_1.hi = (fp16_packed_a_1.hi != 0x0200) ? fp16_packed_a_1.hi : 0x0; + fp16_packed_b_1.lo = (fp16_packed_b_1.lo != 0x0200) ? fp16_packed_b_1.lo : 0x0; + fp16_packed_b_1.hi = (fp16_packed_b_1.hi != 0x0200) ? fp16_packed_b_1.hi : 0x0; + + sign_a.lo = (fp4x8.s1 << 12) & 0x8000; + sign_a.hi = (fp4x8.s1 << 8) & 0x8000; + sign_b.lo = (fp4x8.s1 << 4) & 0x8000; + sign_b.hi = fp4x8.s1 & 0x8000; + + fp16_packed_a_1 = sign_a + bias_a + fp16_packed_a_1; + fp16_packed_b_1 = sign_b + bias_b + fp16_packed_b_1; + + return as_half8((ushort8)(fp16_packed_a_0, fp16_packed_b_0, fp16_packed_a_1, fp16_packed_b_1)); +} + + +#define dotx16_reduce8(a_reg, b_lm, c_reg, lm_offset) \ + acc.s0 = dot(a_reg.s0123, b_lm[lm_offset + 0]); \ + acc.s1 = dot(a_reg.s0123, b_lm[lm_offset + 1]); \ + acc.s2 = dot(a_reg.s0123, b_lm[lm_offset + 2]); \ + acc.s3 = dot(a_reg.s0123, b_lm[lm_offset + 3]); \ + acc.s4 = dot(a_reg.s0123, b_lm[lm_offset + 4]); \ + acc.s5 = dot(a_reg.s0123, b_lm[lm_offset + 5]); \ + acc.s6 = dot(a_reg.s0123, b_lm[lm_offset + 6]); \ + acc.s7 = dot(a_reg.s0123, b_lm[lm_offset + 7]); \ + acc.s8 = dot(a_reg.s0123, b_lm[lm_offset + 8]); \ + acc.s9 = dot(a_reg.s0123, b_lm[lm_offset + 9]); \ + acc.sa = dot(a_reg.s0123, b_lm[lm_offset + 10]); \ + acc.sb = dot(a_reg.s0123, b_lm[lm_offset + 11]); \ + acc.sc = dot(a_reg.s0123, b_lm[lm_offset + 12]); \ + acc.sd = dot(a_reg.s0123, b_lm[lm_offset + 13]); \ + acc.se = dot(a_reg.s0123, b_lm[lm_offset + 14]); \ + acc.sf = dot(a_reg.s0123, b_lm[lm_offset + 15]); \ + acc.s0 += dot(a_reg.s4567, b_lm[lm_offset + 32]); \ + acc.s1 += dot(a_reg.s4567, b_lm[lm_offset + 33]); \ + acc.s2 += dot(a_reg.s4567, b_lm[lm_offset + 34]); \ + acc.s3 += dot(a_reg.s4567, b_lm[lm_offset + 35]); \ + acc.s4 += dot(a_reg.s4567, b_lm[lm_offset + 36]); \ + acc.s5 += dot(a_reg.s4567, b_lm[lm_offset + 37]); \ + acc.s6 += dot(a_reg.s4567, b_lm[lm_offset + 38]); \ + acc.s7 += dot(a_reg.s4567, b_lm[lm_offset + 39]); \ + acc.s8 += dot(a_reg.s4567, b_lm[lm_offset + 40]); \ + acc.s9 += dot(a_reg.s4567, b_lm[lm_offset + 41]); \ + acc.sa += dot(a_reg.s4567, b_lm[lm_offset + 42]); \ + acc.sb += dot(a_reg.s4567, b_lm[lm_offset + 43]); \ + acc.sc += dot(a_reg.s4567, b_lm[lm_offset + 44]); \ + acc.sd += dot(a_reg.s4567, b_lm[lm_offset + 45]); \ + acc.se += dot(a_reg.s4567, b_lm[lm_offset + 46]); \ + acc.sf += dot(a_reg.s4567, b_lm[lm_offset + 47]); \ + c_reg.lo += convert_float8(acc.lo); \ + c_reg.hi += convert_float8(acc.hi); \ + acc.s0 = dot(a_reg.s89ab, b_lm[lm_offset + 64]); \ + acc.s1 = dot(a_reg.s89ab, b_lm[lm_offset + 65]); \ + acc.s2 = dot(a_reg.s89ab, b_lm[lm_offset + 66]); \ + acc.s3 = dot(a_reg.s89ab, b_lm[lm_offset + 67]); \ + acc.s4 = dot(a_reg.s89ab, b_lm[lm_offset + 68]); \ + acc.s5 = dot(a_reg.s89ab, b_lm[lm_offset + 69]); \ + acc.s6 = dot(a_reg.s89ab, b_lm[lm_offset + 70]); \ + acc.s7 = dot(a_reg.s89ab, b_lm[lm_offset + 71]); \ + acc.s8 = dot(a_reg.s89ab, b_lm[lm_offset + 72]); \ + acc.s9 = dot(a_reg.s89ab, b_lm[lm_offset + 73]); \ + acc.sa = dot(a_reg.s89ab, b_lm[lm_offset + 74]); \ + acc.sb = dot(a_reg.s89ab, b_lm[lm_offset + 75]); \ + acc.sc = dot(a_reg.s89ab, b_lm[lm_offset + 76]); \ + acc.sd = dot(a_reg.s89ab, b_lm[lm_offset + 77]); \ + acc.se = dot(a_reg.s89ab, b_lm[lm_offset + 78]); \ + acc.sf = dot(a_reg.s89ab, b_lm[lm_offset + 79]); \ + acc.s0 += dot(a_reg.scdef, b_lm[lm_offset + 96]); \ + acc.s1 += dot(a_reg.scdef, b_lm[lm_offset + 97]); \ + acc.s2 += dot(a_reg.scdef, b_lm[lm_offset + 98]); \ + acc.s3 += dot(a_reg.scdef, b_lm[lm_offset + 99]); \ + acc.s4 += dot(a_reg.scdef, b_lm[lm_offset + 100]); \ + acc.s5 += dot(a_reg.scdef, b_lm[lm_offset + 101]); \ + acc.s6 += dot(a_reg.scdef, b_lm[lm_offset + 102]); \ + acc.s7 += dot(a_reg.scdef, b_lm[lm_offset + 103]); \ + acc.s8 += dot(a_reg.scdef, b_lm[lm_offset + 104]); \ + acc.s9 += dot(a_reg.scdef, b_lm[lm_offset + 105]); \ + acc.sa += dot(a_reg.scdef, b_lm[lm_offset + 106]); \ + acc.sb += dot(a_reg.scdef, b_lm[lm_offset + 107]); \ + acc.sc += dot(a_reg.scdef, b_lm[lm_offset + 108]); \ + acc.sd += dot(a_reg.scdef, b_lm[lm_offset + 109]); \ + acc.se += dot(a_reg.scdef, b_lm[lm_offset + 110]); \ + acc.sf += dot(a_reg.scdef, b_lm[lm_offset + 111]); \ + c_reg.lo += convert_float8(acc.lo); \ + c_reg.hi += convert_float8(acc.hi); \ + + +static inline half e8m0_to_fp16(uchar x) { + ushort bits; + bits = (ushort)(x) - (ushort)(112); + bits = ((bits & 0x00E0) != 0) ? 0x7C00 : (bits << 10); + return as_half(bits); +} + +static inline float e8m0_to_fp32(uchar x) { + int bits; + bits = (x == 0) ? 0x00400000 : ((uint) x << 23); + return as_float(bits); +} + + +__attribute__((qcom_wave_pair_mode(1))) // 1=force single 2=force pair +kernel void kernel_gemm_moe_mxfp4_f32_ns( + __read_only image1d_buffer_t src0_q, + __global uchar * src0_d, + __read_only image1d_buffer_t src1, + __global uint * src2, + __global ushort * src2_emap, + __write_only image1d_buffer_t dst, + __global int * total_tiles, + uint ne00, + uint ne01 +) { + uint block_id_m = get_global_id(1); // m_tile + uint block_id_n = get_global_id(2); // n_tile + + // Boundary check + if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) { + return; + } + + __private half16 reg_a; + __private float32 reg_c = (float32)(0); + __local half4 shared_b[128]; + + const ushort expert_id = src2_emap[block_id_n]; + + const uint row = block_id_m * TILESIZE_M; + const uint col = block_id_n * TILESIZE_N; + + uint sub_block_id_m = get_local_id(0); + uint2 b_global_offset; + b_global_offset.x = ((sub_block_id_m & 3) << 2) + (sub_block_id_m >> 2) * ne00; + b_global_offset.y = b_global_offset.x + (16 * ne00); + uint2 b_local_offset; + b_local_offset.x = (sub_block_id_m & 3) * 32 + (sub_block_id_m >> 2); + b_local_offset.y = b_local_offset.x + 16; + + // Loop along K axis, 32 elements (one block) for each iteration, divided into 2 sub-blocks + for (uint step = 0; step < ne00; step += TILESIZE_K * 2) { + // First sub-block + uint q_sub_offset = row + ((ne01 * step) >> 3) + ((expert_id * ne00 * ne01) >> 3); + uint s_sub_offset = row + ((ne01 * step) >> 5) + ((expert_id * ne00 * ne01) >> 5); + uint b_sub_offset = col * ne00 + step; + + // Load scale for current mxfp4 block + uint s_offset = s_sub_offset + get_global_id(0); + float s = e8m0_to_fp32(src0_d[s_offset]); + + // Load 16 fp4 (64-bits) in transposed layout + uint2 mxfp4x16; + mxfp4x16.x = read_imageui(src0_q, q_sub_offset + sub_block_id_m).x; + mxfp4x16.y = read_imageui(src0_q, q_sub_offset + sub_block_id_m + ne01).x; + + // Load 16x32 floats from matrix B, each fiber out of 64 in a sub-group loads 8 elements + float8 bx8_f32; + bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4); + bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4); + // Convert to half and store to LM to share within the subgroup + half8 bx8_f16 = convert_half8(bx8_f32); + shared_b[b_local_offset.x] = bx8_f16.lo; + shared_b[b_local_offset.y] = bx8_f16.hi; + + // Dequantization + reg_a.lo = mxfp4_to_fp16_packed8(as_ushort2(mxfp4x16.lo)) * s; + reg_a.hi = mxfp4_to_fp16_packed8(as_ushort2(mxfp4x16.hi)) * s; + + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + + // 32 16x16 fp16 dot product with 8 elements reduction for better precision + half16 acc; + dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0); + dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); + + // Repeat for second sub-block + uint half_step = step + TILESIZE_K; + q_sub_offset = row + ((ne01 * half_step) >> 3) + ((expert_id * ne00 * ne01) >> 3); + b_sub_offset = col * ne00 + half_step; + + // Load next 16 fp4 (64-bits) in transposed layout + mxfp4x16.x = read_imageui(src0_q, q_sub_offset + sub_block_id_m).x; + mxfp4x16.y = read_imageui(src0_q, q_sub_offset + sub_block_id_m + ne01).x; + + // Load 16x32 floats from matrix B, each fiber out of 64 in a sub-group loads 8 elements + bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4); + bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4); + // Convert to half and store to LM to share within the subgroup + bx8_f16 = convert_half8(bx8_f32); + shared_b[b_local_offset.x] = bx8_f16.lo; + shared_b[b_local_offset.y] = bx8_f16.hi; + + // Dequantization + reg_a.lo = mxfp4_to_fp16_packed8(as_ushort2(mxfp4x16.lo)) * s; + reg_a.hi = mxfp4_to_fp16_packed8(as_ushort2(mxfp4x16.hi)) * s; + + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + + // 32 16x16 fp16 dot product with 3-levels reduction for better precision + dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0); + dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); + } + + // Load poster router and share in LM + __local uint out_idx[TILESIZE_N]; + + if (get_local_id(0) < TILESIZE_N) { + uint idx = src2[block_id_n * TILESIZE_N + get_local_id(0)]; + if (idx == 0xFFFFFFFF) { + idx = src2[block_id_n * TILESIZE_N + 0]; + } + out_idx[get_local_id(0)] = idx * ne01; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + // Scatter results back to original position in output grid + uint m_offset = row + get_local_id(0); + + write_imagef(dst, out_idx[1] + m_offset, (reg_c.s1)); + write_imagef(dst, out_idx[2] + m_offset, (reg_c.s2)); + write_imagef(dst, out_idx[3] + m_offset, (reg_c.s3)); + write_imagef(dst, out_idx[4] + m_offset, (reg_c.s4)); + write_imagef(dst, out_idx[5] + m_offset, (reg_c.s5)); + write_imagef(dst, out_idx[6] + m_offset, (reg_c.s6)); + write_imagef(dst, out_idx[7] + m_offset, (reg_c.s7)); + write_imagef(dst, out_idx[8] + m_offset, (reg_c.s8)); + write_imagef(dst, out_idx[9] + m_offset, (reg_c.s9)); + write_imagef(dst, out_idx[10] + m_offset, (reg_c.sa)); + write_imagef(dst, out_idx[11] + m_offset, (reg_c.sb)); + write_imagef(dst, out_idx[12] + m_offset, (reg_c.sc)); + write_imagef(dst, out_idx[13] + m_offset, (reg_c.sd)); + write_imagef(dst, out_idx[14] + m_offset, (reg_c.se)); + write_imagef(dst, out_idx[15] + m_offset, (reg_c.sf)); + write_imagef(dst, out_idx[16] + m_offset, (reg_c.sg)); + write_imagef(dst, out_idx[17] + m_offset, (reg_c.sh)); + write_imagef(dst, out_idx[18] + m_offset, (reg_c.si)); + write_imagef(dst, out_idx[19] + m_offset, (reg_c.sj)); + write_imagef(dst, out_idx[20] + m_offset, (reg_c.sk)); + write_imagef(dst, out_idx[21] + m_offset, (reg_c.sl)); + write_imagef(dst, out_idx[22] + m_offset, (reg_c.sm)); + write_imagef(dst, out_idx[23] + m_offset, (reg_c.sn)); + write_imagef(dst, out_idx[24] + m_offset, (reg_c.so)); + write_imagef(dst, out_idx[25] + m_offset, (reg_c.sp)); + write_imagef(dst, out_idx[26] + m_offset, (reg_c.sq)); + write_imagef(dst, out_idx[27] + m_offset, (reg_c.sr)); + write_imagef(dst, out_idx[28] + m_offset, (reg_c.ss)); + write_imagef(dst, out_idx[29] + m_offset, (reg_c.st)); + write_imagef(dst, out_idx[30] + m_offset, (reg_c.su)); + write_imagef(dst, out_idx[31] + m_offset, (reg_c.sv)); + + // Store zero padding parts to the index of first output in tile, override correct result in the end + barrier(CLK_GLOBAL_MEM_FENCE); + write_imagef(dst, out_idx[0] + m_offset, (reg_c.s0)); +} diff --git a/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl new file mode 100644 index 00000000000..e4b44c1a56a --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl @@ -0,0 +1,161 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable + +#define QK_MXFP4 32 +#define N_SIMDGROUP 4 +#define SIMDGROUP_WIDTH 64 + +static inline half8 mxfp4_to_fp16_packed8(ushort2 fp4x8) { + ushort2 fp16_packed_a_0, fp16_packed_b_0, bias_a, bias_b, sign_a, sign_b; + fp16_packed_a_0.lo = (fp4x8.s0 << 9) & 0x0E00; + fp16_packed_a_0.hi = (fp4x8.s0 << 5) & 0x0E00; + fp16_packed_b_0.lo = (fp4x8.s0 << 1) & 0x0E00; + fp16_packed_b_0.hi = (fp4x8.s0 >> 3) & 0x0E00; + + bias_a.lo = (fp16_packed_a_0.lo != 0) ? 0x3800 : 0x0; + bias_a.hi = (fp16_packed_a_0.hi != 0) ? 0x3800 : 0x0; + bias_b.lo = (fp16_packed_b_0.lo != 0) ? 0x3800 : 0x0; + bias_b.hi = (fp16_packed_b_0.hi != 0) ? 0x3800 : 0x0; + + fp16_packed_a_0.lo = (fp16_packed_a_0.lo != 0x0200) ? fp16_packed_a_0.lo : 0x0; + fp16_packed_a_0.hi = (fp16_packed_a_0.hi != 0x0200) ? fp16_packed_a_0.hi : 0x0; + fp16_packed_b_0.lo = (fp16_packed_b_0.lo != 0x0200) ? fp16_packed_b_0.lo : 0x0; + fp16_packed_b_0.hi = (fp16_packed_b_0.hi != 0x0200) ? fp16_packed_b_0.hi : 0x0; + + sign_a.lo = (fp4x8.s0 << 12) & 0x8000; + sign_a.hi = (fp4x8.s0 << 8) & 0x8000; + sign_b.lo = (fp4x8.s0 << 4) & 0x8000; + sign_b.hi = fp4x8.s0 & 0x8000; + + fp16_packed_a_0 = sign_a + bias_a + fp16_packed_a_0; + fp16_packed_b_0 = sign_b + bias_b + fp16_packed_b_0; + + ushort2 fp16_packed_a_1, fp16_packed_b_1; + fp16_packed_a_1.lo = (fp4x8.s1 << 9) & 0x0E00; + fp16_packed_a_1.hi = (fp4x8.s1 << 5) & 0x0E00; + fp16_packed_b_1.lo = (fp4x8.s1 << 1) & 0x0E00; + fp16_packed_b_1.hi = (fp4x8.s1 >> 3) & 0x0E00; + + bias_a.lo = (fp16_packed_a_1.lo != 0) ? 0x3800 : 0x0; + bias_a.hi = (fp16_packed_a_1.hi != 0) ? 0x3800 : 0x0; + bias_b.lo = (fp16_packed_b_1.lo != 0) ? 0x3800 : 0x0; + bias_b.hi = (fp16_packed_b_1.hi != 0) ? 0x3800 : 0x0; + + fp16_packed_a_1.lo = (fp16_packed_a_1.lo != 0x0200) ? fp16_packed_a_1.lo : 0x0; + fp16_packed_a_1.hi = (fp16_packed_a_1.hi != 0x0200) ? fp16_packed_a_1.hi : 0x0; + fp16_packed_b_1.lo = (fp16_packed_b_1.lo != 0x0200) ? fp16_packed_b_1.lo : 0x0; + fp16_packed_b_1.hi = (fp16_packed_b_1.hi != 0x0200) ? fp16_packed_b_1.hi : 0x0; + + sign_a.lo = (fp4x8.s1 << 12) & 0x8000; + sign_a.hi = (fp4x8.s1 << 8) & 0x8000; + sign_b.lo = (fp4x8.s1 << 4) & 0x8000; + sign_b.hi = fp4x8.s1 & 0x8000; + + fp16_packed_a_1 = sign_a + bias_a + fp16_packed_a_1; + fp16_packed_b_1 = sign_b + bias_b + fp16_packed_b_1; + + return as_half8((ushort8)(fp16_packed_a_0, fp16_packed_b_0, fp16_packed_a_1, fp16_packed_b_1)); +} + +static inline float e8m0_to_fp32(uchar x) { + int bits; + bits = (x == 0) ? 0x00400000 : ((uint) x << 23); + return as_float(bits); +} + + +__attribute__((qcom_reqd_sub_group_size("half"))) +__kernel void kernel_gemv_moe_mxfp4_f32_ns( + __global uint * src0_q, + __global uchar * src0_e, + __read_only image1d_buffer_t src1, + __global uint * src2, + __global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne11 +) { + uint i01 = get_global_id(0); + uint i20 = get_global_id(2); + uint sgid = get_local_id(1); + uint slid = get_sub_group_local_id(); + + uint i11 = i20 % ne11; + + uint expert_id = src2[i20]; + uint expert_offset = expert_id * ne00 * ne01 / 32; + + __private float sum = 0.0f; // each thread calculate partial sum of one output + + // loop along ne00 in block granularity, skip 4 blocks every iter + for (uint ib00 = sgid; ib00 < (ne00 / QK_MXFP4); ib00 += N_SIMDGROUP) { + + // load one block of q + uint4 regQ; + uint block_offset = expert_offset * 4 + ib00 * ne01 * 4 + i01; + + regQ.s0 = src0_q[block_offset]; + regQ.s1 = src0_q[block_offset + ne01]; + regQ.s2 = src0_q[block_offset + ne01 * 2]; + regQ.s3 = src0_q[block_offset + ne01 * 3]; + + uint offset = i11 * ne00 / 4 + ib00 * 8; + + half8 fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s0)); + + float4 shared_y4; + shared_y4 = read_imagef(src1, (offset + 0)); + float4 acc = shared_y4 * convert_float4(fp16x8.lo); + + shared_y4 = read_imagef(src1, (offset + 1)); + acc += shared_y4 * convert_float4(fp16x8.hi); + + fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s1)); + + shared_y4 = read_imagef(src1, (offset + 2)); + acc += shared_y4 * convert_float4(fp16x8.lo); + + shared_y4 = read_imagef(src1, (offset + 3)); + acc += shared_y4 * convert_float4(fp16x8.hi); + + + fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s2)); + + shared_y4 = read_imagef(src1, (offset + 4)); + acc += shared_y4 * convert_float4(fp16x8.lo); + + shared_y4 = read_imagef(src1, (offset + 5)); + acc += shared_y4 * convert_float4(fp16x8.hi); + + + fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s3)); + + shared_y4 = read_imagef(src1, (offset + 6)); + acc += shared_y4 * convert_float4(fp16x8.lo); + + shared_y4 = read_imagef(src1, (offset + 7)); + acc += shared_y4 * convert_float4(fp16x8.hi); + + uchar regE = src0_e[ib00 * ne01 + i01 + expert_offset]; + sum += e8m0_to_fp32(regE) * ((acc.s0 + acc.s1) + (acc.s2 + acc.s3)); + } + + // reduction in local memory, assumes #subgroups=4 + __local float reduceLM[SIMDGROUP_WIDTH * (N_SIMDGROUP - 1)]; + if (sgid == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = sum; + if (sgid == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = sum; + if (sgid == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = sum; + barrier(CLK_LOCAL_MEM_FENCE); + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 0 + slid]; + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 1 + slid]; + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 2 + slid]; + + // 1 outputs per thread in subgroup 0 + if (sgid == 0) { + dst = dst + (offsetd >> 2); + dst[i01 + i20 * ne01] = sum; + } + +} diff --git a/ggml/src/ggml-opencl/kernels/moe_reorder_b.cl b/ggml/src/ggml-opencl/kernels/moe_reorder_b.cl new file mode 100644 index 00000000000..e6295c81648 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/moe_reorder_b.cl @@ -0,0 +1,30 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#define QK4_0 32 + +kernel void kernel_moe_reorder_b( + global float4 * src, + global uint * router, + global float4 * dst, + global int * total_tiles, + uint K, + ushort map_ratio, + uint tile_size +) { + uint k_4 = get_global_id(0); + uint post_router_idx = get_global_id(1); + + if ((k_4 >= (K / 4)) || (post_router_idx >= total_tiles[0] * tile_size)) { + return; + } + + uint router_idx = router[post_router_idx]; + + float4 out = (float4)(0); + if (router_idx != 0xFFFFFFFF) { + ushort activation_idx = router_idx / map_ratio; + out = src[activation_idx * K / 4 + k_4]; + } + + dst[post_router_idx * K / 4 + k_4] = out; +} diff --git a/ggml/src/ggml-opencl/kernels/moe_sort_by_expert.cl b/ggml/src/ggml-opencl/kernels/moe_sort_by_expert.cl new file mode 100644 index 00000000000..d9703429b11 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/moe_sort_by_expert.cl @@ -0,0 +1,82 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +__kernel void kernel_moe_histogram( + __global const int * input, + __global int * hist, + uint N, + uint topK, + uint n_experts +) { + uint n = get_global_id(0); + uint k = get_global_id(1); + + if (n >= N || k >= topK) { + return; + } + + int expert_id = input[n * n_experts + k]; + atomic_inc(&hist[expert_id]); +} + +__kernel void kernel_moe_scan( + __global int * hist, + __global int * tile_offset, + __global int * total_tiles, + __global int * slot_counter, + int tile_size, + uint n_experts +) { + int offset = 0; + for (int v = 0; v < n_experts; v++) { + int count = hist[v]; + int tiles = (count + tile_size - 1) / tile_size; + tile_offset[v] = offset; + offset += tiles; + hist[v] = 0; + slot_counter[v] = 0; + } + + *total_tiles = offset; +} + +__kernel void kernel_moe_scatter( + __global const int * input, + __global int * post_router, + __global ushort * emap, + __global const int * tile_offset, + __global int * slot_counter, + int N, + int topK, + uint n_experts +) { + uint n = get_global_id(0); + uint k = get_global_id(1); + + if (n >= N || k >= topK) { + return; + } + + int val = input[n * n_experts + k]; + + int local_slot = atomic_inc(&slot_counter[val]); + + int tile_idx = tile_offset[val] + (local_slot / 32); + int lane = local_slot % 32; + int out_pos = tile_idx * 32 + lane; + + post_router[out_pos] = n * topK + k; + emap[tile_idx] = val; +} + +__kernel void kernel_moe_fill( + __global int * post_router, + __global int * total_tiles, + int tile_size +) { + int tile_id = get_global_id(0); + int vec_id_in_tile = get_global_id(1); + + if (tile_id < total_tiles[0]) { + post_router[tile_id * tile_size + vec_id_in_tile] = 0xFFFFFFFF; + } +} From 9ab94b8cdac1845f492421321aa14cebdc705853 Mon Sep 17 00:00:00 2001 From: JusteLeo Date: Sat, 2 May 2026 15:28:50 +0200 Subject: [PATCH 550/831] ggml-virtgpu: fix circular dependency in headers (llama/22557) --- ggml/src/ggml-virtgpu/virtgpu-shm.cpp | 1 + ggml/src/ggml-virtgpu/virtgpu.cpp | 1 + ggml/src/ggml-virtgpu/virtgpu.h | 2 -- 3 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-virtgpu/virtgpu-shm.cpp b/ggml/src/ggml-virtgpu/virtgpu-shm.cpp index ce6b3b3e607..7f2c2322d91 100644 --- a/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +++ b/ggml/src/ggml-virtgpu/virtgpu-shm.cpp @@ -1,6 +1,7 @@ #include "virtgpu-shm.h" #include "virtgpu.h" +#include "ggml-remoting.h" #include diff --git a/ggml/src/ggml-virtgpu/virtgpu.cpp b/ggml/src/ggml-virtgpu/virtgpu.cpp index a84a77399d9..e3ae1cc75e0 100644 --- a/ggml/src/ggml-virtgpu/virtgpu.cpp +++ b/ggml/src/ggml-virtgpu/virtgpu.cpp @@ -1,4 +1,5 @@ #include "virtgpu.h" +#include "ggml-remoting.h" #include #include diff --git a/ggml/src/ggml-virtgpu/virtgpu.h b/ggml/src/ggml-virtgpu/virtgpu.h index f82d8fb50ba..6b8de583893 100644 --- a/ggml/src/ggml-virtgpu/virtgpu.h +++ b/ggml/src/ggml-virtgpu/virtgpu.h @@ -18,8 +18,6 @@ #include -#include "ggml-remoting.h" - #define VIRGL_RENDERER_UNSTABLE_APIS 1 #include "apir_hw.h" #include From 3bcac0a0c7d73128fa9cf6e65ff1d16ff8438933 Mon Sep 17 00:00:00 2001 From: lucy <154630366+lucyknada@users.noreply.github.com> Date: Sat, 2 May 2026 16:19:25 -0400 Subject: [PATCH 551/831] fix: CUDA device PCI bus ID de-dupe OOMing (ignoring other 3 gpus entirely) (llama/22533) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: CUDA device PCI bus ID detection for multi-GPU de-dupe * HIP, MUSA macros --------- Co-authored-by: Johannes Gäßler --- ggml/src/ggml-cuda/ggml-cuda.cu | 4 ++-- ggml/src/ggml-cuda/vendors/hip.h | 1 + ggml/src/ggml-cuda/vendors/musa.h | 1 + 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index fbe0fa06242..8d21b2267f5 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -5431,8 +5431,8 @@ ggml_backend_reg_t ggml_backend_cuda_reg() { CUDA_CHECK(cudaGetDeviceProperties(&prop, i)); dev_ctx->description = prop.name; - char pci_bus_id[16] = {}; - snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.0", prop.pciDomainID, prop.pciBusID, prop.pciDeviceID); + char pci_bus_id[32] = {}; + CUDA_CHECK(cudaDeviceGetPCIBusId(pci_bus_id, sizeof(pci_bus_id), i)); dev_ctx->pci_bus_id = pci_bus_id; dev_ctx->op_offload_min_batch_size = min_batch_size; diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h index 78ca364d38f..e5d363c65d1 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h @@ -55,6 +55,7 @@ #define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess #define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess #define cudaDeviceGetAttribute hipDeviceGetAttribute +#define cudaDeviceGetPCIBusId hipDeviceGetPCIBusId #define cudaDeviceProp hipDeviceProp_t #define cudaDeviceSynchronize hipDeviceSynchronize #define cudaError_t hipError_t diff --git a/ggml/src/ggml-cuda/vendors/musa.h b/ggml/src/ggml-cuda/vendors/musa.h index 8aa056e9174..940c34a9fb2 100644 --- a/ggml/src/ggml-cuda/vendors/musa.h +++ b/ggml/src/ggml-cuda/vendors/musa.h @@ -39,6 +39,7 @@ #define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer #define cudaDeviceDisablePeerAccess musaDeviceDisablePeerAccess #define cudaDeviceEnablePeerAccess musaDeviceEnablePeerAccess +#define cudaDeviceGetPCIBusId musaDeviceGetPCIBusId #define cudaDeviceProp musaDeviceProp #define cudaDeviceSynchronize musaDeviceSynchronize #define cudaError_t musaError_t From d1d0dc2348f6b294598725ef8cd3d40652fb674d Mon Sep 17 00:00:00 2001 From: Chen Yuan Date: Sun, 3 May 2026 23:52:53 -0400 Subject: [PATCH 552/831] ggml-webgpu: add layer norm ops (llama/22406) * shader(norm): add layer norm ops * shader(norm): stablize floating point computation with Kahan summation and handle mixed types * shader(norm): remove the non-contiguous strides * shader(norm): use the original implementation rather than the kahan summation --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 32 +++++- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 2 + .../ggml-webgpu/wgsl-shaders/row_norm.wgsl | 97 +++++++++++++++---- 3 files changed, 107 insertions(+), 24 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index cff93b8d170..c6dc2c21147 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -228,11 +228,13 @@ struct ggml_webgpu_get_rows_pipeline_key_hash { /** Row Norm **/ struct ggml_webgpu_row_norm_pipeline_key { - ggml_op op; - bool inplace; + ggml_op op; + ggml_type src_type; + ggml_type dst_type; + bool inplace; bool operator==(const ggml_webgpu_row_norm_pipeline_key & other) const { - return op == other.op && inplace == other.inplace; + return op == other.op && src_type == other.src_type && dst_type == other.dst_type && inplace == other.inplace; } }; @@ -240,6 +242,8 @@ struct ggml_webgpu_row_norm_pipeline_key_hash { size_t operator()(const ggml_webgpu_row_norm_pipeline_key & key) const { size_t seed = 0; ggml_webgpu_hash_combine(seed, key.op); + ggml_webgpu_hash_combine(seed, key.src_type); + ggml_webgpu_hash_combine(seed, key.dst_type); ggml_webgpu_hash_combine(seed, key.inplace); return seed; } @@ -1097,6 +1101,8 @@ class ggml_webgpu_shader_lib { webgpu_pipeline get_row_norm_pipeline(const ggml_webgpu_shader_lib_context & context) { ggml_webgpu_row_norm_pipeline_key key = {}; key.op = context.dst->op; + key.src_type = context.src0->type; + key.dst_type = context.dst->type; key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst); auto it = row_norm_pipelines.find(key); @@ -1111,6 +1117,10 @@ class ggml_webgpu_shader_lib { defines.push_back("RMS_NORM"); variant = "rms_norm"; break; + case GGML_OP_NORM: + defines.push_back("NORM"); + variant = "norm"; + break; case GGML_OP_L2_NORM: defines.push_back("L2_NORM"); variant = "l2_norm"; @@ -1124,6 +1134,22 @@ class ggml_webgpu_shader_lib { variant += "_inplace"; } + if (key.src_type == GGML_TYPE_F32) { + defines.push_back("SRC_F32"); + variant += "_src_f32"; + } else if (key.src_type == GGML_TYPE_F16) { + defines.push_back("SRC_F16"); + variant += "_src_f16"; + } + + if (key.dst_type == GGML_TYPE_F32) { + defines.push_back("DST_F32"); + variant += "_dst_f32"; + } else if (key.dst_type == GGML_TYPE_F16) { + defines.push_back("DST_F16"); + variant += "_dst_f16"; + } + const uint32_t row_norm_wg_size = 128u; uint32_t wg_size = std::min(context.max_wg_size, row_norm_wg_size); defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index cab0aead198..12f60a9900e 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -2927,6 +2927,7 @@ static std::optional ggml_webgpu_encode(webgpu_context ctx, } else { return ggml_webgpu_row_norm(ctx, src0, node); } + case GGML_OP_NORM: case GGML_OP_L2_NORM: return ggml_webgpu_row_norm(ctx, src0, node); case GGML_OP_ROPE: @@ -4071,6 +4072,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const break; } case GGML_OP_RMS_NORM: + case GGML_OP_NORM: case GGML_OP_L2_NORM: supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32; break; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl index bd8d32bded7..5eaf5e7bbe5 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl @@ -1,20 +1,17 @@ -#ifdef INPLACE -fn update(src_offset: u32, dst_offset: u32, scale: f32) { - src[dst_offset] = scale * src[src_offset]; -} +#if defined(SRC_F16) || defined(DST_F16) +enable f16; +#endif -@group(0) @binding(1) -var params: Params; +#ifdef SRC_F16 +#define SRC_TYPE f16 #else -fn update(src_offset: u32, dst_offset: u32, scale: f32) { - dst[dst_offset] = scale * src[src_offset]; -} - -@group(0) @binding(1) -var dst: array; +#define SRC_TYPE f32 +#endif -@group(0) @binding(2) -var params: Params; +#ifdef DST_F16 +#define DST_TYPE f16 +#else +#define DST_TYPE f32 #endif struct Params { @@ -40,9 +37,20 @@ struct Params { }; @group(0) @binding(0) -var src: array; +var src: array; -var scratch: array; +#ifdef INPLACE +@group(0) @binding(1) +var params: Params; +#else +@group(0) @binding(1) +var dst: array; + +@group(0) @binding(2) +var params: Params; +#endif + +var scratch: array; @compute @workgroup_size(WG_SIZE) fn main(@builtin(workgroup_id) wid: vec3, @@ -65,34 +73,81 @@ fn main(@builtin(workgroup_id) wid: vec3, if (col >= params.ne0) { break; } - sum += pow(src[i_src_row + col], 2.0); + let v = f32(src[i_src_row + col]); +#ifdef NORM + sum += v; +#else + sum += v * v; +#endif col += WG_SIZE; } scratch[lid.x] = sum; workgroupBarrier(); - var offset: u32 = WG_SIZE / 2; + + var offset: u32 = WG_SIZE / 2u; while (offset > 0) { if (lid.x < offset) { scratch[lid.x] += scratch[lid.x + offset]; } - offset = offset / 2; + offset /= 2u; workgroupBarrier(); } sum = scratch[0]; -#ifdef RMS_NORM +#ifdef NORM + let mean = sum / f32(params.ne0); + var sq_sum = 0.0f; + col = lid.x; + for (var j: u32 = 0; j < elems; j++) { + if (col >= params.ne0) { + break; + } + let v = f32(src[i_src_row + col]); + let d = v - mean; + sq_sum += d * d; + col += WG_SIZE; + } + + workgroupBarrier(); + scratch[lid.x] = sq_sum; + workgroupBarrier(); + offset = WG_SIZE / 2u; + while (offset > 0) { + if (lid.x < offset) { + scratch[lid.x] += scratch[lid.x + offset]; + } + offset /= 2u; + workgroupBarrier(); + } + + let variance = scratch[0] / f32(params.ne0); + let scale = 1.0 / sqrt(variance + params.eps); +#elif defined(RMS_NORM) let scale = 1.0/sqrt(sum/f32(params.ne0) + params.eps); #elif defined(L2_NORM) let scale = 1.0/max(sqrt(sum), params.eps); #endif +#ifdef NORM + let mean_val = mean; +#else + let mean_val = 0.0f; +#endif + col = lid.x; for (var j: u32 = 0; j < elems; j++) { if (col >= params.ne0) { break; } - update(i_src_row + col, i_dst_row + col, scale); + let i_src = i_src_row + col; + let i_dst = i_dst_row + col; + let v = src[i_src]; +#ifdef INPLACE + src[i_dst] = scale * (v - mean_val); +#else + dst[i_dst] = scale * (v - mean_val); +#endif col += WG_SIZE; } } From 0fffe2cdb87a4bf0b91dde4068858f6b7ef0838a Mon Sep 17 00:00:00 2001 From: Atomic-Germ <97569476+Atomic-Germ@users.noreply.github.com> Date: Sun, 3 May 2026 22:49:29 -0700 Subject: [PATCH 553/831] vulkan: delete dead GGML_VK_MAX_NODES def (llama/22621) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index c2f1883328f..423e01dbff1 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -111,8 +111,6 @@ static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; } #define VK_DEVICE_DESCRIPTOR_POOL_SIZE 256 -#define GGML_VK_MAX_NODES 8192 - #define VK_CHECK(err, msg) \ do { \ vk::Result err_ = (err); \ From 36a83b84bb29651e9802b8d287d3941240fc860b Mon Sep 17 00:00:00 2001 From: leonardHONG <2695316095@qq.com> Date: Mon, 4 May 2026 22:24:05 +0800 Subject: [PATCH 554/831] CUDA: use fastdiv for batch index split in get_rows (llama/22650) --- ggml/src/ggml-cuda/getrows.cu | 30 ++++++++++++++++++++---------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/ggml/src/ggml-cuda/getrows.cu b/ggml/src/ggml-cuda/getrows.cu index e99cba63d34..36b840e8148 100644 --- a/ggml/src/ggml-cuda/getrows.cu +++ b/ggml/src/ggml-cuda/getrows.cu @@ -6,17 +6,18 @@ template static __global__ void k_get_rows( const void * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst, const int64_t ne00, /*const int64_t ne01, const int64_t ne02, const int64_t ne03,*/ - /*const int64_t ne10,*/ const int64_t ne11, const int64_t ne12, /*const int64_t ne13,*/ + /*const int64_t ne10,*/ const int64_t ne11, const uint3 ne12_fdv, /*const int64_t ne13,*/ /*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3, /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03, const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) { - for (int64_t z = blockIdx.z; z < ne11*ne12; z += gridDim.z) { + for (int64_t z = blockIdx.z; z < ne11*(int64_t)ne12_fdv.z; z += gridDim.z) { for (int64_t i00 = 2*(blockIdx.y*blockDim.x + threadIdx.x); i00 < ne00; i00 += gridDim.y*blockDim.x) { // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher. const int i10 = blockIdx.x; - const int i11 = z / ne12; // TODO fastdiv - const int i12 = z % ne12; + const uint2 dm = fast_div_modulo((uint32_t)z, ne12_fdv); + const int i11 = dm.x; + const int i12 = dm.y; const int i01 = src1[i10*s10 + i11*s11 + i12*s12]; @@ -42,17 +43,18 @@ template static __global__ void k_get_rows_float( const src0_t * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst, const int64_t ne00, /*const int64_t ne01, const int64_t ne02, const int64_t ne03,*/ - /*const int64_t ne10,*/ const int64_t ne11, const int64_t ne12, /*const int64_t ne13,*/ + /*const int64_t ne10,*/ const int64_t ne11, const uint3 ne12_fdv, /*const int64_t ne13,*/ /*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3, /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03, const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) { - for (int64_t z = blockIdx.z; z < ne11*ne12; z += gridDim.z) { + for (int64_t z = blockIdx.z; z < ne11*(int64_t)ne12_fdv.z; z += gridDim.z) { for (int64_t i00 = blockIdx.y*blockDim.x + threadIdx.x; i00 < ne00; i00 += gridDim.y*blockDim.x) { // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher. const int i10 = blockIdx.x; - const int i11 = z / ne12; // TODO fastdiv - const int i12 = z % ne12; + const uint2 dm = fast_div_modulo((uint32_t)z, ne12_fdv); + const int i11 = dm.x; + const int i12 = dm.y; if (i00 >= ne00) { return; @@ -115,10 +117,14 @@ static void get_rows_cuda_q( GGML_ASSERT(ne00 % 2 == 0); + GGML_ASSERT(ne12 > 0); + GGML_ASSERT(ne11 <= std::numeric_limits::max() / ne12); + const uint3 ne12_fdv = init_fastdiv_values(ne12); + k_get_rows<<>>( src0_d, src1_d, dst_d, ne00, /*ne01, ne02, ne03,*/ - /*ne10,*/ ne11, ne12, /*ne13,*/ + /*ne10,*/ ne11, ne12_fdv, /*ne13,*/ /* s0,*/ s1, s2, s3, /* nb00,*/ nb01, nb02, nb03, s10, s11, s12/*, s13*/); @@ -146,10 +152,14 @@ static void get_rows_cuda_float( const size_t s12 = nb12 / sizeof(int32_t); // const size_t s13 = nb13 / sizeof(int32_t); + GGML_ASSERT(ne12 > 0); + GGML_ASSERT(ne11 <= std::numeric_limits::max() / ne12); + const uint3 ne12_fdv = init_fastdiv_values(ne12); + k_get_rows_float<<>>( src0_d, src1_d, dst_d, ne00, /*ne01, ne02, ne03,*/ - /*ne10,*/ ne11, ne12, /*ne13,*/ + /*ne10,*/ ne11, ne12_fdv, /*ne13,*/ /* s0,*/ s1, s2, s3, /* nb00,*/ nb01, nb02, nb03, s10, s11, s12/*, s13*/); From 254f951db8ffd75512c87dc3a234a4b84bf8c6ad Mon Sep 17 00:00:00 2001 From: Charles Xu Date: Mon, 4 May 2026 21:13:31 +0200 Subject: [PATCH 555/831] kleidiai : update to v1.24.0 and use release archive (llama/22549) --- ggml/src/ggml-cpu/CMakeLists.txt | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index c1c225f0197..869c7b238bf 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -578,13 +578,13 @@ function(ggml_add_cpu_backend_variant_impl tag_name) # Fetch KleidiAI sources: include(FetchContent) - set(KLEIDIAI_COMMIT_TAG "v1.22.0") - set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz") - set(KLEIDIAI_ARCHIVE_MD5 "54049037570ab0ee0a0d126b2ba5ece1") + set(KLEIDIAI_COMMIT_TAG "v1.24.0") + set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/releases/download/${KLEIDIAI_COMMIT_TAG}/kleidiai-${KLEIDIAI_COMMIT_TAG}-src.tar.gz") + set(KLEIDIAI_RELEASE_ARCHIVE_MD5 "2f02ebe29573d45813e671eb304f2a00") set(KLEIDIAI_FETCH_ARGS URL ${KLEIDIAI_DOWNLOAD_URL} - URL_HASH MD5=${KLEIDIAI_ARCHIVE_MD5} + URL_HASH MD5=${KLEIDIAI_RELEASE_ARCHIVE_MD5} ) if (CMAKE_VERSION VERSION_GREATER_EQUAL "3.24") list(APPEND KLEIDIAI_FETCH_ARGS DOWNLOAD_EXTRACT_TIMESTAMP NEW) From 4794432337769d04ffb7d443747d69e6c0fd7469 Mon Sep 17 00:00:00 2001 From: Ismail <115064057+AlrIsmail@users.noreply.github.com> Date: Tue, 5 May 2026 04:05:05 +0200 Subject: [PATCH 556/831] ggml : implement fast walsh-hadamard transform for kv rotation (#21352) (llama/22631) --- ggml/include/ggml.h | 11 +++++ ggml/src/ggml-cpu/ggml-cpu.c | 6 +++ ggml/src/ggml-cpu/ops.cpp | 88 ++++++++++++++++++++++++++++++++++++ ggml/src/ggml-cpu/ops.h | 1 + ggml/src/ggml.c | 10 ++++ 5 files changed, 116 insertions(+) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 703e3783136..3357a0d9985 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -438,6 +438,12 @@ extern "C" { GGML_PREC_F32 = 10, }; + // op hint + enum ggml_op_hint { + GGML_HINT_NONE = 0, + GGML_HINT_SRC0_IS_HADAMARD = 1, + }; + // model file types enum ggml_ftype { GGML_FTYPE_UNKNOWN = -1, @@ -1419,6 +1425,11 @@ extern "C" { struct ggml_tensor * a, enum ggml_prec prec); + // change the hint of a matrix multiplication + GGML_API void ggml_mul_mat_set_hint( + struct ggml_tensor * a, + enum ggml_op_hint hint); + // indirect matrix multiplication GGML_API struct ggml_tensor * ggml_mul_mat_id( struct ggml_context * ctx, diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 2b3eb5b5ce6..2d6cc1fcd46 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -1245,6 +1245,12 @@ void ggml_compute_forward_mul_mat( const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * src1 = dst->src[1]; + const int32_t hint = ggml_get_op_params_i32(dst, 1); + if (hint == GGML_HINT_SRC0_IS_HADAMARD && !params->use_ref) { + ggml_compute_forward_fwht(params, dst); + return; + } + GGML_TENSOR_BINARY_OP_LOCALS const int ith = params->ith; diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index a9bc21da6f0..211f1ba1b2f 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -11212,3 +11212,91 @@ void ggml_compute_forward_opt_step_sgd(const ggml_compute_params * params, ggml_ } } } + +static void ggml_compute_forward_fwht_f32(const ggml_compute_params * params, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + GGML_TENSOR_BINARY_OP_LOCALS + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t n = ne10; + GGML_ASSERT((n & (n - 1)) == 0); // must be power of 2 + + const int64_t nr = ne11 * ne12 * ne13; + const int64_t rows_per_thread = (nr + nth - 1) / nth; + const int64_t start_row = ith * rows_per_thread; + const int64_t end_row = MIN(start_row + rows_per_thread, nr); + + const float scale = 1.0f / sqrtf((float)n); + +#if defined(GGML_SIMD) + const GGML_F32_VEC v_minus_one = GGML_F32_VEC_SET1(-1.0f); +#endif + + for (int64_t r = start_row; r < end_row; r++) { + const int64_t i13 = r / (ne11 * ne12); + const int64_t i12 = (r - i13 * ne11 * ne12) / ne11; + const int64_t i11 = r - i13 * ne11 * ne12 - i12 * ne11; + + const float * src_row = (const float *) ((const char *) src1->data + i11 * nb11 + i12 * nb12 + i13 * nb13); + float * dst_row = (float *) ((char *) dst->data + i11 * nb1 + i12 * nb2 + i13 * nb3); + + for (int64_t j = 0; j < n; j++) { + dst_row[j] = src_row[j] * scale; + } + + // Scalar passes +#if defined(GGML_SIMD) + const int step = GGML_F32_EPR; +#else + const int step = n; +#endif + for (int64_t len = 1; len < step && len < n; len <<= 1) { + for (int64_t i = 0; i < n; i += 2 * len) { + for (int64_t j = 0; j < len; j++) { + float u = dst_row[i + j]; + float v = dst_row[i + len + j]; + dst_row[i + j] = u + v; + dst_row[i + len + j] = u - v; + } + } + } + + // SIMD passes using GGML_F32_VEC_* macros for multi-architecture support +#if defined(GGML_SIMD) + for (int64_t len = step; len < n; len <<= 1) { + for (int64_t i = 0; i < n; i += 2 * len) { + for (int64_t j = 0; j < len; j += step) { + GGML_F32_VEC u = GGML_F32_VEC_LOAD(dst_row + i + j); + GGML_F32_VEC v = GGML_F32_VEC_LOAD(dst_row + i + len + j); + + GGML_F32_VEC_STORE(dst_row + i + j, GGML_F32_VEC_ADD(u, v)); + GGML_F32_VEC_STORE(dst_row + i + len + j, GGML_F32_VEC_FMA(u, v, v_minus_one)); + } + } + } +#endif + } +} + +void ggml_compute_forward_fwht(const ggml_compute_params * params, ggml_tensor * dst) { + const ggml_tensor * src1 = dst->src[1]; + + switch (src1->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_fwht_f32(params, dst); + } + break; + default: + { + GGML_ABORT("fatal error - fwht is F32 only"); + } + } +} diff --git a/ggml/src/ggml-cpu/ops.h b/ggml/src/ggml-cpu/ops.h index 3fa1443abc4..29efdeee37f 100644 --- a/ggml/src/ggml-cpu/ops.h +++ b/ggml/src/ggml-cpu/ops.h @@ -111,6 +111,7 @@ void ggml_compute_forward_cross_entropy_loss(const struct ggml_compute_params * void ggml_compute_forward_cross_entropy_loss_back(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_opt_step_adamw(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_mul_mat(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_fwht(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_opt_step_sgd(const struct ggml_compute_params * params, struct ggml_tensor * dst); #ifdef __cplusplus } diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 81343eeb14c..191cf2fa106 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -3264,6 +3264,16 @@ void ggml_mul_mat_set_prec( ggml_set_op_params_i32(a, 0, prec_i32); } +void ggml_mul_mat_set_hint( + struct ggml_tensor * a, + enum ggml_op_hint hint) { + GGML_ASSERT(a->op == GGML_OP_MUL_MAT); + + const int32_t hint_i32 = (int32_t) hint; + + ggml_set_op_params_i32(a, 1, hint_i32); +} + // ggml_mul_mat_id /* From 6f6103f6d0034945a9377d16e29cf0d3ec2b4c35 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 5 May 2026 06:35:07 +0300 Subject: [PATCH 557/831] llama : add option to save memory in device buffers (llama/22679) * llama : add option to save memory in device buffers * tests : extend llama-save-load-state --- ggml/src/ggml-metal/ggml-metal-device.h | 1 + ggml/src/ggml-metal/ggml-metal-device.m | 42 +++++++++++++++++++++++++ ggml/src/ggml-metal/ggml-metal.cpp | 19 ++++++----- 3 files changed, 54 insertions(+), 8 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index a6c1dab5515..4718ca083b0 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -282,6 +282,7 @@ bool ggml_metal_buffer_is_shared(ggml_metal_buffer_t buf); void ggml_metal_buffer_memset_tensor(ggml_metal_buffer_t buf, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size); void ggml_metal_buffer_set_tensor (ggml_metal_buffer_t buf, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); void ggml_metal_buffer_get_tensor (ggml_metal_buffer_t buf, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); +bool ggml_metal_buffer_cpy_tensor (ggml_metal_buffer_t buf, const struct ggml_tensor * src, struct ggml_tensor * dst); void ggml_metal_buffer_clear (ggml_metal_buffer_t buf, uint8_t value); // finds the Metal buffer that contains the tensor data on the GPU device diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index fe90aafe7bc..fab7891c008 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -1,6 +1,7 @@ #import "ggml-metal-device.h" #import "ggml-impl.h" +#import "ggml-backend-impl.h" #include @@ -1737,6 +1738,47 @@ void ggml_metal_buffer_get_tensor(ggml_metal_buffer_t buf, const struct ggml_ten } } +bool ggml_metal_buffer_cpy_tensor(ggml_metal_buffer_t buf_dst, const struct ggml_tensor * src, struct ggml_tensor * dst) { + ggml_metal_buffer_t buf_src = (ggml_metal_buffer_t)src->buffer->context; + + const size_t size = ggml_nbytes(src); + + // if both buffers are shared, we can use memcpy directly + if (buf_dst->is_shared && buf_src->is_shared) { + memcpy(dst->data, src->data, size); + return true; + } + + // for private buffers, we need to use Metal blit commands + @autoreleasepool { + struct ggml_metal_buffer_id bid_src = ggml_metal_buffer_get_id(buf_src, src); + struct ggml_metal_buffer_id bid_dst = ggml_metal_buffer_get_id(buf_dst, dst); + + if (bid_src.metal == nil || bid_dst.metal == nil) { + return false; + } + + id cmd_buf = [buf_dst->dev->mtl_queue commandBufferWithUnretainedReferences]; + + { + id encoder = [cmd_buf blitCommandEncoder]; + + [encoder copyFromBuffer:bid_src.metal + sourceOffset:bid_src.offs + toBuffer:bid_dst.metal + destinationOffset:bid_dst.offs + size:size]; + + [encoder endEncoding]; + } + + [cmd_buf commit]; + [cmd_buf waitUntilCompleted]; + } + + return true; +} + void ggml_metal_buffer_clear(ggml_metal_buffer_t buf, uint8_t value) { if (buf->is_shared) { memset(buf->all_data, value, buf->all_size); diff --git a/ggml/src/ggml-metal/ggml-metal.cpp b/ggml/src/ggml-metal/ggml-metal.cpp index cc329d67594..35774254983 100644 --- a/ggml/src/ggml-metal/ggml-metal.cpp +++ b/ggml/src/ggml-metal/ggml-metal.cpp @@ -17,6 +17,9 @@ // note: can be overridden with GGML_METAL_DEVICES env to simulate virtual devices static int g_devices = 1; +// forward declaration +static bool ggml_backend_buffer_is_metal(ggml_backend_buffer_t buffer); + //////////////////////////////////////////////////////////////////////////////// // backend interface //////////////////////////////////////////////////////////////////////////////// @@ -68,11 +71,11 @@ static bool ggml_backend_metal_buffer_shared_cpy_tensor(ggml_backend_buffer_t bu GGML_ASSERT(ggml_metal_buffer_is_shared(ctx)); - GGML_UNUSED(buffer); - GGML_UNUSED(src); - GGML_UNUSED(dst); + if (!ggml_backend_buffer_is_metal(src->buffer)) { + return false; + } - return false; + return ggml_metal_buffer_cpy_tensor(ctx, src, dst); } static void ggml_backend_metal_buffer_shared_clear(ggml_backend_buffer_t buffer, uint8_t value) { @@ -144,11 +147,11 @@ static bool ggml_backend_metal_buffer_private_cpy_tensor(ggml_backend_buffer_t b GGML_ASSERT(!ggml_metal_buffer_is_shared(ctx)); - GGML_UNUSED(buffer); - GGML_UNUSED(src); - GGML_UNUSED(dst); + if (!ggml_backend_buffer_is_metal(src->buffer)) { + return false; + } - return false; + return ggml_metal_buffer_cpy_tensor(ctx, src, dst); } static void ggml_backend_metal_buffer_private_clear(ggml_backend_buffer_t buffer, uint8_t value) { From 716acdb08212087ca61b56f569e354068e6613eb Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 5 May 2026 13:14:32 +0300 Subject: [PATCH 558/831] ggml : bump version to 0.11.0 (ggml/1478) --- ggml/CMakeLists.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index c97f681988b..8dd4d64063f 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -4,8 +4,8 @@ project("ggml" C CXX ASM) ### GGML Version set(GGML_VERSION_MAJOR 0) -set(GGML_VERSION_MINOR 10) -set(GGML_VERSION_PATCH 2) +set(GGML_VERSION_MINOR 11) +set(GGML_VERSION_PATCH 0) set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}") list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/") From 0bafd810b60a76bdb9e9784759cd22511d779c40 Mon Sep 17 00:00:00 2001 From: Radoslav Gerganov Date: Tue, 5 May 2026 13:47:13 +0300 Subject: [PATCH 559/831] rpc : use graph uid instead of graph cache (llama/22701) Store the last graph uid and compare against it to determine if the same graph is being computed. --- ggml/src/ggml-rpc/ggml-rpc.cpp | 38 +++++++--------------------------- 1 file changed, 7 insertions(+), 31 deletions(-) diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index 7176d2feef9..1cb8f563d85 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -207,35 +207,11 @@ struct ggml_backend_rpc_buffer_type_context { size_t max_size; }; -struct graph_cache { - - bool is_cached(const ggml_cgraph * cgraph) { - if ((int)last_graph.size() != cgraph->n_nodes) { - return false; - } - for (int i = 0; i < cgraph->n_nodes; i++) { - if (memcmp(&last_graph[i], cgraph->nodes[i], sizeof(ggml_tensor)) != 0) { - return false; - } - } - return true; - } - - void add(const ggml_cgraph * cgraph) { - last_graph.resize(cgraph->n_nodes); - for (int i = 0; i < cgraph->n_nodes; i++) { - memcpy(&last_graph[i], cgraph->nodes[i], sizeof(ggml_tensor)); - } - } - - std::vector last_graph; -}; - struct ggml_backend_rpc_context { std::string endpoint; uint32_t device; std::string name; - graph_cache gc; + uint64_t last_graph_uid; }; struct ggml_backend_rpc_buffer_context { @@ -717,7 +693,7 @@ static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, g ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context; GGML_ASSERT(cgraph->n_nodes > 0); - bool reuse = rpc_ctx->gc.is_cached(cgraph); + bool reuse = cgraph->uid != 0 && rpc_ctx->last_graph_uid == cgraph->uid; if (reuse) { rpc_msg_graph_recompute_req request; request.device = rpc_ctx->device; @@ -725,7 +701,7 @@ static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, g bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_RECOMPUTE, &request, sizeof(request)); RPC_STATUS_ASSERT(status); } else { - rpc_ctx->gc.add(cgraph); + rpc_ctx->last_graph_uid = cgraph->uid; std::vector input; serialize_graph(rpc_ctx->device, cgraph, input); auto sock = get_socket(rpc_ctx->endpoint); @@ -791,10 +767,10 @@ ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint, u ggml_backend_t ggml_backend_rpc_init(const char * endpoint, uint32_t device) { std::string dev_name = "RPC" + std::to_string(device) + "[" + std::string(endpoint) + "]"; ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context { - /* .endpoint = */ endpoint, - /* .device = */ device, - /* .name = */ dev_name, - /* .gc = */ {}, + /* .endpoint = */ endpoint, + /* .device = */ device, + /* .name = */ dev_name, + /* .last_graph_uid = */ 0, }; auto reg = ggml_backend_rpc_add_server(endpoint); ggml_backend_t backend = new ggml_backend { From f83b6bdc44c5e44607cae112cfa883686cd57271 Mon Sep 17 00:00:00 2001 From: lhez Date: Sun, 10 May 2026 14:52:20 +0300 Subject: [PATCH 560/831] opencl: refactor Adreno q4_0 (llama/22335) --- ggml/src/ggml-opencl/CMakeLists.txt | 10 +- ggml/src/ggml-opencl/ggml-opencl.cpp | 944 +++++++----------- ...b_Bi_8x4.cl => gemm_noshuffle_q4_0_f32.cl} | 2 +- ..._f32_8x4.cl => gemm_noshuffle_q8_0_f32.cl} | 2 +- ..._general.cl => gemv_noshuffle_q4_0_f32.cl} | 10 +- ...fle.cl => gemv_noshuffle_q4_0_f32_spec.cl} | 10 +- ...q8_0_f32.cl => gemv_noshuffle_q8_0_f32.cl} | 0 7 files changed, 355 insertions(+), 623 deletions(-) rename ggml/src/ggml-opencl/kernels/{mul_mat_Ab_Bi_8x4.cl => gemm_noshuffle_q4_0_f32.cl} (99%) rename ggml/src/ggml-opencl/kernels/{mul_mm_q8_0_f32_8x4.cl => gemm_noshuffle_q8_0_f32.cl} (98%) rename ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general.cl => gemv_noshuffle_q4_0_f32.cl} (98%) rename ggml/src/ggml-opencl/kernels/{gemv_noshuffle.cl => gemv_noshuffle_q4_0_f32_spec.cl} (98%) rename ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general_q8_0_f32.cl => gemv_noshuffle_q8_0_f32.cl} (100%) diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index 35d425a431f..0a45a4daa13 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -66,8 +66,6 @@ set(GGML_OPENCL_KERNELS diag div gelu - gemv_noshuffle_general - gemv_noshuffle get_rows glu group_norm @@ -75,7 +73,6 @@ set(GGML_OPENCL_KERNELS im2col_f32 im2col_f16 mean - mul_mat_Ab_Bi_8x4 mul_mv_f16_f16 mul_mv_f16_f32_1row mul_mv_f16_f32_l4 @@ -120,12 +117,15 @@ set(GGML_OPENCL_KERNELS mul_mm_q4_k_f32_l4_lm mul_mm_q5_k_f32_l4_lm mul_mm_q6_k_f32_l4_lm - mul_mm_q8_0_f32_8x4 + gemv_noshuffle_q4_0_f32 + gemv_noshuffle_q4_0_f32_spec + gemm_noshuffle_q4_0_f32 gemv_noshuffle_q4_1_f32 gemm_noshuffle_q4_1_f32 gemv_noshuffle_iq4_nl_f32 gemm_noshuffle_iq4_nl_f32 - gemv_noshuffle_general_q8_0_f32 + gemv_noshuffle_q8_0_f32 + gemm_noshuffle_q8_0_f32 gemv_noshuffle_q4_k_f32 gemm_noshuffle_q4_k_f32 gemv_noshuffle_q6_k_f32 diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 74948c27e4e..8c7bf98c16f 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -731,22 +731,16 @@ struct ggml_backend_opencl_context { cl_kernel kernel_transpose_16_4x1; // Gemm and Gemv related programs, kernels, etc - cl_program program_CL_gemm; - cl_program program_CL_gemv_general; - cl_program program_CL_gemv_4096_1_11008; - cl_program program_CL_gemv_4096_1_4096; - cl_program program_CL_gemv_11008_1_4096; - cl_program program_CL_gemv_32000_1_4096; - cl_kernel CL_mul_mat_Ab_Bi_8x4; - cl_kernel CL_mul_mat_vec_q4_0_f32_1d_4x_flat_general; - cl_kernel CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_11008; - cl_kernel CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_4096; - cl_kernel CL_mul_mat_vec_q4_0_f32_1d_4x_flat_11008_1_4096; - cl_kernel CL_mul_mat_vec_q4_0_f32_1d_4x_flat_32000_1_4096; + cl_kernel kernel_gemm_noshuffle_q4_0_f32; + cl_kernel kernel_gemv_noshuffle_q4_0_f32; + cl_kernel kernel_gemv_noshuffle_q4_0_f32_4096_1_11008; + cl_kernel kernel_gemv_noshuffle_q4_0_f32_4096_1_4096; + cl_kernel kernel_gemv_noshuffle_q4_0_f32_11008_1_4096; + cl_kernel kernel_gemv_noshuffle_q4_0_f32_32000_1_4096; cl_kernel kernel_gemv_noshuffle_q4_1_f32; cl_kernel kernel_gemm_noshuffle_q4_1_f32; - cl_kernel kernel_mul_mm_q8_0_f32_8x4; - cl_kernel CL_mul_mat_vec_q8_0_f32; + cl_kernel kernel_gemm_noshuffle_q8_0_f32; + cl_kernel kernel_gemv_noshuffle_q8_0_f32; cl_kernel kernel_gemv_noshuffle_q4_k_f32; cl_kernel kernel_gemm_noshuffle_q4_k_f32; cl_kernel kernel_gemv_noshuffle_q6_K_f32; @@ -2578,21 +2572,22 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve " -DSIMDGROUP_WIDTH=" + std::to_string(backend_ctx->adreno_wave_size); if (backend_ctx->has_vector_subgroup_broadcast) { - CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAST "; } #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src_CL_gemv_general { - #include "gemv_noshuffle_general.cl.h" + #include "gemv_noshuffle_q4_0_f32.cl.h" }; #else - const std::string kernel_src_CL_gemv_general = read_file("gemv_noshuffle_general.cl"); + const std::string kernel_src_CL_gemv_general = read_file("gemv_noshuffle_q4_0_f32.cl"); #endif - backend_ctx->program_CL_gemv_general = build_program_from_source( + cl_program prog = build_program_from_source( backend_ctx->context, backend_ctx->device, kernel_src_CL_gemv_general.c_str(), CL_gemv_compile_opts); - CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_general = clCreateKernel(backend_ctx->program_CL_gemv_general, "kernel_gemv_noshuffle", &err), err)); + CL_CHECK((backend_ctx->kernel_gemv_noshuffle_q4_0_f32 = clCreateKernel(prog, "kernel_gemv_noshuffle_q4_0_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); GGML_LOG_CONT("."); } @@ -2606,20 +2601,21 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve " -DSIMDGROUP_WIDTH=" + std::to_string(backend_ctx->adreno_wave_size); if (backend_ctx->has_vector_subgroup_broadcast) { - CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAST "; } #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src_CL_gemv { - #include "gemv_noshuffle.cl.h" + #include "gemv_noshuffle_q4_0_f32_spec.cl.h" }; #else - const std::string kernel_src_CL_gemv = read_file("gemv_noshuffle.cl"); + const std::string kernel_src_CL_gemv = read_file("gemv_noshuffle_q4_0_f32_spec.cl"); #endif - backend_ctx->program_CL_gemv_4096_1_4096 = build_program_from_source( + cl_program prog = build_program_from_source( backend_ctx->context, backend_ctx->device, kernel_src_CL_gemv.c_str(), CL_gemv_compile_opts); - CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_4096 = clCreateKernel(backend_ctx->program_CL_gemv_4096_1_4096, "kernel_gemv_noshuffle", &err), err)); + CL_CHECK((backend_ctx->kernel_gemv_noshuffle_q4_0_f32_4096_1_4096 = clCreateKernel(prog, "kernel_gemv_noshuffle_q4_0_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); GGML_LOG_CONT("."); // Gemv 2048, 16384 @@ -2630,12 +2626,13 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve " -DSIMDGROUP_WIDTH=" + std::to_string(backend_ctx->adreno_wave_size); if (backend_ctx->has_vector_subgroup_broadcast) { - CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAST "; } - backend_ctx->program_CL_gemv_4096_1_11008 = build_program_from_source( + prog = build_program_from_source( backend_ctx->context, backend_ctx->device, kernel_src_CL_gemv.c_str(), CL_gemv_compile_opts); - CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_11008 = clCreateKernel(backend_ctx->program_CL_gemv_4096_1_11008, "kernel_gemv_noshuffle", &err), err)); + CL_CHECK((backend_ctx->kernel_gemv_noshuffle_q4_0_f32_4096_1_11008 = clCreateKernel(prog, "kernel_gemv_noshuffle_q4_0_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); GGML_LOG_CONT("."); // Gemv 5504, 44032 @@ -2646,12 +2643,13 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve " -DSIMDGROUP_WIDTH=" + std::to_string(backend_ctx->adreno_wave_size); if (backend_ctx->has_vector_subgroup_broadcast) { - CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAST "; } - backend_ctx->program_CL_gemv_11008_1_4096 = build_program_from_source( + prog = build_program_from_source( backend_ctx->context, backend_ctx->device, kernel_src_CL_gemv.c_str(), CL_gemv_compile_opts); - CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_11008_1_4096 = clCreateKernel(backend_ctx->program_CL_gemv_11008_1_4096, "kernel_gemv_noshuffle", &err), err)); + CL_CHECK((backend_ctx->kernel_gemv_noshuffle_q4_0_f32_11008_1_4096 = clCreateKernel(prog, "kernel_gemv_noshuffle_q4_0_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); GGML_LOG_CONT("."); // Gemv 16000, 128000 @@ -2663,12 +2661,13 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve std::to_string(backend_ctx->adreno_wave_size); if (backend_ctx->has_vector_subgroup_broadcast) { - CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAST "; } - backend_ctx->program_CL_gemv_32000_1_4096 = build_program_from_source( + prog = build_program_from_source( backend_ctx->context, backend_ctx->device, kernel_src_CL_gemv.c_str(), CL_gemv_compile_opts); - CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_32000_1_4096 = clCreateKernel(backend_ctx->program_CL_gemv_32000_1_4096, "kernel_gemv_noshuffle", &err), err)); + CL_CHECK((backend_ctx->kernel_gemv_noshuffle_q4_0_f32_32000_1_4096 = clCreateKernel(prog, "kernel_gemv_noshuffle_q4_0_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); GGML_LOG_CONT("."); } @@ -2676,13 +2675,14 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve { #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src_CL_gemm { - #include "mul_mat_Ab_Bi_8x4.cl.h" + #include "gemm_noshuffle_q4_0_f32.cl.h" }; #else - const std::string kernel_src_CL_gemm = read_file("mul_mat_Ab_Bi_8x4.cl"); + const std::string kernel_src_CL_gemm = read_file("gemm_noshuffle_q4_0_f32.cl"); #endif - backend_ctx->program_CL_gemm = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_CL_gemm.c_str(), compile_opts); - CL_CHECK((backend_ctx->CL_mul_mat_Ab_Bi_8x4 = clCreateKernel(backend_ctx->program_CL_gemm, "kernel_mul_mat_Ab_Bi_8x4", &err), err)); + cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_CL_gemm.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_gemm_noshuffle_q4_0_f32 = clCreateKernel(prog, "kernel_gemm_noshuffle_q4_0_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); GGML_LOG_CONT("."); } @@ -2767,14 +2767,15 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve // mul_mm_q8_0_f32_8x4 { #ifdef GGML_OPENCL_EMBED_KERNELS - const std::string kernel_src_q8_8x4_gemm { - #include "mul_mm_q8_0_f32_8x4.cl.h" + const std::string kernel_src { + #include "gemm_noshuffle_q8_0_f32.cl.h" }; #else - const std::string kernel_src_q8_8x4_gemm = read_file("mul_mm_q8_0_f32_8x4.cl"); + const std::string kernel_src = read_file("gemm_noshuffle_q8_0_f32.cl"); #endif - backend_ctx->program_CL_gemm = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_q8_8x4_gemm.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_mul_mm_q8_0_f32_8x4 = clCreateKernel(backend_ctx->program_CL_gemm, "kernel_mul_mm_q8_0_f32_8x4", &err), err)); + cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_gemm_noshuffle_q8_0_f32 = clCreateKernel(prog, "kernel_gemm_noshuffle_q8_0_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); GGML_LOG_CONT("."); } @@ -2790,16 +2791,16 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src_CL_gemv_general { - #include "gemv_noshuffle_general_q8_0_f32.cl.h" + #include "gemv_noshuffle_q8_0_f32.cl.h" }; #else - const std::string kernel_src_CL_gemv_general = read_file("gemv_noshuffle_general_q8_0_f32.cl"); + const std::string kernel_src_CL_gemv_general = read_file("gemv_noshuffle_q8_0_f32.cl"); #endif cl_program prog = build_program_from_source( backend_ctx->context, backend_ctx->device, kernel_src_CL_gemv_general.c_str(), CL_gemv_compile_opts); - CL_CHECK((backend_ctx->CL_mul_mat_vec_q8_0_f32 = clCreateKernel(prog, "kernel_gemv_noshuffle_q8_0_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_gemv_noshuffle_q8_0_f32 = clCreateKernel(prog, "kernel_gemv_noshuffle_q8_0_f32", &err), err)); CL_CHECK(clReleaseProgram(prog)); GGML_LOG_CONT("."); } @@ -4937,164 +4938,15 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, // Only do transpose for large, non batched matrix // TODO: use preallocated images instead of sub-buffer then image if (use_adreno_kernels(backend_ctx, tensor)) { - // <----------------------------------------------------------------------------------> // - // start transpose - // <----------------------------------------------------------------------------------> // - int M = tensor->ne[1]; // ne01 - int K = tensor->ne[0]; // ne00 - - //For matrix-vector multiplication kernel, we assume K is a multiple of 32 - GGML_ASSERT(K % 32 == 0); - //For transpose kernels, we assume K is a multiple of 4 (satisfied by prior assert), and M is a multiple of 4 - GGML_ASSERT(M % 4 == 0); - - // transpose is out of place, so we need to allocate transposed buffers - // <----------------------------------------------------------------------------------> // - // use sub_buffer of max buffer size instead - - size_t q_size_bytes = K * M / 8 * sizeof(float); - backend_ctx->prealloc_quant_trans.allocate(context, q_size_bytes); - - cl_buffer_region region; - region.origin = 0; - region.size = q_size_bytes; - cl_mem qT_d = clCreateSubBuffer( - backend_ctx->prealloc_quant_trans.buffer, - 0, - CL_BUFFER_CREATE_TYPE_REGION, - ®ion, - &err); - CL_CHECK(err); - - bool K_tile_trans = true; - if ((K / 32) % 4 != 0){ - K_tile_trans =false; - } - - size_t d_size_bytes = M * (K / 32) * 2; - backend_ctx->prealloc_scales_trans.allocate(context, d_size_bytes); - - region.origin = 0; - region.size = d_size_bytes; - cl_mem dT_d = clCreateSubBuffer( - backend_ctx->prealloc_scales_trans.buffer, - 0, - CL_BUFFER_CREATE_TYPE_REGION, - ®ion, - &err); - CL_CHECK(err); - - // <----------------------------------------------------------------------------------> // - - - // create images from the buffers - // <----------------------------------------------------------------------------------> // - cl_mem q_d_image1D; - cl_mem d_d_image1D; - cl_mem qT_d_image1D; - cl_mem dT_d_image1D; - - cl_image_format img_fmt_1d = { CL_RGBA, CL_HALF_FLOAT }; - cl_image_desc img_desc_1d; - - memset(&img_desc_1d, 0, sizeof(img_desc_1d)); - img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; - img_desc_1d.image_width = M * K / 4 / 4; - img_desc_1d.buffer = extra->q; - q_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err); - CL_CHECK(err); - - img_fmt_1d = { CL_RGBA, CL_HALF_FLOAT }; - memset(&img_desc_1d, 0, sizeof(img_desc_1d)); - img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; - img_desc_1d.image_width = M * K / 4 / 4; - img_desc_1d.buffer = qT_d; - qT_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err); - CL_CHECK(err); - - memset(&img_desc_1d, 0, sizeof(img_desc_1d)); - if (K_tile_trans) { - img_fmt_1d = { CL_RGBA, CL_HALF_FLOAT }; - img_desc_1d.image_width = M * K / 32 / 4; - } else { - img_fmt_1d = { CL_R, CL_HALF_FLOAT }; - img_desc_1d.image_width = M * K / 32; - } - img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; - img_desc_1d.buffer = extra->d; - d_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err); - CL_CHECK(err); - - img_fmt_1d = { CL_RGBA, CL_HALF_FLOAT }; - memset(&img_desc_1d, 0, sizeof(img_desc_1d)); - img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; - img_desc_1d.image_width = M * K / 32 / 4; - img_desc_1d.buffer = dT_d; - dT_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err); - CL_CHECK(err); - // <----------------------------------------------------------------------------------> // - - // set up and call the transpose kernels - // <----------------------------------------------------------------------------------> // - // weights - int height_q = M / 4; - int width_q = K / 4 / 4; - kernel = backend_ctx->kernel_transpose_16; - - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &q_d_image1D)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &qT_d_image1D)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_q)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_q)); - - size_t local_size_q[3] = {4, 16, 1}; - size_t global_size_q[3] = {static_cast(width_q), static_cast(height_q), 1}; - CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_size_q, local_size_q, 0, NULL, &evt)); - CL_CHECK(clWaitForEvents(1, &evt)); - - // scales - int height_s = M / 4; - int width_s = K / 32 / 4; - - kernel = backend_ctx->kernel_transpose_16; - if (!K_tile_trans) { - kernel = backend_ctx->kernel_transpose_16_4x1; - width_s = K / 32; - } - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &d_d_image1D)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &dT_d_image1D)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_s)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_s)); - - size_t local_size_s[3] = {4, 16, 1}; - size_t global_size_s[3] = {static_cast(width_s), static_cast(height_s), 1}; - CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_size_s, local_size_s, 0, NULL, &evt)); - CL_CHECK(clWaitForEvents(1, &evt)); - // <----------------------------------------------------------------------------------> // + int M = tensor->ne[1]; + int K = tensor->ne[0]; - // copy transposed buffer contents to original buffers - // <----------------------------------------------------------------------------------> // - // weights - CL_CHECK(clEnqueueCopyBuffer(queue, qT_d, extra->q, 0, 0, q_size_bytes, 0, NULL, &evt)); - CL_CHECK(clWaitForEvents(1, &evt)); + GGML_ASSERT(K % 32 == 0); - // scales - CL_CHECK(clEnqueueCopyBuffer(queue, dT_d, extra->d, 0, 0, d_size_bytes, 0, NULL, &evt)); - CL_CHECK(clWaitForEvents(1, &evt)); - // <----------------------------------------------------------------------------------> // - - // deallocate transpose buffers - // <----------------------------------------------------------------------------------> // - CL_CHECK(clReleaseMemObject(qT_d)); - CL_CHECK(clReleaseMemObject(dT_d)); - - // deallocate temporary images - CL_CHECK(clReleaseMemObject(q_d_image1D)); - CL_CHECK(clReleaseMemObject(d_d_image1D)); - CL_CHECK(clReleaseMemObject(qT_d_image1D)); - CL_CHECK(clReleaseMemObject(dT_d_image1D)); - // <----------------------------------------------------------------------------------> // - // end transpose - // <----------------------------------------------------------------------------------> // + // Transpose q as ushort + transpose_2d_as_16b(backend_ctx, extra->q, extra->q, size_q, K/4, M); + // Transpose d as ushort + transpose_2d_as_16b(backend_ctx, extra->d, extra->d, size_d, K/32, M); } #endif // GGML_OPENCL_USE_ADRENO_KERNELS @@ -5820,8 +5672,9 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, #ifdef GGML_OPENCL_USE_ADRENO_KERNELS if (use_adreno_kernels(backend_ctx, tensor)) { - cl_int err; - cl_kernel kernel; + ggml_cl_buffer buf_trans_q; + ggml_cl_buffer buf_trans_d; + ggml_cl_buffer buf_unpacked; cl_int M = tensor->ne[1]; // ne01 cl_int K = tensor->ne[0]; // ne00 @@ -5833,46 +5686,12 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, size_t size_d = (ggml_nelements(tensor)/ggml_blck_size(tensor->type))*sizeof(ggml_fp16_t); GGML_ASSERT(size_d + size_q == ggml_nbytes(tensor) && "Incorrect tensor size"); - cl_mem buf_trans_q; - cl_mem buf_trans_d; - - CL_CHECK((buf_trans_q = clCreateBuffer(context, CL_MEM_READ_WRITE, - size_q, NULL, &err), err)); - CL_CHECK((buf_trans_d = clCreateBuffer(context, CL_MEM_READ_WRITE, - size_d, NULL, &err), err)); - - kernel = backend_ctx->kernel_transpose_16_buf; - - // transpose q back - cl_int stride_k_q = K/4; - size_t local_size_q[3] = {64, 1, 1}; - size_t global_size_q[3] = {(size_t)M, (size_t)stride_k_q, 1}; - - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &buf_trans_q)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_int), &M)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_int), &stride_k_q)); - - CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, - global_size_q, local_size_q, 0, NULL, NULL)); - - // transpose scales back - cl_int stride_k_d = K/32; - size_t local_size_d[3] = {64, 1, 1}; - size_t global_size_d[3] = {(size_t)M, (size_t)stride_k_d, 1}; - - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->d)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &buf_trans_d)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_int), &M)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_int), &stride_k_d)); - - CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, - global_size_d, local_size_d, 0, NULL, NULL)); + buf_trans_q.allocate(backend_ctx->context, size_q); + buf_trans_d.allocate(backend_ctx->context, size_d); + buf_unpacked.allocate(backend_ctx->context, ggml_nbytes(tensor)); - // unpack - cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, - ggml_nbytes(tensor), NULL, &err); - CL_CHECK(err); + transpose_2d_as_16b(backend_ctx, extra->q, buf_trans_q.buffer, size_q, M, K/4); + transpose_2d_as_16b(backend_ctx, extra->d, buf_trans_d.buffer, size_d, M, K/32); cl_uchar mask_0F = 0x0F; cl_uchar mask_F0 = 0xF0; @@ -5880,25 +5699,15 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; size_t local_work_size[] = {1, 1, 1}; - kernel = backend_ctx->kernel_restore_block_q4_0_noshuffle; - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &buf_trans_q)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &buf_trans_d)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &data_device)); + cl_kernel kernel = backend_ctx->kernel_restore_block_q4_0_noshuffle; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &buf_trans_q.buffer)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &buf_trans_d.buffer)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &buf_unpacked.buffer)); CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_uchar), &mask_0F)); CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_uchar), &mask_F0)); - CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, - global_work_size, local_work_size, 0, NULL, NULL)); - - // read back to host - CL_CHECK(clEnqueueReadBuffer( - queue, data_device, CL_TRUE, offset, - size, data, 0, NULL, NULL)); - - CL_CHECK(clReleaseMemObject(data_device)); - CL_CHECK(clReleaseMemObject(buf_trans_q)); - CL_CHECK(clReleaseMemObject(buf_trans_d)); - + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); + CL_CHECK(clEnqueueReadBuffer(queue, buf_unpacked.buffer, CL_TRUE, offset, size, data, 0, NULL, NULL)); return; } #endif @@ -10073,6 +9882,235 @@ static void ggml_cl_mul_mat_kq_kqv_adreno(ggml_backend_t backend, const ggml_ten CL_CHECK(clReleaseMemObject(D_sub_buffer)); } +static void ggml_cl_mul_mat_q4_0_f32_adreno(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + ggml_tensor_extra_cl_q4_0 * extra0_q4_0 = (ggml_tensor_extra_cl_q4_0 *)src0->extra; + + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + + const int ne10 = src1->ne[0]; + const int ne12 = src1->ne[2]; + + const int ne0 = dst->ne[0]; + const int ne1 = dst->ne[1]; + + GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0); + + cl_context context = backend_ctx->context; + cl_kernel kernel; + + cl_int err; + cl_image_format img_fmt; + cl_image_desc img_desc; + cl_buffer_region region; + + int M = ne01; + int N = ne1; + int K = ne00; + + if (ne1 == 1) { + cl_mem q_img = nullptr; + cl_mem b_sub_buf = nullptr; + cl_mem b_img = nullptr; + + // image for q + img_fmt = { CL_R, CL_UNSIGNED_INT32}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = M * K / 2 / 4; + img_desc.buffer = extra0_q4_0->q; + CL_CHECK((q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + // subbuffer for activations + region.origin = offset1; + region.size = K * N * sizeof(float); + CL_CHECK((b_sub_buf = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for activations + img_fmt = {CL_RGBA, CL_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * N / 4; + img_desc.buffer = b_sub_buf; + CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + kernel = backend_ctx->kernel_gemv_noshuffle_q4_0_f32; + if (M == 4096 && K == 4096) { + kernel = backend_ctx->kernel_gemv_noshuffle_q4_0_f32_4096_1_4096; + } else if (M == 4096 && K == 11008) { + kernel = backend_ctx->kernel_gemv_noshuffle_q4_0_f32_4096_1_11008; + } else if (M == 11008 && K == 4096) { + kernel = backend_ctx->kernel_gemv_noshuffle_q4_0_f32_11008_1_4096; + } else if (M == 32000 && K == 4096) { + kernel = backend_ctx->kernel_gemv_noshuffle_q4_0_f32_32000_1_4096; + } + + int r2 = 1; + int r3 = 1; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &q_img)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_0->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &b_img)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r3)); + + size_t local_work_size[3] = {64, 4, 1}; + size_t global_work_size[3] = {(size_t)CEIL_DIV(ne01/2, 64)*64, 4, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + + CL_CHECK(clReleaseMemObject(q_img)); + CL_CHECK(clReleaseMemObject(b_sub_buf)); + CL_CHECK(clReleaseMemObject(b_img)); + } else { + cl_mem b_sub_buf = nullptr; + cl_mem b_sub_buf_trans = nullptr; + cl_mem b_img = nullptr; + cl_mem b_img_trans = nullptr; + cl_mem d_sub_buf = nullptr; + + // subbuffer for activations + region.origin = offset1; + region.size = K * N * sizeof(float); + CL_CHECK((b_sub_buf = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for activations + img_fmt = {CL_RGBA, CL_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * N / 4; + img_desc.buffer = b_sub_buf; + CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + // pad N to multiple of 8 + int extra_elements = N % 8; + int padding = 0; + if (extra_elements > 0){ + padding = 8 - extra_elements; + } + + // subbuffer for transposed activations + region.origin = 0; + region.size = K * (N + padding) * sizeof(float)/2; + backend_ctx->prealloc_act_trans.allocate(context, region.size); + CL_CHECK((b_sub_buf_trans = clCreateSubBuffer(backend_ctx->prealloc_act_trans.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for transposed activations + img_fmt = {CL_RGBA, CL_HALF_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * (N + padding) / 4; + img_desc.buffer = b_sub_buf_trans; + CL_CHECK((b_img_trans = clCreateImage(context, 0, &img_fmt, &img_desc, NULL, &err), err)); + + // subbuffer for output + region.origin = extrad->offset; // Specify the starting offset (in bytes) + region.size = M * N * sizeof(float); // Specify the size of the sub-buffer + CL_CHECK((d_sub_buf = clCreateSubBuffer(extrad->data_device, CL_MEM_WRITE_ONLY, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // transpose activations + int height_B = N/4; + if (height_B == 0) { + height_B = 1; + } + int width_B = K/4; + int padded_height_B = (N + padding)/4; + + kernel = backend_ctx->kernel_transpose_32_16; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &b_img)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &b_img_trans)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_B)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_B)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &padded_height_B)); + + size_t local_work_size_t[2] = { 1, 16 }; + size_t global_work_size_t[2] = { (size_t)width_B, (size_t)padded_height_B }; + if (ne0 == 4096 && ne1 == 128 && ne10 == 4096) { + local_work_size_t[0]=4; + local_work_size_t[1]=8; + } else if (ne0 == 11008 && ne1 == 128 && ne10 == 4096) { + local_work_size_t[0]=2; + local_work_size_t[1]=8; + } else if(ne0 == 4096 && ne1 == 128 && ne10 == 11008) { + local_work_size_t[0]=1; + local_work_size_t[1]=8; + } else if(ne0 == 32000 && ne1 == 128 && ne10 == 4096) { + local_work_size_t[0]=2; + local_work_size_t[1]=8; + } + backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size_t, local_work_size_t, dst); + + // gemm + kernel = backend_ctx->kernel_gemm_noshuffle_q4_0_f32; + int padded_N = N + padding; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q4_0->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_0->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &b_img_trans)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &d_sub_buf)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_int), &padded_N)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_int), &ne1)); + + size_t global_work_size[3] = {(size_t)CEIL_DIV(ne1, 8), (size_t)CEIL_DIV(ne01, 4), 1}; + size_t local_work_size[3] = {1, 128, 1}; + if (ne0 == 4096 && ne1 == 128 && ne10 == 4096) { + local_work_size[0] = 1; + local_work_size[1] = 128; + } else if (ne0 == 11008 && ne1 == 128 && ne10 == 4096) { + local_work_size[0] = 2; + local_work_size[1] = 64; + } else if (ne0 == 4096 && ne1 == 128 && ne10 == 11008) { + local_work_size[0] = 2; + local_work_size[1] = 64; + } else if (ne0 == 32000 && ne1 == 128 && ne10 == 4096) { + local_work_size[0] = 2; + local_work_size[1] = 64; + } + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + + CL_CHECK(clReleaseMemObject(b_sub_buf)); + CL_CHECK(clReleaseMemObject(b_sub_buf_trans)); + CL_CHECK(clReleaseMemObject(b_img)); + CL_CHECK(clReleaseMemObject(b_img_trans)); + CL_CHECK(clReleaseMemObject(d_sub_buf)); + } +#else + GGML_UNUSED(backend); + GGML_UNUSED(src0); + GGML_UNUSED(src1); + GGML_UNUSED(dst); +#endif +} + static void ggml_cl_mul_mat_q4_1_f32_adreno(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { #ifdef GGML_OPENCL_USE_ADRENO_KERNELS GGML_ASSERT(src0); @@ -10495,7 +10533,7 @@ static void ggml_cl_mul_mat_q8_0_f32_adreno(ggml_backend_t backend, const ggml_t img_desc.buffer = b_sub_buf; CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); - kernel = backend_ctx->CL_mul_mat_vec_q8_0_f32; + kernel = backend_ctx->kernel_gemv_noshuffle_q8_0_f32; int r2 = 1; int r3 = 1; @@ -10585,7 +10623,7 @@ static void ggml_cl_mul_mat_q8_0_f32_adreno(ggml_backend_t backend, const ggml_t backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size_t, local_work_size_t, dst); // gemm - kernel = backend_ctx->kernel_mul_mm_q8_0_f32_8x4; + kernel = backend_ctx->kernel_gemm_noshuffle_q8_0_f32; int padded_N = N + padding; CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q8_0->q)); @@ -11195,8 +11233,8 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co GGML_ASSERT(dst); GGML_ASSERT(dst->extra); - const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT; - const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT; + const enum ggml_type src0t = src0->type; + const enum ggml_type src1t = src1->type; ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; @@ -11219,28 +11257,12 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co ggml_tensor_extra_cl_q6_K * extra0_q6_K = (ggml_tensor_extra_cl_q6_K *)src0->extra; #endif - const int ne00 = src0 ? src0->ne[0] : 0; - const int ne01 = src0 ? src0->ne[1] : 0; - const int ne02 = src0 ? src0->ne[2] : 0; - const int ne03 = src0 ? src0->ne[3] : 0; - - const cl_ulong nb00 = src0 ? src0->nb[0] : 0; - const cl_ulong nb01 = src0 ? src0->nb[1] : 0; - const cl_ulong nb02 = src0 ? src0->nb[2] : 0; - const cl_ulong nb03 = src0 ? src0->nb[3] : 0; - - const int ne10 = src1 ? src1->ne[0] : 0; - const int ne11 = src1 ? src1->ne[1] : 0; - const int ne12 = src1 ? src1->ne[2] : 0; - const int ne13 = src1 ? src1->ne[3] : 0; - - const cl_ulong nb10 = src1 ? src1->nb[0] : 0; - const cl_ulong nb11 = src1 ? src1->nb[1] : 0; - const cl_ulong nb12 = src1 ? src1->nb[2] : 0; - const cl_ulong nb13 = src1 ? src1->nb[3] : 0; - - const int ne0 = dst ? dst->ne[0] : 0; - const int ne1 = dst ? dst->ne[1] : 0; + GGML_TENSOR_LOCALS(int, ne0, src0, ne); + GGML_TENSOR_LOCALS(cl_ulong, nb0, src0, nb); + GGML_TENSOR_LOCALS(int, ne1, src1, ne); + GGML_TENSOR_LOCALS(cl_ulong, nb1, src1, nb); + GGML_TENSOR_LOCALS(int, ne, dst, ne); + GGML_TENSOR_LOCALS(cl_ulong, nb, dst, nb); int r2 = ne12/ne02; int r3 = ne13/ne03; @@ -11256,8 +11278,6 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co cl_kernel kernel; #ifdef GGML_OPENCL_USE_ADRENO_KERNELS - cl_context context = backend_ctx->context; - if(src0t == GGML_TYPE_F16 && src1t == GGML_TYPE_F32){ if (ne01 >= 64 && ne1 >= 32 && ne00 >= 16 && (ne12 % ne02) == 0 && // dst is wrapped with image1d_buffer, the size limit applies, also src0 @@ -11284,340 +11304,52 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co } if (ne01 && ne1 && use_adreno_kernels(backend_ctx, src0)) { + // NOTE: Kernels using image1d_buffer_t (e.g., src0_q) would normally require + // a limit check, but q4_0 / q4_1 tensors are very unlikely to exceed that + // limit, so the check is omitted. - // init CL objects - // <--------------------------------------------> // - cl_int status; - cl_image_format img_fmt_1d; - cl_image_desc img_desc_1d; - cl_buffer_region region; - cl_mem A_image1d = nullptr; - cl_mem B_image1d = nullptr; - cl_mem B_sub_buffer = nullptr; - cl_mem C_d = nullptr; - // for B transpose - cl_mem B_d = nullptr; - cl_mem B_d_input_image = nullptr; - // <--------------------------------------------> // - - // define matrix dimensions - // <--------------------------------------------> // - int M = ne01; - int N = ne1; - int K = ne00; - int padding; - // <--------------------------------------------> // - - // NOTE: Kernels using image1d_buffer_t (e.g., src0_q) would normally require - // a limit check, but q4_0 / q4_1 tensors are very unlikely to exceed that - // limit, so the check is omitted. + // q4_0 x fp32 + if(src0t == GGML_TYPE_Q4_0 && src1t == GGML_TYPE_F32) { + ggml_cl_mul_mat_q4_0_f32_adreno(backend, src0, src1, dst); + return; + } - // q4_1 x fp32 - if (src0t == GGML_TYPE_Q4_1 && src1t == GGML_TYPE_F32) { + // q4_1 x fp32 + if (src0t == GGML_TYPE_Q4_1 && src1t == GGML_TYPE_F32) { ggml_cl_mul_mat_q4_1_f32_adreno(backend, src0, src1, dst); return; - } - - // iq4_nl x fp32 - if (src0t == GGML_TYPE_IQ4_NL && src1t == GGML_TYPE_F32) { - ggml_cl_mul_mat_iq4_nl_f32_adreno(backend, src0, src1, dst); - return; - } + } - // q8_0 x fp32 - if (src0t == GGML_TYPE_Q8_0 && src1t == GGML_TYPE_F32 && - enable_adreno_trans_weight(backend_ctx, src0)) { - ggml_cl_mul_mat_q8_0_f32_adreno(backend, src0, src1, dst); + // iq4_nl x fp32 + if (src0t == GGML_TYPE_IQ4_NL && src1t == GGML_TYPE_F32) { + ggml_cl_mul_mat_iq4_nl_f32_adreno(backend, src0, src1, dst); return; - } + } + + // q8_0 x fp32 + if (src0t == GGML_TYPE_Q8_0 && src1t == GGML_TYPE_F32 && + enable_adreno_trans_weight(backend_ctx, src0)) { + ggml_cl_mul_mat_q8_0_f32_adreno(backend, src0, src1, dst); + return; + } - // q4_k x fp32 - if (src0t == GGML_TYPE_Q4_K && src1t == GGML_TYPE_F32) { + // q4_k x fp32 + if (src0t == GGML_TYPE_Q4_K && src1t == GGML_TYPE_F32) { ggml_cl_mul_mat_q4_k_f32_adreno(backend, src0, src1, dst); return; - } - - // q6_K x fp32 - if (src0t == GGML_TYPE_Q6_K && src1t == GGML_TYPE_F32) { - ggml_cl_mul_mat_q6_K_f32_adreno(backend, src0, src1, dst); - return; - } - - // q5_K x fp32 - if (src0t == GGML_TYPE_Q5_K && src1t == GGML_TYPE_F32) { - ggml_cl_mul_mat_q5_K_f32_adreno(backend, src0, src1, dst); - return; - } - - // q4_0 x fp32 - if(src0t == GGML_TYPE_Q4_0 && src1t == GGML_TYPE_F32) { - // TODO: remove duplicate definitions of image description + format -- move to top - - // create an image for A - // <--------------------------------------------> // - if (N == 1) { - img_fmt_1d = { CL_R, CL_UNSIGNED_INT32}; - } else { - img_fmt_1d = { CL_R, CL_FLOAT}; - } - memset(&img_desc_1d, 0, sizeof(img_desc_1d)); - img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; - img_desc_1d.image_width = M * K / 2 / 4; // Divide by 4 for char -> float - img_desc_1d.buffer = extra0_q4_0->q; - A_image1d = clCreateImage( - context, - CL_MEM_READ_ONLY, - &img_fmt_1d, - &img_desc_1d, - NULL, - &status); - CL_CHECK(status); - // <--------------------------------------------> // - - - // create a sub_buffer for B - // <--------------------------------------------> // - region.origin = (extra1->offset); - region.size = K * N * sizeof(float); - B_sub_buffer = clCreateSubBuffer( - extra1->data_device, - 0, - CL_BUFFER_CREATE_TYPE_REGION, - ®ion, - &status); - CL_CHECK(status); - // <--------------------------------------------> // - - // transpose activation for Skyler's gemm - if (N != 1) { - //how many extra elements beyond multiple of 8 - int extra_elements = N % 8; - - //how much padding to add - padding = 0; - if (extra_elements > 0){ - padding = 8 - extra_elements; - } - - // Specify the starting offset (in bytes) - region.origin = 0; - // Specify the size of the sub-buffer (divide by 2 for FP16) - region.size = K * (N + padding) * sizeof(float)/2; - backend_ctx->prealloc_act_trans.allocate(context, region.size); - - B_d = clCreateSubBuffer( - backend_ctx->prealloc_act_trans.buffer, - 0, - CL_BUFFER_CREATE_TYPE_REGION, - ®ion, - &status); - CL_CHECK(status); - - cl_image_format image_format_B_d_input = { CL_RGBA, CL_FLOAT }; - cl_image_desc image_desc_B_d_input = { - CL_MEM_OBJECT_IMAGE1D_BUFFER, - static_cast(K * N / 4), - 0, 0, 0, 0, 0, 0, 0, { B_sub_buffer } - }; - B_d_input_image = clCreateImage( - context, - 0, - &image_format_B_d_input, - &image_desc_B_d_input, - NULL, - &status); - CL_CHECK(status); - - cl_image_format image_format_B_d_output = { CL_RGBA, CL_HALF_FLOAT }; //(CL_HALF_FLOAT for FP16) - cl_image_desc image_desc_B_d_output = { - CL_MEM_OBJECT_IMAGE1D_BUFFER, - static_cast(K * (N + padding)/4), - 0, 0, 0, 0, 0, 0, 0, { B_d } - }; - B_image1d = clCreateImage( - context, - 0, - &image_format_B_d_output, - &image_desc_B_d_output, - NULL, - &status); - CL_CHECK(status); - - int height_B = N/4; - if (height_B == 0) { - height_B = 1; - } - int width_B = K/4; - int padded_height_B = (N + padding)/4; - - kernel = backend_ctx->kernel_transpose_32_16; - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &B_d_input_image)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &B_image1d)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_B)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_B)); - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &padded_height_B)); - - size_t local_size_t[2] = { 1, 16 }; - //WGS tuning - if (ne0 == 4096 && ne1 == 128 && ne10 == 4096) { - local_size_t[0]=4; - local_size_t[1]=8; - } else if (ne0 == 11008 && ne1 == 128 && ne10 == 4096) { - local_size_t[0]=2; - local_size_t[1]=8; - } else if(ne0 == 4096 && ne1 == 128 && ne10 == 11008) { - local_size_t[0]=1; - local_size_t[1]=8; - } else if(ne0 == 32000 && ne1 == 128 && ne10 == 4096) { - local_size_t[0]=2; - local_size_t[1]=8; - } - - size_t global_size_t[2] = { - static_cast(width_B), - static_cast(padded_height_B) - }; - - backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_size_t, local_size_t, dst); - } else { - // no need to transpose B in other cases - // create an image for B from sub_buffer - // <--------------------------------------------> // - img_fmt_1d = {CL_RGBA, CL_FLOAT}; - - memset(&img_desc_1d, 0, sizeof(img_desc_1d)); - img_desc_1d.image_width = K * N / 4; - img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; - img_desc_1d.buffer = B_sub_buffer; - B_image1d = clCreateImage( - context, - CL_MEM_READ_ONLY, - &img_fmt_1d, - &img_desc_1d, - NULL, - &status); - CL_CHECK(status); - // <--------------------------------------------> // - } - - // choose gemm or gemv kernel - // <--------------------------------------------> // - if (N == 1) { - kernel = backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_general; - if (M == 4096 && K == 4096) { - kernel = backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_4096; - } else if (M == 4096 && K == 11008) { - kernel = backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_11008; - } else if (M == 11008 && K == 4096) { - kernel = backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_11008_1_4096; - } else if (M == 32000 && K == 4096) { - kernel = backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_32000_1_4096; - } - } else { - kernel = backend_ctx->CL_mul_mat_Ab_Bi_8x4; - } - // <--------------------------------------------> // - - // set kernel args - // <--------------------------------------------> // - cl_uint k_arg = 0; - - if (N == 1) { - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_mem), &A_image1d)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_mem), &extra0_q4_0->d)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_mem), &B_image1d)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_ulong), &extra1->offset)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_mem), &extrad->data_device)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_ulong), &extrad->offset)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne00)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne01)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne02)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne10)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne12)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne0)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne1)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &r2)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &r3)); - } else { - region.origin = extrad->offset; // Specify the starting offset (in bytes) - region.size = M * N * sizeof(float); // Specify the size of the sub-buffer - C_d = clCreateSubBuffer(extrad->data_device, CL_MEM_WRITE_ONLY, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); - CL_CHECK(status); - - int padded_N = ne1 + padding; - - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q4_0->q)); //A_q_dextra0_q4_0->q - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_0->d)); //A_s_d - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &B_image1d)); //B_d - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &C_d)); //C_d - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne01)); //M - CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &padded_N)); //N with padding - CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); //K - CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne1)); //N without padding - } - // <--------------------------------------------> // - - // choose workgroup size - // <--------------------------------------------> // - size_t global_work_size[3] = { - 64, static_cast((M+63)/64), static_cast((N+31)/32)}; - size_t local_work_size[3] = {64, 2, 4}; - - global_work_size[0] = (size_t)(ceil((float)ne1/8)); - global_work_size[1] = (size_t)(ne01/4); - global_work_size[2] = (size_t)(1); - - local_work_size[0] = (size_t)(1); //4x32 for FP32 - local_work_size[1] = (size_t)(128); - local_work_size[2] = (size_t)(1); - - //WGS tuning - if (ne0 == 4096 && ne1 == 128 && ne10 == 4096) { - local_work_size[0] = 1; - local_work_size[1] = 128; - } else if (ne0 == 11008 && ne1 == 128 && ne10 == 4096) { - local_work_size[0] = 2; - local_work_size[1] = 64; - } else if (ne0 == 4096 && ne1 == 128 && ne10 == 11008) { - local_work_size[0] = 2; - local_work_size[1] = 64; - } else if (ne0 == 32000 && ne1 == 128 && ne10 == 4096) { - local_work_size[0] = 2; - local_work_size[1] = 64; } - if (N == 1) { - size_t wavesize = backend_ctx->adreno_wave_size; - local_work_size[0] = wavesize; // localsize - local_work_size[1] = 4; // reduce factor - local_work_size[2] = 1; - - global_work_size[0] = (((M / 2) + wavesize - 1) / wavesize) * wavesize; - global_work_size[1] = 4; // reduce factor - global_work_size[2] = 1; + // q6_K x fp32 + if (src0t == GGML_TYPE_Q6_K && src1t == GGML_TYPE_F32) { + ggml_cl_mul_mat_q6_K_f32_adreno(backend, src0, src1, dst); + return; } - // <--------------------------------------------> // - - // enqueue kernel with profiling - // <--------------------------------------------> // - backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); - // <--------------------------------------------> // - - // deallocate sub buffers and images - // <--------------------------------------------> // - CL_CHECK(clReleaseMemObject(A_image1d)); - CL_CHECK(clReleaseMemObject(B_sub_buffer)); - CL_CHECK(clReleaseMemObject(B_image1d)); - if (N != 1) { - CL_CHECK(clReleaseMemObject(B_d)); - CL_CHECK(clReleaseMemObject(B_d_input_image)); - CL_CHECK(clReleaseMemObject(C_d)); + // q5_K x fp32 + if (src0t == GGML_TYPE_Q5_K && src1t == GGML_TYPE_F32) { + ggml_cl_mul_mat_q5_K_f32_adreno(backend, src0, src1, dst); + return; } - // <--------------------------------------------> // - - return; - } } // if (ne01 && ne1) #endif // GGML_OPENCL_USE_ADRENO_KERNELS diff --git a/ggml/src/ggml-opencl/kernels/mul_mat_Ab_Bi_8x4.cl b/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_0_f32.cl similarity index 99% rename from ggml/src/ggml-opencl/kernels/mul_mat_Ab_Bi_8x4.cl rename to ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_0_f32.cl index ecb577b9933..159378049fb 100644 --- a/ggml/src/ggml-opencl/kernels/mul_mat_Ab_Bi_8x4.cl +++ b/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_0_f32.cl @@ -17,7 +17,7 @@ REQD_SUBGROUP_SIZE_128 #endif -kernel void kernel_mul_mat_Ab_Bi_8x4( +kernel void kernel_gemm_noshuffle_q4_0_f32( global const ushort * src0_q, // quantized A global const half * src0_d, // A scales __read_only image1d_buffer_t src1, // B (1d image) diff --git a/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl b/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q8_0_f32.cl similarity index 98% rename from ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl rename to ggml/src/ggml-opencl/kernels/gemm_noshuffle_q8_0_f32.cl index 51ce2121ce2..7f06a22a2cb 100644 --- a/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +++ b/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q8_0_f32.cl @@ -11,7 +11,7 @@ REQD_SUBGROUP_SIZE_128 #endif -kernel void kernel_mul_mm_q8_0_f32_8x4( +kernel void kernel_gemm_noshuffle_q8_0_f32( global const uint * src0_q, global const half * src0_d, __read_only image1d_buffer_t src1, diff --git a/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general.cl b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_0_f32.cl similarity index 98% rename from ggml/src/ggml-opencl/kernels/gemv_noshuffle_general.cl rename to ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_0_f32.cl index 469d3edef00..10683206919 100644 --- a/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general.cl +++ b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_0_f32.cl @@ -191,7 +191,7 @@ #ifdef ADRENO_GPU REQD_SUBGROUP_SIZE_64 #endif -__kernel void kernel_gemv_noshuffle( +__kernel void kernel_gemv_noshuffle_q4_0_f32( __read_only image1d_buffer_t src0_q, // quantized A global half2 * src0_d, // A scales __read_only image1d_buffer_t src1, // B @@ -238,21 +238,21 @@ __kernel void kernel_gemv_noshuffle( regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 1)).x; regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 2)).x; regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 3)).x; -#ifdef VECTOR_SUB_GROUP_BROADCAT +#ifdef VECTOR_SUB_GROUP_BROADCAST dequantizeBlockAccum_ns_sgbroadcast_8_hi(totalSum, as_ushort8(regA), regS, regB); #else dequantizeBlockAccum_ns_sgbroadcast_1_hi(totalSum, as_ushort8(regA), regS, regB); -#endif // VECTOR_SUB_GROUP_BROADCAT +#endif // VECTOR_SUB_GROUP_BROADCAST regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 4)).x; regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 5)).x; regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x; regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x; -#ifdef VECTOR_SUB_GROUP_BROADCAT +#ifdef VECTOR_SUB_GROUP_BROADCAST dequantizeBlockAccum_ns_sgbroadcast_8_lo(totalSum, as_ushort8(regA), regS, regB); #else dequantizeBlockAccum_ns_sgbroadcast_1_lo(totalSum, as_ushort8(regA), regS, regB); -#endif // VECTOR_SUB_GROUP_BROADCAT +#endif // VECTOR_SUB_GROUP_BROADCAST } // reduction in local memory, assumes #wave=4 diff --git a/ggml/src/ggml-opencl/kernels/gemv_noshuffle.cl b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_0_f32_spec.cl similarity index 98% rename from ggml/src/ggml-opencl/kernels/gemv_noshuffle.cl rename to ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_0_f32_spec.cl index ee5c79f000d..571a375da7f 100644 --- a/ggml/src/ggml-opencl/kernels/gemv_noshuffle.cl +++ b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_0_f32_spec.cl @@ -191,7 +191,7 @@ #ifdef ADRENO_GPU REQD_SUBGROUP_SIZE_64 #endif -__kernel void kernel_gemv_noshuffle( +__kernel void kernel_gemv_noshuffle_q4_0_f32( __read_only image1d_buffer_t src0_q, // quantized A global half2 * src0_d, // A scales __read_only image1d_buffer_t src1, // B @@ -232,21 +232,21 @@ __kernel void kernel_gemv_noshuffle( regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 1)).x; regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 2)).x; regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 3)).x; -#ifdef VECTOR_SUB_GROUP_BROADCAT +#ifdef VECTOR_SUB_GROUP_BROADCAST dequantizeBlockAccum_ns_sgbroadcast_8_hi(totalSum, as_ushort8(regA), regS, regB); #else dequantizeBlockAccum_ns_sgbroadcast_1_hi(totalSum, as_ushort8(regA), regS, regB); -#endif // VECTOR_SUB_GROUP_BROADCAT +#endif // VECTOR_SUB_GROUP_BROADCAST regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 4)).x; regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 5)).x; regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x; regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x; -#ifdef VECTOR_SUB_GROUP_BROADCAT +#ifdef VECTOR_SUB_GROUP_BROADCAST dequantizeBlockAccum_ns_sgbroadcast_8_lo(totalSum, as_ushort8(regA), regS, regB); #else dequantizeBlockAccum_ns_sgbroadcast_1_lo(totalSum, as_ushort8(regA), regS, regB); -#endif // VECTOR_SUB_GROUP_BROADCAT +#endif // VECTOR_SUB_GROUP_BROADCAST } // reduction in local memory, assumes #wave=4 diff --git a/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q8_0_f32.cl similarity index 100% rename from ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl rename to ggml/src/ggml-opencl/kernels/gemv_noshuffle_q8_0_f32.cl From a6d678954ab9935383cfb74a2b64735537a4abb8 Mon Sep 17 00:00:00 2001 From: Trivikram Reddy <127072883+trivikram-reddy1@users.noreply.github.com> Date: Tue, 5 May 2026 11:43:03 -0500 Subject: [PATCH 561/831] Hexagon: Process M-tail rows on HMX instead of HVX (llama/22724) * hex-mm: process m-tail rows on HMX instead of HVX * hmx-mm: unroll and optimize padded activation loop --------- Co-authored-by: Max Krasnyansky --- ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c | 51 ++++++++++++++++++---- ggml/src/ggml-hexagon/htp/matmul-ops.c | 36 +++------------ 2 files changed, 48 insertions(+), 39 deletions(-) diff --git a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c index 2666a78a96a..9e8c9966e04 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c @@ -742,17 +742,45 @@ static void transfer_output_chunk_threaded(struct htp_context *ctx, float *dst, // activations : fp32 -> fp16 static void transfer_activation_chunk_fp32_to_fp16(__fp16 *restrict vtcm_dst, const float *restrict src, int n_rows, int k_block, int k_stride) { - for (int r = 0; r < n_rows; r += 2) { + const int n_rows_padded = hex_align_up(n_rows, HMX_FP16_TILE_N_ROWS); + const int n_rows_tiled = (n_rows / HMX_FP16_TILE_N_ROWS) * HMX_FP16_TILE_N_ROWS; + + int r = 0; + + #pragma unroll(2) + for (r = 0; r < n_rows_tiled; r += 2) { int r0 = r / HMX_FP16_TILE_N_ROWS; // tile row index int r1 = r % HMX_FP16_TILE_N_ROWS; // intra-tile row idx - const bool next_row_valid = (r + 1) < n_rows; - const HVX_Vector *pv_in0 = (const HVX_Vector *) (src + (r + 0) * k_stride); const HVX_Vector *pv_in1 = (const HVX_Vector *) (src + (r + 1) * k_stride); for (int c = 0; c < k_block; c += 32) { HVX_Vector v0 = *pv_in0++; - HVX_Vector v1 = next_row_valid ? *pv_in1++ : Q6_V_vzero(); + HVX_Vector v1 = *pv_in1++; + + HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); + + // compute output position + int c0 = c / HMX_FP16_TILE_N_COLS; // tile column index + int tile_idx = r0 * (k_block / HMX_FP16_TILE_N_COLS) + c0; + + HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HMX_FP16_TILE_N_ELMS); + tile[r1 / 2] = v_out; + } + } + + for (; r < n_rows_padded; r += 2) { + int r0 = r / HMX_FP16_TILE_N_ROWS; // tile row index + int r1 = r % HMX_FP16_TILE_N_ROWS; // intra-tile row idx + + const bool row0_valid = r < n_rows; + const bool row1_valid = (r + 1) < n_rows; + + const HVX_Vector *pv_in0 = row0_valid ? (const HVX_Vector *) (src + (r + 0) * k_stride) : NULL; + const HVX_Vector *pv_in1 = row1_valid ? (const HVX_Vector *) (src + (r + 1) * k_stride) : NULL; + for (int c = 0; c < k_block; c += 32) { + HVX_Vector v0 = row0_valid ? *pv_in0++ : Q6_V_vzero(); + HVX_Vector v1 = row1_valid ? *pv_in1++ : Q6_V_vzero(); HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); @@ -889,7 +917,9 @@ static __attribute__((noinline)) int mat_mul_qk_0_d16a32_out_stationary(struct h // n_block_cost = m*2: each extra N-block re-loads all M×K activation (cheaper). const size_t m_block_cost = (size_t) n * 3; const size_t n_block_cost = (size_t) m * 2; - if (hmx_compute_chunks(vtcm_budget, overhead, per_n, per_m, per_mn, m, n, m_block_cost, n_block_cost, &M_BLOCK_SIZE, + if (hmx_compute_chunks(vtcm_budget, overhead, per_n, per_m, per_mn, + hex_align_up(m, HMX_FP16_TILE_N_ROWS), n, + m_block_cost, n_block_cost, &M_BLOCK_SIZE, &N_BLOCK_SIZE, &vtcm_used) != 0) { FARF(HIGH, "%s: VTCM too small (m=%d k=%d n=%d budget=%zu)", __func__, m, k, n, vtcm_budget); return -1; @@ -1084,7 +1114,8 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds if (m >= 128) { size_t mc = 0, nc = 0, used = 0; - if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, pipe_per_n, /*per_m=*/vec_dot_size, pipe_per_mn, m, n, + if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, pipe_per_n, /*per_m=*/vec_dot_size, pipe_per_mn, + hex_align_up(m, HMX_FP16_TILE_N_ROWS), n, /*m_block_cost=*/(size_t) n * 3, /*n_block_cost=*/(size_t) m * 2, &mc, &nc, &used) == 0 && hmx_ceil_div((size_t) n, nc) >= 2) { @@ -1096,7 +1127,8 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds } if (!use_pipeline) { - if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, seq_per_n, /*per_m=*/vec_dot_size, seq_per_mn, m, n, + if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, seq_per_n, /*per_m=*/vec_dot_size, seq_per_mn, + hex_align_up(m, HMX_FP16_TILE_N_ROWS), n, /*m_block_cost=*/(size_t) n * 3, /*n_block_cost=*/(size_t) m * 2, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) { FARF(HIGH, "%s: VTCM too small (m=%d k=%d n=%d budget=%zu)", __func__, m, k, n, vtcm_budget); @@ -1432,7 +1464,8 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, /*per_n=*/3 * vec_dot_size, /*per_m=*/group_size * vec_dot_size + f32_scratch_per_m, - /*per_mn=*/sizeof(__fp16), params->m, params->n, + /*per_mn=*/sizeof(__fp16), + hex_align_up(params->m, HMX_FP16_TILE_N_ROWS), params->n, /*m_block_cost=*/(size_t) params->n, /*n_block_cost=*/(size_t) params->m, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) { FARF(HIGH, "%s: grouped path does not fit VTCM, falling back to legacy batched loop", __func__); @@ -1612,7 +1645,7 @@ int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, float *restrict dst, co /*per_n=*/3 * vec_dot_size, // W + S0 + S1 /*per_m=*/vec_dot_size + f32_scratch_per_m, // A + optional F32 scratch /*per_mn=*/sizeof(__fp16), // O - m, n, + hex_align_up(m, HMX_FP16_TILE_N_ROWS), n, /*m_block_cost=*/(size_t) n, /*n_block_cost=*/(size_t) m, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) { FARF(HIGH, "%s: VTCM too small (m=%d k=%d n=%d budget=%zu)", __func__, m, k, n, vtcm_budget); diff --git a/ggml/src/ggml-hexagon/htp/matmul-ops.c b/ggml/src/ggml-hexagon/htp/matmul-ops.c index a0c265132c8..2461ae617fa 100644 --- a/ggml/src/ggml-hexagon/htp/matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/matmul-ops.c @@ -2991,12 +2991,10 @@ int op_matmul(struct htp_ops_context * octx) { return op_matmul_hvx(octx); } - // M alignment: when M > 32 but not 32-aligned, we split into - // HMX (first m_hmx = M & ~31 rows) + HVX (remaining m_tail rows). - // When M <= 32 and not 32-aligned, fall back entirely to HVX. + // M alignment: Use HMX when M >= 32, the last partial tile (m_total % 32 rows) + // is handled by HMX itself; when M < 32 fall back to HVX. const int m_total = (int) src1->ne[1]; - const int m_tail = m_total % 32; - const int m_hmx = m_total - m_tail; + const int m_hmx = m_total & ~31; // 0 when M < 32 if (m_hmx == 0) { return op_matmul_hvx(octx); @@ -3009,7 +3007,6 @@ int op_matmul(struct htp_ops_context * octx) { int k = (int) src0->ne[0]; // inner dimension int n = (int) src0->ne[1]; // weight columns - // --- Phase 1: HMX on the first m_hmx (32-aligned) rows --- int ret = -1; // Row strides in elements. For compact tensors these equal k; for @@ -3027,7 +3024,7 @@ int op_matmul(struct htp_ops_context * octx) { .dst = (float *) dst->data, .activation = (float *) src1->data, .permuted_weight = (const __fp16 *) src0->data, - .m = m_hmx, + .m = m_total, .k = k, .n = n, .act_stride = act_stride, @@ -3048,12 +3045,12 @@ int op_matmul(struct htp_ops_context * octx) { } else { ret = hmx_mat_mul_permuted_w16a32(octx->ctx, (float*) dst->data, (float*) src1->data, (const __fp16 *) src0->data, - m_hmx, k, n, act_stride, wgt_stride); + m_total, k, n, act_stride, wgt_stride); } } else { ret = hmx_mat_mul_permuted_qk_0_d16a32(octx->ctx, (float*) dst->data, (float*) src1->data, (const uint8_t *) src0->data, - m_hmx, k, n, (int) src0->type); + m_total, k, n, (int) src0->type); } if (ret != 0) { @@ -3061,27 +3058,6 @@ int op_matmul(struct htp_ops_context * octx) { return op_matmul(octx); } - // --- Phase 2: HVX on the remaining m_tail rows --- - if (m_tail > 0) { - // copy of src1 and dst - struct htp_tensor src1_tail = *src1; - struct htp_tensor dst_tail = *dst; - - src1_tail.ne[1] = m_tail; // only tail rows - dst_tail.ne[1] = m_tail; // only tail rows - - // Offset activation and dst pointers past the HMX-processed rows. - // Use nb[1] (row stride in bytes) to compute the byte offset. - src1_tail.data += (uint32_t) m_hmx * src1->nb[1]; - dst_tail.data += (uint32_t) m_hmx * dst->nb[1]; - - octx->src[1] = &src1_tail; - octx->dst = &dst_tail; - - FARF(HIGH, "hmx-matmul: HVX tail m_tail %d src1 %p dst %p", m_tail, (void *) src1_tail.data, (void *) dst_tail.data); - return op_matmul_hvx(octx); - } - return 0; #endif // HTP_HAS_HMX } From 3613268bc73ffaaf7c5d768ecfbfee24125e67b0 Mon Sep 17 00:00:00 2001 From: fl0rianr Date: Wed, 6 May 2026 07:12:48 +0200 Subject: [PATCH 562/831] ggml : use `CL_DEVICE_GLOBAL_MEM_SIZE` as memory estimate for OpenCL --fit (llama/22688) * ggml : report estimated OpenCL memory for --fit Signed-off-by: Florian Reinle * ggml : estimated OpenCL memory backend integrated Signed-off-by: Florian Reinle --------- Signed-off-by: Florian Reinle --- ggml/src/ggml-opencl/ggml-opencl.cpp | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 8c7bf98c16f..d344bde0fe3 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -389,6 +389,7 @@ struct ggml_backend_opencl_context { ADRENO_GPU_GEN adreno_gen; cl_int alignment; + size_t global_mem_size; size_t max_alloc_size; size_t max_workgroup_size; bool fp16_support; @@ -3386,6 +3387,9 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { backend_ctx->alignment = base_align_in_bits / 8u; GGML_LOG_INFO("ggml_opencl: mem base addr align: %u\n", backend_ctx->alignment); + clGetDeviceInfo(device, CL_DEVICE_GLOBAL_MEM_SIZE, sizeof(size_t), &backend_ctx->global_mem_size, NULL); + GGML_LOG_INFO("ggml_opencl: global mem size: %zu MB\n", backend_ctx->global_mem_size/1024/1024); + clGetDeviceInfo(device, CL_DEVICE_MAX_MEM_ALLOC_SIZE, sizeof(size_t), &backend_ctx->max_alloc_size, NULL); GGML_LOG_INFO("ggml_opencl: max mem alloc size: %zu MB\n", backend_ctx->max_alloc_size/1024/1024); @@ -6356,11 +6360,16 @@ static const char * ggml_backend_opencl_device_get_description(ggml_backend_dev_ } static void ggml_backend_opencl_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { - // no memory to report - *free = 0; - *total = 0; + ggml_backend_opencl_device_context * dev_ctx = (ggml_backend_opencl_device_context *) dev->context; + ggml_backend_opencl_context * backend_ctx = (ggml_backend_opencl_context *) dev_ctx->backend_ctx; - GGML_UNUSED(dev); + static const size_t opencl_extra_margin = 1024ull*1024ull*1024ull; + + // OpenCL does not provide reliable currently-free device memory. + // Use total/global memory as a best-effort upper bound. + // Improved safety: Reduce by a 1GiB extra margin for common --fit + *total = backend_ctx->global_mem_size; + *free = *total > opencl_extra_margin ? *total - opencl_extra_margin : 0; } static enum ggml_backend_dev_type ggml_backend_opencl_device_get_type(ggml_backend_dev_t dev) { From d3f16afcf57d7f6b9ae7aa469dbebd7113dc1dc1 Mon Sep 17 00:00:00 2001 From: zzzzwc Date: Wed, 6 May 2026 15:41:14 +0800 Subject: [PATCH 563/831] ggml-cpu: fuse RMS_NORM + MUL on CPU backend (llama/22423) --- ggml/src/ggml-cpu/ggml-cpu.c | 53 +++++++++++++++++++++++- ggml/src/ggml-cpu/ops.cpp | 78 ++++++++++++++++++++++++++++-------- ggml/src/ggml-cpu/ops.h | 1 + 3 files changed, 115 insertions(+), 17 deletions(-) diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 2d6cc1fcd46..8b7acafdaa8 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -2965,6 +2965,45 @@ struct ggml_cplan ggml_graph_plan( return cplan; } + +// Try to fuse the current node with subsequent nodes for better performance. +// Returns the number of nodes skipped by fusion (>=1), or 0 if no fusion was applied. +static bool ggml_cpu_disable_fusion = false; // initialized once in ggml_cpu_init(), read-only afterwards + +static int ggml_cpu_try_fuse_ops( + const struct ggml_cgraph * cgraph, + const int node_n, + const struct ggml_compute_params * params, + const struct ggml_cplan * cplan) { + + if (ggml_cpu_disable_fusion || cplan->use_ref) { + return 0; + } + + struct ggml_tensor * node = cgraph->nodes[node_n]; + + if (node->op == GGML_OP_RMS_NORM) { + // RMS_NORM + MUL fusion + const enum ggml_op fuse_ops[] = { GGML_OP_RMS_NORM, GGML_OP_MUL }; + if (ggml_can_fuse(cgraph, node_n, fuse_ops, 2)) { + struct ggml_tensor * mul_node = cgraph->nodes[node_n + 1]; + const struct ggml_tensor * mul_w = (mul_node->src[0] == node) + ? mul_node->src[1] : mul_node->src[0]; + if (node->src[0]->type == GGML_TYPE_F32 && + mul_node->type == GGML_TYPE_F32 && + mul_w->type == GGML_TYPE_F32 && + mul_w->ne[0] == node->ne[0] && + mul_w->nb[0] == sizeof(float)) { + + ggml_compute_forward_rms_norm_mul_fused(params, node, mul_node); + return 1; + } + } + } + + return 0; +} + static thread_ret_t ggml_graph_compute_thread(void * data) { struct ggml_compute_state * state = (struct ggml_compute_state *) data; struct ggml_threadpool * tp = state->threadpool; @@ -3001,7 +3040,14 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { continue; } - ggml_compute_forward(¶ms, node); + // TODO: move fused-op detection into ggml_graph_plan so fusion decisions are made once at planning time + // Try fused ops, fall back to normal compute + const int n_fused = ggml_cpu_try_fuse_ops(cgraph, node_n, ¶ms, cplan); + if (n_fused > 0) { + node_n += n_fused; + } else { + ggml_compute_forward(¶ms, node); + } if (state->ith == 0 && cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) { @@ -3763,6 +3809,11 @@ void ggml_cpu_init(void) { ggml_init_riscv_arch_features(); #endif + { + const char * env = getenv("GGML_CPU_DISABLE_FUSION"); + ggml_cpu_disable_fusion = (env != NULL && atoi(env) == 1); + } + is_first_call = false; } diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 211f1ba1b2f..6bc8dc150ce 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -3713,11 +3713,27 @@ void ggml_compute_forward_norm( // ggml_compute_forward_group_rms_norm +// fusion kinds that can be combined with the rms_norm computation in a single pass. +// extend this enum when adding new fused variants (e.g. FUSE_ADD, FUSE_MUL_ADD, ...). +enum ggml_rms_norm_fuse_op { + GGML_RMS_NORM_FUSE_OP_NONE, + GGML_RMS_NORM_FUSE_OP_MUL, +}; + +template static void ggml_compute_forward_rms_norm_f32( const ggml_compute_params * params, - ggml_tensor * dst) { + ggml_tensor * dst_rms_norm, + ggml_tensor * dst_fused = nullptr) { - const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src0 = dst_rms_norm->src[0]; + const ggml_tensor * src1 = nullptr; + ggml_tensor * dst = dst_rms_norm; + + if constexpr (FUSE_OP == GGML_RMS_NORM_FUSE_OP_MUL) { + src1 = (dst_fused->src[0] == dst_rms_norm) ? dst_fused->src[1] : dst_fused->src[0]; + dst = dst_fused; + } GGML_ASSERT(ggml_are_same_shape(src0, dst)); @@ -3726,11 +3742,10 @@ static void ggml_compute_forward_rms_norm_f32( const int ith = params->ith; const int nth = params->nth; - GGML_TENSOR_UNARY_OP_LOCALS + GGML_TENSOR_BINARY_OP_LOCALS float eps; - memcpy(&eps, dst->op_params, sizeof(float)); - + memcpy(&eps, dst_rms_norm->op_params, sizeof(float)); GGML_ASSERT(eps >= 0.0f); // TODO: optimize @@ -3740,25 +3755,32 @@ static void ggml_compute_forward_rms_norm_f32( const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); ggml_float sum = 0.0; + // worth switching to explicit SIMD? for (int64_t i00 = 0; i00 < ne00; i00++) { sum += (ggml_float)(x[i00] * x[i00]); } - const float mean = sum/ne00; - - float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); - - memcpy(y, x, ne00 * sizeof(float)); - // for (int i00 = 0; i00 < ne00; i00++) { - // y[i00] = x[i00]; - // } - + const float mean = sum/ne00; const float scale = 1.0f/sqrtf(mean + eps); // if you hit this, likely you got an inf somewhere earlier assert(scale > 0.0f); - ggml_vec_scale_f32(ne00, y, scale); + float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); + + if constexpr (FUSE_OP == GGML_RMS_NORM_FUSE_OP_MUL) { + const int64_t i11 = i01 % ne11; + const int64_t i12 = i02 % ne12; + const int64_t i13 = i03 % ne13; + const float * w = (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13); + + for (int64_t i00 = 0; i00 < ne00; i00++) { + y[i00] = x[i00] * scale * w[i00]; + } + } else { + memcpy(y, x, ne00 * sizeof(float)); + ggml_vec_scale_f32(ne00, y, scale); + } } } } @@ -3773,7 +3795,31 @@ void ggml_compute_forward_rms_norm( switch (src0->type) { case GGML_TYPE_F32: { - ggml_compute_forward_rms_norm_f32(params, dst); + ggml_compute_forward_rms_norm_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +// Fused RMS_NORM + MUL: computes dst = rms_norm(src0) * src1 in a single pass. +// This avoids materializing the intermediate rms_norm result in memory. +void ggml_compute_forward_rms_norm_mul_fused( + const ggml_compute_params * params, + ggml_tensor * dst_rms_norm, + ggml_tensor * dst_mul) { + + GGML_ASSERT(dst_mul != nullptr); + GGML_ASSERT(dst_mul->src[0] == dst_rms_norm || dst_mul->src[1] == dst_rms_norm); + + const ggml_tensor * src0 = dst_rms_norm->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_rms_norm_f32(params, dst_rms_norm, dst_mul); } break; default: { diff --git a/ggml/src/ggml-cpu/ops.h b/ggml/src/ggml-cpu/ops.h index 29efdeee37f..7398e561894 100644 --- a/ggml/src/ggml-cpu/ops.h +++ b/ggml/src/ggml-cpu/ops.h @@ -44,6 +44,7 @@ void ggml_compute_forward_concat(const struct ggml_compute_params * params, stru void ggml_compute_forward_silu_back(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_norm(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_rms_norm(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_rms_norm_mul_fused(const struct ggml_compute_params * params, struct ggml_tensor * dst_rms_norm, struct ggml_tensor * dst_mul); void ggml_compute_forward_rms_norm_back(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_group_norm(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_l2_norm(const struct ggml_compute_params * params, struct ggml_tensor * dst); From 4395364605738226787117697e7227ea514fcd3e Mon Sep 17 00:00:00 2001 From: pl752 Date: Thu, 7 May 2026 18:09:25 +0500 Subject: [PATCH 564/831] ggml-cpu: Optimized risc-v cpu q1_0 dot --- ggml/src/ggml-cpu/arch-fallback.h | 1 - ggml/src/ggml-cpu/arch/riscv/quants.c | 98 +++++++++++++++++++++++++++ 2 files changed, 98 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-cpu/arch-fallback.h b/ggml/src/ggml-cpu/arch-fallback.h index 595ded09f03..b0391a67c88 100644 --- a/ggml/src/ggml-cpu/arch-fallback.h +++ b/ggml/src/ggml-cpu/arch-fallback.h @@ -203,7 +203,6 @@ #elif defined(__riscv) // quants.c #define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 -#define ggml_vec_dot_q1_0_q8_0_generic ggml_vec_dot_q1_0_q8_0 // repack.cpp #define ggml_quantize_mat_q8_0_4x1_generic ggml_quantize_mat_q8_0_4x1 #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 diff --git a/ggml/src/ggml-cpu/arch/riscv/quants.c b/ggml/src/ggml-cpu/arch/riscv/quants.c index d3278d6489f..ee69e5ab5e5 100644 --- a/ggml/src/ggml-cpu/arch/riscv/quants.c +++ b/ggml/src/ggml-cpu/arch/riscv/quants.c @@ -480,6 +480,104 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi #endif } +#if defined(__riscv_v) +static NOINLINE void ggml_vec_dot_q1_0_q8_0_vl256(const int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy) { + const int qk = QK1_0; + const int nb = n / qk; + assert(n % qk == 0); + + const block_q1_0 * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + + //LMUL = 1, VLMAX = 32 + const size_t vl32 = __riscv_vsetvl_e8m1(32); + assert(vl32 == 32); + + const vint16m1_t zero = __riscv_vmv_v_x_i16m1(0, 1); + + float sumf = 0; + + for (int ib = 0; ib < nb; ++ib) { + const float d0 = GGML_CPU_FP16_TO_FP32(x[ib].d); + + float acc = 0; + + for (int k = 0; k < 4; ++k) { + const block_q8_0 * GGML_RESTRICT yb = &y[ib * 4 + k]; + const vbool8_t is_not_zero = __riscv_vlm_v_b8(x[ib].qs + 4 * k, vl32); + + const vint8m1_t qy = __riscv_vle8_v_i8m1(yb->qs, vl32); + const vint8m1_t neg_qy = __riscv_vneg_v_i8m1(qy, vl32); + const vint8m1_t sy = __riscv_vmerge_vvm_i8m1(neg_qy, qy, is_not_zero, vl32); + + const vint16m1_t red = __riscv_vwredsum_vs_i8m1_i16m1(sy, zero, vl32); + acc += GGML_CPU_FP16_TO_FP32(yb->d) * (float)__riscv_vmv_x_s_i16m1_i16(red); + } + + sumf += d0 * acc; + } + + *s = sumf; +} + +static NOINLINE void ggml_vec_dot_q1_0_q8_0_vl128(const int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy) { + const int qk = QK1_0; + const int nb = n / qk; + assert(n % qk == 0); + + const block_q1_0 * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + + //LMUL = 2, VLMAX = 32 + const size_t vl32 = __riscv_vsetvl_e8m2(32); + assert(vl32 == 32); + + const vint16m1_t zero = __riscv_vmv_v_x_i16m1(0, 1); + + float sumf = 0; + + for (int ib = 0; ib < nb; ++ib) { + const float d0 = GGML_CPU_FP16_TO_FP32(x[ib].d); + + float acc = 0; + + for (int k = 0; k < 4; ++k) { + const block_q8_0 * GGML_RESTRICT yb = &y[ib * 4 + k]; + const vbool4_t is_not_zero = __riscv_vlm_v_b4(x[ib].qs + 4 * k, vl32); + + const vint8m2_t qy = __riscv_vle8_v_i8m2(yb->qs, vl32); + const vint8m2_t neg_qy =__riscv_vneg_v_i8m2(qy, vl32); + const vint8m2_t sy = __riscv_vmerge_vvm_i8m2(neg_qy, qy, is_not_zero, vl32); + + const vint16m1_t red = __riscv_vwredsum_vs_i8m2_i16m1(sy, zero, vl32); + acc += GGML_CPU_FP16_TO_FP32(yb->d) * (float)__riscv_vmv_x_s_i16m1_i16(red); + } + + sumf += d0 * acc; + } + + *s = sumf; +} +#endif + +void ggml_vec_dot_q1_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined(__riscv_v) + assert(nrc == 1); + + const size_t vlen_bits = __riscv_vlenb() * 8; + + if (vlen_bits >= 256) { + ggml_vec_dot_q1_0_q8_0_vl256(n, s, vx, vy); + } else if (vlen_bits >= 128) { + ggml_vec_dot_q1_0_q8_0_vl128(n, s, vx, vy); + } else { + ggml_vec_dot_q1_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); + } +#else + ggml_vec_dot_q1_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(nrc == 1); UNUSED(nrc); From bd693bb1ebf5cb8bc987b1a43dd8bf31a7b65edc Mon Sep 17 00:00:00 2001 From: Intel AI Get-to Market Customer Success and Solutions Date: Thu, 7 May 2026 08:51:33 -0700 Subject: [PATCH 565/831] sycl: add FILL, CUMSUM, DIAG, SOLVE_TRI, SSM_SCAN, GATED_DELTA_NET (llama/22149) * sycl: add FILL, CUMSUM, DIAG, SOLVE_TRI, SSM_SCAN, GATED_DELTA_NET Signed-off-by: Chun Tao * Fix abort during test-backend-ops Signed-off-by: Todd Malsbary * Regenerate ops.md Signed-off-by: Todd Malsbary * Add scope_dbg_print to newly added SYCL ops. Also add scope_dbg_print to existing ssm_conv op. Signed-off-by: Todd Malsbary --------- Signed-off-by: Chun Tao Signed-off-by: Todd Malsbary Co-authored-by: Chun Tao Co-authored-by: Todd Malsbary --- ggml/src/ggml-sycl/cumsum.cpp | 148 +++++++++++++++++++++ ggml/src/ggml-sycl/cumsum.hpp | 5 + ggml/src/ggml-sycl/diag.cpp | 67 ++++++++++ ggml/src/ggml-sycl/diag.hpp | 5 + ggml/src/ggml-sycl/fill.cpp | 55 ++++++++ ggml/src/ggml-sycl/fill.hpp | 5 + ggml/src/ggml-sycl/gated_delta_net.hpp | 1 + ggml/src/ggml-sycl/ggml-sycl.cpp | 37 +++++- ggml/src/ggml-sycl/solve_tri.cpp | 172 +++++++++++++++++++++++++ ggml/src/ggml-sycl/solve_tri.hpp | 8 ++ ggml/src/ggml-sycl/ssm_conv.cpp | 7 +- ggml/src/ggml-sycl/ssm_scan.cpp | 156 ++++++++++++++++++++++ ggml/src/ggml-sycl/ssm_scan.hpp | 5 + 13 files changed, 669 insertions(+), 2 deletions(-) create mode 100644 ggml/src/ggml-sycl/cumsum.cpp create mode 100644 ggml/src/ggml-sycl/cumsum.hpp create mode 100644 ggml/src/ggml-sycl/diag.cpp create mode 100644 ggml/src/ggml-sycl/diag.hpp create mode 100644 ggml/src/ggml-sycl/fill.cpp create mode 100644 ggml/src/ggml-sycl/fill.hpp create mode 100644 ggml/src/ggml-sycl/solve_tri.cpp create mode 100644 ggml/src/ggml-sycl/solve_tri.hpp create mode 100644 ggml/src/ggml-sycl/ssm_scan.cpp create mode 100644 ggml/src/ggml-sycl/ssm_scan.hpp diff --git a/ggml/src/ggml-sycl/cumsum.cpp b/ggml/src/ggml-sycl/cumsum.cpp new file mode 100644 index 00000000000..c1c5fe4fe4a --- /dev/null +++ b/ggml/src/ggml-sycl/cumsum.cpp @@ -0,0 +1,148 @@ +#include "cumsum.hpp" +#include "common.hpp" + +#include + +#define SYCL_CUMSUM_BLOCK_SIZE 256 + +static __dpct_inline__ float warp_prefix_inclusive_sum_f32(float x, const sycl::nd_item<3> & item) { + return sycl::inclusive_scan_over_group(item.get_sub_group(), x, sycl::plus()); +} + +static void cumsum_f32_kernel( + const float * __restrict__ src, float * __restrict__ dst, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const int64_t s01, const int64_t s02, const int64_t s03, + const int64_t d1, const int64_t d2, const int64_t d3, + const sycl::nd_item<3> & item, float * smem) { + + const int tid = item.get_local_id(2); + const int block_size = item.get_local_range(2); + const int lane = tid % WARP_SIZE; + const int warp = tid / WARP_SIZE; + const int warps_per_block = block_size / WARP_SIZE; + + float * s_vals = smem; + float * s_warp_sums = smem + block_size; + float * s_carry = smem + block_size + warps_per_block; + + if (tid == 0) { + s_carry[0] = 0.0f; + } + item.barrier(sycl::access::fence_space::local_space); + + const int64_t i3 = item.get_group(0); + const int64_t i2 = item.get_group(1); + const int64_t i1 = item.get_group(2); + if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) { + return; + } + + const float * src_row = src + i1 * s01 + i2 * s02 + i3 * s03; + float * dst_row = dst + i1 * d1 + i2 * d2 + i3 * d3; + + constexpr int num_unroll = 4; + float temp[num_unroll]; + + for (int64_t i = 0; i < ne00; i += num_unroll * block_size) { + int64_t idx = i + tid * num_unroll; + + temp[0] = (idx < ne00 ? src_row[idx] : 0.0f); +#pragma unroll + for (int j = 1; j < num_unroll; j++) { + temp[j] = temp[j - 1]; + if (idx + j < ne00) { + temp[j] += src_row[idx + j]; + } + } + + float val = (idx < ne00) ? temp[num_unroll - 1] : 0.0f; + + val = warp_prefix_inclusive_sum_f32(val, item); + s_vals[tid] = val; + + if (lane == WARP_SIZE - 1) { + s_warp_sums[warp] = val; + } + item.barrier(sycl::access::fence_space::local_space); + + if (warp == 0) { + float w = (tid < warps_per_block) ? s_warp_sums[tid] : 0.0f; + float inc = warp_prefix_inclusive_sum_f32(w, item); + if (tid < warps_per_block) { + s_warp_sums[tid] = inc - w; + } + if (tid == warps_per_block - 1) { + s_carry[1] = inc; + } + } + item.barrier(sycl::access::fence_space::local_space); + + float carry = s_carry[0]; + float final_offset = s_vals[tid] + s_warp_sums[warp] + carry - temp[num_unroll - 1]; + +#pragma unroll + for (int j = 0; j < num_unroll; j++) { + if (idx + j < ne00) { + dst_row[idx + j] = temp[j] + final_offset; + } + } + + item.barrier(sycl::access::fence_space::local_space); + + if (tid == 0) { + s_carry[0] += s_carry[1]; + } + } +} + +inline void ggml_sycl_op_cumsum(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + dpct::queue_ptr stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + + const float * src_d = static_cast(src0->data); + float * dst_d = static_cast(dst->data); + + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const int64_t ne02 = src0->ne[2]; + const int64_t ne03 = src0->ne[3]; + + const size_t ts = sizeof(float); + const int64_t s01 = src0->nb[1] / ts; + const int64_t s02 = src0->nb[2] / ts; + const int64_t s03 = src0->nb[3] / ts; + const int64_t d1 = dst->nb[1] / ts; + const int64_t d2 = dst->nb[2] / ts; + const int64_t d3 = dst->nb[3] / ts; + + const int num_warps = (ne00 + WARP_SIZE - 1) / WARP_SIZE; + int block_size = num_warps * WARP_SIZE; + block_size = std::min(block_size, SYCL_CUMSUM_BLOCK_SIZE); + const int warps_per_block = block_size / WARP_SIZE; + const int smem_size = block_size + warps_per_block + 2; + + const sycl::range<3> grid(ne03, ne02, ne01); + const sycl::range<3> block(1, 1, block_size); + + stream->submit([&](sycl::handler & cgh) { + sycl::local_accessor smem_acc(sycl::range<1>(smem_size), cgh); + cgh.parallel_for( + sycl::nd_range<3>(grid * block, block), + [=](sycl::nd_item<3> item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + cumsum_f32_kernel(src_d, dst_d, ne00, ne01, ne02, ne03, + s01, s02, s03, d1, d2, d3, + item, get_pointer(smem_acc)); + }); + }); +} + +void ggml_sycl_cumsum(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + ggml_sycl_op_cumsum(ctx, dst); +} diff --git a/ggml/src/ggml-sycl/cumsum.hpp b/ggml/src/ggml-sycl/cumsum.hpp new file mode 100644 index 00000000000..f1a564472c5 --- /dev/null +++ b/ggml/src/ggml-sycl/cumsum.hpp @@ -0,0 +1,5 @@ +#pragma once + +#include "common.hpp" + +void ggml_sycl_cumsum(ggml_backend_sycl_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-sycl/diag.cpp b/ggml/src/ggml-sycl/diag.cpp new file mode 100644 index 00000000000..c4264fee342 --- /dev/null +++ b/ggml/src/ggml-sycl/diag.cpp @@ -0,0 +1,67 @@ +#include "diag.hpp" +#include "common.hpp" + +#define SYCL_DIAG_BLOCK_SIZE 256 + +template +static void diag_kernel(T * __restrict__ dst, const T * __restrict__ src, + const int64_t ne0, const int64_t ne1, + const int64_t ne2, const int64_t ne3, + const int64_t total_elements, + const sycl::nd_item<1> & item) { + const int64_t i = item.get_global_id(0); + if (i >= total_elements) { + return; + } + + const int64_t i0 = i % ne0; + const int64_t i1 = (i / ne0) % ne1; + const int64_t i2 = (i / (ne0 * ne1)) % ne2; + const int64_t i3 = i / (ne0 * ne1 * ne2); + + const int64_t dst_idx = ((i3 * ne2 + i2) * ne1 + i1) * ne0 + i0; + + if (i0 == i1) { + const int64_t batch_idx = i3 * ne2 + i2; + dst[dst_idx] = src[batch_idx * ne0 + i0]; + } else { + dst[dst_idx] = T(0); + } + + (void)ne3; +} + +inline void ggml_sycl_op_diag(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + + GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(src0->ne[1] == 1); + + dpct::queue_ptr stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + + const void * src0_d = src0->data; + void * dst_d = dst->data; + + const int64_t ne0 = dst->ne[0]; + const int64_t ne1 = dst->ne[1]; + const int64_t ne2 = dst->ne[2]; + const int64_t ne3 = dst->ne[3]; + const int64_t n_elems = ggml_nelements(dst); + const int64_t num_blocks = (n_elems + SYCL_DIAG_BLOCK_SIZE - 1) / SYCL_DIAG_BLOCK_SIZE; + + GGML_ASSERT(dst->type == GGML_TYPE_F32); + stream->parallel_for( + sycl::nd_range<1>(num_blocks * SYCL_DIAG_BLOCK_SIZE, SYCL_DIAG_BLOCK_SIZE), + [=](sycl::nd_item<1> item) { + diag_kernel(static_cast(dst_d), + static_cast(src0_d), + ne0, ne1, ne2, ne3, n_elems, item); + }); +} + +void ggml_sycl_diag(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + ggml_sycl_op_diag(ctx, dst); +} diff --git a/ggml/src/ggml-sycl/diag.hpp b/ggml/src/ggml-sycl/diag.hpp new file mode 100644 index 00000000000..20d7ce4895d --- /dev/null +++ b/ggml/src/ggml-sycl/diag.hpp @@ -0,0 +1,5 @@ +#pragma once + +#include "common.hpp" + +void ggml_sycl_diag(ggml_backend_sycl_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-sycl/fill.cpp b/ggml/src/ggml-sycl/fill.cpp new file mode 100644 index 00000000000..28e618e4ef5 --- /dev/null +++ b/ggml/src/ggml-sycl/fill.cpp @@ -0,0 +1,55 @@ +#include "fill.hpp" +#include "common.hpp" + +#define SYCL_FILL_BLOCK_SIZE 256 + +template +static void fill_kernel(T * dst, const int64_t k, const T value, + const sycl::nd_item<1> & item) { + const int64_t i = (int64_t)item.get_global_id(0); + if (i >= k) { + return; + } + dst[i] = value; +} + +inline void ggml_sycl_op_fill(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + GGML_ASSERT(ggml_is_contiguous(dst)); + + dpct::queue_ptr stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + + float value; + memcpy(&value, dst->op_params, sizeof(float)); + + const int64_t k = ggml_nelements(dst); + const int64_t num_blocks = (k + SYCL_FILL_BLOCK_SIZE - 1) / SYCL_FILL_BLOCK_SIZE; + void * dst_d = dst->data; + + switch (dst->type) { + case GGML_TYPE_F32: + stream->parallel_for( + sycl::nd_range<1>(num_blocks * SYCL_FILL_BLOCK_SIZE, SYCL_FILL_BLOCK_SIZE), + [=](sycl::nd_item<1> item) { + fill_kernel(static_cast(dst_d), k, value, item); + }); + break; + case GGML_TYPE_F16: + { + sycl::half h_value = sycl::half(value); + stream->parallel_for( + sycl::nd_range<1>(num_blocks * SYCL_FILL_BLOCK_SIZE, SYCL_FILL_BLOCK_SIZE), + [=](sycl::nd_item<1> item) { + fill_kernel(static_cast(dst_d), k, h_value, item); + }); + } + break; + default: + GGML_ABORT("unsupported type"); + } +} + +void ggml_sycl_fill(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/0); + ggml_sycl_op_fill(ctx, dst); +} diff --git a/ggml/src/ggml-sycl/fill.hpp b/ggml/src/ggml-sycl/fill.hpp new file mode 100644 index 00000000000..b2adb94ff52 --- /dev/null +++ b/ggml/src/ggml-sycl/fill.hpp @@ -0,0 +1,5 @@ +#pragma once + +#include "common.hpp" + +void ggml_sycl_fill(ggml_backend_sycl_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-sycl/gated_delta_net.hpp b/ggml/src/ggml-sycl/gated_delta_net.hpp index a3308ee8763..350b4ce2f66 100644 --- a/ggml/src/ggml-sycl/gated_delta_net.hpp +++ b/ggml/src/ggml-sycl/gated_delta_net.hpp @@ -5,4 +5,5 @@ #include "common.hpp" #include "ggml.h" +void ggml_sycl_op_gated_delta_net(ggml_backend_sycl_context & ctx, ggml_tensor * dst); void ggml_sycl_gated_delta_net(ggml_backend_sycl_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index f06147eeeb8..29ecedb5de9 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -54,7 +54,12 @@ #include "ggml-sycl/set.hpp" #include "ggml-sycl/ssm_conv.hpp" #include "ggml-sycl/sycl_hw.hpp" - +#include "ggml-sycl/ssm_scan.hpp" +#include "ggml-sycl/fill.hpp" +#include "ggml-sycl/cumsum.hpp" +#include "ggml-sycl/diag.hpp" +#include "ggml-sycl/solve_tri.hpp" +#include "ggml-sycl/gated_delta_net.hpp" static bool g_sycl_loaded = false; int g_ggml_sycl_debug = 0; @@ -4394,6 +4399,21 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg case GGML_OP_SSM_CONV: ggml_sycl_ssm_conv(ctx, dst); break; + case GGML_OP_SSM_SCAN: + ggml_sycl_ssm_scan(ctx, dst); + break; + case GGML_OP_FILL: + ggml_sycl_fill(ctx, dst); + break; + case GGML_OP_CUMSUM: + ggml_sycl_cumsum(ctx, dst); + break; + case GGML_OP_DIAG: + ggml_sycl_diag(ctx, dst); + break; + case GGML_OP_SOLVE_TRI: + ggml_sycl_solve_tri(ctx, dst); + break; case GGML_OP_ROLL: ggml_sycl_roll(ctx, dst); break; @@ -5104,6 +5124,21 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g return op->type == GGML_TYPE_F32; case GGML_OP_ARANGE: return op->type == GGML_TYPE_F32; + case GGML_OP_SSM_SCAN: + if (op->src[3]->ne[0] == 1) { + // Mamba2 + // (kernel only supports (d_state == 128 || d_state == 256) && d_head % WARP_SIZE == 0) + return (op->src[0]->ne[0] == 128 || op->src[0]->ne[0] == 256) && op->src[0]->ne[1] % WARP_SIZE == 0; + } else { + // TODO Mamba-1 not yet ported to SYCL + return false; + } + case GGML_OP_FILL: + case GGML_OP_CUMSUM: + case GGML_OP_DIAG: + return true; + case GGML_OP_SOLVE_TRI: + return op->src[0]->ne[0] <= SYCL_SOLVE_TRI_MAX_N && op->src[1]->ne[0] <= SYCL_SOLVE_TRI_MAX_K; case GGML_OP_FLASH_ATTN_EXT: return ggml_sycl_flash_attn_ext_supported(device, op); default: diff --git a/ggml/src/ggml-sycl/solve_tri.cpp b/ggml/src/ggml-sycl/solve_tri.cpp new file mode 100644 index 00000000000..39326deee44 --- /dev/null +++ b/ggml/src/ggml-sycl/solve_tri.cpp @@ -0,0 +1,172 @@ +#include "solve_tri.hpp" +#include "common.hpp" +#include + +template +static void solve_tri_f32_fast(const float * __restrict__ A, + const float * __restrict__ B, + float * __restrict__ X, + const int64_t ne02, [[maybe_unused]] const int64_t ne03, + const int64_t nb02, const int64_t nb03, + const int64_t nb12, const int64_t nb13, + const int64_t nb2, const int64_t nb3, + const int n_arg, const int k_arg, + const sycl::nd_item<2> & item, float * sA) { + + const int n = n_template == 0 ? n_arg : n_template; + const int k = k_template == 0 ? k_arg : k_template; + + const int batch_idx = item.get_group(1); + const int lane = item.get_local_id(1) % WARP_SIZE; + const int col_idx = item.get_local_id(0); + + if (col_idx >= k) { + return; + } + + const int64_t i03 = batch_idx / ne02; + const int64_t i02 = batch_idx % ne02; + + const float * A_batch = (const float *) ((const char *) A + i02 * nb02 + i03 * nb03); + const float * B_batch = (const float *) ((const char *) B + i02 * nb12 + i03 * nb13); + float * X_batch = (float *) ((char *) X + i02 * nb2 + i03 * nb3); + + const int offset = item.get_local_id(1) + item.get_local_id(0) * item.get_local_range(1); + +#pragma unroll + for (int i = 0; i < n * n; i += k * WARP_SIZE) { + const int i0 = i + offset; + if (i0 < n * n) { + sA[i0] = A_batch[i0]; + } + } + + item.barrier(sycl::access::fence_space::local_space); + + float x_low = (lane < n) ? B_batch[lane * k + col_idx] : 0.0f; + float x_high = (WARP_SIZE + lane < n) ? B_batch[(WARP_SIZE + lane) * k + col_idx] : 0.0f; + + const int half = WARP_SIZE; + const int nrows_low = (n < half) ? n : half; + +#pragma unroll + for (int row = 0; row < nrows_low; ++row) { + float sum = 0.0f; + if (lane < row) { + sum += sA[row * n + lane] * x_low; + } + sum = warp_reduce_sum(sum); + if (lane == row) { + x_low = (x_low - sum) / sA[row * n + row]; + } + } + +#pragma unroll + for (int row = half; row < n; ++row) { + float sum = sA[row * n + lane] * x_low; + const int j = half + lane; + if (j < row) { + sum += sA[row * n + j] * x_high; + } + sum = warp_reduce_sum(sum); + if (lane == row - half) { + x_high = (x_high - sum) / sA[row * n + row]; + } + } + +#pragma unroll + for (int rr = 0; rr < 2; ++rr) { + const int row = rr * WARP_SIZE + lane; + if (row < n) { + const float val = (row < half) ? x_low : x_high; + X_batch[row * k + col_idx] = val; + } + } +} + +static void solve_tri_f32_mkl(dpct::queue_ptr stream, + const float * A, float * X, + int n, int k, + int64_t ne02, [[maybe_unused]] int64_t ne03, + int64_t nb02, [[maybe_unused]] int64_t nb03, + int64_t nb2, [[maybe_unused]] int64_t nb3) { + const float alpha = 1.0f; + const int64_t total_batches = ne02 * ne03; + if (total_batches == 0) { + return; + } + + const int64_t stride_a = nb02 / sizeof(float); + const int64_t stride_x = nb2 / sizeof(float); + + oneapi::mkl::blas::trsm_batch( + *stream, + oneapi::mkl::side::right, + oneapi::mkl::uplo::upper, + oneapi::mkl::transpose::nontrans, + oneapi::mkl::diag::nonunit, + k, n, alpha, + A, n, stride_a, + X, k, stride_x, + total_batches); +} + +inline void ggml_sycl_op_solve_tri(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); + GGML_ASSERT(src0->type == GGML_TYPE_F32); + + dpct::queue_ptr stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + + const int n = src0->ne[0]; + const int k = src1->ne[0]; + const int64_t ne02 = src0->ne[2]; + const int64_t ne03 = src0->ne[3]; + + GGML_ASSERT(n <= SYCL_SOLVE_TRI_MAX_N && k <= SYCL_SOLVE_TRI_MAX_K); + + const float * A_d = static_cast(src0->data); + const float * B_d = static_cast(src1->data); + float * X_d = static_cast(dst->data); + + if (X_d != B_d) { + const int64_t total_elements = (int64_t)n * k * ne02 * ne03; + stream->memcpy(X_d, B_d, total_elements * sizeof(float)); + } + + const int64_t nb02 = src0->nb[2]; + const int64_t nb03 = src0->nb[3]; + const int64_t nb12 = src1->nb[2]; + const int64_t nb13 = src1->nb[3]; + const int64_t nb2 = dst->nb[2]; + const int64_t nb3 = dst->nb[3]; + + const int64_t total_batches = ne02 * ne03; + + if (n <= 2 * WARP_SIZE && k <= 32) { + const int smem_size = 2 * WARP_SIZE * 2 * WARP_SIZE; + const sycl::range<2> grid(1, total_batches); + const sycl::range<2> block(k, WARP_SIZE); + stream->submit([&](sycl::handler & cgh) { + sycl::local_accessor smem_acc(sycl::range<1>(smem_size), cgh); + cgh.parallel_for( + sycl::nd_range<2>(grid * block, block), + [=](sycl::nd_item<2> item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + solve_tri_f32_fast<0, 0>(A_d, B_d, X_d, ne02, ne03, + nb02, nb03, nb12, nb13, nb2, nb3, + n, k, item, get_pointer(smem_acc)); + }); + }); + } else { + solve_tri_f32_mkl(stream, A_d, X_d, n, k, ne02, ne03, nb02, nb03, nb2, nb3); + } +} + +void ggml_sycl_solve_tri(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); + ggml_sycl_op_solve_tri(ctx, dst); +} diff --git a/ggml/src/ggml-sycl/solve_tri.hpp b/ggml/src/ggml-sycl/solve_tri.hpp new file mode 100644 index 00000000000..c7c34cfa2bb --- /dev/null +++ b/ggml/src/ggml-sycl/solve_tri.hpp @@ -0,0 +1,8 @@ +#pragma once + +#include "common.hpp" + +#define SYCL_SOLVE_TRI_MAX_N 64 +#define SYCL_SOLVE_TRI_MAX_K 64 + +void ggml_sycl_solve_tri(ggml_backend_sycl_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-sycl/ssm_conv.cpp b/ggml/src/ggml-sycl/ssm_conv.cpp index eea9a73d67e..e55223586a1 100644 --- a/ggml/src/ggml-sycl/ssm_conv.cpp +++ b/ggml/src/ggml-sycl/ssm_conv.cpp @@ -63,7 +63,7 @@ static void kernel_ssm_conv( }); } -void ggml_sycl_ssm_conv(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +inline void ggml_sycl_op_ssm_conv(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { ggml_tensor * src0 = dst->src[0]; ggml_tensor * src1 = dst->src[1]; @@ -125,3 +125,8 @@ void ggml_sycl_ssm_conv(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { throw; } } + +void ggml_sycl_ssm_conv(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); + ggml_sycl_op_ssm_conv(ctx, dst); +} diff --git a/ggml/src/ggml-sycl/ssm_scan.cpp b/ggml/src/ggml-sycl/ssm_scan.cpp new file mode 100644 index 00000000000..ae652981384 --- /dev/null +++ b/ggml/src/ggml-sycl/ssm_scan.cpp @@ -0,0 +1,156 @@ +#include "ssm_scan.hpp" +#include "common.hpp" + +template +static void ssm_scan_f32_group( + const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2, + const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5, + const int32_t * __restrict__ src6, float * __restrict__ dst, + const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3, + const int src2_nb1, const int src2_nb2, const int src3_nb1, + const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3, + const int64_t s_off, const int64_t n_head, const int64_t d_head, const int64_t n_group, const int64_t n_tok, + const sycl::nd_item<2> & item) { + + const int lane = item.get_local_id(1) % WARP_SIZE; + const int warp = item.get_local_id(1) / WARP_SIZE; + const int warp_idx = item.get_group(1) * c_factor + warp; + const int seq_idx = item.get_group(0); + + const int head_idx = warp_idx / d_head; + const int head_off = (warp_idx % d_head) * sizeof(float); + const int group_off = (head_idx / (n_head / n_group)) * d_state * sizeof(float); + + const float * s0_warp = (const float *) ((const char *) src0 + src6[seq_idx] * src0_nb3 + head_idx * src0_nb2 + head_off * d_state); + const float * x_warp = (const float *) ((const char *) src1 + (seq_idx * src1_nb3) + (warp_idx * sizeof(float))); + const float * dt_warp = (const float *) ((const char *) src2 + (seq_idx * src2_nb2) + head_idx * sizeof(float)); + const float * A_warp = (const float *) ((const char *) src3 + head_idx * src3_nb1); + const float * B_warp = (const float *) ((const char *) src4 + (seq_idx * src4_nb3) + (group_off)); + const float * C_warp = (const float *) ((const char *) src5 + (seq_idx * src5_nb3) + (group_off)); + float * y_warp = dst + (seq_idx * n_tok * n_head * d_head) + warp_idx; + float * s_warp = (float *) ((char *) dst + s_off + seq_idx * src0_nb3 + head_idx * src0_nb2 + head_off * d_state); + + const int stride_x = src1_nb2 / sizeof(float); + const int stride_dt = src2_nb1 / sizeof(float); + const int stride_B = src4_nb2 / sizeof(float); + const int stride_C = src5_nb2 / sizeof(float); + const int stride_y = n_head * d_head; + + float state[c_factor]; + float state_sum = 0.0f; + +#pragma unroll + for (int j = 0; j < c_factor; j++) { + state[j] = s0_warp[WARP_SIZE * j + lane]; + } + + for (int64_t i = 0; i < n_tok; i++) { + const float dt_val = dt_warp[i * stride_dt]; + const float dt_soft_plus = (dt_val <= 20.0f ? sycl::log1p(sycl::exp(dt_val)) : dt_val); + + state_sum = 0.0f; + const float dA = sycl::exp(dt_soft_plus * A_warp[0]); + const float x_dt = x_warp[i * stride_x] * dt_soft_plus; +#pragma unroll + for (int j = 0; j < c_factor; j++) { + const float B_val = B_warp[i * stride_B + WARP_SIZE * j + lane]; + const float C_val = C_warp[i * stride_C + WARP_SIZE * j + lane]; + state[j] = (state[j] * dA) + (B_val * x_dt); + state_sum += state[j] * C_val; + } + + state_sum = warp_reduce_sum(state_sum); + + if (lane == 0) { + y_warp[i * stride_y] = state_sum; + } + } + +#pragma unroll + for (int j = 0; j < c_factor; j++) { + s_warp[WARP_SIZE * j + lane] = state[j]; + } +} + +static void ssm_scan_f32_sycl( + const float * src0, const float * src1, const float * src2, const float * src3, + const float * src4, const float * src5, const int32_t * src6, float * dst, + const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3, const int src2_nb1, + const int src2_nb2, const int src3_nb1, const int src4_nb2, const int src4_nb3, const int src5_nb2, + const int src5_nb3, const int64_t s_off, const int64_t d_state, const int64_t head_dim, + const int64_t n_head, const int64_t n_group, const int64_t n_tok, const int64_t n_seq, + dpct::queue_ptr stream) { + + // NOTE: if you change conditions here, be sure to update the corresponding supports_op condition! + GGML_ASSERT(src3_nb1 == sizeof(float)); + if (d_state == 128) { + constexpr int threads = 128; + constexpr int num_warps = threads / WARP_SIZE; + const sycl::range<2> grid(n_seq, (n_head * head_dim + num_warps - 1) / num_warps); + const sycl::range<2> block(1, threads); + stream->parallel_for( + sycl::nd_range<2>(grid * block, block), + [=](sycl::nd_item<2> item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + ssm_scan_f32_group<128 / WARP_SIZE, 128>( + src0, src1, src2, src3, src4, src5, src6, dst, + src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, + src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok, item); + }); + } else if (d_state == 256) { + constexpr int threads = 256; + constexpr int num_warps = threads / WARP_SIZE; + const sycl::range<2> grid(n_seq, (n_head * head_dim + num_warps - 1) / num_warps); + const sycl::range<2> block(1, threads); + stream->parallel_for( + sycl::nd_range<2>(grid * block, block), + [=](sycl::nd_item<2> item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + ssm_scan_f32_group<256 / WARP_SIZE, 256>( + src0, src1, src2, src3, src4, src5, src6, dst, + src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, + src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok, item); + }); + } else { + GGML_ABORT("ssm_scan: unsupported d_state (must be 128 or 256)"); + } +} + +inline void ggml_sycl_op_ssm_scan(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + const ggml_tensor * src2 = dst->src[2]; + const ggml_tensor * src3 = dst->src[3]; + const ggml_tensor * src4 = dst->src[4]; + const ggml_tensor * src5 = dst->src[5]; + const ggml_tensor * src6 = dst->src[6]; + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src6->type == GGML_TYPE_I32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + const int64_t nc = src0->ne[0]; + const int64_t nr = src0->ne[1]; + const int64_t nh = src1->ne[1]; + const int64_t ng = src4->ne[1]; + const int64_t n_t = src1->ne[2]; + const int64_t n_s = src1->ne[3]; + const int64_t s_off = ggml_nelements(src1) * sizeof(float); + + GGML_ASSERT(ggml_nelements(src1) + nc * nr * nh * n_s == ggml_nelements(dst)); + + dpct::queue_ptr stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + + ssm_scan_f32_sycl( + static_cast(src0->data), static_cast(src1->data), + static_cast(src2->data), static_cast(src3->data), + static_cast(src4->data), static_cast(src5->data), + static_cast(src6->data), static_cast(dst->data), + src0->nb[2], src0->nb[3], src1->nb[2], src1->nb[3], src2->nb[1], src2->nb[2], + src3->nb[1], src4->nb[2], src4->nb[3], src5->nb[2], src5->nb[3], + s_off, nc, nr, nh, ng, n_t, n_s, stream); +} + +void ggml_sycl_ssm_scan(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/7); + ggml_sycl_op_ssm_scan(ctx, dst); +} diff --git a/ggml/src/ggml-sycl/ssm_scan.hpp b/ggml/src/ggml-sycl/ssm_scan.hpp new file mode 100644 index 00000000000..1f9731fb6fd --- /dev/null +++ b/ggml/src/ggml-sycl/ssm_scan.hpp @@ -0,0 +1,5 @@ +#pragma once + +#include "common.hpp" + +void ggml_sycl_ssm_scan(ggml_backend_sycl_context & ctx, ggml_tensor * dst); From 7774fe2c8d19813a3ad2d74ce33ba51ae01996da Mon Sep 17 00:00:00 2001 From: shaofeiqi Date: Thu, 7 May 2026 11:00:20 -0700 Subject: [PATCH 566/831] opencl: add opfilter regex for debugging (llama/22782) --- ggml/src/ggml-opencl/ggml-opencl.cpp | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index d344bde0fe3..e5a5d42f6fb 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -28,6 +28,7 @@ #include #include #include +#include #undef MIN #undef MAX @@ -396,6 +397,8 @@ struct ggml_backend_opencl_context { bool has_vector_subgroup_broadcast; bool disable_fusion; + std::regex *opfilter = nullptr; // regex of ops to not claim + bool adreno_has_large_buffer; bool adreno_use_large_buffer; ggml_cl_compiler_version adreno_cl_compiler_version; @@ -3494,6 +3497,12 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { backend_ctx->disable_fusion = getenv("GGML_OPENCL_DISABLE_FUSION") != nullptr; + const char * str_opfilter = getenv("GGML_OPENCL_OPFILTER"); + if (str_opfilter) { + backend_ctx->opfilter = new std::regex(str_opfilter, std::regex_constants::icase); + GGML_LOG_INFO("ggml_opencl: opfilter regex = \"%s\"\n", str_opfilter); + } + dev_ctx->backend_ctx = backend_ctx.release(); return dev_ctx->backend_ctx; } @@ -4143,6 +4152,11 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te ggml_backend_opencl_device_context * dev_ctx = (ggml_backend_opencl_device_context *)dev->context; ggml_backend_opencl_context * backend_ctx = dev_ctx->backend_ctx; + // reject ops that match the opfilter regex + if (backend_ctx->opfilter && std::regex_match(std::string(ggml_op_desc(op)), *backend_ctx->opfilter)) { + return false; + } + switch (op->op) { case GGML_OP_NONE: return true; From 5fd75cda3fec6b87f724a5784479f8ff9348a7d4 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 7 May 2026 21:43:40 +0300 Subject: [PATCH 567/831] llama : fix device state save/load (llama/22805) --- ggml/src/ggml-metal/ggml-metal.cpp | 44 +++++++++++++++--------------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal.cpp b/ggml/src/ggml-metal/ggml-metal.cpp index 35774254983..a1003b3acff 100644 --- a/ggml/src/ggml-metal/ggml-metal.cpp +++ b/ggml/src/ggml-metal/ggml-metal.cpp @@ -87,17 +87,17 @@ static void ggml_backend_metal_buffer_shared_clear(ggml_backend_buffer_t buffer, } static ggml_backend_buffer_i ggml_backend_metal_buffer_shared_i = { - /* .free_buffer = */ ggml_backend_metal_buffer_shared_free_buffer, - /* .get_base = */ ggml_backend_metal_buffer_shared_get_base, - /* .init_tensor = */ NULL, - /* .memset_tensor = */ ggml_backend_metal_buffer_shared_memset_tensor, - /* .set_tensor = */ ggml_backend_metal_buffer_shared_set_tensor, - /* .get_tensor = */ ggml_backend_metal_buffer_shared_get_tensor, - /* .set_tensor_2d = */ NULL, - /* .get_tensor_2d = */ NULL, - /* .cpy_tensor = */ ggml_backend_metal_buffer_shared_cpy_tensor, - /* .clear = */ ggml_backend_metal_buffer_shared_clear, - /* .reset = */ NULL, + /* .free_buffer = */ ggml_backend_metal_buffer_shared_free_buffer, + /* .get_base = */ ggml_backend_metal_buffer_shared_get_base, + /* .init_tensor = */ NULL, + /* .memset_tensor = */ ggml_backend_metal_buffer_shared_memset_tensor, + /* .set_tensor = */ ggml_backend_metal_buffer_shared_set_tensor, + /* .get_tensor = */ ggml_backend_metal_buffer_shared_get_tensor, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, + /* .cpy_tensor = */ ggml_backend_metal_buffer_shared_cpy_tensor, + /* .clear = */ ggml_backend_metal_buffer_shared_clear, + /* .reset = */ NULL, }; // private buffer @@ -163,17 +163,17 @@ static void ggml_backend_metal_buffer_private_clear(ggml_backend_buffer_t buffer } static ggml_backend_buffer_i ggml_backend_metal_buffer_private_i = { - /* .free_buffer = */ ggml_backend_metal_buffer_private_free_buffer, - /* .get_base = */ ggml_backend_metal_buffer_private_get_base, - /* .init_tensor = */ NULL, - /* .memset_tensor = */ ggml_backend_metal_buffer_private_memset_tensor, - /* .set_tensor = */ ggml_backend_metal_buffer_private_set_tensor, - /* .get_tensor = */ ggml_backend_metal_buffer_private_get_tensor, - /* .set_tensor_2d = */ NULL, - /* .get_tensor_2d = */ NULL, - /* .cpy_tensor = */ ggml_backend_metal_buffer_private_cpy_tensor, - /* .clear = */ ggml_backend_metal_buffer_private_clear, - /* .reset = */ NULL, + /* .free_buffer = */ ggml_backend_metal_buffer_private_free_buffer, + /* .get_base = */ ggml_backend_metal_buffer_private_get_base, + /* .init_tensor = */ NULL, + /* .memset_tensor = */ ggml_backend_metal_buffer_private_memset_tensor, + /* .set_tensor = */ ggml_backend_metal_buffer_private_set_tensor, + /* .get_tensor = */ ggml_backend_metal_buffer_private_get_tensor, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, + /* .cpy_tensor = */ ggml_backend_metal_buffer_private_cpy_tensor, + /* .clear = */ ggml_backend_metal_buffer_private_clear, + /* .reset = */ NULL, }; static bool ggml_backend_buffer_is_metal(ggml_backend_buffer_t buffer) { From 6e91ed3b338f221c51c02439abf4c4e952ebe306 Mon Sep 17 00:00:00 2001 From: leonardHONG <2695316095@qq.com> Date: Fri, 8 May 2026 03:59:29 +0800 Subject: [PATCH 568/831] CUDA: batch out_prod inner loop with cublasSgemmStridedBatched (llama/22651) * CUDA: batch out_prod inner loop with cublasSgemmStridedBatched * CUDA: batch out_prod inner loop with cublasSgemmStridedBatched * CUDA: add cublasSgemmStridedBatched mapping for HIP and MUSA backends --- ggml/src/ggml-cuda/out-prod.cu | 30 +++++++++++++++++++++++------- ggml/src/ggml-cuda/vendors/hip.h | 1 + ggml/src/ggml-cuda/vendors/musa.h | 1 + 3 files changed, 25 insertions(+), 7 deletions(-) diff --git a/ggml/src/ggml-cuda/out-prod.cu b/ggml/src/ggml-cuda/out-prod.cu index c9b2b699c6a..499903d09b1 100644 --- a/ggml/src/ggml-cuda/out-prod.cu +++ b/ggml/src/ggml-cuda/out-prod.cu @@ -54,15 +54,31 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const int64_t dps2 = ne2 / ne02; const int64_t dps3 = ne3 / ne03; - // TODO batched matrix multiplication - for (int64_t i3 = 0; i3 < ne3; ++i3) { - for (int64_t i2 = 0; i2 < ne2; ++i2) { + if (dps2 == 1 && ne2 > 1) { + // src0 has uniform stride s02 along dim 2; batch the inner loop with a strided GEMM + GGML_ASSERT(ne2 <= std::numeric_limits::max()); + const int batch_count = (int) ne2; + for (int64_t i3 = 0; i3 < ne3; ++i3) { CUBLAS_CHECK( - cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op, + cublasSgemmStridedBatched(handle, CUBLAS_OP_N, src1_cublas_op, ne0, ne1, ne01, - &alpha, src0_d + (i3/dps3)*s03 + (i2/dps2)*s02, lda, - src1_d + i3 *s13 + i2 *s12, ldb, - &beta, dst_d + i3 *s3 + i2 *s2, ldc)); + &alpha, src0_d + (i3/dps3)*s03, lda, s02, + src1_d + i3 *s13, ldb, s12, + &beta, dst_d + i3 *s3, ldc, s2, + batch_count)); + } + } else { + // Fallback: ne2 == 1 (no batching benefit) or dps2 > 1 (src0 broadcast along dim 2 + // with non-uniform stride; would need cublasSgemmBatched with pointer arrays). + for (int64_t i3 = 0; i3 < ne3; ++i3) { + for (int64_t i2 = 0; i2 < ne2; ++i2) { + CUBLAS_CHECK( + cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op, + ne0, ne1, ne01, + &alpha, src0_d + (i3/dps3)*s03 + (i2/dps2)*s02, lda, + src1_d + i3 *s13 + i2 *s12, ldb, + &beta, dst_d + i3 *s3 + i2 *s2, ldc)); + } } } } diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h index e5d363c65d1..5e0e22c7fc2 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h @@ -48,6 +48,7 @@ #define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS #define cublasSetStream hipblasSetStream #define cublasSgemm hipblasSgemm +#define cublasSgemmStridedBatched hipblasSgemmStridedBatched #define cublasStatus_t hipblasStatus_t #define cublasOperation_t hipblasOperation_t #define cudaDevAttrCooperativeLaunch hipDeviceAttributeCooperativeLaunch diff --git a/ggml/src/ggml-cuda/vendors/musa.h b/ggml/src/ggml-cuda/vendors/musa.h index 940c34a9fb2..99e8fa3703e 100644 --- a/ggml/src/ggml-cuda/vendors/musa.h +++ b/ggml/src/ggml-cuda/vendors/musa.h @@ -32,6 +32,7 @@ #define cublasSetMathMode mublasSetMathMode #define cublasSetStream mublasSetStream #define cublasSgemm mublasSgemm +#define cublasSgemmStridedBatched mublasSgemmStridedBatched #define cublasStatus_t mublasStatus_t #define cublasOperation_t mublasOperation_t #define cublasGetStatusString mublasGetStatusString From ef77e10404ade9b4e74f5a546d974c9463defe94 Mon Sep 17 00:00:00 2001 From: Shawn Gu Date: Thu, 7 May 2026 21:17:07 -0700 Subject: [PATCH 569/831] opencl: add q4_0 MoE GEMM for Adreno (llama/22731) * Q4_0 MoE CLC pass sanity check * release program * opencl: fix whitespace * opencl: remove unused cl_program * opencl: break #if block to make it more clear * opencl: adjust format --------- Co-authored-by: Li He --- ggml/src/ggml-opencl/CMakeLists.txt | 2 + ggml/src/ggml-opencl/ggml-opencl.cpp | 296 +++++++++++++++++- ggml/src/ggml-opencl/kernels/cvt.cl | 86 +++++ .../kernels/gemm_moe_q4_0_f32_ns.cl | 252 +++++++++++++++ .../kernels/gemv_moe_q4_0_f32_ns.cl | 116 +++++++ 5 files changed, 743 insertions(+), 9 deletions(-) create mode 100644 ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl create mode 100644 ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index 0a45a4daa13..ffde6a4f063 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -102,6 +102,8 @@ set(GGML_OPENCL_KERNELS mul_mv_id_q8_0_f32_flat mul_mv_id_mxfp4_f32 mul_mv_id_mxfp4_f32_flat + gemm_moe_q4_0_f32_ns + gemv_moe_q4_0_f32_ns gemm_moe_mxfp4_f32 gemv_moe_mxfp4_f32 gemm_moe_mxfp4_f32_ns diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index e5a5d42f6fb..4e6f6fb43d2 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -542,6 +542,7 @@ struct ggml_backend_opencl_context { cl_kernel kernel_mul_mm_f16_f32_kq; cl_kernel kernel_mul_mat_q4_0_f32, kernel_mul_mat_q4_0_f32_v; cl_kernel kernel_convert_block_q4_0, kernel_restore_block_q4_0; + cl_kernel kernel_convert_block_q4_0_trans4_ns, kernel_restore_block_q4_0_trans4_ns; cl_kernel kernel_convert_block_q4_1, kernel_restore_block_q4_1; cl_kernel kernel_convert_block_mxfp4, kernel_convert_block_mxfp4_trans, kernel_restore_block_mxfp4, kernel_restore_block_mxfp4_trans; cl_kernel kernel_convert_block_mxfp4_trans4_ns, kernel_restore_block_mxfp4_trans4_ns; @@ -600,6 +601,7 @@ struct ggml_backend_opencl_context { cl_kernel kernel_conv_2d_f16_f32; cl_kernel kernel_ssm_conv_f32_f32, kernel_ssm_conv_f32_f32_4; cl_kernel kernel_timestep_embedding; + cl_kernel kernel_gemv_moe_q4_0_f32_ns, kernel_gemm_moe_q4_0_f32_ns; cl_kernel kernel_gemv_moe_mxfp4_f32, kernel_gemm_moe_mxfp4_f32; cl_kernel kernel_gemv_moe_mxfp4_f32_ns, kernel_gemm_moe_mxfp4_f32_ns; cl_kernel kernel_moe_reorder_b; @@ -950,6 +952,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve CL_CHECK((backend_ctx->kernel_restore_block_q4_0_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_0_noshuffle", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q4_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_0", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_q4_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_0", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q4_0_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_0_trans4_ns", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q4_0_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_0_trans4_ns", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q4_1_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_1_noshuffle", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_q4_1_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_1_noshuffle", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q4_1 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_1", &err), err)); @@ -2884,6 +2888,40 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // gemv_moe_q4_0_f32_ns + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemv_moe_q4_0_f32_ns.cl.h" + }; +#else + const std::string kernel_src = read_file("gemv_moe_q4_0_f32_ns.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); + + CL_CHECK((backend_ctx->kernel_gemv_moe_q4_0_f32_ns = clCreateKernel(prog, "kernel_gemv_moe_q4_0_f32_ns", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // gemm_moe_q4_0_f32_ns + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemm_moe_q4_0_f32_ns.cl.h" + }; +#else + const std::string kernel_src = read_file("gemm_moe_q4_0_f32_ns.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); + + CL_CHECK((backend_ctx->kernel_gemm_moe_q4_0_f32_ns = clCreateKernel(prog, "kernel_gemm_moe_q4_0_f32_ns", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + // gemv_moe_mxfp4_f32_ns { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -3657,11 +3695,14 @@ struct ggml_tensor_extra_cl_q4_0 { CL_CHECK(clReleaseMemObject(d)); d = nullptr; } + if (q_img != nullptr) { + CL_CHECK(clReleaseMemObject(q_img)); + q_img = nullptr; + } // Currently, q_img and d_img are only initialized when SMALL_ALLOC is // enabled. They point to the images in ggml_backend_opencl_buffer_context. // So, there is no need to release them here. // TODO: initialize them for non SMALL_PATH path, or remove them. - q_img = nullptr; d_img = nullptr; size_q = 0; size_d = 0; @@ -4926,17 +4967,53 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); CL_CHECK(err); - //cl_kernel kernel = backend_ctx->kernel_convert_block_q4_0; - #ifdef GGML_OPENCL_USE_ADRENO_KERNELS +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + // Adreno moe q4_0 kernel needs special transpose and unshuffling + if (use_adreno_moe_kernels(backend_ctx, tensor)) { + cl_kernel kernel = backend_ctx->kernel_convert_block_q4_0_trans4_ns; + + int ne00 = tensor->ne[0]; + int ne01 = tensor->ne[1]; + int ne02 = tensor->ne[2]; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne01)); + + size_t global_work_size[3] = {static_cast(((ne01 + 63) / 64) * 64), static_cast(ne00 / 32), static_cast(ne02)}; + size_t local_work_size[3] = {64, 2, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + + // Create image for Q + cl_image_format img_format_q = {CL_R, CL_UNSIGNED_INT32}; + cl_image_desc img_desc_q = { + CL_MEM_OBJECT_IMAGE1D_BUFFER, + static_cast(ggml_nelements(tensor) / 8), + 0, 0, 0, 0, 0, 0, 0, + { extra->q } + }; + extra->q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_format_q, &img_desc_q, NULL, &err); + tensor->extra = extra; + + return; + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS cl_kernel kernel = backend_ctx->kernel_convert_block_q4_0; // The optimized kernels need weights in natural order, so unshuffle. if (use_adreno_kernels(backend_ctx, tensor)) { kernel = backend_ctx->kernel_convert_block_q4_0_noshuffle; } - #else +#else cl_kernel kernel = backend_ctx->kernel_convert_block_q4_0; - #endif // GGML_OPENCL_USE_ADRENO_KERNELS +#endif // GGML_OPENCL_USE_ADRENO_KERNELS CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q)); CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d)); @@ -4952,7 +5029,7 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, tensor->extra = extra; // transpose the weights and scales - #ifdef GGML_OPENCL_USE_ADRENO_KERNELS +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS // Only do transpose for large, non batched matrix // TODO: use preallocated images instead of sub-buffer then image if (use_adreno_kernels(backend_ctx, tensor)) { @@ -4966,10 +5043,8 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, // Transpose d as ushort transpose_2d_as_16b(backend_ctx, extra->d, extra->d, size_d, K/32, M); } - #endif // GGML_OPENCL_USE_ADRENO_KERNELS - +#endif // GGML_OPENCL_USE_ADRENO_KERNELS return; - } if (tensor->type == GGML_TYPE_Q4_1) { ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra; @@ -5689,6 +5764,36 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, ggml_tensor_extra_cl_q4_0 * extra = (ggml_tensor_extra_cl_q4_0 *)tensor->extra; #ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_moe_kernels(backend_ctx, tensor)) { + cl_int err; + cl_kernel kernel = backend_ctx->kernel_restore_block_q4_0_trans4_ns; + + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + + int ne00 = tensor->ne[0]; + int ne01 = tensor->ne[1]; + int ne02 = tensor->ne[2]; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_int), &ne01)); + + size_t global_work_size[3] = {static_cast(((ne01 + 63) / 64) * 64), static_cast(ne00 / 32), static_cast(ne02)}; + size_t local_work_size[3] = {64, 2, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clEnqueueReadBuffer( + queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); + return; + } if (use_adreno_kernels(backend_ctx, tensor)) { ggml_cl_buffer buf_trans_q; ggml_cl_buffer buf_trans_d; @@ -12811,6 +12916,179 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, // subgroup mat vec switch (src0->type) { case GGML_TYPE_Q4_0: { +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_moe_kernels(backend_ctx, src0)) { + cl_int status; + + size_t local_size[3] = {64, 2, 1}; + size_t global_size[3] = {64, 2, 1}; + + if (ne12 == 1) { // for gemv + kernel = backend_ctx->kernel_gemv_moe_q4_0_f32_ns; + + cl_mem src1_sub_buffer, buf_src1_image, buf_src2; + + // create a sub_buffer for src2 + cl_buffer_region region; + region.origin = offset2; + region.size = ne20 * ne21 * sizeof(int); + buf_src2 = clCreateSubBuffer(extra2->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // set thread grid + global_size[0] = static_cast(ne01); + global_size[1] = 4; + global_size[2] = static_cast(ne20); + local_size[1] = 4; + + // create a sub_buffer for src1 + region.origin = offset1; + region.size = ne10 * ne11 * ne12 * sizeof(float); + src1_sub_buffer = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // create image for src1 + cl_image_format image_format_buf_src1 = {CL_RGBA, CL_FLOAT}; + cl_image_desc image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ne10 * ne11 * ne12 / 4), 0,0,0,0,0,0,0, {src1_sub_buffer}}; + buf_src1_image = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status); + CL_CHECK(status); + + // Set kernel args + int arg_idx = 0; + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_0->q)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_0->d)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src1_image)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne11)); + + // launch kernel + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst); + + // deallocate sub buffers and images + CL_CHECK(clReleaseMemObject(src1_sub_buffer)); + CL_CHECK(clReleaseMemObject(buf_src1_image)); + CL_CHECK(clReleaseMemObject(buf_src2)); + + } else { // for gemm + kernel = backend_ctx->kernel_gemm_moe_q4_0_f32_ns; + + // Reorder router if called from test-backend-ops or when new router is generated. + // Otherwise reuse the reordered result from previous mul_mat_id call. + if ((strstr(src0->name, "as") != NULL) || backend_ctx->toggle_reorder) { + moe_router_reoerder(backend, src2, ne20); + backend_ctx->toggle_reorder = false; + } + + cl_mem sub_buf_src1_pre, buf_src1_reordered, image_src1_reordered, sub_buf_dst, buf_dst_image; + cl_mem buf_src2, buf_src2_emap; + + cl_buffer_region region; + region.origin = 0; + region.size = sizeof(int) * max_post_router_tile * n_tile_size; + buf_src2 = clCreateSubBuffer(backend_ctx->prealloc_post_router.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + region.origin = 0; + region.size = sizeof(short) * max_post_router_tile; + buf_src2_emap = clCreateSubBuffer(backend_ctx->prealloc_emap.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // Reorder activations + // create a sub_buffer for src1 + region.origin = offset1; + region.size = ne10 * ne11 * ne12 * sizeof(float); + sub_buf_src1_pre = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // Create image for reordered src1 + // Use pre-allocated placeholder + region.origin = 0; + region.size = ne00 * max_post_router_tile * n_tile_size * sizeof(float); + backend_ctx->prealloc_act_trans.allocate(backend_ctx->context, region.size); + buf_src1_reordered = clCreateSubBuffer( + backend_ctx->prealloc_act_trans.buffer, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &status); + CL_CHECK(status); + cl_image_format image_format_buf_src1; + cl_image_desc image_desc_buf_src1; + image_format_buf_src1 = {CL_RGBA, CL_FLOAT}; + image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ne00 * max_post_router_tile * n_tile_size / 4), 0,0,0,0,0,0,0, {buf_src1_reordered}}; + image_src1_reordered = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status); + CL_CHECK(status); + + unsigned short map_ratio = ne20 / ne11; + GGML_ASSERT(((map_ratio == 1) || (map_ratio == ne20)) && "Map ratio not supported\n"); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 0, sizeof(cl_mem), &sub_buf_src1_pre)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 1, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 2, sizeof(cl_mem), &buf_src1_reordered)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 3, sizeof(cl_mem), &(backend_ctx->prealloc_total_tiles.buffer))); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 4, sizeof(unsigned int), &ne00)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 5, sizeof(unsigned short), &map_ratio)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 6, sizeof(unsigned int), &n_tile_size)); + + size_t reorder_b_local_size[3] = {256, 1, 1}; + size_t reorder_b_global_size[3] = {static_cast(((ne00 / 4) + 255) / 256 * 256), static_cast(max_post_router_tile * n_tile_size), 1}; + + // Dispatch reorder kernel + backend_ctx->enqueue_ndrange_kernel(backend_ctx->kernel_moe_reorder_b, 3, reorder_b_global_size, reorder_b_local_size, dst); + + // MoE kernel prepare + // Create sub buffer for dst + region.origin = offsetd; + region.size = ne0 * ne1 * ne2 * sizeof(float); + sub_buf_dst = clCreateSubBuffer( + extrad->data_device, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &status); + CL_CHECK(status); + // Create image for dst + cl_image_format image_format_buf_dst = {CL_R, CL_FLOAT}; + cl_image_desc image_desc_buf_dst = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ne0 * ne1 * ne2), 0,0,0,0,0,0,0, {sub_buf_dst}}; + buf_dst_image = clCreateImage(backend_ctx->context, CL_MEM_WRITE_ONLY, &image_format_buf_dst, &image_desc_buf_dst, NULL, &status); + CL_CHECK(status); + + // Set kernel args + int arg_idx = 0; + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_0->q_img)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_0->d)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &image_src1_reordered)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2_emap)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_dst_image)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &(backend_ctx->prealloc_total_tiles.buffer))); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01)); + + // set thread grid + global_size[1] = static_cast((ne01 + 63) / 64); + global_size[2] = static_cast(max_post_router_tile); + local_size[1] = 1; + local_size[2] = 1; + + // Dispatch kernel + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst); + + clReleaseMemObject(sub_buf_src1_pre); + clReleaseMemObject(buf_src1_reordered); + clReleaseMemObject(image_src1_reordered); + clReleaseMemObject(buf_src2); + clReleaseMemObject(buf_src2_emap); + clReleaseMemObject(sub_buf_dst); + clReleaseMemObject(buf_dst_image); + } + return; + } // fallback to generic Q4_0 MoE kernel + +#endif // GGML_OPENCL_USE_ADRENO_KERNELS kernel = backend_ctx->kernel_mul_mv_id_q4_0_f32_8x_flat; if (backend_ctx->gpu_family == INTEL) { diff --git a/ggml/src/ggml-opencl/kernels/cvt.cl b/ggml/src/ggml-opencl/kernels/cvt.cl index c1ad46f4435..c87450dc49e 100644 --- a/ggml/src/ggml-opencl/kernels/cvt.cl +++ b/ggml/src/ggml-opencl/kernels/cvt.cl @@ -190,6 +190,92 @@ kernel void kernel_restore_block_q4_0_noshuffle( } } +kernel void kernel_convert_block_q4_0_trans4_ns( + global struct block_q4_0 * src0, + __global uint * dst_q, + __global half * dst_d, + uint ne00, + uint ne01 +) { + uint i00 = get_global_id(1); + uint i01 = get_global_id(0); + uint i02 = get_global_id(2); + + uint ne00_blk = ne00 / QK4_0; + uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; + uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; + + global struct block_q4_0 * b = src0 + src_blk_offset; + dst_d[dst_blk_offset] = b->d; + + // extract quantization and unshuffle + ushort8 pre_block = ((global ushort8 *)(&(b->qs[0])))[0]; + + ushort8 post_block = (ushort8)(0); + + uchar * pre_block_ptr = (uchar *)(&pre_block); + uchar * post_block_ptr = (uchar *)(&post_block); + + for (int i = 0; i < QK4_0 / 4; ++i) { + uchar x0 = pre_block_ptr[2*i + 0]; + uchar x1 = pre_block_ptr[2*i + 1]; + + post_block_ptr[i + 0 ] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4); + post_block_ptr[i + QK4_0 / 4] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0); + } + + uint4 q_block = as_uint4(post_block); + + uint offset = i02 * ne00_blk * ne01 * 4 + i00 * ne01 * 4 + i01; + dst_q[offset] = q_block.x; + dst_q[offset + ne01] = q_block.y; + dst_q[offset + ne01 * 2] = q_block.z; + dst_q[offset + ne01 * 3] = q_block.w; +} + +kernel void kernel_restore_block_q4_0_trans4_ns( + __global uint * src_q, + __global half * src_d, + __global struct block_q4_0 * dst0, + uint ne00, + uint ne01 +) { + uint i00 = get_global_id(1); + uint i01 = get_global_id(0); + uint i02 = get_global_id(2); + + uint ne00_blk = ne00 / QK4_0; + uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; + uint src_d_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; + + __global struct block_q4_0 * b = dst0 + dst_blk_offset; + b->d = src_d[src_d_offset]; + + // collect transposed quantization parts for a block + uint src_q_offset = i02 * ne00_blk * ne01 * 4 + i00 * ne01 * 4 + i01; + uint4 q_block; + q_block.x = src_q[src_q_offset]; + q_block.y = src_q[src_q_offset + ne01]; + q_block.z = src_q[src_q_offset + ne01 * 2]; + q_block.w = src_q[src_q_offset + ne01 * 3]; + + ushort8 post_block = as_ushort8(q_block); + ushort8 pre_block = (ushort8)(0); + + uchar * pre_block_ptr = (uchar *)(&pre_block); + uchar * post_block_ptr = (uchar *)(&post_block); + + for (int i = 0; i < QK4_0 / 4; ++i) { + uchar x0 = post_block_ptr[i + 0]; + uchar x1 = post_block_ptr[i + QK4_0 / 4]; + + pre_block_ptr[2 * i + 0] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4); + pre_block_ptr[2 * i + 1] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0); + } + + ((__global ushort8 *)(&(b->qs[0])))[0] = pre_block; +} + //------------------------------------------------------------------------------ // kernel_convert_block_q4_1 // Convert the block_q4_1 format to 2 separate arrays (AOS -> SOA). diff --git a/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl new file mode 100644 index 00000000000..02290c17eb1 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl @@ -0,0 +1,252 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_uniform_load: enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_constant_load: enable +#pragma OPENCL EXTENSION cl_qcom_extra_vector_types : enable + +#define TILESIZE_K 16 +#define TILESIZE_M 64 +#define TILESIZE_N 32 + + +#define dequantize_q4_0(q4, a_f16, scale) \ + a_f16.s0 = (half)((q4.s0 & 0x000F) - 8) * scale; \ + a_f16.s1 = (half)(((q4.s0 & 0x00F0) >> 4) - 8) * scale; \ + a_f16.s2 = (half)(((q4.s0 & 0x0F00) >> 8) - 8) * scale; \ + a_f16.s3 = (half)(((q4.s0 & 0xF000) >> 12) - 8) * scale; \ + a_f16.s4 = (half)((q4.s1 & 0x000F) - 8) * scale; \ + a_f16.s5 = (half)(((q4.s1 & 0x00F0) >> 4) - 8) * scale; \ + a_f16.s6 = (half)(((q4.s1 & 0x0F00) >> 8) - 8) * scale; \ + a_f16.s7 = (half)(((q4.s1 & 0xF000) >> 12) - 8) * scale; \ + a_f16.s8 = (half)((q4.s2 & 0x000F) - 8) * scale; \ + a_f16.s9 = (half)(((q4.s2 & 0x00F0) >> 4) - 8) * scale; \ + a_f16.sa = (half)(((q4.s2 & 0x0F00) >> 8) - 8) * scale; \ + a_f16.sb = (half)(((q4.s2 & 0xF000) >> 12) - 8) * scale; \ + a_f16.sc = (half)((q4.s3 & 0x000F) - 8) * scale; \ + a_f16.sd = (half)(((q4.s3 & 0x00F0) >> 4) - 8) * scale; \ + a_f16.se = (half)(((q4.s3 & 0x0F00) >> 8) - 8) * scale; \ + a_f16.sf = (half)(((q4.s3 & 0xF000) >> 12) - 8) * scale; \ + + +#define dotx16_reduce8(a_reg, b_lm, c_reg, lm_offset) \ + acc.s0 = dot(a_reg.s0123, b_lm[lm_offset + 0]); \ + acc.s1 = dot(a_reg.s0123, b_lm[lm_offset + 1]); \ + acc.s2 = dot(a_reg.s0123, b_lm[lm_offset + 2]); \ + acc.s3 = dot(a_reg.s0123, b_lm[lm_offset + 3]); \ + acc.s4 = dot(a_reg.s0123, b_lm[lm_offset + 4]); \ + acc.s5 = dot(a_reg.s0123, b_lm[lm_offset + 5]); \ + acc.s6 = dot(a_reg.s0123, b_lm[lm_offset + 6]); \ + acc.s7 = dot(a_reg.s0123, b_lm[lm_offset + 7]); \ + acc.s8 = dot(a_reg.s0123, b_lm[lm_offset + 8]); \ + acc.s9 = dot(a_reg.s0123, b_lm[lm_offset + 9]); \ + acc.sa = dot(a_reg.s0123, b_lm[lm_offset + 10]); \ + acc.sb = dot(a_reg.s0123, b_lm[lm_offset + 11]); \ + acc.sc = dot(a_reg.s0123, b_lm[lm_offset + 12]); \ + acc.sd = dot(a_reg.s0123, b_lm[lm_offset + 13]); \ + acc.se = dot(a_reg.s0123, b_lm[lm_offset + 14]); \ + acc.sf = dot(a_reg.s0123, b_lm[lm_offset + 15]); \ + acc.s0 += dot(a_reg.s4567, b_lm[lm_offset + 32]); \ + acc.s1 += dot(a_reg.s4567, b_lm[lm_offset + 33]); \ + acc.s2 += dot(a_reg.s4567, b_lm[lm_offset + 34]); \ + acc.s3 += dot(a_reg.s4567, b_lm[lm_offset + 35]); \ + acc.s4 += dot(a_reg.s4567, b_lm[lm_offset + 36]); \ + acc.s5 += dot(a_reg.s4567, b_lm[lm_offset + 37]); \ + acc.s6 += dot(a_reg.s4567, b_lm[lm_offset + 38]); \ + acc.s7 += dot(a_reg.s4567, b_lm[lm_offset + 39]); \ + acc.s8 += dot(a_reg.s4567, b_lm[lm_offset + 40]); \ + acc.s9 += dot(a_reg.s4567, b_lm[lm_offset + 41]); \ + acc.sa += dot(a_reg.s4567, b_lm[lm_offset + 42]); \ + acc.sb += dot(a_reg.s4567, b_lm[lm_offset + 43]); \ + acc.sc += dot(a_reg.s4567, b_lm[lm_offset + 44]); \ + acc.sd += dot(a_reg.s4567, b_lm[lm_offset + 45]); \ + acc.se += dot(a_reg.s4567, b_lm[lm_offset + 46]); \ + acc.sf += dot(a_reg.s4567, b_lm[lm_offset + 47]); \ + c_reg.lo += convert_float8(acc.lo); \ + c_reg.hi += convert_float8(acc.hi); \ + acc.s0 = dot(a_reg.s89ab, b_lm[lm_offset + 64]); \ + acc.s1 = dot(a_reg.s89ab, b_lm[lm_offset + 65]); \ + acc.s2 = dot(a_reg.s89ab, b_lm[lm_offset + 66]); \ + acc.s3 = dot(a_reg.s89ab, b_lm[lm_offset + 67]); \ + acc.s4 = dot(a_reg.s89ab, b_lm[lm_offset + 68]); \ + acc.s5 = dot(a_reg.s89ab, b_lm[lm_offset + 69]); \ + acc.s6 = dot(a_reg.s89ab, b_lm[lm_offset + 70]); \ + acc.s7 = dot(a_reg.s89ab, b_lm[lm_offset + 71]); \ + acc.s8 = dot(a_reg.s89ab, b_lm[lm_offset + 72]); \ + acc.s9 = dot(a_reg.s89ab, b_lm[lm_offset + 73]); \ + acc.sa = dot(a_reg.s89ab, b_lm[lm_offset + 74]); \ + acc.sb = dot(a_reg.s89ab, b_lm[lm_offset + 75]); \ + acc.sc = dot(a_reg.s89ab, b_lm[lm_offset + 76]); \ + acc.sd = dot(a_reg.s89ab, b_lm[lm_offset + 77]); \ + acc.se = dot(a_reg.s89ab, b_lm[lm_offset + 78]); \ + acc.sf = dot(a_reg.s89ab, b_lm[lm_offset + 79]); \ + acc.s0 += dot(a_reg.scdef, b_lm[lm_offset + 96]); \ + acc.s1 += dot(a_reg.scdef, b_lm[lm_offset + 97]); \ + acc.s2 += dot(a_reg.scdef, b_lm[lm_offset + 98]); \ + acc.s3 += dot(a_reg.scdef, b_lm[lm_offset + 99]); \ + acc.s4 += dot(a_reg.scdef, b_lm[lm_offset + 100]); \ + acc.s5 += dot(a_reg.scdef, b_lm[lm_offset + 101]); \ + acc.s6 += dot(a_reg.scdef, b_lm[lm_offset + 102]); \ + acc.s7 += dot(a_reg.scdef, b_lm[lm_offset + 103]); \ + acc.s8 += dot(a_reg.scdef, b_lm[lm_offset + 104]); \ + acc.s9 += dot(a_reg.scdef, b_lm[lm_offset + 105]); \ + acc.sa += dot(a_reg.scdef, b_lm[lm_offset + 106]); \ + acc.sb += dot(a_reg.scdef, b_lm[lm_offset + 107]); \ + acc.sc += dot(a_reg.scdef, b_lm[lm_offset + 108]); \ + acc.sd += dot(a_reg.scdef, b_lm[lm_offset + 109]); \ + acc.se += dot(a_reg.scdef, b_lm[lm_offset + 110]); \ + acc.sf += dot(a_reg.scdef, b_lm[lm_offset + 111]); \ + c_reg.lo += convert_float8(acc.lo); \ + c_reg.hi += convert_float8(acc.hi); \ + + +__attribute__((qcom_wave_pair_mode(1))) // 1=force single 2=force pair +kernel void kernel_gemm_moe_q4_0_f32_ns( + __read_only image1d_buffer_t src0_q, + __global half * src0_d, + __read_only image1d_buffer_t src1, + __global uint * src2, + __global ushort * src2_emap, + __write_only image1d_buffer_t dst, + __global int * total_tiles, + uint ne00, + uint ne01 +) { + uint block_id_m = get_global_id(1); // m_tile + uint block_id_n = get_global_id(2); // n_tile + + // Boundary check + if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) { + return; + } + + __private half16 reg_a; + __private float32 reg_c = (float32)(0); + __local half4 shared_b[128]; + + const ushort expert_id = src2_emap[block_id_n]; + + const uint row = block_id_m * TILESIZE_M; + const uint col = block_id_n * TILESIZE_N; + + uint sub_block_id_m = get_local_id(0); + uint2 b_global_offset; + b_global_offset.x = ((sub_block_id_m & 3) << 2) + (sub_block_id_m >> 2) * ne00; + b_global_offset.y = b_global_offset.x + (16 * ne00); + uint2 b_local_offset; + b_local_offset.x = (sub_block_id_m & 3) * 32 + (sub_block_id_m >> 2); + b_local_offset.y = b_local_offset.x + 16; + + // Loop along K axis, 32 elements (one block) for each iteration, divided into 2 sub-blocks + for (uint step = 0; step < ne00; step += TILESIZE_K * 2) { + // First sub-block + uint q_sub_offset = row + ((ne01 * step) >> 3) + ((expert_id * ne00 * ne01) >> 3); + uint s_sub_offset = row + ((ne01 * step) >> 5) + ((expert_id * ne00 * ne01) >> 5); + uint b_sub_offset = col * ne00 + step; + + // Load scale for current Q4_0 block + uint s_offset = s_sub_offset + get_global_id(0); + half s = src0_d[s_offset]; + + // Load 16 q (64-bits) in transposed layout + uint2 q4x16; + q4x16.x = read_imageui(src0_q, q_sub_offset + sub_block_id_m).x; + q4x16.y = read_imageui(src0_q, q_sub_offset + sub_block_id_m + ne01).x; + + // Load 16x32 floats from matrix B, each fiber out of 64 in a sub-group loads 8 elements + float8 bx8_f32; + bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4); + bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4); + // Convert to half and store to LM to share within the subgroup + half8 bx8_f16 = convert_half8(bx8_f32); + shared_b[b_local_offset.x] = bx8_f16.lo; + shared_b[b_local_offset.y] = bx8_f16.hi; + + // Dequantization + dequantize_q4_0(as_ushort4(q4x16), reg_a, s); + + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + + // 32 16x16 fp16 dot product with 8 elements reduction for better precision + half16 acc; + dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0); + dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); + + // Repeat for second sub-block + uint half_step = step + TILESIZE_K; + q_sub_offset = row + ((ne01 * half_step) >> 3) + ((expert_id * ne00 * ne01) >> 3); + b_sub_offset = col * ne00 + half_step; + + // Load next 16 q (64-bits) in transposed layout + q4x16.x = read_imageui(src0_q, q_sub_offset + sub_block_id_m).x; + q4x16.y = read_imageui(src0_q, q_sub_offset + sub_block_id_m + ne01).x; + + // Load 16x32 floats from matrix B, each fiber out of 64 in a sub-group loads 8 elements + bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4); + bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4); + // Convert to half and store to LM to share within the subgroup + bx8_f16 = convert_half8(bx8_f32); + shared_b[b_local_offset.x] = bx8_f16.lo; + shared_b[b_local_offset.y] = bx8_f16.hi; + + // Dequantization + dequantize_q4_0(as_ushort4(q4x16), reg_a, s); + + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + + // 32 16x16 fp16 dot product with 3-levels reduction for better precision + dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0); + dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); + } + + // Load poster router and share in LM + __local uint out_idx[TILESIZE_N]; + + if (get_local_id(0) < TILESIZE_N) { + uint idx = src2[block_id_n * TILESIZE_N + get_local_id(0)]; + if (idx == 0xFFFFFFFF) { + idx = src2[block_id_n * TILESIZE_N + 0]; + } + out_idx[get_local_id(0)] = idx * ne01; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + // Scatter results back to original position in output grid + uint m_offset = row + get_local_id(0); + + write_imagef(dst, out_idx[1] + m_offset, (reg_c.s1)); + write_imagef(dst, out_idx[2] + m_offset, (reg_c.s2)); + write_imagef(dst, out_idx[3] + m_offset, (reg_c.s3)); + write_imagef(dst, out_idx[4] + m_offset, (reg_c.s4)); + write_imagef(dst, out_idx[5] + m_offset, (reg_c.s5)); + write_imagef(dst, out_idx[6] + m_offset, (reg_c.s6)); + write_imagef(dst, out_idx[7] + m_offset, (reg_c.s7)); + write_imagef(dst, out_idx[8] + m_offset, (reg_c.s8)); + write_imagef(dst, out_idx[9] + m_offset, (reg_c.s9)); + write_imagef(dst, out_idx[10] + m_offset, (reg_c.sa)); + write_imagef(dst, out_idx[11] + m_offset, (reg_c.sb)); + write_imagef(dst, out_idx[12] + m_offset, (reg_c.sc)); + write_imagef(dst, out_idx[13] + m_offset, (reg_c.sd)); + write_imagef(dst, out_idx[14] + m_offset, (reg_c.se)); + write_imagef(dst, out_idx[15] + m_offset, (reg_c.sf)); + write_imagef(dst, out_idx[16] + m_offset, (reg_c.sg)); + write_imagef(dst, out_idx[17] + m_offset, (reg_c.sh)); + write_imagef(dst, out_idx[18] + m_offset, (reg_c.si)); + write_imagef(dst, out_idx[19] + m_offset, (reg_c.sj)); + write_imagef(dst, out_idx[20] + m_offset, (reg_c.sk)); + write_imagef(dst, out_idx[21] + m_offset, (reg_c.sl)); + write_imagef(dst, out_idx[22] + m_offset, (reg_c.sm)); + write_imagef(dst, out_idx[23] + m_offset, (reg_c.sn)); + write_imagef(dst, out_idx[24] + m_offset, (reg_c.so)); + write_imagef(dst, out_idx[25] + m_offset, (reg_c.sp)); + write_imagef(dst, out_idx[26] + m_offset, (reg_c.sq)); + write_imagef(dst, out_idx[27] + m_offset, (reg_c.sr)); + write_imagef(dst, out_idx[28] + m_offset, (reg_c.ss)); + write_imagef(dst, out_idx[29] + m_offset, (reg_c.st)); + write_imagef(dst, out_idx[30] + m_offset, (reg_c.su)); + write_imagef(dst, out_idx[31] + m_offset, (reg_c.sv)); + + // Store zero padding parts to the index of first output in tile, override correct result in the end + barrier(CLK_GLOBAL_MEM_FENCE); + write_imagef(dst, out_idx[0] + m_offset, (reg_c.s0)); +} diff --git a/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl new file mode 100644 index 00000000000..6f4d3f53216 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl @@ -0,0 +1,116 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable + +#define QK_Q4_0 32 +#define N_SIMDGROUP 4 +#define SIMDGROUP_WIDTH 64 + +static inline float8 q4_0_to_fp32_packed8(ushort2 q4x8) { + float8 fp32x8; + fp32x8.s0 = (float)((q4x8.s0 & 0x000F) - 8); + fp32x8.s1 = (float)(((q4x8.s0 & 0x00F0) >> 4) - 8); + fp32x8.s2 = (float)(((q4x8.s0 & 0x0F00) >> 8) - 8); + fp32x8.s3 = (float)(((q4x8.s0 & 0xF000) >> 12) - 8); + fp32x8.s4 = (float)((q4x8.s1 & 0x000F) - 8); + fp32x8.s5 = (float)(((q4x8.s1 & 0x00F0) >> 4) - 8); + fp32x8.s6 = (float)(((q4x8.s1 & 0x0F00) >> 8) - 8); + fp32x8.s7 = (float)(((q4x8.s1 & 0xF000) >> 12) - 8); + return fp32x8; +} + + +__attribute__((qcom_reqd_sub_group_size("half"))) +__kernel void kernel_gemv_moe_q4_0_f32_ns( + __global uint * src0_q, + __global half * src0_d, + __read_only image1d_buffer_t src1, + __global uint * src2, + __global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne11 +) { + uint i01 = get_global_id(0); + uint i20 = get_global_id(2); + uint sgid = get_local_id(1); + uint slid = get_sub_group_local_id(); + + uint i11 = i20 % ne11; + + uint expert_id = src2[i20]; + uint expert_offset = expert_id * ne00 * ne01 / 32; + + __private float sum = 0.0f; // each thread calculate partial sum of one output + + // loop along ne00 in block granularity, skip 4 blocks every iter + for (uint ib00 = sgid; ib00 < (ne00 / QK_Q4_0); ib00 += N_SIMDGROUP) { + + // load one block of q + uint4 regQ; + uint block_offset = expert_offset * 4 + ib00 * ne01 * 4 + i01; + + regQ.s0 = src0_q[block_offset]; + regQ.s1 = src0_q[block_offset + ne01]; + regQ.s2 = src0_q[block_offset + ne01 * 2]; + regQ.s3 = src0_q[block_offset + ne01 * 3]; + + uint offset = i11 * ne00 / 4 + ib00 * 8; + + float8 fp32x8 = q4_0_to_fp32_packed8(as_ushort2(regQ.s0)); + + float4 shared_y4; + shared_y4 = read_imagef(src1, (offset + 0)); + float4 acc = shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (offset + 1)); + acc += shared_y4 * fp32x8.hi; + + fp32x8 = q4_0_to_fp32_packed8(as_ushort2(regQ.s1)); + + shared_y4 = read_imagef(src1, (offset + 2)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (offset + 3)); + acc += shared_y4 * fp32x8.hi; + + + fp32x8 = q4_0_to_fp32_packed8(as_ushort2(regQ.s2)); + + shared_y4 = read_imagef(src1, (offset + 4)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (offset + 5)); + acc += shared_y4 * fp32x8.hi; + + + fp32x8 = q4_0_to_fp32_packed8(as_ushort2(regQ.s3)); + + shared_y4 = read_imagef(src1, (offset + 6)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (offset + 7)); + acc += shared_y4 * fp32x8.hi; + + half regS = src0_d[ib00 * ne01 + i01 + expert_offset]; + sum += (float)(regS) * ((acc.s0 + acc.s1) + (acc.s2 + acc.s3)); + } + + // reduction in local memory, assumes #subgroups=4 + __local float reduceLM[SIMDGROUP_WIDTH * (N_SIMDGROUP - 1)]; + if (sgid == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = sum; + if (sgid == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = sum; + if (sgid == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = sum; + barrier(CLK_LOCAL_MEM_FENCE); + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 0 + slid]; + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 1 + slid]; + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 2 + slid]; + + // 1 outputs per thread in subgroup 0 + if (sgid == 0) { + dst = dst + (offsetd >> 2); + dst[i01 + i20 * ne01] = sum; + } + +} From eb38a02de13c2778c18a514a4c93f3e49dda016d Mon Sep 17 00:00:00 2001 From: Max Krasnyansky Date: Thu, 7 May 2026 22:43:04 -0700 Subject: [PATCH 570/831] ggml: update SCHED_DEBUG output to use ggml_op_desc() (llama/22825) --- ggml/src/ggml-backend.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index d9f8aaec52f..4e36909f45e 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -965,7 +965,7 @@ static void ggml_backend_sched_print_assignments(ggml_backend_sched_t sched, str } if (sched->debug > 1) { ggml_backend_t tensor_backend = ggml_backend_sched_get_tensor_backend(sched, node); - GGML_LOG_DEBUG("node #%3d (%10.10s): %20.20s (%5.5s) [%5.5s %8.8s] use=%d,c=%d:", i, ggml_op_name(node->op), node->name, + GGML_LOG_DEBUG("node #%3d (%10.10s): %20.20s (%5.5s) [%5.5s %8.8s] use=%d,c=%d:", i, ggml_op_desc(node), node->name, fmt_size(ggml_nbytes(node)), tensor_backend ? ggml_backend_name(tensor_backend) : "NULL", GET_CAUSE(node), graph->use_counts[ggml_hash_find(&graph->visited_hash_set, node)], node->flags & GGML_TENSOR_FLAG_COMPUTE ? 1 : 0); for (int j = 0; j < GGML_MAX_SRC; j++) { From 803424ac5a03c1b05945b432f1ea94e1e1b5b1bb Mon Sep 17 00:00:00 2001 From: miyan <1138989048@qq.com> Date: Fri, 8 May 2026 15:35:22 +0800 Subject: [PATCH 571/831] vulkan: fix spv shadowing (llama/22760) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 423e01dbff1..0a7931002ab 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -2149,11 +2149,11 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin // Patch SPIR-V to enable RTE rounding for FP16, avoiding the need for // separate shader variants compiled with -DRTE16. - std::vector spv; + std::vector spirv; if (device->float_controls_rte_fp16) { const uint32_t* spv_words = reinterpret_cast(spv_data); size_t word_count = spv_size / sizeof(uint32_t); - spv.assign(spv_words, spv_words + word_count); + spirv.assign(spv_words, spv_words + word_count); // Find insertion points respecting SPIR-V layout order: // Header(5) -> OpCapability -> OpExtension -> ... -> OpEntryPoint -> OpExecutionMode -> ... @@ -2163,9 +2163,9 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin size_t exec_insert_pos = pos; uint32_t entry_point_id = 0; - while (pos < spv.size()) { - uint32_t opcode = spv[pos] & spv::OpCodeMask; - uint32_t len = spv[pos] >> spv::WordCountShift; + while (pos < spirv.size()) { + uint32_t opcode = spirv[pos] & spv::OpCodeMask; + uint32_t len = spirv[pos] >> spv::WordCountShift; if (len == 0) break; if (opcode == spv::OpCapability) { @@ -2174,7 +2174,7 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin } else if (opcode == spv::OpExtension) { ext_insert_pos = pos + len; } else if (opcode == spv::OpEntryPoint) { - entry_point_id = spv[pos + 2]; + entry_point_id = spirv[pos + 2]; exec_insert_pos = pos + len; } else if (opcode == spv::OpExecutionMode || opcode == spv::OpExecutionModeId) { exec_insert_pos = pos + len; @@ -2189,7 +2189,7 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin // OpExecutionMode %entrypoint RoundingModeRTE 16 uint32_t exec_mode[] = { (4u << spv::WordCountShift) | spv::OpExecutionMode, entry_point_id, spv::ExecutionModeRoundingModeRTE, 16 }; - spv.insert(spv.begin() + exec_insert_pos, std::begin(exec_mode), std::end(exec_mode)); + spirv.insert(spirv.begin() + exec_insert_pos, std::begin(exec_mode), std::end(exec_mode)); // OpExtension "SPV_KHR_float_controls" const char ext_str[] = "SPV_KHR_float_controls"; @@ -2197,13 +2197,13 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin std::vector extension(1 + ext_str_words, 0); extension[0] = (uint32_t)((1 + ext_str_words) << spv::WordCountShift) | spv::OpExtension; memcpy(&extension[1], ext_str, sizeof(ext_str)); - spv.insert(spv.begin() + ext_insert_pos, extension.begin(), extension.end()); + spirv.insert(spirv.begin() + ext_insert_pos, extension.begin(), extension.end()); // OpCapability RoundingModeRTE uint32_t capability[] = { (2u << spv::WordCountShift) | spv::OpCapability, spv::CapabilityRoundingModeRTE }; - spv.insert(spv.begin() + cap_insert_pos, std::begin(capability), std::end(capability)); + spirv.insert(spirv.begin() + cap_insert_pos, std::begin(capability), std::end(capability)); - shader_module_create_info = vk::ShaderModuleCreateInfo({}, spv.size() * sizeof(uint32_t), spv.data()); + shader_module_create_info = vk::ShaderModuleCreateInfo({}, spirv.size() * sizeof(uint32_t), spirv.data()); } pipeline->shader_module = device->device.createShaderModule(shader_module_create_info); From ea459fba9d7c88bb2137f32607d7a28f8538e1b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Fri, 8 May 2026 10:09:38 +0200 Subject: [PATCH 572/831] CUDA: lower-case PCI bus id, standardize for ggml (llama/22820) --- ggml/include/ggml-backend.h | 2 +- ggml/src/ggml-cuda/ggml-cuda.cu | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h index d0c7e5a1be0..b6f73739809 100644 --- a/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h @@ -169,7 +169,7 @@ extern "C" { // device type enum ggml_backend_dev_type type; // device id - // for PCI devices, this should be the PCI bus id formatted as "domain:bus:device.function" (e.g. "0000:01:00.0") + // for PCI devices, this should be the lower-case PCI bus id formatted as "domain:bus:device.function" (e.g. "0000:c1:00.0") // if the id is unknown, this should be NULL const char * device_id; // device capabilities diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 8d21b2267f5..925a9ffe04c 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -5434,6 +5434,9 @@ ggml_backend_reg_t ggml_backend_cuda_reg() { char pci_bus_id[32] = {}; CUDA_CHECK(cudaDeviceGetPCIBusId(pci_bus_id, sizeof(pci_bus_id), i)); dev_ctx->pci_bus_id = pci_bus_id; + for (char & c : dev_ctx->pci_bus_id) { + c = std::tolower(c); + } dev_ctx->op_offload_min_batch_size = min_batch_size; ggml_backend_dev_t dev = new ggml_backend_device { From 184f1a1383e3e6917527e723f8b4ccf9fd571550 Mon Sep 17 00:00:00 2001 From: Pascal Date: Fri, 8 May 2026 11:44:09 +0200 Subject: [PATCH 573/831] cuda: fuse snake activation (mul, sin, sqr, mul, add) (llama/22667) * cuda: fuse snake activation (mul, sin, sqr, mul, add) Add ggml_cuda_op_snake_fused with F32 / F16 / BF16 templates. The matcher recognizes the naive 5 op decomposition emitted by audio decoders (BigVGAN, Vocos) for snake activation y = x + sin(a*x)^2 * inv_b and rewrites it to a single elementwise kernel. Add test_snake_fuse comparing CPU naive vs CUDA fused across F32 / F16 / BF16. * cuda: address review feedback from @am17an Use ggml_cuda_cast for F32/F16/BF16 conversions and rename kernel_snake to snake_kernel to match upstream conventions. * cuda: snake fusion fastdiv on T_len, Suggested-by: @am17an * Update tests/test-backend-ops.cpp Co-authored-by: Aman Gupta * cuda: snake fusion check add->type matches x->type Address review feedback from @am17an * cuda: snake fusion check add->type matches x->type Moved for readability (equivalent) Address review feedback from @am17an --------- Co-authored-by: Aman Gupta --- ggml/src/ggml-cuda/ggml-cuda.cu | 30 ++++++++++++++ ggml/src/ggml-cuda/snake.cu | 72 +++++++++++++++++++++++++++++++++ ggml/src/ggml-cuda/snake.cuh | 8 ++++ 3 files changed, 110 insertions(+) create mode 100644 ggml/src/ggml-cuda/snake.cu create mode 100644 ggml/src/ggml-cuda/snake.cuh diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 925a9ffe04c..4df1b930882 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -39,6 +39,7 @@ #include "ggml-cuda/rope.cuh" #include "ggml-cuda/roll.cuh" #include "ggml-cuda/scale.cuh" +#include "ggml-cuda/snake.cuh" #include "ggml-cuda/softcap.cuh" #include "ggml-cuda/softmax.cuh" #include "ggml-cuda/ssm-conv.cuh" @@ -3757,6 +3758,35 @@ static int ggml_cuda_try_fuse(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph return 2; } + // Snake activation: y = x + sin(a*x)^2 * inv_b + // Naive 5-op decomposition emitted by frontends: mul -> sin -> sqr -> mul -> add + if (ggml_can_fuse_subgraph(cgraph, i, + { GGML_OP_MUL, GGML_OP_SIN, GGML_OP_SQR, GGML_OP_MUL, GGML_OP_ADD }, + { i + 4 })) { + const ggml_tensor * mul0 = cgraph->nodes[i]; + const ggml_tensor * sqr = cgraph->nodes[i + 2]; + const ggml_tensor * mul1 = cgraph->nodes[i + 3]; + ggml_tensor * add = cgraph->nodes[i + 4]; + + // x carries the full activation shape, a is the broadcast operand + const ggml_tensor * x = ggml_are_same_shape(mul0, mul0->src[0]) ? mul0->src[0] : mul0->src[1]; + const ggml_tensor * a = (x == mul0->src[0]) ? mul0->src[1] : mul0->src[0]; + + // mul1 reads sqr and inv_b in either operand order + const ggml_tensor * inv_b = (mul1->src[0] == sqr) ? mul1->src[1] : mul1->src[0]; + + // closure check: the trailing add must read the same x as the leading mul + const ggml_tensor * x_in_add = (add->src[0] == mul1) ? add->src[1] : add->src[0]; + + const bool type_ok = (x->type == GGML_TYPE_F32 || x->type == GGML_TYPE_F16 || x->type == GGML_TYPE_BF16); + const bool shape_ok = ggml_are_same_shape(a, inv_b) && a->ne[0] == 1 && a->ne[1] == x->ne[1]; + + if (type_ok && shape_ok && x_in_add == x && add->type == x->type) { + ggml_cuda_op_snake_fused(*cuda_ctx, x, a, inv_b, add); + return 4; + } + } + // multi-(add or mul) if (node->op == GGML_OP_ADD || node->op == GGML_OP_MUL) { int n_fuse = 0; diff --git a/ggml/src/ggml-cuda/snake.cu b/ggml/src/ggml-cuda/snake.cu new file mode 100644 index 00000000000..384638c1f47 --- /dev/null +++ b/ggml/src/ggml-cuda/snake.cu @@ -0,0 +1,72 @@ +#include "snake.cuh" +#include "convert.cuh" + +// Fused Snake activation: y = x + sin^2(a * x) * inv_b +// x: [T, C] (T contiguous), a: [1, C], inv_b: [1, C] +// Supports F32, F16, BF16 data with F32 compute. + +template +static __global__ void snake_kernel( + const T * __restrict__ x, + const float * __restrict__ a, + const float * __restrict__ inv_b, + T * __restrict__ dst, + const int total, + const uint3 T_len_fastdiv) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= total) return; + + const int c = (int) fastdiv((uint32_t) idx, T_len_fastdiv); + + const float xi = ggml_cuda_cast(x[idx]); + const float s = sinf(a[c] * xi); + dst[idx] = ggml_cuda_cast(xi + s * s * inv_b[c]); +} + +// Internal launcher with explicit x/a/inv_b/dst tensors. +// Shared by the public op (reads dst->src) and the fusion path (explicit args). +static void launch_snake(ggml_backend_cuda_context & ctx, + const ggml_tensor * x, + const ggml_tensor * a, + const ggml_tensor * inv_b, + ggml_tensor * dst) { + const float * a_d = (const float *)a->data; + const float * inv_b_d = (const float *)inv_b->data; + + const int T = (int)x->ne[0]; + const int C = (int)x->ne[1]; + const int total = T * C; + const uint3 T_len_fastdiv = init_fastdiv_values((uint64_t) T); + + const int block_size = 256; + const int grid_size = (total + block_size - 1) / block_size; + + cudaStream_t stream = ctx.stream(); + + switch (x->type) { + case GGML_TYPE_F32: { + snake_kernel<<>>( + (const float *)x->data, a_d, inv_b_d, (float *)dst->data, total, T_len_fastdiv); + } break; + case GGML_TYPE_F16: { + snake_kernel<<>>( + (const half *)x->data, a_d, inv_b_d, (half *)dst->data, total, T_len_fastdiv); + } break; + case GGML_TYPE_BF16: { + snake_kernel<<>>( + (const nv_bfloat16 *)x->data, a_d, inv_b_d, (nv_bfloat16 *)dst->data, total, T_len_fastdiv); + } break; + default: + GGML_ABORT("snake: unsupported type"); + } +} + +// Fusion entry: caller supplies x/a/inv_b explicitly from the matched +// mul -> sin -> sqr -> mul -> add pattern. The dst is the trailing add output. +void ggml_cuda_op_snake_fused(ggml_backend_cuda_context & ctx, + const ggml_tensor * x, + const ggml_tensor * a, + const ggml_tensor * inv_b, + ggml_tensor * dst) { + launch_snake(ctx, x, a, inv_b, dst); +} diff --git a/ggml/src/ggml-cuda/snake.cuh b/ggml/src/ggml-cuda/snake.cuh new file mode 100644 index 00000000000..7f6f1cb3b41 --- /dev/null +++ b/ggml/src/ggml-cuda/snake.cuh @@ -0,0 +1,8 @@ +#include "common.cuh" + +// Fusion entry point. Caller supplies x/a/inv_b explicitly. +void ggml_cuda_op_snake_fused(ggml_backend_cuda_context & ctx, + const ggml_tensor * x, + const ggml_tensor * a, + const ggml_tensor * inv_b, + ggml_tensor * dst); From e0573051c6e7f814db4353429980482f663a0057 Mon Sep 17 00:00:00 2001 From: Pranav Dhinakar Date: Fri, 8 May 2026 13:41:40 -0700 Subject: [PATCH 574/831] Feature hexagon l2 norm (llama/22816) * L2_NORM Updates * Addressed PR Comments * ggml-hexagon: add L2_NORM HVX kernel for Hexagon backend * hex-unary: remove supported_unary_nc since the outer loop is the same for all unary ops --------- Co-authored-by: Max Krasnyansky --- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 9 ++- ggml/src/ggml-hexagon/htp/htp-ops.h | 2 + ggml/src/ggml-hexagon/htp/main.c | 1 + ggml/src/ggml-hexagon/htp/unary-ops.c | 81 ++++++++++++++++++++++++++ 4 files changed, 91 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index df4ed101464..8ddd1915c83 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -2420,8 +2420,8 @@ static bool ggml_hexagon_supported_unary(const struct ggml_hexagon_session * ses return false; } - // TODO: add support for non-contiguous elements within a row - if (!ggml_is_contiguous_rows(src0) || !ggml_is_contiguous_rows(dst)) { + // dst must be contiguous; src0 may be non-contiguous + if (!ggml_is_contiguous(dst)) { return false; } @@ -2791,6 +2791,7 @@ static htp_op_code op_remap_to_htp(const ggml_tensor * t) { case GGML_OP_SET_ROWS: return HTP_OP_SET_ROWS; case GGML_OP_SUM_ROWS: return HTP_OP_SUM_ROWS; case GGML_OP_ARGSORT: return HTP_OP_ARGSORT; + case GGML_OP_L2_NORM: return HTP_OP_L2_NORM; case GGML_OP_RMS_NORM: return HTP_OP_RMS_NORM; case GGML_OP_SCALE: return HTP_OP_SCALE; case GGML_OP_SQR: return HTP_OP_SQR; @@ -3253,6 +3254,10 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons supp = ggml_hexagon_supported_add_id(sess, op); break; + case GGML_OP_L2_NORM: + supp = ggml_hexagon_supported_unary(sess, op); + break; + case GGML_OP_RMS_NORM: case GGML_OP_SCALE: supp = ggml_hexagon_supported_unary(sess, op); diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index 66a3150c1a0..ef96ad38278 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -83,6 +83,8 @@ enum htp_op_code { HTP_OP_FILL, HTP_OP_DIAG, HTP_OP_SOLVE_TRI, + HTP_OP_L2_NORM, + HTP_OP_INVALID }; diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index 49c1a15b344..e18f1a0e61e 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -542,6 +542,7 @@ static int execute_op(struct htp_ops_context * octx) { case HTP_OP_UNARY_SIGMOID: case HTP_OP_UNARY_NEG: case HTP_OP_UNARY_EXP: + case HTP_OP_L2_NORM: return op_unary(octx); case HTP_OP_UNARY_SILU: diff --git a/ggml/src/ggml-hexagon/htp/unary-ops.c b/ggml/src/ggml-hexagon/htp/unary-ops.c index 819cdc49bd9..26a0e0bd793 100644 --- a/ggml/src/ggml-hexagon/htp/unary-ops.c +++ b/ggml/src/ggml-hexagon/htp/unary-ops.c @@ -298,6 +298,81 @@ static void softplus_f32(const float * restrict src, } } +// --- L2_NORM HVX kernel --- +// Computes y[i] = x[i] / fmax(sqrt(sum(x[j]^2)), epsilon) for each row. +// scale = 1/fmax(sqrt(sum), epsilon) is computed entirely in HVX registers +// using rsqrt + inverse to avoid scalar extraction. +static void hvx_fast_l2_norm_f32(const uint8_t * restrict src, + uint8_t * restrict dst, + uint8_t * restrict pad, + const int num_elems, + float epsilon) { + (void)pad; + + const HVX_Vector * restrict v_src = (HVX_Vector *) src; + HVX_Vector * restrict v_dst = (HVX_Vector *) dst; + + HVX_Vector sum_v = hvx_vec_splat_f32(0.0f); + + const int nvec = num_elems / VLEN_FP32; + const int nloe = num_elems % VLEN_FP32; + + #pragma unroll(4) + for (int i = 0; i < nvec; i++) { + HVX_Vector v1 = v_src[i]; + HVX_Vector sq = Q6_Vqf32_vmpy_VsfVsf(v1, v1); + sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, sq); + } + + // Include tail elements in the sum-of-squares using a predicate mask + if (nloe > 0) { + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4); + HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]); + HVX_Vector sq = Q6_Vqf32_vmpy_VsfVsf(v1, v1); + sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, sq); + } + + // Compute scale = 1/fmax(sqrt(sum), epsilon) entirely in HVX registers. + // hvx_vec_rsqrt_f32 + hvx_vec_inverse_f32 avoids scalar extraction. + HVX_Vector sum_sf = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_v)); + HVX_Vector rsqrt_v = hvx_vec_rsqrt_f32(sum_sf); // 1/sqrt(sum) + HVX_Vector sqrt_v = hvx_vec_inverse_f32(rsqrt_v); // sqrt(sum) + HVX_Vector epsilon_v = hvx_vec_splat_f32(epsilon); + HVX_Vector denom_v = Q6_Vsf_vmax_VsfVsf(sqrt_v, epsilon_v); // fmax(sqrt(sum), epsilon) + HVX_Vector scale_v = hvx_vec_inverse_f32(denom_v); // 1/fmax(sqrt(sum), epsilon) + + #pragma unroll(4) + for (int i = 0; i < nvec; i++) { + HVX_Vector v1 = v_src[i]; + v_dst[i] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(v1, scale_v)); + } + + if (nloe > 0) { + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4); + HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]); + HVX_Vector result = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(v1, scale_v)); + hvx_vec_store_a(&v_dst[nvec], nloe * 4, result); + } +} + +static void l2_norm_f32(const float * restrict src, + float * restrict dst, + uint8_t * restrict spad, + const uint32_t num_rows, + const uint32_t row_elems, + const size_t row_size, + int32_t * op_params) { + float epsilon = 0.f; + memcpy(&epsilon, op_params, sizeof(float)); + + for (uint32_t ir = 0; ir < num_rows; ir++) { + const float * restrict src_f = (const float *)((const uint8_t *)src + (ir * row_size)); + float * restrict dst_f = (float *)((uint8_t *)dst + (ir * row_size)); + + hvx_fast_l2_norm_f32((const uint8_t *)src_f, (uint8_t *)dst_f, spad, row_elems, epsilon); + } +} + static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * data) { const struct htp_unary_context * uctx = (const struct htp_unary_context *) data; struct htp_ops_context * octx = uctx->octx; @@ -402,6 +477,9 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * case HTP_OP_UNARY_SOFTPLUS: softplus_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params); break; + case HTP_OP_L2_NORM: + l2_norm_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params); + break; default: break; } @@ -469,6 +547,9 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) { case HTP_OP_UNARY_SOFTPLUS: op_type = "softplus-f32"; break; + case HTP_OP_L2_NORM: + op_type = "l2norm-f32"; + break; default: FARF(ERROR, "Unsupported unary Op %u\n", octx->op); From 892f786a653d19c2f474b3a9c56d9d4d7be2fb1f Mon Sep 17 00:00:00 2001 From: Intel AI Get-to Market Customer Success and Solutions Date: Fri, 8 May 2026 17:05:22 -0700 Subject: [PATCH 575/831] sycl: support non-contiguous input in PAD op (llama/22148) Signed-off-by: Chun Tao Co-authored-by: Chun Tao Co-authored-by: Todd Malsbary --- ggml/src/ggml-sycl/ggml-sycl.cpp | 3 +- ggml/src/ggml-sycl/pad.cpp | 54 ++++++++++++++++---------------- 2 files changed, 28 insertions(+), 29 deletions(-) diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 29ecedb5de9..c3ac281067a 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -5104,11 +5104,10 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_ACC: return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]); case GGML_OP_PAD: - // TODO: add circular padding support for syscl, see https://github.com/ggml-org/llama.cpp/pull/16985 if (ggml_get_op_params_i32(op, 8) != 0) { return false; } - return ggml_is_contiguous(op->src[0]); + return true; case GGML_OP_LEAKY_RELU: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_RWKV_WKV6: diff --git a/ggml/src/ggml-sycl/pad.cpp b/ggml/src/ggml-sycl/pad.cpp index f989c5e4b8b..ee93bb51801 100644 --- a/ggml/src/ggml-sycl/pad.cpp +++ b/ggml/src/ggml-sycl/pad.cpp @@ -13,7 +13,8 @@ //#include "common.hpp" #include "pad.hpp" -static void pad_f32(const float * src, float * dst, +static void pad_f32(const float * src, size_t s00, size_t s01, size_t s02, size_t s03, + float * dst, const int lp0, const int rp0, const int lp1, const int rp1, const int lp2, const int rp2, const int lp3, const int rp3, const int ne0, const int ne1, const int ne2, const int ne3, @@ -27,7 +28,6 @@ static void pad_f32(const float * src, float * dst, return; } - // operation const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0; if ((i0 >= lp0 && i0 < ne0 - rp0) && (i1 >= lp1 && i1 < ne1 - rp1) && @@ -37,12 +37,8 @@ static void pad_f32(const float * src, float * dst, const int64_t i01 = i1 - lp1; const int64_t i02 = i2 - lp2; const int64_t i03 = i3 - lp3; - const int64_t ne02 = ne2 - lp2 - rp2; - const int64_t ne01 = ne1 - lp1 - rp1; - const int64_t ne00 = ne0 - lp0 - rp0; - const int64_t src_idx = i03 * (ne00 * ne01 * ne02) + - i02 * (ne00 * ne01) + i01 * ne00 + i00; + const int64_t src_idx = i03 * s03 + i02 * s02 + i01 * s01 + i00 * s00; dst[dst_idx] = src[src_idx]; } else { @@ -50,20 +46,19 @@ static void pad_f32(const float * src, float * dst, } } -static void pad_f32_sycl(const float *src, float *dst, const int lp0, - const int rp0, const int lp1, const int rp1, - const int lp2, const int rp2, const int lp3, - const int rp3, const int ne0, const int ne1, - const int ne2, const int ne3, +static void pad_f32_sycl(const float * src, size_t s00, size_t s01, size_t s02, size_t s03, + float * dst, const int lp0, const int rp0, const int lp1, const int rp1, + const int lp2, const int rp2, const int lp3, const int rp3, + const int ne0, const int ne1, const int ne2, const int ne3, dpct::queue_ptr stream) { int num_blocks = (ne0 + SYCL_PAD_BLOCK_SIZE - 1) / SYCL_PAD_BLOCK_SIZE; - dpct::dim3 gridDim(num_blocks, ne1, ne2 * ne3); + sycl::range<3> grid(ne2 * ne3, ne1, num_blocks); stream->parallel_for( - sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE), + sycl::nd_range<3>(grid * sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { - pad_f32(src, dst, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, ne0, ne1, - ne2, ne3, item_ct1); + pad_f32(src, s00, s01, s02, s03, dst, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, + ne0, ne1, ne2, ne3, item_ct1); }); } @@ -71,22 +66,27 @@ void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const float * src0_d = (const float *)src0->data; float * dst_d = (float *)dst->data; - dpct::queue_ptr stream = ctx.stream(); + dpct::queue_ptr stream = ctx.stream(); GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32); - GGML_ASSERT(ggml_is_contiguous(src0)); - const int32_t lp0 = ((const int32_t*)(dst->op_params))[0]; - const int32_t rp0 = ((const int32_t*)(dst->op_params))[1]; - const int32_t lp1 = ((const int32_t*)(dst->op_params))[2]; - const int32_t rp1 = ((const int32_t*)(dst->op_params))[3]; - const int32_t lp2 = ((const int32_t*)(dst->op_params))[4]; - const int32_t rp2 = ((const int32_t*)(dst->op_params))[5]; - const int32_t lp3 = ((const int32_t*)(dst->op_params))[6]; - const int32_t rp3 = ((const int32_t*)(dst->op_params))[7]; + const size_t ts = ggml_type_size(src0->type); + const size_t s00 = src0->nb[0] / ts; + const size_t s01 = src0->nb[1] / ts; + const size_t s02 = src0->nb[2] / ts; + const size_t s03 = src0->nb[3] / ts; - pad_f32_sycl(src0_d, dst_d, + const int32_t lp0 = ((const int32_t *)(dst->op_params))[0]; + const int32_t rp0 = ((const int32_t *)(dst->op_params))[1]; + const int32_t lp1 = ((const int32_t *)(dst->op_params))[2]; + const int32_t rp1 = ((const int32_t *)(dst->op_params))[3]; + const int32_t lp2 = ((const int32_t *)(dst->op_params))[4]; + const int32_t rp2 = ((const int32_t *)(dst->op_params))[5]; + const int32_t lp3 = ((const int32_t *)(dst->op_params))[6]; + const int32_t rp3 = ((const int32_t *)(dst->op_params))[7]; + + pad_f32_sycl(src0_d, s00, s01, s02, s03, dst_d, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream); } From 42aea65eda1b4100bedb42ea1cc7344da9f1054a Mon Sep 17 00:00:00 2001 From: Yanzhao Wang Date: Fri, 8 May 2026 17:12:04 -0700 Subject: [PATCH 576/831] hexagon: add HTP kernel for GGML_OP_GATED_DELTA_NET (llama/22837) Implement the Gated Delta Net recurrence on HVX with: - 4-row fused kernels for PP (prompt processing) path - 8-row fused kernels for TG (token generation) path, reducing K/Q/gate vector reload overhead by 2x - Separate PP/TG thread functions for I-cache isolation - VTCM state scratchpad with DMA in/out for TG single-cycle access - Vectorized gate exp via hvx_exp_f32 --- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 111 +- ggml/src/ggml-hexagon/htp/CMakeLists.txt | 1 + .../ggml-hexagon/htp/gated-delta-net-ops.c | 955 ++++++++++++++++++ ggml/src/ggml-hexagon/htp/htp-ctx.h | 1 + ggml/src/ggml-hexagon/htp/htp-ops.h | 1 + ggml/src/ggml-hexagon/htp/main.c | 3 + 6 files changed, 1045 insertions(+), 27 deletions(-) create mode 100644 ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 8ddd1915c83..d3c125dbc3d 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -2261,6 +2261,58 @@ static bool ggml_hexagon_supported_flash_attn_ext(const struct ggml_hexagon_sess return true; } +static bool ggml_hexagon_supported_gated_delta_net(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + const struct ggml_tensor * q = op->src[0]; + const struct ggml_tensor * k = op->src[1]; + const struct ggml_tensor * v = op->src[2]; + const struct ggml_tensor * g = op->src[3]; + const struct ggml_tensor * beta = op->src[4]; + const struct ggml_tensor * state = op->src[5]; + const struct ggml_tensor * dst = op; + + if (!q || !k || !v || !g || !beta || !state) { + return false; + } + + if (q->type != GGML_TYPE_F32 || k->type != GGML_TYPE_F32 || v->type != GGML_TYPE_F32 || + g->type != GGML_TYPE_F32 || beta->type != GGML_TYPE_F32 || state->type != GGML_TYPE_F32 || + dst->type != GGML_TYPE_F32) { + return false; + } + + if (!ggml_is_contiguous_rows(q) || !ggml_is_contiguous_rows(k) || !ggml_is_contiguous_rows(v) || + !ggml_is_contiguous(g) || !ggml_is_contiguous(beta) || !ggml_is_contiguous(state) || + !ggml_is_contiguous(dst)) { + return false; + } + + const int64_t S_v = v->ne[0]; + const int64_t H = v->ne[1]; + const int64_t n_tokens = v->ne[2]; + const int64_t n_seqs = v->ne[3]; + + if (S_v <= 0 || S_v > 128 || H <= 0 || n_tokens <= 0 || n_seqs <= 0) { + return false; + } + if (q->ne[0] != S_v || k->ne[0] != S_v || q->ne[1] <= 0 || k->ne[1] <= 0 || + q->ne[2] != n_tokens || k->ne[2] != n_tokens || q->ne[3] <= 0 || k->ne[3] <= 0 || + (n_seqs % q->ne[3]) != 0 || (n_seqs % k->ne[3]) != 0) { + return false; + } + if ((g->ne[0] != 1 && g->ne[0] != S_v) || beta->ne[0] != 1) { + return false; + } + if (ggml_nelements(state) != S_v * S_v * H * n_seqs) { + return false; + } + if (dst->ne[0] != S_v * H || dst->ne[1] != n_tokens * n_seqs + S_v * n_seqs) { + return false; + } + + GGML_UNUSED(sess); + return true; +} + static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * sess, const struct ggml_tensor * dst) { const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * src1 = dst->src[1]; @@ -2777,33 +2829,34 @@ static void ggml_backend_hexagon_free(ggml_backend_t backend) { static htp_op_code op_remap_to_htp(const ggml_tensor * t) { switch (t->op) { - case GGML_OP_FLASH_ATTN_EXT: return HTP_OP_FLASH_ATTN_EXT; - case GGML_OP_MUL_MAT: return HTP_OP_MUL_MAT; - case GGML_OP_MUL_MAT_ID: return HTP_OP_MUL_MAT_ID; - case GGML_OP_MUL: return HTP_OP_MUL; - case GGML_OP_ADD: return HTP_OP_ADD; - case GGML_OP_ADD_ID: return HTP_OP_ADD_ID; - case GGML_OP_SUB: return HTP_OP_SUB; - case GGML_OP_DIV: return HTP_OP_DIV; - case GGML_OP_CPY: return HTP_OP_CPY; - case GGML_OP_CONT: return HTP_OP_CPY; - case GGML_OP_GET_ROWS: return HTP_OP_GET_ROWS; - case GGML_OP_SET_ROWS: return HTP_OP_SET_ROWS; - case GGML_OP_SUM_ROWS: return HTP_OP_SUM_ROWS; - case GGML_OP_ARGSORT: return HTP_OP_ARGSORT; - case GGML_OP_L2_NORM: return HTP_OP_L2_NORM; - case GGML_OP_RMS_NORM: return HTP_OP_RMS_NORM; - case GGML_OP_SCALE: return HTP_OP_SCALE; - case GGML_OP_SQR: return HTP_OP_SQR; - case GGML_OP_SQRT: return HTP_OP_SQRT; - case GGML_OP_SOFT_MAX: return HTP_OP_SOFTMAX; - case GGML_OP_SSM_CONV: return HTP_OP_SSM_CONV; - case GGML_OP_ROPE: return HTP_OP_ROPE; - case GGML_OP_REPEAT: return HTP_OP_REPEAT; - case GGML_OP_CUMSUM: return HTP_OP_CUMSUM; - case GGML_OP_FILL: return HTP_OP_FILL; - case GGML_OP_DIAG: return HTP_OP_DIAG; - case GGML_OP_SOLVE_TRI: return HTP_OP_SOLVE_TRI; + case GGML_OP_FLASH_ATTN_EXT: return HTP_OP_FLASH_ATTN_EXT; + case GGML_OP_MUL_MAT: return HTP_OP_MUL_MAT; + case GGML_OP_MUL_MAT_ID: return HTP_OP_MUL_MAT_ID; + case GGML_OP_MUL: return HTP_OP_MUL; + case GGML_OP_ADD: return HTP_OP_ADD; + case GGML_OP_ADD_ID: return HTP_OP_ADD_ID; + case GGML_OP_SUB: return HTP_OP_SUB; + case GGML_OP_DIV: return HTP_OP_DIV; + case GGML_OP_CPY: return HTP_OP_CPY; + case GGML_OP_CONT: return HTP_OP_CPY; + case GGML_OP_GET_ROWS: return HTP_OP_GET_ROWS; + case GGML_OP_SET_ROWS: return HTP_OP_SET_ROWS; + case GGML_OP_SUM_ROWS: return HTP_OP_SUM_ROWS; + case GGML_OP_ARGSORT: return HTP_OP_ARGSORT; + case GGML_OP_L2_NORM: return HTP_OP_L2_NORM; + case GGML_OP_RMS_NORM: return HTP_OP_RMS_NORM; + case GGML_OP_SCALE: return HTP_OP_SCALE; + case GGML_OP_SQR: return HTP_OP_SQR; + case GGML_OP_SQRT: return HTP_OP_SQRT; + case GGML_OP_SOFT_MAX: return HTP_OP_SOFTMAX; + case GGML_OP_SSM_CONV: return HTP_OP_SSM_CONV; + case GGML_OP_GATED_DELTA_NET: return HTP_OP_GATED_DELTA_NET; + case GGML_OP_ROPE: return HTP_OP_ROPE; + case GGML_OP_REPEAT: return HTP_OP_REPEAT; + case GGML_OP_CUMSUM: return HTP_OP_CUMSUM; + case GGML_OP_FILL: return HTP_OP_FILL; + case GGML_OP_DIAG: return HTP_OP_DIAG; + case GGML_OP_SOLVE_TRI: return HTP_OP_SOLVE_TRI; case GGML_OP_UNARY: switch (ggml_get_unary_op(t)) { case GGML_UNARY_OP_SILU: return HTP_OP_UNARY_SILU; @@ -3341,6 +3394,10 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons supp = ggml_hexagon_supported_ssm_conv(sess, op); break; + case GGML_OP_GATED_DELTA_NET: + supp = ggml_hexagon_supported_gated_delta_net(sess, op); + break; + case GGML_OP_CUMSUM: supp = ggml_hexagon_supported_cumsum(sess, op); break; diff --git a/ggml/src/ggml-hexagon/htp/CMakeLists.txt b/ggml/src/ggml-hexagon/htp/CMakeLists.txt index 7c9e4cda5f1..bcadac11f95 100644 --- a/ggml/src/ggml-hexagon/htp/CMakeLists.txt +++ b/ggml/src/ggml-hexagon/htp/CMakeLists.txt @@ -37,6 +37,7 @@ add_library(${HTP_LIB} SHARED fill-ops.c diag-ops.c solve-tri-ops.c + gated-delta-net-ops.c ) target_compile_definitions(${HTP_LIB} PRIVATE diff --git a/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c b/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c new file mode 100644 index 00000000000..2e84badc9b7 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c @@ -0,0 +1,955 @@ +#include +#include +#include + +#include "hvx-utils.h" + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" + +#ifndef MIN +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#endif + +#define HTP_GDN_MAX_SV 128 + +struct htp_gdn_context { + struct htp_ops_context * octx; + uint32_t rows_per_thread; + size_t state_bytes; + bool use_vtcm; + uint8_t * vtcm_state_base; + size_t vtcm_state_per_thread; +}; + +static inline float gdn_mul_dot_f32(float * restrict dst, const float * restrict mul, + const float * restrict dot, uint32_t n) { + HVX_Vector acc = Q6_V_vzero(); + + const uint32_t epv = 128 / sizeof(float); + const uint32_t nvec = n / epv; + const uint32_t tail = n % epv; + for (uint32_t i = 0; i < nvec; ++i) { + HVX_Vector vd = hvx_vmemu(dst + i * epv); + HVX_Vector vm = hvx_vmem(mul + i * epv); + HVX_Vector vdot = hvx_vmem(dot + i * epv); + HVX_Vector out = hvx_vec_mul_f32_f32(vd, vm); + hvx_vmemu(dst + i * epv) = out; + acc = hvx_vec_add_f32_f32(acc, hvx_vec_mul_f32_f32(out, vdot)); + } + + if (tail) { + const uint32_t off = nvec * epv; + HVX_Vector vd = hvx_vmemu(dst + off); + HVX_Vector vm = hvx_vmem(mul + off); + HVX_Vector vdot = hvx_vmem(dot + off); + HVX_Vector out = hvx_vec_mul_f32_f32(vd, vm); + hvx_vec_store_u(dst + off, tail * sizeof(float), out); + HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float)); + HVX_Vector prod = hvx_vec_mul_f32_f32(out, vdot); + acc = hvx_vec_add_f32_f32(acc, Q6_V_vmux_QVV(mask, prod, Q6_V_vzero())); + } + + return hvx_vec_get_f32(hvx_vec_reduce_sum_f32(acc)); +} + +static inline float gdn_mul_scalar_dot_f32(float * restrict dst, float mul, + const float * restrict dot, uint32_t n) { + HVX_Vector acc = Q6_V_vzero(); + const HVX_Vector vmul = hvx_vec_splat_f32(mul); + + const uint32_t epv = 128 / sizeof(float); + const uint32_t nvec = n / epv; + const uint32_t tail = n % epv; + for (uint32_t i = 0; i < nvec; ++i) { + HVX_Vector vd = hvx_vmemu(dst + i * epv); + HVX_Vector vdot = hvx_vmem(dot + i * epv); + HVX_Vector out = hvx_vec_mul_f32_f32(vd, vmul); + hvx_vmemu(dst + i * epv) = out; + acc = hvx_vec_add_f32_f32(acc, hvx_vec_mul_f32_f32(out, vdot)); + } + + if (tail) { + const uint32_t off = nvec * epv; + HVX_Vector vd = hvx_vmemu(dst + off); + HVX_Vector vdot = hvx_vmem(dot + off); + HVX_Vector out = hvx_vec_mul_f32_f32(vd, vmul); + hvx_vec_store_u(dst + off, tail * sizeof(float), out); + HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float)); + HVX_Vector prod = hvx_vec_mul_f32_f32(out, vdot); + acc = hvx_vec_add_f32_f32(acc, Q6_V_vmux_QVV(mask, prod, Q6_V_vzero())); + } + + return hvx_vec_get_f32(hvx_vec_reduce_sum_f32(acc)); +} + +static inline float gdn_add_scaled_dot_f32(float * restrict dst, const float * restrict src, + float scale, const float * restrict dot, uint32_t n) { + HVX_Vector acc = Q6_V_vzero(); + const HVX_Vector vscale = hvx_vec_splat_f32(scale); + + const uint32_t epv = 128 / sizeof(float); + const uint32_t nvec = n / epv; + const uint32_t tail = n % epv; + for (uint32_t i = 0; i < nvec; ++i) { + HVX_Vector vd = hvx_vmemu(dst + i * epv); + HVX_Vector vs = hvx_vmem(src + i * epv); + HVX_Vector vdot = hvx_vmem(dot + i * epv); + HVX_Vector out = hvx_vec_add_f32_f32(vd, hvx_vec_mul_f32_f32(vs, vscale)); + hvx_vmemu(dst + i * epv) = out; + acc = hvx_vec_add_f32_f32(acc, hvx_vec_mul_f32_f32(out, vdot)); + } + + if (tail) { + const uint32_t off = nvec * epv; + HVX_Vector vd = hvx_vmemu(dst + off); + HVX_Vector vs = hvx_vmem(src + off); + HVX_Vector vdot = hvx_vmem(dot + off); + HVX_Vector out = hvx_vec_add_f32_f32(vd, hvx_vec_mul_f32_f32(vs, vscale)); + hvx_vec_store_u(dst + off, tail * sizeof(float), out); + HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float)); + HVX_Vector prod = hvx_vec_mul_f32_f32(out, vdot); + acc = hvx_vec_add_f32_f32(acc, Q6_V_vmux_QVV(mask, prod, Q6_V_vzero())); + } + + return hvx_vec_get_f32(hvx_vec_reduce_sum_f32(acc)); +} + +static inline void gdn_mul_dot4_f32(float * restrict dst0, float * restrict dst1, + float * restrict dst2, float * restrict dst3, const float * restrict mul, + const float * restrict dot, uint32_t n, float * restrict sums) { + HVX_Vector acc0 = Q6_V_vzero(); + HVX_Vector acc1 = Q6_V_vzero(); + HVX_Vector acc2 = Q6_V_vzero(); + HVX_Vector acc3 = Q6_V_vzero(); + + const uint32_t epv = 128 / sizeof(float); + const uint32_t nvec = n / epv; + const uint32_t tail = n % epv; + for (uint32_t i = 0; i < nvec; ++i) { + HVX_Vector vm = hvx_vmem(mul + i * epv); + HVX_Vector vdot = hvx_vmem(dot + i * epv); + + HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + i * epv), vm); + HVX_Vector out1 = hvx_vec_mul_f32_f32(hvx_vmemu(dst1 + i * epv), vm); + HVX_Vector out2 = hvx_vec_mul_f32_f32(hvx_vmemu(dst2 + i * epv), vm); + HVX_Vector out3 = hvx_vec_mul_f32_f32(hvx_vmemu(dst3 + i * epv), vm); + + hvx_vmemu(dst0 + i * epv) = out0; + hvx_vmemu(dst1 + i * epv) = out1; + hvx_vmemu(dst2 + i * epv) = out2; + hvx_vmemu(dst3 + i * epv) = out3; + + acc0 = hvx_vec_add_f32_f32(acc0, hvx_vec_mul_f32_f32(out0, vdot)); + acc1 = hvx_vec_add_f32_f32(acc1, hvx_vec_mul_f32_f32(out1, vdot)); + acc2 = hvx_vec_add_f32_f32(acc2, hvx_vec_mul_f32_f32(out2, vdot)); + acc3 = hvx_vec_add_f32_f32(acc3, hvx_vec_mul_f32_f32(out3, vdot)); + } + + if (tail) { + const uint32_t off = nvec * epv; + HVX_Vector vm = hvx_vmem(mul + off); + HVX_Vector vdot = hvx_vmem(dot + off); + HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float)); + HVX_Vector zero = Q6_V_vzero(); + + HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + off), vm); + HVX_Vector out1 = hvx_vec_mul_f32_f32(hvx_vmemu(dst1 + off), vm); + HVX_Vector out2 = hvx_vec_mul_f32_f32(hvx_vmemu(dst2 + off), vm); + HVX_Vector out3 = hvx_vec_mul_f32_f32(hvx_vmemu(dst3 + off), vm); + + hvx_vec_store_u(dst0 + off, tail * sizeof(float), out0); + hvx_vec_store_u(dst1 + off, tail * sizeof(float), out1); + hvx_vec_store_u(dst2 + off, tail * sizeof(float), out2); + hvx_vec_store_u(dst3 + off, tail * sizeof(float), out3); + + acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero)); + acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero)); + acc2 = hvx_vec_add_f32_f32(acc2, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out2, vdot), zero)); + acc3 = hvx_vec_add_f32_f32(acc3, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out3, vdot), zero)); + } + + HVX_Vector_x4 acc = { .v = { acc0, acc1, acc2, acc3 } }; + hvx_vec_store_u(sums, 4 * sizeof(float), hvx_vec_reduce_sum_f32x4(acc)); +} + +static inline void gdn_mul_scalar_dot4_f32(float * restrict dst0, float * restrict dst1, + float * restrict dst2, float * restrict dst3, float mul, + const float * restrict dot, uint32_t n, float * restrict sums) { + HVX_Vector acc0 = Q6_V_vzero(); + HVX_Vector acc1 = Q6_V_vzero(); + HVX_Vector acc2 = Q6_V_vzero(); + HVX_Vector acc3 = Q6_V_vzero(); + const HVX_Vector vmul = hvx_vec_splat_f32(mul); + + const uint32_t epv = 128 / sizeof(float); + const uint32_t nvec = n / epv; + const uint32_t tail = n % epv; + for (uint32_t i = 0; i < nvec; ++i) { + HVX_Vector vdot = hvx_vmem(dot + i * epv); + + HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + i * epv), vmul); + HVX_Vector out1 = hvx_vec_mul_f32_f32(hvx_vmemu(dst1 + i * epv), vmul); + HVX_Vector out2 = hvx_vec_mul_f32_f32(hvx_vmemu(dst2 + i * epv), vmul); + HVX_Vector out3 = hvx_vec_mul_f32_f32(hvx_vmemu(dst3 + i * epv), vmul); + + hvx_vmemu(dst0 + i * epv) = out0; + hvx_vmemu(dst1 + i * epv) = out1; + hvx_vmemu(dst2 + i * epv) = out2; + hvx_vmemu(dst3 + i * epv) = out3; + + acc0 = hvx_vec_add_f32_f32(acc0, hvx_vec_mul_f32_f32(out0, vdot)); + acc1 = hvx_vec_add_f32_f32(acc1, hvx_vec_mul_f32_f32(out1, vdot)); + acc2 = hvx_vec_add_f32_f32(acc2, hvx_vec_mul_f32_f32(out2, vdot)); + acc3 = hvx_vec_add_f32_f32(acc3, hvx_vec_mul_f32_f32(out3, vdot)); + } + + if (tail) { + const uint32_t off = nvec * epv; + HVX_Vector vdot = hvx_vmem(dot + off); + HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float)); + HVX_Vector zero = Q6_V_vzero(); + + HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + off), vmul); + HVX_Vector out1 = hvx_vec_mul_f32_f32(hvx_vmemu(dst1 + off), vmul); + HVX_Vector out2 = hvx_vec_mul_f32_f32(hvx_vmemu(dst2 + off), vmul); + HVX_Vector out3 = hvx_vec_mul_f32_f32(hvx_vmemu(dst3 + off), vmul); + + hvx_vec_store_u(dst0 + off, tail * sizeof(float), out0); + hvx_vec_store_u(dst1 + off, tail * sizeof(float), out1); + hvx_vec_store_u(dst2 + off, tail * sizeof(float), out2); + hvx_vec_store_u(dst3 + off, tail * sizeof(float), out3); + + acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero)); + acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero)); + acc2 = hvx_vec_add_f32_f32(acc2, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out2, vdot), zero)); + acc3 = hvx_vec_add_f32_f32(acc3, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out3, vdot), zero)); + } + + HVX_Vector_x4 acc = { .v = { acc0, acc1, acc2, acc3 } }; + hvx_vec_store_u(sums, 4 * sizeof(float), hvx_vec_reduce_sum_f32x4(acc)); +} + +static inline void gdn_add_scaled_dot4_f32(float * restrict dst0, float * restrict dst1, + float * restrict dst2, float * restrict dst3, const float * restrict src, + const float * restrict scale, const float * restrict dot, uint32_t n, + float * restrict sums) { + HVX_Vector acc0 = Q6_V_vzero(); + HVX_Vector acc1 = Q6_V_vzero(); + HVX_Vector acc2 = Q6_V_vzero(); + HVX_Vector acc3 = Q6_V_vzero(); + const HVX_Vector scale0 = hvx_vec_splat_f32(scale[0]); + const HVX_Vector scale1 = hvx_vec_splat_f32(scale[1]); + const HVX_Vector scale2 = hvx_vec_splat_f32(scale[2]); + const HVX_Vector scale3 = hvx_vec_splat_f32(scale[3]); + + const uint32_t epv = 128 / sizeof(float); + const uint32_t nvec = n / epv; + const uint32_t tail = n % epv; + for (uint32_t i = 0; i < nvec; ++i) { + HVX_Vector vs = hvx_vmem(src + i * epv); + HVX_Vector vdot = hvx_vmem(dot + i * epv); + + HVX_Vector out0 = hvx_vec_add_f32_f32(hvx_vmemu(dst0 + i * epv), hvx_vec_mul_f32_f32(vs, scale0)); + HVX_Vector out1 = hvx_vec_add_f32_f32(hvx_vmemu(dst1 + i * epv), hvx_vec_mul_f32_f32(vs, scale1)); + HVX_Vector out2 = hvx_vec_add_f32_f32(hvx_vmemu(dst2 + i * epv), hvx_vec_mul_f32_f32(vs, scale2)); + HVX_Vector out3 = hvx_vec_add_f32_f32(hvx_vmemu(dst3 + i * epv), hvx_vec_mul_f32_f32(vs, scale3)); + + hvx_vmemu(dst0 + i * epv) = out0; + hvx_vmemu(dst1 + i * epv) = out1; + hvx_vmemu(dst2 + i * epv) = out2; + hvx_vmemu(dst3 + i * epv) = out3; + + acc0 = hvx_vec_add_f32_f32(acc0, hvx_vec_mul_f32_f32(out0, vdot)); + acc1 = hvx_vec_add_f32_f32(acc1, hvx_vec_mul_f32_f32(out1, vdot)); + acc2 = hvx_vec_add_f32_f32(acc2, hvx_vec_mul_f32_f32(out2, vdot)); + acc3 = hvx_vec_add_f32_f32(acc3, hvx_vec_mul_f32_f32(out3, vdot)); + } + + if (tail) { + const uint32_t off = nvec * epv; + HVX_Vector vs = hvx_vmem(src + off); + HVX_Vector vdot = hvx_vmem(dot + off); + HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float)); + HVX_Vector zero = Q6_V_vzero(); + + HVX_Vector out0 = hvx_vec_add_f32_f32(hvx_vmemu(dst0 + off), hvx_vec_mul_f32_f32(vs, scale0)); + HVX_Vector out1 = hvx_vec_add_f32_f32(hvx_vmemu(dst1 + off), hvx_vec_mul_f32_f32(vs, scale1)); + HVX_Vector out2 = hvx_vec_add_f32_f32(hvx_vmemu(dst2 + off), hvx_vec_mul_f32_f32(vs, scale2)); + HVX_Vector out3 = hvx_vec_add_f32_f32(hvx_vmemu(dst3 + off), hvx_vec_mul_f32_f32(vs, scale3)); + + hvx_vec_store_u(dst0 + off, tail * sizeof(float), out0); + hvx_vec_store_u(dst1 + off, tail * sizeof(float), out1); + hvx_vec_store_u(dst2 + off, tail * sizeof(float), out2); + hvx_vec_store_u(dst3 + off, tail * sizeof(float), out3); + + acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero)); + acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero)); + acc2 = hvx_vec_add_f32_f32(acc2, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out2, vdot), zero)); + acc3 = hvx_vec_add_f32_f32(acc3, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out3, vdot), zero)); + } + + HVX_Vector_x4 acc = { .v = { acc0, acc1, acc2, acc3 } }; + hvx_vec_store_u(sums, 4 * sizeof(float), hvx_vec_reduce_sum_f32x4(acc)); +} + +static inline void gdn_mul_dot8_f32(float * restrict dst0, float * restrict dst1, + float * restrict dst2, float * restrict dst3, float * restrict dst4, + float * restrict dst5, float * restrict dst6, float * restrict dst7, + const float * restrict mul, const float * restrict dot, uint32_t n, + float * restrict sums) { + HVX_Vector acc0 = Q6_V_vzero(); + HVX_Vector acc1 = Q6_V_vzero(); + HVX_Vector acc2 = Q6_V_vzero(); + HVX_Vector acc3 = Q6_V_vzero(); + HVX_Vector acc4 = Q6_V_vzero(); + HVX_Vector acc5 = Q6_V_vzero(); + HVX_Vector acc6 = Q6_V_vzero(); + HVX_Vector acc7 = Q6_V_vzero(); + + const uint32_t epv = 128 / sizeof(float); + const uint32_t nvec = n / epv; + const uint32_t tail = n % epv; + for (uint32_t i = 0; i < nvec; ++i) { + HVX_Vector vm = hvx_vmem(mul + i * epv); + HVX_Vector vdot = hvx_vmem(dot + i * epv); + + HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + i * epv), vm); + HVX_Vector out1 = hvx_vec_mul_f32_f32(hvx_vmemu(dst1 + i * epv), vm); + HVX_Vector out2 = hvx_vec_mul_f32_f32(hvx_vmemu(dst2 + i * epv), vm); + HVX_Vector out3 = hvx_vec_mul_f32_f32(hvx_vmemu(dst3 + i * epv), vm); + HVX_Vector out4 = hvx_vec_mul_f32_f32(hvx_vmemu(dst4 + i * epv), vm); + HVX_Vector out5 = hvx_vec_mul_f32_f32(hvx_vmemu(dst5 + i * epv), vm); + HVX_Vector out6 = hvx_vec_mul_f32_f32(hvx_vmemu(dst6 + i * epv), vm); + HVX_Vector out7 = hvx_vec_mul_f32_f32(hvx_vmemu(dst7 + i * epv), vm); + + hvx_vmemu(dst0 + i * epv) = out0; + hvx_vmemu(dst1 + i * epv) = out1; + hvx_vmemu(dst2 + i * epv) = out2; + hvx_vmemu(dst3 + i * epv) = out3; + hvx_vmemu(dst4 + i * epv) = out4; + hvx_vmemu(dst5 + i * epv) = out5; + hvx_vmemu(dst6 + i * epv) = out6; + hvx_vmemu(dst7 + i * epv) = out7; + + acc0 = hvx_vec_add_f32_f32(acc0, hvx_vec_mul_f32_f32(out0, vdot)); + acc1 = hvx_vec_add_f32_f32(acc1, hvx_vec_mul_f32_f32(out1, vdot)); + acc2 = hvx_vec_add_f32_f32(acc2, hvx_vec_mul_f32_f32(out2, vdot)); + acc3 = hvx_vec_add_f32_f32(acc3, hvx_vec_mul_f32_f32(out3, vdot)); + acc4 = hvx_vec_add_f32_f32(acc4, hvx_vec_mul_f32_f32(out4, vdot)); + acc5 = hvx_vec_add_f32_f32(acc5, hvx_vec_mul_f32_f32(out5, vdot)); + acc6 = hvx_vec_add_f32_f32(acc6, hvx_vec_mul_f32_f32(out6, vdot)); + acc7 = hvx_vec_add_f32_f32(acc7, hvx_vec_mul_f32_f32(out7, vdot)); + } + + if (tail) { + const uint32_t off = nvec * epv; + HVX_Vector vm = hvx_vmem(mul + off); + HVX_Vector vdot = hvx_vmem(dot + off); + HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float)); + HVX_Vector zero = Q6_V_vzero(); + + HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + off), vm); + HVX_Vector out1 = hvx_vec_mul_f32_f32(hvx_vmemu(dst1 + off), vm); + HVX_Vector out2 = hvx_vec_mul_f32_f32(hvx_vmemu(dst2 + off), vm); + HVX_Vector out3 = hvx_vec_mul_f32_f32(hvx_vmemu(dst3 + off), vm); + HVX_Vector out4 = hvx_vec_mul_f32_f32(hvx_vmemu(dst4 + off), vm); + HVX_Vector out5 = hvx_vec_mul_f32_f32(hvx_vmemu(dst5 + off), vm); + HVX_Vector out6 = hvx_vec_mul_f32_f32(hvx_vmemu(dst6 + off), vm); + HVX_Vector out7 = hvx_vec_mul_f32_f32(hvx_vmemu(dst7 + off), vm); + + hvx_vec_store_u(dst0 + off, tail * sizeof(float), out0); + hvx_vec_store_u(dst1 + off, tail * sizeof(float), out1); + hvx_vec_store_u(dst2 + off, tail * sizeof(float), out2); + hvx_vec_store_u(dst3 + off, tail * sizeof(float), out3); + hvx_vec_store_u(dst4 + off, tail * sizeof(float), out4); + hvx_vec_store_u(dst5 + off, tail * sizeof(float), out5); + hvx_vec_store_u(dst6 + off, tail * sizeof(float), out6); + hvx_vec_store_u(dst7 + off, tail * sizeof(float), out7); + + acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero)); + acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero)); + acc2 = hvx_vec_add_f32_f32(acc2, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out2, vdot), zero)); + acc3 = hvx_vec_add_f32_f32(acc3, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out3, vdot), zero)); + acc4 = hvx_vec_add_f32_f32(acc4, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out4, vdot), zero)); + acc5 = hvx_vec_add_f32_f32(acc5, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out5, vdot), zero)); + acc6 = hvx_vec_add_f32_f32(acc6, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out6, vdot), zero)); + acc7 = hvx_vec_add_f32_f32(acc7, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out7, vdot), zero)); + } + + HVX_Vector_x4 accA = { .v = { acc0, acc1, acc2, acc3 } }; + HVX_Vector_x4 accB = { .v = { acc4, acc5, acc6, acc7 } }; + hvx_vec_store_u(sums + 0, 4 * sizeof(float), hvx_vec_reduce_sum_f32x4(accA)); + hvx_vec_store_u(sums + 4, 4 * sizeof(float), hvx_vec_reduce_sum_f32x4(accB)); +} + +static inline void gdn_mul_scalar_dot8_f32(float * restrict dst0, float * restrict dst1, + float * restrict dst2, float * restrict dst3, float * restrict dst4, + float * restrict dst5, float * restrict dst6, float * restrict dst7, + float mul, const float * restrict dot, uint32_t n, float * restrict sums) { + HVX_Vector acc0 = Q6_V_vzero(); + HVX_Vector acc1 = Q6_V_vzero(); + HVX_Vector acc2 = Q6_V_vzero(); + HVX_Vector acc3 = Q6_V_vzero(); + HVX_Vector acc4 = Q6_V_vzero(); + HVX_Vector acc5 = Q6_V_vzero(); + HVX_Vector acc6 = Q6_V_vzero(); + HVX_Vector acc7 = Q6_V_vzero(); + const HVX_Vector vmul = hvx_vec_splat_f32(mul); + + const uint32_t epv = 128 / sizeof(float); + const uint32_t nvec = n / epv; + const uint32_t tail = n % epv; + for (uint32_t i = 0; i < nvec; ++i) { + HVX_Vector vdot = hvx_vmem(dot + i * epv); + + HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + i * epv), vmul); + HVX_Vector out1 = hvx_vec_mul_f32_f32(hvx_vmemu(dst1 + i * epv), vmul); + HVX_Vector out2 = hvx_vec_mul_f32_f32(hvx_vmemu(dst2 + i * epv), vmul); + HVX_Vector out3 = hvx_vec_mul_f32_f32(hvx_vmemu(dst3 + i * epv), vmul); + HVX_Vector out4 = hvx_vec_mul_f32_f32(hvx_vmemu(dst4 + i * epv), vmul); + HVX_Vector out5 = hvx_vec_mul_f32_f32(hvx_vmemu(dst5 + i * epv), vmul); + HVX_Vector out6 = hvx_vec_mul_f32_f32(hvx_vmemu(dst6 + i * epv), vmul); + HVX_Vector out7 = hvx_vec_mul_f32_f32(hvx_vmemu(dst7 + i * epv), vmul); + + hvx_vmemu(dst0 + i * epv) = out0; + hvx_vmemu(dst1 + i * epv) = out1; + hvx_vmemu(dst2 + i * epv) = out2; + hvx_vmemu(dst3 + i * epv) = out3; + hvx_vmemu(dst4 + i * epv) = out4; + hvx_vmemu(dst5 + i * epv) = out5; + hvx_vmemu(dst6 + i * epv) = out6; + hvx_vmemu(dst7 + i * epv) = out7; + + acc0 = hvx_vec_add_f32_f32(acc0, hvx_vec_mul_f32_f32(out0, vdot)); + acc1 = hvx_vec_add_f32_f32(acc1, hvx_vec_mul_f32_f32(out1, vdot)); + acc2 = hvx_vec_add_f32_f32(acc2, hvx_vec_mul_f32_f32(out2, vdot)); + acc3 = hvx_vec_add_f32_f32(acc3, hvx_vec_mul_f32_f32(out3, vdot)); + acc4 = hvx_vec_add_f32_f32(acc4, hvx_vec_mul_f32_f32(out4, vdot)); + acc5 = hvx_vec_add_f32_f32(acc5, hvx_vec_mul_f32_f32(out5, vdot)); + acc6 = hvx_vec_add_f32_f32(acc6, hvx_vec_mul_f32_f32(out6, vdot)); + acc7 = hvx_vec_add_f32_f32(acc7, hvx_vec_mul_f32_f32(out7, vdot)); + } + + if (tail) { + const uint32_t off = nvec * epv; + HVX_Vector vdot = hvx_vmem(dot + off); + HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float)); + HVX_Vector zero = Q6_V_vzero(); + + HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + off), vmul); + HVX_Vector out1 = hvx_vec_mul_f32_f32(hvx_vmemu(dst1 + off), vmul); + HVX_Vector out2 = hvx_vec_mul_f32_f32(hvx_vmemu(dst2 + off), vmul); + HVX_Vector out3 = hvx_vec_mul_f32_f32(hvx_vmemu(dst3 + off), vmul); + HVX_Vector out4 = hvx_vec_mul_f32_f32(hvx_vmemu(dst4 + off), vmul); + HVX_Vector out5 = hvx_vec_mul_f32_f32(hvx_vmemu(dst5 + off), vmul); + HVX_Vector out6 = hvx_vec_mul_f32_f32(hvx_vmemu(dst6 + off), vmul); + HVX_Vector out7 = hvx_vec_mul_f32_f32(hvx_vmemu(dst7 + off), vmul); + + hvx_vec_store_u(dst0 + off, tail * sizeof(float), out0); + hvx_vec_store_u(dst1 + off, tail * sizeof(float), out1); + hvx_vec_store_u(dst2 + off, tail * sizeof(float), out2); + hvx_vec_store_u(dst3 + off, tail * sizeof(float), out3); + hvx_vec_store_u(dst4 + off, tail * sizeof(float), out4); + hvx_vec_store_u(dst5 + off, tail * sizeof(float), out5); + hvx_vec_store_u(dst6 + off, tail * sizeof(float), out6); + hvx_vec_store_u(dst7 + off, tail * sizeof(float), out7); + + acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero)); + acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero)); + acc2 = hvx_vec_add_f32_f32(acc2, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out2, vdot), zero)); + acc3 = hvx_vec_add_f32_f32(acc3, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out3, vdot), zero)); + acc4 = hvx_vec_add_f32_f32(acc4, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out4, vdot), zero)); + acc5 = hvx_vec_add_f32_f32(acc5, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out5, vdot), zero)); + acc6 = hvx_vec_add_f32_f32(acc6, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out6, vdot), zero)); + acc7 = hvx_vec_add_f32_f32(acc7, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out7, vdot), zero)); + } + + HVX_Vector_x4 accA = { .v = { acc0, acc1, acc2, acc3 } }; + HVX_Vector_x4 accB = { .v = { acc4, acc5, acc6, acc7 } }; + hvx_vec_store_u(sums + 0, 4 * sizeof(float), hvx_vec_reduce_sum_f32x4(accA)); + hvx_vec_store_u(sums + 4, 4 * sizeof(float), hvx_vec_reduce_sum_f32x4(accB)); +} + +static inline void gdn_add_scaled_dot8_f32(float * restrict dst0, float * restrict dst1, + float * restrict dst2, float * restrict dst3, float * restrict dst4, + float * restrict dst5, float * restrict dst6, float * restrict dst7, + const float * restrict src, const float * restrict scale, + const float * restrict dot, uint32_t n, float * restrict sums) { + HVX_Vector acc0 = Q6_V_vzero(); + HVX_Vector acc1 = Q6_V_vzero(); + HVX_Vector acc2 = Q6_V_vzero(); + HVX_Vector acc3 = Q6_V_vzero(); + HVX_Vector acc4 = Q6_V_vzero(); + HVX_Vector acc5 = Q6_V_vzero(); + HVX_Vector acc6 = Q6_V_vzero(); + HVX_Vector acc7 = Q6_V_vzero(); + const HVX_Vector scale0 = hvx_vec_splat_f32(scale[0]); + const HVX_Vector scale1 = hvx_vec_splat_f32(scale[1]); + const HVX_Vector scale2 = hvx_vec_splat_f32(scale[2]); + const HVX_Vector scale3 = hvx_vec_splat_f32(scale[3]); + const HVX_Vector scale4 = hvx_vec_splat_f32(scale[4]); + const HVX_Vector scale5 = hvx_vec_splat_f32(scale[5]); + const HVX_Vector scale6 = hvx_vec_splat_f32(scale[6]); + const HVX_Vector scale7 = hvx_vec_splat_f32(scale[7]); + + const uint32_t epv = 128 / sizeof(float); + const uint32_t nvec = n / epv; + const uint32_t tail = n % epv; + for (uint32_t i = 0; i < nvec; ++i) { + HVX_Vector vs = hvx_vmem(src + i * epv); + HVX_Vector vdot = hvx_vmem(dot + i * epv); + + HVX_Vector out0 = hvx_vec_add_f32_f32(hvx_vmemu(dst0 + i * epv), hvx_vec_mul_f32_f32(vs, scale0)); + HVX_Vector out1 = hvx_vec_add_f32_f32(hvx_vmemu(dst1 + i * epv), hvx_vec_mul_f32_f32(vs, scale1)); + HVX_Vector out2 = hvx_vec_add_f32_f32(hvx_vmemu(dst2 + i * epv), hvx_vec_mul_f32_f32(vs, scale2)); + HVX_Vector out3 = hvx_vec_add_f32_f32(hvx_vmemu(dst3 + i * epv), hvx_vec_mul_f32_f32(vs, scale3)); + HVX_Vector out4 = hvx_vec_add_f32_f32(hvx_vmemu(dst4 + i * epv), hvx_vec_mul_f32_f32(vs, scale4)); + HVX_Vector out5 = hvx_vec_add_f32_f32(hvx_vmemu(dst5 + i * epv), hvx_vec_mul_f32_f32(vs, scale5)); + HVX_Vector out6 = hvx_vec_add_f32_f32(hvx_vmemu(dst6 + i * epv), hvx_vec_mul_f32_f32(vs, scale6)); + HVX_Vector out7 = hvx_vec_add_f32_f32(hvx_vmemu(dst7 + i * epv), hvx_vec_mul_f32_f32(vs, scale7)); + + hvx_vmemu(dst0 + i * epv) = out0; + hvx_vmemu(dst1 + i * epv) = out1; + hvx_vmemu(dst2 + i * epv) = out2; + hvx_vmemu(dst3 + i * epv) = out3; + hvx_vmemu(dst4 + i * epv) = out4; + hvx_vmemu(dst5 + i * epv) = out5; + hvx_vmemu(dst6 + i * epv) = out6; + hvx_vmemu(dst7 + i * epv) = out7; + + acc0 = hvx_vec_add_f32_f32(acc0, hvx_vec_mul_f32_f32(out0, vdot)); + acc1 = hvx_vec_add_f32_f32(acc1, hvx_vec_mul_f32_f32(out1, vdot)); + acc2 = hvx_vec_add_f32_f32(acc2, hvx_vec_mul_f32_f32(out2, vdot)); + acc3 = hvx_vec_add_f32_f32(acc3, hvx_vec_mul_f32_f32(out3, vdot)); + acc4 = hvx_vec_add_f32_f32(acc4, hvx_vec_mul_f32_f32(out4, vdot)); + acc5 = hvx_vec_add_f32_f32(acc5, hvx_vec_mul_f32_f32(out5, vdot)); + acc6 = hvx_vec_add_f32_f32(acc6, hvx_vec_mul_f32_f32(out6, vdot)); + acc7 = hvx_vec_add_f32_f32(acc7, hvx_vec_mul_f32_f32(out7, vdot)); + } + + if (tail) { + const uint32_t off = nvec * epv; + HVX_Vector vs = hvx_vmem(src + off); + HVX_Vector vdot = hvx_vmem(dot + off); + HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float)); + HVX_Vector zero = Q6_V_vzero(); + + HVX_Vector out0 = hvx_vec_add_f32_f32(hvx_vmemu(dst0 + off), hvx_vec_mul_f32_f32(vs, scale0)); + HVX_Vector out1 = hvx_vec_add_f32_f32(hvx_vmemu(dst1 + off), hvx_vec_mul_f32_f32(vs, scale1)); + HVX_Vector out2 = hvx_vec_add_f32_f32(hvx_vmemu(dst2 + off), hvx_vec_mul_f32_f32(vs, scale2)); + HVX_Vector out3 = hvx_vec_add_f32_f32(hvx_vmemu(dst3 + off), hvx_vec_mul_f32_f32(vs, scale3)); + HVX_Vector out4 = hvx_vec_add_f32_f32(hvx_vmemu(dst4 + off), hvx_vec_mul_f32_f32(vs, scale4)); + HVX_Vector out5 = hvx_vec_add_f32_f32(hvx_vmemu(dst5 + off), hvx_vec_mul_f32_f32(vs, scale5)); + HVX_Vector out6 = hvx_vec_add_f32_f32(hvx_vmemu(dst6 + off), hvx_vec_mul_f32_f32(vs, scale6)); + HVX_Vector out7 = hvx_vec_add_f32_f32(hvx_vmemu(dst7 + off), hvx_vec_mul_f32_f32(vs, scale7)); + + hvx_vec_store_u(dst0 + off, tail * sizeof(float), out0); + hvx_vec_store_u(dst1 + off, tail * sizeof(float), out1); + hvx_vec_store_u(dst2 + off, tail * sizeof(float), out2); + hvx_vec_store_u(dst3 + off, tail * sizeof(float), out3); + hvx_vec_store_u(dst4 + off, tail * sizeof(float), out4); + hvx_vec_store_u(dst5 + off, tail * sizeof(float), out5); + hvx_vec_store_u(dst6 + off, tail * sizeof(float), out6); + hvx_vec_store_u(dst7 + off, tail * sizeof(float), out7); + + acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero)); + acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero)); + acc2 = hvx_vec_add_f32_f32(acc2, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out2, vdot), zero)); + acc3 = hvx_vec_add_f32_f32(acc3, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out3, vdot), zero)); + acc4 = hvx_vec_add_f32_f32(acc4, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out4, vdot), zero)); + acc5 = hvx_vec_add_f32_f32(acc5, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out5, vdot), zero)); + acc6 = hvx_vec_add_f32_f32(acc6, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out6, vdot), zero)); + acc7 = hvx_vec_add_f32_f32(acc7, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out7, vdot), zero)); + } + + HVX_Vector_x4 accA = { .v = { acc0, acc1, acc2, acc3 } }; + HVX_Vector_x4 accB = { .v = { acc4, acc5, acc6, acc7 } }; + hvx_vec_store_u(sums + 0, 4 * sizeof(float), hvx_vec_reduce_sum_f32x4(accA)); + hvx_vec_store_u(sums + 4, 4 * sizeof(float), hvx_vec_reduce_sum_f32x4(accB)); +} + +static void gated_delta_net_f32_pp_thread(unsigned int nth, unsigned int ith, void * data) { + struct htp_gdn_context * gctx = (struct htp_gdn_context *) data; + struct htp_ops_context * octx = gctx->octx; + + const struct htp_tensor * q = octx->src[0]; + const struct htp_tensor * k = octx->src[1]; + const struct htp_tensor * v = octx->src[2]; + const struct htp_tensor * g = octx->src[3]; + const struct htp_tensor * beta = octx->src[4]; + const struct htp_tensor * state = octx->src[5]; + const struct htp_tensor * dst = octx->dst; + + const uint32_t S_v = v->ne[0]; + const uint32_t H = v->ne[1]; + const uint32_t n_tokens = v->ne[2]; + const uint32_t n_seqs = v->ne[3]; + + const uint32_t total_rows = H * n_seqs; + if (ith >= total_rows) { + return; + } + + const uint32_t rq3 = n_seqs / q->ne[3]; + const uint32_t rk3 = n_seqs / k->ne[3]; + const float scale = 1.0f / sqrtf((float) S_v); + + float * dst_base = (float *) (uintptr_t) dst->data; + float * state_out_base = dst_base + (uint64_t) S_v * H * n_tokens * n_seqs; + const float * state_in_base = (const float *) (uintptr_t) state->data; + + const bool kda = (g->ne[0] == S_v); + float local_gate[HTP_GDN_MAX_SV] __attribute__((aligned(128))); + float local_q[HTP_GDN_MAX_SV] __attribute__((aligned(128))); + float local_k[HTP_GDN_MAX_SV] __attribute__((aligned(128))); + float local_sums[4] __attribute__((aligned(128))); + + for (uint32_t ir = ith; ir < total_rows; ir += nth) { + const uint32_t iv1 = ir % H; + const uint32_t iv3 = ir / H; + + const uint32_t iq1 = iv1 % q->ne[1]; + const uint32_t ik1 = iv1 % k->ne[1]; + const uint32_t iq3 = iv3 / rq3; + const uint32_t ik3 = iv3 / rk3; + + float * s_out = state_out_base + ((uint64_t) iv3 * H + iv1) * S_v * S_v; + const float * s_in = state_in_base + ((uint64_t) iv3 * H + iv1) * S_v * S_v; + + memcpy(s_out, s_in, gctx->state_bytes); + float * s_work = s_out; + + float * attn_data = dst_base + ((uint64_t) iv3 * n_tokens * H + iv1) * S_v; + + for (uint32_t t = 0; t < n_tokens; ++t) { + const float * q_t = (const float *) ((const uint8_t *) (uintptr_t) q->data + + (uint64_t) iq3 * q->nb[3] + (uint64_t) t * q->nb[2] + (uint64_t) iq1 * q->nb[1]); + const float * k_t = (const float *) ((const uint8_t *) (uintptr_t) k->data + + (uint64_t) ik3 * k->nb[3] + (uint64_t) t * k->nb[2] + (uint64_t) ik1 * k->nb[1]); + const float * v_t = (const float *) ((const uint8_t *) (uintptr_t) v->data + + (uint64_t) iv3 * v->nb[3] + (uint64_t) t * v->nb[2] + (uint64_t) iv1 * v->nb[1]); + const float * g_t = (const float *) ((const uint8_t *) (uintptr_t) g->data + + (uint64_t) iv3 * g->nb[3] + (uint64_t) t * g->nb[2] + (uint64_t) iv1 * g->nb[1]); + const float beta_val = *(const float *) ((const uint8_t *) (uintptr_t) beta->data + + (uint64_t) iv3 * beta->nb[3] + (uint64_t) t * beta->nb[2] + (uint64_t) iv1 * beta->nb[1]); + + memcpy(local_q, q_t, (size_t) S_v * sizeof(float)); + memcpy(local_k, k_t, (size_t) S_v * sizeof(float)); + + if (kda) { + hvx_exp_f32((uint8_t *) local_gate, (const uint8_t *) g_t, S_v, false); + + uint32_t j = 0; + for (; j + 4 <= S_v; j += 4) { + float * row0 = s_work + (uint64_t) (j + 0) * S_v; + float * row1 = s_work + (uint64_t) (j + 1) * S_v; + float * row2 = s_work + (uint64_t) (j + 2) * S_v; + float * row3 = s_work + (uint64_t) (j + 3) * S_v; + gdn_mul_dot4_f32(row0, row1, row2, row3, local_gate, local_k, S_v, local_sums); + float local_delta_b[4] __attribute__((aligned(128))); + for (uint32_t r = 0; r < 4; ++r) { + local_delta_b[r] = (v_t[j + r] - local_sums[r]) * beta_val; + } + gdn_add_scaled_dot4_f32(row0, row1, row2, row3, local_k, local_delta_b, local_q, S_v, local_sums); + for (uint32_t r = 0; r < 4; ++r) { + attn_data[j + r] = local_sums[r] * scale; + } + } + for (; j < S_v; ++j) { + float * row = s_work + (uint64_t) j * S_v; + const float sum = gdn_mul_dot_f32(row, local_gate, local_k, S_v); + const float dj = (v_t[j] - sum) * beta_val; + attn_data[j] = gdn_add_scaled_dot_f32(row, local_k, dj, local_q, S_v) * scale; + } + } else { + const float gate = expf(g_t[0]); + uint32_t j = 0; + for (; j + 4 <= S_v; j += 4) { + float * row0 = s_work + (uint64_t) (j + 0) * S_v; + float * row1 = s_work + (uint64_t) (j + 1) * S_v; + float * row2 = s_work + (uint64_t) (j + 2) * S_v; + float * row3 = s_work + (uint64_t) (j + 3) * S_v; + gdn_mul_scalar_dot4_f32(row0, row1, row2, row3, gate, local_k, S_v, local_sums); + float local_delta_b[4] __attribute__((aligned(128))); + for (uint32_t r = 0; r < 4; ++r) { + local_delta_b[r] = (v_t[j + r] - local_sums[r]) * beta_val; + } + gdn_add_scaled_dot4_f32(row0, row1, row2, row3, local_k, local_delta_b, local_q, S_v, local_sums); + for (uint32_t r = 0; r < 4; ++r) { + attn_data[j + r] = local_sums[r] * scale; + } + } + for (; j < S_v; ++j) { + float * row = s_work + (uint64_t) j * S_v; + const float sum = gdn_mul_scalar_dot_f32(row, gate, local_k, S_v); + const float dj = (v_t[j] - sum) * beta_val; + attn_data[j] = gdn_add_scaled_dot_f32(row, local_k, dj, local_q, S_v) * scale; + } + } + + attn_data += (uint64_t) S_v * H; + } + } +} + +static void gated_delta_net_f32_tg_thread(unsigned int nth, unsigned int ith, void * data) { + struct htp_gdn_context * gctx = (struct htp_gdn_context *) data; + struct htp_ops_context * octx = gctx->octx; + + const struct htp_tensor * q = octx->src[0]; + const struct htp_tensor * k = octx->src[1]; + const struct htp_tensor * v = octx->src[2]; + const struct htp_tensor * g = octx->src[3]; + const struct htp_tensor * beta = octx->src[4]; + const struct htp_tensor * state = octx->src[5]; + const struct htp_tensor * dst = octx->dst; + + const uint32_t S_v = v->ne[0]; + const uint32_t H = v->ne[1]; + const uint32_t n_seqs = v->ne[3]; + + const uint32_t total_rows = H * n_seqs; + if (ith >= total_rows) { + return; + } + + const uint32_t rq3 = n_seqs / q->ne[3]; + const uint32_t rk3 = n_seqs / k->ne[3]; + const float scale = 1.0f / sqrtf((float) S_v); + + float * dst_base = (float *) (uintptr_t) dst->data; + float * state_out_base = dst_base + (uint64_t) S_v * H * n_seqs; + const float * state_in_base = (const float *) (uintptr_t) state->data; + + const bool kda = (g->ne[0] == S_v); + float local_gate[HTP_GDN_MAX_SV] __attribute__((aligned(128))); + float local_q[HTP_GDN_MAX_SV] __attribute__((aligned(128))); + float local_k[HTP_GDN_MAX_SV] __attribute__((aligned(128))); + float local_sums[8] __attribute__((aligned(128))); + + dma_queue * dma = octx->ctx->dma[ith]; + + uint8_t * spad = NULL; + if (gctx->use_vtcm) { + spad = gctx->vtcm_state_base + gctx->vtcm_state_per_thread * ith; + } + + for (uint32_t ir = ith; ir < total_rows; ir += nth) { + const uint32_t iv1 = ir % H; + const uint32_t iv3 = ir / H; + + const uint32_t iq1 = iv1 % q->ne[1]; + const uint32_t ik1 = iv1 % k->ne[1]; + const uint32_t iq3 = iv3 / rq3; + const uint32_t ik3 = iv3 / rk3; + + float * s_out = state_out_base + ((uint64_t) iv3 * H + iv1) * S_v * S_v; + const float * s_in = state_in_base + ((uint64_t) iv3 * H + iv1) * S_v * S_v; + float * s_work; + + if (spad) { + dma_queue_push(dma, dma_make_ptr(spad, s_in), + S_v * sizeof(float), S_v * sizeof(float), + S_v * sizeof(float), S_v); + dma_queue_pop(dma); + s_work = (float *) spad; + } else { + s_work = s_out; + memcpy(s_work, s_in, gctx->state_bytes); + } + + float * attn_data = dst_base + ((uint64_t) iv3 * H + iv1) * S_v; + + const float * q_t = (const float *) ((const uint8_t *) (uintptr_t) q->data + + (uint64_t) iq3 * q->nb[3] + (uint64_t) iq1 * q->nb[1]); + const float * k_t = (const float *) ((const uint8_t *) (uintptr_t) k->data + + (uint64_t) ik3 * k->nb[3] + (uint64_t) ik1 * k->nb[1]); + const float * v_t = (const float *) ((const uint8_t *) (uintptr_t) v->data + + (uint64_t) iv3 * v->nb[3] + (uint64_t) iv1 * v->nb[1]); + const float * g_t = (const float *) ((const uint8_t *) (uintptr_t) g->data + + (uint64_t) iv3 * g->nb[3] + (uint64_t) iv1 * g->nb[1]); + const float beta_val = *(const float *) ((const uint8_t *) (uintptr_t) beta->data + + (uint64_t) iv3 * beta->nb[3] + (uint64_t) iv1 * beta->nb[1]); + + memcpy(local_q, q_t, (size_t) S_v * sizeof(float)); + memcpy(local_k, k_t, (size_t) S_v * sizeof(float)); + + if (kda) { + hvx_exp_f32((uint8_t *) local_gate, (const uint8_t *) g_t, S_v, false); + + uint32_t j = 0; + for (; j + 8 <= S_v; j += 8) { + float * row0 = s_work + (uint64_t) (j + 0) * S_v; + float * row1 = s_work + (uint64_t) (j + 1) * S_v; + float * row2 = s_work + (uint64_t) (j + 2) * S_v; + float * row3 = s_work + (uint64_t) (j + 3) * S_v; + float * row4 = s_work + (uint64_t) (j + 4) * S_v; + float * row5 = s_work + (uint64_t) (j + 5) * S_v; + float * row6 = s_work + (uint64_t) (j + 6) * S_v; + float * row7 = s_work + (uint64_t) (j + 7) * S_v; + gdn_mul_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7, + local_gate, local_k, S_v, local_sums); + float local_delta_b[8] __attribute__((aligned(128))); + for (uint32_t r = 0; r < 8; ++r) { + local_delta_b[r] = (v_t[j + r] - local_sums[r]) * beta_val; + } + gdn_add_scaled_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7, + local_k, local_delta_b, local_q, S_v, local_sums); + for (uint32_t r = 0; r < 8; ++r) { + attn_data[j + r] = local_sums[r] * scale; + } + } + for (; j + 4 <= S_v; j += 4) { + float * row0 = s_work + (uint64_t) (j + 0) * S_v; + float * row1 = s_work + (uint64_t) (j + 1) * S_v; + float * row2 = s_work + (uint64_t) (j + 2) * S_v; + float * row3 = s_work + (uint64_t) (j + 3) * S_v; + gdn_mul_dot4_f32(row0, row1, row2, row3, local_gate, local_k, S_v, local_sums); + float local_delta_b[4] __attribute__((aligned(128))); + for (uint32_t r = 0; r < 4; ++r) { + local_delta_b[r] = (v_t[j + r] - local_sums[r]) * beta_val; + } + gdn_add_scaled_dot4_f32(row0, row1, row2, row3, local_k, local_delta_b, local_q, S_v, local_sums); + for (uint32_t r = 0; r < 4; ++r) { + attn_data[j + r] = local_sums[r] * scale; + } + } + for (; j < S_v; ++j) { + float * row = s_work + (uint64_t) j * S_v; + const float sum = gdn_mul_dot_f32(row, local_gate, local_k, S_v); + const float dj = (v_t[j] - sum) * beta_val; + attn_data[j] = gdn_add_scaled_dot_f32(row, local_k, dj, local_q, S_v) * scale; + } + } else { + const float gate = expf(g_t[0]); + uint32_t j = 0; + for (; j + 8 <= S_v; j += 8) { + float * row0 = s_work + (uint64_t) (j + 0) * S_v; + float * row1 = s_work + (uint64_t) (j + 1) * S_v; + float * row2 = s_work + (uint64_t) (j + 2) * S_v; + float * row3 = s_work + (uint64_t) (j + 3) * S_v; + float * row4 = s_work + (uint64_t) (j + 4) * S_v; + float * row5 = s_work + (uint64_t) (j + 5) * S_v; + float * row6 = s_work + (uint64_t) (j + 6) * S_v; + float * row7 = s_work + (uint64_t) (j + 7) * S_v; + gdn_mul_scalar_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7, + gate, local_k, S_v, local_sums); + float local_delta_b[8] __attribute__((aligned(128))); + for (uint32_t r = 0; r < 8; ++r) { + local_delta_b[r] = (v_t[j + r] - local_sums[r]) * beta_val; + } + gdn_add_scaled_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7, + local_k, local_delta_b, local_q, S_v, local_sums); + for (uint32_t r = 0; r < 8; ++r) { + attn_data[j + r] = local_sums[r] * scale; + } + } + for (; j + 4 <= S_v; j += 4) { + float * row0 = s_work + (uint64_t) (j + 0) * S_v; + float * row1 = s_work + (uint64_t) (j + 1) * S_v; + float * row2 = s_work + (uint64_t) (j + 2) * S_v; + float * row3 = s_work + (uint64_t) (j + 3) * S_v; + gdn_mul_scalar_dot4_f32(row0, row1, row2, row3, gate, local_k, S_v, local_sums); + float local_delta_b[4] __attribute__((aligned(128))); + for (uint32_t r = 0; r < 4; ++r) { + local_delta_b[r] = (v_t[j + r] - local_sums[r]) * beta_val; + } + gdn_add_scaled_dot4_f32(row0, row1, row2, row3, local_k, local_delta_b, local_q, S_v, local_sums); + for (uint32_t r = 0; r < 4; ++r) { + attn_data[j + r] = local_sums[r] * scale; + } + } + for (; j < S_v; ++j) { + float * row = s_work + (uint64_t) j * S_v; + const float sum = gdn_mul_scalar_dot_f32(row, gate, local_k, S_v); + const float dj = (v_t[j] - sum) * beta_val; + attn_data[j] = gdn_add_scaled_dot_f32(row, local_k, dj, local_q, S_v) * scale; + } + } + + if (spad) { + dma_queue_push(dma, dma_make_ptr(s_out, spad), + S_v * sizeof(float), S_v * sizeof(float), + S_v * sizeof(float), S_v); + dma_queue_pop(dma); + } + } +} + +int op_gated_delta_net(struct htp_ops_context * octx) { + const struct htp_tensor * q = octx->src[0]; + const struct htp_tensor * k = octx->src[1]; + const struct htp_tensor * v = octx->src[2]; + const struct htp_tensor * g = octx->src[3]; + const struct htp_tensor * beta = octx->src[4]; + const struct htp_tensor * state = octx->src[5]; + const struct htp_tensor * dst = octx->dst; + + if (!q || !k || !v || !g || !beta || !state || !dst) { + return HTP_STATUS_INVAL_PARAMS; + } + + if (q->type != HTP_TYPE_F32 || k->type != HTP_TYPE_F32 || v->type != HTP_TYPE_F32 || + g->type != HTP_TYPE_F32 || beta->type != HTP_TYPE_F32 || state->type != HTP_TYPE_F32 || + dst->type != HTP_TYPE_F32) { + return HTP_STATUS_NO_SUPPORT; + } + + const uint32_t S_v = v->ne[0]; + const uint32_t H = v->ne[1]; + const uint32_t n_tokens = v->ne[2]; + const uint32_t n_seqs = v->ne[3]; + + if (S_v == 0 || S_v > HTP_GDN_MAX_SV || H == 0 || n_tokens == 0 || n_seqs == 0) { + return HTP_STATUS_NO_SUPPORT; + } + if ((g->ne[0] != 1 && g->ne[0] != S_v) || beta->ne[0] != 1) { + return HTP_STATUS_NO_SUPPORT; + } + if (q->ne[0] != S_v || k->ne[0] != S_v || q->ne[1] == 0 || k->ne[1] == 0 || + q->ne[2] != n_tokens || k->ne[2] != n_tokens || q->ne[3] == 0 || k->ne[3] == 0 || + (n_seqs % q->ne[3]) != 0 || (n_seqs % k->ne[3]) != 0) { + return HTP_STATUS_NO_SUPPORT; + } + if (state->ne[0] * state->ne[1] * state->ne[2] * state->ne[3] != S_v * S_v * H * n_seqs) { + return HTP_STATUS_NO_SUPPORT; + } + if (dst->ne[0] != S_v * H || dst->ne[1] != n_tokens * n_seqs + S_v * n_seqs) { + return HTP_STATUS_NO_SUPPORT; + } + + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { + return HTP_STATUS_OK; + } + + struct htp_gdn_context gctx; + gctx.octx = octx; + gctx.rows_per_thread = (H * n_seqs + octx->n_threads - 1) / octx->n_threads; + gctx.state_bytes = (size_t) S_v * S_v * sizeof(float); + + size_t state_aligned = (size_t) S_v * S_v * sizeof(float); + state_aligned = (state_aligned + 127) & ~(size_t)127; + + gctx.use_vtcm = false; + gctx.vtcm_state_base = NULL; + gctx.vtcm_state_per_thread = 0; + + if (n_tokens == 1 && octx->ctx->vtcm_base) { + size_t vtcm_total = state_aligned * octx->n_threads; + if (octx->ctx->vtcm_size >= vtcm_total) { + gctx.use_vtcm = true; + gctx.vtcm_state_base = octx->ctx->vtcm_base; + gctx.vtcm_state_per_thread = state_aligned; + } + } + + if (n_tokens == 1) { + worker_pool_run_func(octx->ctx->worker_pool, gated_delta_net_f32_tg_thread, &gctx, octx->n_threads); + } else { + worker_pool_run_func(octx->ctx->worker_pool, gated_delta_net_f32_pp_thread, &gctx, octx->n_threads); + } + + return HTP_STATUS_OK; +} diff --git a/ggml/src/ggml-hexagon/htp/htp-ctx.h b/ggml/src/ggml-hexagon/htp/htp-ctx.h index e9c563ca887..92f02eac6e3 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ctx.h +++ b/ggml/src/ggml-hexagon/htp/htp-ctx.h @@ -106,5 +106,6 @@ int op_cumsum(struct htp_ops_context * octx); int op_fill(struct htp_ops_context * octx); int op_diag(struct htp_ops_context * octx); int op_solve_tri(struct htp_ops_context * octx); +int op_gated_delta_net(struct htp_ops_context * octx); #endif /* HTP_CTX_H */ diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index ef96ad38278..6203e3848b9 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -84,6 +84,7 @@ enum htp_op_code { HTP_OP_DIAG, HTP_OP_SOLVE_TRI, HTP_OP_L2_NORM, + HTP_OP_GATED_DELTA_NET, HTP_OP_INVALID }; diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index e18f1a0e61e..fa1e0698f4a 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -594,6 +594,9 @@ static int execute_op(struct htp_ops_context * octx) { case HTP_OP_SOLVE_TRI: return op_solve_tri(octx); + case HTP_OP_GATED_DELTA_NET: + return op_gated_delta_net(octx); + case HTP_OP_INVALID: break; From 197c62c10b0fd0452a740704cbba3257526df04c Mon Sep 17 00:00:00 2001 From: AesSedai <7980540+AesSedai@users.noreply.github.com> Date: Fri, 8 May 2026 20:28:29 -0700 Subject: [PATCH 577/831] Add flash attention MMA / Tiles to support MiMo-V2.5 (llama/22812) * mimo-v2.5: add flash attention mma/tiles for for d_kq=192 d_v=128 * mimo-v2.5: follow (256, 256) fattn templates * mimo-v2.5: cleanup comments * mimo-v2.5: further comment cleanup * mimo-v2.5: address PR feedback fix GQA handling check for other dangling 320/576 carveouts and mirror them for 192 Add to backend ops test so new paths are covered --- ggml/src/ggml-cuda/fattn-mma-f16.cuh | 9 +++++ ggml/src/ggml-cuda/fattn-tile.cu | 4 ++ ggml/src/ggml-cuda/fattn-tile.cuh | 40 ++++++++++++++++++- ggml/src/ggml-cuda/fattn.cu | 33 +++++++++++++-- ...ttn-mma-f16-instance-ncols1_1-ncols2_16.cu | 1 + ...attn-mma-f16-instance-ncols1_1-ncols2_8.cu | 1 + ...ttn-mma-f16-instance-ncols1_2-ncols2_16.cu | 1 + ...attn-mma-f16-instance-ncols1_2-ncols2_8.cu | 1 + ...ttn-mma-f16-instance-ncols1_4-ncols2_16.cu | 1 + ...attn-mma-f16-instance-ncols1_4-ncols2_8.cu | 1 + ...attn-mma-f16-instance-ncols1_8-ncols2_8.cu | 1 + .../fattn-tile-instance-dkq192-dv128.cu | 5 +++ .../template-instances/generate_cu_files.py | 13 ++++-- 13 files changed, 102 insertions(+), 9 deletions(-) create mode 100644 ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq192-dv128.cu diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 3f01e858de7..43e22c5e5ee 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -61,6 +61,11 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 32, 128, 2, 64, 64, 64, 64, 2, true); GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 64, 128, 2, 64, 64, 64, 64, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 8, 64, 4, 64, 96, 64, 64, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 16, 64, 4, 32, 96, 64, 64, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 32, 128, 2, 32, 96, 64, 64, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 64, 128, 2, 32, 96, 64, 64, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 8, 64, 4, 64, 128, 128, 128, 2, true); GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 64, 4, 32, 128, 128, 128, 2, true); GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 2, true); @@ -1561,6 +1566,10 @@ static __global__ void flash_attn_ext_f16( NO_DEVICE_CODE; return; } + if (DKQ == 192 && ncols2 != 8 && ncols2 != 16) { + NO_DEVICE_CODE; + return; + } #ifdef VOLTA_MMA_AVAILABLE if (ncols1*ncols2 < 32) { NO_DEVICE_CODE; diff --git a/ggml/src/ggml-cuda/fattn-tile.cu b/ggml/src/ggml-cuda/fattn-tile.cu index d60634cc0e9..c8281497d14 100644 --- a/ggml/src/ggml-cuda/fattn-tile.cu +++ b/ggml/src/ggml-cuda/fattn-tile.cu @@ -34,6 +34,10 @@ void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor GGML_ASSERT(V->ne[0] == K->ne[0]); ggml_cuda_flash_attn_ext_tile_case<128, 128>(ctx, dst); } break; + case 192: { + GGML_ASSERT(V->ne[0] == 128); + ggml_cuda_flash_attn_ext_tile_case<192, 128>(ctx, dst); + } break; case 256: { GGML_ASSERT(V->ne[0] == K->ne[0]); ggml_cuda_flash_attn_ext_tile_case<256, 256>(ctx, dst); diff --git a/ggml/src/ggml-cuda/fattn-tile.cuh b/ggml/src/ggml-cuda/fattn-tile.cuh index 585f2c22853..7b0a5e5cf49 100644 --- a/ggml/src/ggml-cuda/fattn-tile.cuh +++ b/ggml/src/ggml-cuda/fattn-tile.cuh @@ -62,6 +62,12 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 16, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 2, 64, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 4, 128, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 8, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 16, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 32, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 2, 64, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 4, 128, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 8, 256, 2, 64, 64) @@ -124,6 +130,12 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 16, 128, 3, 32, 128) GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 2, 128, 3, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 4, 128, 3, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 8, 256, 2, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 16, 256, 2, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 32, 256, 2, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 2, 128, 3, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 4, 128, 3, 32, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 8, 256, 2, 32, 256) @@ -193,6 +205,12 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 64, 256, 2, 64, 32) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 2, 256, 2, 128, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 4, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 8, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 16, 256, 2, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 32, 256, 2, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 2, 256, 2, 128, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 4, 256, 2, 64, 128) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 8, 256, 2, 64, 128) @@ -264,6 +282,12 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 3, 128, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 64, 256, 3, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 2, 64, 8, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 4, 128, 6, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 8, 128, 6, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 16, 256, 5, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 32, 256, 3, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 2, 64, 8, 32, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 4, 128, 6, 32, 256) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 8, 128, 6, 32, 256) @@ -1250,7 +1274,20 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm } } - if constexpr (DKQ <= 512 && DKQ != 320) { + if constexpr (DKQ == 192) { + // MiMo-V2.5 / V2.5-Pro / V2-Flash: gqa_ratio is 8 (SWA) or 16 (full attn) + if (use_gqa_opt && gqa_ratio % 16 == 0) { + launch_fattn_tile_switch_ncols1(ctx, dst); + return; + } + if (use_gqa_opt && gqa_ratio % 8 == 0) { + launch_fattn_tile_switch_ncols1(ctx, dst); + return; + } + GGML_ABORT("flash-attn tile (192/128): expected GQA ratio multiple of 8"); + } + + if constexpr (DKQ <= 512 && DKQ != 320 && DKQ != 192) { if (use_gqa_opt && gqa_ratio % 8 == 0) { launch_fattn_tile_switch_ncols1(ctx, dst); return; @@ -1303,6 +1340,7 @@ extern DECL_FATTN_TILE_CASE( 80, 80); extern DECL_FATTN_TILE_CASE( 96, 96); extern DECL_FATTN_TILE_CASE(112, 112); extern DECL_FATTN_TILE_CASE(128, 128); +extern DECL_FATTN_TILE_CASE(192, 128); extern DECL_FATTN_TILE_CASE(256, 256); extern DECL_FATTN_TILE_CASE(320, 256); extern DECL_FATTN_TILE_CASE(512, 512); diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 8256591b21d..e045b04f727 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -139,6 +139,22 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg GGML_ASSERT(V->ne[0] == 128); ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<128, 128>(ctx, dst); break; + case 192: { + // MiMo-V2.5 / V2.5-Pro / V2-Flash: gqa_ratio is 8 (SWA) or 16 (full attn) + GGML_ASSERT(V->ne[0] == 128); + float max_bias = 0.0f; + memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); + const bool use_gqa_opt = mask && max_bias == 0.0f; + GGML_ASSERT(use_gqa_opt); + GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); + const int gqa_ratio = Q->ne[2] / K->ne[2]; + if (gqa_ratio % 16 == 0) { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<192, 128, 16>(ctx, dst); + } else { + GGML_ASSERT(gqa_ratio % 8 == 0); + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<192, 128, 8>(ctx, dst); + } + } break; case 256: GGML_ASSERT(V->ne[0] == 256); ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<256, 256>(ctx, dst); @@ -368,6 +384,14 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const return BEST_FATTN_KERNEL_NONE; } break; + case 192: + if (V->ne[0] != 128 || !gqa_opt_applies) { + return BEST_FATTN_KERNEL_NONE; + } + if (gqa_ratio % 8 != 0) { + return BEST_FATTN_KERNEL_NONE; + } + break; case 320: if (V->ne[0] != 256 || !gqa_opt_applies) { return BEST_FATTN_KERNEL_NONE; @@ -425,7 +449,8 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const } // For small batch sizes the vector kernel may be preferable over the kernels optimized for large batch sizes: - const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0 && K->ne[1] % FATTN_KQ_STRIDE == 0; + // 192 satisfies % 64 == 0 but has no vec instance (DKQ != DV); force it onto the MMA path. + const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0 && Q->ne[0] != 192 && K->ne[1] % FATTN_KQ_STRIDE == 0; // If Turing tensor cores are available, use them: if (turing_mma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72) { @@ -454,7 +479,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const if (volta_mma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72) { int gqa_ratio_eff = 1; - const int ncols2_max = Q->ne[0] == 576 ? 16 : 8; + const int ncols2_max = (Q->ne[0] == 576 || Q->ne[0] == 192) ? 16 : 8; while (gqa_ratio % (2*gqa_ratio_eff) == 0 && gqa_ratio_eff < ncols2_max) { gqa_ratio_eff *= 2; } @@ -468,7 +493,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const } // Use the WMMA kernel if possible: - if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 512 && Q->ne[0] != 576) { + if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 192 && Q->ne[0] != 512 && Q->ne[0] != 576) { if (can_use_vector_kernel && Q->ne[1] <= 2) { return BEST_FATTN_KERNEL_VEC; } @@ -501,7 +526,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const } // Use MFMA flash attention for CDNA (MI100+): - if (amd_mfma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 256 && Q->ne[0] != 512 && Q->ne[0] != 576) { + if (amd_mfma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 192 && Q->ne[0] != 256 && Q->ne[0] != 512 && Q->ne[0] != 576) { const int64_t eff_nq = Q->ne[1] * (gqa_opt_applies ? gqa_ratio : 1); // MMA vs tile crossover benchmarked on MI300X @ d32768: // hsk=64 (gqa=4): MMA wins at eff >= 128 (+11%) diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu index fb26abeb0da..b2661b93162 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu @@ -2,4 +2,5 @@ #include "../fattn-mma-f16.cuh" +DECL_FATTN_MMA_F16_CASE(192, 128, 1, 16); DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu index 22d383173f3..6ae77bec895 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu @@ -7,5 +7,6 @@ DECL_FATTN_MMA_F16_CASE(80, 80, 1, 8); DECL_FATTN_MMA_F16_CASE(96, 96, 1, 8); DECL_FATTN_MMA_F16_CASE(112, 112, 1, 8); DECL_FATTN_MMA_F16_CASE(128, 128, 1, 8); +DECL_FATTN_MMA_F16_CASE(192, 128, 1, 8); DECL_FATTN_MMA_F16_CASE(256, 256, 1, 8); DECL_FATTN_MMA_F16_CASE(512, 512, 1, 8); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu index f011a208cd2..fd41e71b142 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu @@ -2,4 +2,5 @@ #include "../fattn-mma-f16.cuh" +DECL_FATTN_MMA_F16_CASE(192, 128, 2, 16); DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu index 84b674cd05a..9f4bef11a44 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu @@ -7,5 +7,6 @@ DECL_FATTN_MMA_F16_CASE(80, 80, 2, 8); DECL_FATTN_MMA_F16_CASE(96, 96, 2, 8); DECL_FATTN_MMA_F16_CASE(112, 112, 2, 8); DECL_FATTN_MMA_F16_CASE(128, 128, 2, 8); +DECL_FATTN_MMA_F16_CASE(192, 128, 2, 8); DECL_FATTN_MMA_F16_CASE(256, 256, 2, 8); DECL_FATTN_MMA_F16_CASE(512, 512, 2, 8); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu index f5fd0e2369c..cc41fa52f13 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu @@ -2,4 +2,5 @@ #include "../fattn-mma-f16.cuh" +DECL_FATTN_MMA_F16_CASE(192, 128, 4, 16); DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu index 5906398db91..859bea5c525 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu @@ -7,5 +7,6 @@ DECL_FATTN_MMA_F16_CASE(80, 80, 4, 8); DECL_FATTN_MMA_F16_CASE(96, 96, 4, 8); DECL_FATTN_MMA_F16_CASE(112, 112, 4, 8); DECL_FATTN_MMA_F16_CASE(128, 128, 4, 8); +DECL_FATTN_MMA_F16_CASE(192, 128, 4, 8); DECL_FATTN_MMA_F16_CASE(256, 256, 4, 8); DECL_FATTN_MMA_F16_CASE(512, 512, 4, 8); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu index 4bc60d62f91..c975ce6b9b7 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu @@ -7,5 +7,6 @@ DECL_FATTN_MMA_F16_CASE(80, 80, 8, 8); DECL_FATTN_MMA_F16_CASE(96, 96, 8, 8); DECL_FATTN_MMA_F16_CASE(112, 112, 8, 8); DECL_FATTN_MMA_F16_CASE(128, 128, 8, 8); +DECL_FATTN_MMA_F16_CASE(192, 128, 8, 8); DECL_FATTN_MMA_F16_CASE(256, 256, 8, 8); DECL_FATTN_MMA_F16_CASE(512, 512, 8, 8); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq192-dv128.cu b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq192-dv128.cu new file mode 100644 index 00000000000..b571cca0df2 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq192-dv128.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.cuh" + +DECL_FATTN_TILE_CASE(192, 128); diff --git a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py index 5e9a1cb2eb3..af05a9eff71 100755 --- a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +++ b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py @@ -3,7 +3,10 @@ from glob import glob import os -HEAD_SIZES_KQ = [40, 64, 72, 80, 96, 112, 128, 256, 320, 512, 576] +HEAD_SIZES_KQ = [40, 64, 72, 80, 96, 112, 128, 192, 256, 320, 512, 576] + +# DKQ -> DV override for asymmetric head dims. +HEAD_SIZES_V_OVERRIDE = {576: 512, 320: 256, 192: 128} TYPES_KV = ["GGML_TYPE_F16", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", "GGML_TYPE_BF16"] @@ -62,7 +65,7 @@ def get_short_name(long_quant_name): os.remove(filename) for head_size_kq in HEAD_SIZES_KQ: - head_size_v = 256 if head_size_kq == 320 else (head_size_kq if head_size_kq != 576 else 512) + head_size_v = HEAD_SIZES_V_OVERRIDE.get(head_size_kq, head_size_kq) with open(f"fattn-tile-instance-dkq{head_size_kq}-dv{head_size_v}.cu", "w") as f: f.write(SOURCE_FATTN_TILE.format(head_size_kq=head_size_kq, head_size_v=head_size_v)) @@ -85,15 +88,17 @@ def get_short_name(long_quant_name): if head_size_kq == 72: continue # Skip compilation of unused ncols2 values for niche head sizes: + if head_size_kq == 192 and ncols2 not in (8, 16): # MiMo-V2.5 + continue if head_size_kq == 320 and ncols2 != 32: # Mistral Small 4 continue if head_size_kq == 512 and ncols2 not in (4, 8): # Gemma 4 continue if head_size_kq == 576 and ncols2 not in (4, 16, 32): # Deepseek, GLM 4.7 Flash continue - if head_size_kq not in (320, 576) and ncols2 in (16, 32): + if head_size_kq not in (192, 320, 576) and ncols2 in (16, 32): continue - head_size_v = 256 if head_size_kq == 320 else (head_size_kq if head_size_kq != 576 else 512) + head_size_v = HEAD_SIZES_V_OVERRIDE.get(head_size_kq, head_size_kq) f.write(SOURCE_FATTN_MMA_CASE.format(ncols1=ncols1, ncols2=ncols2, head_size_kq=head_size_kq, head_size_v=head_size_v)) for type in TYPES_MMQ: From 63f788320628650ca377d77f95504ca8538f6d46 Mon Sep 17 00:00:00 2001 From: Intel AI Get-to Market Customer Success and Solutions Date: Fri, 8 May 2026 22:42:40 -0700 Subject: [PATCH 578/831] sycl: Battlemage AOT build via spir64_gen + MMQ subgroup annotations (llama/22147) * sycl: Battlemage AOT build via spir64_gen + MMQ subgroup annotations Signed-off-by: Chun Tao * Remove unneeded/unnecessary comments and annotations The MMQ subgroup annotations added are on functions gated behind ggml_sycl_supports_mmq(). Revisit the need for these annotations when that function changes. --------- Signed-off-by: Chun Tao Co-authored-by: Chun Tao Co-authored-by: Todd Malsbary --- ggml/src/ggml-sycl/CMakeLists.txt | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-sycl/CMakeLists.txt b/ggml/src/ggml-sycl/CMakeLists.txt index 8e589fa238d..8f44c6ed080 100644 --- a/ggml/src/ggml-sycl/CMakeLists.txt +++ b/ggml/src/ggml-sycl/CMakeLists.txt @@ -135,7 +135,11 @@ endif() if (GGML_SYCL_TARGET STREQUAL "INTEL") add_compile_definitions(GGML_SYCL_WARP_SIZE=16) - target_link_options(ggml-sycl PRIVATE -Xs -ze-intel-greater-than-4GB-buffer-required) + if (NOT GGML_SYCL_DEVICE_ARCH) + target_link_options(ggml-sycl PRIVATE -Xs -ze-intel-greater-than-4GB-buffer-required) + else() + message(STATUS "Skipping -ze-intel-greater-than-4GB-buffer-required for spir64_gen AOT") + endif() # Link against Intel oneMKL if (CMAKE_CXX_COMPILER_ID STREQUAL "Clang") @@ -160,7 +164,15 @@ if (GGML_SYCL_HOST_MEM_FALLBACK) endif() if (GGML_SYCL_DEVICE_ARCH) - target_compile_options(ggml-sycl PRIVATE -Xsycl-target-backend --offload-arch=${GGML_SYCL_DEVICE_ARCH}) - target_link_options(ggml-sycl PRIVATE -Xsycl-target-backend --offload-arch=${GGML_SYCL_DEVICE_ARCH}) + message(STATUS "GGML_SYCL_DEVICE_ARCH=${GGML_SYCL_DEVICE_ARCH} (AOT via spir64_gen)") + target_compile_options( + ggml-sycl PRIVATE + -fsycl-targets=spir64_gen + "SHELL:-Xsycl-target-backend=spir64_gen \"-device ${GGML_SYCL_DEVICE_ARCH}\"" + ) + target_link_options( + ggml-sycl PRIVATE + -fsycl-targets=spir64_gen + "SHELL:-Xsycl-target-backend=spir64_gen \"-device ${GGML_SYCL_DEVICE_ARCH}\"" + ) endif() - From 3542894544e53a429e6b6f110fbabfb1382d1898 Mon Sep 17 00:00:00 2001 From: Intel AI Get-to Market Customer Success and Solutions Date: Fri, 8 May 2026 22:48:07 -0700 Subject: [PATCH 579/831] sycl: Q5_K reorder MMVQ/dequant + Q8_0 reorder MMVQ path (llama/22152) * sycl: Q5_K reorder MMVQ/dequant + Q8_0 reorder MMVQ path Signed-off-by: Chun Tao * Remove duplicate definitions --------- Signed-off-by: Chun Tao Co-authored-by: Chun Tao Co-authored-by: Todd Malsbary --- ggml/src/ggml-sycl/convert.cpp | 29 ++++++++- ggml/src/ggml-sycl/dequantize.hpp | 57 ++++++++++++++++++ ggml/src/ggml-sycl/ggml-sycl.cpp | 52 ++++++++++++++++ ggml/src/ggml-sycl/mmvq.cpp | 30 +++++++++- ggml/src/ggml-sycl/quants.hpp | 25 ++++++++ ggml/src/ggml-sycl/vecdotq.hpp | 98 +++++++++++++++++++++++-------- 6 files changed, 265 insertions(+), 26 deletions(-) diff --git a/ggml/src/ggml-sycl/convert.cpp b/ggml/src/ggml-sycl/convert.cpp index 67b9c06f3e4..576f19d79ae 100644 --- a/ggml/src/ggml-sycl/convert.cpp +++ b/ggml/src/ggml-sycl/convert.cpp @@ -252,6 +252,23 @@ static void dequantize_row_q5_K_sycl(const void *vx, dst_t *y, const int64_t k, #endif } +template +static void dequantize_row_q5_K_sycl_reorder(const void * vx, dst_t * y, const int64_t k, dpct::queue_ptr stream) { + const int64_t nb = k / QK_K; + + dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); + + stream->submit([&](sycl::handler & cgh) { + sycl::local_accessor scale_local_acc(sycl::range<1>(K_SCALE_SIZE), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)), + [=](sycl::nd_item<3> item_ct1) { + dequantize_block_q5_K_reorder(vx, y, get_pointer(scale_local_acc), item_ct1, nb); + }); + }); +} + template static void dequantize_row_q6_K_sycl(const void *vx, dst_t *y, const int64_t k, dpct::queue_ptr stream) { @@ -643,7 +660,11 @@ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) { return dequantize_row_q4_K_sycl; } case GGML_TYPE_Q5_K: - return dequantize_row_q5_K_sycl; + if (dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { + return dequantize_row_q5_K_sycl_reorder; + } else { + return dequantize_row_q5_K_sycl; + } case GGML_TYPE_Q6_K: if (dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { return dequantize_row_q6_K_sycl_reorder; @@ -718,7 +739,11 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) { return dequantize_row_q4_K_sycl; } case GGML_TYPE_Q5_K: - return dequantize_row_q5_K_sycl; + if (dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { + return dequantize_row_q5_K_sycl_reorder; + } else { + return dequantize_row_q5_K_sycl; + } case GGML_TYPE_Q6_K: if (dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { return dequantize_row_q6_K_sycl_reorder; diff --git a/ggml/src/ggml-sycl/dequantize.hpp b/ggml/src/ggml-sycl/dequantize.hpp index 19fa88680d6..2324bfacd22 100644 --- a/ggml/src/ggml-sycl/dequantize.hpp +++ b/ggml/src/ggml-sycl/dequantize.hpp @@ -537,6 +537,63 @@ static void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restri #endif } +template +static void dequantize_block_q5_K_reorder(const void * __restrict__ vx, dst_t * __restrict__ yy, + uint8_t * scales_local, const sycl::nd_item<3> & item_ct1, int64_t n_blocks) { + const int64_t ib = item_ct1.get_group(2); + +#if QK_K == 256 + // assume 64 threads + const int64_t tid = item_ct1.get_local_id(2); + const int64_t il = tid / 16; // 0...3 + const int64_t ir = tid % 16; // 0...15 + const int64_t is = 2 * il; + + dst_t * y = yy + ib * QK_K + 64 * il + 2 * ir; + + const uint8_t * base = static_cast(vx); + + // Reordered layout: [qs (QK_K/2 per block)] [qh (QK_K/8 per block)] [scales (K_SCALE_SIZE per block)] [dm (half2 per block)] + const size_t qs_offset = ib * (QK_K / 2); + const size_t qh_offset = n_blocks * (QK_K / 2) + ib * (QK_K / 8); + const size_t scales_offset = n_blocks * (QK_K / 2) + n_blocks * (QK_K / 8) + ib * K_SCALE_SIZE; + const size_t dm_offset = n_blocks * (QK_K / 2) + n_blocks * (QK_K / 8) + n_blocks * K_SCALE_SIZE + ib * sizeof(ggml_half2); + + const uint8_t * qs_ptr = base + qs_offset; + const uint8_t * qh_ptr = base + qh_offset; + const uint8_t * scales_ptr = base + scales_offset; + const ggml_half2 dm_values = *reinterpret_cast(base + dm_offset); + + const float dall = dm_values.x(); + const float dmin = dm_values.y(); + + const uint8_t * ql = qs_ptr + 32 * il + 2 * ir; + const uint8_t * qh = qh_ptr + 2 * ir; + + if (tid < K_SCALE_SIZE) { + scales_local[tid] = scales_ptr[tid]; + } + + item_ct1.barrier(sycl::access::fence_space::local_space); + + uint8_t sc, m; + get_scale_min_k4(is + 0, scales_local, sc, m); + const float d1 = dall * sc; const float m1 = dmin * m; + get_scale_min_k4(is + 1, scales_local, sc, m); + const float d2 = dall * sc; const float m2 = dmin * m; + + uint8_t hm = 1 << (2 * il); + y[ 0] = d1 * ((ql[ 0] & 0xF) + (qh[ 0] & hm ? 16 : 0)) - m1; + y[ 1] = d1 * ((ql[ 1] & 0xF) + (qh[ 1] & hm ? 16 : 0)) - m1; + hm <<= 1; + y[32] = d2 * ((ql[ 0] >> 4) + (qh[ 0] & hm ? 16 : 0)) - m2; + y[33] = d2 * ((ql[ 1] >> 4) + (qh[ 1] & hm ? 16 : 0)) - m2; +#else + GGML_UNUSED(ib); GGML_UNUSED(tid); GGML_UNUSED(yy); GGML_UNUSED(scales_local); GGML_UNUSED(n_blocks); + GGML_ABORT("Q5_K reorder dequantize not supported for QK_K != 256"); +#endif +} + template static void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restrict__ yy, const sycl::nd_item<3> &item_ct1) { diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index c3ac281067a..f86ff3e9466 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -3303,6 +3303,7 @@ inline bool ggml_sycl_supports_reorder_mul_mat_sycl(enum ggml_type type) { case GGML_TYPE_Q8_0: return true; case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: return !g_ggml_sycl_prioritize_dmmv; default: @@ -3325,6 +3326,7 @@ inline bool ggml_sycl_supports_reorder_mmvq(enum ggml_type type) { case GGML_TYPE_Q4_0: case GGML_TYPE_Q8_0: case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: return true; default: @@ -3541,6 +3543,54 @@ static bool reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, d return true; } +static bool reorder_qw_q5_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) { + GGML_ASSERT(size % sizeof(block_q5_K) == 0); + GGML_ASSERT(offset % sizeof(block_q5_K) == 0); + + const int nblocks = size / sizeof(block_q5_K); + + sycl_reorder_temp_buffer tmp(stream, size); + if (!tmp) { + GGML_LOG_WARN("%s: failed to allocate %zu bytes for reorder temp buffer, skipping reorder\n", __func__, size); + return false; + } + uint8_t * tmp_buf = static_cast(tmp.ptr); + + sycl::event copy_event; + SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size))); + if (!g_ggml_sycl_use_async_mem_op) { + copy_event.wait(); + } + + auto * qs_ptr = data_device; + auto * qh_ptr = qs_ptr + (QK_K / 2) * nblocks; + auto * scales_ptr = qh_ptr + (QK_K / 8) * nblocks; + auto * dm_ptr = (sycl::half2 *) (scales_ptr + K_SCALE_SIZE * nblocks); + + auto reorder_event = stream->parallel_for(nblocks, [=](auto i) { + const block_q5_K * x = (const block_q5_K *) tmp_buf; + const int ib = i; + + for (int j = 0; j < QK_K / 2; ++j) { + qs_ptr[ib * (QK_K / 2) + j] = x[ib].qs[j]; + } + + for (int j = 0; j < QK_K / 8; ++j) { + qh_ptr[ib * (QK_K / 8) + j] = x[ib].qh[j]; + } + + for (int j = 0; j < K_SCALE_SIZE; ++j) { + scales_ptr[ib * K_SCALE_SIZE + j] = x[ib].scales[j]; + } + + dm_ptr[ib] = x[ib].dm; + }); + if (!g_ggml_sycl_use_async_mem_op) { + reorder_event.wait_and_throw(); + } + return true; +} + static bool reorder_qw_q6_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) { GGML_ASSERT(size % sizeof(block_q6_K) == 0); GGML_ASSERT(offset % sizeof(block_q6_K) == 0); @@ -3607,6 +3657,8 @@ static bool reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) { return reorder_qw_q8_0(data_device, ncols, nrows, size, 0, stream); case GGML_TYPE_Q4_K: return reorder_qw_q4_k(data_device, size, 0, stream); + case GGML_TYPE_Q5_K: + return reorder_qw_q5_k(data_device, size, 0, stream); case GGML_TYPE_Q6_K: return reorder_qw_q6_k(data_device, size, 0, stream); default: diff --git a/ggml/src/ggml-sycl/mmvq.cpp b/ggml/src/ggml-sycl/mmvq.cpp index 8fa2198f35a..49998f13ba8 100644 --- a/ggml/src/ggml-sycl/mmvq.cpp +++ b/ggml/src/ggml-sycl/mmvq.cpp @@ -839,6 +839,26 @@ static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy, } } +static void reorder_mul_mat_vec_q5_k_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, + const int nrows, dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + + const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y); + constexpr size_t num_subgroups = 16; + GGML_ASSERT(block_num_y % num_subgroups == 0); + + const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE); + const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE); + + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size), + [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_reorder>(vx, vy, dst, ncols, + nrows, nd_item); + }); + }); +} + static void reorder_mul_mat_vec_q6_k_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, dpct::queue_ptr stream) { GGML_ASSERT(ncols % QK_K == 0); @@ -1125,6 +1145,7 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q8_0_q8_1_sycl\n"); reorder_mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); } else { + GGML_SYCL_DEBUG("Calling mul_mat_vec_q8_0_q8_1_sycl\n"); mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); } break; @@ -1145,7 +1166,14 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens } break; case GGML_TYPE_Q5_K: - mul_mat_vec_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + if ((ggml_tensor_extra_gpu *) dst->src[0]->extra && + ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { + GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q5_k_q8_1_sycl\n"); + reorder_mul_mat_vec_q5_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } else { + GGML_SYCL_DEBUG("Calling mul_mat_vec_q5_K_q8_1_sycl\n"); + mul_mat_vec_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } break; case GGML_TYPE_Q6_K: if ((ggml_tensor_extra_gpu *) dst->src[0]->extra && diff --git a/ggml/src/ggml-sycl/quants.hpp b/ggml/src/ggml-sycl/quants.hpp index 1f5b62740a8..806028ef3a3 100644 --- a/ggml/src/ggml-sycl/quants.hpp +++ b/ggml/src/ggml-sycl/quants.hpp @@ -79,6 +79,31 @@ template <> struct block_q_t { static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; } }; +template <> struct block_q_t { + struct traits { + static constexpr uint32_t qk = QK_K; + static constexpr uint32_t qi = QI5_K; + static constexpr uint32_t qr = QR5_K; + static constexpr uint32_t vdr_mmvq = 2; + }; + + // Reordered layout: [qs (QK_K/2 per block)] [qh (QK_K/8 per block)] [scales] [dm] + static constexpr std::pair get_block_offset(const int block_index, const int n_blocks) { + auto qs_offset = block_index * (QK_K / 2); + auto qh_offset = n_blocks * (QK_K / 2) + block_index * (QK_K / 8); + return { qs_offset, qh_offset }; + } + + static constexpr std::pair get_d_offset(int nrows, int ncols, const int block_index) { + auto nblocks = (nrows * (ncols / QK_K)); + auto total_qs_bytes = nblocks * (QK_K / 2) + nblocks * (QK_K / 8); + return { total_qs_bytes + block_index * K_SCALE_SIZE, + total_qs_bytes + nblocks * K_SCALE_SIZE + block_index * sizeof(ggml_half2) }; + } + + static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; } +}; + template <> struct block_q_t { struct traits { static constexpr uint32_t qk = QK_K; diff --git a/ggml/src/ggml-sycl/vecdotq.hpp b/ggml/src/ggml-sycl/vecdotq.hpp index 9253168e5ea..d7770047424 100644 --- a/ggml/src/ggml-sycl/vecdotq.hpp +++ b/ggml/src/ggml-sycl/vecdotq.hpp @@ -357,38 +357,31 @@ template <> struct reorder_vec_dot_q_sycl { using q8_0_block = ggml_sycl_reordered::block_q_t; using q8_0_traits = typename q8_0_block::traits; - __dpct_inline__ float vec_dot_q8_0_q8_1_impl(const int * v, const int * u, const float & d8_0, const sycl::half2 & ds8) { - int sumi = 0; - -#pragma unroll - for (size_t i = 0; i < q8_0_traits::vdr_mmvq; ++i) { - // Q8_0 values are signed int8, no nibble extraction needed - // Direct dp4a: each int packs 4 int8 values - sumi = dpct::dp4a(v[i], u[i], sumi); - } - - const sycl::float2 ds8f = ds8.convert(); - - // Q8_0 has no bias term (values are signed), so just scale - return d8_0 * sumi * ds8f.x(); - } - __dpct_inline__ float operator()(const void * __restrict__ vbq, const std::pair ibx_offset, const std::pair d_offset, const int8_t * q8_1_quant_ptr, const sycl::half2 * q8_1_ds, const int & iqs) { - const int8_t * bq8_0 = static_cast(vbq) + ibx_offset.first; - const ggml_half d = *(reinterpret_cast(static_cast(vbq) + d_offset.first)); - int v[q8_0_traits::vdr_mmvq]; - int u[q8_0_traits::vdr_mmvq]; + const uint8_t * base = static_cast(vbq); + const int8_t * qs = reinterpret_cast(base + ibx_offset.first); + const ggml_half d = *reinterpret_cast(base + d_offset.first); + + int v[q8_0_traits::vdr_mmvq]; + int u[q8_0_traits::vdr_mmvq]; #pragma unroll for (size_t i = 0; i < q8_0_traits::vdr_mmvq; ++i) { - v[i] = get_int_from_int8(bq8_0, iqs + i); + v[i] = get_int_from_int8(qs, iqs + i); u[i] = get_int_from_int8_aligned(q8_1_quant_ptr, iqs + i); } - return vec_dot_q8_0_q8_1_impl(v, u, d, *q8_1_ds); - }; + int sumi = 0; +#pragma unroll + for (size_t i = 0; i < q8_0_traits::vdr_mmvq; ++i) { + sumi = dpct::dp4a(v[i], u[i], sumi); + } + + const sycl::half2 ds_values = *q8_1_ds; + return static_cast(d) * static_cast(ds_values[0]) * sumi; + } }; static inline float vec_dot_q4_K_q8_1_common(const int * __restrict__ q4, const uint16_t * __restrict__ scales, @@ -481,6 +474,65 @@ template <> struct reorder_vec_dot_q_sycl { } }; +template <> struct reorder_vec_dot_q_sycl { + static constexpr ggml_type gtype = GGML_TYPE_Q5_K; + + using q5_k_block = ggml_sycl_reordered::block_q_t; + using q5_k_traits = typename q5_k_block::traits; + + __dpct_inline__ float operator()(const void * __restrict__ vbq, const std::pair ibx_offset, + const std::pair d_offset, const int8_t * q8_1_quant_ptr, + const sycl::half2 * q8_1_ds, const int & iqs) { + const uint8_t * base = static_cast(vbq); + const uint8_t * qs = base + ibx_offset.first; // low 4 bits + const uint8_t * qh_base = base + ibx_offset.second; // high bit + const uint8_t * scs = base + d_offset.first; + const ggml_half2 * dms = reinterpret_cast(base + d_offset.second); + + const int bq8_offset = QR5_K * ((iqs / 2) / (QI8_1 / 2)); + const int * ql_ptr = (const int *) (qs + 16 * bq8_offset + 4 * ((iqs / 2) % 4)); + const int * qh_ptr = (const int *) (qh_base + 4 * ((iqs / 2) % 4)); + const uint16_t * scales = (const uint16_t *) scs; + + int vl[2]; + int vh[2]; + int u[2 * QR5_K]; + float d8[QR5_K]; + + vl[0] = ql_ptr[0]; + vl[1] = ql_ptr[4]; + + vh[0] = qh_ptr[0] >> bq8_offset; + vh[1] = qh_ptr[4] >> bq8_offset; + + uint16_t aux[2]; + const int j = (QR5_K * ((iqs / 2) / (QI8_1 / 2))) / 2; + if (j < 2) { + aux[0] = scales[j + 0] & 0x3f3f; + aux[1] = scales[j + 2] & 0x3f3f; + } else { + aux[0] = ((scales[j + 2] >> 0) & 0x0f0f) | ((scales[j - 2] & 0xc0c0) >> 2); + aux[1] = ((scales[j + 2] >> 4) & 0x0f0f) | ((scales[j - 0] & 0xc0c0) >> 2); + } + + const uint8_t * sc = (const uint8_t *) aux; + const uint8_t * m = sc + 2; + + for (int i = 0; i < QR5_K; ++i) { + const int8_t* quant_base_ptr = q8_1_quant_ptr + (bq8_offset + i) * QK8_1; + sycl::half2 ds_values = *(q8_1_ds + bq8_offset + i); + + d8[i] = ds_values[0]; + + const int * q8 = (const int *) quant_base_ptr + ((iqs / 2) % 4); + u[2 * i + 0] = q8[0]; + u[2 * i + 1] = q8[4]; + } + + return vec_dot_q5_K_q8_1_impl_vmmq(vl, vh, u, sc, m, *dms, d8); + } +}; + template <> struct reorder_vec_dot_q_sycl { static constexpr ggml_type gtype = GGML_TYPE_Q6_K; From 25f543175d0652204eacd643864a09a8b5fd39fe Mon Sep 17 00:00:00 2001 From: Devedse <2350015+devedse@users.noreply.github.com> Date: Sat, 9 May 2026 07:50:24 +0200 Subject: [PATCH 580/831] Add BF16 support to GET_ROWS operation (llama/21391) Add GGML_TYPE_BF16 to the SYCL backend's GET_ROWS operation, both in supports_op and in the kernel dispatch. This fixes a performance regression where models using BF16 embedding tensors (e.g., Gemma4's per_layer_token_embd.weight) fall back to CPU for the GET_ROWS op, causing a full GPU-to-CPU tensor transfer every token. The fix reuses the existing get_rows_sycl_float template with sycl::ext::oneapi::bfloat16, matching the pattern already used for sycl::half (F16) and float (F32). --- ggml/src/ggml-sycl/getrows.cpp | 4 ++++ ggml/src/ggml-sycl/ggml-sycl.cpp | 1 + 2 files changed, 5 insertions(+) diff --git a/ggml/src/ggml-sycl/getrows.cpp b/ggml/src/ggml-sycl/getrows.cpp index 03f8dd90748..ca457454775 100644 --- a/ggml/src/ggml-sycl/getrows.cpp +++ b/ggml/src/ggml-sycl/getrows.cpp @@ -183,6 +183,10 @@ void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { get_rows_sycl_float(ctx, dst->src[0], dst->src[1], dst, (const sycl::half *)dst->src[0]->data, src1_i32, (float *)dst->data, ctx.stream()); break; + case GGML_TYPE_BF16: + get_rows_sycl_float(ctx, dst->src[0], dst->src[1], dst, (const sycl::ext::oneapi::bfloat16 *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); + break; case GGML_TYPE_F32: get_rows_sycl_float(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, src1_i32, (float *)dst->data, ctx.stream()); diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index f86ff3e9466..b6e705cdf3a 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -4974,6 +4974,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g { switch (op->src[0]->type) { case GGML_TYPE_F16: + case GGML_TYPE_BF16: case GGML_TYPE_F32: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: From 8c7efe885cb38b9617ca952ebf7bc19d79bd6ffa Mon Sep 17 00:00:00 2001 From: Alexey Kopytko Date: Sat, 9 May 2026 15:30:39 +0900 Subject: [PATCH 581/831] SYCL: reduce allocation overhead during flash attention (llama/22732) * SYCL: reduce allocation overhead during flash attention * tidy up whitespace * add a note about the flag * move ggml_sycl_fattn_* into fattn-buffers.hpp * refactor implementation into fattn-buffers.cpp * move new_fattn_kv_buffers back into ggml-sycl.cpp --- ggml/src/ggml-sycl/common.hpp | 16 +++++++ ggml/src/ggml-sycl/fattn-buffers.cpp | 56 +++++++++++++++++++++++++ ggml/src/ggml-sycl/fattn-buffers.hpp | 63 ++++++++++++++++++++++++++++ ggml/src/ggml-sycl/fattn-common.hpp | 6 ++- ggml/src/ggml-sycl/ggml-sycl.cpp | 41 ++++++++++++++++++ 5 files changed, 180 insertions(+), 2 deletions(-) create mode 100644 ggml/src/ggml-sycl/fattn-buffers.cpp create mode 100644 ggml/src/ggml-sycl/fattn-buffers.hpp diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp index 5abf2290651..eec36e8db9a 100644 --- a/ggml/src/ggml-sycl/common.hpp +++ b/ggml/src/ggml-sycl/common.hpp @@ -25,6 +25,7 @@ #include "presets.hpp" #include "type.hpp" #include "sycl_hw.hpp" +#include "fattn-buffers.hpp" namespace syclexp = sycl::ext::oneapi::experimental; @@ -404,12 +405,16 @@ struct ggml_backend_sycl_context { std::unique_ptr pools[GGML_SYCL_MAX_DEVICES]; std::unordered_map>> scratchpad_map; + std::unique_ptr fattn_bufs[GGML_SYCL_MAX_DEVICES]; + std::unique_ptr host_pools[GGML_SYCL_MAX_DEVICES]; static std::unique_ptr new_pool_for_device(queue_ptr qptr, int device); static std::unique_ptr new_pool_for_host(queue_ptr qptr, int device); + static std::unique_ptr new_fattn_kv_buffers(queue_ptr qptr, int device); + ggml_sycl_pool & pool(int device) { if (pools[device] == nullptr) { pools[device] = new_pool_for_device(stream(device,0), device); @@ -421,6 +426,17 @@ struct ggml_backend_sycl_context { return pool(device); } + ggml_sycl_fattn_kv_buffers & fattn_buffers(int device) { + if (fattn_bufs[device] == nullptr) { + fattn_bufs[device] = new_fattn_kv_buffers(stream(device, 0), device); + } + return *fattn_bufs[device]; + } + + ggml_sycl_fattn_kv_buffers & fattn_buffers() { + return fattn_buffers(device); + } + #ifdef GGML_SYCL_GRAPH std::unique_ptr> exec_graph = nullptr; #endif diff --git a/ggml/src/ggml-sycl/fattn-buffers.cpp b/ggml/src/ggml-sycl/fattn-buffers.cpp new file mode 100644 index 00000000000..46cf6d551f1 --- /dev/null +++ b/ggml/src/ggml-sycl/fattn-buffers.cpp @@ -0,0 +1,56 @@ +// +// MIT license +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: MIT +// + +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// + +#include "common.hpp" + +sycl::half * ggml_sycl_fattn_kv_buffers::kv_buffer::ensure_half(size_t n_elems) { + const size_t need_bytes = n_elems * sizeof(sycl::half); + + if (capacity >= need_bytes) { + return ptr; + } + + if (ptr) { + SYCL_CHECK(CHECK_TRY_ERROR(qptr->wait())); + SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(ptr, *qptr))); + ptr = nullptr; + capacity = 0; + } + + size_t cap = 0; + while (cap < need_bytes) { + cap += CHUNK_SIZE; + } + + void * dev_ptr; + SYCL_CHECK( + CHECK_TRY_ERROR(dev_ptr = sycl::malloc_device( + cap, *qptr))); + + if (!dev_ptr) { + GGML_LOG_ERROR("%s: can't allocate %lu Bytes of memory on device\n", __func__, cap); + GGML_ABORT("fattn buffer alloc failed"); + } + + ptr = static_cast(dev_ptr); + capacity = cap; + return ptr; +} + +ggml_sycl_fattn_kv_buffers::kv_buffer::~kv_buffer() { +#ifdef DEBUG_SYCL_POOL + GGML_LOG_INFO("ggml_sycl_fattn_kv_buffer[%d]: %.2f MiB\n", device, capacity / 1024.0 / 1024.0); +#endif + if (ptr) { + SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(ptr, *qptr))); + } +} diff --git a/ggml/src/ggml-sycl/fattn-buffers.hpp b/ggml/src/ggml-sycl/fattn-buffers.hpp new file mode 100644 index 00000000000..c00461de620 --- /dev/null +++ b/ggml/src/ggml-sycl/fattn-buffers.hpp @@ -0,0 +1,63 @@ +// +// MIT license +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: MIT +// + +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// + +#ifndef GGML_SYCL_FATTN_BUFFERS_HPP +#define GGML_SYCL_FATTN_BUFFERS_HPP + +#include + +typedef sycl::queue *queue_ptr; + +struct ggml_sycl_fattn_kv_buffers { + // buffers grow in chunks of this size + static constexpr size_t CHUNK_SIZE = 16ull << 20; // 16 MiB + + struct kv_buffer { + kv_buffer(queue_ptr qptr_, int device_) : qptr(qptr_), device(device_) {} + ~kv_buffer(); + + kv_buffer(const kv_buffer &) = delete; + kv_buffer & operator=(const kv_buffer &) = delete; + + sycl::half * ensure_half(size_t n_elems); + + private: + sycl::half * ptr = nullptr; + size_t capacity = 0; + queue_ptr qptr = nullptr; + [[maybe_unused]] int device = 0; + }; + + kv_buffer K; + kv_buffer V; + + ggml_sycl_fattn_kv_buffers(queue_ptr qptr, int device) : K(qptr, device), V(qptr, device) {} + + ggml_sycl_fattn_kv_buffers(const ggml_sycl_fattn_kv_buffers &) = delete; + ggml_sycl_fattn_kv_buffers & operator=(const ggml_sycl_fattn_kv_buffers &) = delete; +}; + +/** + * Imitates `ggml_sycl_pool_alloc` to keep the code calling alloc unchanged. + */ +struct ggml_sycl_fattn_alloc { + ggml_sycl_fattn_kv_buffers::kv_buffer & buf; + sycl::half * ptr = nullptr; + + explicit ggml_sycl_fattn_alloc(ggml_sycl_fattn_kv_buffers::kv_buffer & buf_) : buf(buf_) {} + + sycl::half * alloc(size_t n_elems) { + ptr = buf.ensure_half(n_elems); + return ptr; + } +}; +#endif diff --git a/ggml/src/ggml-sycl/fattn-common.hpp b/ggml/src/ggml-sycl/fattn-common.hpp index ed00d03c3b6..03f0c2623c8 100644 --- a/ggml/src/ggml-sycl/fattn-common.hpp +++ b/ggml/src/ggml-sycl/fattn-common.hpp @@ -5,6 +5,7 @@ #include "common.hpp" #include "convert.hpp" #include "vecdotq.hpp" +#include "fattn-buffers.hpp" #include "ggml.h" @@ -918,12 +919,13 @@ void launch_fattn( GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16); ggml_sycl_pool & pool = ctx.pool(); + ggml_sycl_fattn_kv_buffers & fbuf = ctx.fattn_buffers(); dpct::queue_ptr main_stream = ctx.stream(); const int id = ggml_sycl_get_device(); const int nsm = ggml_sycl_info().devices[id].nsm; - ggml_sycl_pool_alloc K_f16(pool); - ggml_sycl_pool_alloc V_f16(pool); + ggml_sycl_fattn_alloc K_f16(fbuf.K); + ggml_sycl_fattn_alloc V_f16(fbuf.V); ggml_sycl_pool_alloc KV_max(pool); ggml_sycl_pool_alloc dst_tmp(pool); ggml_sycl_pool_alloc dst_tmp_meta(pool); diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index b6e705cdf3a..e7768b8bf61 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -1286,6 +1286,23 @@ struct ggml_sycl_pool_leg : public ggml_sycl_pool { explicit ggml_sycl_pool_leg(queue_ptr qptr_, int device_) : device(device_), qptr(qptr_) {} ~ggml_sycl_pool_leg() { +#ifdef DEBUG_SYCL_POOL + int n_cached = 0; + size_t bytes_cached = 0; + for (int i = 0; i < MAX_SYCL_BUFFERS; ++i) { + if (buffer_pool[i].ptr != nullptr) { + ++n_cached; + bytes_cached += buffer_pool[i].size; + } + } + GGML_LOG_INFO("%s: %d buffers, cached = %.2f MiB\n", __func__, + n_cached, bytes_cached / 1024.0 / 1024.0); + const auto slots = format_slots_in_alloc_order(); + if (!slots.empty()) { + GGML_LOG_INFO("%s: slots MiB: %s\n", __func__, slots.c_str()); + } +#endif + for (int i = 0; i < MAX_SYCL_BUFFERS; ++i) { ggml_sycl_buffer & b = buffer_pool[i]; if (b.ptr != nullptr) { @@ -1296,6 +1313,26 @@ struct ggml_sycl_pool_leg : public ggml_sycl_pool { GGML_ASSERT(pool_size == 0); } +#ifdef DEBUG_SYCL_POOL + std::string format_slots_in_alloc_order() const { + std::string line; + char buf[32]; + bool first = true; + for (int i = 0; i < MAX_SYCL_BUFFERS; ++i) { + if (buffer_pool[i].ptr == nullptr) { + continue; + } + if (!first) { + line += '/'; + } + first = false; + snprintf(buf, sizeof(buf), "%.2f", buffer_pool[i].size / 1024.0 / 1024.0); + line += buf; + } + return line; + } +#endif + void * alloc(size_t size, size_t * actual_size) override { #ifdef DEBUG_sycl_MALLOC int nnz = 0; @@ -1459,6 +1496,10 @@ std::unique_ptr ggml_backend_sycl_context::new_pool_for_device(q return std::unique_ptr(new ggml_sycl_pool_leg(qptr, device)); } +std::unique_ptr ggml_backend_sycl_context::new_fattn_kv_buffers(queue_ptr qptr, int device) { + return std::unique_ptr(new ggml_sycl_fattn_kv_buffers(qptr, device)); +} + // TBD pool with virtual memory management // struct ggml_sycl_pool_vmm : public ggml_sycl_pool From 7072bdab9233401c45d27035a4cdbebd1d5bef49 Mon Sep 17 00:00:00 2001 From: scutler-nv Date: Sun, 10 May 2026 02:05:22 -0700 Subject: [PATCH 582/831] internal AllReduce kernel for CUDA provider (llama/22299) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * ggml-cuda: add internal AllReduce provider for tensor parallelism Introduces a NCCL-free AllReduce implementation for LLAMA_SPLIT_MODE_TENSOR using a single-phase CUDA kernel that pipelines D2H copy, cross-GPU handshake via pinned-memory volatile flags, and the reduction in one kernel launch per GPU. New files: - ggml/src/ggml-cuda/comm.cuh — ggml_cuda_allreduce_provider enum - ggml/src/ggml-cuda/allreduce.cuh — pipeline API declarations - ggml/src/ggml-cuda/allreduce.cu — kernel + pipeline init/dispatch ggml-cuda.cu changes: - ggml_backend_cuda_comm_context gains ar_pipeline field - Provider selection via GGML_CUDA_ALLREDUCE env var ("nccl" / "internal") - INTERNAL provider initialises the pipeline at comm_init time - Dispatch routes to ggml_cuda_ar_allreduce(); falls back to meta-backend CPU reduce for unsupported sizes or GPU counts (> 2) Current scope: 2 GPUs, FP32, tensors <= 256 KB. Notes in NOTES-allreduce.md. Co-Authored-By: Claude Sonnet 4.6 * llama-bench: add --allreduce flag to select AllReduce provider Adds --allreduce to llama-bench (and via the shared field pattern, consistent with other multi-value flags). Useful for isolating hangs or regressions in tensor-parallel mode: pass --allreduce nccl to force NCCL and bypass the internal provider. Also fixes ggml_cuda_select_allreduce_provider() to treat an empty GGML_CUDA_ALLREDUCE env var the same as unset (avoids spurious warning when llama-bench sets it to "" for the "auto" case). Co-Authored-By: Claude Sonnet 4.6 xt gains ar_pipeline field - Provider selection via GGML_CUDA_ALLREDUCE env var ("nccl" / "internal") - INTERNAL provider initialises the pipeline at comm_init time - Dispatch routes to ggml_cuda_ar_allreduce(); falls back to meta-backend CPU reduce for unsupported sizes or GPU counts (> 2) Current scope: 2 GPUs, FP32, tensors <= 256 KB. Notes in NOTES-allreduce.md. Co-Authored-By: Claude Sonnet 4.6 * llama-bench: rename --allreduce to --reduction-provider / -rp Co-Authored-By: Claude Sonnet 4.6 via the shared field pattern, consistent with other multi-value flags). Useful for isolating hangs or regressions in tensor-parallel mode: pass --allreduce nccl to force NCCL and bypass the internal provider. Also fixes ggml_cuda_select_allreduce_provider() to treat an empty GGML_CUDA_ALLREDUCE env var the same as unset (avoids spurious warning when llama-bench sets it to "" for the "auto" case). Co-Authored-By: Claude Sonnet 4.6 xt gains ar_pipeline field - Provider selection via GGML_CUDA_ALLREDUCE env var ("nccl" / "internal") - INTERNAL provider initialises the pipeline at comm_init time - Dispatch routes to ggml_cuda_ar_allreduce(); falls back to meta-backend CPU reduce for unsupported sizes or GPU counts (> 2) Current scope: 2 GPUs, FP32, tensors <= 256 KB. Notes in NOTES-allreduce.md. Co-Authored-By: Claude Sonnet 4.6 * llama-bench: pass WARN/ERROR log messages through in non-verbose mode The null log callback was silently dropping all messages. WARN and ERROR should always be visible since they indicate legitimate issues (e.g. a requested reduction provider not being available). Co-Authored-By: Claude Sonnet 4.6 vider. Also fixes ggml_cuda_select_allreduce_provider() to treat an empty GGML_CUDA_ALLREDUCE env var the same as unset (avoids spurious warning when llama-bench sets it to "" for the "auto" case). Co-Authored-By: Claude Sonnet 4.6 xt gains ar_pipeline field - Provider selection via GGML_CUDA_ALLREDUCE env var ("nccl" / "internal") - INTERNAL provider initialises the pipeline at comm_init time - Dispatch routes to ggml_cuda_ar_allreduce(); falls back to meta-backend CPU reduce for unsupported sizes or GPU counts (> 2) Current scope: 2 GPUs, FP32, tensors <= 256 KB. Notes in NOTES-allreduce.md. Co-Authored-By: Claude Sonnet 4.6 * cmake: improve NCCL detection for source-tree builds, add static/dynamic switch FindNCCL.cmake now searches the cmake source-build layout used by the Windows NCCL port (cmake/lib/Release for static, cmake/src/Release for dynamic import lib) and also checks src/include for the generated nccl.h header. New option GGML_CUDA_NCCL_STATIC (default OFF) selects static vs dynamic linking and controls which paths and library names are searched. Co-Authored-By: Claude Sonnet 4.6 for the "auto" case). Co-Authored-By: Claude Sonnet 4.6 xt gains ar_pipeline field - Provider selection via GGML_CUDA_ALLREDUCE env var ("nccl" / "internal") - INTERNAL provider initialises the pipeline at comm_init time - Dispatch routes to ggml_cuda_ar_allreduce(); falls back to meta-backend CPU reduce for unsupported sizes or GPU counts (> 2) Current scope: 2 GPUs, FP32, tensors <= 256 KB. Notes in NOTES-allreduce.md. Co-Authored-By: Claude Sonnet 4.6 * ggml-cuda: add AllReduce hang watchdog (GGML_CUDA_AR_WATCHDOG) When compiled with -DGGML_CUDA_AR_WATCHDOG=ON, uses a debug kernel variant that writes per-GPU spin diagnostics to pinned host memory. A host-side blocking poll (cudaEventQuery + volatile reads) detects hangs and logs WARN with the last observed arrival counters and spin counts, controlled by GGML_CUDA_AR_WATCHDOG (ms timeout) and GGML_CUDA_AR_MAX_SPIN (kernel bailout) env vars at runtime. Zero overhead on the production path — all debug code is behind #ifdef. Co-Authored-By: Claude Sonnet 4.6 ar_pipeline field - Provider selection via GGML_CUDA_ALLREDUCE env var ("nccl" / "internal") - INTERNAL provider initialises the pipeline at comm_init time - Dispatch routes to ggml_cuda_ar_allreduce(); falls back to meta-backend CPU reduce for unsupported sizes or GPU counts (> 2) Current scope: 2 GPUs, FP32, tensors <= 256 KB. Notes in NOTES-allreduce.md. Co-Authored-By: Claude Sonnet 4.6 * ggml-cuda: fix intermittent AllReduce hang on Blackwell PCIe Add __threadfence_system() before the arrival signal write in signal_set to ensure D2H data is globally visible before the peer observes the arrival flag. Without this fence, the peer could enter Phase 3 host reads before the data had fully landed, causing an intermittent deadlock on RTX 5090 (Blackwell, PCIe-only). Also redesign the watchdog from a blocking dispatch-thread poll to a non-blocking background thread, eliminating the ~20ms per-slot latency the old design added. Verified: 30/30 soak test runs clean at ~50 t/s (previously ~1-in-15 hang rate). Co-Authored-By: Claude Sonnet 4.6 - INTERNAL provider initialises the pipeline at comm_init time - Dispatch routes to ggml_cuda_ar_allreduce(); falls back to meta-backend CPU reduce for unsupported sizes or GPU counts (> 2) Current scope: 2 GPUs, FP32, tensors <= 256 KB. Notes in NOTES-allreduce.md. Co-Authored-By: Claude Sonnet 4.6 * ggml-cuda: fix watchdog shutdown ordering and pipeline_free drain - Stop watchdog thread BEFORE destroying GPU resources (events, streams) to prevent polling destroyed handles → spurious "busy" readings - Add cudaStreamSynchronize in pipeline_free to drain in-flight kernels before freeing pinned host buffers they may still be reading - Sleep-first watchdog polling: no +0ms noise, only logs when a kernel is genuinely stuck past the poll interval - Check wdog_stop in both outer and inner loops so join() returns promptly instead of draining the entire queue - Add Phase 3 breadcrumbs to debug[3] for hang localization Co-Authored-By: Claude Sonnet 4.6 RNAL provider initialises the pipeline at comm_init time - Dispatch routes to ggml_cuda_ar_allreduce(); falls back to meta-backend CPU reduce for unsupported sizes or GPU counts (> 2) Current scope: 2 GPUs, FP32, tensors <= 256 KB. Notes in NOTES-allreduce.md. Co-Authored-By: Claude Sonnet 4.6 * ggml-cuda: replace event-based watchdog with per-GPU ring buffer Completely rework the GGML_CUDA_AR_WATCHDOG system: - Replace the shared debug_buf + event-polling + queue design with per-GPU ring buffers in pinned host memory - Kernel writes a debug record only on spin-limit bailout: claims a ring slot via atomicAdd (single-GPU host atomics work on RTX 5090), writes fields, fences, sets completion flag, then all threads exit - Watchdog thread simply polls ring head counters every 1ms and prints any new complete records — no CUDA event queries, no mutex, no queue - Zero overhead on the dispatch path (no queue posting, no memset) - Watchdog shutdown returns within ~1ms (atomic bool, no drain) - On bailout the kernel skips Phase 3 entirely and exits cleanly Verified: 20/20 prefill soak test clean at ~1112 t/s, no hangs. Co-Authored-By: Claude Sonnet 4.6 P32, tensors <= 256 KB. Notes in NOTES-allreduce.md. Co-Authored-By: Claude Sonnet 4.6 * fix: normalize line endings to LF (undo Windows CRLF conversion) Five files were inadvertently converted to CRLF by the Windows development environment, causing every line to show as changed in diffs against master. Co-Authored-By: Claude Sonnet 4.6 imit bailout: claims a ring slot via atomicAdd (single-GPU host atomics work on RTX 5090), writes fields, fences, sets completion flag, then all threads exit - Watchdog thread simply polls ring head counters every 1ms and prints any new complete records — no CUDA event queries, no mutex, no queue - Zero overhead on the dispatch path (no queue posting, no memset) - Watchdog shutdown returns within ~1ms (atomic bool, no drain) - On bailout the kernel skips Phase 3 entirely and exits cleanly Verified: 20/20 prefill soak test clean at ~1112 t/s, no hangs. Co-Authored-By: Claude Sonnet 4.6 P32, tensors <= 256 KB. Notes in NOTES-allreduce.md. Co-Authored-By: Claude Sonnet 4.6 * .gitattributes: force LF line endings to prevent Windows CRLF conversion Co-Authored-By: Claude Sonnet 4.6 elopment environment, causing every line to show as changed in diffs against master. Co-Authored-By: Claude Sonnet 4.6 imit bailout: claims a ring slot via atomicAdd (single-GPU host atomics work on RTX 5090), writes fields, fences, sets completion flag, then all threads exit - Watchdog thread simply polls ring head counters every 1ms and prints any new complete records — no CUDA event queries, no mutex, no queue - Zero overhead on the dispatch path (no queue posting, no memset) - Watchdog shutdown returns within ~1ms (atomic bool, no drain) - On bailout the kernel skips Phase 3 entirely and exits cleanly Verified: 20/20 prefill soak test clean at ~1112 t/s, no hangs. Co-Authored-By: Claude Sonnet 4.6 P32, tensors <= 256 KB. Notes in NOTES-allreduce.md. Co-Authored-By: Claude Sonnet 4.6 * ggml-cuda: move GGML_CUDA_AR_WATCHDOG from CMake option to local define The watchdog is development-only; a global CMake option is overkill. Move the toggle to a #define at the top of allreduce.cu (set to 0 by default) and remove the option from ggml/CMakeLists.txt and the CUDA CMakeLists.txt add_compile_definitions block. Co-Authored-By: Claude Sonnet 4.6 fences, sets completion flag, then all threads exit - Watchdog thread simply polls ring head counters every 1ms and prints any new complete records — no CUDA event queries, no mutex, no queue - Zero overhead on the dispatch path (no queue posting, no memset) - Watchdog shutdown returns within ~1ms (atomic bool, no drain) - On bailout the kernel skips Phase 3 entirely and exits cleanly Verified: 20/20 prefill soak test clean at ~1112 t/s, no hangs. Co-Authored-By: Claude Sonnet 4.6 P32, tensors <= 256 KB. Notes in NOTES-allreduce.md. Co-Authored-By: Claude Sonnet 4.6 * unify kernel debug paths * use __threadfence_system explicitly (not in ggml_cuda_ar_signal_set) * preferentially use internal reduction for <=2 GPUs * templatize the main kernel to support fp16/bf16 * restore llama-bench.cpp changes * revert CMakeLists changes * remove notes from repo * remove dead warmup code * fix comments * improve reduction provider fallback code * add messages for allreduce fallback * rework reduction provider init to not call ncclCommInitAll if using the internal provider * fix case where a given tensor has not been computed * add chunked mode to the kernel for unlimited vector size * rework a few checks/fallbacks * various small cleanups * allow disabling CUDA reductions completely (falling back to the non-CUDA butterfly mode) * simplify reduction provider selection * minor simplifications * more cleanups/fixes * prototype alternate path for large reductions * chunked version of large reduction path * use bf16 for large reductions * experimental reduction using cudaMemcpyPeerAsync (slightly slower) * revert experimental change * add combined conversion/reduction kernel * add bf16 wire format for single kernel mode * experimental on-stream small reduction kernel * double buffer arrival slots, use token (incrementing) method * double buffer host_buf for small reductions * put in waits for use of host_mem in large reduction case (prevents stomping on in-use memory * remove watchdog code * various cleanups / dead code removal * fix fp16 mode * fix some comments/logging statements * use increasing token scheme for arrival signals * add top-level comment to allreduce.cu * improve top-level comment in allreduce.cu * fix comments in ggml_cuda_ar_kernel * improve event handling for hostmem buffer usage tracking * change ev_pool to fixed 2D array * add chunked memcpy fallback for extra-large reductions (>32 MB) * change thresholds for copy-engine path and bf16 demotion * multi-block kernel test * more fine-tuning for chukn-size, etc. * various fixes for PR review * more PR fixes * fix semantics of all host mappings * require ampere+ * small cleanups * properly use host pointer for src/dst in cudaMemcpy calls * allreduce: lazy-init the internal pipeline on first use A config that lives entirely on NCCL never needs the chunked-kernel pipeline (host_buf, host_large, dev_tmp, streams, events, arrival ring). Defer pipeline creation to the first try_allreduce_internal call using the same std::call_once pattern as ensure_nccl, so those resources stay unallocated when only NCCL is in use. Co-Authored-By: Claude Opus 4.7 (1M context) * allreduce: assert n_backends == 2 instead of soft-fallback ar_pipeline_init already requires n_devices == 2 and bails before any AR can get here, so by the time we reach try_allreduce_internal we know we have exactly two backends. Replace the runtime-debug-log fallback with a hard assert. Co-Authored-By: Claude Opus 4.7 (1M context) NCCL is in use. Co-Authored-By: Claude Opus 4.7 (1M context) * rework reduction provider selection. internal/nccl is OS dependent; most fallbacks are removed * remove unneeded Turing arch check (llama.cpp doesn't even compile pre-Turing anyway) * allreduce: ASCII-only comments and ggml_cuda_cast for value conversions Replace non-ASCII characters in comments (em dashes, right arrows) with ASCII equivalents (--, ->) so the source stays in the ggml/upstream norm. In the kernel-side code, replace static_cast/static_cast with ggml_cuda_cast<...> so the BF16 conversions go through the fast __float2bfloat16 / __bfloat162float intrinsics from convert.cuh. Pure pointer and integer casts stay as static_cast. Also drops two stray garbage tokens that snuck in from earlier merges (a duplicated 'return ok; }' tail in allreduce.cu and a leftover '_reg)' fragment in ggml-cuda.cu). Co-Authored-By: Claude Opus 4.7 (1M context) * allreduce: use ggml_cuda_memcpy_1 for the chunked-kernel vector copies The chunked kernel's two 16-byte register<->host transfers (Phase 1 store and Phase 3 load) used reinterpret_cast on both sides. Replace with ggml_cuda_memcpy_1, which is the canonical helper for this pattern and emits the same int4 LD/ST under the hood. Conformance passes; 5x reruns of 70b internal pp512 show 1832-1836 t/s, matching the prior matrix value of 1831 t/s -- no perf change as expected. Co-Authored-By: Claude Opus 4.7 (1M context) ok; }' tail in allreduce.cu and a leftover '_reg)' fragment in ggml-cuda.cu). Co-Authored-By: Claude Opus 4.7 (1M context) * allreduce: assert cuda_ctx->device matches the pipeline's device Both ggml_cuda_ar_pipeline and ggml_backend_cuda_context carry the device they were created for; if they ever disagree, every cuda call that follows runs on the wrong device. Add GGML_ASSERT at each cuda_ctx retrieval site in the AR path so the misuse fails fast rather than silently corrupting. Also: rename __nv_bfloat16 -> nv_bfloat16 (typedef alias) for consistency with the rest of the file, and tighten one cudaGetLastError check to fire only after the to_bf16 call that can actually fail. Co-Authored-By: Claude Opus 4.7 (1M context) gml-cuda.cu). Co-Authored-By: Claude Opus 4.7 (1M context) * allreduce: expand one-liner for loops to braced bodies Code-style preference -- match the rest of the file by writing every for loop with the body on its own braced line. Three sites in the copy-engine typed dispatch. Co-Authored-By: Claude Opus 4.7 (1M context) in the AR path so the misuse fails fast rather than silently corrupting. Also: rename __nv_bfloat16 -> nv_bfloat16 (typedef alias) for consistency with the rest of the file, and tighten one cudaGetLastError check to fire only after the to_bf16 call that can actually fail. Co-Authored-By: Claude Opus 4.7 (1M context) gml-cuda.cu). Co-Authored-By: Claude Opus 4.7 (1M context) * allreduce: rename template parameters Tdst/Twire/Tsrc -> T_dst/T_wire/T_src Code-style preference per PR review -- T_dst/T_wire/T_src is more consistent with surrounding code. Whole-word rename across all 58 sites in allreduce.cu (kernel definitions, internal uses, and comment text). Realigned the parameter columns in three function signatures whose T_src/T_dst lines shifted by 1 char relative to their non-templated neighbors. Co-Authored-By: Claude Opus 4.7 (1M context) to fire only after the to_bf16 call that can actually fail. Co-Authored-By: Claude Opus 4.7 (1M context) gml-cuda.cu). Co-Authored-By: Claude Opus 4.7 (1M context) * allreduce: drop hyphen in 'chunked-kernel' across comments Per PR review feedback -- 'chunked kernel' (no hyphen) reads more naturally in running prose, especially for ESL readers. Pure comment-only change; all 10 occurrences in allreduce.cu updated. Co-Authored-By: Claude Opus 4.7 (1M context) three function signatures whose T_src/T_dst lines shifted by 1 char relative to their non-templated neighbors. Co-Authored-By: Claude Opus 4.7 (1M context) to fire only after the to_bf16 call that can actually fail. Co-Authored-By: Claude Opus 4.7 (1M context) gml-cuda.cu). Co-Authored-By: Claude Opus 4.7 (1M context) * allreduce: use ggml_cuda_get_max_cpy_bytes() instead of hardcoded 16 The chunked kernel hardcoded a 16-byte vector unit; replace with the ggml_cuda_get_max_cpy_bytes() helper that fattn-common.cuh uses for the same purpose, so ELEMS_PER_VEC self-adjusts to the arch's widest single-instruction copy. Perf-neutral on supported targets (Volta+ returns 16). Co-Authored-By: Claude Opus 4.7 (1M context) hbors. Co-Authored-By: Claude Opus 4.7 (1M context) to fire only after the to_bf16 call that can actually fail. Co-Authored-By: Claude Opus 4.7 (1M context) gml-cuda.cu). Co-Authored-By: Claude Opus 4.7 (1M context) * ggml-cuda: PR review fixes -- annotate #endif, fix stale comment, assert nbytes alignment Three separate but minor changes from PR #22299 review feedback: 1. Annotate the five GGML_USE_NCCL #endif lines with the matching condition so the pairing is visible without scrolling back. 2. The comment block on ggml_backend_cuda_comm_context claimed NCCL is lazy-initialised; that was true at one point but the dispatch refactor (727b141c0) made both NCCL and the internal pipeline eager. Rewrite the comment to match current behaviour. 3. Assert in ggml_backend_cuda_comm_allreduce_internal that the tensor's byte size is a 16-byte multiple. The chunked-kernel issues full-width vector loads/stores, so this is a precondition; tensor-parallel splits of hidden-dim-multiples satisfy it trivially, but a hard assert turns any caller-side bug into a clear failure rather than UB. Co-Authored-By: Claude Opus 4.7 (1M context) device's new AR records its ev.ker -- otherwise the second device's wait sees the first device's just-recorded event (the in-flight new AR) and creates a circular dependency with the in-kernel peer signal. Two-pass dispatch (all waits, then all launches) avoids this. Bump POOL_SIZE 2 -> 8 (small memory cost, more breathing room for the GPU's view of the event chain) and add a runtime env override for the hybrid kernel chunk size (GGML_CUDA_AR_HYBRID_CHUNK_BYTES) for tuning. One-shot stderr diagnostic at first AR prints the chosen path + sizing. Result on 2x RTX 5090 Linux, 70b ub_sweep: ub=64 (1 MB AR): 913 -> 1036 t/s (+13.5% vs old, +1.8% vs NCCL) ub=128 (2 MB AR): 1056 -> 1181 (+11.9%, +3.7% vs NCCL) ub=256 (4 MB AR): 1212 -> 1424 (+17.5%, +3.5% vs NCCL) Internal now beats NCCL at every size (+1.8% to +15.6%), recovering all ground in the 1-4 MB regime that was previously a 10-12% loss. Co-Authored-By: Claude Opus 4.7 (1M context) * simplify the init logic * address some other PR requests * ggml-cuda: stub internal AllReduce on HIP/MUSA, drop pre-Ampere mention, gate NCCL fallback warning on !HIP The internal AllReduce relies on cudaHostAllocPortable/Mapped, cudaHostGetDevicePointer, and __nanosleep -- none of which the HIP or MUSA shims expose -- so wrap the implementation in !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) and provide nullptr/no-op/false stubs in the #else branch. The dispatcher already treats a null pipeline as init failure and silently falls back to the meta backend's generic AllReduce, so HIP/MUSA builds compile clean and behave correctly without further call-site changes. PR review follow-ups: - drop "or pre-Ampere?" from the internal-init failure warning -- the kernel doesn't require Ampere or newer. - guard the "NCCL not compiled in" fallback warning behind !defined(GGML_USE_HIP); the suggestion to install NCCL only makes sense on NVIDIA builds. Co-Authored-By: Claude Opus 4.7 (1M context) hind, now +6-8% ahead at ub=1024-4096. Perplexity (32 chunks) matches NCCL bit-for-bit (3.4044 vs 3.4043). Co-Authored-By: Claude Opus 4.7 (1M context) * allreduce: guard __nanosleep on Volta+ and reject pre-Volta devices at init __nanosleep is the only Volta-specific intrinsic in the kernel; wrap it in #if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA / NO_DEVICE_CODE so the file still compiles cleanly when targeting older arches (the dispatcher's init check below ensures the kernel is never actually launched on pre-Volta). Add a per-device compute-capability check in pipeline_init that returns nullptr if any device is below sm70. The dispatcher already treats nullptr as init failure and silently falls back to the meta backend's generic AllReduce. Co-Authored-By: Claude Opus 4.7 (1M context) rom the internal-init failure warning -- the kernel doesn't require Ampere or newer. - guard the "NCCL not compiled in" fallback warning behind !defined(GGML_USE_HIP); the suggestion to install NCCL only makes sense on NVIDIA builds. Co-Authored-By: Claude Opus 4.7 (1M context) hind, now +6-8% ahead at ub=1024-4096. Perplexity (32 chunks) matches NCCL bit-for-bit (3.4044 vs 3.4043). Co-Authored-By: Claude Opus 4.7 (1M context) * allreduce: fix CI -Werror warnings (sign-compare, format, restrict alias, maybe-uninitialized) The CUDA CI builds with -Werror -Wsign-compare -Wformat -Wrestrict -Wmaybe-uninitialized. Address each: - n_devices is size_t; change `int i; i < n_devices` to size_t in the three init loops, and the matching GGML_LOG_INFO format from %d to %zu. - ggml_cuda_ar_kernel was launched with sendbuf == recvbuf (in-place reduction), so the __restrict__ qualifiers on those parameters were technically UB. Drop __restrict__ from sendbuf and recvbuf; an A/B sweep showed <0.6% perf delta (within noise) on Linux. - The buf/src/dst pointer arrays in ggml_cuda_ar_allreduce and the per-iteration arrays in ggml_cuda_ar_allreduce_copy_outer were declared with size GGML_CUDA_MAX_DEVICES but the loop only writes indices [0, n_devices); zero-initialise so the compiler sees the tail elements as defined. Co-Authored-By: Claude Opus 4.7 (1M context) now +6-8% ahead at ub=1024-4096. Perplexity (32 chunks) matches NCCL bit-for-bit (3.4044 vs 3.4043). Co-Authored-By: Claude Opus 4.7 (1M context) * ggml-cuda: drop unused-function warning by guarding try_allreduce_nccl behind GGML_USE_NCCL The only call site (in init_nccl) is already inside #ifdef GGML_USE_NCCL, so the function is unreferenced in non-NCCL builds and trips nvcc's -Werror=unused-function check. Move the guard from inside the function body to around the entire definition. Co-Authored-By: Claude Opus 4.7 (1M context) ce reduction), so the __restrict__ qualifiers on those parameters were technically UB. Drop __restrict__ from sendbuf and recvbuf; an A/B sweep showed <0.6% perf delta (within noise) on Linux. - The buf/src/dst pointer arrays in ggml_cuda_ar_allreduce and the per-iteration arrays in ggml_cuda_ar_allreduce_copy_outer were declared with size GGML_CUDA_MAX_DEVICES but the loop only writes indices [0, n_devices); zero-initialise so the compiler sees the tail elements as defined. Co-Authored-By: Claude Opus 4.7 (1M context) now +6-8% ahead at ub=1024-4096. Perplexity (32 chunks) matches NCCL bit-for-bit (3.4044 vs 3.4043). Co-Authored-By: Claude Opus 4.7 (1M context) --------- Co-authored-by: Claude Sonnet 4.6 --- ggml/src/ggml-cuda/allreduce.cu | 968 +++++++++++++++++++++++++++++++ ggml/src/ggml-cuda/allreduce.cuh | 29 + ggml/src/ggml-cuda/ggml-cuda.cu | 265 +++++++-- 3 files changed, 1205 insertions(+), 57 deletions(-) create mode 100644 ggml/src/ggml-cuda/allreduce.cu create mode 100644 ggml/src/ggml-cuda/allreduce.cuh diff --git a/ggml/src/ggml-cuda/allreduce.cu b/ggml/src/ggml-cuda/allreduce.cu new file mode 100644 index 00000000000..434689abd95 --- /dev/null +++ b/ggml/src/ggml-cuda/allreduce.cu @@ -0,0 +1,968 @@ +#include "allreduce.cuh" + +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + +#include "convert.cuh" +#include "ggml-impl.h" + +#include +#include +#include +#include + +// --------------------------------------------------------------------------- +// CUDA AllReduce for tensor-parallel inference across two GPUs. +// +// Provides an in-place sum reduction over matching tensors on two CUDA +// devices in the same process. Used by the tensor-split path alongside +// NCCL; targets setups without NVLink, where data is exchanged between the +// GPUs by staging it through pinned host memory over PCIe. +// +// Two reduction strategies are selected per call by tensor size: +// +// * Chunked kernel path (small reductions): a single CUDA kernel both +// stages data through pinned host memory and performs the local sum. +// Cross-GPU synchronization happens *inside the kernel* (busy-wait on +// a host-memory flag), which keeps launch overhead low for the +// latency-sensitive token-generation case. +// +// * Copy-engine path (large reductions): the transfer is split into +// D2H + H2D cudaMemcpyAsync chunks driven by the GPU's copy engine, +// followed by a small device-side add kernel. Cross-GPU +// synchronization happens *outside the kernel*, via CUDA events +// between streams. This keeps the compute engine free while large +// transfers are in flight, which matters for prefill-sized tensors. +// Reductions larger than the per-call inner cap are processed by an +// outer chunker that issues sequential inner calls. +// --------------------------------------------------------------------------- + +// --------------------------------------------------------------------------- +// Cross-GPU signal mechanism +// +// One int per (slot, rank) pair in pinned host memory. Each AR call writes a +// strictly increasing token (= the AR call number) into its own arrival int. +// The peer spins until its read of the other's arrival int equals the token +// it expects for this call -- a mismatch means the peer hasn't arrived yet. +// Tokens never repeat over realistic call rates (32-bit int wraps in tens of +// days at thousands of ARs/sec), so arrival ints don't need to be reset +// between calls; we initialize once at pipeline init and let the values +// accumulate. +// +// There is exactly one writer (the owning GPU) and one reader (the peer), so +// we don't need atomics. A volatile store paired with __threadfence_system() +// provides the release ordering that makes the D2H writes visible system-wide +// before the arrival token is observed. +// +// atomicAdd_system() requires hostNativeAtomicSupported, which is unavailable +// on PCIe-attached consumer GPUs without NVLink, so the volatile path is the +// portable choice. +// --------------------------------------------------------------------------- + +static __device__ __forceinline__ void ggml_cuda_ar_signal_set(int * p, int token) { + *(volatile int *)p = token; +} +static __device__ __forceinline__ int ggml_cuda_ar_signal_get(const int * p) { + return *(const volatile int *)p; +} + +// Byte spacing between adjacent arrival ints. 64 bytes (one cache line) +// ensures each GPU/block's arrival slot lives on its own line, preventing +// false-sharing stalls on the polling GPU. +static constexpr size_t GGML_CUDA_AR_ARRIVAL_STRIDE = 64; + +// Number of blocks the chunked kernel launches with. Each block stripes a +// disjoint slice of the data and synchronizes through its own arrival-token +// slot so multiple SMs can pump PCIe stores in parallel. +static constexpr int GGML_CUDA_AR_KERNEL_BLOCKS = 8; + +// --------------------------------------------------------------------------- +// Chunked kernel AllReduce -- 2 GPUs, supports float, half, and bfloat16. +// +// Both GPUs run this kernel simultaneously on independent streams. sendbuf +// and recvbuf live in T_dst (the caller's tensor type); host_mine / host_other +// carry data in T_wire (the on-wire type, possibly narrower than T_dst -- e.g. +// T_dst=F32 with T_wire=BF16 halves the bytes pushed across PCIe). When +// T_dst == T_wire the casts below are no-ops. +// +// Each GPU runs three phases: +// +// Phase 1 (all threads): cast sendbuf (T_dst) -> T_wire and store as +// single-instruction-width vectors into host_mine. +// __threadfence_system() commits these writes to host +// memory. +// Phase 2 (thread 0): write token to arrival_mine; spin until +// arrival_other == token. +// Phase 3 (all threads): read T_wire vectors from host_other, cast +// each element to T_dst, and sum with the local +// sendbuf value (also rounded through T_wire so that +// both GPUs truncate identically -- this guarantees +// bit-equivalent results across the two devices). +// +// Multi-block: blocks stripe vectors across (gridDim.x * blockDim.x) global +// threads to keep multiple SMs issuing PCIe stores in parallel. Each block +// has its own arrival-token slot (offset by blockIdx.x * ARRIVAL_STRIDE); +// thread 0 of each block signals/spins on that slot independently of other +// blocks. Tail elements (the leftover < ELEMS_PER_VEC at the end) are +// handled only by block 0 to avoid cross-block writes to the same slots. +// --------------------------------------------------------------------------- +template +static __global__ void ggml_cuda_ar_kernel( + const T_dst * sendbuf, + T_dst * recvbuf, + T_wire * __restrict__ host_mine, + const T_wire * __restrict__ host_other, + int count, + int * arrival_mine, + int * arrival_other, + int token) { + + // Vector unit for the wire type, sized to the arch's widest single-instruction + // copy (16 B on Volta+). Each phase-1 iter writes one vector to host memory; + // each phase-3 iter reads one and produces ELEMS_PER_VEC sums. + constexpr int ELEMS_PER_VEC = ggml_cuda_get_max_cpy_bytes() / sizeof(T_wire); + constexpr int ARRIVAL_INTS = (int)(GGML_CUDA_AR_ARRIVAL_STRIDE / sizeof(int)); + + const int tid = threadIdx.x; + const int nt = blockDim.x; + const int bid = blockIdx.x; + const int gtid = bid * nt + tid; + const int gnt = gridDim.x * nt; + const int count_vec = count / ELEMS_PER_VEC; + const int tail = count_vec * ELEMS_PER_VEC; + + // Phase 1: cast sendbuf (T_dst) -> host_mine (T_wire) and store as vectors. + { + for (int i = gtid; i < count_vec; i += gnt) { + const int off = i * ELEMS_PER_VEC; + T_wire wire[ELEMS_PER_VEC]; + #pragma unroll + for (int k = 0; k < ELEMS_PER_VEC; ++k) { + wire[k] = ggml_cuda_cast(sendbuf[off + k]); + } + ggml_cuda_memcpy_1(&host_mine[off], wire); + } + if (bid == 0 && tid < count - tail) { + host_mine[tail + tid] = ggml_cuda_cast(sendbuf[tail + tid]); + } + } + + // Commit this block's host writes before signalling. + __threadfence_system(); + __syncthreads(); + + // Phase 2: thread 0 of each block signals on its own arrival slot, then + // spins for the matching slot from peer. Per-block tokens mean blocks + // proceed independently -- no inter-block barrier needed. + if (tid == 0) { + int * my_slot = arrival_mine + bid * ARRIVAL_INTS; + const int * other_slot = arrival_other + bid * ARRIVAL_INTS; + + ggml_cuda_ar_signal_set(my_slot, token); + __threadfence_system(); // make our signal visible system-wide + + while (ggml_cuda_ar_signal_get(other_slot) != token) { +#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA + __nanosleep(100); +#else + NO_DEVICE_CODE; +#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA + } + } + + __syncthreads(); + + // Acquire peer's host_other writes (this block's stripe of them). + __threadfence_system(); + + // Phase 3: read peer's T_wire vector, cast both sides through T_wire for + // bit-equivalence, sum in T_dst precision, and write back to recvbuf. + { + for (int i = gtid; i < count_vec; i += gnt) { + const int off = i * ELEMS_PER_VEC; + T_wire wire[ELEMS_PER_VEC]; + ggml_cuda_memcpy_1(wire, &host_other[off]); + #pragma unroll + for (int k = 0; k < ELEMS_PER_VEC; ++k) { + const T_wire d_low = ggml_cuda_cast(sendbuf[off + k]); + recvbuf[off + k] = ggml_cuda_cast(d_low) + ggml_cuda_cast(wire[k]); + } + } + if (bid == 0 && tid < count - tail) { + const T_wire d_low = ggml_cuda_cast(sendbuf[tail + tid]); + recvbuf[tail + tid] = + ggml_cuda_cast(d_low) + ggml_cuda_cast(host_other[tail + tid]); + } + } +} + +// Combined load-convert-add kernel. The peer's contribution arrives as T_src +// (which may be a lower-precision type than T_dst when the BF16 round-trip is +// active). For bit-equivalence between the two GPUs, dst is first rounded +// through T_src's precision via ggml_cuda_cast -- peer already truncated its +// own value the same way before sending -- so both sides perform identical +// arithmetic. When T_dst == T_src the round-trip cast is a no-op. +template +static __global__ void ggml_cuda_ar_add_kernel( + T_dst * __restrict__ dst, + const T_src * __restrict__ src, + int count) { + const int tid = blockIdx.x * blockDim.x + threadIdx.x; + const int nt = gridDim.x * blockDim.x; + for (int i = tid; i < count; i += nt) { + const T_src d_low = ggml_cuda_cast(dst[i]); + dst[i] = ggml_cuda_cast(d_low) + ggml_cuda_cast(src[i]); + } +} + +// --------------------------------------------------------------------------- +// Pipeline structure +// --------------------------------------------------------------------------- + +// Number of slots in the event / arrival ring. Two slots is sufficient: +// lockstep guarantees the two GPUs are at most one AR (or chunk) apart, so +// slot[N%2] is always safe to reuse -- peer has already consumed slot[N%2] +// from AR N-2 by the time we get to AR N. acquire_slot's +// cudaEventSynchronize on ev.ker for both devices makes that consumption +// explicit before we overwrite host_buf[slot] for the new AR. +static constexpr int GGML_CUDA_AR_POOL_SIZE = 2; + +// Maximum chunk size (bytes per GPU) handled by one chunked kernel launch. +// Larger tensors are reduced by issuing multiple chunked launches. +static constexpr size_t GGML_CUDA_AR_MAX_BYTES = 1024 * 1024; // 1 MB + +// Copy-engine path: largest tensor accepted on this path; sets host_large / +// dev_tmp allocation size. +static constexpr size_t GGML_CUDA_AR_COPY_MAX_BYTES = 32 * 1024 * 1024; // 32 MB + +// AR wire size at which the copy-engine path takes over from the chunked- +// kernel path. Override via GGML_CUDA_AR_COPY_THRESHOLD. +static constexpr size_t GGML_CUDA_AR_COPY_THRESHOLD_DEFAULT = 1024 * 1024; // 1 MB +// Per-call CE chunk-size heuristic: chunk_bytes = clamp(nbytes / 4, MIN, MAX). +// The /4 keeps ~4 chunks in flight at any moment (good D2H/H2D overlap with +// the peer); the clamps cover the cases where nbytes/4 is too small (per- +// memcpy fixed cost dominates) or too large (chunk-level pipelining stalls). +// Env var GGML_CUDA_AR_COPY_CHUNK_BYTES can override with a fixed value. +static constexpr size_t GGML_CUDA_AR_COPY_CHUNK_BYTES_HEURISTIC_MIN = 512 * 1024; // 512 KB +static constexpr size_t GGML_CUDA_AR_COPY_CHUNK_BYTES_HEURISTIC_MAX = 2 * 1024 * 1024; // 2 MB +// Absolute floor that an env-var override is allowed to set; this caps the +// per-slot copy-event array. 256 KB -> up to 128 chunks per 32 MB tensor. +static constexpr size_t GGML_CUDA_AR_COPY_CHUNK_BYTES_MIN = 256 * 1024; +static constexpr int GGML_CUDA_AR_COPY_MAX_CHUNKS = + static_cast((GGML_CUDA_AR_COPY_MAX_BYTES + GGML_CUDA_AR_COPY_CHUNK_BYTES_MIN - 1) / + GGML_CUDA_AR_COPY_CHUNK_BYTES_MIN); + +struct ggml_cuda_ar_event_slot { + cudaEvent_t app = nullptr; // upstream computation complete + cudaEvent_t cpy[GGML_CUDA_AR_COPY_MAX_CHUNKS] = {}; // copy-engine D2H chunks complete + cudaEvent_t h2d = nullptr; // copy-engine H2Ds complete (handoff AR stream -> compute stream) + cudaEvent_t ker = nullptr; // AllReduce kernel complete +}; + +// Mapped pinned host allocation: cudaHostAlloc + cudaHostGetDevicePointer +// in one place, with the host handle preserved for cudaFreeHost. Used where +// the CPU never touches the buffer -- only the device reads/writes via the +// mapped device pointer. Required on systems where cudaDevAttrCanUseHost- +// PointerForRegisteredMem is 0 and the host pointer can't be used as a +// device pointer. +struct ggml_cuda_ar_host_mapping { + uint8_t * host = nullptr; // cudaFreeHost handle; also the H-side ptr for cudaMemcpyAsync + uint8_t * dev = nullptr; // device-side pointer for kernels / cudaMemset + + cudaError_t alloc(size_t bytes) { + cudaError_t rc = cudaHostAlloc(reinterpret_cast(&host), bytes, + cudaHostAllocPortable | cudaHostAllocMapped); + if (rc != cudaSuccess) { + host = nullptr; + return rc; + } + rc = cudaHostGetDevicePointer(reinterpret_cast(&dev), host, 0); + if (rc != cudaSuccess) { + cudaFreeHost(host); + host = nullptr; + dev = nullptr; + } + return rc; + } + + void free() { + if (host) { + cudaFreeHost(host); + host = nullptr; + dev = nullptr; + } + } +}; + +struct ggml_cuda_ar_pipeline { + int n_devices; + int devices[GGML_CUDA_MAX_DEVICES]; + size_t buf_bytes; // bytes per device in host_buf[] + size_t copy_bytes; // bytes per device in host_large[] / dev_tmp[] + size_t copy_threshold; + size_t copy_chunk_bytes; + size_t bf16_threshold; // tensors >= this size (bytes) are reduced via FP32->BF16 round-trip; 0 disables + uint64_t call_count; + + // Per-device resources. + ggml_cuda_ar_host_mapping host_buf[GGML_CUDA_MAX_DEVICES]; // pinned staging (chunked kernel) + ggml_cuda_ar_host_mapping host_large[GGML_CUDA_MAX_DEVICES]; // pinned staging (copy-engine) + char * dev_tmp[GGML_CUDA_MAX_DEVICES]; // device scratch for copy-engine path + cudaStream_t streams[GGML_CUDA_MAX_DEVICES]; // non-blocking + ggml_cuda_ar_event_slot ev_pool[GGML_CUDA_MAX_DEVICES][GGML_CUDA_AR_POOL_SIZE]; + + // Copy-engine: per-device "I finished reading my peer's host_large" + // event. Indexed by RECORDER device. Recorded same-device on streams[i] + // after stage 2's last H2D from host_large[peer]. Waited cross-device + // by peer's stage-1 stream before the next AR overwrites host_large[peer]. + cudaEvent_t host_large_read_done[GGML_CUDA_MAX_DEVICES]; + bool host_large_read_done_valid; + + // Copy-engine: per-device "my add_kernel is done with dev_tmp" event. + // Recorded on the compute stream after each add_kernel; the AR stream + // waits on it before the next copy_impl's H2D overwrites dev_tmp. Lets us + // single-buffer dev_tmp despite add_kernel running on a separate stream. + cudaEvent_t dev_tmp_kernel_done[GGML_CUDA_MAX_DEVICES]; + bool dev_tmp_kernel_done_valid; + + // Arrival ring: ARRIVAL_STRIDE bytes between adjacent ints. Mapped pinned + // memory; CPU never reads/writes -- only the kernel and cudaMemset. + // Use ggml_cuda_ar_arrival_ptr() to index. + ggml_cuda_ar_host_mapping arrival; +}; + +// Base pointer for the (slot, rank) per-block token block. The kernel adds +// blockIdx.x * (ARRIVAL_STRIDE/sizeof(int)) internally to land on its own slot. +static int * ggml_cuda_ar_arrival_ptr(const ggml_cuda_ar_pipeline * p, int slot, int rank) { + const size_t offset = ((size_t)slot * p->n_devices + rank) * + GGML_CUDA_AR_KERNEL_BLOCKS * GGML_CUDA_AR_ARRIVAL_STRIDE; + return reinterpret_cast(p->arrival.dev + offset); +} + +static uint64_t ggml_cuda_ar_env_u64(const char * name, uint64_t default_value) { + const char * value = getenv(name); + if (value == nullptr || value[0] == '\0') { + return default_value; + } + + char * end = nullptr; + const unsigned long long parsed = strtoull(value, &end, 10); + return end != value ? (uint64_t) parsed : default_value; +} + +struct ggml_cuda_ar_slot_info { + int slot; + int token; +}; + +static ggml_cuda_ar_slot_info ggml_cuda_ar_acquire_slot(ggml_cuda_ar_pipeline * p) { + const int slot = static_cast(p->call_count % GGML_CUDA_AR_POOL_SIZE); + const bool pool_lapped = p->call_count >= GGML_CUDA_AR_POOL_SIZE; + p->call_count++; + + if (pool_lapped) { + for (int i = 0; i < p->n_devices; ++i) { + ggml_cuda_set_device(p->devices[i]); + CUDA_CHECK(cudaEventSynchronize(p->ev_pool[i][slot].ker)); + } + } + + return { slot, (int) p->call_count }; +} + +// Per-AR copy-engine chunk size: env-var override if set, else heuristic +// (clamp(nbytes/4, HEURISTIC_MIN, HEURISTIC_MAX)). +static size_t ggml_cuda_ar_chunk_bytes(const ggml_cuda_ar_pipeline * p, size_t nbytes) { + if (p->copy_chunk_bytes > 0) { + return p->copy_chunk_bytes; + } + return std::min(GGML_CUDA_AR_COPY_CHUNK_BYTES_HEURISTIC_MAX, + std::max(GGML_CUDA_AR_COPY_CHUNK_BYTES_HEURISTIC_MIN, nbytes / 4)); +} + +static void ggml_cuda_ar_wait_for_compute( + ggml_cuda_ar_pipeline * p, ggml_backend_cuda_context * cuda_ctx, int rank, int slot) { + ggml_cuda_ar_event_slot & ev = p->ev_pool[rank][slot]; + CUDA_CHECK(cudaEventRecord(ev.app, cuda_ctx->stream())); + CUDA_CHECK(cudaStreamWaitEvent(p->streams[rank], ev.app)); +} + +// --------------------------------------------------------------------------- +// Init / free +// --------------------------------------------------------------------------- + +ggml_cuda_ar_pipeline * ggml_cuda_ar_pipeline_init(const int * devices, size_t n_devices) { + + if (n_devices != 2) { + GGML_LOG_DEBUG("%s: internal AllReduce only supports n_devices=2 (got %zu); " + "falling back\n", __func__, n_devices); + return nullptr; + } + + // The chunked kernel uses __nanosleep, which is sm70+ (Volta+). + for (size_t i = 0; i < n_devices; ++i) { + const int cc = ggml_cuda_info().devices[devices[i]].cc; + if (cc < GGML_CUDA_CC_VOLTA) { + GGML_LOG_DEBUG("%s: internal AllReduce requires compute capability >= %d " + "(device %d has cc=%d); falling back\n", + __func__, GGML_CUDA_CC_VOLTA, devices[i], cc); + return nullptr; + } + } + + auto * p = new ggml_cuda_ar_pipeline{}; + p->n_devices = n_devices; + p->copy_bytes = GGML_CUDA_AR_COPY_MAX_BYTES; + p->copy_threshold = ggml_cuda_ar_env_u64("GGML_CUDA_AR_COPY_THRESHOLD", GGML_CUDA_AR_COPY_THRESHOLD_DEFAULT); + // 0 = use the per-call heuristic (default). Non-zero env value forces a + // fixed chunk size for diagnostics, with a floor at COPY_CHUNK_BYTES_MIN. + p->copy_chunk_bytes = ggml_cuda_ar_env_u64("GGML_CUDA_AR_COPY_CHUNK_BYTES", 0); + if (p->copy_chunk_bytes > 0 && p->copy_chunk_bytes < GGML_CUDA_AR_COPY_CHUNK_BYTES_MIN) { + GGML_LOG_WARN("%s: GGML_CUDA_AR_COPY_CHUNK_BYTES=%zu below minimum %zu; clamping\n", + __func__, p->copy_chunk_bytes, GGML_CUDA_AR_COPY_CHUNK_BYTES_MIN); + p->copy_chunk_bytes = GGML_CUDA_AR_COPY_CHUNK_BYTES_MIN; + } + // Default 1: BF16 round-trip is always on for F32 inputs (any non-zero + // ne). Set GGML_CUDA_AR_BF16_THRESHOLD=0 to disable, or to a larger + // byte threshold to opt out for small tensors. + p->bf16_threshold = ggml_cuda_ar_env_u64("GGML_CUDA_AR_BF16_THRESHOLD", 1); + for (size_t i = 0; i < n_devices; ++i) { + p->devices[i] = devices[i]; + } + + // Per-device streams and event pools. + for (size_t i = 0; i < n_devices; ++i) { + ggml_cuda_set_device(p->devices[i]); + + cudaStream_t stream = nullptr; + if (cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking) != cudaSuccess) { + GGML_LOG_ERROR("%s: cudaStreamCreateWithFlags failed for device %d\n", + __func__, p->devices[i]); + ggml_cuda_ar_pipeline_free(p); + return nullptr; + } + p->streams[i] = stream; + + for (int s = 0; s < GGML_CUDA_AR_POOL_SIZE; ++s) { + bool ok = + cudaEventCreateWithFlags(&p->ev_pool[i][s].app, cudaEventDisableTiming) == cudaSuccess && + cudaEventCreateWithFlags(&p->ev_pool[i][s].h2d, cudaEventDisableTiming) == cudaSuccess && + cudaEventCreateWithFlags(&p->ev_pool[i][s].ker, cudaEventDisableTiming) == cudaSuccess; + for (int c = 0; ok && c < GGML_CUDA_AR_COPY_MAX_CHUNKS; ++c) { + ok = cudaEventCreateWithFlags(&p->ev_pool[i][s].cpy[c], cudaEventDisableTiming) == cudaSuccess; + } + if (!ok) { + GGML_LOG_ERROR("%s: cudaEventCreate failed for device %d slot %d\n", + __func__, p->devices[i], s); + ggml_cuda_ar_pipeline_free(p); + return nullptr; + } + } + + if (cudaEventCreateWithFlags(&p->host_large_read_done[i], cudaEventDisableTiming) != cudaSuccess) { + GGML_LOG_ERROR("%s: cudaEventCreate for host_large_read_done failed for device %d\n", + __func__, p->devices[i]); + ggml_cuda_ar_pipeline_free(p); + return nullptr; + } + if (cudaEventCreateWithFlags(&p->dev_tmp_kernel_done[i], cudaEventDisableTiming) != cudaSuccess) { + GGML_LOG_ERROR("%s: cudaEventCreate for dev_tmp_kernel_done failed for device %d\n", + __func__, p->devices[i]); + ggml_cuda_ar_pipeline_free(p); + return nullptr; + } + } + + // Arrival ring: cache-line padded so each GPU's int is on its own line. + const size_t arrival_bytes = + (size_t)GGML_CUDA_AR_POOL_SIZE * n_devices * + GGML_CUDA_AR_KERNEL_BLOCKS * GGML_CUDA_AR_ARRIVAL_STRIDE; + if (p->arrival.alloc(arrival_bytes) != cudaSuccess) { + GGML_LOG_ERROR("%s: alloc for arrival ring failed (%zu bytes)\n", + __func__, arrival_bytes); + ggml_cuda_ar_pipeline_free(p); + return nullptr; + } + ggml_cuda_set_device(p->devices[0]); + if (cudaMemset(p->arrival.dev, 0, arrival_bytes) != cudaSuccess) { + GGML_LOG_ERROR("%s: cudaMemset for arrival ring failed (%zu bytes)\n", + __func__, arrival_bytes); + ggml_cuda_ar_pipeline_free(p); + return nullptr; + } + + // Per-device pinned staging buffers -- POOL_SIZE-deep ring so the chunked- + // kernel can write the next slot's data while the peer is still reading + // the previous slot's. Indexed by (slot * buf_bytes) at the call site. + p->buf_bytes = GGML_CUDA_AR_MAX_BYTES; + const size_t host_buf_total = (size_t) GGML_CUDA_AR_POOL_SIZE * p->buf_bytes; + for (size_t i = 0; i < n_devices; ++i) { + if (p->host_buf[i].alloc(host_buf_total) != cudaSuccess) { + GGML_LOG_ERROR("%s: alloc for staging failed (%zu bytes)\n", + __func__, host_buf_total); + ggml_cuda_ar_pipeline_free(p); + return nullptr; + } + } + + // Copy-engine path: pinned host staging + device scratch, sized for the + // largest tensor we accept on this path (GGML_CUDA_AR_COPY_MAX_BYTES). + // dev_tmp is single-buffered; cross-AR safety is enforced by an explicit + // cross-stream wait in copy_impl on the prior AR's add_kernel-done event. + for (size_t i = 0; i < n_devices; ++i) { + ggml_cuda_set_device(p->devices[i]); + if (p->host_large[i].alloc(p->copy_bytes) != cudaSuccess) { + GGML_LOG_ERROR("%s: alloc for large staging failed (%zu bytes)\n", + __func__, p->copy_bytes); + ggml_cuda_ar_pipeline_free(p); + return nullptr; + } + if (cudaMalloc(reinterpret_cast(&p->dev_tmp[i]), p->copy_bytes) != cudaSuccess) { + GGML_LOG_ERROR("%s: cudaMalloc for copy scratch failed (%zu bytes) on device %d\n", + __func__, p->copy_bytes, p->devices[i]); + ggml_cuda_ar_pipeline_free(p); + return nullptr; + } + } + + GGML_LOG_INFO("%s: initialized AllReduce pipeline: %zu GPUs, " + "%zu KB chunked kernel staging + %zu MB copy-engine staging per GPU\n", + __func__, n_devices, p->buf_bytes >> 10, p->copy_bytes >> 20); + + return p; +} + +void ggml_cuda_ar_pipeline_free(ggml_cuda_ar_pipeline * p) { + if (!p) { + return; + } + + // Drain all in-flight kernels before tearing down resources. + for (int i = 0; i < p->n_devices; ++i) { + if (p->streams[i]) { + ggml_cuda_set_device(p->devices[i]); + cudaStreamSynchronize(p->streams[i]); + } + } + + for (int i = 0; i < p->n_devices; ++i) { + p->host_buf[i].free(); + p->host_large[i].free(); + if (p->dev_tmp[i]) { + ggml_cuda_set_device(p->devices[i]); + cudaFree(p->dev_tmp[i]); + } + ggml_cuda_set_device(p->devices[i]); + for (int s = 0; s < GGML_CUDA_AR_POOL_SIZE; ++s) { + if (p->ev_pool[i][s].app) { cudaEventDestroy(p->ev_pool[i][s].app); } + for (int c = 0; c < GGML_CUDA_AR_COPY_MAX_CHUNKS; ++c) { + if (p->ev_pool[i][s].cpy[c]) { cudaEventDestroy(p->ev_pool[i][s].cpy[c]); } + } + if (p->ev_pool[i][s].h2d) { cudaEventDestroy(p->ev_pool[i][s].h2d); } + if (p->ev_pool[i][s].ker) { cudaEventDestroy(p->ev_pool[i][s].ker); } + } + if (p->host_large_read_done[i]) { + ggml_cuda_set_device(p->devices[i]); + cudaEventDestroy(p->host_large_read_done[i]); + } + if (p->dev_tmp_kernel_done[i]) { + ggml_cuda_set_device(p->devices[i]); + cudaEventDestroy(p->dev_tmp_kernel_done[i]); + } + if (p->streams[i]) { + ggml_cuda_set_device(p->devices[i]); + cudaStreamDestroy(p->streams[i]); + } + } + p->arrival.free(); + delete p; +} + +// --------------------------------------------------------------------------- +// Dispatch +// --------------------------------------------------------------------------- + +// Asymmetric copy_impl: data sent over PCIe in T_src precision (one element of +// nbytes per ne element); accumulated locally into a T_dst buffer. When +// T_src == T_dst this is the original homogeneous reduction. When they differ +// (e.g. BF16 wire / F32 accumulator) the add kernel rounds dst through T_src +// for bit-equivalence between GPUs and we skip the otherwise-needed +// post-conversion entirely. +template +static bool ggml_cuda_ar_allreduce_copy_impl( + ggml_cuda_ar_pipeline * p, + ggml_backend_t * backends, + T_src * const src_buf[GGML_CUDA_MAX_DEVICES], + T_dst * const dst_buf[GGML_CUDA_MAX_DEVICES], + const bool compute[GGML_CUDA_MAX_DEVICES], + int64_t ne, + size_t nbytes) { + GGML_ASSERT(p->n_devices == 2); + GGML_ASSERT(nbytes <= p->copy_bytes); + GGML_ASSERT(ne <= std::numeric_limits::max()); + + const size_t chunk_bytes = ggml_cuda_ar_chunk_bytes(p, nbytes); + GGML_ASSERT(chunk_bytes > 0); + + const int slot = ggml_cuda_ar_acquire_slot(p).slot; + const size_t copy_chunks = (nbytes + chunk_bytes - 1) / chunk_bytes; + GGML_ASSERT(copy_chunks <= GGML_CUDA_AR_COPY_MAX_CHUNKS); + + ggml_backend_cuda_context * cuda_ctx[2] = {}; + + // Stage 1: both GPUs copy their local contribution to pinned host memory. + for (int i = 0; i < 2; ++i) { + ggml_cuda_set_device(p->devices[i]); + cuda_ctx[i] = static_cast(backends[i]->context); + GGML_ASSERT(cuda_ctx[i]->device == p->devices[i]); + + ggml_cuda_ar_wait_for_compute(p, cuda_ctx[i], i, slot); + + // Wait for peer's H2D from our host_large[i] (recorded in the + // previous AR's stage 2) to complete before we overwrite host_large[i]. + // host_large_read_done[peer] = peer finished reading host_large[i]. + // No-op on the first AR -- no prior record exists. + if (p->host_large_read_done_valid) { + const int peer = 1 - i; + CUDA_CHECK(cudaStreamWaitEvent(p->streams[i], p->host_large_read_done[peer])); + } + + if (!compute[i]) { + CUDA_CHECK(cudaMemsetAsync(src_buf[i], 0, nbytes, p->streams[i])); + } + + for (size_t c = 0; c < copy_chunks; ++c) { + const size_t offset = c * chunk_bytes; + const size_t this_bytes = (nbytes - offset) < chunk_bytes ? + (nbytes - offset) : chunk_bytes; + + CUDA_CHECK(cudaMemcpyAsync( + p->host_large[i].host + offset, reinterpret_cast(src_buf[i]) + offset, this_bytes, + cudaMemcpyDeviceToHost, p->streams[i])); + CUDA_CHECK(cudaEventRecord(p->ev_pool[i][slot].cpy[c], p->streams[i])); + } + } + + // Stage 2: each GPU waits for each peer D2H chunk, pulls that chunk back to + // local device scratch (dev_tmp), then performs one device-local add over + // the assembled peer tensor. The H2Ds run on the AR stream (copy engine) + // and the add_kernel runs on the caller's compute stream, so the AR stream + // stays pure-copy and avoids an in-stream copy->compute engine switch every + // AR. dev_tmp is single-buffered: the AR stream waits cross-stream on the + // prior AR's add_kernel-done event before overwriting it. + for (int i = 0; i < 2; ++i) { + const int peer = 1 - i; + ggml_cuda_set_device(p->devices[i]); + + // Wait for the previous AR's add_kernel (on the compute stream) to + // finish reading dev_tmp before our H2D overwrites it. No-op on the + // first copy_impl call. + if (p->dev_tmp_kernel_done_valid) { + CUDA_CHECK(cudaStreamWaitEvent(p->streams[i], p->dev_tmp_kernel_done[i])); + } + + for (size_t c = 0; c < copy_chunks; ++c) { + const size_t offset = c * chunk_bytes; + const size_t this_bytes = (nbytes - offset) < chunk_bytes ? + (nbytes - offset) : chunk_bytes; + + CUDA_CHECK(cudaStreamWaitEvent(p->streams[i], p->ev_pool[peer][slot].cpy[c])); + CUDA_CHECK(cudaMemcpyAsync( + p->dev_tmp[i] + offset, p->host_large[peer].host + offset, this_bytes, + cudaMemcpyHostToDevice, p->streams[i])); + } + + // Mark our reads of host_large[peer] complete so peer's next AR can + // safely overwrite it. + CUDA_CHECK(cudaEventRecord(p->host_large_read_done[i], p->streams[i])); + + // Hand off from AR stream (copy engine) to compute stream: compute + // stream waits for all H2Ds to finish, then runs the add_kernel. + CUDA_CHECK(cudaEventRecord(p->ev_pool[i][slot].h2d, p->streams[i])); + CUDA_CHECK(cudaStreamWaitEvent(cuda_ctx[i]->stream(), p->ev_pool[i][slot].h2d)); + + const int block_size = 256; + int n_blocks = (int) ((ne + block_size - 1) / block_size); + if (n_blocks > 1024) { + n_blocks = 1024; + } + ggml_cuda_ar_add_kernel<<stream()>>>( + dst_buf[i], + reinterpret_cast(p->dev_tmp[i]), + (int) ne); + CUDA_CHECK(cudaGetLastError()); + + // Record dev_tmp-released on the compute stream so the next copy_impl + // can wait for the kernel to finish before overwriting dev_tmp. Also + // record AR-done as ev.ker for acquire_slot's pool-wraparound sync. + CUDA_CHECK(cudaEventRecord(p->dev_tmp_kernel_done[i], cuda_ctx[i]->stream())); + CUDA_CHECK(cudaEventRecord(p->ev_pool[i][slot].ker, cuda_ctx[i]->stream())); + } + p->host_large_read_done_valid = true; + p->dev_tmp_kernel_done_valid = true; + + return true; +} + +// Outer-level chunker: copy_impl handles up to copy_bytes per call (limited by +// the host_large / dev_tmp allocation size). When the full AR exceeds that, +// slice the tensor into copy_bytes-sized pieces and call copy_impl repeatedly. +// Each slice goes through its own stage 1 -> stage 2 cycle and acquires its own +// slot, so cross-AR fences and pool wraparound work the same way as for any +// other sequence of small ARs. +template +static bool ggml_cuda_ar_allreduce_copy_outer( + ggml_cuda_ar_pipeline * p, + ggml_backend_t * backends, + T_src * const src_buf[GGML_CUDA_MAX_DEVICES], + T_dst * const dst_buf[GGML_CUDA_MAX_DEVICES], + const bool compute[GGML_CUDA_MAX_DEVICES], + int64_t ne) { + const int64_t outer_max_elems = (int64_t) (p->copy_bytes / sizeof(T_src)); + GGML_ASSERT(outer_max_elems > 0); + + bool ok = true; + for (int64_t outer_start = 0; outer_start < ne && ok; outer_start += outer_max_elems) { + const int64_t outer_ne = std::min(outer_max_elems, ne - outer_start); + const size_t outer_nbytes = (size_t) outer_ne * sizeof(T_src); + + T_src * src[GGML_CUDA_MAX_DEVICES] = {}; + T_dst * dst[GGML_CUDA_MAX_DEVICES] = {}; + for (int i = 0; i < p->n_devices; ++i) { + src[i] = src_buf[i] + outer_start; + dst[i] = dst_buf[i] + outer_start; + } + ok = ggml_cuda_ar_allreduce_copy_impl( + p, backends, src, dst, compute, outer_ne, outer_nbytes); + } + return ok; +} + +bool ggml_cuda_ar_allreduce( + ggml_cuda_ar_pipeline * p, + ggml_backend_t * backends, + ggml_tensor ** tensors) { + GGML_ASSERT(p != nullptr); + + const int n = p->n_devices; + GGML_ASSERT(n == 2); + + const ggml_type input_type = tensors[0]->type; + GGML_ASSERT(input_type == GGML_TYPE_F32 || input_type == GGML_TYPE_F16 || input_type == GGML_TYPE_BF16); + + const int64_t ne = ggml_nelements(tensors[0]); + GGML_ASSERT(ne > 0); + + const size_t input_nbytes = ggml_nbytes(tensors[0]); + + // BF16 round-trip: F32 inputs >= bf16_threshold are converted to BF16 for + // the reduction (chunked or copy-engine), halving on-wire bytes. Matches + // NCCL's behaviour. The pre-conversion zeroes inactive shards so the + // inner paths see them as already-prepared compute tensors. + const bool use_bf16 = + input_type == GGML_TYPE_F32 && + p->bf16_threshold > 0 && + input_nbytes >= p->bf16_threshold; + + const ggml_type kernel_type = use_bf16 ? GGML_TYPE_BF16 : input_type; + const size_t type_size = ggml_type_size(kernel_type); + GGML_ASSERT(p->buf_bytes >= type_size); + const size_t nbytes = (size_t) ne * type_size; + + bool compute_flag[GGML_CUDA_MAX_DEVICES] = {}; + for (int i = 0; i < n; ++i) { + compute_flag[i] = (tensors[i]->flags & GGML_TENSOR_FLAG_COMPUTE) != 0; + } + + // Decide between copy-engine and chunked kernel paths based on the working + // type's actual byte count. No upper bound: copy_outer slices reductions + // larger than copy_bytes into copy_bytes-sized pieces. + const bool use_copy_engine = + p->copy_threshold > 0 && + nbytes >= p->copy_threshold; + + // BF16 inactive-shard zeroing: when use_bf16 is on, the combined kernel + // (chunked kernel path) and the combined add kernel (copy_engine path) + // both accumulate into the F32 tensor data directly, so an inactive + // shard's accumulator must start at zero. + if (use_bf16) { + for (int i = 0; i < n; ++i) { + if (!compute_flag[i]) { + auto * cuda_ctx = static_cast(backends[i]->context); + GGML_ASSERT(cuda_ctx->device == p->devices[i]); + ggml_cuda_set_device(p->devices[i]); + CUDA_CHECK(cudaMemsetAsync(tensors[i]->data, 0, (size_t) ne * sizeof(float), cuda_ctx->stream())); + } + } + } + + // Pre-convert F32 -> BF16 into bf16_tmp ONLY for the copy_engine + use_bf16 + // path; the chunked kernel path's combined kernel does the conversion + // inline as it writes to host_buf. + ggml_cuda_pool_alloc bf16_tmp[GGML_CUDA_MAX_DEVICES]; + void * copy_src_ptr[GGML_CUDA_MAX_DEVICES] = {}; + + if (use_copy_engine && use_bf16) { + to_bf16_cuda_t to_bf16 = ggml_get_to_bf16_cuda(GGML_TYPE_F32); + for (int i = 0; i < n; ++i) { + auto * cuda_ctx = static_cast(backends[i]->context); + GGML_ASSERT(cuda_ctx->device == p->devices[i]); + bf16_tmp[i].pool = &cuda_ctx->pool(); + bf16_tmp[i].alloc(ne); + ggml_cuda_set_device(p->devices[i]); + if (compute_flag[i]) { + to_bf16(tensors[i]->data, bf16_tmp[i].get(), ne, cuda_ctx->stream()); + CUDA_CHECK(cudaGetLastError()); + } else { + CUDA_CHECK(cudaMemsetAsync(bf16_tmp[i].get(), 0, nbytes, cuda_ctx->stream())); + } + copy_src_ptr[i] = bf16_tmp[i].get(); + } + } + + bool ok = true; + if (use_copy_engine) { + // After up-front BF16 conversion, the tmp buffers already hold the + // (possibly zeroed-for-inactive) data, so the inner path can treat + // every shard as compute. + bool inner_compute[GGML_CUDA_MAX_DEVICES]; + for (int i = 0; i < n; ++i) { + inner_compute[i] = use_bf16 ? true : compute_flag[i]; + } + + // Dispatch into copy_impl with explicit src/dst types. When use_bf16 + // is on, the wire type is BF16 (src = bf16_tmp) and the accumulator + // is F32 (dst = tensors[i]->data); the combined add kernel rounds dst + // through BF16 for bit-equivalence and writes F32 directly, so no + // post-conversion is needed. Otherwise src == dst (same native type). + if (use_bf16) { + GGML_ASSERT(kernel_type == GGML_TYPE_BF16); + nv_bfloat16 * src[GGML_CUDA_MAX_DEVICES] = {}; + float * dst[GGML_CUDA_MAX_DEVICES] = {}; + for (int i = 0; i < n; ++i) { + src[i] = static_cast(copy_src_ptr[i]); + dst[i] = static_cast(tensors[i]->data); + } + ok = ggml_cuda_ar_allreduce_copy_outer( + p, backends, src, dst, inner_compute, ne); + } else { + switch (kernel_type) { + case GGML_TYPE_F32: { + float * buf[GGML_CUDA_MAX_DEVICES] = {}; + for (int i = 0; i < n; ++i) { + buf[i] = static_cast(tensors[i]->data); + } + ok = ggml_cuda_ar_allreduce_copy_outer( + p, backends, buf, buf, inner_compute, ne); + break; + } + case GGML_TYPE_BF16: { + nv_bfloat16 * buf[GGML_CUDA_MAX_DEVICES] = {}; + for (int i = 0; i < n; ++i) { + buf[i] = static_cast(tensors[i]->data); + } + ok = ggml_cuda_ar_allreduce_copy_outer( + p, backends, buf, buf, inner_compute, ne); + break; + } + case GGML_TYPE_F16: { + half * buf[GGML_CUDA_MAX_DEVICES] = {}; + for (int i = 0; i < n; ++i) { + buf[i] = static_cast(tensors[i]->data); + } + ok = ggml_cuda_ar_allreduce_copy_outer( + p, backends, buf, buf, inner_compute, ne); + break; + } + default: + GGML_ASSERT(false); + } + } + } else { + // host_buf carries T_wire-typed data; max_chunk_elems is the count that + // fits in one host_buf at the wire size. + const size_t max_chunk_elems = p->buf_bytes / type_size; + const size_t input_type_size = ggml_type_size(input_type); + + // Chunked kernel path runs entirely on the caller's compute stream: + // since AR is a barrier here, same-stream ordering subsumes any + // cross-stream event handshake that the copy-engine path needs, and + // skips the cross-stream scheduling overhead that was hurting the + // small-tensor (tg) latency on the AR-stream variant. Only ev.ker is + // still recorded at end-of-AR for acquire_slot's pool-wraparound check. + for (int64_t chunk_start = 0; chunk_start < ne; chunk_start += (int64_t) max_chunk_elems) { + const size_t remaining_elems = (size_t) (ne - chunk_start); + const size_t chunk_elems = remaining_elems < max_chunk_elems ? remaining_elems : max_chunk_elems; + const size_t chunk_dst_bytes = chunk_elems * input_type_size; + + const auto [slot, token] = ggml_cuda_ar_acquire_slot(p); + const bool last_chunk = chunk_start + (int64_t) chunk_elems == ne; + + for (int i = 0; i < n; ++i) { + const int peer = 1 - i; // valid for n == 2 only + ggml_cuda_set_device(p->devices[i]); + auto * cuda_ctx = static_cast(backends[i]->context); + GGML_ASSERT(cuda_ctx->device == p->devices[i]); + cudaStream_t stream = cuda_ctx->stream(); + + char * data = static_cast(tensors[i]->data) + chunk_start * (int64_t) input_type_size; + + // Match NCCL/meta-backend semantics: inactive shards contribute + // zeros. On the BF16 path the F32 tensor data was already + // zeroed up-front (above), so per-chunk zeroing isn't needed. + if (!compute_flag[i] && !use_bf16) { + CUDA_CHECK(cudaMemsetAsync(data, 0, chunk_dst_bytes, stream)); + } + +#define LAUNCH_AR_KERNEL(T_dst, T_wire) \ + ggml_cuda_ar_kernel<<>>( \ + reinterpret_cast(data), \ + reinterpret_cast(data), \ + reinterpret_cast(p->host_buf[i].dev + (size_t) slot * p->buf_bytes), \ + reinterpret_cast(p->host_buf[peer].dev + (size_t) slot * p->buf_bytes), \ + static_cast(chunk_elems), \ + ggml_cuda_ar_arrival_ptr(p, slot, i), \ + ggml_cuda_ar_arrival_ptr(p, slot, peer), \ + token) + + if (use_bf16) { + GGML_ASSERT(input_type == GGML_TYPE_F32); + LAUNCH_AR_KERNEL(float, nv_bfloat16); + } else { + switch (input_type) { + case GGML_TYPE_F32: LAUNCH_AR_KERNEL(float, float); break; + case GGML_TYPE_F16: LAUNCH_AR_KERNEL(half, half); break; + case GGML_TYPE_BF16: LAUNCH_AR_KERNEL(nv_bfloat16, nv_bfloat16); break; + default: GGML_ASSERT(false); + } + } + +#undef LAUNCH_AR_KERNEL + CUDA_CHECK(cudaGetLastError()); + + if (last_chunk) { + CUDA_CHECK(cudaEventRecord(p->ev_pool[i][slot].ker, stream)); + } + } + } + } + + return ok; +} + +#else // defined(GGML_USE_HIP) || defined(GGML_USE_MUSA) + +// HIP and MUSA lack the host-mapped pinned-memory APIs (cudaHostAllocPortable +// / cudaHostAllocMapped / cudaHostGetDevicePointer) and __nanosleep that this +// implementation relies on, so the internal AllReduce is a CUDA-only feature. +// The dispatcher in ggml-cuda.cu treats a nullptr pipeline as "init failed" +// and silently falls back to the meta backend's generic AllReduce. +ggml_cuda_ar_pipeline * ggml_cuda_ar_pipeline_init(const int *, size_t) { + return nullptr; +} +void ggml_cuda_ar_pipeline_free(ggml_cuda_ar_pipeline *) { +} +bool ggml_cuda_ar_allreduce(ggml_cuda_ar_pipeline *, ggml_backend_t *, ggml_tensor **) { + return false; +} + +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) diff --git a/ggml/src/ggml-cuda/allreduce.cuh b/ggml/src/ggml-cuda/allreduce.cuh new file mode 100644 index 00000000000..0f2c9518d5d --- /dev/null +++ b/ggml/src/ggml-cuda/allreduce.cuh @@ -0,0 +1,29 @@ +#pragma once + +#include "common.cuh" +#include "ggml-backend-impl.h" + +#include + +// Opaque pipeline context -- owns all pinned buffers, streams, and events. +struct ggml_cuda_ar_pipeline; + +// Allocate a pipeline for n_devices GPUs. +// devices[] holds the CUDA device IDs in rank order. +// Returns nullptr on allocation failure. +ggml_cuda_ar_pipeline * ggml_cuda_ar_pipeline_init( + const int * devices, size_t n_devices); + +// Release all resources owned by the pipeline. +void ggml_cuda_ar_pipeline_free(ggml_cuda_ar_pipeline * pipeline); + +// Execute an in-place AllReduce (sum) across tensors[0..n_devices-1]. +// tensors[i] must live on the device managed by backends[i] and be +// contiguous F32, F16, or BF16. +// Preconditions are checked by the CUDA comm dispatcher before calling this. +// Returns true once the reduction work has been enqueued successfully. +bool ggml_cuda_ar_allreduce( + ggml_cuda_ar_pipeline * pipeline, + ggml_backend_t * backends, + ggml_tensor ** tensors); + diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 4df1b930882..b92a208705d 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2,6 +2,7 @@ #include "ggml-impl.h" #include "ggml-backend-impl.h" +#include "ggml-cuda/allreduce.cuh" #include "ggml-cuda/common.cuh" #include "ggml-cuda/acc.cuh" #include "ggml-cuda/add-id.cuh" @@ -86,6 +87,9 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size"); +#define GGML_LOG_WARN_ONCE(str) \ + { static std::once_flag warn_flag; std::call_once(warn_flag, []() { GGML_LOG_WARN(str); }); } + [[noreturn]] void ggml_cuda_error(const char * stmt, const char * func, const char * file, int line, const char * msg) { int id = -1; // in case cudaGetDevice fails @@ -1139,70 +1143,46 @@ static const ggml_backend_buffer_type_i ggml_backend_cuda_split_buffer_type_inte /* .is_host = */ ggml_backend_cuda_split_buffer_type_is_host, }; -#ifdef GGML_USE_NCCL +// Communication context for multi-GPU AllReduce during tensor parallelism. +// +// Created once per meta backend instance. Resources for the selected mode +// (NCCL communicators or the internal AllReduce pipeline) are initialised +// eagerly during comm_init so any init failure surfaces at startup rather +// than mid-run. struct ggml_backend_cuda_comm_context { + using try_allreduce_fn = bool(*)(ggml_backend_cuda_comm_context *, struct ggml_tensor **); + std::vector backends; - std::vector comms; + std::vector dev_ids; - ~ggml_backend_cuda_comm_context() { - for (ncclComm_t comm : comms) { - NCCL_CHECK(ncclCommDestroy(comm)); - } - } -}; -#endif // GGML_USE_NCCL + // Set by the init chain (comm_init_{nccl, internal, none}) to one of + // try_allreduce_{nccl, internal, butterfly}. nccl needs `comms`, + // internal needs `ar_pipeline`, butterfly needs nothing. Per-call + // failures return false; the meta backend's generic implementation then + // handles that call. + try_allreduce_fn try_allreduce = nullptr; + + ggml_cuda_ar_pipeline * ar_pipeline = nullptr; -static void ggml_backend_cuda_comm_free(void * comm_ctx_v) { #ifdef GGML_USE_NCCL - if (comm_ctx_v == nullptr) { - return; - } - ggml_backend_cuda_comm_context * comm_ctx = (ggml_backend_cuda_comm_context *) comm_ctx_v; - delete comm_ctx; -#else - GGML_UNUSED(comm_ctx_v); + std::vector comms; #endif // GGML_USE_NCCL -} -static void * ggml_backend_cuda_comm_init(ggml_backend_t * backends, size_t n_backends) { + ~ggml_backend_cuda_comm_context() { #ifdef GGML_USE_NCCL - for (size_t i = 0; i < n_backends; i++) { - if (!ggml_backend_is_cuda(backends[i])) { - return nullptr; + for (ncclComm_t comm : comms) { + NCCL_CHECK(ncclCommDestroy(comm)); } - } - ggml_backend_cuda_comm_context * ret = new ggml_backend_cuda_comm_context; - std::vector dev_ids; - ret->backends.reserve(n_backends); - dev_ids.reserve(n_backends); - for (size_t i = 0; i < n_backends; i++) { - ret->backends.push_back(backends[i]); - ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backends[i]->context; - dev_ids.push_back(cuda_ctx->device); - } - - ret->comms.resize(n_backends); - NCCL_CHECK(ncclCommInitAll(ret->comms.data(), n_backends, dev_ids.data())); - return ret; -#else - // If NCCL is installed it is used by default for optimal performance. - // However, NVIDIA does not distribute NCCL with CUDA so users may be unwittingly missing this package. - // RCCL is disabled by default, users are explicitly opting in. - // Therefore print no warning for RCCL. -#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) - static bool warning_printed = false; - if (!warning_printed) { - GGML_LOG_WARN("%s: NVIDIA Collective Communications Library (NCCL) is unavailable, multi GPU performance will be suboptimal\n", __func__); - warning_printed = true; - } -#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) - GGML_UNUSED_VARS(backends, n_backends); - return nullptr; #endif // GGML_USE_NCCL -} + ggml_cuda_ar_pipeline_free(ar_pipeline); + } +}; -static bool ggml_backend_cuda_comm_allreduce_tensor(void * comm_ctx_v, struct ggml_tensor ** tensors) { #ifdef GGML_USE_NCCL +// AllReduce via NCCL. Reduces as FP32 for small tensors and BF16 for large +// tensors (bandwidth-bound), then converts back to FP32. +static bool ggml_backend_cuda_comm_allreduce_nccl( + ggml_backend_cuda_comm_context * comm_ctx, struct ggml_tensor ** tensors) { const int64_t ne = ggml_nelements(tensors[0]); // FIXME the input of llm_graph_context::build_in_out_ids can produce a tensor with 0 elements if n_outputs == 0 // This then causes a crash in this function @@ -1210,8 +1190,6 @@ static bool ggml_backend_cuda_comm_allreduce_tensor(void * comm_ctx_v, struct gg return true; } - GGML_ASSERT(comm_ctx_v != nullptr); - ggml_backend_cuda_comm_context * comm_ctx = (ggml_backend_cuda_comm_context *) comm_ctx_v; const size_t n_backends = comm_ctx->backends.size(); for (size_t i = 0; i < n_backends; ++i) { @@ -1236,7 +1214,6 @@ static bool ggml_backend_cuda_comm_allreduce_tensor(void * comm_ctx_v, struct gg NCCL_CHECK(ncclAllReduce(tensors[i]->data, tensors[i]->data, ne, ncclFloat, ncclSum, comm_ctx->comms[i], cuda_ctx->stream())); } NCCL_CHECK(ncclGroupEnd()); - return true; } @@ -1275,10 +1252,184 @@ static bool ggml_backend_cuda_comm_allreduce_tensor(void * comm_ctx_v, struct gg } return true; -#else - GGML_UNUSED_VARS(comm_ctx_v, tensors); +} +#endif // GGML_USE_NCCL + +// Run the internal AR pipeline. Returns false on unsupported / failed input +// -- the caller decides whether to abort (env-forced) or fall back silently. +static bool ggml_backend_cuda_comm_allreduce_internal( + ggml_backend_cuda_comm_context * comm_ctx, struct ggml_tensor ** tensors) { + GGML_ASSERT(comm_ctx->ar_pipeline != nullptr); + + const size_t n_backends = comm_ctx->backends.size(); + GGML_ASSERT(n_backends == 2); + GGML_ASSERT(tensors[0] != nullptr); + + const int64_t ne = ggml_nelements(tensors[0]); + const ggml_type type = tensors[0]->type; + + if (type != GGML_TYPE_F32 && type != GGML_TYPE_F16 && type != GGML_TYPE_BF16) { + GGML_LOG_DEBUG("%s: internal unsupported: type=%d\n", __func__, (int) type); + return false; + } + + if (ne == 0) { + return true; + } + + for (size_t i = 0; i < n_backends; ++i) { + if (tensors[i] == nullptr) { + GGML_LOG_ERROR("%s: internal failed: tensor[%zu] is null\n", __func__, i); + return false; + } + if (ggml_nelements(tensors[i]) != ne || tensors[i]->type != type) { + GGML_LOG_ERROR("%s: internal failed: tensor[%zu] ne=%" PRId64 " type=%d expected ne=%" PRId64 " type=%d\n", + __func__, i, ggml_nelements(tensors[i]), (int) tensors[i]->type, ne, (int) type); + return false; + } + if (!ggml_is_contiguously_allocated(tensors[i])) { + GGML_LOG_DEBUG("%s: internal unsupported: tensor[%zu] is not contiguously allocated: ne=%" PRId64 " nbytes=%zu packed=%zu type=%d\n", + __func__, i, ne, ggml_nbytes(tensors[i]), + (size_t) ne * ggml_type_size(type) / ggml_blck_size(type), (int) type); + return false; + } + if (((uintptr_t) tensors[i]->data & 0xF) != 0) { + GGML_LOG_DEBUG("%s: internal unsupported: tensor[%zu] data pointer is not 16-byte aligned: %p type=%d ne=%" PRId64 "\n", + __func__, i, tensors[i]->data, (int) type, ne); + return false; + } + GGML_ASSERT((ggml_nbytes(tensors[i]) & 0xF) == 0); + } + + return ggml_cuda_ar_allreduce(comm_ctx->ar_pipeline, comm_ctx->backends.data(), tensors); +} + +// --------------------------------------------------------------------------- +// Per-call dispatch -- three variants, one per backend. Each is set as +// comm_ctx->try_allreduce by the matching init step. Per-call failure +// returns false; the meta backend's generic implementation handles that call. +// --------------------------------------------------------------------------- + +#ifdef GGML_USE_NCCL +static bool ggml_backend_cuda_comm_try_allreduce_nccl( + ggml_backend_cuda_comm_context * comm_ctx, struct ggml_tensor ** tensors) { + return ggml_backend_cuda_comm_allreduce_nccl(comm_ctx, tensors); +} +#endif // GGML_USE_NCCL + +static bool ggml_backend_cuda_comm_try_allreduce_internal( + ggml_backend_cuda_comm_context * comm_ctx, struct ggml_tensor ** tensors) { + return ggml_backend_cuda_comm_allreduce_internal(comm_ctx, tensors); +} + +static bool ggml_backend_cuda_comm_try_allreduce_butterfly( + ggml_backend_cuda_comm_context *, struct ggml_tensor **) { return false; +} + +static void ggml_backend_cuda_comm_free(void * comm_ctx_v) { + if (comm_ctx_v == nullptr) { + return; + } + delete static_cast(comm_ctx_v); +} + +// --------------------------------------------------------------------------- +// Init -- chained nccl -> internal -> none. Each step tries to bring up its +// resource; on failure it warns and recurses into the next step. +// --------------------------------------------------------------------------- +static void ggml_backend_cuda_comm_init_none(ggml_backend_cuda_comm_context * ret) { + ret->try_allreduce = ggml_backend_cuda_comm_try_allreduce_butterfly; +} + +static void ggml_backend_cuda_comm_init_internal(ggml_backend_cuda_comm_context * ret) { + ret->ar_pipeline = ggml_cuda_ar_pipeline_init(ret->dev_ids.data(), ret->dev_ids.size()); + if (ret->ar_pipeline) { + ret->try_allreduce = ggml_backend_cuda_comm_try_allreduce_internal; + return; + } + + // Clear sticky CUDA error from the failed init. + (void) cudaGetLastError(); + GGML_LOG_WARN("internal AllReduce init failed (n_devices != 2?); " + "falling back to meta-backend butterfly\n"); + ggml_backend_cuda_comm_init_none(ret); +} + +static void ggml_backend_cuda_comm_init_nccl(ggml_backend_cuda_comm_context * ret) { +#ifdef GGML_USE_NCCL + const size_t n = ret->dev_ids.size(); + ret->comms.resize(n); + ncclResult_t rc = ncclCommInitAll(ret->comms.data(), (int) n, ret->dev_ids.data()); + if (rc == ncclSuccess) { + ret->try_allreduce = ggml_backend_cuda_comm_try_allreduce_nccl; + return; + } + + ret->comms.clear(); + GGML_LOG_WARN("NCCL init failed (%s); falling back to internal AllReduce\n", + ncclGetErrorString(rc)); +#else // GGML_USE_NCCL +#ifndef GGML_USE_HIP + GGML_LOG_WARN("NCCL not compiled in; falling back to internal AllReduce. " + "Recompile with -DGGML_CUDA_NCCL=ON for best multi-GPU performance.\n"); +#endif // !GGML_USE_HIP #endif // GGML_USE_NCCL + + ggml_backend_cuda_comm_init_internal(ret); +} + +// Top-level init. Picks one of the three init paths based on +// GGML_CUDA_ALLREDUCE (or the platform default) and lets the chain handle +// any fallback. Unrecognised env values warn and fall through to the +// platform default. +static void * ggml_backend_cuda_comm_init(ggml_backend_t * backends, size_t n_backends) { + for (size_t i = 0; i < n_backends; i++) { + if (!ggml_backend_is_cuda(backends[i])) { + return nullptr; + } + } + + auto * ret = new ggml_backend_cuda_comm_context; + ret->backends.assign(backends, backends + n_backends); + ret->dev_ids.reserve(n_backends); + for (size_t i = 0; i < n_backends; i++) { + ret->dev_ids.push_back(static_cast(backends[i]->context)->device); + } + + const char * env = getenv("GGML_CUDA_ALLREDUCE"); + if (!env) { + // Platform default: Linux uses NCCL, otherwise (generally Windows) internal +#if defined(__linux__) + ggml_backend_cuda_comm_init_nccl(ret); +#else + ggml_backend_cuda_comm_init_internal(ret); +#endif // defined(__linux__) + } else { + std::string env_str(env); + if (env_str == "nccl") { + ggml_backend_cuda_comm_init_nccl(ret); + } else if (env_str == "internal") { + ggml_backend_cuda_comm_init_internal(ret); + } else if (env_str == "none") { + ggml_backend_cuda_comm_init_none(ret); + } else { + GGML_LOG_WARN("unknown GGML_CUDA_ALLREDUCE value: %s\n", env); + ggml_backend_cuda_comm_init_none(ret); + } + } + + return ret; +} + +// Top-level dispatch -- calls the function pointer chosen by comm_init. +// Returns false to let the meta-backend's butterfly run. +static bool ggml_backend_cuda_comm_allreduce_tensor(void * comm_ctx_v, struct ggml_tensor ** tensors) { + if (comm_ctx_v == nullptr) { + return false; + } + auto * comm_ctx = static_cast(comm_ctx_v); + return comm_ctx->try_allreduce(comm_ctx, tensors); } ggml_backend_buffer_type_t ggml_backend_cuda_split_buffer_type(int main_device, const float * tensor_split) { From cf6e65bc594b4bc2648d2b3f8e98ffa42c01034e Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 10 May 2026 16:57:19 +0300 Subject: [PATCH 583/831] ggml : bump version to 0.11.1 (ggml/1484) --- ggml/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 8dd4d64063f..672b37dffc3 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -5,7 +5,7 @@ project("ggml" C CXX ASM) ### GGML Version set(GGML_VERSION_MAJOR 0) set(GGML_VERSION_MINOR 11) -set(GGML_VERSION_PATCH 0) +set(GGML_VERSION_PATCH 1) set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}") list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/") From 4730e765525b718133b5d63d664499ba33b7cd5a Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 10 May 2026 17:27:59 +0300 Subject: [PATCH 584/831] sync : ggml --- scripts/sync-ggml.last | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index 812e721a8c5..15685a0718f 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -19eac6f0edaf285506eb6228d31bb9caeda9aba1 +628249b398293fc8d2fa81a449ae2920a02c6523 From 54ecc9dba43ccf99dcbcb6e1ae3396806e19bf4b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 10 May 2026 17:34:06 +0300 Subject: [PATCH 585/831] talk-llama : sync llama.cpp --- examples/talk-llama/llama-arch.cpp | 1 + examples/talk-llama/llama-arch.h | 1 + examples/talk-llama/llama-context.cpp | 347 +- examples/talk-llama/llama-context.h | 19 + examples/talk-llama/llama-graph.cpp | 7 +- examples/talk-llama/llama-hparams.h | 2 + examples/talk-llama/llama-io.cpp | 9 +- examples/talk-llama/llama-io.h | 6 +- examples/talk-llama/llama-kv-cache.cpp | 54 +- .../talk-llama/llama-memory-recurrent.cpp | 42 +- examples/talk-llama/llama-model-saver.cpp | 1 + examples/talk-llama/llama-model.cpp | 9207 ++--------------- examples/talk-llama/llama-model.h | 89 +- examples/talk-llama/llama-quant.cpp | 23 +- examples/talk-llama/llama-vocab.cpp | 13 + examples/talk-llama/llama-vocab.h | 1 + examples/talk-llama/llama.cpp | 246 +- examples/talk-llama/llama.h | 3 + examples/talk-llama/models/afmoe.cpp | 108 +- examples/talk-llama/models/apertus.cpp | 58 +- examples/talk-llama/models/arcee.cpp | 47 +- examples/talk-llama/models/arctic.cpp | 55 +- examples/talk-llama/models/arwkv7.cpp | 118 +- examples/talk-llama/models/baichuan.cpp | 45 +- examples/talk-llama/models/bailingmoe.cpp | 61 +- examples/talk-llama/models/bailingmoe2.cpp | 96 +- examples/talk-llama/models/bert.cpp | 79 +- examples/talk-llama/models/bitnet.cpp | 49 +- examples/talk-llama/models/bloom.cpp | 64 +- examples/talk-llama/models/chameleon.cpp | 52 +- examples/talk-llama/models/chatglm.cpp | 55 +- examples/talk-llama/models/codeshell.cpp | 51 +- examples/talk-llama/models/cogvlm.cpp | 51 +- .../models/{cohere2-iswa.cpp => cohere2.cpp} | 49 +- examples/talk-llama/models/command-r.cpp | 42 +- examples/talk-llama/models/dbrx.cpp | 46 +- examples/talk-llama/models/deci.cpp | 78 +- examples/talk-llama/models/deepseek.cpp | 73 +- examples/talk-llama/models/deepseek2.cpp | 145 +- examples/talk-llama/models/deepseek2ocr.cpp | 82 + examples/talk-llama/models/dots1.cpp | 72 +- examples/talk-llama/models/dream.cpp | 50 +- examples/talk-llama/models/ernie4-5-moe.cpp | 6 +- examples/talk-llama/models/ernie4-5.cpp | 75 +- examples/talk-llama/models/eurobert.cpp | 37 +- examples/talk-llama/models/exaone-moe.cpp | 113 +- examples/talk-llama/models/exaone.cpp | 45 +- examples/talk-llama/models/exaone4.cpp | 70 +- examples/talk-llama/models/falcon-h1.cpp | 111 +- examples/talk-llama/models/falcon.cpp | 49 +- .../talk-llama/models/gemma-embedding.cpp | 74 +- examples/talk-llama/models/gemma.cpp | 40 +- .../models/{gemma2-iswa.cpp => gemma2.cpp} | 61 +- examples/talk-llama/models/gemma3.cpp | 86 +- .../models/{gemma3n-iswa.cpp => gemma3n.cpp} | 99 +- .../models/{gemma4-iswa.cpp => gemma4.cpp} | 149 +- examples/talk-llama/models/glm-dsa.cpp | 155 + examples/talk-llama/models/glm4-moe.cpp | 135 +- examples/talk-llama/models/glm4.cpp | 74 +- examples/talk-llama/models/gpt2.cpp | 56 +- examples/talk-llama/models/gptneox.cpp | 85 +- examples/talk-llama/models/granite-hybrid.cpp | 137 +- examples/talk-llama/models/granite-moe.cpp | 89 + examples/talk-llama/models/granite.cpp | 93 +- examples/talk-llama/models/grok.cpp | 85 +- examples/talk-llama/models/grovemoe.cpp | 66 +- examples/talk-llama/models/hunyuan-dense.cpp | 132 +- examples/talk-llama/models/hunyuan-moe.cpp | 55 +- examples/talk-llama/models/hunyuan-vl.cpp | 189 + examples/talk-llama/models/internlm2.cpp | 39 +- examples/talk-llama/models/jais.cpp | 54 +- examples/talk-llama/models/jais2.cpp | 57 +- examples/talk-llama/models/jamba.cpp | 107 +- examples/talk-llama/models/jina-bert-v2.cpp | 66 + examples/talk-llama/models/jina-bert-v3.cpp | 69 + examples/talk-llama/models/kimi-linear.cpp | 172 +- examples/talk-llama/models/lfm2.cpp | 92 +- examples/talk-llama/models/lfm2moe.cpp | 85 + examples/talk-llama/models/llada-moe.cpp | 52 +- examples/talk-llama/models/llada.cpp | 68 +- examples/talk-llama/models/llama-embed.cpp | 6 + examples/talk-llama/models/llama.cpp | 101 +- examples/talk-llama/models/llama4.cpp | 108 +- examples/talk-llama/models/maincoder.cpp | 45 +- examples/talk-llama/models/mamba.cpp | 87 +- examples/talk-llama/models/mamba2.cpp | 87 + examples/talk-llama/models/mimo2-iswa.cpp | 129 - examples/talk-llama/models/mimo2.cpp | 240 + examples/talk-llama/models/minicpm.cpp | 89 + examples/talk-llama/models/minicpm3.cpp | 62 +- examples/talk-llama/models/minimax-m2.cpp | 46 +- examples/talk-llama/models/mistral3.cpp | 92 +- examples/talk-llama/models/mistral4.cpp | 6 + examples/talk-llama/models/models.h | 1866 +++- examples/talk-llama/models/modern-bert.cpp | 65 +- examples/talk-llama/models/mpt.cpp | 66 +- examples/talk-llama/models/nemotron-h-moe.cpp | 6 + examples/talk-llama/models/nemotron-h.cpp | 127 +- examples/talk-llama/models/nemotron.cpp | 48 +- examples/talk-llama/models/neo-bert.cpp | 42 +- examples/talk-llama/models/nomic-bert-moe.cpp | 72 + examples/talk-llama/models/nomic-bert.cpp | 72 + examples/talk-llama/models/olmo.cpp | 42 +- examples/talk-llama/models/olmo2.cpp | 67 +- examples/talk-llama/models/olmoe.cpp | 51 +- .../{openai-moe-iswa.cpp => openai-moe.cpp} | 63 +- examples/talk-llama/models/openelm.cpp | 49 +- examples/talk-llama/models/orion.cpp | 42 +- examples/talk-llama/models/paddleocr.cpp | 6 +- .../{pangu-embedded.cpp => pangu-embed.cpp} | 56 +- examples/talk-llama/models/phi2.cpp | 46 +- examples/talk-llama/models/phi3.cpp | 70 +- examples/talk-llama/models/phimoe.cpp | 55 + examples/talk-llama/models/plamo.cpp | 38 +- examples/talk-llama/models/plamo2.cpp | 109 +- examples/talk-llama/models/plamo3.cpp | 73 +- examples/talk-llama/models/plm.cpp | 46 +- examples/talk-llama/models/qwen.cpp | 42 +- examples/talk-llama/models/qwen2.cpp | 51 +- examples/talk-llama/models/qwen2moe.cpp | 63 +- examples/talk-llama/models/qwen2vl.cpp | 41 +- examples/talk-llama/models/qwen3.cpp | 51 +- examples/talk-llama/models/qwen35.cpp | 102 +- examples/talk-llama/models/qwen35moe.cpp | 115 +- examples/talk-llama/models/qwen3moe.cpp | 61 +- examples/talk-llama/models/qwen3next.cpp | 119 +- examples/talk-llama/models/qwen3vl.cpp | 52 +- .../{qwen3vl-moe.cpp => qwen3vlmoe.cpp} | 63 +- examples/talk-llama/models/refact.cpp | 77 +- examples/talk-llama/models/rnd1.cpp | 62 +- examples/talk-llama/models/rwkv6.cpp | 93 +- examples/talk-llama/models/rwkv6qwen2.cpp | 83 +- examples/talk-llama/models/rwkv7.cpp | 123 +- examples/talk-llama/models/seed-oss.cpp | 47 +- examples/talk-llama/models/smallthinker.cpp | 79 +- examples/talk-llama/models/smollm3.cpp | 45 +- examples/talk-llama/models/stablelm.cpp | 50 +- examples/talk-llama/models/starcoder.cpp | 58 +- examples/talk-llama/models/starcoder2.cpp | 57 +- .../models/{step35-iswa.cpp => step35.cpp} | 104 +- examples/talk-llama/models/t5.cpp | 122 +- examples/talk-llama/models/t5encoder.cpp | 43 +- .../talk-llama/models/wavtokenizer-dec.cpp | 117 +- examples/talk-llama/models/xverse.cpp | 39 +- 144 files changed, 12061 insertions(+), 9097 deletions(-) rename examples/talk-llama/models/{cohere2-iswa.cpp => cohere2.cpp} (60%) create mode 100644 examples/talk-llama/models/deepseek2ocr.cpp rename examples/talk-llama/models/{gemma2-iswa.cpp => gemma2.cpp} (53%) rename examples/talk-llama/models/{gemma3n-iswa.cpp => gemma3n.cpp} (76%) rename examples/talk-llama/models/{gemma4-iswa.cpp => gemma4.cpp} (62%) create mode 100644 examples/talk-llama/models/glm-dsa.cpp create mode 100644 examples/talk-llama/models/granite-moe.cpp create mode 100644 examples/talk-llama/models/hunyuan-vl.cpp create mode 100644 examples/talk-llama/models/jina-bert-v2.cpp create mode 100644 examples/talk-llama/models/jina-bert-v3.cpp create mode 100644 examples/talk-llama/models/lfm2moe.cpp create mode 100644 examples/talk-llama/models/llama-embed.cpp create mode 100644 examples/talk-llama/models/mamba2.cpp delete mode 100644 examples/talk-llama/models/mimo2-iswa.cpp create mode 100644 examples/talk-llama/models/mimo2.cpp create mode 100644 examples/talk-llama/models/minicpm.cpp create mode 100644 examples/talk-llama/models/mistral4.cpp create mode 100644 examples/talk-llama/models/nemotron-h-moe.cpp create mode 100644 examples/talk-llama/models/nomic-bert-moe.cpp create mode 100644 examples/talk-llama/models/nomic-bert.cpp rename examples/talk-llama/models/{openai-moe-iswa.cpp => openai-moe.cpp} (51%) rename examples/talk-llama/models/{pangu-embedded.cpp => pangu-embed.cpp} (53%) create mode 100644 examples/talk-llama/models/phimoe.cpp rename examples/talk-llama/models/{qwen3vl-moe.cpp => qwen3vlmoe.cpp} (57%) rename examples/talk-llama/models/{step35-iswa.cpp => step35.cpp} (52%) diff --git a/examples/talk-llama/llama-arch.cpp b/examples/talk-llama/llama-arch.cpp index 633a66fc665..59dde99e362 100644 --- a/examples/talk-llama/llama-arch.cpp +++ b/examples/talk-llama/llama-arch.cpp @@ -232,6 +232,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, "%s.attention.sliding_window_pattern" }, { LLM_KV_ATTENTION_SCALE, "%s.attention.scale" }, { LLM_KV_ATTENTION_OUTPUT_SCALE, "%s.attention.output_scale" }, + { LLM_KV_ATTENTION_VALUE_SCALE, "%s.attention.value_scale" }, { LLM_KV_ATTENTION_TEMPERATURE_LENGTH, "%s.attention.temperature_length" }, { LLM_KV_ATTENTION_TEMPERATURE_SCALE, "%s.attention.temperature_scale" }, { LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" }, diff --git a/examples/talk-llama/llama-arch.h b/examples/talk-llama/llama-arch.h index 8f335f5c7b3..e37d548c98e 100644 --- a/examples/talk-llama/llama-arch.h +++ b/examples/talk-llama/llama-arch.h @@ -236,6 +236,7 @@ enum llm_kv { LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, LLM_KV_ATTENTION_SCALE, LLM_KV_ATTENTION_OUTPUT_SCALE, + LLM_KV_ATTENTION_VALUE_SCALE, LLM_KV_ATTENTION_TEMPERATURE_LENGTH, LLM_KV_ATTENTION_TEMPERATURE_SCALE, LLM_KV_ATTENTION_KEY_LENGTH_MLA, diff --git a/examples/talk-llama/llama-context.cpp b/examples/talk-llama/llama-context.cpp index 8126249e143..71a59395eb2 100644 --- a/examples/talk-llama/llama-context.cpp +++ b/examples/talk-llama/llama-context.cpp @@ -2230,13 +2230,17 @@ llm_graph_cb llama_context::graph_get_cb() const { class llama_io_write_dummy : public llama_io_write_i { public: - llama_io_write_dummy() = default; + llama_io_write_dummy(bool skip_tensors) : skip_tensors(skip_tensors) {} void write(const void * /* src */, size_t size) override { size_written += size; } - void write_tensor(const ggml_tensor * /* tensor */, size_t /* offset */, size_t size) override { + void write_tensor(ggml_tensor * /* tensor */, size_t /* offset */, size_t size) override { + if (skip_tensors) { + return; + } + size_written += size; } @@ -2245,14 +2249,23 @@ class llama_io_write_dummy : public llama_io_write_i { } private: + const bool skip_tensors; + size_t size_written = 0; }; -class llama_io_write_buffer : public llama_io_write_i { +class llama_io_write_host : public llama_io_write_i { public: - llama_io_write_buffer( + llama_io_write_host( uint8_t * p, size_t len) : ptr(p), buf_size(len) {} + ~llama_io_write_host() { + // TODO: add backend support to batch tensor_get? or some other way to speed this up + for (const auto & winfo : winfos) { + ggml_backend_tensor_get(winfo.tensor, winfo.ptr, winfo.offset, winfo.size); + } + } + void write(const void * src, size_t size) override { if (size > buf_size) { throw std::runtime_error("unexpectedly reached end of buffer"); @@ -2263,11 +2276,14 @@ class llama_io_write_buffer : public llama_io_write_i { buf_size -= size; } - void write_tensor(const ggml_tensor * tensor, size_t offset, size_t size) override { + void write_tensor(ggml_tensor * tensor, size_t offset, size_t size) override { if (size > buf_size) { throw std::runtime_error("unexpectedly reached end of buffer"); } - ggml_backend_tensor_get(tensor, ptr, offset, size); + + // save the write for later during destruction + winfos.push_back({tensor, ptr, size, offset}); + ptr += size; size_written += size; buf_size -= size; @@ -2281,25 +2297,48 @@ class llama_io_write_buffer : public llama_io_write_i { uint8_t * ptr; size_t buf_size = 0; size_t size_written = 0; + + struct write_info { + ggml_tensor * tensor; + uint8_t * ptr; + size_t size; + size_t offset; + }; + std::vector winfos; }; -class llama_io_read_buffer : public llama_io_read_i { +class llama_io_read_host : public llama_io_read_i { public: - llama_io_read_buffer(const uint8_t * p, size_t len) : ptr(p), buf_size(len) {} + llama_io_read_host(const uint8_t * p, size_t len) : ptr(p), buf_size(len) {} - const uint8_t * read(size_t size) override { - const uint8_t * base_ptr = ptr; + ~llama_io_read_host() { + // flush the reads + for (const auto & rinfo : rinfos) { + ggml_backend_tensor_set(rinfo.tensor, rinfo.ptr, rinfo.offset, rinfo.size); + } + } + + void read(void * dst, size_t size) override { if (size > buf_size) { throw std::runtime_error("unexpectedly reached end of buffer"); } + memcpy(dst, ptr, size); ptr += size; size_read += size; buf_size -= size; - return base_ptr; } - void read_to(void * dst, size_t size) override { - memcpy(dst, read(size), size); + void read_tensor(ggml_tensor * tensor, size_t offset, size_t size) override { + if (size > buf_size) { + throw std::runtime_error("unexpectedly reached end of buffer"); + } + + // save for later during destruction + rinfos.push_back({tensor, ptr, size, offset}); + + ptr += size; + size_read += size; + buf_size -= size; } size_t n_bytes() override { @@ -2310,6 +2349,14 @@ class llama_io_read_buffer : public llama_io_read_i { const uint8_t * ptr; size_t buf_size = 0; size_t size_read = 0; + + struct read_info { + ggml_tensor * tensor; + const uint8_t * ptr; + size_t size; + size_t offset; + }; + std::vector rinfos; }; class llama_io_write_file : public llama_io_write_i { @@ -2321,7 +2368,7 @@ class llama_io_write_file : public llama_io_write_i { size_written += size; } - void write_tensor(const ggml_tensor * tensor, size_t offset, size_t size) override { + void write_tensor(ggml_tensor * tensor, size_t offset, size_t size) override { temp_buffer.resize(size); ggml_backend_tensor_get(tensor, temp_buffer.data(), offset, size); write(temp_buffer.data(), temp_buffer.size()); @@ -2341,15 +2388,15 @@ class llama_io_read_file : public llama_io_read_i { public: llama_io_read_file(llama_file * f) : file(f) {} - void read_to(void * dst, size_t size) override { + void read(void * dst, size_t size) override { file->read_raw(dst, size); size_read += size; } - const uint8_t * read(size_t size) override { + void read_tensor(ggml_tensor * tensor, size_t offset, size_t size) override { temp_buffer.resize(size); - read_to(temp_buffer.data(), size); - return temp_buffer.data(); + read(temp_buffer.data(), size); + ggml_backend_tensor_set(tensor, temp_buffer.data(), offset, size); } size_t n_bytes() override { @@ -2362,8 +2409,212 @@ class llama_io_read_file : public llama_io_read_i { std::vector temp_buffer; }; +class llama_io_write_device : public llama_io_write_i { +public: + llama_io_write_device(uint8_t * p, size_t len, llama_memory_buffers & mbufs) : ptr(p), buf_size(len), mbufs(mbufs) { + } + + ~llama_io_write_device() { + llama_memory_buffers mbufs_new; + + for (const auto & winfo : winfos) { + auto * buft = ggml_backend_buffer_get_type(winfo.tensor->buffer); + + mbufs_new[buft].n_tensors++; + mbufs_new[buft].total_size += winfo.size; + } + + for (auto & [buft, mbuf] : mbufs_new) { + ggml_init_params params = { + /*.mem_size =*/ 2*mbuf.n_tensors*ggml_tensor_overhead(), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + mbuf.ctx.reset(ggml_init(params)); + + mbuf.org.reserve(mbuf.n_tensors); + mbuf.cpy.reserve(mbuf.n_tensors); + } + + for (const auto & winfo : winfos) { + auto * buft = ggml_backend_buffer_get_type(winfo.tensor->buffer); + + const int64_t n = winfo.size/ggml_element_size(winfo.tensor); + + auto & mbuf = mbufs_new[buft]; + + mbuf.org.push_back(ggml_view_1d (mbuf.ctx.get(), winfo.tensor, n, winfo.offset)); + mbuf.cpy.push_back(ggml_new_tensor_1d(mbuf.ctx.get(), winfo.tensor->type, n)); + } + + for (auto & [buft, mbuf] : mbufs_new) { + auto & mbuf_cur = mbufs[buft]; + + bool need_alloc = false; + + need_alloc = need_alloc || (!mbuf_cur.buf); + need_alloc = need_alloc || (mbuf_cur.org.size() != mbuf.org.size()); + need_alloc = need_alloc || (mbuf_cur.total_size != mbuf.total_size); + + if (!need_alloc) { + for (size_t i = 0; i < mbuf_cur.org.size(); ++i) { + auto * org0 = mbuf_cur.org[i]; + auto * org1 = mbuf.org[i]; + + if (!ggml_are_same_shape(org0, org1)) { + need_alloc = true; + break; + } + + if (org0->view_src != org1->view_src || org0->view_offs != org1->view_offs) { + need_alloc = true; + break; + } + } + } + + if (need_alloc) { + mbuf_cur = std::move(mbuf); + + mbuf_cur.buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(mbuf_cur.ctx.get(), buft)); + + LLAMA_LOG_INFO("%s: allocated '%s' buffer %.3f MiB\n", __func__, ggml_backend_buft_name(buft), mbuf.total_size/1024.0/1024.0); + } + + for (size_t i = 0; i < mbuf_cur.org.size(); ++i) { + ggml_backend_tensor_copy(mbuf_cur.org[i], mbuf_cur.cpy[i]); + } + } + } + + void write(const void * src, size_t size) override { + if (size > buf_size) { + throw std::runtime_error("unexpectedly reached end of buffer"); + } + memcpy(ptr, src, size); + ptr += size; + size_written += size; + buf_size -= size; + } + + void write_tensor(ggml_tensor * tensor, size_t offset, size_t size) override { + // save the write for later during destruction + winfos.push_back({tensor, ptr, size, offset}); + } + + size_t n_bytes() override { + return size_written; + } + +private: + uint8_t * ptr; + size_t buf_size = 0; + size_t size_written = 0; + + struct write_info { + ggml_tensor * tensor; + uint8_t * ptr; + size_t size; + size_t offset; + }; + std::vector winfos; + + llama_memory_buffers & mbufs; +}; + +class llama_io_read_device : public llama_io_read_i { +public: + llama_io_read_device(const uint8_t * p, size_t len, const llama_memory_buffers & mbufs) : ptr(p), buf_size(len), mbufs(mbufs) { + } + + ~llama_io_read_device() { + llama_memory_buffers mbufs_new; + + for (const auto & rinfo : rinfos) { + auto * buft = ggml_backend_buffer_get_type(rinfo.tensor->buffer); + + mbufs_new[buft].n_tensors++; + mbufs_new[buft].total_size += rinfo.size; + } + + for (auto & [buft, mbuf] : mbufs_new) { + ggml_init_params params = { + /*.mem_size =*/ mbuf.n_tensors*ggml_tensor_overhead(), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + mbuf.ctx.reset(ggml_init(params)); + + mbuf.org.reserve(mbuf.n_tensors); + } + + for (const auto & rinfo : rinfos) { + auto * buft = ggml_backend_buffer_get_type(rinfo.tensor->buffer); + + const int64_t n = rinfo.size/ggml_element_size(rinfo.tensor); + + auto & mbuf = mbufs_new[buft]; + + mbuf.org.push_back(ggml_view_1d(mbuf.ctx.get(), rinfo.tensor, n, rinfo.offset)); + + auto & view = mbuf.org.back(); + view->buffer = rinfo.tensor->buffer; + } + + for (auto & [buft, mbuf] : mbufs_new) { + const auto & mbuf_cur = mbufs.at(buft); + + if (!mbuf_cur.buf || mbuf_cur.n_tensors != mbuf.n_tensors || mbuf_cur.total_size != mbuf.total_size) { + GGML_ABORT("%s: memory buffer mismatch\n", __func__); + } + + for (size_t i = 0; i < mbuf_cur.org.size(); ++i) { + ggml_backend_tensor_copy(mbuf_cur.cpy[i], mbuf.org[i]); + } + } + + GGML_ASSERT(buf_size == 0); + } + + void read(void * dst, size_t size) override { + if (size > buf_size) { + throw std::runtime_error("unexpectedly reached end of buffer"); + } + memcpy(dst, ptr, size); + ptr += size; + size_read += size; + buf_size -= size; + } + + void read_tensor(ggml_tensor * tensor, size_t offset, size_t size) override { + // save for later during destruction + rinfos.push_back({tensor, ptr, size, offset}); + } + + size_t n_bytes() override { + return size_read; + } + +private: + const uint8_t * ptr; + size_t buf_size = 0; + size_t size_read = 0; + + struct read_info { + ggml_tensor * tensor; + const uint8_t * ptr; + size_t size; + size_t offset; + }; + std::vector rinfos; + + const llama_memory_buffers & mbufs; +}; + size_t llama_context::state_get_size() { - llama_io_write_dummy io; + llama_io_write_dummy io(false); try { return state_write_data(io); } catch (const std::exception & err) { @@ -2373,7 +2624,7 @@ size_t llama_context::state_get_size() { } size_t llama_context::state_get_data(uint8_t * dst, size_t size) { - llama_io_write_buffer io(dst, size); + llama_io_write_host io(dst, size); try { return state_write_data(io); } catch (const std::exception & err) { @@ -2383,7 +2634,7 @@ size_t llama_context::state_get_data(uint8_t * dst, size_t size) { } size_t llama_context::state_set_data(const uint8_t * src, size_t size) { - llama_io_read_buffer io(src, size); + llama_io_read_host io(src, size); try { return state_read_data(io); } catch (const std::exception & err) { @@ -2392,9 +2643,14 @@ size_t llama_context::state_set_data(const uint8_t * src, size_t size) { } } +static constexpr uint32_t io_magic = 0xaf143cd8; + size_t llama_context::state_seq_get_size(llama_seq_id seq_id, llama_state_seq_flags flags) { - llama_io_write_dummy io; + llama_io_write_dummy io(flags & LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); try { + io.write(&io_magic, sizeof(io_magic)); + io.write(&seq_id, sizeof(seq_id)); + return state_seq_write_data(io, seq_id, flags); } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what()); @@ -2403,9 +2659,18 @@ size_t llama_context::state_seq_get_size(llama_seq_id seq_id, llama_state_seq_fl } size_t llama_context::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size, llama_state_seq_flags flags) { - llama_io_write_buffer io(dst, size); + std::unique_ptr io; + if (flags & LLAMA_STATE_SEQ_FLAGS_ON_DEVICE) { + io = std::make_unique(dst, size, mem_storage[seq_id]); + } else { + io = std::make_unique(dst, size); + } + try { - return state_seq_write_data(io, seq_id, flags); + io->write(&io_magic, sizeof(io_magic)); + io->write(&seq_id, sizeof(seq_id)); + + return state_seq_write_data(*io, seq_id, flags); } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what()); return 0; @@ -2413,9 +2678,38 @@ size_t llama_context::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, siz } size_t llama_context::state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size, llama_state_seq_flags flags) { - llama_io_read_buffer io(src, size); + std::unique_ptr io; + if (flags & LLAMA_STATE_SEQ_FLAGS_ON_DEVICE) { + // create a temporary io to read the magic and the src seq_id + io = std::make_unique(src, size); + + uint32_t magic_read; + io->read(&magic_read, sizeof(magic_read)); + if (io_magic != magic_read) { + throw std::runtime_error("wrong sequence state magic"); + } + + llama_seq_id seq_id_read; + io->read(&seq_id_read, sizeof(seq_id_read)); + + GGML_ASSERT(mem_storage.find(seq_id_read) != mem_storage.end()); + + io = std::make_unique(src, size, mem_storage[seq_id_read]); + } else { + io = std::make_unique(src, size); + } + try { - return state_seq_read_data(io, seq_id, flags); + uint32_t magic_read; + io->read(&magic_read, sizeof(magic_read)); + if (io_magic != magic_read) { + throw std::runtime_error("wrong sequence state magic"); + } + + llama_seq_id seq_id_read; + io->read(&seq_id_read, sizeof(seq_id_read)); + + return state_seq_read_data(*io, seq_id, flags); } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what()); return 0; @@ -3406,7 +3700,6 @@ size_t llama_state_seq_get_data_ext(llama_context * ctx, uint8_t * dst, size_t s return ctx->state_seq_get_data(seq_id, dst, size, flags); } - size_t llama_state_seq_set_data_ext(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id, llama_state_seq_flags flags) { ctx->synchronize(); diff --git a/examples/talk-llama/llama-context.h b/examples/talk-llama/llama-context.h index 53c705eaffc..92d1b0cf95a 100644 --- a/examples/talk-llama/llama-context.h +++ b/examples/talk-llama/llama-context.h @@ -23,6 +23,21 @@ class llama_io_write_i; struct llama_memory_i; struct llama_memory_context_i; +// stores copy of the memory in device buffer. used for fast state save/load +struct llama_memory_buffer { + int n_tensors = 0; + size_t total_size = 0; + + ggml_backend_buffer_ptr buf; + + ggml_context_ptr ctx; + + std::vector org; + std::vector cpy; +}; + +using llama_memory_buffers = std::map; + struct llama_context { // init scheduler and compute buffers, reserve worst-case graphs llama_context( @@ -128,6 +143,7 @@ struct llama_context { size_t state_set_data(const uint8_t * src, size_t size); size_t state_seq_get_size(llama_seq_id seq_id, llama_state_seq_flags flags); + size_t state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size, llama_state_seq_flags flags); size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size, llama_state_seq_flags flags); @@ -328,6 +344,9 @@ struct llama_context { // host buffer for the model output (logits and embeddings) ggml_backend_buffer_ptr buf_output; + // keep copies of the per-sequence memory on the device + std::map mem_storage; + bool has_evaluated_once = false; // env: LLAMA_GRAPH_REUSE_DISABLE diff --git a/examples/talk-llama/llama-graph.cpp b/examples/talk-llama/llama-graph.cpp index 2ff23f87cf4..fe155c92dea 100644 --- a/examples/talk-llama/llama-graph.cpp +++ b/examples/talk-llama/llama-graph.cpp @@ -65,8 +65,13 @@ static ggml_tensor * ggml_mul_mat_aux( ggml_tensor * res; - res = ggml_reshape_2d(ctx, cur, n, ggml_nelements(cur)/n); + if (!ggml_is_contiguous(cur)) { + res = ggml_cont_2d (ctx, cur, n, ggml_nelements(cur)/n); + } else { + res = ggml_reshape_2d(ctx, cur, n, ggml_nelements(cur)/n); + } res = ggml_mul_mat (ctx, rot, res); + ggml_mul_mat_set_hint(res, GGML_HINT_SRC0_IS_HADAMARD); res = ggml_reshape_4d(ctx, res, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]); return res; diff --git a/examples/talk-llama/llama-hparams.h b/examples/talk-llama/llama-hparams.h index ac7f9ee8650..0160a89caa2 100644 --- a/examples/talk-llama/llama-hparams.h +++ b/examples/talk-llama/llama-hparams.h @@ -166,6 +166,8 @@ struct llama_hparams { float f_attn_out_scale = 0.0f; uint32_t attn_temp_length = 0; + float f_attn_value_scale = 0.0f; + bool causal_attn = true; bool use_alibi = false; bool attn_soft_cap = false; diff --git a/examples/talk-llama/llama-io.cpp b/examples/talk-llama/llama-io.cpp index 7ad70d16334..5ec4634943f 100644 --- a/examples/talk-llama/llama-io.cpp +++ b/examples/talk-llama/llama-io.cpp @@ -1,5 +1,7 @@ #include "llama-io.h" +#include + void llama_io_write_i::write_string(const std::string & str) { uint32_t str_size = str.size(); @@ -9,7 +11,10 @@ void llama_io_write_i::write_string(const std::string & str) { void llama_io_read_i::read_string(std::string & str) { uint32_t str_size; - read_to(&str_size, sizeof(str_size)); + read(&str_size, sizeof(str_size)); + + std::vector buf(str_size); + read(buf.data(), str_size); - str.assign((const char *) read(str_size), str_size); + str.assign(buf.data(), str_size); } diff --git a/examples/talk-llama/llama-io.h b/examples/talk-llama/llama-io.h index ce9216b83b1..f276af4fb96 100644 --- a/examples/talk-llama/llama-io.h +++ b/examples/talk-llama/llama-io.h @@ -12,7 +12,7 @@ class llama_io_write_i { virtual ~llama_io_write_i() = default; virtual void write(const void * src, size_t size) = 0; - virtual void write_tensor(const ggml_tensor * tensor, size_t offset, size_t size) = 0; + virtual void write_tensor(ggml_tensor * tensor, size_t offset, size_t size) = 0; // bytes written so far virtual size_t n_bytes() = 0; @@ -25,8 +25,8 @@ class llama_io_read_i { llama_io_read_i() = default; virtual ~llama_io_read_i() = default; - virtual const uint8_t * read(size_t size) = 0; - virtual void read_to(void * dst, size_t size) = 0; + virtual void read(void * dst, size_t size) = 0; + virtual void read_tensor(ggml_tensor * tensor, size_t offset, size_t size) = 0; // bytes read so far virtual size_t n_bytes() = 0; diff --git a/examples/talk-llama/llama-kv-cache.cpp b/examples/talk-llama/llama-kv-cache.cpp index 09102f549c8..a49a055a630 100644 --- a/examples/talk-llama/llama-kv-cache.cpp +++ b/examples/talk-llama/llama-kv-cache.cpp @@ -67,6 +67,7 @@ static ggml_tensor * ggml_mul_mat_aux( res = ggml_reshape_2d(ctx, cur, n, ggml_nelements(cur)/n); res = ggml_mul_mat (ctx, rot, res); + ggml_mul_mat_set_hint(res, GGML_HINT_SRC0_IS_HADAMARD); res = ggml_reshape_4d(ctx, res, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]); return res; @@ -1900,14 +1901,14 @@ void llama_kv_cache::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size())); uint32_t n_stream_cur; - io.read_to(&n_stream_cur, sizeof(n_stream_cur)); + io.read(&n_stream_cur, sizeof(n_stream_cur)); if (n_stream_cur != n_stream) { throw std::runtime_error("n_stream mismatch"); } for (uint32_t s = 0; s < n_stream; ++s) { uint32_t cell_count; - io.read_to(&cell_count, sizeof(cell_count)); + io.read(&cell_count, sizeof(cell_count)); if (cell_count == 0) { continue; @@ -2082,8 +2083,8 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32 llama_pos pos; uint32_t n_seq_id; - io.read_to(&pos, sizeof(pos)); - io.read_to(&n_seq_id, sizeof(n_seq_id)); + io.read(&pos, sizeof(pos)); + io.read(&n_seq_id, sizeof(n_seq_id)); if (n_seq_id != 1) { LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__); @@ -2092,7 +2093,7 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32 if (hparams.n_pos_per_embd() > 1) { llama_kv_cell_ext ext; - io.read_to(&ext, sizeof(ext)); + io.read(&ext, sizeof(ext)); ubatch.pos[i + ubatch.n_tokens] = ext.y; ubatch.pos[i + ubatch.n_tokens*2] = ext.x; @@ -2101,7 +2102,7 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32 // read the sequence id, but directly discard it - we will use dest_seq_id instead { llama_seq_id seq_id; - io.read_to(&seq_id, sizeof(seq_id)); + io.read(&seq_id, sizeof(seq_id)); } ubatch.pos[i] = pos; @@ -2143,20 +2144,20 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32 llama_pos pos; uint32_t n_seq_id; - io.read_to(&pos, sizeof(pos)); - io.read_to(&n_seq_id, sizeof(n_seq_id)); + io.read(&pos, sizeof(pos)); + io.read(&n_seq_id, sizeof(n_seq_id)); cells.pos_set(i, pos); if (hparams.n_pos_per_embd() > 1) { llama_kv_cell_ext ext; - io.read_to(&ext, sizeof(ext)); + io.read(&ext, sizeof(ext)); cells.ext_set(i, ext); } for (uint32_t j = 0; j < n_seq_id; ++j) { llama_seq_id seq_id; - io.read_to(&seq_id, sizeof(seq_id)); + io.read(&seq_id, sizeof(seq_id)); if (seq_id < 0 || (uint32_t) seq_id >= n_seq_max) { LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, n_seq_max); @@ -2189,8 +2190,8 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32 uint32_t v_trans; uint32_t n_layer; - io.read_to(&v_trans, sizeof(v_trans)); - io.read_to(&n_layer, sizeof(n_layer)); + io.read(&v_trans, sizeof(v_trans)); + io.read(&n_layer, sizeof(n_layer)); if (n_layer != layers.size()) { LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, (uint32_t) layers.size()); @@ -2217,7 +2218,7 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32 // Read type of key int32_t k_type_i_ref; - io.read_to(&k_type_i_ref, sizeof(k_type_i_ref)); + io.read(&k_type_i_ref, sizeof(k_type_i_ref)); const int32_t k_type_i = (int32_t) k->type; if (k_type_i != k_type_i_ref) { LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il); @@ -2226,7 +2227,7 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32 // Read row size of key uint64_t k_size_row_ref; - io.read_to(&k_size_row_ref, sizeof(k_size_row_ref)); + io.read(&k_size_row_ref, sizeof(k_size_row_ref)); const size_t k_size_row = ggml_row_size(k->type, n_embd_k_gqa); if (k_size_row != k_size_row_ref) { LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il); @@ -2236,13 +2237,12 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32 if (cell_count) { if (sinfo.is_contiguous()) { // Fast path: contiguous cells, single memcpy - ggml_backend_tensor_set(k, io.read(cell_count * k_size_row), sinfo.head() * k_size_row, cell_count * k_size_row); + io.read_tensor(k, sinfo.head() * k_size_row, cell_count * k_size_row); } else { // Slow path: scatter to non-contiguous positions - const void * src = io.read(cell_count * k_size_row); for (uint32_t i = 0; i < cell_count; ++i) { const size_t dst_offset = sinfo.idxs[0][i] * k_size_row; - ggml_backend_tensor_set(k, (const char*)src + i * k_size_row, dst_offset, k_size_row); + io.read_tensor(k, dst_offset, k_size_row); } } } @@ -2261,7 +2261,7 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32 // Read type of value int32_t v_type_i_ref; - io.read_to(&v_type_i_ref, sizeof(v_type_i_ref)); + io.read(&v_type_i_ref, sizeof(v_type_i_ref)); const int32_t v_type_i = (int32_t) v->type; if (v_type_i != v_type_i_ref) { LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); @@ -2270,7 +2270,7 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32 // Read row size of value uint64_t v_size_row_ref; - io.read_to(&v_size_row_ref, sizeof(v_size_row_ref)); + io.read(&v_size_row_ref, sizeof(v_size_row_ref)); const size_t v_size_row = ggml_row_size(v->type, n_embd_v_gqa); if (v_size_row != v_size_row_ref) { LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il); @@ -2280,13 +2280,12 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32 if (cell_count) { if (sinfo.is_contiguous()) { // Fast path: contiguous cells, single memcpy - ggml_backend_tensor_set(v, io.read(cell_count * v_size_row), sinfo.head() * v_size_row, cell_count * v_size_row); + io.read_tensor(v, sinfo.head() * v_size_row, cell_count * v_size_row); } else { // Slow path: scatter to non-contiguous positions - const void * src = io.read(cell_count * v_size_row); for (uint32_t i = 0; i < cell_count; ++i) { const size_t dst_offset = sinfo.idxs[0][i] * v_size_row; - ggml_backend_tensor_set(v, (const char*)src + i * v_size_row, dst_offset, v_size_row); + io.read_tensor(v, dst_offset, v_size_row); } } } @@ -2305,7 +2304,7 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32 // Read type of value int32_t v_type_i_ref; - io.read_to(&v_type_i_ref, sizeof(v_type_i_ref)); + io.read(&v_type_i_ref, sizeof(v_type_i_ref)); const int32_t v_type_i = (int32_t) v->type; if (v_type_i != v_type_i_ref) { LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); @@ -2314,7 +2313,7 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32 // Read element size of value uint32_t v_size_el_ref; - io.read_to(&v_size_el_ref, sizeof(v_size_el_ref)); + io.read(&v_size_el_ref, sizeof(v_size_el_ref)); const size_t v_size_el = ggml_type_size(v->type); if (v_size_el != v_size_el_ref) { LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il); @@ -2323,7 +2322,7 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32 // Read GQA embedding size uint32_t n_embd_v_gqa_ref; - io.read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref)); + io.read(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref)); if (n_embd_v_gqa != n_embd_v_gqa_ref) { LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il); return false; @@ -2335,15 +2334,14 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32 const uint32_t h = sinfo.head(); for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { const size_t dst_offset = (h + j * cells.size()) * v_size_el; - ggml_backend_tensor_set(v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el); + io.read_tensor(v, dst_offset, cell_count * v_size_el); } } else { // Slow path: scatter to non-contiguous positions for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { - const void * src = io.read(cell_count * v_size_el); for (uint32_t i = 0; i < cell_count; ++i) { const size_t dst_offset = (sinfo.idxs[0][i] + j * cells.size()) * v_size_el; - ggml_backend_tensor_set(v, (const char*)src + i * v_size_el, dst_offset, v_size_el); + io.read_tensor(v, dst_offset, v_size_el); } } } diff --git a/examples/talk-llama/llama-memory-recurrent.cpp b/examples/talk-llama/llama-memory-recurrent.cpp index 9287fe45e96..c07f1d969cb 100644 --- a/examples/talk-llama/llama-memory-recurrent.cpp +++ b/examples/talk-llama/llama-memory-recurrent.cpp @@ -726,6 +726,10 @@ void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq cell_ranges.emplace_back(cell_range_begin, size); } + if (flags % LLAMA_STATE_SEQ_FLAGS_ON_DEVICE && cell_ranges.size() > 1) { + GGML_ABORT("cannot save/load multiple ranges of cells to/from device memory\n"); + } + // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count uint32_t cell_count_check = 0; for (const auto & range : cell_ranges) { @@ -743,7 +747,7 @@ void llama_memory_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_i GGML_UNUSED(flags); uint32_t cell_count; - io.read_to(&cell_count, sizeof(cell_count)); + io.read(&cell_count, sizeof(cell_count)); bool res = true; @@ -784,7 +788,7 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std:: const uint32_t n_layer = hparams.n_layer; io.write(&s_trans, sizeof(s_trans)); - io.write(&n_layer, sizeof(n_layer)); + io.write(&n_layer, sizeof(n_layer)); // Iterate and write all the R tensors first, each row is a cell // Get whole range at a time @@ -879,8 +883,8 @@ bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell llama_pos pos; uint32_t n_seq_id; - io.read_to(&pos, sizeof(pos)); - io.read_to(&n_seq_id, sizeof(n_seq_id)); + io.read(&pos, sizeof(pos)); + io.read(&n_seq_id, sizeof(n_seq_id)); if (n_seq_id != 0) { LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__); @@ -920,14 +924,14 @@ bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell llama_pos pos; uint32_t n_seq_id; - io.read_to(&pos, sizeof(pos)); - io.read_to(&n_seq_id, sizeof(n_seq_id)); + io.read(&pos, sizeof(pos)); + io.read(&n_seq_id, sizeof(n_seq_id)); cell.pos = pos; for (uint32_t j = 0; j < n_seq_id; ++j) { llama_seq_id seq_id; - io.read_to(&seq_id, sizeof(seq_id)); + io.read(&seq_id, sizeof(seq_id)); if (seq_id < 0 || (uint32_t) seq_id >= this->n_seq_max) { LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, this->n_seq_max); @@ -961,8 +965,8 @@ bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell_count) { uint32_t s_trans; uint32_t n_layer; - io.read_to(&s_trans, sizeof(s_trans)); - io.read_to(&n_layer, sizeof(n_layer)); + io.read(&s_trans, sizeof(s_trans)); + io.read(&n_layer, sizeof(n_layer)); if (n_layer != hparams.n_layer) { LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, hparams.n_layer); @@ -984,7 +988,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell // Read type of key int32_t r_type_i_ref; - io.read_to(&r_type_i_ref, sizeof(r_type_i_ref)); + io.read(&r_type_i_ref, sizeof(r_type_i_ref)); const int32_t r_type_i = (int32_t) r_l[il]->type; if (r_type_i != r_type_i_ref) { LLAMA_LOG_ERROR("%s: mismatched r type (%d != %d, layer %d)\n", __func__, r_type_i, r_type_i_ref, il); @@ -993,7 +997,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell // Read row size of key uint64_t r_size_row_ref; - io.read_to(&r_size_row_ref, sizeof(r_size_row_ref)); + io.read(&r_size_row_ref, sizeof(r_size_row_ref)); const size_t r_size_row = ggml_row_size(r_l[il]->type, hparams.n_embd_r()); if (r_size_row != r_size_row_ref) { LLAMA_LOG_ERROR("%s: mismatched r row size (%zu != %zu, layer %d)\n", __func__, r_size_row, (size_t) r_size_row_ref, il); @@ -1002,7 +1006,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell if (cell_count) { // Read and set the keys for the whole cell range - ggml_backend_tensor_set(r_l[il], io.read(cell_count * r_size_row), head * r_size_row, cell_count * r_size_row); + io.read_tensor(r_l[il], head * r_size_row, cell_count * r_size_row); } } @@ -1013,7 +1017,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell // Read type of value int32_t s_type_i_ref; - io.read_to(&s_type_i_ref, sizeof(s_type_i_ref)); + io.read(&s_type_i_ref, sizeof(s_type_i_ref)); const int32_t s_type_i = (int32_t)s_l[il]->type; if (s_type_i != s_type_i_ref) { @@ -1023,7 +1027,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell // Read row size of value uint64_t s_size_row_ref; - io.read_to(&s_size_row_ref, sizeof(s_size_row_ref)); + io.read(&s_size_row_ref, sizeof(s_size_row_ref)); const size_t s_size_row = ggml_row_size(s_l[il]->type, hparams.n_embd_s()); if (s_size_row != s_size_row_ref) { LLAMA_LOG_ERROR("%s: mismatched s row size (%zu != %zu, layer %d)\n", __func__, s_size_row, (size_t) s_size_row_ref, il); @@ -1032,7 +1036,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell if (cell_count) { // Read and set the values for the whole cell range - ggml_backend_tensor_set(s_l[il], io.read(cell_count * s_size_row), head * s_size_row, cell_count * s_size_row); + io.read_tensor(s_l[il], head * s_size_row, cell_count * s_size_row); } } } else { @@ -1045,7 +1049,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell // Read type of value int32_t s_type_i_ref; - io.read_to(&s_type_i_ref, sizeof(s_type_i_ref)); + io.read(&s_type_i_ref, sizeof(s_type_i_ref)); const int32_t s_type_i = (int32_t)s_l[il]->type; if (s_type_i != s_type_i_ref) { LLAMA_LOG_ERROR("%s: mismatched s type (%d != %d, layer %d)\n", __func__, s_type_i, s_type_i_ref, il); @@ -1054,7 +1058,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell // Read element size of value uint32_t s_size_el_ref; - io.read_to(&s_size_el_ref, sizeof(s_size_el_ref)); + io.read(&s_size_el_ref, sizeof(s_size_el_ref)); const size_t s_size_el = ggml_type_size(s_l[il]->type); if (s_size_el != s_size_el_ref) { LLAMA_LOG_ERROR("%s: mismatched s element size (%zu != %zu, layer %d)\n", __func__, s_size_el, (size_t) s_size_el_ref, il); @@ -1063,7 +1067,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell // Read state embedding size uint32_t n_embd_s_ref; - io.read_to(&n_embd_s_ref, sizeof(n_embd_s_ref)); + io.read(&n_embd_s_ref, sizeof(n_embd_s_ref)); if (n_embd_s != n_embd_s_ref) { LLAMA_LOG_ERROR("%s: mismatched s embedding size (%u != %u, layer %d)\n", __func__, n_embd_s, n_embd_s_ref, il); return false; @@ -1073,7 +1077,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell // For each row in the transposed matrix, read the values for the whole cell range for (uint32_t j = 0; j < n_embd_s; ++j) { const size_t dst_offset = (head + j * size) * s_size_el; - ggml_backend_tensor_set(s_l[il], io.read(cell_count * s_size_el), dst_offset, cell_count * s_size_el); + io.read_tensor(s_l[il], dst_offset, cell_count * s_size_el); } } } diff --git a/examples/talk-llama/llama-model-saver.cpp b/examples/talk-llama/llama-model-saver.cpp index 26864c18e97..e83056557bf 100644 --- a/examples/talk-llama/llama-model-saver.cpp +++ b/examples/talk-llama/llama-model-saver.cpp @@ -268,6 +268,7 @@ void llama_model_saver::add_kv_from_model() { // add_kv(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, ???); add_kv(LLM_KV_ATTENTION_SCALE, hparams.f_attention_scale); add_kv(LLM_KV_ATTENTION_OUTPUT_SCALE, hparams.f_attn_out_scale); + add_kv(LLM_KV_ATTENTION_VALUE_SCALE, hparams.f_attn_value_scale); add_kv(LLM_KV_ATTENTION_TEMPERATURE_LENGTH, hparams.attn_temp_length); add_kv(LLM_KV_ATTENTION_TEMPERATURE_SCALE, hparams.f_attn_temp_scale); add_kv(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla_impl); diff --git a/examples/talk-llama/llama-model.cpp b/examples/talk-llama/llama-model.cpp index 9e2a13cbd43..ff30a2ae7a6 100644 --- a/examples/talk-llama/llama-model.cpp +++ b/examples/talk-llama/llama-model.cpp @@ -34,6 +34,285 @@ #include #include +static llama_model * llama_model_mapping(llm_arch arch, const llama_model_params & params) { + switch (arch) { + case LLM_ARCH_LLAMA: + return new llama_model_llama(params); + case LLM_ARCH_LLAMA4: + return new llama_model_llama4(params); + case LLM_ARCH_LLAMA_EMBED: + return new llama_model_llama_embed(params); + case LLM_ARCH_MAINCODER: + return new llama_model_maincoder(params); + case LLM_ARCH_DECI: + return new llama_model_deci(params); + case LLM_ARCH_BAICHUAN: + return new llama_model_baichuan(params); + case LLM_ARCH_FALCON: + return new llama_model_falcon(params); + case LLM_ARCH_GROK: + return new llama_model_grok(params); + case LLM_ARCH_STARCODER: + return new llama_model_starcoder(params); + case LLM_ARCH_REFACT: + return new llama_model_refact(params); + case LLM_ARCH_BERT: + return new llama_model_bert(params); + case LLM_ARCH_JINA_BERT_V2: + return new llama_model_jina_bert_v2(params); + case LLM_ARCH_JINA_BERT_V3: + return new llama_model_jina_bert_v3(params); + case LLM_ARCH_NOMIC_BERT: + return new llama_model_nomic_bert(params); + case LLM_ARCH_NOMIC_BERT_MOE: + return new llama_model_nomic_bert_moe(params); + case LLM_ARCH_MODERN_BERT: + return new llama_model_modern_bert(params); + case LLM_ARCH_NEO_BERT: + return new llama_model_neo_bert(params); + case LLM_ARCH_EUROBERT: + return new llama_model_eurobert(params); + case LLM_ARCH_BLOOM: + return new llama_model_bloom(params); + case LLM_ARCH_MPT: + return new llama_model_mpt(params); + case LLM_ARCH_STABLELM: + return new llama_model_stablelm(params); + case LLM_ARCH_QWEN: + return new llama_model_qwen(params); + case LLM_ARCH_QWEN2: + return new llama_model_qwen2(params); + case LLM_ARCH_DREAM: + return new llama_model_dream(params); + case LLM_ARCH_LLADA: + return new llama_model_llada(params); + case LLM_ARCH_LLADA_MOE: + return new llama_model_llada_moe(params); + case LLM_ARCH_RND1: + return new llama_model_rnd1(params); + case LLM_ARCH_QWEN2VL: + return new llama_model_qwen2vl(params); + case LLM_ARCH_QWEN2MOE: + return new llama_model_qwen2moe(params); + case LLM_ARCH_QWEN3: + return new llama_model_qwen3(params); + case LLM_ARCH_QWEN3MOE: + return new llama_model_qwen3moe(params); + case LLM_ARCH_QWEN3VL: + return new llama_model_qwen3vl(params); + case LLM_ARCH_QWEN3VLMOE: + return new llama_model_qwen3vlmoe(params); + case LLM_ARCH_PHI2: + return new llama_model_phi2(params); + case LLM_ARCH_PHI3: + return new llama_model_phi3(params); + case LLM_ARCH_PHIMOE: + return new llama_model_phimoe(params); + case LLM_ARCH_PLAMO: + return new llama_model_plamo(params); + case LLM_ARCH_PLAMO2: + return new llama_model_plamo2(params); + case LLM_ARCH_PLAMO3: + return new llama_model_plamo3(params); + case LLM_ARCH_GPT2: + return new llama_model_gpt2(params); + case LLM_ARCH_CODESHELL: + return new llama_model_codeshell(params); + case LLM_ARCH_ORION: + return new llama_model_orion(params); + case LLM_ARCH_INTERNLM2: + return new llama_model_internlm2(params); + case LLM_ARCH_MINICPM3: + return new llama_model_minicpm3(params); + case LLM_ARCH_GEMMA: + return new llama_model_gemma(params); + case LLM_ARCH_GEMMA2: + return new llama_model_gemma2(params); + case LLM_ARCH_GEMMA3: + return new llama_model_gemma3(params); + case LLM_ARCH_GEMMA3N: + return new llama_model_gemma3n(params); + case LLM_ARCH_GEMMA4: + return new llama_model_gemma4(params); + case LLM_ARCH_GEMMA_EMBEDDING: + return new llama_model_gemma_embedding(params); + case LLM_ARCH_STARCODER2: + return new llama_model_starcoder2(params); + case LLM_ARCH_MAMBA: + return new llama_model_mamba(params); + case LLM_ARCH_MAMBA2: + return new llama_model_mamba2(params); + case LLM_ARCH_JAMBA: + return new llama_model_jamba(params); + case LLM_ARCH_XVERSE: + return new llama_model_xverse(params); + case LLM_ARCH_COMMAND_R: + return new llama_model_command_r(params); + case LLM_ARCH_COHERE2: + return new llama_model_cohere2(params); + case LLM_ARCH_DBRX: + return new llama_model_dbrx(params); + case LLM_ARCH_OLMO: + return new llama_model_olmo(params); + case LLM_ARCH_OLMO2: + return new llama_model_olmo2(params); + case LLM_ARCH_OLMOE: + return new llama_model_olmoe(params); + case LLM_ARCH_OPENELM: + return new llama_model_openelm(params); + case LLM_ARCH_GPTNEOX: + return new llama_model_gptneox(params); + case LLM_ARCH_ARCTIC: + return new llama_model_arctic(params); + case LLM_ARCH_DEEPSEEK: + return new llama_model_deepseek(params); + case LLM_ARCH_DEEPSEEK2: + return new llama_model_deepseek2(params); + case LLM_ARCH_DEEPSEEK2OCR: + return new llama_model_deepseek2ocr(params); + case LLM_ARCH_GLM_DSA: + return new llama_model_glm_dsa(params); + case LLM_ARCH_MISTRAL4: + return new llama_model_mistral4(params); + case LLM_ARCH_CHATGLM: + return new llama_model_chatglm(params); + case LLM_ARCH_GLM4: + return new llama_model_glm4(params); + case LLM_ARCH_GLM4_MOE: + return new llama_model_glm4_moe(params); + case LLM_ARCH_BITNET: + return new llama_model_bitnet(params); + case LLM_ARCH_T5: + return new llama_model_t5(params); + case LLM_ARCH_T5ENCODER: + return new llama_model_t5encoder(params); + case LLM_ARCH_JAIS: + return new llama_model_jais(params); + case LLM_ARCH_JAIS2: + return new llama_model_jais2(params); + case LLM_ARCH_NEMOTRON: + return new llama_model_nemotron(params); + case LLM_ARCH_NEMOTRON_H: + return new llama_model_nemotron_h(params); + case LLM_ARCH_NEMOTRON_H_MOE: + return new llama_model_nemotron_h_moe(params); + case LLM_ARCH_EXAONE: + return new llama_model_exaone(params); + case LLM_ARCH_EXAONE4: + return new llama_model_exaone4(params); + case LLM_ARCH_EXAONE_MOE: + return new llama_model_exaone_moe(params); + case LLM_ARCH_RWKV6: + return new llama_model_rwkv6(params); + case LLM_ARCH_RWKV6QWEN2: + return new llama_model_rwkv6qwen2(params); + case LLM_ARCH_RWKV7: + return new llama_model_rwkv7(params); + case LLM_ARCH_ARWKV7: + return new llama_model_arwkv7(params); + case LLM_ARCH_GRANITE: + return new llama_model_granite(params); + case LLM_ARCH_GRANITE_MOE: + return new llama_model_granite_moe(params); + case LLM_ARCH_MINICPM: + return new llama_model_minicpm(params); + case LLM_ARCH_GRANITE_HYBRID: + return new llama_model_granite_hybrid(params); + case LLM_ARCH_CHAMELEON: + return new llama_model_chameleon(params); + case LLM_ARCH_WAVTOKENIZER_DEC: + return new llama_model_wavtokenizer_dec(params); + case LLM_ARCH_PLM: + return new llama_model_plm(params); + case LLM_ARCH_BAILINGMOE: + return new llama_model_bailingmoe(params); + case LLM_ARCH_BAILINGMOE2: + return new llama_model_bailingmoe2(params); + case LLM_ARCH_SEED_OSS: + return new llama_model_seed_oss(params); + case LLM_ARCH_DOTS1: + return new llama_model_dots1(params); + case LLM_ARCH_ARCEE: + return new llama_model_arcee(params); + case LLM_ARCH_AFMOE: + return new llama_model_afmoe(params); + case LLM_ARCH_ERNIE4_5: + return new llama_model_ernie4_5(params); + case LLM_ARCH_ERNIE4_5_MOE: + return new llama_model_ernie4_5_moe(params); + case LLM_ARCH_PADDLEOCR: + return new llama_model_paddleocr(params); + case LLM_ARCH_HUNYUAN_MOE: + return new llama_model_hunyuan_moe(params); + case LLM_ARCH_HUNYUAN_VL: + return new llama_model_hunyuan_vl(params); + case LLM_ARCH_HUNYUAN_DENSE: + return new llama_model_hunyuan_dense(params); + case LLM_ARCH_SMOLLM3: + return new llama_model_smollm3(params); + case LLM_ARCH_OPENAI_MOE: + return new llama_model_openai_moe(params); + case LLM_ARCH_FALCON_H1: + return new llama_model_falcon_h1(params); + case LLM_ARCH_LFM2: + return new llama_model_lfm2(params); + case LLM_ARCH_LFM2MOE: + return new llama_model_lfm2moe(params); + case LLM_ARCH_SMALLTHINKER: + return new llama_model_smallthinker(params); + case LLM_ARCH_GROVEMOE: + return new llama_model_grovemoe(params); + case LLM_ARCH_APERTUS: + return new llama_model_apertus(params); + case LLM_ARCH_MINIMAX_M2: + return new llama_model_minimax_m2(params); + case LLM_ARCH_COGVLM: + return new llama_model_cogvlm(params); + case LLM_ARCH_PANGU_EMBED: + return new llama_model_pangu_embed(params); + case LLM_ARCH_QWEN3NEXT: + return new llama_model_qwen3next(params); + case LLM_ARCH_QWEN35: + return new llama_model_qwen35(params); + case LLM_ARCH_QWEN35MOE: + return new llama_model_qwen35moe(params); + case LLM_ARCH_MISTRAL3: + return new llama_model_mistral3(params); + case LLM_ARCH_MIMO2: + return new llama_model_mimo2(params); + case LLM_ARCH_KIMI_LINEAR: + return new llama_model_kimi_linear(params); + case LLM_ARCH_STEP35: + return new llama_model_step35(params); + default: + throw std::runtime_error(std::string("unsupported model architecture: '") + llm_arch_name(arch) + "'"); + } + +} + +llama_model * llama_model_create(llm_arch arch, const llama_model_params & params) { + llama_model * model = llama_model_mapping(arch, params); + + if (model != nullptr) { + model->arch = arch; + auto & devices = model->devices; + if (!devices.empty() && devices[0].is_meta && !llm_arch_supports_sm_tensor(arch)) { + throw std::runtime_error(std::string("LLAMA_SPLIT_MODE_TENSOR not implemented for architecture '") + llm_arch_name(arch) + "'"); + } + } + + return model; +} + +llama_model * llama_model_create(llama_model_loader & ml, const llama_model_params & params) { + llm_arch arch = ml.get_arch(); + if (arch == LLM_ARCH_UNKNOWN) { + throw std::runtime_error("unknown model architecture: '" + ml.get_arch_name() + "'"); + } + + return llama_model_create(arch, params); +} + struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const struct ggml_tensor * tensor, void * userdata) { const llama_meta_device_get_split_state_userdata * ud = (const llama_meta_device_get_split_state_userdata *) userdata; const llama_hparams & hparams = ud->model->hparams; @@ -688,22 +967,12 @@ llama_model::~llama_model() { } } -void llama_model::load_stats(llama_model_loader & ml) { +void llama_model_base::load_stats(llama_model_loader & ml) { pimpl->n_elements = ml.n_elements; pimpl->n_bytes = ml.n_bytes; } -void llama_model::load_arch(llama_model_loader & ml) { - arch = ml.get_arch(); - if (arch == LLM_ARCH_UNKNOWN) { - throw std::runtime_error("unknown model architecture: '" + ml.get_arch_name() + "'"); - } - if (!devices.empty() && devices[0].is_meta && !llm_arch_supports_sm_tensor(arch)) { - throw std::runtime_error(std::string("LLAMA_SPLIT_MODE_TENSOR not implemented for architecture '") + llm_arch_name(arch) + "'"); - } -} - -void llama_model::load_hparams(llama_model_loader & ml) { +void llama_model_base::load_hparams(llama_model_loader & ml) { const gguf_context * ctx = ml.metadata; // get metadata as string @@ -862,8215 +1131,931 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT_SWA, hparams.n_rot_swa, false); } - // for differentiating model types - uint32_t n_vocab = 0; - ml.get_key(LLM_KV_VOCAB_SIZE, n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, n_vocab, false); - // for classifier models ml.get_arr(LLM_KV_CLASSIFIER_OUTPUT_LABELS, classifier_labels, false); if (!classifier_labels.empty()) { hparams.n_cls_out = classifier_labels.size(); } - // arch-specific KVs - switch (arch) { - case LLM_ARCH_LLAMA: - case LLM_ARCH_LLAMA_EMBED: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + // per-arch hparams + load_arch_hparams(ml); - if (hparams.n_expert == 8) { - switch (hparams.n_layer) { - case 32: type = LLM_TYPE_8x7B; break; - case 56: type = LLM_TYPE_8x22B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } else { - switch (hparams.n_layer) { - case 16: type = LLM_TYPE_1B; break; // Llama 3.2 1B - case 22: type = LLM_TYPE_1B; break; - case 26: type = LLM_TYPE_3B; break; - case 28: type = LLM_TYPE_3B; break; // Llama 3.2 3B - case 30: type = LLM_TYPE_256M; break; // smoldocling 256M - // granite uses a vocab with len 49152 - case 32: type = n_vocab == 49152 ? LLM_TYPE_3B : (n_vocab < 40000 ? LLM_TYPE_7B : LLM_TYPE_8B); break; - case 36: type = LLM_TYPE_8B; break; // granite - case 40: type = LLM_TYPE_13B; break; - case 48: type = LLM_TYPE_34B; break; - case 60: type = LLM_TYPE_30B; break; - case 80: type = hparams.n_head() == hparams.n_head_kv() ? LLM_TYPE_65B : LLM_TYPE_70B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } - } break; - case LLM_ARCH_LLAMA4: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); - ml.get_key(LLM_KV_INTERLEAVE_MOE_LAYER_STEP, hparams.n_moe_layer_step); - - const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); - if (found_swa && hparams.n_swa == 0) { - hparams.swa_type = LLAMA_SWA_TYPE_NONE; - hparams.n_no_rope_layer_step = hparams.n_layer; // always use rope - } else { - hparams.swa_type = LLAMA_SWA_TYPE_CHUNKED; - hparams.n_swa = 8192; - hparams.n_attn_temp_floor_scale = 8192; - hparams.f_attn_temp_scale = 0.1f; - hparams.f_attn_temp_offset = 1.0f; - uint32_t swa_period = 4; // pattern: 3 chunked - 1 full - ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); - hparams.set_swa_pattern(swa_period); - - hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; - hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; - ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); - } + pimpl->n_bytes = ml.n_bytes; - switch (hparams.n_expert) { - case 0: { - // MobileLLM (no MoE) - switch (hparams.n_embd) { - case 2048: type = LLM_TYPE_140M; break; - case 4096: type = LLM_TYPE_360M; break; - case 6144: type = LLM_TYPE_950M; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case 16: type = LLM_TYPE_17B_16E; break; - case 128: type = LLM_TYPE_17B_128E; break; - default: type = LLM_TYPE_UNKNOWN; - } + pimpl->desc_str = arch_name() + " " + type_name() + " " + ml.ftype_name(); - hparams.use_kq_norm = type != LLM_TYPE_17B_128E; - } break; - case LLM_ARCH_ARCEE: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + if (hparams.f_max_alibi_bias > 0.0f) { + hparams.use_alibi = true; + } - // Arcee uses the same structure as Llama - switch (hparams.n_layer) { - case 36: type = LLM_TYPE_4B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_AFMOE: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); - ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); - ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); - ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); - - // Set up interleaved sliding window attention (ISWA) - // Pattern: 3 sliding - 1 full (global_attn_every_n_layers = 4) - if (hparams.n_swa > 0) { - hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; - uint32_t swa_period = 4; - ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); - hparams.set_swa_pattern(swa_period); - - hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; - hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; - ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); - } else { - hparams.swa_type = LLAMA_SWA_TYPE_NONE; - } + hparams.rope_type = llama_model_rope_type(this); +} - // Default to sigmoid if not set - if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { - hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID; - } +void llama_model_base::load_vocab(llama_model_loader & ml) { + const auto kv = LLM_KV(arch); - switch (hparams.n_layer) { - case 56: type = LLM_TYPE_6B; break; - case 32: type = LLM_TYPE_26B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_DECI: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { - case 32: type = LLM_TYPE_7B; break; - case 80: type = LLM_TYPE_70B; break; - case 162: type = LLM_TYPE_405B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_MINICPM: - { - // Backward-compatible defaults for older MiniCPM GGUFs - hparams.f_embedding_scale = 12.0f; - hparams.f_residual_scale = 1.4f / sqrtf(float(hparams.n_layer)); - hparams.f_logit_scale = hparams.n_embd ? (256.0f / float(hparams.n_embd)) : 1.0f; + vocab.load(ml, kv); +} - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); +bool llama_model_base::load_tensors(llama_model_loader & ml) { + const auto & split_mode = params.split_mode; + const auto & use_mlock = params.use_mlock; + const auto & tensor_split = params.tensor_split; - // Optional KV reads, override defaults if present in newer GGUF exports - ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale, /*required=*/false); - ml.get_key(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale, /*required=*/false); - ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale, /*required=*/false); + const int n_layer = hparams.n_layer; + const int n_gpu_layers = this->n_gpu_layers(); - // MiniCPM uses rope by default, unlike Granite which uses it as a switch - hparams.rope_finetuned = true; + const bool use_mmap_buffer = true; - switch (hparams.n_layer) { - case 52: type = LLM_TYPE_1B; break; - case 40: type = LLM_TYPE_2B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_MINICPM3: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); - ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + this->ml = &ml; // to be used by create_tensor() and load_arch_tensors() - switch (hparams.n_layer) { - case 62: type = LLM_TYPE_4B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_GROK: - { - // defaults for old GGUFs - hparams.yarn_beta_fast = 8.0f; - hparams.f_logit_scale = 0.5773502691896257f; - hparams.f_embedding_scale = 78.38367176906169f; - hparams.f_attn_out_scale = 0.08838834764831845f; - hparams.f_attn_logit_softcapping = 30.0f; - hparams.f_router_logit_softcapping = 30.0f; - // no final_logit_softcapping in grok-1 - hparams.f_final_logit_softcapping = 0.0f; - - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); - ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale, false); - ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale, false); - ml.get_key(LLM_KV_ATTENTION_OUTPUT_SCALE, hparams.f_attn_out_scale, false); - ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false); - ml.get_key(LLM_KV_ROUTER_LOGIT_SOFTCAPPING, hparams.f_router_logit_softcapping, false); - ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false); - - ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_LENGTH, hparams.attn_temp_length, false); - ml.get_key(LLM_KV_ROPE_SCALING_YARN_EXT_FACTOR, hparams.yarn_ext_factor, false); - ml.get_key(LLM_KV_ROPE_SCALING_YARN_ATTN_FACTOR, hparams.yarn_attn_factor, false); - ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_FAST, hparams.yarn_beta_fast, false); - ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, hparams.yarn_beta_slow, false); - - switch (hparams.n_layer) { - case 64: type = LLM_TYPE_314B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_FALCON: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + LLAMA_LOG_INFO("%s: loading model tensors, this can take a while... (mmap = %s, direct_io = %s)\n", + __func__, ml.use_mmap ? "true" : "false", ml.use_direct_io ? "true" : "false"); - switch (hparams.n_layer) { - case 32: type = LLM_TYPE_7B; break; - case 60: type = LLM_TYPE_40B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_BAICHUAN: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { - case 32: type = LLM_TYPE_7B; break; - case 40: type = LLM_TYPE_13B; break; - default: type = LLM_TYPE_UNKNOWN; - } + // build a list of buffer types for the CPU and GPU devices + pimpl->cpu_buft_list = make_cpu_buft_list(devices, params.use_extra_bufts, params.no_host); + for (const auto & dev : devices) { + buft_list_t buft_list = make_gpu_buft_list(dev.dev, split_mode, tensor_split); + // add CPU buffer types as a fallback + buft_list.insert(buft_list.end(), pimpl->cpu_buft_list.begin(), pimpl->cpu_buft_list.end()); + pimpl->gpu_buft_list.emplace(dev.dev, std::move(buft_list)); + } - if (type == LLM_TYPE_13B) { - // TODO: become GGUF KV parameter - hparams.f_max_alibi_bias = 8.0f; - } - } break; - case LLM_ARCH_STARCODER: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - switch (hparams.n_layer) { - case 24: type = LLM_TYPE_1B; break; - case 36: type = LLM_TYPE_3B; break; - case 42: type = LLM_TYPE_7B; break; - case 40: type = LLM_TYPE_15B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_REFACT: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { - case 32: type = LLM_TYPE_1B; break; - default: type = LLM_TYPE_UNKNOWN; - } + ggml_backend_dev_t cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + if (cpu_dev == nullptr) { + throw std::runtime_error(format("%s: no CPU backend found", __func__)); + } - // TODO: become GGUF KV parameter - hparams.f_max_alibi_bias = 8.0f; - } break; - case LLM_ARCH_BERT: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - - switch (hparams.n_layer) { - case 3: - type = LLM_TYPE_17M; break; // bge-micro - case 6: - type = LLM_TYPE_22M; break; // MiniLM-L6 - case 12: - switch (hparams.n_embd) { - case 384: type = LLM_TYPE_33M; break; // MiniLM-L12, bge-small - case 768: type = LLM_TYPE_109M; break; // bge-base - default: type = LLM_TYPE_UNKNOWN; - } break; - case 24: - type = LLM_TYPE_335M; break; // bge-large - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_MODERN_BERT: - { - const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); - if (found_swa && hparams.n_swa > 0) { - hparams.swa_type = LLAMA_SWA_TYPE_SYMMETRIC; - ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); - uint32_t swa_period = 3; - ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); - hparams.set_swa_pattern(swa_period, true); - } else { - hparams.swa_type = LLAMA_SWA_TYPE_NONE; - } + // calculate the split points + bool all_zero = tensor_split == nullptr || std::all_of(tensor_split, tensor_split + n_devices(), [](float x) { return x == 0.0f; }); + std::vector splits(n_devices()); + if (all_zero) { + // default split, by free memory + for (size_t i = 0; i < n_devices(); ++i) { + ggml_backend_dev_t dev = devices[i].dev; + size_t total; + size_t free; + ggml_backend_dev_memory(dev, &free, &total); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - - switch (hparams.n_layer) { - case 12: - type = LLM_TYPE_47M; break; // granite-embedding-small - case 22: - type = LLM_TYPE_149M; break; // modern-bert-base - case 28: - type = LLM_TYPE_395M; break; // modern-bert-large - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_JINA_BERT_V2: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - hparams.f_max_alibi_bias = 8.0f; - - switch (hparams.n_layer) { - case 4: type = LLM_TYPE_33M; break; // jina-embeddings-small - case 12: type = LLM_TYPE_137M; break; // jina-embeddings-base - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_JINA_BERT_V3: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - - switch (hparams.n_layer) { - case 24: - type = LLM_TYPE_558M; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_NOMIC_BERT: - case LLM_ARCH_NOMIC_BERT_MOE: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - ml.get_key(LLM_KV_MOE_EVERY_N_LAYERS, hparams.moe_every_n_layers, 0); - - if (hparams.n_layer == 12 && hparams.n_embd == 768) { - if (arch == LLM_ARCH_NOMIC_BERT) { - type = LLM_TYPE_137M; - } else if (arch == LLM_ARCH_NOMIC_BERT_MOE && hparams.moe_every_n_layers == 2) { - type = LLM_TYPE_475M; - } - } - } break; - case LLM_ARCH_NEO_BERT: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - - if (hparams.n_layer == 28) { - type = LLM_TYPE_250M; - } - } break; - case LLM_ARCH_EUROBERT: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - - if (hparams.n_layer == 12) { - type = LLM_TYPE_SMALL; // 0.2B - } - } break; - case LLM_ARCH_BLOOM: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - - switch (hparams.n_layer) { - case 24: type = LLM_TYPE_1B; break; - case 30: - switch (hparams.n_embd) { - case 2560: type = LLM_TYPE_3B; break; - case 4096: type = LLM_TYPE_7B; break; - default: type = LLM_TYPE_UNKNOWN; - } break; - default: type = LLM_TYPE_UNKNOWN; - } - - // TODO: become GGUF KV parameter - hparams.f_max_alibi_bias = 8.0f; - } break; - case LLM_ARCH_MPT: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - ml.get_key(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv, false); - ml.get_key(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, hparams.f_max_alibi_bias, false); - - switch (hparams.n_layer) { - case 32: type = LLM_TYPE_7B; break; - case 48: type = LLM_TYPE_30B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_STABLELM: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - - switch (hparams.n_layer) { - case 24: type = LLM_TYPE_1B; break; - case 32: type = LLM_TYPE_3B; break; - case 40: type = LLM_TYPE_12B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_QWEN: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - - switch (hparams.n_layer) { - case 32: type = LLM_TYPE_7B; break; - case 40: type = LLM_TYPE_13B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_QWEN2VL: - { - ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, true); + // devices can return 0 bytes for free and total memory if they do not + // have any to report. in this case, we will use the host memory as a fallback + // fixes: https://github.com/ggml-org/llama.cpp/issues/18577 + if (free == 0 && total == 0) { + ggml_backend_dev_memory(cpu_dev, &free, &total); } - // fall through - case LLM_ARCH_QWEN2: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { - case 24: type = hparams.n_embd == 1024 ? LLM_TYPE_0_5B : LLM_TYPE_1B; break; - case 28: type = hparams.n_embd == 1536 ? LLM_TYPE_1_5B : LLM_TYPE_7B; break; - case 32: type = LLM_TYPE_7B; break; - case 36: type = LLM_TYPE_3B; break; - case 40: type = hparams.n_head() == 20 ? LLM_TYPE_4B : LLM_TYPE_13B; break; - case 48: type = LLM_TYPE_14B; break; - case 64: type = LLM_TYPE_32B; break; - case 80: type = LLM_TYPE_70B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_DREAM: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - // Dream models are primarily 7B with 28 layers - switch (hparams.n_layer) { - case 28: - type = LLM_TYPE_7B; - break; - default: - type = LLM_TYPE_UNKNOWN; - } - // Set non-causal attention for diffusion models - hparams.causal_attn = false; - } break; - case LLM_ARCH_LLADA: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - // LLaDA-8B has 32 layers, similar to LLaMA but for diffusion - switch (hparams.n_layer) { - case 32: - type = LLM_TYPE_8B; - break; - default: - type = LLM_TYPE_UNKNOWN; - } - // Set non-causal attention for diffusion models - hparams.causal_attn = false; - } break; - case LLM_ARCH_LLADA_MOE: - { - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); - - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - // diffusion language model uses non-causal attention - hparams.causal_attn = false; - switch (hparams.n_layer) { - case 16: type = LLM_TYPE_A1_7B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_RND1: - { - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); - - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { - case 48: type = LLM_TYPE_30B_A3B; break; - default: type = LLM_TYPE_UNKNOWN; - } - // Set non-causal attention for diffusion models - hparams.causal_attn = false; - } break; - case LLM_ARCH_QWEN2MOE: - { - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); - ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); - - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { - case 24: type = LLM_TYPE_A2_7B; break; - case 28: type = LLM_TYPE_57B_A14B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_QWEN3: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { - case 28: type = hparams.n_embd == 1024 ? LLM_TYPE_0_6B : LLM_TYPE_1_7B; break; - case 36: type = hparams.n_embd == 2560 ? LLM_TYPE_4B : LLM_TYPE_8B; break; - case 40: type = LLM_TYPE_14B; break; - case 64: type = LLM_TYPE_32B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_MAINCODER: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { - case 32: type = LLM_TYPE_1B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_QWEN3VL: - { - ml.get_key(LLM_KV_NUM_DEEPSTACK_LAYERS, hparams.n_deepstack_layers, false); - ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, true); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { - case 28: type = LLM_TYPE_1_7B; break; - case 36: type = hparams.n_embd == 2560 ? LLM_TYPE_4B : LLM_TYPE_8B; break; - case 64: type = LLM_TYPE_32B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_QWEN3MOE: - { - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); - - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { - case 48: type = LLM_TYPE_30B_A3B; break; - case 94: type = LLM_TYPE_235B_A22B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_QWEN3VLMOE: - { - ml.get_key(LLM_KV_NUM_DEEPSTACK_LAYERS, hparams.n_deepstack_layers, false); - ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, true); - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { - case 48: type = LLM_TYPE_30B_A3B; break; - case 94: type = LLM_TYPE_235B_A22B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_PHI2: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + splits[i] = free; + } + } else { + std::copy(tensor_split, tensor_split + n_devices(), splits.begin()); + } - switch (hparams.n_layer) { - case 24: type = LLM_TYPE_1B; break; - case 32: type = LLM_TYPE_3B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_PHI3: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + // sum and normalize the splits to get the split points + float split_sum = 0.0f; + for (size_t i = 0; i < n_devices(); ++i) { + split_sum += splits[i]; + splits[i] = split_sum; + } + for (size_t i = 0; i < n_devices(); ++i) { + splits[i] /= split_sum; + } - switch (hparams.n_layer) { - case 24: type = LLM_TYPE_1B; break; - case 32: type = LLM_TYPE_3B; break; - case 40: type = LLM_TYPE_14B; break; - default: type = LLM_TYPE_UNKNOWN; - } + const int i_gpu_start = std::max(int(hparams.n_layer) + 1 - n_gpu_layers, 0); + const int act_gpu_layers = devices.empty() ? 0 : std::min(n_gpu_layers, int(n_layer) + 1); + auto get_layer_buft_list = [&](int il) -> llama_model::impl::layer_dev { + const bool is_swa = il < int(hparams.n_layer) && hparams.is_swa(il); + if (il < i_gpu_start || (il - i_gpu_start) >= act_gpu_layers) { + LLAMA_LOG_DEBUG("load_tensors: layer %3d assigned to device %s, is_swa = %d\n", il, ggml_backend_dev_name(cpu_dev), is_swa); + return {cpu_dev, &pimpl->cpu_buft_list}; + } + const int layer_gpu = std::upper_bound(splits.begin(), splits.begin() + n_devices(), float(il - i_gpu_start)/act_gpu_layers) - splits.begin(); + auto * dev = devices.at(layer_gpu).dev; + LLAMA_LOG_DEBUG("load_tensors: layer %3d assigned to device %s, is_swa = %d\n", il, ggml_backend_dev_name(dev), is_swa); + return {dev, &pimpl->gpu_buft_list.at(dev)}; + }; - const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); + // assign the input layer + // there is very little benefit to offloading the input layer, so always keep it on the CPU + pimpl->dev_input = { cpu_dev, &pimpl->cpu_buft_list }; - if (found_swa && hparams.n_swa > 0) { - LLAMA_LOG_WARN("%s: Phi SWA is currently disabled - results might be suboptimal for some models (see %s)\n", - __func__, "https://github.com/ggml-org/llama.cpp/pull/13676"); + // assign the repeating layers to the devices according to the splits + pimpl->dev_layer.resize(n_layer); + for (int il = 0; il < n_layer; ++il) { + pimpl->dev_layer[il] = get_layer_buft_list(il); + } - // TODO: fix conversion scripts to correctly populate `n_swa` and `n_swa_pattern` - hparams.swa_type = LLAMA_SWA_TYPE_NONE; + // assign the output layer + pimpl->dev_output = get_layer_buft_list(n_layer); - hparams.n_swa = 0; - hparams.set_swa_pattern(1); - } - } break; - case LLM_ARCH_PHIMOE: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + const auto TENSOR_NOT_REQUIRED = llama_model_loader::TENSOR_NOT_REQUIRED; - switch (hparams.n_layer) { - case 32: type = LLM_TYPE_16x3_8B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_PLAMO: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + // create tensors for the weights + { + // TODO: move to a separate function + const auto tn = LLM_TN(arch); - switch (hparams.n_layer) { - case 40: type = LLM_TYPE_13B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_PLAMO2: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + const int64_t n_expert = hparams.n_expert; + const int64_t n_expert_used = hparams.n_expert_used; - // Load Mamba SSM parameters - ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); - ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); - ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); - ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); - ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); + if (n_expert > 0 && n_expert_used == 0) { + throw std::runtime_error("model has expert layers but no expert layers are used"); + } - for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = hparams.n_head_kv(i) == 0; - } + layers.resize(n_layer); - switch (hparams.n_layer) { - case 16: type = LLM_TYPE_1B; break; - case 32: - if (hparams.n_embd == 2048) { - type = LLM_TYPE_2B; - } else if (hparams.n_embd == 4096) { - type = LLM_TYPE_8B; - } - break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_PLAMO3: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); - if (found_swa && hparams.n_swa > 0) { - hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; - ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); - uint32_t swa_period = 8; - ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); - hparams.set_swa_pattern(swa_period); - } else { - hparams.swa_type = LLAMA_SWA_TYPE_NONE; - } + // call the per-model loading function + load_arch_tensors(ml); - switch (hparams.n_layer) { - case 24: type = LLM_TYPE_2B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_GPT2: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - switch (hparams.n_layer) { - case 12: type = LLM_TYPE_SMALL; break; - case 24: type = LLM_TYPE_MEDIUM; break; - case 36: type = LLM_TYPE_LARGE; break; - case 48: type = LLM_TYPE_XL; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_CODESHELL: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - switch (hparams.n_layer) { - case 42: type = LLM_TYPE_7B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_ORION: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + // generic pass: load optional per-tensor/per-expert ".scale" tensors (e.g. NVFP4 scale2) + // this avoids having to add scale loading to every architecture + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; - switch (hparams.n_layer) { - case 40: type = LLM_TYPE_14B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_INTERNLM2: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { - case 32: type = LLM_TYPE_7B; break; - case 48: type = LLM_TYPE_20B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_GEMMA: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + // attention weight scales (per-tensor, shape {1}) + if (!layer.wq_s && layer.wq) { + layer.wq_s = create_tensor(tn(LLM_TENSOR_ATTN_Q, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.wk_s && layer.wk) { + layer.wk_s = create_tensor(tn(LLM_TENSOR_ATTN_K, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.wv_s && layer.wv) { + layer.wv_s = create_tensor(tn(LLM_TENSOR_ATTN_V, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.wo_s && layer.wo) { + layer.wo_s = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.wqkv_s && layer.wqkv) { + layer.wqkv_s = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.wqkv_gate_s && layer.wqkv_gate) { + layer.wqkv_gate_s = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } - switch (hparams.n_layer) { - case 18: type = LLM_TYPE_2B; break; - case 28: type = LLM_TYPE_7B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_GEMMA2: - { - hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; - hparams.n_swa = 4096; // default value of gemma 2 - uint32_t swa_period = 2; - ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); - hparams.set_swa_pattern(swa_period); - hparams.attn_soft_cap = true; - hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; - hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; - - ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); - ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false); - ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false); - - switch (hparams.n_layer) { - case 26: type = LLM_TYPE_2B; break; - case 42: type = LLM_TYPE_9B; break; - case 46: type = LLM_TYPE_27B; break; - default: type = LLM_TYPE_UNKNOWN; - } - - // ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/config.py#L173 - hparams.f_attention_scale = type == LLM_TYPE_27B - ? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0))) - : 1.0f / std::sqrt(float(hparams.n_embd_head_k())); - } break; - case LLM_ARCH_GEMMA3: - { - const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); - if (found_swa && hparams.n_swa > 0) { - hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; - uint32_t swa_period = 6; - ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); - hparams.set_swa_pattern(swa_period); - - ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); - } else { - hparams.swa_type = LLAMA_SWA_TYPE_NONE; - } + // dense FFN weight scales (per-tensor, shape {1}) + if (!layer.ffn_gate_s && layer.ffn_gate) { + layer.ffn_gate_s = create_tensor(tn(LLM_TENSOR_FFN_GATE, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_down_s && layer.ffn_down) { + layer.ffn_down_s = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_up_s && layer.ffn_up) { + layer.ffn_up_s = create_tensor(tn(LLM_TENSOR_FFN_UP, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_gate_shexp_s && layer.ffn_gate_shexp) { + layer.ffn_gate_shexp_s = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_down_shexp_s && layer.ffn_down_shexp) { + layer.ffn_down_shexp_s = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_up_shexp_s && layer.ffn_up_shexp) { + layer.ffn_up_shexp_s = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } - hparams.f_final_logit_softcapping = 0.0f; - ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - - switch (hparams.n_layer) { - case 18: type = LLM_TYPE_270M; break; - case 26: type = LLM_TYPE_1B; break; - case 32: type = LLM_TYPE_8B; break; // Rnj-1 - case 34: type = LLM_TYPE_4B; break; - case 48: type = LLM_TYPE_12B; break; - case 62: type = LLM_TYPE_27B; break; - default: type = LLM_TYPE_UNKNOWN; - } + // MoE expert weight scales (per-expert, shape {n_expert}) + if (!layer.ffn_gate_exps_s && layer.ffn_gate_exps) { + layer.ffn_gate_exps_s = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "scale", i), {n_expert}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_down_exps_s && layer.ffn_down_exps) { + layer.ffn_down_exps_s = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "scale", i), {n_expert}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_up_exps_s && layer.ffn_up_exps) { + layer.ffn_up_exps_s = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "scale", i), {n_expert}, TENSOR_NOT_REQUIRED); + } - // ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/config.py#L289 - hparams.f_attention_scale = type == LLM_TYPE_27B - ? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0))) - : 1.0f / std::sqrt(float(hparams.n_embd_head_k())); - } break; - case LLM_ARCH_GEMMA3N: - { - uint32_t swa_period = 5; - ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); - hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; - hparams.set_swa_pattern(swa_period); - - hparams.n_layer_kv_from_start = 20; - hparams.f_attention_scale = 1.0f; - - ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); - ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - - switch (hparams.n_layer) { - case 30: type = LLM_TYPE_E2B; break; - case 35: type = LLM_TYPE_E4B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_GEMMA4: - { - hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; - ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.swa_layers, hparams.n_layer); - - uint32_t n_kv_shared_layers = 0; - ml.get_key(LLM_KV_ATTENTION_SHARED_KV_LAYERS, n_kv_shared_layers, false); - - hparams.n_layer_kv_from_start = hparams.n_layer - (int32_t)n_kv_shared_layers; - hparams.f_attention_scale = 1.0f; // Gemma4 uses self.scaling = 1.0 (no pre-attn scaling) - - ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); - ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_EMBEDDING_LENGTH_PER_LAYER, hparams.n_embd_per_layer); - ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_SWA, hparams.n_embd_head_k_swa); - ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_SWA, hparams.n_embd_head_v_swa); - ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false); - - switch (hparams.n_layer) { - case 30: type = LLM_TYPE_26B_A4B; break; - case 35: type = LLM_TYPE_E2B; break; - case 42: type = LLM_TYPE_E4B; break; - case 60: type = LLM_TYPE_31B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_GEMMA_EMBEDDING: - { - hparams.swa_type = LLAMA_SWA_TYPE_SYMMETRIC; - uint32_t swa_period = 6; - ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); - hparams.set_swa_pattern(swa_period); + // recurrent / linear-attention weight scales (per-tensor, shape {1}) + if (!layer.ssm_in_s && layer.ssm_in) { + layer.ssm_in_s = create_tensor(tn(LLM_TENSOR_SSM_IN, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ssm_out_s && layer.ssm_out) { + layer.ssm_out_s = create_tensor(tn(LLM_TENSOR_SSM_OUT, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ssm_alpha_s && layer.ssm_alpha) { + layer.ssm_alpha_s = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ssm_beta_s && layer.ssm_beta) { + layer.ssm_beta_s = create_tensor(tn(LLM_TENSOR_SSM_BETA, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } - hparams.causal_attn = false; // embeddings do not use causal attention + // input scales + if (!layer.wq_in_s && layer.wq) { + layer.wq_in_s = create_tensor(tn(LLM_TENSOR_ATTN_Q, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.wk_in_s && layer.wk) { + layer.wk_in_s = create_tensor(tn(LLM_TENSOR_ATTN_K, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.wv_in_s && layer.wv) { + layer.wv_in_s = create_tensor(tn(LLM_TENSOR_ATTN_V, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.wo_in_s && layer.wo) { + layer.wo_in_s = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.wqkv_in_s && layer.wqkv) { + layer.wqkv_in_s = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.wqkv_gate_in_s && layer.wqkv_gate) { + layer.wqkv_gate_in_s = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_gate_in_s && layer.ffn_gate) { + layer.ffn_gate_in_s = create_tensor(tn(LLM_TENSOR_FFN_GATE, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_down_in_s && layer.ffn_down) { + layer.ffn_down_in_s = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_up_in_s && layer.ffn_up) { + layer.ffn_up_in_s = create_tensor(tn(LLM_TENSOR_FFN_UP, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_gate_exps_in_s && layer.ffn_gate_exps) { + layer.ffn_gate_exps_in_s = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "input_scale", i), {n_expert}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_down_exps_in_s && layer.ffn_down_exps) { + layer.ffn_down_exps_in_s = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "input_scale", i), {n_expert}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_up_exps_in_s && layer.ffn_up_exps) { + layer.ffn_up_exps_in_s = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "input_scale", i), {n_expert}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_gate_shexp_in_s && layer.ffn_gate_shexp) { + layer.ffn_gate_shexp_in_s = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_down_shexp_in_s && layer.ffn_down_shexp) { + layer.ffn_down_shexp_in_s = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_up_shexp_in_s && layer.ffn_up_shexp) { + layer.ffn_up_shexp_in_s = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ssm_in_in_s && layer.ssm_in) { + layer.ssm_in_in_s = create_tensor(tn(LLM_TENSOR_SSM_IN, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ssm_out_in_s && layer.ssm_out) { + layer.ssm_out_in_s = create_tensor(tn(LLM_TENSOR_SSM_OUT, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ssm_alpha_in_s && layer.ssm_alpha) { + layer.ssm_alpha_in_s = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ssm_beta_in_s && layer.ssm_beta) { + layer.ssm_beta_in_s = create_tensor(tn(LLM_TENSOR_SSM_BETA, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + } + } - ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); - ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.done_getting_tensors(); - //applied only if model converted with --sentence-transformers-dense-modules - ml.get_key(LLM_KV_DENSE_2_FEAT_IN, hparams.dense_2_feat_in, false); - ml.get_key(LLM_KV_DENSE_2_FEAT_OUT, hparams.dense_2_feat_out, false); - ml.get_key(LLM_KV_DENSE_3_FEAT_IN, hparams.dense_3_feat_in, false); - ml.get_key(LLM_KV_DENSE_3_FEAT_OUT, hparams.dense_3_feat_out, false); + // populate tensors_by_name + for (auto & [_, ctx_ptr] : ml.ctx_map) { + for (auto * cur = ggml_get_first_tensor(ctx_ptr.get()); cur != NULL; cur = ggml_get_next_tensor(ctx_ptr.get(), cur)) { + tensors_by_name.emplace_back(ggml_get_name(cur), cur); + } + } - GGML_ASSERT((hparams.dense_2_feat_in == 0 || hparams.dense_2_feat_in == hparams.n_embd) && "dense_2_feat_in must be equal to n_embd"); - GGML_ASSERT((hparams.dense_3_feat_out == 0 || hparams.dense_3_feat_out == hparams.n_embd) && "dense_3_feat_out must be equal to n_embd"); + ml.init_mappings(true, use_mlock ? &pimpl->mlock_mmaps : nullptr); + pimpl->mappings.reserve(ml.mappings.size()); - switch (hparams.n_layer) { - case 24: type = LLM_TYPE_0_3B; break; - default: type = LLM_TYPE_UNKNOWN; - } - hparams.f_attention_scale = 1.0f / std::sqrt(float(hparams.n_embd_head_k())); + // create the backend buffers + std::vector> ctx_buf_maps; + ctx_buf_maps.reserve(ml.ctx_map.size()); - } break; - case LLM_ARCH_STARCODER2: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - switch (hparams.n_layer) { - case 30: type = LLM_TYPE_3B; break; - case 32: type = LLM_TYPE_7B; break; - case 40: type = LLM_TYPE_15B; break; - case 52: type = LLM_TYPE_20B; break; // granite - case 88: type = LLM_TYPE_34B; break; // granite - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_MAMBA: - { - ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); - ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); - ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); - ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); - ml.get_key(LLM_KV_SSM_DT_B_C_RMS, hparams.ssm_dt_b_c_rms, false); - - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - - switch (hparams.n_layer) { - case 24: - switch (hparams.n_embd) { - case 768: type = LLM_TYPE_SMALL; break; - default: type = LLM_TYPE_UNKNOWN; - } break; - case 48: - switch (hparams.n_embd) { - case 1024: type = LLM_TYPE_MEDIUM; break; - case 1536: type = LLM_TYPE_LARGE; break; - case 2048: type = LLM_TYPE_XL; break; - default: type = LLM_TYPE_UNKNOWN; - } break; - case 64: - switch (hparams.n_embd) { - case 2560: type = LLM_TYPE_3B; break; - default: type = LLM_TYPE_UNKNOWN; - } break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_MAMBA2: - { - ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); - ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); - ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); - ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); - ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); - - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - - switch (hparams.n_layer) { - case 24: - switch (hparams.n_embd) { - case 768: type = LLM_TYPE_SMALL; break; - default: type = LLM_TYPE_UNKNOWN; - } break; - case 48: - switch (hparams.n_embd) { - case 1024: type = LLM_TYPE_MEDIUM; break; - case 1536: type = LLM_TYPE_LARGE; break; - case 2048: type = LLM_TYPE_XL; break; - default: type = LLM_TYPE_UNKNOWN; - } break; - case 64: - switch (hparams.n_embd) { - case 2560: type = LLM_TYPE_3B; break; - case 4096: type = LLM_TYPE_7B; break; - default: type = LLM_TYPE_UNKNOWN; - } break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_JAMBA: - { - ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); - ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); - ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); - ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + // Ensure we have enough capacity for the maximum backend buffer we will potentially create + const size_t n_max_backend_buffer = ml.ctx_map.size() * ml.files.size(); + pimpl->ctxs_bufs.reserve(n_max_backend_buffer); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + for (auto & [buft, ctx_ptr] : ml.ctx_map) { + ggml_context * ctx = ctx_ptr.get(); - for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = hparams.n_head_kv(i) == 0; - } + // skip contexts without tensors + if (ggml_get_first_tensor(ctx) == nullptr) { + continue; + } - switch (hparams.n_layer) { - // TODO: Jamba layers are a bit heterogeneous, so naming this is hard. - case 12: // 900M 8x???M - case 32: // 51B 16x?B - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_XVERSE: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { - case 32: type = LLM_TYPE_7B; break; - case 40: type = LLM_TYPE_13B; break; - case 80: type = LLM_TYPE_65B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_COMMAND_R: - { - ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale, false); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - switch (hparams.n_layer) { - case 40: type = LLM_TYPE_35B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_COHERE2: - { - hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; - uint32_t swa_period = 4; - ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); - hparams.set_swa_pattern(swa_period); - hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; - hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; - - ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); - ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); - ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - switch (hparams.n_layer) { - case 32: type = LLM_TYPE_8B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_DBRX: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - ml.get_key(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv); + llama_buf_map buf_map; + buf_map.reserve(n_max_backend_buffer); - switch (hparams.n_layer) { - case 40: type = LLM_TYPE_16x12B; break; - default: type = LLM_TYPE_UNKNOWN; + // check if it is possible to use buffer_from_host_ptr with this buffer type + ggml_backend_dev_t dev = ggml_backend_buft_get_device(buft); + if (!dev) { + // FIXME: workaround for CPU backend buft having a NULL device + dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + if (!dev) { + throw std::runtime_error(format("%s: no CPU backend found", __func__)); } - } break; - case LLM_ARCH_OLMO: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - ml.get_key(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv, false); - - switch (hparams.n_layer) { - case 22: type = LLM_TYPE_1B; break; - case 32: type = LLM_TYPE_7B; break; - case 80: type = LLM_TYPE_70B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_OLMO2: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - - const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); - if (found_swa && hparams.n_swa > 0) { - hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; - uint32_t swa_period = 4; - ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); - hparams.set_swa_pattern(swa_period); - - hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; - hparams.rope_freq_scale_train_swa = 1.0; // See olmo2.cpp - ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); - } else { - hparams.swa_type = LLAMA_SWA_TYPE_NONE; - } + } + ggml_backend_dev_props props; + ggml_backend_dev_get_props(dev, &props); + bool buffer_from_host_ptr_supported = props.caps.buffer_from_host_ptr; + bool is_default_buft = buft == ggml_backend_dev_buffer_type(dev); - switch (hparams.n_layer) { - case 16: type = LLM_TYPE_1B; break; - case 32: type = LLM_TYPE_7B; break; - case 40: type = LLM_TYPE_13B; break; - case 64: type = LLM_TYPE_32B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_SEED_OSS: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { - case 64: type = LLM_TYPE_36B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_OLMOE: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { - case 16: type = LLM_TYPE_A1_7B; break; - default: type = LLM_TYPE_UNKNOWN; + std::vector bufs; + if (ml.use_mmap && use_mmap_buffer && buffer_from_host_ptr_supported && is_default_buft) { + GGML_ASSERT(!ml.no_alloc); + for (uint32_t idx = 0; idx < ml.files.size(); idx++) { + // only the mmap region containing the tensors in the model is mapped to the backend buffer + // this is important for metal with apple silicon: if the entire model could be mapped to a metal buffer, + // then we could just use metal for all layers + // this allows using partial offloading when the model size exceeds the metal buffer size, but not the RAM size + void * addr = nullptr; + size_t first, last; // NOLINT + ml.get_mapping_range(&first, &last, &addr, idx, ctx); + if (first >= last) { + continue; } - } break; - case LLM_ARCH_OPENELM: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - - switch (hparams.n_layer) { - case 16: type = LLM_TYPE_270M; break; - case 20: type = LLM_TYPE_450M; break; - case 28: type = LLM_TYPE_1B; break; - case 36: type = LLM_TYPE_3B; break; - default: type = LLM_TYPE_UNKNOWN; + const size_t max_size = ggml_get_max_tensor_size(ctx); + ggml_backend_buffer_t buf = ggml_backend_dev_buffer_from_host_ptr(dev, (char *) addr + first, last - first, max_size); + if (buf == nullptr) { + throw std::runtime_error(format("unable to allocate %s buffer", ggml_backend_buft_name(buft))); } - } break; - case LLM_ARCH_GPTNEOX: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - ml.get_key(LLM_KV_USE_PARALLEL_RESIDUAL, hparams.use_par_res); - switch (hparams.n_layer) { - case 6: - switch (hparams.n_ff()) { - case 512: type = LLM_TYPE_14M; break; - case 2048: type = LLM_TYPE_70M; break; - default: type = LLM_TYPE_UNKNOWN; - } break; - case 12: - switch (hparams.n_ff()) { - case 3072: type = LLM_TYPE_160M; break; - default: type = LLM_TYPE_UNKNOWN; - } break; - case 16: - switch (hparams.n_ff()) { - case 8192: type = LLM_TYPE_1B; break; - default: type = LLM_TYPE_UNKNOWN; - } break; - case 24: - switch (hparams.n_ff()) { - case 4096: type = LLM_TYPE_410M; break; - case 8192: type = LLM_TYPE_1_4B; break; - default: type = LLM_TYPE_UNKNOWN; - } break; - case 32: - switch (hparams.n_ff()) { - case 10240: type = LLM_TYPE_2_8B; break; - case 16384: type = LLM_TYPE_6_9B; break; - default: type = LLM_TYPE_UNKNOWN; - } break; - case 36: - switch (hparams.n_ff()) { - case 20480: type = LLM_TYPE_12B; break; - default: type = LLM_TYPE_UNKNOWN; - } break; - case 44: - switch (hparams.n_ff()) { - case 24576: type = LLM_TYPE_20B; break; - default: type = LLM_TYPE_UNKNOWN; - } break; - default: type = LLM_TYPE_UNKNOWN; + bufs.emplace_back(buf); + buf_map.emplace(idx, buf); + } + } else { + ggml_backend_buffer_t buf; + if (ml.no_alloc) { + buf = ggml_backend_buft_alloc_buffer(buft, /*size =*/ 0); // dummy buffer + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) { + t->buffer = buf; // set dummy buffer for weights so that the backend scheduler won't try to allocate them } - } break; - case LLM_ARCH_ARCTIC: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + } else { + buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); // real buffer + } + if (buf == nullptr) { + throw std::runtime_error(format("unable to allocate %s buffer", ggml_backend_buft_name(buft))); + } + if (use_mlock && ggml_backend_buffer_is_host(buf)) { + pimpl->mlock_bufs.emplace_back(new llama_mlock); + auto & mlock_buf = pimpl->mlock_bufs.back(); + mlock_buf->init (ggml_backend_buffer_get_base(buf)); + mlock_buf->grow_to(ggml_backend_buffer_get_size(buf)); + } + bufs.emplace_back(buf); + for (uint32_t idx = 0; idx < ml.files.size(); idx++) { + buf_map.emplace(idx, buf); + } + } - if (hparams.n_expert == 128) { - switch (hparams.n_layer) { - case 35: type = LLM_TYPE_10B_128x3_66B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } else { - type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_DEEPSEEK: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); - ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); - - switch (hparams.n_ff_exp) { - case 1408: type = LLM_TYPE_16B; break; - case 1792: type = LLM_TYPE_20B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_DEEPSEEK2: - case LLM_ARCH_MISTRAL4: - { - // lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B, Kanana-2-30B-A3B - const bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26 || (hparams.n_layer == 48 && n_vocab == 128256)); + for (auto & buf : bufs) { + // indicate that this buffer contains weights + // this is used by ggml_backend_sched to improve op scheduling: ops that use a weight are preferably scheduled to the backend that contains the weight + ggml_backend_buffer_set_usage(buf.get(), GGML_BACKEND_BUFFER_USAGE_WEIGHTS); + } - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); - if (!is_lite) { - ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); - } - ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); - ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla_impl, false); - ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla_impl, false); - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); - ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); - ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); - if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { - // for compatibility with existing DeepSeek V2 and V2.5 GGUFs - // that have no expert_gating_func model parameter set - if ((hparams.n_layer == 47 || hparams.n_layer == 48) && n_vocab == 154880) { - // GLM 4.7 Lite - hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID; - } else { - hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX; - } - } + pimpl->ctxs_bufs.emplace_back(std::move(ctx_ptr), std::move(bufs)); - if (ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, 0.0f)) { - // [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX] - // cancel the factor from the convert script - hparams.rope_yarn_log_mul /= 0.1f; - } + ctx_buf_maps.emplace_back(ctx, buf_map); + } - // (optional) temperature tuning - used by mistral-large - ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_SCALE, hparams.f_attn_temp_scale, false); - ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_LENGTH, hparams.n_attn_temp_floor_scale, false); // FIXME why not use temperature_length? + if (llama_supports_gpu_offload()) { + const int n_gpu = std::min(n_gpu_layers, int(hparams.n_layer)); - hparams.f_attn_temp_offset = 0.0f; + int n_repeating = n_gpu; + if (n_repeating > 0) { + LLAMA_LOG_INFO("%s: offloading output layer to GPU\n", __func__); + n_repeating--; + } + LLAMA_LOG_INFO("%s: offloading %d repeating layers to GPU\n", __func__, n_repeating); - switch (hparams.n_layer) { - case 27: type = LLM_TYPE_16B; break; - case 47: type = LLM_TYPE_30B_A3B; break; - case 60: type = LLM_TYPE_236B; break; - case 61: type = LLM_TYPE_671B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_DEEPSEEK2OCR: - { - // similar to deepseek2, but without MLA - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); - ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); - ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); - - if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { - hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX; - } + const int max_backend_supported_layers = hparams.n_layer + 1; + const int max_offloadable_layers = hparams.n_layer + 1; - switch (hparams.n_layer) { - case 12: type = LLM_TYPE_3B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_PLM: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); - switch (hparams.n_layer) { - case 32: type = LLM_TYPE_1_8B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_CHATGLM: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { - case 28: { - if (hparams.n_head(0) == 16) { - type = LLM_TYPE_1_5B; - } else { - type = LLM_TYPE_6B; - } - } break; - case 40: { - if (hparams.n_head(0) == 24) { - type = LLM_TYPE_4B; - } else { - type = LLM_TYPE_9B; - } - } break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_GLM4: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, false); - - // NextN/MTP parameters (GLM-OCR) - ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); - GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); - - // TODO: when MTP is implemented, this should probably be updated if needed - hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; - - switch (hparams.n_layer) { - case 17: type = LLM_TYPE_1B; break; // GLM-OCR - case 40: type = LLM_TYPE_9B; break; - case 61: type = LLM_TYPE_32B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_GLM4_MOE: - { - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, false); - - // MoE parameters - ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert); - ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used); - ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); - ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); - - // Expert gating function (GLM-4.5 uses sigmoid) - ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); - if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { - hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID; - } - - // NextN/MTP parameters - ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); - GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); - - // TODO: when MTP is implemented, this should probably be updated if needed - hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; + LLAMA_LOG_INFO("%s: offloaded %d/%d layers to GPU\n", __func__, std::min(n_gpu_layers, max_offloadable_layers), max_backend_supported_layers); + } - switch (hparams.n_layer) { - case 47: type = LLM_TYPE_106B_A12B; break; // GLM-4.5-Air (46 layers + 1 NextN layer) - case 48: type = LLM_TYPE_102B_A12B; break; // Solar Open - case 93: type = LLM_TYPE_355B_A32B; break; // GLM-4.5 (92 layers + 1 NextN layer) - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_GLM_DSA: - { - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, false); - - // MoE parameters - ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert); - ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used); - ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); - ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); - - // deepseek MLA parameters - ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); - ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); - ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla_impl, false); - ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla_impl, false); - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); - ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); - - // DSA parameters - ml.get_key(LLM_KV_ATTENTION_INDEXER_HEAD_COUNT, hparams.indexer_n_head); - ml.get_key(LLM_KV_ATTENTION_INDEXER_KEY_LENGTH, hparams.indexer_head_size); - ml.get_key(LLM_KV_ATTENTION_INDEXER_TOP_K, hparams.indexer_top_k); - - // Expert gating function (GLM-4.5 uses sigmoid) - ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); - if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { - hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID; - } + // print memory requirements per buffer type + for (auto & [_, bufs] : pimpl->ctxs_bufs) { + for (auto & buf: bufs) { + LLAMA_LOG_INFO("%s: %12s model buffer size = %8.2f MiB\n", + __func__, ggml_backend_buffer_name(buf.get()), ggml_backend_buffer_get_size(buf.get()) / 1024.0 / 1024.0); + } + } - // NextN/MTP parameters - ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); - GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); + if (ml.no_alloc) { + return true; + } - // TODO: when MTP is implemented, this should probably be updated if needed - hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; + // load tensor data + for (auto & [ctx, buf_map] : ctx_buf_maps) { + if (!ml.load_all_data(ctx, buf_map, use_mlock ? &pimpl->mlock_mmaps : NULL, params.progress_callback, params.progress_callback_user_data)) { + return false; + } + } - switch (hparams.n_layer) { - case 79: type = LLM_TYPE_744B_A40B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_BITNET: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + if (use_mmap_buffer) { + for (auto & mapping : ml.mappings) { + pimpl->mappings.emplace_back(std::move(mapping)); + } + } - switch (hparams.n_layer) { - case 26: type = LLM_TYPE_3B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_T5: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, hparams.n_rel_attn_bkts); + return true; +} - uint32_t dec_start_token_id; - if (ml.get_key(LLM_KV_DECODER_START_TOKEN_ID, dec_start_token_id, false)) { - hparams.dec_start_token_id = dec_start_token_id; - } +ggml_tensor * llama_model_base::create_tensor(llama_model_loader & ml, const LLM_TN_IMPL & tn, const std::initializer_list & ne, int flags) { + const buft_list_t * buft_list_layer = tn.bid == -1 ? nullptr : pimpl->dev_layer.at(tn.bid).buft_list; + return ml.create_tensor( + hparams, &pimpl->cpu_buft_list, pimpl->dev_input.buft_list, pimpl->dev_output.buft_list, buft_list_layer, + tn, ne, flags); +} - hparams.dec_n_layer = hparams.n_layer; - ml.get_key(LLM_KV_DECODER_BLOCK_COUNT, hparams.dec_n_layer, false); - - switch (hparams.n_layer) { - case 6: type = LLM_TYPE_60M; break; // t5-small - case 8: type = LLM_TYPE_80M; break; // flan-t5-small - case 12: - switch (hparams.n_ff()) { - case 3072: type = LLM_TYPE_220M; break; // t5-base - case 2048: type = LLM_TYPE_250M; break; // flan-t5-base - default: type = LLM_TYPE_UNKNOWN; - } break; - case 24: - switch (hparams.n_ff()) { - case 4096: type = LLM_TYPE_770M; break; // t5-large - case 2816: type = LLM_TYPE_780M; break; // flan-t5-large - case 16384: type = LLM_TYPE_3B; break; // t5-3b - case 5120: type = LLM_TYPE_3B; break; // flan-t5-xl - case 65536: type = LLM_TYPE_11B; break; // t5-11b - case 10240: type = LLM_TYPE_11B; break; // flan-t5-xxl - default: type = LLM_TYPE_UNKNOWN; - } break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_T5ENCODER: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, hparams.n_rel_attn_bkts); - type = LLM_TYPE_UNKNOWN; - } break; - case LLM_ARCH_JAIS: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - ml.get_key(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, hparams.f_max_alibi_bias, false); - - switch (hparams.n_layer) { - case 24: type = LLM_TYPE_1_3B; break; - case 40: type = LLM_TYPE_13B; break; - /* TODO: add variants */ - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_JAIS2: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); +std::string llama_model::arch_name() const { + return llm_arch_name(arch); +} - switch (hparams.n_layer) { - case 32: type = LLM_TYPE_8B; break; - case 68: type = LLM_TYPE_70B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_NEMOTRON: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - switch (hparams.n_layer) { - case 32: type = LLM_TYPE_4B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_NEMOTRON_H: - case LLM_ARCH_NEMOTRON_H_MOE: - { - ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); - ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); - ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); - ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); - ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); - - // A layer is recurrent IFF the n_head_kv value is set to 0 and - // the n_ff value is set to 0 - for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = (hparams.n_head_kv(i) == 0 && hparams.n_ff(i) == 0); - } +std::string llama_model::type_name() const { + return llm_type_name(type); +} - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); +std::string llama_model::desc() const { + return pimpl->desc_str; +} - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); - ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); - ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared, false); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); - ml.get_key(LLM_KV_MOE_LATENT_SIZE, hparams.moe_latent_size, false); +size_t llama_model::size() const { + return pimpl->n_bytes; +} - switch (hparams.n_layer) { - case 52: type = LLM_TYPE_31B_A3_5B; break; // Nemotron-H_MOE 31B - case 56: type = LLM_TYPE_9B; break; - case 88: type = LLM_TYPE_120B_A12B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_EXAONE: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); +size_t llama_model::n_tensors() const { + return tensors_by_name.size(); +} - switch (hparams.n_layer) { - case 32: type = LLM_TYPE_8B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_EXAONE4: - { - if (hparams.n_layer == 64) { // 32B - hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; - hparams.n_swa = 4096; - uint32_t swa_period = 4; - ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); - hparams.set_swa_pattern(swa_period); - - hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; - hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; - ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); - } +size_t llama_model::n_devices() const { + return devices.size(); +} - ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); +const float * llama_model::tensor_split() const { + return params.tensor_split; +} - switch (hparams.n_layer) { - case 30: type = LLM_TYPE_1_2B; break; - case 64: type = LLM_TYPE_32B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_EXAONE_MOE: - { - hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; - hparams.n_swa = 128; - uint32_t swa_period = 4; - ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); - hparams.set_swa_pattern(swa_period); - hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; - hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; - - ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); - ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared, false); - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); - ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); - ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); - ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); - - ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); - GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); - - switch (hparams.n_layer) { - case 32: type = LLM_TYPE_30B_A3B; break; - case 48: - case 49: type = LLM_TYPE_235B_A22B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_RWKV6: - case LLM_ARCH_RWKV6QWEN2: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps, false); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps, false); - ml.get_key(LLM_KV_WKV_HEAD_SIZE, hparams.wkv_head_size); - ml.get_key(LLM_KV_TIME_MIX_EXTRA_DIM, hparams.time_mix_extra_dim); - ml.get_key(LLM_KV_TIME_DECAY_EXTRA_DIM, hparams.time_decay_extra_dim); - ml.get_key(LLM_KV_RESCALE_EVERY_N_LAYERS, hparams.rescale_every_n_layers, false); - ml.get_key(LLM_KV_TOKEN_SHIFT_COUNT, hparams.token_shift_count, false); - - switch (hparams.n_layer) { - case 24: type = LLM_TYPE_1_6B; break; - case 32: - switch (hparams.n_embd) { - case 2560: type = LLM_TYPE_3B; break; - case 4096: type = LLM_TYPE_7B; break; - default: type = LLM_TYPE_UNKNOWN; - } break; - case 61: type = LLM_TYPE_14B; break; - case 64: type = LLM_TYPE_32B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_RWKV7: - case LLM_ARCH_ARWKV7: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps, false); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps, false); - ml.get_key(LLM_KV_WKV_HEAD_SIZE, hparams.wkv_head_size); - ml.get_key(LLM_KV_ATTENTION_DECAY_LORA_RANK, hparams.n_lora_decay); - ml.get_key(LLM_KV_ATTENTION_ICLR_LORA_RANK, hparams.n_lora_iclr); - ml.get_key(LLM_KV_ATTENTION_VALUE_RESIDUAL_MIX_LORA_RANK, hparams.n_lora_value_res_mix); - ml.get_key(LLM_KV_ATTENTION_GATE_LORA_RANK, hparams.n_lora_gate, false); - ml.get_key(LLM_KV_TOKEN_SHIFT_COUNT, hparams.token_shift_count, false); - - switch (hparams.n_layer) { - case 12: - switch (hparams.n_embd) { - case 768: type = LLM_TYPE_190M; break; - default: type = LLM_TYPE_UNKNOWN; - } break; - case 24: - switch (hparams.n_embd) { - case 1024: type = LLM_TYPE_450M; break; - case 2048: type = LLM_TYPE_1_5B; break; - default: type = LLM_TYPE_UNKNOWN; - } break; - case 28: - switch (hparams.n_embd) { - case 1536: type = LLM_TYPE_1_5B; break; - case 3584: type = LLM_TYPE_7B; break; - default: type = LLM_TYPE_UNKNOWN; - } break; - case 32: - switch (hparams.n_embd) { - case 2560: type = LLM_TYPE_2_9B; break; - case 4096: type = LLM_TYPE_7B; break; - default: type = LLM_TYPE_UNKNOWN; - } break; - case 61: - switch (hparams.n_embd) { - case 4096: type = LLM_TYPE_14B; break; - default: type = LLM_TYPE_UNKNOWN; - } break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_GRANITE: - case LLM_ARCH_GRANITE_MOE: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); - ml.get_key(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale, false); - ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale, false); - ml.get_key(LLM_KV_ATTENTION_SCALE, hparams.f_attention_scale, false); - - // Granite uses rope_finetuned as a switch for rope, so default to true - bool rope_finetuned = true; - ml.get_key(LLM_KV_ROPE_SCALING_FINETUNED, rope_finetuned, false); - hparams.rope_finetuned = rope_finetuned; - - switch (hparams.n_layer) { - case 32: type = LLM_TYPE_3B; break; - case 40: type = LLM_TYPE_3B; break; - // Add additional layer/vocab/etc checks here for other model sizes - default: type = LLM_TYPE_UNKNOWN; - } +uint32_t llama_model::n_gpu_layers() const { + return params.n_gpu_layers >= 0 ? params.n_gpu_layers : hparams.n_layer + 1; +} - // For Granite MoE Shared - ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, /* required */ false); - } break; - case LLM_ARCH_GRANITE_HYBRID: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale, /* required */ false); - ml.get_key(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale, /* required */ false); - ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale, /* required */ false); - ml.get_key(LLM_KV_ATTENTION_SCALE, hparams.f_attention_scale, /* required */ false); - - ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); - ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); - ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); - ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); - ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); - - // Granite uses rope_finetuned as a switch for rope, so default to true - bool rope_finetuned = true; - ml.get_key(LLM_KV_ROPE_SCALING_FINETUNED, rope_finetuned, false); - hparams.rope_finetuned = rope_finetuned; - - // A layer is recurrent IFF the n_head_kv value is set to 0 - for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = hparams.n_head_kv(i) == 0; - } +llama_split_mode llama_model::split_mode() const { + return params.split_mode; +} - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); +std::map llama_model::memory_breakdown() const { + std::map ret; + for (const auto & [ctx, bufs] : pimpl->ctxs_bufs) { + if (hparams.no_alloc) { + GGML_ASSERT(bufs.size() == 1); + ggml_backend_buffer_t buf = bufs[0].get(); + GGML_ASSERT(ggml_backend_buffer_get_base(buf) == nullptr); + ggml_backend_buffer_type_t buft = ggml_backend_buffer_get_type(buf); + ret[buft] += ggml_backend_alloc_ctx_tensors_from_buft_size(ctx.get(), buft); + } else { + for (const auto & buf : bufs) { + // GGML_ASSERT(ggml_backend_buffer_get_base(buf.get()) != nullptr); // multi_buffer does not have a defined base + ret[ggml_backend_buffer_get_type(buf.get())] += ggml_backend_buffer_get_size(buf.get()); + } + } + } + return ret; +} - switch (hparams.n_embd) { - case 768: type = LLM_TYPE_350M; break; - case 1536: type = (hparams.n_ff() == 512 ? LLM_TYPE_7B_A1B : LLM_TYPE_1B); break; - case 2048: case 2560: type = LLM_TYPE_3B; break; - case 4096: type = LLM_TYPE_32B; break; - default: type = LLM_TYPE_UNKNOWN; - } +uint64_t llama_model::n_elements() const { + return pimpl->n_elements; +} - // For Granite MoE Shared - ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, /* required */ false); - } break; - case LLM_ARCH_CHAMELEON: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - hparams.f_norm_eps = 1e-5; // eps for qk-norm, torch default - ml.get_key(LLM_KV_SWIN_NORM, hparams.swin_norm, false); - - switch (hparams.n_layer) { - case 32: type = LLM_TYPE_7B; break; - case 48: type = LLM_TYPE_34B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_WAVTOKENIZER_DEC: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - ml.get_key(LLM_KV_ATTENTION_GROUPNORM_EPS, hparams.f_norm_group_eps); - ml.get_key(LLM_KV_ATTENTION_GROUPNORM_GROUPS, hparams.n_norm_groups); - } break; - case LLM_ARCH_BAILINGMOE: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); - ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); - - switch (hparams.n_layer) { - case 28: type = LLM_TYPE_16B; break; - case 88: type = LLM_TYPE_290B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_BAILINGMOE2: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); - ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); - ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); - ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func); - ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); - GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); - - // TODO: when MTP is implemented, this should probably be updated if needed - hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; - - switch (hparams.n_layer) { - case 20: type = LLM_TYPE_16B_A1B; break; - case 21: type = LLM_TYPE_16B_A1B; break; - case 32: type = LLM_TYPE_100B_A6B; break; - case 33: type = LLM_TYPE_100B_A6B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_DOTS1: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); - ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); - ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); - switch (hparams.n_layer) { - case 62: type = LLM_TYPE_142B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_ERNIE4_5: - case LLM_ARCH_ERNIE4_5_MOE: - case LLM_ARCH_PADDLEOCR: - { - // paddleocr need mrope_section - ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, false); - - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - if (arch == LLM_ARCH_ERNIE4_5_MOE) { - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); - ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); - ml.get_key(LLM_KV_INTERLEAVE_MOE_LAYER_STEP, hparams.n_moe_layer_step); - ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); - } +void llama_model::print_info() const { + const std::string rope_scaling_type = llama_rope_scaling_type_name(hparams.rope_scaling_type_train); - switch (hparams.n_layer) { - case 18: type = LLM_TYPE_0_3B; break; - case 28: type = LLM_TYPE_21B_A3B; break; - case 54: type = LLM_TYPE_300B_A47B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_FALCON_H1: - { - // Common parameters - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - - // SSM parameters - ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); - ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); - ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); - ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); - ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); - - std::fill(hparams.recurrent_layer_arr.begin(), hparams.recurrent_layer_arr.end(), true); - - switch (hparams.n_layer) { - case 36: - type = LLM_TYPE_0_5B; break; - case 24: - type = LLM_TYPE_1_5B; break; - case 66: - type = LLM_TYPE_1B; break; - case 32: - type = LLM_TYPE_3B; break; - case 44: - type = LLM_TYPE_7B; break; - case 72: - type = LLM_TYPE_34B; break; - default: - type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_HUNYUAN_MOE: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); - ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); + auto print_f = [](const std::function & f, uint32_t n) { + bool is_var = false; - switch (hparams.n_layer) { - case 32: type = LLM_TYPE_A13B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_HUNYUAN_VL: - case LLM_ARCH_HUNYUAN_DENSE: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, false); - - // XDRoPE / NTK-aware scaling: base = rope_theta * alpha^(dim / (dim - 2)) - if (hparams.rope_scaling_alpha > 0.0f) { - const int dim = hparams.n_embd_head_k(); - hparams.rope_freq_base_train = hparams.rope_freq_base_train - * powf(hparams.rope_scaling_alpha, (float)dim / (float)(dim - 2)); - } + std::vector v; + for (uint32_t i = 0; i < n; ++i) { + v.push_back(f(i)); + if (v[i] != v[0]) { + is_var = true; + } + } - switch (hparams.n_embd) { - case 1024: type = LLM_TYPE_0_5B; break; - case 2048: type = LLM_TYPE_1_8B; break; - case 3072: type = LLM_TYPE_4B; break; - case 4096: type = LLM_TYPE_7B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_SMOLLM3: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - hparams.n_no_rope_layer_step = 4; + std::stringstream ss; - switch (hparams.n_layer) { - case 36: type = LLM_TYPE_3B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_OPENAI_MOE: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); - ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); - - hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; - uint32_t swa_period = 2; - ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); - hparams.set_swa_pattern(swa_period); - - hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; - hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; - ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); - - switch (hparams.n_layer) { - case 24: type = LLM_TYPE_20B; break; - case 36: type = LLM_TYPE_120B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_LFM2: - { - ml.get_key(LLM_KV_SHORTCONV_L_CACHE, hparams.n_shortconv_l_cache); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - for (uint32_t il = 0; il < hparams.n_layer; ++il) { - hparams.recurrent_layer_arr[il] = hparams.n_head_kv(il) == 0; - } - hparams.n_layer_dense_lead = hparams.n_layer; - switch (hparams.n_ff()) { - case 4608: type = LLM_TYPE_350M; break; - case 6912: type = LLM_TYPE_700M; break; - case 8192: type = LLM_TYPE_1_2B; break; - case 10752: type = LLM_TYPE_2_6B; break; - default: type = LLM_TYPE_UNKNOWN; - } - if (const auto is_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); is_swa && hparams.n_swa > 0) { - hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; - for (uint32_t il = 0; il < hparams.n_layer; ++il) { - hparams.swa_layers[il] = !hparams.recurrent_layer_arr[il]; - } - } - } break; - case LLM_ARCH_LFM2MOE: - { - ml.get_key(LLM_KV_SHORTCONV_L_CACHE, hparams.n_shortconv_l_cache); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); - ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func); - - for (uint32_t il = 0; il < hparams.n_layer; ++il) { - hparams.recurrent_layer_arr[il] = hparams.n_head_kv(il) == 0; + if (is_var) { + ss << "["; + for (uint32_t i = 0; i < n; ++i) { + ss << v[i]; + if (i < n - 1) { + ss << ", "; } + } + ss << "]"; + } else { + ss << v[0]; + } - switch (hparams.n_layer) { - case 24: type = LLM_TYPE_8B_A1B; break; - case 40: type = LLM_TYPE_24B_A2B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_SMALLTHINKER: - { - const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); - - if (found_swa && hparams.n_swa > 0) { - hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; - hparams.n_swa = 4096; - uint32_t swa_period = 4; - ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); - hparams.set_swa_pattern(swa_period, true); - - hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; - hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; - ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); - } else { - hparams.swa_type = LLAMA_SWA_TYPE_NONE; - hparams.n_no_rope_layer_step = hparams.n_layer; - } + return ss.str(); + }; - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); + // hparams + LLAMA_LOG_INFO("%s: arch = %s\n", __func__, arch_name().c_str()); + LLAMA_LOG_INFO("%s: vocab_only = %d\n", __func__, hparams.vocab_only); + LLAMA_LOG_INFO("%s: no_alloc = %d\n", __func__, hparams.no_alloc); - switch (hparams.n_layer) { - case 32: type = LLM_TYPE_4B; break; - case 52: type = LLM_TYPE_20B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_GROVEMOE: - { - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); - ml.get_key(LLM_KV_EXPERT_CHUNK_FEED_FORWARD_LENGTH, hparams.n_ff_chexp, false); - ml.get_key(LLM_KV_EXPERT_GROUP_SCALE, hparams.expert_group_scale); - ml.get_key(LLM_KV_EXPERTS_PER_GROUP, hparams.n_group_experts); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - - switch (hparams.n_layer) { - case 48: type = LLM_TYPE_30B_A3B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_APERTUS: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key_or_arr(LLM_KV_XIELU_ALPHA_N, hparams.xielu_alpha_n, hparams.n_layer); - ml.get_key_or_arr(LLM_KV_XIELU_ALPHA_P, hparams.xielu_alpha_p, hparams.n_layer); - ml.get_key_or_arr(LLM_KV_XIELU_BETA, hparams.xielu_beta, hparams.n_layer); - ml.get_key_or_arr(LLM_KV_XIELU_EPS, hparams.xielu_eps, hparams.n_layer); - - switch (hparams.n_layer) { - case 32: type = LLM_TYPE_8B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_MINIMAX_M2: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); - ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); + if (!hparams.vocab_only) { + LLAMA_LOG_INFO("%s: n_ctx_train = %u\n", __func__, hparams.n_ctx_train); + LLAMA_LOG_INFO("%s: n_embd = %u\n", __func__, hparams.n_embd); + LLAMA_LOG_INFO("%s: n_embd_inp = %u\n", __func__, hparams.n_embd_inp()); + LLAMA_LOG_INFO("%s: n_layer = %u\n", __func__, hparams.n_layer); + LLAMA_LOG_INFO("%s: n_head = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head(il); }, hparams.n_layer).c_str()); + LLAMA_LOG_INFO("%s: n_head_kv = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head_kv(il); }, hparams.n_layer).c_str()); + LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot_full); + LLAMA_LOG_INFO("%s: n_swa = %u\n", __func__, hparams.n_swa); + LLAMA_LOG_INFO("%s: is_swa_any = %u\n", __func__, hparams.is_swa_any()); + LLAMA_LOG_INFO("%s: n_embd_head_k = %u\n", __func__, hparams.n_embd_head_k_full); + LLAMA_LOG_INFO("%s: n_embd_head_v = %u\n", __func__, hparams.n_embd_head_v_full); + LLAMA_LOG_INFO("%s: n_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_gqa(il); }, hparams.n_layer).c_str()); + LLAMA_LOG_INFO("%s: n_embd_k_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_k_gqa(il); }, hparams.n_layer).c_str()); + LLAMA_LOG_INFO("%s: n_embd_v_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_v_gqa(il); }, hparams.n_layer).c_str()); + LLAMA_LOG_INFO("%s: f_norm_eps = %.1e\n", __func__, hparams.f_norm_eps); + LLAMA_LOG_INFO("%s: f_norm_rms_eps = %.1e\n", __func__, hparams.f_norm_rms_eps); + LLAMA_LOG_INFO("%s: f_clamp_kqv = %.1e\n", __func__, hparams.f_clamp_kqv); + LLAMA_LOG_INFO("%s: f_max_alibi_bias = %.1e\n", __func__, hparams.f_max_alibi_bias); + LLAMA_LOG_INFO("%s: f_logit_scale = %.1e\n", __func__, hparams.f_logit_scale); + LLAMA_LOG_INFO("%s: f_attn_scale = %.1e\n", __func__, hparams.f_attention_scale); + LLAMA_LOG_INFO("%s: f_attn_value_scale = %.4f\n", __func__, hparams.f_attn_value_scale); + LLAMA_LOG_INFO("%s: n_ff = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_ff(il); }, hparams.n_layer).c_str()); + LLAMA_LOG_INFO("%s: n_expert = %u\n", __func__, hparams.n_expert); + LLAMA_LOG_INFO("%s: n_expert_used = %u\n", __func__, hparams.n_expert_used); + LLAMA_LOG_INFO("%s: n_expert_groups = %d\n", __func__, hparams.n_expert_groups); + LLAMA_LOG_INFO("%s: n_group_used = %d\n", __func__, hparams.n_group_used); + LLAMA_LOG_INFO("%s: causal attn = %d\n", __func__, hparams.causal_attn); + LLAMA_LOG_INFO("%s: pooling type = %d\n", __func__, hparams.pooling_type); + LLAMA_LOG_INFO("%s: rope type = %d\n", __func__, hparams.rope_type); + LLAMA_LOG_INFO("%s: rope scaling = %s\n", __func__, rope_scaling_type.c_str()); + LLAMA_LOG_INFO("%s: freq_base_train = %.1f\n", __func__, hparams.rope_freq_base_train); + LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train); + if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { + LLAMA_LOG_INFO("%s: freq_base_swa = %.1f\n", __func__, hparams.rope_freq_base_train_swa); + LLAMA_LOG_INFO("%s: freq_scale_swa = %g\n", __func__, hparams.rope_freq_scale_train_swa); + LLAMA_LOG_INFO("%s: n_embd_head_k_swa = %u\n", __func__, hparams.n_embd_head_k_swa); + LLAMA_LOG_INFO("%s: n_embd_head_v_swa = %u\n", __func__, hparams.n_embd_head_v_swa); + LLAMA_LOG_INFO("%s: n_rot_swa = %u\n", __func__, hparams.n_rot_swa); + } + LLAMA_LOG_INFO("%s: n_ctx_orig_yarn = %u\n", __func__, hparams.n_ctx_orig_yarn); + LLAMA_LOG_INFO("%s: rope_yarn_log_mul = %.4f\n", __func__, hparams.rope_yarn_log_mul); + LLAMA_LOG_INFO("%s: rope_finetuned = %s\n", __func__, hparams.rope_finetuned ? "yes" : "unknown"); + // MRoPE (Multi-axis Rotary Position Embedding) sections + if (const auto & s = hparams.rope_sections; s[0] || s[1] || s[2] || s[3]) { + LLAMA_LOG_INFO("%s: mrope sections = [%d, %d, %d, %d]\n", __func__, s[0], s[1], s[2], s[3]); + } + if (!classifier_labels.empty()) { + LLAMA_LOG_INFO("%s: n_cls_out = %u\n", __func__, hparams.n_cls_out); - switch (hparams.n_layer) { - case 62: type = LLM_TYPE_230B_A10B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_COGVLM: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { - case 32: type = LLM_TYPE_13B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_PANGU_EMBED: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { - case 26: type = LLM_TYPE_1B; break; // openPangu-Embedded-1B-V1.1 - case 34: type = LLM_TYPE_7B; break; // openPangu-Embedded-7B-V1.1 - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_QWEN3NEXT: - { - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); - ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - - // Load linear attention (gated delta net) parameters - ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); - ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); - ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); - ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); - ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); - - // Mark recurrent layers (linear attention layers) - { - uint32_t full_attn_interval = 4; - ml.get_key(LLM_KV_FULL_ATTENTION_INTERVAL, full_attn_interval, false); - for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = ((i + 1) % full_attn_interval != 0); - } - } + size_t i = 0; + for (const auto & label : classifier_labels) { + LLAMA_LOG_INFO("%s: cls_label[%2zu] = %s\n", __func__, i++, label.c_str()); + } + } - switch (hparams.n_layer) { - case 48: type = LLM_TYPE_80B_A3B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_QWEN35: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, true); - - // Load linear attention (gated delta net) parameters - ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); - ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); - ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); - ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); - ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); - - // Mark recurrent layers (linear attention layers) - { - uint32_t full_attn_interval = 4; - ml.get_key(LLM_KV_FULL_ATTENTION_INTERVAL, full_attn_interval, false); - for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = ((i + 1) % full_attn_interval != 0); - } - } + if (arch == LLM_ARCH_MAMBA || + arch == LLM_ARCH_MAMBA2 || + arch == LLM_ARCH_JAMBA || + arch == LLM_ARCH_FALCON_H1 || + arch == LLM_ARCH_PLAMO2 || + arch == LLM_ARCH_GRANITE_HYBRID || + arch == LLM_ARCH_QWEN3NEXT || + arch == LLM_ARCH_QWEN35 || + arch == LLM_ARCH_QWEN35MOE || + arch == LLM_ARCH_NEMOTRON_H || + arch == LLM_ARCH_NEMOTRON_H_MOE) { + LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv); + LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner); + LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state); + LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank); + LLAMA_LOG_INFO("%s: ssm_n_group = %u\n", __func__, hparams.ssm_n_group); + LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms); + } - switch (hparams.n_layer) { - case 24: type = hparams.n_embd == 1024 ? LLM_TYPE_0_8B : LLM_TYPE_2B; break; - case 32: type = hparams.n_embd == 2560 ? LLM_TYPE_4B : LLM_TYPE_9B; break; - case 64: type = LLM_TYPE_27B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_QWEN35MOE: - { - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); - ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - - ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, true); - - // Load linear attention (gated delta net) parameters - ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); - ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); - ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); - ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); - ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); - - // Mark recurrent layers (linear attention layers) - { - uint32_t full_attn_interval = 4; - ml.get_key(LLM_KV_FULL_ATTENTION_INTERVAL, full_attn_interval, false); - for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = ((i + 1) % full_attn_interval != 0); - } - } + LLAMA_LOG_INFO("%s: model type = %s\n", __func__, type_name().c_str()); + if (pimpl->n_elements >= 1e12) { + LLAMA_LOG_INFO("%s: model params = %.2f T\n", __func__, pimpl->n_elements*1e-12); + } else if (pimpl->n_elements >= 1e9) { + LLAMA_LOG_INFO("%s: model params = %.2f B\n", __func__, pimpl->n_elements*1e-9); + } else if (pimpl->n_elements >= 1e6) { + LLAMA_LOG_INFO("%s: model params = %.2f M\n", __func__, pimpl->n_elements*1e-6); + } else { + LLAMA_LOG_INFO("%s: model params = %.2f K\n", __func__, pimpl->n_elements*1e-3); + } - switch (hparams.n_layer) { - case 40: type = LLM_TYPE_35B_A3B; break; - case 48: type = LLM_TYPE_122B_A10B; break; - case 60: type = LLM_TYPE_397B_A17B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_MISTRAL3: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_SCALE, hparams.f_attn_temp_scale, false); + // general kv + LLAMA_LOG_INFO("%s: general.name = %s\n", __func__, name.c_str()); - ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_FAST, hparams.yarn_beta_fast, false); - ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, hparams.yarn_beta_slow, false); - ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, 0.0f); + if (arch == LLM_ARCH_DEEPSEEK) { + LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); + LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); + } - hparams.f_attn_temp_offset = 0.0f; + if (arch == LLM_ARCH_DEEPSEEK2 || arch == LLM_ARCH_DEEPSEEK2OCR || arch == LLM_ARCH_GLM_DSA || arch == LLM_ARCH_MISTRAL4) { + LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); + LLAMA_LOG_INFO("%s: n_lora_q = %d\n", __func__, hparams.n_lora_q); + LLAMA_LOG_INFO("%s: n_lora_kv = %d\n", __func__, hparams.n_lora_kv); + LLAMA_LOG_INFO("%s: n_embd_head_k_mla = %d\n", __func__, hparams.n_embd_head_k_mla()); + LLAMA_LOG_INFO("%s: n_embd_head_v_mla = %d\n", __func__, hparams.n_embd_head_v_mla()); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); + LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); + LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); + LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); + } - // TODO: maybe add n_attn_temp_floor_scale as a separate KV? - if (hparams.f_attn_temp_scale != 0.0f) { - hparams.n_attn_temp_floor_scale = hparams.n_ctx_orig_yarn; - if (hparams.n_attn_temp_floor_scale == 0) { - throw std::runtime_error("invalid n_ctx_orig_yarn for attention temperature scaling"); - } - } + if (arch == LLM_ARCH_QWEN2MOE) { + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); + } - switch (hparams.n_layer) { - case 26: type = LLM_TYPE_3B; break; - case 34: type = LLM_TYPE_8B; break; - case 40: type = LLM_TYPE_14B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_MIMO2: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + if (arch == LLM_ARCH_QWEN3MOE || arch == LLM_ARCH_OPENAI_MOE || arch == LLM_ARCH_QWEN3VLMOE || arch == LLM_ARCH_RND1) { + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + } - hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + if (arch == LLM_ARCH_MINICPM || + arch == LLM_ARCH_GRANITE || + arch == LLM_ARCH_GRANITE_MOE || + arch == LLM_ARCH_GRANITE_HYBRID || + arch == LLM_ARCH_NEMOTRON_H_MOE) { + LLAMA_LOG_INFO("%s: f_embedding_scale = %f\n", __func__, hparams.f_embedding_scale); + LLAMA_LOG_INFO("%s: f_residual_scale = %f\n", __func__, hparams.f_residual_scale); + LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale); + LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); + } - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); - ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); - ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); - ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.swa_layers, hparams.n_layer); + if (arch == LLM_ARCH_BAILINGMOE) { + LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); + LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); + LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); + } - switch (hparams.n_layer) { - case 48: type = LLM_TYPE_310B_A15B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_KIMI_LINEAR: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla_impl); - ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla_impl); - ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); - ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); - ml.get_key(LLM_KV_KDA_HEAD_DIM, hparams.n_embd_head_kda); - - // MLA qk_rope_head_dim (for reference) - // qk_rope_head_dim = 64, qk_nope_head_dim = 128, qk_head_dim = 192 - - // Mark KDA layers as recurrent using n_head_kv pattern (like Jamba) - // Set n_head_kv = 0 for KDA layers (recurrent), n_head_kv = n_head for MLA layers (attention) - for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = hparams.n_head_kv(i) == 0; // KDA layers are recurrent - } + if (arch == LLM_ARCH_BAILINGMOE2) { + LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); + LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); + LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); + LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); + LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); + LLAMA_LOG_INFO("%s: nextn_predict_layers = %d\n", __func__, hparams.nextn_predict_layers); + } - // MoE parameters - Kimi uses moe_intermediate_size = 1024 - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); - ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); - ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); - ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func); + if (arch == LLM_ARCH_SMALLTHINKER || arch == LLM_ARCH_LFM2MOE) { + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); + } - switch (hparams.n_layer) { - case 27: type = LLM_TYPE_48B_A3B; break; // Kimi-Linear-48B-A3B - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_STEP35: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + if (arch == LLM_ARCH_GROVEMOE) { + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_ff_chexp = %d\n", __func__, hparams.n_ff_chexp); + LLAMA_LOG_INFO("%s: n_group_experts = %d\n", __func__, hparams.n_group_experts); + LLAMA_LOG_INFO("%s: expert_group_scale = %.2f\n", __func__, hparams.expert_group_scale); + } + } - hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + vocab.print_info(); +} - // full_attention layer only use half of the RoPE dimensions - hparams.n_rot_full = hparams.n_rot_full / 2; +ggml_backend_dev_t llama_model::dev_layer(int il) const { + return pimpl->dev_layer.at(il).dev; +} - // MoE + SWA parameters - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); - ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); - ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); +ggml_backend_dev_t llama_model::dev_output() const { + return pimpl->dev_output.dev; +} - // Step35 uses sigmoid gating by default (if not set in GGUF) - if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { - hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID; - } +template +static bool buft_supported(ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev, F & fn) { + ggml_init_params params = { + /*.mem_size =*/ ggml_tensor_overhead()*8, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; - ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); - ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); - ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.swa_layers, hparams.n_layer); - ml.get_key_or_arr(LLM_KV_SWIGLU_CLAMP_EXP, hparams.swiglu_clamp_exp, hparams.n_layer, false); - ml.get_key_or_arr(LLM_KV_SWIGLU_CLAMP_SHEXP, hparams.swiglu_clamp_shexp, hparams.n_layer, false); + ggml_context_ptr ctx { ggml_init(params) }; + if (!ctx) { + throw std::runtime_error(format("failed to create ggml context")); + } - switch (hparams.n_layer) { - case 45: type = LLM_TYPE_196B_A11B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - default: throw std::runtime_error("unsupported model architecture: " + arch_name()); + ggml_backend_buffer_ptr buf { ggml_backend_buft_alloc_buffer(buft, 0) }; + ggml_tensor * op_tensor = fn(ctx.get()); + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (op_tensor->src[i] != nullptr) { + assert(op_tensor->src[i]->buffer == nullptr); + op_tensor->src[i]->buffer = buf.get(); + } } - pimpl->n_bytes = ml.n_bytes; + bool op_supported = ggml_backend_dev_supports_op(dev, op_tensor); - pimpl->desc_str = arch_name() + " " + type_name() + " " + ml.ftype_name(); + return op_supported; +} - if (hparams.f_max_alibi_bias > 0.0f) { - hparams.use_alibi = true; +template +static ggml_backend_buffer_type_t select_buft(const buft_list_t & buft_list, const F & fn) { + for (const auto & cur : buft_list) { + ggml_backend_dev_t cur_dev = cur.first; + ggml_backend_buffer_type_t cur_buft = cur.second; + if (buft_supported(cur_buft, cur_dev, fn)) { + return cur_buft; + } } - hparams.rope_type = llama_model_rope_type(this); + throw std::runtime_error(format("no suitable buffer type found")); } -void llama_model::load_vocab(llama_model_loader & ml) { - const auto kv = LLM_KV(arch); +ggml_backend_buffer_type_t llama_model::select_buft(int il) const { + return ::select_buft( + *pimpl->dev_layer.at(il).buft_list, + [&](ggml_context * ctx) { + ggml_tensor * cur = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.n_embd); + ggml_tensor * layer_dir = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.n_embd); + return ggml_add(ctx, cur, layer_dir); + }); +} - vocab.load(ml, kv); +bool llama_model::has_tensor_overrides() const { + return pimpl->has_tensor_overrides; } -bool llama_model::load_tensors(llama_model_loader & ml) { - const auto & split_mode = params.split_mode; - const auto & use_mlock = params.use_mlock; - const auto & tensor_split = params.tensor_split; - - const int n_layer = hparams.n_layer; - const int n_gpu_layers = this->n_gpu_layers(); - - const bool use_mmap_buffer = true; - - LLAMA_LOG_INFO("%s: loading model tensors, this can take a while... (mmap = %s, direct_io = %s)\n", - __func__, ml.use_mmap ? "true" : "false", ml.use_direct_io ? "true" : "false"); - - // build a list of buffer types for the CPU and GPU devices - pimpl->cpu_buft_list = make_cpu_buft_list(devices, params.use_extra_bufts, params.no_host); - for (const auto & dev : devices) { - buft_list_t buft_list = make_gpu_buft_list(dev.dev, split_mode, tensor_split); - // add CPU buffer types as a fallback - buft_list.insert(buft_list.end(), pimpl->cpu_buft_list.begin(), pimpl->cpu_buft_list.end()); - pimpl->gpu_buft_list.emplace(dev.dev, std::move(buft_list)); - } - - ggml_backend_dev_t cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); - if (cpu_dev == nullptr) { - throw std::runtime_error(format("%s: no CPU backend found", __func__)); - } - - // calculate the split points - bool all_zero = tensor_split == nullptr || std::all_of(tensor_split, tensor_split + n_devices(), [](float x) { return x == 0.0f; }); - std::vector splits(n_devices()); - if (all_zero) { - // default split, by free memory - for (size_t i = 0; i < n_devices(); ++i) { - ggml_backend_dev_t dev = devices[i].dev; - size_t total; - size_t free; - ggml_backend_dev_memory(dev, &free, &total); - - // devices can return 0 bytes for free and total memory if they do not - // have any to report. in this case, we will use the host memory as a fallback - // fixes: https://github.com/ggml-org/llama.cpp/issues/18577 - if (free == 0 && total == 0) { - ggml_backend_dev_memory(cpu_dev, &free, &total); - } - splits[i] = free; - } - } else { - std::copy(tensor_split, tensor_split + n_devices(), splits.begin()); - } - - // sum and normalize the splits to get the split points - float split_sum = 0.0f; - for (size_t i = 0; i < n_devices(); ++i) { - split_sum += splits[i]; - splits[i] = split_sum; - } - for (size_t i = 0; i < n_devices(); ++i) { - splits[i] /= split_sum; - } - - const int i_gpu_start = std::max(int(hparams.n_layer) + 1 - n_gpu_layers, 0); - const int act_gpu_layers = devices.empty() ? 0 : std::min(n_gpu_layers, int(n_layer) + 1); - auto get_layer_buft_list = [&](int il) -> llama_model::impl::layer_dev { - const bool is_swa = il < int(hparams.n_layer) && hparams.is_swa(il); - if (il < i_gpu_start || (il - i_gpu_start) >= act_gpu_layers) { - LLAMA_LOG_DEBUG("load_tensors: layer %3d assigned to device %s, is_swa = %d\n", il, ggml_backend_dev_name(cpu_dev), is_swa); - return {cpu_dev, &pimpl->cpu_buft_list}; - } - const int layer_gpu = std::upper_bound(splits.begin(), splits.begin() + n_devices(), float(il - i_gpu_start)/act_gpu_layers) - splits.begin(); - auto * dev = devices.at(layer_gpu).dev; - LLAMA_LOG_DEBUG("load_tensors: layer %3d assigned to device %s, is_swa = %d\n", il, ggml_backend_dev_name(dev), is_swa); - return {dev, &pimpl->gpu_buft_list.at(dev)}; - }; - - // assign the input layer - // there is very little benefit to offloading the input layer, so always keep it on the CPU - pimpl->dev_input = { cpu_dev, &pimpl->cpu_buft_list }; - - // assign the repeating layers to the devices according to the splits - pimpl->dev_layer.resize(n_layer); - for (int il = 0; il < n_layer; ++il) { - pimpl->dev_layer[il] = get_layer_buft_list(il); +const ggml_tensor * llama_model::get_tensor(const char * name) const { + auto it = std::find_if(tensors_by_name.begin(), tensors_by_name.end(), + [name](const std::pair & it) { + return it.first == name; + }); + if (it == tensors_by_name.end()) { + return nullptr; } - // assign the output layer - pimpl->dev_output = get_layer_buft_list(n_layer); - - const auto TENSOR_DUPLICATED = llama_model_loader::TENSOR_DUPLICATED; - const auto TENSOR_NOT_REQUIRED = llama_model_loader::TENSOR_NOT_REQUIRED; - const auto TENSOR_SKIP = llama_model_loader::TENSOR_SKIP; - const auto TENSOR_SKIP_IF_VIRTUAL = llama_model_loader::TENSOR_SKIP_IF_VIRTUAL; - - // create tensors for the weights - { - // note: cast to int64_t since we will use these for the tensor dimensions - const int64_t n_head = hparams.n_head(); - const int64_t n_head_kv = hparams.n_head_kv(); - const int64_t n_embd = hparams.n_embd; - const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(); - const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(); - const int64_t n_embd_head_k = hparams.n_embd_head_k(); - const int64_t n_embd_head_v = hparams.n_embd_head_v(); - const int64_t n_ff = hparams.n_ff(); - const int64_t n_embd_gqa = n_embd_v_gqa; - const int64_t n_vocab = vocab.n_tokens(); - const int64_t n_token_types = vocab.n_token_types(); - const int64_t n_rot = hparams.n_rot(); - const int64_t n_expert = hparams.n_expert; - const int64_t n_expert_used = hparams.n_expert_used; - const int64_t n_ctx_train = hparams.n_ctx_train; - - if (n_expert > 0 && hparams.n_expert_used == 0) { - throw std::runtime_error("model has expert layers but no expert layers are used"); - } - - auto create_tensor = [&](const LLM_TN_IMPL & tn, const std::initializer_list & ne, int flags) -> ggml_tensor * { - const buft_list_t * buft_list_layer = tn.bid == -1 ? nullptr : pimpl->dev_layer.at(tn.bid).buft_list; - return ml.create_tensor( - hparams, &pimpl->cpu_buft_list, pimpl->dev_input.buft_list, pimpl->dev_output.buft_list, buft_list_layer, - tn, ne, flags); - }; - - layers.resize(n_layer); - - // TODO: move to a separate function - const auto tn = LLM_TN(arch); - - // helper: try merged gate_up_exps first, fall back to separate gate and up - auto create_tensor_gate_up_exps = [&](llama_layer & layer, int bid, int64_t n_embd_, int64_t n_ff_, int64_t n_expert_, int flags) { - layer.ffn_gate_up_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_UP_EXPS, "weight", bid), {n_embd_, n_ff_ * 2, n_expert_}, TENSOR_NOT_REQUIRED); - if (layer.ffn_gate_up_exps == nullptr) { - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", bid), {n_embd_, n_ff_, n_expert_}, flags); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", bid), {n_embd_, n_ff_, n_expert_}, flags); - } - }; - - // helper: try to load merged qkv first, fall back to separate q, k, v - auto create_tensor_qkv = [&](llama_layer & layer, int bid, - int64_t n_embd_, int64_t n_embd_q_, int64_t n_embd_k_, int64_t n_embd_v_, - int flags) { - const int64_t n_embd_qkv = n_embd_q_ + n_embd_k_ + n_embd_v_; - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", bid), {n_embd_, n_embd_qkv}, TENSOR_NOT_REQUIRED | TENSOR_SKIP_IF_VIRTUAL); - if (layer.wqkv) { - layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", bid), {n_embd_qkv}, TENSOR_NOT_REQUIRED | TENSOR_SKIP_IF_VIRTUAL); - } else { - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", bid), {n_embd_, n_embd_q_}, flags); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", bid), {n_embd_, n_embd_k_}, flags); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", bid), {n_embd_, n_embd_v_}, flags); - layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", bid), {n_embd_q_}, TENSOR_NOT_REQUIRED); - layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", bid), {n_embd_k_}, TENSOR_NOT_REQUIRED); - layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", bid), {n_embd_v_}, TENSOR_NOT_REQUIRED); - } - }; - - switch (arch) { - case LLM_ARCH_LLAMA: - case LLM_ARCH_REFACT: - case LLM_ARCH_MINICPM: - case LLM_ARCH_GRANITE: - case LLM_ARCH_GRANITE_MOE: - case LLM_ARCH_MISTRAL3: - case LLM_ARCH_LLAMA_EMBED: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); - - // optional bias tensors - layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { - layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - } - else { - layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - } - - if (n_expert == 0) { - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - - // optional MLP bias - layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); - layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); - } else { - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, TENSOR_NOT_REQUIRED); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); - - // For Granite MoE Shared - if (hparams.n_ff_shexp > 0) { - layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); - layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); - layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, 0); - } - } - } - } break; - case LLM_ARCH_LLADA: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); - - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = - create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); - - // Use separate Q, K, V projections without bias, matching LLaDALlamaBlock - layer.wq = - create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0); - // No bias for QKV projections as per config: include_bias=false, include_qkv_bias=false - layer.wo = - create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); - layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), { n_embd }, TENSOR_NOT_REQUIRED); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); - - layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), { n_rot / 2 }, - TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0); - - // optional MLP bias - layer.ffn_gate_b = - create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), { n_ff }, TENSOR_NOT_REQUIRED); - layer.ffn_down_b = - create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), { n_embd }, TENSOR_NOT_REQUIRED); - layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), { n_ff }, TENSOR_NOT_REQUIRED); - } - } - break; - case LLM_ARCH_LLADA_MOE: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - GGML_ASSERT(n_expert > 0 && "n_expert must be > 0 for llada-moe"); - GGML_ASSERT(n_expert_used > 0 && "n_expert_used must be > 0 for llada-moe"); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); - - const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; - - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); - } - } break; - case LLM_ARCH_LLAMA4: - { - if (n_expert == 0) { - throw std::runtime_error(arch_name() + " model cannot have zero experts"); - } - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - const bool is_moe_layer = hparams.n_moe_layer_step > 0 && (i + 1) % hparams.n_moe_layer_step == 0; - - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - - if (is_moe_layer) { - const int64_t n_ff_exp = hparams.n_ff_exp; - - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); - - // Shared expert - const int64_t n_ff_shexp = n_ff_exp; - layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp}, 0); - layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd }, 0); - layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp}, 0); - } else { - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } - } break; - case LLM_ARCH_DECI: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(i); - const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(i); - const int64_t n_ff = hparams.n_ff(i); - const int64_t n_head = hparams.n_head(i); - const int64_t n_head_kv = hparams.n_head_kv(i); - - if (n_head_kv == 0 && n_head > 0) { - // linear attention for DeciLMCausalModel - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - } - else if (n_head_kv > 0) { - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); - } - - // optional bias tensors - layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - - if (n_ff > 0) { - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - } - - if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { - layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - } - else { - layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - } - - if (n_ff > 0) { - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - - // optional MLP bias - layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); - layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); - } - } break; - case LLM_ARCH_MINICPM3: - { - const int64_t n_embd_head_qk_rope = hparams.n_rot(); - const int64_t n_embd_head_qk_nope = hparams.n_embd_head_k() - hparams.n_rot(); - - const int64_t q_lora_rank = hparams.n_lora_q; - const int64_t kv_lora_rank = hparams.n_lora_kv; - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, 0); - - layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, 0); - - layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, 0); - layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k}, 0); - - layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + (n_embd_head_qk_rope)}, 0); - layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_head * ( n_embd_head_v), n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - - layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), { n_embd_head_qk_rope/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_embd_head_qk_rope/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - } - } break; - case LLM_ARCH_GROK: - { - if (n_expert == 0) { - throw std::runtime_error(arch_name() + " model cannot have zero experts"); - } - - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff/* / n_expert_used*/; // grok-1 n_ff_exp == n_ff - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - - layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, TENSOR_NOT_REQUIRED); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); - - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); - - layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); - if (!layer.ffn_post_norm) { - layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); - } - } - } break; - case LLM_ARCH_DBRX: - { - if (n_expert == 0) { - throw std::runtime_error("DBRX model cannot have zero experts"); - } - - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - - layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); - } - } break; - case LLM_ARCH_BAICHUAN: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - { - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } break; - case LLM_ARCH_FALCON: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - { - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); - - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - if (!output) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); // needs to be on GPU - } - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); - - layer.attn_norm_2 = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } break; - case LLM_ARCH_STARCODER: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}, 0); - - // output - { - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - if (!output) { - // needs to be on GPU - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); - - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); - layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); - - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); - - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); - layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); - - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); - } - } break; - case LLM_ARCH_BERT: - case LLM_ARCH_NOMIC_BERT: - case LLM_ARCH_NOMIC_BERT_MOE: - case LLM_ARCH_JINA_BERT_V3: - { - if (n_token_types == 0) { - throw std::runtime_error(arch_name() + " model needs to define token type count"); - } - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, TENSOR_NOT_REQUIRED); - - if (arch == LLM_ARCH_BERT) { - pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}, 0); - - cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, TENSOR_NOT_REQUIRED); - cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"), {n_embd}, TENSOR_NOT_REQUIRED); - - cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED); - cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {hparams.n_cls_out}, TENSOR_NOT_REQUIRED); - } - - tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {n_embd}, 0); - tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias", 0), {n_embd}, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); - - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - - layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0); - layer.attn_out_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd}, 0); - - if (hparams.moe_every_n_layers > 0 && i % hparams.moe_every_n_layers == 1) { - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff, n_expert}, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); - } else { - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); - layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - - if (arch == LLM_ARCH_NOMIC_BERT) { - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - } - } - - layer.layer_out_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, 0); - layer.layer_out_norm_b = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd}, 0); - } - } break; - case LLM_ARCH_MODERN_BERT: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {n_embd}, 0); - - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - - for(int i = 0; i < n_layer; ++i) { - auto& layer = layers[i]; - - if ( i != 0 ) { - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - } else{ - // layer 0 uses identity - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); - } - - - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, 3 * n_embd }, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, 2 * n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - } - - cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED); - cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {hparams.n_cls_out}, TENSOR_NOT_REQUIRED); - cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, TENSOR_NOT_REQUIRED); - cls_norm = create_tensor(tn(LLM_TENSOR_CLS_NORM, "weight"), {n_embd}, TENSOR_NOT_REQUIRED); - - } break; - case LLM_ARCH_NEO_BERT: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, TENSOR_NOT_REQUIRED); - cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"), {n_embd}, TENSOR_NOT_REQUIRED); - - cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED); - cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {hparams.n_cls_out}, TENSOR_NOT_REQUIRED); - - output_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_OUTPUT_NORM, "weight"), {n_embd}, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + return it->second; +} - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); +float llama_model::get_rope_freq_base (const llama_cparams & cparams, int il) const { + return hparams.is_swa(il) ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base; +} - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); +float llama_model::get_rope_freq_scale(const llama_cparams & cparams, int il) const { + return hparams.is_swa(il) ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale; +} - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff*2}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); - } - } break; - case LLM_ARCH_EUROBERT: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); +ggml_tensor * llama_model::get_rope_factors(const llama_cparams & cparams, int il) const { + const uint32_t n_ctx_seq = cparams.n_ctx_seq; - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); - } - } break; - case LLM_ARCH_JINA_BERT_V2: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); // word_embeddings - type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, 0); // token_type_embeddings - - tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {n_embd}, 0); // LayerNorm - tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias", 0), {n_embd}, 0); // LayerNorm bias - - cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, 1}, TENSOR_NOT_REQUIRED); - cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"), {1}, TENSOR_NOT_REQUIRED); - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; // JinaBertLayer - - create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); - - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.attn_q_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); //output_dens - layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); //output_dens - - layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0); //output_norm - layer.attn_out_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd}, 0); - - layer.attn_norm_2 = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); - - const auto tn_ffn_up_weight = tn(LLM_TENSOR_FFN_UP, "weight", i); - ggml_tensor * t_ffn_up = ml.get_tensor_meta(tn_ffn_up_weight.str().c_str()); - const int64_t n_ffn_up = t_ffn_up ? t_ffn_up->ne[1] : n_ff; - - GGML_ASSERT(n_ffn_up == n_ff || n_ffn_up == n_ff * 2); - layer.ffn_up = create_tensor(tn_ffn_up_weight, {n_embd, n_ffn_up}, 0); - layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ffn_up}, TENSOR_NOT_REQUIRED); - - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); - layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); - - layer.layer_out_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, 0); - layer.layer_out_norm_b = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd}, 0); - } - } break; - case LLM_ARCH_BLOOM: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {n_embd}, 0); - tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias", 0), {n_embd}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); - - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); - layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); - - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); - - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); - layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); - - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); - } - } break; - case LLM_ARCH_MPT: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}, TENSOR_NOT_REQUIRED); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, TENSOR_NOT_REQUIRED); - - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - if (!output) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); // needs to be on GPU - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); - layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); - - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); - layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); - - // FIXME test-llama-archs crashes if q_norm is created - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED | TENSOR_SKIP_IF_VIRTUAL); - layer.attn_q_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED | TENSOR_SKIP_IF_VIRTUAL); - - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - - // AWQ ScaleActivation layer - layer.ffn_act = create_tensor(tn(LLM_TENSOR_FFN_ACT, "scales", i), {n_ff}, TENSOR_NOT_REQUIRED); - } - } break; - case LLM_ARCH_STABLELM: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - - // optional q and k layernorms, present in StableLM 2 12B - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k, n_head}, TENSOR_NOT_REQUIRED); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k, n_head_kv}, TENSOR_NOT_REQUIRED); - - // optional FFN norm, not present in StableLM 2 12B which uses parallel residual - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } break; - case LLM_ARCH_QWEN: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd*3}, 0); - layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd*3}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff/2}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff/2, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff/2}, 0); - } - } break; - case LLM_ARCH_QWEN2: - case LLM_ARCH_QWEN2VL: - case LLM_ARCH_DREAM: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), {n_vocab}, TENSOR_NOT_REQUIRED); - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } break; - case LLM_ARCH_QWEN2MOE: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); - - if (n_expert == 0) { - throw std::runtime_error("n_expert must be > 0 for QWEN2MOE"); - } - if (n_expert_used == 0) { - throw std::runtime_error("n_expert_used must be > 0 for QWEN2MOE"); - } - - // MoE branch - const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; - - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); - - // Shared expert branch - const int64_t n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff; - - layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), {n_embd}, 0); - layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp}, 0); - layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, 0); - layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp}, 0); - } - } break; - case LLM_ARCH_QWEN3: - case LLM_ARCH_QWEN3VL: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - // output rerank head - cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); - - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } break; - case LLM_ARCH_QWEN3MOE: - case LLM_ARCH_QWEN3VLMOE: - case LLM_ARCH_RND1: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); - - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); - - if (n_expert == 0) { - throw std::runtime_error("n_expert must be > 0 for QWEN3MOE"); - } - if (n_expert_used == 0) { - throw std::runtime_error("n_expert_used must be > 0 for QWEN3MOE"); - } - - // MoE branch - const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; - - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); - } - } break; - case LLM_ARCH_PHI2: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), {n_vocab}, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); - - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); - - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); - layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); - - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); - } - } break; - case LLM_ARCH_PHI3: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, TENSOR_NOT_REQUIRED); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); - - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, 2 * n_ff }, 0); - - layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - } - } break; - case LLM_ARCH_PHIMOE: - { - const int64_t n_embd_head = n_embd / n_head; - - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); - output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, 0); - output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), { n_vocab }, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); - layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), { n_embd }, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }, 0); - layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), { n_embd }, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); - layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), { n_embd }, 0); - - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); - - layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), { n_embd_head/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_embd_head/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - } - } break; - case LLM_ARCH_PLAMO: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } break; - case LLM_ARCH_PLAMO2: - { - // mamba parameters - const uint32_t d_conv = hparams.ssm_d_conv; - const uint32_t d_state = hparams.ssm_d_state; - const uint32_t num_heads = hparams.ssm_dt_rank; - const uint32_t intermediate_size = hparams.ssm_d_inner; - const int64_t dt_dim = std::max(64, int(hparams.n_embd / 16)); - - // attention parameters - const uint32_t qk_dim = hparams.n_embd_head_k(); - const uint32_t v_dim = hparams.n_embd_head_v(); - - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - bool is_mamba_layer = hparams.is_recurrent(i); - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - if (is_mamba_layer) { - layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, 2 * intermediate_size}, 0); - layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, intermediate_size}, 0); - - layer.ssm_x = create_tensor(tn(LLM_TENSOR_SSM_X, "weight", i), {intermediate_size, dt_dim + 2*d_state}, 0); - layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "weight", i), {dt_dim, num_heads}, 0); - layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {num_heads}, 0); - - layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {num_heads}, 0); - layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {num_heads}, 0); - - layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {intermediate_size, n_embd}, 0); - - layer.ssm_dt_norm = create_tensor(tn(LLM_TENSOR_SSM_DT_NORM, i), {dt_dim}, 0); - layer.ssm_b_norm = create_tensor(tn(LLM_TENSOR_SSM_B_NORM, i), {d_state}, 0); - layer.ssm_c_norm = create_tensor(tn(LLM_TENSOR_SSM_C_NORM, i), {d_state}, 0); - } else { - const int64_t num_attention_heads = hparams.n_head(i); - const int64_t q_num_heads = num_attention_heads; - const int64_t num_key_value_heads = hparams.n_head_kv(i); - const int64_t k_num_heads = num_key_value_heads; - const int64_t v_num_heads = num_key_value_heads; - const int64_t q_proj_dim = q_num_heads * qk_dim; - const int64_t k_proj_dim = k_num_heads * qk_dim; - const int64_t v_proj_dim = v_num_heads * v_dim; - - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, q_proj_dim + k_proj_dim + v_proj_dim}, 0); - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {qk_dim, num_attention_heads}, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {qk_dim, k_num_heads}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {q_num_heads * v_dim, n_embd}, 0); - } - - // All layers have post-attention norm, FFN norm, and FFN tensors - layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, i), {n_embd}, 0); - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff * 2}, 0); - layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, i), {n_embd}, 0); - } - } break; - case LLM_ARCH_PLAMO3: - { - const int64_t head_dim_q = hparams.n_embd_head_k(); - const int64_t head_dim_v = hparams.n_embd_head_v(); - - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - const int64_t num_attention_heads = hparams.n_head(i); - const int64_t num_key_value_heads = hparams.n_head_kv(i); - const int64_t q_proj_dim = num_attention_heads * head_dim_q; - const int64_t k_proj_dim = num_key_value_heads * head_dim_q; - const int64_t v_proj_dim = num_key_value_heads * head_dim_v; - const int64_t n_ff_cur = hparams.n_ff(i); - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), - {n_embd,q_proj_dim + k_proj_dim + v_proj_dim}, 0); - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {head_dim_q}, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {head_dim_q}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {num_attention_heads * head_dim_v, n_embd}, 0); - layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, i), {n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, i), {n_embd}, 0); - - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff_cur * 2}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff_cur, n_embd}, 0); - } - } break; - case LLM_ARCH_GPT2: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); - - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); - layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); - - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); - - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); - layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); - - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); - } - } break; - case LLM_ARCH_CODESHELL: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - - // if tok embd is NULL, init from output - if (tok_embd == NULL) { - tok_embd = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); - - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); - - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); - layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); - - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); - } - } break; - case LLM_ARCH_ORION: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); - - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } break; - case LLM_ARCH_INTERNLM2: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - // layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); - create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); - - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } break; - case LLM_ARCH_GEMMA: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - } - } break; - case LLM_ARCH_GEMMA2: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); - layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); - } - } break; - case LLM_ARCH_GEMMA3: - case LLM_ARCH_GEMMA_EMBEDDING: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - // Dense linear weights - dense_2_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_2_OUT, "weight"), {n_embd, hparams.dense_2_feat_out}, TENSOR_NOT_REQUIRED); - dense_3_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_3_OUT, "weight"), {hparams.dense_3_feat_in, n_embd}, TENSOR_NOT_REQUIRED); - - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); - - layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); - } - } break; - case LLM_ARCH_GEMMA3N: - { - const int64_t n_altup = hparams.n_altup; - const int64_t laurel_rank = hparams.laurel_rank; - const int64_t n_embd_altup = hparams.n_embd_altup; - - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - altup_proj = create_tensor(tn(LLM_TENSOR_ALTUP_PROJ, "weight"), {n_embd, n_embd, n_altup - 1}, 0); - altup_unembd_proj = create_tensor(tn(LLM_TENSOR_ALTUP_UNEMBD_PROJ, "weight"), {n_embd, n_embd, n_altup - 1}, 0); - - per_layer_tok_embd = create_tensor(tn(LLM_TENSOR_PER_LAYER_TOKEN_EMBD, "weight"), {n_embd_altup * n_layer, n_vocab}, 0); - per_layer_model_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_MODEL_PROJ, "weight", 0), {n_embd, n_embd_altup * n_layer}, 0); - per_layer_proj_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ_NORM, "weight", 0), {n_embd_altup}, 0); - - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); - - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); - layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); - - // altup & laurel - layer.per_layer_inp_gate = create_tensor(tn(LLM_TENSOR_PER_LAYER_INP_GATE, "weight", i), {n_embd, n_embd_altup}, 0); - layer.per_layer_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ, "weight", i), {n_embd_altup, n_embd}, 0); - layer.per_layer_post_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_POST_NORM, "weight", i), {n_embd}, 0); - layer.altup_correct_coef = create_tensor(tn(LLM_TENSOR_ALTUP_CORRECT_COEF, "weight", i), {n_altup, n_altup}, 0); - layer.altup_correct_scale = create_tensor(tn(LLM_TENSOR_ALTUP_CORRECT_SCALE, "weight", i), {n_embd}, 0); - layer.altup_predict_coef = create_tensor(tn(LLM_TENSOR_ALTUP_PREDICT_COEF, "weight", i), {n_altup, n_altup * n_altup}, 0); - layer.altup_router = create_tensor(tn(LLM_TENSOR_ALTUP_ROUTER, "weight", i), {n_embd, n_altup}, 0); - layer.altup_router_norm = create_tensor(tn(LLM_TENSOR_ALTUP_ROUTER_NORM, "weight", i), {n_embd}, 0); - layer.laurel_l = create_tensor(tn(LLM_TENSOR_LAUREL_L, "weight", i), {n_embd, laurel_rank}, 0); - layer.laurel_r = create_tensor(tn(LLM_TENSOR_LAUREL_R, "weight", i), {laurel_rank, n_embd}, 0); - layer.laurel_post_norm = create_tensor(tn(LLM_TENSOR_LAUREL_POST_NORM, "weight", i), {n_embd}, 0); - } - } break; - case LLM_ARCH_GEMMA4: - { - const uint32_t n_embd_per_layer = hparams.n_embd_per_layer; - const int64_t n_ff_exp = hparams.n_ff_exp; - - if (n_embd_head_k != n_embd_head_v) { - throw std::runtime_error("Gemma 4 requires n_embd_head_k == n_embd_head_v"); - } - if (hparams.n_embd_head_k_swa != hparams.n_embd_head_v_swa) { - throw std::runtime_error("Gemma 4 requires n_embd_head_k_swa == n_embd_head_v_swa"); - } - - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - if (n_embd_per_layer > 0) { - per_layer_tok_embd = create_tensor(tn(LLM_TENSOR_PER_LAYER_TOKEN_EMBD, "weight"), {n_embd_per_layer * n_layer, n_vocab}, 0); - per_layer_model_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_MODEL_PROJ, "weight", 0), {n_embd, n_embd_per_layer * n_layer}, 0); - per_layer_proj_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ_NORM, "weight", 0), {n_embd_per_layer}, 0); - } - - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - - int rope_freqs_flag = 0; - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - const int64_t n_head = hparams.n_head(i); - const int64_t n_embd_head = hparams.n_embd_head_k(i); - const int64_t n_embd_k = hparams.n_embd_k_gqa(i); - const int64_t n_embd_v = hparams.n_embd_v_gqa(i); - const int kv_flags = hparams.has_kv(i) ? 0 : TENSOR_NOT_REQUIRED; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - // note: use_alternative_attention (v_proj is optional, if it's not present, use k_proj) - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k}, kv_flags); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v}, TENSOR_NOT_REQUIRED); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head * n_head, n_embd}, 0); - - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head}, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head}, kv_flags); - layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); - - layer.out_scale = create_tensor(tn(LLM_TENSOR_LAYER_OUT_SCALE, "weight", i), {1u}, TENSOR_NOT_REQUIRED); - - if (!hparams.is_swa(i)) { - // full_attention layers use rope_freqs for proportional rope - layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_embd_head/2}, rope_freqs_flag); - rope_freqs_flag = TENSOR_DUPLICATED; - } - - // handle use_double_wide_mlp - int64_t n_ff_cur = hparams.n_ff(i); - - // for expert layers, we use normal FFN as shared expert (same as python code) - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff_cur}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff_cur}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff_cur, n_embd}, 0); - layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); - - // MoE router - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, TENSOR_NOT_REQUIRED); - bool has_expert = layer.ffn_gate_inp != nullptr; - - // norm - if (has_expert) { - layer.ffn_gate_inp_s = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "scale", i), {n_embd}, 0); - - layer.ffn_pre_norm_2 = create_tensor(tn(LLM_TENSOR_FFN_PRE_NORM_2, "weight", i), {n_embd}, 0); - layer.ffn_post_norm_1 = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM_1, "weight", i), {n_embd}, 0); - layer.ffn_post_norm_2 = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM_2, "weight", i), {n_embd}, 0); - - // MoE FFN - layer.ffn_gate_up_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_UP_EXPS, "weight", i), {n_embd, n_ff_exp * 2, n_expert}, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); - - // per-expert scale will be loaded as down_exps_s at the end of the current switch case - } - - // per-layer embeddings - if (n_embd_per_layer > 0) { - layer.per_layer_inp_gate = create_tensor(tn(LLM_TENSOR_PER_LAYER_INP_GATE, "weight", i), {n_embd, n_embd_per_layer}, 0); - layer.per_layer_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ, "weight", i), {n_embd_per_layer, n_embd}, 0); - layer.per_layer_post_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_POST_NORM, "weight", i), {n_embd}, 0); - } - } - } break; - case LLM_ARCH_STARCODER2: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); - - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - - // optional bias tensors - layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); - - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - - // optional bias tensors - layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); - layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP , "bias", i), { n_ff}, 0); - } - } break; - case LLM_ARCH_MAMBA: - { - const int64_t d_conv = hparams.ssm_d_conv; - const int64_t d_inner = hparams.ssm_d_inner; - const int64_t d_state = hparams.ssm_d_state; - const int64_t dt_rank = hparams.ssm_dt_rank; - - // only an expansion factor of 2 is supported for now - if (2 * n_embd != d_inner) { - throw std::runtime_error("only an expansion factor of 2 is supported for now"); - } - - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - // if output is NULL, init from the input tok embed, duplicated to allow offloading - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - // norm - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, 2*d_inner}, 0); - - layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner}, 0); - layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner}, 0); - - layer.ssm_x = create_tensor(tn(LLM_TENSOR_SSM_X, "weight", i), {d_inner, dt_rank + 2*d_state}, 0); - - layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "weight", i), {dt_rank, d_inner}, 0); - layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {d_inner}, 0); - - // no "weight" suffix for these - layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {d_state, d_inner}, 0); - layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {d_inner}, 0); - - // out_proj - layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0); - } - } break; - case LLM_ARCH_MAMBA2: - { - const int64_t d_conv = hparams.ssm_d_conv; - const int64_t d_inner = hparams.ssm_d_inner; - const int64_t d_state = hparams.ssm_d_state; - const int64_t n_head = hparams.ssm_dt_rank; - const int64_t n_group = hparams.ssm_n_group; - const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + n_head; - - // only an expansion factor of 2 is supported for now - GGML_ASSERT(2 * n_embd == d_inner); - - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - { - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - // if output is NULL, init from the input tok embed, duplicated to allow offloading - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - // norm - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, d_in_proj}, 0); - - layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner + 2*n_group*d_state}, 0); - layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner + 2*n_group*d_state}, 0); - - layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {n_head}, 0); - - // no "weight" suffix for these - layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_head}, 0); - layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, n_head}, 0); - - layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {d_inner / n_group, n_group}, 0); - - // out_proj - layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0); - } - } break; - case LLM_ARCH_JAMBA: - { - const int64_t d_conv = hparams.ssm_d_conv; - const int64_t d_inner = hparams.ssm_d_inner; - const int64_t d_state = hparams.ssm_d_state; - const int64_t dt_rank = hparams.ssm_dt_rank; - - // only an expansion factor of 2 is supported for now - GGML_ASSERT(2 * n_embd == d_inner); - - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - { - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - // if output is NULL, init from the input tok embed, duplicated to allow offloading - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - } - - for (int i = 0; i < n_layer; ++i) { - const int64_t n_head_kv = hparams.n_head_kv(i); - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(i); - - auto & layer = layers[i]; - - // norm - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - if (n_head_kv == 0) { - // Mamba layer - layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, 2*d_inner}, 0); - - layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner}, 0); - layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner}, 0); - - layer.ssm_x = create_tensor(tn(LLM_TENSOR_SSM_X, "weight", i), {d_inner, dt_rank + 2*d_state}, 0); - - layer.ssm_dt_norm = create_tensor(tn(LLM_TENSOR_SSM_DT_NORM, "weight", i), {dt_rank}, 0); - - layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "weight", i), {dt_rank, d_inner}, 0); - layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {d_inner}, 0); - - layer.ssm_b_norm = create_tensor(tn(LLM_TENSOR_SSM_B_NORM, "weight", i), {d_state}, 0); - layer.ssm_c_norm = create_tensor(tn(LLM_TENSOR_SSM_C_NORM, "weight", i), {d_state}, 0); - - // no "weight" suffix for these - layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {d_state, d_inner}, 0); - layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {d_inner}, 0); - - // out_proj - layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0); - } else { - // Attention layers - - create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - } - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, TENSOR_NOT_REQUIRED); - - if (layer.ffn_gate_inp) { - // MoE - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); - } else { - // FFN (no MoE) - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } - } break; - case LLM_ARCH_GRANITE_HYBRID: - { - // mamba2 Mixer SSM params - // NOTE: int64_t for tensor dimensions - const int64_t d_conv = hparams.ssm_d_conv; - const int64_t d_inner = hparams.ssm_d_inner; - const int64_t d_state = hparams.ssm_d_state; - const int64_t n_ssm_head = hparams.ssm_dt_rank; - const int64_t n_group = hparams.ssm_n_group; - const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + n_ssm_head; - - // only an expansion factor of 2 is supported for now - GGML_ASSERT(2 * n_embd == d_inner); - - // embeddings - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - { - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - // if output is NULL, init from the input tok embed, duplicated to allow offloading - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - // norm - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - if (hparams.is_recurrent(i)) { - // ssm layers - layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, d_in_proj}, 0); - - layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner + 2*n_group*d_state}, 0); - layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner + 2*n_group*d_state}, TENSOR_NOT_REQUIRED); - - layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {n_ssm_head}, 0); - - // no "weight" suffix for these - layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_ssm_head}, 0); - layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, n_ssm_head}, 0); - - layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {d_inner / n_group, n_group}, 0); - - // out_proj - layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0); - } else { - // attention layers (with optional bias) - const int64_t n_head_i = hparams.n_head(i); - const int64_t n_embd_k_gqa_i = hparams.n_embd_k_gqa(i); - const int64_t n_embd_v_gqa_i = hparams.n_embd_v_gqa(i); - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head_i, n_embd_k_gqa_i, n_embd_v_gqa_i, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head_i, n_embd}, 0); - layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - } - - // feed forward (w/ optional biases) - if (n_expert > 0) { - // MoE FFN - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, TENSOR_NOT_REQUIRED); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); - - // For Granite MoE Shared - if (hparams.n_ff_shexp > 0) { - layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); - layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); - layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, 0); - } - } else { - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); - layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); - } - } - } break; - case LLM_ARCH_XVERSE: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } break; - case LLM_ARCH_COMMAND_R: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - // init output from the input tok embed - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - if (n_layer >= 64){ - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k, n_head}, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k, n_head_kv}, 0); - } - - create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } break; - case LLM_ARCH_COHERE2: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); - // init output from the input tok embed - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, - TENSOR_DUPLICATED); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }, 0); - - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0); - } - } - break; - case LLM_ARCH_OLMO: // adapted from LLM_ARCH_LLAMA with norm params removed - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } break; - case LLM_ARCH_OLMO2: - { - const int64_t n_embd_head = n_embd / n_head; - - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_head_kv * n_embd_head}, 0); - layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); - } - } break; - case LLM_ARCH_SEED_OSS: - { - const uint32_t head_dim = hparams.n_embd_head_k(); - const int64_t n_qo_dim = n_head * head_dim; - const int64_t n_kv_dim = n_head_kv * head_dim; - - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - create_tensor_qkv(layer, i, n_embd, n_qo_dim, n_kv_dim, n_kv_dim, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_qo_dim, n_embd}, 0); - - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - } - } break; - - case LLM_ARCH_OLMOE: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); - - if (n_expert == 0) { - throw std::runtime_error("n_expert must be > 0"); - } - if (n_expert_used == 0) { - throw std::runtime_error("n_expert_used must be > 0"); - } - - // MoE branch - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); - } - } break; - case LLM_ARCH_OPENELM: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - // init output from the input tok embed - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - - for (int i = 0; i < n_layer; ++i) { - const int64_t n_head = hparams.n_head(i); - const int64_t n_head_qkv = 2*hparams.n_head_kv(i) + n_head; - const int64_t n_ff = hparams.n_ff(i); - - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_head_qkv*n_embd_head_k}, 0); - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head*n_embd_head_k, n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } break; - case LLM_ARCH_GPTNEOX: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); - - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); - layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); - - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); - - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); - layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); - - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); - } - } break; - case LLM_ARCH_ARCTIC: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_embd}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_embd, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_embd}, 0); - - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); - layer.ffn_norm_exps = create_tensor(tn(LLM_TENSOR_FFN_NORM_EXPS, "weight", i), {n_embd}, 0); - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, false); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); - } - } break; - case LLM_ARCH_DEEPSEEK: - { - - const int64_t n_ff_exp = hparams.n_ff_exp; - const int64_t n_expert_shared = hparams.n_expert_shared; - - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - // try to load output.weight, if not found, use token_embd (tied embeddings) - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - if (!output) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - if (i < (int) hparams.n_layer_dense_lead) { - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } else { - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); - - if (n_expert == 0) { - throw std::runtime_error("n_expert must be > 0"); - } - if (n_expert_used == 0) { - throw std::runtime_error("n_expert_used must be > 0"); - } - - // MoE branch - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); - - // Shared expert branch - layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); - layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, 0); - layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); - } - } - } break; - case LLM_ARCH_DEEPSEEK2: - case LLM_ARCH_MISTRAL4: - { - const bool is_mla = hparams.is_mla(); - - // note: these are the actual head sizes you get when treating as MHA or after "decompression" using wv_b for MLA - const int64_t n_embd_head_k_mla = hparams.n_embd_head_k_mla(); - const int64_t n_embd_head_v_mla = hparams.n_embd_head_v_mla(); - - const int64_t n_embd_head_qk_rope = hparams.n_rot(); - const int64_t n_embd_head_qk_nope = n_embd_head_k_mla - n_embd_head_qk_rope; - GGML_ASSERT(n_embd_head_qk_nope >= 1); - - const int64_t q_lora_rank = hparams.n_lora_q; - const int64_t kv_lora_rank = hparams.n_lora_kv; - - const int64_t n_ff_exp = hparams.n_ff_exp; - const int64_t n_expert_shared = hparams.n_expert_shared; - - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - // try to load output.weight, if not found, use token_embd (tied embeddings) - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - if (!output) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - if (q_lora_rank > 0) { - layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, 0); - } - - layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, 0); - - if (q_lora_rank > 0) { - layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, 0); - layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k_mla}, 0); - } else { - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_head * n_embd_head_k_mla}, 0); - } - - layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + n_embd_head_qk_rope}, 0); - - // note: only old legacy GGUF files will have the unsplit wkv_b tensor in - if (is_mla) { - layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K_B, "weight", i), {n_embd_head_qk_nope, kv_lora_rank, n_head}, 0); - layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_embd_head_v_mla, n_head}, 0); - } else { - layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v_mla)}, 0); - } - - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_embd_head_v_mla, n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - if (i < (int) hparams.n_layer_dense_lead) { - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } else { - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); - layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); - - if (n_expert == 0) { - throw std::runtime_error("n_expert must be > 0"); - } - if (n_expert_used == 0) { - throw std::runtime_error("n_expert_used must be > 0"); - } - - // MoE branch - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); - create_tensor_gate_up_exps(layer, i, n_embd, n_ff_exp, n_expert, 0); - - // Shared expert branch - layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); - layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, 0); - layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); - } - } - } break; - case LLM_ARCH_DEEPSEEK2OCR: - { - // similar to deepseek2, but without MLA - const int64_t n_ff_exp = hparams.n_ff_exp; - const int64_t n_expert_shared = hparams.n_expert_shared; - - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - // try to load output.weight, if not found, use token_embd (tied embeddings) - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - if (!output) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - - // norm - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - if (i < (int) hparams.n_layer_dense_lead) { - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - } else { - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); - layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); - - if (n_expert == 0) { - throw std::runtime_error("n_expert must be > 0"); - } - if (n_expert_used == 0) { - throw std::runtime_error("n_expert_used must be > 0"); - } - - // MoE branch - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); - create_tensor_gate_up_exps(layer, i, n_embd, n_ff_exp, n_expert, 0); - - // Shared expert branch - layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); - layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, 0); - layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); - } - } - } break; - case LLM_ARCH_PLM: - { - const int64_t n_embd_head_qk_rope = hparams.n_rot(); - const int64_t n_embd_head_qk_nope = hparams.n_embd_head_k() - hparams.n_rot(); - const int64_t kv_lora_rank = hparams.n_lora_kv; - - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - // output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + (n_embd_head_qk_rope)}, 0); - layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, 0); - layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_head * ( n_embd_head_v), n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } break; - case LLM_ARCH_BITNET: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_sub_norm = create_tensor(tn(LLM_TENSOR_ATTN_SUB_NORM, "weight", i), {n_embd}, 0); - - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wq_s = create_tensor(tn(LLM_TENSOR_ATTN_Q, "scale", i), {1}, TENSOR_NOT_REQUIRED); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wk_s = create_tensor(tn(LLM_TENSOR_ATTN_K, "scale", i), {1}, TENSOR_NOT_REQUIRED); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv_s = create_tensor(tn(LLM_TENSOR_ATTN_V, "scale", i), {1}, TENSOR_NOT_REQUIRED); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.wo_s = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "scale", i), {1}, TENSOR_NOT_REQUIRED); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_sub_norm = create_tensor(tn(LLM_TENSOR_FFN_SUB_NORM, "weight", i), {n_ff}, 0); - - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_gate_s = create_tensor(tn(LLM_TENSOR_FFN_GATE, "scale", i), {1}, TENSOR_NOT_REQUIRED); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); - layer.ffn_down_s = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "scale", i), {1}, TENSOR_NOT_REQUIRED); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_up_s = create_tensor(tn(LLM_TENSOR_FFN_UP, "scale", i), {1}, TENSOR_NOT_REQUIRED); - } - } break; - case LLM_ARCH_T5: - { - const auto n_rel_attn_bkts = hparams.n_rel_attn_bkts; - - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_OUTPUT_NORM, "weight"), {n_embd}, 0); - output_norm = create_tensor(tn(LLM_TENSOR_DEC_OUTPUT_NORM, "weight"), {n_embd}, 0); - - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - // n_layer: number of encoder_layers - // dec_n_layer: number of decoder_layers - const int dec_n_layer = hparams.dec_n_layer; - if (dec_n_layer > n_layer) { - layers.resize(dec_n_layer); - } - - // load encoder layers - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_rel_b_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, TENSOR_NOT_REQUIRED); - - layer.wq_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wk_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); - layer.wo_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}, 0); - - layer.ffn_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_gate_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); - layer.ffn_down_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - - // load decoder layers - for (int i = 0; i < dec_n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_DEC_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_rel_b = create_tensor(tn(LLM_TENSOR_DEC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, TENSOR_NOT_REQUIRED); - - layer.wq = create_tensor(tn(LLM_TENSOR_DEC_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_DEC_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_DEC_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_DEC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}, 0); - - layer.attn_norm_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_NORM, "weight", i), {n_embd}, 0); - // this tensor seems to be unused in HF transformers implementation - layer.attn_rel_b_cross = create_tensor( - tn(LLM_TENSOR_DEC_CROSS_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, TENSOR_NOT_REQUIRED | TENSOR_SKIP_IF_VIRTUAL); - - layer.wq_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wk_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); - layer.wo_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_DEC_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_DEC_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_DEC_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_DEC_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } break; - case LLM_ARCH_T5ENCODER: - { - const auto n_rel_attn_bkts = hparams.n_rel_attn_bkts; - - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_rel_b_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, TENSOR_NOT_REQUIRED); - - layer.wq_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wk_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); - layer.wo_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}, 0); - - layer.ffn_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_gate_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); - layer.ffn_down_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } break; - case LLM_ARCH_JAIS: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); - - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); - layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); - - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); - - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); - layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); - - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, 0); - - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); - } - } break; - case LLM_ARCH_JAIS2: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - if (!output) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); - - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); - - // attention biases - all have shape n_embd (output dimension of projections) - layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); - layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd}, 0); - layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd}, 0); - layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); - - // Jais-2 uses simple MLP (no gate) with biases - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); - layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); - } - } break; - case LLM_ARCH_CHATGLM: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); - - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff * 2}, 0); - - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); - } - } break; - case LLM_ARCH_GLM4: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - int flags = 0; - if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { - // skip all tensors in the NextN layers - flags |= TENSOR_SKIP; - } - - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, flags); - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, flags); - - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, flags); - - layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, flags); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, flags); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, flags); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff * 2}, flags); - - layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, flags); - - // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers - if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { - layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags); - layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags); - layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags); - - // Optional tensors - layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED); - layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED); - layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, flags | TENSOR_NOT_REQUIRED); - } - } - } break; - case LLM_ARCH_GLM4_MOE: - { - const int64_t n_expert = hparams.n_expert; - const int64_t n_expert_used = hparams.n_expert_used; - const int64_t n_expert_shared = hparams.n_expert_shared; - - GGML_ASSERT(hparams.n_expert > 0 && "n_expert must be > 0 for GLM4_MOE MoE layers"); - GGML_ASSERT(hparams.n_expert_used > 0 && "n_expert_used must be > 0 for GLM4_MOE MoE layers"); - - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); - } - - // Load ALL tensors including NextN layer to satisfy total tensor count - // but only PROCESS up to last layer (skipping final NextN layer) in forward pass - for (int i = 0; i < n_layer; ++i) { - int flags = 0; - if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { - // skip all tensors in the NextN layers - flags |= TENSOR_SKIP; - } - - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, flags); - - // GLM-style attention with bias terms - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, flags); - - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, flags); - - // K/Q norm tensors (optional for GLM-4.5 355B variant) - layer.attn_q_norm = create_tensor( - tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, TENSOR_NOT_REQUIRED | flags); - layer.attn_k_norm = create_tensor( - tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, TENSOR_NOT_REQUIRED | flags); - - layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, flags); - - // Check if this layer uses MoE or dense FFN based on n_layer_dense_lead - // GLM 4.5 uses hybrid architecture: layer 0 is dense, layers 1+ are MoE - const bool use_moe = (static_cast(i) >= hparams.n_layer_dense_lead); - - if (use_moe) { - // MoE layers - layer.ffn_gate_inp = - create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, flags); - layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), { n_expert }, flags); - - // MoE branch - const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; - - layer.ffn_gate_exps = create_tensor( - tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, flags); - layer.ffn_down_exps = create_tensor( - tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, flags); - layer.ffn_up_exps = create_tensor( - tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, flags); - - // Shared expert - if (n_expert_shared > 0) { - const int64_t n_ff_shexp = n_ff_exp * n_expert_shared; - layer.ffn_gate_shexp = create_tensor( - tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp }, flags); - layer.ffn_down_shexp = create_tensor( - tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_shexp, n_embd }, flags); - layer.ffn_up_shexp = create_tensor( - tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp }, flags); - } - } else { - // Dense layers (first k layers) - GLM uses separate gate/up projections - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, flags); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, flags); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, flags); - } - - // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers - if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { - layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags); - layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags); - layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags); - - // Optional tensors - layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED); - layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED); - layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, flags | TENSOR_NOT_REQUIRED); - } - } - } - break; - case LLM_ARCH_GLM_DSA: - { - const bool is_mla = hparams.is_mla(); - if (!is_mla) { - throw std::runtime_error("GLM_DSA architecture requires MLA"); - } - - // note: these are the actual head sizes you get when treating as MHA or after "decompression" using wv_b for MLA - const int64_t n_embd_head_k_mla = hparams.n_embd_head_k_mla(); - const int64_t n_embd_head_v_mla = hparams.n_embd_head_v_mla(); - - const int64_t n_embd_head_qk_rope = hparams.n_rot(); - const int64_t n_embd_head_qk_nope = n_embd_head_k_mla - n_embd_head_qk_rope; - - const int64_t q_lora_rank = hparams.n_lora_q; - const int64_t kv_lora_rank = hparams.n_lora_kv; - - const int64_t n_ff_exp = hparams.n_ff_exp; - const int64_t n_expert_shared = hparams.n_expert_shared; - - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - // try to load output.weight, if not found, use token_embd (tied embeddings) - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - if (!output) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - int flags = 0; - if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { - // skip all tensors in the NextN layers - // TODO @ngxson : TENSOR_NOT_REQUIRED was a hack, need to remove it later - flags |= TENSOR_SKIP | TENSOR_NOT_REQUIRED; - } - - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, flags); - layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, flags); - layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, flags); - - layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, flags); - layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k_mla}, flags); - - layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + n_embd_head_qk_rope}, flags); - - // note: only old legacy GGUF files will have the unsplit wkv_b tensor in - layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K_B, "weight", i), {n_embd_head_qk_nope, kv_lora_rank, n_head}, flags); - layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_embd_head_v_mla, n_head}, flags); - - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_embd_head_v_mla, n_embd}, flags); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, flags); - - // DSA indexer - layer.indexer_k_norm = create_tensor(tn(LLM_TENSOR_INDEXER_K_NORM, "weight", i), {hparams.indexer_head_size}, flags); - layer.indexer_k_norm_b = create_tensor(tn(LLM_TENSOR_INDEXER_K_NORM, "bias", i), {hparams.indexer_head_size}, flags); - layer.indexer_proj = create_tensor(tn(LLM_TENSOR_INDEXER_PROJ, "weight", i), {n_embd, hparams.indexer_n_head}, flags); - layer.indexer_attn_k = create_tensor(tn(LLM_TENSOR_INDEXER_ATTN_K, "weight", i), {n_embd, hparams.indexer_head_size}, flags); - layer.indexer_attn_q_b = create_tensor(tn(LLM_TENSOR_INDEXER_ATTN_Q_B, "weight", i), {q_lora_rank, hparams.indexer_n_head * hparams.indexer_head_size}, flags); - if (i < (int) hparams.n_layer_dense_lead) { - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, flags); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, flags); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, flags); - } else { - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, flags); - layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); - - if (n_expert == 0) { - throw std::runtime_error("n_expert must be > 0"); - } - if (n_expert_used == 0) { - throw std::runtime_error("n_expert_used must be > 0"); - } - - // MoE branch - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, flags); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, flags); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, flags); - - // Shared expert branch - layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, flags); - layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, flags); - layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, flags); - } - - // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers - if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { - layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags); - layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags); - layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags); - - // Optional tensors - layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED); - layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED); - layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, flags | TENSOR_NOT_REQUIRED); - } - } - } break; - case LLM_ARCH_NEMOTRON: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - - // optional bias tensors - layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); - - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - - // optional MLP bias - layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); - } - } break; - case LLM_ARCH_NEMOTRON_H: - case LLM_ARCH_NEMOTRON_H_MOE: - { - // mamba2 Mixer SSM params - // NOTE: int64_t for tensor dimensions - const int64_t d_conv = hparams.ssm_d_conv; - const int64_t d_inner = hparams.ssm_d_inner; - const int64_t d_state = hparams.ssm_d_state; - const int64_t n_ssm_head = hparams.ssm_dt_rank; - const int64_t n_group = hparams.ssm_n_group; - const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + n_ssm_head; - const int64_t moe_n_embd = hparams.moe_latent_size > 0 ? hparams.moe_latent_size : n_embd; - - // embeddings - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - { - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - // if output is NULL, init from the input tok embed, duplicated to allow offloading - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - // all blocks use the attn norm - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - if (hparams.is_recurrent(i)) { - // ssm layers - layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, d_in_proj}, 0); - - layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner + 2*n_group*d_state}, 0); - layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner + 2*n_group*d_state}, TENSOR_NOT_REQUIRED); - - layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {n_ssm_head}, 0); - - // no "weight" suffix for these - layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_ssm_head}, 0); - layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, n_ssm_head}, 0); - - layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {d_inner / n_group, n_group}, 0); - - // out_proj - layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0); - } else if (hparams.n_ff(i) == 0) { - // attention layers (with optional bias) - const int64_t n_head_i = hparams.n_head(i); - const int64_t n_embd_k_gqa_i = hparams.n_embd_k_gqa(i); - const int64_t n_embd_v_gqa_i = hparams.n_embd_v_gqa(i); - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head_i, n_embd_k_gqa_i, n_embd_v_gqa_i, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head_i, n_embd}, 0); - layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - } else { - if (n_expert != 0) { - const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; - const int64_t n_ff_shexp = hparams.n_ff_shexp; - - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert}, 0); - layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert }, 0); - - // MoE branch - layer.ffn_latent_down = create_tensor(tn(LLM_TENSOR_FFN_LATENT_DOWN, "weight", i), {n_embd, moe_n_embd}, TENSOR_NOT_REQUIRED); - layer.ffn_latent_up = create_tensor(tn(LLM_TENSOR_FFN_LATENT_UP, "weight", i), {moe_n_embd, n_embd}, TENSOR_NOT_REQUIRED); - - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, moe_n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {moe_n_embd, n_ff_exp, n_expert}, 0); - - // Shared expert branch - layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, 0); - layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp}, 0); - - } else { - // mlp layers - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { hparams.n_ff(i), n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, hparams.n_ff(i)}, 0); - layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {hparams.n_ff(i)}, TENSOR_NOT_REQUIRED); - } - } - } - } break; - case LLM_ARCH_EXAONE: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } break; - case LLM_ARCH_EXAONE4: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - - layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - - layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); - - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); - } - } break; - case LLM_ARCH_EXAONE_MOE: - { - const int64_t n_ff_exp = hparams.n_ff_exp; - const int64_t n_expert = hparams.n_expert; - const int64_t n_expert_used = hparams.n_expert_used; - const int64_t n_ff_shexp = hparams.n_ff_shexp > 0 ? hparams.n_ff_shexp : n_ff_exp; - const int64_t head_dim = hparams.n_embd_head_k(); - const int64_t n_qo_dim = n_head * head_dim; - const int64_t n_kv_dim = n_head_kv * head_dim; - - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - int flags = 0; - if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { - // skip all tensors in the NextN layers - flags |= TENSOR_SKIP; - } - - auto & layer = layers[i]; - create_tensor_qkv(layer, i, n_embd, n_qo_dim, n_kv_dim, n_kv_dim, flags); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_qo_dim, n_embd}, flags); - - layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0) | flags); - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, flags); - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, flags); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, flags); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, flags); - - // dense layers for first n_layer_dense_lead layers or nextn_predict_layers layers at the end - if (i < (int) hparams.n_layer_dense_lead || (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers)) { - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, flags); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, flags); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, flags); - } else { - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, flags); - layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED | flags); - - if (n_expert == 0) { - throw std::runtime_error("n_expert must be > 0"); - } - if (n_expert_used == 0) { - throw std::runtime_error("n_expert_used must be > 0"); - } - - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, flags); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, flags); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, flags); - - layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp}, flags); - layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, flags); - layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp}, flags); - } - - // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers - if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { - layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), {2 * n_embd, n_embd}, flags); - layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), {n_embd}, flags); - layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), {n_embd}, flags); - - layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), {n_embd}, flags | TENSOR_NOT_REQUIRED); - layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), {n_embd, n_vocab}, flags | TENSOR_NOT_REQUIRED); - layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), {n_embd, n_vocab}, flags | TENSOR_NOT_REQUIRED); - } - } - } break; - case LLM_ARCH_RWKV6: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // Block 0, LN0 - tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {n_embd}, 0); - tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias", 0), {n_embd}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - const int time_mix_extra_dim = hparams.time_mix_extra_dim; - const int time_decay_extra_dim = hparams.time_decay_extra_dim; - const int head_size = hparams.wkv_head_size; - const int attn_hidden_size = n_embd; - const int ffn_size = hparams.n_ff_arr[0]; - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); - - layer.attn_norm_2 = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, 0); - layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, 0); - - layer.time_mix_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W1, "weight", i), {n_embd, time_mix_extra_dim * 5}, 0); - layer.time_mix_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {time_mix_extra_dim, n_embd, 5}, 0); - - layer.time_mix_lerp_x = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_X, "weight", i), {n_embd, 1, 1}, 0); - layer.time_mix_lerp_w = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_W, "weight", i), {n_embd, 1, 1}, TENSOR_NOT_REQUIRED); - layer.time_mix_lerp_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_K, "weight", i), {n_embd, 1, 1}, TENSOR_NOT_REQUIRED); - layer.time_mix_lerp_v = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_V, "weight", i), {n_embd, 1, 1}, TENSOR_NOT_REQUIRED); - layer.time_mix_lerp_r = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_R, "weight", i), {n_embd, 1, 1}, TENSOR_NOT_REQUIRED); - layer.time_mix_lerp_g = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_G, "weight", i), {n_embd, 1, 1}, TENSOR_NOT_REQUIRED); - layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 5}, TENSOR_NOT_REQUIRED); - GGML_ASSERT(!(layer.time_mix_lerp_fused == NULL && layer.time_mix_lerp_w == NULL)); - - layer.time_mix_first = create_tensor(tn(LLM_TENSOR_TIME_MIX_FIRST, "weight", i), {head_size, n_embd / head_size}, 0); - layer.time_mix_decay = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY, "weight", i), {n_embd}, 0); - layer.time_mix_decay_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY_W1, "weight", i), {n_embd, time_decay_extra_dim}, 0); - layer.time_mix_decay_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY_W2, "weight", i), {time_decay_extra_dim, attn_hidden_size}, 0); - layer.time_mix_key = create_tensor(tn(LLM_TENSOR_TIME_MIX_KEY, "weight", i), {attn_hidden_size, n_embd}, 0); - layer.time_mix_value = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "weight", i), {attn_hidden_size, n_embd}, 0); - layer.time_mix_receptance = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "weight", i), {attn_hidden_size, n_embd}, 0); - layer.time_mix_gate = create_tensor(tn(LLM_TENSOR_TIME_MIX_GATE, "weight", i), {attn_hidden_size, n_embd}, 0); - - layer.time_mix_ln = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "weight", i), {n_embd}, 0); - layer.time_mix_ln_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "bias", i), {n_embd}, 0); - layer.time_mix_output = create_tensor(tn(LLM_TENSOR_TIME_MIX_OUTPUT, "weight", i), {n_embd, attn_hidden_size}, 0); - - layer.channel_mix_lerp_k = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_LERP_K, "weight", i), {n_embd, 1, 1}, 0); - layer.channel_mix_lerp_r = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_LERP_R, "weight", i), {n_embd, 1, 1}, 0); - - layer.channel_mix_key = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_KEY, "weight", i), {n_embd, ffn_size}, 0); - layer.channel_mix_value = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_VALUE, "weight", i), {ffn_size, n_embd}, 0); - layer.channel_mix_receptance = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_RECEPTANCE, "weight", i), {n_embd, n_embd}, 0); - } - - } break; - case LLM_ARCH_RWKV6QWEN2: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, TENSOR_NOT_REQUIRED); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - const int time_mix_extra_dim = hparams.time_mix_extra_dim; - const int time_decay_extra_dim = hparams.time_decay_extra_dim; - const int head_size = hparams.wkv_head_size; - const int attn_hidden_size = n_embd; - const int n_head_kv = hparams.n_head_kv(); - int attn_key_value_size; - if (n_head_kv == 0 || attn_hidden_size / head_size == n_head_kv) { - attn_key_value_size = attn_hidden_size; - } else { - attn_key_value_size = n_head_kv * head_size; - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - layer.time_mix_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W1, "weight", i), {n_embd, time_mix_extra_dim * 5}, 0); - layer.time_mix_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {time_mix_extra_dim, n_embd, 5}, 0); - - layer.time_mix_lerp_x = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_X, "weight", i), {n_embd, 1, 1}, 0); - layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 5}, 0); - - layer.time_mix_first = create_tensor(tn(LLM_TENSOR_TIME_MIX_FIRST, "weight", i), {head_size, n_embd / head_size}, TENSOR_NOT_REQUIRED); - layer.time_mix_decay = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY, "weight", i), {n_embd}, 0); - layer.time_mix_decay_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY_W1, "weight", i), {n_embd, time_decay_extra_dim}, 0); - layer.time_mix_decay_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY_W2, "weight", i), {time_decay_extra_dim, attn_hidden_size}, 0); - layer.time_mix_key = create_tensor(tn(LLM_TENSOR_TIME_MIX_KEY, "weight", i), {n_embd, attn_key_value_size}, 0); - layer.time_mix_value = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "weight", i), {n_embd, attn_key_value_size}, 0); - layer.time_mix_receptance = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "weight", i), {attn_hidden_size, n_embd}, 0); - layer.time_mix_gate = create_tensor(tn(LLM_TENSOR_TIME_MIX_GATE, "weight", i), {attn_hidden_size, n_embd}, 0); - // optional bias tensors - layer.time_mix_key_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_KEY, "bias", i), {attn_key_value_size}, TENSOR_NOT_REQUIRED); - layer.time_mix_value_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "bias", i), {attn_key_value_size}, TENSOR_NOT_REQUIRED); - layer.time_mix_receptance_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "bias", i), {attn_hidden_size}, TENSOR_NOT_REQUIRED); - - layer.time_mix_output = create_tensor(tn(LLM_TENSOR_TIME_MIX_OUTPUT, "weight", i), {n_embd, attn_hidden_size}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } break; - case LLM_ARCH_RWKV7: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // Block 0, LN0 - tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {n_embd}, 0); - tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias", 0), {n_embd}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - const int n_lora_decay = hparams.n_lora_decay; - const int n_lora_iclr = hparams.n_lora_iclr; - const int n_lora_value_res_mix = hparams.n_lora_value_res_mix; - const int n_lora_gate = hparams.n_lora_gate; - const int attn_hidden_size = n_embd; - const int ffn_size = hparams.n_ff_arr[0]; - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); - - layer.attn_norm_2 = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, 0); - layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, 0); - - layer.time_mix_w0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W0, "weight", i), {n_embd}, 0); - layer.time_mix_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W1, "weight", i), {n_embd, n_lora_decay}, 0); - layer.time_mix_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {n_lora_decay, n_embd}, 0); - - layer.time_mix_a0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A0, "weight", i), {n_embd}, 0); - layer.time_mix_a1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A1, "weight", i), {n_embd, n_lora_iclr}, 0); - layer.time_mix_a2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A2, "weight", i), {n_lora_iclr, n_embd}, 0); - - if (i == 0) { - // actually not used - layer.time_mix_v0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V0, "weight", i), {n_embd}, 0); - layer.time_mix_v1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V1, "weight", i), {n_embd, n_lora_iclr}, 0); - layer.time_mix_v2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V2, "weight", i), {n_lora_iclr, n_embd}, 0); - } else { - layer.time_mix_v0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V0, "weight", i), {n_embd}, 0); - layer.time_mix_v1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V1, "weight", i), {n_embd, n_lora_value_res_mix}, 0); - layer.time_mix_v2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V2, "weight", i), {n_lora_value_res_mix, n_embd}, 0); - } - - layer.time_mix_g1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_G1, "weight", i), {n_embd, n_lora_gate}, 0); - layer.time_mix_g2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_G2, "weight", i), {n_lora_gate, n_embd}, 0); - - layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 6}, 0); - - layer.time_mix_k_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_K_K, "weight", i), {attn_hidden_size}, 0); - layer.time_mix_k_a = create_tensor(tn(LLM_TENSOR_TIME_MIX_K_A, "weight", i), {attn_hidden_size}, 0); - layer.time_mix_r_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_R_K, "weight", i), {attn_hidden_size}, 0); - - layer.time_mix_key = create_tensor(tn(LLM_TENSOR_TIME_MIX_KEY, "weight", i), {attn_hidden_size, n_embd}, 0); - layer.time_mix_value = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "weight", i), {attn_hidden_size, n_embd}, 0); - layer.time_mix_receptance = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "weight", i), {attn_hidden_size, n_embd}, 0); - - layer.time_mix_ln = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "weight", i), {n_embd}, 0); - layer.time_mix_ln_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "bias", i), {n_embd}, 0); - layer.time_mix_output = create_tensor(tn(LLM_TENSOR_TIME_MIX_OUTPUT, "weight", i), {n_embd, attn_hidden_size}, 0); - - layer.channel_mix_lerp_k = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_LERP_K, "weight", i), {n_embd, 1, 1}, 0); - - layer.channel_mix_key = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_KEY, "weight", i), {n_embd, ffn_size}, 0); - layer.channel_mix_value = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_VALUE, "weight", i), {ffn_size, n_embd}, 0); - } - - } break; - case LLM_ARCH_ARWKV7: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - const int n_lora_decay = hparams.n_lora_decay; - const int n_lora_iclr = hparams.n_lora_iclr; - const int n_lora_value_res_mix = hparams.n_lora_value_res_mix; - const int n_lora_gate = hparams.n_lora_gate; - const int attn_hidden_size = n_embd; - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - layer.time_mix_w0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W0, "weight", i), {n_embd}, 0); - layer.time_mix_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W1, "weight", i), {n_embd, n_lora_decay}, 0); - layer.time_mix_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {n_lora_decay, n_embd}, 0); - - layer.time_mix_a0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A0, "weight", i), {n_embd}, 0); - layer.time_mix_a1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A1, "weight", i), {n_embd, n_lora_iclr}, 0); - layer.time_mix_a2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A2, "weight", i), {n_lora_iclr, n_embd}, 0); - - if (i == 0) { - // actually not used - layer.time_mix_v0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V0, "weight", i), {n_embd}, 0); - layer.time_mix_v1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V1, "weight", i), {n_embd, n_lora_iclr}, 0); - layer.time_mix_v2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V2, "weight", i), {n_lora_iclr, n_embd}, 0); - } else { - layer.time_mix_v0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V0, "weight", i), {n_embd}, 0); - layer.time_mix_v1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V1, "weight", i), {n_embd, n_lora_value_res_mix}, 0); - layer.time_mix_v2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V2, "weight", i), {n_lora_value_res_mix, n_embd}, 0); - } - - layer.time_mix_g1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_G1, "weight", i), {n_embd, n_lora_gate}, TENSOR_NOT_REQUIRED); - layer.time_mix_g2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_G2, "weight", i), {n_lora_gate, n_embd}, TENSOR_NOT_REQUIRED); - - try { - layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 6}, 0); - } catch(std::runtime_error & e) { - // ARWKV models may not have gate tensors - layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 5}, 0); - } - - layer.time_mix_k_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_K_K, "weight", i), {attn_hidden_size}, 0); - layer.time_mix_k_a = create_tensor(tn(LLM_TENSOR_TIME_MIX_K_A, "weight", i), {attn_hidden_size}, 0); - layer.time_mix_r_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_R_K, "weight", i), {attn_hidden_size}, 0); - - layer.time_mix_key = create_tensor(tn(LLM_TENSOR_TIME_MIX_KEY, "weight", i), {attn_hidden_size, n_embd}, 0); - layer.time_mix_value = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "weight", i), {attn_hidden_size, n_embd}, 0); - layer.time_mix_receptance = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "weight", i), {attn_hidden_size, n_embd}, 0); - - layer.time_mix_ln = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.time_mix_ln_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.time_mix_output = create_tensor(tn(LLM_TENSOR_TIME_MIX_OUTPUT, "weight", i), {n_embd, attn_hidden_size}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - - } break; - case LLM_ARCH_CHAMELEON: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k, n_head}, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k, n_head_kv}, 0); - layer.attn_q_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {n_embd_head_k, n_head}, TENSOR_NOT_REQUIRED); - layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {n_embd_head_k, n_head_kv}, TENSOR_NOT_REQUIRED); - - create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } break; - case LLM_ARCH_WAVTOKENIZER_DEC: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {hparams.n_embd, n_vocab}, 0); - - conv1d = create_tensor(tn(LLM_TENSOR_CONV1D, "weight", 0), {7, hparams.n_embd, hparams.posnet.n_embd}, 0); - conv1d_b = create_tensor(tn(LLM_TENSOR_CONV1D, "bias", 0), {1, hparams.posnet.n_embd}, 0); - - // posnet - { - const int64_t n_embd = hparams.posnet.n_embd; - - for (uint32_t i = 0; i < hparams.posnet.n_layer; ++i) { - auto & layer = layers[i].posnet; - - // posnet: - // - // - resnet - // - resnet - // - attn - // - resnet - // - resnet - // - norm - // - switch (i) { - case 0: - case 1: - case 3: - case 4: - { - layer.norm1 = create_tensor(tn(LLM_TENSOR_POS_NET_NORM1, "weight", i), {1, n_embd}, 0); - layer.norm1_b = create_tensor(tn(LLM_TENSOR_POS_NET_NORM1, "bias", i), {1, n_embd}, 0); - - layer.conv1 = create_tensor(tn(LLM_TENSOR_POS_NET_CONV1, "weight", i), {3, n_embd, n_embd}, 0); - layer.conv1_b = create_tensor(tn(LLM_TENSOR_POS_NET_CONV1, "bias", i), {1, n_embd}, 0); - - layer.norm2 = create_tensor(tn(LLM_TENSOR_POS_NET_NORM2, "weight", i), {1, n_embd}, 0); - layer.norm2_b = create_tensor(tn(LLM_TENSOR_POS_NET_NORM2, "bias", i), {1, n_embd}, 0); - - layer.conv2 = create_tensor(tn(LLM_TENSOR_POS_NET_CONV2, "weight", i), {3, n_embd, n_embd}, 0); - layer.conv2_b = create_tensor(tn(LLM_TENSOR_POS_NET_CONV2, "bias", i), {1, n_embd}, 0); - } break; - case 2: - { - layer.attn_norm = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "weight", i), {1, n_embd}, 0); - layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "bias", i), {1, n_embd}, 0); - - layer.attn_q = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_Q, "weight", i), {1, n_embd, n_embd}, 0); - layer.attn_q_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_Q, "bias", i), {1, n_embd}, 0); - - layer.attn_k = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_K, "weight", i), {1, n_embd, n_embd}, 0); - layer.attn_k_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_K, "bias", i), {1, n_embd}, 0); - - layer.attn_v = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_V, "weight", i), {1, n_embd, n_embd}, 0); - layer.attn_v_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_V, "bias", i), {1, n_embd}, 0); - - layer.attn_o = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_OUT, "weight", i), {1, n_embd, n_embd}, 0); - layer.attn_o_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_OUT, "bias", i), {1, n_embd}, 0); - } break; - case 5: - { - layer.norm = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "weight", i), {1, n_embd}, 0); - layer.norm_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "bias", i), {1, n_embd}, 0); - } break; - default: GGML_ABORT("unknown posnet layer"); - }; - } - } - - GGML_ASSERT(hparams.posnet.n_embd == hparams.convnext.n_embd); - - tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {hparams.posnet.n_embd}, 0); - tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias", 0), {hparams.posnet.n_embd}, 0); - - // convnext - { - const int64_t n_embd = hparams.convnext.n_embd; - - for (uint32_t i = 0; i < hparams.convnext.n_layer; ++i) { - auto & layer = layers[i].convnext; - - layer.dw = create_tensor(tn(LLM_TENSOR_CONVNEXT_DW, "weight", i), {7, 1, n_embd}, 0); - layer.dw_b = create_tensor(tn(LLM_TENSOR_CONVNEXT_DW, "bias", i), {1, n_embd}, 0); - - layer.norm = create_tensor(tn(LLM_TENSOR_CONVNEXT_NORM, "weight", i), {n_embd}, 0); - layer.norm_b = create_tensor(tn(LLM_TENSOR_CONVNEXT_NORM, "bias", i), {n_embd}, 0); - - layer.pw1 = create_tensor(tn(LLM_TENSOR_CONVNEXT_PW1, "weight", i), {n_embd, n_ff}, 0); - layer.pw1_b = create_tensor(tn(LLM_TENSOR_CONVNEXT_PW1, "bias", i), {n_ff}, 0); - - layer.pw2 = create_tensor(tn(LLM_TENSOR_CONVNEXT_PW2, "weight", i), {n_ff, n_embd}, 0); - layer.pw2_b = create_tensor(tn(LLM_TENSOR_CONVNEXT_PW2, "bias", i), {n_embd}, 0); - - layer.gamma = create_tensor(tn(LLM_TENSOR_CONVNEXT_GAMMA, "weight", i), {n_embd}, 0); - } - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); - } - - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {hparams.convnext.n_embd, hparams.n_embd_out()}, 0); - output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), {hparams.n_embd_out()}, 0); - } break; - case LLM_ARCH_BAILINGMOE: - { - const int64_t n_ff_exp = hparams.n_ff_exp; - const int64_t n_expert_shared = hparams.n_expert_shared; - - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - create_tensor_qkv(layer, i, n_embd, n_head * n_rot, n_head_kv * n_rot, n_head_kv * n_rot, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_rot, n_embd}, 0); - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); - - if (n_expert == 0) { - throw std::runtime_error("n_expert must be > 0"); - } - if (n_expert_used == 0) { - throw std::runtime_error("n_expert_used must be > 0"); - } - - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); - - layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); - layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, 0); - layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); - } - } break; - case LLM_ARCH_BAILINGMOE2: - { - const int64_t n_ff_exp = hparams.n_ff_exp; - const int64_t n_expert_shared = hparams.n_expert_shared; - - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - GGML_ASSERT(n_expert > 0 && "n_expert must be > 0 for bailingmoe2"); - GGML_ASSERT(n_expert_used > 0 && "n_expert_used must be > 0 for bailingmoe2"); - - for (int i = 0; i < n_layer; ++i) { - int flags = 0; - if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { - // skip all tensors in the NextN layers - flags |= TENSOR_SKIP; - } - - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, flags); - - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, flags); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, flags); - - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, flags); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, flags); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, flags); - - if (static_cast(i) >= hparams.n_layer_dense_lead) { // MoE layers - const int64_t n_ff_shexp = (hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff_exp) * n_expert_shared; - - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, flags); - layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED | flags); - - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, flags); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, flags); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, flags); - - layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp}, flags); - layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, flags); - layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp}, flags); - } else { // Dense layers - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, flags); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, flags); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, flags); - } - - // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers - if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { - layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags); - layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED | flags); - layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags); - layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags); - layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED | flags); - layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED | flags); - layer.layer_out_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, flags); - } - } - } break; - case LLM_ARCH_DOTS1: - { - const int64_t n_ff_exp = hparams.n_ff_exp; - const int64_t n_expert_shared = hparams.n_expert_shared; - - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_head_k * n_head, n_embd_head_k * n_head, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); - - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - if (i < (int) hparams.n_layer_dense_lead) { - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } else { - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); - layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); - - if (n_expert == 0) { - throw std::runtime_error("n_expert must be > 0"); - } - if (n_expert_used == 0) { - throw std::runtime_error("n_expert_used must be > 0"); - } - - // MoE branch - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); - - // Shared expert branch - layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); - layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, 0); - layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); - } - } - } break; - case LLM_ARCH_ARCEE: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } break; - case LLM_ARCH_AFMOE: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - const int64_t n_ff_exp = hparams.n_ff_exp; - const int64_t n_expert_shared = hparams.n_expert_shared; - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - // dual attention normalization - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); - - // attention projections - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); - - // Q/K normalization - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); - - // attention gating - layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - - // dual ffn normalization - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); - - if (static_cast(i) >= hparams.n_layer_dense_lead) { - // MoE layers - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); - layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, 0); - - // grouped expert weights - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); - - // shared expert - if (n_expert_shared > 0) { - const int64_t n_ff_shexp = n_ff_exp * n_expert_shared; - layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp}, 0); - layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, 0); - layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp}, 0); - } - } else { - // Dense layers - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } - } break; - case LLM_ARCH_ERNIE4_5: - case LLM_ARCH_ERNIE4_5_MOE: - case LLM_ARCH_PADDLEOCR: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); - - // optional bias tensors - layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - if (arch == LLM_ARCH_ERNIE4_5_MOE && static_cast(i) >= hparams.n_layer_dense_lead) { // MoE layers - int n_ff_exp = hparams.n_ff_exp; - - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); - layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); - - // Shared expert (if present) - if (hparams.n_ff_shexp > 0) { - layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, hparams.n_ff_shexp}, 0); - layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd }, 0); - layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, hparams.n_ff_shexp}, 0); - } - } else { // Dense layers - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } - } break; - case LLM_ARCH_FALCON_H1: - { - // Common - const int64_t hidden_size = hparams.n_embd; // hidden_size - - // mamba2 Mixer SSM params - const int64_t ssm_conv_kernel_size = hparams.ssm_d_conv; // ssm_conv_kernel_size - const int64_t ssm_n_groups = hparams.ssm_n_group; // ssm_n_groups - const int64_t ssm_state_size = hparams.ssm_d_state; // ssm_state_size - const int64_t ssm_intermediate_size = hparams.ssm_d_inner; // TODO expand - const int64_t ssm_num_heads = hparams.ssm_dt_rank; // ssm_num_heads - const int64_t ssm_conv_dim = ssm_intermediate_size + 2 * ssm_n_groups * ssm_state_size; - const int64_t ssm_projection_size = ssm_intermediate_size + ssm_conv_dim + ssm_num_heads; - - // attn params - const int64_t attn_num_attention_head = hparams.n_head(0); // rename to: attn_num_attention_head - const int64_t attn_num_key_value_head = hparams.n_head_kv(0); - - // ffn params - const int64_t ffn_intermediate_size = hparams.n_ff(0); - - // embeddings - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {hidden_size, n_vocab}, 0); - - // output - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {hidden_size, n_vocab}, TENSOR_NOT_REQUIRED); - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {hidden_size}, 0); - - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {hidden_size, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - /*SSM LAYERS*/ - // ssm in - layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {hidden_size, ssm_projection_size}, 0); - // ssm 1d conv - layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {ssm_conv_kernel_size, ssm_conv_dim}, 0); - layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {ssm_conv_dim}, TENSOR_NOT_REQUIRED); - // ssm_dt - layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {ssm_num_heads}, 0); - // no "weight" suffix for these - layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, ssm_num_heads}, 0); - layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, ssm_num_heads}, 0); - // ssm_norm - layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {ssm_intermediate_size / ssm_n_groups, ssm_n_groups}, TENSOR_NOT_REQUIRED); - // out_proj - layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {ssm_intermediate_size, hidden_size}, 0); - - /*ATTENTION LAYERS*/ - // attention layers (with optional bias) - create_tensor_qkv(layer, i, hidden_size, n_embd_head_k * attn_num_attention_head, attn_num_key_value_head * n_embd_head_k, attn_num_key_value_head * n_embd_head_v, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * attn_num_attention_head, hidden_size}, 0); - layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {hidden_size}, TENSOR_NOT_REQUIRED); - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {hidden_size}, 0); - - - // feed forward (w/ optional biases) - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, i), {hidden_size}, 0); - layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {hidden_size, ffn_intermediate_size}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { ffn_intermediate_size, hidden_size}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {hidden_size, ffn_intermediate_size}, 0); - - layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {ffn_intermediate_size}, TENSOR_NOT_REQUIRED); - layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {hidden_size}, TENSOR_NOT_REQUIRED); - layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {ffn_intermediate_size}, TENSOR_NOT_REQUIRED); - } - } break; - case LLM_ARCH_HUNYUAN_MOE: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - const uint32_t n_ff_shexp = hparams.n_ff_shexp > 0 ? hparams.n_ff_shexp : hparams.n_ff(i); - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); - - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); - - layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp}, 0); - layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp}, 0); - layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, 0); - } - } break; - case LLM_ARCH_HUNYUAN_VL: - case LLM_ARCH_HUNYUAN_DENSE: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); - - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - - } - } break; - case LLM_ARCH_SMOLLM3: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } break; - case LLM_ARCH_OPENAI_MOE: - { - const int64_t n_ff_exp = hparams.n_ff_exp; - - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); - - create_tensor_qkv(layer, i, n_embd, n_head * n_rot, n_head_kv * n_rot, n_head_kv * n_rot, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_rot, n_embd}, 0); - - layer.attn_sinks = create_tensor(tn(LLM_TENSOR_ATTN_SINKS, "weight", i), {n_head}, 0); - - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert}, 0); - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); - - layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); - - layer.ffn_gate_inp_b = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "bias", i), {n_expert}, 0); - layer.ffn_gate_exps_b = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "bias", i), {n_ff_exp, n_expert}, 0); - layer.ffn_down_exps_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "bias", i), { n_embd, n_expert}, 0); - layer.ffn_up_exps_b = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "bias", i), {n_ff_exp, n_expert}, 0); - } - } break; - case LLM_ARCH_LFM2: - case LLM_ARCH_LFM2MOE: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM_LFM2, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - const bool is_moe_layer = i >= static_cast(hparams.n_layer_dense_lead); - - // ffn/moe is same for transformer and conv layers - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - if (is_moe_layer) { - GGML_ASSERT(n_expert && n_expert_used); - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, hparams.n_ff_exp, n_expert}, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {hparams.n_ff_exp, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, hparams.n_ff_exp, n_expert}, 0); - layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, 0); - } else { // dense - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - - // for operator_norm - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - if (!hparams.is_recurrent(i)) { - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); - GGML_ASSERT(n_embd_v_gqa == n_embd_k_gqa); - - create_tensor_qkv(layer, i, n_embd, n_embd, hparams.n_embd_k_gqa(i), hparams.n_embd_v_gqa(i), 0); - - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - } else { - layer.shortconv.conv = create_tensor(tn(LLM_TENSOR_SHORTCONV_CONV, "weight", i), {hparams.n_shortconv_l_cache, n_embd}, 0); - layer.shortconv.in_proj = create_tensor(tn(LLM_TENSOR_SHORTCONV_INPROJ, "weight", i), {n_embd, 3 * n_embd}, 0); - layer.shortconv.out_proj = create_tensor(tn(LLM_TENSOR_SHORTCONV_OUTPROJ, "weight", i), {n_embd, n_embd}, 0); - } - } - - // for LFM2-ColBert-350M - dense_2_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_2_OUT, "weight"), {n_embd, hparams.n_embd_out()}, TENSOR_NOT_REQUIRED); - dense_2_out_layers_b = create_tensor(tn(LLM_TENSOR_DENSE_2_OUT, "bias"), {hparams.n_embd_out() }, TENSOR_NOT_REQUIRED); - } break; - case LLM_ARCH_SMALLTHINKER: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); - - GGML_ASSERT(n_expert > 0 && "n_expert must be > 0 for SMALLTHINKER"); - GGML_ASSERT(n_expert_used > 0 && "n_expert_used must be > 0 for SMALLTHINKER"); - - // MoE branch - const int64_t n_ff_exp = hparams.n_ff_exp; - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0); - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0); - } - } break; - case LLM_ARCH_GROVEMOE: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - GGML_ASSERT(n_expert > 0 && "n_expert must be > 0 for GROVEMOE"); - GGML_ASSERT(n_expert_used > 0 && "n_expert_used must be > 0 for GROVEMOE"); - GGML_ASSERT(hparams.n_group_experts > 0 && "n_group_experts must be > 0 for GROVEMOE"); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); - - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); - - // MoE branch - const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; - const int64_t n_ff_chexp = hparams.n_ff_chexp ? hparams.n_ff_chexp : n_embd_head_k; - const int64_t n_chunk_expert = n_expert / hparams.n_group_experts; - - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); - - layer.ffn_gate_chexps = create_tensor(tn(LLM_TENSOR_FFN_GATE_CHEXPS, "weight", i), { n_embd, n_ff_chexp, n_chunk_expert}, 0); - layer.ffn_down_chexps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_CHEXPS, "weight", i), {n_ff_chexp, n_embd, n_chunk_expert}, 0); - layer.ffn_up_chexps = create_tensor(tn(LLM_TENSOR_FFN_UP_CHEXPS, "weight", i), { n_embd, n_ff_chexp, n_chunk_expert}, 0); - } - } break; - case LLM_ARCH_APERTUS: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); - - if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { - layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - } else { - layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - } - - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); - - // optional bias tensors - layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), { n_embd }, TENSOR_NOT_REQUIRED); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0); - - // Q and K layernorms for Apertus - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0); - layer.attn_q_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), { n_embd_head_k }, TENSOR_NOT_REQUIRED); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0); - layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), { n_embd_head_k }, TENSOR_NOT_REQUIRED); - } - } break; - case LLM_ARCH_MINIMAX_M2: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k * n_head}, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_k_gqa}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); - layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, 0); - } - } break; - case LLM_ARCH_KIMI_LINEAR: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - // Check for KDA specific tensors to determine layer type or if it's a mixed model - // Assuming KDA layer if KDA tensors are present - - // KDA uses head_dim = 128 (from linear_attn_config.head_dim) - const int64_t n_embd_head_k_kda = hparams.n_embd_head_kda; - const int64_t n_embd_head_v_kda = hparams.n_embd_head_kda; - const int64_t ssm_d_conv = hparams.ssm_d_conv; - - if (hparams.is_recurrent(i)) { - // Conv1d weights: try 4D first, then 3D (quantization may remove trailing 1) - // 4D: [d_conv, 1, d_inner, 1], 3D: [d_conv, 1, d_inner] - layer.ssm_q_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_Q, "weight", i), {ssm_d_conv, 1, n_embd_head_k_kda * n_head, 1}, TENSOR_NOT_REQUIRED); - if (!layer.ssm_q_conv) { - layer.ssm_q_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_Q, "weight", i), {ssm_d_conv, 1, n_embd_head_k_kda * n_head}, 0); - } - - // KDA Layer - Conv1d weights may be 3D or 4D - layer.ssm_k_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_K, "weight", i), {ssm_d_conv, 1, n_embd_head_k_kda * n_head, 1}, TENSOR_NOT_REQUIRED); - if (!layer.ssm_k_conv) { - layer.ssm_k_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_K, "weight", i), {ssm_d_conv, 1, n_embd_head_k_kda * n_head}, 0); - } - layer.ssm_v_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_V, "weight", i), {ssm_d_conv, 1, n_embd_head_v_kda * n_head, 1}, TENSOR_NOT_REQUIRED); - if (!layer.ssm_v_conv) { - layer.ssm_v_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_V, "weight", i), {ssm_d_conv, 1, n_embd_head_v_kda * n_head}, 0); - } - - // q, k, v projections - // Python: q_proj, k_proj, v_proj - create_tensor_qkv(layer, i, n_embd, n_embd_head_k_kda * n_head, n_embd_head_k_kda * n_head, n_embd_head_v_kda * n_head, 0); - - // KDA specific projections - // f_a_proj, f_b_proj - layer.ssm_f_a = create_tensor(tn(LLM_TENSOR_SSM_F_A, "weight", i), {n_embd, n_embd_head_k_kda}, 0); // head_dim - layer.ssm_f_b = create_tensor(tn(LLM_TENSOR_SSM_F_B, "weight", i), {n_embd_head_k_kda, n_embd_head_k_kda * n_head}, 0); // projection_size - - // b_proj (beta mixing coefficient) - layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", i), {n_embd, n_head}, 0); - - // A_log - Shape in GGUF: [1, num_heads, 1, 1] (4D) or [1, num_heads] (2D after quantization) Note: -exp(A_log) is applied in convert_hf_to_gguf.py - layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_head, 1, 1}, TENSOR_NOT_REQUIRED); - if (!layer.ssm_a) { - layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_head}, 0); - } - - // dt_bias - shape [n_embd_head_k_kda * n_head] = [4096] - layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {n_embd_head_k_kda * n_head}, 0); - - // g_a_proj, g_b_proj (output gate) - layer.ssm_g_a = create_tensor(tn(LLM_TENSOR_SSM_G_A, "weight", i), {n_embd, n_embd_head_k_kda}, 0); - layer.ssm_g_b = create_tensor(tn(LLM_TENSOR_SSM_G_B, "weight", i), {n_embd_head_k_kda, n_embd_head_k_kda * n_head}, 0); - - // o_norm (reusing SSM_NORM) - layer.ssm_o_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {n_embd_head_k_kda}, 0); // FusedRMSNormGated - - // o_proj - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_v_kda * n_head, n_embd}, 0); - - } else { - // MLA Layer - use MLA-specific head dimensions - const int64_t q_lora_rank = hparams.n_lora_q; - const int64_t kv_lora_rank = hparams.n_lora_kv; - const int64_t n_embd_head_k_mla = hparams.n_embd_head_k_mla(); - const int64_t n_embd_head_v_mla = hparams.n_embd_head_v_mla(); - - layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, TENSOR_NOT_REQUIRED); - layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, 0); - - if (layer.attn_q_a_norm) { - layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, 0); - layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k_mla}, 0); - } else { - // Kimi MLA without Q compression: wq = [n_embd, n_head * n_embd_head_k_mla] - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_head * n_embd_head_k_mla}, 0); - } - - // Kimi: qk_rope_head_dim = 64 (actual RoPE dimension for MLA) - // Note: hparams.n_rot may be 72 (from conversion) but actual is 64 - const int64_t qk_rope_head_dim = hparams.n_rot(); // From config: qk_rope_head_dim - layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + qk_rope_head_dim}, 0); - // Support Legacy GGUFs that don't split wkv_b (MLA KV cache disabled) - layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), - {kv_lora_rank, n_head * (n_embd_head_k_mla - qk_rope_head_dim + n_embd_head_v_mla)}, TENSOR_NOT_REQUIRED | TENSOR_SKIP_IF_VIRTUAL); - if (!layer.wkv_b) { // MLA KV cache enabled - layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K_B, "weight", i), {n_embd_head_k_mla - qk_rope_head_dim, kv_lora_rank, n_head}, 0); - layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_embd_head_v_mla, n_head}, 0); - } - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_embd_head_v_mla, n_embd}, 0); - } - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - // MoE intermediate size (different from dense FFN) - const int64_t n_ff_exp = hparams.n_ff_exp; - - // Kimi uses n_layer_dense_lead to determine which layers use dense FFN vs MoE - // first_k_dense_replace = 1 means layer 0 uses dense FFN, layers 1+ use MoE - if (i < (int) hparams.n_layer_dense_lead) { - // Dense FFN layer - use normal n_ff - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } else { - // MoE layer - use n_ff_exp (1024) instead of n_ff (9216) - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); - - // Shared experts use moe_intermediate_size * num_shared_experts - // Kimi: shared_expert_intermediate_size = 1024 * 1 = 1024 - // Tensors are 2D: [n_embd, n_ff_shexp] or [n_ff_shexp, n_embd] - const int64_t n_ff_shexp_actual = n_ff_exp * (hparams.n_expert_shared > 0 ? hparams.n_expert_shared : 1); - layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp_actual}, TENSOR_NOT_REQUIRED); - layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp_actual, n_embd}, TENSOR_NOT_REQUIRED); - layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp_actual}, TENSOR_NOT_REQUIRED); - - layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, 0); - } - } - } break; - case LLM_ARCH_COGVLM: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd_head_k * n_head * 3}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); - - layer.visexp_attn_wqkv = create_tensor(tn(LLM_TENSOR_VISEXP_ATTN_QKV, "weight", i), {n_embd, n_embd_head_k * n_head * 3}, 0); - layer.visexp_attn_wo = create_tensor(tn(LLM_TENSOR_VISEXP_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); - - layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - - layer.visexp_ffn_gate = create_tensor(tn(LLM_TENSOR_VISEXP_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.visexp_ffn_down = create_tensor(tn(LLM_TENSOR_VISEXP_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.visexp_ffn_up = create_tensor(tn(LLM_TENSOR_VISEXP_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } break; - case LLM_ARCH_PANGU_EMBED: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - // weight tensors - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); - - // bias tensors - layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { - layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - } else { - layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - } - - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } break; - case LLM_ARCH_QWEN3NEXT: - { - if (n_expert == 0) { - throw std::runtime_error(arch_name() + " model cannot have zero experts"); - } - - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); - - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); - } - - const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; - - // Calculate dimensions from hyperparameters - const int64_t head_k_dim = hparams.ssm_d_state; - const int64_t head_v_dim = hparams.ssm_d_state; - const int64_t n_k_heads = hparams.ssm_n_group; - const int64_t n_v_heads = hparams.ssm_dt_rank; - const int64_t key_dim = head_k_dim * n_k_heads; - const int64_t value_dim = head_v_dim * n_v_heads; - const int64_t conv_dim = key_dim * 2 + value_dim; - - // Calculate projection sizes - const int64_t qkvz_dim = key_dim * 2 + value_dim * 2; - const int64_t ba_dim = n_v_heads * 2; - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - const uint32_t n_ff_shexp = hparams.n_ff_shexp > 0 ? hparams.n_ff_shexp : hparams.n_ff(i); - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); - layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0); - - if (!hparams.is_recurrent(i)) { - // Attention layers - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); - - // Q/K normalization for attention layers - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0); - } else { - // Linear attention (gated delta net) specific tensors - // Create tensors with calculated dimensions - // note: ssm_in is used by legacy GGUF - layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), { n_embd, qkvz_dim }, TENSOR_NOT_REQUIRED); - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED); - layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), { n_embd, value_dim }, TENSOR_NOT_REQUIRED); - layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), { hparams.ssm_d_conv, conv_dim }, 0); - layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), { hparams.ssm_dt_rank }, 0); - layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, i), { hparams.ssm_dt_rank }, 0); - layer.ssm_beta_alpha = create_tensor(tn(LLM_TENSOR_SSM_BETA_ALPHA, "weight", i), { n_embd, ba_dim }, 0); - layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), { head_v_dim }, 0); - layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), { value_dim, n_embd }, 0); - } - - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0); - create_tensor_gate_up_exps(layer, i, n_embd, n_ff_exp, n_expert, 0); - - // Shared experts - layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), { n_embd }, 0); - layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0); - layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0); - layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_shexp, n_embd }, 0); - } - } break; - case LLM_ARCH_QWEN35MOE: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); - - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); - } - - const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; - - // Calculate dimensions from hyperparameters - const int64_t head_k_dim = hparams.ssm_d_state; - const int64_t head_v_dim = hparams.ssm_d_state; - const int64_t n_k_heads = hparams.ssm_n_group; - const int64_t n_v_heads = hparams.ssm_dt_rank; - const int64_t key_dim = head_k_dim * n_k_heads; - const int64_t value_dim = head_v_dim * n_v_heads; - const int64_t conv_dim = key_dim * 2 + value_dim; - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); - layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0); - - if (!hparams.is_recurrent(i)) { - // Attention layers - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); - - // Q/K normalization for attention layers - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0); - } else { - // Linear attention (gated delta net) specific tensors - // Create tensors with calculated dimensions - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED); - layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), { n_embd, value_dim }, TENSOR_NOT_REQUIRED); - layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), { hparams.ssm_d_conv, conv_dim }, 0); - layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), { hparams.ssm_dt_rank }, 0); - layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, i), { hparams.ssm_dt_rank }, 0); - layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", i), { n_embd, n_v_heads }, 0); - layer.ssm_alpha = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "weight", i), { n_embd, n_v_heads }, 0); - layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), { head_v_dim }, 0); - layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), { value_dim, n_embd }, 0); - } - - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0); - create_tensor_gate_up_exps(layer, i, n_embd, n_ff_exp, n_expert, 0); - - // Shared experts - const int64_t n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff; - - layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), { n_embd }, 0); - layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0); - layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0); - layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_shexp, n_embd }, 0); - } - } break; - case LLM_ARCH_QWEN35: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); - - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); - } - - // Calculate dimensions from hyperparameters - const int64_t head_k_dim = hparams.ssm_d_state; - const int64_t head_v_dim = hparams.ssm_d_state; - const int64_t n_k_heads = hparams.ssm_n_group; - const int64_t n_v_heads = hparams.ssm_dt_rank; - const int64_t key_dim = head_k_dim * n_k_heads; - const int64_t value_dim = head_v_dim * n_v_heads; - const int64_t conv_dim = key_dim * 2 + value_dim; - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); - layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0); - - if (!hparams.is_recurrent(i)) { - // Attention layers - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); - - // Q/K normalization for attention layers - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0); - } else { - // Linear attention (gated delta net) specific tensors - // Create tensors with calculated dimensions - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED); - layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), { n_embd, value_dim }, TENSOR_NOT_REQUIRED); - layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), { hparams.ssm_d_conv, conv_dim }, 0); - layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), { hparams.ssm_dt_rank }, 0); - layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, i), { hparams.ssm_dt_rank }, 0); - layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", i), { n_embd, n_v_heads }, 0); - layer.ssm_alpha = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "weight", i), { n_embd, n_v_heads }, 0); - layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), { head_v_dim }, 0); - layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), { value_dim, n_embd }, 0); - } - - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } break; - case LLM_ARCH_MIMO2: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i); - uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i); - uint32_t n_head = hparams.n_head(i); - - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_v * n_head, n_embd }, 0); - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_sinks = create_tensor(tn(LLM_TENSOR_ATTN_SINKS, "weight", i), {n_head}, TENSOR_NOT_REQUIRED); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - // non-MoE branch - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, TENSOR_NOT_REQUIRED); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); - - // MoE branch - int64_t n_ff_exp = hparams.n_ff_exp; - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, TENSOR_NOT_REQUIRED); - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, TENSOR_NOT_REQUIRED); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED); - layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); - } - } break; - case LLM_ARCH_STEP35: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - // STEP35 supports per-layer partial RoPE dims; rope factors are stored as a single shared tensor - // ("rope_freqs.weight") and ggml uses only the first (n_rot_l/2) entries per layer. - uint32_t n_rot_max = 0; - for (int i = 0; i < n_layer; ++i) { - n_rot_max = std::max(n_rot_max, hparams.n_rot(i)); - } - if (n_rot_max == 0) { - n_rot_max = n_rot; - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - const uint32_t n_head_l = hparams.n_head(i); - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i); - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, TENSOR_NOT_REQUIRED); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, TENSOR_NOT_REQUIRED); - - // optional rope factors (llama3) / longrope tensors - if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { - layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot_max/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot_max/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - } else { - layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot_max/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - } - - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head_l, n_embd_k_gqa, n_embd_v_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_v * n_head_l, n_embd}, 0); - - // head-wise attention gate (Step35 self_attn.g_proj) - layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), {n_embd, n_head_l}, TENSOR_NOT_REQUIRED); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - // dense MLP (leading dense blocks) - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, TENSOR_NOT_REQUIRED); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); - - // MoE routed experts + selection bias (router_bias) - const int64_t n_ff_exp = hparams.n_ff_exp; - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, TENSOR_NOT_REQUIRED); - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, TENSOR_NOT_REQUIRED); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED); - layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); - - // shared expert MLP - layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, TENSOR_NOT_REQUIRED); - layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, TENSOR_NOT_REQUIRED); - layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, TENSOR_NOT_REQUIRED); - } - } break; - case LLM_ARCH_MAINCODER: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); - - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } break; - default: - throw std::runtime_error("unknown architecture"); - } - - // generic pass: load optional per-tensor/per-expert ".scale" tensors (e.g. NVFP4 scale2) - // this avoids having to add scale loading to every architecture - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - // attention weight scales (per-tensor, shape {1}) - if (!layer.wq_s && layer.wq) { - layer.wq_s = create_tensor(tn(LLM_TENSOR_ATTN_Q, "scale", i), {1}, TENSOR_NOT_REQUIRED); - } - if (!layer.wk_s && layer.wk) { - layer.wk_s = create_tensor(tn(LLM_TENSOR_ATTN_K, "scale", i), {1}, TENSOR_NOT_REQUIRED); - } - if (!layer.wv_s && layer.wv) { - layer.wv_s = create_tensor(tn(LLM_TENSOR_ATTN_V, "scale", i), {1}, TENSOR_NOT_REQUIRED); - } - if (!layer.wo_s && layer.wo) { - layer.wo_s = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "scale", i), {1}, TENSOR_NOT_REQUIRED); - } - if (!layer.wqkv_s && layer.wqkv) { - layer.wqkv_s = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "scale", i), {1}, TENSOR_NOT_REQUIRED); - } - if (!layer.wqkv_gate_s && layer.wqkv_gate) { - layer.wqkv_gate_s = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "scale", i), {1}, TENSOR_NOT_REQUIRED); - } - - // dense FFN weight scales (per-tensor, shape {1}) - if (!layer.ffn_gate_s && layer.ffn_gate) { - layer.ffn_gate_s = create_tensor(tn(LLM_TENSOR_FFN_GATE, "scale", i), {1}, TENSOR_NOT_REQUIRED); - } - if (!layer.ffn_down_s && layer.ffn_down) { - layer.ffn_down_s = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "scale", i), {1}, TENSOR_NOT_REQUIRED); - } - if (!layer.ffn_up_s && layer.ffn_up) { - layer.ffn_up_s = create_tensor(tn(LLM_TENSOR_FFN_UP, "scale", i), {1}, TENSOR_NOT_REQUIRED); - } - if (!layer.ffn_gate_shexp_s && layer.ffn_gate_shexp) { - layer.ffn_gate_shexp_s = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "scale", i), {1}, TENSOR_NOT_REQUIRED); - } - if (!layer.ffn_down_shexp_s && layer.ffn_down_shexp) { - layer.ffn_down_shexp_s = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "scale", i), {1}, TENSOR_NOT_REQUIRED); - } - if (!layer.ffn_up_shexp_s && layer.ffn_up_shexp) { - layer.ffn_up_shexp_s = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "scale", i), {1}, TENSOR_NOT_REQUIRED); - } - - // MoE expert weight scales (per-expert, shape {n_expert}) - if (!layer.ffn_gate_exps_s && layer.ffn_gate_exps) { - layer.ffn_gate_exps_s = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "scale", i), {n_expert}, TENSOR_NOT_REQUIRED); - } - if (!layer.ffn_down_exps_s && layer.ffn_down_exps) { - layer.ffn_down_exps_s = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "scale", i), {n_expert}, TENSOR_NOT_REQUIRED); - } - if (!layer.ffn_up_exps_s && layer.ffn_up_exps) { - layer.ffn_up_exps_s = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "scale", i), {n_expert}, TENSOR_NOT_REQUIRED); - } - - // recurrent / linear-attention weight scales (per-tensor, shape {1}) - if (!layer.ssm_in_s && layer.ssm_in) { - layer.ssm_in_s = create_tensor(tn(LLM_TENSOR_SSM_IN, "scale", i), {1}, TENSOR_NOT_REQUIRED); - } - if (!layer.ssm_out_s && layer.ssm_out) { - layer.ssm_out_s = create_tensor(tn(LLM_TENSOR_SSM_OUT, "scale", i), {1}, TENSOR_NOT_REQUIRED); - } - if (!layer.ssm_alpha_s && layer.ssm_alpha) { - layer.ssm_alpha_s = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "scale", i), {1}, TENSOR_NOT_REQUIRED); - } - if (!layer.ssm_beta_s && layer.ssm_beta) { - layer.ssm_beta_s = create_tensor(tn(LLM_TENSOR_SSM_BETA, "scale", i), {1}, TENSOR_NOT_REQUIRED); - } - - // input scales - if (!layer.wq_in_s && layer.wq) { - layer.wq_in_s = create_tensor(tn(LLM_TENSOR_ATTN_Q, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); - } - if (!layer.wk_in_s && layer.wk) { - layer.wk_in_s = create_tensor(tn(LLM_TENSOR_ATTN_K, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); - } - if (!layer.wv_in_s && layer.wv) { - layer.wv_in_s = create_tensor(tn(LLM_TENSOR_ATTN_V, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); - } - if (!layer.wo_in_s && layer.wo) { - layer.wo_in_s = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); - } - if (!layer.wqkv_in_s && layer.wqkv) { - layer.wqkv_in_s = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); - } - if (!layer.wqkv_gate_in_s && layer.wqkv_gate) { - layer.wqkv_gate_in_s = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); - } - if (!layer.ffn_gate_in_s && layer.ffn_gate) { - layer.ffn_gate_in_s = create_tensor(tn(LLM_TENSOR_FFN_GATE, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); - } - if (!layer.ffn_down_in_s && layer.ffn_down) { - layer.ffn_down_in_s = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); - } - if (!layer.ffn_up_in_s && layer.ffn_up) { - layer.ffn_up_in_s = create_tensor(tn(LLM_TENSOR_FFN_UP, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); - } - if (!layer.ffn_gate_exps_in_s && layer.ffn_gate_exps) { - layer.ffn_gate_exps_in_s = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "input_scale", i), {n_expert}, TENSOR_NOT_REQUIRED); - } - if (!layer.ffn_down_exps_in_s && layer.ffn_down_exps) { - layer.ffn_down_exps_in_s = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "input_scale", i), {n_expert}, TENSOR_NOT_REQUIRED); - } - if (!layer.ffn_up_exps_in_s && layer.ffn_up_exps) { - layer.ffn_up_exps_in_s = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "input_scale", i), {n_expert}, TENSOR_NOT_REQUIRED); - } - if (!layer.ffn_gate_shexp_in_s && layer.ffn_gate_shexp) { - layer.ffn_gate_shexp_in_s = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); - } - if (!layer.ffn_down_shexp_in_s && layer.ffn_down_shexp) { - layer.ffn_down_shexp_in_s = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); - } - if (!layer.ffn_up_shexp_in_s && layer.ffn_up_shexp) { - layer.ffn_up_shexp_in_s = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); - } - if (!layer.ssm_in_in_s && layer.ssm_in) { - layer.ssm_in_in_s = create_tensor(tn(LLM_TENSOR_SSM_IN, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); - } - if (!layer.ssm_out_in_s && layer.ssm_out) { - layer.ssm_out_in_s = create_tensor(tn(LLM_TENSOR_SSM_OUT, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); - } - if (!layer.ssm_alpha_in_s && layer.ssm_alpha) { - layer.ssm_alpha_in_s = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); - } - if (!layer.ssm_beta_in_s && layer.ssm_beta) { - layer.ssm_beta_in_s = create_tensor(tn(LLM_TENSOR_SSM_BETA, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); - } - } - } - - ml.done_getting_tensors(); - - // populate tensors_by_name - for (auto & [_, ctx_ptr] : ml.ctx_map) { - for (auto * cur = ggml_get_first_tensor(ctx_ptr.get()); cur != NULL; cur = ggml_get_next_tensor(ctx_ptr.get(), cur)) { - tensors_by_name.emplace_back(ggml_get_name(cur), cur); - } - } - - ml.init_mappings(true, use_mlock ? &pimpl->mlock_mmaps : nullptr); - pimpl->mappings.reserve(ml.mappings.size()); - - // create the backend buffers - std::vector> ctx_buf_maps; - ctx_buf_maps.reserve(ml.ctx_map.size()); - - // Ensure we have enough capacity for the maximum backend buffer we will potentially create - const size_t n_max_backend_buffer = ml.ctx_map.size() * ml.files.size(); - pimpl->ctxs_bufs.reserve(n_max_backend_buffer); - - for (auto & [buft, ctx_ptr] : ml.ctx_map) { - ggml_context * ctx = ctx_ptr.get(); - - // skip contexts without tensors - if (ggml_get_first_tensor(ctx) == nullptr) { - continue; - } - - llama_buf_map buf_map; - buf_map.reserve(n_max_backend_buffer); - - // check if it is possible to use buffer_from_host_ptr with this buffer type - ggml_backend_dev_t dev = ggml_backend_buft_get_device(buft); - if (!dev) { - // FIXME: workaround for CPU backend buft having a NULL device - dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); - if (!dev) { - throw std::runtime_error(format("%s: no CPU backend found", __func__)); - } - } - ggml_backend_dev_props props; - ggml_backend_dev_get_props(dev, &props); - bool buffer_from_host_ptr_supported = props.caps.buffer_from_host_ptr; - bool is_default_buft = buft == ggml_backend_dev_buffer_type(dev); - - std::vector bufs; - if (ml.use_mmap && use_mmap_buffer && buffer_from_host_ptr_supported && is_default_buft) { - GGML_ASSERT(!ml.no_alloc); - for (uint32_t idx = 0; idx < ml.files.size(); idx++) { - // only the mmap region containing the tensors in the model is mapped to the backend buffer - // this is important for metal with apple silicon: if the entire model could be mapped to a metal buffer, - // then we could just use metal for all layers - // this allows using partial offloading when the model size exceeds the metal buffer size, but not the RAM size - void * addr = nullptr; - size_t first, last; // NOLINT - ml.get_mapping_range(&first, &last, &addr, idx, ctx); - if (first >= last) { - continue; - } - const size_t max_size = ggml_get_max_tensor_size(ctx); - ggml_backend_buffer_t buf = ggml_backend_dev_buffer_from_host_ptr(dev, (char *) addr + first, last - first, max_size); - if (buf == nullptr) { - throw std::runtime_error(format("unable to allocate %s buffer", ggml_backend_buft_name(buft))); - } - bufs.emplace_back(buf); - buf_map.emplace(idx, buf); - } - } else { - ggml_backend_buffer_t buf; - if (ml.no_alloc) { - buf = ggml_backend_buft_alloc_buffer(buft, /*size =*/ 0); // dummy buffer - for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) { - t->buffer = buf; // set dummy buffer for weights so that the backend scheduler won't try to allocate them - } - } else { - buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); // real buffer - } - if (buf == nullptr) { - throw std::runtime_error(format("unable to allocate %s buffer", ggml_backend_buft_name(buft))); - } - if (use_mlock && ggml_backend_buffer_is_host(buf)) { - pimpl->mlock_bufs.emplace_back(new llama_mlock); - auto & mlock_buf = pimpl->mlock_bufs.back(); - mlock_buf->init (ggml_backend_buffer_get_base(buf)); - mlock_buf->grow_to(ggml_backend_buffer_get_size(buf)); - } - bufs.emplace_back(buf); - for (uint32_t idx = 0; idx < ml.files.size(); idx++) { - buf_map.emplace(idx, buf); - } - } - - for (auto & buf : bufs) { - // indicate that this buffer contains weights - // this is used by ggml_backend_sched to improve op scheduling: ops that use a weight are preferably scheduled to the backend that contains the weight - ggml_backend_buffer_set_usage(buf.get(), GGML_BACKEND_BUFFER_USAGE_WEIGHTS); - } - - pimpl->ctxs_bufs.emplace_back(std::move(ctx_ptr), std::move(bufs)); - - ctx_buf_maps.emplace_back(ctx, buf_map); - } - - if (llama_supports_gpu_offload()) { - const int n_gpu = std::min(n_gpu_layers, int(hparams.n_layer)); - - int n_repeating = n_gpu; - if (n_repeating > 0) { - LLAMA_LOG_INFO("%s: offloading output layer to GPU\n", __func__); - n_repeating--; - } - LLAMA_LOG_INFO("%s: offloading %d repeating layers to GPU\n", __func__, n_repeating); - - const int max_backend_supported_layers = hparams.n_layer + 1; - const int max_offloadable_layers = hparams.n_layer + 1; - - LLAMA_LOG_INFO("%s: offloaded %d/%d layers to GPU\n", __func__, std::min(n_gpu_layers, max_offloadable_layers), max_backend_supported_layers); - } - - // print memory requirements per buffer type - for (auto & [_, bufs] : pimpl->ctxs_bufs) { - for (auto & buf: bufs) { - LLAMA_LOG_INFO("%s: %12s model buffer size = %8.2f MiB\n", - __func__, ggml_backend_buffer_name(buf.get()), ggml_backend_buffer_get_size(buf.get()) / 1024.0 / 1024.0); - } - } - - if (ml.no_alloc) { - return true; - } - - // load tensor data - for (auto & [ctx, buf_map] : ctx_buf_maps) { - if (!ml.load_all_data(ctx, buf_map, use_mlock ? &pimpl->mlock_mmaps : NULL, params.progress_callback, params.progress_callback_user_data)) { - return false; - } - } - - if (use_mmap_buffer) { - for (auto & mapping : ml.mappings) { - pimpl->mappings.emplace_back(std::move(mapping)); - } - } - - return true; -} - -std::string llama_model::arch_name() const { - return llm_arch_name(arch); -} - -std::string llama_model::type_name() const { - return llm_type_name(type); -} - -std::string llama_model::desc() const { - return pimpl->desc_str; -} - -size_t llama_model::size() const { - return pimpl->n_bytes; -} - -size_t llama_model::n_tensors() const { - return tensors_by_name.size(); -} - -size_t llama_model::n_devices() const { - return devices.size(); -} - -const float * llama_model::tensor_split() const { - return params.tensor_split; -} - -uint32_t llama_model::n_gpu_layers() const { - return params.n_gpu_layers >= 0 ? params.n_gpu_layers : hparams.n_layer + 1; -} - -llama_split_mode llama_model::split_mode() const { - return params.split_mode; -} - -std::map llama_model::memory_breakdown() const { - std::map ret; - for (const auto & [ctx, bufs] : pimpl->ctxs_bufs) { - if (hparams.no_alloc) { - GGML_ASSERT(bufs.size() == 1); - ggml_backend_buffer_t buf = bufs[0].get(); - GGML_ASSERT(ggml_backend_buffer_get_base(buf) == nullptr); - ggml_backend_buffer_type_t buft = ggml_backend_buffer_get_type(buf); - ret[buft] += ggml_backend_alloc_ctx_tensors_from_buft_size(ctx.get(), buft); - } else { - for (const auto & buf : bufs) { - // GGML_ASSERT(ggml_backend_buffer_get_base(buf.get()) != nullptr); // multi_buffer does not have a defined base - ret[ggml_backend_buffer_get_type(buf.get())] += ggml_backend_buffer_get_size(buf.get()); - } - } - } - return ret; -} - -uint64_t llama_model::n_elements() const { - return pimpl->n_elements; -} - -void llama_model::print_info() const { - const std::string rope_scaling_type = llama_rope_scaling_type_name(hparams.rope_scaling_type_train); - - auto print_f = [](const std::function & f, uint32_t n) { - bool is_var = false; - - std::vector v; - for (uint32_t i = 0; i < n; ++i) { - v.push_back(f(i)); - if (v[i] != v[0]) { - is_var = true; - } - } - - std::stringstream ss; - - if (is_var) { - ss << "["; - for (uint32_t i = 0; i < n; ++i) { - ss << v[i]; - if (i < n - 1) { - ss << ", "; - } - } - ss << "]"; - } else { - ss << v[0]; - } - - return ss.str(); - }; - - // hparams - LLAMA_LOG_INFO("%s: arch = %s\n", __func__, arch_name().c_str()); - LLAMA_LOG_INFO("%s: vocab_only = %d\n", __func__, hparams.vocab_only); - LLAMA_LOG_INFO("%s: no_alloc = %d\n", __func__, hparams.no_alloc); - - if (!hparams.vocab_only) { - LLAMA_LOG_INFO("%s: n_ctx_train = %u\n", __func__, hparams.n_ctx_train); - LLAMA_LOG_INFO("%s: n_embd = %u\n", __func__, hparams.n_embd); - LLAMA_LOG_INFO("%s: n_embd_inp = %u\n", __func__, hparams.n_embd_inp()); - LLAMA_LOG_INFO("%s: n_layer = %u\n", __func__, hparams.n_layer); - LLAMA_LOG_INFO("%s: n_head = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head(il); }, hparams.n_layer).c_str()); - LLAMA_LOG_INFO("%s: n_head_kv = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head_kv(il); }, hparams.n_layer).c_str()); - LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot_full); - LLAMA_LOG_INFO("%s: n_swa = %u\n", __func__, hparams.n_swa); - LLAMA_LOG_INFO("%s: is_swa_any = %u\n", __func__, hparams.is_swa_any()); - LLAMA_LOG_INFO("%s: n_embd_head_k = %u\n", __func__, hparams.n_embd_head_k_full); - LLAMA_LOG_INFO("%s: n_embd_head_v = %u\n", __func__, hparams.n_embd_head_v_full); - LLAMA_LOG_INFO("%s: n_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_gqa(il); }, hparams.n_layer).c_str()); - LLAMA_LOG_INFO("%s: n_embd_k_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_k_gqa(il); }, hparams.n_layer).c_str()); - LLAMA_LOG_INFO("%s: n_embd_v_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_v_gqa(il); }, hparams.n_layer).c_str()); - LLAMA_LOG_INFO("%s: f_norm_eps = %.1e\n", __func__, hparams.f_norm_eps); - LLAMA_LOG_INFO("%s: f_norm_rms_eps = %.1e\n", __func__, hparams.f_norm_rms_eps); - LLAMA_LOG_INFO("%s: f_clamp_kqv = %.1e\n", __func__, hparams.f_clamp_kqv); - LLAMA_LOG_INFO("%s: f_max_alibi_bias = %.1e\n", __func__, hparams.f_max_alibi_bias); - LLAMA_LOG_INFO("%s: f_logit_scale = %.1e\n", __func__, hparams.f_logit_scale); - LLAMA_LOG_INFO("%s: f_attn_scale = %.1e\n", __func__, hparams.f_attention_scale); - LLAMA_LOG_INFO("%s: n_ff = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_ff(il); }, hparams.n_layer).c_str()); - LLAMA_LOG_INFO("%s: n_expert = %u\n", __func__, hparams.n_expert); - LLAMA_LOG_INFO("%s: n_expert_used = %u\n", __func__, hparams.n_expert_used); - LLAMA_LOG_INFO("%s: n_expert_groups = %d\n", __func__, hparams.n_expert_groups); - LLAMA_LOG_INFO("%s: n_group_used = %d\n", __func__, hparams.n_group_used); - LLAMA_LOG_INFO("%s: causal attn = %d\n", __func__, hparams.causal_attn); - LLAMA_LOG_INFO("%s: pooling type = %d\n", __func__, hparams.pooling_type); - LLAMA_LOG_INFO("%s: rope type = %d\n", __func__, hparams.rope_type); - LLAMA_LOG_INFO("%s: rope scaling = %s\n", __func__, rope_scaling_type.c_str()); - LLAMA_LOG_INFO("%s: freq_base_train = %.1f\n", __func__, hparams.rope_freq_base_train); - LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train); - if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { - LLAMA_LOG_INFO("%s: freq_base_swa = %.1f\n", __func__, hparams.rope_freq_base_train_swa); - LLAMA_LOG_INFO("%s: freq_scale_swa = %g\n", __func__, hparams.rope_freq_scale_train_swa); - LLAMA_LOG_INFO("%s: n_embd_head_k_swa = %u\n", __func__, hparams.n_embd_head_k_swa); - LLAMA_LOG_INFO("%s: n_embd_head_v_swa = %u\n", __func__, hparams.n_embd_head_v_swa); - LLAMA_LOG_INFO("%s: n_rot_swa = %u\n", __func__, hparams.n_rot_swa); - } - LLAMA_LOG_INFO("%s: n_ctx_orig_yarn = %u\n", __func__, hparams.n_ctx_orig_yarn); - LLAMA_LOG_INFO("%s: rope_yarn_log_mul = %.4f\n", __func__, hparams.rope_yarn_log_mul); - LLAMA_LOG_INFO("%s: rope_finetuned = %s\n", __func__, hparams.rope_finetuned ? "yes" : "unknown"); - // MRoPE (Multi-axis Rotary Position Embedding) sections - if (const auto & s = hparams.rope_sections; s[0] || s[1] || s[2] || s[3]) { - LLAMA_LOG_INFO("%s: mrope sections = [%d, %d, %d, %d]\n", __func__, s[0], s[1], s[2], s[3]); - } - if (!classifier_labels.empty()) { - LLAMA_LOG_INFO("%s: n_cls_out = %u\n", __func__, hparams.n_cls_out); - - size_t i = 0; - for (const auto & label : classifier_labels) { - LLAMA_LOG_INFO("%s: cls_label[%2zu] = %s\n", __func__, i++, label.c_str()); - } - } - - if (arch == LLM_ARCH_MAMBA || - arch == LLM_ARCH_MAMBA2 || - arch == LLM_ARCH_JAMBA || - arch == LLM_ARCH_FALCON_H1 || - arch == LLM_ARCH_PLAMO2 || - arch == LLM_ARCH_GRANITE_HYBRID || - arch == LLM_ARCH_QWEN3NEXT || - arch == LLM_ARCH_QWEN35 || - arch == LLM_ARCH_QWEN35MOE || - arch == LLM_ARCH_NEMOTRON_H || - arch == LLM_ARCH_NEMOTRON_H_MOE) { - LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv); - LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner); - LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state); - LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank); - LLAMA_LOG_INFO("%s: ssm_n_group = %u\n", __func__, hparams.ssm_n_group); - LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms); - } - - LLAMA_LOG_INFO("%s: model type = %s\n", __func__, type_name().c_str()); - if (pimpl->n_elements >= 1e12) { - LLAMA_LOG_INFO("%s: model params = %.2f T\n", __func__, pimpl->n_elements*1e-12); - } else if (pimpl->n_elements >= 1e9) { - LLAMA_LOG_INFO("%s: model params = %.2f B\n", __func__, pimpl->n_elements*1e-9); - } else if (pimpl->n_elements >= 1e6) { - LLAMA_LOG_INFO("%s: model params = %.2f M\n", __func__, pimpl->n_elements*1e-6); - } else { - LLAMA_LOG_INFO("%s: model params = %.2f K\n", __func__, pimpl->n_elements*1e-3); - } - - // general kv - LLAMA_LOG_INFO("%s: general.name = %s\n", __func__, name.c_str()); - - if (arch == LLM_ARCH_DEEPSEEK) { - LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); - LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); - } - - if (arch == LLM_ARCH_DEEPSEEK2 || arch == LLM_ARCH_DEEPSEEK2OCR || arch == LLM_ARCH_GLM_DSA || arch == LLM_ARCH_MISTRAL4) { - LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); - LLAMA_LOG_INFO("%s: n_lora_q = %d\n", __func__, hparams.n_lora_q); - LLAMA_LOG_INFO("%s: n_lora_kv = %d\n", __func__, hparams.n_lora_kv); - LLAMA_LOG_INFO("%s: n_embd_head_k_mla = %d\n", __func__, hparams.n_embd_head_k_mla()); - LLAMA_LOG_INFO("%s: n_embd_head_v_mla = %d\n", __func__, hparams.n_embd_head_v_mla()); - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); - LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); - LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); - LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); - } - - if (arch == LLM_ARCH_QWEN2MOE) { - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); - } - - if (arch == LLM_ARCH_QWEN3MOE || arch == LLM_ARCH_OPENAI_MOE || arch == LLM_ARCH_QWEN3VLMOE || arch == LLM_ARCH_RND1) { - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - } - - if (arch == LLM_ARCH_MINICPM || - arch == LLM_ARCH_GRANITE || - arch == LLM_ARCH_GRANITE_MOE || - arch == LLM_ARCH_GRANITE_HYBRID || - arch == LLM_ARCH_NEMOTRON_H_MOE) { - LLAMA_LOG_INFO("%s: f_embedding_scale = %f\n", __func__, hparams.f_embedding_scale); - LLAMA_LOG_INFO("%s: f_residual_scale = %f\n", __func__, hparams.f_residual_scale); - LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale); - LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); - } - - if (arch == LLM_ARCH_BAILINGMOE) { - LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); - LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); - LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); - } - - if (arch == LLM_ARCH_BAILINGMOE2) { - LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); - LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); - LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); - LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); - LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); - LLAMA_LOG_INFO("%s: nextn_predict_layers = %d\n", __func__, hparams.nextn_predict_layers); - } - - if (arch == LLM_ARCH_SMALLTHINKER || arch == LLM_ARCH_LFM2MOE) { - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); - } - - if (arch == LLM_ARCH_GROVEMOE) { - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: n_ff_chexp = %d\n", __func__, hparams.n_ff_chexp); - LLAMA_LOG_INFO("%s: n_group_experts = %d\n", __func__, hparams.n_group_experts); - LLAMA_LOG_INFO("%s: expert_group_scale = %.2f\n", __func__, hparams.expert_group_scale); - } - } - - vocab.print_info(); -} - -ggml_backend_dev_t llama_model::dev_layer(int il) const { - return pimpl->dev_layer.at(il).dev; -} - -ggml_backend_dev_t llama_model::dev_output() const { - return pimpl->dev_output.dev; -} - -template -static bool buft_supported(ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev, F & fn) { - ggml_init_params params = { - /*.mem_size =*/ ggml_tensor_overhead()*8, - /*.mem_buffer =*/ NULL, - /*.no_alloc =*/ true, - }; - - ggml_context_ptr ctx { ggml_init(params) }; - if (!ctx) { - throw std::runtime_error(format("failed to create ggml context")); - } - - ggml_backend_buffer_ptr buf { ggml_backend_buft_alloc_buffer(buft, 0) }; - ggml_tensor * op_tensor = fn(ctx.get()); - for (int i = 0; i < GGML_MAX_SRC; i++) { - if (op_tensor->src[i] != nullptr) { - assert(op_tensor->src[i]->buffer == nullptr); - op_tensor->src[i]->buffer = buf.get(); - } - } - - bool op_supported = ggml_backend_dev_supports_op(dev, op_tensor); - - return op_supported; -} - -template -static ggml_backend_buffer_type_t select_buft(const buft_list_t & buft_list, const F & fn) { - for (const auto & cur : buft_list) { - ggml_backend_dev_t cur_dev = cur.first; - ggml_backend_buffer_type_t cur_buft = cur.second; - if (buft_supported(cur_buft, cur_dev, fn)) { - return cur_buft; - } - } - - throw std::runtime_error(format("no suitable buffer type found")); -} - -ggml_backend_buffer_type_t llama_model::select_buft(int il) const { - return ::select_buft( - *pimpl->dev_layer.at(il).buft_list, - [&](ggml_context * ctx) { - ggml_tensor * cur = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.n_embd); - ggml_tensor * layer_dir = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.n_embd); - return ggml_add(ctx, cur, layer_dir); - }); -} - -bool llama_model::has_tensor_overrides() const { - return pimpl->has_tensor_overrides; -} - -const ggml_tensor * llama_model::get_tensor(const char * name) const { - auto it = std::find_if(tensors_by_name.begin(), tensors_by_name.end(), - [name](const std::pair & it) { - return it.first == name; - }); - if (it == tensors_by_name.end()) { - return nullptr; - } - - return it->second; -} - -float llama_model::get_rope_freq_base (const llama_cparams & cparams, int il) const { - return hparams.is_swa(il) ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base; -} - -float llama_model::get_rope_freq_scale(const llama_cparams & cparams, int il) const { - return hparams.is_swa(il) ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale; -} - -ggml_tensor * llama_model::get_rope_factors(const llama_cparams & cparams, int il) const { - const uint32_t n_ctx_seq = cparams.n_ctx_seq; - - // choose long/short freq factors based on the context size - if (layers[il].rope_freqs != nullptr) { - return layers[il].rope_freqs; - } - - if (n_ctx_seq > hparams.n_ctx_orig_yarn) { - return layers[il].rope_long; - } - - return layers[il].rope_short; -} - -llama_memory_i * llama_model::create_memory(const llama_memory_params & params, const llama_cparams & cparams) const { - llama_memory_i * res; - - switch (arch) { - // Models that need specific instantiation should be handled in the - // switch statement - case LLM_ARCH_BERT: - case LLM_ARCH_JINA_BERT_V2: - case LLM_ARCH_JINA_BERT_V3: - case LLM_ARCH_NOMIC_BERT: - case LLM_ARCH_NOMIC_BERT_MOE: - case LLM_ARCH_NEO_BERT: - case LLM_ARCH_EUROBERT: - case LLM_ARCH_WAVTOKENIZER_DEC: - case LLM_ARCH_MODERN_BERT: - case LLM_ARCH_GEMMA_EMBEDDING: - case LLM_ARCH_DREAM: - case LLM_ARCH_LLADA: - case LLM_ARCH_LLADA_MOE: - case LLM_ARCH_RND1: - { - res = nullptr; - } break; - // Models that need standard caching should rely on recurrent/hybrid - // checks - default: - { - if (llm_arch_is_recurrent(arch)) { - res = new llama_memory_recurrent( - *this, - GGML_TYPE_F32, - GGML_TYPE_F32, - cparams.offload_kqv, - std::max((uint32_t) 1, cparams.n_seq_max), - cparams.n_seq_max, - nullptr); - } else if (llm_arch_is_hybrid(arch)) { - // The main difference between hybrid architectures is the - // layer filters, so pick the right one here - llama_memory_hybrid::layer_filter_cb filter_attn = nullptr; - llama_memory_hybrid::layer_filter_cb filter_recr = nullptr; - if (arch == LLM_ARCH_FALCON_H1) { - filter_attn = [&](int32_t) { return true; }; - filter_recr = [&](int32_t) { return true; }; - } else if (arch == LLM_ARCH_NEMOTRON_H || arch == LLM_ARCH_NEMOTRON_H_MOE) { - filter_attn = [&](int32_t il) { - return !hparams.is_recurrent(il) && hparams.n_ff(il) == 0; - }; - filter_recr = [&](int32_t il) { - return hparams.is_recurrent(il) && hparams.n_ff(il) == 0; - }; - } - - if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { - // Use hybrid-iswa for hybrid models with SWA - res = new llama_memory_hybrid_iswa( - /* model */ *this, - /* attn_type_k */ params.type_k, - /* attn_type_v */ params.type_v, - /* attn_v_trans */ !cparams.flash_attn, - /* attn_swa_full */ params.swa_full, - /* attn_kv_size */ cparams.n_ctx_seq, - /* attn_n_ubatch */ cparams.n_ubatch, - /* attn_n_pad */ 1, - /* recurrent_type_r */ GGML_TYPE_F32, - /* recurrent_type_s */ GGML_TYPE_F32, - /* recurrent_rs_size */ std::max((uint32_t) 1, cparams.n_seq_max), - /* n_seq_max */ cparams.n_seq_max, - /* offload */ cparams.offload_kqv, - /* unified */ cparams.kv_unified, - /* filter_attn */ std::move(filter_attn), - /* filter_recr */ std::move(filter_recr)); - } else { - res = new llama_memory_hybrid( - /* model */ *this, - /* attn_type_k */ params.type_k, - /* attn_type_v */ params.type_v, - /* attn_v_trans */ !cparams.flash_attn, - /* attn_kv_size */ cparams.n_ctx_seq, - /* attn_n_pad */ 1, - /* attn_n_swa */ hparams.n_swa, - /* attn_swa_type */ hparams.swa_type, - /* recurrent_type_k */ GGML_TYPE_F32, - /* recurrent_type_v */ GGML_TYPE_F32, - /* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max), - /* n_seq_max */ cparams.n_seq_max, - /* offload */ cparams.offload_kqv, - /* unified */ cparams.kv_unified, - /* filter_attn */ std::move(filter_attn), - /* filter_recr */ std::move(filter_recr)); - } - } else { - llama_memory_i::layer_reuse_cb reuse = nullptr; - - if (arch == LLM_ARCH_GEMMA3N || arch == LLM_ARCH_GEMMA4) { - reuse = [&](int32_t il) { - if (il >= (int32_t) hparams.n_layer_kv_from_start) { - return (int32_t) hparams.n_layer_kv_from_start - (hparams.is_swa(il) ? 2 : 1); - } - - return -1; - }; - } - - if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { - GGML_ASSERT(hparams.is_swa_any()); - - res = new llama_kv_cache_iswa( - *this, - params.type_k, - params.type_v, - !cparams.flash_attn, - cparams.offload_kqv, - params.swa_full, - cparams.kv_unified, - cparams.n_ctx_seq, - cparams.n_seq_max, - cparams.n_ubatch, - 1, - nullptr, - reuse); - } else { - GGML_ASSERT(!hparams.is_swa_any()); - - res = new llama_kv_cache( - *this, - params.type_k, - params.type_v, - !cparams.flash_attn, - cparams.offload_kqv, - cparams.kv_unified, - cparams.n_ctx_seq, - cparams.n_seq_max, - 1, - hparams.n_swa, - hparams.swa_type, - nullptr, - nullptr); - } - } - } - } - - return res; -} - -ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { - std::unique_ptr llm; - - switch (arch) { - case LLM_ARCH_LLAMA: - { - llm = std::make_unique>(*this, params); - } break; - case LLM_ARCH_LLAMA4: - { - if (hparams.swa_type == LLAMA_SWA_TYPE_NONE) { - llm = std::make_unique>(*this, params); - } else { - llm = std::make_unique>(*this, params); - } - } break; - case LLM_ARCH_LLAMA_EMBED: - { - llm = std::make_unique>(*this, params); - } break; - case LLM_ARCH_MAINCODER: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_DECI: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_BAICHUAN: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_FALCON: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_GROK: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_STARCODER: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_REFACT: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_BERT: - case LLM_ARCH_JINA_BERT_V2: - case LLM_ARCH_JINA_BERT_V3: - case LLM_ARCH_NOMIC_BERT: - case LLM_ARCH_NOMIC_BERT_MOE: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_MODERN_BERT: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_NEO_BERT: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_EUROBERT: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_BLOOM: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_MPT: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_STABLELM: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_QWEN: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_QWEN2: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_DREAM: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_LLADA: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_LLADA_MOE: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_RND1: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_QWEN2VL: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_QWEN2MOE: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_QWEN3: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_QWEN3MOE: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_QWEN3VL: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_QWEN3VLMOE: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_PHI2: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_PHI3: - case LLM_ARCH_PHIMOE: - { - if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { - llm = std::make_unique> (*this, params); - } else { - llm = std::make_unique>(*this, params); - } - } break; - case LLM_ARCH_PLAMO: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_PLAMO2: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_PLAMO3: - { - if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { - llm = std::make_unique> (*this, params); - } else { - llm = std::make_unique>(*this, params); - } - } break; - case LLM_ARCH_GPT2: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_CODESHELL: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_ORION: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_INTERNLM2: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_MINICPM3: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_GEMMA: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_GEMMA2: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_GEMMA3: - { - if (hparams.swa_type == LLAMA_SWA_TYPE_STANDARD) { - llm = std::make_unique>(*this, params); - } else { - llm = std::make_unique>(*this, params); - } - } break; - case LLM_ARCH_GEMMA3N: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_GEMMA4: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_GEMMA_EMBEDDING: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_STARCODER2: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_MAMBA: - case LLM_ARCH_MAMBA2: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_JAMBA: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_XVERSE: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_COMMAND_R: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_COHERE2: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_DBRX: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_OLMO: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_OLMO2: - { - if (hparams.swa_type == LLAMA_SWA_TYPE_STANDARD) { - llm = std::make_unique>(*this, params); - } else { - llm = std::make_unique>(*this, params); - } - } break; - case LLM_ARCH_OLMOE: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_OPENELM: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_GPTNEOX: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_ARCTIC: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_DEEPSEEK: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_DEEPSEEK2: - case LLM_ARCH_DEEPSEEK2OCR: - case LLM_ARCH_GLM_DSA: - case LLM_ARCH_MISTRAL4: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_CHATGLM: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_GLM4: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_GLM4_MOE: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_BITNET: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_T5: - { - switch (params.gtype) { - case LLM_GRAPH_TYPE_ENCODER: - llm = std::make_unique>(*this, params); - break; - case LLM_GRAPH_TYPE_DEFAULT: - case LLM_GRAPH_TYPE_DECODER: - llm = std::make_unique>(*this, params); - break; - default: - GGML_ABORT("invalid graph type"); - }; - } break; - case LLM_ARCH_T5ENCODER: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_JAIS: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_JAIS2: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_NEMOTRON: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_NEMOTRON_H: - case LLM_ARCH_NEMOTRON_H_MOE: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_EXAONE: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_EXAONE4: - { - if (hparams.swa_type == LLAMA_SWA_TYPE_STANDARD) { - llm = std::make_unique>(*this, params); - } else { - llm = std::make_unique>(*this, params); - } - } break; - case LLM_ARCH_EXAONE_MOE: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_RWKV6: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_RWKV6QWEN2: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_RWKV7: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_ARWKV7: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_GRANITE: - case LLM_ARCH_GRANITE_MOE: - case LLM_ARCH_MINICPM: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_GRANITE_HYBRID: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_CHAMELEON: - { - llm = std::make_unique(*this, params); - } break; + // choose long/short freq factors based on the context size + if (layers[il].rope_freqs != nullptr) { + return layers[il].rope_freqs; + } + + if (n_ctx_seq > hparams.n_ctx_orig_yarn) { + return layers[il].rope_long; + } + + return layers[il].rope_short; +} + +llama_memory_i * llama_model::create_memory(const llama_memory_params & params, const llama_cparams & cparams) const { + llama_memory_i * res; + + switch (arch) { + // Models that need specific instantiation should be handled in the + // switch statement + case LLM_ARCH_BERT: + case LLM_ARCH_JINA_BERT_V2: + case LLM_ARCH_JINA_BERT_V3: + case LLM_ARCH_NOMIC_BERT: + case LLM_ARCH_NOMIC_BERT_MOE: + case LLM_ARCH_NEO_BERT: + case LLM_ARCH_EUROBERT: case LLM_ARCH_WAVTOKENIZER_DEC: + case LLM_ARCH_MODERN_BERT: + case LLM_ARCH_GEMMA_EMBEDDING: + case LLM_ARCH_DREAM: + case LLM_ARCH_LLADA: + case LLM_ARCH_LLADA_MOE: + case LLM_ARCH_RND1: { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_PLM: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_BAILINGMOE: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_BAILINGMOE2: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_SEED_OSS: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_DOTS1: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_ARCEE: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_AFMOE: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_ERNIE4_5: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_ERNIE4_5_MOE: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_PADDLEOCR: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_HUNYUAN_MOE: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_HUNYUAN_VL: - case LLM_ARCH_HUNYUAN_DENSE: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_SMOLLM3: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_OPENAI_MOE: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_FALCON_H1: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_LFM2: - case LLM_ARCH_LFM2MOE: - { - if (hparams.swa_type == LLAMA_SWA_TYPE_STANDARD) { - llm = std::make_unique>(*this, params); - } else { - llm = std::make_unique>(*this, params); - } + res = nullptr; } break; - case LLM_ARCH_SMALLTHINKER: + // Models that need standard caching should rely on recurrent/hybrid + // checks + default: { - if (hparams.swa_type == LLAMA_SWA_TYPE_STANDARD) { - llm = std::make_unique> (*this, params); + if (llm_arch_is_recurrent(arch)) { + res = new llama_memory_recurrent( + *this, + GGML_TYPE_F32, + GGML_TYPE_F32, + cparams.offload_kqv, + std::max((uint32_t) 1, cparams.n_seq_max), + cparams.n_seq_max, + nullptr); + } else if (llm_arch_is_hybrid(arch)) { + // The main difference between hybrid architectures is the + // layer filters, so pick the right one here + llama_memory_hybrid::layer_filter_cb filter_attn = nullptr; + llama_memory_hybrid::layer_filter_cb filter_recr = nullptr; + if (arch == LLM_ARCH_FALCON_H1) { + filter_attn = [&](int32_t) { return true; }; + filter_recr = [&](int32_t) { return true; }; + } else if (arch == LLM_ARCH_NEMOTRON_H || arch == LLM_ARCH_NEMOTRON_H_MOE) { + filter_attn = [&](int32_t il) { + return !hparams.is_recurrent(il) && hparams.n_ff(il) == 0; + }; + filter_recr = [&](int32_t il) { + return hparams.is_recurrent(il) && hparams.n_ff(il) == 0; + }; + } + + if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { + // Use hybrid-iswa for hybrid models with SWA + res = new llama_memory_hybrid_iswa( + /* model */ *this, + /* attn_type_k */ params.type_k, + /* attn_type_v */ params.type_v, + /* attn_v_trans */ !cparams.flash_attn, + /* attn_swa_full */ params.swa_full, + /* attn_kv_size */ cparams.n_ctx_seq, + /* attn_n_ubatch */ cparams.n_ubatch, + /* attn_n_pad */ 1, + /* recurrent_type_r */ GGML_TYPE_F32, + /* recurrent_type_s */ GGML_TYPE_F32, + /* recurrent_rs_size */ std::max((uint32_t) 1, cparams.n_seq_max), + /* n_seq_max */ cparams.n_seq_max, + /* offload */ cparams.offload_kqv, + /* unified */ cparams.kv_unified, + /* filter_attn */ std::move(filter_attn), + /* filter_recr */ std::move(filter_recr)); + } else { + res = new llama_memory_hybrid( + /* model */ *this, + /* attn_type_k */ params.type_k, + /* attn_type_v */ params.type_v, + /* attn_v_trans */ !cparams.flash_attn, + /* attn_kv_size */ cparams.n_ctx_seq, + /* attn_n_pad */ 1, + /* attn_n_swa */ hparams.n_swa, + /* attn_swa_type */ hparams.swa_type, + /* recurrent_type_k */ GGML_TYPE_F32, + /* recurrent_type_v */ GGML_TYPE_F32, + /* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max), + /* n_seq_max */ cparams.n_seq_max, + /* offload */ cparams.offload_kqv, + /* unified */ cparams.kv_unified, + /* filter_attn */ std::move(filter_attn), + /* filter_recr */ std::move(filter_recr)); + } } else { - llm = std::make_unique>(*this, params); + llama_memory_i::layer_reuse_cb reuse = nullptr; + + if (arch == LLM_ARCH_GEMMA3N || arch == LLM_ARCH_GEMMA4) { + reuse = [&](int32_t il) { + if (il >= (int32_t) hparams.n_layer_kv_from_start) { + return (int32_t) hparams.n_layer_kv_from_start - (hparams.is_swa(il) ? 2 : 1); + } + + return -1; + }; + } + + if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { + GGML_ASSERT(hparams.is_swa_any()); + + res = new llama_kv_cache_iswa( + *this, + params.type_k, + params.type_v, + !cparams.flash_attn, + cparams.offload_kqv, + params.swa_full, + cparams.kv_unified, + cparams.n_ctx_seq, + cparams.n_seq_max, + cparams.n_ubatch, + 1, + nullptr, + reuse); + } else { + GGML_ASSERT(!hparams.is_swa_any()); + + res = new llama_kv_cache( + *this, + params.type_k, + params.type_v, + !cparams.flash_attn, + cparams.offload_kqv, + cparams.kv_unified, + cparams.n_ctx_seq, + cparams.n_seq_max, + 1, + hparams.n_swa, + hparams.swa_type, + nullptr, + nullptr); + } } - } break; - case LLM_ARCH_GROVEMOE: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_APERTUS: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_MINIMAX_M2: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_COGVLM: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_PANGU_EMBED: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_QWEN3NEXT: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_QWEN35: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_QWEN35MOE: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_MISTRAL3: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_MIMO2: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_KIMI_LINEAR: - { - llm = std::make_unique(*this, params); - } break; - case LLM_ARCH_STEP35: - { - llm = std::make_unique(*this, params); - } break; - default: - GGML_ABORT("fatal error"); + } } + return res; +} + +ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { + std::unique_ptr llm = build_arch_graph(params); + // add on pooling layer llm->build_pooling(cls, cls_b, cls_out, cls_out_b, cls_norm); @@ -9487,3 +2472,43 @@ ggml_backend_dev_t llama_model_get_device(const struct llama_model * model, int } return model->devices[i].dev; } + +// +// llama_model_base +// + +llama_model_base::llama_model_base(const struct llama_model_params & params) : llama_model(params), model(this), tn(model->arch), + TENSOR_DUPLICATED (llama_model_loader::TENSOR_DUPLICATED), + TENSOR_NOT_REQUIRED (llama_model_loader::TENSOR_NOT_REQUIRED), + TENSOR_SKIP (llama_model_loader::TENSOR_SKIP), + TENSOR_SKIP_IF_VIRTUAL(llama_model_loader::TENSOR_SKIP_IF_VIRTUAL) {} + +ggml_tensor * llama_model_base::create_tensor(const LLM_TN_IMPL & tn, const std::initializer_list & ne, int flags) { + GGML_ASSERT(ml != nullptr); + return create_tensor(*ml, tn, ne, flags); +} + +void llama_model_base::create_tensor_gate_up_exps(llama_layer & layer, int bid, int64_t n_embd_, int64_t n_ff_, int64_t n_expert_, int flags) { + layer.ffn_gate_up_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_UP_EXPS, "weight", bid), {n_embd_, n_ff_ * 2, n_expert_}, TENSOR_NOT_REQUIRED); + if (layer.ffn_gate_up_exps == nullptr) { + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", bid), {n_embd_, n_ff_, n_expert_}, flags); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", bid), {n_embd_, n_ff_, n_expert_}, flags); + } +} + +void llama_model_base::create_tensor_qkv(llama_layer & layer, int bid, + int64_t n_embd_, int64_t n_embd_q_, int64_t n_embd_k_, int64_t n_embd_v_, + int flags) { + const int64_t n_embd_qkv = n_embd_q_ + n_embd_k_ + n_embd_v_; + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", bid), {n_embd_, n_embd_qkv}, TENSOR_NOT_REQUIRED | TENSOR_SKIP_IF_VIRTUAL); + if (layer.wqkv) { + layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", bid), {n_embd_qkv}, TENSOR_NOT_REQUIRED | TENSOR_SKIP_IF_VIRTUAL); + } else { + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", bid), {n_embd_, n_embd_q_}, flags); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", bid), {n_embd_, n_embd_k_}, flags); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", bid), {n_embd_, n_embd_v_}, flags); + layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", bid), {n_embd_q_}, TENSOR_NOT_REQUIRED); + layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", bid), {n_embd_k_}, TENSOR_NOT_REQUIRED); + layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", bid), {n_embd_v_}, TENSOR_NOT_REQUIRED); + } +} diff --git a/examples/talk-llama/llama-model.h b/examples/talk-llama/llama-model.h index 5f101bd6374..d63c689185a 100644 --- a/examples/talk-llama/llama-model.h +++ b/examples/talk-llama/llama-model.h @@ -577,14 +577,8 @@ struct llama_model { int64_t t_load_us = 0; int64_t t_start_us = 0; - explicit llama_model(const struct llama_model_params & params); - ~llama_model(); - - void load_stats (llama_model_loader & ml); - void load_arch (llama_model_loader & ml); - void load_hparams(llama_model_loader & ml); - void load_vocab (llama_model_loader & ml); - bool load_tensors(llama_model_loader & ml); // returns false if cancelled by progress_callback + explicit llama_model(const llama_model_params & params); + virtual ~llama_model(); std::string arch_name() const; std::string type_name() const; @@ -620,21 +614,94 @@ struct llama_model { ggml_tensor * get_rope_factors(const llama_cparams & cparams, int il) const; - // TODO: move this to new llm_arch_model_i interface llama_memory_i * create_memory(const llama_memory_params & params, const llama_cparams & cparams) const; - // TODO: move this to new llm_arch_model_i interface ggml_cgraph * build_graph(const llm_graph_params & params) const; -private: + virtual void load_stats (llama_model_loader & ml) = 0; + virtual void load_hparams(llama_model_loader & ml) = 0; + virtual void load_vocab (llama_model_loader & ml) = 0; + virtual bool load_tensors(llama_model_loader & ml) = 0; // returns false if cancelled by progress_callback + + // model must define these + virtual void load_arch_hparams(llama_model_loader & ml) = 0; + virtual void load_arch_tensors(llama_model_loader & ml) = 0; + virtual std::unique_ptr build_arch_graph(const llm_graph_params & params) const = 0; + +protected: llama_model_params params; struct impl; std::unique_ptr pimpl; }; +llama_model * llama_model_create(llm_arch arch, const llama_model_params & params); +llama_model * llama_model_create(llama_model_loader & ml, const llama_model_params & params); + +// model must inherit from this +struct llama_model_base : public llama_model { + friend struct llama_model; + + llama_model * model; + llama_model_loader * ml = nullptr; + const LLM_TN tn; + + // llama_model_loader is not yet defined at this point, so we will set it after construction + const int TENSOR_DUPLICATED; + const int TENSOR_NOT_REQUIRED; + const int TENSOR_SKIP; + const int TENSOR_SKIP_IF_VIRTUAL; + + explicit llama_model_base(const llama_model_params & params); + virtual ~llama_model_base() = default; + + ggml_tensor * create_tensor(llama_model_loader & ml, const LLM_TN_IMPL & tn, const std::initializer_list & ne, int flags); + + // convenience overload of create_tensor that doesn't require llama_model_loader + ggml_tensor * create_tensor(const LLM_TN_IMPL & tn, const std::initializer_list & ne, int flags); + + // helper: try merged gate_up_exps first, fall back to separate gate and up + void create_tensor_gate_up_exps(llama_layer & layer, int bid, int64_t n_embd_, + int64_t n_ff_, int64_t n_expert_, int flags); + + // helper: try to load merged qkv first, fall back to separate q, k, v + void create_tensor_qkv(llama_layer & layer, int bid, + int64_t n_embd_, int64_t n_embd_q_, int64_t n_embd_k_, int64_t n_embd_v_, + int flags); + + void load_stats (llama_model_loader & ml) override; + void load_hparams(llama_model_loader & ml) override; + void load_vocab (llama_model_loader & ml) override; + bool load_tensors(llama_model_loader & ml) override; + + // model must define these + void load_arch_hparams(llama_model_loader & ml) override = 0; + void load_arch_tensors(llama_model_loader & ml) override = 0; + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override = 0; +}; + const char * llm_type_name(llm_type type); +// convenience macro for loading local variables for load_tensors() in llama_model_base +// note: cast to int64_t since we will use these for the tensor dimensions +#define LLAMA_LOAD_LOCALS \ + const int n_layer = hparams.n_layer; GGML_UNUSED(n_layer); \ + const int64_t n_head = hparams.n_head(); GGML_UNUSED(n_head); \ + const int64_t n_head_kv = hparams.n_head_kv(); GGML_UNUSED(n_head_kv); \ + const int64_t n_embd = hparams.n_embd; GGML_UNUSED(n_embd); \ + const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(); GGML_UNUSED(n_embd_k_gqa); \ + const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(); GGML_UNUSED(n_embd_v_gqa); \ + const int64_t n_embd_head_k = hparams.n_embd_head_k(); GGML_UNUSED(n_embd_head_k); \ + const int64_t n_embd_head_v = hparams.n_embd_head_v(); GGML_UNUSED(n_embd_head_v); \ + const int64_t n_ff = hparams.n_ff(); GGML_UNUSED(n_ff); \ + const int64_t n_embd_gqa = n_embd_v_gqa; GGML_UNUSED(n_embd_gqa); \ + const int64_t n_vocab = vocab.n_tokens(); GGML_UNUSED(n_vocab); \ + const int64_t n_token_types = vocab.n_token_types(); GGML_UNUSED(n_token_types); \ + const int64_t n_rot = hparams.n_rot(); GGML_UNUSED(n_rot); \ + const int64_t n_expert = hparams.n_expert; GGML_UNUSED(n_expert); \ + const int64_t n_expert_used = hparams.n_expert_used; GGML_UNUSED(n_expert_used); \ + const int64_t n_ctx_train = hparams.n_ctx_train; GGML_UNUSED(n_ctx_train); + // For internal test use // TODO: remove const std::vector> & llama_internal_get_tensor_map(const llama_model * model); diff --git a/examples/talk-llama/llama-quant.cpp b/examples/talk-llama/llama-quant.cpp index 2f0f70b73b6..43e05c3d56f 100644 --- a/examples/talk-llama/llama-quant.cpp +++ b/examples/talk-llama/llama-quant.cpp @@ -882,13 +882,18 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: fname_inp, splits, /*file*/ nullptr, use_mmap, /*use_direct_io*/ false, /*check_tensors*/ true, /*no_alloc*/ false, kv_overrides, nullptr); ml.init_mappings(false); // no prefetching - llama_model model(llama_model_default_params()); + auto mparams = llama_model_default_params(); + std::unique_ptr model_ptr(llama_model_create(ml, mparams)); - model.load_arch (ml); - model.load_hparams(ml); - model.load_stats (ml); + auto * model = dynamic_cast(model_ptr.get()); + if (model == nullptr) { + GGML_ABORT("fatal error: model does not implement llama_model_base"); + } + + model->load_hparams(ml); + model->load_stats (ml); - quantize_state_impl qs(model, params); + quantize_state_impl qs(*model, params); if (params->only_copy) { ftype = ml.ftype; @@ -1023,7 +1028,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: } gguf_add_tensor(ctx_outs[i_split].get(), tensor); - metadata[i].allows_quantization = tensor_allows_quantization(params, model.arch, tensor); + metadata[i].allows_quantization = tensor_allows_quantization(params, model->arch, tensor); if (metadata[i].allows_quantization) { metadata[i].target_type = llama_tensor_get_type(qs, params, tensor, default_type, metadata[i]); @@ -1331,9 +1336,9 @@ void llama_quant_free(quantize_state_impl * qs) { llama_model * llama_quant_model_from_metadata(const llama_quant_model_desc * desc) { struct llama_model_params mparams = llama_model_default_params(); - auto * model = new llama_model(mparams); - - model->arch = llm_arch_from_string(desc->architecture); + auto arch = llm_arch_from_string(desc->architecture); + auto * model = llama_model_create(arch, mparams); + model->arch = arch; // infer llm_type: only LLM_TYPE_70B matters for quantization logic if (model->arch == LLM_ARCH_LLAMA && desc->n_layer == 80 && desc->n_head != desc->n_head_kv) { diff --git a/examples/talk-llama/llama-vocab.cpp b/examples/talk-llama/llama-vocab.cpp index 163f222ef61..f43cf546ca0 100644 --- a/examples/talk-llama/llama-vocab.cpp +++ b/examples/talk-llama/llama-vocab.cpp @@ -503,6 +503,14 @@ struct llm_tokenizer_bpe : llm_tokenizer { }; byte_encode = false; // uses raw UTF-8, not GPT-2 byte encoding break; + case LLAMA_VOCAB_PRE_TYPE_SARVAM_MOE: + // Sarvam uses SPM-style BPE (same shape as Gemma4): spaces replaced with U+2581 + // by the normalizer, BPE merges over the whole text on raw UTF-8. + regex_exprs = { + "[^\\n]+|[\\n]+", + }; + byte_encode = false; + break; default: // default regex for BPE tokenization pre-processing regex_exprs = { @@ -2005,6 +2013,11 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "gemma4") { pre_type = LLAMA_VOCAB_PRE_TYPE_GEMMA4; escape_whitespaces = true; + } else if ( + tokenizer_pre == "sarvam-moe") { + pre_type = LLAMA_VOCAB_PRE_TYPE_SARVAM_MOE; + escape_whitespaces = true; + clean_spaces = false; } else if ( tokenizer_pre == "jina-v1-en" || tokenizer_pre == "jina-v2-code" || diff --git a/examples/talk-llama/llama-vocab.h b/examples/talk-llama/llama-vocab.h index dd38f45d3a2..8b040b912e2 100644 --- a/examples/talk-llama/llama-vocab.h +++ b/examples/talk-llama/llama-vocab.h @@ -59,6 +59,7 @@ enum llama_vocab_pre_type { LLAMA_VOCAB_PRE_TYPE_JOYAI_LLM = 48, LLAMA_VOCAB_PRE_TYPE_JAIS2 = 49, LLAMA_VOCAB_PRE_TYPE_GEMMA4 = 50, + LLAMA_VOCAB_PRE_TYPE_SARVAM_MOE = 51, }; struct LLM_KV; diff --git a/examples/talk-llama/llama.cpp b/examples/talk-llama/llama.cpp index e9c3028585d..dfe30ce8f61 100644 --- a/examples/talk-llama/llama.cpp +++ b/examples/talk-llama/llama.cpp @@ -71,12 +71,18 @@ bool llama_supports_mlock(void) { } bool llama_supports_gpu_offload(void) { + if (!ggml_backend_reg_count()) { + ggml_backend_load_all(); + } return ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_GPU) != nullptr || ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_IGPU) != nullptr || llama_supports_rpc(); } bool llama_supports_rpc(void) { + if (!ggml_backend_reg_count()) { + ggml_backend_load_all(); + } return ggml_backend_reg_by_name("RPC") != nullptr; } @@ -89,6 +95,10 @@ void llama_backend_init(void) { struct ggml_context * ctx = ggml_init(params); ggml_free(ctx); } + + if (!ggml_backend_reg_count()) { + ggml_backend_load_all(); + } } void llama_numa_init(enum ggml_numa_strategy numa) { @@ -111,113 +121,8 @@ int64_t llama_time_us(void) { return ggml_time_us(); } -// Returns 0 on success, -1 on error, and -2 on cancellation via llama_progress_callback -static int llama_model_load(struct gguf_context * metadata, llama_model_set_tensor_data_t set_tensor_data, void * set_tensor_data_ud, - const std::string & fname, std::vector & splits, FILE * file, llama_model & model, llama_model_params & params) { - // loading time will be recalculated after the first eval, so - // we take page faults deferred by mmap() into consideration - model.t_load_us = 0; - time_meas tm(model.t_load_us); - - model.t_start_us = tm.t_start_us; - - try { - llama_model_loader ml(metadata, set_tensor_data, set_tensor_data_ud, fname, splits, file, params.use_mmap, params.use_direct_io, - params.check_tensors, params.no_alloc, params.kv_overrides, params.tensor_buft_overrides); - - ml.print_info(); - - model.hparams.vocab_only = params.vocab_only; - model.hparams.no_alloc = params.no_alloc; - - try { - model.load_arch(ml); - } catch(const std::exception & e) { - throw std::runtime_error("error loading model architecture: " + std::string(e.what())); - } - try { - model.load_hparams(ml); - } catch(const std::exception & e) { - throw std::runtime_error("error loading model hyperparameters: " + std::string(e.what())); - } - if (model.arch == LLM_ARCH_CLIP) { - throw std::runtime_error("CLIP cannot be used as main model, use it with --mmproj instead"); - } - try { - model.load_vocab(ml); - } catch(const std::exception & e) { - throw std::runtime_error("error loading model vocabulary: " + std::string(e.what())); - } - - model.load_stats(ml); - model.print_info(); - - if (params.vocab_only) { - LLAMA_LOG_INFO("%s: vocab only - skipping tensors\n", __func__); - return 0; - } - - if (!model.load_tensors(ml)) { - return -2; - } - } catch (const std::exception & err) { - LLAMA_LOG_ERROR("%s: error loading model: %s\n", __func__, err.what()); - return -1; - } - - return 0; -} - -static struct llama_model * llama_model_load_from_file_impl( - struct gguf_context * metadata, - llama_model_set_tensor_data_t set_tensor_data, - void * set_tensor_data_ud, - const std::string & path_model, - std::vector & splits, - FILE * file, - struct llama_model_params params) { - { - int n_sources_defined = 0; - if (metadata != nullptr) { - n_sources_defined++; - } - if (!path_model.empty()) { - n_sources_defined++; - } - if (file != nullptr) { - n_sources_defined++; - } - if (n_sources_defined != 1) { - LLAMA_LOG_ERROR("%s: exactly one out metadata, path_model, and file must be defined\n", __func__); - return nullptr; - } - } - ggml_time_init(); - - if (!params.vocab_only && ggml_backend_reg_count() == 0) { - LLAMA_LOG_ERROR("%s: no backends are loaded. hint: use ggml_backend_load() or ggml_backend_load_all() to load a backend before calling this function\n", __func__); - return nullptr; - } - - unsigned cur_percentage = 0; - if (params.progress_callback == NULL) { - params.progress_callback_user_data = &cur_percentage; - params.progress_callback = [](float progress, void * ctx) { - unsigned * cur_percentage_p = (unsigned *) ctx; - unsigned percentage = (unsigned) (100 * progress); - while (percentage > *cur_percentage_p) { - *cur_percentage_p = percentage; - LLAMA_LOG_CONT("."); - if (percentage >= 100) { - LLAMA_LOG_CONT("\n"); - } - } - return true; - }; - } - - llama_model * model = new llama_model(params); - +// returns true on success +static bool llama_prepare_model_devices(const llama_model_params & params, llama_model * model) { // create list of devices to use with this model if (params.devices) { if (params.split_mode == LLAMA_SPLIT_MODE_TENSOR) { @@ -227,7 +132,7 @@ static struct llama_model * llama_model_load_from_file_impl( } if (n_devs == 0) { LLAMA_LOG_ERROR("%s: LLAMA_SPLIT_MODE_TENSOR needs >= 1 devices\n", __func__); - return nullptr; + return false; } LLAMA_LOG_INFO("%s: creating a Meta device with %zu devices\n", __func__, n_devs); for (size_t i = 0; i < n_devs; ++i) { @@ -265,7 +170,7 @@ static struct llama_model * llama_model_load_from_file_impl( } if (devs.empty()) { LLAMA_LOG_ERROR("%s: LLAMA_SPLIT_MODE_TENSOR needs >= 1 devices\n", __func__); - return nullptr; + return false; } LLAMA_LOG_INFO("%s: creating a Meta device for tensor parallelism from %zu devices:\n", __func__, devs.size()); @@ -347,8 +252,7 @@ static struct llama_model * llama_model_load_from_file_impl( } else { if (params.main_gpu >= (int)model->devices.size()) { LLAMA_LOG_ERROR("%s: invalid value for main_gpu: %d (available devices: %zu)\n", __func__, params.main_gpu, model->devices.size()); - llama_model_free(model); - return nullptr; + return false; } llama_device main_gpu = model->devices[params.main_gpu]; model->devices.clear(); @@ -365,7 +269,121 @@ static struct llama_model * llama_model_load_from_file_impl( props.memory_free/1024/1024); } - const int status = llama_model_load(metadata, set_tensor_data, set_tensor_data_ud, path_model, splits, file, *model, params); + return true; +} + +// Returns 0 on success, -1 on error, and -2 on cancellation via llama_progress_callback +static std::pair llama_model_load(struct gguf_context * metadata, llama_model_set_tensor_data_t set_tensor_data, void * set_tensor_data_ud, + const std::string & fname, std::vector & splits, FILE * file, llama_model_params & params) { + try { + llama_model_loader ml(metadata, set_tensor_data, set_tensor_data_ud, fname, splits, file, params.use_mmap, params.use_direct_io, + params.check_tensors, params.no_alloc, params.kv_overrides, params.tensor_buft_overrides); + + ml.print_info(); + std::unique_ptr model_ptr(llama_model_create(ml, params)); + + bool ok = llama_prepare_model_devices(params, model_ptr.get()); + if (!ok) { + return {-1, nullptr}; + } + + auto * model = dynamic_cast(model_ptr.get()); + if (model == nullptr) { + GGML_ABORT("fatal error: model does not implement llama_model_base"); + } + + // loading time will be recalculated after the first eval, so + // we take page faults deferred by mmap() into consideration + model->t_load_us = 0; + time_meas tm(model->t_load_us); + + model->t_start_us = tm.t_start_us; + + model->hparams.vocab_only = params.vocab_only; + model->hparams.no_alloc = params.no_alloc; + + try { + model->load_hparams(ml); + } catch(const std::exception & e) { + throw std::runtime_error("error loading model hyperparameters: " + std::string(e.what())); + } + if (model->arch == LLM_ARCH_CLIP) { + throw std::runtime_error("CLIP cannot be used as main model, use it with --mmproj instead"); + } + try { + model->load_vocab(ml); + } catch(const std::exception & e) { + throw std::runtime_error("error loading model vocabulary: " + std::string(e.what())); + } + + model->load_stats(ml); + model->print_info(); + + if (params.vocab_only) { + LLAMA_LOG_INFO("%s: vocab only - skipping tensors\n", __func__); + return {0, model_ptr.release()}; + } + + if (!model->load_tensors(ml)) { + return {-2, nullptr}; + } + + return {0, model_ptr.release()}; + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: error loading model: %s\n", __func__, err.what()); + return {-1, nullptr}; + } +} + +static struct llama_model * llama_model_load_from_file_impl( + struct gguf_context * metadata, + llama_model_set_tensor_data_t set_tensor_data, + void * set_tensor_data_ud, + const std::string & path_model, + std::vector & splits, + FILE * file, + struct llama_model_params params) { + { + int n_sources_defined = 0; + if (metadata != nullptr) { + n_sources_defined++; + } + if (!path_model.empty()) { + n_sources_defined++; + } + if (file != nullptr) { + n_sources_defined++; + } + if (n_sources_defined != 1) { + LLAMA_LOG_ERROR("%s: exactly one out metadata, path_model, and file must be defined\n", __func__); + return nullptr; + } + } + ggml_time_init(); + + if (!params.vocab_only && ggml_backend_reg_count() == 0) { + LLAMA_LOG_ERROR("%s: no backends are loaded. hint: use ggml_backend_load() or ggml_backend_load_all() to load a backend before calling this function\n", __func__); + return nullptr; + } + + unsigned cur_percentage = 0; + if (params.progress_callback == NULL) { + params.progress_callback_user_data = &cur_percentage; + params.progress_callback = [](float progress, void * ctx) { + unsigned * cur_percentage_p = (unsigned *) ctx; + unsigned percentage = (unsigned) (100 * progress); + while (percentage > *cur_percentage_p) { + *cur_percentage_p = percentage; + LLAMA_LOG_CONT("."); + if (percentage >= 100) { + LLAMA_LOG_CONT("\n"); + } + } + return true; + }; + } + + const auto [status, model] = llama_model_load(metadata, set_tensor_data, set_tensor_data_ud, path_model, splits, file, params); GGML_ASSERT(status <= 0); if (status < 0) { if (status == -1) { @@ -374,7 +392,9 @@ static struct llama_model * llama_model_load_from_file_impl( LLAMA_LOG_INFO("%s: cancelled model load\n", __func__); } - llama_model_free(model); + if (model) { + llama_model_free(model); + } return nullptr; } diff --git a/examples/talk-llama/llama.h b/examples/talk-llama/llama.h index eb869814097..2ea226726ad 100644 --- a/examples/talk-llama/llama.h +++ b/examples/talk-llama/llama.h @@ -864,6 +864,9 @@ extern "C" { // work only with partial states, such as SWA KV cache or recurrent cache (e.g. Mamba) #define LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY 1 +// keeps the tensor data on device buffers (i.e. not accessible in host memory, but faster save/load) +#define LLAMA_STATE_SEQ_FLAGS_ON_DEVICE 2 + typedef uint32_t llama_state_seq_flags; LLAMA_API size_t llama_state_seq_get_size_ext( diff --git a/examples/talk-llama/models/afmoe.cpp b/examples/talk-llama/models/afmoe.cpp index 2790b12111d..602e3176afd 100644 --- a/examples/talk-llama/models/afmoe.cpp +++ b/examples/talk-llama/models/afmoe.cpp @@ -1,6 +1,112 @@ #include "models.h" -llm_build_afmoe::llm_build_afmoe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_afmoe::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); + + // Set up interleaved sliding window attention (ISWA) + // Pattern: 3 sliding - 1 full (global_attn_every_n_layers = 4) + if (hparams.n_swa > 0) { + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + uint32_t swa_period = 4; + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); + hparams.set_swa_pattern(swa_period); + + hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; + hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + } else { + hparams.swa_type = LLAMA_SWA_TYPE_NONE; + } + + // Default to sigmoid if not set + if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { + hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID; + } + + switch (hparams.n_layer) { + case 56: type = LLM_TYPE_6B; break; + case 32: type = LLM_TYPE_26B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_afmoe::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + const int64_t n_expert_shared = hparams.n_expert_shared; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + const int64_t n_ff_exp = hparams.n_ff_exp; + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + // dual attention normalization + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); + + // attention projections + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + // Q/K normalization + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + + // attention gating + layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + + // dual ffn normalization + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); + + if (static_cast(i) >= hparams.n_layer_dense_lead) { + // MoE layers + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, 0); + + // grouped expert weights + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); + + // shared expert + if (n_expert_shared > 0) { + const int64_t n_ff_shexp = n_ff_exp * n_expert_shared; + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp}, 0); + } + } else { + // Dense layers + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } +} + +std::unique_ptr llama_model_afmoe::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_afmoe::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/apertus.cpp b/examples/talk-llama/models/apertus.cpp index af44cea6054..136ff702957 100644 --- a/examples/talk-llama/models/apertus.cpp +++ b/examples/talk-llama/models/apertus.cpp @@ -1,6 +1,62 @@ #include "models.h" -llm_build_apertus::llm_build_apertus(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_apertus::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key_or_arr(LLM_KV_XIELU_ALPHA_N, hparams.xielu_alpha_n, hparams.n_layer); + ml.get_key_or_arr(LLM_KV_XIELU_ALPHA_P, hparams.xielu_alpha_p, hparams.n_layer); + ml.get_key_or_arr(LLM_KV_XIELU_BETA, hparams.xielu_beta, hparams.n_layer); + ml.get_key_or_arr(LLM_KV_XIELU_EPS, hparams.xielu_eps, hparams.n_layer); + + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_8B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_apertus::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + + if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } else { + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); + + // optional bias tensors + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), { n_embd }, TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0); + + // Q and K layernorms for Apertus + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0); + layer.attn_q_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), { n_embd_head_k }, TENSOR_NOT_REQUIRED); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0); + layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), { n_embd_head_k }, TENSOR_NOT_REQUIRED); + } +} + +std::unique_ptr llama_model_apertus::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_apertus::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/arcee.cpp b/examples/talk-llama/models/arcee.cpp index 2e71f5d9e2a..70e86d41130 100644 --- a/examples/talk-llama/models/arcee.cpp +++ b/examples/talk-llama/models/arcee.cpp @@ -1,6 +1,51 @@ #include "models.h" -llm_build_arcee::llm_build_arcee(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_arcee::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + // Arcee uses the same structure as Llama + switch (hparams.n_layer) { + case 36: type = LLM_TYPE_4B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_arcee::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} + +std::unique_ptr llama_model_arcee::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_arcee::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/arctic.cpp b/examples/talk-llama/models/arctic.cpp index f8ca6aff6ab..d8653a44639 100644 --- a/examples/talk-llama/models/arctic.cpp +++ b/examples/talk-llama/models/arctic.cpp @@ -1,6 +1,59 @@ #include "models.h" -llm_build_arctic::llm_build_arctic(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_arctic::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + if (hparams.n_expert == 128) { + switch (hparams.n_layer) { + case 35: type = LLM_TYPE_10B_128x3_66B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } else { + type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_arctic::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_embd}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_embd, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_norm_exps = create_tensor(tn(LLM_TENSOR_FFN_NORM_EXPS, "weight", i), {n_embd}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, false); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + } +} + +std::unique_ptr llama_model_arctic::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_arctic::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/arwkv7.cpp b/examples/talk-llama/models/arwkv7.cpp index 107a3bef8da..79aa8c90899 100644 --- a/examples/talk-llama/models/arwkv7.cpp +++ b/examples/talk-llama/models/arwkv7.cpp @@ -1,7 +1,123 @@ #include "models.h" +void llama_model_arwkv7::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps, false); + ml.get_key(LLM_KV_WKV_HEAD_SIZE, hparams.wkv_head_size); + ml.get_key(LLM_KV_ATTENTION_DECAY_LORA_RANK, hparams.n_lora_decay); + ml.get_key(LLM_KV_ATTENTION_ICLR_LORA_RANK, hparams.n_lora_iclr); + ml.get_key(LLM_KV_ATTENTION_VALUE_RESIDUAL_MIX_LORA_RANK, hparams.n_lora_value_res_mix); + ml.get_key(LLM_KV_ATTENTION_GATE_LORA_RANK, hparams.n_lora_gate, false); + ml.get_key(LLM_KV_TOKEN_SHIFT_COUNT, hparams.token_shift_count, false); + + switch (hparams.n_layer) { + case 12: + switch (hparams.n_embd) { + case 768: type = LLM_TYPE_190M; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 24: + switch (hparams.n_embd) { + case 1024: type = LLM_TYPE_450M; break; + case 2048: type = LLM_TYPE_1_5B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 28: + switch (hparams.n_embd) { + case 1536: type = LLM_TYPE_1_5B; break; + case 3584: type = LLM_TYPE_7B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 32: + switch (hparams.n_embd) { + case 2560: type = LLM_TYPE_2_9B; break; + case 4096: type = LLM_TYPE_7B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 61: + switch (hparams.n_embd) { + case 4096: type = LLM_TYPE_14B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_arwkv7::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + const int n_lora_decay = hparams.n_lora_decay; + const int n_lora_iclr = hparams.n_lora_iclr; + const int n_lora_value_res_mix = hparams.n_lora_value_res_mix; + const int n_lora_gate = hparams.n_lora_gate; + const int attn_hidden_size = n_embd; + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.time_mix_w0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W0, "weight", i), {n_embd}, 0); + layer.time_mix_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W1, "weight", i), {n_embd, n_lora_decay}, 0); + layer.time_mix_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {n_lora_decay, n_embd}, 0); + + layer.time_mix_a0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A0, "weight", i), {n_embd}, 0); + layer.time_mix_a1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A1, "weight", i), {n_embd, n_lora_iclr}, 0); + layer.time_mix_a2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A2, "weight", i), {n_lora_iclr, n_embd}, 0); + + if (i == 0) { + // actually not used + layer.time_mix_v0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V0, "weight", i), {n_embd}, 0); + layer.time_mix_v1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V1, "weight", i), {n_embd, n_lora_iclr}, 0); + layer.time_mix_v2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V2, "weight", i), {n_lora_iclr, n_embd}, 0); + } else { + layer.time_mix_v0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V0, "weight", i), {n_embd}, 0); + layer.time_mix_v1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V1, "weight", i), {n_embd, n_lora_value_res_mix}, 0); + layer.time_mix_v2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V2, "weight", i), {n_lora_value_res_mix, n_embd}, 0); + } + + layer.time_mix_g1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_G1, "weight", i), {n_embd, n_lora_gate}, TENSOR_NOT_REQUIRED); + layer.time_mix_g2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_G2, "weight", i), {n_lora_gate, n_embd}, TENSOR_NOT_REQUIRED); + + try { + layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 6}, 0); + } catch(std::runtime_error & e) { + // ARWKV models may not have gate tensors + layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 5}, 0); + } + + layer.time_mix_k_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_K_K, "weight", i), {attn_hidden_size}, 0); + layer.time_mix_k_a = create_tensor(tn(LLM_TENSOR_TIME_MIX_K_A, "weight", i), {attn_hidden_size}, 0); + layer.time_mix_r_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_R_K, "weight", i), {attn_hidden_size}, 0); + + layer.time_mix_key = create_tensor(tn(LLM_TENSOR_TIME_MIX_KEY, "weight", i), {attn_hidden_size, n_embd}, 0); + layer.time_mix_value = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "weight", i), {attn_hidden_size, n_embd}, 0); + layer.time_mix_receptance = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "weight", i), {attn_hidden_size, n_embd}, 0); + + layer.time_mix_ln = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.time_mix_ln_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.time_mix_output = create_tensor(tn(LLM_TENSOR_TIME_MIX_OUTPUT, "weight", i), {n_embd, attn_hidden_size}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + +} + +std::unique_ptr llama_model_arwkv7::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} -llm_build_arwkv7::llm_build_arwkv7(const llama_model & model, const llm_graph_params & params) : llm_build_rwkv7_base(model, params) { +llama_model_arwkv7::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_build_rwkv7_base(model, params) { GGML_ASSERT(n_embd == hparams.n_embd_r()); ggml_tensor * cur; diff --git a/examples/talk-llama/models/baichuan.cpp b/examples/talk-llama/models/baichuan.cpp index 2d0d05df485..4e55290e4e5 100644 --- a/examples/talk-llama/models/baichuan.cpp +++ b/examples/talk-llama/models/baichuan.cpp @@ -1,6 +1,49 @@ #include "models.h" -llm_build_baichuan::llm_build_baichuan(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_baichuan::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_7B; break; + case 40: type = LLM_TYPE_13B; break; + default: type = LLM_TYPE_UNKNOWN; + } + + if (type == LLM_TYPE_13B) { + // TODO: become GGUF KV parameter + hparams.f_max_alibi_bias = 8.0f; + } +} + +void llama_model_baichuan::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + { + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} + +std::unique_ptr llama_model_baichuan::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_baichuan::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/bailingmoe.cpp b/examples/talk-llama/models/bailingmoe.cpp index 67a7120d622..030dd4f42a4 100644 --- a/examples/talk-llama/models/bailingmoe.cpp +++ b/examples/talk-llama/models/bailingmoe.cpp @@ -1,6 +1,65 @@ #include "models.h" -llm_build_bailingmoe::llm_build_bailingmoe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_bailingmoe::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + + switch (hparams.n_layer) { + case 28: type = LLM_TYPE_16B; break; + case 88: type = LLM_TYPE_290B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_bailingmoe::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + const int64_t n_expert_shared = hparams.n_expert_shared; + + const int64_t n_ff_exp = hparams.n_ff_exp; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_head * n_rot, n_head_kv * n_rot, n_head_kv * n_rot, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_rot, n_embd}, 0); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0"); + } + + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + } +} + +std::unique_ptr llama_model_bailingmoe::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_bailingmoe::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { ggml_tensor * cur; ggml_tensor * inpL; diff --git a/examples/talk-llama/models/bailingmoe2.cpp b/examples/talk-llama/models/bailingmoe2.cpp index 497b4babd0c..e7fe3d5b45a 100644 --- a/examples/talk-llama/models/bailingmoe2.cpp +++ b/examples/talk-llama/models/bailingmoe2.cpp @@ -1,6 +1,100 @@ #include "models.h" -llm_build_bailingmoe2::llm_build_bailingmoe2(const llama_model & model, const llm_graph_params & params) : +void llama_model_bailingmoe2::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func); + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); + + // TODO: when MTP is implemented, this should probably be updated if needed + hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; + + switch (hparams.n_layer) { + case 20: type = LLM_TYPE_16B_A1B; break; + case 21: type = LLM_TYPE_16B_A1B; break; + case 32: type = LLM_TYPE_100B_A6B; break; + case 33: type = LLM_TYPE_100B_A6B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_bailingmoe2::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + const int64_t n_expert_shared = hparams.n_expert_shared; + + const int64_t n_ff_exp = hparams.n_ff_exp; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + GGML_ASSERT(n_expert > 0 && "n_expert must be > 0 for bailingmoe2"); + GGML_ASSERT(n_expert_used > 0 && "n_expert_used must be > 0 for bailingmoe2"); + + for (int i = 0; i < n_layer; ++i) { + int flags = 0; + if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { + // skip all tensors in the NextN layers + flags |= TENSOR_SKIP; + } + + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, flags); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, flags); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, flags); + + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, flags); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, flags); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, flags); + + if (static_cast(i) >= hparams.n_layer_dense_lead) { // MoE layers + const int64_t n_ff_shexp = (hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff_exp) * n_expert_shared; + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, flags); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED | flags); + + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, flags); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, flags); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, flags); + + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp}, flags); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, flags); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp}, flags); + } else { // Dense layers + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, flags); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, flags); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, flags); + } + + // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers + if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags); + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED | flags); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED | flags); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED | flags); + layer.layer_out_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, flags); + } + } +} + +std::unique_ptr llama_model_bailingmoe2::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_bailingmoe2::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); diff --git a/examples/talk-llama/models/bert.cpp b/examples/talk-llama/models/bert.cpp index 7e046cfd2a4..3c28f419ccf 100644 --- a/examples/talk-llama/models/bert.cpp +++ b/examples/talk-llama/models/bert.cpp @@ -1,6 +1,83 @@ #include "models.h" -llm_build_bert::llm_build_bert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_bert::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + + switch (hparams.n_layer) { + case 3: + type = LLM_TYPE_17M; break; // bge-micro + case 6: + type = LLM_TYPE_22M; break; // MiniLM-L6 + case 12: + switch (hparams.n_embd) { + case 384: type = LLM_TYPE_33M; break; // MiniLM-L12, bge-small + case 768: type = LLM_TYPE_109M; break; // bge-base + default: type = LLM_TYPE_UNKNOWN; + } break; + case 24: + type = LLM_TYPE_335M; break; // bge-large + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_bert::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + if (n_token_types == 0) { + throw std::runtime_error(arch_name() + " model needs to define token type count"); + } + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, TENSOR_NOT_REQUIRED); + + if (arch == LLM_ARCH_BERT) { + pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}, 0); + + cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, TENSOR_NOT_REQUIRED); + cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"), {n_embd}, TENSOR_NOT_REQUIRED); + + cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED); + cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {hparams.n_cls_out}, TENSOR_NOT_REQUIRED); + } + + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {n_embd}, 0); + tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias", 0), {n_embd}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0); + layer.attn_out_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd}, 0); + + if (hparams.moe_every_n_layers > 0 && i % hparams.moe_every_n_layers == 1) { + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + } else { + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + if (arch == LLM_ARCH_NOMIC_BERT) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + } + } + + layer.layer_out_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, 0); + layer.layer_out_norm_b = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd}, 0); + } +} + +std::unique_ptr llama_model_bert::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_bert::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/bitnet.cpp b/examples/talk-llama/models/bitnet.cpp index 71526354ca6..7e8125deec4 100644 --- a/examples/talk-llama/models/bitnet.cpp +++ b/examples/talk-llama/models/bitnet.cpp @@ -1,7 +1,54 @@ #include "models.h" +void llama_model_bitnet::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); -llm_build_bitnet::llm_build_bitnet(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + switch (hparams.n_layer) { + case 26: type = LLM_TYPE_3B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_bitnet::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_sub_norm = create_tensor(tn(LLM_TENSOR_ATTN_SUB_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wq_s = create_tensor(tn(LLM_TENSOR_ATTN_Q, "scale", i), {1}, TENSOR_NOT_REQUIRED); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wk_s = create_tensor(tn(LLM_TENSOR_ATTN_K, "scale", i), {1}, TENSOR_NOT_REQUIRED); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv_s = create_tensor(tn(LLM_TENSOR_ATTN_V, "scale", i), {1}, TENSOR_NOT_REQUIRED); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.wo_s = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "scale", i), {1}, TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_sub_norm = create_tensor(tn(LLM_TENSOR_FFN_SUB_NORM, "weight", i), {n_ff}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_gate_s = create_tensor(tn(LLM_TENSOR_FFN_GATE, "scale", i), {1}, TENSOR_NOT_REQUIRED); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_s = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "scale", i), {1}, TENSOR_NOT_REQUIRED); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_s = create_tensor(tn(LLM_TENSOR_FFN_UP, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } +} + +std::unique_ptr llama_model_bitnet::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_bitnet::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/bloom.cpp b/examples/talk-llama/models/bloom.cpp index f3b0999bf54..b600fb0c954 100644 --- a/examples/talk-llama/models/bloom.cpp +++ b/examples/talk-llama/models/bloom.cpp @@ -1,6 +1,68 @@ #include "models.h" -llm_build_bloom::llm_build_bloom(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_bloom::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + + switch (hparams.n_layer) { + case 24: type = LLM_TYPE_1B; break; + case 30: + switch (hparams.n_embd) { + case 2560: type = LLM_TYPE_3B; break; + case 4096: type = LLM_TYPE_7B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + default: type = LLM_TYPE_UNKNOWN; + } + + // TODO: become GGUF KV parameter + hparams.f_max_alibi_bias = 8.0f; +} + +void llama_model_bloom::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {n_embd}, 0); + tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias", 0), {n_embd}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); + layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); + } +} + +std::unique_ptr llama_model_bloom::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_bloom::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/chameleon.cpp b/examples/talk-llama/models/chameleon.cpp index 21deaba1a6d..8510b9e29f8 100644 --- a/examples/talk-llama/models/chameleon.cpp +++ b/examples/talk-llama/models/chameleon.cpp @@ -1,8 +1,56 @@ #include "models.h" - #include -llm_build_chameleon::llm_build_chameleon(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_chameleon::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + hparams.f_norm_eps = 1e-5; // eps for qk-norm, torch default + ml.get_key(LLM_KV_SWIN_NORM, hparams.swin_norm, false); + + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_7B; break; + case 48: type = LLM_TYPE_34B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_chameleon::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k, n_head}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k, n_head_kv}, 0); + layer.attn_q_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {n_embd_head_k, n_head}, TENSOR_NOT_REQUIRED); + layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {n_embd_head_k, n_head_kv}, TENSOR_NOT_REQUIRED); + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} + +std::unique_ptr llama_model_chameleon::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_chameleon::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/chatglm.cpp b/examples/talk-llama/models/chatglm.cpp index 7d4a43fdca5..e898eff7939 100644 --- a/examples/talk-llama/models/chatglm.cpp +++ b/examples/talk-llama/models/chatglm.cpp @@ -1,7 +1,60 @@ #include "models.h" +void llama_model_chatglm::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 28: { + if (hparams.n_head(0) == 16) { + type = LLM_TYPE_1_5B; + } else { + type = LLM_TYPE_6B; + } + } break; + case 40: { + if (hparams.n_head(0) == 24) { + type = LLM_TYPE_4B; + } else { + type = LLM_TYPE_9B; + } + } break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_chatglm::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff * 2}, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + } +} + +std::unique_ptr llama_model_chatglm::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} -llm_build_chatglm::llm_build_chatglm(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +llama_model_chatglm::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/codeshell.cpp b/examples/talk-llama/models/codeshell.cpp index 3ceb5835b85..e9e85d96713 100644 --- a/examples/talk-llama/models/codeshell.cpp +++ b/examples/talk-llama/models/codeshell.cpp @@ -1,6 +1,55 @@ #include "models.h" -llm_build_codeshell::llm_build_codeshell(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_codeshell::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + switch (hparams.n_layer) { + case 42: type = LLM_TYPE_7B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_codeshell::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if tok embd is NULL, init from output + if (tok_embd == NULL) { + tok_embd = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); + } +} + +std::unique_ptr llama_model_codeshell::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_codeshell::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/cogvlm.cpp b/examples/talk-llama/models/cogvlm.cpp index be3eeeddac7..79236121bd5 100644 --- a/examples/talk-llama/models/cogvlm.cpp +++ b/examples/talk-llama/models/cogvlm.cpp @@ -1,6 +1,55 @@ #include "models.h" -llm_build_cogvlm::llm_build_cogvlm(const llama_model & model, const llm_graph_params & params) : +void llama_model_cogvlm::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_13B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_cogvlm::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd_head_k * n_head * 3}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.visexp_attn_wqkv = create_tensor(tn(LLM_TENSOR_VISEXP_ATTN_QKV, "weight", i), {n_embd, n_embd_head_k * n_head * 3}, 0); + layer.visexp_attn_wo = create_tensor(tn(LLM_TENSOR_VISEXP_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + + layer.visexp_ffn_gate = create_tensor(tn(LLM_TENSOR_VISEXP_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.visexp_ffn_down = create_tensor(tn(LLM_TENSOR_VISEXP_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.visexp_ffn_up = create_tensor(tn(LLM_TENSOR_VISEXP_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} + +std::unique_ptr llama_model_cogvlm::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_cogvlm::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); const float kq_scale = 1.0f / sqrtf(float(n_embd_head)); diff --git a/examples/talk-llama/models/cohere2-iswa.cpp b/examples/talk-llama/models/cohere2.cpp similarity index 60% rename from examples/talk-llama/models/cohere2-iswa.cpp rename to examples/talk-llama/models/cohere2.cpp index 670b08e7d97..12edbae1094 100644 --- a/examples/talk-llama/models/cohere2-iswa.cpp +++ b/examples/talk-llama/models/cohere2.cpp @@ -1,6 +1,53 @@ #include "models.h" -llm_build_cohere2_iswa::llm_build_cohere2_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_cohere2::load_arch_hparams(llama_model_loader & ml) { + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + uint32_t swa_period = 4; + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); + hparams.set_swa_pattern(swa_period); + hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; + hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; + + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_8B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_cohere2::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + // init output from the input tok embed + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, + TENSOR_DUPLICATED); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0); + } +} + +std::unique_ptr llama_model_cohere2::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_cohere2::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/command-r.cpp b/examples/talk-llama/models/command-r.cpp index 067961caa08..decb89f547b 100644 --- a/examples/talk-llama/models/command-r.cpp +++ b/examples/talk-llama/models/command-r.cpp @@ -1,8 +1,48 @@ #include "models.h" +void llama_model_command_r::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + switch (hparams.n_layer) { + case 40: type = LLM_TYPE_35B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_command_r::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + // init output from the input tok embed + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + if (n_layer >= 64){ + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k, n_head}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k, n_head_kv}, 0); + } + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} + +std::unique_ptr llama_model_command_r::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} -llm_build_command_r::llm_build_command_r(const llama_model & model, const llm_graph_params & params) : +llama_model_command_r::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); diff --git a/examples/talk-llama/models/dbrx.cpp b/examples/talk-llama/models/dbrx.cpp index 0e882721807..bce6b04bcf9 100644 --- a/examples/talk-llama/models/dbrx.cpp +++ b/examples/talk-llama/models/dbrx.cpp @@ -1,6 +1,50 @@ #include "models.h" -llm_build_dbrx::llm_build_dbrx(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_dbrx::load_arch_hparams(llama_model_loader & ml) { +ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); +ml.get_key(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv); + +switch (hparams.n_layer) { + case 40: type = LLM_TYPE_16x12B; break; + default: type = LLM_TYPE_UNKNOWN; +} + } + +void llama_model_dbrx::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + if (n_expert == 0) { + throw std::runtime_error("DBRX model cannot have zero experts"); + } + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + } +} + +std::unique_ptr llama_model_dbrx::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_dbrx::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/deci.cpp b/examples/talk-llama/models/deci.cpp index 30272eabd69..9f1a959c32c 100644 --- a/examples/talk-llama/models/deci.cpp +++ b/examples/talk-llama/models/deci.cpp @@ -1,6 +1,82 @@ #include "models.h" -llm_build_deci::llm_build_deci(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_deci::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_7B; break; + case 80: type = LLM_TYPE_70B; break; + case 162: type = LLM_TYPE_405B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_deci::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(i); + const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(i); + const int64_t n_ff = hparams.n_ff(i); + const int64_t n_head = hparams.n_head(i); + const int64_t n_head_kv = hparams.n_head_kv(i); + + if (n_head_kv == 0 && n_head > 0) { + // linear attention for DeciLMCausalModel + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + } + else if (n_head_kv > 0) { + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + } + + // optional bias tensors + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + if (n_ff > 0) { + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + } + + if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + else { + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + + if (n_ff > 0) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + + // optional MLP bias + layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + } +} + +std::unique_ptr llama_model_deci::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_deci::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/deepseek.cpp b/examples/talk-llama/models/deepseek.cpp index 671b72dfead..c7946059662 100644 --- a/examples/talk-llama/models/deepseek.cpp +++ b/examples/talk-llama/models/deepseek.cpp @@ -1,6 +1,77 @@ #include "models.h" -llm_build_deepseek::llm_build_deepseek(const llama_model & model, const llm_graph_params & params) : +void llama_model_deepseek::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); + + switch (hparams.n_ff_exp) { + case 1408: type = LLM_TYPE_16B; break; + case 1792: type = LLM_TYPE_20B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_deepseek::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + const int64_t n_expert_shared = hparams.n_expert_shared; + + + const int64_t n_ff_exp = hparams.n_ff_exp; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + // try to load output.weight, if not found, use token_embd (tied embeddings) + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + if (!output) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + if (i < (int) hparams.n_layer_dense_lead) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } else { + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0"); + } + + // MoE branch + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + + // Shared expert branch + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + } + } +} + +std::unique_ptr llama_model_deepseek::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_deepseek::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); diff --git a/examples/talk-llama/models/deepseek2.cpp b/examples/talk-llama/models/deepseek2.cpp index 303fc72c610..1fe54adc13e 100644 --- a/examples/talk-llama/models/deepseek2.cpp +++ b/examples/talk-llama/models/deepseek2.cpp @@ -1,6 +1,149 @@ #include "models.h" -llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_graph_params & params) : +void llama_model_deepseek2::load_arch_hparams(llama_model_loader & ml) { + uint32_t n_vocab = 0; + ml.get_key(LLM_KV_VOCAB_SIZE, n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, n_vocab, false); + + // lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B, Kanana-2-30B-A3B + const bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26 || (hparams.n_layer == 48 && n_vocab == 128256)); + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); + if (!is_lite) { + ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); + } + ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla_impl, false); + ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla_impl, false); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); + if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { + // for compatibility with existing DeepSeek V2 and V2.5 GGUFs + // that have no expert_gating_func model parameter set + if ((hparams.n_layer == 47 || hparams.n_layer == 48) && n_vocab == 154880) { + // GLM 4.7 Lite + hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID; + } else { + hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX; + } + } + + if (ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, false)) { + // [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX] + // cancel the factor from the convert script + hparams.rope_yarn_log_mul /= 0.1f; + } + + // (optional) temperature tuning - used by mistral-large + ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_SCALE, hparams.f_attn_temp_scale, false); + ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_LENGTH, hparams.n_attn_temp_floor_scale, false); // FIXME why not use temperature_length? + + hparams.f_attn_temp_offset = 0.0f; + + switch (hparams.n_layer) { + case 27: type = LLM_TYPE_16B; break; + case 47: type = LLM_TYPE_30B_A3B; break; + case 60: type = LLM_TYPE_236B; break; + case 61: type = LLM_TYPE_671B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_deepseek2::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + const int64_t n_expert_shared = hparams.n_expert_shared; + + const bool is_mla = hparams.is_mla(); + + // note: these are the actual head sizes you get when treating as MHA or after "decompression" using wv_b for MLA + const int64_t n_embd_head_k_mla = hparams.n_embd_head_k_mla(); + const int64_t n_embd_head_v_mla = hparams.n_embd_head_v_mla(); + + const int64_t n_embd_head_qk_rope = hparams.n_rot(); + const int64_t n_embd_head_qk_nope = n_embd_head_k_mla - n_embd_head_qk_rope; + GGML_ASSERT(n_embd_head_qk_nope >= 1); + + const int64_t q_lora_rank = hparams.n_lora_q; + const int64_t kv_lora_rank = hparams.n_lora_kv; + + const int64_t n_ff_exp = hparams.n_ff_exp; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + // try to load output.weight, if not found, use token_embd (tied embeddings) + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + if (!output) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + if (q_lora_rank > 0) { + layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, 0); + } + + layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, 0); + + if (q_lora_rank > 0) { + layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, 0); + layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k_mla}, 0); + } else { + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_head * n_embd_head_k_mla}, 0); + } + + layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + n_embd_head_qk_rope}, 0); + + // note: only old legacy GGUF files will have the unsplit wkv_b tensor in + if (is_mla) { + layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K_B, "weight", i), {n_embd_head_qk_nope, kv_lora_rank, n_head}, 0); + layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_embd_head_v_mla, n_head}, 0); + } else { + layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v_mla)}, 0); + } + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_embd_head_v_mla, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + if (i < (int) hparams.n_layer_dense_lead) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } else { + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0"); + } + + // MoE branch + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + create_tensor_gate_up_exps(layer, i, n_embd, n_ff_exp, n_expert, 0); + + // Shared expert branch + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + } + } +} + +std::unique_ptr llama_model_deepseek2::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_deepseek2::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { // lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B bool is_ocr = model.arch == LLM_ARCH_DEEPSEEK2OCR; diff --git a/examples/talk-llama/models/deepseek2ocr.cpp b/examples/talk-llama/models/deepseek2ocr.cpp new file mode 100644 index 00000000000..f9e4c98785c --- /dev/null +++ b/examples/talk-llama/models/deepseek2ocr.cpp @@ -0,0 +1,82 @@ +#include "models.h" + +void llama_model_deepseek2ocr::load_arch_hparams(llama_model_loader & ml) { + // similar to deepseek2, but without MLA + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); + + if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { + hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX; + } + + switch (hparams.n_layer) { + case 12: type = LLM_TYPE_3B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_deepseek2ocr::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + const int64_t n_expert_shared = hparams.n_expert_shared; + + // similar to deepseek2, but without MLA + const int64_t n_ff_exp = hparams.n_ff_exp; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + // try to load output.weight, if not found, use token_embd (tied embeddings) + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + if (!output) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + // norm + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + if (i < (int) hparams.n_layer_dense_lead) { + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + } else { + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0"); + } + + // MoE branch + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + create_tensor_gate_up_exps(layer, i, n_embd, n_ff_exp, n_expert, 0); + + // Shared expert branch + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + } + } +} + +std::unique_ptr llama_model_deepseek2ocr::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + diff --git a/examples/talk-llama/models/dots1.cpp b/examples/talk-llama/models/dots1.cpp index 5d1750fedda..93cbcf9d931 100644 --- a/examples/talk-llama/models/dots1.cpp +++ b/examples/talk-llama/models/dots1.cpp @@ -1,6 +1,76 @@ #include "models.h" -llm_build_dots1::llm_build_dots1(const llama_model & model, const llm_graph_params & params) : +void llama_model_dots1::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); + switch (hparams.n_layer) { + case 62: type = LLM_TYPE_142B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_dots1::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + const int64_t n_expert_shared = hparams.n_expert_shared; + + const int64_t n_ff_exp = hparams.n_ff_exp; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_head_k * n_head, n_embd_head_k * n_head, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + if (i < (int) hparams.n_layer_dense_lead) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } else { + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0"); + } + + // MoE branch + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + + // Shared expert branch + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + } + } +} + +std::unique_ptr llama_model_dots1::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_dots1::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); diff --git a/examples/talk-llama/models/dream.cpp b/examples/talk-llama/models/dream.cpp index 8e7d9ae64c7..60a3f0ec285 100644 --- a/examples/talk-llama/models/dream.cpp +++ b/examples/talk-llama/models/dream.cpp @@ -1,6 +1,54 @@ #include "models.h" -llm_build_dream::llm_build_dream(const llama_model & model, const llm_graph_params & params) : +void llama_model_dream::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + // Dream models are primarily 7B with 28 layers + switch (hparams.n_layer) { + case 28: + type = LLM_TYPE_7B; + break; + default: + type = LLM_TYPE_UNKNOWN; + } + // Set non-causal attention for diffusion models + hparams.causal_attn = false; +} + +void llama_model_dream::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), {n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} + +std::unique_ptr llama_model_dream::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_dream::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { //copied from qwen2 const int64_t n_embd_head = hparams.n_embd_head_v(); diff --git a/examples/talk-llama/models/ernie4-5-moe.cpp b/examples/talk-llama/models/ernie4-5-moe.cpp index fc6a3e17a09..2bd01a2c512 100644 --- a/examples/talk-llama/models/ernie4-5-moe.cpp +++ b/examples/talk-llama/models/ernie4-5-moe.cpp @@ -1,6 +1,10 @@ #include "models.h" -llm_build_ernie4_5_moe::llm_build_ernie4_5_moe(const llama_model & model, const llm_graph_params & params) : +std::unique_ptr llama_model_ernie4_5_moe::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_ernie4_5_moe::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); diff --git a/examples/talk-llama/models/ernie4-5.cpp b/examples/talk-llama/models/ernie4-5.cpp index 033ba409eab..fa989fe92cd 100644 --- a/examples/talk-llama/models/ernie4-5.cpp +++ b/examples/talk-llama/models/ernie4-5.cpp @@ -1,6 +1,79 @@ #include "models.h" -llm_build_ernie4_5::llm_build_ernie4_5(const llama_model & model, const llm_graph_params & params) : +void llama_model_ernie4_5::load_arch_hparams(llama_model_loader & ml) { + // paddleocr need mrope_section + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, false); + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + if (arch == LLM_ARCH_ERNIE4_5_MOE) { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); + ml.get_key(LLM_KV_INTERLEAVE_MOE_LAYER_STEP, hparams.n_moe_layer_step); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); + } + + switch (hparams.n_layer) { + case 18: type = LLM_TYPE_0_3B; break; + case 28: type = LLM_TYPE_21B_A3B; break; + case 54: type = LLM_TYPE_300B_A47B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_ernie4_5::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + // optional bias tensors + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + if (arch == LLM_ARCH_ERNIE4_5_MOE && static_cast(i) >= hparams.n_layer_dense_lead) { // MoE layers + int n_ff_exp = hparams.n_ff_exp; + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); + + // Shared expert (if present) + if (hparams.n_ff_shexp > 0) { + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, hparams.n_ff_shexp}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd }, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, hparams.n_ff_shexp}, 0); + } + } else { // Dense layers + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } +} + +std::unique_ptr llama_model_ernie4_5::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_ernie4_5::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); diff --git a/examples/talk-llama/models/eurobert.cpp b/examples/talk-llama/models/eurobert.cpp index 43fff4daf3a..ddf13c3028f 100644 --- a/examples/talk-llama/models/eurobert.cpp +++ b/examples/talk-llama/models/eurobert.cpp @@ -1,6 +1,41 @@ #include "models.h" -llm_build_eurobert::llm_build_eurobert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_eurobert::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + if (hparams.n_layer == 12) { + type = LLM_TYPE_SMALL; // 0.2B + } +} + +void llama_model_eurobert::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + } +} + +std::unique_ptr llama_model_eurobert::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_eurobert::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/exaone-moe.cpp b/examples/talk-llama/models/exaone-moe.cpp index 7b88a31d39d..54bb3ca86b3 100644 --- a/examples/talk-llama/models/exaone-moe.cpp +++ b/examples/talk-llama/models/exaone-moe.cpp @@ -1,6 +1,117 @@ #include "models.h" -llm_build_exaone_moe::llm_build_exaone_moe(const llama_model & model, const llm_graph_params & params) : +void llama_model_exaone_moe::load_arch_hparams(llama_model_loader & ml) { + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + hparams.n_swa = 128; + uint32_t swa_period = 4; + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); + hparams.set_swa_pattern(swa_period); + hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; + hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; + + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared, false); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); + + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); + + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_30B_A3B; break; + case 48: + case 49: type = LLM_TYPE_235B_A22B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_exaone_moe::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + const int64_t n_ff_exp = hparams.n_ff_exp; + const int64_t n_ff_shexp = hparams.n_ff_shexp > 0 ? hparams.n_ff_shexp : n_ff_exp; + const int64_t head_dim = hparams.n_embd_head_k(); + const int64_t n_qo_dim = n_head * head_dim; + const int64_t n_kv_dim = n_head_kv * head_dim; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + int flags = 0; + if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { + // skip all tensors in the NextN layers + flags |= TENSOR_SKIP; + } + + auto & layer = layers[i]; + create_tensor_qkv(layer, i, n_embd, n_qo_dim, n_kv_dim, n_kv_dim, flags); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_qo_dim, n_embd}, flags); + + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0) | flags); + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, flags); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, flags); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, flags); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, flags); + + // dense layers for first n_layer_dense_lead layers or nextn_predict_layers layers at the end + if (i < (int) hparams.n_layer_dense_lead || (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers)) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, flags); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, flags); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, flags); + } else { + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, flags); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED | flags); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0"); + } + + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, flags); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, flags); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, flags); + + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp}, flags); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, flags); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp}, flags); + } + + // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers + if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), {2 * n_embd, n_embd}, flags); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), {n_embd}, flags); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), {n_embd}, flags); + + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), {n_embd}, flags | TENSOR_NOT_REQUIRED); + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), {n_embd, n_vocab}, flags | TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), {n_embd, n_vocab}, flags | TENSOR_NOT_REQUIRED); + } + } +} + +std::unique_ptr llama_model_exaone_moe::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_exaone_moe::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_k(); diff --git a/examples/talk-llama/models/exaone.cpp b/examples/talk-llama/models/exaone.cpp index 4f845bf4106..75d5f60631c 100644 --- a/examples/talk-llama/models/exaone.cpp +++ b/examples/talk-llama/models/exaone.cpp @@ -1,6 +1,49 @@ #include "models.h" -llm_build_exaone::llm_build_exaone(const llama_model & model, const llm_graph_params & params) : +void llama_model_exaone::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_8B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_exaone::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} + +std::unique_ptr llama_model_exaone::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_exaone::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); diff --git a/examples/talk-llama/models/exaone4.cpp b/examples/talk-llama/models/exaone4.cpp index 34bee3b8fe9..5506e76424d 100644 --- a/examples/talk-llama/models/exaone4.cpp +++ b/examples/talk-llama/models/exaone4.cpp @@ -1,7 +1,71 @@ #include "models.h" +void llama_model_exaone4::load_arch_hparams(llama_model_loader & ml) { + if (hparams.n_layer == 64) { // 32B + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + hparams.n_swa = 4096; + uint32_t swa_period = 4; + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); + hparams.set_swa_pattern(swa_period); + + hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; + hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + } + + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 30: type = LLM_TYPE_1_2B; break; + case 64: type = LLM_TYPE_32B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_exaone4::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); + } +} + +std::unique_ptr llama_model_exaone4::build_arch_graph(const llm_graph_params & params) const { + if (hparams.swa_type == LLAMA_SWA_TYPE_STANDARD) { + return std::make_unique>(*this, params); + } else { + return std::make_unique>(*this, params); + } +} + template -llm_build_exaone4::llm_build_exaone4(const llama_model & model, const llm_graph_params & params) : +llama_model_exaone4::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_k(); @@ -108,5 +172,5 @@ llm_build_exaone4::llm_build_exaone4(const llama_model & model, const llm_ } // Explicit template instantiations -template struct llm_build_exaone4; -template struct llm_build_exaone4; +template struct llama_model_exaone4::graph; +template struct llama_model_exaone4::graph; diff --git a/examples/talk-llama/models/falcon-h1.cpp b/examples/talk-llama/models/falcon-h1.cpp index 05accf90fad..d353befdb8e 100644 --- a/examples/talk-llama/models/falcon-h1.cpp +++ b/examples/talk-llama/models/falcon-h1.cpp @@ -1,6 +1,115 @@ #include "models.h" -llm_build_falcon_h1::llm_build_falcon_h1(const llama_model & model, const llm_graph_params & params) : +void llama_model_falcon_h1::load_arch_hparams(llama_model_loader & ml) { + // Common parameters + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + // SSM parameters + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); + + std::fill(hparams.recurrent_layer_arr.begin(), hparams.recurrent_layer_arr.end(), true); + + switch (hparams.n_layer) { + case 36: + type = LLM_TYPE_0_5B; break; + case 24: + type = LLM_TYPE_1_5B; break; + case 66: + type = LLM_TYPE_1B; break; + case 32: + type = LLM_TYPE_3B; break; + case 44: + type = LLM_TYPE_7B; break; + case 72: + type = LLM_TYPE_34B; break; + default: + type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_falcon_h1::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + // Common + const int64_t hidden_size = hparams.n_embd; // hidden_size + + // mamba2 Mixer SSM params + const int64_t ssm_conv_kernel_size = hparams.ssm_d_conv; // ssm_conv_kernel_size + const int64_t ssm_n_groups = hparams.ssm_n_group; // ssm_n_groups + const int64_t ssm_state_size = hparams.ssm_d_state; // ssm_state_size + const int64_t ssm_intermediate_size = hparams.ssm_d_inner; // TODO expand + const int64_t ssm_num_heads = hparams.ssm_dt_rank; // ssm_num_heads + const int64_t ssm_conv_dim = ssm_intermediate_size + 2 * ssm_n_groups * ssm_state_size; + const int64_t ssm_projection_size = ssm_intermediate_size + ssm_conv_dim + ssm_num_heads; + + // attn params + const int64_t attn_num_attention_head = hparams.n_head(0); // rename to: attn_num_attention_head + const int64_t attn_num_key_value_head = hparams.n_head_kv(0); + + // ffn params + const int64_t ffn_intermediate_size = hparams.n_ff(0); + + // embeddings + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {hidden_size, n_vocab}, 0); + + // output + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {hidden_size, n_vocab}, TENSOR_NOT_REQUIRED); + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {hidden_size}, 0); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {hidden_size, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + /*SSM LAYERS*/ + // ssm in + layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {hidden_size, ssm_projection_size}, 0); + // ssm 1d conv + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {ssm_conv_kernel_size, ssm_conv_dim}, 0); + layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {ssm_conv_dim}, TENSOR_NOT_REQUIRED); + // ssm_dt + layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {ssm_num_heads}, 0); + // no "weight" suffix for these + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, ssm_num_heads}, 0); + layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, ssm_num_heads}, 0); + // ssm_norm + layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {ssm_intermediate_size / ssm_n_groups, ssm_n_groups}, TENSOR_NOT_REQUIRED); + // out_proj + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {ssm_intermediate_size, hidden_size}, 0); + + /*ATTENTION LAYERS*/ + // attention layers (with optional bias) + create_tensor_qkv(layer, i, hidden_size, n_embd_head_k * attn_num_attention_head, attn_num_key_value_head * n_embd_head_k, attn_num_key_value_head * n_embd_head_v, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * attn_num_attention_head, hidden_size}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {hidden_size}, TENSOR_NOT_REQUIRED); + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {hidden_size}, 0); + + + // feed forward (w/ optional biases) + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, i), {hidden_size}, 0); + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {hidden_size, ffn_intermediate_size}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { ffn_intermediate_size, hidden_size}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {hidden_size, ffn_intermediate_size}, 0); + + layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {ffn_intermediate_size}, TENSOR_NOT_REQUIRED); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {hidden_size}, TENSOR_NOT_REQUIRED); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {ffn_intermediate_size}, TENSOR_NOT_REQUIRED); + } +} + +std::unique_ptr llama_model_falcon_h1::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_falcon_h1::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_build_mamba_base(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); diff --git a/examples/talk-llama/models/falcon.cpp b/examples/talk-llama/models/falcon.cpp index 2f65fa56e1f..75f2cfef560 100644 --- a/examples/talk-llama/models/falcon.cpp +++ b/examples/talk-llama/models/falcon.cpp @@ -1,6 +1,53 @@ #include "models.h" -llm_build_falcon::llm_build_falcon(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_falcon::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_7B; break; + case 60: type = LLM_TYPE_40B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_falcon::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + { + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + if (!output) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); // needs to be on GPU + } + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + layer.attn_norm_2 = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} + +std::unique_ptr llama_model_falcon::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_falcon::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/gemma-embedding.cpp b/examples/talk-llama/models/gemma-embedding.cpp index b6de9551c52..4e07f5f2bda 100644 --- a/examples/talk-llama/models/gemma-embedding.cpp +++ b/examples/talk-llama/models/gemma-embedding.cpp @@ -1,6 +1,78 @@ #include "models.h" -llm_build_gemma_embedding::llm_build_gemma_embedding(const llama_model & model, const llm_graph_params & params) : +void llama_model_gemma_embedding::load_arch_hparams(llama_model_loader & ml) { + hparams.swa_type = LLAMA_SWA_TYPE_SYMMETRIC; + uint32_t swa_period = 6; + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); + hparams.set_swa_pattern(swa_period); + + hparams.causal_attn = false; // embeddings do not use causal attention + + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + //applied only if model converted with --sentence-transformers-dense-modules + ml.get_key(LLM_KV_DENSE_2_FEAT_IN, hparams.dense_2_feat_in, false); + ml.get_key(LLM_KV_DENSE_2_FEAT_OUT, hparams.dense_2_feat_out, false); + ml.get_key(LLM_KV_DENSE_3_FEAT_IN, hparams.dense_3_feat_in, false); + ml.get_key(LLM_KV_DENSE_3_FEAT_OUT, hparams.dense_3_feat_out, false); + + GGML_ASSERT((hparams.dense_2_feat_in == 0 || hparams.dense_2_feat_in == hparams.n_embd) && "dense_2_feat_in must be equal to n_embd"); + GGML_ASSERT((hparams.dense_3_feat_out == 0 || hparams.dense_3_feat_out == hparams.n_embd) && "dense_3_feat_out must be equal to n_embd"); + + switch (hparams.n_layer) { + case 24: type = LLM_TYPE_0_3B; break; + default: type = LLM_TYPE_UNKNOWN; + } + hparams.f_attention_scale = 1.0f / std::sqrt(float(hparams.n_embd_head_k())); + +} + +void llama_model_gemma_embedding::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + // Dense linear weights + dense_2_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_2_OUT, "weight"), {n_embd, hparams.dense_2_feat_out}, TENSOR_NOT_REQUIRED); + dense_3_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_3_OUT, "weight"), {hparams.dense_3_feat_in, n_embd}, TENSOR_NOT_REQUIRED); + + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); + } +} + +std::unique_ptr llama_model_gemma_embedding::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_gemma_embedding::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_k(); diff --git a/examples/talk-llama/models/gemma.cpp b/examples/talk-llama/models/gemma.cpp index 09d2ff8bae7..06731670007 100644 --- a/examples/talk-llama/models/gemma.cpp +++ b/examples/talk-llama/models/gemma.cpp @@ -1,6 +1,44 @@ #include "models.h" -llm_build_gemma::llm_build_gemma(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_gemma::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 18: type = LLM_TYPE_2B; break; + case 28: type = LLM_TYPE_7B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_gemma::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + } +} + +std::unique_ptr llama_model_gemma::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_gemma::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); ggml_tensor * cur; diff --git a/examples/talk-llama/models/gemma2-iswa.cpp b/examples/talk-llama/models/gemma2.cpp similarity index 53% rename from examples/talk-llama/models/gemma2-iswa.cpp rename to examples/talk-llama/models/gemma2.cpp index 0ef07df8d01..6255bf740fc 100644 --- a/examples/talk-llama/models/gemma2-iswa.cpp +++ b/examples/talk-llama/models/gemma2.cpp @@ -1,6 +1,65 @@ #include "models.h" -llm_build_gemma2_iswa::llm_build_gemma2_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_gemma2::load_arch_hparams(llama_model_loader & ml) { + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + hparams.n_swa = 4096; // default value of gemma 2 + uint32_t swa_period = 2; + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); + hparams.set_swa_pattern(swa_period); + hparams.attn_soft_cap = true; + hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; + hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; + + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false); + ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false); + + switch (hparams.n_layer) { + case 26: type = LLM_TYPE_2B; break; + case 42: type = LLM_TYPE_9B; break; + case 46: type = LLM_TYPE_27B; break; + default: type = LLM_TYPE_UNKNOWN; + } + + // ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/config.py#L173 + hparams.f_attention_scale = type == LLM_TYPE_27B + ? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0))) + : 1.0f / std::sqrt(float(hparams.n_embd_head_k())); +} + +void llama_model_gemma2::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); + } +} + +std::unique_ptr llama_model_gemma2::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_gemma2::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_k(); ggml_tensor * cur; diff --git a/examples/talk-llama/models/gemma3.cpp b/examples/talk-llama/models/gemma3.cpp index 0da4af21c17..ee510fe38b0 100644 --- a/examples/talk-llama/models/gemma3.cpp +++ b/examples/talk-llama/models/gemma3.cpp @@ -1,7 +1,87 @@ #include "models.h" +void llama_model_gemma3::load_arch_hparams(llama_model_loader & ml) { + const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); + if (found_swa && hparams.n_swa > 0) { + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + uint32_t swa_period = 6; + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); + hparams.set_swa_pattern(swa_period); + + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + } else { + hparams.swa_type = LLAMA_SWA_TYPE_NONE; + } + + hparams.f_final_logit_softcapping = 0.0f; + ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 18: type = LLM_TYPE_270M; break; + case 26: type = LLM_TYPE_1B; break; + case 32: type = LLM_TYPE_8B; break; // Rnj-1 + case 34: type = LLM_TYPE_4B; break; + case 48: type = LLM_TYPE_12B; break; + case 62: type = LLM_TYPE_27B; break; + default: type = LLM_TYPE_UNKNOWN; + } + + // ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/config.py#L289 + hparams.f_attention_scale = type == LLM_TYPE_27B + ? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0))) + : 1.0f / std::sqrt(float(hparams.n_embd_head_k())); +} + +void llama_model_gemma3::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + // Dense linear weights + dense_2_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_2_OUT, "weight"), {n_embd, hparams.dense_2_feat_out}, TENSOR_NOT_REQUIRED); + dense_3_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_3_OUT, "weight"), {hparams.dense_3_feat_in, n_embd}, TENSOR_NOT_REQUIRED); + + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); + } +} + +std::unique_ptr llama_model_gemma3::build_arch_graph(const llm_graph_params & params) const { + if (hparams.swa_type == LLAMA_SWA_TYPE_STANDARD) { + return std::make_unique>(*this, params); + } else { + return std::make_unique>(*this, params); + } +} + template -llm_build_gemma3::llm_build_gemma3(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +llama_model_gemma3::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_k(); ggml_tensor * cur; @@ -141,5 +221,5 @@ llm_build_gemma3::llm_build_gemma3(const llama_model & model, const llm_gr ggml_build_forward_expand(gf, cur); } -template struct llm_build_gemma3; -template struct llm_build_gemma3; +template struct llama_model_gemma3::graph; +template struct llama_model_gemma3::graph; diff --git a/examples/talk-llama/models/gemma3n-iswa.cpp b/examples/talk-llama/models/gemma3n.cpp similarity index 76% rename from examples/talk-llama/models/gemma3n-iswa.cpp rename to examples/talk-llama/models/gemma3n.cpp index f8095417e06..881499b0ca7 100644 --- a/examples/talk-llama/models/gemma3n-iswa.cpp +++ b/examples/talk-llama/models/gemma3n.cpp @@ -1,5 +1,86 @@ #include "models.h" +void llama_model_gemma3n::load_arch_hparams(llama_model_loader & ml) { + uint32_t swa_period = 5; + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + hparams.set_swa_pattern(swa_period); + + hparams.n_layer_kv_from_start = 20; + hparams.f_attention_scale = 1.0f; + + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 30: type = LLM_TYPE_E2B; break; + case 35: type = LLM_TYPE_E4B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_gemma3n::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + const int64_t n_altup = hparams.n_altup; + const int64_t laurel_rank = hparams.laurel_rank; + const int64_t n_embd_altup = hparams.n_embd_altup; + + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + altup_proj = create_tensor(tn(LLM_TENSOR_ALTUP_PROJ, "weight"), {n_embd, n_embd, n_altup - 1}, 0); + altup_unembd_proj = create_tensor(tn(LLM_TENSOR_ALTUP_UNEMBD_PROJ, "weight"), {n_embd, n_embd, n_altup - 1}, 0); + + per_layer_tok_embd = create_tensor(tn(LLM_TENSOR_PER_LAYER_TOKEN_EMBD, "weight"), {n_embd_altup * n_layer, n_vocab}, 0); + per_layer_model_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_MODEL_PROJ, "weight", 0), {n_embd, n_embd_altup * n_layer}, 0); + per_layer_proj_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ_NORM, "weight", 0), {n_embd_altup}, 0); + + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); + + // altup & laurel + layer.per_layer_inp_gate = create_tensor(tn(LLM_TENSOR_PER_LAYER_INP_GATE, "weight", i), {n_embd, n_embd_altup}, 0); + layer.per_layer_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ, "weight", i), {n_embd_altup, n_embd}, 0); + layer.per_layer_post_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_POST_NORM, "weight", i), {n_embd}, 0); + layer.altup_correct_coef = create_tensor(tn(LLM_TENSOR_ALTUP_CORRECT_COEF, "weight", i), {n_altup, n_altup}, 0); + layer.altup_correct_scale = create_tensor(tn(LLM_TENSOR_ALTUP_CORRECT_SCALE, "weight", i), {n_embd}, 0); + layer.altup_predict_coef = create_tensor(tn(LLM_TENSOR_ALTUP_PREDICT_COEF, "weight", i), {n_altup, n_altup * n_altup}, 0); + layer.altup_router = create_tensor(tn(LLM_TENSOR_ALTUP_ROUTER, "weight", i), {n_embd, n_altup}, 0); + layer.altup_router_norm = create_tensor(tn(LLM_TENSOR_ALTUP_ROUTER_NORM, "weight", i), {n_embd}, 0); + layer.laurel_l = create_tensor(tn(LLM_TENSOR_LAUREL_L, "weight", i), {n_embd, laurel_rank}, 0); + layer.laurel_r = create_tensor(tn(LLM_TENSOR_LAUREL_R, "weight", i), {laurel_rank, n_embd}, 0); + layer.laurel_post_norm = create_tensor(tn(LLM_TENSOR_LAUREL_POST_NORM, "weight", i), {n_embd}, 0); + } +} + +std::unique_ptr llama_model_gemma3n::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + // get 2D slice view from a 3D tensor, the idx corresponds to the 3rd dim static ggml_tensor * ggml_view_2d_slice(ggml_context * ctx0, ggml_tensor * x, int idx) { GGML_ASSERT(idx < (int) x->ne[2]); @@ -7,7 +88,7 @@ static ggml_tensor * ggml_view_2d_slice(ggml_context * ctx0, ggml_tensor * x, in idx * x->ne[0] * x->ne[1] * ggml_element_size(x)); } -llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const llm_graph_params & params) : +llama_model_gemma3n::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params), model(model), n_embd_head(model.hparams.n_embd_head_k()), @@ -229,13 +310,13 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const ggml_build_forward_expand(gf, cur); } -ggml_tensor * llm_build_gemma3n_iswa::calc_magnitude(ggml_tensor * x) { +ggml_tensor * llama_model_gemma3n::graph::calc_magnitude(ggml_tensor * x) { return ggml_sqrt(ctx0, ggml_sum_rows(ctx0, ggml_sqr(ctx0, x))); } // equivalent to get_per_layer_inputs() in python code // output shape: [n_embd_altup, n_layer, n_tokens] -ggml_tensor * llm_build_gemma3n_iswa::build_inp_per_layer() { +ggml_tensor * llama_model_gemma3n::graph::build_inp_per_layer() { auto inp = std::make_unique(n_embd); ggml_tensor * inp_per_layer; float tok_embd_scale = sqrtf((float) n_embd_altup); @@ -268,7 +349,7 @@ ggml_tensor * llm_build_gemma3n_iswa::build_inp_per_layer() { // equivalent to project_per_layer_inputs() in python code // this calculates the per-layer inputs, so the final tensor shape will have n_layer as the last dim // output shape: [n_embd_altup, n_tokens, n_layer] -ggml_tensor * llm_build_gemma3n_iswa::project_per_layer_inputs(ggml_tensor * inp_batch, ggml_tensor * inp_per_layer) { +ggml_tensor * llama_model_gemma3n::graph::project_per_layer_inputs(ggml_tensor * inp_batch, ggml_tensor * inp_per_layer) { const float per_layer_projection_scale = 1.0f / sqrtf((float) n_embd); const float per_layer_input_scale = 1.0f / sqrtf(2.0f); @@ -291,7 +372,7 @@ ggml_tensor * llm_build_gemma3n_iswa::project_per_layer_inputs(ggml_tensor * inp // input cur shape: [n_altup, n_tokens] // output shape: [n_altup, n_tokens] -ggml_tensor * llm_build_gemma3n_iswa::laurel(ggml_tensor * cur, int il) { +ggml_tensor * llama_model_gemma3n::graph::laurel(ggml_tensor * cur, int il) { ggml_tensor * tmp = cur; tmp = build_lora_mm(model.layers[il].laurel_l, tmp); tmp = build_lora_mm(model.layers[il].laurel_r, tmp); @@ -303,7 +384,7 @@ ggml_tensor * llm_build_gemma3n_iswa::laurel(ggml_tensor * cur, int il) { // input x shape: [n_embd, n_tokens] // output shape: [n_embd, n_tokens] -ggml_tensor * llm_build_gemma3n_iswa::gaussian_topk(ggml_tensor * x) { +ggml_tensor * llama_model_gemma3n::graph::gaussian_topk(ggml_tensor * x) { ggml_tensor * mean = ggml_mean(ctx0, x); ggml_tensor * std = ggml_sqrt(ctx0, ggml_scale(ctx0, ggml_sum_rows(ctx0, ggml_sqr(ctx0, ggml_sub(ctx0, x, mean))), 1.0f / (float) (x->ne[0] - 1))); @@ -318,7 +399,7 @@ ggml_tensor * llm_build_gemma3n_iswa::gaussian_topk(ggml_tensor * x) { // equivalent to compute_router_modalities() in python code // input x shape: [n_embd, n_tokens] // output shape: [n_altup, n_tokens] -ggml_tensor * llm_build_gemma3n_iswa::altup_compute_router_modalities(ggml_tensor * x, int il) { +ggml_tensor * llama_model_gemma3n::graph::altup_compute_router_modalities(ggml_tensor * x, int il) { ggml_tensor * router_inputs = build_norm(x, model.layers[il].altup_router_norm, NULL, LLM_NORM_RMS, il); // router_input_scale @@ -330,7 +411,7 @@ ggml_tensor * llm_build_gemma3n_iswa::altup_compute_router_modalities(ggml_tenso // input cur shape: [n_embd, n_tokens, n_altup] // output shape: [n_embd, n_tokens, n_altup] -ggml_tensor * llm_build_gemma3n_iswa::altup_predict(ggml_tensor * cur, int il) { +ggml_tensor * llama_model_gemma3n::graph::altup_predict(ggml_tensor * cur, int il) { ggml_tensor * activated = ggml_view_2d_slice(ctx0, cur, i_altup_act); // [n_embd, n_tokens] ggml_tensor * modalities = altup_compute_router_modalities(activated, il); // [n_altup, n_tokens] cb(modalities, "modalities", il); @@ -355,7 +436,7 @@ ggml_tensor * llm_build_gemma3n_iswa::altup_predict(ggml_tensor * cur, int il) { // input predictions shape: [n_embd, n_tokens, n_altup] // input activated shape: [n_embd, n_tokens] // output shape: [n_embd, n_tokens, n_altup] -ggml_tensor * llm_build_gemma3n_iswa::altup_correct(ggml_tensor * predictions, ggml_tensor * activated, int il) { +ggml_tensor * llama_model_gemma3n::graph::altup_correct(ggml_tensor * predictions, ggml_tensor * activated, int il) { ggml_tensor * modalities = altup_compute_router_modalities(activated, il); // [n_altup, n_tokens] cb(modalities, "modalities", il); diff --git a/examples/talk-llama/models/gemma4-iswa.cpp b/examples/talk-llama/models/gemma4.cpp similarity index 62% rename from examples/talk-llama/models/gemma4-iswa.cpp rename to examples/talk-llama/models/gemma4.cpp index c7fb7747414..f45ae4cad59 100644 --- a/examples/talk-llama/models/gemma4-iswa.cpp +++ b/examples/talk-llama/models/gemma4.cpp @@ -1,5 +1,140 @@ #include "models.h" +void llama_model_gemma4::load_arch_hparams(llama_model_loader & ml) { + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.swa_layers, hparams.n_layer); + + uint32_t n_kv_shared_layers = 0; + ml.get_key(LLM_KV_ATTENTION_SHARED_KV_LAYERS, n_kv_shared_layers, false); + + hparams.n_layer_kv_from_start = hparams.n_layer - (int32_t)n_kv_shared_layers; + hparams.f_attention_scale = 1.0f; // Gemma4 uses self.scaling = 1.0 (no pre-attn scaling) + + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_EMBEDDING_LENGTH_PER_LAYER, hparams.n_embd_per_layer); + ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_SWA, hparams.n_embd_head_k_swa); + ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_SWA, hparams.n_embd_head_v_swa); + ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false); + + switch (hparams.n_layer) { + case 30: type = LLM_TYPE_26B_A4B; break; + case 35: type = LLM_TYPE_E2B; break; + case 42: type = LLM_TYPE_E4B; break; + case 60: type = LLM_TYPE_31B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_gemma4::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + const uint32_t n_embd_per_layer = hparams.n_embd_per_layer; + const int64_t n_ff_exp = hparams.n_ff_exp; + + if (n_embd_head_k != n_embd_head_v) { + throw std::runtime_error("Gemma 4 requires n_embd_head_k == n_embd_head_v"); + } + if (hparams.n_embd_head_k_swa != hparams.n_embd_head_v_swa) { + throw std::runtime_error("Gemma 4 requires n_embd_head_k_swa == n_embd_head_v_swa"); + } + + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + if (n_embd_per_layer > 0) { + per_layer_tok_embd = create_tensor(tn(LLM_TENSOR_PER_LAYER_TOKEN_EMBD, "weight"), {n_embd_per_layer * n_layer, n_vocab}, 0); + per_layer_model_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_MODEL_PROJ, "weight", 0), {n_embd, n_embd_per_layer * n_layer}, 0); + per_layer_proj_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ_NORM, "weight", 0), {n_embd_per_layer}, 0); + } + + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + + int rope_freqs_flag = 0; + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + const int64_t n_head = hparams.n_head(i); + const int64_t n_embd_head = hparams.n_embd_head_k(i); + const int64_t n_embd_k = hparams.n_embd_k_gqa(i); + const int64_t n_embd_v = hparams.n_embd_v_gqa(i); + const int kv_flags = hparams.has_kv(i) ? 0 : TENSOR_NOT_REQUIRED; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + // note: use_alternative_attention (v_proj is optional, if it's not present, use k_proj) + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k}, kv_flags); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v}, TENSOR_NOT_REQUIRED); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head * n_head, n_embd}, 0); + + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head}, kv_flags); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); + + layer.out_scale = create_tensor(tn(LLM_TENSOR_LAYER_OUT_SCALE, "weight", i), {1u}, TENSOR_NOT_REQUIRED); + + if (!hparams.is_swa(i)) { + // full_attention layers use rope_freqs for proportional rope + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_embd_head/2}, rope_freqs_flag); + rope_freqs_flag = TENSOR_DUPLICATED; + } + + // handle use_double_wide_mlp + int64_t n_ff_cur = hparams.n_ff(i); + + // for expert layers, we use normal FFN as shared expert (same as python code) + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff_cur}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff_cur}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff_cur, n_embd}, 0); + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); + + // MoE router + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, TENSOR_NOT_REQUIRED); + bool has_expert = layer.ffn_gate_inp != nullptr; + + // norm + if (has_expert) { + layer.ffn_gate_inp_s = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "scale", i), {n_embd}, 0); + + layer.ffn_pre_norm_2 = create_tensor(tn(LLM_TENSOR_FFN_PRE_NORM_2, "weight", i), {n_embd}, 0); + layer.ffn_post_norm_1 = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM_1, "weight", i), {n_embd}, 0); + layer.ffn_post_norm_2 = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM_2, "weight", i), {n_embd}, 0); + + // MoE FFN + layer.ffn_gate_up_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_UP_EXPS, "weight", i), {n_embd, n_ff_exp * 2, n_expert}, TENSOR_NOT_REQUIRED); + + if (layer.ffn_gate_up_exps == nullptr) { + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); + } + + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + + // per-expert scale will be loaded as down_exps_s at the end of the current switch case + } + + // per-layer embeddings + if (n_embd_per_layer > 0) { + layer.per_layer_inp_gate = create_tensor(tn(LLM_TENSOR_PER_LAYER_INP_GATE, "weight", i), {n_embd, n_embd_per_layer}, 0); + layer.per_layer_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ, "weight", i), {n_embd_per_layer, n_embd}, 0); + layer.per_layer_post_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_POST_NORM, "weight", i), {n_embd}, 0); + } + } +} + +std::unique_ptr llama_model_gemma4::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + // get 2D slice view from a 3D tensor, the idx corresponds to the 3rd dim static ggml_tensor * ggml_view_2d_slice(ggml_context * ctx0, ggml_tensor * x, int idx) { GGML_ASSERT(idx < (int) x->ne[2]); @@ -7,7 +142,7 @@ static ggml_tensor * ggml_view_2d_slice(ggml_context * ctx0, ggml_tensor * x, in idx * x->ne[0] * x->ne[1] * ggml_element_size(x)); } -llm_build_gemma4_iswa::llm_build_gemma4_iswa(const llama_model & model, const llm_graph_params & params) : +llama_model_gemma4::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params), model(model), n_embd_per_layer(model.hparams.n_embd_per_layer) { @@ -157,8 +292,8 @@ llm_build_gemma4_iswa::llm_build_gemma4_iswa(const llama_model & model, const ll cur_moe = build_moe_ffn(cur_moe, nullptr, // gate_inp - nullptr, // up_exps - nullptr, // gate_exps + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps, nullptr, // exp_probs_b (not used for gemma4) n_expert, n_expert_used, @@ -167,8 +302,8 @@ llm_build_gemma4_iswa::llm_build_gemma4_iswa(const llama_model & model, const ll LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il, logits, model.layers[il].ffn_gate_up_exps, - nullptr, // up_exps_s - nullptr, // gate_exps_s + model.layers[il].ffn_up_exps_s, + model.layers[il].ffn_gate_exps_s, model.layers[il].ffn_down_exps_s); cur_moe = build_norm(cur_moe, model.layers[il].ffn_post_norm_2, nullptr, @@ -261,7 +396,7 @@ llm_build_gemma4_iswa::llm_build_gemma4_iswa(const llama_model & model, const ll // equivalent to get_per_layer_inputs() in python code // output shape: [n_embd_per_layer, n_layer, n_tokens] -ggml_tensor * llm_build_gemma4_iswa::build_inp_per_layer() { +ggml_tensor * llama_model_gemma4::graph::build_inp_per_layer() { auto inp = std::make_unique(n_embd); ggml_tensor * inp_per_layer; @@ -299,7 +434,7 @@ ggml_tensor * llm_build_gemma4_iswa::build_inp_per_layer() { // inp_batch shape: [n_embd, n_tokens] // inp_per_layer shape: [n_embd_per_layer, n_layer, n_tokens] (from build_inp_per_layer) // output shape: [n_embd_per_layer, n_tokens, n_layer] -ggml_tensor * llm_build_gemma4_iswa::project_per_layer_inputs(ggml_tensor * inp_batch, ggml_tensor * inp_per_layer) { +ggml_tensor * llama_model_gemma4::graph::project_per_layer_inputs(ggml_tensor * inp_batch, ggml_tensor * inp_per_layer) { const float per_layer_projection_scale = 1.0f / sqrtf((float) n_embd); const float per_layer_input_scale = 1.0f / sqrtf(2.0f); diff --git a/examples/talk-llama/models/glm-dsa.cpp b/examples/talk-llama/models/glm-dsa.cpp new file mode 100644 index 00000000000..af2b55ef563 --- /dev/null +++ b/examples/talk-llama/models/glm-dsa.cpp @@ -0,0 +1,155 @@ +#include "models.h" + +void llama_model_glm_dsa::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, false); + + // MoE parameters + ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert); + ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + + // deepseek MLA parameters + ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); + ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla_impl, false); + ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla_impl, false); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + + // DSA parameters + ml.get_key(LLM_KV_ATTENTION_INDEXER_HEAD_COUNT, hparams.indexer_n_head); + ml.get_key(LLM_KV_ATTENTION_INDEXER_KEY_LENGTH, hparams.indexer_head_size); + ml.get_key(LLM_KV_ATTENTION_INDEXER_TOP_K, hparams.indexer_top_k); + + // Expert gating function (GLM-4.5 uses sigmoid) + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); + if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { + hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID; + } + + // NextN/MTP parameters + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); + + // TODO: when MTP is implemented, this should probably be updated if needed + hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; + + switch (hparams.n_layer) { + case 79: type = LLM_TYPE_744B_A40B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_glm_dsa::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + const int64_t n_expert_shared = hparams.n_expert_shared; + + const bool is_mla = hparams.is_mla(); + if (!is_mla) { + throw std::runtime_error("GLM_DSA architecture requires MLA"); + } + + // note: these are the actual head sizes you get when treating as MHA or after "decompression" using wv_b for MLA + const int64_t n_embd_head_k_mla = hparams.n_embd_head_k_mla(); + const int64_t n_embd_head_v_mla = hparams.n_embd_head_v_mla(); + + const int64_t n_embd_head_qk_rope = hparams.n_rot(); + const int64_t n_embd_head_qk_nope = n_embd_head_k_mla - n_embd_head_qk_rope; + + const int64_t q_lora_rank = hparams.n_lora_q; + const int64_t kv_lora_rank = hparams.n_lora_kv; + + const int64_t n_ff_exp = hparams.n_ff_exp; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + // try to load output.weight, if not found, use token_embd (tied embeddings) + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + if (!output) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + int flags = 0; + if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { + // skip all tensors in the NextN layers + // TODO @ngxson : TENSOR_NOT_REQUIRED was a hack, need to remove it later + flags |= TENSOR_SKIP | TENSOR_NOT_REQUIRED; + } + + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, flags); + layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, flags); + layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, flags); + + layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, flags); + layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k_mla}, flags); + + layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + n_embd_head_qk_rope}, flags); + + // note: only old legacy GGUF files will have the unsplit wkv_b tensor in + layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K_B, "weight", i), {n_embd_head_qk_nope, kv_lora_rank, n_head}, flags); + layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_embd_head_v_mla, n_head}, flags); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_embd_head_v_mla, n_embd}, flags); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, flags); + + // DSA indexer + layer.indexer_k_norm = create_tensor(tn(LLM_TENSOR_INDEXER_K_NORM, "weight", i), {hparams.indexer_head_size}, flags); + layer.indexer_k_norm_b = create_tensor(tn(LLM_TENSOR_INDEXER_K_NORM, "bias", i), {hparams.indexer_head_size}, flags); + layer.indexer_proj = create_tensor(tn(LLM_TENSOR_INDEXER_PROJ, "weight", i), {n_embd, hparams.indexer_n_head}, flags); + layer.indexer_attn_k = create_tensor(tn(LLM_TENSOR_INDEXER_ATTN_K, "weight", i), {n_embd, hparams.indexer_head_size}, flags); + layer.indexer_attn_q_b = create_tensor(tn(LLM_TENSOR_INDEXER_ATTN_Q_B, "weight", i), {q_lora_rank, hparams.indexer_n_head * hparams.indexer_head_size}, flags); + if (i < (int) hparams.n_layer_dense_lead) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, flags); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, flags); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, flags); + } else { + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, flags); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0"); + } + + // MoE branch + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, flags); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, flags); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, flags); + + // Shared expert branch + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, flags); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, flags); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, flags); + } + + // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers + if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags); + + // Optional tensors + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, flags | TENSOR_NOT_REQUIRED); + } + } +} + +std::unique_ptr llama_model_glm_dsa::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + diff --git a/examples/talk-llama/models/glm4-moe.cpp b/examples/talk-llama/models/glm4-moe.cpp index 8d4f4a01553..45886b51ac1 100644 --- a/examples/talk-llama/models/glm4-moe.cpp +++ b/examples/talk-llama/models/glm4-moe.cpp @@ -1,6 +1,139 @@ #include "models.h" -llm_build_glm4_moe::llm_build_glm4_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_glm4_moe::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, false); + + // MoE parameters + ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert); + ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + + // Expert gating function (GLM-4.5 uses sigmoid) + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); + if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { + hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID; + } + + // NextN/MTP parameters + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); + + // TODO: when MTP is implemented, this should probably be updated if needed + hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; + + switch (hparams.n_layer) { + case 47: type = LLM_TYPE_106B_A12B; break; // GLM-4.5-Air (46 layers + 1 NextN layer) + case 48: type = LLM_TYPE_102B_A12B; break; // Solar Open + case 93: type = LLM_TYPE_355B_A32B; break; // GLM-4.5 (92 layers + 1 NextN layer) + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_glm4_moe::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + const int64_t n_expert_shared = hparams.n_expert_shared; + + + GGML_ASSERT(hparams.n_expert > 0 && "n_expert must be > 0 for GLM4_MOE MoE layers"); + GGML_ASSERT(hparams.n_expert_used > 0 && "n_expert_used must be > 0 for GLM4_MOE MoE layers"); + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); + } + + // Load ALL tensors including NextN layer to satisfy total tensor count + // but only PROCESS up to last layer (skipping final NextN layer) in forward pass + for (int i = 0; i < n_layer; ++i) { + int flags = 0; + if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { + // skip all tensors in the NextN layers + flags |= TENSOR_SKIP; + } + + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, flags); + + // GLM-style attention with bias terms + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, flags); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, flags); + + // K/Q norm tensors (optional for GLM-4.5 355B variant) + layer.attn_q_norm = create_tensor( + tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, TENSOR_NOT_REQUIRED | flags); + layer.attn_k_norm = create_tensor( + tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, TENSOR_NOT_REQUIRED | flags); + + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, flags); + + // Check if this layer uses MoE or dense FFN based on n_layer_dense_lead + // GLM 4.5 uses hybrid architecture: layer 0 is dense, layers 1+ are MoE + const bool use_moe = (static_cast(i) >= hparams.n_layer_dense_lead); + + if (use_moe) { + // MoE layers + layer.ffn_gate_inp = + create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, flags); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), { n_expert }, flags); + + // MoE branch + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + + layer.ffn_gate_exps = create_tensor( + tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, flags); + layer.ffn_down_exps = create_tensor( + tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, flags); + layer.ffn_up_exps = create_tensor( + tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, flags); + + // Shared expert + if (n_expert_shared > 0) { + const int64_t n_ff_shexp = n_ff_exp * n_expert_shared; + layer.ffn_gate_shexp = create_tensor( + tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp }, flags); + layer.ffn_down_shexp = create_tensor( + tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_shexp, n_embd }, flags); + layer.ffn_up_shexp = create_tensor( + tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp }, flags); + } + } else { + // Dense layers (first k layers) - GLM uses separate gate/up projections + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, flags); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, flags); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, flags); + } + + // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers + if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags); + + // Optional tensors + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, flags | TENSOR_NOT_REQUIRED); + } + } +} + +std::unique_ptr llama_model_glm4_moe::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_glm4_moe::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/glm4.cpp b/examples/talk-llama/models/glm4.cpp index f0bfda393fa..d6ef76e26d6 100644 --- a/examples/talk-llama/models/glm4.cpp +++ b/examples/talk-llama/models/glm4.cpp @@ -1,6 +1,78 @@ #include "models.h" -llm_build_glm4::llm_build_glm4(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_glm4::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, false); + + // NextN/MTP parameters (GLM-OCR) + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); + + // TODO: when MTP is implemented, this should probably be updated if needed + hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; + + switch (hparams.n_layer) { + case 17: type = LLM_TYPE_1B; break; // GLM-OCR + case 40: type = LLM_TYPE_9B; break; + case 61: type = LLM_TYPE_32B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_glm4::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + int flags = 0; + if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { + // skip all tensors in the NextN layers + flags |= TENSOR_SKIP; + } + + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, flags); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, flags); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, flags); + + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, flags); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, flags); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, flags); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff * 2}, flags); + + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, flags); + + // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers + if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags); + + // Optional tensors + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, flags | TENSOR_NOT_REQUIRED); + } + } +} + +std::unique_ptr llama_model_glm4::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_glm4::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/gpt2.cpp b/examples/talk-llama/models/gpt2.cpp index f8dc53eb723..ba49c31b56b 100644 --- a/examples/talk-llama/models/gpt2.cpp +++ b/examples/talk-llama/models/gpt2.cpp @@ -1,6 +1,60 @@ #include "models.h" -llm_build_gpt2::llm_build_gpt2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_gpt2::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + switch (hparams.n_layer) { + case 12: type = LLM_TYPE_SMALL; break; + case 24: type = LLM_TYPE_MEDIUM; break; + case 36: type = LLM_TYPE_LARGE; break; + case 48: type = LLM_TYPE_XL; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_gpt2::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); + layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); + } +} + +std::unique_ptr llama_model_gpt2::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_gpt2::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/gptneox.cpp b/examples/talk-llama/models/gptneox.cpp index 0016ddede43..33ebe2d8800 100644 --- a/examples/talk-llama/models/gptneox.cpp +++ b/examples/talk-llama/models/gptneox.cpp @@ -1,6 +1,89 @@ #include "models.h" -llm_build_gptneox::llm_build_gptneox(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_gptneox::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_USE_PARALLEL_RESIDUAL, hparams.use_par_res); + switch (hparams.n_layer) { + case 6: + switch (hparams.n_ff()) { + case 512: type = LLM_TYPE_14M; break; + case 2048: type = LLM_TYPE_70M; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 12: + switch (hparams.n_ff()) { + case 3072: type = LLM_TYPE_160M; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 16: + switch (hparams.n_ff()) { + case 8192: type = LLM_TYPE_1B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 24: + switch (hparams.n_ff()) { + case 4096: type = LLM_TYPE_410M; break; + case 8192: type = LLM_TYPE_1_4B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 32: + switch (hparams.n_ff()) { + case 10240: type = LLM_TYPE_2_8B; break; + case 16384: type = LLM_TYPE_6_9B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 36: + switch (hparams.n_ff()) { + case 20480: type = LLM_TYPE_12B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 44: + switch (hparams.n_ff()) { + case 24576: type = LLM_TYPE_20B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_gptneox::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); + layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); + } +} + +std::unique_ptr llama_model_gptneox::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_gptneox::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/granite-hybrid.cpp b/examples/talk-llama/models/granite-hybrid.cpp index e983742bef5..12e4790ae24 100644 --- a/examples/talk-llama/models/granite-hybrid.cpp +++ b/examples/talk-llama/models/granite-hybrid.cpp @@ -1,6 +1,137 @@ #include "models.h" -llm_build_granite_hybrid::llm_build_granite_hybrid(const llama_model & model, const llm_graph_params & params) : +void llama_model_granite_hybrid::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale, /* required */ false); + ml.get_key(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale, /* required */ false); + ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale, /* required */ false); + ml.get_key(LLM_KV_ATTENTION_SCALE, hparams.f_attention_scale, /* required */ false); + + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); + + // Granite uses rope_finetuned as a switch for rope, so default to true + bool rope_finetuned = true; + ml.get_key(LLM_KV_ROPE_SCALING_FINETUNED, rope_finetuned, false); + hparams.rope_finetuned = rope_finetuned; + + // A layer is recurrent IFF the n_head_kv value is set to 0 + for (uint32_t i = 0; i < hparams.n_layer; ++i) { + hparams.recurrent_layer_arr[i] = hparams.n_head_kv(i) == 0; + } + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_embd) { + case 768: type = LLM_TYPE_350M; break; + case 1536: type = (hparams.n_ff() == 512 ? LLM_TYPE_7B_A1B : LLM_TYPE_1B); break; + case 2048: case 2560: type = LLM_TYPE_3B; break; + case 4096: type = LLM_TYPE_32B; break; + default: type = LLM_TYPE_UNKNOWN; + } + + // For Granite MoE Shared + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, /* required */ false); +} + +void llama_model_granite_hybrid::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + // mamba2 Mixer SSM params + // NOTE: int64_t for tensor dimensions + const int64_t d_conv = hparams.ssm_d_conv; + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t d_state = hparams.ssm_d_state; + const int64_t n_ssm_head = hparams.ssm_dt_rank; + const int64_t n_group = hparams.ssm_n_group; + const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + n_ssm_head; + + // only an expansion factor of 2 is supported for now + GGML_ASSERT(2 * n_embd == d_inner); + + // embeddings + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + { + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed, duplicated to allow offloading + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + // norm + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + if (hparams.is_recurrent(i)) { + // ssm layers + layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, d_in_proj}, 0); + + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner + 2*n_group*d_state}, 0); + layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner + 2*n_group*d_state}, TENSOR_NOT_REQUIRED); + + layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {n_ssm_head}, 0); + + // no "weight" suffix for these + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_ssm_head}, 0); + layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, n_ssm_head}, 0); + + layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {d_inner / n_group, n_group}, 0); + + // out_proj + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0); + } else { + // attention layers (with optional bias) + const int64_t n_head_i = hparams.n_head(i); + const int64_t n_embd_k_gqa_i = hparams.n_embd_k_gqa(i); + const int64_t n_embd_v_gqa_i = hparams.n_embd_v_gqa(i); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head_i, n_embd_k_gqa_i, n_embd_v_gqa_i, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head_i, n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + } + + // feed forward (w/ optional biases) + if (n_expert > 0) { + // MoE FFN + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + + // For Granite MoE Shared + if (hparams.n_ff_shexp > 0) { + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, 0); + } + } else { + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + } + } +} + +std::unique_ptr llama_model_granite_hybrid::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_granite_hybrid::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_build_mamba_base(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); @@ -67,7 +198,7 @@ llm_build_granite_hybrid::llm_build_granite_hybrid(const llama_model & model, co ggml_build_forward_expand(gf, cur); } -ggml_tensor * llm_build_granite_hybrid::build_attention_layer(ggml_tensor * cur, +ggml_tensor * llama_model_granite_hybrid::graph::build_attention_layer(ggml_tensor * cur, ggml_tensor * inp_pos, llm_graph_input_attn_kv * inp_attn, const llama_model & model, @@ -98,7 +229,7 @@ ggml_tensor * llm_build_granite_hybrid::build_attention_layer(ggml_tensor * return cur; } -ggml_tensor * llm_build_granite_hybrid::build_layer_ffn(ggml_tensor * cur, +ggml_tensor * llama_model_granite_hybrid::graph::build_layer_ffn(ggml_tensor * cur, ggml_tensor * inpSA, const llama_model & model, const int il) { diff --git a/examples/talk-llama/models/granite-moe.cpp b/examples/talk-llama/models/granite-moe.cpp new file mode 100644 index 00000000000..0d89bc1f340 --- /dev/null +++ b/examples/talk-llama/models/granite-moe.cpp @@ -0,0 +1,89 @@ +#include "models.h" + +void llama_model_granite_moe::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); + ml.get_key(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale, false); + ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale, false); + ml.get_key(LLM_KV_ATTENTION_SCALE, hparams.f_attention_scale, false); + + // Granite uses rope_finetuned as a switch for rope, so default to true + bool rope_finetuned = true; + ml.get_key(LLM_KV_ROPE_SCALING_FINETUNED, rope_finetuned, false); + hparams.rope_finetuned = rope_finetuned; + + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_3B; break; + case 40: type = LLM_TYPE_3B; break; + // Add additional layer/vocab/etc checks here for other model sizes + default: type = LLM_TYPE_UNKNOWN; + } + + // For Granite MoE Shared + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, /* required */ false); +} + +void llama_model_granite_moe::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + // optional bias tensors + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + else { + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + + if (n_expert == 0) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + + // optional MLP bias + layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + } else { + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + + // For Granite MoE Shared + if (hparams.n_ff_shexp > 0) { + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, 0); + } + } + } +} + +std::unique_ptr llama_model_granite_moe::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + diff --git a/examples/talk-llama/models/granite.cpp b/examples/talk-llama/models/granite.cpp index 6ea90285225..5e7c7b68181 100644 --- a/examples/talk-llama/models/granite.cpp +++ b/examples/talk-llama/models/granite.cpp @@ -1,6 +1,93 @@ #include "models.h" -llm_build_granite::llm_build_granite( +void llama_model_granite::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); + ml.get_key(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale, false); + ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale, false); + ml.get_key(LLM_KV_ATTENTION_SCALE, hparams.f_attention_scale, false); + + // Granite uses rope_finetuned as a switch for rope, so default to true + bool rope_finetuned = true; + ml.get_key(LLM_KV_ROPE_SCALING_FINETUNED, rope_finetuned, false); + hparams.rope_finetuned = rope_finetuned; + + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_3B; break; + case 40: type = LLM_TYPE_3B; break; + // Add additional layer/vocab/etc checks here for other model sizes + default: type = LLM_TYPE_UNKNOWN; + } + + // For Granite MoE Shared + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, /* required */ false); +} + +void llama_model_granite::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + // optional bias tensors + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + else { + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + + if (n_expert == 0) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + + // optional MLP bias + layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + } else { + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + + // For Granite MoE Shared + if (hparams.n_ff_shexp > 0) { + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, 0); + } + } + } +} + +std::unique_ptr llama_model_granite::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_granite::graph::graph( const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { @@ -68,7 +155,7 @@ llm_build_granite::llm_build_granite( ggml_build_forward_expand(gf, cur); } -ggml_tensor * llm_build_granite::build_attention_layer( +ggml_tensor * llama_model_granite::graph::build_attention_layer( ggml_tensor * cur, ggml_tensor * inp_pos, llm_graph_input_attn_kv * inp_attn, @@ -107,7 +194,7 @@ ggml_tensor * llm_build_granite::build_attention_layer( return cur; } -ggml_tensor * llm_build_granite::build_layer_ffn( +ggml_tensor * llama_model_granite::graph::build_layer_ffn( ggml_tensor * cur, ggml_tensor * inpSA, const llama_model & model, diff --git a/examples/talk-llama/models/grok.cpp b/examples/talk-llama/models/grok.cpp index b8f35afdc03..0bc49d00206 100644 --- a/examples/talk-llama/models/grok.cpp +++ b/examples/talk-llama/models/grok.cpp @@ -1,6 +1,89 @@ #include "models.h" -llm_build_grok::llm_build_grok(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_grok::load_arch_hparams(llama_model_loader & ml) { + // defaults for old GGUFs + hparams.yarn_beta_fast = 8.0f; + hparams.f_logit_scale = 0.5773502691896257f; + hparams.f_embedding_scale = 78.38367176906169f; + hparams.f_attn_out_scale = 0.08838834764831845f; + hparams.f_attn_logit_softcapping = 30.0f; + hparams.f_router_logit_softcapping = 30.0f; + // no final_logit_softcapping in grok-1 + hparams.f_final_logit_softcapping = 0.0f; + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale, false); + ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale, false); + ml.get_key(LLM_KV_ATTENTION_OUTPUT_SCALE, hparams.f_attn_out_scale, false); + ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false); + ml.get_key(LLM_KV_ROUTER_LOGIT_SOFTCAPPING, hparams.f_router_logit_softcapping, false); + ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false); + + ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_LENGTH, hparams.attn_temp_length, false); + ml.get_key(LLM_KV_ROPE_SCALING_YARN_EXT_FACTOR, hparams.yarn_ext_factor, false); + ml.get_key(LLM_KV_ROPE_SCALING_YARN_ATTN_FACTOR, hparams.yarn_attn_factor, false); + ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_FAST, hparams.yarn_beta_fast, false); + ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, hparams.yarn_beta_slow, false); + + switch (hparams.n_layer) { + case 64: type = LLM_TYPE_314B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_grok::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + if (n_expert == 0) { + throw std::runtime_error(arch_name() + " model cannot have zero experts"); + } + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff/* / n_expert_used*/; // grok-1 n_ff_exp == n_ff + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); + + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); + if (!layer.ffn_post_norm) { + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); + } + } +} + +std::unique_ptr llama_model_grok::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_grok::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/grovemoe.cpp b/examples/talk-llama/models/grovemoe.cpp index 151108a2a71..feef815165b 100644 --- a/examples/talk-llama/models/grovemoe.cpp +++ b/examples/talk-llama/models/grovemoe.cpp @@ -1,6 +1,70 @@ #include "models.h" -llm_build_grovemoe::llm_build_grovemoe(const llama_model & model, const llm_graph_params & params) : +void llama_model_grovemoe::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_CHUNK_FEED_FORWARD_LENGTH, hparams.n_ff_chexp, false); + ml.get_key(LLM_KV_EXPERT_GROUP_SCALE, hparams.expert_group_scale); + ml.get_key(LLM_KV_EXPERTS_PER_GROUP, hparams.n_group_experts); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 48: type = LLM_TYPE_30B_A3B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_grovemoe::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + GGML_ASSERT(n_expert > 0 && "n_expert must be > 0 for GROVEMOE"); + GGML_ASSERT(n_expert_used > 0 && "n_expert_used must be > 0 for GROVEMOE"); + GGML_ASSERT(hparams.n_group_experts > 0 && "n_group_experts must be > 0 for GROVEMOE"); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + + // MoE branch + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + const int64_t n_ff_chexp = hparams.n_ff_chexp ? hparams.n_ff_chexp : n_embd_head_k; + const int64_t n_chunk_expert = n_expert / hparams.n_group_experts; + + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + + layer.ffn_gate_chexps = create_tensor(tn(LLM_TENSOR_FFN_GATE_CHEXPS, "weight", i), { n_embd, n_ff_chexp, n_chunk_expert}, 0); + layer.ffn_down_chexps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_CHEXPS, "weight", i), {n_ff_chexp, n_embd, n_chunk_expert}, 0); + layer.ffn_up_chexps = create_tensor(tn(LLM_TENSOR_FFN_UP_CHEXPS, "weight", i), { n_embd, n_ff_chexp, n_chunk_expert}, 0); + } +} + +std::unique_ptr llama_model_grovemoe::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_grovemoe::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); const int64_t n_chunk_expert = n_expert / hparams.n_group_experts; diff --git a/examples/talk-llama/models/hunyuan-dense.cpp b/examples/talk-llama/models/hunyuan-dense.cpp index 1cd85d6d9d4..c137bd37c02 100644 --- a/examples/talk-llama/models/hunyuan-dense.cpp +++ b/examples/talk-llama/models/hunyuan-dense.cpp @@ -1,132 +1,6 @@ #include "models.h" -llm_build_hunyuan_dense::llm_build_hunyuan_dense(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v(); - - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); - GGML_ASSERT(n_embd_head == n_rot); - - const bool use_mrope = hparams.use_mrope(); - - int sections[4]; - std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); - - ggml_tensor * cur; - ggml_tensor * inpL; - - inpL = build_inp_embd(model.tok_embd); - - // inp_pos - contains the positions - ggml_tensor * inp_pos = build_inp_pos(); - - auto * inp_attn = build_attn_inp_kv(); - - const float kq_scale = 1.0f / sqrtf(float(n_embd_head)); - - ggml_tensor * inp_out_ids = build_inp_out_ids(); - - for (int il = 0; il < n_layer; ++il) { - ggml_tensor * inpSA = inpL; - - // norm - cur = build_norm(inpL, - model.layers[il].attn_norm, NULL, - LLM_NORM_RMS, il); - cb(cur, "attn_norm", il); - // self-attention - { - // rope freq factors for llama3; may return nullptr for llama2 and other models - ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); - - // compute Q and K and RoPE them - auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, - n_embd_head, n_head, n_head_kv, il); - - if (use_mrope) { - Qcur = ggml_rope_multi( - ctx0, Qcur, inp_pos, rope_factors, - n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow - ); - - Kcur = ggml_rope_multi( - ctx0, Kcur, inp_pos, rope_factors, - n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow - ); - } else { - Qcur = ggml_rope_ext( - ctx0, Qcur, inp_pos, rope_factors, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow - ); - - Kcur = ggml_rope_ext( - ctx0, Kcur, inp_pos, rope_factors, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow - ); - } - - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); - - Kcur = build_norm(Kcur, - model.layers[il].attn_k_norm, nullptr, - LLM_NORM_RMS, il); - cb(Kcur, "Kcur_norm", il); - - Qcur = build_norm(Qcur, - model.layers[il].attn_q_norm, nullptr, - LLM_NORM_RMS, il); - cb(Qcur, "Qcur_norm", il); - - cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, - Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); - cb(cur, "attn_out", il); - } - if (il == n_layer - 1 && inp_out_ids) { - cur = ggml_get_rows(ctx0, cur, inp_out_ids); - inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); - } - ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); - cb(ffn_inp, "ffn_inp", il); - - cur = build_norm(ffn_inp, - model.layers[il].ffn_norm, NULL, - LLM_NORM_RMS, il); - cb(cur, "ffn_norm", il); - // feed-forward network (non-MoE) - ggml_tensor * cur_mlp = build_ffn(cur, - model.layers[il].ffn_up, NULL, NULL, - model.layers[il].ffn_gate, NULL, NULL, - model.layers[il].ffn_down, NULL, NULL, - NULL, - LLM_FFN_SILU, LLM_FFN_PAR, il); - cb(cur_mlp, "ffn_out", il); - - cur = ggml_add(ctx0, cur_mlp, ffn_inp); - - cur = build_cvec(cur, il); - cb(cur, "l_out", il); - - // input for next layer - inpL = cur; - } - cur = inpL; - - cur = build_norm(cur, - model.output_norm, NULL, - LLM_NORM_RMS, -1); - - cb(cur, "result_norm", -1); - res->t_embd = cur; - // lm_head - cur = build_lora_mm(model.output, cur); - cb(cur, "result_output", -1); - res->t_logits = cur; - - ggml_build_forward_expand(gf, cur); +std::unique_ptr llama_model_hunyuan_dense::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); } + diff --git a/examples/talk-llama/models/hunyuan-moe.cpp b/examples/talk-llama/models/hunyuan-moe.cpp index ffe1664b0e1..44af42412f7 100644 --- a/examples/talk-llama/models/hunyuan-moe.cpp +++ b/examples/talk-llama/models/hunyuan-moe.cpp @@ -1,6 +1,59 @@ #include "models.h" -llm_build_hunyuan_moe::llm_build_hunyuan_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_hunyuan_moe::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); + + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_A13B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_hunyuan_moe::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + const uint32_t n_ff_shexp = hparams.n_ff_shexp > 0 ? hparams.n_ff_shexp : hparams.n_ff(i); + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, 0); + } +} + +std::unique_ptr llama_model_hunyuan_moe::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_hunyuan_moe::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/hunyuan-vl.cpp b/examples/talk-llama/models/hunyuan-vl.cpp new file mode 100644 index 00000000000..5fb9154bec0 --- /dev/null +++ b/examples/talk-llama/models/hunyuan-vl.cpp @@ -0,0 +1,189 @@ +#include "models.h" + +void llama_model_hunyuan_vl::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, false); + + // XDRoPE / NTK-aware scaling: base = rope_theta * alpha^(dim / (dim - 2)) + if (hparams.rope_scaling_alpha > 0.0f) { + const int dim = hparams.n_embd_head_k(); + hparams.rope_freq_base_train = hparams.rope_freq_base_train + * powf(hparams.rope_scaling_alpha, (float)dim / (float)(dim - 2)); + } + + switch (hparams.n_embd) { + case 1024: type = LLM_TYPE_0_5B; break; + case 2048: type = LLM_TYPE_1_8B; break; + case 3072: type = LLM_TYPE_4B; break; + case 4096: type = LLM_TYPE_7B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_hunyuan_vl::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + + } +} + +std::unique_ptr llama_model_hunyuan_vl::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_hunyuan_vl::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); + + const bool use_mrope = hparams.use_mrope(); + + int sections[4]; + std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv(); + + const float kq_scale = 1.0f / sqrtf(float(n_embd_head)); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + // self-attention + { + // rope freq factors for llama3; may return nullptr for llama2 and other models + ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); + + // compute Q and K and RoPE them + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); + + if (use_mrope) { + Qcur = ggml_rope_multi( + ctx0, Qcur, inp_pos, rope_factors, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_multi( + ctx0, Kcur, inp_pos, rope_factors, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + } else { + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + } + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + Kcur = build_norm(Kcur, + model.layers[il].attn_k_norm, nullptr, + LLM_NORM_RMS, il); + cb(Kcur, "Kcur_norm", il); + + Qcur = build_norm(Qcur, + model.layers[il].attn_q_norm, nullptr, + LLM_NORM_RMS, il); + cb(Qcur, "Qcur_norm", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + cb(cur, "attn_out", il); + } + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + // feed-forward network (non-MoE) + ggml_tensor * cur_mlp = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur_mlp, "ffn_out", il); + + cur = ggml_add(ctx0, cur_mlp, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + // lm_head + cur = build_lora_mm(model.output, cur); + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} diff --git a/examples/talk-llama/models/internlm2.cpp b/examples/talk-llama/models/internlm2.cpp index 83be2ca0aee..f0c5580a6f4 100644 --- a/examples/talk-llama/models/internlm2.cpp +++ b/examples/talk-llama/models/internlm2.cpp @@ -1,6 +1,43 @@ #include "models.h" -llm_build_internlm2::llm_build_internlm2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_internlm2::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_7B; break; + case 48: type = LLM_TYPE_20B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_internlm2::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + // layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} + +std::unique_ptr llama_model_internlm2::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_internlm2::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/jais.cpp b/examples/talk-llama/models/jais.cpp index 31101f3c14b..a6451dca095 100644 --- a/examples/talk-llama/models/jais.cpp +++ b/examples/talk-llama/models/jais.cpp @@ -1,6 +1,58 @@ #include "models.h" -llm_build_jais::llm_build_jais(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_jais::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, hparams.f_max_alibi_bias, false); + + switch (hparams.n_layer) { + case 24: type = LLM_TYPE_1_3B; break; + case 40: type = LLM_TYPE_13B; break; + /* TODO: add variants */ + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_jais::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); + layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, 0); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); + } +} + +std::unique_ptr llama_model_jais::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_jais::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/jais2.cpp b/examples/talk-llama/models/jais2.cpp index 507e04fa4aa..ad59b953e8d 100644 --- a/examples/talk-llama/models/jais2.cpp +++ b/examples/talk-llama/models/jais2.cpp @@ -1,8 +1,63 @@ #include "models.h" +void llama_model_jais2::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_8B; break; + case 68: type = LLM_TYPE_70B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_jais2::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + if (!output) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + // attention biases - all have shape n_embd (output dimension of projections) + layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); + layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd}, 0); + layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); + + // Jais-2 uses simple MLP (no gate) with biases + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); + } +} + +std::unique_ptr llama_model_jais2::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + // JAIS-2 model graph builder // Uses: LayerNorm (not RMSNorm), relu2 activation, separate Q/K/V, RoPE embeddings -llm_build_jais2::llm_build_jais2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +llama_model_jais2::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/jamba.cpp b/examples/talk-llama/models/jamba.cpp index f82b7795c87..e1b8d137e38 100644 --- a/examples/talk-llama/models/jamba.cpp +++ b/examples/talk-llama/models/jamba.cpp @@ -1,6 +1,111 @@ #include "models.h" -llm_build_jamba::llm_build_jamba(const llama_model & model, const llm_graph_params & params) : llm_build_mamba_base(params) { +void llama_model_jamba::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + for (uint32_t i = 0; i < hparams.n_layer; ++i) { + hparams.recurrent_layer_arr[i] = hparams.n_head_kv(i) == 0; + } + + switch (hparams.n_layer) { + // TODO: Jamba layers are a bit heterogeneous, so naming this is hard. + case 12: // 900M 8x???M + case 32: // 51B 16x?B + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_jamba::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + const int64_t d_conv = hparams.ssm_d_conv; + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t d_state = hparams.ssm_d_state; + const int64_t dt_rank = hparams.ssm_dt_rank; + + // only an expansion factor of 2 is supported for now + GGML_ASSERT(2 * n_embd == d_inner); + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + { + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed, duplicated to allow offloading + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + } + + for (int i = 0; i < n_layer; ++i) { + const int64_t n_head_kv = hparams.n_head_kv(i); + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(i); + + auto & layer = layers[i]; + + // norm + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + if (n_head_kv == 0) { + // Mamba layer + layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, 2*d_inner}, 0); + + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner}, 0); + layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner}, 0); + + layer.ssm_x = create_tensor(tn(LLM_TENSOR_SSM_X, "weight", i), {d_inner, dt_rank + 2*d_state}, 0); + + layer.ssm_dt_norm = create_tensor(tn(LLM_TENSOR_SSM_DT_NORM, "weight", i), {dt_rank}, 0); + + layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "weight", i), {dt_rank, d_inner}, 0); + layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {d_inner}, 0); + + layer.ssm_b_norm = create_tensor(tn(LLM_TENSOR_SSM_B_NORM, "weight", i), {d_state}, 0); + layer.ssm_c_norm = create_tensor(tn(LLM_TENSOR_SSM_C_NORM, "weight", i), {d_state}, 0); + + // no "weight" suffix for these + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {d_state, d_inner}, 0); + layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {d_inner}, 0); + + // out_proj + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0); + } else { + // Attention layers + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + } + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, TENSOR_NOT_REQUIRED); + + if (layer.ffn_gate_inp) { + // MoE + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + } else { + // FFN (no MoE) + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } +} + +std::unique_ptr llama_model_jamba::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_jamba::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_build_mamba_base(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); ggml_tensor * cur; diff --git a/examples/talk-llama/models/jina-bert-v2.cpp b/examples/talk-llama/models/jina-bert-v2.cpp new file mode 100644 index 00000000000..4f8866ece4d --- /dev/null +++ b/examples/talk-llama/models/jina-bert-v2.cpp @@ -0,0 +1,66 @@ +#include "models.h" + +void llama_model_jina_bert_v2::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + hparams.f_max_alibi_bias = 8.0f; + + switch (hparams.n_layer) { + case 4: type = LLM_TYPE_33M; break; // jina-embeddings-small + case 12: type = LLM_TYPE_137M; break; // jina-embeddings-base + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_jina_bert_v2::load_arch_tensors(llama_model_loader & ml) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); // word_embeddings + type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, 0); // token_type_embeddings + + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {n_embd}, 0); // LayerNorm + tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias", 0), {n_embd}, 0); // LayerNorm bias + + cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, 1}, TENSOR_NOT_REQUIRED); + cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"), {1}, TENSOR_NOT_REQUIRED); + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; // JinaBertLayer + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.attn_q_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); //output_dens + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); //output_dens + + layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0); //output_norm + layer.attn_out_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd}, 0); + + layer.attn_norm_2 = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); + + const auto tn_ffn_up_weight = tn(LLM_TENSOR_FFN_UP, "weight", i); + ggml_tensor * t_ffn_up = ml.get_tensor_meta(tn_ffn_up_weight.str().c_str()); + const int64_t n_ffn_up = t_ffn_up ? t_ffn_up->ne[1] : n_ff; + + GGML_ASSERT(n_ffn_up == n_ff || n_ffn_up == n_ff * 2); + layer.ffn_up = create_tensor(tn_ffn_up_weight, {n_embd, n_ffn_up}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ffn_up}, TENSOR_NOT_REQUIRED); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); + + layer.layer_out_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, 0); + layer.layer_out_norm_b = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd}, 0); + } +} + +std::unique_ptr llama_model_jina_bert_v2::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + diff --git a/examples/talk-llama/models/jina-bert-v3.cpp b/examples/talk-llama/models/jina-bert-v3.cpp new file mode 100644 index 00000000000..e0527529f56 --- /dev/null +++ b/examples/talk-llama/models/jina-bert-v3.cpp @@ -0,0 +1,69 @@ +#include "models.h" + +void llama_model_jina_bert_v3::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + + switch (hparams.n_layer) { + case 24: + type = LLM_TYPE_558M; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_jina_bert_v3::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + if (n_token_types == 0) { + throw std::runtime_error(arch_name() + " model needs to define token type count"); + } + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, TENSOR_NOT_REQUIRED); + + if (arch == LLM_ARCH_BERT) { + pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}, 0); + + cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, TENSOR_NOT_REQUIRED); + cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"), {n_embd}, TENSOR_NOT_REQUIRED); + + cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED); + cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {hparams.n_cls_out}, TENSOR_NOT_REQUIRED); + } + + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {n_embd}, 0); + tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias", 0), {n_embd}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0); + layer.attn_out_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd}, 0); + + if (hparams.moe_every_n_layers > 0 && i % hparams.moe_every_n_layers == 1) { + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + } else { + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + if (arch == LLM_ARCH_NOMIC_BERT) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + } + } + + layer.layer_out_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, 0); + layer.layer_out_norm_b = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd}, 0); + } +} + +std::unique_ptr llama_model_jina_bert_v3::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + diff --git a/examples/talk-llama/models/kimi-linear.cpp b/examples/talk-llama/models/kimi-linear.cpp index 58c89c417fc..ecffb105496 100644 --- a/examples/talk-llama/models/kimi-linear.cpp +++ b/examples/talk-llama/models/kimi-linear.cpp @@ -1,7 +1,175 @@ #include "models.h" - #include "llama-memory-recurrent.h" +void llama_model_kimi_linear::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla_impl); + ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla_impl); + ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_KDA_HEAD_DIM, hparams.n_embd_head_kda); + + // MLA qk_rope_head_dim (for reference) + // qk_rope_head_dim = 64, qk_nope_head_dim = 128, qk_head_dim = 192 + + // Mark KDA layers as recurrent using n_head_kv pattern (like Jamba) + // Set n_head_kv = 0 for KDA layers (recurrent), n_head_kv = n_head for MLA layers (attention) + for (uint32_t i = 0; i < hparams.n_layer; ++i) { + hparams.recurrent_layer_arr[i] = hparams.n_head_kv(i) == 0; // KDA layers are recurrent + } + + // MoE parameters - Kimi uses moe_intermediate_size = 1024 + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func); + + switch (hparams.n_layer) { + case 27: type = LLM_TYPE_48B_A3B; break; // Kimi-Linear-48B-A3B + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_kimi_linear::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + // Check for KDA specific tensors to determine layer type or if it's a mixed model + // Assuming KDA layer if KDA tensors are present + + // KDA uses head_dim = 128 (from linear_attn_config.head_dim) + const int64_t n_embd_head_k_kda = hparams.n_embd_head_kda; + const int64_t n_embd_head_v_kda = hparams.n_embd_head_kda; + const int64_t ssm_d_conv = hparams.ssm_d_conv; + + if (hparams.is_recurrent(i)) { + // Conv1d weights: try 4D first, then 3D (quantization may remove trailing 1) + // 4D: [d_conv, 1, d_inner, 1], 3D: [d_conv, 1, d_inner] + layer.ssm_q_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_Q, "weight", i), {ssm_d_conv, 1, n_embd_head_k_kda * n_head, 1}, TENSOR_NOT_REQUIRED); + if (!layer.ssm_q_conv) { + layer.ssm_q_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_Q, "weight", i), {ssm_d_conv, 1, n_embd_head_k_kda * n_head}, 0); + } + + // KDA Layer - Conv1d weights may be 3D or 4D + layer.ssm_k_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_K, "weight", i), {ssm_d_conv, 1, n_embd_head_k_kda * n_head, 1}, TENSOR_NOT_REQUIRED); + if (!layer.ssm_k_conv) { + layer.ssm_k_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_K, "weight", i), {ssm_d_conv, 1, n_embd_head_k_kda * n_head}, 0); + } + layer.ssm_v_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_V, "weight", i), {ssm_d_conv, 1, n_embd_head_v_kda * n_head, 1}, TENSOR_NOT_REQUIRED); + if (!layer.ssm_v_conv) { + layer.ssm_v_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_V, "weight", i), {ssm_d_conv, 1, n_embd_head_v_kda * n_head}, 0); + } + + // q, k, v projections + // Python: q_proj, k_proj, v_proj + create_tensor_qkv(layer, i, n_embd, n_embd_head_k_kda * n_head, n_embd_head_k_kda * n_head, n_embd_head_v_kda * n_head, 0); + + // KDA specific projections + // f_a_proj, f_b_proj + layer.ssm_f_a = create_tensor(tn(LLM_TENSOR_SSM_F_A, "weight", i), {n_embd, n_embd_head_k_kda}, 0); // head_dim + layer.ssm_f_b = create_tensor(tn(LLM_TENSOR_SSM_F_B, "weight", i), {n_embd_head_k_kda, n_embd_head_k_kda * n_head}, 0); // projection_size + + // b_proj (beta mixing coefficient) + layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", i), {n_embd, n_head}, 0); + + // A_log - Shape in GGUF: [1, num_heads, 1, 1] (4D) or [1, num_heads] (2D after quantization) Note: -exp(A_log) is applied in convert_hf_to_gguf.py + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_head, 1, 1}, TENSOR_NOT_REQUIRED); + if (!layer.ssm_a) { + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_head}, 0); + } + + // dt_bias - shape [n_embd_head_k_kda * n_head] = [4096] + layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {n_embd_head_k_kda * n_head}, 0); + + // g_a_proj, g_b_proj (output gate) + layer.ssm_g_a = create_tensor(tn(LLM_TENSOR_SSM_G_A, "weight", i), {n_embd, n_embd_head_k_kda}, 0); + layer.ssm_g_b = create_tensor(tn(LLM_TENSOR_SSM_G_B, "weight", i), {n_embd_head_k_kda, n_embd_head_k_kda * n_head}, 0); + + // o_norm (reusing SSM_NORM) + layer.ssm_o_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {n_embd_head_k_kda}, 0); // FusedRMSNormGated + + // o_proj + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_v_kda * n_head, n_embd}, 0); + + } else { + // MLA Layer - use MLA-specific head dimensions + const int64_t q_lora_rank = hparams.n_lora_q; + const int64_t kv_lora_rank = hparams.n_lora_kv; + const int64_t n_embd_head_k_mla = hparams.n_embd_head_k_mla(); + const int64_t n_embd_head_v_mla = hparams.n_embd_head_v_mla(); + + layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, TENSOR_NOT_REQUIRED); + layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, 0); + + if (layer.attn_q_a_norm) { + layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, 0); + layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k_mla}, 0); + } else { + // Kimi MLA without Q compression: wq = [n_embd, n_head * n_embd_head_k_mla] + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_head * n_embd_head_k_mla}, 0); + } + + // Kimi: qk_rope_head_dim = 64 (actual RoPE dimension for MLA) + // Note: hparams.n_rot may be 72 (from conversion) but actual is 64 + const int64_t qk_rope_head_dim = hparams.n_rot(); // From config: qk_rope_head_dim + layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + qk_rope_head_dim}, 0); + // Support Legacy GGUFs that don't split wkv_b (MLA KV cache disabled) + layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), + {kv_lora_rank, n_head * (n_embd_head_k_mla - qk_rope_head_dim + n_embd_head_v_mla)}, TENSOR_NOT_REQUIRED | TENSOR_SKIP_IF_VIRTUAL); + if (!layer.wkv_b) { // MLA KV cache enabled + layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K_B, "weight", i), {n_embd_head_k_mla - qk_rope_head_dim, kv_lora_rank, n_head}, 0); + layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_embd_head_v_mla, n_head}, 0); + } + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_embd_head_v_mla, n_embd}, 0); + } + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + // MoE intermediate size (different from dense FFN) + const int64_t n_ff_exp = hparams.n_ff_exp; + + // Kimi uses n_layer_dense_lead to determine which layers use dense FFN vs MoE + // first_k_dense_replace = 1 means layer 0 uses dense FFN, layers 1+ use MoE + if (i < (int) hparams.n_layer_dense_lead) { + // Dense FFN layer - use normal n_ff + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } else { + // MoE layer - use n_ff_exp (1024) instead of n_ff (9216) + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); + + // Shared experts use moe_intermediate_size * num_shared_experts + // Kimi: shared_expert_intermediate_size = 1024 * 1 = 1024 + // Tensors are 2D: [n_embd, n_ff_shexp] or [n_ff_shexp, n_embd] + const int64_t n_ff_shexp_actual = n_ff_exp * (hparams.n_expert_shared > 0 ? hparams.n_expert_shared : 1); + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp_actual}, TENSOR_NOT_REQUIRED); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp_actual, n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp_actual}, TENSOR_NOT_REQUIRED); + + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, 0); + } + } +} + +std::unique_ptr llama_model_kimi_linear::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + // Causal Conv1d function for Q,K,V // When qkv is 0, it is Q, 1 is K, 2 is V static ggml_tensor * causal_conv1d(ggml_cgraph * gf, ggml_context * ctx0, ggml_tensor * conv_states_all, ggml_tensor * conv_state_all, int64_t qkv, ggml_tensor * x, ggml_tensor * proj_w, ggml_tensor * conv_w, int64_t d_conv, int64_t head_dim, int64_t n_head, int64_t n_seq_tokens, int64_t n_seqs, int64_t n_tokens, int64_t kv_head) { @@ -63,7 +231,7 @@ static ggml_tensor * causal_conv1d(ggml_cgraph * gf, ggml_context * ctx0, ggml_t return ggml_reshape_4d(ctx0, Xcur, head_dim, n_head, n_seq_tokens, n_seqs); } -llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const llm_graph_params & params) : +llama_model_kimi_linear::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_build_delta_net_base(params), model(model) { ggml_tensor * cur; ggml_tensor * inpL; diff --git a/examples/talk-llama/models/lfm2.cpp b/examples/talk-llama/models/lfm2.cpp index eb8ec3c803a..df6a8028736 100644 --- a/examples/talk-llama/models/lfm2.cpp +++ b/examples/talk-llama/models/lfm2.cpp @@ -1,10 +1,94 @@ #include "models.h" - #include "../llama-memory-hybrid-iswa.h" #include "../llama-memory-hybrid.h" +void llama_model_lfm2::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_SHORTCONV_L_CACHE, hparams.n_shortconv_l_cache); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + for (uint32_t il = 0; il < hparams.n_layer; ++il) { + hparams.recurrent_layer_arr[il] = hparams.n_head_kv(il) == 0; + } + hparams.n_layer_dense_lead = hparams.n_layer; + switch (hparams.n_ff()) { + case 4608: type = LLM_TYPE_350M; break; + case 6912: type = LLM_TYPE_700M; break; + case 8192: type = LLM_TYPE_1_2B; break; + case 10752: type = LLM_TYPE_2_6B; break; + default: type = LLM_TYPE_UNKNOWN; + } + if (const auto is_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); is_swa && hparams.n_swa > 0) { + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + for (uint32_t il = 0; il < hparams.n_layer; ++il) { + hparams.swa_layers[il] = !hparams.recurrent_layer_arr[il]; + } + } +} + +void llama_model_lfm2::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM_LFM2, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + const bool is_moe_layer = i >= static_cast(hparams.n_layer_dense_lead); + + // ffn/moe is same for transformer and conv layers + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + if (is_moe_layer) { + GGML_ASSERT(n_expert && n_expert_used); + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, hparams.n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {hparams.n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, hparams.n_ff_exp, n_expert}, 0); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, 0); + } else { // dense + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + + // for operator_norm + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + if (!hparams.is_recurrent(i)) { + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + GGML_ASSERT(n_embd_v_gqa == n_embd_k_gqa); + + create_tensor_qkv(layer, i, n_embd, n_embd, hparams.n_embd_k_gqa(i), hparams.n_embd_v_gqa(i), 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + } else { + layer.shortconv.conv = create_tensor(tn(LLM_TENSOR_SHORTCONV_CONV, "weight", i), {hparams.n_shortconv_l_cache, n_embd}, 0); + layer.shortconv.in_proj = create_tensor(tn(LLM_TENSOR_SHORTCONV_INPROJ, "weight", i), {n_embd, 3 * n_embd}, 0); + layer.shortconv.out_proj = create_tensor(tn(LLM_TENSOR_SHORTCONV_OUTPROJ, "weight", i), {n_embd, n_embd}, 0); + } + } + + // for LFM2-ColBert-350M + dense_2_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_2_OUT, "weight"), {n_embd, hparams.n_embd_out()}, TENSOR_NOT_REQUIRED); + dense_2_out_layers_b = create_tensor(tn(LLM_TENSOR_DENSE_2_OUT, "bias"), {hparams.n_embd_out() }, TENSOR_NOT_REQUIRED); +} + +std::unique_ptr llama_model_lfm2::build_arch_graph(const llm_graph_params & params) const { + if (hparams.swa_type == LLAMA_SWA_TYPE_STANDARD) { + return std::make_unique>(*this, params); + } else { + return std::make_unique>(*this, params); + } +} + template -llm_build_lfm2::llm_build_lfm2(const llama_model & model, const llm_graph_params & params) : +llama_model_lfm2::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { using inp_hybrid_type = std::conditional_t; using inp_attn_type = std::conditional_t; @@ -187,5 +271,5 @@ llm_build_lfm2::llm_build_lfm2(const llama_model & model, const llm_graph_ } // Explicit template instantiations -template struct llm_build_lfm2; -template struct llm_build_lfm2; +template struct llama_model_lfm2::graph; +template struct llama_model_lfm2::graph; diff --git a/examples/talk-llama/models/lfm2moe.cpp b/examples/talk-llama/models/lfm2moe.cpp new file mode 100644 index 00000000000..12a66c05c7d --- /dev/null +++ b/examples/talk-llama/models/lfm2moe.cpp @@ -0,0 +1,85 @@ +#include "models.h" +#include "../llama-memory-hybrid-iswa.h" +#include "../llama-memory-hybrid.h" + +void llama_model_lfm2moe::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_SHORTCONV_L_CACHE, hparams.n_shortconv_l_cache); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func); + + for (uint32_t il = 0; il < hparams.n_layer; ++il) { + hparams.recurrent_layer_arr[il] = hparams.n_head_kv(il) == 0; + } + + switch (hparams.n_layer) { + case 24: type = LLM_TYPE_8B_A1B; break; + case 40: type = LLM_TYPE_24B_A2B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_lfm2moe::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM_LFM2, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + const bool is_moe_layer = i >= static_cast(hparams.n_layer_dense_lead); + + // ffn/moe is same for transformer and conv layers + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + if (is_moe_layer) { + GGML_ASSERT(n_expert && n_expert_used); + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, hparams.n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {hparams.n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, hparams.n_ff_exp, n_expert}, 0); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, 0); + } else { // dense + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + + // for operator_norm + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + if (!hparams.is_recurrent(i)) { + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + GGML_ASSERT(n_embd_v_gqa == n_embd_k_gqa); + + create_tensor_qkv(layer, i, n_embd, n_embd, hparams.n_embd_k_gqa(i), hparams.n_embd_v_gqa(i), 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + } else { + layer.shortconv.conv = create_tensor(tn(LLM_TENSOR_SHORTCONV_CONV, "weight", i), {hparams.n_shortconv_l_cache, n_embd}, 0); + layer.shortconv.in_proj = create_tensor(tn(LLM_TENSOR_SHORTCONV_INPROJ, "weight", i), {n_embd, 3 * n_embd}, 0); + layer.shortconv.out_proj = create_tensor(tn(LLM_TENSOR_SHORTCONV_OUTPROJ, "weight", i), {n_embd, n_embd}, 0); + } + } + + // for LFM2-ColBert-350M + dense_2_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_2_OUT, "weight"), {n_embd, hparams.n_embd_out()}, TENSOR_NOT_REQUIRED); + dense_2_out_layers_b = create_tensor(tn(LLM_TENSOR_DENSE_2_OUT, "bias"), {hparams.n_embd_out() }, TENSOR_NOT_REQUIRED); +} + +std::unique_ptr llama_model_lfm2moe::build_arch_graph(const llm_graph_params & params) const { + if (hparams.swa_type == LLAMA_SWA_TYPE_STANDARD) { + return std::make_unique>(*this, params); + } else { + return std::make_unique>(*this, params); + } +} + diff --git a/examples/talk-llama/models/llada-moe.cpp b/examples/talk-llama/models/llada-moe.cpp index c756d6fde5f..b60f67f6c4b 100644 --- a/examples/talk-llama/models/llada-moe.cpp +++ b/examples/talk-llama/models/llada-moe.cpp @@ -1,6 +1,56 @@ #include "models.h" -llm_build_llada_moe::llm_build_llada_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_llada_moe::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + // diffusion language model uses non-causal attention + hparams.causal_attn = false; + switch (hparams.n_layer) { + case 16: type = LLM_TYPE_A1_7B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_llada_moe::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + GGML_ASSERT(n_expert > 0 && "n_expert must be > 0 for llada-moe"); + GGML_ASSERT(n_expert_used > 0 && "n_expert_used must be > 0 for llada-moe"); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + } +} + +std::unique_ptr llama_model_llada_moe::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_llada_moe::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/llada.cpp b/examples/talk-llama/models/llada.cpp index 501df3c7eaf..fa21c5fe32c 100644 --- a/examples/talk-llama/models/llada.cpp +++ b/examples/talk-llama/models/llada.cpp @@ -1,6 +1,72 @@ #include "models.h" -llm_build_llada::llm_build_llada(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_llada::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + // LLaDA-8B has 32 layers, similar to LLaMA but for diffusion + switch (hparams.n_layer) { + case 32: + type = LLM_TYPE_8B; + break; + default: + type = LLM_TYPE_UNKNOWN; + } + // Set non-causal attention for diffusion models + hparams.causal_attn = false; +} + +void llama_model_llada::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = + create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + + // Use separate Q, K, V projections without bias, matching LLaDALlamaBlock + layer.wq = + create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0); + // No bias for QKV projections as per config: include_bias=false, include_qkv_bias=false + layer.wo = + create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), { n_embd }, TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); + + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), { n_rot / 2 }, + TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0); + + // optional MLP bias + layer.ffn_gate_b = + create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), { n_ff }, TENSOR_NOT_REQUIRED); + layer.ffn_down_b = + create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), { n_embd }, TENSOR_NOT_REQUIRED); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), { n_ff }, TENSOR_NOT_REQUIRED); + } +} + +std::unique_ptr llama_model_llada::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_llada::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { // LLaDA is similar to LLaMA but uses non-causal attention for diffusion const int64_t n_embd_head = hparams.n_embd_head_v(); diff --git a/examples/talk-llama/models/llama-embed.cpp b/examples/talk-llama/models/llama-embed.cpp new file mode 100644 index 00000000000..0699e744461 --- /dev/null +++ b/examples/talk-llama/models/llama-embed.cpp @@ -0,0 +1,6 @@ +#include "models.h" + +std::unique_ptr llama_model_llama_embed::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique>(*this, params); +} + diff --git a/examples/talk-llama/models/llama.cpp b/examples/talk-llama/models/llama.cpp index 8d478dc6747..8ddb5936820 100644 --- a/examples/talk-llama/models/llama.cpp +++ b/examples/talk-llama/models/llama.cpp @@ -1,7 +1,102 @@ #include "models.h" +void llama_model_llama::load_arch_hparams(llama_model_loader & ml) { + uint32_t n_vocab = 0; + ml.get_key(LLM_KV_VOCAB_SIZE, n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, n_vocab, false); + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + if (hparams.n_expert == 8) { + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_8x7B; break; + case 56: type = LLM_TYPE_8x22B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } else { + switch (hparams.n_layer) { + case 16: type = LLM_TYPE_1B; break; // Llama 3.2 1B + case 22: type = LLM_TYPE_1B; break; + case 26: type = LLM_TYPE_3B; break; + case 28: type = LLM_TYPE_3B; break; // Llama 3.2 3B + case 30: type = LLM_TYPE_256M; break; // smoldocling 256M + // granite uses a vocab with len 49152 + case 32: type = n_vocab == 49152 ? LLM_TYPE_3B : (n_vocab < 40000 ? LLM_TYPE_7B : LLM_TYPE_8B); break; + case 36: type = LLM_TYPE_8B; break; // granite + case 40: type = LLM_TYPE_13B; break; + case 48: type = LLM_TYPE_34B; break; + case 60: type = LLM_TYPE_30B; break; + case 80: type = hparams.n_head() == hparams.n_head_kv() ? LLM_TYPE_65B : LLM_TYPE_70B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } +} + +void llama_model_llama::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + // optional bias tensors + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + else { + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + + if (n_expert == 0) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + + // optional MLP bias + layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + } else { + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + + // For Granite MoE Shared + if (hparams.n_ff_shexp > 0) { + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, 0); + } + } + } +} + +std::unique_ptr llama_model_llama::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique>(*this, params); +} + template -llm_build_llama::llm_build_llama(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +llama_model_llama::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); @@ -149,5 +244,5 @@ llm_build_llama::llm_build_llama(const llama_model & model, const llm_gra ggml_build_forward_expand(gf, cur); } -template struct llm_build_llama; -template struct llm_build_llama; +template struct llama_model_llama::graph; +template struct llama_model_llama::graph; diff --git a/examples/talk-llama/models/llama4.cpp b/examples/talk-llama/models/llama4.cpp index 4e4bfb43f33..899611d53f6 100644 --- a/examples/talk-llama/models/llama4.cpp +++ b/examples/talk-llama/models/llama4.cpp @@ -1,7 +1,109 @@ #include "models.h" +void llama_model_llama4::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_INTERLEAVE_MOE_LAYER_STEP, hparams.n_moe_layer_step); + + const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); + if (found_swa && hparams.n_swa == 0) { + hparams.swa_type = LLAMA_SWA_TYPE_NONE; + hparams.n_no_rope_layer_step = hparams.n_layer; // always use rope + } else { + hparams.swa_type = LLAMA_SWA_TYPE_CHUNKED; + hparams.n_swa = 8192; + hparams.n_attn_temp_floor_scale = 8192; + hparams.f_attn_temp_scale = 0.1f; + hparams.f_attn_temp_offset = 1.0f; + uint32_t swa_period = 4; // pattern: 3 chunked - 1 full + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); + hparams.set_swa_pattern(swa_period); + + hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; + hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + } + + switch (hparams.n_expert) { + case 0: { + // MobileLLM (no MoE) + switch (hparams.n_embd) { + case 2048: type = LLM_TYPE_140M; break; + case 4096: type = LLM_TYPE_360M; break; + case 6144: type = LLM_TYPE_950M; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case 16: type = LLM_TYPE_17B_16E; break; + case 128: type = LLM_TYPE_17B_128E; break; + default: type = LLM_TYPE_UNKNOWN; + } + + hparams.use_kq_norm = type != LLM_TYPE_17B_128E; +} + +void llama_model_llama4::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + if (n_expert == 0) { + throw std::runtime_error(arch_name() + " model cannot have zero experts"); + } + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + const bool is_moe_layer = hparams.n_moe_layer_step > 0 && (i + 1) % hparams.n_moe_layer_step == 0; + + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + + if (is_moe_layer) { + const int64_t n_ff_exp = hparams.n_ff_exp; + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); + + // Shared expert + const int64_t n_ff_shexp = n_ff_exp; + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd }, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp}, 0); + } else { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } +} + +std::unique_ptr llama_model_llama4::build_arch_graph(const llm_graph_params & params) const { + if (hparams.swa_type == LLAMA_SWA_TYPE_NONE) { + return std::make_unique>(*this, params); + } else { + return std::make_unique>(*this, params); + } +} + template -llm_build_llama4::llm_build_llama4(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +llama_model_llama4::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); @@ -167,5 +269,5 @@ llm_build_llama4::llm_build_llama4(const llama_model & model, const llm_gr } // Explicit template instantiations -template struct llm_build_llama4; -template struct llm_build_llama4; +template struct llama_model_llama4::graph; +template struct llama_model_llama4::graph; diff --git a/examples/talk-llama/models/maincoder.cpp b/examples/talk-llama/models/maincoder.cpp index 8a76931c007..3dbd82fd362 100644 --- a/examples/talk-llama/models/maincoder.cpp +++ b/examples/talk-llama/models/maincoder.cpp @@ -1,6 +1,49 @@ #include "models.h" -llm_build_maincoder::llm_build_maincoder(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_maincoder::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_1B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_maincoder::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} + +std::unique_ptr llama_model_maincoder::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_maincoder::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/mamba.cpp b/examples/talk-llama/models/mamba.cpp index 55fd2e055c4..b7708d7fdd1 100644 --- a/examples/talk-llama/models/mamba.cpp +++ b/examples/talk-llama/models/mamba.cpp @@ -1,6 +1,90 @@ #include "models.h" -llm_build_mamba::llm_build_mamba(const llama_model & model, const llm_graph_params & params) : llm_build_mamba_base(params) { +void llama_model_mamba::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + ml.get_key(LLM_KV_SSM_DT_B_C_RMS, hparams.ssm_dt_b_c_rms, false); + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 24: + switch (hparams.n_embd) { + case 768: type = LLM_TYPE_SMALL; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 48: + switch (hparams.n_embd) { + case 1024: type = LLM_TYPE_MEDIUM; break; + case 1536: type = LLM_TYPE_LARGE; break; + case 2048: type = LLM_TYPE_XL; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 64: + switch (hparams.n_embd) { + case 2560: type = LLM_TYPE_3B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_mamba::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + const int64_t d_conv = hparams.ssm_d_conv; + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t d_state = hparams.ssm_d_state; + const int64_t dt_rank = hparams.ssm_dt_rank; + + // only an expansion factor of 2 is supported for now + if (2 * n_embd != d_inner) { + throw std::runtime_error("only an expansion factor of 2 is supported for now"); + } + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed, duplicated to allow offloading + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + // norm + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, 2*d_inner}, 0); + + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner}, 0); + layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner}, 0); + + layer.ssm_x = create_tensor(tn(LLM_TENSOR_SSM_X, "weight", i), {d_inner, dt_rank + 2*d_state}, 0); + + layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "weight", i), {dt_rank, d_inner}, 0); + layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {d_inner}, 0); + + // no "weight" suffix for these + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {d_state, d_inner}, 0); + layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {d_inner}, 0); + + // out_proj + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0); + } +} + +std::unique_ptr llama_model_mamba::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_mamba::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_build_mamba_base(params) { ggml_tensor * cur; ggml_tensor * inpL; @@ -51,4 +135,3 @@ llm_build_mamba::llm_build_mamba(const llama_model & model, const llm_graph_para ggml_build_forward_expand(gf, cur); } - diff --git a/examples/talk-llama/models/mamba2.cpp b/examples/talk-llama/models/mamba2.cpp new file mode 100644 index 00000000000..3277ca53ec4 --- /dev/null +++ b/examples/talk-llama/models/mamba2.cpp @@ -0,0 +1,87 @@ +#include "models.h" + +void llama_model_mamba2::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 24: + switch (hparams.n_embd) { + case 768: type = LLM_TYPE_SMALL; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 48: + switch (hparams.n_embd) { + case 1024: type = LLM_TYPE_MEDIUM; break; + case 1536: type = LLM_TYPE_LARGE; break; + case 2048: type = LLM_TYPE_XL; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 64: + switch (hparams.n_embd) { + case 2560: type = LLM_TYPE_3B; break; + case 4096: type = LLM_TYPE_7B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_mamba2::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + const int64_t d_conv = hparams.ssm_d_conv; + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t d_state = hparams.ssm_d_state; + const int64_t n_group = hparams.ssm_n_group; + const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + n_head; + + // only an expansion factor of 2 is supported for now + GGML_ASSERT(2 * n_embd == d_inner); + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + { + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed, duplicated to allow offloading + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + // norm + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, d_in_proj}, 0); + + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner + 2*n_group*d_state}, 0); + layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner + 2*n_group*d_state}, 0); + + layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {n_head}, 0); + + // no "weight" suffix for these + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_head}, 0); + layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, n_head}, 0); + + layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {d_inner / n_group, n_group}, 0); + + // out_proj + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0); + } +} + +std::unique_ptr llama_model_mamba2::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + diff --git a/examples/talk-llama/models/mimo2-iswa.cpp b/examples/talk-llama/models/mimo2-iswa.cpp deleted file mode 100644 index 52c6acfe214..00000000000 --- a/examples/talk-llama/models/mimo2-iswa.cpp +++ /dev/null @@ -1,129 +0,0 @@ -#include "models.h" - -llm_build_mimo2_iswa::llm_build_mimo2_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - ggml_tensor * cur; - ggml_tensor * inpL; - - inpL = build_inp_embd(model.tok_embd); - - ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_iswa(); - ggml_tensor * inp_out_ids = build_inp_out_ids(); - - for (int il = 0; il < n_layer; ++il) { - ggml_tensor * inpSA = inpL; - - uint32_t n_head_l = hparams.n_head(il); - uint32_t n_head_kv_l = hparams.n_head_kv(il); - const float freq_base_l = model.get_rope_freq_base(cparams, il); - const float freq_scale_l = model.get_rope_freq_scale(cparams, il); - - cur = inpL; - - // self_attention - { - cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); - cb(cur, "attn_norm", il); - - // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head_l, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv_l, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head_v, n_head_kv_l, n_tokens); - - Qcur = ggml_rope_ext( - ctx0, Qcur, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, - ext_factor, attn_factor, beta_fast, beta_slow - ); - - Kcur = ggml_rope_ext( - ctx0, Kcur, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, - ext_factor, attn_factor, beta_fast, beta_slow - ); - - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); - - ggml_tensor * sinks = model.layers[il].attn_sinks; - - cur = build_attn(inp_attn, - model.layers[il].wo, NULL, model.layers[il].wo_s, - Qcur, Kcur, Vcur, nullptr, sinks, nullptr, 1.0f/sqrtf(float(n_embd_head_k)), il); - } - - if (il == n_layer - 1 && inp_out_ids) { - cur = ggml_get_rows(ctx0, cur, inp_out_ids); - inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); - } - - ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); - cb(ffn_inp, "ffn_inp", il); - - cur = build_norm(ffn_inp, - model.layers[il].ffn_norm, NULL, - LLM_NORM_RMS, il); - cb(cur, "ffn_norm", il); - - // feed-forward network - if (model.layers[il].ffn_gate_inp == nullptr) { - // dense branch - cur = build_ffn(cur, - model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, - model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, - model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, - NULL, - LLM_FFN_SILU, LLM_FFN_PAR, il); - cb(cur, "ffn_out", il); - } else { - // MoE branch - cur = build_moe_ffn(cur, - model.layers[il].ffn_gate_inp, - model.layers[il].ffn_up_exps, - model.layers[il].ffn_gate_exps, - model.layers[il].ffn_down_exps, - model.layers[il].ffn_exp_probs_b, - n_expert, n_expert_used, - LLM_FFN_SILU, true, - hparams.expert_weights_scale, - LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID, - il); - cb(cur, "ffn_moe_out", il); - } - - cur = ggml_add(ctx0, cur, ffn_inp); - - cur = build_cvec(cur, il); - cb(cur, "l_out", il); - - // input for next layer - inpL = cur; - } - - cur = inpL; - - cur = build_norm(cur, - model.output_norm, NULL, - LLM_NORM_RMS, -1); - - cb(cur, "result_norm", -1); - res->t_embd = cur; - - // lm_head - cur = build_lora_mm(model.output, cur); - - cb(cur, "result_output", -1); - res->t_logits = cur; - - ggml_build_forward_expand(gf, cur); -} diff --git a/examples/talk-llama/models/mimo2.cpp b/examples/talk-llama/models/mimo2.cpp new file mode 100644 index 00000000000..71996616611 --- /dev/null +++ b/examples/talk-llama/models/mimo2.cpp @@ -0,0 +1,240 @@ +#include "models.h" + +void llama_model_mimo2::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.swa_layers, hparams.n_layer); + + float value_scale = 0.0f; + if (ml.get_key(LLM_KV_ATTENTION_VALUE_SCALE, value_scale, false) && value_scale != 1.0f) { + hparams.f_attn_value_scale = value_scale; + } + + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); + hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; + + switch (hparams.n_layer - hparams.nextn_predict_layers) { + case 48: type = LLM_TYPE_310B_A15B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_mimo2::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + const uint32_t n_nextn = hparams.nextn_predict_layers; + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i); + uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i); + uint32_t n_head = hparams.n_head(i); + + // NextN/MTP layers (the last n_nextn blocks) are preserved but disabled pending support + const bool is_nextn = (n_nextn > 0) && (static_cast(i) >= n_layer - n_nextn); + const int skip = is_nextn ? TENSOR_SKIP : 0; + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, skip); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_v * n_head, n_embd }, skip); + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, skip); + layer.attn_sinks = create_tensor(tn(LLM_TENSOR_ATTN_SINKS, "weight", i), {n_head}, TENSOR_NOT_REQUIRED | skip); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, skip); + + // non-MoE branch + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED | skip); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, TENSOR_NOT_REQUIRED | skip); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED | skip); + + // MoE branch + int64_t n_ff_exp = hparams.n_ff_exp; + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, TENSOR_NOT_REQUIRED | skip); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED | skip); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, TENSOR_NOT_REQUIRED | skip); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED | skip); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED | skip); + + if (is_nextn) { + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), {2 * n_embd, n_embd}, skip); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), {n_embd}, skip); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), {n_embd}, skip); + layer.layer_out_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, skip); + } + } +} + +std::unique_ptr llama_model_mimo2::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_mimo2::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + ggml_tensor * inp_pos = build_inp_pos(); + auto * inp_attn = build_attn_inp_kv_iswa(); + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + const float v_scale = hparams.f_attn_value_scale; + + // The last hparams.nextn_predict_layers blocks are MTP heads, currently inactive + const int n_transformer_layers = n_layer - hparams.nextn_predict_layers; + + for (int il = 0; il < n_transformer_layers; ++il) { + ggml_tensor * inpSA = inpL; + + uint32_t n_head_l = hparams.n_head(il); + uint32_t n_head_kv_l = hparams.n_head_kv(il); + const float freq_base_l = model.get_rope_freq_base(cparams, il); + const float freq_scale_l = model.get_rope_freq_scale(cparams, il); + + cur = inpL; + + // self_attention + { + cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + ggml_tensor * Qcur; + ggml_tensor * Kcur; + ggml_tensor * Vcur; + + if (model.layers[il].wqkv) { + // Fused qkv_proj - Q/K share head_dim_k, V uses head_dim_v + ggml_tensor * qkv = build_lora_mm(model.layers[il].wqkv, cur); + cb(qkv, "wqkv", il); + + const size_t row_k = ggml_row_size(qkv->type, n_embd_head_k); + const size_t row_v = ggml_row_size(qkv->type, n_embd_head_v); + const size_t row_full = qkv->nb[1]; + const size_t k_off = row_k * n_head_l; + const size_t v_off = k_off + row_k * n_head_kv_l; + + Qcur = ggml_view_3d(ctx0, qkv, n_embd_head_k, n_head_l, n_tokens, row_k, row_full, 0); + Kcur = ggml_view_3d(ctx0, qkv, n_embd_head_k, n_head_kv_l, n_tokens, row_k, row_full, k_off); + Vcur = ggml_view_3d(ctx0, qkv, n_embd_head_v, n_head_kv_l, n_tokens, row_v, row_full, v_off); + } else { + // Split path + Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head_l, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv_l, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head_v, n_head_kv_l, n_tokens); + } + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + ggml_tensor * sinks = model.layers[il].attn_sinks; + + cur = build_attn(inp_attn, + model.layers[il].wo, NULL, model.layers[il].wo_s, + Qcur, Kcur, Vcur, nullptr, sinks, nullptr, 1.0f/sqrtf(float(n_embd_head_k)), il); + cb(cur, "attn_out", il); + + if (v_scale) { + cur = ggml_scale(ctx0, cur, v_scale); + cb(cur, "attn_out_scaled", il); + } + } + + if (il == n_transformer_layers - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + // feed-forward network + if (model.layers[il].ffn_gate_inp == nullptr) { + // dense branch + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } else { + // MoE branch + cur = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + model.layers[il].ffn_exp_probs_b, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + hparams.expert_weights_scale, + LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID, + il); + cb(cur, "ffn_moe_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} diff --git a/examples/talk-llama/models/minicpm.cpp b/examples/talk-llama/models/minicpm.cpp new file mode 100644 index 00000000000..966d3af615c --- /dev/null +++ b/examples/talk-llama/models/minicpm.cpp @@ -0,0 +1,89 @@ +#include "models.h" + +void llama_model_minicpm::load_arch_hparams(llama_model_loader & ml) { + // Backward-compatible defaults for older MiniCPM GGUFs + hparams.f_embedding_scale = 12.0f; + hparams.f_residual_scale = 1.4f / sqrtf(float(hparams.n_layer)); + hparams.f_logit_scale = hparams.n_embd ? (256.0f / float(hparams.n_embd)) : 1.0f; + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + // Optional KV reads, override defaults if present in newer GGUF exports + ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale, /*required=*/false); + ml.get_key(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale, /*required=*/false); + ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale, /*required=*/false); + + // MiniCPM uses rope by default, unlike Granite which uses it as a switch + hparams.rope_finetuned = true; + + switch (hparams.n_layer) { + case 52: type = LLM_TYPE_1B; break; + case 40: type = LLM_TYPE_2B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_minicpm::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + // optional bias tensors + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + else { + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + + if (n_expert == 0) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + + // optional MLP bias + layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + } else { + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + + // For Granite MoE Shared + if (hparams.n_ff_shexp > 0) { + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, 0); + } + } + } +} + +std::unique_ptr llama_model_minicpm::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + diff --git a/examples/talk-llama/models/minicpm3.cpp b/examples/talk-llama/models/minicpm3.cpp index bf12ab73c74..ff5eb6ffa5f 100644 --- a/examples/talk-llama/models/minicpm3.cpp +++ b/examples/talk-llama/models/minicpm3.cpp @@ -1,6 +1,66 @@ #include "models.h" -llm_build_minicpm3::llm_build_minicpm3(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_minicpm3::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); + ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + + switch (hparams.n_layer) { + case 62: type = LLM_TYPE_4B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_minicpm3::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + const int64_t n_embd_head_qk_rope = hparams.n_rot(); + const int64_t n_embd_head_qk_nope = hparams.n_embd_head_k() - hparams.n_rot(); + + const int64_t q_lora_rank = hparams.n_lora_q; + const int64_t kv_lora_rank = hparams.n_lora_kv; + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, 0); + + layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, 0); + + layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, 0); + layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k}, 0); + + layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + (n_embd_head_qk_rope)}, 0); + layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_head * ( n_embd_head_v), n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), { n_embd_head_qk_rope/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_embd_head_qk_rope/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } +} + +std::unique_ptr llama_model_minicpm3::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_minicpm3::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { //TODO: if the model varies, these parameters need to be read from the model const int64_t n_embd_base = 256; const float scale_embd = 12.0f; diff --git a/examples/talk-llama/models/minimax-m2.cpp b/examples/talk-llama/models/minimax-m2.cpp index b809b79f2b9..0dee8934692 100644 --- a/examples/talk-llama/models/minimax-m2.cpp +++ b/examples/talk-llama/models/minimax-m2.cpp @@ -1,6 +1,50 @@ #include "models.h" -llm_build_minimax_m2::llm_build_minimax_m2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_minimax_m2::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); + + switch (hparams.n_layer) { + case 62: type = LLM_TYPE_230B_A10B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_minimax_m2::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k * n_head}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_k_gqa}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, 0); + } +} + +std::unique_ptr llama_model_minimax_m2::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_minimax_m2::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/mistral3.cpp b/examples/talk-llama/models/mistral3.cpp index b5ae72a2ee1..708da49af1f 100644 --- a/examples/talk-llama/models/mistral3.cpp +++ b/examples/talk-llama/models/mistral3.cpp @@ -1,6 +1,96 @@ #include "models.h" -llm_build_mistral3::llm_build_mistral3(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_mistral3::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_SCALE, hparams.f_attn_temp_scale, false); + + ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_FAST, hparams.yarn_beta_fast, false); + ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, hparams.yarn_beta_slow, false); + ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, false); + + hparams.f_attn_temp_offset = 0.0f; + + // TODO: maybe add n_attn_temp_floor_scale as a separate KV? + if (hparams.f_attn_temp_scale != 0.0f) { + hparams.n_attn_temp_floor_scale = hparams.n_ctx_orig_yarn; + if (hparams.n_attn_temp_floor_scale == 0) { + throw std::runtime_error("invalid n_ctx_orig_yarn for attention temperature scaling"); + } + } + + switch (hparams.n_layer) { + case 26: type = LLM_TYPE_3B; break; + case 34: type = LLM_TYPE_8B; break; + case 40: type = LLM_TYPE_14B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_mistral3::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + // optional bias tensors + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + else { + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + + if (n_expert == 0) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + + // optional MLP bias + layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + } else { + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + + // For Granite MoE Shared + if (hparams.n_ff_shexp > 0) { + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, 0); + } + } + } +} + +std::unique_ptr llama_model_mistral3::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_mistral3::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/mistral4.cpp b/examples/talk-llama/models/mistral4.cpp new file mode 100644 index 00000000000..3d9190650e3 --- /dev/null +++ b/examples/talk-llama/models/mistral4.cpp @@ -0,0 +1,6 @@ +#include "models.h" + +std::unique_ptr llama_model_mistral4::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + diff --git a/examples/talk-llama/models/models.h b/examples/talk-llama/models/models.h index 94991c55fe8..6d5f18a8e20 100644 --- a/examples/talk-llama/models/models.h +++ b/examples/talk-llama/models/models.h @@ -2,6 +2,7 @@ #include "llama-model.h" #include "llama-graph.h" +#include "llama-model-loader.h" // note: almost all graphs require at least sqrtf, so include cmath globally #include @@ -110,611 +111,1750 @@ struct llm_build_rwkv7_base : public llm_graph_context { // models // -struct llm_build_afmoe : public llm_graph_context { - llm_build_afmoe(const llama_model & model, const llm_graph_params & params); +struct llama_model_llama : public llama_model_base { + llama_model_llama(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + template + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_apertus : public llm_graph_context { - llm_build_apertus(const llama_model & model, const llm_graph_params & params); + +struct llama_model_llama4 : public llama_model_base { + llama_model_llama4(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + template + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_arcee : public llm_graph_context { - llm_build_arcee(const llama_model & model, const llm_graph_params & params); + +struct llama_model_llama_embed : public llama_model_llama { + llama_model_llama_embed(const struct llama_model_params & params) : llama_model_llama(params) {} + // reuse load_arch_hparams and load_arch_tensors from llama_model_llama + + template + using graph = llama_model_llama::graph; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_arctic : public llm_graph_context { - llm_build_arctic(const llama_model & model, const llm_graph_params & params); + +struct llama_model_maincoder : public llama_model_base { + llama_model_maincoder(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_arwkv7 : public llm_build_rwkv7_base { - llm_build_arwkv7(const llama_model & model, const llm_graph_params & params); + +struct llama_model_deci : public llama_model_base { + llama_model_deci(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_baichuan : public llm_graph_context { - llm_build_baichuan(const llama_model & model, const llm_graph_params & params); + +struct llama_model_baichuan : public llama_model_base { + llama_model_baichuan(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_bailingmoe2 : public llm_graph_context { - llm_build_bailingmoe2(const llama_model & model, const llm_graph_params & params); + +struct llama_model_falcon : public llama_model_base { + llama_model_falcon(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_bailingmoe : public llm_graph_context { - llm_build_bailingmoe(const llama_model & model, const llm_graph_params & params); + +struct llama_model_grok : public llama_model_base { + llama_model_grok(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_bert : public llm_graph_context { - llm_build_bert(const llama_model & model, const llm_graph_params & params); + +struct llama_model_starcoder : public llama_model_base { + llama_model_starcoder(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_bitnet : public llm_graph_context { - llm_build_bitnet(const llama_model & model, const llm_graph_params & params); + +struct llama_model_refact : public llama_model_base { + llama_model_refact(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_bloom : public llm_graph_context { - llm_build_bloom(const llama_model & model, const llm_graph_params & params); + +struct llama_model_bert : public llama_model_base { + llama_model_bert(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_chameleon : public llm_graph_context { - llm_build_chameleon(const llama_model & model, const llm_graph_params & params); + +struct llama_model_jina_bert_v2 : public llama_model_base { + llama_model_jina_bert_v2(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + using graph = llama_model_bert::graph; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_chatglm : public llm_graph_context { - llm_build_chatglm(const llama_model & model, const llm_graph_params & params); + +struct llama_model_jina_bert_v3 : public llama_model_base { + llama_model_jina_bert_v3(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + using graph = llama_model_bert::graph; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_codeshell : public llm_graph_context { - llm_build_codeshell(const llama_model & model, const llm_graph_params & params); + +struct llama_model_nomic_bert : public llama_model_base { + llama_model_nomic_bert(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + using graph = llama_model_bert::graph; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_cogvlm : public llm_graph_context { - llm_build_cogvlm(const llama_model & model, const llm_graph_params & params); + +struct llama_model_nomic_bert_moe : public llama_model_base { + llama_model_nomic_bert_moe(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + using graph = llama_model_bert::graph; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_cohere2_iswa : public llm_graph_context { - llm_build_cohere2_iswa(const llama_model & model, const llm_graph_params & params); + +struct llama_model_modern_bert : public llama_model_base { + llama_model_modern_bert(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_command_r : public llm_graph_context { - llm_build_command_r(const llama_model & model, const llm_graph_params & params); + +struct llama_model_neo_bert : public llama_model_base { + llama_model_neo_bert(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_dbrx : public llm_graph_context { - llm_build_dbrx(const llama_model & model, const llm_graph_params & params); + +struct llama_model_eurobert : public llama_model_base { + llama_model_eurobert(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_deci : public llm_graph_context { - llm_build_deci(const llama_model & model, const llm_graph_params & params); + +struct llama_model_bloom : public llama_model_base { + llama_model_bloom(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_deepseek2 : public llm_graph_context { - llm_build_deepseek2(const llama_model & model, const llm_graph_params & params); + +struct llama_model_mpt : public llama_model_base { + llama_model_mpt(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_deepseek : public llm_graph_context { - llm_build_deepseek(const llama_model & model, const llm_graph_params & params); + +struct llama_model_stablelm : public llama_model_base { + llama_model_stablelm(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_dots1 : public llm_graph_context { - llm_build_dots1(const llama_model & model, const llm_graph_params & params); + +struct llama_model_qwen : public llama_model_base { + llama_model_qwen(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_dream : public llm_graph_context { - llm_build_dream(const llama_model & model, const llm_graph_params & params); + +struct llama_model_qwen2 : public llama_model_base { + llama_model_qwen2(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_ernie4_5 : public llm_graph_context { - llm_build_ernie4_5(const llama_model & model, const llm_graph_params & params); + +struct llama_model_dream : public llama_model_base { + llama_model_dream(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_ernie4_5_moe : public llm_graph_context { - llm_build_ernie4_5_moe(const llama_model & model, const llm_graph_params & params); + +struct llama_model_llada : public llama_model_base { + llama_model_llada(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_paddleocr : public llm_graph_context { - llm_build_paddleocr(const llama_model & model, const llm_graph_params & params); + +struct llama_model_llada_moe : public llama_model_base { + llama_model_llada_moe(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -template -struct llm_build_exaone4 : public llm_graph_context { - llm_build_exaone4(const llama_model & model, const llm_graph_params & params); + +struct llama_model_rnd1 : public llama_model_base { + llama_model_rnd1(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_exaone : public llm_graph_context { - llm_build_exaone(const llama_model & model, const llm_graph_params & params); + +struct llama_model_qwen2vl : public llama_model_base { + llama_model_qwen2vl(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_exaone_moe : public llm_graph_context { - llm_build_exaone_moe(const llama_model & model, const llm_graph_params & params); + +struct llama_model_qwen2moe : public llama_model_base { + llama_model_qwen2moe(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_falcon : public llm_graph_context { - llm_build_falcon(const llama_model & model, const llm_graph_params & params); + +struct llama_model_qwen3 : public llama_model_base { + llama_model_qwen3(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_falcon_h1 : public llm_build_mamba_base { - llm_build_falcon_h1(const llama_model & model, const llm_graph_params & params); + +struct llama_model_qwen3moe : public llama_model_base { + llama_model_qwen3moe(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_gemma2_iswa : public llm_graph_context { - llm_build_gemma2_iswa(const llama_model & model, const llm_graph_params & params); + +struct llama_model_qwen3vl : public llama_model_base { + llama_model_qwen3vl(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -template -struct llm_build_gemma3 : public llm_graph_context { - llm_build_gemma3(const llama_model & model, const llm_graph_params & params); + +struct llama_model_qwen3vlmoe : public llama_model_base { + llama_model_qwen3vlmoe(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_gemma3n_iswa : public llm_graph_context { - const llama_model & model; - const int64_t n_embd_head; - const int64_t n_embd_altup; - const int64_t n_altup; - const int i_altup_act; - const int n_layer_sparsity = 10; // number of layers using activation sparsity - const float f_sparsity_std_mul = 1.6448533535003662f; // std_multiplier = normal_dist.icdf(0.95) +struct llama_model_phi2 : public llama_model_base { + llama_model_phi2(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; +}; - llm_build_gemma3n_iswa(const llama_model & model, const llm_graph_params & params); - ggml_tensor * calc_magnitude(ggml_tensor * x); - // TODO: refactor in common "per-layer" functionality [TAG_PER_LAYER] - ggml_tensor * build_inp_per_layer(); - ggml_tensor * project_per_layer_inputs(ggml_tensor * inp_batch, ggml_tensor * inp_per_layer); +struct llama_model_phi3 : public llama_model_base { + llama_model_phi3(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; - ggml_tensor * gaussian_topk(ggml_tensor * x); - ggml_tensor * altup_compute_router_modalities(ggml_tensor * x, int il); - ggml_tensor * altup_predict(ggml_tensor * cur, int il); - ggml_tensor * laurel(ggml_tensor * cur, int il); - ggml_tensor * altup_correct(ggml_tensor * predictions, ggml_tensor * activated, int il); + template + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_gemma4_iswa : public llm_graph_context { - const llama_model & model; - const int64_t n_embd_per_layer; +struct llama_model_phimoe : public llama_model_base { + llama_model_phimoe(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; - llm_build_gemma4_iswa(const llama_model & model, const llm_graph_params & params); + template + using graph = llama_model_phi3::graph; - // TODO: refactor in common "per-layer" functionality [TAG_PER_LAYER] - ggml_tensor * build_inp_per_layer(); - ggml_tensor * project_per_layer_inputs(ggml_tensor * inp_batch, ggml_tensor * inp_per_layer); + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_gemma_embedding : public llm_graph_context { - llm_build_gemma_embedding(const llama_model & model, const llm_graph_params & params); + +struct llama_model_plamo : public llama_model_base { + llama_model_plamo(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_gemma : public llm_graph_context { - llm_build_gemma(const llama_model & model, const llm_graph_params & params); + +struct llama_model_plamo2 : public llama_model_base { + llama_model_plamo2(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_build_mamba_base { + graph(const llama_model & model, const llm_graph_params & params); + private: + ggml_tensor * build_plamo2_mamba_layer(llm_graph_input_rs * inp, ggml_tensor * cur, const llama_model & model, const llama_ubatch & ubatch, int il); + ggml_tensor * build_plamo2_attn_layer(llm_graph_input_attn_kv * inp, ggml_tensor * inp_pos, ggml_tensor * cur, + const llama_model & model, int il); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_glm4 : public llm_graph_context { - llm_build_glm4(const llama_model & model, const llm_graph_params & params); + +struct llama_model_plamo3 : public llama_model_base { + llama_model_plamo3(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + template + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_glm4_moe : public llm_graph_context { - llm_build_glm4_moe(const llama_model & model, const llm_graph_params & params); + +struct llama_model_gpt2 : public llama_model_base { + llama_model_gpt2(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_gpt2 : public llm_graph_context { - llm_build_gpt2(const llama_model & model, const llm_graph_params & params); + +struct llama_model_codeshell : public llama_model_base { + llama_model_codeshell(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_gptneox : public llm_graph_context { - llm_build_gptneox(const llama_model & model, const llm_graph_params & params); + +struct llama_model_orion : public llama_model_base { + llama_model_orion(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_granite : public llm_graph_context { - llm_build_granite(const llama_model & model, const llm_graph_params & params); -private: - ggml_tensor * build_attention_layer( - ggml_tensor * cur, - ggml_tensor * inp_pos, - llm_graph_input_attn_kv * inp_attn, - const llama_model & model, - const int64_t n_embd_head, - const int il); +struct llama_model_internlm2 : public llama_model_base { + llama_model_internlm2(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; - ggml_tensor * build_layer_ffn( - ggml_tensor * cur, - ggml_tensor * inpSA, - const llama_model & model, - const int il); + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_granite_hybrid : public llm_build_mamba_base { - llm_build_granite_hybrid(const llama_model & model, const llm_graph_params & params); - ggml_tensor * build_layer_ffn(ggml_tensor * cur, ggml_tensor * inpSA, const llama_model & model, const int il); - ggml_tensor * build_attention_layer(ggml_tensor * cur, ggml_tensor * inp_pos, llm_graph_input_attn_kv * inp_attn, - const llama_model & model,const int64_t n_embd_head, const int il); + +struct llama_model_minicpm3 : public llama_model_base { + llama_model_minicpm3(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_grok : public llm_graph_context { - llm_build_grok(const llama_model & model, const llm_graph_params & params); + +struct llama_model_gemma : public llama_model_base { + llama_model_gemma(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_grovemoe : public llm_graph_context { - llm_build_grovemoe(const llama_model & model, const llm_graph_params & params); + +struct llama_model_gemma2 : public llama_model_base { + llama_model_gemma2(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_hunyuan_dense : public llm_graph_context { - llm_build_hunyuan_dense(const llama_model & model, const llm_graph_params & params); + +struct llama_model_gemma3 : public llama_model_base { + llama_model_gemma3(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + template + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_hunyuan_moe : public llm_graph_context { - llm_build_hunyuan_moe(const llama_model & model, const llm_graph_params & params); + +struct llama_model_gemma3n : public llama_model_base { + llama_model_gemma3n(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + const llama_model & model; + + const int64_t n_embd_head; + const int64_t n_embd_altup; + const int64_t n_altup; + const int i_altup_act; + const int n_layer_sparsity = 10; // number of layers using activation sparsity + const float f_sparsity_std_mul = 1.6448533535003662f; // std_multiplier = normal_dist.icdf(0.95) + + graph(const llama_model & model, const llm_graph_params & params); + ggml_tensor * calc_magnitude(ggml_tensor * x); + + // TODO: refactor in common "per-layer" functionality [TAG_PER_LAYER] + ggml_tensor * build_inp_per_layer(); + ggml_tensor * project_per_layer_inputs(ggml_tensor * inp_batch, ggml_tensor * inp_per_layer); + + ggml_tensor * gaussian_topk(ggml_tensor * x); + ggml_tensor * altup_compute_router_modalities(ggml_tensor * x, int il); + ggml_tensor * altup_predict(ggml_tensor * cur, int il); + ggml_tensor * laurel(ggml_tensor * cur, int il); + ggml_tensor * altup_correct(ggml_tensor * predictions, ggml_tensor * activated, int il); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_internlm2 : public llm_graph_context { - llm_build_internlm2(const llama_model & model, const llm_graph_params & params); + +struct llama_model_gemma4 : public llama_model_base { + llama_model_gemma4(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + const llama_model & model; + + const int64_t n_embd_per_layer; + + graph(const llama_model & model, const llm_graph_params & params); + + // TODO: refactor in common "per-layer" functionality [TAG_PER_LAYER] + ggml_tensor * build_inp_per_layer(); + ggml_tensor * project_per_layer_inputs(ggml_tensor * inp_batch, ggml_tensor * inp_per_layer); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_jais : public llm_graph_context { - llm_build_jais(const llama_model & model, const llm_graph_params & params); + +struct llama_model_gemma_embedding : public llama_model_base { + llama_model_gemma_embedding(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_jais2 : public llm_graph_context { - llm_build_jais2(const llama_model & model, const llm_graph_params & params); + +struct llama_model_starcoder2 : public llama_model_base { + llama_model_starcoder2(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_jamba : public llm_build_mamba_base { - llm_build_jamba(const llama_model & model, const llm_graph_params & params); + +struct llama_model_mamba : public llama_model_base { + llama_model_mamba(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_build_mamba_base { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_kimi_linear : public llm_build_delta_net_base { - llm_build_kimi_linear(const llama_model & model, const llm_graph_params & params); - std::pair build_kda_autoregressive( - ggml_tensor * q, - ggml_tensor * k, - ggml_tensor * v, - ggml_tensor * gk, - ggml_tensor * beta, - ggml_tensor * state, - int il); +struct llama_model_mamba2 : public llama_model_base { + llama_model_mamba2(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; - std::pair build_kda_chunking( - ggml_tensor * q, - ggml_tensor * k, - ggml_tensor * v, - ggml_tensor * gk, - ggml_tensor * beta, - ggml_tensor * state, - ggml_tensor * causal_mask, - ggml_tensor * identity, - ggml_tensor * diag_mask, - int il); + using graph = llama_model_mamba::graph; - const llama_model & model; + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -template -struct llm_build_lfm2 : public llm_graph_context { - llm_build_lfm2(const llama_model & model, const llm_graph_params & params); + +struct llama_model_jamba : public llama_model_base { + llama_model_jamba(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_build_mamba_base { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_llada : public llm_graph_context { - llm_build_llada(const llama_model & model, const llm_graph_params & params); + +struct llama_model_xverse : public llama_model_base { + llama_model_xverse(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_llada_moe : public llm_graph_context { - llm_build_llada_moe(const llama_model & model, const llm_graph_params & params); + +struct llama_model_command_r : public llama_model_base { + llama_model_command_r(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_cohere2 : public llama_model_base { + llama_model_cohere2(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_dbrx : public llama_model_base { + llama_model_dbrx(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_olmo : public llama_model_base { + llama_model_olmo(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_olmo2 : public llama_model_base { + llama_model_olmo2(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + template + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_olmoe : public llama_model_base { + llama_model_olmoe(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_openelm : public llama_model_base { + llama_model_openelm(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_gptneox : public llama_model_base { + llama_model_gptneox(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_arctic : public llama_model_base { + llama_model_arctic(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -template -struct llm_build_llama : public llm_graph_context { - llm_build_llama(const llama_model & model, const llm_graph_params & params); + +struct llama_model_deepseek : public llama_model_base { + llama_model_deepseek(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -template -struct llm_build_llama4 : public llm_graph_context { - llm_build_llama4(const llama_model & model, const llm_graph_params & params); + +struct llama_model_deepseek2 : public llama_model_base { + llama_model_deepseek2(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_maincoder : public llm_graph_context { - llm_build_maincoder(const llama_model & model, const llm_graph_params & params); + +struct llama_model_deepseek2ocr : public llama_model_base { + llama_model_deepseek2ocr(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + using graph = llama_model_deepseek2::graph; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_mamba : public llm_build_mamba_base { - llm_build_mamba(const llama_model & model, const llm_graph_params & params); + +struct llama_model_glm_dsa : public llama_model_base { + llama_model_glm_dsa(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + using graph = llama_model_deepseek2::graph; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_mimo2_iswa : public llm_graph_context { - llm_build_mimo2_iswa(const llama_model & model, const llm_graph_params & params); + +struct llama_model_mistral4 : public llama_model_deepseek2 { + llama_model_mistral4(const struct llama_model_params & params) : llama_model_deepseek2(params) {} + // reuse load_arch_hparams and load_arch_tensors from llama_model_deepseek2 + + using graph = llama_model_deepseek2::graph; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_minicpm3 : public llm_graph_context { - llm_build_minicpm3(const llama_model & model, const llm_graph_params & params); + +struct llama_model_chatglm : public llama_model_base { + llama_model_chatglm(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_minimax_m2 : public llm_graph_context { - llm_build_minimax_m2(const llama_model & model, const llm_graph_params & params); + +struct llama_model_glm4 : public llama_model_base { + llama_model_glm4(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_mistral3 : public llm_graph_context { - llm_build_mistral3(const llama_model & model, const llm_graph_params & params); + +struct llama_model_glm4_moe : public llama_model_base { + llama_model_glm4_moe(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_modern_bert : public llm_graph_context { - llm_build_modern_bert(const llama_model & model, const llm_graph_params & params); + +struct llama_model_bitnet : public llama_model_base { + llama_model_bitnet(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_mpt : public llm_graph_context { - llm_build_mpt(const llama_model & model, const llm_graph_params & params); + +struct llama_model_t5 : public llama_model_base { + llama_model_t5(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + template + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_nemotron : public llm_graph_context { - llm_build_nemotron(const llama_model & model, const llm_graph_params & params); + +struct llama_model_t5encoder : public llama_model_base { + llama_model_t5encoder(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + using graph = llama_model_t5::graph; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_nemotron_h : public llm_build_mamba_base { - llm_build_nemotron_h(const llama_model & model, const llm_graph_params & params); - ggml_tensor * build_ffn_layer(ggml_tensor * cur, const llama_model & model, int il); - ggml_tensor * build_attention_layer(ggml_tensor * cur, llm_graph_input_attn_kv * inp_attn, - const llama_model & model, int64_t n_embd_head, int il); + +struct llama_model_jais : public llama_model_base { + llama_model_jais(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_neo_bert : public llm_graph_context { - llm_build_neo_bert(const llama_model & model, const llm_graph_params & params); + +struct llama_model_jais2 : public llama_model_base { + llama_model_jais2(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_eurobert : public llm_graph_context { - llm_build_eurobert(const llama_model & model, const llm_graph_params & params); + +struct llama_model_nemotron : public llama_model_base { + llama_model_nemotron(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -template -struct llm_build_olmo2 : public llm_graph_context { - llm_build_olmo2(const llama_model & model, const llm_graph_params & params); + +struct llama_model_nemotron_h : public llama_model_base { + llama_model_nemotron_h(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_build_mamba_base { + graph(const llama_model & model, const llm_graph_params & params); + ggml_tensor * build_ffn_layer(ggml_tensor * cur, const llama_model & model, int il); + ggml_tensor * build_attention_layer(ggml_tensor * cur, llm_graph_input_attn_kv * inp_attn, + const llama_model & model, int64_t n_embd_head, int il); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_olmoe : public llm_graph_context { - llm_build_olmoe(const llama_model & model, const llm_graph_params & params); + +struct llama_model_nemotron_h_moe : public llama_model_nemotron_h { + llama_model_nemotron_h_moe(const struct llama_model_params & params) : llama_model_nemotron_h(params) {} + // reuse load_arch_hparams and load_arch_tensors from llama_model_nemotron_h + + using graph = llama_model_nemotron_h::graph; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_olmo : public llm_graph_context { - llm_build_olmo(const llama_model & model, const llm_graph_params & params); + +struct llama_model_exaone : public llama_model_base { + llama_model_exaone(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_openai_moe_iswa : public llm_graph_context { - llm_build_openai_moe_iswa(const llama_model & model, const llm_graph_params & params); + +struct llama_model_exaone4 : public llama_model_base { + llama_model_exaone4(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + template + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_openelm : public llm_graph_context { - llm_build_openelm(const llama_model & model, const llm_graph_params & params); + +struct llama_model_exaone_moe : public llama_model_base { + llama_model_exaone_moe(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_orion : public llm_graph_context { - llm_build_orion(const llama_model & model, const llm_graph_params & params); + +struct llama_model_rwkv6 : public llama_model_base { + llama_model_rwkv6(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_build_rwkv6_base { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_pangu_embedded : public llm_graph_context { - llm_build_pangu_embedded(const llama_model & model, const llm_graph_params & params); + +struct llama_model_rwkv6qwen2 : public llama_model_base { + llama_model_rwkv6qwen2(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_build_rwkv6_base { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_phi2 : public llm_graph_context { - llm_build_phi2(const llama_model & model, const llm_graph_params & params); + +struct llama_model_rwkv7 : public llama_model_base { + llama_model_rwkv7(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_build_rwkv7_base { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -template -struct llm_build_phi3 : public llm_graph_context { - llm_build_phi3(const llama_model & model, const llm_graph_params & params); + +struct llama_model_arwkv7 : public llama_model_base { + llama_model_arwkv7(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_build_rwkv7_base { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_plamo2 : public llm_build_mamba_base { - llm_build_plamo2(const llama_model & model, const llm_graph_params & params); + +struct llama_model_granite : public llama_model_base { + llama_model_granite(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + private: - ggml_tensor * build_plamo2_mamba_layer(llm_graph_input_rs * inp, ggml_tensor * cur, const llama_model & model, const llama_ubatch & ubatch, int il); - ggml_tensor * build_plamo2_attn_layer(llm_graph_input_attn_kv * inp, ggml_tensor * inp_pos, ggml_tensor * cur, - const llama_model & model, int il); + ggml_tensor * build_attention_layer( + ggml_tensor * cur, + ggml_tensor * inp_pos, + llm_graph_input_attn_kv * inp_attn, + const llama_model & model, + const int64_t n_embd_head, + const int il); + + ggml_tensor * build_layer_ffn( + ggml_tensor * cur, + ggml_tensor * inpSA, + const llama_model & model, + const int il); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_plamo : public llm_graph_context { - llm_build_plamo(const llama_model & model, const llm_graph_params & params); + +struct llama_model_granite_moe : public llama_model_base { + llama_model_granite_moe(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + using graph = llama_model_granite::graph; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -template -struct llm_build_plamo3 : public llm_graph_context { - llm_build_plamo3(const llama_model & model, const llm_graph_params & params); + +struct llama_model_minicpm : public llama_model_base { + llama_model_minicpm(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + using graph = llama_model_granite::graph; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_plm : public llm_graph_context { - llm_build_plm(const llama_model & model, const llm_graph_params & params); + +struct llama_model_granite_hybrid : public llama_model_base { + llama_model_granite_hybrid(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_build_mamba_base { + graph(const llama_model & model, const llm_graph_params & params); + ggml_tensor * build_layer_ffn(ggml_tensor * cur, ggml_tensor * inpSA, const llama_model & model, const int il); + ggml_tensor * build_attention_layer(ggml_tensor * cur, ggml_tensor * inp_pos, llm_graph_input_attn_kv * inp_attn, + const llama_model & model,const int64_t n_embd_head, const int il); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_qwen2 : public llm_graph_context { - llm_build_qwen2(const llama_model & model, const llm_graph_params & params); + +struct llama_model_chameleon : public llama_model_base { + llama_model_chameleon(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_qwen2moe : public llm_graph_context { - llm_build_qwen2moe(const llama_model & model, const llm_graph_params & params); + +struct llama_model_wavtokenizer_dec : public llama_model_base { + llama_model_wavtokenizer_dec(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_qwen2vl : public llm_graph_context { - llm_build_qwen2vl(const llama_model & model, const llm_graph_params & params); + +struct llama_model_plm : public llama_model_base { + llama_model_plm(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_qwen3 : public llm_graph_context { - llm_build_qwen3(const llama_model & model, const llm_graph_params & params); + +struct llama_model_bailingmoe : public llama_model_base { + llama_model_bailingmoe(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_qwen3moe : public llm_graph_context { - llm_build_qwen3moe(const llama_model & model, const llm_graph_params & params); + +struct llama_model_bailingmoe2 : public llama_model_base { + llama_model_bailingmoe2(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_qwen3vl : public llm_graph_context { - llm_build_qwen3vl(const llama_model & model, const llm_graph_params & params); + +struct llama_model_seed_oss : public llama_model_base { + llama_model_seed_oss(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_qwen3vlmoe : public llm_graph_context { - llm_build_qwen3vlmoe(const llama_model & model, const llm_graph_params & params); + +struct llama_model_dots1 : public llama_model_base { + llama_model_dots1(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_qwen3next : public llm_build_delta_net_base { - llm_build_qwen3next(const llama_model & model, const llm_graph_params & params); -private: - ggml_tensor * build_layer_attn( - llm_graph_input_attn_kv * inp_attn, - ggml_tensor * cur, - ggml_tensor * inp_pos, - int il); - ggml_tensor * build_layer_attn_linear( - llm_graph_input_rs * inp, - ggml_tensor * cur, - int il); +struct llama_model_arcee : public llama_model_base { + llama_model_arcee(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; - ggml_tensor * build_layer_ffn( - ggml_tensor * cur, - int il); + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; - ggml_tensor * build_norm_gated( - ggml_tensor * input, - ggml_tensor * weights, - ggml_tensor * gate, - int layer); + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; +}; - // returns pair of qkv, z - std::pair build_qkvz( - ggml_tensor * input, - int il); - const llama_model & model; +struct llama_model_afmoe : public llama_model_base { + llama_model_afmoe(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_qwen35 : public llm_build_delta_net_base { - llm_build_qwen35(const llama_model & model, const llm_graph_params & params); -private: - ggml_tensor * build_layer_attn( - llm_graph_input_attn_kv * inp_attn, - ggml_tensor * cur, - ggml_tensor * inp_pos, - int * sections, - int il); - ggml_tensor * build_layer_attn_linear( - llm_graph_input_rs * inp, - ggml_tensor * cur, - int il); +struct llama_model_ernie4_5 : public llama_model_base { + llama_model_ernie4_5(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; - ggml_tensor * build_layer_ffn( - ggml_tensor * cur, - int il); + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; - ggml_tensor * build_norm_gated( - ggml_tensor * input, - ggml_tensor * weights, - ggml_tensor * gate, - int layer); + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; +}; - // returns pair of qkv, z - std::pair build_qkvz( - ggml_tensor * input, - int il); - const llama_model & model; +struct llama_model_ernie4_5_moe : public llama_model_ernie4_5 { + llama_model_ernie4_5_moe(const struct llama_model_params & params) : llama_model_ernie4_5(params) {} + // reuse load_arch_hparams and load_arch_tensors from llama_model_ernie4_5 + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -// TODO: derive llm_build_delta_net_base instead -struct llm_build_qwen35moe : public llm_build_delta_net_base { - llm_build_qwen35moe(const llama_model & model, const llm_graph_params & params); -private: - ggml_tensor * build_layer_attn( - llm_graph_input_attn_kv * inp_attn, - ggml_tensor * cur, - ggml_tensor * inp_pos, - int * sections, - int il); - ggml_tensor * build_layer_attn_linear( - llm_graph_input_rs * inp, - ggml_tensor * cur, - int il); +struct llama_model_paddleocr : public llama_model_ernie4_5 { + llama_model_paddleocr(const struct llama_model_params & params) : llama_model_ernie4_5(params) {} + // reuse load_arch_hparams and load_arch_tensors from llama_model_ernie4_5 - ggml_tensor * build_layer_ffn( - ggml_tensor * cur, - int il); + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; - ggml_tensor * build_norm_gated( - ggml_tensor * input, - ggml_tensor * weights, - ggml_tensor * gate, - int layer); + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; +}; - // returns pair of qkv, z - std::pair build_qkvz( - ggml_tensor * input, - int il); - const llama_model & model; +struct llama_model_hunyuan_moe : public llama_model_base { + llama_model_hunyuan_moe(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_qwen : public llm_graph_context { - llm_build_qwen(const llama_model & model, const llm_graph_params & params); + +struct llama_model_hunyuan_vl : public llama_model_base { + llama_model_hunyuan_vl(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_refact : public llm_graph_context { - llm_build_refact(const llama_model & model, const llm_graph_params & params); + +struct llama_model_hunyuan_dense : public llama_model_hunyuan_vl { + llama_model_hunyuan_dense(const struct llama_model_params & params) : llama_model_hunyuan_vl(params) {} + // reuse load_arch_hparams and load_arch_tensors from llama_model_hunyuan_vl + + using graph = llama_model_hunyuan_vl::graph; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_rnd1 : public llm_graph_context { - llm_build_rnd1(const llama_model & model, const llm_graph_params & params); + +struct llama_model_smollm3 : public llama_model_base { + llama_model_smollm3(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_rwkv6 : public llm_build_rwkv6_base { - llm_build_rwkv6(const llama_model & model, const llm_graph_params & params); + +struct llama_model_openai_moe : public llama_model_base { + llama_model_openai_moe(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base { - llm_build_rwkv6qwen2(const llama_model & model, const llm_graph_params & params); + +struct llama_model_falcon_h1 : public llama_model_base { + llama_model_falcon_h1(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_build_mamba_base { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_rwkv7 : public llm_build_rwkv7_base { - llm_build_rwkv7(const llama_model & model, const llm_graph_params & params); + +struct llama_model_lfm2 : public llama_model_base { + llama_model_lfm2(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + template + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_seed_oss : public llm_graph_context { - llm_build_seed_oss(const llama_model & model, const llm_graph_params & params); + +struct llama_model_lfm2moe : public llama_model_base { + llama_model_lfm2moe(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + template + using graph = llama_model_lfm2::graph; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -template -struct llm_build_smallthinker : public llm_graph_context { - llm_build_smallthinker(const llama_model & model, const llm_graph_params & params); + +struct llama_model_smallthinker : public llama_model_base { + llama_model_smallthinker(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + template + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_smollm3 : public llm_graph_context { - llm_build_smollm3(const llama_model & model, const llm_graph_params & params); + +struct llama_model_grovemoe : public llama_model_base { + llama_model_grovemoe(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_stablelm : public llm_graph_context { - llm_build_stablelm(const llama_model & model, const llm_graph_params & params); + +struct llama_model_apertus : public llama_model_base { + llama_model_apertus(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_starcoder2 : public llm_graph_context { - llm_build_starcoder2(const llama_model & model, const llm_graph_params & params); + +struct llama_model_minimax_m2 : public llama_model_base { + llama_model_minimax_m2(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_starcoder : public llm_graph_context { - llm_build_starcoder(const llama_model & model, const llm_graph_params & params); + +struct llama_model_cogvlm : public llama_model_base { + llama_model_cogvlm(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_step35_iswa : public llm_graph_context { - llm_build_step35_iswa(const llama_model & model, const llm_graph_params & params); + +struct llama_model_pangu_embed : public llama_model_base { + llama_model_pangu_embed(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -template -struct llm_build_t5 : public llm_graph_context { - llm_build_t5(const llama_model & model, const llm_graph_params & params); + +struct llama_model_qwen3next : public llama_model_base { + llama_model_qwen3next(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_build_delta_net_base { + graph(const llama_model & model, const llm_graph_params & params); + private: + ggml_tensor * build_layer_attn( + llm_graph_input_attn_kv * inp_attn, + ggml_tensor * cur, + ggml_tensor * inp_pos, + int il); + + ggml_tensor * build_layer_attn_linear( + llm_graph_input_rs * inp, + ggml_tensor * cur, + int il); + + ggml_tensor * build_layer_ffn( + ggml_tensor * cur, + int il); + + ggml_tensor * build_norm_gated( + ggml_tensor * input, + ggml_tensor * weights, + ggml_tensor * gate, + int layer); + + // returns pair of qkv, z + std::pair build_qkvz( + ggml_tensor * input, + int il); + + const llama_model & model; + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_t5encoder : public llm_build_t5 { - llm_build_t5encoder(const llama_model & model, const llm_graph_params & params); + +struct llama_model_qwen35 : public llama_model_base { + llama_model_qwen35(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_build_delta_net_base { + graph(const llama_model & model, const llm_graph_params & params); + private: + ggml_tensor * build_layer_attn( + llm_graph_input_attn_kv * inp_attn, + ggml_tensor * cur, + ggml_tensor * inp_pos, + int * sections, + int il); + + ggml_tensor * build_layer_attn_linear( + llm_graph_input_rs * inp, + ggml_tensor * cur, + int il); + + ggml_tensor * build_layer_ffn( + ggml_tensor * cur, + int il); + + ggml_tensor * build_norm_gated( + ggml_tensor * input, + ggml_tensor * weights, + ggml_tensor * gate, + int layer); + + // returns pair of qkv, z + std::pair build_qkvz( + ggml_tensor * input, + int il); + + const llama_model & model; + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_wavtokenizer_dec : public llm_graph_context { - llm_build_wavtokenizer_dec(const llama_model & model, const llm_graph_params & params); + +struct llama_model_qwen35moe : public llama_model_base { + llama_model_qwen35moe(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_build_delta_net_base { + graph(const llama_model & model, const llm_graph_params & params); + private: + ggml_tensor * build_layer_attn( + llm_graph_input_attn_kv * inp_attn, + ggml_tensor * cur, + ggml_tensor * inp_pos, + int * sections, + int il); + + ggml_tensor * build_layer_attn_linear( + llm_graph_input_rs * inp, + ggml_tensor * cur, + int il); + + ggml_tensor * build_layer_ffn( + ggml_tensor * cur, + int il); + + ggml_tensor * build_norm_gated( + ggml_tensor * input, + ggml_tensor * weights, + ggml_tensor * gate, + int layer); + + // returns pair of qkv, z + std::pair build_qkvz( + ggml_tensor * input, + int il); + + const llama_model & model; + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_mistral3 : public llama_model_base { + llama_model_mistral3(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_mimo2 : public llama_model_base { + llama_model_mimo2(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_xverse : public llm_graph_context { - llm_build_xverse(const llama_model & model, const llm_graph_params & params); + +struct llama_model_kimi_linear : public llama_model_base { + llama_model_kimi_linear(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_build_delta_net_base { + graph(const llama_model & model, const llm_graph_params & params); + + std::pair build_kda_autoregressive( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * gk, + ggml_tensor * beta, + ggml_tensor * state, + int il); + + std::pair build_kda_chunking( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * gk, + ggml_tensor * beta, + ggml_tensor * state, + ggml_tensor * causal_mask, + ggml_tensor * identity, + ggml_tensor * diag_mask, + int il); + + const llama_model & model; + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_step35 : public llama_model_base { + llama_model_step35(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; diff --git a/examples/talk-llama/models/modern-bert.cpp b/examples/talk-llama/models/modern-bert.cpp index 5c6a1b5e1bc..e9b79ffc6dc 100644 --- a/examples/talk-llama/models/modern-bert.cpp +++ b/examples/talk-llama/models/modern-bert.cpp @@ -1,6 +1,69 @@ #include "models.h" -llm_build_modern_bert::llm_build_modern_bert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_modern_bert::load_arch_hparams(llama_model_loader & ml) { + const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); + if (found_swa && hparams.n_swa > 0) { + hparams.swa_type = LLAMA_SWA_TYPE_SYMMETRIC; + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + uint32_t swa_period = 3; + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); + hparams.set_swa_pattern(swa_period, true); + } else { + hparams.swa_type = LLAMA_SWA_TYPE_NONE; + } + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + + switch (hparams.n_layer) { + case 12: + type = LLM_TYPE_47M; break; // granite-embedding-small + case 22: + type = LLM_TYPE_149M; break; // modern-bert-base + case 28: + type = LLM_TYPE_395M; break; // modern-bert-large + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_modern_bert::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {n_embd}, 0); + + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + + for(int i = 0; i < n_layer; ++i) { + auto& layer = layers[i]; + + if ( i != 0 ) { + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + } else{ + // layer 0 uses identity + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); + } + + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, 3 * n_embd }, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, 2 * n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + } + + cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED); + cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {hparams.n_cls_out}, TENSOR_NOT_REQUIRED); + cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, TENSOR_NOT_REQUIRED); + cls_norm = create_tensor(tn(LLM_TENSOR_CLS_NORM, "weight"), {n_embd}, TENSOR_NOT_REQUIRED); + +} + +std::unique_ptr llama_model_modern_bert::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_modern_bert::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/mpt.cpp b/examples/talk-llama/models/mpt.cpp index 8596bbb2024..cfc60e8de29 100644 --- a/examples/talk-llama/models/mpt.cpp +++ b/examples/talk-llama/models/mpt.cpp @@ -1,6 +1,70 @@ #include "models.h" -llm_build_mpt::llm_build_mpt(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_mpt::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv, false); + ml.get_key(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, hparams.f_max_alibi_bias, false); + + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_7B; break; + case 48: type = LLM_TYPE_30B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_mpt::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}, TENSOR_NOT_REQUIRED); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, TENSOR_NOT_REQUIRED); + + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + if (!output) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); // needs to be on GPU + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); + layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + + // FIXME test-llama-archs crashes if q_norm is created + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED | TENSOR_SKIP_IF_VIRTUAL); + layer.attn_q_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED | TENSOR_SKIP_IF_VIRTUAL); + + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + // AWQ ScaleActivation layer + layer.ffn_act = create_tensor(tn(LLM_TENSOR_FFN_ACT, "scales", i), {n_ff}, TENSOR_NOT_REQUIRED); + } +} + +std::unique_ptr llama_model_mpt::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_mpt::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/nemotron-h-moe.cpp b/examples/talk-llama/models/nemotron-h-moe.cpp new file mode 100644 index 00000000000..a59cc6c9fbd --- /dev/null +++ b/examples/talk-llama/models/nemotron-h-moe.cpp @@ -0,0 +1,6 @@ +#include "models.h" + +std::unique_ptr llama_model_nemotron_h_moe::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + diff --git a/examples/talk-llama/models/nemotron-h.cpp b/examples/talk-llama/models/nemotron-h.cpp index dc07d43df58..865461f61db 100644 --- a/examples/talk-llama/models/nemotron-h.cpp +++ b/examples/talk-llama/models/nemotron-h.cpp @@ -1,6 +1,127 @@ #include "models.h" -llm_build_nemotron_h::llm_build_nemotron_h(const llama_model & model, const llm_graph_params & params) : +void llama_model_nemotron_h::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); + + // A layer is recurrent IFF the n_head_kv value is set to 0 and + // the n_ff value is set to 0 + for (uint32_t i = 0; i < hparams.n_layer; ++i) { + hparams.recurrent_layer_arr[i] = (hparams.n_head_kv(i) == 0 && hparams.n_ff(i) == 0); + } + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); + ml.get_key(LLM_KV_MOE_LATENT_SIZE, hparams.moe_latent_size, false); + + switch (hparams.n_layer) { + case 52: type = LLM_TYPE_31B_A3_5B; break; // Nemotron-H_MOE 31B + case 56: type = LLM_TYPE_9B; break; + case 88: type = LLM_TYPE_120B_A12B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_nemotron_h::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + // mamba2 Mixer SSM params + // NOTE: int64_t for tensor dimensions + const int64_t d_conv = hparams.ssm_d_conv; + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t d_state = hparams.ssm_d_state; + const int64_t n_ssm_head = hparams.ssm_dt_rank; + const int64_t n_group = hparams.ssm_n_group; + const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + n_ssm_head; + const int64_t moe_n_embd = hparams.moe_latent_size > 0 ? hparams.moe_latent_size : n_embd; + + // embeddings + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + { + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed, duplicated to allow offloading + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + // all blocks use the attn norm + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + if (hparams.is_recurrent(i)) { + // ssm layers + layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, d_in_proj}, 0); + + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner + 2*n_group*d_state}, 0); + layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner + 2*n_group*d_state}, TENSOR_NOT_REQUIRED); + + layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {n_ssm_head}, 0); + + // no "weight" suffix for these + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_ssm_head}, 0); + layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, n_ssm_head}, 0); + + layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {d_inner / n_group, n_group}, 0); + + // out_proj + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0); + } else if (hparams.n_ff(i) == 0) { + // attention layers (with optional bias) + const int64_t n_head_i = hparams.n_head(i); + const int64_t n_embd_k_gqa_i = hparams.n_embd_k_gqa(i); + const int64_t n_embd_v_gqa_i = hparams.n_embd_v_gqa(i); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head_i, n_embd_k_gqa_i, n_embd_v_gqa_i, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head_i, n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + } else { + if (n_expert != 0) { + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + const int64_t n_ff_shexp = hparams.n_ff_shexp; + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert}, 0); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert }, 0); + + // MoE branch + layer.ffn_latent_down = create_tensor(tn(LLM_TENSOR_FFN_LATENT_DOWN, "weight", i), {n_embd, moe_n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_latent_up = create_tensor(tn(LLM_TENSOR_FFN_LATENT_UP, "weight", i), {moe_n_embd, n_embd}, TENSOR_NOT_REQUIRED); + + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, moe_n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {moe_n_embd, n_ff_exp, n_expert}, 0); + + // Shared expert branch + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp}, 0); + + } else { + // mlp layers + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { hparams.n_ff(i), n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, hparams.n_ff(i)}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {hparams.n_ff(i)}, TENSOR_NOT_REQUIRED); + } + } + } +} + +std::unique_ptr llama_model_nemotron_h::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_nemotron_h::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_build_mamba_base(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); @@ -60,7 +181,7 @@ llm_build_nemotron_h::llm_build_nemotron_h(const llama_model & model, const llm_ ggml_build_forward_expand(gf, cur); } -ggml_tensor * llm_build_nemotron_h::build_attention_layer(ggml_tensor * cur, +ggml_tensor * llama_model_nemotron_h::graph::build_attention_layer(ggml_tensor * cur, llm_graph_input_attn_kv * inp_attn, const llama_model & model, int64_t n_embd_head, @@ -76,7 +197,7 @@ ggml_tensor * llm_build_nemotron_h::build_attention_layer(ggml_tensor * return cur; } -ggml_tensor * llm_build_nemotron_h::build_ffn_layer(ggml_tensor * cur, const llama_model & model, int il) { +ggml_tensor * llama_model_nemotron_h::graph::build_ffn_layer(ggml_tensor * cur, const llama_model & model, int il) { if (model.layers[il].ffn_gate_inp == nullptr) { cur = build_ffn(cur, model.layers[il].ffn_up, model.layers[il].ffn_up_b, model.layers[il].ffn_up_s, diff --git a/examples/talk-llama/models/nemotron.cpp b/examples/talk-llama/models/nemotron.cpp index 054b16fe0ef..0c72ed297aa 100644 --- a/examples/talk-llama/models/nemotron.cpp +++ b/examples/talk-llama/models/nemotron.cpp @@ -1,6 +1,52 @@ #include "models.h" -llm_build_nemotron::llm_build_nemotron(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_nemotron::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_4B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_nemotron::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + // optional bias tensors + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + + // optional MLP bias + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + } +} + +std::unique_ptr llama_model_nemotron::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_nemotron::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/neo-bert.cpp b/examples/talk-llama/models/neo-bert.cpp index da68024a34d..f00d6eddfc9 100644 --- a/examples/talk-llama/models/neo-bert.cpp +++ b/examples/talk-llama/models/neo-bert.cpp @@ -1,6 +1,46 @@ #include "models.h" -llm_build_neo_bert::llm_build_neo_bert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_neo_bert::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + if (hparams.n_layer == 28) { + type = LLM_TYPE_250M; + } +} + +void llama_model_neo_bert::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, TENSOR_NOT_REQUIRED); + cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"), {n_embd}, TENSOR_NOT_REQUIRED); + + cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED); + cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {hparams.n_cls_out}, TENSOR_NOT_REQUIRED); + + output_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_OUTPUT_NORM, "weight"), {n_embd}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff*2}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + } +} + +std::unique_ptr llama_model_neo_bert::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_neo_bert::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/nomic-bert-moe.cpp b/examples/talk-llama/models/nomic-bert-moe.cpp new file mode 100644 index 00000000000..a17abe2c269 --- /dev/null +++ b/examples/talk-llama/models/nomic-bert-moe.cpp @@ -0,0 +1,72 @@ +#include "models.h" + +void llama_model_nomic_bert_moe::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_MOE_EVERY_N_LAYERS, hparams.moe_every_n_layers, 0); + + if (hparams.n_layer == 12 && hparams.n_embd == 768) { + if (arch == LLM_ARCH_NOMIC_BERT) { + type = LLM_TYPE_137M; + } else if (arch == LLM_ARCH_NOMIC_BERT_MOE && hparams.moe_every_n_layers == 2) { + type = LLM_TYPE_475M; + } + } +} + +void llama_model_nomic_bert_moe::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + if (n_token_types == 0) { + throw std::runtime_error(arch_name() + " model needs to define token type count"); + } + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, TENSOR_NOT_REQUIRED); + + if (arch == LLM_ARCH_BERT) { + pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}, 0); + + cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, TENSOR_NOT_REQUIRED); + cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"), {n_embd}, TENSOR_NOT_REQUIRED); + + cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED); + cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {hparams.n_cls_out}, TENSOR_NOT_REQUIRED); + } + + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {n_embd}, 0); + tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias", 0), {n_embd}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0); + layer.attn_out_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd}, 0); + + if (hparams.moe_every_n_layers > 0 && i % hparams.moe_every_n_layers == 1) { + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + } else { + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + if (arch == LLM_ARCH_NOMIC_BERT) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + } + } + + layer.layer_out_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, 0); + layer.layer_out_norm_b = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd}, 0); + } +} + +std::unique_ptr llama_model_nomic_bert_moe::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + diff --git a/examples/talk-llama/models/nomic-bert.cpp b/examples/talk-llama/models/nomic-bert.cpp new file mode 100644 index 00000000000..5a8a5584457 --- /dev/null +++ b/examples/talk-llama/models/nomic-bert.cpp @@ -0,0 +1,72 @@ +#include "models.h" + +void llama_model_nomic_bert::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_MOE_EVERY_N_LAYERS, hparams.moe_every_n_layers, 0); + + if (hparams.n_layer == 12 && hparams.n_embd == 768) { + if (arch == LLM_ARCH_NOMIC_BERT) { + type = LLM_TYPE_137M; + } else if (arch == LLM_ARCH_NOMIC_BERT_MOE && hparams.moe_every_n_layers == 2) { + type = LLM_TYPE_475M; + } + } +} + +void llama_model_nomic_bert::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + if (n_token_types == 0) { + throw std::runtime_error(arch_name() + " model needs to define token type count"); + } + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, TENSOR_NOT_REQUIRED); + + if (arch == LLM_ARCH_BERT) { + pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}, 0); + + cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, TENSOR_NOT_REQUIRED); + cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"), {n_embd}, TENSOR_NOT_REQUIRED); + + cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED); + cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {hparams.n_cls_out}, TENSOR_NOT_REQUIRED); + } + + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {n_embd}, 0); + tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias", 0), {n_embd}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0); + layer.attn_out_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd}, 0); + + if (hparams.moe_every_n_layers > 0 && i % hparams.moe_every_n_layers == 1) { + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + } else { + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + if (arch == LLM_ARCH_NOMIC_BERT) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + } + } + + layer.layer_out_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, 0); + layer.layer_out_norm_b = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd}, 0); + } +} + +std::unique_ptr llama_model_nomic_bert::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + diff --git a/examples/talk-llama/models/olmo.cpp b/examples/talk-llama/models/olmo.cpp index a9974025f07..161035e72bc 100644 --- a/examples/talk-llama/models/olmo.cpp +++ b/examples/talk-llama/models/olmo.cpp @@ -1,6 +1,46 @@ #include "models.h" -llm_build_olmo::llm_build_olmo(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_olmo::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv, false); + + switch (hparams.n_layer) { + case 22: type = LLM_TYPE_1B; break; + case 32: type = LLM_TYPE_7B; break; + case 80: type = LLM_TYPE_70B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_olmo::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} + +std::unique_ptr llama_model_olmo::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_olmo::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/olmo2.cpp b/examples/talk-llama/models/olmo2.cpp index 308d2a600c2..9633f269965 100644 --- a/examples/talk-llama/models/olmo2.cpp +++ b/examples/talk-llama/models/olmo2.cpp @@ -1,7 +1,68 @@ #include "models.h" +void llama_model_olmo2::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); + if (found_swa && hparams.n_swa > 0) { + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + uint32_t swa_period = 4; + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); + hparams.set_swa_pattern(swa_period); + + hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; + hparams.rope_freq_scale_train_swa = 1.0; // See olmo2.cpp + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + } else { + hparams.swa_type = LLAMA_SWA_TYPE_NONE; + } + + switch (hparams.n_layer) { + case 16: type = LLM_TYPE_1B; break; + case 32: type = LLM_TYPE_7B; break; + case 40: type = LLM_TYPE_13B; break; + case 64: type = LLM_TYPE_32B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_olmo2::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + const int64_t n_embd_head = n_embd / n_head; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_head_kv * n_embd_head}, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); + } +} + +std::unique_ptr llama_model_olmo2::build_arch_graph(const llm_graph_params & params) const { + if (hparams.swa_type == LLAMA_SWA_TYPE_STANDARD) { + return std::make_unique>(*this, params); + } else { + return std::make_unique>(*this, params); + } +} + template -llm_build_olmo2::llm_build_olmo2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +llama_model_olmo2::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); @@ -146,5 +207,5 @@ llm_build_olmo2::llm_build_olmo2(const llama_model & model, const llm_grap } // Explicit template instantiations -template struct llm_build_olmo2; -template struct llm_build_olmo2; +template struct llama_model_olmo2::graph; +template struct llama_model_olmo2::graph; diff --git a/examples/talk-llama/models/olmoe.cpp b/examples/talk-llama/models/olmoe.cpp index ed46a00ef90..4bb9013054c 100644 --- a/examples/talk-llama/models/olmoe.cpp +++ b/examples/talk-llama/models/olmoe.cpp @@ -1,6 +1,55 @@ #include "models.h" -llm_build_olmoe::llm_build_olmoe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_olmoe::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 16: type = LLM_TYPE_A1_7B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_olmoe::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0"); + } + + // MoE branch + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + } +} + +std::unique_ptr llama_model_olmoe::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_olmoe::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/openai-moe-iswa.cpp b/examples/talk-llama/models/openai-moe.cpp similarity index 51% rename from examples/talk-llama/models/openai-moe-iswa.cpp rename to examples/talk-llama/models/openai-moe.cpp index 50992b8d506..13a590ce646 100644 --- a/examples/talk-llama/models/openai-moe-iswa.cpp +++ b/examples/talk-llama/models/openai-moe.cpp @@ -1,6 +1,67 @@ #include "models.h" -llm_build_openai_moe_iswa::llm_build_openai_moe_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_openai_moe::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + uint32_t swa_period = 2; + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); + hparams.set_swa_pattern(swa_period); + + hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; + hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + + switch (hparams.n_layer) { + case 24: type = LLM_TYPE_20B; break; + case 36: type = LLM_TYPE_120B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_openai_moe::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + const int64_t n_ff_exp = hparams.n_ff_exp; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_head * n_rot, n_head_kv * n_rot, n_head_kv * n_rot, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_rot, n_embd}, 0); + + layer.attn_sinks = create_tensor(tn(LLM_TENSOR_ATTN_SINKS, "weight", i), {n_head}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + + layer.ffn_gate_inp_b = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "bias", i), {n_expert}, 0); + layer.ffn_gate_exps_b = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "bias", i), {n_ff_exp, n_expert}, 0); + layer.ffn_down_exps_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "bias", i), { n_embd, n_expert}, 0); + layer.ffn_up_exps_b = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "bias", i), {n_ff_exp, n_expert}, 0); + } +} + +std::unique_ptr llama_model_openai_moe::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_openai_moe::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { ggml_tensor * cur; ggml_tensor * inpL; diff --git a/examples/talk-llama/models/openelm.cpp b/examples/talk-llama/models/openelm.cpp index 514ac33517f..b4128e116e7 100644 --- a/examples/talk-llama/models/openelm.cpp +++ b/examples/talk-llama/models/openelm.cpp @@ -1,6 +1,53 @@ #include "models.h" -llm_build_openelm::llm_build_openelm(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_openelm::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 16: type = LLM_TYPE_270M; break; + case 20: type = LLM_TYPE_450M; break; + case 28: type = LLM_TYPE_1B; break; + case 36: type = LLM_TYPE_3B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_openelm::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + // init output from the input tok embed + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + + for (int i = 0; i < n_layer; ++i) { + const int64_t n_head = hparams.n_head(i); + const int64_t n_head_qkv = 2*hparams.n_head_kv(i) + n_head; + const int64_t n_ff = hparams.n_ff(i); + + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_head_qkv*n_embd_head_k}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head*n_embd_head_k, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} + +std::unique_ptr llama_model_openelm::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_openelm::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/orion.cpp b/examples/talk-llama/models/orion.cpp index a5874b6dee7..7ace0a5139d 100644 --- a/examples/talk-llama/models/orion.cpp +++ b/examples/talk-llama/models/orion.cpp @@ -1,6 +1,46 @@ #include "models.h" -llm_build_orion::llm_build_orion(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_orion::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + + switch (hparams.n_layer) { + case 40: type = LLM_TYPE_14B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_orion::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} + +std::unique_ptr llama_model_orion::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_orion::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/paddleocr.cpp b/examples/talk-llama/models/paddleocr.cpp index 56cb1d94c5f..1c0eadefa98 100644 --- a/examples/talk-llama/models/paddleocr.cpp +++ b/examples/talk-llama/models/paddleocr.cpp @@ -1,6 +1,10 @@ #include "models.h" -llm_build_paddleocr::llm_build_paddleocr(const llama_model & model, const llm_graph_params & params) : +std::unique_ptr llama_model_paddleocr::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_paddleocr::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { // NOTE: same with qwen2vl.cpp, but bias tensors are optional diff --git a/examples/talk-llama/models/pangu-embedded.cpp b/examples/talk-llama/models/pangu-embed.cpp similarity index 53% rename from examples/talk-llama/models/pangu-embedded.cpp rename to examples/talk-llama/models/pangu-embed.cpp index 53464f21d22..41b7e2ac23e 100644 --- a/examples/talk-llama/models/pangu-embedded.cpp +++ b/examples/talk-llama/models/pangu-embed.cpp @@ -1,6 +1,60 @@ #include "models.h" -llm_build_pangu_embedded::llm_build_pangu_embedded(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_pangu_embed::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 26: type = LLM_TYPE_1B; break; // openPangu-Embedded-1B-V1.1 + case 34: type = LLM_TYPE_7B; break; // openPangu-Embedded-7B-V1.1 + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_pangu_embed::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + // weight tensors + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + // bias tensors + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } else { + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} + +std::unique_ptr llama_model_pangu_embed::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_pangu_embed::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/phi2.cpp b/examples/talk-llama/models/phi2.cpp index 0fb3ffa2e63..a333602c72d 100644 --- a/examples/talk-llama/models/phi2.cpp +++ b/examples/talk-llama/models/phi2.cpp @@ -1,6 +1,50 @@ #include "models.h" -llm_build_phi2::llm_build_phi2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_phi2::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + + switch (hparams.n_layer) { + case 24: type = LLM_TYPE_1B; break; + case 32: type = LLM_TYPE_3B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_phi2::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), {n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); + } +} + +std::unique_ptr llama_model_phi2::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_phi2::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/phi3.cpp b/examples/talk-llama/models/phi3.cpp index 39af285d3c5..0a65e91fefa 100644 --- a/examples/talk-llama/models/phi3.cpp +++ b/examples/talk-llama/models/phi3.cpp @@ -1,7 +1,71 @@ #include "models.h" +void llama_model_phi3::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 24: type = LLM_TYPE_1B; break; + case 32: type = LLM_TYPE_3B; break; + case 40: type = LLM_TYPE_14B; break; + default: type = LLM_TYPE_UNKNOWN; + } + + const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); + + if (found_swa && hparams.n_swa > 0) { + LLAMA_LOG_WARN("%s: Phi SWA is currently disabled - results might be suboptimal for some models (see %s)\n", + __func__, "https://github.com/ggml-org/llama.cpp/pull/13676"); + + // TODO: fix conversion scripts to correctly populate `n_swa` and `n_swa_pattern` + hparams.swa_type = LLAMA_SWA_TYPE_NONE; + + hparams.n_swa = 0; + hparams.set_swa_pattern(1); + } +} + +void llama_model_phi3::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, TENSOR_NOT_REQUIRED); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, 2 * n_ff }, 0); + + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } +} + +std::unique_ptr llama_model_phi3::build_arch_graph(const llm_graph_params & params) const { + if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { + return std::make_unique> (*this, params); + } else { + return std::make_unique>(*this, params); + } +} + template -llm_build_phi3::llm_build_phi3(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +llama_model_phi3::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); @@ -128,5 +192,5 @@ llm_build_phi3::llm_build_phi3(const llama_model & model, const llm_graph_ } // Explicit template instantiations -template struct llm_build_phi3; -template struct llm_build_phi3; +template struct llama_model_phi3::graph; +template struct llama_model_phi3::graph; diff --git a/examples/talk-llama/models/phimoe.cpp b/examples/talk-llama/models/phimoe.cpp new file mode 100644 index 00000000000..4575d6139cf --- /dev/null +++ b/examples/talk-llama/models/phimoe.cpp @@ -0,0 +1,55 @@ +#include "models.h" + +void llama_model_phimoe::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_16x3_8B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_phimoe::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + const int64_t n_embd_head = n_embd / n_head; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, 0); + output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), { n_vocab }, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), { n_embd }, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), { n_embd }, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), { n_embd }, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), { n_embd_head/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_embd_head/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } +} + +std::unique_ptr llama_model_phimoe::build_arch_graph(const llm_graph_params & params) const { + if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { + return std::make_unique> (*this, params); + } else { + return std::make_unique>(*this, params); + } +} + diff --git a/examples/talk-llama/models/plamo.cpp b/examples/talk-llama/models/plamo.cpp index 4d5c84506c2..4c16c20a0d4 100644 --- a/examples/talk-llama/models/plamo.cpp +++ b/examples/talk-llama/models/plamo.cpp @@ -1,6 +1,42 @@ #include "models.h" -llm_build_plamo::llm_build_plamo(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_plamo::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 40: type = LLM_TYPE_13B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_plamo::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} + +std::unique_ptr llama_model_plamo::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_plamo::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/plamo2.cpp b/examples/talk-llama/models/plamo2.cpp index b6142daebd9..29c8702606a 100644 --- a/examples/talk-llama/models/plamo2.cpp +++ b/examples/talk-llama/models/plamo2.cpp @@ -1,8 +1,109 @@ #include "models.h" - #include "llama-memory-recurrent.h" -llm_build_plamo2::llm_build_plamo2(const llama_model & model, const llm_graph_params & params) : +void llama_model_plamo2::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + // Load Mamba SSM parameters + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); + + for (uint32_t i = 0; i < hparams.n_layer; ++i) { + hparams.recurrent_layer_arr[i] = hparams.n_head_kv(i) == 0; + } + + switch (hparams.n_layer) { + case 16: type = LLM_TYPE_1B; break; + case 32: + if (hparams.n_embd == 2048) { + type = LLM_TYPE_2B; + } else if (hparams.n_embd == 4096) { + type = LLM_TYPE_8B; + } + break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_plamo2::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + // mamba parameters + const uint32_t d_conv = hparams.ssm_d_conv; + const uint32_t d_state = hparams.ssm_d_state; + const uint32_t num_heads = hparams.ssm_dt_rank; + const uint32_t intermediate_size = hparams.ssm_d_inner; + const int64_t dt_dim = std::max(64, int(hparams.n_embd / 16)); + + // attention parameters + const uint32_t qk_dim = hparams.n_embd_head_k(); + const uint32_t v_dim = hparams.n_embd_head_v(); + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + bool is_mamba_layer = hparams.is_recurrent(i); + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + if (is_mamba_layer) { + layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, 2 * intermediate_size}, 0); + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, intermediate_size}, 0); + + layer.ssm_x = create_tensor(tn(LLM_TENSOR_SSM_X, "weight", i), {intermediate_size, dt_dim + 2*d_state}, 0); + layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "weight", i), {dt_dim, num_heads}, 0); + layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {num_heads}, 0); + + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {num_heads}, 0); + layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {num_heads}, 0); + + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {intermediate_size, n_embd}, 0); + + layer.ssm_dt_norm = create_tensor(tn(LLM_TENSOR_SSM_DT_NORM, i), {dt_dim}, 0); + layer.ssm_b_norm = create_tensor(tn(LLM_TENSOR_SSM_B_NORM, i), {d_state}, 0); + layer.ssm_c_norm = create_tensor(tn(LLM_TENSOR_SSM_C_NORM, i), {d_state}, 0); + } else { + const int64_t num_attention_heads = hparams.n_head(i); + const int64_t q_num_heads = num_attention_heads; + const int64_t num_key_value_heads = hparams.n_head_kv(i); + const int64_t k_num_heads = num_key_value_heads; + const int64_t v_num_heads = num_key_value_heads; + const int64_t q_proj_dim = q_num_heads * qk_dim; + const int64_t k_proj_dim = k_num_heads * qk_dim; + const int64_t v_proj_dim = v_num_heads * v_dim; + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, q_proj_dim + k_proj_dim + v_proj_dim}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {qk_dim, num_attention_heads}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {qk_dim, k_num_heads}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {q_num_heads * v_dim, n_embd}, 0); + } + + // All layers have post-attention norm, FFN norm, and FFN tensors + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, i), {n_embd}, 0); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff * 2}, 0); + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, i), {n_embd}, 0); + } +} + +std::unique_ptr llama_model_plamo2::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_plamo2::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_build_mamba_base(params) { ggml_tensor * cur; ggml_tensor * inpL; @@ -95,7 +196,7 @@ llm_build_plamo2::llm_build_plamo2(const llama_model & model, const llm_graph_pa ggml_build_forward_expand(gf, cur); } -ggml_tensor * llm_build_plamo2::build_plamo2_attn_layer(llm_graph_input_attn_kv * inp, +ggml_tensor * llama_model_plamo2::graph::build_plamo2_attn_layer(llm_graph_input_attn_kv * inp, ggml_tensor * inp_pos, ggml_tensor * cur, const llama_model & model, @@ -150,7 +251,7 @@ ggml_tensor * llm_build_plamo2::build_plamo2_attn_layer(llm_graph_input_attn_kv return cur; } -ggml_tensor * llm_build_plamo2::build_plamo2_mamba_layer(llm_graph_input_rs * inp, +ggml_tensor * llama_model_plamo2::graph::build_plamo2_mamba_layer(llm_graph_input_rs * inp, ggml_tensor * cur, const llama_model & model, const llama_ubatch & ubatch, diff --git a/examples/talk-llama/models/plamo3.cpp b/examples/talk-llama/models/plamo3.cpp index 67844c09f24..849f1579e63 100644 --- a/examples/talk-llama/models/plamo3.cpp +++ b/examples/talk-llama/models/plamo3.cpp @@ -1,7 +1,74 @@ #include "models.h" +void llama_model_plamo3::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); + if (found_swa && hparams.n_swa > 0) { + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + uint32_t swa_period = 8; + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); + hparams.set_swa_pattern(swa_period); + } else { + hparams.swa_type = LLAMA_SWA_TYPE_NONE; + } + + switch (hparams.n_layer) { + case 24: type = LLM_TYPE_2B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_plamo3::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + const int64_t head_dim_q = hparams.n_embd_head_k(); + const int64_t head_dim_v = hparams.n_embd_head_v(); + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + const int64_t num_attention_heads = hparams.n_head(i); + const int64_t num_key_value_heads = hparams.n_head_kv(i); + const int64_t q_proj_dim = num_attention_heads * head_dim_q; + const int64_t k_proj_dim = num_key_value_heads * head_dim_q; + const int64_t v_proj_dim = num_key_value_heads * head_dim_v; + const int64_t n_ff_cur = hparams.n_ff(i); + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), + {n_embd,q_proj_dim + k_proj_dim + v_proj_dim}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {head_dim_q}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {head_dim_q}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {num_attention_heads * head_dim_v, n_embd}, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, i), {n_embd}, 0); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff_cur * 2}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff_cur, n_embd}, 0); + } +} + +std::unique_ptr llama_model_plamo3::build_arch_graph(const llm_graph_params & params) const { + if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { + return std::make_unique> (*this, params); + } else { + return std::make_unique>(*this, params); + } +} + template -llm_build_plamo3::llm_build_plamo3(const llama_model & model, const llm_graph_params & params) : +llama_model_plamo3::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t head_dim_q = hparams.n_embd_head_k(); const int64_t head_dim_v = hparams.n_embd_head_v(); @@ -126,5 +193,5 @@ llm_build_plamo3::llm_build_plamo3(const llama_model & model, const llm_gr } // Explicit template instantiations -template struct llm_build_plamo3; -template struct llm_build_plamo3; +template struct llama_model_plamo3::graph; +template struct llama_model_plamo3::graph; diff --git a/examples/talk-llama/models/plm.cpp b/examples/talk-llama/models/plm.cpp index abce6b34d04..57f5995103b 100644 --- a/examples/talk-llama/models/plm.cpp +++ b/examples/talk-llama/models/plm.cpp @@ -1,6 +1,50 @@ #include "models.h" -llm_build_plm::llm_build_plm(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_plm::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_1_8B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_plm::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + const int64_t n_embd_head_qk_rope = hparams.n_rot(); + const int64_t n_embd_head_qk_nope = hparams.n_embd_head_k() - hparams.n_rot(); + const int64_t kv_lora_rank = hparams.n_lora_kv; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + // output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + (n_embd_head_qk_rope)}, 0); + layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, 0); + layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_head * ( n_embd_head_v), n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} + +std::unique_ptr llama_model_plm::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_plm::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const float kq_scale = 1.0f/sqrtf(float(hparams.n_embd_head_k())); const uint32_t n_embd_head_qk_rope = hparams.n_rot(); diff --git a/examples/talk-llama/models/qwen.cpp b/examples/talk-llama/models/qwen.cpp index 44e75d87437..cdc076cdf77 100644 --- a/examples/talk-llama/models/qwen.cpp +++ b/examples/talk-llama/models/qwen.cpp @@ -1,6 +1,46 @@ #include "models.h" -llm_build_qwen::llm_build_qwen(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_qwen::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_7B; break; + case 40: type = LLM_TYPE_13B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_qwen::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd*3}, 0); + layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd*3}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff/2}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff/2, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff/2}, 0); + } +} + +std::unique_ptr llama_model_qwen::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_qwen::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/qwen2.cpp b/examples/talk-llama/models/qwen2.cpp index 2892dd75087..6320458a13b 100644 --- a/examples/talk-llama/models/qwen2.cpp +++ b/examples/talk-llama/models/qwen2.cpp @@ -1,6 +1,55 @@ #include "models.h" -llm_build_qwen2::llm_build_qwen2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_qwen2::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 24: type = hparams.n_embd == 1024 ? LLM_TYPE_0_5B : LLM_TYPE_1B; break; + case 28: type = hparams.n_embd == 1536 ? LLM_TYPE_1_5B : LLM_TYPE_7B; break; + case 32: type = LLM_TYPE_7B; break; + case 36: type = LLM_TYPE_3B; break; + case 40: type = hparams.n_head() == 20 ? LLM_TYPE_4B : LLM_TYPE_13B; break; + case 48: type = LLM_TYPE_14B; break; + case 64: type = LLM_TYPE_32B; break; + case 80: type = LLM_TYPE_70B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_qwen2::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), {n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} + +std::unique_ptr llama_model_qwen2::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_qwen2::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/qwen2moe.cpp b/examples/talk-llama/models/qwen2moe.cpp index 5f0a6861b68..7587c802c68 100644 --- a/examples/talk-llama/models/qwen2moe.cpp +++ b/examples/talk-llama/models/qwen2moe.cpp @@ -1,6 +1,67 @@ #include "models.h" -llm_build_qwen2moe::llm_build_qwen2moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_qwen2moe::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 24: type = LLM_TYPE_A2_7B; break; + case 28: type = LLM_TYPE_57B_A14B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_qwen2moe::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0 for QWEN2MOE"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0 for QWEN2MOE"); + } + + // MoE branch + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + + // Shared expert branch + const int64_t n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff; + + layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), {n_embd}, 0); + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp}, 0); + } +} + +std::unique_ptr llama_model_qwen2moe::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_qwen2moe::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/qwen2vl.cpp b/examples/talk-llama/models/qwen2vl.cpp index da7937c7667..1a40fa89be4 100644 --- a/examples/talk-llama/models/qwen2vl.cpp +++ b/examples/talk-llama/models/qwen2vl.cpp @@ -1,6 +1,45 @@ #include "models.h" -llm_build_qwen2vl::llm_build_qwen2vl(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_qwen2vl::load_arch_hparams(llama_model_loader & ml) { + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, true); +} +// fall through + +void llama_model_qwen2vl::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), {n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} + +std::unique_ptr llama_model_qwen2vl::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_qwen2vl::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/qwen3.cpp b/examples/talk-llama/models/qwen3.cpp index 883dd5f9a90..fa656c84ea0 100644 --- a/examples/talk-llama/models/qwen3.cpp +++ b/examples/talk-llama/models/qwen3.cpp @@ -1,6 +1,55 @@ #include "models.h" -llm_build_qwen3::llm_build_qwen3(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_qwen3::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 28: type = hparams.n_embd == 1024 ? LLM_TYPE_0_6B : LLM_TYPE_1_7B; break; + case 36: type = hparams.n_embd == 2560 ? LLM_TYPE_4B : LLM_TYPE_8B; break; + case 40: type = LLM_TYPE_14B; break; + case 64: type = LLM_TYPE_32B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_qwen3::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + // output rerank head + cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} + +std::unique_ptr llama_model_qwen3::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_qwen3::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/qwen35.cpp b/examples/talk-llama/models/qwen35.cpp index 87790f08e4e..f276be61ba8 100644 --- a/examples/talk-llama/models/qwen35.cpp +++ b/examples/talk-llama/models/qwen35.cpp @@ -1,8 +1,96 @@ #include "models.h" - #include "llama-memory-recurrent.h" -llm_build_qwen35::llm_build_qwen35(const llama_model & model, const llm_graph_params & params) : +void llama_model_qwen35::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, true); + + // Load linear attention (gated delta net) parameters + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); + + // Mark recurrent layers (linear attention layers) + { + uint32_t full_attn_interval = 4; + ml.get_key(LLM_KV_FULL_ATTENTION_INTERVAL, full_attn_interval, false); + for (uint32_t i = 0; i < hparams.n_layer; ++i) { + hparams.recurrent_layer_arr[i] = ((i + 1) % full_attn_interval != 0); + } + } + + switch (hparams.n_layer) { + case 24: type = hparams.n_embd == 1024 ? LLM_TYPE_0_8B : LLM_TYPE_2B; break; + case 32: type = hparams.n_embd == 2560 ? LLM_TYPE_4B : LLM_TYPE_9B; break; + case 64: type = LLM_TYPE_27B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_qwen35::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); + } + + // Calculate dimensions from hyperparameters + const int64_t head_k_dim = hparams.ssm_d_state; + const int64_t head_v_dim = hparams.ssm_d_state; + const int64_t n_k_heads = hparams.ssm_n_group; + const int64_t n_v_heads = hparams.ssm_dt_rank; + const int64_t key_dim = head_k_dim * n_k_heads; + const int64_t value_dim = head_v_dim * n_v_heads; + const int64_t conv_dim = key_dim * 2 + value_dim; + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0); + + if (!hparams.is_recurrent(i)) { + // Attention layers + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); + + // Q/K normalization for attention layers + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0); + } else { + // Linear attention (gated delta net) specific tensors + // Create tensors with calculated dimensions + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED); + layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), { n_embd, value_dim }, TENSOR_NOT_REQUIRED); + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), { hparams.ssm_d_conv, conv_dim }, 0); + layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), { hparams.ssm_dt_rank }, 0); + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, i), { hparams.ssm_dt_rank }, 0); + layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", i), { n_embd, n_v_heads }, 0); + layer.ssm_alpha = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "weight", i), { n_embd, n_v_heads }, 0); + layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), { head_v_dim }, 0); + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), { value_dim, n_embd }, 0); + } + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} + +std::unique_ptr llama_model_qwen35::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_build_delta_net_base(params), model(model) { const int64_t n_embd_head = hparams.n_embd_head_v(); @@ -87,7 +175,7 @@ llm_build_qwen35::llm_build_qwen35(const llama_model & model, const llm_graph_pa ggml_build_forward_expand(gf, cur); } -std::pair llm_build_qwen35::build_qkvz( +std::pair llama_model_qwen35::graph::build_qkvz( ggml_tensor * input, int il) { const int64_t n_seqs = ubatch.n_seqs; @@ -103,7 +191,7 @@ std::pair llm_build_qwen35::build_qkvz( return { qkv_mixed, z }; } -ggml_tensor * llm_build_qwen35::build_norm_gated( +ggml_tensor * llama_model_qwen35::graph::build_norm_gated( ggml_tensor * input, ggml_tensor * weights, ggml_tensor * gate, @@ -114,7 +202,7 @@ ggml_tensor * llm_build_qwen35::build_norm_gated( return ggml_mul(ctx0, normalized, gated_silu); } -ggml_tensor * llm_build_qwen35::build_layer_attn( +ggml_tensor * llama_model_qwen35::graph::build_layer_attn( llm_graph_input_attn_kv * inp, ggml_tensor * cur, ggml_tensor * inp_pos, @@ -195,7 +283,7 @@ ggml_tensor * llm_build_qwen35::build_layer_attn( return cur; } -ggml_tensor * llm_build_qwen35::build_layer_attn_linear( +ggml_tensor * llama_model_qwen35::graph::build_layer_attn_linear( llm_graph_input_rs * inp, ggml_tensor * cur, int il) { @@ -369,7 +457,7 @@ ggml_tensor * llm_build_qwen35::build_layer_attn_linear( return cur; } -ggml_tensor * llm_build_qwen35::build_layer_ffn(ggml_tensor * cur, const int il) { +ggml_tensor * llama_model_qwen35::graph::build_layer_ffn(ggml_tensor * cur, const int il) { // Qwen3.5 does not use MoE FFN GGML_ASSERT(model.layers[il].ffn_gate_inp == nullptr); diff --git a/examples/talk-llama/models/qwen35moe.cpp b/examples/talk-llama/models/qwen35moe.cpp index 7dc6a23c751..cf05dc9d61c 100644 --- a/examples/talk-llama/models/qwen35moe.cpp +++ b/examples/talk-llama/models/qwen35moe.cpp @@ -1,8 +1,109 @@ #include "models.h" - #include "llama-memory-recurrent.h" -llm_build_qwen35moe::llm_build_qwen35moe(const llama_model & model, const llm_graph_params & params) : +void llama_model_qwen35moe::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, true); + + // Load linear attention (gated delta net) parameters + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); + + // Mark recurrent layers (linear attention layers) + { + uint32_t full_attn_interval = 4; + ml.get_key(LLM_KV_FULL_ATTENTION_INTERVAL, full_attn_interval, false); + for (uint32_t i = 0; i < hparams.n_layer; ++i) { + hparams.recurrent_layer_arr[i] = ((i + 1) % full_attn_interval != 0); + } + } + + switch (hparams.n_layer) { + case 40: type = LLM_TYPE_35B_A3B; break; + case 48: type = LLM_TYPE_122B_A10B; break; + case 60: type = LLM_TYPE_397B_A17B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_qwen35moe::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); + } + + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + + // Calculate dimensions from hyperparameters + const int64_t head_k_dim = hparams.ssm_d_state; + const int64_t head_v_dim = hparams.ssm_d_state; + const int64_t n_k_heads = hparams.ssm_n_group; + const int64_t n_v_heads = hparams.ssm_dt_rank; + const int64_t key_dim = head_k_dim * n_k_heads; + const int64_t value_dim = head_v_dim * n_v_heads; + const int64_t conv_dim = key_dim * 2 + value_dim; + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0); + + if (!hparams.is_recurrent(i)) { + // Attention layers + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); + + // Q/K normalization for attention layers + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0); + } else { + // Linear attention (gated delta net) specific tensors + // Create tensors with calculated dimensions + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED); + layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), { n_embd, value_dim }, TENSOR_NOT_REQUIRED); + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), { hparams.ssm_d_conv, conv_dim }, 0); + layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), { hparams.ssm_dt_rank }, 0); + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, i), { hparams.ssm_dt_rank }, 0); + layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", i), { n_embd, n_v_heads }, 0); + layer.ssm_alpha = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "weight", i), { n_embd, n_v_heads }, 0); + layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), { head_v_dim }, 0); + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), { value_dim, n_embd }, 0); + } + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0); + create_tensor_gate_up_exps(layer, i, n_embd, n_ff_exp, n_expert, 0); + + // Shared experts + const int64_t n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff; + + layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), { n_embd }, 0); + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_shexp, n_embd }, 0); + } +} + +std::unique_ptr llama_model_qwen35moe::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_build_delta_net_base(params), model(model) { const int64_t n_embd_head = hparams.n_embd_head_v(); @@ -87,7 +188,7 @@ llm_build_qwen35moe::llm_build_qwen35moe(const llama_model & model, const llm_gr ggml_build_forward_expand(gf, cur); } -std::pair llm_build_qwen35moe::build_qkvz( +std::pair llama_model_qwen35moe::graph::build_qkvz( ggml_tensor * input, int il) { const int64_t n_seqs = ubatch.n_seqs; @@ -103,7 +204,7 @@ std::pair llm_build_qwen35moe::build_qkvz( return { qkv_mixed, z }; } -ggml_tensor * llm_build_qwen35moe::build_norm_gated( +ggml_tensor * llama_model_qwen35moe::graph::build_norm_gated( ggml_tensor * input, ggml_tensor * weights, ggml_tensor * gate, @@ -114,7 +215,7 @@ ggml_tensor * llm_build_qwen35moe::build_norm_gated( return ggml_mul(ctx0, normalized, gated_silu); } -ggml_tensor * llm_build_qwen35moe ::build_layer_attn( +ggml_tensor * llama_model_qwen35moe::graph::build_layer_attn( llm_graph_input_attn_kv * inp, ggml_tensor * cur, ggml_tensor * inp_pos, @@ -195,7 +296,7 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn( return cur; } -ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear( +ggml_tensor * llama_model_qwen35moe::graph::build_layer_attn_linear( llm_graph_input_rs * inp, ggml_tensor * cur, int il) { @@ -369,7 +470,7 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear( return cur; } -ggml_tensor * llm_build_qwen35moe ::build_layer_ffn(ggml_tensor * cur, const int il) { +ggml_tensor * llama_model_qwen35moe::graph::build_layer_ffn(ggml_tensor * cur, const int il) { // Check if this is an MoE layer GGML_ASSERT(model.layers[il].ffn_gate_inp != nullptr); diff --git a/examples/talk-llama/models/qwen3moe.cpp b/examples/talk-llama/models/qwen3moe.cpp index 16bedba994d..4440b83aa45 100644 --- a/examples/talk-llama/models/qwen3moe.cpp +++ b/examples/talk-llama/models/qwen3moe.cpp @@ -1,6 +1,65 @@ #include "models.h" -llm_build_qwen3moe::llm_build_qwen3moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_qwen3moe::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 48: type = LLM_TYPE_30B_A3B; break; + case 94: type = LLM_TYPE_235B_A22B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_qwen3moe::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0 for QWEN3MOE"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0 for QWEN3MOE"); + } + + // MoE branch + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + } +} + +std::unique_ptr llama_model_qwen3moe::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_qwen3moe::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/qwen3next.cpp b/examples/talk-llama/models/qwen3next.cpp index 1beda70b7cf..cb1b4814caf 100644 --- a/examples/talk-llama/models/qwen3next.cpp +++ b/examples/talk-llama/models/qwen3next.cpp @@ -1,8 +1,113 @@ #include "models.h" - #include "llama-memory-recurrent.h" -llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_graph_params & params) : +void llama_model_qwen3next::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + // Load linear attention (gated delta net) parameters + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); + + // Mark recurrent layers (linear attention layers) + { + uint32_t full_attn_interval = 4; + ml.get_key(LLM_KV_FULL_ATTENTION_INTERVAL, full_attn_interval, false); + for (uint32_t i = 0; i < hparams.n_layer; ++i) { + hparams.recurrent_layer_arr[i] = ((i + 1) % full_attn_interval != 0); + } + } + + switch (hparams.n_layer) { + case 48: type = LLM_TYPE_80B_A3B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_qwen3next::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + if (n_expert == 0) { + throw std::runtime_error(arch_name() + " model cannot have zero experts"); + } + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); + } + + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + + // Calculate dimensions from hyperparameters + const int64_t head_k_dim = hparams.ssm_d_state; + const int64_t head_v_dim = hparams.ssm_d_state; + const int64_t n_k_heads = hparams.ssm_n_group; + const int64_t n_v_heads = hparams.ssm_dt_rank; + const int64_t key_dim = head_k_dim * n_k_heads; + const int64_t value_dim = head_v_dim * n_v_heads; + const int64_t conv_dim = key_dim * 2 + value_dim; + + // Calculate projection sizes + const int64_t qkvz_dim = key_dim * 2 + value_dim * 2; + const int64_t ba_dim = n_v_heads * 2; + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + const uint32_t n_ff_shexp = hparams.n_ff_shexp > 0 ? hparams.n_ff_shexp : hparams.n_ff(i); + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0); + + if (!hparams.is_recurrent(i)) { + // Attention layers + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); + + // Q/K normalization for attention layers + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0); + } else { + // Linear attention (gated delta net) specific tensors + // Create tensors with calculated dimensions + // note: ssm_in is used by legacy GGUF + layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), { n_embd, qkvz_dim }, TENSOR_NOT_REQUIRED); + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED); + layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), { n_embd, value_dim }, TENSOR_NOT_REQUIRED); + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), { hparams.ssm_d_conv, conv_dim }, 0); + layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), { hparams.ssm_dt_rank }, 0); + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, i), { hparams.ssm_dt_rank }, 0); + layer.ssm_beta_alpha = create_tensor(tn(LLM_TENSOR_SSM_BETA_ALPHA, "weight", i), { n_embd, ba_dim }, 0); + layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), { head_v_dim }, 0); + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), { value_dim, n_embd }, 0); + } + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0); + create_tensor_gate_up_exps(layer, i, n_embd, n_ff_exp, n_expert, 0); + + // Shared experts + layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), { n_embd }, 0); + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_shexp, n_embd }, 0); + } +} + +std::unique_ptr llama_model_qwen3next::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_qwen3next::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_build_delta_net_base(params), model(model) { ggml_tensor * cur; ggml_tensor * inpL; @@ -87,7 +192,7 @@ static ggml_tensor * get_slice_2d(ggml_context * ctx0, ggml_tensor * t, int64_t t->nb[1], t->nb[2], t->nb[3], t->nb[2] * c); } -ggml_tensor * llm_build_qwen3next::build_norm_gated( +ggml_tensor * llama_model_qwen3next::graph::build_norm_gated( ggml_tensor * input, ggml_tensor * weights, ggml_tensor * gate, @@ -98,7 +203,7 @@ ggml_tensor * llm_build_qwen3next::build_norm_gated( return ggml_mul(ctx0, normalized, gated_silu); } -ggml_tensor * llm_build_qwen3next::build_layer_attn( +ggml_tensor * llama_model_qwen3next::graph::build_layer_attn( llm_graph_input_attn_kv * inp, ggml_tensor * cur, ggml_tensor * inp_pos, @@ -178,7 +283,7 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn( return cur; } -std::pair llm_build_qwen3next::build_qkvz( +std::pair llama_model_qwen3next::graph::build_qkvz( ggml_tensor * input, int il) { const int64_t d_inner = hparams.ssm_d_inner; @@ -259,7 +364,7 @@ std::pair llm_build_qwen3next::build_qkvz( } } -ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( +ggml_tensor * llama_model_qwen3next::graph::build_layer_attn_linear( llm_graph_input_rs * inp, ggml_tensor * cur, int il) { @@ -468,7 +573,7 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( return cur; } -ggml_tensor * llm_build_qwen3next::build_layer_ffn(ggml_tensor * cur, const int il) { +ggml_tensor * llama_model_qwen3next::graph::build_layer_ffn(ggml_tensor * cur, const int il) { // Check if this is an MoE layer if (model.layers[il].ffn_gate_inp != nullptr) { // MoE branch diff --git a/examples/talk-llama/models/qwen3vl.cpp b/examples/talk-llama/models/qwen3vl.cpp index faa5f2ef3c8..7871f8f7952 100644 --- a/examples/talk-llama/models/qwen3vl.cpp +++ b/examples/talk-llama/models/qwen3vl.cpp @@ -1,6 +1,56 @@ #include "models.h" -llm_build_qwen3vl::llm_build_qwen3vl(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_qwen3vl::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_NUM_DEEPSTACK_LAYERS, hparams.n_deepstack_layers, false); + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, true); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 28: type = LLM_TYPE_1_7B; break; + case 36: type = hparams.n_embd == 2560 ? LLM_TYPE_4B : LLM_TYPE_8B; break; + case 64: type = LLM_TYPE_32B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_qwen3vl::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + // output rerank head + cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} + +std::unique_ptr llama_model_qwen3vl::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_qwen3vl::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const size_t n_deepstack_layers = hparams.n_deepstack_layers; const int64_t n_embd = hparams.n_embd; diff --git a/examples/talk-llama/models/qwen3vl-moe.cpp b/examples/talk-llama/models/qwen3vlmoe.cpp similarity index 57% rename from examples/talk-llama/models/qwen3vl-moe.cpp rename to examples/talk-llama/models/qwen3vlmoe.cpp index 29ee8278a4d..b99143c8908 100644 --- a/examples/talk-llama/models/qwen3vl-moe.cpp +++ b/examples/talk-llama/models/qwen3vlmoe.cpp @@ -1,6 +1,66 @@ #include "models.h" -llm_build_qwen3vlmoe::llm_build_qwen3vlmoe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_qwen3vlmoe::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_NUM_DEEPSTACK_LAYERS, hparams.n_deepstack_layers, false); + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, true); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 48: type = LLM_TYPE_30B_A3B; break; + case 94: type = LLM_TYPE_235B_A22B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_qwen3vlmoe::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0 for QWEN3MOE"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0 for QWEN3MOE"); + } + + // MoE branch + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + } +} + +std::unique_ptr llama_model_qwen3vlmoe::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_qwen3vlmoe::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const size_t n_deepstack_layers = hparams.n_deepstack_layers; const int64_t n_embd = hparams.n_embd; @@ -127,4 +187,3 @@ llm_build_qwen3vlmoe::llm_build_qwen3vlmoe(const llama_model & model, const llm_ ggml_build_forward_expand(gf, cur); } - diff --git a/examples/talk-llama/models/refact.cpp b/examples/talk-llama/models/refact.cpp index 398eb368db0..f14f10917ff 100644 --- a/examples/talk-llama/models/refact.cpp +++ b/examples/talk-llama/models/refact.cpp @@ -1,6 +1,81 @@ #include "models.h" -llm_build_refact::llm_build_refact(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_refact::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_1B; break; + default: type = LLM_TYPE_UNKNOWN; + } + + // TODO: become GGUF KV parameter + hparams.f_max_alibi_bias = 8.0f; +} + +void llama_model_refact::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + // optional bias tensors + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + else { + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + + if (n_expert == 0) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + + // optional MLP bias + layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + } else { + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + + // For Granite MoE Shared + if (hparams.n_ff_shexp > 0) { + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, 0); + } + } + } +} + +std::unique_ptr llama_model_refact::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_refact::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/rnd1.cpp b/examples/talk-llama/models/rnd1.cpp index a917c19f25a..325ee73ba5c 100644 --- a/examples/talk-llama/models/rnd1.cpp +++ b/examples/talk-llama/models/rnd1.cpp @@ -1,7 +1,67 @@ #include "models.h" +void llama_model_rnd1::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 48: type = LLM_TYPE_30B_A3B; break; + default: type = LLM_TYPE_UNKNOWN; + } + // Set non-causal attention for diffusion models + hparams.causal_attn = false; +} + +void llama_model_rnd1::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0 for QWEN3MOE"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0 for QWEN3MOE"); + } + + // MoE branch + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + } +} + +std::unique_ptr llama_model_rnd1::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + // RND1 is a Qwen3Moe AR model converted to diffusion model. -llm_build_rnd1::llm_build_rnd1(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +llama_model_rnd1::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/rwkv6.cpp b/examples/talk-llama/models/rwkv6.cpp index 032b219d6cb..2944711acec 100644 --- a/examples/talk-llama/models/rwkv6.cpp +++ b/examples/talk-llama/models/rwkv6.cpp @@ -1,6 +1,97 @@ #include "models.h" -llm_build_rwkv6::llm_build_rwkv6(const llama_model & model, const llm_graph_params & params) : +void llama_model_rwkv6::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps, false); + ml.get_key(LLM_KV_WKV_HEAD_SIZE, hparams.wkv_head_size); + ml.get_key(LLM_KV_TIME_MIX_EXTRA_DIM, hparams.time_mix_extra_dim); + ml.get_key(LLM_KV_TIME_DECAY_EXTRA_DIM, hparams.time_decay_extra_dim); + ml.get_key(LLM_KV_RESCALE_EVERY_N_LAYERS, hparams.rescale_every_n_layers, false); + ml.get_key(LLM_KV_TOKEN_SHIFT_COUNT, hparams.token_shift_count, false); + + switch (hparams.n_layer) { + case 24: type = LLM_TYPE_1_6B; break; + case 32: + switch (hparams.n_embd) { + case 2560: type = LLM_TYPE_3B; break; + case 4096: type = LLM_TYPE_7B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 61: type = LLM_TYPE_14B; break; + case 64: type = LLM_TYPE_32B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_rwkv6::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // Block 0, LN0 + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {n_embd}, 0); + tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias", 0), {n_embd}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + const int time_mix_extra_dim = hparams.time_mix_extra_dim; + const int time_decay_extra_dim = hparams.time_decay_extra_dim; + const int head_size = hparams.wkv_head_size; + const int attn_hidden_size = n_embd; + const int ffn_size = hparams.n_ff_arr[0]; + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + layer.attn_norm_2 = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, 0); + layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, 0); + + layer.time_mix_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W1, "weight", i), {n_embd, time_mix_extra_dim * 5}, 0); + layer.time_mix_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {time_mix_extra_dim, n_embd, 5}, 0); + + layer.time_mix_lerp_x = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_X, "weight", i), {n_embd, 1, 1}, 0); + layer.time_mix_lerp_w = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_W, "weight", i), {n_embd, 1, 1}, TENSOR_NOT_REQUIRED); + layer.time_mix_lerp_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_K, "weight", i), {n_embd, 1, 1}, TENSOR_NOT_REQUIRED); + layer.time_mix_lerp_v = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_V, "weight", i), {n_embd, 1, 1}, TENSOR_NOT_REQUIRED); + layer.time_mix_lerp_r = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_R, "weight", i), {n_embd, 1, 1}, TENSOR_NOT_REQUIRED); + layer.time_mix_lerp_g = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_G, "weight", i), {n_embd, 1, 1}, TENSOR_NOT_REQUIRED); + layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 5}, TENSOR_NOT_REQUIRED); + GGML_ASSERT(!(layer.time_mix_lerp_fused == NULL && layer.time_mix_lerp_w == NULL)); + + layer.time_mix_first = create_tensor(tn(LLM_TENSOR_TIME_MIX_FIRST, "weight", i), {head_size, n_embd / head_size}, 0); + layer.time_mix_decay = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY, "weight", i), {n_embd}, 0); + layer.time_mix_decay_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY_W1, "weight", i), {n_embd, time_decay_extra_dim}, 0); + layer.time_mix_decay_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY_W2, "weight", i), {time_decay_extra_dim, attn_hidden_size}, 0); + layer.time_mix_key = create_tensor(tn(LLM_TENSOR_TIME_MIX_KEY, "weight", i), {attn_hidden_size, n_embd}, 0); + layer.time_mix_value = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "weight", i), {attn_hidden_size, n_embd}, 0); + layer.time_mix_receptance = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "weight", i), {attn_hidden_size, n_embd}, 0); + layer.time_mix_gate = create_tensor(tn(LLM_TENSOR_TIME_MIX_GATE, "weight", i), {attn_hidden_size, n_embd}, 0); + + layer.time_mix_ln = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "weight", i), {n_embd}, 0); + layer.time_mix_ln_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "bias", i), {n_embd}, 0); + layer.time_mix_output = create_tensor(tn(LLM_TENSOR_TIME_MIX_OUTPUT, "weight", i), {n_embd, attn_hidden_size}, 0); + + layer.channel_mix_lerp_k = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_LERP_K, "weight", i), {n_embd, 1, 1}, 0); + layer.channel_mix_lerp_r = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_LERP_R, "weight", i), {n_embd, 1, 1}, 0); + + layer.channel_mix_key = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_KEY, "weight", i), {n_embd, ffn_size}, 0); + layer.channel_mix_value = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_VALUE, "weight", i), {ffn_size, n_embd}, 0); + layer.channel_mix_receptance = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_RECEPTANCE, "weight", i), {n_embd, n_embd}, 0); + } + +} + +std::unique_ptr llama_model_rwkv6::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_rwkv6::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_build_rwkv6_base(model, params) { GGML_ASSERT(hparams.token_shift_count == 2); diff --git a/examples/talk-llama/models/rwkv6qwen2.cpp b/examples/talk-llama/models/rwkv6qwen2.cpp index e84e5973820..6f7d1f5722f 100644 --- a/examples/talk-llama/models/rwkv6qwen2.cpp +++ b/examples/talk-llama/models/rwkv6qwen2.cpp @@ -1,6 +1,87 @@ #include "models.h" -llm_build_rwkv6qwen2::llm_build_rwkv6qwen2(const llama_model & model, const llm_graph_params & params) : llm_build_rwkv6_base(model, params) { +void llama_model_rwkv6qwen2::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps, false); + ml.get_key(LLM_KV_WKV_HEAD_SIZE, hparams.wkv_head_size); + ml.get_key(LLM_KV_TIME_MIX_EXTRA_DIM, hparams.time_mix_extra_dim); + ml.get_key(LLM_KV_TIME_DECAY_EXTRA_DIM, hparams.time_decay_extra_dim); + ml.get_key(LLM_KV_RESCALE_EVERY_N_LAYERS, hparams.rescale_every_n_layers, false); + ml.get_key(LLM_KV_TOKEN_SHIFT_COUNT, hparams.token_shift_count, false); + + switch (hparams.n_layer) { + case 24: type = LLM_TYPE_1_6B; break; + case 32: + switch (hparams.n_embd) { + case 2560: type = LLM_TYPE_3B; break; + case 4096: type = LLM_TYPE_7B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 61: type = LLM_TYPE_14B; break; + case 64: type = LLM_TYPE_32B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_rwkv6qwen2::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, TENSOR_NOT_REQUIRED); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + const int time_mix_extra_dim = hparams.time_mix_extra_dim; + const int time_decay_extra_dim = hparams.time_decay_extra_dim; + const int head_size = hparams.wkv_head_size; + const int attn_hidden_size = n_embd; + int attn_key_value_size; + if (n_head_kv == 0 || attn_hidden_size / head_size == n_head_kv) { + attn_key_value_size = attn_hidden_size; + } else { + attn_key_value_size = n_head_kv * head_size; + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.time_mix_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W1, "weight", i), {n_embd, time_mix_extra_dim * 5}, 0); + layer.time_mix_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {time_mix_extra_dim, n_embd, 5}, 0); + + layer.time_mix_lerp_x = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_X, "weight", i), {n_embd, 1, 1}, 0); + layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 5}, 0); + + layer.time_mix_first = create_tensor(tn(LLM_TENSOR_TIME_MIX_FIRST, "weight", i), {head_size, n_embd / head_size}, TENSOR_NOT_REQUIRED); + layer.time_mix_decay = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY, "weight", i), {n_embd}, 0); + layer.time_mix_decay_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY_W1, "weight", i), {n_embd, time_decay_extra_dim}, 0); + layer.time_mix_decay_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY_W2, "weight", i), {time_decay_extra_dim, attn_hidden_size}, 0); + layer.time_mix_key = create_tensor(tn(LLM_TENSOR_TIME_MIX_KEY, "weight", i), {n_embd, attn_key_value_size}, 0); + layer.time_mix_value = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "weight", i), {n_embd, attn_key_value_size}, 0); + layer.time_mix_receptance = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "weight", i), {attn_hidden_size, n_embd}, 0); + layer.time_mix_gate = create_tensor(tn(LLM_TENSOR_TIME_MIX_GATE, "weight", i), {attn_hidden_size, n_embd}, 0); + // optional bias tensors + layer.time_mix_key_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_KEY, "bias", i), {attn_key_value_size}, TENSOR_NOT_REQUIRED); + layer.time_mix_value_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "bias", i), {attn_key_value_size}, TENSOR_NOT_REQUIRED); + layer.time_mix_receptance_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "bias", i), {attn_hidden_size}, TENSOR_NOT_REQUIRED); + + layer.time_mix_output = create_tensor(tn(LLM_TENSOR_TIME_MIX_OUTPUT, "weight", i), {n_embd, attn_hidden_size}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} + +std::unique_ptr llama_model_rwkv6qwen2::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_rwkv6qwen2::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_build_rwkv6_base(model, params) { GGML_ASSERT(n_embd == hparams.n_embd_r()); ggml_tensor * cur; diff --git a/examples/talk-llama/models/rwkv7.cpp b/examples/talk-llama/models/rwkv7.cpp index 16ffa6901b9..b205e3935e1 100644 --- a/examples/talk-llama/models/rwkv7.cpp +++ b/examples/talk-llama/models/rwkv7.cpp @@ -1,6 +1,127 @@ #include "models.h" -llm_build_rwkv7::llm_build_rwkv7(const llama_model & model, const llm_graph_params & params) : +void llama_model_rwkv7::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps, false); + ml.get_key(LLM_KV_WKV_HEAD_SIZE, hparams.wkv_head_size); + ml.get_key(LLM_KV_ATTENTION_DECAY_LORA_RANK, hparams.n_lora_decay); + ml.get_key(LLM_KV_ATTENTION_ICLR_LORA_RANK, hparams.n_lora_iclr); + ml.get_key(LLM_KV_ATTENTION_VALUE_RESIDUAL_MIX_LORA_RANK, hparams.n_lora_value_res_mix); + ml.get_key(LLM_KV_ATTENTION_GATE_LORA_RANK, hparams.n_lora_gate, false); + ml.get_key(LLM_KV_TOKEN_SHIFT_COUNT, hparams.token_shift_count, false); + + switch (hparams.n_layer) { + case 12: + switch (hparams.n_embd) { + case 768: type = LLM_TYPE_190M; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 24: + switch (hparams.n_embd) { + case 1024: type = LLM_TYPE_450M; break; + case 2048: type = LLM_TYPE_1_5B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 28: + switch (hparams.n_embd) { + case 1536: type = LLM_TYPE_1_5B; break; + case 3584: type = LLM_TYPE_7B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 32: + switch (hparams.n_embd) { + case 2560: type = LLM_TYPE_2_9B; break; + case 4096: type = LLM_TYPE_7B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 61: + switch (hparams.n_embd) { + case 4096: type = LLM_TYPE_14B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_rwkv7::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // Block 0, LN0 + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {n_embd}, 0); + tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias", 0), {n_embd}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + const int n_lora_decay = hparams.n_lora_decay; + const int n_lora_iclr = hparams.n_lora_iclr; + const int n_lora_value_res_mix = hparams.n_lora_value_res_mix; + const int n_lora_gate = hparams.n_lora_gate; + const int attn_hidden_size = n_embd; + const int ffn_size = hparams.n_ff_arr[0]; + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + layer.attn_norm_2 = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, 0); + layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, 0); + + layer.time_mix_w0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W0, "weight", i), {n_embd}, 0); + layer.time_mix_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W1, "weight", i), {n_embd, n_lora_decay}, 0); + layer.time_mix_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {n_lora_decay, n_embd}, 0); + + layer.time_mix_a0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A0, "weight", i), {n_embd}, 0); + layer.time_mix_a1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A1, "weight", i), {n_embd, n_lora_iclr}, 0); + layer.time_mix_a2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A2, "weight", i), {n_lora_iclr, n_embd}, 0); + + if (i == 0) { + // actually not used + layer.time_mix_v0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V0, "weight", i), {n_embd}, 0); + layer.time_mix_v1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V1, "weight", i), {n_embd, n_lora_iclr}, 0); + layer.time_mix_v2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V2, "weight", i), {n_lora_iclr, n_embd}, 0); + } else { + layer.time_mix_v0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V0, "weight", i), {n_embd}, 0); + layer.time_mix_v1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V1, "weight", i), {n_embd, n_lora_value_res_mix}, 0); + layer.time_mix_v2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V2, "weight", i), {n_lora_value_res_mix, n_embd}, 0); + } + + layer.time_mix_g1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_G1, "weight", i), {n_embd, n_lora_gate}, 0); + layer.time_mix_g2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_G2, "weight", i), {n_lora_gate, n_embd}, 0); + + layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 6}, 0); + + layer.time_mix_k_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_K_K, "weight", i), {attn_hidden_size}, 0); + layer.time_mix_k_a = create_tensor(tn(LLM_TENSOR_TIME_MIX_K_A, "weight", i), {attn_hidden_size}, 0); + layer.time_mix_r_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_R_K, "weight", i), {attn_hidden_size}, 0); + + layer.time_mix_key = create_tensor(tn(LLM_TENSOR_TIME_MIX_KEY, "weight", i), {attn_hidden_size, n_embd}, 0); + layer.time_mix_value = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "weight", i), {attn_hidden_size, n_embd}, 0); + layer.time_mix_receptance = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "weight", i), {attn_hidden_size, n_embd}, 0); + + layer.time_mix_ln = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "weight", i), {n_embd}, 0); + layer.time_mix_ln_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "bias", i), {n_embd}, 0); + layer.time_mix_output = create_tensor(tn(LLM_TENSOR_TIME_MIX_OUTPUT, "weight", i), {n_embd, attn_hidden_size}, 0); + + layer.channel_mix_lerp_k = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_LERP_K, "weight", i), {n_embd, 1, 1}, 0); + + layer.channel_mix_key = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_KEY, "weight", i), {n_embd, ffn_size}, 0); + layer.channel_mix_value = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_VALUE, "weight", i), {ffn_size, n_embd}, 0); + } + +} + +std::unique_ptr llama_model_rwkv7::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_rwkv7::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_build_rwkv7_base(model, params) { GGML_ASSERT(hparams.token_shift_count == 2); diff --git a/examples/talk-llama/models/seed-oss.cpp b/examples/talk-llama/models/seed-oss.cpp index 6db8d9781fe..83e114740b6 100644 --- a/examples/talk-llama/models/seed-oss.cpp +++ b/examples/talk-llama/models/seed-oss.cpp @@ -1,6 +1,51 @@ #include "models.h" -llm_build_seed_oss::llm_build_seed_oss(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_seed_oss::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 64: type = LLM_TYPE_36B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_seed_oss::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + const uint32_t head_dim = hparams.n_embd_head_k(); + const int64_t n_qo_dim = n_head * head_dim; + const int64_t n_kv_dim = n_head_kv * head_dim; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + create_tensor_qkv(layer, i, n_embd, n_qo_dim, n_kv_dim, n_kv_dim, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_qo_dim, n_embd}, 0); + + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + } +} + +std::unique_ptr llama_model_seed_oss::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_seed_oss::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/smallthinker.cpp b/examples/talk-llama/models/smallthinker.cpp index 55d09ec325d..3214e7cbad3 100644 --- a/examples/talk-llama/models/smallthinker.cpp +++ b/examples/talk-llama/models/smallthinker.cpp @@ -1,7 +1,80 @@ #include "models.h" +void llama_model_smallthinker::load_arch_hparams(llama_model_loader & ml) { + const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); + + if (found_swa && hparams.n_swa > 0) { + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + hparams.n_swa = 4096; + uint32_t swa_period = 4; + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); + hparams.set_swa_pattern(swa_period, true); + + hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; + hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + } else { + hparams.swa_type = LLAMA_SWA_TYPE_NONE; + hparams.n_no_rope_layer_step = hparams.n_layer; + } + + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); + + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_4B; break; + case 52: type = LLM_TYPE_20B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_smallthinker::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); + + GGML_ASSERT(n_expert > 0 && "n_expert must be > 0 for SMALLTHINKER"); + GGML_ASSERT(n_expert_used > 0 && "n_expert_used must be > 0 for SMALLTHINKER"); + + // MoE branch + const int64_t n_ff_exp = hparams.n_ff_exp; + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0); + } +} + +std::unique_ptr llama_model_smallthinker::build_arch_graph(const llm_graph_params & params) const { + if (hparams.swa_type == LLAMA_SWA_TYPE_STANDARD) { + return std::make_unique> (*this, params); + } else { + return std::make_unique>(*this, params); + } +} + template -llm_build_smallthinker::llm_build_smallthinker(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params){ +llama_model_smallthinker::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params){ const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); @@ -113,5 +186,5 @@ llm_build_smallthinker::llm_build_smallthinker(const llama_model & model, } // Explicit template instantiations -template struct llm_build_smallthinker; -template struct llm_build_smallthinker; +template struct llama_model_smallthinker::graph; +template struct llama_model_smallthinker::graph; diff --git a/examples/talk-llama/models/smollm3.cpp b/examples/talk-llama/models/smollm3.cpp index 83636dbf546..7adaf34c534 100644 --- a/examples/talk-llama/models/smollm3.cpp +++ b/examples/talk-llama/models/smollm3.cpp @@ -1,6 +1,49 @@ #include "models.h" -llm_build_smollm3::llm_build_smollm3(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_smollm3::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + hparams.n_no_rope_layer_step = 4; + + switch (hparams.n_layer) { + case 36: type = LLM_TYPE_3B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_smollm3::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} + +std::unique_ptr llama_model_smollm3::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_smollm3::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/stablelm.cpp b/examples/talk-llama/models/stablelm.cpp index 9c19abd8835..8f613e55947 100644 --- a/examples/talk-llama/models/stablelm.cpp +++ b/examples/talk-llama/models/stablelm.cpp @@ -1,6 +1,54 @@ #include "models.h" -llm_build_stablelm::llm_build_stablelm(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_stablelm::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + + switch (hparams.n_layer) { + case 24: type = LLM_TYPE_1B; break; + case 32: type = LLM_TYPE_3B; break; + case 40: type = LLM_TYPE_12B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_stablelm::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + // optional q and k layernorms, present in StableLM 2 12B + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k, n_head}, TENSOR_NOT_REQUIRED); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k, n_head_kv}, TENSOR_NOT_REQUIRED); + + // optional FFN norm, not present in StableLM 2 12B which uses parallel residual + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} + +std::unique_ptr llama_model_stablelm::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_stablelm::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/starcoder.cpp b/examples/talk-llama/models/starcoder.cpp index cf9fe95c35b..58cf0ac0edc 100644 --- a/examples/talk-llama/models/starcoder.cpp +++ b/examples/talk-llama/models/starcoder.cpp @@ -1,6 +1,62 @@ #include "models.h" -llm_build_starcoder::llm_build_starcoder(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_starcoder::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + switch (hparams.n_layer) { + case 24: type = LLM_TYPE_1B; break; + case 36: type = LLM_TYPE_3B; break; + case 42: type = LLM_TYPE_7B; break; + case 40: type = LLM_TYPE_15B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_starcoder::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}, 0); + + // output + { + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + if (!output) { + // needs to be on GPU + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); + layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); + } +} + +std::unique_ptr llama_model_starcoder::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_starcoder::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/starcoder2.cpp b/examples/talk-llama/models/starcoder2.cpp index b6d4d5aac1a..45dae0602d4 100644 --- a/examples/talk-llama/models/starcoder2.cpp +++ b/examples/talk-llama/models/starcoder2.cpp @@ -1,6 +1,61 @@ #include "models.h" -llm_build_starcoder2::llm_build_starcoder2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_starcoder2::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + switch (hparams.n_layer) { + case 30: type = LLM_TYPE_3B; break; + case 32: type = LLM_TYPE_7B; break; + case 40: type = LLM_TYPE_15B; break; + case 52: type = LLM_TYPE_20B; break; // granite + case 88: type = LLM_TYPE_34B; break; // granite + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_starcoder2::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + // optional bias tensors + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + + // optional bias tensors + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP , "bias", i), { n_ff}, 0); + } +} + +std::unique_ptr llama_model_starcoder2::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_starcoder2::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/step35-iswa.cpp b/examples/talk-llama/models/step35.cpp similarity index 52% rename from examples/talk-llama/models/step35-iswa.cpp rename to examples/talk-llama/models/step35.cpp index 86aa98909e7..c4789752d21 100644 --- a/examples/talk-llama/models/step35-iswa.cpp +++ b/examples/talk-llama/models/step35.cpp @@ -1,6 +1,108 @@ #include "models.h" -llm_build_step35_iswa::llm_build_step35_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_step35::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + + // full_attention layer only use half of the RoPE dimensions + hparams.n_rot_full = hparams.n_rot_full / 2; + + // MoE + SWA parameters + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + + // Step35 uses sigmoid gating by default (if not set in GGUF) + if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { + hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID; + } + + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.swa_layers, hparams.n_layer); + ml.get_key_or_arr(LLM_KV_SWIGLU_CLAMP_EXP, hparams.swiglu_clamp_exp, hparams.n_layer, false); + ml.get_key_or_arr(LLM_KV_SWIGLU_CLAMP_SHEXP, hparams.swiglu_clamp_shexp, hparams.n_layer, false); + + switch (hparams.n_layer) { + case 45: type = LLM_TYPE_196B_A11B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_step35::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + // STEP35 supports per-layer partial RoPE dims; rope factors are stored as a single shared tensor + // ("rope_freqs.weight") and ggml uses only the first (n_rot_l/2) entries per layer. + uint32_t n_rot_max = 0; + for (int i = 0; i < n_layer; ++i) { + n_rot_max = std::max(n_rot_max, hparams.n_rot(i)); + } + if (n_rot_max == 0) { + n_rot_max = n_rot; + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + const uint32_t n_head_l = hparams.n_head(i); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i); + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, TENSOR_NOT_REQUIRED); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, TENSOR_NOT_REQUIRED); + + // optional rope factors (llama3) / longrope tensors + if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot_max/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot_max/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } else { + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot_max/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head_l, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_v * n_head_l, n_embd}, 0); + + // head-wise attention gate (Step35 self_attn.g_proj) + layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), {n_embd, n_head_l}, TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + // dense MLP (leading dense blocks) + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); + + // MoE routed experts + selection bias (router_bias) + const int64_t n_ff_exp = hparams.n_ff_exp; + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); + + // shared expert MLP + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, TENSOR_NOT_REQUIRED); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, TENSOR_NOT_REQUIRED); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, TENSOR_NOT_REQUIRED); + } +} + +std::unique_ptr llama_model_step35::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_step35::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { ggml_tensor * cur; ggml_tensor * inpL; diff --git a/examples/talk-llama/models/t5.cpp b/examples/talk-llama/models/t5.cpp index 9f9dfef4012..27a0711ba41 100644 --- a/examples/talk-llama/models/t5.cpp +++ b/examples/talk-llama/models/t5.cpp @@ -1,7 +1,125 @@ #include "models.h" +void llama_model_t5::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, hparams.n_rel_attn_bkts); + + uint32_t dec_start_token_id; + if (ml.get_key(LLM_KV_DECODER_START_TOKEN_ID, dec_start_token_id, false)) { + hparams.dec_start_token_id = dec_start_token_id; + } + + hparams.dec_n_layer = hparams.n_layer; + ml.get_key(LLM_KV_DECODER_BLOCK_COUNT, hparams.dec_n_layer, false); + + switch (hparams.n_layer) { + case 6: type = LLM_TYPE_60M; break; // t5-small + case 8: type = LLM_TYPE_80M; break; // flan-t5-small + case 12: + switch (hparams.n_ff()) { + case 3072: type = LLM_TYPE_220M; break; // t5-base + case 2048: type = LLM_TYPE_250M; break; // flan-t5-base + default: type = LLM_TYPE_UNKNOWN; + } break; + case 24: + switch (hparams.n_ff()) { + case 4096: type = LLM_TYPE_770M; break; // t5-large + case 2816: type = LLM_TYPE_780M; break; // flan-t5-large + case 16384: type = LLM_TYPE_3B; break; // t5-3b + case 5120: type = LLM_TYPE_3B; break; // flan-t5-xl + case 65536: type = LLM_TYPE_11B; break; // t5-11b + case 10240: type = LLM_TYPE_11B; break; // flan-t5-xxl + default: type = LLM_TYPE_UNKNOWN; + } break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_t5::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + const auto n_rel_attn_bkts = hparams.n_rel_attn_bkts; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm = create_tensor(tn(LLM_TENSOR_DEC_OUTPUT_NORM, "weight"), {n_embd}, 0); + + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + // n_layer: number of encoder_layers + // dec_n_layer: number of decoder_layers + const int dec_n_layer = hparams.dec_n_layer; + if (dec_n_layer > n_layer) { + layers.resize(dec_n_layer); + } + + // load encoder layers + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_rel_b_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, TENSOR_NOT_REQUIRED); + + layer.wq_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wk_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}, 0); + + layer.ffn_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + + // load decoder layers + for (int i = 0; i < dec_n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_DEC_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_rel_b = create_tensor(tn(LLM_TENSOR_DEC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, TENSOR_NOT_REQUIRED); + + layer.wq = create_tensor(tn(LLM_TENSOR_DEC_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_DEC_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_DEC_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_DEC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}, 0); + + layer.attn_norm_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_NORM, "weight", i), {n_embd}, 0); + // this tensor seems to be unused in HF transformers implementation + layer.attn_rel_b_cross = create_tensor( + tn(LLM_TENSOR_DEC_CROSS_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, TENSOR_NOT_REQUIRED | TENSOR_SKIP_IF_VIRTUAL); + + layer.wq_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wk_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_DEC_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_DEC_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_DEC_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_DEC_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} + +std::unique_ptr llama_model_t5::build_arch_graph(const llm_graph_params & params) const { + switch (params.gtype) { + case LLM_GRAPH_TYPE_ENCODER: + return std::make_unique>(*this, params); + case LLM_GRAPH_TYPE_DEFAULT: + case LLM_GRAPH_TYPE_DECODER: + return std::make_unique>(*this, params); + default: + GGML_ABORT("invalid graph type"); + }; +} + template <> -llm_build_t5::llm_build_t5(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +llama_model_t5::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); //const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); @@ -156,7 +274,7 @@ llm_build_t5::llm_build_t5(const llama_model & model, const llm_graph_par } template <> -llm_build_t5::llm_build_t5(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +llama_model_t5::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); diff --git a/examples/talk-llama/models/t5encoder.cpp b/examples/talk-llama/models/t5encoder.cpp index 5c1f9eb4030..23c5f9b6a1c 100644 --- a/examples/talk-llama/models/t5encoder.cpp +++ b/examples/talk-llama/models/t5encoder.cpp @@ -1,3 +1,44 @@ #include "models.h" -llm_build_t5encoder::llm_build_t5encoder(const llama_model & model, const llm_graph_params & params) : llm_build_t5(model, params) {} +void llama_model_t5encoder::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, hparams.n_rel_attn_bkts); + type = LLM_TYPE_UNKNOWN; +} + +void llama_model_t5encoder::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + const auto n_rel_attn_bkts = hparams.n_rel_attn_bkts; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_rel_b_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, TENSOR_NOT_REQUIRED); + + layer.wq_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wk_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}, 0); + + layer.ffn_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} + +std::unique_ptr llama_model_t5encoder::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} diff --git a/examples/talk-llama/models/wavtokenizer-dec.cpp b/examples/talk-llama/models/wavtokenizer-dec.cpp index a7776d9cdc9..a873e5d2e8f 100644 --- a/examples/talk-llama/models/wavtokenizer-dec.cpp +++ b/examples/talk-llama/models/wavtokenizer-dec.cpp @@ -1,6 +1,121 @@ #include "models.h" -llm_build_wavtokenizer_dec::llm_build_wavtokenizer_dec(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_wavtokenizer_dec::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_ATTENTION_GROUPNORM_EPS, hparams.f_norm_group_eps); + ml.get_key(LLM_KV_ATTENTION_GROUPNORM_GROUPS, hparams.n_norm_groups); +} + +void llama_model_wavtokenizer_dec::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {hparams.n_embd, n_vocab}, 0); + + conv1d = create_tensor(tn(LLM_TENSOR_CONV1D, "weight", 0), {7, hparams.n_embd, hparams.posnet.n_embd}, 0); + conv1d_b = create_tensor(tn(LLM_TENSOR_CONV1D, "bias", 0), {1, hparams.posnet.n_embd}, 0); + + // posnet + { + const int64_t n_embd = hparams.posnet.n_embd; + + for (uint32_t i = 0; i < hparams.posnet.n_layer; ++i) { + auto & layer = layers[i].posnet; + + // posnet: + // + // - resnet + // - resnet + // - attn + // - resnet + // - resnet + // - norm + // + switch (i) { + case 0: + case 1: + case 3: + case 4: + { + layer.norm1 = create_tensor(tn(LLM_TENSOR_POS_NET_NORM1, "weight", i), {1, n_embd}, 0); + layer.norm1_b = create_tensor(tn(LLM_TENSOR_POS_NET_NORM1, "bias", i), {1, n_embd}, 0); + + layer.conv1 = create_tensor(tn(LLM_TENSOR_POS_NET_CONV1, "weight", i), {3, n_embd, n_embd}, 0); + layer.conv1_b = create_tensor(tn(LLM_TENSOR_POS_NET_CONV1, "bias", i), {1, n_embd}, 0); + + layer.norm2 = create_tensor(tn(LLM_TENSOR_POS_NET_NORM2, "weight", i), {1, n_embd}, 0); + layer.norm2_b = create_tensor(tn(LLM_TENSOR_POS_NET_NORM2, "bias", i), {1, n_embd}, 0); + + layer.conv2 = create_tensor(tn(LLM_TENSOR_POS_NET_CONV2, "weight", i), {3, n_embd, n_embd}, 0); + layer.conv2_b = create_tensor(tn(LLM_TENSOR_POS_NET_CONV2, "bias", i), {1, n_embd}, 0); + } break; + case 2: + { + layer.attn_norm = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "weight", i), {1, n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "bias", i), {1, n_embd}, 0); + + layer.attn_q = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_Q, "weight", i), {1, n_embd, n_embd}, 0); + layer.attn_q_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_Q, "bias", i), {1, n_embd}, 0); + + layer.attn_k = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_K, "weight", i), {1, n_embd, n_embd}, 0); + layer.attn_k_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_K, "bias", i), {1, n_embd}, 0); + + layer.attn_v = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_V, "weight", i), {1, n_embd, n_embd}, 0); + layer.attn_v_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_V, "bias", i), {1, n_embd}, 0); + + layer.attn_o = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_OUT, "weight", i), {1, n_embd, n_embd}, 0); + layer.attn_o_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_OUT, "bias", i), {1, n_embd}, 0); + } break; + case 5: + { + layer.norm = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "weight", i), {1, n_embd}, 0); + layer.norm_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "bias", i), {1, n_embd}, 0); + } break; + default: GGML_ABORT("unknown posnet layer"); + }; + } + } + + GGML_ASSERT(hparams.posnet.n_embd == hparams.convnext.n_embd); + + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {hparams.posnet.n_embd}, 0); + tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias", 0), {hparams.posnet.n_embd}, 0); + + // convnext + { + const int64_t n_embd = hparams.convnext.n_embd; + + for (uint32_t i = 0; i < hparams.convnext.n_layer; ++i) { + auto & layer = layers[i].convnext; + + layer.dw = create_tensor(tn(LLM_TENSOR_CONVNEXT_DW, "weight", i), {7, 1, n_embd}, 0); + layer.dw_b = create_tensor(tn(LLM_TENSOR_CONVNEXT_DW, "bias", i), {1, n_embd}, 0); + + layer.norm = create_tensor(tn(LLM_TENSOR_CONVNEXT_NORM, "weight", i), {n_embd}, 0); + layer.norm_b = create_tensor(tn(LLM_TENSOR_CONVNEXT_NORM, "bias", i), {n_embd}, 0); + + layer.pw1 = create_tensor(tn(LLM_TENSOR_CONVNEXT_PW1, "weight", i), {n_embd, n_ff}, 0); + layer.pw1_b = create_tensor(tn(LLM_TENSOR_CONVNEXT_PW1, "bias", i), {n_ff}, 0); + + layer.pw2 = create_tensor(tn(LLM_TENSOR_CONVNEXT_PW2, "weight", i), {n_ff, n_embd}, 0); + layer.pw2_b = create_tensor(tn(LLM_TENSOR_CONVNEXT_PW2, "bias", i), {n_embd}, 0); + + layer.gamma = create_tensor(tn(LLM_TENSOR_CONVNEXT_GAMMA, "weight", i), {n_embd}, 0); + } + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + } + + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {hparams.convnext.n_embd, hparams.n_embd_out()}, 0); + output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), {hparams.n_embd_out()}, 0); +} + +std::unique_ptr llama_model_wavtokenizer_dec::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_wavtokenizer_dec::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { ggml_tensor * cur; ggml_tensor * inpL; diff --git a/examples/talk-llama/models/xverse.cpp b/examples/talk-llama/models/xverse.cpp index 53085ec80f6..e4d111e622a 100644 --- a/examples/talk-llama/models/xverse.cpp +++ b/examples/talk-llama/models/xverse.cpp @@ -1,6 +1,43 @@ #include "models.h" -llm_build_xverse::llm_build_xverse(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_xverse::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_7B; break; + case 40: type = LLM_TYPE_13B; break; + case 80: type = LLM_TYPE_65B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_xverse::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} + +std::unique_ptr llama_model_xverse::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_xverse::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); From f6f32a7f51c3e1c9fddb5ae55a7848221c667276 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Mon, 11 May 2026 14:07:30 +0200 Subject: [PATCH 586/831] try to fix window cublas CI failure Refs: https://github.com/ggml-org/whisper.cpp/actions/runs/25631391231/job/75237266964?pr=3803 --- .github/workflows/build.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index be3f78a3f5b..df390a9179c 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -944,7 +944,7 @@ jobs: cmake --version where cmake if "${{ matrix.cuda-toolkit }}" == "11.8.0" ( - set CUDA_FLAGS=-allow-unsupported-compiler -D_ALLOW_COMPILER_AND_STL_VERSION_MISMATCH -D_DISABLE_CONSTEXPR_MUTEX_CONSTRUCTOR + set CUDA_FLAGS=-allow-unsupported-compiler -D_ALLOW_COMPILER_AND_STL_VERSION_MISMATCH -D_DISABLE_CONSTEXPR_MUTEX_CONSTRUCTOR -D__CUDA_NO_HALF_CONVERSIONS__ ) else ( set CUDA_FLAGS= ) From 1665885f769e1cefa429d375f0044406f0665989 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Mon, 11 May 2026 14:39:16 +0200 Subject: [PATCH 587/831] Revert "try to fix window cublas CI failure" This reverts commit a4d91768aa2ae8cf7083650b3e4dc214413f92b7. --- .github/workflows/build.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index df390a9179c..be3f78a3f5b 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -944,7 +944,7 @@ jobs: cmake --version where cmake if "${{ matrix.cuda-toolkit }}" == "11.8.0" ( - set CUDA_FLAGS=-allow-unsupported-compiler -D_ALLOW_COMPILER_AND_STL_VERSION_MISMATCH -D_DISABLE_CONSTEXPR_MUTEX_CONSTRUCTOR -D__CUDA_NO_HALF_CONVERSIONS__ + set CUDA_FLAGS=-allow-unsupported-compiler -D_ALLOW_COMPILER_AND_STL_VERSION_MISMATCH -D_DISABLE_CONSTEXPR_MUTEX_CONSTRUCTOR ) else ( set CUDA_FLAGS= ) From e0bfd3ae4d50efd2959b4ae6407210bf74c921ab Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Mon, 11 May 2026 14:44:23 +0200 Subject: [PATCH 588/831] try using CCCL 12.4.127 with cuda 11.8.0 to fix CI failure --- .github/workflows/build.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index be3f78a3f5b..423b1b28b22 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -822,7 +822,7 @@ jobs: $NVTX_VER = "11.8.86" $VS_VER = "11.8.86" $NVPROF_VER = "11.8.87" - $CCCL_VER = "11.8.89" + $CCCL_VER = "12.4.127" # Create the directory where the CUDA Toolkit will be installed mkdir -p $CUDA_TOOLKIT_DIR From 5b2d4af850edf31dc23e750a769128c4b0feac1a Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Mon, 11 May 2026 15:17:13 +0200 Subject: [PATCH 589/831] Revert "try using CCCL 12.4.127 with cuda 11.8.0 to fix CI failure" This reverts commit be867eadf553801eb7d1c383ed47a90fdd3d4b18. Sorry about this noise, I thought it was worth a try. --- .github/workflows/build.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 423b1b28b22..be3f78a3f5b 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -822,7 +822,7 @@ jobs: $NVTX_VER = "11.8.86" $VS_VER = "11.8.86" $NVPROF_VER = "11.8.87" - $CCCL_VER = "12.4.127" + $CCCL_VER = "11.8.89" # Create the directory where the CUDA Toolkit will be installed mkdir -p $CUDA_TOOLKIT_DIR From 633de7f99e692fe5de95edb8eb9a778f74de548d Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Tue, 12 May 2026 06:38:12 +0200 Subject: [PATCH 590/831] devops : add spirv-headers to vulkan dockerfile --- .devops/main-vulkan.Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.devops/main-vulkan.Dockerfile b/.devops/main-vulkan.Dockerfile index 2be22e4d53b..077af4f1001 100644 --- a/.devops/main-vulkan.Dockerfile +++ b/.devops/main-vulkan.Dockerfile @@ -2,7 +2,7 @@ FROM ubuntu:24.04 AS build WORKDIR /app RUN apt-get update && \ - apt-get install -y build-essential wget cmake git libvulkan-dev glslc \ + apt-get install -y build-essential wget cmake git libvulkan-dev spirv-headers glslc \ && rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/* COPY .. . From b1ebddf154c38adfaff4448ac37b8564d066f559 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Tue, 12 May 2026 07:59:24 +0200 Subject: [PATCH 591/831] ggml-cuda : add explicit casts to -INFINITY for float and half2 types This commit adds explicit casts to float for -INFINITY. The motivation for this is that in CUDA 11.8.0, the -INFINITY macro is defined as a double (a header provided NVCC). This triggers a warning and hence causes a CI failure in whisper.cpp. I belive that this header might have been updated in CUDA 12 which is why we don't see this warning. Refs: https://github.com/ggml-org/whisper.cpp/actions/runs/25713948217/job/75500081939?pr=3803 Refs: https://github.com/ggml-org/llama.cpp/issues/22824 --- ggml/src/ggml-cuda/common.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 10817505d9f..246a76193ca 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -582,9 +582,9 @@ template struct block_reduce_policy { static __device__ T sentinel() { if constexpr (std::is_same_v) { - return -INFINITY; + return -(float)INFINITY; } else if constexpr (std::is_same_v) { - return make_half2(-INFINITY, -INFINITY); + return make_half2(__float2half(-(float)INFINITY), __float2half(-(float)INFINITY)); } else { static_assert(ggml_cuda_dependent_false_v, "Unsupported type for block reduce max"); } From b6a4b32a88b743bc42d5f849e435384b14ddeab8 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Tue, 12 May 2026 08:30:00 +0200 Subject: [PATCH 592/831] ggml-cuda : add ar_add() to avoid ambiguous operator+ for half/bfloat16 in CUDA 11.8 --- ggml/src/ggml-cuda/allreduce.cu | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-cuda/allreduce.cu b/ggml/src/ggml-cuda/allreduce.cu index 434689abd95..03d88968cd5 100644 --- a/ggml/src/ggml-cuda/allreduce.cu +++ b/ggml/src/ggml-cuda/allreduce.cu @@ -105,6 +105,20 @@ static constexpr int GGML_CUDA_AR_KERNEL_BLOCKS = 8; // blocks. Tail elements (the leftover < ELEMS_PER_VEC at the end) are // handled only by block 0 to avoid cross-block writes to the same slots. // --------------------------------------------------------------------------- + +// CUDA 11.8 does not expose operator+ for half/bfloat16 below sm_530, +// so use the explicit intrinsics to avoid ambiguous implicit conversions. +template +static __device__ inline T ar_add(T a, T b) { + if constexpr (std::is_same_v) { + return __hadd(a, b); + } else if constexpr (std::is_same_v) { + return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b)); + } else { + return a + b; + } +} + template static __global__ void ggml_cuda_ar_kernel( const T_dst * sendbuf, @@ -184,13 +198,13 @@ static __global__ void ggml_cuda_ar_kernel( #pragma unroll for (int k = 0; k < ELEMS_PER_VEC; ++k) { const T_wire d_low = ggml_cuda_cast(sendbuf[off + k]); - recvbuf[off + k] = ggml_cuda_cast(d_low) + ggml_cuda_cast(wire[k]); + recvbuf[off + k] = ar_add(ggml_cuda_cast(d_low), ggml_cuda_cast(wire[k])); } } if (bid == 0 && tid < count - tail) { const T_wire d_low = ggml_cuda_cast(sendbuf[tail + tid]); recvbuf[tail + tid] = - ggml_cuda_cast(d_low) + ggml_cuda_cast(host_other[tail + tid]); + ar_add(ggml_cuda_cast(d_low), ggml_cuda_cast(host_other[tail + tid])); } } } @@ -210,7 +224,7 @@ static __global__ void ggml_cuda_ar_add_kernel( const int nt = gridDim.x * blockDim.x; for (int i = tid; i < count; i += nt) { const T_src d_low = ggml_cuda_cast(dst[i]); - dst[i] = ggml_cuda_cast(d_low) + ggml_cuda_cast(src[i]); + dst[i] = ar_add(ggml_cuda_cast(d_low), ggml_cuda_cast(src[i])); } } From d04a1faaec814772ca29f801ad7a10f4c330e16f Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Tue, 12 May 2026 08:36:14 +0200 Subject: [PATCH 593/831] ci : update ONEAPI version to 2025.3.3-0-devel-ubuntu24.04 --- .devops/main-intel.Dockerfile | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.devops/main-intel.Dockerfile b/.devops/main-intel.Dockerfile index 1b5859715d4..dbb60682dce 100644 --- a/.devops/main-intel.Dockerfile +++ b/.devops/main-intel.Dockerfile @@ -1,6 +1,6 @@ -ARG ONEAPI_VERSION=2025.1.1-0-devel-ubuntu24.04 +ARG ONEAPI_VERSION=2025.3.3-0-devel-ubuntu24.04 -FROM intel/oneapi-basekit:$ONEAPI_VERSION AS build +FROM intel/deep-learning-essentials:$ONEAPI_VERSION AS build WORKDIR /app RUN apt-get update && \ @@ -16,7 +16,7 @@ RUN if [ "${GGML_SYCL_F16}" = "ON" ]; then \ fi && \ make base.en CMAKE_ARGS="-DGGML_SYCL=1 -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx ${OPT_SYCL_F16}" -FROM intel/oneapi-basekit:$ONEAPI_VERSION AS runtime +FROM intel/deep-learning-essentials:$ONEAPI_VERSION AS build WORKDIR /app RUN apt-get update && \ From ea29be532eac424323fe690b1b876b3324ced417 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Tue, 12 May 2026 11:15:56 +0200 Subject: [PATCH 594/831] squash! ci : update ONEAPI version to 2025.3.3-0-devel-ubuntu24.04 --- .devops/main-intel.Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.devops/main-intel.Dockerfile b/.devops/main-intel.Dockerfile index dbb60682dce..86b901c1538 100644 --- a/.devops/main-intel.Dockerfile +++ b/.devops/main-intel.Dockerfile @@ -16,7 +16,7 @@ RUN if [ "${GGML_SYCL_F16}" = "ON" ]; then \ fi && \ make base.en CMAKE_ARGS="-DGGML_SYCL=1 -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx ${OPT_SYCL_F16}" -FROM intel/deep-learning-essentials:$ONEAPI_VERSION AS build +FROM intel/deep-learning-essentials:$ONEAPI_VERSION AS runtime WORKDIR /app RUN apt-get update && \ From db7bcdb791162f6a256babc73e6cbceebe7c5d8d Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Thu, 14 May 2026 05:27:13 +0200 Subject: [PATCH 595/831] Revert "ggml-cuda : add ar_add() to avoid ambiguous operator+ for half/bfloat16 in CUDA 11.8" This reverts commit 5cd228494af3973294e90aad95b58c2ede400f43. Reverting in favor of: https://github.com/ggml-org/llama.cpp/pull/22994 --- ggml/src/ggml-cuda/allreduce.cu | 20 +++----------------- 1 file changed, 3 insertions(+), 17 deletions(-) diff --git a/ggml/src/ggml-cuda/allreduce.cu b/ggml/src/ggml-cuda/allreduce.cu index 03d88968cd5..434689abd95 100644 --- a/ggml/src/ggml-cuda/allreduce.cu +++ b/ggml/src/ggml-cuda/allreduce.cu @@ -105,20 +105,6 @@ static constexpr int GGML_CUDA_AR_KERNEL_BLOCKS = 8; // blocks. Tail elements (the leftover < ELEMS_PER_VEC at the end) are // handled only by block 0 to avoid cross-block writes to the same slots. // --------------------------------------------------------------------------- - -// CUDA 11.8 does not expose operator+ for half/bfloat16 below sm_530, -// so use the explicit intrinsics to avoid ambiguous implicit conversions. -template -static __device__ inline T ar_add(T a, T b) { - if constexpr (std::is_same_v) { - return __hadd(a, b); - } else if constexpr (std::is_same_v) { - return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b)); - } else { - return a + b; - } -} - template static __global__ void ggml_cuda_ar_kernel( const T_dst * sendbuf, @@ -198,13 +184,13 @@ static __global__ void ggml_cuda_ar_kernel( #pragma unroll for (int k = 0; k < ELEMS_PER_VEC; ++k) { const T_wire d_low = ggml_cuda_cast(sendbuf[off + k]); - recvbuf[off + k] = ar_add(ggml_cuda_cast(d_low), ggml_cuda_cast(wire[k])); + recvbuf[off + k] = ggml_cuda_cast(d_low) + ggml_cuda_cast(wire[k]); } } if (bid == 0 && tid < count - tail) { const T_wire d_low = ggml_cuda_cast(sendbuf[tail + tid]); recvbuf[tail + tid] = - ar_add(ggml_cuda_cast(d_low), ggml_cuda_cast(host_other[tail + tid])); + ggml_cuda_cast(d_low) + ggml_cuda_cast(host_other[tail + tid]); } } } @@ -224,7 +210,7 @@ static __global__ void ggml_cuda_ar_add_kernel( const int nt = gridDim.x * blockDim.x; for (int i = tid; i < count; i += nt) { const T_src d_low = ggml_cuda_cast(dst[i]); - dst[i] = ar_add(ggml_cuda_cast(d_low), ggml_cuda_cast(src[i])); + dst[i] = ggml_cuda_cast(d_low) + ggml_cuda_cast(src[i]); } } From 5a24c7538fcf5ccc04770b03fe569a98cf1b0f5d Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Thu, 14 May 2026 05:28:56 +0200 Subject: [PATCH 596/831] Revert "ggml-cuda : add explicit casts to -INFINITY for float and half2 types" This reverts commit a2839b4404de473bc7af127b7b308d530afda024. Reverting this as after closer inspection these only warnings and not errors. --- ggml/src/ggml-cuda/common.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 246a76193ca..10817505d9f 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -582,9 +582,9 @@ template struct block_reduce_policy { static __device__ T sentinel() { if constexpr (std::is_same_v) { - return -(float)INFINITY; + return -INFINITY; } else if constexpr (std::is_same_v) { - return make_half2(__float2half(-(float)INFINITY), __float2half(-(float)INFINITY)); + return make_half2(-INFINITY, -INFINITY); } else { static_assert(ggml_cuda_dependent_false_v, "Unsupported type for block reduce max"); } From dd706793ccdca94b01c5e3a39b000bbccc552502 Mon Sep 17 00:00:00 2001 From: Steve Lhomme Date: Sun, 10 May 2026 16:35:38 +0200 Subject: [PATCH 597/831] ggml: install ggml.pc in /pkgconfig (ggml/1480) That's always how it's done: https://github.com/search?q=path%3ACMakeLists.txt%20%22%24%7BCMAKE_INSTALL_LIBDIR%7D%2Fpkgconfig%22&type=code --- ggml/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 672b37dffc3..4e65cd68b4e 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -352,7 +352,7 @@ if (GGML_STANDALONE) @ONLY) install(FILES ${CMAKE_CURRENT_BINARY_DIR}/ggml.pc - DESTINATION share/pkgconfig) + DESTINATION ${CMAKE_INSTALL_LIBDIR}/pkgconfig) endif() # From 5f08683bb615fac03383d171737281f65050fe82 Mon Sep 17 00:00:00 2001 From: CrispStrobe <154636388+CrispStrobe@users.noreply.github.com> Date: Sun, 10 May 2026 16:45:00 +0200 Subject: [PATCH 598/831] metal : tighten input-position loop in kernel_conv_transpose_1d (ggml/1477) For a given output position j on the time axis, only input positions i such that i*s0 <= j < i*s0 + K contribute -- i.e. i in [ceil((j - K + 1)/s0), floor(j/s0)] intersected with [0, IL-1]. That's at most ceil(K/s0) values (typically 2 for stride==K/2 transposed convs). The current kernel iterates the full IL range and filters with an `if`, amplifying per-thread work by IL/ceil(K/s0) (~160x for IL=320, K=10, s0=5 -- a representative codec-decoder shape). On Apple M1 the wasted work trips the macOS GPU watchdog (kIOGPUCommandBufferCallbackErrorImpactingInteractivity) on long graphs. Compute i_min, i_max analytically before the inner loop and iterate only [i_min, i_max]. Output is bit-identical (same multiplies and adds in the same order); loop bound shrinks by IL/ceil(K/s0). Tested on M1 with a downstream consumer running a TTS codec at full T_codec; end-to-end codec decode ~3-4x faster, zero watchdog hits across long synthesis runs vs ~30% pre-patch. --- ggml/src/ggml-metal/ggml-metal.metal | 31 +++++++++++++++++++++------- 1 file changed, 24 insertions(+), 7 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index c372eaedeae..5c2ec8a4ab8 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -4850,15 +4850,32 @@ kernel void kernel_conv_transpose_1d( uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpg[[threadgroups_per_grid]]) { - float v = 0.0f; + // For output position j on the time axis, only input positions + // i such that i*s0 <= j < i*s0 + K + // contribute -- i.e. i in [ceil((j - K + 1)/s0), floor(j/s0)] + // intersected with [0, IL-1]. That's at most ceil(K/s0) values + // (typically 2 for stride==K/2 transposed convs). + const int32_t j = tgpig[0]; + const int32_t s0 = args.s0; + const int32_t K = args.K; + const int32_t IL = args.IL; + + int32_t i_min; + { + int32_t a = j - K + 1; + i_min = a <= 0 ? 0 : (a + s0 - 1) / s0; // ceil(a/s0) for a>0 + } + int32_t i_max = j / s0; + if (i_max > IL - 1) i_max = IL - 1; - for (int64_t c = 0; c < args.IC; c++) { - const int32_t kernel_offset = c * tgpg[1] * args.K + args.K * tgpig[1]; - const int32_t input_offset = c * args.IL; + float v = 0.0f; + if (i_min <= i_max) { + for (int64_t c = 0; c < args.IC; c++) { + const int32_t kernel_offset = c * tgpg[1] * K + K * tgpig[1]; + const int32_t input_offset = c * IL; - for (int64_t i = 0; i < args.IL; i++) { - if (tgpig[0] >= i * args.s0 && tgpig[0] < i * args.s0 + args.K) { - v += src0[kernel_offset + tgpig[0] - i * args.s0] * src1[input_offset + i]; + for (int32_t i = i_min; i <= i_max; i++) { + v += float(src0[kernel_offset + j - i * s0]) * src1[input_offset + i]; } } } From 73f63f529539a740a3a81f87ba4984abfa7daf3d Mon Sep 17 00:00:00 2001 From: Oliver Walsh Date: Sun, 10 May 2026 16:32:41 +0100 Subject: [PATCH 599/831] ggml-virtgpu : include missing mutex header (llama/22810) Add missing `#include ` in ggml-backend-device.cpp. Fixes: #22809 Signed-off-by: Oliver Walsh --- ggml/src/ggml-virtgpu/ggml-backend-device.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ggml/src/ggml-virtgpu/ggml-backend-device.cpp b/ggml/src/ggml-virtgpu/ggml-backend-device.cpp index ec8156bb868..a978812cd90 100644 --- a/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +++ b/ggml/src/ggml-virtgpu/ggml-backend-device.cpp @@ -1,5 +1,7 @@ #include "ggml-remoting.h" +#include + static const char * ggml_backend_remoting_device_get_name(ggml_backend_dev_t dev) { virtgpu * gpu = DEV_TO_GPU(dev); From 4db2f450754b1aa41e45c3b1328555f48f2ba6fb Mon Sep 17 00:00:00 2001 From: Neo Zhang Date: Mon, 11 May 2026 13:01:47 +0800 Subject: [PATCH 600/831] Add OP im2col_3d (llama/22903) * add im2col_3d * format code * update the ops.md --- ggml/src/ggml-sycl/ggml-sycl.cpp | 9 + ggml/src/ggml-sycl/im2col.cpp | 442 ++++++++++++++++++++++++------- ggml/src/ggml-sycl/im2col.hpp | 8 +- 3 files changed, 367 insertions(+), 92 deletions(-) diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index e7768b8bf61..57cc4ffb6f7 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -4159,6 +4159,11 @@ static void ggml_sycl_im2col(ggml_backend_sycl_context & ctx, ggml_tensor * dst) ggml_sycl_op_im2col(ctx, dst); } +static void ggml_sycl_im2col_3d(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); + ggml_sycl_op_im2col_3d(ctx, dst); +} + static void ggml_sycl_sum(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); GGML_ASSERT(ggml_is_contiguous(dst->src[0])); @@ -4456,6 +4461,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg case GGML_OP_IM2COL: ggml_sycl_im2col(ctx, dst); break; + case GGML_OP_IM2COL_3D: + ggml_sycl_im2col_3d(ctx, dst); + break; case GGML_OP_POOL_2D: ggml_sycl_pool2d(ctx, dst); break; @@ -5175,6 +5183,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_ROPE: case GGML_OP_ROPE_BACK: case GGML_OP_IM2COL: + case GGML_OP_IM2COL_3D: case GGML_OP_UPSCALE: return true; case GGML_OP_SUM: diff --git a/ggml/src/ggml-sycl/im2col.cpp b/ggml/src/ggml-sycl/im2col.cpp index 6d75d34d83f..7bf3584fb97 100644 --- a/ggml/src/ggml-sycl/im2col.cpp +++ b/ggml/src/ggml-sycl/im2col.cpp @@ -1,6 +1,6 @@ // // MIT license -// Copyright (C) 2024 Intel Corporation +// Copyright (C) 2026 Intel Corporation // SPDX-License-Identifier: MIT // @@ -12,125 +12,389 @@ #include "im2col.hpp" -#include -#include // For std::is_same_v - -#include "ggml.h" +#define MAX_GRIDDIM_Z 65535 template -static void im2col_kernel(const float * x, T * dst, int64_t batch_offset, int64_t offset_delta, int64_t IC, int64_t IW, - int64_t IH, int64_t OH, int64_t OW, int64_t KW, int64_t KH, int64_t pelements, int64_t CHW, - int s0, int s1, int p0, int p1, int d0, int d1, const sycl::nd_item<3> & item_ct1) { - const int64_t work_group_size = item_ct1.get_local_range(2); - const int64_t global_id = item_ct1.get_local_id(2) + (work_group_size * item_ct1.get_group(2)); - - // make each work-item deal with more elements since sycl global range can not exceed max int - for (int64_t i = global_id; i < pelements; i += (work_group_size * item_ct1.get_group_range(2))) { - const int64_t ksize = OW * KH; - const int64_t kx = i / ksize; - const int64_t kd = kx * ksize; - const int64_t ky = (i - kd) / OW; - const int64_t ix = i % OW; - - const int64_t oh = item_ct1.get_group(1); - const int64_t batch = item_ct1.get_group(0) / IC; - const int64_t ic = item_ct1.get_group(0) % IC; - - const int64_t iiw = (ix * s0) + (kx * d0) - p0; - const int64_t iih = (oh * s1) + (ky * d1) - p1; - - const int64_t offset_dst = (((batch * OH + oh) * OW + ix) * CHW) + (ic * (KW * KH) + ky * KW + kx); - - const int64_t offset_src_base = (ic * offset_delta) + (batch * batch_offset); - const int64_t offset_src = offset_src_base + (iih * IW) + iiw; - - const bool out_of_bounds = (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW); - const float src_val = out_of_bounds ? 0.0f : x[offset_src]; - - if constexpr (std::is_same_v) { - dst[offset_dst] = sycl::half(src_val); - } else if constexpr (std::is_same_v) { - dst[offset_dst] = src_val; - } +static void im2col_kernel( + const float * x, T * dst, + int64_t IC, int64_t IW, int64_t IH, int64_t OH, int64_t OW, int64_t KW, int64_t KH, + int64_t IC_IH_IW, int64_t IH_IW, int64_t N_OH, int64_t KH_KW, int64_t IC_KH_KW, + int s0, int s1, int p0, int p1, int d0, int d1) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + const int64_t i = item_ct1.get_local_id(2) + item_ct1.get_group(2) * item_ct1.get_local_range(2); + if (i >= IC_KH_KW) { + return; } -} -template -static void im2col_sycl_internal(const float * x, T * dst, int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, - int64_t KH, int64_t IC, int64_t batch, int64_t batch_offset, int64_t offset_delta, - int s0, int s1, int p0, int p1, int d0, int d1, queue_ptr stream) { - const int64_t parallel_elements = OW * KW * KH; - const int64_t num_blocks = (parallel_elements + SYCL_IM2COL_BLOCK_SIZE - 1) / SYCL_IM2COL_BLOCK_SIZE; + const int64_t iic = i / (KH_KW); + const int64_t rem = i - iic * KH_KW; + const int64_t ikh = rem / KW; + const int64_t ikw = rem - ikh * KW; - // decrease global range when it exceeds the max int - int64_t local_size = downsample_sycl_global_range(batch * IC * OH * num_blocks, SYCL_IM2COL_BLOCK_SIZE); + const int64_t iow = item_ct1.get_group(1); + for (int64_t iz = item_ct1.get_group(0); iz < N_OH; iz += MAX_GRIDDIM_Z) { + const int64_t in = iz / OH; + const int64_t ioh = iz - in * OH; - sycl::range<3> block_nums(batch * IC, OH, num_blocks); - sycl::range<3> local_range(1, 1, local_size); + const int64_t iiw = iow * s0 + ikw * d0 - p0; + const int64_t iih = ioh * s1 + ikh * d1 - p1; - const int64_t CHW = IC * KH * KW; + const int64_t offset_dst = + ((in * OH + ioh) * OW + iow) * IC_KH_KW + iic * KH_KW + ikh * KW + ikw; - stream->parallel_for(sycl::nd_range<3>(block_nums * local_range, local_range), [=](sycl::nd_item<3> item_ct1) { - im2col_kernel(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH, parallel_elements, CHW, s0, s1, - p0, p1, d0, d1, item_ct1); - }); + if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { + dst[offset_dst] = 0.0f; + } else { + const int64_t offset_src = iic * IC_IH_IW + in * IH_IW; + dst[offset_dst] = x[offset_src + iih * IW + iiw]; + } + } + + GGML_UNUSED(IC); + GGML_UNUSED(KH); } -static void im2col_sycl_f16(const float * x, sycl::half * dst, int64_t IW, int64_t IH, int64_t OW, int64_t OH, - int64_t KW, int64_t KH, int64_t IC, int64_t batch, int64_t batch_offset, - int64_t offset_delta, int s0, int s1, int p0, int p1, int d0, int d1, queue_ptr stream) { - if (!stream->get_device().has(sycl::aspect::fp16)) { - throw sycl::exception(sycl::make_error_code(sycl::errc::kernel_not_supported), - "Device does not support half precision (fp16) operations!"); - } - im2col_sycl_internal(x, dst, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, offset_delta, s0, s1, p0, - p1, d0, d1, stream); +// im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW] +template +static void im2col_sycl(const float * x, + T * dst, + int64_t IW, + int64_t IH, + int64_t OW, + int64_t OH, + int64_t KW, + int64_t KH, + int64_t IC, + int64_t N, + int64_t IC_IH_IW, + int64_t IH_IW, + int s0, + int s1, + int p0, + int p1, + int d0, + int d1, + dpct::queue_ptr stream) { + const int64_t IC_KH_KW = IC * KH * KW; + const int64_t num_blocks = (IC_KH_KW + SYCL_IM2COL_BLOCK_SIZE - 1) / SYCL_IM2COL_BLOCK_SIZE; + const int64_t N_OH = N * OH; + const int64_t KH_KW = KW*KH; + dpct::dim3 block_nums(num_blocks, OW, MIN(N_OH, MAX_GRIDDIM_Z)); + /* + DPCT1049:73: The work-group size passed to the SYCL kernel may exceed the limit. To get the device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. + */ + stream->parallel_for(sycl::nd_range<3>(block_nums * sycl::range<3>(1, 1, MIN(IC_KH_KW, SYCL_IM2COL_BLOCK_SIZE)), + sycl::range<3>(1, 1, MIN(IC_KH_KW, SYCL_IM2COL_BLOCK_SIZE))), + [=](sycl::nd_item<3> item_ct1) { + im2col_kernel(x, dst, IC, IW, IH, OH, OW, KW, KH, IC_IH_IW, IH_IW, N_OH, KH_KW, IC_KH_KW, + s0, s1, p0, p1, d0, d1); + }); +} + +static void im2col_sycl_f16(const float * x, + sycl::half * dst, + int64_t IW, + int64_t IH, + int64_t OW, + int64_t OH, + int64_t KW, + int64_t KH, + int64_t IC, + int64_t N, + int64_t IC_IH_IW, + int64_t IH_IW, + int s0, + int s1, + int p0, + int p1, + int d0, + int d1, + dpct::queue_ptr stream) { + im2col_sycl(x, dst, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream); } -static void im2col_sycl_f32(const float * x, float * dst, int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, - int64_t KH, int64_t IC, int64_t batch, int64_t batch_offset, int64_t offset_delta, int s0, - int s1, int p0, int p1, int d0, int d1, queue_ptr stream) { - im2col_sycl_internal(x, dst, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, offset_delta, s0, s1, p0, p1, - d0, d1, stream); +static void im2col_sycl_f32(const float * x, + float * dst, + int64_t IW, + int64_t IH, + int64_t OW, + int64_t OH, + int64_t KW, + int64_t KH, + int64_t IC, + int64_t N, + int64_t IC_IH_IW, + int64_t IH_IW, + int s0, + int s1, + int p0, + int p1, + int d0, + int d1, + dpct::queue_ptr stream) { + im2col_sycl(x, dst, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream); } void ggml_sycl_op_im2col(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; + const float * src1_d = (const float *)src1->data; + float * dst_d = (float *)dst->data; + dpct::queue_ptr stream = ctx.stream(); GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32); - const int32_t s0 = ((const int32_t *) (dst->op_params))[0]; - const int32_t s1 = ((const int32_t *) (dst->op_params))[1]; - const int32_t p0 = ((const int32_t *) (dst->op_params))[2]; - const int32_t p1 = ((const int32_t *) (dst->op_params))[3]; - const int32_t d0 = ((const int32_t *) (dst->op_params))[4]; - const int32_t d1 = ((const int32_t *) (dst->op_params))[5]; + const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; + const int32_t s1 = ((const int32_t*)(dst->op_params))[1]; + const int32_t p0 = ((const int32_t*)(dst->op_params))[2]; + const int32_t p1 = ((const int32_t*)(dst->op_params))[3]; + const int32_t d0 = ((const int32_t*)(dst->op_params))[4]; + const int32_t d1 = ((const int32_t*)(dst->op_params))[5]; - const bool is_2D = ((const int32_t *) (dst->op_params))[6] == 1; + const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1; const int64_t IC = src1->ne[is_2D ? 2 : 1]; const int64_t IH = is_2D ? src1->ne[1] : 1; - const int64_t IW = src1->ne[0]; + const int64_t IW = src1->ne[0]; const int64_t KH = is_2D ? src0->ne[1] : 1; - const int64_t KW = src0->ne[0]; + const int64_t KW = src0->ne[0]; const int64_t OH = is_2D ? dst->ne[2] : 1; - const int64_t OW = dst->ne[1]; + const int64_t OW = dst->ne[1]; + + const int64_t IC_IH_IW = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32 + const int64_t N = src1->ne[is_2D ? 3 : 2]; + const int64_t IH_IW = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32 + + if(dst->type == GGML_TYPE_F16) { + im2col_sycl_f16(src1_d, (sycl::half *) dst_d, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, + d0, d1, stream); + } else { + im2col_sycl_f32(src1_d, (float *) dst_d, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream); + } +} + +// [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW] +template +static void im2col_3d_kernel( + const float * src, T * dst, + int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC, + int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW, + int64_t OH_OW, int64_t KD_KH_KW, int64_t ID_IH_IW, int64_t KH_KW, int64_t IH_IW, int64_t IC_ID_IH_IW, + int64_t IC_KD_KH_KW, int64_t OW_KD_KH_KW, int64_t OD_OH_OW_IC_KD_KH_KW, int64_t OH_OW_IC_KD_KH_KW, + int64_t OW_IC_KD_KH_KW, int64_t N_OD_OH, int64_t OD_OH, + int64_t stride_q, int64_t stride_z, int64_t stride_y, int64_t stride_x, + int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + const int64_t i = item_ct1.get_local_id(2) + item_ct1.get_group(2) * item_ct1.get_local_range(2); + if (i >= IC_KD_KH_KW) { + return; + } + GGML_UNUSED(N); GGML_UNUSED(OC); GGML_UNUSED(OH_OW); GGML_UNUSED(OD); GGML_UNUSED(OW); GGML_UNUSED(KD); GGML_UNUSED(KH); + GGML_UNUSED(ID_IH_IW); GGML_UNUSED(IH_IW); GGML_UNUSED(IC_ID_IH_IW); GGML_UNUSED(OW_KD_KH_KW); + + const int64_t iic = i / KD_KH_KW; + const int64_t ikd = (i - iic * KD_KH_KW) / KH_KW; + const int64_t ikh = (i - iic * KD_KH_KW - ikd * KH_KW) / KW; + const int64_t ikw = i % KW; + + const int64_t iow = item_ct1.get_group(1); + for (int64_t iz = item_ct1.get_group(0); iz < N_OD_OH; iz += MAX_GRIDDIM_Z) { + const int64_t in = iz / OD_OH; + const int64_t iod = (iz - in*OD_OH) / OH; + const int64_t ioh = iz % OH; + + const int64_t iiw = iow * s0 + ikw * d0 - p0; + const int64_t iih = ioh * s1 + ikh * d1 - p1; + const int64_t iid = iod * s2 + ikd * d2 - p2; + + const int64_t offset_dst = in*OD_OH_OW_IC_KD_KH_KW + iod*OH_OW_IC_KD_KH_KW + ioh*OW_IC_KD_KH_KW + iow*IC_KD_KH_KW + iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw; + + if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) { + dst[offset_dst] = 0.0f; + } else { + const int64_t offset_src = ((in * IC + iic) * stride_q) + (iid * stride_z) + (iih * stride_y) + (iiw * stride_x); + dst[offset_dst] = src[offset_src]; + } + } +} + +// [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW] +template +static void im2col_3d_sycl(const float * src, + T * dst, + int64_t N, + int64_t IC, + int64_t ID, + int64_t IH, + int64_t IW, + int64_t OC, + int64_t KD, + int64_t KH, + int64_t KW, + int64_t OD, + int64_t OH, + int64_t OW, + int64_t stride_q, + int64_t stride_z, + int64_t stride_y, + int64_t stride_x, + int s0, + int s1, + int s2, + int p0, + int p1, + int p2, + int d0, + int d1, + int d2, + dpct::queue_ptr stream) { + const int64_t OH_OW = OH*OW; + const int64_t KD_KH_KW = KD*KH*KW; + const int64_t ID_IH_IW = ID*IH*IW; + const int64_t KH_KW = KH*KW; + const int64_t IH_IW = IH*IW; + const int64_t IC_KD_KH_KW = IC*KD*KH*KW; + const int64_t OW_KD_KH_KW = OW*KD*KH*KW; + const int64_t N_OD_OH = N*OD*OH; + const int64_t OD_OH = OD*OH; + const int64_t IC_ID_IH_IW = IC*ID*IH*IW; + const int64_t OD_OH_OW_IC_KD_KH_KW = OD*OH*OW*IC*KD*KH*KW; + const int64_t OH_OW_IC_KD_KH_KW = OH*OW*IC*KD*KH*KW; + const int64_t OW_IC_KD_KH_KW = OW*IC*KD*KH*KW; + const int64_t num_blocks = (IC_KD_KH_KW + SYCL_IM2COL_BLOCK_SIZE - 1) / SYCL_IM2COL_BLOCK_SIZE; + dpct::dim3 block_nums(num_blocks, OW, MIN(N_OD_OH, MAX_GRIDDIM_Z)); + /* + DPCT1049:74: The work-group size passed to the SYCL kernel may exceed the limit. To get the device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. + */ + stream->parallel_for(sycl::nd_range<3>(block_nums * sycl::range<3>(1, 1, MIN(IC_KD_KH_KW, SYCL_IM2COL_BLOCK_SIZE)), + sycl::range<3>(1, 1, MIN(IC_KD_KH_KW, SYCL_IM2COL_BLOCK_SIZE))), + [=](sycl::nd_item<3> item_ct1) { + im2col_3d_kernel(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, OH_OW, KD_KH_KW, + ID_IH_IW, KH_KW, IH_IW, IC_ID_IH_IW, IC_KD_KH_KW, OW_KD_KH_KW, + OD_OH_OW_IC_KD_KH_KW, OH_OW_IC_KD_KH_KW, OW_IC_KD_KH_KW, N_OD_OH, OD_OH, + stride_q, stride_z, stride_y, stride_x, s0, s1, s2, p0, p1, p2, d0, d1, + d2); + }); +} + +static void im2col_3d_sycl_f16(const float * src, + sycl::half * dst, + int64_t N, + int64_t IC, + int64_t ID, + int64_t IH, + int64_t IW, + int64_t OC, + int64_t KD, + int64_t KH, + int64_t KW, + int64_t OD, + int64_t OH, + int64_t OW, + int64_t stride_q, + int64_t stride_z, + int64_t stride_y, + int64_t stride_x, + int s0, + int s1, + int s2, + int p0, + int p1, + int p2, + int d0, + int d1, + int d2, + dpct::queue_ptr stream) { + im2col_3d_sycl(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, stride_q, stride_z, stride_y, + stride_x, s0, s1, s2, p0, p1, p2, d0, d1, d2, stream); +} + +static void im2col_3d_sycl_f32(const float * src, + float * dst, + int64_t N, + int64_t IC, + int64_t ID, + int64_t IH, + int64_t IW, + int64_t OC, + int64_t KD, + int64_t KH, + int64_t KW, + int64_t OD, + int64_t OH, + int64_t OW, + int64_t stride_q, + int64_t stride_z, + int64_t stride_y, + int64_t stride_x, + int s0, + int s1, + int s2, + int p0, + int p1, + int p2, + int d0, + int d1, + int d2, + dpct::queue_ptr stream) { + im2col_3d_sycl(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, + stride_q, stride_z, stride_y, stride_x, + s0, s1, s2, p0, p1, p2, d0, d1, d2, stream); +} + +void ggml_sycl_op_im2col_3d(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + const float * src1_d = (const float *)src1->data; + float * dst_d = (float *)dst->data; + dpct::queue_ptr stream = ctx.stream(); + + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32); + + GGML_TENSOR_BINARY_OP_LOCALS + + const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; + const int32_t s1 = ((const int32_t *)(dst->op_params))[1]; + const int32_t s2 = ((const int32_t *)(dst->op_params))[2]; + const int32_t p0 = ((const int32_t *)(dst->op_params))[3]; + const int32_t p1 = ((const int32_t *)(dst->op_params))[4]; + const int32_t p2 = ((const int32_t *)(dst->op_params))[5]; + const int32_t d0 = ((const int32_t *)(dst->op_params))[6]; + const int32_t d1 = ((const int32_t *)(dst->op_params))[7]; + const int32_t d2 = ((const int32_t *)(dst->op_params))[8]; + const int32_t IC = ((const int32_t *)(dst->op_params))[9]; + + const int64_t N = ne13 / IC; + const int64_t ID = ne12; + const int64_t IH = ne11; + const int64_t IW = ne10; + + const int64_t OC = ne03 / IC; + const int64_t KD = ne02; + const int64_t KH = ne01; + const int64_t KW = ne00; - const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / sizeof(float); - const int64_t batch = src1->ne[is_2D ? 3 : 2]; - const size_t batch_offset = src1->nb[is_2D ? 3 : 2] / sizeof(float); + const int64_t OD = ne3 / N; + const int64_t OH = ne2; + const int64_t OW = ne1; - queue_ptr stream = ctx.stream(); + const size_t es = ggml_element_size(src1); + const int64_t stride_x = src1->nb[0] / es; + const int64_t stride_y = src1->nb[1] / es; + const int64_t stride_z = src1->nb[2] / es; + const int64_t stride_q = src1->nb[3] / es; - if (dst->type == GGML_TYPE_F16) { - im2col_sycl_f16((const float *) src1->data, (sycl::half *) dst->data, IW, IH, OW, OH, KW, KH, IC, batch, - batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, stream); + if(dst->type == GGML_TYPE_F16) { + im2col_3d_sycl_f16(src1_d, (sycl::half *) dst_d, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, + stride_q, stride_z, stride_y, stride_x, + s0, s1, s2, p0, p1, p2, d0, d1, d2, stream); } else { - im2col_sycl_f32((const float *) src1->data, (float *) dst->data, IW, IH, OW, OH, KW, KH, IC, batch, - batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, stream); + im2col_3d_sycl_f32(src1_d, (float *) dst_d, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, + stride_q, stride_z, stride_y, stride_x, + s0, s1, s2, p0, p1, p2, d0, d1, d2, stream); } } diff --git a/ggml/src/ggml-sycl/im2col.hpp b/ggml/src/ggml-sycl/im2col.hpp index dbbb248ddb4..976d1094636 100644 --- a/ggml/src/ggml-sycl/im2col.hpp +++ b/ggml/src/ggml-sycl/im2col.hpp @@ -1,6 +1,6 @@ // // MIT license -// Copyright (C) 2024 Intel Corporation +// Copyright (C) 2026 Intel Corporation // SPDX-License-Identifier: MIT // @@ -15,7 +15,9 @@ #include "common.hpp" -void ggml_sycl_op_im2col( - ggml_backend_sycl_context & ctx, ggml_tensor *dst); +#define SYCL_IM2COL_BLOCK_SIZE 256 + +void ggml_sycl_op_im2col(ggml_backend_sycl_context & ctx, ggml_tensor * dst); +void ggml_sycl_op_im2col_3d(ggml_backend_sycl_context & ctx, ggml_tensor * dst); #endif // GGML_SYCL_IM2COL_HPP From 0077a6d3320dcf3d72983a0ce4f0ba35e21d051e Mon Sep 17 00:00:00 2001 From: Oliver Simons Date: Mon, 11 May 2026 12:16:38 +0200 Subject: [PATCH 601/831] CUDA: directly include cuda/iterator (llama/22936) Before, we relied on a transient import from `cub/cub.cuh`, which is bad practice to do as cub may not always expose cuda/iterator --- ggml/src/ggml-cuda/argsort.cu | 1 + 1 file changed, 1 insertion(+) diff --git a/ggml/src/ggml-cuda/argsort.cu b/ggml/src/ggml-cuda/argsort.cu index 0f3f017b534..c4f08091e79 100644 --- a/ggml/src/ggml-cuda/argsort.cu +++ b/ggml/src/ggml-cuda/argsort.cu @@ -4,6 +4,7 @@ # include # if (CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 1) # define STRIDED_ITERATOR_AVAILABLE +# include # endif using namespace cub; #endif // GGML_CUDA_USE_CUB From c0c1f994b711114e907ac3250605b7542d7d19ec Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Mon, 11 May 2026 05:49:03 -0500 Subject: [PATCH 602/831] vulkan: Support asymmetric FA in scalar/mmq/coopmat1 paths (llama/22589) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 253 +++++++-------- .../vulkan-shaders/flash_attn.comp | 176 +++++----- .../vulkan-shaders/flash_attn_base.glsl | 210 +++--------- .../vulkan-shaders/flash_attn_cm1.comp | 160 +++++----- .../vulkan-shaders/flash_attn_cm2.comp | 43 +-- .../vulkan-shaders/flash_attn_dequant.glsl | 123 +++++++ .../vulkan-shaders/flash_attn_mmq_funcs.glsl | 300 +++++++++++------- .../vulkan-shaders/mul_mmq_shmem_types.glsl | 11 +- .../vulkan-shaders/vulkan-shaders-gen.cpp | 36 +-- 9 files changed, 632 insertions(+), 680 deletions(-) create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 0a7931002ab..7e450a559dd 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -855,7 +855,7 @@ struct vk_device_struct { vk_pipeline pipeline_conv2d_dw_whcn_f32, pipeline_conv2d_dw_whcn_f16_f32; vk_pipeline pipeline_conv2d_dw_cwhn_f32, pipeline_conv2d_dw_cwhn_f16_f32; - std::map pipeline_flash_attn_f32_f16[GGML_TYPE_COUNT]; + std::map pipeline_flash_attn_f32_f16; std::map, vk_pipeline> pipeline_fa_mask_opt; @@ -2933,10 +2933,10 @@ struct vk_fa_tuning_params { } }; -static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type kv_type); +static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type k_type, ggml_type v_type); static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc); -static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) { +static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type k_type, ggml_type v_type, bool f32acc) { vk_fa_tuning_params result{}; result.path = FA_SCALAR; @@ -2988,7 +2988,7 @@ static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device, result.shmem_staging = (device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 && hsv < 256) ? 1 : 0; - if (!reduce_block_rows && !ggml_vk_flash_attn_scalar_shmem_support(device, result, hsk, hsv, f32acc, kv_type)) { + if (!reduce_block_rows && !ggml_vk_flash_attn_scalar_shmem_support(device, result, hsk, hsv, f32acc, k_type, v_type)) { result.block_rows /= 2; } @@ -3011,10 +3011,11 @@ static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device, return result; } -static vk_fa_tuning_params get_fa_tuning_params_coopmat1(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) { +static vk_fa_tuning_params get_fa_tuning_params_coopmat1(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type k_type, ggml_type v_type, bool f32acc) { GGML_UNUSED(n_rows); GGML_UNUSED(n_kv); - GGML_UNUSED(kv_type); + GGML_UNUSED(k_type); + GGML_UNUSED(v_type); GGML_UNUSED(f32acc); vk_fa_tuning_params result{}; @@ -3070,12 +3071,6 @@ static vk_fa_tuning_params get_fa_tuning_params_coopmat2(const vk_device& device } static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type k_type, ggml_type v_type, bool f32acc) { - // Mixed K/V is only implemented on the coopmat2 (flash_attn_cm2) path; never use scalar/cm1. - if (k_type != v_type) { - GGML_ASSERT(device->coopmat2); - return get_fa_tuning_params_coopmat2(device, hsk, hsv, n_rows, n_kv, k_type, v_type, f32acc); - } - FaCodePath path = device->coopmat2 ? FA_COOPMAT2 : device->coopmat1_fa_support ? FA_COOPMAT1 : FA_SCALAR; @@ -3087,7 +3082,7 @@ static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_ if (path == FA_COOPMAT1) { bool shape_ok = (f32acc && device->coopmat_support_16x16x16_f32acc) || (!f32acc && device->coopmat_support_16x16x16_f16acc); - const vk_fa_tuning_params params = get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, k_type, f32acc); + const vk_fa_tuning_params params = get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, k_type, v_type, f32acc); bool shmem_ok = ggml_vk_flash_attn_coopmat_shmem_support(device, params, hsk, hsv, f32acc); if (!shape_ok || !shmem_ok) { @@ -3107,9 +3102,9 @@ static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_ switch (path) { case FA_SCALAR: - return get_fa_tuning_params_scalar(device, hsk, hsv, n_rows, n_kv, k_type, f32acc); + return get_fa_tuning_params_scalar(device, hsk, hsv, n_rows, n_kv, k_type, v_type, f32acc); case FA_COOPMAT1: - return get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, k_type, f32acc); + return get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, k_type, v_type, f32acc); case FA_COOPMAT2: return get_fa_tuning_params_coopmat2(device, hsk, hsv, n_rows, n_kv, k_type, v_type, f32acc); default: @@ -3279,6 +3274,20 @@ static uint32_t get_subgroup_size(const std::string &pipeline_name, const vk_dev return 0; // If no matching configuration is found } +// Whether scalar flash attention will use the MMQ path for the given k_type. +static bool ggml_vk_fa_scalar_uses_mmq(const vk_device& device, ggml_type k_type) { +#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + return device->integer_dot_product && device->subgroup_clustered && + (k_type == GGML_TYPE_Q4_0 || k_type == GGML_TYPE_Q4_1 || + k_type == GGML_TYPE_Q5_0 || k_type == GGML_TYPE_Q5_1 || + k_type == GGML_TYPE_Q8_0); +#else + GGML_UNUSED(device); + GGML_UNUSED(k_type); + return false; +#endif +} + static void ggml_vk_load_shaders(vk_device& device) { VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")"); @@ -3525,121 +3534,96 @@ static void ggml_vk_load_shaders(vk_device& device) { align, disable_robustness, require_full_subgroups, required_subgroup_size); }; -#define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \ - for (auto &fa : device->pipeline_flash_attn_f32_f16[TYPE]) { \ - FaCodePath path = fa.first.path; \ - uint32_t Br = fa.first.Br; \ - uint32_t Bc = fa.first.Bc; \ - bool aligned = fa.first.aligned; \ - bool f32acc = fa.first.f32acc; \ - uint32_t fa_sgs = fa.first.subgroup_size; \ - bool fa_ds = fa.first.subgroup_size == 0; \ - if (path == FAPATH) { \ - if (aligned) { \ - if (f32acc) { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), Bc, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \ - } else { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), Bc, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \ - } \ - } else { \ - if (f32acc) { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), 1, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \ - } else { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), 1, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \ - } \ - } \ - } \ - } - - if (device->fp16) { - CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, ) - CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, ) - -#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) - if (device->integer_dot_product && device->subgroup_clustered) { - CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _int8) - CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _int8) - CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _int8) - CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _int8) - CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _int8) - CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _int8) - } else -#endif - { - CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, ) - CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, ) - CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, ) - CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, ) - CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, ) - CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, ) - } - } else { - CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, _fp32) - CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, _fp32) - + // FA scalar has two SPIR-V modules (MMQ vs non-MMQ); FA cm1 has one. K/V + // quant type is selected at runtime via the FaTypeK / FaTypeV spec constants. + + for (auto &fa : device->pipeline_flash_attn_f32_f16) { + if (fa.first.path != FA_SCALAR) continue; + const uint32_t Br = fa.first.Br; + const uint32_t Bc = fa.first.Bc; + const bool aligned = fa.first.aligned; + const bool f32acc = fa.first.f32acc; + const uint32_t fa_sgs = fa.first.subgroup_size; + const bool fa_ds = fa.first.subgroup_size == 0; + + const bool use_mmq = ggml_vk_fa_scalar_uses_mmq(device, fa.first.k_type); + const void * spv_data = nullptr; + size_t spv_size = 0; + if (use_mmq) { #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) - if (device->integer_dot_product && device->subgroup_clustered) { - CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32_int8) - CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32_int8) - CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _fp32_int8) - CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _fp32_int8) - CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _fp32_int8) - CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _fp32_int8) - } else + if (device->fp16) { + if (f32acc) { spv_data = flash_attn_f32_f16_int8_data; spv_size = flash_attn_f32_f16_int8_len; } + else { spv_data = flash_attn_f32_f16_f16acc_int8_data; spv_size = flash_attn_f32_f16_f16acc_int8_len; } + } else { + spv_data = flash_attn_f32_f16_fp32_int8_data; + spv_size = flash_attn_f32_f16_fp32_int8_len; + } #endif - { - CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32) - CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32) - CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_SCALAR, _fp32) - CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_SCALAR, _fp32) - CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_SCALAR, _fp32) - CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_SCALAR, _fp32) + } else { + if (device->fp16) { + if (f32acc) { spv_data = flash_attn_f32_f16_data; spv_size = flash_attn_f32_f16_len; } + else { spv_data = flash_attn_f32_f16_f16acc_data; spv_size = flash_attn_f32_f16_f16acc_len; } + } else { + spv_data = flash_attn_f32_f16_fp32_data; + spv_size = flash_attn_f32_f16_fp32_len; + } } + const char *name = aligned ? "flash_attn_f32_f16_aligned" : "flash_attn_f32_f16"; + ggml_vk_create_pipeline(device, fa.second, name, spv_size, spv_data, "main", 7, + sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, + get_fa_spec_constants(fa.first), aligned ? Bc : 1, true, + !fa_ds, !fa_ds ? fa_sgs : 0); } + #if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) if (device->coopmat1_fa_support) { - CREATE_FA(GGML_TYPE_F32, f32, FA_COOPMAT1, _cm1) - CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT1, _cm1) - CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT1, _cm1) - CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT1, _cm1) - CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_COOPMAT1, _cm1) - CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_COOPMAT1, _cm1) - CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_COOPMAT1, _cm1) - CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_COOPMAT1, _cm1) + for (auto &fa : device->pipeline_flash_attn_f32_f16) { + if (fa.first.path != FA_COOPMAT1) continue; + const uint32_t Br = fa.first.Br; + const uint32_t Bc = fa.first.Bc; + const bool aligned = fa.first.aligned; + const bool f32acc = fa.first.f32acc; + const uint32_t fa_sgs = fa.first.subgroup_size; + const bool fa_ds = fa.first.subgroup_size == 0; + + const void * spv_data; + size_t spv_size; + if (f32acc) { spv_data = flash_attn_f32_f16_cm1_data; spv_size = flash_attn_f32_f16_cm1_len; } + else { spv_data = flash_attn_f32_f16_f16acc_cm1_data; spv_size = flash_attn_f32_f16_f16acc_cm1_len; } + const char *name = aligned ? "flash_attn_f32_f16_aligned_cm1" : "flash_attn_f32_f16_cm1"; + ggml_vk_create_pipeline(device, fa.second, name, spv_size, spv_data, "main", 7, + sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, + get_fa_spec_constants(fa.first), aligned ? Bc : 1, true, + !fa_ds, !fa_ds ? fa_sgs : 0); + } } #endif + #if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) -#define CREATE_FA_CM2_MIXED() \ - for (int fa_k_ty = 0; fa_k_ty < (int)GGML_TYPE_COUNT; ++fa_k_ty) { \ - for (auto &fa : device->pipeline_flash_attn_f32_f16[fa_k_ty]) { \ - FaCodePath path = fa.first.path; \ - uint32_t Br = fa.first.Br; \ - uint32_t Bc = fa.first.Bc; \ - bool aligned = fa.first.aligned; \ - bool f32acc = fa.first.f32acc; \ - if (path == FA_COOPMAT2) { \ - if (aligned) { \ - if (f32acc) { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_mixed_aligned_f32acc_cm2", flash_attn_f32_f16_mixed_cm2_len, flash_attn_f32_f16_mixed_cm2_data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), Bc, true, false, 0); \ - } else { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_mixed_aligned_f16acc_cm2", flash_attn_f32_f16_mixed_f16acc_cm2_len, flash_attn_f32_f16_mixed_f16acc_cm2_data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), Bc, true, false, 0); \ - } \ - } else { \ - if (f32acc) { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_mixed_f32acc_cm2", flash_attn_f32_f16_mixed_cm2_len, flash_attn_f32_f16_mixed_cm2_data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), 1, true, false, 0); \ - } else { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_mixed_f16acc_cm2", flash_attn_f32_f16_mixed_f16acc_cm2_len, flash_attn_f32_f16_mixed_f16acc_cm2_data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), 1, true, false, 0); \ - } \ - } \ - } \ - } \ - } if (device->coopmat2) { - CREATE_FA_CM2_MIXED(); + for (auto &fa : device->pipeline_flash_attn_f32_f16) { + if (fa.first.path != FA_COOPMAT2) continue; + const uint32_t Br = fa.first.Br; + const uint32_t Bc = fa.first.Bc; + const bool aligned = fa.first.aligned; + const bool f32acc = fa.first.f32acc; + + const void * spv_data; + size_t spv_size; + const char * name; + if (aligned) { + if (f32acc) { spv_data = flash_attn_f32_f16_cm2_data; spv_size = flash_attn_f32_f16_cm2_len; name = "flash_attn_f32_f16_aligned_f32acc_cm2"; } + else { spv_data = flash_attn_f32_f16_f16acc_cm2_data; spv_size = flash_attn_f32_f16_f16acc_cm2_len; name = "flash_attn_f32_f16_aligned_f16acc_cm2"; } + } else { + if (f32acc) { spv_data = flash_attn_f32_f16_cm2_data; spv_size = flash_attn_f32_f16_cm2_len; name = "flash_attn_f32_f16_f32acc_cm2"; } + else { spv_data = flash_attn_f32_f16_f16acc_cm2_data; spv_size = flash_attn_f32_f16_f16acc_cm2_len; name = "flash_attn_f32_f16_f16acc_cm2"; } + } + ggml_vk_create_pipeline(device, fa.second, name, spv_size, spv_data, "main", 7, + sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, + get_fa_spec_constants(fa.first), aligned ? Bc : 1, true, false, 0); + } } -#undef CREATE_FA_CM2_MIXED #endif -#undef CREATE_FA const int mul_mat_id_param_count = 5; @@ -8940,8 +8924,9 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx } } -static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type kv_type) { +static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type k_type, ggml_type v_type) { GGML_UNUSED(f32acc); + GGML_UNUSED(v_type); // Needs to be kept up to date on shader changes const uint32_t wg_size = params.workgroup_size; const uint32_t Br = params.block_rows; @@ -8949,10 +8934,7 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con const uint32_t float_type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float); - const bool mmq = device->integer_dot_product && device->subgroup_clustered && - (kv_type == GGML_TYPE_Q4_0 || kv_type == GGML_TYPE_Q4_1 || - kv_type == GGML_TYPE_Q5_0 || kv_type == GGML_TYPE_Q5_1 || - kv_type == GGML_TYPE_Q8_0 || kv_type == GGML_TYPE_IQ4_NL); + const bool mmq = ggml_vk_fa_scalar_uses_mmq(device, k_type); // tmpsh is overestimated slightly const uint32_t tmpsh = wg_size * sizeof(float); @@ -8969,17 +8951,10 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con // kvsh uses D = HSV (K goes through kblocksh instead) kvsh = params.shmem_staging ? Bc * (hsv / 4 + 1) * 4 * float_type_size : 4 * float_type_size; - // block_a_cache size depends on quant type - uint32_t block_a_size; - switch (kv_type) { - case GGML_TYPE_Q4_0: block_a_size = 4 * sizeof(uint32_t) + float_type_size; break; - case GGML_TYPE_Q4_1: block_a_size = 4 * sizeof(uint32_t) + 2 * float_type_size; break; - case GGML_TYPE_Q5_0: block_a_size = 4 * sizeof(uint32_t) + sizeof(uint32_t) + float_type_size; break; - case GGML_TYPE_Q5_1: block_a_size = 4 * sizeof(uint32_t) + sizeof(uint32_t) + 2 * float_type_size; break; - case GGML_TYPE_Q8_0: - case GGML_TYPE_IQ4_NL: block_a_size = 8 * sizeof(int32_t) + float_type_size; break; - default: block_a_size = 0; break; - } + // The mixed MMQ shader uses a superset block_a_cache that fits every + // FA-supported quant: int32_t qs[8] + uint32_t qh + FLOAT_TYPEV2 dm. + // Single-scale types leave dm.y unused; non-Q5_* leave qh unused. + const uint32_t block_a_size = 8 * sizeof(int32_t) + sizeof(uint32_t) + 2 * float_type_size; kblocksh_size = params.shmem_staging ? Bc * (hsk / 32) * block_a_size : block_a_size; } else { Qf = Br * (hsk / 4 + 1) * 4 * float_type_size; @@ -9117,10 +9092,6 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, N, KV, k->type, v->type, f32acc); - if (tuning_params.path != FA_COOPMAT2) { - GGML_ASSERT(k->type == v->type); - } - const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type)); uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type)); uint32_t v_stride = (uint32_t)(nbv1 / ggml_type_size(v->type)); @@ -9164,7 +9135,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx { std::lock_guard guard(ctx->device->mutex); - auto &pipelines = ctx->device->pipeline_flash_attn_f32_f16[k->type]; + auto &pipelines = ctx->device->pipeline_flash_attn_f32_f16; auto it = pipelines.find(fa_pipeline_state); if (it != pipelines.end()) { pipeline = it->second; @@ -15642,10 +15613,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) { return false; } - // mismatching K/V type is currently supported for coopmat2 only. - if (op->src[1]->type != op->src[2]->type && !coopmat2) { - return false; - } auto fa_kv_ok = [coopmat2](ggml_type t) { switch (t) { case GGML_TYPE_F32: diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index 6e6bdabc92e..6ac095489b3 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -22,6 +22,7 @@ #include "types.glsl" #include "flash_attn_base.glsl" +#include "flash_attn_dequant.glsl" const uint32_t HSK_per_thread = HSK / D_split; const uint32_t HSV_per_thread = HSV / D_split; @@ -128,18 +129,20 @@ void main() { Qf[buf_ib].qs[buf_iqs] = pack32(i8vec4(vals)); -#if defined(DATA_A_Q8_0) || defined(DATA_A_IQ4_NL) - if (buf_iqs == 0) { - Qf[buf_ib].ds = FLOAT_TYPEV2(qd, 0.0); - } -#else // Q4_0, Q4_1, Q5_0, Q5_1 - const FLOAT_TYPE thread_sum = vals.x + vals.y + vals.z + vals.w; - const FLOAT_TYPE sum = subgroupClusteredAdd(thread_sum, 8); + // Q8_0 K only needs (qd, _); the asymmetric Q4_*/Q5_* family also stores + // the row-sum scaled by qd, used in k_dot_correction. + if (FaTypeK == FA_TYPE_Q8_0) { + if (buf_iqs == 0) { + Qf[buf_ib].ds = FLOAT_TYPEV2(qd, 0.0); + } + } else { + const FLOAT_TYPE thread_sum = vals.x + vals.y + vals.z + vals.w; + const FLOAT_TYPE sum = subgroupClusteredAdd(thread_sum, 8); - if (buf_iqs == 0) { - Qf[buf_ib].ds = FLOAT_TYPEV2(qd, sum * qd); + if (buf_iqs == 0) { + Qf[buf_ib].ds = FLOAT_TYPEV2(qd, sum * qd); + } } -#endif #endif } barrier(); @@ -177,13 +180,9 @@ void main() { // mo_offset will point to the tile starting at row i*Br and col 0 uint32_t mo_offset = mo_stride * i; -#if BLOCK_SIZE > 1 - uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE; - uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE; -#else - uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2; - uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2; -#endif + // FaBlockBytesK/V == 2 for f16, 16 for f32, ggml block byte size for quants. + uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / FaBlockBytesK; + uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / FaBlockBytesV; uint32_t m_offset = gqa_iq1*KV; if (p.nem2 != 1 || p.nem3 != 1) { m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV; @@ -257,21 +256,21 @@ void main() { if (idx + gl_WorkGroupSize.x <= Bc * HSK / 4 || c < Bc) { FLOAT_TYPEV4 K_Tf = FLOAT_TYPEV4(0); if (!KV_bounds_check || j * Bc + c < KV) { -#if BLOCK_SIZE > 1 - uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d; - uint ib = coord / BLOCK_SIZE; - uint iqs = (coord % BLOCK_SIZE); - K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); -#else - K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]); -#endif + if (USE_DECODE_K) { + uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE_K + 4 * d; + uint ib = coord / BLOCK_SIZE_K; + uint iqs = (coord % BLOCK_SIZE_K); + K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); + } else { + K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]); + } } kvsh[c * kvsh_stride + d] = K_Tf; } } #else // MMQ - const uint ints_per_block = 8 / QUANT_R_MMQ; + const uint ints_per_block = 8u / fa_quant_r_mmq(FaTypeK); const uint quant_iters = Bc * HSK / 32 * ints_per_block; [[unroll]] for (uint32_t idx = 0; idx < quant_iters; idx += gl_WorkGroupSize.x) { const uint32_t iqs = (idx + tid) % ints_per_block; @@ -310,15 +309,13 @@ void main() { FLOAT_TYPEV4 K_Tf; if (SHMEM_STAGING != 0) { K_Tf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)]; - } else { -#if BLOCK_SIZE > 1 - uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); - uint ib = coord / BLOCK_SIZE; - uint iqs = (coord % BLOCK_SIZE); + } else if (USE_DECODE_K) { + uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE_K + 4 * (d * D_split + d_tid); + uint ib = coord / BLOCK_SIZE_K; + uint iqs = (coord % BLOCK_SIZE_K); K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); -#else + } else { K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]); -#endif } [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { Sf[r][c] += dot(ACC_TYPEV4(Q_cache[r]), ACC_TYPEV4(K_Tf)); @@ -335,15 +332,13 @@ void main() { FLOAT_TYPEV4 K_Tf; if (SHMEM_STAGING != 0) { K_Tf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)]; - } else { -#if BLOCK_SIZE > 1 - uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); - uint ib = coord / BLOCK_SIZE; - uint iqs = (coord % BLOCK_SIZE); + } else if (USE_DECODE_K) { + uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE_K + 4 * (d * D_split + d_tid); + uint ib = coord / BLOCK_SIZE_K; + uint iqs = (coord % BLOCK_SIZE_K); K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); -#else + } else { K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]); -#endif } [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { Sf[r][c] += dot(ACC_TYPEV4(Qf[tile_row(r) * qf_stride + d * D_split + d_tid]), ACC_TYPEV4(K_Tf)); @@ -366,72 +361,47 @@ void main() { int32_t k_quants[d_per_step]; ACC_TYPEV2 k_dm; + // Q4_*/Q5_* take the block-8 fast path when one step covers a full + // block; Q8_0 always goes through the per-int get_k_qs* helpers + // (its qs is byte-packed, not nibble-packed). + const bool block8_fast = (d_per_step == 8) && (FaTypeK != FA_TYPE_Q8_0); + if (SHMEM_STAGING != 0) { const uint k_block_idx = (d_tid * (HSK_per_thread / 4) + d_block) / 8; const uint buf_ib = (c * cols_per_iter + col_tid) * qf_stride + k_block_idx; -#if QUANT_AUXF == 1 - k_dm = ACC_TYPEV2(kblocksh[buf_ib].dm, 0.0); -#else k_dm = ACC_TYPEV2(kblocksh[buf_ib].dm); -#endif -#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1) || defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1) - if (d_per_step == 8) { + if (block8_fast) { + const bool has_qh = (FaTypeK == FA_TYPE_Q5_0) || (FaTypeK == FA_TYPE_Q5_1); [[unroll]] for (uint32_t d = 0; d < 4; d++) { uint vui = kblocksh[buf_ib].qs[d]; k_quants[d ] = int32_t( vui & 0x0F0F0F0F); k_quants[d + 4] = int32_t((vui >> 4) & 0x0F0F0F0F); -#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1) - uint qh_lo = (kblocksh[buf_ib].qh >> (d * 4)) & 0xF; - uint qh_hi = (kblocksh[buf_ib].qh >> (d * 4 + 16)) & 0xF; - k_quants[d ] |= int32_t((qh_lo * 0x02040810u) & 0x10101010u); - k_quants[d + 4] |= int32_t((qh_hi * 0x02040810u) & 0x10101010u); -#endif + if (has_qh) { + uint qh_lo = (kblocksh[buf_ib].qh >> (d * 4)) & 0xF; + uint qh_hi = (kblocksh[buf_ib].qh >> (d * 4 + 16)) & 0xF; + k_quants[d ] |= int32_t((qh_lo * 0x02040810u) & 0x10101010u); + k_quants[d + 4] |= int32_t((qh_hi * 0x02040810u) & 0x10101010u); + } } - } else -#endif - { + } else { [[unroll]] for (uint32_t d = 0; d < d_per_step; d++) { k_quants[d] = get_k_qs_shmem(buf_ib, (d_tid * (HSK_per_thread / 4) + d_block) % 8 + d); } } } else { - const uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d_tid * (HSK_per_thread / 4) + d_block); - const uint ib = coord / BLOCK_SIZE; - const uint iqs = (coord % BLOCK_SIZE); + const uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE_K + 4 * (d_tid * (HSK_per_thread / 4) + d_block); + const uint ib = coord / BLOCK_SIZE_K; + const uint iqs = (coord % BLOCK_SIZE_K); -#if QUANT_AUXF == 1 - k_dm = ACC_TYPEV2(get_k_d(ib, k_offset), 0.0); -#else - k_dm = ACC_TYPEV2(get_k_dm(ib, k_offset)); -#endif -#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1) || defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1) - if (d_per_step == 8) { -#if defined(DATA_A_Q5_0) - uint qh = pack32(u16vec2(k_packed.k_data_packed16[k_offset + ib].qh[0], - k_packed.k_data_packed16[k_offset + ib].qh[1])); -#elif defined(DATA_A_Q5_1) - uint qh = k_packed.k_data_packed16[k_offset + ib].qh; -#endif - [[unroll]] for (uint32_t d = 0; d < 4; d++) { -#if defined(A_TYPE_PACKED32) - uint vui = k_packed32.k_data_packed32[k_offset + ib].qs[d]; -#else - uint vui = pack32(u16vec2(k_packed.k_data_packed16[k_offset + ib].qs[iqs / 2 + d * 2 + 0], - k_packed.k_data_packed16[k_offset + ib].qs[iqs / 2 + d * 2 + 1])); -#endif - k_quants[d ] = int32_t( vui & 0x0F0F0F0F); - k_quants[d + 4] = int32_t((vui >> 4) & 0x0F0F0F0F); -#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1) - uint qh_lo = (qh >> (d * 4)) & 0xF; - uint qh_hi = (qh >> (d * 4 + 16)) & 0xF; - k_quants[d ] |= int32_t((qh_lo * 0x02040810u) & 0x10101010u); - k_quants[d + 4] |= int32_t((qh_hi * 0x02040810u) & 0x10101010u); -#endif + k_dm = ACC_TYPEV2(get_k_scale(ib, k_offset)); + + if (block8_fast) { + fa_k_qs_block8 blk = get_k_qs_block8(ib, k_offset); + [[unroll]] for (uint32_t d = 0; d < 8; d++) { + k_quants[d] = blk.qs[d]; } - } else -#endif - { + } else { [[unroll]] for (uint32_t d = 0; d < d_per_step; d++) { k_quants[d] = get_k_qs(ib, iqs + d * 4, k_offset); } @@ -516,14 +486,14 @@ void main() { if (idx + gl_WorkGroupSize.x <= Bc * HSV / 4 || c < Bc) { FLOAT_TYPEV4 V_Tf = FLOAT_TYPEV4(0); if (!KV_bounds_check || j * Bc + c < KV) { -#if BLOCK_SIZE > 1 - uint coord = (j * Bc + c) * v_stride * BLOCK_SIZE + 4 * d; - uint ib = coord / BLOCK_SIZE; - uint iqs = (coord % BLOCK_SIZE); - V_Tf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); -#else - V_Tf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]); -#endif + if (USE_DECODE_V) { + uint coord = (j * Bc + c) * v_stride * BLOCK_SIZE_V + 4 * d; + uint ib = coord / BLOCK_SIZE_V; + uint iqs = (coord % BLOCK_SIZE_V); + V_Tf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); + } else { + V_Tf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]); + } } kvsh[c * kvsh_stride + d] = V_Tf; @@ -547,15 +517,13 @@ void main() { FLOAT_TYPEV4 Vf; if (SHMEM_STAGING != 0) { Vf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)]; - } else { -#if BLOCK_SIZE > 1 - uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); - uint ib = coord / BLOCK_SIZE; - uint iqs = (coord % BLOCK_SIZE); + } else if (USE_DECODE_V) { + uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE_V + 4 * (d * D_split + d_tid); + uint ib = coord / BLOCK_SIZE_V; + uint iqs = (coord % BLOCK_SIZE_V); Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); -#else + } else { Vf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]); -#endif } [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { Of[r][d] += FLOAT_TYPEV4(Pf[r] * Vf); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl index efed3a73e22..9a7957da97b 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl @@ -87,176 +87,58 @@ layout (binding = 6) readonly buffer MO {uint32_t data_mask_opt[];}; #define BINDING_IDX_K 0 #define BINDING_IDX_V 1 -#if defined(DATA_A_F32) -layout (binding = 1) readonly buffer K_PACKED {vec4 k_data_packed[];} k_packed; -layout (binding = 2) readonly buffer V_PACKED {vec4 v_data_packed[];} v_packed; -#elif defined(A_TYPE_PACKED16) -layout (binding = 1) readonly buffer K_PACKED16 {A_TYPE_PACKED16 k_data_packed16[];} k_packed; -layout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16[];} v_packed; -#endif - -#if defined(A_TYPE_PACKED32) -layout (binding = 1) readonly buffer K_PACKED32 {A_TYPE_PACKED32 k_data_packed32[];} k_packed32; -layout (binding = 2) readonly buffer V_PACKED32 {A_TYPE_PACKED32 v_data_packed32[];} v_packed32; -#endif - -#ifndef BLOCK_SIZE -#define BLOCK_SIZE 1 -#endif - -#if defined(DATA_A_F32) -#undef BLOCK_SIZE -#define BLOCK_SIZE 4 -#define BLOCK_BYTE_SIZE 16 - -FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { - // iqs is currently always zero in the flash attention shaders - if (binding_idx == BINDING_IDX_K) { - return FLOAT_TYPEV4(k_packed.k_data_packed[a_offset + ib]); - } else { - return FLOAT_TYPEV4(v_packed.v_data_packed[a_offset + ib]); - } -} -#endif - -#if defined(DATA_A_Q4_0) -#define BLOCK_BYTE_SIZE 18 -#elif defined(DATA_A_Q4_1) -#define BLOCK_BYTE_SIZE 20 -#endif - -#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1) -FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { - if (binding_idx == BINDING_IDX_K) { - uint vui_lo = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); - uint vui_hi = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); - uint shift = (iqs & 0x10) >> 2; - vui_lo >>= shift; - vui_hi >>= shift; - - FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF); -#ifdef DATA_A_Q4_1 - return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * nibbles + FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].m); -#else - return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * (nibbles - FLOAT_TYPE(8.0f)); -#endif - } else { - uint vui_lo = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); - uint vui_hi = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); - uint shift = (iqs & 0x10) >> 2; - vui_lo >>= shift; - vui_hi >>= shift; - - FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF); -#ifdef DATA_A_Q4_1 - return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * nibbles + FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].m); -#else - return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * (nibbles - FLOAT_TYPE(8.0f)); -#endif - } -} -#endif - -#if defined(DATA_A_Q5_0) -#define BLOCK_BYTE_SIZE 22 -#elif defined(DATA_A_Q5_1) -#define BLOCK_BYTE_SIZE 24 -#endif - -#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1) -FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { - if (binding_idx == BINDING_IDX_K) { - uint vui_lo = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); - uint vui_hi = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); - uint shift = (iqs & 0x10) >> 2; - vui_lo >>= shift; - vui_hi >>= shift; - -#ifdef DATA_A_Q5_1 - uint qh = k_packed.k_data_packed16[a_offset + ib].qh; -#else - uint qh = uint(k_packed.k_data_packed16[a_offset + ib].qh[0]) | (uint(k_packed.k_data_packed16[a_offset + ib].qh[1]) << 16); -#endif - FLOAT_TYPEV4 hb = FLOAT_TYPEV4((qh >> iqs) & 1, (qh >> (iqs + 1)) & 1, (qh >> (iqs + 2)) & 1, (qh >> (iqs + 3)) & 1) * FLOAT_TYPE(16.0f); - - FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF); -#ifdef DATA_A_Q5_1 - return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * (nibbles + hb) + FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].m); -#else - return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * (nibbles + hb - FLOAT_TYPE(16.0f)); -#endif - } else { - uint vui_lo = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); - uint vui_hi = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); - uint shift = (iqs & 0x10) >> 2; - vui_lo >>= shift; - vui_hi >>= shift; - -#ifdef DATA_A_Q5_1 - uint qh = v_packed.v_data_packed16[a_offset + ib].qh; -#else - uint qh = uint(v_packed.v_data_packed16[a_offset + ib].qh[0]) | (uint(v_packed.v_data_packed16[a_offset + ib].qh[1]) << 16); -#endif - FLOAT_TYPEV4 hb = FLOAT_TYPEV4((qh >> iqs) & 1, (qh >> (iqs + 1)) & 1, (qh >> (iqs + 2)) & 1, (qh >> (iqs + 3)) & 1) * FLOAT_TYPE(16.0f); - - FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF); -#ifdef DATA_A_Q5_1 - return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * (nibbles + hb) + FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].m); -#else - return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * (nibbles + hb - FLOAT_TYPE(16.0f)); -#endif - } -} -#endif - - -#if defined(DATA_A_IQ4_NL) -#define BLOCK_BYTE_SIZE 18 - -FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { - if (binding_idx == BINDING_IDX_K) { - uint vui_lo = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); - uint vui_hi = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); - uint shift = (iqs & 0x10) >> 2; - vui_lo >>= shift; - vui_hi >>= shift; - return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * FLOAT_TYPEV4( - kvalues_iq4nl[vui_lo & 0xF], - kvalues_iq4nl[(vui_lo >> 8) & 0xF], - kvalues_iq4nl[vui_hi & 0xF], - kvalues_iq4nl[(vui_hi >> 8) & 0xF]); - } else { - uint vui_lo = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); - uint vui_hi = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); - uint shift = (iqs & 0x10) >> 2; - vui_lo >>= shift; - vui_hi >>= shift; - - return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * FLOAT_TYPEV4( - kvalues_iq4nl[vui_lo & 0xF], - kvalues_iq4nl[(vui_lo >> 8) & 0xF], - kvalues_iq4nl[vui_hi & 0xF], - kvalues_iq4nl[(vui_hi >> 8) & 0xF]); +// FaTypeK / FaTypeV spec constant values. These mirror enum ggml_type so the +// host can pass the type directly. Keep in sync with ggml.h. +#define FA_TYPE_F32 0u +#define FA_TYPE_F16 1u +#define FA_TYPE_Q4_0 2u +#define FA_TYPE_Q4_1 3u +#define FA_TYPE_Q5_0 6u +#define FA_TYPE_Q5_1 7u +#define FA_TYPE_Q8_0 8u +#define FA_TYPE_Q1_0 41u + +// Number of matrix elements per buffer block, derived from the K/V type spec +// constant. F32 is treated as a vec4 "block" of 4 floats. F16 uses block size 1 +// and bypasses the dequant path entirely. Quants follow their ggml block sizes. +uint fa_block_elems(uint ty) { + switch (ty) { + case FA_TYPE_F32: return 4u; + case FA_TYPE_F16: return 1u; + case FA_TYPE_Q4_0: return uint(QUANT_K_Q4_0); + case FA_TYPE_Q4_1: return uint(QUANT_K_Q4_1); + case FA_TYPE_Q5_0: return uint(QUANT_K_Q5_0); + case FA_TYPE_Q5_1: return uint(QUANT_K_Q5_1); + case FA_TYPE_Q8_0: return uint(QUANT_K_Q8_0); + case FA_TYPE_Q1_0: return uint(QUANT_K_Q1_0); // cm2-only, harmless elsewhere + default: return 1u; } } -#endif -#if defined(DATA_A_Q8_0) -#define BLOCK_BYTE_SIZE 34 -FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { - if (binding_idx == BINDING_IDX_K) { - const i8vec2 v0 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147 - const i8vec2 v1 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy; - - return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * FLOAT_TYPEV4(v0.x, v0.y, v1.x, v1.y); - } else { - const i8vec2 v0 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147 - const i8vec2 v1 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy; - return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * FLOAT_TYPEV4(v0.x, v0.y, v1.x, v1.y); +// QUANT_R_MMQ for FA-eligible K types. Q4_*/Q5_* store two nibbles per byte +// (R==2); Q8_0 stores one byte per element (R==1). Used to derive the number +// of int32s per 32-element block on the MMQ K path: ints_per_block == 8 / R. +uint fa_quant_r_mmq(uint ty) { + switch (ty) { + case FA_TYPE_Q4_0: return uint(QUANT_R_Q4_0); + case FA_TYPE_Q4_1: return uint(QUANT_R_Q4_1); + case FA_TYPE_Q5_0: return uint(QUANT_R_Q5_0); + case FA_TYPE_Q5_1: return uint(QUANT_R_Q5_1); + case FA_TYPE_Q8_0: return uint(QUANT_R_Q8_0); + default: return 1u; } } -#endif + +// These can't be `const` globals because GLSL forbids function calls in global +// const initializers, even when the spec constants would let the driver fold +// them. Macros expand at the use site and fold after specialization. +#define BLOCK_SIZE_K fa_block_elems(FaTypeK) +#define BLOCK_SIZE_V fa_block_elems(FaTypeV) +// F16 reads f16 elements directly from the binding; everything else routes +// through dequantize4 / the MMQ helpers to unpack from the packed block layout. +#define USE_DECODE_K (FaTypeK != FA_TYPE_F16) +#define USE_DECODE_V (FaTypeV != FA_TYPE_F16) #define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp index 526e8da384e..bffcc095be3 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp @@ -14,6 +14,7 @@ #include "types.glsl" #include "flash_attn_base.glsl" +#include "flash_attn_dequant.glsl" // These need to be supported N,M values for a MatBc x MatBr x 16 coopmatmuladd const uint32_t MatBr = 16; @@ -127,13 +128,9 @@ void main() { // mo_offset will point to the tile starting at row i*Br and col 0 uint32_t mo_offset = mo_stride * i; -#if BLOCK_SIZE > 1 - uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE; - uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE; -#else - uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2; - uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2; -#endif + // FaBlockBytesK/V == 2 for f16 (sizeof f16) and == 16 for f32 (vec4) and == ggml block size for quants. + uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / FaBlockBytesK; + uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / FaBlockBytesV; uint32_t m_offset = gqa_iq1*KV; if (p.nem2 != 1 || p.nem3 != 1) { m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV; @@ -227,14 +224,14 @@ void main() { if (idx + gl_WorkGroupSize.x <= Bc * HSK_pad / 4 || c < Bc) { f16vec4 K_Tf = f16vec4(0); if ((!KV_bounds_check || j * Bc + c < KV) && (HSK == HSK_pad || d < HSK / 4)) { -#if BLOCK_SIZE > 1 - uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d; - uint ib = coord / BLOCK_SIZE; - uint iqs = (coord % BLOCK_SIZE); - K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); -#else - K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]); -#endif + if (USE_DECODE_K) { + uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE_K + 4 * d; + uint ib = coord / BLOCK_SIZE_K; + uint iqs = (coord % BLOCK_SIZE_K); + K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); + } else { + K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]); + } } kvsh[c * kvsh_stride + d] = K_Tf; @@ -256,47 +253,40 @@ void main() { // staged through a Bc * MatBr size staging buffer. // If K is not type f16, then it is always staged for dequantization. if (SHMEM_STAGING == 0) { -#if BLOCK_SIZE == 1 - if (KV_bounds_check || d * 16 + 16 > HSK) { -#endif - barrier(); - [[unroll]] for (uint32_t idx = 0; idx < Bc * MatBr / 4; idx += gl_WorkGroupSize.x) { - uint32_t col_vec = (idx + tid) % (MatBr / 4); - uint32_t row = (idx + tid) / (MatBr / 4); - if (idx + tid < Bc * MatBr / 4) { - f16vec4 K_Tf = f16vec4(0); - if ((!KV_bounds_check || j * Bc + row < KV) && (HSK == HSK_pad || d * 16 + col_vec * 4 < HSK)) { -#if BLOCK_SIZE > 1 - uint coord = (j * Bc + row) * k_stride * BLOCK_SIZE + d * 16 + col_vec * 4; - uint ib = coord / BLOCK_SIZE; - uint iqs = (coord % BLOCK_SIZE); - K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); -#else - K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + row) * k_stride / 4 + d * 16 / 4 + col_vec]); -#endif - } + // For quants we always need to dequant into kvsh; for f16 we can load + // directly from global memory when alignment / bounds allow it. + const bool stage_k = USE_DECODE_K || KV_bounds_check || d * 16 + 16 > HSK; + if (stage_k) { + barrier(); + [[unroll]] for (uint32_t idx = 0; idx < Bc * MatBr / 4; idx += gl_WorkGroupSize.x) { + uint32_t col_vec = (idx + tid) % (MatBr / 4); + uint32_t row = (idx + tid) / (MatBr / 4); + if (idx + tid < Bc * MatBr / 4) { + f16vec4 K_Tf = f16vec4(0); + if ((!KV_bounds_check || j * Bc + row < KV) && (HSK == HSK_pad || d * 16 + col_vec * 4 < HSK)) { + if (USE_DECODE_K) { + uint coord = (j * Bc + row) * k_stride * BLOCK_SIZE_K + d * 16 + col_vec * 4; + uint ib = coord / BLOCK_SIZE_K; + uint iqs = (coord % BLOCK_SIZE_K); + K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); + } else { + K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + row) * k_stride / 4 + d * 16 / 4 + col_vec]); + } + } - kvsh[row * kvsh_stride + col_vec] = K_Tf; + kvsh[row * kvsh_stride + col_vec] = K_Tf; + } } + barrier(); } - barrier(); -#if BLOCK_SIZE == 1 - } -#endif -#if BLOCK_SIZE == 1 - if (KV_bounds_check || d * 16 + 16 > HSK) -#endif - { + if (stage_k) { uint coord = (gl_SubgroupID * MatBc) * kvsh_stride; coopMatLoad(KMat, kvsh, coord, kvsh_stride, gl_CooperativeMatrixLayoutRowMajor); - } -#if BLOCK_SIZE == 1 - else { + } else { const uint coord = k_offset / 4 + (j * Bc + gl_SubgroupID * MatBc) * k_stride / 4 + d * 16 / 4; coopMatLoad(KMat, data_kv4, coord, k_stride / 4, gl_CooperativeMatrixLayoutRowMajor); } -#endif } else { uint coord = (gl_SubgroupID * MatBc) * kvsh_stride + d * 16 / 4; coopMatLoad(KMat, kvsh, coord, kvsh_stride, gl_CooperativeMatrixLayoutRowMajor); @@ -397,14 +387,14 @@ void main() { if (idx + gl_WorkGroupSize.x <= Bc * HSV_pad / 4 || c < Bc) { f16vec4 V_Tf = f16vec4(0); if ((!KV_bounds_check || j * Bc + c < KV) && (HSV == HSV_pad || d < HSV / 4)) { -#if BLOCK_SIZE > 1 - uint coord = (j * Bc + c) * v_stride * BLOCK_SIZE + 4 * d; - uint ib = coord / BLOCK_SIZE; - uint iqs = (coord % BLOCK_SIZE); - V_Tf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); -#else - V_Tf = f16vec4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]); -#endif + if (USE_DECODE_V) { + uint coord = (j * Bc + c) * v_stride * BLOCK_SIZE_V + 4 * d; + uint ib = coord / BLOCK_SIZE_V; + uint iqs = (coord % BLOCK_SIZE_V); + V_Tf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); + } else { + V_Tf = f16vec4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]); + } } kvsh[c * kvsh_stride + d] = V_Tf; @@ -431,36 +421,33 @@ void main() { // staged through a Bc * MatBr size staging buffer. // If V is not type f16, then it is always staged for dequantization. if (SHMEM_STAGING == 0) { -#if BLOCK_SIZE == 1 - // For f16, only preload if not aligned - if (KV_bounds_check) { -#endif - [[unroll]] for (uint32_t i = 0; i < v_loads_per_thread; ++i) { - const uint idx = i * gl_WorkGroupSize.x + tid; - const uint row = idx / v_cols; - const uint col = idx % v_cols; - - const uint v_row = j * Bc + row; - const uint v_col = hsv_tile * MatBc * row_split + col * 4; - - const uint coord = v_row * v_stride * BLOCK_SIZE + v_col; - const uint ib = coord / BLOCK_SIZE; - const uint iqs = coord % BLOCK_SIZE; - - if (!KV_bounds_check || (v_row < KV && v_col < HSV)) { -#if BLOCK_SIZE > 1 - kvsh[row * vsh_stride + col] = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); -#else - kvsh[row * vsh_stride + col] = data_vv4[(v_offset + v_row * v_stride + v_col) / 4]; -#endif - } else { - kvsh[row * vsh_stride + col] = f16vec4(0.0f); + // For quants we always preload via kvsh. For f16 we only preload when + // alignment / bounds force it (otherwise we coopMatLoad direct from data_vv4). + const bool stage_v = USE_DECODE_V || KV_bounds_check; + if (stage_v) { + [[unroll]] for (uint32_t i = 0; i < v_loads_per_thread; ++i) { + const uint idx = i * gl_WorkGroupSize.x + tid; + const uint row = idx / v_cols; + const uint col = idx % v_cols; + + const uint v_row = j * Bc + row; + const uint v_col = hsv_tile * MatBc * row_split + col * 4; + + const uint coord = v_row * v_stride * BLOCK_SIZE_V + v_col; + const uint ib = coord / BLOCK_SIZE_V; + const uint iqs = coord % BLOCK_SIZE_V; + + if (!KV_bounds_check || (v_row < KV && v_col < HSV)) { + if (USE_DECODE_V) { + kvsh[row * vsh_stride + col] = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); + } else { + kvsh[row * vsh_stride + col] = data_vv4[(v_offset + v_row * v_stride + v_col) / 4]; + } + } else { + kvsh[row * vsh_stride + col] = f16vec4(0.0f); + } } } - -#if BLOCK_SIZE == 1 - } -#endif } barrier(); @@ -471,15 +458,12 @@ void main() { coopMatLoad(KMat, Psh, bc_chunk * MatBc * psh_stride, psh_stride, gl_CooperativeMatrixLayoutColumnMajor); if (SHMEM_STAGING == 0) { -#if BLOCK_SIZE == 1 - if (!KV_bounds_check) { + if (!USE_DECODE_V && !KV_bounds_check) { // F16 values can be loaded directly from global memory const uint v_tile_row = j * Bc + bc_chunk * MatBc; const uint v_tile_offset = v_offset / 4 + v_tile_row * v_stride / 4 + hsv_offset / 4; coopMatLoad(QMat, data_vv4, v_tile_offset, v_stride / 4, gl_CooperativeMatrixLayoutRowMajor); - } else -#endif - { + } else { const uint v_tile_offset = bc_chunk * MatBr * v_cols + gl_SubgroupID * (MatBc / 4); coopMatLoad(QMat, kvsh, v_tile_offset, vsh_stride, gl_CooperativeMatrixLayoutRowMajor); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index 8a7bbaeb92c..141bb870883 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -28,43 +28,28 @@ layout(buffer_reference, std430, buffer_reference_align = 1) buffer decodeBufFA_ uint8_t raw[FaBlockBytesV]; }; -uint fa_block_elems(uint ty) { - switch (ty) { - case 0u: return 4u; // GGML_TYPE_F32: vec4 block (matches decodeBufF32 / dequantFuncF32) - case 1u: return 1u; // GGML_TYPE_F16 - case 2u: return uint(QUANT_K_Q4_0); - case 3u: return uint(QUANT_K_Q4_1); - case 6u: return uint(QUANT_K_Q5_0); - case 7u: return uint(QUANT_K_Q5_1); - case 8u: return uint(QUANT_K_Q8_0); - case 41u: return uint(QUANT_K_Q1_0); - default: - return 1u; - } -} - float16_t faDecodeK(const decodeBufFA_K bl_in, const uint blockCoords[2], const uint coordInBlock[2]) { switch (FaTypeK) { - case 0u: return dequantFuncF32(decodeBufF32(bl_in), blockCoords, coordInBlock); - case 2u: return dequantFuncQ4_0(decodeBufQ4_0(bl_in), blockCoords, coordInBlock); - case 3u: return dequantFuncQ4_1(decodeBufQ4_1(bl_in), blockCoords, coordInBlock); - case 6u: return dequantFuncQ5_0(decodeBufQ5_0(bl_in), blockCoords, coordInBlock); - case 7u: return dequantFuncQ5_1(decodeBufQ5_1(bl_in), blockCoords, coordInBlock); - case 8u: return dequantFuncQ8_0(decodeBufQ8_0(bl_in), blockCoords, coordInBlock); - case 41u: return dequantFuncQ1_0(decodeBufQ1_0(bl_in), blockCoords, coordInBlock); + case FA_TYPE_F32: return dequantFuncF32 (decodeBufF32 (bl_in), blockCoords, coordInBlock); + case FA_TYPE_Q4_0: return dequantFuncQ4_0(decodeBufQ4_0(bl_in), blockCoords, coordInBlock); + case FA_TYPE_Q4_1: return dequantFuncQ4_1(decodeBufQ4_1(bl_in), blockCoords, coordInBlock); + case FA_TYPE_Q5_0: return dequantFuncQ5_0(decodeBufQ5_0(bl_in), blockCoords, coordInBlock); + case FA_TYPE_Q5_1: return dequantFuncQ5_1(decodeBufQ5_1(bl_in), blockCoords, coordInBlock); + case FA_TYPE_Q8_0: return dequantFuncQ8_0(decodeBufQ8_0(bl_in), blockCoords, coordInBlock); + case FA_TYPE_Q1_0: return dequantFuncQ1_0(decodeBufQ1_0(bl_in), blockCoords, coordInBlock); default: return float16_t(0); } } float16_t faDecodeV(const decodeBufFA_V bl_in, const uint blockCoords[2], const uint coordInBlock[2]) { switch (FaTypeV) { - case 0u: return dequantFuncF32(decodeBufF32(bl_in), blockCoords, coordInBlock); - case 2u: return dequantFuncQ4_0(decodeBufQ4_0(bl_in), blockCoords, coordInBlock); - case 3u: return dequantFuncQ4_1(decodeBufQ4_1(bl_in), blockCoords, coordInBlock); - case 6u: return dequantFuncQ5_0(decodeBufQ5_0(bl_in), blockCoords, coordInBlock); - case 7u: return dequantFuncQ5_1(decodeBufQ5_1(bl_in), blockCoords, coordInBlock); - case 8u: return dequantFuncQ8_0(decodeBufQ8_0(bl_in), blockCoords, coordInBlock); - case 41u: return dequantFuncQ1_0(decodeBufQ1_0(bl_in), blockCoords, coordInBlock); + case FA_TYPE_F32: return dequantFuncF32 (decodeBufF32 (bl_in), blockCoords, coordInBlock); + case FA_TYPE_Q4_0: return dequantFuncQ4_0(decodeBufQ4_0(bl_in), blockCoords, coordInBlock); + case FA_TYPE_Q4_1: return dequantFuncQ4_1(decodeBufQ4_1(bl_in), blockCoords, coordInBlock); + case FA_TYPE_Q5_0: return dequantFuncQ5_0(decodeBufQ5_0(bl_in), blockCoords, coordInBlock); + case FA_TYPE_Q5_1: return dequantFuncQ5_1(decodeBufQ5_1(bl_in), blockCoords, coordInBlock); + case FA_TYPE_Q8_0: return dequantFuncQ8_0(decodeBufQ8_0(bl_in), blockCoords, coordInBlock); + case FA_TYPE_Q1_0: return dequantFuncQ1_0(decodeBufQ1_0(bl_in), blockCoords, coordInBlock); default: return float16_t(0); } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl new file mode 100644 index 00000000000..02106f33cbe --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl @@ -0,0 +1,123 @@ +// Asymmetric K/V flash attention: aliased SSBO views of bindings 1 (K) and 2 (V) +// covering every supported FA element type, plus an uber dequantize4() that +// switches on FaTypeK / FaTypeV. After spec-constant specialization the driver +// folds away every path except the one matching the K/V type for this pipeline. +// +// Included by flash_attn.comp and flash_attn_cm1.comp. Not included by +// flash_attn_cm2.comp, which has its own buffer_reference-based decode path. +// +// We use macros (rather than per-quant decode functions taking a struct) on +// purpose: the FA shaders don't enable GL_EXT_shader_explicit_arithmetic_types_float16 +// when FLOAT16 isn't defined, which makes float16-containing struct values +// illegal to return from / pass to functions. Macros expand inline where the +// float16 stays in storage and is converted to FLOAT_TYPE at use. + +// F32 is fed as a vec4 "block" (4 floats), matching what dequant_funcs_cm2.glsl +// does for F32 in the cm2 shader. FaBlockBytesK/V == 16 for F32. +layout (binding = 1) readonly buffer K_PACKED_F32 { vec4 data[]; } k_packed_f32; +layout (binding = 2) readonly buffer V_PACKED_F32 { vec4 data[]; } v_packed_f32; + +layout (binding = 1) readonly buffer K_PACKED_Q4_0 { block_q4_0_packed16 data[]; } k_packed_q4_0; +layout (binding = 2) readonly buffer V_PACKED_Q4_0 { block_q4_0_packed16 data[]; } v_packed_q4_0; +layout (binding = 1) readonly buffer K_PACKED_Q4_1 { block_q4_1_packed16 data[]; } k_packed_q4_1; +layout (binding = 2) readonly buffer V_PACKED_Q4_1 { block_q4_1_packed16 data[]; } v_packed_q4_1; +layout (binding = 1) readonly buffer K_PACKED_Q5_0 { block_q5_0_packed16 data[]; } k_packed_q5_0; +layout (binding = 2) readonly buffer V_PACKED_Q5_0 { block_q5_0_packed16 data[]; } v_packed_q5_0; +layout (binding = 1) readonly buffer K_PACKED_Q5_1 { block_q5_1_packed16 data[]; } k_packed_q5_1; +layout (binding = 2) readonly buffer V_PACKED_Q5_1 { block_q5_1_packed16 data[]; } v_packed_q5_1; +layout (binding = 1) readonly buffer K_PACKED_Q8_0 { block_q8_0_packed16 data[]; } k_packed_q8_0; +layout (binding = 2) readonly buffer V_PACKED_Q8_0 { block_q8_0_packed16 data[]; } v_packed_q8_0; + +// Q4_1 and Q5_1 packed32 views: aliased to the same memory as the packed16 +// views, used by the MMQ K-side hot path for fast 4-uint loads. +layout (binding = 1) readonly buffer K_PACKED_Q4_1_P32 { block_q4_1_packed32 data[]; } k_packed_q4_1_p32; +layout (binding = 1) readonly buffer K_PACKED_Q5_1_P32 { block_q5_1_packed32 data[]; } k_packed_q5_1_p32; + +// Per-quant decode bodies are expanded once for the K view set and once for +// the V view set. The macros take the buffer name as a parameter. +#define FA_DEQUANT4_F32(BUF) \ + return FLOAT_TYPEV4(BUF.data[a_offset + ib]); + +#define FA_DEQUANT4_Q4_0(BUF) { \ + uint vui_lo = uint(BUF.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); \ + uint vui_hi = uint(BUF.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); \ + uint shift = (iqs & 0x10) >> 2; \ + vui_lo >>= shift; \ + vui_hi >>= shift; \ + FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, \ + vui_hi & 0xF, (vui_hi >> 8) & 0xF); \ + return FLOAT_TYPE(BUF.data[a_offset + ib].d) * (nibbles - FLOAT_TYPE(8.0f)); \ +} + +#define FA_DEQUANT4_Q4_1(BUF) { \ + uint vui_lo = uint(BUF.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); \ + uint vui_hi = uint(BUF.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); \ + uint shift = (iqs & 0x10) >> 2; \ + vui_lo >>= shift; \ + vui_hi >>= shift; \ + FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, \ + vui_hi & 0xF, (vui_hi >> 8) & 0xF); \ + return FLOAT_TYPE(BUF.data[a_offset + ib].d) * nibbles \ + + FLOAT_TYPE(BUF.data[a_offset + ib].m); \ +} + +#define FA_DEQUANT4_Q5_0(BUF) { \ + uint vui_lo = uint(BUF.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); \ + uint vui_hi = uint(BUF.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); \ + uint shift = (iqs & 0x10) >> 2; \ + vui_lo >>= shift; \ + vui_hi >>= shift; \ + uint qh = uint(BUF.data[a_offset + ib].qh[0]) \ + | (uint(BUF.data[a_offset + ib].qh[1]) << 16); \ + FLOAT_TYPEV4 hb = FLOAT_TYPEV4((qh >> iqs) & 1, (qh >> (iqs + 1)) & 1, \ + (qh >> (iqs + 2)) & 1, (qh >> (iqs + 3)) & 1) \ + * FLOAT_TYPE(16.0f); \ + FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, \ + vui_hi & 0xF, (vui_hi >> 8) & 0xF); \ + return FLOAT_TYPE(BUF.data[a_offset + ib].d) * (nibbles + hb - FLOAT_TYPE(16.0f)); \ +} + +#define FA_DEQUANT4_Q5_1(BUF) { \ + uint vui_lo = uint(BUF.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); \ + uint vui_hi = uint(BUF.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); \ + uint shift = (iqs & 0x10) >> 2; \ + vui_lo >>= shift; \ + vui_hi >>= shift; \ + uint qh = BUF.data[a_offset + ib].qh; \ + FLOAT_TYPEV4 hb = FLOAT_TYPEV4((qh >> iqs) & 1, (qh >> (iqs + 1)) & 1, \ + (qh >> (iqs + 2)) & 1, (qh >> (iqs + 3)) & 1) \ + * FLOAT_TYPE(16.0f); \ + FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, \ + vui_hi & 0xF, (vui_hi >> 8) & 0xF); \ + return FLOAT_TYPE(BUF.data[a_offset + ib].d) * (nibbles + hb) \ + + FLOAT_TYPE(BUF.data[a_offset + ib].m); \ +} + +#define FA_DEQUANT4_Q8_0(BUF) { \ + const i8vec2 v0 = unpack8(int32_t(BUF.data[a_offset + ib].qs[iqs / 2 ])).xy; \ + const i8vec2 v1 = unpack8(int32_t(BUF.data[a_offset + ib].qs[iqs / 2 + 1])).xy; \ + return FLOAT_TYPE(BUF.data[a_offset + ib].d) * FLOAT_TYPEV4(v0.x, v0.y, v1.x, v1.y); \ +} + +FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { + if (binding_idx == BINDING_IDX_K) { + switch (FaTypeK) { + case FA_TYPE_F32: FA_DEQUANT4_F32 (k_packed_f32) + case FA_TYPE_Q4_0: FA_DEQUANT4_Q4_0(k_packed_q4_0) + case FA_TYPE_Q4_1: FA_DEQUANT4_Q4_1(k_packed_q4_1) + case FA_TYPE_Q5_0: FA_DEQUANT4_Q5_0(k_packed_q5_0) + case FA_TYPE_Q5_1: FA_DEQUANT4_Q5_1(k_packed_q5_1) + case FA_TYPE_Q8_0: FA_DEQUANT4_Q8_0(k_packed_q8_0) + } + } else { + switch (FaTypeV) { + case FA_TYPE_F32: FA_DEQUANT4_F32 (v_packed_f32) + case FA_TYPE_Q4_0: FA_DEQUANT4_Q4_0(v_packed_q4_0) + case FA_TYPE_Q4_1: FA_DEQUANT4_Q4_1(v_packed_q4_1) + case FA_TYPE_Q5_0: FA_DEQUANT4_Q5_0(v_packed_q5_0) + case FA_TYPE_Q5_1: FA_DEQUANT4_Q5_1(v_packed_q5_1) + case FA_TYPE_Q8_0: FA_DEQUANT4_Q8_0(v_packed_q8_0) + } + } + return FLOAT_TYPEV4(0); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl index e14e62d546a..6bf10a7cffd 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl @@ -1,149 +1,203 @@ -#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1) -int32_t get_k_qs(uint ib, uint iqs, uint a_offset) { -#ifdef DATA_A_Q4_0 - uint vui = pack32(u16vec2(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0], - k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1])); -#else - uint vui = k_packed32.k_data_packed32[a_offset + ib].qs[(iqs & 0xF) / 4]; -#endif - - uint shift = (iqs & 0x10) >> 2; - vui >>= shift; - - return int32_t(vui & 0x0F0F0F0F); -} -#endif - -#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1) -int32_t get_k_qs(uint ib, uint iqs, uint a_offset) { -#ifdef DATA_A_Q5_0 - uint vui = pack32(u16vec2(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0], - k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1])); - uint qh = pack32(u16vec2(k_packed.k_data_packed16[a_offset + ib].qh[0], - k_packed.k_data_packed16[a_offset + ib].qh[1])); -#else - uint vui = k_packed32.k_data_packed32[a_offset + ib].qs[(iqs & 0xF) / 4]; - uint qh = k_packed.k_data_packed16[a_offset + ib].qh; -#endif - - uint shift = (iqs & 0x10) >> 2; - vui >>= shift; - - uint qh_bits = (qh >> iqs) & 0xF; - return int32_t(vui & 0x0F0F0F0F) | int32_t((qh_bits * 0x02040810u) & 0x10101010u); -} -#endif - -#if defined(DATA_A_Q8_0) -int32_t get_k_qs(uint ib, uint iqs, uint a_offset) { - return pack32(i16vec2(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2], k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2 + 1])); -} -#endif +// MMQ K-side helpers, asymmetric form. Each function dispatches on FaTypeK and +// reads from the matching aliased K binding declared in flash_attn_dequant.glsl. +// Spec-constant specialization folds the unused paths. -#if defined(DATA_A_IQ4_NL) int32_t get_k_qs(uint ib, uint iqs, uint a_offset) { - uint vui = pack32(u16vec2(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0], - k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1])); - uint shift = (iqs & 0x10) >> 2; - vui >>= shift; - - u8vec4 idx = unpack8(vui & 0x0F0F0F0F); - return pack32(i8vec4(kvalues_iq4nl_const[idx.x], - kvalues_iq4nl_const[idx.y], - kvalues_iq4nl_const[idx.z], - kvalues_iq4nl_const[idx.w])); + switch (FaTypeK) { + case FA_TYPE_Q4_0: { + uint vui = pack32(u16vec2(k_packed_q4_0.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 0], + k_packed_q4_0.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 1])); + uint shift = (iqs & 0x10) >> 2; + vui >>= shift; + return int32_t(vui & 0x0F0F0F0F); + } + case FA_TYPE_Q4_1: { // uses packed32 alias + uint vui = k_packed_q4_1_p32.data[a_offset + ib].qs[(iqs & 0xF) / 4]; + uint shift = (iqs & 0x10) >> 2; + vui >>= shift; + return int32_t(vui & 0x0F0F0F0F); + } + case FA_TYPE_Q5_0: { + uint vui = pack32(u16vec2(k_packed_q5_0.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 0], + k_packed_q5_0.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 1])); + uint qh = pack32(u16vec2(k_packed_q5_0.data[a_offset + ib].qh[0], + k_packed_q5_0.data[a_offset + ib].qh[1])); + uint shift = (iqs & 0x10) >> 2; + vui >>= shift; + uint qh_bits = (qh >> iqs) & 0xF; + return int32_t(vui & 0x0F0F0F0F) | int32_t((qh_bits * 0x02040810u) & 0x10101010u); + } + case FA_TYPE_Q5_1: { // qs via packed32, qh via packed16 + uint vui = k_packed_q5_1_p32.data[a_offset + ib].qs[(iqs & 0xF) / 4]; + uint qh = k_packed_q5_1.data[a_offset + ib].qh; + uint shift = (iqs & 0x10) >> 2; + vui >>= shift; + uint qh_bits = (qh >> iqs) & 0xF; + return int32_t(vui & 0x0F0F0F0F) | int32_t((qh_bits * 0x02040810u) & 0x10101010u); + } + case FA_TYPE_Q8_0: { + return pack32(i16vec2(k_packed_q8_0.data[a_offset + ib].qs[iqs / 2], + k_packed_q8_0.data[a_offset + ib].qs[iqs / 2 + 1])); + } + default: return 0; + } } -#endif -#if QUANT_AUXF == 1 -FLOAT_TYPE get_k_d(uint ib, uint a_offset) { - return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d); -} -#else -FLOAT_TYPEV2 get_k_dm(uint ib, uint a_offset) { - return FLOAT_TYPEV2(k_packed32.k_data_packed32[a_offset + ib].dm); +// Per-block scale/min, packed as (d, m). Single-scale types (Q4_0, Q5_0, Q8_0) +// return (d, 0) so call sites always see the same shape. +FLOAT_TYPEV2 get_k_scale(uint ib, uint a_offset) { + switch (FaTypeK) { + case FA_TYPE_Q4_0: return FLOAT_TYPEV2(FLOAT_TYPE(k_packed_q4_0.data[a_offset + ib].d), 0.0); + case FA_TYPE_Q4_1: return FLOAT_TYPEV2(k_packed_q4_1_p32.data[a_offset + ib].dm); + case FA_TYPE_Q5_0: return FLOAT_TYPEV2(FLOAT_TYPE(k_packed_q5_0.data[a_offset + ib].d), 0.0); + case FA_TYPE_Q5_1: return FLOAT_TYPEV2(k_packed_q5_1_p32.data[a_offset + ib].dm); + case FA_TYPE_Q8_0: return FLOAT_TYPEV2(FLOAT_TYPE(k_packed_q8_0.data[a_offset + ib].d), 0.0); + default: return FLOAT_TYPEV2(0); + } } -#endif void k_block_to_shmem(const uint buf_ib, const uint global_ib, const uint iqs, const uint a_offset) { -#if defined(DATA_A_Q4_0) - kblocksh[buf_ib].qs[iqs] = pack32(u16vec2(k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2], - k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2 + 1])); -#elif defined(DATA_A_Q4_1) - kblocksh[buf_ib].qs[iqs] = k_packed32.k_data_packed32[a_offset + global_ib].qs[iqs]; -#elif defined(DATA_A_Q5_0) - kblocksh[buf_ib].qs[iqs] = pack32(u16vec2(k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2], - k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2 + 1])); - if (iqs == 0) { - kblocksh[buf_ib].qh = pack32(u16vec2(k_packed.k_data_packed16[a_offset + global_ib].qh[0], - k_packed.k_data_packed16[a_offset + global_ib].qh[1])); + // kblocksh[].qs is int32_t for the unified MMQ struct; uint sources need + // explicit casts. The bit pattern is what we care about here -- the actual + // signed/unsigned interpretation happens downstream in the dot product. + switch (FaTypeK) { + case FA_TYPE_Q4_0: { + kblocksh[buf_ib].qs[iqs] = int32_t(pack32(u16vec2(k_packed_q4_0.data[a_offset + global_ib].qs[iqs * 2], + k_packed_q4_0.data[a_offset + global_ib].qs[iqs * 2 + 1]))); + break; + } + case FA_TYPE_Q4_1: { + kblocksh[buf_ib].qs[iqs] = int32_t(k_packed_q4_1_p32.data[a_offset + global_ib].qs[iqs]); + break; + } + case FA_TYPE_Q5_0: { + kblocksh[buf_ib].qs[iqs] = int32_t(pack32(u16vec2(k_packed_q5_0.data[a_offset + global_ib].qs[iqs * 2], + k_packed_q5_0.data[a_offset + global_ib].qs[iqs * 2 + 1]))); + if (iqs == 0) { + kblocksh[buf_ib].qh = pack32(u16vec2(k_packed_q5_0.data[a_offset + global_ib].qh[0], + k_packed_q5_0.data[a_offset + global_ib].qh[1])); + } + break; + } + case FA_TYPE_Q5_1: { + kblocksh[buf_ib].qs[iqs] = int32_t(k_packed_q5_1_p32.data[a_offset + global_ib].qs[iqs]); + if (iqs == 0) { + kblocksh[buf_ib].qh = k_packed_q5_1.data[a_offset + global_ib].qh; + } + break; + } + case FA_TYPE_Q8_0: { + kblocksh[buf_ib].qs[iqs] = pack32(i16vec2(k_packed_q8_0.data[a_offset + global_ib].qs[iqs * 2], + k_packed_q8_0.data[a_offset + global_ib].qs[iqs * 2 + 1])); + break; + } } -#elif defined(DATA_A_Q5_1) - kblocksh[buf_ib].qs[iqs] = k_packed32.k_data_packed32[a_offset + global_ib].qs[iqs]; + if (iqs == 0) { - kblocksh[buf_ib].qh = k_packed.k_data_packed16[a_offset + global_ib].qh; + // Q4_0/Q5_0/Q8_0 store dm.x = d; Q4_1/Q5_1 store dm = (d, m) pair. + switch (FaTypeK) { + case FA_TYPE_Q4_0: kblocksh[buf_ib].dm = FLOAT_TYPEV2(FLOAT_TYPE(k_packed_q4_0.data[a_offset + global_ib].d), 0.0); break; + case FA_TYPE_Q4_1: kblocksh[buf_ib].dm = FLOAT_TYPEV2(k_packed_q4_1_p32.data[a_offset + global_ib].dm); break; + case FA_TYPE_Q5_0: kblocksh[buf_ib].dm = FLOAT_TYPEV2(FLOAT_TYPE(k_packed_q5_0.data[a_offset + global_ib].d), 0.0); break; + case FA_TYPE_Q5_1: kblocksh[buf_ib].dm = FLOAT_TYPEV2(k_packed_q5_1_p32.data[a_offset + global_ib].dm); break; + case FA_TYPE_Q8_0: kblocksh[buf_ib].dm = FLOAT_TYPEV2(FLOAT_TYPE(k_packed_q8_0.data[a_offset + global_ib].d), 0.0); break; + } } -#elif defined(DATA_A_Q8_0) - kblocksh[buf_ib].qs[iqs] = pack32(i16vec2(k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2], - k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2 + 1])); -#elif defined(DATA_A_IQ4_NL) - const uint qs = pack32(u16vec2(k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2], - k_packed.k_data_packed16[a_offset + global_ib].qs[iqs * 2 + 1])); - const u8vec4 i_a0 = unpack8( qs & 0x0F0F0F0F); - const u8vec4 i_a1 = unpack8((qs >> 4) & 0x0F0F0F0F); - kblocksh[buf_ib].qs[iqs ] = pack32(i8vec4(kvalues_iq4nl_const[i_a0.x], kvalues_iq4nl_const[i_a0.y], - kvalues_iq4nl_const[i_a0.z], kvalues_iq4nl_const[i_a0.w])); - kblocksh[buf_ib].qs[iqs + 4] = pack32(i8vec4(kvalues_iq4nl_const[i_a1.x], kvalues_iq4nl_const[i_a1.y], - kvalues_iq4nl_const[i_a1.z], kvalues_iq4nl_const[i_a1.w])); -#endif +} - if (iqs == 0) { -#if QUANT_AUXF == 1 - kblocksh[buf_ib].dm = FLOAT_TYPE(k_packed.k_data_packed16[a_offset + global_ib].d); -#else - kblocksh[buf_ib].dm = FLOAT_TYPEV2(k_packed32.k_data_packed32[a_offset + global_ib].dm); -#endif +// d_per_step==8 hot path: read one full 32-element block worth of nibble-packed +// int32 quants. Equivalent to 8 calls to get_k_qs(ib, d*4, a_offset) but reads +// qh (Q5_*) and runs pack32 (Q4_0/Q5_0) once per block instead of per nibble +// quad. iqs is always 0 in this path (hsk4 % 8 == 0 implies block-aligned). +// Q8_0 takes the generic get_k_qs path because its qs layout (i8 pairs) doesn't +// share this nibble shape. +// +// Returned via a struct so the caller's k_quants array (sized from spec +// constants) doesn't need to match a fixed[8] out-parameter type. +struct fa_k_qs_block8 { + int32_t qs[8]; +}; + +fa_k_qs_block8 get_k_qs_block8(uint ib, uint a_offset) { + fa_k_qs_block8 r; + uint qh = 0; + if (FaTypeK == FA_TYPE_Q5_0) { + qh = pack32(u16vec2(k_packed_q5_0.data[a_offset + ib].qh[0], + k_packed_q5_0.data[a_offset + ib].qh[1])); + } else if (FaTypeK == FA_TYPE_Q5_1) { + qh = k_packed_q5_1.data[a_offset + ib].qh; } + const bool has_qh = (FaTypeK == FA_TYPE_Q5_0) || (FaTypeK == FA_TYPE_Q5_1); + [[unroll]] for (uint32_t d = 0; d < 4; d++) { + uint vui = 0; + switch (FaTypeK) { + case FA_TYPE_Q4_0: { // packed16 + vui = pack32(u16vec2(k_packed_q4_0.data[a_offset + ib].qs[d * 2 + 0], + k_packed_q4_0.data[a_offset + ib].qs[d * 2 + 1])); + break; + } + case FA_TYPE_Q4_1: { // packed32 alias + vui = k_packed_q4_1_p32.data[a_offset + ib].qs[d]; + break; + } + case FA_TYPE_Q5_0: { // packed16 + vui = pack32(u16vec2(k_packed_q5_0.data[a_offset + ib].qs[d * 2 + 0], + k_packed_q5_0.data[a_offset + ib].qs[d * 2 + 1])); + break; + } + case FA_TYPE_Q5_1: { // packed32 alias + vui = k_packed_q5_1_p32.data[a_offset + ib].qs[d]; + break; + } + } + r.qs[d ] = int32_t( vui & 0x0F0F0F0F); + r.qs[d + 4] = int32_t((vui >> 4) & 0x0F0F0F0F); + if (has_qh) { + uint qh_lo = (qh >> (d * 4)) & 0xFu; + uint qh_hi = (qh >> (d * 4 + 16)) & 0xFu; + r.qs[d ] |= int32_t((qh_lo * 0x02040810u) & 0x10101010u); + r.qs[d + 4] |= int32_t((qh_hi * 0x02040810u) & 0x10101010u); + } + } + return r; } int32_t get_k_qs_shmem(const uint buf_ib, const uint pos) { -#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1) - uint sub = pos % 4; - uint shift = ((pos % 8) >= 4) ? 4 : 0; - return int32_t((kblocksh[buf_ib].qs[sub] >> shift) & 0x0F0F0F0F); -#elif defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1) - uint sub = pos % 4; - uint shift = ((pos % 8) >= 4) ? 4 : 0; - int32_t result = int32_t((kblocksh[buf_ib].qs[sub] >> shift) & 0x0F0F0F0F); - uint qh_bits = (kblocksh[buf_ib].qh >> (pos * 4)) & 0xF; - return result | int32_t((qh_bits * 0x02040810u) & 0x10101010u); -#elif defined(DATA_A_Q8_0) || defined(DATA_A_IQ4_NL) - return kblocksh[buf_ib].qs[pos]; -#endif + switch (FaTypeK) { + case FA_TYPE_Q4_0: + case FA_TYPE_Q4_1: { + uint sub = pos % 4; + uint shift = ((pos % 8) >= 4) ? 4u : 0u; + return int32_t((uint(kblocksh[buf_ib].qs[sub]) >> shift) & 0x0F0F0F0Fu); + } + case FA_TYPE_Q5_0: + case FA_TYPE_Q5_1: { + uint sub = pos % 4; + uint shift = ((pos % 8) >= 4) ? 4u : 0u; + int32_t result = int32_t((uint(kblocksh[buf_ib].qs[sub]) >> shift) & 0x0F0F0F0Fu); + uint qh_bits = (kblocksh[buf_ib].qh >> (pos * 4u)) & 0xFu; + return result | int32_t((qh_bits * 0x02040810u) & 0x10101010u); + } + case FA_TYPE_Q8_0: { + return kblocksh[buf_ib].qs[pos]; + } + default: return 0; + } } ACC_TYPE k_dot_correction(const uint qib, const ACC_TYPEV2 k_dm) { -#if defined(DATA_A_Q4_0) - return -ACC_TYPE(8.0) * ACC_TYPE(Qf[qib].ds.y) * k_dm.x; -#elif defined(DATA_A_Q5_0) - return -ACC_TYPE(16.0) * ACC_TYPE(Qf[qib].ds.y) * k_dm.x; -#elif defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1) - return ACC_TYPE(Qf[qib].ds.y) * k_dm.y; -#else - return ACC_TYPE(0.0); -#endif + switch (FaTypeK) { + case FA_TYPE_Q4_0: return -ACC_TYPE(8.0) * ACC_TYPE(Qf[qib].ds.y) * k_dm.x; + case FA_TYPE_Q5_0: return -ACC_TYPE(16.0) * ACC_TYPE(Qf[qib].ds.y) * k_dm.x; + case FA_TYPE_Q4_1: + case FA_TYPE_Q5_1: return ACC_TYPE(Qf[qib].ds.y) * k_dm.y; + default: return ACC_TYPE(0.0); + } } void k_block_to_shmem_zero(const uint buf_ib, const uint iqs) { kblocksh[buf_ib].qs[iqs] = 0; -#if defined(DATA_A_IQ4_NL) - kblocksh[buf_ib].qs[iqs + 4] = 0; -#endif if (iqs == 0) { -#if QUANT_AUXF == 1 - kblocksh[buf_ib].dm = FLOAT_TYPE(0.0f); -#else kblocksh[buf_ib].dm = FLOAT_TYPEV2(0.0f); -#endif } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl index 10552d013a2..79c933f40cf 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl @@ -1,4 +1,13 @@ -#if defined(DATA_A_Q4_0) +#if defined(FA_MMQ_MIXED) +// Mixed-K flash attention MMQ: superset cache that fits Q4_0/Q4_1/Q5_0/Q5_1/Q8_0. +// Q4_*/Q5_* only use qs[0..3] and (for Q5_*) qh. Q8_0 uses qs[0..7]. Single-scale +// types (Q4_0/Q5_0/Q8_0) leave dm.y unused. +struct block_a_cache { + int32_t qs[8]; + uint32_t qh; + FLOAT_TYPEV2 dm; +}; +#elif defined(DATA_A_Q4_0) #define QUANT_R_MMQ 2 struct block_a_cache { uint32_t qs[16/4]; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 6f2a929c40c..d99b2b5d802 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -643,42 +643,22 @@ void process_shaders() { if (fp16) { #if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) - string_to_spv("flash_attn_f32_f16_mixed", "flash_attn_cm2.comp", + string_to_spv("flash_attn_f32_f16", "flash_attn_cm2.comp", merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), fp16, false, true, f16acc); #endif - } - - for (const auto& tname : type_names) { - if (tname == "bf16") continue; - if (fp16) { #if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) - if (tname == "f16") { - string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp", - merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"COOPMAT", "1"}}), fp16, true, false, f16acc); - } else if (tname == "q4_0" || tname == "q4_1" || tname == "q5_0" || tname == "q5_1" || tname == "iq4_nl" || tname == "q8_0" || tname == "f32") { - std::string data_a_key = "DATA_A_" + to_uppercase(tname); - string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp", - merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), fp16, true, false, f16acc); - } + string_to_spv("flash_attn_f32_f16", "flash_attn_cm1.comp", + merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"COOPMAT", "1"}}), fp16, true, false, f16acc); #endif - } + } - if (tname == "f16") { - string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", - merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), fp16, false, false, f16acc); - } else if (tname == "q4_0" || tname == "q4_1" || tname == "q5_0" || tname == "q5_1" || tname == "iq4_nl" || tname == "q8_0" || tname == "f32") { - std::string data_a_key = "DATA_A_" + to_uppercase(tname); - string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", - merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), fp16, false, false, f16acc); + string_to_spv("flash_attn_f32_f16", "flash_attn.comp", + merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), fp16, false, false, f16acc); #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) - if (tname != "f32") { - string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", - merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }, {"MMQ", "1"}}), fp16, false, false, f16acc, "_int8"); - } + string_to_spv("flash_attn_f32_f16", "flash_attn.comp", + merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"MMQ", "1"}, {"FA_MMQ_MIXED", "1"}}), fp16, false, false, f16acc, "_int8"); #endif - } - } } } From 449b33fc8f6aadf267e3b577622deae45c81ea0c Mon Sep 17 00:00:00 2001 From: Pascal Date: Mon, 11 May 2026 18:42:08 +0200 Subject: [PATCH 603/831] Ggml/cuda snake fusion hardening (llama/22912) * cuda: tighten snake fusion type checks for all operands (defensive, sync vulkan) * cuda: reject snake fusion when ne[2] or ne[3] > 1 (mirror vulkan PR review) * cuda: merge type_ok and types_ok into a single types_ok (address am17an review) * cuda: filter ADD/SUB/MUL/DIV in supports_op to F32/F16 bin_bcast only dispatches F32/F16 type triplets, mirror the vulkan filter so unsupported types fall back through cpy instead of aborting. * test-backend-ops: extend snake_fuse to rank-4 with ne[2]/ne[3] > 1 cases --- ggml/src/ggml-cuda/ggml-cuda.cu | 30 ++++++++++++++++++++++++------ 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index b92a208705d..e25be3592fd 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3929,10 +3929,25 @@ static int ggml_cuda_try_fuse(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph // closure check: the trailing add must read the same x as the leading mul const ggml_tensor * x_in_add = (add->src[0] == mul1) ? add->src[1] : add->src[0]; - const bool type_ok = (x->type == GGML_TYPE_F32 || x->type == GGML_TYPE_F16 || x->type == GGML_TYPE_BF16); + // Kernel iterates over total = T * C, so x and add must be 2D and + // a / inv_b must collapse to [1, C, 1, 1]. Higher dims are not handled. + const bool dim_ok = (x->ne[2] == 1 && x->ne[3] == 1) && + (add->ne[2] == 1 && add->ne[3] == 1) && + (a->ne[2] == 1 && a->ne[3] == 1); const bool shape_ok = ggml_are_same_shape(a, inv_b) && a->ne[0] == 1 && a->ne[1] == x->ne[1]; - if (type_ok && shape_ok && x_in_add == x && add->type == x->type) { + // x must be in the supported whitelist and every operand / intermediate + // result must share x's type, since launch_snake casts a / inv_b as + // float and templates the kernel on a single T. Mixed precision chains + // fall back to the naive path. + const ggml_tensor * sin1 = cgraph->nodes[i + 1]; + const bool types_ok = (x->type == GGML_TYPE_F32 || x->type == GGML_TYPE_F16 || x->type == GGML_TYPE_BF16) && + (a->type == x->type) && (inv_b->type == x->type) && + (mul0->type == x->type) && (sin1->type == x->type) && + (sqr->type == x->type) && (mul1->type == x->type) && + (add->type == x->type); + + if (types_ok && shape_ok && dim_ok && x_in_add == x) { ggml_cuda_op_snake_fused(*cuda_ctx, x, a, inv_b, add); return 4; } @@ -5291,12 +5306,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_VIEW: case GGML_OP_PERMUTE: case GGML_OP_TRANSPOSE: - case GGML_OP_ADD: case GGML_OP_ADD_ID: case GGML_OP_ADD1: - case GGML_OP_SUB: - case GGML_OP_MUL: - case GGML_OP_DIV: case GGML_OP_SCALE: case GGML_OP_SQR: case GGML_OP_SQRT: @@ -5305,6 +5316,13 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_CLAMP: case GGML_OP_LOG: return true; + case GGML_OP_ADD: + case GGML_OP_SUB: + case GGML_OP_MUL: + case GGML_OP_DIV: + return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && + (op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16) && + (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16); case GGML_OP_SSM_SCAN: { if (op->src[3]->ne[0] == 1) { // Mamba2 From 287f637fb15f3c59d0f579f86c4b59d3ebfbc443 Mon Sep 17 00:00:00 2001 From: CrispStrobe <154636388+CrispStrobe@users.noreply.github.com> Date: Mon, 11 May 2026 19:48:29 +0200 Subject: [PATCH 604/831] CUDA: handle OW > 65535 in im2col (2D and 3D) (llama/22944) `im2col_cuda` and `im2col_3d_cuda` both dispatch with `block_nums.y = OW`. CUDA caps grid Y at 65535. Conv1d encoders on raw 16 kHz audio with T > 65535 (~ 4 s) trip the limit -- e.g. SEANet at 11 s lands at OW = 176000 -- and the launch returns `invalid configuration argument`. Clamp `block_nums.y` to `MIN(OW, MAX_GRIDDIM_Y)` and loop inside the kernel with stride `MAX_GRIDDIM_Y`. Same in-kernel stride pattern already used for the z axis (`MAX_GRIDDIM_Z`). Both 2D `im2col_kernel` and 3D `im2col_3d_kernel` need the same fix. Bit-identical for OW <= 65535 (single iteration of the new outer loop). Tested on T4 / Jetson Orin with a SEANet encoder running on 11 s / 16 kHz audio (im2col reaching OW ~ 176000); pre-fix launch returns `invalid configuration argument`, post-fix runs to completion. Existing test-backend-ops im2col cases unchanged. --- ggml/src/ggml-cuda/im2col.cu | 61 +++++++++++++++++++----------------- 1 file changed, 32 insertions(+), 29 deletions(-) diff --git a/ggml/src/ggml-cuda/im2col.cu b/ggml/src/ggml-cuda/im2col.cu index 56dc0545742..28c79ab462e 100644 --- a/ggml/src/ggml-cuda/im2col.cu +++ b/ggml/src/ggml-cuda/im2col.cu @@ -1,5 +1,6 @@ #include "im2col.cuh" +#define MAX_GRIDDIM_Y 65535 #define MAX_GRIDDIM_Z 65535 template @@ -18,22 +19,23 @@ static __global__ void im2col_kernel( const int64_t ikh = rem / KW; const int64_t ikw = rem - ikh * KW; - const int64_t iow = blockIdx.y; - for (int64_t iz = blockIdx.z; iz < N_OH; iz+=MAX_GRIDDIM_Z) { - const int64_t in = iz / OH; - const int64_t ioh = iz - in * OH; + for (int64_t iow = blockIdx.y; iow < OW; iow += MAX_GRIDDIM_Y) { + for (int64_t iz = blockIdx.z; iz < N_OH; iz += MAX_GRIDDIM_Z) { + const int64_t in = iz / OH; + const int64_t ioh = iz - in * OH; - const int64_t iiw = iow * s0 + ikw * d0 - p0; - const int64_t iih = ioh * s1 + ikh * d1 - p1; + const int64_t iiw = iow * s0 + ikw * d0 - p0; + const int64_t iih = ioh * s1 + ikh * d1 - p1; - const int64_t offset_dst = - ((in * OH + ioh) * OW + iow) * IC_KH_KW + iic * KH_KW + ikh * KW + ikw; + const int64_t offset_dst = + ((in * OH + ioh) * OW + iow) * IC_KH_KW + iic * KH_KW + ikh * KW + ikw; - if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { - dst[offset_dst] = 0.0f; - } else { - const int64_t offset_src = iic * IC_IH_IW + in * IH_IW; - dst[offset_dst] = x[offset_src + iih * IW + iiw]; + if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { + dst[offset_dst] = 0.0f; + } else { + const int64_t offset_src = iic * IC_IH_IW + in * IH_IW; + dst[offset_dst] = x[offset_src + iih * IW + iiw]; + } } } @@ -51,7 +53,7 @@ static void im2col_cuda(const float * x, T* dst, const int64_t num_blocks = (IC_KH_KW + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE; const int64_t N_OH = N * OH; const int64_t KH_KW = KW*KH; - dim3 block_nums(num_blocks, OW, MIN(N_OH, MAX_GRIDDIM_Z)); + dim3 block_nums(num_blocks, MIN(OW, MAX_GRIDDIM_Y), MIN(N_OH, MAX_GRIDDIM_Z)); im2col_kernel<<>>(x, dst, IC, IW, IH, OH, OW, KW, KH, IC_IH_IW, IH_IW, N_OH, KH_KW, IC_KH_KW, s0, s1, p0, p1, d0, d1); @@ -136,23 +138,24 @@ static __global__ void im2col_3d_kernel( const int64_t ikh = (i - iic * KD_KH_KW - ikd * KH_KW) / KW; const int64_t ikw = i % KW; - const int64_t iow = blockIdx.y; - for (int64_t iz = blockIdx.z; iz < N_OD_OH; iz+=MAX_GRIDDIM_Z) { - const int64_t in = iz / OD_OH; - const int64_t iod = (iz - in*OD_OH) / OH; - const int64_t ioh = iz % OH; + for (int64_t iow = blockIdx.y; iow < OW; iow += MAX_GRIDDIM_Y) { + for (int64_t iz = blockIdx.z; iz < N_OD_OH; iz += MAX_GRIDDIM_Z) { + const int64_t in = iz / OD_OH; + const int64_t iod = (iz - in*OD_OH) / OH; + const int64_t ioh = iz % OH; - const int64_t iiw = iow * s0 + ikw * d0 - p0; - const int64_t iih = ioh * s1 + ikh * d1 - p1; - const int64_t iid = iod * s2 + ikd * d2 - p2; + const int64_t iiw = iow * s0 + ikw * d0 - p0; + const int64_t iih = ioh * s1 + ikh * d1 - p1; + const int64_t iid = iod * s2 + ikd * d2 - p2; - const int64_t offset_dst = in*OD_OH_OW_IC_KD_KH_KW + iod*OH_OW_IC_KD_KH_KW + ioh*OW_IC_KD_KH_KW + iow*IC_KD_KH_KW + iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw; + const int64_t offset_dst = in*OD_OH_OW_IC_KD_KH_KW + iod*OH_OW_IC_KD_KH_KW + ioh*OW_IC_KD_KH_KW + iow*IC_KD_KH_KW + iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw; - if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) { - dst[offset_dst] = 0.0f; - } else { - const int64_t offset_src = ((in * IC + iic) * stride_q) + (iid * stride_z) + (iih * stride_y) + (iiw * stride_x); - dst[offset_dst] = src[offset_src]; + if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) { + dst[offset_dst] = 0.0f; + } else { + const int64_t offset_src = ((in * IC + iic) * stride_q) + (iid * stride_z) + (iih * stride_y) + (iiw * stride_x); + dst[offset_dst] = src[offset_src]; + } } } } @@ -178,7 +181,7 @@ static void im2col_3d_cuda(const float * src, T* dst, const int64_t OH_OW_IC_KD_KH_KW = OH*OW*IC*KD*KH*KW; const int64_t OW_IC_KD_KH_KW = OW*IC*KD*KH*KW; const int64_t num_blocks = (IC_KD_KH_KW + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE; - dim3 block_nums(num_blocks, OW, MIN(N_OD_OH, MAX_GRIDDIM_Z)); + dim3 block_nums(num_blocks, MIN(OW, MAX_GRIDDIM_Y), MIN(N_OD_OH, MAX_GRIDDIM_Z)); im2col_3d_kernel<<>>(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, OH_OW, KD_KH_KW, ID_IH_IW, KH_KW, IH_IW, IC_ID_IH_IW, IC_KD_KH_KW, OW_KD_KH_KW, OD_OH_OW_IC_KD_KH_KW, From ea4652c42704fc298605fb639dedc40937845c78 Mon Sep 17 00:00:00 2001 From: Shawn Gu Date: Mon, 11 May 2026 11:57:26 -0700 Subject: [PATCH 605/831] opencl: add q4_1 MoE for Adreno (llama/22856) * Q4_1 MoE CLC pass sanity check * remove unnecessary code * opencl: remove unnecessary asserts and reformat * opencl: fix supports_op for q4_1 moe * q4_1 moe is supported by Adreno with certain shapes --------- Co-authored-by: Li He --- ggml/src/ggml-opencl/CMakeLists.txt | 2 + ggml/src/ggml-opencl/ggml-opencl.cpp | 366 ++++++++++++++++-- ggml/src/ggml-opencl/kernels/cvt.cl | 90 +++++ .../kernels/gemm_moe_q4_1_f32_ns.cl | 254 ++++++++++++ .../kernels/gemv_moe_q4_1_f32_ns.cl | 119 ++++++ 5 files changed, 798 insertions(+), 33 deletions(-) create mode 100644 ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl create mode 100644 ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index ffde6a4f063..7edb3eb4e9c 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -104,6 +104,8 @@ set(GGML_OPENCL_KERNELS mul_mv_id_mxfp4_f32_flat gemm_moe_q4_0_f32_ns gemv_moe_q4_0_f32_ns + gemm_moe_q4_1_f32_ns + gemv_moe_q4_1_f32_ns gemm_moe_mxfp4_f32 gemv_moe_mxfp4_f32 gemm_moe_mxfp4_f32_ns diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 4e6f6fb43d2..73a58f74a94 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -544,6 +544,7 @@ struct ggml_backend_opencl_context { cl_kernel kernel_convert_block_q4_0, kernel_restore_block_q4_0; cl_kernel kernel_convert_block_q4_0_trans4_ns, kernel_restore_block_q4_0_trans4_ns; cl_kernel kernel_convert_block_q4_1, kernel_restore_block_q4_1; + cl_kernel kernel_convert_block_q4_1_trans4_ns, kernel_restore_block_q4_1_trans4_ns; cl_kernel kernel_convert_block_mxfp4, kernel_convert_block_mxfp4_trans, kernel_restore_block_mxfp4, kernel_restore_block_mxfp4_trans; cl_kernel kernel_convert_block_mxfp4_trans4_ns, kernel_restore_block_mxfp4_trans4_ns; cl_kernel kernel_convert_block_q8_0, kernel_restore_block_q8_0, kernel_restore_block_q8_0_trans; @@ -602,6 +603,7 @@ struct ggml_backend_opencl_context { cl_kernel kernel_ssm_conv_f32_f32, kernel_ssm_conv_f32_f32_4; cl_kernel kernel_timestep_embedding; cl_kernel kernel_gemv_moe_q4_0_f32_ns, kernel_gemm_moe_q4_0_f32_ns; + cl_kernel kernel_gemv_moe_q4_1_f32_ns, kernel_gemm_moe_q4_1_f32_ns; cl_kernel kernel_gemv_moe_mxfp4_f32, kernel_gemm_moe_mxfp4_f32; cl_kernel kernel_gemv_moe_mxfp4_f32_ns, kernel_gemm_moe_mxfp4_f32_ns; cl_kernel kernel_moe_reorder_b; @@ -958,6 +960,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve CL_CHECK((backend_ctx->kernel_restore_block_q4_1_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_1_noshuffle", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q4_1 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_1", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_q4_1 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_1", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q4_1_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_1_trans4_ns", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q4_1_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_1_trans4_ns", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_mxfp4 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_mxfp4_trans = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4_trans", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_mxfp4_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4_trans4_ns", &err), err)); @@ -2856,6 +2860,38 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve " -cl-mad-enable " " -cl-fast-relaxed-math"; + // gemv_moe_q4_1_f32_ns + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemv_moe_q4_1_f32_ns.cl.h" + }; +#else + const std::string kernel_src = read_file("gemv_moe_q4_1_f32_ns.cl"); +#endif + cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); + + CL_CHECK((backend_ctx->kernel_gemv_moe_q4_1_f32_ns = clCreateKernel(prog, "kernel_gemv_moe_q4_1_f32_ns", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // gemm_moe_q4_1_f32_ns + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemm_moe_q4_1_f32_ns.cl.h" + }; +#else + const std::string kernel_src = read_file("gemm_moe_q4_1_f32_ns.cl"); +#endif + cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); + + CL_CHECK((backend_ctx->kernel_gemm_moe_q4_1_f32_ns = clCreateKernel(prog, "kernel_gemm_moe_q4_1_f32_ns", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + // gemv_moe_mxfp4_f32 { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -3749,11 +3785,14 @@ struct ggml_tensor_extra_cl_q4_1 { CL_CHECK(clReleaseMemObject(m)); m = nullptr; } + if (q_img != nullptr) { + CL_CHECK(clReleaseMemObject(q_img)); + q_img = nullptr; + } // Currently, q_img and d_img are only initialized when SMALL_ALLOC is // enabled. They point to the images in ggml_backend_opencl_buffer_context. // So, there is no need to release them here. // TODO: initialize them for non SMALL_PATH path, or remove them. - q_img = nullptr; d_img = nullptr; m_img = nullptr; size_q = 0; @@ -4189,6 +4228,35 @@ static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggm return GGML_STATUS_SUCCESS; } +// The optimized gemm and gemv kernels are used for large matrices without batch. +// tensor is the quantized weights matrix. +inline bool use_adreno_kernels(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) { + int64_t threshold_ne0 = 512; + int64_t threshold_ne1 = 512; + if (!backend_ctx->adreno_cl_compiler_version.newer_than_or_same(E031, 38, 11, 0) && + backend_ctx->adreno_cl_compiler_version.type != DX) { + threshold_ne0 = 128; + threshold_ne1 = 128; + } + return tensor->ne[0] >= threshold_ne0 && tensor->ne[1] >= threshold_ne1 && + tensor->ne[2] == 1 && tensor->ne[3] == 1; +} + +inline bool use_adreno_moe_kernels(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) { + GGML_UNUSED(backend_ctx); + int ne01 = tensor->ne[1]; + return (((strstr(tensor->name, "ffn") != NULL) && (strstr(tensor->name, "exps") != NULL)) || (strstr(tensor->name, "as") != NULL)) && (ne01 % 64 == 0); +} + +inline bool enable_adreno_trans_weight(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) { + + bool adreno_kernel = use_adreno_kernels(backend_ctx, tensor); + + size_t elem_num = tensor->ne[0] * tensor->ne[1] * tensor->ne[2] * tensor->ne[3]; + + return ((elem_num < 128 * 1024 * 1024) && adreno_kernel); // max element num: 2**27 +} + static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) { ggml_backend_opencl_device_context * dev_ctx = (ggml_backend_opencl_device_context *)dev->context; ggml_backend_opencl_context * backend_ctx = dev_ctx->backend_ctx; @@ -4385,6 +4453,18 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]); } } + // q4_0, q8_0 and mxfp4 have general MUL_MAT_ID support, + // the quantizations here currently do not - they are only supported by Adreno with certain shapes + if (op->src[0]->type == GGML_TYPE_Q4_1) { +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (op->src[1]->type == GGML_TYPE_F32) { + return use_adreno_moe_kernels(backend_ctx, op->src[0]) + && ggml_is_contiguous(op->src[0]) + && ggml_is_contiguous(op->src[1]); + } +#endif + return false; + } return false; case GGML_OP_RESHAPE: case GGML_OP_VIEW: @@ -4555,6 +4635,12 @@ struct ggml_backend_opencl_buffer_context { for (ggml_tensor_extra_cl_q4_0 * e : temp_tensor_extras_q4_0_in_use) { delete e; } + for (ggml_tensor_extra_cl_q4_1 * e : temp_tensor_extras_q4_1) { + delete e; + } + for (ggml_tensor_extra_cl_q4_1 * e : temp_tensor_extras_q4_1_in_use) { + delete e; + } for (ggml_tensor_extra_cl_mxfp4 * e : temp_tensor_extras_mxfp4) { delete e; } @@ -4868,35 +4954,6 @@ static enum ggml_status ggml_backend_opencl_buffer_init_tensor(ggml_backend_buff return GGML_STATUS_SUCCESS; } -// The optimized gemm and gemv kernels are used for large matrices without batch. -// tensor is the quantized weights matrix. -inline bool use_adreno_kernels(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) { - int64_t threshold_ne0 = 512; - int64_t threshold_ne1 = 512; - if (!backend_ctx->adreno_cl_compiler_version.newer_than_or_same(E031, 38, 11, 0) && - backend_ctx->adreno_cl_compiler_version.type != DX) { - threshold_ne0 = 128; - threshold_ne1 = 128; - } - return tensor->ne[0] >= threshold_ne0 && tensor->ne[1] >= threshold_ne1 && - tensor->ne[2] == 1 && tensor->ne[3] == 1; -} - -inline bool use_adreno_moe_kernels(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) { - GGML_UNUSED(backend_ctx); - int ne01 = tensor->ne[1]; - return (((strstr(tensor->name, "ffn") != NULL) && (strstr(tensor->name, "exps") != NULL)) || (strstr(tensor->name, "as") != NULL)) && (ne01 % 64 == 0); -} - -inline bool enable_adreno_trans_weight(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) { - - bool adreno_kernel = use_adreno_kernels(backend_ctx, tensor); - - size_t elem_num = tensor->ne[0] * tensor->ne[1] * tensor->ne[2] * tensor->ne[3]; - - return ((elem_num < 128 * 1024 * 1024) && adreno_kernel); // max element num: 2**27 -} - static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(buffer->buft->device); @@ -5097,15 +5154,54 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); CL_CHECK(err); - #ifdef GGML_OPENCL_USE_ADRENO_KERNELS +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + // Adreno moe q4_1 kernel needs special transpose and unshuffling + if (use_adreno_moe_kernels(backend_ctx, tensor)) { + cl_kernel kernel = backend_ctx->kernel_convert_block_q4_1_trans4_ns; + + int ne00 = tensor->ne[0]; + int ne01 = tensor->ne[1]; + int ne02 = tensor->ne[2]; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->m)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01)); + + size_t global_work_size[3] = {static_cast(((ne01 + 63) / 64) * 64), static_cast(ne00 / 32), static_cast(ne02)}; + size_t local_work_size[3] = {64, 2, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + + // Create image for Q + cl_image_format img_format_q = {CL_R, CL_UNSIGNED_INT32}; + cl_image_desc img_desc_q = { + CL_MEM_OBJECT_IMAGE1D_BUFFER, + static_cast(ggml_nelements(tensor) / 8), + 0, 0, 0, 0, 0, 0, 0, + { extra->q } + }; + extra->q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_format_q, &img_desc_q, NULL, &err); + tensor->extra = extra; + + return; + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + + // normal q4_1 repack +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS cl_kernel kernel = backend_ctx->kernel_convert_block_q4_1; if (use_adreno_kernels(backend_ctx, tensor)) { kernel = backend_ctx->kernel_convert_block_q4_1_noshuffle; } - #else +#else cl_kernel kernel = backend_ctx->kernel_convert_block_q4_1; - #endif // GGML_OPENCL_USE_ADRENO_KERNELS +#endif // GGML_OPENCL_USE_ADRENO_KERNELS CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q)); CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d)); @@ -5862,6 +5958,36 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, ggml_tensor_extra_cl_q4_1 * extra = (ggml_tensor_extra_cl_q4_1 *)tensor->extra; #ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_moe_kernels(backend_ctx, tensor)) { + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + cl_kernel kernel = backend_ctx->kernel_restore_block_q4_1_trans4_ns; + + int ne00 = tensor->ne[0]; + int ne01 = tensor->ne[1]; + int ne02 = tensor->ne[2]; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->m)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_int), &ne01)); + + size_t global_work_size[3] = {static_cast(((ne01 + 63) / 64) * 64), static_cast(ne00 / 32), static_cast(ne02)}; + size_t local_work_size[3] = {64, 2, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clEnqueueReadBuffer( + queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); + return; + } if (use_adreno_kernels(backend_ctx, tensor)) { static ggml_cl_buffer buf_trans_q; static ggml_cl_buffer buf_trans_m; @@ -12862,6 +12988,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, #ifdef GGML_OPENCL_SOA_Q ggml_tensor_extra_cl_q4_0 * extra0_q4_0 = (ggml_tensor_extra_cl_q4_0 *)src0->extra; + ggml_tensor_extra_cl_q4_1 * extra0_q4_1 = (ggml_tensor_extra_cl_q4_1 *)src0->extra; ggml_tensor_extra_cl_mxfp4 * extra0_mxfp4 = (ggml_tensor_extra_cl_mxfp4 *)src0->extra; ggml_tensor_extra_cl_q8_0 * extra0_q8_0 = (ggml_tensor_extra_cl_q8_0 *)src0->extra; #endif @@ -13131,6 +13258,179 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, break; } + case GGML_TYPE_Q4_1: { +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_moe_kernels(backend_ctx, src0)) { + cl_int status; + + size_t local_size[3] = {64, 2, 1}; + size_t global_size[3] = {64, 2, 1}; + + if (ne12 == 1) { // for gemv + kernel = backend_ctx->kernel_gemv_moe_q4_1_f32_ns; + + cl_mem src1_sub_buffer, buf_src1_image, buf_src2; + + // create a sub_buffer for src2 + cl_buffer_region region; + region.origin = offset2; + region.size = ne20 * ne21 * sizeof(int); + buf_src2 = clCreateSubBuffer(extra2->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // set thread grid + global_size[0] = static_cast(ne01); + global_size[1] = 4; + global_size[2] = static_cast(ne20); + local_size[1] = 4; + + // create a sub_buffer for src1 + region.origin = offset1; + region.size = ne10 * ne11 * ne12 * sizeof(float); + src1_sub_buffer = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // create image for src1 + cl_image_format image_format_buf_src1 = {CL_RGBA, CL_FLOAT}; + cl_image_desc image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ne10 * ne11 * ne12 / 4), 0,0,0,0,0,0,0, {src1_sub_buffer}}; + buf_src1_image = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status); + CL_CHECK(status); + + // Set kernel args + int arg_idx = 0; + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_1->q)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_1->d)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_1->m)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src1_image)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne11)); + + // launch kernel + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst); + + // deallocate sub buffers and images + CL_CHECK(clReleaseMemObject(src1_sub_buffer)); + CL_CHECK(clReleaseMemObject(buf_src1_image)); + CL_CHECK(clReleaseMemObject(buf_src2)); + + } else { // for gemm + kernel = backend_ctx->kernel_gemm_moe_q4_1_f32_ns; + + if (strstr(src0->name, "as") != NULL) { + moe_router_reoerder(backend, src2, ne20); + } + + cl_mem sub_buf_src1_pre, buf_src1_reordered, image_src1_reordered, sub_buf_dst, buf_dst_image; + cl_mem buf_src2, buf_src2_emap; + + cl_buffer_region region; + region.origin = 0; + region.size = sizeof(int) * max_post_router_tile * n_tile_size; + buf_src2 = clCreateSubBuffer(backend_ctx->prealloc_post_router.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + region.origin = 0; + region.size = sizeof(short) * max_post_router_tile; + buf_src2_emap = clCreateSubBuffer(backend_ctx->prealloc_emap.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // Reorder activations + // create a sub_buffer for src1 + region.origin = offset1; + region.size = ne10 * ne11 * ne12 * sizeof(float); + sub_buf_src1_pre = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // Create image for reordered src1 + // Use pre-allocated placeholder + region.origin = 0; + region.size = ne00 * max_post_router_tile * n_tile_size * sizeof(float); + backend_ctx->prealloc_act_trans.allocate(backend_ctx->context, region.size); + buf_src1_reordered = clCreateSubBuffer( + backend_ctx->prealloc_act_trans.buffer, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &status); + CL_CHECK(status); + cl_image_format image_format_buf_src1; + cl_image_desc image_desc_buf_src1; + image_format_buf_src1 = {CL_RGBA, CL_FLOAT}; + image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ne00 * max_post_router_tile * n_tile_size / 4), 0,0,0,0,0,0,0, {buf_src1_reordered}}; + image_src1_reordered = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status); + CL_CHECK(status); + + unsigned short map_ratio = ne20 / ne11; + GGML_ASSERT(((map_ratio == 1) || (map_ratio == ne20)) && "Map ratio not supported\n"); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 0, sizeof(cl_mem), &sub_buf_src1_pre)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 1, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 2, sizeof(cl_mem), &buf_src1_reordered)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 3, sizeof(cl_mem), &(backend_ctx->prealloc_total_tiles.buffer))); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 4, sizeof(unsigned int), &ne00)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 5, sizeof(unsigned short), &map_ratio)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 6, sizeof(unsigned int), &n_tile_size)); + + size_t reorder_b_local_size[3] = {256, 1, 1}; + size_t reorder_b_global_size[3] = {static_cast(((ne00 / 4) + 255) / 256 * 256), static_cast(max_post_router_tile * n_tile_size), 1}; + + // Dispatch reorder kernel + backend_ctx->enqueue_ndrange_kernel(backend_ctx->kernel_moe_reorder_b, 3, reorder_b_global_size, reorder_b_local_size, dst); + + // MoE kernel prepare + // Create sub buffer for dst + region.origin = offsetd; + region.size = ne0 * ne1 * ne2 * sizeof(float); + sub_buf_dst = clCreateSubBuffer( + extrad->data_device, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &status); + CL_CHECK(status); + // Create image for dst + cl_image_format image_format_buf_dst = {CL_R, CL_FLOAT}; + cl_image_desc image_desc_buf_dst = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ne0 * ne1 * ne2), 0,0,0,0,0,0,0, {sub_buf_dst}}; + buf_dst_image = clCreateImage(backend_ctx->context, CL_MEM_WRITE_ONLY, &image_format_buf_dst, &image_desc_buf_dst, NULL, &status); + CL_CHECK(status); + + // Set kernel args + int arg_idx = 0; + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_1->q_img)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_1->d)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_1->m)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &image_src1_reordered)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2_emap)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_dst_image)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &(backend_ctx->prealloc_total_tiles.buffer))); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01)); + + // set thread grid + global_size[1] = static_cast((ne01 + 63) / 64); + global_size[2] = static_cast(max_post_router_tile); + local_size[1] = 1; + local_size[2] = 1; + + // Dispatch kernel + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst); + + clReleaseMemObject(sub_buf_src1_pre); + clReleaseMemObject(buf_src1_reordered); + clReleaseMemObject(image_src1_reordered); + clReleaseMemObject(buf_src2); + clReleaseMemObject(buf_src2_emap); + clReleaseMemObject(sub_buf_dst); + clReleaseMemObject(buf_dst_image); + } + return; + } +#endif //GGML_OPENCL_USE_ADRENO_KERNELS + } case GGML_TYPE_Q8_0: { #ifdef GGML_OPENCL_SOA_Q kernel = backend_ctx->kernel_mul_mv_id_q8_0_f32_flat; diff --git a/ggml/src/ggml-opencl/kernels/cvt.cl b/ggml/src/ggml-opencl/kernels/cvt.cl index c87450dc49e..5bbf09710f9 100644 --- a/ggml/src/ggml-opencl/kernels/cvt.cl +++ b/ggml/src/ggml-opencl/kernels/cvt.cl @@ -370,6 +370,96 @@ kernel void kernel_restore_block_q4_1_noshuffle( } } +kernel void kernel_convert_block_q4_1_trans4_ns( + __global struct block_q4_1 * src0, + __global uint * dst_q, + __global half * dst_d, + __global half * dst_m, + uint ne00, + uint ne01 +) { + uint i00 = get_global_id(1); + uint i01 = get_global_id(0); + uint i02 = get_global_id(2); + + uint ne00_blk = ne00 / QK4_1; + uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; + uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; + + global struct block_q4_1 * b = src0 + src_blk_offset; + dst_d[dst_blk_offset] = b->d; + dst_m[dst_blk_offset] = b->m; + + // extract quantization and unshuffle + ushort8 pre_block = ((global ushort8 *)(&(b->qs[0])))[0]; + + ushort8 post_block = (ushort8)(0); + + uchar * pre_block_ptr = (uchar *)(&pre_block); + uchar * post_block_ptr = (uchar *)(&post_block); + + for (int i = 0; i < QK4_1 / 4; ++i) { + uchar x0 = pre_block_ptr[2*i + 0]; + uchar x1 = pre_block_ptr[2*i + 1]; + + post_block_ptr[i + 0 ] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4); + post_block_ptr[i + QK4_1 / 4] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0); + } + + uint4 q_block = as_uint4(post_block); + + uint offset = i02 * ne00_blk * ne01 * 4 + i00 * ne01 * 4 + i01; + dst_q[offset] = q_block.x; + dst_q[offset + ne01] = q_block.y; + dst_q[offset + ne01 * 2] = q_block.z; + dst_q[offset + ne01 * 3] = q_block.w; +} + +kernel void kernel_restore_block_q4_1_trans4_ns( + __global uint * src_q, + __global half * src_d, + __global half * src_m, + __global struct block_q4_1 * dst0, + uint ne00, + uint ne01 +) { + int i00 = get_global_id(1); + uint i01 = get_global_id(0); + uint i02 = get_global_id(2); + + uint ne00_blk = ne00 / QK4_1; + uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; + uint src_dm_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; + + __global struct block_q4_1 * b = dst0 + dst_blk_offset; + b->d = src_d[src_dm_offset]; + b->m = src_m[src_dm_offset]; + + // collect transposed quantization parts for a block + uint src_q_offset = i02 * ne00_blk * ne01 * 4 + i00 * ne01 * 4 + i01; + uint4 q_block; + q_block.x = src_q[src_q_offset]; + q_block.y = src_q[src_q_offset + ne01]; + q_block.z = src_q[src_q_offset + ne01 * 2]; + q_block.w = src_q[src_q_offset + ne01 * 3]; + + ushort8 post_block = as_ushort8(q_block); + ushort8 pre_block = (ushort8)(0); + + uchar * pre_block_ptr = (uchar *)(&pre_block); + uchar * post_block_ptr = (uchar *)(&post_block); + + for (int i = 0; i < QK4_0 / 4; ++i) { + uchar x0 = post_block_ptr[i + 0]; + uchar x1 = post_block_ptr[i + QK4_0 / 4]; + + pre_block_ptr[2 * i + 0] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4); + pre_block_ptr[2 * i + 1] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0); + } + + ((__global ushort8 *)(&(b->qs[0])))[0] = pre_block; +} + //------------------------------------------------------------------------------ // block_mxfp4 //------------------------------------------------------------------------------ diff --git a/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl new file mode 100644 index 00000000000..e2574ae0187 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl @@ -0,0 +1,254 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_uniform_load: enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_constant_load: enable +#pragma OPENCL EXTENSION cl_qcom_extra_vector_types : enable + +#define TILESIZE_K 16 +#define TILESIZE_M 64 +#define TILESIZE_N 32 + + +#define dequantize_q4_1(q4, a_f16, scale, m) \ + a_f16.s0 = (half)(q4.s0 & 0x000F) * scale + m; \ + a_f16.s1 = (half)((q4.s0 & 0x00F0) >> 4) * scale + m; \ + a_f16.s2 = (half)((q4.s0 & 0x0F00) >> 8) * scale + m; \ + a_f16.s3 = (half)((q4.s0 & 0xF000) >> 12) * scale + m; \ + a_f16.s4 = (half)(q4.s1 & 0x000F) * scale + m; \ + a_f16.s5 = (half)((q4.s1 & 0x00F0) >> 4) * scale + m; \ + a_f16.s6 = (half)((q4.s1 & 0x0F00) >> 8) * scale + m; \ + a_f16.s7 = (half)((q4.s1 & 0xF000) >> 12) * scale + m; \ + a_f16.s8 = (half)(q4.s2 & 0x000F) * scale + m; \ + a_f16.s9 = (half)((q4.s2 & 0x00F0) >> 4) * scale + m; \ + a_f16.sa = (half)((q4.s2 & 0x0F00) >> 8) * scale + m; \ + a_f16.sb = (half)((q4.s2 & 0xF000) >> 12) * scale + m; \ + a_f16.sc = (half)(q4.s3 & 0x000F) * scale + m; \ + a_f16.sd = (half)((q4.s3 & 0x00F0) >> 4) * scale + m; \ + a_f16.se = (half)((q4.s3 & 0x0F00) >> 8) * scale + m; \ + a_f16.sf = (half)((q4.s3 & 0xF000) >> 12) * scale + m; \ + + +#define dotx16_reduce8(a_reg, b_lm, c_reg, lm_offset) \ + acc.s0 = dot(a_reg.s0123, b_lm[lm_offset + 0]); \ + acc.s1 = dot(a_reg.s0123, b_lm[lm_offset + 1]); \ + acc.s2 = dot(a_reg.s0123, b_lm[lm_offset + 2]); \ + acc.s3 = dot(a_reg.s0123, b_lm[lm_offset + 3]); \ + acc.s4 = dot(a_reg.s0123, b_lm[lm_offset + 4]); \ + acc.s5 = dot(a_reg.s0123, b_lm[lm_offset + 5]); \ + acc.s6 = dot(a_reg.s0123, b_lm[lm_offset + 6]); \ + acc.s7 = dot(a_reg.s0123, b_lm[lm_offset + 7]); \ + acc.s8 = dot(a_reg.s0123, b_lm[lm_offset + 8]); \ + acc.s9 = dot(a_reg.s0123, b_lm[lm_offset + 9]); \ + acc.sa = dot(a_reg.s0123, b_lm[lm_offset + 10]); \ + acc.sb = dot(a_reg.s0123, b_lm[lm_offset + 11]); \ + acc.sc = dot(a_reg.s0123, b_lm[lm_offset + 12]); \ + acc.sd = dot(a_reg.s0123, b_lm[lm_offset + 13]); \ + acc.se = dot(a_reg.s0123, b_lm[lm_offset + 14]); \ + acc.sf = dot(a_reg.s0123, b_lm[lm_offset + 15]); \ + acc.s0 += dot(a_reg.s4567, b_lm[lm_offset + 32]); \ + acc.s1 += dot(a_reg.s4567, b_lm[lm_offset + 33]); \ + acc.s2 += dot(a_reg.s4567, b_lm[lm_offset + 34]); \ + acc.s3 += dot(a_reg.s4567, b_lm[lm_offset + 35]); \ + acc.s4 += dot(a_reg.s4567, b_lm[lm_offset + 36]); \ + acc.s5 += dot(a_reg.s4567, b_lm[lm_offset + 37]); \ + acc.s6 += dot(a_reg.s4567, b_lm[lm_offset + 38]); \ + acc.s7 += dot(a_reg.s4567, b_lm[lm_offset + 39]); \ + acc.s8 += dot(a_reg.s4567, b_lm[lm_offset + 40]); \ + acc.s9 += dot(a_reg.s4567, b_lm[lm_offset + 41]); \ + acc.sa += dot(a_reg.s4567, b_lm[lm_offset + 42]); \ + acc.sb += dot(a_reg.s4567, b_lm[lm_offset + 43]); \ + acc.sc += dot(a_reg.s4567, b_lm[lm_offset + 44]); \ + acc.sd += dot(a_reg.s4567, b_lm[lm_offset + 45]); \ + acc.se += dot(a_reg.s4567, b_lm[lm_offset + 46]); \ + acc.sf += dot(a_reg.s4567, b_lm[lm_offset + 47]); \ + c_reg.lo += convert_float8(acc.lo); \ + c_reg.hi += convert_float8(acc.hi); \ + acc.s0 = dot(a_reg.s89ab, b_lm[lm_offset + 64]); \ + acc.s1 = dot(a_reg.s89ab, b_lm[lm_offset + 65]); \ + acc.s2 = dot(a_reg.s89ab, b_lm[lm_offset + 66]); \ + acc.s3 = dot(a_reg.s89ab, b_lm[lm_offset + 67]); \ + acc.s4 = dot(a_reg.s89ab, b_lm[lm_offset + 68]); \ + acc.s5 = dot(a_reg.s89ab, b_lm[lm_offset + 69]); \ + acc.s6 = dot(a_reg.s89ab, b_lm[lm_offset + 70]); \ + acc.s7 = dot(a_reg.s89ab, b_lm[lm_offset + 71]); \ + acc.s8 = dot(a_reg.s89ab, b_lm[lm_offset + 72]); \ + acc.s9 = dot(a_reg.s89ab, b_lm[lm_offset + 73]); \ + acc.sa = dot(a_reg.s89ab, b_lm[lm_offset + 74]); \ + acc.sb = dot(a_reg.s89ab, b_lm[lm_offset + 75]); \ + acc.sc = dot(a_reg.s89ab, b_lm[lm_offset + 76]); \ + acc.sd = dot(a_reg.s89ab, b_lm[lm_offset + 77]); \ + acc.se = dot(a_reg.s89ab, b_lm[lm_offset + 78]); \ + acc.sf = dot(a_reg.s89ab, b_lm[lm_offset + 79]); \ + acc.s0 += dot(a_reg.scdef, b_lm[lm_offset + 96]); \ + acc.s1 += dot(a_reg.scdef, b_lm[lm_offset + 97]); \ + acc.s2 += dot(a_reg.scdef, b_lm[lm_offset + 98]); \ + acc.s3 += dot(a_reg.scdef, b_lm[lm_offset + 99]); \ + acc.s4 += dot(a_reg.scdef, b_lm[lm_offset + 100]); \ + acc.s5 += dot(a_reg.scdef, b_lm[lm_offset + 101]); \ + acc.s6 += dot(a_reg.scdef, b_lm[lm_offset + 102]); \ + acc.s7 += dot(a_reg.scdef, b_lm[lm_offset + 103]); \ + acc.s8 += dot(a_reg.scdef, b_lm[lm_offset + 104]); \ + acc.s9 += dot(a_reg.scdef, b_lm[lm_offset + 105]); \ + acc.sa += dot(a_reg.scdef, b_lm[lm_offset + 106]); \ + acc.sb += dot(a_reg.scdef, b_lm[lm_offset + 107]); \ + acc.sc += dot(a_reg.scdef, b_lm[lm_offset + 108]); \ + acc.sd += dot(a_reg.scdef, b_lm[lm_offset + 109]); \ + acc.se += dot(a_reg.scdef, b_lm[lm_offset + 110]); \ + acc.sf += dot(a_reg.scdef, b_lm[lm_offset + 111]); \ + c_reg.lo += convert_float8(acc.lo); \ + c_reg.hi += convert_float8(acc.hi); \ + + +__attribute__((qcom_wave_pair_mode(1))) // 1=force single 2=force pair +kernel void kernel_gemm_moe_q4_1_f32_ns( + __read_only image1d_buffer_t src0_q, + __global half * src0_d, + __global half * src0_m, + __read_only image1d_buffer_t src1, + __global uint * src2, + __global ushort * src2_emap, + __write_only image1d_buffer_t dst, + __global int * total_tiles, + uint ne00, + uint ne01 +) { + uint block_id_m = get_global_id(1); // m_tile + uint block_id_n = get_global_id(2); // n_tile + + // Boundary check + if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) { + return; + } + + __private half16 reg_a; + __private float32 reg_c = (float32)(0); + __local half4 shared_b[128]; + + const ushort expert_id = src2_emap[block_id_n]; + + const uint row = block_id_m * TILESIZE_M; + const uint col = block_id_n * TILESIZE_N; + + uint sub_block_id_m = get_local_id(0); + uint2 b_global_offset; + b_global_offset.x = ((sub_block_id_m & 3) << 2) + (sub_block_id_m >> 2) * ne00; + b_global_offset.y = b_global_offset.x + (16 * ne00); + uint2 b_local_offset; + b_local_offset.x = (sub_block_id_m & 3) * 32 + (sub_block_id_m >> 2); + b_local_offset.y = b_local_offset.x + 16; + + // Loop along K axis, 32 elements (one block) for each iteration, divided into 2 sub-blocks + for (uint step = 0; step < ne00; step += TILESIZE_K * 2) { + // First sub-block + uint q_sub_offset = row + ((ne01 * step) >> 3) + ((expert_id * ne00 * ne01) >> 3); + uint s_sub_offset = row + ((ne01 * step) >> 5) + ((expert_id * ne00 * ne01) >> 5); + uint b_sub_offset = col * ne00 + step; + + // Load scale and m for current Q4_1 block + uint sm_offset = s_sub_offset + get_global_id(0); + half s = src0_d[sm_offset]; + half m = src0_m[sm_offset]; + + // Load 16 q (64-bits) in transposed layout + uint2 q4x16; + q4x16.x = read_imageui(src0_q, q_sub_offset + sub_block_id_m).x; + q4x16.y = read_imageui(src0_q, q_sub_offset + sub_block_id_m + ne01).x; + + // Load 16x32 floats from matrix B, each fiber out of 64 in a sub-group loads 8 elements + float8 bx8_f32; + bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4); + bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4); + // Convert to half and store to LM to share within the subgroup + half8 bx8_f16 = convert_half8(bx8_f32); + shared_b[b_local_offset.x] = bx8_f16.lo; + shared_b[b_local_offset.y] = bx8_f16.hi; + + // Dequantization + dequantize_q4_1(as_ushort4(q4x16), reg_a, s, m); + + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + + // 32 16x16 fp16 dot product with 8 elements reduction for better precision + half16 acc; + dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0); + dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); + + // Repeat for second sub-block + uint half_step = step + TILESIZE_K; + q_sub_offset = row + ((ne01 * half_step) >> 3) + ((expert_id * ne00 * ne01) >> 3); + b_sub_offset = col * ne00 + half_step; + + // Load next 16 q (64-bits) in transposed layout + q4x16.x = read_imageui(src0_q, q_sub_offset + sub_block_id_m).x; + q4x16.y = read_imageui(src0_q, q_sub_offset + sub_block_id_m + ne01).x; + + // Load 16x32 floats from matrix B, each fiber out of 64 in a sub-group loads 8 elements + bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4); + bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4); + // Convert to half and store to LM to share within the subgroup + bx8_f16 = convert_half8(bx8_f32); + shared_b[b_local_offset.x] = bx8_f16.lo; + shared_b[b_local_offset.y] = bx8_f16.hi; + + // Dequantization + dequantize_q4_1(as_ushort4(q4x16), reg_a, s, m); + + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + + // 32 16x16 fp16 dot product with 3-levels reduction for better precision + dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0); + dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); + } + + // Load poster router and share in LM + __local uint out_idx[TILESIZE_N]; + + if (get_local_id(0) < TILESIZE_N) { + uint idx = src2[block_id_n * TILESIZE_N + get_local_id(0)]; + if (idx == 0xFFFFFFFF) { + idx = src2[block_id_n * TILESIZE_N + 0]; + } + out_idx[get_local_id(0)] = idx * ne01; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + // Scatter results back to original position in output grid + uint m_offset = row + get_local_id(0); + + write_imagef(dst, out_idx[1] + m_offset, (reg_c.s1)); + write_imagef(dst, out_idx[2] + m_offset, (reg_c.s2)); + write_imagef(dst, out_idx[3] + m_offset, (reg_c.s3)); + write_imagef(dst, out_idx[4] + m_offset, (reg_c.s4)); + write_imagef(dst, out_idx[5] + m_offset, (reg_c.s5)); + write_imagef(dst, out_idx[6] + m_offset, (reg_c.s6)); + write_imagef(dst, out_idx[7] + m_offset, (reg_c.s7)); + write_imagef(dst, out_idx[8] + m_offset, (reg_c.s8)); + write_imagef(dst, out_idx[9] + m_offset, (reg_c.s9)); + write_imagef(dst, out_idx[10] + m_offset, (reg_c.sa)); + write_imagef(dst, out_idx[11] + m_offset, (reg_c.sb)); + write_imagef(dst, out_idx[12] + m_offset, (reg_c.sc)); + write_imagef(dst, out_idx[13] + m_offset, (reg_c.sd)); + write_imagef(dst, out_idx[14] + m_offset, (reg_c.se)); + write_imagef(dst, out_idx[15] + m_offset, (reg_c.sf)); + write_imagef(dst, out_idx[16] + m_offset, (reg_c.sg)); + write_imagef(dst, out_idx[17] + m_offset, (reg_c.sh)); + write_imagef(dst, out_idx[18] + m_offset, (reg_c.si)); + write_imagef(dst, out_idx[19] + m_offset, (reg_c.sj)); + write_imagef(dst, out_idx[20] + m_offset, (reg_c.sk)); + write_imagef(dst, out_idx[21] + m_offset, (reg_c.sl)); + write_imagef(dst, out_idx[22] + m_offset, (reg_c.sm)); + write_imagef(dst, out_idx[23] + m_offset, (reg_c.sn)); + write_imagef(dst, out_idx[24] + m_offset, (reg_c.so)); + write_imagef(dst, out_idx[25] + m_offset, (reg_c.sp)); + write_imagef(dst, out_idx[26] + m_offset, (reg_c.sq)); + write_imagef(dst, out_idx[27] + m_offset, (reg_c.sr)); + write_imagef(dst, out_idx[28] + m_offset, (reg_c.ss)); + write_imagef(dst, out_idx[29] + m_offset, (reg_c.st)); + write_imagef(dst, out_idx[30] + m_offset, (reg_c.su)); + write_imagef(dst, out_idx[31] + m_offset, (reg_c.sv)); + + // Store zero padding parts to the index of first output in tile, override correct result in the end + barrier(CLK_GLOBAL_MEM_FENCE); + write_imagef(dst, out_idx[0] + m_offset, (reg_c.s0)); +} diff --git a/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl new file mode 100644 index 00000000000..3739a215705 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl @@ -0,0 +1,119 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable + +#define QK_Q4_1 32 +#define N_SIMDGROUP 4 +#define SIMDGROUP_WIDTH 64 + +static inline float8 q4_1_to_fp32_packed8(ushort2 q4x8, half s, half m) { + float8 fp32x8; + fp32x8.s0 = (float)((q4x8.s0 & 0x000F) * s + m); + fp32x8.s1 = (float)(((q4x8.s0 & 0x00F0) >> 4) * s + m); + fp32x8.s2 = (float)(((q4x8.s0 & 0x0F00) >> 8) * s + m); + fp32x8.s3 = (float)(((q4x8.s0 & 0xF000) >> 12) * s + m); + fp32x8.s4 = (float)((q4x8.s1 & 0x000F) * s + m); + fp32x8.s5 = (float)(((q4x8.s1 & 0x00F0) >> 4) * s + m); + fp32x8.s6 = (float)(((q4x8.s1 & 0x0F00) >> 8) * s + m); + fp32x8.s7 = (float)(((q4x8.s1 & 0xF000) >> 12) * s + m); + return fp32x8; +} + + +__attribute__((qcom_reqd_sub_group_size("half"))) +__kernel void kernel_gemv_moe_q4_1_f32_ns( + __global uint * src0_q, + __global half * src0_d, + __global half * src0_m, + __read_only image1d_buffer_t src1, + __global uint * src2, + __global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne11 +) { + uint i01 = get_global_id(0); + uint i20 = get_global_id(2); + uint sgid = get_local_id(1); + uint slid = get_sub_group_local_id(); + + uint i11 = i20 % ne11; + + uint expert_id = src2[i20]; + uint expert_offset = expert_id * ne00 * ne01 / 32; + + __private float sum = 0.0f; // each thread calculate partial sum of one output + + // loop along ne00 in block granularity, skip 4 blocks every iter + for (uint ib00 = sgid; ib00 < (ne00 / QK_Q4_1); ib00 += N_SIMDGROUP) { + + // load one block of q + uint4 regQ; + uint block_offset = expert_offset * 4 + ib00 * ne01 * 4 + i01; + + regQ.s0 = src0_q[block_offset]; + regQ.s1 = src0_q[block_offset + ne01]; + regQ.s2 = src0_q[block_offset + ne01 * 2]; + regQ.s3 = src0_q[block_offset + ne01 * 3]; + + uint offset = i11 * ne00 / 4 + ib00 * 8; + + half regM = src0_m[ib00 * ne01 + i01 + expert_offset]; + half regS = src0_d[ib00 * ne01 + i01 + expert_offset]; + + float8 fp32x8 = q4_1_to_fp32_packed8(as_ushort2(regQ.s0), regS, regM); + + float4 shared_y4; + shared_y4 = read_imagef(src1, (offset + 0)); + float4 acc = shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (offset + 1)); + acc += shared_y4 * fp32x8.hi; + + fp32x8 = q4_1_to_fp32_packed8(as_ushort2(regQ.s1), regS, regM); + + shared_y4 = read_imagef(src1, (offset + 2)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (offset + 3)); + acc += shared_y4 * fp32x8.hi; + + + fp32x8 = q4_1_to_fp32_packed8(as_ushort2(regQ.s2), regS, regM); + + shared_y4 = read_imagef(src1, (offset + 4)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (offset + 5)); + acc += shared_y4 * fp32x8.hi; + + + fp32x8 = q4_1_to_fp32_packed8(as_ushort2(regQ.s3), regS, regM); + + shared_y4 = read_imagef(src1, (offset + 6)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (offset + 7)); + acc += shared_y4 * fp32x8.hi; + + sum += ((acc.s0 + acc.s1) + (acc.s2 + acc.s3)); + } + + // reduction in local memory, assumes #subgroups=4 + __local float reduceLM[SIMDGROUP_WIDTH * (N_SIMDGROUP - 1)]; + if (sgid == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = sum; + if (sgid == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = sum; + if (sgid == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = sum; + barrier(CLK_LOCAL_MEM_FENCE); + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 0 + slid]; + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 1 + slid]; + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 2 + slid]; + + // 1 outputs per thread in subgroup 0 + if (sgid == 0) { + dst = dst + (offsetd >> 2); + dst[i01 + i20 * ne01] = sum; + } + +} From 8ec91c91e17f043dc920439b0ef2b428b037614d Mon Sep 17 00:00:00 2001 From: guyfischman <138163913+guyfischman@users.noreply.github.com> Date: Tue, 12 May 2026 07:15:02 +0200 Subject: [PATCH 606/831] metal : promote mul_mv/mul_mm batch divisors to function constants (llama/22711) * metal : promote mul_mv/mul_mm batch divisors to function constants * metal : take op directly in get_pipeline_mul_mv_ext --- ggml/src/ggml-metal/ggml-metal-device.cpp | 46 +++++- ggml/src/ggml-metal/ggml-metal-device.h | 2 +- ggml/src/ggml-metal/ggml-metal-ops.cpp | 2 +- ggml/src/ggml-metal/ggml-metal.metal | 165 +++++++++++----------- 4 files changed, 127 insertions(+), 88 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index d211bf79f14..f0147af84c1 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -647,19 +647,30 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri(ggml_m return res; } -ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext(ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1, int nsg, int nxpsg, int r1ptg) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext(ggml_metal_library_t lib, const ggml_tensor * op, int nsg, int nxpsg, int r1ptg) { char base[256]; char name[256]; + const ggml_type tsrc0 = op->src[0]->type; + const ggml_type tsrc1 = op->src[1]->type; + const int ne12 = op->src[1]->ne[2]; + const int r2 = ne12 / op->src[0]->ne[2]; + const int r3 = op->src[1]->ne[3] / op->src[0]->ne[3]; + + GGML_ASSERT(ne12 <= INT16_MAX && r2 <= INT16_MAX && r3 <= INT16_MAX); + snprintf(base, 256, "kernel_mul_mv_ext_%s_%s_r1_%d", ggml_type_name(tsrc0), ggml_type_name(tsrc1), r1ptg); - snprintf(name, 256, "%s_nsg=%d_nxpsg=%d", base, nsg, nxpsg); + snprintf(name, 256, "%s_nsg=%d_nxpsg=%d_ne12=%d_r2=%d_r3=%d", base, nsg, nxpsg, ne12, r2, r3); ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); if (!res.pipeline) { ggml_metal_cv_t cv = ggml_metal_cv_init(); - ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0); - ggml_metal_cv_set_int16(cv, nxpsg, FC_MUL_MV + 1); + ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0); + ggml_metal_cv_set_int16(cv, nxpsg, FC_MUL_MV + 1); + ggml_metal_cv_set_int16(cv, (int16_t) ne12, FC_MUL_MV + 2); + ggml_metal_cv_set_int16(cv, (int16_t) r2, FC_MUL_MV + 3); + ggml_metal_cv_set_int16(cv, (int16_t) r3, FC_MUL_MV + 4); res = ggml_metal_library_compile_pipeline(lib, base, name, cv); @@ -687,8 +698,15 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm(ggml_meta ? (op->ne[0] % NRA != 0 || op->ne[1] % NRB != 0) : (op->ne[0] % 64 != 0 || op->ne[1] % 32 != 0); + GGML_ASSERT(op->src[1]->ne[2] <= INT16_MAX && op->src[1]->ne[3] <= INT16_MAX); + const int16_t ne12 = (int16_t) op->src[1]->ne[2]; + const int16_t ne13 = (int16_t) op->src[1]->ne[3]; + const int16_t r2 = (int16_t) (ne12 / op->src[0]->ne[2]); + const int16_t r3 = (int16_t) (ne13 / op->src[0]->ne[3]); + snprintf(base, 256, "kernel_mul_mm_%s_%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1)); - snprintf(name, 256, "%s_bci=%d_bco=%d", base, bc_inp, bc_out); + snprintf(name, 256, "%s_bci=%d_bco=%d_ne12=%d_ne13=%d_r2=%d_r3=%d", + base, bc_inp, bc_out, ne12, ne13, r2, r3); ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); if (!res.pipeline) { @@ -696,6 +714,10 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm(ggml_meta ggml_metal_cv_set_bool(cv, bc_inp, FC_MUL_MM + 0); ggml_metal_cv_set_bool(cv, bc_out, FC_MUL_MM + 1); + ggml_metal_cv_set_int16(cv, ne12, FC_MUL_MM + 2); + ggml_metal_cv_set_int16(cv, ne13, FC_MUL_MM + 3); + ggml_metal_cv_set_int16(cv, r2, FC_MUL_MM + 4); + ggml_metal_cv_set_int16(cv, r3, FC_MUL_MM + 5); res = ggml_metal_library_compile_pipeline(lib, base, name, cv); @@ -877,14 +899,21 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv(ggml_meta } }; + GGML_ASSERT(ne12 <= INT16_MAX && ne13 <= INT16_MAX); + const int16_t r2 = (int16_t) (ne12 / ne02); + const int16_t r3 = (int16_t) (ne13 / ne03); + snprintf(base, 256, "kernel_mul_mv_%s_%s%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1), suffix); - snprintf(name, 256, "%s_nsg=%d", base, nsg); + snprintf(name, 256, "%s_nsg=%d_ne12=%d_r2=%d_r3=%d", base, nsg, ne12, r2, r3); ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); if (!res.pipeline) { ggml_metal_cv_t cv = ggml_metal_cv_init(); - ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0); + ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0); + ggml_metal_cv_set_int16(cv, (int16_t) ne12, FC_MUL_MV + 2); + ggml_metal_cv_set_int16(cv, r2, FC_MUL_MV + 3); + ggml_metal_cv_set_int16(cv, r3, FC_MUL_MV + 4); res = ggml_metal_library_compile_pipeline(lib, base, name, cv); @@ -1102,6 +1131,9 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_id(ggml_m ggml_metal_cv_t cv = ggml_metal_cv_init(); ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0); + ggml_metal_cv_set_int16(cv, 1, FC_MUL_MV + 2); + ggml_metal_cv_set_int16(cv, 1, FC_MUL_MV + 3); + ggml_metal_cv_set_int16(cv, 1, FC_MUL_MV + 4); res = ggml_metal_library_compile_pipeline(lib, base, name, cv); diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index 4718ca083b0..1f212a92f98 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -129,7 +129,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri (ggml_metal_library_t lib, const struct ggml_tensor * op); -struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, const struct ggml_tensor * op, int nsg, int nxpsg, int r1ptg); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm_id_map0 (ggml_metal_library_t lib, int ne02, int ne20); diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 5fa162c875c..a114391c2e8 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -2120,7 +2120,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) { GGML_ABORT("unsupported ne11"); }; - auto pipeline = ggml_metal_library_get_pipeline_mul_mv_ext(lib, op->src[0]->type, op->src[1]->type, nsg, nxpsg, r1ptg); + auto pipeline = ggml_metal_library_get_pipeline_mul_mv_ext(lib, op, nsg, nxpsg, r1ptg); ggml_metal_kargs_mul_mv_ext args = { /*.ne00 =*/ ne00, diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 5c2ec8a4ab8..2d45de8cce2 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -3353,6 +3353,9 @@ static inline void helper_mv_reduce_and_write( constant short FC_mul_mv_nsg [[function_constant(FC_MUL_MV + 0)]]; constant short FC_mul_mv_nxpsg [[function_constant(FC_MUL_MV + 1)]]; +constant short FC_mul_mv_ne12 [[function_constant(FC_MUL_MV + 2)]]; +constant short FC_mul_mv_r2 [[function_constant(FC_MUL_MV + 3)]]; +constant short FC_mul_mv_r3 [[function_constant(FC_MUL_MV + 4)]]; template void mul_vec_q_n_f32_impl( @@ -3376,10 +3379,10 @@ void mul_vec_q_n_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + //const uint64_t offset0 = r0*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; //device const block_q_type * x = (device const block_q_type *) (src0 + offset0); @@ -3388,7 +3391,7 @@ void mul_vec_q_n_f32_impl( // pointers to src0 rows device const block_q_type * ax[NR0]; FOR_UNROLL (int row = 0; row < NR0; ++row) { - const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; ax[row] = (device const block_q_type *) ((device char *) src0 + offset0); } @@ -3462,8 +3465,8 @@ void kernel_mul_mv_q1_0_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; const uint64_t offset1 = r1*args.nb11 + (i12)*args.nb12 + (i13)*args.nb13; @@ -3471,7 +3474,7 @@ void kernel_mul_mv_q1_0_f32_impl( device const block_q1_0 * ax[nr0]; for (int row = 0; row < nr0; ++row) { - const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; ax[row] = (device const block_q1_0 *) ((device char *) src0 + offset0); } @@ -3590,10 +3593,10 @@ void kernel_mul_mv_q8_0_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + //const uint64_t offset0 = r0*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; //device const block_q8_0 * x = (device const block_q8_0 *) (src0 + offset0); @@ -3602,7 +3605,7 @@ void kernel_mul_mv_q8_0_f32_impl( // pointers to src0 rows device const block_q8_0 * ax[NR0]; FOR_UNROLL (short row = 0; row < NR0; ++row) { - const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; ax[row] = (device const block_q8_0 *) ((device char *) src0 + offset0); } @@ -3682,10 +3685,10 @@ void kernel_mul_mv_ext_q4_f32_impl( const int i11 = tgpig.y*r1ptg; const int i1m = tgpig.z; - const int i12 = i1m%args.ne12; - const int i13 = i1m/args.ne12; + const int i12 = i1m%FC_mul_mv_ne12; + const int i13 = i1m/FC_mul_mv_ne12; - const uint64_t offset0 = i01*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = i01*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = i11*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const q_t * xq = (i01 < args.ne01) ? (device const q_t *) (src0 + offset0) + tx/chpb : (device const q_t *) src0; @@ -3785,10 +3788,10 @@ void kernel_mul_mv_ext_q4x4_f32_impl( const int i11 = tgpig.y*r1ptg; const int i1m = tgpig.z; - const int i12 = i1m%args.ne12; - const int i13 = i1m/args.ne12; + const int i12 = i1m%FC_mul_mv_ne12; + const int i13 = i1m/FC_mul_mv_ne12; - const uint64_t offset0 = i01*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = i01*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = i11*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const q_t * xq = (i01 < args.ne01) ? (device const q_t *) (src0 + offset0) + tx/chpb : (device const q_t *) src0; @@ -4000,10 +4003,10 @@ void kernel_mul_mv_t_t_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + //const uint64_t offset0 = r0*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; //device const T0 * x = (device const T0 *) (src0 + offset0); @@ -4012,7 +4015,7 @@ void kernel_mul_mv_t_t_impl( // pointers to src0 rows device const T0 * ax [NR0]; FOR_UNROLL (short row = 0; row < NR0; ++row) { - const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; ax[row] = (device const T0 *) ((device char *) src0 + offset0); } @@ -4122,10 +4125,10 @@ void kernel_mul_mv_t_t_4_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + //const uint64_t offset0 = r0*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const T1 * y = (device const T1 *) (src1 + offset1); @@ -4135,7 +4138,7 @@ void kernel_mul_mv_t_t_4_impl( device const T0 * ax [NR0]; device const T04 * ax4[NR0]; FOR_UNROLL (short row = 0; row < NR0; ++row) { - const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; ax [row] = (device const T0 *) ((device char *) src0 + offset0); ax4[row] = (device const T04 *) ((device char *) src0 + offset0); @@ -4239,10 +4242,10 @@ void kernel_mul_mv_t_t_short_impl( return; } - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = r0*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; device const T0 * x = (device const T0 *) (src0 + offset0); @@ -7479,10 +7482,10 @@ void kernel_mul_mv_q2_K_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_q2_K * x = (device const block_q2_K *) (src0 + offset0); @@ -7584,10 +7587,10 @@ void kernel_mul_mv_q3_K_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_q3_K * x = (device const block_q3_K *) (src0 + offset0); @@ -7758,10 +7761,10 @@ void kernel_mul_mv_q4_K_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_q4_K * x = (device const block_q4_K *) (src0 + offset0); @@ -7870,10 +7873,10 @@ void kernel_mul_mv_q5_K_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_q5_K * x = (device const block_q5_K *) (src0 + offset0); @@ -8006,10 +8009,10 @@ void kernel_mul_mv_q6_K_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_q6_K * x = (device const block_q6_K *) (src0 + offset0); @@ -8111,10 +8114,10 @@ void kernel_mul_mv_iq2_xxs_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq2_xxs * x = (device const block_iq2_xxs *) (src0 + offset0); @@ -8219,10 +8222,10 @@ void kernel_mul_mv_iq2_xs_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq2_xs * x = (device const block_iq2_xs *) (src0 + offset0); @@ -8338,10 +8341,10 @@ void kernel_mul_mv_iq3_xxs_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq3_xxs * x = (device const block_iq3_xxs *) (src0 + offset0); @@ -8450,10 +8453,10 @@ void kernel_mul_mv_iq3_s_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq3_s * x = (device const block_iq3_s *) (src0 + offset0); @@ -8562,10 +8565,10 @@ void kernel_mul_mv_iq2_s_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq2_s * x = (device const block_iq2_s *) (src0 + offset0); @@ -8675,10 +8678,10 @@ void kernel_mul_mv_iq1_s_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq1_s * x = (device const block_iq1_s *) (src0 + offset0); @@ -8774,10 +8777,10 @@ void kernel_mul_mv_iq1_m_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq1_m * x = (device const block_iq1_m *) (src0 + offset0); @@ -8883,10 +8886,10 @@ void kernel_mul_mv_iq4_nl_f32_impl( const int first_row = (r0 * NSG + sgitg) * NR0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq4_nl * x = (device const block_iq4_nl *) (src0 + offset0); @@ -8992,10 +8995,10 @@ void kernel_mul_mv_iq4_xs_f32_impl( const int im = tgpig.z; const int first_row = (r0 * NSG + sgitg) * NR0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq4_xs * x = (device const block_iq4_xs *) (src0 + offset0); @@ -9103,10 +9106,10 @@ void kernel_mul_mv_mxfp4_f32_impl( const int first_row = (r0 * NSG + sgitg) * NR0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_mxfp4 * x = (device const block_mxfp4 *) (src0 + offset0); @@ -9321,6 +9324,10 @@ kernel void kernel_diag_f32( constant bool FC_mul_mm_bc_inp [[function_constant(FC_MUL_MM + 0)]]; constant bool FC_mul_mm_bc_out [[function_constant(FC_MUL_MM + 1)]]; +constant short FC_mul_mm_ne12 [[function_constant(FC_MUL_MM + 2)]]; +constant short FC_mul_mm_ne13 [[function_constant(FC_MUL_MM + 3)]]; +constant short FC_mul_mm_r2 [[function_constant(FC_MUL_MM + 4)]]; +constant short FC_mul_mm_r3 [[function_constant(FC_MUL_MM + 5)]]; // each block_q contains 16*nl weights #ifdef GGML_METAL_HAS_TENSOR @@ -9347,11 +9354,11 @@ kernel void kernel_mul_mm( // Batch dimension handling const int im = tgpig.z; - const int i12 = im % args.ne12; - const int i13 = im / args.ne12; + const int i12 = im % FC_mul_mm_ne12; + const int i13 = im / FC_mul_mm_ne12; // Batch offsets for srcA and srcB - const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = (i12/FC_mul_mm_r2)*args.nb02 + (i13/FC_mul_mm_r3)*args.nb03; // Tile dimensions constexpr int NRB = SZ_SIMDGROUP * N_MM_BLOCK_X * N_MM_SIMD_GROUP_X; @@ -9490,10 +9497,10 @@ kernel void kernel_mul_mm( short il = il0; - const int i12 = im%args.ne12; - const int i13 = im/args.ne12; + const int i12 = im % FC_mul_mm_ne12; + const int i13 = im / FC_mul_mm_ne12; - const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = (i12/FC_mul_mm_r2)*args.nb02 + (i13/FC_mul_mm_r3)*args.nb03; const short offset1 = il0/nl; device const block_q * x = (device const block_q *)(src0 + args.nb01*(r0 + lr0) + offset0) + offset1; From 20895abdbd5f076118be9f8ddd0170448a399e79 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Tue, 12 May 2026 04:41:58 -0500 Subject: [PATCH 607/831] vulkan: Check shared memory size for mmq shaders (llama/22693) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 168 ++++++++++++++++++++++++--- 1 file changed, 149 insertions(+), 19 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 7e450a559dd..90ea7cc1a9b 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -681,6 +681,15 @@ struct vk_device_struct { bool mul_mat_id_m[GGML_TYPE_COUNT]; bool mul_mat_id_s[GGML_TYPE_COUNT]; + // Separate flags for the q8_1 (integer dot) mmq path, whose shader uses + // a different shared-memory layout than the float matmul shaders. + bool mul_mat_l_int[GGML_TYPE_COUNT]; + bool mul_mat_m_int[GGML_TYPE_COUNT]; + bool mul_mat_s_int[GGML_TYPE_COUNT]; + bool mul_mat_id_l_int[GGML_TYPE_COUNT]; + bool mul_mat_id_m_int[GGML_TYPE_COUNT]; + bool mul_mat_id_s_int[GGML_TYPE_COUNT]; + vk::DescriptorSetLayout dsl; vk_matmul_pipeline pipeline_matmul_f32 {}; @@ -3207,6 +3216,70 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec return supported; } +// Shmem usage for the q8_1 mmq shader (mul_mmq.comp), which uses +// block_a_cache / block_b_cache layouts (see mul_mmq_shmem_types.glsl) rather +// than the float load buffers checked by ggml_vk_matmul_shmem_support. +// Sizes follow std430 rules. Returns false for types without a q8_1 pipeline. +static bool ggml_vk_matmul_int_shmem_support(const vk_device& device, const std::vector& warptile, bool mul_mat_id, ggml_type src0_type) { + + // FLOAT_TYPE in the shader is float16_t with fp16 support, otherwise float. + const uint32_t fp_size = device->fp16 ? 2u : 4u; + const uint32_t fp_align = fp_size; + const uint32_t fp2_size = 2u * fp_size; + const uint32_t fp2_align = device->fp16 ? 4u : 8u; + + struct member { uint32_t size, align; }; + auto std430_size = [](std::initializer_list members) { + uint32_t off = 0, struct_align = 1; + for (const auto &m : members) { + off = (off + m.align - 1) & ~(m.align - 1); + off += m.size; + struct_align = std::max(struct_align, m.align); + } + return (off + struct_align - 1) & ~(struct_align - 1); + }; + + uint32_t block_a_size = 0; + switch (src0_type) { + case GGML_TYPE_Q4_0: block_a_size = std430_size({{16, 4}, {fp_size, fp_align}}); break; // qs[16/4] + dm + case GGML_TYPE_Q4_1: block_a_size = std430_size({{16, 4}, {fp2_size, fp2_align}}); break; // qs[16/4] + dm(vec2) + case GGML_TYPE_Q5_0: block_a_size = std430_size({{16, 4}, {4, 4}, {fp_size, fp_align}}); break; // qs[16/4] + qh + dm + case GGML_TYPE_Q5_1: block_a_size = std430_size({{16, 4}, {4, 4}, {fp2_size, fp2_align}}); break; // qs[16/4] + qh + dm(vec2) + case GGML_TYPE_Q8_0: block_a_size = std430_size({{32, 4}, {fp_size, fp_align}}); break; // qs[8] + dm + case GGML_TYPE_MXFP4: block_a_size = std430_size({{32, 4}, {fp_size, fp_align}}); break; // qs[8] + d + case GGML_TYPE_Q2_K: block_a_size = std430_size({{ 8, 4}, {2, 2}, {fp2_size, fp2_align}}); break; // qs[2] + scales(u8vec2) + dm(vec2) + case GGML_TYPE_Q3_K: block_a_size = std430_size({{16, 4}, {fp2_size, fp2_align}}); break; // qs[4] + d_scales(vec2) + case GGML_TYPE_Q4_K: block_a_size = std430_size({{16, 4}, {fp2_size, fp2_align}}); break; // qs[4] + dm(vec2) + case GGML_TYPE_Q5_K: block_a_size = std430_size({{32, 4}, {fp2_size, fp2_align}}); break; // qs[8] + dm(vec2) + case GGML_TYPE_Q6_K: block_a_size = std430_size({{32, 4}, {fp2_size, fp2_align}}); break; // qs[8] + d_scales(vec2) + default: + return false; + } + + // block_b_cache: { int32_t qs[8]; FLOAT_TYPEV2 ds; } + const uint32_t block_b_size = std430_size({{32, 4}, {fp2_size, fp2_align}}); + + const uint32_t BM = warptile[1]; + const uint32_t BN = warptile[2]; + // mul_mmq.comp: BK_STEP=1 for MUL_MAT_ID, 4 otherwise. + const uint32_t BK_STEP = mul_mat_id ? 1u : 4u; + + const uint32_t buf_a_size = BM * BK_STEP * block_a_size; + const uint32_t buf_b_size = BN * BK_STEP * block_b_size; + const uint32_t mmid_row_ids = mul_mat_id ? (BN * 2u * (uint32_t)sizeof(uint16_t)) : 0u; + + const uint32_t warps = warptile[0] / warptile[10]; + const uint32_t ballots_sh = mul_mat_id ? (warps * 4u * (uint32_t)sizeof(uint32_t)) : 0u; + + const uint32_t total_size = buf_a_size + buf_b_size + mmid_row_ids + ballots_sh; + const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize; + + VK_LOG_DEBUG("ggml_vk_matmul_int_shmem_support(warptile=(" << warptile[0] << "," << warptile[1] << "," << warptile[2] << "), " + "mul_mat_id=" << mul_mat_id << ", src0_type=" << ggml_type_name(src0_type) << ", total=" << total_size << ", supported=" << supported); + + return supported; +} + struct GpuPipelineConfig { // GPU architecture identifier. // Example: vk_device_architecture::AMD_GCN @@ -3453,6 +3526,40 @@ static void ggml_vk_load_shaders(vk_device& device) { } else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmqid, true, t)) { device->mul_mat_id_l[i] = false; } + + // The q8_1 mmq path has its own (larger) shmem layout, check it separately. + // K-quants use the _int_k warptiles, others use _int. + const bool is_k_quant = (t == GGML_TYPE_Q2_K || t == GGML_TYPE_Q3_K || + t == GGML_TYPE_Q4_K || t == GGML_TYPE_Q5_K || + t == GGML_TYPE_Q6_K); + const auto & s_int = is_k_quant ? s_warptile_mmq_int_k : s_warptile_mmq_int; + const auto & m_int = is_k_quant ? m_warptile_mmq_int_k : m_warptile_mmq_int; + const auto & l_int = is_k_quant ? l_warptile_mmq_int_k : l_warptile_mmq_int; + const auto & s_intid = is_k_quant ? s_warptile_mmqid_int_k : s_warptile_mmqid_int; + const auto & m_intid = is_k_quant ? m_warptile_mmqid_int_k : m_warptile_mmqid_int; + const auto & l_intid = is_k_quant ? l_warptile_mmqid_int_k : l_warptile_mmqid_int; + + if (!ggml_vk_matmul_int_shmem_support(device, s_int, false, t)) { + device->mul_mat_s_int[i] = false; + device->mul_mat_m_int[i] = false; + device->mul_mat_l_int[i] = false; + } else if (!ggml_vk_matmul_int_shmem_support(device, m_int, false, t)) { + device->mul_mat_m_int[i] = false; + device->mul_mat_l_int[i] = false; + } else if (!ggml_vk_matmul_int_shmem_support(device, l_int, false, t)) { + device->mul_mat_l_int[i] = false; + } + + if (!ggml_vk_matmul_int_shmem_support(device, s_intid, true, t)) { + device->mul_mat_id_s_int[i] = false; + device->mul_mat_id_m_int[i] = false; + device->mul_mat_id_l_int[i] = false; + } else if (!ggml_vk_matmul_int_shmem_support(device, m_intid, true, t)) { + device->mul_mat_id_m_int[i] = false; + device->mul_mat_id_l_int[i] = false; + } else if (!ggml_vk_matmul_int_shmem_support(device, l_intid, true, t)) { + device->mul_mat_id_l_int[i] = false; + } } } @@ -5613,6 +5720,13 @@ static vk_device ggml_vk_get_device(size_t idx) { device->mul_mat_id_s[i] = true; break; } + + device->mul_mat_l_int[i] = true; + device->mul_mat_m_int[i] = true; + device->mul_mat_s_int[i] = true; + device->mul_mat_id_l_int[i] = true; + device->mul_mat_id_m_int[i] = true; + device->mul_mat_id_s_int[i] = true; } @@ -7220,6 +7334,13 @@ static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, uint32_t m, static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type, ggml_type src1_type) { VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")"); + // The q8_1 (integer dot) mmq path uses a different shader with its own + // shared-memory layout, so use the int-specific availability flags. + const bool is_q8_1 = (src1_type == GGML_TYPE_Q8_1); + const bool mm_l = is_q8_1 ? ctx->device->mul_mat_l_int[src0_type] : ctx->device->mul_mat_l[src0_type]; + const bool mm_m = is_q8_1 ? ctx->device->mul_mat_m_int[src0_type] : ctx->device->mul_mat_m[src0_type]; + const bool mm_s = is_q8_1 ? ctx->device->mul_mat_s_int[src0_type] : ctx->device->mul_mat_s[src0_type]; + if (ctx->device->coopmat2) { const uint32_t shader_core_count = ctx->device->shader_core_count; const uint32_t tiles_l = CEIL_DIV(m, mmp->a_l->wg_denoms[0]) * CEIL_DIV(n, mmp->a_l->wg_denoms[1]); @@ -7236,26 +7357,24 @@ static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, // split_k==3 with large tiles likely better than medium tiles with no split_k. (tiles_l <= shader_core_count / 3 && tiles_m > shader_core_count / 2); - if ((ctx->device->mul_mat_l[src0_type] && (n > crossover_large && prefer_large)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_s[src0_type])) { + if ((mm_l && (n > crossover_large && prefer_large)) || (!mm_m && !mm_s)) { return aligned ? mmp->a_l : mmp->l; } // Use medium shader when the N dimension is greater than the small shader's tile size uint32_t crossover_medium = mmp->s->wg_denoms[1]; - if ((ctx->device->mul_mat_m[src0_type] && (n > crossover_medium)) || !ctx->device->mul_mat_s[src0_type]) { + if ((mm_m && (n > crossover_medium)) || !mm_s) { return aligned ? mmp->a_m : mmp->m; } return aligned ? mmp->a_s : mmp->s; } - if ((ctx->device->mul_mat_s[src0_type] && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_l[src0_type])) { + if ((mm_s && (m <= 32 || n <= 32)) || (!mm_m && !mm_l)) { return aligned ? mmp->a_s : mmp->s; } - if ((ctx->device->mul_mat_m[src0_type] && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_l[src0_type]) { + if ((mm_m && (m <= 64 || n <= 64)) || !mm_l) { return aligned ? mmp->a_m : mmp->m; } return aligned ? mmp->a_l : mmp->l; - - GGML_UNUSED(src1_type); } static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type, ggml_type src1_type) { @@ -7312,35 +7431,42 @@ static void ggml_vk_matmul( ctx->prealloc_split_k_need_sync = true; } -static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type) { - VK_LOG_DEBUG("ggml_vk_guess_matmul_id_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")"); +static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type, ggml_type src1_type) { + VK_LOG_DEBUG("ggml_vk_guess_matmul_id_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")"); + + // The q8_1 (integer dot) mmq path uses a different shader with its own + // shared-memory layout, so use the int-specific availability flags. + const bool is_q8_1 = (src1_type == GGML_TYPE_Q8_1); + const bool mm_l = is_q8_1 ? ctx->device->mul_mat_id_l_int[src0_type] : ctx->device->mul_mat_id_l[src0_type]; + const bool mm_m = is_q8_1 ? ctx->device->mul_mat_id_m_int[src0_type] : ctx->device->mul_mat_id_m[src0_type]; + const bool mm_s = is_q8_1 ? ctx->device->mul_mat_id_s_int[src0_type] : ctx->device->mul_mat_id_s[src0_type]; if (ctx->device->coopmat2) { // Use large shader when the N dimension is greater than the medium shader's tile size uint32_t crossover_large = mmp->m->wg_denoms[1]; - if ((ctx->device->mul_mat_id_l[src0_type] && (n > crossover_large)) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_s[src0_type])) { + if ((mm_l && (n > crossover_large)) || (!mm_m && !mm_s)) { return aligned ? mmp->a_l : mmp->l; } // Use medium shader when the N dimension is greater than the small shader's tile size uint32_t crossover_medium = mmp->s->wg_denoms[1]; - if ((ctx->device->mul_mat_id_m[src0_type] && (n > crossover_medium)) || !ctx->device->mul_mat_id_s[src0_type]) { + if ((mm_m && (n > crossover_medium)) || !mm_s) { return aligned ? mmp->a_m : mmp->m; } return aligned ? mmp->a_s : mmp->s; } - if ((ctx->device->mul_mat_id_s[src0_type] && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_l[src0_type])) { + if ((mm_s && (m <= 32 || n <= 32)) || (!mm_m && !mm_l)) { return aligned ? mmp->a_s : mmp->s; } - if ((ctx->device->mul_mat_id_m[src0_type] && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_id_l[src0_type]) { + if ((mm_m && (m <= 64 || n <= 64)) || !mm_l) { return aligned ? mmp->a_m : mmp->m; } return aligned ? mmp->a_l : mmp->l; } -static uint32_t ggml_vk_guess_matmul_id_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type) { - VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ")"); - return ggml_vk_guess_matmul_id_pipeline(ctx, mmp, m, n, true, src0_type)->align; +static uint32_t ggml_vk_guess_matmul_id_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type, ggml_type src1_type) { + VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")"); + return ggml_vk_guess_matmul_id_pipeline(ctx, mmp, m, n, true, src0_type, src1_type)->align; } static void ggml_vk_matmul_id( @@ -7636,10 +7762,12 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub // Not implemented GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT - const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? f16_type : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type))); + const ggml_type effective_src1_type = quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type); + + const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? f16_type : src0->type, effective_src1_type)); const bool aligned = !quantize_y && ne10 == kpad && ne01 > 8 && ne11 > 8; - vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? f16_type : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type)); + vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? f16_type : src0->type, effective_src1_type); if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) { pipeline = ggml_vk_get_64b_indexing_pipeline(ctx, pipeline); @@ -8471,10 +8599,12 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& // Not implemented GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT - const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? f16_type : src0->type)); + const ggml_type effective_src1_type = quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type); + + const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? f16_type : src0->type, effective_src1_type)); const bool aligned = !quantize_y && ne10 == kpad && ne01 > 8 && nei1 > 8; - vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? f16_type : src0->type); + vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? f16_type : src0->type, effective_src1_type); if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) { pipeline = ggml_vk_get_64b_indexing_pipeline(ctx, pipeline); From be5a35cceebee8b70c75ced22336b4ae4a8882af Mon Sep 17 00:00:00 2001 From: Masato Nakasaka Date: Tue, 12 May 2026 03:15:34 -0700 Subject: [PATCH 608/831] vulkan: Fix Windows performance regression on Intel GPU BF16 workloads for Xe2 and newer (llama/22461) * refactor * Use l_warptile only when coopamt is available for BF16 --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 21 ++++++++------------- 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 90ea7cc1a9b..a0a556206d5 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -4260,11 +4260,6 @@ static void ggml_vk_load_shaders(vk_device& device) { m_wg_denoms = { 64, 64, 1 }; s_wg_denoms = { 32, 32, 1 }; - if (device->vendor_id == VK_VENDOR_ID_INTEL && device->architecture == INTEL_XE2) { - // Xe2/Xe3 - bf16 warptile performance tuning - l_warptile = { 512, 128, 128, 16, subgroup_size_8, 32, 2, 4, 4, 1, subgroup_size_8 }; - } - CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); } @@ -5689,19 +5684,19 @@ static vk_device ggml_vk_get_device(size_t idx) { device->mul_mat_id_m[i] = true; device->mul_mat_id_s[i] = true; break; - case VK_VENDOR_ID_INTEL: - if (!device->coopmat_support || device->architecture != INTEL_XE2) { - device->mul_mat_l[i] = false; - device->mul_mat_id_l[i] = false; - } else { - device->mul_mat_l[i] = true; // if coopmat & XE2+, allow large matmul warptile config for Intel - device->mul_mat_id_l[i] = true; - } + case VK_VENDOR_ID_INTEL: { + // Current Windows driver does not expose BF16 support. + // We only want to use l_warptile if coopmat is available and is Xe2+ + const bool xe2_with_coopmat = device->coopmat_support && device->architecture == INTEL_XE2; + const bool use_l_warptile = (i == GGML_TYPE_BF16) ? (device->coopmat_bf16_support && xe2_with_coopmat) : xe2_with_coopmat; + device->mul_mat_l[i] = use_l_warptile; + device->mul_mat_id_l[i] = use_l_warptile; device->mul_mat_m[i] = true; device->mul_mat_s[i] = true; device->mul_mat_id_m[i] = true; device->mul_mat_id_s[i] = true; break; + } case VK_VENDOR_ID_APPLE: device->mul_mat_l[i] = false; device->mul_mat_m[i] = true; From a9bcbf559577c0a637819a704ad36091f0953fec Mon Sep 17 00:00:00 2001 From: Chen Yuan Date: Tue, 12 May 2026 10:27:04 -0400 Subject: [PATCH 609/831] ggml-webgpu: address precision issues for multimodal (llama/22808) * fix(mixed-types): use f32 for precision and update the shared memory calculation logic for f32 * fix(unary): correct the gelu, gelu quick and gelu erf functions * fix(flash-attn-tile): fix the hardcode v type * fix(flash_attn): fix tile path * fix: pass editorconfig and address the type conflicts * fix: remove reduant pipeline keys * fix: remove inline min/max group size functions and revert the flash attn path order * fix: use clamp to avoid NaN for GELU * fix: use the right range for exp, 80 is safer for f32 exp --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 193 ++++++++++++------ ggml/src/ggml-webgpu/ggml-webgpu.cpp | 30 ++- .../wgsl-shaders/flash_attn_tile.wgsl | 87 +++++--- .../wgsl-shaders/flash_attn_vec_reduce.wgsl | 10 +- .../wgsl-shaders/flash_attn_vec_split.wgsl | 112 +++++----- ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl | 49 +++-- 6 files changed, 295 insertions(+), 186 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index c6dc2c21147..932a01d385e 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -91,6 +91,7 @@ struct ggml_webgpu_shader_lib_context { uint32_t sg_mat_m = 0; uint32_t sg_mat_n = 0; uint32_t sg_mat_k = 0; + uint32_t min_subgroup_size = 0; uint32_t max_subgroup_size = 0; }; @@ -531,7 +532,9 @@ enum ggml_webgpu_flash_attn_path : uint32_t { }; struct ggml_webgpu_flash_attn_pipeline_key { + ggml_type q_type; ggml_type kv_type; + ggml_type dst_type; uint32_t head_dim_qk; uint32_t head_dim_v; bool kv_direct; @@ -542,16 +545,19 @@ struct ggml_webgpu_flash_attn_pipeline_key { uint32_t path; bool operator==(const ggml_webgpu_flash_attn_pipeline_key & other) const { - return kv_type == other.kv_type && head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v && - kv_direct == other.kv_direct && kv_overlap == other.kv_overlap && has_mask == other.has_mask && - has_sinks == other.has_sinks && uses_logit_softcap == other.uses_logit_softcap && path == other.path; + return q_type == other.q_type && kv_type == other.kv_type && dst_type == other.dst_type && + head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v && kv_direct == other.kv_direct && + kv_overlap == other.kv_overlap && has_mask == other.has_mask && has_sinks == other.has_sinks && + uses_logit_softcap == other.uses_logit_softcap && path == other.path; } }; struct ggml_webgpu_flash_attn_pipeline_key_hash { size_t operator()(const ggml_webgpu_flash_attn_pipeline_key & key) const { size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.q_type); ggml_webgpu_hash_combine(seed, key.kv_type); + ggml_webgpu_hash_combine(seed, key.dst_type); ggml_webgpu_hash_combine(seed, key.head_dim_qk); ggml_webgpu_hash_combine(seed, key.head_dim_v); ggml_webgpu_hash_combine(seed, key.kv_direct); @@ -595,14 +601,14 @@ inline uint32_t ggml_webgpu_flash_attn_pick_vec_ne(const ggml_webgpu_flash_attn_ } inline ggml_webgpu_flash_attn_pipeline_key ggml_webgpu_flash_attn_make_pipeline_key( - const ggml_webgpu_shader_lib_context & context, - uint32_t path) { + const ggml_webgpu_shader_lib_context & context, + const ggml_webgpu_flash_attn_decisions & decisions) { const bool has_mask = context.src3 != nullptr; const bool has_sinks = context.src4 != nullptr; bool kv_direct = false; - if (path != GGML_WEBGPU_FLASH_ATTN_PATH_TILE) { + if (decisions.path != GGML_WEBGPU_FLASH_ATTN_PATH_TILE) { uint32_t kv_direct_align = GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH; - if (path == GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX) { + if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX) { kv_direct_align = context.sg_mat_k; } kv_direct = (context.src1->type == GGML_TYPE_F16) && @@ -611,7 +617,9 @@ inline ggml_webgpu_flash_attn_pipeline_key ggml_webgpu_flash_attn_make_pipeline_ } ggml_webgpu_flash_attn_pipeline_key key = {}; + key.q_type = context.src0->type; key.kv_type = context.src1->type; + key.dst_type = context.dst->type; key.head_dim_qk = (uint32_t) context.src0->ne[0]; key.head_dim_v = (uint32_t) context.src2->ne[0]; key.kv_direct = kv_direct; @@ -619,13 +627,14 @@ inline ggml_webgpu_flash_attn_pipeline_key ggml_webgpu_flash_attn_make_pipeline_ key.has_mask = has_mask; key.has_sinks = has_sinks; key.uses_logit_softcap = ggml_get_op_params_f32(context.dst, 2) != 0.0f; - key.path = path; + key.path = decisions.path; return key; } struct ggml_webgpu_flash_attn_vec_reduce_pipeline_key { - uint32_t head_dim_v; - uint32_t wg_size; + uint32_t head_dim_v; + uint32_t wg_size; + ggml_type dst_type; }; struct ggml_webgpu_flash_attn_vec_reduce_pipeline_key_hash { @@ -633,13 +642,14 @@ struct ggml_webgpu_flash_attn_vec_reduce_pipeline_key_hash { size_t seed = 0; ggml_webgpu_hash_combine(seed, key.head_dim_v); ggml_webgpu_hash_combine(seed, key.wg_size); + ggml_webgpu_hash_combine(seed, key.dst_type); return seed; } }; inline bool operator==(const ggml_webgpu_flash_attn_vec_reduce_pipeline_key & lhs, const ggml_webgpu_flash_attn_vec_reduce_pipeline_key & rhs) { - return lhs.head_dim_v == rhs.head_dim_v && lhs.wg_size == rhs.wg_size; + return lhs.head_dim_v == rhs.head_dim_v && lhs.wg_size == rhs.wg_size && lhs.dst_type == rhs.dst_type; } struct ggml_webgpu_flash_attn_blk_pipeline_key { @@ -662,19 +672,32 @@ inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile, uint32_t head_dim_qk, uint32_t head_dim_v, bool has_mask, - bool kv_direct) { + bool kv_direct, + uint32_t path = GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX) { const uint32_t max_head_dim = std::max(head_dim_qk, head_dim_v); size_t f16_elems = 0; size_t f32_elems = 0; - f16_elems += q_tile * head_dim_qk; // q_shmem + if (path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { + f32_elems += head_dim_qk; // q_shmem + if (!kv_direct) { + f32_elems += kv_tile * max_head_dim; // kv_shmem + } + f32_elems += head_dim_v; // o_shmem + if (has_mask) { + f32_elems += kv_tile; // mask_shmem + } + f32_elems += kv_tile; // inter_shmem + return f32_elems * GGML_WEBGPU_F32_SIZE_BYTES; + } + f32_elems += q_tile * head_dim_qk; // q_shmem if (!kv_direct) { - f16_elems += kv_tile * max_head_dim; // kv_shmem + f32_elems += kv_tile * max_head_dim; // kv_shmem } - f16_elems += q_tile * head_dim_v; // o_shmem + f32_elems += q_tile * head_dim_v; // o_shmem if (has_mask) { - f16_elems += q_tile * kv_tile; // mask_shmem + f32_elems += q_tile * kv_tile; // mask_shmem } - f16_elems += q_tile * kv_tile; // inter_shmem + f32_elems += q_tile * kv_tile; // inter_shmem f32_elems += q_tile; // row_max_shmem f32_elems += q_tile; // exp_sum_shmem return f16_elems * GGML_WEBGPU_F16_SIZE_BYTES + f32_elems * GGML_WEBGPU_F32_SIZE_BYTES; @@ -684,27 +707,27 @@ inline uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_shader_lib_ const ggml_webgpu_flash_attn_pipeline_key & key) { const size_t limit_bytes = context.wg_mem_limit_bytes; uint32_t q_tile = context.sg_mat_m; - uint32_t kv_granularity = context.sg_mat_n; + uint32_t kv_granularity = std::max(1u, context.sg_mat_n); if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE) { q_tile = GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE; - kv_granularity = std::max(1u, context.max_subgroup_size); + kv_granularity = 1u; } else if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { q_tile = 1u; kv_granularity = 8u; } - const size_t base_q_bytes = (key.head_dim_qk + key.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES + - 2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES; - size_t bytes_per_kv = 0; - if (!key.kv_direct) { - bytes_per_kv += std::max(key.head_dim_qk, key.head_dim_v); + const size_t base_q_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(q_tile, 0, key.head_dim_qk, key.head_dim_v, + key.has_mask, key.kv_direct, key.path); + if (limit_bytes <= base_q_bytes) { + return 0; } - if (key.has_mask) { - bytes_per_kv += q_tile; + const size_t one_kv_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(q_tile, 1, key.head_dim_qk, key.head_dim_v, + key.has_mask, key.kv_direct, key.path); + const size_t bytes_per_kv = one_kv_bytes - base_q_bytes; + if (bytes_per_kv == 0) { + return 0; } - bytes_per_kv += q_tile; - bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES; - const uint32_t max_kv_tile = (limit_bytes - base_q_bytes) / bytes_per_kv; - return (max_kv_tile / kv_granularity) * kv_granularity; + const size_t max_kv_tile = (limit_bytes - base_q_bytes) / bytes_per_kv; + return (uint32_t) ((max_kv_tile / kv_granularity) * kv_granularity); } inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions( @@ -731,14 +754,18 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions( (v_offset_elems % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0u); const bool kv_vec_type_supported = K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0; - const bool use_vec = context.supports_subgroups && (context.src0->ne[1] < 20) && (context.src0->ne[0] % 32 == 0) && - (context.src2->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && - kv_vec_type_supported && (K->type != GGML_TYPE_F16 || f16_vec4_aligned) && - (context.src2->type == K->type); + const bool use_vec = context.supports_subgroups && (context.src0->ne[1] < 20) && (context.src0->ne[0] % 32 == 0) && + (context.src2->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && + kv_vec_type_supported && (K->type != GGML_TYPE_F16 || f16_vec4_aligned) && + (context.src2->type == K->type); + const bool tile_can_dispatch_all_q_rows = + context.max_subgroup_size > 0 && + context.max_wg_size >= GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE * context.max_subgroup_size; const bool use_tile = context.supports_subgroups && !context.supports_subgroup_matrix && K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16 && f16_vec4_aligned && (context.src0->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && - (context.src2->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && !use_vec; + (context.src2->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && + tile_can_dispatch_all_q_rows && !use_vec; decisions.path = use_vec ? GGML_WEBGPU_FLASH_ATTN_PATH_VEC : use_tile ? GGML_WEBGPU_FLASH_ATTN_PATH_TILE : @@ -749,7 +776,7 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions( return decisions; } - const ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context, decisions.path); + const ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context, decisions); decisions.kv_direct = key.kv_direct; const uint32_t max_kv_tile = ggml_webgpu_flash_attn_max_kv_tile(context, key); // invalidate if even the smallest kv_tile doesn't fit in shared memory @@ -778,21 +805,20 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions( std::min(64u, max_kv_tile) : std::min(max_kv_tile, context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES); decisions.wg_size = decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? - GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE : + std::min(std::max(1u, context.max_wg_size), + std::max(GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE, + GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE * context.max_subgroup_size)) : std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE); - if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE) { - const uint32_t tile_kv_granularity = std::max(1u, context.max_subgroup_size); - decisions.kv_tile = - std::max(tile_kv_granularity, (decisions.kv_tile / tile_kv_granularity) * tile_kv_granularity); + if (decisions.kv_tile == 0) { + return decisions; } if (decisions.kv_direct) { GGML_ASSERT(decisions.kv_tile <= GGML_WEBGPU_KV_SEQ_PAD); while (GGML_WEBGPU_KV_SEQ_PAD % decisions.kv_tile != 0) { - decisions.kv_tile -= decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? - std::max(1u, context.max_subgroup_size) : - context.sg_mat_n; + decisions.kv_tile -= + decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? context.min_subgroup_size : context.sg_mat_n; } } return decisions; @@ -1577,7 +1603,7 @@ class ggml_webgpu_shader_lib { key.type = context.dst->type; key.d_state = (int) context.src0->ne[0]; key.xbc_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src4) && - ggml_webgpu_tensor_overlap(context.src1, context.src5); + ggml_webgpu_tensor_overlap(context.src1, context.src5); auto it = ssm_scan_pipelines.find(key); if (it != ssm_scan_pipelines.end()) { @@ -1694,10 +1720,10 @@ class ggml_webgpu_shader_lib { ggml_webgpu_mul_mat_vec_pipeline_key key = {}; key.src0_type = context.src0->type; key.src1_type = context.src1->type; - key.vectorized = (context.src0->ne[0] % 4 == 0 && + key.vectorized = (context.src0->ne[0] % 4 == 0 && (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? - 1 : - 0; + 1 : + 0; auto it = mul_mat_vec_pipelines.find(key); if (it != mul_mat_vec_pipelines.end()) { @@ -1805,13 +1831,13 @@ class ggml_webgpu_shader_lib { webgpu_pipeline get_mul_mat_fast_pipeline(const ggml_webgpu_shader_lib_context & context) { ggml_webgpu_mul_mat_pipeline_key key = {}; - key.src0_type = context.src0->type; - key.src1_type = context.src1->type; - key.vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 && - (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? - 1 : - 0; - key.use_subgroup_matrix = context.supports_subgroup_matrix; + key.src0_type = context.src0->type; + key.src1_type = context.src1->type; + key.vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 && + (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? + 1 : + 0; + key.use_subgroup_matrix = context.supports_subgroup_matrix; auto it = mul_mat_fast_pipelines.find(key); if (it != mul_mat_fast_pipelines.end()) { @@ -2074,10 +2100,10 @@ class ggml_webgpu_shader_lib { key.src0_type = context.src0->type; key.src1_type = context.src1->type; key.n_experts = context.src0->ne[2]; - key.vectorized = (context.src0->ne[0] % 4 == 0 && context.src0->ne[1] % 4 == 0 && + key.vectorized = (context.src0->ne[0] % 4 == 0 && context.src0->ne[1] % 4 == 0 && (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? - 1 : - 0; + 1 : + 0; auto it = mul_mat_id_pipelines.find(key); if (it != mul_mat_id_pipelines.end()) { @@ -2194,10 +2220,10 @@ class ggml_webgpu_shader_lib { key.src0_type = context.src0->type; key.src1_type = context.src1->type; key.n_experts = context.src0->ne[2]; - key.vectorized = (context.src0->ne[0] % 4 == 0 && + key.vectorized = (context.src0->ne[0] % 4 == 0 && (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? - 1 : - 0; + 1 : + 0; auto it = mul_mat_id_vec_pipelines.find(key); if (it != mul_mat_id_vec_pipelines.end()) { @@ -2558,7 +2584,7 @@ class ggml_webgpu_shader_lib { const ggml_webgpu_flash_attn_decisions decisions = ggml_webgpu_flash_attn_get_decisions(context, storage_offset_alignment); GGML_ASSERT(decisions.path != GGML_WEBGPU_FLASH_ATTN_PATH_NONE); - ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context, decisions.path); + ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context, decisions); auto it = flash_attn_pipelines.find(key); if (it != flash_attn_pipelines.end()) { return it->second; @@ -2586,6 +2612,30 @@ class ggml_webgpu_shader_lib { } variant += std::string("_") + ggml_type_name(key.kv_type); + switch (key.q_type) { + case GGML_TYPE_F32: + defines.push_back("Q_F32"); + break; + case GGML_TYPE_F16: + defines.push_back("Q_F16"); + break; + default: + GGML_ABORT("Unsupported Q type for flash attention shader"); + } + variant += std::string("_q") + ggml_type_name(key.q_type); + + switch (key.dst_type) { + case GGML_TYPE_F32: + defines.push_back("DST_F32"); + break; + case GGML_TYPE_F16: + defines.push_back("DST_F16"); + break; + default: + GGML_ABORT("Unsupported dst type for flash attention shader"); + } + variant += std::string("_dst") + ggml_type_name(key.dst_type); + if (key.has_mask) { defines.push_back("MASK"); if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { @@ -2625,9 +2675,11 @@ class ggml_webgpu_shader_lib { shader_src = wgsl_flash_attn_vec_split; } else if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE) { shader_src = wgsl_flash_attn_tile; - defines.push_back("MAX_SUBGROUP_SIZE=" + std::to_string(context.max_subgroup_size)); + defines.push_back("MIN_SUBGROUP_SIZE=" + std::to_string(context.min_subgroup_size) + "u"); + defines.push_back("MAX_SUBGROUP_SIZE=" + std::to_string(context.max_subgroup_size) + "u"); defines.push_back("KV_STAGE_STRIDE=" + std::to_string(std::max(key.head_dim_qk, key.head_dim_v))); - variant += "_tile"; + variant += "_tile_sg" + std::to_string(context.min_subgroup_size) + "_" + + std::to_string(context.max_subgroup_size); } else { defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m)); defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n)); @@ -2677,6 +2729,7 @@ class ggml_webgpu_shader_lib { webgpu_pipeline get_flash_attn_vec_reduce_pipeline(const ggml_webgpu_shader_lib_context & context) { ggml_webgpu_flash_attn_vec_reduce_pipeline_key key = {}; key.head_dim_v = (uint32_t) context.src2->ne[0]; + key.dst_type = context.dst->type; key.wg_size = context.max_wg_size; auto it = flash_attn_vec_reduce_pipelines.find(key); if (it != flash_attn_vec_reduce_pipelines.end()) { @@ -2686,6 +2739,18 @@ class ggml_webgpu_shader_lib { std::vector defines; std::string variant = "flash_attn_vec_reduce"; + switch (key.dst_type) { + case GGML_TYPE_F32: + defines.push_back("DST_F32"); + break; + case GGML_TYPE_F16: + defines.push_back("DST_F16"); + break; + default: + GGML_ABORT("Unsupported dst type for flash attention vec reduce shader"); + } + variant += std::string("_dst") + ggml_type_name(key.dst_type); + defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v)); variant += std::string("_hsv") + std::to_string(key.head_dim_v); diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 12f60a9900e..02414bfc8b6 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -187,6 +187,7 @@ struct webgpu_capabilities { uint32_t sg_mat_k = 0; uint32_t subgroup_size = 0; + uint32_t min_subgroup_size = 0; uint32_t max_subgroup_size = 0; size_t memset_bytes_per_thread; }; @@ -1442,6 +1443,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, shader_lib_ctx.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m; shader_lib_ctx.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n; shader_lib_ctx.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k; + shader_lib_ctx.min_subgroup_size = ctx->global_ctx->capabilities.min_subgroup_size; shader_lib_ctx.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size; // Get or create pipeline @@ -1750,6 +1752,7 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, shader_lib_ctx.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m; shader_lib_ctx.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n; shader_lib_ctx.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k; + shader_lib_ctx.min_subgroup_size = ctx->global_ctx->capabilities.min_subgroup_size; shader_lib_ctx.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size; webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_pipeline( shader_lib_ctx, ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment); @@ -3469,6 +3472,7 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer shader_lib_ctx.sg_mat_m = ctx->webgpu_global_ctx->capabilities.sg_mat_m; shader_lib_ctx.sg_mat_n = ctx->webgpu_global_ctx->capabilities.sg_mat_n; shader_lib_ctx.sg_mat_k = ctx->webgpu_global_ctx->capabilities.sg_mat_k; + shader_lib_ctx.min_subgroup_size = ctx->webgpu_global_ctx->capabilities.min_subgroup_size; shader_lib_ctx.max_subgroup_size = ctx->webgpu_global_ctx->capabilities.max_subgroup_size; const ggml_webgpu_flash_attn_decisions decisions = ggml_webgpu_flash_attn_get_decisions( @@ -3667,8 +3671,9 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { #endif ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix = valid_subgroup_matrix_config; - // For subgroup matrix code to be the most efficient, we would like the subgroup size to be consistent and accurate. - // Unfortunately, that is not possible, so we use the maximum subgroup size reported by the adapter. + // Runtime subgroup size can be any supported size in this range. Shaders + // that allocate per-lane register arrays must size them for the minimum. + ctx->webgpu_global_ctx->capabilities.min_subgroup_size = info.subgroupMinSize; ctx->webgpu_global_ctx->capabilities.max_subgroup_size = info.subgroupMaxSize; // Initialize device std::vector required_features = { wgpu::FeatureName::ShaderF16 }; @@ -4024,11 +4029,14 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const shader_lib_ctx.dst = const_cast(op); shader_lib_ctx.supports_subgroups = ctx->webgpu_global_ctx->capabilities.supports_subgroups; shader_lib_ctx.supports_subgroup_matrix = ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix; + shader_lib_ctx.max_wg_size = + ctx->webgpu_global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; shader_lib_ctx.wg_mem_limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; shader_lib_ctx.sg_mat_m = ctx->webgpu_global_ctx->capabilities.sg_mat_m; shader_lib_ctx.sg_mat_n = ctx->webgpu_global_ctx->capabilities.sg_mat_n; shader_lib_ctx.sg_mat_k = ctx->webgpu_global_ctx->capabilities.sg_mat_k; + shader_lib_ctx.min_subgroup_size = ctx->webgpu_global_ctx->capabilities.min_subgroup_size; shader_lib_ctx.max_subgroup_size = ctx->webgpu_global_ctx->capabilities.max_subgroup_size; const ggml_webgpu_flash_attn_decisions decisions = ggml_webgpu_flash_attn_get_decisions( @@ -4040,9 +4048,9 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const break; } if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { - const size_t min_bytes = - ggml_webgpu_flash_attn_wg_mem_bytes(decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0], - (uint32_t) src2->ne[0], has_mask, decisions.kv_direct); + const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes( + decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask, + decisions.kv_direct, decisions.path); if (min_bytes > limit_bytes) { supports_op = false; } @@ -4050,9 +4058,9 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const } if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE) { - const size_t min_bytes = - ggml_webgpu_flash_attn_wg_mem_bytes(decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0], - (uint32_t) src2->ne[0], has_mask, decisions.kv_direct); + const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes( + decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask, + decisions.kv_direct, decisions.path); if (min_bytes > limit_bytes) { supports_op = false; } @@ -4063,9 +4071,9 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const supports_op = false; break; } - const size_t min_bytes = - ggml_webgpu_flash_attn_wg_mem_bytes(decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0], - (uint32_t) src2->ne[0], has_mask, decisions.kv_direct); + const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes( + decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask, + decisions.kv_direct, decisions.path); if (min_bytes > limit_bytes) { supports_op = false; } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl index 37ea23b80c8..ae8036b9ac5 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl @@ -1,12 +1,33 @@ enable f16; enable subgroups; +#ifdef Q_F16 +#define Q_TYPE f16 +#else +#define Q_TYPE f32 +#endif + +#ifdef KV_F32 +#define KV_TYPE f32 +#else +#define KV_TYPE f16 +#endif + +#ifdef DST_F16 +#define DST_TYPE f16 +#else +#define DST_TYPE f32 +#endif + #define HEAD_DIM_QK 64 #define HEAD_DIM_V 64 #define KV_STAGE_STRIDE 64 #define Q_TILE 4 #define KV_TILE 64 #define WG_SIZE 128 +#ifndef MIN_SUBGROUP_SIZE +#define MIN_SUBGROUP_SIZE MAX_SUBGROUP_SIZE +#endif struct Params { offset_q: u32, @@ -41,13 +62,13 @@ struct Params { m1: f32, }; -@group(0) @binding(0) var Q: array; +@group(0) @binding(0) var Q: array; #ifdef KV_OVERLAP -@group(0) @binding(1) var K: array>; +@group(0) @binding(1) var K: array>; #define V K #else -@group(0) @binding(1) var K: array>; -@group(0) @binding(2) var V: array>; +@group(0) @binding(1) var K: array>; +@group(0) @binding(2) var V: array>; #endif #if defined(MASK) && defined(SINKS) @@ -92,17 +113,17 @@ struct Params { #endif #endif -@group(0) @binding(DST_BINDING) var dst: array>; +@group(0) @binding(DST_BINDING) var dst: array>; @group(0) @binding(PARAMS_BINDING) var params: Params; const FLOAT_MIN: f32 = -1.0e9; const Q_CHUNKS: u32 = HEAD_DIM_QK / 4u; const V_CHUNKS: u32 = HEAD_DIM_V / 4u; -const SCORE_REGS_PER_LANE: u32 = (KV_TILE + MAX_SUBGROUP_SIZE - 1u) / MAX_SUBGROUP_SIZE; -const OUT_REGS_PER_LANE: u32 = (V_CHUNKS + MAX_SUBGROUP_SIZE - 1u) / MAX_SUBGROUP_SIZE; +const SCORE_REGS_PER_LANE: u32 = (KV_TILE + MIN_SUBGROUP_SIZE - 1u) / MIN_SUBGROUP_SIZE; +const OUT_REGS_PER_LANE: u32 = (V_CHUNKS + MIN_SUBGROUP_SIZE - 1u) / MIN_SUBGROUP_SIZE; -var q_shmem: array; -var kv_shmem: array; +var q_shmem: array; +var kv_shmem: array; var p_shmem: array; @compute @workgroup_size(WG_SIZE) @@ -158,10 +179,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let q_col = elem_idx % HEAD_DIM_QK; let head_q_row = q_row_start + q_tile_row; let global_q_row_offset = q_head_offset + head_q_row * params.stride_q1; - q_shmem[elem_idx] = f16(select( + q_shmem[elem_idx] = select( 0.0, - Q[global_q_row_offset + q_col] * params.scale, - head_q_row < params.seq_len_q)); + f32(Q[global_q_row_offset + q_col]) * params.scale, + head_q_row < params.seq_len_q); } workgroupBarrier(); @@ -192,10 +213,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let k_vec_index = (k_head_offset + global_k_row * params.stride_k1 + chunk * 4u) >> 2u; let k4 = K[k_vec_index]; let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u; - kv_shmem[kv_off + 0u] = k4.x; - kv_shmem[kv_off + 1u] = k4.y; - kv_shmem[kv_off + 2u] = k4.z; - kv_shmem[kv_off + 3u] = k4.w; + kv_shmem[kv_off + 0u] = f32(k4.x); + kv_shmem[kv_off + 1u] = f32(k4.y); + kv_shmem[kv_off + 2u] = f32(k4.z); + kv_shmem[kv_off + 3u] = f32(k4.w); } workgroupBarrier(); @@ -213,16 +234,16 @@ fn main(@builtin(workgroup_id) wg_id: vec3, for (var chunk = 0u; chunk < Q_CHUNKS; chunk += 1u) { let q_off = q_base + chunk * 4u; let qv = vec4( - f32(q_shmem[q_off + 0u]), - f32(q_shmem[q_off + 1u]), - f32(q_shmem[q_off + 2u]), - f32(q_shmem[q_off + 3u])); + q_shmem[q_off + 0u], + q_shmem[q_off + 1u], + q_shmem[q_off + 2u], + q_shmem[q_off + 3u]); let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u; let kv = vec4( - f32(kv_shmem[kv_off + 0u]), - f32(kv_shmem[kv_off + 1u]), - f32(kv_shmem[kv_off + 2u]), - f32(kv_shmem[kv_off + 3u])); + kv_shmem[kv_off + 0u], + kv_shmem[kv_off + 1u], + kv_shmem[kv_off + 2u], + kv_shmem[kv_off + 3u]); dot_val += dot(qv, kv); } #ifdef LOGIT_SOFTCAP @@ -264,10 +285,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let v_vec_index = (v_head_offset + global_v_row * params.stride_v1 + chunk * 4u) >> 2u; let v4 = V[v_vec_index]; let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u; - kv_shmem[kv_off + 0u] = v4.x; - kv_shmem[kv_off + 1u] = v4.y; - kv_shmem[kv_off + 2u] = v4.z; - kv_shmem[kv_off + 3u] = v4.w; + kv_shmem[kv_off + 0u] = f32(v4.x); + kv_shmem[kv_off + 1u] = f32(v4.y); + kv_shmem[kv_off + 2u] = f32(v4.z); + kv_shmem[kv_off + 3u] = f32(v4.w); } workgroupBarrier(); @@ -288,10 +309,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let p = p_shmem[subgroup_p_offset + kv_local]; let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u; let v4 = vec4( - f32(kv_shmem[kv_off + 0u]), - f32(kv_shmem[kv_off + 1u]), - f32(kv_shmem[kv_off + 2u]), - f32(kv_shmem[kv_off + 3u])); + kv_shmem[kv_off + 0u], + kv_shmem[kv_off + 1u], + kv_shmem[kv_off + 2u], + kv_shmem[kv_off + 3u]); acc += p * v4; } out_regs[reg_idx] = acc; @@ -324,7 +345,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, continue; } let dst_vec_index = (row_base + chunk * 4u) >> 2u; - dst[dst_vec_index] = out_regs[reg_idx] * inv_exp_sum; + dst[dst_vec_index] = vec4(out_regs[reg_idx] * inv_exp_sum); } } } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl index 9a0de82a56a..1091d744073 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl @@ -2,6 +2,12 @@ diagnostic(off, subgroup_uniformity); enable f16; enable subgroups; +#ifdef DST_F16 +#define DST_TYPE f16 +#else +#define DST_TYPE f32 +#endif + // Default values #define HEAD_DIM_V 64 #define WG_SIZE 128 @@ -17,7 +23,7 @@ struct Params { }; @group(0) @binding(0) var tmp: array; -@group(0) @binding(1) var dst: array>; +@group(0) @binding(1) var dst: array>; @group(0) @binding(2) var params: Params; const FLOAT_MIN: f32 = -1.0e9; @@ -72,7 +78,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, if (thread == 0u) { let dst_vec_index = (row_base + elem_base) >> 2u; - dst[dst_vec_index] = vec4(sum_x, sum_y, sum_z, sum_w) * inv_s; + dst[dst_vec_index] = vec4(vec4(sum_x, sum_y, sum_z, sum_w) * inv_s); } } } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl index b1e234784a8..30ebbebe772 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl @@ -8,6 +8,18 @@ enable subgroups; #define KV_TYPE f16 #endif +#ifdef Q_F16 +#define Q_TYPE f16 +#else +#define Q_TYPE f32 +#endif + +#ifdef DST_F16 +#define DST_TYPE f16 +#else +#define DST_TYPE f32 +#endif + #define HEAD_DIM_QK 64 #define HEAD_DIM_V 64 @@ -89,7 +101,7 @@ struct Params { nwg: u32, }; -@group(0) @binding(0) var Q: array; +@group(0) @binding(0) var Q: array; #ifdef KV_OVERLAP #if defined(KV_Q4_0) || defined(KV_Q8_0) @group(0) @binding(1) var K: array; @@ -191,41 +203,41 @@ struct Params { @group(0) @binding(BLK_BINDING) var blk: array; #endif @group(0) @binding(TMP_BINDING) var tmp: array; -@group(0) @binding(DST_BINDING) var dst: array>; +@group(0) @binding(DST_BINDING) var dst: array>; @group(0) @binding(PARAMS_BINDING) var params: Params; // Just a very small float value. const FLOAT_MIN: f32 = -1.0e9; -var q_shmem: array; +var q_shmem: array; #ifndef KV_DIRECT const kv_shmem_size = KV_TILE * max(HEAD_DIM_QK, HEAD_DIM_V); // we can reuse the same shmem for K and V since we only need one at a time -var kv_shmem: array; +var kv_shmem: array; #endif -var o_shmem: array; +var o_shmem: array; #ifdef MASK // storage for mask values -var mask_shmem: array; +var mask_shmem: array; #endif // note that we reuse the same storage for both since we only need one at a time -var inter_shmem: array; +var inter_shmem: array; // Storage for row max and exp sum during online softmax fn calc_softmax_term(kv_idx: u32, slope: f32, has_bias: bool, apply_mask: bool) -> f32 { var v = select(FLOAT_MIN, - f32(inter_shmem[kv_idx]) * params.scale, + inter_shmem[kv_idx] * params.scale, kv_idx < KV_TILE); #ifdef LOGIT_SOFTCAP v = params.logit_softcap * tanh(v); #endif #ifdef MASK if (apply_mask) { - var mask_val = select(0.0, f32(mask_shmem[kv_idx]), kv_idx < KV_TILE); + var mask_val = select(0.0, mask_shmem[kv_idx], kv_idx < KV_TILE); v += select(mask_val, slope * mask_val, has_bias); } #endif @@ -289,10 +301,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3, // load the single Q row into shared memory for (var elem_idx = local_id.x; elem_idx < HEAD_DIM_QK; elem_idx += WG_SIZE) { let global_q_row_offset = q_head_offset + q_row_start * params.stride_q1; - q_shmem[elem_idx] = f16(select( + q_shmem[elem_idx] = select( 0.0, - Q[global_q_row_offset + elem_idx], - q_row_start < params.seq_len_q)); + f32(Q[global_q_row_offset + elem_idx]), + q_row_start < params.seq_len_q); } for (var kv_tile = iwg * KV_TILE; kv_tile < params.seq_len_kv; kv_tile += KV_TILE * params.nwg) { @@ -308,7 +320,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let blk_state = blk_state_local; let skip_tile = blk_state == 0u; for (var elem_idx = local_id.x; elem_idx < KV_TILE; elem_idx += WG_SIZE) { - inter_shmem[elem_idx] = f16(0.0); + inter_shmem[elem_idx] = 0.0; } // load k tile into shared memory @@ -331,8 +343,8 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let q_packed = bitcast(vec2(q_0, q_1)); for (var k = 0u; k < 4u; k++) { let q_byte = get_byte(q_packed, k); - let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d; - let q_lo = (f16(q_byte & 0xF) - 8.0) * d; + let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * f32(d); + let q_lo = (f32(q_byte & 0xF) - 8.0) * f32(d); let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; kv_shmem[row_offset + idx] = q_lo; kv_shmem[row_offset + idx + 16u] = q_hi; @@ -359,7 +371,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let q_packed = bitcast(vec2(q_0, q_1)); for (var k = 0u; k < 4u; k++) { let q_byte = get_byte_i32(q_packed, k); - let q_val = f16(q_byte) * d; + let q_val = f32(q_byte) * f32(d); let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; kv_shmem[row_offset + idx] = q_val; } @@ -377,10 +389,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let in_bounds = global_k_row < params.seq_len_kv && (k_col + 3u) < HEAD_DIM_QK; let vec_idx = (global_k_row_offset + k_col) >> 2u; let k4 = select(vec4(0.0), K[vec_idx], in_bounds); - kv_shmem[elem_idx + 0u] = f16(k4.x); - kv_shmem[elem_idx + 1u] = f16(k4.y); - kv_shmem[elem_idx + 2u] = f16(k4.z); - kv_shmem[elem_idx + 3u] = f16(k4.w); + kv_shmem[elem_idx + 0u] = f32(k4.x); + kv_shmem[elem_idx + 1u] = f32(k4.y); + kv_shmem[elem_idx + 2u] = f32(k4.z); + kv_shmem[elem_idx + 3u] = f32(k4.w); } #endif @@ -401,20 +413,20 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let q_off = i * 4u; let qv = vec4( - f32(q_shmem[q_off + 0u]), - f32(q_shmem[q_off + 1u]), - f32(q_shmem[q_off + 2u]), - f32(q_shmem[q_off + 3u])); + q_shmem[q_off + 0u], + q_shmem[q_off + 1u], + q_shmem[q_off + 2u], + q_shmem[q_off + 3u]); #ifdef KV_DIRECT let idx = k_head_offset + (kv_tile + kv_idx) * params.stride_k1 + (i * 4u); let kv = vec4(K[idx >> 2u]); #else let idx = kv_idx * HEAD_DIM_QK + (i * 4u); let kv = vec4( - f32(kv_shmem[idx + 0u]), - f32(kv_shmem[idx + 1u]), - f32(kv_shmem[idx + 2u]), - f32(kv_shmem[idx + 3u])); + kv_shmem[idx + 0u], + kv_shmem[idx + 1u], + kv_shmem[idx + 2u], + kv_shmem[idx + 3u]); #endif partial_sum += dot(qv, kv); } @@ -435,7 +447,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let sum_bcast = subgroupShuffle(sum, num_of_threads * ty); if (tx == 0u && kv_valid) { - inter_shmem[kv_idx] = f16(sum_bcast); + inter_shmem[kv_idx] = sum_bcast; } } } @@ -450,7 +462,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let global_k_col = kv_tile + elem_idx; let mask_in_bounds = q_row_start < params.seq_len_q && global_k_col < params.seq_len_kv; let mask_idx = mask_global_offset + global_k_col; - mask_shmem[elem_idx] = select(0.0, mask[mask_idx], mask_in_bounds); + mask_shmem[elem_idx] = select(0.0f, f32(mask[mask_idx]), mask_in_bounds); } } #else @@ -483,7 +495,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, kv_tile + kv_idx < params.seq_len_kv && kv_idx < KV_TILE); total_exp_term += subgroupAdd(cur_p); if (kv_idx < KV_TILE) { - inter_shmem[kv_idx] = f16(cur_p); + inter_shmem[kv_idx] = cur_p; } } @@ -493,7 +505,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, exp_sum = exp_sum * cur_exp + total_exp_term; for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) { - o_shmem[elem_idx] = f16(f32(o_shmem[elem_idx]) * cur_exp); + o_shmem[elem_idx] = o_shmem[elem_idx] * cur_exp; } } @@ -517,8 +529,8 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let q_packed = bitcast(vec2(q_0, q_1)); for (var k = 0u; k < 4u; k++) { let q_byte = get_byte(q_packed, k); - let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d; - let q_lo = (f16(q_byte & 0xF) - 8.0) * d; + let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * f32(d); + let q_lo = (f32(q_byte & 0xF) - 8.0) * f32(d); let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; kv_shmem[row_offset + idx] = q_lo; kv_shmem[row_offset + idx + 16u] = q_hi; @@ -545,7 +557,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let q_packed = bitcast(vec2(q_0, q_1)); for (var k = 0u; k < 4u; k++) { let q_byte = get_byte_i32(q_packed, k); - let q_val = f16(q_byte) * d; + let q_val = f32(q_byte) * f32(d); let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; kv_shmem[row_offset + idx] = q_val; } @@ -563,10 +575,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let in_bounds = global_v_row < params.seq_len_kv && (v_col + 3u) < HEAD_DIM_V; let vec_idx = (global_v_row_offset + v_col) >> 2u; let v4 = select(vec4(0.0), V[vec_idx], in_bounds); - kv_shmem[elem_idx + 0u] = f16(v4.x); - kv_shmem[elem_idx + 1u] = f16(v4.y); - kv_shmem[elem_idx + 2u] = f16(v4.z); - kv_shmem[elem_idx + 3u] = f16(v4.w); + kv_shmem[elem_idx + 0u] = f32(v4.x); + kv_shmem[elem_idx + 1u] = f32(v4.y); + kv_shmem[elem_idx + 2u] = f32(v4.z); + kv_shmem[elem_idx + 3u] = f32(v4.w); } #endif @@ -589,17 +601,17 @@ fn main(@builtin(workgroup_id) wg_id: vec3, continue; } - let p = f32(inter_shmem[kv_idx]); + let p = inter_shmem[kv_idx]; #ifdef KV_DIRECT let v_idx = v_head_offset + v_row * params.stride_v1 + vec_col * 4u; let v4 = vec4(V[v_idx >> 2u]); #else let v_idx = kv_idx * HEAD_DIM_V + vec_col * 4u; let v4 = vec4( - f32(kv_shmem[v_idx + 0u]), - f32(kv_shmem[v_idx + 1u]), - f32(kv_shmem[v_idx + 2u]), - f32(kv_shmem[v_idx + 3u])); + kv_shmem[v_idx + 0u], + kv_shmem[v_idx + 1u], + kv_shmem[v_idx + 2u], + kv_shmem[v_idx + 3u]); #endif lo += p * v4; } @@ -630,10 +642,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3, if (ty_pv == 0u) { let elem_base = vec_col * 4u; - o_shmem[elem_base + 0u] = f16(f32(o_shmem[elem_base + 0u]) + lo_x); - o_shmem[elem_base + 1u] = f16(f32(o_shmem[elem_base + 1u]) + lo_y); - o_shmem[elem_base + 2u] = f16(f32(o_shmem[elem_base + 2u]) + lo_z); - o_shmem[elem_base + 3u] = f16(f32(o_shmem[elem_base + 3u]) + lo_w); + o_shmem[elem_base + 0u] = o_shmem[elem_base + 0u] + lo_x; + o_shmem[elem_base + 1u] = o_shmem[elem_base + 1u] + lo_y; + o_shmem[elem_base + 2u] = o_shmem[elem_base + 2u] + lo_z; + o_shmem[elem_base + 3u] = o_shmem[elem_base + 3u] + lo_w; } } } @@ -660,7 +672,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, exp_sum = exp_sum * max_exp + sink_exp_sum; for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) { - o_shmem[elem_idx] = f16(f32(o_shmem[elem_idx]) * max_exp); + o_shmem[elem_idx] = o_shmem[elem_idx] * max_exp; } } workgroupBarrier(); @@ -681,7 +693,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, ); let dst_vec_index: u32 = (row_base + elem_base) >> 2u; - dst[dst_vec_index] = v; + dst[dst_vec_index] = vec4(v); } } else { let rid = batch_idx * rows_per_batch + head_idx * params.seq_len_q + q_row_start; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl index b8f1bca1284..8e34e1c9ca0 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl @@ -50,10 +50,25 @@ struct Params { @group(0) @binding(PARAMS_BINDING) var params: Params; +fn erf_approx(x: TYPE) -> TYPE { + let x_f32 = f32(x); + let s = select(-1.0, 1.0, x_f32 >= 0.0); + let ax = abs(x_f32); + + let t = 1.0 / (1.0 + 0.3275911 * ax); + + let y = 1.0 - + (((((1.061405429 * t - 1.453152027) * t + 1.421413741) * t + - 0.284496736) * t + 0.254829592) * t) * + exp(-ax * ax); + + return TYPE(s * y); +} + @compute @workgroup_size(WG_SIZE) fn main(@builtin(global_invocation_id) gid: vec3) { if (gid.x >= params.ne) { - return; + return; } var i = gid.x; let ne2 = params.ne2; @@ -71,15 +86,13 @@ fn main(@builtin(global_invocation_id) gid: vec3) { let i1 = i / ne0; let i0 = i % ne0; - let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 + - i2 * params.stride_src2 + i3 * params.stride_src3; + let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 + i2 * params.stride_src2 + i3 * params.stride_src3; #ifdef ABS let res = abs(src[params.offset_src + src_idx]); #endif #ifdef SGN - let res = select(TYPE(select(0.0, -1.0, src[params.offset_src + src_idx] < 0.0)), TYPE(1.0), - src[params.offset_src + src_idx] > 0.0); + let res = select(TYPE(select(0.0, -1.0, src[params.offset_src + src_idx] < 0.0)), TYPE(1.0), src[params.offset_src + src_idx] > 0.0); #endif #ifdef NEG let res = -src[params.offset_src + src_idx]; @@ -94,8 +107,7 @@ fn main(@builtin(global_invocation_id) gid: vec3) { let res = select(0.0, src[params.offset_src + src_idx], src[params.offset_src + src_idx] > 0.0); #endif #ifdef ELU - let res = select(exp(src[params.offset_src + src_idx]) - 1.0, src[params.offset_src + src_idx], - src[params.offset_src + src_idx] > 0.0); + let res = select(exp(src[params.offset_src + src_idx]) - 1.0, src[params.offset_src + src_idx], src[params.offset_src + src_idx] > 0.0); #endif #ifdef HARDSIGMOID let res = min(1.0, max(0.0, (src[params.offset_src + src_idx] + 3.0) / 6.0)); @@ -120,31 +132,16 @@ fn main(@builtin(global_invocation_id) gid: vec3) { let res = TYPE(params.fill_val); #endif #ifdef HARDSWISH - let res = src[params.offset_src + src_idx] * - min(1.0, max(0.0, (src[params.offset_src + src_idx] + 3.0) / 6.0)); + let res = src[params.offset_src + src_idx] * min(1.0, max(0.0, (src[params.offset_src + src_idx] + 3.0) / 6.0)); #endif #ifdef GELU - let res = 0.5 * src[params.offset_src + src_idx] * - (1.0 + tanh(clamp(sqrt(2.0 / 3.14159265) * - (src[params.offset_src + src_idx] + - 0.044715 * pow(src[params.offset_src + src_idx], 3.0)), - -9.010913, 9.010913))); + let res = 0.5 * src[params.offset_src + src_idx] * (1.0 + tanh(clamp(0.7978845608028654 * (src[params.offset_src + src_idx] + 0.044715 * src[params.offset_src + src_idx] * src[params.offset_src + src_idx] * src[params.offset_src + src_idx]), -9.010913, 9.010913))); #endif #ifdef GELU_QUICK - let res = src[params.offset_src + src_idx] * 0.5 * - (1.0 + tanh(clamp(0.79788456 * - (src[params.offset_src + src_idx] + - 0.044715 * src[params.offset_src + src_idx] * - src[params.offset_src + src_idx] * src[params.offset_src + src_idx]), - -9.010913, 9.010913))); + let res = src[params.offset_src + src_idx] * (1.0 / (1.0 + exp(clamp(-1.702 * src[params.offset_src + src_idx], -80.0, 80.0)))); #endif #ifdef GELU_ERF - let res = 0.5 * src[params.offset_src + src_idx] * - (1.0 + tanh(clamp(0.79788456 * - (src[params.offset_src + src_idx] + - 0.044715 * src[params.offset_src + src_idx] * - src[params.offset_src + src_idx] * src[params.offset_src + src_idx]), - -9.010913, 9.010913))); + let res = 0.5 * src[params.offset_src + src_idx] * (1.0 + erf_approx(src[params.offset_src + src_idx] * 0.7071067811865476)); #endif #ifdef XIELU let val = f32(src[params.offset_src + src_idx]); From e8a7cd314fccab6dc2db3206a70f6f1a782031c2 Mon Sep 17 00:00:00 2001 From: Masashi Yoshimura Date: Tue, 12 May 2026 23:27:40 +0900 Subject: [PATCH 610/831] ggml-webgpu: Enables running gpt-oss-20b (llama/22906) * Enable to run gpt-oss-20b and refactor mulmat-q * disable test-backend-ops in ubuntu-24-webgpu --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 68 +++++- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 61 ++++- ggml/src/ggml-webgpu/wgsl-shaders/add_id.wgsl | 64 +++++ .../wgsl-shaders/common_decls.tmpl | 7 + .../ggml-webgpu/wgsl-shaders/get_rows.wgsl | 21 ++ .../wgsl-shaders/mul_mat_decls.tmpl | 221 +++++++++++------- .../wgsl-shaders/mul_mat_vec_acc.tmpl | 42 ++++ 7 files changed, 392 insertions(+), 92 deletions(-) create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/add_id.wgsl diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 932a01d385e..11701e79433 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -495,6 +495,22 @@ struct ggml_webgpu_binary_pipeline_key_hash { } }; +/* Add_Id */ + +struct ggml_webgpu_add_id_pipeline_key { + bool inplace; + + bool operator==(const ggml_webgpu_add_id_pipeline_key & other) const { return inplace == other.inplace; } +}; + +struct ggml_webgpu_add_id_pipeline_key_hash { + size_t operator()(const ggml_webgpu_add_id_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.inplace); + return seed; + } +}; + /** Unary **/ struct ggml_webgpu_unary_pipeline_key { @@ -1058,7 +1074,9 @@ class ggml_webgpu_shader_lib { std::unordered_map pad_pipelines; // circular/non-circular std::unordered_map - binary_pipelines; // type/op/inplace/overlap + binary_pipelines; // type/op/inplace/overlap/src_overlap + std::unordered_map + add_id_pipelines; // inplace std::unordered_map concat_pipelines; // type std::unordered_map @@ -1433,6 +1451,7 @@ class ggml_webgpu_shader_lib { case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ4_NL: + case GGML_TYPE_MXFP4: { // Quantized types using u32 buffers for portability. defines.push_back("SRC_TYPE=u32"); @@ -1451,6 +1470,7 @@ class ggml_webgpu_shader_lib { defines.push_back(type_upper + "_SCALE_MIN"); defines.push_back(type_upper + "_TABLES"); defines.push_back(type_upper + "_GRID"); + defines.push_back(type_upper + "_LUT"); variant += "_"; variant += type_str; @@ -1460,7 +1480,7 @@ class ggml_webgpu_shader_lib { if (key.src_type == GGML_TYPE_Q1_0) { defines.push_back("BLOCK_SIZE=128u"); } else if ((key.src_type >= GGML_TYPE_Q4_0 && key.src_type <= GGML_TYPE_Q8_1) || - key.src_type == GGML_TYPE_IQ4_NL) { + key.src_type == GGML_TYPE_IQ4_NL || key.src_type == GGML_TYPE_MXFP4) { defines.push_back("BLOCK_SIZE=32u"); } else if (key.src_type >= GGML_TYPE_Q2_K) { defines.push_back("BLOCK_SIZE=256u"); @@ -1774,6 +1794,9 @@ class ggml_webgpu_shader_lib { defines.push_back(type_upper + "_GRID"); defines.push_back(type_upper + "_TABLES"); break; + case GGML_TYPE_MXFP4: + defines.push_back(type_upper + "_LUT"); + break; default: break; } @@ -1908,6 +1931,9 @@ class ggml_webgpu_shader_lib { defines.push_back(type_upper + "_GRID"); defines.push_back(type_upper + "_TABLES"); break; + case GGML_TYPE_MXFP4: + defines.push_back(type_upper + "_LUT"); + break; default: break; } @@ -2042,6 +2068,7 @@ class ggml_webgpu_shader_lib { case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ4_NL: + case GGML_TYPE_MXFP4: { // Quantized types using u32 buffers for portability. defines.push_back("SRC0_TYPE=u32"); @@ -2169,6 +2196,9 @@ class ggml_webgpu_shader_lib { defines.push_back(type_upper + "_GRID"); defines.push_back(type_upper + "_TABLES"); break; + case GGML_TYPE_MXFP4: + defines.push_back(type_upper + "_LUT"); + break; default: break; } @@ -2286,6 +2316,9 @@ class ggml_webgpu_shader_lib { defines.push_back(type_upper + "_GRID"); defines.push_back(type_upper + "_TABLES"); break; + case GGML_TYPE_MXFP4: + defines.push_back(type_upper + "_LUT"); + break; default: break; } @@ -2503,6 +2536,37 @@ class ggml_webgpu_shader_lib { return binary_pipelines[key]; } + webgpu_pipeline get_add_id_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_add_id_pipeline_key key = {}; + key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst); + + auto it = add_id_pipelines.find(key); + if (it != add_id_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "add_id"; + const char * shader_src = wgsl_add_id; + + if (key.inplace) { + defines.push_back("INPLACE"); + variant += "_inplace"; + } + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(shader_src, defines); + auto pipeline_decisions = std::make_shared(); + pipeline_decisions->wg_size = context.max_wg_size; + pipeline_decisions->inplace = key.inplace; + + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = pipeline_decisions; + add_id_pipelines[key] = pipeline; + return pipeline; + } + webgpu_pipeline get_concat_pipeline(const ggml_webgpu_shader_lib_context & context) { ggml_webgpu_concat_pipeline_key key = {}; key.type = context.dst->type; diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 02414bfc8b6..b24101c78b0 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -1411,8 +1411,6 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, case GGML_TYPE_Q3_K: case GGML_TYPE_Q2_K: case GGML_TYPE_Q1_0: - use_fast = true; - break; case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ1_M: case GGML_TYPE_IQ2_XXS: @@ -1422,6 +1420,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: + case GGML_TYPE_MXFP4: use_fast = true; break; default: @@ -2145,6 +2144,56 @@ static webgpu_encoded_op ggml_webgpu_binary_op(webgpu_context & ctx, return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } +static webgpu_encoded_op ggml_webgpu_add_id(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * src2, + ggml_tensor * dst) { + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.src2 = src2; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + + webgpu_pipeline pipeline = ctx->shader_lib->get_add_id_pipeline(shader_lib_ctx); + + auto * decisions = static_cast(pipeline.context.get()); + + std::vector params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), + (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), + (uint32_t) (src2->nb[0] / ggml_type_size(src2->type)), + (uint32_t) (src2->nb[1] / ggml_type_size(src2->type)), + (uint32_t) dst->ne[0], + (uint32_t) dst->ne[1], + (uint32_t) dst->ne[2], + }; + + std::vector entries; + + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, src2)); + + if (!decisions->inplace) { + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 3, dst)); + } + + uint32_t wg_x = 1; + uint32_t wg_y = 1; + uint32_t total_wg = ggml_nrows(dst); + const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; + compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y); + + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); +} + static webgpu_encoded_op ggml_webgpu_concat(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, @@ -2918,6 +2967,8 @@ static std::optional ggml_webgpu_encode(webgpu_context ctx, case GGML_OP_MUL: case GGML_OP_DIV: return ggml_webgpu_binary_op(ctx, src0, src1, node); + case GGML_OP_ADD_ID: + return ggml_webgpu_add_id(ctx, src0, src1, src2, node); case GGML_OP_CONCAT: return ggml_webgpu_concat(ctx, src0, src1, node); case GGML_OP_REPEAT: @@ -3867,6 +3918,7 @@ static bool ggml_webgpu_supported_qtype(ggml_type type) { case GGML_TYPE_IQ1_M: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: + case GGML_TYPE_MXFP4: return true; default: return false; @@ -3905,6 +3957,9 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type) && (src1->type == op->type); break; + case GGML_OP_ADD_ID: + supports_op = src0->type == GGML_TYPE_F32; + break; case GGML_OP_CONCAT: supports_op = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32); break; @@ -3962,6 +4017,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const case GGML_TYPE_IQ1_M: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: + case GGML_TYPE_MXFP4: supports_op = true; break; default: @@ -4001,6 +4057,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: + case GGML_TYPE_MXFP4: supports_op = true; break; default: diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/add_id.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/add_id.wgsl new file mode 100644 index 00000000000..2573926cb89 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/add_id.wgsl @@ -0,0 +1,64 @@ +struct Params { + offset_src0: u32, + offset_src1: u32, + offset_ids: u32, + offset_dst: u32, + + nb01: u32, + nb02: u32, + nb11: u32, + nb20: u32, + nb21: u32, + + ne0: u32, + ne1: u32, + ne2: u32, +}; + +@group(0) @binding(0) var src0: array; // [n_embd, n_experts_used, n_token] +@group(0) @binding(1) var src1: array; // [n_embd, n_experts] +@group(0) @binding(2) var ids: array; // [n_experts_used, n_token] + +#ifdef INPLACE + +@group(0) @binding(3) +var params: Params; + +#else + +@group(0) @binding(3) +var dst: array; + +@group(0) @binding(4) +var params: Params; + +#endif + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(workgroup_id) wg_id: vec3, + @builtin(num_workgroups) num_wg: vec3, + @builtin(local_invocation_id) local_id: vec3) { + + let wg_linear = wg_id.x + wg_id.y * num_wg.x; + + if (wg_linear < params.ne1 * params.ne2) { + let thread_id = local_id.x; + let i2 = wg_linear / params.ne1; + let i1 = wg_linear % params.ne1; + + let i11 = u32(ids[params.offset_ids + i1 * params.nb20 + i2 * params.nb21]); + + let src0_row = params.offset_src0 + i1 * params.nb01 + i2 * params.nb02; + let src1_row = params.offset_src1 + i11 * params.nb11; + let dst_row = params.offset_dst + i1 * params.ne0 + i2 * (params.ne0 * params.ne1); + + for (var i = thread_id;i < params.ne0; i += WG_SIZE) { +#ifdef INPLACE + src0[src0_row + i] = src0[src0_row + i] + src1[src1_row + i]; +#else + dst[dst_row + i] = src0[src0_row + i] + src1[src1_row + i]; +#endif + } + } + +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl index 14c045b0ba6..372ea79bf9d 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl @@ -896,3 +896,10 @@ const kvalues_iq4nl = array( ); #endif + +#ifdef MXFP4_LUT +const kvalues_mxfp4 = array( + 0, 1, 2, 3, 4, 6, 8, 12, 0, -1, -2, -3, -4, -6, -8, -12 +); +#endif + diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl index 5710cd35469..78d61a93d28 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl @@ -652,6 +652,27 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { } #endif +#ifdef MXFP4 +fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { + let block_byte_base = (src_base + offset) * 17; + let eu8 = get_byte(load_u32_at_src(block_byte_base), 0); + let d = ldexp(1.0, i32(eu8) - 128); + for (var j: u32 = 0u; j < 4; j++) { + let q_byte_offset = block_byte_base + 1 + j * 4; + let q_packed = load_u32_at_src(q_byte_offset); + for (var k: u32 = 0; k < 4; k++) { + let q_byte = get_byte(q_packed, k); + let q_hi = f32(kvalues_mxfp4[(q_byte >> 4) & 0xF]) * d; + let q_lo = f32(kvalues_mxfp4[q_byte & 0xFu]) * d; + let dst_offset = dst_base + offset * 32 + j * 4 + k; + dst[dst_offset] = q_lo; + dst[dst_offset + 16u] = q_hi; + } + } +} +#endif + + @group(0) @binding(0) var src: array; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl index 51cf08f196f..eb2a8368f43 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl @@ -100,34 +100,37 @@ const BLOCK_SIZE_BYTES = 18u; // the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. override BLOCKS_K = TILE_K/BLOCK_SIZE; const NQ = 16u; -const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 -const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; +const BYTES_PER_THREAD = 8u; // NQ(16) weights use 8 bytes of q +const BYTES_PER_INNER_LOOP = 4u; // == sizeof(q_packed) fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) { let blck_idx = i / BLOCK_SIZE; - let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; - let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; + let block_offset = (i % BLOCK_SIZE) / NQ; + let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD; let tile_m = blck_idx / BLOCKS_K; let global_m = offset_m + tile_m; let block_k = blck_idx % BLOCKS_K; - let global_k = k_outer / BLOCK_SIZE + block_k; + let global_block_k = k_outer / BLOCK_SIZE + block_k; - if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { - let src0_idx = batch_offset + global_m * params.stride_01 + global_k; + if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) { + let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; let d = load_f16_at_src0(block_byte_base); - for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); + // store NQ(16) weights + for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) { + + let q_byte_offset = block_byte_base + 2u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP; let q_packed = load_u32_at_src0(q_byte_offset); - for (var k = 0u; k < 4u; k++) { + + for (var k = 0u; k < BYTES_PER_INNER_LOOP; k++) { let q_byte = get_byte(q_packed, k); let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d; let q_lo = (f16(q_byte & 0xF) - 8.0) * d; - shmem[shmem_idx + j * 2 + k] = q_lo; - shmem[shmem_idx + j * 2 + k + 16u] = q_hi; + shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k] = q_lo; + shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = q_hi; } } } @@ -141,35 +144,38 @@ const BLOCK_SIZE_BYTES = 20u; // the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. override BLOCKS_K = TILE_K/BLOCK_SIZE; const NQ = 16u; -const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 -const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; +const BYTES_PER_THREAD = 8u; // NQ(16) weights use 8 bytes of q +const BYTES_PER_INNER_LOOP = 4u; // == sizeof(q_packed) fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) { let blck_idx = i / BLOCK_SIZE; - let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; - let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; + let block_offset = (i % BLOCK_SIZE) / NQ; + let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD; let tile_m = blck_idx / BLOCKS_K; let global_m = offset_m + tile_m; let block_k = blck_idx % BLOCKS_K; - let global_k = k_outer / BLOCK_SIZE + block_k; + let global_block_k = k_outer / BLOCK_SIZE + block_k; - if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { - let src0_idx = batch_offset + global_m * params.stride_01 + global_k; + if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) { + let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; let d = load_f16_at_src0(block_byte_base); let m = load_f16_at_src0(block_byte_base + 2u); - for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j); + // store NQ(16) weights + for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) { + + let q_byte_offset = block_byte_base + 4u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP; let q_packed = load_u32_at_src0(q_byte_offset); - for (var k = 0u; k < 4u; k++) { + + for (var k = 0u; k < BYTES_PER_INNER_LOOP; k++) { let q_byte = get_byte(q_packed, k); let q_lo = f16(q_byte & 0xF) * d + m; let q_hi = f16((q_byte >> 4) & 0xF) * d + m; - shmem[shmem_idx + j * 2 + k] = q_lo; - shmem[shmem_idx + j * 2 + k + 16u] = q_hi; + shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k] = q_lo; + shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = q_hi; } } } @@ -178,52 +184,49 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 #endif // INIT_SRC0_SHMEM_Q4_1 #ifdef INIT_SRC0_SHMEM_Q5_0 -// 32 weights per block, each at 4 bits each = 32 * 4 = 128 bits / 16 = 8 f16s per block const BLOCK_SIZE = 32u; const BLOCK_SIZE_BYTES = 22u; // the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. // tile_k is defined as 32u, so blocks_k ends up being 1 always override BLOCKS_K = TILE_K / BLOCK_SIZE; const NQ = 16u; -const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 -const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 16 / 4 = 4 f16s per thread, each thread should handle 4 f16s * 4 weights per = 16 weights +const BYTES_PER_THREAD = 8u; // NQ(16) weights use 8 bytes of q +const BYTES_PER_INNER_LOOP = 4u; // == sizeof(q_packed) fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) { let blck_idx = i / BLOCK_SIZE; - let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; - let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; + let block_offset = (i % BLOCK_SIZE) / NQ; + let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD; let tile_m = blck_idx / BLOCKS_K; let global_m = offset_m + tile_m; let block_k = blck_idx % BLOCKS_K; - let global_k = k_outer / BLOCK_SIZE + block_k; + let global_block_k = k_outer / BLOCK_SIZE + block_k; - if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { - let src0_idx = batch_offset + global_m * params.stride_01 + global_k; + if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) { + let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; let d = load_f16_at_src0(block_byte_base); let qh_packed = load_u32_at_src0(block_byte_base + 2u); - for (var j = 0u; j < 2; j++) { - let q_byte_offset = block_byte_base + 6u + 2u * (block_offset + j * 2u); + // store NQ(16) weights + for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) { + let q_byte_offset = block_byte_base + 6u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP; let q_packed = load_u32_at_src0(q_byte_offset); - let j_adjusted = j + (block_offset / 2u); - - - for (var k = 0u; k < 4u; k++) { + for (var k = 0u; k < BYTES_PER_INNER_LOOP; k++) { let q_byte = get_byte(q_packed, k); - let qh_hi = (qh_packed >> (j_adjusted * 4 + k + 12)) & 0x10; + let byte_idx = block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP + k; + let qh_hi = (qh_packed >> (byte_idx + 12u)) & 0x10; let q_hi = (f16(((q_byte >> 4) & 0xF) | qh_hi) - 16.0) * d; - let qh_lo = ((qh_packed >> (j_adjusted * 4 + k)) << 4) & 0x10; + let qh_lo = ((qh_packed >> byte_idx) << 4) & 0x10; let q_lo = (f16((q_byte & 0xF) | qh_lo) - 16.0) * d; - - shmem[shmem_idx + j * 4u + k] = q_lo; // store first weight - shmem[shmem_idx + j * 4u + k + 16u] = q_hi; // store second weight + shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k] = q_lo; + shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = q_hi; } } } @@ -232,54 +235,49 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 #endif // INIT_SRC0_SHMEM_Q5_0 #ifdef INIT_SRC0_SHMEM_Q5_1 -// 32 weights per block, each at 4 bits each = 32 * 4 = 128 bits / 16 = 8 f16s per block const BLOCK_SIZE = 32u; const BLOCK_SIZE_BYTES = 24u; // the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. -// tile_k is defined as 32u, so blocks_k ends up being 1 always override BLOCKS_K = TILE_K / BLOCK_SIZE; const NQ = 16u; -const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 -const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 16 / 4 = 4 f16s per thread, each thread should handle 4 f16s * 4 weights per = 16 weights +const BYTES_PER_THREAD = 8u; // NQ(16) weights use 8 bytes of q +const BYTES_PER_INNER_LOOP = 4u; // == sizeof(q_packed) fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) { let blck_idx = i / BLOCK_SIZE; - let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; - let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; + let block_offset = (i % BLOCK_SIZE) / NQ; + let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD; let tile_m = blck_idx / BLOCKS_K; let global_m = offset_m + tile_m; let block_k = blck_idx % BLOCKS_K; - let global_k = k_outer / BLOCK_SIZE + block_k; + let global_block_k = k_outer / BLOCK_SIZE + block_k; - if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { - let src0_idx = batch_offset + global_m * params.stride_01 + global_k; + if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) { + let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; let d = load_f16_at_src0(block_byte_base); let m = load_f16_at_src0(block_byte_base + 2u); let qh_packed = load_u32_at_src0(block_byte_base + 4u); - for (var j = 0u; j < 2; j++) { - - let q_byte_offset = block_byte_base + 8u + 2u * (block_offset + j * 2u); + // store NQ(16) weights + for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) { + let q_byte_offset = block_byte_base + 8u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP; let q_packed = load_u32_at_src0(q_byte_offset); - let j_adjusted = j + (block_offset / 2u); - - - for (var k = 0u; k < 4u; k++) { + for (var k = 0u; k < BYTES_PER_INNER_LOOP; k++) { let q_byte = get_byte(q_packed, k); - let qh_hi = (qh_packed >> (j_adjusted * 4 + k + 12)) & 0x10; - let q_hi = (f16(((q_byte >> 4) & 0xF) | qh_hi)) * d + m; - let qh_lo = ((qh_packed >> (j_adjusted * 4 + k)) << 4) & 0x10; - let q_lo = (f16((q_byte & 0xF) | qh_lo)) * d + m; - - shmem[shmem_idx + j * 4u + k] = q_lo; // store first weight - shmem[shmem_idx + j * 4u + k + 16u] = q_hi; // store second weight + let byte_idx = block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP + k; + let qh_hi = (qh_packed >> (byte_idx + 12u)) & 0x10; + let q_hi = f16(((q_byte >> 4) & 0xF) | qh_hi) * d + m; + let qh_lo = ((qh_packed >> byte_idx) << 4) & 0x10; + let q_lo = f16((q_byte & 0xF) | qh_lo) * d + m; + shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k] = q_lo; + shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = q_hi; } } } @@ -293,33 +291,34 @@ const BLOCK_SIZE_BYTES = 34u; // the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. override BLOCKS_K = TILE_K/BLOCK_SIZE; const NQ = 16u; -const WEIGHTS_PER_F16 = 2u; // 2 8-bit weights per f16 -const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 8 f16s per thread +const BYTES_PER_THREAD = 16u; // NQ(16) weights use 16 bytes of q +const BYTES_PER_INNER_LOOP = 4u; // == sizeof(q_packed) fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) { let blck_idx = i / BLOCK_SIZE; - let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; - let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; + let block_offset = (i % BLOCK_SIZE) / NQ; + let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD; let tile_m = blck_idx / BLOCKS_K; let global_m = offset_m + tile_m; let block_k = blck_idx % BLOCKS_K; - let global_k = k_outer / BLOCK_SIZE + block_k; + let global_block_k = k_outer / BLOCK_SIZE + block_k; - if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { - let src0_idx = batch_offset + global_m * params.stride_01 + global_k; + if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) { + let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; let d = load_f16_at_src0(block_byte_base); - for (var j = 0u; j < F16_PER_THREAD; j+=2) { - let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); + // store NQ(16) weights + for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) { + let q_byte_offset = block_byte_base + 2u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP; let q_packed = load_u32_at_src0(q_byte_offset); - for (var k = 0u; k < 4u; k++) { + for (var k = 0u; k < BYTES_PER_INNER_LOOP; k++) { let q_byte = get_byte_i32(q_packed, k); let q_val = f16(q_byte) * d; - shmem[shmem_idx + j * 2 + k] = q_val; + shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k] = q_val; } } } @@ -333,34 +332,35 @@ const BLOCK_SIZE_BYTES = 36u; // the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. override BLOCKS_K = TILE_K/BLOCK_SIZE; const NQ = 16u; -const WEIGHTS_PER_F16 = 2u; // 2 8-bit weights per f16 -const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; // 8 f16s per thread, 2 threads per block +const BYTES_PER_THREAD = 16u; // NQ(16) weights use 16 bytes of q +const BYTES_PER_INNER_LOOP = 4u; // == sizeof(q_packed) fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) { let blck_idx = i / BLOCK_SIZE; - let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; - let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; + let block_offset = (i % BLOCK_SIZE) / NQ; + let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD; let tile_m = blck_idx / BLOCKS_K; let global_m = offset_m + tile_m; let block_k = blck_idx % BLOCKS_K; - let global_k = k_outer / BLOCK_SIZE + block_k; + let global_block_k = k_outer / BLOCK_SIZE + block_k; - if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { - let src0_idx = batch_offset + global_m * params.stride_01 + global_k; + if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) { + let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; let d = load_f16_at_src0(block_byte_base); let m = load_f16_at_src0(block_byte_base + 2u); - for (var j = 0u; j < F16_PER_THREAD; j+=2) { - let q_byte_offset = block_byte_base + 4u + 2u * (block_offset + j); + // store NQ(16) weights + for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) { + let q_byte_offset = block_byte_base + 4u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP; let q_packed = load_u32_at_src0(q_byte_offset); - for (var k = 0u; k < 4u; k++) { + for (var k = 0u; k < BYTES_PER_INNER_LOOP; k++) { let q_byte = get_byte_i32(q_packed, k); let q_val = f16(q_byte) * d + m; - shmem[shmem_idx + j * 2 + k] = q_val; + shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k] = q_val; } } } @@ -1163,3 +1163,48 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 } } #endif // INIT_SRC0_SHMEM_IQ3_S + +#ifdef INIT_SRC0_SHMEM_MXFP4 +const BLOCK_SIZE = 32u; +const BLOCK_SIZE_BYTES = 17u; +// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. +override BLOCKS_K = TILE_K/BLOCK_SIZE; +const NQ = 16u; +const BYTES_PER_THREAD = 8u; // NQ(16) weights uses 8 bytes of q +const BYTES_PER_INNER_LOOP = 4u; // == sizeof(q_packed) + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) { + let blck_idx = i / BLOCK_SIZE; + let block_offset = (i % BLOCK_SIZE) / NQ; + let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD; + + let tile_m = blck_idx / BLOCKS_K; + let global_m = offset_m + tile_m; + let block_k = blck_idx % BLOCKS_K; + let global_block_k = k_outer / BLOCK_SIZE + block_k; + + if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) { + let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + let eu8 = get_byte(load_u32_at_src0(block_byte_base), 0); + let e = ldexp(1.0, i32(eu8) - 128); + + // store NQ(16) weights + for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) { + + let q_byte_offset = block_byte_base + 1u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP; + let q_packed = load_u32_at_src0(q_byte_offset); + + for (var k = 0u; k < BYTES_PER_INNER_LOOP; k++) { + let q_byte = get_byte(q_packed, k); + let q_hi = f32(kvalues_mxfp4[(q_byte >> 4) & 0xF]) * e; + let q_lo = f32(kvalues_mxfp4[q_byte & 0xF]) * e; + shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k] = f16(q_lo); + shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = f16(q_hi); + } + } + } + } +} +#endif // INIT_SRC0_SHMEM_MXFP4 diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl index 1f59bd14863..711c7e829d8 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl @@ -1389,3 +1389,45 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src return acc; } #endif + +#ifdef MUL_ACC_MXFP4 +#define BLOCK_SIZE 32 +#define BLOCK_SIZE_BYTES 17 +#define THREADS_PER_BLOCK 4 +#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { + var acc: array; + + let num_blocks = params.k / BLOCK_SIZE; + let thread_within_block = thread_id % 4; + for (var block = thread_id/THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE/THREADS_PER_BLOCK) { + let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; + var x_block: array; + for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 4] = f32(src1[x_base + i + 16]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let eu8 = get_byte(load_u32_at_src0(block_byte_base), 0); + let e = ldexp(1.0, i32(eu8) - 128); + var row_sum = 0.0; + let q_packed = load_u32_at_src0(block_byte_base + 1u + 4u * thread_within_block); + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_byte = get_byte(q_packed, byte_idx); + let q_lo = f32(kvalues_mxfp4[q_byte & 0xFu]) * e; + let q_hi = f32(kvalues_mxfp4[(q_byte >> 4u) & 0xFu]) * e; + row_sum += q_lo * x_block[byte_idx]; + row_sum += q_hi * x_block[byte_idx + 4u]; + } + acc[row] += row_sum; + } + } + } + + return acc; +} +#endif From 1caed1d2bae10279bc56261c2e03c3a86a27f4bb Mon Sep 17 00:00:00 2001 From: yzyyzyhhh <96101183+happyyzy@users.noreply.github.com> Date: Wed, 13 May 2026 04:10:37 +0800 Subject: [PATCH 611/831] opencl: add opt-in Adreno xmem F16xF32 GEMM for prefill (llama/22755) * ggml-opencl: add Adreno xmem F16xF32 GEMM for prefill * ggml-opencl: address Adreno xmem review comments * ggml-opencl: align xmem gemm kernel naming --------- Co-authored-by: Your Name --- ggml/src/ggml-opencl/CMakeLists.txt | 4 + ggml/src/ggml-opencl/ggml-opencl.cpp | 220 +++++++++++++++++ .../kernels/gemm_xmem_f16_f32_os8.cl | 233 ++++++++++++++++++ 3 files changed, 457 insertions(+) create mode 100644 ggml/src/ggml-opencl/kernels/gemm_xmem_f16_f32_os8.cl diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index 7edb3eb4e9c..0b39c011371 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -176,6 +176,10 @@ set(GGML_OPENCL_KERNELS flash_attn_f32 ) +if (GGML_OPENCL_USE_ADRENO_KERNELS) + list(APPEND GGML_OPENCL_KERNELS gemm_xmem_f16_f32_os8) +endif () + foreach (K ${GGML_OPENCL_KERNELS}) ggml_opencl_add_kernel(${K}) endforeach() diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 73a58f74a94..61bdc62cd10 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -407,6 +407,8 @@ struct ggml_backend_opencl_context { cl_bool non_uniform_workgroups; size_t image_max_buffer_size; + size_t image2d_max_width; + size_t image2d_max_height; cl_context context; cl_command_queue queue; @@ -420,6 +422,11 @@ struct ggml_backend_opencl_context { ggml_cl_buffer prealloc_src0; ggml_cl_buffer prealloc_src1; +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + ggml_cl_buffer prealloc_adreno_xmem_const; + bool adreno_xmem_gemm_enabled = false; +#endif + // prealloc buffers for MoE router table preprocess bool toggle_reorder = false; ggml_cl_buffer prealloc_post_router; @@ -538,6 +545,10 @@ struct ggml_backend_opencl_context { cl_kernel kernel_mul_mat_f16_f32; cl_kernel kernel_mul_mat_f16_f32_l4; cl_kernel kernel_mul_mat_f16_f32_tiled; + cl_kernel kernel_adreno_xmem_pack_src_f32; + cl_kernel kernel_adreno_xmem_prepack_weight_f16; + cl_kernel kernel_gemm_xmem_f16_f32_os8; + cl_kernel kernel_adreno_xmem_store_dst_f32; cl_kernel kernel_mul_mm_f16_f32_kqv; cl_kernel kernel_mul_mm_f16_f32_kq; cl_kernel kernel_mul_mat_q4_0_f32, kernel_mul_mat_q4_0_f32_v; @@ -1554,6 +1565,32 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + // gemm_xmem_f16_f32_os8 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemm_xmem_f16_f32_os8.cl.h" + }; +#else + const std::string kernel_src = read_file("gemm_xmem_f16_f32_os8.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_adreno_xmem_pack_src_f32 = + clCreateKernel(prog, "adreno_xmem_pack_src_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_adreno_xmem_prepack_weight_f16 = + clCreateKernel(prog, "adreno_xmem_prepack_weight_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_gemm_xmem_f16_f32_os8 = + clCreateKernel(prog, "kernel_gemm_xmem_f16_f32_os8", &err), err)); + CL_CHECK((backend_ctx->kernel_adreno_xmem_store_dst_f32 = + clCreateKernel(prog, "adreno_xmem_store_dst_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + // mul_mm_f32_f32_l4_lm { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -3473,6 +3510,10 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { clGetDeviceInfo(device, CL_DEVICE_IMAGE_MAX_BUFFER_SIZE, sizeof(size_t), &backend_ctx->image_max_buffer_size, NULL); GGML_LOG_INFO("ggml_opencl: device max image buffer size (pixels): %lu\n", backend_ctx->image_max_buffer_size); + clGetDeviceInfo(device, CL_DEVICE_IMAGE2D_MAX_WIDTH, sizeof(size_t), &backend_ctx->image2d_max_width, NULL); + clGetDeviceInfo(device, CL_DEVICE_IMAGE2D_MAX_HEIGHT, sizeof(size_t), &backend_ctx->image2d_max_height, NULL); + GGML_LOG_INFO("ggml_opencl: device max image2d size: %lu x %lu\n", backend_ctx->image2d_max_width, backend_ctx->image2d_max_height); + clGetDeviceInfo(device, CL_DEVICE_MAX_WORK_GROUP_SIZE, sizeof(size_t), &backend_ctx->max_workgroup_size, NULL); GGML_LOG_INFO("ggml_opencl: device max workgroup size: %lu\n", backend_ctx->max_workgroup_size); @@ -3511,6 +3552,16 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { GGML_LOG_INFO("ggml_opencl: using kernels optimized for Adreno (GGML_OPENCL_USE_ADRENO_KERNELS)\n"); #endif // GGML_OPENCL_USE_ADRENO_KERNELS +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + backend_ctx->adreno_xmem_gemm_enabled = getenv("GGML_OPENCL_ADRENO_XMEM_GEMM") != nullptr && + backend_ctx->gpu_family == GPU_FAMILY::ADRENO; + if (getenv("GGML_OPENCL_ADRENO_XMEM_GEMM") != nullptr) { + GGML_LOG_INFO("ggml_opencl: Adreno xmem F16xF32 GEMM %s\n", + backend_ctx->adreno_xmem_gemm_enabled ? + "enabled (temporary weight prepack)" : "requested but unsupported by this driver"); + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + // determine whether to use large buffer for Adreno backend_ctx->adreno_use_large_buffer = getenv("GGML_OPENCL_ADRENO_USE_LARGE_BUFFER") != nullptr && backend_ctx->gpu_family == GPU_FAMILY::ADRENO; @@ -9920,6 +9971,169 @@ static void ggml_cl_mul_mat_f16_f32_tiled(ggml_backend_t backend, const ggml_ten backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size, local_work_size, dst); } +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS +static bool ggml_cl_can_use_adreno_xmem_gemm_f16_f32( + const ggml_backend_opencl_context * backend_ctx, + const ggml_tensor * src0, + const ggml_tensor * src1, + const ggml_tensor * dst) { + if (!backend_ctx->adreno_xmem_gemm_enabled) { + return false; + } + if (backend_ctx->gpu_family != GPU_FAMILY::ADRENO) { + return false; + } + if (src0->type != GGML_TYPE_F16 || src1->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) { + return false; + } + if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) { + return false; + } + if (src0->ne[2] != 1 || src0->ne[3] != 1 || + src1->ne[2] != 1 || src1->ne[3] != 1 || + dst->ne[2] != 1 || dst->ne[3] != 1) { + return false; + } + const int K = src0->ne[0]; + const int M = src0->ne[1]; + const int N = src1->ne[1]; + if (src1->ne[0] != K || dst->ne[0] != M || dst->ne[1] != N) { + return false; + } + if (N <= 1 || M < 64 || N < 16 || K < 64) { + return false; + } + if ((K % 8) != 0) { + return false; + } + const int kpack = K / 4; + const int npack = CEIL_DIV(M, 4); + if (static_cast(N) > backend_ctx->image2d_max_width || + static_cast(kpack) > backend_ctx->image2d_max_height) { + return false; + } + if (static_cast(N) > backend_ctx->image2d_max_width || + static_cast(npack) > backend_ctx->image2d_max_height) { + return false; + } + return true; +} + +static void ggml_cl_mul_mat_f16_f32_adreno_xmem( + ggml_backend_t backend, + const ggml_tensor * src0, + const ggml_tensor * src1, + ggml_tensor * dst) { + ggml_backend_opencl_context * backend_ctx = (ggml_backend_opencl_context *)backend->context; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + const cl_ulong offset0 = extra0->offset + src0->view_offs; + const cl_ulong offset1 = extra1->offset + src1->view_offs; + const cl_ulong offsetd = extrad->offset + dst->view_offs; + + const int K = src0->ne[0]; + const int M = src0->ne[1]; + const int N = src1->ne[1]; + const int kpack = K / 4; + const int npack = CEIL_DIV(M, 4); + const int os = 8; + + const size_t xmem_bytes = 6144; + const size_t weight_bytes = static_cast(kpack) * static_cast(npack) * 4u * sizeof(cl_half4); + + backend_ctx->prealloc_adreno_xmem_const.allocate(backend_ctx->context, xmem_bytes); + + cl_int err = CL_SUCCESS; + cl_image_format fmt = {}; + fmt.image_channel_order = CL_RGBA; + fmt.image_channel_data_type = CL_HALF_FLOAT; + + cl_image_desc desc_src = {}; + desc_src.image_type = CL_MEM_OBJECT_IMAGE2D; + desc_src.image_width = static_cast(N); + desc_src.image_height = static_cast(kpack); + cl_mem src_img = clCreateImage(backend_ctx->context, CL_MEM_READ_WRITE, &fmt, &desc_src, nullptr, &err); + CL_CHECK(err); + + cl_image_desc desc_dst = {}; + desc_dst.image_type = CL_MEM_OBJECT_IMAGE2D; + desc_dst.image_width = static_cast(N); + desc_dst.image_height = static_cast(npack); + cl_mem dst_img = clCreateImage(backend_ctx->context, CL_MEM_READ_WRITE, &fmt, &desc_dst, nullptr, &err); + CL_CHECK(err); + + cl_mem weights = clCreateBuffer(backend_ctx->context, CL_MEM_READ_WRITE, weight_bytes, nullptr, &err); + CL_CHECK(err); + + cl_kernel prepack = backend_ctx->kernel_adreno_xmem_prepack_weight_f16; + CL_CHECK(clSetKernelArg(prepack, 0, sizeof(cl_mem), &weights)); + CL_CHECK(clSetKernelArg(prepack, 1, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(prepack, 2, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(prepack, 3, sizeof(int), &K)); + CL_CHECK(clSetKernelArg(prepack, 4, sizeof(int), &M)); + CL_CHECK(clSetKernelArg(prepack, 5, sizeof(int), &kpack)); + CL_CHECK(clSetKernelArg(prepack, 6, sizeof(int), &npack)); + CL_CHECK(clSetKernelArg(prepack, 7, sizeof(int), &os)); + size_t lws = 256; + size_t max_wg = backend_ctx->get_kernel_workgroup_size(prepack); + if (lws > max_wg) { + lws = max_wg; + } + size_t gws = CEIL_DIV(static_cast(kpack) * static_cast(npack), lws) * lws; + backend_ctx->enqueue_ndrange_kernel(prepack, 1, &gws, &lws, dst); + + cl_kernel pack_src = backend_ctx->kernel_adreno_xmem_pack_src_f32; + CL_CHECK(clSetKernelArg(pack_src, 0, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(pack_src, 1, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(pack_src, 2, sizeof(cl_mem), &src_img)); + CL_CHECK(clSetKernelArg(pack_src, 3, sizeof(int), &K)); + CL_CHECK(clSetKernelArg(pack_src, 4, sizeof(int), &N)); + size_t pack_src_lws[2] = { 16, 16 }; + size_t pack_src_gws[2] = { + CEIL_DIV(static_cast(N), pack_src_lws[0])*pack_src_lws[0], + CEIL_DIV(static_cast(kpack), pack_src_lws[1])*pack_src_lws[1] + }; + backend_ctx->enqueue_ndrange_kernel(pack_src, 2, pack_src_gws, pack_src_lws, dst); + + cl_kernel gemm = backend_ctx->kernel_gemm_xmem_f16_f32_os8; + CL_CHECK(clSetKernelArg(gemm, 0, sizeof(cl_mem), &weights)); + CL_CHECK(clSetKernelArg(gemm, 1, sizeof(cl_mem), &backend_ctx->prealloc_adreno_xmem_const.buffer)); + CL_CHECK(clSetKernelArg(gemm, 2, sizeof(cl_mem), &src_img)); + CL_CHECK(clSetKernelArg(gemm, 3, sizeof(cl_mem), &dst_img)); + CL_CHECK(clSetKernelArg(gemm, 4, sizeof(int), &N)); + CL_CHECK(clSetKernelArg(gemm, 5, sizeof(int), &npack)); + CL_CHECK(clSetKernelArg(gemm, 6, sizeof(int), &kpack)); + const size_t z_values = CEIL_DIV(static_cast(npack), static_cast(os)); + size_t gemm_lws[3] = { 64, 1, 1 }; + size_t gemm_gws[3] = { + z_values*gemm_lws[0], + CEIL_DIV(static_cast(N), gemm_lws[0]), + 1 + }; + backend_ctx->enqueue_ndrange_kernel(gemm, 3, gemm_gws, gemm_lws, dst); + + cl_kernel store_dst = backend_ctx->kernel_adreno_xmem_store_dst_f32; + CL_CHECK(clSetKernelArg(store_dst, 0, sizeof(cl_mem), &dst_img)); + CL_CHECK(clSetKernelArg(store_dst, 1, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(store_dst, 2, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(store_dst, 3, sizeof(int), &M)); + CL_CHECK(clSetKernelArg(store_dst, 4, sizeof(int), &N)); + size_t store_lws[2] = { 16, 16 }; + size_t store_gws[2] = { + CEIL_DIV(static_cast(N), store_lws[0])*store_lws[0], + CEIL_DIV(static_cast(npack), store_lws[1])*store_lws[1] + }; + backend_ctx->enqueue_ndrange_kernel(store_dst, 2, store_gws, store_lws, dst); + + CL_CHECK(clReleaseMemObject(weights)); + CL_CHECK(clReleaseMemObject(dst_img)); + CL_CHECK(clReleaseMemObject(src_img)); +} +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + static void ggml_cl_conv_2d(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_TENSOR_BINARY_OP_LOCALS; ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; @@ -11681,6 +11895,12 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co return; } case GGML_TYPE_F16: { +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (ggml_cl_can_use_adreno_xmem_gemm_f16_f32(backend_ctx, src0, src1, dst)) { + ggml_cl_mul_mat_f16_f32_adreno_xmem(backend, src0, src1, dst); + return; + } +#endif kernel = backend_ctx->kernel_mul_mm_f16_f32_l4_lm; nth0 = 128; // calculated as (BM*BN)/(TM*TN) diff --git a/ggml/src/ggml-opencl/kernels/gemm_xmem_f16_f32_os8.cl b/ggml/src/ggml-opencl/kernels/gemm_xmem_f16_f32_os8.cl new file mode 100644 index 00000000000..df9d9aed067 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemm_xmem_f16_f32_os8.cl @@ -0,0 +1,233 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_uniform_load : enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_constant_load : enable + +__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; + +__kernel void adreno_xmem_pack_src_f32( + __global const void * src_void, + ulong offset, + __write_only image2d_t src_img, + int K, + int N) { + const int x = get_global_id(0); + const int y = get_global_id(1); + const int kpack = K / 4; + + if (x >= N || y >= kpack) { + return; + } + + __global const float * src = (__global const float *)((__global const char *)src_void + offset); + const int base = x*K + y*4; + const half4 v = (half4)((half)src[base + 0], (half)src[base + 1], (half)src[base + 2], (half)src[base + 3]); + write_imageh(src_img, (int2)(x, y), v); +} + +__kernel void adreno_xmem_prepack_weight_f16( + __global half4 * dst, + __global const void * src_void, + ulong offset, + int K, + int M, + int kpack, + int npack, + int os) { + const int linear = get_global_id(0); + const int total = kpack*npack; + if (linear >= total) { + return; + } + + __global const half * src = (__global const half *)((__global const char *)src_void + offset); + + const int dst_ogroup = linear % os; + const int dst_o_sp_i = linear / os; + const int dst_i = dst_o_sp_i % kpack; + const int dst_o = dst_o_sp_i / kpack; + const int o_slice = dst_o*os + dst_ogroup; + const int k_base = dst_i*4; + + half4 w0 = (half4)(0.0h); + half4 w1 = (half4)(0.0h); + half4 w2 = (half4)(0.0h); + half4 w3 = (half4)(0.0h); + + const int o0 = o_slice*4 + 0; + const int o1 = o_slice*4 + 1; + const int o2 = o_slice*4 + 2; + const int o3 = o_slice*4 + 3; + + if (k_base + 0 < K) { + if (o0 < M) w0.s0 = src[o0*K + k_base + 0]; + if (o1 < M) w0.s1 = src[o1*K + k_base + 0]; + if (o2 < M) w0.s2 = src[o2*K + k_base + 0]; + if (o3 < M) w0.s3 = src[o3*K + k_base + 0]; + } + if (k_base + 1 < K) { + if (o0 < M) w1.s0 = src[o0*K + k_base + 1]; + if (o1 < M) w1.s1 = src[o1*K + k_base + 1]; + if (o2 < M) w1.s2 = src[o2*K + k_base + 1]; + if (o3 < M) w1.s3 = src[o3*K + k_base + 1]; + } + if (k_base + 2 < K) { + if (o0 < M) w2.s0 = src[o0*K + k_base + 2]; + if (o1 < M) w2.s1 = src[o1*K + k_base + 2]; + if (o2 < M) w2.s2 = src[o2*K + k_base + 2]; + if (o3 < M) w2.s3 = src[o3*K + k_base + 2]; + } + if (k_base + 3 < K) { + if (o0 < M) w3.s0 = src[o0*K + k_base + 3]; + if (o1 < M) w3.s1 = src[o1*K + k_base + 3]; + if (o2 < M) w3.s2 = src[o2*K + k_base + 3]; + if (o3 < M) w3.s3 = src[o3*K + k_base + 3]; + } + + dst[linear*4 + 0] = w0; + dst[linear*4 + 1] = w1; + dst[linear*4 + 2] = w2; + dst[linear*4 + 3] = w3; +} + +__attribute__((qcom_max_concurrent_subgroups(12))) +__kernel void kernel_gemm_xmem_f16_f32_os8( + __constant half8 * weights_buffer __attribute__((sub_group_uniform)), + __constant half8 * xmem_buffer __attribute__((max_constant_size((6144)))), + __read_only image2d_t src_img, + __write_only image2d_t dst_img, + int N, + int npack, + int kpack) { + const int X = get_group_id(1)*get_local_size(0) + get_local_id(0); + const int Z = get_group_id(0)*get_local_size(2) + get_local_id(2); + + if (X >= N || Z*8 >= npack) { + return; + } + + half4 r0 = (half4)(0.0h); + half4 r1 = (half4)(0.0h); + half4 r2 = (half4)(0.0h); + half4 r3 = (half4)(0.0h); + half4 r4 = (half4)(0.0h); + half4 r5 = (half4)(0.0h); + half4 r6 = (half4)(0.0h); + half4 r7 = (half4)(0.0h); + + int f_offset = Z*kpack*32; + int subgroup_id = (int)(0x1F & qcom_get_physical_sub_group_id()); + subgroup_id = subgroup_id % 12; + const int c_offset = subgroup_id*32; + __constant half16 * weights_cache = (__constant half16 *)&xmem_buffer[c_offset]; + + int coord_s = 0; + do { + const half4 src0 = read_imageh(src_img, smp_zero, (int2)(X, coord_s)); + coord_s++; + const half4 src1 = read_imageh(src_img, smp_zero, (int2)(X, coord_s)); + coord_s++; + + qcom_sub_group_constant_load8(xmem_buffer, weights_buffer, c_offset, f_offset >> 1, 32); + f_offset += 64; + qcom_sub_group_sync(QCOM_CLK_CONST_LOAD_SYNC); + + r0 += src0.x * weights_cache[0].s0123; + r0 += src0.y * weights_cache[0].s4567; + r0 += src0.z * weights_cache[0].s89ab; + r0 += src0.w * weights_cache[0].scdef; + r1 += src0.x * weights_cache[1].s0123; + r1 += src0.y * weights_cache[1].s4567; + r1 += src0.z * weights_cache[1].s89ab; + r1 += src0.w * weights_cache[1].scdef; + r2 += src0.x * weights_cache[2].s0123; + r2 += src0.y * weights_cache[2].s4567; + r2 += src0.z * weights_cache[2].s89ab; + r2 += src0.w * weights_cache[2].scdef; + r3 += src0.x * weights_cache[3].s0123; + r3 += src0.y * weights_cache[3].s4567; + r3 += src0.z * weights_cache[3].s89ab; + r3 += src0.w * weights_cache[3].scdef; + r4 += src0.x * weights_cache[4].s0123; + r4 += src0.y * weights_cache[4].s4567; + r4 += src0.z * weights_cache[4].s89ab; + r4 += src0.w * weights_cache[4].scdef; + r5 += src0.x * weights_cache[5].s0123; + r5 += src0.y * weights_cache[5].s4567; + r5 += src0.z * weights_cache[5].s89ab; + r5 += src0.w * weights_cache[5].scdef; + r6 += src0.x * weights_cache[6].s0123; + r6 += src0.y * weights_cache[6].s4567; + r6 += src0.z * weights_cache[6].s89ab; + r6 += src0.w * weights_cache[6].scdef; + r7 += src0.x * weights_cache[7].s0123; + r7 += src0.y * weights_cache[7].s4567; + r7 += src0.z * weights_cache[7].s89ab; + r7 += src0.w * weights_cache[7].scdef; + + r0 += src1.x * weights_cache[8].s0123; + r0 += src1.y * weights_cache[8].s4567; + r0 += src1.z * weights_cache[8].s89ab; + r0 += src1.w * weights_cache[8].scdef; + r1 += src1.x * weights_cache[9].s0123; + r1 += src1.y * weights_cache[9].s4567; + r1 += src1.z * weights_cache[9].s89ab; + r1 += src1.w * weights_cache[9].scdef; + r2 += src1.x * weights_cache[10].s0123; + r2 += src1.y * weights_cache[10].s4567; + r2 += src1.z * weights_cache[10].s89ab; + r2 += src1.w * weights_cache[10].scdef; + r3 += src1.x * weights_cache[11].s0123; + r3 += src1.y * weights_cache[11].s4567; + r3 += src1.z * weights_cache[11].s89ab; + r3 += src1.w * weights_cache[11].scdef; + r4 += src1.x * weights_cache[12].s0123; + r4 += src1.y * weights_cache[12].s4567; + r4 += src1.z * weights_cache[12].s89ab; + r4 += src1.w * weights_cache[12].scdef; + r5 += src1.x * weights_cache[13].s0123; + r5 += src1.y * weights_cache[13].s4567; + r5 += src1.z * weights_cache[13].s89ab; + r5 += src1.w * weights_cache[13].scdef; + r6 += src1.x * weights_cache[14].s0123; + r6 += src1.y * weights_cache[14].s4567; + r6 += src1.z * weights_cache[14].s89ab; + r6 += src1.w * weights_cache[14].scdef; + r7 += src1.x * weights_cache[15].s0123; + r7 += src1.y * weights_cache[15].s4567; + r7 += src1.z * weights_cache[15].s89ab; + r7 += src1.w * weights_cache[15].scdef; + } while (coord_s < kpack); + + int coord_s_out = Z*8; + if (coord_s_out < npack) { write_imageh(dst_img, (int2)(X, coord_s_out), r0); coord_s_out++; } + if (coord_s_out < npack) { write_imageh(dst_img, (int2)(X, coord_s_out), r1); coord_s_out++; } + if (coord_s_out < npack) { write_imageh(dst_img, (int2)(X, coord_s_out), r2); coord_s_out++; } + if (coord_s_out < npack) { write_imageh(dst_img, (int2)(X, coord_s_out), r3); coord_s_out++; } + if (coord_s_out < npack) { write_imageh(dst_img, (int2)(X, coord_s_out), r4); coord_s_out++; } + if (coord_s_out < npack) { write_imageh(dst_img, (int2)(X, coord_s_out), r5); coord_s_out++; } + if (coord_s_out < npack) { write_imageh(dst_img, (int2)(X, coord_s_out), r6); coord_s_out++; } + if (coord_s_out < npack) { write_imageh(dst_img, (int2)(X, coord_s_out), r7); } +} + +__kernel void adreno_xmem_store_dst_f32( + __read_only image2d_t dst_img, + __global void * dst_void, + ulong offset, + int M, + int N) { + const int x = get_global_id(0); + const int y = get_global_id(1); + const int npack = (M + 3) / 4; + + if (x >= N || y >= npack) { + return; + } + + __global float * dst = (__global float *)((__global char *)dst_void + offset); + const half4 hv = read_imageh(dst_img, smp_zero, (int2)(x, y)); + const int m = y*4; + if (m + 0 < M) dst[x*M + m + 0] = (float)hv.s0; + if (m + 1 < M) dst[x*M + m + 1] = (float)hv.s1; + if (m + 2 < M) dst[x*M + m + 2] = (float)hv.s2; + if (m + 3 < M) dst[x*M + m + 3] = (float)hv.s3; +} From bcaf4498269fea44b47621a4f0e8dff562c93cd4 Mon Sep 17 00:00:00 2001 From: Trivikram Reddy <127072883+trivikram-reddy1@users.noreply.github.com> Date: Tue, 12 May 2026 19:28:02 -0500 Subject: [PATCH 612/831] hexagon: eliminate scalar VTCM loads via HVX splat helpers (llama/22993) * hexagon: add hvx_vec_repl helpers and use those for splat-from-vtcm usecase * hmx-mm: optimize per-group scale handling * hmx-fa: optimize slope load from vtcm * hmx-fa: use aligned access where possible in hmx-utils * hexagon: add hvx_vec_repl_2x_f16 helper and consolidate repl helpers --------- Co-authored-by: Max Krasnyansky --- .../src/ggml-hexagon/htp/hmx-flash-attn-ops.c | 5 +- ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c | 29 +++----- ggml/src/ggml-hexagon/htp/hmx-utils.h | 34 ++++----- ggml/src/ggml-hexagon/htp/hvx-repl.h | 74 +++++++++++++++++++ ggml/src/ggml-hexagon/htp/hvx-utils.h | 1 + 5 files changed, 106 insertions(+), 37 deletions(-) create mode 100644 ggml/src/ggml-hexagon/htp/hvx-repl.h diff --git a/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c b/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c index 8a6d7c14edf..4a4ff0b331d 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +++ b/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c @@ -760,8 +760,9 @@ static void fa_softmax_thread(unsigned int n, unsigned int i, void * data) { // ALiBi slopes — only needed when has_alibi (scheme A) HVX_Vector v_slope0, v_slope1; if (args->has_alibi) { - v_slope0 = hvx_vec_splat_f16(args->slopes[r + 0]); - v_slope1 = (r + 1 < (int) n_rows_g) ? hvx_vec_splat_f16(args->slopes[r + 1]) : Q6_V_vzero(); + HVX_Vector v_s = hvx_vmemu(args->slopes + r); + v_slope0 = hvx_vec_repl_f16(v_s); + v_slope1 = (r + 1 < (int) n_rows_g) ? hvx_vec_repl_f16(Q6_V_vror_VR(v_s, 2)) : Q6_V_vzero(); } const HVX_Vector v_threshold = Q6_Vh_vsplat_R(0xcc00); // fp16 -16.0 (hoisted outside for-c) diff --git a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c index 9e8c9966e04..e05ccfd5fc7 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c @@ -180,12 +180,10 @@ static int hmx_compute_chunks(size_t vtcm_total, // Dequantize one x4x2 Q4_0 group (32 elements from 32 packed bytes) -> 32 FP16 in first 64 bytes. // In x4x2, sub-blocks 0..3 use lower nibbles, sub-blocks 4..7 use upper nibbles // of the same 32 packed bytes. -static inline HVX_Vector dequantize_x4x2_q4_0_group_hvx( - const uint8_t *packed_32, bool upper_nibbles, - const __fp16 *scale, const HVX_Vector vlut_cvt) { +static inline HVX_Vector dequantize_x4x2_q4_0_group_hvx(const uint8_t *packed_32, bool upper_nibbles, const __fp16 *scale, const HVX_Vector vlut_cvt) { HVX_Vector vq = hvx_vmemu(packed_32); const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); - HVX_Vector v_scales = hvx_vec_splat_f16(*scale); + HVX_Vector v_scales = hvx_vec_repl_f16(hvx_vmemu(scale)); // q4x4x2 stores two int4 values per byte. Keep only the selected nibble. HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles); v_quants = Q6_V_vand_VV(v_quants, mask_h4); @@ -223,9 +221,10 @@ static inline void dequantize_x4x2_q4_0_x4groups_hvx( HVX_Vector v_hi = Q6_V_hi_W(vp); // [group2: 32 fp16 | group3: 32 fp16] // Build per-group scale vectors: first 64 bytes use scale_a, last 64 use scale_b - HVX_VectorPred q64 = Q6_Q_vsetq_R(64); - HVX_Vector v_sc01 = Q6_V_vmux_QVV(q64, hvx_vec_splat_f16(scales_4[0]), hvx_vec_splat_f16(scales_4[1])); - HVX_Vector v_sc23 = Q6_V_vmux_QVV(q64, hvx_vec_splat_f16(scales_4[2]), hvx_vec_splat_f16(scales_4[3])); + volatile HVX_Vector vscale = hvx_vmemu(scales_4); + + HVX_Vector v_sc01 = hvx_vec_repl_2x_f16(vscale); + HVX_Vector v_sc23 = hvx_vec_repl_2x_f16(Q6_V_vror_VR(vscale, 4)); v_lo = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_lo, v_sc01)); v_hi = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hi, v_sc23)); @@ -237,10 +236,10 @@ static inline void dequantize_x4x2_q4_0_x4groups_hvx( // Dequantize one x4x2 Q8_0 group (32 int8 quants) -> 32 FP16 in first 64 bytes. static inline HVX_Vector dequantize_x4x2_q8_0_group_hvx(const int8_t *quants_32, const __fp16 *scale) { - HVX_Vector vq = hvx_vmemu(quants_32); - HVX_Vector v_scales = hvx_vec_splat_f16(*scale); - HVX_Vector v0 = Q6_V_lo_W(Q6_Wh_vunpack_Vb(vq)); - HVX_Vector v_hf = Q6_Vhf_equals_Vh(v0); + HVX_Vector vq = hvx_vmemu(quants_32); + HVX_Vector v_scales = hvx_vec_repl_f16(hvx_vmemu(scale)); + HVX_Vector v0 = Q6_V_lo_W(Q6_Wh_vunpack_Vb(vq)); + HVX_Vector v_hf = Q6_Vhf_equals_Vh(v0); return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hf, v_scales)); } @@ -521,12 +520,8 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task( const uint8_t *r0 = vtcm_src + row0 * row_stride; const uint8_t *r1 = vtcm_src + row1 * row_stride; - HVX_Vector v0 = dequantize_x4x2_q8_0_group_hvx( - (const int8_t *)(r0 + byte_off), (const __fp16 *)(r0 + scale_off)); - HVX_Vector v1 = (row1 < n_cols) - ? dequantize_x4x2_q8_0_group_hvx( - (const int8_t *)(r1 + byte_off), (const __fp16 *)(r1 + scale_off)) - : Q6_V_vzero(); + HVX_Vector v0 = dequantize_x4x2_q8_0_group_hvx((const int8_t *)(r0 + byte_off), (const __fp16 *)(r0 + scale_off)); + HVX_Vector v1 = (row1 < n_cols) ? dequantize_x4x2_q8_0_group_hvx((const int8_t *)(r1 + byte_off), (const __fp16 *)(r1 + scale_off)) : Q6_V_vzero(); Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0); v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); diff --git a/ggml/src/ggml-hexagon/htp/hmx-utils.h b/ggml/src/ggml-hexagon/htp/hmx-utils.h index 68f174d6937..f448ee3372a 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-utils.h +++ b/ggml/src/ggml-hexagon/htp/hmx-utils.h @@ -77,16 +77,18 @@ static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst, const HVX_Vector v_off0 = Q6_Vw_vadd_VwVw(v_scat_base, Q6_V_vsplat_R(local_r * 4)); const HVX_Vector v_off1 = Q6_Vw_vadd_VwVw(v_off0, v_scat_step); - __fp16 * tile_base = vtcm_dst + (size_t) ct * n_k_tiles * HMX_FP16_TILE_N_ELMS; - const uint8_t * p0 = (const uint8_t *) (vtcm_src + r * src_stride); - const uint8_t * p1 = next_row_valid ? (const uint8_t *) (vtcm_src + (r + 1) * src_stride) : NULL; + __fp16 * tile_base = vtcm_dst + (size_t) ct * n_k_tiles * HMX_FP16_TILE_N_ELMS; + const uint8_t * p0 = (const uint8_t *) (vtcm_src + r * src_stride); + const uint8_t * p1 = next_row_valid ? (const uint8_t *) (vtcm_src + (r + 1) * src_stride) : NULL; + + assert(hex_is_aligned(p0, 128)); + assert(hex_is_aligned(p1, 128)); + assert(c_byte_step % 128 == 0); if (p1) { for (int i = 0; i < n_c_iters; ++i) { - HVX_Vector v0 = hvx_vmemu(p0); - p0 += c_byte_step; - HVX_Vector v1 = hvx_vmemu(p1); - p1 += c_byte_step; + HVX_Vector v0 = hvx_vmem(p0); p0 += c_byte_step; + HVX_Vector v1 = hvx_vmem(p1); p1 += c_byte_step; Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off0, v0); Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off1, v1); tile_base += dst_step; @@ -94,8 +96,7 @@ static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst, } else { const HVX_Vector vzero = Q6_V_vzero(); for (int i = 0; i < n_c_iters; ++i) { - HVX_Vector v0 = hvx_vmemu(p0); - p0 += c_byte_step; + HVX_Vector v0 = hvx_vmem(p0); p0 += c_byte_step; Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off0, v0); Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off1, vzero); tile_base += dst_step; @@ -116,16 +117,14 @@ static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst, const HVX_Vector v_off0 = Q6_Vw_vadd_VwVw(v_scat_base, Q6_V_vsplat_R(local_r * 4)); const HVX_Vector v_off1 = Q6_Vw_vadd_VwVw(v_off0, v_scat_step); - __fp16 * tile_base = vtcm_dst + (size_t) ct * n_k_tiles * HMX_FP16_TILE_N_ELMS; - const uint8_t * p0 = (const uint8_t *) (vtcm_src + r * src_stride); - const uint8_t * p1 = next_row_valid ? (const uint8_t *) (vtcm_src + (r + 1) * src_stride) : NULL; + __fp16 * tile_base = vtcm_dst + (size_t) ct * n_k_tiles * HMX_FP16_TILE_N_ELMS; + const uint8_t * p0 = (const uint8_t *) (vtcm_src + r * src_stride); + const uint8_t * p1 = next_row_valid ? (const uint8_t *) (vtcm_src + (r + 1) * src_stride) : NULL; if (p1) { for (int i = 0; i < n_c_iters; ++i) { - HVX_Vector v0 = hvx_vmemu(p0); - p0 += c_byte_step; - HVX_Vector v1 = hvx_vmemu(p1); - p1 += c_byte_step; + HVX_Vector v0 = hvx_vmemu(p0); p0 += c_byte_step; + HVX_Vector v1 = hvx_vmemu(p1); p1 += c_byte_step; Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off0, v0); Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off1, v1); tile_base += dst_step; @@ -133,8 +132,7 @@ static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst, } else { const HVX_Vector vzero = Q6_V_vzero(); for (int i = 0; i < n_c_iters; ++i) { - HVX_Vector v0 = hvx_vmemu(p0); - p0 += c_byte_step; + HVX_Vector v0 = hvx_vmemu(p0); p0 += c_byte_step; Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off0, v0); Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off1, vzero); tile_base += dst_step; diff --git a/ggml/src/ggml-hexagon/htp/hvx-repl.h b/ggml/src/ggml-hexagon/htp/hvx-repl.h new file mode 100644 index 00000000000..fdc7e6c7d2f --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hvx-repl.h @@ -0,0 +1,74 @@ +#ifndef HVX_REPL_H +#define HVX_REPL_H + +#include +#include +#include + +#include "hvx-base.h" + +static inline HVX_Vector hvx_vec_repl(HVX_Vector v, const uint8_t * ctrl) { + return Q6_V_vdelta_VV(v, hvx_vmem(ctrl)); +} + +static inline HVX_Vector hvx_vec_repl_u32(HVX_Vector v) { + // vdelta control to replicate first 4 bytes across all lanes + static const uint8_t __attribute__((aligned(128))) repl[128] = { + 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + }; + return hvx_vec_repl(v, repl); +} + +static inline HVX_Vector hvx_vec_repl_f32(HVX_Vector v) { + // vdelta control to replicate first 4 bytes across all lanes + static const uint8_t __attribute__((aligned(128))) repl[128] = { + 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + }; + return hvx_vec_repl(v, repl); +} + +static inline HVX_Vector hvx_vec_repl_f16(HVX_Vector v) { + // vdelta control to replicate first two bytes across all lanes + static const uint8_t __attribute__((aligned(128))) repl[128] = { + 0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, + 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, + 0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, + 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, + 0x40, 0x40, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, + 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, + 0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, + 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, + }; + return hvx_vec_repl(v, repl); +} + +static inline HVX_Vector hvx_vec_repl_2x_f16(HVX_Vector v) { + // vdelta control to splat a pair of f16s: first half = f16[0], second half = f16[1] + static const uint8_t __attribute__((aligned(128))) repl[128] = { + 0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, + 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, + 0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, + 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, + 0x02, 0x02, 0x40, 0x40, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, + 0x02, 0x02, 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, + 0x02, 0x02, 0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, + 0x02, 0x02, 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, + }; + return hvx_vec_repl(v, repl); +} + +#endif // HVX_REPL_H diff --git a/ggml/src/ggml-hexagon/htp/hvx-utils.h b/ggml/src/ggml-hexagon/htp/hvx-utils.h index a518ad37331..e0452811ec3 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-utils.h +++ b/ggml/src/ggml-hexagon/htp/hvx-utils.h @@ -5,6 +5,7 @@ #include "hvx-types.h" #include "hvx-copy.h" +#include "hvx-repl.h" #include "hvx-scale.h" #include "hvx-exp.h" #include "hvx-inverse.h" From 8b288f5d966e3b3fb048097e991bdb815e79caf6 Mon Sep 17 00:00:00 2001 From: Sachin Sharma Date: Wed, 13 May 2026 11:43:47 +0530 Subject: [PATCH 613/831] ggml-zendnn : adaptive fallback to CPU backend for small batch sizes (llama/22681) * ggml-zendnn : add runtime env var GGML_ZENDNN_ADAPTIVE_FALLBACK to control adaptive fallback (default: enabled) * ggml-zendnn : restore original fallback logic when adaptive fallback is disabled --- ggml/src/ggml-zendnn/CMakeLists.txt | 2 +- ggml/src/ggml-zendnn/ggml-zendnn.cpp | 27 +++++++++++++++++++++++---- 2 files changed, 24 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-zendnn/CMakeLists.txt b/ggml/src/ggml-zendnn/CMakeLists.txt index 4f321a25257..f1e4f991fae 100644 --- a/ggml/src/ggml-zendnn/CMakeLists.txt +++ b/ggml/src/ggml-zendnn/CMakeLists.txt @@ -28,7 +28,7 @@ if (NOT ZENDNN_ROOT OR ZENDNN_ROOT STREQUAL "" OR ZENDNN_ROOT STREQUAL "OFF") ExternalProject_Add( zendnn GIT_REPOSITORY https://github.com/amd/ZenDNN.git - GIT_TAG f79f7321a1add65ced6397a6bfab7edba6e3e14e # ZenDNN-2026-WW13 + GIT_TAG ac9e580d9434b7b98985f2627a7ebfb5eba4bb0d # ZenDNN-2026-WW17 PREFIX ${ZENDNN_PREFIX} SOURCE_DIR ${ZENDNN_SOURCE_DIR} BINARY_DIR ${ZENDNN_BUILD_DIR} diff --git a/ggml/src/ggml-zendnn/ggml-zendnn.cpp b/ggml/src/ggml-zendnn/ggml-zendnn.cpp index 2b82c7c1dbb..6a83bb6b1ec 100644 --- a/ggml/src/ggml-zendnn/ggml-zendnn.cpp +++ b/ggml/src/ggml-zendnn/ggml-zendnn.cpp @@ -47,6 +47,7 @@ static bool ggml_zendnn_matmul(ggml_backend_zendnn_context * ctx, int64_t m, int params.dtypes.dst = ggml_to_zendnn_type(); params.num_threads = ctx->n_threads; + zendnnl::lowoha::matmul::matmul_batch_params_t batch_params; zendnnl::error_handling::status_t status = zendnnl::lowoha::matmul::matmul_direct( 'r', false, true, // row-major, don't transpose B, transpose A (because it's column-major) n, // M: rows of B and C @@ -59,7 +60,7 @@ static bool ggml_zendnn_matmul(ggml_backend_zendnn_context * ctx, int64_t m, int 0.0f, // beta C, ldc, // output C[n,m] true, // is_weights_const - {}, // batch_params + batch_params, // batch_params params // params ); @@ -520,6 +521,12 @@ static ggml_backend_buffer_t ggml_backend_zendnn_device_buffer_from_host_ptr(ggm GGML_UNUSED(max_tensor_size); } +static bool ggml_zendnn_adaptive_fallback_enabled() { + static const bool enabled = std::getenv("GGML_ZENDNN_ADAPTIVE_FALLBACK") == nullptr || + std::atoi(std::getenv("GGML_ZENDNN_ADAPTIVE_FALLBACK")) != 0; + return enabled; +} + static bool ggml_backend_zendnn_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) { switch (op->op) { case GGML_OP_NONE: @@ -538,12 +545,24 @@ static bool ggml_backend_zendnn_device_supports_op(ggml_backend_dev_t dev, const const int64_t ne10 = inputs->ne[0]; const int64_t ne0 = op->ne[0]; const int64_t ne1 = op->ne[1]; - const int64_t min_batch = 1; - if (!ggml_is_contiguous(weights) || !ggml_is_contiguous(inputs) || - ne0 < min_batch || ne1 < min_batch || ne10 < min_batch) { + + if(!ggml_is_contiguous(weights) || !ggml_is_contiguous(inputs)) { + return false; + } + + if (ggml_zendnn_adaptive_fallback_enabled()) { + const int64_t K = inputs->ne[0]; + const int64_t N = (inputs->ne[1]*inputs->ne[2]*inputs->ne[3]); + const int64_t M = weights->ne[1]; + if(K <= 256 || N <= 128 || M <= 96) { return false; + } } + else if (ne0 < min_batch || ne1 < min_batch || ne10 < min_batch) { + return false; + } + // MUL_MAT_ID performs best with a moderate number of experts due to its // gather + batched matmul + scatter approach. Future versions will leverage // ZenDNN's grouped_gemm for better scalability with larger expert counts: From cb7d38bf18f763efa2bc794c44e648808aa8064c Mon Sep 17 00:00:00 2001 From: Max Krasnyansky Date: Wed, 13 May 2026 06:59:28 -0700 Subject: [PATCH 614/831] hexagon: add unary tanh op (llama/22999) --- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 2 ++ ggml/src/ggml-hexagon/htp/htp-ops.h | 1 + ggml/src/ggml-hexagon/htp/main.c | 1 + ggml/src/ggml-hexagon/htp/unary-ops.c | 22 +++++++++++++++++++++- 4 files changed, 25 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index d3c125dbc3d..3d1c9da8329 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -2865,6 +2865,7 @@ static htp_op_code op_remap_to_htp(const ggml_tensor * t) { case GGML_UNARY_OP_NEG: return HTP_OP_UNARY_NEG; case GGML_UNARY_OP_EXP: return HTP_OP_UNARY_EXP; case GGML_UNARY_OP_SOFTPLUS: return HTP_OP_UNARY_SOFTPLUS; + case GGML_UNARY_OP_TANH: return HTP_OP_UNARY_TANH; default: break; } @@ -3335,6 +3336,7 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons case GGML_UNARY_OP_EXP: case GGML_UNARY_OP_SIGMOID: case GGML_UNARY_OP_SOFTPLUS: + case GGML_UNARY_OP_TANH: supp = ggml_hexagon_supported_unary(sess, op); break; case GGML_UNARY_OP_SILU: diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index 6203e3848b9..98db864dd42 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -62,6 +62,7 @@ enum htp_op_code { HTP_OP_UNARY_EXP, HTP_OP_UNARY_NEG, HTP_OP_UNARY_SOFTPLUS, + HTP_OP_UNARY_TANH, HTP_OP_GLU_SWIGLU, HTP_OP_GLU_SWIGLU_OAI, HTP_OP_GLU_GEGLU, diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index fa1e0698f4a..883a31d6163 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -542,6 +542,7 @@ static int execute_op(struct htp_ops_context * octx) { case HTP_OP_UNARY_SIGMOID: case HTP_OP_UNARY_NEG: case HTP_OP_UNARY_EXP: + case HTP_OP_UNARY_TANH: case HTP_OP_L2_NORM: return op_unary(octx); diff --git a/ggml/src/ggml-hexagon/htp/unary-ops.c b/ggml/src/ggml-hexagon/htp/unary-ops.c index 26a0e0bd793..d4ae89ee6f0 100644 --- a/ggml/src/ggml-hexagon/htp/unary-ops.c +++ b/ggml/src/ggml-hexagon/htp/unary-ops.c @@ -373,6 +373,21 @@ static void l2_norm_f32(const float * restrict src, } } +static void tanh_f32(const float * restrict src, + float * restrict dst, + uint8_t * restrict spad, + const uint32_t num_rows, + const uint32_t row_elems, + const size_t row_size, + int32_t * op_params) { + for (uint32_t ir = 0; ir < num_rows; ir++) { + const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size); + uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size); + + hvx_tanh_f32_aa(dst_local, src_local, row_elems); + } +} + static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * data) { const struct htp_unary_context * uctx = (const struct htp_unary_context *) data; struct htp_ops_context * octx = uctx->octx; @@ -477,6 +492,9 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * case HTP_OP_UNARY_SOFTPLUS: softplus_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params); break; + case HTP_OP_UNARY_TANH: + tanh_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params); + break; case HTP_OP_L2_NORM: l2_norm_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params); break; @@ -547,10 +565,12 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) { case HTP_OP_UNARY_SOFTPLUS: op_type = "softplus-f32"; break; + case HTP_OP_UNARY_TANH: + op_type = "tanh-f32"; + break; case HTP_OP_L2_NORM: op_type = "l2norm-f32"; break; - default: FARF(ERROR, "Unsupported unary Op %u\n", octx->op); return HTP_STATUS_NO_SUPPORT; From 1cbbd0b6d09147657b51f04bd83e48b8db2f6711 Mon Sep 17 00:00:00 2001 From: Masashi Yoshimura Date: Thu, 14 May 2026 02:22:44 +0900 Subject: [PATCH 615/831] flush the gpu profile timestamp before the queryset is overflowed (llama/22995) --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index b24101c78b0..401c75c1230 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -3148,6 +3148,16 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str } ctx->param_arena.reset(); commands.clear(); +#ifdef GGML_WEBGPU_GPU_PROFILE + // flush before the next batch can overflow the QuerySet + if (ctx->profile_timestamp_query_count + 2 * ctx->global_ctx->command_submit_batch_size >= + WEBGPU_MAX_PROFILE_QUERY_COUNT) { + ggml_backend_webgpu_collect_profile_results(ctx, profile_pipeline_names, num_inflight_batches); + // reset profile timestamp state + ctx->profile_timestamp_query_count = 0; + profile_pipeline_names.clear(); + } +#endif } node_idx += num_encoded_ops; From b19beb6027f45ae5a413554de05ff6896b3a5540 Mon Sep 17 00:00:00 2001 From: lhez Date: Wed, 13 May 2026 11:24:33 -0700 Subject: [PATCH 616/831] opencl: fix crash when warming up MoE on Adreno (llama/22876) --- ggml/src/ggml-opencl/ggml-opencl.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 61bdc62cd10..248124c2896 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -13132,7 +13132,7 @@ static void moe_router_reoerder(ggml_backend_t backend, const ggml_tensor * src, CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne02)); size_t histogram_global_size[] = {(size_t)(((ne21 + 63) / 64) * 64), static_cast(ne20), 1}; - size_t histogram_local_size[] = {64, static_cast(ne20), 1}; + size_t histogram_local_size[] = {64, 1, 1}; backend_ctx->enqueue_ndrange_kernel(kernel, 3, histogram_global_size, histogram_local_size, src); // Scan From d4a4d87f0ef6511c1e5fec36a3f84d6710f83c33 Mon Sep 17 00:00:00 2001 From: shaofeiqi Date: Wed, 13 May 2026 11:57:31 -0700 Subject: [PATCH 617/831] opencl: add q5_0 and q5_1 MoE for Adreno (llama/22985) * opencl: add q5_0 moe support * opencl: add q5_1 moe support * opencl: avoid potential leak * opencl: suppress unused var warning when building for non-Adreno --------- Co-authored-by: Li He --- ggml/src/ggml-opencl/CMakeLists.txt | 4 + ggml/src/ggml-opencl/ggml-opencl.cpp | 1019 +++++++++++++++-- ggml/src/ggml-opencl/kernels/cvt.cl | 204 ++++ .../kernels/gemm_moe_q5_0_f32_ns.cl | 256 +++++ .../kernels/gemm_moe_q5_1_f32_ns.cl | 258 +++++ .../kernels/gemv_moe_q5_0_f32_ns.cl | 119 ++ .../kernels/gemv_moe_q5_1_f32_ns.cl | 121 ++ 7 files changed, 1914 insertions(+), 67 deletions(-) create mode 100644 ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl create mode 100644 ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl create mode 100644 ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl create mode 100644 ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index 0b39c011371..c6aba608736 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -106,6 +106,10 @@ set(GGML_OPENCL_KERNELS gemv_moe_q4_0_f32_ns gemm_moe_q4_1_f32_ns gemv_moe_q4_1_f32_ns + gemm_moe_q5_0_f32_ns + gemv_moe_q5_0_f32_ns + gemm_moe_q5_1_f32_ns + gemv_moe_q5_1_f32_ns gemm_moe_mxfp4_f32 gemv_moe_mxfp4_f32 gemm_moe_mxfp4_f32_ns diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 248124c2896..0e511592d53 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -556,6 +556,8 @@ struct ggml_backend_opencl_context { cl_kernel kernel_convert_block_q4_0_trans4_ns, kernel_restore_block_q4_0_trans4_ns; cl_kernel kernel_convert_block_q4_1, kernel_restore_block_q4_1; cl_kernel kernel_convert_block_q4_1_trans4_ns, kernel_restore_block_q4_1_trans4_ns; + cl_kernel kernel_convert_block_q5_0_trans4_ns, kernel_restore_block_q5_0_trans4_ns; + cl_kernel kernel_convert_block_q5_1_trans4_ns, kernel_restore_block_q5_1_trans4_ns; cl_kernel kernel_convert_block_mxfp4, kernel_convert_block_mxfp4_trans, kernel_restore_block_mxfp4, kernel_restore_block_mxfp4_trans; cl_kernel kernel_convert_block_mxfp4_trans4_ns, kernel_restore_block_mxfp4_trans4_ns; cl_kernel kernel_convert_block_q8_0, kernel_restore_block_q8_0, kernel_restore_block_q8_0_trans; @@ -615,6 +617,8 @@ struct ggml_backend_opencl_context { cl_kernel kernel_timestep_embedding; cl_kernel kernel_gemv_moe_q4_0_f32_ns, kernel_gemm_moe_q4_0_f32_ns; cl_kernel kernel_gemv_moe_q4_1_f32_ns, kernel_gemm_moe_q4_1_f32_ns; + cl_kernel kernel_gemv_moe_q5_0_f32_ns, kernel_gemm_moe_q5_0_f32_ns; + cl_kernel kernel_gemv_moe_q5_1_f32_ns, kernel_gemm_moe_q5_1_f32_ns; cl_kernel kernel_gemv_moe_mxfp4_f32, kernel_gemm_moe_mxfp4_f32; cl_kernel kernel_gemv_moe_mxfp4_f32_ns, kernel_gemm_moe_mxfp4_f32_ns; cl_kernel kernel_moe_reorder_b; @@ -973,6 +977,10 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve CL_CHECK((backend_ctx->kernel_restore_block_q4_1 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_1", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q4_1_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_1_trans4_ns", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_q4_1_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_1_trans4_ns", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q5_0_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q5_0_trans4_ns", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q5_0_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q5_0_trans4_ns", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q5_1_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q5_1_trans4_ns", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q5_1_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q5_1_trans4_ns", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_mxfp4 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_mxfp4_trans = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4_trans", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_mxfp4_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4_trans4_ns", &err), err)); @@ -2995,6 +3003,74 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // gemv_moe_q5_0_f32_ns + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemv_moe_q5_0_f32_ns.cl.h" + }; +#else + const std::string kernel_src = read_file("gemv_moe_q5_0_f32_ns.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); + + CL_CHECK((backend_ctx->kernel_gemv_moe_q5_0_f32_ns = clCreateKernel(prog, "kernel_gemv_moe_q5_0_f32_ns", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // gemm_moe_q5_0_f32_ns + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemm_moe_q5_0_f32_ns.cl.h" + }; +#else + const std::string kernel_src = read_file("gemm_moe_q5_0_f32_ns.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); + + CL_CHECK((backend_ctx->kernel_gemm_moe_q5_0_f32_ns = clCreateKernel(prog, "kernel_gemm_moe_q5_0_f32_ns", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // gemv_moe_q5_1_f32_ns + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemv_moe_q5_1_f32_ns.cl.h" + }; +#else + const std::string kernel_src = read_file("gemv_moe_q5_1_f32_ns.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); + + CL_CHECK((backend_ctx->kernel_gemv_moe_q5_1_f32_ns = clCreateKernel(prog, "kernel_gemv_moe_q5_1_f32_ns", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // gemm_moe_q5_1_f32_ns + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemm_moe_q5_1_f32_ns.cl.h" + }; +#else + const std::string kernel_src = read_file("gemm_moe_q5_1_f32_ns.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); + + CL_CHECK((backend_ctx->kernel_gemm_moe_q5_1_f32_ns = clCreateKernel(prog, "kernel_gemm_moe_q5_1_f32_ns", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + // gemv_moe_mxfp4_f32_ns { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -3852,6 +3928,122 @@ struct ggml_tensor_extra_cl_q4_1 { } }; +struct ggml_tensor_extra_cl_q5_0 { + // Quantized values. + cl_mem qs = nullptr; + // Quantized values in image1d_buffer_t. + cl_mem qs_img = nullptr; + // 5-th bit values. + cl_mem qh = nullptr; + // 5-th bit values in image1d_buffer_t. + cl_mem qh_img = nullptr; + // Scales. + cl_mem d = nullptr; + // Scales in image1d_buffer_t. + cl_mem d_img = nullptr; + // Size of quantized values. + size_t size_qs = 0; + // Size of 5-th bit values. + size_t size_qh = 0; + // Size of scales. + size_t size_d = 0; + + ~ggml_tensor_extra_cl_q5_0() { + reset(); + } + + void reset() { + if (qs != nullptr) { + CL_CHECK(clReleaseMemObject(qs)); + qs = nullptr; + } + if (qh != nullptr) { + CL_CHECK(clReleaseMemObject(qh)); + qh = nullptr; + } + if (d != nullptr) { + CL_CHECK(clReleaseMemObject(d)); + d = nullptr; + } + if (qs_img != nullptr) { + CL_CHECK(clReleaseMemObject(qs_img)); + qs_img = nullptr; + } + + qh_img = nullptr; + d_img = nullptr; + size_qs = 0; + size_qh = 0; + size_d = 0; + } +}; + +struct ggml_tensor_extra_cl_q5_1 { + // Quantized values. + cl_mem qs = nullptr; + // Quantized values in image1d_buffer_t. + cl_mem qs_img = nullptr; + // 5-th bit values. + cl_mem qh = nullptr; + // 5-th bit values in image1d_buffer_t. + cl_mem qh_img = nullptr; + // Scales. + cl_mem d = nullptr; + // Scales in image1d_buffer_t. + cl_mem d_img = nullptr; + // Min + cl_mem m = nullptr; + // Min in image1d_buffer_t. + cl_mem m_img = nullptr; + // Size of quantized values. + size_t size_qs = 0; + // Size of 5-th bit values. + size_t size_qh = 0; + // Size of scales. + size_t size_d = 0; + // Size of min values. + size_t size_m = 0; + + ~ggml_tensor_extra_cl_q5_1() { + reset(); + } + + void reset() { + // q and d are subbuffers into the bigger buffer allocated in ggml_backend_buffer. + // They must be properly released so that the original buffer can be + // properly released to avoid memory leak. + if (qs != nullptr) { + CL_CHECK(clReleaseMemObject(qs)); + qs = nullptr; + } + if (qh != nullptr) { + CL_CHECK(clReleaseMemObject(qh)); + qh = nullptr; + } + if (d != nullptr) { + CL_CHECK(clReleaseMemObject(d)); + d = nullptr; + } + if (m != nullptr) { + CL_CHECK(clReleaseMemObject(m)); + m = nullptr; + } + if (qs_img != nullptr) { + CL_CHECK(clReleaseMemObject(qs_img)); + qs_img = nullptr; + } + // qh_img, d_img, and m_img are not currently allocated separately. + // TODO: initialize them for non SMALL_PATH path, or remove them. + qh_img = nullptr; + d_img = nullptr; + m_img = nullptr; + size_qs = 0; + size_qh = 0; + size_d = 0; + size_m = 0; + } +}; + struct ggml_tensor_extra_cl_mxfp4 { // Quantized values. cl_mem q = nullptr; @@ -4506,7 +4698,9 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te } // q4_0, q8_0 and mxfp4 have general MUL_MAT_ID support, // the quantizations here currently do not - they are only supported by Adreno with certain shapes - if (op->src[0]->type == GGML_TYPE_Q4_1) { + if (op->src[0]->type == GGML_TYPE_Q4_1 || + op->src[0]->type == GGML_TYPE_Q5_0 || + op->src[0]->type == GGML_TYPE_Q5_1) { #ifdef GGML_OPENCL_USE_ADRENO_KERNELS if (op->src[1]->type == GGML_TYPE_F32) { return use_adreno_moe_kernels(backend_ctx, op->src[0]) @@ -4692,6 +4886,18 @@ struct ggml_backend_opencl_buffer_context { for (ggml_tensor_extra_cl_q4_1 * e : temp_tensor_extras_q4_1_in_use) { delete e; } + for (ggml_tensor_extra_cl_q5_0 * e : temp_tensor_extras_q5_0) { + delete e; + } + for (ggml_tensor_extra_cl_q5_0 * e : temp_tensor_extras_q5_0_in_use) { + delete e; + } + for (ggml_tensor_extra_cl_q5_1 * e : temp_tensor_extras_q5_1) { + delete e; + } + for (ggml_tensor_extra_cl_q5_1 * e : temp_tensor_extras_q5_1_in_use) { + delete e; + } for (ggml_tensor_extra_cl_mxfp4 * e : temp_tensor_extras_mxfp4) { delete e; } @@ -4775,6 +4981,36 @@ struct ggml_backend_opencl_buffer_context { return extra; } + ggml_tensor_extra_cl_q5_0 * ggml_opencl_alloc_temp_tensor_extra_q5_0() { + ggml_tensor_extra_cl_q5_0 * extra; + if (temp_tensor_extras_q5_0.empty()) { + extra = new ggml_tensor_extra_cl_q5_0(); + } else { + extra = temp_tensor_extras_q5_0.back(); + temp_tensor_extras_q5_0.pop_back(); + } + + temp_tensor_extras_q5_0_in_use.push_back(extra); + + extra->reset(); + return extra; + } + + ggml_tensor_extra_cl_q5_1 * ggml_opencl_alloc_temp_tensor_extra_q5_1() { + ggml_tensor_extra_cl_q5_1 * extra; + if (temp_tensor_extras_q5_1.empty()) { + extra = new ggml_tensor_extra_cl_q5_1(); + } else { + extra = temp_tensor_extras_q5_1.back(); + temp_tensor_extras_q5_1.pop_back(); + } + + temp_tensor_extras_q5_1_in_use.push_back(extra); + + extra->reset(); + return extra; + } + ggml_tensor_extra_cl_mxfp4 * ggml_opencl_alloc_temp_tensor_extra_mxfp4() { ggml_tensor_extra_cl_mxfp4 * extra; if (temp_tensor_extras_mxfp4.empty()) { @@ -4881,6 +5117,16 @@ struct ggml_backend_opencl_buffer_context { } temp_tensor_extras_q4_1_in_use.clear(); + for (ggml_tensor_extra_cl_q5_0 * e : temp_tensor_extras_q5_0_in_use) { + temp_tensor_extras_q5_0.push_back(e); + } + temp_tensor_extras_q5_0_in_use.clear(); + + for (ggml_tensor_extra_cl_q5_1 * e : temp_tensor_extras_q5_1_in_use) { + temp_tensor_extras_q5_1.push_back(e); + } + temp_tensor_extras_q5_1_in_use.clear(); + for (ggml_tensor_extra_cl_mxfp4 * e : temp_tensor_extras_mxfp4_in_use) { temp_tensor_extras_mxfp4.push_back(e); } @@ -4923,6 +5169,10 @@ struct ggml_backend_opencl_buffer_context { std::vector temp_tensor_extras_q4_0_in_use; std::vector temp_tensor_extras_q4_1; std::vector temp_tensor_extras_q4_1_in_use; + std::vector temp_tensor_extras_q5_0; + std::vector temp_tensor_extras_q5_0_in_use; + std::vector temp_tensor_extras_q5_1; + std::vector temp_tensor_extras_q5_1_in_use; std::vector temp_tensor_extras_mxfp4; std::vector temp_tensor_extras_mxfp4_in_use; std::vector temp_tensor_extras_q8_0; @@ -5286,17 +5536,18 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, #endif // GGML_OPENCL_USE_ADRENO_KERNELS return; } - if (tensor->type == GGML_TYPE_MXFP4) { + if (tensor->type == GGML_TYPE_Q5_0) { ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra; GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized"); // Allocate the new extra and create aliases from the original. ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; - ggml_tensor_extra_cl_mxfp4 * extra = ctx->ggml_opencl_alloc_temp_tensor_extra_mxfp4(); + ggml_tensor_extra_cl_q5_0 * extra = ctx->ggml_opencl_alloc_temp_tensor_extra_q5_0(); - size_t size_e = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(char); - size_t size_q = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/2; - GGML_ASSERT(size_e + size_q == ggml_nbytes(tensor) && "Incorrect tensor size"); + size_t size_d = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t); + size_t size_qs = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/2; + size_t size_qh = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(int32_t); + GGML_ASSERT(size_d + size_qs + size_qh == ggml_nbytes(tensor) && "Incorrect tensor size"); cl_int err; cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, @@ -5306,40 +5557,48 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, queue, data_device, CL_TRUE, 0, ggml_nbytes(tensor), data, 0, NULL, NULL)); - // The original tensor memory is divided into scales and quants, i.e., - // we first store scales, then quants. cl_buffer_region region; // Create subbuffer for scales. region.origin = align_to(extra_orig->offset + tensor->view_offs + offset, backend_ctx->alignment); - region.size = size_e; - extra->e = clCreateSubBuffer( + region.size = size_d; + extra->d = clCreateSubBuffer( extra_orig->data_device, CL_MEM_READ_WRITE, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); CL_CHECK(err); auto previous_origin = region.origin; - // Create subbuffer for quants. - region.origin = align_to(previous_origin + size_e, backend_ctx->alignment); - region.size = size_q; - extra->q = clCreateSubBuffer( + // Create subbuffer for qh. + region.origin = align_to(previous_origin + size_d, backend_ctx->alignment); + region.size = size_qh; + extra->qh = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + previous_origin = region.origin; + + // Create subbuffer for qs. + region.origin = align_to(previous_origin + size_qh, backend_ctx->alignment); + region.size = size_qs; + extra->qs = clCreateSubBuffer( extra_orig->data_device, CL_MEM_READ_WRITE, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); CL_CHECK(err); #ifdef GGML_OPENCL_USE_ADRENO_KERNELS - // Adreno moe mxfp4 kernel needs special transpose and unshuffling + // Adreno moe q5_0 kernel needs special transpose and unshuffling if (use_adreno_moe_kernels(backend_ctx, tensor)) { - cl_kernel kernel = backend_ctx->kernel_convert_block_mxfp4_trans4_ns; + cl_kernel kernel = backend_ctx->kernel_convert_block_q5_0_trans4_ns; int ne00 = tensor->ne[0]; int ne01 = tensor->ne[1]; int ne02 = tensor->ne[2]; CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->e)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &ne00)); - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->qs)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->qh)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01)); size_t global_work_size[3] = {static_cast(((ne01 + 63) / 64) * 64), static_cast(ne00 / 32), static_cast(ne02)}; size_t local_work_size[3] = {64, 2, 1}; @@ -5348,61 +5607,36 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); CL_CHECK(clWaitForEvents(1, &evt)); CL_CHECK(clReleaseMemObject(data_device)); - tensor->extra = extra; // Create image for Q - cl_image_format img_format_q = {CL_R, CL_UNSIGNED_INT32}; - cl_image_desc img_desc_q = { + cl_image_format img_format_qs = {CL_R, CL_UNSIGNED_INT32}; + cl_image_desc img_desc_qs = { CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ggml_nelements(tensor) / 8), 0, 0, 0, 0, 0, 0, 0, - { extra->q } + { extra->qs } }; - extra->q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_format_q, &img_desc_q, NULL, &err); + extra->qs_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_format_qs, &img_desc_qs, NULL, &err); tensor->extra = extra; return; } - #endif // GGML_OPENCL_USE_ADRENO_KERNELS - cl_kernel kernel = backend_ctx->kernel_convert_block_mxfp4; - - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->e)); - - size_t global_work_size[3] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; - size_t local_work_size[3] = {64, 1, 1}; - - cl_event evt; - CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); - CL_CHECK(clWaitForEvents(1, &evt)); - CL_CHECK(clReleaseMemObject(data_device)); - - // Create image for Q - cl_image_format img_format_q = {CL_RG, CL_UNSIGNED_INT32}; - cl_image_desc img_desc_q = { - CL_MEM_OBJECT_IMAGE1D_BUFFER, - static_cast(ggml_nelements(tensor)/32*2), - 0, 0, 0, 0, 0, 0, 0, - { extra->q } - }; - extra->q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_format_q, &img_desc_q, NULL, &err); - tensor->extra = extra; - return; } - if (tensor->type == GGML_TYPE_Q8_0) { + if (tensor->type == GGML_TYPE_Q5_1) { ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra; GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized"); // Allocate the new extra and create aliases from the original. ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; - ggml_tensor_extra_cl_q8_0 * extra = ctx->ggml_opencl_alloc_temp_tensor_extra_q8_0(); + ggml_tensor_extra_cl_q5_1 * extra = ctx->ggml_opencl_alloc_temp_tensor_extra_q5_1(); size_t size_d = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t); - size_t size_q = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*(ggml_blck_size(tensor->type)*sizeof(char)); - GGML_ASSERT(size_d + size_q == ggml_nbytes(tensor) && "Incorrect tensor size"); + size_t size_m = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t); + size_t size_qs = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/2; + size_t size_qh = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(int32_t); + GGML_ASSERT(size_d + size_m + size_qs + size_qh == ggml_nbytes(tensor) && "Incorrect tensor size"); cl_int err; cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, @@ -5412,10 +5646,10 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, queue, data_device, CL_TRUE, 0, ggml_nbytes(tensor), data, 0, NULL, NULL)); - // The original tensor memory is divided into scales and quants, i.e., - // we first store scales, then quants. cl_buffer_region region; + // The original tensor memory is divided into scales and quants, i.e., + // we first store scales, mins, then quants. // Create subbuffer for scales. region.origin = align_to(extra_orig->offset + tensor->view_offs + offset, backend_ctx->alignment); region.size = size_d; @@ -5425,22 +5659,227 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, CL_CHECK(err); auto previous_origin = region.origin; - // Create subbuffer for quants. + // Create subbuffer for mins. region.origin = align_to(previous_origin + size_d, backend_ctx->alignment); - region.size = size_q; - extra->q = clCreateSubBuffer( + region.size = size_m; + extra->m = clCreateSubBuffer( extra_orig->data_device, CL_MEM_READ_WRITE, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); CL_CHECK(err); + previous_origin = region.origin; - cl_kernel kernel = backend_ctx->kernel_convert_block_q8_0; + // Create subbuffer for qh. + region.origin = align_to(previous_origin + size_m, backend_ctx->alignment); + region.size = size_qh; + extra->qh = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + previous_origin = region.origin; - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d)); + // Create subbuffer for qs. + region.origin = align_to(previous_origin + size_qh, backend_ctx->alignment); + region.size = size_qs; + extra->qs = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); - size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; - size_t local_work_size[] = {64, 1, 1}; +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + // Adreno moe q5_1 kernel needs special transpose and unshuffling + if (use_adreno_moe_kernels(backend_ctx, tensor)) { + cl_kernel kernel = backend_ctx->kernel_convert_block_q5_1_trans4_ns; + + int ne00 = tensor->ne[0]; + int ne01 = tensor->ne[1]; + int ne02 = tensor->ne[2]; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->qs)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->qh)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra->m)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne01)); + + size_t global_work_size[3] = {static_cast(((ne01 + 63) / 64) * 64), static_cast(ne00 / 32), static_cast(ne02)}; + size_t local_work_size[3] = {64, 2, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + + // Create image for Q + cl_image_format img_format_qs = {CL_R, CL_UNSIGNED_INT32}; + cl_image_desc img_desc_qs = { + CL_MEM_OBJECT_IMAGE1D_BUFFER, + static_cast(ggml_nelements(tensor) / 8), + 0, 0, 0, 0, 0, 0, 0, + { extra->qs } + }; + extra->qs_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_format_qs, &img_desc_qs, NULL, &err); + tensor->extra = extra; + + return; + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + return; + } + if (tensor->type == GGML_TYPE_MXFP4) { + ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra; + GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized"); + + // Allocate the new extra and create aliases from the original. + ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; + ggml_tensor_extra_cl_mxfp4 * extra = ctx->ggml_opencl_alloc_temp_tensor_extra_mxfp4(); + + size_t size_e = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(char); + size_t size_q = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/2; + GGML_ASSERT(size_e + size_q == ggml_nbytes(tensor) && "Incorrect tensor size"); + + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + CL_CHECK(clEnqueueWriteBuffer( + queue, data_device, CL_TRUE, 0, + ggml_nbytes(tensor), data, 0, NULL, NULL)); + + // The original tensor memory is divided into scales and quants, i.e., + // we first store scales, then quants. + cl_buffer_region region; + + // Create subbuffer for scales. + region.origin = align_to(extra_orig->offset + tensor->view_offs + offset, backend_ctx->alignment); + region.size = size_e; + extra->e = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + auto previous_origin = region.origin; + + // Create subbuffer for quants. + region.origin = align_to(previous_origin + size_e, backend_ctx->alignment); + region.size = size_q; + extra->q = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + // Adreno moe mxfp4 kernel needs special transpose and unshuffling + if (use_adreno_moe_kernels(backend_ctx, tensor)) { + cl_kernel kernel = backend_ctx->kernel_convert_block_mxfp4_trans4_ns; + + int ne00 = tensor->ne[0]; + int ne01 = tensor->ne[1]; + int ne02 = tensor->ne[2]; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->e)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne01)); + + size_t global_work_size[3] = {static_cast(((ne01 + 63) / 64) * 64), static_cast(ne00 / 32), static_cast(ne02)}; + size_t local_work_size[3] = {64, 2, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + tensor->extra = extra; + + // Create image for Q + cl_image_format img_format_q = {CL_R, CL_UNSIGNED_INT32}; + cl_image_desc img_desc_q = { + CL_MEM_OBJECT_IMAGE1D_BUFFER, + static_cast(ggml_nelements(tensor) / 8), + 0, 0, 0, 0, 0, 0, 0, + { extra->q } + }; + extra->q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_format_q, &img_desc_q, NULL, &err); + tensor->extra = extra; + + return; + } + +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + cl_kernel kernel = backend_ctx->kernel_convert_block_mxfp4; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->e)); + + size_t global_work_size[3] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[3] = {64, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + + // Create image for Q + cl_image_format img_format_q = {CL_RG, CL_UNSIGNED_INT32}; + cl_image_desc img_desc_q = { + CL_MEM_OBJECT_IMAGE1D_BUFFER, + static_cast(ggml_nelements(tensor)/32*2), + 0, 0, 0, 0, 0, 0, 0, + { extra->q } + }; + extra->q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_format_q, &img_desc_q, NULL, &err); + tensor->extra = extra; + + return; + } + if (tensor->type == GGML_TYPE_Q8_0) { + ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra; + GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized"); + + // Allocate the new extra and create aliases from the original. + ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; + ggml_tensor_extra_cl_q8_0 * extra = ctx->ggml_opencl_alloc_temp_tensor_extra_q8_0(); + + size_t size_d = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t); + size_t size_q = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*(ggml_blck_size(tensor->type)*sizeof(char)); + GGML_ASSERT(size_d + size_q == ggml_nbytes(tensor) && "Incorrect tensor size"); + + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + CL_CHECK(clEnqueueWriteBuffer( + queue, data_device, CL_TRUE, 0, + ggml_nbytes(tensor), data, 0, NULL, NULL)); + + // The original tensor memory is divided into scales and quants, i.e., + // we first store scales, then quants. + cl_buffer_region region; + + // Create subbuffer for scales. + region.origin = align_to(extra_orig->offset + tensor->view_offs + offset, backend_ctx->alignment); + region.size = size_d; + extra->d = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + auto previous_origin = region.origin; + + // Create subbuffer for quants. + region.origin = align_to(previous_origin + size_d, backend_ctx->alignment); + region.size = size_q; + extra->q = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + + cl_kernel kernel = backend_ctx->kernel_convert_block_q8_0; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d)); + + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {64, 1, 1}; cl_event evt; CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); @@ -6109,6 +6548,89 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, CL_CHECK(clReleaseMemObject(data_device)); return; } + if (tensor->type == GGML_TYPE_Q5_0) { + ggml_tensor_extra_cl_q5_0 * extra = (ggml_tensor_extra_cl_q5_0 *)tensor->extra; + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_moe_kernels(backend_ctx, tensor)) { + cl_int err; + // TODO: use ggml_cl_buffer to manage this temporary buffer + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + + cl_kernel kernel = backend_ctx->kernel_restore_block_q5_0_trans4_ns; + + int ne00 = tensor->ne[0]; + int ne01 = tensor->ne[1]; + int ne02 = tensor->ne[2]; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->qs)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_int), &ne01)); + + size_t global_work_size[3] = {static_cast(((ne01 + 63) / 64) * 64), static_cast(ne00 / 32), static_cast(ne02)}; + size_t local_work_size[3] = {64, 2, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clEnqueueReadBuffer( + queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); + return; + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + // TODO: normal q5_0 + (void) extra; + return; + } + if (tensor->type == GGML_TYPE_Q5_1) { + ggml_tensor_extra_cl_q5_1 * extra = (ggml_tensor_extra_cl_q5_1 *)tensor->extra; + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_moe_kernels(backend_ctx, tensor)) { + cl_int err; + // TODO: use ggml_cl_buffer to manage this temporary buffer + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + + cl_kernel kernel = backend_ctx->kernel_restore_block_q5_1_trans4_ns; + + int ne00 = tensor->ne[0]; + int ne01 = tensor->ne[1]; + int ne02 = tensor->ne[2]; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->qs)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->m)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_int), &ne01)); + + size_t global_work_size[3] = {static_cast(((ne01 + 63) / 64) * 64), static_cast(ne00 / 32), static_cast(ne02)}; + size_t local_work_size[3] = {64, 2, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clEnqueueReadBuffer( + queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); + return; + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + // TODO: normal q5_1 + (void) extra; + return; + } if (tensor->type == GGML_TYPE_MXFP4) { ggml_tensor_extra_cl_mxfp4 * extra = (ggml_tensor_extra_cl_mxfp4 *)tensor->extra; @@ -13209,10 +13731,17 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, #ifdef GGML_OPENCL_SOA_Q ggml_tensor_extra_cl_q4_0 * extra0_q4_0 = (ggml_tensor_extra_cl_q4_0 *)src0->extra; ggml_tensor_extra_cl_q4_1 * extra0_q4_1 = (ggml_tensor_extra_cl_q4_1 *)src0->extra; + ggml_tensor_extra_cl_q5_0 * extra0_q5_0 = (ggml_tensor_extra_cl_q5_0 *)src0->extra; + ggml_tensor_extra_cl_q5_1 * extra0_q5_1 = (ggml_tensor_extra_cl_q5_1 *)src0->extra; ggml_tensor_extra_cl_mxfp4 * extra0_mxfp4 = (ggml_tensor_extra_cl_mxfp4 *)src0->extra; ggml_tensor_extra_cl_q8_0 * extra0_q8_0 = (ggml_tensor_extra_cl_q8_0 *)src0->extra; #endif + // TODO: general MoE for the following types + (void)extra0_q4_1; + (void)extra0_q5_0; + (void)extra0_q5_1; + const int ne00 = src0->ne[0]; const int ne01 = src0->ne[1]; const int ne02 = src0->ne[2]; @@ -13540,8 +14069,11 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, } else { // for gemm kernel = backend_ctx->kernel_gemm_moe_q4_1_f32_ns; - if (strstr(src0->name, "as") != NULL) { + // Reorder router if called from test-backend-ops or when new router is generated. + // Otherwise reuse the reordered result from previous mul_mat_id call. + if ((strstr(src0->name, "as") != NULL) || backend_ctx->toggle_reorder) { moe_router_reoerder(backend, src2, ne20); + backend_ctx->toggle_reorder = false; } cl_mem sub_buf_src1_pre, buf_src1_reordered, image_src1_reordered, sub_buf_dst, buf_dst_image; @@ -13649,6 +14181,359 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, } return; } +#endif //GGML_OPENCL_USE_ADRENO_KERNELS + } + case GGML_TYPE_Q5_0: { +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_moe_kernels(backend_ctx, src0)) { + cl_int status; + + size_t local_size[3] = {64, 2, 1}; + size_t global_size[3] = {64, 2, 1}; + + if (ne12 == 1) { // for gemv + kernel = backend_ctx->kernel_gemv_moe_q5_0_f32_ns; + + cl_mem src1_sub_buffer, buf_src1_image, buf_src2; + + // create a sub_buffer for src2 + cl_buffer_region region; + region.origin = offset2; + region.size = ne20 * ne21 * sizeof(int); + buf_src2 = clCreateSubBuffer(extra2->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // set thread grid + global_size[0] = static_cast(ne01); + global_size[1] = 4; + global_size[2] = static_cast(ne20); + local_size[1] = 4; + + // create a sub_buffer for src1 + region.origin = offset1; + region.size = ne10 * ne11 * ne12 * sizeof(float); + src1_sub_buffer = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // create image for src1 + cl_image_format image_format_buf_src1 = {CL_RGBA, CL_FLOAT}; + cl_image_desc image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ne10 * ne11 * ne12 / 4), 0,0,0,0,0,0,0, {src1_sub_buffer}}; + buf_src1_image = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status); + CL_CHECK(status); + + // Set kernel args + int arg_idx = 0; + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_0->qs)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_0->qh)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_0->d)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src1_image)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne11)); + + // launch kernel + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst); + + // deallocate sub buffers and images + CL_CHECK(clReleaseMemObject(src1_sub_buffer)); + CL_CHECK(clReleaseMemObject(buf_src1_image)); + CL_CHECK(clReleaseMemObject(buf_src2)); + + } else { // for gemm + kernel = backend_ctx->kernel_gemm_moe_q5_0_f32_ns; + + // Reorder router if called from test-backend-ops or when new router is generated. + // Otherwise reuse the reordered result from previous mul_mat_id call. + if ((strstr(src0->name, "as") != NULL) || backend_ctx->toggle_reorder) { + moe_router_reoerder(backend, src2, ne20); + backend_ctx->toggle_reorder = false; + } + + cl_mem sub_buf_src1_pre, buf_src1_reordered, image_src1_reordered, sub_buf_dst, buf_dst_image; + cl_mem buf_src2, buf_src2_emap; + + cl_buffer_region region; + region.origin = 0; + region.size = sizeof(int) * max_post_router_tile * n_tile_size; + buf_src2 = clCreateSubBuffer(backend_ctx->prealloc_post_router.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + region.origin = 0; + region.size = sizeof(short) * max_post_router_tile; + buf_src2_emap = clCreateSubBuffer(backend_ctx->prealloc_emap.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // Reorder activations + // create a sub_buffer for src1 + region.origin = offset1; + region.size = ne10 * ne11 * ne12 * sizeof(float); + sub_buf_src1_pre = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // Create image for reordered src1 + // Use pre-allocated placeholder + region.origin = 0; + region.size = ne00 * max_post_router_tile * n_tile_size * sizeof(float); + backend_ctx->prealloc_act_trans.allocate(backend_ctx->context, region.size); + buf_src1_reordered = clCreateSubBuffer( + backend_ctx->prealloc_act_trans.buffer, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &status); + CL_CHECK(status); + cl_image_format image_format_buf_src1; + cl_image_desc image_desc_buf_src1; + image_format_buf_src1 = {CL_RGBA, CL_FLOAT}; + image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ne00 * max_post_router_tile * n_tile_size / 4), 0,0,0,0,0,0,0, {buf_src1_reordered}}; + image_src1_reordered = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status); + CL_CHECK(status); + + unsigned short map_ratio = ne20 / ne11; + GGML_ASSERT(((map_ratio == 1) || (map_ratio == ne20)) && "Map ratio not supported\n"); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 0, sizeof(cl_mem), &sub_buf_src1_pre)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 1, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 2, sizeof(cl_mem), &buf_src1_reordered)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 3, sizeof(cl_mem), &(backend_ctx->prealloc_total_tiles.buffer))); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 4, sizeof(unsigned int), &ne00)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 5, sizeof(unsigned short), &map_ratio)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 6, sizeof(unsigned int), &n_tile_size)); + + size_t reorder_b_local_size[3] = {256, 1, 1}; + size_t reorder_b_global_size[3] = {static_cast(((ne00 / 4) + 255) / 256 * 256), static_cast(max_post_router_tile * n_tile_size), 1}; + + // Dispatch reorder kernel + backend_ctx->enqueue_ndrange_kernel(backend_ctx->kernel_moe_reorder_b, 3, reorder_b_global_size, reorder_b_local_size, dst); + + // MoE kernel prepare + // Create sub buffer for dst + region.origin = offsetd; + region.size = ne0 * ne1 * ne2 * sizeof(float); + sub_buf_dst = clCreateSubBuffer( + extrad->data_device, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &status); + CL_CHECK(status); + // Create image for dst + cl_image_format image_format_buf_dst = {CL_R, CL_FLOAT}; + cl_image_desc image_desc_buf_dst = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ne0 * ne1 * ne2), 0,0,0,0,0,0,0, {sub_buf_dst}}; + buf_dst_image = clCreateImage(backend_ctx->context, CL_MEM_WRITE_ONLY, &image_format_buf_dst, &image_desc_buf_dst, NULL, &status); + CL_CHECK(status); + + // Set kernel args + int arg_idx = 0; + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_0->qs_img)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_0->qh)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_0->d)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &image_src1_reordered)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2_emap)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_dst_image)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &(backend_ctx->prealloc_total_tiles.buffer))); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01)); + + // set thread grid + global_size[1] = static_cast((ne01 + 63) / 64); + global_size[2] = static_cast(max_post_router_tile); + local_size[1] = 1; + local_size[2] = 1; + + // Dispatch kernel + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst); + + clReleaseMemObject(sub_buf_src1_pre); + clReleaseMemObject(buf_src1_reordered); + clReleaseMemObject(image_src1_reordered); + clReleaseMemObject(buf_src2); + clReleaseMemObject(buf_src2_emap); + clReleaseMemObject(sub_buf_dst); + clReleaseMemObject(buf_dst_image); + } + return; + } +#endif //GGML_OPENCL_USE_ADRENO_KERNELS + } + case GGML_TYPE_Q5_1: { +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_moe_kernels(backend_ctx, src0)) { + cl_int status; + + size_t local_size[3] = {64, 2, 1}; + size_t global_size[3] = {64, 2, 1}; + + if (ne12 == 1) { // for gemv + kernel = backend_ctx->kernel_gemv_moe_q5_1_f32_ns; + + cl_mem src1_sub_buffer, buf_src1_image, buf_src2; + + // create a sub_buffer for src2 + cl_buffer_region region; + region.origin = offset2; + region.size = ne20 * ne21 * sizeof(int); + buf_src2 = clCreateSubBuffer(extra2->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // set thread grid + global_size[0] = static_cast(ne01); + global_size[1] = 4; + global_size[2] = static_cast(ne20); + local_size[1] = 4; + + // create a sub_buffer for src1 + region.origin = offset1; + region.size = ne10 * ne11 * ne12 * sizeof(float); + src1_sub_buffer = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // create image for src1 + cl_image_format image_format_buf_src1 = {CL_RGBA, CL_FLOAT}; + cl_image_desc image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ne10 * ne11 * ne12 / 4), 0,0,0,0,0,0,0, {src1_sub_buffer}}; + buf_src1_image = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status); + CL_CHECK(status); + + // Set kernel args + int arg_idx = 0; + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_1->qs)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_1->qh)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_1->d)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_1->m)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src1_image)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne11)); + + // launch kernel + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst); + + // deallocate sub buffers and images + CL_CHECK(clReleaseMemObject(src1_sub_buffer)); + CL_CHECK(clReleaseMemObject(buf_src1_image)); + CL_CHECK(clReleaseMemObject(buf_src2)); + } else { // for gemm + kernel = backend_ctx->kernel_gemm_moe_q5_1_f32_ns; + + // Reorder router if called from test-backend-ops or when new router is generated. + // Otherwise reuse the reordered result from previous mul_mat_id call. + if ((strstr(src0->name, "as") != NULL) || backend_ctx->toggle_reorder) { + moe_router_reoerder(backend, src2, ne20); + backend_ctx->toggle_reorder = false; + } + + cl_mem sub_buf_src1_pre, buf_src1_reordered, image_src1_reordered, sub_buf_dst, buf_dst_image; + cl_mem buf_src2, buf_src2_emap; + + cl_buffer_region region; + region.origin = 0; + region.size = sizeof(int) * max_post_router_tile * n_tile_size; + buf_src2 = clCreateSubBuffer(backend_ctx->prealloc_post_router.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + region.origin = 0; + region.size = sizeof(short) * max_post_router_tile; + buf_src2_emap = clCreateSubBuffer(backend_ctx->prealloc_emap.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // Reorder activations + // create a sub_buffer for src1 + region.origin = offset1; + region.size = ne10 * ne11 * ne12 * sizeof(float); + sub_buf_src1_pre = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // Create image for reordered src1 + // Use pre-allocated placeholder + region.origin = 0; + region.size = ne00 * max_post_router_tile * n_tile_size * sizeof(float); + backend_ctx->prealloc_act_trans.allocate(backend_ctx->context, region.size); + buf_src1_reordered = clCreateSubBuffer( + backend_ctx->prealloc_act_trans.buffer, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &status); + CL_CHECK(status); + cl_image_format image_format_buf_src1; + cl_image_desc image_desc_buf_src1; + image_format_buf_src1 = {CL_RGBA, CL_FLOAT}; + image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ne00 * max_post_router_tile * n_tile_size / 4), 0,0,0,0,0,0,0, {buf_src1_reordered}}; + image_src1_reordered = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status); + CL_CHECK(status); + + unsigned short map_ratio = ne20 / ne11; + GGML_ASSERT(((map_ratio == 1) || (map_ratio == ne20)) && "Map ratio not supported\n"); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 0, sizeof(cl_mem), &sub_buf_src1_pre)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 1, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 2, sizeof(cl_mem), &buf_src1_reordered)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 3, sizeof(cl_mem), &(backend_ctx->prealloc_total_tiles.buffer))); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 4, sizeof(unsigned int), &ne00)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 5, sizeof(unsigned short), &map_ratio)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 6, sizeof(unsigned int), &n_tile_size)); + + size_t reorder_b_local_size[3] = {256, 1, 1}; + size_t reorder_b_global_size[3] = {static_cast(((ne00 / 4) + 255) / 256 * 256), static_cast(max_post_router_tile * n_tile_size), 1}; + + // Dispatch reorder kernel + backend_ctx->enqueue_ndrange_kernel(backend_ctx->kernel_moe_reorder_b, 3, reorder_b_global_size, reorder_b_local_size, dst); + + // MoE kernel prepare + // Create sub buffer for dst + region.origin = offsetd; + region.size = ne0 * ne1 * ne2 * sizeof(float); + sub_buf_dst = clCreateSubBuffer( + extrad->data_device, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &status); + CL_CHECK(status); + // Create image for dst + cl_image_format image_format_buf_dst = {CL_R, CL_FLOAT}; + cl_image_desc image_desc_buf_dst = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ne0 * ne1 * ne2), 0,0,0,0,0,0,0, {sub_buf_dst}}; + buf_dst_image = clCreateImage(backend_ctx->context, CL_MEM_WRITE_ONLY, &image_format_buf_dst, &image_desc_buf_dst, NULL, &status); + CL_CHECK(status); + + // Set kernel args + int arg_idx = 0; + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_1->qs_img)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_1->qh)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_1->d)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_1->m)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &image_src1_reordered)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2_emap)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_dst_image)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &(backend_ctx->prealloc_total_tiles.buffer))); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01)); + + // set thread grid + global_size[1] = static_cast((ne01 + 63) / 64); + global_size[2] = static_cast(max_post_router_tile); + local_size[1] = 1; + local_size[2] = 1; + + // Dispatch kernel + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst); + + clReleaseMemObject(sub_buf_src1_pre); + clReleaseMemObject(buf_src1_reordered); + clReleaseMemObject(image_src1_reordered); + clReleaseMemObject(buf_src2); + clReleaseMemObject(buf_src2_emap); + clReleaseMemObject(sub_buf_dst); + clReleaseMemObject(buf_dst_image); + } + return; + } #endif //GGML_OPENCL_USE_ADRENO_KERNELS } case GGML_TYPE_Q8_0: { diff --git a/ggml/src/ggml-opencl/kernels/cvt.cl b/ggml/src/ggml-opencl/kernels/cvt.cl index 5bbf09710f9..8f06d570587 100644 --- a/ggml/src/ggml-opencl/kernels/cvt.cl +++ b/ggml/src/ggml-opencl/kernels/cvt.cl @@ -56,6 +56,25 @@ struct block_q4_1 { uchar qs[QK4_1 / 2]; // nibbles / quants }; +//------------------------------------------------------------------------------ +// block_q5_0 +//------------------------------------------------------------------------------ +struct block_q5_0 { + half d; // delta + uchar qh[4]; // 5-th bit of quants + uchar qs[QK5_0 / 2]; // nibbles / quants +}; + +//------------------------------------------------------------------------------ +// block_q5_1 +//------------------------------------------------------------------------------ +struct block_q5_1 { + half d; // delta + half m; // min + uchar qh[4]; // 5-th bit of quants + uchar qs[QK5_1 / 2]; // nibbles / quants +}; + //------------------------------------------------------------------------------ // block_q4_k //------------------------------------------------------------------------------ @@ -460,6 +479,191 @@ kernel void kernel_restore_block_q4_1_trans4_ns( ((__global ushort8 *)(&(b->qs[0])))[0] = pre_block; } +kernel void kernel_convert_block_q5_0_trans4_ns( + __global struct block_q5_0 * src0, + __global uint * dst_qs, + __global uint * dst_qh, + __global half * dst_d, + uint ne00, + uint ne01 +) { + uint i00 = get_global_id(1); + uint i01 = get_global_id(0); + uint i02 = get_global_id(2); + + uint ne00_blk = ne00 / QK5_0; + uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; + uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; + + global struct block_q5_0 * b = src0 + src_blk_offset; + dst_d[dst_blk_offset] = b->d; + + dst_qh[dst_blk_offset] = ((global uint *)(&(b->qh[0])))[0]; + + // extract quantization and unshuffle + ushort8 pre_block = ((global ushort8 *)(&(b->qs[0])))[0]; + ushort8 post_block = (ushort8)(0); + + uchar * pre_block_ptr = (uchar *)(&pre_block); + uchar * post_block_ptr = (uchar *)(&post_block); + + for (int i = 0; i < QK5_0 / 4; ++i) { + uchar x0 = pre_block_ptr[2*i + 0]; + uchar x1 = pre_block_ptr[2*i + 1]; + + post_block_ptr[i + 0 ] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4); + post_block_ptr[i + QK5_0 / 4] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0); + } + + uint4 q_block = as_uint4(post_block); + + uint offset = i02 * ne00_blk * ne01 * 4 + i00 * ne01 * 4 + i01; + dst_qs[offset] = q_block.x; + dst_qs[offset + ne01] = q_block.y; + dst_qs[offset + ne01 * 2] = q_block.z; + dst_qs[offset + ne01 * 3] = q_block.w; +} + +kernel void kernel_restore_block_q5_0_trans4_ns( + __global uint * src_qs, + __global uint * src_qh, + __global half * src_d, + __global struct block_q5_0 * dst0, + uint ne00, + uint ne01 +) { + int i00 = get_global_id(1); + uint i01 = get_global_id(0); + uint i02 = get_global_id(2); + + uint ne00_blk = ne00 / QK5_0; + uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; + uint src_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; + + __global struct block_q5_0 * b = dst0 + dst_blk_offset; + b->d = src_d[src_blk_offset]; + + ((__global uint *)(&(b->qh[0])))[0] = src_qh[src_blk_offset]; + + // collect transposed quantization parts for a block + uint src_q_offset = i02 * ne00_blk * ne01 * 4 + i00 * ne01 * 4 + i01; + uint4 q_block; + q_block.x = src_qs[src_q_offset]; + q_block.y = src_qs[src_q_offset + ne01]; + q_block.z = src_qs[src_q_offset + ne01 * 2]; + q_block.w = src_qs[src_q_offset + ne01 * 3]; + + ushort8 post_block = as_ushort8(q_block); + ushort8 pre_block = (ushort8)(0); + + uchar * pre_block_ptr = (uchar *)(&pre_block); + uchar * post_block_ptr = (uchar *)(&post_block); + + for (int i = 0; i < QK5_0 / 4; ++i) { + uchar x0 = post_block_ptr[i + 0]; + uchar x1 = post_block_ptr[i + QK5_0 / 4]; + + pre_block_ptr[2 * i + 0] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4); + pre_block_ptr[2 * i + 1] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0); + } + + ((__global ushort8 *)(&(b->qs[0])))[0] = pre_block; +} + +kernel void kernel_convert_block_q5_1_trans4_ns( + __global struct block_q5_1 * src0, + __global uint * dst_qs, + __global uint * dst_qh, + __global half * dst_d, + __global half * dst_m, + uint ne00, + uint ne01 +) { + uint i00 = get_global_id(1); + uint i01 = get_global_id(0); + uint i02 = get_global_id(2); + + uint ne00_blk = ne00 / QK5_1; + uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; + uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; + + global struct block_q5_1 * b = src0 + src_blk_offset; + dst_d[dst_blk_offset] = b->d; + dst_m[dst_blk_offset] = b->m; + + dst_qh[dst_blk_offset] = ((global uint *)(&(b->qh[0])))[0]; + + // extract quantization and unshuffle + ushort8 pre_block = ((global ushort8 *)(&(b->qs[0])))[0]; + ushort8 post_block = (ushort8)(0); + + uchar * pre_block_ptr = (uchar *)(&pre_block); + uchar * post_block_ptr = (uchar *)(&post_block); + + for (int i = 0; i < QK5_1 / 4; ++i) { + uchar x0 = pre_block_ptr[2*i + 0]; + uchar x1 = pre_block_ptr[2*i + 1]; + + post_block_ptr[i + 0 ] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4); + post_block_ptr[i + QK5_1 / 4] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0); + } + + uint4 q_block = as_uint4(post_block); + + uint offset = i02 * ne00_blk * ne01 * 4 + i00 * ne01 * 4 + i01; + dst_qs[offset] = q_block.x; + dst_qs[offset + ne01] = q_block.y; + dst_qs[offset + ne01 * 2] = q_block.z; + dst_qs[offset + ne01 * 3] = q_block.w; +} + +kernel void kernel_restore_block_q5_1_trans4_ns( + __global uint * src_qs, + __global uint * src_qh, + __global half * src_d, + __global half * src_m, + __global struct block_q5_1 * dst0, + uint ne00, + uint ne01 +) { + int i00 = get_global_id(1); + uint i01 = get_global_id(0); + uint i02 = get_global_id(2); + + uint ne00_blk = ne00 / QK5_1; + uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; + uint src_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; + + __global struct block_q5_1 * b = dst0 + dst_blk_offset; + b->d = src_d[src_blk_offset]; + b->m = src_m[src_blk_offset]; + + ((__global uint *)(&(b->qh[0])))[0] = src_qh[src_blk_offset]; + + // collect transposed quantization parts for a block + uint src_q_offset = i02 * ne00_blk * ne01 * 4 + i00 * ne01 * 4 + i01; + uint4 q_block; + q_block.x = src_qs[src_q_offset]; + q_block.y = src_qs[src_q_offset + ne01]; + q_block.z = src_qs[src_q_offset + ne01 * 2]; + q_block.w = src_qs[src_q_offset + ne01 * 3]; + + ushort8 post_block = as_ushort8(q_block); + ushort8 pre_block = (ushort8)(0); + + uchar * pre_block_ptr = (uchar *)(&pre_block); + uchar * post_block_ptr = (uchar *)(&post_block); + + for (int i = 0; i < QK5_1 / 4; ++i) { + uchar x0 = post_block_ptr[i + 0]; + uchar x1 = post_block_ptr[i + QK5_1 / 4]; + + pre_block_ptr[2 * i + 0] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4); + pre_block_ptr[2 * i + 1] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0); + } + ((__global ushort8 *)(&(b->qs[0])))[0] = pre_block; +} + //------------------------------------------------------------------------------ // block_mxfp4 //------------------------------------------------------------------------------ diff --git a/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl new file mode 100644 index 00000000000..3524cb1bdbd --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl @@ -0,0 +1,256 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_uniform_load: enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_constant_load: enable +#pragma OPENCL EXTENSION cl_qcom_extra_vector_types : enable + +#define TILESIZE_K 16 +#define TILESIZE_M 64 +#define TILESIZE_N 32 + + +#define dequantize_q5_0(qs5x16, qh5x16, a_f16, scale) \ + a_f16.s0 = (half)((( qs5x16.s0 & 0x000F) | (( qh5x16.s0 & 0x01) << 4)) - 16) * scale; \ + a_f16.s1 = (half)((((qs5x16.s0 & 0x00F0) >> 4 ) | (((qh5x16.s0 >> 1) & 0x01) << 4)) - 16) * scale; \ + a_f16.s2 = (half)((((qs5x16.s0 & 0x0F00) >> 8 ) | (((qh5x16.s0 >> 2) & 0x01) << 4)) - 16) * scale; \ + a_f16.s3 = (half)((((qs5x16.s0 & 0xF000) >> 12) | (((qh5x16.s0 >> 3) & 0x01) << 4)) - 16) * scale; \ + a_f16.s4 = (half)((( qs5x16.s1 & 0x000F) | (((qh5x16.s0 >> 4) & 0x01) << 4)) - 16) * scale; \ + a_f16.s5 = (half)((((qs5x16.s1 & 0x00F0) >> 4 ) | (((qh5x16.s0 >> 5) & 0x01) << 4)) - 16) * scale; \ + a_f16.s6 = (half)((((qs5x16.s1 & 0x0F00) >> 8 ) | (((qh5x16.s0 >> 6) & 0x01) << 4)) - 16) * scale; \ + a_f16.s7 = (half)((((qs5x16.s1 & 0xF000) >> 12) | (((qh5x16.s0 >> 7) & 0x01) << 4)) - 16) * scale; \ + a_f16.s8 = (half)((( qs5x16.s2 & 0x000F) | (( qh5x16.s1 & 0x01) << 4)) - 16) * scale; \ + a_f16.s9 = (half)((((qs5x16.s2 & 0x00F0) >> 4 ) | (((qh5x16.s1 >> 1) & 0x01) << 4)) - 16) * scale; \ + a_f16.sa = (half)((((qs5x16.s2 & 0x0F00) >> 8 ) | (((qh5x16.s1 >> 2) & 0x01) << 4)) - 16) * scale; \ + a_f16.sb = (half)((((qs5x16.s2 & 0xF000) >> 12) | (((qh5x16.s1 >> 3) & 0x01) << 4)) - 16) * scale; \ + a_f16.sc = (half)((( qs5x16.s3 & 0x000F) | (((qh5x16.s1 >> 4) & 0x01) << 4)) - 16) * scale; \ + a_f16.sd = (half)((((qs5x16.s3 & 0x00F0) >> 4 ) | (((qh5x16.s1 >> 5) & 0x01) << 4)) - 16) * scale; \ + a_f16.se = (half)((((qs5x16.s3 & 0x0F00) >> 8 ) | (((qh5x16.s1 >> 6) & 0x01) << 4)) - 16) * scale; \ + a_f16.sf = (half)((((qs5x16.s3 & 0xF000) >> 12) | (((qh5x16.s1 >> 7) & 0x01) << 4)) - 16) * scale; \ + + +#define dotx16_reduce8(a_reg, b_lm, c_reg, lm_offset) \ + acc.s0 = dot(a_reg.s0123, b_lm[lm_offset + 0]); \ + acc.s1 = dot(a_reg.s0123, b_lm[lm_offset + 1]); \ + acc.s2 = dot(a_reg.s0123, b_lm[lm_offset + 2]); \ + acc.s3 = dot(a_reg.s0123, b_lm[lm_offset + 3]); \ + acc.s4 = dot(a_reg.s0123, b_lm[lm_offset + 4]); \ + acc.s5 = dot(a_reg.s0123, b_lm[lm_offset + 5]); \ + acc.s6 = dot(a_reg.s0123, b_lm[lm_offset + 6]); \ + acc.s7 = dot(a_reg.s0123, b_lm[lm_offset + 7]); \ + acc.s8 = dot(a_reg.s0123, b_lm[lm_offset + 8]); \ + acc.s9 = dot(a_reg.s0123, b_lm[lm_offset + 9]); \ + acc.sa = dot(a_reg.s0123, b_lm[lm_offset + 10]); \ + acc.sb = dot(a_reg.s0123, b_lm[lm_offset + 11]); \ + acc.sc = dot(a_reg.s0123, b_lm[lm_offset + 12]); \ + acc.sd = dot(a_reg.s0123, b_lm[lm_offset + 13]); \ + acc.se = dot(a_reg.s0123, b_lm[lm_offset + 14]); \ + acc.sf = dot(a_reg.s0123, b_lm[lm_offset + 15]); \ + acc.s0 += dot(a_reg.s4567, b_lm[lm_offset + 32]); \ + acc.s1 += dot(a_reg.s4567, b_lm[lm_offset + 33]); \ + acc.s2 += dot(a_reg.s4567, b_lm[lm_offset + 34]); \ + acc.s3 += dot(a_reg.s4567, b_lm[lm_offset + 35]); \ + acc.s4 += dot(a_reg.s4567, b_lm[lm_offset + 36]); \ + acc.s5 += dot(a_reg.s4567, b_lm[lm_offset + 37]); \ + acc.s6 += dot(a_reg.s4567, b_lm[lm_offset + 38]); \ + acc.s7 += dot(a_reg.s4567, b_lm[lm_offset + 39]); \ + acc.s8 += dot(a_reg.s4567, b_lm[lm_offset + 40]); \ + acc.s9 += dot(a_reg.s4567, b_lm[lm_offset + 41]); \ + acc.sa += dot(a_reg.s4567, b_lm[lm_offset + 42]); \ + acc.sb += dot(a_reg.s4567, b_lm[lm_offset + 43]); \ + acc.sc += dot(a_reg.s4567, b_lm[lm_offset + 44]); \ + acc.sd += dot(a_reg.s4567, b_lm[lm_offset + 45]); \ + acc.se += dot(a_reg.s4567, b_lm[lm_offset + 46]); \ + acc.sf += dot(a_reg.s4567, b_lm[lm_offset + 47]); \ + c_reg.lo += convert_float8(acc.lo); \ + c_reg.hi += convert_float8(acc.hi); \ + acc.s0 = dot(a_reg.s89ab, b_lm[lm_offset + 64]); \ + acc.s1 = dot(a_reg.s89ab, b_lm[lm_offset + 65]); \ + acc.s2 = dot(a_reg.s89ab, b_lm[lm_offset + 66]); \ + acc.s3 = dot(a_reg.s89ab, b_lm[lm_offset + 67]); \ + acc.s4 = dot(a_reg.s89ab, b_lm[lm_offset + 68]); \ + acc.s5 = dot(a_reg.s89ab, b_lm[lm_offset + 69]); \ + acc.s6 = dot(a_reg.s89ab, b_lm[lm_offset + 70]); \ + acc.s7 = dot(a_reg.s89ab, b_lm[lm_offset + 71]); \ + acc.s8 = dot(a_reg.s89ab, b_lm[lm_offset + 72]); \ + acc.s9 = dot(a_reg.s89ab, b_lm[lm_offset + 73]); \ + acc.sa = dot(a_reg.s89ab, b_lm[lm_offset + 74]); \ + acc.sb = dot(a_reg.s89ab, b_lm[lm_offset + 75]); \ + acc.sc = dot(a_reg.s89ab, b_lm[lm_offset + 76]); \ + acc.sd = dot(a_reg.s89ab, b_lm[lm_offset + 77]); \ + acc.se = dot(a_reg.s89ab, b_lm[lm_offset + 78]); \ + acc.sf = dot(a_reg.s89ab, b_lm[lm_offset + 79]); \ + acc.s0 += dot(a_reg.scdef, b_lm[lm_offset + 96]); \ + acc.s1 += dot(a_reg.scdef, b_lm[lm_offset + 97]); \ + acc.s2 += dot(a_reg.scdef, b_lm[lm_offset + 98]); \ + acc.s3 += dot(a_reg.scdef, b_lm[lm_offset + 99]); \ + acc.s4 += dot(a_reg.scdef, b_lm[lm_offset + 100]); \ + acc.s5 += dot(a_reg.scdef, b_lm[lm_offset + 101]); \ + acc.s6 += dot(a_reg.scdef, b_lm[lm_offset + 102]); \ + acc.s7 += dot(a_reg.scdef, b_lm[lm_offset + 103]); \ + acc.s8 += dot(a_reg.scdef, b_lm[lm_offset + 104]); \ + acc.s9 += dot(a_reg.scdef, b_lm[lm_offset + 105]); \ + acc.sa += dot(a_reg.scdef, b_lm[lm_offset + 106]); \ + acc.sb += dot(a_reg.scdef, b_lm[lm_offset + 107]); \ + acc.sc += dot(a_reg.scdef, b_lm[lm_offset + 108]); \ + acc.sd += dot(a_reg.scdef, b_lm[lm_offset + 109]); \ + acc.se += dot(a_reg.scdef, b_lm[lm_offset + 110]); \ + acc.sf += dot(a_reg.scdef, b_lm[lm_offset + 111]); \ + c_reg.lo += convert_float8(acc.lo); \ + c_reg.hi += convert_float8(acc.hi); \ + + +__attribute__((qcom_wave_pair_mode(1))) // 1=force single 2=force pair +kernel void kernel_gemm_moe_q5_0_f32_ns( + __read_only image1d_buffer_t src0_qs, + __global uint * src0_qh, + __global half * src0_d, + __read_only image1d_buffer_t src1, + __global uint * src2, + __global ushort * src2_emap, + __write_only image1d_buffer_t dst, + __global int * total_tiles, + uint ne00, + uint ne01 +) { + uint block_id_m = get_global_id(1); // m_tile + uint block_id_n = get_global_id(2); // n_tile + + // Boundary check + if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) { + return; + } + + __private half16 reg_a; + __private float32 reg_c = (float32)(0); + __local half4 shared_b[128]; + + const ushort expert_id = src2_emap[block_id_n]; + + const uint row = block_id_m * TILESIZE_M; + const uint col = block_id_n * TILESIZE_N; + + uint sub_block_id_m = get_local_id(0); + uint2 b_global_offset; + b_global_offset.x = ((sub_block_id_m & 3) << 2) + (sub_block_id_m >> 2) * ne00; + b_global_offset.y = b_global_offset.x + (16 * ne00); + uint2 b_local_offset; + b_local_offset.x = (sub_block_id_m & 3) * 32 + (sub_block_id_m >> 2); + b_local_offset.y = b_local_offset.x + 16; + + // Loop along K axis, 32 elements (one block) for each iteration, divided into 2 sub-blocks + for (uint step = 0; step < ne00; step += TILESIZE_K * 2) { + // First sub-block + uint q_sub_offset = row + ((ne01 * step) >> 3) + ((expert_id * ne00 * ne01) >> 3); + uint s_sub_offset = row + ((ne01 * step) >> 5) + ((expert_id * ne00 * ne01) >> 5); + uint b_sub_offset = col * ne00 + step; + + // Load scale for current Q5_0 block + uint blk_offset = s_sub_offset + get_global_id(0); + half s = src0_d[blk_offset]; + + // Load 32 qh (5-th bit of each Q5) for the entire block + uchar4 qhx32 = as_uchar4(src0_qh[blk_offset]); + + // Load 16 qs (half block) in transposed layout + uint2 qsx16; + qsx16.x = read_imageui(src0_qs, q_sub_offset + sub_block_id_m).x; + qsx16.y = read_imageui(src0_qs, q_sub_offset + sub_block_id_m + ne01).x; + + // Load 16x32 floats from matrix B, each fiber out of 64 in a sub-group loads 8 elements + float8 bx8_f32; + bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4); + bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4); + // Convert to half and store to LM to share within the subgroup + half8 bx8_f16 = convert_half8(bx8_f32); + shared_b[b_local_offset.x] = bx8_f16.lo; + shared_b[b_local_offset.y] = bx8_f16.hi; + + // Dequantization + dequantize_q5_0(as_ushort4(qsx16), qhx32.lo, reg_a, s); + + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + + // 32 16x16 fp16 dot product with 8 elements reduction for better precision + half16 acc; + dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0); + dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); + + // Repeat for second sub-block + uint half_step = step + TILESIZE_K; + q_sub_offset = row + ((ne01 * half_step) >> 3) + ((expert_id * ne00 * ne01) >> 3); + b_sub_offset = col * ne00 + half_step; + + // Load next 16 qs in transposed layout + qsx16.x = read_imageui(src0_qs, q_sub_offset + sub_block_id_m).x; + qsx16.y = read_imageui(src0_qs, q_sub_offset + sub_block_id_m + ne01).x; + + // Load 16x32 floats from matrix B, each fiber out of 64 in a sub-group loads 8 elements + bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4); + bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4); + // Convert to half and store to LM to share within the subgroup + bx8_f16 = convert_half8(bx8_f32); + shared_b[b_local_offset.x] = bx8_f16.lo; + shared_b[b_local_offset.y] = bx8_f16.hi; + + // Dequantization + dequantize_q5_0(as_ushort4(qsx16), qhx32.hi, reg_a, s); + + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + + // 32 16x16 fp16 dot product with 3-levels reduction for better precision + dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0); + dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); + } + + // Load poster router and share in LM + __local uint out_idx[TILESIZE_N]; + + if (get_local_id(0) < TILESIZE_N) { + uint idx = src2[block_id_n * TILESIZE_N + get_local_id(0)]; + if (idx == 0xFFFFFFFF) { + idx = src2[block_id_n * TILESIZE_N + 0]; + } + out_idx[get_local_id(0)] = idx * ne01; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + // Scatter results back to original position in output grid + uint m_offset = row + get_local_id(0); + + write_imagef(dst, out_idx[1] + m_offset, (reg_c.s1)); + write_imagef(dst, out_idx[2] + m_offset, (reg_c.s2)); + write_imagef(dst, out_idx[3] + m_offset, (reg_c.s3)); + write_imagef(dst, out_idx[4] + m_offset, (reg_c.s4)); + write_imagef(dst, out_idx[5] + m_offset, (reg_c.s5)); + write_imagef(dst, out_idx[6] + m_offset, (reg_c.s6)); + write_imagef(dst, out_idx[7] + m_offset, (reg_c.s7)); + write_imagef(dst, out_idx[8] + m_offset, (reg_c.s8)); + write_imagef(dst, out_idx[9] + m_offset, (reg_c.s9)); + write_imagef(dst, out_idx[10] + m_offset, (reg_c.sa)); + write_imagef(dst, out_idx[11] + m_offset, (reg_c.sb)); + write_imagef(dst, out_idx[12] + m_offset, (reg_c.sc)); + write_imagef(dst, out_idx[13] + m_offset, (reg_c.sd)); + write_imagef(dst, out_idx[14] + m_offset, (reg_c.se)); + write_imagef(dst, out_idx[15] + m_offset, (reg_c.sf)); + write_imagef(dst, out_idx[16] + m_offset, (reg_c.sg)); + write_imagef(dst, out_idx[17] + m_offset, (reg_c.sh)); + write_imagef(dst, out_idx[18] + m_offset, (reg_c.si)); + write_imagef(dst, out_idx[19] + m_offset, (reg_c.sj)); + write_imagef(dst, out_idx[20] + m_offset, (reg_c.sk)); + write_imagef(dst, out_idx[21] + m_offset, (reg_c.sl)); + write_imagef(dst, out_idx[22] + m_offset, (reg_c.sm)); + write_imagef(dst, out_idx[23] + m_offset, (reg_c.sn)); + write_imagef(dst, out_idx[24] + m_offset, (reg_c.so)); + write_imagef(dst, out_idx[25] + m_offset, (reg_c.sp)); + write_imagef(dst, out_idx[26] + m_offset, (reg_c.sq)); + write_imagef(dst, out_idx[27] + m_offset, (reg_c.sr)); + write_imagef(dst, out_idx[28] + m_offset, (reg_c.ss)); + write_imagef(dst, out_idx[29] + m_offset, (reg_c.st)); + write_imagef(dst, out_idx[30] + m_offset, (reg_c.su)); + write_imagef(dst, out_idx[31] + m_offset, (reg_c.sv)); + + // Store zero padding parts to the index of first output in tile, override correct result in the end + barrier(CLK_GLOBAL_MEM_FENCE); + write_imagef(dst, out_idx[0] + m_offset, (reg_c.s0)); +} diff --git a/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl new file mode 100644 index 00000000000..5fc2a523234 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl @@ -0,0 +1,258 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_uniform_load: enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_constant_load: enable +#pragma OPENCL EXTENSION cl_qcom_extra_vector_types : enable + +#define TILESIZE_K 16 +#define TILESIZE_M 64 +#define TILESIZE_N 32 + + +#define dequantize_q5_1(qs5x16, qh5x16, a_f16, scale, m) \ + a_f16.s0 = (half)((( qs5x16.s0 & 0x000F) | (( qh5x16.s0 & 0x01) << 4)) * scale + m); \ + a_f16.s1 = (half)((((qs5x16.s0 & 0x00F0) >> 4 ) | (((qh5x16.s0 >> 1) & 0x01) << 4)) * scale + m); \ + a_f16.s2 = (half)((((qs5x16.s0 & 0x0F00) >> 8 ) | (((qh5x16.s0 >> 2) & 0x01) << 4)) * scale + m); \ + a_f16.s3 = (half)((((qs5x16.s0 & 0xF000) >> 12) | (((qh5x16.s0 >> 3) & 0x01) << 4)) * scale + m); \ + a_f16.s4 = (half)((( qs5x16.s1 & 0x000F) | (((qh5x16.s0 >> 4) & 0x01) << 4)) * scale + m); \ + a_f16.s5 = (half)((((qs5x16.s1 & 0x00F0) >> 4 ) | (((qh5x16.s0 >> 5) & 0x01) << 4)) * scale + m); \ + a_f16.s6 = (half)((((qs5x16.s1 & 0x0F00) >> 8 ) | (((qh5x16.s0 >> 6) & 0x01) << 4)) * scale + m); \ + a_f16.s7 = (half)((((qs5x16.s1 & 0xF000) >> 12) | (((qh5x16.s0 >> 7) & 0x01) << 4)) * scale + m); \ + a_f16.s8 = (half)((( qs5x16.s2 & 0x000F) | (( qh5x16.s1 & 0x01) << 4)) * scale + m); \ + a_f16.s9 = (half)((((qs5x16.s2 & 0x00F0) >> 4 ) | (((qh5x16.s1 >> 1) & 0x01) << 4)) * scale + m); \ + a_f16.sa = (half)((((qs5x16.s2 & 0x0F00) >> 8 ) | (((qh5x16.s1 >> 2) & 0x01) << 4)) * scale + m); \ + a_f16.sb = (half)((((qs5x16.s2 & 0xF000) >> 12) | (((qh5x16.s1 >> 3) & 0x01) << 4)) * scale + m); \ + a_f16.sc = (half)((( qs5x16.s3 & 0x000F) | (((qh5x16.s1 >> 4) & 0x01) << 4)) * scale + m); \ + a_f16.sd = (half)((((qs5x16.s3 & 0x00F0) >> 4 ) | (((qh5x16.s1 >> 5) & 0x01) << 4)) * scale + m); \ + a_f16.se = (half)((((qs5x16.s3 & 0x0F00) >> 8 ) | (((qh5x16.s1 >> 6) & 0x01) << 4)) * scale + m); \ + a_f16.sf = (half)((((qs5x16.s3 & 0xF000) >> 12) | (((qh5x16.s1 >> 7) & 0x01) << 4)) * scale + m); \ + + +#define dotx16_reduce8(a_reg, b_lm, c_reg, lm_offset) \ + acc.s0 = dot(a_reg.s0123, b_lm[lm_offset + 0]); \ + acc.s1 = dot(a_reg.s0123, b_lm[lm_offset + 1]); \ + acc.s2 = dot(a_reg.s0123, b_lm[lm_offset + 2]); \ + acc.s3 = dot(a_reg.s0123, b_lm[lm_offset + 3]); \ + acc.s4 = dot(a_reg.s0123, b_lm[lm_offset + 4]); \ + acc.s5 = dot(a_reg.s0123, b_lm[lm_offset + 5]); \ + acc.s6 = dot(a_reg.s0123, b_lm[lm_offset + 6]); \ + acc.s7 = dot(a_reg.s0123, b_lm[lm_offset + 7]); \ + acc.s8 = dot(a_reg.s0123, b_lm[lm_offset + 8]); \ + acc.s9 = dot(a_reg.s0123, b_lm[lm_offset + 9]); \ + acc.sa = dot(a_reg.s0123, b_lm[lm_offset + 10]); \ + acc.sb = dot(a_reg.s0123, b_lm[lm_offset + 11]); \ + acc.sc = dot(a_reg.s0123, b_lm[lm_offset + 12]); \ + acc.sd = dot(a_reg.s0123, b_lm[lm_offset + 13]); \ + acc.se = dot(a_reg.s0123, b_lm[lm_offset + 14]); \ + acc.sf = dot(a_reg.s0123, b_lm[lm_offset + 15]); \ + acc.s0 += dot(a_reg.s4567, b_lm[lm_offset + 32]); \ + acc.s1 += dot(a_reg.s4567, b_lm[lm_offset + 33]); \ + acc.s2 += dot(a_reg.s4567, b_lm[lm_offset + 34]); \ + acc.s3 += dot(a_reg.s4567, b_lm[lm_offset + 35]); \ + acc.s4 += dot(a_reg.s4567, b_lm[lm_offset + 36]); \ + acc.s5 += dot(a_reg.s4567, b_lm[lm_offset + 37]); \ + acc.s6 += dot(a_reg.s4567, b_lm[lm_offset + 38]); \ + acc.s7 += dot(a_reg.s4567, b_lm[lm_offset + 39]); \ + acc.s8 += dot(a_reg.s4567, b_lm[lm_offset + 40]); \ + acc.s9 += dot(a_reg.s4567, b_lm[lm_offset + 41]); \ + acc.sa += dot(a_reg.s4567, b_lm[lm_offset + 42]); \ + acc.sb += dot(a_reg.s4567, b_lm[lm_offset + 43]); \ + acc.sc += dot(a_reg.s4567, b_lm[lm_offset + 44]); \ + acc.sd += dot(a_reg.s4567, b_lm[lm_offset + 45]); \ + acc.se += dot(a_reg.s4567, b_lm[lm_offset + 46]); \ + acc.sf += dot(a_reg.s4567, b_lm[lm_offset + 47]); \ + c_reg.lo += convert_float8(acc.lo); \ + c_reg.hi += convert_float8(acc.hi); \ + acc.s0 = dot(a_reg.s89ab, b_lm[lm_offset + 64]); \ + acc.s1 = dot(a_reg.s89ab, b_lm[lm_offset + 65]); \ + acc.s2 = dot(a_reg.s89ab, b_lm[lm_offset + 66]); \ + acc.s3 = dot(a_reg.s89ab, b_lm[lm_offset + 67]); \ + acc.s4 = dot(a_reg.s89ab, b_lm[lm_offset + 68]); \ + acc.s5 = dot(a_reg.s89ab, b_lm[lm_offset + 69]); \ + acc.s6 = dot(a_reg.s89ab, b_lm[lm_offset + 70]); \ + acc.s7 = dot(a_reg.s89ab, b_lm[lm_offset + 71]); \ + acc.s8 = dot(a_reg.s89ab, b_lm[lm_offset + 72]); \ + acc.s9 = dot(a_reg.s89ab, b_lm[lm_offset + 73]); \ + acc.sa = dot(a_reg.s89ab, b_lm[lm_offset + 74]); \ + acc.sb = dot(a_reg.s89ab, b_lm[lm_offset + 75]); \ + acc.sc = dot(a_reg.s89ab, b_lm[lm_offset + 76]); \ + acc.sd = dot(a_reg.s89ab, b_lm[lm_offset + 77]); \ + acc.se = dot(a_reg.s89ab, b_lm[lm_offset + 78]); \ + acc.sf = dot(a_reg.s89ab, b_lm[lm_offset + 79]); \ + acc.s0 += dot(a_reg.scdef, b_lm[lm_offset + 96]); \ + acc.s1 += dot(a_reg.scdef, b_lm[lm_offset + 97]); \ + acc.s2 += dot(a_reg.scdef, b_lm[lm_offset + 98]); \ + acc.s3 += dot(a_reg.scdef, b_lm[lm_offset + 99]); \ + acc.s4 += dot(a_reg.scdef, b_lm[lm_offset + 100]); \ + acc.s5 += dot(a_reg.scdef, b_lm[lm_offset + 101]); \ + acc.s6 += dot(a_reg.scdef, b_lm[lm_offset + 102]); \ + acc.s7 += dot(a_reg.scdef, b_lm[lm_offset + 103]); \ + acc.s8 += dot(a_reg.scdef, b_lm[lm_offset + 104]); \ + acc.s9 += dot(a_reg.scdef, b_lm[lm_offset + 105]); \ + acc.sa += dot(a_reg.scdef, b_lm[lm_offset + 106]); \ + acc.sb += dot(a_reg.scdef, b_lm[lm_offset + 107]); \ + acc.sc += dot(a_reg.scdef, b_lm[lm_offset + 108]); \ + acc.sd += dot(a_reg.scdef, b_lm[lm_offset + 109]); \ + acc.se += dot(a_reg.scdef, b_lm[lm_offset + 110]); \ + acc.sf += dot(a_reg.scdef, b_lm[lm_offset + 111]); \ + c_reg.lo += convert_float8(acc.lo); \ + c_reg.hi += convert_float8(acc.hi); \ + + +__attribute__((qcom_wave_pair_mode(1))) // 1=force single 2=force pair +kernel void kernel_gemm_moe_q5_1_f32_ns( + __read_only image1d_buffer_t src0_qs, + __global uint * src0_qh, + __global half * src0_d, + __global half * src0_m, + __read_only image1d_buffer_t src1, + __global uint * src2, + __global ushort * src2_emap, + __write_only image1d_buffer_t dst, + __global int * total_tiles, + uint ne00, + uint ne01 +) { + uint block_id_m = get_global_id(1); // m_tile + uint block_id_n = get_global_id(2); // n_tile + + // Boundary check + if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) { + return; + } + + __private half16 reg_a; + __private float32 reg_c = (float32)(0); + __local half4 shared_b[128]; + + const ushort expert_id = src2_emap[block_id_n]; + + const uint row = block_id_m * TILESIZE_M; + const uint col = block_id_n * TILESIZE_N; + + uint sub_block_id_m = get_local_id(0); + uint2 b_global_offset; + b_global_offset.x = ((sub_block_id_m & 3) << 2) + (sub_block_id_m >> 2) * ne00; + b_global_offset.y = b_global_offset.x + (16 * ne00); + uint2 b_local_offset; + b_local_offset.x = (sub_block_id_m & 3) * 32 + (sub_block_id_m >> 2); + b_local_offset.y = b_local_offset.x + 16; + + // Loop along K axis, 32 elements (one block) for each iteration, divided into 2 sub-blocks + for (uint step = 0; step < ne00; step += TILESIZE_K * 2) { + // First sub-block + uint q_sub_offset = row + ((ne01 * step) >> 3) + ((expert_id * ne00 * ne01) >> 3); + uint s_sub_offset = row + ((ne01 * step) >> 5) + ((expert_id * ne00 * ne01) >> 5); + uint b_sub_offset = col * ne00 + step; + + // Load scale and m for current Q5_1 block + uint blk_offset = s_sub_offset + get_global_id(0); + half s = src0_d[blk_offset]; + half m = src0_m[blk_offset]; + + // Load 32 qh (5-th bit of each Q5) for the entire block + uchar4 qhx32 = as_uchar4(src0_qh[blk_offset]); + + // Load 16 qs (half block) in transposed layout + uint2 qsx16; + qsx16.x = read_imageui(src0_qs, q_sub_offset + sub_block_id_m).x; + qsx16.y = read_imageui(src0_qs, q_sub_offset + sub_block_id_m + ne01).x; + + // Load 16x32 floats from matrix B, each fiber out of 64 in a sub-group loads 8 elements + float8 bx8_f32; + bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4); + bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4); + // Convert to half and store to LM to share within the subgroup + half8 bx8_f16 = convert_half8(bx8_f32); + shared_b[b_local_offset.x] = bx8_f16.lo; + shared_b[b_local_offset.y] = bx8_f16.hi; + + // Dequantization + dequantize_q5_1(as_ushort4(qsx16), qhx32.lo, reg_a, s, m); + + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + + // 32 16x16 fp16 dot product with 8 elements reduction for better precision + half16 acc; + dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0); + dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); + + // Repeat for second sub-block + uint half_step = step + TILESIZE_K; + q_sub_offset = row + ((ne01 * half_step) >> 3) + ((expert_id * ne00 * ne01) >> 3); + b_sub_offset = col * ne00 + half_step; + + // Load next 16 qs in transposed layout + qsx16.x = read_imageui(src0_qs, q_sub_offset + sub_block_id_m).x; + qsx16.y = read_imageui(src0_qs, q_sub_offset + sub_block_id_m + ne01).x; + + // Load 16x32 floats from matrix B, each fiber out of 64 in a sub-group loads 8 elements + bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4); + bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4); + // Convert to half and store to LM to share within the subgroup + bx8_f16 = convert_half8(bx8_f32); + shared_b[b_local_offset.x] = bx8_f16.lo; + shared_b[b_local_offset.y] = bx8_f16.hi; + + // Dequantization + dequantize_q5_1(as_ushort4(qsx16), qhx32.hi, reg_a, s, m); + + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + + // 32 16x16 fp16 dot product with 3-levels reduction for better precision + dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0); + dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); + } + + // Load poster router and share in LM + __local uint out_idx[TILESIZE_N]; + + if (get_local_id(0) < TILESIZE_N) { + uint idx = src2[block_id_n * TILESIZE_N + get_local_id(0)]; + if (idx == 0xFFFFFFFF) { + idx = src2[block_id_n * TILESIZE_N + 0]; + } + out_idx[get_local_id(0)] = idx * ne01; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + // Scatter results back to original position in output grid + uint m_offset = row + get_local_id(0); + + write_imagef(dst, out_idx[1] + m_offset, (reg_c.s1)); + write_imagef(dst, out_idx[2] + m_offset, (reg_c.s2)); + write_imagef(dst, out_idx[3] + m_offset, (reg_c.s3)); + write_imagef(dst, out_idx[4] + m_offset, (reg_c.s4)); + write_imagef(dst, out_idx[5] + m_offset, (reg_c.s5)); + write_imagef(dst, out_idx[6] + m_offset, (reg_c.s6)); + write_imagef(dst, out_idx[7] + m_offset, (reg_c.s7)); + write_imagef(dst, out_idx[8] + m_offset, (reg_c.s8)); + write_imagef(dst, out_idx[9] + m_offset, (reg_c.s9)); + write_imagef(dst, out_idx[10] + m_offset, (reg_c.sa)); + write_imagef(dst, out_idx[11] + m_offset, (reg_c.sb)); + write_imagef(dst, out_idx[12] + m_offset, (reg_c.sc)); + write_imagef(dst, out_idx[13] + m_offset, (reg_c.sd)); + write_imagef(dst, out_idx[14] + m_offset, (reg_c.se)); + write_imagef(dst, out_idx[15] + m_offset, (reg_c.sf)); + write_imagef(dst, out_idx[16] + m_offset, (reg_c.sg)); + write_imagef(dst, out_idx[17] + m_offset, (reg_c.sh)); + write_imagef(dst, out_idx[18] + m_offset, (reg_c.si)); + write_imagef(dst, out_idx[19] + m_offset, (reg_c.sj)); + write_imagef(dst, out_idx[20] + m_offset, (reg_c.sk)); + write_imagef(dst, out_idx[21] + m_offset, (reg_c.sl)); + write_imagef(dst, out_idx[22] + m_offset, (reg_c.sm)); + write_imagef(dst, out_idx[23] + m_offset, (reg_c.sn)); + write_imagef(dst, out_idx[24] + m_offset, (reg_c.so)); + write_imagef(dst, out_idx[25] + m_offset, (reg_c.sp)); + write_imagef(dst, out_idx[26] + m_offset, (reg_c.sq)); + write_imagef(dst, out_idx[27] + m_offset, (reg_c.sr)); + write_imagef(dst, out_idx[28] + m_offset, (reg_c.ss)); + write_imagef(dst, out_idx[29] + m_offset, (reg_c.st)); + write_imagef(dst, out_idx[30] + m_offset, (reg_c.su)); + write_imagef(dst, out_idx[31] + m_offset, (reg_c.sv)); + + // Store zero padding parts to the index of first output in tile, override correct result in the end + barrier(CLK_GLOBAL_MEM_FENCE); + write_imagef(dst, out_idx[0] + m_offset, (reg_c.s0)); +} diff --git a/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl new file mode 100644 index 00000000000..938054cf982 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl @@ -0,0 +1,119 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable + +#define QK_Q5_0 32 +#define N_SIMDGROUP 4 +#define SIMDGROUP_WIDTH 64 + +static inline float8 q5_0_to_fp32_packed8(ushort2 qs5x8, uchar qh5x8) { + float8 fp32x8; + fp32x8.s0 = (float)((( qs5x8.s0 & 0x000F) | (( qh5x8 & 0x01) << 4)) - 16); + fp32x8.s1 = (float)((((qs5x8.s0 & 0x00F0) >> 4 ) | (((qh5x8 >> 1) & 0x01) << 4)) - 16); + fp32x8.s2 = (float)((((qs5x8.s0 & 0x0F00) >> 8 ) | (((qh5x8 >> 2) & 0x01) << 4)) - 16); + fp32x8.s3 = (float)((((qs5x8.s0 & 0xF000) >> 12) | (((qh5x8 >> 3) & 0x01) << 4)) - 16); + fp32x8.s4 = (float)((( qs5x8.s1 & 0x000F) | (((qh5x8 >> 4) & 0x01) << 4)) - 16); + fp32x8.s5 = (float)((((qs5x8.s1 & 0x00F0) >> 4 ) | (((qh5x8 >> 5) & 0x01) << 4)) - 16); + fp32x8.s6 = (float)((((qs5x8.s1 & 0x0F00) >> 8 ) | (((qh5x8 >> 6) & 0x01) << 4)) - 16); + fp32x8.s7 = (float)((((qs5x8.s1 & 0xF000) >> 12) | (((qh5x8 >> 7) & 0x01) << 4)) - 16); + return fp32x8; +} + + +__attribute__((qcom_reqd_sub_group_size("half"))) +__kernel void kernel_gemv_moe_q5_0_f32_ns( + __global uint * src0_qs, + __global uint * src0_qh, + __global half * src0_d, + __read_only image1d_buffer_t src1, + __global uint * src2, + __global float * dst, + ulong offsetd, + uint ne00, + uint ne01, + uint ne11 +) { + uint i01 = get_global_id(0); + uint i20 = get_global_id(2); + uint sgid = get_local_id(1); + uint slid = get_sub_group_local_id(); + + uint i11 = i20 % ne11; + + uint expert_id = src2[i20]; + uint expert_offset = expert_id * ne00 * ne01 / 32; + + __private float sum = 0.0f; // each thread calculate partial sum of one output + + // loop along ne00 in block granularity, skip 4 blocks every iter + for (uint ib00 = sgid; ib00 < (ne00 / QK_Q5_0); ib00 += N_SIMDGROUP) { + + // load one block of q + uint4 regQ; + uint block_offset = expert_offset * 4 + ib00 * ne01 * 4 + i01; + + regQ.s0 = src0_qs[block_offset]; + regQ.s1 = src0_qs[block_offset + ne01]; + regQ.s2 = src0_qs[block_offset + ne01 * 2]; + regQ.s3 = src0_qs[block_offset + ne01 * 3]; + + uint offset = i11 * ne00 / 4 + ib00 * 8; + + uchar4 regQh = as_uchar4(src0_qh[ib00 * ne01 + i01 + expert_offset]); + half regS = src0_d[ib00 * ne01 + i01 + expert_offset]; + + float8 fp32x8 = q5_0_to_fp32_packed8(as_ushort2(regQ.s0), regQh.s0); + + float4 shared_y4; + shared_y4 = read_imagef(src1, (offset + 0)); + float4 acc = shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (offset + 1)); + acc += shared_y4 * fp32x8.hi; + + fp32x8 = q5_0_to_fp32_packed8(as_ushort2(regQ.s1), regQh.s1); + + shared_y4 = read_imagef(src1, (offset + 2)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (offset + 3)); + acc += shared_y4 * fp32x8.hi; + + + fp32x8 = q5_0_to_fp32_packed8(as_ushort2(regQ.s2), regQh.s2); + + shared_y4 = read_imagef(src1, (offset + 4)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (offset + 5)); + acc += shared_y4 * fp32x8.hi; + + + fp32x8 = q5_0_to_fp32_packed8(as_ushort2(regQ.s3), regQh.s3); + + shared_y4 = read_imagef(src1, (offset + 6)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (offset + 7)); + acc += shared_y4 * fp32x8.hi; + + sum += (float)(regS) * ((acc.s0 + acc.s1) + (acc.s2 + acc.s3)); + } + + // reduction in local memory, assumes #subgroups=4 + __local float reduceLM[SIMDGROUP_WIDTH * (N_SIMDGROUP - 1)]; + if (sgid == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = sum; + if (sgid == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = sum; + if (sgid == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = sum; + barrier(CLK_LOCAL_MEM_FENCE); + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 0 + slid]; + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 1 + slid]; + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 2 + slid]; + + // 1 outputs per thread in subgroup 0 + if (sgid == 0) { + dst = dst + (offsetd >> 2); + dst[i01 + i20 * ne01] = sum; + } + +} diff --git a/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl new file mode 100644 index 00000000000..f33a4ef2757 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl @@ -0,0 +1,121 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable + +#define QK_Q5_1 32 +#define N_SIMDGROUP 4 +#define SIMDGROUP_WIDTH 64 + +static inline float8 q5_1_to_fp32_packed8(ushort2 qs5x8, uchar qh5x8, half s, half m) { + float8 fp32x8; + fp32x8.s0 = (float)((( qs5x8.s0 & 0x000F) | (( qh5x8 & 0x01) << 4)) * s + m); + fp32x8.s1 = (float)((((qs5x8.s0 & 0x00F0) >> 4 ) | (((qh5x8 >> 1) & 0x01) << 4)) * s + m); + fp32x8.s2 = (float)((((qs5x8.s0 & 0x0F00) >> 8 ) | (((qh5x8 >> 2) & 0x01) << 4)) * s + m); + fp32x8.s3 = (float)((((qs5x8.s0 & 0xF000) >> 12) | (((qh5x8 >> 3) & 0x01) << 4)) * s + m); + fp32x8.s4 = (float)((( qs5x8.s1 & 0x000F) | (((qh5x8 >> 4) & 0x01) << 4)) * s + m); + fp32x8.s5 = (float)((((qs5x8.s1 & 0x00F0) >> 4 ) | (((qh5x8 >> 5) & 0x01) << 4)) * s + m); + fp32x8.s6 = (float)((((qs5x8.s1 & 0x0F00) >> 8 ) | (((qh5x8 >> 6) & 0x01) << 4)) * s + m); + fp32x8.s7 = (float)((((qs5x8.s1 & 0xF000) >> 12) | (((qh5x8 >> 7) & 0x01) << 4)) * s + m); + return fp32x8; +} + + +__attribute__((qcom_reqd_sub_group_size("half"))) +__kernel void kernel_gemv_moe_q5_1_f32_ns( + __global uint * src0_qs, + __global uint * src0_qh, + __global half * src0_d, + __global half * src0_m, + __read_only image1d_buffer_t src1, + __global uint * src2, + __global float * dst, + ulong offsetd, + uint ne00, + uint ne01, + uint ne11 +) { + uint i01 = get_global_id(0); + uint i20 = get_global_id(2); + uint sgid = get_local_id(1); + uint slid = get_sub_group_local_id(); + + uint i11 = i20 % ne11; + + uint expert_id = src2[i20]; + uint expert_offset = expert_id * ne00 * ne01 / 32; + + __private float sum = 0.0f; // each thread calculate partial sum of one output + + // loop along ne00 in block granularity, skip 4 blocks every iter + for (uint ib00 = sgid; ib00 < (ne00 / QK_Q5_1); ib00 += N_SIMDGROUP) { + + // load one block of q + uint4 regQ; + uint block_offset = expert_offset * 4 + ib00 * ne01 * 4 + i01; + + regQ.s0 = src0_qs[block_offset]; + regQ.s1 = src0_qs[block_offset + ne01]; + regQ.s2 = src0_qs[block_offset + ne01 * 2]; + regQ.s3 = src0_qs[block_offset + ne01 * 3]; + + uint offset = i11 * ne00 / 4 + ib00 * 8; + + uchar4 regQh = as_uchar4(src0_qh[ib00 * ne01 + i01 + expert_offset]); + half regM = src0_m[ib00 * ne01 + i01 + expert_offset]; + half regS = src0_d[ib00 * ne01 + i01 + expert_offset]; + + float8 fp32x8 = q5_1_to_fp32_packed8(as_ushort2(regQ.s0), regQh.s0, regS, regM); + + float4 shared_y4; + shared_y4 = read_imagef(src1, (offset + 0)); + float4 acc = shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (offset + 1)); + acc += shared_y4 * fp32x8.hi; + + fp32x8 = q5_1_to_fp32_packed8(as_ushort2(regQ.s1), regQh.s1, regS, regM); + + shared_y4 = read_imagef(src1, (offset + 2)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (offset + 3)); + acc += shared_y4 * fp32x8.hi; + + + fp32x8 = q5_1_to_fp32_packed8(as_ushort2(regQ.s2), regQh.s2, regS, regM); + + shared_y4 = read_imagef(src1, (offset + 4)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (offset + 5)); + acc += shared_y4 * fp32x8.hi; + + + fp32x8 = q5_1_to_fp32_packed8(as_ushort2(regQ.s3), regQh.s3, regS, regM); + + shared_y4 = read_imagef(src1, (offset + 6)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (offset + 7)); + acc += shared_y4 * fp32x8.hi; + + sum += ((acc.s0 + acc.s1) + (acc.s2 + acc.s3)); + } + + // reduction in local memory, assumes #subgroups=4 + __local float reduceLM[SIMDGROUP_WIDTH * (N_SIMDGROUP - 1)]; + if (sgid == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = sum; + if (sgid == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = sum; + if (sgid == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = sum; + barrier(CLK_LOCAL_MEM_FENCE); + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 0 + slid]; + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 1 + slid]; + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 2 + slid]; + + // 1 outputs per thread in subgroup 0 + if (sgid == 0) { + dst = dst + (offsetd >> 2); + dst[i01 + i20 * ne01] = sum; + } + +} From 97371e928560fffe4d1309202dcdc6290067269c Mon Sep 17 00:00:00 2001 From: scutler-nv Date: Wed, 13 May 2026 13:36:14 -0700 Subject: [PATCH 618/831] Fix for issue #22974. Cast intermediate results to float before adding and casting the result to the destination type. Avoids half+half operator ambiguity. (llama/22994) --- ggml/src/ggml-cuda/allreduce.cu | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-cuda/allreduce.cu b/ggml/src/ggml-cuda/allreduce.cu index 434689abd95..d56129a227e 100644 --- a/ggml/src/ggml-cuda/allreduce.cu +++ b/ggml/src/ggml-cuda/allreduce.cu @@ -184,13 +184,15 @@ static __global__ void ggml_cuda_ar_kernel( #pragma unroll for (int k = 0; k < ELEMS_PER_VEC; ++k) { const T_wire d_low = ggml_cuda_cast(sendbuf[off + k]); - recvbuf[off + k] = ggml_cuda_cast(d_low) + ggml_cuda_cast(wire[k]); + recvbuf[off + k] = ggml_cuda_cast( + ggml_cuda_cast(d_low) + ggml_cuda_cast(wire[k])); } } if (bid == 0 && tid < count - tail) { const T_wire d_low = ggml_cuda_cast(sendbuf[tail + tid]); - recvbuf[tail + tid] = - ggml_cuda_cast(d_low) + ggml_cuda_cast(host_other[tail + tid]); + recvbuf[tail + tid] = ggml_cuda_cast( + ggml_cuda_cast(d_low) + + ggml_cuda_cast(host_other[tail + tid])); } } } @@ -210,7 +212,8 @@ static __global__ void ggml_cuda_ar_add_kernel( const int nt = gridDim.x * blockDim.x; for (int i = tid; i < count; i += nt) { const T_src d_low = ggml_cuda_cast(dst[i]); - dst[i] = ggml_cuda_cast(d_low) + ggml_cuda_cast(src[i]); + dst[i] = ggml_cuda_cast( + ggml_cuda_cast(d_low) + ggml_cuda_cast(src[i])); } } From e4ce42e55f325c83b4e8fb4f848cd3fa086988c0 Mon Sep 17 00:00:00 2001 From: Zheyuan Chen Date: Wed, 13 May 2026 15:12:40 -0700 Subject: [PATCH 619/831] ggml-webgpu: only use subgroup-matrix path when head dims are divisible by sg_mat_k / sg_mat_n (llama/23020) --- ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 11701e79433..62a523365b9 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -777,7 +777,10 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions( const bool tile_can_dispatch_all_q_rows = context.max_subgroup_size > 0 && context.max_wg_size >= GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE * context.max_subgroup_size; - const bool use_tile = context.supports_subgroups && !context.supports_subgroup_matrix && K->type == GGML_TYPE_F16 && + const bool use_subgroup_matrix = + context.supports_subgroup_matrix && context.sg_mat_k > 0 && context.sg_mat_n > 0 && + context.src0->ne[0] % context.sg_mat_k == 0 && context.src2->ne[0] % context.sg_mat_n == 0; + const bool use_tile = context.supports_subgroups && !use_subgroup_matrix && K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16 && f16_vec4_aligned && (context.src0->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && (context.src2->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && @@ -785,7 +788,7 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions( decisions.path = use_vec ? GGML_WEBGPU_FLASH_ATTN_PATH_VEC : use_tile ? GGML_WEBGPU_FLASH_ATTN_PATH_TILE : - context.supports_subgroup_matrix ? GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX : + use_subgroup_matrix ? GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX : GGML_WEBGPU_FLASH_ATTN_PATH_NONE; if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_NONE) { From 69500f5502bf508e3d0f350734e9158f59aab1d6 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 14 May 2026 11:53:30 +0300 Subject: [PATCH 620/831] sync : ggml --- scripts/sync-ggml.last | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index 15685a0718f..5a605ba344e 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -628249b398293fc8d2fa81a449ae2920a02c6523 +57ea0bc119d722d74594196cc5b494a34dd87be4 From 46ca43d6399fdeada1b49fb2126ba373bd9ebc38 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 14 May 2026 11:53:43 +0300 Subject: [PATCH 621/831] talk-llama : sync llama.cpp --- examples/talk-llama/llama-context.cpp | 27 ++++++++++++++++++++++----- examples/talk-llama/llama.h | 2 ++ 2 files changed, 24 insertions(+), 5 deletions(-) diff --git a/examples/talk-llama/llama-context.cpp b/examples/talk-llama/llama-context.cpp index 71a59395eb2..3d9714ab166 100644 --- a/examples/talk-llama/llama-context.cpp +++ b/examples/talk-llama/llama-context.cpp @@ -2475,11 +2475,29 @@ class llama_io_write_device : public llama_io_write_i { } if (need_alloc) { - mbuf_cur = std::move(mbuf); + if (!mbuf_cur.buf || mbuf_cur.total_size != mbuf.total_size) { + mbuf_cur = std::move(mbuf); - mbuf_cur.buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(mbuf_cur.ctx.get(), buft)); + mbuf_cur.buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(mbuf_cur.ctx.get(), buft)); - LLAMA_LOG_INFO("%s: allocated '%s' buffer %.3f MiB\n", __func__, ggml_backend_buft_name(buft), mbuf.total_size/1024.0/1024.0); + LLAMA_LOG_INFO("%s: allocated '%s' buffer %.3f MiB\n", __func__, ggml_backend_buft_name(buft), mbuf.total_size/1024.0/1024.0); + } else { + //LLAMA_LOG_INFO("%s: reallocating tensors in '%s' buffer %.3f MiB\n", __func__, ggml_backend_buft_name(buft), mbuf.total_size/1024.0/1024.0); + + // save the old buffer and allocate the new tensors in it + auto buf = std::move(mbuf_cur.buf); + + mbuf_cur = std::move(mbuf); + + ggml_tallocr talloc = ggml_tallocr_new(buf.get()); + + for (size_t i = 0; i < mbuf_cur.org.size(); ++i) { + ggml_backend_view_init(mbuf_cur.org[i]); + ggml_tallocr_alloc(&talloc, mbuf_cur.cpy[i]); + } + + mbuf_cur.buf = std::move(buf); + } } for (size_t i = 0; i < mbuf_cur.org.size(); ++i) { @@ -2559,8 +2577,7 @@ class llama_io_read_device : public llama_io_read_i { mbuf.org.push_back(ggml_view_1d(mbuf.ctx.get(), rinfo.tensor, n, rinfo.offset)); - auto & view = mbuf.org.back(); - view->buffer = rinfo.tensor->buffer; + ggml_backend_view_init(mbuf.org.back()); } for (auto & [buft, mbuf] : mbufs_new) { diff --git a/examples/talk-llama/llama.h b/examples/talk-llama/llama.h index 2ea226726ad..308e8ba9dbd 100644 --- a/examples/talk-llama/llama.h +++ b/examples/talk-llama/llama.h @@ -858,6 +858,8 @@ extern "C" { size_t n_token_capacity, size_t * n_token_count_out); +#define LLAMA_STATE_SEQ_FLAGS_NONE 0 + // for backwards-compat #define LLAMA_STATE_SEQ_FLAGS_SWA_ONLY 1 From 968eebe77225d25e57a3f981da7c696310f0e881 Mon Sep 17 00:00:00 2001 From: Andreas Lubbe Date: Fri, 15 May 2026 14:03:17 +0200 Subject: [PATCH 622/831] server: add support for carry_initial_prompt (#3781) * Add support for carry_initial_prompt on the server * Update README --- examples/server/README.md | 3 + examples/server/server.cpp | 287 +++++++++++++++++++------------------ 2 files changed, 147 insertions(+), 143 deletions(-) diff --git a/examples/server/README.md b/examples/server/README.md index ffba5f4edf5..8d4c802b8bf 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -40,6 +40,7 @@ options: -l LANG, --language LANG [en ] spoken language ('auto' for auto-detect) -dl, --detect-language [false ] exit after automatically detecting language --prompt PROMPT [ ] initial prompt + --carry-initial-prompt [false ] always prepend initial prompt -m FNAME, --model FNAME [models/ggml-base.en.bin] model path -oved D, --ov-e-device DNAME [CPU ] the OpenVINO device used for encode inference -dtw MODEL --dtw MODEL [ ] compute token-level timestamps @@ -78,6 +79,8 @@ curl 127.0.0.1:8080/inference \ -F file="@" \ -F temperature="0.0" \ -F temperature_inc="0.2" \ +-F prompt="" \ +-F carry_initial_prompt="true" \ -F response_format="json" ``` diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 735255b6290..afc95176ec8 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -56,11 +56,11 @@ inline void signal_handler(int signal) { struct server_params { - std::string hostname = "127.0.0.1"; - std::string public_path = "examples/server/public"; - std::string request_path = ""; + std::string hostname = "127.0.0.1"; + std::string public_path = "examples/server/public"; + std::string request_path = ""; std::string inference_path = "/inference"; - std::string tmp_dir = "."; + std::string tmp_dir = "."; int32_t port = 8080; int32_t read_timeout = 600; @@ -89,49 +89,45 @@ struct whisper_params { float temperature_inc = 0.20f; float no_speech_thold = 0.6f; - bool debug_mode = false; - bool translate = false; - bool detect_language = false; - bool diarize = false; - bool tinydiarize = false; - bool split_on_word = false; - bool no_fallback = false; - bool print_special = false; - bool print_colors = false; - bool print_realtime = false; - bool print_progress = false; - bool no_timestamps = false; - bool token_timestamps = true; - bool use_gpu = true; - bool flash_attn = true; - int32_t gpu_device = 0; - bool suppress_nst = false; - bool no_context = true; + bool debug_mode = false; + bool translate = false; + bool detect_language = false; + bool diarize = false; + bool tinydiarize = false; + bool split_on_word = false; + bool no_fallback = false; + bool print_special = false; + bool print_colors = false; + bool print_realtime = false; + bool print_progress = false; + bool no_timestamps = false; + bool token_timestamps = true; + bool use_gpu = true; + bool flash_attn = true; + int32_t gpu_device = 0; + bool suppress_nst = false; + bool no_context = true; bool no_language_probabilities = false; - - std::string language = "en"; - std::string prompt = ""; - std::string font_path = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf"; - std::string model = "models/ggml-base.en.bin"; - - std::string response_format = json_format; - - // [TDRZ] speaker turn string - std::string tdrz_speaker_turn = " [SPEAKER_TURN]"; // TODO: set from command line - + bool carry_initial_prompt = false; + + std::string language = "en"; + std::string prompt = ""; + std::string font_path = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf"; + std::string model = "models/ggml-base.en.bin"; + std::string response_format = json_format; + std::string tdrz_speaker_turn = " [SPEAKER_TURN]"; // TODO: set from command line std::string openvino_encode_device = "CPU"; - - std::string dtw = ""; + std::string dtw = ""; // Voice Activity Detection (VAD) parameters - bool vad = false; - std::string vad_model = ""; - float vad_threshold = 0.5f; - int vad_min_speech_duration_ms = 250; + bool vad = false; + std::string vad_model = ""; + float vad_threshold = 0.5f; + int vad_min_speech_duration_ms = 250; int vad_min_silence_duration_ms = 100; - float vad_max_speech_duration_s = FLT_MAX; - int vad_speech_pad_ms = 30; - float vad_samples_overlap = 0.1f; + float vad_max_speech_duration_s = FLT_MAX; + int vad_speech_pad_ms = 30; + float vad_samples_overlap = 0.1f; }; void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & params, const server_params& sparams) { @@ -139,51 +135,52 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, "usage: %s [options] \n", argv[0]); fprintf(stderr, "\n"); fprintf(stderr, "options:\n"); - fprintf(stderr, " -h, --help [default] show this help message and exit\n"); - fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads); - fprintf(stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors); - fprintf(stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms); - fprintf(stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n); - fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms); - fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context); - fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len); - fprintf(stderr, " -sow, --split-on-word [%-7s] split on word rather than on token\n", params.split_on_word ? "true" : "false"); - fprintf(stderr, " -bo N, --best-of N [%-7d] number of best candidates to keep\n", params.best_of); - fprintf(stderr, " -bs N, --beam-size N [%-7d] beam size for beam search\n", params.beam_size); - fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx); - fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold); - fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", params.entropy_thold); - fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold); - fprintf(stderr, " -debug, --debug-mode [%-7s] enable debug mode (eg. dump log_mel)\n", params.debug_mode ? "true" : "false"); - fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); - fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false"); - fprintf(stderr, " -tdrz, --tinydiarize [%-7s] enable tinydiarize (requires a tdrz model)\n", params.tinydiarize ? "true" : "false"); - fprintf(stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false"); - fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); - fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false"); - fprintf(stderr, " -pr, --print-realtime [%-7s] print output in realtime\n", params.print_realtime ? "true" : "false"); - fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false"); - fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "true" : "false"); - fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str()); - fprintf(stderr, " -dl, --detect-language [%-7s] exit after automatically detecting language\n", params.detect_language ? "true" : "false"); - fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str()); - fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str()); - fprintf(stderr, " -oved D, --ov-e-device DNAME [%-7s] the OpenVINO device used for encode inference\n", params.openvino_encode_device.c_str()); + fprintf(stderr, " -h, --help [default] show this help message and exit\n"); + fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads); + fprintf(stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors); + fprintf(stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms); + fprintf(stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n); + fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms); + fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context); + fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len); + fprintf(stderr, " -sow, --split-on-word [%-7s] split on word rather than on token\n", params.split_on_word ? "true" : "false"); + fprintf(stderr, " -bo N, --best-of N [%-7d] number of best candidates to keep\n", params.best_of); + fprintf(stderr, " -bs N, --beam-size N [%-7d] beam size for beam search\n", params.beam_size); + fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx); + fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold); + fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", params.entropy_thold); + fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold); + fprintf(stderr, " -debug, --debug-mode [%-7s] enable debug mode (eg. dump log_mel)\n", params.debug_mode ? "true" : "false"); + fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); + fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false"); + fprintf(stderr, " -tdrz, --tinydiarize [%-7s] enable tinydiarize (requires a tdrz model)\n", params.tinydiarize ? "true" : "false"); + fprintf(stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false"); + fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); + fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false"); + fprintf(stderr, " -pr, --print-realtime [%-7s] print output in realtime\n", params.print_realtime ? "true" : "false"); + fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false"); + fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "true" : "false"); + fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str()); + fprintf(stderr, " -dl, --detect-language [%-7s] exit after automatically detecting language\n", params.detect_language ? "true" : "false"); + fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str()); + fprintf(stderr, " --carry-initial-prompt [%-7s] always prepend initial prompt\n", params.carry_initial_prompt ? "true" : "false"); + fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str()); + fprintf(stderr, " -oved D, --ov-e-device DNAME [%-7s] the OpenVINO device used for encode inference\n", params.openvino_encode_device.c_str()); // server params - fprintf(stderr, " -dtw MODEL --dtw MODEL [%-7s] compute token-level timestamps\n", params.dtw.c_str()); - fprintf(stderr, " --host HOST, [%-7s] Hostname/ip-adress for the server\n", sparams.hostname.c_str()); - fprintf(stderr, " --port PORT, [%-7d] Port number for the server\n", sparams.port); - fprintf(stderr, " --public PATH, [%-7s] Path to the public folder\n", sparams.public_path.c_str()); - fprintf(stderr, " --request-path PATH, [%-7s] Request path for all requests\n", sparams.request_path.c_str()); - fprintf(stderr, " --inference-path PATH, [%-7s] Inference path for all requests\n", sparams.inference_path.c_str()); - fprintf(stderr, " --convert, [%-7s] Convert audio to WAV, requires ffmpeg on the server\n", sparams.ffmpeg_converter ? "true" : "false"); - fprintf(stderr, " --tmp-dir, [%-7s] Temporary directory for ffmpeg transcoded files\n", sparams.tmp_dir.c_str()); - fprintf(stderr, " -sns, --suppress-nst [%-7s] suppress non-speech tokens\n", params.suppress_nst ? "true" : "false"); - fprintf(stderr, " -nth N, --no-speech-thold N [%-7.2f] no speech threshold\n", params.no_speech_thold); - fprintf(stderr, " -ng, --no-gpu [%-7s] do not use gpu\n", params.use_gpu ? "false" : "true"); - fprintf(stderr, " -dev N, --device N [%-7d] GPU device ID (default: 0)\n", params.gpu_device); - fprintf(stderr, " -fa, --flash-attn [%-7s] enable flash attention\n", params.flash_attn ? "true" : "false"); - fprintf(stderr, " -nfa, --no-flash-attn [%-7s] disable flash attention\n", params.flash_attn ? "false" : "true"); + fprintf(stderr, " -dtw MODEL --dtw MODEL [%-7s] compute token-level timestamps\n", params.dtw.c_str()); + fprintf(stderr, " --host HOST, [%-7s] Hostname/ip-adress for the server\n", sparams.hostname.c_str()); + fprintf(stderr, " --port PORT, [%-7d] Port number for the server\n", sparams.port); + fprintf(stderr, " --public PATH, [%-7s] Path to the public folder\n", sparams.public_path.c_str()); + fprintf(stderr, " --request-path PATH, [%-7s] Request path for all requests\n", sparams.request_path.c_str()); + fprintf(stderr, " --inference-path PATH, [%-7s] Inference path for all requests\n", sparams.inference_path.c_str()); + fprintf(stderr, " --convert, [%-7s] Convert audio to WAV, requires ffmpeg on the server\n", sparams.ffmpeg_converter ? "true" : "false"); + fprintf(stderr, " --tmp-dir, [%-7s] Temporary directory for ffmpeg transcoded files\n", sparams.tmp_dir.c_str()); + fprintf(stderr, " -sns, --suppress-nst [%-7s] suppress non-speech tokens\n", params.suppress_nst ? "true" : "false"); + fprintf(stderr, " -nth N, --no-speech-thold N [%-7.2f] no speech threshold\n", params.no_speech_thold); + fprintf(stderr, " -ng, --no-gpu [%-7s] do not use gpu\n", params.use_gpu ? "false" : "true"); + fprintf(stderr, " -dev N, --device N [%-7d] GPU device ID (default: 0)\n", params.gpu_device); + fprintf(stderr, " -fa, --flash-attn [%-7s] enable flash attention\n", params.flash_attn ? "true" : "false"); + fprintf(stderr, " -nfa, --no-flash-attn [%-7s] disable flash attention\n", params.flash_attn ? "false" : "true"); fprintf(stderr, " -nlp, --no-language-probabilities [%-7s] exclude language probabilities from verbose_json output\n", params.no_language_probabilities ? "true" : "false"); // Voice Activity Detection (VAD) parameters fprintf(stderr, "\nVoice Activity Detection (VAD) options:\n"); @@ -191,10 +188,8 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -vm FNAME, --vad-model FNAME [%-7s] VAD model path\n", params.vad_model.c_str()); fprintf(stderr, " -vt N, --vad-threshold N [%-7.2f] VAD threshold for speech recognition\n", params.vad_threshold); fprintf(stderr, " -vspd N, --vad-min-speech-duration-ms N [%-7d] VAD min speech duration (0.0-1.0)\n", params.vad_min_speech_duration_ms); - fprintf(stderr, " -vsd N, --vad-min-silence-duration-ms N [%-7d] VAD min silence duration (to split segments)\n", params.vad_min_silence_duration_ms); - fprintf(stderr, " -vmsd N, --vad-max-speech-duration-s N [%-7s] VAD max speech duration (auto-split longer)\n", params.vad_max_speech_duration_s == FLT_MAX ? - std::string("FLT_MAX").c_str() : - std::to_string(params.vad_max_speech_duration_s).c_str()); + fprintf(stderr, " -vsd N, --vad-min-silence-duration-ms N [%-7d] VAD min silence duration (to split segments)\n", params.vad_min_silence_duration_ms); + fprintf(stderr, " -vmsd N, --vad-max-speech-duration-s N [%-7s] VAD max speech duration (auto-split longer)\n", params.vad_max_speech_duration_s == FLT_MAX ? std::string("FLT_MAX").c_str() : std::to_string(params.vad_max_speech_duration_s).c_str()); fprintf(stderr, " -vp N, --vad-speech-pad-ms N [%-7d] VAD speech padding (extend segments)\n", params.vad_speech_pad_ms); fprintf(stderr, " -vo N, --vad-samples-overlap N [%-7.2f] VAD samples overlap (seconds between segments)\n", params.vad_samples_overlap); fprintf(stderr, "\n"); @@ -212,63 +207,64 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params, serve whisper_print_usage(argc, argv, params, sparams); exit(0); } - else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); } - else if (arg == "-p" || arg == "--processors") { params.n_processors = std::stoi(argv[++i]); } - else if (arg == "-ot" || arg == "--offset-t") { params.offset_t_ms = std::stoi(argv[++i]); } - else if (arg == "-on" || arg == "--offset-n") { params.offset_n = std::stoi(argv[++i]); } - else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(argv[++i]); } - else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(argv[++i]); } - else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(argv[++i]); } - else if (arg == "-bo" || arg == "--best-of") { params.best_of = std::stoi(argv[++i]); } - else if (arg == "-bs" || arg == "--beam-size") { params.beam_size = std::stoi(argv[++i]); } - else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); } - else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); } - else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(argv[++i]); } - else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); } - else if (arg == "-debug"|| arg == "--debug-mode") { params.debug_mode = true; } - else if (arg == "-tr" || arg == "--translate") { params.translate = true; } - else if (arg == "-di" || arg == "--diarize") { params.diarize = true; } - else if (arg == "-tdrz" || arg == "--tinydiarize") { params.tinydiarize = true; } - else if (arg == "-sow" || arg == "--split-on-word") { params.split_on_word = true; } - else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; } - else if (arg == "-fp" || arg == "--font-path") { params.font_path = argv[++i]; } - else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } - else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; } - else if (arg == "-pr" || arg == "--print-realtime") { params.print_realtime = true; } - else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; } - else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; } - else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; } - else if (arg == "-dl" || arg == "--detect-language") { params.detect_language = true; } - else if ( arg == "--prompt") { params.prompt = argv[++i]; } - else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; } - else if (arg == "-oved" || arg == "--ov-e-device") { params.openvino_encode_device = argv[++i]; } - else if (arg == "-dtw" || arg == "--dtw") { params.dtw = argv[++i]; } - else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } - else if (arg == "-dev" || arg == "--device") { params.gpu_device = std::stoi(argv[++i]); } - else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; } - else if (arg == "-nfa" || arg == "--no-flash-attn") { params.flash_attn = false; } - else if (arg == "-sns" || arg == "--suppress-nst") { params.suppress_nst = true; } - else if (arg == "-nth" || arg == "--no-speech-thold") { params.no_speech_thold = std::stof(argv[++i]); } - else if (arg == "-nlp" || arg == "--no-language-probabilities") { params.no_language_probabilities = true; } + else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); } + else if (arg == "-p" || arg == "--processors") { params.n_processors = std::stoi(argv[++i]); } + else if (arg == "-ot" || arg == "--offset-t") { params.offset_t_ms = std::stoi(argv[++i]); } + else if (arg == "-on" || arg == "--offset-n") { params.offset_n = std::stoi(argv[++i]); } + else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(argv[++i]); } + else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(argv[++i]); } + else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(argv[++i]); } + else if (arg == "-bo" || arg == "--best-of") { params.best_of = std::stoi(argv[++i]); } + else if (arg == "-bs" || arg == "--beam-size") { params.beam_size = std::stoi(argv[++i]); } + else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); } + else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); } + else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(argv[++i]); } + else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); } + else if (arg == "-debug" || arg == "--debug-mode") { params.debug_mode = true; } + else if (arg == "-tr" || arg == "--translate") { params.translate = true; } + else if (arg == "-di" || arg == "--diarize") { params.diarize = true; } + else if (arg == "-tdrz" || arg == "--tinydiarize") { params.tinydiarize = true; } + else if (arg == "-sow" || arg == "--split-on-word") { params.split_on_word = true; } + else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; } + else if (arg == "-fp" || arg == "--font-path") { params.font_path = argv[++i]; } + else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } + else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; } + else if (arg == "-pr" || arg == "--print-realtime") { params.print_realtime = true; } + else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; } + else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; } + else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; } + else if (arg == "-dl" || arg == "--detect-language") { params.detect_language = true; } + else if ( arg == "--prompt") { params.prompt = argv[++i]; } + else if ( arg == "--carry-initial-prompt") { params.carry_initial_prompt = true; } + else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; } + else if (arg == "-oved" || arg == "--ov-e-device") { params.openvino_encode_device = argv[++i]; } + else if (arg == "-dtw" || arg == "--dtw") { params.dtw = argv[++i]; } + else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } + else if (arg == "-dev" || arg == "--device") { params.gpu_device = std::stoi(argv[++i]); } + else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; } + else if (arg == "-nfa" || arg == "--no-flash-attn") { params.flash_attn = false; } + else if (arg == "-sns" || arg == "--suppress-nst") { params.suppress_nst = true; } + else if (arg == "-nth" || arg == "--no-speech-thold") { params.no_speech_thold = std::stof(argv[++i]); } + else if (arg == "-nlp" || arg == "--no-language-probabilities") { params.no_language_probabilities = true; } // server params - else if ( arg == "--port") { sparams.port = std::stoi(argv[++i]); } - else if ( arg == "--host") { sparams.hostname = argv[++i]; } - else if ( arg == "--public") { sparams.public_path = argv[++i]; } - else if ( arg == "--request-path") { sparams.request_path = argv[++i]; } - else if ( arg == "--inference-path") { sparams.inference_path = argv[++i]; } - else if ( arg == "--convert") { sparams.ffmpeg_converter = true; } - else if ( arg == "--tmp-dir") { sparams.tmp_dir = argv[++i]; } + else if ( arg == "--port") { sparams.port = std::stoi(argv[++i]); } + else if ( arg == "--host") { sparams.hostname = argv[++i]; } + else if ( arg == "--public") { sparams.public_path = argv[++i]; } + else if ( arg == "--request-path") { sparams.request_path = argv[++i]; } + else if ( arg == "--inference-path") { sparams.inference_path = argv[++i]; } + else if ( arg == "--convert") { sparams.ffmpeg_converter = true; } + else if ( arg == "--tmp-dir") { sparams.tmp_dir = argv[++i]; } // Voice Activity Detection (VAD) - else if ( arg == "--vad") { params.vad = true; } - else if (arg == "-vm" || arg == "--vad-model") { params.vad_model = argv[++i]; } - else if (arg == "-vt" || arg == "--vad-threshold") { params.vad_threshold = std::stof(argv[++i]); } - else if (arg == "-vspd" || arg == "--vad-min-speech-duration-ms") { params.vad_min_speech_duration_ms = std::stoi(argv[++i]); } - else if (arg == "-vsd" || arg == "--vad-min-silence-duration-ms") { params.vad_min_silence_duration_ms = std::stoi(argv[++i]); } - else if (arg == "-vmsd" || arg == "--vad-max-speech-duration-s") { params.vad_max_speech_duration_s = std::stof(argv[++i]); } - else if (arg == "-vp" || arg == "--vad-speech-pad-ms") { params.vad_speech_pad_ms = std::stoi(argv[++i]); } - else if (arg == "-vo" || arg == "--vad-samples-overlap") { params.vad_samples_overlap = std::stof(argv[++i]); } + else if ( arg == "--vad") { params.vad = true; } + else if (arg == "-vm" || arg == "--vad-model") { params.vad_model = argv[++i]; } + else if (arg == "-vt" || arg == "--vad-threshold") { params.vad_threshold = std::stof(argv[++i]); } + else if (arg == "-vspd" || arg == "--vad-min-speech-duration-ms") { params.vad_min_speech_duration_ms = std::stoi(argv[++i]); } + else if (arg == "-vsd" || arg == "--vad-min-silence-duration-ms") { params.vad_min_silence_duration_ms = std::stoi(argv[++i]); } + else if (arg == "-vmsd" || arg == "--vad-max-speech-duration-s") { params.vad_max_speech_duration_s = std::stof(argv[++i]); } + else if (arg == "-vp" || arg == "--vad-speech-pad-ms") { params.vad_speech_pad_ms = std::stoi(argv[++i]); } + else if (arg == "-vo" || arg == "--vad-samples-overlap") { params.vad_samples_overlap = std::stof(argv[++i]); } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); whisper_print_usage(argc, argv, params, sparams); @@ -573,6 +569,10 @@ void get_req_parameters(const Request & req, whisper_params & params) { params.prompt = req.get_file_value("prompt").content; } + if (req.has_file("carry_initial_prompt")) + { + params.carry_initial_prompt = parse_str_to_bool(req.get_file_value("carry_initial_prompt").content); + } if (req.has_file("response_format")) { params.response_format = req.get_file_value("response_format").content; @@ -940,6 +940,7 @@ int main(int argc, char ** argv) { wparams.tdrz_enable = params.tinydiarize; // [TDRZ] wparams.initial_prompt = params.prompt.c_str(); + wparams.carry_initial_prompt = params.carry_initial_prompt; wparams.greedy.best_of = params.best_of; wparams.beam_search.beam_size = params.beam_size; From 6227a0ef739a78312d96e6f8f85e7b6d63683445 Mon Sep 17 00:00:00 2001 From: Andreas Lubbe Date: Mon, 18 May 2026 09:18:04 +0200 Subject: [PATCH 623/831] server : Return speaker information in JSON (#3782) --- examples/server/server.cpp | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index afc95176ec8..590378b725f 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -315,10 +315,10 @@ std::string generate_temp_filename(const std::string &path, const std::string &p return ss.str(); } -bool convert_to_wav(const std::string & temp_filename, std::string & error_resp) { +bool convert_to_wav(const std::string & temp_filename, std::string & error_resp, bool stereo) { std::ostringstream cmd_stream; std::string converted_filename_temp = temp_filename + "_temp.wav"; - cmd_stream << "ffmpeg -i \"" << temp_filename << "\" -y -ar 16000 -ac 1 -c:a pcm_s16le \"" << converted_filename_temp << "\" 2>&1"; + cmd_stream << "ffmpeg -i \"" << temp_filename << "\" -y -ar 16000 -ac " << (stereo ? 2 : 1) << " -c:a pcm_s16le \"" << converted_filename_temp << "\" 2>&1"; std::string cmd = cmd_stream.str(); int status = std::system(cmd.c_str()); @@ -341,7 +341,7 @@ bool convert_to_wav(const std::string & temp_filename, std::string & error_resp) return true; } -std::string estimate_diarization_speaker(std::vector> pcmf32s, int64_t t0, int64_t t1, bool id_only = false) { +std::string estimate_diarization_speaker(const std::vector> & pcmf32s, int64_t t0, int64_t t1, bool id_only = false) { std::string speaker = ""; const int64_t n_samples = pcmf32s[0].size(); @@ -451,7 +451,7 @@ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper } } -std::string output_str(struct whisper_context * ctx, const whisper_params & params, std::vector> pcmf32s) { +std::string output_str(struct whisper_context * ctx, const whisper_params & params, const std::vector> & pcmf32s) { std::stringstream result; const int n_segments = whisper_full_n_segments(ctx); for (int i = 0; i < n_segments; ++i) { @@ -848,7 +848,7 @@ int main(int argc, char ** argv) { temp_file.close(); std::string error_resp = "{\"error\":\"Failed to execute ffmpeg command.\"}"; - const bool is_converted = convert_to_wav(temp_filename, error_resp); + const bool is_converted = convert_to_wav(temp_filename, error_resp, params.diarize); if (!is_converted) { res.status = 500; res.set_content(error_resp, "application/json"); @@ -1091,6 +1091,14 @@ int main(int argc, char ** argv) { segment["end"] = whisper_full_get_segment_t1(ctx, i) * 0.01; } + if (params.diarize && pcmf32s.size() == 2) { + segment["speaker"] = estimate_diarization_speaker( + pcmf32s, + whisper_full_get_segment_t0(ctx, i), + whisper_full_get_segment_t1(ctx, i), + true); + } + float total_logprob = 0; const int n_tokens = whisper_full_n_tokens(ctx, i); for (int j = 0; j < n_tokens; ++j) { From 47b9eb37a33c5031a1b667ace64477330b9f36c1 Mon Sep 17 00:00:00 2001 From: petterreinholdtsen Date: Mon, 18 May 2026 12:16:39 +0200 Subject: [PATCH 624/831] examples : fix memory leak in read_audio_data (#3810) This commit addresses a memory leak in the `read_audio_data` function where it is currently possible that a call to `ma_decoder_init_file` succeeds and the function returns early without calling `ma_decoder_uninit`. A similar situation can occur with `ma_decoder_init_memory`. Refs: https://bugs.debian.org/1124796 Co-authored-by: Daniel Bevenius --- examples/common-whisper.cpp | 55 +++++++++++++++++++++++-------------- 1 file changed, 34 insertions(+), 21 deletions(-) diff --git a/examples/common-whisper.cpp b/examples/common-whisper.cpp index 6218a882eb5..977527a0ca5 100644 --- a/examples/common-whisper.cpp +++ b/examples/common-whisper.cpp @@ -44,7 +44,18 @@ bool read_audio_data(const std::string & fname, std::vector& pcmf32, std: ma_result result; ma_decoder_config decoder_config; - ma_decoder decoder; + + struct decoder_guard { + ma_decoder decoder; + bool initialized = false; + ma_decoder * operator&() { return &decoder; } + ~decoder_guard() { + if (initialized) { + ma_decoder_uninit(&decoder); + } + } + }; + decoder_guard decoder{}; decoder_config = ma_decoder_config_init(ma_format_f32, stereo ? 2 : 1, WHISPER_SAMPLE_RATE); @@ -63,32 +74,36 @@ bool read_audio_data(const std::string & fname, std::vector& pcmf32, std: audio_data.insert(audio_data.end(), buf, buf + n); } - if ((result = ma_decoder_init_memory(audio_data.data(), audio_data.size(), &decoder_config, &decoder)) != MA_SUCCESS) { - + result = ma_decoder_init_memory(audio_data.data(), audio_data.size(), &decoder_config, &decoder); + if (result != MA_SUCCESS) { fprintf(stderr, "Error: failed to open audio data from stdin (%s)\n", ma_result_description(result)); - return false; } + decoder.initialized = true; fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, audio_data.size()); } - else if (((result = ma_decoder_init_file(fname.c_str(), &decoder_config, &decoder)) != MA_SUCCESS)) { + else { + result = ma_decoder_init_file(fname.c_str(), &decoder_config, &decoder); + if (result == MA_SUCCESS) { + decoder.initialized = true; + } #if defined(WHISPER_FFMPEG) - if (ffmpeg_decode_audio(fname, audio_data) != 0) { - fprintf(stderr, "error: failed to ffmpeg decode '%s'\n", fname.c_str()); - - return false; - } - - if ((result = ma_decoder_init_memory(audio_data.data(), audio_data.size(), &decoder_config, &decoder)) != MA_SUCCESS) { - fprintf(stderr, "error: failed to read audio data as wav (%s)\n", ma_result_description(result)); - - return false; - } + if (!decoder.initialized) { + if (ffmpeg_decode_audio(fname, audio_data) != 0) { + fprintf(stderr, "error: failed to ffmpeg decode '%s'\n", fname.c_str()); + return false; + } + result = ma_decoder_init_memory(audio_data.data(), audio_data.size(), &decoder_config, &decoder); + if (result != MA_SUCCESS) { + fprintf(stderr, "error: failed to read audio data as wav (%s)\n", ma_result_description(result)); + return false; + } + decoder.initialized = true; + } #else - if ((result = ma_decoder_init_memory(fname.c_str(), fname.size(), &decoder_config, &decoder)) != MA_SUCCESS) { - fprintf(stderr, "error: failed to read audio data as wav (%s)\n", ma_result_description(result)); - + if (!decoder.initialized) { + fprintf(stderr, "error: failed to read audio data from (%s)\n", fname.c_str()); return false; } #endif @@ -128,8 +143,6 @@ bool read_audio_data(const std::string & fname, std::vector& pcmf32, std: } } - ma_decoder_uninit(&decoder); - return true; } From afa2ea544fb4b0448916b4a31ecd33c8685bd482 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Tue, 19 May 2026 08:58:43 +0200 Subject: [PATCH 625/831] whisper : set bench data for each iteration (#3812) * whisper : set bench data for each iteration This commit updates whisper_bench_ggml_mul_mat_str to intialize the tensors data for each iteration. The motivation for this is that is currently possible for a previous run's results, F32 values, to leak into the next run. When it is time for the F16 iteration then F32 results can cause NaN values to appear in the tensor values causing the F16 iteration to fail. Refs:https://github.com/ggml-org/whisper.cpp/actions/runs/25901678402/job/76152894644?pr=3735 * ci : set GGML_NATIVE=OFF if x86_64 This commit sets GGML_NATIVE=OFF for x86_64 architectures. The motivation for this is to try to get CI to pass and the theory is that the libggml-cpu.so library in the ccache might have been built by a runner that supports a different instruction set. When another runner that does not support that instruction set tries to use it, it will fail with a segmentation fault. I'm not sure about this yet but going to try this out and if it does not work I'll ssh into the runner to debug further. --- ci/run.sh | 4 ++++ src/whisper.cpp | 12 +++++++++--- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/ci/run.sh b/ci/run.sh index cbe28442e16..b03fdf1c6b1 100644 --- a/ci/run.sh +++ b/ci/run.sh @@ -50,6 +50,10 @@ fi CMAKE_EXTRA="-DWHISPER_FATAL_WARNINGS=ON" +if [[ "$(uname -m)" == "x86_64" ]]; then + CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_NATIVE=OFF" +fi + if [ ! -z ${GG_BUILD_METAL} ]; then CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_METAL=ON" fi diff --git a/src/whisper.cpp b/src/whisper.cpp index 210ca597fb4..0fe29a4541e 100644 --- a/src/whisper.cpp +++ b/src/whisper.cpp @@ -8258,9 +8258,6 @@ WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) { // when F16 is used, there is an extra work buffer of size N*N*sizeof(float) std::vector buf(3llu*N_max*N_max*sizeof(float) + 3*ggml_tensor_overhead() + ggml_graph_overhead()); - // put a bunch of random data in the buffer - for (size_t i = 0; i < buf.size(); i++) buf[i] = i; - for (int j = 0; j < (int) sizes.size(); j++) { int n_q4_0 = 0; int n_q4_1 = 0; @@ -8304,6 +8301,15 @@ WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) { struct ggml_tensor * a = ggml_new_tensor_2d(ctx0, wtype, N, N); struct ggml_tensor * b = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, N, N); + // set tensor data after allocation so previous iteration results don't corrupt it. + { + uint8_t * a_data = (uint8_t *) a->data; + for (size_t ii = 0; ii < ggml_nbytes(a); ii++) a_data[ii] = ii & 0x3F; + + uint8_t * b_data = (uint8_t *) b->data; + for (size_t ii = 0; ii < ggml_nbytes(b); ii++) b_data[ii] = ii & 0x3F; + } + struct ggml_tensor * c = ggml_mul_mat(ctx0, a, b); struct ggml_cgraph * gf = ggml_new_graph(ctx0); From 8443cf05e3fa8ce1b32348e1bcbcf8fc31f7f3ae Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Thu, 21 May 2026 10:59:58 +0200 Subject: [PATCH 626/831] ci : use github ubuntu-22.04-arm runner instead of qemu (#3815) * ci : use github ubuntu-22.04-arm runner instead of qemu This commit updates the ubuntu-22-gcc-arm64 job to use a arm github runner instead of QEMU. The motivation for this is that we get intermittent failure specifically related to QEMU. For example: ```console Segmentation fault (core dumped) qemu: uncaught target signal 11 (Segmentation fault) - core dumped Segmentation fault (core dumped) dpkg: error processing package libc-bin (--configure): installed libc-bin package post-installation script subprocess returned error exit status 139 Processing triggers for ca-certificates (20240203~22.04.1) ... Updating certificates in /etc/ssl/certs... 0 added, 0 removed; done. Running hooks in /etc/ca-certificates/update.d... done. Errors were encountered while processing: libc-bin E: Sub-process /usr/bin/dpkg returned an error code (1) ``` This is an attempt to try to avoid QEMU and hence avoid this issue. * ci : remove QEMU where possible --- .github/workflows/build.yml | 122 +++++++++++++++++++----------------- 1 file changed, 64 insertions(+), 58 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index be3f78a3f5b..7ace04e1207 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -150,34 +150,21 @@ jobs: ubuntu-22-arm64: if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || github.event.inputs.run_type == 'full-ci' }} - runs-on: ubuntu-22.04 - - strategy: - fail-fast: false - matrix: - arch: [linux/arm64] + runs-on: ubuntu-22.04-arm steps: - name: Clone uses: actions/checkout@v6 - - name: Set up QEMU - uses: docker/setup-qemu-action@v3 - - - name: Build ${{ matrix.arch }} + - name: Install dependencies run: | - docker run --platform ${{ matrix.arch }} --rm \ - -v ${{ github.workspace }}:/workspace \ - -w /workspace ${{ env.ubuntu_image }} /bin/sh -c ' - set -e - export DEBIAN_FRONTEND=noninteractive - sed -i "s|archive.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list - sed -i "s|security.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list + sudo apt-get update + sudo apt-get install -y build-essential libsdl2-dev cmake git - apt update - apt install -y build-essential libsdl2-dev cmake git - cmake -B build -DGGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=armv8-a - cmake --build build --config Release -j $(nproc)' + - name: Build + run: | + cmake -B build -DGGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=armv8-a + cmake --build build --config Release -j $(nproc) ubuntu-22-arm-v7: if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || @@ -305,36 +292,34 @@ jobs: ubuntu-22-gcc-arm64: if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || github.event.inputs.run_type == 'full-ci' }} - runs-on: ubuntu-22.04 + runs-on: ubuntu-22.04-arm strategy: fail-fast: false matrix: build: [Debug, Release] - arch: [linux/arm64] steps: - name: Clone uses: actions/checkout@v6 - - name: Set up QEMU - uses: docker/setup-qemu-action@v3 + - name: Install dependencies + run: | + sudo apt-get update + sudo apt-get install -y build-essential cmake libsdl2-dev git - - name: Build ${{ matrix.arch }} + - name: Configure CMake run: | - docker run --platform ${{ matrix.arch }} --rm \ - -v ${{ github.workspace }}:/workspace \ - -w /workspace ${{ env.ubuntu_image }} /bin/sh -c ' - set -e - export DEBIAN_FRONTEND=noninteractive - sed -i "s|archive.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list - sed -i "s|security.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list + cmake . \ + -DWHISPER_SDL2=ON \ + -DCMAKE_BUILD_TYPE=${{ matrix.build }} \ + -DGGML_NATIVE=OFF \ + -DGGML_CPU_ARM_ARCH=armv8-a - apt update - apt install -y build-essential cmake libsdl2-dev git - cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} -DGGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=armv8-a - make - ctest -L gh --output-on-failure' + - name: Build and Test + run: | + make + ctest -L gh --output-on-failure ubuntu-22-gcc-arm-v7: if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || @@ -382,7 +367,7 @@ jobs: #arch: [linux/amd64, linux/arm64, linux/arm/v7, linux/ppc64le] # TODO: arm/v7 disabled due to clang bug # https://github.com/ggerganov/whisper.cpp/actions/runs/9657764109/job/26637633042?pr=2256#step:4:1990 - arch: [linux/amd64, linux/arm64, linux/ppc64le] + arch: [linux/amd64, linux/ppc64le] steps: - name: Clone @@ -407,6 +392,36 @@ jobs: make ctest -L gh --output-on-failure' + ubuntu-22-clang-arm64: + if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || + github.event.inputs.run_type == 'full-ci' }} + runs-on: ubuntu-22.04-arm + + strategy: + fail-fast: false + matrix: + build: [Debug, Release] + + steps: + - name: Clone + uses: actions/checkout@v6 + + - name: Install dependencies + run: | + sudo apt-get update + sudo apt-get install -y clang build-essential cmake libsdl2-dev git + + - name: Build and Test + run: | + cmake . -DWHISPER_SDL2=ON \ + -DCMAKE_BUILD_TYPE=${{ matrix.build }} \ + -DCMAKE_CXX_COMPILER=clang++ \ + -DCMAKE_C_COMPILER=clang \ + -DGGML_NATIVE=OFF \ + -DGGML_CPU_ARM_ARCH=armv8-a + make + ctest -L gh --output-on-failure + ubuntu-22-gcc-sanitized: if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || github.event.inputs.run_type == 'full-ci' }} @@ -416,32 +431,23 @@ jobs: fail-fast: false matrix: sanitizer: [ADDRESS, THREAD, UNDEFINED] - arch: [linux/amd64] steps: - name: Clone uses: actions/checkout@v6 - - name: Set up QEMU - uses: docker/setup-qemu-action@v3 - - - name: Build ${{ matrix.arch }} + - name: Install dependencies run: | - docker run --platform ${{ matrix.arch }} --rm \ - -v ${{ github.workspace }}:/workspace \ - -w /workspace ${{ env.ubuntu_image }} /bin/sh -c ' - set -e - export DEBIAN_FRONTEND=noninteractive - sed -i "s|archive.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list - sed -i "s|security.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list + sudo apt-get update + sudo apt-get install -y build-essential cmake git - apt update - apt install -y build-essential cmake git - cmake . -DCMAKE_BUILD_TYPE=Debug \ - -DWHISPER_SANITIZE_${{ matrix.sanitizer }}=ON \ - -DGGML_OPENMP=OFF - make - ctest -L gh --output-on-failure' + - name: Build and Test + run: | + cmake . -DCMAKE_BUILD_TYPE=Debug \ + -DWHISPER_SANITIZE_${{ matrix.sanitizer }}=ON \ + -DGGML_OPENMP=OFF + make + ctest -L gh --output-on-failure ubuntu-22-cmake-sycl: if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || From 0ccd896f5b882628e1c077f9769735ef4ce52860 Mon Sep 17 00:00:00 2001 From: Pascal Date: Fri, 22 May 2026 08:27:35 +0200 Subject: [PATCH 627/831] common : fix server /inference fails to decode in-memory audio (regression) (#3818) * common: add memory buffer overload of read_audio_data whisper-server /inference without --convert passed the uploaded file bytes to read_audio_data as a filename, so ma_decoder_init_file tried to open a path starting with "RIFF" and failed. every request returned HTTP 400 "Invalid request" on builds without WHISPER_FFMPEG, which is the default. factor the PCM extraction into a shared helper and add an overload that decodes straight from a memory buffer via ma_decoder_init_memory, which the function already used for the stdin path. server now calls it with the upload content. the filename overload behavior is unchanged. --- examples/common-whisper.cpp | 79 ++++++++++++++++++++++--------------- examples/common-whisper.h | 8 ++++ examples/server/server.cpp | 3 +- 3 files changed, 57 insertions(+), 33 deletions(-) diff --git a/examples/common-whisper.cpp b/examples/common-whisper.cpp index 977527a0ca5..d29166b50d8 100644 --- a/examples/common-whisper.cpp +++ b/examples/common-whisper.cpp @@ -39,6 +39,42 @@ extern bool ffmpeg_decode_audio(const std::string & ifname, std::vector & wav_data); #endif +// extract f32 PCM frames from an initialized decoder, downmix to mono and keep the stereo split +static bool read_audio_from_decoder(ma_decoder & decoder, std::vector & pcmf32, std::vector> & pcmf32s, bool stereo) { + ma_result result; + ma_uint64 frame_count; + ma_uint64 frames_read; + + if ((result = ma_decoder_get_length_in_pcm_frames(&decoder, &frame_count)) != MA_SUCCESS) { + fprintf(stderr, "error: failed to retrieve the length of the audio data (%s)\n", ma_result_description(result)); + return false; + } + + pcmf32.resize(stereo ? frame_count*2 : frame_count); + + if ((result = ma_decoder_read_pcm_frames(&decoder, pcmf32.data(), frame_count, &frames_read)) != MA_SUCCESS) { + fprintf(stderr, "error: failed to read the frames of the audio data (%s)\n", ma_result_description(result)); + return false; + } + + if (stereo) { + std::vector stereo_data = pcmf32; + pcmf32.resize(frame_count); + for (uint64_t i = 0; i < frame_count; i++) { + pcmf32[i] = (stereo_data[2*i] + stereo_data[2*i + 1]); + } + pcmf32s.resize(2); + pcmf32s[0].resize(frame_count); + pcmf32s[1].resize(frame_count); + for (uint64_t i = 0; i < frame_count; i++) { + pcmf32s[0][i] = stereo_data[2*i]; + pcmf32s[1][i] = stereo_data[2*i + 1]; + } + } + + return true; +} + bool read_audio_data(const std::string & fname, std::vector& pcmf32, std::vector>& pcmf32s, bool stereo) { std::vector audio_data; // used for pipe input from stdin or ffmpeg decoding output @@ -109,41 +145,22 @@ bool read_audio_data(const std::string & fname, std::vector& pcmf32, std: #endif } - ma_uint64 frame_count; - ma_uint64 frames_read; - - if ((result = ma_decoder_get_length_in_pcm_frames(&decoder, &frame_count)) != MA_SUCCESS) { - fprintf(stderr, "error: failed to retrieve the length of the audio data (%s)\n", ma_result_description(result)); - - return false; - } - - pcmf32.resize(stereo ? frame_count*2 : frame_count); - - if ((result = ma_decoder_read_pcm_frames(&decoder, pcmf32.data(), frame_count, &frames_read)) != MA_SUCCESS) { - fprintf(stderr, "error: failed to read the frames of the audio data (%s)\n", ma_result_description(result)); - - return false; - } - - if (stereo) { - std::vector stereo_data = pcmf32; - pcmf32.resize(frame_count); + return read_audio_from_decoder(decoder.decoder, pcmf32, pcmf32s, stereo); +} - for (uint64_t i = 0; i < frame_count; i++) { - pcmf32[i] = (stereo_data[2*i] + stereo_data[2*i + 1]); - } +// decode audio bytes already held in memory +bool read_audio_data(const char * buffer, size_t buffer_size, std::vector & pcmf32, std::vector> & pcmf32s, bool stereo) { + ma_decoder_config decoder_config = ma_decoder_config_init(ma_format_f32, stereo ? 2 : 1, WHISPER_SAMPLE_RATE); + ma_decoder decoder; - pcmf32s.resize(2); - pcmf32s[0].resize(frame_count); - pcmf32s[1].resize(frame_count); - for (uint64_t i = 0; i < frame_count; i++) { - pcmf32s[0][i] = stereo_data[2*i]; - pcmf32s[1][i] = stereo_data[2*i + 1]; - } + if (ma_decoder_init_memory(buffer, buffer_size, &decoder_config, &decoder) != MA_SUCCESS) { + fprintf(stderr, "error: failed to decode audio data from memory buffer\n"); + return false; } - return true; + bool ok = read_audio_from_decoder(decoder, pcmf32, pcmf32s, stereo); + ma_decoder_uninit(&decoder); + return ok; } // 500 -> 00:05.000 diff --git a/examples/common-whisper.h b/examples/common-whisper.h index 4134362150a..8714c381046 100644 --- a/examples/common-whisper.h +++ b/examples/common-whisper.h @@ -14,6 +14,14 @@ bool read_audio_data( std::vector> & pcmf32s, bool stereo); +// decode audio bytes already held in memory (uploaded file, network buffer) +bool read_audio_data( + const char * buffer, + size_t buffer_size, + std::vector & pcmf32, + std::vector> & pcmf32s, + bool stereo); + // convert timestamp to string, 6000 -> 01:00.000 std::string to_timestamp(int64_t t, bool comma = false); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 590378b725f..aae74c3d840 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -868,8 +868,7 @@ int main(int argc, char ** argv) { // remove temp file std::remove(temp_filename.c_str()); } else { - if (!::read_audio_data(audio_file.content, pcmf32, pcmf32s, params.diarize)) - { + if (!::read_audio_data(audio_file.content.data(), audio_file.content.size(), pcmf32, pcmf32s, params.diarize)) { fprintf(stderr, "error: failed to read audio data\n"); const std::string error_resp = "{\"error\":\"failed to read audio data\"}"; res.status = 400; From b3877e10c0a8435c53209b2ebfbaa8359f7baaae Mon Sep 17 00:00:00 2001 From: OrbisAI Security Date: Mon, 25 May 2026 11:49:23 +0530 Subject: [PATCH 628/831] fix: in bindings/ruby/test/jfk_reader/jfk_reader in jfk_reader.c (#3756) * fix: V-002 security vulnerability Automated security fix generated by Orbis Security AI * fix(ruby): use Ruby allocator macros in jfk_reader and fix memory leak - Replace calloc/free with ALLOC_N/xfree to match Ruby binding conventions (ALLOC_N handles overflow checking and raises NoMemoryError on failure) - Free temporary samples buffer after conversion loop (was leaked) - Add NULL check for fopen return value with rb_raise - Add comment clarifying n_samples is a compile-time constant Co-Authored-By: Claude Opus 4.6 * fix(ruby): return false instead of rb_raise in memory_view callback rb_memory_view_get_func_t callbacks should communicate errors via return value (false), not exceptions. rb_memory_view_get has no exception-handling wrapper around get_func calls. Co-Authored-By: Claude Opus 4.6 * replacing ALLOC_N with rb_protect as ALLOC_N raises Ruby exceptions --------- Co-authored-by: Claude Opus 4.6 --- bindings/ruby/test/jfk_reader/jfk_reader.c | 57 +++++++++++++++++++--- 1 file changed, 50 insertions(+), 7 deletions(-) diff --git a/bindings/ruby/test/jfk_reader/jfk_reader.c b/bindings/ruby/test/jfk_reader/jfk_reader.c index 6657176e767..62207aaa411 100644 --- a/bindings/ruby/test/jfk_reader/jfk_reader.c +++ b/bindings/ruby/test/jfk_reader/jfk_reader.c @@ -2,6 +2,24 @@ #include #include +typedef struct { + VALUE audio_path; + int n_samples; + const char *audio_path_str; + float *data; + short *samples; +} jfk_alloc_args; + +static VALUE +jfk_reader_alloc_resources(VALUE arg) +{ + jfk_alloc_args *a = (jfk_alloc_args *)arg; + a->audio_path_str = StringValueCStr(a->audio_path); + a->data = ALLOC_N(float, a->n_samples); + a->samples = ALLOC_N(short, a->n_samples); + return Qnil; +} + static VALUE jfk_reader_initialize(VALUE self, VALUE audio_path) { @@ -13,21 +31,42 @@ static bool jfk_reader_get_memory_view(const VALUE obj, rb_memory_view_t *view, int flags) { VALUE audio_path = rb_iv_get(obj, "audio_path"); - const char *audio_path_str = StringValueCStr(audio_path); + // n_samples is a fixed constant (not derived from user input). const int n_samples = 176000; - float *data = (float *)malloc(n_samples * sizeof(float)); - short *samples = (short *)malloc(n_samples * sizeof(short)); - FILE *file = fopen(audio_path_str, "rb"); + + jfk_alloc_args args = { + .audio_path = audio_path, + .n_samples = n_samples, + .audio_path_str = NULL, + .data = NULL, + .samples = NULL, + }; + + int state; + rb_protect(jfk_reader_alloc_resources, (VALUE)&args, &state); + if (state) { + if (args.samples) xfree(args.samples); + if (args.data) xfree(args.data); + return false; + } + + FILE *file = fopen(args.audio_path_str, "rb"); + if (file == NULL) { + xfree(args.samples); + xfree(args.data); + return false; + } fseek(file, 78, SEEK_SET); - fread(samples, sizeof(short), n_samples, file); + fread(args.samples, sizeof(short), n_samples, file); fclose(file); for (int i = 0; i < n_samples; i++) { - data[i] = samples[i]/32768.0; + args.data[i] = args.samples[i] / 32768.0; } + xfree(args.samples); view->obj = obj; - view->data = (void *)data; + view->data = (void *)args.data; view->byte_size = sizeof(float) * n_samples; view->readonly = true; view->format = "f"; @@ -45,6 +84,10 @@ jfk_reader_get_memory_view(const VALUE obj, rb_memory_view_t *view, int flags) static bool jfk_reader_release_memory_view(const VALUE obj, rb_memory_view_t *view) { + if (view->data) { + xfree(view->data); + view->data = NULL; + } return true; } From e414ecf67424f0cd69a3520f99439122ce9aaa1f Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Mon, 25 May 2026 11:25:15 +0200 Subject: [PATCH 629/831] cmake : add CMakePresets.json [no ci] (#3808) This commit adds a CMakePresets.json file similar to the one in llama.cpp. The motivation for this is that this provides sharable named configuration which can be used with cmake --preset . It also allows for extendins these preset with a CMakeUserPresets.json for specific hardware (like CPUs), architectures, and toolchains etc. --- .gitignore | 1 + CMakePresets.json | 95 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 96 insertions(+) create mode 100644 CMakePresets.json diff --git a/.gitignore b/.gitignore index 957eeb75456..6eb8ff45915 100644 --- a/.gitignore +++ b/.gitignore @@ -10,6 +10,7 @@ .DS_Store .vimspector.json /CMakeSettings.json +/CMakeUserPresets.json /talk-llama.dSYM/ build/ diff --git a/CMakePresets.json b/CMakePresets.json new file mode 100644 index 00000000000..b5afeb3c0f2 --- /dev/null +++ b/CMakePresets.json @@ -0,0 +1,95 @@ +{ + "version": 4, + "configurePresets": [ + { + "name": "base", + "hidden": true, + "generator": "Ninja", + "binaryDir": "${sourceDir}/build-${presetName}", + "cacheVariables": { + "CMAKE_EXPORT_COMPILE_COMMANDS": "ON", + "CMAKE_INSTALL_RPATH": "$ORIGIN;$ORIGIN/.." + } + }, + { + "name": "sycl-base", + "hidden": true, + "generator": "Ninja", + "binaryDir": "${sourceDir}/build-${presetName}", + "cacheVariables": { + "CMAKE_EXPORT_COMPILE_COMMANDS": "ON", + "CMAKE_CXX_COMPILER": "icx", + "CMAKE_C_COMPILER": "cl", + "GGML_SYCL": "ON", + "CMAKE_INSTALL_RPATH": "$ORIGIN;$ORIGIN/.." + } + }, + { "name": "debug", "hidden": true, "cacheVariables": { "CMAKE_BUILD_TYPE": "Debug" } }, + { "name": "release", "hidden": true, "cacheVariables": { "CMAKE_BUILD_TYPE": "Release" } }, + { "name": "reldbg", "hidden": true, "cacheVariables": { "CMAKE_BUILD_TYPE": "RelWithDebInfo" } }, + { "name": "static", "hidden": true, "cacheVariables": { "GGML_STATIC": "ON" } }, + { "name": "sycl_f16", "hidden": true, "cacheVariables": { "GGML_SYCL_F16": "ON" } }, + { "name": "vulkan", "hidden": true, "cacheVariables": { "GGML_VULKAN": "ON" } }, + + { + "name": "x64-windows-llvm", "hidden": true, + "cacheVariables": { + "CMAKE_TOOLCHAIN_FILE": "${sourceDir}/cmake/x64-windows-llvm.cmake" + } + }, + + { + "name": "arm64-windows-llvm", "hidden": true, + "architecture": { "value": "arm64", "strategy": "external" }, + "toolset": { "value": "host=x64", "strategy": "external" }, + "cacheVariables": { + "CMAKE_TOOLCHAIN_FILE": "${sourceDir}/cmake/arm64-windows-llvm.cmake" + } + }, + + { + "name": "arm64-apple-clang", "hidden": true, + "architecture": { "value": "arm64", "strategy": "external" }, + "toolset": { "value": "host=x64", "strategy": "external" }, + "cacheVariables": { + "CMAKE_TOOLCHAIN_FILE": "${sourceDir}/cmake/arm64-apple-clang.cmake" + } + }, + { + "name": "x64-linux-gcc", "hidden": true, + "cacheVariables": { + "CMAKE_C_COMPILER": "gcc", + "CMAKE_CXX_COMPILER": "g++" + } + }, + { "name": "x64-linux-gcc-debug", "inherits": [ "base", "x64-linux-gcc", "debug" ] }, + { "name": "x64-linux-gcc-release", "inherits": [ "base", "x64-linux-gcc", "release" ] }, + { "name": "x64-linux-gcc-reldbg", "inherits": [ "base", "x64-linux-gcc", "reldbg" ] }, + { "name": "x64-linux-gcc+static-release", "inherits": [ "base", "x64-linux-gcc", "release", "static" ] }, + + { "name": "arm64-windows-llvm-debug", "inherits": [ "base", "arm64-windows-llvm", "debug" ] }, + { "name": "arm64-windows-llvm-release", "inherits": [ "base", "arm64-windows-llvm", "reldbg" ] }, + { "name": "arm64-windows-llvm+static-release", "inherits": [ "base", "arm64-windows-llvm", "reldbg", "static" ] }, + + { "name": "arm64-apple-clang-debug", "inherits": [ "base", "arm64-apple-clang", "debug" ] }, + { "name": "arm64-apple-clang-release", "inherits": [ "base", "arm64-apple-clang", "reldbg" ] }, + { "name": "arm64-apple-clang+static-release", "inherits": [ "base", "arm64-apple-clang", "reldbg", "static" ] }, + + { "name": "x64-windows-llvm-debug", "inherits": [ "base", "x64-windows-llvm", "debug" ] }, + { "name": "x64-windows-llvm-release", "inherits": [ "base", "x64-windows-llvm", "release" ] }, + { "name": "x64-windows-llvm-reldbg", "inherits": [ "base", "x64-windows-llvm", "reldbg" ] }, + { "name": "x64-windows-llvm+static-release", "inherits": [ "base", "x64-windows-llvm", "reldbg", "static" ] }, + + { "name": "x64-windows-msvc-debug", "inherits": [ "base", "debug" ] }, + { "name": "x64-windows-msvc-release", "inherits": [ "base", "reldbg" ] }, + { "name": "x64-windows-msvc+static-release", "inherits": [ "base", "reldbg", "static" ] }, + + { "name": "x64-windows-sycl-debug", "inherits": [ "sycl-base", "debug" ] }, + { "name": "x64-windows-sycl-debug-f16", "inherits": [ "sycl-base", "debug", "sycl_f16" ] }, + { "name": "x64-windows-sycl-release", "inherits": [ "sycl-base", "release" ] }, + { "name": "x64-windows-sycl-release-f16", "inherits": [ "sycl-base", "release", "sycl_f16" ] }, + + { "name": "x64-windows-vulkan-debug", "inherits": [ "base", "vulkan", "debug" ] }, + { "name": "x64-windows-vulkan-release", "inherits": [ "base", "vulkan", "release" ] } + ] +} From 06cfc3653b256e4d6e02553e70ddce8bbc625ede Mon Sep 17 00:00:00 2001 From: Katostrofik Date: Thu, 14 May 2026 01:39:14 -0400 Subject: [PATCH 630/831] SYCL: fix multi-GPU system RAM exhaustion by using Level Zero allocations (llama/21597) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * SYCL: fix multi-GPU system RAM exhaustion by using Level Zero allocations Replace sycl::malloc_device with zeMemAllocDevice for GPU memory allocation in the SYCL backend. sycl::malloc_device triggers the xe kernel driver's DMA-buf/TTM path which mirrors every VRAM allocation 1:1 in system RAM. zeMemAllocDevice uses the SVM/P2P path with no host staging. On a dual Intel Arc Pro B70 system (64GB VRAM, 64GB RAM), a 15.6 GiB model consumed 60 GiB of system RAM via sycl::malloc_device, causing OOM crashes. With zeMemAllocDevice, the same workload uses ~6.7 GiB of system RAM with no performance regression. All Level Zero calls include automatic fallback to the original SYCL allocation path if Level Zero interop is unavailable. * SYCL: address review feedback - remove try/catch, check device types, deduplicate - Remove try/catch from malloc/free/memcpy helpers, check backend and device type upfront instead (ggml_sycl_is_level_zero, ggml_sycl_is_dgpu) - Move shared helpers (is_level_zero, is_dgpu, free_device) to common.cpp and declare in common.hpp to eliminate code duplication - Use SYCL_CHECK(CHECK_TRY_ERROR()) for fallback sycl::free calls - Guard dev2dev_memcpy L0 path to dGPU-to-dGPU only, preserving the host-staged path for iGPU-to-dGPU transfers - Add Windows Level Zero SDK path detection (LEVEL_ZERO_V1_SDK_PATH) in CMakeLists.txt (co-authored with @arthw) * SYCL: add build/runtime flags for Level Zero, address review feedback Implements the architecture suggested by @arthw: compile-time and runtime flags to cleanly separate Level Zero and SYCL memory API paths. - Add GGML_SYCL_SUPPORT_LEVEL_ZERO cmake option (default ON). All Level Zero code is wrapped in #ifdef so the build works on systems without the Level Zero SDK installed (e.g. CPU-only CI servers). Both the loader library and headers are checked before enabling. - Add GGML_SYCL_ENABLE_LEVEL_ZERO runtime env var (default 1). Controls whether Level Zero or SYCL memory APIs are used. Only one API style is used per session, no mixing. If Level Zero is enabled but the devices don't support the Level Zero backend, it auto-disables with a warning. - Remove Level Zero code from dpct_malloc. It was unused (dpct::device_memory is not called anywhere in the backend) and used try/catch for flow control. - Update SYCL.md with documentation for both new parameters. Tested on Intel Arc Pro B70 (32GB), single-GPU and dual-GPU, with both GGML_SYCL_SUPPORT_LEVEL_ZERO=ON and OFF builds. AI-assisted development (Claude). Code reviewed and tested on my hardware. * SYCL: unify Level Zero malloc/free call sites, address review feedback Move ggml_sycl_malloc_device to common.cpp alongside ggml_sycl_free_device. Both functions are now unconditionally available — Level Zero code is #ifdef'd inside the functions, not at call sites. All call sites use uniform SYCL_CHECK(CHECK_TRY_ERROR()) wrapping with no #ifdef blocks. Addresses arthw's review: wrap all malloc/free in SYCL_CHECK for stack traces on failure, eliminate duplicated #ifdef/else patterns at 6 call sites (-29 lines net). Co-Authored-By: Claude Opus 4.6 (1M context) * SYCL: add Level Zero SDK to CI, fix device check and missed alloc paths Add Level Zero SDK installation to Ubuntu and Windows SYCL CI jobs so the Level Zero code path is compiled and tested in CI. Fix two bugs found during extended dual-GPU testing (no ONEAPI_DEVICE_SELECTOR set): - The Level Zero backend check was iterating all SYCL devices including CPU. The OpenCL CPU device caused Level Zero to be disabled for the GPUs, defeating the fix on multi-GPU systems. Added is_gpu() filter so only GPU devices are checked. - sycl_ext_malloc_device/sycl_ext_free (tensor reorder temp buffers) were still calling sycl::malloc/sycl::free directly, bypassing the Level Zero path. Routed through ggml_sycl_malloc_device/free_device for consistency with the other device memory call sites. Co-Authored-By: Claude Opus 4.6 (1M context) * SYCL: address arthw review feedback on Level Zero memory API structure - Move ggml_sycl_malloc_device to static function in ggml-sycl.cpp; only ggml_sycl_free_device (used by common.cpp) stays in common.cpp - Switch both helpers to use g_ggml_sycl_enable_level_zero global instead of per-call queue backend checks - Remove #ifdef wrapper from global definition; always declare at 0, add #else branch in init block so it stays 0 when L0 not compiled in - Update init loop comment to explain GPU-only device check - CMakeLists: message(STATUS) before the if block; align option wording AI-assisted implementation. Reviewed and tested on dual Intel Arc Pro B70 (32 GB each): test-backend-ops OK on both GPUs, single/dual-GPU Q4_K_M and Q8_0 bench correct, zeMemAllocDevice GTT delta confirmed <5 MiB per 4 GiB allocation (vs ~4 GiB shadow with sycl::malloc_device). Co-Authored-By: Claude Sonnet 4.6 * SYCL: remove unused cstdio/cstdlib includes from common.cpp Leftover from the deleted ggml_sycl_queue_supports_level_zero helper. Co-authored-by: Claude Sonnet 4.6 * Apply suggestions from code review Co-authored-by: Neo Zhang * SYCL: preserve Level Zero allocation path during early malloc * ci: fix Level Zero package conflict in Intel Docker build * ci: find Level Zero loader in oneAPI package step * ci: allow Windows SYCL package without Level Zero DLL --------- Co-authored-by: Claude Opus 4.6 (1M context) Co-authored-by: Neo Zhang --- ggml/CMakeLists.txt | 1 + ggml/src/ggml-sycl/CMakeLists.txt | 29 +++++++++ ggml/src/ggml-sycl/common.cpp | 76 +++++++++++++++++++++++- ggml/src/ggml-sycl/common.hpp | 4 ++ ggml/src/ggml-sycl/ggml-sycl.cpp | 98 +++++++++++++++++++++++++------ 5 files changed, 187 insertions(+), 21 deletions(-) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 4e65cd68b4e..bdeca34bf9f 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -249,6 +249,7 @@ option(GGML_SYCL "ggml: use SYCL" option(GGML_SYCL_F16 "ggml: use 16 bit floats for sycl calculations" OFF) option(GGML_SYCL_GRAPH "ggml: enable graphs in the SYCL backend" ON) option(GGML_SYCL_HOST_MEM_FALLBACK "ggml: allow host memory fallback in SYCL reorder (requires kernel 6.8+)" ON) +option(GGML_SYCL_SUPPORT_LEVEL_ZERO "ggml: use Level Zero API in SYCL backend" ON) option(GGML_SYCL_DNN "ggml: enable oneDNN in the SYCL backend" ON) set (GGML_SYCL_TARGET "INTEL" CACHE STRING "ggml: sycl target device") diff --git a/ggml/src/ggml-sycl/CMakeLists.txt b/ggml/src/ggml-sycl/CMakeLists.txt index 8f44c6ed080..180de92202d 100644 --- a/ggml/src/ggml-sycl/CMakeLists.txt +++ b/ggml/src/ggml-sycl/CMakeLists.txt @@ -39,6 +39,18 @@ if (WIN32) set(CMAKE_CXX_COMPILER "icx") set(CMAKE_CXX_COMPILER_ID "IntelLLVM") endif() + # Level Zero SDK path for Windows (only when GGML_SYCL_SUPPORT_LEVEL_ZERO is enabled) + if(GGML_SYCL_SUPPORT_LEVEL_ZERO) + if(DEFINED ENV{LEVEL_ZERO_V1_SDK_PATH}) + set(LEVEL_ZERO_V1_SDK_PATH $ENV{LEVEL_ZERO_V1_SDK_PATH}) + if(EXISTS "${LEVEL_ZERO_V1_SDK_PATH}") + target_include_directories(ggml-sycl PRIVATE "${LEVEL_ZERO_V1_SDK_PATH}/include") + set(LEVEL_ZERO_V1_SDK_LIB_PATH "${LEVEL_ZERO_V1_SDK_PATH}/lib") + else() + message(WARNING "LEVEL_ZERO_V1_SDK_PATH set but folder not found: ${LEVEL_ZERO_V1_SDK_PATH}") + endif() + endif() + endif() endif() macro(detect_and_find_package package_name) @@ -93,6 +105,23 @@ endif() target_compile_options(ggml-sycl PRIVATE "-Wno-narrowing") +message(STATUS "GGML_SYCL_SUPPORT_LEVEL_ZERO ${GGML_SYCL_SUPPORT_LEVEL_ZERO}") +if (GGML_SYCL_SUPPORT_LEVEL_ZERO) + # Link against Level Zero loader for direct device memory allocation. + # Avoids sycl::malloc_device triggering DMA-buf/TTM system RAM staging + # in the xe kernel driver during multi-GPU inference. + find_path(LEVEL_ZERO_INCLUDE_DIR level_zero/ze_api.h HINTS ${ONEAPI_ROOT}/include ${LEVEL_ZERO_V1_SDK_PATH}/include) + find_library(ZE_LOADER_LIB ze_loader HINTS ${ONEAPI_ROOT}/lib ${LEVEL_ZERO_V1_SDK_LIB_PATH} ENV LD_LIBRARY_PATH) + if(ZE_LOADER_LIB AND LEVEL_ZERO_INCLUDE_DIR) + target_link_libraries(ggml-sycl PRIVATE ${ZE_LOADER_LIB}) + target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_SUPPORT_LEVEL_ZERO) + message(STATUS "Level Zero loader found: ${ZE_LOADER_LIB}") + message(STATUS "Level Zero headers found: ${LEVEL_ZERO_INCLUDE_DIR}") + else() + message(WARNING "Level Zero loader or headers not found, Level Zero support disabled") + endif() +endif() + # Link against oneDNN set(GGML_SYCL_DNNL 0) if(GGML_SYCL_DNN) diff --git a/ggml/src/ggml-sycl/common.cpp b/ggml/src/ggml-sycl/common.cpp index 05fd5ef46c7..ae08abad81b 100644 --- a/ggml/src/ggml-sycl/common.cpp +++ b/ggml/src/ggml-sycl/common.cpp @@ -11,6 +11,10 @@ // #include "common.hpp" +#include +#ifdef GGML_SYCL_SUPPORT_LEVEL_ZERO +#include +#endif #include "ggml-backend-impl.h" #include "ggml-impl.h" @@ -55,6 +59,20 @@ bool gpu_has_xmx(sycl::device &dev) { return dev.has(sycl::aspect::ext_intel_matrix); } +static int ggml_sycl_get_env(const char *env_name, int default_val) { + char *user_device_string = getenv(env_name); + int user_number = default_val; + + unsigned n; + if (user_device_string != NULL && + sscanf(user_device_string, " %u", &n) == 1) { + user_number = (int)n; + } else { + user_number = default_val; + } + return user_number; +} + int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block_size) { const int64_t max_range = std::numeric_limits::max(); int64_t sycl_down_blk_size = block_size; @@ -66,6 +84,61 @@ int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block return sycl_down_blk_size; } +#ifdef GGML_SYCL_SUPPORT_LEVEL_ZERO +static bool ggml_sycl_use_level_zero_device_alloc(sycl::queue &q) { + return ggml_sycl_get_env("GGML_SYCL_ENABLE_LEVEL_ZERO", 1) && + q.get_device().is_gpu() && + q.get_backend() == sycl::backend::ext_oneapi_level_zero; +} +#endif + +// Use Level Zero zeMemAllocDevice to avoid sycl::malloc_device triggering +// DMA-buf/TTM system RAM staging in the xe kernel driver during multi-GPU inference. +// The decision is made from the queue and runtime env because large buffers can be +// allocated before ggml_check_sycl() initializes g_ggml_sycl_enable_level_zero. +void * ggml_sycl_malloc_device(size_t size, sycl::queue &q) { +#ifdef GGML_SYCL_SUPPORT_LEVEL_ZERO + if (ggml_sycl_use_level_zero_device_alloc(q)) { + void *ptr = nullptr; + auto ze_ctx = sycl::get_native(q.get_context()); + auto ze_dev = sycl::get_native(q.get_device()); +#ifdef ZE_RELAXED_ALLOCATION_LIMITS_EXP_NAME + ze_relaxed_allocation_limits_exp_desc_t relaxed_desc = { + ZE_STRUCTURE_TYPE_RELAXED_ALLOCATION_LIMITS_EXP_DESC, + nullptr, + ZE_RELAXED_ALLOCATION_LIMITS_EXP_FLAG_MAX_SIZE, + }; + ze_device_mem_alloc_desc_t alloc_desc = { + ZE_STRUCTURE_TYPE_DEVICE_MEM_ALLOC_DESC, + &relaxed_desc, + 0, + 0, + }; +#else + ze_device_mem_alloc_desc_t alloc_desc = {ZE_STRUCTURE_TYPE_DEVICE_MEM_ALLOC_DESC, nullptr, 0, 0}; +#endif + ze_result_t r = zeMemAllocDevice(ze_ctx, &alloc_desc, size, 64, ze_dev, &ptr); + if (r == ZE_RESULT_SUCCESS && ptr) { + return ptr; + } + return nullptr; + } +#endif + return sycl::malloc_device(size, q); +} + +void ggml_sycl_free_device(void *ptr, sycl::queue &q) { + if (!ptr) return; +#ifdef GGML_SYCL_SUPPORT_LEVEL_ZERO + if (ggml_sycl_use_level_zero_device_alloc(q)) { + auto ze_ctx = sycl::get_native(q.get_context()); + zeMemFree(ze_ctx, ptr); + return; + } +#endif + SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(ptr, q))); +} + void release_extra_gpu(ggml_tensor_extra_gpu * extra, std::vector streams) { for (int i = 0; i < ggml_sycl_info().device_count; ++i) { for (int64_t is = 0; is < GGML_SYCL_MAX_STREAMS; ++is) { @@ -75,8 +148,7 @@ void release_extra_gpu(ggml_tensor_extra_gpu * extra, std::vector str } if (extra->data_device[i] != nullptr && streams.size()>0) { ggml_sycl_set_device(i); - SYCL_CHECK( - CHECK_TRY_ERROR(sycl::free(extra->data_device[i], *(streams[i])))); + SYCL_CHECK(CHECK_TRY_ERROR(ggml_sycl_free_device(extra->data_device[i], *(streams[i])))); } } delete extra; diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp index eec36e8db9a..96bc1c98bd9 100644 --- a/ggml/src/ggml-sycl/common.hpp +++ b/ggml/src/ggml-sycl/common.hpp @@ -310,6 +310,10 @@ struct ggml_tensor_extra_gpu { optimize_feature optimized_feature; }; +extern int g_ggml_sycl_enable_level_zero; +void * ggml_sycl_malloc_device(size_t size, sycl::queue &q); +void ggml_sycl_free_device(void *ptr, sycl::queue &q); + void release_extra_gpu(ggml_tensor_extra_gpu * extra, std::vector streams={}); namespace sycl_ex = sycl::ext::oneapi::experimental; diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 57cc4ffb6f7..f5d10b56de0 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -30,6 +30,10 @@ #include #include +#include +#ifdef GGML_SYCL_SUPPORT_LEVEL_ZERO +#include +#endif #if defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC # include #endif @@ -68,6 +72,7 @@ int g_ggml_sycl_disable_graph = 0; int g_ggml_sycl_disable_dnn = 0; int g_ggml_sycl_prioritize_dmmv = 0; int g_ggml_sycl_use_async_mem_op = 0; +int g_ggml_sycl_enable_level_zero = 0; int g_ggml_sycl_enable_flash_attention = 1; @@ -223,6 +228,27 @@ static void ggml_check_sycl() try { g_ggml_sycl_disable_graph = get_sycl_env("GGML_SYCL_DISABLE_GRAPH", 1); g_ggml_sycl_disable_dnn = get_sycl_env("GGML_SYCL_DISABLE_DNN", 0); g_ggml_sycl_prioritize_dmmv = get_sycl_env("GGML_SYCL_PRIORITIZE_DMMV", 0); +#ifdef GGML_SYCL_SUPPORT_LEVEL_ZERO + g_ggml_sycl_enable_level_zero = get_sycl_env("GGML_SYCL_ENABLE_LEVEL_ZERO", 1); +#else + g_ggml_sycl_enable_level_zero = 0; +#endif + if (g_ggml_sycl_enable_level_zero) { + // Verify all GPU devices use the Level Zero backend before enabling L0 APIs. + // Only check GPU devices; CPU devices use OpenCL and would otherwise + // disable Level Zero for the GPUs on systems without ONEAPI_DEVICE_SELECTOR set. + for (unsigned int i = 0; i < dpct::dev_mgr::instance().device_count(); i++) { + auto & q = dpct::dev_mgr::instance().get_device(i).default_queue(); + if (!q.get_device().is_gpu()) { + continue; + } + if (q.get_backend() != sycl::backend::ext_oneapi_level_zero) { + GGML_LOG_WARN("SYCL GPU device %d does not use Level Zero backend, disabling Level Zero memory API\n", i); + g_ggml_sycl_enable_level_zero = 0; + break; + } + } + } #ifdef SYCL_FLASH_ATTN g_ggml_sycl_enable_flash_attention = get_sycl_env("GGML_SYCL_ENABLE_FLASH_ATTN", 1); @@ -253,6 +279,11 @@ static void ggml_check_sycl() try { #else GGML_LOG_INFO(" GGML_SYCL_DNNL: no\n"); #endif +#if defined(GGML_SYCL_SUPPORT_LEVEL_ZERO) + GGML_LOG_INFO(" GGML_SYCL_SUPPORT_LEVEL_ZERO: yes\n"); +#else + GGML_LOG_INFO(" GGML_SYCL_SUPPORT_LEVEL_ZERO: no\n"); +#endif GGML_LOG_INFO("Running with Environment Variables:\n"); GGML_LOG_INFO(" GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug); @@ -262,6 +293,11 @@ static void ggml_check_sycl() try { #else GGML_LOG_INFO(" GGML_SYCL_DISABLE_GRAPH: graph disabled by compile flag\n"); #endif +#ifdef GGML_SYCL_SUPPORT_LEVEL_ZERO + GGML_LOG_INFO(" GGML_SYCL_ENABLE_LEVEL_ZERO: %d\n", g_ggml_sycl_enable_level_zero); +#else + GGML_LOG_INFO(" GGML_SYCL_ENABLE_LEVEL_ZERO: Level Zero disabled by compile flag\n"); +#endif #if GGML_SYCL_DNNL GGML_LOG_INFO(" GGML_SYCL_DISABLE_DNN: %d\n", g_ggml_sycl_disable_dnn); #else @@ -371,7 +407,7 @@ struct ggml_backend_sycl_buffer_context { ~ggml_backend_sycl_buffer_context() { if (dev_ptr != nullptr) { ggml_sycl_set_device(device); - SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(dev_ptr, *stream))); + SYCL_CHECK(CHECK_TRY_ERROR(ggml_sycl_free_device(dev_ptr, *stream))); } //release extra used by tensors @@ -504,8 +540,43 @@ catch (sycl::exception const &exc) { std::exit(1); } +#ifdef GGML_SYCL_SUPPORT_LEVEL_ZERO +static bool ggml_sycl_is_l0_discrete_gpu(sycl::queue &q) { + if (!q.get_device().is_gpu() || q.get_backend() != sycl::backend::ext_oneapi_level_zero) { + return false; + } + + ze_device_handle_t ze_dev = sycl::get_native(q.get_device()); + ze_device_properties_t props = {}; + props.stype = ZE_STRUCTURE_TYPE_DEVICE_PROPERTIES; + ze_result_t r = zeDeviceGetProperties(ze_dev, &props); + return r == ZE_RESULT_SUCCESS && !(props.flags & ZE_DEVICE_PROPERTY_FLAG_INTEGRATED); +} +#endif + static void dev2dev_memcpy(sycl::queue &q_dst, sycl::queue &q_src, void *ptr_dst, const void *ptr_src, size_t size) { +#ifdef GGML_SYCL_SUPPORT_LEVEL_ZERO + // Use Level Zero direct copy for dGPU-to-dGPU transfers. + const bool l0_copy_supported = + ggml_sycl_is_l0_discrete_gpu(q_dst) && ggml_sycl_is_l0_discrete_gpu(q_src); + if (g_ggml_sycl_enable_level_zero && l0_copy_supported) { + auto ze_ctx = sycl::get_native(q_dst.get_context()); + auto ze_dev = sycl::get_native(q_dst.get_device()); + ze_command_queue_desc_t cq_desc = {ZE_STRUCTURE_TYPE_COMMAND_QUEUE_DESC, nullptr, 0, 0, + 0, ZE_COMMAND_QUEUE_MODE_SYNCHRONOUS, ZE_COMMAND_QUEUE_PRIORITY_NORMAL}; + ze_command_list_handle_t cl; + ze_result_t r = zeCommandListCreateImmediate(ze_ctx, ze_dev, &cq_desc, &cl); + if (r == ZE_RESULT_SUCCESS) { + r = zeCommandListAppendMemoryCopy(cl, ptr_dst, ptr_src, size, nullptr, 0, nullptr); + zeCommandListDestroy(cl); + if (r == ZE_RESULT_SUCCESS) { + return; + } + } + } +#endif + // Host-staged copy char *host_buf = (char *)malloc(size); q_src.memcpy(host_buf, (const char *)ptr_src, size).wait(); q_dst.memcpy((char *)ptr_dst, host_buf, size).wait(); @@ -675,8 +746,7 @@ ggml_backend_sycl_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size = std::max(size, (size_t)1); // syclMalloc returns null for size 0 void * dev_ptr; - SYCL_CHECK(CHECK_TRY_ERROR(dev_ptr = (void *)sycl::malloc_device( - size, *stream))); + SYCL_CHECK(CHECK_TRY_ERROR(dev_ptr = (void *)ggml_sycl_malloc_device(size, *stream))); if (!dev_ptr) { GGML_LOG_ERROR("%s: can't allocate %lu Bytes of memory on device\n", __func__, size); return nullptr; @@ -917,18 +987,10 @@ ggml_backend_sycl_split_buffer_init_tensor(ggml_backend_buffer_t buffer, size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING); } - // FIXME: do not crash if SYCL Buffer alloc fails - // currently, init_tensor cannot fail, it needs to be fixed in ggml-backend first ggml_sycl_set_device(i); const queue_ptr stream = ctx->streams[i]; char * buf; - /* - DPCT1009:208: SYCL uses exceptions to report errors and does not use the - error codes. The original code was commented out and a warning string - was inserted. You need to rewrite this code. - */ - SYCL_CHECK(CHECK_TRY_ERROR(buf = (char *)sycl::malloc_device( - size, *stream))); + SYCL_CHECK(CHECK_TRY_ERROR(buf = (char *)ggml_sycl_malloc_device(size, *stream))); if (!buf) { char err_buf[1024]; snprintf(err_buf, 1023, "%s: can't allocate %lu Bytes of memory on device\n", __func__, size); @@ -1306,7 +1368,7 @@ struct ggml_sycl_pool_leg : public ggml_sycl_pool { for (int i = 0; i < MAX_SYCL_BUFFERS; ++i) { ggml_sycl_buffer & b = buffer_pool[i]; if (b.ptr != nullptr) { - SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(b.ptr, *qptr))); + SYCL_CHECK(CHECK_TRY_ERROR(ggml_sycl_free_device(b.ptr, *qptr))); pool_size -= b.size; } } @@ -1374,9 +1436,7 @@ struct ggml_sycl_pool_leg : public ggml_sycl_pool { void * ptr; size_t look_ahead_size = (size_t) (1.05 * size); - SYCL_CHECK( - CHECK_TRY_ERROR(ptr = (void *)sycl::malloc_device( - look_ahead_size, *qptr))); + SYCL_CHECK(CHECK_TRY_ERROR(ptr = (void *)ggml_sycl_malloc_device(look_ahead_size, *qptr))); if (!ptr) { GGML_LOG_ERROR("%s: can't allocate %lu Bytes of memory on device/GPU\n", __func__, look_ahead_size); return nullptr; @@ -1404,7 +1464,7 @@ struct ggml_sycl_pool_leg : public ggml_sycl_pool { } } GGML_LOG_WARN("WARNING: sycl buffer pool full, increase MAX_sycl_BUFFERS\n"); - SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(ptr, *qptr))); + SYCL_CHECK(CHECK_TRY_ERROR(ggml_sycl_free_device(ptr, *qptr))); pool_size -= size; } }; @@ -3405,7 +3465,7 @@ static inline void * sycl_ext_malloc_device(dpct::queue_ptr stream, size_t size) // If async allocation extension is not available, use_async should always be false. GGML_ASSERT(!use_async); #endif - return sycl::malloc(size, *stream, sycl::usm::alloc::device); + return ggml_sycl_malloc_device(size, *stream); } static inline void sycl_ext_free(dpct::queue_ptr stream, void * ptr) { @@ -3419,7 +3479,7 @@ static inline void sycl_ext_free(dpct::queue_ptr stream, void * ptr) { // If async allocation extension is not available, use_async should always be false. GGML_ASSERT(!use_async); #endif - sycl::free(ptr, *stream); + ggml_sycl_free_device(ptr, *stream); } // RAII wrapper for temporary reorder buffers with optional host memory fallback. From 97ba44338fcda566e10644057135f08ea820ff60 Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Thu, 14 May 2026 10:36:54 +0200 Subject: [PATCH 631/831] vulkan: fix matmul integer pipeline selection (llama/23005) * vulkan: fix matmul integer pipeline selection * gate pipeline creation with the right bools --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index a0a556206d5..8c4cf9ef1db 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -3954,13 +3954,13 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ #define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \ - if (device->mul_mat ## ID ## _l[TYPE]) { \ + if (device->mul_mat ## ID ## _l_int[TYPE]) { \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->l, #NAMELC "_l", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ } \ - if (device->mul_mat ## ID ## _m[TYPE]) { \ + if (device->mul_mat ## ID ## _m_int[TYPE]) { \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->m, #NAMELC "_m", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ } \ - if (device->mul_mat ## ID ## _s[TYPE]) { \ + if (device->mul_mat ## ID ## _s_int[TYPE]) { \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->s, #NAMELC "_s", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ } \ @@ -4131,11 +4131,11 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ #define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ - if (device->mul_mat ## ID ## _l[TYPE]) \ + if (device->mul_mat ## ID ## _l_int[TYPE]) \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC "_l", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \ - if (device->mul_mat ## ID ## _m[TYPE]) \ + if (device->mul_mat ## ID ## _m_int[TYPE]) \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC "_m", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \ - if (device->mul_mat ## ID ## _s[TYPE]) \ + if (device->mul_mat ## ID ## _s_int[TYPE]) \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC "_s", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \ CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); @@ -5716,12 +5716,12 @@ static vk_device ggml_vk_get_device(size_t idx) { break; } - device->mul_mat_l_int[i] = true; - device->mul_mat_m_int[i] = true; - device->mul_mat_s_int[i] = true; - device->mul_mat_id_l_int[i] = true; - device->mul_mat_id_m_int[i] = true; - device->mul_mat_id_s_int[i] = true; + device->mul_mat_l_int[i] = device->mul_mat_l[i]; + device->mul_mat_m_int[i] = device->mul_mat_m[i]; + device->mul_mat_s_int[i] = device->mul_mat_s[i]; + device->mul_mat_id_l_int[i] = device->mul_mat_id_l[i]; + device->mul_mat_id_m_int[i] = device->mul_mat_id_m[i]; + device->mul_mat_id_s_int[i] = device->mul_mat_id_s[i]; } From f0223903aa33894bcc6d7f23b7b7466dc95c33dd Mon Sep 17 00:00:00 2001 From: alex-spacemit Date: Thu, 14 May 2026 17:39:30 +0800 Subject: [PATCH 632/831] ggml-cpu: Add IME2 Instruction Support for the SpacemiT Backend (llama/22863) --- ggml/src/ggml-cpu/CMakeLists.txt | 13 + ggml/src/ggml-cpu/cmake/FindSMTIME.cmake | 32 + ggml/src/ggml-cpu/ggml-cpu.c | 12 + ggml/src/ggml-cpu/spacemit/ime.cpp | 2089 ++++-- ggml/src/ggml-cpu/spacemit/ime.h | 8 + ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp | 3363 ++-------- ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp | 5768 +++++++++++++++++ ggml/src/ggml-cpu/spacemit/ime_env.cpp | 320 + ggml/src/ggml-cpu/spacemit/ime_env.h | 55 + ggml/src/ggml-cpu/spacemit/ime_kernels.h | 201 +- ggml/src/ggml-cpu/spacemit/repack.cpp | 1795 +++++ ggml/src/ggml-cpu/spacemit/repack.h | 14 + ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp | 3178 +++++++++ ggml/src/ggml-cpu/spacemit/rvv_kernels.h | 95 + ggml/src/ggml-cpu/spacemit/spine_barrier.h | 34 + ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp | 760 +++ ggml/src/ggml-cpu/spacemit/spine_mem_pool.h | 32 + ggml/src/ggml-cpu/spacemit/spine_tcm.h | 409 ++ 18 files changed, 14706 insertions(+), 3472 deletions(-) create mode 100644 ggml/src/ggml-cpu/cmake/FindSMTIME.cmake create mode 100644 ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp create mode 100644 ggml/src/ggml-cpu/spacemit/ime_env.cpp create mode 100644 ggml/src/ggml-cpu/spacemit/ime_env.h create mode 100644 ggml/src/ggml-cpu/spacemit/repack.cpp create mode 100644 ggml/src/ggml-cpu/spacemit/repack.h create mode 100644 ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp create mode 100644 ggml/src/ggml-cpu/spacemit/rvv_kernels.h create mode 100644 ggml/src/ggml-cpu/spacemit/spine_barrier.h create mode 100644 ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp create mode 100644 ggml/src/ggml-cpu/spacemit/spine_mem_pool.h create mode 100644 ggml/src/ggml-cpu/spacemit/spine_tcm.h diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index 869c7b238bf..f3eccff7d72 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -450,12 +450,22 @@ function(ggml_add_cpu_backend_variant_impl tag_name) ggml-cpu/arch/riscv/repack.cpp ) if (GGML_CPU_RISCV64_SPACEMIT) + include(ggml-cpu/cmake/FindSMTIME.cmake) target_compile_definitions(${GGML_CPU_NAME} PRIVATE GGML_USE_CPU_RISCV64_SPACEMIT ${RISCV64_SPACEMIT_IME_SPEC}) list(APPEND GGML_CPU_SOURCES ggml-cpu/spacemit/ime.cpp ggml-cpu/spacemit/ime.h + ggml-cpu/spacemit/spine_mem_pool.cpp + ggml-cpu/spacemit/spine_mem_pool.h + ggml-cpu/spacemit/repack.cpp + ggml-cpu/spacemit/repack.h + ggml-cpu/spacemit/ime_env.cpp + ggml-cpu/spacemit/ime_env.h ggml-cpu/spacemit/ime1_kernels.cpp + ggml-cpu/spacemit/ime2_kernels.cpp ggml-cpu/spacemit/ime_kernels.h + ggml-cpu/spacemit/rvv_kernels.cpp + ggml-cpu/spacemit/rvv_kernels.h ) endif() if(NOT GGML_CPU_ALL_VARIANTS) @@ -485,6 +495,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name) if (GGML_RV_ZIHINTPAUSE) string(APPEND MARCH_STR "_zihintpause") endif() + if (GGML_RV_ZBA) + string(APPEND MARCH_STR "_zba") + endif() if (GGML_CPU_RISCV64_SPACEMIT) # `xsmtvdotii' is only required for GCC >= 15. if (CMAKE_C_COMPILER_ID STREQUAL "GNU" AND diff --git a/ggml/src/ggml-cpu/cmake/FindSMTIME.cmake b/ggml/src/ggml-cpu/cmake/FindSMTIME.cmake new file mode 100644 index 00000000000..c8a4d4b4ec9 --- /dev/null +++ b/ggml/src/ggml-cpu/cmake/FindSMTIME.cmake @@ -0,0 +1,32 @@ +include(CheckCSourceRuns) + +if (CMAKE_SYSTEM_PROCESSOR MATCHES "^(riscv)" AND GGML_CPU_RISCV64_SPACEMIT) + set(SMT_MARCH_STR "-march=rv64gcv_zfh_zvfh_zba_zicbop") + if (CMAKE_C_COMPILER_ID STREQUAL "GNU" AND + CMAKE_C_COMPILER_VERSION VERSION_GREATER_EQUAL 15) + string(APPEND SMT_MARCH_STR "_xsmtvdotii") + endif() + set(CMAKE_REQUIRED_FLAGS "${SMT_MARCH_STR}") + + check_c_source_compiles("int main() {__asm__ volatile(\"vmadot v2, v0, v1\");}" SPACEMIT_RISCV_COMPILER_SUPPORT_IME1) + check_c_source_compiles("int main() {__asm__ volatile(\"vmadot v2, v0, v1, i4\");}" SPACEMIT_RISCV_COMPILER_SUPPORT_VMADOT_S4) + check_c_source_compiles("int main() {__asm__ volatile(\"vmadot v2, v0, v1, i8\");}" SPACEMIT_RISCV_COMPILER_SUPPORT_VMADOT_S8) + check_c_source_compiles("int main() {__asm__ volatile(\"vfwmadot v2, v0, v1, fp16\");}" SPACEMIT_RISCV_COMPILER_SUPPORT_VFWMADOT_FP16) + check_c_source_compiles("int main() {__asm__ volatile(\"vmadot.hp v2, v0, v1, v0, 0, i4\");}" SPACEMIT_RISCV_COMPILER_SUPPORT_VFMADOT_S4) + check_c_source_compiles("int main() {__asm__ volatile(\"vmadot.hp v2, v0, v1, v0, 0, i8\");}" SPACEMIT_RISCV_COMPILER_SUPPORT_VFMADOT_S8) + check_c_source_compiles("int main() {__asm__ volatile(\"vmadot1 v2, v0, v1\");}" SPACEMIT_RISCV_COMPILER_SUPPORT_VMADOTN) + check_c_source_compiles("int main() {__asm__ volatile(\"vpack.vv v2, v0, v1, 2\");}" SPACEMIT_RISCV_COMPILER_SUPPORT_VPACK) + check_c_source_compiles("int main() {__asm__ volatile(\"vnspack.vv v2, v0, v1, 2\");}" SPACEMIT_RISCV_COMPILER_SUPPORT_VNPACK) + unset(CMAKE_REQUIRED_FLAGS) + + list(APPEND RISCV64_SPACEMIT_IME_SPEC "") + if (SPACEMIT_RISCV_COMPILER_SUPPORT_IME1) + set(RISCV64_SPACEMIT_IME_SPEC "RISCV64_SPACEMIT_IME1") + endif() + + if (SPACEMIT_RISCV_COMPILER_SUPPORT_VMADOT_S4 AND SPACEMIT_RISCV_COMPILER_SUPPORT_VPACK AND SPACEMIT_RISCV_COMPILER_SUPPORT_VNPACK) + list(APPEND RISCV64_SPACEMIT_IME_SPEC "RISCV64_SPACEMIT_IME2") + endif() + + message("RISCV64_SPACEMIT_IME_SPEC: ${RISCV64_SPACEMIT_IME_SPEC}") +endif() diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 8b7acafdaa8..7b05edf6b75 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -50,6 +50,10 @@ #include "llamafile/sgemm.h" #endif +#ifdef GGML_USE_CPU_RISCV64_SPACEMIT +# include "spacemit/ime.h" +#endif + // Note: once we move threading into a separate C++ file // will use std::hardware_destructive_interference_size instead of hardcoding it here // and we'll use C++ attribute syntax. @@ -3011,7 +3015,11 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { const struct ggml_cgraph * cgraph = tp->cgraph; const struct ggml_cplan * cplan = tp->cplan; +#ifdef GGML_USE_CPU_RISCV64_SPACEMIT + ggml_backend_cpu_riscv64_spacemit_set_numa_thread_affinity(state->ith); +#else set_numa_thread_affinity(state->ith); +#endif struct ggml_compute_params params = { /*.ith =*/ state->ith, @@ -3068,6 +3076,10 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { ggml_barrier(state->threadpool); +#ifdef GGML_USE_CPU_RISCV64_SPACEMIT + ggml_backend_cpu_riscv64_spacemit_clear_numa_thread_affinity_threaded(state->ith); +#endif + return 0; } diff --git a/ggml/src/ggml-cpu/spacemit/ime.cpp b/ggml/src/ggml-cpu/spacemit/ime.cpp index 91fe1925eaa..9563ea3e4bd 100644 --- a/ggml/src/ggml-cpu/spacemit/ime.cpp +++ b/ggml/src/ggml-cpu/spacemit/ime.cpp @@ -3,19 +3,32 @@ #include "ime.h" +#include "binary-ops.h" +#include "common.h" #include "ggml-backend-impl.h" #include "ggml-common.h" #include "ggml-cpu.h" +#include "ime_env.h" #include "ime_kernels.h" +#include "ops.h" +#include "repack.h" +#include "rvv_kernels.h" +#include "spine_mem_pool.h" #include "traits.h" +#include "vec.h" + +#include +#include +#include #include +#include #include +#include #include #include // for GGML_ASSERT #include #include - // clang-format off #if defined(__riscv) @@ -25,13 +38,17 @@ #include #endif -#if !defined(__riscv_zfh) -#error "riscv zfh extension not enabled" +#if !defined(__riscv_zfh) || !defined(__riscv_zvfh) +#error "riscv zfh extension not enabled, GGML_RV_ZFH and GGML_RV_ZVFH must be defined to 1" #endif -#if defined(RISCV64_SPACEMIT_IME1) +#if !defined(__riscv_zba) +#error "riscv zba extension not enabled, GGML_RV_ZBA must be defined to 1" +#endif + +#if defined(RISCV64_SPACEMIT_IME1) || defined(RISCV64_SPACEMIT_IME2) #else -#error "RISCV64_SPACEMIT_IME1 not defined" +#error "RISCV64_SPACEMIT_IME1 or RISCV64_SPACEMIT_IME2 not defined" #endif #else @@ -46,382 +63,490 @@ #pragma GCC diagnostic ignored "-Wunused-parameter" #endif -#if defined(RISCV64_SPACEMIT_IME1) -#define QGEMM_STRIDEN_THREAD_ALIGN 16 -#else -#define QGEMM_STRIDEN_THREAD_ALIGN 32 -#endif - // clang-format on -struct qnbitgemm_spacemit_ime_args { - const float * a_ptr = nullptr; - size_t lda = 0; - const std::byte * packed_quant_b_data = nullptr; - const float * quant_b_scale = nullptr; - const void * quant_b_zp = nullptr; - const float * quant_b_blksum = nullptr; - const float * bias = nullptr; - float * c_ptr = nullptr; - size_t ldc = 0; -}; - -constexpr size_t div_round_up(size_t up, size_t down) { - return (up + down - 1) / down; -} - -constexpr size_t q8_blk_size(size_t blk_len) { - const size_t blk_size = sizeof(float) + blk_len * sizeof(int8_t); - // Currently, the strictest alignment requirement of a block is for a float. - // Ensure contiguous blocks are suitably aligned. - assert(blk_size % alignof(float) == 0); - return blk_size; +extern "C" { +extern void ggml_threadpool_chunk_set(struct ggml_threadpool * tp, int value); +extern int ggml_threadpool_chunk_add(struct ggml_threadpool * tp, int value); } namespace ggml::cpu::riscv64_spacemit { -const int num_ai_cores = std::thread::hardware_concurrency() / 2; - -} // namespace ggml::cpu::riscv64_spacemit +struct TLSContext { + int cpu_id{ -1 }; + cpu_set_t cpuset; + void * tcm_buffer{ nullptr }; + size_t tcm_buffer_size{ 0 }; +}; -static void sqnbitgemm_spacemit_ime_i8i4(const size_t blk_len, - const size_t gemm_k, - const qnbitgemm_spacemit_ime_args * gemm_args, - void * const per_gemm_ws, - const size_t m_start, - const size_t m_count, - const size_t n_start, - const size_t n_count) { - constexpr size_t scale_stride = sizeof(uint16_t); - constexpr size_t blk_bitwidth = 4; +thread_local TLSContext tls_context; + +template constexpr size_t get_repacked_block_type_size() { + if constexpr (std::is_same_v || std::is_same_v) { + return sizeof(block_q8_0); + } else if constexpr (std::is_same_v) { + return sizeof(block_q4_0) * INTER_SIZE / QK4_0; + } else if constexpr (std::is_same_v || std::is_same_v) { + return (sizeof(block_q4_0) + sizeof(uint8_t)) * INTER_SIZE / QK4_1; + } else if constexpr (std::is_same_v) { + return sizeof(spacemit_kernels::nrow_block_q2_k<1>); + } else if constexpr (std::is_same_v) { + return sizeof(spacemit_kernels::nrow_block_q3_k<1>); + } else if constexpr (std::is_same_v) { + return sizeof(spacemit_kernels::nrow_block_mxfp4<1>); + } else if constexpr (std::is_same_v || std::is_same_v) { + return sizeof(spacemit_kernels::nrow_block_q5_1<1>); + } else if constexpr (std::is_same_v) { + return sizeof(spacemit_kernels::nrow_block_q5_0<1>); + } else { + assert(false); + return 0; + } +} - const size_t k_blks = div_round_up(gemm_k, blk_len); +template constexpr bool block_type_has_zp() { + if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) { + return false; + } else if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v) { + return true; + } else { + assert(false); + return false; + } +} - const size_t lda = k_blks * q8_blk_size(blk_len); - const size_t ldc = gemm_args->ldc; - const size_t ldb = k_blks * (blk_len * blk_bitwidth / 8); - const std::byte * quant_a_ptr = static_cast(per_gemm_ws) + m_start * lda; +class tensor_traits_base : public ggml::cpu::tensor_traits { + public: + virtual int repack(ggml_tensor * t, const void * data, size_t data_size) = 0; +}; - const size_t zero_point_stride = gemm_args->quant_b_zp != nullptr ? sizeof(uint8_t) : 0; - const size_t packed_b_stride = ldb + k_blks * (scale_stride + zero_point_stride); - const std::byte * packed_quant_b_data = gemm_args->packed_quant_b_data + n_start * packed_b_stride; +template class tensor_traits : public tensor_traits_base { + bool work_size(int /* n_threads */, const ggml_tensor * op, size_t & size) override { + switch (op->op) { + case GGML_OP_MUL_MAT: + { + int64_t src1_nelements = ggml_nelements(op->src[1]); + + if constexpr (std::is_same_v || std::is_same_v) { + size = + spacemit_kernels::div_round_up(src1_nelements, QK_K) * spacemit_kernels::q8k_blk_size(QK_K); + } else if constexpr (INTER_SIZE == QK4_0) { + size = spacemit_kernels::div_round_up(src1_nelements, QK4_0) * + spacemit_kernels::q8_blk_size(QK4_0, true); + } else if constexpr (INTER_SIZE == 256) { + size = spacemit_kernels::div_round_up(src1_nelements, 256) * + spacemit_kernels::q8_hp_blk_size(256, true, true); + } else { + GGML_ABORT("unsupported block type"); + } - float * c_ptr = gemm_args->c_ptr + m_start * ldc + n_start; + size = GGML_PAD(size, sizeof(int64_t)); - size_t count_n = 0; - const size_t compute_block_count_n = m_count == 1 ? n_count : 16; - for (size_t n = 0; n < n_count; n += count_n) { - count_n = std::min(n_count - n, compute_block_count_n); + return true; + } + case GGML_OP_MUL_MAT_ID: + { + int64_t src1_nelements = ggml_nelements(op->src[1]); + + if constexpr (std::is_same_v || std::is_same_v) { + size = + spacemit_kernels::div_round_up(src1_nelements, QK_K) * spacemit_kernels::q8k_blk_size(QK_K); + } else if constexpr (INTER_SIZE == QK4_0) { + size = spacemit_kernels::div_round_up(src1_nelements, QK4_0) * + spacemit_kernels::q8_blk_size(QK4_0, true); + } else if constexpr (INTER_SIZE == 256) { + size = spacemit_kernels::div_round_up(src1_nelements, 256) * + spacemit_kernels::q8_hp_blk_size(256, true, true); + } else { + GGML_ABORT("unsupported block type"); + } - const std::byte * a_row = quant_a_ptr; - const std::byte * b_col = packed_quant_b_data + n * packed_b_stride; - const std::byte * b_col_zp = (zero_point_stride != 0) ? b_col : nullptr; - float * c_blk = c_ptr + n; + size = GGML_PAD(size, sizeof(int64_t)); - int32_t rows_remaining = m_count; + const int64_t ne02 = op->src[0]->ne[2]; // n_as, n_expert + const int64_t ne12 = op->src[1]->ne[2]; // n_tokens - while (rows_remaining > 0) { - const auto rows_handled = sqnbitgemm_spacemit_ime::ime1::gemm_kernel_i8i4( - blk_len, a_row, b_col, nullptr, b_col_zp, c_blk, rows_remaining, count_n, gemm_k, k_blks, ldc, nullptr, - scale_stride); + const size_t sizeof_mmid_row_mapping = sizeof(int64_t); + size += sizeof_mmid_row_mapping * ne02 * (ne12 + 1) + (ne02 + 1) * sizeof(int64_t); - c_blk += rows_handled * ldc; - a_row += rows_handled * lda; + size = GGML_PAD(size, sizeof(int64_t)); - rows_remaining -= rows_handled; + return true; + } + default: + // GGML_ABORT("fatal error"); + break; } + return false; } -} -template constexpr int QK_0() { - if constexpr (K == 4) { - return QK4_0; - } - if constexpr (K == 8) { - return QK8_0; + bool compute_forward(ggml_compute_params * params, ggml_tensor * op) override { + switch (op->op) { + case GGML_OP_MUL_MAT: + switch (op->src[0]->type) { + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q5_K: + //case GGML_TYPE_MXFP4: + forward_mul_mat(params, op); + return true; + default: + // GGML_ABORT("fatal error: unsupported type for src0 in MUL_MAT"); + return false; + } + break; + case GGML_OP_MUL_MAT_ID: + switch (op->src[0]->type) { + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q5_K: + //case GGML_TYPE_MXFP4: + forward_mul_mat_id(params, op); + return true; + default: + // GGML_ABORT("fatal error: unsupported type for src0 in MUL_MAT_ID"); + return false; + } + break; + default: + // GGML_ABORT("fatal error"); + break; + } + return false; } - return -1; -} -template struct block { - ggml_half d[N]; // deltas for N qK_0 blocks - uint8_t qs[(QK_0() * N * K) / 8]; // quants for N qK_0 blocks -}; + void forward_mul_mat(ggml_compute_params * params, ggml_tensor * op) { + constexpr size_t a_blk_len = INTER_SIZE; + constexpr size_t b_blk_len = INTER_SIZE; -template struct block_with_zp { - ggml_half d[N]; // deltas for N qK_1 blocks - uint8_t zp[N]; // zero points for N qK_1 blocks - uint8_t qs[(QK_0() * N * K) / 8]; // quants for N qK_1 blocks -}; + const ggml_tensor * src0 = op->src[0]; + const ggml_tensor * src1 = op->src[1]; + ggml_tensor * dst = op; -// control size -static_assert(sizeof(block<4, 16>) == 16 * sizeof(ggml_half) + QK4_0 * 8, "wrong block<4,16> size/padding"); -static_assert(sizeof(block_with_zp<4, 16>) == 16 * sizeof(ggml_half) + QK4_0 * 8 + 16 * sizeof(uint8_t), - "wrong block_with_zp<4,16> size/padding"); -static_assert(sizeof(block<8, 16>) == 16 * sizeof(ggml_half) + QK4_0 * 16, "wrong block<8,16> size/padding"); + GGML_TENSOR_BINARY_OP_LOCALS -using block_q4_0x16 = block<4, 16>; -using block_q4_1x16 = block_with_zp<4, 16>; -using block_q8_0x16 = block<8, 16>; + int ith = params->ith; + int nth = params->nth; -static block_q4_0x16 make_block_q4_0x16(block_q4_0 * in, unsigned int blck_size_interleave) { - block_q4_0x16 out; - GGML_ASSERT(QK4_0 / blck_size_interleave == 2); + [[maybe_unused]] const enum ggml_type type = src0->type; - for (int i = 0; i < 16; i++) { - out.d[i] = in[i].d; - } + void * w_data = (void *) src0->data; + const float * feature = (const float *) src1->data; + float * output = (float *) dst->data; - for (int i = 0; i < 16; i++) { - // [0, 15], in.d & 0x0F - for (int j = 0; j < QK4_0 / 4; j++) { - //src [b0 b16] ......... [b8 b24] ......... [b15 b31] - //dst [b0 b8] ......... [b7 b15] - out.qs[i * QK4_0 / 4 + j] = (in[i].qs[j] & 0x0F) | ((in[i].qs[j + QK4_0 / 4] & 0x0F) << 4); + const int64_t gemm_m = ne11 * ne12 * ne13; + const int64_t gemm_k = ne10; + const int64_t gemm_n = ne01; + + spacemit_kernels::quantize_a_row_def quantize_a_row_i8; + spacemit_kernels::quantize_a_row_def quantize_a_4row_i8; + spacemit_kernels::gemm_kernel_quantize_def gemm_kernel; + bool set_kernel_impl = false; + + int64_t block_stride_a = spacemit_kernels::q8_blk_size(a_blk_len); + +#if defined(RISCV64_SPACEMIT_IME2) + if (!set_kernel_impl && (global_spine_env_info.use_ime2)) { + quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8; + quantize_a_4row_i8 = spacemit_kernels::rvv::quantize_a_4row_i8; + block_stride_a = spacemit_kernels::q8_blk_size(a_blk_len, true); + + if constexpr (std::is_same_v || std::is_same_v) { + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i8; + set_kernel_impl = true; + } else if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v) { + if constexpr (INTER_SIZE == 256) { + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i4_hp; + quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8_hp; + quantize_a_4row_i8 = spacemit_kernels::rvv::quantize_a_4row_i8_hp; + block_stride_a = spacemit_kernels::q8_hp_blk_size(a_blk_len, true, true); + set_kernel_impl = true; + } else { + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i4; + quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8; + quantize_a_4row_i8 = spacemit_kernels::rvv::quantize_a_4row_i8; + block_stride_a = spacemit_kernels::q8_blk_size(a_blk_len, true); + set_kernel_impl = true; + } + } else if constexpr (std::is_same_v) { + quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8k; + quantize_a_4row_i8 = spacemit_kernels::rvv::quantize_a_4row_i8k; + block_stride_a = spacemit_kernels::q8k_blk_size(a_blk_len); + + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i2k; + set_kernel_impl = true; + } else if constexpr (std::is_same_v) { + quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8k; + quantize_a_4row_i8 = spacemit_kernels::rvv::quantize_a_4row_i8k; + block_stride_a = spacemit_kernels::q8k_blk_size(a_blk_len); + + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i3k; + set_kernel_impl = true; + } else if constexpr (std::is_same_v) { + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8mxfp4; + set_kernel_impl = true; + } else if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v) { + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i5; + set_kernel_impl = true; + } } - } +#endif - for (int i = 0; i < 16; i++) { - // [16, 31], in.d & 0xF0 - for (int j = 0; j < QK4_0 / 4; j++) { - //src [b0 b16] ......... [b8 b24] ......... [b15 b31] - //dst [b16 b24] ......... [b23 b31] - out.qs[4 * QK4_0 + i * QK4_0 / 4 + j] = ((in[i].qs[j] & 0xF0) >> 4) | (in[i].qs[j + QK4_0 / 4] & 0xF0); +#if defined(RISCV64_SPACEMIT_IME1) + if (!set_kernel_impl && (global_spine_env_info.use_ime1)) { + quantize_a_row_i8 = spacemit_kernels::ime1::quantize_a_row_i8; + quantize_a_4row_i8 = spacemit_kernels::ime1::quantize_a_4row_i8; + + if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v) { + gemm_kernel = spacemit_kernels::ime1::gemm_kernel_i8i4; + set_kernel_impl = true; + } + } +#endif + if (!set_kernel_impl) { + GGML_ABORT("no kernel implementation found for the block type"); } - } - return out; -} + const int64_t a_k_blks = spacemit_kernels::div_round_up(gemm_k, a_blk_len); + const int64_t b_k_blks = spacemit_kernels::div_round_up(gemm_k, b_blk_len); -static block_q4_1x16 make_block_q4_1x16(block_q4_1 * in, unsigned int blck_size_interleave) { - block_q4_1x16 out; - GGML_ASSERT(QK4_1 / blck_size_interleave == 2); - - for (int i = 0; i < 16; i++) { - float d = GGML_FP16_TO_FP32(in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d); - float m = GGML_FP16_TO_FP32(in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.m); - float mid = -std::nearbyintf(m / d); - mid = std::min(15.0f, std::max(0.0f, mid)); - out.d[i] = GGML_FP32_TO_FP16(d); - out.zp[i] = static_cast(mid); - } + const int64_t row_stride_a = a_k_blks * block_stride_a; + const int64_t gemm_workspace_size = GGML_PAD(gemm_m * row_stride_a, alignof(int64_t)); - for (int i = 0; i < 16; i++) { - // [0, 15], in.d & 0x0F - for (int j = 0; j < QK4_1 / 4; j++) { - //src [b0 b16] ......... [b8 b24] ......... [b15 b31] - //dst [b0 b8] ......... [b7 b15] - out.qs[i * QK4_1 / 4 + j] = (in[i].qs[j] & 0x0F) | ((in[i].qs[j + QK4_1 / 4] & 0x0F) << 4); + if (ith == 0 && params->wsize < gemm_workspace_size) { + GGML_ABORT("wsize less than gemm_workspace_size"); } - } - for (int i = 0; i < 16; i++) { - // [16, 31], in.d & 0xF0 - for (int j = 0; j < QK4_1 / 4; j++) { - //src [b0 b16] ......... [b8 b24] ......... [b15 b31] - //dst [b16 b24] ......... [b23 b31] - out.qs[4 * QK4_1 + i * QK4_1 / 4 + j] = ((in[i].qs[j] & 0xF0) >> 4) | (in[i].qs[j + QK4_1 / 4] & 0xF0); - } - } + uintptr_t ws_ptr = reinterpret_cast(params->wdata); - return out; -} + void * tcm_buffer = ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer; + const int64_t tcm_buffer_size = ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer_size; -static int repack_q4_0_to_q4_0_16_bl(struct ggml_tensor * t, - int interleave_block, - const void * GGML_RESTRICT data, - size_t data_size) { - GGML_ASSERT(t->type == GGML_TYPE_Q4_0); - GGML_ASSERT(interleave_block == 16); + auto * quant_a_buffer = reinterpret_cast(ws_ptr); - constexpr int nrows_interleaved = 16; + constexpr int64_t row_align = 4; + const int64_t row_blks = spacemit_kernels::div_round_up(gemm_m, row_align); - block_q4_0x16 * dst = (block_q4_0x16 *) t->data; - const block_q4_0 * src = (const block_q4_0 *) data; - block_q4_0 dst_tmp[16]; - int nrow = ggml_nrows(t); - int nblocks = t->ne[0] / QK4_0; + const int64_t row_stride_b = b_k_blks * get_repacked_block_type_size(); + const int64_t per_mb_rows_wsize = row_align * row_stride_a; + const int64_t per_nb_cols_wsize = NB_COLS * row_stride_b; - GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0)); + const int64_t barrier_idx = static_cast(ith / 2); - if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_0 != 0) { - return -1; - } + GGML_ASSERT(global_spine_env_info.init_barrier != nullptr); + GGML_ASSERT(barrier_idx < spine_init_barrier_count); + spine_barrier_t * cur_barrier = &global_spine_env_info.init_barrier[barrier_idx]; - for (int b = 0; b < nrow; b += nrows_interleaved) { - for (int64_t x = 0; x < nblocks; x++) { - for (int i = 0; i < nrows_interleaved; i++) { - dst_tmp[i] = src[x + i * nblocks]; + if (gemm_m == 1) { + int task_per_thread = spacemit_kernels::div_round_up(a_k_blks, nth); + int a_blk_start = ith * task_per_thread; + int a_blk_end = std::min(a_blk_start + task_per_thread, (int) a_k_blks); + if (a_blk_start < a_blk_end) { + quantize_a_row_i8(a_blk_len, feature + a_blk_start * a_blk_len, (a_blk_end - a_blk_start) * a_blk_len, + quant_a_buffer + a_blk_start * block_stride_a); + } + } else { + int task_per_thread = spacemit_kernels::div_round_up(row_blks, nth); + int m_row_blk_start = ith * task_per_thread; + int m_row_blk_end = std::min(m_row_blk_start + task_per_thread, (int) row_blks); + for (int m_row_blk = m_row_blk_start; m_row_blk < m_row_blk_end; m_row_blk++) { + int m_idx = m_row_blk * row_align; + int rows_tobe_handled = (gemm_m - m_idx) > row_align ? row_align : (gemm_m - m_idx); + + if (rows_tobe_handled == row_align && quantize_a_4row_i8 != nullptr) { + const float * a_row_ptr = feature + m_idx * gemm_k; + auto * quant_a_row_ptr = quant_a_buffer + m_idx * row_stride_a; + quantize_a_4row_i8(a_blk_len, a_row_ptr, gemm_k, quant_a_row_ptr); + } else { + while (rows_tobe_handled) { + const float * a_row_ptr = feature + m_idx * gemm_k; + auto * quant_a_row_ptr = quant_a_buffer + m_idx * row_stride_a; + quantize_a_row_i8(a_blk_len, a_row_ptr, gemm_k, quant_a_row_ptr); + rows_tobe_handled -= 1; + m_idx += 1; + } + } } - *dst++ = make_block_q4_0x16(dst_tmp, interleave_block); } - src += nrows_interleaved * nblocks; - } - return 0; - GGML_UNUSED(data_size); -} + ggml_barrier(params->threadpool); -static int repack_q4_1_to_q4_1_16_bl(struct ggml_tensor * t, - int interleave_block, - const void * GGML_RESTRICT data, - size_t data_size) { - GGML_ASSERT(t->type == GGML_TYPE_Q4_1); - GGML_ASSERT(interleave_block == 16); + const int64_t gemm_m_stride = gemm_n / gemm_m > 64 ? gemm_m : 16; + const int64_t gemm_m_blocked = spacemit_kernels::div_round_up(gemm_m, gemm_m_stride); + const int64_t max_gemm_n_stride = spacemit_kernels::div_round_up(gemm_n * gemm_m_blocked, nth); - constexpr int nrows_interleaved = 16; + int64_t gemm_n_stride = gemm_n; + if (max_gemm_n_stride < gemm_n) { + gemm_n_stride = + std::min(gemm_n_stride, spacemit_kernels::div_round_up(max_gemm_n_stride, NB_COLS) * NB_COLS); + } - block_q4_1x16 * dst = (block_q4_1x16 *) t->data; - const block_q4_1 * src = (const block_q4_1 *) data; - block_q4_1 dst_tmp[16]; - int nrow = ggml_nrows(t); - int nblocks = t->ne[0] / QK4_1; + if (gemm_n_stride == gemm_n && tcm_buffer != nullptr && per_mb_rows_wsize <= tcm_buffer_size) { + for (int64_t m_start = ith * row_align; m_start < gemm_m; m_start += row_align * nth) { + uint8_t * b_col = reinterpret_cast(w_data); + uint8_t * b_col_zp = block_type_has_zp() ? b_col : nullptr; - GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_1)); + int64_t m_row_real = std::min(gemm_m - m_start, row_align); - if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_1 != 0) { - return -1; - } + spacemit_kernels::rvv::memcpy1d(tcm_buffer, quant_a_buffer + m_start * row_stride_a, + m_row_real * row_stride_a); - for (int b = 0; b < nrow; b += nrows_interleaved) { - for (int64_t x = 0; x < nblocks; x++) { - for (int i = 0; i < nrows_interleaved; i++) { - dst_tmp[i] = src[x + i * nblocks]; + int64_t n_blk_real = 0; + for (int64_t ni = 0; ni < gemm_n; ni += n_blk_real, b_col += n_blk_real * row_stride_b) { + n_blk_real = std::min(gemm_n - ni, (int64_t) NB_COLS); + + uint8_t * a_row_ptr = (uint8_t *) tcm_buffer; + float * c_blk = output + m_start * gemm_n + ni; + + int32_t rows_remaining = m_row_real; + + while (rows_remaining > 0) { + auto rows_handled = gemm_kernel(b_blk_len, a_row_ptr, b_col, b_col_zp, c_blk, rows_remaining, + n_blk_real, b_k_blks, gemm_n); + + c_blk += rows_handled * gemm_n; + a_row_ptr += rows_handled * row_stride_a; + + rows_remaining -= rows_handled; + } + } } - *dst++ = make_block_q4_1x16(dst_tmp, interleave_block); - } - src += nrows_interleaved * nblocks; - } - return 0; + } else if (tcm_buffer != nullptr && per_nb_cols_wsize <= tcm_buffer_size) { + uint8_t * a_row = quant_a_buffer; + uint8_t * b_col = reinterpret_cast(tcm_buffer); + if ((gemm_workspace_size + per_nb_cols_wsize) <= tcm_buffer_size) { + a_row = (uint8_t *) tcm_buffer; + b_col = reinterpret_cast(tcm_buffer) + gemm_workspace_size; + } + uint8_t * b_col_zp = block_type_has_zp() ? b_col : nullptr; - GGML_UNUSED(data_size); -} + int64_t ni = ith * NB_COLS; + int64_t nb_real = std::min(gemm_n - ni, NB_COLS); -static inline void get_scale_min_k4(int j, - const uint8_t * GGML_RESTRICT q, - uint8_t * GGML_RESTRICT d, - uint8_t * GGML_RESTRICT m) { - if (j < 4) { - *d = q[j] & 63; - *m = q[j + 4] & 63; - } else { - *d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4); - *m = (q[j + 4] >> 4) | ((q[j - 0] >> 6) << 4); - } -} + if (ith % 2 == 0 && nb_real > 0) { + spacemit_kernels::rvv::memcpy1d(b_col, reinterpret_cast(w_data) + ni * row_stride_b, + nb_real * row_stride_b); + if (a_row != quant_a_buffer) { + spacemit_kernels::rvv::memcpy1d(a_row, quant_a_buffer, gemm_workspace_size); + } + } -static int repack_q4_k_to_q4_1_16_bl(struct ggml_tensor * t, - int interleave_block, - const void * GGML_RESTRICT data, - size_t data_size) { - GGML_ASSERT(t->type == GGML_TYPE_Q4_K); - GGML_ASSERT(interleave_block == 16); - GGML_ASSERT(QK_K / QK4_1 == 8); + spine_barrier_wait(cur_barrier); - constexpr int nrows_interleaved = 16; + if (ith % 2 != 0 && nb_real > 0) { + if (a_row != quant_a_buffer) { + spacemit_kernels::rvv::memcpy1d(a_row, quant_a_buffer, gemm_workspace_size); + } + spacemit_kernels::rvv::memcpy1d(b_col, reinterpret_cast(w_data) + ni * row_stride_b, + nb_real * row_stride_b); + } - block_q4_1x16 * dst = (block_q4_1x16 *) t->data; - const block_q4_K * src = (const block_q4_K *) data; - block_q4_1 dst_tmp[16]; - int nrow = ggml_nrows(t); - int nblocks = t->ne[0] / QK_K; + for (; ni < gemm_n; ni += NB_COLS * nth) { + int64_t rows_remaining = gemm_m; + float * c_blk = output + ni; + auto * a_row_cur = a_row; - if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK_K != 0) { - return -1; - } + if (ith % 2 != 0) { + spine_barrier_wait(cur_barrier); + } - for (int b = 0; b < nrow; b += nrows_interleaved) { - for (int64_t x = 0; x < nblocks; x++) { - for (int j = 0; j < 8; j++) { - for (int i = 0; i < nrows_interleaved; i++) { - uint8_t sc, m; - const float d = GGML_FP16_TO_FP32(src[x + i * nblocks].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d); - const float min = - GGML_FP16_TO_FP32(src[x + i * nblocks].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin); - get_scale_min_k4(j, src[x + i * nblocks].scales, &sc, &m); - const float d1 = d * sc; - const float m1 = min * m; - - dst_tmp[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d = GGML_FP32_TO_FP16(d1); - dst_tmp[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.m = GGML_FP32_TO_FP16(-m1); - // src -> [b0, b32] [b1, b33] ... [b31, b63] - // dst -> [b0, b16] [b1, b17] ... [b15, b31] [b32, b48] [b33, b49] ... [b47, b63] - const uint8_t * q = src[x + i * nblocks].qs + (j / 2) * QK4_1; - if (j % 2 == 0) { - for (int ii = 0; ii < 16; ii++) { - dst_tmp[i].qs[ii] = (q[ii] & 0x0F) | ((q[ii + 16] & 0x0F) << 4); - } - } else { - for (int ii = 0; ii < 16; ii++) { - dst_tmp[i].qs[ii] = ((q[ii] & 0xF0) >> 4) | (q[ii + 16] & 0xF0); - } - } + while (rows_remaining > 0) { + auto rows_handled = gemm_kernel(b_blk_len, a_row_cur, b_col, b_col_zp, c_blk, rows_remaining, + nb_real, b_k_blks, gemm_n); + + c_blk += rows_handled * gemm_n; + a_row_cur += rows_handled * row_stride_a; + + rows_remaining -= rows_handled; + } + + if (ith % 2 == 0) { + spine_barrier_wait(cur_barrier); + } + + const int64_t next_ni = ni + NB_COLS * nth; + if (next_ni < gemm_n) { + nb_real = std::min(gemm_n - next_ni, NB_COLS); + spacemit_kernels::rvv::memcpy1d(b_col, reinterpret_cast(w_data) + next_ni * row_stride_b, + nb_real * row_stride_b); } - *dst++ = make_block_q4_1x16(dst_tmp, interleave_block); } - } - src += nrows_interleaved * nblocks; - } - return 0; + } else { + const int64_t task_count_m = spacemit_kernels::div_round_up(gemm_m, gemm_m_stride); + const int64_t task_count_n = spacemit_kernels::div_round_up(gemm_n, gemm_n_stride); - GGML_UNUSED(data_size); -} + int64_t task_count = task_count_m * task_count_n; + int64_t task_per_thread = (task_count + nth - 1) / nth; + int64_t start = ith * task_per_thread; + int64_t end = std::min((ith + 1) * task_per_thread, task_count); + for (int64_t compute_idx = start; compute_idx < end; compute_idx++) { + const auto tid_n = compute_idx / task_count_m; + const auto tid_m = compute_idx % task_count_m; -namespace ggml::cpu::riscv64_spacemit { + const int64_t m_start = tid_m * gemm_m_stride; + const int64_t m_count = std::min(gemm_m - m_start, (int64_t) gemm_m_stride); -template -int repack(struct ggml_tensor *, const void *, size_t); + const int64_t n_start = tid_n * gemm_n_stride; + const int64_t n_count = std::min(gemm_n - n_start, (int64_t) gemm_n_stride); -template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { - return repack_q4_0_to_q4_0_16_bl(t, 16, data, data_size); -} + const int64_t n_blk = m_count == 1 ? n_count : NB_COLS; -template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { - return repack_q4_1_to_q4_1_16_bl(t, 16, data, data_size); -} + uint8_t * b_col = reinterpret_cast(w_data) + n_start * row_stride_b; + uint8_t * b_col_zp = block_type_has_zp() ? b_col : nullptr; -template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { - return repack_q4_k_to_q4_1_16_bl(t, 16, data, data_size); -} + int64_t n_blk_real = 0; + for (int64_t ni = 0; ni < n_count; ni += n_blk_real, b_col += n_blk_real * row_stride_b) { + n_blk_real = std::min(n_count - ni, n_blk); -class tensor_traits_base : public ggml::cpu::tensor_traits { - public: - virtual int repack(struct ggml_tensor * t, const void * data, size_t data_size) = 0; -}; + uint8_t * a_row = quant_a_buffer + m_start * row_stride_a; -template class tensor_traits : public tensor_traits_base { - bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override { - switch (op->op) { - case GGML_OP_MUL_MAT: - size = ggml_row_size(GGML_TYPE_Q8_0, ggml_nelements(op->src[1])) * 4; - size = ((size + QK4_0 - 1) / QK4_0) * (QK4_0 * sizeof(float) + sizeof(float)); - return true; - default: - // GGML_ABORT("fatal error"); - break; - } - return false; - } + float * c_blk = output + m_start * gemm_n + n_start + ni; - bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) override { - switch (op->op) { - case GGML_OP_MUL_MAT: - if (op->src[0]->type == GGML_TYPE_Q4_0 || // - op->src[0]->type == GGML_TYPE_Q4_1 || // - op->src[0]->type == GGML_TYPE_Q4_K) { - forward_mul_mat_q4(params, op); - return true; + int64_t rows_remaining = m_count; + + uint8_t * b_col_cur = b_col; + uint8_t * b_col_zp_cur = b_col_zp; + + while (rows_remaining > 0) { + auto rows_handled = gemm_kernel(b_blk_len, a_row, b_col_cur, b_col_zp_cur, c_blk, + rows_remaining, n_blk_real, b_k_blks, gemm_n); + + c_blk += rows_handled * gemm_n; + a_row += rows_handled * row_stride_a; + + rows_remaining -= rows_handled; + } } - default: - // GGML_ABORT("fatal error"); - break; + } } - return false; } - void forward_mul_mat_q4(ggml_compute_params * params, ggml_tensor * op) { + void forward_mul_mat_id(ggml_compute_params * params, ggml_tensor * op) { + constexpr size_t a_blk_len = INTER_SIZE; + constexpr size_t b_blk_len = INTER_SIZE; + const ggml_tensor * src0 = op->src[0]; const ggml_tensor * src1 = op->src[1]; + const ggml_tensor * ids = op->src[2]; ggml_tensor * dst = op; GGML_TENSOR_BINARY_OP_LOCALS @@ -429,133 +554,381 @@ template class tensor_ int ith = params->ith; int nth = params->nth; - [[maybe_unused]] const enum ggml_type type = src0->type; + // row groups + const int n_ids = ids->ne[0]; // n_expert_used + const int n_as = ne02; // n_expert + + struct mmid_row_mapping { + int32_t i1; + int32_t i2; + }; + + spacemit_kernels::quantize_a_row_def quantize_a_row_i8; + spacemit_kernels::gemm_kernel_quantize_def gemm_kernel; + spacemit_kernels::moe_gemm_kernel_quantize_def moe_gemm_kernel_m2; + bool set_kernel_impl = false; + size_t block_stride_a = spacemit_kernels::q8_blk_size(QK4_0); + +#if defined(RISCV64_SPACEMIT_IME2) + if (!set_kernel_impl && (global_spine_env_info.use_ime2)) { + quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8; + block_stride_a = spacemit_kernels::q8_blk_size(QK4_0, true); + + if constexpr (std::is_same_v || std::is_same_v) { + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i8; + set_kernel_impl = true; + } else if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v) { + if constexpr (INTER_SIZE == 256) { + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i4_hp; + quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8_hp; + block_stride_a = spacemit_kernels::q8_hp_blk_size(a_blk_len, true, true); + set_kernel_impl = true; + } else { + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i4; + moe_gemm_kernel_m2 = spacemit_kernels::ime2::moe_m2_gemm_kernel_i8i4; + quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8; + block_stride_a = spacemit_kernels::q8_blk_size(a_blk_len, true); + set_kernel_impl = true; + } + } else if constexpr (std::is_same_v) { + quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8k; + block_stride_a = spacemit_kernels::q8k_blk_size(a_blk_len); + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i2k; + set_kernel_impl = true; + } else if constexpr (std::is_same_v) { + quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8k; + block_stride_a = spacemit_kernels::q8k_blk_size(a_blk_len); + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i3k; + set_kernel_impl = true; + } else if constexpr (std::is_same_v) { + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8mxfp4; + moe_gemm_kernel_m2 = spacemit_kernels::ime2::moe_m2_gemm_kernel_i8mxfp4; + set_kernel_impl = true; + } else if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v) { + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i5; + moe_gemm_kernel_m2 = spacemit_kernels::ime2::moe_m2_gemm_kernel_i8i5; + set_kernel_impl = true; + } + } +#endif - void * w_data = (void *) src0->data; - const float * feature = (const float *) src1->data; - float * output = (float *) dst->data; +#if defined(RISCV64_SPACEMIT_IME1) + if (!set_kernel_impl && (global_spine_env_info.use_ime1)) { + quantize_a_row_i8 = spacemit_kernels::ime1::quantize_a_row_i8; + + if constexpr (std::is_same_v || std::is_same_v || + std::is_same_v) { + gemm_kernel = spacemit_kernels::ime1::gemm_kernel_i8i4; + set_kernel_impl = true; + } + } +#endif + if (!set_kernel_impl) { + GGML_ABORT("no kernel implementation found for the block type"); + } - const size_t batch_feature = ne12 * ne13; - [[maybe_unused]] const size_t batch_weight = ne02 * ne03; - const size_t gemm_m = ne11; - const size_t gemm_k = ne10; - const size_t gemm_n = ne01; + const size_t a_k_blks = spacemit_kernels::div_round_up(ne10, a_blk_len); + const size_t b_k_blks = spacemit_kernels::div_round_up(ne10, b_blk_len); - GGML_ASSERT(batch_weight == 1); + const size_t nbw1 = a_k_blks * block_stride_a; + const size_t nbw2 = ne11 * nbw1; + const size_t nbw3 = nbw2 * ne12; + const size_t gemm_workspace_size = GGML_PAD(nbw3, alignof(int64_t)); - const size_t block_count_k = div_round_up(gemm_k, QK4_0); - const size_t per_gemm_workspace_size = gemm_m * block_count_k * q8_blk_size(QK4_0); - const size_t per_gemm_workspace_stride = - div_round_up(per_gemm_workspace_size, alignof(uint64_t)) * alignof(uint64_t); - const size_t gemm_workspace_size = batch_feature * per_gemm_workspace_stride; - const size_t desired_wsize = gemm_workspace_size + alignof(uint64_t) - 1; + const uintptr_t ws_ptr = reinterpret_cast(params->wdata); + auto * quant_a_buffer = reinterpret_cast(ws_ptr); - if (ith == 0 && params->wsize < desired_wsize) { - throw std::runtime_error("wsize less than desired_wsize"); + if (ne11 == 1) { + for (int64_t ii = ith; ii < ne12 * a_k_blks; ii += nth) { + int64_t i12 = ii / a_k_blks; + int64_t ak_blk_id = ii % a_k_blks; + quantize_a_row_i8(a_blk_len, (float *) ((char *) src1->data + i12 * nb12) + ak_blk_id * a_blk_len, + a_blk_len, quant_a_buffer + i12 * nbw2 + ak_blk_id * block_stride_a); + } + } else { + for (int64_t ii = ith; ii < ne12 * ne11; ii += nth) { + int64_t i12 = ii / ne11; + int64_t i11 = ii % ne11; + quantize_a_row_i8(a_blk_len, (float *) ((char *) src1->data + i12 * nb12 + i11 * nb11), ne10, + quant_a_buffer + i12 * nbw2 + i11 * nbw1); + } } - std::vector qnbitgemm_args(batch_feature); +#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id) *ne12 + (i1)] - for (size_t i = 0; i < batch_feature; i++) { - qnbitgemm_args[i].a_ptr = feature + gemm_m * gemm_k * i; - qnbitgemm_args[i].lda = gemm_k; - qnbitgemm_args[i].packed_quant_b_data = (const std::byte *) w_data; - qnbitgemm_args[i].quant_b_scale = nullptr; + int64_t * matrix_row_counts = (int64_t *) (ws_ptr + gemm_workspace_size); + int32_t * valid_ep_count = (int32_t *) (matrix_row_counts + n_as); + int32_t * valid_act_count = (int32_t *) (valid_ep_count + 1); + int64_t * valid_matrix_row_counts = (int64_t *) (valid_act_count + 1); + mmid_row_mapping * matrix_rows = (mmid_row_mapping *) (valid_matrix_row_counts + n_as); - if constexpr (std::is_same_v) { - qnbitgemm_args[i].quant_b_zp = nullptr; - } else { - qnbitgemm_args[i].quant_b_zp = w_data; + if (ith == 0) { + // initialize matrix_row_counts + memset(matrix_row_counts, 0, n_as * sizeof(int64_t)); + + // group rows by src0 matrix + for (int32_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) { + for (int32_t id = 0; id < n_ids; ++id) { + const int32_t i02 = + *(const int32_t *) ((const char *) ids->data + iid1 * ids->nb[1] + id * ids->nb[0]); + + GGML_ASSERT(i02 >= 0 && i02 < n_as); + + MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = { id, iid1 }; + matrix_row_counts[i02] += 1; + } } - qnbitgemm_args[i].bias = nullptr; - qnbitgemm_args[i].c_ptr = output + gemm_m * gemm_n * i; - qnbitgemm_args[i].ldc = gemm_n; + int32_t valid_ep_count_t = 0; + int32_t valid_act_count_t = 0; + for (int cur_a = 0; cur_a < n_as; ++cur_a) { + const int64_t cne1 = matrix_row_counts[cur_a]; + if (cne1 == 0) { + continue; + } + valid_matrix_row_counts[valid_ep_count_t] = cur_a; + valid_act_count_t += cne1; + valid_ep_count_t += 1; + } + valid_ep_count[0] = valid_ep_count_t; + valid_act_count[0] = valid_act_count_t; } - const uintptr_t ws_ptr = reinterpret_cast(params->wdata); - void * ws = reinterpret_cast((ws_ptr + alignof(uint64_t) - 1) & (~(alignof(uint64_t) - 1))); - const size_t quant_a_stride = block_count_k * q8_blk_size(QK4_0); + const int64_t barrier_idx = static_cast(ith / 2); - { - constexpr size_t block_size_m = 4; - size_t per_gemm_block_count_m = div_round_up(gemm_m, block_size_m); - int32_t task_count = batch_feature * per_gemm_block_count_m; - int32_t task_per_thread = (task_count + nth - 1) / nth; - int32_t start = ith * task_per_thread; - int32_t end = std::min((ith + 1) * task_per_thread, task_count); - for (int32_t compute_idx = start; compute_idx < end; compute_idx++) { - int32_t gemm_idx = compute_idx / per_gemm_block_count_m; - int32_t block_idx_in_gemm = compute_idx % per_gemm_block_count_m; - int32_t m_idx = block_idx_in_gemm * block_size_m; - const qnbitgemm_spacemit_ime_args & data = qnbitgemm_args[gemm_idx]; - int32_t rows_tobe_handled = (gemm_m - m_idx) > block_size_m ? block_size_m : (gemm_m - m_idx); - - if (rows_tobe_handled == block_size_m) { - const float * a_row_ptr = data.a_ptr + m_idx * data.lda; - std::byte * quant_a_row_ptr = - static_cast(ws) + gemm_idx * per_gemm_workspace_stride + m_idx * quant_a_stride; - sqnbitgemm_spacemit_ime::ime1::quantize_a_4row_i8(QK4_0, a_row_ptr, gemm_k, quant_a_row_ptr); - } else { - while (rows_tobe_handled) { - const float * a_row_ptr = data.a_ptr + m_idx * data.lda; - std::byte * quant_a_row_ptr = static_cast(ws) + - gemm_idx * per_gemm_workspace_stride + m_idx * quant_a_stride; - sqnbitgemm_spacemit_ime::ime1::quantize_a_row_i8(QK4_0, a_row_ptr, gemm_k, quant_a_row_ptr); - rows_tobe_handled -= 1; - m_idx += 1; + GGML_ASSERT(global_spine_env_info.init_barrier != nullptr); + GGML_ASSERT(barrier_idx < spine_init_barrier_count); + spine_barrier_t * cur_barrier = &global_spine_env_info.init_barrier[barrier_idx]; + + ggml_barrier(params->threadpool); + + const size_t row_stride_b = b_k_blks * get_repacked_block_type_size(); + const size_t expert_b_stride = ne01 * row_stride_b; + const size_t per_nb_cols_wsize = NB_COLS * row_stride_b; + + std::array src_workspaces; + std::array dst_workspaces; + + auto * tcm_buffer = ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer; + const auto tcm_buffer_size = ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer_size; + + const auto valid_ep_count_t = valid_ep_count[0]; + const auto valid_act_count_t = valid_act_count[0]; + + int nth_es = 1; + int nth_n = nth; + + int ith_es = ith % nth_es; + int ith_n = (ith / nth_es) % nth_n; + + if (valid_ep_count_t % nth == 0 && tcm_buffer != nullptr && valid_ep_count_t == n_as && + valid_act_count_t == n_as && per_nb_cols_wsize <= tcm_buffer_size) { + for (int64_t valid_id = ith; valid_id < valid_ep_count_t; valid_id += nth) { + const int64_t cur_a = valid_matrix_row_counts[valid_id]; + + auto * src0_cur = (uint8_t *) src0->data + cur_a * expert_b_stride; + + mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, 0); + const int id = row_mapping.i1; + const int64_t i11 = id % ne11; + const int64_t i12 = row_mapping.i2; + const int64_t i1 = id; + const int64_t i2 = i12; + + auto * src1_col = quant_a_buffer + (i11 * nbw1 + i12 * nbw2); + float * c_blk = (float *) ((char *) dst->data + (i1 * nb1 + i2 * nb2)); + + uint8_t * a_row = src1_col; + uint8_t * b_col = reinterpret_cast(tcm_buffer); + if ((nbw1 + per_nb_cols_wsize) <= tcm_buffer_size) { + a_row = (uint8_t *) tcm_buffer; + b_col = reinterpret_cast(tcm_buffer) + nbw1; + } + uint8_t * b_col_zp = block_type_has_zp() ? b_col : nullptr; + + if (ith % 2 == 0) { + spacemit_kernels::rvv::memcpy1d(b_col, reinterpret_cast(src0_cur), per_nb_cols_wsize); + + if (a_row != src1_col) { + spacemit_kernels::rvv::memcpy1d(a_row, src1_col, nbw1); + } + } + + spine_barrier_wait(cur_barrier); + + if (ith % 2 != 0) { + if (a_row != src1_col) { + spacemit_kernels::rvv::memcpy1d(a_row, src1_col, nbw1); + } + + spacemit_kernels::rvv::memcpy1d(b_col, reinterpret_cast(src0_cur), per_nb_cols_wsize); + } + + int64_t nb_real = std::min(ne01, NB_COLS); + for (int64_t ni = 0; ni < ne01; ni += NB_COLS) { + if (ith % 2 != 0) { + spine_barrier_wait(cur_barrier); + } + + gemm_kernel(b_blk_len, a_row, b_col, b_col_zp, c_blk + ni, 1, nb_real, b_k_blks, ne01); + + if (ith % 2 == 0) { + spine_barrier_wait(cur_barrier); + } + + const int64_t next_ni = ni + NB_COLS; + if (next_ni < ne01) { + nb_real = std::min(ne01 - next_ni, NB_COLS); + spacemit_kernels::rvv::memcpy1d( + b_col, reinterpret_cast(src0_cur) + next_ni * row_stride_b, per_nb_cols_wsize); } } } - } + } else { + for (int64_t valid_id = ith_es; valid_id < valid_ep_count_t; valid_id += nth_es) { + const int64_t cur_a = valid_matrix_row_counts[valid_id]; + const int64_t cne1 = matrix_row_counts[cur_a]; - ggml_barrier(params->threadpool); + int64_t src1_cur_start = 0; + int64_t src1_cur_end = cne1; - if (ith >= ggml::cpu::riscv64_spacemit::num_ai_cores) { - return; - } - nth = std::min(nth, int{ ggml::cpu::riscv64_spacemit::num_ai_cores }); - - size_t threads_per_gemm = nth / batch_feature; - constexpr size_t gemm_m_stride = 128; - size_t nc = gemm_n; - const size_t gemm_m_blocked = div_round_up(gemm_m, gemm_m_stride); - const size_t max_nc = div_round_up(gemm_n * gemm_m_blocked, threads_per_gemm); - if (max_nc < nc) { - nc = std::min(nc, div_round_up(max_nc, QGEMM_STRIDEN_THREAD_ALIGN) * QGEMM_STRIDEN_THREAD_ALIGN); - } - const size_t gemm_n_stride = nc; - const size_t thread_count_m = div_round_up(gemm_m, gemm_m_stride); - const size_t thread_count_n = div_round_up(gemm_n, gemm_n_stride); - threads_per_gemm = thread_count_m * thread_count_n; + int64_t src0_cur_start = (ith_n * ne01) / nth_n; + int64_t src0_cur_end = MIN(((ith_n + 1) * ne01) / nth_n, ne01); - { - int task_count = batch_feature * threads_per_gemm; - int task_per_thread = (task_count + nth - 1) / nth; - int start = ith * task_per_thread; - int end = std::min((ith + 1) * task_per_thread, task_count); - for (int compute_idx = start; compute_idx < end; compute_idx++) { - const auto gemm_i = compute_idx / threads_per_gemm; - const auto blk_i = compute_idx % threads_per_gemm; - const auto * data = &qnbitgemm_args[gemm_i]; + if (src1_cur_start >= src1_cur_end || src0_cur_start >= src0_cur_end) { + continue; + } + + src0_cur_start = + (src0_cur_start % NB_COLS) ? src0_cur_start + NB_COLS - (src0_cur_start % NB_COLS) : src0_cur_start; + src0_cur_end = + (src0_cur_end % NB_COLS) ? src0_cur_end + NB_COLS - (src0_cur_end % NB_COLS) : src0_cur_end; + + auto * src0_cur = (uint8_t *) src0->data + cur_a * expert_b_stride + src0_cur_start * row_stride_b; + uint8_t * b_col_zp = block_type_has_zp() ? src0_cur : nullptr; + + size_t extra_tcm_buffer_size = tcm_buffer_size; + void * extra_tcm_buffer = tcm_buffer; + if (tcm_buffer != nullptr && (src1_cur_end - src1_cur_start) >= 4 && + (src0_cur_end - src0_cur_start) * row_stride_b <= tcm_buffer_size) { + spacemit_kernels::rvv::memcpy1d(tcm_buffer, src0_cur, + (src0_cur_end - src0_cur_start) * row_stride_b); + src0_cur = reinterpret_cast(tcm_buffer); + b_col_zp = block_type_has_zp() ? src0_cur : nullptr; + extra_tcm_buffer_size -= (src0_cur_end - src0_cur_start) * row_stride_b; + extra_tcm_buffer = reinterpret_cast(reinterpret_cast(tcm_buffer) + + (src0_cur_end - src0_cur_start) * row_stride_b); + } - const auto tid_n = blk_i / thread_count_m; - const auto tid_m = blk_i % thread_count_m; + int ir1 = src1_cur_start; - const size_t m_start = tid_m * gemm_m_stride; - const size_t m_count = std::min(gemm_m - m_start, (size_t) gemm_m_stride); + if (extra_tcm_buffer_size >= nbw1 && extra_tcm_buffer != nullptr) { + int64_t quant_a_tile_size = extra_tcm_buffer_size / nbw1; + do { + quant_a_tile_size = MIN(quant_a_tile_size, src1_cur_end - ir1); - const size_t n_start = tid_n * gemm_n_stride; - const size_t n_count = std::min(gemm_n - n_start, (size_t) gemm_n_stride); + uint8_t * quant_a_tile_buffer = reinterpret_cast(extra_tcm_buffer); - void * per_gemm_ws = reinterpret_cast(ws) + gemm_i * per_gemm_workspace_stride; + int iir1 = ir1; + for (; iir1 < (ir1 + quant_a_tile_size); ++iir1) { + mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, iir1); - sqnbitgemm_spacemit_ime_i8i4(QK4_0, gemm_k, data, per_gemm_ws, m_start, m_count, n_start, n_count); + const int id = row_mapping.i1; // selected expert index + + const int64_t i11 = id % ne11; + const int64_t i12 = row_mapping.i2; // row index in src1 + + auto * src1_col = quant_a_buffer + (i11 * nbw1 + i12 * nbw2); + spacemit_kernels::rvv::memcpy1d(quant_a_tile_buffer, src1_col, nbw1); + quant_a_tile_buffer = quant_a_tile_buffer + nbw1; + } + + quant_a_tile_buffer = reinterpret_cast(extra_tcm_buffer); + iir1 = ir1; + + if (moe_gemm_kernel_m2 != nullptr) { + for (; iir1 < (ir1 + quant_a_tile_size - 1); iir1 += 2, quant_a_tile_buffer += 2 * nbw1) { + mmid_row_mapping row_mapping_0 = MMID_MATRIX_ROW(cur_a, iir1); + mmid_row_mapping row_mapping_1 = MMID_MATRIX_ROW(cur_a, iir1 + 1); + + src_workspaces[0] = quant_a_tile_buffer; + src_workspaces[1] = quant_a_tile_buffer + nbw1; + + dst_workspaces[0] = + (float *) ((char *) dst->data + (row_mapping_0.i1 * nb1 + row_mapping_0.i2 * nb2)) + + src0_cur_start; + dst_workspaces[1] = (float *) ((char *) dst->data + + ((row_mapping_1.i1) * nb1 + (row_mapping_1.i2) * nb2)) + + src0_cur_start; + moe_gemm_kernel_m2(b_blk_len, src_workspaces.data(), src0_cur, b_col_zp, + dst_workspaces.data(), 1, src0_cur_end - src0_cur_start, b_k_blks, + ne01); + } + } + + for (; iir1 < (ir1 + quant_a_tile_size); iir1++, quant_a_tile_buffer += nbw1) { + mmid_row_mapping row_mapping_0 = MMID_MATRIX_ROW(cur_a, iir1); + + gemm_kernel( + b_blk_len, quant_a_tile_buffer, src0_cur, b_col_zp, + (float *) ((char *) dst->data + (row_mapping_0.i1 * nb1 + row_mapping_0.i2 * nb2)) + + src0_cur_start, + 1, src0_cur_end - src0_cur_start, b_k_blks, ne01); + } + + ir1 += quant_a_tile_size; + } while (ir1 < src1_cur_end); + } else { + if (moe_gemm_kernel_m2 != nullptr) { + for (; ir1 < src1_cur_end - 1; ir1 += 2) { + for (int iir1 = 0; iir1 < 2; ++iir1) { + mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, ir1 + iir1); + + const int id = row_mapping.i1; // selected expert index + + const int64_t i11 = id % ne11; + const int64_t i12 = row_mapping.i2; // row index in src1 + + const int64_t i1 = id; // selected expert index + const int64_t i2 = i12; // row + + src_workspaces[iir1] = quant_a_buffer + (i11 * nbw1 + i12 * nbw2); + + dst_workspaces[iir1] = + (float *) ((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start; + } + + moe_gemm_kernel_m2(b_blk_len, src_workspaces.data(), src0_cur, b_col_zp, + dst_workspaces.data(), 1, src0_cur_end - src0_cur_start, b_k_blks, ne01); + } + } + + for (; ir1 < src1_cur_end; ir1++) { + mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, ir1); + + const int id = row_mapping.i1; // selected expert index + + const int64_t i11 = id % ne11; + const int64_t i12 = row_mapping.i2; // row index in src1 + + const int64_t i1 = id; // selected expert index + const int64_t i2 = i12; // row + + auto * src1_col = quant_a_buffer + (i11 * nbw1 + i12 * nbw2); + + gemm_kernel(b_blk_len, src1_col, src0_cur, b_col_zp, + (float *) ((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start, 1, + src0_cur_end - src0_cur_start, b_k_blks, ne01); + } + } } } +#undef MMID_MATRIX_ROW } - int repack(struct ggml_tensor * t, const void * data, size_t data_size) override { + int repack(ggml_tensor * t, const void * data, size_t data_size) override { GGML_LOG_DEBUG("%s: repack tensor %s with %s_%dx%d\n", __func__, t->name, ggml_type_name(t->type), (int) NB_COLS, (int) INTER_SIZE); return ggml::cpu::riscv64_spacemit::repack(t, data, data_size); @@ -563,309 +936,464 @@ template class tensor_ }; class tensor_traits_common : public tensor_traits_base { - bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override { + bool work_size(int n_threads, const ggml_tensor * op, size_t & size) override { switch (op->op) { - case GGML_OP_NORM: - case GGML_OP_RMS_NORM: - size = 0; + case GGML_OP_FLASH_ATTN_EXT: + { + const int n_tasks = n_threads; + const int64_t neq2 = op->src[0]->ne[2]; // number of query heads + const int64_t DK = op->src[1]->ne[0]; + const int64_t DV = op->src[2]->ne[0]; // DV + + // Tiled flash attention scratch (tile sizes defined in common.h) + // Per-thread: Q_q + KQ + mask + VKQ32 + V32 + K_f32 + padding + size_t prefill = sizeof(float) * + (GGML_FA_TILE_Q * DK + 2 * GGML_FA_TILE_Q * GGML_FA_TILE_KV + GGML_FA_TILE_Q * DV + + GGML_FA_TILE_KV * DV + GGML_FA_TILE_KV * DK) * + n_tasks; + + // Decode path: n_kv_chunks = n_tasks (one chunk per thread) + // Per-thread: VKQ accmulator (DV), partial M, partial S + intra-thread scratch for V, Q and VKQ + size_t n_chunks = n_tasks; + size_t decode = sizeof(float) * (neq2 * n_chunks * (2 + DV) + n_tasks * (DK + 2 * DV)); + + size = MAX(prefill, decode); + } return true; default: - // GGML_ABORT("fatal error"); break; } return false; } - bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) override { + bool compute_forward(ggml_compute_params * params, ggml_tensor * op) override { switch (op->op) { case GGML_OP_NORM: - forward_norm_f32(params, op); - return true; + switch (op->src[0]->type) { + case GGML_TYPE_F32: + spacemit_kernels::rvv::forward_norm_f32(params, op); + return true; + default: + GGML_ABORT("fatal error"); + } case GGML_OP_RMS_NORM: - forward_rms_norm_f32(params, op); + switch (op->src[0]->type) { + case GGML_TYPE_F32: + spacemit_kernels::rvv::forward_rms_norm_f32(params, op); + return true; + default: + GGML_ABORT("fatal error"); + } + case GGML_OP_ADD: + switch (op->src[0]->type) { + case GGML_TYPE_F32: + spacemit_kernels::rvv::forward_binary(params, op); + return true; + case GGML_TYPE_F16: + spacemit_kernels::rvv::forward_binary(params, op); + return true; + default: + ggml_compute_forward_add(params, op); + return true; + } + case GGML_OP_SUB: + switch (op->src[0]->type) { + case GGML_TYPE_F32: + spacemit_kernels::rvv::forward_binary(params, op); + return true; + case GGML_TYPE_F16: + spacemit_kernels::rvv::forward_binary(params, op); + return true; + default: + ggml_compute_forward_sub(params, op); + return true; + } + case GGML_OP_MUL: + switch (op->src[0]->type) { + case GGML_TYPE_F32: + spacemit_kernels::rvv::forward_binary(params, op); + return true; + case GGML_TYPE_F16: + spacemit_kernels::rvv::forward_binary(params, op); + return true; + default: + ggml_compute_forward_mul(params, op); + return true; + } + case GGML_OP_DIV: + switch (op->src[0]->type) { + case GGML_TYPE_F32: + spacemit_kernels::rvv::forward_binary(params, op); + return true; + case GGML_TYPE_F16: + spacemit_kernels::rvv::forward_binary(params, op); + return true; + default: + ggml_compute_forward_div(params, op); + return true; + } + case GGML_OP_FLASH_ATTN_EXT: + forward_flash_attn_ext_f16(params, op); + return true; + case GGML_OP_CONT: + { + const ggml_tensor * src0 = op->src[0]; + if (op->type == src0->type && op->nb[0] != src0->nb[0] && op->nb[0] == src0->nb[1] && + op->ne[3] * op->ne[2] * op->nb[2] == src0->ne[3] * src0->ne[2] * src0->nb[2]) { + spacemit_kernels::rvv::forward_cont_with_permute(params, op); + } else { + ggml_compute_forward_cont(params, op); + } + return true; + } + case GGML_OP_CPY: + { + const ggml_tensor * src0 = op->src[0]; + if (op->type == src0->type && op->nb[0] == src0->nb[1] && src0->nb[0] != src0->nb[1] && + ggml_nelements(src0) == ggml_nelements(op)) { + spacemit_kernels::rvv::forward_cpy_with_permute(params, op); + } else { + ggml_compute_forward_cpy(params, op); + } + return true; + } + case GGML_OP_REPEAT: + { + const bool rows_equal = ggml_nrows(op->src[0]) == ggml_nrows(op); + const bool broadcast_or_equal = op->src[0]->ne[0] == 1 || op->src[0]->ne[0] == op->ne[0]; + + if (rows_equal && broadcast_or_equal) { + switch (op->src[0]->type) { + case GGML_TYPE_F32: + spacemit_kernels::rvv::forward_repeat_nrows(params, op); + return true; + case GGML_TYPE_F16: + spacemit_kernels::rvv::forward_repeat_nrows(params, op); + return true; + default: + break; + } + } + + if (op->src[0]->ne[1] == 1 && op->src[0]->ne[0] == op->ne[0]) { + switch (op->src[0]->type) { + case GGML_TYPE_F32: + spacemit_kernels::rvv::forward_repeat_dim1(params, op); + return true; + case GGML_TYPE_F16: + spacemit_kernels::rvv::forward_repeat_dim1(params, op); + return true; + default: + break; + } + } + + ggml_compute_forward_repeat(params, op); + } + return true; + case GGML_OP_SUM_ROWS: + { + if (op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) { + spacemit_kernels::rvv::forward_sum_rows(params, op); + } else { + ggml_compute_forward_sum_rows(params, op); + } + } + return true; + case GGML_OP_GET_ROWS: + { + if (op->src[0]->type == op->type) { + switch (op->src[0]->type) { + case GGML_TYPE_F32: + spacemit_kernels::rvv::forward_get_rows(params, op); + return true; + case GGML_TYPE_F16: + spacemit_kernels::rvv::forward_get_rows(params, op); + return true; + default: + break; + } + } + + ggml_compute_forward_get_rows(params, op); + } return true; + case GGML_OP_CONCAT: + { + const int32_t dim = ggml_get_op_params_i32(op, 0); + if (dim == 0 && op->type == op->src[0]->type) { + switch (op->src[0]->type) { + case GGML_TYPE_F32: + spacemit_kernels::rvv::forward_concat(params, op); + return true; + case GGML_TYPE_F16: + spacemit_kernels::rvv::forward_concat(params, op); + return true; + default: + break; + } + } + + ggml_compute_forward_concat(params, op); + } + return true; + // TODO For GGML_OP_GATED_DELTA_NET + // case GGML_OP_GATED_DELTA_NET: + // return true; default: - // GGML_ABORT("fatal error"); break; } return false; } - void forward_norm_f32(ggml_compute_params * params, ggml_tensor * op) { - const ggml_tensor * src0 = op->src[0]; - ggml_tensor * dst = op; - GGML_ASSERT(ggml_are_same_shape(src0, dst)); - GGML_ASSERT(src0->nb[0] == sizeof(float)); + void forward_flash_attn_ext_f16(const ggml_compute_params * params, ggml_tensor * dst) { + const ggml_tensor * q = dst->src[0]; + const ggml_tensor * k = dst->src[1]; + const ggml_tensor * v = dst->src[2]; + + GGML_TENSOR_LOCALS(int64_t, neq, q, ne) + GGML_TENSOR_LOCALS(size_t, nbq, q, nb) + GGML_TENSOR_LOCALS(int64_t, nek, k, ne) + GGML_TENSOR_LOCALS(size_t, nbk, k, nb) + GGML_TENSOR_LOCALS(int64_t, nev, v, ne) + GGML_TENSOR_LOCALS(size_t, nbv, v, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + + const int64_t DK = nek0; + const int64_t DV = nev0; + + const bool supported_prec = (dst->op_params[3] == GGML_PREC_F32 || dst->op_params[3] == GGML_PREC_DEFAULT); + const bool supported_types = (q->type == GGML_TYPE_F32 && k->type == GGML_TYPE_F16 && v->type == GGML_TYPE_F16); + const bool supported_shape = (DK > 0 && DK <= 128 && DV > 0 && DV <= 128); + const bool supported_vlen = (__riscv_vlenb() == 128); + + if (!(supported_prec && supported_types && supported_shape && supported_vlen)) { + ggml_compute_forward_flash_attn_ext(params, dst); + return; + } + + // total rows in q + const int64_t nr = neq1 * neq2 * neq3; + // rows per thread const int ith = params->ith; const int nth = params->nth; - GGML_TENSOR_UNARY_OP_LOCALS + static constexpr int64_t Q_TILE_SZ = ggml_fa_tile_config::Q; + const bool use_tiled = !params->use_ref && (neq1 >= Q_TILE_SZ); - float epsilon; - memcpy(&epsilon, dst->op_params, sizeof(float)); + // 4x chunks per thread + // int nth_scaled = nth * 4; + // int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled; + // int64_t nchunk = (nr + chunk_size - 1) / chunk_size; - GGML_ASSERT(epsilon > 0.0f); + // if (nth == 1 || nchunk < nth) { + // nchunk = nth; + // } - auto * input = (float *) src0->data; - auto * output = (float *) dst->data; + int64_t nchunk = nth; - const auto hidden_size = ne00; - const auto task_count = ne01 * ne02 * ne03; - const auto task_per_thread = (task_count + nth - 1) / nth; - - const auto task_begin = ith * task_per_thread; - const auto task_end = std::min((ith + 1) * task_per_thread, task_count); + if (ith == 0) { + // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start. + ggml_threadpool_chunk_set(params->threadpool, nth); + } - for (auto task_idx = task_begin; task_idx < task_end; task_idx++) { - auto offset = task_idx * hidden_size; - auto * p_input = const_cast(input + offset); + ggml_barrier(params->threadpool); - auto * p_output = output + offset; - auto * p_temp_output = p_output; - auto * p_gamma_data = (const float *) nullptr; - auto * p_beta_data = (const float *) nullptr; - size_t gvl = __riscv_vsetvlmax_e32m4(); - vfloat32m4_t sum = __riscv_vfmv_v_f_f32m4(0.f, gvl); - vfloat32m4_t sum_sq = __riscv_vfmv_v_f_f32m4(0.f, gvl); - int64_t length = hidden_size; - while (length > 0) { - gvl = __riscv_vsetvl_e32m4(length); - // load data - vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_input, gvl); + // The number of elements in each chunk + const int64_t dr = (nr + nchunk - 1) / nchunk; - sum = __riscv_vfadd_vv_f32m4(sum, src_data, gvl); - sum_sq = __riscv_vfmacc_vv_f32m4(sum_sq, src_data, src_data, gvl); + // The first chunk comes from our thread_id, the rest will get auto-assigned. + int current_chunk = ith; - __riscv_vse32_v_f32m4(p_temp_output, src_data, gvl); + while (current_chunk < nchunk) { + const int64_t ir0 = dr * current_chunk; + const int64_t ir1 = MIN(ir0 + dr, nr); - p_input += gvl; - p_temp_output += gvl; - length -= gvl; + if (use_tiled) { + spacemit_kernels::rvv::forward_flash_attn_ext_f16_tiled_vlen1024_vf16( + params, dst, ir0, ir1, ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer, + ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer_size); + } else { + spacemit_kernels::rvv::forward_flash_attn_ext_f16_one_chunk_vlen1024_vf16( + params, dst, ir0, ir1, ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer, + ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer_size); } - gvl = __riscv_vsetvlmax_e32m1(); - - float mean = 0.f; - vfloat32m1_t zero_v = __riscv_vfmv_v_f_f32m1(0.f, gvl); - vfloat32m1_t mean_v = - __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(sum, 0), __riscv_vget_v_f32m4_f32m1(sum, 1), gvl); - mean_v = __riscv_vfadd_vv_f32m1(mean_v, __riscv_vget_v_f32m4_f32m1(sum, 2), gvl); - mean_v = __riscv_vfadd_vv_f32m1(mean_v, __riscv_vget_v_f32m4_f32m1(sum, 3), gvl); - mean_v = __riscv_vfredusum_vs_f32m1_f32m1(mean_v, zero_v, gvl); - mean = __riscv_vfmv_f_s_f32m1_f32(mean_v); - mean /= hidden_size; - - vfloat32m1_t mean_square_v = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(sum_sq, 0), - __riscv_vget_v_f32m4_f32m1(sum_sq, 1), gvl); - mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 2), gvl); - mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 3), gvl); - mean_square_v = __riscv_vfredusum_vs_f32m1_f32m1(mean_square_v, zero_v, gvl); - - float mean_square = __riscv_vfmv_f_s_f32m1_f32(mean_square_v); - mean_square /= hidden_size; - mean_square = sqrt(mean_square - mean * mean + epsilon); - - mean_square = 1.0f / mean_square; - length = hidden_size; - p_temp_output = p_output; - - if (p_gamma_data == nullptr && p_beta_data == nullptr) { - while (length > 0) { - gvl = __riscv_vsetvl_e32m4(length); - vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); - src_data = __riscv_vfsub_vf_f32m4(src_data, mean, gvl); - src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); - __riscv_vse32_v_f32m4(p_output, src_data, gvl); - p_temp_output += gvl; - p_output += gvl; - length -= gvl; - } - } else if (p_beta_data == nullptr) { - while (length > 0) { - gvl = __riscv_vsetvl_e32m4(length); - vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); - vfloat32m4_t gamma_data_v = __riscv_vle32_v_f32m4(p_gamma_data, gvl); - src_data = __riscv_vfsub_vf_f32m4(src_data, mean, gvl); - src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); - src_data = __riscv_vfmul_vv_f32m4(src_data, gamma_data_v, gvl); - __riscv_vse32_v_f32m4(p_output, src_data, gvl); - p_temp_output += gvl; - p_output += gvl; - p_gamma_data += gvl; - length -= gvl; - } - } else if (p_gamma_data != nullptr) { - while (length > 0) { - gvl = __riscv_vsetvl_e32m4(length); - vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); - vfloat32m4_t gamma_data_v = __riscv_vle32_v_f32m4(p_gamma_data, gvl); - src_data = __riscv_vfsub_vf_f32m4(src_data, mean, gvl); - src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); - src_data = __riscv_vfmul_vv_f32m4(src_data, gamma_data_v, gvl); - vfloat32m4_t beta_data_v = __riscv_vle32_v_f32m4(p_beta_data, gvl); - src_data = __riscv_vfadd_vv_f32m4(src_data, beta_data_v, gvl); - p_beta_data += gvl; - __riscv_vse32_v_f32m4(p_output, src_data, gvl); - p_temp_output += gvl; - p_output += gvl; - p_gamma_data += gvl; - length -= gvl; - } - } + current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1); } } - void forward_rms_norm_f32(ggml_compute_params * params, ggml_tensor * op) { - const ggml_tensor * src0 = op->src[0]; - ggml_tensor * dst = op; - GGML_ASSERT(ggml_are_same_shape(src0, dst)); - GGML_ASSERT(src0->nb[0] == sizeof(float)); - - const int ith = params->ith; - const int nth = params->nth; - - GGML_TENSOR_UNARY_OP_LOCALS - - float epsilon; - memcpy(&epsilon, dst->op_params, sizeof(float)); - - GGML_ASSERT(epsilon > 0.0f); - - auto * input = (float *) src0->data; - auto * output = (float *) dst->data; - - const auto hidden_size = ne00; - const auto task_count = ne01 * ne02 * ne03; - const auto task_per_thread = (task_count + nth - 1) / nth; - - const auto task_begin = ith * task_per_thread; - const auto task_end = std::min((ith + 1) * task_per_thread, task_count); - - for (auto task_idx = task_begin; task_idx < task_end; task_idx++) { - auto offset = task_idx * hidden_size; - auto * p_input = const_cast(input + offset); - auto * p_output = output + offset; - auto * p_temp_output = p_output; - auto * p_gamma_data = (const float *) nullptr; - auto * p_beta_data = (const float *) nullptr; - - size_t gvl = __riscv_vsetvlmax_e32m4(); - // vfloat32m4_t sum = __riscv_vfmv_v_f_f32m4(0.f, gvl); - vfloat32m4_t sum_sq = __riscv_vfmv_v_f_f32m4(0.f, gvl); - int64_t length = hidden_size; - while (length > 0) { - gvl = __riscv_vsetvl_e32m4(length); - // load data - vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_input, gvl); + int repack(ggml_tensor * t, const void * data, size_t data_size) override { + memcpy(t->data, data, data_size); + return 0; + } +}; - sum_sq = __riscv_vfmacc_vv_f32m4(sum_sq, src_data, src_data, gvl); +// Impl By IME1 +static const tensor_traits q4_0_16x32_q8_0; +static const tensor_traits q4_1_16x32_q8_0; +static const tensor_traits q4_k_16x32_q8_0; +// Impl By IME2 +static const tensor_traits q2_k_32x256_q8_0; +static const tensor_traits q3_k_32x256_q8_0; +static const tensor_traits q4_0_32x32_q8_0; +static const tensor_traits q4_1_32x32_q8_0; +static const tensor_traits q4_0_32x256_q8_0; +static const tensor_traits q4_1_32x256_q8_0; +static const tensor_traits q4_k_32x32_q8_0; +static const tensor_traits q6_k_32x32_q8_0; +static const tensor_traits q8_0_32x32_q8_0; +static const tensor_traits mxfp4_32x32_q8_0; +static const tensor_traits q5_k_32x32_q8_0; +static const tensor_traits q5_1_32x32_q8_0; +static const tensor_traits q5_0_32x32_q8_0; +// Impl By RVV +static const tensor_traits_common rvv_impl; - __riscv_vse32_v_f32m4(p_temp_output, src_data, gvl); +} // namespace ggml::cpu::riscv64_spacemit - p_input += gvl; - p_temp_output += gvl; - length -= gvl; +static const ggml::cpu::tensor_traits * ggml_riscv64_spacemit_get_optimal_repack_type(const ggml_tensor * cur) { + switch (cur->type) { + case GGML_TYPE_Q2_K: + { +#if defined(RISCV64_SPACEMIT_IME2) + if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) { + return &ggml::cpu::riscv64_spacemit::q2_k_32x256_q8_0; + } +#endif } + break; + case GGML_TYPE_Q3_K: + { +#if defined(RISCV64_SPACEMIT_IME2) + if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) { + return &ggml::cpu::riscv64_spacemit::q3_k_32x256_q8_0; + } +#endif + } + break; + case GGML_TYPE_Q4_0: + { +#if defined(RISCV64_SPACEMIT_IME2) + if (cur->ne[1] % 32 == 0 && cur->ne[0] % 256 == 0 && + (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) { + return &ggml::cpu::riscv64_spacemit::q4_0_32x256_q8_0; + } - gvl = __riscv_vsetvlmax_e32m1(); - - // float mean = 0.f; - vfloat32m1_t zero_v = __riscv_vfmv_v_f_f32m1(0.f, gvl); - - vfloat32m1_t mean_square_v = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(sum_sq, 0), - __riscv_vget_v_f32m4_f32m1(sum_sq, 1), gvl); - mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 2), gvl); - mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 3), gvl); - mean_square_v = __riscv_vfredusum_vs_f32m1_f32m1(mean_square_v, zero_v, gvl); - - float mean_square = __riscv_vfmv_f_s_f32m1_f32(mean_square_v); - mean_square /= hidden_size; + if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) { + return &ggml::cpu::riscv64_spacemit::q4_0_32x32_q8_0; + } +#endif - mean_square = sqrt(mean_square + epsilon); +#if defined(RISCV64_SPACEMIT_IME1) + if (cur->ne[1] % 16 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime1)) { + return &ggml::cpu::riscv64_spacemit::q4_0_16x32_q8_0; + } +#endif + } + break; + case GGML_TYPE_Q4_1: + { +#if defined(RISCV64_SPACEMIT_IME2) + // TODO + // if (cur->ne[1] % 32 == 0 && cur->ne[0] % 256 == 0 && + // (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) { + // return &ggml::cpu::riscv64_spacemit::q4_1_32x256_q8_0; + // } + + if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) { + return &ggml::cpu::riscv64_spacemit::q4_1_32x32_q8_0; + } +#endif - mean_square = 1.0f / mean_square; - length = hidden_size; - p_temp_output = p_output; +#if defined(RISCV64_SPACEMIT_IME1) + if (cur->ne[1] % 16 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime1)) { + return &ggml::cpu::riscv64_spacemit::q4_1_16x32_q8_0; + } +#endif + } + break; + case GGML_TYPE_Q4_K: + { +#if defined(RISCV64_SPACEMIT_IME2) + if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) { + return &ggml::cpu::riscv64_spacemit::q4_k_32x32_q8_0; + } +#endif - if (p_gamma_data == nullptr && p_beta_data == nullptr) { - while (length > 0) { - gvl = __riscv_vsetvl_e32m4(length); - vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); - src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); - __riscv_vse32_v_f32m4(p_output, src_data, gvl); - p_temp_output += gvl; - p_output += gvl; - length -= gvl; +#if defined(RISCV64_SPACEMIT_IME1) + if (cur->ne[1] % 16 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime1)) { + return &ggml::cpu::riscv64_spacemit::q4_k_16x32_q8_0; } - } else if (p_beta_data == nullptr) { - while (length > 0) { - gvl = __riscv_vsetvl_e32m4(length); - vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); - vfloat32m4_t gamma_data_v = __riscv_vle32_v_f32m4(p_gamma_data, gvl); - src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); - src_data = __riscv_vfmul_vv_f32m4(src_data, gamma_data_v, gvl); - __riscv_vse32_v_f32m4(p_output, src_data, gvl); - p_temp_output += gvl; - p_output += gvl; - p_gamma_data += gvl; - length -= gvl; +#endif + } + break; + case GGML_TYPE_Q6_K: + { +#if defined(RISCV64_SPACEMIT_IME2) + if ((ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) { + return &ggml::cpu::riscv64_spacemit::q6_k_32x32_q8_0; } - } else if (p_gamma_data != nullptr) { - while (length > 0) { - gvl = __riscv_vsetvl_e32m4(length); - vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); - vfloat32m4_t gamma_data_v = __riscv_vle32_v_f32m4(p_gamma_data, gvl); - src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); - src_data = __riscv_vfmul_vv_f32m4(src_data, gamma_data_v, gvl); - vfloat32m4_t beta_data_v = __riscv_vle32_v_f32m4(p_beta_data, gvl); - src_data = __riscv_vfadd_vv_f32m4(src_data, beta_data_v, gvl); - p_beta_data += gvl; - __riscv_vse32_v_f32m4(p_output, src_data, gvl); - p_temp_output += gvl; - p_output += gvl; - p_gamma_data += gvl; - length -= gvl; +#endif + } + break; + case GGML_TYPE_Q8_0: + { +#if defined(RISCV64_SPACEMIT_IME2) + if ((ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) { + return &ggml::cpu::riscv64_spacemit::q8_0_32x32_q8_0; } +#endif } - } - } - - int repack(struct ggml_tensor * t, const void * data, size_t data_size) override { - memcpy(t->data, data, data_size); - return 0; - } -}; - -static const tensor_traits q4_0_16x8_q8_0; -static const tensor_traits q4_1_16x8_q8_0; -static const tensor_traits q4_k_16x8_q8_0; -static const tensor_traits_common rvv_impl; - -} // namespace ggml::cpu::riscv64_spacemit - -static const ggml::cpu::tensor_traits * ggml_riscv64_spacemit_get_optimal_repack_type(const struct ggml_tensor * cur) { - if (cur->type == GGML_TYPE_Q4_0) { - if (cur->ne[1] % 16 == 0) { - return &ggml::cpu::riscv64_spacemit::q4_0_16x8_q8_0; - } - } else if (cur->type == GGML_TYPE_Q4_1) { - if (cur->ne[1] % 16 == 0) { - return &ggml::cpu::riscv64_spacemit::q4_1_16x8_q8_0; - } - } else if (cur->type == GGML_TYPE_Q4_K) { - if (cur->ne[1] % 16 == 0) { - return &ggml::cpu::riscv64_spacemit::q4_k_16x8_q8_0; - } - } else if (cur->type == GGML_TYPE_F32) { - return &ggml::cpu::riscv64_spacemit::rvv_impl; + break; + case GGML_TYPE_MXFP4: + { +#if defined(RISCV64_SPACEMIT_IME2) + // TODO + // if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) { + // return &ggml::cpu::riscv64_spacemit::mxfp4_32x32_q8_0; + // } +#endif + } + break; + case GGML_TYPE_Q5_K: + { +#if defined(RISCV64_SPACEMIT_IME2) + if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) { + return &ggml::cpu::riscv64_spacemit::q5_k_32x32_q8_0; + } +#endif + } + break; + case GGML_TYPE_Q5_1: + { +#if defined(RISCV64_SPACEMIT_IME2) + if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) { + return &ggml::cpu::riscv64_spacemit::q5_1_32x32_q8_0; + } +#endif + } + break; + case GGML_TYPE_Q5_0: + { +#if defined(RISCV64_SPACEMIT_IME2) + if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) { + return &ggml::cpu::riscv64_spacemit::q5_0_32x32_q8_0; + } +#endif + } + break; + default: + break; } return nullptr; } static enum ggml_status ggml_backend_riscv64_spacemit_buffer_init_tensor(ggml_backend_buffer_t buffer, - struct ggml_tensor * tensor) { + ggml_tensor * tensor) { tensor->extra = (void *) const_cast(ggml_riscv64_spacemit_get_optimal_repack_type(tensor)); @@ -874,8 +1402,46 @@ static enum ggml_status ggml_backend_riscv64_spacemit_buffer_init_tensor(ggml_ba return GGML_STATUS_SUCCESS; } +static void ggml_backend_riscv64_spacemit_buffer_free_buffer(ggml_backend_buffer_t buffer) { + GGML_ASSERT(buffer); + + void * base = buffer->context; + if (base == nullptr) { + return; + } + + ggml::cpu::riscv64_spacemit::spine_mem_pool_free(base); +} + +static void * ggml_backend_riscv64_spacemit_buffer_get_base(ggml_backend_buffer_t buffer) { + GGML_ASSERT(buffer); + + void * base = buffer->context; + GGML_ASSERT(base != nullptr); + return base; +} + +static void ggml_backend_riscv64_spacemit_buffer_memset_tensor(ggml_backend_buffer_t buffer, + ggml_tensor * tensor, + uint8_t value, + size_t offset, + size_t size) { + GGML_ASSERT(tensor); + memset((char *) tensor->data + offset, value, size); + + GGML_UNUSED(buffer); +} + +static void ggml_backend_riscv64_spacemit_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { + GGML_ASSERT(buffer); + + void * base = buffer->context; + GGML_ASSERT(base != nullptr); + memset(base, value, buffer->size); +} + static void ggml_backend_riscv64_spacemit_buffer_set_tensor(ggml_backend_buffer_t buffer, - struct ggml_tensor * tensor, + ggml_tensor * tensor, const void * data, size_t offset, size_t size) { @@ -891,6 +1457,20 @@ static void ggml_backend_riscv64_spacemit_buffer_set_tensor(ggml_backend_buffer_ GGML_UNUSED(buffer); } +static const ggml_backend_buffer_i ggml_backend_riscv64_spacemit_buffer_i = { + /* .free_buffer = */ ggml_backend_riscv64_spacemit_buffer_free_buffer, + /* .get_base = */ ggml_backend_riscv64_spacemit_buffer_get_base, + /* .init_tensor = */ ggml_backend_riscv64_spacemit_buffer_init_tensor, + /* .memset_tensor = */ ggml_backend_riscv64_spacemit_buffer_memset_tensor, + /* .set_tensor = */ ggml_backend_riscv64_spacemit_buffer_set_tensor, + /* .get_tensor = */ nullptr, + /* .set_tensor_2d = */ nullptr, + /* .get_tensor_2d = */ nullptr, + /* .cpy_tensor = */ nullptr, + /* .clear = */ ggml_backend_riscv64_spacemit_buffer_clear, + /* .reset = */ nullptr, +}; + static const char * ggml_backend_cpu_riscv64_spacemit_buffer_type_get_name(ggml_backend_buffer_type_t buft) { return "CPU_RISCV64_SPACEMIT"; @@ -899,18 +1479,12 @@ static const char * ggml_backend_cpu_riscv64_spacemit_buffer_type_get_name(ggml_ static ggml_backend_buffer_t ggml_backend_cpu_riscv64_spacemit_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { - ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size); - - if (buffer == nullptr) { + void * base = ggml::cpu::riscv64_spacemit::spine_mem_pool_alloc(size, 64); + if (base == nullptr) { return nullptr; } - buffer->buft = buft; - buffer->iface.init_tensor = ggml_backend_riscv64_spacemit_buffer_init_tensor; - buffer->iface.set_tensor = ggml_backend_riscv64_spacemit_buffer_set_tensor; - buffer->iface.get_tensor = nullptr; - buffer->iface.cpy_tensor = nullptr; - return buffer; + return ggml_backend_buffer_init(buft, ggml_backend_riscv64_spacemit_buffer_i, base, size); } static size_t ggml_backend_cpu_riscv64_spacemit_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { @@ -919,44 +1493,91 @@ static size_t ggml_backend_cpu_riscv64_spacemit_buffer_type_get_alignment(ggml_b GGML_UNUSED(buft); } -static size_t ggml_backend_cpu_riscv64_spacemit_nbytes(ggml_backend_buffer_type_t buft, - const struct ggml_tensor * tensor) { +static size_t ggml_backend_cpu_riscv64_spacemit_nbytes(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { for (int i = 0; i < GGML_MAX_DIMS; ++i) { if (tensor->ne[i] <= 0) { return 0; } } - size_t nbytes; + GGML_UNUSED(buft); + + const auto plain_nbytes = [&]() { + size_t total = ggml_type_size(tensor->type); + for (int i = 0; i < GGML_MAX_DIMS; ++i) { + total += (tensor->ne[i] - 1) * tensor->nb[i]; + } + return total; + }; + const size_t blck_size = ggml_blck_size(tensor->type); if (blck_size == 1) { - nbytes = ggml_type_size(tensor->type); - for (int i = 0; i < GGML_MAX_DIMS; ++i) { - nbytes += (tensor->ne[i] - 1) * tensor->nb[i]; + return plain_nbytes(); + } + + const size_t row_nbytes = tensor->ne[0] * tensor->nb[0] / blck_size; + + const auto add_strided_nbytes = [&](size_t total, size_t src_block_size, size_t dst_block_size) { + for (int i = 1; i < GGML_MAX_DIMS; ++i) { + total += (tensor->ne[i] - 1) * (tensor->nb[i] / src_block_size) * dst_block_size; } - } else { - nbytes = tensor->ne[0] * tensor->nb[0] / blck_size; - if (tensor->type == GGML_TYPE_Q4_K) { - GGML_ASSERT(nbytes % sizeof(block_q4_K) == 0); - nbytes = (nbytes / sizeof(block_q4_K)) * sizeof(block_q4_1) * 8; - for (int i = 1; i < GGML_MAX_DIMS; ++i) { - nbytes += (tensor->ne[i] - 1) * (tensor->nb[i] / sizeof(block_q4_K)) * sizeof(block_q4_1) * 8; - } - } else { - for (int i = 1; i < GGML_MAX_DIMS; ++i) { - nbytes += (tensor->ne[i] - 1) * tensor->nb[i]; - } + return total; + }; + + const auto remap_block_nbytes = [&](size_t src_block_size, size_t dst_block_size, int64_t padded_rows = 0) { + GGML_ASSERT(row_nbytes % src_block_size == 0); + + size_t total = + add_strided_nbytes((row_nbytes / src_block_size) * dst_block_size, src_block_size, dst_block_size); + + if (padded_rows > 0 && tensor->ne[1] % padded_rows != 0) { + total += (padded_rows - tensor->ne[1] % padded_rows) * (tensor->nb[1] / src_block_size) * dst_block_size; } + + return total; + }; + + size_t nbytes = row_nbytes; + switch (tensor->type) { + case GGML_TYPE_Q4_K: + nbytes = remap_block_nbytes(sizeof(block_q4_K), sizeof(block_q4_1) * 8); + break; + case GGML_TYPE_Q6_K: + nbytes = remap_block_nbytes(sizeof(block_q6_K), sizeof(block_q8_0) * 8, 32); + break; + case GGML_TYPE_Q8_0: + nbytes = remap_block_nbytes(sizeof(block_q8_0), sizeof(block_q8_0), 32); + break; + case GGML_TYPE_Q2_K: + nbytes = remap_block_nbytes(sizeof(block_q2_K), sizeof(spacemit_kernels::nrow_block_q2_k<1>)); + break; + case GGML_TYPE_Q3_K: + nbytes = remap_block_nbytes(sizeof(block_q3_K), sizeof(spacemit_kernels::nrow_block_q3_k<1>)); + break; + case GGML_TYPE_MXFP4: + nbytes = remap_block_nbytes(sizeof(block_mxfp4), sizeof(spacemit_kernels::nrow_block_mxfp4<1>)); + break; + case GGML_TYPE_Q5_K: + nbytes = remap_block_nbytes(sizeof(block_q5_K), sizeof(spacemit_kernels::nrow_block_q5_1<1>) * 8); + break; + case GGML_TYPE_Q5_1: + nbytes = remap_block_nbytes(sizeof(block_q5_1), sizeof(spacemit_kernels::nrow_block_q5_1<1>)); + break; + case GGML_TYPE_Q5_0: + nbytes = remap_block_nbytes(sizeof(block_q5_0), sizeof(spacemit_kernels::nrow_block_q5_0<1>)); + break; + default: + nbytes = add_strided_nbytes(row_nbytes, 1, 1); + break; } - GGML_UNUSED(buft); return nbytes; } namespace ggml::cpu::riscv64_spacemit { class extra_buffer_type : ggml::cpu::extra_buffer_type { - bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override { + bool supports_op(ggml_backend_dev_t, const ggml_tensor * op) override { switch (op->op) { case GGML_OP_MUL_MAT: if (op->src[0]->buffer && (ggml_n_dims(op->src[0]) == 2) && @@ -970,10 +1591,16 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type { } } break; - case GGML_OP_NORM: - case GGML_OP_RMS_NORM: - if (op->src[0]->type == GGML_TYPE_F32) { - return true; + case GGML_OP_MUL_MAT_ID: + if (op->src[0]->buffer && (ggml_n_dims(op->src[0]) == 3) && + op->src[0]->buffer->buft == ggml_backend_cpu_riscv64_spacemit_buffer_type() && + ggml_riscv64_spacemit_get_optimal_repack_type(op->src[0])) { + if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) { + return false; + } + if (op->src[1]->type == GGML_TYPE_F32) { + return true; + } } break; default: @@ -983,15 +1610,28 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type { return false; } - ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override { + ggml::cpu::tensor_traits * get_tensor_traits(const ggml_tensor * op) override { switch (op->op) { case GGML_OP_MUL_MAT: + case GGML_OP_MUL_MAT_ID: if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_riscv64_spacemit_buffer_type()) { return (ggml::cpu::tensor_traits *) op->src[0]->extra; } break; case GGML_OP_NORM: case GGML_OP_RMS_NORM: + case GGML_OP_ADD: + case GGML_OP_SUB: + case GGML_OP_MUL: + case GGML_OP_DIV: + case GGML_OP_FLASH_ATTN_EXT: + case GGML_OP_CONT: + case GGML_OP_CPY: + case GGML_OP_REPEAT: + case GGML_OP_SUM_ROWS: + case GGML_OP_GET_ROWS: + case GGML_OP_CONCAT: + // case GGML_OP_GATED_DELTA_NET: return (ggml::cpu::tensor_traits *) (&ggml::cpu::riscv64_spacemit::rvv_impl); default: // GGML_ABORT("fatal error"); @@ -1005,7 +1645,7 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type { } // namespace ggml::cpu::riscv64_spacemit ggml_backend_buffer_type_t ggml_backend_cpu_riscv64_spacemit_buffer_type(void) { - static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_riscv64_spacemit = { + static ggml_backend_buffer_type ggml_backend_cpu_buffer_type_riscv64_spacemit = { /* .iface = */ { /* .get_name = */ ggml_backend_cpu_riscv64_spacemit_buffer_type_get_name, @@ -1023,3 +1663,78 @@ ggml_backend_buffer_type_t ggml_backend_cpu_riscv64_spacemit_buffer_type(void) { return &ggml_backend_cpu_buffer_type_riscv64_spacemit; } + +extern "C" { +static int bind_ai_thread() { + int fd, bytes; + char str[32]; + + fd = open("/proc/set_ai_thread", O_WRONLY); + if (fd < 0) { + GGML_LOG_ERROR("try open /proc/set_ai_thread failed\n"); + return -1; + } + + snprintf(str, 16, "%d", 0); + bytes = write(fd, str, strlen(str)); + if (bytes < 0) { + GGML_LOG_ERROR("try write /proc/set_ai_thread failed\n"); + close(fd); + return -1; + } + + close(fd); + return 0; +} + +void ggml_backend_cpu_riscv64_spacemit_set_numa_thread_affinity(int thread_n) { + int cpu_id = sched_getcpu(); + if (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2 && + !((1 << cpu_id) & ggml::cpu::riscv64_spacemit::global_spine_env_info.cpu_mask)) { + GGML_PRINT_DEBUG("bind_ai_thread for thread %d, pid %d\n", thread_n, getpid()); + bind_ai_thread(); + } + + if (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_tcm && + ggml::cpu::riscv64_spacemit::tls_context.cpu_id == -1) { + CPU_ZERO(&(ggml::cpu::riscv64_spacemit::tls_context.cpuset)); + pthread_t main_thread = pthread_self(); + const auto & perfer_core_ids = ggml::cpu::riscv64_spacemit::global_spine_env_info.perfer_core_ids; + if (thread_n < 0 || static_cast(thread_n) >= perfer_core_ids.size()) { + GGML_ABORT("thread_n %d exceeds perfer_core_ids size %zu\n", thread_n, perfer_core_ids.size()); + } + auto perfer_cpu_id = perfer_core_ids[static_cast(thread_n)]; + CPU_SET(perfer_cpu_id, &(ggml::cpu::riscv64_spacemit::tls_context.cpuset)); + int s = + pthread_setaffinity_np(main_thread, sizeof(cpu_set_t), &(ggml::cpu::riscv64_spacemit::tls_context.cpuset)); + if (s != 0) { + GGML_ABORT("set thread affinity error for thread_n %d, cpu_id %d\n", thread_n, perfer_cpu_id); + } + + int ai_cpu_id = perfer_cpu_id - ggml::cpu::riscv64_spacemit::global_spine_env_info.aicpu_id_offset; + ggml::cpu::riscv64_spacemit::tls_context.cpu_id = ai_cpu_id; + ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer = + ggml::cpu::riscv64_spacemit::spine_mem_pool_tcm_mem_get(ai_cpu_id); + ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer_size = + ggml::cpu::riscv64_spacemit::global_spine_env_info.tcm_blk_size; + } + + if (ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer != nullptr) { + void * rt = + ggml::cpu::riscv64_spacemit::spine_mem_pool_tcm_mem_wait(ggml::cpu::riscv64_spacemit::tls_context.cpu_id); + if (rt == nullptr) { + GGML_ABORT("wait tcm buffer failed for cpu_id: %d", ggml::cpu::riscv64_spacemit::tls_context.cpu_id); + } + } +} + +void ggml_backend_cpu_riscv64_spacemit_clear_numa_thread_affinity_threaded(int thread_n) { + if (ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer != nullptr) { + auto rt = ggml::cpu::riscv64_spacemit::spine_mem_pool_tcm_mem_release( + ggml::cpu::riscv64_spacemit::tls_context.cpu_id); + if (rt != 0) { + GGML_ABORT("release tcm buffer failed for cpu_id: %d", ggml::cpu::riscv64_spacemit::tls_context.cpu_id); + } + } +} +} diff --git a/ggml/src/ggml-cpu/spacemit/ime.h b/ggml/src/ggml-cpu/spacemit/ime.h index 800d91acdae..6849dd95e05 100644 --- a/ggml/src/ggml-cpu/spacemit/ime.h +++ b/ggml/src/ggml-cpu/spacemit/ime.h @@ -8,6 +8,14 @@ extern "C" { ggml_backend_buffer_type_t ggml_backend_cpu_riscv64_spacemit_buffer_type(void); +void ggml_backend_cpu_riscv64_spacemit_set_numa_thread_affinity(int thread_n); + +void ggml_backend_cpu_riscv64_spacemit_clear_numa_thread_affinity_threaded(int thread_n); + +void * ggml_backend_cpu_riscv64_spacemit_alloc_shared(size_t size, size_t alignment); + +void ggml_backend_cpu_riscv64_spacemit_free_shared(void * ptr); + #ifdef __cplusplus } #endif diff --git a/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp b/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp index cbbb6cd9160..6acc6819dfb 100644 --- a/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +++ b/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp @@ -1,8 +1,26 @@ +#include "ggml-impl.h" #include "ggml.h" #include "ime_kernels.h" +#include "rvv_kernels.h" #include #include +#include + +#if !defined(__riscv_v) || !defined(__riscv_v_intrinsic) +# error "riscv v extension or v_intrinsic not enabled" +#else +# include +#endif + +#if !defined(__riscv_zfh) +# error "riscv zfh extension not enabled" +#endif + +#if defined(RISCV64_SPACEMIT_IME1) +#else +# error "RISCV64_SPACEMIT_IME1 not defined" +#endif // clang-format off #if defined(__GNUC__) @@ -11,7 +29,7 @@ #pragma GCC diagnostic ignored "-Wunused-parameter" #endif // clang-format on -namespace sqnbitgemm_spacemit_ime { +namespace spacemit_kernels { #define QUANTIZEM4ROW_KERNEL \ "vmv.s.x v16, zero \n\t" \ @@ -76,1093 +94,208 @@ namespace sqnbitgemm_spacemit_ime { "vse8.v v31, (s1) \n\t" namespace ime1 { -void quantize_a_4row_i8(size_t BlkLen, const float * A, size_t CountK, std::byte * QuantA) { +void quantize_a_4row_i8(size_t BlkLen, const float * A, size_t CountK, uint8_t * QuantA) { constexpr float range_max_reciprocal = 1.0f / ((1 << 7) - 1); const float fone = 1.0f; - if (BlkLen == 16 || BlkLen == 32 || BlkLen == 64) { - for (size_t row_index = 0; row_index < 4; ++row_index) { - const float * SRC = A + row_index * CountK; - std::byte * DST = QuantA + row_index * sizeof(float); + for (size_t row_index = 0; row_index < 4; ++row_index) { + const float * SRC = A + row_index * CountK; + uint8_t * DST = QuantA + row_index * sizeof(float); - const size_t offset = (4 - row_index) * 4 + row_index * 8; - const size_t stride = 4 * (sizeof(float) + BlkLen); - __asm__ volatile( - "vsetvli t0, zero, e32, m8 \n\t" - "addi t2, %[CountK], 0 \n\t" - "addi a1, %[DST], 0 \n\t" - "blt t2, %[BlkLen], TAIL%= \n\t" - - "LOOP%=: \n\t" - "vsetvli t0, %[BlkLen], e32, m8 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "sub t2, t2, t0 \n\t" - "slli t1, t0, 2 \n\t" - "add %[SRC], %[SRC], t1 \n\t" - "add s1, a1, %[OFFSET] \n\t" - - QUANTIZEM4ROW_KERNEL QUANTIZEM4ROW_STORE - - "add a1, a1, %[STRIDE] \n\t" - "bge t2, %[BlkLen], LOOP%= \n\t" - - "TAIL%=: \n\t" - "blez t2, QUIT%= \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "vxor.vv v24, v24, v24 \n\t" - "vsetvli t0, t2, e32, m8 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "add s1, a1, %[OFFSET] \n\t" - - QUANTIZEM4ROW_KERNEL - - "addi t3, %[BlkLen], 0 \n\t" - "addi s2, s1, 0 \n\t" - "vsetvli t0, zero, e8, mf4 \n\t" - "vxor.vv v8, v8, v8 \n\t" - "SET_ZERO%=: \n\t" - "vse8.v v8, (s2) \n\t" - "addi s2, s2, 32 \n\t" - "addi t3, t3, -8 \n\t" - "bnez t3, SET_ZERO%= \n\t" - - QUANTIZEM4ROW_STORE - - "QUIT%=: \n\t" - : [SRC] "+r"(SRC) - : [DST] "r"(DST), [BlkLen] "r"(BlkLen), [OFFSET] "r"(offset), [STRIDE] "r"(stride), - [CountK] "r"(CountK), [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal) - : "cc", "t0", "t1", "t2", "t3", "a1", "s1", "s2", "f10", "f11"); - } - } else if (BlkLen == 128) { - for (size_t row_index = 0; row_index < 4; ++row_index) { - const float * SRC = A + row_index * CountK; - std::byte * DST = QuantA + row_index * sizeof(float); - - const size_t offset = (4 - row_index) * 4 + row_index * 8; - const size_t stride = 4 * (sizeof(float) + BlkLen); - __asm__ volatile( - "vsetvli t0, zero, e32, m8 \n\t" - "li t6, 32 \n\t" - "addi t2, %[CountK], 0 \n\t" - "addi a1, %[DST], 0 \n\t" - "add s1, a1, %[OFFSET] \n\t" - "blt t2, %[BlkLen], TAIL%= \n\t" - - "LOOP%=: \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vle32.v v8, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "addi t2, t2, -128 \n\t" - - "QUANTIZE%=: \n\t" - "add s1, a1, %[OFFSET] \n\t" - "vfabs.v v16, v0 \n\t" - "vfabs.v v24, v8 \n\t" - "vfmax.vv v16, v24, v16 \n\t" - "vfredmax.vs v24, v16, v24 \n\t" - "vfmv.f.s f10, v24 \n\t" - "fmul.s f10, f10, %[RMAXREC] \n\t" - "fsw f10, (a1) \n\t" - "fdiv.s f11, %[FONE], f10 \n\t" - "vfmul.vf v16, v0, f11 \n\t" - "vfmul.vf v24, v8, f11 \n\t" - "vfcvt.x.f.v v16, v16 \n\t" - "vfcvt.x.f.v v24, v24 \n\t" - "vsetvli t0, zero, e16, m4 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vnclip.wx v20, v24, zero \n\t" - "vsetvli t0, zero, e8, m4 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vsetvli t0, zero, e64, m4 \n\t" - "vsse64.v v16, (s1), t6 \n\t" - "add a1, a1, %[STRIDE] \n\t" - "bge t2, %[BlkLen], LOOP%= \n\t" - - "TAIL%=: \n\t" - "blez t2, QUIT%= \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v0, v0, v0 \n\t" - "vxor.vv v8, v8, v8 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "vxor.vv v24, v24, v24 \n\t" - "vsetvli t0, t2, e32, m8 \n\t" - "sub t2, t2, t0 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vsetvli t0, t2, e32, m8 \n\t" - "vle32.v v8, (%[SRC]) \n\t" - "sub t2, t2, t2 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "jal x0, QUANTIZE%= \n\t" - - "QUIT%=: \n\t" - : [SRC] "+r"(SRC) - : [DST] "r"(DST), [BlkLen] "r"(BlkLen), [OFFSET] "r"(offset), [STRIDE] "r"(stride), - [CountK] "r"(CountK), [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal) - : "cc", "t0", "t1", "t2", "t6", "a1", "s1", "s2", "f10", "f11"); - } - } else if (BlkLen == 256) { - for (size_t row_index = 0; row_index < 4; ++row_index) { - const float * SRC = A + row_index * CountK; - std::byte * DST = QuantA + row_index * sizeof(float); - const size_t offset = (4 - row_index) * 4 + row_index * 8; - const size_t stride = 4 * (sizeof(float) + BlkLen); - __asm__ volatile( - "vsetvli t0, zero, e32, m8 \n\t" - "li t6, 32 \n\t" - "addi t2, %[CountK], 0 \n\t" - "addi a1, %[DST], 0 \n\t" - "add s1, a1, %[OFFSET] \n\t" - "blt t2, %[BlkLen], TAIL%= \n\t" - - "LOOP%=: \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vle32.v v8, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vle32.v v16, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vle32.v v24, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], -768 \n\t" - "addi t2, t2, -256 \n\t" - "vfabs.v v0, v0 \n\t" - "vfabs.v v8, v8 \n\t" - "vfabs.v v16, v16 \n\t" - "vfabs.v v24, v24 \n\t" - "vfmax.vv v8, v0, v8 \n\t" - "vfmax.vv v24, v24, v16 \n\t" - "vfmax.vv v8, v8, v24 \n\t" - "vfredmax.vs v24, v8, v24 \n\t" - "vfmv.f.s f10, v24 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vle32.v v8, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vle32.v v16, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vle32.v v24, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - - "QUANTIZE%=: \n\t" - "add s1, a1, %[OFFSET] \n\t" - "fmul.s f10, f10, %[RMAXREC] \n\t" - "fsw f10, (a1) \n\t" - "fdiv.s f11, %[FONE], f10 \n\t" - "vfmul.vf v0, v0, f11 \n\t" - "vfmul.vf v8, v8, f11 \n\t" - "vfmul.vf v16, v16, f11 \n\t" - "vfmul.vf v24, v24, f11 \n\t" - "vfcvt.x.f.v v0, v0 \n\t" - "vfcvt.x.f.v v8, v8 \n\t" - "vfcvt.x.f.v v16, v16 \n\t" - "vfcvt.x.f.v v24, v24 \n\t" - "vsetvli t0, zero, e16, m4 \n\t" - "vnclip.wx v0, v0, zero \n\t" - "vnclip.wx v4, v8, zero \n\t" - "vnclip.wx v8, v16, zero \n\t" - "vnclip.wx v12, v24, zero \n\t" - "vsetvli t0, zero, e8, m4 \n\t" - "vnclip.wx v0, v0, zero \n\t" - "vnclip.wx v4, v8, zero \n\t" - "vsetvli t0, zero, e64, m8 \n\t" - "vsse64.v v0, (s1), t6 \n\t" - "add a1, a1, %[STRIDE] \n\t" - "bge t2, %[BlkLen], LOOP%= \n\t" - - "TAIL%=: \n\t" - "blez t2, QUIT%= \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v0, v0, v0 \n\t" - "vxor.vv v8, v8, v8 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "vxor.vv v24, v24, v24 \n\t" - "addi t1, t2, 0 \n\t" - "vsetvli t0, t1, e32, m8 \n\t" - "sub t1, t1, t0 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vsetvli t0, t1, e32, m8 \n\t" - "sub t1, t1, t0 \n\t" - "vle32.v v8, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vsetvli t0, t1, e32, m8 \n\t" - "sub t1, t1, t0 \n\t" - "vle32.v v16, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vsetvli t0, t1, e32, m8 \n\t" - "vle32.v v24, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], -768 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vfabs.v v0, v0 \n\t" - "vfabs.v v8, v8 \n\t" - "vfabs.v v16, v16 \n\t" - "vfabs.v v24, v24 \n\t" - "vfmax.vv v8, v0, v8 \n\t" - "vfmax.vv v24, v16, v24 \n\t" - "vfmax.vv v8, v8, v24 \n\t" - "vfredmax.vs v24, v8, v24 \n\t" - "vfmv.f.s f10, v24 \n\t" - "add s1, a1, %[OFFSET] \n\t" - "fmul.s f10, f10, %[RMAXREC] \n\t" - "fsw f10, (a1) \n\t" - "fdiv.s f11, %[FONE], f10 \n\t" - "vsetvli t0, zero, e64, m8 \n\t" - "vxor.vv v0, v0, v0 \n\t" - "vsse64.v v0, (s1), t6 \n\t" - - "TAIL_LOOP%=: \n\t" - "vsetvli t0, zero, e32, m4 \n\t" - "vxor.vv v0, v0, v0 \n\t" - "vsetvli t0, t2, e32, m1 \n\t" - "sub t2, t2, t0 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 32 \n\t" - "vfmul.vf v1, v0, f11 \n\t" - "vfcvt.x.f.v v2, v1 \n\t" - "vsetvli t0, zero, e16, mf2 \n\t" - "vnclip.wx v3, v2, zero \n\t" - "vsetvli t0, zero, e8, mf4 \n\t" - "vnclip.wx v3, v3, zero \n\t" - "vse8.v v3, (s1) \n\t" - "addi s1, s1, 32 \n\t" - "bnez t2, TAIL_LOOP%= \n\t" - - "QUIT%=: \n\t" - : [SRC] "+r"(SRC) - : [DST] "r"(DST), [BlkLen] "r"(BlkLen), [OFFSET] "r"(offset), [STRIDE] "r"(stride), - [CountK] "r"(CountK), [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal) - : "cc", "t0", "t1", "t2", "t6", "a1", "s1", "s2", "f10", "f11"); - } + const size_t offset = (4 - row_index) * 4 + row_index * 8; + const size_t stride = 4 * (sizeof(float) + BlkLen); + __asm__ volatile( + "vsetvli t0, zero, e32, m8 \n\t" + "addi t2, %[CountK], 0 \n\t" + "addi a1, %[DST], 0 \n\t" + "blt t2, %[BlkLen], TAIL%= \n\t" + + "LOOP%=: \n\t" + "vsetvli t0, %[BlkLen], e32, m8 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "sub t2, t2, t0 \n\t" + "slli t1, t0, 2 \n\t" + "add %[SRC], %[SRC], t1 \n\t" + "add s1, a1, %[OFFSET] \n\t" + + QUANTIZEM4ROW_KERNEL QUANTIZEM4ROW_STORE + + "add a1, a1, %[STRIDE] \n\t" + "bge t2, %[BlkLen], LOOP%= \n\t" + + "TAIL%=: \n\t" + "blez t2, QUIT%= \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v24, v24, v24 \n\t" + "vsetvli t0, t2, e32, m8 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "add s1, a1, %[OFFSET] \n\t" + + QUANTIZEM4ROW_KERNEL + + "addi t3, %[BlkLen], 0 \n\t" + "addi s2, s1, 0 \n\t" + "vsetvli t0, zero, e8, mf4 \n\t" + "vxor.vv v8, v8, v8 \n\t" + "SET_ZERO%=: \n\t" + "vse8.v v8, (s2) \n\t" + "addi s2, s2, 32 \n\t" + "addi t3, t3, -8 \n\t" + "bnez t3, SET_ZERO%= \n\t" + + QUANTIZEM4ROW_STORE + + "QUIT%=: \n\t" + : [SRC] "+r"(SRC) + : [DST] "r"(DST), [BlkLen] "r"(BlkLen), [OFFSET] "r"(offset), [STRIDE] "r"(stride), [CountK] "r"(CountK), + [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal) + : "cc", "t0", "t1", "t2", "t3", "a1", "s1", "s2", "f10", "f11"); } } -void quantize_a_row_i8(size_t BlkLen, const float * A, size_t CountK, std::byte * QuantA) { +void quantize_a_row_i8(size_t BlkLen, const float * A, size_t CountK, uint8_t * QuantA) { const float * SRC = A; - std::byte * DST = QuantA; + uint8_t * DST = QuantA; constexpr float range_max_reciprocal = 1.0f / ((1 << 7) - 1); const float fone = 1.0f; - std::byte * QuantA_offset = QuantA + CountK + 4 * ((CountK + BlkLen - 1) / BlkLen); + uint8_t * QuantA_offset = QuantA + CountK + 4 * ((CountK + BlkLen - 1) / BlkLen); size_t offset = (CountK + BlkLen - 1) / BlkLen * BlkLen - CountK; - if (CountK <= BlkLen) { - float max_abs_A = 0.0f; - for (size_t k = 0; k < CountK; k++) { - max_abs_A = std::max(max_abs_A, fabsf(A[k])); - } - float scale_A = max_abs_A * range_max_reciprocal; - - ((float *) QuantA)[0] = scale_A; - - auto * QuantAData_offset = (int8_t *) (QuantA + sizeof(float)); - - for (size_t k = 0; k < CountK; k++) { - QuantAData_offset[k] = - (int8_t) std::clamp(roundf(A[k] / scale_A), (float) std::numeric_limits::lowest(), - (float) std::numeric_limits::max()); - } - for (size_t k = CountK; k < BlkLen; k++) { - QuantAData_offset[k] = 0; - } - - return; - } - - if (BlkLen != 32 || BlkLen != 64 || BlkLen != 128) { - __asm__ volatile( - "vsetvli t0, zero, e8, m8 \n\t" - "vxor.vv v24, v24, v24 \n\t" - "LOOP%=: \n\t" - "vsetvli t0, %[CNT], e8, m8 \n\t" - "vse8.v v24, (%[DST]) \n\t" - "addi %[DST], %[DST], 128 \n\t" - "sub %[CNT], %[CNT], t0 \n\t" - "bnez %[CNT], LOOP%= \n\t" - : [DST] "+r"(QuantA_offset), [CNT] "+r"(offset) - : - : "cc", "t0"); - } - if (BlkLen == 16) { - float buffer[64] = { 0.0f }; - __asm__ volatile( - "addi t3, zero, 16*8 \n\t" - "addi t2, zero, 16 \n\t" - "blt %[K], t3, LOOP_K%= \n\t" - "blt %[K], t2, TAIL%= \n\t" - "LOOP_MAIN%=: \n\t" - "vsetvli t1, zero, e32, m2 \n\t" - "addi %[K], %[K], -128 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 64 \n\t" - "vle32.v v2, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 64 \n\t" - "vle32.v v4, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 64 \n\t" - "vle32.v v6, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 64 \n\t" - "vle32.v v8, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 64 \n\t" - "vle32.v v10, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 64 \n\t" - "vle32.v v12, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 64 \n\t" - "vle32.v v14, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 64 \n\t" - "addi a1, %[BUFFER], 0 \n\t" - "vfabs.v v16, v0 \n\t" - "vfabs.v v18, v2 \n\t" - "vfabs.v v20, v4 \n\t" - "vfabs.v v22, v6 \n\t" - "vfabs.v v24, v8 \n\t" - "vfabs.v v26, v10 \n\t" - "vfabs.v v28, v12 \n\t" - "vfabs.v v30, v14 \n\t" - "vsetvli t0, zero, e32, m1 \n\t" - "vfmax.vv v16, v16, v17 \n\t" - "vfmax.vv v18, v18, v19 \n\t" - "vfmax.vv v20, v20, v21 \n\t" - "vfmax.vv v22, v22, v23 \n\t" - "vfmax.vv v24, v24, v25 \n\t" - "vfmax.vv v26, v26, v27 \n\t" - "vfmax.vv v28, v28, v29 \n\t" - "vfmax.vv v30, v30, v31 \n\t" - "vse32.v v16, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vse32.v v18, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vse32.v v20, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vse32.v v22, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vse32.v v24, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vse32.v v26, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vse32.v v28, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vse32.v v30, (a1) \n\t" - "addi a1, %[BUFFER], 0 \n\t" - "flw f0, (a1) \n\t" - "flw f1, 4(a1) \n\t" - "flw f2, 8(a1) \n\t" - "flw f3, 12(a1) \n\t" - "flw f4, 16(a1) \n\t" - "flw f5, 20(a1) \n\t" - "flw f6, 24(a1) \n\t" - "flw f7, 28(a1) \n\t" - "addi a1, a1, 32 \n\t" - "fmax.s f1, f0, f1 \n\t" - "fmax.s f3, f2, f3 \n\t" - "fmax.s f5, f4, f5 \n\t" - "fmax.s f7, f6, f7 \n\t" - "fmax.s f3, f1, f3 \n\t" - "fmax.s f7, f5, f7 \n\t" - "fmax.s f10, f3, f7 \n\t" - "fmul.s f10, f10, %[RMAXREC] \n\t" - "fsw f10, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "fdiv.s f10, %[FONE], f10 \n\t" - "flw f0, (a1) \n\t" - "flw f1, 4(a1) \n\t" - "flw f2, 8(a1) \n\t" - "flw f3, 12(a1) \n\t" - "flw f4, 16(a1) \n\t" - "flw f5, 20(a1) \n\t" - "flw f6, 24(a1) \n\t" - "flw f7, 28(a1) \n\t" - "addi a1, a1, 32 \n\t" - "fmax.s f1, f0, f1 \n\t" - "fmax.s f3, f2, f3 \n\t" - "fmax.s f5, f4, f5 \n\t" - "fmax.s f7, f6, f7 \n\t" - "fmax.s f3, f1, f3 \n\t" - "fmax.s f7, f5, f7 \n\t" - "fmax.s f11, f3, f7 \n\t" - "fmul.s f11, f11, %[RMAXREC] \n\t" - "fsw f11, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "fdiv.s f11, %[FONE], f11 \n\t" - "flw f0, (a1) \n\t" - "flw f1, 4(a1) \n\t" - "flw f2, 8(a1) \n\t" - "flw f3, 12(a1) \n\t" - "flw f4, 16(a1) \n\t" - "flw f5, 20(a1) \n\t" - "flw f6, 24(a1) \n\t" - "flw f7, 28(a1) \n\t" - "addi a1, a1, 32 \n\t" - "fmax.s f1, f0, f1 \n\t" - "fmax.s f3, f2, f3 \n\t" - "fmax.s f5, f4, f5 \n\t" - "fmax.s f7, f6, f7 \n\t" - "fmax.s f3, f1, f3 \n\t" - "fmax.s f7, f5, f7 \n\t" - "fmax.s f12, f3, f7 \n\t" - "fmul.s f12, f12, %[RMAXREC] \n\t" - "fsw f12, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "fdiv.s f12, %[FONE], f12 \n\t" - "flw f0, (a1) \n\t" - "flw f1, 4(a1) \n\t" - "flw f2, 8(a1) \n\t" - "flw f3, 12(a1) \n\t" - "flw f4, 16(a1) \n\t" - "flw f5, 20(a1) \n\t" - "flw f6, 24(a1) \n\t" - "flw f7, 28(a1) \n\t" - "addi a1, a1, 32 \n\t" - "fmax.s f1, f0, f1 \n\t" - "fmax.s f3, f2, f3 \n\t" - "fmax.s f5, f4, f5 \n\t" - "fmax.s f7, f6, f7 \n\t" - "fmax.s f3, f1, f3 \n\t" - "fmax.s f7, f5, f7 \n\t" - "fmax.s f13, f3, f7 \n\t" - "fmul.s f13, f13, %[RMAXREC] \n\t" - "fsw f13, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "fdiv.s f13, %[FONE], f13 \n\t" - "flw f0, (a1) \n\t" - "flw f1, 4(a1) \n\t" - "flw f2, 8(a1) \n\t" - "flw f3, 12(a1) \n\t" - "flw f4, 16(a1) \n\t" - "flw f5, 20(a1) \n\t" - "flw f6, 24(a1) \n\t" - "flw f7, 28(a1) \n\t" - "addi a1, a1, 32 \n\t" - "fmax.s f1, f0, f1 \n\t" - "fmax.s f3, f2, f3 \n\t" - "fmax.s f5, f4, f5 \n\t" - "fmax.s f7, f6, f7 \n\t" - "fmax.s f3, f1, f3 \n\t" - "fmax.s f7, f5, f7 \n\t" - "fmax.s f14, f3, f7 \n\t" - "fmul.s f14, f14, %[RMAXREC] \n\t" - "fsw f14, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "fdiv.s f14, %[FONE], f14 \n\t" - "flw f0, (a1) \n\t" - "flw f1, 4(a1) \n\t" - "flw f2, 8(a1) \n\t" - "flw f3, 12(a1) \n\t" - "flw f4, 16(a1) \n\t" - "flw f5, 20(a1) \n\t" - "flw f6, 24(a1) \n\t" - "flw f7, 28(a1) \n\t" - "addi a1, a1, 32 \n\t" - "fmax.s f1, f0, f1 \n\t" - "fmax.s f3, f2, f3 \n\t" - "fmax.s f5, f4, f5 \n\t" - "fmax.s f7, f6, f7 \n\t" - "fmax.s f3, f1, f3 \n\t" - "fmax.s f7, f5, f7 \n\t" - "fmax.s f15, f3, f7 \n\t" - "fmul.s f15, f15, %[RMAXREC] \n\t" - "fsw f15, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "fdiv.s f15, %[FONE], f15 \n\t" - "flw f0, (a1) \n\t" - "flw f1, 4(a1) \n\t" - "flw f2, 8(a1) \n\t" - "flw f3, 12(a1) \n\t" - "flw f4, 16(a1) \n\t" - "flw f5, 20(a1) \n\t" - "flw f6, 24(a1) \n\t" - "flw f7, 28(a1) \n\t" - "addi a1, a1, 32 \n\t" - "fmax.s f1, f0, f1 \n\t" - "fmax.s f3, f2, f3 \n\t" - "fmax.s f5, f4, f5 \n\t" - "fmax.s f7, f6, f7 \n\t" - "fmax.s f3, f1, f3 \n\t" - "fmax.s f7, f5, f7 \n\t" - "fmax.s f16, f3, f7 \n\t" - "fmul.s f16, f16, %[RMAXREC] \n\t" - "fsw f16, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "fdiv.s f16, %[FONE], f16 \n\t" - "flw f0, (a1) \n\t" - "flw f1, 4(a1) \n\t" - "flw f2, 8(a1) \n\t" - "flw f3, 12(a1) \n\t" - "flw f4, 16(a1) \n\t" - "flw f5, 20(a1) \n\t" - "flw f6, 24(a1) \n\t" - "flw f7, 28(a1) \n\t" - "addi a1, a1, 32 \n\t" - "fmax.s f1, f0, f1 \n\t" - "fmax.s f3, f2, f3 \n\t" - "fmax.s f5, f4, f5 \n\t" - "fmax.s f7, f6, f7 \n\t" - "fmax.s f3, f1, f3 \n\t" - "fmax.s f7, f5, f7 \n\t" - "fmax.s f17, f3, f7 \n\t" - "fmul.s f17, f17, %[RMAXREC] \n\t" - "fsw f17, (%[DST]) \n\t" - "addi %[DST], %[DST], -136 \n\t" - "fdiv.s f17, %[FONE], f17 \n\t" - "vsetvli t0, zero, e32, m2 \n\t" - "vfmul.vf v16, v0, f10 \n\t" - "vfmul.vf v18, v2, f11 \n\t" - "vfmul.vf v20, v4, f12 \n\t" - "vfmul.vf v22, v6, f13 \n\t" - "vfmul.vf v24, v8, f14 \n\t" - "vfmul.vf v26, v10, f15 \n\t" - "vfmul.vf v28, v12, f16 \n\t" - "vfmul.vf v30, v14, f17 \n\t" - "vfcvt.x.f.v v16, v16 \n\t" - "vfcvt.x.f.v v18, v18 \n\t" - "vfcvt.x.f.v v20, v20 \n\t" - "vfcvt.x.f.v v22, v22 \n\t" - "vfcvt.x.f.v v24, v24 \n\t" - "vfcvt.x.f.v v26, v26 \n\t" - "vfcvt.x.f.v v28, v28 \n\t" - "vfcvt.x.f.v v30, v30 \n\t" - "vsetvli t0, zero, e16, m1 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vnclip.wx v18, v18, zero \n\t" - "vnclip.wx v20, v20, zero \n\t" - "vnclip.wx v22, v22, zero \n\t" - "vnclip.wx v24, v24, zero \n\t" - "vnclip.wx v26, v26, zero \n\t" - "vnclip.wx v28, v28, zero \n\t" - "vnclip.wx v30, v30, zero \n\t" - "vsetvli t0, t1, e8, mf2 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vnclip.wx v18, v18, zero \n\t" - "vnclip.wx v20, v20, zero \n\t" - "vnclip.wx v22, v22, zero \n\t" - "vnclip.wx v24, v24, zero \n\t" - "vnclip.wx v26, v26, zero \n\t" - "vnclip.wx v28, v28, zero \n\t" - "vnclip.wx v30, v30, zero \n\t" - "vse8.v v16, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "vse8.v v18, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "vse8.v v20, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "vse8.v v22, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "vse8.v v24, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "vse8.v v26, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "vse8.v v28, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "vse8.v v30, (%[DST]) \n\t" - "addi %[DST], %[DST], 16 \n\t" - "bge %[K], t3, LOOP_MAIN%= \n\t" - "blt %[K], t2, TAIL%= \n\t" - "LOOP_K%=: \n\t" - "vsetvli t1, %[K], e32, m2 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 64 \n\t" - "sub %[K], %[K], t1 \n\t" - "vfabs.v v16, v0 \n\t" - "vsetvli t0, zero, e32, m1 \n\t" - "vfmax.vv v16, v16, v17 \n\t" - "vse32.v v16, (%[BUFFER]) \n\t" - "flw f0, (%[BUFFER]) \n\t" - "flw f1, 4(%[BUFFER]) \n\t" - "flw f2, 8(%[BUFFER]) \n\t" - "flw f3, 12(%[BUFFER]) \n\t" - "flw f4, 16(%[BUFFER]) \n\t" - "flw f5, 20(%[BUFFER]) \n\t" - "flw f6, 24(%[BUFFER]) \n\t" - "flw f7, 28(%[BUFFER]) \n\t" - "fmax.s f1, f0, f1 \n\t" - "fmax.s f3, f2, f3 \n\t" - "fmax.s f5, f4, f5 \n\t" - "fmax.s f7, f6, f7 \n\t" - "fmax.s f3, f1, f3 \n\t" - "fmax.s f7, f5, f7 \n\t" - "fmax.s f10, f3, f7 \n\t" - "fmul.s f10, f10, %[RMAXREC] \n\t" - "fsw f10, (%[DST]) \n\t" - "addi %[DST], %[DST], 4 \n\t" - "fdiv.s f11, %[FONE], f10 \n\t" - "vsetvli t0, zero, e32, m2 \n\t" - "vfmul.vf v16, v0, f11 \n\t" - "vfcvt.x.f.v v16, v16 \n\t" - "vsetvli t0, zero, e16, m1 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vsetvli t0, t1, e8, mf2 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vse8.v v16, (%[DST]) \n\t" - "addi %[DST], %[DST], 16 \n\t" - "bge %[K], t2, LOOP_K%= \n\t" - "TAIL%=: \n\t" - "blez %[K], END%= \n\t" - "vsetvli t0, t3, e32, m2 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "jal x0, LOOP_K%= \n\t" - "END%=: \n\t" - : [SRC] "+r"(SRC), [DST] "+r"(DST), [K] "+r"(CountK) - : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [BUFFER] "r"(buffer) - : "cc", "t3", "t2", "t1", "t0", "a1", "f0", "f1", "f2", "f3", "f4", "f5", "f6", "f7", "f10", "f11", "f12", - "f13", "f14", "f15", "f16", "f17"); - } else if (BlkLen == 32) { - __asm__ volatile( - "addi t3, zero, 32*4 \n\t" - "addi t2, zero, 32 \n\t" - - "addi a1, %[SRC], 0 \n\t" - "addi a2, %[SRC], 128 \n\t" - "addi a3, %[SRC], 256 \n\t" - "addi a4, %[SRC], 384 \n\t" - - "addi s1, %[DST], 0 \n\t" - "addi s2, %[DST], 36 \n\t" - "addi s3, %[DST], 72 \n\t" - "addi s4, %[DST], 108 \n\t" - "blt %[K], t3, LOOP_K%= \n\t" - "blt %[K], t2, TAIL%= \n\t" - - "LOOP_MAIN%=: \n\t" - "vsetvli t1, zero, e32, m4 \n\t" - "addi %[K], %[K], -128 \n\t" - "vle32.v v0, (a1) \n\t" - "addi a1, a1, 512 \n\t" - "vle32.v v4, (a2) \n\t" - "addi a2, a2, 512 \n\t" - "vle32.v v8, (a3) \n\t" - "addi a3, a3, 512 \n\t" - "vle32.v v12, (a4) \n\t" - "addi a4, a4, 512 \n\t" - "vfabs.v v16, v0 \n\t" - "vfabs.v v20, v4 \n\t" - "vfabs.v v24, v8 \n\t" - "vfabs.v v28, v12 \n\t" - "vsetvli t0, zero, e32, m2 \n\t" - "vfmax.vv v16, v16, v18 \n\t" - "vfmax.vv v20, v20, v22 \n\t" - "vfmax.vv v24, v24, v26 \n\t" - "vfmax.vv v28, v28, v30 \n\t" - "vsetvli t0, zero, e32, m1 \n\t" - "vfmax.vv v16, v16, v17 \n\t" - "vfmax.vv v20, v20, v21 \n\t" - "vfmax.vv v24, v24, v25 \n\t" - "vfmax.vv v28, v28, v29 \n\t" - - "vfredmax.vs v17, v16, v17 \n\t" - "vfredmax.vs v21, v20, v21 \n\t" - "vfredmax.vs v25, v24, v25 \n\t" - "vfredmax.vs v29, v28, v29 \n\t" - "vfmv.f.s f10, v17 \n\t" - "vfmv.f.s f11, v21 \n\t" - "vfmv.f.s f12, v25 \n\t" - "vfmv.f.s f13, v29 \n\t" - - "fmul.s f10, f10, %[RMAXREC] \n\t" - "fmul.s f11, f11, %[RMAXREC] \n\t" - "fmul.s f12, f12, %[RMAXREC] \n\t" - "fmul.s f13, f13, %[RMAXREC] \n\t" - "fsw f10, (s1) \n\t" - "addi s1, s1, 4 \n\t" - - "fsw f11, (s2) \n\t" - "addi s2, s2, 4 \n\t" - "fsw f12, (s3) \n\t" - "addi s3, s3, 4 \n\t" - "fsw f13, (s4) \n\t" - "addi s4, s4, 4 \n\t" - "fdiv.s f10, %[FONE], f10 \n\t" - "fdiv.s f11, %[FONE], f11 \n\t" - "fdiv.s f12, %[FONE], f12 \n\t" - "fdiv.s f13, %[FONE], f13 \n\t" - "vsetvli t0, zero, e32, m4 \n\t" - "vfmul.vf v16, v0, f10 \n\t" - "vfmul.vf v20, v4, f11 \n\t" - "vfmul.vf v24, v8, f12 \n\t" - "vfmul.vf v28, v12, f13 \n\t" - "vfcvt.x.f.v v16, v16 \n\t" - "vfcvt.x.f.v v20, v20 \n\t" - "vfcvt.x.f.v v24, v24 \n\t" - "vfcvt.x.f.v v28, v28 \n\t" - "vsetvli t0, zero, e16, m2 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vnclip.wx v20, v20, zero \n\t" - "vnclip.wx v24, v24, zero \n\t" - "vnclip.wx v28, v28, zero \n\t" - "vsetvli t0, t1, e8, m1 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vnclip.wx v20, v20, zero \n\t" - "vnclip.wx v24, v24, zero \n\t" - "vnclip.wx v28, v28, zero \n\t" - "vse8.v v16, (s1) \n\t" - "addi s1, s1, 140 \n\t" - "vse8.v v20, (s2) \n\t" - "addi s2, s2, 140 \n\t" - "vse8.v v24, (s3) \n\t" - "addi s3, s3, 140 \n\t" - "vse8.v v28, (s4) \n\t" - "addi s4, s4, 140 \n\t" - "bge %[K], t3, LOOP_MAIN%= \n\t" - "blt %[K], t2, TAIL%= \n\t" - "LOOP_K%=: \n\t" - "vsetvli t1, %[K], e32, m4 \n\t" - "vle32.v v0, (a1) \n\t" - "addi a1, a1, 128 \n\t" - "sub %[K], %[K], t1 \n\t" - "vfabs.v v16, v0 \n\t" - "vsetvli t0, zero, e32, m2 \n\t" - "vfmax.vv v16, v16, v18 \n\t" - "vsetvli t0, zero, e32, m1 \n\t" - "vfmax.vv v16, v16, v17 \n\t" - "vfredmax.vs v17, v16, v17 \n\t" - "vfmv.f.s f10, v17 \n\t" - - "fmul.s f10, f10, %[RMAXREC] \n\t" - "fsw f10, (s1) \n\t" - "addi s1, s1, 4 \n\t" - "fdiv.s f11, %[FONE], f10 \n\t" - "vsetvli t0, zero, e32, m4 \n\t" - "vfmul.vf v16, v0, f11 \n\t" - "vfcvt.x.f.v v16, v16 \n\t" - "vsetvli t0, zero, e16, m2 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vsetvli t0, zero, e8, m1 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vse8.v v16, (s1) \n\t" - "addi s1, s1, 32 \n\t" - "bge %[K], t2, LOOP_K%= \n\t" - "TAIL%=: \n\t" - "blez %[K], END%= \n\t" - "vsetvli t0, t3, e32, m4 \n\t" - "vxor.vv v0, v0, v0 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "jal x0, LOOP_K%= \n\t" - "END%=: \n\t" - : [K] "+r"(CountK) - : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [SRC] "r"(SRC), [DST] "r"(DST) - : "cc", "t3", "t2", "t1", "t0", "a1", "a2", "a3", "a4", "s1", "s2", "s3", "s4", "f10", "f11", "f12", "f13"); - } else if (BlkLen == 64) { - __asm__ volatile( - "addi t3, zero, 64*2 \n\t" - "addi t2, zero, 64 \n\t" - "addi a1, %[SRC], 0 \n\t" - "addi a2, %[SRC], 256 \n\t" - "addi s1, %[DST], 0 \n\t" - "addi s2, %[DST], 68 \n\t" - "blt %[K], t3, LOOP_K%= \n\t" - "blt %[K], t2, TAIL%= \n\t" - "LOOP_MAIN%=: \n\t" - "vsetvli t1, zero, e32, m8 \n\t" - "addi %[K], %[K], -128 \n\t" - "vle32.v v0, (a1) \n\t" - "addi a1, a1, 512 \n\t" - "vle32.v v8, (a2) \n\t" - "addi a2, a2, 512 \n\t" - "vfabs.v v16, v0 \n\t" - "vfabs.v v24, v8 \n\t" - "vsetvli t0, zero, e32, m4 \n\t" - "vfmax.vv v16, v16, v20 \n\t" - "vfmax.vv v24, v24, v28 \n\t" - "vsetvli t0, zero, e32, m2 \n\t" - "vfmax.vv v16, v16, v18 \n\t" - "vfmax.vv v24, v24, v26 \n\t" - "vsetvli t0, zero, e32, m1 \n\t" - "vfmax.vv v16, v16, v17 \n\t" - "vfmax.vv v24, v24, v25 \n\t" - "vfredmax.vs v17, v16, v17 \n\t" - "vfredmax.vs v25, v24, v25 \n\t" - "vfmv.f.s f10, v17 \n\t" - "vfmv.f.s f11, v25 \n\t" - "fmul.s f10, f10, %[RMAXREC] \n\t" - "fmul.s f11, f11, %[RMAXREC] \n\t" - "fsw f10, (s1) \n\t" - "addi s1, s1, 4 \n\t" - "fsw f11, (s2) \n\t" - "addi s2, s2, 4 \n\t" - "fdiv.s f10, %[FONE], f10 \n\t" - "fdiv.s f11, %[FONE], f11 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vfmul.vf v16, v0, f10 \n\t" - "vfmul.vf v24, v8, f11 \n\t" - "vfcvt.x.f.v v16, v16 \n\t" - "vfcvt.x.f.v v24, v24 \n\t" - "vsetvli t0, zero, e16, m4 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vnclip.wx v24, v24, zero \n\t" - "vsetvli t0, t1, e8, m2 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vnclip.wx v24, v24, zero \n\t" - "vse8.v v16, (s1) \n\t" - "addi s1, s1, 132 \n\t" - "vse8.v v24, (s2) \n\t" - "addi s2, s2, 132 \n\t" - "bge %[K], t3, LOOP_MAIN%= \n\t" - "blt %[K], t2, TAIL%= \n\t" - "LOOP_K%=: \n\t" - "vsetvli t1, %[K], e32, m8 \n\t" - "vle32.v v0, (a1) \n\t" - "addi a1, a1, 256 \n\t" - "sub %[K], %[K], t1 \n\t" - "vfabs.v v16, v0 \n\t" - "vsetvli t0, zero, e32, m4 \n\t" - "vfmax.vv v16, v16, v20 \n\t" - "vsetvli t0, zero, e32, m2 \n\t" - "vfmax.vv v16, v16, v18 \n\t" - "vsetvli t0, zero, e32, m1 \n\t" - "vfmax.vv v16, v16, v17 \n\t" - "vfredmax.vs v17, v16, v17 \n\t" - "vfmv.f.s f10, v17 \n\t" - "fmul.s f10, f10, %[RMAXREC] \n\t" - "fsw f10, (s1) \n\t" - "addi s1, s1, 4 \n\t" - "fdiv.s f11, %[FONE], f10 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vfmul.vf v16, v0, f11 \n\t" - "vfcvt.x.f.v v16, v16 \n\t" - "vsetvli t0, zero, e16, m4 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vsetvli t0, zero, e8, m2 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vse8.v v16, (s1) \n\t" - "addi s1, s1, 64 \n\t" - "bge %[K], t2, LOOP_K%= \n\t" - "TAIL%=: \n\t" - "blez %[K], END%= \n\t" - "vsetvli t0, t3, e32, m8 \n\t" - "vxor.vv v0, v0, v0 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "jal x0, LOOP_K%= \n\t" - "END%=: \n\t" - : [K] "+r"(CountK) - : [SRC] "r"(SRC), [DST] "r"(DST), [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal) - : "cc", "t3", "t2", "t1", "t0", "a1", "a2", "s1", "s2", "f10", "f11"); - } else if (BlkLen == 128) { - __asm__ volatile( - "addi t2, zero, 128 \n\t" - "addi a1, %[SRC], 0 \n\t" - "addi a2, %[SRC], 256 \n\t" - "blt %[K], t2, TAIL%= \n\t" - "LOOP_K%=: \n\t" - "vsetvli t1, zero, e32, m8 \n\t" - "vle32.v v0, (a1) \n\t" - "addi a1, a1, 512 \n\t" - "vle32.v v8, (a2) \n\t" - "addi a2, a2, 512 \n\t" - "sub %[K], %[K], t2 \n\t" - "QUANT%=: \n\t" - "vfabs.v v16, v0 \n\t" - "vfabs.v v24, v8 \n\t" - "vfmax.vv v24, v16, v24 \n\t" - "vsetvli t1, zero, e32, m4 \n\t" - "vfmax.vv v28, v24, v28 \n\t" - "vsetvli t0, zero, e32, m2 \n\t" - "vfmax.vv v30, v28, v30 \n\t" - "vsetvli t0, zero, e32, m1 \n\t" - "vfmax.vv v30, v30, v31 \n\t" - "vfredmax.vs v31, v30, v31 \n\t" - "vfmv.f.s f10, v31 \n\t" - "fmul.s f10, f10, %[RMAXREC] \n\t" - "fsw f10, (%[DST]) \n\t" - "addi %[DST], %[DST], 4 \n\t" - "fdiv.s f11, %[FONE], f10 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vfmul.vf v16, v0, f11 \n\t" - "vfmul.vf v24, v8, f11 \n\t" - "vfcvt.x.f.v v16, v16 \n\t" - "vfcvt.x.f.v v24, v24 \n\t" - "vsetvli t0, zero, e16, m4 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vnclip.wx v20, v24, zero \n\t" - "vsetvli t0, zero, e8, m4 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vse8.v v16, (%[DST]) \n\t" - "addi %[DST], %[DST], 128 \n\t" - "bge %[K], t2, LOOP_K%= \n\t" - "TAIL%=: \n\t" - "blez %[K], END%= \n\t" - "vsetvli t1, zero, e32, m8 \n\t" - "vxor.vv v0, v0, v0 \n\t" - "vxor.vv v8, v8, v8 \n\t" - "vsetvli t0, %[K], e32, m8 \n\t" - "vle32.v v0, (a1) \n\t" - "sub %[K], %[K], t0 \n\t" - "vsetvli t0, %[K], e32, m8 \n\t" - "vle32.v v8, (a2) \n\t" - "sub %[K], %[K], t0 \n\t" - "vsetvli t1, zero, e32, m8 \n\t" - "jal x0, QUANT%= \n\t" - "END%=: \n\t" - - : [DST] "+r"(DST), [K] "+r"(CountK) - : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [SRC] "r"(SRC) - : "cc", "t2", "t1", "t0", "a1", "a2", "f10", "f11"); - } else { - float buffer[8] = { 0.0f }; - size_t cnt = BlkLen / 256; - - __asm__ volatile( - "slli t3, %[BLK], 2 \n\t" - "blt %[K], %[BLK], LOOP_TAIL%= \n\t" - "LOOP_MAIN%=: \n\t" - "vsetvli t0, zero, e32, m1 \n\t" - "vxor.vv v31, v31, v31 \n\t" - "vse32.v v31, (%[BUFFER]) \n\t" - "addi t6, %[CNT], 0 \n\t" - "LOOP_CMP%=: \n\t" - "addi t6, t6, -1 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vle32.v v8, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vle32.v v16, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vle32.v v24, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vfabs.v v0, v0 \n\t" - "vfabs.v v8, v8 \n\t" - "vfabs.v v16, v16 \n\t" - "vfabs.v v24, v24 \n\t" - "vfmax.vv v8, v0, v8 \n\t" - "vfmax.vv v16, v16, v24 \n\t" - "vfmax.vv v0, v0, v16 \n\t" - "vsetvli t0, zero, e32, m4 \n\t" - "vfmax.vv v0, v0, v4 \n\t" - "vsetvli t0, zero, e32, m2 \n\t" - "vfmax.vv v0, v0, v2 \n\t" - "vsetvli t0, zero, e32, m1 \n\t" - "vfmax.vv v0, v0, v1 \n\t" - "vle32.v v30, (%[BUFFER]) \n\t" - "vfmax.vv v31, v30, v0 \n\t" - "vse32.v v31, (%[BUFFER]) \n\t" - "bnez t6, LOOP_CMP%= \n\t" - "sub %[SRC], %[SRC], t3 \n\t" - "addi t6, %[CNT], 0 \n\t" - "flw f0, (%[BUFFER]) \n\t" - "flw f1, 4(%[BUFFER]) \n\t" - "flw f2, 8(%[BUFFER]) \n\t" - "flw f3, 12(%[BUFFER]) \n\t" - "flw f4, 16(%[BUFFER]) \n\t" - "flw f5, 20(%[BUFFER]) \n\t" - "flw f6, 24(%[BUFFER]) \n\t" - "flw f7, 28(%[BUFFER]) \n\t" - "fmax.s f1, f0, f1 \n\t" - "fmax.s f3, f2, f3 \n\t" - "fmax.s f5, f4, f5 \n\t" - "fmax.s f7, f6, f7 \n\t" - "fmax.s f3, f1, f3 \n\t" - "fmax.s f7, f5, f7 \n\t" - "fmax.s f10, f3, f7 \n\t" - "fmul.s f10, f10, %[RMAXREC] \n\t" - "fsw f10, (%[DST]) \n\t" - "addi %[DST], %[DST], 4 \n\t" - "fdiv.s f11, %[FONE], f10 \n\t" - "addi t6, %[CNT], 0 \n\t" - "LOOP_QUANT%=: \n\t" - "addi t6, t6, -1 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vle32.v v8, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vle32.v v16, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vle32.v v24, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vfmul.vf v0, v0, f11 \n\t" - "vfmul.vf v8, v8, f11 \n\t" - "vfmul.vf v16, v16, f11 \n\t" - "vfmul.vf v24, v24, f11 \n\t" - "vfcvt.x.f.v v0, v0 \n\t" - "vfcvt.x.f.v v8, v8 \n\t" - "vfcvt.x.f.v v16, v16 \n\t" - "vfcvt.x.f.v v24, v24 \n\t" - "vsetvli t0, zero, e16, m4 \n\t" - "vnclip.wx v0, v0, zero \n\t" - "vnclip.wx v4, v8, zero \n\t" - "vnclip.wx v8, v16, zero \n\t" - "vnclip.wx v12, v24, zero \n\t" - "vsetvli t0, zero, e8, m4 \n\t" - "vnclip.wx v0, v0, zero \n\t" - "vnclip.wx v4, v8, zero \n\t" - "vse8.v v0, (%[DST]) \n\t" - "addi %[DST], %[DST], 128 \n\t" - "vse8.v v4, (%[DST]) \n\t" - "addi %[DST], %[DST], 128 \n\t" - "bnez t6, LOOP_QUANT%= \n\t" - "sub %[K], %[K], %[BLK] \n\t" - "bge %[K], %[BLK], LOOP_MAIN%= \n\t" - "blez %[K], END%= \n\t" - "LOOP_TAIL%=: \n\t" - "vsetvli t0, zero, e32, m1 \n\t" - "vxor.vv v31, v31, v31 \n\t" - "vse32.v v31, (%[BUFFER]) \n\t" - "addi t6, %[K], 0 \n\t" - "addi s1, %[SRC], 0 \n\t" - "TAIL_CMP%=: \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v0, v0, v0 \n\t" - "vsetvli t0, t6, e32, m8 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "sub t6, t6, t0 \n\t" - "vfabs.v v0, v0 \n\t" - "vsetvli t0, zero, e32, m4 \n\t" - "vfmax.vv v0, v0, v4 \n\t" - "vsetvli t0, zero, e32, m2 \n\t" - "vfmax.vv v0, v0, v2 \n\t" - "vsetvli t0, zero, e32, m1 \n\t" - "vfmax.vv v0, v0, v1 \n\t" - "vle32.v v30, (%[BUFFER]) \n\t" - "vfmax.vv v31, v30, v0 \n\t" - "vse32.v v31, (%[BUFFER]) \n\t" - "bnez t6, TAIL_CMP%= \n\t" - "addi t6, %[K], 0 \n\t" - "flw f0, (%[BUFFER]) \n\t" - "flw f1, 4(%[BUFFER]) \n\t" - "flw f2, 8(%[BUFFER]) \n\t" - "flw f3, 12(%[BUFFER]) \n\t" - "flw f4, 16(%[BUFFER]) \n\t" - "flw f5, 20(%[BUFFER]) \n\t" - "flw f6, 24(%[BUFFER]) \n\t" - "flw f7, 28(%[BUFFER]) \n\t" - "fmax.s f1, f0, f1 \n\t" - "fmax.s f3, f2, f3 \n\t" - "fmax.s f5, f4, f5 \n\t" - "fmax.s f7, f6, f7 \n\t" - "fmax.s f3, f1, f3 \n\t" - "fmax.s f7, f5, f7 \n\t" - "fmax.s f10, f3, f7 \n\t" - "fmul.s f10, f10, %[RMAXREC] \n\t" - "fsw f10, (%[DST]) \n\t" - "addi %[DST], %[DST], 4 \n\t" - "fdiv.s f11, %[FONE], f10 \n\t" - "addi t6, %[K], 0 \n\t" - "TAIL_QUANT%=: \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v0, v0, v0 \n\t" - "vsetvli t1, t6, e32, m8 \n\t" - "vle32.v v0, (s1) \n\t" - "addi s1, s1, 256 \n\t" - "sub t6, t6, t1 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vfmul.vf v0, v0, f11 \n\t" - "vfcvt.x.f.v v0, v0 \n\t" - "vsetvli t0, zero, e16, m4 \n\t" - "vnclip.wx v0, v0, zero \n\t" - "vsetvli t0, t1, e8, m2 \n\t" - "vnclip.wx v0, v0, zero \n\t" - "vse8.v v0, (%[DST]) \n\t" - "addi %[DST], %[DST], 64 \n\t" - "bnez t6, TAIL_QUANT%= \n\t" - "END%=: \n\t" - : [SRC] "+r"(SRC), [DST] "+r"(DST), [K] "+r"(CountK) - : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [BLK] "r"(BlkLen), [BUFFER] "r"(buffer), - [CNT] "r"(cnt) - : "cc", "t1", "t0", "t6", "s1", "f0", "f1", "f2", "f3", "f4", "f5", "f6"); - } + __asm__ volatile( + "addi t3, zero, 32*4 \n\t" + "addi t2, zero, 32 \n\t" + + "addi a1, %[SRC], 0 \n\t" + "addi a2, %[SRC], 128 \n\t" + "addi a3, %[SRC], 256 \n\t" + "addi a4, %[SRC], 384 \n\t" + + "addi s1, %[DST], 0 \n\t" + "addi s2, %[DST], 36 \n\t" + "addi s3, %[DST], 72 \n\t" + "addi s4, %[DST], 108 \n\t" + "blt %[K], t3, LOOP_K%= \n\t" + "blt %[K], t2, TAIL%= \n\t" + + "LOOP_MAIN%=: \n\t" + "vsetvli t1, zero, e32, m4 \n\t" + "addi %[K], %[K], -128 \n\t" + "vle32.v v0, (a1) \n\t" + "addi a1, a1, 512 \n\t" + "vle32.v v4, (a2) \n\t" + "addi a2, a2, 512 \n\t" + "vle32.v v8, (a3) \n\t" + "addi a3, a3, 512 \n\t" + "vle32.v v12, (a4) \n\t" + "addi a4, a4, 512 \n\t" + "vfabs.v v16, v0 \n\t" + "vfabs.v v20, v4 \n\t" + "vfabs.v v24, v8 \n\t" + "vfabs.v v28, v12 \n\t" + "vsetvli t0, zero, e32, m2 \n\t" + "vfmax.vv v16, v16, v18 \n\t" + "vfmax.vv v20, v20, v22 \n\t" + "vfmax.vv v24, v24, v26 \n\t" + "vfmax.vv v28, v28, v30 \n\t" + "vsetvli t0, zero, e32, m1 \n\t" + "vfmax.vv v16, v16, v17 \n\t" + "vfmax.vv v20, v20, v21 \n\t" + "vfmax.vv v24, v24, v25 \n\t" + "vfmax.vv v28, v28, v29 \n\t" + + "vfredmax.vs v17, v16, v17 \n\t" + "vfredmax.vs v21, v20, v21 \n\t" + "vfredmax.vs v25, v24, v25 \n\t" + "vfredmax.vs v29, v28, v29 \n\t" + "vfmv.f.s f10, v17 \n\t" + "vfmv.f.s f11, v21 \n\t" + "vfmv.f.s f12, v25 \n\t" + "vfmv.f.s f13, v29 \n\t" + + "fmul.s f10, f10, %[RMAXREC] \n\t" + "fmul.s f11, f11, %[RMAXREC] \n\t" + "fmul.s f12, f12, %[RMAXREC] \n\t" + "fmul.s f13, f13, %[RMAXREC] \n\t" + "fsw f10, (s1) \n\t" + "addi s1, s1, 4 \n\t" + + "fsw f11, (s2) \n\t" + "addi s2, s2, 4 \n\t" + "fsw f12, (s3) \n\t" + "addi s3, s3, 4 \n\t" + "fsw f13, (s4) \n\t" + "addi s4, s4, 4 \n\t" + "fdiv.s f10, %[FONE], f10 \n\t" + "fdiv.s f11, %[FONE], f11 \n\t" + "fdiv.s f12, %[FONE], f12 \n\t" + "fdiv.s f13, %[FONE], f13 \n\t" + "vsetvli t0, zero, e32, m4 \n\t" + "vfmul.vf v16, v0, f10 \n\t" + "vfmul.vf v20, v4, f11 \n\t" + "vfmul.vf v24, v8, f12 \n\t" + "vfmul.vf v28, v12, f13 \n\t" + "vfcvt.x.f.v v16, v16 \n\t" + "vfcvt.x.f.v v20, v20 \n\t" + "vfcvt.x.f.v v24, v24 \n\t" + "vfcvt.x.f.v v28, v28 \n\t" + "vsetvli t0, zero, e16, m2 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vnclip.wx v20, v20, zero \n\t" + "vnclip.wx v24, v24, zero \n\t" + "vnclip.wx v28, v28, zero \n\t" + "vsetvli t0, t1, e8, m1 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vnclip.wx v20, v20, zero \n\t" + "vnclip.wx v24, v24, zero \n\t" + "vnclip.wx v28, v28, zero \n\t" + "vse8.v v16, (s1) \n\t" + "addi s1, s1, 140 \n\t" + "vse8.v v20, (s2) \n\t" + "addi s2, s2, 140 \n\t" + "vse8.v v24, (s3) \n\t" + "addi s3, s3, 140 \n\t" + "vse8.v v28, (s4) \n\t" + "addi s4, s4, 140 \n\t" + "bge %[K], t3, LOOP_MAIN%= \n\t" + "blt %[K], t2, TAIL%= \n\t" + "LOOP_K%=: \n\t" + "vsetvli t1, %[K], e32, m4 \n\t" + "vle32.v v0, (a1) \n\t" + "addi a1, a1, 128 \n\t" + "sub %[K], %[K], t1 \n\t" + "vfabs.v v16, v0 \n\t" + "vsetvli t0, zero, e32, m2 \n\t" + "vfmax.vv v16, v16, v18 \n\t" + "vsetvli t0, zero, e32, m1 \n\t" + "vfmax.vv v16, v16, v17 \n\t" + "vfredmax.vs v17, v16, v17 \n\t" + "vfmv.f.s f10, v17 \n\t" + + "fmul.s f10, f10, %[RMAXREC] \n\t" + "fsw f10, (s1) \n\t" + "addi s1, s1, 4 \n\t" + "fdiv.s f11, %[FONE], f10 \n\t" + "vsetvli t0, zero, e32, m4 \n\t" + "vfmul.vf v16, v0, f11 \n\t" + "vfcvt.x.f.v v16, v16 \n\t" + "vsetvli t0, zero, e16, m2 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vsetvli t0, zero, e8, m1 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vse8.v v16, (s1) \n\t" + "addi s1, s1, 32 \n\t" + "bge %[K], t2, LOOP_K%= \n\t" + "TAIL%=: \n\t" + "blez %[K], END%= \n\t" + "vsetvli t0, t3, e32, m4 \n\t" + "vxor.vv v0, v0, v0 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "jal x0, LOOP_K%= \n\t" + "END%=: \n\t" + : [K] "+r"(CountK) + : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [SRC] "r"(SRC), [DST] "r"(DST) + : "cc", "t3", "t2", "t1", "t0", "a1", "a2", "a3", "a4", "s1", "s2", "s3", "s4", "f10", "f11", "f12", "f13"); } } // namespace ime1 @@ -1451,1746 +584,444 @@ namespace { "vadd.vi v1, v1, -12 \n\t" template -void SQ4BitGemmM4Kernel_CompInt8_ScaleFp16_Impl(size_t BlkLen, - const std::byte * QuantA, - const std::byte * QuantBData, - const float * QuantBScale, - const std::byte * QuantBZeroPoint, - float * C, - size_t CountN, - size_t BlockCountK, - const float * Bias, - const size_t ldc) { - GGML_UNUSED(QuantBScale); - GGML_UNUSED(QuantBZeroPoint); +void SQ4BitGemmM4Kernel_CompInt8_ScaleFp16_Impl(size_t BlkLen, + const uint8_t * QuantA, + const uint8_t * QuantBData, + float * C, + size_t CountN, + size_t BlockCountK, + const size_t ldc) { size_t LDC = ldc * sizeof(float); const size_t INNER = BlkLen / 16; float tmp[4 * 16]; if constexpr (HasZeroPoint) { for (size_t n = 0; n < CountN; n += 16) { - size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n; - std::byte * QuantBDataPtr = (std::byte *) QuantBData + // - n * BlockCountK * BlkLen / 2 + // b data - n * BlockCountK * sizeof(uint8_t) + // zp - n * BlockCountK * sizeof(_Float16); // scale + size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n; + uint8_t * QuantBDataPtr = (uint8_t *) QuantBData + // + n * BlockCountK * BlkLen / 2 + // b data + n * BlockCountK * sizeof(uint8_t) + // zp + n * BlockCountK * sizeof(_Float16); // scale float * CPtr = C + n; if (NBLKS < 16) { CPtr = tmp; LDC = 16 * sizeof(float); } - if (Bias != nullptr) { - const float * bias = Bias + n; - if (NBLKS < 16) { - __asm__ volatile( - "vsetvli t0, %[N], e32, m2 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "vse32.v v0, (%[DST]) \n\t" - : - : [SRC] "r"(bias), [DST] "r"(tmp), [N] "r"(NBLKS) - : "cc", "t0"); - bias = tmp; - } - __asm__ volatile(LOAD_BIAS - - "addi t3, %[BlockCountK], 0 \n\t" - - "vsetvli t0, zero, e8, m1 \n\t" - "li s1, 24 \n\t" - "vmv.v.i v1, 3 \n\t" - "vsetvli t0, s1, e8, m1 \n\t" - "vmv.v.i v1, 2 \n\t" - "vsetvli t0, zero, e8, mf2 \n\t" - "vmv.v.i v1, 1 \n\t" - "vsetvli t0, zero, e8, mf4 \n\t" - "vmv.v.i v1, 0 \n\t" - - "addi a1, %[A], 0 \n\t" - "addi s1, %[B], 0 \n\t" - - "BLOCK_COUNTK_LOOP%=: \n\t" - // scale offset - "addi s5, s1, 0 \n\t" - // zp offset - "addi s6, s1, 32 \n\t" - "addi s1, s6, 16 \n\t" - "addi s2, s1, 32 \n\t" - "addi s3, s1, 32*2 \n\t" - "addi s4, s1, 32*3 \n\t" - - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v16, v16, v16 \n\t" - // load a scale - "flw f1, (a1) \n\t" - "flw f2, 4(a1) \n\t" - "flw f3, 8(a1) \n\t" - "flw f4, 12(a1) \n\t" - "addi a1, a1, 16 \n\t" - "addi t2, %[INNER], 0 \n\t" - - SQ4BIT_KERNEL_LOAD_ZP_16X1_v2 - - "BLOCK_INNER_LOOP%=: \n\t" - - LOAD_B_16x8x2 - - "vle8.v v10, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vle8.v v11, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vsub.vv v2, v2, v12 \n\t" - "vsub.vv v6, v6, v12 \n\t" - "vsub.vv v3, v3, v13 \n\t" - "vsub.vv v7, v7, v13 \n\t" - "vsub.vv v4, v4, v14 \n\t" - "vsub.vv v8, v8, v14 \n\t" - "vsub.vv v5, v5, v15 \n\t" - "vsub.vv v9, v9, v15 \n\t" - - SQ4BIT_KERNEL_COMP_4x16x16 - - "addi t2, t2, -1 \n\t" - "bnez t2, BLOCK_INNER_LOOP%= \n\t" - - LOAD_SCALE_4x16_FP16 - - "vsetvli t0, zero, e32, m8 \n\t" - "vfcvt.f.x.v v16, v16 \n\t" - "vfmacc.vv v24, v16, v8 \n\t" - "addi t3, t3, -1 \n\t" - "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" - - "RESULT_SAVE%=: \n\t" - - SAVE_RESULT_4x16 - - : - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), - [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr), [BIAS] "r"(bias) - : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", - "s2", "s3", "s4", "s5", "s6"); - - } else { - __asm__ volatile( - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v24, v24, v24 \n\t" - "addi t3, %[BlockCountK], 0 \n\t" - "vsetvli t0, zero, e8, m1 \n\t" - "li s1, 24 \n\t" - "vmv.v.i v1, 3 \n\t" - "vsetvli t0, s1, e8, m1 \n\t" - "vmv.v.i v1, 2 \n\t" - "vsetvli t0, zero, e8, mf2 \n\t" - "vmv.v.i v1, 1 \n\t" - "vsetvli t0, zero, e8, mf4 \n\t" - "vmv.v.i v1, 0 \n\t" - "addi a1, %[A], 0 \n\t" - "addi s1, %[B], 0 \n\t" - "BLOCK_COUNTK_LOOP%=: \n\t" - // scale offset - "addi s5, s1, 0 \n\t" - // zp offset - "addi s6, s1, 32 \n\t" - "addi s1, s6, 16 \n\t" - "addi s2, s1, 32 \n\t" - "addi s3, s1, 32*2 \n\t" - "addi s4, s1, 32*3 \n\t" - - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v16, v16, v16 \n\t" - // load a scale - "flw f1, (a1) \n\t" - "flw f2, 4(a1) \n\t" - "flw f3, 8(a1) \n\t" - "flw f4, 12(a1) \n\t" - "addi a1, a1, 16 \n\t" - "addi t2, %[INNER], 0 \n\t" - - SQ4BIT_KERNEL_LOAD_ZP_16X1_v2 - - "BLOCK_INNER_LOOP%=: \n\t" - - LOAD_B_16x8x2 - - "vle8.v v10, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vle8.v v11, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vsub.vv v2, v2, v12 \n\t" - "vsub.vv v6, v6, v12 \n\t" - "vsub.vv v3, v3, v13 \n\t" - "vsub.vv v7, v7, v13 \n\t" - "vsub.vv v4, v4, v14 \n\t" - "vsub.vv v8, v8, v14 \n\t" - "vsub.vv v5, v5, v15 \n\t" - "vsub.vv v9, v9, v15 \n\t" - - SQ4BIT_KERNEL_COMP_4x16x16 - - "addi t2, t2, -1 \n\t" - "bnez t2, BLOCK_INNER_LOOP%= \n\t" - - LOAD_SCALE_4x16_FP16 - - "vsetvli t0, zero, e32, m8 \n\t" - "vfcvt.f.x.v v16, v16 \n\t" - "vfmacc.vv v24, v16, v8 \n\t" - "addi t3, t3, -1 \n\t" - "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" - - "RESULT_SAVE%=: \n\t" - - SAVE_RESULT_4x16 - - : - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), - [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr) - : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3", - "s4", "s5", "s6"); - } - } - } else { - for (size_t n = 0; n < CountN; n += 16) { - size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n; - std::byte * QuantBDataPtr = (std::byte *) QuantBData + // - n * BlockCountK * BlkLen / 2 + // b data - n * BlockCountK * sizeof(_Float16); // scale - float * CPtr = C + n; - if (NBLKS < 16) { - CPtr = tmp; - LDC = 16 * sizeof(float); - } - if (Bias != nullptr) { - const float * bias = Bias + n; - if (NBLKS < 16) { - __asm__ volatile( - "vsetvli t0, %[N], e32, m2 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "vse32.v v0, (%[DST]) \n\t" - : - : [SRC] "r"(bias), [DST] "r"(tmp), [N] "r"(NBLKS) - : "cc", "t0"); - bias = tmp; - } - __asm__ volatile(LOAD_BIAS - - "addi t3, %[BlockCountK], 0 \n\t" - "addi a1, %[A], 0 \n\t" - "addi s1, %[B], 0 \n\t" - "BLOCK_COUNTK_LOOP%=: \n\t" - "addi s5, s1, 0 \n\t" - "addi s1, s5, 32 \n\t" - "addi s2, s1, 32 \n\t" - "addi s3, s1, 32*2 \n\t" - "addi s4, s1, 32*3 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v16, v16, v16 \n\t" - // load a scale - "flw f1, (a1) \n\t" - "flw f2, 4(a1) \n\t" - "flw f3, 8(a1) \n\t" - "flw f4, 12(a1) \n\t" - "addi a1, a1, 16 \n\t" - "addi t2, %[INNER], 0 \n\t" - "BLOCK_INNER_LOOP%=: \n\t" - - LOAD_B_16x8x2 - - "vsetvli t0, zero, e8, m1 \n\t" - "vle8.v v10, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vle8.v v11, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vadd.vi v2, v2, -8 \n\t" - "vadd.vi v3, v3, -8 \n\t" - "vadd.vi v4, v4, -8 \n\t" - "vadd.vi v5, v5, -8 \n\t" - "vadd.vi v6, v6, -8 \n\t" - "vadd.vi v7, v7, -8 \n\t" - "vadd.vi v8, v8, -8 \n\t" - "vadd.vi v9, v9, -8 \n\t" - - SQ4BIT_KERNEL_COMP_4x16x16 - - "addi t2, t2, -1 \n\t" - "bnez t2, BLOCK_INNER_LOOP%= \n\t" - - LOAD_SCALE_4x16_FP16 - - "vsetvli t0, zero, e32, m8 \n\t" - "vfcvt.f.x.v v16, v16 \n\t" - "vfmacc.vv v24, v16, v8 \n\t" - "addi t3, t3, -1 \n\t" - "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" - "RESULT_SAVE%=: \n\t" - - SAVE_RESULT_4x16 - - : - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), - [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr), [BIAS] "r"(bias) - : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", - "s2", "s3", "s4", "s5", "s6"); - - } else { - __asm__ volatile( - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v24, v24, v24 \n\t" - "addi t3, %[BlockCountK], 0 \n\t" - "addi a1, %[A], 0 \n\t" - "addi s1, %[B], 0 \n\t" - "BLOCK_COUNTK_LOOP%=: \n\t" - "addi s5, s1, 0 \n\t" - "addi s1, s5, 32 \n\t" - "addi s2, s1, 32 \n\t" - "addi s3, s1, 32*2 \n\t" - "addi s4, s1, 32*3 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v16, v16, v16 \n\t" - // load a scale - "flw f1, (a1) \n\t" - "flw f2, 4(a1) \n\t" - "flw f3, 8(a1) \n\t" - "flw f4, 12(a1) \n\t" - "addi a1, a1, 16 \n\t" - "addi t2, %[INNER], 0 \n\t" - "BLOCK_INNER_LOOP%=: \n\t" - - LOAD_B_16x8x2 - - "vsetvli t0, zero, e8, m1 \n\t" - "vle8.v v10, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vle8.v v11, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vadd.vi v2, v2, -8 \n\t" - "vadd.vi v3, v3, -8 \n\t" - "vadd.vi v4, v4, -8 \n\t" - "vadd.vi v5, v5, -8 \n\t" - "vadd.vi v6, v6, -8 \n\t" - "vadd.vi v7, v7, -8 \n\t" - "vadd.vi v8, v8, -8 \n\t" - "vadd.vi v9, v9, -8 \n\t" - - SQ4BIT_KERNEL_COMP_4x16x16 - - "addi t2, t2, -1 \n\t" - "bnez t2, BLOCK_INNER_LOOP%= \n\t" - - LOAD_SCALE_4x16_FP16 - - "vsetvli t0, zero, e32, m8 \n\t" - "vfcvt.f.x.v v16, v16 \n\t" - "vfmacc.vv v24, v16, v8 \n\t" - "addi t3, t3, -1 \n\t" - "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" - "RESULT_SAVE%=: \n\t" - - SAVE_RESULT_4x16 - - : - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), - [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr) - : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3", - "s4", "s5", "s6"); - } - } - } - if (CountN % 16 != 0) { - // stroe output from tmp to C when NBLKS less than 16. - float * CPtr = C + CountN / 16 * 16; - const size_t N = CountN % 16; - LDC = ldc * sizeof(float); - __asm__ volatile( - "vsetvli t0, %[N], e32, m2 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "addi s2, %[SRC], 64 \n\t" - "addi s3, %[SRC], 64*2 \n\t" - "addi s4, %[SRC], 64*3 \n\t" - "vle32.v v2, (s2) \n\t" - "vle32.v v4, (s3) \n\t" - "vle32.v v6, (s4) \n\t" - "add t2, %[DST], %[LDC] \n\t" - "add t3, t2, %[LDC] \n\t" - "add t4, t3, %[LDC] \n\t" - "vse32.v v0, (%[DST]) \n\t" - "vse32.v v2, (t2) \n\t" - "vse32.v v4, (t3) \n\t" - "vse32.v v6, (t4) \n\t" - : - : [N] "r"(N), [SRC] "r"(tmp), [DST] "r"(CPtr), [LDC] "r"(LDC) - : "cc", "t0", "t2", "t3", "t4", "s2", "s3", "s4"); - } -} -template -void SQ4BitGemmM4Kernel_CompInt8_Impl(size_t BlkLen, - const std::byte * QuantA, - const std::byte * QuantBData, - const float * QuantBScale, - const std::byte * QuantBZeroPoint, - float * C, - size_t CountN, - size_t BlockCountK, - const float * Bias, - const size_t ldc) { - GGML_UNUSED(QuantBScale); - GGML_UNUSED(QuantBZeroPoint); - size_t LDC = ldc * sizeof(float); - const size_t INNER = BlkLen / 16; - float tmp[4 * 16]; - - if constexpr (HasZeroPoint) { - for (size_t n = 0; n < CountN; n += 16) { - size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n; - std::byte * QuantBDataPtr = (std::byte *) QuantBData + // - n * BlockCountK * BlkLen / 2 + // b data - n * BlockCountK * sizeof(uint8_t) + // zp - n * BlockCountK * sizeof(float); // scale - float * CPtr = C + n; - if (NBLKS < 16) { - CPtr = tmp; - LDC = 16 * sizeof(float); - } - if (Bias != nullptr) { - const float * bias = Bias + n; - if (NBLKS < 16) { - __asm__ volatile( - "vsetvli t0, %[N], e32, m2 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "vse32.v v0, (%[DST]) \n\t" - : - : [SRC] "r"(bias), [DST] "r"(tmp), [N] "r"(NBLKS) - : "cc", "t0"); - bias = tmp; - } - - __asm__ volatile(LOAD_BIAS - "addi t3, %[BlockCountK], 0 \n\t" - "vsetvli t0, zero, e8, m1 \n\t" - "li s1, 24 \n\t" - "vmv.v.i v1, 3 \n\t" - "vsetvli t0, s1, e8, m1 \n\t" - "vmv.v.i v1, 2 \n\t" - "vsetvli t0, zero, e8, mf2 \n\t" - "vmv.v.i v1, 1 \n\t" - "vsetvli t0, zero, e8, mf4 \n\t" - "vmv.v.i v1, 0 \n\t" - "addi a1, %[A], 0 \n\t" - "addi s1, %[B], 0 \n\t" - "BLOCK_COUNTK_LOOP%=: \n\t" - // scale offset - "addi s5, s1, 0 \n\t" - // zp offset - "addi s6, s1, 64 \n\t" - "addi s1, s6, 16 \n\t" - "addi s2, s1, 32 \n\t" - "addi s3, s1, 32*2 \n\t" - "addi s4, s1, 32*3 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v16, v16, v16 \n\t" - // load a scale - "flw f1, (a1) \n\t" - "flw f2, 4(a1) \n\t" - "flw f3, 8(a1) \n\t" - "flw f4, 12(a1) \n\t" - "addi a1, a1, 16 \n\t" - "addi t2, %[INNER], 0 \n\t" - - SQ4BIT_KERNEL_LOAD_ZP_16X1_v2 - - "BLOCK_INNER_LOOP%=: \n\t" - - LOAD_B_16x8x2 - - "vle8.v v10, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vle8.v v11, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vsub.vv v2, v2, v12 \n\t" - "vsub.vv v6, v6, v12 \n\t" - "vsub.vv v3, v3, v13 \n\t" - "vsub.vv v7, v7, v13 \n\t" - "vsub.vv v4, v4, v14 \n\t" - "vsub.vv v8, v8, v14 \n\t" - "vsub.vv v5, v5, v15 \n\t" - "vsub.vv v9, v9, v15 \n\t" - - SQ4BIT_KERNEL_COMP_4x16x16 - - "addi t2, t2, -1 \n\t" - "bnez t2, BLOCK_INNER_LOOP%= \n\t" - - LOAD_SCALE_4x16 - - "vsetvli t0, zero, e32, m8 \n\t" - "vfcvt.f.x.v v16, v16 \n\t" - "vfmacc.vv v24, v16, v8 \n\t" - "addi t3, t3, -1 \n\t" - "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" - - "RESULT_SAVE%=: \n\t" - - SAVE_RESULT_4x16 - - : - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), - [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr), [BIAS] "r"(bias) - : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", - "s2", "s3", "s4", "s5", "s6"); - - } else { - __asm__ volatile( - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v24, v24, v24 \n\t" - "addi t3, %[BlockCountK], 0 \n\t" - "vsetvli t0, zero, e8, m1 \n\t" - "li s1, 24 \n\t" - "vmv.v.i v1, 3 \n\t" - "vsetvli t0, s1, e8, m1 \n\t" - "vmv.v.i v1, 2 \n\t" - "vsetvli t0, zero, e8, mf2 \n\t" - "vmv.v.i v1, 1 \n\t" - "vsetvli t0, zero, e8, mf4 \n\t" - "vmv.v.i v1, 0 \n\t" - "addi a1, %[A], 0 \n\t" - "addi s1, %[B], 0 \n\t" - "BLOCK_COUNTK_LOOP%=: \n\t" - // scale offset - "addi s5, s1, 0 \n\t" - // zp offset - "addi s6, s1, 64 \n\t" - "addi s1, s6, 16 \n\t" - "addi s2, s1, 32 \n\t" - "addi s3, s1, 32*2 \n\t" - "addi s4, s1, 32*3 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v16, v16, v16 \n\t" - // load a scale - // load a scale - "flw f1, (a1) \n\t" - "flw f2, 4(a1) \n\t" - "flw f3, 8(a1) \n\t" - "flw f4, 12(a1) \n\t" - "addi a1, a1, 16 \n\t" - "addi t2, %[INNER], 0 \n\t" - - SQ4BIT_KERNEL_LOAD_ZP_16X1_v2 - - "BLOCK_INNER_LOOP%=: \n\t" - - LOAD_B_16x8x2 - - "vle8.v v10, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vle8.v v11, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vsub.vv v2, v2, v12 \n\t" - "vsub.vv v6, v6, v12 \n\t" - "vsub.vv v3, v3, v13 \n\t" - "vsub.vv v7, v7, v13 \n\t" - "vsub.vv v4, v4, v14 \n\t" - "vsub.vv v8, v8, v14 \n\t" - "vsub.vv v5, v5, v15 \n\t" - "vsub.vv v9, v9, v15 \n\t" - - SQ4BIT_KERNEL_COMP_4x16x16 - - "addi t2, t2, -1 \n\t" - "bnez t2, BLOCK_INNER_LOOP%= \n\t" - - LOAD_SCALE_4x16 - - "vsetvli t0, zero, e32, m8 \n\t" - "vfcvt.f.x.v v16, v16 \n\t" - "vfmacc.vv v24, v16, v8 \n\t" - "addi t3, t3, -1 \n\t" - "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" - - "RESULT_SAVE%=: \n\t" - - SAVE_RESULT_4x16 - - : - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), - [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr) - : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3", - "s4", "s5", "s6"); - } + __asm__ volatile( + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v24, v24, v24 \n\t" + "addi t3, %[BlockCountK], 0 \n\t" + "vsetvli t0, zero, e8, m1 \n\t" + "li s1, 24 \n\t" + "vmv.v.i v1, 3 \n\t" + "vsetvli t0, s1, e8, m1 \n\t" + "vmv.v.i v1, 2 \n\t" + "vsetvli t0, zero, e8, mf2 \n\t" + "vmv.v.i v1, 1 \n\t" + "vsetvli t0, zero, e8, mf4 \n\t" + "vmv.v.i v1, 0 \n\t" + "addi a1, %[A], 0 \n\t" + "addi s1, %[B], 0 \n\t" + "BLOCK_COUNTK_LOOP%=: \n\t" + // scale offset + "addi s5, s1, 0 \n\t" + // zp offset + "addi s6, s1, 32 \n\t" + "addi s1, s6, 16 \n\t" + "addi s2, s1, 32 \n\t" + "addi s3, s1, 32*2 \n\t" + "addi s4, s1, 32*3 \n\t" + + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v16, v16, v16 \n\t" + // load a scale + "flw f1, (a1) \n\t" + "flw f2, 4(a1) \n\t" + "flw f3, 8(a1) \n\t" + "flw f4, 12(a1) \n\t" + "addi a1, a1, 16 \n\t" + "addi t2, %[INNER], 0 \n\t" + + SQ4BIT_KERNEL_LOAD_ZP_16X1_v2 + + "BLOCK_INNER_LOOP%=: \n\t" + + LOAD_B_16x8x2 + + "vle8.v v10, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vle8.v v11, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vsub.vv v2, v2, v12 \n\t" + "vsub.vv v6, v6, v12 \n\t" + "vsub.vv v3, v3, v13 \n\t" + "vsub.vv v7, v7, v13 \n\t" + "vsub.vv v4, v4, v14 \n\t" + "vsub.vv v8, v8, v14 \n\t" + "vsub.vv v5, v5, v15 \n\t" + "vsub.vv v9, v9, v15 \n\t" + + SQ4BIT_KERNEL_COMP_4x16x16 + + "addi t2, t2, -1 \n\t" + "bnez t2, BLOCK_INNER_LOOP%= \n\t" + + LOAD_SCALE_4x16_FP16 + + "vsetvli t0, zero, e32, m8 \n\t" + "vfcvt.f.x.v v16, v16 \n\t" + "vfmacc.vv v24, v16, v8 \n\t" + "addi t3, t3, -1 \n\t" + "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" + + "RESULT_SAVE%=: \n\t" + + SAVE_RESULT_4x16 + + : + : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), + [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr) + : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3", "s4", + "s5", "s6"); } } else { for (size_t n = 0; n < CountN; n += 16) { - size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n; - std::byte * QuantBDataPtr = (std::byte *) QuantBData + // - n * BlockCountK * BlkLen / 2 + // b data - n * BlockCountK * sizeof(float); // scale + size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n; + uint8_t * QuantBDataPtr = (uint8_t *) QuantBData + // + n * BlockCountK * BlkLen / 2 + // b data + n * BlockCountK * sizeof(_Float16); // scale float * CPtr = C + n; if (NBLKS < 16) { CPtr = tmp; LDC = 16 * sizeof(float); } - if (Bias != nullptr) { - const float * bias = Bias + n; - if (NBLKS < 16) { - __asm__ volatile( - "vsetvli t0, %[N], e32, m2 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "vse32.v v0, (%[DST]) \n\t" - : - : [SRC] "r"(bias), [DST] "r"(tmp), [N] "r"(NBLKS) - : "cc", "t0"); - bias = tmp; - } - __asm__ volatile(LOAD_BIAS - "addi t3, %[BlockCountK], 0 \n\t" - "addi a1, %[A], 0 \n\t" - "addi s1, %[B], 0 \n\t" - "BLOCK_COUNTK_LOOP%=: \n\t" - "addi s5, s1, 0 \n\t" - "addi s1, s5, 64 \n\t" - "addi s2, s1, 32 \n\t" - "addi s3, s1, 32*2 \n\t" - "addi s4, s1, 32*3 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v16, v16, v16 \n\t" - // load a scale - "flw f1, (a1) \n\t" - "flw f2, 4(a1) \n\t" - "flw f3, 8(a1) \n\t" - "flw f4, 12(a1) \n\t" - "addi a1, a1, 16 \n\t" - "addi t2, %[INNER], 0 \n\t" - "BLOCK_INNER_LOOP%=: \n\t" - - LOAD_B_16x8x2 - - "vsetvli t0, zero, e8, m1 \n\t" - "vle8.v v10, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vle8.v v11, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vadd.vi v2, v2, -8 \n\t" - "vadd.vi v3, v3, -8 \n\t" - "vadd.vi v4, v4, -8 \n\t" - "vadd.vi v5, v5, -8 \n\t" - "vadd.vi v6, v6, -8 \n\t" - "vadd.vi v7, v7, -8 \n\t" - "vadd.vi v8, v8, -8 \n\t" - "vadd.vi v9, v9, -8 \n\t" - - SQ4BIT_KERNEL_COMP_4x16x16 - - "addi t2, t2, -1 \n\t" - "bnez t2, BLOCK_INNER_LOOP%= \n\t" - - LOAD_SCALE_4x16 - - "vsetvli t0, zero, e32, m8 \n\t" - "vfcvt.f.x.v v16, v16 \n\t" - "vfmacc.vv v24, v16, v8 \n\t" - "addi t3, t3, -1 \n\t" - "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" - - "RESULT_SAVE%=: \n\t" - - SAVE_RESULT_4x16 - - : - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), - [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr), [BIAS] "r"(bias) - : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", - "s2", "s3", "s4", "s5", "s6"); - - } else { - __asm__ volatile( - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v24, v24, v24 \n\t" - "addi t3, %[BlockCountK], 0 \n\t" - "addi a1, %[A], 0 \n\t" - "addi s1, %[B], 0 \n\t" - "BLOCK_COUNTK_LOOP%=: \n\t" - "addi s5, s1, 0 \n\t" - "addi s1, s5, 64 \n\t" - "addi s2, s1, 32 \n\t" - "addi s3, s1, 32*2 \n\t" - "addi s4, s1, 32*3 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v16, v16, v16 \n\t" - // load a scale - "flw f1, (a1) \n\t" - "flw f2, 4(a1) \n\t" - "flw f3, 8(a1) \n\t" - "flw f4, 12(a1) \n\t" - "addi a1, a1, 16 \n\t" - "addi t2, %[INNER], 0 \n\t" - "BLOCK_INNER_LOOP%=: \n\t" - - LOAD_B_16x8x2 - - "vsetvli t0, zero, e8, m1 \n\t" - "vle8.v v10, (a1) \n\t" - - "addi a1, a1, 32 \n\t" - "vle8.v v11, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vadd.vi v2, v2, -8 \n\t" - "vadd.vi v3, v3, -8 \n\t" - "vadd.vi v4, v4, -8 \n\t" - "vadd.vi v5, v5, -8 \n\t" - "vadd.vi v6, v6, -8 \n\t" - "vadd.vi v7, v7, -8 \n\t" - "vadd.vi v8, v8, -8 \n\t" - "vadd.vi v9, v9, -8 \n\t" - - SQ4BIT_KERNEL_COMP_4x16x16 - - "addi t2, t2, -1 \n\t" - "bnez t2, BLOCK_INNER_LOOP%= \n\t" - - LOAD_SCALE_4x16 - - "vsetvli t0, zero, e32, m8 \n\t" - "vfcvt.f.x.v v16, v16 \n\t" - "vfmacc.vv v24, v16, v8 \n\t" - "addi t3, t3, -1 \n\t" - "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" - - "RESULT_SAVE%=: \n\t" - - SAVE_RESULT_4x16 - - : - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), - [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr) - : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3", - "s4", "s5", "s6"); - } + + __asm__ volatile( + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v24, v24, v24 \n\t" + "addi t3, %[BlockCountK], 0 \n\t" + "addi a1, %[A], 0 \n\t" + "addi s1, %[B], 0 \n\t" + "BLOCK_COUNTK_LOOP%=: \n\t" + "addi s5, s1, 0 \n\t" + "addi s1, s5, 32 \n\t" + "addi s2, s1, 32 \n\t" + "addi s3, s1, 32*2 \n\t" + "addi s4, s1, 32*3 \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v16, v16, v16 \n\t" + // load a scale + "flw f1, (a1) \n\t" + "flw f2, 4(a1) \n\t" + "flw f3, 8(a1) \n\t" + "flw f4, 12(a1) \n\t" + "addi a1, a1, 16 \n\t" + "addi t2, %[INNER], 0 \n\t" + "BLOCK_INNER_LOOP%=: \n\t" + + LOAD_B_16x8x2 + + "vsetvli t0, zero, e8, m1 \n\t" + "vle8.v v10, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vle8.v v11, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vadd.vi v2, v2, -8 \n\t" + "vadd.vi v3, v3, -8 \n\t" + "vadd.vi v4, v4, -8 \n\t" + "vadd.vi v5, v5, -8 \n\t" + "vadd.vi v6, v6, -8 \n\t" + "vadd.vi v7, v7, -8 \n\t" + "vadd.vi v8, v8, -8 \n\t" + "vadd.vi v9, v9, -8 \n\t" + + SQ4BIT_KERNEL_COMP_4x16x16 + + "addi t2, t2, -1 \n\t" + "bnez t2, BLOCK_INNER_LOOP%= \n\t" + + LOAD_SCALE_4x16_FP16 + + "vsetvli t0, zero, e32, m8 \n\t" + "vfcvt.f.x.v v16, v16 \n\t" + "vfmacc.vv v24, v16, v8 \n\t" + "addi t3, t3, -1 \n\t" + "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" + "RESULT_SAVE%=: \n\t" + + SAVE_RESULT_4x16 + + : + : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), + [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr) + : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3", "s4", + "s5", "s6"); } } - if (CountN % 16 != 0) { - // stroe output from tmp to C when NBLKS less than 16. - float * CPtr = C + CountN / 16 * 16; - const size_t N = CountN % 16; - LDC = ldc * sizeof(float); - __asm__ volatile( - "vsetvli t0, %[N], e32, m2 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "addi s2, %[SRC], 64 \n\t" - "addi s3, %[SRC], 64*2 \n\t" - "addi s4, %[SRC], 64*3 \n\t" - "vle32.v v2, (s2) \n\t" - "vle32.v v4, (s3) \n\t" - "vle32.v v6, (s4) \n\t" - "add t2, %[DST], %[LDC] \n\t" - "add t3, t2, %[LDC] \n\t" - "add t4, t3, %[LDC] \n\t" - "vse32.v v0, (%[DST]) \n\t" - "vse32.v v2, (t2) \n\t" - "vse32.v v4, (t3) \n\t" - "vse32.v v6, (t4) \n\t" - : - : [N] "r"(N), [SRC] "r"(tmp), [DST] "r"(CPtr), [LDC] "r"(LDC) - : "cc", "t0", "t2", "t3", "t4", "s2", "s3", "s4"); - } } template -void SQ4BitGemmM1Kernel_CompInt8_ScaleFp16_Impl(size_t BlkLen, - const std::byte * QuantA, - const std::byte * QuantBData, - const float * QuantBScale, - const std::byte * QuantBZeroPoint, - float * C, - size_t CountN, - size_t BlockCountK, - const float * Bias) { - GGML_UNUSED(QuantBScale); - GGML_UNUSED(QuantBZeroPoint); +void SQ4BitGemmM1Kernel_CompInt8_ScaleFp16_Impl(size_t BlkLen, + const uint8_t * QuantA, + const uint8_t * QuantBData, + float * C, + size_t CountN, + size_t BlockCountK, + const size_t ldc) { + GGML_UNUSED(ldc); size_t INNER = BlkLen / 16; if constexpr (HasZeroPoint) { for (size_t n = 0; n < CountN; n += 16) { - size_t nblks = (CountN - n) > 16 ? 16 : CountN - n; - std::byte * QuantBDataPtr = (std::byte *) QuantBData + // - n * BlockCountK * BlkLen / 2 + // b data - n * BlockCountK * sizeof(uint8_t) + // zp - n * BlockCountK * sizeof(_Float16); // scale - float * CPtr = C + n; - size_t cnt = BlockCountK; - if (Bias != nullptr) { - const float * bias = Bias + n; - __asm__ volatile( - "addi t3, %[NBLKS], 0 \n\t" - "vsetvli t0, zero, e8, m1 \n\t" - - "vmv.v.i v13, 3 \n\t" - "li s1, 24 \n\t" - "vsetvli t0, s1, e8, m1 \n\t" - "vmv.v.i v13, 2 \n\t" - "vsetvli t0, zero, e8, mf2 \n\t" - "vmv.v.i v13, 1 \n\t" - "vsetvli t0, zero, e8, mf4 \n\t" - "vmv.v.i v13, 0 \n\t" - "addi s1, %[B], 0 \n\t" - "addi s2, %[B], 8 \n\t" - "addi s3, %[B], 16 \n\t" - "addi s4, %[B], 24 \n\t" - // zp offset - "addi s7, %[B], 32 \n\t" - // a offset - "addi s5, %[A], 0 \n\t" - "addi s6, %[A], 12 \n\t" - - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v28, (%[BIAS]) \n\t" - "sub t3, t3, t0 \n\t" - "addi %[BIAS], %[BIAS], 16 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v29, (%[BIAS]) \n\t" - "sub t3, t3, t0 \n\t" - "addi %[BIAS], %[BIAS], 16 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v30, (%[BIAS]) \n\t" - "sub t3, t3, t0 \n\t" - "addi %[BIAS], %[BIAS], 16 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v31, (%[BIAS]) \n\t" - - "LOOP_K%=: \n\t" - "vsetvli t0, zero, e16, mf4 \n\t" - - "vle16.v v4, (s1) \n\t" - "addi s1, s1, 48 \n\t" - "vle16.v v5, (s2) \n\t" - "addi s2, s2, 72 \n\t" - "vle16.v v6, (s3) \n\t" - "addi s3, s3, 96 \n\t" - "vle16.v v7, (s4) \n\t" - "addi s4, s4, 120 \n\t" - "flw f1, (s5) \n\t" - "addi s5, s5, 4 \n\t" - "vfwcvt.f.f.v v8, v4 \n\t" - "vfwcvt.f.f.v v9, v5 \n\t" - "vfwcvt.f.f.v v10, v6 \n\t" - "vfwcvt.f.f.v v11, v7 \n\t" - - "vsetvli t0, zero, e32, mf2 \n\t" - "addi t5, %[INNER], 0 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "vxor.vv v18, v18, v18 \n\t" - "vxor.vv v20, v20, v20 \n\t" - "vxor.vv v22, v22, v22 \n\t" - "vfmul.vf v24, v8, f1 \n\t" - "vfmul.vf v25, v9, f1 \n\t" - "vfmul.vf v26, v10, f1 \n\t" - "vfmul.vf v27, v11, f1 \n\t" - "addi %[CNT], %[CNT], -1 \n\t" - - SQ4BIT_KERNEL_LOAD_ZP_16X1 - - "LOOP_INNER%=: \n\t" - - SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 - - "vsub.vv v0, v0, v8 \n\t" - "vsub.vv v4, v4, v8 \n\t" - "vsub.vv v1, v1, v9 \n\t" - "vsub.vv v5, v5, v9 \n\t" - "vsub.vv v2, v2, v10 \n\t" - "vsub.vv v6, v6, v10 \n\t" - "vsub.vv v3, v3, v11 \n\t" - "vsub.vv v7, v7, v11 \n\t" - - SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 - - "bnez t5, LOOP_INNER%= \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - - SQ4BIT_KERNEL_ACC_F16_1X4X4 - "addi s7, s1, 32 \n\t" - - "bnez %[CNT], LOOP_K%= \n\t" - "addi t3, zero, 16 \n\t" - "addi s1, %[C], 16 \n\t" - "addi s2, %[C], 32 \n\t" - "addi s3, %[C], 48 \n\t" - "blt %[NBLKS], t3, ST_TAIL%= \n\t" - "vse32.v v28, (%[C]) \n\t" - "vse32.v v29, (s1) \n\t" - "vse32.v v30, (s2) \n\t" - "vse32.v v31, (s3) \n\t" - "jal x0, END%= \n\t" - - "ST_TAIL%=: \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v28, (%[C]) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v29, (s1) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v30, (s2) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v31, (s3) \n\t" - "END%=: \n\t" - - : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks), [BIAS] "+r"(bias) - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) - : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7"); - } else { - __asm__ volatile( - "vsetvli t0, zero, e32, m4 \n\t" - "vxor.vv v28, v28, v28 \n\t" - - "vsetvli t0, zero, e8, m1 \n\t" - "vmv.v.i v13, 3 \n\t" - "li s1, 24 \n\t" - "vsetvli t0, s1, e8, m1 \n\t" - "vmv.v.i v13, 2 \n\t" - "vsetvli t0, zero, e8, mf2 \n\t" - "vmv.v.i v13, 1 \n\t" - "vsetvli t0, zero, e8, mf4 \n\t" - "vmv.v.i v13, 0 \n\t" - - "addi s1, %[B], 0 \n\t" - "addi s2, %[B], 8 \n\t" - "addi s3, %[B], 16 \n\t" - "addi s4, %[B], 24 \n\t" - - "addi s7, %[B], 32 \n\t" - - "addi s5, %[A], 0 \n\t" - "addi s6, %[A], 12 \n\t" - "LOOP_K%=: \n\t" - "vsetvli t0, zero, e16, mf4 \n\t" - "vle16.v v4, (s1) \n\t" - "addi s1, s1, 48 \n\t" - "vle16.v v5, (s2) \n\t" - "addi s2, s2, 72 \n\t" - "vle16.v v6, (s3) \n\t" - "addi s3, s3, 96 \n\t" - "vle16.v v7, (s4) \n\t" - "addi s4, s4, 120 \n\t" - "flw f1, (s5) \n\t" - "addi s5, s5, 4 \n\t" - - "vfwcvt.f.f.v v8, v4 \n\t" - "vfwcvt.f.f.v v9, v5 \n\t" - "vfwcvt.f.f.v v10, v6 \n\t" - "vfwcvt.f.f.v v11, v7 \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - - "addi t5, %[INNER], 0 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "vxor.vv v18, v18, v18 \n\t" - "vxor.vv v20, v20, v20 \n\t" - "vxor.vv v22, v22, v22 \n\t" - "vfmul.vf v24, v8, f1 \n\t" - "vfmul.vf v25, v9, f1 \n\t" - "vfmul.vf v26, v10, f1 \n\t" - "vfmul.vf v27, v11, f1 \n\t" - "addi %[CNT], %[CNT], -1 \n\t" - - SQ4BIT_KERNEL_LOAD_ZP_16X1 - - "LOOP_INNER%=: \n\t" - - SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 - - "vsub.vv v0, v0, v8 \n\t" - "vsub.vv v4, v4, v8 \n\t" - "vsub.vv v1, v1, v9 \n\t" - "vsub.vv v5, v5, v9 \n\t" - "vsub.vv v2, v2, v10 \n\t" - "vsub.vv v6, v6, v10 \n\t" - "vsub.vv v3, v3, v11 \n\t" - "vsub.vv v7, v7, v11 \n\t" - - SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 - - "bnez t5, LOOP_INNER%= \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - - SQ4BIT_KERNEL_ACC_F16_1X4X4 - "addi s7, s1, 32 \n\t" - - "bnez %[CNT], LOOP_K%= \n\t" - "addi t3, zero, 16 \n\t" - "addi s1, %[C], 16 \n\t" - "addi s2, %[C], 32 \n\t" - "addi s3, %[C], 48 \n\t" - "blt %[NBLKS], t3, ST_TAIL%= \n\t" - "vse32.v v28, (%[C]) \n\t" - "vse32.v v29, (s1) \n\t" - "vse32.v v30, (s2) \n\t" - "vse32.v v31, (s3) \n\t" - "jal x0, END%= \n\t" - - "ST_TAIL%=: \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v28, (%[C]) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v29, (s1) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v30, (s2) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v31, (s3) \n\t" - "END%=: \n\t" - - : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks) - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) - : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7"); - } - } - } else { - for (size_t n = 0; n < CountN; n += 16) { - size_t nblks = (CountN - n) > 16 ? 16 : CountN - n; - std::byte * QuantBDataPtr = (std::byte *) QuantBData + // - n * BlockCountK * BlkLen / 2 + // b data - n * BlockCountK * sizeof(_Float16); // scale + size_t nblks = (CountN - n) > 16 ? 16 : CountN - n; + uint8_t * QuantBDataPtr = (uint8_t *) QuantBData + // + n * BlockCountK * BlkLen / 2 + // b data + n * BlockCountK * sizeof(uint8_t) + // zp + n * BlockCountK * sizeof(_Float16); // scale float * CPtr = C + n; size_t cnt = BlockCountK; - if (Bias != nullptr) { - const float * bias = Bias + n; - __asm__ volatile( - "addi t3, %[NBLKS], 0 \n\t" - "addi s1, %[B], 0 \n\t" - "addi s2, %[B], 8 \n\t" - "addi s3, %[B], 16 \n\t" - "addi s4, %[B], 24 \n\t" - "addi s5, %[A], 0 \n\t" - "addi s6, %[A], 12 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v28, (%[BIAS]) \n\t" - "sub t3, t3, t0 \n\t" - "addi %[BIAS], %[BIAS], 16 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v29, (%[BIAS]) \n\t" - "sub t3, t3, t0 \n\t" - "addi %[BIAS], %[BIAS], 16 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v30, (%[BIAS]) \n\t" - "sub t3, t3, t0 \n\t" - "addi %[BIAS], %[BIAS], 16 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v31, (%[BIAS]) \n\t" - - "LOOP_K%=: \n\t" - "vsetvli t0, zero, e16, mf4 \n\t" - - "vle16.v v4, (s1) \n\t" - "addi s1, s1, 32 \n\t" - "vle16.v v5, (s2) \n\t" - "addi s2, s2, 56 \n\t" - "vle16.v v6, (s3) \n\t" - "addi s3, s3, 80 \n\t" - "vle16.v v7, (s4) \n\t" - "addi s4, s4, 104 \n\t" - "flw f1, (s5) \n\t" - "addi s5, s5, 4 \n\t" - "vfwcvt.f.f.v v8, v4 \n\t" - "vfwcvt.f.f.v v9, v5 \n\t" - "vfwcvt.f.f.v v10, v6 \n\t" - "vfwcvt.f.f.v v11, v7 \n\t" - - "vsetvli t0, zero, e32, mf2 \n\t" - "addi t5, %[INNER], 0 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "vxor.vv v18, v18, v18 \n\t" - "vxor.vv v20, v20, v20 \n\t" - "vxor.vv v22, v22, v22 \n\t" - "vfmul.vf v24, v8, f1 \n\t" - "vfmul.vf v25, v9, f1 \n\t" - "vfmul.vf v26, v10, f1 \n\t" - "vfmul.vf v27, v11, f1 \n\t" - "addi %[CNT], %[CNT], -1 \n\t" - "vsetvli t0, zero, e8, m1 \n\t" - "LOOP_INNER%=: \n\t" - - SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 - - "vadd.vi v0, v0, -8 \n\t" - "vadd.vi v1, v1, -8 \n\t" - "vadd.vi v2, v2, -8 \n\t" - "vadd.vi v3, v3, -8 \n\t" - "vadd.vi v4, v4, -8 \n\t" - "vadd.vi v5, v5, -8 \n\t" - "vadd.vi v6, v6, -8 \n\t" - "vadd.vi v7, v7, -8 \n\t" - - SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 - - "bnez t5, LOOP_INNER%= \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - - SQ4BIT_KERNEL_ACC_F16_1X4X4 - - "bnez %[CNT], LOOP_K%= \n\t" - "addi t3, zero, 16 \n\t" - "addi s1, %[C], 16 \n\t" - "addi s2, %[C], 32 \n\t" - "addi s3, %[C], 48 \n\t" - "blt %[NBLKS], t3, ST_TAIL%= \n\t" - "vse32.v v28, (%[C]) \n\t" - "vse32.v v29, (s1) \n\t" - "vse32.v v30, (s2) \n\t" - "vse32.v v31, (s3) \n\t" - "jal x0, END%= \n\t" - - "ST_TAIL%=: \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v28, (%[C]) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v29, (s1) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v30, (s2) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v31, (s3) \n\t" - "END%=: \n\t" - - : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks), [BIAS] "+r"(bias) - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) - : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6"); - } else { - __asm__ volatile( - "vsetvli t0, zero, e32, m4 \n\t" - "vxor.vv v28, v28, v28 \n\t" - "addi s1, %[B], 0 \n\t" - "addi s2, %[B], 8 \n\t" - "addi s3, %[B], 16 \n\t" - "addi s4, %[B], 24 \n\t" - - "addi s5, %[A], 0 \n\t" - "addi s6, %[A], 12 \n\t" - "LOOP_K%=: \n\t" - "vsetvli t0, zero, e16, mf4 \n\t" - "vle16.v v4, (s1) \n\t" - "addi s1, s1, 32 \n\t" - "vle16.v v5, (s2) \n\t" - "addi s2, s2, 56 \n\t" - "vle16.v v6, (s3) \n\t" - "addi s3, s3, 80 \n\t" - "vle16.v v7, (s4) \n\t" - "addi s4, s4, 104 \n\t" - "flw f1, (s5) \n\t" - "addi s5, s5, 4 \n\t" - - "vfwcvt.f.f.v v8, v4 \n\t" - "vfwcvt.f.f.v v9, v5 \n\t" - "vfwcvt.f.f.v v10, v6 \n\t" - "vfwcvt.f.f.v v11, v7 \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - - "addi t5, %[INNER], 0 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "vxor.vv v18, v18, v18 \n\t" - "vxor.vv v20, v20, v20 \n\t" - "vxor.vv v22, v22, v22 \n\t" - "vfmul.vf v24, v8, f1 \n\t" - "vfmul.vf v25, v9, f1 \n\t" - "vfmul.vf v26, v10, f1 \n\t" - "vfmul.vf v27, v11, f1 \n\t" - "addi %[CNT], %[CNT], -1 \n\t" - "vsetvli t0, zero, e8, m1 \n\t" - "LOOP_INNER%=: \n\t" - - SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 - - "vadd.vi v0, v0, -8 \n\t" - "vadd.vi v1, v1, -8 \n\t" - "vadd.vi v2, v2, -8 \n\t" - "vadd.vi v3, v3, -8 \n\t" - "vadd.vi v4, v4, -8 \n\t" - "vadd.vi v5, v5, -8 \n\t" - "vadd.vi v6, v6, -8 \n\t" - "vadd.vi v7, v7, -8 \n\t" - - SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 - - "bnez t5, LOOP_INNER%= \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - - SQ4BIT_KERNEL_ACC_F16_1X4X4 - - "bnez %[CNT], LOOP_K%= \n\t" - "addi t3, zero, 16 \n\t" - "addi s1, %[C], 16 \n\t" - "addi s2, %[C], 32 \n\t" - "addi s3, %[C], 48 \n\t" - "blt %[NBLKS], t3, ST_TAIL%= \n\t" - "vse32.v v28, (%[C]) \n\t" - "vse32.v v29, (s1) \n\t" - "vse32.v v30, (s2) \n\t" - "vse32.v v31, (s3) \n\t" - "jal x0, END%= \n\t" - - "ST_TAIL%=: \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v28, (%[C]) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v29, (s1) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v30, (s2) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v31, (s3) \n\t" - "END%=: \n\t" - - : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks) - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) - : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6"); - } - } - } -} -template -void SQ4BitGemmM1Kernel_CompInt8_Impl(size_t BlkLen, - const std::byte * QuantA, - const std::byte * QuantBData, - const float * QuantBScale, - const std::byte * QuantBZeroPoint, - float * C, - size_t CountN, - size_t BlockCountK, - const float * Bias) { - GGML_UNUSED(QuantBScale); - GGML_UNUSED(QuantBZeroPoint); - const size_t INNER = BlkLen / 16; - if constexpr (HasZeroPoint) { - for (size_t n = 0; n < CountN; n += 16) { - size_t nblks = (CountN - n) > 16 ? 16 : CountN - n; - std::byte * QuantBDataPtr = (std::byte *) QuantBData + // - n * BlockCountK * BlkLen / 2 + // b data - n * BlockCountK * sizeof(uint8_t) + // zp - n * BlockCountK * sizeof(float); // scale - float * CPtr = C + n; - size_t cnt = BlockCountK; - if (Bias != nullptr) { - const float * bias = Bias + n; - __asm__ volatile( - "addi t3, %[NBLKS], 0 \n\t" - "vsetvli t0, zero, e8, m1 \n\t" - "vmv.v.i v13, 3 \n\t" - "li s1, 24 \n\t" - "vsetvli t0, s1, e8, m1 \n\t" - "vmv.v.i v13, 2 \n\t" - "vsetvli t0, zero, e8, mf2 \n\t" - "vmv.v.i v13, 1 \n\t" - "vsetvli t0, zero, e8, mf4 \n\t" - "vmv.v.i v13, 0 \n\t" - "vsetvli t0, zero, e32, m4 \n\t" - "vxor.vv v28, v28, v28 \n\t" - - // scale offset, scale0.0, scale1.0, scale2.0, scale3.0....scale15.0 - "addi s1, %[B], 0 \n\t" - "addi s2, %[B], 16 \n\t" - "addi s3, %[B], 32 \n\t" - "addi s4, %[B], 48 \n\t" - // zp offset - "addi s7, %[B], 64 \n\t" - // a offset - "addi s5, %[A], 0 \n\t" - "addi s6, %[A], 12 \n\t" - - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v28, (%[BIAS]) \n\t" - "sub t3, t3, t0 \n\t" - "addi %[BIAS], %[BIAS], 16 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v29, (%[BIAS]) \n\t" - "sub t3, t3, t0 \n\t" - "addi %[BIAS], %[BIAS], 16 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v30, (%[BIAS]) \n\t" - "sub t3, t3, t0 \n\t" - "addi %[BIAS], %[BIAS], 16 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v31, (%[BIAS]) \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - "LOOP_K%=: \n\t" - - // load scale - "vle32.v v8, (s1) \n\t" - "addi s1, s1, 80 \n\t" - "vle32.v v9, (s2) \n\t" - "addi s2, s2, 96 \n\t" - "vle32.v v10, (s3) \n\t" - "addi s3, s3, 112 \n\t" - "vle32.v v11, (s4) \n\t" - "addi s4, s4, 128 \n\t" - - // load a scale - "flw f1, (s5) \n\t" - "addi s5, s5, 4 \n\t" - - "addi t5, %[INNER], 0 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "vxor.vv v18, v18, v18 \n\t" - "vxor.vv v20, v20, v20 \n\t" - "vxor.vv v22, v22, v22 \n\t" - - // a scale * b scale - "vfmul.vf v24, v8, f1 \n\t" - "vfmul.vf v25, v9, f1 \n\t" - "vfmul.vf v26, v10, f1 \n\t" - "vfmul.vf v27, v11, f1 \n\t" - "addi %[CNT], %[CNT], -1 \n\t" - - SQ4BIT_KERNEL_LOAD_ZP_16X1 - - "LOOP_INNER%=: \n\t" - - SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 - - "vsub.vv v0, v0, v8 \n\t" - "vsub.vv v4, v4, v8 \n\t" - "vsub.vv v1, v1, v9 \n\t" - "vsub.vv v5, v5, v9 \n\t" - "vsub.vv v2, v2, v10 \n\t" - "vsub.vv v6, v6, v10 \n\t" - "vsub.vv v3, v3, v11 \n\t" - "vsub.vv v7, v7, v11 \n\t" - - SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 - - "bnez t5, LOOP_INNER%= \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - - SQ4BIT_KERNEL_ACC_1X4X4 - "addi s7, s1, 64 \n\t" - - "bnez %[CNT], LOOP_K%= \n\t" - - "addi t3, zero, 16 \n\t" - "addi s1, %[C], 16 \n\t" - "addi s2, %[C], 32 \n\t" - "addi s3, %[C], 48 \n\t" - "blt %[NBLKS], t3, ST_TAIL%= \n\t" - "vse32.v v28, (%[C]) \n\t" - "vse32.v v29, (s1) \n\t" - "vse32.v v30, (s2) \n\t" - "vse32.v v31, (s3) \n\t" - "jal x0, END%= \n\t" - - "ST_TAIL%=: \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v28, (%[C]) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v29, (s1) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v30, (s2) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v31, (s3) \n\t" - "END%=: \n\t" - - : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks), [BIAS] "+r"(bias) - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) - : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7"); - } else { - __asm__ volatile( - "vsetvli t0, zero, e32, m4 \n\t" - "vxor.vv v28, v28, v28 \n\t" - - "vsetvli t0, zero, e8, m1 \n\t" - "vmv.v.i v13, 3 \n\t" - "li s1, 24 \n\t" - "vsetvli t0, s1, e8, m1 \n\t" - "vmv.v.i v13, 2 \n\t" - "vsetvli t0, zero, e8, mf2 \n\t" - "vmv.v.i v13, 1 \n\t" - "vsetvli t0, zero, e8, mf4 \n\t" - "vmv.v.i v13, 0 \n\t" - "addi s1, %[B], 0 \n\t" - "addi s2, %[B], 16 \n\t" - "addi s3, %[B], 32 \n\t" - "addi s4, %[B], 48 \n\t" - - "addi s7, %[B], 64 \n\t" - - "addi s5, %[A], 0 \n\t" - "addi s6, %[A], 12 \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - - "LOOP_K%=: \n\t" - "vle32.v v8, (s1) \n\t" - "addi s1, s1, 80 \n\t" - "vle32.v v9, (s2) \n\t" - "addi s2, s2, 96 \n\t" - "vle32.v v10, (s3) \n\t" - "addi s3, s3, 112 \n\t" - "vle32.v v11, (s4) \n\t" - "addi s4, s4, 128 \n\t" - - "flw f1, (s5) \n\t" - "addi s5, s5, 4 \n\t" - - "addi t5, %[INNER], 0 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "vxor.vv v18, v18, v18 \n\t" - "vxor.vv v20, v20, v20 \n\t" - "vxor.vv v22, v22, v22 \n\t" - - "vfmul.vf v24, v8, f1 \n\t" - "vfmul.vf v25, v9, f1 \n\t" - "vfmul.vf v26, v10, f1 \n\t" - "vfmul.vf v27, v11, f1 \n\t" - "addi %[CNT], %[CNT], -1 \n\t" - - SQ4BIT_KERNEL_LOAD_ZP_16X1 - - "LOOP_INNER%=: \n\t" - - SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 - - "vsub.vv v0, v0, v8 \n\t" - "vsub.vv v4, v4, v8 \n\t" - "vsub.vv v1, v1, v9 \n\t" - "vsub.vv v5, v5, v9 \n\t" - "vsub.vv v2, v2, v10 \n\t" - "vsub.vv v6, v6, v10 \n\t" - "vsub.vv v3, v3, v11 \n\t" - "vsub.vv v7, v7, v11 \n\t" - - SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 - - "bnez t5, LOOP_INNER%= \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - - SQ4BIT_KERNEL_ACC_1X4X4 - "addi s7, s1, 64 \n\t" - - "bnez %[CNT], LOOP_K%= \n\t" - - "addi t3, zero, 16 \n\t" - "addi s1, %[C], 16 \n\t" - "addi s2, %[C], 32 \n\t" - "addi s3, %[C], 48 \n\t" - "blt %[NBLKS], t3, ST_TAIL%= \n\t" - "vse32.v v28, (%[C]) \n\t" - "vse32.v v29, (s1) \n\t" - "vse32.v v30, (s2) \n\t" - "vse32.v v31, (s3) \n\t" - "jal x0, END%= \n\t" - - "ST_TAIL%=: \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v28, (%[C]) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v29, (s1) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v30, (s2) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v31, (s3) \n\t" - "END%=: \n\t" - - : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks) - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) - : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7"); - } + __asm__ volatile( + "vsetvli t0, zero, e32, m4 \n\t" + "vxor.vv v28, v28, v28 \n\t" + + "vsetvli t0, zero, e8, m1 \n\t" + "vmv.v.i v13, 3 \n\t" + "li s1, 24 \n\t" + "vsetvli t0, s1, e8, m1 \n\t" + "vmv.v.i v13, 2 \n\t" + "vsetvli t0, zero, e8, mf2 \n\t" + "vmv.v.i v13, 1 \n\t" + "vsetvli t0, zero, e8, mf4 \n\t" + "vmv.v.i v13, 0 \n\t" + + "addi s1, %[B], 0 \n\t" + "addi s2, %[B], 8 \n\t" + "addi s3, %[B], 16 \n\t" + "addi s4, %[B], 24 \n\t" + + "addi s7, %[B], 32 \n\t" + + "addi s5, %[A], 0 \n\t" + "addi s6, %[A], 12 \n\t" + "LOOP_K%=: \n\t" + "vsetvli t0, zero, e16, mf4 \n\t" + "vle16.v v4, (s1) \n\t" + "addi s1, s1, 48 \n\t" + "vle16.v v5, (s2) \n\t" + "addi s2, s2, 72 \n\t" + "vle16.v v6, (s3) \n\t" + "addi s3, s3, 96 \n\t" + "vle16.v v7, (s4) \n\t" + "addi s4, s4, 120 \n\t" + "flw f1, (s5) \n\t" + "addi s5, s5, 4 \n\t" + + "vfwcvt.f.f.v v8, v4 \n\t" + "vfwcvt.f.f.v v9, v5 \n\t" + "vfwcvt.f.f.v v10, v6 \n\t" + "vfwcvt.f.f.v v11, v7 \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" + + "addi t5, %[INNER], 0 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v18, v18 \n\t" + "vxor.vv v20, v20, v20 \n\t" + "vxor.vv v22, v22, v22 \n\t" + "vfmul.vf v24, v8, f1 \n\t" + "vfmul.vf v25, v9, f1 \n\t" + "vfmul.vf v26, v10, f1 \n\t" + "vfmul.vf v27, v11, f1 \n\t" + "addi %[CNT], %[CNT], -1 \n\t" + + SQ4BIT_KERNEL_LOAD_ZP_16X1 + + "LOOP_INNER%=: \n\t" + + SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 + + "vsub.vv v0, v0, v8 \n\t" + "vsub.vv v4, v4, v8 \n\t" + "vsub.vv v1, v1, v9 \n\t" + "vsub.vv v5, v5, v9 \n\t" + "vsub.vv v2, v2, v10 \n\t" + "vsub.vv v6, v6, v10 \n\t" + "vsub.vv v3, v3, v11 \n\t" + "vsub.vv v7, v7, v11 \n\t" + + SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 + + "bnez t5, LOOP_INNER%= \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" + + SQ4BIT_KERNEL_ACC_F16_1X4X4 + "addi s7, s1, 32 \n\t" + + "bnez %[CNT], LOOP_K%= \n\t" + "addi t3, zero, 16 \n\t" + "addi s1, %[C], 16 \n\t" + "addi s2, %[C], 32 \n\t" + "addi s3, %[C], 48 \n\t" + "blt %[NBLKS], t3, ST_TAIL%= \n\t" + "vse32.v v28, (%[C]) \n\t" + "vse32.v v29, (s1) \n\t" + "vse32.v v30, (s2) \n\t" + "vse32.v v31, (s3) \n\t" + "jal x0, END%= \n\t" + + "ST_TAIL%=: \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v28, (%[C]) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v29, (s1) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v30, (s2) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v31, (s3) \n\t" + "END%=: \n\t" + + : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks) + : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) + : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7"); } } else { for (size_t n = 0; n < CountN; n += 16) { - size_t nblks = (CountN - n) > 16 ? 16 : CountN - n; - std::byte * QuantBDataPtr = (std::byte *) QuantBData + // - n * BlockCountK * BlkLen / 2 + // b data - n * BlockCountK * sizeof(float); // scale + size_t nblks = (CountN - n) > 16 ? 16 : CountN - n; + uint8_t * QuantBDataPtr = (uint8_t *) QuantBData + // + n * BlockCountK * BlkLen / 2 + // b data + n * BlockCountK * sizeof(_Float16); // scale float * CPtr = C + n; size_t cnt = BlockCountK; - if (Bias != nullptr) { - const float * bias = Bias + n; - __asm__ volatile( - "addi t3, %[NBLKS], 0 \n\t" - "addi s1, %[B], 0 \n\t" - "addi s2, %[B], 16 \n\t" - "addi s3, %[B], 32 \n\t" - "addi s4, %[B], 48 \n\t" - "addi s5, %[A], 0 \n\t" - "addi s6, %[A], 12 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v28, (%[BIAS]) \n\t" - "sub t3, t3, t0 \n\t" - "addi %[BIAS], %[BIAS], 16 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v29, (%[BIAS]) \n\t" - "sub t3, t3, t0 \n\t" - "addi %[BIAS], %[BIAS], 16 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v30, (%[BIAS]) \n\t" - "sub t3, t3, t0 \n\t" - "addi %[BIAS], %[BIAS], 16 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v31, (%[BIAS]) \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - "LOOP_K%=: \n\t" - "vle32.v v8, (s1) \n\t" - "addi s1, s1, 64 \n\t" - "vle32.v v9, (s2) \n\t" - "addi s2, s2, 80 \n\t" - "vle32.v v10, (s3) \n\t" - "addi s3, s3, 96 \n\t" - "vle32.v v11, (s4) \n\t" - "addi s4, s4, 112 \n\t" - "flw f1, (s5) \n\t" - "addi s5, s5, 4 \n\t" - - "addi t5, %[INNER], 0 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "vxor.vv v18, v18, v18 \n\t" - "vxor.vv v20, v20, v20 \n\t" - "vxor.vv v22, v22, v22 \n\t" - "vfmul.vf v24, v8, f1 \n\t" - "vfmul.vf v25, v9, f1 \n\t" - "vfmul.vf v26, v10, f1 \n\t" - "vfmul.vf v27, v11, f1 \n\t" - "addi %[CNT], %[CNT], -1 \n\t" - "vsetvli t0, zero, e8, m1 \n\t" - "LOOP_INNER%=: \n\t" - - SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 - - "vadd.vi v0, v0, -8 \n\t" - "vadd.vi v1, v1, -8 \n\t" - "vadd.vi v2, v2, -8 \n\t" - "vadd.vi v3, v3, -8 \n\t" - "vadd.vi v4, v4, -8 \n\t" - "vadd.vi v5, v5, -8 \n\t" - "vadd.vi v6, v6, -8 \n\t" - "vadd.vi v7, v7, -8 \n\t" - - SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 - - "bnez t5, LOOP_INNER%= \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - - SQ4BIT_KERNEL_ACC_1X4X4 - - "bnez %[CNT], LOOP_K%= \n\t" - "addi t3, zero, 16 \n\t" - "addi s1, %[C], 16 \n\t" - "addi s2, %[C], 32 \n\t" - "addi s3, %[C], 48 \n\t" - "blt %[NBLKS], t3, ST_TAIL%= \n\t" - "vse32.v v28, (%[C]) \n\t" - "vse32.v v29, (s1) \n\t" - "vse32.v v30, (s2) \n\t" - "vse32.v v31, (s3) \n\t" - "jal x0, END%= \n\t" - - "ST_TAIL%=: \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v28, (%[C]) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v29, (s1) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v30, (s2) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v31, (s3) \n\t" - "END%=: \n\t" - - : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks), [BIAS] "+r"(bias) - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) - : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6"); - } else { - __asm__ volatile( - "vsetvli t0, zero, e32, m4 \n\t" - "vxor.vv v28, v28, v28 \n\t" - "addi s1, %[B], 0 \n\t" - "addi s2, %[B], 16 \n\t" - "addi s3, %[B], 32 \n\t" - "addi s4, %[B], 48 \n\t" - - "addi s5, %[A], 0 \n\t" - "addi s6, %[A], 12 \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - "LOOP_K%=: \n\t" - "vle32.v v8, (s1) \n\t" - "addi s1, s1, 64 \n\t" - "vle32.v v9, (s2) \n\t" - "addi s2, s2, 80 \n\t" - "vle32.v v10, (s3) \n\t" - "addi s3, s3, 96 \n\t" - "vle32.v v11, (s4) \n\t" - "addi s4, s4, 112 \n\t" - "flw f1, (s5) \n\t" - "addi s5, s5, 4 \n\t" - - "addi t5, %[INNER], 0 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "vxor.vv v18, v18, v18 \n\t" - "vxor.vv v20, v20, v20 \n\t" - "vxor.vv v22, v22, v22 \n\t" - "vfmul.vf v24, v8, f1 \n\t" - "vfmul.vf v25, v9, f1 \n\t" - "vfmul.vf v26, v10, f1 \n\t" - "vfmul.vf v27, v11, f1 \n\t" - "addi %[CNT], %[CNT], -1 \n\t" - "vsetvli t0, zero, e8, m1 \n\t" - "LOOP_INNER%=: \n\t" - - SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 - - "vadd.vi v0, v0, -8 \n\t" - "vadd.vi v1, v1, -8 \n\t" - "vadd.vi v2, v2, -8 \n\t" - "vadd.vi v3, v3, -8 \n\t" - "vadd.vi v4, v4, -8 \n\t" - "vadd.vi v5, v5, -8 \n\t" - "vadd.vi v6, v6, -8 \n\t" - "vadd.vi v7, v7, -8 \n\t" - - SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 - - "bnez t5, LOOP_INNER%= \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - - SQ4BIT_KERNEL_ACC_1X4X4 - - "bnez %[CNT], LOOP_K%= \n\t" - "addi t3, zero, 16 \n\t" - "addi s1, %[C], 16 \n\t" - "addi s2, %[C], 32 \n\t" - "addi s3, %[C], 48 \n\t" - "blt %[NBLKS], t3, ST_TAIL%= \n\t" - "vse32.v v28, (%[C]) \n\t" - "vse32.v v29, (s1) \n\t" - "vse32.v v30, (s2) \n\t" - "vse32.v v31, (s3) \n\t" - "jal x0, END%= \n\t" - - "ST_TAIL%=: \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v28, (%[C]) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v29, (s1) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v30, (s2) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v31, (s3) \n\t" - "END%=: \n\t" - - : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks) - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) - : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6"); - } - } - } -} - -template -inline void SQ4BitGemmM4Kernel_CompInt8_DispatchOnBlkLen(size_t BlkLen, - const std::byte * QuantA, - const std::byte * QuantBData, - const float * QuantBScale, - const std::byte * QuantBZeroPoint, - float * C, - size_t CountM, - size_t CountN, - size_t BlockStrideQuantB, - const float * Bias, - const size_t ldc, - const size_t scalestride) { - if (scalestride == 4) { - SQ4BitGemmM4Kernel_CompInt8_Impl(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, C, - CountN, BlockStrideQuantB, Bias, ldc); - - } else if (scalestride == 2) { - SQ4BitGemmM4Kernel_CompInt8_ScaleFp16_Impl( - BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, C, CountN, BlockStrideQuantB, Bias, ldc); - } -} -template -inline void SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen(size_t BlkLen, - const std::byte * QuantA, - const std::byte * QuantBData, - const float * QuantBScale, - const std::byte * QuantBZeroPoint, - float * C, - size_t CountM, - size_t CountN, - size_t BlockStrideQuantB, - const float * Bias, - const size_t ldc, - const size_t scalestride) { - if (scalestride == 4) { - SQ4BitGemmM1Kernel_CompInt8_Impl(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, C, - CountN, BlockStrideQuantB, Bias); - } else if (scalestride == 2) { - SQ4BitGemmM1Kernel_CompInt8_ScaleFp16_Impl(BlkLen, QuantA, QuantBData, QuantBScale, - QuantBZeroPoint, C, CountN, BlockStrideQuantB, Bias); + __asm__ volatile( + "vsetvli t0, zero, e32, m4 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "addi s1, %[B], 0 \n\t" + "addi s2, %[B], 8 \n\t" + "addi s3, %[B], 16 \n\t" + "addi s4, %[B], 24 \n\t" + + "addi s5, %[A], 0 \n\t" + "addi s6, %[A], 12 \n\t" + "LOOP_K%=: \n\t" + "vsetvli t0, zero, e16, mf4 \n\t" + "vle16.v v4, (s1) \n\t" + "addi s1, s1, 32 \n\t" + "vle16.v v5, (s2) \n\t" + "addi s2, s2, 56 \n\t" + "vle16.v v6, (s3) \n\t" + "addi s3, s3, 80 \n\t" + "vle16.v v7, (s4) \n\t" + "addi s4, s4, 104 \n\t" + "flw f1, (s5) \n\t" + "addi s5, s5, 4 \n\t" + + "vfwcvt.f.f.v v8, v4 \n\t" + "vfwcvt.f.f.v v9, v5 \n\t" + "vfwcvt.f.f.v v10, v6 \n\t" + "vfwcvt.f.f.v v11, v7 \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" + + "addi t5, %[INNER], 0 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v18, v18 \n\t" + "vxor.vv v20, v20, v20 \n\t" + "vxor.vv v22, v22, v22 \n\t" + "vfmul.vf v24, v8, f1 \n\t" + "vfmul.vf v25, v9, f1 \n\t" + "vfmul.vf v26, v10, f1 \n\t" + "vfmul.vf v27, v11, f1 \n\t" + "addi %[CNT], %[CNT], -1 \n\t" + "vsetvli t0, zero, e8, m1 \n\t" + "LOOP_INNER%=: \n\t" + + SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 + + "vadd.vi v0, v0, -8 \n\t" + "vadd.vi v1, v1, -8 \n\t" + "vadd.vi v2, v2, -8 \n\t" + "vadd.vi v3, v3, -8 \n\t" + "vadd.vi v4, v4, -8 \n\t" + "vadd.vi v5, v5, -8 \n\t" + "vadd.vi v6, v6, -8 \n\t" + "vadd.vi v7, v7, -8 \n\t" + + SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 + + "bnez t5, LOOP_INNER%= \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" + + SQ4BIT_KERNEL_ACC_F16_1X4X4 + + "bnez %[CNT], LOOP_K%= \n\t" + "addi t3, zero, 16 \n\t" + "addi s1, %[C], 16 \n\t" + "addi s2, %[C], 32 \n\t" + "addi s3, %[C], 48 \n\t" + "blt %[NBLKS], t3, ST_TAIL%= \n\t" + "vse32.v v28, (%[C]) \n\t" + "vse32.v v29, (s1) \n\t" + "vse32.v v30, (s2) \n\t" + "vse32.v v31, (s3) \n\t" + "jal x0, END%= \n\t" + + "ST_TAIL%=: \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v28, (%[C]) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v29, (s1) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v30, (s2) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v31, (s3) \n\t" + "END%=: \n\t" + + : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks) + : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) + : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6"); + } } } - } // namespace namespace ime1 { -size_t gemm_kernel_i8i4(size_t BlkLen, - const std::byte * QuantA, - const std::byte * QuantBData, - const float * QuantBScale, - const std::byte * QuantBZeroPoint, - float * C, - size_t CountM, - size_t CountN, - size_t CountK, - size_t BlockCountK, - size_t ldc, - const float * Bias, - const size_t ScaleStride) { - GGML_UNUSED(CountM); - GGML_UNUSED(CountK); - GGML_UNUSED(ldc); - if (CountM >= 4) { - if (QuantBZeroPoint != nullptr) { - SQ4BitGemmM4Kernel_CompInt8_DispatchOnBlkLen(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, - C, CountM, CountN, BlockCountK, Bias, ldc, ScaleStride); +size_t gemm_kernel_i8i4(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + if (count_m >= 4) { + if (quant_b_zp != nullptr) { + SQ4BitGemmM4Kernel_CompInt8_ScaleFp16_Impl(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_n, k_blks, + ldc); } else { - SQ4BitGemmM4Kernel_CompInt8_DispatchOnBlkLen(BlkLen, QuantA, QuantBData, QuantBScale, - QuantBZeroPoint, C, CountM, CountN, BlockCountK, Bias, - ldc, ScaleStride); + SQ4BitGemmM4Kernel_CompInt8_ScaleFp16_Impl(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_n, + k_blks, ldc); } return 4; } else { - if (QuantBZeroPoint != nullptr) { - SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, - C, CountM, CountN, BlockCountK, Bias, ldc, ScaleStride); + if (quant_b_zp != nullptr) { + SQ4BitGemmM1Kernel_CompInt8_ScaleFp16_Impl(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_n, k_blks, + ldc); } else { - SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen(BlkLen, QuantA, QuantBData, QuantBScale, - QuantBZeroPoint, C, CountM, CountN, BlockCountK, Bias, - ldc, ScaleStride); + SQ4BitGemmM1Kernel_CompInt8_ScaleFp16_Impl(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_n, + k_blks, ldc); } return 1; } } } // namespace ime1 -} // namespace sqnbitgemm_spacemit_ime +} // namespace spacemit_kernels diff --git a/ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp b/ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp new file mode 100644 index 00000000000..0c7a036a92a --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp @@ -0,0 +1,5768 @@ +#include "ggml-impl.h" +#include "ggml.h" +#include "ime_kernels.h" +#include "rvv_kernels.h" +#include "string.h" + +#include +#include +#include + +#if !defined(__riscv_v) || !defined(__riscv_v_intrinsic) +# error "riscv v extension or v_intrinsic not enabled" +#else +# include +#endif + +#if !defined(__riscv_zfh) +# error "riscv zfh extension not enabled" +#endif + +#if defined(RISCV64_SPACEMIT_IME2) +#else +# error "RISCV64_SPACEMIT_IME2 not defined" +#endif + +#if defined(__GNUC__) +# pragma GCC diagnostic ignored "-Woverlength-strings" +# pragma GCC diagnostic ignored "-Wcast-qual" +# pragma GCC diagnostic ignored "-Wunused-parameter" +#endif + +namespace spacemit_kernels { +namespace ime2 { + +template +void gemm_kernel_i8i2k_mrow_ref(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + using blk_type = nrow_block_q2_k; + constexpr float refactor_scale = 16.0f; + constexpr float factor_scale = 1.0f / refactor_scale; + + int64_t a_blk_stride = q8k_blk_size(256); + int64_t a_nrow_block_stride = a_blk_stride * MB_ROWS; + int64_t b_ncol_block_stride = sizeof(blk_type); + + float output[MB_ROWS * NB_COLS] = { 0 }; + _Float16 output_f16[MB_ROWS * NB_COLS] = { 0 }; + blk_type * quant_b_blk_data = (blk_type *) (quant_b_data); + + for (size_t ni = 0; ni < count_n; ni += NB_COLS, c_ptr += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + + int8_t * a_data = (int8_t *) quant_a_ptr + sizeof(float) * MB_ROWS + sizeof(int16_t) * MB_ROWS * 16; + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + output[ci + mi * NB_COLS] = 0; + } + } + + for (size_t ki = 0; ki < k_blks; ki++, quant_b_blk_data++, a_data += a_nrow_block_stride) { + uint8_t * b_data = quant_b_blk_data->qs; + uint8_t * scales = quant_b_blk_data->scales; + uint8_t * scales16 = (uint8_t *) (quant_b_blk_data->scales16); + uint8_t * zeros16 = (uint8_t *) (quant_b_blk_data->zeros16); + + _Float16 * scales_fp16 = (_Float16 *) scales16; + _Float16 * zeros_fp16 = (_Float16 *) zeros16; + + float * a_scale_row = (float *) (a_data - sizeof(float) * MB_ROWS - sizeof(int16_t) * MB_ROWS * 16); + int16_t * a_sum_row = (int16_t *) (a_data - sizeof(int16_t) * MB_ROWS * 16); + + memset(output_f16, 0, sizeof(output_f16)); + + uint8_t * scales_temp = scales; + uint8_t * zps_temp = scales; + for (size_t kii = 0; kii < 16; kii++, scales_temp += NB_COLS, zps_temp++) { + size_t b_shift = (kii % 4) * 2; + + uint8_t * b_data_col = b_data + (kii / 4) * NB_COLS * 16; + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + int16_t a_sum = a_sum_row[mi * 16 + kii]; + for (size_t ci = 0; ci < NB_COLS; ci++) { + _Float16 acc_0 = 0.0; + + uint8_t b_zp = zps_temp[ci * 16] >> 4; + uint8_t b_scale = scales_temp[ci] & 0x0F; + for (size_t bi = 0; bi < 16; bi++) { + int8_t a0 = a_data[mi * 256 + bi + kii * 16]; + uint8_t b0 = b_data_col[ci * 16 + bi]; + acc_0 += static_cast(a0) * static_cast((b0 >> b_shift) & 0x03); + } + + _Float16 scale_item = + static_cast<_Float16>(b_scale) * static_cast<_Float16>(factor_scale) * scales_fp16[ci]; + + output_f16[ci + mi * NB_COLS] += acc_0 * scale_item; + output[ci + mi * NB_COLS] += b_zp * a_sum * a_scale_row[mi] * zeros_fp16[ci]; + } + } + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + auto a_scale = a_scale_row[mi] * refactor_scale; + for (size_t ci = 0; ci < NB_COLS; ci++) { + output[ci + mi * NB_COLS] += output_f16[ci + mi * NB_COLS] * a_scale; + } + } + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < nb_real; ci++) { + c_ptr[mi * ldc + ci] = output[mi * NB_COLS + ci]; + } + } + } +} + +template +void gemm_kernel_i8i3k_mrow_ref(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + using blk_type = nrow_block_q2_k; + constexpr float refactor_scale = 16.0f; + constexpr float factor_scale = 1.0f / refactor_scale; + + int64_t a_blk_stride = q8k_blk_size(256); + int64_t a_nrow_block_stride = a_blk_stride * MB_ROWS; + int64_t b_ncol_block_stride = sizeof(blk_type); + + float output[MB_ROWS * NB_COLS] = { 0 }; + _Float16 output_f16[MB_ROWS * NB_COLS] = { 0 }; + + blk_type * quant_b_blk_data = (blk_type *) (quant_b_data); + + for (size_t ni = 0; ni < count_n; ni += NB_COLS, c_ptr += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + + int8_t * a_data = (int8_t *) quant_a_ptr + sizeof(float) * MB_ROWS + sizeof(int16_t) * MB_ROWS * 16; + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + output[ci + mi * NB_COLS] = 0; + } + } + + for (size_t ki = 0; ki < k_blks; ki++, quant_b_blk_data++, a_data += a_nrow_block_stride) { + uint8_t * b_data = quant_b_blk_data->qs; + uint8_t * b_hmask = quant_b_blk_data->hmask; + int8_t * scales = quant_b_blk_data->scales; + uint8_t * scales16 = (uint8_t *) (quant_b_blk_data->scales16); + + _Float16 * scales_fp16 = (_Float16 *) scales16; + + float * a_scale_row = (float *) (a_data - sizeof(float) * MB_ROWS - sizeof(int16_t) * MB_ROWS * 16); + int16_t * a_sum_row = (int16_t *) (a_data - sizeof(int16_t) * MB_ROWS * 16); + + memset(output_f16, 0, sizeof(output_f16)); + + int8_t * scales_temp = scales; + uint16_t * b_mask_col = (uint16_t *) b_hmask; + + float acc_0_max = 0.0f; + for (size_t kii = 0; kii < 16; kii++, scales_temp += NB_COLS, b_mask_col += NB_COLS) { + size_t b_shift = (kii % 4) * 2; + + uint8_t * b_data_col = b_data + (kii / 4) * NB_COLS * 16; + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + _Float16 acc_0 = 0; + // blk 2 * kii + 0 + uint16_t b_shift_mask = 1; + for (size_t bi = 0; bi < 16; bi++, b_shift_mask <<= 1) { + int8_t a0 = a_data[mi * 256 + bi + kii * 16]; + int8_t b0 = static_cast((b_data_col[ci * 16 + bi] >> b_shift) & 0x03); + b0 -= b_mask_col[ci] & b_shift_mask ? 0 : 4; + acc_0 += static_cast(a0) * static_cast(b0); + } + + _Float16 scale_item = static_cast<_Float16>(scales_temp[ci]) * scales_fp16[ci] * + static_cast<_Float16>(factor_scale); + + output_f16[ci + mi * NB_COLS] += acc_0 * scale_item; + } + } + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + auto a_scale = a_scale_row[mi] * refactor_scale; + for (size_t ci = 0; ci < NB_COLS; ci++) { + output[ci + mi * NB_COLS] += output_f16[ci + mi * NB_COLS] * a_scale; + } + } + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < nb_real; ci++) { + c_ptr[mi * ldc + ci] = output[mi * NB_COLS + ci]; + } + } + } +} + +template +void gemm_kernel_i8i4_mrow_ref(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + constexpr size_t kblks_per_blk = 16; + GGML_ASSERT(k_blks % kblks_per_blk == 0); + + int64_t b_blk_stride = (sizeof(_Float16) + (blk_len / 2) + (quant_b_zp ? sizeof(uint8_t) : 0)); + int64_t b_stride = k_blks * b_blk_stride; + int64_t a_blk_stride = q8_blk_size(blk_len, true); + int64_t a_nrow_block_stride = a_blk_stride * MB_ROWS; + int64_t b_ncol_block_stride = b_blk_stride * NB_COLS; + + float output[MB_ROWS * NB_COLS] = { 0 }; + _Float16 output_f16[MB_ROWS * NB_COLS] = { 0 }; + + for (size_t ni = 0; ni < count_n; ni += NB_COLS, c_ptr += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + uint8_t * b_data = (uint8_t *) quant_b_data + ni * b_stride + NB_COLS * sizeof(_Float16); + if (quant_b_zp) { + b_data += NB_COLS * sizeof(uint8_t); + } + + int8_t * a_data = (int8_t *) quant_a_ptr + sizeof(float) * MB_ROWS + sizeof(int16_t) * MB_ROWS; + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + output[ci + mi * NB_COLS] = 0.0f; + output_f16[ci + mi * NB_COLS] = static_cast<_Float16>(0.0f); + } + } + + size_t kii = 0; + for (size_t ki = 0; ki < k_blks; ki++, a_data += a_nrow_block_stride, b_data += b_ncol_block_stride) { + _Float16 * b_scale_fp16 = (_Float16 *) (b_data - NB_COLS * sizeof(_Float16)); + uint8_t * b_zp = nullptr; + if (quant_b_zp) { + b_scale_fp16 = (_Float16 *) (b_data - NB_COLS * sizeof(_Float16) - NB_COLS * sizeof(uint8_t)); + b_zp = (uint8_t *) (b_data - NB_COLS * sizeof(uint8_t)); + } + + float * a_scale_row = (float *) (a_data - sizeof(float) * MB_ROWS - sizeof(int16_t) * MB_ROWS); + int16_t * a_sum_row = (int16_t *) (a_data - sizeof(int16_t) * MB_ROWS); + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + _Float16 a_scale = a_scale_row[mi]; + int16_t a_sum = a_sum_row[mi]; + + for (size_t ci = 0; ci < NB_COLS; ci++) { + _Float16 b_scale = b_scale_fp16[ci]; + int32_t acc = 0; + if (b_zp) { + acc += a_sum * b_zp[ci]; + } else { + acc += a_sum * 8; + } + for (size_t bi = 0; bi < blk_len / 2; bi++) { + int8_t a0 = a_data[mi * blk_len + 2 * bi]; + int8_t a1 = a_data[mi * blk_len + 2 * bi + 1]; + uint8_t b = b_data[ci * blk_len / 2 + bi]; + int8_t b0 = static_cast(b & 0x0F); + int8_t b1 = static_cast((b & 0xF0) >> 4); + acc += static_cast(a0) * static_cast(b0) + + static_cast(a1) * static_cast(b1); + } + output_f16[ci + mi * NB_COLS] += + static_cast(acc) * static_cast(a_scale) * static_cast(b_scale); + } + } + + if (kii == kblks_per_blk - 1) { + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + output[ci + mi * NB_COLS] += static_cast(output_f16[ci + mi * NB_COLS]); + output_f16[ci + mi * NB_COLS] = 0.0f; + } + } + kii = 0; + } else { + kii++; + } + } + + if (kii == kblks_per_blk - 1) { + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + output[ci + mi * NB_COLS] += static_cast(output_f16[ci + mi * NB_COLS]); + output_f16[ci + mi * NB_COLS] = 0.0f; + } + } + kii = 0; + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < nb_real; ci++) { + c_ptr[mi * ldc + ci] = output[mi * NB_COLS + ci]; + } + } + } +} + +template +void gemm_kernel_i8i4_hp_mrow_ref(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + constexpr size_t k_subblks_per_superblk = 8; + + struct block_q4_0x32_layout { + _Float16 d[NB_COLS]; + uint8_t qs[16 * NB_COLS]; + }; + + GGML_ASSERT(blk_len == 256); + + const size_t b_superblk_stride = sizeof(block_q4_0x32_layout) * k_subblks_per_superblk + + (quant_b_zp ? NB_COLS * k_subblks_per_superblk * sizeof(uint8_t) : 0); + const size_t b_tile_stride = k_blks * b_superblk_stride; + + const size_t a_nrow_block_stride = q8_hp_blk_size(blk_len, true, true) * MB_ROWS; + const size_t a_subblk_stride = q8_hp_blk_size(32, false, false) * MB_ROWS; + + float output[MB_ROWS * NB_COLS] = { 0 }; + for (size_t ni = 0; ni < count_n; ni += NB_COLS, c_ptr += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + const uint8_t * b_tile_base = quant_b_data + (ni / NB_COLS) * b_tile_stride; + int8_t * a_data = (int8_t *) quant_a_ptr; + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + output[ci + mi * NB_COLS] = 0.0f; + } + } + + for (size_t ki = 0; ki < k_blks; ki++, a_data += a_nrow_block_stride) { + _Float16 output_f16[MB_ROWS * NB_COLS] = { 0 }; + + const uint8_t * b_superblk_ptr = b_tile_base + ki * b_superblk_stride; + const block_q4_0x32_layout * b_blocks = reinterpret_cast(b_superblk_ptr); + const uint8_t * b_zps = + quant_b_zp ? b_superblk_ptr + sizeof(block_q4_0x32_layout) * k_subblks_per_superblk : nullptr; + + _Float16 * a_sum_row = (_Float16 *) (a_data + a_subblk_stride * k_subblks_per_superblk); + _Float16 * a_scale_avg_row = (_Float16 *) (a_data + a_nrow_block_stride - sizeof(_Float16) * MB_ROWS); + _Float16 scale_factor = a_scale_avg_row[0]; + + for (size_t ksi = 0; ksi < k_subblks_per_superblk; ++ksi) { + const _Float16 * a_scale_row = reinterpret_cast(a_data + a_subblk_stride * ksi); + int8_t * a_subblk = a_data + a_subblk_stride * ksi + MB_ROWS * sizeof(_Float16); + const _Float16 a_scale = a_scale_row[0]; + const block_q4_0x32_layout & b_block = b_blocks[ksi]; + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + const uint8_t * b_qs = b_block.qs + ci * 16; + _Float16 b_scale = b_block.d[ci] * a_scale; + + int16_t acc = 0; + for (size_t bi = 0; bi < 16; bi++) { + uint8_t b = b_qs[bi]; + int8_t b0 = static_cast(b & 0x0F); + int8_t b1 = static_cast((b & 0xF0) >> 4); + + acc += static_cast(a_subblk[mi * 32 + 2 * bi]) * static_cast(b0) + + static_cast(a_subblk[mi * 32 + 2 * bi + 1]) * static_cast(b1); + } + + const _Float16 scaled_acc = static_cast<_Float16>(acc) * b_scale; + output_f16[ci + mi * NB_COLS] += scaled_acc; + } + } + } + + for (size_t ksi = 0; ksi < k_subblks_per_superblk; ++ksi) { + const _Float16 * a_scale_row = reinterpret_cast(a_data + a_subblk_stride * ksi); + const block_q4_0x32_layout & b_block = b_blocks[ksi]; + const uint8_t * b_zp_row = b_zps ? b_zps + ksi * NB_COLS : nullptr; + const _Float16 a_scale = a_scale_row[0]; + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + const _Float16 a_sum = a_sum_row[mi * k_subblks_per_superblk + ksi]; + for (size_t ci = 0; ci < NB_COLS; ci++) { + _Float16 b_scale = b_block.d[ci] * a_scale; + _Float16 a_sum_bzp = a_sum; + if (b_zp_row) { + a_sum_bzp = a_sum * static_cast<_Float16>(0.125f) * static_cast<_Float16>(b_zp_row[ci]); + } + + const _Float16 scaled_acc = a_sum_bzp * b_scale; + output[ci + mi * NB_COLS] += scaled_acc * scale_factor; + } + } + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + auto val = static_cast(output_f16[ci + mi * NB_COLS]) * static_cast(scale_factor); + output[ci + mi * NB_COLS] += val; + } + } + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < nb_real; ci++) { + c_ptr[mi * ldc + ci] = output[mi * NB_COLS + ci]; + } + } + } +} + +template +void moe_gemm_kernel_i8i4_mrow_ref(size_t blk_len, + const uint8_t ** quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float ** c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + int64_t b_blk_stride = (sizeof(ggml_fp16_t) + (blk_len / 2) + (quant_b_zp ? sizeof(uint8_t) : 0)); + int64_t b_stride = k_blks * b_blk_stride; + int64_t a_blk_stride = q8_blk_size(blk_len, true); + int64_t b_ncol_block_stride = b_blk_stride * NB_COLS; + + float output[MB_ROWS * NB_COLS] = { 0 }; + std::array a_data; + std::array c_data; + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + c_data[mi] = c_ptr[mi]; + } + + for (size_t ni = 0; ni < count_n; ni += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + uint8_t * b_data = (uint8_t *) quant_b_data + ni * b_stride + NB_COLS * sizeof(ggml_fp16_t); + if (quant_b_zp) { + b_data += NB_COLS * sizeof(uint8_t); + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + a_data[mi] = (int8_t *) quant_a_ptr[mi] + sizeof(float) + sizeof(int16_t); + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + output[ci + mi * NB_COLS] = 0; + } + } + + for (size_t ki = 0; ki < k_blks; ki++, b_data += b_ncol_block_stride) { + ggml_fp16_t * b_scale_fp16 = (ggml_fp16_t *) (b_data - NB_COLS * sizeof(ggml_fp16_t)); + uint8_t * b_zp = nullptr; + if (quant_b_zp) { + b_scale_fp16 = (ggml_fp16_t *) (b_data - NB_COLS * sizeof(ggml_fp16_t) - NB_COLS * sizeof(uint8_t)); + b_zp = (uint8_t *) (b_data - NB_COLS * sizeof(uint8_t)); + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + float * a_scale_row = (float *) (a_data[mi] - sizeof(float) - sizeof(int16_t)); + int16_t * a_sum_row = (int16_t *) (a_data[mi] - sizeof(int16_t)); + + float a_scale = *a_scale_row; + int16_t a_sum = *a_sum_row; + + for (size_t ci = 0; ci < NB_COLS; ci++) { + float b_scale = ggml_fp16_to_fp32(b_scale_fp16[ci]); + int32_t acc = 0; + if (b_zp) { + acc += a_sum * b_zp[ci]; + } else { + acc += a_sum * 8; + } + for (size_t bi = 0; bi < blk_len / 2; bi++) { + int8_t a0 = (a_data[mi])[2 * bi]; + int8_t a1 = (a_data[mi])[2 * bi + 1]; + uint8_t b = b_data[ci * blk_len / 2 + bi]; + int8_t b0 = static_cast(b & 0x0F); + int8_t b1 = static_cast((b & 0xF0) >> 4); + acc += static_cast(a0) * static_cast(b0) + + static_cast(a1) * static_cast(b1); + } + output[ci + mi * NB_COLS] += static_cast(acc) * a_scale * b_scale; + } + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + a_data[mi] += a_blk_stride; + } + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < nb_real; ci++) { + (c_data[mi])[ci] = output[mi * NB_COLS + ci]; + } + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + c_data[mi] += NB_COLS; + } + } +} + +template +void moe_gemm_kernel_i8i5_mrow_ref(size_t blk_len, + const uint8_t ** quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float ** c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + GGML_UNUSED(count_m); + GGML_UNUSED(ldc); + + // blk_len is expected to be 32 for Q5 types. + int64_t a_blk_stride = q8_blk_size(blk_len, true); + + float output[MB_ROWS * NB_COLS] = { 0 }; + std::array a_data; + std::array c_data; + + for (size_t mi = 0; mi < MB_ROWS; ++mi) { + c_data[mi] = c_ptr[mi]; + } + + if (quant_b_zp) { + using blk_type = nrow_block_q5_1; + + for (size_t ni = 0; ni < count_n; ni += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + blk_type * quant_b_blk_data = (blk_type *) quant_b_data + (ni / NB_COLS) * k_blks; + + for (size_t mi = 0; mi < MB_ROWS; ++mi) { + a_data[mi] = (int8_t *) quant_a_ptr[mi] + sizeof(float) + sizeof(int16_t); + } + + for (size_t mi = 0; mi < MB_ROWS; ++mi) { + for (size_t ci = 0; ci < NB_COLS; ++ci) { + output[ci + mi * NB_COLS] = 0; + } + } + + for (size_t ki = 0; ki < k_blks; ++ki, ++quant_b_blk_data) { + for (size_t mi = 0; mi < MB_ROWS; ++mi) { + float * a_scale_row = (float *) (a_data[mi] - sizeof(float) - sizeof(int16_t)); + int16_t * a_sum_row = (int16_t *) (a_data[mi] - sizeof(int16_t)); + float a_scale = *a_scale_row; + int16_t a_sum = *a_sum_row; + + for (size_t ci = 0; ci < NB_COLS; ++ci) { + float b_scale = ggml_fp16_to_fp32(quant_b_blk_data->scales16[ci]); + uint8_t b_zp_val = quant_b_blk_data->zp[ci]; + int32_t acc = a_sum * static_cast(b_zp_val); + + for (size_t bi = 0; bi < blk_len / 2; ++bi) { + int8_t a0 = a_data[mi][2 * bi]; + int8_t a1 = a_data[mi][2 * bi + 1]; + uint8_t qs_byte = quant_b_blk_data->qs[ci * (blk_len / 2) + bi]; + int8_t b0 = static_cast(qs_byte & 0x0F); + int8_t b1 = static_cast((qs_byte >> 4) & 0x0F); + uint8_t qh_byte0 = quant_b_blk_data->qh[ci * 4 + (2 * bi) / 8]; + uint8_t qh_byte1 = quant_b_blk_data->qh[ci * 4 + (2 * bi + 1) / 8]; + uint8_t h0 = (qh_byte0 >> ((2 * bi) % 8)) & 1; + uint8_t h1 = (qh_byte1 >> ((2 * bi + 1) % 8)) & 1; + + b0 |= (h0 << 4); + b1 |= (h1 << 4); + + acc += static_cast(a0) * static_cast(b0) + + static_cast(a1) * static_cast(b1); + } + + output[ci + mi * NB_COLS] += static_cast(acc) * a_scale * b_scale; + } + + a_data[mi] += a_blk_stride; + } + } + + for (size_t mi = 0; mi < MB_ROWS; ++mi) { + for (size_t ci = 0; ci < nb_real; ++ci) { + c_data[mi][ci] = output[mi * NB_COLS + ci]; + } + c_data[mi] += NB_COLS; + } + } + } else { + using blk_type = nrow_block_q5_0; + + for (size_t ni = 0; ni < count_n; ni += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + blk_type * quant_b_blk_data = (blk_type *) quant_b_data + (ni / NB_COLS) * k_blks; + + for (size_t mi = 0; mi < MB_ROWS; ++mi) { + a_data[mi] = (int8_t *) quant_a_ptr[mi] + sizeof(float) + sizeof(int16_t); + } + + for (size_t mi = 0; mi < MB_ROWS; ++mi) { + for (size_t ci = 0; ci < NB_COLS; ++ci) { + output[ci + mi * NB_COLS] = 0; + } + } + + for (size_t ki = 0; ki < k_blks; ++ki, ++quant_b_blk_data) { + for (size_t mi = 0; mi < MB_ROWS; ++mi) { + float * a_scale_row = (float *) (a_data[mi] - sizeof(float) - sizeof(int16_t)); + int16_t * a_sum_row = (int16_t *) (a_data[mi] - sizeof(int16_t)); + float a_scale = *a_scale_row; + int16_t a_sum = *a_sum_row; + + for (size_t ci = 0; ci < NB_COLS; ++ci) { + float b_scale = ggml_fp16_to_fp32(quant_b_blk_data->scales16[ci]); + int32_t acc = a_sum * 16; + + for (size_t bi = 0; bi < blk_len / 2; ++bi) { + int8_t a0 = a_data[mi][2 * bi]; + int8_t a1 = a_data[mi][2 * bi + 1]; + uint8_t qs_byte = quant_b_blk_data->qs[ci * (blk_len / 2) + bi]; + int8_t b0 = static_cast(qs_byte & 0x0F); + int8_t b1 = static_cast((qs_byte >> 4) & 0x0F); + uint8_t qh_byte0 = quant_b_blk_data->qh[ci * 4 + (2 * bi) / 8]; + uint8_t qh_byte1 = quant_b_blk_data->qh[ci * 4 + (2 * bi + 1) / 8]; + uint8_t h0 = (qh_byte0 >> ((2 * bi) % 8)) & 1; + uint8_t h1 = (qh_byte1 >> ((2 * bi + 1) % 8)) & 1; + + b0 |= (h0 << 4); + b1 |= (h1 << 4); + + acc += static_cast(a0) * static_cast(b0) + + static_cast(a1) * static_cast(b1); + } + + output[ci + mi * NB_COLS] += static_cast(acc) * a_scale * b_scale; + } + + a_data[mi] += a_blk_stride; + } + } + + for (size_t mi = 0; mi < MB_ROWS; ++mi) { + for (size_t ci = 0; ci < nb_real; ++ci) { + c_data[mi][ci] = output[mi * NB_COLS + ci]; + } + c_data[mi] += NB_COLS; + } + } + } +} + +template +void gemm_kernel_i8i8_mrow_ref(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + int64_t b_blk_stride = (sizeof(ggml_fp16_t) + blk_len); + int64_t b_stride = k_blks * b_blk_stride; + int64_t a_blk_stride = q8_blk_size(blk_len, true); + int64_t a_nrow_block_stride = a_blk_stride * MB_ROWS; + int64_t b_ncol_block_stride = b_blk_stride * NB_COLS; + + float output[MB_ROWS * NB_COLS] = { 0 }; + + for (size_t ni = 0; ni < count_n; ni += NB_COLS, c_ptr += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + int8_t * b_data = (int8_t *) quant_b_data + ni * b_stride + NB_COLS * sizeof(ggml_fp16_t); + + int8_t * a_data = (int8_t *) quant_a_ptr + sizeof(float) * MB_ROWS + sizeof(int16_t) * MB_ROWS; + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + output[ci + mi * NB_COLS] = 0; + } + } + + for (size_t ki = 0; ki < k_blks; ki++, a_data += a_nrow_block_stride, b_data += b_ncol_block_stride) { + ggml_fp16_t * b_scale_fp16 = (ggml_fp16_t *) (b_data - NB_COLS * sizeof(ggml_fp16_t)); + + float * a_scale_row = (float *) (a_data - sizeof(float) * MB_ROWS - sizeof(int16_t) * MB_ROWS); + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + float a_scale = a_scale_row[mi]; + for (size_t ci = 0; ci < NB_COLS; ci++) { + float b_scale = ggml_fp16_to_fp32(b_scale_fp16[ci]); + int32_t acc = 0; + for (size_t bi = 0; bi < blk_len; bi++) { + int8_t a0 = a_data[mi * blk_len + bi]; + int8_t b0 = b_data[ci * blk_len + bi]; + acc += static_cast(a0) * static_cast(b0); + } + output[ci + mi * NB_COLS] += static_cast(acc) * a_scale * b_scale; + } + } + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < nb_real; ci++) { + c_ptr[mi * ldc + ci] = output[mi * NB_COLS + ci]; + } + } + } +} + +template +void gemm_kernel_i8i5_mrow_ref(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + // blk_len is expected to be 32 for Q5 types + // quant_b_zp != nullptr => nrow_block_q5_1 (has zp) + // quant_b_zp == nullptr => nrow_block_q5_0 (no zp) + + int64_t a_blk_stride = q8_blk_size(blk_len, true); + int64_t a_nrow_block_stride = a_blk_stride * MB_ROWS; + + float output[MB_ROWS * NB_COLS] = { 0 }; + + if (quant_b_zp) { + // nrow_block_q5_1: scales16[NB_COLS] + zp[NB_COLS] + qh[4*NB_COLS] + qs[16*NB_COLS] + using blk_type = nrow_block_q5_1; + int64_t b_ncol_block_stride = sizeof(blk_type); + blk_type * quant_b_blk_data = (blk_type *) quant_b_data; + + for (size_t ni = 0; ni < count_n; ni += NB_COLS, c_ptr += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + + int8_t * a_data = (int8_t *) quant_a_ptr + sizeof(float) * MB_ROWS + sizeof(int16_t) * MB_ROWS; + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + output[ci + mi * NB_COLS] = 0; + } + } + + for (size_t ki = 0; ki < k_blks; ki++, quant_b_blk_data++, a_data += a_nrow_block_stride) { + float * a_scale_row = (float *) (a_data - sizeof(float) * MB_ROWS - sizeof(int16_t) * MB_ROWS); + int16_t * a_sum_row = (int16_t *) (a_data - sizeof(int16_t) * MB_ROWS); + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + float a_scale = a_scale_row[mi]; + int16_t a_sum = a_sum_row[mi]; + + for (size_t ci = 0; ci < NB_COLS; ci++) { + float b_scale = ggml_fp16_to_fp32(quant_b_blk_data->scales16[ci]); + uint8_t b_zp_val = quant_b_blk_data->zp[ci]; + int32_t acc = a_sum * static_cast(b_zp_val); + + for (size_t bi = 0; bi < blk_len / 2; bi++) { + int8_t a0 = a_data[mi * blk_len + 2 * bi]; + int8_t a1 = a_data[mi * blk_len + 2 * bi + 1]; + uint8_t qs_byte = quant_b_blk_data->qs[ci * (blk_len / 2) + bi]; + int8_t b0 = static_cast(qs_byte & 0x0F); + int8_t b1 = static_cast((qs_byte >> 4) & 0x0F); + + // Extract high bits from qh + // qh is packed as 4 bytes per column (32 bits for 32 elements) + uint8_t qh_byte0 = quant_b_blk_data->qh[ci * 4 + (2 * bi) / 8]; + uint8_t qh_byte1 = quant_b_blk_data->qh[ci * 4 + (2 * bi + 1) / 8]; + uint8_t h0 = (qh_byte0 >> ((2 * bi) % 8)) & 1; + uint8_t h1 = (qh_byte1 >> ((2 * bi + 1) % 8)) & 1; + + b0 |= (h0 << 4); + b1 |= (h1 << 4); + + acc += static_cast(a0) * static_cast(b0) + + static_cast(a1) * static_cast(b1); + } + output[ci + mi * NB_COLS] += static_cast(acc) * a_scale * b_scale; + } + } + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < nb_real; ci++) { + c_ptr[mi * ldc + ci] = output[mi * NB_COLS + ci]; + } + } + } + } else { + // nrow_block_q5_0: scales16[NB_COLS] + qh[4*NB_COLS] + qs[16*NB_COLS] + using blk_type = nrow_block_q5_0; + int64_t b_ncol_block_stride = sizeof(blk_type); + blk_type * quant_b_blk_data = (blk_type *) quant_b_data; + + for (size_t ni = 0; ni < count_n; ni += NB_COLS, c_ptr += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + + int8_t * a_data = (int8_t *) quant_a_ptr + sizeof(float) * MB_ROWS + sizeof(int16_t) * MB_ROWS; + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + output[ci + mi * NB_COLS] = 0; + } + } + + for (size_t ki = 0; ki < k_blks; ki++, quant_b_blk_data++, a_data += a_nrow_block_stride) { + float * a_scale_row = (float *) (a_data - sizeof(float) * MB_ROWS - sizeof(int16_t) * MB_ROWS); + int16_t * a_sum_row = (int16_t *) (a_data - sizeof(int16_t) * MB_ROWS); + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + float a_scale = a_scale_row[mi]; + int16_t a_sum = a_sum_row[mi]; + + for (size_t ci = 0; ci < NB_COLS; ci++) { + float b_scale = ggml_fp16_to_fp32(quant_b_blk_data->scales16[ci]); + // Q5_0 has no zp, use default offset 16 (midpoint of 5-bit unsigned range) + int32_t acc = a_sum * 16; + + for (size_t bi = 0; bi < blk_len / 2; bi++) { + int8_t a0 = a_data[mi * blk_len + 2 * bi]; + int8_t a1 = a_data[mi * blk_len + 2 * bi + 1]; + uint8_t qs_byte = quant_b_blk_data->qs[ci * (blk_len / 2) + bi]; + int8_t b0 = static_cast(qs_byte & 0x0F); + int8_t b1 = static_cast((qs_byte >> 4) & 0x0F); + + // Extract high bits from qh + uint8_t qh_byte0 = quant_b_blk_data->qh[ci * 4 + (2 * bi) / 8]; + uint8_t qh_byte1 = quant_b_blk_data->qh[ci * 4 + (2 * bi + 1) / 8]; + uint8_t h0 = (qh_byte0 >> ((2 * bi) % 8)) & 1; + uint8_t h1 = (qh_byte1 >> ((2 * bi + 1) % 8)) & 1; + + b0 |= (h0 << 4); + b1 |= (h1 << 4); + + acc += static_cast(a0) * static_cast(b0) + + static_cast(a1) * static_cast(b1); + } + output[ci + mi * NB_COLS] += static_cast(acc) * a_scale * b_scale; + } + } + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < nb_real; ci++) { + c_ptr[mi * ldc + ci] = output[mi * NB_COLS + ci]; + } + } + } + } +} + +template +void gemm_kernel_i8mxfp4_mrow_ref(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + // blk_len is expected to be 32 (QK_MXFP4) + // quant_b_zp is unused for MXFP4 (symmetric quantization) + GGML_UNUSED(quant_b_zp); + + int64_t a_blk_stride = q8_blk_size(blk_len, true); + int64_t a_nrow_block_stride = a_blk_stride * MB_ROWS; + + float output[MB_ROWS * NB_COLS] = { 0 }; + + using blk_type = nrow_block_mxfp4; + blk_type * quant_b_blk_data = (blk_type *) quant_b_data; + + for (size_t ni = 0; ni < count_n; ni += NB_COLS, c_ptr += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + + int8_t * a_data = (int8_t *) quant_a_ptr + sizeof(float) * MB_ROWS + sizeof(int16_t) * MB_ROWS; + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + output[ci + mi * NB_COLS] = 0; + } + } + + for (size_t ki = 0; ki < k_blks; ki++, quant_b_blk_data++, a_data += a_nrow_block_stride) { + float * a_scale_row = (float *) (a_data - sizeof(float) * MB_ROWS - sizeof(int16_t) * MB_ROWS); + int16_t * a_sum_row = (int16_t *) (a_data - sizeof(int16_t) * MB_ROWS); + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + float a_scale = a_scale_row[mi]; + + for (size_t ci = 0; ci < NB_COLS; ci++) { + float b_scale = GGML_E8M0_TO_FP32_HALF(quant_b_blk_data->e[ci]); + + // Read 32 sign bits for this column + uint32_t sign_bits; + memcpy(&sign_bits, &quant_b_blk_data->qh[ci * 4], 4); + + int32_t acc = 0; + for (size_t bi = 0; bi < blk_len / 2; bi++) { + int8_t a0 = a_data[mi * blk_len + 2 * bi]; + int8_t a1 = a_data[mi * blk_len + 2 * bi + 1]; + + // qs[ci*16 + bi] stores abs(vals[bi*2]) in low 4 bits + // and abs(vals[bi*2+1]) in high 4 bits + uint8_t qs_byte = quant_b_blk_data->qs[ci * 16 + bi]; + int8_t b_abs0 = static_cast(qs_byte & 0x0F); + int8_t b_abs1 = static_cast((qs_byte >> 4) & 0x0F); + + // Extract sign bits: bit (2*bi) for vals[2*bi], bit (2*bi+1) for vals[2*bi+1] + int8_t b0 = (sign_bits >> (2 * bi)) & 1 ? -b_abs0 : b_abs0; + int8_t b1 = (sign_bits >> (2 * bi + 1)) & 1 ? -b_abs1 : b_abs1; + + acc += static_cast(a0) * static_cast(b0) + + static_cast(a1) * static_cast(b1); + } + output[ci + mi * NB_COLS] += static_cast(acc) * a_scale * b_scale; + } + } + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < nb_real; ci++) { + c_ptr[mi * ldc + ci] = output[mi * NB_COLS + ci]; + } + } + } +} + +void gemm_kernel_i8i2k_m1(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + constexpr size_t NB_COLS = 32; + using blk_type = nrow_block_q2_k; + + int64_t b_ncol_block_stride = sizeof(blk_type) * k_blks; + + for (size_t ni = 0; ni < count_n; ni += NB_COLS) { + uint8_t * b_data = (uint8_t *) quant_b_data + (ni / NB_COLS) * b_ncol_block_stride; + int8_t * a_data = (int8_t *) quant_a_ptr; + float * dst_c = (float *) c_ptr + ni; + + asm volatile( + "vsetvli t0, x0, e16, m1 \n\t" + "vxor.vv v31, v31, v31 \n\t" + "mv s1, %[BK] \n\t" + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + // load scale A + "flw fa0, (%[A]) \n\t" + "addi %[A], %[A], 4 \n\t" + + "li t1, 4 \n\t" + "addi t2, %[B], 512 \n\t" // B data addr + "addi t3, %[A], 32 \n\t" // A data addr + "addi s3, %[B], 0 \n\t" + "vxor.vv v30, v29, v29 \n\t" // tmp result + + "INNER_K_LOOP%=: \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vxor.vv v2, v2, v2 \n\t" + "vxor.vv v3, v3, v3 \n\t" + "vxor.vv v4, v4, v4 \n\t" + "vxor.vv v5, v5, v5 \n\t" + "vxor.vv v6, v6, v6 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "vxor.vv v29, v29, v29 \n\t" + + // load scale B + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v0, (%[B]) \n\t" + "addi %[B], %[B], 128 \n\t" + + // A data, 1x64@i8 + "vsetivli t0, 16, e8, mf4 \n\t" + "vle8.v v2, (t3) \n\t" + "addi t3, t3, 16 \n\t" + + "vsetivli t0, 16, e8, mf4 \n\t" + "vle8.v v4, (t3) \n\t" + "addi t3, t3, 16 \n\t" + + "vsetivli t0, 16, e8, mf4 \n\t" + "vle8.v v5, (t3) \n\t" + "addi t3, t3, 16 \n\t" + + "vsetivli t0, 16, e8, mf4 \n\t" + "vle8.v v6, (t3) \n\t" + "addi t3, t3, 16 \n\t" + + "vsetvli t0, x0, e64, mf2 \n\t" + "vslideup.vi v3, v4, 2 \n\t" + "vslideup.vi v28, v5, 4 \n\t" + "vslideup.vi v29, v6, 6 \n\t" + + // init the accumu to zero + "vsetvli t0, x0, e16, m1 \n\t" + "vxor.vv v20, v18, v18 \n\t" + "vxor.vv v22, v18, v18 \n\t" + "vxor.vv v24, v18, v18 \n\t" + "vxor.vv v26, v18, v18 \n\t" + + // B data, 32x64@i2 + "vsetvli t0, x0, e8, m1 \n\t" + "vl4r.v v4, (t2) \n\t" + "addi t2, t2, 512 \n\t" + "vand.vi v8, v4, 0x3 \n\t" // 0-15 + "vsrl.vi v9, v4, 2 \n\t" + "vsrl.vi v10, v4, 4 \n\t" + "vsrl.vi v11, v4, 6 \n\t" // 48-63 + "vand.vi v9, v9, 0x3 \n\t" // 16-31 + "vand.vi v10, v10, 0x3 \n\t" // 32-47 + + "vand.vi v12, v5, 0x3 \n\t" // 0-15 + "vsrl.vi v13, v5, 2 \n\t" + "vsrl.vi v14, v5, 4 \n\t" + "vsrl.vi v15, v5, 6 \n\t" // 48-63 + "vand.vi v13, v13, 0x3 \n\t" // 16-31 + "vand.vi v14, v14, 0x3 \n\t" // 32-47 + + "vand.vi v16, v6, 0x3 \n\t" // 0-15 + "vsrl.vi v17, v6, 2 \n\t" + "vsrl.vi v18, v6, 4 \n\t" + "vsrl.vi v19, v6, 6 \n\t" // 48-63 + "vand.vi v17, v17, 0x3 \n\t" // 16-31 + "vand.vi v18, v18, 0x3 \n\t" // 32-47 + + "vand.vi v4, v7, 0x3 \n\t" // 0-15 + "vsrl.vi v5, v7, 2 \n\t" + "vsrl.vi v6, v7, 4 \n\t" + "vsrl.vi v7, v7, 6 \n\t" // 48-63 + "vand.vi v5, v5, 0x3 \n\t" // 16-31 + "vand.vi v6, v6, 0x3 \n\t" // 32-47 + + // i2 * i8 vmadot + "vsetvli t0, x0, e8, m1 \n\t" + "vmadotsu v20, v2, v8, i8 \n\t" + "vmadotsu v22, v2, v12, i8 \n\t" + "vmadotsu v24, v2, v16, i8 \n\t" + "vmadotsu v26, v2, v4, i8 \n\t" + + "vmadotsu v20, v3, v9, i8 \n\t" + "vmadotsu v22, v3, v13, i8 \n\t" + "vmadotsu v24, v3, v17, i8 \n\t" + "vmadotsu v26, v3, v5, i8 \n\t" + + "vmadotsu v20, v28, v10, i8 \n\t" + "vmadotsu v22, v28, v14, i8 \n\t" + "vmadotsu v24, v28, v18, i8 \n\t" + "vmadotsu v26, v28, v6, i8 \n\t" + + "vmadotsu v20, v29, v11, i8 \n\t" + "vmadotsu v22, v29, v15, i8 \n\t" + "vmadotsu v24, v29, v19, i8 \n\t" + "vmadotsu v26, v29, v7, i8 \n\t" + + "vand.vi v10, v0, 0xf \n\t" // scale + "vwadd.vx v12, v10, x0 \n\t" + "vsetvli t0, x0, e16, m2 \n\t" + "vwadd.vx v16, v12, x0 \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vpack.vv v2, v20, v22, 2 \n\t" + "vpack.vv v4, v24, v26, 2 \n\t" + "vpack.vv v6, v2, v4, 3 \n\t" // 0,1 + "vpack.vv v8, v3, v5, 3 \n\t" // 2,3 + + // mul scale + "vmacc.vv v30, v6, v16 \n\t" + "vmacc.vv v30, v7, v17 \n\t" + "vmacc.vv v30, v8, v18 \n\t" + "vmacc.vv v30, v9, v19 \n\t" + + "addi t1, t1, -1 \n\t" + "bgtz t1, INNER_K_LOOP%= \n\t" + + // load zp B + "vsetvli t0, x0, e8, m4 \n\t" + "vle8.v v4, (s3) \n\t" + "vsrl.vi v8, v4, 4 \n\t" // zp + + // asum * zp + "vsetvli t0, x0, e16, m1 \n\t" + "vxor.vv v20, v20, v20 \n\t" + "vxor.vv v22, v22, v22 \n\t" + "vxor.vv v24, v24, v24 \n\t" + "vxor.vv v26, v26, v26 \n\t" + + "vsetvli t0, x0, e16, mf4 \n\t" + "vle16.v v2, (%[A]) \n\t" + "vsetvli t0, x0, e8, mf4 \n\t" + "vnsrl.wi v12, v2, 0 \n\t" // low 8 + "vnsra.wi v13, v2, 8 \n\t" // high 8 + + "vsetvli t0, x0, e32, m1 \n\t" + "vmadotsu v20, v13, v8, i8 \n\t" + "vmadotsu v22, v13, v9, i8 \n\t" + "vmadotsu v24, v13, v10, i8 \n\t" + "vmadotsu v26, v13, v11, i8 \n\t" + + "vsll.vi v20, v20, 8 \n\t" + "vsll.vi v22, v22, 8 \n\t" + "vsll.vi v24, v24, 8 \n\t" + "vsll.vi v26, v26, 8 \n\t" + + "vmadotu v20, v12, v8, i8 \n\t" + "vmadotu v22, v12, v9, i8 \n\t" + "vmadotu v24, v12, v10, i8 \n\t" + "vmadotu v26, v12, v11, i8 \n\t" + + "vpack.vv v2, v20, v22, 2 \n\t" + "vpack.vv v4, v24, v26, 2 \n\t" + "vpack.vv v28, v2, v4, 3 \n\t" + + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v0, (t2) \n\t" // scale16 + "addi t2, t2, 64 \n\t" + "vle16.v v1, (t2) \n\t" // zero16 + "vfwcvt.f.f.v v2, v0 \n\t" + "vfwcvt.f.f.v v4, v1 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vfcvt.f.x.v v30, v30 \n\t" + "vfcvt.f.x.v v28, v28 \n\t" + "addi %[B], t2, 64 \n\t" + "mv %[A], t3 \n\t" + + "vfmul.vv v30, v30, v2 \n\t" // mul scale16 + "vfmacc.vv v30, v28, v4 \n\t" // + mul zero16 + "vfmacc.vf v31, fa0, v30 \n\t" + "addi s1, s1, -1 \n\t" + "bgtz s1, BLK_LOOP%= \n\t" + + // save + "vsetvli t0, x0, e32, m1 \n\t" + "vse32.v v31, (%[DST]) \n\t" + : [A] "+r"(a_data), [B] "+r"(b_data) + : [DST] "r"(dst_c), [BK] "r"(k_blks) + : "t0", "t1", "t2", "t3", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", + "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", + "v28", "v29", "v30", "v31", "fa0", "t4", "t5", "t6", "s1", "s2", "s3"); + } +} + +void gemm_kernel_i8i2k_m4(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + constexpr size_t NB_COLS = 32; + using blk_type = nrow_block_q2_k; + + int64_t b_ncol_block_stride = sizeof(blk_type) * k_blks; + _Float16 scale = 0.0625f; + _Float16 scale_1 = 16.0f; + + for (size_t ni = 0; ni < count_n; ni += NB_COLS) { + uint8_t * b_data = (uint8_t *) quant_b_data + (ni / NB_COLS) * b_ncol_block_stride; + int8_t * a_data = (int8_t *) quant_a_ptr; + float * dst_c = (float *) c_ptr + ni; + + asm volatile( + "vsetvli t0, x0, e16, m1 \n\t" + "vxor.vv v28, v31, v31 \n\t" // init result + "vxor.vv v29, v31, v31 \n\t" + "vxor.vv v30, v31, v31 \n\t" + "vxor.vv v31, v31, v31 \n\t" + "mv s1, %[BK] \n\t" + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + // load scale A + "flw fa0, (%[A]) \n\t" + "flw fa1, 4(%[A]) \n\t" + "flw fa2, 8(%[A]) \n\t" + "flw fa3, 12(%[A]) \n\t" + "addi %[A], %[A], 16 \n\t" + + "li t1, 4 \n\t" + "addi t2, %[B], 512 \n\t" // B data addr + "addi t3, %[A], 128 \n\t" // A data addr + "addi s4, t2, 1024 \n\t" // scale16 addr + "addi s4, s4, 1024 \n\t" // TODO + "addi s3, %[B], 0 \n\t" + + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v1, (s4) \n\t" // load scale16 + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v22, v1, v1, 3 \n\t" + + "addi s4, t3, 256 \n\t" // addr 1 + "addi s5, t3, 512 \n\t" // addr 2 + "addi s6, t3, 768 \n\t" // addr 3 + + // init the accu to 0 + "vxor.vv v24, v24, v24 \n\t" + "vxor.vv v25, v25, v25 \n\t" + "vxor.vv v26, v26, v26 \n\t" + "vxor.vv v27, v27, v27 \n\t" + + "INNER_K_LOOP%=: \n\t" + // load scale B + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v1, (%[B]) \n\t" + "addi %[B], %[B], 128 \n\t" + "vand.vi v1, v1, 0xf \n\t" + + "vfwcvt.f.x.v v20, v1 \n\t" // f16 scale B + "vsetvli t0, x0, e16, m1 \n\t" + "vfmul.vv v0, v20, v22 \n\t" // mul scale16 + "vfmul.vv v1, v21, v22 \n\t" // mul scale16 + "vfmul.vf v0, v0, %[SCALE] \n\t" // mul magic + "vfmul.vf v1, v1, %[SCALE] \n\t" // mul magic + + // A data, 4x64@i8 + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v2, (t3) \n\t" + "addi t3, t3, 64 \n\t" + "vle8.v v3, (s4) \n\t" + "addi s4, s4, 64 \n\t" + "vle8.v v4, (s5) \n\t" + "addi s5, s5, 64 \n\t" + "vle8.v v5, (s6) \n\t" + "addi s6, s6, 64 \n\t" + + // 4x64 => 4x16x4 + "vsetvli t0, x0, e8, m1 \n\t" + "vpack.vv v6, v2, v3, 1 \n\t" + "vpack.vv v8, v4, v5, 1 \n\t" + "vpack.vv v2, v6, v8, 2 \n\t" // 0, 2 + + "vpack.vv v20, v2, v2, 3 \n\t" // 1 + "vor.vv v23, v21, v21 \n\t" + "vpack.vv v20, v3, v3, 3 \n\t" // 3 + + // B data, 32x64@i2 + "vsetvli t0, x0, e8, m1 \n\t" + "vl4r.v v4, (t2) \n\t" + "addi t2, t2, 512 \n\t" + "vand.vi v8, v4, 0x3 \n\t" // 0-15 + "vsrl.vi v9, v4, 2 \n\t" + "vsrl.vi v10, v4, 4 \n\t" + "vsrl.vi v11, v4, 6 \n\t" // 48-63 + "vand.vi v9, v9, 0x3 \n\t" // 16-31 + "vand.vi v10, v10, 0x3 \n\t" // 32-47 + + "vand.vi v12, v5, 0x3 \n\t" // 0-15 + "vsrl.vi v13, v5, 2 \n\t" + "vsrl.vi v14, v5, 4 \n\t" + "vsrl.vi v15, v5, 6 \n\t" // 48-63 + "vand.vi v13, v13, 0x3 \n\t" // 16-31 + "vand.vi v14, v14, 0x3 \n\t" // 32-47 + + "vand.vi v16, v6, 0x3 \n\t" // 0-15 + "vsrl.vi v17, v6, 2 \n\t" + "vsrl.vi v18, v6, 4 \n\t" + "vsrl.vi v19, v6, 6 \n\t" // 48-63 + "vand.vi v17, v17, 0x3 \n\t" // 16-31 + "vand.vi v18, v18, 0x3 \n\t" // 32-47 + + "vand.vi v4, v7, 0x3 \n\t" // 0-15 + "vsrl.vi v5, v7, 2 \n\t" + "vsrl.vi v6, v7, 4 \n\t" + "vsrl.vi v7, v7, 6 \n\t" // 48-63 + "vand.vi v5, v5, 0x3 \n\t" // 16-31 + "vand.vi v6, v6, 0x3 \n\t" // 32-47 + + // i2 * i8 vmadot + "vsetvli t0, x0, e8, m1 \n\t" + "vmadotsu.hp v24, v2, v8, v0, 0, i8 \n\t" + "vmadotsu.hp v25, v2, v12, v0, 1, i8 \n\t" + "vmadotsu.hp v26, v2, v16, v0, 2, i8 \n\t" + "vmadotsu.hp v27, v2, v4, v0, 3, i8 \n\t" + + "vmadotsu.hp v24, v23, v9, v0, 4, i8 \n\t" + "vmadotsu.hp v25, v23, v13, v0, 5, i8\n\t" + "vmadotsu.hp v26, v23, v17, v0, 6, i8\n\t" + "vmadotsu.hp v27, v23, v5, v0, 7, i8 \n\t" + + "vmadotsu.hp v24, v3, v10, v1, 0, i8 \n\t" + "vmadotsu.hp v25, v3, v14, v1, 1, i8 \n\t" + "vmadotsu.hp v26, v3, v18, v1, 2, i8 \n\t" + "vmadotsu.hp v27, v3, v6, v1, 3, i8 \n\t" + + "vmadotsu.hp v24, v21, v11, v1, 4, i8\n\t" + "vmadotsu.hp v25, v21, v15, v1, 5, i8\n\t" + "vmadotsu.hp v26, v21, v19, v1, 6, i8\n\t" + "vmadotsu.hp v27, v21, v7, v1, 7, i8 \n\t" + + "addi t1, t1, -1 \n\t" + "bgtz t1, INNER_K_LOOP%= \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v2, v24, v25, 1 \n\t" + "vpack.vv v4, v26, v27, 1 \n\t" + "vpack.vv v6, v2, v4, 2 \n\t" // 0,1,2,3 + + "vxor.vv v18, v18, v18 \n\t" + "vxor.vv v20, v20, v20 \n\t" + "vxor.vv v22, v22, v22 \n\t" + "vxor.vv v24, v24, v24 \n\t" + // load zp B, 16x8x4@int4 + "vsetvli t0, x0, e8, m4 \n\t" + "vle8.v v0, (s3) \n\t" + "vsrl.vi v0, v0, 4 \n\t" // zp + + // 4x16@int16 + "vsetvli t0, x0, e16, m1 \n\t" // a sum + "vle16.v v12, (%[A]) \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vnsrl.wi v10, v12, 0 \n\t" // low 8 + "vnsra.wi v11, v12, 8 \n\t" // high 8 + + // asum * zp + "vsetvli t0, x0, e32, m1 \n\t" + "vmadotsu v18, v11, v0, i8 \n\t" + "vmadotsu v20, v11, v1, i8 \n\t" + "vmadotsu v22, v11, v2, i8 \n\t" + "vmadotsu v24, v11, v3, i8 \n\t" + "vsll.vi v18, v18, 8 \n\t" + "vsll.vi v20, v20, 8 \n\t" + "vsll.vi v22, v22, 8 \n\t" + "vsll.vi v24, v24, 8 \n\t" + "vmadotu v18, v10, v0, i8 \n\t" + "vmadotu v20, v10, v1, i8 \n\t" + "vmadotu v22, v10, v2, i8 \n\t" + "vmadotu v24, v10, v3, i8 \n\t" + + "vpack.vv v10, v18, v20, 2 \n\t" + "vpack.vv v12, v22, v24, 2 \n\t" + "vpack.vv v14, v10, v12, 3 \n\t" + "vpack.vv v16, v11, v13, 3 \n\t" + + "vsetvli t0, x0, e16, mf2 \n\t" + "addi t2, t2, 64 \n\t" + "vle16.v v20, (t2) \n\t" // zero16 + "vfwcvt.f.f.v v22, v20 \n\t" + + // mul 1/magic + "vsetvli t0, x0, e16, m1 \n\t" + "vfwmul.vf v0, v6, %[SCALE_1] \n\t" + "vfwmul.vf v2, v7, %[SCALE_1] \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vfcvt.f.x.v v14, v14 \n\t" + "vfcvt.f.x.v v15, v15 \n\t" + "vfcvt.f.x.v v16, v16 \n\t" + "vfcvt.f.x.v v17, v17 \n\t" + + "addi %[B], t2, 64 \n\t" + "mv %[A], s6 \n\t" + + "vfmacc.vv v0, v14, v22 \n\t" // + mul zero16 + "vfmacc.vv v1, v15, v22 \n\t" + "vfmacc.vv v2, v16, v22 \n\t" + "vfmacc.vv v3, v17, v22 \n\t" + + "vfmacc.vf v28, fa0, v0 \n\t" // mul a scale + "vfmacc.vf v29, fa1, v1 \n\t" + "vfmacc.vf v30, fa2, v2 \n\t" + "vfmacc.vf v31, fa3, v3 \n\t" + + "addi s1, s1, -1 \n\t" + "bgtz s1, BLK_LOOP%= \n\t" + + // save + "vsetvli t0, x0, e32, m1 \n\t" + "add t1, %[LDC], %[DST] \n\t" + "vse32.v v28, (%[DST]) \n\t" + "vse32.v v29, (t1) \n\t" + "add t1, t1, %[LDC] \n\t" + "vse32.v v30, (t1) \n\t" + "add t1, t1, %[LDC] \n\t" + "vse32.v v31, (t1) \n\t" + : [A] "+r"(a_data), [B] "+r"(b_data) + : [DST] "r"(dst_c), [BK] "r"(k_blks), [LDC] "r"(ldc * 4), [SCALE] "f"(scale), [SCALE_1] "f"(scale_1) + : "t0", "t1", "t2", "t3", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", + "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", + "v28", "v29", "v30", "v31", "fa0", "t4", "t5", "t6", "s1", "s2", "s3", "s4", "s5", "s6"); + } +} + +void gemm_kernel_i8i3k_m1(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + constexpr size_t NB_COLS = 32; //only support 32 in ASM + using blk_type = nrow_block_q3_k; + + const blk_type * b_base = reinterpret_cast(quant_b_data); + + int64_t a_blk_stride = q8k_blk_size(256); + int64_t a_nrow_block_stride = a_blk_stride; + int64_t b_ncol_block_stride = sizeof(blk_type); + + // Constants used by q3_k scaling in HP branch: + // - k_q3k_scale_step: per-nibble scale factor (1/16). + // - k_a_scale_post_mul: A_scale needs an extra *16 at the end (pairs with 1/16 above). + const _Float16 k_q3k_scale_step = (_Float16) 0.0625f; // 1 / 16 + const float k_a_scale_post_mul = 16.0f; + + for (size_t ni = 0; ni < count_n; ni += NB_COLS, c_ptr += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + const blk_type * quant_b_blk_data = b_base + (ni / NB_COLS) * k_blks; +#if 0 + //------------------------------------------------------------------------------ + // A format + // Ascale fp32 * 1 32bit + // Asum int16 * 16 256bit + // A M1K256 int8 2048bit + //------------------------------------------------------------------------------ + // B format + // B_scl uint8*N32*16 4096bit + // B_Hmask N32K16*16 1bit 8192bit + // B_Qs N32K16*16 2bit 16384bit + // B scl16 fp16 * N32 512bit; + //------------------------------------------------------------------------------ + //bias always be nullptr + __asm__ volatile( + // t2 = k_blks (each is K256 superblock) + "mv t2, %[KBLKS] \n\t" + // t3 = 256/64 = 4 (K64 iterations per superblock) + "li t3, 4 \n\t" + "mv s2, %[pA] \n\t" // s2 = pASCL + "addi s3, %[pA], 4+32 \n\t" // s3 = pAData, (pA+AScl+ASum) + + // B block layout for nrow_block_q3_k<32>: + // scales: 512B, hmask: 1024B, qs: 2048B, scales16: 64B + "addi s5, %[pB], 32*16 \n\t" // s5 = pB_hmask + "mv s4, %[pB] \n\t" // s4 = pB_scales + "addi s6, s5, 1024 \n\t" // s6 = pB_qs + "mv s7, %[pB] \n\t" // s7 = pB_base + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v31, v0, v0 \n\t" // clear acc + "vxor.vv v30, v0, v0 \n\t" // clear acc of K256 + + // ordinary vmadot: vle*10 vecIns*78 vmadot*16 + ".align 4 \n\t" + "BLK_LPST%=: \n\t" + "K64_LPST%=: \n\t" + + // K0-15 + // load B scales (32 bytes per K16, 16 times => 512B) + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v2, (s4) \n\t" + "addi s4, s4, 128 \n\t" + + // load B qs chunk (128B per K16, 16 times => 2048B) + "vle8.v v4, (s6) \n\t" + "addi s6, s6, 128 \n\t" + "vle8.v v5, (s6) \n\t" + "addi s6, s6, 128 \n\t" + "vle8.v v6, (s6) \n\t" + "addi s6, s6, 128 \n\t" + "vle8.v v7, (s6) \n\t" + "addi s6, s6, 128 \n\t" + + // load B hmask chunk (64B per K16, 16 times => 1024B) + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s5) \n\t" + "addi s5, s5, 64 \n\t" + + // load A data (16 bytes per K16, 16 times => 256B) + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v1, (s3) \n\t" + "addi s3, s3, 64 \n\t" + + // unpack 2-bit qs + hmask -> signed values + "vsetvli t0, x0, e8, m1 \n\t" + "vnot.v v0, v0 \n\t" + "vand.vi v12, v4, 0x3 \n\t" + "vand.vi v13, v5, 0x3 \n\t" + "vand.vi v14, v6, 0x3 \n\t" + "vand.vi v15, v7, 0x3 \n\t" + + "vsetvli t0, x0, e8, m4 \n\t" + "vadd.vi v12, v12, -4, v0.t \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + "vmadot v16, v1, v12, i8 \n\t" + "vmadot v18, v1, v13, i8 \n\t" + "vmadot v20, v1, v14, i8 \n\t" + "vmadot v22, v1, v15, i8 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v24, v16, v18, 2 \n\t" + "vpack.vv v26, v20, v22, 2 \n\t" + "vpack.vv v16, v24, v26, 3 \n\t" // N0-N31 in v16 + + // apply B int8 scales (-32 bias has been applyed) + "vsetvli t0, x0, e8, mf4 \n\t" + "vwadd.vx v18, v2, x0 \n\t" // int8 -> int16 + + "vsetvli t0, x0, e16, mf2 \n\t" + "vwadd.vx v19, v18, x0 \n\t" // int8 -> int16 + + // static_cast(qsum) * b_scale; + "vsetvli t0, x0, e32, m1 \n\t" + "vmacc.vv v30, v16, v19 \n\t" + + //K16-31 + // load B scales (32 bytes per K16, 16 times => 512B) + "vsetvli t0, x0, e64, m1 \n\t" + "vslidedown.vi v2, v2, 4 \n\t" + + // load B hmask chunk (64B per K16, 16 times => 1024B) + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s5) \n\t" + "addi s5, s5, 64 \n\t" + + // load A data (16 bytes per K16, 16 times => 256B) + "vsetvli t0, x0, e64, mf2 \n\t" + "vslidedown.vi v1, v1, 2 \n\t" + + // unpack 2-bit qs + hmask -> signed values + "vsetvli t0, x0, e8, m1 \n\t" + "vsll.vi v8, v4, 4 \n\t" + "vsll.vi v9, v5, 4 \n\t" + "vsll.vi v10, v6, 4 \n\t" + "vsll.vi v11, v7, 4 \n\t" + "vnot.v v0, v0 \n\t" + + "vsrl.vi v12, v8, 6 \n\t" + "vsrl.vi v13, v9, 6 \n\t" + "vsrl.vi v14, v10, 6 \n\t" + "vsrl.vi v15, v11, 6 \n\t" + + "vsetvli t0, x0, e8, m4 \n\t" + "vadd.vi v12, v12, -4, v0.t \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + "vmadot v16, v1, v12, i8 \n\t" + "vmadot v18, v1, v13, i8 \n\t" + "vmadot v20, v1, v14, i8 \n\t" + "vmadot v22, v1, v15, i8 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v24, v16, v18, 2 \n\t" + "vpack.vv v26, v20, v22, 2 \n\t" + "vpack.vv v16, v24, v26, 3 \n\t" // N0-N31 in v16 + + // apply B int8 scales (-32 bias has been applyed) + "vsetvli t0, x0, e8, mf4 \n\t" + "vwadd.vx v18, v2, x0 \n\t" // int8 -> int16 + + "vsetvli t0, x0, e16, mf2 \n\t" + "vwadd.vx v19, v18, x0 \n\t" // int8 -> int16 + + // static_cast(qsum) * b_scale; + "vsetvli t0, x0, e32, m1 \n\t" + "vmacc.vv v30, v16, v19 \n\t" + + //K32-47 + // load B scales (32 bytes per K16, 16 times => 512B) + "vsetvli t0, x0, e64, m1 \n\t" + "vslidedown.vi v2, v2, 4 \n\t" + + // load B hmask chunk (64B per K16, 16 times => 1024B) + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s5) \n\t" + "addi s5, s5, 64 \n\t" + + // load A data (16 bytes per K16, 16 times => 256B) + "vsetvli t0, x0, e64, mf2 \n\t" + "vslidedown.vi v1, v1, 2 \n\t" + + // unpack 2-bit qs + hmask -> signed values + "vsetvli t0, x0, e8, m1 \n\t" + "vsll.vi v8, v4, 2 \n\t" + "vsll.vi v9, v5, 2 \n\t" + "vsll.vi v10, v6, 2 \n\t" + "vsll.vi v11, v7, 2 \n\t" + "vnot.v v0, v0 \n\t" + + "vsrl.vi v12, v8, 6 \n\t" + "vsrl.vi v13, v9, 6 \n\t" + "vsrl.vi v14, v10, 6 \n\t" + "vsrl.vi v15, v11, 6 \n\t" + + "vsetvli t0, x0, e8, m4 \n\t" + "vadd.vi v12, v12, -4, v0.t \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + "vmadot v16, v1, v12, i8 \n\t" + "vmadot v18, v1, v13, i8 \n\t" + "vmadot v20, v1, v14, i8 \n\t" + "vmadot v22, v1, v15, i8 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v24, v16, v18, 2 \n\t" + "vpack.vv v26, v20, v22, 2 \n\t" + "vpack.vv v16, v24, v26, 3 \n\t" + + // apply B int8 scales (-32 bias has been applyed) + "vsetvli t0, x0, e8, mf4 \n\t" + "vwadd.vx v18, v2, x0 \n\t" // int8 -> int16 + + "vsetvli t0, x0, e16, mf2 \n\t" + "vwadd.vx v19, v18, x0 \n\t" // int8 -> int16 + + // static_cast(qsum) * b_scale; + "vsetvli t0, x0, e32, m1 \n\t" + "vmacc.vv v30, v16, v19 \n\t" + + // K48-63 + // load B scales (32 bytes per K16, 16 times => 512B) + "vsetvli t0, x0, e64, m1 \n\t" + "vslidedown.vi v2, v2, 4 \n\t" + + // load B hmask chunk (64B per K16, 16 times => 1024B) + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s5) \n\t" + "addi s5, s5, 64 \n\t" + + // load A data (16 bytes per K16, 16 times => 256B) + "vsetvli t0, x0, e64, mf2 \n\t" + "vslidedown.vi v1, v1, 2 \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vnot.v v0, v0 \n\t" + "vsrl.vi v12, v4, 6 \n\t" + "vsrl.vi v13, v5, 6 \n\t" + "vsrl.vi v14, v6, 6 \n\t" + "vsrl.vi v15, v7, 6 \n\t" + + "vsetvli t0, x0, e8, m4 \n\t" + "vadd.vi v12, v12, -4, v0.t \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + "vmadot v16, v1, v12, i8 \n\t" + "vmadot v18, v1, v13, i8 \n\t" + "vmadot v20, v1, v14, i8 \n\t" + "vmadot v22, v1, v15, i8 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v24, v16, v18, 2 \n\t" + "vpack.vv v26, v20, v22, 2 \n\t" + "vpack.vv v16, v24, v26, 3 \n\t" + + // apply B int8 scales (-32 bias has been applyed) + "vsetvli t0, x0, e8, mf4 \n\t" + "vwadd.vx v18, v2, x0 \n\t" // int8 -> int16 + + "vsetvli t0, x0, e16, mf2 \n\t" + "vwadd.vx v19, v18, x0 \n\t" // int8 -> int16 + + // static_cast(qsum) * b_scale; + "vsetvli t0, x0, e32, m1 \n\t" + "vmacc.vv v30, v16, v19 \n\t" + + "addi t3, t3, -1 \n\t" + "bgtz t3, K64_LPST%= \n\t" + "K64_LPND%=: \n\t" + + // load A scale (fp32) and advance A to next superblock + "flw f0, (s2) \n\t" + "addi s2, s2, 4+32+256 \n\t" + "add t4, s7, %[B_STR] \n\t" // t4 = next B blk base + "addi s3, s2, 4+32 \n\t" + + // load B scales16[32] (fp16) at end of qs region + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v2, (s6) \n\t" + + // pointer modify + "addi s5, t4, 32*16 \n\t" + "mv s4, t4 \n\t" + "addi s6, s5, 32*32 \n\t" + "addi s7, t4, 0 \n\t" + + // b_scale fp16 -> fp32 + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwcvt.f.f.v v24, v2 \n\t" + + // a_scale * b_scale; + "vsetvli t0, x0, e32, m1 \n\t" + "vfcvt.f.x.v v26, v30 \n\t" + "vfmul.vf v1, v24, f0 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + // static_cast(qsum) * a_scale * b_scale; + "vfmacc.vv v31, v1, v26 \n\t" + + // next K-superblock + "addi t2, t2, -1 \n\t" + "vxor.vv v30, v0, v0 \n\t" // clear acc of K256 + "li t3, 4 \n\t" + "bgtz t2, BLK_LPST%= \n\t" + + "BLK_LPND%=: \n\t" + "vsetvli t0, %[NBLKS], e32, m1 \n\t" + "vse32.v v31, (%[pC]) \n\t" + "FUNC_END%=: \n\t" + + : + : [KBLKS] "r"(k_blks), [NBLKS] "r"(nb_real), [pA] "r"(quant_a_ptr), [pB] "r"(quant_b_blk_data), + [pC] "r"(c_ptr), [B_STR] "r"(b_ncol_block_stride) + : "cc", "memory", "t0", "t2", "t3", "t4", "t5", "f0", "s2", "s3", "s4", "s5", "s6", "s7"); +#else + + __asm__ volatile( + // ========================= + // Kernel overview (M1 x N32) + // ========================= + // Process one output row (M=1) and 32 columns (N=32) per call. + // + // Loop structure: + // - Outer loop: K superblocks of size K=256 (k_blks times) + // - Each K256 superblock is broken into 4 x K64 + // - Each K64 is processed as 4 x K16 "sub-blocks" (via unpack+dot) + // + // Data layout (high level): + // A (q8k K=256, per superblock): + // [ fp32 a_scale ][ int16 a_sum[16] ][ int8 a_qs[256] ] + // B (nrow_block_q3_k<32>, per superblock): + // [ int8 scales[32*16] ][ hmask[1024] ][ qs[2048] ][ fp16 scales16[32] ] + // + // Registers/pointers: + // s2: pA (points at A superblock header; used to load fp32 a_scale) + // s3: pA_qs (points at A int8 data within the current superblock) + // s4: pB_scales (points at B int8 per-K16 scales) + // s5: pB_hmask (points at B sign mask area) + // s6: pB_qs (points at B 2-bit packed qs area) + // s8: pB_scales16 (points at B fp16 scales16[32] at the end of block) + // s7: pB_base (base pointer to current B block; used for block-to-block stride) + + // t2 = number of K256 superblocks + "mv t2, %[KBLKS] \n\t" + // t3 = number of K64 chunks per K256 superblock (256 / 64) + "li t3, 4 \n\t" + + // A pointers + "mv s2, %[pA] \n\t" // s2 = pA_superblock (a_scale at +0) + "addi s3, %[pA], 4+32 \n\t" // s3 = pA_qs (skip a_scale + a_sum[16]) + + // B pointers for nrow_block_q3_k<32> + "addi s5, %[pB], 32*16 \n\t" // s5 = pB_hmask (skip scales[32*16]) + "mv s4, %[pB] \n\t" // s4 = pB_scales + "addi s6, s5, 1024 \n\t" // s6 = pB_qs (skip hmask) + // scales16 is at the end of the block: qs(2048) after hmask + "addi s8, s6, 1024 \n\t" + "addi s8, s8, 1024 \n\t" // s8 = pB_scales16 (fp16 scales16[32]) + "mv s7, %[pB] \n\t" // s7 = pB_base (for next-block address calc) + + // v31: final FP32 accumulator for N=32 + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v31, v0, v0 \n\t" + + // ---- Preload B scales16[32] and build FP16 scale vector used by vmadot.hp ---- + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v1, (s8) \n\t" // load fp16 scales16[32] + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v26, v1, v1, 3 \n\t" // broadcast/pack to match lanes + "vmv.v.v v17, v26 \n\t" + "vsetvli t0, x0, e16, m1 \n\t" + "vfmul.vf v30, v17, %[q3_step] \n\t" // v30 = scales16 * (1/16) + + // v24-v27: fp16 partial accumulators for a K64 chunk (vmadot.hp outputs) + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v24, v16, v16 \n\t" + "vxor.vv v25, v16, v16 \n\t" + "vxor.vv v26, v16, v16 \n\t" + "vxor.vv v27, v16, v16 \n\t" + + // HP vmadot: vle*10 vecIns*38 vmadot.hp*16 + ".align 4 \n\t" + "BLK_LPST%=: \n\t" // loop over K256 superblocks + "K64_LPST%=: \n\t" // loop over 4 x K64 chunks + + // ------------------------------------------------------------ + // K0-15: load B scales + {hmask, qs} + A data; unpack and dot + // ------------------------------------------------------------ + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v2, (s4) \n\t" // B int8 scales for this K16 + "addi s4, s4, 128 \n\t" + + "vle8.v v4, (s6) \n\t" + "addi s6, s6, 128 \n\t" + "vle8.v v5, (s6) \n\t" + "addi s6, s6, 128 \n\t" + "vle8.v v6, (s6) \n\t" + "addi s6, s6, 128 \n\t" + "vle8.v v7, (s6) \n\t" + "addi s6, s6, 128 \n\t" + + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s5) \n\t" // B hmask for this K16 + "addi s5, s5, 64 \n\t" + + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v3, (s3) \n\t" // A int8 data for this K16 + "addi s3, s3, 64 \n\t" + + // Convert B int8 scales to FP16 and apply scales16*(1/16) + "vsetvli t0, x0, e8, m1 \n\t" + "vfwcvt.f.x.v v28, v2 \n\t" // int8 -> fp16 + "vsetvli t0, x0, e16, m1 \n\t" + "vfmul.vv v1, v28, v30 \n\t" // v1: FP16 scale vector for vmadot.hp + "vfmul.vv v29, v29, v30 \n\t" + + // Unpack B 2-bit qs + hmask -> signed int8 in v12..v15 + "vsetvli t0, x0, e8, m1 \n\t" + "vnot.v v0, v0 \n\t" + "vand.vi v12, v4, 0x3 \n\t" + "vand.vi v13, v5, 0x3 \n\t" + "vand.vi v14, v6, 0x3 \n\t" + "vand.vi v15, v7, 0x3 \n\t" + "vsetvli t0, x0, e8, m4 \n\t" + "vadd.vi v12, v12, -4, v0.t \n\t" + + // (Next K16 unpack path uses a fresh hmask load) + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s5) \n\t" + "addi s5, s5, 64 \n\t" + + // Prepare another group from packed qs (bit shifts) + apply sign from hmask + "vsetvli t0, x0, e8, m1 \n\t" + "vsll.vi v8, v4, 4 \n\t" + "vsll.vi v9, v5, 4 \n\t" + "vsll.vi v10, v6, 4 \n\t" + "vsll.vi v11, v7, 4 \n\t" + "vsrl.vi v16, v8, 6 \n\t" + "vsrl.vi v17, v9, 6 \n\t" + "vnot.v v0, v0 \n\t" + "vsrl.vi v18, v10, 6 \n\t" + "vsrl.vi v19, v11, 6 \n\t" + "vsetvli t0, x0, e8, m4 \n\t" + "vadd.vi v16, v16, -4, v0.t \n\t" + + // A shift for the second dot within this K64 + "vsetvli t0, x0, e64, mf2 \n\t" + "vslidedown.vi v2, v3, 2 \n\t" + + // Dot products with FP16 scaling (accumulate into v24..v27) + "vsetvli t0, x0, e32, m1 \n\t" + "vmadot.hp v24, v3, v12, v1, 0, i8 \n\t" + "vmadot.hp v25, v3, v13, v1, 1, i8 \n\t" + "vmadot.hp v26, v3, v14, v1, 2, i8 \n\t" + "vmadot.hp v27, v3, v15, v1, 3, i8 \n\t" + "vmadot.hp v24, v2, v16, v1, 4, i8 \n\t" + "vmadot.hp v25, v2, v17, v1, 5, i8 \n\t" + "vmadot.hp v26, v2, v18, v1, 6, i8 \n\t" + "vmadot.hp v27, v2, v19, v1, 7, i8 \n\t" + + // (K32-47 / K48-63 blocks continue unchanged...) + // load B scales (32 bytes per K16, 16 times => 512B) + "vsetvli t0, x0, e64, m1 \n\t" + "vmv.v.v v1, v29 \n\t" + + // load B hmask chunk (64B per K16, 16 times => 1024B) + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s5) \n\t" + "addi s5, s5, 64 \n\t" + + // load A data (16 bytes per K16, 16 times => 256B) + "vsetvli t0, x0, e64, mf2 \n\t" + "vslidedown.vi v3, v3, 4 \n\t" + + // unpack 2-bit qs + hmask -> signed values + "vsetvli t0, x0, e8, m1 \n\t" + "vsll.vi v8, v4, 2 \n\t" + "vsll.vi v9, v5, 2 \n\t" + "vsll.vi v10, v6, 2 \n\t" + "vsll.vi v11, v7, 2 \n\t" + + "vsrl.vi v20, v8, 6 \n\t" + "vsrl.vi v21, v9, 6 \n\t" + "vnot.v v0, v0 \n\t" + "vsrl.vi v22, v10, 6 \n\t" + "vsrl.vi v23, v11, 6 \n\t" + + "vsetvli t0, x0, e8, m4 \n\t" + "vadd.vi v20, v20, -4, v0.t \n\t" + + // K48-63 + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s5) \n\t" + "addi s5, s5, 64 \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vsrl.vi v8, v4, 6 \n\t" + "vsrl.vi v9, v5, 6 \n\t" + "vnot.v v0, v0 \n\t" + "vsrl.vi v10, v6, 6 \n\t" + "vsrl.vi v11, v7, 6 \n\t" + + "vsetvli t0, x0, e8, m4 \n\t" + "vadd.vi v8, v8, -4, v0.t \n\t" + + // load A data (16 bytes per K16, 16 times => 256B) + "vsetvli t0, x0, e64, mf2 \n\t" + "vslidedown.vi v2, v3, 2 \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vmadot.hp v24, v3, v20, v1, 0, i8 \n\t" + "vmadot.hp v25, v3, v21, v1, 1, i8 \n\t" + "vmadot.hp v26, v3, v22, v1, 2, i8 \n\t" + "vmadot.hp v27, v3, v23, v1, 3, i8 \n\t" + "vmadot.hp v24, v2, v8, v1, 4, i8 \n\t" + "vmadot.hp v25, v2, v9, v1, 5, i8 \n\t" + "vmadot.hp v26, v2, v10, v1, 6, i8 \n\t" + "vmadot.hp v27, v2, v11, v1, 7, i8 \n\t" + + "addi t3, t3, -1 \n\t" + "bgtz t3, K64_LPST%= \n\t" + "K64_LPND%=: \n\t" + + // ---- End of K64 chunk: reduce fp16 accumulators -> fp32 and scale by A ---- + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v12, v24, v25, 1 \n\t" + "vpack.vv v14, v26, v27, 1 \n\t" + "vpack.vv v16, v12, v14, 2 \n\t" + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwcvt.f.f.v v26, v16 \n\t" // fp16 -> fp32 vector (qsum * b_scales) + + // Load A scale and advance A pointer to next K256 superblock + "flw f0, (s2) \n\t" + "addi s2, s2, 4+32+256 \n\t" + "add t4, s7, %[B_STR] \n\t" // next B block base + "addi s3, s2, 4+32 \n\t" // reset A data pointer for next block + + // Advance B pointers to next K256 superblock + "addi s5, t4, 32*16 \n\t" + "mv s4, t4 \n\t" + "addi s6, s5, 32*32 \n\t" + "addi s8, s6, 1024 \n\t" + "addi s8, s8, 1024 \n\t" + "addi s7, t4, 0 \n\t" + "addi t2, t2, -1 \n\t" + + // Final per-block scaling: a_scale * 16.0f + "fmul.s f0, f0, %[a_post_mul] \n\t" + // acc += (qsum * b_scales) * (a_scale*16) + "vsetvli t0, x0, e32, m1 \n\t" + "vfmacc.vf v31, f0, v26 \n\t" + + "beqz t2, BLK_LPND%= \n\t" + + // Preload next block's scales16 and rebuild v30 for vmadot.hp + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v1, (s8) \n\t" + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v26, v1, v1, 3 \n\t" + "vmv.v.v v17, v26 \n\t" + "vsetvli t0, x0, e16, m1 \n\t" + "vfmul.vf v30, v17, %[q3_step] \n\t" + + // Reset fp16 partial accumulators for next K64 loop(s) + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v24, v16, v16 \n\t" + "vxor.vv v25, v16, v16 \n\t" + "vxor.vv v26, v16, v16 \n\t" + "vxor.vv v27, v16, v16 \n\t" + + "li t3, 4 \n\t" + "bgtz t2, BLK_LPST%= \n\t" + + "BLK_LPND%=: \n\t" + "vsetvli t0, %[NBLKS], e32, m1 \n\t" + "vse32.v v31, (%[pC]) \n\t" + + : + : [KBLKS] "r"(k_blks), [NBLKS] "r"(nb_real), [pA] "r"(quant_a_ptr), [pB] "r"(quant_b_blk_data), + [pC] "r"(c_ptr), [B_STR] "r"(b_ncol_block_stride), [q3_step] "f"(k_q3k_scale_step), + [a_post_mul] "f"(k_a_scale_post_mul) + : "cc", "memory", "t0", "t2", "t3", "t4", "t5", "f0", "f1", "s2", "s3", "s4", "s5", "s6", "s7", "s8"); +#endif + } +} + +void gemm_kernel_i8i3k_m4(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + using blk_type = nrow_block_q3_k<32>; + constexpr size_t NB_COLS = 32; //only support 32 in ASM + + const blk_type * b_base = reinterpret_cast(quant_b_data); + + int64_t a_blk_stride = q8k_blk_size(256); + int64_t a_nrow_block_stride = a_blk_stride * 4; + int64_t b_ncol_block_stride = sizeof(blk_type); + + for (size_t ni = 0; ni < count_n; ni += NB_COLS, c_ptr += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + const blk_type * quant_b_blk_data = b_base + (ni / NB_COLS) * k_blks; + + //------------------------------------------------------------------------------ + // A format + // Ascale fp32 * 1* 4row 128bit + // Asum int16 * 16 4row 1024bit + // A M1K256 int8 4row 8192bit + //------------------------------------------------------------------------------ + // B format + // B_scl uint8*N32*16 4096bit + // B_Hmask N32K16*16 1bit 8192bit + // B_Qs N32K16*16 2bit 16384bit + // B scl16 fp16 * N32 512bit; + //------------------------------------------------------------------------------ + //bias always be nullptr + __asm__ volatile( + // t2 = k_blks (each is K256 superblock) + "mv t2, %[KBLKS] \n\t" + // t3 = 256/64 = 4 (K64 iterations per superblock) + "li t3, 4 \n\t" + "mv s2, %[pA] \n\t" // s2 = pASCL + "addi s3, %[pA], 16+128 \n\t" // s3 = pAData, (pA+AScl+ASum) + + // B block layout for nrow_block_q3_k<32>: + // scales: 512B, hmask: 1024B, qs: 2048B, scales16: 64B + "addi s5, %[pB], 32*16 \n\t" // s5 = pB_hmask (skip scales) + "mv s4, %[pB] \n\t" // s4 = pB_scales + "addi s6, s5, 1024 \n\t" // s6 = pB_qs (skip hmask) + "mv s7, %[pB] \n\t" // s7 = pB_base + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v24, v0, v0 \n\t" // v24-v27: K256 temp accumulator + "vxor.vv v25, v0, v0 \n\t" + "vxor.vv v26, v0, v0 \n\t" + "vxor.vv v27, v0, v0 \n\t" + "vxor.vv v28, v0, v0 \n\t" // v28-v31: final accumulator + "vxor.vv v29, v0, v0 \n\t" + "vxor.vv v30, v0, v0 \n\t" + "vxor.vv v31, v0, v0 \n\t" + + // ordinary vmadot: vle*13 vecIns*96 vmadot*16 + ".align 4 \n\t" + "BLK_LPST%=: \n\t" + "K64_LPST%=: \n\t" + + // ========== K0-15: First K16 sub-block ========== + // Load B INT8 scale factors (32 cols × 16 K16 blocks) + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v8, (s4) \n\t" + "addi s4, s4, 128 \n\t" + + // Load B quantized data (32 cols × 16 elements × 2bit, stored in 4 groups) + "vle8.v v4, (s6) \n\t" + "addi s6, s6, 128 \n\t" + "vle8.v v5, (s6) \n\t" + "addi s6, s6, 128 \n\t" + "vle8.v v6, (s6) \n\t" + "addi s6, s6, 128 \n\t" + "vle8.v v7, (s6) \n\t" + "addi s6, s6, 128 \n\t" + + // Load B hmask (32 cols × 16bit sign mask) + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s5) \n\t" + "addi s5, s5, 64 \n\t" + + // Load A data (4 rows × 16 elements × INT8) + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v12, (s3) \n\t" + "addi s3, s3, 256 \n\t" // Jump to next row + "vle8.v v13, (s3) \n\t" + "addi s3, s3, 256 \n\t" + "vle8.v v14, (s3) \n\t" + "addi s3, s3, 256 \n\t" + "vle8.v v15, (s3) \n\t" + "addi s3, s3, -768+64 \n\t" // Back to first row, advance 16 elements + + // Pack A data: merge 4 rows into 2 vectors + "vsetvli t0, x0, e8, m1 \n\t" + "vpack.vv v16, v12, v13, 1 \n\t" + "vpack.vv v18, v14, v15, 1 \n\t" + "vpack.vv v2, v16, v18, 2 \n\t" + + // unpack 2-bit qs + hmask -> signed values + "vsetvli t0, x0, e8, m1 \n\t" + "vnot.v v0, v0 \n\t" + "vand.vi v12, v4, 0x3 \n\t" + "vand.vi v13, v5, 0x3 \n\t" + "vand.vi v14, v6, 0x3 \n\t" + "vand.vi v15, v7, 0x3 \n\t" + + "vsetvli t0, x0, e8, m4 \n\t" + "vadd.vi v12, v12, -4, v0.t \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + "vmadot v16, v2, v12, i8 \n\t" // 4 rows × cols 0-7 + "vmadot v18, v2, v13, i8 \n\t" // 4 rows × cols 8-15 + "vmadot v20, v2, v14, i8 \n\t" // 4 rows × cols 16-23 + "vmadot v22, v2, v15, i8 \n\t" // 4 rows × cols 24-31 + + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v12, v16, v18, 2 \n\t" // Merge cols 0-15 + "vpack.vv v14, v20, v22, 2 \n\t" // Merge cols 16-31 + "vpack.vv v16, v12, v14, 3 \n\t" // Inter-row results (INT16) + "vpack.vv v18, v13, v15, 3 \n\t" + + // apply B int8 scales (-32 bias has been applyed) + "vsetvli t0, x0, e8, mf4 \n\t" + "vwadd.vx v21, v8, x0 \n\t" // INT8 → INT16 + + "vsetvli t0, x0, e16, mf2 \n\t" + "vwadd.vx v23, v21, x0 \n\t" // INT16 → INT32 + + // Accumulate to K256 accumulator: qsum * b_scale + "vsetvli t0, x0, e32, m1 \n\t" + "vmacc.vv v24, v16, v23 \n\t" // Row 0 + "vmacc.vv v25, v17, v23 \n\t" // Row 1 + "vmacc.vv v26, v18, v23 \n\t" // Row 2 + "vmacc.vv v27, v19, v23 \n\t" + + // ========== K16-31, K32-47, K48-63: Similar processing ========== + // load B scales (32 bytes per K16, 16 times => 512B) + "vsetvli t0, x0, e64, m1 \n\t" + "vslidedown.vi v8, v8, 4 \n\t" + + // load B hmask chunk (64B per K16, 16 times => 1024B) + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s5) \n\t" + "addi s5, s5, 64 \n\t" + + // load A data (16 bytes per K16, 16 times => 256B) + "vsetvli t0, x0, e64, m1 \n\t" + "vslidedown.vi v2, v2, 8 \n\t" + + // unpack 2-bit qs + hmask -> signed values + "vsetvli t0, x0, e8, m1 \n\t" + "vsll.vi v12, v4, 4 \n\t" + "vsll.vi v13, v5, 4 \n\t" + "vsll.vi v14, v6, 4 \n\t" + "vsll.vi v15, v7, 4 \n\t" + "vnot.v v0, v0 \n\t" + + "vsrl.vi v12, v12, 6 \n\t" + "vsrl.vi v13, v13, 6 \n\t" + "vsrl.vi v14, v14, 6 \n\t" + "vsrl.vi v15, v15, 6 \n\t" + + "vsetvli t0, x0, e8, m4 \n\t" + "vadd.vi v12, v12, -4, v0.t \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + "vmadot v16, v2, v12, i8 \n\t" + "vmadot v18, v2, v13, i8 \n\t" + "vmadot v20, v2, v14, i8 \n\t" + "vmadot v22, v2, v15, i8 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v12, v16, v18, 2 \n\t" + "vpack.vv v14, v20, v22, 2 \n\t" + "vpack.vv v16, v12, v14, 3 \n\t" // N0-N31 in v16 + "vpack.vv v18, v13, v15, 3 \n\t" + + // apply B int8 scales (-32 bias has been applyed) + "vsetvli t0, x0, e8, mf4 \n\t" + "vwadd.vx v21, v8, x0 \n\t" // int8 -> int16 + + "vsetvli t0, x0, e16, mf2 \n\t" + "vwadd.vx v23, v21, x0 \n\t" // int8 -> int16 + + // static_cast(qsum) * b_scale; + "vsetvli t0, x0, e32, m1 \n\t" + "vmacc.vv v24, v16, v23 \n\t" + "vmacc.vv v25, v17, v23 \n\t" + "vmacc.vv v26, v18, v23 \n\t" + "vmacc.vv v27, v19, v23 \n\t" + + //K32-47 + // load B scales (32 bytes per K16, 16 times => 512B) + "vsetvli t0, x0, e64, m1 \n\t" + "vslidedown.vi v8, v8, 4 \n\t" + + // load B hmask chunk (64B per K16, 16 times => 1024B) + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s5) \n\t" + "addi s5, s5, 64 \n\t" + + // load A data (16 bytes per K16, 16 times => 256B) + + // unpack 2-bit qs + hmask -> signed values + "vsetvli t0, x0, e8, m1 \n\t" + "vsll.vi v12, v4, 2 \n\t" + "vsll.vi v13, v5, 2 \n\t" + "vsll.vi v14, v6, 2 \n\t" + "vsll.vi v15, v7, 2 \n\t" + "vnot.v v0, v0 \n\t" + + "vsrl.vi v12, v12, 6 \n\t" + "vsrl.vi v13, v13, 6 \n\t" + "vsrl.vi v14, v14, 6 \n\t" + "vsrl.vi v15, v15, 6 \n\t" + + "vsetvli t0, x0, e8, m4 \n\t" + "vadd.vi v12, v12, -4, v0.t \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + "vmadot v16, v3, v12, i8 \n\t" + "vmadot v18, v3, v13, i8 \n\t" + "vmadot v20, v3, v14, i8 \n\t" + "vmadot v22, v3, v15, i8 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v12, v16, v18, 2 \n\t" + "vpack.vv v14, v20, v22, 2 \n\t" + "vpack.vv v16, v12, v14, 3 \n\t" // N0-N31 in v16 + "vpack.vv v18, v13, v15, 3 \n\t" + + // apply B int8 scales (-32 bias has been applyed) + "vsetvli t0, x0, e8, mf4 \n\t" + "vwadd.vx v21, v8, x0 \n\t" // int8 -> int16 + + "vsetvli t0, x0, e16, mf2 \n\t" + "vwadd.vx v23, v21, x0 \n\t" // int8 -> int16 + + // static_cast(qsum) * b_scale; + "vsetvli t0, x0, e32, m1 \n\t" + "vmacc.vv v24, v16, v23 \n\t" + "vmacc.vv v25, v17, v23 \n\t" + "vmacc.vv v26, v18, v23 \n\t" + "vmacc.vv v27, v19, v23 \n\t" + + // K48-63 + // load B scales (32 bytes per K16, 16 times => 512B) + "vsetvli t0, x0, e64, m1 \n\t" + "vslidedown.vi v8, v8, 4 \n\t" + + // load B hmask chunk (64B per K16, 16 times => 1024B) + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s5) \n\t" + "addi s5, s5, 64 \n\t" + + // load A data (16 bytes per K16, 16 times => 256B) + "vsetvli t0, x0, e64, m1 \n\t" + "vslidedown.vi v3, v3, 8 \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vnot.v v0, v0 \n\t" + "vsrl.vi v12, v4, 6 \n\t" + "vsrl.vi v13, v5, 6 \n\t" + "vsrl.vi v14, v6, 6 \n\t" + "vsrl.vi v15, v7, 6 \n\t" + + "vsetvli t0, x0, e8, m4 \n\t" + "vadd.vi v12, v12, -4, v0.t \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + "vmadot v16, v3, v12, i8 \n\t" + "vmadot v18, v3, v13, i8 \n\t" + "vmadot v20, v3, v14, i8 \n\t" + "vmadot v22, v3, v15, i8 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v12, v16, v18, 2 \n\t" + "vpack.vv v14, v20, v22, 2 \n\t" + "vpack.vv v16, v12, v14, 3 \n\t" // N0-N31 in v16 + "vpack.vv v18, v13, v15, 3 \n\t" + + // apply B int8 scales (-32 bias has been applyed) + "vsetvli t0, x0, e8, mf4 \n\t" + "vwadd.vx v21, v8, x0 \n\t" // int8 -> int16 + + "vsetvli t0, x0, e16, mf2 \n\t" + "vwadd.vx v23, v21, x0 \n\t" // int8 -> int16 + + // static_cast(qsum) * b_scale; + "vsetvli t0, x0, e32, m1 \n\t" + "vmacc.vv v24, v16, v23 \n\t" + "vmacc.vv v25, v17, v23 \n\t" + "vmacc.vv v26, v18, v23 \n\t" + "vmacc.vv v27, v19, v23 \n\t" + + "addi t3, t3, -1 \n\t" + "bgtz t3, K64_LPST%= \n\t" + "K64_LPND%=: \n\t" + + // ========== K256 superblock complete, apply scale factors ========== + // Load A's 4 row scale factors (FP32) + "flw f0, (s2) \n\t" + "flw f1, 4(s2) \n\t" + "flw f2, 8(s2) \n\t" + "flw f3, 12(s2) \n\t" + "add s2, s2, %[A_STR] \n\t" // Advance to next superblock + "add t4, s7, %[B_STR] \n\t" // t4 = next B block address + "addi s3, s2, (4+32)*4 \n\t" + + // Load B FP16 global scale factors (32 cols) + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v8, (s6) \n\t" + + // Update B pointers to next block + "addi s5, t4, 32*16 \n\t" + "mv s4, t4 \n\t" + "addi s6, s5, 32*32 \n\t" + "addi s7, t4, 0 \n\t" + + // ========== Type conversion and final scaling ========== + // FP16 → FP32 + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwcvt.f.f.v v9, v8 \n\t" + + // INT32 → FP32 + "vsetvli t0, x0, e32, m1 \n\t" + "vfcvt.f.x.v v24, v24 \n\t" + "vfcvt.f.x.v v25, v25 \n\t" + "vfcvt.f.x.v v26, v26 \n\t" + "vfcvt.f.x.v v27, v27 \n\t" + + // Compute a_scale * b_scale (4 rows) + "vfmul.vf v12, v9, f0 \n\t" + "vfmul.vf v13, v9, f1 \n\t" + "vfmul.vf v14, v9, f2 \n\t" + "vfmul.vf v15, v9, f3 \n\t" + + // Final accumulation: result += qsum * a_scale * b_scale + "vsetvli t0, x0, e32, m1 \n\t" + "vfmacc.vv v28, v12, v24 \n\t" + "vfmacc.vv v29, v13, v25 \n\t" + "vfmacc.vv v30, v14, v26 \n\t" + "vfmacc.vv v31, v15, v27 \n\t" + + // Prepare for next K superblock + "addi t2, t2, -1 \n\t" + "vxor.vv v24, v0, v0 \n\t" // Clear K256 accumulator + "vxor.vv v25, v0, v0 \n\t" + "vxor.vv v26, v0, v0 \n\t" + "vxor.vv v27, v0, v0 \n\t" + "li t3, 4 \n\t" + "bgtz t2, BLK_LPST%= \n\t" + + "BLK_LPND%=: \n\t" + + // ========== Store results (4 rows × 32 cols) ========== + "mv t5, %[pC] \n\t" + "vsetvli t0, %[NBLKS], e32, m1 \n\t" + "vse32.v v28, (%[pC]) \n\t" + "add t5, t5, %[LDC] \n\t" + "vse32.v v29, (t5) \n\t" + "add t5, t5, %[LDC] \n\t" + "vse32.v v30, (t5) \n\t" + "add t5, t5, %[LDC] \n\t" + "vse32.v v31, (t5) \n\t" + "add t5, t5, %[LDC] \n\t" + "FUNC_END%=: \n\t" + + : + : [KBLKS] "r"(k_blks), [NBLKS] "r"(nb_real), [pA] "r"(quant_a_ptr), [pB] "r"(quant_b_blk_data), + [pC] "r"(c_ptr), [B_STR] "r"(b_ncol_block_stride), [A_STR] "r"(a_nrow_block_stride), [LDC] "r"(ldc * 4) + : "cc", "memory", "t0", "t2", "t3", "t4", "t5", "f0", "f1", "f2", "f3", "s2", "s3", "s4", "s5", "s6", "s7"); + } +} + +void gemm_kernel_i8i4_m1(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + if (quant_b_zp == NULL) { + for (size_t n = 0; n < count_n; n += 32) { + size_t nblks = (count_n - n) > 32 ? 32 : count_n - n; + uint8_t * QuantBDataPtr = (uint8_t *) quant_b_data + // + n * k_blks * blk_len / 2 + // b data + n * k_blks * sizeof(_Float16); // scale + float * CPtr = c_ptr + n; + size_t cnt = k_blks; + + // A format Version_1 (FP32 SCALE FOR Normal VMADOTins of IME2) + // A M1K32 int8 256bit + // Ascale fp32 * 1 32bit + // || scl*1(fp32) | Asum(int16) | blk0 || scl*1(fp32) | Asum(int16) | blk0 || ... + // || Element || Element || ... + // B format + // B N8K32 int4 1024bit + // 4VRF, N32K32, 4096bit + // Bscale fp16 * N32 512bit; + // || scl*32..(fp16) | blk0 blk1 ... blk31 || scl*32..(fp16) | blk0 blk1 ... blk31 || ... + // || Element || Element || ... +#if 0 + //bias always be nullptr + __asm__ volatile( + + // t3 = k/32 + "mv t3, %[BCK] \n\t" + "mv t4, %[NBLKS] \n\t" + "mv s2, %[pA] \n\t" // s2 = pASCL + "addi s3, %[pA], 4+2 \n\t" // s3 = pAData, (pA+AScl+ASum) + "mv s4, %[pB] \n\t" // s4 = pBSCL + "addi s5, %[pB], 32*2 \n\t" // s5 = pBdata; + "mv s6, %[pC] \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v2, v0, v0 \n\t" // clear acc + + // ordinary vmadot: vle*6 flw*1 vecIns*21 vmadot*8 + ".align 4 \n\t" + "_K_LPST%=: \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vl4r.v v4, (s5) \n\t" // B Data 4VRF * 8Row * 32 + "addi s5, s5, 128*4+64 \n\t" // 1024bit + + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s4) \n\t" // B Scale 4VRF*8Row*FP16 = 512bit + "addi s4, s4, 64+128*4 \n\t" + + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v3, (s3) \n\t" // A Data M1*K32*int8 = 256bit + "addi s3, s3, 32+6 \n\t" + + "flw f0, (s2) \n\t" // A Scale fp32 + "lh t2, 4(s2) \n\t" // A sum of int16 + "addi s2, s2, 6+32 \n\t" + + "vsetvli t0, zero, e8, m1 \n\t" + "vsrl.vi v24, v3, 4 \n\t" + + "vnpack4.vv v8, v3, v3, 3 \n\t" // lo4 of A + "vnpack4.vv v10, v24, v24, 3 \n\t" // hi4 of A + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + "vmadotsu v16, v10, v4, i4 \n\t" // M0 N0 - N7 INT32(256bit) + "vmadotsu v18, v10, v5, i4 \n\t" // M0 N8 - N15 + "vmadotsu v20, v10, v6, i4 \n\t" // M0 N16 - N23 + "vmadotsu v22, v10, v7, i4 \n\t" // M0 N24 - N31 + + "vsll.vi v16, v16, 4 \n\t" + "vsll.vi v18, v18, 4 \n\t" + "vsll.vi v20, v20, 4 \n\t" + "vsll.vi v22, v22, 4 \n\t" + + "vmadotu v16, v8, v4, i4 \n\t" + "vmadotu v18, v8, v5, i4 \n\t" + "vmadotu v20, v8, v6, i4 \n\t" + "vmadotu v22, v8, v7, i4 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vmv.v.i v28, 8 \n\t" + "vpack.vv v24, v16, v18, 2 \n\t" + "vpack.vv v26, v20, v22, 2 \n\t" + "vpack.vv v16, v24, v26, 3 \n\t" + + "vwmul.vx v24, v28, t2 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vadd.vv v16, v16, v24 \n\t" + + // b_scale fp16 -> fp32 + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwcvt.f.f.v v24, v0 \n\t" + // mac result i32 -> fp32 + "vsetvli t0, x0, e32, m1 \n\t" + "vfcvt.f.x.v v26, v16 \n\t" + // a_scale * b_scale; + "vfmul.vf v1, v24, f0 \n\t" + // static_cast(qsum) * a_scale * b_scale; + "vfmacc.vv v2, v1, v26 \n\t" + + "addi t3, t3, -1 \n\t" + "bgtz t3, _K_LPST%= \n\t" + "_K_LPND%=: \n\t" + + //----------------------------------------- + // STORE Equal 32N------------------------- + "_ST32%=: \n\t" + "vsetvli t0, t4, e32, m1 \n\t" + "vse32.v v2, (s6) \n\t" // M0 [N0 : N32]; FP32(1024bit) + + "_FUNC_END%=: \n\t" + + : + : [BCK] "r"(cnt), [NBLKS] "r"(nblks), [pA] "r"(quant_a_ptr), [pB] "r"(QuantBDataPtr), [pC] "r"(CPtr) + : "cc", "t0", "t2", "t3", "t4", "f0", "s2", "s3", "s4", "s5", "s6"); +#else + __asm__ volatile( + + // t3 = k/32 + "mv t3, %[BCK] \n\t" + "mv t4, %[NBLKS] \n\t" + "vsetvli t0, x0, e16, m1 \n\t" + "vmv.v.i v0, 1 \n\t" // init the scale + "mv s2, %[pA] \n\t" // s2 = pASCL + "addi s3, %[pA], 4+2 \n\t" // s3 = pAData, (pA+AScl+ASum) + "mv s4, %[pB] \n\t" // s4 = pBSCL + "addi s5, %[pB], 32*2 \n\t" // s5 = pBdata; + "mv s6, %[pC] \n\t" + + "vsll.vi v1, v0, 4 \n\t" + "vxor.vv v2, v0, v0 \n\t" // clear acc + "vfcvt.f.x.v v0, v0 \n\t" + "vfcvt.f.x.v v1, v1 \n\t" + + // vmadot hp: vle*7 flw*1 vecIns*14 vmadot*8 + ".align 4 \n\t" + "_K_LPST%=: \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vl4r.v v4, (s5) \n\t" // B Data 4VRF * 8Row * 32 + "addi s5, s5, 128*4+64 \n\t" // 1024bit + + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v30, (s4) \n\t" // B Scale 4VRF*8Row*FP16 = 512bit + "addi s4, s4, 64+128*4 \n\t" + + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v3, (s3) \n\t" // A Data M1*K32*int8 = 256bit + "addi s3, s3, 32+6 \n\t" + + "flw f0, (s2) \n\t" // A Scale fp32 + "lh t2, 4(s2) \n\t" // A sum of int16 + "addi s2, s2, 6+32 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vmv.v.i v28, 8 \n\t" // Bzp u8 -> u16 + "vsetvli t0, x0, e8, m1 \n\t" + "vsrl.vi v24, v3, 4 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vmul.vx v26, v28, t2 \n\t" // asum*zp i16*i16 + "vnpack4.vv v8, v3, v3, 3 \n\t" // lo4 of A + "vnpack4.vv v10, v24, v24, 3 \n\t" // hi4 of A + + "vfcvt.f.x.v v16, v26 \n\t" // zp i16 -> fp16 + "vadd.vi v18, v16, 0 \n\t" + "vadd.vi v20, v16, 0 \n\t" + "vadd.vi v22, v16, 0 \n\t" + + "vmadotsu.hp v16, v10, v4, v1, 0, i4 \n\t" // high 4 + "vmadotsu.hp v18, v10, v5, v1, 0, i4 \n\t" + "vmadotsu.hp v20, v10, v6, v1, 0, i4 \n\t" + "vmadotsu.hp v22, v10, v7, v1, 0, i4 \n\t" + "vmadotu.hp v16, v8, v4, v0, 0, i4 \n\t" // low 4 + "vmadotu.hp v18, v8, v5, v0, 0, i4 \n\t" + "vmadotu.hp v20, v8, v6, v0, 0, i4 \n\t" + "vmadotu.hp v22, v8, v7, v0, 0, i4 \n\t" + + "vpack.vv v24, v16, v18, 1 \n\t" + "vpack.vv v26, v20, v22, 1 \n\t" + "vpack.vv v16, v24, v26, 2 \n\t" + + "vsetvli t0, x0, e16, mf2 \n\t" + // mac result * b_scale; f16*f16->f32 + "vfwmul.vv v31, v30, v16 \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + // static_cast(qsum * b_scale) * a_scale; + "vfmacc.vf v2, f0, v31 \n\t" + + "addi t3, t3, -1 \n\t" + "bgtz t3, _K_LPST%= \n\t" + "_K_LPND%=: \n\t" + + //----------------------------------------- + // STORE Equal 32N------------------------- + "_ST32%=: \n\t" + "vsetvli t0, t4, e32, m1 \n\t" + "vse32.v v2, (s6) \n\t" // M0 [N0 : N32]; FP32(1024bit) + + "_FUNC_END%=: \n\t" + + : + : [BCK] "r"(cnt), [NBLKS] "r"(nblks), [pA] "r"(quant_a_ptr), [pB] "r"(QuantBDataPtr), [pC] "r"(CPtr) + : "cc", "t0", "t2", "t3", "t4", "f0", "s2", "s3", "s4", "s5", "s6"); + +#endif + } + } else { + for (size_t n = 0; n < count_n; n += 32) { + size_t nblks = (count_n - n) > 32 ? 32 : count_n - n; + uint8_t * QuantBDataPtr = (uint8_t *) quant_b_data + // + n * k_blks * blk_len / 2 + // b data + n * k_blks * sizeof(uint8_t) + // b zp + n * k_blks * sizeof(_Float16); // scale + float * CPtr = c_ptr + n; + size_t cnt = k_blks; + + // A format Version_1 (FP32 SCALE FOR Normal VMADOTins of IME2) + // A M1K32 int8 256bit + // Ascale fp32 * 1 32bit + // || scl*1(fp32) | Asum(int16) | blk0 || scl*1(fp32) | Asum(int16) | blk0 || ... + // || Element || Element || ... + // B format + // B N8K32 int4 1024bit + // 4VRF, N32K32, 4096bit + // Bscale fp16 * N32 512bit; + // Bzp uint8_t * N32 256bit; + // || scl*32..(fp16) | zp*32(uint8) | blk0 blk1 ... blk31 || scl*32..(fp16) ... + // || Element || Element ... + + //bias always be nullptr +#if 0 + __asm__ volatile( + + // t3 = k/32 + "mv t3, %[BCK] \n\t" + "mv t4, %[NBLKS] \n\t" + "mv s2, %[pA] \n\t" // s2 = pASCL + "addi s3, %[pA], 4+2 \n\t" // s3 = pAData, (pA+AScl+ASum) + "mv s4, %[pB] \n\t" // s4 = pBSCL + "addi s5, %[pB], 32*3 \n\t" // s5 = pBdata, (pB+BScl+Bzp) + "mv s6, %[pC] \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v2, v0, v0 \n\t" // clear acc + + // ordinary vmadot: vle*6 flw*1 vecIns*21 vmadot*8 + ".align 4 \n\t" + "_K_LPST%=: \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vl4r.v v4, (s5) \n\t" // B Data 4VRF * 8Row * 32 + "addi s5, s5, 128*4+96 \n\t" // 1024bit + + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s4) \n\t" // B Scale 4VRF*8Row*FP16 = 512bit + "addi s4, s4, 64 \n\t" + + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v3, (s3) \n\t" // A Data M1*K32*int8 = 256bit + "addi s3, s3, 32+6 \n\t" + + "flw f0, (s2) \n\t" // A Scale fp32 + "lh t2, 4(s2) \n\t" // A sum of int16 + "addi s2, s2, 6+32 \n\t" + + "vsetvli t0, zero, e8, m1 \n\t" + "vsrl.vi v24, v3, 4 \n\t" + + "vnpack4.vv v8, v3, v3, 3 \n\t" // lo4 of A + "vnpack4.vv v10, v24, v24, 3 \n\t" // hi4 of A + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + "vmadotsu v16, v10, v4, i4 \n\t" // M0 N0 - N7 INT32(256bit) + "vmadotsu v18, v10, v5, i4 \n\t" // M0 N8 - N15 + "vmadotsu v20, v10, v6, i4 \n\t" // M0 N16 - N23 + "vmadotsu v22, v10, v7, i4 \n\t" // M0 N24 - N31 + + "vsll.vi v16, v16, 4 \n\t" + "vsll.vi v18, v18, 4 \n\t" + "vsll.vi v20, v20, 4 \n\t" + "vsll.vi v22, v22, 4 \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v1, (s4) \n\t" // Bzp + "addi s4, s4, 32+128*4 \n\t" + + "vmadotu v16, v8, v4, i4 \n\t" + "vmadotu v18, v8, v5, i4 \n\t" + "vmadotu v20, v8, v6, i4 \n\t" + "vmadotu v22, v8, v7, i4 \n\t" + + "vwaddu.vx v28, v1, x0 \n\t" // uint8 -> uint16 + "vpack.vv v24, v16, v18, 2 \n\t" + "vpack.vv v26, v20, v22, 2 \n\t" + "vpack.vv v16, v24, v26, 3 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vwmul.vx v24, v28, t2 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vadd.vv v16, v16, v24 \n\t" + + // b_scale fp16 -> fp32 + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwcvt.f.f.v v24, v0 \n\t" + // mac result i32 -> fp32 + "vsetvli t0, x0, e32, m1 \n\t" + "vfcvt.f.x.v v26, v16 \n\t" + // a_scale * b_scale; + "vfmul.vf v1, v24, f0 \n\t" + // static_cast(qsum) * a_scale * b_scale; + "vfmacc.vv v2, v1, v26 \n\t" + + "addi t3, t3, -1 \n\t" + "bgtz t3, _K_LPST%= \n\t" + "_K_LPND%=: \n\t" + + //----------------------------------------- + // STORE Equal 32N------------------------- + "_ST32%=: \n\t" + "vsetvli t0, t4, e32, m1 \n\t" + "vse32.v v2, (s6) \n\t" // M0 [N0 : N32]; FP32(1024bit) + + "_FUNC_END%=: \n\t" + + : + : [BCK] "r"(cnt), [NBLKS] "r"(nblks), [pA] "r"(quant_a_ptr), [pB] "r"(QuantBDataPtr), [pC] "r"(CPtr) + : "cc", "t0", "t2", "t3", "t4", "f0", "s2", "s3", "s4", "s5", "s6"); +#else + __asm__ volatile( + + // t3 = k/32 + "mv t3, %[BCK] \n\t" + "mv t4, %[NBLKS] \n\t" + "vsetvli t0, x0, e16, m1 \n\t" + "vmv.v.i v0, 1 \n\t" // init the scale + "mv s2, %[pA] \n\t" // s2 = pASCL + "addi s3, %[pA], 4+2 \n\t" // s3 = pAData, (pA+AScl+ASum) + "mv s4, %[pB] \n\t" // s4 = pBSCL + "addi s5, %[pB], 32*3 \n\t" // s5 = pBdata, (pB+BScl+Bzp) + "mv s6, %[pC] \n\t" + + "vsll.vi v1, v0, 4 \n\t" + "vxor.vv v2, v0, v0 \n\t" // clear acc + "vfcvt.f.x.v v0, v0 \n\t" + "vfcvt.f.x.v v1, v1 \n\t" + + // vmadot hp: vle*6 flw*1 vecIns*14 vmadot*8 + ".align 4 \n\t" + "_K_LPST%=: \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vl4r.v v4, (s5) \n\t" // B Data 4VRF * 8Row * 32 + "addi s5, s5, 128*4+96 \n\t" // 1024bit + + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v30, (s4) \n\t" // B Scale 4VRF*8Row*FP16 = 512bit + "addi s4, s4, 64 \n\t" + + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v31, (s4) \n\t" // B zp 32Row*uint8 = 256bit + "addi s4, s4, 32+128*4 \n\t" + + "vle8.v v3, (s3) \n\t" // A Data M1*K32*int8 = 256bit + "addi s3, s3, 32+6 \n\t" + + "flw f0, (s2) \n\t" // A Scale fp32 + "lh t2, 4(s2) \n\t" // A sum of int16 + "addi s2, s2, 6+32 \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vsrl.vi v24, v3, 4 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vnpack4.vv v8, v3, v3, 3 \n\t" // lo4 of A + "vnpack4.vv v10, v24, v24, 3 \n\t" // hi4 of A + + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + "vmadotsu.hp v16, v10, v4, v1, 0, i4 \n\t" // high 4 + "vmadotsu.hp v18, v10, v5, v1, 0, i4 \n\t" + "vmadotsu.hp v20, v10, v6, v1, 0, i4 \n\t" + "vmadotsu.hp v22, v10, v7, v1, 0, i4 \n\t" + "vmadotu.hp v16, v8, v4, v0, 0, i4 \n\t" // low 4 + "vmadotu.hp v18, v8, v5, v0, 0, i4 \n\t" + "vmadotu.hp v20, v8, v6, v0, 0, i4 \n\t" + "vmadotu.hp v22, v8, v7, v0, 0, i4 \n\t" + + "vsetvli t0, x0, e8, mf4 \n\t" + "vwaddu.vx v28, v31, x0 \n\t" // Bzp u8 -> u16 + + "vsetvli t0, x0, e8, m1 \n\t" + "vpack.vv v24, v16, v18, 1 \n\t" + "vpack.vv v26, v20, v22, 1 \n\t" + "vpack.vv v16, v24, v26, 2 \n\t" + + "vsetvli t0, x0, e16, mf2 \n\t" + "vmul.vx v26, v28, t2 \n\t" // asum*zp i16*i16 + "vfwcvt.f.f.v v22, v30 \n\t" // b_scale fp16 -> fp32 + "vfcvt.f.x.v v18, v26 \n\t" // zp i16 -> fp16 + "vsetvli t0, x0, e16, m1 \n\t" + "vfwadd.vv v20, v18, v16 \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + // mac result * b_scale; f32*f32->f32 + "vfmul.vv v31, v22, v20 \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + // static_cast(qsum * b_scale) * a_scale; + "vfmacc.vf v2, f0, v31 \n\t" + + "addi t3, t3, -1 \n\t" + "bgtz t3, _K_LPST%= \n\t" + "_K_LPND%=: \n\t" + + //----------------------------------------- + // STORE Equal 32N------------------------- + "_ST32%=: \n\t" + "vsetvli t0, t4, e32, m1 \n\t" + "vse32.v v2, (s6) \n\t" // M0 [N0 : N32]; FP32(1024bit) + + "_FUNC_END%=: \n\t" + + : + : [BCK] "r"(cnt), [NBLKS] "r"(nblks), [pA] "r"(quant_a_ptr), [pB] "r"(QuantBDataPtr), [pC] "r"(CPtr) + : "cc", "t0", "t2", "t3", "t4", "f0", "s2", "s3", "s4", "s5", "s6"); +#endif + } + } +} + +void gemm_kernel_i8i4_hp_m1(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + constexpr size_t NB_COLS = 32; + constexpr size_t k_subblks_per_superblk = 8; + + struct block_q4_0x32_layout { + _Float16 d[NB_COLS]; + uint8_t qs[16 * NB_COLS]; + }; + + GGML_ASSERT(blk_len == 256); + + const size_t b_superblk_stride = sizeof(block_q4_0x32_layout) * k_subblks_per_superblk + + (quant_b_zp ? NB_COLS * k_subblks_per_superblk * sizeof(uint8_t) : 0); + const size_t b_tile_stride = k_blks * b_superblk_stride; + + if (quant_b_zp == NULL) { + for (size_t ni = 0; ni < count_n; ni += 32) { + uint8_t * b_data = (uint8_t *) quant_b_data + (ni / NB_COLS) * b_tile_stride; + int8_t * a_data = (int8_t *) quant_a_ptr; + float * dst_c = c_ptr + ni; + + asm volatile( + "vsetvli t0, x0, e16, m1 \n\t" + "vxor.vv v31, v31, v31 \n\t" // init acc to zero + "mv t4, %[BK] \n\t" + "li t0, 0x4c00 \n\t" // 16 in fp16 + "fmv.h.x fa0, t0 \n\t" + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + "li t5, 8 \n\t" + "addi t6, %[A], 288 \n\t" // point to blk scale + "flh ft1, (t6) \n\t" + "addi t6, %[A], 272 \n\t" // point to asum + + // init the acc fp16 + "vsetvli t0, x0, e16, m1 \n\t" + "vxor.vv v16, v18, v18 \n\t" + "vxor.vv v17, v18, v18 \n\t" + "vxor.vv v18, v18, v18 \n\t" + "vxor.vv v19, v18, v18 \n\t" + + "INNER_BLK_LOOP%=: \n\t" + // load a sum and scale + "flh fa1, (t6) \n\t" + "addi t6, t6, 2 \n\t" + "flh ft0, (%[A]) \n\t" + "addi %[A], %[A], 2 \n\t" + // load A + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v3, (%[A]) \n\t" // 1x32@i8 + "addi %[A], %[A], 32 \n\t" + + // load scale B and B + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v8, (%[B]) \n\t" // b_scale fp16 + "addi %[B], %[B], 64 \n\t" + "vl4r.v v4, (%[B]) \n\t" // 32*32@i4 + "addi %[B], %[B], 512 \n\t" + "vfmul.vf v8, v8, ft0 \n\t" // scale b * scale a + "vfmul.vf v9, v8, fa0 \n\t" + "vfmul.vf v10, v8, fa1 \n\t" // scale b * scale a * asm + "vfwmacc.vf v31, ft1, v10 \n\t" // asum * scale a * scale b * blk scale + + "vsetvli t0, x0, e8, m1 \n\t" + "vpack.vv v0, v8, v9, 3 \n\t" + "vsrl.vi v28, v3, 4 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vnpack4.vv v2, v3, v3, 3 \n\t" // lo4 of A + "vnpack4.vv v3, v28, v28, 3 \n\t" // hi4 of A + + // i4 * i4 vmadot + "vsetvli t0, x0, e16, m1 \n\t" + "vmadotsu.hp v16, v3, v4, v0, 4, i4 \n\t" // high 4 + "vmadotsu.hp v17, v3, v5, v0, 5, i4 \n\t" + "vmadotsu.hp v18, v3, v6, v0, 6, i4 \n\t" + "vmadotsu.hp v19, v3, v7, v0, 7, i4 \n\t" + "vmadotu.hp v16, v2, v4, v0, 0, i4 \n\t" // low 4 + "vmadotu.hp v17, v2, v5, v0, 1, i4 \n\t" + "vmadotu.hp v18, v2, v6, v0, 2, i4 \n\t" + "vmadotu.hp v19, v2, v7, v0, 3, i4 \n\t" + + "addi t5, t5, -1 \n\t" + "bgtz t5, INNER_BLK_LOOP%= \n\t" + + "vpack.vv v8, v16, v17, 1 \n\t" + "vpack.vv v12, v18, v19, 1 \n\t" + "vpack.vv v20, v8, v12, 2 \n\t" + + "vsetvli t0, x0, e16, mf2 \n\t" + "addi t4, t4, -1 \n\t" + "vfwmacc.vf v31, ft1, v20 \n\t" + //"vsetvli t0, x0, e32, m1 \n\t" + //"vfmul.vf v31, v31, ft1 \n\t" // blk scale + + // update A ptr + "addi %[A], t6, 2 \n\t" + + "bgtz t4, BLK_LOOP%= \n\t" + + // save + "vsetvli t0, x0, e32, m1 \n\t" + "vse32.v v31, (%[DST]) \n\t" + : [A] "+r"(a_data), [B] "+r"(b_data) + : [DST] "r"(dst_c), [BK] "r"(k_blks) + : "t0", "t1", "t2", "t3", "t4", "t5", "t6", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", + "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", + "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "fa0", "fa1", "ft0", "ft1"); + } + } else { + // TODO: support quant_b_zp for i8i4 hp kernel + GGML_ABORT("gemm_kernel_i8i4_hp_m1 with quant_b_zp is not supported yet"); + } +} + +void gemm_kernel_i8i4_m4(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + int64_t b_data_stride = + k_blks * (sizeof(ggml_fp16_t) + 16 * sizeof(int8_t) + (quant_b_zp != NULL ? sizeof(int8_t) : 0)); + if (quant_b_zp == NULL) { + for (size_t ni = 0; ni < count_n; ni += 32) { + uint8_t * b_data = (uint8_t *) quant_b_data + ni * b_data_stride; + int8_t * a_data = (int8_t *) quant_a_ptr; + float * dst_c = c_ptr + ni; +#if 0 + asm volatile( + "li t1, 8 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "vxor.vv v29, v29, v29 \n\t" + "vxor.vv v30, v30, v30 \n\t" + "vxor.vv v31, v31, v31 \n\t" + "mv t4, %[BK] \n\t" + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + // load scale A + "flw fa0, (%[A]) \n\t" + "flw fa1, 4(%[A]) \n\t" + "flw fa2, 8(%[A]) \n\t" + "flw fa3, 12(%[A]) \n\t" + "addi %[A], %[A], 16 \n\t" + + // load scale B + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v12, (%[B]) \n\t" + "addi %[B], %[B], 64 \n\t" + "vfwcvt.f.f.v v14, v12 \n\t" + + "vsetivli t0, 4, e16, mf2 \n\t" + "vle16.v v8, (%[A]) \n\t" // asum + "addi %[A], %[A], 8 \n\t" + "vwmul.vx v10, v8, t1 \n\t" // 8*asum + + "vsetvli t0, x0, e8, m1 \n\t" + "vl1r.v v0, (%[A]) \n\t" + "addi %[A], %[A], 128 \n\t" // 4*32@i8 + "vl4r.v v4, (%[B]) \n\t" // 32*32@i4 + "addi %[B], %[B], 512 \n\t" + "vsrl.vi v1, v0, 4 \n\t" + "vnpack4.vv v12, v0, v1, 3 \n\t" // A low u4 + "vupack.vv v2, v12, v12, 2 \n\t" + + // init the accumu to asum * zp + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + // i4 * i4 vmadot + "vsetvli t0, x0, e32, m1 \n\t" + "vmadotsu v16, v3, v4, i4 \n\t" // high 4 + "vmadotsu v18, v3, v5, i4 \n\t" + "vmadotsu v20, v3, v6, i4 \n\t" + "vmadotsu v22, v3, v7, i4 \n\t" + "vsll.vi v16, v16, 4 \n\t" + "vsll.vi v18, v18, 4 \n\t" + "vsll.vi v20, v20, 4 \n\t" + "vsll.vi v22, v22, 4 \n\t" + "vmadotu v16, v2, v4, i4 \n\t" // low 4 + "vmadotu v18, v2, v5, i4 \n\t" + "vmadotu v20, v2, v6, i4 \n\t" + "vmadotu v22, v2, v7, i4 \n\t" + + "vpack.vv v0, v16, v18, 2 \n\t" + "vpack.vv v2, v20, v22, 2 \n\t" + "vpack.vv v16, v0, v2, 3 \n\t" + "vpack.vv v18, v1, v3, 3 \n\t" + + "vrgather.vi v0, v10, 0 \n\t" + "vrgather.vi v1, v10, 1 \n\t" + "vrgather.vi v2, v10, 2 \n\t" + "vrgather.vi v3, v10, 3 \n\t" + + "vadd.vv v16, v16, v0 \n\t" + "vadd.vv v17, v17, v1 \n\t" + "vadd.vv v18, v18, v2 \n\t" + "vadd.vv v19, v19, v3 \n\t" + + "vfcvt.f.x.v v16, v16 \n\t" + "vfcvt.f.x.v v17, v17 \n\t" + "vfcvt.f.x.v v18, v18 \n\t" + "vfcvt.f.x.v v19, v19 \n\t" + + // mul scale + "vfmul.vv v16, v16, v14 \n\t" + "vfmul.vv v17, v17, v14 \n\t" + "vfmul.vv v18, v18, v14 \n\t" + "vfmul.vv v19, v19, v14 \n\t" + + "addi t4, t4, -1 \n\t" + "vfmacc.vf v28, fa0, v16 \n\t" + "vfmacc.vf v29, fa1, v17 \n\t" + "vfmacc.vf v30, fa2, v18 \n\t" + "vfmacc.vf v31, fa3, v19 \n\t" + + "bgtz t4, BLK_LOOP%= \n\t" + + // save + "vsetvli t0, x0, e32, m1 \n\t" + "add t2, %[LDC], %[DST] \n\t" + "vse32.v v28, (%[DST]) \n\t" + "add t3, %[LDC], t2 \n\t" + "vse32.v v29, (t2) \n\t" + "add t2, %[LDC], t3 \n\t" + "vse32.v v30, (t3) \n\t" + "vse32.v v31, (t2) \n\t" + : [A] "+r"(a_data), [B] "+r"(b_data) + : [DST] "r"(dst_c), [LDC] "r"(ldc*4), [BK] "r"(k_blks) + : "t0", "t1", "t2", "t3", "t4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", + "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31", "fa0", "fa1", "fa2", "fa3"); +#else + asm volatile( + "vsetvli t0, x0, e16, m1 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "vxor.vv v29, v29, v29 \n\t" + "vxor.vv v30, v30, v30 \n\t" + "vxor.vv v31, v31, v31 \n\t" + "vmv.v.i v0, 1 \n\t" // init the scale + "vsll.vi v1, v0, 4 \n\t" + "vfcvt.f.x.v v0, v0 \n\t" + "vfcvt.f.x.v v1, v1 \n\t" + "mv t4, %[BK] \n\t" + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + // load scale A + "flw fa0, (%[A]) \n\t" + "flw fa1, 4(%[A]) \n\t" + "flw fa2, 8(%[A]) \n\t" + "flw fa3, 12(%[A]) \n\t" + "addi %[A], %[A], 16 \n\t" + + // load scale B + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v12, (%[B]) \n\t" + "addi %[B], %[B], 64 \n\t" + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v14, v12, v12, 3 \n\t" + + "vsetivli t0, 4, e16, mf2 \n\t" + "vle16.v v8, (%[A]) \n\t" // asum + "addi %[A], %[A], 8 \n\t" + "vsll.vi v8, v8, 3 \n\t" // asum * 8 + "vfcvt.f.x.v v9, v8 \n\t" + "vsetvli t0, x0, e64, m1 \n\t" + "vrgather.vi v10, v9, 0 \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vl1r.v v16, (%[A]) \n\t" + "addi %[A], %[A], 128 \n\t" // 4*32@i8 + "vl4r.v v4, (%[B]) \n\t" // 32*32@i4 + "addi %[B], %[B], 512 \n\t" + "vsrl.vi v17, v16, 4 \n\t" + "vnpack4.vv v12, v16, v17, 3 \n\t" // A low u4 + "vupack.vv v2, v12, v12, 2 \n\t" + + // init the accumu to asum * zp + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v16, v10, v10,0 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vpack.vv v20, v16, v16,0 \n\t" + "vsetvli t0, x0, e64, m1 \n\t" + "vpack.vv v18, v20, v20, 0 \n\t" + "vor.vv v20, v18, v18 \n\t" + "vor.vv v21, v18, v18 \n\t" + + // i4 * i4 vmadot + "vsetvli t0, x0, e16, m1 \n\t" + "vmadotsu.hp v18, v3, v4, v1, 0, i4 \n\t" // high 4 + "vmadotsu.hp v19, v3, v5, v1, 0, i4 \n\t" + "vmadotsu.hp v20, v3, v6, v1, 0, i4 \n\t" + "vmadotsu.hp v21, v3, v7, v1, 0, i4 \n\t" + "vmadotu.hp v18, v2, v4, v0, 0, i4 \n\t" // low 4 + "vmadotu.hp v19, v2, v5, v0, 0, i4 \n\t" + "vmadotu.hp v20, v2, v6, v0, 0, i4 \n\t" + "vmadotu.hp v21, v2, v7, v0, 0, i4 \n\t" + + "vpack.vv v8, v18, v19, 1 \n\t" + "vpack.vv v12, v20, v21, 1 \n\t" + "vpack.vv v20, v8, v12, 2 \n\t" + + "vfwmul.vv v16, v20, v14 \n\t" + "vfwmul.vv v18, v21, v14 \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + + "addi t4, t4, -1 \n\t" + "vfmacc.vf v28, fa0, v16 \n\t" + "vfmacc.vf v29, fa1, v17 \n\t" + "vfmacc.vf v30, fa2, v18 \n\t" + "vfmacc.vf v31, fa3, v19 \n\t" + + "bgtz t4, BLK_LOOP%= \n\t" + + // save + "vsetvli t0, x0, e32, m1 \n\t" + "add t2, %[LDC], %[DST] \n\t" + "vse32.v v28, (%[DST]) \n\t" + "add t3, %[LDC], t2 \n\t" + "vse32.v v29, (t2) \n\t" + "add t2, %[LDC], t3 \n\t" + "vse32.v v30, (t3) \n\t" + "vse32.v v31, (t2) \n\t" + : [A] "+r"(a_data), [B] "+r"(b_data) + : [DST] "r"(dst_c), [LDC] "r"(ldc * 4), [BK] "r"(k_blks) + : "t0", "t1", "t2", "t3", "t4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", + "v25", "v26", "v27", "v28", "v29", "v30", "v31", "fa0", "fa1", "fa2", "fa3"); +#endif + } + } else { + for (size_t ni = 0; ni < count_n; ni += 32) { + uint8_t * b_data = (uint8_t *) quant_b_data + ni * b_data_stride; + int8_t * a_data = (int8_t *) quant_a_ptr; + float * dst_c = c_ptr + ni; + + asm volatile( + "li t1, 8 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "vxor.vv v29, v29, v29 \n\t" + "vxor.vv v30, v30, v30 \n\t" + "vxor.vv v31, v31, v31 \n\t" + "mv t4, %[BK] \n\t" + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + // load scale A + "flw fa0, (%[A]) \n\t" + "flw fa1, 4(%[A]) \n\t" + "flw fa2, 8(%[A]) \n\t" + "flw fa3, 12(%[A]) \n\t" + "addi %[A], %[A], 16 \n\t" + + // load scale B + "vsetvli t0, x0, e16, mf2\n\t" + "vle16.v v12, (%[B]) \n\t" + "addi %[B], %[B], 64 \n\t" + "vfwcvt.f.f.v v14, v12 \n\t" + + // load zp + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v8, (%[B]) \n\t" + "addi %[B], %[B], 32 \n\t" + "vwaddu.vx v10, v8, x0 \n\t" + + // load a sum + "lh s1, (%[A]) \n\t" + "lh s2, 2(%[A]) \n\t" + "lh s3, 4(%[A]) \n\t" + "lh s4, 6(%[A]) \n\t" + "addi %[A], %[A], 8 \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vl1r.v v0, (%[A]) \n\t" + "addi %[A], %[A], 128 \n\t" // 4*32@i8 + "vl4r.v v4, (%[B]) \n\t" // 32*32@i4 + "addi %[B], %[B], 512 \n\t" + "vsrl.vi v1, v0, 4 \n\t" + "vnpack4.vv v12, v0, v1, 3 \n\t" // A low u4 + "vupack.vv v2, v12, v12, 2 \n\t" + + // init the accumu to asum * zp + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + // i4 * i4 vmadot + "vsetvli t0, x0, e32, m1 \n\t" + "vmadotsu v16, v3, v4, i4 \n\t" // high 4 + "vmadotsu v18, v3, v5, i4 \n\t" + "vmadotsu v20, v3, v6, i4 \n\t" + "vmadotsu v22, v3, v7, i4 \n\t" + "vsll.vi v16, v16, 4 \n\t" + "vsll.vi v18, v18, 4 \n\t" + "vsll.vi v20, v20, 4 \n\t" + "vsll.vi v22, v22, 4 \n\t" + "vmadotu v16, v2, v4, i4 \n\t" // low 4 + "vmadotu v18, v2, v5, i4 \n\t" + "vmadotu v20, v2, v6, i4 \n\t" + "vmadotu v22, v2, v7, i4 \n\t" + + "vpack.vv v0, v16, v18, 2 \n\t" + "vpack.vv v2, v20, v22, 2 \n\t" + "vpack.vv v16, v0, v2, 3 \n\t" + "vpack.vv v18, v1, v3, 3 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vwmul.vx v0, v10, s1 \n\t" + "vwmul.vx v2, v10, s2 \n\t" + "vwmul.vx v4, v10, s3 \n\t" + "vwmul.vx v6, v10, s4 \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vadd.vv v16, v16, v0 \n\t" + "vadd.vv v17, v17, v2 \n\t" + "vadd.vv v18, v18, v4 \n\t" + "vadd.vv v19, v19, v6 \n\t" + + "vfcvt.f.x.v v16, v16 \n\t" + "vfcvt.f.x.v v17, v17 \n\t" + "vfcvt.f.x.v v18, v18 \n\t" + "vfcvt.f.x.v v19, v19 \n\t" + + // mul scale + "vfmul.vv v16, v16, v14 \n\t" + "vfmul.vv v17, v17, v14 \n\t" + "vfmul.vv v18, v18, v14 \n\t" + "vfmul.vv v19, v19, v14 \n\t" + + "addi t4, t4, -1 \n\t" + "vfmacc.vf v28, fa0, v16 \n\t" + "vfmacc.vf v29, fa1, v17 \n\t" + "vfmacc.vf v30, fa2, v18 \n\t" + "vfmacc.vf v31, fa3, v19 \n\t" + + "bgtz t4, BLK_LOOP%= \n\t" + + // save + "vsetvli t0, x0, e32, m1 \n\t" + "add t2, %[LDC], %[DST]\n\t" + "vse32.v v28, (%[DST]) \n\t" + "add t3, %[LDC], t2 \n\t" + "vse32.v v29, (t2) \n\t" + "add t2, %[LDC], t3 \n\t" + "vse32.v v30, (t3) \n\t" + "vse32.v v31, (t2) \n\t" + : [A] "+r"(a_data), [B] "+r"(b_data) + : [DST] "r"(dst_c), [LDC] "r"(ldc * 4), [BK] "r"(k_blks) + : "t0", "t1", "t2", "t3", "t4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", + "v25", "v26", "v27", "v28", "v29", "v30", "v31", "fa0", "fa1", "fa2", "fa3", "s1", "s2", "s3", "s4"); + } + } +} + +void gemm_kernel_i8i4_hp_m4(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + constexpr size_t NB_COLS = 32; + constexpr size_t K_SUBBLKS_PER_SUPERBLK = 8; + constexpr size_t K_SUBBLK_LEN = 32; + + struct block_q4_0x32_layout { + _Float16 d[NB_COLS]; + uint8_t qs[16 * NB_COLS]; + }; + + GGML_ASSERT(blk_len == 256); + GGML_ASSERT(count_m >= 4); + + // Contract: + // - computes a 4-row x 32-col tile per inner invocation + // - A is q8 HP packed in m4 layout, one logical K256 block at a time + // - B is q4 HP packed in N32 tiles, optionally with a separate zp area + // - tail-N is currently not handled here; the caller must provide full N32 tiles + + const size_t b_superblk_stride = sizeof(block_q4_0x32_layout) * K_SUBBLKS_PER_SUPERBLK + + (quant_b_zp ? NB_COLS * K_SUBBLKS_PER_SUPERBLK * sizeof(uint8_t) : 0); + const size_t b_tile_stride = k_blks * b_superblk_stride; + const size_t a_nrow_block_stride = q8_hp_blk_size(blk_len, true, true) * 4; + const size_t a_subblk_stride = q8_hp_blk_size(K_SUBBLK_LEN, false, false) * 4; + + if (quant_b_zp != nullptr) { + for (size_t ni = 0; ni < count_n; ni += NB_COLS) { + const size_t nb_real = std::min(NB_COLS, count_n - ni); + if (nb_real != NB_COLS) { + break; + } + + uint8_t * b_tile_base = (uint8_t *) quant_b_data + (ni / NB_COLS) * b_tile_stride; + uint8_t * a_block = (uint8_t *) quant_a_ptr; + float * dst_c = c_ptr + ni; + + // Data layout summary for the with-zp path. + // + // A: M4 x K256 q8 HP block + // - split into 8 x K32 subblocks + // - each K32 subblock is 136B: + // 8B = 4 x fp16 row scales + // 128B = 4 x int8[32] row payloads + // - trailer after 8 subblocks is 72B: + // 4 rows x fp16[8] a_sum values, indexed as [row][ksi] + // 4 rows x fp16 scale_avg tail + // + // B: N32 x K256 q4 HP block with explicit zp area + // - each K32 subblock is 576B: + // 64B = fp16 scale[32] + // 512B = packed q4 payload for 32 columns x 32 k-elements + // - zp is stored separately, not interleaved with the 576B payload block + // - one K256 superblock is laid out as: + // 8 x (scale + qs) blocks = 4608B + // 8 x zp[32] = 256B + // + // C: 4 rows x 32 fp32 outputs + // + // ASM pointer convention: + // - t6: current A K32 subblock base + // - t2: current A a_sum base for this ksi + // row1/row2/row3 are at +16/+32/+48 bytes + // - s5: current B (scale + qs) K32 subblock base + // - s6: current B zp[32] base for this ksi + // + // Loop progression: + // - per ksi: A += 136, a_sum += 2, B_data += 576, B_zp += 32 + // - per ki : skip the 72B A trailer and advance B to the next 4864B superblock + + const _Float16 hp_scale_16 = (_Float16) 16.0f; + const _Float16 hp_scale_1 = (_Float16) 1.0f; + const _Float16 hp_scale_0125 = (_Float16) 0.125f; + + // VPR grouping used below: + // - v4-v7 : B q4 payload for N32 split as 4 x N8 groups + // - v8/v10 : zp u8 / widened fp16 + // - v12 : B fp16 scale[32] + // - v14-v15 : packed (Bscale * Ascale) for rows [0,1] / [2,3] + // - v16-v19 : temporary per-row scaled B scales + // - v28-v31 : final fp32 accumulators for rows 0..3 + + asm volatile( + "mv t5, %[BK] \n\t" + "mv t6, %[A] \n\t" + "mv s5, %[B] \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "vxor.vv v29, v29, v29 \n\t" + "vxor.vv v30, v30, v30 \n\t" + "vxor.vv v31, v31, v31 \n\t" + "li t4, 8 \n\t" + "li t1, 4608 \n\t" + "addi t2, t6, 1088 \n\t" // 8 * 136B A K32 subblocks, a_sum trailer starts here + "add s6, s5, t1 \n\t" // 8 * 576B B(scale+qs), zp area starts here + + ".align 4 \n\t" + "_BLK_LPST%=: \n\t" + "flh fa1, 64(t2) \n\t" // a_scale_avg_row[0] + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v18, v30, v30 \n\t" + "vxor.vv v19, v31, v31 \n\t" + "vxor.vv v20, v30, v30 \n\t" + "vxor.vv v21, v31, v31 \n\t" + "_KsubBLK_LPST%=: \n\t" + // load first subblock scales for 4 rows + "flh fa0, 0(t6) \n\t" // ascale_fp16 + + // load B fp16 scales[32] + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v12, (s5) \n\t" + + // load Bzp[32] for the current ksi from the dedicated zp area + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v8, (s6) \n\t" + + "fmul.h fa2, fa0, %[HP16] \n\t" + "vfwcvt.f.xu.v v10, v8 \n\t" // uint8 -> fp16 + + "vsetvli t0, x0, e16, mf2 \n\t" + "vfmul.vf v16, v12, fa0 \n\t" // row0: Bscale * Ascale + "vfmul.vf v17, v12, fa2 \n\t" + + // load a_sum[row][ksi] from the trailer; t2 points to row0[ksi] + "flh ft1, 0(t2) \n\t" + "flh ft2, 16(t2) \n\t" + "flh ft3, 32(t2) \n\t" + "flh ft4, 48(t2) \n\t" + + "fmul.h ft1, ft1, %[HP0125] \n\t" + "fmul.h ft2, ft2, %[HP0125] \n\t" + "fmul.h ft3, ft3, %[HP0125] \n\t" + "fmul.h ft4, ft4, %[HP0125] \n\t" + + // load A payload from current K32 subblock and B q4 payload from current 576B block + "addi t3, t6, 8 \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vl1r.v v0, (t3) \n\t" //A + "addi t3, s5, 64 \n\t" + "vl4r.v v4, (t3) \n\t" //B + + "vsetvli t0, x0, e8, m1 \n\t" + "vsrl.vi v1, v0, 4 \n\t" + "vnpack4.vv v12, v0, v1, 3 \n\t" + "vpack.vv v0, v17, v16, 3 \n\t" + "vupack.vv v2, v12, v12, 2 \n\t" + + "vsetvli t0, x0, e16, mf2 \n\t" // mf2 -> mf2 + "vfmul.vv v10, v10, v16 \n\t" // zp * ascale * bscale; fp16*fp16 + + "vsetvli t0, x0, e16, mf2 \n\t" // mf2 -> m1 + "vfmul.vf v12, v10, ft1 \n\t" // zp(1:n)* abscale * asum_m0; fp16*fp16 + "vfmul.vf v13, v10, ft2 \n\t" // zp(1:n)* abscale * asum_m1; fp16*fp16 + "vfmul.vf v24, v10, ft3 \n\t" // zp(1:n)* abscale * asum_m2; fp16*fp16 + "vfmul.vf v25, v10, ft4 \n\t" // zp(1:n)* abscale * asum_m3; fp16*fp16 + + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwmacc.vf v28, fa1, v12 \n\t" // row0/1 accum += dot * packed scale + "vfwmacc.vf v29, fa1, v13 \n\t" + "vfwmacc.vf v30, fa1, v24 \n\t" + "vfwmacc.vf v31, fa1, v25 \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vmadotsu.hp v18, v3, v4, v0, 0, i4 \n\t" //lo4;n0n7 + "vmadotsu.hp v19, v3, v5, v0, 1, i4 \n\t" //lo4;n8n15 + "vmadotsu.hp v20, v3, v6, v0, 2, i4 \n\t" //lo4;n16n23 + "vmadotsu.hp v21, v3, v7, v0, 3, i4 \n\t" //lo4;n24n31 + "vmadotu.hp v18, v2, v4, v0, 4, i4 \n\t" //hi4;n0n7 + "vmadotu.hp v19, v2, v5, v0, 5, i4 \n\t" //hi4;n8n15 + "vmadotu.hp v20, v2, v6, v0, 6, i4 \n\t" //hi4;n16n23 + "vmadotu.hp v21, v2, v7, v0, 7, i4 \n\t" //hi4;n24n31 + + "addi t4, t4, -1 \n\t" + "addi t6, t6, 8+128 \n\t" // next A K32 subblock + "addi t2, t2, 2 \n\t" // next ksi entry in each a_sum row + "addi s5, s5, 64+512 \n\t" // next B (scale + qs) K32 block + "addi s6, s6, 32 \n\t" // next zp[32] + "bgtz t4, _KsubBLK_LPST%= \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v8, v18, v19, 1 \n\t" // 128(16*8)->256(16*16) + "vpack.vv v12, v20, v21, 1 \n\t" + "vpack.vv v26, v8, v12, 2 \n\t" // 256(16*16)->512(16*32) + + "vsetvli t0, x0, e16, m1 \n\t" + "vfwmacc.vf v28, fa1, v26 \n\t" // row0/1 accum += dot * packed scale + "vfwmacc.vf v30, fa1, v27 \n\t" + + "li t4, 8 \n\t" + "addi t5, t5, -1 \n\t" + "addi t6, t6, 72 \n\t" // skip A trailer after 8 subblocks and scale_avg tail + "mv s5, s6 \n\t" // s6 already points to next B superblock base + "addi t2, t6, 1088 \n\t" // 8 * 136B A K32 subblocks, a_sum trailer starts here + "add s6, s5, t1 \n\t" // 8 * 576B B(scale+qs), zp area starts here + "bgtz t5, _BLK_LPST%= \n\t" + + "_BLK_LPND%=: \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "add t2, %[LDC], %[DST] \n\t" + "vse32.v v28, (%[DST]) \n\t" + "add t3, %[LDC], t2 \n\t" + "vse32.v v29, (t2) \n\t" + "add t2, %[LDC], t3 \n\t" + "vse32.v v30, (t3) \n\t" + "vse32.v v31, (t2) \n\t" + : [A] "+r"(a_block), [B] "+r"(b_tile_base) + : [DST] "r"(dst_c), [LDC] "r"(ldc * 4), [BK] "r"(k_blks), [HP16] "f"(hp_scale_16), + [HP1] "f"(hp_scale_1), [HP0125] "f"(hp_scale_0125) + : "t0", "t1", "t2", "t3", "t4", "t5", "t6", "s5", "s6", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", + "v8", "v10", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v24", + "v25", "v26", "v27", "v28", "v29", "v30", "v31", "fa0", "fa1", "fa2", "ft1", "ft2", "ft3", "ft4", + "memory"); + } + return; + } else { + for (size_t ni = 0; ni < count_n; ni += NB_COLS) { + const size_t nb_real = std::min(NB_COLS, count_n - ni); + if (nb_real != NB_COLS) { + break; + } + + uint8_t * b_tile_base = (uint8_t *) quant_b_data + (ni / NB_COLS) * b_tile_stride; + uint8_t * a_block = (uint8_t *) quant_a_ptr; + float * dst_c = c_ptr + ni; + + // Data layout summary for the no-zp path. + // + // A layout is identical to the with-zp branch. + // + // B: N32 x K256 q4 HP block without explicit zp storage + // - each K32 subblock is still 576B: + // 64B = fp16 scale[32] + // 512B = packed q4 payload + // - zp is implicit and treated as a constant value 8 in the kernel + // - one K256 superblock therefore contains only: + // 8 x (scale + qs) blocks = 4608B + // + // C: 4 rows x 32 fp32 outputs + // + // ASM pointer convention: + // - t6: current A K32 subblock base + // - t2: current A a_sum base for this ksi + // - s5: current B (scale + qs) K32 subblock base + // + // Loop progression: + // - per ksi: A += 136, a_sum += 2, B_data += 576 + // - per ki : skip the 72B A trailer and advance B to the next 4608B superblock + + const _Float16 hp_scale_16 = (_Float16) 16.0f; + const _Float16 hp_scale_1 = (_Float16) 1.0f; + + // VPR grouping used below matches the with-zp path: + // - v4-v7 : B q4 payload for N32 split as 4 x N8 groups + // - v8/v10 : implicit zp lane / widened fp16 + // - v12 : B fp16 scale[32] + // - v14-v15 : packed (Bscale * Ascale) for rows [0,1] / [2,3] + // - v16-v19 : temporary per-row scaled B scales + // - v28-v31 : final fp32 accumulators for rows 0..3 + + asm volatile( + "mv t5, %[BK] \n\t" + "mv t6, %[A] \n\t" + "mv s5, %[B] \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "vxor.vv v29, v29, v29 \n\t" + "vxor.vv v30, v30, v30 \n\t" + "vxor.vv v31, v31, v31 \n\t" + "li t4, 8 \n\t" + "addi t2, t6, 1088 \n\t" // 8 * 136B A K32 subblocks, a_sum trailer starts here + + ".align 4 \n\t" + "_BLK_LPST%=: \n\t" + "flh fa1, 64(t2) \n\t" // a_scale_avg_row[0] + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v18, v30, v30 \n\t" + "vxor.vv v19, v31, v31 \n\t" + "vxor.vv v20, v30, v30 \n\t" + "vxor.vv v21, v31, v31 \n\t" + "_KsubBLK_LPST%=: \n\t" + // load first subblock scales for 4 rows + "flh fa0, 0(t6) \n\t" // ascale_fp16 + + // load B fp16 scales[32] + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v12, (s5) \n\t" + + "fmul.h fa2, fa0, %[HP16] \n\t" + + "vsetvli t0, x0, e16, mf2 \n\t" + "vfmul.vf v16, v12, fa0 \n\t" // row0: Bscale * Ascale + "vfmul.vf v17, v12, fa2 \n\t" + + // load a_sum[row][ksi] from the trailer; t2 points to row0[ksi] + "flh ft1, 0(t2) \n\t" + "flh ft2, 16(t2) \n\t" + "flh ft3, 32(t2) \n\t" + "flh ft4, 48(t2) \n\t" + + // load A payload from current K32 subblock and B q4 payload from current 576B block + "addi t3, t6, 8 \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vl1r.v v0, (t3) \n\t" //A + "addi t3, s5, 64 \n\t" + "vl4r.v v4, (t3) \n\t" //B + + "vsetvli t0, x0, e8, m1 \n\t" + "vsrl.vi v1, v0, 4 \n\t" + "vnpack4.vv v12, v0, v1, 3 \n\t" + "vpack.vv v0, v17, v16, 3 \n\t" + "vupack.vv v2, v12, v12, 2 \n\t" + + "vsetvli t0, x0, e16, mf2 \n\t" // mf2 -> m1 + "vfmul.vf v12, v16, ft1 \n\t" // zp(1:n)* abscale * asum_m0; fp16*fp16 + "vfmul.vf v13, v16, ft2 \n\t" // zp(1:n)* abscale * asum_m1; fp16*fp16 + "vfmul.vf v24, v16, ft3 \n\t" // zp(1:n)* abscale * asum_m2; fp16*fp16 + "vfmul.vf v25, v16, ft4 \n\t" // zp(1:n)* abscale * asum_m3; fp16*fp16 + + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwmacc.vf v28, fa1, v12 \n\t" + "vfwmacc.vf v29, fa1, v13 \n\t" + "vfwmacc.vf v30, fa1, v24 \n\t" + "vfwmacc.vf v31, fa1, v25 \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vmadotsu.hp v18, v3, v4, v0, 0, i4 \n\t" //lo4;n0n7 + "vmadotsu.hp v19, v3, v5, v0, 1, i4 \n\t" //lo4;n8n15 + "vmadotsu.hp v20, v3, v6, v0, 2, i4 \n\t" //lo4;n16n23 + "vmadotsu.hp v21, v3, v7, v0, 3, i4 \n\t" //lo4;n24n31 + "vmadotu.hp v18, v2, v4, v0, 4, i4 \n\t" //hi4;n0n7 + "vmadotu.hp v19, v2, v5, v0, 5, i4 \n\t" //hi4;n8n15 + "vmadotu.hp v20, v2, v6, v0, 6, i4 \n\t" //hi4;n16n23 + "vmadotu.hp v21, v2, v7, v0, 7, i4 \n\t" //hi4;n24n31 + + "addi t4, t4, -1 \n\t" + + "addi t6, t6, 8+128 \n\t" // next A K32 subblock + "addi t2, t2, 2 \n\t" // next ksi entry in each a_sum row + "addi s5, s5, 64+512 \n\t" // next B (scale + qs) K32 block + "bgtz t4, _KsubBLK_LPST%= \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" //N32in1register + "vpack.vv v8, v18, v19, 1 \n\t" // 128(16*8)->256(16*16) + "vpack.vv v12, v20, v21, 1 \n\t" + "vpack.vv v26, v8, v12, 2 \n\t" // 256(16*16)->512(16*32) + + "vsetvli t0, x0, e16, m1 \n\t" + "vfwmacc.vf v28, fa1, v26 \n\t" // row0/1 accum += dot * packed scale + "vfwmacc.vf v30, fa1, v27 \n\t" + + "li t4, 8 \n\t" + "addi t5, t5, -1 \n\t" + "addi t6, t6, 72 \n\t" // skip A trailer after 8 subblocks and scale_avg tail + // s5 already points to next B superblock base + "addi t2, t6, 1088 \n\t" // 8 * 136B A K32 subblocks, a_sum trailer starts here + "bgtz t5, _BLK_LPST%= \n\t" + + "_BLK_LPND%=: \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "add t2, %[LDC], %[DST] \n\t" + "vse32.v v28, (%[DST]) \n\t" + "add t3, %[LDC], t2 \n\t" + "vse32.v v29, (t2) \n\t" + "add t2, %[LDC], t3 \n\t" + "vse32.v v30, (t3) \n\t" + "vse32.v v31, (t2) \n\t" + : [A] "+r"(a_block), [B] "+r"(b_tile_base) + : [DST] "r"(dst_c), [LDC] "r"(ldc * 4), [BK] "r"(k_blks), [HP16] "f"(hp_scale_16), [HP1] "f"(hp_scale_1) + : "t0", "t2", "t3", "t4", "t5", "t6", "s5", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v10", + "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v24", "v25", "v26", + "v27", "v28", "v29", "v30", "v31", "fa0", "fa1", "fa2", "ft1", "ft2", "ft3", "ft4", "memory"); + } + return; + } +} + +void gemm_kernel_i8mxfp4_m1(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + constexpr size_t NB_COLS = 32; + constexpr size_t K_TILE = 32; + using blk_type = nrow_block_mxfp4; + + GGML_ASSERT(blk_len == K_TILE); + GGML_ASSERT(count_m == 1); + GGML_UNUSED(quant_b_zp); + + const size_t a_blk_stride = q8_blk_size(blk_len, true); + const size_t b_blk_stride = sizeof(blk_type); + const size_t b_tile_stride = k_blks * b_blk_stride; + + if (quant_b_zp == NULL) { + for (size_t n = 0; n < count_n; n += 32) { + size_t nblks = (count_n - n) > 32 ? 32 : count_n - n; + // MXFP4 no-zp: per column per k-block stride = scale_e8m0(1B) + qs(16B) + qh(4B) = 21B + uint8_t * QuantBDataPtr = (uint8_t *) quant_b_data + // + n * k_blks * (blk_len / 8) + // qh sign/high-bit mask: n×k_blks×4 + n * k_blks * blk_len / 2 + // qs packed 4-bit magnitudes: n×k_blks×16 + n * k_blks * sizeof(uint8_t); // scale: n×k_blks×1 + float * CPtr = c_ptr + n; + size_t cnt = k_blks; + + // A format (q8 block with per-block scale and stored sum field): + // || scl(fp32,4B) | asum(int16,2B) | data(int8,32B) || × k_blks + // + // Register map: + // t3 = k_blks loop counter t4 = nblks (tail) + // f0 = A scale (fp32) + // s2 = pA (scale/asum) s3 = pA data + // s4 = pB scales (u8×32) + // s5 = pB qh (sign/high-bit mask, 128B) + // s6 = pB qs (packed 4-bit magnitudes, 512B) + // s7 = pC + // v3 = fp32 accumulator (N32) + // v2 = B scales u8 (loaded as bytes; later widened) + // v0 = qh mask bytes (also used as v0.t mask after load) + // v1 = A int8 (K32) + // v8..v15 / v16..v23 = qs unpack/pack temporaries (build signed vmadot lanes) + // v24/v26/v28/v30 = int32 dot accumulators & packing temps + + __asm__ volatile( + "mv t3, %[BCK] \n\t" // t3 = k_blks + "mv t4, %[NBLKS] \n\t" // t4 = nblks (tail guard) + + // ---- pre-loop: init fp16 constants in e16 m1 context ---- + "vsetvli t0, x0, e16, m1 \n\t" + "vmv.v.i v0, 1 \n\t" // v0 = int16(1) + "vfcvt.f.x.v v0, v0 \n\t" // v0 = 1.0_fp16 + "vxor.vv v3, v16, v16 \n\t" + + // ---- pointer setup ---- + "mv s2, %[pA] \n\t" // s2 = pA (scale, fp32) + "addi s3, %[pA], 4+2 \n\t" // s3 = pA data (skip scale+asum) + "mv s4, %[pB] \n\t" // s4 = pBSCL + "addi s5, %[pB], 32 \n\t" // s5 = pBh (pB + 32B scale) + "addi s6, %[pB], 32+128 \n\t" // s6 = pBs (pB + 32 + 128 = pB+192) + "mv s7, %[pC] \n\t" // s7 = pC + + // ===================================================================== + // K-block loop: each iteration processes one N32×K32 block + // Stride per k-block = 672B = 32(scl) + 512(Bs) + 128(Bh) + // ===================================================================== + ".align 4 \n\t" + "BLK_LPST%=: \n\t" + + // ---- load qs (512B = 4 VRF) from s6, advance s6 by 672 ---- + "vsetvli t0, x0, e8, m1 \n\t" + "vl4r.v v8, (s6) \n\t" // v8..v11 = qs N32K32 packed 4-bit magnitudes + "addi s6, s6, 128*4+128+32 \n\t" // s6 += 672 (512+128+32) + + // ---- load B scale (32B = 32×u8) from s4, advance s4 by 672 ---- + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v2, (s4) \n\t" // v2 = scale_u8 × 32 + "addi s4, s4, 32+128*4+128 \n\t" // s4 += 672 (32+512+128) + + // ---- load qh (128B = 1 VRF) from s5, advance s5 by 672 ---- + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v0, (s5) \n\t" // v0 = qh N32K32 sign/high-bit packed + "addi s5, s5, 128+32+128*4 \n\t" // s5 += 672 (128+32+512) + + // ---- load A data (32B = K32 int8) from s3 ---- + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v1, (s3) \n\t" // v1 = A M1K32 int8 + "addi s3, s3, 32+6 \n\t" // s3 += 38 (data + scl + asum) + + // ---- load A scale (fp32) and asum (int16) from s2 ---- + "flw f0, (s2) \n\t" // f0 = A scale (fp32) + "addi s2, s2, 6+32 \n\t" // s2 += 38 + + // ---- Decode packed MXFP4 payload into a vmadot-friendly signed-lane layout ---- + "vsetvli t0, x0, e8, m1 \n\t" + "vand.vi v12, v8, 0xF \n\t" //8bit(lo4) //[8*32] + "vand.vi v13, v9, 0xF \n\t" + "vand.vi v14, v10, 0xF \n\t" + "vand.vi v15, v11, 0xF \n\t" + "vsrl.vi v8, v8, 4 \n\t" //8bit(hi4) + "vsrl.vi v9, v9, 4 \n\t" + "vsrl.vi v10, v10, 4 \n\t" + "vsrl.vi v11, v11, 4 \n\t" + + // [4*32]*2 + "vsetvli t0, x0, e8, m1 \n\t" + "vpack.vv v16, v12, v8, 0 \n\t" + "vpack.vv v18, v13, v9, 0 \n\t" + "vpack.vv v20, v14, v10, 0 \n\t" + "vpack.vv v22, v15, v11, 0 \n\t" + + "vsetvli t0, x0, e8, m8 \n\t" + "vrsub.vi v16, v16, 0, v0.t \n\t" + + // [4*32]*2 -> [8*16] + "vsetvli t0, x0, e8, m1 \n\t" + "vupack.vv v8, v16, v17, 1 \n\t" + "vupack.vv v10, v18, v19, 1 \n\t" + "vupack.vv v12, v20, v21, 1 \n\t" + "vupack.vv v14, v22, v23, 1 \n\t" + + "vsetvli t0, x0, e64, m1 \n\t" + "vslidedown.vi v16, v1, 2 \n\t" + + // init the accumu to 0 + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v24, v16, v16 \n\t" + "vxor.vv v26, v16, v16 \n\t" + "vxor.vv v28, v16, v16 \n\t" + "vxor.vv v30, v16, v16 \n\t" + + // ---- int8 dot products over the decoded MXFP4 lane groups ---- + "vmadot v24, v1, v8, i8 \n\t" // N0..7 + "vmadot v26, v1, v10, i8 \n\t" // N8..15 + "vmadot v28, v1, v12, i8 \n\t" // N16..23 + "vmadot v30, v1, v14, i8 \n\t" // N24..31 + "vmadot v24, v16, v9, i8 \n\t" // N0..7 + "vmadot v26, v16, v11, i8 \n\t" // N8..15 + "vmadot v28, v16, v13, i8 \n\t" // N16..23 + "vmadot v30, v16, v15, i8 \n\t" // N24..31 + + "vsetvli t0, x0, e32, m1 \n\t" + "vpack.vv v16, v24, v26, 2 \n\t" // v16 = N0..15 + "vpack.vv v18, v28, v30, 2 \n\t" // v18 = N16..31 + "vpack.vv v24, v16, v18, 3 \n\t" // v24 = N0..31 + + "lui t1, 0x00200 \n\t" + "vmv.v.x v30, t1 \n\t" + // b_scale e8m0 -> fp32 + "vsetvli t0, x0, e8, mf4 \n\t" + "vwaddu.vx v28, v2, x0 \n\t" + "vsetvli t0, x0, e16, mf2 \n\t" + "vwadd.vx v2, v28, x0 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vmsle.vi v0, v2, 1 \n\t" + "vadd.vi v28, v2, -1 \n\t" + "vsll.vi v28, v28, 23 \n\t" + "vsll.vv v28, v30, v2, v0.t \n\t" + + // a_scale * b_scale; + "vsetvli t0, x0, e32, m1 \n\t" + "vfcvt.f.x.v v26, v24 \n\t" + "vfmul.vf v30, v28, f0 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + // static_cast(qsum) * a_scale * b_scale; + "vfmacc.vv v3, v30, v26 \n\t" + + "addi t3, t3, -1 \n\t" + "bgtz t3, BLK_LPST%= \n\t" + "BLK_LPND%=: \n\t" + "vsetvli t0, %[NBLKS], e32, m1 \n\t" + "vse32.v v3, (%[pC]) \n\t" + "FUNC_END%=: \n\t" + + : + : [BCK] "r"(cnt), [NBLKS] "r"(nblks), [pA] "r"(quant_a_ptr), [pB] "r"(QuantBDataPtr), [pC] "r"(CPtr) + : "cc", "memory", "t0", "t1", "t2", "t3", "t4", "f0", "s2", "s3", "s4", "s5", "s6", "s7", "v0", "v1", + "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); + } + } +} + +void gemm_kernel_i8mxfp4_m4(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + constexpr size_t NB_COLS = 32; + constexpr size_t K_TILE = 32; + using blk_type = nrow_block_mxfp4; + + GGML_ASSERT(blk_len == K_TILE); + GGML_ASSERT(count_m == 4); + GGML_UNUSED(quant_b_zp); + + const size_t a_blk_stride = q8_blk_size(blk_len, true); + const size_t b_blk_stride = sizeof(blk_type); + const size_t b_tile_stride = k_blks * b_blk_stride; + + if (quant_b_zp == NULL) { + // MXFP4 block layout per K32/N32 tile: + // [scale_e8m0 x 32][qh sign/high-bit mask x 128B][qs packed 4-bit magnitudes x 512B] + // There is no explicit zp stream; qh is combined with qs to reconstruct signed MXFP4 values. + for (size_t ni = 0; ni < count_n; ni += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + uint8_t * b_data = (uint8_t *) quant_b_data + (ni / NB_COLS) * b_tile_stride; + uint8_t * a_data = (uint8_t *) quant_a_ptr; + float * dst_c = c_ptr + ni; + size_t cnt = k_blks; + + asm volatile( + // v4-v7 are the fp32 accumulators for rows 0..3 of the current N32 tile. + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v4, v4, v4 \n\t" + "vxor.vv v5, v5, v5 \n\t" + "vxor.vv v6, v6, v6 \n\t" + "vxor.vv v7, v7, v7 \n\t" + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + // Load the 4 A-row scales for this K32 block and build row data pointers. + "flw fa0, 0(%[A]) \n\t" + "flw fa1, 4(%[A]) \n\t" + "flw fa2, 8(%[A]) \n\t" + "flw fa3, 12(%[A]) \n\t" + "addi t3, %[A], 24 \n\t" + "addi t4, t3, 32 \n\t" + "addi t5, t3, 64 \n\t" + "addi t6, t3, 96 \n\t" + "addi %[A], %[A], 152 \n\t" + + // B-side pointers: + // t1 -> qh bitmask stream, t2 -> qs low-nibble stream. + "addi t1, %[B], 32 \n\t" + "addi t2, %[B], 160 \n\t" + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v2, (%[B]) \n\t" + "addi %[B], %[B], 672 \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v0, (t1) \n\t" + "vl4r.v v8, (t2) \n\t" + + // Decode the packed MXFP4 payload once for the whole tile and expand it + // into a vmadot-friendly layout. + "vand.vi v12, v8, 0xF \n\t" + "vand.vi v13, v9, 0xF \n\t" + "vand.vi v14, v10, 0xF \n\t" + "vand.vi v15, v11, 0xF \n\t" + "vsrl.vi v8, v8, 4 \n\t" + "vsrl.vi v9, v9, 4 \n\t" + "vsrl.vi v10, v10, 4 \n\t" + "vsrl.vi v11, v11, 4 \n\t" + + "vpack.vv v16, v12, v8, 0 \n\t" + "vpack.vv v18, v13, v9, 0 \n\t" + "vpack.vv v20, v14, v10, 0 \n\t" + "vpack.vv v22, v15, v11, 0 \n\t" + + "vsetvli t0, x0, e8, m8 \n\t" + "vrsub.vi v16, v16, 0, v0.t \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vupack.vv v8, v16, v17, 1 \n\t" + "vupack.vv v10, v18, v19, 1 \n\t" + "vupack.vv v12, v20, v21, 1 \n\t" + "vupack.vv v14, v22, v23, 1 \n\t" + + "lui t1, 0x00200 \n\t" + "vmv.v.x v30, t1 \n\t" + // b_scale e8m0 -> fp32 + "vsetvli t0, x0, e8, mf4 \n\t" + "vwaddu.vx v28, v2, x0 \n\t" + "vsetvli t0, x0, e16, mf2 \n\t" + "vwadd.vx v26, v28, x0 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vmsle.vi v0, v26, 1 \n\t" + "vadd.vi v24, v26, -1 \n\t" + "vsll.vi v18, v24, 23 \n\t" + "vsll.vv v18, v30, v26, v0.t \n\t" + + // Row 0: dot(A0, decoded MXFP4 lane groups), accumulate in int32 and + // then apply A/B scaling. + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v1, (t3) \n\t" + "vsetvli t0, x0, e64, m1 \n\t" + "vupack.vv v16, v1, v2, 1 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v24, v24, v24 \n\t" + "vxor.vv v26, v26, v26 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "vxor.vv v30, v30, v30 \n\t" + "vmadot v24, v16, v8, i8 \n\t" + "vmadot v26, v16, v10, i8 \n\t" + "vmadot v28, v16, v12, i8 \n\t" + "vmadot v30, v16, v14, i8 \n\t" + "vmadot v24, v17, v9, i8 \n\t" + "vmadot v26, v17, v11, i8 \n\t" + "vmadot v28, v17, v13, i8 \n\t" + "vmadot v30, v17, v15, i8 \n\t" + "vpack.vv v16, v24, v26, 2 \n\t" + "vpack.vv v20, v28, v30, 2 \n\t" + "vpack.vv v24, v16, v20, 3 \n\t" + "vpack.vv v26, v17, v21, 3 \n\t" + "vfcvt.f.x.v v24, v24 \n\t" + "vfcvt.f.x.v v25, v25 \n\t" + "vfcvt.f.x.v v26, v26 \n\t" + "vfcvt.f.x.v v27, v27 \n\t" + "vfmul.vv v24, v24, v18 \n\t" + "vfmul.vv v25, v25, v18 \n\t" + "vfmul.vv v26, v26, v18 \n\t" + "vfmul.vv v27, v27, v18 \n\t" + "vfmacc.vf v4, fa0, v24 \n\t" + "vfmacc.vf v5, fa1, v25 \n\t" + "vfmacc.vf v6, fa2, v26 \n\t" + "vfmacc.vf v7, fa3, v27 \n\t" + + "addi %[BK], %[BK], -1 \n\t" + "bgtz %[BK], BLK_LOOP%= \n\t" + + // Tail-aware store for the final N tile (`nb_real` may be < 32). + "vsetvli t0, %[NBLKS], e32, m1 \n\t" + "add t1, %[LDC], %[DST] \n\t" + "vse32.v v4, (%[DST]) \n\t" + "vse32.v v5, (t1) \n\t" + "add t2, t1, %[LDC] \n\t" + "vse32.v v6, (t2) \n\t" + "add t3, t2, %[LDC] \n\t" + "vse32.v v7, (t3) \n\t" + : [A] "+r"(a_data), [B] "+r"(b_data), [BK] "+r"(cnt) + : [DST] "r"(dst_c), [LDC] "r"(ldc * 4), [NBLKS] "r"(nb_real) + : "cc", "memory", "t0", "t1", "t2", "t3", "t4", "t5", "t6", "s1", "s2", "s3", "s4", "v0", "v1", "v2", + "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", + "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", + "fa0", "fa1", "fa2", "fa3"); + } + } +} + +void gemm_kernel_i8i5_m1(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + // ========================================================================= + // i8i5: 8-bit activation × 5-bit weight (4-bit low + 1-bit high mask) + // + // B layout per N32K32 k-block (no-zp): + // [0 .. 63 ] : scale_fp16 × 32 (64B) + // [64 .. 191] : Bh i1-high-bit × 32N × 32K (128B = 1 VRF) + // [192.. 703] : Bs i4-low-nibble × 32N × 32K (512B = 4 VRF) + // Total: 704B per k-block stride + // + // B layout per N32K32 k-block (with-zp): + // [0 .. 63 ] : scale_fp16 × 32 (64B) + // [64 .. 95 ] : zp_uint8 × 32 (32B) + // [96 .. 223] : Bh i1-high-bit × 32N × 32K (128B = 1 VRF) + // [224.. 735] : Bs i4-low-nibble × 32N × 32K (512B = 4 VRF) + // Total: 736B per k-block stride + // + // Bh format per N8K32 sub-block (32B): + // K rows × N cols × 1bit packed as bytes (8 cols per byte, K groups of 4B) + // Byte k gives 8 mask bits for columns N7..N0 at k-th K-element. + // + // Computation: + // B5bit_signed = (Bs | (Bh << 4)) - zp + // dot(A, B5) = dot(A, Bs_u4) + 16*dot(A, Bh_u1) - zp*asum + // No-zp: implicit zp = 16 (unsigned [0..31] centered at 16) + // With-zp: explicit zp from data + // + // ========================================================================= + + if (quant_b_zp == NULL) { + for (size_t n = 0; n < count_n; n += 32) { + size_t nblks = (count_n - n) > 32 ? 32 : count_n - n; + // i8i5 no-zp: per column per k-block stride = fp16(2B) + i4(16B) + i1(4B) = 22B + uint8_t * QuantBDataPtr = (uint8_t *) quant_b_data + // + n * k_blks * (blk_len / 8) + // Bh i1 mask: n×k_blks×4 + n * k_blks * blk_len / 2 + // Bs i4 data: n×k_blks×16 + n * k_blks * sizeof(_Float16); // scale: n×k_blks×2 + float * CPtr = c_ptr + n; + size_t cnt = k_blks; + + // A format (same as i8i4): + // || scl(fp32,4B) | asum(int16,2B) | data(int8,32B) || × k_blks + // + // Register map: + // t3 = k_blks loop counter t4 = nblks (tail) + // t2 = A asum (int16) << 4 f0 = A scale (fp32) + // s2 = pA (scale/asum) s3 = pA data + // s4 = pB scales (fp16×32) + // s5 = pB Bh (i1 mask, 128B) + // s6 = pB Bs (i4 packed, 512B) + // s7 = pC + // v3 = fp32 accumulator (N32) + // v2 = B scales fp16 (loaded as bytes; later widened) + // v0 = Bh mask bytes (also used as v0.t mask after load) + // v1 = A int8 (K32) + // v8..v15 / v16..v23 = Bs unpack/pack temporaries (build b5bit bytes) + // v24/v26/v28/v30 = int32 dot accumulators & packing temps + + __asm__ volatile( + "mv t3, %[BCK] \n\t" // t3 = k_blks + "mv t4, %[NBLKS] \n\t" // t4 = nblks (tail guard) + + // ---- pre-loop: init fp16 constants in e16 m1 context ---- + "vsetvli t0, x0, e16, m1 \n\t" + "vmv.v.i v0, 1 \n\t" // v0 = int16(1) + "vfcvt.f.x.v v0, v0 \n\t" // v0 = 1.0_fp16 + "vxor.vv v3, v16, v16 \n\t" + + // ---- pointer setup ---- + "mv s2, %[pA] \n\t" // s2 = pA (scale, fp32) + "addi s3, %[pA], 4+2 \n\t" // s3 = pA data (skip scale+asum) + "mv s4, %[pB] \n\t" // s4 = pBSCL + "addi s5, %[pB], 32*2 \n\t" // s5 = pBh (pB + 64B scale) + "addi s6, %[pB], 32*2+128 \n\t" // s6 = pBs (pB + 64 + 128 = pB+192) + "mv s7, %[pC] \n\t" // s7 = pC + + // ===================================================================== + // K-block loop: each iteration processes one N32×K32 block + // Stride per k-block = 704B = 64(scl) + 512(Bs) + 128(Bh) + // ===================================================================== + ".align 4 \n\t" + "BLK_LPST%=: \n\t" + + // ---- load Bs (512B = 4 VRF) from s6, advance s6 by 704 ---- + "vsetvli t0, x0, e8, m1 \n\t" + "vl4r.v v8, (s6) \n\t" // v8..v11 = Bs N32K32 i4 + "addi s6, s6, 128*4+128+64 \n\t" // s6 += 704 (512+128+64) + + // ---- load B scale (64B = 32×fp16) from s4, advance s4 by 704 ---- + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v2, (s4) \n\t" // v2 = scale_fp16 × 32 + "addi s4, s4, 64+128*4+128 \n\t" // s4 += 704 (64+512+128) + + // ---- load Bh (128B = 1 VRF) from s5, advance s5 by 704 ---- + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v0, (s5) \n\t" // v0 = Bh N32K32 1-bit packed + "addi s5, s5, 128+64+128*4 \n\t" // s5 += 704 (128+64+512) + + // ---- load A data (32B = K32 int8) from s3 ---- + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v1, (s3) \n\t" // v1 = A M1K32 int8 + "addi s3, s3, 32+6 \n\t" // s3 += 38 (data + scl + asum) + + // ---- load A scale (fp32) and asum (int16) from s2 ---- + "flw f0, (s2) \n\t" // f0 = A scale (fp32) + "lh t2, 4(s2) \n\t" // t2 = A asum (int16) + "addi s2, s2, 6+32 \n\t" // s2 += 38 + + //// ---- A nibble unpacking ---- + "vsetvli t0, x0, e8, m1 \n\t" + "vand.vi v12, v8, 0xF \n\t" //8bit(lo4) //[8*32] + "vand.vi v13, v9, 0xF \n\t" + "vand.vi v14, v10, 0xF \n\t" + "vand.vi v15, v11, 0xF \n\t" + "vsrl.vi v8, v8, 4 \n\t" //8bit(hi4) + "vsrl.vi v9, v9, 4 \n\t" + "vsrl.vi v10, v10, 4 \n\t" + "vsrl.vi v11, v11, 4 \n\t" + + "slli t2, t2, 4 \n\t" // a_sum * 16; + // [4*32]*2 + "vsetvli t0, x0, e8, m1 \n\t" + "vpack.vv v16, v12, v8, 0 \n\t" + "vpack.vv v18, v13, v9, 0 \n\t" + "vpack.vv v20, v14, v10, 0 \n\t" + "vpack.vv v22, v15, v11, 0 \n\t" + + "li t1, 16 \n\t" + "vsetvli t0, x0, e8, m8 \n\t" + "vadd.vx v16, v16, t1, v0.t \n\t" + + // [4*32]*2 -> [8*16] + "vsetvli t0, x0, e8, m1 \n\t" + "vupack.vv v8, v16, v17, 1 \n\t" + "vupack.vv v10, v18, v19, 1 \n\t" + "vupack.vv v12, v20, v21, 1 \n\t" + "vupack.vv v14, v22, v23, 1 \n\t" + + "vsetvli t0, x0, e64, m1 \n\t" + "vslidedown.vi v16, v1, 2 \n\t" + + // init the accumu to asum * zp + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v24, v16, v16 \n\t" + "vxor.vv v26, v16, v16 \n\t" + "vxor.vv v28, v16, v16 \n\t" + "vxor.vv v30, v16, v16 \n\t" + + // ---- i8 main dot products ---- + // vmadot: A × unsigned Bh × 16 → fp16 accumulate + "vmadot v24, v1, v8, i8 \n\t" // N0..7 + "vmadot v26, v1, v10, i8 \n\t" // N8..15 + "vmadot v28, v1, v12, i8 \n\t" // N16..23 + "vmadot v30, v1, v14, i8 \n\t" // N24..31 + //// vmadot: A × unsigned Bh × 1 → fp16 accumulate + "vmadot v24, v16, v9, i8 \n\t" // N0..7 + "vmadot v26, v16, v11, i8 \n\t" // N8..15 + "vmadot v28, v16, v13, i8 \n\t" // N16..23 + "vmadot v30, v16, v15, i8 \n\t" // N24..31 + + "vsetvli t0, x0, e32, m1 \n\t" + "vpack.vv v16, v24, v26, 2 \n\t" // v16 = N0..15 + "vpack.vv v18, v28, v30, 2 \n\t" // v18 = N16..31 + "vpack.vv v24, v16, v18, 3 \n\t" // v24 = N0..31 + + "vadd.vx v24, v24, t2 \n\t" + // b_scale fp16 -> fp32 + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwcvt.f.f.v v28, v2 \n\t" + + // a_scale * b_scale; + "vsetvli t0, x0, e32, m1 \n\t" + "vfcvt.f.x.v v26, v24 \n\t" + "vfmul.vf v30, v28, f0 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + // static_cast(qsum) * a_scale * b_scale; + "vfmacc.vv v3, v30, v26 \n\t" + + "addi t3, t3, -1 \n\t" + "bgtz t3, BLK_LPST%= \n\t" + "BLK_LPND%=: \n\t" + "vsetvli t0, %[NBLKS], e32, m1 \n\t" + "vse32.v v3, (%[pC]) \n\t" + "FUNC_END%=: \n\t" + + : + : [BCK] "r"(cnt), [NBLKS] "r"(nblks), [pA] "r"(quant_a_ptr), [pB] "r"(QuantBDataPtr), [pC] "r"(CPtr) + : "cc", "memory", "t0", "t1", "t2", "t3", "t4", "f0", "s2", "s3", "s4", "s5", "s6", "s7", "v0", "v1", + "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); + } + } else { + for (size_t n = 0; n < count_n; n += 32) { + size_t nblks = (count_n - n) > 32 ? 32 : count_n - n; + // i8i5 with-zp: per column per k-block stride = fp16(2B)+zp(1B)+i4(16B)+i1(4B)=23B + uint8_t * QuantBDataPtr = (uint8_t *) quant_b_data + // + n * k_blks * blk_len / 2 + // Bs i4: n×k_blks×16 + n * k_blks * (blk_len / 8) + // Bh i1: n×k_blks×4 + n * k_blks * sizeof(uint8_t) + // zp: n×k_blks×1 + n * k_blks * sizeof(_Float16); // scale: n×k_blks×2 + float * CPtr = c_ptr + n; + size_t cnt = k_blks; + + // A format (same as i8i4): + // || scl(fp32,4B) | asum(int16,2B) | data(int8,32B) || × k_blks + // + // Register map: + // t3 = k_blks loop counter t4 = nblks (tail) + // t2 = A asum (int16) << 4 f0 = A scale (fp32) + // s2 = pA (scale/asum) s3 = pA data + // s4 = pB scales (fp16×32); 每个 k-block 先 +64 指向 zp,再 +672 到下一个 block + // s5 = pB Bh (i1 mask, 128B) (offset +96) + // s6 = pB Bs (i4 packed, 512B) (offset +224) + // s7 = pC + // v3 = fp32 accumulator (N32) + // v2 = B scales fp16 (loaded as bytes; later widened) + // v0 = Bh mask bytes (also used as v0.t mask after load) + // v1 = A int8 (K32) / later reused to hold Bzp bytes + // v8..v15 / v16..v23 = Bs unpack/pack temporaries (build b5bit bytes) + // v24/v26/v28/v30 = int32 dot accumulators & packing temps + + __asm__ volatile( + "mv t3, %[BCK] \n\t" // t3 = k_blks + "mv t4, %[NBLKS] \n\t" // t4 = nblks (tail guard) + + // ---- pre-loop: init fp16 constants in e16 m1 context ---- + "vsetvli t0, x0, e16, m1 \n\t" + "vmv.v.i v0, 1 \n\t" // v0 = int16(1) + "vfcvt.f.x.v v0, v0 \n\t" // v0 = 1.0_fp16 + "vxor.vv v3, v16, v16 \n\t" + + // ---- pointer setup ---- + "mv s2, %[pA] \n\t" // s2 = pA (scale, fp32) + "addi s3, %[pA], 4+2 \n\t" // s3 = pA data (skip scale+asum) + "mv s4, %[pB] \n\t" // s4 = pBSCL + "addi s5, %[pB], 32*3 \n\t" // s5 = pBh (pB + 64B scale + 32B zp = pB+96) + "addi s6, %[pB], 32*3+128 \n\t" // s6 = pBs (pB + 96 + 128 = pB+224) + "mv s7, %[pC] \n\t" // s7 = pC + + // ===================================================================== + // K-block loop: each iteration processes one N32×K32 block + // Stride per k-block = 736B = 64(scale) + 32(zp) + 128(Bh) + 512(Bs) + // ===================================================================== + ".align 4 \n\t" + "BLK_LPST%=: \n\t" + + // ---- load Bs (512B = 4 VRF) from s6, advance s6 by 736 ---- + "vsetvli t0, x0, e8, m1 \n\t" + "vl4r.v v8, (s6) \n\t" // v8..v11 = Bs N32K32 i4 + "addi s6, s6, 128*4+128+96 \n\t" // s6 += 736 (512+128+96) + + // ---- load B scale (64B = 32×fp16) from s4; then s4 points to zp[32] ---- + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v2, (s4) \n\t" // v2 = scale_fp16 × 32 + "addi s4, s4, 64 \n\t" // s4 += 64 (now points to zp) + + // ---- load Bh (128B = 1 VRF) from s5, advance s5 by 736 ---- + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v0, (s5) \n\t" // v0 = Bh N32K32 1-bit packed + "addi s5, s5, 128+96+128*4 \n\t" // s5 += 736 (128+96+512) + + // ---- load A data (32B = K32 int8) from s3 ---- + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v1, (s3) \n\t" // v1 = A M1K32 int8 + "addi s3, s3, 32+6 \n\t" // s3 += 38 (data + scl + asum) + + // ---- load A scale (fp32) and asum (int16) from s2 ---- + "flw f0, (s2) \n\t" // f0 = A scale (fp32) + "lh t2, 4(s2) \n\t" // t2 = A asum (int16) + "addi s2, s2, 6+32 \n\t" // s2 += 38 + + //// ---- A nibble unpacking ---- + "vsetvli t0, x0, e8, m1 \n\t" + "vand.vi v12, v8, 0xF \n\t" //8bit(lo4) //[8*32] + "vand.vi v13, v9, 0xF \n\t" + "vand.vi v14, v10, 0xF \n\t" + "vand.vi v15, v11, 0xF \n\t" + "vsrl.vi v8, v8, 4 \n\t" //8bit(hi4) + "vsrl.vi v9, v9, 4 \n\t" + "vsrl.vi v10, v10, 4 \n\t" + "vsrl.vi v11, v11, 4 \n\t" + + // [4*32]*2 + "vsetvli t0, x0, e8, m1 \n\t" + "vpack.vv v16, v12, v8, 0 \n\t" + "vpack.vv v18, v13, v9, 0 \n\t" + "vpack.vv v20, v14, v10, 0 \n\t" + "vpack.vv v22, v15, v11, 0 \n\t" + + "li t1, 16 \n\t" + "vsetvli t0, x0, e8, m8 \n\t" + "vadd.vx v16, v16, t1, v0.t \n\t" + + // [4*32]*2 -> [8*16] + "vsetvli t0, x0, e8, m1 \n\t" + "vupack.vv v8, v16, v17, 1 \n\t" + "vupack.vv v10, v18, v19, 1 \n\t" + "vupack.vv v12, v20, v21, 1 \n\t" + "vupack.vv v14, v22, v23, 1 \n\t" + + "vsetvli t0, x0, e64, m1 \n\t" + "vslidedown.vi v16, v1, 2 \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v24, v16, v16 \n\t" + "vxor.vv v26, v16, v16 \n\t" + "vxor.vv v28, v16, v16 \n\t" + "vxor.vv v30, v16, v16 \n\t" + + // ---- i8 main dot products ---- + // vmadot: A × unsigned Bh × 16 → fp16 accumulate + "vmadot v24, v1, v8, i8 \n\t" // N0..7 + "vmadot v26, v1, v10, i8 \n\t" // N8..15 + "vmadot v28, v1, v12, i8 \n\t" // N16..23 + "vmadot v30, v1, v14, i8 \n\t" // N24..31 + // vmadot: A × unsigned Bh × 1 → fp16 accumulate + "vmadot v24, v16, v9, i8 \n\t" // N0..7 + "vmadot v26, v16, v11, i8 \n\t" // N8..15 + "vmadot v28, v16, v13, i8 \n\t" // N16..23 + "vmadot v30, v16, v15, i8 \n\t" // N24..31 + + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v1, (s4) \n\t" // Bzp + "addi s4, s4, 32+128*4+128 \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vpack.vv v16, v24, v26, 2 \n\t" // v16 = N0..15 + "vpack.vv v18, v28, v30, 2 \n\t" // v18 = N16..31 + "vpack.vv v24, v16, v18, 3 \n\t" // v24 = N0..31 + + "vwaddu.vx v28, v1, x0 \n\t" // uint8 -> uint16 + + "vsetvli t0, x0, e16, m1 \n\t" + "vwmul.vx v30, v28, t2 \n\t" + + // b_scale fp16 -> fp32 + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwcvt.f.f.v v28, v2 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vadd.vv v24, v24, v30 \n\t" + + // a_scale * b_scale; + "vsetvli t0, x0, e32, m1 \n\t" + "vfmul.vf v30, v28, f0 \n\t" + "vfcvt.f.x.v v26, v24 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + // static_cast(qsum) * a_scale * b_scale; + "vfmacc.vv v3, v30, v26 \n\t" + + "addi t3, t3, -1 \n\t" + "bgtz t3, BLK_LPST%= \n\t" + "BLK_LPND%=: \n\t" + "vsetvli t0, %[NBLKS], e32, m1 \n\t" + "vse32.v v3, (%[pC]) \n\t" + "FUNC_END%=: \n\t" + : + : [BCK] "r"(cnt), [NBLKS] "r"(nblks), [pA] "r"(quant_a_ptr), [pB] "r"(QuantBDataPtr), [pC] "r"(CPtr) + : "cc", "memory", "t0", "t1", "t2", "t3", "t4", "f0", "s2", "s3", "s4", "s5", "s6", "s7", "v0", "v1", + "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); + } + } +} + +void gemm_kernel_i8i5_m4(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + constexpr size_t NB_COLS = 32; + + GGML_UNUSED(count_m); + GGML_UNUSED(blk_len); + + // This kernel computes a 4x32 output tile. For each K32 block we decode the + // packed Q5 weights once and reuse the decoded vectors across the 4 A rows. + constexpr size_t B_Q50_BLK_STRIDE = sizeof(nrow_block_q5_0); + constexpr size_t B_Q51_BLK_STRIDE = sizeof(nrow_block_q5_1); + + if (quant_b_zp) { + // Q5_1 block layout per K32/N32 tile: + // [scale_fp16 x 32][zp_u8 x 32][qh high-bit mask x 128B][qs low nibbles x 512B] + for (size_t ni = 0; ni < count_n; ni += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + uint8_t * b_data = (uint8_t *) quant_b_data + (ni / NB_COLS) * k_blks * B_Q51_BLK_STRIDE; + uint8_t * a_data = (uint8_t *) quant_a_ptr; + float * dst_c = c_ptr + ni; + size_t cnt = k_blks; + + asm volatile( + // v4-v7 are the fp32 accumulators for rows 0..3 of the current N32 tile. + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v4, v4, v4 \n\t" + "vxor.vv v5, v5, v5 \n\t" + "vxor.vv v6, v6, v6 \n\t" + "vxor.vv v7, v7, v7 \n\t" + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + // Load the 4 A-row scales/sums for this K32 block and build row data pointers. + "flw fa0, 0(%[A]) \n\t" + "flw fa1, 4(%[A]) \n\t" + "flw fa2, 8(%[A]) \n\t" + "flw fa3, 12(%[A]) \n\t" + "lh s1, 16(%[A]) \n\t" + "lh s2, 18(%[A]) \n\t" + "lh s3, 20(%[A]) \n\t" + "lh s4, 22(%[A]) \n\t" + "addi t3, %[A], 24 \n\t" + "addi t4, t3, 32 \n\t" + "addi t5, t3, 64 \n\t" + "addi t6, t3, 96 \n\t" + "addi %[A], %[A], 152 \n\t" + + // B-side pointers: + // t1 -> zp stream, t2 -> qh bitmask stream, s5 -> qs low-nibble stream. + "addi t1, %[B], 64 \n\t" + "addi t2, %[B], 96 \n\t" + "addi s5, %[B], 224 \n\t" + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v2, (%[B]) \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v0, (t2) \n\t" + "vl4r.v v8, (s5) \n\t" + "addi %[B], %[B], 736 \n\t" + + // Decode Q5 payload once for the whole tile: + // 1) split `qs` low/high nibbles, + // 2) repack into bytes, + // 3) use the `qh` mask to inject bit4 (+16) where needed, + // 4) expand into the vmadot-friendly layout reused by all 4 rows. + "vand.vi v12, v8, 0xF \n\t" + "vand.vi v13, v9, 0xF \n\t" + "vand.vi v14, v10, 0xF \n\t" + "vand.vi v15, v11, 0xF \n\t" + "vsrl.vi v8, v8, 4 \n\t" + "vsrl.vi v9, v9, 4 \n\t" + "vsrl.vi v10, v10, 4 \n\t" + "vsrl.vi v11, v11, 4 \n\t" + + "vpack.vv v16, v12, v8, 0 \n\t" + "vpack.vv v18, v13, v9, 0 \n\t" + "li t2, 16 \n\t" + "vpack.vv v20, v14, v10, 0 \n\t" + "vpack.vv v22, v15, v11, 0 \n\t" + + "vsetvli t0, x0, e8, m8 \n\t" + "vadd.vx v16, v16, t2, v0.t \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vupack.vv v8, v16, v17, 1 \n\t" + "vupack.vv v10, v18, v19, 1 \n\t" + "vupack.vv v12, v20, v21, 1 \n\t" + "vupack.vv v14, v22, v23, 1 \n\t" + + // Convert per-column fp16 scales once; the same scale vector is shared by all 4 rows. + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwcvt.f.f.v v18, v2 \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v3, (t1) \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + + // Row 0: dot(A0, decoded_q5) + a_sum0 * zp, then scale by A/B scales. + // The widen/mul correction sequence intentionally matches the proven m1 Q5_1 path. + "vle8.v v1, (t3) \n\t" + "vsetvli t0, x0, e64, m1 \n\t" + "vupack.vv v16, v1, v2, 1 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v24, v24, v24 \n\t" + "vxor.vv v26, v26, v26 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "vxor.vv v30, v30, v30 \n\t" + "vmadot v24, v16, v8, i8 \n\t" + "vmadot v26, v16, v10, i8 \n\t" + "vmadot v28, v16, v12, i8 \n\t" + "vmadot v30, v16, v14, i8 \n\t" + "vmadot v24, v17, v9, i8 \n\t" + "vmadot v26, v17, v11, i8 \n\t" + "vmadot v28, v17, v13, i8 \n\t" + "vmadot v30, v17, v15, i8 \n\t" + "vpack.vv v16, v24, v26, 2 \n\t" + "vpack.vv v20, v28, v30, 2 \n\t" + "vpack.vv v24, v16, v20, 3 \n\t" + "vpack.vv v26, v17, v21, 3 \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vwaddu.vx v28, v3, x0 \n\t" + "vsetvli t0, x0, e16, m1 \n\t" + "vwmul.vx v12, v28, s1 \n\t" + "vwmul.vx v14, v28, s2 \n\t" + "vwmul.vx v20, v28, s3 \n\t" + "vwmul.vx v22, v28, s4 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vadd.vv v24, v24, v12 \n\t" + "vadd.vv v25, v25, v14 \n\t" + "vadd.vv v26, v26, v20 \n\t" + "vadd.vv v27, v27, v22 \n\t" + "vfcvt.f.x.v v12, v24 \n\t" + "vfcvt.f.x.v v14, v25 \n\t" + "vfcvt.f.x.v v20, v26 \n\t" + "vfcvt.f.x.v v22, v27 \n\t" + "vfmul.vv v12, v12, v18 \n\t" + "vfmul.vv v14, v14, v18 \n\t" + "vfmul.vv v20, v20, v18 \n\t" + "vfmul.vv v22, v22, v18 \n\t" + "vfmacc.vf v4, fa0, v12 \n\t" + "vfmacc.vf v5, fa1, v14 \n\t" + "vfmacc.vf v6, fa2, v20 \n\t" + "vfmacc.vf v7, fa3, v22 \n\t" + + "addi %[BK], %[BK], -1 \n\t" + "bgtz %[BK], BLK_LOOP%= \n\t" + + // Tail-aware store for the final N tile (`nb_real` may be < 32). + "vsetvli t0, %[NBLKS], e32, m1 \n\t" + "add t1, %[LDC], %[DST] \n\t" + "vse32.v v4, (%[DST]) \n\t" + "vse32.v v5, (t1) \n\t" + "add t2, t1, %[LDC] \n\t" + "vse32.v v6, (t2) \n\t" + "add t3, t2, %[LDC] \n\t" + "vse32.v v7, (t3) \n\t" + : [A] "+r"(a_data), [B] "+r"(b_data), [BK] "+r"(cnt) + : [DST] "r"(dst_c), [LDC] "r"(ldc * 4), [NBLKS] "r"(nb_real) + : "cc", "memory", "t0", "t1", "t2", "t3", "t4", "t5", "t6", "s1", "s2", "s3", "s4", "s5", "v0", "v1", + "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", + "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", + "v31", "fa0", "fa1", "fa2", "fa3"); + } + } else { + // Q5_0 block layout per K32/N32 tile: + // [scale_fp16 x 32][qh high-bit mask x 128B][qs low nibbles x 512B] + // There is no explicit zp stream; the implicit midpoint correction is +16. + for (size_t ni = 0; ni < count_n; ni += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + uint8_t * b_data = (uint8_t *) quant_b_data + (ni / NB_COLS) * k_blks * B_Q50_BLK_STRIDE; + uint8_t * a_data = (uint8_t *) quant_a_ptr; + float * dst_c = c_ptr + ni; + size_t cnt = k_blks; + + asm volatile( + // v4-v7 are the fp32 accumulators for rows 0..3 of the current N32 tile. + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v4, v4, v4 \n\t" + "vxor.vv v5, v5, v5 \n\t" + "vxor.vv v6, v6, v6 \n\t" + "vxor.vv v7, v7, v7 \n\t" + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + // Load the 4 A-row scales/sums for this K32 block and build row data pointers. + "flw fa0, 0(%[A]) \n\t" + "flw fa1, 4(%[A]) \n\t" + "flw fa2, 8(%[A]) \n\t" + "flw fa3, 12(%[A]) \n\t" + "lh s1, 16(%[A]) \n\t" + "lh s2, 18(%[A]) \n\t" + "lh s3, 20(%[A]) \n\t" + "lh s4, 22(%[A]) \n\t" + "addi t3, %[A], 24 \n\t" + "addi t4, t3, 32 \n\t" + "addi t5, t3, 64 \n\t" + "addi t6, t3, 96 \n\t" + "addi %[A], %[A], 152 \n\t" + + // B-side pointers: + // t1 -> qh bitmask stream, t2 -> qs low-nibble stream. + "addi t1, %[B], 64 \n\t" + "addi t2, %[B], 192 \n\t" + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v2, (%[B]) \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v0, (t1) \n\t" + "vl4r.v v8, (t2) \n\t" + "addi %[B], %[B], 704 \n\t" + + // Decode Q5 payload once for the whole tile and expand it into the vmadot layout. + "vand.vi v12, v8, 0xF \n\t" + "vand.vi v13, v9, 0xF \n\t" + "vand.vi v14, v10, 0xF \n\t" + "vand.vi v15, v11, 0xF \n\t" + "vsrl.vi v8, v8, 4 \n\t" + "vsrl.vi v9, v9, 4 \n\t" + "vsrl.vi v10, v10, 4 \n\t" + "vsrl.vi v11, v11, 4 \n\t" + + "vpack.vv v16, v12, v8, 0 \n\t" + "vpack.vv v18, v13, v9, 0 \n\t" + "li t2, 16 \n\t" + "vpack.vv v20, v14, v10, 0 \n\t" + "vpack.vv v22, v15, v11, 0 \n\t" + + "vsetvli t0, x0, e8, m8 \n\t" + "vadd.vx v16, v16, t2, v0.t \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vupack.vv v8, v16, v17, 1 \n\t" + "vupack.vv v10, v18, v19, 1 \n\t" + "vupack.vv v12, v20, v21, 1 \n\t" + "vupack.vv v14, v22, v23, 1 \n\t" + + // Convert per-column fp16 scales once; the same scale vector is shared by all 4 rows. + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwcvt.f.f.v v18, v2 \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + + // Row 0: dot(A0, decoded_q5) + a_sum0 * 16 (implicit Q5_0 midpoint correction). + "vle8.v v1, (t3) \n\t" + "vsetvli t0, x0, e64, m1 \n\t" + "vupack.vv v16, v1, v2, 1 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v24, v24, v24 \n\t" + "vxor.vv v26, v26, v26 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "vxor.vv v30, v30, v30 \n\t" + "vmadot v24, v16, v8, i8 \n\t" + "vmadot v26, v16, v10, i8 \n\t" + "vmadot v28, v16, v12, i8 \n\t" + "vmadot v30, v16, v14, i8 \n\t" + "vmadot v24, v17, v9, i8 \n\t" + "vmadot v26, v17, v11, i8 \n\t" + "vmadot v28, v17, v13, i8 \n\t" + "vmadot v30, v17, v15, i8 \n\t" + "vpack.vv v16, v24, v26, 2 \n\t" + "slli s1, s1, 4 \n\t" + "vpack.vv v20, v28, v30, 2 \n\t" + "slli s2, s2, 4 \n\t" + "vpack.vv v24, v16, v20, 3 \n\t" + "slli s3, s3, 4 \n\t" + "vpack.vv v26, v17, v21, 3 \n\t" + "slli s4, s4, 4 \n\t" + "vadd.vx v24, v24, s1 \n\t" + "vadd.vx v25, v25, s2 \n\t" + "vadd.vx v26, v26, s3 \n\t" + "vadd.vx v27, v27, s4 \n\t" + "vfcvt.f.x.v v24, v24 \n\t" + "vfcvt.f.x.v v25, v25 \n\t" + "vfcvt.f.x.v v26, v26 \n\t" + "vfcvt.f.x.v v27, v27 \n\t" + "vfmul.vv v24, v24, v18 \n\t" + "vfmul.vv v25, v25, v18 \n\t" + "vfmul.vv v26, v26, v18 \n\t" + "vfmul.vv v27, v27, v18 \n\t" + "vfmacc.vf v4, fa0, v24 \n\t" + "vfmacc.vf v5, fa1, v25 \n\t" + "vfmacc.vf v6, fa2, v26 \n\t" + "vfmacc.vf v7, fa3, v27 \n\t" + + "addi %[BK], %[BK], -1 \n\t" + "bgtz %[BK], BLK_LOOP%= \n\t" + + // Tail-aware store for the final N tile (`nb_real` may be < 32). + "vsetvli t0, %[NBLKS], e32, m1 \n\t" + "add t1, %[LDC], %[DST] \n\t" + "vse32.v v4, (%[DST]) \n\t" + "vse32.v v5, (t1) \n\t" + "add t2, t1, %[LDC] \n\t" + "vse32.v v6, (t2) \n\t" + "add t3, t2, %[LDC] \n\t" + "vse32.v v7, (t3) \n\t" + : [A] "+r"(a_data), [B] "+r"(b_data), [BK] "+r"(cnt) + : [DST] "r"(dst_c), [LDC] "r"(ldc * 4), [NBLKS] "r"(nb_real) + : "cc", "memory", "t0", "t1", "t2", "t3", "t4", "t5", "t6", "s1", "s2", "s3", "s4", "v0", "v1", "v2", + "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", + "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", + "fa0", "fa1", "fa2", "fa3"); + } + } +} + +void gemm_kernel_i8i8_m1(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + for (size_t n = 0; n < count_n; n += 32) { + size_t nblks = (count_n - n) > 32 ? 32 : count_n - n; + uint8_t * QuantBDataPtr = (uint8_t *) quant_b_data + // + n * k_blks * blk_len + // b data + n * k_blks * sizeof(_Float16); // scale + float * CPtr = c_ptr + n; + size_t cnt = k_blks; + + // A format Version_1 (FP32 SCALE FOR Normal VMADOTins of IME2) + // A M1K32 int8 256bit + // Ascale fp32 * 1 32bit + // || scl*1(fp32) | Asum(int16) | blk0 || scl*1(fp32) | Asum(int16) | blk0 || ... + // || Element || Element || ... + // B format + // B N8K32 int4 2048bit + // 4VRF, N32K32, 8192bit + // Bscale fp16 * N32 512bit; + // || scl*32..(fp16) | blk0 blk1 ... blk31 || scl*32..(fp16) | blk0 blk1 ... blk31 || ... + // || Element || Element || ... + + //bias always be nullptr + __asm__ volatile( + + // t3 = k/32 + "mv t3, %[BCK] \n\t" + "mv t4, %[NBLKS] \n\t" + "mv s2, %[pA] \n\t" // s2 = pASCL + "addi s3, %[pA], 4+2 \n\t" // s3 = pAData, (pA+AScl+ASum) + "mv s4, %[pB] \n\t" // s4 = pBSCL + "addi s5, %[pB], 32*2 \n\t" // s5 = pBdata; + "mv s6, %[pC] \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v2, v0, v0 \n\t" // clear acc + + // ordinary vmadot: vle*6 flw*1 vecIns*64 vmadot*8 + ".align 4 \n\t" + "_K_LPST%=: \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vl4r.v v4, (s5) \n\t" // B Data 4VRF * 8Row * 32 + "addi s5, s5, 128*4 \n\t" + "vl4r.v v8, (s5) \n\t" // B Data 4VRF * 8Row * 32 + "addi s5, s5, 128*4+64 \n\t" + + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s4) \n\t" // B Scale 4VRF*8Row*FP16 = 512bit + "addi s4, s4, 64+128*8 \n\t" + + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v3, (s3) \n\t" // A Data M1*K32*int8 = 256bit + "addi s3, s3, 32+6 \n\t" + + "flw f0, (s2) \n\t" // A Scale fp32 + "addi s2, s2, 6+32 \n\t" // AScale + Asum(FP32+i16) + + "vsetvli t0, zero, e32, m1 \n\t" + "vupack.vv v24, v4, v5, 1 \n\t" + "vupack.vv v26, v6, v7, 1 \n\t" + "vupack.vv v28, v8, v9, 1 \n\t" + "vupack.vv v30, v10, v11, 1 \n\t" + + "vslidedown.vi v4, v3, 4 \n\t" + + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + "vmadot v16, v3, v24, i8 \n\t" // M0 N0 - N7 INT32(256bit) + "vmadot v18, v3, v26, i8 \n\t" // M0 N8 - N15 + "vmadot v20, v3, v28, i8 \n\t" // M0 N16 - N23 + "vmadot v22, v3, v30, i8 \n\t" // M0 N24 - N31 + + "vmadot v16, v4, v25, i8 \n\t" + "vmadot v18, v4, v27, i8 \n\t" + "vmadot v20, v4, v29, i8 \n\t" + "vmadot v22, v4, v31, i8 \n\t" + + "vpack.vv v24, v16, v18, 2 \n\t" + "vpack.vv v26, v20, v22, 2 \n\t" + "vpack.vv v16, v24, v26, 3 \n\t" + + // b_scale fp16 -> fp32 + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwcvt.f.f.v v24, v0 \n\t" + // mac result i32 -> fp32 + "vsetvli t0, x0, e32, m1 \n\t" + "vfcvt.f.x.v v26, v16 \n\t" + // a_scale * b_scale; + "vfmul.vf v1, v24, f0 \n\t" + // static_cast(qsum) * a_scale * b_scale; + "vfmacc.vv v2, v1, v26 \n\t" + + "addi t3, t3, -1 \n\t" + "bgtz t3, _K_LPST%= \n\t" + "_K_LPND%=: \n\t" + + //----------------------------------------- + // STORE Equal 32N------------------------- + "_ST32%=: \n\t" + "vsetvli t0, t4, e32, m1 \n\t" + "vse32.v v2, (s6) \n\t" // M0 [N0 : N32]; FP32(1024bit) + + "_FUNC_END%=: \n\t" + + : + : [BCK] "r"(cnt), [NBLKS] "r"(nblks), [pA] "r"(quant_a_ptr), [pB] "r"(QuantBDataPtr), [pC] "r"(CPtr) + : "cc", "t0", "t3", "t4", "f0", "s2", "s3", "s4", "s5", "s6"); + } +} + +void gemm_kernel_i8i8_m4(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + int64_t b_data_stride = k_blks * sizeof(ggml_fp16_t) + k_blks * blk_len; + for (size_t ni = 0; ni < count_n; ni += 32) { + uint8_t * b_data = (uint8_t *) quant_b_data + ni * b_data_stride; + int8_t * a_data = (int8_t *) quant_a_ptr; + float * dst_c = c_ptr + ni; + + asm volatile( + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "vxor.vv v29, v29, v29 \n\t" + "vxor.vv v30, v30, v30 \n\t" + "vxor.vv v31, v31, v31 \n\t" + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + // load scale A + "flw fa0, (%[A]) \n\t" + "flw fa1, 4(%[A]) \n\t" + "flw fa2, 8(%[A]) \n\t" + "flw fa3, 12(%[A]) \n\t" + "addi %[A], %[A], 16+8 \n\t" // Ascl+Asum; FP32*4+i16*4 + + // load scale B + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v12, (%[B]) \n\t" + "addi %[B], %[B], 64 \n\t" + "vfwcvt.f.f.v v14, v12 \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vl1r.v v0, (%[A]) \n\t" + "addi %[A], %[A], 128 \n\t" // 4*32@i8 + "vl4r.v v4, (%[B]) \n\t" // 32*32@i8 + "addi %[B], %[B], 512 \n\t" + "vl4r.v v8, (%[B]) \n\t" // 32*32@i8 + "addi %[B], %[B], 512 \n\t" + + "vsetvli t0, zero, e32, m1 \n\t" + "vupack.vv v2, v0, v0, 1 \n\t" + + "vupack.vv v24, v4, v5, 1 \n\t" + "vupack.vv v26, v6, v7, 1 \n\t" + "vupack.vv v4, v8, v9, 1 \n\t" + "vupack.vv v6, v10, v11, 1 \n\t" + + // init the accumu to asum * zp + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + // i4 * i4 vmadot + "vsetvli t0, x0, e32, m1 \n\t" + "vmadot v16, v2, v24, i8 \n\t" + "vmadot v18, v2, v26, i8 \n\t" + "vmadot v20, v2, v4, i8 \n\t" + "vmadot v22, v2, v6, i8 \n\t" + "vmadot v16, v3, v25, i8 \n\t" + "vmadot v18, v3, v27, i8 \n\t" + "vmadot v20, v3, v5, i8 \n\t" + "vmadot v22, v3, v7, i8 \n\t" + + "vpack.vv v0, v16, v18, 2 \n\t" + "vpack.vv v2, v20, v22, 2 \n\t" + "vpack.vv v16, v0, v2, 3 \n\t" + "vpack.vv v18, v1, v3, 3 \n\t" + + "vfcvt.f.x.v v16, v16 \n\t" + "vfcvt.f.x.v v17, v17 \n\t" + "vfcvt.f.x.v v18, v18 \n\t" + "vfcvt.f.x.v v19, v19 \n\t" + + // mul scale + "vfmul.vv v16, v16, v14 \n\t" + "vfmul.vv v17, v17, v14 \n\t" + "vfmul.vv v18, v18, v14 \n\t" + "vfmul.vv v19, v19, v14 \n\t" + + "addi %[BK], %[BK], -1 \n\t" + "vfmacc.vf v28, fa0, v16 \n\t" + "vfmacc.vf v29, fa1, v17 \n\t" + "vfmacc.vf v30, fa2, v18 \n\t" + "vfmacc.vf v31, fa3, v19 \n\t" + + "bgtz %[BK], BLK_LOOP%= \n\t" + + // save + "vsetvli t0, x0, e32, m1 \n\t" + "add t2, %[LDC], %[DST] \n\t" + "vse32.v v28, (%[DST]) \n\t" + "add t3, %[LDC], t2 \n\t" + "vse32.v v29, (t2) \n\t" + "add t2, %[LDC], t3 \n\t" + "vse32.v v30, (t3) \n\t" + "vse32.v v31, (t2) \n\t" + : [A] "+r"(a_data), [B] "+r"(b_data) + : [DST] "r"(dst_c), [LDC] "r"(ldc * 4), [BK] "r"(k_blks) + : "t0", "t1", "t2", "t3", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", + "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", + "v28", "v29", "v30", "v31", "fa0", "fa1", "fa2", "fa3"); + } +} + +void moe_m2_gemm_kernel_i8i4_impl(size_t blk_len, + const uint8_t ** quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float ** c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { +#if 0 + moe_gemm_kernel_i8i4_mrow_ref<2, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, + ldc); +#else + int64_t b_data_stride = + k_blks * (sizeof(ggml_fp16_t) + 16 * sizeof(int8_t) + (quant_b_zp != NULL ? sizeof(int8_t) : 0)); + if (quant_b_zp == NULL) { + for (size_t ni = 0; ni < count_n; ni += 32) { + uint8_t * b_data = (uint8_t *) quant_b_data + ni * b_data_stride; + int8_t * a_data0 = (int8_t *) quant_a_ptr[0]; + int8_t * a_data1 = (int8_t *) quant_a_ptr[1]; + float * dst_c0 = (float *) c_ptr[0] + ni; + float * dst_c1 = (float *) c_ptr[1] + ni; + + asm volatile( + "vsetvli t0, x0, e16, m1 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "vxor.vv v29, v29, v29 \n\t" + "vmv.v.i v0, 1 \n\t" // init the scale + "vsll.vi v1, v0, 4 \n\t" + "vfcvt.f.x.v v0, v0 \n\t" + "vfcvt.f.x.v v1, v1 \n\t" + "mv t3, %[BK] \n\t" + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + // load scale A0 + "flw fa0, (%[A0]) \n\t" // A0 scale + "lh t1, 4(%[A0]) \n\t" // A0 asum + "addi %[A0], %[A0], 6 \n\t" + + // load scale B + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v12, (%[B]) \n\t" + "addi %[B], %[B], 64 \n\t" + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v14, v12, v12, 3 \n\t" + + // load scale A1 + "flw fa1, (%[A1]) \n\t" // A1 scale + "lh t2, 4(%[A1]) \n\t" // A1 asum + "addi %[A1], %[A1], 6 \n\t" + "vsetvli t0, x0, e16, m1 \n\t" + "vmv.v.x v10, t1 \n\t" + "vmv.v.x v11, t2 \n\t" + + "vpack.vv v18, v10, v11, 1 \n\t" + "vsll.vi v18, v18, 3 \n\t" // mul 8 + "vfcvt.f.x.v v18, v18 \n\t" + + "vsetvli t0, x0, e8, mf4 \n\t" // A0 data + "vle8.v v16, (%[A0]) \n\t" + "addi %[A0], %[A0], 32 \n\t" // 1*32@i8 + "vle8.v v20, (%[A1]) \n\t" + "addi %[A1], %[A1], 32 \n\t" // 1*32@i8 + + "vl4r.v v4, (%[B]) \n\t" // 32*32@i4 + "addi %[B], %[B], 512 \n\t" + + "vsrl.vi v17, v16, 4 \n\t" + "vsrl.vi v21, v20, 4 \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vnpack4.vv v2, v16, v20, 2 \n\t" // low u4 + "vnpack4.vv v3, v17, v21, 2 \n\t" // high s4 + + // init the accumu to asum * zp + "vsetvli t0, x0, e16, m1 \n\t" + "vor.vv v19, v18, v18 \n\t" + "vor.vv v20, v18, v18 \n\t" + "vor.vv v21, v18, v18 \n\t" + + // i4 * i4 vmadot + "vsetvli t0, x0, e16, m1 \n\t" + "vmadotsu.hp v18, v3, v4, v1, 0, i4 \n\t" // high 4 + "vmadotsu.hp v19, v3, v5, v1, 0, i4 \n\t" + "vmadotsu.hp v20, v3, v6, v1, 0, i4 \n\t" + "vmadotsu.hp v21, v3, v7, v1, 0, i4 \n\t" + "vmadotu.hp v18, v2, v4, v0, 0, i4 \n\t" // low 4 + "vmadotu.hp v19, v2, v5, v0, 0, i4 \n\t" + "vmadotu.hp v20, v2, v6, v0, 0, i4 \n\t" + "vmadotu.hp v21, v2, v7, v0, 0, i4 \n\t" + + "vpack.vv v8, v18, v19, 1 \n\t" + "vpack.vv v12, v20, v21, 1 \n\t" + "vpack.vv v20, v8, v12, 2 \n\t" + + "vfwmul.vv v16, v20, v14 \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + + "addi t3, t3, -1 \n\t" + "vfmacc.vf v28, fa0, v16 \n\t" + "vfmacc.vf v29, fa1, v17 \n\t" + + "bgtz t3, BLK_LOOP%= \n\t" + + // save + "vsetvli t0, x0, e32, m1 \n\t" + "vse32.v v28, (%[DST0]) \n\t" + "vse32.v v29, (%[DST1]) \n\t" + : [A0] "+r"(a_data0), [A1] "+r"(a_data1), [B] "+r"(b_data) + : [DST0] "r"(dst_c0), [DST1] "r"(dst_c1), [BK] "r"(k_blks) + : "t0", "t1", "t2", "t3", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", + "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31", "fa0", "fa1", "fa2", "fa3"); + } + } else { +# if 0 + moe_gemm_kernel_i8i4_mrow_ref<2, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, + k_blks, ldc); +# else + for (size_t ni = 0; ni < count_n; ni += 32) { + uint8_t * b_data = (uint8_t *) quant_b_data + ni * b_data_stride; + int8_t * a_data0 = (int8_t *) quant_a_ptr[0]; + int8_t * a_data1 = (int8_t *) quant_a_ptr[1]; + float * dst_c0 = (float *) c_ptr[0] + ni; + float * dst_c1 = (float *) c_ptr[1] + ni; + + asm volatile( + "vsetvli t0, x0, e16, m1 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "vxor.vv v29, v29, v29 \n\t" + "vmv.v.i v0, 1 \n\t" // init the scale + "vsll.vi v1, v0, 4 \n\t" + "vfcvt.f.x.v v0, v0 \n\t" + "vfcvt.f.x.v v1, v1 \n\t" + "mv t3, %[BK] \n\t" + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + // load scale A0 + "flw fa0, (%[A0]) \n\t" // A0 scale + "lh t1, 4(%[A0]) \n\t" // A0 asum + "addi %[A0], %[A0], 6 \n\t" + + // load scale B + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v12, (%[B]) \n\t" + "addi %[B], %[B], 64 \n\t" + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v14, v12, v12, 3 \n\t" + + // load scale A1 + "flw fa1, (%[A1]) \n\t" // A1 scale + "lh t2, 4(%[A1]) \n\t" // A1 asum + "addi %[A1], %[A1], 6 \n\t" + + // load zp + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v8, (%[B]) \n\t" + "addi %[B], %[B], 32 \n\t" + "vwaddu.vx v10, v8, x0 \n\t" + + "vsetvli t0, x0, e8, mf4 \n\t" // A0 data + "vle8.v v16, (%[A0]) \n\t" + "addi %[A0], %[A0], 32 \n\t" // 1*32@i8 + "vle8.v v20, (%[A1]) \n\t" + "addi %[A1], %[A1], 32 \n\t" // 1*32@i8 + + "vl4r.v v4, (%[B]) \n\t" // 32*32@i4 + "addi %[B], %[B], 512 \n\t" + + "vsrl.vi v17, v16, 4 \n\t" + "vsrl.vi v21, v20, 4 \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vnpack4.vv v2, v16, v20, 2 \n\t" // low u4 + "vnpack4.vv v3, v17, v21, 2 \n\t" // high s4 + + // init the accumu to asum * zp + "vsetvli t0, x0, e16, m1 \n\t" + "vxor.vv v18, v18, v18 \n\t" + "vxor.vv v19, v19, v19 \n\t" + "vxor.vv v20, v20, v20 \n\t" + "vxor.vv v21, v21, v21 \n\t" + + // i4 * i4 vmadot + "vsetvli t0, x0, e16, m1 \n\t" + "vmadotsu.hp v18, v3, v4, v1, 0, i4 \n\t" // high 4 + "vmadotsu.hp v19, v3, v5, v1, 0, i4 \n\t" + "vmadotsu.hp v20, v3, v6, v1, 0, i4 \n\t" + "vmadotsu.hp v21, v3, v7, v1, 0, i4 \n\t" + "vmadotu.hp v18, v2, v4, v0, 0, i4 \n\t" // low 4 + "vmadotu.hp v19, v2, v5, v0, 0, i4 \n\t" + "vmadotu.hp v20, v2, v6, v0, 0, i4 \n\t" + "vmadotu.hp v21, v2, v7, v0, 0, i4 \n\t" + + "vpack.vv v8, v18, v19, 1 \n\t" + "vpack.vv v12, v20, v21, 1 \n\t" + "vpack.vv v20, v8, v12, 2 \n\t" + // asum*zp + "vsetvli t0, x0, e16, mf2 \n\t" + "vwmul.vx v2, v10, t1 \n\t" + "vwmul.vx v4, v10, t2 \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + + "vfcvt.f.x.v v2, v2 \n\t" + "vfcvt.f.x.v v4, v4 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vfwcvt.f.f.v v16, v20 \n\t" + + "vfwcvt.f.f.v v18, v14 \n\t" + + // +asum*zp + "vsetvli t0, x0, e32, m1 \n\t" + "vfadd.vv v16, v16, v2 \n\t" + "vfadd.vv v17, v17, v4 \n\t" + "vfmul.vv v16, v16, v18 \n\t" + "vfmul.vv v17, v17, v18 \n\t" + + "addi t3, t3, -1 \n\t" + "vfmacc.vf v28, fa0, v16 \n\t" + "vfmacc.vf v29, fa1, v17 \n\t" + + "bgtz t3, BLK_LOOP%= \n\t" + + // save + "vsetvli t0, x0, e32, m1 \n\t" + "vse32.v v28, (%[DST0]) \n\t" + "vse32.v v29, (%[DST1]) \n\t" + : [A0] "+r"(a_data0), [A1] "+r"(a_data1), [B] "+r"(b_data) + : [DST0] "r"(dst_c0), [DST1] "r"(dst_c1), [BK] "r"(k_blks) + : "t0", "t1", "t2", "t3", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", + "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31", "fa0", "fa1", "fa2", "fa3"); + } +# endif + } +#endif +} + +void moe_m2_gemm_kernel_i8i5_impl(size_t blk_len, + const uint8_t ** quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float ** c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + constexpr size_t NB_COLS = 32; + constexpr size_t B_Q50_BLK_STRIDE = sizeof(nrow_block_q5_0); + constexpr size_t B_Q51_BLK_STRIDE = sizeof(nrow_block_q5_1); + + GGML_UNUSED(blk_len); + GGML_UNUSED(count_m); + GGML_UNUSED(ldc); + + if (quant_b_zp == NULL) { + for (size_t ni = 0; ni < count_n; ni += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + uint8_t * b_data = (uint8_t *) quant_b_data + (ni / NB_COLS) * k_blks * B_Q50_BLK_STRIDE; + int8_t * a_data0 = (int8_t *) quant_a_ptr[0]; + int8_t * a_data1 = (int8_t *) quant_a_ptr[1]; + float * dst_c0 = (float *) c_ptr[0] + ni; + float * dst_c1 = (float *) c_ptr[1] + ni; + + asm volatile( + "mv t4, %[BK] \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v2, v0, v0 \n\t" + "vxor.vv v3, v0, v0 \n\t" + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + // ---- load B scale/Bh/Bs and advance to the next q5_0 k-block ---- + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v1, (%[B]) \n\t" // v1 = scale_fp16 × 32 + "addi %[B], %[B], 64 \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v0, (%[B]) \n\t" // v0 = Bh N32K32 1-bit packed + "addi %[B], %[B], 128 \n\t" + "vl4r.v v8, (%[B]) \n\t" // v8..v11 = Bs N32K32 i4 + "addi %[B], %[B], 512 \n\t" + + // ---- load A0/A1 header then payload, each block stride = 38B ---- + "flw f0, (%[A0]) \n\t" // f0 = A0 scale (fp32) + "lh t2, 4(%[A0]) \n\t" // t2 = A0 asum (int16) + "addi %[A0], %[A0], 6 \n\t" + "flw f1, (%[A1]) \n\t" // f1 = A1 scale (fp32) + "lh t3, 4(%[A1]) \n\t" // t3 = A1 asum (int16) + "addi %[A1], %[A1], 6 \n\t" + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v4, (%[A0]) \n\t" // v4 = A0 M1K32 int8 + "addi %[A0], %[A0], 32 \n\t" + "vle8.v v5, (%[A1]) \n\t" // v5 = A1 M1K32 int8 + "addi %[A1], %[A1], 32 \n\t" + + //// ---- A nibble unpacking ---- + "vsetvli t0, x0, e8, m1 \n\t" + "vand.vi v12, v8, 0xF \n\t" //8bit(lo4) //[8*32] + "vand.vi v13, v9, 0xF \n\t" + "vand.vi v14, v10, 0xF \n\t" + "vand.vi v15, v11, 0xF \n\t" + "vsrl.vi v8, v8, 4 \n\t" //8bit(hi4) + "vsrl.vi v9, v9, 4 \n\t" + "vsrl.vi v10, v10, 4 \n\t" + "vsrl.vi v11, v11, 4 \n\t" + + "slli t2, t2, 4 \n\t" // a_sum * 16; + "slli t3, t3, 4 \n\t" + // [4*32]*2 + "vsetvli t0, x0, e8, m1 \n\t" + "vpack.vv v16, v12, v8, 0 \n\t" + "vpack.vv v18, v13, v9, 0 \n\t" + "vpack.vv v20, v14, v10, 0 \n\t" + "vpack.vv v22, v15, v11, 0 \n\t" + + "li t1, 16 \n\t" + "vsetvli t0, x0, e8, m8 \n\t" + "vadd.vx v16, v16, t1, v0.t \n\t" + + // [4*32]*2 -> [8*16] + "vsetvli t0, x0, e8, m1 \n\t" + "vupack.vv v8, v16, v17, 1 \n\t" + "vupack.vv v10, v18, v19, 1 \n\t" + "vupack.vv v12, v20, v21, 1 \n\t" + "vupack.vv v14, v22, v23, 1 \n\t" + + "vpack.vv v6, v4, v5, 2 \n\t" + + // init the accumu to asum * zp + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v24, v16, v16 \n\t" + "vxor.vv v26, v16, v16 \n\t" + "vupack.vv v4, v6, v7, 1 \n\t" + "vxor.vv v28, v16, v16 \n\t" + "vxor.vv v30, v16, v16 \n\t" + + // ---- i8 main dot products ---- + // vmadot: A × unsigned Bh × 16 → fp16 accumulate + "vmadot v24, v4, v8, i8 \n\t" // N0..7 + "vmadot v26, v4, v10, i8 \n\t" // N8..15 + "vmadot v28, v4, v12, i8 \n\t" // N16..23 + "vmadot v30, v4, v14, i8 \n\t" // N24..31 + // vmadot: A × unsigned Bh × 1 → fp16 accumulate + "vmadot v24, v5, v9, i8 \n\t" // N0..7 + "vmadot v26, v5, v11, i8 \n\t" // N8..15 + "vmadot v28, v5, v13, i8 \n\t" // N16..23 + "vmadot v30, v5, v15, i8 \n\t" // N24..31 + + "vpack.vv v16, v24, v26, 2 \n\t" // v16 = N0..15 + "vpack.vv v18, v28, v30, 2 \n\t" // v18 = N16..31 + "vpack.vv v24, v16, v18, 3 \n\t" // v24 = N0..31 + + "vadd.vx v24, v24, t2 \n\t" + "vadd.vx v25, v25, t3 \n\t" + // b_scale fp16 -> fp32 + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwcvt.f.f.v v28, v1 \n\t" + + // a_scale * b_scale; + "vsetvli t0, x0, e32, m1 \n\t" + "vfcvt.f.x.v v26, v24 \n\t" + "vfcvt.f.x.v v27, v25 \n\t" + "vfmul.vf v30, v28, f0 \n\t" + "vfmul.vf v31, v28, f1 \n\t" + // static_cast(qsum) * a_scale * b_scale; + "vfmacc.vv v2, v30, v26 \n\t" + "vfmacc.vv v3, v31, v27 \n\t" + + "addi t4, t4, -1 \n\t" + "bgtz t4, BLK_LOOP%= \n\t" + + "vsetvli t0, %[NR], e32, m1 \n\t" + "vse32.v v2, (%[DST0]) \n\t" + "vse32.v v3, (%[DST1]) \n\t" + : [A0] "+r"(a_data0), [A1] "+r"(a_data1), [B] "+r"(b_data) + : [DST0] "r"(dst_c0), [DST1] "r"(dst_c1), [BK] "r"(k_blks), [NR] "r"(nb_real) + : "cc", "memory", "t0", "t1", "t2", "t3", "t4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", + "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", + "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "f0", "f1"); + } + } else { + for (size_t ni = 0; ni < count_n; ni += NB_COLS) { + size_t nb_real = std::min(NB_COLS, count_n - ni); + uint8_t * b_data = (uint8_t *) quant_b_data + (ni / NB_COLS) * k_blks * B_Q51_BLK_STRIDE; + int8_t * a_data0 = (int8_t *) quant_a_ptr[0]; + int8_t * a_data1 = (int8_t *) quant_a_ptr[1]; + float * dst_c0 = (float *) c_ptr[0] + ni; + float * dst_c1 = (float *) c_ptr[1] + ni; + + asm volatile( + "mv t4, %[BK] \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v2, v0, v0 \n\t" + "vxor.vv v3, v0, v0 \n\t" + "addi t5, %[B], 64 \n\t" // t5 = zp (32B) + "addi t6, %[B], 96 \n\t" // t6 = qh (128B) + "addi s1, %[B], 224 \n\t" // s1 = qs (512B) + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + // ---- load B scale/zp/Bh/Bs and advance to the next q5_1 k-block ---- + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v1, (%[B]) \n\t" // v1 = scale_fp16 × 32 + "addi %[B], %[B], 736 \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v0, (t6) \n\t" // v0 = Bh N32K32 1-bit packed + "addi t6, t6, 736 \n\t" + "vl4r.v v8, (s1) \n\t" // v8..v11 = Bs N32K32 i4 + "addi s1, s1, 736 \n\t" + + // ---- load A0/A1 header then payload, each block stride = 38B ---- + "flw f0, (%[A0]) \n\t" // f0 = A0 scale (fp32) + "lh t2, 4(%[A0]) \n\t" // t2 = A0 asum (int16) + "addi %[A0], %[A0], 6 \n\t" + "flw f1, (%[A1]) \n\t" // f1 = A1 scale (fp32) + "lh t3, 4(%[A1]) \n\t" // t3 = A1 asum (int16) + "addi %[A1], %[A1], 6 \n\t" + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v4, (%[A0]) \n\t" // v4 = A0 M1K32 int8 + "addi %[A0], %[A0], 32 \n\t" + "vle8.v v5, (%[A1]) \n\t" // v5 = A1 M1K32 int8 + "addi %[A1], %[A1], 32 \n\t" + + //// ---- A nibble unpacking ---- + "vsetvli t0, x0, e8, m1 \n\t" + "vand.vi v12, v8, 0xF \n\t" //8bit(lo4) //[8*32] + "vand.vi v13, v9, 0xF \n\t" + "vand.vi v14, v10, 0xF \n\t" + "vand.vi v15, v11, 0xF \n\t" + "vsrl.vi v8, v8, 4 \n\t" //8bit(hi4) + "vsrl.vi v9, v9, 4 \n\t" + "vsrl.vi v10, v10, 4 \n\t" + "vsrl.vi v11, v11, 4 \n\t" + + // q5_1 uses explicit zp, so keep a_sum unshifted here. + // [4*32]*2 + "vpack.vv v16, v12, v8, 0 \n\t" + "vpack.vv v18, v13, v9, 0 \n\t" + "vpack.vv v20, v14, v10, 0 \n\t" + "vpack.vv v22, v15, v11, 0 \n\t" + + "li t1, 16 \n\t" + "vsetvli t0, x0, e8, m8 \n\t" + "vadd.vx v16, v16, t1, v0.t \n\t" + + // [4*32]*2 -> [8*16] + "vsetvli t0, x0, e8, m1 \n\t" + "vupack.vv v8, v16, v17, 1 \n\t" + "vupack.vv v10, v18, v19, 1 \n\t" + "vupack.vv v12, v20, v21, 1 \n\t" + "vupack.vv v14, v22, v23, 1 \n\t" + + "vpack.vv v6, v4, v5, 2 \n\t" + + // init the accumu to asum * zp + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v24, v16, v16 \n\t" + "vxor.vv v26, v16, v16 \n\t" + "vupack.vv v4, v6, v7, 1 \n\t" + "vxor.vv v28, v16, v16 \n\t" + "vxor.vv v30, v16, v16 \n\t" + + // ---- i8 main dot products ---- + // vmadot: A × unsigned Bh × 16 → fp16 accumulate + "vmadot v24, v4, v8, i8 \n\t" // N0..7 + "vmadot v26, v4, v10, i8 \n\t" // N8..15 + "vmadot v28, v4, v12, i8 \n\t" // N16..23 + "vmadot v30, v4, v14, i8 \n\t" // N24..31 + // vmadot: A × unsigned Bh × 1 → fp16 accumulate + "vmadot v24, v5, v9, i8 \n\t" // N0..7 + "vmadot v26, v5, v11, i8 \n\t" // N8..15 + "vmadot v28, v5, v13, i8 \n\t" // N16..23 + "vmadot v30, v5, v15, i8 \n\t" // N24..31 + + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v4, (t5) \n\t" // v4 = Bzp N32 uint8 + "addi t5, t5, 736 \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vpack.vv v16, v24, v26, 2 \n\t" // v16 = N0..15 + "vpack.vv v18, v28, v30, 2 \n\t" // v18 = N16..31 + "vpack.vv v24, v16, v18, 3 \n\t" // v24 = N0..31 + + "vsetvli t0, x0, e8, mf4 \n\t" + "vwaddu.vx v28, v4, x0 \n\t" + + "vsetvli t0, x0, e16, mf2 \n\t" + "vwmul.vx v30, v28, t2 \n\t" + "vwmul.vx v31, v28, t3 \n\t" + + // b_scale fp16 -> fp32 + "vfwcvt.f.f.v v28, v1 \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vadd.vv v24, v24, v30 \n\t" + "vadd.vv v25, v25, v31 \n\t" + + // a_scale * b_scale; + "vfcvt.f.x.v v26, v24 \n\t" + "vfcvt.f.x.v v27, v25 \n\t" + "vfmul.vf v30, v28, f0 \n\t" + "vfmul.vf v31, v28, f1 \n\t" + // static_cast(qsum) * a_scale * b_scale; + "vfmacc.vv v2, v30, v26 \n\t" + "vfmacc.vv v3, v31, v27 \n\t" + + "addi t4, t4, -1 \n\t" + "bgtz t4, BLK_LOOP%= \n\t" + + "vsetvli t0, %[NR], e32, m1 \n\t" + "vse32.v v2, (%[DST0]) \n\t" + "vse32.v v3, (%[DST1]) \n\t" + : [A0] "+r"(a_data0), [A1] "+r"(a_data1), [B] "+r"(b_data) + : [DST0] "r"(dst_c0), [DST1] "r"(dst_c1), [BK] "r"(k_blks), [NR] "r"(nb_real) + : "cc", "memory", "t0", "t1", "t2", "t3", "t4", "t5", "t6", "s1", "v0", "v1", "v2", "v3", "v4", "v5", + "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", + "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "f0", "f1"); + } + } +} + +size_t gemm_kernel_i8i2k(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + if (count_m >= 4) { +#if 0 + gemm_kernel_i8i2k_mrow_ref<4, 32>(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_m, count_n, k_blks, ldc); +#else + gemm_kernel_i8i2k_m4(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 4; + } else { +#if 0 + gemm_kernel_i8i2k_mrow_ref<1, 32>(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_m, count_n, k_blks, + ldc); +#else + gemm_kernel_i8i2k_m1(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 1; + } +} + +size_t gemm_kernel_i8i3k(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + if (count_m >= 4) { +#if 0 + gemm_kernel_i8i3k_mrow_ref<4, 32>(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_m, count_n, k_blks, ldc); +#else + gemm_kernel_i8i3k_m4(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 4; + } else { +#if 0 + gemm_kernel_i8i3k_mrow_ref<1, 32>(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_m, count_n, k_blks, ldc); +#else + gemm_kernel_i8i3k_m1(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 1; + } +} + +size_t gemm_kernel_i8i4(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + if (count_m >= 4) { +#if 0 + gemm_kernel_i8i4_mrow_ref<4, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, + k_blks, ldc); +#else + gemm_kernel_i8i4_m4(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 4; + } else { +#if 0 + gemm_kernel_i8i4_mrow_ref<1, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, + k_blks, ldc); +#else + gemm_kernel_i8i4_m1(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 1; + } +} + +size_t gemm_kernel_i8i4_hp(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + if (count_m >= 4) { +#if 0 + gemm_kernel_i8i4_hp_mrow_ref<4, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, + k_blks, ldc); +#else + gemm_kernel_i8i4_hp_m4(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 4; + } else { +#if 0 + gemm_kernel_i8i4_hp_mrow_ref<1, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, + k_blks, ldc); +#else + gemm_kernel_i8i4_hp_m1(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 1; + } +} + +size_t moe_m2_gemm_kernel_i8i4(size_t blk_len, + const uint8_t ** quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float ** c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + moe_m2_gemm_kernel_i8i4_impl(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); + return 2; +} + +size_t gemm_kernel_i8i8(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + if (count_m >= 4) { +#if 0 + gemm_kernel_i8i8_mrow_ref<4, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, + k_blks, ldc); +#else + gemm_kernel_i8i8_m4(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 4; + } else { +#if 0 + gemm_kernel_i8i8_mrow_ref<1, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, + k_blks, ldc); +#else + gemm_kernel_i8i8_m1(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 1; + } +} + +size_t gemm_kernel_i8mxfp4(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + if (count_m >= 4) { +#if 1 + gemm_kernel_i8mxfp4_mrow_ref<4, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, + k_blks, ldc); +#else + gemm_kernel_i8mxfp4_m4(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 4; + } else { +#if 1 + gemm_kernel_i8mxfp4_mrow_ref<1, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, + k_blks, ldc); +#else + gemm_kernel_i8mxfp4_m1(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 1; + } +} + +size_t moe_m2_gemm_kernel_i8mxfp4(size_t blk_len, + const uint8_t ** quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float ** c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + //moe_m2_gemm_kernel_i8mxfp4_impl(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); + return 2; +} + +size_t gemm_kernel_i8i5(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + if (count_m >= 4) { +#if 0 + gemm_kernel_i8i5_mrow_ref<4, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, + k_blks, ldc); +#else + gemm_kernel_i8i5_m4(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 4; + } else { +#if 0 + gemm_kernel_i8i5_mrow_ref<1, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, + k_blks, ldc); +#else + gemm_kernel_i8i5_m1(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 1; + } +} + +size_t moe_m2_gemm_kernel_i8i5(size_t blk_len, + const uint8_t ** quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float ** c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { +#if 0 + moe_gemm_kernel_i8i5_mrow_ref<2, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, + k_blks, ldc); +#else + moe_m2_gemm_kernel_i8i5_impl(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 2; +} + +} // namespace ime2 +} // namespace spacemit_kernels diff --git a/ggml/src/ggml-cpu/spacemit/ime_env.cpp b/ggml/src/ggml-cpu/spacemit/ime_env.cpp new file mode 100644 index 00000000000..a13ba391da2 --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/ime_env.cpp @@ -0,0 +1,320 @@ +#include "ime_env.h" + +#include "ggml-impl.h" +#include "spine_mem_pool.h" + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace ggml::cpu::riscv64_spacemit { +bool spine_core_info::get_spine_core_info(std::vector & result) { + static std::unordered_map spine_march_mapping_ = { + {0x8000000058000001, spine_core_arch_id::core_arch_x60 }, + { 0x8000000041000001, spine_core_arch_id::core_arch_a60 }, + { 0x8000000058000002, spine_core_arch_id::core_arch_x100}, + { 0x8000000041000002, spine_core_arch_id::core_arch_a100}, + }; + + result.clear(); + std::ifstream file("/proc/cpuinfo"); + std::string line; + + std::vector> cpu_info_list; + + uint64_t current_processor = spine_invalid_core_id; + uint64_t current_marchid = 0; + bool has_processor = false; + bool has_marchid = false; + + if (!file.is_open()) { + return false; + } + + while (std::getline(file, line)) { + if (line.substr(0, 9) == "processor") { + if (has_processor && has_marchid) { + cpu_info_list.push_back({ current_processor, current_marchid }); + } + + size_t colon_pos = line.find(':'); + if (colon_pos != std::string::npos) { + current_processor = std::stoi(line.substr(colon_pos + 1)); + has_processor = true; + } + + has_marchid = false; + } else if (line.substr(0, 7) == "marchid") { + size_t colon_pos = line.find(':'); + if (colon_pos != std::string::npos) { + std::string marchid_str = line.substr(colon_pos + 1); + marchid_str.erase(std::remove_if(marchid_str.begin(), marchid_str.end(), isspace), marchid_str.end()); + current_marchid = std::stoull(marchid_str, nullptr, 16); + has_marchid = true; + } + } + } + + if (has_processor && has_marchid) { + cpu_info_list.push_back({ current_processor, current_marchid }); + } + + if (has_processor && has_marchid) { + for (auto & cpu_info : cpu_info_list) { + if (cpu_info[0] != spine_invalid_core_id && + spine_march_mapping_.find(cpu_info[1]) != spine_march_mapping_.end()) { + auto core_info = spine_core_info(); + core_info.core_id = cpu_info[0]; + core_info.arch_id = spine_core_arch_id(spine_march_mapping_[cpu_info[1]]); + + result.push_back(core_info); + } + } + } + + return has_processor && has_marchid; +} + +namespace { +uint16_t hex_string_to_u16(const std::string & hex_str) { + try { + size_t pos = 0; + if (hex_str.substr(0, 2) == "0x" || hex_str.substr(0, 2) == "0X") { + pos = 2; + } + unsigned long result = std::stoul(hex_str.substr(pos), nullptr, 16); + if (result > std::numeric_limits::max()) { + throw std::out_of_range("Converted value is out of range for uint16_t"); + } + return static_cast(result); + } catch (const std::invalid_argument & e) { + throw std::invalid_argument("Invalid hexadecimal string"); + } catch (const std::out_of_range & e) { + throw; + } +} + +const char * spine_mem_pool_backend_to_string(spine_mem_pool_backend backend) { + switch (backend) { + case spine_mem_pool_backend::none: + return "NONE"; + case spine_mem_pool_backend::posix_memalign: + return "POSIX"; + case spine_mem_pool_backend::transparent_hugepage: + return "HPAGE"; + case spine_mem_pool_backend::hugetlb_1g: + return "HPAGE1GB"; + } + + return "unknown"; +} + +spine_mem_pool_backend parse_mem_backend(const char * mem_backend_str) { + if (mem_backend_str == nullptr || mem_backend_str[0] == '\0') { + return spine_mem_pool_backend::transparent_hugepage; + } + + std::string value(mem_backend_str); + std::transform(value.begin(), value.end(), value.begin(), + [](unsigned char ch) { return static_cast(std::tolower(ch)); }); + + if (value == "none") { + return spine_mem_pool_backend::none; + } + + if (value == "posix") { + return spine_mem_pool_backend::posix_memalign; + } + + if (value == "hpage") { + return spine_mem_pool_backend::transparent_hugepage; + } + + if (value == "hpage1gb") { + return spine_mem_pool_backend::hugetlb_1g; + } + + throw std::runtime_error("invalid SPACEMIT_MEM_BACKEND: " + value + ", expected NONE, POSIX, HPAGE or HPAGE1GB"); +} +} // namespace + +spine_env_info::spine_env_info() { + num_cores = static_cast(std::thread::hardware_concurrency()); + spine_core_info::get_spine_core_info(core_info_list); + + // special for x60 K1 + if (core_info_list.size() == 8 && core_info_list[0].arch_id == spine_core_arch_id::core_arch_x60) { + for (int i = 0; i < 4; i++) { + core_info_list[i].arch_id = spine_core_arch_id::core_arch_a60; + } + } + + // special for qemu + if (core_info_list.size() == 0) { + char * spine_core_arch_str = getenv("SPACEMIT_CORE_ARCH"); + if (spine_core_arch_str != nullptr) { + auto arch_id = hex_string_to_u16(spine_core_arch_str); + for (int i = 0; i < num_cores; i++) { + auto core_info = spine_core_info(); + core_info.core_id = i; + core_info.arch_id = spine_core_arch_id{ arch_id }; + core_info_list.push_back(core_info); + } + } + } + + if (core_info_list.size() == 0) { + throw std::runtime_error( + "Failed to get SPACEMIT_CORE_ARCH from environment or failed to parse it from /proc/cpuinfo"); + } + + char * spine_perfer_core_arch_str = getenv("SPACEMIT_PERFER_CORE_ARCH"); + if (spine_perfer_core_arch_str != nullptr && spine_perfer_core_arch_str != "") { + perfer_core_arch_id = spine_core_arch_id{ hex_string_to_u16(spine_perfer_core_arch_str) }; + } + + char * spine_perfer_core_id_str = getenv("SPACEMIT_PERFER_CORE_ID"); + std::vector perfer_core_id_vec; + if (spine_perfer_core_id_str != nullptr && spine_perfer_core_id_str != "") { + std::string perfer_core_id_str(spine_perfer_core_id_str); + size_t start = 0; + size_t end = 0; + while ((end = perfer_core_id_str.find(',', start)) != std::string::npos) { + std::string core_id_substr = perfer_core_id_str.substr(start, end - start); + perfer_core_id_vec.push_back(std::stoi(core_id_substr)); + start = end + 1; + } + std::string core_id_substr = perfer_core_id_str.substr(start); + perfer_core_id_vec.push_back(std::stoi(core_id_substr)); + } + + perfer_core_ids.reserve(num_cores); + if (perfer_core_arch_id == spine_core_arch_id::core_arch_none) { + for (auto & core_info : core_info_list) { + auto core_arch_id = core_info.arch_id; + auto core_arch_head = (uint16_t) (core_arch_id) >> 12; + if (core_arch_head == 0xA) { + num_perfer_cores++; + perfer_core_arch_id = core_arch_id; + cpu_mask |= (1ULL << core_info.core_id); + perfer_core_ids.push_back(core_info.core_id); + } + } + } else { + for (auto & core_info : core_info_list) { + auto core_arch_id = core_info.arch_id; + if (core_arch_id == perfer_core_arch_id) { + num_perfer_cores++; + cpu_mask |= (1ULL << core_info.core_id); + + auto core_arch_head = (uint16_t) (core_arch_id) >> 12; + if (core_arch_head == 0xA) { + perfer_core_ids.push_back(core_info.core_id); + } + } + } + if (num_perfer_cores == 0) { + GGML_ABORT("can not find core with arch id %x for SPACEMIT_PERFER_CORE_ARCH in core info list\n", + (uint16_t) perfer_core_arch_id); + } + } + + if (perfer_core_id_vec.size() > 0) { + perfer_core_ids.clear(); + cpu_mask = 0; + num_perfer_cores = 0; + for (int core_id : perfer_core_id_vec) { + if (core_id < 0 || core_id >= num_cores) { + GGML_ABORT("invalid core id in SPACEMIT_PERFER_CORE_ID: %d, should be between 0 and %d\n", core_id, + num_cores - 1); + } + auto core_info = core_info_list[core_id]; + auto core_arch_id = core_info.arch_id; + if (core_arch_id == perfer_core_arch_id) { + cpu_mask |= (1ULL << core_id); + perfer_core_ids.push_back(core_id); + } else { + GGML_ABORT( + "core id %d in SPACEMIT_PERFER_CORE_ID has arch id %x which does not match " + "SPACEMIT_PERFER_CORE_ARCH %x\n", + core_id, (uint16_t) core_arch_id, (uint16_t) perfer_core_arch_id); + } + } + std::string perfer_core_id_vec_str; + for (int core_id : perfer_core_id_vec) { + perfer_core_id_vec_str += std::to_string(core_id) + ","; + } + perfer_core_id_vec_str.pop_back(); + GGML_LOG_DEBUG("SPACEMIT_PERFER_CORE_ID is set, perferred core ids: %s\n", perfer_core_id_vec_str.c_str()); + num_perfer_cores = static_cast(perfer_core_id_vec.size()); + } + + use_ime1 = perfer_core_arch_id == spine_core_arch_id::core_arch_a60 || + perfer_core_arch_id == spine_core_arch_id::core_arch_x100; + + use_ime2 = perfer_core_arch_id == spine_core_arch_id::core_arch_a100; + + mem_backend = parse_mem_backend(getenv("SPACEMIT_MEM_BACKEND")); + char * spine_disable_tcm_str = getenv("SPACEMIT_DISABLE_TCM"); + auto user_disable_tcm = spine_disable_tcm_str != nullptr && strcmp(spine_disable_tcm_str, "0") != 0; + + if (!user_disable_tcm) { + spine_mem_pool_tcm_info tcm_info; + if (spine_mem_pool_tcm_init(&tcm_info)) { + use_tcm = tcm_info.available; + tcm_blk_size = tcm_info.blk_size; + GGML_LOG_DEBUG("CPU_RISCV64_SPACEMIT: tcm is available, blk_size: %zu, blk_num: %zu, is_fake_tcm: %d\n", + tcm_info.blk_size, tcm_info.blk_num, tcm_info.is_fake_tcm); + + for (auto & core_info : core_info_list) { + auto core_arch_head = (uint16_t) (core_info.arch_id) >> 12; + if (core_arch_head != 0xA) { + aicpu_id_offset++; + } else { + break; + } + } + } + } + + GGML_LOG_DEBUG( + "CPU_RISCV64_SPACEMIT: num_cores: %d, num_perfer_cores: %d, perfer_core_arch_id: %x, exclude_main_thread: %d, " + "use_ime1: %d, use_ime2: %d, mem_backend: %s, cpu_mask: %lx, aicpu_id_offset: %d\n", + num_cores, num_perfer_cores, (uint16_t) perfer_core_arch_id, exclude_main_thread, use_ime1, use_ime2, + spine_mem_pool_backend_to_string(mem_backend), cpu_mask, aicpu_id_offset); + + const size_t init_barrier_size = sizeof(spine_barrier_t) * spine_init_barrier_count; + init_barrier = + static_cast(spine_mem_pool_shared_mem_alloc(init_barrier_size, alignof(spine_barrier_t))); + if (init_barrier != nullptr) { + init_barrier_is_shared_mem = true; + } else { + GGML_LOG_WARN("CPU_RISCV64_SPACEMIT: failed to allocate init_barrier from shared mem, falling back to heap\n", + __func__); + init_barrier = new spine_barrier_t[spine_init_barrier_count]; + } + + spine_barrier_init(init_barrier, spine_init_barrier_count, 2); +} + +spine_env_info::~spine_env_info() { + if (init_barrier_is_shared_mem) { + spine_mem_pool_shared_mem_free(init_barrier); + } else { + delete[] init_barrier; + } + + init_barrier = nullptr; + init_barrier_is_shared_mem = false; +} + +spine_env_info global_spine_env_info; + +} // namespace ggml::cpu::riscv64_spacemit diff --git a/ggml/src/ggml-cpu/spacemit/ime_env.h b/ggml/src/ggml-cpu/spacemit/ime_env.h new file mode 100644 index 00000000000..a6ca06d26a4 --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/ime_env.h @@ -0,0 +1,55 @@ +#pragma once + +#include "spine_barrier.h" +#include "spine_mem_pool.h" + +#include +#include +#include + +namespace ggml::cpu::riscv64_spacemit { + +constexpr uint64_t spine_invalid_core_id = 0xFFFFFFFF; +constexpr size_t spine_init_barrier_count = 16; + +enum class spine_core_arch_id : uint16_t { + core_arch_none = 0, + core_arch_x60 = 0x503C, + core_arch_x100 = 0x5064, + core_arch_x200 = 0x50C8, + core_arch_a60 = 0xA03C, + core_arch_a100 = 0xA064, + core_arch_a200 = 0xA0C8, +}; + +struct spine_core_info { + uint64_t core_id{ spine_invalid_core_id }; + spine_core_arch_id arch_id{ spine_core_arch_id::core_arch_none }; + + static bool get_spine_core_info(std::vector & result); +}; + +struct spine_env_info { + std::vector core_info_list; + std::vector perfer_core_ids; + int aicpu_id_offset{ 0 }; + int num_cores{ 0 }; + int num_perfer_cores{ 0 }; + spine_core_arch_id perfer_core_arch_id{ spine_core_arch_id::core_arch_none }; + bool exclude_main_thread{ false }; + bool use_ime2{ false }; + bool use_ime1{ false }; + bool use_tcm{ false }; + spine_mem_pool_backend mem_backend{ spine_mem_pool_backend::transparent_hugepage }; + uint64_t tcm_blk_size{ 0 }; + uint64_t cpu_mask{ 0 }; + spine_barrier_t * init_barrier{ nullptr }; + bool init_barrier_is_shared_mem{ false }; + + spine_env_info(); + ~spine_env_info(); +}; + +extern spine_env_info global_spine_env_info; + +} // namespace ggml::cpu::riscv64_spacemit diff --git a/ggml/src/ggml-cpu/spacemit/ime_kernels.h b/ggml/src/ggml-cpu/spacemit/ime_kernels.h index 75706341505..0a1fafffb25 100644 --- a/ggml/src/ggml-cpu/spacemit/ime_kernels.h +++ b/ggml/src/ggml-cpu/spacemit/ime_kernels.h @@ -1,26 +1,189 @@ #pragma once +#include #include +#include + +namespace spacemit_kernels { + +#define BLOCK_QNK_LEN 256 + +template struct nrow_block_q2_k { + // [4bit scale + 4bit zp] * N * 16 + uint8_t scales[N * BLOCK_QNK_LEN / 16]; + // [b0, b16, b32, b48] [b1, b17, b33, b49] ... [b15, b31, b47, b63] + // [b64, b80, b96, b112] ...[b79, b95, b111, b127] + // [b128, b144, b160, b176] ...[b143, b159, b175, b191] + // [b192, b208, b224, b240] ...[b207, b223, b239, b255] + uint8_t qs[N * BLOCK_QNK_LEN / 4]; + uint16_t scales16[N]; + uint16_t zeros16[N]; +}; + +template struct nrow_block_q3_k { + // [8bit scale] * N * 16 + int8_t scales[N * 16]; + // [b0, b1, b2, b3, b4, b5, b6, b7] ... [b248, b249, b250, b251, b252, b253, b254, b255] + uint8_t hmask[N * BLOCK_QNK_LEN / 8]; + // [b0, b16, b32, b48] [b1, b17, b33, b49] ... [b15, b31, b47, b63] + // [b64, b80, b96, b112] ...[b79, b95, b111, b127] + // [b128, b144, b160, b176] ...[b143, b159, b175, b191] + // [b192, b208, b224, b240] ...[b207, b223, b239, b255] + uint8_t qs[N * BLOCK_QNK_LEN / 4]; + uint16_t scales16[N]; +}; + +template struct nrow_block_mxfp4 { + uint8_t e[N]; + uint8_t qh[4 * N]; + uint8_t qs[16 * N]; +}; + +template struct __attribute__((packed)) nrow_block_q5_1 { + uint16_t scales16[N]; + uint8_t zp[N]; + // n0 [bh0, bh1, bh2, bh3, bh4, bh5, bh6, bh7] .... + uint8_t qh[4 * N]; + // n0 [b0, b1], [b2, b3] .... [b30, b31] + // n1 [b0, b1], [b2, b3] .... [b30, b31] + uint8_t qs[16 * N]; +}; + +static_assert(sizeof(nrow_block_q5_1<1>) == sizeof(uint8_t) + 22, "wrong nrow_block_q5_1 block size/padding"); + +template struct __attribute__((packed)) nrow_block_q5_0 { + uint16_t scales16[N]; + // n0 [bh0, bh1, bh2, bh3, bh4, bh5, bh6, bh7] .... + uint8_t qh[4 * N]; + // n0 [b0, b1], [b2, b3] .... [b30, b31] + // n1 [b0, b1], [b2, b3] .... [b30, b31] + uint8_t qs[16 * N]; +}; + +static_assert(sizeof(nrow_block_q5_0<1>) == 22, "wrong nrow_block_q5_0 block size/padding"); + +using gemm_kernel_quantize_def = std::function< + size_t(size_t, const uint8_t *, const uint8_t *, const uint8_t *, float *, size_t, size_t, size_t, size_t)>; + +using moe_gemm_kernel_quantize_def = std::function< + size_t(size_t, const uint8_t **, const uint8_t *, const uint8_t *, float **, size_t, size_t, size_t, size_t)>; -namespace sqnbitgemm_spacemit_ime { namespace ime1 { -size_t gemm_kernel_i8i4(size_t blk_len, - const std::byte * quant_a_ptr, - const std::byte * quant_b_data, - const float * quant_b_scale, - const std::byte * quant_b_zp, - float * c_ptr, - size_t count_m, - size_t count_n, - size_t count_k, - size_t block_count_k, - size_t ldc, - const float * bias, - const size_t scale_stride); - -void quantize_a_row_i8(size_t blk_len, const float * a_ptr, size_t count_k, std::byte * quant_a_ptr); - -void quantize_a_4row_i8(size_t blk_len, const float * a_ptr, size_t count_k, std::byte * quant_a_ptr); +size_t gemm_kernel_i8i4(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc); + +void quantize_a_row_i8(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr); + +void quantize_a_4row_i8(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr); } // namespace ime1 -} // namespace sqnbitgemm_spacemit_ime + +namespace ime2 { +size_t gemm_kernel_i8i2k(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc); + +size_t gemm_kernel_i8i3k(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc); + +size_t gemm_kernel_i8i4(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc); + +size_t gemm_kernel_i8i4_hp(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc); + +size_t moe_m2_gemm_kernel_i8i4(size_t blk_len, + const uint8_t ** quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float ** c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc); + +size_t gemm_kernel_i8i8(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc); + +size_t gemm_kernel_i8mxfp4(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc); + +size_t moe_m2_gemm_kernel_i8mxfp4(size_t blk_len, + const uint8_t ** quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float ** c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc); + +size_t gemm_kernel_i8i5(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc); + +size_t moe_m2_gemm_kernel_i8i5(size_t blk_len, + const uint8_t ** quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float ** c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc); +} // namespace ime2 +} // namespace spacemit_kernels diff --git a/ggml/src/ggml-cpu/spacemit/repack.cpp b/ggml/src/ggml-cpu/spacemit/repack.cpp new file mode 100644 index 00000000000..3c879c4b7a0 --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/repack.cpp @@ -0,0 +1,1795 @@ +#define GGML_COMMON_IMPL_CPP +#define GGML_COMMON_DECL_CPP + +#include "repack.h" + +#include "ggml-common.h" +#include "ggml-cpu.h" +#include "ggml-impl.h" +#include "ime_kernels.h" + +#include +#include +#include +#include + +// clang-format off +#if defined(__riscv) + +#if !defined(__riscv_v) || !defined(__riscv_v_intrinsic) +#error "riscv v extension or v_intrinsic not enabled" +#else +#include +#endif + +#if !defined(__riscv_zfh) +#error "riscv zfh extension not enabled" +#endif + +#else +#error "riscv not enabled in this build" +#endif + +#if defined(__GNUC__) +#pragma GCC diagnostic ignored "-Wcast-qual" +#pragma GCC diagnostic ignored "-Wunused-parameter" +#endif + +// clang-format on + +template constexpr int QK_0() { + if constexpr (K == 4) { + return QK4_0; + } + if constexpr (K == 8) { + return QK8_0; + } + return -1; +} + +template struct block { + ggml_half d[N]; // deltas for N qK_0 blocks + uint8_t qs[(QK_0() * N * K) / 8]; // quants for N qK_0 blocks +}; + +template struct block_with_zp { + ggml_half d[N]; // deltas for N qK_1 blocks + uint8_t zp[N]; // zero points for N qK_1 blocks + uint8_t qs[(QK_0() * N * K) / 8]; // quants for N qK_1 blocks +}; + +// control size +static_assert(sizeof(block<4, 16>) == 16 * sizeof(ggml_half) + QK4_0 * 8, "wrong block<4,16> size/padding"); +static_assert(sizeof(block_with_zp<4, 16>) == 16 * sizeof(ggml_half) + QK4_0 * 8 + 16 * sizeof(uint8_t), + "wrong block_with_zp<4,16> size/padding"); + +static_assert(sizeof(block<8, 16>) == 16 * sizeof(ggml_half) + QK4_0 * 16, "wrong block<8,16> size/padding"); + +static_assert(sizeof(block<4, 32>) == 32 * sizeof(ggml_half) + QK4_0 * 16, "wrong block<4,32> size/padding"); +static_assert(sizeof(block_with_zp<4, 32>) == 32 * sizeof(ggml_half) + QK4_0 * 16 + 32 * sizeof(uint8_t), + "wrong block_with_zp<4,32> size/padding"); + +using block_q4_0x16 = block<4, 16>; +using block_q4_1x16 = block_with_zp<4, 16>; +using block_q8_0x16 = block<8, 16>; + +using block_q4_0x32 = block<4, 32>; +using block_q4_1x32 = block_with_zp<4, 32>; +using block_q8_0x32 = block<8, 32>; + +struct block_q4_0x32x256 { + block_q4_0x32 blocks[8]; // [f16 * 32 | i4 * 32 * 32] * 8 +}; + +struct block_q4_1x32x256 { + block_q4_0x32 blocks[8]; + uint8_t zps[32 * 8]; +}; + +static block_q4_0x16 make_block_q4_0x16(block_q4_0 * in, unsigned int blck_size_interleave) { + block_q4_0x16 out; + GGML_ASSERT(QK4_0 / blck_size_interleave == 2); + + for (int i = 0; i < 16; i++) { + out.d[i] = in[i].d; + } + + for (int i = 0; i < 16; i++) { + // [0, 15], in.d & 0x0F + for (int j = 0; j < QK4_0 / 4; j++) { + //src [b0 b16] ......... [b8 b24] ......... [b15 b31] + //dst [b0 b8] ......... [b7 b15] + out.qs[i * QK4_0 / 4 + j] = (in[i].qs[j] & 0x0F) | ((in[i].qs[j + QK4_0 / 4] & 0x0F) << 4); + } + } + + for (int i = 0; i < 16; i++) { + // [16, 31], in.d & 0xF0 + for (int j = 0; j < QK4_0 / 4; j++) { + //src [b0 b16] ......... [b8 b24] ......... [b15 b31] + //dst [b16 b24] ......... [b23 b31] + out.qs[4 * QK4_0 + i * QK4_0 / 4 + j] = ((in[i].qs[j] & 0xF0) >> 4) | (in[i].qs[j + QK4_0 / 4] & 0xF0); + } + } + + return out; +} + +static block_q4_1x16 make_block_q4_1x16(block_q4_1 * in, unsigned int blck_size_interleave) { + block_q4_1x16 out; + GGML_ASSERT(QK4_1 / blck_size_interleave == 2); + + for (int i = 0; i < 16; i++) { + float d = GGML_FP16_TO_FP32(in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d); + float m = GGML_FP16_TO_FP32(in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.m); + float mid = -std::nearbyintf(m / d); + mid = std::min(15.0f, std::max(0.0f, mid)); + out.d[i] = GGML_FP32_TO_FP16(d); + out.zp[i] = static_cast(mid); + } + + for (int i = 0; i < 16; i++) { + // [0, 15], in.d & 0x0F + for (int j = 0; j < QK4_1 / 4; j++) { + //src [b0 b16] ......... [b8 b24] ......... [b15 b31] + //dst [b0 b8] ......... [b7 b15] + out.qs[i * QK4_1 / 4 + j] = (in[i].qs[j] & 0x0F) | ((in[i].qs[j + QK4_1 / 4] & 0x0F) << 4); + } + } + + for (int i = 0; i < 16; i++) { + // [16, 31], in.d & 0xF0 + for (int j = 0; j < QK4_1 / 4; j++) { + //src [b0 b16] ......... [b8 b24] ......... [b15 b31] + //dst [b16 b24] ......... [b23 b31] + out.qs[4 * QK4_1 + i * QK4_1 / 4 + j] = ((in[i].qs[j] & 0xF0) >> 4) | (in[i].qs[j + QK4_1 / 4] & 0xF0); + } + } + + return out; +} + +static int repack_q4_0_to_q4_0_16_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_0); + GGML_ASSERT(interleave_block == 16); + + constexpr int nrows_interleaved = 16; + + block_q4_0x16 * dst = (block_q4_0x16 *) t->data; + const block_q4_0 * src = (const block_q4_0 *) data; + block_q4_0 dst_tmp[16]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK4_0; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_0 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_q4_0x16(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +static int repack_q4_1_to_q4_1_16_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_1); + GGML_ASSERT(interleave_block == 16); + + constexpr int nrows_interleaved = 16; + + block_q4_1x16 * dst = (block_q4_1x16 *) t->data; + const block_q4_1 * src = (const block_q4_1 *) data; + block_q4_1 dst_tmp[16]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK4_1; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_1)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_1 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_q4_1x16(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +static inline void get_scale_min_k4(int j, + const uint8_t * GGML_RESTRICT q, + uint8_t * GGML_RESTRICT d, + uint8_t * GGML_RESTRICT m) { + if (j < 4) { + *d = q[j] & 63; + *m = q[j + 4] & 63; + } else { + *d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4); + *m = (q[j + 4] >> 4) | ((q[j - 0] >> 6) << 4); + } +} + +static int repack_q4_k_to_q4_1_16_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_K); + GGML_ASSERT(interleave_block == 16); + GGML_ASSERT(QK_K / QK4_1 == 8); + + constexpr int nrows_interleaved = 16; + + block_q4_1x16 * dst = (block_q4_1x16 *) t->data; + const block_q4_K * src = (const block_q4_K *) data; + block_q4_1 dst_tmp[16]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_K; + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK_K != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int j = 0; j < 8; j++) { + for (int i = 0; i < nrows_interleaved; i++) { + uint8_t sc, m; + const float d = GGML_FP16_TO_FP32(src[x + i * nblocks].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d); + const float min = + GGML_FP16_TO_FP32(src[x + i * nblocks].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin); + get_scale_min_k4(j, src[x + i * nblocks].scales, &sc, &m); + const float d1 = d * sc; + const float m1 = min * m; + + dst_tmp[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d = GGML_FP32_TO_FP16(d1); + dst_tmp[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.m = GGML_FP32_TO_FP16(-m1); + // src -> [b0, b32] [b1, b33] ... [b31, b63] + // dst -> [b0, b16] [b1, b17] ... [b15, b31] [b32, b48] [b33, b49] ... [b47, b63] + const uint8_t * q = src[x + i * nblocks].qs + (j / 2) * QK4_1; + if (j % 2 == 0) { + for (int ii = 0; ii < 16; ii++) { + dst_tmp[i].qs[ii] = (q[ii] & 0x0F) | ((q[ii + 16] & 0x0F) << 4); + } + } else { + for (int ii = 0; ii < 16; ii++) { + dst_tmp[i].qs[ii] = ((q[ii] & 0xF0) >> 4) | (q[ii + 16] & 0xF0); + } + } + } + *dst++ = make_block_q4_1x16(dst_tmp, interleave_block); + } + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +static block_q4_0x32 make_block_q4_0x32(block_q4_0 * in, unsigned int blck_size_interleave) { + block_q4_0x32 out; + assert(QK4_0 / blck_size_interleave == 1); + GGML_UNUSED(blck_size_interleave); + + for (int i = 0; i < 32; i++) { + out.d[i] = in[i].d; + } + + for (int i = 0; i < 32; i++) { + // [0, 15], in.d & 0x0F + for (int j = 0; j < QK4_0 / 4; j++) { + //src [b0 b16] ......... [b8 b24] ......... [b15 b31] + //dst [b0 b1] ......... [b14 b15] + out.qs[i * QK4_0 / 2 + j] = (in[i].qs[j * 2] & 0x0F) | ((in[i].qs[j * 2 + 1] & 0x0F) << 4); + } + } + + for (int i = 0; i < 32; i++) { + // [16, 31], in.d & 0xF0 + for (int j = 0; j < QK4_0 / 4; j++) { + //src [b0 b16] ......... [b8 b24] ......... [b15 b31] + //dst [b16 b17] ......... [b30 b31] + out.qs[i * QK4_0 / 2 + QK4_0 / 4 + j] = ((in[i].qs[j * 2] & 0xF0) >> 4) | (in[i].qs[j * 2 + 1] & 0xF0); + } + } + + return out; +} + +static block_q4_1x32 make_block_q4_1x32(block_q4_1 * in, unsigned int blck_size_interleave) { + block_q4_1x32 out; + GGML_ASSERT(QK4_1 / blck_size_interleave == 1); + GGML_UNUSED(blck_size_interleave); + + for (int i = 0; i < 32; i++) { + float d = GGML_FP16_TO_FP32(in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d); + float m = GGML_FP16_TO_FP32(in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.m); + float mid = -std::nearbyintf(m / d); + mid = std::min(15.0f, std::max(0.0f, mid)); + out.d[i] = GGML_FP32_TO_FP16(d); + out.zp[i] = static_cast(mid); + } + + for (int i = 0; i < 32; i++) { + // [0, 15], in.d & 0x0F + for (int j = 0; j < QK4_1 / 4; j++) { + //src [b0 b16] ......... [b8 b24] ......... [b15 b31] + //dst [b0 b1] ......... [b14 b15] + out.qs[i * QK4_1 / 2 + j] = (in[i].qs[j * 2] & 0x0F) | ((in[i].qs[j * 2 + 1] & 0x0F) << 4); + } + } + + for (int i = 0; i < 32; i++) { + // [16, 31], in.d & 0xF0 + for (int j = 0; j < QK4_1 / 4; j++) { + //src [b0 b16] ......... [b8 b24] ......... [b15 b31] + //dst [b16 b24] ......... [b23 b31] + out.qs[i * QK4_1 / 2 + QK4_1 / 4 + j] = ((in[i].qs[j * 2] & 0xF0) >> 4) | (in[i].qs[j * 2 + 1] & 0xF0); + } + } + + return out; +} + +static block_q8_0x32 make_block_q8_0x32(block_q8_0 * in, unsigned int blck_size_interleave) { + block_q8_0x32 out; + GGML_ASSERT(QK8_0 / blck_size_interleave == 1); + GGML_UNUSED(blck_size_interleave); + + for (int i = 0; i < 32; i++) { + out.d[i] = in[i].d; + } + + for (int i = 0; i < 32; i++) { + memcpy(out.qs + i * QK8_0, in[i].qs, QK8_0); + } + + return out; +} + +static int repack_q2_k_to_q2_k_32_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q2_K); + GGML_ASSERT(interleave_block == 32); + GGML_ASSERT(QK_K == 256); + + constexpr int nrows_interleaved = 32; + + const block_q2_K * src = (const block_q2_K *) data; + + auto * dst = (spacemit_kernels::nrow_block_q2_k<32> *) t->data; + + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_K; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q2_K)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK_K != 0) { + return -1; + } + + uint8_t qs_aux[256] = { 0 }; + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + const block_q2_K * src_block = &src[(b + i) * nblocks + x]; + + // scale for [16, N] + for (int j = 0; j < 16; j++) { + auto zp_aux = (dst->scales[j * nrows_interleaved + i]) & 0xF0; + + dst->scales[j * nrows_interleaved + i] = (src_block->scales[j] & 0x0F) | zp_aux; + } + + // zp for [N, 16] + for (int j = 0; j < 16; j++) { + auto scale_aux = (dst->scales[16 * i + j]) & 0x0F; + + dst->scales[16 * i + j] = (src_block->scales[j] & 0xF0) | scale_aux; + } + + for (int k = 0; k < 4; k++) { + for (int j = 0; j < 32; j++) { + qs_aux[k * 32 + j] = (src_block->qs[j] >> (2 * k)) & 0x03; + } + } + + for (int k = 0; k < 4; k++) { + for (int j = 0; j < 32; j++) { + qs_aux[k * 32 + j + 128] = (src_block->qs[j + 32] >> (2 * k)) & 0x03; + } + } + + // from nrows_interleaved * [2 * 32byte] + // to 4 * [nrows_interleaved * 16byte] + for (int k = 0; k < 4; k++) { + for (int j = 0; j < 16; j++) { + uint8_t qs0 = qs_aux[j + k * 64]; + uint8_t qs16 = qs_aux[j + 16 + k * 64]; + uint8_t qs32 = qs_aux[j + 32 + k * 64]; + uint8_t qs48 = qs_aux[j + 48 + k * 64]; + + dst->qs[(k * nrows_interleaved + i) * 16 + j] = + (qs0 & 0x03) | ((qs16 & 0x03) << 2) | ((qs32 & 0x03) << 4) | ((qs48 & 0x03) << 6); + } + } + + dst->scales16[i] = src_block->GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d; + dst->zeros16[i] = src_block->GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin; + } + dst++; + } + } + + return 0; +} + +static int repack_q3_k_to_q3_k_32_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q3_K); + GGML_ASSERT(interleave_block == 32); + GGML_ASSERT(QK_K == 256); + + constexpr int nrows_interleaved = 32; + + const uint32_t kmask1 = 0x03030303; + const uint32_t kmask2 = 0x0f0f0f0f; + + const block_q3_K * src = (const block_q3_K *) data; + + auto * dst = (spacemit_kernels::nrow_block_q3_k<32> *) t->data; + + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_K; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q3_K)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK_K != 0) { + return -1; + } + + uint32_t b_scale_aux[4] = { 0 }; + uint8_t qs_aux[256] = { 0 }; + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + const block_q3_K * src_block = &src[(b + i) * nblocks + x]; + + uint32_t * auxs = b_scale_aux; + int8_t * scale = (int8_t *) auxs; + memcpy(auxs, src_block->scales, 12); + + uint32_t tmp = auxs[2]; + auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); + auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4); + auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4); + auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4); + + for (int j = 0; j < 16; j++) { + dst->scales[j * nrows_interleaved + i] = scale[j] - 32; + } + + for (int k = 0; k < 4; k++) { + for (int j = 0; j < 32; j++) { + qs_aux[k * 32 + j] = (src_block->qs[j] >> (2 * k)) & 0x03; + } + } + + for (int k = 0; k < 4; k++) { + for (int j = 0; j < 32; j++) { + qs_aux[k * 32 + j + 128] = (src_block->qs[j + 32] >> (2 * k)) & 0x03; + } + } + + // from nrows_interleaved * [2 * 32byte] + // to 4 * [nrows_interleaved * 16byte] + for (int k = 0; k < 4; k++) { + for (int j = 0; j < 16; j++) { + uint8_t qs0 = qs_aux[j + k * 64]; + uint8_t qs16 = qs_aux[j + 16 + k * 64]; + uint8_t qs32 = qs_aux[j + 32 + k * 64]; + uint8_t qs48 = qs_aux[j + 48 + k * 64]; + + dst->qs[(k * nrows_interleaved + i) * 16 + j] = + (qs0 & 0x03) | ((qs16 & 0x03) << 2) | ((qs32 & 0x03) << 4) | ((qs48 & 0x03) << 6); + } + } + + //memcpy(dst->hmask + i * 32, src_block->hmask, 32); + + // from nrows_interleaved * [32byte] + // to 16 * [nrows_interleaved * uint16_t] + uint16_t * dst_mask = ((uint16_t *) dst->hmask) + i; + for (int j = 0; j < 16; j++, dst_mask += nrows_interleaved) { + uint8_t b_shift = j / 2; + uint8_t * b_mask_col = (uint8_t *) (src_block->hmask + (j % 2) * 16); + // b0 - b15 + uint16_t msk_out_0 = 0; + + for (int k = 0; k < 8; k++) { + msk_out_0 |= (uint16_t) ((b_mask_col[k] >> b_shift) & 0x01) << k; + } + for (int k = 8; k < 16; k++) { + msk_out_0 |= (uint16_t) ((b_mask_col[k] >> b_shift) & 0x01) << k; + } + + dst_mask[0] = msk_out_0; + } + + dst->scales16[i] = src_block->d; + } + + dst++; + } + } + + return 0; +} + +static int repack_q4_0_to_q4_0_32_bl_ref(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_0); + GGML_ASSERT(interleave_block == 32); // unused + + constexpr int nrows_interleaved = 32; + + block_q4_0x32 * dst = (block_q4_0x32 *) t->data; + const block_q4_0 * src = (const block_q4_0 *) data; + block_q4_0 dst_tmp[32]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK4_0; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_0 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_q4_0x32(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +static int repack_q4_0_to_q4_0_256_32_bl_ref(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_0); + GGML_ASSERT(interleave_block == 32); // unused + + constexpr int nrows_interleaved = 32; + + block_q4_0x32x256 * dst = (block_q4_0x32x256 *) t->data; + const block_q4_0 * src = (const block_q4_0 *) data; + block_q4_0 dst_tmp[32]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK4_0; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0)); + GGML_ASSERT(nblocks % 8 == 0); // for 256-block interleaving + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_0 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x += 8) { + for (int j = 0; j < 8; j++) { + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + j + i * nblocks]; + } + dst->blocks[j] = make_block_q4_0x32(dst_tmp, interleave_block); + } + dst++; + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +static int repack_q4_0_to_q4_1_256_32_bl_ref(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_1); + GGML_ASSERT(interleave_block == 32); // unused + + constexpr int nrows_interleaved = 32; + + block_q4_1x32x256 * dst = (block_q4_1x32x256 *) t->data; + const block_q4_1 * src = (const block_q4_1 *) data; + block_q4_1 dst_tmp[32]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK4_0; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_1)); + GGML_ASSERT(nblocks % 8 == 0); // for 256-block interleaving + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_0 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x += 8) { + for (int j = 0; j < 8; j++) { + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + j + i * nblocks]; + } + + block_q4_0x32 * dst_block = &dst->blocks[j]; + uint8_t * dst_zp = dst->zps + j * nrows_interleaved; + + for (int i = 0; i < nrows_interleaved; i++) { + float d = GGML_FP16_TO_FP32(dst_tmp[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d); + float m = GGML_FP16_TO_FP32(dst_tmp[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.m); + float mid = -std::nearbyintf(m / d); + mid = std::min(15.0f, std::max(0.0f, mid)); + + dst_block->d[i] = GGML_FP32_TO_FP16(d); + dst_zp[i] = static_cast(mid); + } + + for (int i = 0; i < nrows_interleaved; i++) { + for (int k = 0; k < QK4_1 / 4; k++) { + dst_block->qs[i * QK4_1 / 2 + k] = + (dst_tmp[i].qs[k * 2] & 0x0F) | ((dst_tmp[i].qs[k * 2 + 1] & 0x0F) << 4); + } + } + + for (int i = 0; i < nrows_interleaved; i++) { + for (int k = 0; k < QK4_1 / 4; k++) { + dst_block->qs[i * QK4_1 / 2 + QK4_1 / 4 + k] = + ((dst_tmp[i].qs[k * 2] & 0xF0) >> 4) | (dst_tmp[i].qs[k * 2 + 1] & 0xF0); + } + } + } + dst++; + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +// RVV optimized version of repack_q4_0_to_q4_0_32_bl +// Eliminates the intermediate dst_tmp buffer and vectorizes nibble repack. +static int repack_q4_0_to_q4_0_32_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_0); + GGML_ASSERT(interleave_block == 32); + + constexpr int nrows_interleaved = 32; + constexpr int qs_bytes = QK4_0 / 2; // 16 + + block_q4_0x32 * dst = (block_q4_0x32 *) t->data; + const block_q4_0 * src = (const block_q4_0 *) data; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK4_0; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_0 != 0) { + return -1; + } + + const ptrdiff_t row_stride = (ptrdiff_t) nblocks * sizeof(block_q4_0); + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + const block_q4_0 * col_src = src + x; + + // --- 1) Gather 32 scale values (ggml_half d) with stride load --- + // d is at offset 0 of each block_q4_0, stride between rows = row_stride + { + const uint8_t * d_base = (const uint8_t *) &col_src->d; + ggml_half * d_dst = dst->d; + size_t remaining = 32; + size_t offset = 0; + while (remaining > 0) { + size_t vl = __riscv_vsetvl_e16m1(remaining); + vuint16m1_t vd = + __riscv_vlse16_v_u16m1((const uint16_t *) (d_base + offset * row_stride), row_stride, vl); + __riscv_vse16_v_u16m1((uint16_t *) (d_dst + offset), vd, vl); + offset += vl; + remaining -= vl; + } + } + + // --- 2) Nibble repack qs for each of the 32 rows --- + // For each row i: + // src qs[16]: [b0|b16] [b1|b17] ... [b15|b31] (lo nibble = b_j, hi nibble = b_{j+16}) + // dst qs low 8B: (qs[2j] & 0x0F) | ((qs[2j+1] & 0x0F) << 4) for j=0..7 + // dst qs high 8B: ((qs[2j] >> 4)) | (qs[2j+1] & 0xF0) for j=0..7 + { + const size_t vl8 = __riscv_vsetvl_e8m1(8); + for (int i = 0; i < 32; i++) { + const uint8_t * sq = col_src[i * nblocks].qs; + uint8_t * dq = dst->qs + i * qs_bytes; + + // stride-2 load to separate even/odd bytes + vuint8m1_t v_even = __riscv_vlse8_v_u8m1(sq, 2, vl8); // qs[0], qs[2], ..., qs[14] + vuint8m1_t v_odd = __riscv_vlse8_v_u8m1(sq + 1, 2, vl8); // qs[1], qs[3], ..., qs[15] + + // low nibble part: (even & 0x0F) | ((odd & 0x0F) << 4) + vuint8m1_t v_even_lo = __riscv_vand_vx_u8m1(v_even, 0x0F, vl8); + vuint8m1_t v_odd_lo = __riscv_vand_vx_u8m1(v_odd, 0x0F, vl8); + vuint8m1_t v_lo = __riscv_vor_vv_u8m1(v_even_lo, __riscv_vsll_vx_u8m1(v_odd_lo, 4, vl8), vl8); + + // high nibble part: (even >> 4) | (odd & 0xF0) + vuint8m1_t v_even_hi = __riscv_vsrl_vx_u8m1(v_even, 4, vl8); + vuint8m1_t v_odd_hi = __riscv_vand_vx_u8m1(v_odd, 0xF0, vl8); + vuint8m1_t v_hi = __riscv_vor_vv_u8m1(v_even_hi, v_odd_hi, vl8); + + __riscv_vse8_v_u8m1(dq, v_lo, vl8); + __riscv_vse8_v_u8m1(dq + 8, v_hi, vl8); + } + } + + dst++; + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +static int repack_q4_1_to_q4_1_32_bl_ref(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_1); + GGML_ASSERT(interleave_block == 32); // unused + + constexpr int nrows_interleaved = 32; + + block_q4_1x32 * dst = (block_q4_1x32 *) t->data; + const block_q4_1 * src = (const block_q4_1 *) data; + block_q4_1 dst_tmp[32]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK4_1; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_1)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_1 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_q4_1x32(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +// RVV optimized version of repack_q4_1_to_q4_1_32_bl +// Eliminates the intermediate dst_tmp buffer and vectorizes nibble repack + zp computation. +static int repack_q4_1_to_q4_1_32_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_1); + GGML_ASSERT(interleave_block == 32); + + constexpr int nrows_interleaved = 32; + constexpr int qs_bytes = QK4_1 / 2; // 16 + + block_q4_1x32 * dst = (block_q4_1x32 *) t->data; + const block_q4_1 * src = (const block_q4_1 *) data; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK4_1; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_1)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_1 != 0) { + return -1; + } + + const ptrdiff_t row_stride = (ptrdiff_t) nblocks * sizeof(block_q4_1); + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + const block_q4_1 * col_src = src + x; + + // --- 1) Gather d and m, compute zp = clamp(nearbyint(-m/d), 0, 15) --- + // block_q4_1 layout: [d(f16), m(f16), qs[16]] + // d is at byte offset 0, m is at byte offset 2 from each block start + { + const uint8_t * dm_base = (const uint8_t *) &col_src->GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d; + ggml_half * d_dst = dst->d; + uint8_t * zp_dst = dst->zp; + size_t remaining = 32; + size_t offset = 0; + while (remaining > 0) { + size_t vl = __riscv_vsetvl_e16m1(remaining); + + // stride load d (f16) from each row + vuint16m1_t vd_raw = + __riscv_vlse16_v_u16m1((const uint16_t *) (dm_base + offset * row_stride), row_stride, vl); + __riscv_vse16_v_u16m1((uint16_t *) (d_dst + offset), vd_raw, vl); + + // stride load m (f16) from each row (offset +2 bytes from d) + vuint16m1_t vm_raw = + __riscv_vlse16_v_u16m1((const uint16_t *) (dm_base + 2 + offset * row_stride), row_stride, vl); + + // convert to f32 for zp computation: zp = nearbyint(-m / d) + vfloat16m1_t vd_f16 = __riscv_vreinterpret_v_u16m1_f16m1(vd_raw); + vfloat16m1_t vm_f16 = __riscv_vreinterpret_v_u16m1_f16m1(vm_raw); + + // -m / d in f16 directly (SpaceMIT X60 supports f16 arithmetic) + vfloat16m1_t v_neg_m = __riscv_vfneg_v_f16m1(vm_f16, vl); + vfloat16m1_t v_ratio = __riscv_vfdiv_vv_f16m1(v_neg_m, vd_f16, vl); + + // Convert to f32 for nearbyint, then clamp + vfloat32m2_t v_ratio_f32 = __riscv_vfwcvt_f_f_v_f32m2(v_ratio, vl); + + // Use integer rounding: convert f32 -> int (rounds to nearest) + vint32m2_t v_zp_i32 = __riscv_vfcvt_x_f_v_i32m2(v_ratio_f32, vl); + + // clamp to [0, 15] + v_zp_i32 = __riscv_vmax_vx_i32m2(v_zp_i32, 0, vl); + v_zp_i32 = __riscv_vmin_vx_i32m2(v_zp_i32, 15, vl); + + // narrow i32 -> u8 + vint16m1_t v_zp_i16 = __riscv_vncvt_x_x_w_i16m1(v_zp_i32, vl); + vint8mf2_t v_zp_i8 = __riscv_vncvt_x_x_w_i8mf2(v_zp_i16, vl); + vuint8mf2_t v_zp_u8 = __riscv_vreinterpret_v_i8mf2_u8mf2(v_zp_i8); + __riscv_vse8_v_u8mf2(zp_dst + offset, v_zp_u8, vl); + + offset += vl; + remaining -= vl; + } + } + + // --- 2) Nibble repack qs for each of the 32 rows --- + { + const size_t vl8 = __riscv_vsetvl_e8m1(8); + for (int i = 0; i < 32; i++) { + const uint8_t * sq = col_src[i * nblocks].qs; + uint8_t * dq = dst->qs + i * qs_bytes; + + // stride-2 load to separate even/odd bytes + vuint8m1_t v_even = __riscv_vlse8_v_u8m1(sq, 2, vl8); + vuint8m1_t v_odd = __riscv_vlse8_v_u8m1(sq + 1, 2, vl8); + + // low nibble part: (even & 0x0F) | ((odd & 0x0F) << 4) + vuint8m1_t v_even_lo = __riscv_vand_vx_u8m1(v_even, 0x0F, vl8); + vuint8m1_t v_odd_lo = __riscv_vand_vx_u8m1(v_odd, 0x0F, vl8); + vuint8m1_t v_lo = __riscv_vor_vv_u8m1(v_even_lo, __riscv_vsll_vx_u8m1(v_odd_lo, 4, vl8), vl8); + + // high nibble part: (even >> 4) | (odd & 0xF0) + vuint8m1_t v_even_hi = __riscv_vsrl_vx_u8m1(v_even, 4, vl8); + vuint8m1_t v_odd_hi = __riscv_vand_vx_u8m1(v_odd, 0xF0, vl8); + vuint8m1_t v_hi = __riscv_vor_vv_u8m1(v_even_hi, v_odd_hi, vl8); + + __riscv_vse8_v_u8m1(dq, v_lo, vl8); + __riscv_vse8_v_u8m1(dq + 8, v_hi, vl8); + } + } + + dst++; + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +static int repack_q4_k_to_q4_1_32_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_K); + GGML_ASSERT(interleave_block == 32); + GGML_ASSERT(QK_K / QK4_1 == 8); + + constexpr int nrows_interleaved = 32; + + block_q4_1x32 * dst = (block_q4_1x32 *) t->data; + const block_q4_K * src = (const block_q4_K *) data; + block_q4_1 dst_tmp[32]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_K; + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK_K != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int j = 0; j < 8; j++) { + for (int i = 0; i < nrows_interleaved; i++) { + uint8_t sc, m; + const float d = GGML_FP16_TO_FP32(src[x + i * nblocks].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d); + const float min = + GGML_FP16_TO_FP32(src[x + i * nblocks].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin); + get_scale_min_k4(j, src[x + i * nblocks].scales, &sc, &m); + const float d1 = d * sc; + const float m1 = min * m; + + dst_tmp[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d = GGML_FP32_TO_FP16(d1); + dst_tmp[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.m = GGML_FP32_TO_FP16(-m1); + // src -> [b0, b32] [b1, b33] ... [b31, b63] + // dst -> [b0, b16] [b1, b17] ... [b15, b31] [b32, b48] [b33, b49] ... [b47, b63] + const uint8_t * q = src[x + i * nblocks].qs + (j / 2) * QK4_1; + if (j % 2 == 0) { + for (int ii = 0; ii < 16; ii++) { + dst_tmp[i].qs[ii] = (q[ii] & 0x0F) | ((q[ii + 16] & 0x0F) << 4); + } + } else { + for (int ii = 0; ii < 16; ii++) { + dst_tmp[i].qs[ii] = ((q[ii] & 0xF0) >> 4) | (q[ii + 16] & 0xF0); + } + } + } + *dst++ = make_block_q4_1x32(dst_tmp, interleave_block); + } + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +static int repack_q6_k_to_q8_0_32_bl_ref(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q6_K); + GGML_ASSERT(interleave_block == 32); + GGML_ASSERT(QK_K / QK4_1 == 8); + + constexpr int nrows_interleaved = 32; + + block_q8_0x32 * dst = (block_q8_0x32 *) t->data; + const block_q6_K * src = (const block_q6_K *) data; + block_q8_0 dst_tmp[32]; + int8_t aux8[QK4_1]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_K; + + if (t->ne[0] % QK_K != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + int64_t nrow_real = std::min((int64_t) nrow - b, (int64_t) nrows_interleaved); + for (int64_t x = 0; x < nblocks; x++) { + for (int bi = 0; bi < 8; bi++) { + int i = 0; + for (; i < nrow_real; i++) { + const uint8_t * q4 = src[x + i * nblocks].ql; + const uint8_t * qh = src[x + i * nblocks].qh; + const int8_t * scales = src[x + i * nblocks].scales; + float d = GGML_FP16_TO_FP32(src[x + i * nblocks].d); + + q4 += 64 * (bi / 4); + qh += 32 * (bi / 4); + int8_t * GGML_RESTRICT a = aux8; + + int8_t bi_idx = bi % 4; + + if (bi_idx == 0) { + for (int l = 0; l < 32; ++l) { + a[l] = (int8_t) ((q4[l] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; + } + } else if (bi_idx == 1) { + for (int l = 0; l < 32; ++l) { + a[l] = (int8_t) ((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; + } + } else if (bi_idx == 2) { + for (int l = 0; l < 32; ++l) { + a[l] = (int8_t) ((q4[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; + } + } else if (bi_idx == 3) { + for (int l = 0; l < 32; ++l) { + a[l] = (int8_t) ((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; + } + } + a = aux8; + + float a_max_abs = 0.0f; + float scale_0 = scales[bi * 2 + 0] * d; + float scale_1 = scales[bi * 2 + 1] * d; + for (int l = 0; l < 16; ++l) { + a_max_abs = std::max(a_max_abs, std::abs(a[l] * scale_0)); + } + + for (int l = 16; l < 32; ++l) { + a_max_abs = std::max(a_max_abs, std::abs(a[l] * scale_1)); + } + + float reflect_scale = a_max_abs / ((1 << 7) - 1); + float reflect_scale_0 = scale_0 / reflect_scale; + float reflect_scale_1 = scale_1 / reflect_scale; + + for (int l = 0; l < 16; ++l) { + float a_temp = std::clamp(std::nearbyintf(a[l] * reflect_scale_0), -128.0f, 127.0f); + a[l] = (int8_t) (a_temp); + } + + for (int l = 16; l < 32; ++l) { + float a_temp = std::clamp(std::nearbyintf(a[l] * reflect_scale_1), -128.0f, 127.0f); + a[l] = (int8_t) (a_temp); + } + + dst_tmp[i].d = GGML_FP32_TO_FP16(reflect_scale); + + memcpy(dst_tmp[i].qs, a, 32 * sizeof(int8_t)); + } + + for (; i < nrows_interleaved; i++) { + memset(&dst_tmp[i], 0, sizeof(block_q8_0)); + } + + *dst++ = make_block_q8_0x32(dst_tmp, interleave_block); + } + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +// RVV optimized version of repack_q6_k_to_q8_0_32_bl +// Vectorizes the Q6_K dequant -> requant pipeline using RVV intrinsics. +// For each sub-block (bi), dequant 32 Q6_K values to int6 -> apply two sub-block scales -> +// find max abs -> compute reflect_scale -> requant to int8 -> gather d with stride load. +static int repack_q6_k_to_q8_0_32_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q6_K); + GGML_ASSERT(interleave_block == 32); + GGML_ASSERT(QK_K / QK4_1 == 8); + + constexpr int nrows_interleaved = 32; + + block_q8_0x32 * dst = (block_q8_0x32 *) t->data; + const block_q6_K * src = (const block_q6_K *) data; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_K; + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK_K != 0) { + return -1; + } + + const ptrdiff_t row_stride = (ptrdiff_t) nblocks * sizeof(block_q6_K); + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int bi = 0; bi < 8; bi++) { + // --- 1) Gather 32 d values with stride load --- + // We need to compute reflect_scale per row first, so gather d later. + // Process each row: dequant Q6_K sub-block -> requant to Q8_0 + for (int i = 0; i < nrows_interleaved; i++) { + const block_q6_K * src_blk = &src[x + i * nblocks]; + const uint8_t * q4 = src_blk->ql + 64 * (bi / 4); + const uint8_t * qh = src_blk->qh + 32 * (bi / 4); + const int8_t * scales = src_blk->scales; + float d = GGML_FP16_TO_FP32(src_blk->d); + + int8_t bi_idx = bi % 4; + + // --- Dequant 32 Q6_K values to int6 (range [-32, 31]) using RVV --- + // vl = 32 for e8m2 (VLEN=256) or loop for smaller VLEN + const size_t vl16 = __riscv_vsetvl_e8m1(16); + + vint8m1_t va_lo, va_hi; // 16 elements each + + if (bi_idx == 0) { + // a[l] = (q4[l] & 0xF) | (((qh[l] >> 0) & 3) << 4) - 32 + vuint8m1_t vq4_lo = __riscv_vle8_v_u8m1(q4, vl16); + vuint8m1_t vq4_hi = __riscv_vle8_v_u8m1(q4 + 16, vl16); + vuint8m1_t vqh_lo = __riscv_vle8_v_u8m1(qh, vl16); + vuint8m1_t vqh_hi = __riscv_vle8_v_u8m1(qh + 16, vl16); + + vuint8m1_t vlo4_lo = __riscv_vand_vx_u8m1(vq4_lo, 0x0F, vl16); + vuint8m1_t vlo4_hi = __riscv_vand_vx_u8m1(vq4_hi, 0x0F, vl16); + vuint8m1_t vh_lo = __riscv_vsll_vx_u8m1(__riscv_vand_vx_u8m1(vqh_lo, 0x03, vl16), 4, vl16); + vuint8m1_t vh_hi = __riscv_vsll_vx_u8m1(__riscv_vand_vx_u8m1(vqh_hi, 0x03, vl16), 4, vl16); + + vuint8m1_t vcomb_lo = __riscv_vor_vv_u8m1(vlo4_lo, vh_lo, vl16); + vuint8m1_t vcomb_hi = __riscv_vor_vv_u8m1(vlo4_hi, vh_hi, vl16); + + va_lo = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(vcomb_lo), 32, vl16); + va_hi = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(vcomb_hi), 32, vl16); + } else if (bi_idx == 1) { + // a[l] = (q4[l+32] & 0xF) | (((qh[l] >> 2) & 3) << 4) - 32 + vuint8m1_t vq4_lo = __riscv_vle8_v_u8m1(q4 + 32, vl16); + vuint8m1_t vq4_hi = __riscv_vle8_v_u8m1(q4 + 48, vl16); + vuint8m1_t vqh_lo = __riscv_vle8_v_u8m1(qh, vl16); + vuint8m1_t vqh_hi = __riscv_vle8_v_u8m1(qh + 16, vl16); + + vuint8m1_t vlo4_lo = __riscv_vand_vx_u8m1(vq4_lo, 0x0F, vl16); + vuint8m1_t vlo4_hi = __riscv_vand_vx_u8m1(vq4_hi, 0x0F, vl16); + vuint8m1_t vh_lo = __riscv_vsll_vx_u8m1( + __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(vqh_lo, 2, vl16), 0x03, vl16), 4, vl16); + vuint8m1_t vh_hi = __riscv_vsll_vx_u8m1( + __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(vqh_hi, 2, vl16), 0x03, vl16), 4, vl16); + + vuint8m1_t vcomb_lo = __riscv_vor_vv_u8m1(vlo4_lo, vh_lo, vl16); + vuint8m1_t vcomb_hi = __riscv_vor_vv_u8m1(vlo4_hi, vh_hi, vl16); + + va_lo = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(vcomb_lo), 32, vl16); + va_hi = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(vcomb_hi), 32, vl16); + } else if (bi_idx == 2) { + // a[l] = (q4[l] >> 4) | (((qh[l] >> 4) & 3) << 4) - 32 + vuint8m1_t vq4_lo = __riscv_vle8_v_u8m1(q4, vl16); + vuint8m1_t vq4_hi = __riscv_vle8_v_u8m1(q4 + 16, vl16); + vuint8m1_t vqh_lo = __riscv_vle8_v_u8m1(qh, vl16); + vuint8m1_t vqh_hi = __riscv_vle8_v_u8m1(qh + 16, vl16); + + vuint8m1_t vhi4_lo = __riscv_vsrl_vx_u8m1(vq4_lo, 4, vl16); + vuint8m1_t vhi4_hi = __riscv_vsrl_vx_u8m1(vq4_hi, 4, vl16); + vuint8m1_t vh_lo = __riscv_vsll_vx_u8m1( + __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(vqh_lo, 4, vl16), 0x03, vl16), 4, vl16); + vuint8m1_t vh_hi = __riscv_vsll_vx_u8m1( + __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(vqh_hi, 4, vl16), 0x03, vl16), 4, vl16); + + vuint8m1_t vcomb_lo = __riscv_vor_vv_u8m1(vhi4_lo, vh_lo, vl16); + vuint8m1_t vcomb_hi = __riscv_vor_vv_u8m1(vhi4_hi, vh_hi, vl16); + + va_lo = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(vcomb_lo), 32, vl16); + va_hi = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(vcomb_hi), 32, vl16); + } else { // bi_idx == 3 + // a[l] = (q4[l+32] >> 4) | (((qh[l] >> 6) & 3) << 4) - 32 + vuint8m1_t vq4_lo = __riscv_vle8_v_u8m1(q4 + 32, vl16); + vuint8m1_t vq4_hi = __riscv_vle8_v_u8m1(q4 + 48, vl16); + vuint8m1_t vqh_lo = __riscv_vle8_v_u8m1(qh, vl16); + vuint8m1_t vqh_hi = __riscv_vle8_v_u8m1(qh + 16, vl16); + + vuint8m1_t vhi4_lo = __riscv_vsrl_vx_u8m1(vq4_lo, 4, vl16); + vuint8m1_t vhi4_hi = __riscv_vsrl_vx_u8m1(vq4_hi, 4, vl16); + vuint8m1_t vh_lo = __riscv_vsll_vx_u8m1( + __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(vqh_lo, 6, vl16), 0x03, vl16), 4, vl16); + vuint8m1_t vh_hi = __riscv_vsll_vx_u8m1( + __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(vqh_hi, 6, vl16), 0x03, vl16), 4, vl16); + + vuint8m1_t vcomb_lo = __riscv_vor_vv_u8m1(vhi4_lo, vh_lo, vl16); + vuint8m1_t vcomb_hi = __riscv_vor_vv_u8m1(vhi4_hi, vh_hi, vl16); + + va_lo = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(vcomb_lo), 32, vl16); + va_hi = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(vcomb_hi), 32, vl16); + } + + // --- Widen to i16 for scaled abs computation --- + float scale_0 = scales[bi * 2 + 0] * d; + float scale_1 = scales[bi * 2 + 1] * d; + + // Widen i8 -> i16 -> f32 for abs*scale computation + vint16m2_t va_lo_w = __riscv_vsext_vf2_i16m2(va_lo, vl16); + vint16m2_t va_hi_w = __riscv_vsext_vf2_i16m2(va_hi, vl16); + + // Compute |a[l] * scale_0| for lo half, |a[l] * scale_1| for hi half + vfloat32m4_t vf_lo = __riscv_vfcvt_f_x_v_f32m4(__riscv_vsext_vf2_i32m4(va_lo_w, vl16), vl16); + vfloat32m4_t vf_hi = __riscv_vfcvt_f_x_v_f32m4(__riscv_vsext_vf2_i32m4(va_hi_w, vl16), vl16); + + vfloat32m4_t vabs_lo = __riscv_vfabs_v_f32m4(__riscv_vfmul_vf_f32m4(vf_lo, scale_0, vl16), vl16); + vfloat32m4_t vabs_hi = __riscv_vfabs_v_f32m4(__riscv_vfmul_vf_f32m4(vf_hi, scale_1, vl16), vl16); + + // Find max abs across both halves + vfloat32m4_t vabs_max = __riscv_vfmax_vv_f32m4(vabs_lo, vabs_hi, vl16); + + // Reduce to scalar max + vfloat32m1_t vzero = __riscv_vfmv_v_f_f32m1(0.0f, 1); + vfloat32m1_t vmax_red = __riscv_vfredmax_vs_f32m4_f32m1(vabs_max, vzero, vl16); + float a_max_abs = __riscv_vfmv_f_s_f32m1_f32(vmax_red); + + float reflect_scale = a_max_abs / 127.0f; + float reflect_scale_0 = scale_0 / reflect_scale; + float reflect_scale_1 = scale_1 / reflect_scale; + + // --- Requant: a[l] = clamp(nearbyint(a[l] * reflect_scale_x), -128, 127) --- + vfloat32m4_t vscaled_lo = __riscv_vfmul_vf_f32m4(vf_lo, reflect_scale_0, vl16); + vfloat32m4_t vscaled_hi = __riscv_vfmul_vf_f32m4(vf_hi, reflect_scale_1, vl16); + + // fcvt.x rounds to nearest (using current rounding mode) + vint32m4_t vi_lo = __riscv_vfcvt_x_f_v_i32m4(vscaled_lo, vl16); + vint32m4_t vi_hi = __riscv_vfcvt_x_f_v_i32m4(vscaled_hi, vl16); + + // Clamp to [-128, 127] + vi_lo = __riscv_vmax_vx_i32m4(vi_lo, -128, vl16); + vi_lo = __riscv_vmin_vx_i32m4(vi_lo, 127, vl16); + vi_hi = __riscv_vmax_vx_i32m4(vi_hi, -128, vl16); + vi_hi = __riscv_vmin_vx_i32m4(vi_hi, 127, vl16); + + // Narrow i32 -> i16 -> i8 + vint16m2_t vi16_lo = __riscv_vncvt_x_x_w_i16m2(vi_lo, vl16); + vint16m2_t vi16_hi = __riscv_vncvt_x_x_w_i16m2(vi_hi, vl16); + vint8m1_t vi8_lo = __riscv_vncvt_x_x_w_i8m1(vi16_lo, vl16); + vint8m1_t vi8_hi = __riscv_vncvt_x_x_w_i8m1(vi16_hi, vl16); + + // Store d and qs directly into dst block + dst->d[i] = GGML_FP32_TO_FP16(reflect_scale); + int8_t * dq = (int8_t *) dst->qs + i * QK8_0; + __riscv_vse8_v_i8m1(dq, vi8_lo, vl16); + __riscv_vse8_v_i8m1(dq + 16, vi8_hi, vl16); + } + dst++; + } + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +static int repack_q8_0_to_q8_0_32_bl_ref(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q8_0); + GGML_ASSERT(interleave_block == 32); // unused + + constexpr int nrows_interleaved = 32; + + block_q8_0x32 * dst = (block_q8_0x32 *) t->data; + const block_q8_0 * src = (const block_q8_0 *) data; + block_q8_0 dst_tmp[32]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK8_0; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q8_0)); + + if (t->ne[0] % QK8_0 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + int64_t nrows_real = std::min((int64_t) nrow - b, (int64_t) nrows_interleaved); + for (int64_t x = 0; x < nblocks; x++) { + int i = 0; + for (; i < nrows_real; i++) { + dst_tmp[i] = src[x + i * nblocks]; + } + for (; i < nrows_interleaved; i++) { + memset(&dst_tmp[i], 0, sizeof(block_q8_0)); + } + *dst++ = make_block_q8_0x32(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +// RVV optimized version of repack_q8_0_to_q8_0_32_bl +// Eliminates the intermediate dst_tmp buffer and vectorizes scale gather + qs copy. +static int repack_q8_0_to_q8_0_32_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q8_0); + GGML_ASSERT(interleave_block == 32); + + constexpr int nrows_interleaved = 32; + + block_q8_0x32 * dst = (block_q8_0x32 *) t->data; + const block_q8_0 * src = (const block_q8_0 *) data; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK8_0; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q8_0)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK8_0 != 0) { + return -1; + } + + const ptrdiff_t row_stride = (ptrdiff_t) nblocks * sizeof(block_q8_0); + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + const block_q8_0 * col_src = src + x; + + // --- 1) Gather 32 scale values (ggml_half d) with stride load --- + { + const uint8_t * d_base = (const uint8_t *) &col_src->d; + ggml_half * d_dst = dst->d; + size_t remaining = 32; + size_t offset = 0; + while (remaining > 0) { + size_t vl = __riscv_vsetvl_e16m1(remaining); + vuint16m1_t vd = + __riscv_vlse16_v_u16m1((const uint16_t *) (d_base + offset * row_stride), row_stride, vl); + __riscv_vse16_v_u16m1((uint16_t *) (d_dst + offset), vd, vl); + offset += vl; + remaining -= vl; + } + } + + // --- 2) Copy qs for each of the 32 rows (32 bytes per row) --- + { + for (int i = 0; i < 32; i++) { + const int8_t * sq = col_src[i * nblocks].qs; + int8_t * dq = (int8_t *) dst->qs + i * QK8_0; + + size_t len = QK8_0; + size_t idx = 0; + while (len > 0) { + size_t vl = __riscv_vsetvl_e8m2(len); + vint8m2_t vs = __riscv_vle8_v_i8m2(sq + idx, vl); + __riscv_vse8_v_i8m2(dq + idx, vs, vl); + idx += vl; + len -= vl; + } + } + } + + dst++; + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +static void convert_mxfp4_to_5bit(const block_mxfp4 & src, spacemit_kernels::nrow_block_mxfp4<1> & dst) { + dst.e[0] = src.e; + + // Decode all 32 mxfp4 values to signed integers via kvalues_mxfp4 + int8_t vals[32]; + for (int j = 0; j < QK_MXFP4 / 2; j++) { + vals[j] = kvalues_mxfp4[src.qs[j] & 0xF]; + vals[j + QK_MXFP4 / 2] = kvalues_mxfp4[src.qs[j] >> 4]; + } + + // vals [b0, b1, b2, b3, ..., b30, b31] + // Pack abs into qs with reorder: [b0,b1]..[b14,b15]..[b30,b31] + for (int j = 0; j < QK_MXFP4 / 2; j++) { + uint8_t lo0 = static_cast(std::abs(vals[j * 2])); + uint8_t lo1 = static_cast(std::abs(vals[j * 2 + 1])); + dst.qs[j] = (lo0 & 0x0F) | ((lo1 & 0x0F) << 4); + } + + // Pack sign bits into qh[4] (32 bits total, 1 bit per weight) + // reorder: [0,1,2,...,15,16,17,...,31] after the qs reorder above + uint32_t sign_bits = 0; + for (int j = 0; j < 32; j++) { + if (vals[j] < 0) { + sign_bits |= (1u << j); + } + } + memcpy(dst.qh, &sign_bits, 4); +} + +static spacemit_kernels::nrow_block_mxfp4<32> make_block_mxfp4x32(spacemit_kernels::nrow_block_mxfp4<1> * in, + unsigned int blck_size_interleave) { + spacemit_kernels::nrow_block_mxfp4<32> out; + GGML_ASSERT(QK_MXFP4 / blck_size_interleave == 1); + GGML_UNUSED(blck_size_interleave); + + for (int i = 0; i < 32; i++) { + out.e[i] = in[i].e[0]; + } + + // qs: copy per-row 16 bytes + for (int i = 0; i < 32; i++) { + memcpy(out.qs + i * 16, in[i].qs, 16); + } + + // qh: copy per-row 4 bytes + for (int i = 0; i < 32; i++) { + memcpy(out.qh + i * 4, in[i].qh, 4); + } + + return out; +} + +static int repack_mxfp4_to_mxfp4_32_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_MXFP4); + GGML_ASSERT(interleave_block == 32); + + constexpr int nrows_interleaved = 32; + + spacemit_kernels::nrow_block_mxfp4<32> * dst = (spacemit_kernels::nrow_block_mxfp4<32> *) t->data; + const block_mxfp4 * src = (const block_mxfp4 *) data; + spacemit_kernels::nrow_block_mxfp4<1> dst_tmp[32]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_MXFP4; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_mxfp4)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK_MXFP4 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + convert_mxfp4_to_5bit(src[x + i * nblocks], dst_tmp[i]); + } + *dst++ = make_block_mxfp4x32(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; +} + +static spacemit_kernels::nrow_block_q5_1<32> make_block_q5_1x32(spacemit_kernels::nrow_block_q5_1<1> * in, + unsigned int blck_size_interleave) { + spacemit_kernels::nrow_block_q5_1<32> out; + GGML_ASSERT(QK5_1 / blck_size_interleave == 1); + GGML_UNUSED(blck_size_interleave); + + for (int i = 0; i < 32; i++) { + out.scales16[i] = in[i].scales16[0]; + out.zp[i] = in[i].zp[0]; + } + + // qs: low 4 bits, reorder from [b0,b16],[b1,b17]... to [b0,b1]...[b14,b15] and [b16,b17]...[b30,b31] + for (int i = 0; i < 32; i++) { + // low half [0..15] + for (int j = 0; j < QK5_1 / 4; j++) { + out.qs[i * QK5_1 / 2 + j] = (in[i].qs[j * 2] & 0x0F) | ((in[i].qs[j * 2 + 1] & 0x0F) << 4); + } + // high half [16..31] + for (int j = 0; j < QK5_1 / 4; j++) { + out.qs[i * QK5_1 / 2 + QK5_1 / 4 + j] = ((in[i].qs[j * 2] & 0xF0) >> 4) | (in[i].qs[j * 2 + 1] & 0xF0); + } + } + + // qh: 5th bit, copy directly + for (int i = 0; i < 32; i++) { + for (int j = 0; j < 4; j++) { + out.qh[i * 4 + j] = in[i].qh[j]; + } + } + + return out; +} + +static spacemit_kernels::nrow_block_q5_0<32> make_block_q5_0x32(spacemit_kernels::nrow_block_q5_0<1> * in, + unsigned int blck_size_interleave) { + spacemit_kernels::nrow_block_q5_0<32> out; + GGML_ASSERT(QK5_0 / blck_size_interleave == 1); + GGML_UNUSED(blck_size_interleave); + + for (int i = 0; i < 32; i++) { + out.scales16[i] = in[i].scales16[0]; + } + + // qs: low 4 bits, reorder from [b0,b16],[b1,b17]... to [b0,b1]...[b14,b15] and [b16,b17]...[b30,b31] + for (int i = 0; i < 32; i++) { + // low half [0..15] + for (int j = 0; j < QK5_0 / 4; j++) { + out.qs[i * QK5_0 / 2 + j] = (in[i].qs[j * 2] & 0x0F) | ((in[i].qs[j * 2 + 1] & 0x0F) << 4); + } + // high half [16..31] + for (int j = 0; j < QK5_0 / 4; j++) { + out.qs[i * QK5_0 / 2 + QK5_0 / 4 + j] = ((in[i].qs[j * 2] & 0xF0) >> 4) | (in[i].qs[j * 2 + 1] & 0xF0); + } + } + + // qh: 5th bit, copy directly + for (int i = 0; i < 32; i++) { + for (int j = 0; j < 4; j++) { + out.qh[i * 4 + j] = in[i].qh[j]; + } + } + + return out; +} + +static int repack_q5_0_to_q5_0_32_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q5_0); + GGML_ASSERT(interleave_block == 32); // unused + + constexpr int nrows_interleaved = 32; + + spacemit_kernels::nrow_block_q5_0<32> * dst = (spacemit_kernels::nrow_block_q5_0<32> *) t->data; + const block_q5_0 * src = (const block_q5_0 *) data; + spacemit_kernels::nrow_block_q5_0<1> dst_tmp[32]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK5_0; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q5_0)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK5_0 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + const block_q5_0 & s = src[x + i * nblocks]; + + dst_tmp[i].scales16[0] = s.d; + memcpy(dst_tmp[i].qs, s.qs, sizeof(dst_tmp[i].qs)); + memcpy(dst_tmp[i].qh, s.qh, sizeof(dst_tmp[i].qh)); + } + *dst++ = make_block_q5_0x32(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; +} + +static int repack_q5_1_to_q5_1_32_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q5_1); + GGML_ASSERT(interleave_block == 32); // unused + + constexpr int nrows_interleaved = 32; + + spacemit_kernels::nrow_block_q5_1<32> * dst = (spacemit_kernels::nrow_block_q5_1<32> *) t->data; + const block_q5_1 * src = (const block_q5_1 *) data; + spacemit_kernels::nrow_block_q5_1<1> dst_tmp[32]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK5_1; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q5_1)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK5_1 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + const block_q5_1 & s = src[x + i * nblocks]; + + float d = GGML_FP16_TO_FP32(s.GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d); + float m = GGML_FP16_TO_FP32(s.GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.m); + + if (d == 0.0f) { + dst_tmp[i].scales16[0] = GGML_FP32_TO_FP16(std::fabs(m)); + dst_tmp[i].zp[0] = m < 0.0f ? 1 : 0; + memset(dst_tmp[i].qh, 0, sizeof(dst_tmp[i].qh)); + memset(dst_tmp[i].qs, m > 0.0f ? 0x11 : 0x00, sizeof(dst_tmp[i].qs)); + continue; + } + + float mid = std::nearbyintf(-m / d); + mid = std::min(31.0f, std::max(0.0f, mid)); + + dst_tmp[i].scales16[0] = GGML_FP32_TO_FP16(d); + dst_tmp[i].zp[0] = static_cast(mid); + + // qs: copy low 4 bits directly (same nibble packing) + memcpy(dst_tmp[i].qs, s.qs, QK5_1 / 2); + + // qh: copy 5th bit directly + memcpy(dst_tmp[i].qh, s.qh, 4); + } + *dst++ = make_block_q5_1x32(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; +} + +static int repack_q5_k_to_q5_1_32_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q5_K); + GGML_ASSERT(interleave_block == 32); + GGML_ASSERT(QK_K / QK5_1 == 8); + + constexpr int nrows_interleaved = 32; + + spacemit_kernels::nrow_block_q5_1<32> * dst = (spacemit_kernels::nrow_block_q5_1<32> *) t->data; + const block_q5_K * src = (const block_q5_K *) data; + spacemit_kernels::nrow_block_q5_1<1> dst_tmp[32]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_K; + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK_K != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int j = 0; j < 8; j++) { + for (int i = 0; i < nrows_interleaved; i++) { + uint8_t sc, m; + const float d = GGML_FP16_TO_FP32(src[x + i * nblocks].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d); + const float min = + GGML_FP16_TO_FP32(src[x + i * nblocks].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin); + get_scale_min_k4(j, src[x + i * nblocks].scales, &sc, &m); + + float d1 = d * sc; + float m1 = min * m; + + float mid = std::nearbyintf(m1 / d1); + mid = std::min(31.0f, std::max(0.0f, mid)); + dst_tmp[i].scales16[0] = GGML_FP32_TO_FP16(d1); + dst_tmp[i].zp[0] = static_cast(mid); + + // src -> [b0, b32] [b1, b33] ... [b31, b63] + // dst -> [b0, b16] [b1, b17] ... [b15, b31] [b32, b48] [b33, b49] ... [b47, b63] + const uint8_t * q = src[x + i * nblocks].qs + (j / 2) * QK5_1; + if (j % 2 == 0) { + for (int ii = 0; ii < 16; ii++) { + dst_tmp[i].qs[ii] = (q[ii] & 0x0F) | ((q[ii + 16] & 0x0F) << 4); + } + } else { + for (int ii = 0; ii < 16; ii++) { + dst_tmp[i].qs[ii] = ((q[ii] & 0xF0) >> 4) | (q[ii + 16] & 0xF0); + } + } + + // Extract the 5th bit (qh) for this sub-block + // block_q5_K.qh[32]: for sub-block j, the 5th bit is at bit position j in qh[l] + // qs was reordered: dst_qs maps to src weights [0,16,1,17,...,15,31] + // So qh must follow the same reorder to stay aligned with qs + // dst qh[4] = 32 bits for 32 weights in the reordered layout: + // byte 0: weights 0..7 (from src_qh[0..7]) + // byte 1: weights 8..15 (from src_qh[8..15]) + // byte 2: weights 16..23 (from src_qh[16..23]) + // byte 3: weights 24..31 (from src_qh[24..31]) + const uint8_t * src_qh = src[x + i * nblocks].qh; + for (int bi = 0; bi < 4; bi++) { + uint8_t qh_byte = 0; + for (int k = 0; k < 8; k++) { + int src_idx = bi * 8 + k; + qh_byte |= ((src_qh[src_idx] >> j) & 1) << k; + } + dst_tmp[i].qh[bi] = qh_byte; + } + } + *dst++ = make_block_q5_1x32(dst_tmp, interleave_block); + } + } + src += nrows_interleaved * nblocks; + } + return 0; +} + +namespace ggml::cpu::riscv64_spacemit { + +template int repack(ggml_tensor *, const void *, size_t); + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { + return repack_q4_0_to_q4_0_16_bl(t, 16, data, data_size); +} + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { + return repack_q4_1_to_q4_1_16_bl(t, 16, data, data_size); +} + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { + return repack_q4_k_to_q4_1_16_bl(t, 16, data, data_size); +} + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { + return repack_q2_k_to_q2_k_32_bl(t, 32, data, data_size); +} + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { + return repack_q3_k_to_q3_k_32_bl(t, 32, data, data_size); +} + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { +#if 0 + return repack_q4_0_to_q4_0_32_bl_ref(t, 32, data, data_size); +#else + return repack_q4_0_to_q4_0_32_bl(t, 32, data, data_size); +#endif +} + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { +#if 1 + return repack_q4_0_to_q4_0_256_32_bl_ref(t, 32, data, data_size); +#else + //return repack_q4_0_to_q4_0_256_32_bl(t, 32, data, data_size); +#endif +} + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { +#if 0 + return repack_q4_1_to_q4_1_32_bl_ref(t, 32, data, data_size); +#else + return repack_q4_1_to_q4_1_32_bl(t, 32, data, data_size); +#endif +} + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { +#if 1 + return repack_q4_0_to_q4_1_256_32_bl_ref(t, 32, data, data_size); +#else + return repack_q4_1_to_q4_1_256_32_bl(t, 32, data, data_size); +#endif +} + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { + return repack_q4_k_to_q4_1_32_bl(t, 32, data, data_size); +} + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { +#if 1 + return repack_q6_k_to_q8_0_32_bl_ref(t, 32, data, data_size); +#else + return repack_q6_k_to_q8_0_32_bl(t, 32, data, data_size); +#endif +} + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { +#if 1 + return repack_q8_0_to_q8_0_32_bl_ref(t, 32, data, data_size); +#else + return repack_q8_0_to_q8_0_32_bl(t, 32, data, data_size); +#endif +} + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { + return repack_mxfp4_to_mxfp4_32_bl(t, 32, data, data_size); +} + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { + return repack_q5_0_to_q5_0_32_bl(t, 32, data, data_size); +} + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { + return repack_q5_1_to_q5_1_32_bl(t, 32, data, data_size); +} + +template <> int repack(ggml_tensor * t, const void * data, size_t data_size) { + return repack_q5_k_to_q5_1_32_bl(t, 32, data, data_size); +} + +} // namespace ggml::cpu::riscv64_spacemit diff --git a/ggml/src/ggml-cpu/spacemit/repack.h b/ggml/src/ggml-cpu/spacemit/repack.h new file mode 100644 index 00000000000..950cbde7593 --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/repack.h @@ -0,0 +1,14 @@ +#pragma once + +#include "ggml-common.h" +#include "ggml.h" + +#include +#include + +namespace ggml::cpu::riscv64_spacemit { + +template +int repack(ggml_tensor * t, const void * data, size_t data_size); + +} // namespace ggml::cpu::riscv64_spacemit diff --git a/ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp b/ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp new file mode 100644 index 00000000000..d2f89743622 --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp @@ -0,0 +1,3178 @@ +#include "rvv_kernels.h" + +#include "common.h" +#include "ggml.h" +#include "ops.h" +#include "string.h" + +#include +#include +#include +#include + +#if !defined(__riscv_v) || !defined(__riscv_v_intrinsic) +# error "riscv v extension or v_intrinsic not enabled" +#else +# include +#endif + +#if !defined(__riscv_zfh) +# error "riscv zfh extension not enabled" +#endif + +#if defined(__GNUC__) +# pragma GCC diagnostic ignored "-Woverlength-strings" +# pragma GCC diagnostic ignored "-Wcast-qual" +# pragma GCC diagnostic ignored "-Wunused-parameter" +#endif + +namespace spacemit_kernels::rvv { + +namespace { + +auto align_up(size_t value, size_t alignment) { + return (value + alignment - 1) / alignment * alignment; +} + +static inline bool flash_attn_ext_supported_d_vlen1024_vf16(int64_t d) { + return d > 0 && d <= 128; +} + +static inline bool flash_attn_ext_supported_shape_vlen1024_vf16(int64_t DK, int64_t DV) { + return flash_attn_ext_supported_d_vlen1024_vf16(DK) && flash_attn_ext_supported_d_vlen1024_vf16(DV); +} + +static inline float reduce_sum_f32m4_vlen1024(vfloat32m4_t v, size_t vl) { + vfloat32m1_t s_v = __riscv_vfmv_v_f_f32m1(0.0f, 1); + s_v = __riscv_vfredusum_vs_f32m4_f32m1(v, s_v, vl); + return __riscv_vfmv_f_s_f32m1_f32(s_v); +} + +static inline float reduce_sum_f32m2_vlen1024(vfloat32m2_t v, size_t vl) { + vfloat32m1_t s_v = __riscv_vfmv_v_f_f32m1(0.0f, 1); + s_v = __riscv_vfredusum_vs_f32m2_f32m1(v, s_v, vl); + return __riscv_vfmv_f_s_f32m1_f32(s_v); +} + +// Adapted from ggml_v_expf_m2 in vec.h. This is accurate enough for softmax. +static inline vfloat32m2_t rvv_expf_approx_f32m2(vfloat32m2_t x, size_t vl) { + const vfloat32m2_t r = __riscv_vfmv_v_f_f32m2(0x1.8p23f, vl); + const vfloat32m2_t z = __riscv_vfmacc_vf_f32m2(r, 0x1.715476p+0f, x, vl); + const vfloat32m2_t n = __riscv_vfsub_vv_f32m2(z, r, vl); + const vfloat32m2_t b = + __riscv_vfnmsac_vf_f32m2(__riscv_vfnmsac_vf_f32m2(x, 0x1.62e4p-1f, n, vl), 0x1.7f7d1cp-20f, n, vl); + const vuint32m2_t e = __riscv_vsll_vx_u32m2(__riscv_vreinterpret_v_f32m2_u32m2(z), 23, vl); + const vfloat32m2_t k = __riscv_vreinterpret_v_u32m2_f32m2(__riscv_vadd_vx_u32m2(e, 0x3f800000, vl)); + const vbool16_t c = __riscv_vmfgt_vf_f32m2_b16(__riscv_vfabs_v_f32m2(n, vl), 126.0f, vl); + const vfloat32m2_t u = __riscv_vfmul_vv_f32m2(b, b, vl); + const vfloat32m2_t j = __riscv_vfmacc_vv_f32m2( + __riscv_vfmul_vf_f32m2(b, 0x1.ffffecp-1f, vl), + __riscv_vfmacc_vv_f32m2( + __riscv_vfmacc_vf_f32m2(__riscv_vfmv_v_f_f32m2(0x1.fffdb6p-2f, vl), 0x1.555e66p-3f, b, vl), + __riscv_vfmacc_vf_f32m2(__riscv_vfmv_v_f_f32m2(0x1.573e2ep-5f, vl), 0x1.0e4020p-7f, b, vl), u, vl), + u, vl); + + if (!__riscv_vcpop_m_b16(c, vl)) { + return __riscv_vfmacc_vv_f32m2(k, j, k, vl); + } + + const vbool16_t dm = __riscv_vmfle_vf_f32m2_b16(n, 0.0f, vl); + const vuint32m2_t d = __riscv_vmerge_vxm_u32m2(__riscv_vmv_v_x_u32m2(0, vl), 0x82000000, dm, vl); + const vfloat32m2_t s1 = __riscv_vreinterpret_v_u32m2_f32m2(__riscv_vadd_vx_u32m2(d, 0x7f000000, vl)); + const vfloat32m2_t s2 = __riscv_vreinterpret_v_u32m2_f32m2(__riscv_vsub_vv_u32m2(e, d, vl)); + const vfloat32m2_t r1 = + __riscv_vmerge_vvm_f32m2(__riscv_vfmacc_vv_f32m2(k, k, j, vl), + __riscv_vfmul_vv_f32m2(__riscv_vfmacc_vv_f32m2(s2, s2, j, vl), s1, vl), c, vl); + return __riscv_vmerge_vvm_f32m2(r1, __riscv_vfmul_vv_f32m2(s1, s1, vl), + __riscv_vmfgt_vf_f32m2_b16(__riscv_vfabs_v_f32m2(n, vl), 192.0f, vl), vl); +} + +static inline vfloat32m2_t rvv_tanh_approx_f32m2(vfloat32m2_t x, size_t vl) { + const vfloat32m2_t abs_x = __riscv_vfabs_v_f32m2(x, vl); + const vfloat32m2_t neg_2_abs = __riscv_vfmul_vf_f32m2(abs_x, -2.0f, vl); + const vfloat32m2_t exp_term = rvv_expf_approx_f32m2(neg_2_abs, vl); + const vfloat32m2_t numerator = __riscv_vfsub_vf_f32m2(exp_term, 1.0f, vl); + const vfloat32m2_t denominator = __riscv_vfadd_vf_f32m2(exp_term, 1.0f, vl); + const vfloat32m2_t tanh_abs = __riscv_vfneg_v_f32m2(__riscv_vfdiv_vv_f32m2(numerator, denominator, vl), vl); + const vbool16_t neg_mask = __riscv_vmflt_vf_f32m2_b16(x, 0.0f, vl); + const vfloat32m2_t tanh_neg = __riscv_vfneg_v_f32m2(tanh_abs, vl); + return __riscv_vmerge_vvm_f32m2(tanh_abs, tanh_neg, neg_mask, vl); +} + +static void rvv_softcap_tanh_inplace_f32(float * dst, int64_t dst_stride, int64_t tile_rows, int64_t n, float softcap) { + for (int tq = 0; tq < tile_rows; ++tq, dst += dst_stride) { + float * dst_row = dst; + int64_t remaining = n; + while (remaining > 0) { + const size_t vl = __riscv_vsetvl_e32m2(remaining); + vfloat32m2_t v = __riscv_vle32_v_f32m2(dst_row, vl); + v = rvv_tanh_approx_f32m2(v, vl); + v = __riscv_vfmul_vf_f32m2(v, softcap, vl); + __riscv_vse32_v_f32m2(dst_row, v, vl); + dst_row += vl; + remaining -= vl; + } + } +} + +static inline float rvv_softmax_exp_inplace_f32(float * dst, int64_t n, float max_value) { + float row_sum = 0.0f; + while (n > 0) { + const size_t vl = __riscv_vsetvl_e32m2(n); + vfloat32m2_t v = __riscv_vle32_v_f32m2(dst, vl); + v = __riscv_vfsub_vf_f32m2(v, max_value, vl); + v = rvv_expf_approx_f32m2(v, vl); + __riscv_vse32_v_f32m2(dst, v, vl); + row_sum += reduce_sum_f32m2_vlen1024(v, vl); + dst += vl; + n -= vl; + } + return row_sum; +} + +static inline float rvv_add_max_inplace_f32(float * dst, const float * src, int64_t n) { + float max_val = -INFINITY; + while (n > 0) { + const size_t vl = __riscv_vsetvl_e32m4(n); + vfloat32m4_t vdst = __riscv_vle32_v_f32m4(dst, vl); + vfloat32m4_t vsrc = __riscv_vle32_v_f32m4(src, vl); + vdst = __riscv_vfadd_vv_f32m4(vdst, vsrc, vl); + __riscv_vse32_v_f32m4(dst, vdst, vl); + + vfloat32m1_t seed = __riscv_vfmv_v_f_f32m1(max_val, 1); + seed = __riscv_vfredmax_vs_f32m4_f32m1(vdst, seed, vl); + max_val = __riscv_vfmv_f_s_f32m1_f32(seed); + + dst += vl; + src += vl; + n -= vl; + } + return max_val; +} + +static inline float rvv_softcap_add_max_inplace_f32(float * dst, const float * src, int64_t n, float softcap) { + if (softcap == 0.0f) { + return rvv_add_max_inplace_f32(dst, src, n); + } + + float max_val = -INFINITY; + while (n > 0) { + const size_t vl = __riscv_vsetvl_e32m2(n); + vfloat32m2_t vdst = __riscv_vle32_v_f32m2(dst, vl); + vfloat32m2_t vsrc = __riscv_vle32_v_f32m2(src, vl); + vdst = rvv_tanh_approx_f32m2(vdst, vl); + vdst = __riscv_vfmul_vf_f32m2(vdst, softcap, vl); + vdst = __riscv_vfadd_vv_f32m2(vdst, vsrc, vl); + __riscv_vse32_v_f32m2(dst, vdst, vl); + + vfloat32m1_t seed = __riscv_vfmv_v_f_f32m1(max_val, 1); + seed = __riscv_vfredmax_vs_f32m2_f32m1(vdst, seed, vl); + max_val = __riscv_vfmv_f_s_f32m1_f32(seed); + + dst += vl; + src += vl; + n -= vl; + } + return max_val; +} + +static inline void rvv_zero_f32(float * dst, int64_t n) { + while (n > 0) { + const size_t vl = __riscv_vsetvl_e32m4(n); + const vfloat32m4_t z = __riscv_vfmv_v_f_f32m4(0.0f, vl); + __riscv_vse32_v_f32m4(dst, z, vl); + dst += vl; + n -= vl; + } +} + +static inline void rvv_scale_f32(float * dst, float scale, int64_t n) { + while (n > 0) { + const size_t vl = __riscv_vsetvl_e32m4(n); + vfloat32m4_t v = __riscv_vle32_v_f32m4(dst, vl); + v = __riscv_vfmul_vf_f32m4(v, scale, vl); + __riscv_vse32_v_f32m4(dst, v, vl); + dst += vl; + n -= vl; + } +} + +static inline void rvv_add_inplace_f32(float * dst, + int64_t dst_stride, + const float * src, + int64_t src_stride, + int64_t tile_rows, + int64_t n) { + for (int tq = 0; tq < tile_rows; ++tq, dst += dst_stride, src += src_stride) { + int64_t remaining = n; + float * dst_row = dst; + const float * src_row = src; + while (remaining > 0) { + const size_t vl = __riscv_vsetvl_e32m4(remaining); + vfloat32m4_t vdst = __riscv_vle32_v_f32m4(dst_row, vl); + vfloat32m4_t vsrc = __riscv_vle32_v_f32m4(src_row, vl); + vdst = __riscv_vfadd_vv_f32m4(vdst, vsrc, vl); + __riscv_vse32_v_f32m4(dst_row, vdst, vl); + dst_row += vl; + src_row += vl; + remaining -= vl; + } + } +} + +static inline float rvv_max_f32(const float * src, int64_t n) { + float max_val = -INFINITY; + while (n > 0) { + const size_t vl = __riscv_vsetvl_e32m4(n); + const vfloat32m4_t v = __riscv_vle32_v_f32m4(src, vl); + vfloat32m1_t seed = __riscv_vfmv_v_f_f32m1(max_val, 1); + seed = __riscv_vfredmax_vs_f32m4_f32m1(v, seed, vl); + max_val = __riscv_vfmv_f_s_f32m1_f32(seed); + src += vl; + n -= vl; + } + return max_val; +} + +static void rvv_pack_f32_as_scaled_f16(void * dst, + int64_t dst_row_stride, + const void * src, + int64_t src_row_stride, + int64_t tile_rows, + int64_t n, + float scale) { + for (int tq = 0; tq < tile_rows; ++tq) { + const float * row_ptr = (const float *) ((const char *) src + tq * src_row_stride); + _Float16 * dst_row_ptr = (_Float16 *) ((char *) dst + tq * dst_row_stride); + int64_t remaining = n; + while (remaining > 0) { + const size_t vl = __riscv_vsetvl_e32m4(remaining); + vfloat32m4_t v32 = __riscv_vle32_v_f32m4(row_ptr, vl); + v32 = __riscv_vfmul_vf_f32m4(v32, scale, vl); + const vfloat16m2_t v16 = __riscv_vfncvt_f_f_w_f16m2(v32, vl); + __riscv_vse16_v_f16m2(dst_row_ptr, v16, vl); + dst_row_ptr += vl; + row_ptr += vl; + remaining -= vl; + } + } +} + +static void rvv_pack_scaled_f16_as_f32(void * dst, + int64_t dst_row_stride, + const void * src, + int64_t src_row_stride, + int64_t tile_rows, + int64_t n, + float scale) { + for (int tq = 0; tq < tile_rows; ++tq) { + const _Float16 * row_ptr = (const _Float16 *) ((const char *) src + tq * src_row_stride); + float * dst_row_ptr = (float *) ((char *) dst + tq * dst_row_stride); + int64_t remaining = n; + while (remaining > 0) { + const size_t vl = __riscv_vsetvl_e16m2(remaining); + const vfloat16m2_t v16 = __riscv_vle16_v_f16m2(row_ptr, vl); + vfloat32m4_t v32 = __riscv_vfwcvt_f_f_v_f32m4(v16, vl); + v32 = __riscv_vfmul_vf_f32m4(v32, scale, vl); + __riscv_vse32_v_f32m4(dst_row_ptr, v32, vl); + dst_row_ptr += vl; + row_ptr += vl; + remaining -= vl; + } + } +} + +static void rvv_pack_scaled_f32_as_f32(void * dst, + int64_t dst_row_stride, + const void * src, + int64_t src_row_stride, + int64_t tile_rows, + int64_t n, + float * scale) { + for (int tq = 0; tq < tile_rows; ++tq) { + const float * row_ptr = (const float *) ((const char *) src + tq * src_row_stride); + float * dst_row_ptr = (float *) ((char *) dst + tq * dst_row_stride); + int64_t remaining = n; + while (remaining > 0) { + const size_t vl = __riscv_vsetvl_e32m4(remaining); + vfloat32m4_t v32 = __riscv_vle32_v_f32m4(row_ptr, vl); + v32 = __riscv_vfmul_vf_f32m4(v32, scale[tq], vl); + __riscv_vse32_v_f32m4(dst_row_ptr, v32, vl); + dst_row_ptr += vl; + row_ptr += vl; + remaining -= vl; + } + } +} + +static inline void rvv_transposed_s32_mn_to_nm(int8_t * dst, + int64_t n_dst_stride, + int8_t * src, + int64_t m_src_stride, + int64_t m, + int64_t n) { + int8_t * in = src; + int8_t * out = dst; + + __asm__ volatile( + "vsetvli t0, zero, e32, m1, tu, mu \n\t" + "mul t3, t0, %[os0] \n\t" + "srli t2, %[isz0], 3 \n\t" + "blez t2, M1%= \n\t" + + "LOOP_M8%=: \n\t" + "addi a1, %[dst], 0 \n\t" + "addi s1, %[src], 0 \n\t" + "add s2, %[src], %[is0] \n\t" + "add s3, s2, %[is0] \n\t" + "add s4, s3, %[is0] \n\t" + "add s5, s4, %[is0] \n\t" + "add s6, s5, %[is0] \n\t" + "add s7, s6, %[is0] \n\t" + "add s8, s7, %[is0] \n\t" + "addi t1, %[isz1], 0 \n\t" + + "LOOP_M8N%=: \n\t" + "vsetvli t0, t1, e32, m1, tu, mu \n\t" + "sub t1, t1, t0 \n\t" + "vle32.v v0, (s1) \n\t" + "sh2add s1, t0, s1 \n\t" + "vle32.v v1, (s2) \n\t" + "sh2add s2, t0, s2 \n\t" + "vle32.v v2, (s3) \n\t" + "sh2add s3, t0, s3 \n\t" + "vle32.v v3, (s4) \n\t" + "sh2add s4, t0, s4 \n\t" + "vle32.v v4, (s5) \n\t" + "sh2add s5, t0, s5 \n\t" + "vle32.v v5, (s6) \n\t" + "sh2add s6, t0, s6 \n\t" + "vle32.v v6, (s7) \n\t" + "sh2add s7, t0, s7 \n\t" + "vle32.v v7, (s8) \n\t" + "sh2add s8, t0, s8 \n\t" + "vssseg8e32.v v0, (a1), %[os0] \n\t" + "add a1, a1, t3 \n\t" + "bnez t1, LOOP_M8N%= \n\t" + "sh3add %[src], %[is0], %[src] \n\t" + "addi %[dst], %[dst], 32 \n\t" + "addi t2, t2, -1 \n\t" + "bnez t2, LOOP_M8%= \n\t" + + "M1%=: \n\t" + "andi t2, %[isz0], 7 \n\t" + "blez t2, END%= \n\t" + + "LOOP_M1%=: \n\t" + "addi a1, %[dst], 0 \n\t" + "addi s1, %[src], 0 \n\t" + "addi t1, %[isz1], 0 \n\t" + + "LOOP_M1N%=: \n\t" + "vsetvli t0, t1, e32, m1, tu, mu \n\t" + "sub t1, t1, t0 \n\t" + "vle32.v v0, (s1) \n\t" + "sh2add s1, t0, s1 \n\t" + "vsse32.v v0, (a1), %[os0] \n\t" + "add a1, a1, t3 \n\t" + "bnez t1, LOOP_M1N%= \n\t" + "add %[src], %[is0], %[src] \n\t" + "addi %[dst], %[dst], 4 \n\t" + "addi t2, t2, -1 \n\t" + "bnez t2, LOOP_M1%= \n\t" + "END%=: \n\t" + + : [src] "+r"(in), [dst] "+r"(out), [isz0] "+r"(m) + : [isz1] "r"(n), [is0] "r"(m_src_stride), [os0] "r"(n_dst_stride) + : "cc", "t0", "t1", "t2", "t3", "s1", "s2", "s3", "s4", "s5", "s6", "s7", "s8", "a1"); +} + +static inline void rvv_transposed_s16_mn_to_nm(int8_t * dst, + int64_t n_dst_stride, + int8_t * src, + int64_t m_src_stride, + int64_t m, + int64_t n) { + int8_t * in = src; + int8_t * out = dst; + + __asm__ volatile( + "vsetvli t0, zero, e16, m1, tu, mu \n\t" + "mul t3, t0, %[os0] \n\t" + "srli t2, %[isz0], 3 \n\t" + "blez t2, M1%= \n\t" + + "LOOP_M8%=: \n\t" + "addi a1, %[dst], 0 \n\t" + "addi s1, %[src], 0 \n\t" + "add s2, %[src], %[is0] \n\t" + "add s3, s2, %[is0] \n\t" + "add s4, s3, %[is0] \n\t" + "add s5, s4, %[is0] \n\t" + "add s6, s5, %[is0] \n\t" + "add s7, s6, %[is0] \n\t" + "add s8, s7, %[is0] \n\t" + "addi t1, %[isz1], 0 \n\t" + + "LOOP_M8N%=: \n\t" + "vsetvli t0, t1, e16, m1, tu, mu \n\t" + "sub t1, t1, t0 \n\t" + "vle16.v v0, (s1) \n\t" + "sh1add s1, t0, s1 \n\t" + "vle16.v v1, (s2) \n\t" + "sh1add s2, t0, s2 \n\t" + "vle16.v v2, (s3) \n\t" + "sh1add s3, t0, s3 \n\t" + "vle16.v v3, (s4) \n\t" + "sh1add s4, t0, s4 \n\t" + "vle16.v v4, (s5) \n\t" + "sh1add s5, t0, s5 \n\t" + "vle16.v v5, (s6) \n\t" + "sh1add s6, t0, s6 \n\t" + "vle16.v v6, (s7) \n\t" + "sh1add s7, t0, s7 \n\t" + "vle16.v v7, (s8) \n\t" + "sh1add s8, t0, s8 \n\t" + "vssseg8e16.v v0, (a1), %[os0] \n\t" + "add a1, a1, t3 \n\t" + "bnez t1, LOOP_M8N%= \n\t" + "sh3add %[src], %[is0], %[src] \n\t" + "addi %[dst], %[dst], 16 \n\t" + "addi t2, t2, -1 \n\t" + "bnez t2, LOOP_M8%= \n\t" + + "M1%=: \n\t" + "andi t2, %[isz0], 7 \n\t" + "blez t2, END%= \n\t" + + "LOOP_M1%=: \n\t" + "addi a1, %[dst], 0 \n\t" + "addi s1, %[src], 0 \n\t" + "addi t1, %[isz1], 0 \n\t" + + "LOOP_M1N%=: \n\t" + "vsetvli t0, t1, e16, m1, tu, mu \n\t" + "sub t1, t1, t0 \n\t" + "vle16.v v0, (s1) \n\t" + "sh1add s1, t0, s1 \n\t" + "vsse16.v v0, (a1), %[os0] \n\t" + "add a1, a1, t3 \n\t" + "bnez t1, LOOP_M1N%= \n\t" + "add %[src], %[is0], %[src] \n\t" + "addi %[dst], %[dst], 2 \n\t" + "addi t2, t2, -1 \n\t" + "bnez t2, LOOP_M1%= \n\t" + "END%=: \n\t" + + : [src] "+r"(in), [dst] "+r"(out), [isz0] "+r"(m) + : [isz1] "r"(n), [is0] "r"(m_src_stride), [os0] "r"(n_dst_stride) + : "cc", "t0", "t1", "t2", "t3", "s1", "s2", "s3", "s4", "s5", "s6", "s7", "s8", "a1"); +} + +static inline void rvv_qk_dot_tile_f16_x1(float * dst, + const _Float16 * q_row, + const _Float16 * k_pack, + int64_t dk, + int64_t kv_tile) { + const size_t vl = __riscv_vsetvl_e16m1(kv_tile); + vfloat32m2_t acc = __riscv_vfmv_v_f_f32m2(0.0f, vl); + + for (int64_t d = 0; d < dk; ++d) { + const vfloat16m1_t k_vec = __riscv_vle16_v_f16m1(k_pack + d * ggml_fa_tile_config::KV, vl); + acc = __riscv_vfwmacc_vf_f32m2(acc, q_row[d], k_vec, vl); + } + + __riscv_vse32_v_f32m2(dst, acc, vl); +} + +static inline void rvv_qk_dot_tile_f16_x4(float * dst0, + float * dst1, + float * dst2, + float * dst3, + const _Float16 * q0, + const _Float16 * q1, + const _Float16 * q2, + const _Float16 * q3, + const _Float16 * k_pack, + int64_t dk, + int64_t kv_tile) { + const size_t vl = __riscv_vsetvl_e16m1(kv_tile); + vfloat32m2_t acc0 = __riscv_vfmv_v_f_f32m2(0.0f, vl); + vfloat32m2_t acc1 = __riscv_vfmv_v_f_f32m2(0.0f, vl); + vfloat32m2_t acc2 = __riscv_vfmv_v_f_f32m2(0.0f, vl); + vfloat32m2_t acc3 = __riscv_vfmv_v_f_f32m2(0.0f, vl); + + for (int64_t d = 0; d < dk; ++d) { + const vfloat16m1_t k_vec = __riscv_vle16_v_f16m1(k_pack + d * ggml_fa_tile_config::KV, vl); + acc0 = __riscv_vfwmacc_vf_f32m2(acc0, q0[d], k_vec, vl); + acc1 = __riscv_vfwmacc_vf_f32m2(acc1, q1[d], k_vec, vl); + acc2 = __riscv_vfwmacc_vf_f32m2(acc2, q2[d], k_vec, vl); + acc3 = __riscv_vfwmacc_vf_f32m2(acc3, q3[d], k_vec, vl); + } + + __riscv_vse32_v_f32m2(dst0, acc0, vl); + __riscv_vse32_v_f32m2(dst1, acc1, vl); + __riscv_vse32_v_f32m2(dst2, acc2, vl); + __riscv_vse32_v_f32m2(dst3, acc3, vl); +} + +static inline void rvv_pv_accumulate_f16_x1(float * dst, + const float * prob, + const _Float16 * v_pack, + int64_t kv_tile, + int64_t dv) { + int64_t d_left = dv; + int64_t d_off = 0; + + while (d_left > 0) { + const size_t vl = __riscv_vsetvl_e16m2(d_left); + vfloat32m4_t acc = __riscv_vle32_v_f32m4(dst + d_off, vl); + + for (int64_t tk = 0; tk < kv_tile; ++tk) { + const vfloat16m2_t v16 = __riscv_vle16_v_f16m2(v_pack + tk * dv + d_off, vl); + const vfloat32m4_t v32 = __riscv_vfwcvt_f_f_v_f32m4(v16, vl); + acc = __riscv_vfmacc_vf_f32m4(acc, prob[tk], v32, vl); + } + + __riscv_vse32_v_f32m4(dst + d_off, acc, vl); + d_left -= vl; + d_off += vl; + } +} + +static inline void rvv_pv_accumulate_f16_x4(float * dst0, + float * dst1, + float * dst2, + float * dst3, + const float * prob0, + const float * prob1, + const float * prob2, + const float * prob3, + const _Float16 * v_pack, + int64_t kv_tile, + int64_t dv) { + int64_t d_left = dv; + int64_t d_off = 0; + + while (d_left > 0) { + const size_t vl = __riscv_vsetvl_e16m2(d_left); + vfloat32m4_t acc0 = __riscv_vle32_v_f32m4(dst0 + d_off, vl); + vfloat32m4_t acc1 = __riscv_vle32_v_f32m4(dst1 + d_off, vl); + vfloat32m4_t acc2 = __riscv_vle32_v_f32m4(dst2 + d_off, vl); + vfloat32m4_t acc3 = __riscv_vle32_v_f32m4(dst3 + d_off, vl); + + for (int64_t tk = 0; tk < kv_tile; ++tk) { + const vfloat16m2_t v16 = __riscv_vle16_v_f16m2(v_pack + tk * dv + d_off, vl); + const vfloat32m4_t v32 = __riscv_vfwcvt_f_f_v_f32m4(v16, vl); + acc0 = __riscv_vfmacc_vf_f32m4(acc0, prob0[tk], v32, vl); + acc1 = __riscv_vfmacc_vf_f32m4(acc1, prob1[tk], v32, vl); + acc2 = __riscv_vfmacc_vf_f32m4(acc2, prob2[tk], v32, vl); + acc3 = __riscv_vfmacc_vf_f32m4(acc3, prob3[tk], v32, vl); + } + + __riscv_vse32_v_f32m4(dst0 + d_off, acc0, vl); + __riscv_vse32_v_f32m4(dst1 + d_off, acc1, vl); + __riscv_vse32_v_f32m4(dst2 + d_off, acc2, vl); + __riscv_vse32_v_f32m4(dst3 + d_off, acc3, vl); + d_left -= vl; + d_off += vl; + } +} + +static inline void rvv_qk_dot_tile(float * dst, + const float * q_row, + const float * k_pack, + int64_t dk, + int64_t kv_tile, + float scale) { + const size_t vl = __riscv_vsetvl_e32m4(kv_tile); + vfloat32m4_t acc = __riscv_vfmv_v_f_f32m4(0.0f, vl); + + for (int64_t d = 0; d < dk; ++d) { + const vfloat32m4_t k_vec = __riscv_vle32_v_f32m4(k_pack + d * kv_tile, vl); + acc = __riscv_vfmacc_vf_f32m4(acc, q_row[d] * scale, k_vec, vl); + } + + __riscv_vse32_v_f32m4(dst, acc, vl); +} + +static inline void rvv_pv_accumulate(float * dst, + const float * prob, + const float * v_pack, + int64_t kv_tile, + int64_t dv) { + int64_t d_left = dv; + int64_t d_off = 0; + + while (d_left > 0) { + const size_t vl = __riscv_vsetvl_e32m4(d_left); + vfloat32m4_t acc = __riscv_vle32_v_f32m4(dst + d_off, vl); + + for (int64_t tk = 0; tk < kv_tile; ++tk) { + const vfloat32m4_t v_vec = __riscv_vle32_v_f32m4(v_pack + tk * dv + d_off, vl); + acc = __riscv_vfmacc_vf_f32m4(acc, prob[tk], v_vec, vl); + } + + __riscv_vse32_v_f32m4(dst + d_off, acc, vl); + d_left -= vl; + d_off += vl; + } +} + +static void permute_transpose_impl(const ggml_tensor * src0, + ggml_tensor * dst, + int64_t batch, + int64_t m, + int64_t n, + int64_t batch_stride, + int64_t m_src_stride, + int64_t n_src_stride, + int64_t n_dst_stride, + int ith, + int nth) { + GGML_ASSERT(n_src_stride == sizeof(int32_t) || n_src_stride == sizeof(int16_t)); + + if (n_src_stride == sizeof(int32_t)) { + for (int64_t bi = ith; bi < batch; bi += nth) { + rvv_transposed_s32_mn_to_nm((int8_t *) ((char *) dst->data + bi * batch_stride), n_dst_stride, + (int8_t *) ((char *) src0->data + bi * batch_stride), m_src_stride, m, n); + } + } else if (n_src_stride == sizeof(int16_t)) { + for (int64_t bi = ith; bi < batch; bi += nth) { + rvv_transposed_s32_mn_to_nm((int8_t *) ((char *) dst->data + bi * batch_stride), n_dst_stride, + (int8_t *) ((char *) src0->data + bi * batch_stride), m_src_stride, m, n); + } + } else { + GGML_ABORT("not implemented"); + } +} + +template +static void flash_attn_ext_f16_one_chunk_inner_vlen1024_vf16_mrow(float ** pq, + const char * k_data_row, + const char * v_data_row, + const ggml_fp16_t * mp, + float ** sinks, + float ** dst, + float scale, + float logit_softcap, + float slope, + int64_t nek1, + int64_t nbk1, + int64_t nbv1, + int64_t DV, + int64_t DK, + void * tcm_buffer, + size_t tcm_buffer_size) { + GGML_ASSERT(flash_attn_ext_supported_shape_vlen1024_vf16(DK, DV)); + float S[QLEN] = { 0.0f }; // sum + float M[QLEN] = { -INFINITY }; // maximum KQ value + + _Float16 * kq16_buffer = (_Float16 *) tcm_buffer; + _Float16 * qv_buffer = kq16_buffer + QLEN * DV; + const size_t qkv_temp_buffer_size = (QLEN * DV + QLEN * DK) * sizeof(_Float16); + char * kv_tile_buffer = (char *) (qv_buffer + QLEN * DK); + + { + vfloat16m2_t VKQ16_v = __riscv_vfmv_v_f_f16m2(0.0f, DV); + for (int64_t i = 0; i < QLEN; ++i) { + __riscv_vse16_v_f16m2(kq16_buffer + i * DV, VKQ16_v, DV); + vfloat16m2_t Q_q_v = __riscv_vfncvt_f_f_w_f16m2(__riscv_vle32_v_f32m4(pq[i], DK), DK); + __riscv_vse16_v_f16m2(qv_buffer + i * DK, Q_q_v, DK); + } + } + + const uintptr_t scratch_addr = reinterpret_cast(kv_tile_buffer); + const size_t scratch_size = tcm_buffer_size > qkv_temp_buffer_size ? tcm_buffer_size - qkv_temp_buffer_size : 0; + const uintptr_t kq_tile_addr = align_up(scratch_addr, alignof(float)); + const size_t scratch_prefix = kq_tile_addr - scratch_addr; + const size_t packed_tile_size = + QLEN * sizeof(float) + DK * sizeof(_Float16) + DV * sizeof(_Float16) + sizeof(float); + const int64_t max_ic_tile_step = ((int64_t) __riscv_vsetvlmax_e16m1()) & ~((int64_t) 7); + const int64_t max_fit_by_tcm = + scratch_size > scratch_prefix ? (int64_t) ((scratch_size - scratch_prefix) / packed_tile_size) : 0; + const int64_t ic_tile_step = std::min(max_ic_tile_step, max_fit_by_tcm) & ~((int64_t) 7); + + const uintptr_t k_tile_addr = kq_tile_addr + QLEN * ic_tile_step * sizeof(float); + const uintptr_t v_tile_addr = k_tile_addr + DK * ic_tile_step * sizeof(_Float16); + const uintptr_t mv_tile_addr = v_tile_addr + ic_tile_step * DV * sizeof(_Float16); + + if (ic_tile_step >= 8) { + float * kq_tile_buffer = reinterpret_cast(kq_tile_addr); + _Float16 * k_tile_pack = reinterpret_cast<_Float16 *>(k_tile_addr); + _Float16 * v_tile_pack = reinterpret_cast<_Float16 *>(v_tile_addr); + float * mv_tile_pack = reinterpret_cast(mv_tile_addr); + + const int64_t k_tile_byte_stride = ic_tile_step * (int64_t) sizeof(_Float16); + + int64_t ic_step = 0; + for (int64_t ic = 0; ic < nek1; ++ic) { + const float mv = mp ? slope * ((_Float16 *) mp)[ic] : 0.0f; + + if (mv != -INFINITY) { + const _Float16 * k_data = (const _Float16 *) (k_data_row + ic * nbk1); + const _Float16 * v_data = (const _Float16 *) (v_data_row + ic * nbv1); + + const vfloat16m2_t k_data_v = __riscv_vle16_v_f16m2(k_data, DK); + const vfloat16m2_t v_data_v = __riscv_vle16_v_f16m2(v_data, DV); + __riscv_vsse16_v_f16m2(k_tile_pack + ic_step, k_tile_byte_stride, k_data_v, DK); + __riscv_vse16_v_f16m2(v_tile_pack + ic_step * DV, v_data_v, DV); + mv_tile_pack[ic_step] = mv; + ic_step++; + } + + if (ic_step > 0 && (ic_step == ic_tile_step || ic == (nek1 - 1))) { + if constexpr (QLEN == 4) { + const size_t qk_vl = __riscv_vsetvl_e16m1(ic_step); + vfloat32m2_t qk_acc0 = __riscv_vfmv_v_f_f32m2(0.0f, qk_vl); + vfloat32m2_t qk_acc1 = __riscv_vfmv_v_f_f32m2(0.0f, qk_vl); + vfloat32m2_t qk_acc2 = __riscv_vfmv_v_f_f32m2(0.0f, qk_vl); + vfloat32m2_t qk_acc3 = __riscv_vfmv_v_f_f32m2(0.0f, qk_vl); + + for (int64_t d = 0; d < DK; ++d) { + const vfloat16m1_t k_vec = __riscv_vle16_v_f16m1(k_tile_pack + d * ic_tile_step, qk_vl); + qk_acc0 = __riscv_vfwmacc_vf_f32m2(qk_acc0, qv_buffer[0 * DK + d], k_vec, qk_vl); + qk_acc1 = __riscv_vfwmacc_vf_f32m2(qk_acc1, qv_buffer[1 * DK + d], k_vec, qk_vl); + qk_acc2 = __riscv_vfwmacc_vf_f32m2(qk_acc2, qv_buffer[2 * DK + d], k_vec, qk_vl); + qk_acc3 = __riscv_vfwmacc_vf_f32m2(qk_acc3, qv_buffer[3 * DK + d], k_vec, qk_vl); + } + + qk_acc0 = __riscv_vfmul_vf_f32m2(qk_acc0, scale, qk_vl); + qk_acc1 = __riscv_vfmul_vf_f32m2(qk_acc1, scale, qk_vl); + qk_acc2 = __riscv_vfmul_vf_f32m2(qk_acc2, scale, qk_vl); + qk_acc3 = __riscv_vfmul_vf_f32m2(qk_acc3, scale, qk_vl); + + __riscv_vse32_v_f32m2(kq_tile_buffer + 0 * ic_tile_step, qk_acc0, qk_vl); + __riscv_vse32_v_f32m2(kq_tile_buffer + 1 * ic_tile_step, qk_acc1, qk_vl); + __riscv_vse32_v_f32m2(kq_tile_buffer + 2 * ic_tile_step, qk_acc2, qk_vl); + __riscv_vse32_v_f32m2(kq_tile_buffer + 3 * ic_tile_step, qk_acc3, qk_vl); + } else { + static_assert(QLEN == 2, "unsupported QLEN"); + + const size_t qk_vl = __riscv_vsetvl_e16m1(ic_step); + vfloat32m2_t qk_acc0 = __riscv_vfmv_v_f_f32m2(0.0f, qk_vl); + vfloat32m2_t qk_acc1 = __riscv_vfmv_v_f_f32m2(0.0f, qk_vl); + + for (int64_t d = 0; d < DK; ++d) { + const vfloat16m1_t k_vec = __riscv_vle16_v_f16m1(k_tile_pack + d * ic_tile_step, qk_vl); + qk_acc0 = __riscv_vfwmacc_vf_f32m2(qk_acc0, qv_buffer[0 * DK + d], k_vec, qk_vl); + qk_acc1 = __riscv_vfwmacc_vf_f32m2(qk_acc1, qv_buffer[1 * DK + d], k_vec, qk_vl); + } + + qk_acc0 = __riscv_vfmul_vf_f32m2(qk_acc0, scale, qk_vl); + qk_acc1 = __riscv_vfmul_vf_f32m2(qk_acc1, scale, qk_vl); + + __riscv_vse32_v_f32m2(kq_tile_buffer + 0 * ic_tile_step, qk_acc0, qk_vl); + __riscv_vse32_v_f32m2(kq_tile_buffer + 1 * ic_tile_step, qk_acc1, qk_vl); + } + + for (int i = 0; i < QLEN; ++i) { + float * row_ptr = kq_tile_buffer + i * ic_tile_step; + const float tile_max = + rvv_softcap_add_max_inplace_f32(row_ptr, mv_tile_pack, ic_step, logit_softcap); + + const float Mold = M[i]; + + if (tile_max > Mold) { + const float ms = expf(Mold - tile_max); + M[i] = tile_max; + S[i] *= ms; + + vfloat16m2_t VKQ16_v = __riscv_vle16_v_f16m2(kq16_buffer + i * DV, DV); + VKQ16_v = __riscv_vfmul_vf_f16m2(VKQ16_v, (_Float16) ms, DV); + __riscv_vse16_v_f16m2(kq16_buffer + i * DV, VKQ16_v, DV); + } + + S[i] += rvv_softmax_exp_inplace_f32(row_ptr, ic_step, M[i]); + } + + if constexpr (QLEN == 4) { + vfloat16m2_t pv_acc0 = __riscv_vle16_v_f16m2(kq16_buffer + 0 * DV, DV); + vfloat16m2_t pv_acc1 = __riscv_vle16_v_f16m2(kq16_buffer + 1 * DV, DV); + vfloat16m2_t pv_acc2 = __riscv_vle16_v_f16m2(kq16_buffer + 2 * DV, DV); + vfloat16m2_t pv_acc3 = __riscv_vle16_v_f16m2(kq16_buffer + 3 * DV, DV); + + for (int64_t tk = 0; tk < ic_step; ++tk) { + const vfloat16m2_t v16 = __riscv_vle16_v_f16m2(v_tile_pack + tk * DV, DV); + pv_acc0 = + __riscv_vfmacc_vf_f16m2(pv_acc0, (_Float16) kq_tile_buffer[0 * ic_tile_step + tk], v16, DV); + pv_acc1 = + __riscv_vfmacc_vf_f16m2(pv_acc1, (_Float16) kq_tile_buffer[1 * ic_tile_step + tk], v16, DV); + pv_acc2 = + __riscv_vfmacc_vf_f16m2(pv_acc2, (_Float16) kq_tile_buffer[2 * ic_tile_step + tk], v16, DV); + pv_acc3 = + __riscv_vfmacc_vf_f16m2(pv_acc3, (_Float16) kq_tile_buffer[3 * ic_tile_step + tk], v16, DV); + } + + __riscv_vse16_v_f16m2(kq16_buffer + 0 * DV, pv_acc0, DV); + __riscv_vse16_v_f16m2(kq16_buffer + 1 * DV, pv_acc1, DV); + __riscv_vse16_v_f16m2(kq16_buffer + 2 * DV, pv_acc2, DV); + __riscv_vse16_v_f16m2(kq16_buffer + 3 * DV, pv_acc3, DV); + } else { + static_assert(QLEN == 2, "unsupported QLEN"); + vfloat16m2_t pv_acc0 = __riscv_vle16_v_f16m2(kq16_buffer + 0 * DV, DV); + vfloat16m2_t pv_acc1 = __riscv_vle16_v_f16m2(kq16_buffer + 1 * DV, DV); + + for (int64_t tk = 0; tk < ic_step; ++tk) { + const vfloat16m2_t v16 = __riscv_vle16_v_f16m2(v_tile_pack + tk * DV, DV); + pv_acc0 = + __riscv_vfmacc_vf_f16m2(pv_acc0, (_Float16) kq_tile_buffer[0 * ic_tile_step + tk], v16, DV); + pv_acc1 = + __riscv_vfmacc_vf_f16m2(pv_acc1, (_Float16) kq_tile_buffer[1 * ic_tile_step + tk], v16, DV); + } + + __riscv_vse16_v_f16m2(kq16_buffer + 0 * DV, pv_acc0, DV); + __riscv_vse16_v_f16m2(kq16_buffer + 1 * DV, pv_acc1, DV); + } + + ic_step = 0; + } + } + } else { + for (int64_t ic = 0; ic < nek1; ++ic) { + const float mv = mp ? slope * ((_Float16 *) mp)[ic] : 0.0f; + + const char * k_data = k_data_row + ic * nbk1; + const char * v_data = v_data_row + ic * nbv1; + + vfloat16m2_t k_data_v; + vfloat16m2_t v_data_v; + + if (mv != -INFINITY) { + k_data_v = __riscv_vle16_v_f16m2((_Float16 *) k_data, DK); + v_data_v = __riscv_vle16_v_f16m2((_Float16 *) v_data, DV); + } else { + continue; + } + + for (int i = 0; i < QLEN; ++i) { + vfloat16m2_t Q_q_v = __riscv_vle16_v_f16m2(qv_buffer + i * DK, DK); + vfloat32m4_t qk_acc_v = __riscv_vfwmul_vv_f32m4(k_data_v, Q_q_v, DK); + float s = reduce_sum_f32m4_vlen1024(qk_acc_v, DK); + s = s * scale; + if (logit_softcap != 0.0f) { + s = logit_softcap * tanhf(s); + } + s += mv; + + const float Mold = M[i]; + + float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value + float vs = 1.0f; // post-softmax KQ value, expf(s - M) + + vfloat16m2_t VKQ16_v = __riscv_vle16_v_f16m2(kq16_buffer + i * DV, DV); + if (s > M[i]) { + // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f + M[i] = s; + ms = expf(Mold - M[i]); + + // V = V*expf(Mold - M) + VKQ16_v = __riscv_vfmul_vf_f16m2(VKQ16_v, ms, DV); + } else { + // no new maximum, ms == 1.0f, vs != 1.0f + vs = expf(s - M[i]); + } + VKQ16_v = __riscv_vfmacc_vf_f16m2(VKQ16_v, vs, v_data_v, DV); + __riscv_vse16_v_f16m2(kq16_buffer + i * DV, VKQ16_v, DV); + S[i] = S[i] * ms + vs; // scale and increment sum with partial sum + } + } + } + + for (int i = 0; i < QLEN; ++i) { + vfloat16m2_t VKQ16_v = __riscv_vle16_v_f16m2(kq16_buffer + i * DV, DV); + vfloat32m4_t VKQ32_v = __riscv_vfwcvt_f_f_v_f32m4(VKQ16_v, DV); + + // sinks + if (sinks[i]) { + const float s = *(sinks[i]); + + float ms = 1.0f; + float vs = 1.0f; + + if (s > M[i]) { + ms = expf(M[i] - s); + M[i] = s; + VKQ32_v = __riscv_vfmul_vf_f32m4(VKQ32_v, ms, DV); + } else { + vs = expf(s - M[i]); + } + + S[i] = S[i] * ms + vs; + } + + // V /= S + const float S_inv = S[i] == 0.0f ? 0.0f : 1.0f / S[i]; + + VKQ32_v = __riscv_vfmul_vf_f32m4(VKQ32_v, S_inv, DV); + + __riscv_vse32_v_f32m4(dst[i], VKQ32_v, DV); + } +} + +static void flash_attn_ext_f16_one_chunk_inner_vlen1024_vf16_m1(const float * pq, + const char * k_data_row, + const char * v_data_row, + const ggml_fp16_t * mp, + const float * sinks, + float * dst, + float scale, + float logit_softcap, + float slope, + int64_t nek1, + int64_t nbk1, + int64_t nbv1, + int64_t DV, + int64_t DK) { + GGML_ASSERT(flash_attn_ext_supported_shape_vlen1024_vf16(DK, DV)); + + float S = 0.0f; // sum + float M = -INFINITY; // maximum KQ value + + vfloat16m2_t VKQ16_v = __riscv_vfmv_v_f_f16m2(0.0f, DV); + + vfloat16m2_t Q_q_v = __riscv_vfncvt_f_f_w_f16m2(__riscv_vle32_v_f32m4(pq, DK), DK); + + for (int64_t ic = 0; ic < nek1; ++ic) { + const float mv = mp ? slope * ((_Float16 *) mp)[ic] : 0.0f; + if (mv == -INFINITY) { + continue; + } + + const char * k_data = k_data_row + ic * nbk1; + + vfloat16m2_t k_data_v = __riscv_vle16_v_f16m2((_Float16 *) k_data, DK); + + vfloat32m4_t qk_acc_v = __riscv_vfwmul_vv_f32m4(k_data_v, Q_q_v, DK); + float s = reduce_sum_f32m4_vlen1024(qk_acc_v, DK); + + s = s * scale; // scale KQ value + + if (logit_softcap != 0.0f) { + s = logit_softcap * tanhf(s); + } + + s += mv; // apply mask + + const float Mold = M; + + float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value + float vs = 1.0f; // post-softmax KQ value, expf(s - M) + + const char * v_data = v_data_row + ic * nbv1; + + vfloat16m2_t v_data_v = __riscv_vle16_v_f16m2((_Float16 *) v_data, DV); + + if (s > M) { + // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f + M = s; + ms = expf(Mold - M); + + // V = V*expf(Mold - M) + VKQ16_v = __riscv_vfmul_vf_f16m2(VKQ16_v, ms, DV); + } else { + // no new maximum, ms == 1.0f, vs != 1.0f + vs = expf(s - M); + } + + VKQ16_v = __riscv_vfmacc_vf_f16m2(VKQ16_v, vs, v_data_v, DV); + + S = S * ms + vs; // scale and increment sum with partial sum + } + + vfloat32m4_t VKQ32_v = __riscv_vfwcvt_f_f_v_f32m4(VKQ16_v, DV); + + // sinks + if (sinks) { + const float s = *sinks; + + float ms = 1.0f; + float vs = 1.0f; + + if (s > M) { + ms = expf(M - s); + M = s; + VKQ32_v = __riscv_vfmul_vf_f32m4(VKQ32_v, ms, DV); + } else { + vs = expf(s - M); + } + + S = S * ms + vs; + } + + // V /= S + const float S_inv = S == 0.0f ? 0.0f : 1.0f / S; + + VKQ32_v = __riscv_vfmul_vf_f32m4(VKQ32_v, S_inv, DV); + + __riscv_vse32_v_f32m4(dst, VKQ32_v, DV); +} + +} // namespace + +void memcpy1d(void * dst, const void * src, int64_t size) { + size_t byte_size_all = size; + size_t vlen = __riscv_vlenb() * 8; + if (vlen == 256) { + // 1024 bytes + __asm__ volatile( + // + "srli t0, %[size], 10 \n\t" + "blez t0, memcpy_tail%= \n\t" + "vsetvli t1, x0, e8, m8, tu, mu \n\t" + "memcpy_main_loop%=: \n\t" + "addi t0, t0, -1 \n\t" + "vle8.v v0, (%[s]) \n\t" + "addi %[s], %[s], 256 \n\t" + "vle8.v v8, (%[s]) \n\t" + "addi %[s], %[s], 256 \n\t" + "vle8.v v16, (%[s]) \n\t" + "addi %[s], %[s], 256 \n\t" + "vle8.v v24, (%[s]) \n\t" + "addi %[s], %[s], 256 \n\t" + // + "vse8.v v0, (%[d]) \n\t" + "addi %[d], %[d], 256 \n\t" + "vse8.v v8, (%[d]) \n\t" + "addi %[d], %[d], 256 \n\t" + "vse8.v v16, (%[d]) \n\t" + "addi %[d], %[d], 256 \n\t" + "vse8.v v24, (%[d]) \n\t" + "addi %[d], %[d], 256 \n\t" + // + "bnez t0, memcpy_main_loop%= \n\t" + "memcpy_tail%=: \n\t" + "andi t1, %[size], 1023 \n\t" + "blez t1, out%= \n\t" + "memcpy_tail_loop%=: \n\t" + "vsetvli t0, t1, e8, m8, tu, mu \n\t" + "sub t1, t1, t0 \n\t" + "vle8.v v0, (%[s]) \n\t" + "add %[s], %[s], t0 \n\t" + "vse8.v v0, (%[d]) \n\t" + "add %[d], %[d], t0 \n\t" + "bnez t1, memcpy_tail_loop%= \n\t" + "out%=: \n\t" + : [s] "+r"(src), [d] "+r"(dst) + : [size] "r"(byte_size_all) + : "cc", "t0", "t1"); + } else if (vlen == 1024) { + // 2048 bytes + __asm__ volatile( + // + "srli t0, %[size], 11 \n\t" + "blez t0, memcpy_tail%= \n\t" + "vsetvli t1, x0, e8, m8, tu, mu \n\t" + "addi t2, %[s], 1024 \n\t" + "addi t3, %[d], 1024 \n\t" + "li t5, 2048 \n\t" + "memcpy_main_loop%=: \n\t" + "addi t0, t0, -1 \n\t" + "vle8.v v0, (%[s]) \n\t" + "add %[s], %[s], t5 \n\t" + "vle8.v v8, (t2) \n\t" + "add t2, t2, t5 \n\t" + // + "vse8.v v0, (%[d]) \n\t" + "add %[d], %[d], t5 \n\t" + "vse8.v v8, (t3) \n\t" + "add t3, t3, t5 \n\t" + // + "bnez t0, memcpy_main_loop%= \n\t" + "memcpy_tail%=: \n\t" + "andi t1, %[size], 2047 \n\t" + "blez t1, out%= \n\t" + "memcpy_tail_loop%=: \n\t" + "vsetvli t0, t1, e8, m2, tu, mu \n\t" + "sub t1, t1, t0 \n\t" + "vle8.v v0, (%[s]) \n\t" + "add %[s], %[s], t0 \n\t" + "vse8.v v0, (%[d]) \n\t" + "add %[d], %[d], t0 \n\t" + "bnez t1, memcpy_tail_loop%= \n\t" + "out%=: \n\t" + : [s] "+r"(src), [d] "+r"(dst) + : [size] "r"(byte_size_all) + : "cc", "t0", "t1", "t2", "t3", "t5"); + } else { + __asm__ volatile( + // + "add t1, %[size], zero \n\t" + "memcpy_tail_loop%=: \n\t" + "vsetvli t0, t1, e8, m8, tu, mu \n\t" + "sub t1, t1, t0 \n\t" + "vle8.v v0, (%[s]) \n\t" + "add %[s], %[s], t0 \n\t" + "vse8.v v0, (%[d]) \n\t" + "add %[d], %[d], t0 \n\t" + "bnez t1, memcpy_tail_loop%= \n\t" + : [s] "+r"(src), [d] "+r"(dst) + : [size] "r"(byte_size_all) + : "cc", "t0", "t1", "t2", "t4", "t3"); + } +} + +void memcpy2d(void * dst, int64_t dst_stride, const void * src, int64_t src_stride, int64_t tile_rows, int64_t size) { + for (int64_t i = 0; i < tile_rows; ++i) { + memcpy1d((char *) dst + i * dst_stride, (const char *) src + i * src_stride, size); + } +} + +void forward_flash_attn_ext_f16_one_chunk_vlen1024_vf16(const ggml_compute_params * params, + ggml_tensor * dst, + int ir0, + int ir1, + void * tcm_buffer, + size_t tcm_buffer_size) { + const ggml_tensor * q = dst->src[0]; + const ggml_tensor * k = dst->src[1]; + const ggml_tensor * v = dst->src[2]; + const ggml_tensor * mask = dst->src[3]; + const ggml_tensor * sinks = dst->src[4]; + + GGML_TENSOR_LOCALS(int64_t, neq, q, ne) + GGML_TENSOR_LOCALS(size_t, nbq, q, nb) + GGML_TENSOR_LOCALS(int64_t, nek, k, ne) + GGML_TENSOR_LOCALS(size_t, nbk, k, nb) + GGML_TENSOR_LOCALS(int64_t, nev, v, ne) + GGML_TENSOR_LOCALS(size_t, nbv, v, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + + const int64_t DK = nek0; + const int64_t DV = nev0; + const int64_t N = neq1; + + GGML_ASSERT(flash_attn_ext_supported_shape_vlen1024_vf16(DK, DV)); + + // broadcast factors + const int64_t rk2 = neq2 / nek2; + const int64_t rk3 = neq3 / nek3; + + const int64_t rv2 = neq2 / nev2; + const int64_t rv3 = neq3 / nev3; + + // parallelize by q rows using ggml_vec_dot_f32 + + float scale = *((float *) dst->op_params + 0); + float max_bias = *((float *) dst->op_params + 1); + float logit_softcap = *((float *) dst->op_params + 2); + + if (logit_softcap != 0) { + scale /= logit_softcap; + } + + const uint32_t n_head = neq2; + const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head)); + + const float m0 = powf(2.0f, -(max_bias) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + const int KV_row_size = DK * sizeof(_Float16) + DV * sizeof(_Float16); + + int ith = params->ith; + int ir_step = 1; + for (int ir = ir0; ir < ir1; ir += ir_step) { + // q indices + const int iq3 = ir / (neq2 * neq1); + const int iq2 = (ir - iq3 * neq2 * neq1) / neq1; + const int iq1 = (ir - iq3 * neq2 * neq1 - iq2 * neq1); + + const int iq3_1 = (ir + 1) / (neq2 * neq1); + const int iq2_1 = (ir + 1 - iq3_1 * neq2 * neq1) / neq1; + const int iq1_1 = (ir + 1 - iq3_1 * neq2 * neq1 - iq2_1 * neq1); + + const int iq3_2 = (ir + 2) / (neq2 * neq1); + const int iq2_2 = (ir + 2 - iq3_2 * neq2 * neq1) / neq1; + const int iq1_2 = (ir + 2 - iq3_2 * neq2 * neq1 - iq2_2 * neq1); + + const int iq3_3 = (ir + 3) / (neq2 * neq1); + const int iq2_3 = (ir + 3 - iq3_3 * neq2 * neq1) / neq1; + const int iq1_3 = (ir + 3 - iq3_3 * neq2 * neq1 - iq2_3 * neq1); + + const uint32_t h = iq2; // head index + const float slope = + (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2 * (h - n_head_log2) + 1) : 1.0f; + + const ggml_fp16_t * mp = + mask ? (ggml_fp16_t *) ((char *) mask->data + iq1 * mask->nb[1] + (iq2 % mask->ne[2]) * mask->nb[2] + + (iq3 % mask->ne[3]) * mask->nb[3]) : + NULL; + + const bool mp_equal_2 = iq1_1 == iq1 && (iq2 % mask->ne[2]) == (iq2_1 % mask->ne[2]) && + (iq3 % mask->ne[3]) == (iq3_1 % mask->ne[3]); + + const bool mp_equal_4 = mp_equal_2 && iq1_2 == iq1 && (iq2 % mask->ne[2]) == (iq2_2 % mask->ne[2]) && + (iq3 % mask->ne[3]) == (iq3_2 % mask->ne[3]) && iq1_3 == iq1 && + (iq2 % mask->ne[2]) == (iq2_3 % mask->ne[2]) && + (iq3 % mask->ne[3]) == (iq3_3 % mask->ne[3]); + + // k indices + const int ik3 = iq3 / rk3; + const int ik2 = iq2 / rk2; + + const int ik3_1 = iq3_1 / rk3; + const int ik2_1 = iq2_1 / rk2; + + const int ik3_2 = iq3_2 / rk3; + const int ik2_2 = iq2_2 / rk2; + + const int ik3_3 = iq3_3 / rk3; + const int ik2_3 = iq2_3 / rk2; + + // v indices + const int iv3 = iq3 / rv3; + const int iv2 = iq2 / rv2; + + const int iv3_1 = iq3_1 / rv3; + const int iv2_1 = iq2_1 / rv2; + + const int iv3_2 = iq3_2 / rv3; + const int iv2_2 = iq2_2 / rv2; + + const int iv3_3 = iq3_3 / rv3; + const int iv2_3 = iq2_3 / rv2; + + const float * pq = (const float *) ((char *) q->data + (iq1 * nbq1 + iq2 * nbq2 + iq3 * nbq3)); + + std::array pq_buffer; + std::array sinks_buffer; + std::array dst_buffer; + + if (tcm_buffer != nullptr && 4 * KV_row_size < tcm_buffer_size && ir < (ir1 - 3) && mp_equal_4 && + ik3_3 == ik3 && ik2_3 == ik2 && iv3_3 == iv3 && iv2_3 == iv2 && ik3_2 == ik3 && ik2_2 == ik2 && + iv3_2 == iv3 && iv2_2 == iv2 && ik3_1 == ik3 && ik2_1 == ik2 && iv3_1 == iv3 && iv2_1 == iv2) { + ir_step = 4; + + pq_buffer[0] = (float *) ((char *) q->data + (iq1 * nbq1 + iq2 * nbq2 + iq3 * nbq3)); + pq_buffer[1] = (float *) ((char *) q->data + (iq1_1 * nbq1 + iq2_1 * nbq2 + iq3_1 * nbq3)); + pq_buffer[2] = (float *) ((char *) q->data + (iq1_2 * nbq1 + iq2_2 * nbq2 + iq3_2 * nbq3)); + pq_buffer[3] = (float *) ((char *) q->data + (iq1_3 * nbq1 + iq2_3 * nbq2 + iq3_3 * nbq3)); + + sinks_buffer[0] = sinks ? ((float *) ((char *) sinks->data)) + iq2 : nullptr; + sinks_buffer[1] = sinks ? ((float *) ((char *) sinks->data)) + iq2_1 : nullptr; + sinks_buffer[2] = sinks ? ((float *) ((char *) sinks->data)) + iq2_2 : nullptr; + sinks_buffer[3] = sinks ? ((float *) ((char *) sinks->data)) + iq2_3 : nullptr; + + dst_buffer[0] = (float *) ((char *) dst->data + (iq3 * ne2 * ne1 + iq2 + iq1 * ne1) * nb1); + dst_buffer[1] = (float *) ((char *) dst->data + (iq3_1 * ne2 * ne1 + iq2_1 + iq1_1 * ne1) * nb1); + dst_buffer[2] = (float *) ((char *) dst->data + (iq3_2 * ne2 * ne1 + iq2_2 + iq1_2 * ne1) * nb1); + dst_buffer[3] = (float *) ((char *) dst->data + (iq3_3 * ne2 * ne1 + iq2_3 + iq1_3 * ne1) * nb1); + + flash_attn_ext_f16_one_chunk_inner_vlen1024_vf16_mrow<4>( // + pq_buffer.data(), // + (const char *) k->data + (ik2 * nbk2 + ik3 * nbk3), // + (const char *) v->data + (iv2 * nbv2 + iv3 * nbv3), // + mp, // + sinks_buffer.data(), // + dst_buffer.data(), // + scale, logit_softcap, slope, nek1, nbk1, nbv1, DV, DK, tcm_buffer, tcm_buffer_size); + } else if (tcm_buffer != nullptr && 2 * KV_row_size < tcm_buffer_size && ir < (ir1 - 1) && mp_equal_2 && + ik3_1 == ik3 && ik2_1 == ik2 && iv3_1 == iv3 && iv2_1 == iv2) { + ir_step = 2; + + pq_buffer[0] = (float *) ((char *) q->data + (iq1 * nbq1 + iq2 * nbq2 + iq3 * nbq3)); + pq_buffer[1] = (float *) ((char *) q->data + (iq1_1 * nbq1 + iq2_1 * nbq2 + iq3_1 * nbq3)); + + sinks_buffer[0] = sinks ? ((float *) ((char *) sinks->data)) + iq2 : nullptr; + sinks_buffer[1] = sinks ? ((float *) ((char *) sinks->data)) + iq2_1 : nullptr; + + dst_buffer[0] = (float *) ((char *) dst->data + (iq3 * ne2 * ne1 + iq2 + iq1 * ne1) * nb1); + dst_buffer[1] = (float *) ((char *) dst->data + (iq3_1 * ne2 * ne1 + iq2_1 + iq1_1 * ne1) * nb1); + + flash_attn_ext_f16_one_chunk_inner_vlen1024_vf16_mrow<2>( // + pq_buffer.data(), // + (const char *) k->data + (ik2 * nbk2 + ik3 * nbk3), // + (const char *) v->data + (iv2 * nbv2 + iv3 * nbv3), // + mp, // + sinks_buffer.data(), // + dst_buffer.data(), // + scale, logit_softcap, slope, nek1, nbk1, nbv1, DV, DK, tcm_buffer, tcm_buffer_size); + } else { + ir_step = 1; + flash_attn_ext_f16_one_chunk_inner_vlen1024_vf16_m1( // + pq, // + (const char *) k->data + (ik2 * nbk2 + ik3 * nbk3), // + (const char *) v->data + (iv2 * nbv2 + iv3 * nbv3), // + mp, // + sinks ? ((float *) ((char *) sinks->data)) + h : nullptr, // + (float *) ((char *) dst->data + (iq3 * ne2 * ne1 + iq2 + iq1 * ne1) * nb1), // + scale, logit_softcap, slope, nek1, nbk1, nbv1, DV, DK); + } + } +} + +void forward_flash_attn_ext_f16_tiled_vlen1024_vf16(const ggml_compute_params * params, + ggml_tensor * dst, + int ir0, + int ir1, + void * tcm_buffer, + size_t tcm_buffer_size) { + const ggml_tensor * q = dst->src[0]; + const ggml_tensor * k = dst->src[1]; + const ggml_tensor * v = dst->src[2]; + const ggml_tensor * mask = dst->src[3]; + const ggml_tensor * sinks = dst->src[4]; + + GGML_TENSOR_LOCALS(int64_t, neq, q, ne) + GGML_TENSOR_LOCALS(size_t, nbq, q, nb) + GGML_TENSOR_LOCALS(int64_t, nek, k, ne) + GGML_TENSOR_LOCALS(size_t, nbk, k, nb) + GGML_TENSOR_LOCALS(int64_t, nev, v, ne) + GGML_TENSOR_LOCALS(size_t, nbv, v, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + + const int64_t DK = nek0; + const int64_t DV = nev0; + const int64_t N = neq1; + + GGML_ASSERT(flash_attn_ext_supported_shape_vlen1024_vf16(DK, DV)); + + GGML_ASSERT(ne0 == DV); + GGML_ASSERT(ne2 == N); + + // input tensor rows must be contiguous + GGML_ASSERT(nbq0 == ggml_type_size(q->type)); + GGML_ASSERT(nbk0 == ggml_type_size(k->type)); + GGML_ASSERT(nbv0 == ggml_type_size(v->type)); + + GGML_ASSERT(neq0 == DK); + GGML_ASSERT(nek0 == DK); + GGML_ASSERT(nev0 == DV); + + GGML_ASSERT(neq1 == N); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + GGML_ASSERT(k->type == v->type); + const ggml_type kv_type = k->type; + + // broadcast factors + const int64_t rk2 = neq2 / nek2; + const int64_t rk3 = neq3 / nek3; + + const int64_t rv2 = neq2 / nev2; + const int64_t rv3 = neq3 / nev3; + + float * param_list = (float *) dst->op_params; + float scale = param_list[0]; + float max_bias = param_list[1]; + float logit_softcap = param_list[2]; + + if (logit_softcap != 0) { + scale /= logit_softcap; + } + + const uint32_t n_head = neq2; + const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head)); + + const float m0 = powf(2.0f, -(max_bias) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + int ith = params->ith; + + static constexpr int Q_TILE_SZ = ggml_fa_tile_config::Q; + static constexpr int KV_TILE_SZ = ggml_fa_tile_config::KV; + + // Per-thread scratch layout: + // Q_f32: Q_TILE_SZ * DK + // KQ: Q_TILE_SZ * KV_TILE_SZ + // mask32: Q_TILE_SZ * KV_TILE_SZ + // VKQ32: Q_TILE_SZ * DV + // V32: KV_TILE_SZ * DV + // K_f32: DK * KV_TILE_SZ (transposed K tile) + float * base = (float *) params->wdata + ith * (Q_TILE_SZ * DK + 2 * Q_TILE_SZ * KV_TILE_SZ + Q_TILE_SZ * DV + + KV_TILE_SZ * DV + KV_TILE_SZ * DK + CACHE_LINE_SIZE_F32); + const size_t base_size = + (Q_TILE_SZ * DK + 2 * Q_TILE_SZ * KV_TILE_SZ + Q_TILE_SZ * DV + KV_TILE_SZ * DV + KV_TILE_SZ * DK) * + sizeof(float) + + CACHE_LINE_SIZE_F32; + + if (base_size <= tcm_buffer_size && tcm_buffer != nullptr) { + base = (float *) tcm_buffer; + } + + float S_M_Buf[Q_TILE_SZ * 2]; // buffer to hold S, M, bias for one tile to reduce register pressure in main loop + float * S = S_M_Buf; + float * M = S_M_Buf + Q_TILE_SZ; + + int ir = ir0; + while (ir < ir1) { + // q indices for the start of this tile + const int iq3 = ir / (neq2 * neq1); + const int iq2 = (ir - iq3 * neq2 * neq1) / neq1; + const int iq1 = (ir - iq3 * neq2 * neq1 - iq2 * neq1); + + // Number of valid rows in this tile: + // - limited by tile size (Q_TILE_SZ) + // - limited by chunk boundary (ir1 - ir) + // - limited by head boundary (neq1 - iq1) to avoid crossing into next head + const int tile_rows = MIN(Q_TILE_SZ, MIN((int) (ir1 - ir), (int) (neq1 - iq1))); + GGML_ASSERT(tile_rows > 0); + + const uint32_t h = iq2; // head index + const float slope = + (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2 * (h - n_head_log2) + 1) : 1.0f; + + for (int i = 0; i < Q_TILE_SZ; ++i) { + S[i] = 0.; + M[i] = -INFINITY; + } + + float * Q_f32 = base; + float * KQ = (float *) ((char *) base + Q_TILE_SZ * DK * sizeof(float)); + float * mask32 = KQ + Q_TILE_SZ * KV_TILE_SZ; + float * VKQ32 = mask32 + Q_TILE_SZ * KV_TILE_SZ; + float * V32 = VKQ32 + Q_TILE_SZ * DV; + float * K_f32 = V32 + KV_TILE_SZ * DV; + _Float16 * Q_f16 = (_Float16 *) Q_f32; + _Float16 * V_f16 = (_Float16 *) V32; + _Float16 * K_f16 = (_Float16 *) K_f32; + + rvv_zero_f32(VKQ32, Q_TILE_SZ * DV); + + // k indices + const int ik3 = iq3 / rk3; + const int ik2 = iq2 / rk2; + + // v indices + const int iv3 = iq3 / rv3; + const int iv2 = iq2 / rv2; + + const float * pq = (const float *) ((char *) q->data + (iq1 * nbq1 + iq2 * nbq2 + iq3 * nbq3)); + if (kv_type == GGML_TYPE_F16) { + rvv_pack_f32_as_scaled_f16((uint8_t *) Q_f16, DK * sizeof(_Float16), (uint8_t *) pq, nbq1, tile_rows, DK, + scale); + } else { + memcpy2d(Q_f32, DK * sizeof(float), pq, nbq1, tile_rows, DK * sizeof(float)); + } + + for (int64_t ic = 0; ic < nek1; ic += KV_TILE_SZ) { + const int kv_tile = (int) std::min((int64_t) KV_TILE_SZ, nek1 - ic); + + rvv_zero_f32(K_f32, DK * KV_TILE_SZ); + rvv_zero_f32(V32, KV_TILE_SZ * DV); + + // skip the tile entirely if all the masks are -inf + if (mask) { + bool can_skip = true; + const ggml_fp16_t * mp_row = + (const ggml_fp16_t *) ((const char *) mask->data + iq1 * mask->nb[1] + + (iq2 % mask->ne[2]) * mask->nb[2] + (iq3 % mask->ne[3]) * mask->nb[3]); + rvv_pack_scaled_f16_as_f32(mask32, KV_TILE_SZ * sizeof(float), mp_row + ic, mask->nb[1], tile_rows, + kv_tile, slope); + + for (int tq = 0; tq < tile_rows; tq++) { + for (int tk = 0; tk < kv_tile; tk++) { + if (mask32[tq * KV_TILE_SZ + tk] != -INFINITY) { + can_skip = false; + } + } + // Pad remaining mask entries with -inf + for (int tk = kv_tile; tk < KV_TILE_SZ; tk++) { + mask32[tq * KV_TILE_SZ + tk] = -INFINITY; + } + } + + if (can_skip) { + continue; + } + } + + if (kv_type == GGML_TYPE_F16) { + rvv_transposed_s16_mn_to_nm((int8_t *) K_f16, KV_TILE_SZ * sizeof(_Float16), + (int8_t *) k->data + ic * nbk1 + ik2 * nbk2 + ik3 * nbk3, nbk1, kv_tile, + DK); + + int tq = 0; + for (; tq + 3 < tile_rows; tq += 4) { + rvv_qk_dot_tile_f16_x4(KQ + (tq + 0) * KV_TILE_SZ, KQ + (tq + 1) * KV_TILE_SZ, + KQ + (tq + 2) * KV_TILE_SZ, KQ + (tq + 3) * KV_TILE_SZ, + Q_f16 + (tq + 0) * DK, Q_f16 + (tq + 1) * DK, Q_f16 + (tq + 2) * DK, + Q_f16 + (tq + 3) * DK, K_f16, DK, kv_tile); + } + for (; tq < tile_rows; ++tq) { + rvv_qk_dot_tile_f16_x1(KQ + tq * KV_TILE_SZ, Q_f16 + tq * DK, K_f16, DK, kv_tile); + } + } else { + for (int tk = 0; tk < kv_tile; tk++) { + const char * k_data = (const char *) k->data + (ic + tk) * nbk1 + ik2 * nbk2 + ik3 * nbk3; + float * k_col = K_f32 + tk; + const float * k_src = (const float *) k_data; + for (int64_t dk = 0; dk < DK; ++dk) { + k_col[dk * KV_TILE_SZ] = k_src[dk]; + } + } + + for (int tq = 0; tq < tile_rows; ++tq) { + rvv_qk_dot_tile(KQ + tq * KV_TILE_SZ, Q_f32 + tq * DK, K_f32, DK, KV_TILE_SZ, scale); + } + } + + // Set padded KQ entries to -inf so softmax gives them zero weight + if (kv_tile < KV_TILE_SZ) { + for (int tq = 0; tq < tile_rows; tq++) { + for (int tk = kv_tile; tk < KV_TILE_SZ; tk++) { + KQ[tq * KV_TILE_SZ + tk] = -INFINITY; + } + } + } + + if (logit_softcap != 0.0f) { + rvv_softcap_tanh_inplace_f32(KQ, KV_TILE_SZ, tile_rows, KV_TILE_SZ, logit_softcap); + } + + if (mask) { + rvv_add_inplace_f32(KQ, KV_TILE_SZ, mask32, KV_TILE_SZ, tile_rows, KV_TILE_SZ); + } + + bool skip[Q_TILE_SZ] = {}; + + for (int tq = 0; tq < tile_rows; tq++) { + float * kq_row = KQ + tq * KV_TILE_SZ; + + const float tile_max = rvv_max_f32(kq_row, KV_TILE_SZ); + + if (tile_max == -INFINITY) { + skip[tq] = true; + continue; + } + + const float Mold = M[tq]; + const float Mnew = fmaxf(Mold, tile_max); + + if (Mnew > Mold) { + const float ms = expf(Mold - Mnew); + rvv_scale_f32(VKQ32 + tq * DV, ms, DV); + S[tq] *= ms; + } + M[tq] = Mnew; + + S[tq] += rvv_softmax_exp_inplace_f32(kq_row, KV_TILE_SZ, Mnew); + } + + // Pack V as contiguous [KV_TILE_SZ][DV]. + if (kv_type == GGML_TYPE_F16) { + const char * v_data = (const char *) v->data + ic * nbv1 + iv2 * nbv2 + iv3 * nbv3; + memcpy2d(V_f16, DV * sizeof(_Float16), v_data, nbv1, kv_tile, DV * sizeof(_Float16)); + + int tq = 0; + for (; tq + 3 < tile_rows; tq += 4) { + if (skip[tq + 0] || skip[tq + 1] || skip[tq + 2] || skip[tq + 3]) { + for (int i = 0; i < 4; ++i) { + if (!skip[tq + i]) { + rvv_pv_accumulate_f16_x1(VKQ32 + (tq + i) * DV, KQ + (tq + i) * KV_TILE_SZ, V_f16, + KV_TILE_SZ, DV); + } + } + continue; + } + + rvv_pv_accumulate_f16_x4(VKQ32 + (tq + 0) * DV, VKQ32 + (tq + 1) * DV, VKQ32 + (tq + 2) * DV, + VKQ32 + (tq + 3) * DV, KQ + (tq + 0) * KV_TILE_SZ, + KQ + (tq + 1) * KV_TILE_SZ, KQ + (tq + 2) * KV_TILE_SZ, + KQ + (tq + 3) * KV_TILE_SZ, V_f16, KV_TILE_SZ, DV); + } + for (; tq < tile_rows; ++tq) { + if (!skip[tq]) { + rvv_pv_accumulate_f16_x1(VKQ32 + tq * DV, KQ + tq * KV_TILE_SZ, V_f16, KV_TILE_SZ, DV); + } + } + } else { + const char * v_data = (const char *) v->data + ic * nbv1 + iv2 * nbv2 + iv3 * nbv3; + memcpy2d(V32, DV * sizeof(float), v_data, nbv1, kv_tile, DV * sizeof(float)); + + for (int tq = 0; tq < tile_rows; ++tq) { + if (!skip[tq]) { + rvv_pv_accumulate(VKQ32 + tq * DV, KQ + tq * KV_TILE_SZ, V32, KV_TILE_SZ, DV); + } + } + } + } + + // sinks (apply only to valid rows in the tile) + if (sinks) { + const float s = ((float *) ((char *) sinks->data))[h]; + + for (int tq = 0; tq < tile_rows; tq++) { + float ms = 1.0f; + float vs = 1.0f; + + if (s > M[tq]) { + ms = expf(M[tq] - s); + rvv_scale_f32(VKQ32 + tq * DV, ms, DV); + } else { + vs = expf(s - M[tq]); + } + + float S_temp = S[tq] * ms + vs; + S[tq] = S_temp == 0.0f ? 0.0f : 1.0f / S_temp; + } + } else { + for (int tq = 0; tq < tile_rows; tq++) { + const float S_inv = S[tq] == 0.0f ? 0.0f : 1.0f / S[tq]; + S[tq] = S_inv; + } + } + + float * dst_ptr = (float *) ((char *) dst->data + (iq3 * ne2 * ne1 + iq2 + (iq1) *ne1) * nb1); + rvv_pack_scaled_f32_as_f32(dst_ptr, nb1 * ne1, VKQ32, DV * sizeof(float), tile_rows, DV, S); + + ir += tile_rows; + } +} + +void forward_rms_norm_f32(ggml_compute_params * params, ggml_tensor * op) { + const ggml_tensor * src0 = op->src[0]; + ggml_tensor * dst = op; + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(src0->nb[0] == sizeof(float)); + + int ith = params->ith; + int nth = params->nth; + + GGML_TENSOR_UNARY_OP_LOCALS + + float epsilon = *((float *) dst->op_params); + + GGML_ASSERT(epsilon > 0.0f); + + auto * input = (char *) src0->data; + auto * output = (char *) dst->data; + + const auto hidden_size = ne00; + const auto task_count = ne01 * ne02 * ne03; + const auto task_per_thread = (task_count + nth - 1) / nth; + + const auto task_begin = ith * task_per_thread; + const auto task_end = std::min((ith + 1) * task_per_thread, task_count); + + for (auto task_idx = task_begin; task_idx < task_end; task_idx++) { + int64_t i03 = task_idx / (ne02 * ne01); + int64_t i02 = (task_idx - i03 * ne02 * ne01) / ne01; + int64_t i01 = (task_idx - i03 * ne02 * ne01 - i02 * ne01); + + auto * p_input = (float *) (input + i01 * nb01 + i02 * nb02 + i03 * nb03); + auto * p_output = (float *) (output + i01 * nb1 + i02 * nb2 + i03 * nb3); + auto * p_temp_output = p_output; + + size_t gvl = __riscv_vsetvlmax_e32m4(); + vfloat32m4_t sum_sq = __riscv_vfmv_v_f_f32m4(0.f, gvl); + int64_t length = hidden_size; + while (length > 0) { + gvl = __riscv_vsetvl_e32m4(length); + vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_input, gvl); + sum_sq = __riscv_vfmacc_vv_f32m4(sum_sq, src_data, src_data, gvl); + __riscv_vse32_v_f32m4(p_temp_output, src_data, gvl); + + p_input += gvl; + p_temp_output += gvl; + length -= gvl; + } + + gvl = __riscv_vsetvlmax_e32m1(); + vfloat32m1_t zero_v = __riscv_vfmv_v_f_f32m1(0.f, gvl); + vfloat32m1_t mean_square_v = + __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(sum_sq, 0), __riscv_vget_v_f32m4_f32m1(sum_sq, 1), gvl); + + mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 2), gvl); + mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 3), gvl); + mean_square_v = __riscv_vfredusum_vs_f32m1_f32m1(mean_square_v, zero_v, gvl); + + float mean_square = __riscv_vfmv_f_s_f32m1_f32(mean_square_v); + mean_square /= hidden_size; + + mean_square = sqrt(mean_square + epsilon); + + mean_square = 1.0f / mean_square; + length = hidden_size; + p_temp_output = p_output; + + while (length > 0) { + gvl = __riscv_vsetvl_e32m4(length); + vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); + src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); + __riscv_vse32_v_f32m4(p_output, src_data, gvl); + p_temp_output += gvl; + p_output += gvl; + length -= gvl; + } + } +} + +template +void quantize_a_nrow_i8_ref(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr) { + int64_t a_blk_stride = q8_blk_size(blk_len, true); + int64_t a_nrow_block_stride = a_blk_stride * MB_ROWS; + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_nrow_block_stride) { + float * scale_a_ptr = reinterpret_cast(quant_a_ptr); + int16_t * a_sum_ptr = reinterpret_cast(quant_a_ptr + sizeof(float) * MB_ROWS); + int8_t * quant_a_blk = + reinterpret_cast(quant_a_ptr + sizeof(float) * MB_ROWS + sizeof(int16_t) * MB_ROWS); + + for (size_t row = 0; row < MB_ROWS; row++) { + float max_abs_a = 0.0f; + for (size_t bk = 0; bk < blk_len; bk++) { + max_abs_a = std::max(max_abs_a, std::abs(a_ptr[row * count_k + k + bk])); + } + + float rep_scale_a = ((1 << 7) - 1) / max_abs_a; + scale_a_ptr[row] = 1 / rep_scale_a; + + int16_t a_sum = 0; + for (size_t bk = 0; bk < blk_len; bk++) { + const int8_t quantized = static_cast( + std::clamp(std::nearbyintf(a_ptr[row * count_k + k + bk] * rep_scale_a), -128.0f, 127.0f)); + quant_a_blk[row * blk_len + bk] = quantized; + a_sum += quantized; + } + a_sum_ptr[row] = -a_sum; + } + } +} + +template +void quantize_a_nrow_i8_hp_ref(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr) { + constexpr size_t k_subblk_len = 32; + const size_t subblk_count = blk_len / k_subblk_len; + + GGML_ASSERT(blk_len == 256); + + float scale_temp[8] = { 0.0f }; + int64_t a_blk_stride = q8_hp_blk_size(blk_len, true, true); + int64_t a_nrow_block_stride = a_blk_stride * MB_ROWS; + int64_t a_subblk_stride = q8_hp_blk_size(k_subblk_len, false, false) * MB_ROWS; + + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_nrow_block_stride) { + _Float16 * a_sum_ptr = reinterpret_cast<_Float16 *>(quant_a_ptr + a_subblk_stride * subblk_count); + + float scale_avg = 0.0f; + for (size_t kk = 0; kk < subblk_count; kk++) { + float max_abs_a = 0.0f; + for (size_t row = 0; row < MB_ROWS; row++) { + for (size_t bk = 0; bk < k_subblk_len; bk++) { + max_abs_a = std::max(max_abs_a, std::abs(a_ptr[row * count_k + k + bk + kk * k_subblk_len])); + } + } + scale_temp[kk] = max_abs_a / ((1 << 7) - 1); + scale_avg += scale_temp[kk]; + } + + scale_avg /= subblk_count; + float scale_factor = 1.0f / scale_avg; + + _Float16 * scale_avg_ptr = + reinterpret_cast<_Float16 *>(quant_a_ptr + a_nrow_block_stride - sizeof(_Float16) * MB_ROWS); + scale_avg_ptr[0] = scale_avg; + + for (size_t kk = 0; kk < subblk_count; kk++) { + uint8_t * a_subblk_base = quant_a_ptr + kk * a_subblk_stride; + _Float16 * scale_a_ptr = reinterpret_cast<_Float16 *>(a_subblk_base); + int8_t * quant_a_blk = reinterpret_cast(a_subblk_base + sizeof(_Float16) * MB_ROWS); + + scale_a_ptr[0] = static_cast<_Float16>(scale_temp[kk] * scale_factor); + + const float rep_scale_a = 1.0f / scale_temp[kk]; + + for (size_t row = 0; row < MB_ROWS; row++) { + int16_t a_sum = 0; + for (size_t bk = 0; bk < k_subblk_len; bk++) { + const int8_t quantized = static_cast( + std::clamp(std::nearbyintf(a_ptr[row * count_k + k + bk + kk * k_subblk_len] * rep_scale_a), + -128.0f, 127.0f)); + quant_a_blk[row * k_subblk_len + bk] = quantized; + a_sum += quantized; + } + a_sum_ptr[row * subblk_count + kk] = static_cast<_Float16>(-a_sum) * static_cast<_Float16>(8.0f); + } + } + } +} + +template +void quantize_a_nrow_i8k_ref(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr) { + int64_t a_blk_stride = q8k_blk_size(256); + int64_t a_nrow_block_stride = a_blk_stride * MB_ROWS; + int64_t a_sum_size = 256 / 16; + + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_nrow_block_stride) { + float * scale_a_ptr = reinterpret_cast(quant_a_ptr); + int16_t * a_sum_ptr = reinterpret_cast(quant_a_ptr + sizeof(float) * MB_ROWS); + int8_t * quant_a_blk = + reinterpret_cast(quant_a_ptr + sizeof(float) * MB_ROWS + sizeof(int16_t) * a_sum_size * MB_ROWS); + + for (size_t row = 0; row < MB_ROWS; row++) { + float max_a = 0.0f; + float max_abs_a = 0.0f; + for (size_t bk = 0; bk < blk_len; bk++) { + float ax = std::abs(a_ptr[row * count_k + k + bk]); + if (ax > max_abs_a) { + max_abs_a = ax; + max_a = a_ptr[row * count_k + k + bk]; + } + } + + if (!max_abs_a) { + scale_a_ptr[row] = 0; + for (size_t bki = 0; bki < a_sum_size; bki++) { + for (size_t bk = bki * 16; bk < (bki + 1) * 16; bk++) { + quant_a_blk[row * blk_len + bk] = 0; + } + a_sum_ptr[row * a_sum_size + bki] = 0; + } + continue; + } + + float rep_scale_a = ((1 << 7) - 1) / max_abs_a; + scale_a_ptr[row] = 1 / rep_scale_a; + + for (size_t bki = 0; bki < a_sum_size; bki++) { + int16_t a_sum = 0; + for (size_t bk = bki * 16; bk < (bki + 1) * 16; bk++) { + const int8_t quantized = static_cast( + std::clamp(std::nearbyintf(a_ptr[row * count_k + k + bk] * rep_scale_a), -128.0f, 127.0f)); + quant_a_blk[row * blk_len + bk] = quantized; + a_sum += quantized; + } + a_sum_ptr[row * a_sum_size + bki] = -a_sum; + } + } + } +} + +void quantize_a_row_i8(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr) { + GGML_ASSERT(blk_len == 32); + int64_t a_blk_stride = q8_blk_size(blk_len, true); + size_t vlenb = __riscv_vlenb(); + + if (vlenb == 128) { + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_blk_stride) { + float * scale_a_ptr = reinterpret_cast(quant_a_ptr); + int16_t * a_sum_ptr = reinterpret_cast(quant_a_ptr + sizeof(float)); + int8_t * quant_a_blk = reinterpret_cast(quant_a_ptr + sizeof(float) + sizeof(int16_t)); + + size_t vl = __riscv_vsetvl_e32m1(blk_len); + vfloat32m1_t v_a = __riscv_vle32_v_f32m1(a_ptr + k, vl); + vfloat32m1_t v_a_abs = __riscv_vfabs_v_f32m1(v_a, vl); + + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); + vfloat32m1_t v_a_max = __riscv_vfredmax_vs_f32m1_f32m1(v_a_abs, tmp, vl); + float max_abs_a = __riscv_vfmv_f_s_f32m1_f32(v_a_max); + + float scale_a = max_abs_a / ((1 << 7) - 1); + float rep_scale_a = scale_a ? 1.0f / scale_a : 0.0f; + scale_a_ptr[0] = scale_a; + + vfloat32m1_t v_a_scale = __riscv_vfmul_vf_f32m1(v_a, rep_scale_a, vl); + vint16mf2_t v_a_quant = __riscv_vfncvt_x_f_w_i16mf2(v_a_scale, vl); + vint8mf4_t v_a_quant_i8 = __riscv_vncvt_x_x_w_i8mf4(v_a_quant, vl); + + vint16m1_t tmp_sum = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_a_sum = __riscv_vwredsum_vs_i8mf4_i16m1(v_a_quant_i8, tmp_sum, vl); + int16_t a_sum = __riscv_vmv_x_s_i16m1_i16(v_a_sum); + a_sum_ptr[0] = -a_sum; + + __riscv_vse8_v_i8mf4(quant_a_blk, v_a_quant_i8, vl); + } + } else if (vlenb == 32) { + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_blk_stride) { + float * scale_a_ptr = reinterpret_cast(quant_a_ptr); + int16_t * a_sum_ptr = reinterpret_cast(quant_a_ptr + sizeof(float)); + int8_t * quant_a_blk = reinterpret_cast(quant_a_ptr + sizeof(float) + sizeof(int16_t)); + + size_t vl = __riscv_vsetvl_e32m4(blk_len); + vfloat32m4_t v_a = __riscv_vle32_v_f32m4(a_ptr + k, vl); + vfloat32m4_t v_a_abs = __riscv_vfabs_v_f32m4(v_a, vl); + + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); + vfloat32m1_t v_a_max = __riscv_vfredmax_vs_f32m4_f32m1(v_a_abs, tmp, vl); + float max_abs_a = __riscv_vfmv_f_s_f32m1_f32(v_a_max); + + float scale_a = max_abs_a / ((1 << 7) - 1); + float rep_scale_a = scale_a ? 1.0f / scale_a : 0.0f; + scale_a_ptr[0] = scale_a; + + vfloat32m4_t v_a_scale = __riscv_vfmul_vf_f32m4(v_a, rep_scale_a, vl); + vint16m2_t v_a_quant = __riscv_vfncvt_x_f_w_i16m2(v_a_scale, vl); + vint8m1_t v_a_quant_i8 = __riscv_vncvt_x_x_w_i8m1(v_a_quant, vl); + + vint16m1_t tmp_sum = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_a_sum = __riscv_vwredsum_vs_i8m1_i16m1(v_a_quant_i8, tmp_sum, vl); + int16_t a_sum = __riscv_vmv_x_s_i16m1_i16(v_a_sum); + a_sum_ptr[0] = -a_sum; + + __riscv_vse8_v_i8m1(quant_a_blk, v_a_quant_i8, vl); + } + } else { + quantize_a_nrow_i8_ref<1>(blk_len, a_ptr, count_k, quant_a_ptr); + } +} + +void quantize_a_4row_i8(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr) { + GGML_ASSERT(blk_len == 32); + int64_t a_blk_stride = q8_blk_size(blk_len, true); + int64_t a_nrow_block_stride = a_blk_stride * 4; + size_t vlenb = __riscv_vlenb(); + + if (vlenb == 128) { + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_nrow_block_stride) { + float * scale_a_ptr = reinterpret_cast(quant_a_ptr); + int16_t * a_sum_ptr = reinterpret_cast(quant_a_ptr + sizeof(float) * 4); + int8_t * quant_a_blk = reinterpret_cast(quant_a_ptr + sizeof(float) * 4 + sizeof(int16_t) * 4); + + for (size_t mi = 0; mi < 4; mi++) { + size_t vl = __riscv_vsetvl_e32m1(blk_len); + vfloat32m1_t v_a = __riscv_vle32_v_f32m1(a_ptr + mi * count_k + k, vl); + vfloat32m1_t v_a_abs = __riscv_vfabs_v_f32m1(v_a, vl); + + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); + vfloat32m1_t v_a_max = __riscv_vfredmax_vs_f32m1_f32m1(v_a_abs, tmp, vl); + float max_abs_a = __riscv_vfmv_f_s_f32m1_f32(v_a_max); + + float scale_a = max_abs_a / ((1 << 7) - 1); + float rep_scale_a = scale_a ? 1.0f / scale_a : 0.0f; + scale_a_ptr[mi] = scale_a; + + vfloat32m1_t v_a_scale = __riscv_vfmul_vf_f32m1(v_a, rep_scale_a, vl); + vint16mf2_t v_a_quant = __riscv_vfncvt_x_f_w_i16mf2(v_a_scale, vl); + vint8mf4_t v_a_quant_i8 = __riscv_vncvt_x_x_w_i8mf4(v_a_quant, vl); + + vint16m1_t tmp_sum = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_a_sum = __riscv_vwredsum_vs_i8mf4_i16m1(v_a_quant_i8, tmp_sum, vl); + int16_t a_sum = __riscv_vmv_x_s_i16m1_i16(v_a_sum); + a_sum_ptr[mi] = -a_sum; + + __riscv_vse8_v_i8mf4(quant_a_blk + mi * blk_len, v_a_quant_i8, vl); + } + } + } else if (vlenb == 32) { + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_nrow_block_stride) { + float * scale_a_ptr = reinterpret_cast(quant_a_ptr); + int16_t * a_sum_ptr = reinterpret_cast(quant_a_ptr + sizeof(float) * 4); + int8_t * quant_a_blk = reinterpret_cast(quant_a_ptr + sizeof(float) * 4 + sizeof(int16_t) * 4); + + for (size_t mi = 0; mi < 4; mi++) { + size_t vl = __riscv_vsetvl_e32m4(blk_len); + vfloat32m4_t v_a = __riscv_vle32_v_f32m4(a_ptr + mi * count_k + k, vl); + vfloat32m4_t v_a_abs = __riscv_vfabs_v_f32m4(v_a, vl); + + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); + vfloat32m1_t v_a_max = __riscv_vfredmax_vs_f32m4_f32m1(v_a_abs, tmp, vl); + float max_abs_a = __riscv_vfmv_f_s_f32m1_f32(v_a_max); + + float scale_a = max_abs_a / ((1 << 7) - 1); + float rep_scale_a = scale_a ? 1.0f / scale_a : 0.0f; + scale_a_ptr[mi] = scale_a; + + vfloat32m4_t v_a_scale = __riscv_vfmul_vf_f32m4(v_a, rep_scale_a, vl); + vint16m2_t v_a_quant = __riscv_vfncvt_x_f_w_i16m2(v_a_scale, vl); + vint8m1_t v_a_quant_i8 = __riscv_vncvt_x_x_w_i8m1(v_a_quant, vl); + + vint16m1_t tmp_sum = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_a_sum = __riscv_vwredsum_vs_i8m1_i16m1(v_a_quant_i8, tmp_sum, vl); + int16_t a_sum = __riscv_vmv_x_s_i16m1_i16(v_a_sum); + a_sum_ptr[mi] = -a_sum; + + __riscv_vse8_v_i8m1(quant_a_blk + mi * blk_len, v_a_quant_i8, vl); + } + } + } else { + quantize_a_nrow_i8_ref<4>(blk_len, a_ptr, count_k, quant_a_ptr); + } +} + +void quantize_a_row_i8_hp(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr) { + constexpr size_t k_subblk_len = 32; + GGML_ASSERT(blk_len == 256); + + constexpr size_t subblk_count = 256 / k_subblk_len; + int64_t a_blk_stride = q8_hp_blk_size(blk_len, true, true); + int64_t a_subblk_stride = q8_hp_blk_size(k_subblk_len, false, false); + size_t vlenb = __riscv_vlenb(); + float scale_temp[subblk_count] = { 0.0f }; + + if (vlenb == 128) { + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_blk_stride) { + _Float16 * a_sum_ptr = reinterpret_cast<_Float16 *>(quant_a_ptr + a_subblk_stride * subblk_count); + _Float16 * scale_avg_ptr = reinterpret_cast<_Float16 *>(quant_a_ptr + a_blk_stride - sizeof(_Float16)); + float scale_avg = 0.0f; + + for (size_t kk = 0; kk < subblk_count; ++kk) { + const float * a_src_ptr = a_ptr + k + kk * k_subblk_len; + + size_t vl = __riscv_vsetvl_e32m1(k_subblk_len); + vfloat32m1_t v_a = __riscv_vle32_v_f32m1(a_src_ptr, vl); + vfloat32m1_t v_a_abs = __riscv_vfabs_v_f32m1(v_a, vl); + + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); + vfloat32m1_t v_a_max = __riscv_vfredmax_vs_f32m1_f32m1(v_a_abs, tmp, vl); + float max_abs_a = __riscv_vfmv_f_s_f32m1_f32(v_a_max); + + scale_temp[kk] = max_abs_a / ((1 << 7) - 1); + scale_avg += scale_temp[kk]; + } + + scale_avg /= subblk_count; + const float scale_factor = scale_avg ? 1.0f / scale_avg : 0.0f; + scale_avg_ptr[0] = static_cast<_Float16>(scale_avg); + + for (size_t kk = 0; kk < subblk_count; ++kk) { + uint8_t * a_subblk_base = quant_a_ptr + kk * a_subblk_stride; + _Float16 * scale_a_ptr = reinterpret_cast<_Float16 *>(a_subblk_base); + int8_t * quant_a_blk = reinterpret_cast(a_subblk_base + sizeof(_Float16)); + const float * a_src_ptr = a_ptr + k + kk * k_subblk_len; + + size_t vl = __riscv_vsetvl_e32m1(k_subblk_len); + vfloat32m1_t v_a = __riscv_vle32_v_f32m1(a_src_ptr, vl); + float rep_scale_a = scale_temp[kk] ? 1.0f / scale_temp[kk] : 0.0f; + scale_a_ptr[0] = static_cast<_Float16>(scale_temp[kk] * scale_factor); + + vfloat32m1_t v_a_scale = __riscv_vfmul_vf_f32m1(v_a, rep_scale_a, vl); + vint16mf2_t v_a_quant = __riscv_vfncvt_x_f_w_i16mf2(v_a_scale, vl); + vint8mf4_t v_a_quant_i8 = __riscv_vncvt_x_x_w_i8mf4(v_a_quant, vl); + + vint16m1_t tmp_sum = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_a_sum = __riscv_vwredsum_vs_i8mf4_i16m1(v_a_quant_i8, tmp_sum, vl); + int16_t a_sum = __riscv_vmv_x_s_i16m1_i16(v_a_sum); + a_sum_ptr[kk] = static_cast<_Float16>(-a_sum) * static_cast<_Float16>(8.0f); + + __riscv_vse8_v_i8mf4(quant_a_blk, v_a_quant_i8, vl); + } + } + } else if (vlenb == 32) { + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_blk_stride) { + _Float16 * a_sum_ptr = reinterpret_cast<_Float16 *>(quant_a_ptr + a_subblk_stride * subblk_count); + _Float16 * scale_avg_ptr = reinterpret_cast<_Float16 *>(quant_a_ptr + a_blk_stride - sizeof(_Float16)); + float scale_avg = 0.0f; + + for (size_t kk = 0; kk < subblk_count; ++kk) { + const float * a_src_ptr = a_ptr + k + kk * k_subblk_len; + + size_t vl = __riscv_vsetvl_e32m4(k_subblk_len); + vfloat32m4_t v_a = __riscv_vle32_v_f32m4(a_src_ptr, vl); + vfloat32m4_t v_a_abs = __riscv_vfabs_v_f32m4(v_a, vl); + + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); + vfloat32m1_t v_a_max = __riscv_vfredmax_vs_f32m4_f32m1(v_a_abs, tmp, vl); + float max_abs_a = __riscv_vfmv_f_s_f32m1_f32(v_a_max); + + scale_temp[kk] = max_abs_a / ((1 << 7) - 1); + scale_avg += scale_temp[kk]; + } + + scale_avg /= subblk_count; + const float scale_factor = scale_avg ? 1.0f / scale_avg : 0.0f; + scale_avg_ptr[0] = static_cast<_Float16>(scale_avg); + + for (size_t kk = 0; kk < subblk_count; ++kk) { + uint8_t * a_subblk_base = quant_a_ptr + kk * a_subblk_stride; + _Float16 * scale_a_ptr = reinterpret_cast<_Float16 *>(a_subblk_base); + int8_t * quant_a_blk = reinterpret_cast(a_subblk_base + sizeof(_Float16)); + const float * a_src_ptr = a_ptr + k + kk * k_subblk_len; + + size_t vl = __riscv_vsetvl_e32m4(k_subblk_len); + vfloat32m4_t v_a = __riscv_vle32_v_f32m4(a_src_ptr, vl); + float rep_scale_a = scale_temp[kk] ? 1.0f / scale_temp[kk] : 0.0f; + scale_a_ptr[0] = static_cast<_Float16>(scale_temp[kk] * scale_factor); + + vfloat32m4_t v_a_scale = __riscv_vfmul_vf_f32m4(v_a, rep_scale_a, vl); + vint16m2_t v_a_quant = __riscv_vfncvt_x_f_w_i16m2(v_a_scale, vl); + vint8m1_t v_a_quant_i8 = __riscv_vncvt_x_x_w_i8m1(v_a_quant, vl); + + vint16m1_t tmp_sum = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_a_sum = __riscv_vwredsum_vs_i8m1_i16m1(v_a_quant_i8, tmp_sum, vl); + int16_t a_sum = __riscv_vmv_x_s_i16m1_i16(v_a_sum); + a_sum_ptr[kk] = static_cast<_Float16>(-a_sum) * static_cast<_Float16>(8.0f); + + __riscv_vse8_v_i8m1(quant_a_blk, v_a_quant_i8, vl); + } + } + } else { + quantize_a_nrow_i8_hp_ref<1>(blk_len, a_ptr, count_k, quant_a_ptr); + } +} + +void quantize_a_4row_i8_hp(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr) { + constexpr size_t k_subblk_len = 32; + GGML_ASSERT(blk_len == 256); + + constexpr size_t subblk_count = 256 / k_subblk_len; + int64_t a_blk_stride = q8_hp_blk_size(blk_len, true, true); + int64_t a_nrow_block_stride = a_blk_stride * 4; + int64_t a_subblk_stride = q8_hp_blk_size(k_subblk_len, false, false) * 4; + size_t vlenb = __riscv_vlenb(); + float scale_temp[subblk_count] = { 0.0f }; + + if (vlenb == 128) { + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_nrow_block_stride) { + _Float16 * a_sum_ptr = reinterpret_cast<_Float16 *>(quant_a_ptr + a_subblk_stride * subblk_count); + _Float16 * scale_avg_ptr = + reinterpret_cast<_Float16 *>(quant_a_ptr + a_nrow_block_stride - sizeof(_Float16) * 4); + float scale_avg = 0.0f; + + for (size_t kk = 0; kk < subblk_count; ++kk) { + const float * a_src_ptr0 = a_ptr + 0 * count_k + k + kk * k_subblk_len; + const float * a_src_ptr1 = a_ptr + 1 * count_k + k + kk * k_subblk_len; + const float * a_src_ptr2 = a_ptr + 2 * count_k + k + kk * k_subblk_len; + const float * a_src_ptr3 = a_ptr + 3 * count_k + k + kk * k_subblk_len; + + size_t vl = __riscv_vsetvl_e32m1(k_subblk_len); + vfloat32m1_t v_a0 = __riscv_vle32_v_f32m1(a_src_ptr0, vl); + vfloat32m1_t v_a1 = __riscv_vle32_v_f32m1(a_src_ptr1, vl); + vfloat32m1_t v_a2 = __riscv_vle32_v_f32m1(a_src_ptr2, vl); + vfloat32m1_t v_a3 = __riscv_vle32_v_f32m1(a_src_ptr3, vl); + vfloat32m1_t v_a0_abs = __riscv_vfabs_v_f32m1(v_a0, vl); + vfloat32m1_t v_a1_abs = __riscv_vfabs_v_f32m1(v_a1, vl); + vfloat32m1_t v_a2_abs = __riscv_vfabs_v_f32m1(v_a2, vl); + vfloat32m1_t v_a3_abs = __riscv_vfabs_v_f32m1(v_a3, vl); + + vfloat32m1_t v_max_abs = __riscv_vfmax_vv_f32m1(v_a0_abs, v_a1_abs, vl); + v_max_abs = __riscv_vfmax_vv_f32m1(v_max_abs, v_a2_abs, vl); + v_max_abs = __riscv_vfmax_vv_f32m1(v_max_abs, v_a3_abs, vl); + + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); + vfloat32m1_t v_a_max = __riscv_vfredmax_vs_f32m1_f32m1(v_max_abs, tmp, vl); + float max_abs_a = __riscv_vfmv_f_s_f32m1_f32(v_a_max); + + scale_temp[kk] = max_abs_a / ((1 << 7) - 1); + scale_avg += scale_temp[kk]; + } + + scale_avg /= subblk_count; + const float scale_factor = scale_avg ? 1.0f / scale_avg : 0.0f; + scale_avg_ptr[0] = static_cast<_Float16>(scale_avg); + + for (size_t kk = 0; kk < subblk_count; ++kk) { + uint8_t * a_subblk_base = quant_a_ptr + kk * a_subblk_stride; + _Float16 * scale_a_ptr = reinterpret_cast<_Float16 *>(a_subblk_base); + int8_t * quant_a_blk = reinterpret_cast(a_subblk_base + sizeof(_Float16) * 4); + const float * a_src_ptr0 = a_ptr + 0 * count_k + k + kk * k_subblk_len; + const float * a_src_ptr1 = a_ptr + 1 * count_k + k + kk * k_subblk_len; + const float * a_src_ptr2 = a_ptr + 2 * count_k + k + kk * k_subblk_len; + const float * a_src_ptr3 = a_ptr + 3 * count_k + k + kk * k_subblk_len; + + size_t vl = __riscv_vsetvl_e32m1(k_subblk_len); + vfloat32m1_t v_a0 = __riscv_vle32_v_f32m1(a_src_ptr0, vl); + vfloat32m1_t v_a1 = __riscv_vle32_v_f32m1(a_src_ptr1, vl); + vfloat32m1_t v_a2 = __riscv_vle32_v_f32m1(a_src_ptr2, vl); + vfloat32m1_t v_a3 = __riscv_vle32_v_f32m1(a_src_ptr3, vl); + + float rep_scale_a = scale_temp[kk] ? 1.0f / scale_temp[kk] : 0.0f; + scale_a_ptr[0] = static_cast<_Float16>(scale_temp[kk] * scale_factor); + + vfloat32m1_t v_a0_scale = __riscv_vfmul_vf_f32m1(v_a0, rep_scale_a, vl); + vfloat32m1_t v_a1_scale = __riscv_vfmul_vf_f32m1(v_a1, rep_scale_a, vl); + vfloat32m1_t v_a2_scale = __riscv_vfmul_vf_f32m1(v_a2, rep_scale_a, vl); + vfloat32m1_t v_a3_scale = __riscv_vfmul_vf_f32m1(v_a3, rep_scale_a, vl); + vint16mf2_t v_a0_quant = __riscv_vfncvt_x_f_w_i16mf2(v_a0_scale, vl); + vint16mf2_t v_a1_quant = __riscv_vfncvt_x_f_w_i16mf2(v_a1_scale, vl); + vint16mf2_t v_a2_quant = __riscv_vfncvt_x_f_w_i16mf2(v_a2_scale, vl); + vint16mf2_t v_a3_quant = __riscv_vfncvt_x_f_w_i16mf2(v_a3_scale, vl); + vint8mf4_t v_a0_quant_i8 = __riscv_vncvt_x_x_w_i8mf4(v_a0_quant, vl); + vint8mf4_t v_a1_quant_i8 = __riscv_vncvt_x_x_w_i8mf4(v_a1_quant, vl); + vint8mf4_t v_a2_quant_i8 = __riscv_vncvt_x_x_w_i8mf4(v_a2_quant, vl); + vint8mf4_t v_a3_quant_i8 = __riscv_vncvt_x_x_w_i8mf4(v_a3_quant, vl); + + vint16m1_t tmp_sum0 = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t tmp_sum1 = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t tmp_sum2 = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t tmp_sum3 = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_a0_sum = __riscv_vwredsum_vs_i8mf4_i16m1(v_a0_quant_i8, tmp_sum0, vl); + vint16m1_t v_a1_sum = __riscv_vwredsum_vs_i8mf4_i16m1(v_a1_quant_i8, tmp_sum1, vl); + vint16m1_t v_a2_sum = __riscv_vwredsum_vs_i8mf4_i16m1(v_a2_quant_i8, tmp_sum2, vl); + vint16m1_t v_a3_sum = __riscv_vwredsum_vs_i8mf4_i16m1(v_a3_quant_i8, tmp_sum3, vl); + + a_sum_ptr[0 * subblk_count + kk] = + static_cast<_Float16>(-__riscv_vmv_x_s_i16m1_i16(v_a0_sum)) * static_cast<_Float16>(8.0f); + a_sum_ptr[1 * subblk_count + kk] = + static_cast<_Float16>(-__riscv_vmv_x_s_i16m1_i16(v_a1_sum)) * static_cast<_Float16>(8.0f); + a_sum_ptr[2 * subblk_count + kk] = + static_cast<_Float16>(-__riscv_vmv_x_s_i16m1_i16(v_a2_sum)) * static_cast<_Float16>(8.0f); + a_sum_ptr[3 * subblk_count + kk] = + static_cast<_Float16>(-__riscv_vmv_x_s_i16m1_i16(v_a3_sum)) * static_cast<_Float16>(8.0f); + + __riscv_vse8_v_i8mf4(quant_a_blk + 0 * k_subblk_len, v_a0_quant_i8, vl); + __riscv_vse8_v_i8mf4(quant_a_blk + 1 * k_subblk_len, v_a1_quant_i8, vl); + __riscv_vse8_v_i8mf4(quant_a_blk + 2 * k_subblk_len, v_a2_quant_i8, vl); + __riscv_vse8_v_i8mf4(quant_a_blk + 3 * k_subblk_len, v_a3_quant_i8, vl); + } + } + } else if (vlenb == 32) { + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_nrow_block_stride) { + _Float16 * a_sum_ptr = reinterpret_cast<_Float16 *>(quant_a_ptr + a_subblk_stride * subblk_count); + _Float16 * scale_avg_ptr = + reinterpret_cast<_Float16 *>(quant_a_ptr + a_nrow_block_stride - sizeof(_Float16) * 4); + float scale_avg = 0.0f; + + for (size_t kk = 0; kk < subblk_count; ++kk) { + const float * a_src_ptr0 = a_ptr + 0 * count_k + k + kk * k_subblk_len; + const float * a_src_ptr1 = a_ptr + 1 * count_k + k + kk * k_subblk_len; + const float * a_src_ptr2 = a_ptr + 2 * count_k + k + kk * k_subblk_len; + const float * a_src_ptr3 = a_ptr + 3 * count_k + k + kk * k_subblk_len; + + size_t vl = __riscv_vsetvl_e32m4(k_subblk_len); + vfloat32m4_t v_a0 = __riscv_vle32_v_f32m4(a_src_ptr0, vl); + vfloat32m4_t v_a1 = __riscv_vle32_v_f32m4(a_src_ptr1, vl); + vfloat32m4_t v_a2 = __riscv_vle32_v_f32m4(a_src_ptr2, vl); + vfloat32m4_t v_a3 = __riscv_vle32_v_f32m4(a_src_ptr3, vl); + + vfloat32m4_t v_a0_abs = __riscv_vfabs_v_f32m4(v_a0, vl); + vfloat32m4_t v_a1_abs = __riscv_vfabs_v_f32m4(v_a1, vl); + vfloat32m4_t v_a2_abs = __riscv_vfabs_v_f32m4(v_a2, vl); + vfloat32m4_t v_a3_abs = __riscv_vfabs_v_f32m4(v_a3, vl); + + vfloat32m4_t v_max_abs = __riscv_vfmax_vv_f32m4(v_a0_abs, v_a1_abs, vl); + v_max_abs = __riscv_vfmax_vv_f32m4(v_max_abs, v_a2_abs, vl); + v_max_abs = __riscv_vfmax_vv_f32m4(v_max_abs, v_a3_abs, vl); + + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); + vfloat32m1_t v_a_max = __riscv_vfredmax_vs_f32m4_f32m1(v_max_abs, tmp, vl); + float max_abs_a = __riscv_vfmv_f_s_f32m1_f32(v_a_max); + + scale_temp[kk] = max_abs_a / ((1 << 7) - 1); + scale_avg += scale_temp[kk]; + } + + scale_avg /= subblk_count; + const float scale_factor = scale_avg ? 1.0f / scale_avg : 0.0f; + scale_avg_ptr[0] = static_cast<_Float16>(scale_avg); + + for (size_t kk = 0; kk < subblk_count; ++kk) { + uint8_t * a_subblk_base = quant_a_ptr + kk * a_subblk_stride; + _Float16 * scale_a_ptr = reinterpret_cast<_Float16 *>(a_subblk_base); + int8_t * quant_a_blk = reinterpret_cast(a_subblk_base + sizeof(_Float16) * 4); + const float * a_src_ptr0 = a_ptr + 0 * count_k + k + kk * k_subblk_len; + const float * a_src_ptr1 = a_ptr + 1 * count_k + k + kk * k_subblk_len; + const float * a_src_ptr2 = a_ptr + 2 * count_k + k + kk * k_subblk_len; + const float * a_src_ptr3 = a_ptr + 3 * count_k + k + kk * k_subblk_len; + + size_t vl = __riscv_vsetvl_e32m4(k_subblk_len); + vfloat32m4_t v_a0 = __riscv_vle32_v_f32m4(a_src_ptr0, vl); + vfloat32m4_t v_a1 = __riscv_vle32_v_f32m4(a_src_ptr1, vl); + vfloat32m4_t v_a2 = __riscv_vle32_v_f32m4(a_src_ptr2, vl); + vfloat32m4_t v_a3 = __riscv_vle32_v_f32m4(a_src_ptr3, vl); + + float rep_scale_a = scale_temp[kk] ? 1.0f / scale_temp[kk] : 0.0f; + scale_a_ptr[0] = static_cast<_Float16>(scale_temp[kk] * scale_factor); + + vfloat32m4_t v_a0_scale = __riscv_vfmul_vf_f32m4(v_a0, rep_scale_a, vl); + vfloat32m4_t v_a1_scale = __riscv_vfmul_vf_f32m4(v_a1, rep_scale_a, vl); + vfloat32m4_t v_a2_scale = __riscv_vfmul_vf_f32m4(v_a2, rep_scale_a, vl); + vfloat32m4_t v_a3_scale = __riscv_vfmul_vf_f32m4(v_a3, rep_scale_a, vl); + vint16m2_t v_a0_quant = __riscv_vfncvt_x_f_w_i16m2(v_a0_scale, vl); + vint16m2_t v_a1_quant = __riscv_vfncvt_x_f_w_i16m2(v_a1_scale, vl); + vint16m2_t v_a2_quant = __riscv_vfncvt_x_f_w_i16m2(v_a2_scale, vl); + vint16m2_t v_a3_quant = __riscv_vfncvt_x_f_w_i16m2(v_a3_scale, vl); + vint8m1_t v_a0_quant_i8 = __riscv_vncvt_x_x_w_i8m1(v_a0_quant, vl); + vint8m1_t v_a1_quant_i8 = __riscv_vncvt_x_x_w_i8m1(v_a1_quant, vl); + vint8m1_t v_a2_quant_i8 = __riscv_vncvt_x_x_w_i8m1(v_a2_quant, vl); + vint8m1_t v_a3_quant_i8 = __riscv_vncvt_x_x_w_i8m1(v_a3_quant, vl); + + vint16m1_t tmp_sum0 = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t tmp_sum1 = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t tmp_sum2 = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t tmp_sum3 = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_a0_sum = __riscv_vwredsum_vs_i8m1_i16m1(v_a0_quant_i8, tmp_sum0, vl); + vint16m1_t v_a1_sum = __riscv_vwredsum_vs_i8m1_i16m1(v_a1_quant_i8, tmp_sum1, vl); + vint16m1_t v_a2_sum = __riscv_vwredsum_vs_i8m1_i16m1(v_a2_quant_i8, tmp_sum2, vl); + vint16m1_t v_a3_sum = __riscv_vwredsum_vs_i8m1_i16m1(v_a3_quant_i8, tmp_sum3, vl); + + a_sum_ptr[0 * subblk_count + kk] = + static_cast<_Float16>(-__riscv_vmv_x_s_i16m1_i16(v_a0_sum)) * static_cast<_Float16>(8.0f); + a_sum_ptr[1 * subblk_count + kk] = + static_cast<_Float16>(-__riscv_vmv_x_s_i16m1_i16(v_a1_sum)) * static_cast<_Float16>(8.0f); + a_sum_ptr[2 * subblk_count + kk] = + static_cast<_Float16>(-__riscv_vmv_x_s_i16m1_i16(v_a2_sum)) * static_cast<_Float16>(8.0f); + a_sum_ptr[3 * subblk_count + kk] = + static_cast<_Float16>(-__riscv_vmv_x_s_i16m1_i16(v_a3_sum)) * static_cast<_Float16>(8.0f); + + __riscv_vse8_v_i8m1(quant_a_blk + 0 * k_subblk_len, v_a0_quant_i8, vl); + __riscv_vse8_v_i8m1(quant_a_blk + 1 * k_subblk_len, v_a1_quant_i8, vl); + __riscv_vse8_v_i8m1(quant_a_blk + 2 * k_subblk_len, v_a2_quant_i8, vl); + __riscv_vse8_v_i8m1(quant_a_blk + 3 * k_subblk_len, v_a3_quant_i8, vl); + } + } + } else { + quantize_a_nrow_i8_hp_ref<4>(blk_len, a_ptr, count_k, quant_a_ptr); + } +} + +void quantize_a_row_i8k(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr) { + GGML_ASSERT(blk_len == 256); + constexpr int64_t a_blk_stride = q8k_blk_size(256); + constexpr int64_t a_sum_size = 256 / 16; + size_t vlenb = __riscv_vlenb(); + + if (vlenb == 128) { + // vlen = 1024 bits, can process 32 float32 elements with m1 + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_blk_stride) { + float * scale_a_ptr = reinterpret_cast(quant_a_ptr); + int16_t * a_sum_ptr = reinterpret_cast(quant_a_ptr + sizeof(float)); + int8_t * quant_a_blk = + reinterpret_cast(quant_a_ptr + sizeof(float) + sizeof(int16_t) * a_sum_size); + + // Find max absolute value across all 256 elements + size_t vl = __riscv_vsetvl_e32m1(16); + vfloat32m1_t v_max_abs = __riscv_vfmv_v_f_f32m1(0.0f, vl); + + for (size_t bki = 0; bki < a_sum_size; bki++) { + vfloat32m1_t v_a = __riscv_vle32_v_f32m1(a_ptr + k + bki * 16, vl); + vfloat32m1_t v_a_abs = __riscv_vfabs_v_f32m1(v_a, vl); + v_max_abs = __riscv_vfmax_vv_f32m1(v_a_abs, v_max_abs, vl); + } + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); + vfloat32m1_t v_local_max = __riscv_vfredmax_vs_f32m1_f32m1(v_max_abs, tmp, vl); + float max_abs_a = __riscv_vfmv_f_s_f32m1_f32(v_local_max); + + float scale_a = max_abs_a / ((1 << 7) - 1); + float rep_scale_a = scale_a ? 1.0f / scale_a : 0.0f; + scale_a_ptr[0] = scale_a; + + // Quantize and compute sums for each 16-element group + for (size_t bki = 0; bki < a_sum_size; bki++) { + vfloat32m1_t v_a = __riscv_vle32_v_f32m1(a_ptr + k + bki * 16, vl); + vfloat32m1_t v_a_scale = __riscv_vfmul_vf_f32m1(v_a, rep_scale_a, vl); + vint16mf2_t v_a_quant = __riscv_vfncvt_x_f_w_i16mf2(v_a_scale, vl); + vint8mf4_t v_a_quant_i8 = __riscv_vncvt_x_x_w_i8mf4(v_a_quant, vl); + + vint16m1_t tmp_sum = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_a_sum = __riscv_vwredsum_vs_i8mf4_i16m1(v_a_quant_i8, tmp_sum, vl); + int16_t a_sum = __riscv_vmv_x_s_i16m1_i16(v_a_sum); + a_sum_ptr[bki] = -a_sum; + + __riscv_vse8_v_i8mf4(quant_a_blk + bki * 16, v_a_quant_i8, vl); + } + } + } else if (vlenb == 32) { + // vlen = 256 bits, can process 8 float32 elements with m1 + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_blk_stride) { + float * scale_a_ptr = reinterpret_cast(quant_a_ptr); + int16_t * a_sum_ptr = reinterpret_cast(quant_a_ptr + sizeof(float)); + int8_t * quant_a_blk = + reinterpret_cast(quant_a_ptr + sizeof(float) + sizeof(int16_t) * a_sum_size); + + // Find max absolute value across all 256 elements + size_t vl = __riscv_vsetvl_e32m2(16); + vfloat32m2_t v_max_abs = __riscv_vfmv_v_f_f32m2(0.0f, vl); + + for (size_t bki = 0; bki < a_sum_size; bki++) { + vfloat32m2_t v_a = __riscv_vle32_v_f32m2(a_ptr + k + bki * 16, vl); + vfloat32m2_t v_a_abs = __riscv_vfabs_v_f32m2(v_a, vl); + v_max_abs = __riscv_vfmax_vv_f32m2(v_a_abs, v_max_abs, vl); + } + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); + vfloat32m1_t v_local_max = __riscv_vfredmax_vs_f32m2_f32m1(v_max_abs, tmp, vl); + float max_abs_a = __riscv_vfmv_f_s_f32m1_f32(v_local_max); + + float scale_a = max_abs_a / ((1 << 7) - 1); + float rep_scale_a = scale_a ? 1.0f / scale_a : 0.0f; + scale_a_ptr[0] = scale_a; + + // Quantize and compute sums for each 16-element group + for (size_t bki = 0; bki < a_sum_size; bki++) { + vfloat32m2_t v_a = __riscv_vle32_v_f32m2(a_ptr + k + bki * 16, vl); + vfloat32m2_t v_a_scale = __riscv_vfmul_vf_f32m2(v_a, rep_scale_a, vl); + vint16m1_t v_a_quant = __riscv_vfncvt_x_f_w_i16m1(v_a_scale, vl); + vint8mf2_t v_a_quant_i8 = __riscv_vncvt_x_x_w_i8mf2(v_a_quant, vl); + + vint16m1_t tmp_sum = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_a_sum = __riscv_vwredsum_vs_i8mf2_i16m1(v_a_quant_i8, tmp_sum, vl); + int16_t a_sum = __riscv_vmv_x_s_i16m1_i16(v_a_sum); + a_sum_ptr[bki] = -a_sum; + + __riscv_vse8_v_i8mf2(quant_a_blk + bki * 16, v_a_quant_i8, vl); + } + } + } else { + quantize_a_nrow_i8k_ref<1>(blk_len, a_ptr, count_k, quant_a_ptr); + } +} + +void quantize_a_4row_i8k(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr) { + GGML_ASSERT(blk_len == 256); + constexpr int64_t a_blk_stride = q8k_blk_size(256); + constexpr int64_t a_nrow_block_stride = a_blk_stride * 4; + constexpr int64_t a_sum_size = 256 / 16; + size_t vlenb = __riscv_vlenb(); + + if (vlenb == 128) { + // vlen = 1024 bits + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_nrow_block_stride) { + float * scale_a_ptr = reinterpret_cast(quant_a_ptr); + int16_t * a_sum_ptr = reinterpret_cast(quant_a_ptr + sizeof(float) * 4); + int8_t * quant_a_blk = + reinterpret_cast(quant_a_ptr + sizeof(float) * 4 + sizeof(int16_t) * a_sum_size * 4); + + for (size_t mi = 0; mi < 4; mi++) { + // Find max absolute value across all 256 elements for this row + size_t vl = __riscv_vsetvl_e32m1(16); + vfloat32m1_t v_max_abs = __riscv_vfmv_v_f_f32m1(0.0f, vl); + + for (size_t bki = 0; bki < a_sum_size; bki++) { + vfloat32m1_t v_a = __riscv_vle32_v_f32m1(a_ptr + mi * count_k + k + bki * 16, vl); + vfloat32m1_t v_a_abs = __riscv_vfabs_v_f32m1(v_a, vl); + v_max_abs = __riscv_vfmax_vv_f32m1(v_a_abs, v_max_abs, vl); + } + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); + vfloat32m1_t v_local_max = __riscv_vfredmax_vs_f32m1_f32m1(v_max_abs, tmp, vl); + float max_abs_a = __riscv_vfmv_f_s_f32m1_f32(v_local_max); + + float scale_a = max_abs_a / ((1 << 7) - 1); + float rep_scale_a = scale_a ? 1.0f / scale_a : 0.0f; + scale_a_ptr[mi] = scale_a; + + // Quantize and compute sums for each 16-element group + for (size_t bki = 0; bki < a_sum_size; bki++) { + vfloat32m1_t v_a = __riscv_vle32_v_f32m1(a_ptr + mi * count_k + k + bki * 16, vl); + vfloat32m1_t v_a_scale = __riscv_vfmul_vf_f32m1(v_a, rep_scale_a, vl); + vint16mf2_t v_a_quant = __riscv_vfncvt_x_f_w_i16mf2(v_a_scale, vl); + vint8mf4_t v_a_quant_i8 = __riscv_vncvt_x_x_w_i8mf4(v_a_quant, vl); + + vint16m1_t tmp_sum = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_a_sum = __riscv_vwredsum_vs_i8mf4_i16m1(v_a_quant_i8, tmp_sum, vl); + int16_t a_sum = __riscv_vmv_x_s_i16m1_i16(v_a_sum); + a_sum_ptr[mi * a_sum_size + bki] = -a_sum; + + __riscv_vse8_v_i8mf4(quant_a_blk + mi * blk_len + bki * 16, v_a_quant_i8, vl); + } + } + } + } else if (vlenb == 32) { + // vlen = 256 bits + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_nrow_block_stride) { + float * scale_a_ptr = reinterpret_cast(quant_a_ptr); + int16_t * a_sum_ptr = reinterpret_cast(quant_a_ptr + sizeof(float) * 4); + int8_t * quant_a_blk = + reinterpret_cast(quant_a_ptr + sizeof(float) * 4 + sizeof(int16_t) * a_sum_size * 4); + + for (size_t mi = 0; mi < 4; mi++) { + // Find max absolute value across all 256 elements for this row + size_t vl = __riscv_vsetvl_e32m2(16); + vfloat32m2_t v_max_abs = __riscv_vfmv_v_f_f32m2(0.0f, vl); + + for (size_t bki = 0; bki < a_sum_size; bki++) { + vfloat32m2_t v_a = __riscv_vle32_v_f32m2(a_ptr + mi * count_k + k + bki * 16, vl); + vfloat32m2_t v_a_abs = __riscv_vfabs_v_f32m2(v_a, vl); + v_max_abs = __riscv_vfmax_vv_f32m2(v_a_abs, v_max_abs, vl); + } + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); + vfloat32m1_t v_local_max = __riscv_vfredmax_vs_f32m2_f32m1(v_max_abs, tmp, vl); + float max_abs_a = __riscv_vfmv_f_s_f32m1_f32(v_local_max); + + float scale_a = max_abs_a / ((1 << 7) - 1); + float rep_scale_a = scale_a ? 1.0f / scale_a : 0.0f; + scale_a_ptr[mi] = scale_a; + + // Quantize and compute sums for each 16-element group + for (size_t bki = 0; bki < a_sum_size; bki++) { + vfloat32m2_t v_a = __riscv_vle32_v_f32m2(a_ptr + mi * count_k + k + bki * 16, vl); + vfloat32m2_t v_a_scale = __riscv_vfmul_vf_f32m2(v_a, rep_scale_a, vl); + vint16m1_t v_a_quant = __riscv_vfncvt_x_f_w_i16m1(v_a_scale, vl); + vint8mf2_t v_a_quant_i8 = __riscv_vncvt_x_x_w_i8mf2(v_a_quant, vl); + + vint16m1_t tmp_sum = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_a_sum = __riscv_vwredsum_vs_i8mf2_i16m1(v_a_quant_i8, tmp_sum, vl); + int16_t a_sum = __riscv_vmv_x_s_i16m1_i16(v_a_sum); + a_sum_ptr[mi * a_sum_size + bki] = -a_sum; + + __riscv_vse8_v_i8mf2(quant_a_blk + mi * blk_len + bki * 16, v_a_quant_i8, vl); + } + } + } + } else { + quantize_a_nrow_i8k_ref<4>(blk_len, a_ptr, count_k, quant_a_ptr); + } +} + +void forward_cpy_with_permute(ggml_compute_params * params, ggml_tensor * op) { + const ggml_tensor * src0 = op->src[0]; + ggml_tensor * dst = op; + const int ith = params->ith; + const int nth = params->nth; + + // [batch, m, n] -> [batch, n, m] + int64_t batch = src0->ne[2] * src0->ne[3]; + int64_t m = src0->ne[1]; + int64_t n = src0->ne[0]; + + int64_t batch_stride = src0->nb[2]; + int64_t m_src_stride = src0->nb[0]; + int64_t n_src_stride = src0->nb[1]; + int64_t n_dst_stride = n_src_stride * m; + + permute_transpose_impl(src0, dst, batch, m, n, batch_stride, m_src_stride, n_src_stride, n_dst_stride, ith, nth); +} + +void forward_cont_with_permute(ggml_compute_params * params, ggml_tensor * op) { + const ggml_tensor * src0 = op->src[0]; + ggml_tensor * dst = op; + const int ith = params->ith; + const int nth = params->nth; + + // [batch, m, n] -> [batch, n, m] + int64_t batch = dst->ne[2] * dst->ne[3]; + int64_t n = dst->ne[1]; + int64_t m = dst->ne[0]; + + int64_t batch_stride = dst->nb[2]; + int64_t m_src_stride = src0->nb[0]; + int64_t n_src_stride = src0->nb[1]; + int64_t n_dst_stride = dst->nb[1]; + + permute_transpose_impl(src0, dst, batch, m, n, batch_stride, m_src_stride, n_src_stride, n_dst_stride, ith, nth); +} + +void forward_norm_f32(ggml_compute_params * params, ggml_tensor * op) { + const ggml_tensor * src0 = op->src[0]; + ggml_tensor * dst = op; + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(src0->nb[0] == sizeof(float)); + + int ith = params->ith; + int nth = params->nth; + + GGML_TENSOR_UNARY_OP_LOCALS + + float epsilon = *((float *) dst->op_params); + + GGML_ASSERT(epsilon > 0.0f); + + auto * input = (char *) src0->data; + auto * output = (char *) dst->data; + + const auto hidden_size = ne00; + const auto task_count = ne01 * ne02 * ne03; + const auto task_per_thread = (task_count + nth - 1) / nth; + + const auto task_begin = ith * task_per_thread; + const auto task_end = std::min((ith + 1) * task_per_thread, task_count); + + for (auto task_idx = task_begin; task_idx < task_end; task_idx++) { + int64_t i03 = task_idx / (ne02 * ne01); + int64_t i02 = (task_idx - i03 * ne02 * ne01) / ne01; + int64_t i01 = (task_idx - i03 * ne02 * ne01 - i02 * ne01); + + auto * p_input = (float *) (input + i01 * nb01 + i02 * nb02 + i03 * nb03); + auto * p_output = (float *) (output + i01 * nb1 + i02 * nb2 + i03 * nb3); + auto * p_temp_output = p_output; + + size_t gvl = __riscv_vsetvlmax_e32m4(); + vfloat32m4_t sum = __riscv_vfmv_v_f_f32m4(0.f, gvl); + vfloat32m4_t sum_sq = __riscv_vfmv_v_f_f32m4(0.f, gvl); + int64_t length = hidden_size; + while (length > 0) { + gvl = __riscv_vsetvl_e32m4(length); + // load data + vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_input, gvl); + + sum = __riscv_vfadd_vv_f32m4(sum, src_data, gvl); + sum_sq = __riscv_vfmacc_vv_f32m4(sum_sq, src_data, src_data, gvl); + + __riscv_vse32_v_f32m4(p_temp_output, src_data, gvl); + + p_input += gvl; + p_temp_output += gvl; + length -= gvl; + } + + gvl = __riscv_vsetvlmax_e32m1(); + + float mean = 0.f; + vfloat32m1_t zero_v = __riscv_vfmv_v_f_f32m1(0.f, gvl); + vfloat32m1_t mean_v = + __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(sum, 0), __riscv_vget_v_f32m4_f32m1(sum, 1), gvl); + mean_v = __riscv_vfadd_vv_f32m1(mean_v, __riscv_vget_v_f32m4_f32m1(sum, 2), gvl); + mean_v = __riscv_vfadd_vv_f32m1(mean_v, __riscv_vget_v_f32m4_f32m1(sum, 3), gvl); + mean_v = __riscv_vfredusum_vs_f32m1_f32m1(mean_v, zero_v, gvl); + mean = __riscv_vfmv_f_s_f32m1_f32(mean_v); + mean /= hidden_size; + + vfloat32m1_t mean_square_v = + __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(sum_sq, 0), __riscv_vget_v_f32m4_f32m1(sum_sq, 1), gvl); + mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 2), gvl); + mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 3), gvl); + mean_square_v = __riscv_vfredusum_vs_f32m1_f32m1(mean_square_v, zero_v, gvl); + + float mean_square = __riscv_vfmv_f_s_f32m1_f32(mean_square_v); + mean_square /= hidden_size; + mean_square = sqrt(mean_square - mean * mean + epsilon); + + mean_square = 1.0f / mean_square; + length = hidden_size; + p_temp_output = p_output; + + while (length > 0) { + gvl = __riscv_vsetvl_e32m4(length); + vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); + src_data = __riscv_vfsub_vf_f32m4(src_data, mean, gvl); + src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); + __riscv_vse32_v_f32m4(p_output, src_data, gvl); + p_temp_output += gvl; + p_output += gvl; + length -= gvl; + } + } +} + +template void forward_binary(ggml_compute_params * params, ggml_tensor * op) { + const ggml_tensor * src0 = op->src[0]; + const ggml_tensor * src1 = op->src[1]; + ggml_tensor * dst = op; + GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst)); + + auto src0_rows = ggml_nrows(src0); + auto src1_rows = ggml_nrows(src1); + + int ith = params->ith; + int nth = params->nth; + + GGML_TENSOR_BINARY_OP_LOCALS + + GGML_ASSERT(nb0 == sizeof(T)); + GGML_ASSERT(nb00 == sizeof(T)); + + const auto [ir0, ir1] = get_thread_range(params, src0); + + auto compute_func_vv = [&](int64_t blk_len, int64_t r, T * src0_ptr, T * src1_ptr, T * dst_ptr) { + int64_t idx = 0; + if constexpr (op_type == GGML_OP_ADD) { + if constexpr (std::is_same_v) { + for (size_t vl; blk_len > 0; blk_len -= vl, idx += vl) { + vl = __riscv_vsetvl_e32m4(blk_len); + vfloat32m4_t lhs = __riscv_vle32_v_f32m4(src0_ptr + idx + r, vl); + vfloat32m4_t rhs = __riscv_vle32_v_f32m4(src1_ptr + idx, vl); + vfloat32m4_t res = __riscv_vfadd_vv_f32m4(lhs, rhs, vl); + __riscv_vse32_v_f32m4(dst_ptr + idx + r, res, vl); + } + } else if constexpr (std::is_same_v) { + for (size_t vl; blk_len > 0; blk_len -= vl, idx += vl) { + vl = __riscv_vsetvl_e16m4(blk_len); + vfloat16m4_t lhs = __riscv_vle16_v_f16m4((src0_ptr + idx + r), vl); + vfloat16m4_t rhs = __riscv_vle16_v_f16m4((src1_ptr + idx), vl); + vfloat16m4_t res = __riscv_vfadd_vv_f16m4(lhs, rhs, vl); + __riscv_vse16_v_f16m4((dst_ptr + idx + r), res, vl); + } + } else { + GGML_ABORT("fatal error"); + } + } else if constexpr (op_type == GGML_OP_SUB) { + if constexpr (std::is_same_v) { + for (size_t vl; blk_len > 0; blk_len -= vl, idx += vl) { + vl = __riscv_vsetvl_e32m4(blk_len); + vfloat32m4_t lhs = __riscv_vle32_v_f32m4(src0_ptr + idx + r, vl); + vfloat32m4_t rhs = __riscv_vle32_v_f32m4(src1_ptr + idx, vl); + vfloat32m4_t res = __riscv_vfsub_vv_f32m4(lhs, rhs, vl); + __riscv_vse32_v_f32m4(dst_ptr + idx + r, res, vl); + } + } else if constexpr (std::is_same_v) { + for (size_t vl; blk_len > 0; blk_len -= vl, idx += vl) { + vl = __riscv_vsetvl_e16m4(blk_len); + vfloat16m4_t lhs = __riscv_vle16_v_f16m4((src0_ptr + idx + r), vl); + vfloat16m4_t rhs = __riscv_vle16_v_f16m4((src1_ptr + idx), vl); + vfloat16m4_t res = __riscv_vfsub_vv_f16m4(lhs, rhs, vl); + __riscv_vse16_v_f16m4((dst_ptr + idx + r), res, vl); + } + } else { + GGML_ABORT("fatal error"); + } + } else if constexpr (op_type == GGML_OP_MUL) { + if constexpr (std::is_same_v) { + for (size_t vl; blk_len > 0; blk_len -= vl, idx += vl) { + vl = __riscv_vsetvl_e32m4(blk_len); + vfloat32m4_t lhs = __riscv_vle32_v_f32m4(src0_ptr + idx + r, vl); + vfloat32m4_t rhs = __riscv_vle32_v_f32m4(src1_ptr + idx, vl); + vfloat32m4_t res = __riscv_vfmul_vv_f32m4(lhs, rhs, vl); + __riscv_vse32_v_f32m4(dst_ptr + idx + r, res, vl); + } + } else if constexpr (std::is_same_v) { + for (size_t vl; blk_len > 0; blk_len -= vl, idx += vl) { + vl = __riscv_vsetvl_e16m4(blk_len); + vfloat16m4_t lhs = __riscv_vle16_v_f16m4((src0_ptr + idx + r), vl); + vfloat16m4_t rhs = __riscv_vle16_v_f16m4((src1_ptr + idx), vl); + vfloat16m4_t res = __riscv_vfmul_vv_f16m4(lhs, rhs, vl); + __riscv_vse16_v_f16m4((dst_ptr + idx + r), res, vl); + } + } else { + GGML_ABORT("fatal error"); + } + } else if constexpr (op_type == GGML_OP_DIV) { + if constexpr (std::is_same_v) { + for (size_t vl; blk_len > 0; blk_len -= vl, idx += vl) { + vl = __riscv_vsetvl_e32m4(blk_len); + vfloat32m4_t lhs = __riscv_vle32_v_f32m4(src0_ptr + idx + r, vl); + vfloat32m4_t rhs = __riscv_vle32_v_f32m4(src1_ptr + idx, vl); + vfloat32m4_t res = __riscv_vfdiv_vv_f32m4(lhs, rhs, vl); + __riscv_vse32_v_f32m4(dst_ptr + idx + r, res, vl); + } + } else if constexpr (std::is_same_v) { + for (size_t vl; blk_len > 0; blk_len -= vl, idx += vl) { + vl = __riscv_vsetvl_e16m4(blk_len); + vfloat16m4_t lhs = __riscv_vle16_v_f16m4((src0_ptr + idx + r), vl); + vfloat16m4_t rhs = __riscv_vle16_v_f16m4((src1_ptr + idx), vl); + vfloat16m4_t res = __riscv_vfdiv_vv_f16m4(lhs, rhs, vl); + __riscv_vse16_v_f16m4((dst_ptr + idx + r), res, vl); + } + } else { + GGML_ABORT("fatal error"); + } + } else { + GGML_ABORT("fatal error"); + } + }; + + if (src0_rows == src1_rows && src0_rows == 1 && ne00 == ne10) { + int64_t task_per_thread = (ne00 + nth - 1) / nth; + int64_t task_begin = ith * task_per_thread; + int64_t task_end = std::min((ith + 1) * task_per_thread, ne00); + + T * dst_ptr = ((T *) dst->data) + task_begin; + T * src0_ptr = ((T *) src0->data) + task_begin; + T * src1_ptr = ((T *) src1->data) + task_begin; + + compute_func_vv(task_end - task_begin, 0, src0_ptr, src1_ptr, dst_ptr); + } else if (ne10 > 1) { + for (int64_t ir = ir0; ir < ir1; ++ir) { + const int64_t i03 = ir / (ne02 * ne01); + const int64_t i02 = (ir - i03 * ne02 * ne01) / ne01; + const int64_t i01 = (ir - i03 * ne02 * ne01 - i02 * ne01); + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + + T * dst_ptr = (T *) ((char *) dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1); + T * src0_ptr = (T *) ((char *) src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01); + T * src1_ptr = (T *) ((char *) src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11); + + // src1 is broadcastable across src0 and dst in i1, i2, i3 + for (int64_t r = 0; r < ne00; r += ne10) { + compute_func_vv(ne10, r, src0_ptr, src1_ptr, dst_ptr); + } + } + } else { + for (int64_t ir = ir0; ir < ir1; ++ir) { + const int64_t i03 = ir / (ne02 * ne01); + const int64_t i02 = (ir - i03 * ne02 * ne01) / ne01; + const int64_t i01 = (ir - i03 * ne02 * ne01 - i02 * ne01); + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + + T * dst_ptr = (T *) ((char *) dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1); + T * src0_ptr = (T *) ((char *) src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01); + T * src1_ptr = (T *) ((char *) src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11); + + T rhs_scalar = src1_ptr[0]; + int64_t blk_len = ne00; + int64_t r = 0; + + for (size_t vl; blk_len > 0; blk_len -= vl, r += vl) { + if constexpr (op_type == GGML_OP_ADD) { + if constexpr (std::is_same_v) { + vl = __riscv_vsetvl_e32m4(blk_len); + vfloat32m4_t lhs = __riscv_vle32_v_f32m4(src0_ptr + r, vl); + vfloat32m4_t res = __riscv_vfadd_vf_f32m4(lhs, rhs_scalar, vl); + __riscv_vse32_v_f32m4(dst_ptr + r, res, vl); + } else if constexpr (std::is_same_v) { + vl = __riscv_vsetvl_e16m4(blk_len); + vfloat16m4_t lhs = __riscv_vle16_v_f16m4((src0_ptr + r), vl); + vfloat16m4_t res = __riscv_vfadd_vf_f16m4(lhs, rhs_scalar, vl); + __riscv_vse16_v_f16m4((dst_ptr + r), res, vl); + } else { + GGML_ABORT("fatal error"); + } + } else if constexpr (op_type == GGML_OP_SUB) { + if constexpr (std::is_same_v) { + vl = __riscv_vsetvl_e32m4(blk_len); + vfloat32m4_t lhs = __riscv_vle32_v_f32m4(src0_ptr + r, vl); + vfloat32m4_t res = __riscv_vfsub_vf_f32m4(lhs, rhs_scalar, vl); + __riscv_vse32_v_f32m4(dst_ptr + r, res, vl); + } else if constexpr (std::is_same_v) { + vl = __riscv_vsetvl_e16m4(blk_len); + vfloat16m4_t lhs = __riscv_vle16_v_f16m4((src0_ptr + r), vl); + vfloat16m4_t res = __riscv_vfsub_vf_f16m4(lhs, rhs_scalar, vl); + __riscv_vse16_v_f16m4((dst_ptr + r), res, vl); + } else { + GGML_ABORT("fatal error"); + } + } else if constexpr (op_type == GGML_OP_MUL) { + if constexpr (std::is_same_v) { + vl = __riscv_vsetvl_e32m4(blk_len); + vfloat32m4_t lhs = __riscv_vle32_v_f32m4(src0_ptr + r, vl); + vfloat32m4_t res = __riscv_vfmul_vf_f32m4(lhs, rhs_scalar, vl); + __riscv_vse32_v_f32m4(dst_ptr + r, res, vl); + } else if constexpr (std::is_same_v) { + vl = __riscv_vsetvl_e16m4(blk_len); + vfloat16m4_t lhs = __riscv_vle16_v_f16m4((src0_ptr + r), vl); + vfloat16m4_t res = __riscv_vfmul_vf_f16m4(lhs, rhs_scalar, vl); + __riscv_vse16_v_f16m4((dst_ptr + r), res, vl); + } else { + GGML_ABORT("fatal error"); + } + } else if constexpr (op_type == GGML_OP_DIV) { + if constexpr (std::is_same_v) { + vl = __riscv_vsetvl_e32m4(blk_len); + vfloat32m4_t lhs = __riscv_vle32_v_f32m4(src0_ptr + r, vl); + vfloat32m4_t res = __riscv_vfdiv_vf_f32m4(lhs, rhs_scalar, vl); + __riscv_vse32_v_f32m4(dst_ptr + r, res, vl); + } else if constexpr (std::is_same_v) { + vl = __riscv_vsetvl_e16m4(blk_len); + vfloat16m4_t lhs = __riscv_vle16_v_f16m4((src0_ptr + r), vl); + vfloat16m4_t res = __riscv_vfdiv_vf_f16m4(lhs, rhs_scalar, vl); + __riscv_vse16_v_f16m4((dst_ptr + r), res, vl); + } else { + GGML_ABORT("fatal error"); + } + } else { + GGML_ABORT("fatal error"); + } + } + } + } +} + +template void forward_sum_rows(const ggml_compute_params * params, ggml_tensor * op) { + const ggml_tensor * src0 = op->src[0]; + ggml_tensor * dst = op; + + const int ith = params->ith; + const int nth = params->nth; + + GGML_TENSOR_UNARY_OP_LOCALS + + GGML_ASSERT(ne0 == 1); + GGML_ASSERT(ne1 == ne01); + GGML_ASSERT(ne2 == ne02); + GGML_ASSERT(ne3 == ne03); + + int64_t n_task = ne01 * ne02 * ne03; + int64_t task_per_thread = (n_task + nth - 1) / nth; + int64_t ir_start = ith * task_per_thread; + int64_t ir_end = std::min(ir_start + task_per_thread, n_task); + + for (int64_t ir = ir_start; ir < ir_end; ir++) { + const int64_t i3 = ir / (ne02 * ne01); + const int64_t i2 = (ir - i3 * ne02 * ne01) / ne01; + const int64_t i1 = (ir - i3 * ne02 * ne01 - i2 * ne01); + + T * src_row = (T *) ((char *) src0->data + i1 * nb01 + i2 * nb02 + i3 * nb03); + T * dst_row = (T *) ((char *) op->data + i1 * nb1 + i2 * nb2 + i3 * nb3); + + float row_sum = 0; + + if constexpr (std::is_same_v) { + size_t gvl = __riscv_vsetvlmax_e32m4(); + vfloat32m4_t acc_vec = __riscv_vfmv_v_f_f32m4(0.0f, gvl); + int64_t length = ne00; + const float * p_data = src_row; + + while (length > 0) { + size_t vl = __riscv_vsetvl_e32m4(length); + vfloat32m4_t vec = __riscv_vle32_v_f32m4(p_data, vl); + acc_vec = __riscv_vfadd_vv_f32m4(acc_vec, vec, vl); + p_data += vl; + length -= vl; + } + + gvl = __riscv_vsetvlmax_e32m1(); + vfloat32m1_t zero_v = __riscv_vfmv_v_f_f32m1(0.0f, gvl); + vfloat32m1_t sum_v = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(acc_vec, 0), + __riscv_vget_v_f32m4_f32m1(acc_vec, 1), gvl); + sum_v = __riscv_vfadd_vv_f32m1(sum_v, __riscv_vget_v_f32m4_f32m1(acc_vec, 2), gvl); + sum_v = __riscv_vfadd_vv_f32m1(sum_v, __riscv_vget_v_f32m4_f32m1(acc_vec, 3), gvl); + sum_v = __riscv_vfredusum_vs_f32m1_f32m1(sum_v, zero_v, gvl); + row_sum = __riscv_vfmv_f_s_f32m1_f32(sum_v); + } else if constexpr (std::is_same_v) { + size_t gvl = __riscv_vsetvlmax_e16m2(); + vfloat32m4_t acc_vec = __riscv_vfmv_v_f_f32m4(0.0f, gvl); + int64_t length = ne00; + const _Float16 * p_data = src_row; + + while (length > 0) { + size_t vl = __riscv_vsetvl_e16m2(length); + vfloat16m2_t vec_f16 = __riscv_vle16_v_f16m2(p_data, vl); + vfloat32m4_t vec_f32 = __riscv_vfwcvt_f_f_v_f32m4(vec_f16, vl); + acc_vec = __riscv_vfadd_vv_f32m4(acc_vec, vec_f32, vl); + p_data += vl; + length -= vl; + } + + gvl = __riscv_vsetvlmax_e32m1(); + vfloat32m1_t zero_v = __riscv_vfmv_v_f_f32m1(0.0f, gvl); + vfloat32m1_t sum_v = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(acc_vec, 0), + __riscv_vget_v_f32m4_f32m1(acc_vec, 1), gvl); + sum_v = __riscv_vfadd_vv_f32m1(sum_v, __riscv_vget_v_f32m4_f32m1(acc_vec, 2), gvl); + sum_v = __riscv_vfadd_vv_f32m1(sum_v, __riscv_vget_v_f32m4_f32m1(acc_vec, 3), gvl); + sum_v = __riscv_vfredusum_vs_f32m1_f32m1(sum_v, zero_v, gvl); + row_sum = __riscv_vfmv_f_s_f32m1_f32(sum_v); + } else { + GGML_ABORT("fatal error"); + } + + dst_row[0] = row_sum; + } +} + +template void forward_repeat_nrows(ggml_compute_params * params, ggml_tensor * op) { + const ggml_tensor * src0 = op->src[0]; + ggml_tensor * dst = op; + + const int ith = params->ith; + const int nth = params->nth; + + int64_t nrows = ggml_nrows(src0); + int64_t nrows_per_thread = (nrows + nth - 1) / nth; + int64_t ir_start = ith * nrows_per_thread; + int64_t ir_end = std::min(ir_start + nrows_per_thread, nrows); + + if (src0->ne[0] == 1) { + for (int64_t ir = ir_start; ir < ir_end; ir++) { + T * src_row = (T *) ((char *) src0->data + ir * src0->nb[1]); + T * dst_row = (T *) ((char *) dst->data + ir * dst->nb[1]); + + T src_scalar = src_row[0]; + + int64_t length = dst->ne[0]; + int64_t idx = 0; + size_t vl = 0; + + while (length > 0) { + if constexpr (std::is_same_v) { + vl = __riscv_vsetvl_e32m4(length); + vint32m4_t vec = __riscv_vmv_v_x_i32m4(src_scalar, vl); + __riscv_vse32_v_i32m4(dst_row + idx, vec, vl); + } else if constexpr (std::is_same_v) { + vl = __riscv_vsetvl_e16m4(length); + vint16m4_t vec = __riscv_vmv_v_x_i16m4(src_scalar, vl); + __riscv_vse16_v_i16m4((dst_row + idx), vec, vl); + } else { + GGML_ABORT("fatal error"); + } + idx += vl; + length -= vl; + } + } + } else if (src0->ne[0] == dst->ne[0]) { + for (int64_t ir = ir_start; ir < ir_end; ir++) { + T * src_row = (T *) ((char *) src0->data + ir * src0->nb[1]); + T * dst_row = (T *) ((char *) dst->data + ir * dst->nb[1]); + + int64_t length = dst->ne[0]; + int64_t idx = 0; + size_t vl = 0; + + while (length > 0) { + if constexpr (std::is_same_v) { + vl = __riscv_vsetvl_e32m4(length); + vint32m4_t vec = __riscv_vle32_v_i32m4(src_row + idx, vl); + __riscv_vse32_v_i32m4(dst_row + idx, vec, vl); + } else if constexpr (std::is_same_v) { + vl = __riscv_vsetvl_e16m4(length); + vint16m4_t vec = __riscv_vle16_v_i16m4((src_row + idx), vl); + __riscv_vse16_v_i16m4((dst_row + idx), vec, vl); + } else { + GGML_ABORT("fatal error"); + } + idx += vl; + length -= vl; + } + } + } else { + GGML_ABORT("fatal error"); + } +} + +template void forward_repeat_dim1(ggml_compute_params * params, ggml_tensor * op) { + const ggml_tensor * src0 = op->src[0]; + ggml_tensor * dst = op; + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t ne0 = dst->ne[0]; + const int64_t ne1 = dst->ne[1]; + const int64_t ne2 = dst->ne[2]; + const int64_t ne3 = dst->ne[3]; + + const int64_t total_batches = ne2 * ne3; + const int64_t batches_per_thread = (total_batches + nth - 1) / nth; + const int64_t batch_start = ith * batches_per_thread; + const int64_t batch_end = std::min(batch_start + batches_per_thread, total_batches); + + for (int64_t b = batch_start; b < batch_end; b++) { + const int64_t i3 = b / ne2; + const int64_t i2 = b % ne2; + + T * src_base = (T *) ((char *) src0->data + i2 * src0->nb[2] + i3 * src0->nb[3]); + T * dst_batch = (T *) ((char *) dst->data + i2 * dst->nb[2] + i3 * dst->nb[3]); + + for (int64_t i1 = 0; i1 < ne1; i1++) { + T * dst_ptr = (T *) ((char *) dst_batch + i1 * dst->nb[1]); + int64_t length = ne0; + int64_t idx = 0; + + while (length > 0) { + if constexpr (std::is_same_v) { + size_t vl = __riscv_vsetvl_e32m4(length); + vint32m4_t vec = __riscv_vle32_v_i32m4(src_base + idx, vl); + __riscv_vse32_v_i32m4(dst_ptr + idx, vec, vl); + idx += vl; + length -= vl; + } else if constexpr (std::is_same_v) { + size_t vl = __riscv_vsetvl_e16m4(length); + vint16m4_t vec = __riscv_vle16_v_i16m4((src_base + idx), vl); + __riscv_vse16_v_i16m4((dst_ptr + idx), vec, vl); + idx += vl; + length -= vl; + } else { + GGML_ABORT("fatal error"); + } + } + } + } +} + +template void forward_get_rows(ggml_compute_params * params, ggml_tensor * op) { + const ggml_tensor * src0 = op->src[0]; + const ggml_tensor * src1 = op->src[1]; + ggml_tensor * dst = op; + + GGML_TENSOR_BINARY_OP_LOCALS + + const int64_t nc = ne00; + const int64_t nr = ggml_nelements(src1); + + assert(ne0 == nc); + assert(ne02 == ne11); + assert(nb00 == sizeof(float)); + assert(ggml_nrows(op) == nr); + + const int ith = params->ith; + const int nth = params->nth; + + int rows_nth = nth; + int cols_nth = 1; + + if (nr == 1) { + rows_nth = 1; + cols_nth = nth; + } + + // rows per thread + const int dr = (nr + rows_nth - 1) / rows_nth; + const int dc = (nc + cols_nth - 1) / cols_nth; + + int rows_ith = ith % rows_nth; + int cols_ith = ith % cols_nth; + + // row range for this thread + const int ir0 = dr * rows_ith; + const int ir1 = MIN(ir0 + dr, nr); + + const int cr0 = dc * cols_ith; + const int cr1 = MIN(cr0 + dc, nc); + + for (int64_t i = ir0; i < ir1; ++i) { + const int64_t i12 = i / (ne11 * ne10); + const int64_t i11 = (i - i12 * ne11 * ne10) / ne10; + const int64_t i10 = (i - i12 * ne11 * ne10 - i11 * ne10); + const int64_t i01 = *(int32_t *) ((char *) src1->data + i10 * nb10 + i11 * nb11 + i12 * nb12); + + GGML_ASSERT(i01 >= 0 && i01 < ne01); + + memcpy1d(((char *) dst->data + i10 * nb1 + i11 * nb2 + i12 * nb3) + cr0 * sizeof(T), + ((char *) src0->data + i01 * nb01 + i11 * nb02 + i12 * nb03) + cr0 * sizeof(T), + (cr1 - cr0) * sizeof(T)); + } +} + +template void forward_concat(ggml_compute_params * params, ggml_tensor * op) { + const ggml_tensor * src0 = op->src[0]; + const ggml_tensor * src1 = op->src[1]; + ggml_tensor * dst = op; + + GGML_ASSERT(ggml_type_size(src0->type) == sizeof(float)); + + GGML_TENSOR_BINARY_OP_LOCALS + + const int32_t dim = ggml_get_op_params_i32(dst, 0); + + GGML_ASSERT(dim == 0 && nb0 == sizeof(float) && nb1 == sizeof(float) * (ne00 + ne10)); + + const int64_t nr = ggml_nrows(dst); + const int64_t nc = ne0; + + const int ith = params->ith; + const int nth = params->nth; + + int rows_nth = nth; + int cols_nth = 1; + + if (nr == 1) { + rows_nth = 1; + cols_nth = nth; + } + + const int dr = (nr + rows_nth - 1) / rows_nth; + const int dc = (nc + cols_nth - 1) / cols_nth; + + int rows_ith = ith % rows_nth; + int cols_ith = ith % cols_nth; + + // row range for this thread + const int ir0 = dr * rows_ith; + const int ir1 = MIN(ir0 + dr, nr); + + const int cr0 = dc * cols_ith; + const int cr1 = MIN(cr0 + dc, nc); + + int64_t o[4] = { 0, 0, 0, 0 }; + o[dim] = src0->ne[dim]; + const float * x; + + for (int64_t i = ir0; i < ir1; ++i) { + const int64_t i3 = i / (ne02 * ne01); + const int64_t i2 = (i - i3 * ne02 * ne01) / ne01; + const int64_t i1 = (i - i3 * ne02 * ne01 - i2 * ne01); + + for (int i0 = cr0; i0 < cr1; i0++) { + if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) { + x = (const float *) ((const char *) src0->data + (i0) *nb00 + (i1) *nb01 + (i2) *nb02 + (i3) *nb03); + } else { + x = (const float *) ((const char *) src1->data + (i0 - o[0]) * nb10 + (i1 - o[1]) * nb11 + + (i2 - o[2]) * nb12 + (i3 - o[3]) * nb13); + } + + float * y = (float *) ((char *) dst->data + i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3); + + *y = *x; + } + } +} + +template void forward_binary(ggml_compute_params * params, ggml_tensor * op); +template void forward_binary(ggml_compute_params * params, ggml_tensor * op); +template void forward_binary(ggml_compute_params * params, ggml_tensor * op); +template void forward_binary(ggml_compute_params * params, ggml_tensor * op); +template void forward_binary(ggml_compute_params * params, ggml_tensor * op); +template void forward_binary(ggml_compute_params * params, ggml_tensor * op); +template void forward_binary(ggml_compute_params * params, ggml_tensor * op); +template void forward_binary(ggml_compute_params * params, ggml_tensor * op); +template void forward_sum_rows(const ggml_compute_params * params, ggml_tensor * op); +template void forward_sum_rows<_Float16>(const ggml_compute_params * params, ggml_tensor * op); +template void forward_repeat_nrows(ggml_compute_params * params, ggml_tensor * op); +template void forward_repeat_nrows(ggml_compute_params * params, ggml_tensor * op); +template void forward_repeat_dim1(ggml_compute_params * params, ggml_tensor * op); +template void forward_repeat_dim1(ggml_compute_params * params, ggml_tensor * op); +template void forward_get_rows(ggml_compute_params * params, ggml_tensor * op); +template void forward_get_rows(ggml_compute_params * params, ggml_tensor * op); +template void forward_concat(ggml_compute_params * params, ggml_tensor * op); +template void forward_concat(ggml_compute_params * params, ggml_tensor * op); + +} // namespace spacemit_kernels::rvv diff --git a/ggml/src/ggml-cpu/spacemit/rvv_kernels.h b/ggml/src/ggml-cpu/spacemit/rvv_kernels.h new file mode 100644 index 00000000000..edddf957c21 --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/rvv_kernels.h @@ -0,0 +1,95 @@ +#pragma once + +#include "ggml-cpu-impl.h" + +#include +#include +#include +#include + +namespace spacemit_kernels { + +constexpr auto div_round_up(auto up, auto down) { + return (up + down - 1) / down; +} + +// Q8 Blk [f32] [s16] [int8 * blk_len] +// Q8 Blk N [f32 * N] [s16 * N] [int8 * blk_len * N] +constexpr size_t q8_blk_size(size_t blk_len, bool with_blk_sum = false) { + const size_t blk_size = sizeof(float) + blk_len * sizeof(int8_t) + (with_blk_sum ? sizeof(int16_t) : 0); + return blk_size; +} + +// Q8 HP row block: K is split into K32 subblocks. +// Each subblock stores [f32 scale] [int8 * 32], with an optional fp16 sum trailer per subblock. +constexpr size_t q8_hp_blk_size(size_t blk_len, bool with_blk_sum = false, bool with_blk_scale = false) { + const size_t subblk_count = div_round_up(blk_len, size_t(32)); + const size_t blk_size = blk_len * sizeof(int8_t) + subblk_count * sizeof(_Float16) + + (with_blk_sum ? subblk_count * sizeof(_Float16) : 0) + + (with_blk_scale ? sizeof(_Float16) : 0); + return blk_size; +} + +// Q8K Blk [f32] [s16 * (blk_len / 16)] [int8 * blk_len] +// Q8K Blk N [f32 * N] [s16 * (blk_len / 16) * N] [int8 * blk_len * N] +constexpr size_t q8k_blk_size(size_t blk_len) { + const size_t blk_size = sizeof(float) + blk_len * sizeof(int8_t) + sizeof(int16_t) * blk_len / 16; + return blk_size; +} + +using quantize_a_row_def = std::function; + +namespace rvv { +void memcpy1d(void * dst, const void * src, int64_t size); + +void memcpy2d(void * dst, int64_t dst_stride, const void * src, int64_t src_stride, int64_t tile_rows, int64_t size); + +void forward_flash_attn_ext_f16_one_chunk_vlen1024_vf16(const ggml_compute_params * params, + ggml_tensor * dst, + int ir0, + int ir1, + void * tcm_buffer, + size_t tcm_buffer_size); + +void forward_flash_attn_ext_f16_tiled_vlen1024_vf16(const ggml_compute_params * params, + ggml_tensor * dst, + int ir0, + int ir1, + void * tcm_buffer, + size_t tcm_buffer_size); + +void forward_rms_norm_f32(ggml_compute_params * params, ggml_tensor * op); + +void forward_norm_f32(ggml_compute_params * params, ggml_tensor * op); + +void forward_cont_with_permute(ggml_compute_params * params, ggml_tensor * op); + +void forward_cpy_with_permute(ggml_compute_params * params, ggml_tensor * op); + +template void forward_get_rows(ggml_compute_params * params, ggml_tensor * op); + +template void forward_concat(ggml_compute_params * params, ggml_tensor * op); + +template void forward_binary(ggml_compute_params * params, ggml_tensor * op); + +template void forward_sum_rows(const ggml_compute_params * params, ggml_tensor * op); + +template void forward_repeat_nrows(ggml_compute_params * params, ggml_tensor * op); + +template void forward_repeat_dim1(ggml_compute_params * params, ggml_tensor * op); + +void quantize_a_row_i8(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr); + +void quantize_a_4row_i8(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr); + +void quantize_a_row_i8_hp(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr); + +void quantize_a_4row_i8_hp(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr); + +void quantize_a_row_i8k(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr); + +void quantize_a_4row_i8k(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr); + +} // namespace rvv + +} // namespace spacemit_kernels diff --git a/ggml/src/ggml-cpu/spacemit/spine_barrier.h b/ggml/src/ggml-cpu/spacemit/spine_barrier.h new file mode 100644 index 00000000000..f897dad4b8a --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/spine_barrier.h @@ -0,0 +1,34 @@ +#pragma once + +#include +#include + +#define SPINE_CACHE_LINE 64 +#define SPINE_CACHE_ALIGN __attribute__((aligned(SPINE_CACHE_LINE))) + +struct spine_barrier_t { + SPINE_CACHE_ALIGN std::atomic pending_; + SPINE_CACHE_ALIGN std::atomic rounds_; + SPINE_CACHE_ALIGN int64_t total_; +}; + +inline void spine_barrier_wait(spine_barrier_t * b) { + auto cur_round = b->rounds_.load(std::memory_order_acquire); + auto cnt = --b->pending_; + if (cnt == 0) { + b->pending_.store(b->total_); + b->rounds_.store(cur_round + 1); + } else { + while (cur_round == b->rounds_.load(std::memory_order_relaxed)) { + __asm__ volatile("pause " ::: "memory"); + } + } +} + +inline void spine_barrier_init(spine_barrier_t * b, int num_barriers, uint64_t thread_count) { + for (int i = 0; i < num_barriers; i++) { + b[i].total_ = thread_count; + b[i].pending_.store(thread_count); + b[i].rounds_.store(0); + } +} diff --git a/ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp b/ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp new file mode 100644 index 00000000000..1409423b145 --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp @@ -0,0 +1,760 @@ +#include "spine_mem_pool.h" + +#include "common.h" +#include "ime_env.h" +#include "spine_tcm.h" + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace ggml::cpu::riscv64_spacemit { +namespace { + +constexpr size_t SPINE_MEM_POOL_CHUNK_SIZE = 512ull * 1024ull * 1024ull; +constexpr size_t SPINE_SHARE_MEM_POOL_CHUNK_SIZE = 512ull * 1024ull; +constexpr size_t SPINE_MEM_POOL_1G_REGION_SIZE = 1ull << 30; +constexpr uint64_t HUGETLB_1G_FLAG_REQUIRE_PUD = 1ull << 0; +constexpr char SPINE_MEM_POOL_HUGETLB_1G_DEV[] = "/dev/hugetlb_1g"; +constexpr char SPINE_MEM_POOL_TCM_SYNC_MEM_DEV[] = "/dev/tcm_sync_mem"; + +struct hugetlb_1g_region { + uint64_t size{ 0 }; + uint64_t dma_addr{ 0 }; + uint64_t flags{ 0 }; + uint64_t reserved{ 0 }; +}; + +#define HUGETLB_1G_IOC_MAGIC 'M' +#define HUGETLB_1G_IOC_ALLOC _IOWR(HUGETLB_1G_IOC_MAGIC, 0x00, struct hugetlb_1g_region) +#define HUGETLB_1G_IOC_FREE _IO(HUGETLB_1G_IOC_MAGIC, 0x01) + +struct free_block { + size_t offset{ 0 }; + size_t size{ 0 }; +}; + +struct pool_chunk { + uint8_t * base{ nullptr }; + size_t size{ 0 }; + int fd{ -1 }; + std::vector free_blocks; +}; + +struct pool_allocation { + void * chunk_base{ nullptr }; + size_t chunk_size{ 0 }; + void * base{ nullptr }; + size_t size{ 0 }; +}; + +bool is_power_of_two(size_t value) { + return value != 0 && (value & (value - 1)) == 0; +} + +bool align_up(size_t value, size_t alignment, size_t * aligned_value) { + if (aligned_value == nullptr || alignment == 0) { + return false; + } + + const size_t remainder = value % alignment; + if (remainder == 0) { + *aligned_value = value; + return true; + } + + const size_t padding = alignment - remainder; + if (value > std::numeric_limits::max() - padding) { + return false; + } + + *aligned_value = value + padding; + return true; +} + +bool align_up_uintptr(uintptr_t value, size_t alignment, uintptr_t * aligned_value) { + if (aligned_value == nullptr || alignment == 0) { + return false; + } + + const uintptr_t remainder = value % alignment; + if (remainder == 0) { + *aligned_value = value; + return true; + } + + const uintptr_t padding = alignment - remainder; + if (value > std::numeric_limits::max() - padding) { + return false; + } + + *aligned_value = value + padding; + return true; +} + +class spine_mem_pool_manager { + public: + explicit spine_mem_pool_manager(size_t default_chunk_size) : default_chunk_size_(default_chunk_size) {} + + virtual ~spine_mem_pool_manager() = default; + + void * alloc(size_t size, size_t alignment) { + if (size == 0 || !is_power_of_two(alignment)) { + return nullptr; + } + + size_t aligned_size = 0; + if (!align_up(size, alignment, &aligned_size)) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: align_up failed for size %zu alignment %zu\n", __func__, size, + alignment); + return nullptr; + } + + pool_allocation allocation; + + std::lock_guard lock(mutex_); + + if (!try_alloc_locked(aligned_size, alignment, &allocation)) { + if (!add_chunk_locked(aligned_size, alignment)) { + return nullptr; + } + + if (!try_alloc_locked(aligned_size, alignment, &allocation)) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: allocation retry failed for size %zu alignment %zu\n", + __func__, aligned_size, alignment); + return nullptr; + } + } + + try { + const auto [allocation_it, inserted] = allocations_.emplace(allocation.base, allocation); + if (!inserted) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: duplicate allocation key %p\n", __func__, allocation.base); + rollback_allocation_locked(allocation); + return nullptr; + } + } catch (const std::bad_alloc &) { + rollback_allocation_locked(allocation); + throw; + } + + return allocation.base; + } + + void free(void * base) { + if (base == nullptr) { + return; + } + + std::lock_guard lock(mutex_); + + auto allocation_it = allocations_.find(base); + if (allocation_it == allocations_.end()) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: unknown allocation %p\n", __func__, base); + return; + } + + pool_allocation allocation = allocation_it->second; + allocations_.erase(allocation_it); + + auto chunk_it = find_chunk_locked(allocation); + if (chunk_it == chunks_.end()) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: unknown chunk for allocation %p size %zu\n", __func__, + allocation.base, allocation.size); + return; + } + + auto * chunk_base = chunk_it->base; + auto * alloc_base = static_cast(allocation.base); + if (alloc_base < chunk_base || alloc_base >= chunk_base + chunk_it->size) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: allocation %p out of chunk range %p..%p\n", __func__, + allocation.base, chunk_base, chunk_base + chunk_it->size); + return; + } + + const size_t offset = static_cast(alloc_base - chunk_base); + if (offset > chunk_it->size || allocation.size > chunk_it->size - offset) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: allocation %p size %zu exceeds chunk size %zu\n", __func__, + allocation.base, allocation.size, chunk_it->size); + return; + } + + insert_free_block_locked(*chunk_it, { offset, allocation.size }); + maybe_release_empty_chunk_locked(chunk_it); + } + + protected: + void release_chunks() { + std::lock_guard lock(mutex_); + + allocations_.clear(); + for (auto & chunk : chunks_) { + dealloc_chunk(&chunk); + } + chunks_.clear(); + } + + size_t default_chunk_size() const { return default_chunk_size_; } + + static void clear_chunk(pool_chunk * chunk) { + chunk->base = nullptr; + chunk->size = 0; + chunk->fd = -1; + chunk->free_blocks.clear(); + } + + virtual bool alloc_chunk(size_t min_size, size_t alignment, void * hint_addr, pool_chunk * chunk) = 0; + virtual void dealloc_chunk(pool_chunk * chunk) = 0; + + private: + struct alloc_candidate { + size_t chunk_index{ 0 }; + size_t block_index{ 0 }; + size_t aligned_offset{ 0 }; + uintptr_t address{ std::numeric_limits::max() }; + bool valid{ false }; + }; + + std::vector::iterator find_chunk_locked(const pool_allocation & allocation) { + return std::find_if(chunks_.begin(), chunks_.end(), [&](const pool_chunk & chunk) { + return chunk.base == allocation.chunk_base && chunk.size == allocation.chunk_size; + }); + } + + bool add_chunk_locked(size_t min_size, size_t alignment) { + pool_chunk chunk; + const size_t chunk_request = default_chunk_size_ == 0 ? min_size : std::max(min_size, default_chunk_size_); + void * hint_addr = nullptr; + + for (const auto & existing_chunk : chunks_) { + auto * chunk_end = existing_chunk.base + existing_chunk.size; + if (hint_addr == nullptr || chunk_end > hint_addr) { + hint_addr = chunk_end; + } + } + + if (!alloc_chunk(chunk_request, alignment, hint_addr, &chunk)) { + return false; + } + + if (chunk.base == nullptr || chunk.size < min_size) { + GGML_LOG_ERROR( + "CPU_RISCV64_SPACEMIT: %s: invalid chunk returned for request size %zu, chunk_base=%p chunk_size=%zu\n", + __func__, min_size, chunk.base, chunk.size); + dealloc_chunk(&chunk); + return false; + } + + try { + chunk.free_blocks.push_back({ 0, chunk.size }); + chunks_.push_back(std::move(chunk)); + } catch (const std::bad_alloc &) { + dealloc_chunk(&chunk); + throw; + } + + return true; + } + + void rollback_allocation_locked(const pool_allocation & allocation) { + auto chunk_it = find_chunk_locked(allocation); + if (chunk_it == chunks_.end()) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: failed to rollback allocation %p, owning chunk not found\n", + __func__, allocation.base); + return; + } + + auto * chunk_base = chunk_it->base; + auto * alloc_base = static_cast(allocation.base); + if (alloc_base < chunk_base || alloc_base >= chunk_base + chunk_it->size) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: failed to rollback allocation %p, chunk range is invalid\n", + __func__, allocation.base); + return; + } + + const size_t offset = static_cast(alloc_base - chunk_base); + if (offset > chunk_it->size || allocation.size > chunk_it->size - offset) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: failed to rollback allocation %p size %zu\n", __func__, + allocation.base, allocation.size); + return; + } + + insert_free_block_locked(*chunk_it, { offset, allocation.size }); + maybe_release_empty_chunk_locked(chunk_it); + } + + bool try_alloc_locked(size_t size, size_t alignment, pool_allocation * allocation) { + alloc_candidate best; + + for (size_t chunk_index = 0; chunk_index < chunks_.size(); ++chunk_index) { + const auto & chunk = chunks_[chunk_index]; + for (size_t block_index = 0; block_index < chunk.free_blocks.size(); ++block_index) { + const auto & block = chunk.free_blocks[block_index]; + + uintptr_t aligned_addr = 0; + const auto block_addr = reinterpret_cast(chunk.base + block.offset); + if (!align_up_uintptr(block_addr, alignment, &aligned_addr)) { + continue; + } + + if (aligned_addr < block_addr) { + continue; + } + + const size_t aligned_offset = block.offset + static_cast(aligned_addr - block_addr); + const size_t padding = aligned_offset - block.offset; + if (padding > block.size || size > block.size - padding) { + continue; + } + + if (!best.valid || aligned_addr < best.address) { + best.chunk_index = chunk_index; + best.block_index = block_index; + best.aligned_offset = aligned_offset; + best.address = aligned_addr; + best.valid = true; + } + } + } + + if (!best.valid) { + return false; + } + + auto & chunk = chunks_[best.chunk_index]; + const free_block block = chunk.free_blocks[best.block_index]; + const size_t padding = best.aligned_offset - block.offset; + const size_t alloc_end = best.aligned_offset + size; + const size_t block_end = block.offset + block.size; + + chunk.free_blocks.erase(chunk.free_blocks.begin() + best.block_index); + auto insert_it = chunk.free_blocks.begin() + best.block_index; + if (padding != 0) { + insert_it = chunk.free_blocks.insert(insert_it, { block.offset, padding }); + ++insert_it; + } + if (alloc_end < block_end) { + chunk.free_blocks.insert(insert_it, { alloc_end, block_end - alloc_end }); + } + + allocation->chunk_base = chunk.base; + allocation->chunk_size = chunk.size; + allocation->base = chunk.base + best.aligned_offset; + allocation->size = size; + return true; + } + + void maybe_release_empty_chunk_locked(std::vector::iterator chunk_it) { + if (chunk_it->free_blocks.size() != 1) { + return; + } + + const auto & block = chunk_it->free_blocks.front(); + if (block.offset != 0 || block.size != chunk_it->size) { + return; + } + + dealloc_chunk(&*chunk_it); + chunks_.erase(chunk_it); + } + + void insert_free_block_locked(pool_chunk & chunk, free_block block) { + auto it = chunk.free_blocks.begin(); + while (it != chunk.free_blocks.end() && it->offset < block.offset) { + ++it; + } + + if (it != chunk.free_blocks.begin()) { + const auto & prev = *(it - 1); + if (prev.offset + prev.size > block.offset) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: overlapping free block at offset %zu size %zu\n", __func__, + block.offset, block.size); + return; + } + } + + if (it != chunk.free_blocks.end() && block.offset + block.size > it->offset) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: overlapping next free block at offset %zu size %zu\n", __func__, + block.offset, block.size); + return; + } + + it = chunk.free_blocks.insert(it, block); + + if (it != chunk.free_blocks.begin()) { + auto prev = it - 1; + if (prev->offset + prev->size == it->offset) { + it->offset = prev->offset; + it->size += prev->size; + it = chunk.free_blocks.erase(prev); + } + } + + if (it + 1 != chunk.free_blocks.end() && it->offset + it->size == (it + 1)->offset) { + it->size += (it + 1)->size; + chunk.free_blocks.erase(it + 1); + } + } + + std::mutex mutex_; + std::vector chunks_; + std::unordered_map allocations_; + size_t default_chunk_size_{ 0 }; +}; + +class spine_mem_pool_posix final : public spine_mem_pool_manager { + public: + spine_mem_pool_posix() : spine_mem_pool_manager(0) {} + + ~spine_mem_pool_posix() override { release_chunks(); } + + private: + bool alloc_chunk(size_t min_size, size_t alignment, void * hint_addr, pool_chunk * chunk) override { + (void) hint_addr; + + const size_t alloc_alignment = std::max(alignment, sizeof(void *)); + void * base = nullptr; + const int rc = posix_memalign(&base, alloc_alignment, min_size); + if (rc != 0) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: posix_memalign failed for size %zu alignment %zu, rc=%d\n", + __func__, min_size, alloc_alignment, rc); + return false; + } + + chunk->base = static_cast(base); + chunk->size = min_size; + chunk->fd = -1; + return true; + } + + void dealloc_chunk(pool_chunk * chunk) override { + std::free(chunk->base); + clear_chunk(chunk); + } +}; + +class spine_mem_pool_transparent_hugepage final : public spine_mem_pool_manager { + public: + spine_mem_pool_transparent_hugepage() : spine_mem_pool_manager(SPINE_MEM_POOL_CHUNK_SIZE) {} + + ~spine_mem_pool_transparent_hugepage() override { release_chunks(); } + + private: + bool alloc_chunk(size_t min_size, size_t alignment, void * hint_addr, pool_chunk * chunk) override { + (void) alignment; + + size_t chunk_size = 0; + if (!align_up(min_size, default_chunk_size(), &chunk_size)) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: failed to round chunk size for %zu\n", __func__, min_size); + return false; + } + + void * map_addr = mmap(hint_addr, chunk_size, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); + if (map_addr == MAP_FAILED) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: mmap failed for chunk size %zu, errno=%d\n", __func__, chunk_size, + errno); + return false; + } + + if (madvise(map_addr, chunk_size, MADV_HUGEPAGE) != 0) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: madvise(MADV_HUGEPAGE) failed for chunk size %zu, errno=%d\n", + __func__, chunk_size, errno); + munmap(map_addr, chunk_size); + return false; + } + + chunk->base = static_cast(map_addr); + chunk->size = chunk_size; + chunk->fd = -1; + return true; + } + + void dealloc_chunk(pool_chunk * chunk) override { + if (chunk->base != nullptr && chunk->size != 0 && munmap(chunk->base, chunk->size) != 0) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: munmap failed for chunk %p size %zu, errno=%d\n", __func__, + chunk->base, chunk->size, errno); + } + + clear_chunk(chunk); + } +}; + +class spine_mem_pool_hugetlb_1g final : public spine_mem_pool_manager { + public: + spine_mem_pool_hugetlb_1g() : spine_mem_pool_manager(SPINE_MEM_POOL_1G_REGION_SIZE) {} + + ~spine_mem_pool_hugetlb_1g() override { release_chunks(); } + + private: + bool alloc_chunk(size_t min_size, size_t alignment, void * hint_addr, pool_chunk * chunk) override { + (void) alignment; + (void) hint_addr; + + size_t region_size = 0; + if (!align_up(min_size, SPINE_MEM_POOL_1G_REGION_SIZE, ®ion_size)) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: failed to round hugetlb_1g size for %zu\n", __func__, min_size); + return false; + } + + const int fd = open(SPINE_MEM_POOL_HUGETLB_1G_DEV, O_RDWR); + if (fd < 0) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: open(%s) failed, errno=%d\n", __func__, + SPINE_MEM_POOL_HUGETLB_1G_DEV, errno); + return false; + } + + hugetlb_1g_region region; + region.size = region_size; + region.flags = HUGETLB_1G_FLAG_REQUIRE_PUD; + if (ioctl(fd, HUGETLB_1G_IOC_ALLOC, ®ion) < 0) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: HUGETLB_1G_IOC_ALLOC failed for size %zu, errno=%d\n", __func__, + region_size, errno); + close(fd); + return false; + } + + void * map_addr = mmap(nullptr, region.size, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); + if (map_addr == MAP_FAILED) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: mmap failed for hugetlb_1g size %llu, errno=%d\n", __func__, + static_cast(region.size), errno); + ioctl(fd, HUGETLB_1G_IOC_FREE); + close(fd); + return false; + } + + chunk->base = static_cast(map_addr); + chunk->size = region.size; + chunk->fd = fd; + return true; + } + + void dealloc_chunk(pool_chunk * chunk) override { + if (chunk->base != nullptr && chunk->size != 0 && munmap(chunk->base, chunk->size) != 0) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: munmap failed for hugetlb_1g chunk %p size %zu, errno=%d\n", + __func__, chunk->base, chunk->size, errno); + } + + if (chunk->fd >= 0) { + if (ioctl(chunk->fd, HUGETLB_1G_IOC_FREE) < 0) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: HUGETLB_1G_IOC_FREE failed for chunk %p, errno=%d\n", + __func__, chunk->base, errno); + } + + close(chunk->fd); + } + + clear_chunk(chunk); + } +}; + +class spine_mem_pool_shared_mem final : public spine_mem_pool_manager { + public: + spine_mem_pool_shared_mem() : spine_mem_pool_manager(SPINE_SHARE_MEM_POOL_CHUNK_SIZE) {} + + ~spine_mem_pool_shared_mem() override { release_chunks(); } + + private: + bool alloc_chunk(size_t min_size, size_t alignment, void * hint_addr, pool_chunk * chunk) override { + (void) alignment; + + if (hint_addr != nullptr) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: shared_mem does not support multiple active chunks\n", __func__); + return false; + } + + if (min_size > default_chunk_size()) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: shared_mem request %zu exceeds chunk size %zu\n", __func__, + min_size, default_chunk_size()); + return false; + } + + const int fd = open(SPINE_MEM_POOL_TCM_SYNC_MEM_DEV, O_RDWR | O_SYNC); + if (fd < 0) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: open(%s) failed, errno=%d\n", __func__, + SPINE_MEM_POOL_TCM_SYNC_MEM_DEV, errno); + return false; + } + + void * map_addr = mmap(nullptr, default_chunk_size(), PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); + if (map_addr == MAP_FAILED) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: mmap failed for %s size %zu, errno=%d\n", __func__, + SPINE_MEM_POOL_TCM_SYNC_MEM_DEV, default_chunk_size(), errno); + close(fd); + return false; + } + + chunk->base = static_cast(map_addr); + chunk->size = default_chunk_size(); + chunk->fd = fd; + return true; + } + + void dealloc_chunk(pool_chunk * chunk) override { + if (chunk->base != nullptr && chunk->size != 0 && munmap(chunk->base, chunk->size) != 0) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: munmap failed for shared_mem chunk %p size %zu, errno=%d\n", + __func__, chunk->base, chunk->size, errno); + } + + if (chunk->fd >= 0) { + close(chunk->fd); + } + + clear_chunk(chunk); + } +}; + +spine_mem_pool_manager & get_spine_mem_pool_manager() { + static std::once_flag pool_once; + static std::unique_ptr selected_pool; + static spine_mem_pool_backend selected_backend = spine_mem_pool_backend::none; + + spine_mem_pool_backend backend = global_spine_env_info.mem_backend; + if (backend == spine_mem_pool_backend::none) { + backend = spine_mem_pool_backend::transparent_hugepage; + } + + std::call_once(pool_once, [&]() { + selected_backend = backend; + + switch (selected_backend) { + case spine_mem_pool_backend::posix_memalign: + selected_pool = std::make_unique(); + break; + case spine_mem_pool_backend::transparent_hugepage: + selected_pool = std::make_unique(); + break; + case spine_mem_pool_backend::hugetlb_1g: + selected_pool = std::make_unique(); + break; + case spine_mem_pool_backend::none: + selected_backend = spine_mem_pool_backend::transparent_hugepage; + selected_pool = std::make_unique(); + break; + } + }); + + if (backend != selected_backend) { + GGML_LOG_ERROR( + "CPU_RISCV64_SPACEMIT: %s: mem pool backend is process-global and mutually exclusive, requested=%d but " + "selected=%d\n", + __func__, static_cast(backend), static_cast(selected_backend)); + } + + if (selected_pool) { + return *selected_pool; + } + + throw std::bad_alloc(); +} + +spine_mem_pool_manager & get_spine_mem_pool_shared_mem_manager() { + static std::once_flag shared_mem_pool_once; + static std::unique_ptr shared_mem_pool; + + std::call_once(shared_mem_pool_once, [&]() { shared_mem_pool = std::make_unique(); }); + + if (shared_mem_pool) { + return *shared_mem_pool; + } + + throw std::bad_alloc(); +} + +} // namespace + +bool spine_mem_pool_tcm_init(spine_mem_pool_tcm_info * info) noexcept { + if (info == nullptr) { + return false; + } + + *info = {}; + + if (spine_tcm_open_handle(NULL) != 0 || !spine_tcm_is_available()) { + return false; + } + + spine_tcm_mem_info_t mem_info; + if (spine_tcm_mem_info(&mem_info) != 0) { + return false; + } + + info->available = true; + info->blk_size = mem_info.blk_size; + info->blk_num = mem_info.blk_num; + info->is_fake_tcm = mem_info.is_fake_tcm != 0; + return true; +} + +void * spine_mem_pool_tcm_mem_get(int cpu_id) noexcept { + return spine_tcm_mem_get(cpu_id); +} + +void * spine_mem_pool_tcm_mem_wait(int cpu_id) noexcept { + return spine_tcm_mem_try_wait(cpu_id, 1000 * 1000); +} + +int spine_mem_pool_tcm_mem_release(int cpu_id) noexcept { + return spine_tcm_mem_release(cpu_id); +} + +void * spine_mem_pool_alloc(size_t size, size_t alignment) noexcept { + try { + return get_spine_mem_pool_manager().alloc(size, alignment); + } catch (const std::bad_alloc &) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: bad_alloc while allocating size %zu\n", __func__, size); + return nullptr; + } +} + +void * spine_mem_pool_shared_mem_alloc(size_t size, size_t alignment) noexcept { + try { + return get_spine_mem_pool_shared_mem_manager().alloc(size, alignment); + } catch (const std::bad_alloc &) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: bad_alloc while allocating shared memory size %zu\n", __func__, size); + return nullptr; + } +} + +void spine_mem_pool_free(void * base) noexcept { + try { + get_spine_mem_pool_manager().free(base); + } catch (const std::bad_alloc &) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: bad_alloc while freeing allocation %p\n", __func__, base); + } +} + +void spine_mem_pool_shared_mem_free(void * base) noexcept { + try { + get_spine_mem_pool_shared_mem_manager().free(base); + } catch (const std::bad_alloc &) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: bad_alloc while freeing shared allocation %p\n", __func__, base); + } +} + +} // namespace ggml::cpu::riscv64_spacemit + +extern "C" { +void * ggml_backend_cpu_riscv64_spacemit_alloc_shared(size_t size, size_t alignment) { + void * result = ggml::cpu::riscv64_spacemit::spine_mem_pool_shared_mem_alloc(size, alignment); + if (result == nullptr) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: failed to allocate shared memory size %zu alignment %zu\n", __func__, + size, alignment); + } + return result; +} + +void ggml_backend_cpu_riscv64_spacemit_free_shared(void * ptr) { + ggml::cpu::riscv64_spacemit::spine_mem_pool_shared_mem_free(ptr); +} +} diff --git a/ggml/src/ggml-cpu/spacemit/spine_mem_pool.h b/ggml/src/ggml-cpu/spacemit/spine_mem_pool.h new file mode 100644 index 00000000000..8740d2c99ef --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/spine_mem_pool.h @@ -0,0 +1,32 @@ +#pragma once + +#include +#include + +namespace ggml::cpu::riscv64_spacemit { + +enum class spine_mem_pool_backend : uint8_t { + none, + posix_memalign, + transparent_hugepage, + hugetlb_1g, +}; + +struct spine_mem_pool_tcm_info { + bool available{ false }; + size_t blk_size{ 0 }; + size_t blk_num{ 0 }; + bool is_fake_tcm{ false }; +}; + +bool spine_mem_pool_tcm_init(spine_mem_pool_tcm_info * info) noexcept; +void * spine_mem_pool_tcm_mem_get(int cpu_id) noexcept; +void * spine_mem_pool_tcm_mem_wait(int cpu_id) noexcept; +int spine_mem_pool_tcm_mem_release(int cpu_id) noexcept; + +void * spine_mem_pool_alloc(size_t size, size_t alignment) noexcept; +void * spine_mem_pool_shared_mem_alloc(size_t size, size_t alignment) noexcept; +void spine_mem_pool_free(void * base) noexcept; +void spine_mem_pool_shared_mem_free(void * base) noexcept; + +} // namespace ggml::cpu::riscv64_spacemit diff --git a/ggml/src/ggml-cpu/spacemit/spine_tcm.h b/ggml/src/ggml-cpu/spacemit/spine_tcm.h new file mode 100644 index 00000000000..f300d7d5c04 --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/spine_tcm.h @@ -0,0 +1,409 @@ +#ifndef SPINE_TCM_PUBLIC_H_ +#define SPINE_TCM_PUBLIC_H_ + +/* + * spine_tcm public API + * + * Usage: + * 1. Direct link mode + * Define SPINE_TCM_DIRECT_LINK and link against libspine_tcm.so. + * + * if (spine_tcm_is_available()) { + * void *buffer = spine_tcm_mem_get(0); + * spine_tcm_mem_free(0); + * } + * + * 2. Header-only loader mode + * Include this header without linking libspine_tcm.so. The loader first + * tries to reuse a process-global spine_tcm instance and falls back to + * dlopen("libspine_tcm.so") when needed. + * + * spine_tcm_open_handle(NULL); // optional pre-bind + * if (spine_tcm_is_available()) { + * void *buffer = spine_tcm_mem_get(0); + * spine_tcm_mem_free(0); + * } + */ + +#include +#include +#include + +#if !defined(SPINE_TCM_BUILD_SHARED) && !defined(SPINE_TCM_DIRECT_LINK) +# include +#endif + +#ifdef __cplusplus +extern "C" { +#endif + +#if defined(_WIN32) +# if defined(SPINE_TCM_BUILD_SHARED) +# define SPINE_TCM_API __declspec(dllexport) +# else +# define SPINE_TCM_API __declspec(dllimport) +# endif +#else +# define SPINE_TCM_API __attribute__((visibility("default"))) +#endif + +typedef struct spine_tcm_mem_info { + size_t blk_size; + size_t blk_num; + int is_fake_tcm; +} spine_tcm_mem_info_t; + +typedef struct spine_tcm_block_info { + int id; + void * va; + size_t size; + uint64_t phys_addr; + uint64_t cpu_affinity_mask; + int owner_tid; + int is_acquired; +} spine_tcm_block_info_t; + +/* Shared-library runtime ABI exported by libspine_tcm.so. */ +SPINE_TCM_API const char * spine_tcm_runtime_version(void); +SPINE_TCM_API int spine_tcm_runtime_is_available(void); +SPINE_TCM_API int spine_tcm_runtime_layout_info(spine_tcm_mem_info_t * info); +SPINE_TCM_API int spine_tcm_runtime_mem_info(int id, spine_tcm_block_info_t * info); +SPINE_TCM_API void * spine_tcm_runtime_mem_get(int id); +SPINE_TCM_API int spine_tcm_runtime_mem_free(int id); +SPINE_TCM_API void * spine_tcm_runtime_mem_try_wait(int id, size_t timeout_us); +SPINE_TCM_API int spine_tcm_runtime_mem_release(int id); +SPINE_TCM_API int spine_tcm_runtime_mem_force_release(int id); +SPINE_TCM_API int spine_tcm_runtime_mem_query(int id); + +#if defined(SPINE_TCM_DIRECT_LINK) +/* Optional no-op in direct-link mode. */ +static inline int spine_tcm_open_handle(const char * so_path) { + (void) so_path; + return 0; +} + +static inline const char * spine_tcm_version(void) { + return spine_tcm_runtime_version(); +} + +/* Returns 1 when the runtime driver is available, otherwise 0. */ +static inline int spine_tcm_is_available(void) { + return spine_tcm_runtime_is_available(); +} + +/* Returns runtime memory geometry and whether the current backend is fake TCM. */ +static inline int spine_tcm_mem_info(spine_tcm_mem_info_t * info) { + return spine_tcm_runtime_layout_info(info); +} + +/* Returns per-block runtime metadata for the given TCM id. */ +static inline int spine_tcm_block_info(int id, spine_tcm_block_info_t * info) { + return spine_tcm_runtime_mem_info(id, info); +} + +/* Returns a cached buffer for the given TCM id, or NULL on failure. */ +static inline void * spine_tcm_mem_get(int id) { + return spine_tcm_runtime_mem_get(id); +} + +/* Releases one reference acquired by spine_tcm_mem_get(id). */ +static inline int spine_tcm_mem_free(int id) { + return spine_tcm_runtime_mem_free(id); +} + +/* Waits for a TCM block handoff and returns the driver-owned buffer when available. */ +static inline void * spine_tcm_mem_try_wait(int id, size_t over_time) { + return spine_tcm_runtime_mem_try_wait(id, over_time); +} + +/* Releases a buffer acquired by spine_tcm_mem_try_wait(id, over_time). */ +static inline int spine_tcm_mem_release(int id) { + return spine_tcm_runtime_mem_release(id); +} + +/* Forces a release for the given TCM id when the backend supports it. */ +static inline int spine_tcm_mem_force_release(int id) { + return spine_tcm_runtime_mem_force_release(id); +} + +/* Returns whether the given TCM id is currently acquired. */ +static inline int spine_tcm_mem_query(int id) { + return spine_tcm_runtime_mem_query(id); +} +#elif !defined(SPINE_TCM_BUILD_SHARED) +typedef struct spine_tcm_handle { + void * module_handle; + int use_global_scope; + int owns_module_handle; + const char * (*runtime_version)(void); + int (*runtime_is_available)(void); + int (*runtime_layout_info)(spine_tcm_mem_info_t * info); + int (*runtime_mem_info)(int id, spine_tcm_block_info_t * info); + void * (*runtime_mem_get)(int id); + int (*runtime_mem_free)(int id); + void * (*runtime_mem_try_wait)(int id, size_t over_time); + int (*runtime_mem_release)(int id); + int (*runtime_mem_force_release)(int id); + int (*runtime_mem_query)(int id); +} spine_tcm_handle_t; + +static inline spine_tcm_handle_t * spine_tcm_default_handle(void) { + static spine_tcm_handle_t handle = { 0 }; + return &handle; +} + +static inline void spine_tcm_handle_reset(spine_tcm_handle_t * handle) { + if (handle != NULL) { + memset(handle, 0, sizeof(*handle)); + } +} + +static inline int spine_tcm_handle_bind(spine_tcm_handle_t * handle) { + void * symbol_scope = handle->use_global_scope ? RTLD_DEFAULT : handle->module_handle; + + handle->runtime_version = (const char * (*) (void) ) dlsym(symbol_scope, "spine_tcm_runtime_version"); + handle->runtime_is_available = (int (*)(void)) dlsym(symbol_scope, "spine_tcm_runtime_is_available"); + handle->runtime_layout_info = + (int (*)(spine_tcm_mem_info_t *)) dlsym(symbol_scope, "spine_tcm_runtime_layout_info"); + handle->runtime_mem_info = + (int (*)(int, spine_tcm_block_info_t *)) dlsym(symbol_scope, "spine_tcm_runtime_mem_info"); + handle->runtime_mem_get = (void * (*) (int) ) dlsym(symbol_scope, "spine_tcm_runtime_mem_get"); + handle->runtime_mem_free = (int (*)(int)) dlsym(symbol_scope, "spine_tcm_runtime_mem_free"); + handle->runtime_mem_try_wait = (void * (*) (int, size_t)) dlsym(symbol_scope, "spine_tcm_runtime_mem_try_wait"); + handle->runtime_mem_release = (int (*)(int)) dlsym(symbol_scope, "spine_tcm_runtime_mem_release"); + handle->runtime_mem_force_release = (int (*)(int)) dlsym(symbol_scope, "spine_tcm_runtime_mem_force_release"); + handle->runtime_mem_query = (int (*)(int)) dlsym(symbol_scope, "spine_tcm_runtime_mem_query"); + + return handle->runtime_version != NULL && handle->runtime_is_available != NULL && + handle->runtime_layout_info != NULL && handle->runtime_mem_info != NULL && + handle->runtime_mem_get != NULL && handle->runtime_mem_free != NULL && + handle->runtime_mem_try_wait != NULL && handle->runtime_mem_release != NULL && + handle->runtime_mem_force_release != NULL && handle->runtime_mem_query != NULL ? + 0 : + -1; +} + +/* + * Try to bind against an already-loaded process-global spine_tcm instance. + * The shared library exports spine_tcm_runtime_marker only for this probe. + */ +static inline int spine_tcm_try_bind_global(spine_tcm_handle_t * handle) { + if (dlsym(RTLD_DEFAULT, "spine_tcm_runtime_marker") == NULL) { + return -1; + } + + handle->use_global_scope = 1; + return spine_tcm_handle_bind(handle); +} + +/* + * Optional pre-bind entry point. + * + * Behavior: + * - Reuses an already-loaded global spine_tcm instance when available. + * - Otherwise loads the shared library from so_path or the default soname. + * - Repeated calls are safe and return 0 after the first successful bind. + */ +static inline int spine_tcm_open_handle(const char * so_path) { + spine_tcm_handle_t * resolved = spine_tcm_default_handle(); + const char * library = (so_path != NULL && so_path[0] != '\0') ? so_path : "libspine_tcm.so"; + + if (resolved->module_handle != NULL || resolved->use_global_scope) { + return 0; + } + + if (spine_tcm_try_bind_global(resolved) == 0) { + return 0; + } + + spine_tcm_handle_reset(resolved); + + resolved->module_handle = dlopen(library, RTLD_LAZY | RTLD_GLOBAL); + resolved->owns_module_handle = resolved->module_handle != NULL ? 1 : 0; + + if (resolved->module_handle == NULL) { + spine_tcm_handle_reset(resolved); + return -1; + } + + if (spine_tcm_handle_bind(resolved) != 0) { + if (resolved->owns_module_handle) { + dlclose(resolved->module_handle); + } + spine_tcm_handle_reset(resolved); + return -1; + } + + return 0; +} + +/* Returns 1 when the runtime driver is available, otherwise 0. */ +static inline int spine_tcm_is_available(void) { + spine_tcm_handle_t * resolved = spine_tcm_default_handle(); + + if (resolved->module_handle == NULL && !resolved->use_global_scope) { + (void) spine_tcm_open_handle(NULL); + } + + if ((resolved->module_handle == NULL && !resolved->use_global_scope) || resolved->runtime_is_available == NULL) { + return 0; + } + + return resolved->runtime_is_available(); +} + +/* Returns runtime memory geometry and whether the current backend is fake TCM. */ +static inline int spine_tcm_mem_info(spine_tcm_mem_info_t * info) { + spine_tcm_handle_t * resolved = spine_tcm_default_handle(); + + if (resolved->module_handle == NULL && !resolved->use_global_scope) { + (void) spine_tcm_open_handle(NULL); + } + + if ((resolved->module_handle == NULL && !resolved->use_global_scope) || resolved->runtime_layout_info == NULL) { + return -1; + } + + return resolved->runtime_layout_info(info); +} + +static inline const char * spine_tcm_version(void) { + spine_tcm_handle_t * resolved = spine_tcm_default_handle(); + + if (resolved->module_handle == NULL && !resolved->use_global_scope) { + (void) spine_tcm_open_handle(NULL); + } + + if ((resolved->module_handle == NULL && !resolved->use_global_scope) || resolved->runtime_version == NULL) { + return "unknown"; + } + + return resolved->runtime_version(); +} + +/* Returns per-block runtime metadata for the given TCM id. */ +static inline int spine_tcm_block_info(int id, spine_tcm_block_info_t * info) { + spine_tcm_handle_t * resolved = spine_tcm_default_handle(); + + if (resolved->module_handle == NULL && !resolved->use_global_scope) { + (void) spine_tcm_open_handle(NULL); + } + + if ((resolved->module_handle == NULL && !resolved->use_global_scope) || resolved->runtime_mem_info == NULL) { + return -1; + } + + return resolved->runtime_mem_info(id, info); +} + +/* Returns a cached buffer for the given TCM id, or NULL on failure. */ +static inline void * spine_tcm_mem_get(int id) { + spine_tcm_handle_t * resolved = spine_tcm_default_handle(); + + if (resolved->module_handle == NULL && !resolved->use_global_scope) { + (void) spine_tcm_open_handle(NULL); + } + + if (resolved->module_handle == NULL && !resolved->use_global_scope) { + return NULL; + } + + if (resolved->runtime_mem_get == NULL) { + return NULL; + } + + return resolved->runtime_mem_get(id); +} + +/* Releases one reference acquired by spine_tcm_mem_get(id). */ +static inline int spine_tcm_mem_free(int id) { + spine_tcm_handle_t * resolved = spine_tcm_default_handle(); + + if (resolved->module_handle == NULL && !resolved->use_global_scope) { + (void) spine_tcm_open_handle(NULL); + } + + if ((resolved->module_handle == NULL && !resolved->use_global_scope) || resolved->runtime_mem_free == NULL) { + return -1; + } + + return resolved->runtime_mem_free(id); +} + +/* Waits for a TCM block handoff and returns the driver-owned buffer when available. */ +static inline void * spine_tcm_mem_try_wait(int id, size_t over_time) { + spine_tcm_handle_t * resolved = spine_tcm_default_handle(); + + if (resolved->module_handle == NULL && !resolved->use_global_scope) { + (void) spine_tcm_open_handle(NULL); + } + + if (resolved->module_handle == NULL && !resolved->use_global_scope) { + return NULL; + } + + if (resolved->runtime_mem_try_wait == NULL) { + return NULL; + } + + return resolved->runtime_mem_try_wait(id, over_time); +} + +/* Releases a buffer acquired by spine_tcm_mem_try_wait(id, over_time). */ +static inline int spine_tcm_mem_release(int id) { + spine_tcm_handle_t * resolved = spine_tcm_default_handle(); + + if (resolved->module_handle == NULL && !resolved->use_global_scope) { + (void) spine_tcm_open_handle(NULL); + } + + if ((resolved->module_handle == NULL && !resolved->use_global_scope) || resolved->runtime_mem_release == NULL) { + return -1; + } + + return resolved->runtime_mem_release(id); +} + +/* Forces a release for the given TCM id when the backend supports it. */ +static inline int spine_tcm_mem_force_release(int id) { + spine_tcm_handle_t * resolved = spine_tcm_default_handle(); + + if (resolved->module_handle == NULL && !resolved->use_global_scope) { + (void) spine_tcm_open_handle(NULL); + } + + if ((resolved->module_handle == NULL && !resolved->use_global_scope) || + resolved->runtime_mem_force_release == NULL) { + return -1; + } + + return resolved->runtime_mem_force_release(id); +} + +/* Returns whether the given TCM id is currently acquired. */ +static inline int spine_tcm_mem_query(int id) { + spine_tcm_handle_t * resolved = spine_tcm_default_handle(); + + if (resolved->module_handle == NULL && !resolved->use_global_scope) { + (void) spine_tcm_open_handle(NULL); + } + + if ((resolved->module_handle == NULL && !resolved->use_global_scope) || resolved->runtime_mem_query == NULL) { + return -1; + } + + return resolved->runtime_mem_query(id); +} +#else +static inline const char * spine_tcm_version(void) { + return spine_tcm_runtime_version(); +} +#endif + +#define SPINE_TCM_VERSION (spine_tcm_version()) + +#ifdef __cplusplus +} +#endif + +#endif From 592a8cd15d028f8d9a709e777641a9736a213565 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 14 May 2026 13:05:52 +0300 Subject: [PATCH 633/831] logs : reduce (llama/23021) * logs : reduce * args : fix envs * server : fix build * common : print verbosity level at start * server : clean-up logs * server : print prompt processing timings + sampling params * minor : whitespaces --- ggml/src/ggml-metal/ggml-metal-device.m | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index fab7891c008..780dfe81bb3 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -672,7 +672,7 @@ ggml_metal_device_t ggml_metal_device_init(int device) { ![[dev->mtl_device name] containsString:@"M6"] && ![[dev->mtl_device name] containsString:@"A19"] && ![[dev->mtl_device name] containsString:@"A20"]) { - GGML_LOG_WARN("%s: tensor API disabled for pre-M5 and pre-A19 devices\n", __func__); + GGML_LOG_INFO("%s: tensor API disabled for pre-M5 and pre-A19 devices\n", __func__); dev->props.has_tensor = false; } From 13133ab299e94a413fed015841a424adec149b1c Mon Sep 17 00:00:00 2001 From: Zheyuan Chen Date: Thu, 14 May 2026 09:31:36 -0700 Subject: [PATCH 634/831] ggml-webgpu: makes the flash attn vec path subgroup-aware (llama/23040) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * ggml-webgpu: makes the flash attn vec path compile and size its split/reduce work from the device’s reported subgroup range instead of assuming 32 subgroup size. * ggml-webgpu: remove the extra max_wg_size >= max_subgroup_size guard. Remove hardcoded 32 when determine the value of reduce_wg_size and vec_nwg_cap --- ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp | 13 +++++++++---- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 12 +++++++----- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 62a523365b9..4c4eda1cbe5 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -770,9 +770,14 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions( (v_offset_elems % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0u); const bool kv_vec_type_supported = K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0; - const bool use_vec = context.supports_subgroups && (context.src0->ne[1] < 20) && (context.src0->ne[0] % 32 == 0) && - (context.src2->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && - kv_vec_type_supported && (K->type != GGML_TYPE_F16 || f16_vec4_aligned) && + const uint32_t kv_vec_head_align = K->type == GGML_TYPE_F16 ? GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH : + (uint32_t) ggml_blck_size(K->type); + const bool kv_vec_head_dims_aligned = context.src0->ne[0] % kv_vec_head_align == 0 && + context.src2->ne[0] % kv_vec_head_align == 0; + // Compile with enough invocations to cover the largest reported subgroup. + const bool use_vec = context.supports_subgroups && (context.src0->ne[1] < 20) && + kv_vec_head_dims_aligned && kv_vec_type_supported && + (K->type != GGML_TYPE_F16 || f16_vec4_aligned) && (context.src2->type == K->type); const bool tile_can_dispatch_all_q_rows = context.max_subgroup_size > 0 && @@ -808,7 +813,7 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions( decisions.q_tile = 1u; decisions.kv_tile = std::max(8u, std::min(32u, max_kv_tile)); decisions.kv_tile = (decisions.kv_tile / 8u) * 8u; - decisions.wg_size = std::max(1u, std::min(32u, context.max_subgroup_size)); + decisions.wg_size = context.max_subgroup_size; if (decisions.kv_direct) { decisions.kv_tile = std::min(decisions.kv_tile, GGML_WEBGPU_KV_SEQ_PAD); while (GGML_WEBGPU_KV_SEQ_PAD % decisions.kv_tile != 0) { diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 401c75c1230..78cb02be06d 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -1832,7 +1832,7 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, uint32_t blk_nblk1 = 0; uint32_t blk_batch_count = 0; - const uint32_t vec_nwg_cap = std::max(1u, std::min(32u, ctx->global_ctx->capabilities.max_subgroup_size)); + const uint32_t vec_nwg_cap = ctx->global_ctx->capabilities.min_subgroup_size; uint32_t nwg = 1u; const uint64_t kv_span = (uint64_t) std::max(1u, decisions->kv_tile); while ((2u * nwg * kv_span) < (uint64_t) K->ne[1] && nwg < vec_nwg_cap) { @@ -1953,8 +1953,11 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, std::vector reduce_params; std::vector reduce_entries; if (use_vec_reduce) { - const uint32_t reduce_wg_size = std::max( - 32u, std::min(nwg * 32u, ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup)); + const uint32_t reduce_sg_size = ctx->global_ctx->capabilities.max_subgroup_size; + const uint32_t reduce_wg_size = + std::max(reduce_sg_size, (uint32_t) std::min( + (uint64_t) nwg * reduce_sg_size, + ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup)); ggml_webgpu_shader_lib_context reduce_shader_ctx = shader_lib_ctx; reduce_shader_ctx.max_wg_size = reduce_wg_size; reduce_pipeline = ctx->shader_lib->get_flash_attn_vec_reduce_pipeline(reduce_shader_ctx); @@ -3542,8 +3545,7 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { const uint32_t kv_tile = decisions.kv_tile; - const uint32_t vec_nwg_cap = std::max( - 1u, std::min(32u, ctx->webgpu_global_ctx->capabilities.max_subgroup_size)); + const uint32_t vec_nwg_cap = ctx->webgpu_global_ctx->capabilities.min_subgroup_size; uint32_t nwg = 1u; const uint64_t kv_span = (uint64_t) std::max(1u, kv_tile); while ((2u * nwg * kv_span) < (uint64_t) K->ne[1] && nwg < vec_nwg_cap) { From e62d5893f4153226e761c2bc0ab80b28a63cb055 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Thu, 14 May 2026 22:58:58 +0200 Subject: [PATCH 635/831] HIP: RDNA3 mma FA, faster AMD transpose, tune AMD (llama/22880) Adds RDNA3 support to the CUDA mma FA kernel. To make the RDNA3 tensor cores work with the FP16 accumulation for VKQ the tiles they need to be 32 logical units long in direction of the attention head; for head sizes 80 and 112 that are not exactly divided by 32 the regular length of 16 with FP32 accumulation is used instead. The longer tiles also enable more efficient transposition for a warp size of 32 which is why it's also used for RDNA4. However, this scrambles the data layout of the accumulators along the attention head dimension. To prevent accidental misuse I added another entry to ggml_cuda_mma::data_layout. I also tuned the kernel parameters for RDNA3, RDNA4, and CDNA1 in general, during which I discovered that the kernel can be made to work for head sizes up to 256 for CDNA. For RDNA3/4 I was not able to get better performance that the tile kernel for head sizes > 128. --- ggml/src/ggml-cuda/fattn-mma-f16.cuh | 319 ++++++++++++++++++++------- ggml/src/ggml-cuda/fattn.cu | 57 ++--- ggml/src/ggml-cuda/mma.cuh | 149 +++++++++++-- 3 files changed, 398 insertions(+), 127 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 43e22c5e5ee..a25e912c4d2 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -125,61 +125,107 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co } static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_rdna(const int DKQ, const int DV, const int ncols) { - GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 128, 2, 64, 128, 128, 128, 2, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 8, 128, 2, 64, 32, 32, 32, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 16, 128, 2, 64, 32, 32, 32, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 32, 128, 2, 64, 32, 32, 32, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 64, 128, 2, 64, 32, 32, 32, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 32, 128, 2, 64, 160, 128, 64, 2, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 64, 128, 2, 64, 160, 128, 64, 2, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 8, 64, 2, 32, 40, 40, 40, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 16, 64, 2, 32, 40, 40, 40, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 32, 128, 2, 64, 40, 40, 40, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 64, 128, 2, 64, 40, 40, 40, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 128, 128, 128, 1, false); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 128, 1, false); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 256, 1, 32, 128, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 8, 64, 2, 32, 48, 48, 48, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 16, 64, 2, 32, 48, 48, 48, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 32, 128, 2, 64, 48, 48, 48, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 64, 128, 2, 64, 48, 48, 48, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 96, 64, 128, 1, false); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 8, 64, 2, 32, 56, 56, 56, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 16, 64, 2, 32, 56, 56, 56, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 32, 128, 2, 64, 56, 56, 56, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 64, 128, 2, 64, 56, 56, 56, 1, true); - // TODO tune specifically for RDNA - return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 8, 64, 2, 32, 64, 64, 64, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 16, 64, 2, 32, 64, 64, 64, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 32, 128, 2, 64, 64, 64, 64, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 64, 128, 2, 64, 64, 64, 64, 1, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 8, 64, 2, 32, 96, 64, 64, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 16, 64, 2, 32, 96, 64, 64, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 32, 128, 2, 64, 96, 64, 64, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 64, 128, 2, 64, 96, 64, 64, 1, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 8, 64, 2, 32, 128, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 64, 2, 32, 128, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 1, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 32, 128, 2, 32, 160, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 64, 128, 2, 32, 160, 128, 128, 1, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 8, 128, 3, 64, 96, 64, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 128, 3, 64, 96, 64, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 128, 2, 32, 128, 128, 128, 1, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 128, 3, 64, 96, 64, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 128, 3, 64, 96, 64, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 128, 2, 32, 160, 128, 128, 1, true); + + return fattn_mma_config(32, 1, 0, 0, 0, 0, 0, false); } static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_cdna(const int DKQ, const int DV, const int ncols) { - // Conservative configs for CDNA (MI100+): 64KB LDS, wavefront64, nstages=1 (no cp.async). - GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 8, 128, 2, 128, 32, 32, 32, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 16, 128, 2, 64, 32, 32, 32, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 32, 128, 2, 64, 32, 32, 32, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 64, 256, 2, 64, 32, 32, 32, 1, true); - - GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 8, 128, 2, 128, 40, 40, 40, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 16, 128, 2, 64, 40, 40, 40, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 32, 128, 2, 64, 40, 40, 40, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 8, 128, 1, 64, 32, 32, 32, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 16, 256, 2, 64, 32, 32, 32, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 32, 256, 2, 64, 32, 32, 32, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 64, 256, 4, 64, 32, 32, 32, 1, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 8, 256, 2, 64, 40, 40, 40, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 16, 256, 2, 64, 40, 40, 40, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 32, 256, 2, 64, 40, 40, 40, 1, true); GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 64, 256, 2, 64, 40, 40, 40, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 8, 128, 2, 128, 48, 48, 48, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 16, 128, 2, 64, 48, 48, 48, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 32, 128, 2, 64, 48, 48, 48, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 8, 256, 2, 64, 48, 48, 48, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 16, 256, 2, 64, 48, 48, 48, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 32, 256, 2, 64, 48, 48, 48, 1, true); GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 64, 256, 2, 64, 48, 48, 48, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 8, 128, 2, 128, 56, 56, 56, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 16, 128, 2, 64, 56, 56, 56, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 32, 128, 2, 64, 56, 56, 56, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 8, 256, 2, 64, 56, 56, 56, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 16, 256, 2, 64, 56, 56, 56, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 32, 256, 2, 64, 56, 56, 56, 1, true); GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 64, 256, 2, 64, 56, 56, 56, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 8, 128, 2, 128, 64, 64, 64, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 16, 128, 2, 64, 64, 64, 64, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 32, 128, 2, 64, 64, 64, 64, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 8, 256, 2, 64, 64, 64, 64, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 16, 256, 2, 64, 64, 64, 64, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 32, 256, 2, 64, 64, 64, 64, 1, true); GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 64, 256, 2, 64, 64, 64, 64, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 8, 64, 4, 64, 128, 128, 128, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 64, 4, 32, 128, 128, 128, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 1, true); - GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 256, 2, 32, 128, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 8, 256, 1, 64, 64, 64, 64, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 16, 256, 1, 64, 64, 64, 64, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 32, 256, 1, 64, 64, 64, 64, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 64, 512, 1, 64, 64, 64, 64, 1, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 8, 256, 1, 64, 128, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 256, 1, 64, 128, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 256, 1, 64, 128, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 512, 1, 64, 128, 128, 64, 1, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 32, 256, 1, 64, 160, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 64, 256, 1, 64, 160, 128, 128, 1, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 8, 256, 1, 64, 128, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 256, 1, 64, 128, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 256, 1, 64, 128, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 256, 1, 64, 128, 128, 128, 1, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 256, 1, 64, 128, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 256, 1, 64, 128, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 256, 1, 64, 160, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 64, 160, 128, 128, 1, true); - // Fallback for unsupported DKQ values (e.g. 576). Must return non-zero values to satisfy - // compile-time static_asserts even though the kernel guard prevents runtime execution. - // nthreads=256 gives nwarps=4 (warp_size=64) or 8 (warp_size=32), nbatch_fa=128 satisfies np*16 divisibility. - return fattn_mma_config(256, 1, 128, 4, 4, 4, 1, false); + return fattn_mma_config(32, 1, 0, 0, 0, 0, 0, false); } static __host__ fattn_mma_config ggml_cuda_fattn_mma_get_config(const int DKQ, const int DV, const int ncols, const int cc) { @@ -510,7 +556,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( const int jt, const int kb0, const int k_VKQ_sup) { -#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE) +#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr int ncols = ncols1 * ncols2; constexpr int cols_per_warp = T_B_KQ::I; @@ -712,6 +758,18 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( #pragma unroll for (int i00 = 0; i00 < nbatch_fa; i00 += np*T_C_KQ::J) { const int i0 = i00 + (threadIdx.y % np)*T_C_KQ::J; + + // The mask is stored as 16 bit half values, loading them as 32 bit half2 values is preferred in terms of speed. + // However, this is not possible for RDNA3 where 2 consecutive l indices are not consecutive in the mask memory layout. +#ifdef RDNA3 +#pragma unroll + for (int l = 0; l < T_C_KQ::ne; ++l) { + const int i = i0 + T_C_KQ::get_j(l); + const int j = ((threadIdx.y / np)*cols_per_warp + T_C_KQ::get_i(l)) / ncols2; + + KQ_C[i00/(np*T_C_KQ::J)].x[l] += __half2float(tile_mask[j*(nbatch_fa + 8) + i]); + } +#else #pragma unroll for (int l0 = 0; l0 < T_C_KQ::ne; l0 += 2) { const int i = (i0 + T_C_KQ::get_j(l0)) / 2; @@ -721,6 +779,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( KQ_C[i00/(np*T_C_KQ::J)].x[l0 + 0] += slope*tmp.x; KQ_C[i00/(np*T_C_KQ::J)].x[l0 + 1] += slope*tmp.y; } +#endif // RDNA3 } } @@ -827,13 +886,23 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } } #elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) - const half2 KQ_max_scale_h2 = make_half2( - KQ_max_scale[0], KQ_max_scale[0]); + if constexpr (std::is_same_v) { + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[0]); #pragma unroll - for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) { + for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) { #pragma unroll - for (int l = 0; l < T_C_VKQ::ne; ++l) { - VKQ_C[i].x[l] *= KQ_max_scale_h2; + for (int l = 0; l < T_C_VKQ::ne; ++l) { + VKQ_C[i].x[l] *= KQ_max_scale_h2; + } + } + } else { + static_assert(std::is_same_v, "bad VKQ type"); +#pragma unroll + for (int i = 0; i < DV/T_C_VKQ::J; ++i) { +#pragma unroll + for (int l = 0; l < T_C_VKQ::ne; ++l) { + VKQ_C[i].x[l] *= KQ_max_scale[0]; + } } } #else // Volta @@ -901,9 +970,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( const half2 * tile_V_i = !V_is_K_view || i0_stop > 2*nbatch_K2 ? tile_V : tile_V + i0_start/2; #if defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) - constexpr int i0_stride = cols_per_warp == 8 ? T_C_VKQ::I : 2*T_C_VKQ::J; #pragma unroll - for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += i0_stride) { + for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += T_A_VKQ::I) { static_assert((nbatch_fa/2) % (np*T_A_VKQ::J) == 0, "bad loop size"); #pragma unroll for (int k00 = 0; k00 < nbatch_fa/2; k00 += np*T_A_VKQ::J) { @@ -912,15 +980,15 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( T_A_VKQ A; // Transposed in SRAM but not in registers, gets transposed on load. load_ldmatrix_trans(A, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V); if constexpr (T_B_KQ::I == 8) { - mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]); + mma(VKQ_C[i_VKQ_0/T_A_VKQ::I], A, B[k00/(np*T_A_VKQ::J)]); } else { // Wide version of VKQ_C is column-major. #if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) // AMD matrix C is column-major. - mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]); + mma(VKQ_C[i_VKQ_0/T_A_VKQ::I], A, B[k00/(np*T_A_VKQ::J)]); #else // swap A and B for CUDA. - mma(VKQ_C[i_VKQ_0/i0_stride], B[k00/(np*T_A_VKQ::J)], A); + mma(VKQ_C[i_VKQ_0/T_A_VKQ::I], B[k00/(np*T_A_VKQ::J)], A); #endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) } } @@ -953,11 +1021,11 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0); NO_DEVICE_CODE; -#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE) +#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) } #if defined(TURING_MMA_AVAILABLE) -template struct mma_tile_sizes { +template struct mma_tile_sizes { using T_A_KQ = tile<16, 8, half2>; // row-major using T_B_KQ = tile<16, 8, half2>; // column-major using T_C_KQ = tile<16, 16, float>; // column-major @@ -965,7 +1033,7 @@ template struct mma_tile_sizes { using T_B_VKQ = tile<16, 8, half2>; // column-major using T_C_VKQ = tile<16, 8, half2>; // column-major }; -template<> struct mma_tile_sizes<8> { +template struct mma_tile_sizes { using T_A_KQ = tile<16, 8, half2>; // row-major using T_B_KQ = tile< 8, 8, half2>; // column-major using T_C_KQ = tile<16, 8, float>; // row-major @@ -973,8 +1041,60 @@ template<> struct mma_tile_sizes<8> { using T_B_VKQ = tile< 8, 8, half2>; // column-major using T_C_VKQ = tile<16, 4, half2>; // row-major }; -#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) -template struct mma_tile_sizes { +#elif defined(AMD_WMMA_AVAILABLE) +#ifdef RDNA3 +template struct mma_tile_sizes { + using T_A_KQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major + using T_B_KQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // column-major + using T_C_KQ = tile<16, 16, float, DATA_LAYOUT_I_MAJOR>; // column-major + using T_A_VKQ = tile<32, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major + using T_B_VKQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // column-major + using T_C_VKQ = tile<16, 16, half2, DATA_LAYOUT_I_MAJOR>; // column-major +}; +template struct mma_tile_sizes<80, ncols> { + using T_A_KQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major + using T_B_KQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // column-major + using T_C_KQ = tile<16, 16, float, DATA_LAYOUT_I_MAJOR>; // column-major + using T_A_VKQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major + using T_B_VKQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // column-major + using T_C_VKQ = tile<16, 16, float, DATA_LAYOUT_I_MAJOR>; // column-major +}; +template struct mma_tile_sizes<112, ncols> { + using T_A_KQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major + using T_B_KQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // column-major + using T_C_KQ = tile<16, 16, float, DATA_LAYOUT_I_MAJOR>; // column-major + using T_A_VKQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major + using T_B_VKQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // column-major + using T_C_VKQ = tile<16, 16, float, DATA_LAYOUT_I_MAJOR>; // column-major +}; +#else +template struct mma_tile_sizes { + using T_A_KQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR>; // row-major + using T_B_KQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR>; // column-major + using T_C_KQ = tile<16, 16, float, DATA_LAYOUT_I_MAJOR>; // column-major + using T_A_VKQ = tile<32, 8, half2, DATA_LAYOUT_I_MAJOR>; // row-major + using T_B_VKQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR>; // column-major + using T_C_VKQ = tile<16, 16, half2, DATA_LAYOUT_I_MAJOR_SCRAMBLED>; // column-major +}; +template struct mma_tile_sizes<80, ncols> { + using T_A_KQ = tile<16, 8, half2>; // row-major + using T_B_KQ = tile<16, 8, half2>; // column-major + using T_C_KQ = tile<16, 16, float>; // column-major + using T_A_VKQ = tile<16, 8, half2>; // row-major + using T_B_VKQ = tile<16, 8, half2>; // column-major + using T_C_VKQ = tile<16, 8, half2>; // column-major +}; +template struct mma_tile_sizes<112, ncols> { + using T_A_KQ = tile<16, 8, half2>; // row-major + using T_B_KQ = tile<16, 8, half2>; // column-major + using T_C_KQ = tile<16, 16, float>; // column-major + using T_A_VKQ = tile<16, 8, half2>; // row-major + using T_B_VKQ = tile<16, 8, half2>; // column-major + using T_C_VKQ = tile<16, 8, half2>; // column-major +}; +#endif // RDNA3 +#elif defined(AMD_MFMA_AVAILABLE) +template struct mma_tile_sizes { using T_A_KQ = tile<16, 8, half2>; // row-major using T_B_KQ = tile<16, 8, half2>; // column-major using T_C_KQ = tile<16, 16, float>; // column-major @@ -983,7 +1103,7 @@ template struct mma_tile_sizes { using T_C_VKQ = tile<16, 8, half2>; // column-major }; #else // Volta -template struct mma_tile_sizes { +template struct mma_tile_sizes { using T_A_KQ = tile< 8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major using T_B_KQ = tile<32, 4, half2, DATA_LAYOUT_I_MAJOR>; // column-major using T_C_KQ = tile<32, 8, float, DATA_LAYOUT_I_MAJOR>; // column-major @@ -1018,17 +1138,17 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const int zt_gqa, const int kb0_start, const int kb0_stop) { -#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE) +#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) //In this kernel Q, K, V are matrices while i, j, k are matrix indices. constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr int ncols = ncols1 * ncols2; - using T_A_KQ = typename mma_tile_sizes::T_A_KQ; - using T_B_KQ = typename mma_tile_sizes::T_B_KQ; - using T_C_KQ = typename mma_tile_sizes::T_C_KQ; - using T_A_VKQ = typename mma_tile_sizes::T_A_VKQ; - using T_B_VKQ = typename mma_tile_sizes::T_B_VKQ; - using T_C_VKQ = typename mma_tile_sizes::T_C_VKQ; + using T_A_KQ = typename mma_tile_sizes::T_A_KQ; + using T_B_KQ = typename mma_tile_sizes::T_B_KQ; + using T_C_KQ = typename mma_tile_sizes::T_C_KQ; + using T_A_VKQ = typename mma_tile_sizes::T_A_VKQ; + using T_B_VKQ = typename mma_tile_sizes::T_B_VKQ; + using T_C_VKQ = typename mma_tile_sizes::T_C_VKQ; constexpr int cols_per_warp = T_B_KQ::I; constexpr int cols_per_thread = get_cols_per_thread(); @@ -1061,6 +1181,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( T_B_KQ Q_B[(Q_in_reg ? DKQ/(2*T_B_KQ::J) : 1)]; #if defined(TURING_MMA_AVAILABLE) T_C_VKQ VKQ_C[cols_per_warp == 8 ? DV/T_C_VKQ::I : DV/(2*T_C_VKQ::J)]; +#elif defined(AMD_WMMA_AVAILABLE) && defined(RDNA3) + T_C_VKQ VKQ_C[DV % 32 != 0 ? DV/T_C_VKQ::J : DV/(2*T_C_VKQ::J)]; #elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) T_C_VKQ VKQ_C[ DV/(2*T_C_VKQ::J)]; #else // Volta @@ -1269,12 +1391,23 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( } } #elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) - const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[0]); + if constexpr (std::is_same_v) { + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[0]); #pragma unroll - for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) { + for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) { #pragma unroll - for (int l = 0; l < T_C_VKQ::ne; ++l) { - VKQ_C[i].x[l] *= KQ_max_scale_h2; + for (int l = 0; l < T_C_VKQ::ne; ++l) { + VKQ_C[i].x[l] *= KQ_max_scale_h2; + } + } + } else { + static_assert(std::is_same_v, "bad VKQ type"); +#pragma unroll + for (int i = 0; i < DV/T_C_VKQ::J; ++i) { +#pragma unroll + for (int l = 0; l < T_C_VKQ::ne; ++l) { + VKQ_C[i].x[l] *= KQ_max_scale[0]; + } } } #else // Volta @@ -1433,6 +1566,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( #pragma unroll for (int k00 = 0; k00 < DV/2; k00 += nbatch_combine) { if constexpr (cols_per_warp == 8) { + static_assert(std::is_same_v, "bad VKQ type"); const int jc_cwd = threadIdx.y*T_B_KQ::I + T_B_KQ::get_i(-1); // jc combine write data #pragma unroll for (int k1 = 0; k1 < nbatch_combine; k1 += T_B_KQ::J) { @@ -1447,14 +1581,45 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( } } else { const int j0 = threadIdx.y*cols_per_warp; + if constexpr (std::is_same_v) { + if constexpr (T_C_VKQ::dl == DATA_LAYOUT_I_MAJOR) { #pragma unroll - for (int k1 = 0; k1 < nbatch_combine; k1 += T_C_VKQ::J) { + for (int k1 = 0; k1 < nbatch_combine; k1 += T_C_VKQ::J) { #pragma unroll - for (int l = 0; l < T_C_VKQ::ne; ++l) { - const int j = j0 + T_C_VKQ::get_i(l); - const int k = k1 + T_C_VKQ::get_j(l); + for (int l = 0; l < T_C_VKQ::ne; ++l) { + const int j = j0 + T_C_VKQ::get_i(l); + const int k = k1 + T_C_VKQ::get_j(l); - tile_Q[j*tile_stride + k] = VKQ_C[(k00 + k1)/T_C_VKQ::J].x[l]; + tile_Q[j*tile_stride + k] = VKQ_C[(k00 + k1)/T_C_VKQ::J].x[l]; + } + } + } else { + static_assert(T_C_VKQ::dl == DATA_LAYOUT_I_MAJOR_SCRAMBLED, "bad T_C_VKQ data layout"); + using T_C_VKQ_us = tile; // us == unscrambled +#pragma unroll + for (int k1 = 0; k1 < nbatch_combine; k1 += T_C_VKQ::J) { + const T_C_VKQ_us VKQ_C_us = unscramble(VKQ_C[(k00 + k1)/T_C_VKQ::J]); +#pragma unroll + for (int l = 0; l < T_C_VKQ_us::ne; ++l) { + const int j = j0 + T_C_VKQ_us::get_i(l); + const int k = k1 + T_C_VKQ_us::get_j(l); + + tile_Q[j*tile_stride + k] = VKQ_C_us.x[l]; + } + } + } + } else { + static_assert(std::is_same_v, "bad VKQ type"); + half * tile_Q_h = (half *) tile_Q; +#pragma unroll + for (int k1 = 0; k1 < nbatch_combine; k1 += T_C_VKQ::J/2) { +#pragma unroll + for (int l = 0; l < T_C_VKQ::ne; ++l) { + const int j = j0 + T_C_VKQ::get_i(l); + const int k = 2*k1 + T_C_VKQ::get_j(l); + + tile_Q_h[j*(2*tile_stride) + k] = VKQ_C[(k00 + k1)/(T_C_VKQ::J/2)].x[l]; + } } } } @@ -1532,7 +1697,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop); NO_DEVICE_CODE; -#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE) +#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) } template @@ -1559,7 +1724,7 @@ static __global__ void flash_attn_ext_f16( const int32_t nb21, const int32_t nb22, const int64_t nb23, const int32_t ne31, const int32_t ne32, const int32_t ne33, const int32_t nb31, const int32_t nb32, const int64_t nb33) { -#if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE)) +#if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)) // Skip unused kernel variants for faster compilation: if (use_logit_softcap && !(DKQ == 128 || DKQ == 256 || DKQ == 512)) { @@ -1585,14 +1750,14 @@ static __global__ void flash_attn_ext_f16( #endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING #if defined(AMD_WMMA_AVAILABLE) - if (ncols1*ncols2 > 32 || ncols1*ncols2 < 16 || DKQ > 128 || ncols2 == 1) { + if (ncols1*ncols2 < 16 || ncols2 == 1 || DKQ > 128) { NO_DEVICE_CODE; return; } #endif // defined(AMD_WMMA_AVAILABLE) #if defined(AMD_MFMA_AVAILABLE) - if (DKQ != 64 && DKQ != 80 && DKQ != 96 && DKQ != 112 && DKQ != 128) { + if (ncols1*ncols2 < 16 || DKQ > 256) { NO_DEVICE_CODE; return; } @@ -1715,7 +1880,7 @@ static __global__ void flash_attn_ext_f16( ne31, ne32, ne33, nb31, nb32, nb33); NO_DEVICE_CODE; -#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE)) +#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)) } template diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index e045b04f727..1c7777e8a71 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -19,13 +19,14 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_con } if constexpr (ncols2 <= 16) { - if ((turing_mma_available(cc) || amd_wmma_available(cc)) && Q->ne[1] <= 16/ncols2) { + if (Q->ne[1] <= 16/ncols2) { ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); return; } } - if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING || amd_wmma_available(cc) || Q->ne[1] <= 32/ncols2) { + if (Q->ne[1] <= 32/ncols2 || (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) || + (GGML_CUDA_CC_IS_AMD(cc) && DKQ > 256)) { ggml_cuda_flash_attn_ext_mma_f16_case(ctx, dst); return; } @@ -477,12 +478,13 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const return BEST_FATTN_KERNEL_MMA_F16; } + const int ncols2_max = Q->ne[0] == 320 ? 32 : ((Q->ne[0] == 576 || Q->ne[0] == 192) ? 16 : 8); + int gqa_ratio_eff = 1; + while (gqa_ratio % (2*gqa_ratio_eff) == 0 && gqa_ratio_eff < ncols2_max) { + gqa_ratio_eff *= 2; + } + if (volta_mma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72) { - int gqa_ratio_eff = 1; - const int ncols2_max = (Q->ne[0] == 576 || Q->ne[0] == 192) ? 16 : 8; - while (gqa_ratio % (2*gqa_ratio_eff) == 0 && gqa_ratio_eff < ncols2_max) { - gqa_ratio_eff *= 2; - } if (can_use_vector_kernel && Q->ne[1] * gqa_ratio_eff <= 2) { return BEST_FATTN_KERNEL_VEC; } @@ -500,41 +502,22 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const return BEST_FATTN_KERNEL_WMMA_F16; } - if (amd_wmma_available(cc) && GGML_CUDA_CC_IS_RDNA4(cc) && gqa_opt_applies && Q->ne[0] <= 128 && Q->ne[0] != 40 && Q->ne[0] != 72) { - if (can_use_vector_kernel) { - if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) { - if (Q->ne[1] == 1) { - if (!gqa_opt_applies) { - return BEST_FATTN_KERNEL_VEC; - } - } - } else { - if (Q->ne[1] <= 2) { - return BEST_FATTN_KERNEL_VEC; - } - } + // AMD MFMA needs a certain minimum batch size to outscale the tile kernel for large head sizes. + if ((amd_mfma_available(cc) && Q->ne[0] <= 256) && Q->ne[0] != 40 && Q->ne[0] != 72) { + if ((Q->ne[0] <= 64 && Q->ne[1] * gqa_ratio_eff > 8)) { + return BEST_FATTN_KERNEL_MMA_F16; } - int gqa_ratio_eff = 1; - const int ncols2_max = Q->ne[0] == 576 ? 16 : 8; - while (gqa_ratio % (2*gqa_ratio_eff) == 0 && gqa_ratio_eff < ncols2_max) { - gqa_ratio_eff *= 2; + if ((Q->ne[0] <= 128 && Q->ne[1] * gqa_ratio_eff > 16)) { + return BEST_FATTN_KERNEL_MMA_F16; } - if (Q->ne[1] * gqa_ratio_eff <= 8) { - return BEST_FATTN_KERNEL_TILE; // AMD WMMA is only faster if the full tile width of 16 can be utilized. + if ((Q->ne[0] <= 256 && Q->ne[1] * gqa_ratio_eff > 64)) { + return BEST_FATTN_KERNEL_MMA_F16; } - return BEST_FATTN_KERNEL_MMA_F16; } - // Use MFMA flash attention for CDNA (MI100+): - if (amd_mfma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 192 && Q->ne[0] != 256 && Q->ne[0] != 512 && Q->ne[0] != 576) { - const int64_t eff_nq = Q->ne[1] * (gqa_opt_applies ? gqa_ratio : 1); - // MMA vs tile crossover benchmarked on MI300X @ d32768: - // hsk=64 (gqa=4): MMA wins at eff >= 128 (+11%) - // hsk=128 (gqa=4): MMA wins at eff >= 128 (+4%) - if (eff_nq >= (GGML_CUDA_CC_IS_CDNA1(cc) && Q->ne[0] == 64 ? 64 : 128)) { - return BEST_FATTN_KERNEL_MMA_F16; - } - // Fall through to tile kernel for small effective batch sizes. + // AMD WMMA is always faster than the tile kernel if the full tile width of 16 can be utilized. + if ((amd_wmma_available(cc) && gqa_opt_applies && Q->ne[0] <= 128) && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[1] * gqa_ratio_eff > 8) { + return BEST_FATTN_KERNEL_MMA_F16; } // If there are no tensor cores available, use the generic tile kernel: diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index 79bb2934c5f..8d7c69dc3e8 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -80,6 +80,7 @@ namespace ggml_cuda_mma { DATA_LAYOUT_J_MAJOR = 10, // Matrix C for CDNA and RDNA4, int and float matrix C for RDNA3. DATA_LAYOUT_I_MAJOR_MIRRORED = 20, // Volta, matrix A&B for RDNA3. DATA_LAYOUT_J_MAJOR_MIRRORED = 30, + DATA_LAYOUT_I_MAJOR_SCRAMBLED = 40, // Scrambled matrix C for faster transposition (RDNA4/CDNA), convert to float to unscramble. }; // Implemented mma combinations are: // - (I_MAJOR, I_MAJOR) -> I_MAJOR @@ -312,13 +313,19 @@ namespace ggml_cuda_mma { half2 x[ne] = {{0.0f, 0.0f}}; static constexpr __device__ bool supported() { - if (I == 16 && J == 8) return true; + if (I == 16 && J == 8) return true; + if (I == 16 && J == 16) return true; + if (I == 32 && J == 8) return true; return false; } static __device__ __forceinline__ int get_i(const int l) { if constexpr (I == 16 && J == 8) { return threadIdx.x % 16; + } else if constexpr (I == 16 && J == 16) { + return threadIdx.x % 16; + } else if constexpr (I == 32 && J == 8) { + return (threadIdx.x % 16) * 2 + l / (ne/2); } else { NO_DEVICE_CODE; return -1; @@ -327,7 +334,15 @@ namespace ggml_cuda_mma { static __device__ __forceinline__ int get_j(const int l) { if constexpr (I == 16 && J == 8) { - return ne * (threadIdx.x / 16) + l; + return (threadIdx.x / 16) * ne + l; + } else if constexpr (I == 16 && J == 16) { +#ifdef RDNA3 + return l*2 + (threadIdx.x / 16); +#else + return (threadIdx.x / 16) * ne + l; +#endif // RDNA3 + } else if constexpr (I == 32 && J == 8) { + return (threadIdx.x / 16) * (ne/2) + l % (ne/2); } else { NO_DEVICE_CODE; return -1; @@ -338,13 +353,19 @@ namespace ggml_cuda_mma { half2 x[ne] = {{0.0f, 0.0f}}; static constexpr __device__ bool supported() { - if (I == 16 && J == 8) return true; + if (I == 16 && J == 8) return true; + if (I == 16 && J == 16) return true; + if (I == 32 && J == 8) return true; return false; } static __device__ __forceinline__ int get_i(const int l) { if constexpr (I == 16 && J == 8) { return threadIdx.x % 16; + } else if constexpr (I == 16 && J == 16) { + return threadIdx.x % 16; + } else if constexpr (I == 32 && J == 8) { + return (threadIdx.x % 16) * 2 + l / (ne/2); } else { NO_DEVICE_CODE; return -1; @@ -353,7 +374,11 @@ namespace ggml_cuda_mma { static __device__ __forceinline__ int get_j(const int l) { if constexpr (I == 16 && J == 8) { - return ne * (threadIdx.x / 16) + l; + return (threadIdx.x / 16) * ne + l; + } else if constexpr (I == 16 && J == 16) { + return (threadIdx.x / 16) * ne + l; + } else if constexpr (I == 32 && J == 8) { + return (threadIdx.x / 16) * (ne/2) + l % (ne/2); } else { NO_DEVICE_CODE; return -1; @@ -516,12 +541,15 @@ namespace ggml_cuda_mma { if (I == 16 && J == 16) return true; if (I == 16 && J == 8) return true; if (I == 16 && J == 4) return true; + if (I == 32 && J == 8) return true; return false; } - static __device__ __forceinline__ int get_i(const int /*l*/) { - if constexpr (supported()) { + static __device__ __forceinline__ int get_i(const int l) { + if constexpr (I == 16) { return threadIdx.x % 16; + } else if constexpr (I == 32) { + return (threadIdx.x % 16) * 2 + l / (ne/2); } else { NO_DEVICE_CODE; return -1; @@ -529,8 +557,10 @@ namespace ggml_cuda_mma { } static __device__ __forceinline__ int get_j(const int l) { - if constexpr (supported()) { + if constexpr (I == 16) { return l; + } else if constexpr (I == 32) { + return l % (ne/2); } else { NO_DEVICE_CODE; return -1; @@ -644,6 +674,40 @@ namespace ggml_cuda_mma { } }; + template + struct tile { + static constexpr int I = I_; + static constexpr int J = J_; + static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_SCRAMBLED; + + static constexpr int ne = I * J / ggml_cuda_get_physical_warp_size(); + half2 x[ne] = {{0.0f, 0.0f}}; + + static constexpr __device__ bool supported() { + if (I == 16 && J == 16) return true; + return false; + } + + static __device__ __forceinline__ int get_i(const int l) { + return tile::get_i(l); + } + }; + + static __device__ __forceinline__ tile<16, 16, half2, DATA_LAYOUT_I_MAJOR> unscramble(const tile<16, 16, half2, DATA_LAYOUT_I_MAJOR_SCRAMBLED> & t) { +#if defined(AMD_MFMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) + tile<16, 16, half2, DATA_LAYOUT_I_MAJOR> ret; +#pragma unroll + for (int l0 = 0; l0 < t.ne/2; ++l0) { + ret.x[2*l0 + 0] = __lows2half2(t.x[l0], t.x[l0 + t.ne/2]); + ret.x[2*l0 + 1] = __highs2half2(t.x[l0], t.x[l0 + t.ne/2]); + } + return ret; +#else + NO_DEVICE_CODE; + GGML_UNUSED(t); +#endif // defined(AMD_MFMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) + } + #if defined(TURING_MMA_AVAILABLE) template static __device__ __forceinline__ tile get_half2(const tile & tile_float) { @@ -660,6 +724,21 @@ namespace ggml_cuda_mma { ret.x[0] = ggml_cuda_movmatrix(t.x[0]); ret.x[1] = ggml_cuda_movmatrix(t.x[1]); + return ret; + } +#elif defined(AMD_WMMA_AVAILABLE) && defined(RDNA3) + static __device__ __forceinline__ tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> get_half2( + const tile<16, 16, float, DATA_LAYOUT_I_MAJOR> & tile_float) { + tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> ret; +#pragma unroll + for (int l = 0; l < tile_float.ne; ++l) { + float tmp[2]; + int i = threadIdx.x / 16; + tmp[i] = tile_float.x[l]; + i ^= 1; + tmp[i] = __shfl_xor_sync(0xFFFFFFFF, tile_float.x[l], 16, WARP_SIZE); + ret.x[l] = make_half2(tmp[0], tmp[1]); + } return ret; } #elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) @@ -802,21 +881,35 @@ namespace ggml_cuda_mma { #endif // defined(VOLTA_MMA_AVAILABLE) } - template + template static __device__ __forceinline__ void load_ldmatrix_trans( - tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) { + tile & t, const T * __restrict__ xs0, const int stride) { #ifdef TURING_MMA_AVAILABLE + static_assert(I == 16, "bad tile width"); + static_assert(dl == DATA_LAYOUT_I_MAJOR, "bad data layout"); int * xi = (int *) t.x; const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2); asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.b16 {%0, %1, %2, %3}, [%4];" : "=r"(xi[0]), "=r"(xi[2]), "=r"(xi[1]), "=r"(xi[3]) : "l"(xs)); #elif defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) - half * xh = (half *) t.x; + static_assert(dl == DATA_LAYOUT_I_MAJOR || dl == DATA_LAYOUT_I_MAJOR_MIRRORED, "bad data layout"); + if constexpr (I == 32) { #pragma unroll - for (int l = 0; l < t.ne; ++l) { - xh[2*l + 0] = ((const half *) xs0)[(2*t.get_j(l) + 0)*(2*stride) + t.get_i(l)]; - xh[2*l + 1] = ((const half *) xs0)[(2*t.get_j(l) + 1)*(2*stride) + t.get_i(l)]; + for (int l0 = 0; l0 < t.ne/2; ++l0) { + const half2 tmp0 = xs0[(2*t.get_j(l0) + 0)*stride + t.get_i(l0)/2]; + const half2 tmp1 = xs0[(2*t.get_j(l0) + 1)*stride + t.get_i(l0)/2]; + + t.x[l0] = __lows2half2(tmp0, tmp1); + t.x[l0 + t.ne/2] = __highs2half2(tmp0, tmp1); + } + } else { + half * xh = (half *) t.x; +#pragma unroll + for (int l = 0; l < t.ne; ++l) { + xh[2*l + 0] = ((const half *) xs0)[(2*t.get_j(l) + 0)*(2*stride) + t.get_i(l)]; + xh[2*l + 1] = ((const half *) xs0)[(2*t.get_j(l) + 1)*(2*stride) + t.get_i(l)]; + } } #else GGML_UNUSED_VARS(t, xs0, stride); @@ -972,6 +1065,20 @@ namespace ggml_cuda_mma { #endif // TURING_MMA_AVAILABLE } + static __device__ __forceinline__ void mma( + tile<16, 16, half2, DATA_LAYOUT_I_MAJOR_SCRAMBLED> & D, const tile<32, 8, half2, DATA_LAYOUT_I_MAJOR> & A, + const tile<16, 8, half2, DATA_LAYOUT_I_MAJOR> & B) { +#if defined(AMD_MFMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) + tile<16, 8, half2> * D16 = (tile<16, 8, half2> *) &D; + const tile<16, 8, half2> * A16 = (const tile<16, 8, half2> *) &A; + mma(D16[0], A16[0], B); + mma(D16[1], A16[1], B); +#else + GGML_UNUSED_VARS(D, A, B); + NO_DEVICE_CODE; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) && defined(RDNA4) + } + template static __device__ __forceinline__ void mma( tile<16, 8, float, dl_d> & D, const tile<16, 8, float, dl_ab> & A, const tile<8, 8, float, dl_ab> & B) { @@ -1296,6 +1403,22 @@ namespace ggml_cuda_mma { #endif // defined(VOLTA_MMA_AVAILABLE) } + static __device__ __forceinline__ void mma( + tile<16, 16, half2, DATA_LAYOUT_I_MAJOR> & D, const tile<32, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> & A, + const tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> & B) { +#if defined(AMD_WMMA_AVAILABLE) && defined(RDNA3) + using halfx16_t = __attribute__((ext_vector_type(16))) _Float16; + halfx16_t * xD = (halfx16_t *) D.x; + const halfx16_t * xA = (const halfx16_t *) A.x; + const halfx16_t * xB = (const halfx16_t *) B.x; + xD[0] = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(xA[0], xB[0], xD[0], /*opsel =*/ 0); + xD[0] = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(xA[1], xB[0], xD[0], /*opsel =*/ 1); +#else + GGML_UNUSED_VARS(D, A, B); + NO_DEVICE_CODE; +#endif // TURING_MMA_AVAILABLE + } + template static __device__ __forceinline__ void mma( tile<16, 16, int, dl_d> & D, const tile<16, 4, int, dl_ab> & A, const tile<16, 4, int, dl_ab> & B) { From 18a61f44b63f34bdc05f7c88724b174b706ab149 Mon Sep 17 00:00:00 2001 From: Pranav Dhinakar Date: Thu, 14 May 2026 16:55:54 -0700 Subject: [PATCH 636/831] ggml-hexagon: cpy: add contiguous fast-path in reshape copy (llama/23076) --- ggml/src/ggml-hexagon/htp/cpy-ops.c | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/ggml/src/ggml-hexagon/htp/cpy-ops.c b/ggml/src/ggml-hexagon/htp/cpy-ops.c index e5b9d350fd7..5c040a32224 100644 --- a/ggml/src/ggml-hexagon/htp/cpy-ops.c +++ b/ggml/src/ggml-hexagon/htp/cpy-ops.c @@ -88,6 +88,29 @@ static void cpy_thread_sametype_reshape(struct htp_copy_context * ct, struct htp const uint32_t ir0 = dr * ith; const uint32_t ir1 = (ir0 + dr) < nr ? (ir0 + dr) : nr; + // Fast path: when both src0 and dst are contiguous in memory + // Replace the element-by-element loop with a single bulk HVX copy per (i03, i02) slice. + const bool src0_contig = (nb00 == ct->src0_type_size) && + (nb01 == ne00 * nb00) && + (nb02 == ne01 * nb01) && + (nb03 == ne02 * nb02); + const bool dst_contig = (nb0 == ct->dst_type_size) && + (nb1 == ne0 * nb0) && + (nb2 == ne1 * nb1) && + (nb3 == ne2 * nb2); + + if (src0_contig && dst_contig) { + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + uint8_t * src_ptr = (uint8_t *) src0->data + i03*nb03 + i02*nb02 + ir0*nb01; + uint32_t flat = ((i03*ne02 + i02)*ne01 + ir0) * ne00; + uint8_t * dst_ptr = (uint8_t *) dst->data + flat * ct->src0_type_size; + hvx_copy_uu(dst_ptr, src_ptr, (ir1 - ir0) * ne00, ct->src0_type_size); + } + } + return; + } + // dst counters int64_t k10 = 0; int64_t i11 = 0; From 23f956de336846ea28d7c2bc4c6d370216527203 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Sat, 16 May 2026 20:06:23 +0800 Subject: [PATCH 637/831] llama + spec: MTP Support (llama/22673) * spec: support MTP * fix batch size * rename files * cont : simplify (llama/7) * MTP: clean-up (llama/9) * MTP: clean-up * review: use llama_context_type instead of llama_graph_type * review: remove llama_model_has_mtp * review: fix convert issues * convert: fix pycheck * review: formatting * use `mtp-` for identifying mtp models * convert: fix mtp conversion * mtp -> draft-mtp * remove unused llama_arch * add need_embd in speculative * llama: allow partial seq_rm for GDN models for speculative decoding Currently speculative checkpoint needs to restart from a checkpoint after some draft tokens are not accepted, this leads to some wastage in running the target again. This PR adds the ability to rollback upto `draft_max` by storing the GDN intermediates. * fix pending state * vulkan: add GDN partial rollback * meta: extend check to axis 1 * metal: add GDN partial rollback Extend the gated delta net kernel to store intermediate states for partial rollback support on the Metal backend. - Add K (snapshot slot count) as a function constant - Read input state from slot 0 of the 3D state tensor - Write intermediate states to different slots during token loop - For K=1, maintain backward-compatible single-slot behavior Ref: https://github.com/ggml-org/llama.cpp/commit/8c05923630110223669f069af2000e9cf10c02bc Assisted-by: llama.cpp:local pi * delta_net_base: use ggml_pad instead of new_tensor * review: add need_rs_seq * review: rename part_bounded to n_rs * review: deslop comments * review: rename, add asserts * server : adjust checkpoint logic (llama/11) * server : adjust checkpoint logic * cont : rm asserts * server-context: fix early exit * spec : fix compatibility with n-gram and add TODOs (llama/13) * metal : cleanup * llama : fix faulty bitwise check in recurrent memory * server : disable RS-based MTP in combination with other spec types * spec : add TODOs * cont : fix comment * cont : update comment * common : fix logic for ngram + mtp compat * llama-memory: enable checkpointing with partial rollback * cont: add test-case for loading into a dirty ctx * llama-memory-recurrent: clear rs_idx in clear * download: fix mtp path * llama-arch: fix enorm op * docs: update docs * conversion: fix type annotations --------- Co-authored-by: Georgi Gerganov --- ggml/include/ggml.h | 5 ++ ggml/src/ggml-backend-meta.cpp | 5 +- ggml/src/ggml-cpu/ggml-cpu.c | 4 +- ggml/src/ggml-cpu/ops.cpp | 43 +++++++-- ggml/src/ggml-cuda/gated_delta_net.cu | 88 +++++++++++++------ ggml/src/ggml-metal/ggml-metal-device.cpp | 5 +- ggml/src/ggml-metal/ggml-metal.metal | 46 ++++++++-- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 8 +- .../vulkan-shaders/gated_delta_net.comp | 29 +++++- ggml/src/ggml.c | 12 +-- 10 files changed, 188 insertions(+), 57 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 3357a0d9985..41566d41aef 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -2541,6 +2541,11 @@ extern "C" { // TODO: add ggml_gated_delta_net_set_bcast() to be able to configure Q, K broadcast type: tiled vs interleaved [TAG_GGML_GDN_BCAST] // ref: https://github.com/ggml-org/llama.cpp/pull/19468#discussion_r2786394306 + // + // state is a 3D tensor of shape (S_v*S_v*H, K, n_seqs): + // K == 1: output carries the final state only. + // K > 1: output carries K snapshot slots; the kernel writes the last min(n_tokens, K) + // per-token snapshots into the trailing slots GGML_API struct ggml_tensor * ggml_gated_delta_net( struct ggml_context * ctx, struct ggml_tensor * q, diff --git a/ggml/src/ggml-backend-meta.cpp b/ggml/src/ggml-backend-meta.cpp index c0ffd9a048b..df0f405ed9f 100644 --- a/ggml/src/ggml-backend-meta.cpp +++ b/ggml/src/ggml-backend-meta.cpp @@ -753,7 +753,9 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state(co GGML_ASSERT(src_ss[2].axis == GGML_BACKEND_SPLIT_AXIS_1); GGML_ASSERT(src_ss[3].axis == GGML_BACKEND_SPLIT_AXIS_1); GGML_ASSERT(src_ss[4].axis == GGML_BACKEND_SPLIT_AXIS_1); - GGML_ASSERT(src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_2); + // state shape is (S_v*S_v*H, K, n_seqs); the heads dim is nested inside axis 0, + // so a head-aligned split on the input cache reshapes to axis 0 here (not axis 2). + GGML_ASSERT(src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_2 || src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_1 || src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_0); return {GGML_BACKEND_SPLIT_AXIS_0, {0}, 1}; }; @@ -2140,4 +2142,3 @@ ggml_backend_t ggml_backend_meta_simple_backend(ggml_backend_t meta_backend, siz const ggml_backend_meta_context * backend_ctx = (const ggml_backend_meta_context *) meta_backend->context; return backend_ctx->backend_configs[index].backend; } - diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 7b05edf6b75..cd5c61a8187 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -2943,7 +2943,9 @@ struct ggml_cplan ggml_graph_plan( case GGML_OP_GATED_DELTA_NET: { const int64_t S_v = node->src[2]->ne[0]; - cur = S_v * sizeof(float) * n_tasks; + const int64_t K = node->src[5]->ne[1]; // state is (D, K, n_seqs) + const int64_t per_thread = S_v + (K > 1 ? S_v * S_v : 0); + cur = per_thread * sizeof(float) * n_tasks; } break; case GGML_OP_COUNT: { diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 6bc8dc150ce..7485ba4fc86 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -10513,19 +10513,30 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( const bool kda = (neg0 == S_v); - // scratch layout per thread: [delta(S_v)] - const int64_t scratch_per_thread = S_v; + // state is 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count. + const int64_t K = src_state->ne[1]; + GGML_ASSERT(K >= 1); + // per-seq stride in floats (slot 0 of seq s lives at state + s * seq_stride) + const int64_t state_seq_stride = src_state->nb[2] / sizeof(float); + + const int64_t per_thread = S_v + (K > 1 ? S_v * S_v : 0); const int ith = params->ith; - float * delta = (float *)params->wdata + ith * scratch_per_thread + CACHE_LINE_SIZE_F32; + float * delta = (float *)params->wdata + ith * per_thread + CACHE_LINE_SIZE_F32; + float * state_work = K > 1 ? (delta + S_v) : nullptr; // output layout: [attn_scores | new_states] - // attn_scores: S_v * H * n_tokens * n_seqs floats - // new_states: S_v * S_v * H * n_seqs floats - const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs; + // attn_scores: S_v * H * n_tokens * n_seqs floats + // new_states: S_v * S_v * H * n_seqs * K floats (K snapshot slots; last min(n_tokens, K)) + const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs; + const int64_t state_size_per_snap = S_v * S_v * H * n_seqs; float * attn_out_base = (float *)dst->data; float * state_out_base = (float *)dst->data + attn_score_elems; + // snapshot slot mapping: target_slot = t - shift. When n_tokens < K only the last + // n_tokens slots are written; earlier slots are left untouched (caller-owned). + const int64_t shift = n_tokens - K; + const float * state_in_base = (const float *)src_state->data; //const int64_t rq1 = nev1 / neq1; @@ -10545,10 +10556,15 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( const int64_t iq3 = iv3 / rq3; const int64_t ik3 = iv3 / rk3; - float * s_out = state_out_base + (iv3 * H + iv1) * S_v * S_v; + // For K=1, write directly to the single output slot to avoid an extra memcpy at the end. + // For K>1, work in scratch and copy out per-token when the slot is in range. + float * s_out = (K > 1) + ? state_work + : state_out_base + (iv3 * H + iv1) * S_v * S_v; - // copy input state into output buffer and operate in-place - const float * s_in = state_in_base + (iv3 * H + iv1) * S_v * S_v; + // copy input state into the working buffer and operate in-place + // state layout (D, K, n_seqs): slot 0 of seq iv3 starts at iv3 * state_seq_stride. + const float * s_in = state_in_base + iv3 * state_seq_stride + iv1 * S_v * S_v; memcpy(s_out, s_in, S_v * S_v * sizeof(float)); // attn output pointer for first token of this (head, seq) @@ -10598,6 +10614,15 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( } attn_data += S_v * H; // advance to next token + + if (K > 1) { + const int64_t target_slot = t - shift; + if (target_slot >= 0 && target_slot < K) { + float * curr_state_o = state_out_base + target_slot * state_size_per_snap + + (iv3 * H + iv1) * S_v * S_v; + memcpy(curr_state_o, s_out, S_v * S_v * sizeof(float)); + } + } } } } diff --git a/ggml/src/ggml-cuda/gated_delta_net.cu b/ggml/src/ggml-cuda/gated_delta_net.cu index 6b44bec7317..b4c9845e7a7 100644 --- a/ggml/src/ggml-cuda/gated_delta_net.cu +++ b/ggml/src/ggml-cuda/gated_delta_net.cu @@ -1,6 +1,6 @@ #include "gated_delta_net.cuh" -template +template __global__ void __launch_bounds__((ggml_cuda_get_physical_warp_size() < S_v ? ggml_cuda_get_physical_warp_size() : S_v) * 4, 2) gated_delta_net_cuda(const float * q, const float * k, @@ -23,7 +23,8 @@ gated_delta_net_cuda(const float * q, int64_t sb3, const uint3 neqk1_magic, const uint3 rq3_magic, - float scale) { + float scale, + int K) { const uint32_t h_idx = blockIdx.x; const uint32_t sequence = blockIdx.y; // each warp owns one column, using warp-level primitives to reduce across rows @@ -37,9 +38,13 @@ gated_delta_net_cuda(const float * q, float * attn_data = dst; float * state = dst + attn_score_elems; - const int64_t state_offset = (sequence * H + h_idx) * S_v * S_v; - state += state_offset; - curr_state += state_offset + col * S_v; + // input state layout (D, K, n_seqs) — seq stride is K * D = K * H * S_v * S_v. + // output state layout (per-slot D * n_seqs) — same per-(seq,head) offset as before. + const int64_t state_in_offset = sequence * K * H * S_v * S_v + h_idx * S_v * S_v; + const int64_t state_out_offset = (sequence * H + h_idx) * S_v * S_v; + const int64_t state_size_per_token = S_v * S_v * H * n_seqs; // per-slot stride in output + state += state_out_offset; + curr_state += state_in_offset + col * S_v; attn_data += (sequence * n_tokens * H + h_idx) * S_v; constexpr int warp_size = ggml_cuda_get_physical_warp_size() < S_v ? ggml_cuda_get_physical_warp_size() : S_v; @@ -54,6 +59,10 @@ gated_delta_net_cuda(const float * q, s_shard[r] = curr_state[i]; } + // slot mapping: target_slot = t - shift. When n_tokens < K only the last n_tokens slots + // are written; earlier slots are left untouched (caller-owned). + const int shift = (int) n_tokens - K; + for (int t = 0; t < n_tokens; t++) { const float * q_t = q + iq3 * sq3 + t * sq2 + iq1 * sq1; const float * k_t = k + iq3 * sq3 + t * sq2 + iq1 * sq1; @@ -135,17 +144,30 @@ gated_delta_net_cuda(const float * q, } attn_data += S_v * H; + + if constexpr (keep_rs_t) { + const int target_slot = t - shift; + if (target_slot >= 0 && target_slot < K) { + float * curr_state = (dst + attn_score_elems) + target_slot * state_size_per_token + state_out_offset; +#pragma unroll + for (int r = 0; r < rows_per_lane; r++) { + const int i = r * warp_size + lane; + curr_state[col * S_v + i] = s_shard[r]; + } + } + } } - // Write state back to global memory (transposed layout) + if constexpr (!keep_rs_t) { #pragma unroll - for (int r = 0; r < rows_per_lane; r++) { - const int i = r * warp_size + lane; - state[col * S_v + i] = s_shard[r]; + for (int r = 0; r < rows_per_lane; r++) { + const int i = r * warp_size + lane; + state[col * S_v + i] = s_shard[r]; + } } } -template +template static void launch_gated_delta_net( const float * q_d, const float * k_d, const float * v_d, const float * g_d, const float * b_d, const float * s_d, @@ -155,7 +177,7 @@ static void launch_gated_delta_net( int64_t sv1, int64_t sv2, int64_t sv3, int64_t sb1, int64_t sb2, int64_t sb3, int64_t neqk1, int64_t rq3, - float scale, cudaStream_t stream) { + float scale, int K, cudaStream_t stream) { //TODO: Add chunked kernel for even faster pre-fill const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size; const int num_warps = 4; @@ -169,29 +191,29 @@ static void launch_gated_delta_net( switch (S_v) { case 16: - gated_delta_net_cuda<16, KDA><<>>( + gated_delta_net_cuda<16, KDA, keep_rs_t><<>>( q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, - sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); + sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K); break; case 32: - gated_delta_net_cuda<32, KDA><<>>( + gated_delta_net_cuda<32, KDA, keep_rs_t><<>>( q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, - sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); + sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K); break; case 64: { - gated_delta_net_cuda<64, KDA><<>>( + gated_delta_net_cuda<64, KDA, keep_rs_t><<>>( q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, - sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); + sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K); break; } case 128: { - gated_delta_net_cuda<128, KDA><<>>( + gated_delta_net_cuda<128, KDA, keep_rs_t><<>>( q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, - sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); + sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K); break; } default: @@ -261,13 +283,29 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * cudaStream_t stream = ctx.stream(); + // state is 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count. + const int K = (int) src_state->ne[1]; + const bool keep_rs = K > 1; + if (kda) { - launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, - S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, - sb1, sb2, sb3, neqk1, rq3, scale, stream); + if (keep_rs) { + launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, + S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, neqk1, rq3, scale, K, stream); + } else { + launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, + S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, neqk1, rq3, scale, K, stream); + } } else { - launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, - S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, - sb1, sb2, sb3, neqk1, rq3, scale, stream); + if (keep_rs) { + launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, + S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, neqk1, rq3, scale, K, stream); + } else { + launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, + S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, neqk1, rq3, scale, K, stream); + } } } diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index f0147af84c1..e288a27f992 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -590,6 +590,8 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net( const int ne20 = op->src[2]->ne[0]; // S_v const int ne21 = op->src[2]->ne[1]; // H const int ne30 = op->src[3]->ne[0]; // G + // state is src[5], 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count. + const int K = op->src[5]->ne[1]; const int nsg = op->src[2]->ne[0]/32; @@ -598,7 +600,7 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net( GGML_ASSERT(ne20 % 32 == 0); snprintf(base, 256, "kernel_gated_delta_net_%s_%d", ggml_type_name(op->src[0]->type), nsg); - snprintf(name, 256, "%s_ne20=%d_ne30=%d", base, ne20, ne30); + snprintf(name, 256, "%s_ne20=%d_ne30=%d_K=%d", base, ne20, ne30, K); ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); if (!res.pipeline) { @@ -606,6 +608,7 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net( ggml_metal_cv_set_int16(cv, ne20, FC_GATED_DELTA_NET + 0); ggml_metal_cv_set_int16(cv, ne30, FC_GATED_DELTA_NET + 1); + ggml_metal_cv_set_int16(cv, K, FC_GATED_DELTA_NET + 2); res = ggml_metal_library_compile_pipeline(lib, base, name, cv); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 2d45de8cce2..f6ffb2b3a1c 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -2531,6 +2531,7 @@ kernel void kernel_rwkv_wkv7_f32( constant short FC_gated_delta_net_ne20 [[function_constant(FC_GATED_DELTA_NET + 0)]]; constant short FC_gated_delta_net_ne30 [[function_constant(FC_GATED_DELTA_NET + 1)]]; +constant short FC_gated_delta_net_K [[function_constant(FC_GATED_DELTA_NET + 2)]]; #if 1 template @@ -2548,21 +2549,24 @@ kernel void kernel_gated_delta_net_impl( uint3 ntg[[threads_per_threadgroup]]) { #define S_v FC_gated_delta_net_ne20 #define G FC_gated_delta_net_ne30 +#define K FC_gated_delta_net_K const uint tx = tpitg.x; const uint ty = tpitg.y; - const uint i23 = tgpig.z; // B - const uint i21 = tgpig.y; // H - const uint i20 = tgpig.x*NSG + ty; + const uint i23 = tgpig.z; // B (n_seqs) + const uint i21 = tgpig.y; // H (head) + const uint i20 = tgpig.x*NSG + ty; // row within S_v const uint i01 = i21 % args.ne01; const uint i11 = i21 % args.ne11; const float scale = 1.0f / sqrt((float)S_v); + // input state layout (D, K, n_seqs): per-seq stride is K*H*D; we read slot 0. // state is stored transposed: M[i20][is] = S[is][i20], so row i20 is contiguous - device const float * s_ptr = (device const float *) (s) + (i23*args.ne21 + i21)*S_v*S_v + i20*S_v; + const uint state_in_base = (i23*K*args.ne21 + i21)*S_v*S_v + i20*S_v; + device const float * s_ptr = (device const float *) (s) + state_in_base; float ls[NSG]; @@ -2580,6 +2584,17 @@ kernel void kernel_gated_delta_net_impl( device const float * b_ptr = (device const float *) (b) + (i23*args.ne22*args.ne21 + i21); device const float * g_ptr = (device const float *) (g) + (i23*args.ne22*args.ne21 + i21)*G; + // snapshot slot mapping: target_slot = t - shift. When n_tokens < K, only the last + // n_tokens slots are written; earlier slots are left untouched (caller-owned). + const int shift = (int)args.ne22 - (int)K; + + // output state base offset: after attention scores + const uint attn_size = args.ne22 * args.ne21 * S_v * args.ne23; + // output state per-slot size: S_v * S_v * H * n_seqs + const uint state_size_per_snap = S_v * S_v * args.ne21 * args.ne23; + // per-(seq,head) offset within a slot + const uint state_out_base = (i23*args.ne21 + i21)*S_v*S_v + i20*S_v; + for (short t = 0; t < args.ne22; t++) { float s_k = 0.0f; @@ -2627,17 +2642,30 @@ kernel void kernel_gated_delta_net_impl( b_ptr += args.ne21; g_ptr += args.ne21*G; - } - device float * dst_state = (device float *) (dst) + args.ne23*args.ne22*args.ne21*S_v + (i23*args.ne21 + i21)*S_v*S_v + i20*S_v; + if (K > 1u) { + const int target_slot = (int)t - shift; + if (target_slot >= 0 && target_slot < (int)K) { + device float * dst_state = (device float *) (dst) + attn_size + (uint)target_slot * state_size_per_snap + state_out_base; + FOR_UNROLL (short j = 0; j < NSG; j++) { + const short is = tx*NSG + j; + dst_state[is] = ls[j]; + } + } + } + } - FOR_UNROLL (short j = 0; j < NSG; j++) { - const short is = tx*NSG + j; - dst_state[is] = ls[j]; + if (K == 1u) { + device float * dst_state = (device float *) (dst) + attn_size + state_out_base; + FOR_UNROLL (short j = 0; j < NSG; j++) { + const short is = tx*NSG + j; + dst_state[is] = ls[j]; + } } #undef S_v #undef G +#undef K } typedef decltype(kernel_gated_delta_net_impl<4>) kernel_gated_delta_net_t; diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 8c4cf9ef1db..d29a4bab2e2 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -1506,6 +1506,7 @@ struct vk_op_gated_delta_net_push_constants { uint32_t sb1, sb2, sb3; uint32_t neq1, rq3; float scale; + uint32_t K; }; struct vk_op_ssm_scan_push_constants { @@ -10767,6 +10768,7 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s const ggml_tensor * src_q = dst->src[0]; const ggml_tensor * src_v = dst->src[2]; const ggml_tensor * src_beta = dst->src[4]; + const ggml_tensor * src_state = dst->src[5]; GGML_ASSERT(dst->buffer != nullptr); @@ -10775,6 +10777,9 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s const uint32_t n_tokens = (uint32_t)src_v->ne[2]; const uint32_t n_seqs = (uint32_t)src_v->ne[3]; + // state is 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count. + const uint32_t K = (uint32_t)src_state->ne[1]; + const uint32_t s_off = S_v * H * n_tokens * n_seqs; vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, dst->src[0], dst->src[1], dst->src[2], dst, dst->op); @@ -10808,7 +10813,8 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s sv1, sv2, sv3, sb1, sb2, sb3, neq1, rq3, - scale + scale, + K }; ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp index 5e9f8308c1d..33c3202dbb7 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp @@ -31,6 +31,7 @@ layout(push_constant) uniform Parameters { uint sb1, sb2, sb3; uint neq1, rq3; float scale; + uint K; }; layout(binding = 0) readonly buffer QBuf { FLOAT_TYPE data_q[]; }; @@ -101,13 +102,21 @@ void main() { const uint iq3 = seq_id / rq3; const uint state_size = S_V * S_V; - const uint state_base = (seq_id * H + head_id) * state_size; + // input state layout (D, K, n_seqs): per-seq stride is K*H*D; we read slot 0. + const uint state_in_base = (seq_id * K * H + head_id) * state_size; + // output state layout per slot: same per-(seq,head) offset as the single-slot case. + const uint state_out_base = (seq_id * H + head_id) * state_size; + const uint state_size_per_snap = state_size * H * n_seqs; FLOAT_TYPE s_shard[ROWS_PER_LANE]; [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) { - s_shard[r] = FLOAT_TYPE(data_state[state_base + col * S_V + r * LANES_PER_COLUMN + lane]); + s_shard[r] = FLOAT_TYPE(data_state[state_in_base + col * S_V + r * LANES_PER_COLUMN + lane]); } + // snapshot slot mapping: target_slot = t - shift. When n_tokens < K, only the last + // n_tokens slots are written; earlier slots are left untouched (caller-owned). + const int shift = int(n_tokens) - int(K); + uint attn_off = (seq_id * n_tokens * H + head_id) * S_V; for (uint t = 0; t < n_tokens; t++) { @@ -161,9 +170,21 @@ void main() { } attn_off += S_V * H; + + if (K > 1u) { + const int target_slot = int(t) - shift; + if (target_slot >= 0 && target_slot < int(K)) { + const uint slot_base = s_off + uint(target_slot) * state_size_per_snap + state_out_base; + [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) { + data_dst[slot_base + col * S_V + r * LANES_PER_COLUMN + lane] = s_shard[r]; + } + } + } } - [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) { - data_dst[s_off + state_base + col * S_V + r * LANES_PER_COLUMN + lane] = s_shard[r]; + if (K == 1u) { + [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) { + data_dst[s_off + state_out_base + col * S_V + r * LANES_PER_COLUMN + lane] = s_shard[r]; + } } } diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 191cf2fa106..476c3079795 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -6210,11 +6210,13 @@ struct ggml_tensor * ggml_gated_delta_net( GGML_ASSERT(g->ne[0] == 1 || g->ne[0] == S_v); GGML_ASSERT(beta->ne[0] == 1); - GGML_ASSERT(ggml_nelements(state) == S_v * S_v * H * n_seqs); - - // concat output and new_state into a single tensor - // output: S_v * H * n_tokens * n_seqs, state: S_v * S_v * H * n_seqs - const int64_t ne[4] = { S_v * H, n_tokens * n_seqs + S_v * n_seqs, 1, 1 }; + // state is a 3D tensor (S_v*S_v*H, K, n_seqs). K is the snapshot slot count. + GGML_ASSERT(state->ne[0] == S_v * S_v * H); + GGML_ASSERT(state->ne[2] == n_seqs); + GGML_ASSERT(state->ne[3] == 1); + const int64_t K = state->ne[1]; + const int64_t state_rows = K * S_v * n_seqs; + const int64_t ne[4] = { S_v * H, n_tokens * n_seqs + state_rows, 1, 1 }; struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); result->op = GGML_OP_GATED_DELTA_NET; From 587dca0eda5168b4dcf77585b15ab33d179b3b27 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 16 May 2026 15:59:09 +0300 Subject: [PATCH 638/831] ggml : bump version to 0.12.0 (ggml/1494) --- ggml/CMakeLists.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index bdeca34bf9f..4aac5094d1c 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -4,8 +4,8 @@ project("ggml" C CXX ASM) ### GGML Version set(GGML_VERSION_MAJOR 0) -set(GGML_VERSION_MINOR 11) -set(GGML_VERSION_PATCH 1) +set(GGML_VERSION_MINOR 12) +set(GGML_VERSION_PATCH 0) set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}") list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/") From 3583e35e0d0d07993b96a07788e655acadc249a5 Mon Sep 17 00:00:00 2001 From: Dev-X25874 <283057883+Dev-X25874@users.noreply.github.com> Date: Thu, 21 May 2026 17:28:08 +0530 Subject: [PATCH 639/831] ggml-alloc: fix out-of-bounds read in ggml_dyn_tallocr_remove_block (ggml/1492) --- ggml/src/ggml-alloc.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-alloc.c b/ggml/src/ggml-alloc.c index a4b01ccf8a1..3bda9abbe03 100644 --- a/ggml/src/ggml-alloc.c +++ b/ggml/src/ggml-alloc.c @@ -150,7 +150,7 @@ static void ggml_dyn_tallocr_insert_block(struct tallocr_chunk * chunk, size_t o static void ggml_dyn_tallocr_remove_block(struct tallocr_chunk * chunk, int idx) { // shift all elements after idx by 1 to the left, overwriting the element at idx - for (int i = idx; i < chunk->n_free_blocks; i++) { + for (int i = idx; i < chunk->n_free_blocks - 1; i++) { chunk->free_blocks[i] = chunk->free_blocks[i+1]; } chunk->n_free_blocks--; From e78e69301721c8f804397f4cb356bc1a217b39f2 Mon Sep 17 00:00:00 2001 From: Ori Pekelman Date: Thu, 21 May 2026 12:00:16 +0000 Subject: [PATCH 640/831] ggml.h: correct ggml_silu_back arg docstring (a=dy, b=x) (ggml/1500) --- ggml/include/ggml.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 41566d41aef..f6725265504 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -1189,8 +1189,8 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); - // a - x - // b - dy + // a - dy + // b - x GGML_API struct ggml_tensor * ggml_silu_back( struct ggml_context * ctx, struct ggml_tensor * a, From ef5ddecff9c9ece7946048aa4b193825bb916cb7 Mon Sep 17 00:00:00 2001 From: Winston Ma Date: Sun, 17 May 2026 01:57:35 +0800 Subject: [PATCH 641/831] vulkan: removed duplicate #include in headers (llama/23144) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index d29a4bab2e2..a296d0ab446 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -49,7 +49,6 @@ DispatchLoaderDynamic & ggml_vk_default_dispatcher(); #include #include #include -#include #include #include #include From c7dd64c6062adff284b04679612ffb3261eeffba Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Sun, 17 May 2026 03:25:50 -0500 Subject: [PATCH 642/831] vulkan: fuse SSM_CONV + BIAS + SILU (llama/22653) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 126 ++++++++++++++++-- .../ggml-vulkan/vulkan-shaders/ssm_conv.comp | 12 +- 2 files changed, 129 insertions(+), 9 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index a296d0ab446..d76d4819026 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -854,6 +854,8 @@ struct vk_device_struct { vk_pipeline pipeline_ssm_scan_f32_d128; vk_pipeline pipeline_ssm_scan_f32_d256; vk_pipeline pipeline_ssm_conv_f32; + vk_pipeline pipeline_ssm_conv_silu_f32; + vk_pipeline pipeline_ssm_conv_bias_silu_f32; vk_pipeline pipeline_opt_step_adamw_f32; vk_pipeline pipeline_opt_step_sgd_f32; std::map pipeline_conv2d_f32[CONV_SHAPE_COUNT]; @@ -4900,7 +4902,9 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true); } - ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_f32, "ssm_conv_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 3, sizeof(vk_op_ssm_conv_push_constants), {32, 16, 1}, {32, 16}, 1); + ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_f32, "ssm_conv_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 4, sizeof(vk_op_ssm_conv_push_constants), {32, 16, 1}, {32, 16, 0, 0}, 1); + ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_silu_f32, "ssm_conv_silu_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 4, sizeof(vk_op_ssm_conv_push_constants), {32, 16, 1}, {32, 16, 0, 1}, 1); + ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_bias_silu_f32, "ssm_conv_bias_silu_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 4, sizeof(vk_op_ssm_conv_push_constants), {32, 16, 1}, {32, 16, 1, 1}, 1); ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); @@ -9936,7 +9940,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return nullptr; case GGML_OP_SSM_CONV: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return ctx->device->pipeline_ssm_conv_f32; + switch (ctx->num_additional_fused_ops) { + case 0: return ctx->device->pipeline_ssm_conv_f32; + case 1: return ctx->device->pipeline_ssm_conv_silu_f32; + case 2: return ctx->device->pipeline_ssm_conv_bias_silu_f32; + default: return nullptr; + } } return nullptr; case GGML_OP_OPT_STEP_ADAMW: @@ -10877,11 +10886,28 @@ static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx, pc, elements); } -static void ggml_vk_ssm_conv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) { - const ggml_tensor * src0 = dst->src[0]; - const ggml_tensor * src1 = dst->src[1]; +static void ggml_vk_ssm_conv(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) { + ggml_tensor * conv = cgraph->nodes[node_idx]; + const ggml_tensor * src0 = conv->src[0]; + const ggml_tensor * src1 = conv->src[1]; + + // Pick the destination tensor (last node in the fused chain) and the optional bias. + // Fusion modes: 0 = ssm_conv, 1 = ssm_conv+silu, 2 = ssm_conv+add(bias)+silu. + ggml_tensor * dst = conv; + const ggml_tensor * bias = nullptr; + + if (ctx->num_additional_fused_ops == 1) { + dst = cgraph->nodes[node_idx + 1]; // silu + } else if (ctx->num_additional_fused_ops == 2) { + ggml_tensor * add = cgraph->nodes[node_idx + 1]; + bias = (add->src[0] == conv) ? add->src[1] : add->src[0]; + dst = cgraph->nodes[node_idx + 2]; // silu + } - ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SSM_CONV, { + // The shader always declares 4 bindings; bind src0 as a dummy when bias isn't fused. + const ggml_tensor * src2 = bias ? bias : src0; + + ggml_vk_op_f32(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_SSM_CONV, { (uint32_t)src0->nb[1], (uint32_t)src0->nb[2], (uint32_t)src1->nb[1], (uint32_t)dst->nb[0], (uint32_t)dst->nb[1], (uint32_t)dst->nb[2], @@ -13556,7 +13582,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr break; case GGML_OP_SSM_CONV: - ggml_vk_ssm_conv(ctx, compute_ctx, node); + ggml_vk_ssm_conv(ctx, compute_ctx, cgraph, node_idx); break; @@ -14453,6 +14479,62 @@ static bool ggml_vk_can_fuse(const ggml_backend_vk_context * ctx, const struct g return true; } +// Match SSM_CONV + UNARY(SILU) or SSM_CONV + ADD + UNARY(SILU). num_extra is 1 or 2. +static bool ggml_vk_can_fuse_ssm_conv(const ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, + int node_idx, int num_extra) { + const ggml_tensor * conv = cgraph->nodes[node_idx]; + if (conv->op != GGML_OP_SSM_CONV) { + return false; + } + + const ggml_tensor * silu = nullptr; + const ggml_tensor * bias = nullptr; + + if (num_extra == 1) { + if (!ggml_can_fuse(cgraph, node_idx, { GGML_OP_SSM_CONV, GGML_OP_UNARY })) { + return false; + } + silu = cgraph->nodes[node_idx + 1]; + } else if (num_extra == 2) { + if (!ggml_can_fuse(cgraph, node_idx, { GGML_OP_SSM_CONV, GGML_OP_ADD, GGML_OP_UNARY })) { + return false; + } + const ggml_tensor * add = cgraph->nodes[node_idx + 1]; + silu = cgraph->nodes[node_idx + 2]; + bias = (add->src[0] == conv) ? add->src[1] : add->src[0]; + + if (bias->type != GGML_TYPE_F32 || !ggml_is_contiguous(bias)) { + return false; + } + // bias must be channel-wise (one element per channel of the conv output) + if (ggml_nelements(bias) != conv->ne[0] || bias->ne[0] != conv->ne[0]) { + return false; + } + if (add->type != GGML_TYPE_F32) { + return false; + } + // The shader doesn't apply per-tensor offsets, so reject misaligned bias. + if (get_misalign_bytes(ctx, bias) != 0) { + return false; + } + } else { + return false; + } + + if (ggml_get_unary_op(silu) != GGML_UNARY_OP_SILU) { + return false; + } + if (conv->type != GGML_TYPE_F32 || silu->type != GGML_TYPE_F32) { + return false; + } + // The shader writes to the fused dst using its own strides, but the push constants don't + // carry a per-tensor offset, so the binding must be naturally aligned. + if (get_misalign_bytes(ctx, silu) != 0) { + return false; + } + return true; +} + static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, int node_idx, topk_moe_mode mode) { @@ -14869,6 +14951,19 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg // they are overwritten, and one workgroup per row. So close enough. op_srcs_fused_elementwise[0] = true; op_srcs_fused_elementwise[1] = true; + } else if (ggml_vk_can_fuse_ssm_conv(ctx, cgraph, i, 2)) { + ctx->num_additional_fused_ops = 2; + fusion_string = "SSM_CONV_BIAS_SILU"; + // ssm_conv reads multiple input tokens per output, so it's not elementwise w.r.t. its srcs. + // The downstream add and silu are elementwise on the conv output. + op_srcs_fused_elementwise[0] = false; + op_srcs_fused_elementwise[1] = true; + op_srcs_fused_elementwise[2] = true; + } else if (ggml_vk_can_fuse_ssm_conv(ctx, cgraph, i, 1)) { + ctx->num_additional_fused_ops = 1; + fusion_string = "SSM_CONV_SILU"; + op_srcs_fused_elementwise[0] = false; + op_srcs_fused_elementwise[1] = true; } else if (ggml_can_fuse_subgraph(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, { i + 2 }) && ggml_check_edges(cgraph, i, rope_view_set_rows_edges) && ggml_vk_can_fuse_rope_set_rows(ctx, cgraph, i)) { @@ -15200,7 +15295,9 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT && graph->nodes[j]->op == GGML_OP_ADD) && !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_ADD_ID) && !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_MUL) && - !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_ADD && graph->nodes[j]->op == GGML_OP_ADD)) { + !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_ADD && graph->nodes[j]->op == GGML_OP_ADD) && + !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_SSM_CONV && graph->nodes[j]->op == GGML_OP_ADD) && + !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_SSM_CONV && graph->nodes[j]->op == GGML_OP_UNARY)) { ok = false; break; } @@ -15283,6 +15380,19 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * } } } + // SSM_CONV + ADD + UNARY: pull the consuming UNARY forward + if (j > 0 && + graph->nodes[j]->op == GGML_OP_ADD && + graph->nodes[j-1]->op == GGML_OP_SSM_CONV) { + for (int k = j + 1; k < std::min(j + 15, graph->n_nodes); ++k) { + if (graph->nodes[k]->op == GGML_OP_UNARY && + graph->nodes[k]->src[0] == graph->nodes[j]) { + current_set.push_back(k); + used[k] = true; + break; + } + } + } } } // Second pass grabs view nodes. diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp b/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp index 6802b1fc955..4cd9b8da359 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp @@ -6,12 +6,15 @@ layout(constant_id = 0) const uint BLOCK_SIZE = 32; layout(constant_id = 1) const uint TOKENS_PER_WG = 16; +layout(constant_id = 2) const bool APPLY_BIAS = false; +layout(constant_id = 3) const bool APPLY_SILU = false; layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z = 1) in; layout(binding = 0) readonly buffer Src0 { float src0[]; }; layout(binding = 1) readonly buffer Src1 { float src1[]; }; -layout(binding = 2) buffer Dst { float dst[]; }; +layout(binding = 2) readonly buffer Bias { float bias[]; }; +layout(binding = 3) buffer Dst { float dst[]; }; layout(push_constant) uniform PushConstants { uint nb01; uint nb02; @@ -45,6 +48,13 @@ void main() { } } + if (APPLY_BIAS) { + sum += bias[i1]; + } + if (APPLY_SILU) { + sum = sum / (1.0f + exp(-sum)); + } + const uint dst_idx = i3 * (dst_nb2 / 4) + i2 * (dst_nb1 / 4) + i1; dst[dst_idx] = sum; } From e417ce7aebd9305ce99d8ba89230b9825dc409c2 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Sun, 17 May 2026 04:30:16 -0500 Subject: [PATCH 643/831] vulkan: Support unaligned tensors for ROPE (llama/22637) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 17 +++++++++++++++++ .../ggml-vulkan/vulkan-shaders/rope_funcs.glsl | 7 +++++-- .../ggml-vulkan/vulkan-shaders/rope_params.glsl | 3 +++ 3 files changed, 25 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index d76d4819026..14eab8ea4de 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -1354,6 +1354,8 @@ struct vk_op_rope_push_constants { uint32_t nb11; uint32_t nb12; uint32_t nb13; + uint32_t a_offset; + uint32_t d_offset; }; static_assert(sizeof(vk_op_rope_push_constants) <= 128, "sizeof(vk_op_rope_push_constants) must be <= 128"); @@ -10126,6 +10128,15 @@ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk GGML_UNUSED(src3); } +template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_rope_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) { + p.a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type); + p.d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type); + + GGML_UNUSED(src1); + GGML_UNUSED(src2); + GGML_UNUSED(src3); +} + template static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst, ggml_op op, PC&& pc) { VK_LOG_DEBUG("ggml_vk_op_f32((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; @@ -11270,6 +11281,7 @@ static vk_op_rope_push_constants ggml_vk_make_rope_constants(const ggml_tensor * (uint32_t)src0->ne[2], nb01, nb02, nb03, nb11, nb12, nb13, + 0, 0, // a_offset, d_offset filled in by init_pushconst_tensor_offsets }; return rope; @@ -11365,6 +11377,11 @@ static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, GGML_ASSERT(buf[i] != nullptr); } + // a_offset is unused (the fused path reads from shared memory), but the rope/set_rows dst can be misaligned. + // Round the binding offset down to the storage buffer alignment; the in-element shift goes in pc.rope.d_offset. + pc.rope.d_offset = get_misalign_bytes(ctx, tensors[5]) / ggml_type_size(tensors[5]->type); + offset[5] &= ~(size_t(ctx->device->properties.limits.minStorageBufferOffsetAlignment) - 1); + std::array elements; elements = { (uint32_t)rms->src[0]->ne[1], (uint32_t)rms->src[0]->ne[2], (uint32_t)rms->src[0]->ne[3] }; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl index 2e53459909d..03358793140 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl @@ -9,7 +9,7 @@ uint rope_a_coord(const uint i0, const uint i01, const uint i02, const uint i03, // Per-row offset in shared memory const uint ix = i0; #else - const uint ix = i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i0; + const uint ix = p.a_offset + i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i0; #endif return ix; } @@ -48,6 +48,7 @@ void rope_norm(const uint i0, const uint i1, const uint i2, const uint i3, rope_ idst = i1*p.nb11 + i0; idst += rope_data_i[i2].x * p.set_rows_stride; } + idst += p.d_offset; if (i0 >= p.n_dims) { rope_data_d[idst + 0] = ROPE_D_TYPE(rope_data_a[ix + 0]); @@ -84,6 +85,7 @@ void rope_neox(const uint i0, const uint i1, const uint i2, const uint i3, rope_ idst = i1*p.nb11 + i0/2; idst += rope_data_i[i2].x * p.set_rows_stride; } + idst += p.d_offset; if (i0 >= p.n_dims) { rope_data_d[idst + i0/2 + 0] = ROPE_D_TYPE(rope_data_a[ix + i0/2 + 0]); @@ -121,6 +123,7 @@ void rope_multi(const uint i0, const uint i1, const uint i2, const uint i3, rope idst = i1*p.nb11 + i0/2; idst += rope_data_i[i2].x * p.set_rows_stride; } + idst += p.d_offset; if (i0 >= p.n_dims) { rope_data_d[idst + i0/2 + 0] = ROPE_D_TYPE(rope_data_a[ix + i0/2 + 0]); @@ -176,7 +179,7 @@ void rope_vision(const uint i0, const uint i1, const uint i2, const uint i3, rop return; } - const uint idst = i0/2 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13; + const uint idst = p.d_offset + i0/2 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13; const uint ix = rope_a_coord(i0/2, i1, i2, i3, p); const int sect_dims = p.sections[0] + p.sections[1]; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl index 2e2a7e14c66..3602485b943 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl @@ -26,6 +26,9 @@ struct rope_params { uint nb11; uint nb12; uint nb13; + + uint a_offset; + uint d_offset; }; #endif // !defined(GGML_ROPE_PARAMS) From 50482cbd229dda298f753e0ebcb6f73a6c219920 Mon Sep 17 00:00:00 2001 From: Pascal Date: Sun, 17 May 2026 11:31:20 +0200 Subject: [PATCH 644/831] vulkan: add cpy bf16 -> f32 pipelines (llama/22677) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 14 ++++++++++++-- .../ggml-vulkan/vulkan-shaders/contig_copy.comp | 8 ++++++-- ggml/src/ggml-vulkan/vulkan-shaders/copy.comp | 4 +++- .../vulkan-shaders/vulkan-shaders-gen.cpp | 2 ++ 4 files changed, 23 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 14eab8ea4de..d3fb19048d9 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -759,8 +759,8 @@ struct vk_device_struct { vk_pipeline pipeline_pad_f32; vk_pipeline pipeline_roll_f32; vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32; - vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f16_f32, pipeline_cpy_f32_bf16, pipeline_cpy_f32_i32, pipeline_cpy_i32_f32; - vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16, pipeline_contig_cpy_f32_i32, pipeline_contig_cpy_i32_f32; + vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f16_f32, pipeline_cpy_f32_bf16, pipeline_cpy_bf16_f32, pipeline_cpy_f32_i32, pipeline_cpy_i32_f32; + vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16, pipeline_contig_cpy_bf16_f32, pipeline_contig_cpy_f32_i32, pipeline_contig_cpy_i32_f32; vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT]; vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT]; vk_pipeline pipeline_cpy_transpose_16, pipeline_cpy_transpose_32; @@ -4572,6 +4572,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f16, "cpy_f16_f16", cpy_f16_f16_len, cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f32, "cpy_f16_f32", cpy_f16_f32_len, cpy_f16_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_bf16,"cpy_f32_bf16",cpy_f32_bf16_len,cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_bf16_f32,"cpy_bf16_f32",cpy_bf16_f32_len,cpy_bf16_f32_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_i32_f32, "cpy_i32_f32", cpy_i32_f32_len, cpy_i32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_i32, "cpy_f32_i32", cpy_f32_i32_len, cpy_f32_i32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); @@ -4580,6 +4581,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f16, "contig_cpy_f16_f16", contig_cpy_f16_f16_len, contig_cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f32, "contig_cpy_f16_f32", contig_cpy_f16_f32_len, contig_cpy_f16_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_bf16,"contig_cpy_f32_bf16",contig_cpy_f32_bf16_len,contig_cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_bf16_f32,"contig_cpy_bf16_f32",contig_cpy_bf16_f32_len,contig_cpy_bf16_f32_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_i32_f32, "contig_cpy_i32_f32", contig_cpy_i32_f32_len, contig_cpy_i32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_i32, "contig_cpy_f32_i32", contig_cpy_f32_i32_len, contig_cpy_f32_i32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); @@ -7544,6 +7546,13 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_cpy_f32_bf16; } } + if (src->type == GGML_TYPE_BF16 && to == GGML_TYPE_F32) { + if (contig) { + return ctx->device->pipeline_contig_cpy_bf16_f32; + } else { + return ctx->device->pipeline_cpy_bf16_f32; + } + } if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_I32) { if (contig) { return ctx->device->pipeline_contig_cpy_f32_i32; @@ -15974,6 +15983,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm if (src1_type == GGML_TYPE_F32) { switch (src0_type) { case GGML_TYPE_F16: + case GGML_TYPE_BF16: case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp b/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp index ca1a3ac25bd..b3b182fb084 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp @@ -19,7 +19,9 @@ void main() { if (idx + (num_iter-1)*num_threads < p.ne) { [[unroll]] for (uint i = 0; i < num_iter; ++i) { -#if defined(DATA_D_BF16) +#if defined(DATA_A_BF16) + data_d[get_doffset() + idx] = D_TYPE(bf16_to_fp32(uint32_t(data_a[get_aoffset() + idx]))); +#elif defined(DATA_D_BF16) float f = float(data_a[get_aoffset() + idx]); data_d[get_doffset() + idx] = D_TYPE(fp32_to_bf16(f)); #elif !defined(OPTIMIZATION_ERROR_WORKAROUND) @@ -35,7 +37,9 @@ void main() { continue; } -#if defined(DATA_D_BF16) +#if defined(DATA_A_BF16) + data_d[get_doffset() + idx] = D_TYPE(bf16_to_fp32(uint32_t(data_a[get_aoffset() + idx]))); +#elif defined(DATA_D_BF16) float f = float(data_a[get_aoffset() + idx]); data_d[get_doffset() + idx] = D_TYPE(fp32_to_bf16(f)); #elif !defined(OPTIMIZATION_ERROR_WORKAROUND) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp b/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp index 9f8bfd3c182..d55e13253a8 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp @@ -12,7 +12,9 @@ void main() { return; } -#if defined(DATA_D_BF16) +#if defined(DATA_A_BF16) + data_d[get_doffset() + dst_idx(idx)] = D_TYPE(bf16_to_fp32(uint32_t(data_a[get_aoffset() + src0_idx(idx)]))); +#elif defined(DATA_D_BF16) float f = float(data_a[get_aoffset() + src0_idx(idx)]); data_d[get_doffset() + dst_idx(idx)] = D_TYPE(fp32_to_bf16(f)); #elif !defined(OPTIMIZATION_ERROR_WORKAROUND) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index d99b2b5d802..e3a9d61a558 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -731,6 +731,7 @@ void process_shaders() { string_to_spv("cpy_f16_f16", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); string_to_spv("cpy_f16_f32", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); string_to_spv("cpy_f32_bf16","copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "uint16_t"}, {"DATA_D_BF16", "1"}}); + string_to_spv("cpy_bf16_f32","copy.comp", {{"A_TYPE", "uint16_t"}, {"D_TYPE", "float"}, {"DATA_A_BF16", "1"}}); string_to_spv("contig_cpy_f32_f32", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("contig_cpy_f32_i32", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "int"}}); string_to_spv("contig_cpy_i32_f32", "contig_copy.comp", {{"A_TYPE", "int"}, {"D_TYPE", "float"}}); @@ -738,6 +739,7 @@ void process_shaders() { string_to_spv("contig_cpy_f16_f16", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); string_to_spv("contig_cpy_f16_f32", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); string_to_spv("contig_cpy_f32_bf16","contig_copy.comp",{{"A_TYPE", "float"}, {"D_TYPE", "uint16_t"}, {"DATA_D_BF16", "1"}}); + string_to_spv("contig_cpy_bf16_f32","contig_copy.comp",{{"A_TYPE", "uint16_t"}, {"D_TYPE", "float"}, {"DATA_A_BF16", "1"}}); string_to_spv("cpy_f32_i32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "int"}}); string_to_spv("cpy_i32_f32", "copy.comp", {{"A_TYPE", "int"}, {"D_TYPE", "float"}}); From 9e96e0eaf1847b874dd9fd1734b0c9a469d6548c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Ekstr=C3=B6m?= Date: Sun, 17 May 2026 14:12:11 +0300 Subject: [PATCH 645/831] ggml-vulkan/CMakeLists: add a check for SPIRV-Headers (llama/22009) * ci/run: set explicit SPIR-V Headers search path for macOS vulkan CI For whatever reason, the files are under additional sub-path `vulkan/` under the cmake directory, which does not match either current LunarG macOS Vulkan SDK structure (`lib/cmake/SPIRV-Headers`), nor what gets installed when you run the cmake build+install for SPIRV-Headers itself on at least Linux (`share/cmake/SPIRV-Headers`). This allows for SPIRV-Headers to be found, as currently the CI runner's setup does not seem to include the relevant path in list of search locations. * ggml-vulkan/CMakeLists: add a check for SPIRV-Headers This is installed by the project if it is built and installed. Receiving an error during the configuration step is generally preferred to receiving an error in the middle of a build. --- ggml/src/ggml-vulkan/CMakeLists.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ggml/src/ggml-vulkan/CMakeLists.txt b/ggml/src/ggml-vulkan/CMakeLists.txt index 715a263a6d0..6dbcea065b3 100644 --- a/ggml/src/ggml-vulkan/CMakeLists.txt +++ b/ggml/src/ggml-vulkan/CMakeLists.txt @@ -8,6 +8,8 @@ endif() find_package(Vulkan COMPONENTS glslc REQUIRED) +find_package(SPIRV-Headers REQUIRED) + if (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC") # Parallel build object files add_definitions(/MP) From 53736a3f0e91e132901a55e562252a3519b669dc Mon Sep 17 00:00:00 2001 From: Oliver Simons Date: Sun, 17 May 2026 18:00:10 +0200 Subject: [PATCH 646/831] CUDA: Continue directly including cuda/iterator (llama/23102) Cont of #22936, forgot to update one site --- ggml/src/ggml-cuda/top-k.cu | 1 + 1 file changed, 1 insertion(+) diff --git a/ggml/src/ggml-cuda/top-k.cu b/ggml/src/ggml-cuda/top-k.cu index 59ce36fb1c9..db1d39e2dc7 100644 --- a/ggml/src/ggml-cuda/top-k.cu +++ b/ggml/src/ggml-cuda/top-k.cu @@ -5,6 +5,7 @@ # include # if (CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 2) # define CUB_TOP_K_AVAILABLE +# include using namespace cub; # endif // CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 2 #endif // GGML_CUDA_USE_CUB From 4fb3ccabd38cf5c1103c545ed1d06cc552a44915 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Sun, 17 May 2026 15:05:11 -0600 Subject: [PATCH 647/831] feat: Support d_conv=15 for ssm-conv.cu (llama/23017) Branch: ModalityConditionalAdapters AI-usage: none Signed-off-by: Gabe Goodhart --- ggml/src/ggml-cuda/ssm-conv.cu | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-cuda/ssm-conv.cu b/ggml/src/ggml-cuda/ssm-conv.cu index 4841389fbc8..4c4daf85dc6 100644 --- a/ggml/src/ggml-cuda/ssm-conv.cu +++ b/ggml/src/ggml-cuda/ssm-conv.cu @@ -140,11 +140,12 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const floa }; switch (nc) { - case 3: launch_kernel(std::integral_constant{}); break; - case 4: launch_kernel(std::integral_constant{}); break; - case 5: launch_kernel(std::integral_constant{}); break; - case 9: launch_kernel(std::integral_constant{}); break; - default: GGML_ABORT("Only support kernel sizes 3, 4, 5, 9 right now."); + case 3: launch_kernel(std::integral_constant{}); break; + case 4: launch_kernel(std::integral_constant{}); break; + case 5: launch_kernel(std::integral_constant{}); break; + case 9: launch_kernel(std::integral_constant{}); break; + case 15: launch_kernel(std::integral_constant{}); break; + default: GGML_ABORT("Only support kernel sizes 3, 4, 5, 9, 15 right now."); } } From 619262ad247dcaa319e06bd505f691ae1920b019 Mon Sep 17 00:00:00 2001 From: Intel AI Get-to Market Customer Success and Solutions Date: Sun, 17 May 2026 22:11:51 -0700 Subject: [PATCH 648/831] sycl: route small f32 matmuls to oneMKL, bypass oneDNN (llama/22150) Signed-off-by: Chun Tao Co-authored-by: Chun Tao --- ggml/src/ggml-sycl/ggml-sycl.cpp | 30 +++++++++++++++++------------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index f5d10b56de0..ebe7c5b351c 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -2385,21 +2385,25 @@ inline void ggml_sycl_op_mul_mat_sycl( const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32.get(); const float * src1_ddf1_i = src1->type == GGML_TYPE_F32 ? (const float *) src1_ddf_i : src1_ddq_as_f32.get(); + { + const int64_t gemm_flops = (int64_t)row_diff * src1_ncols * ne10; + const bool use_mkl_direct = gemm_flops < 256 * 256 * 256; #if GGML_SYCL_DNNL - if (!g_ggml_sycl_disable_dnn) { - DnnlGemmWrapper::row_gemm(ctx, row_diff, src1_ncols, ne10, src0_ddf_i, - DnnlGemmWrapper::to_dt(), src1_ddf1_i, DnnlGemmWrapper::to_dt(), - dst_dd_i, DnnlGemmWrapper::to_dt(), stream); - } - else + if (!g_ggml_sycl_disable_dnn && !use_mkl_direct) { + DnnlGemmWrapper::row_gemm(ctx, row_diff, src1_ncols, ne10, src0_ddf_i, + DnnlGemmWrapper::to_dt(), src1_ddf1_i, DnnlGemmWrapper::to_dt(), + dst_dd_i, DnnlGemmWrapper::to_dt(), stream); + } + else #endif - { - const float alpha = 1.0f; - const float beta = 0.0f; - SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm( - *stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, row_diff, - src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10, - dpct::get_value(&beta, *stream), dst_dd_i, ldc))); + { + const float alpha = 1.0f; + const float beta = 0.0f; + SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm( + *stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, row_diff, + src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10, + dpct::get_value(&beta, *stream), dst_dd_i, ldc))); + } } } GGML_UNUSED(dst); From c65b082c947effa60f679f78065a82929e7e52b1 Mon Sep 17 00:00:00 2001 From: Intel AI Get-to Market Customer Success and Solutions Date: Sun, 17 May 2026 22:12:21 -0700 Subject: [PATCH 649/831] sycl: scalar SWAR byte-subtract in Q6_K MMVQ dot product (llama/22156) Signed-off-by: Chun Tao Co-authored-by: Chun Tao --- ggml/src/ggml-sycl/vecdotq.hpp | 99 ++++++++++++++++------------------ 1 file changed, 46 insertions(+), 53 deletions(-) diff --git a/ggml/src/ggml-sycl/vecdotq.hpp b/ggml/src/ggml-sycl/vecdotq.hpp index d7770047424..16b2d65d271 100644 --- a/ggml/src/ggml-sycl/vecdotq.hpp +++ b/ggml/src/ggml-sycl/vecdotq.hpp @@ -85,6 +85,32 @@ static __dpct_inline__ int get_int_from_uint8_aligned( (const int*)(x8 + sizeof(int) * i32)); // assume at least 4 byte alignment } +static __dpct_inline__ int byte_sub_4(const int a, const int b) { + const uint32_t ua = static_cast(a); + const uint32_t ub = static_cast(b); + return static_cast(((ua | 0x80808080u) - ub) ^ 0x80808080u); +} + +static __dpct_inline__ float vec_dot_q6_K_q8_1_impl_mmvq_scalar( + const int vl, const int vh, const int u0, const int u1, const int8_t sc0, + const int8_t sc1, const float d, const float d80, const float d81) { + static_assert(QR6_K == 2, "q6_K MMVQ scalar fast path assumes QR6_K == 2"); + + const int vil0 = (vl >> 0) & 0x0F0F0F0F; + const int vih0 = ((vh >> 0) << 4) & 0x30303030; + const int vi0 = byte_sub_4(vil0 | vih0, 0x20202020); + + const int vil1 = (vl >> 4) & 0x0F0F0F0F; + const int vih1 = ((vh >> 4) << 4) & 0x30303030; + const int vi1 = byte_sub_4(vil1 | vih1, 0x20202020); + + const float sumf = + d80 * (dpct::dp4a(vi0, u0, 0) * sc0) + + d81 * (dpct::dp4a(vi1, u1, 0) * sc1); + + return d * sumf; +} + static __dpct_inline__ void get_int_from_table_16(const uint32_t &q4, const uint8_t *values, int &val1, int &val2) { @@ -279,24 +305,8 @@ vec_dot_q6_K_q8_1_impl_mmvq(const int &vl, const int &vh, const int *__restrict__ u, const int8_t *__restrict__ scales, const float &d, const float *__restrict__ d8) { - - float sumf = 0.0f; - -#pragma unroll - for (int i = 0; i < QR6_K; ++i) { - const int sc = scales[4*i]; - - const int vil = (vl >> (4*i)) & 0x0F0F0F0F; - - const int vih = ((vh >> (4*i)) << 4) & 0x30303030; - - const int vi = dpct::vectorized_binary( - (vil | vih), 0x20202020, dpct::sub_sat()); // vi = (vil | vih) - 32 - - sumf += d8[i] * (dpct::dp4a(vi, u[i], 0) * sc); // SIMD dot product - } - - return d*sumf; + return vec_dot_q6_K_q8_1_impl_mmvq_scalar( + vl, vh, u[0], u[1], scales[0], scales[4], d, d8[0], d8[1]); } // VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called @@ -542,23 +552,8 @@ template <> struct reorder_vec_dot_q_sycl { __dpct_inline__ float vec_dot_q6_K_q8_1_impl_mmvq(const int vl, const int vh, const int * __restrict__ u, const int8_t * __restrict__ scales, const float d, const float * __restrict__ d8) { - float sumf = 0.0f; - -#pragma unroll - for (int i = 0; i < QR6_K; ++i) { - const int sc = scales[4 * i]; - - const int vil = (vl >> (4 * i)) & 0x0F0F0F0F; - - const int vih = ((vh >> (4 * i)) << 4) & 0x30303030; - - const int vi = dpct::vectorized_binary((vil | vih), 0x20202020, - dpct::sub_sat()); // vi = (vil | vih) - 32 - - sumf += d8[i] * (dpct::dp4a(vi, u[i], 0) * sc); // SIMD dot product - } - - return d * sumf; + return vec_dot_q6_K_q8_1_impl_mmvq_scalar( + vl, vh, u[0], u[1], scales[0], scales[4], d, d8[0], d8[1]); } __dpct_inline__ float operator()(const void * __restrict__ vbq, const std::pair ibx_offset, @@ -579,16 +574,15 @@ template <> struct reorder_vec_dot_q_sycl { const int8_t * scs = scales + scale_offset; - int u[QR6_K]; - float d8[QR6_K]; + const int u0 = get_int_from_int8_aligned( + q8_1_quant_ptr + bq8_offset * QK8_1, iqs % QI8_1); + const int u1 = get_int_from_int8_aligned( + q8_1_quant_ptr + (bq8_offset + 2) * QK8_1, iqs % QI8_1); + const float d80 = (*(q8_1_ds + bq8_offset + 0))[0]; + const float d81 = (*(q8_1_ds + bq8_offset + 2))[0]; -#pragma unroll - for (int i = 0; i < QR6_K; ++i) { - u[i] = get_int_from_int8_aligned(q8_1_quant_ptr + (bq8_offset + 2 * i) * QK8_1, iqs % QI8_1); - const sycl::half2 ds_values = *(q8_1_ds + bq8_offset + 2 * i); - d8[i] = ds_values[0]; - } - return vec_dot_q6_K_q8_1_impl_mmvq(vl, vh, u, scs, *d, d8); + return vec_dot_q6_K_q8_1_impl_mmvq_scalar( + vl, vh, u0, u1, scs[0], scs[4], *d, d80, d81); } }; #define VDR_Q4_0_Q8_1_MMVQ 2 @@ -1167,16 +1161,15 @@ vec_dot_q6_K_q8_1(const void *__restrict__ vbq, const int8_t * scales = bq6_K->scales + scale_offset; - int u[QR6_K]; - float d8[QR6_K]; - -#pragma unroll - for (int i = 0; i < QR6_K; ++i) { - u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + 2*i].qs, iqs % QI8_1); - d8[i] = bq8_1[bq8_offset + 2 * i].ds[0]; - } + const int u0 = get_int_from_int8_aligned( + bq8_1[bq8_offset + 0].qs, iqs % QI8_1); + const int u1 = get_int_from_int8_aligned( + bq8_1[bq8_offset + 2].qs, iqs % QI8_1); + const float d80 = bq8_1[bq8_offset + 0].ds[0]; + const float d81 = bq8_1[bq8_offset + 2].ds[0]; - return vec_dot_q6_K_q8_1_impl_mmvq(vl, vh, u, scales, bq6_K->d, d8); + return vec_dot_q6_K_q8_1_impl_mmvq_scalar( + vl, vh, u0, u1, scales[0], scales[4], bq6_K->d, d80, d81); } From 0a11c9fe835b21d484e9161524d2e9dc2288fb12 Mon Sep 17 00:00:00 2001 From: Pranav Dhinakar Date: Mon, 18 May 2026 13:39:36 -0700 Subject: [PATCH 650/831] ggml-hexagon: add PAD op HVX kernel (llama/23078) * ggml-hexagon: add PAD op HVX kernel Implements GGML_OP_PAD on the Hexagon HTP backend using HVX vectorized kernels. Supports zero-padding and circular padding across all 4 tensor dimensions. * hex-ggml: remove duplicate op cases (merge conflict) * hex-pad: fix editorconfig checks and macro alignment --------- Co-authored-by: Max Krasnyansky --- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 18 + ggml/src/ggml-hexagon/htp/CMakeLists.txt | 1 + ggml/src/ggml-hexagon/htp/htp-ctx.h | 1 + ggml/src/ggml-hexagon/htp/htp-ops.h | 1 + ggml/src/ggml-hexagon/htp/main.c | 3 + ggml/src/ggml-hexagon/htp/pad-ops.c | 545 +++++++++++++++++++++++ 6 files changed, 569 insertions(+) create mode 100644 ggml/src/ggml-hexagon/htp/pad-ops.c diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 3d1c9da8329..c24a2305e4c 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -2744,6 +2744,18 @@ static bool ggml_hexagon_supported_ssm_conv(const struct ggml_hexagon_session * return true; } +static bool ggml_hexagon_supported_pad(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + const struct ggml_tensor * src0 = op->src[0]; + const struct ggml_tensor * dst = op; + + if (src0->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) { + return false; + } + + GGML_UNUSED(sess); + return true; +} + static bool ggml_hexagon_supported_cumsum(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { const struct ggml_tensor * src0 = op->src[0]; const struct ggml_tensor * dst = op; @@ -2857,6 +2869,8 @@ static htp_op_code op_remap_to_htp(const ggml_tensor * t) { case GGML_OP_FILL: return HTP_OP_FILL; case GGML_OP_DIAG: return HTP_OP_DIAG; case GGML_OP_SOLVE_TRI: return HTP_OP_SOLVE_TRI; + case GGML_OP_PAD: return HTP_OP_PAD; + case GGML_OP_UNARY: switch (ggml_get_unary_op(t)) { case GGML_UNARY_OP_SILU: return HTP_OP_UNARY_SILU; @@ -3416,6 +3430,10 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons supp = ggml_hexagon_supported_solve_tri(sess, op); break; + case GGML_OP_PAD: + supp = ggml_hexagon_supported_pad(sess, op); + break; + default: break; } diff --git a/ggml/src/ggml-hexagon/htp/CMakeLists.txt b/ggml/src/ggml-hexagon/htp/CMakeLists.txt index bcadac11f95..36f923243cd 100644 --- a/ggml/src/ggml-hexagon/htp/CMakeLists.txt +++ b/ggml/src/ggml-hexagon/htp/CMakeLists.txt @@ -38,6 +38,7 @@ add_library(${HTP_LIB} SHARED diag-ops.c solve-tri-ops.c gated-delta-net-ops.c + pad-ops.c ) target_compile_definitions(${HTP_LIB} PRIVATE diff --git a/ggml/src/ggml-hexagon/htp/htp-ctx.h b/ggml/src/ggml-hexagon/htp/htp-ctx.h index 92f02eac6e3..e500ce46212 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ctx.h +++ b/ggml/src/ggml-hexagon/htp/htp-ctx.h @@ -107,5 +107,6 @@ int op_fill(struct htp_ops_context * octx); int op_diag(struct htp_ops_context * octx); int op_solve_tri(struct htp_ops_context * octx); int op_gated_delta_net(struct htp_ops_context * octx); +int op_pad(struct htp_ops_context * octx); #endif /* HTP_CTX_H */ diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index 98db864dd42..985ded6f299 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -86,6 +86,7 @@ enum htp_op_code { HTP_OP_SOLVE_TRI, HTP_OP_L2_NORM, HTP_OP_GATED_DELTA_NET, + HTP_OP_PAD, HTP_OP_INVALID }; diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index 883a31d6163..85569f07289 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -595,6 +595,9 @@ static int execute_op(struct htp_ops_context * octx) { case HTP_OP_SOLVE_TRI: return op_solve_tri(octx); + case HTP_OP_PAD: + return op_pad(octx); + case HTP_OP_GATED_DELTA_NET: return op_gated_delta_net(octx); diff --git a/ggml/src/ggml-hexagon/htp/pad-ops.c b/ggml/src/ggml-hexagon/htp/pad-ops.c new file mode 100644 index 00000000000..3abc3c2ead1 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/pad-ops.c @@ -0,0 +1,545 @@ +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#include +#include + +#include + +#include "hex-dma.h" +#include "hvx-utils.h" + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" +#include "htp-ops.h" + +/* Circular wrap: maps any integer x into [0, n) */ +static inline uint32_t wrap_around(int32_t x, uint32_t n) { + return (uint32_t)(((x % (int32_t)n) + (int32_t)n) % (int32_t)n); +} + +/* Decompose a flat dst row index into (i1, i2, i3) */ +static inline void pad_decompose_row(uint32_t ir, uint32_t ne1, uint32_t ne2, + uint32_t *i1, uint32_t *i2, uint32_t *i3) { + *i1 = ir % ne1; + *i2 = (ir / ne1) % ne2; + *i3 = ir / (ne1 * ne2); +} + +/* Return non-zero if row (i1,i2,i3) falls in the non-padded interior */ +static inline int pad_is_interior(uint32_t i1, uint32_t i2, uint32_t i3, + int32_t lp1, int32_t rp1, uint32_t ne1, + int32_t lp2, int32_t rp2, uint32_t ne2, + int32_t lp3, int32_t rp3, uint32_t ne3) { + return ((int32_t)i1 >= lp1 && (int32_t)i1 < (int32_t)ne1 - rp1) && + ((int32_t)i2 >= lp2 && (int32_t)i2 < (int32_t)ne2 - rp2) && + ((int32_t)i3 >= lp3 && (int32_t)i3 < (int32_t)ne3 - rp3); +} + +/* Compute the DDR src row pointer for a zero-pad interior row */ +static inline const uint8_t * pad_src_row_ptr(const struct htp_tensor * src, + uint32_t i1, uint32_t i2, uint32_t i3, + int32_t lp1, int32_t lp2, int32_t lp3) { + return (const uint8_t *) src->data + + (i1 - (uint32_t)lp1) * src->nb[1] + + (i2 - (uint32_t)lp2) * src->nb[2] + + (i3 - (uint32_t)lp3) * src->nb[3]; +} + +/* Compute the DDR src row pointer for a circular row (wrap-around indexing) */ +static inline const uint8_t * pad_circ_src_row_ptr(const struct htp_tensor * src, + uint32_t i1, uint32_t i2, uint32_t i3, + int32_t lp1, int32_t lp2, int32_t lp3) { + return (const uint8_t *) src->data + + wrap_around((int32_t)i1 - lp1, src->ne[1]) * src->nb[1] + + wrap_around((int32_t)i2 - lp2, src->ne[2]) * src->nb[2] + + wrap_around((int32_t)i3 - lp3, src->ne[3]) * src->nb[3]; +} + +struct htp_pad_context { + struct htp_ops_context * octx; + + int32_t lp0, rp0; + int32_t lp1, rp1; + int32_t lp2, rp2; + int32_t lp3, rp3; + + uint32_t nrows_per_thread; + uint32_t total_dst_rows; + + size_t type_size; + + // Row sizes for DMA kernel (populated when VTCM is available) + size_t src_row_size; + size_t src_row_size_aligned; + size_t dst_row_size; + size_t dst_row_size_aligned; +}; + +#define htp_pad_preamble \ + const struct htp_tensor * src = octx->src[0]; \ + const struct htp_tensor * dst = octx->dst; \ + \ + const uint32_t ne00 = src->ne[0]; \ + const uint32_t nb00 = src->nb[0]; \ + \ + const uint32_t ne0 = dst->ne[0]; \ + const uint32_t ne1 = dst->ne[1]; \ + const uint32_t ne2 = dst->ne[2]; \ + const uint32_t ne3 = dst->ne[3]; \ + \ + const uint32_t nb1 = dst->nb[1]; \ + const uint32_t nb2 = dst->nb[2]; \ + const uint32_t nb3 = dst->nb[3]; \ + \ + const int32_t lp0 = pctx->lp0, rp0 = pctx->rp0; \ + const int32_t lp1 = pctx->lp1, rp1 = pctx->rp1; \ + const int32_t lp2 = pctx->lp2, rp2 = pctx->rp2; \ + const int32_t lp3 = pctx->lp3, rp3 = pctx->rp3; \ + \ + const size_t type_size = pctx->type_size; \ + \ + const uint32_t row_start = pctx->nrows_per_thread * ith; \ + const uint32_t row_end = MIN(row_start + pctx->nrows_per_thread, pctx->total_dst_rows); + + +#define htp_pad_dma_preamble \ + const size_t src_row_size = pctx->src_row_size; \ + const size_t src_row_size_aligned = pctx->src_row_size_aligned; \ + const size_t dst_row_size = pctx->dst_row_size; \ + const size_t dst_row_size_aligned = pctx->dst_row_size_aligned; \ + \ + uint8_t * src_spad_base = octx->src0_spad.data + ith * octx->src0_spad.size_per_thread; \ + uint8_t * dst_spad_base = octx->dst_spad.data + ith * octx->dst_spad.size_per_thread; \ + \ + dma_queue * dma = octx->ctx->dma[ith]; + +// --------------------------------------------------------------------------- +// HVX vectorized PAD kernel +// --------------------------------------------------------------------------- + +static void pad_job_per_thread_hvx(unsigned int nth, unsigned int ith, void * data) { + const struct htp_pad_context * pctx = (const struct htp_pad_context *) data; + struct htp_ops_context * octx = pctx->octx; + htp_pad_preamble; + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + for (uint32_t dst_row = row_start; dst_row < row_end; dst_row++) { + uint32_t i1, i2, i3; + pad_decompose_row(dst_row, ne1, ne2, &i1, &i2, &i3); + + uint8_t * dst_ptr = (uint8_t *) dst->data + i1 * nb1 + i2 * nb2 + i3 * nb3; + + const int interior = pad_is_interior(i1, i2, i3, + lp1, rp1, ne1, + lp2, rp2, ne2, + lp3, rp3, ne3); + + if (!interior) { + hvx_splat_f32_u(dst_ptr, 0.0f, ne0); + } else { + const uint8_t * src_ptr = pad_src_row_ptr(src, i1, i2, i3, lp1, lp2, lp3); + + if (lp0 > 0) { + hvx_splat_f32_u(dst_ptr, 0.0f, (uint32_t)lp0); + } + + uint8_t * dst_row_start = dst_ptr + (size_t)lp0 * type_size; + if (nb00 == type_size) { + hvx_copy_f32_uu(dst_row_start, src_ptr, ne00); + } else { + for (uint32_t i = 0; i < ne00; i++) { + memcpy(dst_row_start + i * type_size, + src_ptr + (size_t)i * nb00, + type_size); + } + } + + if (rp0 > 0) { + hvx_splat_f32_u(dst_ptr + ((size_t)lp0 + ne00) * type_size, 0.0f, (uint32_t)rp0); + } + } + } + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "pad-hvx %d/%d: (%ux%ux%ux%u) -> (%ux%ux%ux%u) rows %u:%u usec %u\n", + ith, nth, + src->ne[0], src->ne[1], src->ne[2], src->ne[3], + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + row_start, row_end, + (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +// --------------------------------------------------------------------------- +// HVX + DMA PAD kernel — aligned, double-buffered +// --------------------------------------------------------------------------- + +static void pad_job_per_thread_hvx_dma(unsigned int nth, unsigned int ith, void * data) { + const struct htp_pad_context * pctx = (const struct htp_pad_context *) data; + struct htp_ops_context * octx = pctx->octx; + htp_pad_preamble; + htp_pad_dma_preamble; + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + // ----------------------------------------------------------------------- + // Priming phase: push 2 pairs of (dummy_dst_DMA, src_DMA) to seed the + // double-buffer pipeline before the main loop begins. + // ----------------------------------------------------------------------- + for (uint32_t ir = row_start, spad_idx = 0; ir < row_end && spad_idx < 2; ir++, spad_idx++) { + uint8_t * src_spad_cur = src_spad_base + spad_idx * src_row_size_aligned; + uint8_t * dst_spad_cur = dst_spad_base + spad_idx * dst_row_size_aligned; + + dma_queue_push_vtcm_to_ddr(dma, + dma_make_ptr((uint8_t *)dst->data, dst_spad_cur), + dst_row_size, dst_row_size_aligned, 0); + + uint32_t i1, i2, i3; + pad_decompose_row(ir, ne1, ne2, &i1, &i2, &i3); + const int interior = pad_is_interior(i1, i2, i3, + lp1, rp1, ne1, + lp2, rp2, ne2, + lp3, rp3, ne3); + + const uint8_t * src_ptr = interior + ? pad_src_row_ptr(src, i1, i2, i3, lp1, lp2, lp3) : NULL; + + // Interior row: real DMA (1 row) from DDR to VTCM. + // Border row: null DMA (nrows=0) + dma_queue_push_ddr_to_vtcm(dma, + dma_make_ptr(src_spad_cur, + src_ptr ? src_ptr : (const uint8_t *)src_spad_cur), + src_row_size_aligned, src_row_size, src_ptr ? 1 : 0); + } + + // ----------------------------------------------------------------------- + // Main loop: pop completed DMAs, compute in VTCM with aligned HVX ops, + // push dst DMA and prefetch src for the next+1 row. + // ----------------------------------------------------------------------- + for (uint32_t ir = row_start; ir < row_end; ir++) { + uint8_t * dst_spad_cur = (uint8_t *) dma_queue_pop(dma).src; + uint8_t * src_spad_cur = (uint8_t *) dma_queue_pop(dma).dst; + + uint32_t i1, i2, i3; + pad_decompose_row(ir, ne1, ne2, &i1, &i2, &i3); + + uint8_t * dst_ptr = (uint8_t *) dst->data + i1 * nb1 + i2 * nb2 + i3 * nb3; + + const int interior = pad_is_interior(i1, i2, i3, + lp1, rp1, ne1, + lp2, rp2, ne2, + lp3, rp3, ne3); + + if (!interior) { + hvx_splat_f32_a(dst_spad_cur, 0.0f, ne0); + } else { + hvx_splat_f32_a(dst_spad_cur, 0.0f, ne0); + + uint8_t * dst_interior = dst_spad_cur + (size_t)lp0 * type_size; + + if ((uintptr_t)dst_interior % VLEN == 0) { + hvx_copy_f32_aa(dst_interior, src_spad_cur, ne00); + } else { + hvx_copy_f32_ua(dst_interior, src_spad_cur, ne00); + } + } + + dma_queue_push_vtcm_to_ddr(dma, + dma_make_ptr(dst_ptr, dst_spad_cur), + dst_row_size, dst_row_size_aligned, 1); + + const uint32_t next_row = ir + 2; + if (next_row < row_end) { + uint32_t ni1, ni2, ni3; + pad_decompose_row(next_row, ne1, ne2, &ni1, &ni2, &ni3); + const int next_interior = pad_is_interior(ni1, ni2, ni3, + lp1, rp1, ne1, + lp2, rp2, ne2, + lp3, rp3, ne3); + const uint8_t * next_src_ptr = next_interior + ? pad_src_row_ptr(src, ni1, ni2, ni3, lp1, lp2, lp3) : NULL; + + dma_queue_push_ddr_to_vtcm(dma, + dma_make_ptr(src_spad_cur, + next_src_ptr ? next_src_ptr : (const uint8_t *)src_spad_cur), + src_row_size_aligned, src_row_size, next_src_ptr ? 1 : 0); + } + } + + dma_queue_flush(dma); + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "pad-hvx-dma %d/%d: (%ux%ux%ux%u) -> (%ux%ux%ux%u) rows %u:%u usec %u\n", + ith, nth, + src->ne[0], src->ne[1], src->ne[2], src->ne[3], + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + row_start, row_end, + (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +// --------------------------------------------------------------------------- +// HVX circular PAD kernel +// --------------------------------------------------------------------------- + +static void pad_job_per_thread_hvx_circular(unsigned int nth, unsigned int ith, void * data) { + const struct htp_pad_context * pctx = (const struct htp_pad_context *) data; + struct htp_ops_context * octx = pctx->octx; + htp_pad_preamble; + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + for (uint32_t dst_row = row_start; dst_row < row_end; dst_row++) { + uint32_t i1, i2, i3; + pad_decompose_row(dst_row, ne1, ne2, &i1, &i2, &i3); + + uint8_t * dst_ptr = (uint8_t *) dst->data + i1 * nb1 + i2 * nb2 + i3 * nb3; + const uint8_t * src_row = pad_circ_src_row_ptr(src, i1, i2, i3, lp1, lp2, lp3); + + if (nb00 == type_size) { + + if (lp0 > 0) { + if ((uint32_t)lp0 < 32) { + memcpy(dst_ptr, + src_row + (size_t)(ne00 - (uint32_t)lp0) * type_size, + (size_t)lp0 * type_size); + } else { + hvx_copy_f32_uu(dst_ptr, + src_row + (size_t)(ne00 - (uint32_t)lp0) * type_size, + (uint32_t)lp0); + } + } + hvx_copy_f32_uu(dst_ptr + (size_t)lp0 * type_size, src_row, ne00); + if (rp0 > 0) { + if ((uint32_t)rp0 < 32) { + memcpy(dst_ptr + ((size_t)lp0 + ne00) * type_size, + src_row, + (size_t)rp0 * type_size); + } else { + hvx_copy_f32_uu(dst_ptr + ((size_t)lp0 + ne00) * type_size, + src_row, + (uint32_t)rp0); + } + } + } else { + for (uint32_t i = 0; i < (uint32_t)lp0; i++) { + *(float *)(dst_ptr + i * type_size) = + *(const float *)(src_row + (size_t)(ne00 - (uint32_t)lp0 + i) * nb00); + } + for (uint32_t i = 0; i < ne00; i++) { + *(float *)(dst_ptr + ((size_t)lp0 + i) * type_size) = + *(const float *)(src_row + (size_t)i * nb00); + } + for (uint32_t i = 0; i < (uint32_t)rp0; i++) { + *(float *)(dst_ptr + ((size_t)lp0 + ne00 + i) * type_size) = + *(const float *)(src_row + (size_t)i * nb00); + } + } + } + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "pad-hvx-circ %d/%d: (%ux%ux%ux%u) -> (%ux%ux%ux%u) rows %u:%u usec %u\n", + ith, nth, + src->ne[0], src->ne[1], src->ne[2], src->ne[3], + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + row_start, row_end, + (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +// --------------------------------------------------------------------------- +// HVX + DMA circular PAD kernel — aligned, double-buffered +// --------------------------------------------------------------------------- + +static void pad_job_per_thread_hvx_circular_dma(unsigned int nth, unsigned int ith, void * data) { + const struct htp_pad_context * pctx = (const struct htp_pad_context *) data; + struct htp_ops_context * octx = pctx->octx; + htp_pad_preamble; + htp_pad_dma_preamble; + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + // ----------------------------------------------------------------------- + // Priming phase: push 2 pairs of (dummy_dst_DMA, src_DMA) to seed the + // double-buffer pipeline. Every row is a real src DMA (no null DMAs). + // ----------------------------------------------------------------------- + for (uint32_t ir = row_start, spad_idx = 0; ir < row_end && spad_idx < 2; ir++, spad_idx++) { + uint8_t * src_spad_cur = src_spad_base + spad_idx * src_row_size_aligned; + uint8_t * dst_spad_cur = dst_spad_base + spad_idx * dst_row_size_aligned; + + dma_queue_push_vtcm_to_ddr(dma, + dma_make_ptr((uint8_t *)dst->data, dst_spad_cur), + dst_row_size, dst_row_size_aligned, 0); + + uint32_t pi1, pi2, pi3; + pad_decompose_row(ir, ne1, ne2, &pi1, &pi2, &pi3); + dma_queue_push_ddr_to_vtcm(dma, + dma_make_ptr(src_spad_cur, pad_circ_src_row_ptr(src, pi1, pi2, pi3, lp1, lp2, lp3)), + src_row_size_aligned, src_row_size, 1); + } + + // ----------------------------------------------------------------------- + // Main loop: pop completed DMAs, assemble circular row in VTCM with + // aligned HVX ops, push dst DMA and prefetch src for the next+1 row. + // ----------------------------------------------------------------------- + for (uint32_t ir = row_start; ir < row_end; ir++) { + uint8_t * dst_spad_cur = (uint8_t *) dma_queue_pop(dma).src; + uint8_t * src_spad_cur = (uint8_t *) dma_queue_pop(dma).dst; + + uint32_t i1, i2, i3; + pad_decompose_row(ir, ne1, ne2, &i1, &i2, &i3); + uint8_t * dst_ptr = (uint8_t *) dst->data + i1 * nb1 + i2 * nb2 + i3 * nb3; + + + if (lp0 > 0) { + uint8_t * dst_left = dst_spad_cur; + const uint8_t * src_left = src_spad_cur + (size_t)(ne00 - (uint32_t)lp0) * type_size; + if ((uint32_t)lp0 < 32) { + memcpy(dst_left, src_left, (size_t)lp0 * type_size); + } else { + hvx_copy_f32_uu(dst_left, src_left, (uint32_t)lp0); + } + } + + { + uint8_t * dst_mid = dst_spad_cur + (size_t)lp0 * type_size; + if ((uintptr_t)dst_mid % VLEN == 0) { + hvx_copy_f32_aa(dst_mid, src_spad_cur, ne00); + } else { + hvx_copy_f32_ua(dst_mid, src_spad_cur, ne00); + } + } + + if (rp0 > 0) { + uint8_t * dst_right = dst_spad_cur + ((size_t)lp0 + ne00) * type_size; + if ((uint32_t)rp0 < 32) { + memcpy(dst_right, src_spad_cur, (size_t)rp0 * type_size); + } else { + if ((uintptr_t)dst_right % VLEN == 0) { + hvx_copy_f32_aa(dst_right, src_spad_cur, (uint32_t)rp0); + } else { + hvx_copy_f32_ua(dst_right, src_spad_cur, (uint32_t)rp0); + } + } + } + + dma_queue_push_vtcm_to_ddr(dma, + dma_make_ptr(dst_ptr, dst_spad_cur), + dst_row_size, dst_row_size_aligned, 1); + + const uint32_t next_row = ir + 2; + if (next_row < row_end) { + uint32_t nri1, nri2, nri3; + pad_decompose_row(next_row, ne1, ne2, &nri1, &nri2, &nri3); + dma_queue_push_ddr_to_vtcm(dma, + dma_make_ptr(src_spad_cur, + pad_circ_src_row_ptr(src, nri1, nri2, nri3, lp1, lp2, lp3)), + src_row_size_aligned, src_row_size, 1); + } + } + + dma_queue_flush(dma); + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "pad-hvx-circ-dma %d/%d: (%ux%ux%ux%u) -> (%ux%ux%ux%u) rows %u:%u usec %u\n", + ith, nth, + src->ne[0], src->ne[1], src->ne[2], src->ne[3], + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + row_start, row_end, + (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +int op_pad(struct htp_ops_context * octx) { + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * dst = octx->dst; + + // Only F32 supported + size_t type_size; + switch (src0->type) { + case HTP_TYPE_F32: type_size = 4; break; + default: + FARF(ERROR, "pad-hvx: unsupported type %u\n", src0->type); + return HTP_STATUS_NO_SUPPORT; + } + + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { + return HTP_STATUS_OK; + } + + const int32_t lp0 = octx->op_params[0]; + const int32_t rp0 = octx->op_params[1]; + const int32_t lp1 = octx->op_params[2]; + const int32_t rp1 = octx->op_params[3]; + const int32_t lp2 = octx->op_params[4]; + const int32_t rp2 = octx->op_params[5]; + const int32_t lp3 = octx->op_params[6]; + const int32_t rp3 = octx->op_params[7]; + const int32_t circular = octx->op_params[8]; + + const uint32_t ne0 = dst->ne[0]; + const uint32_t ne00 = src0->ne[0]; + + const uint32_t total_dst_rows = dst->ne[1] * dst->ne[2] * dst->ne[3]; + const uint32_t n_threads = MIN(octx->n_threads, total_dst_rows > 0 ? total_dst_rows : 1); + + const size_t src_row_size = (size_t)ne00 * type_size; + const size_t dst_row_size = (size_t)ne0 * type_size; + const size_t src_row_size_aligned = hex_round_up(src_row_size, VLEN); + const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN); + + // Total VTCM needed: 2 buffers (ping+pong) for src and dst, per thread + const size_t vtcm_needed = (size_t)n_threads * 2 * (src_row_size_aligned + dst_row_size_aligned); + + const int use_dma = (src0->nb[0] == (uint32_t)type_size) && + (ne00 >= 512) && + (octx->ctx->vtcm_base != NULL) && + (octx->ctx->vtcm_size >= vtcm_needed); + + if (use_dma) { + octx->src0_spad.size_per_thread = 2 * src_row_size_aligned; + octx->dst_spad.size_per_thread = 2 * dst_row_size_aligned; + octx->src0_spad.size = n_threads * octx->src0_spad.size_per_thread; + octx->dst_spad.size = n_threads * octx->dst_spad.size_per_thread; + octx->src0_spad.data = octx->ctx->vtcm_base; + octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size; + } + + struct htp_pad_context pctx = { + .octx = octx, + .lp0 = lp0, .rp0 = rp0, + .lp1 = lp1, .rp1 = rp1, + .lp2 = lp2, .rp2 = rp2, + .lp3 = lp3, .rp3 = rp3, + .nrows_per_thread = (total_dst_rows + n_threads - 1) / n_threads, + .total_dst_rows = total_dst_rows, + .type_size = type_size, + .src_row_size = src_row_size, + .src_row_size_aligned = src_row_size_aligned, + .dst_row_size = dst_row_size, + .dst_row_size_aligned = dst_row_size_aligned, + }; + + FARF(HIGH, "pad-hvx%s%s: (%ux%ux%ux%u) -> (%ux%ux%ux%u) pads=(%d,%d,%d,%d,%d,%d,%d,%d)\n", + circular ? "-circ" : "", + use_dma ? "-dma" : "", + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3); + + if (circular && use_dma) { worker_pool_run_func(octx->ctx->worker_pool, pad_job_per_thread_hvx_circular_dma, &pctx, n_threads); } + else if (circular) { worker_pool_run_func(octx->ctx->worker_pool, pad_job_per_thread_hvx_circular, &pctx, n_threads); } + else if (use_dma) { worker_pool_run_func(octx->ctx->worker_pool, pad_job_per_thread_hvx_dma, &pctx, n_threads); } + else { worker_pool_run_func(octx->ctx->worker_pool, pad_job_per_thread_hvx, &pctx, n_threads); } + + return HTP_STATUS_OK; +} + From eb558f23cb6d2ea78d5e1401d62bebfa770446c9 Mon Sep 17 00:00:00 2001 From: Pranav Dhinakar Date: Mon, 18 May 2026 14:04:57 -0700 Subject: [PATCH 651/831] hexagon: add support for TRI op (llama/22822) * Hexagon: TRI HVX Kernel addition to ggml hexagon HTP ops and context * addressed PR review comments for TRI op * hexagon: clang format * hex-unary: remove merge conflict markers * hex-ggml: remove duplicate op cases (merge conflict) * hex-ggml: fix editor config errors --------- Co-authored-by: Todor Boinovski Co-authored-by: Max Krasnyansky --- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 20 +++++ ggml/src/ggml-hexagon/htp/htp-ctx.h | 1 + ggml/src/ggml-hexagon/htp/htp-ops.h | 1 + ggml/src/ggml-hexagon/htp/main.c | 3 + ggml/src/ggml-hexagon/htp/unary-ops.c | 113 ++++++++++++++++++++++++- 5 files changed, 137 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index c24a2305e4c..2f75e97ac66 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -2828,6 +2828,21 @@ static bool ggml_hexagon_supported_solve_tri(const struct ggml_hexagon_session * return true; } +static bool ggml_hexagon_supported_tri(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + + const struct ggml_tensor * src0 = op->src[0]; + const struct ggml_tensor * dst = op; + + if (src0->type != GGML_TYPE_F32) { return false; } + if (dst->type != GGML_TYPE_F32) { return false; } + if (!ggml_are_same_shape(src0, dst)) { return false; } + if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(dst)) { return false; } + + return true; + + GGML_UNUSED(sess); +} + static const char * ggml_backend_hexagon_name(ggml_backend_t backend) { auto sess = static_cast(backend->context); return sess->c_name(); @@ -2869,6 +2884,7 @@ static htp_op_code op_remap_to_htp(const ggml_tensor * t) { case GGML_OP_FILL: return HTP_OP_FILL; case GGML_OP_DIAG: return HTP_OP_DIAG; case GGML_OP_SOLVE_TRI: return HTP_OP_SOLVE_TRI; + case GGML_OP_TRI: return HTP_OP_TRI; case GGML_OP_PAD: return HTP_OP_PAD; case GGML_OP_UNARY: @@ -3430,6 +3446,10 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons supp = ggml_hexagon_supported_solve_tri(sess, op); break; + case GGML_OP_TRI: + supp = ggml_hexagon_supported_tri(sess, op); + break; + case GGML_OP_PAD: supp = ggml_hexagon_supported_pad(sess, op); break; diff --git a/ggml/src/ggml-hexagon/htp/htp-ctx.h b/ggml/src/ggml-hexagon/htp/htp-ctx.h index e500ce46212..6fe3e6c7d85 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ctx.h +++ b/ggml/src/ggml-hexagon/htp/htp-ctx.h @@ -107,6 +107,7 @@ int op_fill(struct htp_ops_context * octx); int op_diag(struct htp_ops_context * octx); int op_solve_tri(struct htp_ops_context * octx); int op_gated_delta_net(struct htp_ops_context * octx); +int op_tri(struct htp_ops_context * octx); int op_pad(struct htp_ops_context * octx); #endif /* HTP_CTX_H */ diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index 985ded6f299..676e948a439 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -86,6 +86,7 @@ enum htp_op_code { HTP_OP_SOLVE_TRI, HTP_OP_L2_NORM, HTP_OP_GATED_DELTA_NET, + HTP_OP_TRI, HTP_OP_PAD, HTP_OP_INVALID diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index 85569f07289..12003c1fd8a 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -601,6 +601,9 @@ static int execute_op(struct htp_ops_context * octx) { case HTP_OP_GATED_DELTA_NET: return op_gated_delta_net(octx); + case HTP_OP_TRI: + return op_tri(octx); + case HTP_OP_INVALID: break; diff --git a/ggml/src/ggml-hexagon/htp/unary-ops.c b/ggml/src/ggml-hexagon/htp/unary-ops.c index d4ae89ee6f0..1ce881353ec 100644 --- a/ggml/src/ggml-hexagon/htp/unary-ops.c +++ b/ggml/src/ggml-hexagon/htp/unary-ops.c @@ -17,7 +17,6 @@ #include "ggml-common.h" #include "htp-ctx.h" #include "htp-ops.h" -#include "htp-ops.h" struct htp_unary_context { struct htp_ops_context * octx; @@ -277,6 +276,95 @@ static void sigmoid_f32(const float * restrict src, } } +static void tri_f32(const float * restrict src, + float * restrict dst, + uint8_t * restrict spad, + const uint32_t num_rows, + const uint32_t row_elems, + const size_t row_size, + int32_t * op_params, + const uint32_t ir, + const struct htp_unary_context * uctx) { + + const int32_t ttype = op_params[0]; + const HVX_Vector zero = hvx_vec_splat_f32(0.0f); + const uint32_t nvec = row_elems / VLEN_FP32; + const uint32_t nloe = row_elems % VLEN_FP32; + + const uint32_t ne01 = uctx->octx->src[0]->ne[1]; + + for (uint32_t b = 0; b < num_rows; b++) { + const uint32_t abs_row = ir + b; + const uint32_t i01 = abs_row % ne01; + + const HVX_Vector * restrict v_src = (const HVX_Vector *) ((const uint8_t *) src + b * row_size); + HVX_Vector * restrict v_dst = (HVX_Vector *) ((uint8_t *) dst + b * row_size); + + uint32_t boundary; + int keep_left; + switch (ttype) { + case 0: boundary = i01; keep_left = 0; break; // keep col >= row + case 1: boundary = i01 + 1; keep_left = 0; break; // keep col > row + case 2: boundary = i01 + 1; keep_left = 1; break; // keep col <= row + case 3: boundary = i01; keep_left = 1; break; // keep col < row + default: boundary = 0; keep_left = 0; break; + } + if (boundary > row_elems) boundary = row_elems; + + // Full HVX vectors — each starts at a 128-byte aligned offset + for (uint32_t i = 0; i < nvec; i++) { + const uint32_t vec_start = i * VLEN_FP32; + const uint32_t vec_end = vec_start + VLEN_FP32; + if (keep_left) { + if (vec_end <= boundary) { + v_dst[i] = v_src[i]; + } else if (vec_start >= boundary) { + v_dst[i] = zero; + } else { + HVX_VectorPred mask = Q6_Q_vsetq_R((boundary - vec_start) * sizeof(float)); + v_dst[i] = Q6_V_vmux_QVV(mask, v_src[i], zero); + } + } else { + if (vec_end <= boundary) { + v_dst[i] = zero; + } else if (vec_start >= boundary) { + v_dst[i] = v_src[i]; + } else { + HVX_VectorPred mask = Q6_Q_vsetq_R((boundary - vec_start) * sizeof(float)); + v_dst[i] = Q6_V_vmux_QVV(mask, zero, v_src[i]); + } + } + } + + // Tail elements (row_elems not a multiple of VLEN_FP32) + if (nloe > 0) { + const uint32_t vec_start = nvec * VLEN_FP32; + const uint32_t vec_end = vec_start + nloe; + HVX_Vector tail_val; + if (keep_left) { + if (vec_end <= boundary) { + tail_val = v_src[nvec]; + } else if (vec_start >= boundary) { + tail_val = zero; + } else { + HVX_VectorPred mask = Q6_Q_vsetq_R((boundary - vec_start) * sizeof(float)); + tail_val = Q6_V_vmux_QVV(mask, v_src[nvec], zero); + } + } else { + if (vec_end <= boundary) { + tail_val = zero; + } else if (vec_start >= boundary) { + tail_val = v_src[nvec]; + } else { + HVX_VectorPred mask = Q6_Q_vsetq_R((boundary - vec_start) * sizeof(float)); + tail_val = Q6_V_vmux_QVV(mask, zero, v_src[nvec]); + } + } + hvx_vec_store_a(&v_dst[nvec], nloe * sizeof(float), tail_val); + } + } +} + static void softplus_f32(const float * restrict src, float * restrict dst, uint8_t * restrict spad, @@ -498,6 +586,9 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * case HTP_OP_L2_NORM: l2_norm_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params); break; + case HTP_OP_TRI: + tri_f32(src0_spad, dst_spad, NULL, block_size, ne00, src0_row_size_aligned, op_params, ir, uctx); + break; default: break; } @@ -571,6 +662,10 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) { case HTP_OP_L2_NORM: op_type = "l2norm-f32"; break; + case HTP_OP_TRI: + op_type = "tri-f32"; + break; + default: FARF(ERROR, "Unsupported unary Op %u\n", octx->op); return HTP_STATUS_NO_SUPPORT; @@ -640,6 +735,22 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) { return err; } +int op_tri(struct htp_ops_context * octx) { + int err = HTP_STATUS_OK; + + switch (octx->src[0]->type) { + case HTP_TYPE_F32: + err = execute_op_unary_f32(octx); + break; + + default: + err = HTP_STATUS_NO_SUPPORT; + break; + } + + return err; +} + int op_unary(struct htp_ops_context * octx) { int err = HTP_STATUS_OK; From 3477fdb2e3c5b8828aa06da944c644b32bb7b6de Mon Sep 17 00:00:00 2001 From: Radoslav Gerganov Date: Tue, 19 May 2026 09:42:36 +0300 Subject: [PATCH 652/831] rpc : keep last_graph_uid in the device context (llama/23273) With the introduction of MTP we can have multiple compute contexts for the same RPC device. In this case last_graph_uid is not updated properly when contexts are being switched. This patch fixes this by moving last_graph_uid to the device context, making sure it is always updated. closes: #23242 --- ggml/src/ggml-rpc/ggml-rpc.cpp | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index 1cb8f563d85..d3805772183 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -199,6 +199,14 @@ static ggml_guid_t ggml_backend_rpc_guid() { return &guid; } +struct ggml_backend_rpc_device_context { + std::string endpoint; + uint32_t device; + std::string name; + std::string description; + uint64_t last_graph_uid; +}; + struct ggml_backend_rpc_buffer_type_context { std::string endpoint; uint32_t device; @@ -211,7 +219,6 @@ struct ggml_backend_rpc_context { std::string endpoint; uint32_t device; std::string name; - uint64_t last_graph_uid; }; struct ggml_backend_rpc_buffer_context { @@ -691,9 +698,11 @@ static void serialize_graph(uint32_t device, const ggml_cgraph * cgraph, std::ve static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context; + ggml_backend_dev_t rpc_dev = ggml_backend_get_device(backend); + ggml_backend_rpc_device_context * rpc_dev_ctx = (ggml_backend_rpc_device_context *)rpc_dev->context; GGML_ASSERT(cgraph->n_nodes > 0); - bool reuse = cgraph->uid != 0 && rpc_ctx->last_graph_uid == cgraph->uid; + bool reuse = cgraph->uid != 0 && rpc_dev_ctx->last_graph_uid == cgraph->uid; if (reuse) { rpc_msg_graph_recompute_req request; request.device = rpc_ctx->device; @@ -701,7 +710,7 @@ static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, g bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_RECOMPUTE, &request, sizeof(request)); RPC_STATUS_ASSERT(status); } else { - rpc_ctx->last_graph_uid = cgraph->uid; + rpc_dev_ctx->last_graph_uid = cgraph->uid; std::vector input; serialize_graph(rpc_ctx->device, cgraph, input); auto sock = get_socket(rpc_ctx->endpoint); @@ -770,7 +779,6 @@ ggml_backend_t ggml_backend_rpc_init(const char * endpoint, uint32_t device) { /* .endpoint = */ endpoint, /* .device = */ device, /* .name = */ dev_name, - /* .last_graph_uid = */ 0, }; auto reg = ggml_backend_rpc_add_server(endpoint); ggml_backend_t backend = new ggml_backend { @@ -1757,15 +1765,6 @@ void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir } } -// device interface - -struct ggml_backend_rpc_device_context { - std::string endpoint; - uint32_t device; - std::string name; - std::string description; -}; - static const char * ggml_backend_rpc_device_get_name(ggml_backend_dev_t dev) { ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context; @@ -1947,10 +1946,11 @@ ggml_backend_reg_t ggml_backend_rpc_add_server(const char * endpoint) { std::string dev_name = "RPC" + std::to_string(dev_id); std::string dev_desc = std::string(endpoint); ggml_backend_rpc_device_context * dev_ctx = new ggml_backend_rpc_device_context { - /* .endpoint = */ endpoint, - /* .device = */ ind, - /* .name = */ dev_name, - /* .description = */ dev_desc + /* .endpoint = */ endpoint, + /* .device = */ ind, + /* .name = */ dev_name, + /* .description = */ dev_desc, + /* .last_graph_uid = */ 0, }; ggml_backend_dev_t dev = new ggml_backend_device { From 28edd0cb36581e9b4c4eb0a253c8cc9c9b13f9db Mon Sep 17 00:00:00 2001 From: Intel AI Get-to Market Customer Success and Solutions Date: Mon, 18 May 2026 23:44:02 -0700 Subject: [PATCH 653/831] sycl: add GGML_SYCL_USE_ASYNC_MEM_OP env toggle (llama/22153) * sycl: add GGML_SYCL_USE_ASYNC_MEM_OP env toggle Signed-off-by: Chun Tao * Use async mem ops for correctness when SYCL graphs are explicitly on. Signed-off-by: Tao, Chun --------- Signed-off-by: Chun Tao Signed-off-by: Tao, Chun Co-authored-by: Chun Tao --- ggml/src/ggml-sycl/ggml-sycl.cpp | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index ebe7c5b351c..2ea47f7153a 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -72,6 +72,7 @@ int g_ggml_sycl_disable_graph = 0; int g_ggml_sycl_disable_dnn = 0; int g_ggml_sycl_prioritize_dmmv = 0; int g_ggml_sycl_use_async_mem_op = 0; +int g_ggml_sycl_use_async_mem_op_requested = 1; int g_ggml_sycl_enable_level_zero = 0; int g_ggml_sycl_enable_flash_attention = 1; @@ -304,6 +305,8 @@ static void ggml_check_sycl() try { GGML_LOG_INFO(" GGML_SYCL_DISABLE_DNN: DNN disabled by compile flag\n"); #endif GGML_LOG_INFO(" GGML_SYCL_PRIORITIZE_DMMV: %d\n", g_ggml_sycl_prioritize_dmmv); + g_ggml_sycl_use_async_mem_op_requested = get_sycl_env("GGML_SYCL_USE_ASYNC_MEM_OP", 1); + GGML_LOG_INFO(" GGML_SYCL_USE_ASYNC_MEM_OP: %d\n", g_ggml_sycl_use_async_mem_op_requested); #ifdef SYCL_FLASH_ATTN GGML_LOG_INFO(" GGML_SYCL_ENABLE_FLASH_ATTN: %d\n", g_ggml_sycl_enable_flash_attention); @@ -319,11 +322,11 @@ static void ggml_check_sycl() try { fprintf(stderr, "%s: SYCL_USE_XMX: no\n", __func__); #endif */ - // Currently, we only use async malloc / free when graphs are enabled as it is required for the calls to be - // properly recorded. As this SYCL extension matures it may be beneficial to enable as the default path and in - // other places. + // Async USM allocation/free is also useful outside the graph path: it avoids the host waits in the reorder + // staging path while preserving queue ordering semantics. Graph support still depends on the extension being + // available, but it no longer needs to control the non-graph fast path. #if defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC - g_ggml_sycl_use_async_mem_op = !g_ggml_sycl_disable_graph; + g_ggml_sycl_use_async_mem_op = g_ggml_sycl_use_async_mem_op_requested || !g_ggml_sycl_disable_graph; if (g_ggml_sycl_use_async_mem_op) { for (unsigned int i = 0; i < dpct::dev_mgr::instance().device_count(); ++i) { if (!dpct::dev_mgr::instance().get_device(i).has(sycl::aspect::ext_oneapi_async_memory_alloc)) { From 6090f39f36056a9eb673c7d77bb90a4126d71fa9 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Mon, 18 May 2026 23:45:41 -0700 Subject: [PATCH 654/831] ggml-webgpu : extend GDN for K>1 (llama/23299) --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 2 ++ .../wgsl-shaders/gated_delta_net.wgsl | 24 +++++++++++++++---- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 78cb02be06d..921c12b41ac 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -1234,6 +1234,7 @@ static webgpu_encoded_op ggml_webgpu_gated_delta_net(webgpu_context & ctx, const uint32_t h = (uint32_t) src2->ne[1]; const uint32_t n_tokens = (uint32_t) src2->ne[2]; const uint32_t n_seqs = (uint32_t) src2->ne[3]; + const uint32_t K = (uint32_t) src5->ne[1]; const float scale = 1.0f / sqrtf((float) s_v); uint32_t scale_u32; memcpy(&scale_u32, &scale, sizeof(scale_u32)); @@ -1258,6 +1259,7 @@ static webgpu_encoded_op ggml_webgpu_gated_delta_net(webgpu_context & ctx, (uint32_t) src0->ne[1], (uint32_t) (src2->ne[3] / src0->ne[3]), + K, scale_u32, }; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl index f9d98fda40b..d68520f8282 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl @@ -39,6 +39,7 @@ struct Params { neq1: u32, rq3: u32, + K: u32, scale: f32, }; @@ -62,11 +63,14 @@ fn main( let iq3 = seq_id / params.rq3; let state_size = S_V * S_V; - let state_base = (seq_id * params.h + head_id) * state_size; + let state_in_base = (seq_id * params.K * params.h + head_id) * state_size; + let state_out_base = (seq_id * params.h + head_id) * state_size; + let state_size_per_snap = state_size * params.h * params.n_seqs; + let shift = i32(params.n_tokens) - i32(params.K); var state: array; for (var i = 0u; i < S_V; i++) { - state[i] = src_state[state_base + col * S_V + i]; + state[i] = src_state[state_in_base + col * S_V + i]; } var attn_off = (seq_id * params.n_tokens * params.h + head_id) * S_V; @@ -123,10 +127,22 @@ fn main( dst[attn_off + col] = attn_col * params.scale; attn_off += S_V * params.h; + if (params.K > 1u) { + let target_slot = i32(t) - shift; + if (target_slot >= 0 && target_slot < i32(params.K)) { + let slot_base = params.s_off + u32(target_slot) * state_size_per_snap + state_out_base; + for (var i = 0u; i < S_V; i++) { + dst[slot_base + col * S_V + i] = state[i]; + } + } + } + workgroupBarrier(); } - for (var i = 0u; i < S_V; i++) { - dst[params.s_off + state_base + col * S_V + i] = state[i]; + if (params.K == 1u) { + for (var i = 0u; i < S_V; i++) { + dst[params.s_off + state_out_base + col * S_V + i] = state[i]; + } } } From 459ff0707b51b0262536977d34e8ad2631063e68 Mon Sep 17 00:00:00 2001 From: Aparna M P Date: Tue, 19 May 2026 22:18:21 +0530 Subject: [PATCH 655/831] hexagon: enable support for NORM op (llama/23319) --- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 5 +- ggml/src/ggml-hexagon/htp/htp-ops.h | 1 + ggml/src/ggml-hexagon/htp/main.c | 1 + ggml/src/ggml-hexagon/htp/unary-ops.c | 97 ++++++++++++++++++++++++++ 4 files changed, 101 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 2f75e97ac66..ebeef3bdbaf 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -2870,6 +2870,7 @@ static htp_op_code op_remap_to_htp(const ggml_tensor * t) { case GGML_OP_SET_ROWS: return HTP_OP_SET_ROWS; case GGML_OP_SUM_ROWS: return HTP_OP_SUM_ROWS; case GGML_OP_ARGSORT: return HTP_OP_ARGSORT; + case GGML_OP_NORM: return HTP_OP_NORM; case GGML_OP_L2_NORM: return HTP_OP_L2_NORM; case GGML_OP_RMS_NORM: return HTP_OP_RMS_NORM; case GGML_OP_SCALE: return HTP_OP_SCALE; @@ -3338,10 +3339,8 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons supp = ggml_hexagon_supported_add_id(sess, op); break; + case GGML_OP_NORM: case GGML_OP_L2_NORM: - supp = ggml_hexagon_supported_unary(sess, op); - break; - case GGML_OP_RMS_NORM: case GGML_OP_SCALE: supp = ggml_hexagon_supported_unary(sess, op); diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index 676e948a439..9d905a30133 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -88,6 +88,7 @@ enum htp_op_code { HTP_OP_GATED_DELTA_NET, HTP_OP_TRI, HTP_OP_PAD, + HTP_OP_NORM, HTP_OP_INVALID }; diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index 12003c1fd8a..8e54536f619 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -534,6 +534,7 @@ static int execute_op(struct htp_ops_context * octx) { case HTP_OP_ADD_ID: return op_binary(octx); + case HTP_OP_NORM: case HTP_OP_RMS_NORM: case HTP_OP_SCALE: case HTP_OP_SQR: diff --git a/ggml/src/ggml-hexagon/htp/unary-ops.c b/ggml/src/ggml-hexagon/htp/unary-ops.c index 1ce881353ec..40d2d60153a 100644 --- a/ggml/src/ggml-hexagon/htp/unary-ops.c +++ b/ggml/src/ggml-hexagon/htp/unary-ops.c @@ -158,6 +158,79 @@ static void hvx_fast_rms_norm_f32(const uint8_t * restrict src, } } +static void hvx_fast_norm_f32(const uint8_t * restrict src, + uint8_t * restrict dst, + uint8_t * restrict pad, + const int num_elems, + float epsilon) { + (void)pad; + + const HVX_Vector * restrict v_src = (HVX_Vector *) src; + HVX_Vector * restrict v_dst = (HVX_Vector *) dst; + + const int nvec = num_elems / VLEN_FP32; // number of full vectors + const int nloe = num_elems % VLEN_FP32; // leftover elements + + // Compute sum of squares and sum of values for full vectors + HVX_Vector sum_sq_v = Q6_V_vsplat_R(0x00000000); + HVX_Vector sum_x_v = Q6_V_vsplat_R(0x00000000); + HVX_Vector epsilon_v = hvx_vec_splat_f32(epsilon); + + #pragma unroll(4) + for (int i = 0; i < nvec; i++) { + HVX_Vector v1 = v_src[i]; + HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, v1); + sum_sq_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_sq_v, v2); + sum_x_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_x_v, Q6_Vqf32_vadd_VsfVsf(v1, Q6_V_vzero())); + } + + // Handle tail elements using vectorized ops with masking + if (nloe > 0) { + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4); + HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]); + HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, v1); + sum_sq_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_sq_v, v2); + sum_x_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_x_v, Q6_Vqf32_vadd_VsfVsf(v1, Q6_V_vzero())); + } + + // Reduce HVX sums + sum_sq_v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_sq_v)); + sum_x_v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_x_v)); + + HVX_Vector t_v = hvx_vec_splat_f32((float) num_elems); + HVX_Vector denom_v = hvx_vec_inverse_f32(t_v); + HVX_Vector mean_sq_v = Q6_Vqf32_vmpy_VsfVsf(sum_sq_v, denom_v); + HVX_Vector mean_x_v = Q6_Vqf32_vmpy_VsfVsf(sum_x_v, denom_v); + HVX_Vector mean_x_sq_v = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(mean_x_v), Q6_Vsf_equals_Vqf32(mean_x_v)); + HVX_Vector var_v = Q6_Vqf32_vsub_Vqf32Vqf32(mean_sq_v, mean_x_sq_v); + HVX_Vector var_epsilon_v = Q6_Vqf32_vadd_Vqf32Vsf(var_v, epsilon_v); + + // scale = rsqrt(variance + epsilon), mean_x broadcast for subtraction + HVX_Vector scale_v = hvx_vec_rsqrt_f32(Q6_Vsf_equals_Vqf32(var_epsilon_v)); + HVX_Vector mean_x_b = hvx_vec_splat_f32(hvx_vec_get_f32(Q6_Vsf_equals_Vqf32(mean_x_v))); + + #pragma unroll(4) + for (int i = 0; i < nvec; i++) { + HVX_Vector v1 = v_src[i]; + HVX_Vector v2 = Q6_Vqf32_vsub_VsfVsf(v1, mean_x_b); + HVX_Vector v3 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(v2), scale_v); + v_dst[i] = Q6_Vsf_equals_Vqf32(v3); + } + + // Handle tail elements using vectorized ops with masking + if (nloe > 0) { + + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4); + HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]); + HVX_Vector v2 = Q6_Vqf32_vsub_VsfVsf(v1, mean_x_b); + HVX_Vector v3 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(v2), scale_v); + HVX_Vector result = Q6_Vsf_equals_Vqf32(v3); + + // Store with masking to avoid overwriting memory beyond the tensor + hvx_vec_store_a(&v_dst[nvec], nloe * 4, result); + } +} + static void scale_f32(const float * restrict src, float * restrict dst, uint8_t * restrict spad, @@ -196,6 +269,24 @@ static void rms_norm_f32(const float * restrict src, } } +static void norm_f32(const float * restrict src, + float * restrict dst, + uint8_t * restrict spad, + const uint32_t num_rows, + const uint32_t row_elems, + const size_t row_size, + int32_t * op_params) { + float epsilon = 0.f; + memcpy(&epsilon, op_params, sizeof(float)); + + for (uint32_t ir = 0; ir < num_rows; ir++) { + const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size); + uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size); + + hvx_fast_norm_f32((const uint8_t *) src_local, (uint8_t *) dst_local, spad, row_elems, epsilon); + } +} + static void sqr_f32(const float * restrict src, float * restrict dst, uint8_t * restrict spad, @@ -556,6 +647,9 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * // Process block in VTCM switch (htp_op) { + case HTP_OP_NORM: + norm_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params); + break; case HTP_OP_RMS_NORM: rms_norm_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params); break; @@ -632,6 +726,9 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) { const char * op_type = NULL; switch (octx->op) { + case HTP_OP_NORM: + op_type = "norm-f32"; + break; case HTP_OP_RMS_NORM: op_type = "rmsnorm-f32"; break; From aca63e76386b239cbd1692629ddece4e7e26f03b Mon Sep 17 00:00:00 2001 From: Aparna M P Date: Wed, 20 May 2026 02:40:13 +0530 Subject: [PATCH 656/831] hexagon: add MROPE and IMROPE support in HTP rope op (llama/23317) --- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 2 +- ggml/src/ggml-hexagon/htp/rope-ops.c | 115 +++++++++++++++++++++---- 2 files changed, 98 insertions(+), 19 deletions(-) diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index ebeef3bdbaf..080fb7f47e3 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -2661,7 +2661,7 @@ static bool ggml_hexagon_supported_rope(const struct ggml_hexagon_session * sess int mode = op_params[2]; - if ((mode & GGML_ROPE_TYPE_MROPE) || (mode & GGML_ROPE_TYPE_VISION)) { + if (mode == GGML_ROPE_TYPE_VISION) { return false; } if (mode & 1) { diff --git a/ggml/src/ggml-hexagon/htp/rope-ops.c b/ggml/src/ggml-hexagon/htp/rope-ops.c index 1d8b0796bc9..9901453e91e 100644 --- a/ggml/src/ggml-hexagon/htp/rope-ops.c +++ b/ggml/src/ggml-hexagon/htp/rope-ops.c @@ -18,9 +18,11 @@ #include "htp-ops.h" #include "htp-ops.h" -// Redefined the types GGML_ROPE_TYPE_NORMAL & GGML_ROPE_TYPE_NEOX as we can't include ggml.h +// Redefined the rope type constants as we can't include ggml.h #define HTP_ROPE_TYPE_NORMAL 0 #define HTP_ROPE_TYPE_NEOX 2 +#define HTP_ROPE_TYPE_MROPE 8 +#define HTP_ROPE_TYPE_IMROPE 40 #define HTP_ROPE_SPAD_NROWS 16 #define HTP_ROPE_SPAD_BLOCK (HTP_ROPE_SPAD_NROWS/2) @@ -82,6 +84,29 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) { return (1 - MIN(1, MAX(0, y))); } +// Compute one (cos, sin) pair into cache[i0], cache[i0+1] applying YaRN scaling. +static inline void rope_yarn_one(float theta, float freq_scale, float * corr_dims, + uint32_t i0, float ext_factor, float mscale, + float * cache) { + float theta_extrap = theta; + + // Get n-d rotational scaling corrected for extrapolation + float theta_interp = freq_scale * theta_extrap; + float theta_final = theta_interp; + float mscale_final = mscale; + + if (ext_factor != 0.0f) { + float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor; + theta_final = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; + + // Get n-d magnitude scaling corrected for interpolation + mscale_final *= 1.0f + 0.1f * logf(1.0f / freq_scale); + } + + cache[i0 + 0] = cosf(theta_final) * mscale_final; + cache[i0 + 1] = sinf(theta_final) * mscale_final; +} + static void rope_cache_init(const float theta_base, const float freq_scale, const float * freq_factors, @@ -96,26 +121,62 @@ static void rope_cache_init(const float theta_base, for (uint32_t i0 = 0; i0 < ne0; i0 += 2) { const float ff = freq_factors ? freq_factors[i0 / 2] : 1.0f; + rope_yarn_one(theta / ff, freq_scale, corr_dims, i0, ext_factor, mscale, cache); - float theta_extrap = theta / ff; - - // Get n-d rotational scaling corrected for extrapolation - float theta_interp = freq_scale * theta_extrap; - float theta_final = theta_interp; - float mscale_final = mscale; + theta *= theta_scale; + } +} - if (ext_factor != 0.0f) { - float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor; - theta_final = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; +// pos_t/h/w/e: the four position ids for this sequence step (t=time, h=height, w=width, e=extra). +// sections[4]: number of head dims assigned to each position component. +static void mrope_cache_init(const float pos_t, + const float pos_h, + const float pos_w, + const float pos_e, + const int32_t sections[4], + const bool is_imrope, + const float freq_scale, + const float * freq_factors, + float * corr_dims, + const uint32_t ne0, + const float ext_factor, + const float mscale, + float * cache, + const float theta_scale) { + const int sect_dims = sections[0] + sections[1] + sections[2] + sections[3]; + const int sec_w = sections[0] + sections[1]; + const int sec_e = sec_w + sections[2]; + + float theta_t = pos_t; + float theta_h = pos_h; + float theta_w = pos_w; + float theta_e = pos_e; - // Get n-d magnitude scaling corrected for interpolation - mscale_final *= 1.0f + 0.1f * logf(1.0f / freq_scale); + for (uint32_t i0 = 0; i0 < ne0; i0 += 2) { + const float ff = freq_factors ? freq_factors[i0 / 2] : 1.0f; + const int sector = (i0 / 2) % sect_dims; + + float theta; + if (is_imrope) { + // Interleaved: sector mod 3 selects component + if (sector % 3 == 0 && sector < 3 * sections[0]) { theta = theta_t; } + else if (sector % 3 == 1 && sector < 3 * sections[1]) { theta = theta_h; } + else if (sector % 3 == 2 && sector < 3 * sections[2]) { theta = theta_w; } + else { theta = theta_e; } + } else { + // Contiguous sections + if (sector < sections[0]) { theta = theta_t; } + else if (sector < sec_w) { theta = theta_h; } + else if (sector < sec_e) { theta = theta_w; } + else { theta = theta_e; } } - cache[i0 + 0] = cosf(theta_final) * mscale_final; - cache[i0 + 1] = sinf(theta_final) * mscale_final; + rope_yarn_one(theta / ff, freq_scale, corr_dims, i0, ext_factor, mscale, cache); - theta *= theta_scale; + theta_t *= theta_scale; + theta_h *= theta_scale; + theta_w *= theta_scale; + theta_e *= theta_scale; } } @@ -274,7 +335,8 @@ static void rope_job_f32(unsigned int nth, unsigned int ith, void * data) { uint64_t tt = HAP_perf_get_qtimer_count(); const int32_t mode = rctx->mode; - const bool is_neox = mode & HTP_ROPE_TYPE_NEOX; + // MROPE and IMROPE use NEOX-style pairing for the rotation + const bool is_neox = (mode & HTP_ROPE_TYPE_NEOX) || (mode & HTP_ROPE_TYPE_MROPE); // VTCM setup uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread); @@ -326,8 +388,25 @@ static void rope_job_f32(unsigned int nth, unsigned int ith, void * data) { if (i2 != prev_i2) { prev_i2 = i2; - const int32_t p = pos[i2]; - rope_cache_init(p, rctx->freq_scale, freq_factors, rctx->corr_dims, ne0, rctx->ext_factor, rctx->attn_factor, theta_cache, rctx->theta_scale); + const bool is_mrope = (rctx->mode & HTP_ROPE_TYPE_MROPE) != 0; + if (is_mrope) { + // src1 holds four position arrays stacked along ne0: + // pos[i2], pos[i2+ne2], pos[i2+ne2*2], pos[i2+ne2*3] + const bool is_imrope = (rctx->mode == HTP_ROPE_TYPE_IMROPE); + mrope_cache_init( + (float) pos[i2], + (float) pos[i2 + ne2], + (float) pos[i2 + ne2 * 2], + (float) pos[i2 + ne2 * 3], + rctx->sections, is_imrope, + rctx->freq_scale, freq_factors, rctx->corr_dims, + ne0, rctx->ext_factor, rctx->attn_factor, + theta_cache, rctx->theta_scale); + } else { + rope_cache_init(pos[i2], rctx->freq_scale, freq_factors, rctx->corr_dims, + ne0, rctx->ext_factor, rctx->attn_factor, + theta_cache, rctx->theta_scale); + } // FARF(HIGH, "rope-theta %u: ir %u i1 %u i2 %u i3 %u cache %p : usec %u", ith, ir, i1, i2, i3, theta_cache, // (unsigned) HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - rctx->t_start)); From 37f17208c25aae5c93177599c05de6efc3b57446 Mon Sep 17 00:00:00 2001 From: shaofeiqi Date: Tue, 19 May 2026 14:29:00 -0700 Subject: [PATCH 657/831] opencl: add MoE support for q4_k, q5_k, q6_k on Adreno (llama/23303) * opencl: add q4_k moe support * opencl: add q5_k moe support * opencl: add q6_k moe support * opencl: adjust format --------- Co-authored-by: Li He --- ggml/src/ggml-opencl/CMakeLists.txt | 6 + ggml/src/ggml-opencl/ggml-opencl.cpp | 948 +++++++++++++++++- ggml/src/ggml-opencl/kernels/cvt.cl | 385 +++++++ .../kernels/gemm_moe_q4_k_f32_ns.cl | 279 ++++++ .../kernels/gemm_moe_q5_k_f32_ns.cl | 284 ++++++ .../kernels/gemm_moe_q6_k_f32_ns.cl | 263 +++++ .../kernels/gemv_moe_q4_k_f32_ns.cl | 151 +++ .../kernels/gemv_moe_q5_k_f32_ns.cl | 156 +++ .../kernels/gemv_moe_q6_k_f32_ns.cl | 137 +++ 9 files changed, 2601 insertions(+), 8 deletions(-) create mode 100644 ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl create mode 100644 ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl create mode 100644 ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl create mode 100644 ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl create mode 100644 ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl create mode 100644 ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index c6aba608736..f75d089b574 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -110,6 +110,12 @@ set(GGML_OPENCL_KERNELS gemv_moe_q5_0_f32_ns gemm_moe_q5_1_f32_ns gemv_moe_q5_1_f32_ns + gemm_moe_q4_k_f32_ns + gemv_moe_q4_k_f32_ns + gemm_moe_q5_k_f32_ns + gemv_moe_q5_k_f32_ns + gemm_moe_q6_k_f32_ns + gemv_moe_q6_k_f32_ns gemm_moe_mxfp4_f32 gemv_moe_mxfp4_f32 gemm_moe_mxfp4_f32_ns diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 0e511592d53..a3af8c2da41 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -558,6 +558,9 @@ struct ggml_backend_opencl_context { cl_kernel kernel_convert_block_q4_1_trans4_ns, kernel_restore_block_q4_1_trans4_ns; cl_kernel kernel_convert_block_q5_0_trans4_ns, kernel_restore_block_q5_0_trans4_ns; cl_kernel kernel_convert_block_q5_1_trans4_ns, kernel_restore_block_q5_1_trans4_ns; + cl_kernel kernel_convert_block_q4_k_trans4_ns, kernel_restore_block_q4_k_trans4_ns; + cl_kernel kernel_convert_block_q5_k_trans4_ns, kernel_restore_block_q5_k_trans4_ns; + cl_kernel kernel_convert_block_q6_k_trans4_ns, kernel_restore_block_q6_k_trans4_ns; cl_kernel kernel_convert_block_mxfp4, kernel_convert_block_mxfp4_trans, kernel_restore_block_mxfp4, kernel_restore_block_mxfp4_trans; cl_kernel kernel_convert_block_mxfp4_trans4_ns, kernel_restore_block_mxfp4_trans4_ns; cl_kernel kernel_convert_block_q8_0, kernel_restore_block_q8_0, kernel_restore_block_q8_0_trans; @@ -619,6 +622,9 @@ struct ggml_backend_opencl_context { cl_kernel kernel_gemv_moe_q4_1_f32_ns, kernel_gemm_moe_q4_1_f32_ns; cl_kernel kernel_gemv_moe_q5_0_f32_ns, kernel_gemm_moe_q5_0_f32_ns; cl_kernel kernel_gemv_moe_q5_1_f32_ns, kernel_gemm_moe_q5_1_f32_ns; + cl_kernel kernel_gemv_moe_q4_k_f32_ns, kernel_gemm_moe_q4_k_f32_ns; + cl_kernel kernel_gemv_moe_q5_k_f32_ns, kernel_gemm_moe_q5_k_f32_ns; + cl_kernel kernel_gemv_moe_q6_k_f32_ns, kernel_gemm_moe_q6_k_f32_ns; cl_kernel kernel_gemv_moe_mxfp4_f32, kernel_gemm_moe_mxfp4_f32; cl_kernel kernel_gemv_moe_mxfp4_f32_ns, kernel_gemm_moe_mxfp4_f32_ns; cl_kernel kernel_moe_reorder_b; @@ -981,6 +987,12 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve CL_CHECK((backend_ctx->kernel_restore_block_q5_0_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q5_0_trans4_ns", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q5_1_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q5_1_trans4_ns", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_q5_1_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q5_1_trans4_ns", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q4_k_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_k_trans4_ns", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q4_k_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_k_trans4_ns", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q5_k_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q5_k_trans4_ns", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q5_k_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q5_k_trans4_ns", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q6_k_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q6_k_trans4_ns", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q6_k_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q6_k_trans4_ns", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_mxfp4 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_mxfp4_trans = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4_trans", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_mxfp4_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4_trans4_ns", &err), err)); @@ -3071,6 +3083,108 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // gemv_moe_q4_k_f32_ns + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemv_moe_q4_k_f32_ns.cl.h" + }; +#else + const std::string kernel_src = read_file("gemv_moe_q4_k_f32_ns.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); + + CL_CHECK((backend_ctx->kernel_gemv_moe_q4_k_f32_ns = clCreateKernel(prog, "kernel_gemv_moe_q4_k_f32_ns", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // gemm_moe_q4_k_f32_ns + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemm_moe_q4_k_f32_ns.cl.h" + }; +#else + const std::string kernel_src = read_file("gemm_moe_q4_k_f32_ns.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); + + CL_CHECK((backend_ctx->kernel_gemm_moe_q4_k_f32_ns = clCreateKernel(prog, "kernel_gemm_moe_q4_k_f32_ns", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // gemv_moe_q5_k_f32_ns + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemv_moe_q5_k_f32_ns.cl.h" + }; +#else + const std::string kernel_src = read_file("gemv_moe_q5_k_f32_ns.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); + + CL_CHECK((backend_ctx->kernel_gemv_moe_q5_k_f32_ns = clCreateKernel(prog, "kernel_gemv_moe_q5_k_f32_ns", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // gemm_moe_q5_k_f32_ns + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemm_moe_q5_k_f32_ns.cl.h" + }; +#else + const std::string kernel_src = read_file("gemm_moe_q5_k_f32_ns.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); + + CL_CHECK((backend_ctx->kernel_gemm_moe_q5_k_f32_ns = clCreateKernel(prog, "kernel_gemm_moe_q5_k_f32_ns", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // gemv_moe_q6_k_f32_ns + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemv_moe_q6_k_f32_ns.cl.h" + }; +#else + const std::string kernel_src = read_file("gemv_moe_q6_k_f32_ns.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); + + CL_CHECK((backend_ctx->kernel_gemv_moe_q6_k_f32_ns = clCreateKernel(prog, "kernel_gemv_moe_q6_k_f32_ns", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // gemm_moe_q6_k_f32_ns + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemm_moe_q6_k_f32_ns.cl.h" + }; +#else + const std::string kernel_src = read_file("gemm_moe_q6_k_f32_ns.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); + + CL_CHECK((backend_ctx->kernel_gemm_moe_q6_k_f32_ns = clCreateKernel(prog, "kernel_gemm_moe_q6_k_f32_ns", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + // gemv_moe_mxfp4_f32_ns { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -4148,6 +4262,8 @@ struct ggml_tensor_extra_cl_iq4_nl { struct ggml_tensor_extra_cl_q4_K { // Quantized values cl_mem q = nullptr; + // Quantized values in image1d_buffer_t. + cl_mem q_img = nullptr; // Scales for each super block. cl_mem s = nullptr; // Scales @@ -4176,12 +4292,18 @@ struct ggml_tensor_extra_cl_q4_K { CL_CHECK(clReleaseMemObject(dm)); dm = nullptr; } + if (q_img != nullptr) { + CL_CHECK(clReleaseMemObject(q_img)); + q_img = nullptr; + } } }; struct ggml_tensor_extra_cl_q5_K { // Lower 4 bits of quantized weights. cl_mem q = nullptr; + // Quantized values in image1d_buffer_t. + cl_mem q_img = nullptr; // Upper 1 bit of quantized weights. cl_mem qh = nullptr; // Scales for each block. @@ -4222,6 +4344,10 @@ struct ggml_tensor_extra_cl_q5_K { CL_CHECK(clReleaseMemObject(dm)); dm = nullptr; } + if (q_img != nullptr) { + CL_CHECK(clReleaseMemObject(q_img)); + q_img = nullptr; + } size_q = 0; size_qh = 0; @@ -4234,6 +4360,8 @@ struct ggml_tensor_extra_cl_q5_K { struct ggml_tensor_extra_cl_q6_K { // Lower 4 bits of quantized weights. cl_mem ql = nullptr; + // Lower 4 bits as image1d_buffer_t + cl_mem ql_img = nullptr; // Upper 2 bits of quantized weights. cl_mem qh = nullptr; // Scales for each block. @@ -4267,6 +4395,10 @@ struct ggml_tensor_extra_cl_q6_K { CL_CHECK(clReleaseMemObject(d)); d = nullptr; } + if (ql_img != nullptr) { + CL_CHECK(clReleaseMemObject(ql_img)); + ql_img = nullptr; + } size_ql = 0; size_qh = 0; @@ -4700,7 +4832,10 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te // the quantizations here currently do not - they are only supported by Adreno with certain shapes if (op->src[0]->type == GGML_TYPE_Q4_1 || op->src[0]->type == GGML_TYPE_Q5_0 || - op->src[0]->type == GGML_TYPE_Q5_1) { + op->src[0]->type == GGML_TYPE_Q5_1 || + op->src[0]->type == GGML_TYPE_Q4_K || + op->src[0]->type == GGML_TYPE_Q5_K || + op->src[0]->type == GGML_TYPE_Q6_K) { #ifdef GGML_OPENCL_USE_ADRENO_KERNELS if (op->src[1]->type == GGML_TYPE_F32) { return use_adreno_moe_kernels(backend_ctx, op->src[0]) @@ -6047,14 +6182,57 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); CL_CHECK(err); - #ifdef GGML_OPENCL_USE_ADRENO_KERNELS +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_moe_kernels(backend_ctx, tensor)) { + cl_kernel kernel = backend_ctx->kernel_convert_block_q4_k_trans4_ns; + + int ne00 = tensor->ne[0]; + int ne01 = tensor->ne[1]; + int ne02 = tensor->ne[2]; + + cl_uchar mask_0F = 0x0F; + cl_uchar mask_F0 = 0xF0; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->dm)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra->s)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_uchar), &mask_0F)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_uchar), &mask_F0)); + + size_t global_work_size[] = {static_cast(((ne01 + 63) / 64) * 64), static_cast(ne00 / 256), static_cast(ne02)}; + size_t local_work_size[] = {64, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + + cl_image_format img_format_q = {CL_R, CL_UNSIGNED_INT32}; + cl_image_desc img_desc_q = { + CL_MEM_OBJECT_IMAGE1D_BUFFER, + static_cast(ggml_nelements(tensor) / 8), + 0, 0, 0, 0, 0, 0, 0, + { extra->q } + }; + extra->q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_format_q, &img_desc_q, NULL, &err); + CL_CHECK(err); + tensor->extra = extra; + + return; + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS cl_kernel kernel = backend_ctx->kernel_convert_block_q4_K; if (use_adreno_kernels(backend_ctx, tensor)) { kernel = backend_ctx->kernel_convert_block_q4_K_noshuffle; } - #else +#else cl_kernel kernel = backend_ctx->kernel_convert_block_q4_K; - #endif +#endif // GGML_OPENCL_USE_ADRENO_KERNELS cl_uchar mask_0F = 0x0F; cl_uchar mask_F0 = 0xF0; @@ -6157,14 +6335,58 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, CL_CHECK((extra->qh = clCreateSubBuffer(extra_orig->data_device, CL_MEM_READ_WRITE, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); CL_CHECK(err); - #ifdef GGML_OPENCL_USE_ADRENO_KERNELS +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_moe_kernels(backend_ctx, tensor)) { + cl_kernel kernel = backend_ctx->kernel_convert_block_q5_k_trans4_ns; + + int ne00 = tensor->ne[0]; + int ne01 = tensor->ne[1]; + int ne02 = tensor->ne[2]; + + cl_uchar mask_0F = 0x0F; + cl_uchar mask_F0 = 0xF0; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->qh)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra->dm)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extra->s)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_uchar), &mask_0F)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_uchar), &mask_F0)); + + size_t global_work_size[] = {static_cast(((ne01 + 63) / 64) * 64), static_cast(ne00 / 256), static_cast(ne02)}; + size_t local_work_size[] = {64, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + + cl_image_format img_format_q = {CL_R, CL_UNSIGNED_INT32}; + cl_image_desc img_desc_q = { + CL_MEM_OBJECT_IMAGE1D_BUFFER, + static_cast(ggml_nelements(tensor) / 8), + 0, 0, 0, 0, 0, 0, 0, + { extra->q } + }; + extra->q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_format_q, &img_desc_q, NULL, &err); + CL_CHECK(err); + tensor->extra = extra; + + return; + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS cl_kernel kernel = backend_ctx->kernel_convert_block_q5_K; if (use_adreno_kernels(backend_ctx, tensor)) { kernel = backend_ctx->kernel_convert_block_q5_K_noshuffle; } - #else +#else cl_kernel kernel = backend_ctx->kernel_convert_block_q5_K; - #endif +#endif cl_uchar mask_0F = 0x0F; cl_uchar mask_F0 = 0xF0; @@ -6232,6 +6454,79 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, cl_buffer_region region; + cl_uchar mask_0F = 0x0F; + cl_uchar mask_F0 = 0xF0; + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + // Adreno MoE Q6_K kernel needs special transposed layout + if (use_adreno_moe_kernels(backend_ctx, tensor)) { + size_t moe_size_ql = (size_t)(ggml_nelements(tensor) / 8) * sizeof(uint32_t); // 4 bits per element + size_t moe_size_qh = (size_t)(ggml_nelements(tensor) / 16) * sizeof(uint32_t); // 2 bits per element + size_t moe_size_s = size_s; + size_t moe_size_d = size_d; + + // Subbuffer for ql + region.origin = align_to(extra_orig->offset + tensor->view_offs + offset, backend_ctx->alignment); + region.size = moe_size_ql; + CL_CHECK((extra->ql = clCreateSubBuffer(extra_orig->data_device, CL_MEM_READ_WRITE, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + auto previous_origin = region.origin; + + // Subbuffer for qh + region.origin = align_to(previous_origin + moe_size_ql, backend_ctx->alignment); + region.size = moe_size_qh; + CL_CHECK((extra->qh = clCreateSubBuffer(extra_orig->data_device, CL_MEM_READ_WRITE, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + previous_origin = region.origin; + + // Subbuffer for scales + region.origin = align_to(previous_origin + moe_size_qh, backend_ctx->alignment); + region.size = moe_size_s; + CL_CHECK((extra->s = clCreateSubBuffer(extra_orig->data_device, CL_MEM_READ_WRITE, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + previous_origin = region.origin; + + // Subbuffer for d + region.origin = align_to(previous_origin + moe_size_s, backend_ctx->alignment); + region.size = moe_size_d; + CL_CHECK((extra->d = clCreateSubBuffer(extra_orig->data_device, CL_MEM_READ_WRITE, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + cl_kernel kernel = backend_ctx->kernel_convert_block_q6_k_trans4_ns; + + int ne00 = tensor->ne[0]; + int ne01 = tensor->ne[1]; + int ne02 = tensor->ne[2]; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->ql)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->qh)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra->s)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_uchar), &mask_0F)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_uchar), &mask_F0)); + + size_t global_work_size[] = {static_cast(((ne01 + 63) / 64) * 64), static_cast(ne00 / 256), static_cast(ne02)}; + size_t local_work_size[] = {64, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + + // Create image for ql + cl_image_format img_format_ql = {CL_R, CL_UNSIGNED_INT32}; + cl_image_desc img_desc_ql = { + CL_MEM_OBJECT_IMAGE1D_BUFFER, + static_cast(ggml_nelements(tensor) / 8), + 0, 0, 0, 0, 0, 0, 0, + { extra->ql } + }; + extra->ql_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_format_ql, &img_desc_ql, NULL, &err); + tensor->extra = extra; + + return; + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + // Subbuffer for ql region.origin = align_to(extra_orig->offset + tensor->view_offs + offset, backend_ctx->alignment); region.size = size_ql; @@ -6825,6 +7120,40 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, cl_uchar mask_F0 = 0xF0; #ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_moe_kernels(backend_ctx, tensor)) { + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + + cl_kernel kernel = backend_ctx->kernel_restore_block_q4_k_trans4_ns; + + int ne00 = tensor->ne[0]; + int ne01 = tensor->ne[1]; + int ne02 = tensor->ne[2]; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->dm)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->s)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_uchar), &mask_0F)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_uchar), &mask_F0)); + + size_t global_work_size[] = {static_cast(((ne01 + 63) / 64) * 64), static_cast(ne00 / 256), static_cast(ne02)}; + size_t local_work_size[] = {64, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clEnqueueReadBuffer( + queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); + return; + } if (use_adreno_kernels(backend_ctx, tensor)) { int M = tensor->ne[1]; int K = tensor->ne[0]; @@ -6901,6 +7230,40 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, cl_uchar mask_F0 = 0xF0; #ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_moe_kernels(backend_ctx, tensor)) { + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + cl_kernel kernel = backend_ctx->kernel_restore_block_q5_k_trans4_ns; + + int ne00 = tensor->ne[0]; + int ne01 = tensor->ne[1]; + int ne02 = tensor->ne[2]; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->dm)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra->s)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_uchar), &mask_0F)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_uchar), &mask_F0)); + + size_t global_work_size[] = {static_cast(((ne01 + 63) / 64) * 64), static_cast(ne00 / 256), static_cast(ne02)}; + size_t local_work_size[] = {64, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clEnqueueReadBuffer( + queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); + return; + } if (use_adreno_kernels(backend_ctx, tensor)) { int M = tensor->ne[1]; int K = tensor->ne[0]; @@ -6974,7 +7337,44 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, if (tensor->type == GGML_TYPE_Q6_K) { ggml_tensor_extra_cl_q6_K * extra = (ggml_tensor_extra_cl_q6_K *)tensor->extra; -#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + cl_uchar mask_0F = 0x0F; + cl_uchar mask_F0 = 0xF0; + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_moe_kernels(backend_ctx, tensor)) { + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + + cl_kernel kernel = backend_ctx->kernel_restore_block_q6_k_trans4_ns; + + int ne00 = tensor->ne[0]; + int ne01 = tensor->ne[1]; + int ne02 = tensor->ne[2]; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->ql)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->s)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_uchar), &mask_0F)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_uchar), &mask_F0)); + + size_t global_work_size[] = {static_cast(((ne01 + 63) / 64) * 64), static_cast(ne00 / 256), static_cast(ne02)}; + size_t local_work_size[] = {64, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clEnqueueReadBuffer( + queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); + return; + } if (use_adreno_kernels(backend_ctx, tensor)) { static ggml_cl_buffer buf_trans_ql; static ggml_cl_buffer buf_trans_qh; @@ -13733,6 +14133,9 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, ggml_tensor_extra_cl_q4_1 * extra0_q4_1 = (ggml_tensor_extra_cl_q4_1 *)src0->extra; ggml_tensor_extra_cl_q5_0 * extra0_q5_0 = (ggml_tensor_extra_cl_q5_0 *)src0->extra; ggml_tensor_extra_cl_q5_1 * extra0_q5_1 = (ggml_tensor_extra_cl_q5_1 *)src0->extra; + ggml_tensor_extra_cl_q4_K * extra0_q4_K = (ggml_tensor_extra_cl_q4_K *)src0->extra; + ggml_tensor_extra_cl_q5_K * extra0_q5_K = (ggml_tensor_extra_cl_q5_K *)src0->extra; + ggml_tensor_extra_cl_q6_K * extra0_q6_K = (ggml_tensor_extra_cl_q6_K *)src0->extra; ggml_tensor_extra_cl_mxfp4 * extra0_mxfp4 = (ggml_tensor_extra_cl_mxfp4 *)src0->extra; ggml_tensor_extra_cl_q8_0 * extra0_q8_0 = (ggml_tensor_extra_cl_q8_0 *)src0->extra; #endif @@ -13741,6 +14144,9 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, (void)extra0_q4_1; (void)extra0_q5_0; (void)extra0_q5_1; + (void)extra0_q4_K; + (void)extra0_q5_K; + (void)extra0_q6_K; const int ne00 = src0->ne[0]; const int ne01 = src0->ne[1]; @@ -14612,6 +15018,532 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, #endif // GGML_OPENCL_SOA_Q break; } + case GGML_TYPE_Q4_K: { +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_moe_kernels(backend_ctx, src0)) { + cl_int status; + + size_t local_size[3] = {64, 2, 1}; + size_t global_size[3] = {64, 2, 1}; + + if (ne12 == 1) { // for gemv + kernel = backend_ctx->kernel_gemv_moe_q4_k_f32_ns; + + cl_mem src1_sub_buffer, buf_src1_image, buf_src2; + + // create a sub_buffer for src2 + cl_buffer_region region; + region.origin = offset2; + region.size = ne20 * ne21 * sizeof(int); + buf_src2 = clCreateSubBuffer(extra2->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // set thread grid + global_size[0] = static_cast(ne01); + global_size[1] = 4; + global_size[2] = static_cast(ne20); + local_size[1] = 4; + + // create a sub_buffer for src1 + region.origin = offset1; + region.size = ne10 * ne11 * ne12 * sizeof(float); + src1_sub_buffer = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // create image for src1 + cl_image_format image_format_buf_src1 = {CL_RGBA, CL_FLOAT}; + cl_image_desc image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ne10 * ne11 * ne12 / 4), 0,0,0,0,0,0,0, {src1_sub_buffer}}; + buf_src1_image = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status); + CL_CHECK(status); + + // Set kernel args + int arg_idx = 0; + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_K->q)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_K->d)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_K->dm)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_K->s)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src1_image)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne11)); + + // launch kernel + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst); + + // deallocate sub buffers and images + CL_CHECK(clReleaseMemObject(src1_sub_buffer)); + CL_CHECK(clReleaseMemObject(buf_src1_image)); + CL_CHECK(clReleaseMemObject(buf_src2)); + + } else { // for gemm + kernel = backend_ctx->kernel_gemm_moe_q4_k_f32_ns; + + // Reorder router if called from test-backend-ops or when new router is generated. + // Otherwise reuse the reordered result from previous mul_mat_id call. + if ((strstr(src0->name, "as") != NULL) || backend_ctx->toggle_reorder) { + moe_router_reoerder(backend, src2, ne20); + backend_ctx->toggle_reorder = false; + } + + cl_mem sub_buf_src1_pre, buf_src1_reordered, image_src1_reordered, sub_buf_dst, buf_dst_image; + cl_mem buf_src2, buf_src2_emap; + + cl_buffer_region region; + region.origin = 0; + region.size = sizeof(int) * max_post_router_tile * n_tile_size; + buf_src2 = clCreateSubBuffer(backend_ctx->prealloc_post_router.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + region.origin = 0; + region.size = sizeof(short) * max_post_router_tile; + buf_src2_emap = clCreateSubBuffer(backend_ctx->prealloc_emap.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // Reorder activations + region.origin = offset1; + region.size = ne10 * ne11 * ne12 * sizeof(float); + sub_buf_src1_pre = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // Create image for reordered src1 + region.origin = 0; + region.size = ne00 * max_post_router_tile * n_tile_size * sizeof(float); + backend_ctx->prealloc_act_trans.allocate(backend_ctx->context, region.size); + buf_src1_reordered = clCreateSubBuffer( + backend_ctx->prealloc_act_trans.buffer, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &status); + CL_CHECK(status); + cl_image_format image_format_buf_src1 = {CL_RGBA, CL_FLOAT}; + cl_image_desc image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ne00 * max_post_router_tile * n_tile_size / 4), 0,0,0,0,0,0,0, {buf_src1_reordered}}; + image_src1_reordered = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status); + CL_CHECK(status); + + unsigned short map_ratio = ne20 / ne11; + GGML_ASSERT(((map_ratio == 1) || (map_ratio == ne20)) && "Map ratio not supported\n"); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 0, sizeof(cl_mem), &sub_buf_src1_pre)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 1, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 2, sizeof(cl_mem), &buf_src1_reordered)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 3, sizeof(cl_mem), &(backend_ctx->prealloc_total_tiles.buffer))); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 4, sizeof(unsigned int), &ne00)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 5, sizeof(unsigned short), &map_ratio)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 6, sizeof(unsigned int), &n_tile_size)); + + size_t reorder_b_local_size[3] = {256, 1, 1}; + size_t reorder_b_global_size[3] = {static_cast(((ne00 / 4) + 255) / 256 * 256), static_cast(max_post_router_tile * n_tile_size), 1}; + + // Dispatch reorder kernel + backend_ctx->enqueue_ndrange_kernel(backend_ctx->kernel_moe_reorder_b, 3, reorder_b_global_size, reorder_b_local_size, dst); + + // MoE kernel prepare + region.origin = offsetd; + region.size = ne0 * ne1 * ne2 * sizeof(float); + sub_buf_dst = clCreateSubBuffer( + extrad->data_device, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &status); + CL_CHECK(status); + // Create image for dst + cl_image_format image_format_buf_dst = {CL_R, CL_FLOAT}; + cl_image_desc image_desc_buf_dst = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ne0 * ne1 * ne2), 0,0,0,0,0,0,0, {sub_buf_dst}}; + buf_dst_image = clCreateImage(backend_ctx->context, CL_MEM_WRITE_ONLY, &image_format_buf_dst, &image_desc_buf_dst, NULL, &status); + CL_CHECK(status); + + // Set kernel args + int arg_idx = 0; + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_K->q_img)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_K->d)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_K->dm)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_K->s)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &image_src1_reordered)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2_emap)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_dst_image)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &(backend_ctx->prealloc_total_tiles.buffer))); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01)); + + // set thread grid + global_size[1] = static_cast((ne01 + 63) / 64); + global_size[2] = static_cast(max_post_router_tile); + local_size[1] = 1; + local_size[2] = 1; + + // Dispatch kernel + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst); + + clReleaseMemObject(sub_buf_src1_pre); + clReleaseMemObject(buf_src1_reordered); + clReleaseMemObject(image_src1_reordered); + clReleaseMemObject(buf_src2); + clReleaseMemObject(buf_src2_emap); + clReleaseMemObject(sub_buf_dst); + clReleaseMemObject(buf_dst_image); + } + return; + } +#endif //GGML_OPENCL_USE_ADRENO_KERNELS + } + case GGML_TYPE_Q5_K: { +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_moe_kernels(backend_ctx, src0)) { + cl_int status; + + size_t local_size[3] = {64, 2, 1}; + size_t global_size[3] = {64, 2, 1}; + + if (ne12 == 1) { // for gemv + kernel = backend_ctx->kernel_gemv_moe_q5_k_f32_ns; + + cl_mem src1_sub_buffer, buf_src1_image, buf_src2; + + // create a sub_buffer for src2 + cl_buffer_region region; + region.origin = offset2; + region.size = ne20 * ne21 * sizeof(int); + buf_src2 = clCreateSubBuffer(extra2->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // set thread grid + global_size[0] = static_cast(ne01); + global_size[1] = 4; + global_size[2] = static_cast(ne20); + local_size[1] = 4; + + // create a sub_buffer for src1 + region.origin = offset1; + region.size = ne10 * ne11 * ne12 * sizeof(float); + src1_sub_buffer = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // create image for src1 + cl_image_format image_format_buf_src1 = {CL_RGBA, CL_FLOAT}; + cl_image_desc image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ne10 * ne11 * ne12 / 4), 0,0,0,0,0,0,0, {src1_sub_buffer}}; + buf_src1_image = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status); + CL_CHECK(status); + + // Set kernel args + int arg_idx = 0; + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_K->q)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_K->qh)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_K->d)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_K->dm)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_K->s)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src1_image)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne11)); + + // launch kernel + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst); + + // deallocate sub buffers and images + CL_CHECK(clReleaseMemObject(src1_sub_buffer)); + CL_CHECK(clReleaseMemObject(buf_src1_image)); + CL_CHECK(clReleaseMemObject(buf_src2)); + + } else { // for gemm + kernel = backend_ctx->kernel_gemm_moe_q5_k_f32_ns; + + // Reorder router if called from test-backend-ops or when new router is generated. + // Otherwise reuse the reordered result from previous mul_mat_id call. + if ((strstr(src0->name, "as") != NULL) || backend_ctx->toggle_reorder) { + moe_router_reoerder(backend, src2, ne20); + backend_ctx->toggle_reorder = false; + } + + cl_mem sub_buf_src1_pre, buf_src1_reordered, image_src1_reordered, sub_buf_dst, buf_dst_image; + cl_mem buf_src2, buf_src2_emap; + + cl_buffer_region region; + region.origin = 0; + region.size = sizeof(int) * max_post_router_tile * n_tile_size; + buf_src2 = clCreateSubBuffer(backend_ctx->prealloc_post_router.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + region.origin = 0; + region.size = sizeof(short) * max_post_router_tile; + buf_src2_emap = clCreateSubBuffer(backend_ctx->prealloc_emap.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // Reorder activations + // create a sub_buffer for src1 + region.origin = offset1; + region.size = ne10 * ne11 * ne12 * sizeof(float); + sub_buf_src1_pre = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // Create image for reordered src1 + // Use pre-allocated placeholder + region.origin = 0; + region.size = ne00 * max_post_router_tile * n_tile_size * sizeof(float); + backend_ctx->prealloc_act_trans.allocate(backend_ctx->context, region.size); + buf_src1_reordered = clCreateSubBuffer( + backend_ctx->prealloc_act_trans.buffer, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &status); + CL_CHECK(status); + cl_image_format image_format_buf_src1 = {CL_RGBA, CL_FLOAT}; + cl_image_desc image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ne00 * max_post_router_tile * n_tile_size / 4), 0,0,0,0,0,0,0, {buf_src1_reordered}}; + image_src1_reordered = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status); + CL_CHECK(status); + + unsigned short map_ratio = ne20 / ne11; + GGML_ASSERT(((map_ratio == 1) || (map_ratio == ne20)) && "Map ratio not supported\n"); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 0, sizeof(cl_mem), &sub_buf_src1_pre)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 1, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 2, sizeof(cl_mem), &buf_src1_reordered)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 3, sizeof(cl_mem), &(backend_ctx->prealloc_total_tiles.buffer))); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 4, sizeof(unsigned int), &ne00)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 5, sizeof(unsigned short), &map_ratio)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 6, sizeof(unsigned int), &n_tile_size)); + + size_t reorder_b_local_size[3] = {256, 1, 1}; + size_t reorder_b_global_size[3] = {static_cast(((ne00 / 4) + 255) / 256 * 256), static_cast(max_post_router_tile * n_tile_size), 1}; + + // Dispatch reorder kernel + backend_ctx->enqueue_ndrange_kernel(backend_ctx->kernel_moe_reorder_b, 3, reorder_b_global_size, reorder_b_local_size, dst); + + // MoE kernel prepare + // Create sub buffer for dst + region.origin = offsetd; + region.size = ne0 * ne1 * ne2 * sizeof(float); + sub_buf_dst = clCreateSubBuffer( + extrad->data_device, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &status); + CL_CHECK(status); + // Create image for dst + cl_image_format image_format_buf_dst = {CL_R, CL_FLOAT}; + cl_image_desc image_desc_buf_dst = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ne0 * ne1 * ne2), 0,0,0,0,0,0,0, {sub_buf_dst}}; + buf_dst_image = clCreateImage(backend_ctx->context, CL_MEM_WRITE_ONLY, &image_format_buf_dst, &image_desc_buf_dst, NULL, &status); + CL_CHECK(status); + + // Set kernel args + int arg_idx = 0; + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_K->q_img)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_K->qh)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_K->s)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_K->d)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_K->dm)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &image_src1_reordered)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2_emap)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_dst_image)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &(backend_ctx->prealloc_total_tiles.buffer))); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01)); + + // set thread grid + global_size[1] = static_cast((ne01 + 63) / 64); + global_size[2] = static_cast(max_post_router_tile); + local_size[1] = 1; + local_size[2] = 1; + + // Dispatch kernel + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst); + + clReleaseMemObject(sub_buf_src1_pre); + clReleaseMemObject(buf_src1_reordered); + clReleaseMemObject(image_src1_reordered); + clReleaseMemObject(buf_src2); + clReleaseMemObject(buf_src2_emap); + clReleaseMemObject(sub_buf_dst); + clReleaseMemObject(buf_dst_image); + } + return; + } +#endif //GGML_OPENCL_USE_ADRENO_KERNELS + } + case GGML_TYPE_Q6_K: { +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_moe_kernels(backend_ctx, src0)) { + cl_int status; + + size_t local_size[3] = {64, 2, 1}; + size_t global_size[3] = {64, 2, 1}; + + if (ne12 == 1) { // for gemv + kernel = backend_ctx->kernel_gemv_moe_q6_k_f32_ns; + + cl_mem src1_sub_buffer, buf_src1_image, buf_src2; + + // create a sub_buffer for src2 + cl_buffer_region region; + region.origin = offset2; + region.size = ne20 * ne21 * sizeof(int); + buf_src2 = clCreateSubBuffer(extra2->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // set thread grid + global_size[0] = static_cast(ne01); + global_size[1] = 4; + global_size[2] = static_cast(ne20); + local_size[1] = 4; + + // create a sub_buffer for src1 + region.origin = offset1; + region.size = ne10 * ne11 * ne12 * sizeof(float); + src1_sub_buffer = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // create image for src1 + cl_image_format image_format_buf_src1 = {CL_RGBA, CL_FLOAT}; + cl_image_desc image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ne10 * ne11 * ne12 / 4), 0,0,0,0,0,0,0, {src1_sub_buffer}}; + buf_src1_image = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status); + CL_CHECK(status); + + // Set kernel args + int arg_idx = 0; + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q6_K->ql)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q6_K->qh)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q6_K->s)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q6_K->d)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src1_image)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne11)); + + // launch kernel + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst); + + // deallocate sub buffers and images + CL_CHECK(clReleaseMemObject(src1_sub_buffer)); + CL_CHECK(clReleaseMemObject(buf_src1_image)); + CL_CHECK(clReleaseMemObject(buf_src2)); + + } else { // for gemm + kernel = backend_ctx->kernel_gemm_moe_q6_k_f32_ns; + + // Reorder router if called from test-backend-ops or when new router is generated. + // Otherwise reuse the reordered result from previous mul_mat_id call. + if ((strstr(src0->name, "as") != NULL) || backend_ctx->toggle_reorder) { + moe_router_reoerder(backend, src2, ne20); + backend_ctx->toggle_reorder = false; + } + + cl_mem sub_buf_src1_pre, buf_src1_reordered, image_src1_reordered, sub_buf_dst, buf_dst_image; + cl_mem buf_src2, buf_src2_emap; + + cl_buffer_region region; + region.origin = 0; + region.size = sizeof(int) * max_post_router_tile * n_tile_size; + buf_src2 = clCreateSubBuffer(backend_ctx->prealloc_post_router.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + region.origin = 0; + region.size = sizeof(short) * max_post_router_tile; + buf_src2_emap = clCreateSubBuffer(backend_ctx->prealloc_emap.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // Reorder activations + // create a sub_buffer for src1 + region.origin = offset1; + region.size = ne10 * ne11 * ne12 * sizeof(float); + sub_buf_src1_pre = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // Create image for reordered src1 + region.origin = 0; + region.size = ne00 * max_post_router_tile * n_tile_size * sizeof(float); + backend_ctx->prealloc_act_trans.allocate(backend_ctx->context, region.size); + buf_src1_reordered = clCreateSubBuffer( + backend_ctx->prealloc_act_trans.buffer, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &status); + CL_CHECK(status); + cl_image_format image_format_buf_src1 = {CL_RGBA, CL_FLOAT}; + cl_image_desc image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ne00 * max_post_router_tile * n_tile_size / 4), 0,0,0,0,0,0,0, {buf_src1_reordered}}; + image_src1_reordered = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status); + CL_CHECK(status); + + unsigned short map_ratio = ne20 / ne11; + GGML_ASSERT(((map_ratio == 1) || (map_ratio == ne20)) && "Map ratio not supported\n"); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 0, sizeof(cl_mem), &sub_buf_src1_pre)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 1, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 2, sizeof(cl_mem), &buf_src1_reordered)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 3, sizeof(cl_mem), &(backend_ctx->prealloc_total_tiles.buffer))); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 4, sizeof(unsigned int), &ne00)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 5, sizeof(unsigned short), &map_ratio)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 6, sizeof(unsigned int), &n_tile_size)); + + size_t reorder_b_local_size[3] = {256, 1, 1}; + size_t reorder_b_global_size[3] = {static_cast(((ne00 / 4) + 255) / 256 * 256), static_cast(max_post_router_tile * n_tile_size), 1}; + + // Dispatch reorder kernel + backend_ctx->enqueue_ndrange_kernel(backend_ctx->kernel_moe_reorder_b, 3, reorder_b_global_size, reorder_b_local_size, dst); + + // MoE kernel prepare + // Create sub buffer for dst + region.origin = offsetd; + region.size = ne0 * ne1 * ne2 * sizeof(float); + sub_buf_dst = clCreateSubBuffer( + extrad->data_device, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &status); + CL_CHECK(status); + // Create image for dst + cl_image_format image_format_buf_dst = {CL_R, CL_FLOAT}; + cl_image_desc image_desc_buf_dst = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast(ne0 * ne1 * ne2), 0,0,0,0,0,0,0, {sub_buf_dst}}; + buf_dst_image = clCreateImage(backend_ctx->context, CL_MEM_WRITE_ONLY, &image_format_buf_dst, &image_desc_buf_dst, NULL, &status); + CL_CHECK(status); + + // Set kernel args + int arg_idx = 0; + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q6_K->ql_img)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q6_K->qh)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q6_K->s)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q6_K->d)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &image_src1_reordered)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2_emap)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_dst_image)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &(backend_ctx->prealloc_total_tiles.buffer))); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01)); + + // set thread grid + global_size[1] = static_cast((ne01 + 63) / 64); + global_size[2] = static_cast(max_post_router_tile); + local_size[1] = 1; + local_size[2] = 1; + + // Dispatch kernel + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst); + + clReleaseMemObject(sub_buf_src1_pre); + clReleaseMemObject(buf_src1_reordered); + clReleaseMemObject(image_src1_reordered); + clReleaseMemObject(buf_src2); + clReleaseMemObject(buf_src2_emap); + clReleaseMemObject(sub_buf_dst); + clReleaseMemObject(buf_dst_image); + } + return; + } +#endif //GGML_OPENCL_USE_ADRENO_KERNELS + } case GGML_TYPE_MXFP4: { #ifdef GGML_OPENCL_USE_ADRENO_KERNELS if (use_adreno_moe_kernels(backend_ctx, src0)) { diff --git a/ggml/src/ggml-opencl/kernels/cvt.cl b/ggml/src/ggml-opencl/kernels/cvt.cl index 8f06d570587..312366984b6 100644 --- a/ggml/src/ggml-opencl/kernels/cvt.cl +++ b/ggml/src/ggml-opencl/kernels/cvt.cl @@ -664,6 +664,391 @@ kernel void kernel_restore_block_q5_1_trans4_ns( ((__global ushort8 *)(&(b->qs[0])))[0] = pre_block; } +kernel void kernel_convert_block_q4_k_trans4_ns( + __global struct block_q4_K * src0, + __global uint * dst_q, + __global half * dst_d, + __global half * dst_dm, + __global uchar * dst_s, + uint ne00, + uint ne01, + uchar mask_0F, + uchar mask_F0 +) { + uint i00 = get_global_id(1); + uint i01 = get_global_id(0); + uint i02 = get_global_id(2); + + uint ne00_blk = ne00 / QK_K; + uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; + uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; + + __global struct block_q4_K * b = src0 + src_blk_offset; + + dst_d [dst_blk_offset] = b->d; + dst_dm[dst_blk_offset] = b->dm; + + uint4 qv[8]; + uchar * qv_bytes = (uchar *)qv; + for (int i = 0; i < QK_K / 64; ++i) { + for (int j = 0; j < 16; ++j) { + uchar x0 = b->q[i*32 + 2*j]; + uchar x1 = b->q[i*32 + 2*j + 1]; + + qv_bytes[i*32 + j ] = convert_uchar(x0 & mask_0F) | convert_uchar((x1 & mask_0F) << 4); + qv_bytes[i*32 + j + 16] = convert_uchar((x0 & mask_F0) >> 4) | convert_uchar(x1 & mask_F0); + } + } + + uint base = i02 * ne00_blk * ne01 * 32 + i00 * ne01 * 32 + i01; + #pragma unroll + for (int p = 0; p < 8; ++p) { + uint4 v = qv[p]; + dst_q[base + (p * 4 + 0) * ne01] = v.x; + dst_q[base + (p * 4 + 1) * ne01] = v.y; + dst_q[base + (p * 4 + 2) * ne01] = v.z; + dst_q[base + (p * 4 + 3) * ne01] = v.w; + } + + __global uchar * s_dst = dst_s + (i02 * ne01 + i01) * ne00_blk * K_SCALE_SIZE + i00 * K_SCALE_SIZE; + #pragma unroll + for (int i = 0; i < K_SCALE_SIZE; ++i) { + s_dst[i] = b->s[i]; + } +} + +kernel void kernel_restore_block_q4_k_trans4_ns( + __global uint * src_q, + __global half * src_d, + __global half * src_dm, + __global uchar * src_s, + __global struct block_q4_K * dst0, + uint ne00, + uint ne01, + uchar mask_0F, + uchar mask_F0 +) { + uint i00 = get_global_id(1); // block index along K + uint i01 = get_global_id(0); // row index + uint i02 = get_global_id(2); // batch index + + uint ne00_blk = ne00 / QK_K; + + uint src_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; + uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; + + __global struct block_q4_K * b = dst0 + dst_blk_offset; + + b->d = src_d[src_blk_offset]; + b->dm = src_dm[src_blk_offset]; + + __global uchar * s_src = src_s + (i02 * ne01 + i01) * ne00_blk * K_SCALE_SIZE + i00 * K_SCALE_SIZE; + for (int i = 0; i < K_SCALE_SIZE; ++i) { + b->s[i] = s_src[i]; + } + + uint base = i02 * ne00_blk * ne01 * 32 + i00 * ne01 * 32 + i01; + + uint4 qv[8]; + for (int p = 0; p < 8; ++p) { + qv[p].x = src_q[base + (p * 4 + 0) * ne01]; + qv[p].y = src_q[base + (p * 4 + 1) * ne01]; + qv[p].z = src_q[base + (p * 4 + 2) * ne01]; + qv[p].w = src_q[base + (p * 4 + 3) * ne01]; + } + + uchar * qv_bytes = (uchar *)qv; + for (int i = 0; i < QK_K / 64; ++i) { + for (int j = 0; j < 16; ++j) { + uchar lo = qv_bytes[i*32 + j]; + uchar hi = qv_bytes[i*32 + j + 16]; + b->q[i*32 + 2*j] = convert_uchar((lo & mask_0F) | ((hi & mask_0F) << 4)); + b->q[i*32 + 2*j + 1] = convert_uchar(((lo & mask_F0) >> 4) | (hi & mask_F0)); + } + } +} + +kernel void kernel_convert_block_q5_k_trans4_ns( + __global struct block_q5_K * src0, + __global uint * dst_qs, + __global uint * dst_qh, + __global half * dst_d, + __global half * dst_dm, + __global uchar * dst_s, + uint ne00, + uint ne01, + uchar mask_0F, + uchar mask_F0 +) { + uint i00 = get_global_id(1); + uint i01 = get_global_id(0); + uint i02 = get_global_id(2); + + uint ne00_blk = ne00 / QK_K; + uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; + uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; + + __global struct block_q5_K * b = src0 + src_blk_offset; + + dst_d [dst_blk_offset] = b->d; + dst_dm[dst_blk_offset] = b->dm; + + for (int k = 0; k < 8; k++) { + uchar b0 = 0, b1 = 0, b2 = 0, b3 = 0; + for (int bit = 0; bit < 8; bit++) { + b0 |= (uchar)(((b->qh[bit] >> k) & 1) << bit); + b1 |= (uchar)(((b->qh[8 + bit] >> k) & 1) << bit); + b2 |= (uchar)(((b->qh[16 + bit] >> k) & 1) << bit); + b3 |= (uchar)(((b->qh[24 + bit] >> k) & 1) << bit); + } + uint packed = (uint)b0 | ((uint)b1 << 8) | ((uint)b2 << 16) | ((uint)b3 << 24); + dst_qh[i01 + (i00 * 8 + k) * ne01 + i02 * ne00_blk * 8 * ne01] = packed; + } + + uint4 qv[8]; + uchar * qv_bytes = (uchar *)qv; + for (int i = 0; i < QK_K / 64; ++i) { + for (int j = 0; j < 16; ++j) { + uchar x0 = b->qs[i*32 + 2*j]; + uchar x1 = b->qs[i*32 + 2*j + 1]; + + qv_bytes[i*32 + j ] = convert_uchar(x0 & mask_0F) | convert_uchar((x1 & mask_0F) << 4); + qv_bytes[i*32 + j + 16] = convert_uchar((x0 & mask_F0) >> 4) | convert_uchar(x1 & mask_F0); + } + } + + uint base = i02 * ne00_blk * ne01 * 32 + i00 * ne01 * 32 + i01; + #pragma unroll + for (int p = 0; p < 8; ++p) { + uint4 v = qv[p]; + dst_qs[base + (p * 4 + 0) * ne01] = v.x; + dst_qs[base + (p * 4 + 1) * ne01] = v.y; + dst_qs[base + (p * 4 + 2) * ne01] = v.z; + dst_qs[base + (p * 4 + 3) * ne01] = v.w; + } + + __global uchar * s_dst = dst_s + (i02 * ne01 + i01) * ne00_blk * K_SCALE_SIZE + i00 * K_SCALE_SIZE; + #pragma unroll + for (int i = 0; i < K_SCALE_SIZE; ++i) { + s_dst[i] = b->s[i]; + } +} + +kernel void kernel_restore_block_q5_k_trans4_ns( + __global uint * src_qs, + __global uint * src_qh, + __global half * src_d, + __global half * src_dm, + __global uchar * src_s, + __global struct block_q5_K * dst0, + uint ne00, + uint ne01, + uchar mask_0F, + uchar mask_F0 +) { + uint i00 = get_global_id(1); // block index along K + uint i01 = get_global_id(0); // row index + uint i02 = get_global_id(2); // batch index + + uint ne00_blk = ne00 / QK_K; + + uint src_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; + uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; + + __global struct block_q5_K * b = dst0 + dst_blk_offset; + + b->d = src_d[src_blk_offset]; + b->dm = src_dm[src_blk_offset]; + + for (int j = 0; j < 32; j++) b->qh[j] = 0; + for (int k = 0; k < 8; k++) { + uint packed = src_qh[i01 + (i00 * 8 + k) * ne01 + i02 * ne00_blk * 8 * ne01]; + uchar b0 = (uchar)(packed & 0xFF); + uchar b1 = (uchar)((packed >> 8) & 0xFF); + uchar b2 = (uchar)((packed >> 16) & 0xFF); + uchar b3 = (uchar)((packed >> 24) & 0xFF); + for (int bit = 0; bit < 8; bit++) { + b->qh[bit] |= (uchar)(((b0 >> bit) & 1) << k); + b->qh[8 + bit] |= (uchar)(((b1 >> bit) & 1) << k); + b->qh[16 + bit] |= (uchar)(((b2 >> bit) & 1) << k); + b->qh[24 + bit] |= (uchar)(((b3 >> bit) & 1) << k); + } + } + + __global uchar * s_src = src_s + (i02 * ne01 + i01) * ne00_blk * K_SCALE_SIZE + i00 * K_SCALE_SIZE; + for (int i = 0; i < K_SCALE_SIZE; ++i) { + b->s[i] = s_src[i]; + } + + uint base = i02 * ne00_blk * ne01 * 32 + i00 * ne01 * 32 + i01; + + uint4 qv[8]; + for (int p = 0; p < 8; ++p) { + qv[p].x = src_qs[base + (p * 4 + 0) * ne01]; + qv[p].y = src_qs[base + (p * 4 + 1) * ne01]; + qv[p].z = src_qs[base + (p * 4 + 2) * ne01]; + qv[p].w = src_qs[base + (p * 4 + 3) * ne01]; + } + + uchar * qv_bytes = (uchar *)qv; + for (int i = 0; i < QK_K / 64; ++i) { + for (int j = 0; j < 16; ++j) { + uchar lo = qv_bytes[i*32 + j]; + uchar hi = qv_bytes[i*32 + j + 16]; + b->qs[i*32 + 2*j] = convert_uchar((lo & mask_0F) | ((hi & mask_0F) << 4)); + b->qs[i*32 + 2*j + 1] = convert_uchar(((lo & mask_F0) >> 4) | (hi & mask_F0)); + } + } +} + +kernel void kernel_convert_block_q6_k_trans4_ns( + __global struct block_q6_K * src0, + __global uint * dst_ql, + __global uint * dst_qh, + __global half * dst_d, + __global char * dst_s, + uint ne00, + uint ne01, + uchar mask_0F, + uchar mask_F0 +) { + uint i00 = get_global_id(1); + uint i01 = get_global_id(0); + uint i02 = get_global_id(2); + + uint ne00_blk = ne00 / QK_K; + + uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; + uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; + + __global struct block_q6_K * b = src0 + src_blk_offset; + + dst_d[dst_blk_offset] = b->d; + + uint4 qlv[8]; + uchar * qlv_bytes = (uchar *)qlv; + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 16; ++j) { + uchar x0 = b->ql[i*64 + 2*j]; + uchar x1 = b->ql[i*64 + 2*j + 1]; + uchar x2 = b->ql[i*64 + 32 + 2*j]; + uchar x3 = b->ql[i*64 + 32 + 2*j + 1]; + qlv_bytes[i*64 + j ] = convert_uchar(x0 & mask_0F) | convert_uchar((x1 & mask_0F) << 4); + qlv_bytes[i*64 + j + 16] = convert_uchar(x2 & mask_0F) | convert_uchar((x3 & mask_0F) << 4); + qlv_bytes[i*64 + j + 32] = convert_uchar((x0 & mask_F0) >> 4) | convert_uchar(x1 & mask_F0); + qlv_bytes[i*64 + j + 48] = convert_uchar((x2 & mask_F0) >> 4) | convert_uchar(x3 & mask_F0); + } + } + + uint ql_base = i02 * ne00_blk * ne01 * 32 + i00 * ne01 * 32 + i01; + + #pragma unroll + for (int p = 0; p < 8; ++p) { + uint4 v = qlv[p]; + dst_ql[ql_base + (p * 4 + 0) * ne01] = v.x; + dst_ql[ql_base + (p * 4 + 1) * ne01] = v.y; + dst_ql[ql_base + (p * 4 + 2) * ne01] = v.z; + dst_ql[ql_base + (p * 4 + 3) * ne01] = v.w; + } + + uint qhv[16] = {0}; + + for (int n = 0; n < 2; ++n) { + for (int l = 0; l < 32; ++l) { + uchar h = b->qh[n*32 + l]; + int u = l / 16; + int bit_pos = (l % 16) * 2; + qhv[(n*4 + 0)*2 + u] |= ((uint)((h >> 0) & 0x03)) << bit_pos; + qhv[(n*4 + 1)*2 + u] |= ((uint)((h >> 2) & 0x03)) << bit_pos; + qhv[(n*4 + 2)*2 + u] |= ((uint)((h >> 4) & 0x03)) << bit_pos; + qhv[(n*4 + 3)*2 + u] |= ((uint)((h >> 6) & 0x03)) << bit_pos; + } + } + + uint qh_base = i02 * ne00_blk * ne01 * 16 + i00 * ne01 * 16 + i01; + + for (int p = 0; p < 16; ++p) { + dst_qh[qh_base + p * ne01] = qhv[p]; + } + + __global char * s_dst = dst_s + (i02 * ne01 + i01) * ne00_blk * 16 + i00 * 16; + #pragma unroll + for (int i = 0; i < 16; ++i) { + s_dst[i] = b->scales[i]; + } +} + +kernel void kernel_restore_block_q6_k_trans4_ns( + __global uint * src_ql, + __global uint * src_qh, + __global half * src_d, + __global char * src_s, + __global struct block_q6_K * dst0, + uint ne00, + uint ne01, + uchar mask_0F, + uchar mask_F0 +) { + uint i00 = get_global_id(1); // block index along K + uint i01 = get_global_id(0); // row index + uint i02 = get_global_id(2); // batch index + + uint ne00_blk = ne00 / QK_K; + + uint src_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; + uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; + + __global struct block_q6_K * b = dst0 + dst_blk_offset; + + b->d = src_d[src_blk_offset]; + + uint ql_base = i02 * ne00_blk * ne01 * 32 + i00 * ne01 * 32 + i01; + uint4 qlv[8]; + for (int p = 0; p < 8; ++p) { + qlv[p].x = src_ql[ql_base + (p * 4 + 0) * ne01]; + qlv[p].y = src_ql[ql_base + (p * 4 + 1) * ne01]; + qlv[p].z = src_ql[ql_base + (p * 4 + 2) * ne01]; + qlv[p].w = src_ql[ql_base + (p * 4 + 3) * ne01]; + } + + uchar * qlv_bytes = (uchar *)qlv; + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 16; ++j) { + uchar lo_02 = qlv_bytes[i*64 + j]; + uchar lo_13 = qlv_bytes[i*64 + j + 16]; + uchar hi_02 = qlv_bytes[i*64 + j + 32]; + uchar hi_13 = qlv_bytes[i*64 + j + 48]; + b->ql[i*64 + 2*j] = convert_uchar((lo_02 & mask_0F) | ((hi_02 & mask_0F) << 4)); + b->ql[i*64 + 2*j + 1] = convert_uchar(((lo_02 & mask_F0) >> 4) | (hi_02 & mask_F0)); + b->ql[i*64 + 32 + 2*j] = convert_uchar((lo_13 & mask_0F) | ((hi_13 & mask_0F) << 4)); + b->ql[i*64 + 32 + 2*j + 1] = convert_uchar(((lo_13 & mask_F0) >> 4) | (hi_13 & mask_F0)); + } + } + + uint qh_base = i02 * ne00_blk * ne01 * 16 + i00 * ne01 * 16 + i01; + uint qhv[16]; + for (int p = 0; p < 16; ++p) { + qhv[p] = src_qh[qh_base + p * ne01]; + } + + for (int n = 0; n < 2; ++n) { + for (int l = 0; l < 32; ++l) { + int u = l / 16; + int bit_pos = (l % 16) * 2; + uchar v0 = (uchar)((qhv[(n*4 + 0)*2 + u] >> bit_pos) & 0x03); + uchar v1 = (uchar)((qhv[(n*4 + 1)*2 + u] >> bit_pos) & 0x03); + uchar v2 = (uchar)((qhv[(n*4 + 2)*2 + u] >> bit_pos) & 0x03); + uchar v3 = (uchar)((qhv[(n*4 + 3)*2 + u] >> bit_pos) & 0x03); + b->qh[n*32 + l] = v0 | (v1 << 2) | (v2 << 4) | (v3 << 6); + } + } + + __global char * s_src = src_s + (i02 * ne01 + i01) * ne00_blk * 16 + i00 * 16; + for (int i = 0; i < 16; ++i) { + b->scales[i] = s_src[i]; + } +} + //------------------------------------------------------------------------------ // block_mxfp4 //------------------------------------------------------------------------------ diff --git a/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl new file mode 100644 index 00000000000..9d24aff6a20 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl @@ -0,0 +1,279 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_uniform_load: enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_constant_load: enable +#pragma OPENCL EXTENSION cl_qcom_extra_vector_types : enable + +#define TILESIZE_K 16 +#define TILESIZE_M 64 +#define TILESIZE_N 32 +#define QK_K 256 +#define K_SCALE_SIZE 12 + +inline void get_scale_min_k4( + int j, + global const uchar * q, + uchar * d, + uchar * m +) { + if (j < 4) { + *d = q[j] & 63; + *m = q[j+4] & 63; + } else { + *d = (q[j+4] & 0x0F) | ((q[j-4] & 0xC0) >> 2); + *m = ((q[j+4] >> 4) & 0x0F) | ((q[j] & 0xC0) >> 2); + } +} + +#define dequantize_q4_k(q4, a_f16, scale, minv) \ + a_f16.s0 = (half)((float)(q4.s0 & 0x000F) * scale - minv); \ + a_f16.s1 = (half)((float)((q4.s0 & 0x00F0) >> 4) * scale - minv); \ + a_f16.s2 = (half)((float)((q4.s0 & 0x0F00) >> 8) * scale - minv); \ + a_f16.s3 = (half)((float)((q4.s0 & 0xF000) >> 12) * scale - minv); \ + a_f16.s4 = (half)((float)(q4.s1 & 0x000F) * scale - minv); \ + a_f16.s5 = (half)((float)((q4.s1 & 0x00F0) >> 4) * scale - minv); \ + a_f16.s6 = (half)((float)((q4.s1 & 0x0F00) >> 8) * scale - minv); \ + a_f16.s7 = (half)((float)((q4.s1 & 0xF000) >> 12) * scale - minv); \ + a_f16.s8 = (half)((float)(q4.s2 & 0x000F) * scale - minv); \ + a_f16.s9 = (half)((float)((q4.s2 & 0x00F0) >> 4) * scale - minv); \ + a_f16.sa = (half)((float)((q4.s2 & 0x0F00) >> 8) * scale - minv); \ + a_f16.sb = (half)((float)((q4.s2 & 0xF000) >> 12) * scale - minv); \ + a_f16.sc = (half)((float)(q4.s3 & 0x000F) * scale - minv); \ + a_f16.sd = (half)((float)((q4.s3 & 0x00F0) >> 4) * scale - minv); \ + a_f16.se = (half)((float)((q4.s3 & 0x0F00) >> 8) * scale - minv); \ + a_f16.sf = (half)((float)((q4.s3 & 0xF000) >> 12) * scale - minv); \ + + +#define dotx16_reduce8(a_reg, b_lm, c_reg, lm_offset) \ + acc.s0 = dot(a_reg.s0123, b_lm[lm_offset + 0]); \ + acc.s1 = dot(a_reg.s0123, b_lm[lm_offset + 1]); \ + acc.s2 = dot(a_reg.s0123, b_lm[lm_offset + 2]); \ + acc.s3 = dot(a_reg.s0123, b_lm[lm_offset + 3]); \ + acc.s4 = dot(a_reg.s0123, b_lm[lm_offset + 4]); \ + acc.s5 = dot(a_reg.s0123, b_lm[lm_offset + 5]); \ + acc.s6 = dot(a_reg.s0123, b_lm[lm_offset + 6]); \ + acc.s7 = dot(a_reg.s0123, b_lm[lm_offset + 7]); \ + acc.s8 = dot(a_reg.s0123, b_lm[lm_offset + 8]); \ + acc.s9 = dot(a_reg.s0123, b_lm[lm_offset + 9]); \ + acc.sa = dot(a_reg.s0123, b_lm[lm_offset + 10]); \ + acc.sb = dot(a_reg.s0123, b_lm[lm_offset + 11]); \ + acc.sc = dot(a_reg.s0123, b_lm[lm_offset + 12]); \ + acc.sd = dot(a_reg.s0123, b_lm[lm_offset + 13]); \ + acc.se = dot(a_reg.s0123, b_lm[lm_offset + 14]); \ + acc.sf = dot(a_reg.s0123, b_lm[lm_offset + 15]); \ + acc.s0 += dot(a_reg.s4567, b_lm[lm_offset + 32]); \ + acc.s1 += dot(a_reg.s4567, b_lm[lm_offset + 33]); \ + acc.s2 += dot(a_reg.s4567, b_lm[lm_offset + 34]); \ + acc.s3 += dot(a_reg.s4567, b_lm[lm_offset + 35]); \ + acc.s4 += dot(a_reg.s4567, b_lm[lm_offset + 36]); \ + acc.s5 += dot(a_reg.s4567, b_lm[lm_offset + 37]); \ + acc.s6 += dot(a_reg.s4567, b_lm[lm_offset + 38]); \ + acc.s7 += dot(a_reg.s4567, b_lm[lm_offset + 39]); \ + acc.s8 += dot(a_reg.s4567, b_lm[lm_offset + 40]); \ + acc.s9 += dot(a_reg.s4567, b_lm[lm_offset + 41]); \ + acc.sa += dot(a_reg.s4567, b_lm[lm_offset + 42]); \ + acc.sb += dot(a_reg.s4567, b_lm[lm_offset + 43]); \ + acc.sc += dot(a_reg.s4567, b_lm[lm_offset + 44]); \ + acc.sd += dot(a_reg.s4567, b_lm[lm_offset + 45]); \ + acc.se += dot(a_reg.s4567, b_lm[lm_offset + 46]); \ + acc.sf += dot(a_reg.s4567, b_lm[lm_offset + 47]); \ + c_reg.lo += convert_float8(acc.lo); \ + c_reg.hi += convert_float8(acc.hi); \ + acc.s0 = dot(a_reg.s89ab, b_lm[lm_offset + 64]); \ + acc.s1 = dot(a_reg.s89ab, b_lm[lm_offset + 65]); \ + acc.s2 = dot(a_reg.s89ab, b_lm[lm_offset + 66]); \ + acc.s3 = dot(a_reg.s89ab, b_lm[lm_offset + 67]); \ + acc.s4 = dot(a_reg.s89ab, b_lm[lm_offset + 68]); \ + acc.s5 = dot(a_reg.s89ab, b_lm[lm_offset + 69]); \ + acc.s6 = dot(a_reg.s89ab, b_lm[lm_offset + 70]); \ + acc.s7 = dot(a_reg.s89ab, b_lm[lm_offset + 71]); \ + acc.s8 = dot(a_reg.s89ab, b_lm[lm_offset + 72]); \ + acc.s9 = dot(a_reg.s89ab, b_lm[lm_offset + 73]); \ + acc.sa = dot(a_reg.s89ab, b_lm[lm_offset + 74]); \ + acc.sb = dot(a_reg.s89ab, b_lm[lm_offset + 75]); \ + acc.sc = dot(a_reg.s89ab, b_lm[lm_offset + 76]); \ + acc.sd = dot(a_reg.s89ab, b_lm[lm_offset + 77]); \ + acc.se = dot(a_reg.s89ab, b_lm[lm_offset + 78]); \ + acc.sf = dot(a_reg.s89ab, b_lm[lm_offset + 79]); \ + acc.s0 += dot(a_reg.scdef, b_lm[lm_offset + 96]); \ + acc.s1 += dot(a_reg.scdef, b_lm[lm_offset + 97]); \ + acc.s2 += dot(a_reg.scdef, b_lm[lm_offset + 98]); \ + acc.s3 += dot(a_reg.scdef, b_lm[lm_offset + 99]); \ + acc.s4 += dot(a_reg.scdef, b_lm[lm_offset + 100]); \ + acc.s5 += dot(a_reg.scdef, b_lm[lm_offset + 101]); \ + acc.s6 += dot(a_reg.scdef, b_lm[lm_offset + 102]); \ + acc.s7 += dot(a_reg.scdef, b_lm[lm_offset + 103]); \ + acc.s8 += dot(a_reg.scdef, b_lm[lm_offset + 104]); \ + acc.s9 += dot(a_reg.scdef, b_lm[lm_offset + 105]); \ + acc.sa += dot(a_reg.scdef, b_lm[lm_offset + 106]); \ + acc.sb += dot(a_reg.scdef, b_lm[lm_offset + 107]); \ + acc.sc += dot(a_reg.scdef, b_lm[lm_offset + 108]); \ + acc.sd += dot(a_reg.scdef, b_lm[lm_offset + 109]); \ + acc.se += dot(a_reg.scdef, b_lm[lm_offset + 110]); \ + acc.sf += dot(a_reg.scdef, b_lm[lm_offset + 111]); \ + c_reg.lo += convert_float8(acc.lo); \ + c_reg.hi += convert_float8(acc.hi); \ + + +__attribute__((qcom_wave_pair_mode(1))) +kernel void kernel_gemm_moe_q4_k_f32_ns( + __read_only image1d_buffer_t src0_q, + __global half * src0_d, + __global half * src0_dm, + __global uchar * src0_s, + __read_only image1d_buffer_t src1, + __global uint * src2, + __global ushort * src2_emap, + __write_only image1d_buffer_t dst, + __global int * total_tiles, + uint ne00, + uint ne01 +) { + uint block_id_m = get_global_id(1); // m_tile + uint block_id_n = get_global_id(2); // n_tile + + // Boundary check + if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) { + return; + } + + __private half16 reg_a; + __private float32 reg_c = (float32)(0); + __local half4 shared_b[128]; + + const ushort expert_id = src2_emap[block_id_n]; + + const uint row = block_id_m * TILESIZE_M; + const uint col = block_id_n * TILESIZE_N; + + uint sub_block_id_m = get_local_id(0); + uint2 b_global_offset; + b_global_offset.x = ((sub_block_id_m & 3) << 2) + (sub_block_id_m >> 2) * ne00; + b_global_offset.y = b_global_offset.x + (16 * ne00); + uint2 b_local_offset; + b_local_offset.x = (sub_block_id_m & 3) * 32 + (sub_block_id_m >> 2); + b_local_offset.y = b_local_offset.x + 16; + + uint num_superblocks = ne00 / QK_K; + uint scales_per_row = num_superblocks * K_SCALE_SIZE; + uint row_idx = row + get_global_id(0); + + // Loop along K axis, 32 elements per iteration (one sub-block), divided into 2 halves of 16 + for (uint step = 0; step < ne00; step += TILESIZE_K * 2) { + uint sub = step / 32; + uint sb = sub / 8; + uint j = sub % 8; + + // Load d and dm for super-block + uint d_offset = row + sb * ne01 + expert_id * num_superblocks * ne01 + get_global_id(0); + half d_val = src0_d[d_offset]; + half dm_val = src0_dm[d_offset]; + + // Load sub-block scale and min + global const uchar * sc = src0_s + (expert_id * ne01 + row_idx) * scales_per_row + sb * K_SCALE_SIZE; + uchar sv, mn; + get_scale_min_k4(j, sc, &sv, &mn); + + float scale = (float)d_val * (float)sv; + float minv = (float)dm_val * (float)mn; + + // First sub-block (16 elements) + uint q_sub_offset = row + ((ne01 * step) >> 3) + ((expert_id * ne00 * ne01) >> 3); + uint b_sub_offset = col * ne00 + step; + + // Load 16 q (64-bits) in transposed layout + uint2 q4x16; + q4x16.x = read_imageui(src0_q, q_sub_offset + sub_block_id_m).x; + q4x16.y = read_imageui(src0_q, q_sub_offset + sub_block_id_m + ne01).x; + + // Load 16x32 floats from matrix B + float8 bx8_f32; + bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4); + bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4); + half8 bx8_f16 = convert_half8(bx8_f32); + shared_b[b_local_offset.x] = bx8_f16.lo; + shared_b[b_local_offset.y] = bx8_f16.hi; + + // Dequantization + dequantize_q4_k(as_ushort4(q4x16), reg_a, scale, minv); + + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + + half16 acc; + dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0); + dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); + + // Second half (next 16 elements, same sub-block scale) + uint half_step = step + TILESIZE_K; + q_sub_offset = row + ((ne01 * half_step) >> 3) + ((expert_id * ne00 * ne01) >> 3); + b_sub_offset = col * ne00 + half_step; + + q4x16.x = read_imageui(src0_q, q_sub_offset + sub_block_id_m).x; + q4x16.y = read_imageui(src0_q, q_sub_offset + sub_block_id_m + ne01).x; + + bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4); + bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4); + bx8_f16 = convert_half8(bx8_f32); + shared_b[b_local_offset.x] = bx8_f16.lo; + shared_b[b_local_offset.y] = bx8_f16.hi; + + dequantize_q4_k(as_ushort4(q4x16), reg_a, scale, minv); + + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + + dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0); + dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); + } + + // Load post router and share in LM + __local uint out_idx[TILESIZE_N]; + + if (get_local_id(0) < TILESIZE_N) { + uint idx = src2[block_id_n * TILESIZE_N + get_local_id(0)]; + if (idx == 0xFFFFFFFF) { + idx = src2[block_id_n * TILESIZE_N + 0]; + } + out_idx[get_local_id(0)] = idx * ne01; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + // Scatter results back to original position in output grid + uint m_offset = row + get_local_id(0); + + write_imagef(dst, out_idx[1] + m_offset, (reg_c.s1)); + write_imagef(dst, out_idx[2] + m_offset, (reg_c.s2)); + write_imagef(dst, out_idx[3] + m_offset, (reg_c.s3)); + write_imagef(dst, out_idx[4] + m_offset, (reg_c.s4)); + write_imagef(dst, out_idx[5] + m_offset, (reg_c.s5)); + write_imagef(dst, out_idx[6] + m_offset, (reg_c.s6)); + write_imagef(dst, out_idx[7] + m_offset, (reg_c.s7)); + write_imagef(dst, out_idx[8] + m_offset, (reg_c.s8)); + write_imagef(dst, out_idx[9] + m_offset, (reg_c.s9)); + write_imagef(dst, out_idx[10] + m_offset, (reg_c.sa)); + write_imagef(dst, out_idx[11] + m_offset, (reg_c.sb)); + write_imagef(dst, out_idx[12] + m_offset, (reg_c.sc)); + write_imagef(dst, out_idx[13] + m_offset, (reg_c.sd)); + write_imagef(dst, out_idx[14] + m_offset, (reg_c.se)); + write_imagef(dst, out_idx[15] + m_offset, (reg_c.sf)); + write_imagef(dst, out_idx[16] + m_offset, (reg_c.sg)); + write_imagef(dst, out_idx[17] + m_offset, (reg_c.sh)); + write_imagef(dst, out_idx[18] + m_offset, (reg_c.si)); + write_imagef(dst, out_idx[19] + m_offset, (reg_c.sj)); + write_imagef(dst, out_idx[20] + m_offset, (reg_c.sk)); + write_imagef(dst, out_idx[21] + m_offset, (reg_c.sl)); + write_imagef(dst, out_idx[22] + m_offset, (reg_c.sm)); + write_imagef(dst, out_idx[23] + m_offset, (reg_c.sn)); + write_imagef(dst, out_idx[24] + m_offset, (reg_c.so)); + write_imagef(dst, out_idx[25] + m_offset, (reg_c.sp)); + write_imagef(dst, out_idx[26] + m_offset, (reg_c.sq)); + write_imagef(dst, out_idx[27] + m_offset, (reg_c.sr)); + write_imagef(dst, out_idx[28] + m_offset, (reg_c.ss)); + write_imagef(dst, out_idx[29] + m_offset, (reg_c.st)); + write_imagef(dst, out_idx[30] + m_offset, (reg_c.su)); + write_imagef(dst, out_idx[31] + m_offset, (reg_c.sv)); + + // Store zero padding parts to the index of first output in tile + barrier(CLK_GLOBAL_MEM_FENCE); + write_imagef(dst, out_idx[0] + m_offset, (reg_c.s0)); +} diff --git a/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl new file mode 100644 index 00000000000..808a0c7db6a --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl @@ -0,0 +1,284 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_uniform_load: enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_constant_load: enable +#pragma OPENCL EXTENSION cl_qcom_extra_vector_types : enable + +#define TILESIZE_K 16 +#define TILESIZE_M 64 +#define TILESIZE_N 32 +#define QK_K 256 +#define K_SCALE_SIZE 12 + +inline void get_scale_min_k4( + int j, + global const uchar * q, + uchar * d, + uchar * m +) { + if (j < 4) { + *d = q[j] & 63; + *m = q[j+4] & 63; + } else { + *d = (q[j+4] & 0x0F) | ((q[j-4] & 0xC0) >> 2); + *m = ((q[j+4] >> 4) & 0x0F) | ((q[j] & 0xC0) >> 2); + } +} + +#define dequantize_q5_k(qs5x16, qh5x16, a_f16, scale, m) \ + a_f16.s0 = (half)((float)(( qs5x16.s0 & 0x000F) | (( qh5x16.s0 & 0x01) << 4)) * scale + m); \ + a_f16.s1 = (half)((float)((((qs5x16.s0 & 0x00F0) >> 4 ) | (((qh5x16.s0 >> 1) & 0x01) << 4)) * scale + m)); \ + a_f16.s2 = (half)((float)((((qs5x16.s0 & 0x0F00) >> 8 ) | (((qh5x16.s0 >> 2) & 0x01) << 4)) * scale + m)); \ + a_f16.s3 = (half)((float)((((qs5x16.s0 & 0xF000) >> 12) | (((qh5x16.s0 >> 3) & 0x01) << 4)) * scale + m)); \ + a_f16.s4 = (half)((float)((( qs5x16.s1 & 0x000F) | (((qh5x16.s0 >> 4) & 0x01) << 4)) * scale + m)); \ + a_f16.s5 = (half)((float)((((qs5x16.s1 & 0x00F0) >> 4 ) | (((qh5x16.s0 >> 5) & 0x01) << 4)) * scale + m)); \ + a_f16.s6 = (half)((float)(((qs5x16.s1 & 0x0F00) >> 8 ) | (((qh5x16.s0 >> 6) & 0x01) << 4)) * scale + m); \ + a_f16.s7 = (half)((float)((((qs5x16.s1 & 0xF000) >> 12) | (((qh5x16.s0 >> 7) & 0x01) << 4)) * scale + m)); \ + a_f16.s8 = (half)((float)((( qs5x16.s2 & 0x000F) | (( qh5x16.s1 & 0x01) << 4)) * scale + m)); \ + a_f16.s9 = (half)((float)((((qs5x16.s2 & 0x00F0) >> 4 ) | (((qh5x16.s1 >> 1) & 0x01) << 4)) * scale + m)); \ + a_f16.sa = (half)((float)((((qs5x16.s2 & 0x0F00) >> 8 ) | (((qh5x16.s1 >> 2) & 0x01) << 4)) * scale + m)); \ + a_f16.sb = (half)((float)((((qs5x16.s2 & 0xF000) >> 12) | (((qh5x16.s1 >> 3) & 0x01) << 4)) * scale + m)); \ + a_f16.sc = (half)((float)((( qs5x16.s3 & 0x000F) | (((qh5x16.s1 >> 4) & 0x01) << 4)) * scale + m)); \ + a_f16.sd = (half)((float)((((qs5x16.s3 & 0x00F0) >> 4 ) | (((qh5x16.s1 >> 5) & 0x01) << 4)) * scale + m)); \ + a_f16.se = (half)((float)((((qs5x16.s3 & 0x0F00) >> 8 ) | (((qh5x16.s1 >> 6) & 0x01) << 4)) * scale + m)); \ + a_f16.sf = (half)((float)((((qs5x16.s3 & 0xF000) >> 12) | (((qh5x16.s1 >> 7) & 0x01) << 4)) * scale + m)); \ + + +#define dotx16_reduce8(a_reg, b_lm, c_reg, lm_offset) \ + acc.s0 = dot(a_reg.s0123, b_lm[lm_offset + 0]); \ + acc.s1 = dot(a_reg.s0123, b_lm[lm_offset + 1]); \ + acc.s2 = dot(a_reg.s0123, b_lm[lm_offset + 2]); \ + acc.s3 = dot(a_reg.s0123, b_lm[lm_offset + 3]); \ + acc.s4 = dot(a_reg.s0123, b_lm[lm_offset + 4]); \ + acc.s5 = dot(a_reg.s0123, b_lm[lm_offset + 5]); \ + acc.s6 = dot(a_reg.s0123, b_lm[lm_offset + 6]); \ + acc.s7 = dot(a_reg.s0123, b_lm[lm_offset + 7]); \ + acc.s8 = dot(a_reg.s0123, b_lm[lm_offset + 8]); \ + acc.s9 = dot(a_reg.s0123, b_lm[lm_offset + 9]); \ + acc.sa = dot(a_reg.s0123, b_lm[lm_offset + 10]); \ + acc.sb = dot(a_reg.s0123, b_lm[lm_offset + 11]); \ + acc.sc = dot(a_reg.s0123, b_lm[lm_offset + 12]); \ + acc.sd = dot(a_reg.s0123, b_lm[lm_offset + 13]); \ + acc.se = dot(a_reg.s0123, b_lm[lm_offset + 14]); \ + acc.sf = dot(a_reg.s0123, b_lm[lm_offset + 15]); \ + acc.s0 += dot(a_reg.s4567, b_lm[lm_offset + 32]); \ + acc.s1 += dot(a_reg.s4567, b_lm[lm_offset + 33]); \ + acc.s2 += dot(a_reg.s4567, b_lm[lm_offset + 34]); \ + acc.s3 += dot(a_reg.s4567, b_lm[lm_offset + 35]); \ + acc.s4 += dot(a_reg.s4567, b_lm[lm_offset + 36]); \ + acc.s5 += dot(a_reg.s4567, b_lm[lm_offset + 37]); \ + acc.s6 += dot(a_reg.s4567, b_lm[lm_offset + 38]); \ + acc.s7 += dot(a_reg.s4567, b_lm[lm_offset + 39]); \ + acc.s8 += dot(a_reg.s4567, b_lm[lm_offset + 40]); \ + acc.s9 += dot(a_reg.s4567, b_lm[lm_offset + 41]); \ + acc.sa += dot(a_reg.s4567, b_lm[lm_offset + 42]); \ + acc.sb += dot(a_reg.s4567, b_lm[lm_offset + 43]); \ + acc.sc += dot(a_reg.s4567, b_lm[lm_offset + 44]); \ + acc.sd += dot(a_reg.s4567, b_lm[lm_offset + 45]); \ + acc.se += dot(a_reg.s4567, b_lm[lm_offset + 46]); \ + acc.sf += dot(a_reg.s4567, b_lm[lm_offset + 47]); \ + c_reg.lo += convert_float8(acc.lo); \ + c_reg.hi += convert_float8(acc.hi); \ + acc.s0 = dot(a_reg.s89ab, b_lm[lm_offset + 64]); \ + acc.s1 = dot(a_reg.s89ab, b_lm[lm_offset + 65]); \ + acc.s2 = dot(a_reg.s89ab, b_lm[lm_offset + 66]); \ + acc.s3 = dot(a_reg.s89ab, b_lm[lm_offset + 67]); \ + acc.s4 = dot(a_reg.s89ab, b_lm[lm_offset + 68]); \ + acc.s5 = dot(a_reg.s89ab, b_lm[lm_offset + 69]); \ + acc.s6 = dot(a_reg.s89ab, b_lm[lm_offset + 70]); \ + acc.s7 = dot(a_reg.s89ab, b_lm[lm_offset + 71]); \ + acc.s8 = dot(a_reg.s89ab, b_lm[lm_offset + 72]); \ + acc.s9 = dot(a_reg.s89ab, b_lm[lm_offset + 73]); \ + acc.sa = dot(a_reg.s89ab, b_lm[lm_offset + 74]); \ + acc.sb = dot(a_reg.s89ab, b_lm[lm_offset + 75]); \ + acc.sc = dot(a_reg.s89ab, b_lm[lm_offset + 76]); \ + acc.sd = dot(a_reg.s89ab, b_lm[lm_offset + 77]); \ + acc.se = dot(a_reg.s89ab, b_lm[lm_offset + 78]); \ + acc.sf = dot(a_reg.s89ab, b_lm[lm_offset + 79]); \ + acc.s0 += dot(a_reg.scdef, b_lm[lm_offset + 96]); \ + acc.s1 += dot(a_reg.scdef, b_lm[lm_offset + 97]); \ + acc.s2 += dot(a_reg.scdef, b_lm[lm_offset + 98]); \ + acc.s3 += dot(a_reg.scdef, b_lm[lm_offset + 99]); \ + acc.s4 += dot(a_reg.scdef, b_lm[lm_offset + 100]); \ + acc.s5 += dot(a_reg.scdef, b_lm[lm_offset + 101]); \ + acc.s6 += dot(a_reg.scdef, b_lm[lm_offset + 102]); \ + acc.s7 += dot(a_reg.scdef, b_lm[lm_offset + 103]); \ + acc.s8 += dot(a_reg.scdef, b_lm[lm_offset + 104]); \ + acc.s9 += dot(a_reg.scdef, b_lm[lm_offset + 105]); \ + acc.sa += dot(a_reg.scdef, b_lm[lm_offset + 106]); \ + acc.sb += dot(a_reg.scdef, b_lm[lm_offset + 107]); \ + acc.sc += dot(a_reg.scdef, b_lm[lm_offset + 108]); \ + acc.sd += dot(a_reg.scdef, b_lm[lm_offset + 109]); \ + acc.se += dot(a_reg.scdef, b_lm[lm_offset + 110]); \ + acc.sf += dot(a_reg.scdef, b_lm[lm_offset + 111]); \ + c_reg.lo += convert_float8(acc.lo); \ + c_reg.hi += convert_float8(acc.hi); \ + + +__attribute__((qcom_wave_pair_mode(1))) +kernel void kernel_gemm_moe_q5_k_f32_ns( + __read_only image1d_buffer_t src0_q, + __global uint * src0_qh, + __global uchar * src0_s, + __global half * src0_d, + __global half * src0_dm, + __read_only image1d_buffer_t src1, + __global uint * src2, + __global ushort * src2_emap, + __write_only image1d_buffer_t dst, + __global int * total_tiles, + uint ne00, + uint ne01 +) { + uint block_id_m = get_global_id(1); // m_tile + uint block_id_n = get_global_id(2); // n_tile + + // Boundary check + if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) { + return; + } + + __private half16 reg_a; + __private float32 reg_c = (float32)(0); + __local half4 shared_b[128]; + + const ushort expert_id = src2_emap[block_id_n]; + + const uint row = block_id_m * TILESIZE_M; + const uint col = block_id_n * TILESIZE_N; + + uint sub_block_id_m = get_local_id(0); + uint2 b_global_offset; + b_global_offset.x = ((sub_block_id_m & 3) << 2) + (sub_block_id_m >> 2) * ne00; + b_global_offset.y = b_global_offset.x + (16 * ne00); + uint2 b_local_offset; + b_local_offset.x = (sub_block_id_m & 3) * 32 + (sub_block_id_m >> 2); + b_local_offset.y = b_local_offset.x + 16; + + uint num_superblocks = ne00 / QK_K; + uint scales_per_row = num_superblocks * K_SCALE_SIZE; + uint row_idx = row + get_global_id(0); + + // Loop along K axis, 32 elements per iteration (one sub-block), divided into 2 halves of 16 + for (uint step = 0; step < ne00; step += TILESIZE_K * 2) { + uint sub = step / 32; + uint sb = sub / 8; + uint j = sub % 8; + + // Load d and dm for super-block + uint d_offset = row + sb * ne01 + expert_id * num_superblocks * ne01 + get_global_id(0); + half d_val = src0_d[d_offset]; + half dm_val = src0_dm[d_offset]; + + // Load sub-block scale and min + global const uchar * sc = src0_s + (expert_id * ne01 + row_idx) * scales_per_row + sb * K_SCALE_SIZE; + uchar sv, mn; + get_scale_min_k4(j, sc, &sv, &mn); + + float scale = (float)d_val * (float)sv; + float minv = -(float)dm_val * (float)mn; + + // qh is stored at sub-block granularity + uint qh_offset = row + sub * ne01 + expert_id * num_superblocks * 8 * ne01 + get_global_id(0); + uchar4 qhx32 = as_uchar4(src0_qh[qh_offset]); + + // First sub-block (16 elements) + uint q_sub_offset = row + ((ne01 * step) >> 3) + ((expert_id * ne00 * ne01) >> 3); + uint b_sub_offset = col * ne00 + step; + + // Load 16 q (64-bits) in transposed layout + uint2 q4x16; + q4x16.x = read_imageui(src0_q, q_sub_offset + sub_block_id_m).x; + q4x16.y = read_imageui(src0_q, q_sub_offset + sub_block_id_m + ne01).x; + + // Load 16x32 floats from matrix B + float8 bx8_f32; + bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4); + bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4); + half8 bx8_f16 = convert_half8(bx8_f32); + shared_b[b_local_offset.x] = bx8_f16.lo; + shared_b[b_local_offset.y] = bx8_f16.hi; + + // Dequantization + dequantize_q5_k(as_ushort4(q4x16), qhx32.lo, reg_a, scale, minv); + + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + + half16 acc; + dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0); + dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); + + // Second half + uint half_step = step + TILESIZE_K; + q_sub_offset = row + ((ne01 * half_step) >> 3) + ((expert_id * ne00 * ne01) >> 3); + b_sub_offset = col * ne00 + half_step; + + q4x16.x = read_imageui(src0_q, q_sub_offset + sub_block_id_m).x; + q4x16.y = read_imageui(src0_q, q_sub_offset + sub_block_id_m + ne01).x; + + bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4); + bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4); + bx8_f16 = convert_half8(bx8_f32); + shared_b[b_local_offset.x] = bx8_f16.lo; + shared_b[b_local_offset.y] = bx8_f16.hi; + + dequantize_q5_k(as_ushort4(q4x16), qhx32.hi, reg_a, scale, minv); + + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + + dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0); + dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); + } + + // Load post router and share in LM + __local uint out_idx[TILESIZE_N]; + + if (get_local_id(0) < TILESIZE_N) { + uint idx = src2[block_id_n * TILESIZE_N + get_local_id(0)]; + if (idx == 0xFFFFFFFF) { + idx = src2[block_id_n * TILESIZE_N + 0]; + } + out_idx[get_local_id(0)] = idx * ne01; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + // Scatter results back to original position in output grid + uint m_offset = row + get_local_id(0); + + write_imagef(dst, out_idx[1] + m_offset, (reg_c.s1)); + write_imagef(dst, out_idx[2] + m_offset, (reg_c.s2)); + write_imagef(dst, out_idx[3] + m_offset, (reg_c.s3)); + write_imagef(dst, out_idx[4] + m_offset, (reg_c.s4)); + write_imagef(dst, out_idx[5] + m_offset, (reg_c.s5)); + write_imagef(dst, out_idx[6] + m_offset, (reg_c.s6)); + write_imagef(dst, out_idx[7] + m_offset, (reg_c.s7)); + write_imagef(dst, out_idx[8] + m_offset, (reg_c.s8)); + write_imagef(dst, out_idx[9] + m_offset, (reg_c.s9)); + write_imagef(dst, out_idx[10] + m_offset, (reg_c.sa)); + write_imagef(dst, out_idx[11] + m_offset, (reg_c.sb)); + write_imagef(dst, out_idx[12] + m_offset, (reg_c.sc)); + write_imagef(dst, out_idx[13] + m_offset, (reg_c.sd)); + write_imagef(dst, out_idx[14] + m_offset, (reg_c.se)); + write_imagef(dst, out_idx[15] + m_offset, (reg_c.sf)); + write_imagef(dst, out_idx[16] + m_offset, (reg_c.sg)); + write_imagef(dst, out_idx[17] + m_offset, (reg_c.sh)); + write_imagef(dst, out_idx[18] + m_offset, (reg_c.si)); + write_imagef(dst, out_idx[19] + m_offset, (reg_c.sj)); + write_imagef(dst, out_idx[20] + m_offset, (reg_c.sk)); + write_imagef(dst, out_idx[21] + m_offset, (reg_c.sl)); + write_imagef(dst, out_idx[22] + m_offset, (reg_c.sm)); + write_imagef(dst, out_idx[23] + m_offset, (reg_c.sn)); + write_imagef(dst, out_idx[24] + m_offset, (reg_c.so)); + write_imagef(dst, out_idx[25] + m_offset, (reg_c.sp)); + write_imagef(dst, out_idx[26] + m_offset, (reg_c.sq)); + write_imagef(dst, out_idx[27] + m_offset, (reg_c.sr)); + write_imagef(dst, out_idx[28] + m_offset, (reg_c.ss)); + write_imagef(dst, out_idx[29] + m_offset, (reg_c.st)); + write_imagef(dst, out_idx[30] + m_offset, (reg_c.su)); + write_imagef(dst, out_idx[31] + m_offset, (reg_c.sv)); + + // Store zero padding parts to the index of first output in tile + barrier(CLK_GLOBAL_MEM_FENCE); + write_imagef(dst, out_idx[0] + m_offset, (reg_c.s0)); +} diff --git a/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl new file mode 100644 index 00000000000..a040335adfa --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl @@ -0,0 +1,263 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_uniform_load: enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_constant_load: enable +#pragma OPENCL EXTENSION cl_qcom_extra_vector_types : enable + +#define TILESIZE_K 16 +#define TILESIZE_M 64 +#define TILESIZE_N 32 +#define QK_K 256 + +#define dequantize_q6_k(qs16, qh16, a_f16, scale) \ + a_f16.s0 = (half)(((float)(( qs16.s0 & 0x000F) | ((uint)(( qh16 ) & 0x3) << 4)) - 32.f) * scale); \ + a_f16.s1 = (half)(((float)((( qs16.s0 >> 4) & 0x000F) | ((uint)(( qh16 >> 2) & 0x3) << 4)) - 32.f) * scale); \ + a_f16.s2 = (half)(((float)((( qs16.s0 >> 8) & 0x000F) | ((uint)(( qh16 >> 4) & 0x3) << 4)) - 32.f) * scale); \ + a_f16.s3 = (half)(((float)((( qs16.s0 >>12) & 0x000F) | ((uint)(( qh16 >> 6) & 0x3) << 4)) - 32.f) * scale); \ + a_f16.s4 = (half)(((float)(( qs16.s1 & 0x000F) | ((uint)(( qh16 >> 8) & 0x3) << 4)) - 32.f) * scale); \ + a_f16.s5 = (half)(((float)((( qs16.s1 >> 4) & 0x000F) | ((uint)(( qh16 >> 10) & 0x3) << 4)) - 32.f) * scale); \ + a_f16.s6 = (half)(((float)((( qs16.s1 >> 8) & 0x000F) | ((uint)(( qh16 >> 12) & 0x3) << 4)) - 32.f) * scale); \ + a_f16.s7 = (half)(((float)((( qs16.s1 >>12) & 0x000F) | ((uint)(( qh16 >> 14) & 0x3) << 4)) - 32.f) * scale); \ + a_f16.s8 = (half)(((float)(( qs16.s2 & 0x000F) | ((uint)(( qh16 >> 16) & 0x3) << 4)) - 32.f) * scale); \ + a_f16.s9 = (half)(((float)((( qs16.s2 >> 4) & 0x000F) | ((uint)(( qh16 >> 18) & 0x3) << 4)) - 32.f) * scale); \ + a_f16.sa = (half)(((float)((( qs16.s2 >> 8) & 0x000F) | ((uint)(( qh16 >> 20) & 0x3) << 4)) - 32.f) * scale); \ + a_f16.sb = (half)(((float)((( qs16.s2 >>12) & 0x000F) | ((uint)(( qh16 >> 22) & 0x3) << 4)) - 32.f) * scale); \ + a_f16.sc = (half)(((float)(( qs16.s3 & 0x000F) | ((uint)(( qh16 >> 24) & 0x3) << 4)) - 32.f) * scale); \ + a_f16.sd = (half)(((float)((( qs16.s3 >> 4) & 0x000F) | ((uint)(( qh16 >> 26) & 0x3) << 4)) - 32.f) * scale); \ + a_f16.se = (half)(((float)((( qs16.s3 >> 8) & 0x000F) | ((uint)(( qh16 >> 28) & 0x3) << 4)) - 32.f) * scale); \ + a_f16.sf = (half)(((float)((( qs16.s3 >>12) & 0x000F) | ((uint)(( qh16 >> 30) & 0x3) << 4)) - 32.f) * scale); \ + + +#define dotx16_reduce8(a_reg, b_lm, c_reg, lm_offset) \ + acc.s0 = dot(a_reg.s0123, b_lm[lm_offset + 0]); \ + acc.s1 = dot(a_reg.s0123, b_lm[lm_offset + 1]); \ + acc.s2 = dot(a_reg.s0123, b_lm[lm_offset + 2]); \ + acc.s3 = dot(a_reg.s0123, b_lm[lm_offset + 3]); \ + acc.s4 = dot(a_reg.s0123, b_lm[lm_offset + 4]); \ + acc.s5 = dot(a_reg.s0123, b_lm[lm_offset + 5]); \ + acc.s6 = dot(a_reg.s0123, b_lm[lm_offset + 6]); \ + acc.s7 = dot(a_reg.s0123, b_lm[lm_offset + 7]); \ + acc.s8 = dot(a_reg.s0123, b_lm[lm_offset + 8]); \ + acc.s9 = dot(a_reg.s0123, b_lm[lm_offset + 9]); \ + acc.sa = dot(a_reg.s0123, b_lm[lm_offset + 10]); \ + acc.sb = dot(a_reg.s0123, b_lm[lm_offset + 11]); \ + acc.sc = dot(a_reg.s0123, b_lm[lm_offset + 12]); \ + acc.sd = dot(a_reg.s0123, b_lm[lm_offset + 13]); \ + acc.se = dot(a_reg.s0123, b_lm[lm_offset + 14]); \ + acc.sf = dot(a_reg.s0123, b_lm[lm_offset + 15]); \ + acc.s0 += dot(a_reg.s4567, b_lm[lm_offset + 32]); \ + acc.s1 += dot(a_reg.s4567, b_lm[lm_offset + 33]); \ + acc.s2 += dot(a_reg.s4567, b_lm[lm_offset + 34]); \ + acc.s3 += dot(a_reg.s4567, b_lm[lm_offset + 35]); \ + acc.s4 += dot(a_reg.s4567, b_lm[lm_offset + 36]); \ + acc.s5 += dot(a_reg.s4567, b_lm[lm_offset + 37]); \ + acc.s6 += dot(a_reg.s4567, b_lm[lm_offset + 38]); \ + acc.s7 += dot(a_reg.s4567, b_lm[lm_offset + 39]); \ + acc.s8 += dot(a_reg.s4567, b_lm[lm_offset + 40]); \ + acc.s9 += dot(a_reg.s4567, b_lm[lm_offset + 41]); \ + acc.sa += dot(a_reg.s4567, b_lm[lm_offset + 42]); \ + acc.sb += dot(a_reg.s4567, b_lm[lm_offset + 43]); \ + acc.sc += dot(a_reg.s4567, b_lm[lm_offset + 44]); \ + acc.sd += dot(a_reg.s4567, b_lm[lm_offset + 45]); \ + acc.se += dot(a_reg.s4567, b_lm[lm_offset + 46]); \ + acc.sf += dot(a_reg.s4567, b_lm[lm_offset + 47]); \ + c_reg.lo += convert_float8(acc.lo); \ + c_reg.hi += convert_float8(acc.hi); \ + acc.s0 = dot(a_reg.s89ab, b_lm[lm_offset + 64]); \ + acc.s1 = dot(a_reg.s89ab, b_lm[lm_offset + 65]); \ + acc.s2 = dot(a_reg.s89ab, b_lm[lm_offset + 66]); \ + acc.s3 = dot(a_reg.s89ab, b_lm[lm_offset + 67]); \ + acc.s4 = dot(a_reg.s89ab, b_lm[lm_offset + 68]); \ + acc.s5 = dot(a_reg.s89ab, b_lm[lm_offset + 69]); \ + acc.s6 = dot(a_reg.s89ab, b_lm[lm_offset + 70]); \ + acc.s7 = dot(a_reg.s89ab, b_lm[lm_offset + 71]); \ + acc.s8 = dot(a_reg.s89ab, b_lm[lm_offset + 72]); \ + acc.s9 = dot(a_reg.s89ab, b_lm[lm_offset + 73]); \ + acc.sa = dot(a_reg.s89ab, b_lm[lm_offset + 74]); \ + acc.sb = dot(a_reg.s89ab, b_lm[lm_offset + 75]); \ + acc.sc = dot(a_reg.s89ab, b_lm[lm_offset + 76]); \ + acc.sd = dot(a_reg.s89ab, b_lm[lm_offset + 77]); \ + acc.se = dot(a_reg.s89ab, b_lm[lm_offset + 78]); \ + acc.sf = dot(a_reg.s89ab, b_lm[lm_offset + 79]); \ + acc.s0 += dot(a_reg.scdef, b_lm[lm_offset + 96]); \ + acc.s1 += dot(a_reg.scdef, b_lm[lm_offset + 97]); \ + acc.s2 += dot(a_reg.scdef, b_lm[lm_offset + 98]); \ + acc.s3 += dot(a_reg.scdef, b_lm[lm_offset + 99]); \ + acc.s4 += dot(a_reg.scdef, b_lm[lm_offset + 100]); \ + acc.s5 += dot(a_reg.scdef, b_lm[lm_offset + 101]); \ + acc.s6 += dot(a_reg.scdef, b_lm[lm_offset + 102]); \ + acc.s7 += dot(a_reg.scdef, b_lm[lm_offset + 103]); \ + acc.s8 += dot(a_reg.scdef, b_lm[lm_offset + 104]); \ + acc.s9 += dot(a_reg.scdef, b_lm[lm_offset + 105]); \ + acc.sa += dot(a_reg.scdef, b_lm[lm_offset + 106]); \ + acc.sb += dot(a_reg.scdef, b_lm[lm_offset + 107]); \ + acc.sc += dot(a_reg.scdef, b_lm[lm_offset + 108]); \ + acc.sd += dot(a_reg.scdef, b_lm[lm_offset + 109]); \ + acc.se += dot(a_reg.scdef, b_lm[lm_offset + 110]); \ + acc.sf += dot(a_reg.scdef, b_lm[lm_offset + 111]); \ + c_reg.lo += convert_float8(acc.lo); \ + c_reg.hi += convert_float8(acc.hi); \ + + +__attribute__((qcom_wave_pair_mode(1))) +kernel void kernel_gemm_moe_q6_k_f32_ns( + __read_only image1d_buffer_t src0_ql, + __global uint * src0_qh, + __global char * src0_s, + __global half * src0_d, + __read_only image1d_buffer_t src1, + __global uint * src2, + __global ushort * src2_emap, + __write_only image1d_buffer_t dst, + __global int * total_tiles, + uint ne00, + uint ne01 +) { + uint block_id_m = get_global_id(1); // m_tile + uint block_id_n = get_global_id(2); // n_tile + + // Boundary check + if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) { + return; + } + + __private half16 reg_a; + __private float32 reg_c = (float32)(0); + __local half4 shared_b[128]; + + const ushort expert_id = src2_emap[block_id_n]; + + const uint row = block_id_m * TILESIZE_M; + const uint col = block_id_n * TILESIZE_N; + + uint sub_block_id_m = get_local_id(0); + uint2 b_global_offset; + b_global_offset.x = ((sub_block_id_m & 3) << 2) + (sub_block_id_m >> 2) * ne00; + b_global_offset.y = b_global_offset.x + (16 * ne00); + uint2 b_local_offset; + b_local_offset.x = (sub_block_id_m & 3) * 32 + (sub_block_id_m >> 2); + b_local_offset.y = b_local_offset.x + 16; + + uint num_superblocks = ne00 / QK_K; + uint scales_per_row = num_superblocks * 16; + uint row_idx = row + get_global_id(0); + + // Loop along K axis, 32 elements per iteration (one sub-block), divided into 2 halves of 16 + for (uint step = 0; step < ne00; step += TILESIZE_K * 2) { + uint sub = step / 32; // 32-element group index + uint sb = sub / 8; // super-block index + uint j = sub % 8; // group within super-block + + // Load d for super-block + uint d_offset = row + sb * ne01 + expert_id * num_superblocks * ne01 + get_global_id(0); + half d_val = src0_d[d_offset]; + + // Load sub-block scales + global const char * sc = src0_s + (expert_id * ne01 + row_idx) * scales_per_row + sb * 16; + float scale0 = (float)d_val * (float)sc[j * 2]; + float scale1 = (float)d_val * (float)sc[j * 2 + 1]; + + uint qh_base = row + (sub * 2) * ne01 + expert_id * (num_superblocks * 16) * ne01 + get_global_id(0); + uint qh_first16 = src0_qh[qh_base]; + uint qh_second16 = src0_qh[qh_base + ne01]; + + // First half (16 elements) + uint q_sub_offset = row + ((ne01 * step) >> 3) + ((expert_id * ne00 * ne01) >> 3); + uint b_sub_offset = col * ne00 + step; + + // Load 16 ql nibbles (2 uints) from image + uint2 q4x16; + q4x16.x = read_imageui(src0_ql, q_sub_offset + sub_block_id_m).x; + q4x16.y = read_imageui(src0_ql, q_sub_offset + sub_block_id_m + ne01).x; + + // Load 16x32 floats from matrix B + float8 bx8_f32; + bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4); + bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4); + half8 bx8_f16 = convert_half8(bx8_f32); + shared_b[b_local_offset.x] = bx8_f16.lo; + shared_b[b_local_offset.y] = bx8_f16.hi; + + // Dequantize first 16 elements (scale0) + dequantize_q6_k(as_ushort4(q4x16), qh_first16, reg_a, scale0); + + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + + half16 acc; + dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0); + dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); + + // Second half + uint half_step = step + TILESIZE_K; + q_sub_offset = row + ((ne01 * half_step) >> 3) + ((expert_id * ne00 * ne01) >> 3); + b_sub_offset = col * ne00 + half_step; + + q4x16.x = read_imageui(src0_ql, q_sub_offset + sub_block_id_m).x; + q4x16.y = read_imageui(src0_ql, q_sub_offset + sub_block_id_m + ne01).x; + + bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4); + bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4); + bx8_f16 = convert_half8(bx8_f32); + shared_b[b_local_offset.x] = bx8_f16.lo; + shared_b[b_local_offset.y] = bx8_f16.hi; + + dequantize_q6_k(as_ushort4(q4x16), qh_second16, reg_a, scale1); + + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + + dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0); + dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); + } + + // Load post router and share in LM + __local uint out_idx[TILESIZE_N]; + + if (get_local_id(0) < TILESIZE_N) { + uint idx = src2[block_id_n * TILESIZE_N + get_local_id(0)]; + if (idx == 0xFFFFFFFF) { + idx = src2[block_id_n * TILESIZE_N + 0]; + } + out_idx[get_local_id(0)] = idx * ne01; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + // Scatter results back to original position in output grid + uint m_offset = row + get_local_id(0); + + write_imagef(dst, out_idx[1] + m_offset, (reg_c.s1)); + write_imagef(dst, out_idx[2] + m_offset, (reg_c.s2)); + write_imagef(dst, out_idx[3] + m_offset, (reg_c.s3)); + write_imagef(dst, out_idx[4] + m_offset, (reg_c.s4)); + write_imagef(dst, out_idx[5] + m_offset, (reg_c.s5)); + write_imagef(dst, out_idx[6] + m_offset, (reg_c.s6)); + write_imagef(dst, out_idx[7] + m_offset, (reg_c.s7)); + write_imagef(dst, out_idx[8] + m_offset, (reg_c.s8)); + write_imagef(dst, out_idx[9] + m_offset, (reg_c.s9)); + write_imagef(dst, out_idx[10] + m_offset, (reg_c.sa)); + write_imagef(dst, out_idx[11] + m_offset, (reg_c.sb)); + write_imagef(dst, out_idx[12] + m_offset, (reg_c.sc)); + write_imagef(dst, out_idx[13] + m_offset, (reg_c.sd)); + write_imagef(dst, out_idx[14] + m_offset, (reg_c.se)); + write_imagef(dst, out_idx[15] + m_offset, (reg_c.sf)); + write_imagef(dst, out_idx[16] + m_offset, (reg_c.sg)); + write_imagef(dst, out_idx[17] + m_offset, (reg_c.sh)); + write_imagef(dst, out_idx[18] + m_offset, (reg_c.si)); + write_imagef(dst, out_idx[19] + m_offset, (reg_c.sj)); + write_imagef(dst, out_idx[20] + m_offset, (reg_c.sk)); + write_imagef(dst, out_idx[21] + m_offset, (reg_c.sl)); + write_imagef(dst, out_idx[22] + m_offset, (reg_c.sm)); + write_imagef(dst, out_idx[23] + m_offset, (reg_c.sn)); + write_imagef(dst, out_idx[24] + m_offset, (reg_c.so)); + write_imagef(dst, out_idx[25] + m_offset, (reg_c.sp)); + write_imagef(dst, out_idx[26] + m_offset, (reg_c.sq)); + write_imagef(dst, out_idx[27] + m_offset, (reg_c.sr)); + write_imagef(dst, out_idx[28] + m_offset, (reg_c.ss)); + write_imagef(dst, out_idx[29] + m_offset, (reg_c.st)); + write_imagef(dst, out_idx[30] + m_offset, (reg_c.su)); + write_imagef(dst, out_idx[31] + m_offset, (reg_c.sv)); + + // Store zero padding parts to the index of first output in tile + barrier(CLK_GLOBAL_MEM_FENCE); + write_imagef(dst, out_idx[0] + m_offset, (reg_c.s0)); +} diff --git a/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl new file mode 100644 index 00000000000..13d79f2526f --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl @@ -0,0 +1,151 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable + +#define QK_K 256 +#define K_SCALE_SIZE 12 +#define N_SIMDGROUP 4 +#define SIMDGROUP_WIDTH 64 + +inline void get_scale_min_k4( + int j, + global const uchar * q, + uchar * d, + uchar * m +) { + if (j < 4) { + *d = q[j] & 63; + *m = q[j+4] & 63; + } else { + *d = (q[j+4] & 0x0F) | ((q[j-4] & 0xC0) >> 2); + *m = ((q[j+4] >> 4) & 0x0F) | ((q[j] & 0xC0) >> 2); + } +} + +static inline float8 q4_k_to_fp32_packed8(ushort2 q4x8, float scale, float minv) { + float8 fp32x8; + fp32x8.s0 = (q4x8.s0 & 0x000F) * scale - minv; + fp32x8.s1 = ((q4x8.s0 & 0x00F0) >> 4) * scale - minv; + fp32x8.s2 = ((q4x8.s0 & 0x0F00) >> 8) * scale - minv; + fp32x8.s3 = ((q4x8.s0 & 0xF000) >> 12) * scale - minv; + fp32x8.s4 = (q4x8.s1 & 0x000F) * scale - minv; + fp32x8.s5 = ((q4x8.s1 & 0x00F0) >> 4) * scale - minv; + fp32x8.s6 = ((q4x8.s1 & 0x0F00) >> 8) * scale - minv; + fp32x8.s7 = ((q4x8.s1 & 0xF000) >> 12) * scale - minv; + return fp32x8; +} + +__attribute__((qcom_reqd_sub_group_size("half"))) +__kernel void kernel_gemv_moe_q4_k_f32_ns( + __global uint * src0_q, + __global half * src0_d, + __global half * src0_dm, + __global uchar * src0_s, + __read_only image1d_buffer_t src1, + __global uint * src2, + __global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne11 +) { + uint i01 = get_global_id(0); + uint i20 = get_global_id(2); + uint sgid = get_local_id(1); + uint slid = get_sub_group_local_id(); + + uint i11 = i20 % ne11; + + uint expert_id = src2[i20]; + + int num_superblocks = ne00 / QK_K; + int num_subblocks = ne00 / 32; + int scales_per_row = num_superblocks * K_SCALE_SIZE; + + // Expert offsets in the transposed noshuffle layout + uint expert_q_offset = expert_id * (ne00 / 8) * ne01; + uint expert_d_offset = expert_id * num_superblocks * ne01; + + __private float sum = 0.0f; + + // Loop over sub-blocks of 32 elements, N_SIMDGROUP sub-blocks per iter + for (uint ib = sgid; ib < num_subblocks; ib += N_SIMDGROUP) { + uint sb = ib / 8; + uint j = ib % 8; + + // Load d and dmin for this super-block + half d_val = src0_d[expert_d_offset + sb * ne01 + i01]; + half dm_val = src0_dm[expert_d_offset + sb * ne01 + i01]; + + // Load sub-block scale and min + global const uchar * sc = src0_s + (expert_id * ne01 + i01) * scales_per_row + sb * K_SCALE_SIZE; + uchar sv, mn; + get_scale_min_k4(j, sc, &sv, &mn); + + float scale = (float)d_val * (float)sv; + float minv = (float)dm_val * (float)mn; + + // Load 4 uints of quants (32 nibbles = 32 elements) + uint q_base = expert_q_offset + ib * ne01 * 4 + i01; + + uint4 regQ; + regQ.s0 = src0_q[q_base]; + regQ.s1 = src0_q[q_base + ne01]; + regQ.s2 = src0_q[q_base + ne01 * 2]; + regQ.s3 = src0_q[q_base + ne01 * 3]; + + // Load activations: 32 floats = 8 float4s + uint y_offset = i11 * ne00 / 4 + ib * 8; + + float8 fp32x8 = q4_k_to_fp32_packed8(as_ushort2(regQ.s0), scale, minv); + + float4 shared_y4; + shared_y4 = read_imagef(src1, (y_offset + 0)); + float4 acc = shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (y_offset + 1)); + acc += shared_y4 * fp32x8.hi; + + fp32x8 = q4_k_to_fp32_packed8(as_ushort2(regQ.s1), scale, minv); + + shared_y4 = read_imagef(src1, (y_offset + 2)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (y_offset + 3)); + acc += shared_y4 * fp32x8.hi; + + fp32x8 = q4_k_to_fp32_packed8(as_ushort2(regQ.s2), scale, minv); + + shared_y4 = read_imagef(src1, (y_offset + 4)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (y_offset + 5)); + acc += shared_y4 * fp32x8.hi; + + fp32x8 = q4_k_to_fp32_packed8(as_ushort2(regQ.s3), scale, minv); + + shared_y4 = read_imagef(src1, (y_offset + 6)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (y_offset + 7)); + acc += shared_y4 * fp32x8.hi; + + sum += ((acc.s0 + acc.s1) + (acc.s2 + acc.s3)); + } + + // reduction in local memory, assumes #subgroups=4 + __local float reduceLM[SIMDGROUP_WIDTH * (N_SIMDGROUP - 1)]; + if (sgid == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = sum; + if (sgid == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = sum; + if (sgid == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = sum; + barrier(CLK_LOCAL_MEM_FENCE); + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 0 + slid]; + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 1 + slid]; + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 2 + slid]; + + // 1 output per thread in subgroup 0 + if (sgid == 0) { + dst = dst + (offsetd >> 2); + dst[i01 + i20 * ne01] = sum; + } +} diff --git a/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl new file mode 100644 index 00000000000..f128d44340a --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl @@ -0,0 +1,156 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable + +#define QK_K 256 +#define K_SCALE_SIZE 12 +#define N_SIMDGROUP 4 +#define SIMDGROUP_WIDTH 64 + +inline void get_scale_min_k4( + int j, + global const uchar * q, + uchar * d, + uchar * m +) { + if (j < 4) { + *d = q[j] & 63; + *m = q[j+4] & 63; + } else { + *d = (q[j+4] & 0x0F) | ((q[j-4] & 0xC0) >> 2); + *m = ((q[j+4] >> 4) & 0x0F) | ((q[j] & 0xC0) >> 2); + } +} + +static inline float8 q5_k_to_fp32_packed8(ushort2 qs5x8, uchar qh5x8, half s, half m) { + float8 fp32x8; + fp32x8.s0 = (float)((( qs5x8.s0 & 0x000F) | (( qh5x8 & 0x01) << 4)) * s + m); + fp32x8.s1 = (float)((((qs5x8.s0 & 0x00F0) >> 4 ) | (((qh5x8 >> 1) & 0x01) << 4)) * s + m); + fp32x8.s2 = (float)((((qs5x8.s0 & 0x0F00) >> 8 ) | (((qh5x8 >> 2) & 0x01) << 4)) * s + m); + fp32x8.s3 = (float)((((qs5x8.s0 & 0xF000) >> 12) | (((qh5x8 >> 3) & 0x01) << 4)) * s + m); + fp32x8.s4 = (float)((( qs5x8.s1 & 0x000F) | (((qh5x8 >> 4) & 0x01) << 4)) * s + m); + fp32x8.s5 = (float)((((qs5x8.s1 & 0x00F0) >> 4 ) | (((qh5x8 >> 5) & 0x01) << 4)) * s + m); + fp32x8.s6 = (float)((((qs5x8.s1 & 0x0F00) >> 8 ) | (((qh5x8 >> 6) & 0x01) << 4)) * s + m); + fp32x8.s7 = (float)((((qs5x8.s1 & 0xF000) >> 12) | (((qh5x8 >> 7) & 0x01) << 4)) * s + m); + return fp32x8; +} + +__attribute__((qcom_reqd_sub_group_size("half"))) +__kernel void kernel_gemv_moe_q5_k_f32_ns( + __global uint * src0_q, + __global uint * src0_qh, + __global half * src0_d, + __global half * src0_dm, + __global uchar * src0_s, + __read_only image1d_buffer_t src1, + __global uint * src2, + __global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne11 +) { + uint i01 = get_global_id(0); + uint i20 = get_global_id(2); + uint sgid = get_local_id(1); + uint slid = get_sub_group_local_id(); + + uint i11 = i20 % ne11; + + uint expert_id = src2[i20]; + + int num_superblocks = ne00 / QK_K; + int num_subblocks = ne00 / 32; + int scales_per_row = num_superblocks * K_SCALE_SIZE; + + // Expert offsets in the transposed noshuffle layout + uint expert_q_offset = expert_id * (ne00 / 8) * ne01; + uint expert_d_offset = expert_id * num_superblocks * ne01; + + __private float sum = 0.0f; + + // Loop over sub-blocks of 32 elements, N_SIMDGROUP sub-blocks per iter + for (uint ib = sgid; ib < num_subblocks; ib += N_SIMDGROUP) { + uint sb = ib / 8; + uint j = ib % 8; + + // Load d and dmin for this super-block + half d_val = src0_d[expert_d_offset + sb * ne01 + i01]; + half dm_val = src0_dm[expert_d_offset + sb * ne01 + i01]; + + // sub_block index = sb * 8 + j + uint expert_qh_offset = expert_id * num_superblocks * 8 * ne01; + uchar4 regQh = as_uchar4(src0_qh[expert_qh_offset + (sb * 8 + j) * ne01 + i01]); + + // Load sub-block scale and min + global const uchar * sc = src0_s + (expert_id * ne01 + i01) * scales_per_row + sb * K_SCALE_SIZE; + uchar sv, mn; + get_scale_min_k4(j, sc, &sv, &mn); + + float scale = (float)d_val * (float)sv; + float minv = -(float)dm_val * (float)mn; + + // Load 4 uints of quants (32 nibbles = 32 elements) + uint q_base = expert_q_offset + ib * ne01 * 4 + i01; + + uint4 regQ; + regQ.s0 = src0_q[q_base]; + regQ.s1 = src0_q[q_base + ne01]; + regQ.s2 = src0_q[q_base + ne01 * 2]; + regQ.s3 = src0_q[q_base + ne01 * 3]; + + // Load activations: 32 floats = 8 float4s + uint y_offset = i11 * ne00 / 4 + ib * 8; + + float8 fp32x8 = q5_k_to_fp32_packed8(as_ushort2(regQ.s0), regQh.s0, scale, minv); + + float4 shared_y4; + shared_y4 = read_imagef(src1, (y_offset + 0)); + float4 acc = shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (y_offset + 1)); + acc += shared_y4 * fp32x8.hi; + + fp32x8 = q5_k_to_fp32_packed8(as_ushort2(regQ.s1), regQh.s1, scale, minv); + + shared_y4 = read_imagef(src1, (y_offset + 2)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (y_offset + 3)); + acc += shared_y4 * fp32x8.hi; + + fp32x8 = q5_k_to_fp32_packed8(as_ushort2(regQ.s2), regQh.s2, scale, minv); + + shared_y4 = read_imagef(src1, (y_offset + 4)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (y_offset + 5)); + acc += shared_y4 * fp32x8.hi; + + fp32x8 = q5_k_to_fp32_packed8(as_ushort2(regQ.s3), regQh.s3, scale, minv); + + shared_y4 = read_imagef(src1, (y_offset + 6)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (y_offset + 7)); + acc += shared_y4 * fp32x8.hi; + + sum += ((acc.s0 + acc.s1) + (acc.s2 + acc.s3)); + } + + // reduction in local memory, assumes #subgroups=4 + __local float reduceLM[SIMDGROUP_WIDTH * (N_SIMDGROUP - 1)]; + if (sgid == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = sum; + if (sgid == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = sum; + if (sgid == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = sum; + barrier(CLK_LOCAL_MEM_FENCE); + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 0 + slid]; + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 1 + slid]; + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 2 + slid]; + + // 1 output per thread in subgroup 0 + if (sgid == 0) { + dst = dst + (offsetd >> 2); + dst[i01 + i20 * ne01] = sum; + } +} diff --git a/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl new file mode 100644 index 00000000000..526e609dc3a --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl @@ -0,0 +1,137 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable + +#define QK_K 256 +#define N_SIMDGROUP 4 +#define SIMDGROUP_WIDTH 64 + +static inline float8 q6_k_to_fp32_packed8(ushort2 ql8, ushort qh8, float d_scale) { + float8 fp32x8; + fp32x8.s0 = ((float)(( ql8.s0 & 0x000F) | ((uint)((qh8 ) & 0x3) << 4)) - 32.f) * d_scale; + fp32x8.s1 = ((float)((( ql8.s0 >> 4) & 0x000F) | ((uint)((qh8 >> 2) & 0x3) << 4)) - 32.f) * d_scale; + fp32x8.s2 = ((float)((( ql8.s0 >> 8) & 0x000F) | ((uint)((qh8 >> 4) & 0x3) << 4)) - 32.f) * d_scale; + fp32x8.s3 = ((float)((( ql8.s0 >> 12)& 0x000F) | ((uint)((qh8 >> 6) & 0x3) << 4)) - 32.f) * d_scale; + fp32x8.s4 = ((float)(( ql8.s1 & 0x000F) | ((uint)((qh8 >> 8) & 0x3) << 4)) - 32.f) * d_scale; + fp32x8.s5 = ((float)((( ql8.s1 >> 4) & 0x000F) | ((uint)((qh8 >>10) & 0x3) << 4)) - 32.f) * d_scale; + fp32x8.s6 = ((float)((( ql8.s1 >> 8) & 0x000F) | ((uint)((qh8 >>12) & 0x3) << 4)) - 32.f) * d_scale; + fp32x8.s7 = ((float)((( ql8.s1 >> 12)& 0x000F) | ((uint)((qh8 >>14) & 0x3) << 4)) - 32.f) * d_scale; + return fp32x8; +} + +__attribute__((qcom_reqd_sub_group_size("half"))) +__kernel void kernel_gemv_moe_q6_k_f32_ns( + __global uint * src0_ql, + __global uint * src0_qh, + __global char * src0_s, + __global half * src0_d, + __read_only image1d_buffer_t src1, + __global uint * src2, + __global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne11 +) { + uint i01 = get_global_id(0); + uint i20 = get_global_id(2); + uint sgid = get_local_id(1); + uint slid = get_sub_group_local_id(); + + uint i11 = i20 % ne11; + + uint expert_id = src2[i20]; + + int num_superblocks = ne00 / QK_K; + int num_subblocks = ne00 / 32; // 8 sub-blocks of 32 per super-block + int scales_per_row = num_superblocks * 16; + + // Expert offsets in the transposed noshuffle layout + uint expert_ql_offset = expert_id * (ne00 / 8) * ne01; // 32 uints per super-block + uint expert_qh_offset = expert_id * (ne00 / 16) * ne01; // 16 uints per super-block + uint expert_d_offset = expert_id * num_superblocks * ne01; + + __private float sum = 0.0f; + + // Loop over sub-blocks of 32 elements, N_SIMDGROUP sub-blocks per iter + for (uint ib = sgid; ib < num_subblocks; ib += N_SIMDGROUP) { + uint sb = ib / 8; // super-block index + uint j = ib % 8; // 32-element group within super-block + + // Load d for this super-block + half d_val = src0_d[expert_d_offset + sb * ne01 + i01]; + + // Load 2 sub-block scales + global const char * sc = src0_s + (expert_id * ne01 + i01) * scales_per_row + sb * 16; + float scale0 = (float)d_val * (float)sc[j * 2]; + float scale1 = (float)d_val * (float)sc[j * 2 + 1]; + + // Load 4 uints of ql + uint ql_base = expert_ql_offset + (ib * 4) * ne01 + i01; + uint4 regQL; + regQL.s0 = src0_ql[ql_base]; + regQL.s1 = src0_ql[ql_base + ne01]; + regQL.s2 = src0_ql[ql_base + ne01 * 2]; + regQL.s3 = src0_ql[ql_base + ne01 * 3]; + + // Load 2 uints of qh + uint qh_base = expert_qh_offset + (ib * 2) * ne01 + i01; + uint2 regQH; + regQH.s0 = src0_qh[qh_base]; + regQH.s1 = src0_qh[qh_base + ne01]; + + // Load activations: 32 floats = 8 float4s + uint y_offset = i11 * ne00 / 4 + ib * 8; + + float8 fp32x8 = q6_k_to_fp32_packed8(as_ushort2(regQL.s0), (ushort)(regQH.s0 & 0xFFFF), scale0); + + float4 shared_y4; + shared_y4 = read_imagef(src1, (y_offset + 0)); + float4 acc = shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (y_offset + 1)); + acc += shared_y4 * fp32x8.hi; + + fp32x8 = q6_k_to_fp32_packed8(as_ushort2(regQL.s1), (ushort)(regQH.s0 >> 16), scale0); + + shared_y4 = read_imagef(src1, (y_offset + 2)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (y_offset + 3)); + acc += shared_y4 * fp32x8.hi; + + fp32x8 = q6_k_to_fp32_packed8(as_ushort2(regQL.s2), (ushort)(regQH.s1 & 0xFFFF), scale1); + + shared_y4 = read_imagef(src1, (y_offset + 4)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (y_offset + 5)); + acc += shared_y4 * fp32x8.hi; + + fp32x8 = q6_k_to_fp32_packed8(as_ushort2(regQL.s3), (ushort)(regQH.s1 >> 16), scale1); + + shared_y4 = read_imagef(src1, (y_offset + 6)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (y_offset + 7)); + acc += shared_y4 * fp32x8.hi; + + sum += ((acc.s0 + acc.s1) + (acc.s2 + acc.s3)); + } + + // reduction in local memory, assumes #subgroups=4 + __local float reduceLM[SIMDGROUP_WIDTH * (N_SIMDGROUP - 1)]; + if (sgid == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = sum; + if (sgid == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = sum; + if (sgid == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = sum; + barrier(CLK_LOCAL_MEM_FENCE); + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 0 + slid]; + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 1 + slid]; + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 2 + slid]; + + // 1 output per thread in subgroup 0 + if (sgid == 0) { + dst = dst + (offsetd >> 2); + dst[i01 + i20 * ne01] = sum; + } +} From 0a0a34287e84641c34679c1a38fed299a56e9d4b Mon Sep 17 00:00:00 2001 From: ravel7524 <58877666+ravel7524@users.noreply.github.com> Date: Wed, 20 May 2026 03:52:21 +0200 Subject: [PATCH 658/831] ggml-cuda: tune RDNA3 Q6_K MMVQ nwarps (llama/23349) --- ggml/src/ggml-cuda/mmvq.cu | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index da48f313a38..73a0991e206 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -359,7 +359,9 @@ static constexpr __host__ __device__ int calc_nwarps(ggml_type type, int ncols_d case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: case GGML_TYPE_Q4_K: + return 8; case GGML_TYPE_Q6_K: + return 2; case GGML_TYPE_IQ4_NL: return 8; default: From c58fc465dfed99c3e51a32e27a76d82f19f6481c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 20 May 2026 09:42:00 +0300 Subject: [PATCH 659/831] metal : optimize pad + cpy (llama/23354) * metal : optimize pad * metal : optinmize cpy * cont : better row packing in threadgroup --- ggml/src/ggml-metal/ggml-metal-device.cpp | 8 +- ggml/src/ggml-metal/ggml-metal-ops.cpp | 17 ++- ggml/src/ggml-metal/ggml-metal.metal | 128 ++++++++++++---------- 3 files changed, 90 insertions(+), 63 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index e288a27f992..ba006d9b31a 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -1897,7 +1897,11 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad(ggml_metal_l char base[256]; char name[256]; - snprintf(base, 256, "kernel_pad_%s", ggml_type_name(op->src[0]->type)); + // note: this is slower + //const bool is_c4 = op->src[0]->ne[0] % 4 == 0 && op->ne[0] % 4 == 0; + const bool is_c4 = false; + + snprintf(base, 256, "kernel_pad_%s%s", ggml_type_name(op->src[0]->type), is_c4 ? "_4" : ""); snprintf(name, 256, "%s", base); ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); @@ -1907,6 +1911,8 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad(ggml_metal_l res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + res.c4 = is_c4; + return res; } diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index a114391c2e8..8506000b6c0 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -816,9 +816,7 @@ int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1); } else { const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); - const int nth = MIN(args.ne00, nth_max); - const int nk0 = (args.ne00 + nth - 1)/nth; ggml_metal_encoder_dispatch_threadgroups(enc, nk0*ne01, ne02, ne03, nth, 1, 1); @@ -1863,7 +1861,7 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) { nk0 = ne00/ggml_blck_size(op->type); } - int nth = std::min(nk0, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + int nth = std::min(nk0*ne01, 256); // when rows are small, we can batch them together in a single threadgroup int nrptg = 1; @@ -1874,7 +1872,7 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) { nrptg = (nth + nk0 - 1)/nk0; nth = nk0; - if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + if (nrptg*nth > 256) { nrptg--; } } @@ -4039,14 +4037,21 @@ int ggml_metal_op_pad(ggml_metal_op_t ctx, int idx) { auto pipeline = ggml_metal_library_get_pipeline_pad(lib, op); - const int nth = std::min(1024, ne0); + if (pipeline.c4) { + args.ne00 = ne00/4; + args.ne0 = ne0/4; + } + + const int nth_max = MIN(64, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + const int nth = MIN(args.ne0, nth_max); + const int nk0 = (args.ne0 + 1024 - 1)/1024; // note: 1024 is hardcoded in the kernel! ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); - ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1); + ggml_metal_encoder_dispatch_threadgroups(enc, nk0*ne1, ne2, ne3, nth, 1, 1); return 1; } diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index f6ffb2b3a1c..4cf9dbea946 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -2643,7 +2643,7 @@ kernel void kernel_gated_delta_net_impl( b_ptr += args.ne21; g_ptr += args.ne21*G; - if (K > 1u) { + if (K > 1) { const int target_slot = (int)t - shift; if (target_slot >= 0 && target_slot < (int)K) { device float * dst_state = (device float *) (dst) + attn_size + (uint)target_slot * state_size_per_snap + state_out_base; @@ -2655,7 +2655,7 @@ kernel void kernel_gated_delta_net_impl( } } - if (K == 1u) { + if (K == 1) { device float * dst_state = (device float *) (dst) + attn_size + state_out_base; FOR_UNROLL (short j = 0; j < NSG; j++) { const short is = tx*NSG + j; @@ -5104,7 +5104,7 @@ kernel void kernel_upscale_bilinear_f32( for (int64_t sx = x_min; sx < x_max; ++sx) { const float wx = MAX(0.0f, 1.0f - fabs((float)sx - f00) * invscale0); const float w = wx * wy; - const device const float * src_ptr = (device const float *)(src0 + sy*args.nb01 + sx*args.nb00); + device const float * src_ptr = (device const float *)(src0 + sy*args.nb01 + sx*args.nb00); sum += (*src_ptr) * w; wsum += w; } @@ -5286,7 +5286,7 @@ kernel void kernel_upscale_bicubic_f32( const int64_t ix = MAX(0, MIN(args.ne00 - 1, i00 + dx)); const float wx = (dx == -1) ? w_x0 : (dx == 0) ? w_x1 : (dx == 1) ? w_x2 : w_x3; - const device const float * src_ptr = (device const float *)(src_slice + iy * args.nb01 + ix * args.nb00); + device const float * src_ptr = (device const float *)(src_slice + iy * args.nb01 + ix * args.nb00); sum += (*src_ptr) * wx * wy; } } @@ -5329,42 +5329,46 @@ kernel void kernel_roll_f32( } } -kernel void kernel_pad_f32( +template +kernel void kernel_pad_impl( constant ggml_metal_kargs_pad & args, device const char * src0, device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { + const int32_t i3 = tgpig.z; + const int32_t i2 = tgpig.y; + const int32_t k0 = tgpig.x/args.ne1; + const int32_t i1 = tgpig.x - k0*args.ne1; - const int64_t i3 = tgpig.z; - const int64_t i2 = tgpig.y; - const int64_t i1 = tgpig.x; + const int32_t i03 = i3; + const int32_t i02 = i2; + const int32_t i01 = i1; - const int64_t i03 = i3; - const int64_t i02 = i2; - const int64_t i01 = i1; - - device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01); - device float * dst_ptr = (device float *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1); + device const T * src0_ptr = (device const T *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01); + device T * dst_ptr = (device T *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1); - if (i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) { - for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { - if (i0 < args.ne00) { - dst_ptr[i0] = src0_ptr[i0]; - } else { - dst_ptr[i0] = 0.0f; - } + for (int32_t l0 = 0; l0 < 1024; l0 += ntg.x) { + const int32_t i0 = k0*1024 + tpitg.x + l0; + if (i0 >= args.ne0) { + break; } - return; - } - - for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { - dst_ptr[i0] = 0.0f; + if (i0 < args.ne00 && i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) { + dst_ptr[i0] = src0_ptr[i0]; + } else { + dst_ptr[i0] = 0.0f; + } } } +typedef decltype(kernel_pad_impl) kernel_pad_t; + +template [[host_name("kernel_pad_f32")]] kernel kernel_pad_t kernel_pad_impl; +template [[host_name("kernel_pad_f32_4")]] kernel kernel_pad_t kernel_pad_impl; + +// TODO: this is slow - optimize kernel void kernel_pad_reflect_1d_f32( constant ggml_metal_kargs_pad_reflect_1d & args, device const char * src0, @@ -7328,23 +7332,27 @@ kernel void kernel_cpy_t_t( device const char * src0, device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], - ushort tiitg[[thread_index_in_threadgroup]], + ushort3 tpitg[[thread_position_in_threadgroup]], ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig[2]; - const int i02 = tgpig[1]; - const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0]; - const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0; + const int32_t i03 = tgpig[2]; + const int32_t i02 = tgpig[1]; + const int32_t i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tpitg.y; + const int32_t iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0; + + if (i01 >= args.ne01) { + return; + } const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; - const int64_t i3 = n/(args.ne2*args.ne1*args.ne0); - const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0); - const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0; - const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0); + const int32_t i3 = n/(args.ne2*args.ne1*args.ne0); + const int32_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0); + const int32_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0; + const int32_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0); device T1 * dst_data = (device T1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.ne00; ) { + for (int32_t i00 = iw0*ntg[0] + tpitg.x; i00 < args.ne00;) { device const T0 * src = (device T0 *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); dst_data[i00] = (T1) src[0]; break; @@ -7376,23 +7384,27 @@ kernel void kernel_cpy_f32_q( device const char * src0, device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], - ushort tiitg[[thread_index_in_threadgroup]], + ushort3 tpitg[[thread_position_in_threadgroup]], ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig[2]; - const int i02 = tgpig[1]; - const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0]; - const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0; + const int32_t i03 = tgpig[2]; + const int32_t i02 = tgpig[1]; + const int32_t i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tpitg.y; + const int32_t iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0; + + if (i01 >= args.ne01) { + return; + } const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; - const int64_t i3 = n / (args.ne2*args.ne1*args.ne0); - const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0); - const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0; - const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK; + const int32_t i3 = n / (args.ne2*args.ne1*args.ne0); + const int32_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0); + const int32_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0; + const int32_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK; device block_q * dst_data = (device block_q *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.nk0; ) { + for (int32_t i00 = iw0*ntg[0] + tpitg.x; i00 < args.nk0;) { device const float * src = (device const float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + (i00*QK)*args.nb00); quantize_func(src, dst_data[i00]); @@ -7417,24 +7429,28 @@ kernel void kernel_cpy_q_f32( device const char * src0, device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], - ushort tiitg[[thread_index_in_threadgroup]], + ushort3 tpitg[[thread_position_in_threadgroup]], ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig[2]; - const int i02 = tgpig[1]; - const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0]; - const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0; + const int32_t i03 = tgpig[2]; + const int32_t i02 = tgpig[1]; + const int32_t i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tpitg.y; + const int32_t iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0; + + if (i01 >= args.ne01) { + return; + } const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; - const int64_t i3 = n/(args.ne2*args.ne1*args.ne0); - const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0); - const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0; - const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0); + const int32_t i3 = n/(args.ne2*args.ne1*args.ne0); + const int32_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0); + const int32_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0; + const int32_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0); device const block_q * src_data = (device const block_q *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01); device T4x4 * dst_data = (device T4x4 *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.nk0; ) { + for (int32_t i00 = iw0*ntg[0] + tpitg.x; i00 < args.nk0;) { T4x4 temp; dequantize_func(src_data + i00/nl, i00%nl, temp); dst_data[i00] = temp; From 3fa19558f223461fa30164384a65f592d63577ca Mon Sep 17 00:00:00 2001 From: Andreas Kieslinger <47689530+aendk@users.noreply.github.com> Date: Wed, 20 May 2026 13:59:02 +0200 Subject: [PATCH 660/831] Programmatic Dependent Launch (PDL) for more performance on newer NVIDIA GPUs (Hopper+) (llama/22522) * Adds initial PDL setup. * Adds PDL barriers based on simple heuristic: place "sync" before first input pointer access, and "launch" after last write, e.g. to tensors like dst. * Further optimization pass of the first half of kernels * Optimized PDL barriers for the second batch of kernels * Further refinements after rebase. * Moves pdl logic to separate function, removes some whitespace * Strips post-hoc PDL logic * Adds stream capture PDL setup. Enrolls quantize_q8_1 to leverage pdl to overlap execution with previous kernels * Enrolls mul_mat_vec_q, rms_norm_f32 and k_bin_bcast (partly) into PDL * Enrolls mmvf, rope, set-rows and topk kernels for gpt-oss into PDL * Introduce ggml_cuda_kernel_launch, to abstract away cudaLaunchKernelEx, to enable hip/musa compatibility * Enrolls cpy_scalar_contiguous, k_get_rows_float and rms_norm_f32 * Enrolls flash_attn_combine_results * Fix: Drops needless and broken check of CUDA arch for PDL. PDL either works or is without effect. * Enrolls flash-attention kernels to pdl * Fix: inlines ggml_cuda_kernel_launch, and uses perfect forwarding for kernels args. This fixes PDL. * Perf: Enrolls k_bin_bcast variadic template invocation into PDL, via and template alias and template expansion * Enrolls all remaining kernels for qwen3-coder-next into PDL * Remove all PDL LC calls to create a baseline * Added LC according to internal guidance and tested kernel performance. * Enrols missing qwen3-5 kernels passively into PDL. * Kernel optimizations (LC signals) for qwen3.5 * Enrolls ssm-scan kernels into PDL * Adds GGML_CUDA_PDL command line option to toggle PDL. * Fix: Ada and lower compilation by guarding PDL calls correctly * Cleanup: Removes commented out GGML_CUDA_PDL_LC * Cleanup: Removes experimental comments * Adds 90-virtual to build script so that Hopper GPUs can leverage PDL. * Adds stricter checks to enable PDL, adds env-check to disable it, and removes now superfluous compile option to enable PDL. * Fix: Correct PDL en/disablement based on device-side arch check. Host side check is UB. Required moving from macros to inlined functions * Fix: default-disable PDL. Enable by setting GGML_CUDA_ENABLE_PDL=1 * Enable PDL by default for Hopper+ devices * Enrolls softcap_f32 and two flash_attn kernels into PDL. * Improves flash attn PDL barrier placement * Fix: Perf regression on ada; excludes ada and below from PDL launches * Improves some sync barrier placements * Drops superfluous constructor * Adds #endif guard comments * Reverts experimental change to top-k-moe.cu, which moved expensive allocations in front of the PDL barrier. It did not have a meaningful impact. * Exchanges GGML_CUDA_DISABLE_PDL with GGML_CUDA_PDL. IFF GGML_CUDA_PDL=0 PDL is disabled * Revert "Drops superfluous constructor". Adds const to remaining arguments This reverts commit 12b1d250da0089ae02a9bb71bbb3fd6d70f6f2f1. * Cleanup: Removes and fixes some comments and whitespace * Clarifies comment of sync-barrier position * Relocates and refactors PDL launch functions and accessories * Adds error checking to the regular kernel launch path * Drops "auto" in favor of "ggml_cuda_kernel_params" * Adds "const" to ggml_cuda_kernel_launch_params * [Whitespace] Adds final newline to common.cuh to make editorconfig CI job happy --- ggml/src/ggml-cuda/CMakeLists.txt | 3 +- ggml/src/ggml-cuda/binbcast.cu | 32 +++++----- ggml/src/ggml-cuda/common.cuh | 86 +++++++++++++++++++++++++++ ggml/src/ggml-cuda/concat.cu | 5 +- ggml/src/ggml-cuda/cpy.cu | 20 +++++-- ggml/src/ggml-cuda/fattn-common.cuh | 28 +++++---- ggml/src/ggml-cuda/fattn-mma-f16.cuh | 1 + ggml/src/ggml-cuda/fattn-tile.cuh | 2 + ggml/src/ggml-cuda/fattn-vec.cuh | 3 + ggml/src/ggml-cuda/fattn-wmma-f16.cu | 1 + ggml/src/ggml-cuda/gated_delta_net.cu | 11 ++-- ggml/src/ggml-cuda/getrows.cu | 7 ++- ggml/src/ggml-cuda/mean.cu | 6 +- ggml/src/ggml-cuda/mmvf.cu | 12 ++-- ggml/src/ggml-cuda/mmvq.cu | 11 ++-- ggml/src/ggml-cuda/norm.cu | 46 ++++++++++---- ggml/src/ggml-cuda/quantize.cu | 7 ++- ggml/src/ggml-cuda/reduce_rows.cuh | 2 + ggml/src/ggml-cuda/rope.cu | 15 +++-- ggml/src/ggml-cuda/scale.cu | 5 +- ggml/src/ggml-cuda/set-rows.cu | 11 +++- ggml/src/ggml-cuda/softcap.cu | 5 +- ggml/src/ggml-cuda/ssm-conv.cu | 8 ++- ggml/src/ggml-cuda/ssm-scan.cu | 27 +++++---- ggml/src/ggml-cuda/sumrows.cu | 12 ++-- ggml/src/ggml-cuda/topk-moe.cu | 47 ++++++++------- ggml/src/ggml-cuda/unary.cu | 10 +++- 27 files changed, 310 insertions(+), 113 deletions(-) diff --git a/ggml/src/ggml-cuda/CMakeLists.txt b/ggml/src/ggml-cuda/CMakeLists.txt index b54d4a6b107..d3953eee962 100644 --- a/ggml/src/ggml-cuda/CMakeLists.txt +++ b/ggml/src/ggml-cuda/CMakeLists.txt @@ -15,6 +15,7 @@ if (CUDAToolkit_FOUND) # 80 == Ampere, asynchronous data loading, faster tensor core instructions # 86 == RTX 3000, needs CUDA v11.1 # 89 == RTX 4000, needs CUDA v11.8 + # 90 == Hopper H100/200, needs CUDA v11.8 # 120 == Blackwell, needs CUDA v12.8, FP4 tensor cores # # XX-virtual == compile CUDA code as PTX, do JIT compilation to binary code on first run @@ -33,7 +34,7 @@ if (CUDAToolkit_FOUND) list(APPEND CMAKE_CUDA_ARCHITECTURES 75-virtual 80-virtual 86-real) if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "11.8") - list(APPEND CMAKE_CUDA_ARCHITECTURES 89-real) + list(APPEND CMAKE_CUDA_ARCHITECTURES 89-real 90-virtual) endif() if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "12.8") diff --git a/ggml/src/ggml-cuda/binbcast.cu b/ggml/src/ggml-cuda/binbcast.cu index adb4d5f0cb9..c25f42b32bb 100644 --- a/ggml/src/ggml-cuda/binbcast.cu +++ b/ggml/src/ggml-cuda/binbcast.cu @@ -2,6 +2,9 @@ #include #include +template +using type_for_index = T; + static __device__ __forceinline__ float op_repeat(const float a, const float b) { return b; GGML_UNUSED(a); @@ -52,6 +55,7 @@ static __global__ void k_bin_bcast(const src0_t * src0, const int s12, const int s13, src1_ptrs... src1s) { + ggml_cuda_pdl_lc(); const uint32_t i0s = blockDim.x * blockIdx.x + threadIdx.x; const uint32_t i1 = (blockDim.y * blockIdx.y + threadIdx.y); const uint32_t i2 = fastdiv((blockDim.z * blockIdx.z + threadIdx.z), ne3); @@ -72,6 +76,7 @@ static __global__ void k_bin_bcast(const src0_t * src0, const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr; dst_t * dst_row = dst + i_dst; + ggml_cuda_pdl_sync(); for (int i0 = i0s; i0 < ne0; i0 += blockDim.x * gridDim.x) { const uint32_t i10 = fastmodulo(i0, ne10); @@ -141,6 +146,7 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const int i10 = fastmodulo(i0, ne10); + ggml_cuda_pdl_sync(); float result = src0_row ? (float) src0_row[i0*s00] : 0.0f; if constexpr (sizeof...(src1_ptrs) > 0) { result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10*s10]))); @@ -282,35 +288,24 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor * const uint3 ne1_fastdiv = init_fastdiv_values((uint32_t) ne1); const uint3 ne2_fastdiv = init_fastdiv_values((uint32_t) ne2); - if constexpr (sizeof...(I) > 0) { - k_bin_bcast_unravel<<>>( + { + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params((dim3)block_num, block_size, 0, stream); + ggml_cuda_kernel_launch(k_bin_bcast_unravel...>, launch_params, src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv, ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11, ne12, ne13, /*s0,*/ s1, s2, s3, s00, s01, s02, s03, s10, s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...); - } else { - k_bin_bcast_unravel - <<>>(src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv, - ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11, ne12, ne13, - /*s0,*/ s1, s2, s3, - s00, s01, s02, s03, - s10, s11, s12, s13); } } else { const uint3 ne3_fastdiv = init_fastdiv_values((uint32_t) ne3); - if constexpr (sizeof...(I) > 0) { - k_bin_bcast<<>>( + { + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream); + ggml_cuda_kernel_launch(k_bin_bcast...>, launch_params, src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13, /*s0,*/ s1, s2, s3, - s00 ,s01, s02, s03, - s10, s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...); - } else { - k_bin_bcast<<>>( - src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13, - /*s0,*/ s1, s2, s3, s00, s01, s02, s03, - s10, s11, s12, s13); + s10, s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...); } } } @@ -333,6 +328,7 @@ static __global__ void k_repeat_back( } T sum = 0; + ggml_cuda_pdl_sync(); for (int64_t i3 = tid3; i3 < ne03; i3 += ne3) { for (int64_t i2 = tid2; i2 < ne02; i2 += ne2) { for (int64_t i1 = tid1; i1 < ne01; i1 += ne1) { diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 10817505d9f..9c73fe7e6fa 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -5,6 +5,7 @@ #include "ggml-cuda.h" #include +#include #include #if defined(GGML_USE_HIP) @@ -27,6 +28,7 @@ #include #include #include +#include #include #if defined(GGML_USE_HIP) @@ -50,6 +52,7 @@ #define GGML_CUDA_CC_TURING 750 #define GGML_CUDA_CC_AMPERE 800 #define GGML_CUDA_CC_ADA_LOVELACE 890 +#define GGML_CUDA_CC_HOPPER 900 // While BW spans CC 1000, 1100 & 1200, we are integrating Tensor Core instructions available to 1200 family, see // https://docs.nvidia.com/cutlass/media/docs/cpp/blackwell_functionality.html#blackwell-sm120-gemms #define GGML_CUDA_CC_BLACKWELL 1200 @@ -107,6 +110,24 @@ # define GGML_CUDA_USE_CUB #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070 +// PDL host-side support (cudaLaunchKernelEx) requires CUDART >= 11.8 and excludes HIP/MUSA. +// __CUDA_ARCH__ is undefined in host passes; GPU arch check happens in device-side code. +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11080 +# define GGML_CUDA_USE_PDL +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11080 + +static __device__ __forceinline__ void ggml_cuda_pdl_sync() { +#if defined(GGML_CUDA_USE_PDL) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER + cudaGridDependencySynchronize(); +#endif // defined(GGML_CUDA_USE_PDL) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER +} + +static __device__ __forceinline__ void ggml_cuda_pdl_lc() { +#if defined(GGML_CUDA_USE_PDL) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER + cudaTriggerProgrammaticLaunchCompletion(); +#endif // defined(GGML_CUDA_USE_PDL) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER +} + #ifdef __CUDA_ARCH_LIST__ constexpr bool ggml_cuda_has_arch_impl(int) { return false; @@ -165,6 +186,7 @@ void ggml_cuda_error(const char * stmt, const char * func, const char * file, in #define CUDA_CHECK(err) CUDA_CHECK_GEN(err, cudaSuccess, cudaGetErrorString) + #if CUDART_VERSION >= 12000 || defined(GGML_USE_MUSA) static const char * cublas_get_error_str(const cublasStatus_t err) { return cublasGetStatusString(err); @@ -1487,3 +1509,67 @@ struct ggml_cuda_mm_fusion_args_device { const void * gate_bias = nullptr; ggml_glu_op glu_op; }; + +struct ggml_cuda_kernel_launch_params { + dim3 block_nums; + dim3 block_dims; + size_t shmem; + cudaStream_t stream; + + // size_t shmem + ggml_cuda_kernel_launch_params(const dim3& block_nums_, const dim3& block_dims_, const size_t shmem_, const cudaStream_t stream_) + : block_nums(block_nums_), block_dims(block_dims_), shmem(shmem_), stream(stream_) {} + + // Some call sites pass ints instead of the required size_t. This 2nd constructor casts int->size_t to avoid these -Wnarrowing warnings. + ggml_cuda_kernel_launch_params(const dim3& block_nums_, const dim3& block_dims_, const int shmem_, const cudaStream_t stream_) + : block_nums(block_nums_), block_dims(block_dims_), shmem((size_t)shmem_), stream(stream_) {} +}; + +#if defined(GGML_CUDA_USE_PDL) +struct ggml_cuda_pdl_config { + cudaLaunchAttribute attr; + cudaLaunchConfig_t cfg; + + ggml_cuda_pdl_config(const ggml_cuda_kernel_launch_params & params) { + attr.id = cudaLaunchAttributeProgrammaticStreamSerialization; + attr.val.programmaticStreamSerializationAllowed = 1; + + cfg = {}; + cfg.gridDim = params.block_nums; + cfg.blockDim = params.block_dims; + cfg.dynamicSmemBytes = params.shmem; + cfg.stream = params.stream; + cfg.attrs = &attr; + cfg.numAttrs = 1; + } + + // Delete due to &attr + ggml_cuda_pdl_config(const ggml_cuda_pdl_config&) = delete; + ggml_cuda_pdl_config& operator=(const ggml_cuda_pdl_config&) = delete; + ggml_cuda_pdl_config& operator=(ggml_cuda_pdl_config&&) = delete; + +}; +#endif //defined(GGML_CUDA_USE_PDL) + + +template +static __inline__ void ggml_cuda_kernel_launch(Kernel kernel, const ggml_cuda_kernel_launch_params & launch_params, Args&&... args) { +#if defined(GGML_CUDA_USE_PDL) + + static const bool env_pdl_enabled = []() { + const char * env = getenv("GGML_CUDA_PDL"); + return env == nullptr || std::atoi(env) != 0; + }(); + + if (env_pdl_enabled && ggml_cuda_info().devices[ggml_cuda_get_device()].cc >= GGML_CUDA_CC_HOPPER) { + auto pdl_cfg = ggml_cuda_pdl_config(launch_params); + + CUDA_CHECK(cudaLaunchKernelEx(&pdl_cfg.cfg, kernel, std::forward(args)... )); + return; + } +#endif //defined(GGML_CUDA_USE_PDL) + + kernel<<>>(std::forward(args)... ); + CUDA_CHECK(cudaGetLastError()); +} + diff --git a/ggml/src/ggml-cuda/concat.cu b/ggml/src/ggml-cuda/concat.cu index 102f944f924..adba4d522a4 100644 --- a/ggml/src/ggml-cuda/concat.cu +++ b/ggml/src/ggml-cuda/concat.cu @@ -15,6 +15,7 @@ static __global__ void __launch_bounds__(CUDA_CONCAT_BLOCK_SIZE) concat_f32_cont const int64_t n = ne0 * ne1 * ne2; + ggml_cuda_pdl_sync(); for (int64_t i = (int64_t) blockIdx.x * blockDim.x + threadIdx.x; i < n; i += (int64_t) blockDim.x * gridDim.x) { if constexpr (dim == 0) { const int64_t row = i / ne0; @@ -64,8 +65,8 @@ static void concat_f32_cuda(const float * x, const int num_blocks = (n + CUDA_CONCAT_BLOCK_SIZE - 1) / CUDA_CONCAT_BLOCK_SIZE; if (dim == 0) { - concat_f32_cont<0> - <<>>(x, y, dst, ne00, ne01, ne02, ne0, ne1, ne2); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(num_blocks, CUDA_CONCAT_BLOCK_SIZE, 0, stream); + ggml_cuda_kernel_launch(concat_f32_cont<0>, launch_params,x, y, dst, ne00, ne01, ne02, ne0, ne1, ne2); return; } if (dim == 1) { diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index d208acf2d5f..121472ec228 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -16,6 +16,7 @@ static __global__ void cpy_scalar(const char * cx, char * cdst, const int64_t ne const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13) { + ggml_cuda_pdl_lc(); const int64_t i = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; if (i >= ne) { @@ -36,6 +37,7 @@ static __global__ void cpy_scalar(const char * cx, char * cdst, const int64_t ne const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10; const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13 * nb13; + ggml_cuda_pdl_sync(); cpy_1(cx + x_offset, cdst + dst_offset); } @@ -59,6 +61,7 @@ static __global__ void cpy_scalar_transpose(const char * cx, char * cdst, const __shared__ float tile[2][CUDA_CPY_TILE_DIM_2D][CUDA_CPY_TILE_DIM_2D+1]; int cur_tile_buf = 0; + ggml_cuda_pdl_sync(); #pragma unroll for (int i = 0; i < CUDA_CPY_BLOCK_NM; ++i) { @@ -142,6 +145,7 @@ static __global__ void cpy_f32_q(const char * cx, char * cdst, const int64_t ne, const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10; const int64_t dst_offset = (i10/qk)*nb10 + i11*nb11 + i12*nb12 + i13*nb13; + ggml_cuda_pdl_sync(); cpy_blck(cx + x_offset, cdst + dst_offset); } @@ -168,6 +172,7 @@ static __global__ void cpy_q_f32(const char * cx, char * cdst, const int64_t ne, const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10; const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13; + ggml_cuda_pdl_sync(); cpy_blck(cx + x_offset, cdst + dst_offset); } @@ -182,6 +187,7 @@ static __global__ void cpy_scalar_contiguous(const char * cx, char * cdst, const const src_t * x = (const src_t *) cx; dst_t * dst = (dst_t *) cdst; + ggml_cuda_pdl_sync(); dst[i] = ggml_cuda_cast(x[i]); } @@ -192,8 +198,8 @@ cudaStream_t stream) { const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; GGML_ASSERT(num_blocks < UINT_MAX); - cpy_scalar_contiguous<<>> - (cx, cdst, ne); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params((dim3)num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream); + ggml_cuda_kernel_launch(cpy_scalar_contiguous, launch_params, cx, cdst, ne); } template @@ -223,13 +229,15 @@ static void ggml_cpy_scalar_cuda( GGML_ASSERT(grid_z < USHRT_MAX); dim3 dimGrid(grid_x, grid_y, grid_z); dim3 dimBlock(CUDA_CPY_TILE_DIM_2D, CUDA_CPY_BLOCK_ROWS, 1); - cpy_scalar_transpose<<>> - (cx, cdst, ne, ne00n, ne01n, ne02n, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(dimGrid, dimBlock, 0, stream); + ggml_cuda_kernel_launch(cpy_scalar_transpose, launch_params, + cx, cdst, ne, ne00n, ne01n, ne02n, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } else { const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; GGML_ASSERT(num_blocks < UINT_MAX); - cpy_scalar><<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params((dim3)num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream); + ggml_cuda_kernel_launch(cpy_scalar>, launch_params, + cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } } diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index beeb5238946..debcb6e5447 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -636,6 +636,7 @@ static __global__ void flash_attn_mask_to_KV_max( if (tid < WARP_SIZE) { buf_iw[tid] = 1; } + ggml_cuda_pdl_sync(); __syncthreads(); int KV_max_sj = (ne30 - 1) * FATTN_KQ_STRIDE; @@ -687,6 +688,7 @@ static __global__ void flash_attn_stream_k_fixup_uniform( const uint3 fd_iter_j_z, const uint3 fd_iter_j) { constexpr int ncols = ncols1*ncols2; + ggml_cuda_pdl_lc(); const int tile_idx = blockIdx.x; // One block per output tile. const int j = blockIdx.y; @@ -718,6 +720,7 @@ static __global__ void flash_attn_stream_k_fixup_uniform( dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + zt_Q*D + (j*ne02 + c)*D + tid; + ggml_cuda_pdl_sync(); // Load the partial result that needs a fixup float dst_val = *dst; float max_val; @@ -809,6 +812,7 @@ static __global__ void flash_attn_stream_k_fixup_general( float dst_val = 0.0f; float max_val = 0.0f; float rowsum = 0.0f; + ggml_cuda_pdl_sync(); { dst_val = *dst; @@ -867,6 +871,7 @@ static __global__ void flash_attn_combine_results( const float2 * __restrict__ VKQ_meta, float * __restrict__ dst, const int parallel_blocks) { + ggml_cuda_pdl_lc(); // Dimension 0: threadIdx.x // Dimension 1: blockIdx.x // Dimension 2: blockIdx.y @@ -890,6 +895,7 @@ static __global__ void flash_attn_combine_results( __builtin_assume(tid < D); extern __shared__ float2 meta[]; + ggml_cuda_pdl_sync(); for (int i = tid; i < 2*parallel_blocks; i += D) { ((float *) meta)[i] = ((const float *)VKQ_meta) [i]; } @@ -1146,7 +1152,9 @@ void launch_fattn( const uint3 ne01 = init_fastdiv_values(Q->ne[1]); GGML_ASSERT(block_dim.x % warp_size == 0); - fattn_kernel<<>>( + + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks_num, block_dim, nbytes_shared, main_stream); + ggml_cuda_kernel_launch(fattn_kernel, launch_params, (const char *) Q->data, K_data, V_data, @@ -1176,9 +1184,9 @@ void launch_fattn( const dim3 block_dim_combine(DV, 1, 1); const dim3 blocks_num_combine = {(unsigned)ntiles_dst, ncols1, ncols2}; - flash_attn_stream_k_fixup_uniform - <<>> - ((float *) KQV->data, dst_tmp_meta.ptr, + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks_num_combine, block_dim_combine, 0, main_stream); + ggml_cuda_kernel_launch(flash_attn_stream_k_fixup_uniform, launch_params, + (float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[2], nblocks_sk, gqa_ratio, bpt, fd0, fd1, fd2); } else if (ntiles_dst % blocks_num.x != 0) { @@ -1193,9 +1201,9 @@ void launch_fattn( const dim3 block_dim_combine(DV, 1, 1); const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2}; - flash_attn_stream_k_fixup_general - <<>> - ((float *) KQV->data, dst_tmp_meta.ptr, + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks_num_combine, block_dim_combine, 0, main_stream); + ggml_cuda_kernel_launch(flash_attn_stream_k_fixup_general, launch_params, + (float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], gqa_ratio, total_work, fd_k_j_z_ne12, fd_k_j_z, fd_k_j, fd_k); } @@ -1204,9 +1212,9 @@ void launch_fattn( const dim3 blocks_num_combine(Q->ne[1], Q->ne[2], Q->ne[3]); const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2); - flash_attn_combine_results - <<>> - (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data, parallel_blocks); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks_num_combine, block_dim_combine, nbytes_shared_combine, main_stream); + ggml_cuda_kernel_launch(flash_attn_combine_results, launch_params, + dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data, parallel_blocks); } CUDA_CHECK(cudaGetLastError()); } diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index a25e912c4d2..4871b90df86 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -1724,6 +1724,7 @@ static __global__ void flash_attn_ext_f16( const int32_t nb21, const int32_t nb22, const int64_t nb23, const int32_t ne31, const int32_t ne32, const int32_t ne33, const int32_t nb31, const int32_t nb32, const int64_t nb33) { + ggml_cuda_pdl_sync(); // TODO optimize placement #if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)) // Skip unused kernel variants for faster compilation: diff --git a/ggml/src/ggml-cuda/fattn-tile.cuh b/ggml/src/ggml-cuda/fattn-tile.cuh index 7b0a5e5cf49..fac76f13593 100644 --- a/ggml/src/ggml-cuda/fattn-tile.cuh +++ b/ggml/src/ggml-cuda/fattn-tile.cuh @@ -894,6 +894,8 @@ static __global__ void flash_attn_tile( } float KQ_sum[cpw] = {0.0f}; + ggml_cuda_pdl_sync(); + // Load Q data, convert to FP16 if fast: #pragma unroll for (int jc0 = 0; jc0 < cpw; ++jc0) { diff --git a/ggml/src/ggml-cuda/fattn-vec.cuh b/ggml/src/ggml-cuda/fattn-vec.cuh index f0bd42a5761..b0a6cf67f1a 100644 --- a/ggml/src/ggml-cuda/fattn-vec.cuh +++ b/ggml/src/ggml-cuda/fattn-vec.cuh @@ -40,6 +40,7 @@ static __global__ void flash_attn_ext_vec( const int32_t nb21, const int32_t nb22, const int64_t nb23, const int32_t ne31, const int32_t ne32, const int32_t ne33, const int32_t nb31, const int32_t nb32, const int64_t nb33) { + ggml_cuda_pdl_lc(); #ifdef FLASH_ATTN_AVAILABLE // Skip unused kernel variants for faster compilation: @@ -136,6 +137,8 @@ static __global__ void flash_attn_ext_vec( #endif // V_DOT2_F32_F16_AVAILABLE int Q_i32[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)]; float2 Q_ds[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)]; + + ggml_cuda_pdl_sync(); if constexpr (Q_q8_1) { #pragma unroll for (int j0 = 0; j0 < ncols; j0 += nwarps) { diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ggml/src/ggml-cuda/fattn-wmma-f16.cu index f19defbff93..4b6f6501094 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cu +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cu @@ -86,6 +86,7 @@ static __global__ void flash_attn_ext_f16( constexpr int kqs_padded = FATTN_KQ_STRIDE + 8; constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half); + ggml_cuda_pdl_sync(); const int sequence = blockIdx.z / ne02; const int head = blockIdx.z - sequence*ne02; const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. diff --git a/ggml/src/ggml-cuda/gated_delta_net.cu b/ggml/src/ggml-cuda/gated_delta_net.cu index b4c9845e7a7..018d5d37d47 100644 --- a/ggml/src/ggml-cuda/gated_delta_net.cu +++ b/ggml/src/ggml-cuda/gated_delta_net.cu @@ -1,4 +1,5 @@ #include "gated_delta_net.cuh" +#include "ggml-cuda/common.cuh" template __global__ void __launch_bounds__((ggml_cuda_get_physical_warp_size() < S_v ? ggml_cuda_get_physical_warp_size() : S_v) * 4, 2) @@ -53,6 +54,7 @@ gated_delta_net_cuda(const float * q, float s_shard[rows_per_lane]; // state is stored transposed: M[col][i] = S[i][col], row col is contiguous + ggml_cuda_pdl_sync(); #pragma unroll for (int r = 0; r < rows_per_lane; r++) { const int i = r * warp_size + lane; @@ -189,28 +191,29 @@ static void launch_gated_delta_net( int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(grid_dims, block_dims, 0, stream); switch (S_v) { case 16: - gated_delta_net_cuda<16, KDA, keep_rs_t><<>>( + ggml_cuda_kernel_launch(gated_delta_net_cuda<16, KDA, keep_rs_t>, launch_params, q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K); break; case 32: - gated_delta_net_cuda<32, KDA, keep_rs_t><<>>( + ggml_cuda_kernel_launch(gated_delta_net_cuda<32, KDA, keep_rs_t>, launch_params, q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K); break; case 64: { - gated_delta_net_cuda<64, KDA, keep_rs_t><<>>( + ggml_cuda_kernel_launch(gated_delta_net_cuda<64, KDA, keep_rs_t>, launch_params, q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K); break; } case 128: { - gated_delta_net_cuda<128, KDA, keep_rs_t><<>>( + ggml_cuda_kernel_launch(gated_delta_net_cuda<128, KDA, keep_rs_t>, launch_params, q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K); diff --git a/ggml/src/ggml-cuda/getrows.cu b/ggml/src/ggml-cuda/getrows.cu index 36b840e8148..457b695eb2a 100644 --- a/ggml/src/ggml-cuda/getrows.cu +++ b/ggml/src/ggml-cuda/getrows.cu @@ -11,6 +11,7 @@ static __global__ void k_get_rows( /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03, const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) { + ggml_cuda_pdl_sync(); for (int64_t z = blockIdx.z; z < ne11*(int64_t)ne12_fdv.z; z += gridDim.z) { for (int64_t i00 = 2*(blockIdx.y*blockDim.x + threadIdx.x); i00 < ne00; i00 += gridDim.y*blockDim.x) { // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher. @@ -48,6 +49,8 @@ static __global__ void k_get_rows_float( /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03, const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) { + ggml_cuda_pdl_lc(); + ggml_cuda_pdl_sync(); for (int64_t z = blockIdx.z; z < ne11*(int64_t)ne12_fdv.z; z += gridDim.z) { for (int64_t i00 = blockIdx.y*blockDim.x + threadIdx.x; i00 < ne00; i00 += gridDim.y*blockDim.x) { // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher. @@ -83,6 +86,7 @@ static __global__ void k_get_rows_back_float( float sum = 0.0f; + ggml_cuda_pdl_sync(); for (int64_t i = 0; i < nrows_grad; ++i) { if (rows[i] != dst_row) { continue; @@ -156,7 +160,8 @@ static void get_rows_cuda_float( GGML_ASSERT(ne11 <= std::numeric_limits::max() / ne12); const uint3 ne12_fdv = init_fastdiv_values(ne12); - k_get_rows_float<<>>( + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params{block_nums, block_dims, 0, stream}; + ggml_cuda_kernel_launch(k_get_rows_float, launch_params, src0_d, src1_d, dst_d, ne00, /*ne01, ne02, ne03,*/ /*ne10,*/ ne11, ne12_fdv, /*ne13,*/ diff --git a/ggml/src/ggml-cuda/mean.cu b/ggml/src/ggml-cuda/mean.cu index 49af5389957..a8f6046e46d 100644 --- a/ggml/src/ggml-cuda/mean.cu +++ b/ggml/src/ggml-cuda/mean.cu @@ -67,9 +67,11 @@ void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { // See discussion in: https://github.com/ggml-org/llama.cpp/pull/15132 if ((nrows / nsm) < 2) { const dim3 block_dims(512, 1, 1); - reduce_rows_f32<<>>(src0_d, dst_d, ncols); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream); + ggml_cuda_kernel_launch(reduce_rows_f32, launch_params, src0_d, dst_d, ncols); } else { const dim3 block_dims(ncols < 1024 ? 32 : 128, 1, 1); - reduce_rows_f32<<>>(src0_d, dst_d, ncols); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream); + ggml_cuda_kernel_launch(reduce_rows_f32, launch_params, src0_d, dst_d, ncols); } } diff --git a/ggml/src/ggml-cuda/mmvf.cu b/ggml/src/ggml-cuda/mmvf.cu index d9147202429..09d95f309b4 100644 --- a/ggml/src/ggml-cuda/mmvf.cu +++ b/ggml/src/ggml-cuda/mmvf.cu @@ -21,6 +21,7 @@ static __global__ void mul_mat_vec_f( int channel_y; int sample_dst; + ggml_cuda_pdl_sync(); if constexpr (is_multi_token_id) { // Multi-token MUL_MAT_ID path, adding these in the normal path causes a perf regression for n_tokens=1 case token_idx = blockIdx.z; @@ -298,6 +299,7 @@ static __global__ void mul_mat_vec_f( static_assert(std::is_same_v, "unsupported type"); } + ggml_cuda_pdl_lc(); #pragma unroll for (int j = 0; j < ncols_dst; ++j) { sumf[j] = warp_reduce_sum(sumf[j]); @@ -382,11 +384,13 @@ static void mul_mat_vec_f_switch_fusion( const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, const dim3 & block_dims, const dim3 & block_nums, const int nbytes_shared, const int ids_stride, const cudaStream_t stream) { + const ggml_cuda_kernel_launch_params launch_params = {block_nums, block_dims, nbytes_shared, stream}; + const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr; if constexpr (ncols_dst == 1) { if (has_fusion) { - mul_mat_vec_f<<>> - (x, y, ids, fusion, dst, ncols, nchannels_y, stride_row, stride_col_y, stride_col_dst, + ggml_cuda_kernel_launch(mul_mat_vec_f, launch_params, + x, y, ids, fusion, dst, ncols, nchannels_y, stride_row, stride_col_y, stride_col_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride); return; @@ -395,8 +399,8 @@ static void mul_mat_vec_f_switch_fusion( GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1"); - mul_mat_vec_f<<>> - (x, y, ids, fusion, dst, ncols, nchannels_y, stride_row, stride_col_y, stride_col_dst, + ggml_cuda_kernel_launch(mul_mat_vec_f, launch_params, + x, y, ids, fusion, dst, ncols, nchannels_y, stride_row, stride_col_y, stride_col_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride); diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 73a0991e206..13b8b855282 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -424,6 +424,7 @@ static __global__ void mul_mat_vec_q( uint32_t channel_y; uint32_t sample_dst; + ggml_cuda_pdl_sync(); channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : fastdiv(channel_dst, channel_ratio); channel_y = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst; sample_dst = blockIdx.z; @@ -683,8 +684,9 @@ static void mul_mat_vec_q_switch_fusion( const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr; if constexpr (c_ncols_dst == 1) { if (has_fusion) { - mul_mat_vec_q<<>> - (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, nbytes_shared, stream); + ggml_cuda_kernel_launch(mul_mat_vec_q, launch_params, + vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride); return; @@ -693,8 +695,9 @@ static void mul_mat_vec_q_switch_fusion( GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1"); - mul_mat_vec_q<<>> - (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, nbytes_shared, stream); + ggml_cuda_kernel_launch(mul_mat_vec_q, launch_params, + vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride); } diff --git a/ggml/src/ggml-cuda/norm.cu b/ggml/src/ggml-cuda/norm.cu index ef98f675aa7..09d9f3a7d62 100644 --- a/ggml/src/ggml-cuda/norm.cu +++ b/ggml/src/ggml-cuda/norm.cu @@ -18,6 +18,7 @@ static __global__ void norm_f32( float2 mean_var = make_float2(0.0f, 0.0f); + ggml_cuda_pdl_sync(); for (int col = tid; col < ncols; col += block_size) { const float xi = x[col]; mean_var.x += xi; @@ -46,6 +47,7 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr float tmp = 0.0f; // partial sum for thread in warp + ggml_cuda_pdl_sync(); for (int j = start; j < end; j += block_size) { tmp += x[j]; } @@ -95,6 +97,7 @@ static __global__ void rms_norm_f32(const float * x, const uint3 add_nrows_packed = make_uint3(0, 0, 0), const uint3 add_nchannels_packed = make_uint3(0, 0, 0), const uint3 add_nsamples_packed = make_uint3(0, 0, 0)) { + ggml_cuda_pdl_lc(); const int nrows = gridDim.x; const int nchannels = gridDim.y; @@ -124,6 +127,7 @@ static __global__ void rms_norm_f32(const float * x, float tmp = 0.0f; // partial sum for thread in warp + ggml_cuda_pdl_sync(); for (int col = tid; col < ncols; col += block_size) { const float xi = x[col]; tmp += xi * xi; @@ -163,6 +167,7 @@ static __global__ void rms_norm_back_f32( float sum_xx = 0.0f; // sum for squares of x, equivalent to forward pass float sum_xg = 0.0f; // sum for x * gradient, needed because RMS norm mixes inputs + ggml_cuda_pdl_sync(); for (int col = tid; col < ncols; col += block_size) { const float xfi = xf[col]; sum_xx += xfi * xfi; @@ -253,6 +258,7 @@ static __global__ void l2_norm_f32( float tmp = 0.0f; // partial sum for thread in warp + ggml_cuda_pdl_sync(); for (int col = tid; col < ncols; col += block_size) { const float xi = x[col]; tmp += xi * xi; @@ -261,6 +267,7 @@ static __global__ void l2_norm_f32( // sum up partial sums extern __shared__ float s_sum[]; tmp = block_reduce(tmp, s_sum); + ggml_cuda_pdl_lc(); // from https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html const float scale = rsqrtf(fmaxf(tmp, eps * eps)); @@ -300,10 +307,19 @@ static void rms_norm_f32_cuda( const dim3 blocks_num(nrows, nchannels, nsamples); if (ncols < 1024) { const dim3 block_dims(256, 1, 1); - rms_norm_f32<256, false><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); + const ggml_cuda_kernel_launch_params launch_params = {blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream}; + ggml_cuda_kernel_launch(rms_norm_f32<256, false>, launch_params, + x, dst, ncols, stride_row, stride_channel, stride_sample, eps, + // underlying cudaLaunchKernelEx does not support default params + nullptr, 0, 0, 0, make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0), + nullptr, 0, 0, 0, make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0)); } else { const dim3 block_dims(1024, 1, 1); - rms_norm_f32<1024, false><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params{blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream}; + ggml_cuda_kernel_launch(rms_norm_f32<1024, false>, launch_params, x, dst, ncols, stride_row, stride_channel, stride_sample, eps, + // underlying cudaLaunchKernelEx does not support default params + nullptr, 0, 0, 0, make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0), + nullptr, 0, 0, 0, make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0)); } } @@ -346,14 +362,20 @@ static void rms_norm_mul_f32_cuda(const float * x, const uint3 mul_nsamples_packed = init_fastdiv_values(mul_nsamples); if (ncols < 1024) { const dim3 block_dims(256, 1, 1); - rms_norm_f32<256, true><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>( + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params{blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream}; + ggml_cuda_kernel_launch(rms_norm_f32<256, true>, launch_params, x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, - mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed); + mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, + // underlying cudaLaunchKernelEx does not support default params + nullptr, 0, 0, 0, make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0)); } else { const dim3 block_dims(1024, 1, 1); - rms_norm_f32<1024, true><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>( + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params{blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream}; + ggml_cuda_kernel_launch(rms_norm_f32<1024, true>, launch_params, x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, - mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed); + mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, + // underlying cudaLaunchKernelEx does not support default params + nullptr, 0, 0, 0, make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0)); } } else { const uint3 mul_ncols_packed = init_fastdiv_values(mul_ncols); @@ -367,14 +389,16 @@ static void rms_norm_mul_f32_cuda(const float * x, const uint3 add_nsamples_packed = init_fastdiv_values(add_nsamples); if (ncols < 1024) { const dim3 block_dims(256, 1, 1); - rms_norm_f32<256, true, true><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>( + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params{blocks_num, block_dims,block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream}; + ggml_cuda_kernel_launch(rms_norm_f32<256, true, true>, launch_params, x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add, add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed, add_nchannels_packed, add_nsamples_packed); } else { const dim3 block_dims(1024, 1, 1); - rms_norm_f32<1024, true, true><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>( + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params{blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream}; + ggml_cuda_kernel_launch(rms_norm_f32<1024, true, true>, launch_params, x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add, add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed, @@ -399,10 +423,12 @@ static void l2_norm_f32_cuda( const dim3 blocks_num(nrows, nchannels, nsamples); if (ncols < 1024) { const dim3 block_dims(WARP_SIZE, 1, 1); - l2_norm_f32<<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params{blocks_num, block_dims, 0, stream}; + ggml_cuda_kernel_launch(l2_norm_f32, launch_params, x, dst, ncols, stride_row, stride_channel, stride_sample, eps); } else { const dim3 block_dims(1024, 1, 1); - l2_norm_f32<1024><< WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params{blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream}; + ggml_cuda_kernel_launch(l2_norm_f32<1024>, launch_params, x, dst, ncols, stride_row, stride_channel, stride_sample, eps); } } diff --git a/ggml/src/ggml-cuda/quantize.cu b/ggml/src/ggml-cuda/quantize.cu index 52f664719ae..49516965cad 100644 --- a/ggml/src/ggml-cuda/quantize.cu +++ b/ggml/src/ggml-cuda/quantize.cu @@ -6,6 +6,7 @@ static __global__ void quantize_q8_1( const float * __restrict__ x, void * __restrict__ vy, const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03, const int64_t ne0, const uint32_t ne1, const uint3 ne2) { + ggml_cuda_pdl_lc(); const int64_t i0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; if (i0 >= ne0) { @@ -28,6 +29,7 @@ static __global__ void quantize_q8_1( const int64_t ib = i_cont / QK8_1; // block index const int64_t iqs = i_cont % QK8_1; // quant index + ggml_cuda_pdl_sync(); const float xi = i0 < ne00 ? x[i03*s03 + i02*s02 + i01*s01 + i00] : 0.0f; float amax = fabsf(xi); float sum = xi; @@ -196,6 +198,7 @@ static __global__ void quantize_mmq_mxfp4(const float * __restrict__ x, const int64_t i2 = blockIdx.z % ne2; const int64_t i3 = blockIdx.z / ne2; + ggml_cuda_pdl_sync(); const int64_t i01 = ids ? ids[i1] : i1; const int64_t i02 = i2; const int64_t i03 = i3; @@ -288,6 +291,7 @@ static __global__ void quantize_mmq_q8_1( const int64_t i3 = blockIdx.z / ne2; const int64_t i00 = i0; + ggml_cuda_pdl_sync(); const int64_t i01 = ids ? ids[i1] : i1; const int64_t i02 = i2; const int64_t i03 = i3; @@ -378,7 +382,8 @@ void quantize_row_q8_1_cuda( const int64_t block_num_x = (ne0 + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE; const dim3 num_blocks(block_num_x, ne1, ne2*ne3); const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1); - quantize_q8_1<<>>(x, vy, ne00, s01, s02, s03, ne0, ne1, ne2_fastdiv); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(num_blocks, block_size, 0, stream); + ggml_cuda_kernel_launch(quantize_q8_1, launch_params, x, vy, ne00, s01, s02, s03, ne0, ne1, ne2_fastdiv); GGML_UNUSED(type_src0); } diff --git a/ggml/src/ggml-cuda/reduce_rows.cuh b/ggml/src/ggml-cuda/reduce_rows.cuh index de240fd4413..5895d3bf8e5 100644 --- a/ggml/src/ggml-cuda/reduce_rows.cuh +++ b/ggml/src/ggml-cuda/reduce_rows.cuh @@ -10,6 +10,8 @@ static __global__ void reduce_rows_f32(const float * __restrict__ x, float * __r const int num_unroll = 8; float temp[num_unroll]; float sum_temp[num_unroll] = { 0.0f }; + + ggml_cuda_pdl_sync(); for (int i = col; i < ncols;) { for (int j = 0; j < num_unroll; ++j) { if (i < ncols) { diff --git a/ggml/src/ggml-cuda/rope.cu b/ggml/src/ggml-cuda/rope.cu index 45a49a5dc2a..e20a5cb6bed 100644 --- a/ggml/src/ggml-cuda/rope.cu +++ b/ggml/src/ggml-cuda/rope.cu @@ -134,6 +134,7 @@ static __global__ void rope_neox(const T * x, const float * freq_factors, const int64_t * row_indices, const int set_rows_stride) { + ggml_cuda_pdl_lc(); const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); if (i0 >= ne00) { @@ -148,6 +149,7 @@ static __global__ void rope_neox(const T * x, int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3; const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03; + ggml_cuda_pdl_sync(); // Fusion optimization: ROPE + VIEW + SET_ROWS. // The rope output is viewed as a 1D tensor and offset based on a row index in row_indices. @@ -216,6 +218,7 @@ static __global__ void rope_multi(const T * x, int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3; const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03; + ggml_cuda_pdl_sync(); if (i0 >= n_dims) { dst[idst + i0/2 + 0] = x[ix + i0/2 + 0]; dst[idst + i0/2 + 1] = x[ix + i0/2 + 1]; @@ -300,6 +303,7 @@ static __global__ void rope_vision(const T * x, int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3; const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03; + ggml_cuda_pdl_sync(); const int sect_dims = sections.v[0] + sections.v[1]; const int sec_w = sections.v[1] + sections.v[0]; const int sector = (i0 / 2) % sect_dims; @@ -399,13 +403,14 @@ static void rope_neox_cuda(const T * x, const dim3 block_nums(nr, n_blocks_x, 1); const float theta_scale = powf(freq_base, -2.0f / n_dims); + const ggml_cuda_kernel_launch_params launch_params = {block_nums, block_dims, 0, stream}; if (freq_factors == nullptr) { - rope_neox<<>>( + ggml_cuda_kernel_launch(rope_neox, launch_params, x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride); } else { - rope_neox<<>>( + ggml_cuda_kernel_launch(rope_neox, launch_params, x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride); } @@ -443,11 +448,13 @@ static void rope_multi_cuda(const T * x, const float theta_scale = powf(freq_base, -2.0f / n_dims); if (freq_factors == nullptr) { - rope_multi<<>>( + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream); + ggml_cuda_kernel_launch(rope_multi, launch_params, x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope); } else { - rope_multi<<>>( + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream); + ggml_cuda_kernel_launch(rope_multi, launch_params, x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope); } diff --git a/ggml/src/ggml-cuda/scale.cu b/ggml/src/ggml-cuda/scale.cu index 0ddeff6a175..7b2e59a4383 100644 --- a/ggml/src/ggml-cuda/scale.cu +++ b/ggml/src/ggml-cuda/scale.cu @@ -3,9 +3,11 @@ #define MAX_GRIDDIM_X 0x7FFFFFFF static __global__ void scale_f32(const float * x, float * dst, const float scale, const float bias, const int64_t nelements) { + ggml_cuda_pdl_lc(); int64_t tid = (int64_t)blockIdx.x * (int64_t)blockDim.x + (int64_t)threadIdx.x; int64_t stride = (int64_t)blockDim.x * (int64_t)gridDim.x; + ggml_cuda_pdl_sync(); for (int64_t i = tid; i < nelements; i += stride) { dst[i] = scale * x[i] + bias; } @@ -13,7 +15,8 @@ static __global__ void scale_f32(const float * x, float * dst, const float scale static void scale_f32_cuda(const float * x, float * dst, const float scale, const float bias, const int64_t nelements, cudaStream_t stream) { const int64_t num_blocks = (nelements + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE; - scale_f32<<>>(x, dst, scale, bias, nelements); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(MIN(MAX_GRIDDIM_X, num_blocks), CUDA_SCALE_BLOCK_SIZE, 0, stream); + ggml_cuda_kernel_launch(scale_f32, launch_params, x, dst, scale, bias, nelements); } void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { diff --git a/ggml/src/ggml-cuda/set-rows.cu b/ggml/src/ggml-cuda/set-rows.cu index 631de7e8fa5..e14f96b824c 100644 --- a/ggml/src/ggml-cuda/set-rows.cu +++ b/ggml/src/ggml-cuda/set-rows.cu @@ -53,6 +53,7 @@ static __global__ void k_set_rows_quant(const float * __restrict__ src0, const int64_t i11 = fastmodulo((uint32_t) i02, ne11_fd); const int64_t i10 = i01; + ggml_cuda_pdl_sync(); const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12); const float * src0_row = src0 + i01*s01 + i02*s02 + i03*s03; @@ -157,7 +158,9 @@ static __global__ void k_set_rows(const src_t * __restrict__ src0, const int64_t i11 = fastmodulo((uint32_t) i02, ne11_fd); const int64_t i10 = i01; + ggml_cuda_pdl_sync(); const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12); + ggml_cuda_pdl_lc(); const src_t * src0_row = src0 + i01*s01 + i02*s02 + i03*s03; dst_t * dst_row_ptr = dst + dst_row*s1 + i02*s2 + i03*s3; @@ -203,9 +206,11 @@ static void set_rows_cuda( const uint3 ne11_fd = init_fastdiv_values((uint32_t) ne11); const uint3 ne12_fd = init_fastdiv_values((uint32_t) ne12); - k_set_rows<<>>(src0_d, src1_d, dst_d, ne_total, ne10, ne11, ne12, ne13, s01, - s02, s03, s10, s11, s12, s1, s2, s3, ne00_fd, ne01_fd, ne02_fd, - ne11_fd, ne12_fd); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(grid_size, block_size, 0, stream); + ggml_cuda_kernel_launch(k_set_rows, launch_params, + src0_d, src1_d, dst_d, ne_total, ne10, ne11, ne12, ne13, s01, + s02, s03, s10, s11, s12, s1, s2, s3, ne00_fd, ne01_fd, ne02_fd, + ne11_fd, ne12_fd); } } diff --git a/ggml/src/ggml-cuda/softcap.cu b/ggml/src/ggml-cuda/softcap.cu index 40dfe45d65c..9f0fa1051cf 100644 --- a/ggml/src/ggml-cuda/softcap.cu +++ b/ggml/src/ggml-cuda/softcap.cu @@ -1,18 +1,21 @@ #include "softcap.cuh" static __global__ void softcap_f32(const float * x, float * dst, const float scale, const float softcap, const int k) { + ggml_cuda_pdl_lc(); const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { return; } + ggml_cuda_pdl_sync(); dst[i] = tanhf(scale * x[i]) * softcap; } static void softcap_f32_cuda(const float * x, float * dst, const float scale, const float softcap, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_SOFTCAP_BLOCK_SIZE - 1) / CUDA_SOFTCAP_BLOCK_SIZE; - softcap_f32<<>>(x, dst, scale, softcap, k); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(num_blocks, CUDA_SOFTCAP_BLOCK_SIZE, 0, stream); + ggml_cuda_kernel_launch(softcap_f32, launch_params, x, dst, scale, softcap, k); } // fused GGML_OP_SCALE + GGML_UNARY_OP_TANH + GGML_OP_SCALE diff --git a/ggml/src/ggml-cuda/ssm-conv.cu b/ggml/src/ggml-cuda/ssm-conv.cu index 4c4daf85dc6..48787b4b890 100644 --- a/ggml/src/ggml-cuda/ssm-conv.cu +++ b/ggml/src/ggml-cuda/ssm-conv.cu @@ -1,3 +1,4 @@ +#include "common.cuh" #include "ssm-conv.cuh" #include "unary.cuh" @@ -7,6 +8,7 @@ static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1, float * __restrict__ dst, const int dst_nb0, const int dst_nb1, const int dst_nb2, const int64_t n_t) { + ggml_cuda_pdl_lc(); GGML_UNUSED(src0_nb0); const int tid = threadIdx.x; const int bidx = blockIdx.x; @@ -23,6 +25,7 @@ static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float float x[d_conv] = { 0.0f }; float w[d_conv] = { 0.0f }; + ggml_cuda_pdl_sync(); #pragma unroll for (size_t j = 0; j < d_conv; j++) { w[j] = w_block[tid * stride_w + j]; @@ -128,8 +131,9 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const floa constexpr int kNC = decltype(NC)::value; if (n_t <= 32) { const dim3 blocks(n_s, (nr + threads - 1) / threads, 1); - ssm_conv_f32<<>>(src0, src1, bias, src0_nb0, src0_nb1, src0_nb2, src1_nb1, - dst, dst_nb0, dst_nb1, dst_nb2, n_t); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks, threads, 0, stream); + ggml_cuda_kernel_launch(ssm_conv_f32, launch_params, src0, src1, bias, src0_nb0, src0_nb1, + src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t); } else { const int64_t split_n_t = 32; dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t); diff --git a/ggml/src/ggml-cuda/ssm-scan.cu b/ggml/src/ggml-cuda/ssm-scan.cu index c1d4e2bc8df..412980376ac 100644 --- a/ggml/src/ggml-cuda/ssm-scan.cu +++ b/ggml/src/ggml-cuda/ssm-scan.cu @@ -26,6 +26,7 @@ __global__ void __launch_bounds__(splitD, 1) const int64_t s_off, const int64_t d_inner, const int64_t L_param) { const size_t L = L_template == 0 ? L_param : L_template; + ggml_cuda_pdl_sync(); const float *s0_block = (const float *)((const char *)src0 + src6[blockIdx.x] * src0_nb3 + blockIdx.y * splitD * src0_nb2); const float *x_block = (const float *)((const char *)src1 + (blockIdx.x * src1_nb3) + blockIdx.y * splitD * sizeof(float)); const float *dt_block = (const float *)((const char *)src2 + (blockIdx.x * src2_nb2) + blockIdx.y * splitD * sizeof(float)); @@ -135,6 +136,7 @@ __global__ void __launch_bounds__(d_state, 1) const int group_off = (head_idx / (n_head / n_group)) * d_state * sizeof(float); + ggml_cuda_pdl_sync(); // TODO: refactor strides to be in elements/floats instead of bytes to be cleaner and consistent with the rest of the codebase const float * s0_warp = (const float *) ((const char *) src0 + src6[seq_idx] * src0_nb3 + head_idx * src0_nb2 + head_off * d_state); const float * x_warp = (const float *) ((const char *) src1 + (seq_idx * src1_nb3) + (warp_idx * sizeof(float))); @@ -206,7 +208,8 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa constexpr int num_warps = threads/WARP_SIZE; const dim3 blocks((n_head * head_dim + (num_warps - 1)) / num_warps, n_seq, 1); - ssm_scan_f32_group<128/WARP_SIZE, 128><<>>( + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks, threads, 0, stream); + ggml_cuda_kernel_launch(ssm_scan_f32_group<128/WARP_SIZE, 128>, launch_params, src0, src1, src2, src3, src4, src5, src6, dst, src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok); @@ -215,7 +218,8 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa constexpr int num_warps = threads/WARP_SIZE; const dim3 blocks((n_head * head_dim + (num_warps - 1)) / num_warps, n_seq, 1); - ssm_scan_f32_group<256/WARP_SIZE, 256><<>>( + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks, threads, 0, stream); + ggml_cuda_kernel_launch(ssm_scan_f32_group<256/WARP_SIZE, 256>, launch_params, src0, src1, src2, src3, src4, src5, src6, dst, src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok); @@ -231,58 +235,59 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa const dim3 blocks(n_seq, (n_head + threads - 1) / threads, 1); const int smem_size = (threads * (d_state + 1) * 2) * sizeof(float); if (d_state == 16) { + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks, threads, smem_size, stream); switch (n_tok) { case 1: - ssm_scan_f32<<>>( + ggml_cuda_kernel_launch(ssm_scan_f32, launch_params, src0, src1, src2, src3, src4, src5, src6, dst, src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok); break; case 2: - ssm_scan_f32<<>>( + ggml_cuda_kernel_launch(ssm_scan_f32, launch_params, src0, src1, src2, src3, src4, src5, src6, dst, src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok); break; case 3: - ssm_scan_f32<<>>( + ggml_cuda_kernel_launch(ssm_scan_f32, launch_params, src0, src1, src2, src3, src4, src5, src6, dst, src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok); break; case 4: - ssm_scan_f32<<>>( + ggml_cuda_kernel_launch(ssm_scan_f32, launch_params, src0, src1, src2, src3, src4, src5, src6, dst, src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok); break; case 5: - ssm_scan_f32<<>>( + ggml_cuda_kernel_launch(ssm_scan_f32, launch_params, src0, src1, src2, src3, src4, src5, src6, dst, src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok); break; case 6: - ssm_scan_f32<<>>( + ggml_cuda_kernel_launch(ssm_scan_f32, launch_params, src0, src1, src2, src3, src4, src5, src6, dst, src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok); break; case 7: - ssm_scan_f32<<>>( + ggml_cuda_kernel_launch(ssm_scan_f32, launch_params, src0, src1, src2, src3, src4, src5, src6, dst, src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok); break; case 8: - ssm_scan_f32<<>>( + ggml_cuda_kernel_launch(ssm_scan_f32, launch_params, src0, src1, src2, src3, src4, src5, src6, dst, src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok); break; default: - ssm_scan_f32<<>>( + ggml_cuda_kernel_launch(ssm_scan_f32, launch_params, src0, src1, src2, src3, src4, src5, src6, dst, src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok); diff --git a/ggml/src/ggml-cuda/sumrows.cu b/ggml/src/ggml-cuda/sumrows.cu index 4025771aadb..0003658ca95 100644 --- a/ggml/src/ggml-cuda/sumrows.cu +++ b/ggml/src/ggml-cuda/sumrows.cu @@ -7,10 +7,12 @@ void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int const dim3 block_nums(nrows, 1, 1); if ((nrows / nsm) < 2) { const dim3 block_dims(512, 1, 1); - reduce_rows_f32<<>>(x, dst, ncols); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream); + ggml_cuda_kernel_launch(reduce_rows_f32, launch_params, x, dst, ncols); } else { const dim3 block_dims(ncols < 1024 ? 32 : 128, 1, 1); - reduce_rows_f32<<>>(x, dst, ncols); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream); + ggml_cuda_kernel_launch(reduce_rows_f32, launch_params, x, dst, ncols); } } @@ -34,10 +36,12 @@ void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { if ((nrows / nsm) < 2) { // Increase num threads to 512 for small nrows to better hide the latency const dim3 block_dims(512, 1, 1); - reduce_rows_f32<<>>(src0_d, dst_d, ncols); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream); + ggml_cuda_kernel_launch(reduce_rows_f32, launch_params, src0_d, dst_d, ncols); } else { // Enough active SMs to hide latency, use smaller blocks to allow better scheduling const dim3 block_dims(ncols < 1024 ? 32 : 128, 1, 1); - reduce_rows_f32<<>>(src0_d, dst_d, ncols); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream); + ggml_cuda_kernel_launch(reduce_rows_f32, launch_params, src0_d, dst_d, ncols); } } diff --git a/ggml/src/ggml-cuda/topk-moe.cu b/ggml/src/ggml-cuda/topk-moe.cu index 3020e5c7433..da20c9aab7c 100644 --- a/ggml/src/ggml-cuda/topk-moe.cu +++ b/ggml/src/ggml-cuda/topk-moe.cu @@ -105,6 +105,7 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * wt[i] = -INFINITY; } + ggml_cuda_pdl_sync(); #pragma unroll for (int i = 0; i < n_experts; i += WARP_SIZE) { const int expert = i + threadIdx.x; @@ -161,6 +162,7 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * output_weights[i] = 0.f; } + ggml_cuda_pdl_lc(); for (int k = 0; k < n_expert_used; k++) { float max_val = wt[0]; int max_expert = threadIdx.x; @@ -271,51 +273,52 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx, dim3 grid_dims((n_rows + rows_per_block - 1) / rows_per_block, 1, 1); dim3 block_dims(WARP_SIZE, rows_per_block, 1); cudaStream_t stream = ctx.stream(); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(grid_dims, block_dims, 0, stream); switch (n_expert) { case 1: - topk_moe_cuda<1, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used, - clamp_val, scale_val, config); + ggml_cuda_kernel_launch(topk_moe_cuda<1, has_bias>, launch_params, + logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config); break; case 2: - topk_moe_cuda<2, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used, - clamp_val, scale_val, config); + ggml_cuda_kernel_launch(topk_moe_cuda<2, has_bias>, launch_params, + logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config); break; case 4: - topk_moe_cuda<4, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used, - clamp_val, scale_val, config); + ggml_cuda_kernel_launch(topk_moe_cuda<4, has_bias>, launch_params, + logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config); break; case 8: - topk_moe_cuda<8, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used, - clamp_val, scale_val, config); + ggml_cuda_kernel_launch(topk_moe_cuda<8, has_bias>, launch_params, + logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config); break; case 16: - topk_moe_cuda<16, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used, - clamp_val, scale_val, config); + ggml_cuda_kernel_launch(topk_moe_cuda<16, has_bias>, launch_params, + logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config); break; case 32: - topk_moe_cuda<32, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used, - clamp_val, scale_val, config); + ggml_cuda_kernel_launch(topk_moe_cuda<32, has_bias>, launch_params, + logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config); break; case 64: - topk_moe_cuda<64, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used, - clamp_val, scale_val, config); + ggml_cuda_kernel_launch(topk_moe_cuda<64, has_bias>, launch_params, + logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config); break; case 128: - topk_moe_cuda<128, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used, - clamp_val, scale_val, config); + ggml_cuda_kernel_launch(topk_moe_cuda<128, has_bias>, launch_params, + logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config); break; case 256: - topk_moe_cuda<256, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used, - clamp_val, scale_val, config); + ggml_cuda_kernel_launch(topk_moe_cuda<256, has_bias>, launch_params, + logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config); break; case 512: - topk_moe_cuda<512, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used, - clamp_val, scale_val, config); + ggml_cuda_kernel_launch(topk_moe_cuda<512, has_bias>, launch_params, + logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config); break; case 576: - topk_moe_cuda<576, has_bias><<>>(logits, weights, ids, bias, n_rows, n_expert_used, - clamp_val, scale_val, config); + ggml_cuda_kernel_launch(topk_moe_cuda<576, has_bias>, launch_params, + logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config); break; default: GGML_ASSERT(false && "fatal error"); diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index 2aeba26f414..4cb805fa601 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -116,19 +116,22 @@ static __device__ __forceinline__ float op_trunc(float x) { template static __global__ void unary_op_kernel(const T * x, T * dst, const int k) { + ggml_cuda_pdl_lc(); const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { return; } + ggml_cuda_pdl_sync(); dst[i] = (T)op((float)x[i]); } template static void unary_cuda(const T * x, T * dst, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_NEG_BLOCK_SIZE - 1) / CUDA_NEG_BLOCK_SIZE; - unary_op_kernel<<>>(x, dst, k); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params((dim3)num_blocks, CUDA_NEG_BLOCK_SIZE, 0, stream); + ggml_cuda_kernel_launch(unary_op_kernel, launch_params, x, dst, k); } template @@ -258,6 +261,7 @@ void ggml_cuda_op_softplus(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { template static __global__ void unary_gated_op_kernel(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1) { + ggml_cuda_pdl_lc(); const int64_t i = int64_t(blockDim.x)*blockIdx.x + threadIdx.x; if (i >= k) { @@ -268,13 +272,15 @@ static __global__ void unary_gated_op_kernel(const T * x, const T * g, T * dst, const int64_t j0 = (i / n) * o0 + (i % n); const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n); + ggml_cuda_pdl_sync(); dst[i] = (T)(op((float)x[j0]) * (float)g[j1]); } template static void unary_gated_cuda(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, cudaStream_t stream) { const int64_t num_blocks = (k + CUDA_GLU_BLOCK_SIZE - 1) / CUDA_GLU_BLOCK_SIZE; - unary_gated_op_kernel<<>>(x, g, dst, k, n, o0, o1); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params((dim3)num_blocks, CUDA_GLU_BLOCK_SIZE, 0, stream); + ggml_cuda_kernel_launch(unary_gated_op_kernel, launch_params, x, g, dst, k, n, o0, o1); } template From b93a5ba605580e6dc05d4ee78f5894bfa6021ffc Mon Sep 17 00:00:00 2001 From: Max Krasnyansky Date: Wed, 20 May 2026 07:39:01 -0700 Subject: [PATCH 661/831] hexagon: HMX quantized matmul rework (llama/23368) * hmx-mm: update debug logging in hmx-mm * hmx-mm: update dequant logic to use HVX_vector_x2/4 * hmx-mm: remove non-pipelined version of the quantize matmul It seems that we don't reall need non-pipelined version * hmx-mm: use activation depth mode and update naming Co-authored-by: Kim-Chyan Gan * hex-mm: minor hmx matmul naming updates * hmx-mm: remove unused vars * snapdragon: scripts bump default ubatch-size to 1K * hexagon: combine HMX and power and clock settings into a single set_power call * hmx-mm: remove leftover of the scale repl helper * hexagon: fix editconf error --------- Co-authored-by: Kim-Chyan Gan --- ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c | 642 ++++++--------------- ggml/src/ggml-hexagon/htp/hmx-ops.h | 13 +- ggml/src/ggml-hexagon/htp/main.c | 38 +- ggml/src/ggml-hexagon/htp/matmul-ops.c | 10 +- 4 files changed, 196 insertions(+), 507 deletions(-) diff --git a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c index e05ccfd5fc7..3ef0bcdb26d 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c @@ -201,11 +201,10 @@ static inline HVX_Vector dequantize_x4x2_q4_0_group_hvx(const uint8_t *packed_32 // Batch-dequantize 4 contiguous x4x2 Q4_0 groups (4x32 = 128 packed bytes) using // full HVX vector width. One vmemu + one vlut16 replaces 4 separate calls. -// Output: out[0..3] each hold 32 FP16 values in the first 64 bytes. -static inline void dequantize_x4x2_q4_0_x4groups_hvx( +// Output: vector_x2 each hold 32 FP16 values in the first 64 bytes. +static inline HVX_Vector_x2 dequantize_x4x2_q4_0_x4groups_hvx( const uint8_t *packed_128, bool upper_nibbles, - const __fp16 *scales_4, const HVX_Vector vlut_cvt, - HVX_Vector out[4]) { + const __fp16 *scales_4, const HVX_Vector vlut_cvt) { // Load all 128 packed bytes (4 contiguous 32-byte groups) HVX_Vector vq = hvx_vmemu(packed_128); const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); @@ -221,8 +220,7 @@ static inline void dequantize_x4x2_q4_0_x4groups_hvx( HVX_Vector v_hi = Q6_V_hi_W(vp); // [group2: 32 fp16 | group3: 32 fp16] // Build per-group scale vectors: first 64 bytes use scale_a, last 64 use scale_b - volatile HVX_Vector vscale = hvx_vmemu(scales_4); - + HVX_Vector vscale = hvx_vmemu(scales_4); HVX_Vector v_sc01 = hvx_vec_repl_2x_f16(vscale); HVX_Vector v_sc23 = hvx_vec_repl_2x_f16(Q6_V_vror_VR(vscale, 4)); @@ -230,8 +228,9 @@ static inline void dequantize_x4x2_q4_0_x4groups_hvx( v_hi = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hi, v_sc23)); // Extract individual groups: scatter uses q_mask64 so only first 64 bytes matter - out[0] = v_lo; // group0 already in [0:63] - out[1] = v_hi; // group2 already in [0:63] + HVX_Vector_x2 r = { v_lo,/* group1 already in [0:63] */ + v_hi /* group2 already in [0:63] */ }; + return r; } // Dequantize one x4x2 Q8_0 group (32 int8 quants) -> 32 FP16 in first 64 bytes. @@ -292,12 +291,11 @@ static inline HVX_Vector dequantize_x4x2_mxfp4_group_hvx(const uint8_t * packed } // Batch-dequantize 4 contiguous x4x2 MXFP4 groups (4x32 = 128 packed bytes). -static inline void dequantize_x4x2_mxfp4_x4groups_hvx(const uint8_t * packed_128, +static inline HVX_Vector_x4 dequantize_x4x2_mxfp4_x4groups_hvx(const uint8_t * packed_128, bool upper_nibbles, int sub_blk_base, const HVX_Vector vlut_cvt, - mxfp4_scales_t scales, - HVX_Vector out[4]) { + mxfp4_scales_t scales) { HVX_Vector vq = hvx_vmemu(packed_128); const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); HVX_Vector v_quants = upper_nibbles ? Q6_Vub_vlsr_VubR(vq, 4) : vq; @@ -318,10 +316,8 @@ static inline void dequantize_x4x2_mxfp4_x4groups_hvx(const uint8_t * packed_12 v_lo = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_lo, v_sc01)); v_hi = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hi, v_sc23)); - out[0] = v_lo; - out[1] = Q6_V_vror_VR(v_lo, 64); - out[2] = v_hi; - out[3] = Q6_V_vror_VR(v_hi, 64); + HVX_Vector_x4 r = { v_lo, Q6_V_vror_VR(v_lo, 64), v_hi, Q6_V_vror_VR(v_hi, 64) }; + return r; } // Dequantize a tile range from x4x2 weight data (already in VTCM) to tile-major FP16. @@ -372,18 +368,18 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task( unsigned row1 = ct * HMX_FP16_TILE_N_COLS + 1; for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) { - HVX_Vector v0[2]; const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride; - dequantize_x4x2_q4_0_x4groups_hvx(r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt, v0); - Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, v0[0]); - Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, v0[1]); - v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + const uint8_t *r1 = vtcm_src + row_offset; row_offset += row_stride; + HVX_Vector_x2 dv0 = dequantize_x4x2_q4_0_x4groups_hvx(r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt); + HVX_Vector_x2 dv1 = dequantize_x4x2_q4_0_x4groups_hvx(r1 + packed_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt); + + Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[0]); + Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[1]); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); - r0 = vtcm_src + row_offset; row_offset += row_stride; - dequantize_x4x2_q4_0_x4groups_hvx(r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt, v0); - Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, v0[0]); - Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, v0[1]); + Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[0]); + Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[1]); v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); } @@ -415,21 +411,21 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task( // Batch-convert all 8 E8M0 scales once per row (stays in HVX register) mxfp4_scales_t r0_e8 = mxfp4_convert_scales(r0 + e8m0_blk_off); - HVX_Vector v0[4], v1[4]; - dequantize_x4x2_mxfp4_x4groups_hvx(r0 + packed_off, upper, sub_blk_base, vlut_cvt, r0_e8, v0); + HVX_Vector_x4 dv0, dv1; + dv0 = dequantize_x4x2_mxfp4_x4groups_hvx(r0 + packed_off, upper, sub_blk_base, vlut_cvt, r0_e8); if (row1 < n_cols) { mxfp4_scales_t r1_e8 = mxfp4_convert_scales(r1 + e8m0_blk_off); - dequantize_x4x2_mxfp4_x4groups_hvx(r1 + packed_off, upper, sub_blk_base, vlut_cvt, r1_e8, v1); + dv1 = dequantize_x4x2_mxfp4_x4groups_hvx(r1 + packed_off, upper, sub_blk_base, vlut_cvt, r1_e8); } else { - v1[0] = v1[1] = v1[2] = v1[3] = Q6_V_vzero(); + dv1.v[0] = dv1.v[1] = dv1.v[2] = dv1.v[3] = Q6_V_vzero(); } for (int g = 0; g < 4; g++) { - Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_bases[g], HMX_FP16_TILE_SIZE - 1, v_off, v0[g]); + Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_bases[g], HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[g]); } v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); for (int g = 0; g < 4; g++) { - Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_bases[g], HMX_FP16_TILE_SIZE - 1, v_off, v1[g]); + Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_bases[g], HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[g]); } v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); } @@ -612,11 +608,13 @@ static void core_dot_chunk_fp16(__fp16 *restrict output, const __fp16 *restrict const __fp16 *row_tiles = activation + r * n_dot_tiles * HMX_FP16_TILE_N_ELMS; const __fp16 *col_tiles = weight + c * n_dot_tiles * HMX_FP16_TILE_N_ELMS; - for (int k = 0; k < n_dot_tiles; ++k) { - Q6_activation_hf_mxmem_RR((unsigned int)row_tiles, 2047); - Q6_weight_hf_mxmem_RR((unsigned int)col_tiles, 2047); - row_tiles += HMX_FP16_TILE_N_ELMS; - col_tiles += HMX_FP16_TILE_N_ELMS; + for (int k = 0, k_block; k < n_dot_tiles; k += k_block) { + k_block = hex_smin(n_dot_tiles - k, 32); + const uint32_t range = 2048u * (uint32_t)k_block - 1; + Q6_activation_hf_mxmem_RR_deep((unsigned int)row_tiles, range); + Q6_weight_hf_mxmem_RR((unsigned int)col_tiles, range); + row_tiles += k_block * HMX_FP16_TILE_N_ELMS; + col_tiles += k_block * HMX_FP16_TILE_N_ELMS; } __fp16 *out_tile = output + (r * n_col_tiles + c) * HMX_FP16_TILE_N_ELMS; @@ -832,10 +830,6 @@ static void transfer_activation_chunk_threaded(struct htp_context *ctx, __fp16 * worker_pool_run_func(ctx->worker_pool, transfer_activation_chunk_worker_fn, &state, ctx->n_threads); } -// - -#define FALLBACK_TO_STANDARD 1 - // C += AB static void core_mma_chunk_fp16(__fp16 *restrict c, const __fp16 *restrict a, const __fp16 *restrict b, const __fp16 *restrict col_scales, const __fp16 *restrict eye_tile, @@ -861,314 +855,80 @@ static void core_mma_chunk_fp16(__fp16 *restrict c, const __fp16 *restrict a, co Q6_weight_hf_mxmem_RR((unsigned int)eye_tile, 2047); } - for (int k = 0; k < n_dot_tiles; ++k) { - Q6_activation_hf_mxmem_RR((unsigned int)row_tiles, 2047); - Q6_weight_hf_mxmem_RR((unsigned int)col_tiles, 2047); - row_tiles += HMX_FP16_TILE_N_ELMS; - col_tiles += HMX_FP16_TILE_N_ELMS; - } - Q6_mxmem_AR_after_hf(accum_tile, 0); - } - } -} - -static __attribute__((noinline)) int mat_mul_qk_0_d16a32_out_stationary(struct htp_context *ctx, - float *restrict out, const float *restrict x, const uint8_t *restrict w, - int m, int k, int n, int weight_type) { - // assume k % 32 == 0 && n % 32 == 0 - const size_t row_stride = get_x4x2_row_stride(weight_type, k); - if (row_stride == 0) { - return -1; - } - - const size_t vtcm_budget = ctx->vtcm_size; - - const size_t K_BLOCK_SIZE = 1024; - - // Fallback: if k doesn't need K-blocking, out-stationary has no advantage - const size_t k_iters_check = (k + K_BLOCK_SIZE - 1) / K_BLOCK_SIZE; - if (k_iters_check <= 1) { - FARF(HIGH, "%s: K_BLK=%zu >= k=%d, fallback to standard path", __func__, K_BLOCK_SIZE, k); - return FALLBACK_TO_STANDARD; - } - - // Dynamic M,N search via hmx_compute_chunks - const size_t sub_row_stride_alloc = get_x4x2_row_stride(weight_type, K_BLOCK_SIZE); - const size_t per_m = K_BLOCK_SIZE * sizeof(float) // scratch1: M×K×4 (act DMA staging F32) - + K_BLOCK_SIZE * sizeof(__fp16); // activation: M×K×2 (F16 tiles) - const size_t per_n = sub_row_stride_alloc // scratch0: N×sub_row(K) (packed quant) - + K_BLOCK_SIZE * sizeof(__fp16); // weight: N×K×2 (F16 tiles) - const size_t per_mn = sizeof(__fp16); // output: M×N×2 (out-stationary) - - // Alignment margin: hex_align_up can add up to 2047 bytes per buffer; - // scratch1 (mc×6144) is naturally 2048-aligned, remaining 4 buffers need margin - const size_t align_margin = 4 * HMX_FP16_TILE_SIZE; - const size_t overhead = HMX_FP16_TILE_SIZE + 256 + align_margin; // eye_tile + scales + alignment - - size_t M_BLOCK_SIZE, N_BLOCK_SIZE, vtcm_used; - // Cost-based search: minimize ceil(m/mc)*m_block_cost + ceil(n/nc)*n_block_cost. - // From profiling: wt_dequant per element ≈ 1.5× activation load per element. - // m_block_cost = n*3: each extra M-block re-dequants all N×K weight (expensive). - // n_block_cost = m*2: each extra N-block re-loads all M×K activation (cheaper). - const size_t m_block_cost = (size_t) n * 3; - const size_t n_block_cost = (size_t) m * 2; - if (hmx_compute_chunks(vtcm_budget, overhead, per_n, per_m, per_mn, - hex_align_up(m, HMX_FP16_TILE_N_ROWS), n, - m_block_cost, n_block_cost, &M_BLOCK_SIZE, - &N_BLOCK_SIZE, &vtcm_used) != 0) { - FARF(HIGH, "%s: VTCM too small (m=%d k=%d n=%d budget=%zu)", __func__, m, k, n, vtcm_budget); - return -1; - } - - // Compute precise buffer sizes from searched M,N and fixed K - const size_t weight_size = hex_align_up(N_BLOCK_SIZE * K_BLOCK_SIZE * sizeof(__fp16), HMX_FP16_TILE_SIZE); - const size_t act_size = hex_align_up(M_BLOCK_SIZE * K_BLOCK_SIZE * sizeof(__fp16), HMX_FP16_TILE_SIZE); - const size_t out_size = hex_align_up(M_BLOCK_SIZE * N_BLOCK_SIZE * sizeof(__fp16), HMX_FP16_TILE_SIZE); - const size_t scratch0_sz = hex_align_up(N_BLOCK_SIZE * sub_row_stride_alloc, HMX_FP16_TILE_SIZE); - const size_t scratch1_sz = hex_align_up(M_BLOCK_SIZE * K_BLOCK_SIZE * sizeof(float), HMX_FP16_TILE_SIZE); - - const size_t total_vtcm = weight_size + act_size + out_size + scratch0_sz + scratch1_sz + HMX_FP16_TILE_SIZE + 256; - if (total_vtcm > vtcm_budget) { - FARF(HIGH, "%s: VTCM overflow after search: need %zu have %zu (M=%zu N=%zu K=%zu)", __func__, total_vtcm, - vtcm_budget, M_BLOCK_SIZE, N_BLOCK_SIZE, K_BLOCK_SIZE); - return -1; - } - - uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base; - __fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_size); - __fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, act_size); - __fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, out_size); - uint8_t *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch0_sz); - uint8_t *vtcm_scratch1 = vtcm_seq_alloc(&vtcm_ptr, scratch1_sz); - __fp16 *vtcm_eye_tile = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, HMX_FP16_TILE_SIZE); - __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); - assert((size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base) <= vtcm_budget); - - FARF(HIGH, "hmx-mm: m=%d k=%d n=%d wtype=%d block M=%zu N=%zu K=%zu vtcm=%zu/%zu", m, k, n, weight_type, - M_BLOCK_SIZE, N_BLOCK_SIZE, K_BLOCK_SIZE, (size_t) (vtcm_ptr - (uint8_t *) ctx->vtcm_base), vtcm_budget); - - // initialize eye tile (32x32 identity matrix) - { - HVX_Vector v; - v = Q6_V_vzero(); - v = Q6_Vw_vinsert_VwR(v, 0x3c000000); - v = Q6_V_vror_VR(v, VLEN - 4); - v = Q6_Vw_vinsert_VwR(v, 0x00003c00); - for (int i = 0; i < 16; ++i) { - ((HVX_Vector *) vtcm_eye_tile)[i] = v; - v = Q6_V_vror_VR(v, VLEN - 8); - } - } - hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16 - - TIMER_DEFINE(fetch); - TIMER_DEFINE(act_load); - TIMER_DEFINE(wt_dequant); - TIMER_DEFINE(core); - - HAP_compute_res_hmx_lock(ctx->vtcm_rctx); - - for (size_t mr = 0; mr < m; mr += M_BLOCK_SIZE) { - size_t m_blk_sz = hex_smin(m - mr, M_BLOCK_SIZE); - for (size_t nc = 0; nc < n; nc += N_BLOCK_SIZE) { - size_t n_blk_sz = hex_smin(n - nc, N_BLOCK_SIZE); - - const int n_row_tiles = hmx_ceil_div(m_blk_sz, HMX_FP16_TILE_N_ROWS); - const int n_col_tiles = hmx_ceil_div(n_blk_sz, HMX_FP16_TILE_N_COLS); - - for (size_t kk = 0; kk < k; kk += K_BLOCK_SIZE) { - const size_t k_blk_sz = hex_smin(k - kk, K_BLOCK_SIZE); - - TIMER_START(fetch); - // fetch activation block into VTCM - { - const float *activation_block = x + mr * k + kk; - - dma_queue_push(ctx->dma[0], - dma_make_ptr(vtcm_scratch1, activation_block), - k_blk_sz * sizeof(float), - k * sizeof(float), - k_blk_sz * sizeof(float), - m_blk_sz); - } - - // fetch weight block into VTCM (x4x2 sub-block: quants + scales) - const size_t sub_row_stride = get_x4x2_row_stride(weight_type, k_blk_sz); - { - const int blk_start = kk / QK_Q4_0x4x2; - const int nb_sub = (k_blk_sz + QK_Q4_0x4x2 - 1) / QK_Q4_0x4x2; - const int full_qrow = (weight_type == HTP_TYPE_Q8_0) ? k : (k / 2); - const int scale_blk_size = (weight_type == HTP_TYPE_MXFP4) ? HMX_X4X2_MXFP4_EBLK_SIZE : HMX_X4X2_DBLK_SIZE; - uint8_t *dst = vtcm_scratch0; - const uint8_t *src = w + nc * row_stride; - const size_t n_rows = n_blk_sz; - const size_t src_stride = row_stride; - const size_t dst_stride = sub_row_stride; - const size_t quant_off = (weight_type == HTP_TYPE_Q8_0) ? (blk_start * QK_Q8_0x4x2) : (blk_start * (QK_Q4_0x4x2 / 2)); - const size_t quant_width = (weight_type == HTP_TYPE_Q8_0) ? (nb_sub * QK_Q8_0x4x2) : (nb_sub * (QK_Q4_0x4x2 / 2)); - const size_t scale_off = full_qrow + blk_start * scale_blk_size; - const size_t scale_width = nb_sub * scale_blk_size; - - // 2D DMA: quants sub-range - dma_queue_push(ctx->dma[0], dma_make_ptr(dst, src + quant_off), dst_stride, src_stride, quant_width, n_rows); - // 2D DMA: scales sub-range - dma_queue_push(ctx->dma[0], dma_make_ptr(dst + quant_width, src + scale_off), dst_stride, src_stride, scale_width, n_rows); - } - TIMER_STOP(fetch); - - TIMER_START(act_load); - // load activation block - { - dma_queue_pop(ctx->dma[0]); // wait for act DNA - transfer_activation_chunk_threaded(ctx, vtcm_activation, (float *) vtcm_scratch1, m_blk_sz, k_blk_sz, k_blk_sz); - } - TIMER_STOP(act_load); - - TIMER_START(wt_dequant); - // dequantize weight block - { - dma_queue_pop(ctx->dma[0]); - dma_queue_pop(ctx->dma[0]); - // vtcm_scratch0 is used to store the qweight chunk - // worker_pool_run_func already returned, so fetch is done - dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight, vtcm_scratch0, - n_blk_sz, k_blk_sz, sub_row_stride, weight_type); - } - TIMER_STOP(wt_dequant); - - // core mma - TIMER_START(core); - { - core_mma_chunk_fp16(vtcm_output, vtcm_activation, vtcm_weight, vtcm_scales, vtcm_eye_tile, n_row_tiles, - n_col_tiles, k_blk_sz / HMX_FP16_TILE_N_COLS, kk == 0); - } - TIMER_STOP(core); + for (int k = 0, k_block; k < n_dot_tiles; k += k_block) { + k_block = hex_smin(n_dot_tiles - k, 32); + const uint32_t range = 2048u * (uint32_t)k_block - 1; + Q6_activation_hf_mxmem_RR_deep((unsigned int)row_tiles, range); + Q6_weight_hf_mxmem_RR((unsigned int)col_tiles, range); + row_tiles += k_block * HMX_FP16_TILE_N_ELMS; + col_tiles += k_block * HMX_FP16_TILE_N_ELMS; } - // store output block - { - float *output_block = out + (mr * n + nc); - transfer_output_chunk_threaded(ctx, output_block, vtcm_output, m_blk_sz, n_blk_sz, n); - } + Q6_mxmem_AR_after_hf(accum_tile, 0); } } - - HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); - -#if defined(ENABLE_PROFILE_TIMERS) - FARF(HIGH, "fetch: %lld us, act_load: %lld us, wt_dequant: %lld us, core: %lld us", - TIMER_US(fetch), TIMER_US(act_load), TIMER_US(wt_dequant), TIMER_US(core)); -#endif - return 0; } -int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict dst, const float *restrict activation, +int hmx_matmul_q_f32(struct htp_context *ctx, float *restrict dst, const float *restrict activation, const uint8_t *restrict permuted_weight, int m, int k, int n, int weight_type) { - if (!dst || !activation || !permuted_weight || !m || !n || !k) { return -1; } if (k % 32 != 0 || n % 32 != 0) { return -1; } if (!hex_is_aligned(dst, VLEN) || !hex_is_aligned(activation, VLEN) || !hex_is_aligned(permuted_weight, VLEN)) { return -1; } - // for large m, k (e.g. prefill FFN Down), use out-stationary version - if (m >= 128 && k > n && n > 1024) { - int rc = mat_mul_qk_0_d16a32_out_stationary(ctx, dst, activation, permuted_weight, m, k, n, weight_type); - if (rc != FALLBACK_TO_STANDARD) { - return rc; // 0 success, -1 error - } - FARF(HIGH, "hmx_matmul_qk: out-stationary fallback to standard m=%d k=%d n=%d", m, k, n); - // fall through to standard path - } - size_t row_stride = get_x4x2_row_stride(weight_type, k); if (row_stride == 0) { return -1; } - FARF(HIGH, "hmx_matmul_qk: STANDARD path m=%d k=%d n=%d type=%d", m, k, n, weight_type); - // --- Dynamic VTCM layout --- - const size_t vtcm_budget = ctx->vtcm_size; - const size_t vec_dot_size = k * sizeof(__fp16); + const size_t vec_dot_size = k * sizeof(__fp16); + const size_t vtcm_budget = ctx->vtcm_size; + size_t vtcm_used = 0; // Pipeline = 4-stage DMA→dequant→HMX→store with HMX worker overlap. - // Only pays off when the chunker yields >=2 n-chunks, so the main loop can - // overlap HMX (C) with HVX (B/D); with a single n-chunk the extra VTCM for - // double-buffered output and the worker-dispatch overhead are pure loss. - // Try pipeline costs first; fall back to sequential if the layout collapses - // to one n-chunk. m >= 128 floor keeps HMX utilization reasonable. - const size_t pipe_per_n = row_stride + 2 * vec_dot_size; // Q + S0 + S1 (dequant bufs) - const size_t pipe_per_mn = 2 * sizeof(__fp16); // O x 2 (output double buffer) - const size_t seq_per_n = vec_dot_size + 2 * row_stride; // W + S0 + S1 (x4x2 DMA bufs) - const size_t seq_per_mn = sizeof(__fp16); // O x 1 - - size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0, vtcm_used = 0; - bool use_pipeline = false; - - if (m >= 128) { - size_t mc = 0, nc = 0, used = 0; - if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, pipe_per_n, /*per_m=*/vec_dot_size, pipe_per_mn, - hex_align_up(m, HMX_FP16_TILE_N_ROWS), n, - /*m_block_cost=*/(size_t) n * 3, - /*n_block_cost=*/(size_t) m * 2, &mc, &nc, &used) == 0 && - hmx_ceil_div((size_t) n, nc) >= 2) { - m_chunk_n_rows = mc; - n_chunk_n_cols = nc; - vtcm_used = used; - use_pipeline = true; - } - } + const size_t size_per_n = row_stride + 2 * vec_dot_size; // Q + S0 + S1 (dequant bufs) + const size_t size_per_mn = 2 * sizeof(__fp16); // O x 2 (output double buffer) - if (!use_pipeline) { - if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, seq_per_n, /*per_m=*/vec_dot_size, seq_per_mn, - hex_align_up(m, HMX_FP16_TILE_N_ROWS), n, - /*m_block_cost=*/(size_t) n * 3, - /*n_block_cost=*/(size_t) m * 2, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) { - FARF(HIGH, "%s: VTCM too small (m=%d k=%d n=%d budget=%zu)", __func__, m, k, n, vtcm_budget); - return -1; - } + size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0; + if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, size_per_n, /*per_m=*/vec_dot_size, size_per_mn, + hex_align_up(m, HMX_FP16_TILE_N_ROWS), n, + /*m_block_cost=*/(size_t) n * 3, + /*n_block_cost=*/(size_t) m * 2, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used)) { + FARF(HIGH, "hmx-mm-q: VTCM too small : m %d k %d n %d budget %zu", m, k, n, vtcm_budget); + return -1; } - // Compute precise buffer sizes per execution path - const size_t weight_area_size = hex_align_up( - n_chunk_n_cols * (use_pipeline ? row_stride : vec_dot_size), HMX_FP16_TILE_SIZE); - const size_t activation_area_size = hex_align_up(m_chunk_n_rows * vec_dot_size, HMX_FP16_TILE_SIZE); - const size_t output_area_size = hex_align_up( - m_chunk_n_rows * n_chunk_n_cols * sizeof(__fp16), HMX_FP16_TILE_SIZE); + const size_t weight_area_size = hex_align_up(n_chunk_n_cols * row_stride, HMX_FP16_TILE_SIZE); + const size_t act_area_size = hex_align_up(m_chunk_n_rows * vec_dot_size, HMX_FP16_TILE_SIZE); + const size_t output_area_size = hex_align_up(m_chunk_n_rows * n_chunk_n_cols * sizeof(__fp16), HMX_FP16_TILE_SIZE); size_t scratch0_size, scratch1_size, scratch2_size; - if (use_pipeline) { - scratch0_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); // dequant buf 0 - scratch1_size = scratch0_size; // dequant buf 1 - scratch2_size = output_area_size; // output buf 1 - } else { - scratch0_size = hex_align_up(n_chunk_n_cols * row_stride, HMX_FP16_TILE_SIZE); // x4x2 DMA buf 0 - scratch1_size = scratch0_size; // x4x2 DMA buf 1 - scratch2_size = 0; // unused - } + scratch0_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); // dequant buf 0 + scratch1_size = scratch0_size; // dequant buf 1 + scratch2_size = output_area_size; // output buf 1 uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base; __fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_area_size); - __fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, activation_area_size); + __fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, act_area_size); __fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, output_area_size); void *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch0_size); void *vtcm_scratch1 = vtcm_seq_alloc(&vtcm_ptr, scratch1_size); void *vtcm_scratch2 = scratch2_size ? vtcm_seq_alloc(&vtcm_ptr, scratch2_size) : NULL; __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); - if ((size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base) > vtcm_budget) { - FARF(ERROR, "%s: vtcm overflow: used=%zu limit=%zu", __func__, - (size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget); + + vtcm_used = vtcm_ptr - (uint8_t *) ctx->vtcm_base; + if (vtcm_used > vtcm_budget) { + FARF(ERROR, "hmx-mm-q: VTCM overflow: used %zu budget %zu", vtcm_used, vtcm_budget); return -1; } hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16 - FARF(HIGH, "%s: m=%d k=%d n=%d wtype=%d pipe=%d mc=%zu nc=%zu vtcm=%zu/%zu", - __func__, m, k, n, weight_type, use_pipeline, - m_chunk_n_rows, n_chunk_n_cols, - (size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget); + FARF(HIGH, "hmx-mm-q: standard : m %d k %d n %d wtype %d mc %zu nc %zu vtcm %zu/%zu", + m, k, n, weight_type, m_chunk_n_rows, n_chunk_n_cols, vtcm_used, vtcm_budget); TIMER_DEFINE(activation_load); TIMER_DEFINE(weight_load); @@ -1178,184 +938,115 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds TIMER_DEFINE(total); TIMER_START(total); - FARF(HIGH, "hmx_matmul_qk: %s mc=%zu nc=%zu vtcm=%zu/%zu", - use_pipeline ? "PIPELINE" : "SEQUENTIAL", m_chunk_n_rows, n_chunk_n_cols, - (size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget); + // 4-stage pipeline: DMA load (A), dequantize (B), HMX matmul (C), store (D) + // HMX compute (C) runs on dedicated worker thread, overlapping with HVX stages (B, D). - if (!use_pipeline) { - HAP_compute_res_hmx_lock(ctx->vtcm_rctx); - for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) { - // transfer activation matrix chunk into VTCM - const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows); - const size_t n_row_tiles = hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS); + // A --> B: vtcm_qweight, 1 buffer + // B --> C: vtcm_weight0/vtcm_weight1, 2 buffers + // C --> D: vtcm_output0/vtcm_output1, 2 buffers - TIMER_START(activation_load); - { - const float *activation_chunk = activation + mr * k; - transfer_activation_chunk_threaded(ctx, vtcm_activation, activation_chunk, n_rows, k, k); - } - TIMER_STOP(activation_load); + // Async timeline (C overlaps B+D): + // main+HVX: [A0][Act][B0][A1][sub C0][B1‖C0][A2][wait,sub C1][D0+B2‖C1][wait,sub C2][D1‖C2][wait][D2] + // HMX queue: [████ C0 ████████][████ C1 ████████████][████ C2 ████████] - void *buf_curr = vtcm_scratch0; - void *buf_next = vtcm_scratch1; - - { - const size_t n_cols_first = hex_smin(n, n_chunk_n_cols); - dma_queue_push(ctx->dma[0], dma_make_ptr(buf_curr, permuted_weight), row_stride, row_stride, row_stride, n_cols_first); - } - - for (size_t nc = 0; nc < n; nc += n_chunk_n_cols) { - const size_t n_cols = hex_smin(n - nc, n_chunk_n_cols); - const size_t n_col_tiles = hmx_ceil_div(n_cols, HMX_FP16_TILE_N_COLS); - - TIMER_START(weight_load); - { - dma_queue_pop(ctx->dma[0]); // wait until current weight chunk become ready - - const size_t nc_next = nc + n_chunk_n_cols; - if (nc_next < n) { - const size_t n_cols_next = hex_smin(n - nc_next, n_chunk_n_cols); - - const uint8_t *next_weight_chunk = permuted_weight + nc_next * row_stride; - - dma_queue_push(ctx->dma[0], dma_make_ptr(buf_next, next_weight_chunk), row_stride, row_stride, row_stride, n_cols_next); - } + int n_chunk_cnt = hmx_ceil_div(n, n_chunk_n_cols); + hmx_matmul_job_t job_slots[2]; // persistent double-buffered job descriptors - // Dequant + vscatter writes directly to [K, N] transposed tiles. - // HMX computes C = A x B, where A=[M,K] activation, B=[K,N] weight. - dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight, buf_curr, n_cols, k, row_stride, weight_type); + for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) { + const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows); - hex_swap_ptr(&buf_curr, &buf_next); - } - TIMER_STOP(weight_load); + void *vtcm_qweight = vtcm_weight; + void *vtcm_weight_bufs[2] = { vtcm_scratch0, vtcm_scratch1 }; + void *vtcm_output_bufs[2] = { vtcm_output, vtcm_scratch2 }; - TIMER_START(hmx_core); - { - core_dot_chunk_fp16(vtcm_output, vtcm_activation, vtcm_weight, vtcm_scales, n_row_tiles, n_col_tiles, k / 32); - } - TIMER_STOP(hmx_core); - - TIMER_START(output_store); - { - float *output = dst + (mr * n + nc); - transfer_output_chunk_threaded(ctx, output, vtcm_output, n_rows, n_cols, n); - } - TIMER_STOP(output_store); - } + // prologue: A0 + const size_t n_cols_A0 = hex_smin(n - 0 * n_chunk_n_cols, n_chunk_n_cols); + { + const uint8_t *qweight_chunk_A0 = permuted_weight; + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A0), row_stride, row_stride, row_stride, n_cols_A0); } - HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); - } else { - // 4-stage pipeline: DMA load (A), dequantize (B), HMX matmul (C), store (D) - // HMX compute (C) runs on dedicated worker thread, overlapping with HVX stages (B, D). - - // A --> B: vtcm_qweight, 1 buffer - // B --> C: vtcm_weight0/vtcm_weight1, 2 buffers - // C --> D: vtcm_output0/vtcm_output1, 2 buffers - - // Async timeline (C overlaps B+D): - // main+HVX: [A0][Act][B0][A1][sub C0][B1‖C0][A2][wait,sub C1][D0+B2‖C1][wait,sub C2][D1‖C2][wait][D2] - // HMX queue: [████ C0 ████████][████ C1 ████████████][████ C2 ████████] - - int n_chunk_cnt = hmx_ceil_div(n, n_chunk_n_cols); - hmx_matmul_job_t job_slots[2]; // persistent double-buffered job descriptors - for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) { - const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows); - - void *vtcm_qweight = vtcm_weight; - void *vtcm_weight_bufs[2] = { vtcm_scratch0, vtcm_scratch1 }; - void *vtcm_output_bufs[2] = { vtcm_output, vtcm_scratch2 }; + { + const float *activation_chunk = activation + mr * k; + transfer_activation_chunk_threaded(ctx, vtcm_activation, activation_chunk, n_rows, k, k); + } - // prologue: A0 - const size_t n_cols_A0 = hex_smin(n - 0 * n_chunk_n_cols, n_chunk_n_cols); - { - // Use 2D DMA (n_cols rows x row_stride) to avoid 16-bit roiwidth overflow. - const uint8_t *qweight_chunk_A0 = permuted_weight; - dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A0), row_stride, row_stride, row_stride, n_cols_A0); + // prologue: B0, A1, submit C0 (async), B1 (overlaps C0) + { + // B0: wait for DMA, dequant weight chunk 0 + dma_queue_pop(ctx->dma[0]); + dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[0], vtcm_qweight, n_cols_A0, k, row_stride, weight_type); + + // A1: issue DMA for weight chunk 1 + const size_t n_cols_A1 = hex_smin(n - 1 * n_chunk_n_cols, n_chunk_n_cols); + if (1 < n_chunk_cnt) { + const uint8_t *qweight_chunk_A1 = permuted_weight + n_chunk_n_cols * row_stride; + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A1), row_stride, row_stride, row_stride, n_cols_A1); } - { - const float *activation_chunk = activation + mr * k; - transfer_activation_chunk_threaded(ctx, vtcm_activation, activation_chunk, n_rows, k, k); - } + // submit C0 (non-blocking — HMX worker executes in parallel) + hmx_matmul_job_init(&job_slots[0], (__fp16 *) vtcm_output_bufs[0], (__fp16 *) vtcm_activation, + (__fp16 *) vtcm_weight_bufs[0], vtcm_scales, + hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS), + hmx_ceil_div(n_cols_A0, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS); + hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_matmul_worker_fn, &job_slots[0])); - // prologue: B0, A1, submit C0 (async), B1 (overlaps C0) - { - // B0: wait for DMA, dequant weight chunk 0 + // B1: DMA pop + dequant (runs in parallel with C0 on HMX worker) + if (1 < n_chunk_cnt) { dma_queue_pop(ctx->dma[0]); - dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[0], vtcm_qweight, n_cols_A0, k, row_stride, weight_type); - - // A1: issue DMA for weight chunk 1 - const size_t n_cols_A1 = hex_smin(n - 1 * n_chunk_n_cols, n_chunk_n_cols); - if (1 < n_chunk_cnt) { - const uint8_t *qweight_chunk_A1 = permuted_weight + n_chunk_n_cols * row_stride; - dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A1), row_stride, row_stride, row_stride, n_cols_A1); - } - - // submit C0 (non-blocking — HMX worker executes in parallel) - hmx_matmul_job_init(&job_slots[0], (__fp16 *) vtcm_output_bufs[0], (__fp16 *) vtcm_activation, - (__fp16 *) vtcm_weight_bufs[0], vtcm_scales, - hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS), - hmx_ceil_div(n_cols_A0, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS); - hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_matmul_worker_fn, &job_slots[0])); - - // B1: DMA pop + dequant (runs in parallel with C0 on HMX worker) - if (1 < n_chunk_cnt) { - dma_queue_pop(ctx->dma[0]); - dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[1], vtcm_qweight, n_cols_A1, k, row_stride, weight_type); - } + dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[1], vtcm_qweight, n_cols_A1, k, row_stride, weight_type); } + } - // main loop: wait C_i → submit C_{i+1} → D_i + B_{i+2} (parallel with C_{i+1}) - for (int i = 0; i < n_chunk_cnt; ++i) { - const size_t nc = i * n_chunk_n_cols; - const size_t nc_p1 = nc + 1 * n_chunk_n_cols; - const size_t nc_p2 = nc + 2 * n_chunk_n_cols; + // main loop: wait C_i → submit C_{i+1} → D_i + B_{i+2} (parallel with C_{i+1}) + for (int i = 0; i < n_chunk_cnt; ++i) { + const size_t nc = i * n_chunk_n_cols; + const size_t nc_p1 = nc + 1 * n_chunk_n_cols; + const size_t nc_p2 = nc + 2 * n_chunk_n_cols; - const size_t n_cols = hex_smin(n - nc, n_chunk_n_cols); - const size_t n_cols_p1 = hex_smin(n - nc_p1, n_chunk_n_cols); - const size_t n_cols_p2 = hex_smin(n - nc_p2, n_chunk_n_cols); + const size_t n_cols = hex_smin(n - nc, n_chunk_n_cols); + const size_t n_cols_p1 = hex_smin(n - nc_p1, n_chunk_n_cols); + const size_t n_cols_p2 = hex_smin(n - nc_p2, n_chunk_n_cols); - // issue A_{i+2}: DMA push (non-blocking) - if (i + 2 < n_chunk_cnt) { - const uint8_t *qweight_chunk_p2 = permuted_weight + nc_p2 * row_stride; - dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_p2), row_stride, row_stride, row_stride, n_cols_p2); - } + // issue A_{i+2}: DMA push (non-blocking) + if (i + 2 < n_chunk_cnt) { + const uint8_t *qweight_chunk_p2 = permuted_weight + nc_p2 * row_stride; + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_p2), row_stride, row_stride, row_stride, n_cols_p2); + } - // wait C_i: block until prologue/previous C completes - hmx_queue_pop(ctx->hmx_queue); - - // submit C_{i+1} (non-blocking, overlaps with D_i + B_{i+2} below) - // job_slots[(i+1)%2] is safe: C_i just completed, freeing slot i%2's - // counterpart — and (i+1)%2 was last used by C_{i-1} which completed - // before C_i was submitted. - if (i + 1 < n_chunk_cnt) { - hmx_matmul_job_init(&job_slots[(i + 1) % 2], (__fp16 *) vtcm_output_bufs[(i + 1) % 2], - (__fp16 *) vtcm_activation, (__fp16 *) vtcm_weight_bufs[(i + 1) % 2], - vtcm_scales, hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS), - hmx_ceil_div(n_cols_p1, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS); - hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_matmul_worker_fn, &job_slots[(i + 1) % 2])); - } + // wait C_i: block until prologue/previous C completes + hmx_queue_pop(ctx->hmx_queue); + + // submit C_{i+1} (non-blocking, overlaps with D_i + B_{i+2} below) + // job_slots[(i+1)%2] is safe: C_i just completed, freeing slot i%2's + // counterpart — and (i+1)%2 was last used by C_{i-1} which completed + // before C_i was submitted. + if (i + 1 < n_chunk_cnt) { + hmx_matmul_job_init(&job_slots[(i + 1) % 2], (__fp16 *) vtcm_output_bufs[(i + 1) % 2], + (__fp16 *) vtcm_activation, (__fp16 *) vtcm_weight_bufs[(i + 1) % 2], + vtcm_scales, hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS), + hmx_ceil_div(n_cols_p1, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS); + hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_matmul_worker_fn, &job_slots[(i + 1) % 2])); + } - // D_i: store output (multi-thread HVX, parallel with C_{i+1}) - float *output_chunk = dst + (mr * n + nc); - transfer_output_chunk_threaded(ctx, output_chunk, vtcm_output_bufs[i % 2], n_rows, n_cols, n); + // D_i: store output (multi-thread HVX, parallel with C_{i+1}) + float *output_chunk = dst + (mr * n + nc); + transfer_output_chunk_threaded(ctx, output_chunk, vtcm_output_bufs[i % 2], n_rows, n_cols, n); - // B_{i+2}: DMA pop + dequant (multi-thread HVX, parallel with C_{i+1}) - if (i + 2 < n_chunk_cnt) { - dma_queue_pop(ctx->dma[0]); - dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[(i + 2) % 2], vtcm_qweight, n_cols_p2, k, row_stride, weight_type); - } + // B_{i+2}: DMA pop + dequant (multi-thread HVX, parallel with C_{i+1}) + if (i + 2 < n_chunk_cnt) { + dma_queue_pop(ctx->dma[0]); + dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[(i + 2) % 2], vtcm_qweight, n_cols_p2, k, row_stride, weight_type); } } - - hmx_queue_suspend(ctx->hmx_queue); } + hmx_queue_suspend(ctx->hmx_queue); + TIMER_STOP(total); #if defined(ENABLE_PROFILE_TIMERS) - FARF(HIGH, "%s: %lld us, m=%d k=%d n=%d pipeline=%d", __func__, TIMER_US(total), m, k, n, use_pipeline); + FARF(HIGH, "hex-mm-q: %lld us : m %d k %d n %d", TIMER_US(total), m, k, n); if (!use_pipeline) { FARF(HIGH, " activation_load: %lld us, weight_load: %lld us, hmx_core: %lld us, output_store: %lld us", TIMER_US(activation_load), TIMER_US(weight_load), TIMER_US(hmx_core), TIMER_US(output_store)); @@ -1370,15 +1061,15 @@ int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, float *restrict ds // -static inline int hmx_matmul_batch_r2(const hmx_matmul_w16a32_batched_params_t *params) { +static inline int hmx_matmul_batch_r2(const hmx_matmul_f16_f32_batched_params_t *params) { return params->ne02 > 0 ? params->ne12 / params->ne02 : 1; } -static inline int hmx_matmul_batch_r3(const hmx_matmul_w16a32_batched_params_t *params) { +static inline int hmx_matmul_batch_r3(const hmx_matmul_f16_f32_batched_params_t *params) { return params->ne03 > 0 ? params->ne13 / params->ne03 : 1; } -static inline const __fp16 *hmx_matmul_weight_batch_ptr(const hmx_matmul_w16a32_batched_params_t *params, +static inline const __fp16 *hmx_matmul_weight_batch_ptr(const hmx_matmul_f16_f32_batched_params_t *params, int dst_b2, int dst_b3) { const int r2 = hmx_matmul_batch_r2(params); const int r3 = hmx_matmul_batch_r3(params); @@ -1387,37 +1078,36 @@ static inline const __fp16 *hmx_matmul_weight_batch_ptr(const hmx_matmul_w16a32_ (size_t) (dst_b3 / r3) * params->src0_nb3); } -static inline const float *hmx_matmul_activation_batch_ptr(const hmx_matmul_w16a32_batched_params_t *params, +static inline const float *hmx_matmul_activation_batch_ptr(const hmx_matmul_f16_f32_batched_params_t *params, int dst_b2, int dst_b3) { return (const float *) ((const uint8_t *) params->activation + (size_t) dst_b2 * params->src1_nb2 + (size_t) dst_b3 * params->src1_nb3); } -static inline float *hmx_matmul_dst_batch_ptr(const hmx_matmul_w16a32_batched_params_t *params, +static inline float *hmx_matmul_dst_batch_ptr(const hmx_matmul_f16_f32_batched_params_t *params, int dst_b2, int dst_b3) { return (float *) ((uint8_t *) params->dst + (size_t) dst_b2 * params->dst_nb2 + (size_t) dst_b3 * params->dst_nb3); } -static int hmx_mat_mul_permuted_w16a32_batched_legacy(struct htp_context *ctx, - const hmx_matmul_w16a32_batched_params_t *params) { +static int hmx_matmul_f16_f32_batched_legacy(struct htp_context *ctx, + const hmx_matmul_f16_f32_batched_params_t *params) { int ret = 0; for (int b3 = 0; b3 < params->ne13 && ret == 0; ++b3) { for (int b2 = 0; b2 < params->ne12 && ret == 0; ++b2) { - ret = hmx_mat_mul_permuted_w16a32(ctx, - hmx_matmul_dst_batch_ptr(params, b2, b3), - hmx_matmul_activation_batch_ptr(params, b2, b3), - hmx_matmul_weight_batch_ptr(params, b2, b3), - params->m, params->k, params->n, - params->act_stride, params->weight_stride); + ret = hmx_matmul_f16_f32(ctx, hmx_matmul_dst_batch_ptr(params, b2, b3), + hmx_matmul_activation_batch_ptr(params, b2, b3), + hmx_matmul_weight_batch_ptr(params, b2, b3), + params->m, params->k, params->n, + params->act_stride, params->weight_stride); } } return ret; } -int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmul_w16a32_batched_params_t *params) { +int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32_batched_params_t *params) { if (!ctx || !params || !params->dst || !params->activation || !params->permuted_weight) { return -1; } if (!params->m || !params->k || !params->n) { return -1; } if (params->act_stride < params->k || params->weight_stride < params->k || params->dst_stride < params->n) { return -1; } @@ -1435,7 +1125,7 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu if (group_size <= 1) { FARF(HIGH, "%s: no dim2 GQA reuse (group=%d), using legacy batched loop", __func__, group_size); - return hmx_mat_mul_permuted_w16a32_batched_legacy(ctx, params); + return hmx_matmul_f16_f32_batched_legacy(ctx, params); } // Grouped path: reuse interleaved weight across all q_heads sharing a @@ -1464,7 +1154,7 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu /*m_block_cost=*/(size_t) params->n, /*n_block_cost=*/(size_t) params->m, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) { FARF(HIGH, "%s: grouped path does not fit VTCM, falling back to legacy batched loop", __func__); - return hmx_mat_mul_permuted_w16a32_batched_legacy(ctx, params); + return hmx_matmul_f16_f32_batched_legacy(ctx, params); } const size_t act_head_stride = m_chunk_n_rows * (size_t) params->k; // fp16 elements between heads @@ -1486,7 +1176,7 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu if ((size_t) (vtcm_ptr - (uint8_t *) ctx->vtcm_base) > vtcm_budget) { FARF(HIGH, "%s: grouped layout overflowed VTCM, falling back to legacy batched loop", __func__); - return hmx_mat_mul_permuted_w16a32_batched_legacy(ctx, params); + return hmx_matmul_f16_f32_batched_legacy(ctx, params); } hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16 @@ -1614,7 +1304,7 @@ int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, const hmx_matmu // -int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, float *restrict dst, const float *restrict activation, +int hmx_matmul_f16_f32(struct htp_context *ctx, float *restrict dst, const float *restrict activation, const __fp16 *restrict permuted_weight, int m, int k, int n, int act_stride, int weight_stride) { if (!dst || !activation || !permuted_weight || !m || !n || !k) { return -1; } diff --git a/ggml/src/ggml-hexagon/htp/hmx-ops.h b/ggml/src/ggml-hexagon/htp/hmx-ops.h index 1c78ffadd1c..f114edb822f 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-ops.h +++ b/ggml/src/ggml-hexagon/htp/hmx-ops.h @@ -33,14 +33,14 @@ typedef struct { size_t src1_nb3; size_t dst_nb2; size_t dst_nb3; -} hmx_matmul_w16a32_batched_params_t; +} hmx_matmul_f16_f32_batched_params_t; // HMX matrix multiplication — tile-permuted FP16 weights, FP32 activation/output // act_stride: activation row stride in elements (= k for contiguous, or // nb[1]/sizeof(float) for permuted tensors like attention Q). // weight_stride: weight row stride in elements (= k for compact weights, or // nb[1]/sizeof(__fp16) for permuted KV-cache views used by QK). -int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, +int hmx_matmul_f16_f32(struct htp_context *ctx, float *restrict dst, const float *activation, const __fp16 *permuted_weight, @@ -48,13 +48,12 @@ int hmx_mat_mul_permuted_w16a32(struct htp_context *ctx, int act_stride, int weight_stride); -// Batched F16 wrapper over hmx_mat_mul_permuted_w16a32. +// Batched F16 wrapper over hmx_mat_mul_f16_f32. // Batch semantics match ggml_mul_mat(): src0 broadcasts to src1 in dims 2/3. -int hmx_mat_mul_permuted_w16a32_batched(struct htp_context *ctx, - const hmx_matmul_w16a32_batched_params_t *params); +int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32_batched_params_t *params); -// HMX matrix multiplication — tile-permuted quantised weights (Q4_0/Q8_0/IQ4_NL) -int hmx_mat_mul_permuted_qk_0_d16a32(struct htp_context *ctx, +// HMX matrix multiplication — quantised weights (Q4_0/Q8_0/IQ4_NL/MXFP4) +int hmx_matmul_q_f32(struct htp_context *ctx, float *restrict dst, const float *activation, const uint8_t *permuted_weight, diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index 8e54536f619..e8619388478 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -87,35 +87,37 @@ AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) { } } +#if __HVX_ARCH__ >= 75 { - // Power on HMX + // Power on HMX and set HMX clock HAP_power_request_t request; memset(&request, 0, sizeof(HAP_power_request_t)); - request.type = HAP_power_set_HMX; - request.hmx.power_up = TRUE; - FARF(ALWAYS, "Powering HMX on\n"); - err = HAP_power_set((void *) &ctx, &request); + request.type = HAP_power_set_HMX_v2; + request.hmx_v2.set_power = TRUE; + request.hmx_v2.power_up = TRUE; + request.hmx_v2.set_clock = TRUE; + request.hmx_v2.target_corner = HAP_DCVS_EXP_VCORNER_MAX; + request.hmx_v2.min_corner = HAP_DCVS_EXP_VCORNER_MAX; + request.hmx_v2.max_corner = HAP_DCVS_EXP_VCORNER_MAX; + request.hmx_v2.perf_mode = HAP_CLK_PERF_HIGH; + FARF(ALWAYS, "Setting HMX clock\n"); + err = HAP_power_set((void *) ctx, &request); if (err != AEE_SUCCESS) { - FARF(ERROR, "Error powering on HMX."); + FARF(ERROR, "Error setting HMX clock."); return err; } } - -#if __HVX_ARCH__ >= 75 +#else { - // Set HMX clock + // Power on HMX HAP_power_request_t request; memset(&request, 0, sizeof(HAP_power_request_t)); - request.type = HAP_power_set_HMX_v2; - request.hmx_v2.set_clock = TRUE; - request.hmx_v2.target_corner = HAP_DCVS_EXP_VCORNER_MAX; - request.hmx_v2.min_corner = HAP_DCVS_EXP_VCORNER_MAX; - request.hmx_v2.max_corner = HAP_DCVS_EXP_VCORNER_MAX; - request.hmx_v2.perf_mode = HAP_CLK_PERF_HIGH; - FARF(ALWAYS, "Setting HMX clock\n"); - err = HAP_power_set((void *) &ctx, &request); + request.type = HAP_power_set_HMX; + request.hmx.power_up = TRUE; + FARF(ALWAYS, "Powering HMX on\n"); + err = HAP_power_set((void *) ctx, &request); if (err != AEE_SUCCESS) { - FARF(ERROR, "Error setting HMX clock."); + FARF(ERROR, "Error powering on HMX."); return err; } } diff --git a/ggml/src/ggml-hexagon/htp/matmul-ops.c b/ggml/src/ggml-hexagon/htp/matmul-ops.c index 2461ae617fa..46fc5602dc9 100644 --- a/ggml/src/ggml-hexagon/htp/matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/matmul-ops.c @@ -2995,7 +2995,6 @@ int op_matmul(struct htp_ops_context * octx) { // is handled by HMX itself; when M < 32 fall back to HVX. const int m_total = (int) src1->ne[1]; const int m_hmx = m_total & ~31; // 0 when M < 32 - if (m_hmx == 0) { return op_matmul_hvx(octx); } @@ -3020,7 +3019,7 @@ int op_matmul(struct htp_ops_context * octx) { if (src0->type == HTP_TYPE_F16) { if (is_batched) { - hmx_matmul_w16a32_batched_params_t batch_params = { + hmx_matmul_f16_f32_batched_params_t batch_params = { .dst = (float *) dst->data, .activation = (float *) src1->data, .permuted_weight = (const __fp16 *) src0->data, @@ -3041,15 +3040,14 @@ int op_matmul(struct htp_ops_context * octx) { .dst_nb2 = dst->nb[2], .dst_nb3 = dst->nb[3], }; - ret = hmx_mat_mul_permuted_w16a32_batched(octx->ctx, &batch_params); + ret = hmx_matmul_f16_f32_batched(octx->ctx, &batch_params); } else { - ret = hmx_mat_mul_permuted_w16a32(octx->ctx, + ret = hmx_matmul_f16_f32(octx->ctx, (float*) dst->data, (float*) src1->data, (const __fp16 *) src0->data, m_total, k, n, act_stride, wgt_stride); } } else { - ret = hmx_mat_mul_permuted_qk_0_d16a32(octx->ctx, - (float*) dst->data, (float*) src1->data, (const uint8_t *) src0->data, + ret = hmx_matmul_q_f32(octx->ctx, (float*) dst->data, (float*) src1->data, (const uint8_t *) src0->data, m_total, k, n, (int) src0->type); } From ad717a6de0727b64345d7141094e7b87c43952a1 Mon Sep 17 00:00:00 2001 From: Daniele <57776841+daniandtheweb@users.noreply.github.com> Date: Wed, 20 May 2026 17:15:13 +0200 Subject: [PATCH 662/831] vulkan: optimize operations in the IM2COL shader (llama/22685) * vulkan: optimize operations in the IM2COL shader * Add comments and improve the code formatting --- .../ggml-vulkan/vulkan-shaders/im2col.comp | 73 +++++++++++++++---- 1 file changed, 59 insertions(+), 14 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp b/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp index ba4c2103f0c..f4130d223b1 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp @@ -44,36 +44,81 @@ void im2col(const uint ow, const uint z_idx) { const uint KHKW = p.KH * p.KW; + // Precompute base input coordinates + const int base_iw = int(ow * p.s0) - p.p0; + const int base_ih = int(oh * p.s1) - p.p1; + + // Precompute step deltas + const uint delta_ic = BLOCK_SIZE / KHKW; + const uint delta_rem = BLOCK_SIZE % KHKW; + + const uint delta_ky = delta_rem / p.KW; + const uint delta_kx = delta_rem % p.KW; + + const uint delta_ic_offset = delta_ic * p.offset_delta; + + // If using BDA mode, precompute the base pointer and step size +#if BDA + const BDA_STORAGE_T base_dst_addr = p.dst_addr + D_SIZE * dst_row; + const uint bda_step = D_SIZE * BLOCK_SIZE; +#endif + uint wg_x = gl_WorkGroupID.x; do { const uint wg_offset = wg_x * 512; - [[unroll]] for (uint i = 0; i < NUM_ITER; ++i) { - const uint chw_idx = wg_offset + gidx + i * BLOCK_SIZE; + uint chw_idx = wg_offset + gidx; + + uint ic = chw_idx / KHKW; + uint rem = chw_idx % KHKW; + + uint ky = rem / p.KW; + uint kx = rem % p.KW; + uint ic_offset = src_batch + ic * p.offset_delta; + + // Initialize running pointer/index for the destination buffer +#if BDA + BDA_STORAGE_T current_dst_addr = base_dst_addr + D_SIZE * chw_idx; +#else + uint current_dst_idx = dst_row + chw_idx; +#endif + + [[unroll]] for (uint i = 0; i < NUM_ITER; ++i) { if (chw_idx >= p.CHW) { return; } - const uint ic = chw_idx / KHKW; - const uint rem = chw_idx - ic * KHKW; - const uint ky = rem / p.KW; - const uint kx = rem - ky * p.KW; - - const uint iiw = ow * p.s0 + kx * p.d0 - p.p0; - const uint iih = oh * p.s1 + ky * p.d1 - p.p1; + const int iiw = base_iw + int(kx * p.d0); + const int iih = base_ih + int(ky * p.d1); A_TYPE val = A_TYPE(0); - if (iih < p.IH && iiw < p.IW) { - val = data_a[src_batch + ic * p.offset_delta + iih * p.IW + iiw]; + if (uint(iih) < p.IH && uint(iiw) < p.IW) { + val = data_a[ic_offset + uint(iih) * p.IW + uint(iiw)]; } #if BDA - D_ptr out_ptr = D_ptr(p.dst_addr + D_SIZE * (dst_row + chw_idx)); - out_ptr.d = D_TYPE(val); + D_ptr(current_dst_addr).d = D_TYPE(val); + current_dst_addr += bda_step; #else - data_d[dst_row + chw_idx] = D_TYPE(val); + data_d[current_dst_idx] = D_TYPE(val); + current_dst_idx += BLOCK_SIZE; #endif + + chw_idx += BLOCK_SIZE; + ic_offset += delta_ic_offset; + kx += delta_kx; + ky += delta_ky; + + // Handle X axis wrap + uint kx_wrap = uint(kx >= p.KW); + kx -= kx_wrap * p.KW; + ky += kx_wrap; + + // Handle Y axis wrap + uint ky_wrap = uint(ky >= p.KH); + ky -= ky_wrap * p.KH; + ic_offset += ky_wrap * p.offset_delta; } wg_x += gl_NumWorkGroups.x; From 896718eacf0fd975368784a917d2e4d1856a6e70 Mon Sep 17 00:00:00 2001 From: lhez Date: Wed, 20 May 2026 09:57:36 -0700 Subject: [PATCH 663/831] opencl: refactor backend initilization (llama/23318) * opencl: refactor initialization * opencl: refactor GPU identification * opencl: rename for consistency * opencl: cache global mem size in dev_ctx * opencl: adjust log level * opencl: load argsort and flash_attn kernels in supports_op * argsort kernel must be built for supports_op for querying the max workgroups * flash_attn kernel has many variants, only load them when needed --- ggml/src/ggml-opencl/ggml-opencl.cpp | 429 ++++++++++++++++----------- 1 file changed, 254 insertions(+), 175 deletions(-) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index a3af8c2da41..5fc46f789ec 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -375,6 +375,11 @@ struct ggml_backend_opencl_device_context { ggml_backend_buffer_type buffer_type; cl_context context = nullptr; + + GPU_FAMILY gpu_family = GPU_FAMILY::UNKNOWN; + ADRENO_GPU_GEN adreno_gen = ADRENO_GPU_GEN::ADRENO_UNKNOWN; + + size_t global_mem_size = 0; }; // backend context @@ -384,6 +389,18 @@ struct ggml_backend_opencl_context { cl_device_id device; std::string device_name; + ggml_cl_version platform_version; + ggml_cl_version opencl_c_version; + + // argsort is loaded in supports_op because its availability depends on how + // many workgroups are allowed, which requires kernel compilation. + bool kernels_loaded_argsort = false; + // flash attn is loaded in supports_op because it contains multiple variants + // and takes time to compile, so we want to only compile it when needed. + bool kernels_loaded_flash_attn = false; + // rest of the kernels are currently always loaded in alloc_buffer. + bool kernels_loaded = false; + std::string driver_version; GPU_FAMILY gpu_family; @@ -781,6 +798,8 @@ struct ggml_backend_opencl_context { #endif // GGML_OPENCL_USE_ADRENO_KERNELS void free() { + clFinish(queue); + ref_count--; if (ref_count == 0) { #ifdef GGML_OPENCL_PROFILING @@ -793,6 +812,9 @@ struct ggml_backend_opencl_context { // All registered devices with a default device in the front. static std::vector g_ggml_backend_opencl_devices; +// All device contexts associated with the devices above. +// The devices live as long as the process, so do the contexts. +static std::vector> g_ggml_backend_opencl_dev_ctxs; inline std::string read_file(const std::string &path) { std::ifstream ifs(path); @@ -836,12 +858,120 @@ static cl_program build_program_from_source(cl_context ctx, cl_device_id dev, co return p; } -static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_version opencl_c_version) { +static void load_cl_kernels_argsort(ggml_backend_opencl_context *backend_ctx) { + // compiler options for general kernels + auto opencl_c_std = + std::string("CL") + std::to_string(backend_ctx->opencl_c_version.major) + "." + std::to_string(backend_ctx->opencl_c_version.minor); + std::string compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable -cl-unsafe-math-optimizations" + " -cl-finite-math-only -cl-fast-relaxed-math"; + + // argsort + if (!backend_ctx->kernels_loaded_argsort) { + cl_int err; +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "argsort.cl.h" + }; +#else + const std::string kernel_src = read_file("argsort.cl"); +#endif + backend_ctx->program_argsort_f32_i32 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_argsort_f32_i32 = clCreateKernel(backend_ctx->program_argsort_f32_i32, "kernel_argsort_f32_i32", &err), err)); + backend_ctx->kernels_loaded_argsort = true; + } +} + +static void load_cl_kernels_flash_attn(ggml_backend_opencl_context *backend_ctx) { + // compiler options for general kernels + auto opencl_c_std = + std::string("CL") + std::to_string(backend_ctx->opencl_c_version.major) + "." + std::to_string(backend_ctx->opencl_c_version.minor); + std::string compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable -cl-unsafe-math-optimizations" + " -cl-finite-math-only -cl-fast-relaxed-math"; + + // flash_attn + if (!backend_ctx->kernels_loaded_flash_attn) { + cl_int err; + + #ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src_f16 { + #include "flash_attn_f16.cl.h" + }; + const std::string kernel_src_f32 { + #include "flash_attn_f32.cl.h" + }; + const std::string kernel_src_f32_f16 { + #include "flash_attn_f32_f16.cl.h" + }; + #else + const std::string kernel_src_f16 = read_file("flash_attn_f16.cl"); + const std::string kernel_src_f32 = read_file("flash_attn_f32.cl"); + const std::string kernel_src_f32_f16 = read_file("flash_attn_f32_f16.cl"); + #endif + + if (!kernel_src_f16.empty() && !kernel_src_f32.empty() && !kernel_src_f32_f16.empty()) { + const struct { int dk; int dv; int bm; int bn; } fa_dims[] = { + { 40, 40, 32, 32}, { 64, 64, 64, 64}, { 80, 80, 64, 32}, { 96, 96, 64, 32}, + {112, 112, 32, 32}, {128, 128, 32, 32}, {192, 128, 16, 16}, + {192, 192, 16, 16}, {256, 256, 16, 16}, + }; + + for (size_t i = 0; i < sizeof(fa_dims)/sizeof(fa_dims[0]); ++i) { + const int dk = fa_dims[i].dk; + const int dv = fa_dims[i].dv; + const int bm = fa_dims[i].bm; + const int bn = fa_dims[i].bn; + std::string OPTS = compile_opts + + " -D DK=" + std::to_string(dk) + + " -D DV=" + std::to_string(dv) + + " -D BLOCK_M=" + std::to_string(bm) + + " -D BLOCK_N=" + std::to_string(bn); + + cl_program prog_f16 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f16.c_str(), OPTS); + cl_kernel k_f16, k_f16_q1; + CL_CHECK((k_f16 = clCreateKernel(prog_f16, "flash_attn_f16", &err), err)); + CL_CHECK((k_f16_q1 = clCreateKernel(prog_f16, "flash_attn_f16_q1", &err), err)); + backend_ctx->kernels_flash_attn_f16[{dk, dv}] = k_f16; + backend_ctx->kernels_flash_attn_f16_q1[{dk, dv}] = k_f16_q1; + CL_CHECK(clReleaseProgram(prog_f16)); + + cl_program prog_f32 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f32.c_str(), OPTS); + cl_kernel k_f32, k_f32_q1; + CL_CHECK((k_f32 = clCreateKernel(prog_f32, "flash_attn_f32", &err), err)); + CL_CHECK((k_f32_q1 = clCreateKernel(prog_f32, "flash_attn_f32_q1", &err), err)); + backend_ctx->kernels_flash_attn_f32[{dk, dv}] = k_f32; + backend_ctx->kernels_flash_attn_f32_q1[{dk, dv}] = k_f32_q1; + CL_CHECK(clReleaseProgram(prog_f32)); + + cl_program prog_f32_f16 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f32_f16.c_str(), OPTS); + cl_kernel k_f32_f16, k_f32_f16_q1; + CL_CHECK((k_f32_f16 = clCreateKernel(prog_f32_f16, "flash_attn_f32_f16", &err), err)); + CL_CHECK((k_f32_f16_q1 = clCreateKernel(prog_f32_f16, "flash_attn_f32_f16_q1", &err), err)); + backend_ctx->kernels_flash_attn_f32_f16[{dk, dv}] = k_f32_f16; + backend_ctx->kernels_flash_attn_f32_f16_q1[{dk, dv}] = k_f32_f16_q1; + CL_CHECK(clReleaseProgram(prog_f32_f16)); + + backend_ctx->kernels_flash_attn_bm[{dk, dv}] = bm; + backend_ctx->kernels_flash_attn_bn[{dk, dv}] = bn; + } + backend_ctx->kernels_loaded_flash_attn = true; + } + } +} + +static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx) { + if (backend_ctx->kernels_loaded) { + return; + } + cl_int err; // compiler options for general kernels auto opencl_c_std = - std::string("CL") + std::to_string(opencl_c_version.major) + "." + std::to_string(opencl_c_version.minor); + std::string("CL") + std::to_string(backend_ctx->opencl_c_version.major) + "." + std::to_string(backend_ctx->opencl_c_version.minor); std::string compile_opts = std::string("-cl-std=") + opencl_c_std + " -cl-mad-enable -cl-unsafe-math-optimizations" " -cl-finite-math-only -cl-fast-relaxed-math"; @@ -1986,89 +2116,6 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } - // flash_attn - { - #ifdef GGML_OPENCL_EMBED_KERNELS - const std::string kernel_src_f16 { - #include "flash_attn_f16.cl.h" - }; - const std::string kernel_src_f32 { - #include "flash_attn_f32.cl.h" - }; - const std::string kernel_src_f32_f16 { - #include "flash_attn_f32_f16.cl.h" - }; - #else - const std::string kernel_src_f16 = read_file("flash_attn_f16.cl"); - const std::string kernel_src_f32 = read_file("flash_attn_f32.cl"); - const std::string kernel_src_f32_f16 = read_file("flash_attn_f32_f16.cl"); - #endif - - if (!kernel_src_f16.empty() && !kernel_src_f32.empty() && !kernel_src_f32_f16.empty()) { - const struct { int dk; int dv; int bm; int bn; } fa_dims[] = { - { 40, 40, 32, 32}, { 64, 64, 64, 64}, { 80, 80, 64, 32}, { 96, 96, 64, 32}, - {112, 112, 32, 32}, {128, 128, 32, 32}, {192, 128, 16, 16}, - {192, 192, 16, 16}, {256, 256, 16, 16}, - }; - - for (size_t i = 0; i < sizeof(fa_dims)/sizeof(fa_dims[0]); ++i) { - const int dk = fa_dims[i].dk; - const int dv = fa_dims[i].dv; - const int bm = fa_dims[i].bm; - const int bn = fa_dims[i].bn; - std::string OPTS = compile_opts + - " -D DK=" + std::to_string(dk) + - " -D DV=" + std::to_string(dv) + - " -D BLOCK_M=" + std::to_string(bm) + - " -D BLOCK_N=" + std::to_string(bn); - - cl_program prog_f16 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f16.c_str(), OPTS); - cl_kernel k_f16, k_f16_q1; - CL_CHECK((k_f16 = clCreateKernel(prog_f16, "flash_attn_f16", &err), err)); - CL_CHECK((k_f16_q1 = clCreateKernel(prog_f16, "flash_attn_f16_q1", &err), err)); - backend_ctx->kernels_flash_attn_f16[{dk, dv}] = k_f16; - backend_ctx->kernels_flash_attn_f16_q1[{dk, dv}] = k_f16_q1; - CL_CHECK(clReleaseProgram(prog_f16)); - - cl_program prog_f32 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f32.c_str(), OPTS); - cl_kernel k_f32, k_f32_q1; - CL_CHECK((k_f32 = clCreateKernel(prog_f32, "flash_attn_f32", &err), err)); - CL_CHECK((k_f32_q1 = clCreateKernel(prog_f32, "flash_attn_f32_q1", &err), err)); - backend_ctx->kernels_flash_attn_f32[{dk, dv}] = k_f32; - backend_ctx->kernels_flash_attn_f32_q1[{dk, dv}] = k_f32_q1; - CL_CHECK(clReleaseProgram(prog_f32)); - - cl_program prog_f32_f16 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f32_f16.c_str(), OPTS); - cl_kernel k_f32_f16, k_f32_f16_q1; - CL_CHECK((k_f32_f16 = clCreateKernel(prog_f32_f16, "flash_attn_f32_f16", &err), err)); - CL_CHECK((k_f32_f16_q1 = clCreateKernel(prog_f32_f16, "flash_attn_f32_f16_q1", &err), err)); - backend_ctx->kernels_flash_attn_f32_f16[{dk, dv}] = k_f32_f16; - backend_ctx->kernels_flash_attn_f32_f16_q1[{dk, dv}] = k_f32_f16_q1; - CL_CHECK(clReleaseProgram(prog_f32_f16)); - - backend_ctx->kernels_flash_attn_bm[{dk, dv}] = bm; - backend_ctx->kernels_flash_attn_bn[{dk, dv}] = bn; - } - GGML_LOG_CONT("."); - } - } - - // argsort - { -#ifdef GGML_OPENCL_EMBED_KERNELS - const std::string kernel_src { - #include "argsort.cl.h" - }; -#else - const std::string kernel_src = read_file("argsort.cl"); -#endif - backend_ctx->program_argsort_f32_i32 = - build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - - CL_CHECK((backend_ctx->kernel_argsort_f32_i32 = clCreateKernel(backend_ctx->program_argsort_f32_i32, "kernel_argsort_f32_i32", &err), err)); - GGML_LOG_CONT("."); - } - // div { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -3335,13 +3382,15 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve } #endif // GGML_OPENCL_USE_ADRENO_KERNELS GGML_LOG_CONT("\n"); + backend_ctx->kernels_loaded = true; } // XXX static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { // XXX static bool initialized = false; // XXX static ggml_backend_opencl_context *backend_ctx = nullptr; -static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev); +static ggml_backend_opencl_context * ggml_cl_init(ggml_backend_dev_t dev); +static bool ggml_opencl_is_device_supported(ggml_backend_dev_t dev); namespace /* anonymous */ { extern struct ggml_backend_device_i ggml_backend_opencl_device_i; @@ -3554,13 +3603,13 @@ static std::vector ggml_opencl_probe_devices(ggml_backend_r /* .context = */ dev_ctx.get(), }); - if (!ggml_cl2_init(&found_devices.back())) { + if (!ggml_opencl_is_device_supported(&found_devices.back())) { found_devices.pop_back(); - GGML_LOG_INFO("ggml_opencl: drop unsupported device.\n"); + GGML_LOG_WARN("ggml_opencl: drop unsupported device '%s'.\n", dev->name); continue; } - dev_ctx.release(); + g_ggml_backend_opencl_dev_ctxs.push_back(std::move(dev_ctx)); } if (found_devices.size()) { @@ -3577,8 +3626,79 @@ static std::vector ggml_opencl_probe_devices(ggml_backend_r return found_devices; } +// check if device should be accepted +static bool ggml_opencl_is_device_supported(ggml_backend_dev_t dev) { + GGML_ASSERT(dev); + GGML_ASSERT(dev->context); + + ggml_backend_opencl_device_context * dev_ctx = (ggml_backend_opencl_device_context *) dev->context; + GGML_ASSERT(dev_ctx->platform); + GGML_ASSERT(dev_ctx->device); + + if (strstr(dev_ctx->device_name.c_str(), "Adreno") || + strstr(dev_ctx->device_name.c_str(), "Qualcomm") || + strstr(dev_ctx->device_version.c_str(), "Adreno")) { + dev_ctx->gpu_family = GPU_FAMILY::ADRENO; + + // Usually device version contains the detailed device name + dev_ctx->adreno_gen = get_adreno_gpu_gen(dev_ctx->device_version.c_str()); + if (dev_ctx->adreno_gen == ADRENO_GPU_GEN::ADRENO_UNKNOWN) { + dev_ctx->adreno_gen = get_adreno_gpu_gen(dev_ctx->device_name.c_str()); + } + } else if (strstr(dev_ctx->device_name.c_str(), "Intel")) { + dev_ctx->gpu_family = GPU_FAMILY::INTEL; + } else { + GGML_LOG_WARN("ggml_opencl: unsupported GPU '%s'.\n", dev_ctx->device_name.c_str()); + dev_ctx->gpu_family = GPU_FAMILY::UNKNOWN; + return false; + } + + ggml_cl_version platform_version = get_opencl_platform_version(dev_ctx->platform); + + // Check device OpenCL version, OpenCL 2.0 or above is required + ggml_cl_version opencl_c_version = get_opencl_c_version(platform_version, dev_ctx->device); + if (opencl_c_version.major < 2) { + GGML_LOG_WARN("ggml_opencl: OpenCL 2.0 or above is required\n"); + return false; + } + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (dev_ctx->gpu_family != GPU_FAMILY::ADRENO) { + GGML_LOG_WARN("ggml_opencl: Adreno-specific kernels should not be enabled for non-Adreno GPUs; " + "run on an Adreno GPU or recompile with CMake option `-DGGML_OPENCL_USE_ADRENO_KERNELS=OFF`\n"); + return false; + } +#endif + + size_t ext_str_size; + clGetDeviceInfo(dev_ctx->device, CL_DEVICE_EXTENSIONS, 0, NULL, &ext_str_size); + + char *ext_buffer = (char *)alloca(ext_str_size + 1); + clGetDeviceInfo(dev_ctx->device, CL_DEVICE_EXTENSIONS, ext_str_size, ext_buffer, NULL); + ext_buffer[ext_str_size] = '\0'; + + // Check if ext_buffer contains cl_khr_fp16 + bool fp16_support = strstr(ext_buffer, "cl_khr_fp16") != NULL; + if (!fp16_support) { + GGML_LOG_WARN("ggml_opencl: device does not support FP16\n"); + return false; + } + + // If OpenCL 3.0 is supported, then check for cl_khr_subgroups, which becomes + // optional in OpenCL 3.0 (cl_khr_subgroup is mandatory in OpenCL 2.x) + if (opencl_c_version.major == 3 && strstr(ext_buffer, "cl_khr_subgroups") == NULL && + strstr(ext_buffer, "cl_intel_subgroups") == NULL) { + GGML_LOG_WARN("ggml_opencl: device does not support subgroups (cl_khr_subgroups or cl_intel_subgroups) " + "(note that subgroups is an optional feature in OpenCL 3.0)\n"); + return false; + } + + clGetDeviceInfo(dev_ctx->device, CL_DEVICE_GLOBAL_MEM_SIZE, sizeof(size_t), &dev_ctx->global_mem_size, NULL); + return true; +} + // Initialize device if it is supported (returns nullptr if it is not). -static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { +static ggml_backend_opencl_context * ggml_cl_init(ggml_backend_dev_t dev) { GGML_ASSERT(dev); GGML_ASSERT(dev->context); @@ -3600,33 +3720,12 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { // when the associated device is initialized backend_ctx->ref_count = 0; - if (strstr(dev_ctx->device_name.c_str(), "Adreno") || - strstr(dev_ctx->device_name.c_str(), "Qualcomm") || - strstr(dev_ctx->device_version.c_str(), "Adreno")) { - backend_ctx->gpu_family = GPU_FAMILY::ADRENO; - // Usually device version contains the detailed device name - backend_ctx->adreno_gen = get_adreno_gpu_gen(dev_ctx->device_version.c_str()); - if (backend_ctx->adreno_gen == ADRENO_GPU_GEN::ADRENO_UNKNOWN) { - backend_ctx->adreno_gen = get_adreno_gpu_gen(dev_ctx->device_name.c_str()); - } - + backend_ctx->gpu_family = dev_ctx->gpu_family; + backend_ctx->adreno_gen = dev_ctx->adreno_gen; + if (backend_ctx->gpu_family == GPU_FAMILY::ADRENO) { // Use wave size of 64 for all Adreno GPUs. backend_ctx->adreno_wave_size = 64; - } else if (strstr(dev_ctx->device_name.c_str(), "Intel")) { - backend_ctx->gpu_family = GPU_FAMILY::INTEL; - } else { - GGML_LOG_ERROR("Unsupported GPU: %s\n", dev_ctx->device_name.c_str()); - backend_ctx->gpu_family = GPU_FAMILY::UNKNOWN; - return nullptr; - } - -#ifdef GGML_OPENCL_USE_ADRENO_KERNELS - if (backend_ctx->gpu_family != GPU_FAMILY::ADRENO) { - GGML_LOG_ERROR("ggml_opencl: Adreno-specific kernels should not be enabled for non-Adreno GPUs; " - "run on an Adreno GPU or recompile with CMake option `-DGGML_OPENCL_USE_ADRENO_KERNELS=OFF`\n"); - return nullptr; } -#endif // Populate backend device name backend_ctx->device_name = dev_ctx->device_name; @@ -3635,13 +3734,10 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { cl_device_id device = backend_ctx->device; ggml_cl_version platform_version = get_opencl_platform_version(dev_ctx->platform); - - // Check device OpenCL version, OpenCL 2.0 or above is required ggml_cl_version opencl_c_version = get_opencl_c_version(platform_version, device); - if (opencl_c_version.major < 2) { - GGML_LOG_ERROR("ggml_opencl: OpenCL 2.0 or above is required\n"); - return nullptr; - } + + backend_ctx->platform_version = platform_version; + backend_ctx->opencl_c_version = opencl_c_version; // Check driver version size_t driver_version_str_size; @@ -3664,34 +3760,21 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { char *ext_buffer = (char *)alloca(ext_str_size + 1); clGetDeviceInfo(device, CL_DEVICE_EXTENSIONS, ext_str_size, ext_buffer, NULL); ext_buffer[ext_str_size] = '\0'; // ensure it is null terminated + // Check if ext_buffer contains cl_khr_fp16 backend_ctx->fp16_support = strstr(ext_buffer, "cl_khr_fp16") != NULL; GGML_LOG_INFO("ggml_opencl: device FP16 support: %s\n", backend_ctx->fp16_support ? "true" : "false"); + // check Adreno large buffer support backend_ctx->adreno_has_large_buffer = strstr(ext_buffer, "cl_qcom_large_buffer") != NULL; - // fp16 is required - if (!backend_ctx->fp16_support) { - GGML_LOG_ERROR("ggml_opencl: device does not support FP16\n"); - return nullptr; - } - - // If OpenCL 3.0 is supported, then check for cl_khr_subgroups, which becomes - // optional in OpenCL 3.0 (cl_khr_subgroup is mandatory in OpenCL 2.x) - if (opencl_c_version.major == 3 && strstr(ext_buffer, "cl_khr_subgroups") == NULL && - strstr(ext_buffer, "cl_intel_subgroups") == NULL) { - GGML_LOG_ERROR("ggml_opencl: device does not support subgroups (cl_khr_subgroups or cl_intel_subgroups) " - "(note that subgroups is an optional feature in OpenCL 3.0)\n"); - return nullptr; - } - cl_uint base_align_in_bits; CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_MEM_BASE_ADDR_ALIGN, sizeof(cl_uint), &base_align_in_bits, NULL)); GGML_ASSERT(base_align_in_bits % 8u == 0); backend_ctx->alignment = base_align_in_bits / 8u; GGML_LOG_INFO("ggml_opencl: mem base addr align: %u\n", backend_ctx->alignment); - clGetDeviceInfo(device, CL_DEVICE_GLOBAL_MEM_SIZE, sizeof(size_t), &backend_ctx->global_mem_size, NULL); + backend_ctx->global_mem_size = dev_ctx->global_mem_size; GGML_LOG_INFO("ggml_opencl: global mem size: %zu MB\n", backend_ctx->global_mem_size/1024/1024); clGetDeviceInfo(device, CL_DEVICE_MAX_MEM_ALLOC_SIZE, sizeof(size_t), &backend_ctx->max_alloc_size, NULL); @@ -3779,8 +3862,8 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { #endif CL_CHECK((backend_ctx->queue = clCreateCommandQueue(context, device, command_queue_props, &err), err)); - // Load kernels - load_cl_kernels(backend_ctx.get(), opencl_c_version); + // delay kernel loading until the first buffer is created + // load_cl_kernels(backend_ctx.get()); #ifdef GGML_OPENCL_USE_ADRENO_KERNELS // Allocate intermediate buffers and images @@ -3822,22 +3905,9 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { return dev_ctx->backend_ctx; } -static void ggml_cl2_free(ggml_backend_t backend) { +static void ggml_cl_free(ggml_backend_t backend) { ggml_backend_opencl_context * ctx = (ggml_backend_opencl_context *) backend->context; ctx->free(); - - // The CL context is shared by all backends, release it if all backends have been released - bool should_release_opencl = true; - for (auto device : g_ggml_backend_opencl_devices) { - ggml_backend_opencl_device_context * ctx_dev = (ggml_backend_opencl_device_context *) device.context; - if (ctx_dev->backend_ctx->ref_count > 0) { - should_release_opencl = false; - } - } - - if (should_release_opencl) { - CL_CHECK(clReleaseContext(ctx->context)); - } } #ifdef GGML_OPENCL_USE_ADRENO_KERNELS @@ -4421,7 +4491,7 @@ static const char * ggml_backend_opencl_name(ggml_backend_t backend) { } static void ggml_backend_opencl_free(ggml_backend_t backend) { - ggml_cl2_free(backend); + ggml_cl_free(backend); } static void ggml_backend_opencl_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { @@ -4460,14 +4530,17 @@ static void ggml_backend_opencl_synchronize(ggml_backend_t backend) { // enqueued to it won't start until commands in the other devices have // completed. static void sync_with_other_backends(ggml_backend_opencl_context * backend_ctx) { - if (g_ggml_backend_opencl_devices.size() < 2) - return; // No other devices to synchronize with. + if (g_ggml_backend_opencl_devices.size() < 2) { + return; // No other devices to synchronize with. + } std::vector events; events.reserve(g_ggml_backend_opencl_devices.size()); for (ggml_backend_device & backend_dev : g_ggml_backend_opencl_devices) { - auto * other_backend_ctx = ggml_cl2_init(&backend_dev); + ggml_backend_opencl_device_context * dev_ctx = (ggml_backend_opencl_device_context *) backend_dev.context; + auto * other_backend_ctx = dev_ctx->backend_ctx; + if (backend_ctx != other_backend_ctx) { cl_event ev; CL_CHECK(clEnqueueMarkerWithWaitList(other_backend_ctx->queue, 0, nullptr, &ev)); @@ -4880,6 +4953,8 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te case GGML_OP_IM2COL: return true; case GGML_OP_ARGSORT: { + load_cl_kernels_argsort(backend_ctx); + cl_kernel kernel = backend_ctx->kernel_argsort_f32_i32; int max_workgroup_size = backend_ctx->get_kernel_workgroup_size(kernel); @@ -4897,6 +4972,8 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_FLASH_ATTN_EXT: { + load_cl_kernels_flash_attn(backend_ctx); + const ggml_tensor * q = op->src[0]; const ggml_tensor * k = op->src[1]; const ggml_tensor * v = op->src[2]; @@ -4964,7 +5041,7 @@ static ggml_backend_i ggml_backend_opencl_i = { ggml_backend_t ggml_backend_opencl_init(void) { ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_opencl_reg(), 0); - ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(dev); + ggml_backend_opencl_context *backend_ctx = ggml_cl_init(dev); ggml_backend_t backend = new ggml_backend { /* .guid = */ ggml_backend_opencl_guid(), @@ -5343,15 +5420,13 @@ static void ggml_backend_opencl_buffer_free_buffer(ggml_backend_buffer_t buffer) } static void * ggml_backend_opencl_buffer_get_base(ggml_backend_buffer_t buffer) { - ggml_backend_opencl_context * backend_ctx = ggml_cl2_init(buffer->buft->device); - return (void *) (uintptr_t) backend_ctx->alignment; + ggml_backend_opencl_device_context * dev_ctx = (ggml_backend_opencl_device_context *) buffer->buft->device->context; + return (void *) (uintptr_t) dev_ctx->backend_ctx->alignment; } static enum ggml_status ggml_backend_opencl_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; - ggml_cl2_init(buffer->buft->device); - if (tensor->view_src != nullptr) { GGML_ASSERT(tensor->view_src->buffer->buft == buffer->buft); @@ -5391,7 +5466,8 @@ static enum ggml_status ggml_backend_opencl_buffer_init_tensor(ggml_backend_buff } static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { - ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(buffer->buft->device); + ggml_backend_opencl_device_context * dev_ctx = (ggml_backend_opencl_device_context *) buffer->buft->device->context; + ggml_backend_opencl_context * backend_ctx = dev_ctx->backend_ctx; cl_context context = backend_ctx->context; cl_command_queue queue = backend_ctx->queue; @@ -6626,7 +6702,8 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { GGML_ASSERT(tensor->extra); - ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(buffer->buft->device); + ggml_backend_opencl_device_context * dev_ctx = (ggml_backend_opencl_device_context *) buffer->buft->device->context; + ggml_backend_opencl_context *backend_ctx = dev_ctx->backend_ctx; cl_context context = backend_ctx->context; cl_command_queue queue = backend_ctx->queue; @@ -7470,8 +7547,9 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, } static void ggml_backend_opencl_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { - ggml_backend_dev_t dev = buffer->buft->device; - ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(dev); + ggml_backend_opencl_device_context * dev_ctx = (ggml_backend_opencl_device_context *) buffer->buft->device->context; + ggml_backend_opencl_context * backend_ctx = dev_ctx->backend_ctx; + cl_command_queue queue = backend_ctx->queue; ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; @@ -7511,7 +7589,8 @@ static const char * ggml_backend_opencl_buffer_type_get_name(ggml_backend_buffer } static ggml_backend_buffer_t ggml_backend_opencl_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buffer_type, size_t size) { - ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(buffer_type->device); + ggml_backend_opencl_context *backend_ctx = ggml_cl_init(buffer_type->device); + load_cl_kernels(backend_ctx); // clCreateBuffer returns -61 for size 0 size = std::max(size, (size_t)1); @@ -7534,15 +7613,15 @@ static ggml_backend_buffer_t ggml_backend_opencl_buffer_type_alloc_buffer(ggml_b } static size_t ggml_backend_opencl_buffer_type_get_alignment(ggml_backend_buffer_type_t buffer_type) { - ggml_backend_opencl_context * backend_ctx = ggml_cl2_init(buffer_type->device); - return backend_ctx->alignment; + ggml_backend_opencl_device_context * dev_ctx = (ggml_backend_opencl_device_context *) buffer_type->device->context; + return dev_ctx->backend_ctx->alignment; } static size_t ggml_backend_opencl_buffer_type_get_max_size(ggml_backend_buffer_type_t buffer_type) { static size_t max_size = -1; if (max_size == (size_t)-1) { - ggml_backend_opencl_context * backend_ctx = ggml_cl2_init(buffer_type->device); - max_size = backend_ctx->max_alloc_size; + ggml_backend_opencl_device_context * dev_ctx = (ggml_backend_opencl_device_context *) buffer_type->device->context; + max_size = dev_ctx->backend_ctx->max_alloc_size; } return max_size; } @@ -7579,14 +7658,13 @@ static const char * ggml_backend_opencl_device_get_description(ggml_backend_dev_ static void ggml_backend_opencl_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { ggml_backend_opencl_device_context * dev_ctx = (ggml_backend_opencl_device_context *) dev->context; - ggml_backend_opencl_context * backend_ctx = (ggml_backend_opencl_context *) dev_ctx->backend_ctx; static const size_t opencl_extra_margin = 1024ull*1024ull*1024ull; // OpenCL does not provide reliable currently-free device memory. // Use total/global memory as a best-effort upper bound. // Improved safety: Reduce by a 1GiB extra margin for common --fit - *total = backend_ctx->global_mem_size; + *total = dev_ctx->global_mem_size; *free = *total > opencl_extra_margin ? *total - opencl_extra_margin : 0; } @@ -7610,7 +7688,7 @@ static void ggml_backend_opencl_device_get_props(ggml_backend_dev_t dev, struct } static ggml_backend_t ggml_backend_opencl_device_init(ggml_backend_dev_t dev, const char * params) { - ggml_backend_opencl_context * backend_ctx = ggml_cl2_init(dev); + ggml_backend_opencl_context * backend_ctx = ggml_cl_init(dev); // Getting a new reference to the backend, increase ref_count backend_ctx->ref_count++; @@ -7647,6 +7725,7 @@ static ggml_backend_buffer_t ggml_backend_opencl_device_buffer_from_ptr(ggml_bac } static bool ggml_backend_opencl_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) { + ggml_cl_init(dev); return ggml_opencl_supports_op(dev, op); } @@ -7659,8 +7738,8 @@ static bool ggml_backend_opencl_device_supports_buft(ggml_backend_dev_t dev, ggm // Check cl_context is the same. clEnqueue* commands may not use // buffers from another cl_context. - ggml_backend_opencl_context * backend_ctx0 = ggml_cl2_init(dev); - ggml_backend_opencl_context * backend_ctx1 = ggml_cl2_init(buft->device); + ggml_backend_opencl_context * backend_ctx0 = ggml_cl_init(dev); + ggml_backend_opencl_context * backend_ctx1 = ggml_cl_init(buft->device); return backend_ctx0->context == backend_ctx1->context; } From 6d1d66de407f67ece396467da59124fe8073bd69 Mon Sep 17 00:00:00 2001 From: Todor Boinovski Date: Wed, 20 May 2026 22:14:13 -0700 Subject: [PATCH 664/831] hexagon: ssm-conv fix for large prompts (llama/23307) * hexagon: remove gathers and better handling of vtcm in ssm-conv * hexagon: relax ssm-conv gating requirements * hexagon: add new prefill ssm-conv backend test * hexagon: remove trailing white space * hex-rope: uninline rope_cache_init, otherwise it breaks after rebaseing with SSM_CONV changes --------- Co-authored-by: Max Krasnyansky --- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 7 +- ggml/src/ggml-hexagon/htp/rope-ops.c | 4 +- ggml/src/ggml-hexagon/htp/ssm-conv.c | 388 +++++++++++++++---------- 3 files changed, 246 insertions(+), 153 deletions(-) diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 080fb7f47e3..9db99cb0f3a 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -2735,9 +2735,10 @@ static bool ggml_hexagon_supported_ssm_conv(const struct ggml_hexagon_session * if (dst->ne[0] != d_inner || dst->ne[1] != n_t || dst->ne[2] != n_s) { return false; } - - // TODO: add support for non-contiguous tensors - if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) { + if (src0->nb[0] != sizeof(float) || src1->nb[0] != sizeof(float) || dst->nb[0] != sizeof(float)) { + return false; + } + if (src0->nb[1] != src0->ne[0] * sizeof(float) || src1->nb[1] != src1->ne[0] * sizeof(float)) { return false; } diff --git a/ggml/src/ggml-hexagon/htp/rope-ops.c b/ggml/src/ggml-hexagon/htp/rope-ops.c index 9901453e91e..b398e19f06e 100644 --- a/ggml/src/ggml-hexagon/htp/rope-ops.c +++ b/ggml/src/ggml-hexagon/htp/rope-ops.c @@ -107,7 +107,7 @@ static inline void rope_yarn_one(float theta, float freq_scale, float * corr_dim cache[i0 + 1] = sinf(theta_final) * mscale_final; } -static void rope_cache_init(const float theta_base, +static __attribute__((noinline)) void rope_cache_init(const float theta_base, const float freq_scale, const float * freq_factors, float * corr_dims, @@ -129,7 +129,7 @@ static void rope_cache_init(const float theta_base, // pos_t/h/w/e: the four position ids for this sequence step (t=time, h=height, w=width, e=extra). // sections[4]: number of head dims assigned to each position component. -static void mrope_cache_init(const float pos_t, +static __attribute__((noinline)) void mrope_cache_init(const float pos_t, const float pos_h, const float pos_w, const float pos_e, diff --git a/ggml/src/ggml-hexagon/htp/ssm-conv.c b/ggml/src/ggml-hexagon/htp/ssm-conv.c index a28fd03e978..d574da2e2bc 100644 --- a/ggml/src/ggml-hexagon/htp/ssm-conv.c +++ b/ggml/src/ggml-hexagon/htp/ssm-conv.c @@ -20,55 +20,56 @@ #include "htp-ops.h" #include "hvx-utils.h" -#define htp_ssm_conv_tensors_preamble \ - const struct htp_tensor * restrict src0 = octx->src[0]; \ - const struct htp_tensor * restrict src1 = octx->src[1]; \ - const struct htp_tensor * restrict dst = octx->dst; \ - struct htp_spad * restrict src0_spad = &octx->src0_spad; \ - struct htp_spad * restrict src1_spad = &octx->src1_spad; \ - struct htp_spad * restrict dst_spad = &octx->dst_spad; \ - \ - const uint32_t ne00 = src0->ne[0]; \ - const uint32_t ne01 = src0->ne[1]; \ - const uint32_t ne02 = src0->ne[2]; \ - const uint32_t ne03 = src0->ne[3]; \ - \ - const uint32_t ne10 = src1->ne[0]; \ - const uint32_t ne11 = src1->ne[1]; \ - const uint32_t ne12 = src1->ne[2]; \ - const uint32_t ne13 = src1->ne[3]; \ - \ - const uint32_t ne0 = dst->ne[0]; \ - const uint32_t ne1 = dst->ne[1]; \ - const uint32_t ne2 = dst->ne[2]; \ - const uint32_t ne3 = dst->ne[3]; \ - \ - const uint32_t nb00 = src0->nb[0]; \ - const uint32_t nb01 = src0->nb[1]; \ - const uint32_t nb02 = src0->nb[2]; \ - const uint32_t nb03 = src0->nb[3]; \ - \ - const uint32_t nb10 = src1->nb[0]; \ - const uint32_t nb11 = src1->nb[1]; \ - const uint32_t nb12 = src1->nb[2]; \ - const uint32_t nb13 = src1->nb[3]; \ - \ - const uint32_t nb0 = dst->nb[0]; \ - const uint32_t nb1 = dst->nb[1]; \ - const uint32_t nb2 = dst->nb[2]; \ +#define htp_ssm_conv_tensors_preamble \ + const struct htp_tensor * restrict src0 = octx->src[0]; \ + const struct htp_tensor * restrict src1 = octx->src[1]; \ + const struct htp_tensor * restrict dst = octx->dst; \ + struct htp_spad * restrict src0_spad = &octx->src0_spad; \ + struct htp_spad * restrict src1_spad = &octx->src1_spad; \ + struct htp_spad * restrict dst_spad = &octx->dst_spad; \ + \ + const uint32_t ne00 = src0->ne[0]; \ + const uint32_t ne01 = src0->ne[1]; \ + const uint32_t ne02 = src0->ne[2]; \ + const uint32_t ne03 = src0->ne[3]; \ + \ + const uint32_t ne10 = src1->ne[0]; \ + const uint32_t ne11 = src1->ne[1]; \ + const uint32_t ne12 = src1->ne[2]; \ + const uint32_t ne13 = src1->ne[3]; \ + \ + const uint32_t ne0 = dst->ne[0]; \ + const uint32_t ne1 = dst->ne[1]; \ + const uint32_t ne2 = dst->ne[2]; \ + const uint32_t ne3 = dst->ne[3]; \ + \ + const uint32_t nb00 = src0->nb[0]; \ + const uint32_t nb01 = src0->nb[1]; \ + const uint32_t nb02 = src0->nb[2]; \ + const uint32_t nb03 = src0->nb[3]; \ + \ + const uint32_t nb10 = src1->nb[0]; \ + const uint32_t nb11 = src1->nb[1]; \ + const uint32_t nb12 = src1->nb[2]; \ + const uint32_t nb13 = src1->nb[3]; \ + \ + const uint32_t nb0 = dst->nb[0]; \ + const uint32_t nb1 = dst->nb[1]; \ + const uint32_t nb2 = dst->nb[2]; \ const uint32_t nb3 = dst->nb[3]; struct htp_ssm_conv_context { struct htp_ops_context * octx; uint32_t nrows_per_thread; + uint32_t d_inner_tile; uint64_t t_start; }; -#define htp_ssm_conv_preamble \ +#define htp_ssm_conv_preamble \ struct htp_ssm_conv_context * scctx = (struct htp_ssm_conv_context *) data; \ - struct htp_ops_context * octx = scctx->octx; \ - htp_ssm_conv_tensors_preamble; \ - dma_queue * dma_queue = octx->ctx->dma[ith]; + struct htp_ops_context * octx = scctx->octx; \ + htp_ssm_conv_tensors_preamble; \ + dma_queue * dma_queue = octx->ctx->dma[ith]; // Scalar FP32 SSM_CONV implementation static void ssm_conv_thread_f32_f32(unsigned int nth, unsigned int ith, void *data) { @@ -128,118 +129,211 @@ static void ssm_conv_thread_f32_f32(unsigned int nth, unsigned int ith, void *da dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } -// HVX FP32 SSM_CONV implementation - vectorizes across d_inner dimension -static void ssm_conv_thread_f32_f32_hvx(unsigned int nth, unsigned int ith, void *data) { - htp_ssm_conv_preamble; - - uint64_t t1, t2; - t1 = HAP_perf_get_qtimer_count(); - const int nc = src1->ne[0]; // d_conv - const int ncs = src0->ne[0]; // d_conv - 1 + n_t +// In-register 32x32 fp32 transpose using std 5-stage HVX vshuff butterfly. +static inline void hvx_transpose_32x32_f32(HVX_Vector m[32]) { + HVX_Vector tmp[32]; - const uint32_t d_conv = src1->ne[0]; - const uint32_t d_inner = src0->ne[1]; - const uint32_t n_t = dst->ne[1]; - const uint32_t n_s = dst->ne[2]; + // Stage 0 (R = -4): pair (2i, 2i+1) for i = 0..15. m -> tmp. + for (int i = 0; i < 16; ++i) { + HVX_VectorPair p = Q6_W_vshuff_VVR(m[2*i + 1], m[2*i], -4); + tmp[2*i + 0] = Q6_V_lo_W(p); + tmp[2*i + 1] = Q6_V_hi_W(p); + } - const float * src0_data = (const float *) src0->data; - const float * src1_data = (const float *) src1->data; - float * dst_data = (float *) dst->data; + // Stage 1 (R = -8): per block of 4, pair (b+0, b+2) and (b+1, b+3). tmp -> m. + for (int b = 0; b < 32; b += 4) { + HVX_VectorPair p0 = Q6_W_vshuff_VVR(tmp[b + 2], tmp[b + 0], -8); + HVX_VectorPair p1 = Q6_W_vshuff_VVR(tmp[b + 3], tmp[b + 1], -8); + m[b + 0] = Q6_V_lo_W(p0); m[b + 1] = Q6_V_hi_W(p0); + m[b + 2] = Q6_V_lo_W(p1); m[b + 3] = Q6_V_hi_W(p1); + } - // Calculate row range for this thread - const int dr = scctx->nrows_per_thread; - const uint32_t ir0 = dr * ith; - const uint32_t ir1 = MIN(ir0 + dr, d_inner); - const uint32_t ir = ir1 - ir0; + // Stage 2 (R = -16): per block of 8, pair (b+i, b+i+4) for i = 0..3. m -> tmp. + for (int b = 0; b < 32; b += 8) { + for (int i = 0; i < 4; ++i) { + HVX_VectorPair p = Q6_W_vshuff_VVR(m[b + i + 4], m[b + i], -16); + tmp[b + 2*i + 0] = Q6_V_lo_W(p); + tmp[b + 2*i + 1] = Q6_V_hi_W(p); + } + } - if (ir0 >= ir1) { - return; // No work for this thread + // Stage 3 (R = -32): per block of 16, pair (b+i, b+i+8) for i = 0..7. tmp -> m. + for (int b = 0; b < 32; b += 16) { + for (int i = 0; i < 8; ++i) { + HVX_VectorPair p = Q6_W_vshuff_VVR(tmp[b + i + 8], tmp[b + i], -32); + m[b + 2*i + 0] = Q6_V_lo_W(p); + m[b + 2*i + 1] = Q6_V_hi_W(p); + } } - // src0 and src1 gather offsets - uint32_t __attribute__((aligned(VLEN))) src0_offsets[VLEN_FP32] = { 0 }; - uint32_t __attribute__((aligned(VLEN))) src1_offsets[VLEN_FP32] = { 0 }; + // Stage 4 (R = -64): pair (i, i+16) for i = 0..15. m -> tmp -> m. + for (int i = 0; i < 16; ++i) { + HVX_VectorPair p = Q6_W_vshuff_VVR(m[i + 16], m[i], -64); + tmp[2 * i + 0] = Q6_V_lo_W(p); + tmp[2 * i + 1] = Q6_V_hi_W(p); + } - for (uint32_t i = 0; i < VLEN_FP32; ++i) { - src0_offsets[i] = i * (ncs) * sizeof(float); - src1_offsets[i] = i * (d_conv) * sizeof(float); + for (int i = 0; i < 32; ++i) { + m[i] = tmp[i]; } +} - const uint32_t src0_gather_len = VLEN * ncs; - const uint32_t src1_gather_len = VLEN * d_conv; +// HVX FP32 SSM_CONV implementation - channel-vectorized HVX kernel with src0/src1 +// transposed into VTCM. +// +// VTCM layouts (per thread): +// src1_T : {d_inner_per_thread, d_conv} — staged once per launch (small). +// src0_T : {d_inner_tile, ncs} — staged per d_inner-tile. +// +// d_inner_tile is chosen so that per-thread VTCM stays under the budget. +// Each thread iterates ceil(d_inner_per_thread d_inner_tile) tiles serially. +#define HTP_SSM_CONV_VTCM_BUDGET (1u << 20) // 1 MiB per thread + +// Scalar transpose: src1 {d_conv, d_inner} (DDR) -> {d_inner_per_thread, d_conv} (VTCM) +static inline void transpose_src1(const float * src1_data, + uint32_t src1_stride_inner, + uint32_t i1_off, + uint32_t d_inner_per_thread, + uint32_t d_conv, + float * src1_T) { + for (uint32_t i = 0; i < d_inner_per_thread; ++i) { + const float * src_row = src1_data + (i1_off + i) * src1_stride_inner; + for (uint32_t j = 0; j < d_conv; ++j) { + src1_T[j * d_inner_per_thread + i] = src_row[j]; + } + } +} - // gather scratchpads - HVX_Vector * src0_vec = (HVX_Vector *) (octx->ctx->vtcm_base + ith * VLEN*2 + 0); - HVX_Vector * src1_vec = (HVX_Vector *) (octx->ctx->vtcm_base + ith * VLEN*2 + VLEN); +// HVX 32x32 src0 transpose: src0 {ncs, d_inner} (DDR) -> src0_T {d_inner_tile, ncs} (VTCM) +static inline void transpose_src0_block(const float * src0_block, + uint32_t ncs, + uint32_t cb_n, + uint32_t d_inner_tile, + float * src0_T_block_dst, + uint32_t cb /* dst column offset */) { + const uint32_t T_TILE = VLEN_FP32; + + HVX_Vector __attribute__((aligned(VLEN))) sub[32]; + + for (uint32_t t0 = 0; t0 < ncs; t0 += T_TILE) { + const uint32_t t_n = MIN(T_TILE, ncs - t0); + + // Load 32 rows (channels) of T_TILE samples; pad missing channels with zeros. + for (uint32_t r = 0; r < cb_n; ++r) { + const float * src_row = src0_block + r * ncs + t0; + if (t_n == T_TILE) { + sub[r] = *(const HVX_UVector *) src_row; + } else { + HVX_Vector v = hvx_vec_splat_f32(0.0f); + hvx_vec_store_u(&v, t_n * sizeof(float), hvx_vec_splat_f32(0.0f)); + + float __attribute__((aligned(VLEN))) tmp[VLEN_FP32] = { 0 }; + for (uint32_t k = 0; k < t_n; ++k) tmp[k] = src_row[k]; + v = *(const HVX_Vector *) tmp; + sub[r] = v; + } + } + for (uint32_t r = cb_n; r < T_TILE; ++r) { + sub[r] = hvx_vec_splat_f32(0.0f); + } - float * data_src0 = (float *) ((char *) src0->data + ir0 * src0->nb[1]); - float * data_src1 = (float *) ((char *) src1->data + ir0 * src1->nb[1]); + hvx_transpose_32x32_f32(sub); - uint8_t * spad_src0 = octx->src0_spad.data + ith * octx->src0_spad.size_per_thread; - uint8_t * spad_src1 = octx->src1_spad.data + ith * octx->src1_spad.size_per_thread; + // Store transposed sub-tile to src0_T at offsets (t0 + j) * d_inner_tile + cb. + // Only write the valid t_n rows of the transposed result. + for (uint32_t r = 0; r < t_n; ++r) { + float * dst = src0_T_block_dst + (t0 + r) * d_inner_tile + cb; + if (cb_n == T_TILE) { + *(HVX_UVector *) dst = sub[r]; + } else { + hvx_vec_store_u(dst, cb_n * sizeof(float), sub[r]); + } + } + } +} - // copy src1 workload to VTCM - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src1, data_src1), nb11, nb11, ir); +static void ssm_conv_thread_f32_f32_hvx(unsigned int nth, unsigned int ith, void *data) { + htp_ssm_conv_preamble; - // FARF(HIGH, "ssm-conv-src1-fetch %d: ir0 %u size %u\n", ith, ir0, nb11 * ir); + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); - for (uint32_t i3 = 0; i3 < n_s; ++i3) { - float * src0_data_ptr = (float *) ((char *) data_src0 + i3 * (src0->nb[2])); + const uint32_t d_conv = src1->ne[0]; + const uint32_t d_inner = src0->ne[1]; + const uint32_t n_t = dst->ne[1]; + const uint32_t n_s = dst->ne[2]; + const uint32_t ncs = src0->ne[0]; - // copy src0 workload to VTCM - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0, src0_data_ptr), nb01, nb01, ir); + const uint32_t src0_stride_inner = src0->nb[1] / sizeof(float); + const uint32_t src0_stride_seq = src0->nb[2] / sizeof(float); + const uint32_t src1_stride_inner = src1->nb[1] / sizeof(float); + const uint32_t dst_stride_token = dst->nb[1] / sizeof(float); + const uint32_t dst_stride_seq = dst->nb[2] / sizeof(float); - // FARF(HIGH, "ssm-conv-src0-fetch %d: ir0 %u i3 %u size %u\n", ith, ir0, i3, nb01 * ir); + const uint32_t dr = scctx->nrows_per_thread; + const uint32_t ir0 = dr * ith; + const uint32_t ir1 = MIN(ir0 + dr, d_inner); - dma_queue_flush(dma_queue); + if (ir0 >= ir1) { + return; + } - for (uint32_t i2 = 0; i2 < n_t; ++i2) { - float * dst_ptr = (float *) ((char *) dst->data + ir0 * (dst->nb[0]) + i2 * (dst->nb[1]) + i3 * (dst->nb[2])); + const uint32_t d_inner_per_thread = ir1 - ir0; + const uint32_t d_inner_tile = scctx->d_inner_tile; - const uint32_t nvec = ir / VLEN_FP32; - const uint32_t nloe = ir % VLEN_FP32; - uint32_t i1 = 0; + const float * src0_data = (const float *) src0->data; + const float * src1_data = (const float *) src1->data; + float * dst_data = (float *) dst->data; - for (uint32_t vi1 = 0; vi1 < nvec; vi1++) { - HVX_Vector acc_vec = Q6_V_vsplat_R(0); + // Per-thread VTCM regions. + float * src0_T = (float *)(octx->src0_spad.data + ith * octx->src0_spad.size_per_thread); + float * src1_T = (float *)(octx->src1_spad.data + ith * octx->src1_spad.size_per_thread); - for (uint32_t i0 = 0; i0 < d_conv; ++i0) { - uint32_t src0_base = (uint32_t) spad_src0 + (i0 + i1 * ncs) * sizeof(float) + i2 * (src0->nb[0]); - uint32_t src1_base = (uint32_t) spad_src1 + (i0 + i1 * nc) * sizeof(float); - Q6_vgather_ARMVw(src0_vec, src0_base, src0_gather_len, (*(const HVX_Vector *) src0_offsets)); - Q6_vgather_ARMVw(src1_vec, src1_base, src1_gather_len, (*(const HVX_Vector *) src1_offsets)); + // Stage src1 weights once into VTCM in {d_inner_per_thread, d_conv} layout. + transpose_src1(src1_data, src1_stride_inner, ir0, d_inner_per_thread, d_conv, src1_T); - HVX_Vector prod = Q6_Vqf32_vmpy_VsfVsf(*(const HVX_Vector *) src0_vec, *(const HVX_Vector *) src1_vec); - acc_vec = Q6_Vqf32_vadd_Vqf32Vqf32(acc_vec, prod); - } + const uint32_t C_TILE = VLEN_FP32; - *(HVX_UVector *) (dst_ptr + i1) = Q6_Vsf_equals_Vqf32(acc_vec); - i1 += VLEN_FP32; - } + for (uint32_t i3 = 0; i3 < n_s; ++i3) { + for (uint32_t tile_off = 0; tile_off < d_inner_per_thread; tile_off += d_inner_tile) { + const uint32_t tile_n = MIN(d_inner_tile, d_inner_per_thread - tile_off); - if (nloe) { - HVX_Vector acc_vec = Q6_V_vsplat_R(0); + // Place src0 chunk into VTCM in {d_inner_tile, ncs} layout. + const float * src0_block = src0_data + i3 * src0_stride_seq + (ir0 + tile_off) * src0_stride_inner; - for (uint32_t i0 = 0; i0 < d_conv; ++i0) { - uint32_t src0_base = (uint32_t) spad_src0 + (i0 + i1 * ncs) * sizeof(float) + i2 * (src0->nb[0]); - uint32_t src1_base = (uint32_t) spad_src1 + (i0 + i1 * nc) * sizeof(float); - Q6_vgather_ARMVw(src0_vec, src0_base, src0_gather_len, (*(const HVX_Vector *) src0_offsets)); - Q6_vgather_ARMVw(src1_vec, src1_base, src1_gather_len, (*(const HVX_Vector *) src1_offsets)); + for (uint32_t cb = 0; cb < tile_n; cb += C_TILE) { + const uint32_t cb_n = MIN(C_TILE, tile_n - cb); + transpose_src0_block(src0_block + cb * src0_stride_inner, ncs, cb_n, d_inner_tile, src0_T, cb); + } - HVX_Vector prod = Q6_Vqf32_vmpy_VsfVsf(*(const HVX_Vector *) src0_vec, *(const HVX_Vector *) src1_vec); - acc_vec = Q6_Vqf32_vadd_Vqf32Vqf32(acc_vec, prod); + for (uint32_t t = 0; t < n_t; ++t) { + for (uint32_t cb = 0; cb < tile_n; cb += C_TILE) { + const uint32_t cb_n = MIN(C_TILE, tile_n - cb); + + HVX_Vector acc = hvx_vec_splat_f32(0.0f); + for (uint32_t j = 0; j < d_conv; ++j) { + HVX_Vector x = *(const HVX_Vector *) (src0_T + (t + j) * d_inner_tile + cb); + HVX_Vector w = *(const HVX_Vector *) (src1_T + j * d_inner_per_thread + tile_off + cb); + acc = Q6_Vqf32_vadd_Vqf32Vqf32(acc, Q6_Vqf32_vmpy_VsfVsf(x, w)); + } + HVX_Vector res = Q6_Vsf_equals_Vqf32(acc); + + float * dst_ptr = dst_data + i3 * dst_stride_seq + t * dst_stride_token + (ir0 + tile_off + cb); + if (cb_n == C_TILE) { + *(HVX_UVector *) dst_ptr = res; + } else { + hvx_vec_store_u(dst_ptr, cb_n * sizeof(float), res); + } } - - hvx_vec_store_u(dst_ptr + i1, (ir - i1) * 4, Q6_Vsf_equals_Vqf32(acc_vec)); } } } t2 = HAP_perf_get_qtimer_count(); - FARF(HIGH, "ssm-conv-f32-hvx %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", - ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ir0, ir1, + FARF(HIGH, "ssm-conv-f32-hvx %d/%d: %ux%ux%ux%u (%u:%u) tile=%u * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", + ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ir0, ir1, d_inner_tile, src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } @@ -264,46 +358,44 @@ int op_ssm_conv_f32(struct htp_ops_context * octx) { if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { uint32_t use_hvx = 0; - if (d_inner >= VLEN_FP32 && d_inner % VLEN_FP32 == 0) { - int is_aligned = hex_is_aligned((void *) src0->data, VLEN) && - hex_is_aligned((void *) src1->data, VLEN) && - hex_is_aligned((void *) dst->data, VLEN); - - if (is_aligned) { - use_hvx = 1; - } + if (d_inner >= VLEN_FP32 && n_t >= VLEN_FP32) { + use_hvx = 1; } - if (use_hvx) { - scctx.nrows_per_thread = (d_inner + n_threads - 1) / n_threads; // d_inner chunks per thread - scctx.nrows_per_thread += (scctx.nrows_per_thread & 1); // round up to even + scctx.nrows_per_thread = (d_inner + n_threads - 1) / n_threads; + scctx.nrows_per_thread += (scctx.nrows_per_thread & 1); - octx->src0_spad.size_per_thread = hex_round_up(scctx.nrows_per_thread * nb01, 256); - octx->src1_spad.size_per_thread = hex_round_up(scctx.nrows_per_thread * nb11, 256); - octx->dst_spad.size_per_thread = hex_round_up(scctx.nrows_per_thread * sizeof(float), 256); + const uint32_t d_inner_per_thread = scctx.nrows_per_thread; + const uint32_t ncs = src0->ne[0]; - octx->src0_spad.size = octx->src0_spad.size_per_thread * n_threads; - octx->src1_spad.size = octx->src1_spad.size_per_thread * n_threads; - octx->dst_spad.size = octx->dst_spad.size_per_thread * n_threads; + const uint32_t src1_T_size = hex_round_up(d_conv * d_inner_per_thread * sizeof(float), 256); + const uint32_t src0_T_max = HTP_SSM_CONV_VTCM_BUDGET > src1_T_size ? HTP_SSM_CONV_VTCM_BUDGET - src1_T_size : 0; - // Compute gather scratchpad size for src0 and src1 - const size_t gather_spad_size = n_threads * VLEN * 2; + uint32_t d_inner_tile = (src0_T_max / sizeof(float)) / ncs; + d_inner_tile -= (d_inner_tile % VLEN_FP32); + if (d_inner_tile == 0) { + FARF(HIGH, "ssm_conv-f32: inner tile rounds to 0 (ncs=%u), falling back to scalar\n", ncs); + use_hvx = 0; + } else { + scctx.d_inner_tile = d_inner_tile; - octx->src0_spad.data = octx->ctx->vtcm_base + gather_spad_size; octx->src0_spad.src = NULL; - octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; octx->src1_spad.src = NULL; - octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; octx->dst_spad.src = NULL; + octx->src0_spad.size_per_thread = hex_round_up(d_inner_tile * ncs * sizeof(float), 256); + octx->src1_spad.size_per_thread = src1_T_size; + octx->dst_spad.size_per_thread = 0; - FARF(HIGH, "ssm_conv-f32: gather-spad:%zu spad-per-thread:(%u:%u:%u) spad-sizes:(%u:%u:%u) spad-data:(%p:%p:%p)\n", - gather_spad_size, octx->src0_spad.size_per_thread, octx->src1_spad.size_per_thread, - octx->dst_spad.size_per_thread, octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size, - octx->src0_spad.data, octx->src1_spad.data, octx->dst_spad.data); + octx->src0_spad.size = octx->src0_spad.size_per_thread * n_threads; + octx->src1_spad.size = octx->src1_spad.size_per_thread * n_threads; + octx->dst_spad.size = 0; - const size_t total_spad_size = - gather_spad_size + octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size; + octx->src0_spad.data = octx->ctx->vtcm_base; + octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; + octx->src0_spad.src = NULL; + octx->src1_spad.src = NULL; - if (total_spad_size > octx->ctx->vtcm_size) { - FARF(HIGH, "ssm_conv-f32: HVX scratchpad size %zu exceeds VTCM size %zu", total_spad_size, - octx->ctx->vtcm_size); + const size_t total_spad = octx->src0_spad.size + octx->src1_spad.size; + if (total_spad > octx->ctx->vtcm_size) { + FARF(HIGH, "ssm_conv-f32: scratchpad %zu exceeds VTCM %zu, falling back to scalar\n", + total_spad, octx->ctx->vtcm_size); use_hvx = 0; } } From 03da9f17f47e416d88deef27096291d656d7892e Mon Sep 17 00:00:00 2001 From: Matt Corallo <649246+TheBlueMatt@users.noreply.github.com> Date: Thu, 21 May 2026 06:24:40 +0000 Subject: [PATCH 665/831] ggml : Check the right iface method before using the fallback 2d get (llama/23306) Probably no backends implement only one of 2d get/set, but this might be annoying for some future backend developer trying to add 2d get/set. --- ggml/src/ggml-backend.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index 4e36909f45e..5c0e5b1b9e2 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -379,7 +379,7 @@ void ggml_backend_tensor_get_2d(const struct ggml_tensor * tensor, void * data, ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; GGML_ASSERT(buf != NULL && "tensor buffer not set"); - if (n_copies <= 1 || buf->iface.set_tensor_2d == NULL) { + if (n_copies <= 1 || buf->iface.get_tensor_2d == NULL) { for (size_t i = 0; i < n_copies; i++) { ggml_backend_tensor_get(tensor, (char *) data + i*stride_data, offset + i*stride_tensor, size); } From 158d93c8365745395da5d3f914253947d4a39e22 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 21 May 2026 13:34:08 +0300 Subject: [PATCH 666/831] metal : optimize concat kernel and fix set kernel threads (llama/23411) * metal : fix GGML_OP_SET kernel threads * tests : extend test_cpy to support different src/dst shapes Extend test_cpy to support different source and destination tensor shapes for CPY operations (reshaping), where the total number of elements must match. - Renamed ne -> ne_src, added ne_dst parameter (default: use src shape) - Added 50 new reshaping test cases covering 1D<->2D<->3D<->4D conversions - Tests exercise 1024 boundary, small shapes, and large dimensionality changes - Fixed dangling reference bug (storing & to temporary std::array) - Updated all existing test calls with permute/transpose args for compatibility Assisted-by: llama.cpp:local pi * metal : optimize concat kernel with row batching for small widths When ne0 < 256, batch multiple rows into a single threadgroup to improve occupancy. This avoids underutilizing the GPU when processing narrow tensors. - Dispatch nth = min(256, ne0) threads per group - Calculate nrptg (rows per threadgroup) to fill up to 256 threads - Update kernel index calculation to handle the row batching - Add boundary check for i1 >= ne1 Assisted-by: llama.cpp:local pi * tests : clean-up * tests : refactor CPY shape tests to use dimension permutations Replace 75 hardcoded test cases with a loop over permutations of {3, 5, 7, 32} (total elements: 3360). Each src permutation is tested against canonical sorted and reverse dst, skipping identical shapes. Covers F32, F16, and Q4_0 (when both src and dst ne0 == 32). Assisted-by: llama.cpp:local pi --- ggml/src/ggml-metal/ggml-metal-ops.cpp | 19 +++++++++++++++---- ggml/src/ggml-metal/ggml-metal.metal | 6 +++++- 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 8506000b6c0..206af227a2c 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -564,9 +564,20 @@ int ggml_metal_op_concat(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); - const int nth = std::min(1024, ne0); + int nth = std::min(256, ne0); - ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1); + // when rows are small, we can batch them together in a single threadgroup + int nrptg = 1; + if (nth < 256) { + nrptg = std::min((256 + nth - 1) / nth, ne1); + if (nrptg * nth > 256) { + nrptg = 256 / nth; + } + } + + const int nw0 = (ne1 + nrptg - 1) / nrptg; + + ggml_metal_encoder_dispatch_threadgroups(enc, nw0, ne2, ne3, nth, nrptg, 1); return 1; } @@ -1786,7 +1797,7 @@ int ggml_metal_op_set(ggml_metal_op_t ctx, int idx) { nk0 = ne10/ggml_blck_size(op->type); } - int nth = std::min(nk0, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + int nth = std::min(nk0*ne11, 256); // when rows are small, we can batch them together in a single threadgroup int nrptg = 1; @@ -1797,7 +1808,7 @@ int ggml_metal_op_set(ggml_metal_op_t ctx, int idx) { nrptg = (nth + nk0 - 1)/nk0; nth = nk0; - if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + if (nrptg*nth > 256) { nrptg--; } } diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 4cf9dbea946..e772664ba91 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -7486,7 +7486,11 @@ kernel void kernel_concat( const int i3 = tgpig.z; const int i2 = tgpig.y; - const int i1 = tgpig.x; + const int i1 = ntg.y == 1 ? tgpig.x : tgpig.x*ntg.y + tpitg.y; + + if (i1 >= args.ne1) { + return; + } int o[4] = {0, 0, 0, 0}; o[args.dim] = args.dim == 0 ? args.ne00 : (args.dim == 1 ? args.ne01 : (args.dim == 2 ? args.ne02 : args.ne03)); From c436f1419f8128e6d0f9274bafe804d4b95fad96 Mon Sep 17 00:00:00 2001 From: Chen Yuan Date: Thu, 21 May 2026 10:58:49 -0400 Subject: [PATCH 667/831] fix(flash-attn): replace f32 with kv_type and q_type (llama/23372) --- .../wgsl-shaders/flash_attn_tile.wgsl | 44 +++++++++---------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl index ae8036b9ac5..4133f0ab564 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl @@ -122,9 +122,9 @@ const V_CHUNKS: u32 = HEAD_DIM_V / 4u; const SCORE_REGS_PER_LANE: u32 = (KV_TILE + MIN_SUBGROUP_SIZE - 1u) / MIN_SUBGROUP_SIZE; const OUT_REGS_PER_LANE: u32 = (V_CHUNKS + MIN_SUBGROUP_SIZE - 1u) / MIN_SUBGROUP_SIZE; -var q_shmem: array; -var kv_shmem: array; -var p_shmem: array; +var q_shmem: array; +var kv_shmem: array; +var p_shmem: array; @compute @workgroup_size(WG_SIZE) fn main(@builtin(workgroup_id) wg_id: vec3, @@ -169,10 +169,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let head = f32(head_idx); let slope = select(1.0, - select(pow(params.m1, 2.0 * (head - params.n_head_log2) + 1.0), - pow(params.m0, head + 1.0), - head < params.n_head_log2), - params.max_bias > 0.0); + select(pow(params.m1, 2.0 * (head - params.n_head_log2) + 1.0), + pow(params.m0, head + 1.0), + head < params.n_head_log2), + params.max_bias > 0.0); for (var elem_idx = local_id.x; elem_idx < Q_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE) { let q_tile_row = elem_idx / HEAD_DIM_QK; @@ -181,7 +181,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let global_q_row_offset = q_head_offset + head_q_row * params.stride_q1; q_shmem[elem_idx] = select( 0.0, - f32(Q[global_q_row_offset + q_col]) * params.scale, + Q_TYPE(Q[global_q_row_offset + q_col]) * params.scale, head_q_row < params.seq_len_q); } @@ -213,10 +213,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let k_vec_index = (k_head_offset + global_k_row * params.stride_k1 + chunk * 4u) >> 2u; let k4 = K[k_vec_index]; let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u; - kv_shmem[kv_off + 0u] = f32(k4.x); - kv_shmem[kv_off + 1u] = f32(k4.y); - kv_shmem[kv_off + 2u] = f32(k4.z); - kv_shmem[kv_off + 3u] = f32(k4.w); + kv_shmem[kv_off + 0u] = KV_TYPE(k4.x); + kv_shmem[kv_off + 1u] = KV_TYPE(k4.y); + kv_shmem[kv_off + 2u] = KV_TYPE(k4.z); + kv_shmem[kv_off + 3u] = KV_TYPE(k4.w); } workgroupBarrier(); @@ -233,18 +233,18 @@ fn main(@builtin(workgroup_id) wg_id: vec3, var dot_val = 0.0; for (var chunk = 0u; chunk < Q_CHUNKS; chunk += 1u) { let q_off = q_base + chunk * 4u; - let qv = vec4( + let qv = vec4( q_shmem[q_off + 0u], q_shmem[q_off + 1u], q_shmem[q_off + 2u], q_shmem[q_off + 3u]); let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u; - let kv = vec4( + let kv = vec4( kv_shmem[kv_off + 0u], kv_shmem[kv_off + 1u], kv_shmem[kv_off + 2u], kv_shmem[kv_off + 3u]); - dot_val += dot(qv, kv); + dot_val += dot(vec4(qv), vec4(kv)); } #ifdef LOGIT_SOFTCAP dot_val = params.logit_softcap * tanh(dot_val); @@ -271,7 +271,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let kv_local = sg_inv_id + slot * subgroup_size; if (row_active && kv_local < kv_count) { let p = exp(local_scores[slot] - new_max); - p_shmem[subgroup_p_offset + kv_local] = p; + p_shmem[subgroup_p_offset + kv_local] = KV_TYPE(p); local_sum += p; } } @@ -285,10 +285,10 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let v_vec_index = (v_head_offset + global_v_row * params.stride_v1 + chunk * 4u) >> 2u; let v4 = V[v_vec_index]; let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u; - kv_shmem[kv_off + 0u] = f32(v4.x); - kv_shmem[kv_off + 1u] = f32(v4.y); - kv_shmem[kv_off + 2u] = f32(v4.z); - kv_shmem[kv_off + 3u] = f32(v4.w); + kv_shmem[kv_off + 0u] = KV_TYPE(v4.x); + kv_shmem[kv_off + 1u] = KV_TYPE(v4.y); + kv_shmem[kv_off + 2u] = KV_TYPE(v4.z); + kv_shmem[kv_off + 3u] = KV_TYPE(v4.w); } workgroupBarrier(); @@ -308,12 +308,12 @@ fn main(@builtin(workgroup_id) wg_id: vec3, for (var kv_local = 0u; kv_local < kv_count; kv_local += 1u) { let p = p_shmem[subgroup_p_offset + kv_local]; let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u; - let v4 = vec4( + let v4 = vec4( kv_shmem[kv_off + 0u], kv_shmem[kv_off + 1u], kv_shmem[kv_off + 2u], kv_shmem[kv_off + 3u]); - acc += p * v4; + acc += f32(p) * vec4(v4); } out_regs[reg_idx] = acc; } From 8402c36039c1ad14fb137bd6eecc99f2961a5a35 Mon Sep 17 00:00:00 2001 From: Pascal Date: Thu, 21 May 2026 19:39:42 +0200 Subject: [PATCH 668/831] vulkan: fuse snake activation (mul, sin, sqr, mul, add) (llama/22855) * vulkan: fuse snake activation (mul, sin, sqr, mul, add) Add snake.comp shader with F32 / F16 / BF16 pipelines and ggml_vk_snake_dispatch_fused. The matcher recognizes the naive 5 op decomposition emitted by audio decoders (BigVGAN, Vocos) for snake activation y = x + sin(a*x)^2 * inv_b and rewrites it to a single elementwise kernel. test_snake_fuse from the CUDA PR now also compares CPU naive vs Vulkan fused across F32 / F16 / BF16. * vulkan: address jeffbolznv review for fused snake activation Rename T / C to ne0 / ne1 in the shader and push constants to match the standard naming convention used across the Vulkan backend. Tighten ggml_vk_can_fuse_snake: require x and dst to be contiguous (the shader uses idx = i0 + i1 * ne0) and require a / inv_b to be tightly packed on the broadcast dim (the shader reads data_a[i1]). * vulkan: tighten snake fusion type checks for all operands (address jeffbolznv review) * vulkan: reject snake fusion when ne[2] or ne[3] > 1 (address jeffbolznv review) * vulkan: address 0cc4m review for fused snake activation snake.comp is renamed to follow the ggml DATA_A_* / A_TYPE convention. A_TYPE now applies to the activation tensor data_a instead of the broadcast multiplier, and the bindings become data_a (A_TYPE), data_b (float), data_c (float) and data_d (D_TYPE). A header at the top of the shader maps each buffer to its role in y = x + sin(b * x)^2 * c. On the C++ side, ggml_vk_can_fuse_snake reuses the existing snake_pattern constant instead of duplicating the op list, sin_node is extracted as a named local alongside the other chain nodes, and the broadcast operands a and inv_b are now required to be GGML_TYPE_F32 to match the hardcoded float bindings on data_b and data_c (the previous a->type == x->type would silently reject any future BF16 or F16 chain once the supports_op gate for SIN / SQR is lifted). ggml_vk_snake_dispatch_fused gets an explicit GGML_TYPE_F32 case and GGML_ABORT on default in place of the silent f32 fallback, and a stale comment about data_a[i1] / data_inv_b[i1] is refreshed to match the new binding names. --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 136 +++++++++++++++++- .../src/ggml-vulkan/vulkan-shaders/snake.comp | 49 +++++++ .../vulkan-shaders/vulkan-shaders-gen.cpp | 4 + 3 files changed, 187 insertions(+), 2 deletions(-) create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/snake.comp diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index d3fb19048d9..aa289220a90 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -499,6 +499,12 @@ static constexpr std::initializer_list topk_moe_late_softmax { GGM GGML_OP_GET_ROWS, GGML_OP_RESHAPE, GGML_OP_SOFT_MAX, GGML_OP_RESHAPE }; +// Snake activation: y = x + sin(a*x)^2 * inv_b. Used by the optimize_graph reorder +// pass so it keeps the chain contiguous and by the dispatcher to detect the fusion. +static constexpr std::initializer_list snake_pattern { GGML_OP_MUL, GGML_OP_SIN, + GGML_OP_SQR, GGML_OP_MUL, + GGML_OP_ADD }; + //node #978 ( SOFT_MAX): ffn_moe_probs-15 ( 0K) [Vulka ] use=2: ffn_moe_logits-15 ( 0K) [Vulka ] //node #979 ( RESHAPE): ffn_moe_probs-15 (re ( 0K) [Vulka ] use=1: ffn_moe_probs-15 ( 0K) [Vulka ] //node #980 ( ARGSORT): ffn_moe_argsort-15 ( 0K) [Vulka ] use=1: ffn_moe_probs-15 ( 0K) [Vulka ] @@ -846,6 +852,9 @@ struct vk_device_struct { vk_pipeline pipeline_im2col_3d_f32, pipeline_im2col_3d_f32_f16; vk_pipeline pipeline_timestep_embedding_f32; vk_pipeline pipeline_conv_transpose_1d_f32; + vk_pipeline pipeline_snake_f32; + vk_pipeline pipeline_snake_f16; + vk_pipeline pipeline_snake_bf16; vk_pipeline pipeline_pool2d_f32; vk_pipeline pipeline_rwkv_wkv6_f32; vk_pipeline pipeline_rwkv_wkv7_f32; @@ -1475,6 +1484,11 @@ struct vk_op_conv_transpose_1d_push_constants { int32_t s0; }; +struct vk_op_snake_push_constants { + uint32_t ne0; + uint32_t ne1; +}; + struct vk_op_pool2d_push_constants { uint32_t IW; uint32_t IH; uint32_t OW; uint32_t OH; @@ -4845,6 +4859,10 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_conv_transpose_1d_f32, "conv_transpose_1d_f32", conv_transpose_1d_f32_len, conv_transpose_1d_f32_data, "main", 3, sizeof(vk_op_conv_transpose_1d_push_constants), {1, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_snake_f32, "snake_f32", snake_f32_len, snake_f32_data, "main", 4, sizeof(vk_op_snake_push_constants), {256, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_snake_f16, "snake_f16", snake_f16_len, snake_f16_data, "main", 4, sizeof(vk_op_snake_push_constants), {256, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_snake_bf16, "snake_bf16", snake_bf16_len, snake_bf16_data, "main", 4, sizeof(vk_op_snake_push_constants), {256, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_pool2d_f32, "pool2d_f32", pool2d_f32_len, pool2d_f32_data, "main", 2, sizeof(vk_op_pool2d_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1); @@ -12110,6 +12128,45 @@ static void ggml_vk_conv_transpose_1d(ggml_backend_vk_context * ctx, vk_context& ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_CONV_TRANSPOSE_1D, std::move(p)); } +// Dispatch the fused snake activation: y = x + sin^2(a * x) * inv_b. +// Match the naive mul -> sin -> sqr -> mul -> add chain and run the +// dedicated kernel directly. The pattern is validated by +// ggml_vk_can_fuse_snake before this call. +static void ggml_vk_snake_dispatch_fused(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx) { + const ggml_tensor * mul0 = cgraph->nodes[node_idx + 0]; + const ggml_tensor * sqr = cgraph->nodes[node_idx + 2]; + const ggml_tensor * mul1 = cgraph->nodes[node_idx + 3]; + ggml_tensor * add = cgraph->nodes[node_idx + 4]; + + // x carries the full activation shape, a is the broadcast operand + const ggml_tensor * x = ggml_are_same_shape(mul0, mul0->src[0]) ? mul0->src[0] : mul0->src[1]; + const ggml_tensor * a = (x == mul0->src[0]) ? mul0->src[1] : mul0->src[0]; + + // mul1 reads sqr and inv_b in either operand order + const ggml_tensor * inv_b = (mul1->src[0] == sqr) ? mul1->src[1] : mul1->src[0]; + + vk_pipeline pipeline = nullptr; + switch (x->type) { + case GGML_TYPE_F32: pipeline = ctx->device->pipeline_snake_f32; break; + case GGML_TYPE_F16: pipeline = ctx->device->pipeline_snake_f16; break; + case GGML_TYPE_BF16: pipeline = ctx->device->pipeline_snake_bf16; break; + default: GGML_ABORT("unsupported type"); + } + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + + vk_subbuffer x_buf = ggml_vk_tensor_subbuffer(ctx, x); + vk_subbuffer a_buf = ggml_vk_tensor_subbuffer(ctx, a); + vk_subbuffer inv_b_buf = ggml_vk_tensor_subbuffer(ctx, inv_b); + vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, add); + + vk_op_snake_push_constants pc{}; + pc.ne0 = static_cast(x->ne[0]); + pc.ne1 = static_cast(x->ne[1]); + + std::array elements = { pc.ne0, pc.ne1, 1 }; + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { x_buf, a_buf, inv_b_buf, dst_buf }, pc, elements); +} + static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { uint32_t op = static_cast(dst->op_params[0]); const int32_t k1 = dst->op_params[1]; @@ -13318,7 +13375,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr break; case GGML_OP_MUL: - ggml_vk_mul(ctx, compute_ctx, src0, src1, node); + if (ctx->num_additional_fused_ops) { + ggml_vk_snake_dispatch_fused(ctx, compute_ctx, cgraph, node_idx); + } else { + ggml_vk_mul(ctx, compute_ctx, src0, src1, node); + } break; case GGML_OP_DIV: @@ -14691,6 +14752,65 @@ static bool ggml_vk_can_fuse_rope_set_rows(ggml_backend_vk_context * ctx, const return true; } +// Pattern check for the 5-op Snake fusion: mul -> sin -> sqr -> mul -> add. +// Verifies the chain shape, the closure x_in_add == x_in_mul0, and that +// the broadcast operands a and inv_b share a [1, C] layout. +static bool ggml_vk_can_fuse_snake(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, int node_idx) { + GGML_UNUSED(ctx); + if (!ggml_can_fuse(cgraph, node_idx, snake_pattern)) { + return false; + } + + const ggml_tensor * mul0 = cgraph->nodes[node_idx + 0]; + const ggml_tensor * sin_node = cgraph->nodes[node_idx + 1]; + const ggml_tensor * sqr = cgraph->nodes[node_idx + 2]; + const ggml_tensor * mul1 = cgraph->nodes[node_idx + 3]; + const ggml_tensor * add = cgraph->nodes[node_idx + 4]; + + const ggml_tensor * x = ggml_are_same_shape(mul0, mul0->src[0]) ? mul0->src[0] : mul0->src[1]; + const ggml_tensor * a = (x == mul0->src[0]) ? mul0->src[1] : mul0->src[0]; + + const ggml_tensor * inv_b = (mul1->src[0] == sqr) ? mul1->src[1] : mul1->src[0]; + const ggml_tensor * x_in_add = (add->src[0] == mul1) ? add->src[1] : add->src[0]; + + if (x_in_add != x) { + return false; + } + if (x->type != GGML_TYPE_F32 && x->type != GGML_TYPE_F16 && x->type != GGML_TYPE_BF16) { + return false; + } + // Shader bindings: data_a is A_TYPE so it follows x's precision, while + // data_b and data_c are hardcoded float, so the broadcast operands must + // be F32 regardless of x's type. + if (a->type != GGML_TYPE_F32) return false; + if (inv_b->type != GGML_TYPE_F32) return false; + // Chain intermediates and output share x's precision (single A_TYPE / D_TYPE pipeline). + if (mul0->type != x->type) return false; + if (sin_node->type != x->type) return false; + if (sqr->type != x->type) return false; + if (mul1->type != x->type) return false; + if (add->type != x->type) return false; + if (!ggml_are_same_shape(a, inv_b)) { + return false; + } + if (a->ne[0] != 1 || a->ne[1] != x->ne[1]) { + return false; + } + // Dispatch is 2D over (ne0, ne1), so x and add must be 2D and a / inv_b + // must collapse to [1, C, 1, 1]. Higher dims are not handled by the shader. + if (x->ne[2] != 1 || x->ne[3] != 1) return false; + if (add->ne[2] != 1 || add->ne[3] != 1) return false; + if (a->ne[2] != 1 || a->ne[3] != 1) return false; + if (inv_b->ne[2] != 1 || inv_b->ne[3] != 1) return false; + // Shader uses idx = i0 + i1 * ne0 and reads data_b[i1] / data_c[i1], + // so every operand must be contiguous. + if (!ggml_is_contiguous(x) || !ggml_is_contiguous(add) || + !ggml_is_contiguous(a) || !ggml_is_contiguous(inv_b)) { + return false; + } + return true; +} + // Check whether the tensors overlap in memory. // Fusions can potentially overwrite src tensors in ways that are not prevented // by ggml-alloc. If the fusion src is being applied in a way that's elementwise @@ -14998,6 +15118,14 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg op_srcs_fused_elementwise[0] = false; op_srcs_fused_elementwise[1] = false; op_srcs_fused_elementwise[2] = false; + } else if (ggml_vk_can_fuse_snake(ctx, cgraph, i)) { + ctx->num_additional_fused_ops = 4; + fusion_string = "SNAKE"; + // elementwise=true: snake.comp is safe under exact aliasing because each + // thread reads data_x[idx] into a register before writing data_d[idx] + // with a data dependency on that register. The overlap check still + // rejects partial overlaps (different base or size). + std::fill_n(op_srcs_fused_elementwise, 5, true); } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax_norm, { i + 3, i + 9 }) && ggml_check_edges(cgraph, i, topk_moe_early_softmax_norm_edges) && ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) { @@ -15288,6 +15416,9 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * if (keep_pattern(topk_moe_late_softmax)) { continue; } + if (keep_pattern(snake_pattern)) { + continue; + } // First, grab the next unused node. current_set.push_back(first_unused); @@ -15310,7 +15441,8 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * if (match_pattern(topk_moe_early_softmax_norm, j) || match_pattern(topk_moe_sigmoid_norm_bias, j) || match_pattern(topk_moe_early_softmax, j) || - match_pattern(topk_moe_late_softmax, j)) { + match_pattern(topk_moe_late_softmax, j) || + match_pattern(snake_pattern, j)) { continue; } bool ok = true; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/snake.comp b/ggml/src/ggml-vulkan/vulkan-shaders/snake.comp new file mode 100644 index 00000000000..8585538cbb0 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/snake.comp @@ -0,0 +1,49 @@ +#version 450 + +#include "types.glsl" + +// Fused snake activation: y = x + sin(b * x)^2 * c +// data_a [ne0, ne1] per element activation x (A_TYPE) +// data_b [1, ne1] per channel multiplier (float) +// data_c [1, ne1] per channel inverse scale (float, precomputed as 1 / freq) +// data_d [ne0, ne1] output y (D_TYPE) +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer B {float data_b[];}; +layout (binding = 2) readonly buffer C {float data_c[];}; +layout (binding = 3) writeonly buffer D {D_TYPE data_d[];}; + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (push_constant) uniform parameter { + uint32_t ne0; + uint32_t ne1; +} p; + +// Load A_TYPE to float +float load_val(uint32_t idx) { +#if defined(DATA_A_BF16) + return bf16_to_fp32(uint32_t(data_a[idx])); +#else + return float(data_a[idx]); +#endif +} + +// Store float as D_TYPE +void store_val(uint32_t idx, float v) { +#if defined(DATA_D_BF16) + data_d[idx] = D_TYPE(fp32_to_bf16(v)); +#else + data_d[idx] = D_TYPE(v); +#endif +} + +void main() { + const uint32_t i0 = gl_GlobalInvocationID.x; + const uint32_t i1 = gl_GlobalInvocationID.y; + if (i0 >= p.ne0 || i1 >= p.ne1) return; + + const uint32_t idx = i0 + i1 * p.ne0; + const float xi = load_val(idx); + const float s = sin(data_b[i1] * xi); + store_val(idx, xi + s * s * data_c[i1]); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index e3a9d61a558..a1d735150fd 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -952,6 +952,10 @@ void process_shaders() { string_to_spv("conv_transpose_1d_f32", "conv_transpose_1d.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("snake_f32", "snake.comp", {{"DATA_A_F32", "1"}, {"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("snake_f16", "snake.comp", {{"DATA_A_F16", "1"}, {"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("snake_bf16", "snake.comp", {{"DATA_A_BF16", "1"}, {"DATA_D_BF16", "1"}, {"A_TYPE", "uint16_t"}, {"D_TYPE", "uint16_t"}}); + string_to_spv("pool2d_f32", "pool2d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); From ec183556c6939d30272f928ac281d5390501bc51 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Thu, 21 May 2026 23:35:29 +0200 Subject: [PATCH 669/831] CUDA: fix PDL CC check for JIT compilation (llama/23471) --- ggml/src/ggml-cuda/common.cuh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 9c73fe7e6fa..e54ecb29308 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -1561,7 +1561,8 @@ static __inline__ void ggml_cuda_kernel_launch(Kernel kernel, const ggml_cuda_ke return env == nullptr || std::atoi(env) != 0; }(); - if (env_pdl_enabled && ggml_cuda_info().devices[ggml_cuda_get_device()].cc >= GGML_CUDA_CC_HOPPER) { + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + if (env_pdl_enabled && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_HOPPER) { auto pdl_cfg = ggml_cuda_pdl_config(launch_params); CUDA_CHECK(cudaLaunchKernelEx(&pdl_cfg.cfg, kernel, std::forward(args)... )); From 2d629533a5a4130ee0b91fc29637b8ffae710f6c Mon Sep 17 00:00:00 2001 From: Sachin Sharma Date: Fri, 22 May 2026 16:46:55 +0530 Subject: [PATCH 670/831] ggml-zendnn : add Q8_0 quantization support (llama/23414) * ggml-zendnn : add Q8_0 quantization support * ggml-zendnn : sync with latest ZenDNN * ggml-zendnn : address review comments for Q8_0 --- ggml/src/ggml-zendnn/CMakeLists.txt | 2 +- ggml/src/ggml-zendnn/ggml-zendnn.cpp | 56 ++++++++++++++++++++++------ 2 files changed, 46 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-zendnn/CMakeLists.txt b/ggml/src/ggml-zendnn/CMakeLists.txt index f1e4f991fae..e4ba9cfbd0f 100644 --- a/ggml/src/ggml-zendnn/CMakeLists.txt +++ b/ggml/src/ggml-zendnn/CMakeLists.txt @@ -28,7 +28,7 @@ if (NOT ZENDNN_ROOT OR ZENDNN_ROOT STREQUAL "" OR ZENDNN_ROOT STREQUAL "OFF") ExternalProject_Add( zendnn GIT_REPOSITORY https://github.com/amd/ZenDNN.git - GIT_TAG ac9e580d9434b7b98985f2627a7ebfb5eba4bb0d # ZenDNN-2026-WW17 + GIT_TAG 253b94ce0d7e9284c265fefb485714944caff9d3 # ZenDNN-2026-WW19 PREFIX ${ZENDNN_PREFIX} SOURCE_DIR ${ZENDNN_SOURCE_DIR} BINARY_DIR ${ZENDNN_BUILD_DIR} diff --git a/ggml/src/ggml-zendnn/ggml-zendnn.cpp b/ggml/src/ggml-zendnn/ggml-zendnn.cpp index 6a83bb6b1ec..6051d082003 100644 --- a/ggml/src/ggml-zendnn/ggml-zendnn.cpp +++ b/ggml/src/ggml-zendnn/ggml-zendnn.cpp @@ -2,6 +2,10 @@ #include "ggml-backend-impl.h" #include "ggml-impl.h" + +#define GGML_COMMON_DECL_CPP +#include "ggml-common.h" + #include "zendnnl.hpp" #include @@ -19,6 +23,8 @@ zendnnl::common::data_type_t ggml_to_zendnn_type() { return zendnnl::common::data_type_t::f32; } else if constexpr (std::is_same_v) { return zendnnl::common::data_type_t::bf16; + } else if constexpr (std::is_same_v) { + return zendnnl::common::data_type_t::s8; } else { return zendnnl::common::data_type_t::none; } @@ -48,6 +54,17 @@ static bool ggml_zendnn_matmul(ggml_backend_zendnn_context * ctx, int64_t m, int params.num_threads = ctx->n_threads; zendnnl::lowoha::matmul::matmul_batch_params_t batch_params; + + if constexpr (std::is_same_v) { + params.dtypes.compute = zendnnl::common::data_type_t::s8; + const int64_t num_groups = k / QK8_0; + params.dynamic_quant = true; + params.quant_params.src_scale.buff = nullptr; + params.quant_params.src_scale.dt = zendnnl::common::data_type_t::bf16; + params.quant_params.src_scale.dims = {n, num_groups}; + params.packing.pack_format_b = 1; + } + zendnnl::error_handling::status_t status = zendnnl::lowoha::matmul::matmul_direct( 'r', false, true, // row-major, don't transpose B, transpose A (because it's column-major) n, // M: rows of B and C @@ -108,6 +125,14 @@ static bool ggml_zendnn_sgemm(ggml_backend_zendnn_context * ctx, int64_t m, int6 (const ggml_bf16_t *)B, ldb, (float *)C, ldc); return false; + case GGML_TYPE_Q8_0: + if (Btype != GGML_TYPE_F32 || Ctype != GGML_TYPE_F32) + return false; + return ggml_zendnn_matmul( + ctx, m, n, k, + (const block_q8_0 *)A, lda, + (const float *)B, ldb, + (float *)C, ldc); default: return false; // unsupported type } @@ -145,7 +170,9 @@ static void ggml_zendnn_compute_forward_mul_mat( const int64_t r3 = ne13/ne03; void * work_data = ctx->work_data.get(); - if (src1->type != vec_dot_type) { + + // ZenDNN requires FP32 for dynamic quantization, so conversion is skipped + if (src1->type != vec_dot_type && src0->type != GGML_TYPE_Q8_0) { const size_t nbw1 = ggml_row_size(vec_dot_type, ne10); const size_t nbw2 = nbw1 * ne11; const size_t nbw3 = nbw2 * ne12; @@ -171,7 +198,7 @@ static void ggml_zendnn_compute_forward_mul_mat( for (int64_t i13 = 0; i13 < ne13; i13++) { for (int64_t i12 = 0; i12 < ne12; i12++) { - const void* wdata = src1->type == vec_dot_type ? src1->data : work_data; + const void* wdata = (src1->type == vec_dot_type || src0->type == GGML_TYPE_Q8_0) ? src1->data : work_data; const size_t row_size = ggml_row_size(vec_dot_type, ne10); if (!ggml_zendnn_sgemm(ctx, ne01, // m @@ -184,7 +211,7 @@ static void ggml_zendnn_compute_forward_mul_mat( static_cast(dst->data) + i12*nb2 + i13*nb3, ne01, // ldc src0->type, - vec_dot_type, + src0->type == GGML_TYPE_Q8_0 ? GGML_TYPE_F32 : vec_dot_type, dst->type)) GGML_ABORT("%s: ZenDNN sgemm failed\n", __func__); } @@ -261,10 +288,15 @@ static void ggml_zendnn_compute_forward_mul_mat_id( const size_t nbw1 = row_size; const size_t nbw2 = nbw1 * ne11; const size_t nbw3 = nbw2 * ne12; - const size_t src1_conv_size = (src1->type != vec_dot_type) ? ne13 * nbw3 : 0; + const size_t src1_conv_size = (src1->type != vec_dot_type && src0->type != GGML_TYPE_Q8_0) ? ne13 * nbw3 : 0; + + // For Q8_0, src1 is always F32; the gather buffer must hold F32 rows (ne10*4 bytes), + // not Q8_0-encoded rows (row_size ≈ ne10/32*34 bytes) — they differ by ~4x. + const size_t f32_row_size = (size_t)ne10 * sizeof(float); + const size_t gather_row_size = (src0->type == GGML_TYPE_Q8_0) ? f32_row_size : row_size; // size for MoE gather/scatter buffers - const size_t wdata_cur_size = max_rows * row_size; + const size_t wdata_cur_size = max_rows * gather_row_size; const size_t dst_cur_size = max_rows * ggml_row_size(dst->type, ne01); // allocate single buffer for all needs @@ -279,7 +311,8 @@ static void ggml_zendnn_compute_forward_mul_mat_id( char * wdata_cur = work_data + src1_conv_size; char * dst_cur = wdata_cur + wdata_cur_size; - if (src1->type != vec_dot_type) { + // ZenDNN requires FP32 for dynamic quantization, so conversion is skipped + if (src1->type != vec_dot_type && src0->type != GGML_TYPE_Q8_0) { GGML_ASSERT(src1->type == GGML_TYPE_F32); #pragma omp parallel for collapse(3) num_threads(ctx->n_threads) schedule(static) @@ -294,7 +327,7 @@ static void ggml_zendnn_compute_forward_mul_mat_id( } } - const void * wdata = src1->type == vec_dot_type ? src1->data : work_data; + const void * wdata = (src1->type == vec_dot_type || src0->type == GGML_TYPE_Q8_0) ? src1->data : work_data; // process each expert with gather -> gemm -> scatter pattern for (int64_t cur_a = 0; cur_a < n_as; ++cur_a) { @@ -315,9 +348,9 @@ static void ggml_zendnn_compute_forward_mul_mat_id( const int64_t i12 = row_mapping.i2; std::memcpy( - wdata_cur + ir1 * row_size, - (const char *) wdata + (i11 + i12*ne11) * row_size, - row_size + wdata_cur + ir1 * gather_row_size, + (const char *) wdata + (i11 + i12*ne11) * gather_row_size, + gather_row_size ); } @@ -333,7 +366,7 @@ static void ggml_zendnn_compute_forward_mul_mat_id( dst_cur, ne01, // ldc src0->type, - vec_dot_type, + src0->type == GGML_TYPE_Q8_0 ? GGML_TYPE_F32 : vec_dot_type, dst->type)) { GGML_ABORT("%s: ZenDNN sgemm failed\n", __func__); } @@ -577,6 +610,7 @@ static bool ggml_backend_zendnn_device_supports_op(ggml_backend_dev_t dev, const switch (weights->type) { case GGML_TYPE_F32: case GGML_TYPE_BF16: + case GGML_TYPE_Q8_0: return true; default: return false; From 6fb7f1af2c66ce3512df5c9edc90caea7f6e2a5f Mon Sep 17 00:00:00 2001 From: Katostrofik Date: Fri, 22 May 2026 08:48:24 -0400 Subject: [PATCH 671/831] SYCL: add BF16 to DMMV kernel path (~4x tg speedup on Intel Arc) (llama/21580) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * SYCL: add BF16 to DMMV kernel path for ~4x token generation speedup BF16 models had no dedicated token generation kernel — they fell through to the generic full-GEMM path, resulting in ~14% memory bandwidth utilization on Intel Arc GPUs. This adds BF16 support to the DMMV (dequantize mul-mat-vec) path, matching the existing F16 implementation. Fixes #20478 * SYCL: fix BF16 DMMV out-of-bounds when ncols % 64 != 0 The qk=1 kernel (used for F16 and BF16) iterates with stride 2*GGML_SYCL_DMMV_X (= 64 on Intel targets where WARP_SIZE=16). When ncols is a multiple of DMMV_X (32) but not of 2*DMMV_X (64), the last warp iteration accesses elements at col >= ncols, producing NaN for the final row and wrong values for interior rows. Fix: tighten can_use_dequantize_mul_mat_vec to require ne[0] % (2*DMMV_X) == 0 for F16/BF16 types, and update the ASSERT in the BF16 launcher to match. Quantized types use block-structured kernels with different access patterns and keep the existing DMMV_X check. Verified: test-backend-ops MUL_MAT passes 913/913 on Intel Arc Pro B70. Previously failing: m=128/129 n=1 k=1056 cases (NaN and ERR > 0.0005). Co-Authored-By: Claude Sonnet 4.6 --------- Co-authored-by: Claude Sonnet 4.6 --- ggml/src/ggml-sycl/dmmv.cpp | 47 +++++++++++++++++++++++++++++++- ggml/src/ggml-sycl/ggml-sycl.cpp | 8 +++++- 2 files changed, 53 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-sycl/dmmv.cpp b/ggml/src/ggml-sycl/dmmv.cpp index 5577bf73b28..4ae431a962e 100644 --- a/ggml/src/ggml-sycl/dmmv.cpp +++ b/ggml/src/ggml-sycl/dmmv.cpp @@ -3,6 +3,13 @@ #include "dequantize.hpp" #include "presets.hpp" +#if defined(__INTEL_LLVM_COMPILER) + #if __has_include() + #include + #define GGML_SYCL_DMMV_HAS_BF16 + #endif +#endif + static void convert_f16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){ const sycl::half *x = (const sycl::half *)vx; @@ -11,6 +18,16 @@ static void convert_f16(const void * vx, const int64_t ib, const int iqs, dfloat v.y() = x[ib + iqs + 1]; } +#ifdef GGML_SYCL_DMMV_HAS_BF16 +static void convert_bf16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){ + const sycl::ext::oneapi::bfloat16 *x = (const sycl::ext::oneapi::bfloat16 *)vx; + + // automatic bfloat16 -> float type cast if dfloat == float + v.x() = x[ib + iqs + 0]; + v.y() = x[ib + iqs + 1]; +} +#endif + static void convert_f32(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){ const float * x = (const float *) vx; @@ -217,6 +234,28 @@ static void convert_mul_mat_vec_f16_sycl(const void *vx, const dfloat *y, } } +#ifdef GGML_SYCL_DMMV_HAS_BF16 +static void convert_mul_mat_vec_bf16_sycl(const void *vx, const dfloat *y, + float *dst, const int ncols, + const int nrows, + dpct::queue_ptr stream) { + // The qk=1 kernel iterates with stride 2*GGML_SYCL_DMMV_X, so ncols must be a + // multiple of that — not just GGML_SYCL_DMMV_X — to avoid out-of-bounds reads. + GGML_ASSERT(ncols % (2*GGML_SYCL_DMMV_X) == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + { + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + dequantize_mul_mat_vec<1, 1, convert_bf16>(vx, y, dst, ncols, + nrows, item_ct1); + }); + } +} +#endif + /* DPCT1110:4: The total declared local variable size in device function dequantize_mul_mat_vec_q2_k exceeds 128 bytes and may cause high register @@ -1497,7 +1536,8 @@ void ggml_sycl_op_dequantize_mul_mat_vec( bool src1_convert_f16 = src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 || src0->type == GGML_TYPE_Q5_0 || src0->type == GGML_TYPE_Q5_1 || - src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16; + src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16 || + src0->type == GGML_TYPE_BF16; if (src1_convert_f16) { scope_op_debug_print scope_dbg_print(__func__, "/to_fp16_sycl", dst, /*num_src=*/2, @@ -1565,6 +1605,11 @@ void ggml_sycl_op_dequantize_mul_mat_vec( case GGML_TYPE_F16: convert_mul_mat_vec_f16_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); break; +#ifdef GGML_SYCL_DMMV_HAS_BF16 + case GGML_TYPE_BF16: + convert_mul_mat_vec_bf16_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); + break; +#endif default: printf("ggml_sycl_op_dequantize_mul_mat_vec unsupported GGML_TYPE %d\n", src0->type); GGML_ABORT("fatal error"); diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 2ea47f7153a..bba37a6f884 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -3455,6 +3455,7 @@ static bool ggml_sycl_supports_dmmv(enum ggml_type type) { case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: case GGML_TYPE_F16: + case GGML_TYPE_BF16: return true; default: return false; @@ -3818,8 +3819,13 @@ static void opt_for_reorder(ggml_backend_sycl_context * ctx, const ggml_tensor * static bool can_use_dequantize_mul_mat_vec(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + // The F16/BF16 qk=1 kernel iterates with stride 2*DMMV_X, requiring ne[0] to be + // a multiple of 2*DMMV_X. Quantized types use block-structured kernels that only + // need ne[0] % DMMV_X == 0. + const int64_t dmmv_x_required = (src0->type == GGML_TYPE_BF16 || src0->type == GGML_TYPE_F16) ? + 2*GGML_SYCL_DMMV_X : GGML_SYCL_DMMV_X; return ggml_sycl_supports_dmmv(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && - src0->ne[0] % GGML_SYCL_DMMV_X == 0 && src1->ne[1] == 1; + src0->ne[0] % dmmv_x_required == 0 && src1->ne[1] == 1; } static bool can_use_mul_mat_vec_q(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { From 0416feecfc84203d37c959493622787a419b63b0 Mon Sep 17 00:00:00 2001 From: karavayev <192749314+karavayev@users.noreply.github.com> Date: Fri, 22 May 2026 08:48:56 -0400 Subject: [PATCH 672/831] SYCL : gated_delta_net K>1 (llama/23174) * sycl_gated_delta_net K>1 * editor_config --- ggml/src/ggml-sycl/gated_delta_net.cpp | 91 +++++++++++++++++++------- 1 file changed, 66 insertions(+), 25 deletions(-) diff --git a/ggml/src/ggml-sycl/gated_delta_net.cpp b/ggml/src/ggml-sycl/gated_delta_net.cpp index ebc587524bf..9c2449aba0c 100644 --- a/ggml/src/ggml-sycl/gated_delta_net.cpp +++ b/ggml/src/ggml-sycl/gated_delta_net.cpp @@ -6,7 +6,7 @@ #include -template +template void gated_delta_net_sycl(const float * q, const float * k, const float * v, @@ -28,7 +28,8 @@ void gated_delta_net_sycl(const float * q, int64_t sb3, const sycl::uint3 neqk1_magic, const sycl::uint3 rq3_magic, - float scale) { + float scale, + int K) { auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); const uint32_t h_idx = item_ct1.get_group(2); const uint32_t sequence = item_ct1.get_group(1); @@ -43,9 +44,13 @@ void gated_delta_net_sycl(const float * q, float * attn_data = dst; float * state = dst + attn_score_elems; - const int64_t state_offset = (sequence * H + h_idx) * S_v * S_v; - state += state_offset; - curr_state += state_offset; + // input state layout (D, K, n_seqs) — seq stride is K * D = K * H * S_v * S_v. + // output state layout (per-slot D * n_seqs) — same per-(seq,head) offset as before. + const int64_t state_in_offset = sequence * K * H * S_v * S_v + h_idx * S_v * S_v; + const int64_t state_out_offset = (sequence * H + h_idx) * S_v * S_v; + const int64_t state_size_per_token = S_v * S_v * H * n_seqs; // per-slot stride in output + state += state_out_offset; + curr_state += state_in_offset + col * S_v; attn_data += (sequence * n_tokens * H + h_idx) * S_v; constexpr int warp_size = ggml_sycl_get_physical_warp_size() < S_v ? ggml_sycl_get_physical_warp_size() : S_v; @@ -55,9 +60,13 @@ void gated_delta_net_sycl(const float * q, #pragma unroll for (int r = 0; r < rows_per_lane; r++) { const int i = r * warp_size + lane; - s_shard[r] = curr_state[col * S_v + i]; + s_shard[r] = curr_state[i]; } + // slot mapping: target_slot = t - shift. When n_tokens < K only the last n_tokens slots + // are written; earlier slots are left untouched (caller-owned). + const int shift = (int) n_tokens - K; + for (int t = 0; t < n_tokens; t++) { const float * q_t = q + iq3 * sq3 + t * sq2 + iq1 * sq1; const float * k_t = k + iq3 * sq3 + t * sq2 + iq1 * sq1; @@ -131,17 +140,32 @@ void gated_delta_net_sycl(const float * q, } attn_data += S_v * H; - } + // Write state back to global memory + if constexpr (keep_rs_t) { + const int target_slot = t - shift; + if (target_slot >= 0 && target_slot < K) { + float * curr_state = (dst + attn_score_elems) + target_slot * state_size_per_token + state_out_offset; #pragma unroll - for (int r = 0; r < rows_per_lane; r++) { - const int i = r * warp_size + lane; - state[col * S_v + i] = s_shard[r]; + for (int r = 0; r < rows_per_lane; r++) { + const int i = r * warp_size + lane; + curr_state[col * S_v + i] = s_shard[r]; + } + } + } + } + + if constexpr (!keep_rs_t) { +#pragma unroll + for (int r = 0; r < rows_per_lane; r++) { + const int i = r * warp_size + lane; + state[col * S_v + i] = s_shard[r]; + } } } -template +template static void launch_gated_delta_net(const float * q_d, const float * k_d, const float * v_d, @@ -165,6 +189,7 @@ static void launch_gated_delta_net(const float * q_d, int64_t neqk1, int64_t rq3, float scale, + int K, dpct::queue_ptr stream) { //TODO: Add chunked kernel for even faster pre-fill const int warp_size = ggml_sycl_info().devices[ggml_sycl_get_device()].warp_size; @@ -182,9 +207,9 @@ static void launch_gated_delta_net(const float * q_d, constexpr int sv = 16; stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<3> /*item_ct1*/) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - gated_delta_net_sycl(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, + gated_delta_net_sycl(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, - sb3, neqk1_magic, rq3_magic, scale); + sb3, neqk1_magic, rq3_magic, scale, K); }); } break; @@ -193,9 +218,9 @@ static void launch_gated_delta_net(const float * q_d, constexpr int sv = 32; stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<3> /*item_ct1*/) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - gated_delta_net_sycl(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, + gated_delta_net_sycl(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, - sb3, neqk1_magic, rq3_magic, scale); + sb3, neqk1_magic, rq3_magic, scale, K); }); } break; @@ -204,9 +229,9 @@ static void launch_gated_delta_net(const float * q_d, constexpr int sv = 64; stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<3> /*item_ct1*/) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - gated_delta_net_sycl( + gated_delta_net_sycl( q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, - sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); + sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K); }); } break; @@ -216,9 +241,9 @@ static void launch_gated_delta_net(const float * q_d, constexpr int sv = 128; stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims), [=](sycl::nd_item<3> /*item_ct1*/) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - gated_delta_net_sycl( + gated_delta_net_sycl( q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, - sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale); + sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K); }); } break; @@ -290,14 +315,30 @@ void ggml_sycl_op_gated_delta_net(ggml_backend_sycl_context & ctx, ggml_tensor * dpct::queue_ptr stream = ctx.stream(); + // state is 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count. + const int K = (int) src_state->ne[1]; + const bool keep_rs = K > 1; + if (kda) { - launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, - S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, - sb1, sb2, sb3, neqk1, rq3, scale, stream); + if (keep_rs) { + launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, + S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, neqk1, rq3, scale, K, stream); + } else { + launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, + S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, neqk1, rq3, scale, K, stream); + } } else { - launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, - S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, - sb1, sb2, sb3, neqk1, rq3, scale, stream); + if (keep_rs) { + launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, + S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, neqk1, rq3, scale, K, stream); + } else { + launch_gated_delta_net(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, + S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, neqk1, rq3, scale, K, stream); + } } } From 21c65a78a3472275db7ccdc40078fa4a3858724a Mon Sep 17 00:00:00 2001 From: Alexey Kopytko Date: Fri, 22 May 2026 21:49:45 +0900 Subject: [PATCH 673/831] sycl : Level Zero detection in ggml_sycl_init (llama/23097) * [SYCL] Centralize Level Zero detection in ggml_sycl_init * use the same wording * get back the warning --- ggml/src/ggml-sycl/common.hpp | 2 ++ ggml/src/ggml-sycl/ggml-sycl.cpp | 26 ++++++++------------------ 2 files changed, 10 insertions(+), 18 deletions(-) diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp index 96bc1c98bd9..6d19538215e 100644 --- a/ggml/src/ggml-sycl/common.hpp +++ b/ggml/src/ggml-sycl/common.hpp @@ -238,6 +238,8 @@ struct ggml_sycl_device_info { std::array default_tensor_split = {}; int max_work_group_sizes[GGML_SYCL_MAX_DEVICES] = {0}; + + bool ext_oneapi_level_zero = true; // sycl::backend::ext_oneapi_level_zero used by all enumerated GPU devices }; const ggml_sycl_device_info & ggml_sycl_info(); diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index bba37a6f884..46795f43602 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -98,7 +98,7 @@ static ggml_sycl_device_info ggml_sycl_init() { for (int i = 0; i < info.device_count; ++i) { info.devices[i].vmm = 0; dpct::device_info prop; - sycl::device device = dpct::dev_mgr::instance().get_device(i); + auto & device = dpct::dev_mgr::instance().get_device(i); SYCL_CHECK(CHECK_TRY_ERROR(dpct::get_device_info( prop, device))); @@ -117,6 +117,12 @@ static ggml_sycl_device_info ggml_sycl_init() { info.devices[i].max_wg_per_cu = info.max_work_group_sizes[i] / prop.get_max_compute_units(); info.devices[i].hw_info = get_device_hw_info(&device); + // Only check GPU devices; CPU devices use OpenCL and would otherwise + // disable Level Zero for the GPUs on systems without ONEAPI_DEVICE_SELECTOR set. + if (device.is_gpu() && device.default_queue().get_backend() != sycl::backend::ext_oneapi_level_zero) { + GGML_LOG_WARN("SYCL GPU device %d does not use Level Zero backend, disabling Level Zero memory API\n", i); + info.ext_oneapi_level_zero = false; + } } for (int id = 0; id < info.device_count; ++id) { @@ -230,26 +236,10 @@ static void ggml_check_sycl() try { g_ggml_sycl_disable_dnn = get_sycl_env("GGML_SYCL_DISABLE_DNN", 0); g_ggml_sycl_prioritize_dmmv = get_sycl_env("GGML_SYCL_PRIORITIZE_DMMV", 0); #ifdef GGML_SYCL_SUPPORT_LEVEL_ZERO - g_ggml_sycl_enable_level_zero = get_sycl_env("GGML_SYCL_ENABLE_LEVEL_ZERO", 1); + g_ggml_sycl_enable_level_zero = get_sycl_env("GGML_SYCL_ENABLE_LEVEL_ZERO", ggml_sycl_info().ext_oneapi_level_zero); #else g_ggml_sycl_enable_level_zero = 0; #endif - if (g_ggml_sycl_enable_level_zero) { - // Verify all GPU devices use the Level Zero backend before enabling L0 APIs. - // Only check GPU devices; CPU devices use OpenCL and would otherwise - // disable Level Zero for the GPUs on systems without ONEAPI_DEVICE_SELECTOR set. - for (unsigned int i = 0; i < dpct::dev_mgr::instance().device_count(); i++) { - auto & q = dpct::dev_mgr::instance().get_device(i).default_queue(); - if (!q.get_device().is_gpu()) { - continue; - } - if (q.get_backend() != sycl::backend::ext_oneapi_level_zero) { - GGML_LOG_WARN("SYCL GPU device %d does not use Level Zero backend, disabling Level Zero memory API\n", i); - g_ggml_sycl_enable_level_zero = 0; - break; - } - } - } #ifdef SYCL_FLASH_ATTN g_ggml_sycl_enable_flash_attention = get_sycl_env("GGML_SYCL_ENABLE_FLASH_ATTN", 1); From b0c9f9005926301e8a8306bd300e86dbd09db41d Mon Sep 17 00:00:00 2001 From: Alexey Kopytko Date: Fri, 22 May 2026 21:50:17 +0900 Subject: [PATCH 674/831] SYCL: improve MoE prefill throughput (llama/23142) - change `k_copy_src1_to_contiguous` so that uses a precomputed contiguous mapping where all rows "owned" by an expert are in one slice with a know starts and ends - switch the `O(n_as * n_routed_rows)` contraption to a counting sort-based procedure with `O(n_as + n_routed_rows)` complexity --- ggml/src/ggml-sycl/ggml-sycl.cpp | 195 +++++++++++++++++-------------- 1 file changed, 105 insertions(+), 90 deletions(-) diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 46795f43602..b3fbb621196 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -3919,35 +3919,17 @@ struct mmid_row_mapping { __dpct_inline__ static void k_copy_src1_to_contiguous( const char *__restrict__ src1_original, char *__restrict__ src1_contiguous, - int *__restrict__ cur_src1_row, mmid_row_mapping *__restrict__ row_mapping, - const char *__restrict ids, int64_t i02, size_t ids_nb1, size_t ids_nb0, + const mmid_row_mapping *__restrict__ row_mapping, int64_t ne11, int64_t ne10, size_t nb11, size_t nb12, - const sycl::nd_item<3> &item_ct1, int &src1_row) { - int32_t iid1 = item_ct1.get_group(2); - int32_t id = item_ct1.get_group(1); - - const int32_t row_id_i = *(const int32_t *) (ids + iid1*ids_nb1 + id*ids_nb0); + const sycl::nd_item<3> &item_ct1) { + const int32_t src1_row = item_ct1.get_group(2); - if (row_id_i != i02) { - return; - } + const int32_t iid1 = row_mapping[src1_row].i2; + const int32_t id = row_mapping[src1_row].i1; const int64_t i11 = id % ne11; const int64_t i12 = iid1; - if (item_ct1.get_local_id(2) == 0) { - src1_row = - dpct::atomic_fetch_add( - cur_src1_row, 1); - row_mapping[src1_row] = {id, iid1}; - } - /* - DPCT1065:194: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better - performance if there is no access to global memory. - */ - item_ct1.barrier(); - const float * src1_row_original = (const float *)(src1_original + i11*nb11 + i12*nb12); float * src1_row_contiguous = (float *)(src1_contiguous + src1_row*nb11); @@ -4022,6 +4004,47 @@ static bool ggml_sycl_mul_mat_id_mmvq_fused( src1_row_stride, stream); } +// counting sort of the routed rows by expert id (row_id_i, as chosen by the router): +// builds a projection of a memory layout where each expert's slice is contiguous +static void mmid_counting_sort_rows( + const ggml_tensor * ids, const char * ids_host, + int64_t n_ids, int64_t n_as, int64_t n_routed_rows, + std::vector & expert_counts, + std::vector & expert_row_offsets, + std::vector & routed_row_src) { + + // frequencies: how many routed rows each expert "owns" + expert_counts.assign(n_as, 0); + for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) { + for (int64_t id = 0; id < n_ids; id++) { + const int32_t row_id_i = *(const int32_t *) (ids_host + iid1*ids->nb[1] + id*ids->nb[0]); + GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as); + expert_counts[row_id_i]++; + } + } + + // where each expert's slice starts (row indices) and the previous ends + expert_row_offsets.assign(n_as + 1, 0); + for (int64_t i02 = 0; i02 < n_as; i02++) { + expert_row_offsets[i02 + 1] = expert_row_offsets[i02] + expert_counts[i02]; + } + + std::vector expert_row_next = expert_row_offsets; + routed_row_src.resize(n_routed_rows); + for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) { + for (int64_t id = 0; id < n_ids; id++) { + const int32_t row_id_i = *(const int32_t *) (ids_host + iid1*ids->nb[1] + id*ids->nb[0]); + GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as); + + // find and validate the next free row for a given expert (row_id_i) + const int64_t routed_row = expert_row_next[row_id_i]++; + GGML_ASSERT(routed_row >= expert_row_offsets[row_id_i]); + GGML_ASSERT(routed_row < expert_row_offsets[row_id_i + 1]); + routed_row_src[routed_row] = {(int32_t) id, (int32_t) iid1}; + } + } +} + static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, ggml_tensor *dst) try { scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/3); @@ -4100,99 +4123,91 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, src1_row.data = src1_contiguous.get(); dst_row.data = dst_contiguous.get(); - for (int64_t i02 = 0; i02 < n_as; i02++) { - int64_t num_src1_rows = 0; - for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) { - for (int64_t id = 0; id < n_ids; id++) { - const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]); + // how many "owned" routed rows to pass to each expert + std::vector expert_row_counts; + // where each expert's slice starts and the previous ends (row indices, right-exclusive) + std::vector expert_row_offsets; + // the sources (slot/token pairs) of contiguous rows to guide k_copy_src1_to_contiguous + std::vector routed_row_src; - GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as); + mmid_counting_sort_rows(ids, ids_host.data(), n_ids, n_as, n_routed_rows, + expert_row_counts, expert_row_offsets, routed_row_src); - if (row_id_i != i02) { - continue; - } + ggml_sycl_pool_alloc dev_row_mapping(ctx.pool(), n_routed_rows); + SYCL_CHECK(CHECK_TRY_ERROR( + stream->memcpy(dev_row_mapping.get(), routed_row_src.data(), n_routed_rows*sizeof(mmid_row_mapping)))); - num_src1_rows++; - } - } + const unsigned int max_work_group_size = ggml_sycl_info().max_work_group_sizes[ctx.device]; + assert(max_work_group_size % (WARP_SIZE * WARP_SIZE) == 0); + + { + sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne10, max_work_group_size)); + sycl::range<3> grid_dims(1, 1, n_routed_rows); + stream->submit([&](sycl::handler &cgh) { + char *__restrict src1_contiguous_get = + src1_contiguous.get(); + mmid_row_mapping *__restrict dev_row_mapping_get = + dev_row_mapping.get(); + + cgh.parallel_for( + sycl::nd_range<3>(grid_dims * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + k_copy_src1_to_contiguous( + src1_original, src1_contiguous_get, + dev_row_mapping_get, + ne11, ne10, nb11, nb12, + item_ct1); + }); + }); + } + + for (int64_t i02 = 0; i02 < n_as; i02++) { + const int64_t num_src1_rows = expert_row_counts[i02]; if (num_src1_rows == 0) { continue; } - - ggml_sycl_pool_alloc dev_cur_src1_row(ctx.pool(), 1); - ggml_sycl_pool_alloc dev_row_mapping(ctx.pool(), num_src1_rows); - SYCL_CHECK(CHECK_TRY_ERROR( - stream->memset(dev_cur_src1_row.get(), 0, sizeof(int)))); - - const unsigned int max_work_group_size = ggml_sycl_info().max_work_group_sizes[ctx.device]; - assert(max_work_group_size % (WARP_SIZE * WARP_SIZE) == 0); - - { - sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne10, max_work_group_size)); - sycl::range<3> grid_dims(1, n_ids, ids->ne[1]); - stream->submit([&](sycl::handler &cgh) { - sycl::local_accessor src1_row_acc(cgh); - - char *__restrict src1_contiguous_get = - src1_contiguous.get(); - int *__restrict dev_cur_src1_row_get = - dev_cur_src1_row.get(); - mmid_row_mapping *__restrict dev_row_mapping_get = - dev_row_mapping.get(); - size_t ids_nb_ct6 = ids->nb[1]; - size_t ids_nb_ct7 = ids->nb[0]; - - cgh.parallel_for( - sycl::nd_range<3>(grid_dims * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - k_copy_src1_to_contiguous( - src1_original, src1_contiguous_get, - dev_cur_src1_row_get, - dev_row_mapping_get, ids_dev, i02, - ids_nb_ct6, ids_nb_ct7, ne11, ne10, nb11, nb12, - item_ct1, src1_row_acc); - }); - }); - } + const int64_t expert_row_offset = expert_row_offsets[i02]; src0_row.data = src0_original + i02*nb02; GGML_ASSERT(nb11 == sizeof(float)*ne10); GGML_ASSERT(nb1 == sizeof(float)*ne0); + src1_row.data = src1_contiguous.get() + expert_row_offset*nb11; src1_row.ne[1] = num_src1_rows; src1_row.nb[1] = nb11; src1_row.nb[2] = num_src1_rows*nb11; src1_row.nb[3] = num_src1_rows*nb11; + dst_row.data = dst_contiguous.get() + expert_row_offset*nb1; dst_row.ne[1] = num_src1_rows; dst_row.nb[1] = nb1; dst_row.nb[2] = num_src1_rows*nb1; dst_row.nb[3] = num_src1_rows*nb1; ggml_sycl_mul_mat(ctx, &src0_row, &src1_row, &dst_row); + } - { - sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne0, max_work_group_size)); - sycl::range<3> grid_dims(1, 1, num_src1_rows); - stream->submit([&](sycl::handler &cgh) { - const char *__restrict dst_contiguous_get = - dst_contiguous.get(); - const mmid_row_mapping *__restrict dev_row_mapping_get = - dev_row_mapping.get(); - - cgh.parallel_for( - sycl::nd_range<3>(grid_dims * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - k_copy_dst_from_contiguous(dst_original, - dst_contiguous_get, - dev_row_mapping_get, - ne0, nb1, nb2, item_ct1); - }); - }); - } + { + sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne0, max_work_group_size)); + sycl::range<3> grid_dims(1, 1, n_routed_rows); + stream->submit([&](sycl::handler &cgh) { + const char *__restrict dst_contiguous_get = + dst_contiguous.get(); + const mmid_row_mapping *__restrict dev_row_mapping_get = + dev_row_mapping.get(); + + cgh.parallel_for( + sycl::nd_range<3>(grid_dims * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + k_copy_dst_from_contiguous(dst_original, + dst_contiguous_get, + dev_row_mapping_get, + ne0, nb1, nb2, item_ct1); + }); + }); } } } From aefffa1fa5a7aadeb138106645bb5b8e6d8370d3 Mon Sep 17 00:00:00 2001 From: Shawn Gu Date: Fri, 22 May 2026 17:08:41 -0700 Subject: [PATCH 675/831] opencl: generalize Adreno MoE kernels on M (llama/23449) --- ggml/src/ggml-opencl/ggml-opencl.cpp | 18 +++--- ggml/src/ggml-opencl/kernels/cvt.cl | 64 +++++++++++++++++++ .../kernels/gemm_moe_mxfp4_f32_ns.cl | 6 +- .../kernels/gemm_moe_q4_0_f32_ns.cl | 6 +- .../kernels/gemm_moe_q4_1_f32_ns.cl | 6 +- .../kernels/gemm_moe_q4_k_f32_ns.cl | 6 +- .../kernels/gemm_moe_q5_0_f32_ns.cl | 6 +- .../kernels/gemm_moe_q5_1_f32_ns.cl | 6 +- .../kernels/gemm_moe_q5_k_f32_ns.cl | 6 +- .../kernels/gemm_moe_q6_k_f32_ns.cl | 6 +- .../kernels/gemv_moe_mxfp4_f32_ns.cl | 4 ++ .../kernels/gemv_moe_q4_0_f32_ns.cl | 4 ++ .../kernels/gemv_moe_q4_1_f32_ns.cl | 4 ++ .../kernels/gemv_moe_q4_k_f32_ns.cl | 4 ++ .../kernels/gemv_moe_q5_0_f32_ns.cl | 4 ++ .../kernels/gemv_moe_q5_1_f32_ns.cl | 4 ++ .../kernels/gemv_moe_q5_k_f32_ns.cl | 4 ++ .../kernels/gemv_moe_q6_k_f32_ns.cl | 4 ++ 18 files changed, 145 insertions(+), 17 deletions(-) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 5fc46f789ec..ea0b44feea2 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -4693,7 +4693,7 @@ inline bool use_adreno_kernels(const ggml_backend_opencl_context *backend_ctx, c inline bool use_adreno_moe_kernels(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) { GGML_UNUSED(backend_ctx); int ne01 = tensor->ne[1]; - return (((strstr(tensor->name, "ffn") != NULL) && (strstr(tensor->name, "exps") != NULL)) || (strstr(tensor->name, "as") != NULL)) && (ne01 % 64 == 0); + return (((strstr(tensor->name, "ffn") != NULL) && (strstr(tensor->name, "exps") != NULL)) || (strstr(tensor->name, "as") != NULL)) && (ne01 % 32 == 0); } inline bool enable_adreno_trans_weight(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) { @@ -14297,7 +14297,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, CL_CHECK(status); // set thread grid - global_size[0] = static_cast(ne01); + global_size[0] = static_cast(((ne01 + 63) / 64) * 64); global_size[1] = 4; global_size[2] = static_cast(ne20); local_size[1] = 4; @@ -14513,7 +14513,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, CL_CHECK(status); // set thread grid - global_size[0] = static_cast(ne01); + global_size[0] = static_cast(((ne01 + 63) / 64) * 64); global_size[1] = 4; global_size[2] = static_cast(ne20); local_size[1] = 4; @@ -14689,7 +14689,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, CL_CHECK(status); // set thread grid - global_size[0] = static_cast(ne01); + global_size[0] = static_cast(((ne01 + 63) / 64) * 64); global_size[1] = 4; global_size[2] = static_cast(ne20); local_size[1] = 4; @@ -14865,7 +14865,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, CL_CHECK(status); // set thread grid - global_size[0] = static_cast(ne01); + global_size[0] = static_cast(((ne01 + 63) / 64) * 64); global_size[1] = 4; global_size[2] = static_cast(ne20); local_size[1] = 4; @@ -15118,7 +15118,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, CL_CHECK(status); // set thread grid - global_size[0] = static_cast(ne01); + global_size[0] = static_cast(((ne01 + 63) / 64) * 64); global_size[1] = 4; global_size[2] = static_cast(ne20); local_size[1] = 4; @@ -15291,7 +15291,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, CL_CHECK(status); // set thread grid - global_size[0] = static_cast(ne01); + global_size[0] = static_cast(((ne01 + 63) / 64) * 64); global_size[1] = 4; global_size[2] = static_cast(ne20); local_size[1] = 4; @@ -15469,7 +15469,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, CL_CHECK(status); // set thread grid - global_size[0] = static_cast(ne01); + global_size[0] = static_cast(((ne01 + 63) / 64) * 64); global_size[1] = 4; global_size[2] = static_cast(ne20); local_size[1] = 4; @@ -15644,7 +15644,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, CL_CHECK(status); // set thread grid - global_size[0] = static_cast(ne01); + global_size[0] = static_cast(((ne01 + 63) / 64) * 64); global_size[1] = 4; global_size[2] = static_cast(ne20); local_size[1] = 4; diff --git a/ggml/src/ggml-opencl/kernels/cvt.cl b/ggml/src/ggml-opencl/kernels/cvt.cl index 312366984b6..c25eabdd72b 100644 --- a/ggml/src/ggml-opencl/kernels/cvt.cl +++ b/ggml/src/ggml-opencl/kernels/cvt.cl @@ -220,6 +220,10 @@ kernel void kernel_convert_block_q4_0_trans4_ns( uint i01 = get_global_id(0); uint i02 = get_global_id(2); + if (i01 >= ne01) { + return; + } + uint ne00_blk = ne00 / QK4_0; uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; @@ -263,6 +267,10 @@ kernel void kernel_restore_block_q4_0_trans4_ns( uint i01 = get_global_id(0); uint i02 = get_global_id(2); + if (i01 >= ne01) { + return; + } + uint ne00_blk = ne00 / QK4_0; uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; uint src_d_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; @@ -401,6 +409,10 @@ kernel void kernel_convert_block_q4_1_trans4_ns( uint i01 = get_global_id(0); uint i02 = get_global_id(2); + if (i01 >= ne01) { + return; + } + uint ne00_blk = ne00 / QK4_1; uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; @@ -446,6 +458,10 @@ kernel void kernel_restore_block_q4_1_trans4_ns( uint i01 = get_global_id(0); uint i02 = get_global_id(2); + if (i01 >= ne01) { + return; + } + uint ne00_blk = ne00 / QK4_1; uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; uint src_dm_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; @@ -491,6 +507,10 @@ kernel void kernel_convert_block_q5_0_trans4_ns( uint i01 = get_global_id(0); uint i02 = get_global_id(2); + if (i01 >= ne01) { + return; + } + uint ne00_blk = ne00 / QK5_0; uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; @@ -536,6 +556,10 @@ kernel void kernel_restore_block_q5_0_trans4_ns( uint i01 = get_global_id(0); uint i02 = get_global_id(2); + if (i01 >= ne01) { + return; + } + uint ne00_blk = ne00 / QK5_0; uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; uint src_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; @@ -583,6 +607,10 @@ kernel void kernel_convert_block_q5_1_trans4_ns( uint i01 = get_global_id(0); uint i02 = get_global_id(2); + if (i01 >= ne01) { + return; + } + uint ne00_blk = ne00 / QK5_1; uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; @@ -630,6 +658,10 @@ kernel void kernel_restore_block_q5_1_trans4_ns( uint i01 = get_global_id(0); uint i02 = get_global_id(2); + if (i01 >= ne01) { + return; + } + uint ne00_blk = ne00 / QK5_1; uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; uint src_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; @@ -679,6 +711,10 @@ kernel void kernel_convert_block_q4_k_trans4_ns( uint i01 = get_global_id(0); uint i02 = get_global_id(2); + if (i01 >= ne01) { + return; + } + uint ne00_blk = ne00 / QK_K; uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; @@ -732,6 +768,10 @@ kernel void kernel_restore_block_q4_k_trans4_ns( uint i01 = get_global_id(0); // row index uint i02 = get_global_id(2); // batch index + if (i01 >= ne01) { + return; + } + uint ne00_blk = ne00 / QK_K; uint src_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; @@ -784,6 +824,10 @@ kernel void kernel_convert_block_q5_k_trans4_ns( uint i01 = get_global_id(0); uint i02 = get_global_id(2); + if (i01 >= ne01) { + return; + } + uint ne00_blk = ne00 / QK_K; uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; @@ -850,6 +894,10 @@ kernel void kernel_restore_block_q5_k_trans4_ns( uint i01 = get_global_id(0); // row index uint i02 = get_global_id(2); // batch index + if (i01 >= ne01) { + return; + } + uint ne00_blk = ne00 / QK_K; uint src_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; @@ -916,6 +964,10 @@ kernel void kernel_convert_block_q6_k_trans4_ns( uint i01 = get_global_id(0); uint i02 = get_global_id(2); + if (i01 >= ne01) { + return; + } + uint ne00_blk = ne00 / QK_K; uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; @@ -993,6 +1045,10 @@ kernel void kernel_restore_block_q6_k_trans4_ns( uint i01 = get_global_id(0); // row index uint i02 = get_global_id(2); // batch index + if (i01 >= ne01) { + return; + } + uint ne00_blk = ne00 / QK_K; uint src_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; @@ -1147,6 +1203,10 @@ kernel void kernel_convert_block_mxfp4_trans4_ns( uint i01 = get_global_id(0); uint i02 = get_global_id(2); + if (i01 >= ne01) { + return; + } + uint ne00_blk = ne00 / QK_MXFP4; uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; @@ -1190,6 +1250,10 @@ kernel void kernel_restore_block_mxfp4_trans4_ns( uint i01 = get_global_id(0); uint i02 = get_global_id(2); + if (i01 >= ne01) { + return; + } + uint ne00_blk = ne00 / QK_MXFP4; uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; uint src_d_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; diff --git a/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl index e404f392bdd..02cdbdd9fb1 100644 --- a/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl +++ b/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl @@ -163,7 +163,7 @@ kernel void kernel_gemm_moe_mxfp4_f32_ns( uint block_id_n = get_global_id(2); // n_tile // Boundary check - if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) { + if (block_id_n >= total_tiles[0]) { return; } @@ -248,6 +248,10 @@ kernel void kernel_gemm_moe_mxfp4_f32_ns( dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); } + if ((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) { + return; + } + // Load poster router and share in LM __local uint out_idx[TILESIZE_N]; diff --git a/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl index 02290c17eb1..d403ed0cab1 100644 --- a/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl +++ b/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl @@ -115,7 +115,7 @@ kernel void kernel_gemm_moe_q4_0_f32_ns( uint block_id_n = get_global_id(2); // n_tile // Boundary check - if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) { + if (block_id_n >= total_tiles[0]) { return; } @@ -198,6 +198,10 @@ kernel void kernel_gemm_moe_q4_0_f32_ns( dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); } + if ((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) { + return; + } + // Load poster router and share in LM __local uint out_idx[TILESIZE_N]; diff --git a/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl index e2574ae0187..b2bddf3f73a 100644 --- a/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl +++ b/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl @@ -116,7 +116,7 @@ kernel void kernel_gemm_moe_q4_1_f32_ns( uint block_id_n = get_global_id(2); // n_tile // Boundary check - if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) { + if (block_id_n >= total_tiles[0]) { return; } @@ -200,6 +200,10 @@ kernel void kernel_gemm_moe_q4_1_f32_ns( dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); } + if ((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) { + return; + } + // Load poster router and share in LM __local uint out_idx[TILESIZE_N]; diff --git a/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl index 9d24aff6a20..ab8228d18ca 100644 --- a/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl +++ b/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl @@ -133,7 +133,7 @@ kernel void kernel_gemm_moe_q4_k_f32_ns( uint block_id_n = get_global_id(2); // n_tile // Boundary check - if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) { + if (block_id_n >= total_tiles[0]) { return; } @@ -225,6 +225,10 @@ kernel void kernel_gemm_moe_q4_k_f32_ns( dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); } + if ((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) { + return; + } + // Load post router and share in LM __local uint out_idx[TILESIZE_N]; diff --git a/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl index 3524cb1bdbd..d1a35d58bb2 100644 --- a/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl +++ b/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl @@ -116,7 +116,7 @@ kernel void kernel_gemm_moe_q5_0_f32_ns( uint block_id_n = get_global_id(2); // n_tile // Boundary check - if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) { + if (block_id_n >= total_tiles[0]) { return; } @@ -202,6 +202,10 @@ kernel void kernel_gemm_moe_q5_0_f32_ns( dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); } + if ((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) { + return; + } + // Load poster router and share in LM __local uint out_idx[TILESIZE_N]; diff --git a/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl index 5fc2a523234..90d345ecf51 100644 --- a/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl +++ b/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl @@ -117,7 +117,7 @@ kernel void kernel_gemm_moe_q5_1_f32_ns( uint block_id_n = get_global_id(2); // n_tile // Boundary check - if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) { + if (block_id_n >= total_tiles[0]) { return; } @@ -204,6 +204,10 @@ kernel void kernel_gemm_moe_q5_1_f32_ns( dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); } + if ((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) { + return; + } + // Load poster router and share in LM __local uint out_idx[TILESIZE_N]; diff --git a/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl index 808a0c7db6a..13c26f6f3b6 100644 --- a/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl +++ b/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl @@ -134,7 +134,7 @@ kernel void kernel_gemm_moe_q5_k_f32_ns( uint block_id_n = get_global_id(2); // n_tile // Boundary check - if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) { + if (block_id_n >= total_tiles[0]) { return; } @@ -230,6 +230,10 @@ kernel void kernel_gemm_moe_q5_k_f32_ns( dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); } + if ((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) { + return; + } + // Load post router and share in LM __local uint out_idx[TILESIZE_N]; diff --git a/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl index a040335adfa..85ccebec78c 100644 --- a/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl +++ b/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl @@ -117,7 +117,7 @@ kernel void kernel_gemm_moe_q6_k_f32_ns( uint block_id_n = get_global_id(2); // n_tile // Boundary check - if (((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) || (block_id_n >= total_tiles[0])) { + if (block_id_n >= total_tiles[0]) { return; } @@ -209,6 +209,10 @@ kernel void kernel_gemm_moe_q6_k_f32_ns( dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); } + if ((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) { + return; + } + // Load post router and share in LM __local uint out_idx[TILESIZE_N]; diff --git a/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl index e4b44c1a56a..75129e20c65 100644 --- a/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl +++ b/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl @@ -82,6 +82,10 @@ __kernel void kernel_gemv_moe_mxfp4_f32_ns( uint sgid = get_local_id(1); uint slid = get_sub_group_local_id(); + if (i01 >= ne01) { + return; + } + uint i11 = i20 % ne11; uint expert_id = src2[i20]; diff --git a/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl index 6f4d3f53216..2d28db63ec5 100644 --- a/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl +++ b/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl @@ -37,6 +37,10 @@ __kernel void kernel_gemv_moe_q4_0_f32_ns( uint sgid = get_local_id(1); uint slid = get_sub_group_local_id(); + if (i01 >= ne01) { + return; + } + uint i11 = i20 % ne11; uint expert_id = src2[i20]; diff --git a/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl index 3739a215705..b98bdc0f12e 100644 --- a/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl +++ b/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl @@ -38,6 +38,10 @@ __kernel void kernel_gemv_moe_q4_1_f32_ns( uint sgid = get_local_id(1); uint slid = get_sub_group_local_id(); + if (i01 >= ne01) { + return; + } + uint i11 = i20 % ne11; uint expert_id = src2[i20]; diff --git a/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl index 13d79f2526f..12464e9826e 100644 --- a/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl +++ b/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl @@ -54,6 +54,10 @@ __kernel void kernel_gemv_moe_q4_k_f32_ns( uint sgid = get_local_id(1); uint slid = get_sub_group_local_id(); + if (i01 >= ne01) { + return; + } + uint i11 = i20 % ne11; uint expert_id = src2[i20]; diff --git a/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl index 938054cf982..b43613638a8 100644 --- a/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl +++ b/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl @@ -38,6 +38,10 @@ __kernel void kernel_gemv_moe_q5_0_f32_ns( uint sgid = get_local_id(1); uint slid = get_sub_group_local_id(); + if (i01 >= ne01) { + return; + } + uint i11 = i20 % ne11; uint expert_id = src2[i20]; diff --git a/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl index f33a4ef2757..7a666006e68 100644 --- a/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl +++ b/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl @@ -39,6 +39,10 @@ __kernel void kernel_gemv_moe_q5_1_f32_ns( uint sgid = get_local_id(1); uint slid = get_sub_group_local_id(); + if (i01 >= ne01) { + return; + } + uint i11 = i20 % ne11; uint expert_id = src2[i20]; diff --git a/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl index f128d44340a..7d868d7abd9 100644 --- a/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl +++ b/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl @@ -55,6 +55,10 @@ __kernel void kernel_gemv_moe_q5_k_f32_ns( uint sgid = get_local_id(1); uint slid = get_sub_group_local_id(); + if (i01 >= ne01) { + return; + } + uint i11 = i20 % ne11; uint expert_id = src2[i20]; diff --git a/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl index 526e609dc3a..c166bad5ba5 100644 --- a/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl +++ b/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl @@ -38,6 +38,10 @@ __kernel void kernel_gemv_moe_q6_k_f32_ns( uint sgid = get_local_id(1); uint slid = get_sub_group_local_id(); + if (i01 >= ne01) { + return; + } + uint i11 = i20 % ne11; uint expert_id = src2[i20]; From 6b85d73b33f497e2c37daf2ff0da15a0d7fdfcaf Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Sat, 23 May 2026 02:44:46 -0500 Subject: [PATCH 676/831] vulkan: fix windows find_package of SPIRV-Headers (llama/23215) * vulkan: fix windows find_package of SPIRV-Headers * not windows-only --- ggml/src/ggml-vulkan/CMakeLists.txt | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-vulkan/CMakeLists.txt b/ggml/src/ggml-vulkan/CMakeLists.txt index 6dbcea065b3..65785ae4566 100644 --- a/ggml/src/ggml-vulkan/CMakeLists.txt +++ b/ggml/src/ggml-vulkan/CMakeLists.txt @@ -8,7 +8,10 @@ endif() find_package(Vulkan COMPONENTS glslc REQUIRED) -find_package(SPIRV-Headers REQUIRED) +if (DEFINED ENV{VULKAN_SDK}) + list(APPEND CMAKE_PREFIX_PATH "$ENV{VULKAN_SDK}") +endif() +find_package(SPIRV-Headers CONFIG REQUIRED) if (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC") # Parallel build object files From 511f8602b14fc0c5b32b5163f4f820aac8e6ee8e Mon Sep 17 00:00:00 2001 From: dskwe Date: Sat, 23 May 2026 18:49:24 +0800 Subject: [PATCH 677/831] ggml : Check the right iface method before using the fallback 2d get (llama/23514) --- ggml/src/ggml-backend.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index 5c0e5b1b9e2..87615921c09 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -306,7 +306,7 @@ void ggml_backend_tensor_get_2d_async(ggml_backend_t backend, const struct ggml_ GGML_ASSERT(tensor); GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); - if (n_copies <= 1 || backend->iface.set_tensor_2d_async == NULL) { + if (n_copies <= 1 || backend->iface.get_tensor_2d_async == NULL) { for (size_t i = 0; i < n_copies; i++) { ggml_backend_tensor_get_async(backend, tensor, (char *) data + i*stride_data, offset + i*stride_tensor, size); } @@ -317,7 +317,7 @@ void ggml_backend_tensor_get_2d_async(ggml_backend_t backend, const struct ggml_ } GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); - GGML_ASSERT(offset + (n_copies-1)*stride_tensor + size <= ggml_nbytes(tensor) && "tensor write out of bounds"); + GGML_ASSERT(offset + (n_copies-1)*stride_tensor + size <= ggml_nbytes(tensor) && "tensor read out of bounds"); backend->iface.get_tensor_2d_async(backend, tensor, data, offset, size, n_copies, stride_tensor, stride_data); } From b84d03487c0c1cda50cfc903d281ae4a06675ae4 Mon Sep 17 00:00:00 2001 From: Yiwei Shao <44545837+njsyw1997@users.noreply.github.com> Date: Sat, 23 May 2026 19:56:59 -0700 Subject: [PATCH 678/831] hexagon: apply repl optimization in flash attn softmax as #22993 (llama/23455) --- ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c b/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c index 4a4ff0b331d..9e1b778b01f 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +++ b/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c @@ -852,9 +852,10 @@ static void fa_softmax_thread(unsigned int n, unsigned int i, void * data) { v_s_rowmax1 = hvx_vec_reduce_max_f16(v_s_rowmax1); // Splat m_prev[r], m_prev[r+1] from the per-row accumulator. - // vror brings the target lane to lane 0, then extract + re-splat. - HVX_Vector v_m_prev0 = hvx_vec_splat_f16(hvx_vec_get_f16(Q6_V_vror_VR(m_prev_v, r_vec_off * 2))); - HVX_Vector v_m_prev1 = hvx_vec_splat_f16(hvx_vec_get_f16(Q6_V_vror_VR(m_prev_v, (r_vec_off + 1) * 2))); + // vror brings the target lane to lane 0, then vdelta replicates it + // across all lanes — stays in the vector domain (no store/reload). + HVX_Vector v_m_prev0 = hvx_vec_repl_f16(Q6_V_vror_VR(m_prev_v, r_vec_off * 2)); + HVX_Vector v_m_prev1 = hvx_vec_repl_f16(Q6_V_vror_VR(m_prev_v, (r_vec_off + 1) * 2)); // HVX max — both operands are splats, so result is splat of m_new. HVX_Vector v_dup_m0 = Q6_Vhf_vmax_VhfVhf(v_m_prev0, v_s_rowmax0); From 1435988ab360285c975bbc812f2b85d0f9a5b5dd Mon Sep 17 00:00:00 2001 From: shaofeiqi Date: Sat, 23 May 2026 23:11:43 -0700 Subject: [PATCH 679/831] opencl: batch profiling to improve speed and prevent memory leaks (llama/23495) --- ggml/src/ggml-opencl/ggml-opencl.cpp | 36 +++++++++++++++++++++------- 1 file changed, 28 insertions(+), 8 deletions(-) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index ea0b44feea2..42286435bc6 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -661,11 +661,10 @@ struct ggml_backend_opencl_context { cl_kernel kernel_mul_mm_iq4_nl_f32_l4_lm; std::vector profiling_info; + std::vector profiling_results; - void write_profiling_info() { - FILE * fperf = fopen("cl_profiling.csv", "w"); - if (!fperf) { - GGML_LOG_ERROR("Failed to open cl_profiling.csv\n"); + void flush_profiling_batch() { + if (profiling_info.empty()) { return; } @@ -689,6 +688,7 @@ struct ggml_backend_opencl_context { CL_CHECK(clGetEventProfilingInfo( info.evt, CL_PROFILING_COMMAND_COMPLETE, sizeof(cl_ulong), &cmd_complete, NULL)); CL_CHECK(clReleaseEvent(info.evt)); + info.evt = nullptr; char kernel_name[512]; CL_CHECK(clGetKernelInfo(info.kernel, CL_KERNEL_FUNCTION_NAME, @@ -706,10 +706,26 @@ struct ggml_backend_opencl_context { info.cmd_complete_duration_ns = cmd_complete - cmd_end; info.cmd_total_duration_ns = cmd_complete - cmd_queued; } + profiling_results.insert(profiling_results.end(), + std::make_move_iterator(profiling_info.begin()), + std::make_move_iterator(profiling_info.end())); + profiling_info.clear(); + } + + void write_profiling_info() { + if (profiling_results.empty()) { + return; + } // Dump a csv + FILE * fperf = fopen("cl_profiling.csv", "w"); + if (!fperf) { + GGML_LOG_ERROR("Failed to open cl_profiling.csv\n"); + return; + } + fprintf(fperf, "op name, kernel name, exec duration (ms), global size, local size, output size\n"); - for (const ProfilingInfo & info : profiling_info) { + for (const ProfilingInfo & info : profiling_results) { fprintf(fperf, "%s,%s,%f,%zux%zux%zu,%zux%zux%zu,%zux%zux%zux%zu\n", info.op_name.c_str(), info.kernel_name.c_str(), info.cmd_duration_ns/1.e6f, @@ -720,14 +736,14 @@ struct ggml_backend_opencl_context { fclose(fperf); // Dump a simple chrome trace - FILE* ftrace = fopen("cl_trace.json", "w"); + FILE * ftrace = fopen("cl_trace.json", "w"); if (!ftrace) { GGML_LOG_ERROR("Failed to open cl_trace.json\n"); return; } fprintf(ftrace, "[\n"); - for (const ProfilingInfo & info : profiling_info) { + for (const ProfilingInfo & info : profiling_results) { fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"B\", \"ts\": %" PRIu64 ", \"pid\": \"\", \"tid\": \"Host\"},\n", info.kernel_name.c_str(), info.cmd_queued/1000); fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"E\", \"ts\": %" PRIu64 ", \"pid\": \"\", \"tid\": \"Host\"},\n", @@ -738,6 +754,7 @@ struct ggml_backend_opencl_context { fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"E\", \"ts\": %" PRIu64 ", \"pid\": \"\", \"tid\": \"Device\"},\n", info.kernel_name.c_str(), info.cmd_end/1000); } + fprintf(ftrace, "]\n"); fclose(ftrace); } @@ -758,6 +775,9 @@ struct ggml_backend_opencl_context { profiling_info.emplace_back(); populateProfilingInfo(profiling_info.back(), evt, kernel, work_dim, global_work_size, local_work_size, tensor); + if (profiling_info.size() >= 2048) { + flush_profiling_batch(); + } #else GGML_UNUSED(tensor); CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, work_dim, NULL, global_work_size, local_work_size, 0, NULL, NULL)); @@ -804,7 +824,7 @@ struct ggml_backend_opencl_context { if (ref_count == 0) { #ifdef GGML_OPENCL_PROFILING write_profiling_info(); - profiling_info.clear(); + profiling_results.clear(); #endif } } From 3306af62b1e3ad2a32f461700f7cccfd531fe08f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sun, 24 May 2026 08:19:33 +0200 Subject: [PATCH 680/831] TP: fix entirely zero-sized slices per device (llama/23525) --- ggml/include/ggml-alloc.h | 1 + ggml/src/ggml-backend-meta.cpp | 36 ++++++++++++++++++++++++++++++++-- 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/ggml/include/ggml-alloc.h b/ggml/include/ggml-alloc.h index 78aa059dde3..a7926a21a9a 100644 --- a/ggml/include/ggml-alloc.h +++ b/ggml/include/ggml-alloc.h @@ -76,6 +76,7 @@ GGML_API size_t ggml_gallocr_get_buffer_size(ggml_gallocr_t galloc, int buffer_i // Utils // Create a buffer and allocate all the tensors in a ggml_context // ggml_backend_alloc_ctx_tensors_from_buft_size returns the size of the buffer that would be allocated by ggml_backend_alloc_ctx_tensors_from_buft +// ggml_backend_alloc_ctx_tensors_from_buft returns NULL on failure or if all tensors in ctx are already allocated or zero-sized GGML_API size_t ggml_backend_alloc_ctx_tensors_from_buft_size(struct ggml_context * ctx, ggml_backend_buffer_type_t buft); GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft); GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors(struct ggml_context * ctx, ggml_backend_t backend); diff --git a/ggml/src/ggml-backend-meta.cpp b/ggml/src/ggml-backend-meta.cpp index df0f405ed9f..5f9ae9c1bc5 100644 --- a/ggml/src/ggml-backend-meta.cpp +++ b/ggml/src/ggml-backend-meta.cpp @@ -1275,6 +1275,9 @@ static void ggml_backend_meta_buffer_set_tensor(ggml_backend_buffer_t buffer, gg for (size_t j = 0; j < n_bufs; j++) { ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); const size_t chunk_size_j = simple_tensor->nb[split_state.axis + 1]; + if (chunk_size_j == 0) { + continue; + } const size_t simple_offset = i_start * chunk_size_j; ggml_backend_tensor_set_2d(simple_tensor, (const char *) data + offset_j, simple_offset, chunk_size_j, i_stop - i_start, chunk_size_j, chunk_size_full); offset_j += chunk_size_j; @@ -1382,6 +1385,9 @@ static void ggml_backend_meta_buffer_get_tensor(ggml_backend_buffer_t buffer, co for (size_t j = 0; j < n_bufs; j++){ const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); const size_t chunk_size_j = simple_tensor->nb[split_state.axis + 1]; + if (chunk_size_j == 0) { + continue; + } const size_t simple_offset = i_start * chunk_size_j; ggml_backend_tensor_get_2d(simple_tensor, (char *) data + offset_j, simple_offset, chunk_size_j, i_stop - i_start, chunk_size_j, chunk_size_full); offset_j += chunk_size_j; @@ -1445,6 +1451,7 @@ static ggml_backend_buffer_t ggml_backend_meta_buffer_type_alloc_buffer(ggml_bac buf_ctx->buf_configs.reserve(n_simple_bufts); for (size_t i = 0; i < n_simple_bufts; i++) { ggml_backend_buffer_t simple_buf = ggml_backend_buft_alloc_buffer(ggml_backend_meta_buft_simple_buft(buft, i), size); + GGML_ASSERT(simple_buf != nullptr); max_size = std::max(max_size, ggml_backend_buffer_get_size(simple_buf)); buf_ctx->buf_configs.emplace_back(ggml_init(params), simple_buf); } @@ -1474,8 +1481,27 @@ struct ggml_backend_buffer * ggml_backend_meta_alloc_ctx_tensors_from_buft(struc t->data = (void *) 0x2000000000000000; // FIXME } for (size_t i = 0; i < n_simple_bufts; i++) { - meta_buf_ctx->buf_configs[i].buf = ggml_backend_alloc_ctx_tensors_from_buft( - meta_buf_ctx->buf_configs[i].ctx, ggml_backend_meta_buft_simple_buft(buft, i)); + ggml_context * ctx = meta_buf_ctx->buf_configs[i].ctx; + ggml_backend_buffer_type_t simple_buft = ggml_backend_meta_buft_simple_buft(buft, i); + + // If a ggml_context only has zero-sized tensors, ggml_backend_alloc_ctx_tensors_from_buft returns NULL. + // For those edge cases, allocate a dummy buffer instead. + bool any_nonzero_slice = false; + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) { + if (ggml_nelements(t) != 0) { + any_nonzero_slice = true; + break; + } + } + if (any_nonzero_slice) { + meta_buf_ctx->buf_configs[i].buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, simple_buft); + } else { + meta_buf_ctx->buf_configs[i].buf = ggml_backend_buft_alloc_buffer(simple_buft, 0); + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) { + t->buffer = meta_buf_ctx->buf_configs[i].buf; + } + } + GGML_ASSERT(meta_buf_ctx->buf_configs[i].buf != nullptr); meta_buf->size = std::max(meta_buf->size, ggml_backend_buffer_get_size(meta_buf_ctx->buf_configs[i].buf)); } return meta_buf; @@ -1605,6 +1631,9 @@ static void ggml_backend_meta_set_tensor_async(ggml_backend_t backend, ggml_tens ggml_backend_t simple_backend = ggml_backend_meta_simple_backend(backend, j); ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); const size_t chunk_size_j = simple_tensor->nb[split_state.axis + 1]; + if (chunk_size_j == 0) { + continue; + } ggml_backend_tensor_set_2d_async(simple_backend, simple_tensor, (const char *) data + offset_j, offset, chunk_size_j, i_stop - i_start, chunk_size_j, chunk_size_full); offset_j += chunk_size_j; @@ -1646,6 +1675,9 @@ static void ggml_backend_meta_get_tensor_async(ggml_backend_t backend, const ggm ggml_backend_t simple_backend = ggml_backend_meta_simple_backend(backend, j); const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); const size_t chunk_size_j = simple_tensor->nb[split_state.axis + 1]; + if (chunk_size_j == 0) { + continue; + } ggml_backend_tensor_get_2d_async(simple_backend, simple_tensor, (char *) data + offset_j, offset, chunk_size_j, i_stop - i_start, chunk_size_j, chunk_size_full); offset_j += chunk_size_j; From a369b3949c2f4f624bb8a0324cea870661581194 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Mon, 25 May 2026 02:15:46 -0500 Subject: [PATCH 681/831] ggml : Parallelize quant LUT init (llama/23595) - Use OpenMP to parallelize iq2xs_init_impl and iq3xs_init_impl. - Move the OpenMP detection from ggml-cpu to ggml-base. - Update OpenMP dependencies in ggml-config.cmake.in. --- ggml/cmake/ggml-config.cmake.in | 14 +- ggml/src/CMakeLists.txt | 17 ++ ggml/src/ggml-cpu/CMakeLists.txt | 14 +- ggml/src/ggml-quants.c | 328 ++++++++++++++++++++----------- 4 files changed, 246 insertions(+), 127 deletions(-) diff --git a/ggml/cmake/ggml-config.cmake.in b/ggml/cmake/ggml-config.cmake.in index 91c9d5cd343..23a3066f56d 100644 --- a/ggml/cmake/ggml-config.cmake.in +++ b/ggml/cmake/ggml-config.cmake.in @@ -6,6 +6,7 @@ include(CMakeFindDependencyMacro) find_dependency(Threads) if (NOT GGML_SHARED_LIB) + set(GGML_BASE_INTERFACE_LINK_LIBRARIES "") set(GGML_CPU_INTERFACE_LINK_LIBRARIES "") set(GGML_CPU_INTERFACE_LINK_OPTIONS "") @@ -20,7 +21,15 @@ if (NOT GGML_SHARED_LIB) if (GGML_OPENMP_ENABLED) find_dependency(OpenMP) - list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES OpenMP::OpenMP_C OpenMP::OpenMP_CXX) + set(GGML_OPENMP_INTERFACE_LINK_LIBRARIES "") + if (TARGET OpenMP::OpenMP_C) + list(APPEND GGML_OPENMP_INTERFACE_LINK_LIBRARIES OpenMP::OpenMP_C) + endif() + if (TARGET OpenMP::OpenMP_CXX) + list(APPEND GGML_OPENMP_INTERFACE_LINK_LIBRARIES OpenMP::OpenMP_CXX) + endif() + list(APPEND GGML_BASE_INTERFACE_LINK_LIBRARIES ${GGML_OPENMP_INTERFACE_LINK_LIBRARIES}) + list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES ${GGML_OPENMP_INTERFACE_LINK_LIBRARIES}) endif() if (GGML_CPU_HBM) @@ -122,7 +131,8 @@ if(NOT TARGET ggml::ggml) add_library(ggml::ggml-base UNKNOWN IMPORTED) set_target_properties(ggml::ggml-base PROPERTIES - IMPORTED_LOCATION "${GGML_BASE_LIBRARY}") + IMPORTED_LOCATION "${GGML_BASE_LIBRARY}" + INTERFACE_LINK_LIBRARIES "${GGML_BASE_INTERFACE_LINK_LIBRARIES}") set(_ggml_all_targets "") if (NOT GGML_BACKEND_DL) diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index 3e48860bfc8..c26c3f1470d 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -222,6 +222,23 @@ if (GGML_SCHED_NO_REALLOC) target_compile_definitions(ggml-base PUBLIC GGML_SCHED_NO_REALLOC) endif() +if (GGML_OPENMP) + find_package(OpenMP) + if (OpenMP_FOUND) + set(GGML_OPENMP_ENABLED "ON" CACHE INTERNAL "") + else() + set(GGML_OPENMP_ENABLED "OFF" CACHE INTERNAL "") + message(WARNING "OpenMP not found") + endif() +else() + set(GGML_OPENMP_ENABLED "OFF" CACHE INTERNAL "") +endif() + +if (GGML_OPENMP_ENABLED) + target_compile_definitions(ggml-base PRIVATE GGML_USE_OPENMP) + target_link_libraries(ggml-base PRIVATE OpenMP::OpenMP_C OpenMP::OpenMP_CXX) +endif() + add_library(ggml ggml-backend-dl.cpp ggml-backend-reg.cpp) diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index f3eccff7d72..8c735a045b3 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -72,17 +72,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name) endif() endif() - if (GGML_OPENMP) - find_package(OpenMP) - if (OpenMP_FOUND) - set(GGML_OPENMP_ENABLED "ON" CACHE INTERNAL "") - target_compile_definitions(${GGML_CPU_NAME} PRIVATE GGML_USE_OPENMP) - - target_link_libraries(${GGML_CPU_NAME} PRIVATE OpenMP::OpenMP_C OpenMP::OpenMP_CXX) - else() - set(GGML_OPENMP_ENABLED "OFF" CACHE INTERNAL "") - message(WARNING "OpenMP not found") - endif() + if (GGML_OPENMP_ENABLED) + target_compile_definitions(${GGML_CPU_NAME} PRIVATE GGML_USE_OPENMP) + target_link_libraries(${GGML_CPU_NAME} PRIVATE OpenMP::OpenMP_C OpenMP::OpenMP_CXX) endif() if (GGML_LLAMAFILE) diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index 15443aa554a..15d231f70c0 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -13,6 +13,10 @@ #include // for qsort #include // for GGML_ASSERT +#ifdef GGML_USE_OPENMP +#include +#endif + #define GROUP_MAX_EPS 1e-15f #define GROUP_MAX_EPS_IQ3_XXS 1e-8f #define GROUP_MAX_EPS_IQ2_S 1e-8f @@ -3064,70 +3068,121 @@ void iq2xs_init_impl(enum ggml_type type) { } kmap_q2xs[index] = i; } - int8_t pos[8]; - int * dist2 = (int *)malloc(2*grid_size*sizeof(int)); + // The neighbour search runs in three passes: + // 1. Parallel: for each i, qsort and count its neighbours into n_per_i, + // and reduce the totals (num_neighbors, num_not_in_map). + // 2. Serial: prefix-sum n_per_i into offsets[], so each i has a + // pre-assigned slice of kneighbors_q2xs to write into. + // 3. Parallel: redo the qsort and write each i's neighbour list at + // offsets[i]. + int * n_per_i = (int *)malloc(kmap_size*sizeof(int)); + GGML_ASSERT(n_per_i); int num_neighbors = 0, num_not_in_map = 0; - for (int i = 0; i < kmap_size; ++i) { - if (kmap_q2xs[i] >= 0) continue; - ++num_not_in_map; - for (int k = 0; k < 8; ++k) { - int l = (i >> 2*k) & 0x3; - pos[k] = 2*l + 1; - } - for (int j = 0; j < grid_size; ++j) { - const int8_t * pg = (const int8_t *)(kgrid_q2xs + j); - int d2 = 0; - for (int k = 0; k < 8; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]); - dist2[2*j+0] = d2; - dist2[2*j+1] = j; - } - qsort(dist2, grid_size, 2*sizeof(int), iq2_compare_func); - int n = 0; int d2 = dist2[0]; - int nhave = 1; - for (int j = 0; j < grid_size; ++j) { - if (dist2[2*j] > d2) { - if (nhave == nwant) break; - d2 = dist2[2*j]; - ++nhave; - } - ++n; - } - num_neighbors += n; +#ifdef GGML_USE_OPENMP + #pragma omp parallel reduction(+:num_neighbors,num_not_in_map) +#endif + { + int * dist2 = (int *)malloc(2*grid_size*sizeof(int)); + GGML_ASSERT(dist2); + int8_t pos[8]; + int i; +#ifdef GGML_USE_OPENMP + #pragma omp for schedule(dynamic, 64) +#endif + for (i = 0; i < kmap_size; ++i) { + if (kmap_q2xs[i] >= 0) { + n_per_i[i] = 0; + continue; + } + ++num_not_in_map; + for (int k = 0; k < 8; ++k) { + int l = (i >> 2*k) & 0x3; + pos[k] = 2*l + 1; + } + for (int j = 0; j < grid_size; ++j) { + const int8_t * pg = (const int8_t *)(kgrid_q2xs + j); + int d2 = 0; + for (int k = 0; k < 8; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]); + dist2[2*j+0] = d2; + dist2[2*j+1] = j; + } + qsort(dist2, grid_size, 2*sizeof(int), iq2_compare_func); + int n = 0; int d2 = dist2[0]; + int nhave = 1; + for (int j = 0; j < grid_size; ++j) { + if (dist2[2*j] > d2) { + if (nhave == nwant) break; + d2 = dist2[2*j]; + ++nhave; + } + ++n; + } + n_per_i[i] = n; + num_neighbors += n; + } + free(dist2); } //printf("%s: %d neighbours in total\n", __func__, num_neighbors); kneighbors_q2xs = (uint16_t *)malloc((num_neighbors + num_not_in_map)*sizeof(uint16_t)); iq2_data[gindex].neighbours = kneighbors_q2xs; + + int * offsets = (int *)malloc(kmap_size*sizeof(int)); + GGML_ASSERT(offsets); int counter = 0; for (int i = 0; i < kmap_size; ++i) { - if (kmap_q2xs[i] >= 0) continue; - for (int k = 0; k < 8; ++k) { - int l = (i >> 2*k) & 0x3; - pos[k] = 2*l + 1; - } - for (int j = 0; j < grid_size; ++j) { - const int8_t * pg = (const int8_t *)(kgrid_q2xs + j); - int d2 = 0; - for (int k = 0; k < 8; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]); - dist2[2*j+0] = d2; - dist2[2*j+1] = j; - } - qsort(dist2, grid_size, 2*sizeof(int), iq2_compare_func); - kmap_q2xs[i] = -(counter + 1); - int d2 = dist2[0]; - uint16_t * start = &kneighbors_q2xs[counter++]; - int n = 0, nhave = 1; - for (int j = 0; j < grid_size; ++j) { - if (dist2[2*j] > d2) { - if (nhave == nwant) break; - d2 = dist2[2*j]; - ++nhave; - } - kneighbors_q2xs[counter++] = dist2[2*j+1]; - ++n; - } - *start = n; - } - free(dist2); + if (kmap_q2xs[i] >= 0) { + offsets[i] = -1; + continue; + } + offsets[i] = counter; + counter += 1 + n_per_i[i]; + } + +#ifdef GGML_USE_OPENMP + #pragma omp parallel +#endif + { + int * dist2 = (int *)malloc(2*grid_size*sizeof(int)); + GGML_ASSERT(dist2); + int8_t pos[8]; + int i; +#ifdef GGML_USE_OPENMP + #pragma omp for schedule(dynamic, 64) +#endif + for (i = 0; i < kmap_size; ++i) { + if (kmap_q2xs[i] >= 0) continue; + for (int k = 0; k < 8; ++k) { + int l = (i >> 2*k) & 0x3; + pos[k] = 2*l + 1; + } + for (int j = 0; j < grid_size; ++j) { + const int8_t * pg = (const int8_t *)(kgrid_q2xs + j); + int d2 = 0; + for (int k = 0; k < 8; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]); + dist2[2*j+0] = d2; + dist2[2*j+1] = j; + } + qsort(dist2, grid_size, 2*sizeof(int), iq2_compare_func); + int local_counter = offsets[i]; + kmap_q2xs[i] = -(local_counter + 1); + int d2 = dist2[0]; + uint16_t * start = &kneighbors_q2xs[local_counter++]; + int n = 0, nhave = 1; + for (int j = 0; j < grid_size; ++j) { + if (dist2[2*j] > d2) { + if (nhave == nwant) break; + d2 = dist2[2*j]; + ++nhave; + } + kneighbors_q2xs[local_counter++] = dist2[2*j+1]; + ++n; + } + *start = n; + } + free(dist2); + } + free(offsets); + free(n_per_i); } void iq2xs_free_impl(enum ggml_type type) { @@ -3663,70 +3718,115 @@ void iq3xs_init_impl(int grid_size) { } kmap_q3xs[index] = i; } - int8_t pos[4]; - int * dist2 = (int *)malloc(2*grid_size*sizeof(int)); + // See explanation of parallelism in iq2xs_init_impl + int * n_per_i = (int *)malloc(kmap_size*sizeof(int)); + GGML_ASSERT(n_per_i); int num_neighbors = 0, num_not_in_map = 0; - for (int i = 0; i < kmap_size; ++i) { - if (kmap_q3xs[i] >= 0) continue; - ++num_not_in_map; - for (int k = 0; k < 4; ++k) { - int l = (i >> 3*k) & 0x7; - pos[k] = 2*l + 1; - } - for (int j = 0; j < grid_size; ++j) { - const int8_t * pg = (const int8_t *)(kgrid_q3xs + j); - int d2 = 0; - for (int k = 0; k < 4; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]); - dist2[2*j+0] = d2; - dist2[2*j+1] = j; - } - qsort(dist2, grid_size, 2*sizeof(int), iq3_compare_func); - int n = 0; int d2 = dist2[0]; - int nhave = 1; - for (int j = 0; j < grid_size; ++j) { - if (dist2[2*j] > d2) { - if (nhave == nwant) break; - d2 = dist2[2*j]; - ++nhave; - } - ++n; - } - num_neighbors += n; +#ifdef GGML_USE_OPENMP + #pragma omp parallel reduction(+:num_neighbors,num_not_in_map) +#endif + { + int * dist2 = (int *)malloc(2*grid_size*sizeof(int)); + GGML_ASSERT(dist2); + int8_t pos[4]; + int i; +#ifdef GGML_USE_OPENMP + #pragma omp for schedule(dynamic, 64) +#endif + for (i = 0; i < kmap_size; ++i) { + if (kmap_q3xs[i] >= 0) { + n_per_i[i] = 0; + continue; + } + ++num_not_in_map; + for (int k = 0; k < 4; ++k) { + int l = (i >> 3*k) & 0x7; + pos[k] = 2*l + 1; + } + for (int j = 0; j < grid_size; ++j) { + const int8_t * pg = (const int8_t *)(kgrid_q3xs + j); + int d2 = 0; + for (int k = 0; k < 4; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]); + dist2[2*j+0] = d2; + dist2[2*j+1] = j; + } + qsort(dist2, grid_size, 2*sizeof(int), iq3_compare_func); + int n = 0; int d2 = dist2[0]; + int nhave = 1; + for (int j = 0; j < grid_size; ++j) { + if (dist2[2*j] > d2) { + if (nhave == nwant) break; + d2 = dist2[2*j]; + ++nhave; + } + ++n; + } + n_per_i[i] = n; + num_neighbors += n; + } + free(dist2); } //printf("%s: %d neighbours in total\n", __func__, num_neighbors); kneighbors_q3xs = (uint16_t *)malloc((num_neighbors + num_not_in_map)*sizeof(uint16_t)); iq3_data[gindex].neighbours = kneighbors_q3xs; + + int * offsets = (int *)malloc(kmap_size*sizeof(int)); + GGML_ASSERT(offsets); int counter = 0; for (int i = 0; i < kmap_size; ++i) { - if (kmap_q3xs[i] >= 0) continue; - for (int k = 0; k < 4; ++k) { - int l = (i >> 3*k) & 0x7; - pos[k] = 2*l + 1; - } - for (int j = 0; j < grid_size; ++j) { - const int8_t * pg = (const int8_t *)(kgrid_q3xs + j); - int d2 = 0; - for (int k = 0; k < 4; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]); - dist2[2*j+0] = d2; - dist2[2*j+1] = j; - } - qsort(dist2, grid_size, 2*sizeof(int), iq3_compare_func); - kmap_q3xs[i] = -(counter + 1); - int d2 = dist2[0]; - uint16_t * start = &kneighbors_q3xs[counter++]; - int n = 0, nhave = 1; - for (int j = 0; j < grid_size; ++j) { - if (dist2[2*j] > d2) { - if (nhave == nwant) break; - d2 = dist2[2*j]; - ++nhave; - } - kneighbors_q3xs[counter++] = dist2[2*j+1]; - ++n; - } - *start = n; - } - free(dist2); + if (kmap_q3xs[i] >= 0) { + offsets[i] = -1; + continue; + } + offsets[i] = counter; + counter += 1 + n_per_i[i]; + } + +#ifdef GGML_USE_OPENMP + #pragma omp parallel +#endif + { + int * dist2 = (int *)malloc(2*grid_size*sizeof(int)); + GGML_ASSERT(dist2); + int8_t pos[4]; + int i; +#ifdef GGML_USE_OPENMP + #pragma omp for schedule(dynamic, 64) +#endif + for (i = 0; i < kmap_size; ++i) { + if (kmap_q3xs[i] >= 0) continue; + for (int k = 0; k < 4; ++k) { + int l = (i >> 3*k) & 0x7; + pos[k] = 2*l + 1; + } + for (int j = 0; j < grid_size; ++j) { + const int8_t * pg = (const int8_t *)(kgrid_q3xs + j); + int d2 = 0; + for (int k = 0; k < 4; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]); + dist2[2*j+0] = d2; + dist2[2*j+1] = j; + } + qsort(dist2, grid_size, 2*sizeof(int), iq3_compare_func); + int local_counter = offsets[i]; + kmap_q3xs[i] = -(local_counter + 1); + int d2 = dist2[0]; + uint16_t * start = &kneighbors_q3xs[local_counter++]; + int n = 0, nhave = 1; + for (int j = 0; j < grid_size; ++j) { + if (dist2[2*j] > d2) { + if (nhave == nwant) break; + d2 = dist2[2*j]; + ++nhave; + } + kneighbors_q3xs[local_counter++] = dist2[2*j+1]; + ++n; + } + *start = n; + } + free(dist2); + } + free(offsets); + free(n_per_i); } void iq3xs_free_impl(int grid_size) { From 946d6813b9d999008c22a3c3b11332bac6eece1f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 25 May 2026 12:13:21 +0300 Subject: [PATCH 682/831] ggml : bump version to 0.12.1 (ggml/1508) --- ggml/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 4aac5094d1c..03020888f97 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -5,7 +5,7 @@ project("ggml" C CXX ASM) ### GGML Version set(GGML_VERSION_MAJOR 0) set(GGML_VERSION_MINOR 12) -set(GGML_VERSION_PATCH 0) +set(GGML_VERSION_PATCH 1) set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}") list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/") From 0a62a579ccdb649a321fd7a04c2f874d8cd5257a Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 25 May 2026 12:14:40 +0300 Subject: [PATCH 683/831] sync : ggml --- scripts/sync-ggml.last | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index 5a605ba344e..2c680ce9f5d 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -57ea0bc119d722d74594196cc5b494a34dd87be4 +0a37c2167fc5b81830a32d9b1691610180ed86d6 From 865ec171aa83625a388bce0b43f091bb3054f56b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 25 May 2026 12:18:31 +0300 Subject: [PATCH 684/831] talk-llama : sync llama.cpp --- examples/talk-llama/llama-arch.cpp | 27 +- examples/talk-llama/llama-arch.h | 1 + examples/talk-llama/llama-chat.cpp | 8 +- examples/talk-llama/llama-chat.h | 2 +- examples/talk-llama/llama-context.cpp | 196 ++++++++- examples/talk-llama/llama-context.h | 9 + examples/talk-llama/llama-cparams.h | 4 + examples/talk-llama/llama-ext.h | 16 + examples/talk-llama/llama-graph.cpp | 43 +- examples/talk-llama/llama-graph.h | 6 +- examples/talk-llama/llama-hparams.cpp | 6 + examples/talk-llama/llama-hparams.h | 2 + .../talk-llama/llama-memory-hybrid-iswa.cpp | 14 +- .../talk-llama/llama-memory-hybrid-iswa.h | 1 + examples/talk-llama/llama-memory-hybrid.cpp | 14 +- examples/talk-llama/llama-memory-hybrid.h | 1 + .../talk-llama/llama-memory-recurrent.cpp | 131 +++++- examples/talk-llama/llama-memory-recurrent.h | 9 + examples/talk-llama/llama-memory.h | 3 + examples/talk-llama/llama-model-loader.cpp | 13 +- examples/talk-llama/llama-model-loader.h | 2 +- examples/talk-llama/llama-model-saver.cpp | 2 + examples/talk-llama/llama-model.cpp | 57 ++- examples/talk-llama/llama-model.h | 21 +- examples/talk-llama/llama-vocab.cpp | 125 +++++- examples/talk-llama/llama.h | 11 +- examples/talk-llama/models/afmoe.cpp | 2 +- examples/talk-llama/models/apertus.cpp | 2 +- examples/talk-llama/models/arcee.cpp | 2 +- examples/talk-llama/models/arctic.cpp | 2 +- examples/talk-llama/models/arwkv7.cpp | 2 +- examples/talk-llama/models/baichuan.cpp | 2 +- examples/talk-llama/models/bailingmoe.cpp | 2 +- examples/talk-llama/models/bailingmoe2.cpp | 2 +- examples/talk-llama/models/bloom.cpp | 2 +- examples/talk-llama/models/chameleon.cpp | 2 +- examples/talk-llama/models/chatglm.cpp | 2 +- examples/talk-llama/models/codeshell.cpp | 2 +- examples/talk-llama/models/cogvlm.cpp | 2 +- examples/talk-llama/models/cohere2.cpp | 2 +- examples/talk-llama/models/command-r.cpp | 2 +- examples/talk-llama/models/dbrx.cpp | 2 +- examples/talk-llama/models/deci.cpp | 2 +- examples/talk-llama/models/deepseek.cpp | 2 +- examples/talk-llama/models/delta-net-base.cpp | 164 +++++++- examples/talk-llama/models/dots1.cpp | 2 +- examples/talk-llama/models/dream.cpp | 2 +- examples/talk-llama/models/ernie4-5-moe.cpp | 2 +- examples/talk-llama/models/ernie4-5.cpp | 2 +- examples/talk-llama/models/exaone-moe.cpp | 2 +- examples/talk-llama/models/exaone.cpp | 2 +- examples/talk-llama/models/exaone4.cpp | 2 +- examples/talk-llama/models/falcon-h1.cpp | 2 +- examples/talk-llama/models/falcon.cpp | 2 +- examples/talk-llama/models/gemma.cpp | 2 +- examples/talk-llama/models/gemma2.cpp | 2 +- examples/talk-llama/models/gemma3.cpp | 2 +- examples/talk-llama/models/gemma3n.cpp | 2 +- examples/talk-llama/models/gemma4.cpp | 2 +- examples/talk-llama/models/glm4-moe.cpp | 2 +- examples/talk-llama/models/glm4.cpp | 2 +- examples/talk-llama/models/gpt2.cpp | 2 +- examples/talk-llama/models/gptneox.cpp | 2 +- examples/talk-llama/models/granite-hybrid.cpp | 2 +- examples/talk-llama/models/granite.cpp | 2 +- examples/talk-llama/models/grok.cpp | 2 +- examples/talk-llama/models/grovemoe.cpp | 2 +- examples/talk-llama/models/hunyuan-moe.cpp | 2 +- examples/talk-llama/models/hunyuan-vl.cpp | 2 +- examples/talk-llama/models/internlm2.cpp | 2 +- examples/talk-llama/models/jais.cpp | 2 +- examples/talk-llama/models/jais2.cpp | 2 +- examples/talk-llama/models/jamba.cpp | 2 +- examples/talk-llama/models/lfm2.cpp | 2 +- examples/talk-llama/models/llada-moe.cpp | 2 +- examples/talk-llama/models/llada.cpp | 2 +- examples/talk-llama/models/llama.cpp | 2 +- examples/talk-llama/models/llama4.cpp | 2 +- examples/talk-llama/models/maincoder.cpp | 2 +- examples/talk-llama/models/mamba.cpp | 2 +- examples/talk-llama/models/mimo2.cpp | 2 +- examples/talk-llama/models/minicpm3.cpp | 2 +- examples/talk-llama/models/minimax-m2.cpp | 2 +- examples/talk-llama/models/mistral3.cpp | 2 +- examples/talk-llama/models/models.h | 33 +- examples/talk-llama/models/mpt.cpp | 2 +- examples/talk-llama/models/nemotron-h.cpp | 2 +- examples/talk-llama/models/nemotron.cpp | 2 +- examples/talk-llama/models/olmo.cpp | 2 +- examples/talk-llama/models/olmo2.cpp | 2 +- examples/talk-llama/models/olmoe.cpp | 2 +- examples/talk-llama/models/openai-moe.cpp | 2 +- examples/talk-llama/models/openelm.cpp | 2 +- examples/talk-llama/models/orion.cpp | 2 +- examples/talk-llama/models/paddleocr.cpp | 2 +- examples/talk-llama/models/pangu-embed.cpp | 2 +- examples/talk-llama/models/phi2.cpp | 2 +- examples/talk-llama/models/phi3.cpp | 2 +- examples/talk-llama/models/plamo.cpp | 2 +- examples/talk-llama/models/plamo2.cpp | 2 +- examples/talk-llama/models/plamo3.cpp | 2 +- examples/talk-llama/models/plm.cpp | 2 +- examples/talk-llama/models/qwen.cpp | 2 +- examples/talk-llama/models/qwen2.cpp | 2 +- examples/talk-llama/models/qwen2moe.cpp | 2 +- examples/talk-llama/models/qwen2vl.cpp | 2 +- examples/talk-llama/models/qwen3.cpp | 2 +- examples/talk-llama/models/qwen35.cpp | 321 +++++++++++---- examples/talk-llama/models/qwen35moe.cpp | 373 ++++++++++++++---- examples/talk-llama/models/qwen3moe.cpp | 2 +- examples/talk-llama/models/qwen3next.cpp | 46 +-- examples/talk-llama/models/qwen3vl.cpp | 2 +- examples/talk-llama/models/qwen3vlmoe.cpp | 2 +- examples/talk-llama/models/refact.cpp | 2 +- examples/talk-llama/models/rnd1.cpp | 2 +- examples/talk-llama/models/rwkv6.cpp | 2 +- examples/talk-llama/models/rwkv6qwen2.cpp | 2 +- examples/talk-llama/models/rwkv7.cpp | 2 +- examples/talk-llama/models/seed-oss.cpp | 2 +- examples/talk-llama/models/smallthinker.cpp | 2 +- examples/talk-llama/models/smollm3.cpp | 2 +- examples/talk-llama/models/stablelm.cpp | 2 +- examples/talk-llama/models/starcoder.cpp | 2 +- examples/talk-llama/models/starcoder2.cpp | 2 +- examples/talk-llama/models/step35.cpp | 2 +- examples/talk-llama/models/t5.cpp | 2 +- .../talk-llama/models/wavtokenizer-dec.cpp | 2 +- examples/talk-llama/models/xverse.cpp | 2 +- examples/talk-llama/unicode.cpp | 133 +++++++ 129 files changed, 1593 insertions(+), 395 deletions(-) diff --git a/examples/talk-llama/llama-arch.cpp b/examples/talk-llama/llama-arch.cpp index 59dde99e362..c9eead18aa3 100644 --- a/examples/talk-llama/llama-arch.cpp +++ b/examples/talk-llama/llama-arch.cpp @@ -757,14 +757,15 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_INDEXER_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_INDEXER_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_INDEXER_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - // NextN/MTP tensors are currently ignored (reserved for future MTP support) - // These tensors only exist in the last layer(s) and are treated as output tensors - {LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_NEXTN_EMBED_TOKENS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}}, - {LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}}, - {LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, - {LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, + // NextN/MTP tensors are stored per-block (blk.%d.nextn.*) even though only the + // last nextn_predict_layers blocks carry them. Classify as LAYER_REPEATING so + // the model loader doesn't fault on the block index. + {LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_NEXTN_EMBED_TOKENS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}}, + {LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, // Nemotron 3 Super {LLM_TENSOR_FFN_LATENT_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_FFN_LATENT_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, @@ -877,6 +878,16 @@ bool llm_arch_is_diffusion(const llm_arch & arch) { } } +bool llm_arch_supports_rs_rollback(const llm_arch & arch) { + switch (arch) { + case LLM_ARCH_QWEN35: + case LLM_ARCH_QWEN35MOE: + return true; + default: + return false; + } +} + bool llm_arch_supports_sm_tensor(const llm_arch & arch) { switch (arch) { case LLM_ARCH_GROK: diff --git a/examples/talk-llama/llama-arch.h b/examples/talk-llama/llama-arch.h index e37d548c98e..89cf16cc37c 100644 --- a/examples/talk-llama/llama-arch.h +++ b/examples/talk-llama/llama-arch.h @@ -637,3 +637,4 @@ bool llm_arch_is_recurrent (const llm_arch & arch); bool llm_arch_is_hybrid (const llm_arch & arch); bool llm_arch_is_diffusion (const llm_arch & arch); bool llm_arch_supports_sm_tensor(const llm_arch & arch); +bool llm_arch_supports_rs_rollback(const llm_arch & arch); diff --git a/examples/talk-llama/llama-chat.cpp b/examples/talk-llama/llama-chat.cpp index 6554a89b28a..f10397747b0 100644 --- a/examples/talk-llama/llama-chat.cpp +++ b/examples/talk-llama/llama-chat.cpp @@ -73,7 +73,7 @@ static const std::map LLM_CHAT_TEMPLATES = { { "hunyuan-moe", LLM_CHAT_TEMPLATE_HUNYUAN_MOE }, { "gpt-oss", LLM_CHAT_TEMPLATE_OPENAI_MOE }, { "hunyuan-dense", LLM_CHAT_TEMPLATE_HUNYUAN_DENSE }, - { "hunyuan-ocr", LLM_CHAT_TEMPLATE_HUNYUAN_OCR }, + { "hunyuan-vl", LLM_CHAT_TEMPLATE_HUNYUAN_VL }, { "kimi-k2", LLM_CHAT_TEMPLATE_KIMI_K2 }, { "seed_oss", LLM_CHAT_TEMPLATE_SEED_OSS }, { "grok-2", LLM_CHAT_TEMPLATE_GROK_2 }, @@ -218,7 +218,7 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) { } else if (tmpl_contains("<|start|>") && tmpl_contains("<|channel|>")) { return LLM_CHAT_TEMPLATE_OPENAI_MOE; } else if (tmpl_contains("<|hy_Assistant|>") && tmpl_contains("<|hy_begin▁of▁sentence|>")) { - return LLM_CHAT_TEMPLATE_HUNYUAN_OCR; + return LLM_CHAT_TEMPLATE_HUNYUAN_VL; } else if (tmpl_contains("<|hy_Assistant|>") && tmpl_contains("<|hy_place▁holder▁no▁3|>")) { return LLM_CHAT_TEMPLATE_HUNYUAN_DENSE; } else if (tmpl_contains("<|im_assistant|>assistant<|im_middle|>")) { @@ -825,8 +825,8 @@ int32_t llm_chat_apply_template( ss << "<|hy_User|>" << chat[i]->content << "<|hy_Assistant|>"; } } - } else if (tmpl == LLM_CHAT_TEMPLATE_HUNYUAN_OCR) { - // tencent/HunyuanOCR + } else if (tmpl == LLM_CHAT_TEMPLATE_HUNYUAN_VL) { + // tencent/HunyuanOCR & tencent/HunyuanVL ss << "<|hy_begin▁of▁sentence|>"; for (size_t i = 0; i < chat.size(); i++) { std::string role(chat[i]->role); diff --git a/examples/talk-llama/llama-chat.h b/examples/talk-llama/llama-chat.h index 13f936a946c..ea6540c0be7 100644 --- a/examples/talk-llama/llama-chat.h +++ b/examples/talk-llama/llama-chat.h @@ -53,7 +53,7 @@ enum llm_chat_template { LLM_CHAT_TEMPLATE_HUNYUAN_MOE, LLM_CHAT_TEMPLATE_OPENAI_MOE, LLM_CHAT_TEMPLATE_HUNYUAN_DENSE, - LLM_CHAT_TEMPLATE_HUNYUAN_OCR, + LLM_CHAT_TEMPLATE_HUNYUAN_VL, LLM_CHAT_TEMPLATE_KIMI_K2, LLM_CHAT_TEMPLATE_SEED_OSS, LLM_CHAT_TEMPLATE_GROK_2, diff --git a/examples/talk-llama/llama-context.cpp b/examples/talk-llama/llama-context.cpp index 3d9714ab166..ad36c06667d 100644 --- a/examples/talk-llama/llama-context.cpp +++ b/examples/talk-llama/llama-context.cpp @@ -2,6 +2,7 @@ #include "ggml.h" #include "llama-arch.h" +#include "llama-graph.h" #include "llama-impl.h" #include "llama-batch.h" #include "llama-io.h" @@ -21,6 +22,14 @@ // llama_context // +static llm_graph_type ctx_type_to_graph_type(llama_context_type ctx_type) { + switch (ctx_type) { + case LLAMA_CONTEXT_TYPE_DEFAULT: return LLM_GRAPH_TYPE_DEFAULT; + case LLAMA_CONTEXT_TYPE_MTP : return LLM_GRAPH_TYPE_DECODER_MTP; + } + throw std::runtime_error("Unsupported ctx type"); +} + llama_context::llama_context( const llama_model & model, llama_context_params params) : @@ -42,13 +51,22 @@ llama_context::llama_context( throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_SEQ)); } + cparams.n_rs_seq = params.n_rs_seq; + if (cparams.n_rs_seq > 0 && !llm_arch_supports_rs_rollback(model.arch)) { + LLAMA_LOG_DEBUG("%s: n_rs_seq=%u requested but model arch does not support recurrent partial rollback; clamping to 0\n", + __func__, cparams.n_rs_seq); + cparams.n_rs_seq = 0; + } + cparams.n_threads = params.n_threads; cparams.n_threads_batch = params.n_threads_batch; cparams.yarn_ext_factor = params.yarn_ext_factor >= 0.0f ? params.yarn_ext_factor : hparams.yarn_ext_factor; cparams.yarn_attn_factor = params.yarn_attn_factor >= 0.0f ? params.yarn_attn_factor : hparams.yarn_attn_factor; cparams.yarn_beta_fast = params.yarn_beta_fast >= 0.0f ? params.yarn_beta_fast : hparams.yarn_beta_fast; cparams.yarn_beta_slow = params.yarn_beta_slow >= 0.0f ? params.yarn_beta_slow : hparams.yarn_beta_slow; - cparams.embeddings = params.embeddings; + cparams.embeddings = params.embeddings; + cparams.embeddings_pre_norm = false; + cparams.embeddings_pre_norm_masked = false; cparams.offload_kqv = params.offload_kqv; cparams.no_perf = params.no_perf; cparams.pooling_type = params.pooling_type; @@ -65,6 +83,8 @@ llama_context::llama_context( cparams.cb_eval = params.cb_eval; cparams.cb_eval_user_data = params.cb_eval_user_data; + cparams.ctx_type = params.ctx_type; + // Initialize backend samplers here so they are part of the sampling graph // before the reserve passes run later in this function. This avoids a later // re-reserve when graph nodes change. @@ -206,6 +226,7 @@ llama_context::llama_context( LLAMA_LOG_INFO("%s: kv_unified = %s\n", __func__, cparams.kv_unified ? "true" : "false"); LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale); + LLAMA_LOG_INFO("%s: n_rs_seq = %u\n", __func__, cparams.n_rs_seq); if (cparams.n_ctx_seq < hparams.n_ctx_train) { LLAMA_LOG_WARN("%s: n_ctx_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n", @@ -278,6 +299,7 @@ llama_context::llama_context( /*.type_k =*/ params.type_k, /*.type_v =*/ params.type_v, /*.swa_full =*/ params.swa_full, + /*.ctx_type= */ cparams.ctx_type, }; memory.reset(model.create_memory(params_mem, cparams)); @@ -860,6 +882,42 @@ float * llama_context::get_embeddings_seq(llama_seq_id seq_id) { return it->second.data(); } +float * llama_context::get_embeddings_pre_norm() { + output_reorder(); + + return embd_pre_norm.data; +} + +float * llama_context::get_embeddings_pre_norm_ith(int32_t i) { + output_reorder(); + + try { + if (embd_pre_norm.data == nullptr) { + throw std::runtime_error("no pre-norm embeddings"); + } + + const uint32_t n_embd = model.hparams.n_embd; + + if (!cparams.embeddings_pre_norm_masked) { + // unmasked: pre-norm rows are stored densely, indexed by raw token position. + if (i < 0 || (size_t)(i + 1) * n_embd > embd_pre_norm.size) { + throw std::runtime_error(format("out of range [0, %zu)", embd_pre_norm.size / n_embd)); + } + return embd_pre_norm.data + (size_t) i * n_embd; + } + + const int64_t j = output_resolve_row(i); + return embd_pre_norm.data + j*n_embd; + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: invalid pre-norm embeddings id %d, reason: %s\n", __func__, i, err.what()); +#ifndef NDEBUG + GGML_ABORT("fatal error"); +#else + return nullptr; +#endif + } +} + llama_token llama_context::get_sampled_token_ith(int32_t idx) { output_reorder(); @@ -1040,6 +1098,13 @@ void llama_context::set_embeddings(bool value) { //sched_need_reserve = true; } +void llama_context::set_embeddings_pre_norm(bool value, bool masked) { + LLAMA_LOG_DEBUG("%s: value = %d, masked = %d\n", __func__, value, masked); + + cparams.embeddings_pre_norm = value; + cparams.embeddings_pre_norm_masked = masked; +} + void llama_context::set_causal_attn(bool value) { LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value); @@ -1072,6 +1137,19 @@ bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) { LLAMA_LOG_DEBUG("%s: seq_id = %d, sampler = %p\n", __func__, (int) seq_id, (void *) sampler); + if (sampler && model.split_mode() == LLAMA_SPLIT_MODE_TENSOR) { + static bool warned = false; + if (!warned) { + LLAMA_LOG_WARN("%s: backend sampling not supported with SPLIT_MODE_TENSOR; using CPU\n", __func__); + warned = true; + } + if (sampling.samplers.count(seq_id) > 0) { + sched_need_reserve = true; + } + sampling.samplers.erase(seq_id); + return false; + } + const bool can_offload = sampler && sampler->iface->backend_init && @@ -1241,7 +1319,9 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll } int llama_context::encode(const llama_batch & batch_inp) { - GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT + // MTP hook batches carry both token (next-token id) and embd (h_pre_norm row), + // so accept either present rather than requiring exactly one. + GGML_ASSERT(batch_inp.token || batch_inp.embd); if (batch_inp.n_tokens == 0) { LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__); @@ -1312,8 +1392,9 @@ int llama_context::encode(const llama_batch & batch_inp) { } } - auto * t_logits = res->get_logits(); - auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd(); + auto * t_logits = res->get_logits(); + auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd(); + auto * t_h_pre_norm = cparams.embeddings_pre_norm ? res->get_h_pre_norm() : nullptr; // extract logits if (logits.data && t_logits) { @@ -1379,6 +1460,16 @@ int llama_context::encode(const llama_batch & batch_inp) { } } + // extract pre-norm embeddings (hidden state before the final output norm) + if (embd_pre_norm.data && t_h_pre_norm && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) { + ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_pre_norm); + GGML_ASSERT(backend_h != nullptr); + + const uint32_t n_embd = hparams.n_embd; + GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_pre_norm.size); + ggml_backend_tensor_get_async(backend_h, t_h_pre_norm, embd_pre_norm.data, 0, n_tokens*n_embd*sizeof(float)); + } + // TODO: hacky solution if (model.arch == LLM_ARCH_T5 && t_embd) { //cross.t_embd = t_embd; @@ -1531,7 +1622,9 @@ static bool needs_raw_logits(const llama_ubatch & ubatch, const std::mapget_ubatch(); @@ -1689,7 +1783,8 @@ int llama_context::decode(const llama_batch & batch_inp) { } ggml_status status; - const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status); + + const auto * res = process_ubatch(ubatch, ctx_type_to_graph_type(cparams.ctx_type), mctx.get(), status); if (!res) { // the last ubatch failed or was aborted -> remove all positions of that ubatch from the memory module @@ -1727,8 +1822,9 @@ int llama_context::decode(const llama_batch & batch_inp) { // ggml_graph_dump_dot(gf, NULL, "llama.dot"); //} - auto * t_logits = res->get_logits(); - auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr; + auto * t_logits = res->get_logits(); + auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr; + auto * t_h_pre_norm = cparams.embeddings_pre_norm ? res->get_h_pre_norm() : nullptr; if (t_embd && res->get_embd_pooled()) { t_embd = res->get_embd_pooled(); @@ -1809,6 +1905,25 @@ int llama_context::decode(const llama_batch & batch_inp) { } } + // extract pre-norm embeddings (hidden state before the final output norm) + // only meaningful in LLAMA_POOLING_TYPE_NONE (per-token); other pooling modes are ignored. + { + const bool masked = cparams.embeddings_pre_norm_masked; + const int64_t n_rows = masked ? n_outputs : (int64_t) ubatch.n_tokens; + const int64_t offset = masked ? n_outputs_prev : n_tokens_prev; + + if (embd_pre_norm.data && t_h_pre_norm && n_rows > 0 && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) { + ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_pre_norm); + GGML_ASSERT(backend_h != nullptr); + + const uint32_t n_embd = hparams.n_embd; + float * embd_pre_norm_out = embd_pre_norm.data + offset*n_embd; + + GGML_ASSERT((offset + n_rows)*n_embd <= (int64_t) embd_pre_norm.size); + ggml_backend_tensor_get_async(backend_h, t_h_pre_norm, embd_pre_norm_out, 0, n_rows*n_embd*sizeof(float)); + } + } + // Copy backend sampling output if this ubatch produced any sampling tensors. if (has_samplers && (!res->t_sampled.empty() || !res->t_sampled_probs.empty() || !res->t_sampled_logits.empty())) { const auto seq_to_output_row = build_seq_to_output_row(ubatch, n_outputs_prev); @@ -1823,6 +1938,7 @@ int llama_context::decode(const llama_batch & batch_inp) { } n_outputs_prev += n_outputs; + n_tokens_prev += ubatch.n_tokens; } while (mctx->next()); // set to total number of outputs in the batch, for use in llama_get_logits_ith @@ -1893,10 +2009,12 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { const auto n_batch = cparams.n_batch; const auto n_vocab = vocab.n_tokens(); + const auto n_embd = hparams.n_embd; const auto n_embd_out = hparams.n_embd_out(); - bool has_logits = true; - bool has_embd = cparams.embeddings; + bool has_logits = true; + bool has_embd = cparams.embeddings; + bool has_embd_pre_norm = cparams.embeddings_pre_norm; // TODO: hacky enc-dec support if (model.arch == LLM_ARCH_T5) { @@ -1908,8 +2026,15 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { size_t backend_float_count = 0; size_t backend_token_count = 0; - logits.size = has_logits ? n_vocab*n_outputs_max : 0; - embd.size = has_embd ? n_embd_out*n_outputs_max : 0; + logits.size = has_logits ? n_vocab*n_outputs_max : 0; + embd.size = has_embd ? n_embd_out*n_outputs_max : 0; + embd_pre_norm.size = has_embd_pre_norm ? n_embd*n_outputs_max : 0; + + if (has_embd_pre_norm && !cparams.embeddings_pre_norm_masked) { + // unmasked: pre-norm row exists for every token in the batch, not just + // those flagged via batch.logits[i] -> size by token count instead. + embd_pre_norm.size = (size_t) n_embd * n_batch; + } // Allocate backend sampling output buffers if there are backend samplers configured. const bool has_sampling = !sampling.samplers.empty(); @@ -1925,8 +2050,8 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { const size_t prev_size = buf_output ? ggml_backend_buffer_get_size(buf_output.get()) : 0; const size_t new_size = - (logits.size + embd.size + backend_float_count) * sizeof(float) + - ( backend_token_count) * sizeof(llama_token); + (logits.size + embd.size + embd_pre_norm.size + backend_float_count) * sizeof(float) + + ( backend_token_count) * sizeof(llama_token); // alloc only when more than the current capacity is required // TODO: also consider shrinking the buffer @@ -1942,6 +2067,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { buf_output = nullptr; logits.data = nullptr; embd.data = nullptr; + embd_pre_norm.data = nullptr; } auto * buft = ggml_backend_cpu_buffer_type(); @@ -1970,6 +2096,9 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { embd = has_embd ? buffer_view{(float *) (base + offset), embd.size} : buffer_view{nullptr, 0}; offset += embd.size * sizeof(float); + embd_pre_norm = has_embd_pre_norm ? buffer_view{(float *) (base + offset), embd_pre_norm.size} : buffer_view{nullptr, 0}; + offset += embd_pre_norm.size * sizeof(float); + if (has_sampling) { sampling.logits = {(float *) (base + offset), (size_t)(n_vocab*n_outputs_max)}; offset += sampling.logits.size * sizeof(float); @@ -2034,6 +2163,12 @@ void llama_context::output_reorder() { } } + if (embd_pre_norm.size > 0) { + for (uint64_t k = 0; k < n_embd; k++) { + std::swap(embd_pre_norm.data[i0*n_embd + k], embd_pre_norm.data[i1*n_embd + k]); + } + } + if (!sampling.samplers.empty()) { assert(sampling.logits.size > 0); assert(sampling.probs.size > 0); @@ -2121,7 +2256,7 @@ ggml_cgraph * llama_context::graph_reserve( auto * res = gf_res_reserve.get(); - const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT); + const auto gparams = graph_params(res, ubatch, mctx, ctx_type_to_graph_type(cparams.ctx_type)); res->reset(); @@ -3100,7 +3235,7 @@ void llama_context::opt_epoch_iter( auto * res = gf_res_prev.get(); - const auto gparams = graph_params(res, ubatch, mctx.get(), LLM_GRAPH_TYPE_DEFAULT); + const auto gparams = graph_params(res, ubatch, mctx.get(), ctx_type_to_graph_type(cparams.ctx_type)); res->reset(); @@ -3201,8 +3336,10 @@ llama_context_params llama_context_default_params() { /*.n_batch =*/ 2048, /*.n_ubatch =*/ 512, /*.n_seq_max =*/ 1, + /*.n_rs_seq =*/ 0, /*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default /*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS, + /*.ctx_type =*/ LLAMA_CONTEXT_TYPE_DEFAULT, /*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED, /*.pooling_type =*/ LLAMA_POOLING_TYPE_UNSPECIFIED, /*.attention_type =*/ LLAMA_ATTENTION_TYPE_UNSPECIFIED, @@ -3306,6 +3443,13 @@ llama_context * llama_init_from_model( model->hparams.pooling_type, params.pooling_type); } + if (params.ctx_type == LLAMA_CONTEXT_TYPE_MTP && + model->hparams.nextn_predict_layers == 0) { + LLAMA_LOG_WARN("%s: context type MTP requested but model doesn't contain MTP layers\n", __func__); + return nullptr; + } + + try { auto * ctx = new llama_context(*model, params); return ctx; @@ -3347,6 +3491,10 @@ uint32_t llama_n_seq_max(const llama_context * ctx) { return ctx->n_seq_max(); } +uint32_t llama_n_rs_seq(const llama_context * ctx) { + return ctx->get_cparams().n_rs_seq; +} + const llama_model * llama_get_model(const llama_context * ctx) { return &ctx->get_model(); } @@ -3436,6 +3584,22 @@ float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) { return ctx->get_embeddings_seq(seq_id); } +void llama_set_embeddings_pre_norm(llama_context * ctx, bool value, bool masked) { + ctx->set_embeddings_pre_norm(value, masked); +} + +float * llama_get_embeddings_pre_norm(llama_context * ctx) { + ctx->synchronize(); + + return ctx->get_embeddings_pre_norm(); +} + +float * llama_get_embeddings_pre_norm_ith(llama_context * ctx, int32_t i) { + ctx->synchronize(); + + return ctx->get_embeddings_pre_norm_ith(i); +} + bool llama_set_sampler(llama_context * ctx, llama_seq_id seq_id, llama_sampler * smpl) { return ctx->set_sampler(seq_id, smpl); } diff --git a/examples/talk-llama/llama-context.h b/examples/talk-llama/llama-context.h index 92d1b0cf95a..d03f681d4a1 100644 --- a/examples/talk-llama/llama-context.h +++ b/examples/talk-llama/llama-context.h @@ -84,6 +84,9 @@ struct llama_context { float * get_embeddings_ith(int32_t i); float * get_embeddings_seq(llama_seq_id seq_id); + float * get_embeddings_pre_norm(); + float * get_embeddings_pre_norm_ith(int32_t i); + llama_token * get_sampled_tokens() const; llama_token get_sampled_token_ith(int32_t idx); @@ -107,6 +110,7 @@ struct llama_context { void set_abort_callback(bool (*abort_callback)(void * data), void * abort_callback_data); void set_embeddings (bool value); + void set_embeddings_pre_norm(bool value, bool masked); void set_causal_attn(bool value); void set_warmup(bool value); @@ -278,6 +282,11 @@ struct llama_context { // populated only when pooling_type == LLAMA_POOLING_TYPE_NONE buffer_view embd = {nullptr, 0}; + // hidden state before the final output norm (2-dimensional array: [n_outputs][n_embd]) + // populated only when cparams.embeddings_pre_norm is enabled and the model graph + // sets llm_graph_result::t_h_pre_norm + buffer_view embd_pre_norm = {nullptr, 0}; + struct sampling_info { // !samplers.empty() to check if any samplers are active std::map samplers; diff --git a/examples/talk-llama/llama-cparams.h b/examples/talk-llama/llama-cparams.h index 9d359474132..20ec59fe335 100644 --- a/examples/talk-llama/llama-cparams.h +++ b/examples/talk-llama/llama-cparams.h @@ -12,6 +12,7 @@ struct llama_cparams { uint32_t n_batch; uint32_t n_ubatch; uint32_t n_seq_max; + uint32_t n_rs_seq; // number of recurrent-state snapshots per seq for rollback int32_t n_threads; // number of threads to use for generation int32_t n_threads_batch; // number of threads to use for batch processing @@ -27,6 +28,8 @@ struct llama_cparams { float yarn_beta_slow; bool embeddings; + bool embeddings_pre_norm; // also extract the hidden state before the final output norm + bool embeddings_pre_norm_masked; // extract for only rows where batch.logits != 0 bool causal_attn; bool offload_kqv; bool flash_attn; @@ -40,6 +43,7 @@ struct llama_cparams { bool kv_unified; bool pipeline_parallel; + enum llama_context_type ctx_type; enum llama_pooling_type pooling_type; ggml_backend_sched_eval_callback cb_eval; diff --git a/examples/talk-llama/llama-ext.h b/examples/talk-llama/llama-ext.h index 8ce29d217cb..edfa71c207c 100644 --- a/examples/talk-llama/llama-ext.h +++ b/examples/talk-llama/llama-ext.h @@ -88,3 +88,19 @@ LLAMA_API int32_t llama_model_n_devices(const struct llama_model * model); LLAMA_API ggml_backend_dev_t llama_model_get_device(const struct llama_model * model, int i); LLAMA_API llama_memory_breakdown llama_get_memory_breakdown(const struct llama_context * ctx); + +// +// pre-norm embeddings (hidden state before the final output norm) +// + +// Set whether the context outputs pre-norm embeddings or not +// If masked == true, output the embeddings only for the tokens with batch.logits != 0 +// If masked == false, output the embeddings for all tokens in the batch regardless of batch.logits +LLAMA_API void llama_set_embeddings_pre_norm(struct llama_context * ctx, bool value, bool masked); + +// mirrors: +// LLAMA_API float * llama_get_embeddings(struct llama_context * ctx); +LLAMA_API float * llama_get_embeddings_pre_norm (struct llama_context * ctx); + +// LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i); +LLAMA_API float * llama_get_embeddings_pre_norm_ith(struct llama_context * ctx, int32_t i); diff --git a/examples/talk-llama/llama-graph.cpp b/examples/talk-llama/llama-graph.cpp index fe155c92dea..fc027de8b39 100644 --- a/examples/talk-llama/llama-graph.cpp +++ b/examples/talk-llama/llama-graph.cpp @@ -500,15 +500,21 @@ bool llm_graph_input_attn_k::can_reuse(const llm_graph_params & params) { } void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) { - mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch); - mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch); + // base tensors may not be allocated if there are no non-SWA attention layers + if (self_k_idxs && self_k_idxs->buffer) { + mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch); + mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch); - mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); + mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); + } - mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch); - mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch); + // swa tensors may not be allocated if there are no SWA attention layers + if (self_k_idxs_swa && self_k_idxs_swa->buffer) { + mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch); + mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch); - mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn); + mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn); + } if (self_k_rot) { mctx->get_base()->set_input_k_rot(self_k_rot); @@ -534,14 +540,21 @@ bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) { bool res = true; - res &= self_k_idxs->ne[0] == params.ubatch.n_tokens; - //res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there + // base tensors may not be allocated if there are no non-SWA attention layers + if (self_k_idxs && self_k_idxs->buffer) { + res &= self_k_idxs->ne[0] == params.ubatch.n_tokens; + //res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there + + res &= can_reuse_kq_mask(self_kq_mask, mctx->get_base(), params.ubatch, params.cparams); + } - res &= self_k_idxs_swa->ne[0] == params.ubatch.n_tokens; - //res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there + // swa tensors may not be allocated if there are no SWA attention layers + if (self_k_idxs_swa && self_k_idxs_swa->buffer) { + res &= self_k_idxs_swa->ne[0] == params.ubatch.n_tokens; + //res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there - res &= can_reuse_kq_mask(self_kq_mask, mctx->get_base(), params.ubatch, params.cparams); - res &= can_reuse_kq_mask(self_kq_mask_swa, mctx->get_swa(), params.ubatch, params.cparams); + res &= can_reuse_kq_mask(self_kq_mask_swa, mctx->get_swa(), params.ubatch, params.cparams); + } return res; } @@ -848,6 +861,9 @@ void llm_graph_result::set_outputs() { if (t_embd_pooled != nullptr) { ggml_set_output(t_embd_pooled); } + if (t_h_pre_norm != nullptr) { + ggml_set_output(t_h_pre_norm); + } for (auto & [seq_id, t] : t_sampled) { if (t != nullptr) { ggml_set_output(t); @@ -2528,7 +2544,8 @@ ggml_tensor * llm_graph_context::build_rs( int32_t rs_zero, const llm_graph_get_rows_fn & get_state_rows) const { - ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, rs_size); + GGML_UNUSED(rs_size); + ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, s->ne[1]); // Clear a single state which will then be copied to the other cleared states. // Note that this is a no-op when the view is zero-sized. diff --git a/examples/talk-llama/llama-graph.h b/examples/talk-llama/llama-graph.h index 5cb1756c6a9..bf6778237e6 100644 --- a/examples/talk-llama/llama-graph.h +++ b/examples/talk-llama/llama-graph.h @@ -32,6 +32,7 @@ enum llm_graph_type { LLM_GRAPH_TYPE_DEFAULT, LLM_GRAPH_TYPE_ENCODER, LLM_GRAPH_TYPE_DECODER, + LLM_GRAPH_TYPE_DECODER_MTP, }; enum llm_ffn_op_type { @@ -580,7 +581,8 @@ struct llm_graph_params { ubatch.n_seqs_unq == other.ubatch.n_seqs_unq && ( (!ubatch.token && !other.ubatch.token) || - (!ubatch.embd && !other.ubatch.embd) + (!ubatch.embd && !other.ubatch.embd) || + (ubatch.token && other.ubatch.token && ubatch.embd && other.ubatch.embd) ); // when we split the batch using "equal_seqs" we have to verify that the participating sequences are the same @@ -644,6 +646,7 @@ class llm_graph_result { ggml_tensor * get_logits() const { return t_logits; } ggml_tensor * get_embd() const { return t_embd; } ggml_tensor * get_embd_pooled() const { return t_embd_pooled; } + ggml_tensor * get_h_pre_norm() const { return t_h_pre_norm; } ggml_cgraph * get_gf() const { return gf; } ggml_context * get_ctx() const { return ctx_compute.get(); } @@ -672,6 +675,7 @@ class llm_graph_result { ggml_tensor * t_logits = nullptr; ggml_tensor * t_embd = nullptr; ggml_tensor * t_embd_pooled = nullptr; + ggml_tensor * t_h_pre_norm = nullptr; // [n_embd, n_outputs] hidden state before final output norm std::map t_sampled_logits; std::map t_candidates; diff --git a/examples/talk-llama/llama-hparams.cpp b/examples/talk-llama/llama-hparams.cpp index 002d15d415f..2239309c8fb 100644 --- a/examples/talk-llama/llama-hparams.cpp +++ b/examples/talk-llama/llama-hparams.cpp @@ -229,6 +229,12 @@ uint32_t llama_hparams::n_embd_head_v_mla() const { } bool llama_hparams::has_kv(uint32_t il) const { + if (kv_only_nextn) { + // MTP head: only the trailing nextn_predict_layers blocks own a KV cache; + // the leading trunk blocks are not executed in this graph. + return nextn_predict_layers > 0 && il >= (n_layer - nextn_predict_layers); + } + if (n_layer_kv_from_start >= 0) { if (il < (uint32_t) n_layer_kv_from_start) { return true; diff --git a/examples/talk-llama/llama-hparams.h b/examples/talk-llama/llama-hparams.h index 0160a89caa2..e2d051edc6c 100644 --- a/examples/talk-llama/llama-hparams.h +++ b/examples/talk-llama/llama-hparams.h @@ -92,6 +92,8 @@ struct llama_hparams { uint32_t moe_latent_size = 0; uint32_t nextn_predict_layers = 0; + bool kv_only_nextn = false; // if true, only the last nextn_predict_layers blocks have a KV cache (MTP head arches) + float f_norm_eps; float f_norm_rms_eps; float f_norm_group_eps; diff --git a/examples/talk-llama/llama-memory-hybrid-iswa.cpp b/examples/talk-llama/llama-memory-hybrid-iswa.cpp index 10e6b459797..72f5c2fea72 100644 --- a/examples/talk-llama/llama-memory-hybrid-iswa.cpp +++ b/examples/talk-llama/llama-memory-hybrid-iswa.cpp @@ -24,6 +24,7 @@ llama_memory_hybrid_iswa::llama_memory_hybrid_iswa( uint32_t rs_size, /* common */ uint32_t n_seq_max, + uint32_t n_rs_seq, bool offload, bool unified, /* layer filters */ @@ -54,6 +55,7 @@ llama_memory_hybrid_iswa::llama_memory_hybrid_iswa( offload, rs_size, n_seq_max, + n_rs_seq, filter_recr == nullptr ? [&](int32_t il) { return hparams.is_recurrent(il); } : filter_recr @@ -73,9 +75,15 @@ llama_memory_context_ptr llama_memory_hybrid_iswa::init_batch(llama_batch_allocr // if all tokens are output, split by sequence ubatch = balloc.split_seq(n_ubatch); } else { - // Use non-sequential split when KV cache is unified (needed for hellaswag/winogrande/multiple-choice) - const bool unified = (mem_attn->get_base()->get_n_stream() == 1); - ubatch = balloc.split_equal(n_ubatch, !unified); + if (mem_recr->n_rs_seq > 0) { + // [TAG_RECURRENT_ROLLBACK_SPLITS] + // TODO: recurrent state rollback does not support equal splits + ubatch = balloc.split_seq(n_ubatch); + } else { + // Use non-sequential split when KV cache is unified (needed for hellaswag/winogrande/multiple-choice) + const bool unified = (mem_attn->get_base()->get_n_stream() == 1); + ubatch = balloc.split_equal(n_ubatch, !unified); + } } if (ubatch.n_tokens == 0) { diff --git a/examples/talk-llama/llama-memory-hybrid-iswa.h b/examples/talk-llama/llama-memory-hybrid-iswa.h index 807c8aac96c..c9d3f9f57c5 100644 --- a/examples/talk-llama/llama-memory-hybrid-iswa.h +++ b/examples/talk-llama/llama-memory-hybrid-iswa.h @@ -34,6 +34,7 @@ class llama_memory_hybrid_iswa : public llama_memory_i { uint32_t rs_size, /* common */ uint32_t n_seq_max, + uint32_t n_rs_seq, bool offload, bool unified, /* layer filters */ diff --git a/examples/talk-llama/llama-memory-hybrid.cpp b/examples/talk-llama/llama-memory-hybrid.cpp index 4ce1af592c1..33b3b395e0c 100644 --- a/examples/talk-llama/llama-memory-hybrid.cpp +++ b/examples/talk-llama/llama-memory-hybrid.cpp @@ -24,6 +24,7 @@ llama_memory_hybrid::llama_memory_hybrid( uint32_t rs_size, /* common */ uint32_t n_seq_max, + uint32_t n_rs_seq, bool offload, bool unified, /* layer filters */ @@ -54,6 +55,7 @@ llama_memory_hybrid::llama_memory_hybrid( offload, rs_size, n_seq_max, + n_rs_seq, filter_recr == nullptr ? [&](int32_t il) { return hparams.is_recurrent(il); } : filter_recr @@ -73,9 +75,15 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba // if all tokens are output, split by sequence ubatch = balloc.split_seq(n_ubatch); } else { - // Use non-sequential split when KV cache is unified (needed for hellaswag/winogrande/multiple-choice) - const bool unified = (mem_attn->get_n_stream() == 1); - ubatch = balloc.split_equal(n_ubatch, !unified); + if (mem_recr->n_rs_seq > 0) { + // [TAG_RECURRENT_ROLLBACK_SPLITS] + // TODO: recurrent state rollback does not support equal splits + ubatch = balloc.split_seq(n_ubatch); + } else { + // Use non-sequential split when KV cache is unified (needed for hellaswag/winogrande/multiple-choice) + const bool unified = (mem_attn->get_n_stream() == 1); + ubatch = balloc.split_equal(n_ubatch, !unified); + } } if (ubatch.n_tokens == 0) { diff --git a/examples/talk-llama/llama-memory-hybrid.h b/examples/talk-llama/llama-memory-hybrid.h index 558cafdf984..484eafb7499 100644 --- a/examples/talk-llama/llama-memory-hybrid.h +++ b/examples/talk-llama/llama-memory-hybrid.h @@ -34,6 +34,7 @@ class llama_memory_hybrid : public llama_memory_i { uint32_t rs_size, /* common */ uint32_t n_seq_max, + uint32_t n_rs_seq, bool offload, bool unified, /* layer filters */ diff --git a/examples/talk-llama/llama-memory-recurrent.cpp b/examples/talk-llama/llama-memory-recurrent.cpp index c07f1d969cb..ec5dc5835dd 100644 --- a/examples/talk-llama/llama-memory-recurrent.cpp +++ b/examples/talk-llama/llama-memory-recurrent.cpp @@ -24,6 +24,7 @@ llama_memory_recurrent::llama_memory_recurrent( bool offload, uint32_t mem_size, uint32_t n_seq_max, + uint32_t n_rs_seq, const layer_filter_cb & filter) : hparams(model.hparams), n_seq_max(n_seq_max) { const int32_t n_layer = hparams.n_layer; @@ -31,6 +32,9 @@ llama_memory_recurrent::llama_memory_recurrent( size = mem_size; used = 0; + this->n_rs_seq = n_rs_seq; + rs_idx.assign(n_seq_max, 0); + cells.clear(); cells.resize(mem_size); @@ -92,8 +96,9 @@ llama_memory_recurrent::llama_memory_recurrent( throw std::runtime_error("failed to create ggml context for rs cache"); } - ggml_tensor * r = ggml_new_tensor_2d(ctx, type_r, hparams.n_embd_r(), mem_size); - ggml_tensor * s = ggml_new_tensor_2d(ctx, type_s, hparams.n_embd_s(), mem_size); + const uint32_t n_rows = mem_size * (1 + n_rs_seq); + ggml_tensor * r = ggml_new_tensor_2d(ctx, type_r, hparams.n_embd_r(), n_rows); + ggml_tensor * s = ggml_new_tensor_2d(ctx, type_s, hparams.n_embd_s(), n_rows); ggml_format_name(r, "cache_r_l%d", i); ggml_format_name(s, "cache_s_l%d", i); r_l[i] = r; @@ -115,8 +120,8 @@ llama_memory_recurrent::llama_memory_recurrent( const size_t memory_size_r = size_r_bytes(); const size_t memory_size_s = size_s_bytes(); - LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u seqs), R (%s): %7.2f MiB, S (%s): %7.2f MiB\n", __func__, - (float)(memory_size_r + memory_size_s) / (1024.0f * 1024.0f), mem_size, n_layer, n_seq_max, + LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u seqs %2u rs_seq), R (%s): %7.2f MiB, S (%s): %7.2f MiB\n", __func__, + (float)(memory_size_r + memory_size_s) / (1024.0f * 1024.0f), mem_size, n_layer, n_seq_max, n_rs_seq, ggml_type_name(type_r), (float)memory_size_r / (1024.0f * 1024.0f), ggml_type_name(type_s), (float)memory_size_s / (1024.0f * 1024.0f)); } @@ -138,10 +143,11 @@ void llama_memory_recurrent::clear(bool data) { ggml_backend_buffer_clear(buf.get(), 0); } } + + std::fill(rs_idx.begin(), rs_idx.end(), 0); } bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { - //printf("[DEBUG] calling llama_memory_recurrent::seq_rm` with `seq_id=%d, p0=%d, p1=%d`\n", seq_id, p0, p1); uint32_t new_head = size; if (p0 < 0) { @@ -152,6 +158,15 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1 = std::numeric_limits::max(); } + const bool rm_all = p0 == 0 && p1 == std::numeric_limits::max(); + if (rm_all) { + if (seq_id >= 0) { + set_rs_idx(seq_id, 0); + } else { + std::fill(rs_idx.begin(), rs_idx.end(), 0); + } + } + // models like Mamba or RWKV can't have a state partially erased at the end // of the sequence because their state isn't preserved for previous tokens if (seq_id >= (int64_t) size) { @@ -161,10 +176,16 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos if (0 <= seq_id) { int32_t & tail_id = cells[seq_id].tail; if (tail_id >= 0) { - const auto & cell = cells[tail_id]; - // partial intersection is invalid if it includes the final pos + auto & cell = cells[tail_id]; + + // partial rollback via per-token snapshot index (bounded by n_rs_seq) if (0 < p0 && p0 <= cell.pos && p1 > cell.pos) { - //printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: partial intersection is invalid, so returning false, p0 = %d, cell.pos = %d, p1 = %d\n", p0, cell.pos, p1); + const llama_pos rollback = cell.pos - (p0 - 1); + if (rollback >= 1 && rollback <= (llama_pos) n_rs_seq) { + set_rs_idx(seq_id, (uint32_t) rollback); + cell.pos = p0 - 1; + return true; + } return false; } // invalidate tails which will be cleared @@ -368,6 +389,13 @@ llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const { return result; } +void llama_memory_recurrent::set_rs_idx(llama_seq_id seq_id, uint32_t idx) { + if (seq_id < 0 || (size_t) seq_id >= rs_idx.size()) { + return; + } + rs_idx[seq_id] = (idx > n_rs_seq) ? n_rs_seq : idx; +} + std::map llama_memory_recurrent::memory_breakdown() const { std::map ret; for (const auto & [_, buf] : ctxs_bufs) { @@ -388,9 +416,15 @@ llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & // if all tokens are output, split by sequence ubatch = balloc.split_seq(n_ubatch); } else { - // TODO: non-sequential equal split can be done if using unified KV cache - // for simplicity, we always use sequential equal split for now - ubatch = balloc.split_equal(n_ubatch, true); + if (n_rs_seq > 0) { + // [TAG_RECURRENT_ROLLBACK_SPLITS] + // TODO: recurrent state rollback does not support equal splits + ubatch = balloc.split_seq(n_ubatch); + } else { + // TODO: non-sequential equal split can be done if using unified KV cache + // for simplicity, we always use sequential equal split for now + ubatch = balloc.split_equal(n_ubatch, true); + } } if (ubatch.n_tokens == 0) { @@ -703,6 +737,7 @@ void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq GGML_UNUSED(flags); std::vector> cell_ranges; // ranges, from inclusive, to exclusive + std::vector> cell_ranges_data; // logical source row ranges uint32_t cell_count = 0; // Count the number of cells with the specified seq_id @@ -712,6 +747,35 @@ void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq const auto & cell = cells[i]; if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) { ++cell_count; + uint32_t rs_idx_cur = 0; + + if (n_rs_seq != 0) { + if (seq_id != -1) { + GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < rs_idx.size()); + rs_idx_cur = rs_idx[seq_id]; + } else { + bool has_rs_idx = false; + for (const llama_seq_id cell_seq_id : cell.seq_id) { + GGML_ASSERT(cell_seq_id >= 0 && (size_t) cell_seq_id < rs_idx.size()); + + const uint32_t seq_rs_idx = rs_idx[cell_seq_id]; + if (!has_rs_idx) { + rs_idx_cur = seq_rs_idx; + has_rs_idx = true; + } else if (rs_idx_cur != seq_rs_idx) { + GGML_ABORT("cannot write shared recurrent state with different rollback indices"); + } + } + } + } + + const uint32_t cell_id = rs_idx_cur * size + (cell.src >= 0 ? cell.src : (int32_t) i); + if (cell_ranges_data.empty() || cell_ranges_data.back().second != cell_id) { + cell_ranges_data.emplace_back(cell_id, cell_id + 1); + } else { + cell_ranges_data.back().second++; + } + if (cell_range_begin == size) { cell_range_begin = i; } @@ -726,7 +790,7 @@ void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq cell_ranges.emplace_back(cell_range_begin, size); } - if (flags % LLAMA_STATE_SEQ_FLAGS_ON_DEVICE && cell_ranges.size() > 1) { + if ((flags & LLAMA_STATE_SEQ_FLAGS_ON_DEVICE) && cell_ranges.size() > 1) { GGML_ABORT("cannot save/load multiple ranges of cells to/from device memory\n"); } @@ -737,10 +801,16 @@ void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq } GGML_ASSERT(cell_count == cell_count_check); + cell_count_check = 0; + for (const auto & range : cell_ranges_data) { + cell_count_check += range.second - range.first; + } + GGML_ASSERT(cell_count == cell_count_check); + io.write(&cell_count, sizeof(cell_count)); state_write_meta(io, cell_ranges, seq_id); - state_write_data(io, cell_ranges); + state_write_data(io, cell_ranges_data); } void llama_memory_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) { @@ -762,6 +832,14 @@ void llama_memory_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_i } throw std::runtime_error("failed to restore kv cache"); } + + if (n_rs_seq != 0) { + if (seq_id == -1) { + std::fill(rs_idx.begin(), rs_idx.end(), 0); + } else { + set_rs_idx(seq_id, 0); + } + } } void llama_memory_recurrent::state_write_meta(llama_io_write_i & io, const std::vector> & cell_ranges, llama_seq_id seq_id) const { @@ -804,7 +882,8 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std:: const uint64_t r_size_row = ggml_row_size(r_l[il]->type, hparams.n_embd_r()); io.write(&r_size_row, sizeof(r_size_row)); - // Write each range of cells of r_size_row length + // Write each logical cell row range. With pending recurrent rollback, + // the logical current state may live in a rollback snapshot plane. for (const auto & range : cell_ranges) { const size_t range_size = range.second - range.first; const size_t buf_size = range_size * r_size_row; @@ -825,7 +904,8 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std:: const uint64_t s_size_row = ggml_row_size(s_l[il]->type, hparams.n_embd_s()); io.write(&s_size_row, sizeof(s_size_row)); - // Write each range of S tensor rows + // Write each logical cell row range. With pending recurrent rollback, + // the logical current state may live in a rollback snapshot plane. for (const auto & range : cell_ranges) { const size_t range_size = range.second - range.first; const size_t buf_size = range_size * s_size_row; @@ -852,9 +932,8 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std:: // Write GQA embedding size io.write(&n_embd_s, sizeof(n_embd_s)); - // For each row, we get the element values of each cell + // For each row, we get the element values of each logical cell for (uint32_t j = 0; j < n_embd_s; ++j) { - // Write each range of cells of s_size_el length for (const auto & range : cell_ranges) { const size_t range_size = range.second - range.first; const size_t src_offset = (range.first + j * mem_size) * s_size_el; @@ -1163,5 +1242,21 @@ ggml_tensor * llama_memory_recurrent_context::get_s_l(int32_t il) const { } int32_t llama_memory_recurrent_context::s_copy(int i) const { - return mem->cells[i + mem->head].src0; + const uint32_t cell_idx = i + mem->head; + const int32_t src0 = mem->cells[cell_idx].src0; + + if (mem->n_rs_seq == 0) { + return src0; + } + + uint32_t idx = 0; + if (!mem->cells[cell_idx].seq_id.empty()) { + const llama_seq_id seq = *mem->cells[cell_idx].seq_id.begin(); + if (seq >= 0 && (size_t) seq < mem->rs_idx.size()) { + idx = mem->rs_idx[seq]; + // reset rollback idx + mem->rs_idx[seq] = 0; + } + } + return (int32_t)(idx * mem->size) + src0; } diff --git a/examples/talk-llama/llama-memory-recurrent.h b/examples/talk-llama/llama-memory-recurrent.h index 47f01d73912..b13b7b748f5 100644 --- a/examples/talk-llama/llama-memory-recurrent.h +++ b/examples/talk-llama/llama-memory-recurrent.h @@ -23,6 +23,7 @@ class llama_memory_recurrent : public llama_memory_i { bool offload, uint32_t mem_size, uint32_t n_seq_max, + uint32_t n_rs_seq, const layer_filter_cb & filter); ~llama_memory_recurrent() = default; @@ -69,6 +70,14 @@ class llama_memory_recurrent : public llama_memory_i { uint32_t size = 0; // total number of cells, shared across all sequences uint32_t used = 0; // used cells (i.e. at least one seq_id) + // number of recurrent-state snapshots per seq for rollback; tensors are widened to (1 + n_rs_seq) groups + uint32_t n_rs_seq = 0; + + // per-seq rollback index + std::vector rs_idx; + + void set_rs_idx(llama_seq_id seq_id, uint32_t idx); + // computed before each graph build uint32_t n = 0; diff --git a/examples/talk-llama/llama-memory.h b/examples/talk-llama/llama-memory.h index 4a157b91fdb..4ad1612e45b 100644 --- a/examples/talk-llama/llama-memory.h +++ b/examples/talk-llama/llama-memory.h @@ -1,6 +1,7 @@ #pragma once #include "llama.h" +#include "llama-graph.h" #include #include @@ -20,6 +21,8 @@ struct llama_memory_params { // use full-size SWA cache bool swa_full; + + llama_context_type ctx_type; }; enum llama_memory_status { diff --git a/examples/talk-llama/llama-model-loader.cpp b/examples/talk-llama/llama-model-loader.cpp index 4e65a45a50d..c645d0785ab 100644 --- a/examples/talk-llama/llama-model-loader.cpp +++ b/examples/talk-llama/llama-model-loader.cpp @@ -1312,9 +1312,16 @@ struct ggml_tensor * llama_model_loader::create_tensor_as_view(struct ggml_conte return tensor; } -void llama_model_loader::done_getting_tensors() const { - if (n_created != n_tensors) { - throw std::runtime_error(format("%s: wrong number of tensors; expected %d, got %d", __func__, n_tensors, n_created)); +void llama_model_loader::done_getting_tensors(bool partial) const { + if (n_created > n_tensors) { + throw std::runtime_error(format("%s: too many tensors created; expected %d, got %d", __func__, n_tensors, n_created)); + } + if (n_created < n_tensors) { + if (!partial) { + throw std::runtime_error(format("%s: wrong number of tensors; expected %d, got %d", __func__, n_tensors, n_created)); + } + LLAMA_LOG_INFO("%s: partial load — used %d of %d tensors in the file (rest belong to a sibling model on the same .gguf)\n", + __func__, n_created, n_tensors); } if (n_tensors_moved > 0) { LLAMA_LOG_DEBUG("%s: tensor '%s' (%s) (and %zu others) cannot be used with preferred buffer type %s, using %s instead\n", diff --git a/examples/talk-llama/llama-model-loader.h b/examples/talk-llama/llama-model-loader.h index 7b3d6703c03..c476026d3e5 100644 --- a/examples/talk-llama/llama-model-loader.h +++ b/examples/talk-llama/llama-model-loader.h @@ -184,7 +184,7 @@ struct llama_model_loader { struct ggml_tensor * create_tensor_as_view(struct ggml_context * ctx, struct ggml_tensor * base, const std::string & name, const std::initializer_list & ne, size_t offset, bool required = true); - void done_getting_tensors() const; + void done_getting_tensors(bool partial = false) const; void init_mappings(bool prefetch = true, llama_mlocks * mlock_mmaps = nullptr); diff --git a/examples/talk-llama/llama-model-saver.cpp b/examples/talk-llama/llama-model-saver.cpp index e83056557bf..528e4c9c069 100644 --- a/examples/talk-llama/llama-model-saver.cpp +++ b/examples/talk-llama/llama-model-saver.cpp @@ -393,6 +393,8 @@ void llama_model_saver::add_tensors_from_model() { add_tensor(model->output); add_tensor(model->output_b); add_tensor(model->output_norm_enc); + add_tensor(model->output_s); + add_tensor(model->output_in_s); add_tensor(model->cls); add_tensor(model->cls_b); add_tensor(model->cls_out); diff --git a/examples/talk-llama/llama-model.cpp b/examples/talk-llama/llama-model.cpp index ff30a2ae7a6..0d21b2a53c5 100644 --- a/examples/talk-llama/llama-model.cpp +++ b/examples/talk-llama/llama-model.cpp @@ -1334,6 +1334,12 @@ bool llama_model_base::load_tensors(llama_model_loader & ml) { if (!layer.ssm_beta_s && layer.ssm_beta) { layer.ssm_beta_s = create_tensor(tn(LLM_TENSOR_SSM_BETA, "scale", i), {1}, TENSOR_NOT_REQUIRED); } + if (!layer.nextn.eh_proj_s && layer.nextn.eh_proj) { + layer.nextn.eh_proj_s = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.nextn.shared_head_head_s && layer.nextn.shared_head_head) { + layer.nextn.shared_head_head_s = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } // input scales if (!layer.wq_in_s && layer.wq) { @@ -1393,11 +1399,30 @@ bool llama_model_base::load_tensors(llama_model_loader & ml) { if (!layer.ssm_beta_in_s && layer.ssm_beta) { layer.ssm_beta_in_s = create_tensor(tn(LLM_TENSOR_SSM_BETA, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); } + if (!layer.nextn.eh_proj_in_s && layer.nextn.eh_proj) { + layer.nextn.eh_proj_in_s = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.nextn.shared_head_head_in_s && layer.nextn.shared_head_head) { + layer.nextn.shared_head_head_in_s = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + } + // output scales + if (output && output->type == GGML_TYPE_NVFP4) { + // weight scale + if (!output_s) { + output_s = create_tensor(tn(LLM_TENSOR_OUTPUT, "scale"), {1}, TENSOR_NOT_REQUIRED); + } + // input scale + if (!output_in_s) { + output_in_s = create_tensor(tn(LLM_TENSOR_OUTPUT, "input_scale"), {1}, TENSOR_NOT_REQUIRED); + } } } - ml.done_getting_tensors(); + GGML_ASSERT(!(output && tok_embd && + strcmp(output->name, tok_embd->name) == 0 && + output->type == GGML_TYPE_NVFP4)); // populate tensors_by_name for (auto & [_, ctx_ptr] : ml.ctx_map) { for (auto * cur = ggml_get_first_tensor(ctx_ptr.get()); cur != NULL; cur = ggml_get_next_tensor(ctx_ptr.get(), cur)) { @@ -1934,6 +1959,12 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, // checks default: { + // The MTP head is dense-attention only on hybrid Qwen3.5/3.6, so use a plain + // attention KV cache for the MTP context instead of the hybrid wrapper. + const bool mtp_on_hybrid_qwen35 = + params.ctx_type == LLAMA_CONTEXT_TYPE_MTP && + (arch == LLM_ARCH_QWEN35 || arch == LLM_ARCH_QWEN35MOE); + if (llm_arch_is_recurrent(arch)) { res = new llama_memory_recurrent( *this, @@ -1942,8 +1973,9 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, cparams.offload_kqv, std::max((uint32_t) 1, cparams.n_seq_max), cparams.n_seq_max, + cparams.n_rs_seq, nullptr); - } else if (llm_arch_is_hybrid(arch)) { + } else if (llm_arch_is_hybrid(arch) && !mtp_on_hybrid_qwen35) { // The main difference between hybrid architectures is the // layer filters, so pick the right one here llama_memory_hybrid::layer_filter_cb filter_attn = nullptr; @@ -1958,6 +1990,14 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, filter_recr = [&](int32_t il) { return hparams.is_recurrent(il) && hparams.n_ff(il) == 0; }; + } else if (arch == LLM_ARCH_QWEN35 || arch == LLM_ARCH_QWEN35MOE) { + const uint32_t n_main = hparams.n_layer - hparams.nextn_predict_layers; + filter_attn = [&, n_main](int32_t il) { + return (uint32_t)il < n_main && !hparams.is_recurrent(il); + }; + filter_recr = [&, n_main](int32_t il) { + return (uint32_t)il < n_main && hparams.is_recurrent(il); + }; } if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { @@ -1975,6 +2015,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, /* recurrent_type_s */ GGML_TYPE_F32, /* recurrent_rs_size */ std::max((uint32_t) 1, cparams.n_seq_max), /* n_seq_max */ cparams.n_seq_max, + /* n_rs_seq */ cparams.n_rs_seq, /* offload */ cparams.offload_kqv, /* unified */ cparams.kv_unified, /* filter_attn */ std::move(filter_attn), @@ -1993,6 +2034,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, /* recurrent_type_v */ GGML_TYPE_F32, /* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max), /* n_seq_max */ cparams.n_seq_max, + /* n_rs_seq */ cparams.n_rs_seq, /* offload */ cparams.offload_kqv, /* unified */ cparams.kv_unified, /* filter_attn */ std::move(filter_attn), @@ -2000,6 +2042,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, } } else { llama_memory_i::layer_reuse_cb reuse = nullptr; + llama_kv_cache::layer_filter_cb filter = nullptr; if (arch == LLM_ARCH_GEMMA3N || arch == LLM_ARCH_GEMMA4) { reuse = [&](int32_t il) { @@ -2011,6 +2054,11 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, }; } + if (mtp_on_hybrid_qwen35) { + const uint32_t n_main = hparams.n_layer - hparams.nextn_predict_layers; + filter = [n_main](int32_t il) { return (uint32_t)il >= n_main; }; + } + if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { GGML_ASSERT(hparams.is_swa_any()); @@ -2026,7 +2074,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, cparams.n_seq_max, cparams.n_ubatch, 1, - nullptr, + filter, reuse); } else { GGML_ASSERT(!hparams.is_swa_any()); @@ -2043,7 +2091,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, 1, hparams.n_swa, hparams.swa_type, - nullptr, + filter, nullptr); } } @@ -2146,6 +2194,7 @@ int32_t llama_model_n_swa(const llama_model * model) { return model->hparams.n_swa; } + uint32_t llama_model_n_cls_out(const struct llama_model * model) { return model->hparams.n_cls_out; } diff --git a/examples/talk-llama/llama-model.h b/examples/talk-llama/llama-model.h index d63c689185a..398a0aa725c 100644 --- a/examples/talk-llama/llama-model.h +++ b/examples/talk-llama/llama-model.h @@ -202,12 +202,16 @@ struct llama_layer_shortconv { }; struct llama_layer_nextn { - struct ggml_tensor * eh_proj = nullptr; - struct ggml_tensor * embed_tokens = nullptr; - struct ggml_tensor * enorm = nullptr; - struct ggml_tensor * hnorm = nullptr; - struct ggml_tensor * shared_head_head = nullptr; - struct ggml_tensor * shared_head_norm = nullptr; + struct ggml_tensor * eh_proj = nullptr; + struct ggml_tensor * eh_proj_s = nullptr; + struct ggml_tensor * eh_proj_in_s = nullptr; + struct ggml_tensor * embed_tokens = nullptr; + struct ggml_tensor * enorm = nullptr; + struct ggml_tensor * hnorm = nullptr; + struct ggml_tensor * shared_head_head = nullptr; + struct ggml_tensor * shared_head_head_s = nullptr; + struct ggml_tensor * shared_head_head_in_s = nullptr; + struct ggml_tensor * shared_head_norm = nullptr; }; struct llama_layer { @@ -533,6 +537,11 @@ struct llama_model { struct ggml_tensor * output_b = nullptr; struct ggml_tensor * output_norm_enc = nullptr; + + // NVFP4 per-tensor scale2, input_scale for LM head + struct ggml_tensor * output_s = nullptr; + struct ggml_tensor * output_in_s = nullptr; + // classifier struct ggml_tensor * cls = nullptr; struct ggml_tensor * cls_b = nullptr; diff --git a/examples/talk-llama/llama-vocab.cpp b/examples/talk-llama/llama-vocab.cpp index f43cf546ca0..a5cf148b268 100644 --- a/examples/talk-llama/llama-vocab.cpp +++ b/examples/talk-llama/llama-vocab.cpp @@ -530,6 +530,8 @@ struct llm_tokenizer_bpe : llm_tokenizer { struct llm_tokenizer_bpe_session { llm_tokenizer_bpe_session(const llama_vocab & vocab, const llm_tokenizer_bpe & tokenizer) : vocab(vocab), tokenizer(tokenizer) {} + virtual ~llm_tokenizer_bpe_session() = default; + static void append(const llama_token token_id, std::vector & output) { output.push_back(token_id); } @@ -567,7 +569,7 @@ struct llm_tokenizer_bpe_session { } } - void tokenize(const std::string & text, std::vector & output) { + virtual void tokenize(const std::string & text, std::vector & output) { int final_prev_index = -1; const auto word_collection = unicode_regex_split(text, tokenizer.regex_exprs, tokenizer.byte_encode); @@ -1579,6 +1581,88 @@ struct llm_tokenizer_plamo2_session { const llm_tokenizer_plamo2 & tokenizer; }; +// reserved suffix (U+E000) that keeps DNA k-mers distinct from identical +// base-vocab BPE tokens (e.g. CCCCCC) in token_to_id; erased from id_to_token +// text at load +static const std::string dna_kmer_marker = "\xee\x80\x80"; + +struct llm_tokenizer_hybriddna_session : llm_tokenizer_bpe_session { + llm_tokenizer_hybriddna_session(const llama_vocab & vocab, const llm_tokenizer_bpe & tokenizer) : llm_tokenizer_bpe_session{vocab, tokenizer}, vocab{vocab} {} + + void tokenize(const std::string & text, std::vector & output) override { + static const std::string open_tag = ""; + static const std::string close_tag = ""; + + const auto dna_begin_id = vocab.text_to_token(open_tag); + const auto dna_end_id = vocab.text_to_token(close_tag); + const auto dna_oov_id = vocab.text_to_token(""); + + // Fall back to plain BPE if the DNA pieces aren't in the vocab. + if (dna_begin_id == LLAMA_TOKEN_NULL || dna_end_id == LLAMA_TOKEN_NULL || dna_oov_id == LLAMA_TOKEN_NULL) { + llm_tokenizer_bpe_session::tokenize(text, output); + return; + } + + const size_t k = 6; + size_t pos = 0; + + while (pos < text.size()) { + const size_t start = text.find(open_tag, pos); + if (start == std::string::npos) { + if (pos < text.size()) { + llm_tokenizer_bpe_session::tokenize(text.substr(pos), output); + } + break; + } + if (start > pos) { + llm_tokenizer_bpe_session::tokenize(text.substr(pos, start - pos), output); + } + output.push_back(dna_begin_id); + + const size_t content_start = start + open_tag.size(); + const size_t end = text.find(close_tag, content_start); + const size_t content_end = (end == std::string::npos) ? text.size() : end; + + emit_dna_kmers(text.substr(content_start, content_end - content_start), k, dna_oov_id, output); + + if (end == std::string::npos) { + break; + } + output.push_back(dna_end_id); + pos = end + close_tag.size(); + } + } + +private: + void emit_dna_kmers(const std::string & raw, size_t k, llama_token oov_id, std::vector & output) { + std::string seq = raw; + for (char & c : seq) { + if (c >= 'a' && c <= 'z') { + c = char(c - 32); + } + } + + // k-mers carry the reserved marker suffix; a non-ACGT k-mer simply + // isn't in the vocab and falls back to + auto kmer_token = [&](const std::string & kmer) { + const auto tok = vocab.text_to_token(kmer + dna_kmer_marker); + return tok != LLAMA_TOKEN_NULL ? tok : oov_id; + }; + + size_t i = 0; + for (; i + k <= seq.size(); i += k) { + output.push_back(kmer_token(seq.substr(i, k))); + } + if (i < seq.size()) { + std::string kmer = seq.substr(i); + kmer.append(k - kmer.size(), 'A'); + output.push_back(kmer_token(kmer)); + } + } + + const llama_vocab & vocab; +}; + // // impl // @@ -1808,7 +1892,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { special_mask_id = 103; add_sep = true; - } else if (tokenizer_model == "gpt2") { + } else if (tokenizer_model == "gpt2" || tokenizer_model == "hybriddna") { type = LLAMA_VOCAB_TYPE_BPE; // read bpe merges and populate bpe ranks @@ -2266,6 +2350,23 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { } GGML_ASSERT(id_to_token.size() == token_to_id.size()); + // hybriddna: the marker suffix kept k-mer ids distinct in token_to_id; erase + // it from id_to_token so the k-mers detokenize to the bare DNA sequence. The + // k-mers are the block right after , so only scan from there. + if (tokenizer_model == "hybriddna") { + const auto idx = token_to_id.find(""); + if (idx != token_to_id.end()) { + auto it = id_to_token.begin() + idx->second + 1; + for (; it != id_to_token.end(); ++it) { + std::string & text = it->text; + if (text.size() > dna_kmer_marker.size() + && text.compare(text.size() - dna_kmer_marker.size(), dna_kmer_marker.size(), dna_kmer_marker) == 0) { + text.erase(text.size() - dna_kmer_marker.size()); + } + } + } + } + init_tokenizer(type); // determine the newline token: LLaMA "<0x0A>" == 10 == '\n', Falcon 193 == '\n' @@ -3144,11 +3245,19 @@ std::vector llama_vocab::impl::tokenize( } break; case LLAMA_VOCAB_TYPE_BPE: { - llm_tokenizer_bpe_session session(vocab, *static_cast(tokenizer.get())); // it calls some other methods that are not exist in llm_tokenizer, // here just cast it to bpe tokenizer object + const llm_tokenizer_bpe * tok_bpe = static_cast(tokenizer.get()); + + std::unique_ptr session; + if (vocab.get_tokenizer_model() == "hybriddna") { + session = std::make_unique(vocab, *tok_bpe); + } else { + session = std::make_unique(vocab, *tok_bpe); + } + if (add_special) { - session.append_bos(output); + session->append_bos(output); } for (const auto & fragment : fragment_buffer) { if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { @@ -3161,15 +3270,15 @@ std::vector llama_vocab::impl::tokenize( #ifdef PRETOKENIZERDEBUG LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", text.length(), fragment.offset, fragment.length, text.c_str()); #endif - session.tokenize(text, output); + session->tokenize(text, output); } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN) - session.append(fragment.token, output); + session->append(fragment.token, output); } } if (add_special) { - session.append_eos(output); - session.check_double_bos_eos(output); + session->append_eos(output); + session->check_double_bos_eos(output); } } break; case LLAMA_VOCAB_TYPE_WPM: diff --git a/examples/talk-llama/llama.h b/examples/talk-llama/llama.h index 308e8ba9dbd..e8374c53b70 100644 --- a/examples/talk-llama/llama.h +++ b/examples/talk-llama/llama.h @@ -198,6 +198,11 @@ extern "C" { LLAMA_SPLIT_MODE_TENSOR = 3, }; + enum llama_context_type { + LLAMA_CONTEXT_TYPE_DEFAULT = 0, + LLAMA_CONTEXT_TYPE_MTP = 1, + }; + // TODO: simplify (https://github.com/ggml-org/llama.cpp/pull/9294#pullrequestreview-2286561979) typedef struct llama_token_data { llama_token id; // token id @@ -333,9 +338,11 @@ extern "C" { uint32_t n_batch; // logical maximum batch size that can be submitted to llama_decode uint32_t n_ubatch; // physical maximum batch size uint32_t n_seq_max; // max number of sequences (i.e. distinct states for recurrent models) + uint32_t n_rs_seq; // number of recurrent-state snapshots per seq for rollback (0 = no rollback) [EXPERIMENTAL] int32_t n_threads; // number of threads to use for generation int32_t n_threads_batch; // number of threads to use for batch processing + enum llama_context_type ctx_type; // set the context type (e.g. MTP) enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type` enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id enum llama_attention_type attention_type; // attention type to use for embeddings @@ -530,6 +537,7 @@ extern "C" { LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx); LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx); LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx); + LLAMA_API uint32_t llama_n_rs_seq (const struct llama_context * ctx); DEPRECATED(LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model), "use llama_model_n_ctx_train instead"); DEPRECATED(LLAMA_API int32_t llama_n_embd (const struct llama_model * model), "use llama_model_n_embd instead"); @@ -866,7 +874,8 @@ extern "C" { // work only with partial states, such as SWA KV cache or recurrent cache (e.g. Mamba) #define LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY 1 -// keeps the tensor data on device buffers (i.e. not accessible in host memory, but faster save/load) +// Keeps the tensor data on device buffers (i.e. not accessible in host memory, but faster save/load). +// Getting the state for a seq_id with this flag invalidates all prior states gotten for that seq_id with this flag. #define LLAMA_STATE_SEQ_FLAGS_ON_DEVICE 2 typedef uint32_t llama_state_seq_flags; diff --git a/examples/talk-llama/models/afmoe.cpp b/examples/talk-llama/models/afmoe.cpp index 602e3176afd..a7c77ee5d28 100644 --- a/examples/talk-llama/models/afmoe.cpp +++ b/examples/talk-llama/models/afmoe.cpp @@ -277,7 +277,7 @@ llama_model_afmoe::graph::graph(const llama_model & model, const llm_graph_param res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/apertus.cpp b/examples/talk-llama/models/apertus.cpp index 136ff702957..bec7136521c 100644 --- a/examples/talk-llama/models/apertus.cpp +++ b/examples/talk-llama/models/apertus.cpp @@ -160,7 +160,7 @@ llama_model_apertus::graph::graph(const llama_model & model, const llm_graph_par res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/arcee.cpp b/examples/talk-llama/models/arcee.cpp index 70e86d41130..d086c4717ff 100644 --- a/examples/talk-llama/models/arcee.cpp +++ b/examples/talk-llama/models/arcee.cpp @@ -148,7 +148,7 @@ llama_model_arcee::graph::graph(const llama_model & model, const llm_graph_param res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/arctic.cpp b/examples/talk-llama/models/arctic.cpp index d8653a44639..27deadffeb7 100644 --- a/examples/talk-llama/models/arctic.cpp +++ b/examples/talk-llama/models/arctic.cpp @@ -171,7 +171,7 @@ llama_model_arctic::graph::graph(const llama_model & model, const llm_graph_para res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/arwkv7.cpp b/examples/talk-llama/models/arwkv7.cpp index 79aa8c90899..9bd04127b25 100644 --- a/examples/talk-llama/models/arwkv7.cpp +++ b/examples/talk-llama/models/arwkv7.cpp @@ -193,7 +193,7 @@ llama_model_arwkv7::graph::graph(const llama_model & model, const llm_graph_para cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/baichuan.cpp b/examples/talk-llama/models/baichuan.cpp index 4e55290e4e5..4d26081cd5d 100644 --- a/examples/talk-llama/models/baichuan.cpp +++ b/examples/talk-llama/models/baichuan.cpp @@ -146,7 +146,7 @@ llama_model_baichuan::graph::graph(const llama_model & model, const llm_graph_pa res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/bailingmoe.cpp b/examples/talk-llama/models/bailingmoe.cpp index 030dd4f42a4..fe1ae10864b 100644 --- a/examples/talk-llama/models/bailingmoe.cpp +++ b/examples/talk-llama/models/bailingmoe.cpp @@ -171,7 +171,7 @@ llama_model_bailingmoe::graph::graph(const llama_model & model, const llm_graph_ res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/bailingmoe2.cpp b/examples/talk-llama/models/bailingmoe2.cpp index e7fe3d5b45a..2f0d44a6259 100644 --- a/examples/talk-llama/models/bailingmoe2.cpp +++ b/examples/talk-llama/models/bailingmoe2.cpp @@ -210,7 +210,7 @@ llama_model_bailingmoe2::graph::graph(const llama_model & model, const llm_graph res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/bloom.cpp b/examples/talk-llama/models/bloom.cpp index b600fb0c954..30b0f3d07d0 100644 --- a/examples/talk-llama/models/bloom.cpp +++ b/examples/talk-llama/models/bloom.cpp @@ -142,7 +142,7 @@ llama_model_bloom::graph::graph(const llama_model & model, const llm_graph_param cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/chameleon.cpp b/examples/talk-llama/models/chameleon.cpp index 8510b9e29f8..4bceaefd63b 100644 --- a/examples/talk-llama/models/chameleon.cpp +++ b/examples/talk-llama/models/chameleon.cpp @@ -181,7 +181,7 @@ llama_model_chameleon::graph::graph(const llama_model & model, const llm_graph_p res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output_with_img_logits", -1); // TODO: this suppresses the output of image tokens, which is required to enable text-only outputs. diff --git a/examples/talk-llama/models/chatglm.cpp b/examples/talk-llama/models/chatglm.cpp index e898eff7939..6766fa71c15 100644 --- a/examples/talk-llama/models/chatglm.cpp +++ b/examples/talk-llama/models/chatglm.cpp @@ -151,7 +151,7 @@ llama_model_chatglm::graph::graph(const llama_model & model, const llm_graph_par cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/codeshell.cpp b/examples/talk-llama/models/codeshell.cpp index e9e85d96713..274dd3342a7 100644 --- a/examples/talk-llama/models/codeshell.cpp +++ b/examples/talk-llama/models/codeshell.cpp @@ -143,7 +143,7 @@ llama_model_codeshell::graph::graph(const llama_model & model, const llm_graph_p cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/cogvlm.cpp b/examples/talk-llama/models/cogvlm.cpp index 79236121bd5..2e231bb3f93 100644 --- a/examples/talk-llama/models/cogvlm.cpp +++ b/examples/talk-llama/models/cogvlm.cpp @@ -150,7 +150,7 @@ llama_model_cogvlm::graph::graph(const llama_model & model, const llm_graph_para cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; ggml_build_forward_expand(gf, cur); diff --git a/examples/talk-llama/models/cohere2.cpp b/examples/talk-llama/models/cohere2.cpp index 12edbae1094..a514cf88fc6 100644 --- a/examples/talk-llama/models/cohere2.cpp +++ b/examples/talk-llama/models/cohere2.cpp @@ -146,7 +146,7 @@ llama_model_cohere2::graph::graph(const llama_model & model, const llm_graph_par res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); if (f_logit_scale) { cur = ggml_scale(ctx0, cur, f_logit_scale); diff --git a/examples/talk-llama/models/command-r.cpp b/examples/talk-llama/models/command-r.cpp index decb89f547b..adf7fcaa20f 100644 --- a/examples/talk-llama/models/command-r.cpp +++ b/examples/talk-llama/models/command-r.cpp @@ -131,7 +131,7 @@ llama_model_command_r::graph::graph(const llama_model & model, const llm_graph_p res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); if (f_logit_scale) { cur = ggml_scale(ctx0, cur, f_logit_scale); diff --git a/examples/talk-llama/models/dbrx.cpp b/examples/talk-llama/models/dbrx.cpp index bce6b04bcf9..af71c775365 100644 --- a/examples/talk-llama/models/dbrx.cpp +++ b/examples/talk-llama/models/dbrx.cpp @@ -145,7 +145,7 @@ llama_model_dbrx::graph::graph(const llama_model & model, const llm_graph_params res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/deci.cpp b/examples/talk-llama/models/deci.cpp index 9f1a959c32c..567e3535276 100644 --- a/examples/talk-llama/models/deci.cpp +++ b/examples/talk-llama/models/deci.cpp @@ -181,7 +181,7 @@ llama_model_deci::graph::graph(const llama_model & model, const llm_graph_params res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/deepseek.cpp b/examples/talk-llama/models/deepseek.cpp index c7946059662..f52ec9518b6 100644 --- a/examples/talk-llama/models/deepseek.cpp +++ b/examples/talk-llama/models/deepseek.cpp @@ -185,7 +185,7 @@ llama_model_deepseek::graph::graph(const llama_model & model, const llm_graph_pa res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/delta-net-base.cpp b/examples/talk-llama/models/delta-net-base.cpp index 6bc989c9509..4f4c7cac7a8 100644 --- a/examples/talk-llama/models/delta-net-base.cpp +++ b/examples/talk-llama/models/delta-net-base.cpp @@ -1,6 +1,7 @@ #include "models.h" #include "llama-impl.h" +#include "llama-memory-recurrent.h" // utility to get one slice from the third dimension // input dim: [x, y, c, b] @@ -397,7 +398,9 @@ std::pair llm_build_delta_net_base::build_delta_ne GGML_ASSERT(b->ne[0] == 1 && b->ne[1] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs); GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs); - ggml_tensor * result = ggml_gated_delta_net(ctx0, q, k, v, g, b, s); + // K=1 (final state only): reshape to 3D (S_v*S_v*H_v, 1, n_seqs) for ggml_gated_delta_net. + ggml_tensor * s_3d = ggml_reshape_3d(ctx0, s, S_v * S_v * H_v, 1, n_seqs); + ggml_tensor * result = ggml_gated_delta_net(ctx0, q, k, v, g, b, s_3d); if (n_tokens == 1) { cb(result, LLAMA_TENSOR_NAME_FGDN_AR, il); } else { @@ -443,3 +446,162 @@ std::pair llm_build_delta_net_base::build_delta_ne return build_delta_net_chunking(q, k, v, g, b, s, il); } + +ggml_tensor * llm_build_delta_net_base::build_conv_state( + llm_graph_input_rs * inp, + ggml_tensor * conv_states_all, + ggml_tensor * qkv_mixed, + int64_t conv_kernel_size, + int64_t conv_channels, + int il) { + const auto * mctx_cur = inp->mctx; + + const auto kv_head = mctx_cur->get_head(); + const auto mem_size = mctx_cur->get_size(); + + const int64_t n_seqs = ubatch.n_seqs; + + ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs); + cb(conv_states, "conv_states", il); + + conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs); + cb(conv_states, "conv_states_reshaped", il); + + qkv_mixed = ggml_transpose(ctx0, qkv_mixed); + cb(qkv_mixed, "qkv_mixed_transposed", il); + + ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0); + cb(conv_input, "conv_input", il); + + const int64_t row_count = (conv_kernel_size - 1) * conv_channels; + + const size_t row_size = ggml_row_size(conv_states_all->type, row_count); + + if (cparams.n_rs_seq == 0) { + const int64_t s_idx = conv_input->ne[0] - conv_states->ne[0]; + const int64_t s_slot = 0; + + ggml_tensor * conv_state_last = + ggml_view_3d(ctx0, conv_input, + conv_kernel_size - 1, conv_channels, n_seqs, + conv_input->nb[1], conv_input->nb[2], + ggml_row_size(conv_input->type, s_idx)); + cb(conv_state_last, "conv_state_last", il); + + ggml_tensor * conv_state_update = + ggml_view_2d(ctx0, conv_states_all, + row_count, n_seqs, conv_states_all->nb[1], + (s_slot * mem_size + kv_head) * row_size); + cb(conv_state_update, "conv_state_update", il); + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, conv_state_last, conv_state_update)); + } else { + // [TAG_RECURRENT_ROLLBACK_SPLITS] + // TODO: this logic incorrectly assumes that the last (n_rs_seq + 1) tokens of a sequence in a batch are + // inside the same ubatch. currently with `split_equal()` this is not correct + + const int64_t K = (int64_t) cparams.n_rs_seq + 1; + + for (int64_t t = 1; t <= K; ++t) { + const int64_t s_idx = std::max(0, conv_input->ne[0] - conv_states->ne[0] - K + t); + const int64_t s_slot = K - t; + + ggml_tensor * conv_state_last = + ggml_view_3d(ctx0, conv_input, + conv_kernel_size - 1, conv_channels, n_seqs, + conv_input->nb[1], conv_input->nb[2], + ggml_row_size(conv_input->type, s_idx)); + + ggml_tensor * conv_state_update = + ggml_view_2d(ctx0, + conv_states_all, row_count, n_seqs, + conv_states_all->nb[1], + (s_slot * mem_size + kv_head) * row_size); + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, conv_state_last, conv_state_update)); + } + } + + return conv_input; +} + +ggml_tensor * llm_build_delta_net_base::build_recurrent_attn( + llm_graph_input_rs * inp, + ggml_tensor * ssm_states_all, + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * b, + ggml_tensor * s, + int il) { + const auto * mctx_cur = inp->mctx; + const auto kv_head = mctx_cur->get_head(); + const uint32_t mem_size = mctx_cur->get_size(); + + const int64_t S_v = s->ne[0]; + const int64_t H_v = s->ne[2]; + const int64_t n_seqs = s->ne[3]; + const int64_t n_seq_tokens = q->ne[2]; + + const bool keep = cparams.n_rs_seq > 0; + + if (!keep) { + auto attn_out = build_delta_net(q, k, v, g, b, s, il); + ggml_tensor * output = attn_out.first; + ggml_tensor * new_state = attn_out.second; + cb(output, "attn_output", il); + cb(new_state, "new_state", il); + + ggml_build_forward_expand(gf, + ggml_cpy(ctx0, new_state, + ggml_view_2d(ctx0, ssm_states_all, hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1], + kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all)))); + + return output; + } + + const int64_t D = S_v * S_v * H_v; + const int64_t K = cparams.n_rs_seq + 1; + + // TODO: remove pad + simplify + ggml_tensor * s_3d = ggml_reshape_3d(ctx0, s, D, 1, n_seqs); + ggml_tensor * s_3d_pad = ggml_pad (ctx0, s_3d, 0, K - 1, 0, 0); + + ggml_tensor * gdn_out = ggml_gated_delta_net(ctx0, q, k, v, g, b, s_3d_pad); + if (n_seq_tokens > 1) { + cb(gdn_out, LLAMA_TENSOR_NAME_FGDN_CH, il); + } else { + cb(gdn_out, LLAMA_TENSOR_NAME_FGDN_AR, il); + } + + const int64_t attn_score_elems = S_v * H_v * n_seq_tokens * n_seqs; + const int64_t state_size_per_snap = S_v * S_v * H_v * n_seqs; + + ggml_tensor * output = ggml_view_4d(ctx0, gdn_out, + S_v, H_v, n_seq_tokens, n_seqs, + ggml_row_size(gdn_out->type, S_v), + ggml_row_size(gdn_out->type, S_v * H_v), + ggml_row_size(gdn_out->type, S_v * H_v * n_seq_tokens), + 0); + cb(output, "attn_output", il); + + const size_t row_size = hparams.n_embd_s() * ggml_element_size(ssm_states_all); + for (int64_t k_i = 0; k_i < K; ++k_i) { + const uint32_t cache_slot = (uint32_t) (K - 1 - k_i); + ggml_tensor * src = ggml_view_4d(ctx0, gdn_out, + S_v, S_v, H_v, n_seqs, + ggml_row_size(gdn_out->type, S_v), + ggml_row_size(gdn_out->type, S_v * S_v), + ggml_row_size(gdn_out->type, S_v * S_v * H_v), + ggml_row_size(gdn_out->type, attn_score_elems + k_i * state_size_per_snap)); + + ggml_tensor * dst = ggml_view_2d(ctx0, ssm_states_all, + hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1], + ((size_t) cache_slot * mem_size + kv_head) * row_size); + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, src, dst)); + } + + return output; +} diff --git a/examples/talk-llama/models/dots1.cpp b/examples/talk-llama/models/dots1.cpp index 93cbcf9d931..435d27281c6 100644 --- a/examples/talk-llama/models/dots1.cpp +++ b/examples/talk-llama/models/dots1.cpp @@ -183,7 +183,7 @@ llama_model_dots1::graph::graph(const llama_model & model, const llm_graph_param res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/dream.cpp b/examples/talk-llama/models/dream.cpp index 60a3f0ec285..12ac6f1ce88 100644 --- a/examples/talk-llama/models/dream.cpp +++ b/examples/talk-llama/models/dream.cpp @@ -128,7 +128,7 @@ llama_model_dream::graph::graph(const llama_model & model, const llm_graph_param res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/ernie4-5-moe.cpp b/examples/talk-llama/models/ernie4-5-moe.cpp index 2bd01a2c512..8d9ff138676 100644 --- a/examples/talk-llama/models/ernie4-5-moe.cpp +++ b/examples/talk-llama/models/ernie4-5-moe.cpp @@ -124,7 +124,7 @@ llama_model_ernie4_5_moe::graph::graph(const llama_model & model, const llm_grap res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/ernie4-5.cpp b/examples/talk-llama/models/ernie4-5.cpp index fa989fe92cd..9b39c605e35 100644 --- a/examples/talk-llama/models/ernie4-5.cpp +++ b/examples/talk-llama/models/ernie4-5.cpp @@ -155,7 +155,7 @@ llama_model_ernie4_5::graph::graph(const llama_model & model, const llm_graph_pa res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/exaone-moe.cpp b/examples/talk-llama/models/exaone-moe.cpp index 54bb3ca86b3..76d91982fc5 100644 --- a/examples/talk-llama/models/exaone-moe.cpp +++ b/examples/talk-llama/models/exaone-moe.cpp @@ -237,7 +237,7 @@ llama_model_exaone_moe::graph::graph(const llama_model & model, const llm_graph_ res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/exaone.cpp b/examples/talk-llama/models/exaone.cpp index 75d5f60631c..c7e9960d718 100644 --- a/examples/talk-llama/models/exaone.cpp +++ b/examples/talk-llama/models/exaone.cpp @@ -127,7 +127,7 @@ llama_model_exaone::graph::graph(const llama_model & model, const llm_graph_para res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/exaone4.cpp b/examples/talk-llama/models/exaone4.cpp index 5506e76424d..499e22dde81 100644 --- a/examples/talk-llama/models/exaone4.cpp +++ b/examples/talk-llama/models/exaone4.cpp @@ -163,7 +163,7 @@ llama_model_exaone4::graph::graph(const llama_model & model, const llm_gra res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/falcon-h1.cpp b/examples/talk-llama/models/falcon-h1.cpp index d353befdb8e..94b65a3c7c9 100644 --- a/examples/talk-llama/models/falcon-h1.cpp +++ b/examples/talk-llama/models/falcon-h1.cpp @@ -200,7 +200,7 @@ llama_model_falcon_h1::graph::graph(const llama_model & model, const llm_graph_p res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/falcon.cpp b/examples/talk-llama/models/falcon.cpp index 75f2cfef560..ad546ef2db5 100644 --- a/examples/talk-llama/models/falcon.cpp +++ b/examples/talk-llama/models/falcon.cpp @@ -152,7 +152,7 @@ llama_model_falcon::graph::graph(const llama_model & model, const llm_graph_para cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/gemma.cpp b/examples/talk-llama/models/gemma.cpp index 06731670007..1519682fdf6 100644 --- a/examples/talk-llama/models/gemma.cpp +++ b/examples/talk-llama/models/gemma.cpp @@ -130,7 +130,7 @@ llama_model_gemma::graph::graph(const llama_model & model, const llm_graph_param res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/gemma2.cpp b/examples/talk-llama/models/gemma2.cpp index 6255bf740fc..ae3f9ffb530 100644 --- a/examples/talk-llama/models/gemma2.cpp +++ b/examples/talk-llama/models/gemma2.cpp @@ -163,7 +163,7 @@ llama_model_gemma2::graph::graph(const llama_model & model, const llm_graph_para res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); // final logit soft-capping cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping); diff --git a/examples/talk-llama/models/gemma3.cpp b/examples/talk-llama/models/gemma3.cpp index ee510fe38b0..63a2b380e71 100644 --- a/examples/talk-llama/models/gemma3.cpp +++ b/examples/talk-llama/models/gemma3.cpp @@ -207,7 +207,7 @@ llama_model_gemma3::graph::graph(const llama_model & model, const llm_grap res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); if (hparams.f_final_logit_softcapping) { cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping); diff --git a/examples/talk-llama/models/gemma3n.cpp b/examples/talk-llama/models/gemma3n.cpp index 881499b0ca7..6ec3a006081 100644 --- a/examples/talk-llama/models/gemma3n.cpp +++ b/examples/talk-llama/models/gemma3n.cpp @@ -296,7 +296,7 @@ llama_model_gemma3n::graph::graph(const llama_model & model, const llm_graph_par cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); { // final logit soft-capping diff --git a/examples/talk-llama/models/gemma4.cpp b/examples/talk-llama/models/gemma4.cpp index f45ae4cad59..4f9d8b18bc7 100644 --- a/examples/talk-llama/models/gemma4.cpp +++ b/examples/talk-llama/models/gemma4.cpp @@ -380,7 +380,7 @@ llama_model_gemma4::graph::graph(const llama_model & model, const llm_graph_para res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); if (hparams.f_final_logit_softcapping) { cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping); diff --git a/examples/talk-llama/models/glm4-moe.cpp b/examples/talk-llama/models/glm4-moe.cpp index 45886b51ac1..27654b8cba3 100644 --- a/examples/talk-llama/models/glm4-moe.cpp +++ b/examples/talk-llama/models/glm4-moe.cpp @@ -275,7 +275,7 @@ llama_model_glm4_moe::graph::graph(const llama_model & model, const llm_graph_pa res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/glm4.cpp b/examples/talk-llama/models/glm4.cpp index d6ef76e26d6..7c242fed298 100644 --- a/examples/talk-llama/models/glm4.cpp +++ b/examples/talk-llama/models/glm4.cpp @@ -185,7 +185,7 @@ llama_model_glm4::graph::graph(const llama_model & model, const llm_graph_params res->t_embd = cur; // Output projection - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/gpt2.cpp b/examples/talk-llama/models/gpt2.cpp index ba49c31b56b..e2dcc8b1521 100644 --- a/examples/talk-llama/models/gpt2.cpp +++ b/examples/talk-llama/models/gpt2.cpp @@ -138,7 +138,7 @@ llama_model_gpt2::graph::graph(const llama_model & model, const llm_graph_params cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/gptneox.cpp b/examples/talk-llama/models/gptneox.cpp index 33ebe2d8800..443e35addf2 100644 --- a/examples/talk-llama/models/gptneox.cpp +++ b/examples/talk-llama/models/gptneox.cpp @@ -209,7 +209,7 @@ llama_model_gptneox::graph::graph(const llama_model & model, const llm_graph_par cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/granite-hybrid.cpp b/examples/talk-llama/models/granite-hybrid.cpp index 12e4790ae24..27f6706ea10 100644 --- a/examples/talk-llama/models/granite-hybrid.cpp +++ b/examples/talk-llama/models/granite-hybrid.cpp @@ -186,7 +186,7 @@ llama_model_granite_hybrid::graph::graph(const llama_model & model, const llm_gr res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); // For Granite architectures - scale logits if (hparams.f_logit_scale) { diff --git a/examples/talk-llama/models/granite.cpp b/examples/talk-llama/models/granite.cpp index 5e7c7b68181..cda4aa231fa 100644 --- a/examples/talk-llama/models/granite.cpp +++ b/examples/talk-llama/models/granite.cpp @@ -145,7 +145,7 @@ llama_model_granite::graph::graph( res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); // For Granite architectures - scale logits cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale); diff --git a/examples/talk-llama/models/grok.cpp b/examples/talk-llama/models/grok.cpp index 0bc49d00206..7c46ec1c0f2 100644 --- a/examples/talk-llama/models/grok.cpp +++ b/examples/talk-llama/models/grok.cpp @@ -206,7 +206,7 @@ llama_model_grok::graph::graph(const llama_model & model, const llm_graph_params res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cur = ggml_scale(ctx0, cur, hparams.f_logit_scale); diff --git a/examples/talk-llama/models/grovemoe.cpp b/examples/talk-llama/models/grovemoe.cpp index feef815165b..1cab75adc7f 100644 --- a/examples/talk-llama/models/grovemoe.cpp +++ b/examples/talk-llama/models/grovemoe.cpp @@ -184,7 +184,7 @@ llama_model_grovemoe::graph::graph(const llama_model & model, const llm_graph_pa res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/hunyuan-moe.cpp b/examples/talk-llama/models/hunyuan-moe.cpp index 44af42412f7..deb3c9671f3 100644 --- a/examples/talk-llama/models/hunyuan-moe.cpp +++ b/examples/talk-llama/models/hunyuan-moe.cpp @@ -179,7 +179,7 @@ llama_model_hunyuan_moe::graph::graph(const llama_model & model, const llm_graph res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/hunyuan-vl.cpp b/examples/talk-llama/models/hunyuan-vl.cpp index 5fb9154bec0..da9bb74de7e 100644 --- a/examples/talk-llama/models/hunyuan-vl.cpp +++ b/examples/talk-llama/models/hunyuan-vl.cpp @@ -181,7 +181,7 @@ llama_model_hunyuan_vl::graph::graph(const llama_model & model, const llm_graph_ cb(cur, "result_norm", -1); res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/internlm2.cpp b/examples/talk-llama/models/internlm2.cpp index f0c5580a6f4..f9ee37a24b6 100644 --- a/examples/talk-llama/models/internlm2.cpp +++ b/examples/talk-llama/models/internlm2.cpp @@ -129,7 +129,7 @@ llama_model_internlm2::graph::graph(const llama_model & model, const llm_graph_p res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/jais.cpp b/examples/talk-llama/models/jais.cpp index a6451dca095..2ba162605f1 100644 --- a/examples/talk-llama/models/jais.cpp +++ b/examples/talk-llama/models/jais.cpp @@ -123,7 +123,7 @@ llama_model_jais::graph::graph(const llama_model & model, const llm_graph_params cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/jais2.cpp b/examples/talk-llama/models/jais2.cpp index ad59b953e8d..8966131441c 100644 --- a/examples/talk-llama/models/jais2.cpp +++ b/examples/talk-llama/models/jais2.cpp @@ -152,7 +152,7 @@ llama_model_jais2::graph::graph(const llama_model & model, const llm_graph_param res->t_embd = cur; // Output projection - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/jamba.cpp b/examples/talk-llama/models/jamba.cpp index e1b8d137e38..84ea63c3136 100644 --- a/examples/talk-llama/models/jamba.cpp +++ b/examples/talk-llama/models/jamba.cpp @@ -189,7 +189,7 @@ llama_model_jamba::graph::graph(const llama_model & model, const llm_graph_param res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/lfm2.cpp b/examples/talk-llama/models/lfm2.cpp index df6a8028736..29081344b24 100644 --- a/examples/talk-llama/models/lfm2.cpp +++ b/examples/talk-llama/models/lfm2.cpp @@ -262,7 +262,7 @@ llama_model_lfm2::graph::graph(const llama_model & model, const llm_graph_ cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/llada-moe.cpp b/examples/talk-llama/models/llada-moe.cpp index b60f67f6c4b..9722dde9f17 100644 --- a/examples/talk-llama/models/llada-moe.cpp +++ b/examples/talk-llama/models/llada-moe.cpp @@ -153,7 +153,7 @@ llama_model_llada_moe::graph::graph(const llama_model & model, const llm_graph_p res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/llada.cpp b/examples/talk-llama/models/llada.cpp index fa21c5fe32c..58b2c466e17 100644 --- a/examples/talk-llama/models/llada.cpp +++ b/examples/talk-llama/models/llada.cpp @@ -147,7 +147,7 @@ llama_model_llada::graph::graph(const llama_model & model, const llm_graph_param res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/llama.cpp b/examples/talk-llama/models/llama.cpp index 8ddb5936820..cef66d054b0 100644 --- a/examples/talk-llama/models/llama.cpp +++ b/examples/talk-llama/models/llama.cpp @@ -235,7 +235,7 @@ llama_model_llama::graph::graph(const llama_model & model, const llm_grap if constexpr (!embed) { // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/llama4.cpp b/examples/talk-llama/models/llama4.cpp index 899611d53f6..0ff5376d571 100644 --- a/examples/talk-llama/models/llama4.cpp +++ b/examples/talk-llama/models/llama4.cpp @@ -260,7 +260,7 @@ llama_model_llama4::graph::graph(const llama_model & model, const llm_grap res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/maincoder.cpp b/examples/talk-llama/models/maincoder.cpp index 3dbd82fd362..84cfe399027 100644 --- a/examples/talk-llama/models/maincoder.cpp +++ b/examples/talk-llama/models/maincoder.cpp @@ -141,7 +141,7 @@ llama_model_maincoder::graph::graph(const llama_model & model, const llm_graph_p res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/mamba.cpp b/examples/talk-llama/models/mamba.cpp index b7708d7fdd1..887a1fa509a 100644 --- a/examples/talk-llama/models/mamba.cpp +++ b/examples/talk-llama/models/mamba.cpp @@ -128,7 +128,7 @@ llama_model_mamba::graph::graph(const llama_model & model, const llm_graph_param res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/mimo2.cpp b/examples/talk-llama/models/mimo2.cpp index 71996616611..d0295ec116f 100644 --- a/examples/talk-llama/models/mimo2.cpp +++ b/examples/talk-llama/models/mimo2.cpp @@ -231,7 +231,7 @@ llama_model_mimo2::graph::graph(const llama_model & model, const llm_graph_param res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/minicpm3.cpp b/examples/talk-llama/models/minicpm3.cpp index ff5eb6ffa5f..1ffc54fa7c6 100644 --- a/examples/talk-llama/models/minicpm3.cpp +++ b/examples/talk-llama/models/minicpm3.cpp @@ -251,7 +251,7 @@ llama_model_minicpm3::graph::graph(const llama_model & model, const llm_graph_pa cb(cur, "lmhead_scaling", -1); // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/minimax-m2.cpp b/examples/talk-llama/models/minimax-m2.cpp index 0dee8934692..22e291d73a3 100644 --- a/examples/talk-llama/models/minimax-m2.cpp +++ b/examples/talk-llama/models/minimax-m2.cpp @@ -158,7 +158,7 @@ llama_model_minimax_m2::graph::graph(const llama_model & model, const llm_graph_ res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/mistral3.cpp b/examples/talk-llama/models/mistral3.cpp index 708da49af1f..4e6ebef82cb 100644 --- a/examples/talk-llama/models/mistral3.cpp +++ b/examples/talk-llama/models/mistral3.cpp @@ -222,7 +222,7 @@ llama_model_mistral3::graph::graph(const llama_model & model, const llm_graph_pa res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/models.h b/examples/talk-llama/models/models.h index 6d5f18a8e20..7e551eb965b 100644 --- a/examples/talk-llama/models/models.h +++ b/examples/talk-llama/models/models.h @@ -46,7 +46,7 @@ struct llm_build_delta_net_base : public llm_graph_context { ggml_tensor * s, int il); - // use the ggml_gated_delta_net fused operator + // use the ggml_gated_delta_net fused operator (K=1; state has shape (D, 1, n_seqs)) std::pair build_delta_net_fused( ggml_tensor * q, ggml_tensor * k, @@ -65,6 +65,29 @@ struct llm_build_delta_net_base : public llm_graph_context { ggml_tensor * b, ggml_tensor * s, int il); + + // read conv state from cache, concat with qkv_mixed, write back (single slot or per-token) + // qkv_mixed: (qkv_dim, n_seq_tokens, n_seqs); returns conv_input: (kernel_size + n_seq_tokens - 1, channels, n_seqs) + ggml_tensor * build_conv_state( + llm_graph_input_rs * inp, + ggml_tensor * conv_states_all, + ggml_tensor * qkv_mixed, + int64_t conv_kernel_size, + int64_t conv_channels, + int il); + + // run delta-net attention and write the new recurrent state(s) back to ssm_states_all + // s: (head_v_dim, head_v_dim, num_v_heads, n_seqs); returns output: (head_v_dim, num_v_heads, n_seq_tokens, n_seqs) + ggml_tensor * build_recurrent_attn( + llm_graph_input_rs * inp, + ggml_tensor * ssm_states_all, + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * b, + ggml_tensor * s, + int il); }; struct llm_build_rwkv6_base : public llm_graph_context { @@ -1739,6 +1762,10 @@ struct llama_model_qwen35 : public llama_model_base { const llama_model & model; }; + struct graph_mtp : public llm_graph_context { + graph_mtp(const llama_model & model, const llm_graph_params & params); + }; + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; @@ -1781,6 +1808,10 @@ struct llama_model_qwen35moe : public llama_model_base { const llama_model & model; }; + struct graph_mtp : public llm_graph_context { + graph_mtp(const llama_model & model, const llm_graph_params & params); + }; + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; }; diff --git a/examples/talk-llama/models/mpt.cpp b/examples/talk-llama/models/mpt.cpp index cfc60e8de29..0229d20ed36 100644 --- a/examples/talk-llama/models/mpt.cpp +++ b/examples/talk-llama/models/mpt.cpp @@ -161,7 +161,7 @@ llama_model_mpt::graph::graph(const llama_model & model, const llm_graph_params cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/nemotron-h.cpp b/examples/talk-llama/models/nemotron-h.cpp index 865461f61db..a82f9c170b4 100644 --- a/examples/talk-llama/models/nemotron-h.cpp +++ b/examples/talk-llama/models/nemotron-h.cpp @@ -174,7 +174,7 @@ llama_model_nemotron_h::graph::graph(const llama_model & model, const llm_graph_ res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/nemotron.cpp b/examples/talk-llama/models/nemotron.cpp index 0c72ed297aa..5d4a3b5c69e 100644 --- a/examples/talk-llama/models/nemotron.cpp +++ b/examples/talk-llama/models/nemotron.cpp @@ -140,7 +140,7 @@ llama_model_nemotron::graph::graph(const llama_model & model, const llm_graph_pa res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/olmo.cpp b/examples/talk-llama/models/olmo.cpp index 161035e72bc..cfcf17bcb03 100644 --- a/examples/talk-llama/models/olmo.cpp +++ b/examples/talk-llama/models/olmo.cpp @@ -133,7 +133,7 @@ llama_model_olmo::graph::graph(const llama_model & model, const llm_graph_params res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/olmo2.cpp b/examples/talk-llama/models/olmo2.cpp index 9633f269965..7cc262f5504 100644 --- a/examples/talk-llama/models/olmo2.cpp +++ b/examples/talk-llama/models/olmo2.cpp @@ -198,7 +198,7 @@ llama_model_olmo2::graph::graph(const llama_model & model, const llm_graph res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/olmoe.cpp b/examples/talk-llama/models/olmoe.cpp index 4bb9013054c..7976ae44a51 100644 --- a/examples/talk-llama/models/olmoe.cpp +++ b/examples/talk-llama/models/olmoe.cpp @@ -164,7 +164,7 @@ llama_model_olmoe::graph::graph(const llama_model & model, const llm_graph_param res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/openai-moe.cpp b/examples/talk-llama/models/openai-moe.cpp index 13a590ce646..15b6c8c1205 100644 --- a/examples/talk-llama/models/openai-moe.cpp +++ b/examples/talk-llama/models/openai-moe.cpp @@ -160,7 +160,7 @@ llama_model_openai_moe::graph::graph(const llama_model & model, const llm_graph_ res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/openelm.cpp b/examples/talk-llama/models/openelm.cpp index b4128e116e7..9f76350fd4d 100644 --- a/examples/talk-llama/models/openelm.cpp +++ b/examples/talk-llama/models/openelm.cpp @@ -162,7 +162,7 @@ llama_model_openelm::graph::graph(const llama_model & model, const llm_graph_par cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/orion.cpp b/examples/talk-llama/models/orion.cpp index 7ace0a5139d..bcb4bbba4b1 100644 --- a/examples/talk-llama/models/orion.cpp +++ b/examples/talk-llama/models/orion.cpp @@ -132,7 +132,7 @@ llama_model_orion::graph::graph(const llama_model & model, const llm_graph_param res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/paddleocr.cpp b/examples/talk-llama/models/paddleocr.cpp index 1c0eadefa98..d39220bd778 100644 --- a/examples/talk-llama/models/paddleocr.cpp +++ b/examples/talk-llama/models/paddleocr.cpp @@ -98,7 +98,7 @@ llama_model_paddleocr::graph::graph(const llama_model & model, const llm_graph_p res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/pangu-embed.cpp b/examples/talk-llama/models/pangu-embed.cpp index 41b7e2ac23e..7593f879b24 100644 --- a/examples/talk-llama/models/pangu-embed.cpp +++ b/examples/talk-llama/models/pangu-embed.cpp @@ -148,7 +148,7 @@ llama_model_pangu_embed::graph::graph(const llama_model & model, const llm_graph res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); if (model.output_b != nullptr) { cur = ggml_add(ctx0, cur, model.output_b); diff --git a/examples/talk-llama/models/phi2.cpp b/examples/talk-llama/models/phi2.cpp index a333602c72d..8f3ed5f7b7d 100644 --- a/examples/talk-llama/models/phi2.cpp +++ b/examples/talk-llama/models/phi2.cpp @@ -130,7 +130,7 @@ llama_model_phi2::graph::graph(const llama_model & model, const llm_graph_params cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output_no_bias", -1); cur = ggml_add(ctx0, cur, model.output_b); diff --git a/examples/talk-llama/models/phi3.cpp b/examples/talk-llama/models/phi3.cpp index 0a65e91fefa..f8a4a4d5aa5 100644 --- a/examples/talk-llama/models/phi3.cpp +++ b/examples/talk-llama/models/phi3.cpp @@ -179,7 +179,7 @@ llama_model_phi3::graph::graph(const llama_model & model, const llm_graph_ cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); if (model.output_b != nullptr) { cb(cur, "result_output_no_bias", -1); diff --git a/examples/talk-llama/models/plamo.cpp b/examples/talk-llama/models/plamo.cpp index 4c16c20a0d4..c7ed1211c31 100644 --- a/examples/talk-llama/models/plamo.cpp +++ b/examples/talk-llama/models/plamo.cpp @@ -127,7 +127,7 @@ llama_model_plamo::graph::graph(const llama_model & model, const llm_graph_param res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/plamo2.cpp b/examples/talk-llama/models/plamo2.cpp index 29c8702606a..b713889fe72 100644 --- a/examples/talk-llama/models/plamo2.cpp +++ b/examples/talk-llama/models/plamo2.cpp @@ -185,7 +185,7 @@ llama_model_plamo2::graph::graph(const llama_model & model, const llm_graph_para res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); // Explicitly mark as output tensor to ensure proper backend assignment diff --git a/examples/talk-llama/models/plamo3.cpp b/examples/talk-llama/models/plamo3.cpp index 849f1579e63..29f3e803d68 100644 --- a/examples/talk-llama/models/plamo3.cpp +++ b/examples/talk-llama/models/plamo3.cpp @@ -186,7 +186,7 @@ llama_model_plamo3::graph::graph(const llama_model & model, const llm_grap cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); res->t_logits = cur; ggml_build_forward_expand(gf, cur); diff --git a/examples/talk-llama/models/plm.cpp b/examples/talk-llama/models/plm.cpp index 57f5995103b..ce050919e6a 100644 --- a/examples/talk-llama/models/plm.cpp +++ b/examples/talk-llama/models/plm.cpp @@ -204,7 +204,7 @@ llama_model_plm::graph::graph(const llama_model & model, const llm_graph_params cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/qwen.cpp b/examples/talk-llama/models/qwen.cpp index cdc076cdf77..00467dbad7d 100644 --- a/examples/talk-llama/models/qwen.cpp +++ b/examples/talk-llama/models/qwen.cpp @@ -131,7 +131,7 @@ llama_model_qwen::graph::graph(const llama_model & model, const llm_graph_params res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/qwen2.cpp b/examples/talk-llama/models/qwen2.cpp index 6320458a13b..a5147460bae 100644 --- a/examples/talk-llama/models/qwen2.cpp +++ b/examples/talk-llama/models/qwen2.cpp @@ -141,7 +141,7 @@ llama_model_qwen2::graph::graph(const llama_model & model, const llm_graph_param res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); if (model.output_b != nullptr) { cur = ggml_add(ctx0, cur, model.output_b); diff --git a/examples/talk-llama/models/qwen2moe.cpp b/examples/talk-llama/models/qwen2moe.cpp index 7587c802c68..7cb03859deb 100644 --- a/examples/talk-llama/models/qwen2moe.cpp +++ b/examples/talk-llama/models/qwen2moe.cpp @@ -184,7 +184,7 @@ llama_model_qwen2moe::graph::graph(const llama_model & model, const llm_graph_pa res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/qwen2vl.cpp b/examples/talk-llama/models/qwen2vl.cpp index 1a40fa89be4..d79db682cd4 100644 --- a/examples/talk-llama/models/qwen2vl.cpp +++ b/examples/talk-llama/models/qwen2vl.cpp @@ -134,7 +134,7 @@ llama_model_qwen2vl::graph::graph(const llama_model & model, const llm_graph_par res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/qwen3.cpp b/examples/talk-llama/models/qwen3.cpp index fa656c84ea0..41b97fed956 100644 --- a/examples/talk-llama/models/qwen3.cpp +++ b/examples/talk-llama/models/qwen3.cpp @@ -147,7 +147,7 @@ llama_model_qwen3::graph::graph(const llama_model & model, const llm_graph_param res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/qwen35.cpp b/examples/talk-llama/models/qwen35.cpp index f276be61ba8..04ecc18fcdc 100644 --- a/examples/talk-llama/models/qwen35.cpp +++ b/examples/talk-llama/models/qwen35.cpp @@ -12,16 +12,22 @@ void llama_model_qwen35::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); - // Mark recurrent layers (linear attention layers) + // NextN/MTP (Qwen3.5/3.6): extra decoder block appended beyond the main stack + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); + + // Mark recurrent layers (linear attention layers). MTP layers are dense + // attention-only and must be flagged non-recurrent. { + const uint32_t n_main = hparams.n_layer - hparams.nextn_predict_layers; uint32_t full_attn_interval = 4; ml.get_key(LLM_KV_FULL_ATTENTION_INTERVAL, full_attn_interval, false); for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = ((i + 1) % full_attn_interval != 0); + hparams.recurrent_layer_arr[i] = (i < n_main) && ((i + 1) % full_attn_interval != 0); } } - switch (hparams.n_layer) { + switch (hparams.n_layer - hparams.nextn_predict_layers) { case 24: type = hparams.n_embd == 1024 ? LLM_TYPE_0_8B : LLM_TYPE_2B; break; case 32: type = hparams.n_embd == 2560 ? LLM_TYPE_4B : LLM_TYPE_9B; break; case 64: type = LLM_TYPE_27B; break; @@ -29,9 +35,14 @@ void llama_model_qwen35::load_arch_hparams(llama_model_loader & ml) { } } -void llama_model_qwen35::load_arch_tensors(llama_model_loader &) { +void llama_model_qwen35::load_arch_tensors(llama_model_loader & ml) { LLAMA_LOAD_LOCALS; + const uint32_t n_main = n_layer - hparams.nextn_predict_layers; + const bool mtp_only = (hparams.nextn_predict_layers > 0) && + (ml.get_weight("blk.0.attn_norm.weight") == nullptr); + const int trunk_flags = mtp_only ? TENSOR_NOT_REQUIRED : 0; + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); // output @@ -43,50 +54,85 @@ void llama_model_qwen35::load_arch_tensors(llama_model_loader &) { output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); } - // Calculate dimensions from hyperparameters - const int64_t head_k_dim = hparams.ssm_d_state; - const int64_t head_v_dim = hparams.ssm_d_state; - const int64_t n_k_heads = hparams.ssm_n_group; - const int64_t n_v_heads = hparams.ssm_dt_rank; - const int64_t key_dim = head_k_dim * n_k_heads; - const int64_t value_dim = head_v_dim * n_v_heads; - const int64_t conv_dim = key_dim * 2 + value_dim; + auto load_block_trunk = [&](int il, int flags) { + auto & layer = layers[il]; - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; + // Calculate dimensions from hyperparameters + const int64_t head_k_dim = hparams.ssm_d_state; + const int64_t head_v_dim = hparams.ssm_d_state; + const int64_t n_k_heads = hparams.ssm_n_group; + const int64_t n_v_heads = hparams.ssm_dt_rank; + const int64_t key_dim = head_k_dim * n_k_heads; + const int64_t value_dim = head_v_dim * n_v_heads; + const int64_t conv_dim = key_dim * 2 + value_dim; - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); - layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0); + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", il), { n_embd }, flags); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", il), { n_embd }, flags); - if (!hparams.is_recurrent(i)) { + if (!hparams.is_recurrent(il)) { // Attention layers - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); + create_tensor_qkv(layer, il, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, flags); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", il), { n_embd_head_k * n_head, n_embd }, flags); // Q/K normalization for attention layers - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", il), { n_embd_head_k }, flags); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", il), { n_embd_head_k }, flags); } else { // Linear attention (gated delta net) specific tensors // Create tensors with calculated dimensions - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED); - layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), { n_embd, value_dim }, TENSOR_NOT_REQUIRED); - layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), { hparams.ssm_d_conv, conv_dim }, 0); - layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), { hparams.ssm_dt_rank }, 0); - layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, i), { hparams.ssm_dt_rank }, 0); - layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", i), { n_embd, n_v_heads }, 0); - layer.ssm_alpha = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "weight", i), { n_embd, n_v_heads }, 0); - layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), { head_v_dim }, 0); - layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), { value_dim, n_embd }, 0); + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", il), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED); + layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", il), { n_embd, value_dim }, TENSOR_NOT_REQUIRED); + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", il), { hparams.ssm_d_conv, conv_dim }, flags); + layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", il), { hparams.ssm_dt_rank }, flags); + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, il), { hparams.ssm_dt_rank }, flags); + layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", il), { n_embd, n_v_heads }, flags); + layer.ssm_alpha = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "weight", il), { n_embd, n_v_heads }, flags); + layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", il), { head_v_dim }, flags); + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", il), { value_dim, n_embd }, flags); } - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", il), {n_embd, n_ff}, flags); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", il), { n_ff, n_embd}, flags); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", il), {n_embd, n_ff}, flags); + }; + + auto load_block_mtp = [&](int il) { + auto & layer = layers[il]; + + // MTP block looks like a full-attention Qwen3.5 decoder block. + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", il), { n_embd }, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", il), { n_embd }, 0); + + create_tensor_qkv(layer, il, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", il), { n_embd_head_k * n_head, n_embd }, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", il), { n_embd_head_k }, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", il), { n_embd_head_k }, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", il), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", il), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", il), {n_embd, n_ff}, 0); + + // NextN-specific tensors that define the MTP block. + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", il), { 2 * n_embd, n_embd }, 0); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", il), { n_embd }, 0); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", il), { n_embd }, 0); + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", il), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", il), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", il), { n_embd }, TENSOR_NOT_REQUIRED); + }; + + for (int i = 0; i < (int) n_main; ++i) { + load_block_trunk(i, trunk_flags); + } + for (int i = (int) n_main; i < n_layer; ++i) { + load_block_mtp(i); } } std::unique_ptr llama_model_qwen35::build_arch_graph(const llm_graph_params & params) const { + if (params.gtype == LLM_GRAPH_TYPE_DECODER_MTP) { + return std::make_unique(*this, params); + } return std::make_unique(*this, params); } @@ -111,7 +157,9 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para ggml_tensor * inp_pos = build_inp_pos(); ggml_tensor * inp_out_ids = build_inp_out_ids(); - for (int il = 0; il < n_layer; ++il) { + // MTP/NextN layers are loaded as extra decoder blocks but not executed in the main pass. + const int n_transformer_layers = n_layer - (int) hparams.nextn_predict_layers; + for (int il = 0; il < n_transformer_layers; ++il) { ggml_tensor * inpSA = inpL; cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); @@ -128,7 +176,7 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il); } - if (il == n_layer - 1 && inp_out_ids) { + if (il == n_transformer_layers - 1 && inp_out_ids && cparams.embeddings_pre_norm_masked) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } @@ -160,6 +208,13 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para } cur = inpL; + cb(cur, "h_pre_norm", -1); + res->t_h_pre_norm = cur; + + if (!cparams.embeddings_pre_norm_masked && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + } + // Final norm cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1); @@ -167,7 +222,7 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para res->t_embd = cur; // LM head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; @@ -297,8 +352,6 @@ ggml_tensor * llama_model_qwen35::graph::build_layer_attn_linear( const int64_t head_v_dim = d_inner / num_v_heads; const int64_t n_seq_tokens = ubatch.n_seq_tokens; - const auto kv_head = mctx_cur->get_head(); - GGML_ASSERT(n_seqs != 0); GGML_ASSERT(ubatch.equal_seqs()); GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); @@ -328,41 +381,14 @@ ggml_tensor * llama_model_qwen35::graph::build_layer_attn_linear( gate = ggml_reshape_4d(ctx0, gate, 1, num_v_heads, n_seq_tokens, n_seqs); - // Get convolution states from cache ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); - // Build the convolution states tensor - ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs); - cb(conv_states, "conv_states", il); - - // Calculate convolution kernel size ggml_tensor * conv_kernel = model.layers[il].ssm_conv1d; const int64_t conv_kernel_size = conv_kernel->ne[0]; const int64_t conv_channels = d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state; - conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs); - cb(conv_states, "conv_states_reshaped", il); - - qkv_mixed = ggml_transpose(ctx0, qkv_mixed); - cb(qkv_mixed, "qkv_mixed_transposed", il); - - ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0); - cb(conv_input, "conv_input", il); - - // Update convolution state cache - // Extract the last (conv_kernel_size - 1) states from conv_input - ggml_tensor * last_conv_states = - ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, conv_channels, n_seqs, conv_input->nb[1], - conv_input->nb[2], (conv_input->ne[0] - conv_states->ne[0]) * ggml_element_size(conv_input)); - cb(last_conv_states, "last_conv_states", il); - - ggml_tensor * state_update_target = - ggml_view_2d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels, n_seqs, conv_states_all->nb[1], - kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all)); - cb(state_update_target, "state_update_target", il); - - ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target)); + ggml_tensor * conv_input = build_conv_state(inp, conv_states_all, qkv_mixed, conv_kernel_size, conv_channels, il); ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs); state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim, num_v_heads, n_seqs); @@ -413,7 +439,7 @@ ggml_tensor * llama_model_qwen35::graph::build_layer_attn_linear( //v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); // if head keys and value keys are different, repeat to force tensors into matching shapes - // note: need explicit repeat only if we are not using the fused GDN + // note: need explicit repeat only if we are not using the fused GDN. if (num_k_heads != num_v_heads && (!cparams.fused_gdn_ar || !cparams.fused_gdn_ch)) { GGML_ASSERT(num_v_heads % num_k_heads == 0); q_conv = ggml_repeat_4d(ctx0, q_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs); @@ -424,18 +450,7 @@ ggml_tensor * llama_model_qwen35::graph::build_layer_attn_linear( cb(k_conv, "k_conv_predelta", il); cb(v_conv, "v_conv_predelta", il); - auto attn_out = build_delta_net(q_conv, k_conv, v_conv, gate, beta, state, il); - - ggml_tensor * output = attn_out.first; - ggml_tensor * new_state = attn_out.second; - cb(output, "attn_output", il); - cb(new_state, "new_state", il); - - // Update the recurrent states - ggml_build_forward_expand(gf, - ggml_cpy(ctx0, new_state, - ggml_view_2d(ctx0, ssm_states_all, hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1], - kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all)))); + ggml_tensor * output = build_recurrent_attn(inp, ssm_states_all, q_conv, k_conv, v_conv, gate, beta, state, il); // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] ggml_tensor * z_2d = ggml_reshape_4d(ctx0, z, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); @@ -471,3 +486,151 @@ ggml_tensor * llama_model_qwen35::graph::build_layer_ffn(ggml_tensor * cur, cons return cur; } + +// LLM_GRAPH_TYPE_DECODER_MTP draft head for Qwen3.5/3.6 dense series +llama_model_qwen35::graph_mtp::graph_mtp(const llama_model & model, const llm_graph_params & params) + : llm_graph_context(params) { + GGML_ASSERT(hparams.nextn_predict_layers > 0 && "QWEN35 MTP requires nextn_predict_layers > 0"); + GGML_ASSERT(hparams.nextn_predict_layers == 1 && "QWEN35 MTP currently only supports a single MTP block"); + + const int64_t n_embd_head = hparams.n_embd_head_v(); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + + // hparams.n_layer includes both main model layers and MTP layers. The MTP + // layer is stored immediately after the main layers in model.layers[]. + const int il = (int) hparams.n_layer - (int) hparams.nextn_predict_layers; + const auto & layer = model.layers[il]; + + GGML_ASSERT(layer.nextn.eh_proj && "MTP block missing nextn.eh_proj"); + GGML_ASSERT(layer.nextn.enorm && "MTP block missing nextn.enorm"); + GGML_ASSERT(layer.nextn.hnorm && "MTP block missing nextn.hnorm"); + + int sections[4]; + std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); + + auto inp = std::make_unique(hparams.n_embd); + + inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + ggml_set_input(inp->tokens); + + inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens); + ggml_set_input(inp->embd); + ggml_set_name(inp->embd, "mtp_h_input"); + + ggml_tensor * tok_embd_w = layer.nextn.embed_tokens ? layer.nextn.embed_tokens : model.tok_embd; + + ggml_tensor * h_input = inp->embd; + ggml_tensor * tok_embd = ggml_get_rows(ctx0, tok_embd_w, inp->tokens); + cb(tok_embd, "mtp_tok_embd", il); + + res->add_input(std::move(inp)); + + ggml_tensor * inp_pos = build_inp_pos(); + ggml_tensor * inp_out_ids = build_inp_out_ids(); + auto * inp_attn = build_attn_inp_kv(); + + ggml_tensor * h_norm = build_norm(h_input, layer.nextn.hnorm, nullptr, LLM_NORM_RMS, il); + cb(h_norm, "mtp_hnorm", il); + + ggml_tensor * e_norm = build_norm(tok_embd, layer.nextn.enorm, nullptr, LLM_NORM_RMS, il); + cb(e_norm, "mtp_enorm", il); + + ggml_tensor * concat = ggml_concat(ctx0, e_norm, h_norm, /*dim=*/ 0); + cb(concat, "mtp_concat", il); + + ggml_tensor * cur = build_lora_mm(layer.nextn.eh_proj, concat, layer.nextn.eh_proj_s); + cb(cur, "mtp_eh_proj", il); + + ggml_tensor * inpSA = cur; + + cur = build_norm(cur, layer.attn_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "mtp_attn_norm", il); + + ggml_tensor * Qcur_full = build_lora_mm(layer.wq, cur, layer.wq_s); + cb(Qcur_full, "mtp_Qcur_full", il); + + ggml_tensor * Qcur = ggml_view_3d(ctx0, Qcur_full, + n_embd_head, n_head, n_tokens, + ggml_element_size(Qcur_full) * n_embd_head * 2, + ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, + 0); + Qcur = build_norm(Qcur, layer.attn_q_norm, nullptr, LLM_NORM_RMS, il); + cb(Qcur, "mtp_Qcur_normed", il); + + ggml_tensor * gate = ggml_view_3d(ctx0, Qcur_full, + n_embd_head, n_head, n_tokens, + ggml_element_size(Qcur_full) * n_embd_head * 2, + ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, + ggml_element_size(Qcur_full) * n_embd_head); + gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens); + cb(gate, "mtp_gate", il); + + ggml_tensor * Kcur = build_lora_mm(layer.wk, cur, layer.wk_s); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Kcur = build_norm(Kcur, layer.attn_k_norm, nullptr, LLM_NORM_RMS, il); + cb(Kcur, "mtp_Kcur_normed", il); + + ggml_tensor * Vcur = build_lora_mm(layer.wv, cur, layer.wv_s); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + cb(Vcur, "mtp_Vcur", il); + + Qcur = ggml_rope_multi(ctx0, Qcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + Kcur = ggml_rope_multi(ctx0, Kcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + const float kq_scale = hparams.f_attention_scale == 0.0f + ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + + cur = build_attn(inp_attn, + nullptr, nullptr, nullptr, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + cb(cur, "mtp_attn_pregate", il); + + cur = ggml_mul(ctx0, cur, ggml_sigmoid(ctx0, gate)); + cur = build_lora_mm(layer.wo, cur, layer.wo_s); + cb(cur, "mtp_attn_out", il); + + cur = ggml_add(ctx0, cur, inpSA); + cb(cur, "mtp_attn_residual", il); + + ggml_tensor * ffn_residual = cur; + cur = build_norm(cur, layer.attn_post_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "mtp_attn_post_norm", il); + + cur = build_ffn(cur, + layer.ffn_up, nullptr, layer.ffn_up_s, + layer.ffn_gate, nullptr, layer.ffn_gate_s, + layer.ffn_down, nullptr, layer.ffn_down_s, + nullptr, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "mtp_ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_residual); + cb(cur, "mtp_post_ffn", il); + + // Pre-norm hidden state: used by the AR draft loop to seed the next MTP step. + // (In the trunk graph this is `t_h_pre_norm`; the MTP head reuses the same slot.) + cb(cur, "h_pre_norm", -1); + res->t_h_pre_norm = cur; + + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + + ggml_tensor * head_norm_w = layer.nextn.shared_head_norm + ? layer.nextn.shared_head_norm + : model.output_norm; + GGML_ASSERT(head_norm_w && "QWEN35 MTP: missing both nextn.shared_head_norm and output_norm"); + cur = build_norm(cur, head_norm_w, nullptr, LLM_NORM_RMS, -1); + cb(cur, "mtp_shared_head_norm", -1); + + ggml_tensor * head_w = layer.nextn.shared_head_head ? layer.nextn.shared_head_head : model.output; + ggml_tensor * head_s = layer.nextn.shared_head_head ? layer.nextn.shared_head_head_s : model.output_s; + GGML_ASSERT(head_w && "QWEN35 MTP: missing LM head (nextn.shared_head_head or model.output)"); + cur = build_lora_mm(head_w, cur, head_s); + cb(cur, "result_output", -1); + + res->t_logits = cur; + ggml_build_forward_expand(gf, cur); +} diff --git a/examples/talk-llama/models/qwen35moe.cpp b/examples/talk-llama/models/qwen35moe.cpp index cf05dc9d61c..dc24f6ed537 100644 --- a/examples/talk-llama/models/qwen35moe.cpp +++ b/examples/talk-llama/models/qwen35moe.cpp @@ -15,16 +15,22 @@ void llama_model_qwen35moe::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); - // Mark recurrent layers (linear attention layers) + // NextN/MTP (Qwen3.5/3.6): extra decoder block appended beyond the main stack + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); + + // Mark recurrent layers (linear attention layers). MTP layers are dense + // attention-only and must be flagged non-recurrent. { + const uint32_t n_main = hparams.n_layer - hparams.nextn_predict_layers; uint32_t full_attn_interval = 4; ml.get_key(LLM_KV_FULL_ATTENTION_INTERVAL, full_attn_interval, false); for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = ((i + 1) % full_attn_interval != 0); + hparams.recurrent_layer_arr[i] = (i < n_main) && ((i + 1) % full_attn_interval != 0); } } - switch (hparams.n_layer) { + switch (hparams.n_layer - hparams.nextn_predict_layers) { case 40: type = LLM_TYPE_35B_A3B; break; case 48: type = LLM_TYPE_122B_A10B; break; case 60: type = LLM_TYPE_397B_A17B; break; @@ -32,9 +38,14 @@ void llama_model_qwen35moe::load_arch_hparams(llama_model_loader & ml) { } } -void llama_model_qwen35moe::load_arch_tensors(llama_model_loader &) { +void llama_model_qwen35moe::load_arch_tensors(llama_model_loader & ml) { LLAMA_LOAD_LOCALS; + const uint32_t n_main = n_layer - hparams.nextn_predict_layers; + const bool mtp_only = (hparams.nextn_predict_layers > 0) && + (ml.get_weight("blk.0.attn_norm.weight") == nullptr); + const int trunk_flags = mtp_only ? TENSOR_NOT_REQUIRED : 0; + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); // output @@ -46,60 +57,105 @@ void llama_model_qwen35moe::load_arch_tensors(llama_model_loader &) { output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); } - const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + auto load_block_trunk = [&](int il, int flags) { + auto & layer = layers[il]; - // Calculate dimensions from hyperparameters - const int64_t head_k_dim = hparams.ssm_d_state; - const int64_t head_v_dim = hparams.ssm_d_state; - const int64_t n_k_heads = hparams.ssm_n_group; - const int64_t n_v_heads = hparams.ssm_dt_rank; - const int64_t key_dim = head_k_dim * n_k_heads; - const int64_t value_dim = head_v_dim * n_v_heads; - const int64_t conv_dim = key_dim * 2 + value_dim; + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + const int64_t n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff; - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; + // Calculate dimensions from hyperparameters + const int64_t head_k_dim = hparams.ssm_d_state; + const int64_t head_v_dim = hparams.ssm_d_state; + const int64_t n_k_heads = hparams.ssm_n_group; + const int64_t n_v_heads = hparams.ssm_dt_rank; + const int64_t key_dim = head_k_dim * n_k_heads; + const int64_t value_dim = head_v_dim * n_v_heads; + const int64_t conv_dim = key_dim * 2 + value_dim; - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); - layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0); + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", il), { n_embd }, flags); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", il), { n_embd }, flags); - if (!hparams.is_recurrent(i)) { + if (!hparams.is_recurrent(il)) { // Attention layers - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); + create_tensor_qkv(layer, il, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, flags); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", il), { n_embd_head_k * n_head, n_embd }, flags); // Q/K normalization for attention layers - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", il), { n_embd_head_k }, flags); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", il), { n_embd_head_k }, flags); } else { // Linear attention (gated delta net) specific tensors // Create tensors with calculated dimensions - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED); - layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), { n_embd, value_dim }, TENSOR_NOT_REQUIRED); - layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), { hparams.ssm_d_conv, conv_dim }, 0); - layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), { hparams.ssm_dt_rank }, 0); - layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, i), { hparams.ssm_dt_rank }, 0); - layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", i), { n_embd, n_v_heads }, 0); - layer.ssm_alpha = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "weight", i), { n_embd, n_v_heads }, 0); - layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), { head_v_dim }, 0); - layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), { value_dim, n_embd }, 0); + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", il), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED); + layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", il), { n_embd, value_dim }, TENSOR_NOT_REQUIRED); + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", il), { hparams.ssm_d_conv, conv_dim }, flags); + layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", il), { hparams.ssm_dt_rank }, flags); + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, il), { hparams.ssm_dt_rank }, flags); + layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", il), { n_embd, n_v_heads }, flags); + layer.ssm_alpha = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "weight", il), { n_embd, n_v_heads }, flags); + layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", il), { head_v_dim }, flags); + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", il), { value_dim, n_embd }, flags); } - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0); - create_tensor_gate_up_exps(layer, i, n_embd, n_ff_exp, n_expert, 0); + // Routed experts + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", il), { n_embd, n_expert }, flags); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", il), { n_ff_exp, n_embd, n_expert }, flags); + create_tensor_gate_up_exps(layer, il, n_embd, n_ff_exp, n_expert, flags); // Shared experts + layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", il), { n_embd }, flags); + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", il), { n_embd, n_ff_shexp }, flags); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", il), { n_embd, n_ff_shexp }, flags); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", il), { n_ff_shexp, n_embd }, flags); + }; + + auto load_block_mtp = [&](int il) { + auto & layer = layers[il]; + + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; const int64_t n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff; - layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), { n_embd }, 0); - layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0); - layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0); - layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_shexp, n_embd }, 0); + // MTP block looks like a full-attention Qwen3.5 decoder block with MoE FFN. + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", il), { n_embd }, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", il), { n_embd }, 0); + + create_tensor_qkv(layer, il, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", il), { n_embd_head_k * n_head, n_embd }, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", il), { n_embd_head_k }, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", il), { n_embd_head_k }, 0); + + // Routed experts + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", il), { n_embd, n_expert }, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", il), { n_ff_exp, n_embd, n_expert }, 0); + create_tensor_gate_up_exps(layer, il, n_embd, n_ff_exp, n_expert, 0); + + // Shared experts + layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", il), { n_embd }, 0); + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", il), { n_embd, n_ff_shexp }, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", il), { n_embd, n_ff_shexp }, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", il), { n_ff_shexp, n_embd }, 0); + + // NextN-specific tensors that define the MTP block. + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", il), { 2 * n_embd, n_embd }, 0); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", il), { n_embd }, 0); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", il), { n_embd }, 0); + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", il), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", il), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", il), { n_embd }, TENSOR_NOT_REQUIRED); + }; + + for (int i = 0; i < (int) n_main; ++i) { + load_block_trunk(i, trunk_flags); + } + for (int i = (int) n_main; i < n_layer; ++i) { + load_block_mtp(i); } } std::unique_ptr llama_model_qwen35moe::build_arch_graph(const llm_graph_params & params) const { + if (params.gtype == LLM_GRAPH_TYPE_DECODER_MTP) { + return std::make_unique(*this, params); + } return std::make_unique(*this, params); } @@ -124,7 +180,9 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p ggml_tensor * inp_pos = build_inp_pos(); ggml_tensor * inp_out_ids = build_inp_out_ids(); - for (int il = 0; il < n_layer; ++il) { + // MTP/NextN layers are loaded as extra decoder blocks but not executed in the main pass. + const int n_transformer_layers = n_layer - (int) hparams.nextn_predict_layers; + for (int il = 0; il < n_transformer_layers; ++il) { ggml_tensor * inpSA = inpL; cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); @@ -141,7 +199,7 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il); } - if (il == n_layer - 1 && inp_out_ids) { + if (il == n_transformer_layers - 1 && inp_out_ids && cparams.embeddings_pre_norm_masked) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } @@ -173,6 +231,13 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p } cur = inpL; + cb(cur, "h_pre_norm", -1); + res->t_h_pre_norm = cur; + + if (!cparams.embeddings_pre_norm_masked && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + } + // Final norm cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1); @@ -180,7 +245,7 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p res->t_embd = cur; // LM head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; @@ -310,8 +375,6 @@ ggml_tensor * llama_model_qwen35moe::graph::build_layer_attn_linear( const int64_t head_v_dim = d_inner / num_v_heads; const int64_t n_seq_tokens = ubatch.n_seq_tokens; - const auto kv_head = mctx_cur->get_head(); - GGML_ASSERT(n_seqs != 0); GGML_ASSERT(ubatch.equal_seqs()); GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); @@ -341,41 +404,14 @@ ggml_tensor * llama_model_qwen35moe::graph::build_layer_attn_linear( gate = ggml_reshape_4d(ctx0, gate, 1, num_v_heads, n_seq_tokens, n_seqs); - // Get convolution states from cache ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); - // Build the convolution states tensor - ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs); - cb(conv_states, "conv_states", il); - - // Calculate convolution kernel size ggml_tensor * conv_kernel = model.layers[il].ssm_conv1d; const int64_t conv_kernel_size = conv_kernel->ne[0]; const int64_t conv_channels = d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state; - conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs); - cb(conv_states, "conv_states_reshaped", il); - - qkv_mixed = ggml_transpose(ctx0, qkv_mixed); - cb(qkv_mixed, "qkv_mixed_transposed", il); - - ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0); - cb(conv_input, "conv_input", il); - - // Update convolution state cache - // Extract the last (conv_kernel_size - 1) states from conv_input - ggml_tensor * last_conv_states = - ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, conv_channels, n_seqs, conv_input->nb[1], - conv_input->nb[2], (conv_input->ne[0] - conv_states->ne[0]) * ggml_element_size(conv_input)); - cb(last_conv_states, "last_conv_states", il); - - ggml_tensor * state_update_target = - ggml_view_2d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels, n_seqs, conv_states_all->nb[1], - kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all)); - cb(state_update_target, "state_update_target", il); - - ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target)); + ggml_tensor * conv_input = build_conv_state(inp, conv_states_all, qkv_mixed, conv_kernel_size, conv_channels, il); ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs); state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim, num_v_heads, n_seqs); @@ -426,7 +462,7 @@ ggml_tensor * llama_model_qwen35moe::graph::build_layer_attn_linear( //v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); // if head keys and value keys are different, repeat to force tensors into matching shapes - // note: need explicit repeat only if we are not using the fused GDN + // note: need explicit repeat only if we are not using the fused GDN. if (num_k_heads != num_v_heads && (!cparams.fused_gdn_ar || !cparams.fused_gdn_ch)) { GGML_ASSERT(num_v_heads % num_k_heads == 0); q_conv = ggml_repeat_4d(ctx0, q_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs); @@ -437,18 +473,7 @@ ggml_tensor * llama_model_qwen35moe::graph::build_layer_attn_linear( cb(k_conv, "k_conv_predelta", il); cb(v_conv, "v_conv_predelta", il); - auto attn_out = build_delta_net(q_conv, k_conv, v_conv, gate, beta, state, il); - - ggml_tensor * output = attn_out.first; - ggml_tensor * new_state = attn_out.second; - cb(output, "attn_output", il); - cb(new_state, "new_state", il); - - // Update the recurrent states - ggml_build_forward_expand(gf, - ggml_cpy(ctx0, new_state, - ggml_view_2d(ctx0, ssm_states_all, hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1], - kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all)))); + ggml_tensor * output = build_recurrent_attn(inp, ssm_states_all, q_conv, k_conv, v_conv, gate, beta, state, il); // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] ggml_tensor * z_2d = ggml_reshape_4d(ctx0, z, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); @@ -525,3 +550,183 @@ ggml_tensor * llama_model_qwen35moe::graph::build_layer_ffn(ggml_tensor * cur, c return cur; } + +// LLM_GRAPH_TYPE_DECODER_MTP draft head for Qwen3.5/3.6 MoE +llama_model_qwen35moe::graph_mtp::graph_mtp(const llama_model & model, const llm_graph_params & params) + : llm_graph_context(params) { + GGML_ASSERT(hparams.nextn_predict_layers > 0 && "QWEN35MOE MTP requires nextn_predict_layers > 0"); + GGML_ASSERT(hparams.nextn_predict_layers == 1 && "QWEN35MOE MTP currently only supports a single MTP block"); + + const int64_t n_embd_head = hparams.n_embd_head_v(); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + + const int il = (int) hparams.n_layer - (int) hparams.nextn_predict_layers; + const auto & layer = model.layers[il]; + + GGML_ASSERT(layer.nextn.eh_proj && "MTP block missing nextn.eh_proj"); + GGML_ASSERT(layer.nextn.enorm && "MTP block missing nextn.enorm"); + GGML_ASSERT(layer.nextn.hnorm && "MTP block missing nextn.hnorm"); + GGML_ASSERT(layer.ffn_gate_inp && "MTP block missing ffn_gate_inp"); + + int sections[4]; + std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); + + auto inp = std::make_unique(hparams.n_embd); + + inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + ggml_set_input(inp->tokens); + + inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens); + ggml_set_input(inp->embd); + ggml_set_name(inp->embd, "mtp_h_input"); + + ggml_tensor * tok_embd_w = layer.nextn.embed_tokens ? layer.nextn.embed_tokens : model.tok_embd; + + ggml_tensor * h_input = inp->embd; + ggml_tensor * tok_embd = ggml_get_rows(ctx0, tok_embd_w, inp->tokens); + cb(tok_embd, "mtp_tok_embd", il); + + res->add_input(std::move(inp)); + + ggml_tensor * inp_pos = build_inp_pos(); + ggml_tensor * inp_out_ids = build_inp_out_ids(); + auto * inp_attn = build_attn_inp_kv(); + + + ggml_tensor * h_norm = build_norm(h_input, layer.nextn.hnorm, nullptr, LLM_NORM_RMS, il); + cb(h_norm, "mtp_hnorm", il); + + ggml_tensor * e_norm = build_norm(tok_embd, layer.nextn.enorm, nullptr, LLM_NORM_RMS, il); + cb(e_norm, "mtp_enorm", il); + + ggml_tensor * concat = ggml_concat(ctx0, e_norm, h_norm, /*dim=*/ 0); + cb(concat, "mtp_concat", il); + + ggml_tensor * cur = build_lora_mm(layer.nextn.eh_proj, concat, layer.nextn.eh_proj_s); + cb(cur, "mtp_eh_proj", il); + + ggml_tensor * inpSA = cur; + + cur = build_norm(cur, layer.attn_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "mtp_attn_norm", il); + + ggml_tensor * Qcur_full = build_lora_mm(layer.wq, cur, layer.wq_s); + cb(Qcur_full, "mtp_Qcur_full", il); + + ggml_tensor * Qcur = ggml_view_3d(ctx0, Qcur_full, + n_embd_head, n_head, n_tokens, + ggml_element_size(Qcur_full) * n_embd_head * 2, + ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, + 0); + Qcur = build_norm(Qcur, layer.attn_q_norm, nullptr, LLM_NORM_RMS, il); + cb(Qcur, "mtp_Qcur_normed", il); + + ggml_tensor * gate = ggml_view_3d(ctx0, Qcur_full, + n_embd_head, n_head, n_tokens, + ggml_element_size(Qcur_full) * n_embd_head * 2, + ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, + ggml_element_size(Qcur_full) * n_embd_head); + gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens); + cb(gate, "mtp_gate", il); + + ggml_tensor * Kcur = build_lora_mm(layer.wk, cur, layer.wk_s); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Kcur = build_norm(Kcur, layer.attn_k_norm, nullptr, LLM_NORM_RMS, il); + cb(Kcur, "mtp_Kcur_normed", il); + + ggml_tensor * Vcur = build_lora_mm(layer.wv, cur, layer.wv_s); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + cb(Vcur, "mtp_Vcur", il); + + Qcur = ggml_rope_multi(ctx0, Qcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + Kcur = ggml_rope_multi(ctx0, Kcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + const float kq_scale = hparams.f_attention_scale == 0.0f + ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + + cur = build_attn(inp_attn, + nullptr, nullptr, nullptr, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + cb(cur, "mtp_attn_pregate", il); + + cur = ggml_mul(ctx0, cur, ggml_sigmoid(ctx0, gate)); + cur = build_lora_mm(layer.wo, cur, layer.wo_s); + cb(cur, "mtp_attn_out", il); + + cur = ggml_add(ctx0, cur, inpSA); + cb(cur, "mtp_attn_residual", il); + + ggml_tensor * ffn_residual = cur; + cur = build_norm(cur, layer.attn_post_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "mtp_attn_post_norm", il); + + // MoE FFN — routed experts plus gated shared expert (mirrors qwen35moe). + ggml_tensor * moe_out = + build_moe_ffn(cur, + layer.ffn_gate_inp, + layer.ffn_up_exps, + layer.ffn_gate_exps, + layer.ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + hparams.expert_weights_scale, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il, + nullptr, layer.ffn_gate_up_exps, + layer.ffn_up_exps_s, + layer.ffn_gate_exps_s, + layer.ffn_down_exps_s); + cb(moe_out, "mtp_ffn_moe_out", il); + + if (layer.ffn_up_shexp != nullptr) { + ggml_tensor * ffn_shexp = + build_ffn(cur, + layer.ffn_up_shexp, nullptr, layer.ffn_up_shexp_s, + layer.ffn_gate_shexp, nullptr, layer.ffn_gate_shexp_s, + layer.ffn_down_shexp, nullptr, layer.ffn_down_shexp_s, + nullptr, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(ffn_shexp, "mtp_ffn_shexp", il); + + ggml_tensor * shared_gate = build_lora_mm(layer.ffn_gate_inp_shexp, cur); + shared_gate = ggml_sigmoid(ctx0, shared_gate); + cb(shared_gate, "mtp_shared_expert_gate_sigmoid", il); + + ffn_shexp = ggml_mul(ctx0, ffn_shexp, shared_gate); + cb(ffn_shexp, "mtp_ffn_shexp_gated", il); + + cur = ggml_add(ctx0, moe_out, ffn_shexp); + } else { + cur = moe_out; + } + cb(cur, "mtp_ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_residual); + cb(cur, "mtp_post_ffn", il); + + // Pre-norm hidden state: used by the AR draft loop to seed the next MTP step. + cb(cur, "h_pre_norm", -1); + res->t_h_pre_norm = cur; + + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + + ggml_tensor * head_norm_w = layer.nextn.shared_head_norm + ? layer.nextn.shared_head_norm + : model.output_norm; + GGML_ASSERT(head_norm_w && "QWEN35MOE MTP: missing both nextn.shared_head_norm and output_norm"); + cur = build_norm(cur, head_norm_w, nullptr, LLM_NORM_RMS, -1); + cb(cur, "mtp_shared_head_norm", -1); + + ggml_tensor * head_w = layer.nextn.shared_head_head ? layer.nextn.shared_head_head : model.output; + ggml_tensor * head_s = layer.nextn.shared_head_head ? layer.nextn.shared_head_head_s : model.output_s; + GGML_ASSERT(head_w && "QWEN35MOE MTP: missing LM head (nextn.shared_head_head or model.output)"); + cur = build_lora_mm(head_w, cur, head_s); + cb(cur, "result_output", -1); + + res->t_logits = cur; + ggml_build_forward_expand(gf, cur); +} diff --git a/examples/talk-llama/models/qwen3moe.cpp b/examples/talk-llama/models/qwen3moe.cpp index 4440b83aa45..a4f8e1379c9 100644 --- a/examples/talk-llama/models/qwen3moe.cpp +++ b/examples/talk-llama/models/qwen3moe.cpp @@ -168,7 +168,7 @@ llama_model_qwen3moe::graph::graph(const llama_model & model, const llm_graph_pa res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/qwen3next.cpp b/examples/talk-llama/models/qwen3next.cpp index cb1b4814caf..1d873427db5 100644 --- a/examples/talk-llama/models/qwen3next.cpp +++ b/examples/talk-llama/models/qwen3next.cpp @@ -176,7 +176,7 @@ llama_model_qwen3next::graph::graph(const llama_model & model, const llm_graph_p res->t_embd = cur; // LM head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; @@ -378,8 +378,6 @@ ggml_tensor * llama_model_qwen3next::graph::build_layer_attn_linear( const int64_t head_v_dim = d_inner / num_v_heads; const int64_t n_seq_tokens = ubatch.n_seq_tokens; - const auto kv_head = mctx_cur->get_head(); - GGML_ASSERT(n_seqs != 0); GGML_ASSERT(ubatch.equal_seqs()); GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); @@ -429,41 +427,14 @@ ggml_tensor * llama_model_qwen3next::graph::build_layer_attn_linear( beta = ggml_reshape_4d(ctx0, beta, 1, num_v_heads, n_seq_tokens, n_seqs); gate = ggml_reshape_4d(ctx0, gate, 1, num_v_heads, n_seq_tokens, n_seqs); - // Get convolution states from cache ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); - // Build the convolution states tensor - ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs); - cb(conv_states, "conv_states", il); - - // Calculate convolution kernel size ggml_tensor * conv_kernel = model.layers[il].ssm_conv1d; const int64_t conv_kernel_size = conv_kernel->ne[0]; const int64_t conv_channels = d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state; - conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs); - cb(conv_states, "conv_states_reshaped", il); - - qkv_mixed = ggml_transpose(ctx0, qkv_mixed); - cb(qkv_mixed, "qkv_mixed_transposed", il); - - ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0); - cb(conv_input, "conv_input", il); - - // Update convolution state cache - // Extract the last (conv_kernel_size - 1) states from conv_input - ggml_tensor * last_conv_states = - ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, conv_channels, n_seqs, conv_input->nb[1], - conv_input->nb[2], (conv_input->ne[0] - conv_states->ne[0]) * ggml_element_size(conv_input)); - cb(last_conv_states, "last_conv_states", il); - - ggml_tensor * state_update_target = - ggml_view_2d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels, n_seqs, conv_states_all->nb[1], - kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all)); - cb(state_update_target, "state_update_target", il); - - ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target)); + ggml_tensor * conv_input = build_conv_state(inp, conv_states_all, qkv_mixed, conv_kernel_size, conv_channels, il); ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs); state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim, num_v_heads, n_seqs); @@ -540,18 +511,7 @@ ggml_tensor * llama_model_qwen3next::graph::build_layer_attn_linear( cb(k_conv, "k_conv_predelta", il); cb(v_conv, "v_conv_predelta", il); - auto attn_out = build_delta_net(q_conv, k_conv, v_conv, gate, beta, state, il); - - ggml_tensor * output = attn_out.first; - ggml_tensor * new_state = attn_out.second; - cb(output, "attn_output", il); - cb(new_state, "new_state", il); - - // Update the recurrent states - ggml_build_forward_expand(gf, - ggml_cpy(ctx0, new_state, - ggml_view_2d(ctx0, ssm_states_all, hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1], - kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all)))); + ggml_tensor * output = build_recurrent_attn(inp, ssm_states_all, q_conv, k_conv, v_conv, gate, beta, state, il); // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] ggml_tensor * z_2d = ggml_reshape_4d(ctx0, z, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); diff --git a/examples/talk-llama/models/qwen3vl.cpp b/examples/talk-llama/models/qwen3vl.cpp index 7871f8f7952..5defd893944 100644 --- a/examples/talk-llama/models/qwen3vl.cpp +++ b/examples/talk-llama/models/qwen3vl.cpp @@ -163,7 +163,7 @@ llama_model_qwen3vl::graph::graph(const llama_model & model, const llm_graph_par res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/qwen3vlmoe.cpp b/examples/talk-llama/models/qwen3vlmoe.cpp index b99143c8908..5b77df57122 100644 --- a/examples/talk-llama/models/qwen3vlmoe.cpp +++ b/examples/talk-llama/models/qwen3vlmoe.cpp @@ -180,7 +180,7 @@ llama_model_qwen3vlmoe::graph::graph(const llama_model & model, const llm_graph_ res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/refact.cpp b/examples/talk-llama/models/refact.cpp index f14f10917ff..bf3949a9092 100644 --- a/examples/talk-llama/models/refact.cpp +++ b/examples/talk-llama/models/refact.cpp @@ -150,7 +150,7 @@ llama_model_refact::graph::graph(const llama_model & model, const llm_graph_para res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/rnd1.cpp b/examples/talk-llama/models/rnd1.cpp index 325ee73ba5c..ca8e009615e 100644 --- a/examples/talk-llama/models/rnd1.cpp +++ b/examples/talk-llama/models/rnd1.cpp @@ -167,7 +167,7 @@ llama_model_rnd1::graph::graph(const llama_model & model, const llm_graph_params res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/rwkv6.cpp b/examples/talk-llama/models/rwkv6.cpp index 2944711acec..ba2a9dfa0db 100644 --- a/examples/talk-llama/models/rwkv6.cpp +++ b/examples/talk-llama/models/rwkv6.cpp @@ -176,7 +176,7 @@ llama_model_rwkv6::graph::graph(const llama_model & model, const llm_graph_param cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/rwkv6qwen2.cpp b/examples/talk-llama/models/rwkv6qwen2.cpp index 6f7d1f5722f..566b8cdcb54 100644 --- a/examples/talk-llama/models/rwkv6qwen2.cpp +++ b/examples/talk-llama/models/rwkv6qwen2.cpp @@ -158,7 +158,7 @@ llama_model_rwkv6qwen2::graph::graph(const llama_model & model, const llm_graph_ cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/rwkv7.cpp b/examples/talk-llama/models/rwkv7.cpp index b205e3935e1..7574b252621 100644 --- a/examples/talk-llama/models/rwkv7.cpp +++ b/examples/talk-llama/models/rwkv7.cpp @@ -202,7 +202,7 @@ llama_model_rwkv7::graph::graph(const llama_model & model, const llm_graph_param cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/seed-oss.cpp b/examples/talk-llama/models/seed-oss.cpp index 83e114740b6..806cba574be 100644 --- a/examples/talk-llama/models/seed-oss.cpp +++ b/examples/talk-llama/models/seed-oss.cpp @@ -141,7 +141,7 @@ llama_model_seed_oss::graph::graph(const llama_model & model, const llm_graph_pa res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/smallthinker.cpp b/examples/talk-llama/models/smallthinker.cpp index 3214e7cbad3..4231cccc666 100644 --- a/examples/talk-llama/models/smallthinker.cpp +++ b/examples/talk-llama/models/smallthinker.cpp @@ -178,7 +178,7 @@ llama_model_smallthinker::graph::graph(const llama_model & model, const ll res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/smollm3.cpp b/examples/talk-llama/models/smollm3.cpp index 7adaf34c534..90e7d473eaf 100644 --- a/examples/talk-llama/models/smollm3.cpp +++ b/examples/talk-llama/models/smollm3.cpp @@ -143,7 +143,7 @@ llama_model_smollm3::graph::graph(const llama_model & model, const llm_graph_par res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/stablelm.cpp b/examples/talk-llama/models/stablelm.cpp index 8f613e55947..4da7f7aefcf 100644 --- a/examples/talk-llama/models/stablelm.cpp +++ b/examples/talk-llama/models/stablelm.cpp @@ -163,7 +163,7 @@ llama_model_stablelm::graph::graph(const llama_model & model, const llm_graph_pa res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/starcoder.cpp b/examples/talk-llama/models/starcoder.cpp index 58cf0ac0edc..e131af058bc 100644 --- a/examples/talk-llama/models/starcoder.cpp +++ b/examples/talk-llama/models/starcoder.cpp @@ -135,7 +135,7 @@ llama_model_starcoder::graph::graph(const llama_model & model, const llm_graph_p cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/starcoder2.cpp b/examples/talk-llama/models/starcoder2.cpp index 45dae0602d4..9c207c02885 100644 --- a/examples/talk-llama/models/starcoder2.cpp +++ b/examples/talk-llama/models/starcoder2.cpp @@ -148,7 +148,7 @@ llama_model_starcoder2::graph::graph(const llama_model & model, const llm_graph_ res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/step35.cpp b/examples/talk-llama/models/step35.cpp index c4789752d21..3b68e68707a 100644 --- a/examples/talk-llama/models/step35.cpp +++ b/examples/talk-llama/models/step35.cpp @@ -261,7 +261,7 @@ llama_model_step35::graph::graph(const llama_model & model, const llm_graph_para cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/t5.cpp b/examples/talk-llama/models/t5.cpp index 27a0711ba41..73e32741406 100644 --- a/examples/talk-llama/models/t5.cpp +++ b/examples/talk-llama/models/t5.cpp @@ -265,7 +265,7 @@ llama_model_t5::graph::graph(const llama_model & model, const llm_graph_p res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/wavtokenizer-dec.cpp b/examples/talk-llama/models/wavtokenizer-dec.cpp index a873e5d2e8f..214fed99bad 100644 --- a/examples/talk-llama/models/wavtokenizer-dec.cpp +++ b/examples/talk-llama/models/wavtokenizer-dec.cpp @@ -253,7 +253,7 @@ llama_model_wavtokenizer_dec::graph::graph(const llama_model & model, const llm_ LLM_NORM, -1); // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cur = ggml_add(ctx0, cur, model.output_b); diff --git a/examples/talk-llama/models/xverse.cpp b/examples/talk-llama/models/xverse.cpp index e4d111e622a..d6d1c7a2e5d 100644 --- a/examples/talk-llama/models/xverse.cpp +++ b/examples/talk-llama/models/xverse.cpp @@ -126,7 +126,7 @@ llama_model_xverse::graph::graph(const llama_model & model, const llm_graph_para res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/unicode.cpp b/examples/talk-llama/unicode.cpp index dc13e53f09f..b02ecdc930f 100644 --- a/examples/talk-llama/unicode.cpp +++ b/examples/talk-llama/unicode.cpp @@ -605,6 +605,136 @@ static std::vector unicode_regex_split_custom_qwen2(const std::string & return bpe_offsets; } +// Qwen3.5 system regex: "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?[\\p{L}\\p{M}]+|\\p{N}| ?[^\\s\\p{L}\\p{M}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" +// Compared to Qwen2, letter-runs also consume Unicode combining marks (\p{M}): [\p{L}\p{M}]+ instead of \p{L}+ +static std::vector unicode_regex_split_custom_qwen35(const std::string & text, const std::vector & offsets) { + std::vector bpe_offsets; // store the offset of each word + bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size + + const auto cpts = unicode_cpts_from_utf8(text); + + size_t start = 0; + for (auto offset : offsets) { + const size_t offset_ini = start; + const size_t offset_end = start + offset; + assert(offset_end <= cpts.size()); + start = offset_end; + + static const uint32_t OUT_OF_RANGE = 0xFFFFFFFF; + auto _get_cpt = [&] (const size_t pos) -> uint32_t { + return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE; + }; + + auto _get_flags = [&] (const size_t pos) -> unicode_cpt_flags { + return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags_from_cpt(cpts[pos]) : unicode_cpt_flags{}; + }; + + size_t _prev_end = offset_ini; + auto _add_token = [&] (const size_t end) -> size_t { + assert(_prev_end <= end && end <= offset_end); + size_t len = end - _prev_end; + if (len > 0) { + bpe_offsets.push_back(len); + } + _prev_end = end; + return len; + }; + + for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) { + const uint32_t cpt = _get_cpt(pos); + const auto flags = _get_flags(pos); + + // regex: (?i:'s|'t|'re|'ve|'m|'ll|'d) // case insensitive + if (cpt == '\'' && pos+1 < offset_end) { + uint32_t cpt_next = unicode_tolower(_get_cpt(pos+1)); + if (cpt_next == 's' || cpt_next == 't' || cpt_next == 'm' || cpt_next == 'd') { + pos += _add_token(pos+2); + continue; + } + if (pos+2 < offset_end) { + uint32_t cpt_next_next = unicode_tolower(_get_cpt(pos+2)); + if ((cpt_next == 'r' && cpt_next_next == 'e') || + (cpt_next == 'v' && cpt_next_next == 'e') || + (cpt_next == 'l' && cpt_next_next == 'l')) { + pos += _add_token(pos+3); + continue; + } + } + } + + // regex: [^\r\n\p{L}\p{N}]?[\p{L}\p{M}]+ + if (!(cpt == '\r' || cpt == '\n' || flags.is_number)) { + if (flags.is_letter || flags.is_accent_mark || _get_flags(pos + 1).is_accent_mark || _get_flags(pos+1).is_letter) { + pos++; + while (_get_flags(pos).is_letter || _get_flags(pos).is_accent_mark) { + pos++; + } + _add_token(pos); + continue; + } + } + + // regex: \p{N} + if (flags.is_number) { + pos++; + _add_token(pos); + continue; + } + + // regex: ?[^\s\p{L}\p{M}\p{N}]+[\r\n]* + auto flags2 = (cpt == ' ' ? _get_flags(pos+1) : flags); + if (!(flags2.is_whitespace | flags2.is_letter | flags2.is_accent_mark | flags2.is_number) && flags.as_uint()) { + pos += (cpt == ' '); + while (!(flags2.is_whitespace | flags2.is_letter | flags2.is_accent_mark | flags2.is_number) && flags2.as_uint()) { + flags2 = _get_flags(++pos); + } + uint32_t cpt2 = _get_cpt(pos); + while (cpt2 == '\r' || cpt2 == '\n') { + cpt2 = _get_cpt(++pos); + } + _add_token(pos); + continue; + } + + size_t num_whitespaces = 0; + size_t last_end_r_or_n = 0; + while (_get_flags(pos+num_whitespaces).is_whitespace) { + uint32_t cpt2 = _get_cpt(pos+num_whitespaces); + if (cpt2 == '\r' || cpt2 == '\n') { + last_end_r_or_n = pos + num_whitespaces + 1; + } + num_whitespaces++; + } + + // regex: \s*[\r\n]+ + if (last_end_r_or_n > 0) { + pos = last_end_r_or_n; + _add_token(pos); + continue; + } + + // regex: \s+(?!\S) + if (num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != OUT_OF_RANGE) { + pos += num_whitespaces - 1; + _add_token(pos); + continue; + } + + // regex: \s+ + if (num_whitespaces > 0) { + pos += num_whitespaces; + _add_token(pos); + continue; + } + + // no matches + _add_token(++pos); + } + } + + return bpe_offsets; +} + template static std::vector unicode_regex_split_stl(const std::basic_string & text, const std::basic_string & regex, const std::vector & offsets) { using BidirIt = typename std::basic_string::const_iterator; @@ -929,6 +1059,9 @@ static std::vector unicode_regex_split_custom(const std::string & text, } else if ( regex_expr == "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+") { bpe_offsets = unicode_regex_split_custom_qwen2(text, offsets); + } else if ( + regex_expr == "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?[\\p{L}\\p{M}]+|\\p{N}| ?[^\\s\\p{L}\\p{M}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+") { + bpe_offsets = unicode_regex_split_custom_qwen35(text, offsets); } else if (regex_expr == "\\p{Han}+") { // K2's first pattern - handle all K2 patterns together bpe_offsets = unicode_regex_split_custom_kimi_k2(text, offsets); From 44a50ca41a574d596eef4b2d0a4ffcc1575d6000 Mon Sep 17 00:00:00 2001 From: Kaihui-AMD Date: Mon, 25 May 2026 17:27:42 +0800 Subject: [PATCH 685/831] readme : add AMD ROCm/HIP GPU build instructions (#3823) Signed-off-by: Kaihui-AMD --- README.md | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/README.md b/README.md index 474a1301da7..050a35be21c 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,7 @@ High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisp - [Vulkan support](#vulkan-gpu-support) - Support for CPU-only inference - [Efficient GPU support for NVIDIA](#nvidia-gpu-support) +- [AMD ROCm GPU support](#amd-rocm-gpu-support) - [OpenVINO Support](#openvino-support) - [Ascend NPU Support](#ascend-npu-support) - [Moore Threads GPU Support](#moore-threads-gpu-support) @@ -340,6 +341,27 @@ cmake -B build -DGGML_VULKAN=1 cmake --build build -j --config Release ``` +## AMD ROCm GPU support + +With AMD GPUs the processing can be accelerated via HIP/ROCm. +First, make sure you have installed [ROCm](https://rocm.docs.amd.com/en/latest/). + +Now build `whisper.cpp` with HIP support: + +``` +cmake -B build -DGGML_HIP=1 -DAMDGPU_TARGETS="gfx1201" +cmake --build build -j --config Release +``` + +Replace `gfx1201` with your GPU architecture. You can find it with: + +``` +rocminfo | grep "gfx" +``` + +Common architectures: `gfx1100` (RX 7900 XTX), `gfx1101` (RX 7800 XT), `gfx1201` (RX 9070 XT). +For multiple GPUs with different architectures: `-DAMDGPU_TARGETS="gfx1100;gfx1201"`. + ## BLAS CPU support via OpenBLAS Encoder processing can be accelerated on the CPU via OpenBLAS. From 2979e5f95fa95c4f30fd36eb5e4766448da9da44 Mon Sep 17 00:00:00 2001 From: Gilad S <7817232+giladgd@users.noreply.github.com> Date: Mon, 25 May 2026 11:33:29 +0200 Subject: [PATCH 686/831] ggml: `gguf_init_from_callback` and `gguf_init_from_buffer` (llama/22341) * ggml: implement `gguf_init_from_buffer` * test: `gguf_init_from_buffer` * fix: memory breakdown for a model loaded with `no_alloc` from a file is consistent with being loaded from a buffer * fix: use `GGML_UNUSED` Co-authored-by: Copilot * fix: remove `total_size` from `gguf_reader` * fix: file offset calculation, rename `offset` to `data_offset` Co-authored-by: Copilot * refactor: extract model loader bug fixes to another PR * feat: add `gguf_init_from_callback` * fix: always require a max expected size * fix: change `gguf_reader_callback_t`'s `output` type to `void *`, change `max_expected_size` and offsets to `uint64_t` * fix: harden against offset overflow in buffer read * fix: remove seek behavior from the callback * feat: `max_chunk_read == 0` means `SIZE_MAX` * fix: seeking in a gguf file with no tensors --------- Co-authored-by: Copilot --- ggml/include/gguf.h | 10 ++- ggml/src/gguf.cpp | 178 ++++++++++++++++++++++++++++++++++++++------ 2 files changed, 163 insertions(+), 25 deletions(-) diff --git a/ggml/include/gguf.h b/ggml/include/gguf.h index 02d5f221c03..67851ba6f16 100644 --- a/ggml/include/gguf.h +++ b/ggml/include/gguf.h @@ -76,10 +76,16 @@ extern "C" { struct ggml_context ** ctx; }; + // callback to simulate or wrap a FILE pointer - read up to `len` bytes at `offset` into `output` and return the number of bytes read + typedef size_t (*gguf_reader_callback_t)(void * userdata, void * output, uint64_t offset, size_t len); + GGML_API struct gguf_context * gguf_init_empty(void); GGML_API struct gguf_context * gguf_init_from_file_ptr(FILE * file, struct gguf_init_params params); GGML_API struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params); - //GGML_API struct gguf_context * gguf_init_from_buffer(..); + GGML_API struct gguf_context * gguf_init_from_buffer(const void * data, size_t size, struct gguf_init_params params); + + // max_chunk_read is the maximum number of bytes that the GGUF code will read at once from the callback, a value of 0 means no limit + GGML_API struct gguf_context * gguf_init_from_callback(gguf_reader_callback_t callback, void * userdata, size_t max_chunk_read, uint64_t max_expected_size, struct gguf_init_params params); GGML_API void gguf_free(struct gguf_context * ctx); @@ -87,7 +93,7 @@ extern "C" { GGML_API uint32_t gguf_get_version (const struct gguf_context * ctx); GGML_API size_t gguf_get_alignment (const struct gguf_context * ctx); - GGML_API size_t gguf_get_data_offset(const struct gguf_context * ctx); + GGML_API size_t gguf_get_data_offset(const struct gguf_context * ctx); // padded to gguf_get_alignment if and only if the gguf_context contains at least one tensor GGML_API int64_t gguf_get_n_kv(const struct gguf_context * ctx); GGML_API int64_t gguf_find_key(const struct gguf_context * ctx, const char * key); // returns -1 if key is not found diff --git a/ggml/src/gguf.cpp b/ggml/src/gguf.cpp index ab3cc974867..5e198618251 100644 --- a/ggml/src/gguf.cpp +++ b/ggml/src/gguf.cpp @@ -228,9 +228,18 @@ struct gguf_context { }; struct gguf_reader { - gguf_reader(FILE * file) : file(file) { - // read the remaining bytes once and update on each read - nbytes_remain = file_remain(file); + gguf_reader( + gguf_reader_callback_t callback, + void * userdata, + size_t max_chunk_read, + uint64_t data_offset = 0, + uint64_t nbytes_remain = 0) + : callback(callback), + userdata(userdata), + max_chunk_read(max_chunk_read), + data_offset(data_offset), + nbytes_remain(nbytes_remain) { + GGML_ASSERT(max_chunk_read > 0); } // helper for remaining bytes in a file @@ -257,12 +266,10 @@ struct gguf_reader { template bool read(T & dst) const { const size_t size = sizeof(dst); - if (nbytes_remain < size) { + if (size > nbytes_remain) { return false; } - const size_t nread = fread(&dst, 1, size, file); - nbytes_remain -= nread; - return nread == size; + return read_raw(&dst, size) == size; } template @@ -344,24 +351,71 @@ struct gguf_reader { return false; } dst.resize(static_cast(size)); - const size_t nread = fread(dst.data(), 1, size, file); - nbytes_remain -= nread; - return nread == size; + return read_raw(dst.data(), static_cast(size)) == size; } bool read(void * dst, const size_t size) const { if (size > nbytes_remain) { return false; } - const size_t nread = fread(dst, 1, size, file); - nbytes_remain -= nread; - return nread == size; + return read_raw(dst, size) == size; + } + + uint64_t tell() const { + return data_offset; + } + + bool seek(uint64_t absolute_offset) const { + const uint64_t end_offset = uint64_t(data_offset) + nbytes_remain; + if (absolute_offset > end_offset) { + return false; + } + + data_offset = absolute_offset; + nbytes_remain = end_offset - absolute_offset; + + return true; } private: - FILE * file; + size_t read_raw(void * dst, size_t size) const { + if (callback == nullptr || size == 0) { + return 0; + } + + uint8_t * data = static_cast(dst); + size_t total_nread = 0; + bool reached_eof = false; - mutable uint64_t nbytes_remain; + while (total_nread < size) { + const size_t chunk_size = std::min(max_chunk_read, size - total_nread); + if (data_offset + total_nread < data_offset) { + break; + } + const size_t nread = callback(userdata, static_cast(data + total_nread), data_offset + total_nread, chunk_size); + total_nread += nread; + if (nread != chunk_size) { + reached_eof = true; + break; + } + } + + data_offset += total_nread; + GGML_ASSERT(total_nread <= nbytes_remain); + nbytes_remain -= total_nread; + + if (reached_eof) { + nbytes_remain = 0; + } + + return total_nread; + } + + gguf_reader_callback_t callback = nullptr; + void * userdata = nullptr; + size_t max_chunk_read = 0; + mutable uint64_t data_offset = 0; + mutable uint64_t nbytes_remain = 0; }; struct gguf_context * gguf_init_empty(void) { @@ -394,12 +448,7 @@ bool gguf_read_emplace_helper(const struct gguf_reader & gr, std::vectorinfo.size()) == n_tensors); // we require the data section to be aligned, so take into account any padding - if (gguf_fseek(file, GGML_PAD(gguf_ftell(file), ctx->alignment), SEEK_SET) != 0) { + if (n_tensors > 0 && !gr.seek(GGML_PAD(gr.tell(), ctx->alignment))) { GGML_LOG_ERROR("%s: failed to seek to beginning of data section\n", __func__); gguf_free(ctx); return nullptr; } // store the current file offset - this is where the data section starts - ctx->offset = gguf_ftell(file); + ctx->offset = gr.tell(); // compute the total size of the data section, taking into account the alignment { @@ -844,6 +893,89 @@ struct gguf_context * gguf_init_from_file_ptr(FILE * file, struct gguf_init_para return ctx; } +struct gguf_context * gguf_init_from_callback(gguf_reader_callback_t callback, void * userdata, size_t max_chunk_read, uint64_t max_expected_size, struct gguf_init_params params) { + if (callback == nullptr) { + return nullptr; + } + + const struct gguf_reader gr(callback, userdata, max_chunk_read == 0 ? SIZE_MAX : max_chunk_read, 0, max_expected_size); + return gguf_init_from_reader(gr, params); +} + +struct gguf_file_reader { + FILE * file; + uint64_t offset; +}; + +static size_t gguf_file_reader_callback(void * userdata, void * output, uint64_t offset, size_t len) { + GGML_ASSERT(len > 0); + + gguf_file_reader & reader = *static_cast(userdata); + + if (reader.offset != offset) { + if (offset > INT64_MAX || gguf_fseek(reader.file, static_cast(offset), SEEK_SET) != 0) { + return 0; + } + + reader.offset = offset; + } + + const size_t nread = fread(static_cast(output), 1, len, reader.file); + reader.offset += nread; + return nread; +} + +struct gguf_context * gguf_init_from_file_ptr(FILE * file, struct gguf_init_params params) { + if (!file) { + return nullptr; + } + + const int64_t cur = gguf_ftell(file); + if (cur < 0) { + return nullptr; + } + + gguf_file_reader reader = { + /*.file = */ file, + /*.offset = */ static_cast(cur), + }; + const struct gguf_reader gr(gguf_file_reader_callback, &reader, SIZE_MAX, reader.offset, gguf_reader::file_remain(file)); + return gguf_init_from_reader(gr, params); +} + +struct gguf_buffer_reader { + const uint8_t * data; + size_t size; +}; + +static size_t gguf_buffer_reader_callback(void * userdata, void * output, uint64_t offset, size_t len) { + GGML_ASSERT(len > 0); + + const gguf_buffer_reader & reader = *static_cast(userdata); + + if (offset > reader.size || len > reader.size - offset) { + return 0; + } + + const size_t data_offset = static_cast(offset); + const size_t nread = std::min(len, reader.size - data_offset); + memcpy(static_cast(output), reader.data + data_offset, nread); + return nread; +} + +struct gguf_context * gguf_init_from_buffer(const void * data, size_t size, struct gguf_init_params params) { + if (data == nullptr || size == 0) { + return nullptr; + } + + gguf_buffer_reader reader = { + /*.data = */ static_cast(data), + /*.size = */ size, + }; + const struct gguf_reader gr(gguf_buffer_reader_callback, &reader, SIZE_MAX, 0, size); + return gguf_init_from_reader(gr, params); +} + struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params) { FILE * file = ggml_fopen(fname, "rb"); From bcff51515008baef985289b013e2aca876cdbff5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Mon, 25 May 2026 11:37:25 +0200 Subject: [PATCH 687/831] TP: fix ggml context size calculation (llama/22616) * TP: fix ggml context size calculation, memory leak * move split state cache back into the context * revert to constant ggml context size for cgraphs * increase headroom for statically allocated tensors * remove obsolete include --- ggml/src/ggml-backend-meta.cpp | 194 +++++++++++++++++++++++---------- 1 file changed, 137 insertions(+), 57 deletions(-) diff --git a/ggml/src/ggml-backend-meta.cpp b/ggml/src/ggml-backend-meta.cpp index 5f9ae9c1bc5..d0d64523b4a 100644 --- a/ggml/src/ggml-backend-meta.cpp +++ b/ggml/src/ggml-backend-meta.cpp @@ -13,6 +13,7 @@ #include #include #include +#include #include #include #include @@ -392,64 +393,100 @@ static ggml_backend_buffer_type_t ggml_backend_meta_device_get_host_buffer_type( // meta backend buffer // +// Container to hold the tensor slices per simple ggml backend buffer. +struct ggml_backend_meta_simple_tensor_container { + std::vector ctxs; + std::map> simple_tensors; + + ggml_backend_meta_simple_tensor_container(const ggml_init_params & params, const int n_simple) { + ctxs.reserve(n_simple); + for (int i = 0; i < n_simple; i++) { + ctxs.emplace_back(ggml_init(params)); + } + } + ggml_backend_meta_simple_tensor_container() {} +}; + struct ggml_backend_meta_buffer_context { + // FIXME + // Most tensors can simply be stored statically in their own buffer. + // Externally created views however also need a mapping to simple tensors but they use the buffer of the view source. + // If external views are simply using that buffer they will slowly deplete its memory. + // Current solution: rotating set of 2 "compute" containers to hold external views, works correctly for llama.cpp. + // Long-term: tie the lifetime of external views to the meta backend executing the graph instead, + // currently not possible due to graph-external operations in the backend scheduler. + ggml_backend_meta_simple_tensor_container stc_static; + ggml_backend_meta_simple_tensor_container stc_compute[2]; + int stc_compute_index = 0; + int stc_compute_index_next = 0; + std::vector bufs; + + // FIXME + // The size of the split state cache is unbounded and can theoretically grow infinitely large. + // However, it is also expensive to build and clearing it on every rebuild in ggml_backend_meta_graph_compute is too expensive. static constexpr size_t nbtc = GGML_TENSOR_SIZE - sizeof(ggml_tensor::padding); - std::map, std::pair> split_state_cache; - std::map< const ggml_tensor *, std::vector> simple_tensors; - - struct buffer_config { - ggml_context * ctx; - ggml_backend_buffer_t buf; - - buffer_config(ggml_context * ctx, ggml_backend_buffer_t buf) : ctx(ctx), buf(buf) {} - }; - std::vector buf_configs; int debug; - ggml_backend_meta_buffer_context() { + ggml_backend_meta_buffer_context( + ggml_backend_meta_simple_tensor_container & stc_static, + ggml_backend_meta_simple_tensor_container & stc_compute_0, + ggml_backend_meta_simple_tensor_container & stc_compute_1, + const std::vector & bufs) + : stc_static(std::move(stc_static)), stc_compute{std::move(stc_compute_0), std::move(stc_compute_1)} { + this->bufs.reserve(bufs.size()); + for (ggml_backend_buffer_t buf : bufs) { + this->bufs.emplace_back(buf); + } const char * GGML_META_DEBUG = getenv("GGML_META_DEBUG"); debug = GGML_META_DEBUG ? atoi(GGML_META_DEBUG) : 0; } + + ggml_backend_meta_simple_tensor_container & get_simple_tensor_container(const ggml_tensor * tensor) { + if (stc_static.simple_tensors.find(tensor) != stc_static.simple_tensors.end()) { + return stc_static; + } + return stc_compute[stc_compute_index]; + } }; static void ggml_backend_meta_buffer_free_buffer(ggml_backend_buffer_t buffer) { GGML_ASSERT(ggml_backend_buffer_is_meta(buffer)); ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) buffer->context; - for (auto & [ctx, buf] : buf_ctx->buf_configs) { - ggml_backend_buffer_free(buf); - ggml_free(ctx); - } delete buf_ctx; } static size_t ggml_backend_meta_buffer_n_bufs(ggml_backend_buffer_t meta_buf) { GGML_ASSERT(ggml_backend_buffer_is_meta(meta_buf)); ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) meta_buf->context; - return buf_ctx->buf_configs.size(); + return buf_ctx->bufs.size(); } static ggml_backend_buffer_t ggml_backend_meta_buffer_simple_buffer(ggml_backend_buffer_t meta_buf, size_t index) { GGML_ASSERT(ggml_backend_buffer_is_meta(meta_buf)); ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) meta_buf->context; - GGML_ASSERT(index < buf_ctx->buf_configs.size()); - return buf_ctx->buf_configs[index].buf; + GGML_ASSERT(index < buf_ctx->bufs.size()); + return buf_ctx->bufs[index].get(); } static struct ggml_tensor * ggml_backend_meta_buffer_simple_tensor(const struct ggml_tensor * tensor, size_t index) { GGML_ASSERT(ggml_backend_buffer_is_meta(tensor->buffer)); ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) tensor->buffer->context; - GGML_ASSERT(index < buf_ctx->buf_configs.size()); + GGML_ASSERT(index < buf_ctx->bufs.size()); - auto it = buf_ctx->simple_tensors.find(tensor); - if (it == buf_ctx->simple_tensors.end()) { + ggml_backend_meta_simple_tensor_container & stc = buf_ctx->get_simple_tensor_container(tensor); + auto it = stc.simple_tensors.find(tensor); + if (it == stc.simple_tensors.end()) { return nullptr; } return it->second[index]; } -static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state(const struct ggml_tensor * tensor, bool assume_sync) { +static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state(const struct ggml_tensor * tensor, bool assume_sync); + +static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( + ggml_backend_meta_simple_tensor_container & stc, const struct ggml_tensor * tensor, bool assume_sync) { const size_t n_bufs = ggml_backend_meta_buffer_n_bufs(tensor->buffer); ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) tensor->buffer->context; @@ -785,7 +822,7 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state(co src_ss[i] = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; continue; } - src_ss[i] = ggml_backend_meta_get_split_state(tensor->src[i], /*assume_sync =*/ true); + src_ss[i] = ggml_backend_meta_get_split_state(stc, tensor->src[i], /*assume_sync =*/ true); GGML_ASSERT(src_ss[i].axis != GGML_BACKEND_SPLIT_AXIS_UNKNOWN); } @@ -1079,17 +1116,23 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state(co return ret; } +static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state(const struct ggml_tensor * tensor, bool assume_sync) { + GGML_ASSERT(ggml_backend_buffer_is_meta(tensor->buffer)); + ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) tensor->buffer->context; + return ggml_backend_meta_get_split_state(buf_ctx->get_simple_tensor_container(tensor), tensor, assume_sync); +} + static void * ggml_backend_meta_buffer_get_base(ggml_backend_buffer_t buffer) { GGML_UNUSED(buffer); return (void *) 0x1000000000000000; // FIXME } -static enum ggml_status ggml_backend_meta_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { - GGML_ASSERT(ggml_backend_buffer_is_meta(buffer)); - ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) buffer->context; - const size_t n_simple_bufs = ggml_backend_meta_buffer_n_bufs(buffer); +static enum ggml_status ggml_backend_meta_buffer_init_tensor_impl(ggml_backend_meta_simple_tensor_container & stc, ggml_tensor * tensor) { + GGML_ASSERT(ggml_backend_buffer_is_meta(tensor->buffer)); + ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) tensor->buffer->context; + const size_t n_simple_bufs = ggml_backend_meta_buffer_n_bufs(tensor->buffer); - const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ true); + const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(stc, tensor, /*assume_sync =*/ true); GGML_ASSERT(ggml_nelements(tensor) == 0 || split_state.axis != GGML_BACKEND_SPLIT_AXIS_UNKNOWN); GGML_ASSERT(split_state.n_segments <= 16); @@ -1104,8 +1147,8 @@ static enum ggml_status ggml_backend_meta_buffer_init_tensor(ggml_backend_buffer std::vector simple_tensors; simple_tensors.reserve(n_simple_bufs); for (size_t j = 0; j < n_simple_bufs; j++) { - ggml_context * simple_ctx = buf_ctx->buf_configs[j].ctx; - ggml_backend_buffer_t simple_buf = buf_ctx->buf_configs[j].buf; + ggml_context * simple_ctx = stc.ctxs[j].get(); + ggml_backend_buffer_t simple_buf = buf_ctx->bufs[j].get(); if (split_dim >= 0 && split_dim < GGML_MAX_DIMS) { // TODO: the following assert fails for llama-parallel even though the results are correct: @@ -1158,7 +1201,7 @@ static enum ggml_status ggml_backend_meta_buffer_init_tensor(ggml_backend_buffer t_ij->data = (char *) t_ij->view_src->data + t_ij->view_offs; } else if (simple_buf != nullptr) { t_ij->data = (char *) ggml_backend_buffer_get_base(simple_buf) - + size_t(tensor->data) - size_t(ggml_backend_buffer_get_base(buffer)); + + size_t(tensor->data) - size_t(ggml_backend_buffer_get_base(tensor->buffer)); } t_ij->extra = tensor->extra; for (int i = 0; i < GGML_MAX_SRC; i++) { @@ -1194,11 +1237,18 @@ static enum ggml_status ggml_backend_meta_buffer_init_tensor(ggml_backend_buffer } } - buf_ctx->simple_tensors[tensor] = simple_tensors; + stc.simple_tensors[tensor] = simple_tensors; return GGML_STATUS_SUCCESS; } +static enum ggml_status ggml_backend_meta_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { + GGML_ASSERT(ggml_backend_buffer_is_meta(buffer)); + ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) buffer->context; + buf_ctx->stc_compute_index = buf_ctx->stc_compute_index_next; + return ggml_backend_meta_buffer_init_tensor_impl(buf_ctx->get_simple_tensor_container(tensor), tensor); +} + static void ggml_backend_meta_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { const size_t n_bufs = ggml_backend_meta_buffer_n_bufs(buffer); GGML_ASSERT(ggml_is_contiguous(tensor)); @@ -1413,8 +1463,9 @@ static void ggml_backend_meta_buffer_clear(ggml_backend_buffer_t buffer, uint8_t } static void ggml_backend_meta_buffer_reset(ggml_backend_buffer_t buffer) { - const size_t n_buffers = ggml_backend_meta_buffer_n_bufs(buffer); - for (size_t i = 0; i < n_buffers; i++) { + GGML_ASSERT(ggml_backend_buffer_is_meta(buffer)); + ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) buffer->context; + for (size_t i = 0; i < buf_ctx->bufs.size(); i++) { ggml_backend_buffer_reset(ggml_backend_meta_buffer_simple_buffer(buffer, i)); } } @@ -1440,21 +1491,24 @@ bool ggml_backend_buffer_is_meta(ggml_backend_buffer_t buf) { static ggml_backend_buffer_t ggml_backend_meta_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { const size_t n_simple_bufts = ggml_backend_meta_buft_n_bufts(buft); - ggml_init_params params = { - /*.mem_size =*/ 1024*1024*1024, // FIXME + const ggml_init_params params = { + /*.mem_size =*/ 1024*1024*ggml_tensor_overhead(), // FIXME /*.mem_buffer =*/ nullptr, /*.no_alloc =*/ true, }; + ggml_backend_meta_simple_tensor_container stc_static; + ggml_backend_meta_simple_tensor_container stc_compute_0(params, n_simple_bufts); + ggml_backend_meta_simple_tensor_container stc_compute_1(params, n_simple_bufts); - ggml_backend_meta_buffer_context * buf_ctx = new ggml_backend_meta_buffer_context(); size_t max_size = 0; - buf_ctx->buf_configs.reserve(n_simple_bufts); + std::vector bufs; + bufs.reserve(n_simple_bufts); for (size_t i = 0; i < n_simple_bufts; i++) { - ggml_backend_buffer_t simple_buf = ggml_backend_buft_alloc_buffer(ggml_backend_meta_buft_simple_buft(buft, i), size); - GGML_ASSERT(simple_buf != nullptr); - max_size = std::max(max_size, ggml_backend_buffer_get_size(simple_buf)); - buf_ctx->buf_configs.emplace_back(ggml_init(params), simple_buf); + bufs.push_back(ggml_backend_buft_alloc_buffer(ggml_backend_meta_buft_simple_buft(buft, i), size)); + GGML_ASSERT(bufs.back() != nullptr); + max_size = std::max(max_size, ggml_backend_buffer_get_size(bufs.back())); } + ggml_backend_meta_buffer_context * buf_ctx = new ggml_backend_meta_buffer_context(stc_static, stc_compute_0, stc_compute_1, bufs); return ggml_backend_buffer_init(buft, ggml_backend_meta_buffer_iface, buf_ctx, max_size); } @@ -1462,26 +1516,32 @@ static ggml_backend_buffer_t ggml_backend_meta_buffer_type_alloc_buffer(ggml_bac struct ggml_backend_buffer * ggml_backend_meta_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft) { const size_t n_simple_bufts = ggml_backend_meta_buft_n_bufts(buft); - ggml_init_params params = { - /*.mem_size =*/ 1024*1024*1024, // FIXME + constexpr size_t compute_headroom = 16; // Maximum number of views per statically allocated tensor that can be created between evals. + const ggml_init_params params_static = { + /*.mem_size =*/ ggml_get_mem_size(ctx), /*.mem_buffer =*/ nullptr, /*.no_alloc =*/ true, }; + const ggml_init_params params_compute = { + /*.mem_size =*/ compute_headroom*ggml_get_mem_size(ctx), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + ggml_backend_meta_simple_tensor_container stc_static (params_static, n_simple_bufts); + ggml_backend_meta_simple_tensor_container stc_compute_0(params_compute, n_simple_bufts); + ggml_backend_meta_simple_tensor_container stc_compute_1(params_compute, n_simple_bufts); - ggml_backend_meta_buffer_context * meta_buf_ctx = new ggml_backend_meta_buffer_context(); - meta_buf_ctx->buf_configs.reserve(n_simple_bufts); - for (size_t i = 0; i < n_simple_bufts; i++) { - meta_buf_ctx->buf_configs.emplace_back(ggml_init(params), nullptr); - } + std::vector bufs(n_simple_bufts, nullptr); + ggml_backend_meta_buffer_context * meta_buf_ctx = new ggml_backend_meta_buffer_context(stc_static, stc_compute_0, stc_compute_1, bufs); ggml_backend_buffer_t meta_buf = ggml_backend_buffer_init(buft, ggml_backend_meta_buffer_iface, meta_buf_ctx, 0); for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) { t->buffer = meta_buf; - ggml_backend_meta_buffer_init_tensor(meta_buf, t); + ggml_backend_meta_buffer_init_tensor_impl(meta_buf_ctx->stc_static, t); t->data = (void *) 0x2000000000000000; // FIXME } for (size_t i = 0; i < n_simple_bufts; i++) { - ggml_context * ctx = meta_buf_ctx->buf_configs[i].ctx; + ggml_context * ctx = meta_buf_ctx->stc_static.ctxs[i].get(); ggml_backend_buffer_type_t simple_buft = ggml_backend_meta_buft_simple_buft(buft, i); // If a ggml_context only has zero-sized tensors, ggml_backend_alloc_ctx_tensors_from_buft returns NULL. @@ -1494,15 +1554,15 @@ struct ggml_backend_buffer * ggml_backend_meta_alloc_ctx_tensors_from_buft(struc } } if (any_nonzero_slice) { - meta_buf_ctx->buf_configs[i].buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, simple_buft); + meta_buf_ctx->bufs[i].reset(ggml_backend_alloc_ctx_tensors_from_buft(ctx, simple_buft)); } else { - meta_buf_ctx->buf_configs[i].buf = ggml_backend_buft_alloc_buffer(simple_buft, 0); + meta_buf_ctx->bufs[i].reset(ggml_backend_buft_alloc_buffer(simple_buft, 0)); for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) { - t->buffer = meta_buf_ctx->buf_configs[i].buf; + t->buffer = meta_buf_ctx->bufs[i].get(); } } - GGML_ASSERT(meta_buf_ctx->buf_configs[i].buf != nullptr); - meta_buf->size = std::max(meta_buf->size, ggml_backend_buffer_get_size(meta_buf_ctx->buf_configs[i].buf)); + GGML_ASSERT(meta_buf_ctx->bufs[i]); + meta_buf->size = std::max(meta_buf->size, ggml_backend_buffer_get_size(meta_buf_ctx->bufs[i].get())); } return meta_buf; } @@ -1724,6 +1784,26 @@ static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend, } if (needs_rebuild) { + std::set used_buffers; + for (int i = 0; i < cgraph->n_leafs; i++) { + if (ggml_backend_buffer_is_meta(cgraph->leafs[i]->buffer)) { + used_buffers.emplace(cgraph->leafs[i]->buffer); + } + } + for (int i = 0; i < cgraph->n_nodes; i++) { + if (ggml_backend_buffer_is_meta(cgraph->nodes[i]->buffer)) { + used_buffers.emplace(cgraph->nodes[i]->buffer); + } + } + for (ggml_backend_buffer_t buf : used_buffers) { + ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) buf->context; + buf_ctx->stc_compute_index_next = buf_ctx->stc_compute_index ^ 1; + ggml_backend_meta_simple_tensor_container & stc = buf_ctx->stc_compute[buf_ctx->stc_compute_index_next]; + for (ggml_context_ptr & ctx : stc.ctxs) { + ggml_reset(ctx.get()); + } + stc.simple_tensors.clear(); + } size_t n_subgraphs = 0; size_t max_tmp_size = 0; @@ -1909,7 +1989,7 @@ static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend, const size_t mem_per_device_graphs_main = backend_ctx->max_subgraphs*ggml_graph_overhead_custom(backend_ctx->max_nnodes, cgraph->grads); const size_t mem_per_device_graphs_aux = n_cgraphs_per_device*backend_ctx->max_subgraphs*ggml_graph_overhead_custom(1, cgraph->grads); const size_t mem_per_device_nodes_aux = n_nodes_per_device*backend_ctx->max_subgraphs*ggml_tensor_overhead(); - ggml_init_params params = { + const ggml_init_params params = { /*.mem_size =*/ n_backends * (mem_per_device_graphs_main + mem_per_device_graphs_aux + mem_per_device_nodes_aux), /*.mem_buffer =*/ nullptr, /*.no_alloc =*/ true, From 1cf8e3a9039e2b9cdcb98f1c6b09359607b711de Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 25 May 2026 12:40:17 +0300 Subject: [PATCH 688/831] ggml : bump version to 0.13.0 (ggml/1510) --- ggml/CMakeLists.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 03020888f97..f542f18b6d4 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -4,8 +4,8 @@ project("ggml" C CXX ASM) ### GGML Version set(GGML_VERSION_MAJOR 0) -set(GGML_VERSION_MINOR 12) -set(GGML_VERSION_PATCH 1) +set(GGML_VERSION_MINOR 13) +set(GGML_VERSION_PATCH 0) set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}") list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/") From f14ae77f4082afaf10a98c17dde8282e12744d7c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 25 May 2026 12:44:07 +0300 Subject: [PATCH 689/831] sync : ggml --- scripts/sync-ggml.last | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index 2c680ce9f5d..a4f87b2b9ae 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -0a37c2167fc5b81830a32d9b1691610180ed86d6 +e705c5fed490514458bdd2eaddc43bd098fcce9b From c245b3ec23239d359ce18f3be3ee0ae92525074e Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 25 May 2026 13:05:30 +0300 Subject: [PATCH 690/831] benches : update --- scripts/bench-all-gg.txt | 237 ++++++++++++++++++++++----------------- 1 file changed, 137 insertions(+), 100 deletions(-) diff --git a/scripts/bench-all-gg.txt b/scripts/bench-all-gg.txt index 220bd4c98b8..1b65fc7d778 100644 --- a/scripts/bench-all-gg.txt +++ b/scripts/bench-all-gg.txt @@ -111,61 +111,61 @@ make -j && ./scripts/bench-all.sh 1 1 0 | CPU | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit | | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | -| M2 ULTRA | METAL | tiny | 1 | 0 | 8.57 | 1.12 | 0.27 | 0.01 | f5b477ab | -| M2 ULTRA | METAL | tiny-q5_0 | 1 | 0 | 9.17 | 1.10 | 0.28 | 0.01 | f5b477ab | -| M2 ULTRA | METAL | tiny-q5_1 | 1 | 0 | 9.16 | 1.09 | 0.28 | 0.01 | f5b477ab | -| M2 ULTRA | METAL | tiny-q8_0 | 1 | 0 | 8.81 | 1.12 | 0.27 | 0.01 | f5b477ab | -| M2 ULTRA | METAL | base | 1 | 0 | 15.60 | 1.61 | 0.41 | 0.02 | f5b477ab | -| M2 ULTRA | METAL | base-q5_0 | 1 | 0 | 16.75 | 1.54 | 0.42 | 0.02 | f5b477ab | -| M2 ULTRA | METAL | base-q5_1 | 1 | 0 | 16.64 | 1.54 | 0.43 | 0.02 | f5b477ab | -| M2 ULTRA | METAL | base-q8_0 | 1 | 0 | 16.09 | 1.55 | 0.41 | 0.02 | f5b477ab | -| M2 ULTRA | METAL | small | 1 | 0 | 46.74 | 3.13 | 0.89 | 0.05 | f5b477ab | -| M2 ULTRA | METAL | small-q5_0 | 1 | 0 | 51.57 | 3.03 | 0.91 | 0.06 | f5b477ab | -| M2 ULTRA | METAL | small-q5_1 | 1 | 0 | 51.85 | 3.03 | 0.92 | 0.06 | f5b477ab | -| M2 ULTRA | METAL | small-q8_0 | 1 | 0 | 48.34 | 3.01 | 0.89 | 0.06 | f5b477ab | -| M2 ULTRA | METAL | medium | 1 | 0 | 125.82 | 6.46 | 2.01 | 0.12 | f5b477ab | -| M2 ULTRA | METAL | medium-q5_0 | 1 | 0 | 143.44 | 5.97 | 2.07 | 0.14 | f5b477ab | -| M2 ULTRA | METAL | medium-q5_1 | 1 | 0 | 143.41 | 5.97 | 2.09 | 0.14 | f5b477ab | -| M2 ULTRA | METAL | medium-q8_0 | 1 | 0 | 131.23 | 6.30 | 2.01 | 0.13 | f5b477ab | -| M2 ULTRA | METAL | medium-dis | 1 | 0 | 114.07 | 0.90 | 0.25 | 0.02 | f5b477ab | -| M2 ULTRA | METAL | large-v2 | 1 | 0 | 240.73 | 9.46 | 3.21 | 0.21 | f5b477ab | -| M2 ULTRA | METAL | large-v2-q5_0 | 1 | 0 | 276.56 | 8.62 | 3.16 | 0.25 | f5b477ab | -| M2 ULTRA | METAL | large-v2-q5_1 | 1 | 0 | 275.90 | 8.98 | 3.16 | 0.25 | f5b477ab | -| M2 ULTRA | METAL | large-v2-q8_0 | 1 | 0 | 251.00 | 9.10 | 3.02 | 0.22 | f5b477ab | -| M2 ULTRA | METAL | large-v2-dis | 1 | 0 | 217.43 | 1.01 | 0.28 | 0.02 | f5b477ab | -| M2 ULTRA | METAL | large-v3-turbo | 1 | 0 | 218.39 | 1.55 | 0.47 | 0.03 | f5b477ab | -| M2 ULTRA | METAL | large-v3-turbo-q5_0 | 1 | 0 | 249.41 | 1.39 | 0.47 | 0.04 | f5b477ab | -| M2 ULTRA | METAL | large-v3-turbo-q8_0 | 1 | 0 | 227.54 | 1.43 | 0.45 | 0.03 | f5b477ab | +| M2 ULTRA | METAL | tiny | 1 | 0 | 8.10 | 1.03 | 0.25 | 0.01 | f14ae77f | +| M2 ULTRA | METAL | tiny-q5_0 | 1 | 0 | 8.53 | 1.02 | 0.26 | 0.01 | f14ae77f | +| M2 ULTRA | METAL | tiny-q5_1 | 1 | 0 | 8.67 | 1.00 | 0.26 | 0.01 | f14ae77f | +| M2 ULTRA | METAL | tiny-q8_0 | 1 | 0 | 9.32 | 1.02 | 0.26 | 0.01 | f14ae77f | +| M2 ULTRA | METAL | base | 1 | 0 | 15.50 | 1.51 | 0.40 | 0.02 | f14ae77f | +| M2 ULTRA | METAL | base-q5_0 | 1 | 0 | 16.63 | 1.45 | 0.40 | 0.02 | f14ae77f | +| M2 ULTRA | METAL | base-q5_1 | 1 | 0 | 16.76 | 1.44 | 0.39 | 0.02 | f14ae77f | +| M2 ULTRA | METAL | base-q8_0 | 1 | 0 | 15.73 | 1.43 | 0.38 | 0.02 | f14ae77f | +| M2 ULTRA | METAL | small | 1 | 0 | 45.43 | 2.93 | 0.83 | 0.05 | f14ae77f | +| M2 ULTRA | METAL | small-q5_0 | 1 | 0 | 49.78 | 2.85 | 0.84 | 0.06 | f14ae77f | +| M2 ULTRA | METAL | small-q5_1 | 1 | 0 | 50.22 | 2.85 | 0.84 | 0.06 | f14ae77f | +| M2 ULTRA | METAL | small-q8_0 | 1 | 0 | 47.08 | 2.78 | 0.83 | 0.05 | f14ae77f | +| M2 ULTRA | METAL | medium | 1 | 0 | 125.19 | 6.10 | 1.88 | 0.12 | f14ae77f | +| M2 ULTRA | METAL | medium-q5_0 | 1 | 0 | 142.49 | 5.59 | 1.90 | 0.14 | f14ae77f | +| M2 ULTRA | METAL | medium-q5_1 | 1 | 0 | 142.63 | 5.68 | 1.92 | 0.14 | f14ae77f | +| M2 ULTRA | METAL | medium-q8_0 | 1 | 0 | 130.98 | 5.83 | 1.87 | 0.13 | f14ae77f | +| M2 ULTRA | METAL | medium-dis | 1 | 0 | 113.95 | 0.88 | 0.24 | 0.02 | f14ae77f | +| M2 ULTRA | METAL | large-v2 | 1 | 0 | 239.27 | 8.97 | 2.92 | 0.21 | f14ae77f | +| M2 ULTRA | METAL | large-v2-q5_0 | 1 | 0 | 275.07 | 8.56 | 2.92 | 0.24 | f14ae77f | +| M2 ULTRA | METAL | large-v2-q5_1 | 1 | 0 | 274.28 | 8.62 | 2.93 | 0.24 | f14ae77f | +| M2 ULTRA | METAL | large-v2-q8_0 | 1 | 0 | 248.90 | 8.32 | 2.81 | 0.22 | f14ae77f | +| M2 ULTRA | METAL | large-v2-dis | 1 | 0 | 214.26 | 0.97 | 0.27 | 0.02 | f14ae77f | +| M2 ULTRA | METAL | large-v3-turbo | 1 | 0 | 222.47 | 1.49 | 0.45 | 0.03 | f14ae77f | +| M2 ULTRA | METAL | large-v3-turbo-q5_0 | 1 | 0 | 250.56 | 1.35 | 0.45 | 0.04 | f14ae77f | +| M2 ULTRA | METAL | large-v3-turbo-q8_0 | 1 | 0 | 228.57 | 1.33 | 0.43 | 0.03 | f14ae77f | make -j && ./scripts/bench-all.sh 1 1 1 | CPU | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit | | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | -| M2 ULTRA | METAL | tiny | 1 | 1 | 6.06 | 0.96 | 0.22 | 0.01 | f5b477ab | -| M2 ULTRA | METAL | tiny-q5_0 | 1 | 1 | 6.51 | 0.93 | 0.22 | 0.01 | f5b477ab | -| M2 ULTRA | METAL | tiny-q5_1 | 1 | 1 | 6.47 | 0.93 | 0.23 | 0.01 | f5b477ab | -| M2 ULTRA | METAL | tiny-q8_0 | 1 | 1 | 6.16 | 0.94 | 0.21 | 0.01 | f5b477ab | -| M2 ULTRA | METAL | base | 1 | 1 | 10.63 | 1.37 | 0.32 | 0.01 | f5b477ab | -| M2 ULTRA | METAL | base-q5_0 | 1 | 1 | 11.75 | 1.27 | 0.33 | 0.02 | f5b477ab | -| M2 ULTRA | METAL | base-q5_1 | 1 | 1 | 11.73 | 1.25 | 0.33 | 0.02 | f5b477ab | -| M2 ULTRA | METAL | base-q8_0 | 1 | 1 | 11.17 | 1.28 | 0.32 | 0.02 | f5b477ab | -| M2 ULTRA | METAL | small | 1 | 1 | 31.74 | 2.55 | 0.67 | 0.04 | f5b477ab | -| M2 ULTRA | METAL | small-q5_0 | 1 | 1 | 36.21 | 2.47 | 0.69 | 0.04 | f5b477ab | -| M2 ULTRA | METAL | small-q5_1 | 1 | 1 | 36.22 | 2.47 | 0.70 | 0.04 | f5b477ab | -| M2 ULTRA | METAL | small-q8_0 | 1 | 1 | 32.73 | 2.45 | 0.66 | 0.04 | f5b477ab | -| M2 ULTRA | METAL | medium | 1 | 1 | 86.94 | 5.21 | 1.49 | 0.09 | f5b477ab | -| M2 ULTRA | METAL | medium-q5_0 | 1 | 1 | 104.31 | 4.93 | 1.51 | 0.10 | f5b477ab | -| M2 ULTRA | METAL | medium-q5_1 | 1 | 1 | 104.09 | 4.98 | 1.51 | 0.10 | f5b477ab | -| M2 ULTRA | METAL | medium-q8_0 | 1 | 1 | 92.13 | 5.06 | 1.45 | 0.09 | f5b477ab | -| M2 ULTRA | METAL | medium-dis | 1 | 1 | 76.67 | 0.81 | 0.20 | 0.01 | f5b477ab | -| M2 ULTRA | METAL | large-v2 | 1 | 1 | 167.66 | 7.56 | 2.25 | 0.16 | f5b477ab | -| M2 ULTRA | METAL | large-v2-q5_0 | 1 | 1 | 203.09 | 7.13 | 2.29 | 0.20 | f5b477ab | -| M2 ULTRA | METAL | large-v2-q5_1 | 1 | 1 | 202.53 | 7.12 | 2.29 | 0.20 | f5b477ab | -| M2 ULTRA | METAL | large-v2-q8_0 | 1 | 1 | 177.48 | 6.94 | 2.18 | 0.17 | f5b477ab | -| M2 ULTRA | METAL | large-v2-dis | 1 | 1 | 145.61 | 0.91 | 0.23 | 0.02 | f5b477ab | -| M2 ULTRA | METAL | large-v3-turbo | 1 | 1 | 146.95 | 1.33 | 0.36 | 0.03 | f5b477ab | -| M2 ULTRA | METAL | large-v3-turbo-q5_0 | 1 | 1 | 178.57 | 1.17 | 0.36 | 0.03 | f5b477ab | -| M2 ULTRA | METAL | large-v3-turbo-q8_0 | 1 | 1 | 156.19 | 1.21 | 0.34 | 0.03 | f5b477ab | +| M2 ULTRA | METAL | tiny | 1 | 1 | 6.03 | 0.86 | 0.20 | 0.01 | f14ae77f | +| M2 ULTRA | METAL | tiny-q5_0 | 1 | 1 | 6.46 | 0.84 | 0.21 | 0.01 | f14ae77f | +| M2 ULTRA | METAL | tiny-q5_1 | 1 | 1 | 6.46 | 0.85 | 0.21 | 0.01 | f14ae77f | +| M2 ULTRA | METAL | tiny-q8_0 | 1 | 1 | 6.14 | 0.88 | 0.20 | 0.01 | f14ae77f | +| M2 ULTRA | METAL | base | 1 | 1 | 10.87 | 1.24 | 0.31 | 0.01 | f14ae77f | +| M2 ULTRA | METAL | base-q5_0 | 1 | 1 | 11.98 | 1.18 | 0.31 | 0.02 | f14ae77f | +| M2 ULTRA | METAL | base-q5_1 | 1 | 1 | 12.07 | 1.18 | 0.31 | 0.02 | f14ae77f | +| M2 ULTRA | METAL | base-q8_0 | 1 | 1 | 11.13 | 1.19 | 0.30 | 0.02 | f14ae77f | +| M2 ULTRA | METAL | small | 1 | 1 | 31.46 | 2.37 | 0.63 | 0.04 | f14ae77f | +| M2 ULTRA | METAL | small-q5_0 | 1 | 1 | 36.16 | 2.31 | 0.65 | 0.04 | f14ae77f | +| M2 ULTRA | METAL | small-q5_1 | 1 | 1 | 36.57 | 2.31 | 0.65 | 0.04 | f14ae77f | +| M2 ULTRA | METAL | small-q8_0 | 1 | 1 | 32.94 | 2.27 | 0.63 | 0.04 | f14ae77f | +| M2 ULTRA | METAL | medium | 1 | 1 | 89.86 | 4.92 | 1.41 | 0.09 | f14ae77f | +| M2 ULTRA | METAL | medium-q5_0 | 1 | 1 | 107.12 | 4.72 | 1.42 | 0.10 | f14ae77f | +| M2 ULTRA | METAL | medium-q5_1 | 1 | 1 | 107.00 | 4.70 | 1.42 | 0.10 | f14ae77f | +| M2 ULTRA | METAL | medium-q8_0 | 1 | 1 | 94.93 | 4.56 | 1.37 | 0.09 | f14ae77f | +| M2 ULTRA | METAL | medium-dis | 1 | 1 | 79.66 | 0.78 | 0.20 | 0.01 | f14ae77f | +| M2 ULTRA | METAL | large-v2 | 1 | 1 | 170.06 | 7.13 | 2.15 | 0.16 | f14ae77f | +| M2 ULTRA | METAL | large-v2-q5_0 | 1 | 1 | 205.16 | 6.80 | 2.18 | 0.20 | f14ae77f | +| M2 ULTRA | METAL | large-v2-q5_1 | 1 | 1 | 204.22 | 6.69 | 2.16 | 0.20 | f14ae77f | +| M2 ULTRA | METAL | large-v2-q8_0 | 1 | 1 | 179.78 | 6.35 | 2.13 | 0.18 | f14ae77f | +| M2 ULTRA | METAL | large-v2-dis | 1 | 1 | 148.11 | 0.89 | 0.22 | 0.02 | f14ae77f | +| M2 ULTRA | METAL | large-v3-turbo | 1 | 1 | 149.23 | 1.29 | 0.34 | 0.03 | f14ae77f | +| M2 ULTRA | METAL | large-v3-turbo-q5_0 | 1 | 1 | 180.77 | 1.13 | 0.35 | 0.03 | f14ae77f | +| M2 ULTRA | METAL | large-v3-turbo-q8_0 | 1 | 1 | 158.66 | 1.10 | 0.33 | 0.03 | f14ae77f | ## M4 Max @@ -233,20 +233,6 @@ make -j && ./scripts/bench-all.sh 1 1 0 make -j && ./scripts/bench-all.sh 1 1 1 -| CPU | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit | -| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | -| M4 Max | METAL | tiny | 1 | 1 | 8.23 | 0.71 | 0.16 | 0.01 | 47fcd7da | -| M4 Max | METAL | tiny-q8_0 | 1 | 1 | 8.47 | 0.67 | 0.16 | 0.01 | 47fcd7da | -| M4 Max | METAL | base | 1 | 1 | 15.47 | 1.12 | 0.26 | 0.02 | 47fcd7da | -| M4 Max | METAL | base-q8_0 | 1 | 1 | 15.70 | 1.05 | 0.27 | 0.02 | 47fcd7da | -| M4 Max | METAL | small | 1 | 1 | 49.82 | 2.37 | 0.53 | 0.05 | 47fcd7da | -| M4 Max | METAL | small-q8_0 | 1 | 1 | 51.76 | 1.99 | 0.53 | 0.05 | 47fcd7da | -| M4 Max | METAL | medium | 1 | 1 | 147.76 | 5.52 | 1.27 | 0.12 | 47fcd7da | -| M4 Max | METAL | medium-q8_0 | 1 | 1 | 153.98 | 4.59 | 1.24 | 0.13 | 47fcd7da | -| M4 Max | METAL | large-v2 | 1 | 1 | 282.89 | 9.06 | 2.11 | 0.22 | 47fcd7da | -| M4 Max | METAL | large-v2-q8_0 | 1 | 1 | 296.43 | 7.44 | 2.09 | 0.23 | 47fcd7da | -| M4 Max | METAL | large-v3-turbo | 1 | 1 | 249.91 | 1.65 | 0.38 | 0.04 | 47fcd7da | - | CPU | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit | | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | | M4 Max | METAL | tiny | 1 | 1 | 8.23 | 0.72 | 0.16 | 0.01 | 47af2fb7 | @@ -262,41 +248,77 @@ make -j && ./scripts/bench-all.sh 1 1 1 | M4 Max | METAL | large-v3-turbo | 1 | 1 | 256.23 | 1.61 | 0.38 | 0.04 | 47af2fb7 | +## M5 Max + +make -j && ./scripts/bench-all.sh 1 1 0 + +| CPU | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit | +| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | +| M5 Max | METAL | tiny | 1 | 0 | 4.88 | 0.65 | 0.17 | 0.01 | f14ae77f | +| M5 Max | METAL | tiny-q8_0 | 1 | 0 | 4.84 | 0.63 | 0.17 | 0.01 | f14ae77f | +| M5 Max | METAL | base | 1 | 0 | 8.95 | 1.02 | 0.24 | 0.01 | f14ae77f | +| M5 Max | METAL | base-q8_0 | 1 | 0 | 9.12 | 0.94 | 0.24 | 0.01 | f14ae77f | +| M5 Max | METAL | small | 1 | 0 | 25.61 | 2.15 | 0.52 | 0.03 | f14ae77f | +| M5 Max | METAL | small-q8_0 | 1 | 0 | 25.77 | 1.93 | 0.50 | 0.03 | f14ae77f | +| M5 Max | METAL | medium | 1 | 0 | 73.96 | 4.61 | 1.16 | 0.08 | f14ae77f | +| M5 Max | METAL | medium-q8_0 | 1 | 0 | 74.89 | 3.94 | 1.12 | 0.08 | f14ae77f | +| M5 Max | METAL | large-v2 | 1 | 0 | 132.06 | 6.91 | 1.86 | 0.13 | f14ae77f | +| M5 Max | METAL | large-v2-q8_0 | 1 | 0 | 132.56 | 6.00 | 1.76 | 0.13 | f14ae77f | +| M5 Max | METAL | large-v3-turbo | 1 | 0 | 119.34 | 1.30 | 0.32 | 0.02 | f14ae77f | + + +make -j && ./scripts/bench-all.sh 1 1 1 + +| CPU | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit | +| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | +| M5 Max | METAL | tiny | 1 | 1 | 4.31 | 0.59 | 0.13 | 0.01 | f14ae77f | +| M5 Max | METAL | tiny-q8_0 | 1 | 1 | 4.51 | 0.55 | 0.12 | 0.01 | f14ae77f | +| M5 Max | METAL | base | 1 | 1 | 7.77 | 0.91 | 0.20 | 0.01 | f14ae77f | +| M5 Max | METAL | base-q8_0 | 1 | 1 | 7.67 | 0.78 | 0.19 | 0.01 | f14ae77f | +| M5 Max | METAL | small | 1 | 1 | 20.90 | 1.76 | 0.40 | 0.03 | f14ae77f | +| M5 Max | METAL | small-q8_0 | 1 | 1 | 21.32 | 1.62 | 0.38 | 0.03 | f14ae77f | +| M5 Max | METAL | medium | 1 | 1 | 60.40 | 3.98 | 0.89 | 0.07 | f14ae77f | +| M5 Max | METAL | medium-q8_0 | 1 | 1 | 60.72 | 3.35 | 0.86 | 0.07 | f14ae77f | +| M5 Max | METAL | large-v2 | 1 | 1 | 110.57 | 6.06 | 1.41 | 0.12 | f14ae77f | +| M5 Max | METAL | large-v2-q8_0 | 1 | 1 | 110.92 | 5.00 | 1.31 | 0.12 | f14ae77f | +| M5 Max | METAL | large-v3-turbo | 1 | 1 | 98.36 | 1.19 | 0.27 | 0.02 | f14ae77f | + + # RTX 5090 make -j && ./scripts/bench-all.sh 1 1 0 | GPU | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit | | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | -| RTX 5090 | CUDA | tiny | 1 | 0 | 2.20 | 0.51 | 0.13 | 0.01 | f5b477ab | -| RTX 5090 | CUDA | tiny-q8_0 | 1 | 0 | 2.35 | 0.52 | 0.14 | 0.01 | f5b477ab | -| RTX 5090 | CUDA | base | 1 | 0 | 3.97 | 0.77 | 0.20 | 0.01 | f5b477ab | -| RTX 5090 | CUDA | base-q8_0 | 1 | 0 | 4.20 | 0.73 | 0.20 | 0.01 | f5b477ab | -| RTX 5090 | CUDA | small | 1 | 0 | 11.87 | 1.48 | 0.40 | 0.02 | f5b477ab | -| RTX 5090 | CUDA | small-q8_0 | 1 | 0 | 12.40 | 1.59 | 0.42 | 0.02 | f5b477ab | -| RTX 5090 | CUDA | medium | 1 | 0 | 32.63 | 3.11 | 0.82 | 0.04 | f5b477ab | -| RTX 5090 | CUDA | medium-q8_0 | 1 | 0 | 31.80 | 3.23 | 0.84 | 0.05 | f5b477ab | -| RTX 5090 | CUDA | large-v2 | 1 | 0 | 52.22 | 4.66 | 1.18 | 0.06 | f5b477ab | -| RTX 5090 | CUDA | large-v2-q8_0 | 1 | 0 | 51.11 | 4.37 | 1.15 | 0.07 | f5b477ab | -| RTX 5090 | CUDA | large-v3-turbo | 1 | 0 | 48.72 | 0.70 | 0.18 | 0.01 | f5b477ab | -| RTX 5090 | CUDA | large-v3-turbo-q8_0 | 1 | 0 | 47.81 | 0.64 | 0.16 | 0.01 | f5b477ab | +| RTX 5090 | CUDA | tiny | 1 | 0 | 2.17 | 0.38 | 0.10 | 0.00 | f14ae77f | +| RTX 5090 | CUDA | tiny-q8_0 | 1 | 0 | 2.31 | 0.37 | 0.10 | 0.01 | f14ae77f | +| RTX 5090 | CUDA | base | 1 | 0 | 3.94 | 0.56 | 0.17 | 0.01 | f14ae77f | +| RTX 5090 | CUDA | base-q8_0 | 1 | 0 | 4.13 | 0.53 | 0.14 | 0.01 | f14ae77f | +| RTX 5090 | CUDA | small | 1 | 0 | 12.06 | 1.09 | 0.34 | 0.02 | f14ae77f | +| RTX 5090 | CUDA | small-q8_0 | 1 | 0 | 12.50 | 1.11 | 0.30 | 0.02 | f14ae77f | +| RTX 5090 | CUDA | medium | 1 | 0 | 33.08 | 2.38 | 0.70 | 0.04 | f14ae77f | +| RTX 5090 | CUDA | medium-q8_0 | 1 | 0 | 32.57 | 2.26 | 0.62 | 0.04 | f14ae77f | +| RTX 5090 | CUDA | large-v2 | 1 | 0 | 54.27 | 3.68 | 1.03 | 0.06 | f14ae77f | +| RTX 5090 | CUDA | large-v2-q8_0 | 1 | 0 | 53.11 | 3.22 | 0.89 | 0.06 | f14ae77f | +| RTX 5090 | CUDA | large-v3-turbo | 1 | 0 | 50.56 | 0.58 | 0.15 | 0.01 | f14ae77f | +| RTX 5090 | CUDA | large-v3-turbo-q8_0 | 1 | 0 | 49.39 | 0.49 | 0.13 | 0.01 | f14ae77f | make -j && ./scripts/bench-all.sh 1 1 1 | GPU | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit | | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | -| RTX 5090 | CUDA | tiny | 1 | 1 | 1.37 | 0.44 | 0.11 | 0.00 | f5b477ab | -| RTX 5090 | CUDA | tiny-q8_0 | 1 | 1 | 1.48 | 0.44 | 0.12 | 0.01 | f5b477ab | -| RTX 5090 | CUDA | base | 1 | 1 | 2.34 | 0.66 | 0.16 | 0.01 | f5b477ab | -| RTX 5090 | CUDA | base-q8_0 | 1 | 1 | 2.51 | 0.62 | 0.17 | 0.01 | f5b477ab | -| RTX 5090 | CUDA | small | 1 | 1 | 5.53 | 1.23 | 0.32 | 0.01 | f5b477ab | -| RTX 5090 | CUDA | small-q8_0 | 1 | 1 | 5.88 | 1.35 | 0.33 | 0.02 | f5b477ab | -| RTX 5090 | CUDA | medium | 1 | 1 | 15.09 | 2.55 | 0.65 | 0.03 | f5b477ab | -| RTX 5090 | CUDA | medium-q8_0 | 1 | 1 | 14.06 | 2.72 | 0.67 | 0.03 | f5b477ab | -| RTX 5090 | CUDA | large-v2 | 1 | 1 | 23.24 | 3.94 | 0.97 | 0.04 | f5b477ab | -| RTX 5090 | CUDA | large-v2-q8_0 | 1 | 1 | 22.00 | 3.68 | 0.93 | 0.05 | f5b477ab | -| RTX 5090 | CUDA | large-v3-turbo | 1 | 1 | 19.81 | 0.62 | 0.15 | 0.01 | f5b477ab | -| RTX 5090 | CUDA | large-v3-turbo-q8_0 | 1 | 1 | 18.62 | 0.56 | 0.14 | 0.01 | f5b477ab | +| RTX 5090 | CUDA | tiny | 1 | 1 | 1.29 | 0.31 | 0.07 | 0.00 | f14ae77f | +| RTX 5090 | CUDA | tiny-q8_0 | 1 | 1 | 1.45 | 0.31 | 0.07 | 0.00 | f14ae77f | +| RTX 5090 | CUDA | base | 1 | 1 | 2.15 | 0.44 | 0.13 | 0.01 | f14ae77f | +| RTX 5090 | CUDA | base-q8_0 | 1 | 1 | 2.27 | 0.43 | 0.10 | 0.01 | f14ae77f | +| RTX 5090 | CUDA | small | 1 | 1 | 5.54 | 0.83 | 0.26 | 0.01 | f14ae77f | +| RTX 5090 | CUDA | small-q8_0 | 1 | 1 | 5.95 | 0.84 | 0.22 | 0.01 | f14ae77f | +| RTX 5090 | CUDA | medium | 1 | 1 | 15.43 | 1.81 | 0.53 | 0.02 | f14ae77f | +| RTX 5090 | CUDA | medium-q8_0 | 1 | 1 | 14.71 | 1.66 | 0.46 | 0.03 | f14ae77f | +| RTX 5090 | CUDA | large-v2 | 1 | 1 | 24.73 | 2.92 | 0.81 | 0.04 | f14ae77f | +| RTX 5090 | CUDA | large-v2-q8_0 | 1 | 1 | 23.35 | 2.43 | 0.67 | 0.04 | f14ae77f | +| RTX 5090 | CUDA | large-v3-turbo | 1 | 1 | 21.36 | 0.49 | 0.13 | 0.01 | f14ae77f | +| RTX 5090 | CUDA | large-v3-turbo-q8_0 | 1 | 1 | 20.07 | 0.39 | 0.10 | 0.01 | f14ae77f | # DGX Spark @@ -318,22 +340,37 @@ make -j && ./scripts/bench-all.sh 1 1 0 | DGX Spk. | CUDA | large-v3-turbo | 1 | 0 | 264.90 | 2.03 | 0.37 | 0.03 | f5b477ab | | DGX Spk. | CUDA | large-v3-turbo-q8_0 | 1 | 0 | 253.56 | 1.48 | 0.27 | 0.03 | f5b477ab | +| GPU | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit | +| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | +| DGX Spk. | CUDA | tiny | 1 | 0 | 9.79 | 0.65 | 0.14 | 0.01 | f14ae77f | +| DGX Spk. | CUDA | tiny-q8_0 | 1 | 0 | 8.97 | 0.56 | 0.12 | 0.01 | f14ae77f | +| DGX Spk. | CUDA | base | 1 | 0 | 18.58 | 1.04 | 0.22 | 0.01 | f14ae77f | +| DGX Spk. | CUDA | base-q8_0 | 1 | 0 | 17.36 | 0.88 | 0.18 | 0.02 | f14ae77f | +| DGX Spk. | CUDA | small | 1 | 0 | 56.78 | 2.33 | 0.51 | 0.04 | f14ae77f | +| DGX Spk. | CUDA | small-q8_0 | 1 | 0 | 55.47 | 1.99 | 0.43 | 0.04 | f14ae77f | +| DGX Spk. | CUDA | medium | 1 | 0 | 158.21 | 5.71 | 1.23 | 0.11 | f14ae77f | +| DGX Spk. | CUDA | medium-q8_0 | 1 | 0 | 151.17 | 4.54 | 0.97 | 0.11 | f14ae77f | +| DGX Spk. | CUDA | large-v2 | 1 | 0 | 269.59 | 10.48 | 2.13 | 0.20 | f14ae77f | +| DGX Spk. | CUDA | large-v2-q8_0 | 1 | 0 | 262.82 | 7.43 | 1.61 | 0.20 | f14ae77f | +| DGX Spk. | CUDA | large-v3-turbo | 1 | 0 | 263.91 | 1.80 | 0.37 | 0.03 | f14ae77f | +| DGX Spk. | CUDA | large-v3-turbo-q8_0 | 1 | 0 | 252.89 | 1.23 | 0.26 | 0.03 | f14ae77f | + make -j && ./scripts/bench-all.sh 1 1 1 | GPU | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit | | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | -| DGX Spk. | CUDA | tiny | 1 | 1 | 2.63 | 0.76 | 0.13 | 0.01 | f5b477ab | -| DGX Spk. | CUDA | tiny-q8_0 | 1 | 1 | 2.46 | 0.73 | 0.11 | 0.01 | f5b477ab | -| DGX Spk. | CUDA | base | 1 | 1 | 4.96 | 1.24 | 0.20 | 0.01 | f5b477ab | -| DGX Spk. | CUDA | base-q8_0 | 1 | 1 | 4.23 | 1.08 | 0.17 | 0.01 | f5b477ab | -| DGX Spk. | CUDA | small | 1 | 1 | 16.26 | 2.73 | 0.47 | 0.02 | f5b477ab | -| DGX Spk. | CUDA | small-q8_0 | 1 | 1 | 14.94 | 2.38 | 0.39 | 0.02 | f5b477ab | -| DGX Spk. | CUDA | medium | 1 | 1 | 51.81 | 6.94 | 1.22 | 0.05 | f5b477ab | -| DGX Spk. | CUDA | medium-q8_0 | 1 | 1 | 41.51 | 5.44 | 0.93 | 0.05 | f5b477ab | -| DGX Spk. | CUDA | large-v2 | 1 | 1 | 98.54 | 11.53 | 2.05 | 0.08 | f5b477ab | -| DGX Spk. | CUDA | large-v2-q8_0 | 1 | 1 | 91.61 | 8.49 | 1.55 | 0.08 | f5b477ab | -| DGX Spk. | CUDA | large-v3-turbo | 1 | 1 | 87.20 | 1.94 | 0.36 | 0.02 | f5b477ab | -| DGX Spk. | CUDA | large-v3-turbo-q8_0 | 1 | 1 | 80.28 | 1.38 | 0.26 | 0.01 | f5b477ab | +| DGX Spk. | CUDA | tiny | 1 | 1 | 2.72 | 0.56 | 0.13 | 0.01 | f14ae77f | +| DGX Spk. | CUDA | tiny-q8_0 | 1 | 1 | 2.55 | 0.47 | 0.11 | 0.01 | f14ae77f | +| DGX Spk. | CUDA | base | 1 | 1 | 5.08 | 0.90 | 0.20 | 0.01 | f14ae77f | +| DGX Spk. | CUDA | base-q8_0 | 1 | 1 | 4.38 | 0.72 | 0.16 | 0.01 | f14ae77f | +| DGX Spk. | CUDA | small | 1 | 1 | 16.95 | 2.00 | 0.47 | 0.02 | f14ae77f | +| DGX Spk. | CUDA | small-q8_0 | 1 | 1 | 15.67 | 1.67 | 0.39 | 0.02 | f14ae77f | +| DGX Spk. | CUDA | medium | 1 | 1 | 53.12 | 5.10 | 1.24 | 0.06 | f14ae77f | +| DGX Spk. | CUDA | medium-q8_0 | 1 | 1 | 43.64 | 3.87 | 0.91 | 0.05 | f14ae77f | +| DGX Spk. | CUDA | large-v2 | 1 | 1 | 102.15 | 9.58 | 2.02 | 0.08 | f14ae77f | +| DGX Spk. | CUDA | large-v2-q8_0 | 1 | 1 | 93.86 | 6.54 | 1.49 | 0.08 | f14ae77f | +| DGX Spk. | CUDA | large-v3-turbo | 1 | 1 | 90.29 | 1.69 | 0.36 | 0.02 | f14ae77f | +| DGX Spk. | CUDA | large-v3-turbo-q8_0 | 1 | 1 | 82.79 | 1.13 | 0.25 | 0.01 | f14ae77f | # V100 From e0fd1f6787a5bd4a4957dd97c5b64df882ee7b0c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 25 May 2026 13:06:33 +0300 Subject: [PATCH 691/831] release : v1.8.5 --- CMakeLists.txt | 2 +- bindings/javascript/package.json | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index a0f74041321..2200673d0a3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,6 +1,6 @@ cmake_minimum_required(VERSION 3.5) # for add_link_options and implicit target directories. project("whisper.cpp" C CXX) -project("whisper.cpp" VERSION 1.8.4) +project("whisper.cpp" VERSION 1.8.5) include(CheckIncludeFileCXX) set(SOVERSION 1) diff --git a/bindings/javascript/package.json b/bindings/javascript/package.json index 074dfdda307..caf12b6dd2d 100644 --- a/bindings/javascript/package.json +++ b/bindings/javascript/package.json @@ -1,6 +1,6 @@ { "name": "whisper.cpp", - "version": "1.8.4", + "version": "1.8.5", "description": "Whisper speech recognition", "main": "whisper.js", "scripts": { From 27101c01dcac1676e2b6422256233cd0f1f9ae28 Mon Sep 17 00:00:00 2001 From: texasich <101962694+texasich@users.noreply.github.com> Date: Mon, 25 May 2026 23:23:41 -0500 Subject: [PATCH 692/831] cli : merge tokens split across UTF-8 boundaries in JSON output (#3751) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * cli : merge tokens split across UTF-8 boundaries in JSON output When a multi-byte UTF-8 codepoint (most commonly a CJK character, 3 bytes) is split across multiple whisper tokens, the -ojf/--output-json-full writer emitted each token's partial bytes as its own JSON string, producing invalid UTF-8 that chokes downstream parsers. Merge adjacent tokens in output_json whenever the accumulated text still ends on an incomplete UTF-8 sequence. The merged entry keeps the first token's id/p/t_dtw and extends t1 to the last absorbed token, which matches how segment text is assembled elsewhere. Refs #1798 * fix: address review — add braces for consistency, use full issue URL - Add braces to if/else chain for codebase consistency - Use full URL for issue #1798 reference Review: @danbev --------- Co-authored-by: texasich Co-authored-by: texasich --- examples/cli/cli.cpp | 80 +++++++++++++++++++++++++++++++++++++++----- 1 file changed, 71 insertions(+), 9 deletions(-) diff --git a/examples/cli/cli.cpp b/examples/cli/cli.cpp index 4e84c1b2750..55cd71b4e55 100644 --- a/examples/cli/cli.cpp +++ b/examples/cli/cli.cpp @@ -31,6 +31,39 @@ static void replace_all(std::string & s, const std::string & search, const std:: } } +// Returns the number of trailing continuation bytes still needed for `s` to end +// on a complete UTF-8 codepoint. Returns 0 if the tail of `s` is already a +// complete codepoint (or if the tail looks malformed and we should stop merging). +// Used to merge whisper tokens whose bytes split a multi-byte UTF-8 character +// (e.g. CJK), so the JSON output stays valid UTF-8. See https://github.com/ggml-org/whisper.cpp/issues/1798. +static int utf8_trailing_bytes_needed(const std::string & s) { + const int n = (int) s.size(); + int i = n - 1; + // walk back past continuation bytes (10xxxxxx) + while (i >= 0 && ((unsigned char) s[i] & 0xC0) == 0x80) { + --i; + } + if (i < 0) { + // all continuation bytes, or empty — nothing we can do + return 0; + } + const unsigned char c = (unsigned char) s[i]; + int expected; + if ((c & 0x80) == 0x00) { + expected = 1; // ASCII + } else if ((c & 0xE0) == 0xC0) { + expected = 2; + } else if ((c & 0xF0) == 0xE0) { + expected = 3; + } else if ((c & 0xF8) == 0xF0) { + expected = 4; + } else { + return 0; // malformed lead, give up + } + const int have = n - i; + return have >= expected ? 0 : (expected - have); +} + // command-line parameters struct whisper_params { int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); @@ -738,18 +771,47 @@ static void output_json( if (full) { start_arr("tokens"); const int n = whisper_full_n_tokens(ctx, i); - for (int j = 0; j < n; ++j) { - auto token = whisper_full_get_token_data(ctx, i, j); + + // Merge adjacent tokens whose bytes together form a + // single UTF-8 codepoint. Multi-byte characters (CJK + // in particular) can end up split across whisper + // tokens, which used to produce invalid UTF-8 in the + // JSON string. Refs issue #1798. + struct merged_token { + std::string text; + whisper_token_data data; + int64_t t1; + }; + std::vector merged; + merged.reserve(n); + for (int j = 0; j < n; ) { + auto tok = whisper_full_get_token_data(ctx, i, j); + merged_token m{ whisper_token_to_str(ctx, tok.id), tok, tok.t1 }; + ++j; + while (j < n && utf8_trailing_bytes_needed(m.text) > 0) { + auto tok_next = whisper_full_get_token_data(ctx, i, j); + m.text += whisper_token_to_str(ctx, tok_next.id); + if (tok_next.t1 > -1) { + m.t1 = tok_next.t1; + } + ++j; + } + merged.push_back(std::move(m)); + } + + const int nm = (int) merged.size(); + for (int j = 0; j < nm; ++j) { + const auto & mt = merged[j]; start_obj(nullptr); - value_s("text", whisper_token_to_str(ctx, token.id), false); - if(token.t0 > -1 && token.t1 > -1) { + value_s("text", mt.text.c_str(), false); + if (mt.data.t0 > -1 && mt.t1 > -1) { // If we have per-token timestamps, write them out - times_o(token.t0, token.t1, false); + times_o(mt.data.t0, mt.t1, false); } - value_i("id", token.id, false); - value_f("p", token.p, false); - value_f("t_dtw", token.t_dtw, true); - end_obj(j == (n - 1)); + value_i("id", mt.data.id, false); + value_f("p", mt.data.p, false); + value_f("t_dtw", mt.data.t_dtw, true); + end_obj(j == (nm - 1)); } end_arr(!params.diarize && !params.tinydiarize); } From ee540bf0be55d2a5176872adfb519e70f4fe0e9a Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Wed, 27 May 2026 06:22:38 +0200 Subject: [PATCH 693/831] docs : add AGENTS.md and CONTRIBUTING.md [no ci] (#3826) * docs : add AGENTS.md and CONTRIBUTING.md [no ci] This commit add AGENTS.md and CONTRIBUTING.md which are based on the same files in llama.cpp. They have been modified slightly to fit with whisper.cpp. The motivation for this is to clarify the contribution policy in whisper.cpp so that contributers can have a better understanding of the expectations and requirements for contributing to the project. --- AGENTS.md | 102 +++++++++++++++++++++++++++ CONTRIBUTING.md | 176 +++++++++++++++++++++++++++++++++++++++++++++++ media/matmul.png | Bin 0 -> 265705 bytes 3 files changed, 278 insertions(+) create mode 100644 AGENTS.md create mode 100644 CONTRIBUTING.md create mode 100644 media/matmul.png diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 00000000000..f34f3249977 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,102 @@ +# Instructions for whisper.cpp + +> [!IMPORTANT] +> This project does **not** accept pull requests that are fully or predominantly AI-generated. AI tools may be utilized solely in an assistive capacity. +> +> Read more: [CONTRIBUTING.md](CONTRIBUTING.md) + +AI assistance is permissible only when the majority of the code is authored by a human contributor, with AI employed exclusively for corrections or to expand on verbose modifications that the contributor has already conceptualized (see examples below). + +--- + +## Guidelines for Contributors Using AI + +whisper.cpp is built by humans, for humans. Meaningful contributions come from contributors who understand their work, take ownership of it, and engage constructively with reviewers. + +Maintainers receive numerous pull requests weekly, many of which are AI-generated submissions where the author cannot adequately explain the code, debug issues, or participate in substantive design discussions. Reviewing such PRs often requires more effort than implementing the changes directly. + +**A pull request represents a long-term commitment.** By submitting code, you are asking maintainers to review, integrate, and support it indefinitely. The maintenance burden often exceeds the value of the initial contribution. + +Most maintainers already have access to AI tools. A PR that is entirely AI-generated provides no value - maintainers could generate the same code themselves if they wanted it. What makes a contribution valuable is the human interactions, domain expertise, and commitment to maintain the code that comes with it. + +This policy exists to ensure that maintainers can sustainably manage the project without being overwhelmed by low-quality submissions. + +--- + +## Guidelines for Contributors + +Contributors are expected to: + +1. **Demonstrate full understanding of their code.** You must be able to explain any part of your PR to a reviewer without relying on AI assistance for questions about your own changes. + +2. **Take responsibility for maintenance.** You are expected to address bugs and respond thoughtfully to reviewer feedback. + +3. **Communicate clearly and concisely.** Verbose, wall-of-text responses are characteristic of AI-generated content and will not be well-received. Direct, human communication is expected. + +4. **Respect maintainers' time.** Search for existing issues and discussions before submitting. Ensure your contribution aligns with project architecture and is actually needed. + +Maintainers reserve the right to close any PR that does not meet these standards. This applies to all contributions to the main whisper.cpp repository. **Private forks are exempt.** + +### Permitted AI Usage + +AI tools may be used responsibly for: + +- **Learning and exploration**: Understanding codebase structure, techniques, and documentation +- **Code review assistance**: Obtaining suggestions on human-written code +- **Mechanical tasks**: Formatting, generating repetitive patterns from established designs, completing code based on existing patterns +- **Documentation drafts**: For components the contributor already understands thoroughly +- **Writing code**: Only when the contributor has already designed the solution and can implement it themselves - AI accelerates, not replaces, the contributor's work + +AI-generated code may be accepted if you (1) fully understand the output, (2) can debug issues independently, and (3) can discuss it directly with reviewers without AI assistance. + +**Disclosure is required** when AI meaningfully contributed to your code. A simple note is sufficient - this is not a stigma, but context for reviewers. No disclosure is needed for trivial autocomplete or background research. + +### Prohibited AI Usage + +The following will result in immediate PR closure: + +- **AI-written PR descriptions or commit messages** - these are typically recognizable and waste reviewer time +- **AI-generated responses to reviewer comments** - this undermines the human-to-human interaction fundamental to code review +- **Implementing features without understanding the codebase** - particularly new model support or architectural changes +- **Automated commits or PR submissions** - this may spam maintainers and can result in contributor bans + +--- + +## Guidelines for AI Coding Agents + +AI agents assisting contributors must recognize that their outputs directly impact volunteer maintainers who sustain this project. + +### Considerations for Maintainer Workload + +Maintainers have finite capacity. Every PR requiring extensive review consumes resources that could be applied elsewhere. Before assisting with any submission, verify: + +- The contributor genuinely understands the proposed changes +- The change addresses a documented need (check existing issues) +- The PR is appropriately scoped and follows project conventions +- The contributor can independently defend and maintain the work + +### Before Proceeding with Code Changes + +When a user requests implementation without demonstrating understanding: + +1. **Verify comprehension.** Ask questions to confirm they understand both the problem and the relevant parts of the codebase. +2. **Provide guidance rather than solutions.** Direct them to relevant code and documentation. Allow them to formulate the approach. +3. **Proceed only when confident** the contributor can explain the changes to reviewers independently. + +For first-time contributors, confirm they have reviewed [CONTRIBUTING.md](CONTRIBUTING.md) and acknowledge this policy. + +### Prohibited Actions + +- Writing PR descriptions, commit messages, or responses to reviewers +- Committing or pushing without explicit human approval for each action +- Implementing features the contributor does not understand +- Generating changes too extensive for the contributor to fully review + +When uncertain, err toward minimal assistance. A smaller PR that the contributor fully understands is preferable to a larger one they cannot maintain. + +### Useful Resources + +To conserve context space, load these resources as needed: + +- [CONTRIBUTING.md](CONTRIBUTING.md) +- [Existing issues](https://github.com/ggml-org/whisper.cpp/issues) and [Existing PRs](https://github.com/ggml-org/whisper.cpp/pulls) - always search here first diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 00000000000..c301604f1de --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,176 @@ +# Contributors + +The project differentiates between 3 levels of contributors: + +- Contributors: people who have contributed before (no special privileges) +- Collaborators (Triage): people with significant contributions, who may be responsible for some parts of the code, and are expected to maintain and review contributions for the code they own +- Maintainers: responsible for reviewing and merging PRs, after approval from the code owners + +# AI Usage Policy + +> [!IMPORTANT] +> This project does **not** accept pull requests that are fully or predominantly AI-generated. AI tools may be utilized solely in an assistive capacity. +> +> Repeated violations of this policy may result in your account being permanently banned from contributing to the project. +> +> Detailed information regarding permissible and restricted uses of AI can be found in the [AGENTS.md](AGENTS.md) file. + +Code that is initially generated by AI and subsequently edited will still be considered AI-generated. AI assistance is permissible only when the majority of the code is authored by a human contributor, with AI employed exclusively for corrections or to expand on verbose modifications that the contributor has already conceptualized (e.g., generating repeated lines with minor variations). + +If AI is used to generate any portion of the code, contributors must adhere to the following requirements: + +1. Explicitly disclose the manner in which AI was employed. +2. Perform a comprehensive manual review prior to submitting the pull request. +3. Be prepared to explain every line of code they submitted when asked about it by a maintainer. +4. It is strictly prohibited to use AI to write your posts for you (bug reports, feature requests, pull request descriptions, Github discussions, responding to humans, ...). + +For more info, please refer to the [AGENTS.md](AGENTS.md) file. + +# Pull requests (for contributors & collaborators) + +Before submitting your PR: +- Search for existing PRs to prevent duplicating efforts +- whisper.cpp uses the ggml tensor library for model evaluation. If you are unfamiliar with ggml, consider taking a look at the [examples in the ggml repository](https://github.com/ggml-org/ggml/tree/master/examples/). [simple](https://github.com/ggml-org/ggml/tree/master/examples/simple) shows the bare minimum for using ggml. [gpt-2](https://github.com/ggml-org/ggml/tree/master/examples/gpt-2) has minimal implementations for language model inference using GPT-2. [mnist](https://github.com/ggml-org/ggml/tree/master/examples/mnist) demonstrates how to train and evaluate a simple image classifier +- Test your changes: + - Execute [the full CI locally on your machine](ci/README.md) before publishing +- Create separate PRs for each feature or fix: + - Avoid combining unrelated changes in a single PR + - For intricate features, consider opening a feature request first to discuss and align expectations +- If you are a new contributor + - Limit your open PRs to 1 + - Do not submit trivial fixes (e.g. typos, formatting changes) + +After submitting your PR: +- Expect requests for modifications to ensure the code meets whisper.cpp's standards for quality and long-term maintainability +- Maintainers will rely on your insights and approval when making a final decision to approve and merge a PR +- If your PR becomes stale, rebase it on top of latest `master` to get maintainers attention + +# Pull requests (for maintainers) + +- Squash-merge PRs +- Use the following format for the squashed commit title: ` : (#)`. For example: `utils : fix typo in utils.py (#1234)` +- Optionally pick a `` from here: https://github.com/ggml-org/llama.cpp/wiki/Modules +- Let other maintainers merge their own PRs +- When merging a PR, make sure you have a good understanding of the changes +- Be mindful of maintenance: most of the work going into a feature happens after the PR is merged. If the PR author is not committed to contribute long-term, someone else needs to take responsibility (you) + +Maintainers reserve the right to decline review or close pull requests for any reason, without any questions, particularly under any of the following conditions: +- The proposed change is already mentioned in the roadmap or an existing issue, and it has been assigned to someone. +- The pull request duplicates an existing one. +- The contributor fails to adhere to this contributing guide or the AI policy. + +# Coding guidelines + +- Avoid adding third-party dependencies, extra files, extra headers, etc. +- Always consider cross-compatibility with other operating systems and architectures +- Avoid fancy-looking modern STL constructs, use basic `for` loops, avoid templates, keep it simple +- Vertical alignment makes things more readable and easier to batch edit +- Clean-up any trailing whitespaces, use 4 spaces for indentation, brackets on the same line, `void * ptr`, `int & a` +- Use sized integer types such as `int32_t` in the public API, e.g. `size_t` may also be appropriate for allocation sizes or byte offsets +- Declare structs with `struct foo {}` instead of `typedef struct foo {} foo` + - In C++ code omit optional `struct` and `enum` keyword whenever they are not necessary + ```cpp + // OK + llama_context * ctx; + const llama_rope_type rope_type; + + // not OK + struct llama_context * ctx; + const enum llama_rope_type rope_type; + ``` + + _(NOTE: this guideline is yet to be applied to the `whisper.cpp` codebase. New code should follow this guideline.)_ + +- Try to follow the existing patterns in the code (indentation, spaces, etc.). In case of doubt use `clang-format` (from clang-tools v15+) to format the added code +- For anything not covered in the current guidelines, refer to the [C++ Core Guidelines](https://isocpp.github.io/CppCoreGuidelines/CppCoreGuidelines) +- Tensors store data in row-major order. We refer to dimension 0 as columns, 1 as rows, 2 as matrices +- Matrix multiplication is unconventional: [`C = ggml_mul_mat(ctx, A, B)`](https://github.com/ggml-org/llama.cpp/blob/880e352277fc017df4d5794f0c21c44e1eae2b84/ggml.h#L1058-L1064) means $C^T = A B^T \Leftrightarrow C = B A^T.$ + +![matmul](media/matmul.png) + +# Naming guidelines + +- Use `snake_case` for function, variable and type names +- Naming usually optimizes for longest common prefix (see https://github.com/ggml-org/ggml/pull/302#discussion_r1243240963) + + ```cpp + // not OK + int small_number; + int big_number; + + // OK + int number_small; + int number_big; + ``` + +- Enum values are always in upper case and prefixed with the enum name + + ```cpp + enum llama_vocab_type { + LLAMA_VOCAB_TYPE_NONE = 0, + LLAMA_VOCAB_TYPE_SPM = 1, + LLAMA_VOCAB_TYPE_BPE = 2, + LLAMA_VOCAB_TYPE_WPM = 3, + LLAMA_VOCAB_TYPE_UGM = 4, + LLAMA_VOCAB_TYPE_RWKV = 5, + }; + ``` + +- The general naming pattern is `_`, with `` being `_` + + ```cpp + llama_model_init(); // class: "llama_model", method: "init" + llama_sampler_chain_remove(); // class: "llama_sampler_chain", method: "remove" + llama_sampler_get_seed(); // class: "llama_sampler", method: "get_seed" + llama_set_embeddings(); // class: "llama_context", method: "set_embeddings" + llama_n_threads(); // class: "llama_context", method: "n_threads" + llama_adapter_lora_free(); // class: "llama_adapter_lora", method: "free" + ``` + + - The `get` `` can be omitted + - The `` can be omitted if not necessary + - The `_context` suffix of the `` is optional. Use it to disambiguate symbols when needed + - Use `init`/`free` for constructor/destructor `` + +- Use the `_t` suffix when a type is supposed to be opaque to the user - it's not relevant to them if it is a struct or anything else + + ```cpp + typedef struct llama_context * llama_context_t; + + enum llama_pooling_type llama_pooling_type(const llama_context_t ctx); + ``` + + _(NOTE: this guideline is yet to be applied to the `whisper.cpp` codebase. New code should follow this guideline)_ + +- C/C++ filenames are all lowercase with dashes. Headers use the `.h` extension. Source files use the `.c` or `.cpp` extension +- Python filenames are all lowercase with underscores + +- _(TODO: abbreviations usage)_ + +# Preprocessor directives + +- _(TODO: add guidelines with examples and apply them to the codebase)_ + + ```cpp + #ifdef FOO + #endif // FOO + ``` + +# Code maintenance + +- New code should follow the guidelines (coding, naming, etc.) outlined in this document. Exceptions are allowed in isolated, backend-specific parts of the code that do not interface directly with the `ggml` interfaces. + _(NOTE: for legacy reasons, existing code is not required to follow this guideline)_ + +- For changes in server, please make sure to refer to the [server development documentation](./tools/server/README-dev.md) + +# Documentation + +- Documentation is a community effort +- When you need to look into the source code to figure out how to use an API consider adding a short summary to the header file for future reference +- When you notice incorrect or outdated documentation, please update it + +# Resources + +The Github issues, PRs and discussions contain a lot of information that can be useful to get familiar with the codebase. For convenience, some of the more important information is referenced from Github projects: + +https://github.com/ggml-org/whisper.cpp/projects diff --git a/media/matmul.png b/media/matmul.png new file mode 100644 index 0000000000000000000000000000000000000000..786a20492c02b4ee83fcb2a2bcefa0699ee7a55c GIT binary patch literal 265705 zcmeFZXHZqy7B0GNTWu4x0TmQ5f{K7BNCs^Y6crQ+l9ebRAUOxyZV@Dy06|iNNRTW! z8VE|1oI%MM$r|{X@?p z*t+)LFYxyZ@do<;e);+0$MXMvbyj@4CAsaNe@;D1IQ{Scr%;|e{VQ&vB?)q<+Yotx+_`fKA3%Uq7*$0a+tso%aP!1{k~i6zP4khY#*y^*Y4CvN^;;=|rN zH{TcL<=N8G!Y%c44R-u5r%^^Ka|3?G+5eQqK+PT=ec9Ly&vcup%h&3~Otp**5#kO1V=^_o^ zH|O;JH?KOH+m|iP`@Y`pox1d#rQi~4plpxQZk2oIgeJe9&r)GKal-w>gAGOQ^rtN| zjGUNmjUFKl^}Sle>(__d>?g`ps_EV=^ZxCJ zBs|?uR|uK;4o>&3nO1A;XCl`g-Hg}jU|yLmDz8^mR2(ZkUGaj!+%vL8M?c^9ey>iM z|FC^vU?5{eu%qn}7V402&Kzx*wlG2oW{(%Mg|81hRNNk1@#TwauETmT{n~N_CBrqU-&_?` zpXsYSL?iU${U)i^8nSb7Y7b`SH`Ur_+xJg+U7UKJ$Q!9Zr5>dm+o`7R_fMTtOgaly z)dXwlZ;i67u4AwIrxK67yq?O&IVq)07as2_2;es>&TdaLZf`KoTawI4GiuF5SuT{< z|7Z1v8xjv~+O)|sccdojhKgmQwszg_ojb)$OicV{O6F?{9zBv4T{8ErCo5;PdN*=7 z%iP9oLgVA(2M=~+4#u%fbdDV?7Omy~sB@gGEah+_ijZ$0GAtC=xNRGk{p^oOg?Kf; zDZ8d%X1;s3{AQLdoO0E2Y`YyZ4^b@C<&S$yhTHvy-;6XupTN@788#&!o6sCNc(7b_ zWpTDdeOabw>F>?E3l(z5y=6uXTQhk?mS;O}rruRNp_p_yE^{FI&exE0auU?CbQXA< zp3e71I+&?7%gSqLD5;`j=<7e{Puxrhsj*{MPIZ@iEAKtG_;Bl9jY4<&M6EB+8nSK7 zGFUIJy*5Lg;|Uc_iJB7o%)g#Xv>a}d8MPXyIVLjK^;{{4pe71MzixhNuroL@F(xMF z4o-ONbTF^6+CG!~ZQ5;lKR#~mq;;A-&o1@rSr)bh1zr5YElysvuu~7pw{2UyeoJJ5 z9wo$5_{aOTCgWdNM;HB1>q=!xkVd0_4u6vGR30RY%gM<}j#fjm-i_GulW0k%qs_q^ zzNvS=wj&oY;EC16kyN>U$bk25!%0&!GtHHy8LtO40!Ia`M)tL4T09ewC$@?YT{L z?8c{4Zhv}cQhY2zR+*DrDCJ_>24?${I;AfYo2Nd0{HU~9aQE)rMd>D;D9#+anLD`B z*wSClwmu1@_QzY&WN&BZDx8?m>D=GOWgf?oyV1?_*v(C&_{A~pHc{+F817zfWEFJz z-5PF6F&T_a^RXR$E-ez3!cSWD?BYa4N%rWC-D_s8sHJ-tFY}HM%_z)qXiT&Os?l1tiNDnWYpEwl~2mKyV!f|!v^O3$DFx)ICEz;cJAKo zn@0b=&UmLK1N+HJiOLj%df_;nbdA2R|0Le24Bb?gWj)sMQqaoCZzgbn ziml(bbu{V*w{}J_`n-6NKh*fnz}^RYV%mNqL6r8a9u<#ACSjbh#xZi=`O0KK3LV?$ zfkhIoGiuA`e@M$5l}$dvhHoQy zBcxcTJissgOux7`HnFqkm9doU!qB^lLR9&Vife5;`slf#C~nt!YEd{}nNnYwQs**m z%N9$aYaEP^!D#le8GUrf^AQcBK+WtHN{ukJ;A9)~U3cuLP0_DCjxqYwx^>jPSI~9f zjA7Hx`0R1X&?eNk7*4$3jF;&0ET;G6PAFiQ;I=J z_J7x*jJJK?US6XsD(W|H-h5E~!YSJ9uO%BwF?EET^}Z4;A}q2rSuK|0-57ZnGe0zw zj!ph>wB6EFeaU~8Gt+3i>&sY?=yKv86(=x$<C2Q>Ha40R((B4vF9Gg}S+FU@U6ED5%0Ax3-tkh?^HMU+>V1bkd-iPh1Fhrl zQL@n=E6V(N-iqrS=iuO&*}-5pGZHwvM4?RoT-!@k2@{nP%#3L0F+b)niE0=(ryAPD z96Ns8?bwYsFKnFNJFmGOua=Ylj*1k{zOh*xE9zXDX^*dYM^0UjUhd>b)@WgGu#;Mk zLvZGAXKhTYgvn)C)T53yI`Umz^*7M$yBWpwcsHk%tZ#5|UXlO*+#wP;Ge$B8J6g;t zTIc59sxRp?M3A?NK7alC^%m0-c5lbrHy_Kp9R@VSy4%lFZnf_xgZgYZhBQa2QLBGs z>Zx-bz*~j>JOc)AVW=xA#f(CY>~wLU|M9BfK6W*%-%lb+dwS3b+bz@;G; zqmp5oHi!x7ioF>cn?paZ6y0E%hEPkR+jxA0rOIfP%*j6EQWY@Y#*XMzv zI2tbnEcE?m7N&>y0y)(J+UB@WH{08Xd#D9)44DA$Y?r>Eo6$b3loW)#okq z?BZ10huU@n5can9FUVs6lFOc|Q?A6Vk$KnE+glXwtHwYo{Ob_)!mhn?d;GYhB>BSC&e^&S-U>$yk@Osmb{G2_F%_(y(_QIu)n$Dw|(8ZbzCML`T0xA zsdwq6Mdq}k3?#(34eCzdpaom!Bn_ptJIbl8!V3Lj_OzDmU`Or4^WQ!^2$K#L$x?ai z@jMW}tJu9Xhj}gAZ$p_DArh4L$N_9)q_X4myWj#UtN{CtjLt8~=zdhgsZL?}NHm-@hZVz-=8K;`n8uX>4 zq`5PKt?sP?r<ziX%-%i3p{!I5ATIq{75-pr~4ied+U8p zi(_(Xqz3>kW`bDd);zdilsnfWLx89Rla1CDf9|Lj3#*}qZKRxl78ObYDA(jxDUx$$ zceS*r2d(I0g1oK6uz~vH_Qf9FL{MJqZf}KOSGBSI2bXbx3Yv}ca+;^m!V0~d7CTP` z4`S|wPC?#K>~ExmWKv<{wF6{QqrofnVE-Ukpk5#Ib6s2C^XLEJ?u?t)Fo{6ck zBa97;`(o~{FNf;(soyQv?|=RqQ9r0J5u2j+Rw3S#Ek64cVEAq}`IyL*pK3tVNJeka z&1~y2qk3tfvHbLJe{L>zIHT8)k&uue^KQ-AtDhy&mCR|eYJi)P;h`3Tb%d-4*Euc? zs(Xnz%oFMp1G0s)7#10M6cS0C5zv#pnC!1}XL|g@%S&<3==Qe|#r6Nz4EPR;bOo(}Q60P+LmKU0R_ZXBYdW zotoxsQ74taAFC3|(3f?)RepC`o(Q(ea-8e-EsDGzwmr5oNxNhgJnZmOQk;HITG^$=-a@EJEQvN69)fq#E znN}!ZJ zP9;^Z;)74%E*dzwtN7si)P}7LJ0PambWpw)}jVc3i z4GOOb;a=q;tGI~V_zo$)e?AWCV(Rnv20Tx91jV7F5+NXwGQ+giU0`ETQ@K(&1vjXt z;K{|R1hbXpr9-T&5i-J_BJ+LuDb*3u@v^H&J(7)Z)!E>cx!~l8)3M&KB@8j#5@IYS-)&Y9mH?lpP@;4x~Mqnc%D0?AU| zwiV@9ROqp{B(@RsxHL#uFn2m7UM`cCBD-N#Zadl^RSrnNb?d#eM7QWMyc`kl={7Ep2FBGly;i0J{}*; zJ*0w->sCWsq;A{`1JrpUhL>21+*ZMgKtW!UVK6bpa5d_yJaP5mP$hHf=v14xf*L)T z?>%ojzD9cIOff2MH}#-D&2n9}v(SUg6S)DoYuBEnqvw0Y%B-^GRVEhY@`M|to^)&d z{-@kSLLlu8nHB~qlfOC_Q`N^1l}n$>S%d9O@pB1P4&X5KDCz4R^Qgb9svPRo?BC(R zD0n6a^)552>M&JXYZ`4PCZ?vQw$J+8eY+So4h~O(b}ma(B3A{|;$(E@{54%`)~=V0 zsu!MVd7e(jm`THJ7rCf^;TpPG;Y{AjlFTR%xx~!=H~PcAljgl-`WvWIA|nGa?2&hN7PN9{?@Tzc&Eae6}nd_(!8WfIqw~ ztHi~5^q*arw$~G>mk(N`UYZ5`8YaPjgpsnnvy&5c%(`2{y>;Hl5Uz-QbG(QDG?nT)&kbUb8S5`u`TrYEE97bc@F6`^=O_#UOI@$r9L|SME^*yvQ z0jzk%-S=GYSpt_8tB+GPcx@#mCDqY{-K+h)h4MsfJTDPH+n}30*Yj$u<1r^o9J8>k zIl)l5b7jVb2VC}TUl`4EA|R|TPIWJt)ffl`uP;3(oGKC8%cim_;38{?0-d3cIima0 zs7;4LZ}8>W5t9$=c2o679vd4Qp1PG!w4zs37V>vU=L3Qn+}%YvD*a!*u*w+1I2UmT zPU~DD93>HpZSJ)b*ru!ec?k7oZy6`xv|2CE^>(s~&RvSRnHbvA1ZrRO`RSp~zjjDH zg2YEOl<);2vhHG6+KCC|g?X3{DTy9X@6zJpOzuohIuZ5~KG8fz%MrJ<17xQ*ljUFiAXy)K zJ2_a-Nfg#T0P?P`S#>$;g;FTFU5r`)(+((LU|AK{jcux z|HksJuFwCr0nVgao%(MUfb##f8UFvjfx0Ra)zCpcPlQ=impo32?_AuCzN ziB5Bp_T}t}azVS8LEu3Mq#?Uleu>N^A}%_zy3y4K>$iqtoTqS;N;%ug>o7lQ-5Y2{ zRP`~Kq=lgu9)li4w&?l;DhCQ!xP^HN9JierVI-;ydPrSWzMmm_&=|B9Z{-_A?gP{` z&^zcHt(zOEs8hTT zVFr74tD0XY{GUi8`ik#3l#;YAB^%Tmd0CYBa<)gT1UP4EgJ;;;?}#L9O{2g~$^W+y z4adna<^pKTgMsmI)0PM|VqnE;Ps_d2DQn6Q&AavCPrMv_CzG)P>PSAw%bTWD7xQ>| zIxVjN&aZ;eRYylhsA)+)WR-{^9GAB|Z)#93Fc@RqAMUHKcQEA5Dnq)MMm%{k2JWD3 z5|-NZCpffOPpRwoPc9PR7o7ud<2@n3$~8ApPtP^qtJu48K&ZZ|@XOE$g==@2dp>%3 zN?URZqX0Mf@oq>Bc6ZydGn(qw!w(`HIiX@61GR$?AJFV0m|{;!GxeBi!v7K*wGx^x z5!DQ0RR4h{JWyuU_}uMDWT!uuC@(z1J2GmzNd@uBke+%ByPLKp+q<2frawh|;_-)cr@> z4s}#r=Je?3=twR_H5O1`NSYp4LdE*~--*Bs=3d(gkG{Wl1JUBEg;(9Tnn>oQJ_+I~ zy7S4+xv~a5Vv9=;m6!Ix)r=F3szFx>BD6`Pwr>F_ca?Cl)M(C(nu`5c!5P^#lt_zu zRV!5@#tG&Ts_@3{Ffr%=WfNwn;++CVxGafym5y^%Eez_gadr4cZl%5mgf=+=E+9U( z@rj8roZjm=1ha4pG%yAe9TujKmrY}{3t;%{saG+(>Rq>n;_ODo<$nOET?WzGh*EJp z_UZHIJK$G8ef4TDd^hUU#sYauyXNBhs2PAA(n5)ngOEPnQEzQ9_GaS2a>3E;2AuW3 zgNL#G_MLm%A5}{UvS-5VzUt527^{P88Edl}E$I64JS-$+KZ{hr&F|0sq0BHBDe4B& zxiHz@`CFHoxp2P1Z8P4L4n>1)Q3d;_0WH-8jYc9^#32-V4Y|G60UvhdU8A`RM$n6n zL1^B4TBkJfS}?S%{3E_9+ws*=3kM42)Ol)v!r9*v`cNSP2gfH%A%;#7$U_<#Av1g^ zJISpA5@KEO`rqGQ+nL&I^myYAmVCqlh?)Et)Qe>#&1Rx!&zZY7|JM0=ktwK#bYvrl z&E^{X1v-U6e9jp%;W_gbO;yawc!<4p>cMAXYZ*HXDvv=jZXX*fVlF}$qLU;Oh($8O z59pu1vb?aPel6wNLXr%fuqmL^z!Os854X@iOV7m-9wT6kNHN`Qgea5H0+BSurbtcz zo$mq^_UIPDYGOU(prsXM!mW`2L)E{HMo18k`#%1p7Jtl!f_TTNx9C@B3H!}j?tR*7 zf+}K(8^y1Y8YLHfe8v$0lup6+IjN2Z*I@i=jCB;CAe6h?{CbR?7L)_5+HAV(h+pd7c1SN+?NH{|aBIM)ldtzht#{~$>L5S~zo2@@2#+#eAv zCA1o|(+Uk=be9khy_g zH#=XTHA-y ze?_()bjo@lDk6GGtk=n^(}PzW^|!rN1bPmSz-6&~oLI||?rTpipRmwN=5z*oUjxu#LH}AnrFj%2Y$uLJrt$6l zeKHF5mzbU(hr=_|X2%%>Ez=r#2~83KAvT3Rj$SBky%L1A4}Qy*eS9aCQt$ft*ceoU zVRZbxTkTO=n=Kn+0Gc?VS@uGA2%Lg4Q3-iq5BHsGTNt>HjulyoixG6=rJ(M>(}tt{ zQrOOt;B?<9qCXOs94_1iTvbQRyK*;dg~H&>O@FwU`IwQ6VE0tEoNY6q0Z}EL_3 z#%?Ga#Mu>t9v%w#NW30mw2~kLkBIKomm&_f#26xhHu8R@*zUAjc1A`c9HQ0I!c#SU zIJW*QZMEpK8c>YKU~(5C3=xW)Owzk|VZ^V;#-?{YyWG)a5Px;%CBcVB*r9e1-B~Op zt<_SQbQK<<>tckRKD*4}s*dQLU5vwUO1(zc=clz*X(Y!5$7&DsxhsCBj!>mEk-K7I zFv3ESFCzlE_-I=Wb*zEbXDjK1m{MCfQg0j!y(_t+h}{&COC>8wP}ruN2;%V;kkpx3 zy9HU0zs{SPnH>@pb?SATK~{i~2y7{iUYD?Q?64O*;HtU7vN%b$30EpO+S}Ddy#mZ; z6J!;rEAM`~{1JrA2<^6}X9y$=)I^N|iclpATs6sYsOpoDRv)qtq#IHvA`WF3w$8PV z6D0&6*czs|CY*3C*c>-8IyrdQm6C~?PykG_2k{3__*bkTW^ae>PL#lAW}-TeC_ZXN z16Z5va3b^@)d2|dJ2ARlKhX%D@Dvz`B$zSa-p7q5Fcc-A^>|b$w&Fx=Bjtg_dBp!+uZ3K+^L?wtpcpBF=P`9Nm&sO;}K`O#1{~g^ECX2 zjdG^y`jBt*kN7J|J3M~^ps8dL>XIw^k&~AN7=st-h&Xccn{PN|9H@zwd#ifpe(^Y2 zCO`jjZ9MMZUQ7|oYwg?@W zfUp!28FAB6j}uN{hv;&rXa^cRC-2>PQY=5df9gaPM1M7|gQ$K1LOFRGlP0~0BaS1% zr1ur=GrN+(gH~3%@+DouZx|;ofewgJ3u_y(pYdX z(?WXka1#?6^OODEb;Q7t6S<{>v*=2r=?5qXDJiDKztdBYLj;b)snMB%S~dg#M2GB! zk!Jxb5$u3+9`#i(f}E#+&bt(XQ5SEW%SFN9jjd#@^X}y>O;d3h@+>M*=P!}Ix961B z#S9*Rk{OQCSQ6JA!9t|tw3RHKK5>-Nbwad;rJcUe{FQ=eNM}+JY z=JoGzOq(1RkAeKR3QfG#IihYrwrfsGc~XuIT>?;Wo!7+9xLat2<#6y5pweTY7=X=*T8` z_dH_4zA#Vd2w;x?lW8OI!Rlep#3jvax&~ zV3(`Q^87LXOzBLa?n_-{$fV4Ku<=Iqa}t=zmYH?vgaJKzz&;Q8(M~lLeYC^F!xe!< zbLkiu1}ZDq;C<9#axKeCXlR5^&31Yz-@kvKsoRQgUmfVt_wm8+fC*A&wDhOC#20bE zW1&>qEzWQK?YD#=yAgi2W5>RP+d8lV`_qefLe{?Nqpi4`Q!ROZX4IH}AB|U#F93id z(h-ZwjALVC@^9s1g&fs9r?S z)?&dV0#}UkiIRgV*UV&`Yd;tB^pJ!S@+3y}-2m zTN*ehu*Z+jBvc|r$wxa-*siWGgWqX6yISS%O?1U4B>=WglqAK#8NCUx_LA=R-@jE> zyoGjB2O_;9Uj!M!JziCN@+qRHgXsBIbNqp#PPx@B>adMHW+A3114&J%w##L->~Z#p2Z5^>8k<{8k~SlwWUIxrkxL$3^oA9d(; zt#MzHR8}080iR>`@tf|cohj;cEMTc9mH3~M>#;G=Uw}u+C z95`@} zG9%)N9%81vxx04njsT5qu5!x8nqc~Lqzo$iX!pdwKM0E`tA!^R19wL}w5kRvPid z?nD{X#|>15xzOz6`(~o*#JSnwM3 zvsV4Dx?1q1Y=;ty?v@?9g!Hf%@y8^%}Zwu=sJ(a43y`^`F$mitJ=!8tK+wE?0i_6gzt! zk%c<&q?ThPpON){jmcf@bME3|1zN<;y?fsxq0sCzfrSD{em*_g9*>?K4b1B1<(0;s zch~{G+B}esVVRzy4{*v7PR ziW*d-|B-71s2;7JmkZs!H-!gf9gf4>YB0CP>9=t!f^eJzG0$1U(!s!OGM>G9i=W-a zp>k*ED<7XX@86$|FahhnhB~xQGU&lw)r4Q(y@cUHR=>W!geFccCmIRMh9F*C8;+Pf z0JjCUta;*c72YWp@7p}RkbLj^AF%Wb4YpQ*ObO$WCjXPS|5O;aBldC5vl+~6@`7km zT`7T?8)%tgP+Z54A1|JLiMC-e+SWR2ujlgO#m=eWW(8b@JPx%(>WDhBg_mHqRiVY@ zsz~#43~2x&_SUjZL9V;wOAWGS{>mP8Wt8<*j*mj z%-4&dS(9UD6+EP5*Z1`R4|2ND)D;Tc=yGz^EWgEmA4A$Z5-AJ;%b}OCSs#}xZ%7M= z^Kt44e-++J^?6Mx_bhP5THx+e%526u|PSqE_GR}s{Rij&L(Y~dF?Z;imxDa`|uWtEN`s&qMBnXoYo6oFUHK#`pE+($izt^+9Hd5tV1e?N^cpXa(KH_FC+v*2Qbal zu8E$I2@3#{Oc*k;$Vb^Z#5Vt(wmvXk(E3LR%axZQ5CiC$Poo(H)8HjBC4>>>b{deI zs{)&>O*2*}Z?zju9Xo6QJYM{q)yvQ(LOR$;nPdOyjXzcojq0=K&r=*0EF(+~95{g6 z^rSsL9+#ZV-3rJk;JBDxQ7?7z;x#mM^X_80;CPQxGy`^}Z0p3ll9M`$K5C$aPQFiv zAxD&3JM)T&s3Gd^=F66u0!?KQ%J<}4CQ#8p|HRdiXK5VG^|Qt?cZF;M%-J6tFi% zu&`)VCoAijwY{u(;p6O^}2R_m_aKxm?#w1XeHhI~_sR1HO~GQl@!SNs;#n zMH#x|vezI90cC~kV3q1aP=RCg)DN7TK4hz_!5hArZoS8`fB#v8S>!Mq<=$MO?=JPb z0k~HQ#KVItQMPo9z4p(a2&9Y1Maf*id9OpiOwlgch2JR)*Qvs!yxY31Tj!7S=WBtP zFgK#{^n}}^M_*BXC>Rn|6aV)D)dE2462W)y(8# zU(R^!?w9ZP`XI{-pl-=HmkC19Y!E=E5|8hFH{D7$okGD`;C6!+7v#$}=0PHj0;XM$ zN{Rp_6+siPKRtNi1UQav1Nj+uD$k8pCbfdQ;C*7NP&zC9FLOg~ms9+Ab@?Yw#j zh9;7cQgT)c^Ry<#K*5_$LF)VW?*ewyG05h|V&o09WSEKb%gD&!(N3AEp~mV?+fypx*eB|d}7*L&fY(>^$giTh&`F+8)Qn7CSgt}fI`HSSdfrU(AO`|SycrJ zzu$|L^gd47>89NrF3}z@lk<)-@T1 z*xoWnttN^y5s#kgPFf#PT(f3P-FaGJEyKp7ZXf=c;bvW7@Kt%wef;hd>o;srwwoR@ zQ5r>W(tp7zA|moh2$;{*zY}=YLO;oF?%yoH++4OFT9Kft0kR=d6X>mVvvb+Y+xG4? z01ut=;YXZK7b{)Xon0>qqN8#3<;f3#o_Tu9xPBxC+miOtxLnVL>hh_7;vPC&nts{S zLoOsnrHqDAesTpds#u^#hF(bOG1dMK$3XfxM0sh2C!YXzL@D>Q+Ho#Zi5JtiQP-jE7*!DN05eV~vZo#zfgfs&Ummz*3$Kspu@KD}rKbZsJ< z1k`2O&zVMKxb0#e;Fh+x&qh2=XEA;4+O@HWAmtM|lih5j@m&$aDXm9z`}W-p;5~de zA3*L&wQcK?!*Ztt4o+?8h0O*(zI~{J%wc2?>9(SOo-#CnC(>5>t;_%imN_}r`1$@+c`3?`t6bdU|l1E311p(xXwZfQ{ zGNU^tldl&9d?iqJ+N49wH>l(lKkB}+Q}kcnXqsUR9>zI{qnTEk;k4q&#==rKoO{#U zJUIp8oN`Yd`svjBwX}PB_wwBfsqVdvS!GnqKu@m+5LqsXClRzquu+6CAdr^fK7mA4 zn_gLNl5Pc)Z?G>^bfr%&v}xP7Z#VL@LJ7#<|7Bf^V<3EhO7wzn>=VrB$7JM~yZn*tFu!H!h^h;}I&`0L8WW93m;0VXkhfI*T&K9Jm(r_$q1>`|NlrlRtbhqsm ziBdiE_%v8mRTYDll+p4&Qa0+@6w>z$!W)cp=iXc||N8Y>L`Dm}XKAo#9vVOd%!42!yDdMNo>f+T$&CfDA8=s_xq~3F&O`aN{#WP-CQE)K;iiJy0u!@Rc zIw#A8b<6j$;UT!fIXne2%+30#T&E)Z)As~C#jrMvp#ntM`{?zL-OD%8?cCSbR99Df z@CxbGL_V#yAK4-yE+L_e#&^vh}B1D0IUrI>YG&!OuhJ{mW{jNIWpgEuD01*~90T`Vy zJNci@>q^MxbDHd-n8k*LBrB z#?U+yJ(gI+M6yEe8Y_{|5fH_VxgHa*b_lu`w___lf`Kp}(|*f%bU|qKPtn5kcdkGb zp>JoYrURyvQ&9=XT^zp*%|v3&n&j7584BgIBznHaPOGe$UKCmsFa>>qJn#x$O9qtJ zV?R1rX0eu;V4=#XyTD`VcTTL+#F#IqWW+)D-F6ZTT|7|GQDAQ z7~rbK69JhBz;cfdmXrZQ=EfwZmg6#_b_303Yo1t}`_cOd zRz!;)8WXqMX~x;&7-X=c5FQouFyq_j548>o(a1OhG(KEyP$hFaatr>@>`Dy2RUABY zXbOv+!JVywVMKH(KF6VkL=K+YzpPz51yg{&z@NKVWej$}QFP(%POycDzP>vVCU{WK zG~z^_K}&ifvJyR2TnKS-F@*TO#$$ncc-n!`uX1%fp~PcQSE?8A#h?pk%7S6X&YfB< z+Li~=RcF*zTHATN)Z&JaPBc3+Z(u#0v_+Oc$tBrfjQ+- z0H$*dh-G~2%d8s4CJLBGR#tR*U5QV;m`Cg9#lTY>z6h8v14k&Z&~Hn@tI1x3vWKfQ z=ifr}Xg^x65dReIw(w}&wSWEqWDS17D$V3uk5ua4V=O1ImGY?X5(6!WaxCOjyY1#Q zGEjTZ{F7~(;M9kuL-DcA*gkrS@q{BnSG8y-Y4uT1i{M}#3u^$*%F3Z}oIyA_^6dRz zIGF2!qr(VEAyT3VzZ1kDj}EZU9@7zJU`x;Sv`jy-8<1;(%UVW%_=Ocil@0xxe*$v+GO1k zPgn*@zWAuma&|ACm^RA3oo|8QRt=+7%k7;DnDCK6=^XXi3hlLY7z_#kA*}*u4j%m2 z=U*_~5L^M3R?YhRI$y81xT{K<)~eC0$FFF*%$2V^}807_Ou#=GB}Rpy$5xM0sx5Ht^*T$MtYYIhEJqBGxA0JNj zkulpwIsAL;vW;KcDS{`|%8Y$+$3&M3+n2U5GMuK7q&XG)a&h0~7f1iWAhBW3+xT23 zT2L@#(YrM(Fx>!MK|Jdv;;3>@pFU-Z2MiYpV{C&gnrzxF=~e&yH7>EnXY~=s{br{d z-McJD$}XR3dL;IY$+qu+uhqs}KQy-5c=dKY6|5eS^T=f(MbpZZ%iaHY*U@Wdf!)WH z90gn9^A%cc%=V6H3bvYIc&s&TcDF9ZuetYPmKrO|tEbjY!P1v5J=IoZ>P~pv&%Xpi zdiu_TpuvH<*tbCX0v-To!IgUr0yfgn$io;GE?=v~vSpRz|F&fdw_6sYv6n?Fe$>4K zpjoropI`^K)3IBfdM{0UeZ6t-u3bITMw@aST^w*Gh&@%!W4t>kTj~mfNo8fFQE%Fy zMbfPgzbN{6e=ov^&U98XVBppu4rv%we?xI8Kn=zSz;spT&jpz{O+S(oXgBdpF)#o7 zz4RQx!LqNxGa|+_sz6Nx(;=_p7bakX{@|0)X-ug%YI6*<7R+;n9?m~pk!3gYWe}e= z^iuC(;MMY1_#gqIsu&u~j|VuyQ*2J++V#Yb3gw(nd@?_K3N~1Yw{-;;II!^7b?a=B z*AA)k-1+*xZo-G| z-br9suf3n0poyCisZlS0v3C;OqN(l<^?6;2pE3wQK%q*OC42RjjEqmO)n#&QEkJmS zQ%UOg&dxNmUa97~Qlq}CWv?$%9gTGHJ1Sm>hwDKUN;Mk#e(brKEeN(@PGL`pnLvH> z6FM8mc9PuR|NJDQWiI4!$O1u-lmtD{G2Q<7JNFsNn_2cAe z=n!xR7FGX*)A0{Ek<-~2+LL|qfo*ML^Om-Oe9s{u;)D?S67I(W=dGn=L zbdt2GX;NkDM-2^)yQ*ojQ8Exgf;m*upWNQJWHVD~wG>?XH8>CHO5*-2n-MgHtsg(? zf`wP?_w?PkaU*#)89q=L1ca8pieR{X?cuHW5D*5*#fWe1odnZArVyth#jEv?C~eEW zpFQ;Br%xUoQUQG5F$JK!SzkWIU@2lV{#xe74MWD_^V2)%>H9~wzj*PYYqDd_BIl7K z7hGtW>+lIdM)pJPdFn_tSA)|>gK9YW_Pr3T=r2d1>6^VlQ)lX>-ZLd_gTnj-^sgVL zJkx?E9e-WV`!b^D`RV&LEO;_VM`Q@mnsAY~4 zLj(w3Rk!Xvyb@R)PrKAx;#)CnG3x-K85Op0^m^!Dm8vnG%%YCXG=i@p@yO7y0hYJ1 zKVbqx-%x#b3fOl-o|#9^bP6u20R%LGd*7T-l#E*`8l4u7uzM>ydbgp` z(@bE=M(e`QMsrj#>Gky%bod+f34ALIG z6GBt~6I2E2HV)sZxJBk8*$NNz9`Lk#4 zp01A`-OEZ<=308tkStEQ$HK*Roik_hnaIkb*%4OO1-zuTi&-YG{)Sue$dX2*PL?@{ zWqEejtxS5?yoCuolk`r&VxWsV4hKd65eAQnrN$DjFMhCY;FYY~u%S6u#CfO}h=&2l z5MdaXX4?(D%3p?-Y^3cb@Gzh}3Q2fG$2=Y-;%Qj0sR z`DSSjQL&}|iRN(MPk3-unz(@2qCBwi-cgzCLh}c7#sSq=P2aqJeH6OI@O^FDiOQE#qF8gJQ$Gy#2q$>Y->gkOZXW-tNerqGN4Xj=}0VKB(yZal5+1@Zo0B{<3=t4V%>1mca;+9cc(8HW^?`e zCk_E34oZM2M3v5YCW_A`B)AY%Rbf2qPcs4Pg?*L;h>oKQmzUw5K!$_cQ3fEof=$)q z?8Pa>>u$-}N6CfMwVp6H>v~NEsBvD4s zkL-LBxRpuhCZy|$qY~IhoQ!b%7lr?#@#x2Kg}9jGa?z1c`<^+r>wT`a5Gvqjq~4Ae z4|!)8nbqm|Kp7&IibVERNCgQ+A?-t7pz~%q`x2&=*>UJo(z)!LMMF-iWV&sUhKiNG_}IKZQFrM4B)2FKf>;&X-}I8Rq{b@lnOA>b7Sn{;{n3HR%xfxe62U`7t$+-QtL;jxjwCCxQ^`eox59?CS%= z+pR***gSajWv*V1g&$n0Xk_*Ix%WpfCr)Q(DgN=tpSbiv!~*U@s1P@S$h6!X;&;hd zZCsr*0`5y{{2(X#MeYnT4||K`qh&84OskgGW$_OCDjO*dIXC;JVs=`4O3X z^vf>+D-lLYO8!CkKo(}%$t+B9hOI9z$>61o=ckCeQKW{z2n1cLs8siTqK++*e$aZa zHrGw2;b-Twe%-I!m}&7A&inv(9Edy7z}}T|j02DoM+Sf;HP+_a`!!+8quSv=yV8$9 z*e?tPf&6CT2)qHOtV}?7=_frm5gVP=Uw_q>A7&Hme_xeo_CvU5-lmqrY*VIv6>D=- zY^U_Y(h)+`N5;7rUV|^&xQNTA+vZa0S(dS}R)6AChWxIp$O}ZQwbWQml}LdRSpa|U zE#QE2pqo}=hc7%dk@2vvj&DbTQM_mcPbYAxmH6Sompwc;t{P-9;xV);vZHSO91DmE zKdg+s7I^96#UFewjEg2x03z9|Z$jP0cI1X1SJHfbHVNQCo(Wcqy9i6Jj!FjiGBDVd z3~pvwdWIa?5;7JXwZ2b>8cF~RXMHX?iSY5=BqnkZ`Jlf>M=K>O5=PQ1Mb5xubknVm zlL1CDd!I1=ob&y7ZqdbL2uwyZyE7>ZY3+DaEdy2agHe;yKSPu?AcqKkVk)8A`RZrW z!;>OHV!;_GP@Pjl-}-;=#NysH6kmI9ge3Je(Fu@3Kjc*p&w#6`v0?X>idS$>W%c!= ziw46ZA_ie;rCqP9PEksVydB zcp8DJH{o%7ML&o^4W!5}ZVqIoKsyt)ukL?)dv#}JKq}l_I^bWBFa|y!ww)~{08T4( zo0FbSh*uiTO}4#7^(6Hj0c;&I0_Y7_D>FIUKOgHoNO$3h>n$qr&9xV#cJhZk>d z-~o!AO0GRR8I5qsWl#CVA3zz8pba}ybT+vF=P`0 z3VDoLZbF4)ir3u$p&>VO2dnc6SARt>bUdGwD^_Qw|IQl8y!_$A2dPdpLGq-4O?De1 z)(BvAaK|Aep4#$#`)9YGjcp3j5HV6$$5&m}Xc5ld$d4{qGl?)Lsth=RiG3GrkX7nm=c(aIX>L=JyV1P3*m1 z9IBaJW5dJ50wwZb*gi`DZCYlLmOCyU%6PsjK+U%QNLT|PTxQNHF*}P%spdGR-8ly)Y;h# zV=S%S%z3u48hJum5octf`y?0cV=hpVSG>N8Il8CXQ*;cS%K}1p1#4bgO zf+!^xL`4x1K~aPxDqF=G5EZdR0V#qV5J3|=C@7$SfY=ZhMN~i(Dd#uW!tD1u?>D~j zjq{zq&Nz;>DKhOB*;E~OdWH;}1W2<4D3UOn2ZZ~Y^G|83ju z^AZ0YooF7ug||7^7&|+E5B(%?DvS@8R&6i4Bz#14O7_p}I(EaWCn($1@7BQ6riaFy zAPy%U$U8m25R}QlyY}p;&}Og4v8RVzKX&ZcA|jtl>eWlq9=_{<`Nh~z>1QeHX(u(>wr6pa$$R&9ymQbnZ%;LgO4`&3XER?z z7=3wmdZlydegq!i;+$L7ITyC0)$*Ru5*gh5@iq(mI(&mqbI_R;Yj|^Jb7swe2 z;-np&e3Oi8E7=(;7qx64_(yVA-IBh#+Qm3siI66PrQa|0w5|O!*2RV5gKPa83~vVO z$4b%gw(~fTzf^{Fr_a(XHRjdoZJ$1ES_Wj{BO-(YGj^%J=Ot5%cHRXT=R$Mp*xeeH zHB;pFVq4leLYz8!07j`9uSQse&kX209fuyC{(ksc#2V({6Wc1+{z;R6(h78JW9=R` zyMKB>YfXb0z^!g^!xY`LtfKdt=QGvw4F;tL>rbcGq7(9mcj(U>zrH@Ixo|AF#>$9@ zb==A-5H!5325(uO8MY0%1$~pt<;%!qmSh~ewQUCkmD)sKj#?GnQvOQoNv6S@>S z@g=FR?uz0rSXFa=WP&40JLtC9ub4Dh=ob&51o}VizP47mOWpSS>|t*Il%?ks4+ zrv?rkY615+^j^$9A8b|noO3CdO?T@7jF2Po@$s#KS9Z{`y(@gSSckE~n_w?@{ZMcV zIpskijcbg;isp;Y_tqEJp=XY!U|`s$c4}yL@ZiB1mO>O!SIYElZbme{HXQ7SYovKT z+HKWNUH;<=7~RY=Qxdp0XGbeswrtxb&ODj%lz~fV&R>t~R=`YVy9fn8>r$~~T>+`H z-IkiG-we9=OI<}4i5szTpf{4TPOm=S#kYbK$hS&!VZ}fxe^(6f%`5ze;Pfj0JC(TA z1vc5F*yP&C)Eg|MIb@5H>I$%Dc|S|Cjvt|b=u*_?tFP|KE^fOuoW>1+`aL)*^M@@> zhMNNhIF8;)wLJF{toqFJA%Px+@Z4qjl+(~)Bw@!+Z9{6dvG@H$ZNL6HgN5tVb*xv_ zB58|(ji(LPp~M&F9Brh<7r!zB`9e_h)w9{YWj8+rPGHtP9!|g=B|V_dW7fr!Y-2VD z86caj1Y$(AzyK^H+OyWcWyJVnFrNF38T069XL>a;aH#H%NQQvbdGIg34HbutZDCi#zm&Dq2rIsx&$o(_yquPEdy%(n?80a)L$Ya`)$5#dJ2#O8He z#vQ761ixBBJR;l&!V~F|R6=}Q!~^PFvT2k?HVA->Dog$Pl~QU5iv+hK;X@OA?4`D( zEeOG0MQmn&lxmvR^J9<3Hok^52?N3D!bc4XC*xz^%Ysq8yt<;$0XB&h*U^BvMh=U* z!Rlj8MK-_!4A#*QV{eu3N762dRKet=|bVnG)*LN`bJ4bQ@V3G#G* zf0RI355MSo#brI-ac}PrML{~V<#sPGRPPdnC=ZI)2hpY5%-Y`DBi+T-w#nQD=| zsrV2`PWM)Sd#AmofxGKs>*>mx8^b)xX32_YIqHsa;s47H9o74h)Ty>Hd6u${pz6H| z)0T(}00K3iCmB%rIyN@h5DttIzUZ?dcvFb!UBY~ZNuVv=c&DWeZ`6NeY%s` z-aJO#p!^Y8Ro<+?mD#Z?fn2YJm)h(9a#;~u7X9_*ONiO4(tR2-3mxCzC-fIrI^Se@ zWf3)x0L#I*FuIySZ43A;oqc5Kuh%+zR%wX^`tHl+-G z;W}a%Q|prc=rMJ#$Abdicm4qHAi~;y%{5s3`L-}I*1b)vVF~nyjA-w0Llbc+#BJ~A ze`p?rW5?#BC8u22JE*~pA6b7wWkEwa_0aOiZzcoNX0k#8Xi4r^xviVwL#L!b3L)B@ zCPa}px4iV*algd+{xBk(oi(&ACeQ?`p1oe7uTNN@`nwh7{CUfkPWw$F8h~D2lGuVE z89x7`n%0Wm{OhfRW6`&?nN2hG%iKea8cpdXUA6F1>Lp76pqM=aZ2sgnkmzvp_kqci z9{%DFahE)&3M6ICveuS#jk0MZ2r)QsWd|Cb5c9LBp_i?7QKThtT!`e{)^5eO{*Wn= zfwnx9SDxM9Ip+z?hXEq0$z&grRMy`rJtbVuWark_xWMbR z++BqR-fPv~97{p9sU|PFHH)G@AkmPN7xBD$@Y+!+Ws-YU3vYq!z}~Juv7H?AqKFpA zV%9ot6JwQ!C+GhD>SCqooXQrEX=uniO!iI`LYK44^n8}9N~81Og!7y_(Ju1}R9csvHACg)M&AIIoA$TW}+ zz@#1BPLS?inCU>yR#jUXD@+XiWFbw3>-lZiFyfBT=u7#&PSUUWY)Ftc@g z+;rN+RR%$4RaX9T+9<(pjWK7l*ismFx(*rQtR+d0d2HtMYLEWBA!(5g^ccBY4W>1MJ`Hn!X%>3y zSp0#!4ovM6WEd|+1vhDX*9hLog%Hv(TS;w%p~EPr9feZ!vDJd%ZBk zmyi_BPH+>UQ8*bvr?$BeUyaG(t$6&b^ULN4A{kp@JcfQUd1x5*yGRaAThd#2}@q5)Za0;7-xoT}@eP7{dcT{J?xm&qSrtu%Pu=EslN zXF0)U?vCyY3Q0G+4JsIKx+HzCYIR{ z?jjklZp&IUzx>go;?W7VL1!qhv+2-m$~NxVv!`55(s(Wwo2mcVQUqyIEZ`GxA}cHF zlK0qNy(X|lnnO_&t6G?Ujn`{rDzjFGUa6z(%x13(8U3>^(D%g^rVbPqK#oKG2hBcv zTwr6FRp1|%Fg6FSQsj10?V-Ny_BbkO>((O5U-jncL^2*M?x(U2&FSc}=Y)$Mf-`$` zDvZmVfo3uw*9b~T8Ql-28}D?BzyA#Klhav0=F`IrBPSJAhXFMA0jU<~Tc~8#&y1X* zoZk97McfX;p5_SGapFHiGUQH2mF_YLfd!DY6+nA|SLM$XbtrwMlcYtR%P!c@?65EF z&h@D)6Jz~C8cR6mgU0{mZ7VNSeQfg`ySX(eq2a^(n&aMGRv}Tow{`UNX41P86^oR& z8#Zj{QvU(wm=?5R*}q7#Mn7D0KaY?k{S%R}0>WBTsQ&i*?=X!bNg*PsUFW0Ac4Fy?Ll zkt1!SW`$X0=+JglqxtOF4-cg&VQ5Ip1IZD2^}l2yLh+)Nw7~PH231;l zvy2@_7J=l_i5ffF>)6hs9DC>4X^c=XWE~eeDo)z)=6J;8(BdBM*QWu&d`>=**2Si% zu#YOv`p)?Zc^pKZ8Id!zeS)6u^7Zu{@MeTDabbYbvY6(NJw=NK@v?jktNH}ldCtM? zIV2v-mnJvA>{;hIR-MPIE~fr}Hx;Vs#Q$)af_T|eBowND3r11Pl=Zz?sD{rK`P*SC zJ+1g&ak(eyr8)V<@^$tUdSk6ry$6M2FP)PGV^yUg7l@xr649X3OflMmOkJ`mo>2Wy zhfUL=&aSzl`!{MW9Y2l>rh;WjdYV3?d{}4JByK{DdqrH^aa9dp!L)n#lb&MT#*n7- zCS9K3>u`?~00Wq*8#iv$X6m3E&5|_lvzBc}7$!rY9CKW4c}ll9+L3uGZUnH3c}qF~ zM<%Q};Z)+8(>we|2YA#iXEuM`e$zkS0Pf<2q9%V#GN97%PtU3t_wJ-e`D}q{T)I#9 z>2`B=^Ex~A(~Ywdi3CmMrhntkKXn~Ji_E6gr-8Qrwtf3hFq#jE|5&V^4uo0|iH8CF z4&QAm#G8-5{*7&T7EAQ4CuW!Ge+#;hdEO;-Jgpe2$njUT4f#!QJ-w4G)b*wOxar@` z;lh!|AuIy+MT5>4zIQ%d_@p%CB~vTqoR=C;?N6Ur&_xn#im$Bj>D_&TnqU6rw@;6C zxm<@hMwGcHQc>ay?37O*JH3%^X|*zQvcPk1i_s00ZiJ!i`PV;{?$x(%|K7c?N#t6mMQ#1{*MDwPN4v+-KG&K{?q)znOmbH0@7c=}4sM{l3xLc>>a!DIc)>))@BsR{L} zD60P=#jtJgorSuZA~HDq9`mBePPz;CrYRI7?m=LT=FGGlclCrEm+%1J5hOI~-vT=P ziE1V|O_5?aJyWew_-E9x^waq7b8^=IWbO{QmFF%W^F18}>F;9FhwQbJITgd)-u_YV z?$2^ZbxSX36?`X*^MZ`Vym<2D2tAKZ{rZjF@~~3>}+gVsgU z>cJ^gcgBqMG^BsReutm&Da_`E&X_Ugz*rx9uzY?7(E7IOxxMA&{!7w7JK3UPnP3{O zp9y}zFf1{+Qr^1z7+dU66Q0kP9uD!lb^E}!A`BuU^S*UMGUXfH*4{QV|KJKhNFitY z_7vG+DG=VKEh#$~yIF7{b(?ko-+5TEp1{FnvB8xSc!^t=w~`oXhyZ5F)dy*u8_HKaLDZE=>_z{wOKWy@hX&*V%^Zkx(2~NnqIY zzV@NvxYzWbXfrhb9}DiHyR<_zx_5VYUr8GFdlO!vP-tumZo18;i}4@-)7SN%==RSa z?)mSLL1FaYBZzPB{~l>B@BZ%*>6fYhJ#wAw^}jir{!eaiO>i4|rpJ$%KIZ_Vp=)Tv z&~3_70|EnkckkYv`Glq0bl#6O3<+4!GmtZ6=IUSwH=X*q|GHP;u0f+cnUx=EYXNYN zzfSo0;lsW(k8fD^_o#XT*|q4y^e1)h)TvWSYN{orR#8PoELD*ENa%*j1#vrr zg2DmAK@%izBu}y=OwbaY6<=ImS0}1y=7wxw$;ruKXPR9;4r&uY3AJ%-TO>)dRPp6x zoV0alP9cS=rEMOtYuAz6w~g1VUVSq=+nNu4{pFWt^xc5%kI}>1Z|d<4rmGv%4SDj# zix*!&MEAtTk-|aHYNNE`{-`ev@&?vXGjB|1&-by0UNlr!w@9hfc5rZDc-p`E0DMw% zz02}pA!R>Eu$E`&-CHRsJ0-!;p+ldSmk*Rb{`j#!m=m>goy}*5v(!`J7f-~E_t^-~ zkrqkcfSh~KyJD|3Ym}rJ)|A&cs2(Bm-@0<;%CF~MJ!7?rYb2lt$6&Ug@!$UA%A};E zRT!6PQB-iw$qs>uTJE_Jl%H=anIQumC??3f+3&p@I0WE0SqFvHRPiL>F}uzvS>LAn zm}p=bZ5MgAOS zcqB`{&45*dcp0-#&D#jdbTlVt77eMh6AnRqtpD4PBRxIets4TWW-87%VoHQhstkAU z+B#;)#7UDPDO?4Yp(c2F*h2AYgF4&u*mdYoEa0IuR;2%mJ9oq-QwksmH^Y!D0~t<- z#w%KC`kHf#mt+SqB_Zp~;NX*8h*x9jH(2f9E-W{jP zhY1f2h1xF4H~LA8as(eBp2}$jf*s4219hYX9ldORn$882U0h-~P^TR>hm!^5TRMDB zM6#*9{XtG;iiS#EipPT`D33@Gaq5c#1@gccTlKH26A(Ox+ndnZ*#G%WvxaPjP>2rR zfcKI1Nf)}Ue%uTc#h9wqBYIVLMs>g8+Vh-$Iv?0ED$Aj-uC66e-aXdBFxssckWAv$ zS-86oyAo%8l0LcA_MrL+%z2~wRpQSB&LiMo*@t=m-numtx>dEbYtVruHc2a(rJH?NRsS|#hG3DQo$1K8rDsGgDUmz6fOwnqhB3~RATV8^RW;Im5 zC?+4BJUM;W$nd3Y0EVLiSp$<2F9JyY=f}l16<8zwX4O;w^p#c#@=N=DEkYWt30OjF zXRD0#tat<=>@4@cicprxV~KqST%YsCzo$(^6}lC}z7-;jqz~lSx!)>%s4FvJ5+TNj zrN-T}`0fXv`ig^vc!LxqpO2aYpoMfKinS|-0WoXe0)BXbmkMAd-YFi9jqMaQYF>)3 zg^{uG4C;nzj!PR)pkDcM_P%}l1X=~?KMaT+_;TRmd=unQOv)p85Q<-J8#{^8KpN*6 z>S2IQiEU0N?k9YLy_qr-r9-OOFhImB;L=F*Ohx1Va<)|*kx0ycgebP2?v!6lAUt|X zmea-ROz9UJ-la8!TmUu+bbc|8XL9jj`MLnQA|tL-6ZH824vVDrG-L&z#ta3>N)5F8%k`~`c7e62hsZGiBY5ZHupK>DMK6G4b& zcGJ3zSpb?Q&@ZjW^4c^f>-lGLhoNp2vK(A&w}%(3R=j;aRfV;|n_9(x50CMNLL#ku zP!{YWxOc0;$k73loatnm+Ql6{JWiUtP|WC_lxLkg@e5czT!i~5#-C-a!XQZC0AW-B z%A4AC@ja)`P!3O!cLuNTCqPWApMLuF)b{Q^b>=h)??RjM*Bg;ZP@(VNMOe}iEpr;3 zp@4v~CP*GPzjfMzGr3i4f9yUPFE2Y=S?!Vj7`-KLcGKTZwQ+(uYDJi0kqZd%$w!zc># z#T3b5D&+xHLmSz}q(3CgT7!uhhk_;;gA0q|+Mt!Pm8jmCf2$Dc^X11#EVxZm5kxCB z{GhqrXOvGskW<~+viEH5W|#;Uc%>hLASs%E?nbErrSQ-~9!D5>fgvHw3|jimRp*>q zdq8>R2KRJ1eKvQgECg(;*JNt9!QFlK#0e8#(t-`M)S1g0#sHv{KtDX>>9@JLxk8Mg zELdhxG*m2KL?-~TL_P?gSw6@Xzz1c+Wi6G^H5n43PSAbbo)b1pa!wru(KZRUeE~C0 zeDp-gZ*c8jjIJTp<=XoC+3qR5NuIJGMy9qp_An3pSiCocp5dP2h^6y4Fd*-|G|All zmIhoTPk|kZpE4j-{NMl~-AaF2sj$AGMn|J~DM1{G(2c=Pb5Ch6{Sz0@E0sVb?ohmD z*WNAT9`AA9pQFK*7<`@uuLPAz;cq3#hy(3_mBIf`T3T>nCSNsk;leOwK!0gt>0ud{ z+3DJz;NU34(tmdsvnXXR_5Q43NxCT1AHH_KX@Y3}sy0q+w2n)`~F9ug!L~ky; ziu?7Mf>lGpu%0w&lTsrwhg*OGTJn9Y6CEufpBOu7q!QW1I`QxB(pLQES;v#e2vt;i zNnw^T9II>!pm@@;>eR~E}jlh6_1qMZt$qlv6tve09>OScdq?cDbDyww=e;zePi*i`{PduYZxJI7j z98e1PrfK~_JE%m`5JVIFX6v1Wne%rTzWBZ+5>yyLi-?e6x?2VUx=W#}^dFjf^X6TW z=Hl;tB7Ux|d{@E3%*u${s(zi9&tcRtiR_inl=CG*w=i}J$dxVtz0#lr!{Kb8tloPB z(76Vow?z*MABgF3p2{ML1sJBAQYKU+dWj}Z8hPyDLb%H8q-gbviUWSR{kE>k14`|AB8@IS4rQ7dU*wVZ> zPmv1F(Vz)~cA*!Yd`}g%j$V~+8Z@oUPN6bqg>UMp)8oSqnh){MbKm*QA?+rCS!i(d zssw|eWSh0@FVfyOaSF&013#t*ILPA#O+my$xV(@4c4d1?!a*fWcbX$zPRTQs{ua3) zfkHepI67nL@L*g#|0GmrYKOkXxg;^Gby;J1VgYEA}uE=fd-Y>1k;X@F%f- z(dTue2fks399_V@^(Nz$6qkKYsi81c7xTWU9DgAmNjlzB21C0bZr^dAw3%b?7tv|g zS#!9iJA2~;${860?p9i844?2v5&tk_d+*6vy-C*sn2&kCo$=FfoUzmafS=`qE4qtJ~YI+bVh2`T@qpwr_5d9iN!8Q z=VO{rBL^pbB>|Pu?M3NLpc)opzL1SRlgT2JL7Zik zP4Q;plyi-g8nA}Hn?*vdj;kK0=%xmBz5q$@Jc@?7T!VRJzqVXYVlDxMX^`!{$!;A8 z6Z!TqV4dC`D0HsfxY40uijk3Shzdz9ddOD<1_d$hw`GdZsNSEZb)MGU5^Su9iQ>_( zqv!sSh0<^A*h2S7`v}>llySNt=1VyybidK1)R#WOg%%-GNVCgR)%CmUX3Azc|B-9F zprxoEVm1JVn{?rntrFKjOEKxIVMH}2K8u;1JSv6ONxDh|DMGy=Nv^R|&PL*rYVT=k z3F%J!yh?F-(`uD>j<=8hW%HJaBDzhg45#rd9Hnh&HhF-hw^D!$fM#q4~@%R zx?k66yV8H&ia5hYJl2%5nKvQ-kN?-wtR&N&@E$yhWHL{z)(HhJrq%+KXHFl~CmA`QRZTu|nboJ}IydS(K6t?*__e}+~ z^-Cm+(rwtCcu;9mze@VZY{2}Aq;?Xf?S4xC{tuCRtZ%IMz8K-lXN)n>8a^KPg5ZJhkUb4=9Px6U8o#!DAHvs0T>q}%vM{_TFfz0Dd2wyBnPPUKrzy}B5# zsA~c=cON6ivZ?*y(P?2hDL3>0tuPX2sgNgwt-N3DK9`r$k3AdsjdfH7|5-KFH(iOZ zOm50D=f{^8DR!T8$Wu*{zvdD&@H94YvfVZ6(9zk|=WM*E2L#SMG)nPNt-bQ?r7bVz zxA8XB)i&o+?W5~zqq!-bMTi1y)we48rD)RxOi#@~kHCu!nCWn6W@6bjZi+M`C7tE-#LUY3omKhi4xT>VP48z=|27jJ z2G2~t9sSU8kd?PLzwO)qk!|^_^n!(FlzfntKjK?;O4+M^P2z8(&Xkm%;WlPfR|ooC zDp-EsY59GfwJU0}n{NB#)98$@Q69NlXF7CW`>55yQJ<1<#Z(JPDnIv2q2WQotMpHe z|F|-)Uw6!}PjKLzuCR7np1(@Bdzu?O?|!ig;Z`V|-D)a_-n!tlkZ)M? z>cy-#xt00N;tiIo`#8~m6&4oq<^}5pWK~~kPIWc^d!MOmJUptsrdqpa)@(?M|5*Cj zV|!M87=>K>hS4KNmcydud#2HADe((UaJd4tj# z_Z@Mdz0B7)!!+g9UGALkeV2Ee9%=V(-u-M&f!DKJTYipd2JYv@C{PO9Wq!qsGHjZ& zG797ag9_-&vuvKfatgFE?CDWs=x}P@!~(a}fnmBj&v@;Hg${rOgW2PL{paQj&*wLN zRZ(^aoW2bHnmc;*=m8&Vqc8R?`>Vh~eRQ*;EQ-h{?I_kM*ypYtJ9ZE>5n*D`91=Ty3_ zx*?ATQ%Q%li+TBA&hk~Iaei#gN0m+O+4zp0V=Paomj5m+uQVnxEu+Kwlf`#V zy3MgksC549bze2~qycGYSjgjVK|#)z6&?%9uSqE6J?u&D{WvD^)lk00s?gG^qQ7uY zw0N{nDX^`7;Gk_<{rexf!l@A^fN&DL>G$thuRq=MQKi+GMvvm+S81`WFCU8jI9=7? zGR?*UjW>S}pq$3X-n950p5Cvps~;BfI*8M>`J?{Bsr*3D1sD8obEs48PV@g%0v zL5)!P&KF(*BqlNtPw>w@$#%2c3eTi~%mkahHv&F=TV1Wk@1!#XIL#I&i~|67cH~#E z^;0JH8aJ|un)EY-CRLB8Rbm(9HOdW_Fu!mr;NU0r<+PG?yWMb;jQi%;0RzlHj-A=b zN1Q<}Kuabo4C!CCMjfkWO{a_$S1)$S4(;1JDz{GQ1tj?J?Y7=UUw@Q7l>jHQn-^A6 zMw}>W`NXIvzjtr`DF+(&FD6=Cmxfk!fnXk+IMwxDzI=H-$7$(@%(Vu6q)Qq|hEP&? z=~7<7yvJYrdBIe0Pc$rqXCCLBYNWoA5syeFO#0>oMyh_eUx5?2;T>Fxc_W{ubE#@~3!xW?e=cv%FXP#O zmMO@3fc)Lp%eI{_1<1B8#Y4xxXc|0({$2^!%G2a52k>An7wqTfXYXA-lgqs*DhdTq zpf=hC09ryq%_c183}uRqox#x@W^LBVx#RiMA2c<~t>R{U_TKqZT}BcYO{zO^sxMiD|U#~xhtpX|HE zXhTCi3RKQ62R1rKqhBdfcBTLKrC4*(D&)S5rN@MliyQxrNqFbLNK4d$qy9$ny(cg`MV}w=8a7 zn$|hYlqZOVGvkQb@L(90^#(_?#Jh!dVdC?H;Hjq}+AcNtOfhYlpz7az)4y)l?%neY zRJ6Yz#TQe@b9`(spwC4BW!&2d&I=H^HqW^i{TWr$47O;Hk9QcK#XtXIG$gwFqL}21 zqcdA_vWgGS#-^K?VXdF^XUbh_9sdbGT*Fy&Zz!$r3=C zM@y1t8qS*V#+MR4ailU}=wfu2sY!PVrQJrm2Mob?=tzXTfUw{A;EM)uC~M+n`pQdm zaju5y^o*DpwFiIP0kTPJJHAD|_mUi*CD23IYA^wLGPI*R_?t2dedj_F2r5=kjKsRO zW5whKjT+@#RQ!9H z?OQ@YLgi`^N8a}r|1~+Kq`*DX1Y8lM-H0Mk>}8F^bL4c=YPv)`@gDFkOEDJxMD5W3 zfrSJ_`7@T`d@z#LQcEL$)s@!zdrePm6;Tx14BzI+bkeCz+7N{o+Pw}A>y9|bBwTk- znJLaeqOIfDW-rD7^i)@xhmRgTlFN~!)Us-98=7}igg&cnC*!B+`btESuv(TGJR2k$ z3+am+T{6jQ6e~_?$B|$<&|Nr1?@=fOnc5zfWNIhy=g3$~{y82Ao;7%aNqEkEAc0VD zMcEIzS)sHW6*RR?6xkfJ5`zWJEUzz5!LXEOxp>wZTyqiA8|i_8?T1-5SfJ@22E-T_ z|L^BxkzZ@HYd2rncP>~XnuX`KVB=Y)OI zY;vD4%d#}s9H7~!%37xr9_6HjP&WHgD%x!RK882F&otpSo%1Ns0dW~V<3gWaIwHN7 z|8WV&2eB&~Ug7@LVR!X9Gg}G*%3(p7;=h}zy^+(7|D+%;;i0AJ(XCrIIZ=dQz91sF zmF&o{4kDNXbcw=9QaP_iT5rmLpBcbR1XPY6x?D=76F9!&V&!(B3K*4tU>^Y%PaBAd zl==~4LfJHH5wW|gCZ)C{?ZnF{WMW+RI!_6g{w}(Si21;u?E=H8dG|n6oF1x|ozC+sYDR?w*TNn^NYY@nD;0F=v$KW|*b4ccfvk z=YrPX=EC8oepb=Q*7vbLw`Ar(n!@5H!$#$QegvbU$a)<4hv#06-2+4DT}SNI9%9@?Mn!Gmdy>09K0 z7(h#LL)%qt^3^xrWRhBm*0^WOeaz}Coa!_|0o#duWR^bfYEsO{dfrvkqeNS2HfUHwsHwe|KyNu!gya0j7Moy|+DA7P>E&P)g0eVdx;is?PzReih_+h$G z(r|(!5p3gd28EVb@o{<#>c=1bVx4A^uS?H}PTCAnGDUYn#K6AMDbfGR*B38dtm4%4 zCBWcvF5!XG#op`INyAP2e(}?BeeY!jFY5%J{)>TQrD1Zw?JHAqO4S3PTp8&3d;LFn zgJTBAY0!F{O7jP#_>b`}W0Q*Z%&B7KlKhYNVxui8FJD%pXKNyJltX>w?VzD>zuB2t zh}0$1o&E?;aq+eVFV^)Yjopt^x6rQeK^XP?-EL%CFQri>?h+&&{n4<#a*oGz$aNXV zgB7l;+=goK@lHMFMmQUODHVtuTJTRD5Dm>sB+U6rK*ctO~bP!Ignx5Fx!VY2I*j^2`0sF{eIuChTUD>AT|?Bq=V>>*4U3Z|So{ z3_3XbnvLjZCfYsdnL?#N9lb|cYR3`~`VG}XEC97LS|ZZ~lkEGP);AFXfh8&#@!la* z3Q(4cxtutPh{y}g0z&AyZAK1uUj=k=S6tbUpEwhE7%uC5kEY!rY2fwU6U74J7fDWV zSUENS)LLAf%peV%g~B4NJ$COIuCA{8l3%pe&p#}*1o&`wiQB^#-2B$B>w_ki>2{8p z{7fiJKy%q>*$zt;Mp7Nx;6<__r7#JfJADz85QCYV0KQdoYr%Lp_@cKLAGKPXgGJQD zUJ+&Npi9YwLhdfvolG4J{er*gwB=iHmTW@@vUX;Wi3oP-yADH7GqiEJa(n0DOOLY? zh)hQ^24-D~`$+ojhv0{Wvi+AJNB82~Zt*D`^}}bb`nGKlBg&r~Y>z)P1d31J8iJ)x$sW#iwW?YFwOJU)k~{3L;oGDl~v z@IYZC46Li?RllvNv4&(Rge{gNcF(+DgXbPDo~!QYWs83y2gn_w#Up+)>~=CBVlPI& zT6jZ%?9*y1mlaY7SClYtdzvX~~NH`p`!L2sVB(Wo1iYdUU@_&LFBe@g-x5$jMBMqGM-^mSzo2 z$8OAZx%g(1_+XHOUnIjGV%(mD@n6*Kpb) zH9g(h#YHD>!Kj;3;gXzAoIKg#o%1J%CP1IpFVss&175kY(k{RtClapvuS08aSa-YT z`OEX@j1}v=y`@k@JygU+Z{6Aj(~^+gDj-3h+CkmQT#fw}1 zjTn31)TPGXB`ih#*=Yd)Lxdz|vq`1n=|$po$9 z#cg0DC%X%@Zp|EK6TQS@cpMiioxCs~gD`w*e6OZ2Cy zeRz=?apF(Ljl|mTIA#c+I-v?hUGM*0t;+7E*He@VwMJO;baVlk#Kx9z7!;@9xaR2; z(Tmc;5l;U5d_-tH1RGfFGH}$WIdr%P|AD=viK1rOZ%WNC^r=`w$p`WMQxrCHfhWaDuF#AFO%%)&d6^HcvUZmgHZGLFuWOHFdK_T!-KyN2gMh2+c#bPUuZr z>yq^IU4Mt1iUUkE*Ak5})Xy`2ka|FXM6xcM)nY~-+HMBd2nY)D%oVBmpX|r6vfH71 z7iKi~35f8^;e0HVLtR#8<%|d!H&cnn#5r*Mu)p zMHwLX0m~L*$zsT{wXj`#O$98H<4Ch;RJ1tM%r-Dw9j#*q!K_0=Doz_3Kl~8JJnOLb z?uTj2NAQzyFb8E=uuBGILlk$G8hCJ!vEN~u$T zFb#MkAzKQ>r|p}zhA_Btzi(~ea0?@<7>0I^_#d)h&p{Q*T{?VJ)`cC6jg*|yaH9K7 zDQo%Z!VHyA<>bOJae*fSLbf!;#~T96Fir21uWsu;HhIGDY1TL&Tv~)ls7QgBRJa=~ z^0TPf4OJUwpovmAqcagB%^&1R$&SS1L{bvcDk=SiYZE@NhEoNj7)EU-IQ*d5n_Gm$ zKEG76j?d;)PMn421`d^^zf6+AMO?^(^Q_=!NgMJ0?`_8%r}g&mh=y;XS~@mvpBPOD z(hFxzPE6RT$2UuTNdjUmu+lr z1lECt$%Gb{g@O3i-Hhvg)=%iZl(tL2)lrv}?bNf~%oU*z%Fgjg1z=MQV@6dvaK)te zPq(h?*)*Da&Vkziacd12U@R=;@j9kjT3TXY8ph%}4t1f;sI4BonX;;nQTOstE+mPj{|T#1C&``lk>Q!;^$66Z6m=dC%U`h0Iu!^39%wxy(LD!xdV5}-hzu3o6Wq73QxB7$489L}}fu!JR zauKSz2riOTTM|{|#<_St_EO>q4W01KLm(&xQHTQS5CoHf3puT$$Zl%Re*X~{A88H( zB|f}aOK`Mg3bRn{VqEO8F-c%q6=i)^*60C4MI2QTIJoXB+@2mNWH`hZcllX3`!K*4 zG;xealwiu4lqSGXQwg4`$+k4-mj(6I^ij+{FAPco{`h15v8Gw|eh0SBDj5x3mR(O$ zssEOMMv6Y?$|LVPTcGZU8y_~osdFn_*XU7;S&E(jfeJ|2X*OGtEs0U*uD`NLTav0W zjNt6sB|NOeU_#}-%8TioFmp6%f?&LMo&`l`jm#9Yx~ofyEq4~4k&vXaHNiB|U^1J7 zH85pTPlQlF+(l2_HVLSYSB+5Fl^Re~O=JpC;+lky$Y%|UvyFI0yKfPx8fG+hjFhh6 zpG}zO0FpB~G=3OvTAA8G?^!2*8JsDc_UM zbcAOOwq(ax!~<7ow^p1{J?Q1eQ=WxS%%ez+iS+JpPL=?{I>a)zqe@3 zG@G?0XZ5U!+EF`<%(Y{NIj*o7)nkg?FRds1`lH=X6J~t*Rqri7AMdmCu7!`!`<`ue zrY`KW@0RDn)t6$nmhQe?vo5{(fT3GLO@4mMll|Bm1%eV1ZKnCx!b2-1U<;06GI_;u zrXom=>Nzz-$H~GX^%Aqr;O$Rc*_|&y2e+HI`<)L-P9`zN@br_BkdsBakUMCKSm*M8 zFKG_zX*SWgs4$-vE|tA=l`wyOnIHJ+@UAqV$>SE;y-M@QTi%#Qr0l)X6YsK#P6ZE_}=vSXV+ zkWgaKD{T+=fF7E-5K-A@!kAxZpz1R#6r=O`kaU~11`SGg5BRwV)-|g}N54M4mkx(_ zi&Ye9np~CvlwfZ({EH6Rw0K1xwr{)nF_wwI8FHet=C!#EAh@dY95s(lAruQMgS~e# zUpIb{9QCd+J4PB)N}wO6d8H2!jBJyELZF?tSz@SaD&IJHi@P2jcGwH*xNQC`%&?<5 z&mol7>>c!-1f(5fYHzx?3hqimD2fnav}q~KdPwZxWmo!80_@waTIuf_#mkL+Eq-F_ ze!OC7Yk}IBOx8ZdsIuwbT-LeY{Q2|s7T)fdIi%}TUSbBPwA#9ze(fW*4*?nTkEPrn zXqk#-&W($Y0=+!}88gil9mlEOo$t>gfCQBrta(L8II^ICCO5QpNvXg)37eneD2h?G z0cFSDyP^N1C>5`J9&JG1zLR;R<0>-nw9TA+VZ%((<2uY$IrqWjhQjxI*1lLf?P)|R zSm54^(Ny?~`bwNpiMhI(*A|kGm_Y<6=(qAGv3*m`y17Ae$|!B2p_)f0P{&DIMOOCq zd1v6Z^YRK$w0zB3CzIq3=G295DU~n*I_z~-(M|F0NUrd;#{J54C_6uKKW#(ReDa@{ zK(Ig<`_gJ-Xepav7{J{A?dD6xeSVJV)G=fpVA~4~CY?eAbD>4dFd6R&d&yM7=4<3p#5_KKWd>xN81z<|L$;J7)pbqPz9nZ+m;N-EVr*4vu^d+> zJb7hD#k5X$@Q+X^zE&Nb>Nt05ZN1(~GKdTw++^~pJhIKQ^Oj!dIXL@ydcytZ8%sI( zDTWryUhY)1wd|F0ix)|?T=YynPLc?Vo;d&QK6N4C>nv(%(Fk)p%culY8la9I5;thL z6}Tz2Glxuj&gG;}ZF1veDmoK6a8rKDf%9I%WWRATPXY&#n^o5X56%^d&SLc< zD^7lbFyN?)=RQORgf4$G;#4)Fp?<2BJuekWhOmIRjpflY4zqjGo^gO%qD5?&w zdC`gQ>%|gb0EaU5nB))0FSzr~_@7Qx(V7_Y@6y zIz;p6Rhba7=L}hT9)APLZ}>JjM67`muJ&)A8$@cZ5}NscAB=HujBUa^1Bnm(%g;ya zTHy9`nNJ==53y*G!KORl67r7sZ-4TwO8E4phYg47sbi+a?xvtMn>=5u?-pe)Q;673A+&DRv*F?=U-(b^tq+E|qgjM6_ z8|7lCT{KB%a^A3lS7^#bcQk**RhmC?3c~*FP3e>V9wMx4!TjP@(wq zEutjh-81G>O-S z>);Vi2AfO>2+jt-jZZM@>5R{{h-7ZK?v*Y(bSOerBzdYF4=#^FnqxT6V1%Mn9kh22 z0%amK9Ikog2uf@zhV$Jaq;s^4BnA9bZjCs%cl%s98Y=%!?YwTtuudtzv?7hQp^wfN z91*PnJ(Iovj9xVv>%Q%eU0<0+jmETEynhM7{Xk{dY`AQ{gKKn0mnRn&6)mQI4x6vg z`eK+0M7}AiXH^#1%5P=$(+qjr*fB6Nt}*Kid#_p0tc<0f#U!6+Flm>FE(nucI!*In zQx`b?prpFkR`a+r-vUXB`P0#-!h!+=i{36*7=wo~OzfgHXvX5Je9^;ckd@tXZ{=5c z)n`Yue1G+{`k-3fzge#1L?<0Pv+Asm$q+Q4@3Sh!JV={kp!rPdojY>sva;jEw^>e) z+HGy(5V|JQq7e5-xeun7;8#SN=Mm>psJ4T*n;nxk#QTun1IkoWRLO+_$buS@qvPp) zg@LhU4pEF~iiK}k4Z+(VYe;EIHScUN=?L}aQiI`K74{ zhyZ?aSSYH0@tv$=ivu_GpH6GL;iC9mW)3+pkC2;^lEMaJcChYt;#<#U`JRebO#`s| z7$8N0`}`$Sm*t+m=DwkMJ@Zf0A0{_kJ6k7J?3TV-=646K2kJC8_~8>5pyDHYylMj- z!fz^?oMmh?WxY)1bIIqKPiPCX-CKoY@+RgA#y7Mys@LeoS19tEp0(Fhk`s~tYF=4v zdv4c9f@N6G5rI*UJg3T8eLUVuNkHXJDL_n)*N)|rm&TB zNag6r=Z*A^hH9!9aV{ed+7DuXTA91}^uhHNcIT{=${Is53{eX!JLn^e!yXnf#^->b;Gaf-(<-)vGt|?3^P9 zpGRiDj7;cyO+6qMB8B);|4H&^uAauKWbenRDK~Ep*8D4imiF`Av9WLQM33Cq+xkw? zlo+pVqnnn%<;H~eB~$r9R=+#0_ieX7-;(`u%f7DVR`HV_TTN1@#ZxUn5O|%)$clHX zey!cVe=2irO|5rOd`(zA}7cF^YE4T_jfusO?UN+JCEhj)DdXs z1oL!_SBaet5^i24X~%dKs-y4Y{iBoY`y{yBxt&>ISX0o|`eKbNmF=IV5G$Ro?>Jxgp5-RH>SSw9R{&?0Uxn+h z{OqtgGin5B)CBKhw;l$PBV#}*HADFB9ZjD_*4XWa4zt+cw@gkV&2S9=YuJ*kn@7X@ z$G_EWT87!rZ=Fjanh-;^rYBcGWa0dZeaem--#$0!mwvHR%(}v6BQ`m&uare|ILl_h z4-~mX_wd*6-dT~Rk-F~$?ZpygIWGsMg*Xb3rOF9VjPQyg=gOKIPi66+y=WJVzA&Zf zYl~k3Auv;)rDS{Hmv9_@AWt?Ca+96cdq*O-P+uT~DYpUftv&!36h9@Y#Sn1I>D z%D?YB2V(03O*xTlBVG=H8a6D;-T!B?ED>@wyBZ9;+dqE)_=lz{Y8`n&#>6|u@*0Y)0d0m*e0#kQE1BYuiwVPc^gw(6~ z*OUAw`^C^n2#VXEn8|E{O-Txv(yp}f)$7|FynfgHbH9~-AKM&$Zu>EBx`gvk7oq(2 z9Fa91Yxo~FPMe-s!5Eq0Aw%vZJ|&S3my)@$XV0D;1@;mVFyJe!J|K7WMx^e~f6x^aw-5Qs_JfP6|GB_;A_f*%}x^cD@E z9U8t|$!C3=^4YtHOlTY3CcEf_eT4GM;YzF_Zd^3DO=Cn|hxYkl>$PR4AI<~WrWdwH z*>QL$R3P9A&W!HrwQHRL`^c+&3(yG1l1sd3q(Aj(AF<(0)EsZtjCAiJRzHl?t-V9$ z16PxfkH^h*Drqiny^oN-IR{D7N%$xiCbnBa_hGWh&j(?W!$Tx)n0>8IBwha?M~`JV zCqO$X2wQ<1DR2I(wba=8kt@~$JY0Ye!W13?gDTEz6CYMVT|>Ic19Z!bZ~gdz8;lbc z;=o30*>0xD5W5zlFiWyTXVf2G;2^{!Gr{>bd`Xvz$%pm;Z$q{Wh%JmQz>K zTzOR9d2G_fi}-_&j$2>wps?YcJ|wM7_g0BwVk%bq%8u7Wgg|eH7?)`0^CIbDe%$CG zz>%s(Z&UWS`b*E~WrcEJn(+O&Qy{Jbd2RLf0w~Q{A6rR%tCG0Q*O)`+pqe8*nelOO zqx#`MLjg>l7z_!@(V&H|#q;FCK7zix$~im+~n0XbQacRmBXdK=nw z((h9uk#HHAaAn@9m)OEu!Z}iEUCRIqO7`=^GyG+2!wNQS*szEbTxeSj)(p*&;}FIAWk}ig5W%qnPjqb7 zaDfGO!SP?~SyAex+xzW9=hz6MJC$)L6|~qm6r5Ik+MXT&{|e;AZ_DK*Mi{5e{NY=# zQ9%8mA~zq+DdIfSyFfE!zPpIxO)%0jsD_c2;oEto4-j+7>*tlu8t{XVNnL6xw0(C! z>oQVS#zwYtl}{Bt6VUSg+vf)o-G%f2;M$C-7g`v4jYFHmj(#GzxC<452k^ zhF)dR#;BNHKzhcLPu&u$-aQUpb^0t%VFyzV3H5NE!FKU;!=!lRwHfhe<+M(gw*kdI zDs`bp5&Mv{O_dH?i030YGa_?{DqR})FAiai0+6&4L|XG&AG~`WNMOIAb&5#2SV)AJ zR+2t0;5$V3Lb1Jk$MH&I_-{O-HaI7F;A{`Dw3jI`Iumc!u(25-02;;lmHoh)|iO071s zj7p&V>x}*CD928TxM0ro%w%(+|I7Jk2#;`Z9HJR+_QTeR;u4x?;xJh$NRn_cVR3YL z{{bVr8642!cd!x&QI0!63lfbPXOHX_Uwk3j3a`31H}0cGj9YiVkIV(d!P_h=U;!(P zJ>*1?%oSUt%g43}O5(Q|^_>*TvPJ9aX1X8vnXPCszjQIQc|VElNepO2_r5IW zH#U&}1*c{W-MX}|I`Dr&JRY4}2C0A-UZ-AqHZgFQHFfNH%umDOZNFPun7#(;9a3oqwnbmU0iB*@7-b%X9 zv%%F!ts}33IT1JkqB`9mmOIUc&Evz#?93kkZ|D`IvR@a-Hb6Y!RIu5VT5Z zs^#;Tq?4#v3VJ(O^O`o~!n42p)!gLJm;(0zJkzggarT>Im*0#cnfye+SX-CQo2keI zIm1_U-}dGmJOufNdMu}oG4~=4RN@0C6M;>+dfxQTg2D=Aen{pbnG%ZEz2t0wty`sk zYD{y6>Yl?dhXyVCQYV6ZQMe>hVjQje36}XeY#UU$zDBr~WthLLg=)?#=idA~>rDd} zdbd|!=h1OOw$S5M!oCy}QGc*})1Nh1e6eLYK&RMf(K+0*s*d5T<|&B(R`lsspIUj9 zP65(`EYkTc^&#V!TM}O2$)3*{>C#@Ee4or+7FQ5P=`N;JVMnLbdOxjV2^c1SI0+L} z?6f5MW}5e^Kj=g+QOCyC^Q_oN0fgZzr>8lQ&;SZ-4yJ@YXuiRAOC%9<szZ*Vn^$xOPs7E_-ke{cB83sFngo$5qj z#BQdYybmUE!h|!2csp{`mAELqee-6B=8XvXI5dh#UKn@3TVfqTCZXjpY42A*X(zT_ zwomu9{qoDZ;8gUvKS_|?otyB=+C3$E_a@U{ba%GBK4 z+}@8WAkA2lRtR-GafcoBBgx_Ejs%XF3@@!42ZIU4;_BKpL(}UsrMuv3xxvrRnOG`a zFlsaZqL*acfS~e!T~f{Vv(OK-xE8~ooD{yG3g(v(DyYBa4j8|<3szY|_(3kwDb);h zzBq*tI;w7N{_cwi_*%H)Rk5)nc}}0TUBrvD)%9vGjkPT#qPORmZx|#Nql#&B3ClLv zezVgKwgow^rcHrbrxtJk#;tU*$nE8zyX$UpBpTVOI9;A<6?#JCL`iGoMAa4jB;&@7 z$vzwjCsv=G>kBUpLx$|k{F4)aq}CoV86uH5=1Uv>`p|$rY|bn-J0NS$qu+9>D?@pY z-J*&-wo|GiTbJH_(uWtpZ3=Yp&%mSh(Er8TnFr*Ye{a7Tvl?T*V_&kbA;wZs!i-(Y zlFHIz52a|crehjT5~m=8z=VFWa`Yn-Zmdn7TYd*gJ^* z46EronMW)Lt{f7Wxed;A^j=LcW(Nqxl%}LxkS_>$uFQ$B6EHu^nXex+RGfMs9?5Tt zd7;0HB}_)K`Vzzr&D0;~>xGBN6w00vSI;qBa1Nm<|E-nMQ)d~7va3a<){zZ8;Mz#| z5>+E4yXWiKLmlLFE`Bwwq`V0~n=)j#HYOJ+!o$gxuth>?AFX9-)nB)NZvtm~gtkc( z?C)GxLkD?liO;5s$XE&&f^TCdij%Z>8s46BfA|c#=t!6tG#(d+Oa^WRrWLDouHI}T zNkuu~2?@XLK@&C+XJnR0Y*${qcrjmZgn1$Bu7>W{*i}_fAYNeV6}fgjYgbfdIjF~t zpAJP(+{M6A$8EQtqbh`#@7cTe0lBH91XB71mBhWg^hl}>x|?ee%reqH8WLadODrcD zM^Rc(i45?`HkIekX53U3Ywp8RlaEGpOlAg%l1BPwbO;4UUW*WBjQNUYH{?Y3%d%!1 zFbJ9wKf!S-nwb6Eg48MpK-D8oNi>_)m`8Zbf^<}Tu`wz#QVbSp8IOIbv{g9yj_oE4 z2eBp>DPQBl&pK1F!+y~~GKNEP)cfM^%?Qq)zr-R`Xd{%Or(cB~8gpcu_g zXUF)}K@HWfuZY6s^Oa`vU5K|f>{p5VGGPO=HPA=Qx)?PvWALEq_0hyUkA01U&A|y} z6+G~&5n*Q(5ktylt5K@zY2IeiS;Q3e*1NT!FYWgH{=4XkifWK(TF*W{s%=g-i4xH_ zOU8C6@k?V%e2uBdCcl0A3ClWB2&z1gHM{~c!;MK-IiVy?6{;cMBl4LP3LBCNd{4@( zH>8doO?B4NCgZU>JS-p=z6TW{oOdF;CjSY2{>l_V0I~I8;Yn3B1n99({nl;U%HabD zcTk3T^l4(oIDOPT3l5bzm>A-3KLq*dk)TNWvBedY@RqRyju36PKK}rt*zl5X&Q<`! zYO^LoaCieYA`>T4h+%V6`FBE~(%n6wCTofKu-tg-vI#Vak2=QrCYc=bo4=HJvCXK8$dF#OT>QFq^Gz(E&duRmNKA)!(9KgCWczMxPOF zb^>&XBq4|JJz-$bk)iWq!#T_Js)yk=(3KFr5ANqYAwg281sY9z)dARjJ|81AZ_Xf% z2@@=7@w+q{zNeH4CIHCKALyyCzWBtiB@mypFyxf{c12e?jR_wMEjoYu?J1adVk+sj zIl&qtoh)$5i`!*ADL1nHnj5UI&aGaWKE8S{UeVVZ3a3w<+I6-94I3OLIUuvS8AHRb zoaQdY_HeGOf5dH9W@PPS@F$aTJITG^iWt-0^Er{*$dMDm5TwvK6#YMpx4CM0lEPT1 zo&djC#8MSwC43N)7|zSrrO$|gk)gsGGy)Asgwmk;7NW*m1{{G4h&w|4q0LQ)Nu=3GH50ZqUBAa-Ke%c_N$XalW&;4 z%RQgbEWu=X`l)N$E(y@YdTtF5d)?nBuN(SNg-sO?xDG5W-a^ONIozA%#q!8+rhS!- zO8j~mxm9)v-G63kIf&y_=w_B%iaJ6kDFtRjglsO$5sLpgZW5&U92d)oiZSa$Y184V8gS~6G>e7 zBDBI50RzqFiy*V`Sci`o5zcS&9m36}8vFT4znjky5aH#7u%kTABBRYxpUrRGXW2** z>mL}H0!YenC}!b&M;HUc)c-zDT3<+^X@GNrt#dxrV)WAMKY8z{gw!stpTp4vmK~>Y zY~R}BD3AY`nnsqo7#2mQI@>vuuk6^f=Z*|(L6=F*g)7eAcZ^M^ogz}w0e1`Cn`a65 z@z_a58A!=FAlB(7YYyNvBa>Ru?R$-C8aw=_iRkgnvm3W6`S3ixH4`W5jCpxkz|5_E zX3Hh9RGOby!xbR=KFZYdp=SRaGJy^}Jr3(&p0cz&C}Pf#t&i2Lg~WJ~u|GL`q0uh@ z2LM^=&(IgpsKSSiE`?;UVq=)lOn+9$mUse=?UV45^&@h3oTi3e3znX#b_*v4?aPcA z5o#Xmp~NKB@(xqtvM@LQao{7g0!Rwyk^oYV`M-PrzF6nbnK$q1>O@GKH5qW`a;aKU_@;v7F0K@uiHz_)M}a7~ zu?I_OYN6$|0DrcQJ$m?9`C(3RX`9=UZhNsK4L}*LQ)~do&inep<}jXOr&xiAnM;XX zJ1LuI=}dyr_dxW*Ved;OHZS5Hol%<*XhlloH*enjkj?pH!;9Hk78JO9d`U}mP*a=X zAw*$PuFmg{lW>X3C773hYLW^{UI5d|`|W-PY^Asn{x~^j+dJ`?{mCRNMZCtvnf257 zK52AFPD?8EXuB^*U=U}JwvOJCVV*zixj*1rfR7U=PV~)|8iz!~v=*{!20A=vDH=MG z3dT9p4_}e1;`gbIoktLcSYq+mns$qvz-J4~pRy}5WkJ@InRvggWD|=LlK#8LKezQ> zbJYcpFklqI@QPEh_QFmVC_rRkELuf|8cy;QfgB@Zj%hZ+z*FFZYBPVGH}8#rVm!y%G8TvVy4sa^XAa?xqJ z^}-{SZ^OQ9x#O9r{zOniaqoa6ZyUC!b1}e-F{@mJ%Y_&hFwJ^qoI}xWL3*T$pm8B) zD4{(Oc@33QDfSkDhl!LJjh}BWzX#ZPK5Xc@h%Mv|rU$wssctx~f6n24_d;N^P-_cXJ;4T*zu<4>XOfUZoXFBNapTDA+yB7I_0mH#pCXX3NMr7 z(^!Tkzn4Ab8jT4=k`*oQ(agRJAt2_G*+h&c-mXQ>RSq!TqVe;^Y&wK5!XyPOu12fi zhH19#!wcUNtm+aIAu)(|NXTuv6B)(oGQ`8WH-7HL`Mid2(5@;iRU)jE^M(v#s7KC$ z<0<)tQ-CHG z5f2Xj&H$lwn;-oK4H&Q}|3<-TYsx9$vioIm-OHa|dXVI;WwY9>?eFieB(z+!`?r+a zGFTECrpJs=PuDl|V6o^9m`=GRuEPu)-wUWAo?jAA)aE+B=6j+A!x^qOAbX#!Mp+iT zgi2Sym8H(9#EA3Si&foOwY=LLG;bK}SYlBtCwyoybsrq-pJVg}B>KU2`VDha~BAWFSR~?6z*to+^ll zZ0H>s={_c%V~UP(9mezLpYZ5M*DFjtChmTnwu-#iQ#v=2!%I4Hb^8`Mz@RzH*qq_l zk!>G&g;x|F@VEv=th=ms>G5>c-!(hRJ!<>a@7%>MTe*uduL*x+vRJ{v= z$Kuu_R}22=!8lfsjNEV8%c)0Jy4?6`_niAHw?C`x&d!X8D_E4826ySn zMU&d}t-VZ`75=|8z{=`od26Pokt7VsDj+|&lv{D)(4l#u35q@w?{+_E)le~aOECn_px{nqZMxfMjn!Z5Q>di1FN8$VLH*_;tll_3dtwQsF- zE1;#G6aD=ARazTu3cuzwPRY*BUU#c`wZYF$jH)(G)C-x=ivZU*pjETzz6bKC)`Q+B0>y|;z*%Q49gd$T5Eg&9%YAvhXj(xUY4|$E9KXmu;FD#sW z86|w^8n?@K{F+&k&2i0X8ge@wpCMq5grIwMaTG0U;GO_6N5^Gz?b(`7h0I*Vrq0g?4R1&5m^P7I{cy$pqkTi-vZ$|p*ezAUY zL$zr(9#iYUAWgHuCi}Q~MQLjN^HjUbBZs_wS{tobNBB=j|IR~2pyM%kL$i3epm#zT zS@`SML*hPh1`qkzYI1D1~h|P&b z5B-m|9EYog=^`dyspTN@HqM6Y)|DAflwi&j*xH(hL_ACQ76HS=&mdDv;t=3Q|2qw8 z2fN!4S4;MI-z>OQzU24;<$x!x@m@@cnSo;1TK3}QVrxEFQ}g`aTGSM?=FBlWK5A9d zt0?yG!u&*#Bq@wjx6u9CbGNVfCTmxjEe59z4g2RG4a=Z()?b+yQ|g^;>@Q#9!Iu-) zQkxsePsF#X>ulA*aPm787cbhlvQKpO{z_pGvrrluzC>=KHP4*&Vj(&()1 zO!cy9QUbg#Yev6(+H=ziBkR@wR98n0O*yZkeyXWG45eOl>D9hnyXDc>-z5&}hyCQE zF@~Fg_gzXN2BPcyWSi###nQYqJ8tlAn#LNo8khfU zzi+m^bi^+CUP)(K3|+E`W{_M@uYQhQi& zYwc%gD=5F0T8sJ%+XCR3f~!_KI7hGER* znpbzdue<8(+E%4Yw{HQ2!;o^@wdPIsZxx>zGhjNz7gEfMf=BI@{plsbT+6|B|J=Ecoy&c9{E5s0xj{ta+Q7r(d&Qd1gq-Vf9O6^a_tX zGc8duSZMB>ezN-tqdE!(l-dEDgvpTYV_?mdjGH$jW}l_{R@w}bEzoS(gNER>Y1*&JMfj%=scCPpV7kTtH)9%mc0XZrN%&3$%h zLH4>Z^k)U;?NW3(^S`&ZBIV*w)F!Sp;?Z#!@#~b;=;?%zvrRPuN-&f$B(Yobu9cyM z^I@kx1e%z(OUkM_#6&mjgx3wvfzRx(*4+xzA3b_>4FPwtg|&TDH4|Sidro^z%9L}2 zXP~UAc&e`oslY?SVzC0x|8@2zKEoiy26%6Mwsd16;{vGLU0x~yVIq(faVZh=0i+zp zYV8`mPLli#8gzTi+k@N-l-_)m%ha%lg~QerUQI~dwpF)fk3M#Q(F63l z+&N!Fp~#PPE3Obo+dfLU$Evt>>(&r0t;~Q%90t}o|MV*fA+LBv#*Pmc#MG$*<0^~e zUN!J<=u>hR(!QNtWv?AHwiI4P`fuUtMU<*j#wb7^=TE;@NITpb)B5?R-k`iSWXVkba6_ULA+t7Rzgd z?J^a&Q%XW?noTKNAK-WMufLMg)5DXV2?}o9zM4s|i+GMNFS>l>ziU^t?qEvYDvN7w zD4AFWuMn`UBDa}(wHtZ?Bhxz2MRFH3h_m9$ zWF-TF)#gUd6|**9z~?X94-}$FaRLX9pi)98Vclh8jZR?{@p1NB>N`T<)XJkgwTaDj z=~J)Xc_rg{lx3l-Z601YMXj+tBb}-?6DDNH?L?=c;A=Ph zwN!(NWKf@uS)aRrg#5zVmtf9rJxdCrqwz2K^+ zR$Aat?!D8&mhw@$bHF*(gQi%c@Ir5g9BTo1yfeW)q<(zUyJx@dcuSeHL?2kVdL4W_ z;WjdRjCI3w4<|$`Bt{8V$gHG|vmg6xet6@E2};XIHps7ar=Ebh-3-@t~kZOd{E;BS9(h@>^ zgD5wr8tD$7MS?K78!7iL?eWa*y7X|AW40Rv8=Cm0OP8`lmkjG9&yakvSfAbqA9fX3 zO~l#b_s4epP$nJKrV9^V$b8OGYVaUNlJnQrZ^NRNNV7axF*D59=RE~t^QQrJZ2u`B z?qL4U=utBFAgjEf(Kxcl%~j(;9kxC{ zH2p+mEvVvO4^$H3Y9fkr5oi!3E_=;>MkXKtKG2w~OFCM^8^|3Wv4t@@!Lx5T()ucA zt%{bd`qh|0N=83dIULT!95IX%{2*hV`A+~~607O67M`k~Kx!?j%9Cq9Kp5eyYSE05 zC3{a@mN|J?$1Yu7T+j#B494Dz*L1UZg$VB;2|@!zZ`dyCHz9@UgE_T*l=0t+Gv+ah zjI7}@xirY0H)>_D!7l!#_O7b(VI;qcG0m1SNN-TO=xlbROlE%W-vU%@-VH7;=#)6m z6_u2fylSkXHRdPdPI$}vPmu?!Z- z1bD-uCvN|J@$^J(z?=vuYj*ztu=z^)ZAsbw|a_O>2iwIGI84+G* zzgi9z#_7deQ7VyFFKtLB`G!#5U8|vyF2Pn*j={bJqi%a=<51-`aW@u~E+`i?|3oW_ zLQa~0oG$P&#k8AXWj7@1@qLZS zCEG*PfmeY&H{)|s0j{5@fBv=GLPU^^z`i;qKf$`w8bRz zVE+RlHSMyVK*j82dvr7{^}vdFGM~VTp)+|(WshRQ-{VU@lhQSt+*T}yRDCuo4(YQo z>m-txfz5igM4YkC&8MZxPTcIuPYNzoT3Q-M>cgfFBveW5S^T98pL0~Y%4x;P_TSDQ z%A^Ar+7XFLBysp*V^Awd)O)m!w9~;#Fe%)m=Y;KJc_IXTd1xFD^2yY~LDE-bZbLED z^!5vQk|s!D@n}Ft<%$u$;(ny*L?J*nEhjUDg7^N4x5RyEFo-xAj200m!?ci~BC=Ps ze2!=?F(rvfcPpri3RsGISKGoUL3;3(zEPVXb(ySIF)lTDpD@rB02t(SWOMvucXu^u z?y?k2cm2{q=LWkQ?dY-6ScO_7f=wcuNk$GW8d3IHBE5ej-BpCzL~2XscKS=p7_3o{ z1CC<3SNUN;(YC#II)tc}*A1D@gsL4YVrdL!lJ`H{9!kuFe0G*5yZ)!_JQSyDW9Z2&;rDN&Eod-DO1joh>9Mi z=wY*t^fZD>#yLEn(KI>{*SvVYI>0_rPF5K}Y$?{C!ZPE7ZFP4V{6_DPSi)-w!${5$@O{@?!ertg4*mg&`he!g1$hU~q;4h*a2PaeKN}Dr zynT6a;x3GR4XLgfn7)grrPYz{zhM19lST%6oym@ex%o-kg`STTLIqn&dVsXOAWxnT z?)P$Gxh^xUf~>FQG(bYZ#xrpa4&B0vqfJ^atP^@w!cQ04(KMZ>%J=MkHy3wqA%n1E zSCvh2Zx#byX?n4iOd^)wd;oV$IM8gem=)*CFT1?-64!45v4O7BUf0fqm;q%3*hh=Z zNtBA=Jk966x?WA##|{-!by?7)44#SUz^-p-=af;`Ux_U)TmBxW zFF+uJQ=)CpC(j9-<9F#%t;=c0~sAki*zY5F18%hl11RUyX^X zhsu>nH(sCgj5BrXC9yn*q#YTYw3g z@41Ey%IuzUQ^KVJ|Fp|)w(!ZPJ`vG38TX+(a5w8@b#q)qG41M87(5)(Y6d4O>MVJ` zWOQI0kyB2{nX-@e(FpVBtRU_-YqmWA(N*rcr zW;PelHo+6UDUJ&hM)ufqWwfc=CiosfN{?B~QaJH{6WrzC+htkBzmXsucwSnn0S#l9 zWl9-KOoF;#-Le%Xhl^V~P1~$1P2>7QET1kzHxxDMTQ$)ZV}*1$mab4Q!xkdU86lJl7D_rs=6&`<52c zaC`eIZ!0m)_(SnbV!lQ88YXtI3}#?3nzfOtkTcj+h_DD)?{j{BmS)onxM$ReC-3cN zd7k4MBqHD2f1({(s#LQO@tAZNowbn!NZ2p{9d|m~CqiNM5~3V%&JZLg7TzW+p}#oP z_4S=WSAP7~y}weOZ{?p%%!;I5itSmDAYoT4M*TYMqLCjc{85F)?C!H}+ERp{DUKreW`*%ajKj_kg1h7l}Od!X!YDXgiY_iB~#32zg2}{9+iW=57HbGMh8@ zY8|AU^-9;CgZtrI(Q=+IvNhPGH>)~3a@@7h&^!e5~v zvwN~gju12!aCXUWh1Rhe7PFv7>u!FzMkI1s$};#NmqE%*#%)Am`JL|Oudj7eDzP2x zLqmQ(+j_6B-8MaGVbNsR=R4-*V48p?z>tzxW@t1oM^I0i%aqxpupj}F6lqq4z+6!$ zjv7-wLRt;zu!--w&@Dvu%jfk7ZEr(f$vv5f4OH%h-kl}gkY_7W0~#@)BR`AI7Y`|w zXFJ<>{2~aqo_g1HnNtBmh`!tVLIOP*Q(3oA6$}IOmUHBanqQbzB0e*fav)Kl$?+aw zY!?AXm*<6?ZI)AOwV#zYy1^{5xhY;S&eBwG4dZ|{Nc!V9wob}vpi}=jU%n#~)Pk;$ z57-eDIAI)>12~QLaabZ`2aa*c;4zYsG*K!cvMRpJhUS|L?X*}d_wIaoM`IjQQ^TsW z65rUuj-g6(i8VG;ja@E&d+-P}8Y1E1;!iBM#j;HKVJGq2)QE|Oa?cvR<)78S!Firg zrXho*$?GY01Wb`>WK{J2qg=arGKD8|@+_VVjZIoE6$=$ZMLR9jiP~!?_)CCTSX}e_ z+gGG+HVkYAjIIT^s##s&%QKRx{g9?`D8%`20i$Z2OMt z>bCrfc)epH)B`Fgzmi$s4$I^NCapD)laL#6gBQO%>(|F zoCvAk_t#O#Uv7KrRA#)BURp3M$m3~A7oqg&*4)Wu6csln-_pIax6wEEsVrHe7I{E< zD3;2Rs+i#JXMn@u2!TBVMn|sX=8W_%cZ2;(J?@p3)Q+Q-v9Yl)AD2n{lq840^q}Nd z`lNZPb1R%%+6wI^_IT%Z@O@FJ(Jgws4(_|KK1C}=W6^7phBcgY@+ItGnIHB!G=ThS zQ}WL<&bno7N4{COrn&<%gfu8Ktb7rJ0m3>HQ7vP&=e<>nzc1M(&u>$5>Vxhqf>Niz z`4D(Y8O9swb4ZhX$0d1A28r&YpgHQXY)JvZGP)k`#G%8se#An--A%`6(-4JIrJ7o( z4Q^>75L>!;*d8Dt&U0Q!F7TnC+8JGQMuDXZUd^LL8g2%+G1}*l*lpg2d&bMo1-Gku zU_UdT^GUHPTTN=KRMfquDe7}vtan&`5d$gm=>|eUe)qUli=z65kBD3uI6rWak4UE#tF`O#*D!~MiX_LhOXi^d zB75L-NCVF3d%!C`=8-s|0P1IS?Rtq^RWay1I&wY1P+?H`QSV(!XNJdCBy}=l<&}0} z4`G_*or^!nYZO&z0_Sp$8BY77rF+{5Fp{9oFIehONI|6)N6}2qCs9Jm=(?$BUAij| zMrxVTz45rX?*o)rg#jlymf}S1H`#oB>Ikc=d8L-As?KEjtLkdUpdS-BmDef*`4F;I zkTW^Ms43wTjrQ4SCiw$V8?Vw-a#X4Fp6}36Mht|(Bu3gM;Jx+2Rn)rG9H3urqAe# z#Nnt%6UawnV=4zHAO?3C`99q)^3O7x37VRz{ze&Bua~9waCcR$wE&ua8njYyf<@sa z$?4A?v1RUOGyS#!v+QWBG2`ksvA$B=*8HfYB6BrRqBfjTq508v*&3S;(0DZ#fkNUr};HQQS0OvMN6G5akh1yXUK$F zj=ECbS-@yJ8)(&tcZ=OE?#Kb{E0avJ8wl`DG# z3puu}+qTvAITQxAkNHx*)6U6EO{}5S#C4@4ci*40s_5$!a4L3V3jf%Dz<_Z#%lmxv z_%nQ!7*i>1aM-j-$4AjYh;6s+$m$AI^y&0zaW4b~@z^Qd`)GNp_p+i(6%Iw)TD|&4 zxol^6fKCo`ZEf#@isx;oAL7_H=Ol=CZM!0~WcW-d`$fs=77mT|%I1T+E8uK7JV(B*nBZ|Zbv2=QxLILJR9U0s#BR-v zG5EqqlsQf$C6`pkdhFhRTdbdTGjefRWt#sSi;qJG4@!CLm9}AU$b@ZKEs{{ZOnZKI zX%{4I6zs9=jC&X|=0>JpJwZXU?xWAl+$uQ$97}Z`e!TRdfU?(`OOpLgA3aOCblhl0 z=k)x&kawXkrSeu5)#OAKD-j0w<_$HC7=UkE9riB+@UeO=Pr)EDZr1w}6L-_U9CdEs zr@b<>eVc_X%3yBKcQ*D9Dd@+0To|3=8@l;`scmFT-i8B{>{M0;Gn=o>N;|@1BjVk0 zx}|B}7`wzdYOE;U{&eGiTt?E@yqX@9R)*i-n|Go3{5#yp)lT$K7K#053av}*(f=59 z;B`s5z3#J4VNH2obr`%}p5rrFkVn0}y-W8lx%?i>^~TtEe4ICo<lLMlz--a8_7o~W2RS3Iaa9W+iWWHyK~$_&80KehuYIJ zNMT+cz~?Js#G&Cs!@u0N;lsmg6l{$jxRqSL%G56E^N_F*1g_74q01f*E-#~Gi>TE2kTz^?mP$dDLKa2%$WDNNI0N6PogXncFS5bnm))%CRg z!Y!%F&3(>x4zIn(^j!1!@gHDezW@}Pxj!B@jrvEX8Bb#L@Rr}*^vEungyij!@wLWtqgQ;zt> z4=wEO?zXVDaGeA^hdf%u_|f+fSrUV9qT|XO?djB({-P(u2hdSX4zi;Y-F6^JK|uGAHqIXPV{?Ju8|bOZ#!MAmgz7sx*24E%->r604ZLh`((dB zlrUa7rsEluT>O5$ufPD<-cP#F2eudi{xFEp={`q zHFt+m2ht06<;ujTORibrc^)=ME(qTrOFmyfS1u_#+Q!VF*Ob098L)|Z^JUVad6zdV zL#U;|=~j=9!_5~8%AowNm^WS{z=#dnzN>%#-pfesU@P!dx zq(P!`m=QAPm$~Et7A9flIT}WWm5g+>mauap?8Gu-QljnBkRbJvnWHy|7#~*%9>XX4 zqwL<9UHFVUc>3ua29X8p&?W<+z>73K;Z?EDX2Klb@zNV18F7d%@f~3`0gB7bf8|1M zr3?-3?x2CmxE-x&+tNAr3cp)=O--Bz#F-q(S%+MbBFfLW&F9ts}|C2`#s zmA8j@t*kv<4)KZt-$#8m>8&bv>%`onbRwh-0pnS4Zf+l*h&Ge(fn^Ck^gZ=5H8Qc^ zG#cH(ZP1+Y|NN5i_1yzqM*e_&ywuT@jV@k&9A!OV?VHW=KF265*MAA0iLv<_nIZ<} zb8tPA~c!A#Tl0Ml~95nmxH3-u)({*l};Ex!eEHpP`yZs_Jq{N|C9%t-OF-6(W0RQ!}i-SgV7_wl7xIY_*sq}sq8uQS^U?`*O zB5j(Gia0sVQ=m5pR$vkXMEyL&r)>LV#=vD#R(_D;j%EjD@B#RqPNJDYX1Yr5C8O8! zR=pjSN5H`_P888nFp7=Pp=T?7Yo}?R9R|;x#EjT{{=qPCWiD<|%4ME33V1(o{E_zv z42dlTkAKmrAAS|18%E5NoMr_WLZ*;~(^5oUKE^bF_}%cp&vS~GPpnhwdWHp>Ef@N=$Oa6Rm)-$M|?GgB9@`;%bi%+nYDo6Dv2w8SlK!G#A`DmcHl5pSBi@ z71Q6_4VWFcav_@nN4wuweq{dW-j<`7!L$L>#N{)ZK#e5meSv&ngZKAVq+aHz-@lmjN2R;E>_i=`7O= z{Lae`_E(2lUI8PmkrCEpr&*U|+FDwWKYsk_?7+NC!fZIJOt@YnO*Mi|CLQb} zrzUg+Ff|u||H4$z454Onb}Yb@J1}Sa-E(4;%A1aSGjNsf_$|^yJ4Y#p={8R18$I1> z%!$cjA>DDT$)(j8p0n+OnJ!!Se&ySC17^wSh!7lr{lEJ~xyrh} z22)x5_Lma6so->miii{VFrO7A9ztGeGt9V@iyNO!3*35S`80Vp^a?IL|1w=taxs;b z89}Gr`LVIFr4Voa26VF59bmx{7sB7}&Liwv%uZsb@OEpMc@IKu0;GB#@kImy_|Z{5 z%QwOO!AZD$*tS0%A?~o2e=ztaaJ7l_8UgSP!s-YpX3_-Ui~%XmhiV$!T%Xu%%z`F8f6_)K)H2vP3rLtVG<2lTOY@jX{vY{Wjkzb?W#Cw+VDk;9WJ9@z3EM9-iE3S% zaUd#$V+AQuA2}!ZRq35_|HH?<*gl!I$i>yE-3%JJ2~kiQE&q1z4$z^{^fa+UXg~Q zF}qZEs`RQwZ-wA>-004?KY#v>DVRy}(Gm6+pWX68FlsJ&*S3oe7En3$;2iJ=2CK@O?Kkt=`y=jZ==M*%{bJ+1`**Y@I4COUOG?JgcXM^3REhswE z&))~+0Tu|gL2&Su4;y4zgLVzrneTYN)i>i$ism7_@#x*1G`*f>kKmAjqBoH4H0Su( z+u`3_^KwXEpa>%ISc@OyQ-WpA7WIsA4ykO=g{T7l4Nxy4d+MC28$^DLeVE>SiRN|LnB<5xCb1g2oeJ)i-gzNlIP2>kbspr7a_jBVIz zD*3xGYK8JdXz}c~s~6RMIgAo9BFaQqbC=aO?vBdrY|b4ygq>ij1#Wok?CUp!nt-l#K`fg4*^@o|LOH4 zidIP(DN;it;M=aOwK+V*WEPXw87xYF^o4_G9O~n9J89h^7Lk_x03YEcsfNdoi9?r$rTp-kz|7 zW&V``&)fTNhCP8toyl0XbMf7-IClO(Kl^Uh&fhRWNuqkIgPwl3phj!FWFh&8I?7qK};o)`#9lwe3rP znDq-sTzT~q<77@v4O?94;l#e3@e5N>`J^uX424ti?R*pfvl z3yB7@WADkjqIyT;uCmz69xbLzSyzbLLRtliF6G-qPEIroaw8sUjVH!Ki}gg!B#_cS zh;w_$+>?_$6RkjrGGSne02PjdnHb6lXv8gS(2A|@aamr7dNImwnRFespudXy2UJQ4 zMZl{vg3OU1wyX4#1#9E@4kK?!x^ZK&cvVB@klYLEqW9H~f%Aa@z(i)t%O^>vix~=6 za4`1ua4N(I0i5$NSSHU2m+R$27DYZ!EwLmm65^4irDcHq&7kM$epV_CDCu$Y`+_Bl zCJoU+EZBjP>2b|`Z=G?Gg*C(??jmOmWtBg9<2fd0Pb@e%lQT^Ou0;@!0R(pN_tK9A zzqJpmd-(97+j&VHYCW}3Em-^BR2hccZHEG;$k=Pb+guyabZO+vU_S(-DO^7ve;8H~ zkowS4zl2$vHf<7NVgebdz!727Fl@T3)pw(U2w|e@CK^vDo=SSbfuiE3ibWT@dZ?C` z2|Qsri4oBIND=}*=swBqc?f}Jho<)&={nSCK!TRrv$=HozVN~93F0ITFwZyX8`f{N zLGvS5-Tly)JY_RFl9^0gidq&F08Z`t)RG;<1ULO!=30`;p`qp_)a4=TBnn*`_61jV zR(<9^S{dZrZ^(9)sa9cc20kAm+SA`YO#~D3h|)( z*Pm-^zk5Ls?!>oBfcIbjqBiVz<#GP+zx)4d8Q#0I(Bc(Gx;Pvdsb-(m9(SsntpZ*F z*9r)BNbVGSH><;&xu6mlP;-B?F>>_3uU-DzK#98k`>)*xjBFWZ`S91*Jd`Ethri%I zmd^a(V*TN-FQ~{uAO3>>SlI4=@hCm9R&hyjTOF0t<}klw{C|C2r2DBQmq#g$#{VdY z`h4lv|L^a1kAM8v9JM6Zu_oOwe5OoB|DTto;aA%7Gn3NTAbtco$7(-_nQatEoc!a# z2_HVJcjwoyUdak;skqhDDhGcU(r)=AWPZN+UpxYS{QvK2`fz>!t2ZQnIQcou@}#sJLg?r&6MLjpcHTbVbLEqa$E$@itCbT4TQFc zE8D8myZY+6SEw2gioPLnMG7teM9JF7TWAw$7Xa2GQ z!tX_#c16@%dVQP{&=H_$oTv;RLFx(sVd6ilqtLyi0tK*exSXpcd=xsRlGq6;kl_Um zzH#|hO8|+tY~6|?lL53uVV5mHN$0bZQN0;ufIM%ubm`+%EhaMlc@$O$5w;+*n4Yfh zjZEb`b$>Pm!(Zr^q(ZLZ-}t@jX@f!DfrLP zZp}LgFETJovpl{iyEx^j@O0%5MQj*#sf_bTujNOFl1&h@nB~Lyb-kR9Otdgdk=+)!8GRU41Y=`wE_I(-3TUubdhw*W;)^G*m%94_V|omr{Qgm%o=iwn2|Y(k znu{K)t*w0{u@<;T4o+d@qXIdUvq#EaA&b(UuAW&RKbDVW4t0c^n_DA@XV)fID8Ql` z`nIYtoL=`V57D@V*d~lSIakkkSo{Dnox^z$ze9TW^pbaA={N>1L=uYhV-1=F$>dsu zYmvw2dqUVFZHIlXW^r6iCDOo@k|kYc$CS{g!;ie|=2 z1iw6>)&>dO&6{oWW5HDV{;;?E$Q4*6HeK5FT|+nxaWsV%z!fDy4G~HX0M1;=1hMqt zuZ9}!xVneou@r`zphfUl7VLB@6cPl}6)#bKhB1HSO}bYV6Bu!us?xj?_drsUI#EI4 zhqT(J`6{?l!E0+nNjV^_JFB!C7mQNtd`LCnRdV$3VbO{Hp8QOPmOv7Y27O=hMn=I= ztb_j%Fze{>^FoJ*=YG*R!e+*d8DcyZLP1c*us}9DdbHMy6NvIGfVUuW*OFzCQ|pWD zgwf5gXCyR&HOs^bXv%lSKc~ewlIjYAptbJ7-{N~nFNZ0GoKsoOaG-6NLy(Bg3^${k(Z`g-6*W5pXi|ml(eBxI)=Q5v1 z6gUTjF6Il9!`j0bX2Kzv>>7&RGGmkb{NDq-AcWaYBEEoyKRGOla3?HemQFT%SKZ-j zYsjxsYZL0fXR>I+!_Rw>YQI5h;0_anX~l?h}q2=L7++&tY7l{XnImI(~%uJl7T0|yQ)<27Zq zpEF*DP>KW!Aiox-Fth(b(CmvJ{+@#9!(6ylGQa{;a{w@gmW-wWJJe88_q zKQ}uo+z>#0o+oLkI1zh)?H^zoOWC;Vi8~OF6sb6B{F2~|o#+&!-?wh9(DbvG5NXj>KgGZC zlADtMDQ|UQpbFw#h#;nfC9$IwioZ`k;<0=adcAPK=fcqwD^_{u25mOg*ViK(A#CRZOULoDt#4 z&=BSzi@={&u9wlq%N+GVtUz&n-rMF_a||bfO!k82q5S+4Xyyfxb7DKN@<#bMIKI;v zJU9tobU9#1wPODIyc_1rL9|;3X7~7{Rc}|-t~kgCB|v}m8oTjy>0vh!MIoTB`2B0k z>Cl-)I6zb9aU5hxB0~B@jZWkGpl!>1bGC5B0`ss#4wY9ZuN*&c-lTL`$H2LRHXlha zkw?a{rRAfnYuwFObLhy=46C`7z>0*bF5Mn&epzG#cvgE8ie9^jz%hSqsi4w5{Z2Ib z1>Z3E+Sv+*lSC7O_DXPbkhxISu>}hSlwZ`AS@Z#HOfRF4mUdAF&PtRDOa%7?+As_} zTMp4tPA}nH1wNFF759h`a+_mDGqdTgG5680$HX0Z^ciG|QS=^pg8ZTl#t0%+O9$Qn zs<*HaatxU+ogo<@Y`^b5+~jiNi&%TXZsj5MCBdT$lZo|#^b&;JPLjD6P%XgtzR26M zoE1)TyYuCk15q|(^_3o80->#xK#Q_fWdi1 zT%T{eC*Q&8Tn_w6^XJdiQ9%&Sg=|@RSWiTD#o!=_9stYALPpp%ogKd4ij;y)>>pWC z{WLWV$%tf7Rel^2860ZlEl2Q76>HUf&k{x<1;D@aSc7>N@Fp@S{CWDROkRgGjCL&TGF}NCBp0LOqx~g+<(02wQV6a}##SsC7t;=iPB zAVkl$et468U4!MrqSyFWIIl?Uh4>jRo4O64i)+dpaP3U62<*#CuP zH)x+G69BYe)x)W72Z%*VBHU-rE#*u`WYMiVf2KJtui95kBk7pY9Z%>xYgl;ndztH= zB8mwumFeY^Y#<(4PEgyozftoDp$tIeGi0;kN-V&wK2QUr|jaP z?qv9-54x{_4f{a$Dc?;m_4$Ipwaqc)6*6~Wf8MX~n-`{L6oioOoc{(}FZZXUY`C^H5QWOIb zs?+tW88c_j(H`|6-3_f_D2tIa{E? zLp0J;cbiP^q@eNq>$0)MdD(8z5s=h*3YP7mWwKhK((|Yvkpm-y&PCmz^a&I(B0d+m z&q$1;VCkdKA|li*Q(nP|jBip@gt#kk8JLD=bzLIU)6M)nf}^+F=+%3P!d>WQmi{su z9l8wSqkIc7Po~0!go9p(lDv9JTw7?03bjNvJk(&o@Q~sJxiOs5Rtq(8e#L>AH3da{ zrv)|Z%^;(5J^3DEzr(K!ucDF4sJXQDDV)Y;elx!EIlUnP#@UTY7a4HT=->Z&vDDKl z4qG${P)5S}S4OPDRl0p{2<4$&s~C@qlw=?Z_Xu5Q0|Me6c=bYCsPvaM*n(#%#x5SF zGNoiVNG++4@{E3Wls3?i0RwKrliJRBqtV7$HPrDIPk(vT9Ve%!9cmq!iw1_z;|FB! zfp!m1v=1xa>3AQY%PNrz4%(2b#@gE}cJ8W|T_L`nRrjBBBFA(Sc@ z!P-7c&C*6CA*Vm2k;)Cq#6ScmfI)<5&smT~%`HrOX2KOVpIy}))Rf1q-K?22?{P%K zt_y2ws<$U6wU@w8p@fX-NO)A&-&qmsQ0-^yQwi!{E{y)F$lq$57)D9;T~V`Mg0^w;7^{C z=E4itUfWwb-PBLuz*39>-0C5;YZ%YMFeEG>lA-0ZaS7olg9=b+go%k0?JYlX_(&O2 z^H$chKN!Ubp|fX!xkF}26H#u_R!33IA6~a^G^}}UixaV*nME~4xwemT_o@x*K(PuJ zUM2z<&r-17Z!s<4>zi4-6t zFyZ}^0Aw(S3h&>PQF$oJmuaQ|T-oYAyJYU1V33ul!)4ImF(K%%*5A2R-C*|4L97Z8Zz{^ z-o2mBZxiY1{-3j7{1!idz{vmfzEky$+pBMlR}VaX$947V?dF}ci|*!?T(v4OpFTXH zM88V$>R*inv(@Qw-6r&k*%n4fkQ&A?yQ8?J(UP*p4)bzJ>Z_9S$^$(@J~H3G+*%%% z$62U#oWd3XD>kPB9L=6SIE$wD8_o^y{ZF(^6DmE#;0C{WE%~6*F9t}uP5HN6cW3s( zI$tI5?yj|H%f#?2q1){UNejOJ{>RcTmWRB(ZAu4zCqom8Ynon>o}HiUmaX=x&29zZ zX@-Y1aX}clN-a~@?MII2_1gB6NqBsJaF`oTpM=)|MTz=9#<(~2TLwD(e){R-M7t@= ziy9}n74-8AjW;{)=U0_ICn>7T85@e4@?t$MT5bHUmiv?SzyGm`ss6-};@2I9`W4DM zoc?M^h+U^8b-u4Yaho_%xoS3xM*t=nn7n2)895~=0 zhXASB%U_nLGQ{R*RrRZg+n)N<L>kiW^oyWhV9uDZ&Atku>m%w9V z6Um{rJ@;*bb?U@?>^J>YRZa-=61x&+jb>>+j@u^QWd6k;HFRzHd`U6;Z8l-QDO6_$ z|5+FwZ<0P@{P;XB9jB);Jx(n&O&B!(=3jN9&=U7#XZPfqqMMMif`s=<&#~uhRSU@a ze|*zzk+~QB3UkL{AR+b%HJ-cLeblr^wQ54+ON+$Vn;Y0uZnv|ePB;D3;F2W$sx9Pl z_XJr)vwc*@@orQEqSH*YPdKQ`mi;5jCio6oRHL55C`fSa&x02Ilk1qD&<`z^ZzjB3 zdi*Qjqnb+s@OCyrMOW}2!=oUG`;gzJ_i_RMxA;@z$Z(mIu_rZhqcM@#s;-d`?YWZ?j$Fxe2Itj2&sIw z#ls~15Kmi{3)WR^c6p%WcA)Y@w&#O2!q_}HTk+XF+U|lkQR@GTd-KVbC4(?6&z)0i zB+uP^(`a7Gk%I@9>h3h7dHhhWIepvhn?J1D)Y{i=$qTgOol&^SeKhi;g?CBj#Vf|Z z4S5;E>GJm7XAPpd_FF0syBzI4tPND={HVEh{S zTK~zH1NZoh*zx`2S=X!Wzw#N;(>^B3@DoJ@n0pWW-z0H|YsSbcj5GXbnpXU!sb=w| z*GB=6zPVq1pQmT_Vpzm+CGBwO ze99FfR=1&5-_}<<$2@u|Ra0q~0k2DJ3L$c4d^_te=i2=|;Gtv3zWeX5PyFWcZq;(` z&q;|dC>G=$+wPjoaXK2CU;c~5JhI1)p{TkHOE{E|t!c)$+}o{tJ)M6u&7{SNWH3PdBy7&P-#lP@PD0M{1yBQ&-GXP z)m1qk`0cmv3g*RMZmgqqZt-?2{j{mRx5h&G4=?_>LAC6Lcb%~P-KqnNUUw!vCn_y) zdv~K3R8jWE?M01odg^V#7W16i8~~g6v8dNnDa&(SHpL8B_TIOO{P_2Y6=3|nEWjg2 zkD5In-Sdz#053HKwDR}!FcU>GV<}s@_>W2mEYZxi*yO{d0(;JftTX%etIHu_akDVP z<{`;k1~k@)(k?wW@m#w%?})m1x`@*~*)AsJU=hZ;fdMKXb_WqC3FM2wsYGq?jn#LzI zyo{qps*$jFEZ6%|X)R;H|y%AEUg& z4Vko?tezUjZ=c5nYH|G;BTqPtRi6+QSQad}!~VzZLMno`&Vqjyl<#`}X@IxgsDe*( zP>%K(X8c`VqMKXHEdWysc8le#RxJVDJix8FiPqqG-UDrnG4z@;Mw&bDM(Pdkt%p4b z_^=YyE>dtsjS9CLCshhb?BNGUUxmt%U;P%sF$}Omr#3SaJL$z|Btl4ja;&e zD;=0Ia{Ow>(iA@3lRfHFXKXI(`?%-TREg4ae+lpKh2Cgz7~$2Sr}^-ui=W1STvnqV zgcbwOd=)GH&7RXscHKBdLxL}hq!9H>Ug9Oa3z_0tG2?yow%_z7o>)4%qvD#c=~mK- z6^)1en84FO>9HwqneNUMm_5#q*2f&*|KwXCz+h=*czZ(*sc4n)7K8bq{(f=0cxKPr z|C*E5AGukmd;u%W@58%^iCY~-yi+B8Q`^9QlDUH(XUy4AHGj3A!)1APx4Zm3_{=E& z7TiFbJ^7|iTsExXdro^D;(2VYVQcDdb+=z2A@}y+qd+bx`PE@|2S5Ad6H%hDG?L_t zKV~p}N>AfrS8o6)+Hm@#bfA<{kk4z^;2v@~@dBU6f26m0_AFh+?)zSMCmw1ct zE3bE}x7}4bc-33!uTf41thfyNx7aYgf=lbFI`ZOeKhL8v7vhs8;VBA z>n`nbbYt|xSO8Es2sI9>X2RuXP~41xl)11dNJgcyzrp*!t|vGNs8``q zIBdFrA?eyTy8t67(41+<-@VrFCHkjsI={rX+ggGTJj;{`tHv-sUs?z` zybq_f_03N{Zuml(HqCO{xFX$5dk|-UbQC~%Gv+1Vw2Kkv+d+eo1*ouc z_RkY?A89m)z)|T2(A-F1r2?go@he-?#8}!o0n%6eAM)Nis>-wL`rXDPCPq!M7c_~Q zsMrfl5MzzKfTBpTqKG2W1O*{QO)MA{1w_FHh$yJ2NKr^ouz)Bah=8C7ND&YOK?J_v z+Cbjt{k}8K8Rv{|oN>lD{P8@7viH63`?{{R)|_+A<$iCe2gV?g7+#+LHCfU4TXRz6 zcaHoL^W9Z-lNNJZDH)a;7~35#8q`;&h4FBL*wVvTGR zjOCBSKIdWd7YYhALK&6Z7qa)jKV+V)x{9AZ7M*?*v_|RZWUuY_WBV8zT(O{g!e+T3 z@0_hnh7}*0hO6~|vm?M}Umwln!tq};EhenLg^^#@`1wJ?y4ay6jh(wWj>}IVG-L-&{~8r~$-wx?11= z%<~wd;qB3a43r5Y0;p8 z%KnNZ-|{WG=gZ-Lo^7w+s0>~U&y&!JAK3PHZ?Eln=Oi0y(7*`BnX-hV7JARAJxWp_ zQd;kBLrwwrRH&vhb&_H!tEGb zeAMklSJ^?Z_mFcNwC=4D`OoUR-!#10?=n{9wkmO1tR96@aa%Zr3t_dKCXl8qh&lzs zD7q!iGY0j|M(-pPD1%MoY`b_IzOtehF`KR?eXEuO)_p44md5H3m!gQ0U~c;#ltQw0 zq&22!kjW^*D`)sK>ur}W9@kuJ=GVfd)6C0gxpi8*<2epx zsGi)$qCrzcP?yK+dijOITFHkAkJp9P#avFZIHry zSWv1Q$*QP4BQA#A24qNF%)#xiZ$>gad6uPv<5OwUBJs!J_MUAQy!d3FNqd2q&BVVs z$2)Ky-yNOZQOqg$BV)VAip|lU@86)kxc}VB(@mU2gl(li zpvLP(>}4@dW!J9unl@7SE+h)r8_wb`7r)SHB1$iP;qzt3j*nT1)ztX))jBfu1>&-p z1NYi3$7$21iPJo@ZbF@c582PCuR?usHkYU~vMKBvS$n=>H01Ijub<2ngv{KRf4$(S zg*Z3UQvGPj&fQ{V43}a`bIujA4t`EtdyB<=ZT9ImGsi%nplvY*{T+GaB+FeMiu4fv zxi5S7d}IS6Zx>PAqv=!Km~D#~EHuHRXW1*%0Aj5hwdwHp)^Z&R#q|fx27EHv3hm(D ze9>zG#fvG|s-8m_OmkVp=WwgNRTUlK(k6eYa|Q<`Y2KJNT7N9t9Ee*0QBiPR>LcqH z_BVK$W%zCR94tE-t}a`ciuY0UtAtx z-X-!YC@cdjE2|wBdZRqVUREXpZ_inDU4xMIUM5fimm5(*hdjJWbUY_Uv@Uh`O^3`G zvo}Xvb0JfgK7Z%LZlhy1n2%pX^?mlD-Gx)+=OU<=Z)BqLe+JbPIn*=|&f%7J0{llehgBr()11&TK9EC4+Rq5ai4fK9*HfRcyHUUFNpN zYnhf+#X9|vOj2?vJ0$QpA9Cek`g)<-Xio-JE@rZtSS_ZfQFNvK_T_}L+%6gYEgezm zY3Bw-HsX+;?pmi6M&uTQ8eFq#&JI$5~Bjc)^H-tc<24K8FWH>{dMp#%cPfc>YV zwuL<+ey(8uv*bgQ+-j?Z>Aqs=s4@0Y8cg6xL*pa%N^OoAw4iE4bU&=s9&akE16Iq} zaP7W{_tw-On%AhdD#20GZ^k_Y$?R>mKwt0iu=McHspz??k(`|4CHi`2rhfvBFP%o* zDhcAu-B{5#b6yAfe<(FWIOJb3%RHTNhQl*+g2a1P`ti5s`)opP;lt=oX}6oFQ|%5) z$7P4$Je(|(n7PsGkPS@L>gF*Wfbj^JWe;ox z#B%rJ!aa_)h>fAKh!iC{eA)T=c!7Ljhie_<_r6DION_z$c zMgCM}ou}CUMW;?_T!(bp&p;IH1F;|-k;gSQPeLB5KXOHS_ zi@Wc4lXptX^T9K$!a=I4vl#CoIz|Ep!()F$iH=$wJZRjpXJMkgA%Z`m`A&6)mQ(FT z-p4jKrj2;#Gt)$zKt;jKjBJMbAaoNVvQxVz3zNl=!UFLCA^lL)R8~h?D}osj6$9s@ zA9T#n@jZ6TC{Vd?I3i@Y2@iUdN;8V3FhLtBW0A!W6S3+=#_!t68)MXte!-(=$VRnH zzlO_n?F{B}Sth5yo~(%?ZxQEXGTqw)X6Rz=>v!WG@j-R``1)j%HF%(4Tx-o!`~LfX zthVOB)QinEmb^ca%jiKu-DFtmYow9X@(a0z41HT2mMM3UhbgWb_)y_VIovq(J|fvb zv0X<_EAau2e^z<}0y&T1$k^xCOCti#oY6n`!*ps%QTU~yIwbQUXn4DQTs!vW0?BZ5 z?LZ7gECthTrZv;=8y*e8@+_F4|7m?Kma%w)shm!J?lZ!K>20#frIS{q7J1DF%!ywS z|48>#bHR0ZE+0`4PL323&_I9}nO6WtVCz0}AUBvQll}rSRcl=Cg`VaZi?QtzQ~2i4 z(B@avO(U$5DOUZ_Dm+yTd3f8|v8M2s%4Qx2N2>h@*ky`hYu!nG`wLP8sq&RgvyEtrn#8? zukR@6TWT)%{_}fBrZn%n|NP#pQ%Z>9KfmQVu=R8Q*LPzM4zh0{l>YtZ@)^w$?%!|l zkBb9#H8<1T8^O*JU7t#9Ikchwooz`EmVFT@niq|Q3HristIleoJ8K~>N{oFMt zv}@(2{#-z=9*lZ`hsoBvOrM|liFkW9Lw$-`FZ9gC%vK72yOmHWQ)2#}BGkyy3KHGUhJZK?HHOuq`m|G@vxI=Nn4tRSwZ4C@htdLqXiKC1 ze4uuEvF9WRW2$zmIy4T8n9$a0XUl$@Z?4hut(M=(?fIa+&xlhc zV_aX%Z#hqfbcHFDZ?(Q@e%^X$xkuGKU0b*4+28LitoX1@-u+AIjHd20*{ZXp%!uQ+ zck6Uc7lR|vfv?_o&^Kw`idGp&R}UXPyhiI&g~OEX+iwb9A!UP6)k9u9MD8?0i@auq zgl)|9r}*sA{8EGNGA2!_0pRa$d5u~7wy80vDTT#*f83YGzQ5dS0gr|=O3@=N#B>v{ zzoeJ3Z=2fo&9^3I3ms8Y309#+JrubM0C>f}b9w=vF*oPmUF1o5q^(AYT>qA|uNa_A zW-s5*_G)Iuxf?EB5anMAXKPCt`bzxo#nVZK7LejpKlz|YcoEP%x#vW3)vez`z=Xwk zoMkW=Q7Yh4&AP8!ENc~t#+O9vzFt2L9XcB!nKISb@MNS2mHStHbg@(OrC;9w%ysX< z1DPqNy%JV6f>Q2&b#2o1>l!k;80+qZvTjvWUjO}j^8qcaX1_CQepOGa z+g%Mj-4whc5+nMY9>1QnHbTAa;odFB`-z2h^p3Dx{ppI29P3k-*t&k)bCC7g(mhmy z`e69JfNf)3_p}~Uhj*&(cwoEf-z6JrG}zXkUpV;hYTS|B{&>-|786>9>lXLL%8Lqa z=D!t++8a6DwNh?057*%>%j=y_j>-0JdCoh z@d5KEwS=CEWzq^R%}uEAl&*ArT7NxNOpQV4i#QTwh7GwZIH7bmi?GhE-Iwokni>O2 zOE*B08IW#{Q)%bAza^Z{H>Fm%gDiQc2-K206HY|H;veIJuO^@hh|Aj3I3! zFq4weV$~bOXP8BdFgS3ft}jYDVc4qsGy;gul}d~K@Zsa;`{X)PjDWaFS408%!jl-c zO(!1j2a&4ca2I}S6eVeEf+DeuY}OAModw9}KsA-R3sBI-{hDfbH3HR1JDwCk=ir;R z4a*TYgqJEI0U78P{wA*XjcXH_Pj26mW}KB~@S~Vvu`u=5pR}1;3oIqT^ujMtB~=`e zI0Y`|(Fktn>GrYmK76=^J=!TziBqv`4t!YmPHavYb^YpGP;fGcMCltR>F?lG7Oy`! zMiHa+s`-FBJ60^`;lKp_m}b+ZTQ{+!*joF1m&j$9ymg|I3{IcU3+|45#Uu^oXcz8} z*1+&WsZ)ZZDrZX3Mn_)0&j{iNssa(n-_z-?IPv0P^Nig2gGG}_nK#bll~Zx*VCi^( zQTcXcj2y z+9-y!Z`mNxj5%8>L}UaHMw}r|FBKUrVgn0OziBqjFU%9~$j(l8X{t93Nfc}L5^w|S z63-7GoJ2ApR2M5E-*2@-v8iR5$(W=4nMo}-irPVr^M8k&DV|-#m;VM# zpQTD?Dlu1AMVMPSer`OF^5iGb+~*)|#f@pvk|nRwv=j=9X6{OnyG5pvQB4;!RdfHN zJ~GV!*|vCC^-gTv>aObII_>xC3-C0-|8+l9{3iXHRJg_Bc(H1~jz85$K>S)Wvj~0t z8!Q+ZMuAE3rCD`t6lv3%PluRIqW=QkOyfYeMYDQ(sg#xgbK)-WAGc`^mKY*OI_73U z8Yj~?TM7y3+d@k|1PHfC8>oL$pl>v^Vlh@dh@M|oYT{|0xPlKVPqbcU z%t###xR%T=LW~dqz9b_rpfZMtu&3X`zPwEc4x zux22^d_AeQ--GIMm!@TyI1fFj&LRoRw)e)r=G9gQh3mYQgWf5D+x`|WG56X!S%Xk` z44N^Ofg|G34r-t}Vnj%) zmfwSAm?_ET#(*=7N1=1%9U`_=9zw9Vut@h$${PU2$83Ti!m9XS3=gRS=+jiO!PZc; z5p(0Sna?cEZ`JsC9@esiY@v0mbWOR~WN>`iM|`S?dG};abM5Ys#_4^U&IHjhJ)|fm z9~=^}>Mb@4&X>Cz@**6@MQCzVd-gw1i zBf~Frr=JHJ1)10Sx$!c#mej=F{Mlts`hzKGkGt)m-A|CetF4jIb@yiat);(LoHAuH z$oR7BlR<#&3||eMq1*8GkhDG-Z@%3%f^G+9VEt*mhNba{up}2Gw7B1-wKHT0g|xk~ zM+LZ%1TAG1BkJ4_9C&V<#ff2bZ+TeeIC_Tz^5`#K@KOhUkUogW)WMOEL6X(X@=_=Y zTB^hC1;LSJ$v}C4yTPNAF3E5X8BYV0>E+hMVhbYAQ|5Rwifwe_W>m-0NQO>ZhNQhp z9>s!6hm>Ma-TW8gT4(_PLMz=6z2OEF{V2qK22dX8X3q8-J$kgXqK4$RZXW$(=;F+2 zd!aej_nxg2wU-d492MdE6OE3wOnOOyCOquzz06+|6hvTL>E8ibSG6nUM6<=HWTDIm zpA4sF=h#5gyp~ERh?=rstp}`rBm*5TUcC6tesPZ!rbmFn{MiacXiIvftLO85VGikI zngE=9qDZ?e>2U>;5p8vUO=-x0y}tM9 z1u+#q)bGJNB(7S=1Ndy0GmS^+$rLyAjx$vNCVGl7(CzJMQI0m!zc}?-SjGlP4&)Yc zPS+H+iP_bX9u9g$BcGVrNE3kHn-(u3Thze&3#5M@Y^)&JKX8Eb<oPoplkmW-i+XX%0u5<(Cd~xRrLMaZLJxwVTnk zN~et?W=zYYa))qJIu)gC=5|J~6a|Y5d8t81*v>Eel_gKR09!!{l)v#yJo`QQB zSZmsFyGkGg4Jg88QU>XMe)6;-t&4G>O;2(2YXu4ER_0@wSTP9a}PyW70i3tMsYWqnySX zb?Az}gER$!CnZbvyCiFqwIWW4QqBa0A9H-ZqdPnKBr5h^tdy8dLN#(cN4||eJfhhR zZ{lrf&DCGBWJ$3n@~(&vzsFscREtm4@O^pYWlM07ERzs6*4%gvRr^kG#?r8C? z0^AvK;{nB*XLGLJB-9_9J`i3~gbmgEK-IE(V=ET^2Wojy))8q2&mStR8w1XPJZmZGK-N*Ct@{YUJQ8yfO3q(tl zY0*}3zD06Xu?nsA(UB{DEkjAPQvFjhqjT;0I!q)4Z_!F8;LjXtznz&S^pl$h*cfp{ zi_2~*mErfka?ikdBhJ4%r-8-ER?0V!>Hz^It(#)`BvJtZW=XkG;QHSw+*VWi2+sE$ zgl%e(5fssoOHb-hJXjB&2ev9W0qK30I+Egzy+*D4oz&2@xQ{SoM9cFK^FJRH1e(Mu z0B&{BeZRfY`7;k(R?V3ObopDh(&vTxk1;|r%Lg7t|4)^##QlR%97ujRw=k0v<5Icr z_$?3oRu-X2lV-l#-Lz)OxT7m{%y>kMiYCRK%L8lvT0V&_17Z-uN{W0!6hgU~4AX~f zD%bYr!J5K%r5UFok#7~f;q*u&(cUx6_Q)B@H-$69oFpZj3S>)>FJt1}eSJ-fYst5W z*qX=Ry`wxl=0>F3yv;=6x~T=p{E#Y107=|yk2Y)+aL{d>LL!JUw?MC4H!Zv|oX`z338^z7m)uRXd z9({k+c)Q}4ws|9Ub)%n5|3`+G6@5TDBM9ZJr8_(HJcR;VLnVxY*`q)S26BZbPW`@f zrzCd6_lu4sWfX+T(3%F`SwqJjE7yVc9*sAI(<}0%9)Qi(OAeX0-8Sy|cX zxLrFBb4!J5c}r#XVLpaGJI=4VocrW(d{yi9%s((>y-Op8xm_>M-8{T3solpP4;nDw zOzI_$LOl?{OpEIaOZxz3)kJY~Tj0jplO#Fqmlg^t^_N}i)l_@}|JC)f+oGDn{{8!N z1d^;2)1{Ib{juuS9P z(%fKDa0jQkt(WZdi)!M})EdI%?GG_V&o!%wV!zb(14_7TlT~Qw=f*QD`HAucTR4)5 zUf~fl-jj_j^}i4DQTrGGiVp0*W`*c2T-d?C`OcPEt#O$^jmK&Fc6}!Ut?`>OuO+^A zr1hlOwgYAUe8<*=X^g?{QvjH_v6=SSEQ4;LW}!FKpu|szqS9lb@xro z|Lf?{V%yoI!E(q|_aT*5wM&qeHin8_*DmB>3P8$0x*MoBowRy*(&vhbjV-yxJwJG` zx$%rx@D^vFPtLS1rkEya%JKWPkn-r6`Ex5rk#!V#BDuS+okPlc+^q=l$$b`|Q@<9B z%esghH)Ry6;54H7qOz+p#g}qv9`3&*-Wjz%sM^s6HHBdm?m)Z2Sr($vSQWQ-!Q#b> z4fsmFF_?U~^GeueO zo$G|mL8C`sP7N3HN2x23(A<)vRw+5T(%FtB%K`o3=?ntsN3D0@Plk4J*GrS12yRX1 zeyT|CaQ*!zUVLhsJ{0OoIVdBlfV>_8pG7sEZ5a@&P&iEmGi5+)_kcWMPjnlK=Fw`E z!og`vfL|F;VzZ9JNW#uI<_J)bOeYAZh{r|o2WWiJCLb$skTjSnf#^U=F2DQYi!X#I z6>$J5#AdA|3@Y*0EVJ#Sh;h+q;j3zI3PH9{{n>nb`)eteyu7@KJMK7%VoWeuOn^8I zD0CJxE?g*P6b=y|Eo@zYxtYz5Dq}t3NNyHW5-$4r=Rc%YML8%1HtX(QV&Y?JLMgjO z=U>MOcN}~M9+@XoH`5!NildXri-e+<&R80IWFnYH+GdSDbSPVUjs8gC)V=i{2Q3Q= z772np^Wpst`TB+szv0y#NBs*^HwF=`1#K0aqjHqs&Umb9AC*=s^=#WiU>YUKwUVTR znOcU>X5NF*K;h2Zldy~y6rgrjWLd;4r9ieBsbu}?4^h_iICsQ9M0_W+naq;E?}eS_d)!m?kwBl;B!d>~FU>xk_9wlz(zvqx z?~lNC9_QKcLi?LzhHLzkY17uzNTTdD`i~c)o=X_2RA^VHgTyoDEP8dr^P(P zl+ALtq%@>^$vUJbqC!z2f}l?hFZ@VRJG|vQ?<-bWOLFj6LYn>#=siVlmEt7h~AK5+=c(VYnEkBufE#e_WY@%L7MBQ}V z_B@{(X9F4V&T^CyxwJalS~MhIw^qs4$+`U5=bvvr+JEfW)Dl3#<}bfJNWMrnW~%5N zDd_j0=8xZg`z_lil*k1dc6Fz61d#N;6<6KNMAB!ST9!;9yZ1FTRkB zkYF54xX33lg&9=s34kZwEz&yyOb|5sLdiL#JOSy?4oPNy@4;e9wsl&BCi1>Y4OA}T zog(5h?w3eB&Y14mPlU!ER>snyV445&3x#WUsQ|Gkx$$-@43{*mcs6tF4k;iV;xZsa zig0!$d4A-nm`mX@GmXV0HaqT~p05snp9A(RV1dm06|)FZI(D;=)#8W8V!Hk8W5tOj zEh~(iYoHiY=u{<)@h~Am&f#Dm)$+DHtW#Momj4F)mRqaPSF|jk)^G6t|4$?hP5riQ z>AS|JA%63Rap@h#B>dB3_1$M3_LW}l9A`OX*20YX9g_xET2y^;e7F0;R{@C!^;dtf zbYaGQ!#jrSdc^+GGxX>Rv$FF2Iiq$ZWR<^rkhL+PYO|Y7!S?9#^7<t+9c>U#%09#Ff65j1Yv-gRrVnU#whDTzt%hb^nWs zAosw~vf-^C?BD;lo)r2oW!?JC{Zsyb@x3mSfUTuB;j)WZKV=VWO_8j=*tDg4>&>vb zL65{b;&(;}O>$XF{&gAFT=l=&@cwVVb1z?)u|9pIT9Ek(gxS;ha14Qnq?#lnU5F>W zxN1{e58-qRv>SkX9V!1K>e5ppqaz_9HRLeVCuqTiY!sV5NDYrL^`?JwLdy{hPB`*- zec~GIROt6bnr?WNIEJs&I?{OMkx!SLyPL~&YKT7Qwx95-9Q=l z2HoLZ#==0zC?rjvuq8w9B_!CsI!Y81a7a$=sNTM@tX7ffeyt ziU5|<;59tTwZses?U|lXj$GvIaR4hAW=Up6Jizg8DS3M0772XQ&@he#Eh758eA>*W zxZaR+9Pu6BDD5EXKLD@Ce9)AAEvyCn7d=G6q8>B&R-qVlMVxksQGp^7zhZu9#p%$u zO(QlpRK$KM<7&wL$A~w}@UvYMxwFTPZ#(fjFwC*S+!43#*tm&dAOr?$KB0Q^)@c%z zaKWF>xSBgQDy%tiQhbbm7DQ6MCQcTv2hdbZC6$qRP-=^nTWDk+$dDd5nAlxLY)BMo zJJ{kiiyrrMjiVeEAiqc^7qET(bX>c|RQBydA$TWB0|2%amB&_5%8O7IL4JA6J%u88 z8plzWl@R4=RBICgXMN0 zc7F6-9XHrg<`YPJ|9A7Q@t;Jt_f09kNdiGdun=A2@MT*FtHO$~e)=3N2iIPV+_i4y zNaoXp!3`F}OxWM)9!T5mIRv;={YyKY?V!l*q`Z?-WC*FyDQP@twKLqff8OcSr^S@A zBJL$%O9{B%%E@$&YKY%dDn_Hgv-h(7low;!8M<#F#phe7?Qq+9+o}o@#W}Li%>rr- zDP}as{>T$HVmX}YZsPFwb4Optf}VmcJ*Sj@FkSAFh-_0K8IB)up(#QpMyxO*Q^U6H z1kWiT^%98F4LIVtcR9z}dyvS`%xxrk7_z6x`^v|c^@G%RK>p^Wr^AiaIYPt)^v$ZEM7tlr#JXv`qU{rRs1M+)gk2#gKFq-j~ zRwq-`$D2ClcZFQlu;U7>wvC5x5Kl<@xbC#xi2nVY-yT47ae>9?5u(GE5SJ43y3os) z-@ZIFZwQBixdlrvGi6MYc3$fCozYE=%OrtMscfq_7JGmFRl{uKk5Nj@VAlA4 z1V@Wl{*jm6$F5RpS<+t_;;d%@{56;2@WGob`cnjEK&|}G$o<$lNpr7on2ehf`>?CX z3n=%V=&PWsNRa{Nw>h1&Wa<$)qsvX7DVJ!vlIbtMENY|3?ZLy)hqbLKfl-mk;rqI_ zQOwuCXKJbl;=w>Y3j1{1GcZZdi)CU58j0X3apX1Zg@s z?M9!JgHJJ$en%1-0ol-cm*j||e~o^ZA$XaxsAdYR42{q+YiG6XZ~D$&;YyQLKoF6aKp=^QtZ5pjKzoM`}&wuFv0SzM$H zib$oWcKhPXK2-g2 zeVr@TTc@582bRhKWOOC(jzvs5Qu3uc6m7;_aNI3WZvG+LtYTuXQhD(-*jS{;RTkm2guNh zG4tqb6$Q#p|27lb{1iIIH8Yt7_ZIjdf9rM{N0~#ruBJrN7|z(IL{}j&1y2E*m&|~e zukSy*C(C6X4DEcLyNpu`NOguwBYe*kmWGpAHVd-tvkq@w&|g0vI7(`JQK`d6UJbwN zc?eEH!noQh*Z_Z4Mj(fXS{|K+vZ-P_J`!S|efnWeJXFmY-n|4-xLSzto#Qa1^&WL_ zEZ|RUd;<;oMtt0n37;q?uGq|Fz^V;RuM%#JG$z!PW&jHi5(#kIn;NzN7Yzmhz^IlP zAZNe|B_$HcIu9cv&u#p%h3_bDuQ26HrZ-`LlyuXy^39tCfZ{B*GZHR|&+Bt;Hq3eF zWM&j;KSySdCF=Ix_DHFIIx$ivuESc{Xm^Kag8R+3-2f0JoT%hUGSUN4lL(;XUNRBo z@CZ{Aq0>0SWq>FU-@WJK)~;D&P9iBDn?n1^931WeVtNBU6d2GMG#~<%(DS$ey2oZ? z@ZCfbJ85`|{aeK>|7wOK;a`f-6J(DN*XKHOI%cqy zAo6zfNrHiud;npoluSTs18Yd; z+uSc(i|w=ljTI_1 zhdk!2e(StDEG(>A8Z*fWQdU$qlHsP200!YYUZlmGEpuliV&};<>@Uc``4o%V_wRef z4PAQq)(RI%Q5CsGv~?UTsXxuYf~B+xu1~yk=OO@L4zbNRvgYEYJ-=cOUPNM}P3srkKomsvEK`i;C;&zm_{mJD@#uLI$Fe`krh zdGamWo6NZCQ4qwV^I89iBG+Ez3f@ij)22>cdLgZ;e9yb$$M=Js-vFbO&?k9;mY{&x zVq=;luYxJrb9OVjNJff9mU-~t)_4j&<6ng&?&(Nhod%NQcb7j?v_HwBJ_iyAmvjiu zXL{rhgIAlGt;cd*05KDX=bx-uz4~Ri8_7-(0$#z4fEZS%Uxm`!it31=U~`m~pTQZp ztgc8P?yry738sB-<)wyscgfZ^IS$C-|2o>}f`&^$5U54RkhE|p*Sc)EKqu~P)1j-< zd(Y=B;zjP+4KbKY;i1Hr;DDt^tV2J|wfC{%lr~2>P{3(MJB_}zPlo0I56k*IsaT#p#>)th+Ab=8V&k`^#)cAiioblFJV|f|RsK z6f@!qT>EjX3q}+|G{~q4?$gnTIOaZ}-~ju(x*?7ck8?!}B~!Bv=9nQ(dpt7EC+Tx8 z!;h!pJxYvDSsb*8P_sx> zkG56yd@6BEx8GFIs7L=U8`2-)OXh~C3GiMtQ*9`?yksc!`p6shoM;@fmScV(N{cPg zoOt`vCo4Q7c{Tr%zs^9-^p4wr3?TTJLea$x?M?*Dxmo%T#kxx`EEcVm9Aj7|m)>lo zT$PgG%IdoZopcJ+>6+pg)5nm>-N(nW_$caM=}*-3sj9A)kyvA49<|ezNXL%HGPH}> zT%bCg{I##kht0wZ*qeGI7?ArQ-Yc9ym5Zj+fF#PU!PvU&X2Re@JS`Y8>YG<5d9D;3 zRGl|Jz3@^#4~oIjdqdZ{+++@kc${>|qYqyZ6D9p=oYZyq4xrZt8J>wd$la92_q4K; z5Oo&YHXaH@r?l@%S@C3VJ^gbcrtlucE?b0VVNOUkOIfDRn$70QZo2FzH#tFWQduPO zH=gM0hyqqwYII}u2>Ki}&$~(aB#uK6DQ<=P@hkM~KXhn<)km?qqTv`>;Hwvun7VRo zV4bhbz}-Tl#+Us@m_VOMw%@B=JS|+8xaU ztwyhuW&MF+Q;%3%VoNa7`ayjn3!>-t;m^yBt0YnYWs8^_tiAogcP;{iPze7uA+)TPHG7*%^UP zm<_QG51&q%Ki?3q5}WM#Iyi|wuLerLylu`HZ%?Z9`;B#U-4mg&^z4Vegh;z9eKTCd zuA8G+N(F+4`uFn^$r3q3;)>pZ>Jp|CJ3j6$s!Y)X4?WTpZFiHqaTVIZTW$LB&GVG3 z^vF$O@KF=3wC%ncr1`Y3*cYI%H6Kz#^myPibK|8YZ$Y{x7yS#lgZCTr48&6#q~5=PNk7w?~*VleD>p&;aG~Nz^G*PN%9Ieu$rpdRuN+(DM_c51*7b+ z5fk~ap7%;Zo=isL?vq`Ojk?BxU>{<@Q5)+CFGeh3x0XJe-4XDgr@b)_587hV1p9od zXty!YgR$`yCR8QT&q%g*Y;in46?RGQ61m7D8HC!V?WgNVW_Iq6Qqn`$)5sP@lNEcN zldc%xXN#c|z+ID>V}emAB= z&(XI6+ujm;f7VI1^#>ZyM@FS3=)~)K^}HPC++NY8gFNitWG)05%hjHs&`E$=&jRgM z({YJJJ2JnP%B*7u(EF+E5=2yYM;34fXLNiot1K zZtN^`-x|i@eg4G3Ia_W#6b+j}D_ki6n`kl=?JI)Sgz67`&f{t)E2fd(L0Mf6Tmmgb z&25NiZOi&O>egw6AIr$oKFh61kGWfS%d)^t$XpNl0!qH2giA+;T z1~bk}x=dRO2x78iL-`4N;E~M1I@DG@@X_%O;4@4_CFCP#COFmM;l_sD_DtoDzGq28 zY>y`8PaGLf2_mymUOEhiWI!hB7HBB_@H%g9brz|M40ORDP3qAcr>cEikC{WXEy7o5 zq5|^W&*~EhOe$qDa(`ZpHu4dz*{87T&?^Z)=#}(&z(t~aR`JF+yEmFOJ#^VkRByN8t-h2&PyJ+*uN&MxN3Nn-XJI))%Yr* zvU$#{&iOF#sbgR5B^(YR+(+{10Z1!>R?J4Rc9KjW%gLyJD_KfjW=G z2!F|j4nL|V9UXUHA5ej>7)=>JDEXac*=s74#exoibV(G7)p_IV(x@44RtUzbXL`Px zXikXUKFbXdR+RBd;{$O4prqlC@L~>1zx9PfnP|q7{J=_b;UXEUKuZPC73y9skLPxLi~Y%9~&$ekO*7 zhBYPZi&%@w_v`G(q#JKCvNltpdfhonI+vIC+jILA9u8?0mm+uz_3GzuI32B>k|$#~ z0Vul3!5R3SdGRln^*X*O6)1%x_`yh7YIL-1?4gWGR#B4Jym@hZ=NwhqG!yJrR_(LS z^vxlt1-{!(A#^j*jzO1Vi6hz7&TuM^-myKlj11M?Wag_Al0IrLTcZB=-+zBaEthol zQC$RU)SNK8vyO*#IPVSMoLnt3hqmw-{MD^%isiS&W6eLcYTWXHFUs0fq@LjUtz$fe*ucJICtM|Y14)%8v3ZsFIY`$P86 zf?Q*?*CVd$;DuA98)`+CC>OBbj@2+2Q>FQn$-6>x__dffGIfGd|CxR%P_bnX<}G>W z{hmk7p{lBSzu+o;^rU1*0}W*iXso9|E_fqkge1QGTEoVn_v92sOcmk$#6wOV@t`QG z1rG=kv}ZGDbXXtE=ng%%1ALv2xl*7ZJhDGewPzeC#*g)D=b4YTFYxa211`$B4(zmX zDhQDm03hUyEaP>;`i!xOMwU9BlQCU?7*x8Xr7N-k4d>+ZC-ySoN%ZgaG5{sp_#4HE zzs?dwq3)M!Jxkc`Fn3$k!SHw9g54U=LIvC5-6;@Q!O&AQjd%R;J?K;?&_+Ge+KAy0 zJ$&l0o;Jj(KG9B?%jotj=tp*f)@1f$Ybrmayz6&vRxe8$AL{$dh&63uUjOaTr<)?n zA7paV`(=NnOBlnG)+_-y5izl!x0|v=668kdH+hvpvsWCGKN4&w%-ET7%H~k!s`Zva zA_SzbDA3wg43mEG3r;*3z0wGOqyjrfX*w%z(8Lg1BCv%(%(lbm|J^xiELy+zp&(*e zSJobIcYhhT9yoOdO0xX?{54t=1`N10@`r#CY15`phrwtZ9aV#z#Qt?sJ5>F|G__Ua z(vnXfeNZ&Gw65H`|8$qR1k4r8<*l9UWl9rEMYIra(U_!yI1n1peW)=!pft*e}96i2k#k6TTzq_B=yLYd*Njz-J z``W19(m=u9za06)GVGW_UDT?2Y#Fov;{89prkyIuxjeRN?YbXKx0i4f$5AxDex^OP zYLoNaewKYTt#4ICkY7m8g8k!Fx7Ot;3C>)HnzE$&+S)tu@$qJ%rMmA523EEGNgT`~ zi$~>=1a?aDGJWvjSj^eFWh-5Pb+>ss?tY=`RTtLi`%Vwd3kmhuC`*E%efp}GaSe6M z^6l|#o<^weNkDAJdmR;4$D{T=_}DdbBC_9EX_^G|S!s9g?(6K85OwLU#i2_&-kxQa z5f4iG*FtUVMb1`w^M?_gPv9u^mKr6~`FQ`4e=D7@jFJ}BR|#st-~ajJ&79R-j_a@e zM=%-k`Gb1D-rnbrXC%dtF=q9J`<}m!co)9O9VO?%X-=pq=$WG=jz4-bhYSmt-*2Lf+&a0v7S+7Nk0x4&?C;MNs56UsvF+TGMm7R%0puaK4gUE zJ>vOn4ikhzQ1hiq8QXe z6bnDvJkHD6=b3S9tj+1Bd15oaHxNyVu^n7XT$vUCk65lpuIDNF{Jm&@p0XC{GTi5~ zZuwwRp^mFa?KTzI{L1k+W)zs!moEo%AItQPJuaR7WgCi#XCi^H+`?hFa};q&a0#MT zp9p(~7S<%0J?ryl;8GA9| z6AFHIA;B{EYy9knSb~B2mt81*{yH&prQCw+#$JwJE+z36EbX~PI^9OkHLLC?j64WP zmU%q=8L~}o2!Xe}urL7jf?qNSTsW|P?b_^s4Ce53+B`6vq#~udc}RHO8kkv_Pk306`^zu&Aq@~a@Z*uE$e2dsgWqo z$m?I8IwzN>_~f|PU}vx;0fpkV|H<}}d3->;K(+FFaa(Voq{(4VGxtAkVMq%=-B zo`S~xK6&S}7ZwMReHNr@L5iKEu`bpuf$UBi zHc~Y87)_(|%PBZxf18gKS-X^rnuLokiym=_fU`wc*|xVwnvpyCSApE!EBN@c8ABGI ze=^M5)2z6b2Qnv}m~n>Fv}YXgGtQQ`&NT#L8@pH8%zyZB{;$LnzQTT^cB-qd?;RJu z(&d>RYd1p)`3Wxx`rkXw%ltAaOfU%b3FT&x(ISC%4vvnRA115;k3l6G+A|OAq`2yH zg=@W({==6SY%U>?q*JWt{#DJEVk&&yGBVR3Pz6ohD-NbspQJLZkLsdx z?ujLhAFAf1A6x)q8*MPF$9OfY z?V*mUQ#jb~EaGn8?%C<{%>gJvd&QmS?!4V*HpaKJez4I}DrNPHBrG1@)YJiMdy4m|Jv67k2HdlCeN&%~}9(r?oxaA$7>HPRPt$MkO;g z${vyc^~aT-osn~rjrS-JsdwPb82d;c1S2{1DVIjM@>a$TX`%#gC<{djT@sdQWH9QMKFo+B9S_; zFq-Qt2V&AA`p{!Nsep9A)MU_Om6kXYdj8{lRhPNqDG*FOFe;3sMFxg1rCw%C(t)Sm>(5>hU&ok(bAUpid+9rBGY_=e& z*2t3k$MF(-9e_jZ+_dO=EXsn~Nf+fU%*fRZ84D>D!=Il{*7=2{B&;HjKip^k?%iU7 zW}O`!{fJjFGjQgVxXNfdF_ zhb%btyLilS2B5H#5j#S#3o=1*cQkmF)JXdeM>2R#6}RWZ)SI>&;kQH~VDwV(YZl^L zfP3%VIY?xYKowPJ!DNYby{Yu2Cu#r}P?iXE?r0d#Nh703tD|nOd<2xJq_B;c5Bre* z>q`Ay1EhO=;7s>BRn-#IxcvdrPAH?Jx<=#Rz%5L?i(G`S_7O5n4>n&!cPlo4JcvOX zu`=Ss=y!umZ$=3tgH|eLWDVDnPJ--Pc$HP@UTDzWbM+|4LiE{oRRZ{MUKZ$D%YU8yOeN$@GiRP@a-4R7LOh=VuUmt(`Pl;cr7%X_UrbF_C3gV zV|Ktco~8{|t06AyXZ%os(q3KpK=Xbw-EeDq7@Ya}YYE)u`_j?D7@{)T-F-xv`v)%K zjks_+(}h`CQ}ELixh_S#Biz}z?H?Srcsw5^03}Dg)-V?~RW(slYZxPxtfJs7U(q#T zwWX&73%zznMFa#wZD+V$;c6pcb2tr^~}xp#j2l*`a_hv*Z&qMOd1o3{+2<1 zue>B$`3Rnc%8^h63hx10?Jy>PVrLBCI}XpgU20=xig4dIKXAamOeAh!rSX%*B{H&$ zTw%<{e^y?4#cA^x$I0vof8ElXjq@vwE>=q>n3|f34hX!{YKbxt63q@K`dM3}Uy{#) zV=ZRsZC)bnfTBAf+q2^rT}vv4_f=$_*_3lr0s!5ZeU7HDymXeV+W&`M_3%C53e(y+Cw&{ANh`irDYc$UmjO2E<~Jx_{z32Cf)4mf`qf*i&51#nSU z(B6sE1x9Px&6_=!qOlRb4N>gzJ^^6Jdg+2VNl_vSJhTRp7S}1Y#5o@?p`}9i%__4X5&23*9}4{xe)Ed-2qB4NTRA?E`SqBW2ttyK zy9&UNdb`E#<;D#9S&T&l{*^XbK4jJ5AR?C-2*6SbI|5St^ruA;52_ogG(KaPTzXWDZN1xfUgKm%7x3~6x6XX~d)sDh|n&ky|Td$K{5$SnfOyD7O> zsjo_qYNWDeS8km{f^#pY$2>*Fd@2}yGWtZ%sM#BOfR6#_Ta1yC);|AjW#vNfan_AE zUUO%paUxk%O3(}NrrFkWIGxGShjAljarYPU&J;{$=@Cs5*`!rfNiWa?BMcU2sln{o z33E0RqS`ZM2^<}Dn`R03y+8V83XwO3)`(nXOd(Qn;GJ$NzNZX=Qtg1oS&B_aMu#5i!Cza`C!wX@z zmJRJDydylm7~e?DU=PT+!?LRz%~4vAz?hBsVMMg4^5LCDzc%~fc|22O#n(ix^wYT|i07(+D@HG^G7AkoC*chO)0704 zlX7MHQ+2kDw=8|q#`Tm&GDXggXKhSw;;%vQkmroPO^@tEJOaz~KU1A z4tGQ<^fEtxAO{wf#@B5dxbH(4+l7FN1$rbbY}?HrS6wx zX@9n-0a&i!rxb#($_^q2evF<&sfZgmP%Ld48ykgfA>$&;D=4eXZ6gwN^fVkSOGVIy zoF@y6l~VsBnGKB|hL^$y>O-H#PoDk?i8YxTeC9my48OR$+pyxYgOLJWF$`Q=p2Ql#^dgY&BFD3uVfqvF0N^e%l>sxFnkK2d8{#vNIli37()-ip z>#v_zZtt!L>rO~j#0&#rlR3z*v%>W#+zh~X@#zaBXR0Xp3dx3saE4;TC(1|U?28%u z#b!&rwsD#NR#|ik?R$7Vo#b&B;jeKH`(XJoXD{h{^5D)lIB$;$p@>{$J_hchKh56u zkha6Jjfy8bsJaz@o$lS+Cq$KOTGowNF+(&tZgt`&xsK0zQsKl1&1l|w0`hXW} z`fvt+9jvbHtFuk5ehE~r7`v}g|7=usPcFJkdT%Ugqdc9@lctMAiQ7sJ{i9%kmnJDb zX(O`CZw)lZY3KYAm;})G?cKoVb<<`yn?E}2KGuutS$O6%J>NEpiN74wi>0- zzC%nmxoNI*rgcn8Z+@$f|NSRdWO=QZ&!$_ce5UzdJz4K?^J`w<y>*jd<{zm9pmj`9|rc*tc_7vx{)_w80q<&)L^7~1;!4x?NC8DmFftWfAxCRaGA z$aT2}+?~aBam`kdGX4cK37>|oockgG?ybr#nmD&~WAioe%a_8MKVbV2N00xaJ#)s4 zaP!KJiY}iDAx!`bLZVxs`lX)lFTYv<-S7VMPjV1N?pLfuS$RzFJszUXd(EFOu=EM6 ze54f3HvUv`U7scqITLa6E7WW!udorZAuD)`@)aTt9K$j+>iJ($vbYsLFJwALto5h{HsSSUY@AqYf|!u9LI)!2TFS(ycHBhF*|5m8eW(L%dLpT%SdVbu}Qh4Svn;3?o@ zEqEXw8?t_yizi~YQ~c=kF7nRr_Uve9&L`E-BOS=_Z8!hbF(xkk$R5KU#Z{^y*xWMDISVL1vY2>UGgPBAZUja}+1 zC?u<8kurcBnE|Cirt`2Aye3@({@k}u#8P%t=#e37wxVdhjnNN!leeYmDBGV;{*GJq zr7if77*E4P*qh#eYsOAA7agp$k1|oM%);wghb|_w!Ic#|E0bR(D1Rfcbfc}=)@Hie zxb@!^t83J>#@y)HvfI29l>fzv^08U(Ef{Hr}=C=O!qSk)_&z?HgtNDG}pKJbCj38nP z*JW8xny-8QL*q@h6X@q4&$=Zagk>Bb@$6L}vL8}|(PHJ3yg$GXh8YWM3tNlGb*VkxAP8ySY?0Q z%cbmk-p?O=%ShziJQHhVnd< zUz+%xEq2g&dasSPd~)%-l(sB3y&=Hq&?OU$AXQ*UzCkj zD2D5k?jjP3i0lwE7?n&A-7|8IS)D}l5lHS*HoWbiw^8sZkF%#L9@L;0E@>_z6^iak z4=dWy`2&sHdh>MrfNviv-*BC@9BDgTlzpSPxQuZHGO{?5E$#XR*uWehEnB%r87;=; z@NZsD3q>jt*S@@5m#s>~{ofnV{145Opc()D%XPXS9RK~xb$X@!mNu@D7?A{_zvL4r zautH_ARFO=vM{?{LTXbsJnvd}@Kfk7Z#TqbGnsX2dG4&$AA9y76s?>SU+eO$4tAzX z=gu}Mm%g96`z(?@exc$zT@jqba=0s8Q`}L-QjZVXjs` znO`)gYu8TmwBP(b0nwtjo7nlve|ZlX(nqvl9OzlL=i%CkY1VpEcfGw{2C8A{MY-lI}SCE9S`O9^gLu6#@OSzUv?E$-!xcztwf8G3}J3SBm-+jl{*Hrvp zT-pE2zb8GoJ-qPjppnC!j*Bfhix84!EjdsndNOn;X_T;(wv1UAI2gnoVqcKiUw{2o z*~~&UfUkH2b{#KEN=nKue-Q=ZC;nXG1B_K#7e=Eo+RM-1{~U;-2+ffy=;KK#bp+!P zq;o&pxMt&XFhlhQU1ma@mqz@6fVEA!rq?zsm0zPgW9&k(GUIP~1Aqs40~8gT>XZUv zrrjG*mV`F>hR_TUdH>V-3etBAGUE)F9HR)^8GxWd1T@zFfX}C-;^bkXYy$R)j4b_V z1P%jqtfX_v67R^iWOxgI#0w`i+Jr0(^@lYVJfhKOivnO0e7gut8Wo!hQ`_~sroW$1 zvFn`ewo;Xm>-X+{();Y7BRvj12^naxYVyDu``3RCy!~X~s7Fs$?KygO$xk2u{i%JI zTU$GiGdXSRt`@CZxx@F91EUK|b!>c+3_l#W`h8Wxc@O7;QpdIBKWE+8)*V&k&$DMU zQLNX7jT;wSS@-Dtw%!H?29(b2($dmSowhJBGU@|q>U!tSo%7qgX@g2%|12XzGpOMH zVbveC)W`}Bh(>WINc!rAcT|Y8*P9GiQ|rbDIk;L-`%br=@njOUvEd+?*Px zPdT@s%P0KeQk%o{k}R!S)2Fbbs-~u9NyJ*Ws!p9cS=rjYys@v%nlz^k8-`LVF1NJo zv~1b3w5YEy7W{c<{*%(#$v8&DvFJh@$3uNA?-lD%Q`!>;V0E~few>K4IAnSUxR_qI zT0!O4g+|Gr$jcXiJ+~eB*cW7L1v!m|eQ6KQpMCD`?%ECYRL%dSq;v+W+8Gw6?yYZ) zFtaDM&TVv2x#i4ZYWgr{UemLma22W)s9$~cl`CsHzbGg3>C;oNG3zp~4MSplyL)lvw?l_^ zASflAW6IE5?S^e@*RFNgy7dKz+1ZV=Ug^xLnLdC1S8RPNfNexPZH)?(QRET;re=gyYZQ{ zl-AL{(&fh=1JWz=-^If)GW;|AGLR$37-LhrNP z_))+hDSie^A8;QPl|>ioNb!b%5-vo1GS&B#J2KS6lp%vHEiKQbIZ3w~S$#fG*Qt36 z>ZpKB?~@DW{OLAz^+=FOmkR|3=a+n^b>tuv*?BBc0#OoM>aGH6_3OO};K5xKazZqn zOTPhpagT9g?yREEbabaKaEJDZK2nZo?<Ka0CoE|*) z4xb!q2Gjcnq-7_)x=dYw)ao#t@?gr1g;xvsKo|D*R-YTbrQGM0o<5{Sh2c%mXRdOJ zmcYdxaYv~L|8$!&<)AsOf3xy#!S_{u7!TU^?&*z0n&UeDrB=?0bEjy9L)RZ%(Xg6b zbr|R2!RaBU-JK1SXU(0vSkPXm{iT<#T^s(;F6`R1#X?DLz}|Vr+_~MH*C#`G9EG{+ z;zv?+-myQY{X)T?9ky*#J#ys8ClO`r+2qGgEDw|WJ^S<-s@3hl-wY6ab@iw&zI3ma z>^?Q<9#qa@$e}^JziNesDJAq8D9f49aIaDhnvktsrr7mEmA#vl^GooylaqG1v%!Y& z-uKdPgaw8pzZq?rrQrv9`blsdXT@NS2KN_B-Am=Rp)*zGyjtK_5DCO4YJ11V586tK8T<#)p%Kt&zd&l*>|NsB? zUKth2C@M3hL^drE%3c*E6heto#>r@jl+i1*a5Ri$YglQJgd!s&DqG6v2*2z7bjIA%gUJ zDIXXmYbwbRK(HGOoYHS#cx>3w!U0s#rqS&AOnXJ$H(BLp+a9hp?jU^lHx^(*k_V$= zxZ=Z1S?OGaLCyw!xgqtj){G<%PUbt)w7u)^a~Cdjryr|UP}!5;l5ZQ54Id&R;2U_b zZ0eKi`~8(v7#R@q3hQF+=;&B5j*pdDnRRK4boR~5ohmivR=&(?HJ7xw-j5D5zz@s+lw+UF_r-Y%f{-rBMq_v9EggqW7Cng8y{1pR zZ=CBg&(PSg-@qSLRo#I5^?A_$DNMeudU{r61G2#=SZ3o~C(U~NA$_0daY&qs+!U9b zNc)8-wAOr%kG@XZS^{a?w1HrY@Gwv!JXD9v)2 zr{x{#Z?!DD$D@9Ht)aEu61?xOLEeTGKGw3u=3Q(Wr65c32Pk|wC#NYL3!}0`XHY9^ zaRalFUO2brck1vh+xPAbrG^|sO%OJ?Im17Ei*6FFXNZwko;=aSuNi;uw-bqjD=6j4 zu?3ybEZLv2jY8j0Ow%S7_%pU$y}>}F+aV7ZKtcV={rlfG&u~-;8$v%S21~ljXhBPa ze)au3^yt-V^S;F2rIcwb37QO`qYPSI{a}BFwT(?W32F#&`)rT{xamlF5}y0g{rhSp zRM${TYO*?+4oegDJ~aRV$N?!$Ru77tTTl(^v79TmubfIw_?^nsCu)#$<&j1wbX>J+RbE(tm7$MfU;VX|#V=jGS{umF75WfOymr%@cGfJVJ5+CA=wu*> zb7;c>ATVUIXP@|($h^~s?K_?Ve9xKZT7Yw|Hz0J!%yUa>GVjB%zDtPr{55rxN+ix# z)c;gn<#BV*Ks5@_*TKVms(d}cU!yc#@taEG)Cx15GMEgU)Tff1lA_YOweyK$jX^70 z$fdBC#96aX6wfzbwoLB64kO}Jl{?3fMo{8N=?~x6H*={>4QIGEl&55j zw!-1tTHGScE3q*0;Ga9Q?$NAb&&}cU7HJJrJYTn@)U*znfzDLcj%29z>ej6xCFhNQ z{;7-ka=GX9C*ww8f@AtR5a)*YENuIob$~YT)C-=mdF~_h2`<@~PGCHAz8xZ3RsrsA z#^{CoO7D?W!rtCOOH2LB12bn|oTs#rK9%bX8uVChMxY7J|5E*qzH9M6e2B9AE=tRf zA3t6KOgE;(jCyH5zqZHKpYVKb=bv`!yQp|caEGhtcRnkH#qCRN1*&RuiA!QS3`{C% zZ#3Vr@?%SsJk3ttzENGxNBhjaQU@eWN37((NDQP}iM zg6p(pz_d9>I9DSMOVBZh-Jl9~k+R9-4sf>HBag^us^osff|~p0|%0LpFb8 zQoJe?Fz?ili)=7Q$Z$(mtT4&4Z?EOk7gJq%YDi?9UZYfwTfyGd69n&V4KU0jzy&V- zSTuO{fRNXK3FcD)5=ox}0{}Jy!00P+!~AAW#O#dUMiyO%K`T4;w6w6OC(Z6tuKA6G zHGM+kI*pB*#54)?mgk2OPeSUbLu*G>;UCbX6syA1WKg7BU&3hB11A+kAk#2sc2D9( z8AA;5`Vzv_VN@nR(Ku_&pGYMst@o6Zmi&+F?GC5Vb5mj5m@4$j%Ca!-DJax?@l%=; zA;;g=pq9Sw;>Cd&EIp(8lF>AnzHeXsU%zMO)7RH_&c6#-(w6{9=i7@T(Rx!)ItOmu zU?mKyru4F|n39VFA)?w10(W+Mae0Hxw5d@PAtEgE5gd{GI8%9BNJuf-tOo3Xy*c7W_xq$_pyihq%54N_}R-VRCKi# zIVSZTHSR`uEK)$~rk;Aja54_@n&;}^hOA(oI=i-@+3X5C9eXYAk@@Rrom zFZ_BDhJzrybHbS3x#w$T)qtAh!5>E}B!?m5v`%{SOg0P1rziIZg>^Ia`^yLu-ZkgU ztlWJ=GZ47+99o3_B35>!BJdP_E|KoQQJpN)tM>)#SotQuv7EhX)M^1mtND-587zK$ zuCr;7R5zOtq+ijob=baZR{*5ux2Yw8%)0%dS5@9M9;B@qim&UG?)#n1a2qbKz<{G; z)Mw6_qeUQmj^y0=^P|d}3`(l`1$w7ke$1FARC@dg%^R<22CH}@{|m-u z2k}_9dtN^QJ7Yn+w^y%Rfht-X>cbc=zv}bLh6o7svdI~@cTiheB!-1;UImu@c9NNi z504eE1z-0BG_^ci`-*3mYfDB((7+AF?>!?=!VU$8epV3H%>ngy#l1n(=PZh{ho^f0 zIO<`V?e^hmTJv$^#>FgN;53E;7s1-5+xa?Aow3B<$um5sRhS*l%*`G5rrbqT4mI4R z?(SK3`||upKTFbyEuI$_Yujj_P!R1ud-ZAqi#?DVlCug&>%y$gavJiduU{=Bl+rs+ z9M{>`&(A!g6D!=X-O*dQwSOB~c)l&LiTjaF(!XF|`y;PdW1l{MzCvbfX6cq|&_O@@m>oRxY$(ch#^C*eC-E|~G}y)1 zh|kCHwbp#qmDiJ*Z_&Pe57R~}4X?S(+qPF#cS0NU%fvf6Um7(97+}r^B-hNdohX_u z!eX{dJDN)LPH`k4^gtYf!If|CH>c*Kruk~K=!cOgj81A9J6*O3y6Sas8qqPydE_Yu zuJ`ov3<8d<#Y=CX*FJi1Q8nq9I$ED)>Ty(vETPeF-L_53G`l@&IW2)Owb2|oiBO{q zM4e}=2l_tWKXpcjO+P0)!*zg9r5+MC`o3Sz?(RA3LTk#YEi|?7OW{RAnV*>_J~zj_ zE!d=5TDopcYR7CJhsg8VFR5gAG7WkIP6XXsHyC39q1%qKkE798TlrUIY?)r%j1(VZ zWN-p7gsjQ%X-euOzyE>pB9a=tKuwJ5q}ohWhH+I$uCu>IlQ0GJZ=w^4Di zd3=+u_&2QOUPiYqC!>)?b~e2}1G>37*S-YYw<{lK)zpI+#mA2`=KS(>nYvlTI1o%D zxrMPoQNr}ojR;CwYmRIqbr>P2<;T@^KMBl64$bmxlHWw5e+$-DYt7=LJi`{jqm>N* z2?ow+LIp`{I)2SOBXc{Z;_{6fJ>JeEvW(Ui*Ryf{Qnj75RG%m14hc8C7_Fy1#J4{X z#%`uTu4M7$P&%z1sfWD=a-)bW&OD01ignuCZF~;0R>Vc!S5pY;1jtA9!aR!F;1vLl zac$B9_1wu?|G*mrhK9wkN~;;gjr;cPV;%$9M7$dE6TM8WL+OnKGX}16#u}ex;@0Dn zUVi&0) zz?uhgHMXN`sRgE5Fu<`J{eh>By6B@Oqdna|IXA*@afk1JhKVYG-Lb_#;%?S!IBiO! zoAEe0ZL&7ckeL!nBPp$Ay0JT~ti1Rs`ZhOL{Q2jfF$m$Fqx;DKmVvHzDH#rj#Q^We zme%Vi|C>{kyG=RkH_6qtyWBi@mRB3|;cV#^9Xj+jwLV59eD(hQDV|Lx)u~(exe}k} zyTT>s>A9sNJtJz}$jlr@25wyOx!ofr zg3*i_hLgFC_RpYIU&nYx%%moBT@J^H-FY!7iR0UHj%J_5ye7b>_iUn-?GLRe8bfSl zW79^n>wWnyj>G23mbs--?x~lF4cJ-Xl>3~}JBW5_oe4zZHegSi`b$wkr9%WfIe427F>c5B$ek-ZK&;P*$`q%G|S zl$PduKxml0Y-VB(vy9B0WJU?-Ya5> zBeyrJa%r?w_|1M-d=>BV>scQha=6>0Dj$#_Xtrp zsMWl~pcb0`0U-f7x4H~IY9tE$P01-(A~Prf2zwy#{V@B89XslQV90s3I*&{PEfIZF zi@lVqMung63)$Wg42L3K^JvNHJTeJU^)^{VVs1Q>I;N3jiBHitVdt9^;gWefEe zS@l?mOXU3W^2Cv(VvPGBIh+37$QNTM-QIuPyRVGd`9s@XN^xc{SW&GR3I=o&6D4Cy zvT*;MwHMy_Q+c@xvq2Uo`ie3deA*1;VwrlCAuc*6RgW&+!9s~VXk1+<3}KC(W!<ejFC_MzTq?tV>KveP?3Bg!ogG0gmSJ?K%yVUB+%Z*z{*O!yOxg8P^ph2c0Qz zSK9jcn0{PD3~hDx)a`e>!@|~4^$;AJXN2#B)^!PEEjYlYOwaoyTrUm!VxA$j1}uBo zfa*|JrtjdNe>#EZRrnnKBq0yM&-@P`y+{fb_I2vjz)`HnH3ea;2Uv|7l~*0y5TGSh zrToY9$a-L)MV~)*)N|^gZa$57tN9^u6RrJn zI_g4;U4LOF`*YU3j9EROcJ6@fHi1!OPbpNTl^b$zUjP5wQP z;@!&Kw|>yHvy1nXsI9BHcK`lPHpU96tze;|osP^Q4T;vqex0C9qq{08i9D5@kBgs& zir~`j4@qBU`R3g?xV?QA)8+dme+pkKf#R%~7C3yI{^S}DV@XP{m0N^h&R$Tb)rYY5!eSV6 zmlN+-Ed`Fr_i(lcBDXJYW1e?&kmcg}3u^9V^0knD#P{KhRLCY;VMk7qVIGSgaF{?z zhJ41Eg_rr9S;qKul82>~)A9_^-faB1=ELbB&S)x~lL(MGYuNCzMEp#Mwlcx}>Qai; ze}OgOEnKNBvIep_p{a}MA91JbMpJ&d;A&WZYDKg}mFVzb`fdAF_A9DWub%6i@dVv&vauPhC|nM#IgBZ=PiDnxFf*oHwr<4#2GF>=Re;Y7!Wd=donA6 zu_FGH1|Ke$rJaohFtZ%LeX`ed;`#eF;R8P9)*3W@!I;@Ot-=hxS5>`!v#Jn%h$SQw zyxoT4WOYhbQOA6m*Q#CHIO%N@nsBF-=Cv?t4=1YWnf~!cRK~7dyT&Stp(fjzzB9KJ zd+NFV>>Z~Z$eiuj`rN#XB~GhOl8Hy?MwT@&t85f%GDn^c-Q^af60E{I4l8P~ll{vY z#67yOya(w*Z&QmJxO&x-~sZTG#)x#g3* z@BWqQx43rWyZaS0KwqC*{uYblo~}^SHZ^TWmw_pGL;Fv8$^I2bt!dv?K%7DL?Ji}&nsxqxKFLzP~ z*CRqkRP z{0fqeIWzzVZpG08esuBCO}Aj7VDKH71}sxW-~% zNA&+1ON%ZeGps}V#dq)AIk@J-C%oL~-yuVWO0+J4vca;08_d6d2D)gz^J(+IvboE0 z#;D9K&d;hHsWTC{Yybnn>-vBC_DxN*v*pRs8^?(Uo{>&QJ0;srow}g746T8)2Td99_Z~vvWh{&09p%?x&>w4w%3S}a+ zF2J6ys{YR#9jnT}mpko771uLFW9e=+QNhDSFzL%#32H>x42O!Wdh5N%Y@V#T{|kQZw4uWKNeV zw^b9Z!Gkr~^$bYZ-j9o%7wd#>-n@PFNs}~|_d)kDor#dL8%IX%@gAyg!5PSvO-EOC z{c$7M=a_e7u1)UK*&|1e4BNF!kLCNbV`RpaKekmBm0%;!ftM^(ueTKXZERHFw2C={ z$XGe>CE7{d3~|!Rc-L08mz5^jBguoAC~Hu!;_2^SYLt9cHMM6z9kmm?;O zXkt;K(3N-S!8?4Z-WGW?*pI`=te+T7Xu)8~@mTDzE*`U5N8Hr6-h{TrnFTssyVfO? zR>kc8YW{K2wFb?atv;%2uFs$y+8ghj6WrP3%pWKsrjxr#VW;*SDNvQT%u7jqA2F(+ zBHVcAbmt*MBa>pE11aViM}NFO>tm_q!LB~PS4DiOcFsAZ8ksw;WMSRy)HaD5F;~3? zJFFP?yVMeAaoAigb?j388u*52?c}hVh4+>Kb<9&KDigb{tsZi7c~_5q#+fszRY^c| z-^1E)i!mJ&tI?ywzZNF9b58kCL)OV=2|ne;eJko_{@6MMPKIY4|7~BYTbmDuTgwDt zj7I)H9&K8#K2Qy*LCfD)wKu3-y_&~zp6>^NrAP#>v&W3fg%UUI)y}`;S=_`3WGKc^BAF)%PM;5r0Qo{-kbzE|CZ!tF6+>-CF$P6>ysbG`r z8nH-a!MdW~s%^}}yasKT`TOaczba>*v-2}tb+nUnu<B@3SLe8y|soqBp1F_0e zIoZ{%_i1?h)`cstD%31_oV*`ZDhmROes`EVc<3~bSo4C4;*`z(dH&#{+$}uHm|!k& zcu8*8w*%hzsi;rO9arEpK0KjRPX?gi?O12m${__cK8Kq4quN1m z?iFN~*@NsAUANi+MN`ynkqDCll&K$|of91MO)<{K#vCOAEK;3NljyRa7jF18Kk)5s zvrk19t*0Fy$a=C}o%!(H!baa_o*Z)U{*Tk8%Xw%weO{vMXLZ0S9FX!+6G}7wdz(m8 zaWqUmXr~i$@3#X~yHB)YA%&M~%spD)ILjy0u)tq&Gbd;E+__z`1jBz9Z-#jjQM|e6 z_`snK>CaEAUA=zY!m|z=Q7fMZ*u8geZ___615qQGIVu}frHOSxg_FDlPm|LsvR%p+ zV;GILv>Cswep0g_S??rZl~{jS(9@#!ynp|m8SG!be*J&u;K9R(SMJ=g^9;Fvq9mPJ zovxvwaj@;9z%hG@GTPJ3?3U3!#bL#?+}ze576AldA?-*oKA@ViuOovJsVbLz_s)z< zPW#y7%dc711#fRD)HX5c*Bu)bCntDpXLmW~1YgCN^itNN3tZ;4Z{PmOajVIYGOVNSS364k+Ey0hgBD#jcDtt9O)98d!fu$uO~vMVlsq%Ro0*ow!MN$>iTv$jW|m{73Q*2jt|BhgqVLU-keaX%WP-*tOh%# z2da=g-u3Ogg#X8205R#s1TnJ+#H(mt-j79L4YN)&83Cu;Kc?2s99UHL^I^q>;Hz%> z8x!vD3Am>$!Q9Nbb?n$Nj8w!wck*3rIB$ag-~Ui9GP93swt$9eq1kx`3p{i8(fzEP z=kMP2#4gb7HTz7`Ib6^PTZ%H~Yh@2)ZhA~SQUk;qR_A(lb|U~QB&9QOqsrUt+rTgK zzATQROJsp$37PMXBDcLyX9vD~_U>I1juYd1htU>a3GYMxn6U4sYuNJLE0!IvuJVM%i__a7KapGp6x z<(}d`*V@cm*DWdOrvA6s1N4XP+*=47caCHh3_kJv;Ua$*lfBj za6A~ywsO-i6>YW-jB3YrVI7Y@Hgf<^X;E@v+W>xd)LYB0{y%SZ7QzB=g<4XZn>nvE zv;a4vwm4Eomfn-+dXqaIIX8TtXliJTPh0J$yoeh;o_FuxKc^XFy76+RX;DPaX)WO; zw${Ar=$yJ{K7#%%IBmC-+QVSBE+70{&m)Hxp31B3Sh^Z;#kJYphpDdq|NF8c<@$s$LoOYZ~SEOcHLDrS$p-^?lxH*F?^ z`Q%t!uOBO>@+3gSNAQ1m7=ds^kBKxSL(2R8yubd&y`>d!!M4CH2nsE?NbokG4(m^j z6BIhT1;zLEoP>a3b{@t{yv}<34miYY^iTpr5197t+n2_hIZ>Y-5HsF-{8)p8kJ!l4=csqeUbV0t%ilkgeZTc*vIL_iw^VRYz@1l(Y}f)Li0c1| z6{%vhcV6uEwX%}zbOY%QHw(WopjHi*r!EHT6k>*+>ot&jv3RICpfSnKh`(HP3*EHS zDt+wDrXRYWI|XFl>lGvJ7FW0r`?b{mxUs{L@tvhtK}2lWx33e1$>sIO_6Dy*+50s< z9bfZvv}c`sPM99?*V%hO20DFdlvyP@2?Q|KTZ})xejP~6MyV=>+Rml?{7Hn%3l2zZ zN_FF)ea!)B&K=zc4H`5a^cxAX?vLg5*Mg1-4BdZ$u+A!FPW?~0W@kD~+46?N8%f;w ziQC4aVE*rz)Gf87(qgOJ0Ly4(>$CQvDMFCu9$~WXm;}Gt1OfxdD5aWE%#CvstjBPO zK%^#nK&tNEy&D@5W}o3*n0q5Vy)Mu^hW{7_@x%*Nx zjp2**CL2C=JEB(EaRcU}`+VrEth6uaHb!H9xJ(c}At@=T_msO`?$kaS5`6v@JNjpr zL{3!xaJ`D^Pc-c=Lh(sdw`V&!X>&c$eVt6N%$PQHYT1XbroaJarEVR?D<8Izg}bI= z?%IC0Vu7w65xcJ*hc0_dNQb@Y-7GjT0Y5N;A8p0_Y+PvbV-*HZ#! z*($*Q#87UQR32je@_f&Mh!4DeGjJdfW@X`Uq%i`rQmJPxZm0iv_44K6O!gR9Ij!g% z2XVMicDcn`F84a_wHN2?FiX~80}MC|l9{gno)k@DBR3!wn!RXIFG?z>?C=q0X5O7w zB~?fFA&q?pYs*`hTft$kkUAr)%R zibDeLQ7kywZE}M8Kb;g&sS2E=8#rcC;+%BPinhW<&q2`D5MLs@yfdYI}Gr^3B&rO~d8akf$Z!R`ajb_Q9S-CvQtBEG#Hc1+2h}x`1R+ zcvr2YxU4t*6s6oa0Owa|+%rvWh?Qpbxqh+OA>PEtLs6dYoJ z%aHJ8((urmcw}pt`uWe`cixs%t$?*^RG~51eIJJ1<`2#`=Wn30)N?B9rHENFN{sHO z(q?zRy=%9PqP&|F9^MKSR9x%2NstUZ+y{!Fh+4e*O^$7uV6u%{W&rA;e{+Japxa{O zjg+$9yy$!djF{)$i53T^pmS)(Aem+8^*OrPgx0~`oBr>2-`qE}BOZe>p3Y4@%Uc{k zvO-aqsCAdgHm-0Ss7tls?wIfa5)a#F7QJgyeRu7VO;v^*z7c}l0F5mVPtSW3i*^HI zvzTC$K*hHpogoNnbN62FZDYrdRb)qnBzVDn}!!T?Gwf4s4m7UgL z*vEA!O6hVn;aByIS93;}^q&sJIB7I@*=go|M2TWEG2$1)F) zvtPNB70#p_8zz>6n6<_9`CeYE7kV->Gjce_heO58Of+)0@^OaA`0Oql*~c6nLE(BK zfINlULnN2G1E`Z~zARu)&{e^(EE^B}I45^!%}Y}p{0^Nj@<*X)ztXL?wgQy*ckTQ0 zL4LCfd20&xU(r*Tf9#th$E@ig5W;WlIELei(V|v)kCZ*C?Jo21SP2|Qn*>ylu)-PdYCMlX+QyWRzy*=91b1^ z$3#$XV$-{plbc&GY8p}xpT?a^+9d{z0j5-JG|-`*00S~aWS^(MnU|-c?8nCpqas^R z8NGg9h^0M(8UfE&07N0@{e=eANq4U!XDl}$K9E*@IJr+%ObXJ31Ez~4ESQ-woasg{ zDSrHcTHu_tWz9rINiR{EN9cK<5fz1TR-nf1J>uxz3AkY}BgbpHGRkr_IlUwhpe z`+`|%s|WY&)yv#6YbW;L0NM=9^k`IfZ=ZIPmSXyqE4AA9TX=h6&T*QR9s^s=6>%n= z_TOtb((qHvs>1@1$D?Sjw0nBqF7EDIkP5JUvnpe^AY1a7*16zf*(*MSimj2;BxRfYYMiU_kF`-0GxW}$x&sxL7~u-$5BHIT$#Q!$(8^1C03t5IdRgk_ zC7X@1&pN2#?FGh0=k$zge)k;xfJFJv(IL~2GqL$N-`qI}fa)hsH^=kL1M9dkK0clu z{uuBc>DOc zmJE5fsF7Ol5L##p=4^O#vrX9PuWP^9Kx^=qBQNkToG{%Axw_s=JYb8<3u57~V&V_R z?sT9g9HI<5Y#qJ4ycE+NqtUj1kTh?L&WfGh`D_8BNh47{QRgf_Mn;w(6CVvpm)&Iz zfhCcKW=@)(jdKw4B!^<_>JTCXpe!E{YU>E$>@xTjqof~m@YN^SIX1`Y{EB0aqme6O zbZM5|HM-dHdoV&BL0}6n6&2Y7Jm}_-I*7vlp{(tf0*^Qfm!UICWgQDQ}>mq+ly4^Zgotz*|0 zv9p%kY#8qazD;}d=+O%wo@&TmZKY{0C&8ru4G6GJa>+S?NKDS+zGKZQqcPagJomJr z((2My33EPGzoGXI`DAb)=}I#@e)HtLCjCuKU(kh7%-H_T z#&Ako>PD)XuZ~jU-K6aD9Z-nVe=%*Gi@P;(jG1L;t(wZAr`%j4qQ{hGVDloD{TRTj zR-+85&T*fhMGa79;T)7EB_*GmE=3(^1i~#Ylh0)aBW?&Op|uF4gG&?-$Wk@{(E>P7 zNC4uoB8)-UI046OBGg394FMDTKs^94H!-{XFqMl2oGTWy@C0lPa7D1?~({hd!d|@Ip&#)_Ly=9muRFSHsUS zx75|))BLaDxY!U+m-fHr%js;aiatNV-Y&N~b`bq+S!F2(jKkQ|V{ z`=RcD?j)?4D=b6+FhQZ)gq=fUHLdrVG1*yUJc)pNVpA6am}pPSR1YRE2%>|cA~+kH zW9KJu3pG%ee}spJGxdED{TG}~T*Je>%T++EYL9-Uj$)0%npn?@1wbk+DIsbCs|T+( z909Y2K#js&+XS)=r6+iS_6ixplm&^HB-YG&tjiqsBDl=Mpt>HUtTI`+a3KiM0G{~; z-Ag?bq{{U4C4u4ey`IbOjv3y%_e^!*!B@mzTYBf9Ull|Pa`cysehd zjeGS6wSyQ?C!xVh{Wjg8Wp(COR@`9ykXYA14 zPS9$qsvD)gH$N2H;*Psg=7UOZ(75qP zxHRias9{~EiX?#M=li}~mq=LJ?$Fm-r^cP9OhGl~HXHK5fRMR_APt%}HOF6XD1qX2 ziP&C@)044|Qfo^IoE)`h-I~e6XP$O#gY(u(UtvKIOcdEbG&z5+=tkVQdx&2^k+~KUG5z>x7n%FiIX02^)iN@gnEhm4px$qS zCy5n$+9#_E1-jTdC-=sM9|;Pu?&|>(S$nvn%jPEi%kO>LjZDJA@Q{bf2G;oH-R~qv z6owzE!sKXo>0+jsban8cK`q*~3*x&0O#dIRS52HP4(qo5A4? zn`y{K<8`=>6(#3Opb7c_C+plw6X1*mZO5#9(D=)iFHh$$>;0}H?)id*OUI-L4VYF4 zR$O6$Zy57|qabV2SiJMIcjx0LPo26Z>;s6f%pI%(Ssj34mLSmGzW7Y^lrhp=|8#%Ww4UGihx)~)&C_{~I*rxa_2ZS^meu^SFv?}cU zgj?usI?kTmxeJD^E=lqh5ELRsY*kN1$^^+JZugF?V^G2sXJe-(YK6j-!5KD;_a?+g z!cfU%iv^X{H8R(+cJ105dbb!4k)F9JIC#B^KuNn#pBALGss&}PixzDriF?K-(wlK6 zVQZ7?JCvZv5vyjAb#Hs>S)LQzuRglZ^tUgQCEw@-tBls18ZCDy$98ILOyCo(Sg?l1m5L z&_|H=y}U&?o^?{pv8j#BhXe>- zu*pv7^wDZL^4h$G?36^H?>MuY&QG7+@{QF6Ls3*zWHG>nO>hF}79PE!!`vwc(O!nM z)%ZInS%sV&eDw5bV+vD2LZNHlZ_KxN^W;E(|Jp3H^^67Lctv^{%)gEVmH6|>$Bkno z;n2V1ljLz2kNu2TU0~jn%r|?@q(^jf_kdQMLf$Xy{f?sv2@>3P#xPLUHGY%$+%7o0-D zHwxF0bwuAL;V%O`!YBBIIF4HM>MfZ1NkccPocOECC_WYrkuNVk{-r+Ga9aK z0Rjjvxa94mrmCu{fcm?D1lvp?*7o4=Z0c0N1g`^B6U<=*$|_%+=8MeaCouKeJQG<> zpG@mqOzUOrQ;W830|DKhq=0-ccwZvHS-;GA3klH^L6^)-Tdm1&Cjb2AW=kx%^x37F z;EV@CIx=lVe0U-I*><*=`A96`6r2l%~sX<}}5U)@V&D|3iJMzk|@hEo8(zVFH$ zGzy@Ms)2O6#QIzRu9i+XAJamAUeEaBhYz73e>$KBttGH|0aiUuThoz@-i)}Zv8OgGe!1_j zs#2WiRZ>|4$4fQBXLfs`%(IBDw75KkLSb&v zD2l=nXben8)Ol!EocH}bq<;PS7{oSk+S1to@G|ezdeY-{Bm_JOX(3~F6uUxo-+sS> zovnI!>>f}{4R#C5U`@3F9bPmkewW1-4wyi`AiU1UsebJ-)Q93O)|1T$(-{n!Dv9T< z4OaHZfP{{NbOT+m|H@CDL{fbEgvnS!;5*H+jh)$Q0tx^4+F^XEZ&iX~AM}v8_?F}f ztr(~bRzM;FYb`(hf~YBOsQ}mFjCYHW2C1X~LE(fuozlD?U^Nyc?`v2^f#Ax?iYQb} zdcOi}){pOy(d4&1VF}A2Yp!WC-qH*fB7qY7`0b-5!SAGtXA|+~YirhC0sc2%b-}@nB?hA8yQ@0tvp1qo22DB+(O}tGf>s&3cTgZ;_|ArgEwOv> zELF&Rz)S@o(R7*JgaJomgGxsaT-XM7qQ@1}Er@lTxd4{zp%N##KSMQ_q`A0TOnvI>Mud1IVXcDvG?n>WXUaKLu#$B651 zHt`*{!vqg=jS8aJMZhnZAaxeeIfO}|xR!$gf&$P9{0Am>Hx4dX4B$J6t%sbop{~T6E7pe7Lo9 z(`S(jCE(j|{J1_G2bn{bR`w|Ns|{fCY|y0^gOm^91%wz0kDvPi8JV_7&^N83FttW1Y~+$<4bDS zswMs&i0aIr&5-ngaBUP{LWu7E?B2s~t~nemrN=ZsK3D$GA!rlV}G?phwf=D&+!+J|WF$8yOkhFHT0p1v9bNxPTJOqf4=;bCIIJ zvzQ5!sa31+#ELU#&eR(_*39RUZfuvq z6XtuZwxu{Yp4rbc7!!K*XfJBFo3Q$Up>vk)-m^!E`q-G2c>44PsEnLfBSnrqNsK#g zIq%G(8F{v2g7)m0Rp4)A*qgDKG-cd@$Jl~RG-?q8Fz3@+T6fxvJo>KZj8?m&qS{kE zl)*+tM{rwcHD1B{w}2-O)lx607TZ=9GH4Hx)}cOq?vEuksbU?#c)_<)tI>(ccbmQO z!Go!y#sp)Z_YpOw@bC!*f2f}{@yY&>u9(vL)geXh%#b?63g&c)GO>tyKfij-?;j4bM zYlg0&jK9z=mA>z^Z~bOI2o3nGgj!8-^U`66|M}Ur?to&?To?QXkzb6odt!EJtXlOW zL1(662TEO>wA~PJL`%)|m*BCY$$}6!*ef*)C*V`enBXjU9WN*2sIb3+E&p?4GqdX%RZ@CU{%ouH|L=QCJ*0o6^}kO;ThDj*KEcTvH{)HSylcH{&X{AQJ3Pj~pu@Np z0}WLBB$~`=-J`qL@PxG+u5B|@F+bRCc~Nuka`%>|N#CaTeVq38^T=bxg&w&xx47ln zSKOTLdP@!-dqy`U&aZ=tN?`bo9Xh79!j{lXOFr9K>rS9R$W;RL+RLQ2jVr7ompq4< zkn27{;r{b;Q0;hbPVG0ha?_~u#QCZ6%3<>f@!M#w<4vr-jWMs^DZA;}Ebafs%CPsy zegpqqp(?MG$SZE~zaRMT?_oV<1%oL3Z!E)>(#HIEn4D*`kZcz>$$v*Uyr%M9;e^w< z+-oTz>_+nUzk9M~W(T_wDPN<-XaDmXCf1hIMb-rdSm|tB=+&DOX-S|b{urdC)lwJn zGmHk^i-vQZoB#}*dD=qdH7Irs1ONp4-t*Jq;DhsyxRi3OIXH>GhpttJ4!8c?tn=$0*NyEXF>Hsyt5M_;z=5hD=2p5n_ z=+O3U9H?CIk*0u4E}rle~m+3*(CqphgR^49isO9 zySur4Nbr9*>(W>@3A6Lx-JNIt&&{SOH_3lbE1M)cGxPhSH6H&NATue8eEiaI(u4^+ zdj8K8ANc7pci#WlF}>Z9TCbPR{Xce3F80^Fq`v>Nt6KlRe;jhKfA5*~hwRz^)*ZiZ ztM=c!TdB$>_@8x$xjOeh?|!n6n&w(PlmB^ryQioiQyi7rr9U2Tt~QxgL*+y*T=wMI64frC!I$r zagBqX*tLF@zkJK)&2?)Ch6k+WFsj*tGVXt1bqanoe`SSF&Z7qp_7h7-7psJY5_vu& zeU(t17%sm2`oo9qX!59Pjyp?035|O*xNQ`Re!VX`j4V)e@$1R3{a`W$Kw7z$kulOX zV#>D)AYso@Oyq#6=O8ZD#8(lueqvV3I!cI6pGUaX+q%EWxrJH?swBT2&78}@o{p5# zNOvzi-o52Pldc{gI-%JE&Hu!yQ@+ceETU?jaTl=jsVNmWnerNI!3L%(_MHWAv3q

^T4<4*3k@KXu*48_FZ>H)U6l!T^8n$iQXV+tfTvA6z2d}@{cuLl)_E)`u zyd!2K--QH3KnL7Jl}d^^18L6vZYsthV@Vp~Rumv`H$9E>qkPO`3#)_xC5nw?C)Bkh z8->R=?bj@1mIa)fgpHa0d=(%t#8B-l{*;!CJA2n zJC*p5vk-!bNDmN}Q`hTTDvK<{XXLSkzkc4&@Fp>-EUr&q^Zi=eFNxL~BM ztJmb#b%`g_>8dW_U_N;g#=5X57vk>=j(4#D$#I+5Rw8KnFYC~(2n%;};Eay=A64|c z+55Gi!QA@#aSEM$l8TX%guHyYUVwRE)EL_4kkjDT8}b{`7nMRCsU^JcoX;O;Or5%) zq!?7n5E8HFe2+9|ld(e$Vo@lZ*Vr9GJ$#((-3rZ628cNB@~vCVC`fJqkVRu?J-K>& z*4YAEh5$OJsCDfgyh>HZYt^^j5Q9!KUd6tMDn2VK%Ul&TLIQ1hUD!TS0%L9XoIWvY zGL@JbBsrqWcKpMl@6n@waem2j+^jqYewhIabIM^@Qsd<4<;%3};CZ^_s9nNp$X%ry zB11_k?vT<`vHj+cs{3xKq*90$m?;jo7jC#Fp~!*WA7OR8am2IF}RvweztXgEg@Uw4%3`h22NiZ644Fj*M~}9lQ1yzor?t?4CxlOjf^l+QZil%Dsl?s$Wws(% zntDQx@pBc4L-Nbt+snz1T4@GnI#MR6Jgy*X*dPLk)y74_+Lfq|G3;u!zb>RPjM;LY&LE#ktIF7JT$zRZbBYDky}U$JA|qg`i4vbR zDHjil!HHlC^q{P6_xRCc;xLD{PuWj@e-?8$mzdug{L94HDmGOQT60R(MUh6dEkX}u z&0NxPzjZ_2q8za1$8O4(vF&%@B|YC#&s%eG3mD3&>xTTvQ~thib^yH4>;;WW$PBV6 zH~;>s5xsv6mA->xI9KI7V%ehRh$fB*cOIe~5)>P|;J~qqEPhe*n@eI~EnK<@;r3Qm z0WY3EKY#~Fysb>G(m|$k{EI58l)=d73kpTfR^Uj7Q3u@QP{C~Kpb=Qk;33B5wyh5l zp8d{l+q!iZk(=be^1mF6*L-+9;P3ZlyXY}PcA(y7lr2qV@n@5;M@Ve?XV(-|&*l=& zTa$h^_IrZSY6JXeE~N)+k@J}lUPM-`RI7v*(XI0Ial1o7hESPtN0g+lx_>o+ih#~aG#zz49KwcQpL5@gqYmDhYD8BwBFnt9b;cbNoDZhC2>QM5^@)4D>*6CA?JF2ScR1PY9H^zK3uwD#j z(9WGZ?M}N+nzbe)yWpJ$0@5JUV?R3%i62UR9`1f{OLRHjX|yrFrR4G+;zc*Cwu*eR zk6^P04P;xk(fU0xTrJZ0_g^2)jDKT3yZ7$hn-pXH?FwgQAd>YA{89BIziVur#+3t* zf~;S+Za1(s2APAqe!AcIAfrA3`quire=dYSogj*+6m9xtF4qU>_u|8c-Z_o$8V-P` ze&?`F`-0>!&gkie_?0!T*K=0ILa)Qo7H3>a3G36JK6QZNu>JCNZ%a!{{)#@PqFaH(GMAF}k)wHy6C6+_ZA7Usfx1|Hq8Q~= zO)0)DU!Jq{#`owKVBtoh{>hJ)W~Ippuh$YEwV+F)ot2aJ_+JOIEz2J+Xyg%3s5>2gXgrsfbo)QwE^EeqeZL zI28!0^XRWOU_oC+E}|)Fhz*-Jx1_asn~?q8qq83!W~_?uO5uWY|Aw(A<+xyZ9j+(x zZOK(7QM*i+$6fxAG{Mg9S!wAe8ZdL{DGXTLRpR7>D}E>6&%avjyJSfpfR1ue4G!pb z%xKuOX*y_t+rSO!zua=n$RifCe-^BtIt&MAKnY3X?CIVi>=J~JigA+HsoRJ05%rh- zr!QmkitrkHDIKrT)JvnF`YDI_z5o?wj;-tCUNmUXnqQAQWxyQI!@uPL-K+!}0h7DW zgQX%aMT00B2L~PA;K9;UT4X`IzNNGX=jUlrt7-eLKP9^(6aH*glDluq%a_g!;%M=% z5{ppJmb~zCD6N?H4^PWcyq7~`+MwKqcvyWoDTaq`f2f4p1)!wvudylM6p2Ka!6`-q zg+T~b(HviS0$^5*<~MK)83Hc_bSPSEB9bU={$sNSqUqZVCryJAQ4AKBFmMo5%67Er zV>lWi2We{!7MAfLLqpiWd+SWVjPQZ{Fy8DtyGZAE#K&X*%duzIV8RgWXUoqEx$ z7qJcxT`xvsLO+pV178BLehkjc?j8}LWlfwEjk~&hps7_pV72o!*x`%bjIvjOPn5~IYkFm5esUR zQq7Mso|C$=ycT`WVFa5?5nPbR!F;I?T|J0cXY<2|>x%kcKR(?;8`crfk%gNz?36Zx zX+ru_?H0W|O0DhRpTk#J65K**+K3R3+y+xTob_G~Xuj=HmA*tj4!-4e>EQ~4S#E>H zOGl>C0&~F&cfTw{ggATY3q9U;yNu4>pw? z>-B-c0kHpx0k6=-Yw~sDgvH1^{Q+)U_57s0bPDq*iTfX$x;304B#7I#ATssr;;)Dj zmwbs3XZ@4Pj;I4*-P3Kg1jSS=>+m7 zv)p`q>$*^3%z=b2rCi^+Ed~GQFJ9E5HcRFuKme+0L9{)uUcLJK*|W8zUhNj%KfjxT z?4^thL2zY|=WSl?Yb%0-Xp~67RJ10yHW4{eq@ZdN{=^u&WC~VIKF)U8LShrsG3`i! zFY#!}tv$g4EPXJdr#Z~uAO#o-F%H1HTAq9+$s7)N9V_@5%MILzo~YQMdk3_s>6xo@ z7t?5NNp#Q}kV+^X?_)qI<>a$d5x$LxxGTFJguxV;YQ9AR=|!&JiZ z+e9kp)K1Ct>!p&lNp|>AlYyHyZWL`wT4?*r^h70{ebNWu`;+5F&|E_S#=ij8xbwx) zJA|16HdM+$=+#rx&_DM6bUho?;^K^nXNrPtVw@m3x%7aSBWk6FP%g#~Y24~TMYhL-$v$gqUYSpM< z?bAx^KeD7&AH#QZm=&d^1AGcZZ05v?t=azUyloO!AFbboZjo8U6%ityn-Ih#^xP+KPmVsM*?8;w9R8e4v^wrDJ*O&n9!Ky=)Wai zp@K-cJ-S%kOup2#{YI4ChI>Y&(R^vf-6iQ1s9D#@;vC-0%(5y91PMD&;7NjG!_)!o z)O}LB{AS=_+xRi<{n3UdQ|_8yyPO5?%Xw+6(-yyKeiUXU?6 zYvQQ|DTUtHK9RJW_{b)W1&YAZC8dyx!7hA1X|7WABqD9V>gxW|Y@zQtiwT>SIHxW; zOMw7UfI(-2aFtVUXWg%+=|zs%<%#NB@&pbX-g6?~GSJG8O)dvh9rC%3nE!f!a-+V@ zw6|gMMzZxcyk8c%hts;`ZX(LVh`9B(xpw`d~UWXI2`cZI-hMt^OcClU{HZ+-n>w{Fki)DJ#i1ej-e(<8f#H*y4 zw>bn6$0BG(yqhDZ=v;YvX3`9_cmz93B~qZ2G8h8~CUpz2s!wbx!?_?MN(9otU0rZX z_A7r;myUNJZ;@wvvXY)9p9Wkq6LH1&U`>%)j)*gOv79w10S_E0B(z5G{)ZlP?a~9S zUtLa|vg_itP}zZ6mQ;bYQA!ZztA~xUZ4JVvl@jQ#7bmKwl2l8EVSCb(UalnmN}UUH zxja(6;h!VNx#Zk9VYvUH2r{CIzMwP)Z{CdLW0uJkl{Rf0{JylI-QWF{$1TEJQK1F` z4gmRspddwev!cX?tE0x&gzv5Oj=QBk{|^*+AQny2t^B^3E^4 z?&3n8L0%atQUrPf9)hnwH8iNv2;0^Y(I_=ex>}0vB&SUBhXB15k?_!SHbokq&B`C$ zY7SPHFlz${w>;Zrm56@}3PNeT%#8&hLArqY5wd`xz?5L7|}* zM6vv$Hk**2Cf4yew{Ys~qN1WAa)up`GH|LVD1Qiq)UgCj#(okRND*#@AIfnVjEyk} zmg|^Kfd2U;?SjAM#apSVZ6btvP8A1?eHpp3^a|OebStB4eJ(G*EC(~?p+>_9joaJv z$=F33RH(<%fh@+4%6Hg_Wraz)D2SL_+U9^3l$e!oL{vRq$Z#{0B@`M(y0miO0dRKF zKyB}Qb5aSLEmh_}PrHnF)`@m?cxk_W31`~3Q8(0kb3f3W~R z-p*Op4M+@_pnEsl$H>{Kp>gjwA_xZ<1;6QWO%5$~<@ zm7}Fv<&25=Npc}fgG&7&uhO9$if$vn_5+Bchz)-XUa>ZSA&7|Ma2;zg6NsJTAMll0 z{&{nfDoEdq%3o-GS`)m20MFjD^(2y6tKTo?9R3^nX#KRsJ_*xr-|W3NDB6*RgGAb8 zdK(VwH4<64$XSE%O9l{XL=3iuz#JJTxWl@{l;lzQ>5Gq33sdJDLzi2xSn_D-!|smj zj~pAlVCFIWrXp==@kTDW18*#bOzzI>izGRHkgysOl|jImX4~4A1GMj?y{;UGf4Q@` z$IdANjpCvZTXq5ikux_uLq=p}X08RSp;W)JPDS|`INun4FBTW5Dm<4|5M(2##Ang0 zW&PfNoSg`+Skcj)(DeBEUYiAu0(2}W)K-u3he+DG*w3u<>!V2~DWHz2^#KyG=2&Qk z)j@!x@v}zClKwTyE$5a9y#w@@FRUfkXwufO4pqG9nAX|UiH*LsEWw4ISXgKX?T#Z- zRq+*f?~UF?$Sy-v;1lBDbyVlD+7B>z@Gt=jH{}fCLtNHTU2cH399`?I2E3|^FQP24W>{j zz@+`uKGipnD$)Yi5$~2u*beSbs&EKsa&mzqVFQe$T>t@;VyYIYqIlvU}38K!)7eS%swYm7G4&a zqS09&N0EVZU2OdRKIv3b@=#XFQ|Vz4jsN-<;gx+Nbz9E;t4&+rMbF%WL5+)$EdMmJdAao(+aNI4!y z0ucaU0fHndH7`;4%7d*mWep6MuBI(jMwyMyxdm|gPNlUMz9YP3a`T%nNwY|)MZyqJ z5XHAUFN1fAI4hQvckV+`0+cRWNYoXfb}98s4-KO(Z@}pSHU)!hZIgsi)BEI(1?1oT zSA9)jcz2N8ZW;2IRe0k|HHOjO5X%e2jA#Jgts#RY&@k|{tV?$yH~j2fdHm5E7*A|+Fzf8dKW(B8XVp7-AAZ;08h|f3 z=WEId>@e^3N|Jgdv8Pk?%(B_Z-T)d!%%9!2bH_T#0~g8rB~+CR2FIWC-s~*e zYBv26QhTE@=3#S{*V_fvyu_`Dik3%?Ge%zV=*7={(%d0%4_1mzjvZT2+35Qv2b-&) z6%{{@O;2*!tkZt7gPe9kITgm8KALHlDBj9A5R#s|c+u@x35zKZW5f#36HW)r5&=7C zEAifqb8dmw%6sfQbuzBa_^jdygJ5J*bN6<}#?|#u5vlU5F&jDB%?ZN;W9tT>Nn`mt zdC+qzHA16nR%uuyjME9lIjC0}nY7)po$~qoxYu79k^Y`#Fw8f6MR79c{Y=RiXWdg` zTKY8reh~Iy{-Hu|8aC%Rh=`BTo<9&<+60fFSarGoKx%^uMQs2tjqyjYqfddEbEL?~ zgG{4zzu0v(h)Bh(%V9_Kra}%yR2f?30Sb~0sq#TD%|(H(9J(>$%aE~Fr7Su$na4U|;I95 zb6t&`r6RTp>-qxBED`})M8YN>Ys`GYk)1rhL9tBE<8nRx+6#vwN8;dOVWZ^Vv_6Xm zyAW5!u0iW&9?F;ghgDzR$l~8iE6Z$g+L;Xm;1HftSViev@zkqoIRXXOh==4wSp~1> zOe!5iFiSHRb&~U(S$Tyh4-E$!91Oq#NWpvIvyJK7D$P3)5(y|PS_zvFoU@+X&@>hOqDBZ!l1|7i z95=I8L=A)GP>d00IW!dMwf`eb>0T6qqvq?cVD0+F*l-H+v384VK4XafFRIQ1uIKi9 zki|6_M zyM1V1$251m+Wo)EO^+8i$=l%mxos0g2?HHe zM?Q*I>*Nx`&PM;Ohba_3Mum6W z`h3TPPiuPV=R6{Xg4^B+8yAc!H!3_QEt6Na;IU>j!l=X+Y5|-_&Tc$euu0{w@hU1& zjn{p+@D{egZxiA}^T@q3DhFmd+|l(7yPZ40cybdwh!3~Q)1Ui1-l{a|&9^?SqpCuLvkep=E2 zd-jOEm2@=@&uwd~a3_l0FZvlOT^1V)I_SJp55XYi?xj-*mk=SL!;OWA+i+ zzTbkbUGjwE`jaeJQ8#ZN+93X0&2iEax}wE6)wzt z-=npE`y0My0ZbTOh;x}pH54FHfaawwUUBiPIw!dt4lU@PMO?DVdtpk zDy2+H)`uh7$_j*Z{ov=9R!&^9uJiMR)wWo01-iPv;v7qZk(H2#h=rlru$J?$zjRf9 zL)?Ro#9rLLe17}&M_3FDp>1)PyS*azu$FaNz>gyPn<1mdutj~vhjOsl&$HSE%XP%2 zue3boM@ZVun_JI>D$e2!Ig3Z`+=Z&c+*x@|&Ibl)fk)=Aiot7=9WGeh|Q?ekG;RWc-q7nN@4!A z9WcO#0<&sz3$t$?&u5f-;H~-R70BLt@g$G@SZO*aTdVpd+m+Uzb7$O%loer(Zj5hS z(dE_(e~824i}>eVcb(5uGVUATrf>E6aKCC>#^?&apRzG#=RUEFxSP^J zLqY*Sqf^)?+XoNKyPqd6T>dd>YjJ!d)fv*gXZXsC!#`Z;J?q=t5Ie3j) zwH+AK_?+}aHYegBjSyD;mC{F6$zh|VlQdqUfVA0 zNM?+c3Jpy$-0F~2;&uY{gFM{qT%8%0Lc1v&bSpSLOU$twdoP^!a3zkGE3K2=@U>iH z(9KsiR(^LiXnGy@j8+@$BvdJki@oLXrXbttO^X~%@?p{&|B@X*QTqVE%y=mM$ zE^9wqjNzoLJ)a@4xTX9~onf#3pAV*doyl@c)Moto49ZH7CH|k^B;7DIg=OXUKensl z!TH(jk1_86BNQSmv|L(&#aXxid`rb)*cTcLmFw8V4kwYV(r+hl1}ct40`P7)J{)qd zAoKqEaLRmgOU=J0>8Ny$WxpkN$EE0&f{p0lIP&wG{YP>q!o`jDusw7X)q%D})TX?q z)`6$BGUc|@5wlzKAUAzAC0-vlzQmn7ckWkM!12qyJL-pq?wga1b^`+foc~7YKU&vZ zVX7OJyRh}vWfx!J-kSD^8Op=6)arb5e_rLhLIb94Q~~lWs2G~y4bF-*#PVEZSHSONUYaXy8m~{yooR3H2@gO# z9Fb2*e?*aI?Ap2WN=nK|w$N6$i=t%$P{(N1r=g~)N%_6eIiD-929`X7b{$dvW6~rT zlIBK(B(=F!|4D9cap^C$t-m~EZgsp(T`Vcb=6g?I(ey*7z&&!k5F=HV>!mMM0v_J{ zlh4Ghu2dh6uq|TRY;QaS_a#)qx;k3mJ;6unjJ{7gS{$}I9QCwzy(cZsQ>E*00g2UB zzD7$_vgse2m~t!fbgIb}^(WiRIx|;Sk!@x(un5l!_VZC>p$A|Xvdh=~&g2#?S`@hS z!a_yf61d7kOf^^r5`Lhd#(Z)~Ib(km#igl&12luAB>(j6*;FhfFFm|CEzQs5u=sR5 zFW=t$$dMxvwjYLmU%AFJaMhk218fJYb(_{wVZGRG zmAccSN>{Z6f#~~wejbIv7S>l1de!#BeK6%C{YJwRVcW%>QEdu)jHWU7hY)%VfPmgE zZ-0en=~}XN+V_HaTvQ5EodfaoV)k2~N?tAA(Nhuo2^*Be#o~-<@I2jTsY#4`K{%3<~!a9;XI7$y?#JRL9C1hbmTsuXl(G zCw`^Ld`QW3A$WUTXp1aoE*^8`^SAZRUOqJcXqoa=?CF-h=|~7ro9lchZB}vRuiRPc z+gg9lJ-uKoTwR5P2V+T_3%wNQd-$;WAi{@JSa)DFH~ zyQHQ82HTceZ7q_J6=L0TkZzi4p)(8$^M^l;#I21<>dcvmiNqEi!na2XluqV$1ew@UI!yulC|((f9}3 zjL9&))5z?Vw0@$zQ!JGfK1j7%*_6> zcizWHu8@$|_^5#rj1w;^g(J)XWvs65OkFm}_D3&_ z@~Yq;jRWDn+3(Men8*Fq9d;u(h9Cx!-X$e$sphlf_lw0*o!$7zD?l;bHH_iJx07A4Px=&x~GAgNY$CRB#jLHs?7wiEQvI~K1DAr~zXpPOx@6Zvenhy2-5{tO;C)>HZ;M--Hp`}id%Tw&yxPFw!suTK|fbeL{z zsQHC)A6z=!Kud=M67bWF@r1&g(-xp#&%hvw!yj33DL`FR*-;^>Rb<$eaE`!nnIdof zyJymXdntOXZc!1CaAwlf`ygy)*yON0G9~`dl(S6X@)iTgJ4fryL-G&cPhV}ejB?^N z*CvFloAV$BtKS2(b4OToo6taQm~2ykW`S^d(fx|7?(JTbvcEsM1Y^O~RjRxUMR&@_udqS=asu}aItMa$G-P+n<q0b6A+O{Ff)D zJ#_`WZaE(?H0pfq6OUFgm*|ir*O)evp`b6HiuvPVY#gPACjm94>1-j2xv-7t)xAif%U%LtG=iRw{`Z zQig|brxvjA>EZ#dr+fqvkhGWi3Nv|BG&{u?_4d@8!wdYrshVX=zbh(|efblLj#+}( zk1Hd|8FBLIKC9SiHYbhx2~&&MEj=#!QzE3`vgzHZQKO$ff9}UGfj+1SA<5u;^FB%t zznd%Oe9|gB#8vm(k+a6@moN7}`M3mM>itjD7r_YXSzc8vTST40Lr3mjM^NmUwqx;o ze)PfFKw9XI7Rst-QSf;%>A6(eNCVGWMXRf*2rUlu$fH_k-AS6)$HvCa)VMdNWZ!Z+ zz%sy*1+sT`7TWEYhwhY#Utw;U42XIg0nQ?$wb_JQWkNYiEEs=lnEnnZx6L23n3r2jeJ2e`f4wYtfQZo%Ll z7i>|sSS(g~X>wVVHR+NEfc>lieHB%{2~YkS3BVG)7d>$d5ZMw5J99I~3F`sAxP%c3H+iRiOF#I)#!X+i z$O+rj7>_hd>NY(DeZz59eC8@V{7y5Th-O&dE$^q6{jF|EsUE%2}&q$jk*9Y zO9PhmDleuO$6)F&c{O{-e7!ZkteB+rs*wdT!K8 zQ_5sd8cXvgG(lTpNW-8G`U;fEwoCT7U{P;J8T_pEmTAnkmuZl^w)j%Oqt#DNn9r3D ziz<;0^+X5Cjq*@-)^uEV6tcC|_3PF(xTO*T?*faUGcjovPe^Frw7^kZj~Bs_ik^bOh7@Ug?3?}nZzWGTqFi;iFjF5nvxO%ob);w@1Tg zDGAS-2qXMq_zENP)hA;{yqdO$xPL$~6Akc_yM8SkMqCl^X?o%3U}~J{yXaSTMt<2v zu6gl1#5W46=ut>v(^164yDikS;QvMQ!p0lEWLxLpt~SlZ!4?i0>xcGMe078{+|TsP z@J&rcT#*--1K($sQ|M_@@g|%w#grLYqMK9SDgtY`;~dheN#_@lp2y629&Z4nUPXWy zvp9@wG}*tkSUYhx-KBEm=0*ELcESlbOzfdt>Qr_}Vpi52n z_SSh98o)P2<70gG>83}Bok;P5=LKE{`F@O3C~RIR*{abI@SKIIx#aNpnf|?HJ`jXx z?h+a8hy{)%H0n;I8noJ7AK6vJlf>J@+Q=BNsTIYV4Cv%)^YyoHKyVYY8-+EySw$3y ziqP)6?*0{R2aCyz9r2n%u|~xvNUAl!CvU)3u^K(x*tne`o)>sEfP_X~Fme$9=Vp+VE8+RqJJtEz1xD6^s=;#~;NADGzNZ}E$Ewu#}WL2wHUBYdI8$8&J~C*GA=JXyd4Z!Xz3v#u~W`=C60S3 z{HSq#6|*7WRjoVA9k9=bT(MtfwS+gdq-r!4v1J9p}`e~0kz z!$z#d;2pvZ4e5G~btjBo($b6egE9^82wJUK%^<=qooHTQv(eyh& z_WfjDG&UcF{UM4;{)LJDsX z;0x4z90OMuT*mP=Xwd2(l8v7r3}0lv!KZ#q#e@!@{>lAWX#2 zO1#!&nFBZ7$*EK44SLwT)uE6{62ogkzw&LAL7s5K_mlcw)o&^s0oCTsbxx&& zcg09d-OdU0}PVC-1pxPNX_SIV@t=&9`d!ep1fzB1G8Azx)lz+CeiqbF%D= z&X#s{0K}RI3z`$5vhw>>q^#6v)#XHhgs*dKkjUp@)U3>;I9`X)H*EH7ZNw^CoxM(Q zYt$7wJtbskmszA^a}fDXCwCdL(!*pP4}Tc>lgxBhjJj^Ae?08#I?Im2!esumXe8ej z;bOQ^vgEwV0_<=9h^%5}X56ML1hUs!2=}O3oF5)>I`t+bBt#r!@E`3Kr+b8)t0jgI z>n~imVAf(NqEs((sGp^!`5u}dj}UTHjVjA}5CAsE$1x45FZYXBva#k(Kc3TP5TuMd z-ZZUbhE6(RdEsTki47!6YL0|FTMQoh1Apw z@ve}eAih?OHy!<*p>$WgU#pYeig9BLi}S5c&3SZpc5-kEm9S*r;g8Dtf4=Ai)8^XK zr&E6DI$30wgH7xNtQ$P=7x`RM1LwvF0qPW%z_U9BB~OYcW;_puUU;uFv^F`m!XVAh z&3Ez`xcI>XYP+Vz@6t(EadkU+^Vo6ep+YXKDp?Mqhs`TKa3M+5w=KA^V zs!)-yg;*`|oc;$(8rPi;17Mn{_`i|PK9e+dpWe1!e>u=S#MXK!% z>mu&5>x7}YlZ$8*i*NkAXXn;Zvo-(F$GF46)UP=t{<2B(stGC4fyO01hVl2*md+S| ze$fmzDE`-tVUO2rmD>Uofuus*PY|n#koYxn(P7J%_ko2t9{Idit7K|%epozrd4iFT zJtsv_n_j`Tm1uC4q^oY}(|g(3tqCaCVMrm7v^)N0YcF`XcQ`FmL0 zQ~Ra*>NWc=c)2Fg_BH;5*XaQ*8_RLSfvPY7sbeGkH7|V4xnU^u%--5|aK;gefQ_5>{4cHU zv_eNm!C3bD_qLo$X5r`%{v>&&>S(W%6@Fs9h*&|=+?YbkT`0;b5qC2Y&~2Yd|J4v?yg*AkDsx%vhcP%xxE5LZ;uHHqVCR((%Nw=k4RrjdLc!4KZ@ z(S_$%q$6*3d3vTU#6(F8E$56saq5&Y2$(oW1HL%0%dVb@4KRZafG?##;GhhgASIi4 zG>1E~E(3GqjmRT)1Lv#eMu1z2MXZRlVJ^m;u?c4~2^hf|@$JPBniAh5AoPlsPQKti z1qFWWkwc}e<`5d_G}W)Ds6aUSu+ri4IRPTnT zul?rX8^cESykorilJ1$N#dAi5*>_ewq2gQDzPrQr*Arhc>rQA$P>`xO*lCQg}doMUGvLx5TibZ^>w ze)c-x=7;LXjU0(D-C}?n)QD-Ox=Y1ESOnVe`IrObPl}|iInOCz<$%=)-+(lSIXyLK z-n6M5O%+KaIgE*V<4XGR+=TrYbJX9z9%j-~>Mi-4-AZic=QKx|o=EoV7to1bKn4W7 zbj0!kI%F6^+4iGS#J1CR#;R=(OQjzyQgJAE(Tfy_wp#86!RBTb6%ZRd<5c5w}?V+*|EaDZRc z0Jk}??OeZp{RFK6a@+zABrdp=*`^zr$oD>@hTL2v0D6nwnnWJ$9p_-dcR_}S8tP>Z_ZjbLW^NzYt zJd%u7l(Rp6(h&KI!9ndClg=$@TeDuvNwaqohZt1Xfb`Pyt8e(JCKik34j;O{jtp@6 z@;XTcwMQM|3V}p*mBam;H*3k-BB?@BS+!%wgXTY0n%DRS%Gs$SpkDe;cX1I5y4%U_RKI<4 zhKUQER0NcnuaOuO>WLqUb810;eonKZ&a)esc`Pz=Q13ixYi^HTUAm0$I-l$aa*hjZrIIBLb^pmw<4QT-$G8TEYAGi7XJAin#fb!kqS))j$udbPMZg z2L>gSZZ_4sN4-$mh1u^62?--dD|~U`(xpqjC%d-dDq2q#Z728JjWV1_s&DBA%npk3 zl0lIP$1=a~9c=PgnpdZ}S6Yz9?#&&X3Vtl|PB5Xa-+sc#B+F??+TGbXhNPq_>{6MW zcBuJ^2VCHL5M}hj%eC z?M@@eW=ACR1m3xN5w}FHlBx5qw6s*#lYss1QBlpBH8r=MpUJKZ+ET)eShCZIlRYiL zql?aPPNq<6-0m>K*EGp`N4@h7*=cF@*#{b%OJF-EzF|pKC}V(cuJFEuQEgj95{_ey zL#{LZ)FNthBCt#YjA&x5NFQ2b8ltKuDTBt|T%@hBuaAt{Kwd7^VYFnrpnd6Nl=kAn zvT>QG$DCjE&JLG4qcmtL)&IT3@^g85ctdsDavvK2f47G3s8l4CA(_x})#;!*cIm?g*g!OJu?Qbt zS1uQ#_uZm>-12h1LDTF|!g=4Qje$9DgW<;o_ZPN&(90F2bIfcn`yD=|Zee6e8c||g zTge1IC1Me|Z7@ZNh=_PeWa(~T;C9v!dlmjRpu7eL&xu{V@u8fOU;Fc0Lkfh@@ zw!~hW$z@jYHWz!{yMI4u^{Wb^K|oMYdyOI%Kcz>Gm_x@-8{p4?;ho2hbrqbP1!0eoXNlflLZy8Y-+T)`SU7>eQ_}on%@P zt{XSBUVNV4VH)&{w>8X^3gZ-|HpUBo{_hhA%!i`C7oMkE+js@0U%VPFmu$TpIJNXXI5 z2<`PC7n3l?m2T!qR9l{~W;7|H^hY3TP)Ai+p5VEFctoXI7hX^}HiYI;N59soQ|BJ` z4Br^`bLE#_MP7P)Ot{6V2Jt2UvL~$OixZs429T&Mo|R-D7Z=BfO*PtX)ddb6YHVCx z`cPZ9u7QfI^m$n?eI7UIET7{}jqbiFSFR}0FCBv4my!}w4hr=`SK4;Kr36kO2apSX z(ZL?=%zZX*uFH3S`=}T1N6%-g?Nc`2AW;5k-3~(m;wrLpcGg*#6LJ*GMH0p}Sm3Ow zyl@5=%|R%3P)s4yQ0md+q+zbxC;f?rINT{rvFb+UtQ7jEQ&Y|h~+>kAJ)V# zGGo-y%WC**toAX24PhLU6H8vdph3BP912|UV1n8~e$&Emsf6nd#r>1tcb|_;r z1t}L5Y1TZjwvw?+D;SX}{kg=c+lKA-jSdB983Jm(?aZAz8yWrVCA>2}zxly1*G-!? z;k;6dMWr{c6CWQBkD)8hhq|sh#+J9<=jH8ZU1StWX7p?|a^iitn)WU(sdHM?{*o2- zlg>-$(A1)6h^B9!j6{Pm^0MhYV$PTR2fC=@p#)XunZgxN#+D)-&khizm=86}={t#k zzFS?*A64oP9Xhn6!w+`!HD>&B4lMgrkS+({wLN>cTvdXHRBS?@CDv8`{G57e12=PX zF|+TKU^d9uvLWDS19kNspI4H^xSa_YGje1i*^0zN;ueunoxYCTR^VJh*AkBwVx2cj zajSJt-_4tYSJ&iZ<%Xvz0kYwLt%Ng{RLktX#{QMGVg&e_R(D*mlqwwih6Uqx7Zj1S z0Ni?II?ZkcWK_~&s*W`&6Lup<`A>U=d&{wMTk74mop~=66~C(L7ur8~b$as>9=^Cy zMc=>IkBC+H&&ZJrIRxIHl)>d^kn*qkmPHg)7#N;kRDyIDn(hiR-PeWj4Be|g{ z?Vol1go#>pB>{WMi8sJF$V+)-@4EwhUPh|e^=I0TOpc+w)9I8A=XTwM zIil+F_;?j1X>C)_)Sydf>eglec5H%30=0n~-5K(wiwn(JQB_Ip1iUHe1Id&Lt+td|@=+LM zHyw7W%*09WtM>MO<=D=D{kum>&cORn7Ocku=(|<|21bTnbVZCNJ2;r4elpXZ;pt3uV*{2tlDyOjffk zrZPdro@$dOM>A$XcO-#WX{(pr%bH0L1*BYZUyoe9ssCmdrrevuBT( zfRMNAS{k96@?xR@sU%#CFZS;iaWs3*d)5^E&a=^8qiw;#-h`ZRSP()Ulyo+~4^tN* zb&XhX*@2+%x9aoS%-zGsEZs*%Xw%NHMZ*Ds3oAu5VJD~Xa1Ut-VRLsc|K8GeQ1 z#@(`?I2A^=;?Fy>JPjV1*^v!5-@JK~&WpAiceRwL_9`!LQeJ+zy?_v0ylU=;egW|S z@OM77IUqolaMU6GGbH#cj~}1Ar?58z05MMP(9ZlcOz9>Uy6%J0!cKZ44Bm{dK&x2x zdbtnVAI8v`HOz{Dl+!lm(t!hQS<(%h{G4BxVO!ECl9`dgv8_)ncy)D%(PEh;0o2>H zdGjV0x*hv#rybFnOd*{v1$6VsNeky?U)%^4MlJeG;NHEPNu>_Z2ZCBN{X`)HKD*&Q zr~J&R(1K5R1M?JRccUN6m^D{m|87|RnDCkYGaVhZyF4TxSJ``fcPMto(gIGXtgx7@ zK-#D~R5~|EzJ9${_IL`Tmb_f!{IeV{&$-Gg3++kE2_PRw#1eU5!U#wt53zG}v^)~F ztZS3MvjZ|W#5woS-&IF^qDUd4?kyx12$n_0x-F1gw5FKjkn%Tq3{|MV=R$0o_Vds2 zwp5RdA2B`toCmL1UxOJBbMayFN1LIugih*8*@+~loisNbInvp^y!g@1{>zRX`Ywi# zyj5Z5A<{`YC%yp5N^8vXUwhM_tRws#cunr#;@7UNQt|UmeQ8v240NK?7V9=K zbaZlZo#D)#CX+G`ox#+09o;}wXTke60QCbh1b<#h@!*-_donG$!I zRv~k1>vKIi#MC9RlpLI{wp(#!I&cBFP!VXjso+$w2W>}dObcCq6urJYP& zu}kkpG=TOxAsSw3;z)wge72r_+et0=nIamL)JU}eAT z{w~I$ZyfV0eC_H_%gL`NU0QMaR^jrJLQUzm#2jVMi7_A-R$DkCv<)U&PmAz4wJ&*HojdJ2c9bzYcY1s+1AJiT4@l@DRtQ z&zf~`K|TB9JE4rpAItcUVBg0p#0Vt#!L z2cqfAWrd@Ha~Dz|KuZ~K`Bc6q9(6>N@e3zJ%)4c>chJHIc$ve`mH{j@kbu|(+7gfo z{zlk#Eu#f5l}KF+SAB*I7dD}3RVC0C84D%NVw!wgw{Krh{a{Z6j$HE8%db*nFg>>P z`$yf?AISpA2##>9Zm=Tsm%_q3YCKLr4o>RV3N}-6$rjxVQwfB|6iqX+o}x zFRec%ZF`;#1ds-s@Xzv>s)+?Dxx8Aw7L*Okfe5kI&=wL}Ef=`VyMbhQmHT|y(ZGw? zeU>p-7#>JSnStKvUF2bYw>}TG*XnOX3iIlmv9V?fh(98afYmh}XU>#j9E~{B^kQCq zCHO&0S`!;P|MF*)_^M*45j+aYM|YRU-G!P%Rfqt*`YDhh~9=<_HOuo(u)V&QyF`C8lPcA z&|Yt?_SIxwL5-(x57I36{EBC;Z{sP}lm3pgF&`&A{}+`N8>A@xRY9)oOMlRg6x#Cy z3C^UKj=5s#b1N5wHHi;i)VMDTxfxb$6WT!HpQA!74_GV1%^*)~OVFAbx7f_Nhy8?C zw8XtfYdH*9=aV^>^*jL!3u;5(7eDB0Zw9z^Ao+4z+M_I!3qxhk2SZX&GnJ6hd-|wC zZy3R@8@S>U%O$t4ctj?CoUnM&qT{q(={Fl6H=3JR`ezOMeYpg34ofsb^9zG7f1*j0M}lDX#6iO)<1FJ$o}t5 z_YN}4%gYn~0@t+8sdEp3-DK=^_i2z!FHbozbJi>eo|R|vSxQTYS0!VY+%0BGrX2zu z&l8l|#p%`Ga@8d|2wU~BmXR7V6_@r=tF8V!Ips2o0@6(91z|nDRP-+s`W}qa#O2(U zw|*2;alIjf^yDZRwH~#-bEgFc|4?>H0lyg!a!oLVSTmw_isV` zvFg)HH9&_saH_vwd+{-N_hMA&`Bk~Jju08Ke$ZK7hE*so(OHI#_AUmd75GSsTRVn0 zFRgthUF+YPBs!EtO&G2<#xALu{V0@E>Pxm-`-bkJUPW*d+6qkllZ)kT|Mb2S*>WdV zS9+K&R$JX4Q;#ku4X=_S#)Kx>aJV;MZBy>)*Yg;^& zUDViRM^KO#hF&)-_hP%(^z5Oay-6*Tga5|sJPXdK_NlF<2V}ROar^$k`f}`NuIUBC z(%9zHsct=2 zhGjL$?O{j%`{v!dUB68dNA96(VGUl6)XOrust9#fCkb?Ua6jV|Jjxw@8z*05u1pFuExvuHBh zotFzKPGG?{4ju(k{SveB$6c9xT=7Go>F>kNAE@awWw1$p3+I39rsy4zT*RO(?CDty z_v7QeckIV@qiAnt=cryicpFk)-&>IgHN8Kj2J4>f#guL@61>FkktKtdvcW)2;riV#MNJ!Giw|D52_t zAR#pOG`tm9%+_BXLvi)JJA)3J=UiERb@h(@w4Bsmb*gZ_@%V95?qKu0osvbLN%_1w z3tx@IH>7Buv>|=&%6M?aZ@pcemXf0C(k*A?=CIZsXBM~npF?wr>LeU4RM^uNO`FOv zwUUR^7R{J7t9_?V@59f*Eeq4y!*=duxm)ORWGX6igQo0Owb)U`blKeXK6yh)x#}7kmFW|rNU^!0)Hc}y6nhV>2AJDn=j{!}ZgZ3WyKvYuTZVj0 zO5L40Kj1Olro>^Bi;IiDmX$0{KfR0R#kEoEQ%+N6b#Slx`i{&R_C~8ebYBUa2OO>+ z`4se`at{TDn&DmM{eibAzia>Y+EH_(&R?1sr{<@Xs4hM^I`yr~|NG8;`e?gvI&fh6 zfuOx>FBWeWFA(}6o+nl_9>TjwHE_yq$;w6XK56C1apTmjuBCkL({|YNbUz;xHQ&-^ zZQkB5`T+J>cq#pRG?%XKo-JGJ7Zhdr7yq65Rh&Vu-#hE@qHLmR)S(9fUVd6vtV+Mr*}HwRcrKnyl~5_J6dr zCgm;bEBh_d*zB`aqhKxQLGMo2RW<8mmN{(1C{w_QQ4tysXPUHfdgeH5@#3O}?>lN~ zMH!c-jH9_!_L(F>-L=1ko~gPs(wo=;mHbgSV^Mjz8^uF+|8`ddBBMT1_GK=^y#iRCkdvY>u-+l_oc z$$5S0U{dG@;^jMh`1(vIld(4WMaCMh9doNT#mvV!mk|};t`=-iPqd$R2zt+`R|7#t z+CBza9scUvttd7SY+5gBE;FaEJnnwdzN5Z~94j`CKHE%2>ph^A4AyA1;sp?90K3X7 zFY#Ax5a#L{g|w)pr&yn!>w`83^IGay1J z+VC}_?<3mRWV(^A`K7a`R4sPX6j=+%mGO%K^AlH$)@jw}$$4SOuciBiR$h7T4-%05 zd@`Ne;lvko;#FxhD7zE^6I0@(L59-RJ zMXrn^g}+^YE2GF5dJnKT8)psI3fdSy`1=8&x}I@@88Jl9YHiZrA%7pmx8$)Fl`E-_ zxl?Pj=$?m1Z`iP*JtIbzq*6?o9rijjluN=y*}ip8t{zpZ+^8^zzxAR?|BcaUwpiyu zS<~EFH+|)v4^wODjS6E6^dFZv?g?`5u8(K(x3D+jZ-~L#@x40RcTH^kKO2`tV76O(Ch2rEjZs{Nu^V@NxJb6(zU59D!J_(8 z=9UPVDC~7~!~Z_n$Ppv{q15yI2-3C|5LhR2+^*5gS4%|^E67B8c&JW zs`2ripQ{FZZ&=g}TGJN%gtQ}#xMSmQ9P9Ew2a$FM4c0ij3&a8{>F2djulELw~xT}zn3NLn}ZXgLd^v~%V>b*6s({#ujich3P#84FPYq4gWV$b?S)7eo4W@ zF~^@nOTn)3nSHSQ4u(u4sv5^BS+%*jIy)E#MzV z2N0tvINFZ+{b_)Sz_fcdCs?J+4XutHrTItCS9OyC#bjbH&HvfO515h&2e8%|6-d1t z?Jpmm^hiHJ>9T<0N2+h?tLvm#3u)T9Kcz9Rfi1nKUP4*n?I9DJ$|MV@pVi-xx`_p| z=SLqO(~|&aLNoCE2$&>NE^>Mw|1Lo(4ao1n9EnBE@d($!zsxUaUi9^U^OWz%clI1w zN=khlwZ`tx$2Qh*gI`0!@69(0=Zk`>81VRfLv)-r>YI1m|H*yTDj7JDxRiopJ?#hR zLtYT~!oeyt$>q7|` zW@7D`bEk!TC^Ere0eT{$Us(@c1eHTY%HSY4?!J7v%mW(XLc+6-I9z71^YiIy!57+`A z>+t=<6VLjAGNTj()( zWCp)Ku~FHa=SX>TyAt}zNZ$K07|RAO%~5;TKl~3UhH0YOpn(!3oM?--U6NDoDhp0> zeATa|+wfzEYq`YLRYWOJ7;T*N7I*B&+-cn|QE~smn#M}B1~!0qisfg9AtSYg_DU0c zkdq!Ch<3LbbHa4HNO+?-@}hsU&^x!`jlT=8T4fLv1^O#BWVS-SrHLPF;3+pW?F}&| zdVhlD>FS*FgCim!AgisjbYVJrC90jDg}1x;sg1Du7l#U)D($oLi!_DP0Lw;az-2m( zaP-?hJ_bI+5%)fPc+QDm5L`QndWS&Sq<<4WTNIxSNcRm}K&%)t-T*c>6*j?&sbl{8 zuUmO#UXJ*0bwl%!2?{lo&>GHlo8&tK zb_rMY8+ZxA_TXaftZ3owp@D!SVH1PFr55lAnL}DR6-Z<@6 zBPGWM;l933q23408c&slv&#yq@y0#B<;AzEyik)Jq}T%HqM!c4vZg77!|TNS9B6>q6N%4l39p++u

  • eah}idnAlmkUAny`?6^c$aXFX)`+9s5xEPSZ5%7X8oZkPmRW>k>&LlL0EcA#G&5$yu~u z0m9O19;io^WCYOH(&9O|8@JchRS~itvv7qZMtK+fRu+$w-dgF)_m5K~WRZSjB(Me$ zoV&w}yP)B!QW@gO=}z)Ja5O0sx|>X;0~}rs6<9(x$sd6(5#KIeT|w=G$^}P_3vP$W z+KZ=TAP`%PT5m1PO*p1CNuPbTY#D(gjos|owS)yj#{)|C@zbY;M;JNwf+a$S9a~B; z`NgNBy@7!mHR2&Wk=#?^ab81?K$5$s@&4+l3~}gg4(*?uWi82Pn=W0tr0tAlI=p+T z`m&@B?J~yF1}ZEX1Lu+cy3EjmDR)5X7)mG_8so8g5(XF)Y7xC;I?cI=+6KwWfV9(N zkeC4R3a5$G?~7rqBTce%jO8JR?wD|9B(p-XPuB0zL&!Y*5c^z>qb9Uf9*du0dSJyk z%tB%2cTG665oXdBH%@uUd97us^fWisIO;BuNpm3zo!G#0SAJEE-o3j8TUtqp!bHGa zA&+5j&t3$HtSKx_atsWEECa7$>*5c2q@@z{=`AFVZD4Xh_(FZNfu8{Z2+r=U8W2bN z*-dBT1^XoB;OjNAKkq>)FiVt(SANYoQQxQ<@nkymdmaH_P!Knr+nSGfM3&hbni3o# zEoe`uGxtauTHH5uZ+I$z8`UAt=)|VocGP#fggb~cmP%_O4AT4_AGK-n^Lh78#EF}< z0^;uZgdbFR)Zw>R|DdbvdHS@0{2XM|==%?s!<3R+Qyw3Km#BZFQ^Xd8rerr8$sTR! zo5?uc3`oA%P|AMt?jcYfg`7^mE+R6rEn8##`t^mMk8Ajsgc&X|HqPoorAf4aUnsvJ z9LLe4*TMd0*#`DYlabCmYQr?$!J|iSDl=q2a5bTC@;Z3kch}bs49#7f>rRk7>@qN* zmlV!mV(%O(h(^K&B~N6{-Z@<6_SM+z2Cfrr@n|`mnL@ovkg<}}V?OOnyVP;Uj2Sfu z;#~I^P5hj<%_POt{%&}oOzfRB7cLe6pTq*OOLZkE9PROc7PcuL6keO`I0pv@OKv&o zdRuqB^2>X+RY_6cspZ~bRe|z@zTJ)@T@L9rSI4^YRz#PiOm?+60?_pvPQJ

    jU?` z^1Cv2SE%o(D-zZE=2Z&a3J#t8Ohx$*+eE9EM%-mx;56(9hRb_?p*Qax8#a2>4vfVm zO3DioU&+Cv?|lUlmoZDQ&tIYQB>9ilzMqjxiv_Xf2wsw@6)|6S#=_-AogrGh9(RBB_#`VUEd6m9Vj zoG!NBAo>zk(-vrVkwO$Erep&IgiSB=PlxJYIe^iipR_N~oV)KITLn?tPv|MAd)J|G z(KgJXhx03aXSZ81wIVzS{WS3AW*)Pd64z3Sd?%Dv=7(bdc(>TX*l?pjGmfdq9>D+mp~C$px@HiJnqn zP>NF4C!4PF`#M)Cx+WlQRRxKp-w{x~{3rsJQ6G1H<`M5J6TSIZKa&}d+*nC6m63~Y zrKALWe34+wh?lPS@87q`IzOC3=^@Ae30so|%e3aAbpdaPL+#b6*w7YU{}Xb^8QhHl zr6Vjy5$oc}t!}Zy9Eo72zb@^d)_KTawI>kuBKiCimMl5-flhTW+)%4%hs!W&;wVOM z(rV}H&EON>4thbn@oQXwB2y#D`m+1^z>}I8UEQArWriHtcC4jHZ~88b-^d$V z@EM_nc#rB73@XAz*Q|lrdtq+SmOsD=M_Ho`pm;R%C+UmEsvFtJO`^XVSdo!n3b9d? zs@Nc!8r4rmwaCXO1}8EqL2cT`u%7)%n#tTHw}X`QhHApp4^(LZ{Z!E~kv(zAO;~>V zb%yxt8zGMt)igF%^+=*jH&=pQ&?6PWk<0)Dp4iqr6d>H>zUbh9WG0`Tw;6(_0<<_9 zI%~!Zv0#{{G5|?vhW(s)mPAr_>^)o0gc3jw$Jh1uzd@tbjAxJXRp;VvJEx^DuT+7F z9Y(=Zt4$rMQ8+ko3nXUsb4CNpPu#2{6(ZDYaoRN$Rby=icn z9r)7K&pQQ!WC$)YHxjJ0J)g$ugn`ZI(RGA+hEmWDDWwd0GJj))i#rFwm^OeOeum!5 z8g)JjW>ZmG(2cn{bvWY5GTHo_YH3M=D9eA!QtjHK|8d^Zv;#cW}jx(n0Id)X27 zLY93?ekEh4OgE1+j;sdXc(U%?q)gN`eATNM+X5SzFCW!w*su{J9EcHQxk(7_EVFjR z_jk>2EJjPxB!5>UbkClaJTM`>;+tX`ntE&``L!rEg-ONR=x7vCx9)@z0s(&^xMT45 z$WQPhMSVc~LKTkY1Seb0Q{fO_DFr&7Y!t9U1fls$NbH73uKL>Jd*$1Gxzb(? zR(D73Av_rg8DENvt(HH%mhqPARyey9fato97FM)2T)u`IUS^C^vRk>W&hNY|H4V=l z_NLB9@*EWi?IVB!zHle%y6X359|dT2np<2#e_a%N9Pb<8u$I{`xP66`NSJgk9(sYE zto$Y)?*w;o@alu3M~!04mj*;Ov9O37`!l0w{nmz(2tYtJk~HF7IT(IjCgARRxTS;z zD}#|I3Afm^P{qcaGGd>6q+Zy_DVk2XXKK5qgAq=@gJ}^+R$|DI`*-EhhiRCTyeiYs z5srdC>ZIWBJ76?Y<%R)&qt6VU&~1~H73}q^w{IUdZn94WOcJ)WNFJGKFiBn}HFdPS z450brQ%;T0W>{v~a!6s|&oMyD}#`=|mwgSB9&(J~Y&Jqk8j233;H` z*6-Ax=8N4KXKfI`wW<4qCww$HfRtTSdR@~mKM-x7W{V<`$HI>l5^}Ju%i-OW26CXU zT?@aiQkTm8OHt8Kg6LPi&-U3bj4cbrV*%F5^5x67Ef47U!`)&2eA{-t4?P&kg+}7` z52_hh;}OQc?v`Kb->=_mvJlOzv;F9aXniB<@zc!eEswvS_LyQB=OWk0>dNwE8-KK7E-gP}$8DvVeKr$K z(LIPaVzBzIzR7AG6hkQSXo4-?xO(-3e@NcENIMYkn!jF^>tDj{8NB?3Il*)YUSN72 zWy|S5^DWOiT&63gSy8^r$8`H>t!se5np3`ffM+G=4A;nnhM%W6|Pf9GLU<=T`^m z$xS6r{mSE$IShJDS`8LlNubSOcP$l9N#BAYGZKjuW)cih9q+bi^TEh#kE5t zx5bBh(gZ+_BAsi6l`N$&ePVA1IhHm<$KR(A6 zs(GsjV}wYP2zTfGly)(P3FV@rf$S8Rvk-be16Zf?_JKJv=gBg|bm*#qR7fyl#kbV* za=i0&(3lNaxJ|#Zrpn}#qtXE{WGu6^IYW{WNr_qoR7&2OdX7ntZ5UM5;dx{Z(Fg@e zA{o=EnH{gd-^TRaTLW*L~>+A%jktduD?4K_KT&Y&T{5Nc>Yr;UjCY zF*axFR@fk;VGA=qH8^Wi4J!dnXgf&C+si1ITu}5;Dc!DGeT3$<0PkeUNTZdMWQs7b z5|JV1&ZIW&-rXGVqKwkU{2)x`SzDM>OW3oFN$ZVf)Dwg(%O#kOv0{KbAwl!R4!=cANc?-6v|s%as)?0g_=e=;(z^3&jm; z1bJfz+omsv(&q5?^BWayyUPU}T?+lL5B8`uWFjdfb$hZ-^kn9cXJzwKESR1-g=yRD-A&*iY>|#XvO_x6Vbo zVs%JEMVJtPY|~XoX$_ZK7#}0Bn@PAcy~qv38)2`x>MRm!6Muu@6FGKRBX;%vl?p_*e}>{Q9>jg?Nx_LW!^UjHKWhK$duv|R!KlTXX|PSvc=BM-@=edKHx@?X;TFR9_zI$ zl+P?~#VFvh89sEKo{u;b;FM&RZ&@f?Q-ra^Sou8;ojve^;`g|~98(&~lh<7SEl-s~ z`SAAQE0V8LOa^}XdDv94J{kPC^hZDFVcltW9+`4>tgHiF!P)nt>T=6;a@BD|*?*O~ zl+0p#Keb??zw2Riak^y@pZ(pnsm7E2gZ7R*>g23vKv65eD`iAxbYpEjJynjVXpne< z(PUBUpbDDAYaZANG1p$ki61;*Nv0x{a}oWx>eRUPf3&@KT+jRa$DiXE$BrZ;qpa*G zqsXdsC{$*Sj55krqEL>JBoT!|QHmtVmXejMq@^;ln-VItfA_1-`FuX#e}4b`Za;r~ zKi|{){eF$-xSrQ|T#v`xhX$&#t=fvh6B&lk=Py4IXw|1z#_ZHs{WyCl-B6J?3Qrv7 zLS{2m_svm`Q*LRy{=l#C!h3>q9Xxq_C};2>I>3b7Gk;nWk`R5b-MA5!(5K*0YtpMU z%%6U3^1da>{uZK-CLk0Ikqjra?pI^}VttFQYH>?#H|yaV6Gq*^UIkAM9wP_Hkt1_L z=M*}yU8E?JTK?`9Hl`l*Zt^Z<3Pik@t#*>HlO7fJ00~fdeTf|e@oaX=v@2%?vAxj2 zn4KD@ANQ5qEX{i3S2$EX#-myfKF|}-kqJEm`oapv7JhUbKW9^E+RO}ybyGuo^Ff42 zRbS359XMJJ3brQEhuXS66SbkJAbc^(*J?sgx9yKMJ#WEakw4ANNbTrnV_OsR-J#f# z4&UDS%_-^Me@Y5448qCIbX5B2v>hvhC~9o$I!vVHCuqi;4{f_Mk>b6PGhiY$gt?0& z=k%vDI|IrgPdN|SG+35EHKf;C05A`G4&rnLC{`>h0G(i2qx^X{-c_i%;w7SdVM@x^ z@*c;T3oZ9>w7XJ68#`5%Yh$x#+pbg=qzB%$S;4NS*M%^(He7G0Xd3w++>~kjvF@OB z{ALBvkgE@h7eXJ_&sNsv=uJN|mvs1cY>EeG52XN$@LoDOWYEzuS~Ud=)}+Z(zexYU5Bk}+Xajy zd6gR4%Io*qv!jhtp`ewDOQAX+L&EW0s!TDYX}zkkV!<(4q!5@5{1$u(F*QrQ^ZQ5I;Ic2jQuC9bZ?;yX zUJZ?g?+BmGPIY_1=KRLQPdsKANjwu))uMkFNkg%(1Q{4=t})c4Vt_MA5#s>8x%pb| zg%qUe$?$%$qR4U|_6$Do4>$|5_?wi&5~B6K+qu7lA%Sx;cydsNtT1R!Y>X4MK^sV! z3ZH^4l_5fxwx_b8nGpOhyJ9)+zAH!CeGj1)px}GZ&)|!j(a6aTJuipY#HBI~E%~;O zPV*f*TwR4bfFJb($^YEVr>UNI#XaAi^S11O?HCE%eSGQI2s^{@tDRbZ2#dIBRuDeC z_<(Nqq?bRA>IWZ_yZEmXD8jlv21D@<_T0$xqj^u}p`_1UnedwxUvxuLrp}r_|M>lg zv!iTNdtc6SA2Ry-Fh#Ar1+4^7Xt2~`b2WJ|`b|SqZF(=q1ki#Ok!qtf8;$|yXgiKI z8{Y_AJLH4fPTdYu#ao*CUiUi3%xc2r(3IX6@H==_+gIyvV(J!LkR>fh^fB%a#7{zl zdr!w=N!~M^T8p<1Rn%xCK5ojTh516^?QkQ}V7IcM>KRE|{JRnIU_g}~VvvNDf6_5S zFZvC^`kMyLI9mA(cR@Qp^jXVekYI46`{?XZbA&N4bXK^l(HvpMpelEOKPG)FPT=^~ z`eC@<-d;q4&)dABT{=m zEBD8@Z)E7V@<_Wl!a>#WW=4|g7C}=MV|RMmgJ!&`kbNsJwDQ90a|E@*{vvNg_$4I7 z90^{c^#=ClOsY-g=pSE3Mu*sK?(*x{2sQOxzJ1y&j(5}O)5fz~U~dDHv>B#xBeYxY zxt^@)x^`M4^Mppz_gv9$rRl8_=z7vEG}~2wyS{Z*_18%^Efm}Tcu{oY`OVSkpZ40R zCPbx2b?&=tbj9qVIco593B9XUh3a=wW6~a!q&d+Vl&ascBT^xO)4e`_=^l~YMQ>AY zor$}o8VqTkqH2)-rIEwN>i4+20HJ_`exbs{O2REyrNE+Y5ky8!YAh! zRN%h?^+IuXIle%8Nk{j(rmt(+w!P>tAZ$c`H5fL^CMM}Eyh5Mdty)adpR_Tm({cNK zeM2){E&b`fTKawW7pcy}Iv8l>ok4Umi1nVN=?)LfM!$$MT#@$xlSC*59qhN4QUd}Q z%4~-wjxPv*g(*wE%Q8Bww>gSC4ki_2^86`fF0Gj9Rrn0r*FWSEhYv!zgSTEpHc23A zBKz1*Mc1JpiLw4VhP&|GNR5?1z8v?16vooEr}mvP0a}8K+gdd!0|w}AwX%$+3$dQ# zalrYa1AIRsRn2ohqRm-rb8Pp(_wUXR`O~CtLMkf`XbQ3&W@IJ0fT=+UBU0KXrH_i| z97@z_()=}~iZ%{u_fN&e#YvC^!TsYvgWW<#qGK2;D%|$Jt2xyh^?KE%65SsdmuDXj zot^SHAt8<&D!k+eu~x=MhMLH2*ox$K&`F4tlgqDWzxJ?UGus6;5;w8>&%RYb%(Wv*hxIM)-o5+MwFY)-<|j^7B+yKvWfv=N7uT-j)`52J z0GiZw485c*;P-lPdBCOA2RB`6^aOII<2qKFrAWV_b3nM>-g`!QF=Taru}Y zPHHk_;$LbO-&ERQYJQ2OSqexGbW_`||a{uBD9k0R?vGLXl7g8wWy4))Z?*qW+=GDi%!;iPRE z0z)@Av7@>YZ&&DF-B1ff(D)YdDPyxAr|7SCKa+jT`uNFhCT?7t9iVi1WSx8Mgl57-QxYxKzO~WdYjL5jQCMZ{>;l-s% zY6<<$0_+am-OX;yyryBTOy3PhkuGK25N(uY+s%m{ZPNGU``pl7%@t^cipTN{(7E3< z?*9GvG^6oOO3N&zFqr$L`p=W?oNl~X`pL170)VLRIg__E+4N^Dv=%1eYfY6b9IwdZ z<~DBF7SD(rOr184^&pGfL`3(RkUvDuaz0+`$z5stFYmuNT%ux=)MQrfJ|CYR$cb7n zS+Z>5%W`R&q`b=QQ%R((n?6b3DUu|wWcnBrmSdf|^oA3h!IecWHEj?Yyks^>NPSnj zKB*@F70`Myxhey_N#v9JLY#Ff8#rxxe?`;V|GfY7Rn12!KzO~oWYsJRwTgt;-0-Nf z-y)~8k}s2vD}-{6fjbde4%NGFFHJ4h4>64Rd4y52wBJ6CtXwlzH6I;cYiAdX%%lBf zwHXiuJU6P@@rf3Dt#5(z*6`Ga6NfVYC;NU*+${taNs5-a?E2)_Lz6f${W9*)X}axp zT<{LR8X$057bjYo28V>GK}3t*k&}C1@2Kd+W$zr(D_-}HA}9r3cS02O3r;ib&!fA_ z!o9cKIZieS#}=wrYG&V^i~vI6>|9?9-gt7h*q23z(E@JQu8H zX7=3-y-GXi>*eajTA9mO<yF>_#kjUIKQ7e8b2&*hF>@%q1gGa%W+wI}j>AD_j9-!27hZ5aic37d z!@NLminT5I#lRb{*VV)QuAf|_@Sq}ECAR9vk&?@%tw~MzI5!$oE3H=hQZ03f`;`E% zz}~Mmbm$D`;-Ox&shL2w*bxTt8_toTniTun+CL++d9vV(DZ;s+#=2z^o zjJWYR@iDUwN=q4Ss|)GgpUWnS%(W4{J#{B$SL>ogNCn(Hk9WunIJBrJ_~0~-^wiyF zQD2>^)PUSYYko;Zx6{O)`x_S(FZ}rNqi7WHT;s9}TI62MZfoFG?l^)ooxEe4+PjQY z^y&Q1)GTryqtV~DOFzKNwLM&}w${!HRtdX7bL}AtfF3?bG z!$LFvfMbm83-3B$_Od&zterzh+26%rBSvO0d2)-Zk7x~2dp+O9vXgH>4yb}X_=ad( zxswff+Dcb)ro zGMH4P304O3{gTSrJKX%=%8~tJ@fpk5KsFm%-?~n%pJNZ%67UNzs22BFjV~ zl;;s!`3?WRA!$$s$uQfm*I;ezS@fI65;93i0L3;GNb7AL{qn(IGudA8kl-M=wFft37z$deOPu?O!x(@SA0|I++nMPckP?%C<|{n9xPn z!5B}#O!1##hb}>X!c-sRa+eJl?^RZ5{d93-a-*zpqoVY=brPyIKYJIHJ@k7#)mm?~ z??s>IoX2%%)!%aUs(>zvd5%D@3rzvYIZJqWM70WT6Mg(lsf2xK_Do( zqWIj|oI}hi3?IyaAOG_n$SDqs(R05+N-+Ulyk-cPiyVRyzjUemw)4{NF1b~DJcIKN668&W% zp3uY_7q*`GnrYZ1)D%F#ln2WOV zi2SPS`(x+DdDSY+o?O&Pv}|7)r&Syu(zR&Bqo~aRCrCDQ#ly4rSn|N4rK>k>X*%Qn zF0m{QnL0<`;&$BwJn(zG@7Hp5SB+F<{2AYdz7@qzE+M~v+0Xh3DJ4p>m7tZSr98G^|&Kz~%1ld0c09m$fC%g_wl34Hlv4`}U#3 z;K9iimtDvVQvK79@2d9IS=}A|;Q)=ecNh0?CF#I6T@yKa;K8$rK2EVe5D{)4J7kcS zvuoS&=oVYsZU|_%^QWUGV!r3IcNsnYXF=FX=gii8N-L#Pjac z?%KXbLZP`zs_Hm}P7ikRgUrm?_PZ^nIdsS=H%k#>?#bcafJ_cIW>-fP#RUkly?R z{zF!h^K1SYOz6E20H>vM*&>#HoFB5s*gJ#hH*U8J1^{hW2Wf%1_BosGUu)b>HXr;m z2Q|9}c6OUx;(8R=^$5s`6PCDXH4W3c7GGwM%Ti0Kh`A5BC4PpIYS0*26wAWG{E^T4&|zG>3n> z?5!1QZ)T8Pd|<>J8x*c5y-ZChUe~C^#hC%J*0JvFFaU+)OLyzLz6EBBUtYWYM;9SM zE=r=bJBb1hw2~CY@oe{<)n3qBl*q%S?d#l3b%`8=u0_X?G3w%7T1gPl5b!+X^8w&W z;kit7;A@wIv8tGBj7L2@&GJC?)Ay-m;1#yvovD8%JF5)-9e_*S=~9Hy4sIkKQCB)t zY(^tP#d*Q@h*cB-epJq1R%#vd<<@~EmNEOwqBuy}lV0aScPMhFuQQM*HuzB;?H_Q% zT5lMj`Pb)0+CYA~Zk~Q(m80+IyXAlvME0`Tbc%XXlEfh_ud%+_?%=^EIm+}k)T#te zGd(e`!7mY3`}Wu1(Gc%F(XM&aH{urTts>xw77fT#f;lhukNL)mRog!I)xEkrvo(87 z5l?3D=uP$3whTFF2zpqsyGfe<0JvvFRh^IbYZnBZ{MAHPqrt586X^wC=oK$q+3nVW zf&njYxt@4nX?`K6Z>aDkLXG)PxVZgum1AawW3UKFF!*Rqz4e)CDR&}fIv7v3g(HI4_Tzn`Y%M8CxB(D>1>tOuHdlZQxrn-zSom33Sl2Y<}yM%BE48G zSuk+y?j#D3Ns2hM1=VCs-94wCW(Ne3xRPs9B&R^xi)+r}Y(;tGK+Nbj=NF8RE=ium z=YJ9PCY_y)P$|Fn_y7E6Hz+3k{C4h>L`^BV<*VGw*`-dF`WCG5zQ8gNq+ZSVCl`GW zNw*CC{Id*vBMJ;?t8NQJ;R^I8`%8&)@shh3!OSC}3IgazQfwwlFm=XZDm4+nBy9;^ zC^!$)xzjW+G8=p}c=G8yM;ecl*RfuY6FbBWBj!o_RAu<%%{)Z!l1IDvNJ`&e!iH{{ z%YwSg_`I1BMW7s2x!s@?9McmnG-=_ zMd}1Kd59l{qA|d&OQ{7+ipgL9C%QjUdt?zLnL|BVr=i>IoJ&o0@~8>b<`h^@x>WZW zK%G6Wkp(91+lyrpgmbGcepTO)FO7C-dx#39kewmk7wYlvShP}BZdfZj+iYUfmh-ka ztW};4U@3{QP13YHO)v;@lJrc&U>}R*iBAK+!xQD~(?QK#o7^j7_=ra0@`^pUsb6lM zOn>y)h9Efg^`R0<53N++?I`)xBgS&xWOcR8b0Ed8b)kBV7%{ung;Fjrwovic7|_2& zfj)?JSHu3CHS2KC87ZA;w>)Qy*t>C?f!%IgyOz5$_!h{v>E-Q$cNu$SITgb*x{DO(%)ZY>sfL(@ z?e~zBE1T(L6nGdx6|na6GB2K5?4XtU=x1XC_UH-UicNyF!EQ-TWpjB+(i2Ztr?4`R zLy~_myd~*aWxr5`uau<<7I$4rkQ<{A)odFT@$}Rmb{0+z zoxN!8yY=d?j$KL-^gg+3`sv+wc)9g`U*gay+60f5P=Df)@N1w9y6cSdqRn;ckbxFA zjYz}`XCC9Z9FCwsb(JTi>lpLFSouq0(VA@o`OwBs%SMgmK5FMd|3*!$phTv?h<2tm zM~`2|j_bXhN^{A_0DM7U6c@ku76*(4ogjwMKgF_RY8)zbT!cADv2l>+p zb~v}SAak#+jiNENtLK$VMq_-QATr5u*hhcq($fPh&*jocMh%fCItX;=&x$8=rU5rT zOo5KvZ8A$6zAtivwp4&p<>{n|s534!^Ti9o>sm$X)+<~(sJ~8`IHc`{3$G7Qp^`-U z`Y(SatT<)~#R=s!No;;YNJ)4WN#pXOlg(FcSMJ`Y&qTYh;XrK{l2_$vCsH%Gc4mlI zL$UM4u~u`CCGeeota#eFzh>L*dEjCTpBsrkrXo@Z@-POd_THmENTh}nns-2`zB!d} zl&yp1$!RYkEh-hSZ7F6hdfY8#U{YpAh5#p#H_o7`<7C;fT=-g>!REvquAt2YS_<^Z zdqBOkQ?iow9t)QxpZp_h^G^==iJy=}uiLg100az?ee>=nKO9*aM;O5?pw59ckk{QYc9||D5{w!AHa0dyeBS(S8WcIJw-eh;io#Vj?Rp24Vo%19jyD+y6*uXH!LN!qUd7J#aB%) zfFN^0A2PHkN58~Hs;<|k2Y;umVR-T95e(omhrXsPv3_(x+k3QLvwHO{vYlK#5;D#c zz_EDR2mgSFdFy-F3V`CRYjZ+IZ7uQ>sa%VY4zi>7^pMszX%Vb2#bI@rWei1>9=J4Z z6T#s3RM%XzCPubpuE?qpe4)`z8FzF30UTEY9?*@HOL70`kZk|U)2k|d-%gTufrgk>n{<4#W;m}L&Fnr^tT z_5)5Zt?+$xCPo?<7}%_==&WYG7|j5wSbujOBI`SA713R!L!kvAwOyWO%cb5~(y;`P zVjQ$rOz=?r^G&3qW%B7uVs$j)98WTJUx9F)U)5G&Wz03W^gN&oB1^T7m|x{&Ah2z_zTa>ktf2O&wJp7S{6HjZ>DoiS%=gP>^jdbt&8i8>hCF1PKc<+IdX3D?g9u9Z9tA={L2OM%~iRKV{>X&DXEPL0c|afx<1zW4Eo4C-;AaC&MAQI z>yB9Oi8L&Twxg|WqaAWm^K(L=RK1!?JCpFUh4i7&j(Z%U;=mwpcZfTsMZ#oL_TwQi zH(=u>7tFu3rQxGjG5hSKh67_t_UHCCn87+VX^~QCjy|D~@`@3sROeSd;8a&f?Rjw5 zGgD_}@qpqFDOYXa_{nGHAZ)vQk;*^pe^ShH?^YDosekys&-5aJS+T&@frU$Y)#M@P za?0l~j*hZQ0LjVI0%qCmR|ErEw&g8wK_^1mDR&RtpY)^GpUI=bwax3@!t)0Fl6r63 zNf;N=Bi}`Tis(Y(FWjY)8MaHuCM$Xbq?Ua+!eH`vzq5<;7;z*fIG%|1OTQDF7?_(8 z8MJuD=@w0z+<{I;onz8xD%n@h$@gsJAgCk#zleFp1G6}K#^VqiB9aQCs@!*3N?PO zBF}f_e7uhZsdcZi94FY-H%|PJogL!qC*OAlHf8dq>x_NiD?~9`JRMMP-&otCHQ!?@ zGb8-gYy#bOYsOj(};T@7EP? z)$t1bURoN*>YIwYpW3msV&cjlR72_7Mkp$387iK%v;He^EUh<}GyP0gw;5P|_ zih7)*z|X23guN93o@ewSC7CwS#+k%no-T8+UFINsD9H_MG2#w@?E`pis7CK&0<>ad zDvCjUD~s5p_uWR+X(mv3=$Uo1|2d2snh}*V9a>uu zOQc9C)_mBiHEn#OIAlM-9OC`mxed5)!6?RN^bLg*1~$)$j|#Hr<3a_?Wahv%pClbD zjSjG>>pU6)QN?b-?yl#PU9fl{-R&JvLm(Kr;^olLj)mfuxdwi!-bB}cprk%E^#G)B zmjiz*pf6GKMAJozoGfjF2^v#rf{O%qR_pXFT+J)RQ>Qa!sWZB!=yf1VLuFl=m2{N7 zFSS8I)!2Lg#sX|iqWRE;Q~LUGu#JmSH?keOJimL8&{n4a#3R<5qI{I9r+SuCL4U?t z3q`|z@N1eU<{7kw#f6_}tlut@31C_2bLQIEHa0jiWt)8Eo+c)iCH$slF^cK^wIZ(&lRz#qi@@-%Snk#WW$ZBSv}KGAZ&D{yJ58I|MTGmznzxMw10~hpX}%n; zZmB0rt2cq3JioD*eTT*WVgY`xq|bXM&qp3j`oC*s{mw+5KYt67Vok6fHm>f(_0pKr z5B3pp-#(naFL|_Dc~_Q5675n+rqirYg=We`<2yMunpu*LD$OuX7sM3ANGzW)tzjjXZ`NKMO+b6$RGb$)rs$Q0E;QSFwZ*OPsSza zq_=6-hVk7ZvT2BGhzELNwwsYErJs0w$sIMG{hk44lyf$R>F)T7o;JZe#| z#uNP`qMjVzR>|XCZsID-SA?EfrwMl26THvF>l-~f^qhQDz{l_UHot+I4PM47ou!K=$vItK{t}n9G zSie_Oi{dl!Pa_OxUz%KdYfYl{T;%m^*6;16_MtX&(K2D$+dschxLk;V5Jlvqy+!ma zBG%O3uG{sFg-n&}g zx&b^}AhBQ8|54V9piw#uq+uHizsGqKwm`s0W?IU7!lAB^pUL_3(IAR^PFsrH0FVzQ z?KQzbcK5R^FQz96zV`uev~4HmQmOid^B0G!O}NqUz{`65$=6AtlU585%h8+}5Sxy6 zrd21@K~1Y+l6FndFv`nHKI~Sa(*j84JD?<5P6_lUe85Az&*+F=dQj&h zZXS_R>=;PAU!W(&WOcc|mF~9beK) zMVvph{2QJ1-yk1bNb6qD$gfXR@IbaP(~UC?UQ&eM|AP#>m-M;>bqR3$cu&$TmVTj* zASKTuvo2_9L&v*Ge3jJMEfCk=0%|JbyU)n#D?&sp?XT>@50}#|hdi1w0NlRF@j>VG z2T@Jh-dov;rzWdz>mp5;*fm3xi16}l zTfbbYWWZ#K9I(u=&{{2O`-rGy0nUby5Gc(FeQ*m;L(~V6M~mF>ZQ=AQHul8y_#MKf zM#y9MpyRQ@-@ksnMII4})HTuH?sq(sx(V_v2~22Tcyu!MQn$hDYARL3?%N8-%0K#Y zyV5N;CcGUzZ#Uc*-|!sY@=mh3`}8fBQV5}nJ|JOA3Cq6*L>M4lq@ua-ciLX|%1}`O zd>fAsf7_+HIH1zX8?m^z@y`NrC~E1zuwc?WFnvZ6oo&;rW5a-tb*Ri~qxdkJG|OEM z&*B=O!n-{6^;;)dxro?;=)u4m?#2>2^SyNtBqb!vJm84H$j0~HmS*vc+~McfH`)Q3 z{gobDz36MZ5Ds!_fsz;j9Mc5a`I?s(U+xo;J@BtqZdjGSjY{sPoAi!-hHJ zm*%D3uu74ckAul>dZ;5tP>9(vg7=XO5Y5_c)texACDITJFYHd_SEgmQJ;#CZD>rDb za981z{`p|?-oZ^U*rXG`;fHqB(n@{%(-CnMQNR?;8Ar;yoSY@k0wSi166~)Jdy5N9 z_q2NnO_G<32kvXrOq1D$kV*hZX~l8niZtOi7rjdtE{wxbD6W}H^7Ps9(E)BHe~VK+ zYN&ZJ)$Lk?NEhK60C}Xv1E+0kjprq~(}M_eU zHA&lBNXlr|?MmYXs@=Ca6&GCS+oT=j2@P4G_c3K_zP=rcHHe<3ozIEII9ahK5f|3s zymuo6=rG{CXKL8YmSTr#pE%Fbq>bm`8AVf;x@>!#HC@R4JfO#p%-+@bR!x;>3@kd* zS9*z3%e{H4DL|ayQxQSpq}lPtxuQ9h2(3r?Nfy*bHj6{aAl$`qZq>TA+53uUDqA?G zTSz~51#ra)n)~8%*g5~Z&4Wo1NRy8UiAdiiK4q|zGWaKHrIeG1{w!(dZ9ng=lc)6J zv+IP?=O#i4i0B-;_D%#!`QSbK;2vGhe?!j`6n=QJe)ny3@|_^eDs%CO?4Y7(v(vw# zm{v3V&s!K5;?X_BvNi-U1ES?gZ_btFnOpI4+rTouB`hpVW&qj}afj)sQWPl*PP<93 z=3IhRjAmv*AJfq>pCF!ws6BtTDHv~=fN*4c!bQ@zt3qvWLYO0c=(0M96h82T+sQV{ zg<(ypy)Y1EdV!Xb;oP-;Axzr(goIy-06)-62m@dC{LzZ_HeC3vkd%qYT1%9RXRndA z`6i`I72JK=WWP7kEBkHcZI0BPVE?+(D@^KSYiz8`u7DL%yP~+X>04a8YimC>8M-R# z$)VZ)LM0-$m0>u)=&>?AP)*_T>rw3}9>APcF*}GARCkXxdr%w%Fqmf>Fj;S3!K0dkTNlJwG(N#ZQZr}j% zqN@U9`vD`0ii07Y)PfHKa0zoYs%XElRmHTiqnmLZS>>>E5ns7;oue2Np;$#5RBr%Q zh86pkc1^D19FkpFSXjI)i=-yON167#WY^=N!fD1KuL+auNT{#GDBxJrow%LG>Q^+& zlQx0p1kw3F1`dp<^vikMRx!qnHew0(qD!C8dGP4boxkfhs{Wxnw8;5@msdU9f0%=$ z_E7e-BUwoFC@g%Y_bz!t6#uVGEnPC7UtVoV>%6=mV+ET1n9J)G%a?zTsNlUSuNGMP z@pQJC7kZ$d@rS~2lupU0h*-QjO@=fM>RVWJppjJCiOGR-oOWiSUC{)hN9GyDRV~|8 zw}CC&noA{#&c0?Hhc*Li`D;!B1O0)-2YAj~401rIgOu{@+O^)L(6)LJTh-{`~1v5j_DxgVw<1+IXBIm+hbX9f``iHQMW zDiS?He8WO~%L=*XSlm;Hn?M!9w@c)~y4b_rhjIlClY{O0*9N?I628?bE2SCbFJ68#QUkp|#5jUk0D@9$5w{oIPY8#^F-Kw)SAbc(p#>5N2^YZpYGn8vPO@1>gK1y!UO=yKLUS z@#=38hX{@T>eXXNWxdAjiLfOlECf=Tx8u(XEgKafUZDFh!Zv4H2-a@LD9B+z-zJRMBdC`%NL)*Hfl{ zqh{WQ)}{MGAVN385?j5Z%FX?^wm}Z%_l6D0BAMa7is&M| zMM|eyz)l@%R+%ufP6f!ZA0HOY;o%D|DTV)MUMWL$y5J=z(n^X{7XeGA?*h% zmw#26+^laCHgq`hFTuIE%lG4}|B)Y2RD8Hq4)&q{3Ux_h!((sNzohzV`x^*SX!*Zt zcRZS6`z{r=Swgu|_ojlv%lMZnS?)bp@5jHAKyF6L0uPMz{YQ^Teqp`o^S@ewieCK$ zQ)e~gAMHfB_k&;m|GjuVRki<%Cusb?Z!#&Vo?DZ<{~cy89lVa6*REZwtl12}@^2eA zqoM!%?AI>*zYn_)17|zbyasqQjUOtXbC}lQwJ6lTpcpmIy zNUqAiSBm07|Gr?|LbGk5+Y&=gO-?H#5s>woYOwO^Bvb1uE%N2pQ|&SwN$j7d+3xR} z05E_@j~;F2#9*1RC4MLT>aWw#3$4h6tu+41&8!uqVnC}EDA2&)Y9plV^KkV4=U-dn zgatW-QO7Y)4Xb=v*7vtIXv(86(XKOfhb zfd84WR8tu7L;;&&-LLfD4~&$v1M9bH>-yixWUB7R|MyAiX6b((RrdRTmrC7+{+~x> z$$taHSND`||Jb7c-jBck-<9+KZCq+i*?*sg`EB3jzc;HhM*s6PVr~9s5-yxqQr&WT zP{{DlsIpoKVXSoQa?Aow-%7A1#NY_{>60sX_wl2UP;gG)MeYyGfgx2?g9+);gx4r!JojX>6|bV z)p62huqSROQOk*;&*h1UKc;WOUz95&-X;Db1Ui7^b^vn(wpP4M^wxn#_wk``8k9Fi zqIvuF?eWy8paDDF@l(Y|Tn3uq4V;{G2!tDvCIL=)n?gK}hW4(U<%+_$0@7S<)F@SX z9y%TXo(U@-W+(-u|&5T+;N!@v@*gr?`>HxBOyYH`Q ze+5UM`iruvDGA`H9W)0DB%geT=)VyXlNqZm$$zsKq9q^zaim9AL(Qi!MuwXRkh}-F zBp-FW;qYxdBYeKN@*)ZRkB_ZS#2LuL3f^?$pQTH{^vPF=W<}!&w$4lhnd_or;Eg+jI7g_H>mcCvm`lVZj94&mw#@%O5gDw8L%{x zdGd96T|?P5r@9Inyjv__*5TvWFF3Y?7cSgP^9#BqYLcNM2o?GX=tUveNV2Orf~#|v z=mMCNFj9--PR>p}9s?ua3GG!&mNX)Ifdobk+Qxq*F~9vLr;5#X zj+3~;yJrf?Kn!>ItRgp~Dd#Aivd?lPe}U`PhSiU%Vh`x}gg|bV#rI`b--FOmrZf{t zn!6J?7#CC5oHY&(u6XGXH6CW25v&>-pKS@0Ow##$mFn4Ean3naH;|qB+vN=bmehL4 zI>AD|1Ns`V<>EyBd=X^_qhSi+=|!t{UHT0E zN)rF_=o}l}m;gGf#}R?n-GTz^c37;*aDES10R^we(Z4Fr{dIl6O8}be`fOpg-p^OD zMV~)Qj0Y1|x8cK8@LfA?8Spo|?YCvsjsd#-?klxwlvg_9nvR2vlhp65;vTEZzAvJr zCoR9jX)>Z63>W?zR|67goM5Nm=jLwQZU3wQz0FDrW%%jEiLI-Sl7p-+%u$$16=>1h zOL52zL_`DKKukPgNZG{VQ*6vfgfKpYh`aR+Gyp|vkS;K#uG`G%HCPC5N-urL)+lav zPj=tE+nl*r{MGL#d}`GDBQQno{OUDrNWP;sUNHl*zs>R|W#c}QR%AqbmMi%H92|9^ zs1_o%+en-8(dLkB#3LppMJgGXdefTN2Td7M0Cp#co_xwJKcH&E3DleSY_r*xcD+YX z3j{~ML6~{WY2D=UWpR5(~3Rq^XrZ&ubs^pkSU&AD0@wXG$Cok@`gT50f0NbST! z_9k`2*O9a8BlI?8O9G`asw7l{kaJT{E4G3`yqQwmSM7P^k$M0_yT1C}*1AhVck$Ku zJ`{rI^hXzv+Hk2D!OUwfpT%{P`%9F$#0?91@K+8c%4**wN2q$$@7e!43E#Qx5wkFh zTo`9I+FVChw=aR2IrIeQNF2E9X@|;KlNL88{~sheOyw^i_no!N-QA2y=~13~ay{yt zi*S0}NlPxva-&%~Xkx;fXh(A%oR=d#kZ6rXy_!2q|BTJ^vzbG*U7%8sVJ(ZTAX5J{ zX>$bWC{yrgXJG>VrdRVDeNT@dbobe^WsBLopQnn7VVY3hY@7TztPCHr{iW>C)vjk9 zDKhow+I7rqFi!gq9_(43%Frs!oA+kP+MY0z$%4vp)>R$Hj0+#%MmAb(xs zkOJe6?Jj4}<#(SVU26&lY_spj=VWx@*!&cK6ee_II*LeK~xUHa{{mnl#CrDD*U0fl$k0$^zpWRgL&;9AbEX zK@JB?54=U1Brf#%*+Cv2shnR$Gc#R}6xs@@e$d}9mj$KAB_0isa?BM z*yqq66{Yak&!E=mO*>a(-3c^UA#D@R`cpaSIQLl>PNrb0eUHfNZujq1>sJ^xltfEn zE07_5z9r^8J8^eN;?TgI#~8|UP>qedoKr@~*IT9*o4DQo)N&F%C%EG2=Mk48BWtHa zROd=UC-$IYilc7Q-*+8d%O(Bzd{aF*wa*FsEE3+=%h>~I6f7Cpw1Y#|bStZ^%zc9v+i^51eq5*UzIa#m|mh z*ns13C-CH^oV;eAhBK}D7GL}QW#8^!EP%a>WkF2YNS2d6=|(9Z3IWgM58_)=vD*SJ z#k4%|0t~wM!0ddLQY04OOle=}*}hWFx(x=0l;mATO&(P;RpKK+4&I{k$kQ*@P!gJ5 zJVofcDPBwkRHjd_KkCaqA3ml(=zX~Lcl|S} z6Z)dDB>-+}DyqQB?vhEGHexb-heVr=D00hkb?DTelHI(wr0wa}5cyl)f<9rgkNQMf z(wU#;8Ppc(hT8xx-oAVHH}&;*BpR8Xi0n}#M~ZlIn&)`^7jaU=^mjc*ALayzTi(M% zphv53EGGx-K{`w54h$p8 zs3@PHWGk?(hnsUJnV7_JS$Ce4&6RmN{%Kel4wc$~0Z&iI_f*q_f2|3~00H!u9@_Q2 zw|gzbF+RQ#V6{QzMJnwDPQeaA{<#^a^b&6?@LLPnobZL-rYg*E^0Je?KOYt_(sPs=F-F2EB(fXQOmFJ>r+)W4cqK{v zuyS;5kF27Xeu>uSO(YvY=lo8sblSRBxdF_wkpI=I-%zE{?&5OivQO8p zFO&t_N~_W7^8II?$XvqP708rJjXlZp$jo9^yH=)OO&ug_8`V=K9B;!B*(ut>dDQ;X zVg2sn&L`3JKxIUk;6?Mrnl`Olx57OF6`8&?#wHWMXlERYp)~zzgJPQL zLLJitulg1Tve%I#pb+}HZ-u`?#FEaO>)HPZ+-L%m;C#G?UEgkfZByA`&|*UR7}WgO zqqn$4O9283cJrg|9|opUaXWBUPKO}C`gbg6AQe7XN~vq5b5CGD?0?~B8|N8-S&X<@ zUds`SZP;$zO7^YFh_&4zE89geujxY(M7gvqy#=5ErRS%V7ht28rk++5?wF6iXnqg` zfl#U5yQkl+z5au+GEldTUzQdVFF~anN0)_U@w>HmGa4ZIK%31OBZ%^ZacxP!=AH@r z3OG_F#ZFdww`brP)&!n5B_Oh+3R7IOP8s`dr2>&(zKzgWnS*8rfPsnT#ehMW<6_>z1@ZY+|<(CPnQJxkCbYNCgXLc}o z^XafL+ccZTww_YAsckbiSO=6xo{C1%Pipjm8^xragVA;|u$(k8VFrCXH;s{oRAHMR2SF|6c70q4JS;iVP1Y>r- z-7f!w6y*73OzI)PtKQKxATUhHJV5Sne&f%T48LFxsC|!p)bbJwRhk!u ze>FbYhSKZ?Tz4GiCP}8%*_cTQ1-U+%_C9bQq9-umbx{T-tP^1yJKp7_N<#A^)W2_A zgWIA)4Ic2_`1-cjH~t(@-DTs?_Wsj-;LAQ$3T?`GtjE z_JQ)7E;YY)W^MS;JjFCRo89j2>3LymOAP?rI&jracq2By`#fwkv3x5?=o~hC143+SFaOV{NSL#f3rz&|I zmo)0!%R<&}8eE=XJ&RXrFL~N&-IrIH)8jPAn502taSpfgslB&S4`uhvA6@-sGyl_} zoeUpvyaxW=l2r34?TePy>uN_;?kv}{rQ8$})V+ZHF_RIWyZ%`0hHrr&(*)S+>$(4| z3}w%r_uU8B_z$J3T}9^j>X||B4aNH-5;l(idJPP37`4;qSu;E1!eBLCK>L-~TR}m= zGUwydr}T*SDcVYx2CM2(S=k**YY2kCyU-U0*6JLp8Wz_clNywW96wNr^XvU^Z(U`Je zn#<%mLZ6O3xx6H?bBgKy%gCQ;uY8r^bF6Y64env_n=rYhUv3on@m0WjQ&|~jGkSC% zK|*jEkneR6H8KAx?>!!wTGP69f5f73mvNtjd(b%kT9C}D*M!bVSWy<88;5-X{~39u z>nFni=kicVvLEq$rwP|s5{Ym51>QyeiJhI-n*8*{4*|q>hPd=|Ig>6q&ve5D3xz8A1wklp?{)3)gDX$RTrFpwQ_0@+y)wDPn*%M30hqDQ z!$=TObJG+Dta$?IYL>nCG+t;TzU|tCQ9TqC#!NhF@wpNQIFxuR#OhTN8H3ZHE)MgI zyNWt4cj@^i)+JV^9$S9bUTLqi^pw1^OZ4?zGXW1WnKOKqj%)nUQ5$FA=BJamFrM&d z8vlcpS9jm4(CCPk&3^T41uy|&BzXO}uH*+;R;9r4{5xNeT!5{Ut{aFbIslRx7}InmrRS2G*u;?HBF zT^J*+j+zU%gH%|i({M{Ov5uuQDIM|j>M4&-cUz}*n~fh5h_71Ft8k7gs9I7+YwrbK z)XikDI&Us9ozSo6{_6bv_XiI(7M4$ps!6=j0bnmCwEHwOP_Arqi4outig-_~dCSaGOzm~WB$d1Z?b^W5%xjmt~_U=t*t$aEB z`ZaFFaq*Qr-zoBJ;j+7Q4E^v8Xw9Dd0^D0(Z(q(6?aM1)zbN0adzUmh0ORrXtXDVa zr`qb)fF}3MRap~8YTn{T10#!4+NC*k2CJ&iB??NV_*z9?EhNY~gUnP(qDys(y@&lv zxJL!*m!dg17+{^JV}_CU@az@ynm2j7805?PSI?ukhM72j@ejPEABUJjn@a+X$$n6= zwd33vr|Imibe0tB4Hq_Q_?Me+XAmH*eKq3{(PXkrk4{4cg)QmB!!ycgjI(`nGbTv+ zBp5^!8FIV@8=Z+E67;k2gY#7KzJpiHC2kn1s%|bq1-yA4Cv`~B$xNK!G-67&&Lewo zV@H4`rDXP55d)+U*?2mDn1>U zqL&2c&JrEN5WtpCGTUL{R;=~9a=m&*KJ)GAmt~K>5s79VQdu7&sT=fa>s5?wD)M%kwA}W z=V_y4Z>%|RYnaa3k8U)u+Ugj(GI~!nU!P_`DqN!lrt=49=)pbFy)8hJCroX)de08N?UcX!uXAG0LgY~e&j+@!~o&^eB z|CAnUt1CT>Cv%keJ7%|!Z`(lT%+tT|!ILaOTX+&ftQ3s49kO(K=+$P@F8trO*ILBv zFCiZ+$LKJIf@41jXyRc{ua)En%h?H5j3D-4GPrI{RrT@yn}5unP(8;a6&8Z{0Sdn> zC>ZTI;C^goxJ580sk}7HJDK(;`CUoj%XHJ%msZ)ML?cB7uibkf*a(%WWY&&oTDhv{ zxDi=`r5Qb8UCPZQg0WPR5feNeyWS{cN2+dI#d|@~DNM5rSbw^4h!Y%MReaE&?7Vj* zi-j=h0T-*lLCSex*oR7vIpeIVE11jU1j#Z#O-(;H3mXxfA+qI8BrcecOVrIrtXwNC z+n%4FuL)zKb@Cg2rw=!kTmj1_WHZuI<#-;<{~_r>dV$%@Q}G=NBD&|adXHCRUN8Wh zBKwd8!Wob*y!-&NP0xX~aRYCTwhdmq738=jo-v zk&{BuU>uLr@pbqqCEfV=_#1&6Rv|Iq+4p$m=ZlG#h~3x`ix@!9zTUA8g?=Q&-??kL zVDlJ~MNFCJ%1zy?E)VC_)01WRueoM2vav2M6Jlx(R-YNQI{)ddT{o+K%t<(= zo6bVLj=LnY$+bU;<$ZZI;=`9j5{IY20q`!8?PsItD6}6RIpl1%_rjc(dcVeylpm^k zR~;$jFNb$}^1AuQ)@Nz*7uiIz6nVc)EXOHqp)V#Imva4+@shOBaw=JxqB(LQZ|pa0 zm4BX!f|({jt@y%U`;HQr>e`0Ik|WT`&M~t!+}pYEbl^*HoVMr>uq$2=5o&>VlJ^Bn z8KS;H1c&U!&Cz7wMn2Mrdn)UWkbXW>SNQryy6NtD_!8%8Dq?Gf@OYMVI8SK%nx{NP zmhcc0sobtz4mY9Bl;?p|~J-w{y#|k(yLhvnE5k0Y|>k{rXc{swV^8-)HNui`B@N^V8 zKa$Dou^bOz_Lh;;2(j954;VI{;c$6K0rsojwxT^kEF?476994wYYkV|jhfd|);b12 zL1BZpIebyqE*eiH&-?Ch;7b5_X1Gp4ykDeCjWgT>ckMm2{xmXLop#uT?hDnFElV)w zFD@?GD_^1GM7sFInB3>knOoY~0iJH&tCEhtxu&c8R!v_D9BdO0bvm+m8M7o@Tz za!xNE8=6VPflkbFoR)rLWL<4vm?M3-oC&;6YNIZj=Ie~}i>9-RNls-oUdH66XntAB z`=0AF_>f9*$l19!p0}O5USSwf`%x=Ptzc1@S_c+WzD8lv)AyIjZalQ%?|W2HP%YuK z6pt>prus(RDU^Iqqp4!DEoFWC`@bnDoO{nRV0mmgCV7S@_=TM=Hn?~|`C7;hbN?v^ z`^|oTW<(lC@`!BTS1YWtY;?!h%Zwx5t+Uf37cvOd{)S694rypvXC~c5$qi5}aaonJEW)3UQBk>R93HXCC~u z_I}k1*2dhFv3g1L6NG>Xz#~#nc&o{B^oMdGuQ7M5C?kE?|2mL5^d9=sEC9@KKEX-l2CHRke$*pcUm7b~fTn*2NLV9}JU?UgRLkUL*4ITU* ze^AKlaYcRRT=l-V86Po;<0Ob}!JDqRGG&19TIt|3GiFOF91oqO+>l}a``$S>9Kz8J z!8tr5JQaYRQ@EAgox_c`bmYjKL)2_sjz*G=PHTnT8J3?nv0`xPAM+((8vb-yb7Y;7 z6*BW_;!EezMfZ0k@7KZrK>zvF-ea0cfZl5#t7-y5ZPj|Wn&bL>I;`3`*@`MhZ7ie` zy7oh-Z&VcOTo|);l$ZNkMrUyS3e32ifhn41pYGQh?sRRt-&X6&`TlbEv0*D-SZ^Md z;X&^4dS2tuwael!k)zEe9nT6`N<8;`iq2vst1J%+@H1DK{C@4W`05kBx?@%I$uIM* z2jMxJY4isR5tXN+Ar>y}r={d zLDTL--gDqRUmu^Nv+6$i%5}BJ(#S7>kY8XJ`2}wN>OTLx>WYBXPds>RKhJ%E+!^kZ z>1??@2?pUFjzc7}X9coZwO4kYxus2JrCskKl0Lu#$g=lPx?XCr0V?YiZ*IPp|5_J4 z$)U>}+|6b2l^h_Y%uL-9SG>|r_hnXe*CESq28?E*5JAk>ah(ZL#XyVz1GKjI@YS*m zraTB!N@UTNnrzInBRcgaTNhs*baDC3t0V9q@vRr}J1Xn>tV)JO_627Y8Rj8mQ!T0Tv)fu#2?De zQDtMNT=@Bmu}1 zf&5(N&V;z9)Sz~Tt#4XJR8~r66eqaf)i*^ioHmJhr3ZyQ;>XFFN4_n@s{7;Tn(s7L zj@G-SLAT^_0z%f2Hh74n+TCPy@MXldE1fa#6=VCvcD_M(fv9XMHf!~!K!z|maERG{bT&deX6QD_7CiT zk0{5#H<@6hob%cP`*^+DIqSJy;^9UL3ac9)*?sKd-BE|aQ!z99hG#&RmtyVEC?=?Pk){hvl=crPBjp^+$^zn16kIq@mxW)eP_y_oAZC^Qudwq_{x>ohd zE-T^Pt0M~v92>oDkH>ZzzuuM|tKmE<{6Hw}R?o;ZkMb*=Ove3UNwG0_^1-UzX0IIm zlP=yF{u||9nycI?->%JmVnaOt@X@gbBQ?5pyZy-OlIDS4nX~2(xLy-I>F7ptBy)I>yVq(#awxM>KxO)MNx@ez{lp)4TQ0;*iz=pLTl z@m%rZLD>llo@~tGI^pGdTaFOl1T|6WeVjdHOV;Y4^QiuH4^Du37Efhk537m5*X7~Q zm#T~(H_mv%*E19#%F4>fk#(FCzm2K$SlOkRgW>tQM!#L?dv)F_FdE+K6I!HbuKd6P zN&A^+^XW(NfHkYE+;?fxG8%zI)gs%g9vG)(c?eWlMqXG2tnX8Mv-<(jaB!fP0%AL zxDlTH_)wz#cI!;P$1K;I;I;ZTy*KjVbGucn`hSeVA1!(wd~&`t`la9F1Iv#+So~z8 zCARW1hICYUZczI^8%QKwmMk8%+=tR54O6*_ds#KyPQ>IWTQKBTNXeKI&%3>)Y znp0E=O-PD}@_f#$bwBrSd$)Ic-#?%2`EJ{NuQYtW*L9uec?|opANxVFp}k^(mGWnz zq3u<{hvLDmw3U_5C~Cug)0&|g^>q7AaTNhxPujdkz1Rs@-90Ch(58*aSl?#-zCxi} zo_748|?mbx*vQCR$fBtzM&H#ztm}MCU zE+w^mmAE8y!1#L>!Jj%?Zw39^jz`zrXi?ixTaW(enr|N_m`2dz`5-{`a&^DVt4_Mx zJgUEaYm}ZQ-$xBZ+#GI=WY?~Z>ojloe*XPr2hSdZ21OGBF37d*_#{*B;yFav(Jd1C z{v%ovKF@jO8dMNz6tVjCmTMXmQDV$<*tpd-<91H&+b_xGQQuL9nuvV?jG5A`q>orU z7d9XOuTzRvS{YvFs#E(z4R)RZPj(^rln^-l`OdNtTP#z_JhY!=S+zYo(-UG-zWdTB zA#uAD{eoDq%>^eo)y|9m+OoN!L2$r8ibV#cZOD;3+4?>3G zT6e);4aeLbUvF)aFx0}!9)CJr;{>f)v104c_FAAWb{bdW$CQ2f$>f-s%}a zx8fl6t|_`~|Nh7ax`PVB4eoM^tg_GVt?n$MM`rs@9z1L=ri&gee-WlNoc{f400J|n zH=kzqBCcJ_5mYn^U-@3rbrBY3!RF7?#n9%MlptDwuVRdeCI&3FOr7s7o7_~kFe|?Pz%{6b~n26rf%FU|AlrC5Oedt;Y zkR(qmrMgML$dpzb>Ff=i*-&3^0HeuH(ec!y)Z;Pet)U=yY*(skvBk~}TUJ@GlOAi9 zU*qeD>vMiz*gw(BE<&8_#$TV zUPiRwJR~43EiK`v%*4_3Ha}ex>G-{&o)Q%VnCk(p@2DbO@2&Rz`}fhX5f>=8zw(gq zOM2kIfX3Q%n+|*4RTTO~LJc#9=w_$rEC8DeAt9fcu~FK%Rj-fwO()N#SHC}%>@y~u z4JJOG>b=RO{>D=1VYZr#P_<{}reEiWi$MOtg3MkQJXzM~e89KE`| zr-%fF7?rPr!K;fi)gkME{hVq1PTI=qepYQ6JY-1FGs&=PhNLa##Vrrehym^*MTiin zJPzFhhiqCpZ0Lju*}kJ6AaY!QQUodI=ljFGPJ6!S?KPu0yt=c5o#>j&>v8J>ziu+# za7(R=*vyGi3(S7Jw|8ByX3PBpZDLCDWGc6p`Siu=8eicV$|}+?*Iv^Q$PU~=d8mZZ_ z*Sk;S$c+hZ6qhG`wC%g2`tqya@|MMHefKn{s&dC1HnQF58Q1k_)-^(|7tG^vmx|0T z{qv$Q3Wb*~cD~sA`|@I?d&ereY)S?eERWARv1}SuMsm6uclmRh-YnSu?8!|6K&DRG zP-(<$R{>d*0UAT>@a6{stl0jjDBNIXn1ceB= zZrbPk#e{H1jGG#6_jXzQhG5Pl*{s3B{tXK=&zJ!4<#*^lZ6%=Jve%CTmr+8VH27`^ z;q+(9Pn`Hj4L8Xfc1;K9yPWq<_eHv!>cXV;3MSE~>&Ho*J0-tJrb7KS45dp?eJhyo zr=6H>s(iC$8sXR+9d!i=4n_p&^S{_kem*Cues@`y3`em0C}_)I3pBZe8C_ea2Sjzh zOU3Y%vzv592uUXfCEqtT8dEF%o^r;k%uuuG$&u*UK}4 zZJ=lt51J)o43V8b>X#*YT7AbZn_PVPJzS#Vnce^R!-QBx)v{fz21l7g7UF21w)Sq# zDd2z&!1-K@E+A&i_)a?cXj?C_Rd2l`HGazKfKzB}0~G|9wreMsISm-PPthadHdq{r&g0QFgwI86#HWpbw@2vx+1E^fV!bDrp0no+R&XRar`LTM`6sEN84^QjfI_b&vc(Jyk=|_cS={vVqR)&OxaL%j3Chhr7 z;%W2xq2D5sibW0O-KC@!6qPq!C!_erdkz(XJ<7c2~P(GT3(%}so z54$a|?*|ji12#?Cfw}+>;k2mQzP+~<&bImDH3Ny>D$nhB?ATS?)oHEnB5y4)S~!B2$?o^ zJqboWLpG*jV1>imPKgA6>qqS@X|;wEsh#I2Wxv+#Rh#B4DWjSI!Du0G|M1$A8MPi= z3LponU%Wr)eyoMdo8Eanrd^Tk*g{2AX_~( z^2R*!JR8X@gfvX&^`0eIuD(ZGUQ5JxCe852rdm0Uy55js=W z{jQd2KKHs%=nS#$E{dVyJK}tVeZ808kqf>5 z-s3Ae+Y=|mSt%Pg+Y~GY#uOz#h=rqv4kdbbq@Td+S20LA3<%7#rsizbil5DXA|mAa z+e*?=?!9oSiv-(>a+Kj^i|Nn}(;d4?Y!wYo(?8t@Wm=;{h+PqThWI9N_2{Zc!V6Fq zSjVv@xf3M}&pJDGbky4dkfV++uiWKxjg2fr*kRrJD^rjT2>t9*3}?2!Xb^bA)M}T^ z=IlA$U6iz&U*kmE*hv&M5#`IUeOb!+$#~MkZ`9&+hmO&>OV5D1`HuyOtj*Hoz6tvI z<5EbU0{KhJZ)}dO86cHpKsqc+YZMXRQwVfe{{AmfK_?Ct6-37F1<{4diX}^rS(=^v zt>9?C>D%K&JRXImc^t5{dgX1=3&vVvJ?9=~y2NP_Qepd#1&Z?5r={Hb07Y&^tG+*xOA-gt72vqN}f* zvw-0oZz*NNr@VW5!H1KWL3vpnrOTrsH*X=@{?z;ABxnU&pC5R7OcccKn|G#NuJ(S% z4ekg&Roo9$MexSjHM_9!PtinJM}dGu6teGcQz@_kG#z) zFcE#M7DxFy(;DV^xnS++p3qKBb@rWk(P?*Cgo@c7ZRZx7d5M*^wNdP ziVR?#xvaOS&{CgS3MWrx@Q6|-FWQ$roFFmkCO7l>v3ec9Jq#XVO?g@xzb%I6F|-JD zzpGVAd<4*k9V`N-OQw4VsuR$DM#)`KLE*lmPuB>#fGU#r%3EW&1^WV&3UOUM3`1#s zozfXe2?;wR&>+;ZT0LDa^YNy~NA6-BCpp#RPb%$DEarp={cfl6Re2)Hg%u*`4aAxk zVUGe$4RL_& zE&-kn0r4{?PMq02GzeGy!rH1j+os`l9)y$Vk~?(^aead7H5%UrYey$o!>ut>K4?m$ z)M}U|BD4^P5lWte(faqS{4N}sWVdcS{>74g$3=rM8cvb!k)dnRA<2^B^^ImAy+7;?|roepH* zRw!nhq@HhTXjqUZwJ{(lrk%u|{#Ml=Hr7fNQz%*X6NZMM(uggcJnWzx?7jyNWcTN@ zN|`y`ltR;-zq(*A#CC-@I5l-xx_aHMiX_h{LW&&J+`(1%IEIT=ch>({U!Oyz_&o3X z%+W{p%!F0@ei_zeFqk=r&ajk=6_>=mOX@TYfewGl?ZdbBPx0e4+^6 z3epg2jiw|vM(4MrLtDxdEzB0K$3jEUp>p!DpxjMGx8+z7`3`cCG-!5dXJQosX;+ zt4l?7Sw~6f34=$>-KwY|f>g^`C=o+zpGQQ@Mpcvh^_y#zTNjrhleb`e3R}{@TfYgn z;{rR_OB2k%Rh+u6=^hB04NX{{?@qZyF%on+2@%}5EQ@U~J7zflam(P<>8%ew9HBOt zM_ei48Mm^iBk7z8JE{R*mq+&R)VQku*FqX(L6a`C|b=^$<)c5_+(lL zfjZ}X9Zb38kO6nY=u_5Qv9sRI()odidxhFRrzy|C(DBl@Z(tf7J9S!TFyQv|cb7di z5*IUd-ROutk*xJD2Rk0tBnZUD(Mxk+34el#w5${Qq}^SNAZ zbETR`2EEyWGHg}yHdor!)m2n;dv_nGDhPK7#;pB(`ZDXxa#tjcJ0sERDuD16m0>?! zJEkl%(WuM{SV9$5NL6KifCEj0@ATbGka&!9}z}5(035pj6$f14y^2}ba zZ%IaTsp*5c&ET0JtQ(2@(V~4%Y4ANu(V$0UdK!YOU>H)762T#5JoVL&C@n4%V#c2v z7DyOQW~|7(F`hL!A%{Zx1eR)(eX2msa653|G_+PS#|_ZJJWcnU?1KvwluJ8qg4sL|TAaa8}SvF{cKOVWcozGhI_LBpcOqB0m@%d6Q^1c4!n~`{*!dgRKvwig`-~ z2n7$`hjwV+Ua*|h>=rb~*W_LDLaB<;%J;zEa82_OnY$RWZ8_E}p3czGH~mFUzhc*C zUkZ_yY932|Gz9(32P{!wvGNS;AW01$yF4ko`H-x{Rugg+-s65r%rk)+VbW>^Ngbjw zJ3?u!d2)V<-xZv!KMb?c@@p;NOsBwo#*>kiY5A*fjg9J)t!)T)m4Hs1?>O@kN1ru9 z4PTtbrL&fV&i6LzD|xfNd@K~$TeolDHtKs^?WK$uD+0kcUvR=9wsWBsr_CkvLIBmx z5Lizz|C2i2CCDAi(l#lj2xkLDQIOo(qfF7rF+Sqd3u~s0t<`5M>dn(O;Xj|@r2ZKwkhQ;+NQ<9jxxoC{!^5F?bwv)YFo1pzx zWL$z}`X!Jc{I;E3Q@l6|$}Vl^FwQX>wV` zTH#Z*#nOFxn`J{!1Jn_5)Cy;GZ8txSP@DZ=LOa8x!EZ)#tOPJD0yh5cz@bAn?6u}+ z<*n<~9_aDsQE*jE+JGtri-~A>9~&~WEV=ywky^pTPE3uOuLV@!K(?nJcxIC<_(qJ{#zddx=siY}mF!VyNl8}}7hug%#_4zE9e z7t*I>uV{KTls<|ZL^VLHifwSQw1muY-1jchDY5(cl>Tat;iCVdbWr3aD32=;>=sCm zxQkbV1Qi0m;sF?t#n==Ue*({oAY{Odok?)IYD<)c^y(ENNb^WzTnu?0N*}iwwHd&l}Flby}Q!L4%xxj%r081G_v4WqC9P1et}DiDeE=8>z~{3LyZ-WEz8_7GgbPrq;QBe^wrk z@iS6Sy&+nC;@4#M9G%#%kqH0AU;9fo0N3m>*p`vpv z7Tfs9Yl+jn{^hmUna}SqbeEyxhE%e2X&ISw_pFp3b7~0*m^_dFwMODiO;U-$XLwsV_d=Zc!LH7(iI z3MWEE2Mf-l@PlXNC=weH74%M+oV&tO#8}(gu!s6&`WDI~O9W7B1>%;>?&6gD`F!YR zdt?1?XYZF@8?`CgCs|kYm3VaVKZ@e!NulLl0G;;XOw=NDu~zXH=MNtZsH_1)p$Hko zwXfaG@uKn2>WDjMJe>I~9(KLBb1jeapDLeRrp0p!O$&*3zA@LyDh&Z4^YJgobRDOW zk_y88xt*Uf$NbzVVJeRl0M1tM(u}DV4(~54*rSP0kdD+7l{4?Irem1{A$|oq5V81- z;r&IS3{q;0=`}5?9TQLdCB(py4JfYq&4?@=v_49ZF;geV8Zdq9;-a8ff`fje7lIn) zrL}$@Fa8h=TIIyBiK4ygot>)g#r6_=@`ax)zVP;t%=Y`U{2%QA5*PF} zPNZ*LtajZVG*!_`rr7Xs_HkstrxF#jS0ox^*pdiX4k0 zDC}?o^67=zxLDdD3QJIRqw-uHO-lvr&N0=KmNAhli?K;~f-Nl`4@5Ywja|xy({*jl zKu^FLg1jZPUz@_a08|M`cj8i&7Sz>!Nc*{4YzjO_V*TI+$`77A8DH!< zaiMD>jTL2n&p%9ExO>>Z9zBF=s<7LN`sZhz+GV$C(-kexK@oI>sI9xXb`cXOL?>1w z5zxAC(>J`jP_V?pWyFHYF&pWzv|95>M1VZS$BE`{^lwpkIsV3S26>|>d+;=_z5hOI zsJhUZpc6fsG*Dz`reAI#M7P0Z$AGh!aV#`5=4t!&|8+6Z%NGx&Z?;q;wE(=m~9@lVQM6T8AEnZlWA$?dOSQlav`TApw^n< z+|=&)RLp1|6lc25CyRI~N@@(wEfnszHReg=VMQ}Cx)57WBB&J(3!tvC=EyUCX#)#e zLcfm8&X2P%@A6zc4JxwWtkSm_O$4&-{T5gb`6)ZL-n)0tilK7FS8xB&qK214&+XMU zk7!>|v3eBnuC;R1^%=|hjy0R~%<*rETm84%OgzCQ3T*da?zV#fmgM!J$h=s+zVB`E z{$O9ia?#pBpfO;|rj@5A4ZmWFzGLdQw*&8`r(b+p`e-pvJ%+P}wCUPpywa^^)5M0) zpR;n6W&+z0ye~q|;mN6vZ&ej(u{h=RjidOj$hX2S=s&a@)f-Gq5%zM&%2R5?bEhQ zr3bL>&H-o$-MrC5V$*lHyxWATastz(_}fK%eq+u0wQdxbfnKk*e3uzI@oLlH&e6_f z`4Pmr+G$9S#EXP`g79p-f8deGTC&(#?$u5;}boKnO{_IW`g-Gp(B#O zLrbtec+%&|?$5s5BD{_mK8ocj1d~9X)?TnlWEQ^vgG4Du@WSb|*+r8gk3WnaVg9Mk_m8>f`G!VPhjSopuLFJsxRZ+ z?Qv9VL4EWDXTH-PgTilqNPalHT$S8coC!#<2}p+!6pwnA7JkcM^{`Kvy3afc(QGQC zK2snXBqEJ)3^#C1GT&Wm`iSp|&!*FJIuenJ7Ppe-rP|{G#eM@?9C^&;`mK0Sv1+|0 zYnG|9teb?ya|71RzNGX5SxtIk?U#7AstwaE)V6&*PoPP$S1wGZnz|q8( zuYcO)yq|X!`pjl&yHjZEua0~54C+86+e3IG;*BsK$bn zf@xqeKHT}j-%4pHSrfiLw5*gT))2miT>#V5i1ZBbSrbxFHeKE;yi}q=Y#5mQdD7Yk z(t~L|CLXJ|c8KT&`Z*63@Eqbnf*4h=z#s~N5w1U5OOl;%8D?)jzcRd}J*dI`SoXAn zmD!ZM30t>=k@asF1r5g&fde5AAjtoBRerxtb@z&>ws5L!NbX`w-h3)tyih>nsWPx6xP+4BNO zy2xO}y8vbhpt->Ks+}>${n((XhO`mqLu9ib*dQvJ0|Xxr?9Lc!HyoV^hz0_nJ!3ZL zxH-Yo0`7AO?KHjlT>O||{%KpbvdJ?&zwxW0wh0gq<6&HPCkaDiEg2sSf_YQ z1$9_m(S}Gx@Wa>g%KP4PFbs5cn;_&P1k}QootSD&QCB!h@u;{@?aEz*3SjBa2mky| zI1$c*MXpFoMH1~inLW6|n%SGzgLkjbZFx;SBmo#id<;pF*p)~a?art(;MkIMET;H| zOwF;YyDH=2Bbm2QwzEMmrM0WB*dE> zGk(sX{zp_~EFL${j}N}FuAQXB6Hz@v_AdsrA(-H1sBMYIK@MF-J0Tx}jTSA^C*CTS zW*e)F7Sf%3OfFS}gq`0ui>nqf9KEc$j`k-M(GI~x2mu1}ndf?zt=NP4d8a9IFC5Syo*AFMT>UK^ zA=pftt1)>e?8@hMnbn-spwq4h?Gbkb*KD;voYmD^2*3oz%0~3(I{IuJ5w|Snl1a<3 z_maI8Lo;M{=nAmBTA^6bXNC9niunLb;9GNZB*9H2R3eM*)g!cX*-EsQh_13~j9fFc zLtg;dH$Yv46Uf9x-?@8NeCKX;;6uy=Hp_(F`k97zofAliMbkQgHl7CaFAA`x0?d7r zd2*ikTO0`v>3xZNXtZJI{2l_!5RX<)^iD#&%ERFbYPKL^qg0@0rYrc2oUQE7iYEzf zjhM%c?OH;FdP+|00r591p#6XbG>`NKG=SqhFtJPpX)U?B_6S9Ql+=WgFtoj3b#7o& zeZn+ya`iImBR_up7%Ho;t;VPgVLUjg;${UGA3D1*9wHQ>+)aHobkwYo9e?{RoW-(6 zbV#n6^i{CI#Y3V$NM2=KXUx{QX)ZqMTkV7;0K0rQ!$+H&FBDkp`g1N?Cmg@LGbJEG zsBDgQw{3DGrr!ltGQ@#+m2?y4gp4Jg9PJ!iEKy)blLT~5@;!zcynmVwlKE`W=c`xdi@ZO~y{pDA=*Z8xxKrJ~^LbAn-Z{sS2Bz z=Pfbs;SBpPFU0pHyL2(&(ge#!?GIWOl!F1{H@>u)MZHU;hIvX?1F141C2GvUrKa+1 z6rTE5_1ahvg)x(mz|5py)6_l5h9VYHKcQ(Y@9m2}oRM90Zkl%KA#W!t1;)vaR=ujz7w2L~RO^El#qx3dcv7pf7%l(u<;!;mR?cEo`!#feF%?0T zRJyV30(*-a8g_GBdzsrk^_Pz2DdLwGV^(-#$$o6nW+U$&stF*vpl%^IK+bLE~m)c~{WjseWp)zr% z`sUZ28yX~D^c125@1I~@y2`bCG!==|rpbJlva-_PcC#cEoV5}uUMh%M9lQhy39o)c z1#AVXSBaAln<8Y%bv(LpJ%h;Fkko#~k1^Cv01>oZ^ZUPR$G(W6t*jHg_p3UarOYK% zZAz09+Ez2R(&j%8{y$Gg?$GM>|M|zZj-}Jv4gIbi8#PmWZ%wEFdTJjfCCem+eqR+V zC4RK1L3OaieMzk8V!SR*hkM5Oz#7ZzA#r&W@L?O-tCPe~sW({xMX;&ZwtQQGCvZHA zBAmac#5O$+c+&NDdp3?l?~Dh0{PoG(C$@|C=5^m$G)lKr&itv}+yaa#s~8Vevr#NmGqMXOce~F#|MQg@{OkDq z*LU~${^vvf^WBvrJOag2{_odP_y2l;|NOv&=F|WF9lriQ^X;w%Qyo@NR8*{9B3jGl z|9jET)5{(I{d!g*Ei*wx>pbBP&!_IU`u7L(Wdi^2dK&-U*#FyF?fvPMqV9F%s`by` zThFZWHrLnhvqXCNthojgn>scrFUXJA-LWyD8wI=@l?m8Z_ zimy5yPI{4^>V98&<3xlq*D78++kah$8FfWbMYRui#eP3~B=yU-CpRwN_=9lZ|M5CA z!d)X+|Fl{cWcy@gFO<9^`BoS|UbQ)y`FmI5;*^?`jnsX7e9ni5-><4QTe|c%f`(PK zJ$o3>W1O29PSdyrvce08TU(w-@D5wX-V`TWv5Rst+J^@w2~h6t2sqs zB|;K&senw21~dOlH2gL?37PSKUX80T0%%sNG$&KIs6tLAr)Y$1&sA16zy10R>uaJ_ zFEra!(U(|&t22D>fST^fXu+a78NEQrcK&)986!J(disvZuQ%VwDcX%9=+V@XJ-ihn zB6@MelRwi9$6!pUD$N?SF9C!TNinfs#d zWS{Si)eV2EFsA7isdYaTHG37c{P^*ry4uw3k^AAUUV;Bi?$CW*trsrjKv{po0Xynn z%|+ET;;aqM{rG%2os2<)2M2vUQ)-`i{9M`5_ssDjd<<5$C|wuL3{}ylw)vd{>G7s5 zTk>QV|M#WUH|}yt?TvE)OWtzR|HNKofLm&FeTRe^y%=; zu^_U4m6VuOFLieQ(raeJ6P17WXpbJoOGbk(tm*rQL!TGYFg{p z9m825lB{`0blnf@&T+9pYh1jP!(BaP)cK>$Tec_>iuuMtapO&V9@JCB3E@r^GpI#& zg(V^8RG}EpS{XkGoZmab$OGj9F`<=-mtx#=e)+5Nh0@&A!!!rb`#dw^p^9=#zHOC77uZ~g&}++9>zHvgd1CPrVe z?hF^-k2tFFD2%G*i-1|b-lZE#u$!1lK>|}jGRR?N`daU^fq^7ad&zKa7Z#dqh^4@J zJ}K#I&J*lh!8Uv26=gw)=H}owmTC2>QofFenuXhSCVJx7?zS4h$*Id1Nbpl+J*KgiTPj+UE zMo4e<6^tMx`p^A-i&2`)++v8I6k|7DynTBR<6XTg@RG69#8dy-x%1DUIlDY*Q#($i z7;`GlT2PSJj~_eluTf^$JPy8#Fl?yK0&xiSO_85&j2?P2Q;bKW6=7O(A~k1*iTAH9 zlATx&4l4v_IVcQD{8CqLpFb#HC6mF}80zCat__n_{-X~whYD7bS$6Ky^$yTlKox{AUk?7*(elBXX_U zT^7t!nWCn4hw;=DAm;*}=I~xLv)pB@Y+Wtqtza@(Mh2Z~o9KqroHuXII6s@|hnU#Q zb{CT_g=8-aIa_#uOf^hn^c^98Lf+FXm+6cT?0JygHIXIHRK5QGdB*U`x zbZ-7Mzp>e^&HPCKXthnrrPjc}&18#869s88t&-Z`MdgctR0ArfqEcGo;GCx72}|w; zrS{8?rKj&bPWM5A7ExBD3x|j)wUYtd%N&4t z6raeS%-qDdTmpr+IVx-bHw}YP?7T0+E?QhuyTO;ejujsLi@_!0-AJ432`EngDE~(N zMc>5a%ijqk*5|KW`I8~0TA2)dYY#@InDMRb@L_AkhHo`eJ$h6ZupO1Qepnu6sgj^@ zh>1%^R#u~+rg(8-vXdqm)ww;spy?4<;FQ*}^}{>egJJp#LS>qYr;a5G$;>R6INK!M zsUp4Gn59=FunSMY26MA^*`l@=oqOO}O_QOBvgFn{*L&@A+Oy-o^Nk!7_ zPPj@x|FVG&u7TWA+cDE;;WRNFxk}=n2DZ+x{_Y|)p2Z+3rVHu$AGw&Ne`C1_{d_A3 zkTw81rsiF^#F&=_@4YvrZ%?EuH&o1frNF_M89^Hp!)+gw{N!fu3K3omkdL2V?X>Ry zhFv}kx}8HeZ*J^fLZ*bk<>N9yF<-uae@s-UtIQ7$k!jg zN2q9tm8)4_#$<@;w?>`;3*xR}zC^1*{BmYUW#;C0`Ll<2>j*iymF!x@jN;z=SiF#Z z``SP9SCN9Rx@Y;gzNt^LvbrI(qlp^qi@G{59`RkKC-rpyt+=?D93`C^FokPF`M&TSED|GmmoHwr#$PZuBs=uziI5yIEvbnuoo)7RUMarQiN zWIwwP>3?M8HzBtnW_VH#ooM#Q<{N_$P*h>$=@pn5JfIAPTI_+z`%%H^iq4N2K_CJ= z&+Y>zb?(!rk5;DJ;rmaQGThFn1FeXW@4wy_*-xY`TK4sH9$6(63r-L!-z%xgDysJu z6Fw2fM-wMqK4|AY*e{l(eMwZ&(!VjeNsK^GP^yvCXotx;$R|}mVY0T={kCx}(;V^i zQhn90vw5+rhr-srdA|qejC8St7&r_lntnq}N&+-p1#oHhW+xGX+{lrUgN7)t=`7dN z-DFG8wkW9j2s1Ko-MYp6*3VTXjNhIB;>FN~B^&2{bXz<3o zGIv`%BSOppmBneUGtarUx(FYxGc!XN`aTfA?gHQqJDGW)uDUEmU9`@ih_Cy;i~E~5 zZ^Y~#Dw-ZRna)|Zt5zA<44gb)u4mHCn_?CNWQV^oRt~>r{YaJyRpiz50#h^~W*$sX zS^wg%4oo=c&qwi#zZ2P=cYureWhkSB(hxrStLbE3UY?lRouI@j{UIf__|PgIq6sXK8s zFk|>9w$S=Hgc?|pitQ@Y5_6umJX@+~ke{O3eh)>0kk_*{*GP_W9w{3+TW7%n z6`5mBw`>n58T}k}iCBs9y=Q1<1)x4)b57L+3SU%3^&(n0^xc zl^}T~i00Gl_Cu)oM(6npmkG8Y2FNEk_|J_H3~0YzZNmxYKn>5r}im)@IVqLZM5v@AetDQJt+%DOwT=4YTs=%?wWV) z>5-M0IeYu|IO$D&W7@EBVU*H0r3t4eQ9r2abq4+A43Iq78v27BAQ8YS7i0RaKo z%iA(|&2&e!R)do$lyzcot417r1PvIQD(RHiEU|(b!b}wW2192KHChMK<>0T7SNjKOHs2yPy}C$|Cu`-k?zTMvZbA`S>Yhhx5uwOZ^jxzu0~k#96t+S3tr9~8X~Ntd z+n>n!J;ao(JU*?JnMX&!OnfW+Q#Z}iG%6KvNjPc2{o4@z-731ZoYzAhM281EGsb9_4_*6!bYYZDD|HFNP`~9iXD4@P+2e zKIi<{H)=P@6#+8F{+Jtby?pmTpceroHJ-XlIOrt`KZDkw~uUlwJT7m0!2fPwdI82*L~LGHQI>me7nClbxh zv9pWSk4`inEMjAV_~4~Edduz7+pNBUQ>b1Kuf!l5f%pd}p*%Ef^ymdqH`*dD#yE>f zEMhX2N09PbNg+*beDe1mgbZRd+SJFZD_z}Z$QM6-dY%R`v#e5@v4W>KUFuSU*ph@0 zD6Ns1(^oyp=7goa{rKgB7HS%qw=HTX0o%|g7hW)({HoJ~gUq6Iq&}&H`3fF>FVt|9 zpksGQGs>S!be==kolX(Nyp`+MU;32l<0G9AuzH%uqfMVaeVTxLWK7*eq^54Fmw(}2N44Ykq!$zz*N!*Z;o{OIW2C&ijfyH`v7HQo z>nLT4A$S_~h~iFzBz-CV<>48(8gmzjQLACKa7p0rADUel?FpqqR=Nlk?}WO!KEx`0 zJsv)MIERyNvc={q^0kaz-{9mmEZ_pUA_JB*-fRO7GQfoFk6EO-=8v|(APc}($bM?f zHTQb>`0={FoiX7CSFc@Dg6zlN1%CyPU?|?x+otV^>7pF<78m`l*D?CtZks9anRBcA zXFfj?NY9DD7WTg(X$V$&i)pv^`T>h;AF<+$ykBN8S1RX;g?KK>*Fvw#DjE~sCin-4 zEqn%3SkiJgL!?tEn(qN>`Rh~M)vI@a0!N##UsCDH+q{Glq!E{VtBCQ9*@}I4Y3u0R zfx#A^`s8?>z{5d0qc@+eJ4SxG5;WUDr3oxVXe6qY+C*L$#Bc0ctRTG&Syi_c&7_3g z*kdZ2Ok8Uz?uXd`pjp=$6txuQ=Jv$mM{Tw+GS$#%pW51N6Zk=8O-3hU;SfH()(&23 zp1c=Kk!Z5ZJvB8o>!VYdocQ3u10&BZ|K#nDL@9gIytU>tio!>mxtL*%3mX7zjeke( za5^gN!=!EO?Y+pc?v$04@mxU_!yS*^d;Ama!3olmZrJV}`6Z4veQXa5UyYwv%i8+= zi!DXe>akuKKxMyu)YO-a9Ne~k$9hXkuOzNgE4TM(yURIC({@iMu^%s8%e+ns{5cy2 zuR$BapxQsVEtSAiM-IYt-ZU`}n*SUNmfxo$^5n^r!}Cj*1`*(ny!A5jp;QwgMM!9< zrh>`qr&@Xeal%oE*A}^w0#jMC&HGqt~N-S!;j zmz(7?Mgb_bK}Ic*y82jn9X{MxL9fbPembNjkH*Z!m+(+)oHb|efB|y5wA8%zn-~T+ zsJ6N~;`0xF_EtAo|HU!YvOhn(PzN`@1uwGGwr-r3$pGXJTAN$B{f|ml` zQ(#Co+|pwXW8->6$HWY2*=T8bcf90JSMr4=&ywCaBu|qk9DfOJuMx)~%zGGggr#6$ zSZwZR0Y>gT^l{RhTk=yJtN!GaJg6twOS!DAxyJM87CT_nYI7Q5axEXu{Q9DUu$S8^ zj;T(1yz1S)w!^Q!@WS=U!f>+^{2??EVR9aB;zoLoE86tM$>EJGV<+o>xDm_J*UrFU96(dGnz- zTgi+x3L6^okE&0)xpk!e;ttnlI8%&?CQqFjlhFR}00)L_uBPH-_Jl%sc}RS`kpheZ zZ~ht2T((C+Gt+XN+%@+>!jjxNMD_0-?Vak5NeGv?eZ|Sc_XpM6uav!DtuSrcAkILH zV=18-u-Jakf@#yId;9xGI-g@u@S64O?-mpk=tpqT{qtd%73>aZaR7SgN5?y1T~=9{ zg;nt8xpO;;>=cI97`h5AL}dU+OnaogeGpd`45!9Nojv3dn&WSBr{;LUatgXJh;g$K&dm>~oBu;0(GFNk=; z17aRt_w_Z8Y&(1QY&O*Ln7069)C&{!qm2}MTW`_;zzRgHx(`JHjW_@=rlsxW!auO? zf1{GH;Y32#nW`;ffCMORbis2|Q&YvJz@S1Lm{19b`sRCv;f4VJqs@n)rwA)mbe&391H za0h#N>9izz>0XL+5?{X`X$M*iis=8b<=5UkIJCJqij!~VI!p+#W(ZrjddSbC0tV~_ zvi>KjN-xi#Nmf0j)UxiA3W$hNI?8sU=hdsq90!Z{qt1AdeDr5E;>{|oh@+kY^tM&m zmZt8UDsi5ACXtIMc${~RC4>p4RKpx@wKrE?8rI2;81WJiLCj){wMdDaYymuqev2D3 zd!Lt0Dckf6JDg%@wNgk45?BL7cf6wcR!7I37iBg>2o9IviE zko_ZebJ><%!bBq$tSB9orS%EaS;^s*=1aXAG>KP(rTfONU%S@IxavndLA;OwuKfNh z&TAC6#w2DeW8&fUT+&xY1sqrnOzw$Vxkl7Ce(h1IP}z<@5~~bgqZqK)+luHe;Gu8GgCKv^2)w4AXH9B6RJCz3M0dTtWASNkuD3Xxmgi z`l13ipCREg`Cdx_NtB{!wj5#}FvUTZ_&!`~bz6*df~j+q9*wG>V&*75*H^7)^9D~n z=JH$(m!3+L^OAP0-fbIzUJuLLM;G2y_Y|ykO=;$0)LfR3qU6t2_4^Ib@!yCwDc;{w3&buw$tMXOc#Ws+;6*I z)X8oE2s_*sxhy$DW*uWAA`&rl`@RhcuFE9(9WGZPgprQ zIWME$Iy40<=R5F;P(a8X_u@i8TKZSnGgVrfIksd{m##Z8KEo68_Iau%))hz={Wbhk zn;!``8;$J-)N)%rcI;RsII`?`*cU??jL7i^LZ8DZ>%tYA6G3limlJ5(-Sl)VzeaI% zX{uRvz!G3olw3Qkn^U4gRg`Sq4Vsjt5_31in~n1uK5=40&dg*nt5EP}{$M z|ND<0AB>8MD%(uWS~>LAt!crQcUncUnoB9jq5s@R#!YqQZHZe<&e5q;(#q)61B9KF zSE*`%*A@o{@S+K=wELgOC~GbI$FceVZSEKE^^~~*DLec45DES9NBX6p-9qQk*hOGk zTZ}+6a5Z{UcGJZn$_*>ELdWK;8zdnnq63BAcwn>L#pNXfrtk33qOOnM9~K(Al1?n& z*acyR?50jV-9vNxEh?glH&zrrj5~1t(>b%F&6_@bx+5A}%POJ2#=)S&NoEU<@|*c= z=P?Li&YL^eZPQMw2VGBhGggJzv=s5cuv+ZrQu?wVd&>szsNHMYf7Z+jf_{W9Ev=S7ss|guCeVo+?S@_d>~E#Hwg^S*Y*)J-MSK>DAuX?{c(WP+AO#I=$Ma9IhXNP zsw*cOmFfn~7_59II`^}Hb_n|yy9A}AOc#V_hgPy1h$6-#<|Z8dc1LuFz>?{A;>3xg zi6h&rsQ1(fh3m%;95_&5Crmf%BPD~vy9=(*b3w7?Zi?au@{jkei`9^ z_Atm@^U@9EMHEmE=?~IAM5HS@TXW)QQLb;dv5yF z-jW>=HtDo}{d(QAQw%c7_xyTPKiqeJp{MvBKQr>=hr`KsP)ePRL}dCm7%&r{)IgA2 zNfX|kZ)7;eeF2-2w$`27bnEiv%k8e^T*etd$h)0y)wxG$rHNMK)KNSa#`qXbwP1rm zB@)G70@DGNCm+5!J!6;J>&;We=;xO^gHBOFr2MTg_&YPS1kQGB%jxFn?I`Pt62I(3 z>Xu=RNt=HE{dX~ecEca)>gr<3Ef&nWA~!G3J20@@k|j$5+wA!%jIvel6@ioIWNs3J zl@X5#v*eJoclKsagBjk_V`6^aOik4j0P0`XlWYSxHCdu;UltL7ubm4 z(GmLNFrHShu-8IgMwrMs8qP=gB6%GI@pG*Y?3K z(7N4eKWVHhU!l-yg}{N<_U+rlYJt3Nqi*2#7A#d(HgA=%C;d-6eDjq)Q^=2s#{?EA zPQ|$eub&-~I8Kp|Uh6JThym6}()?<6??|6ZhFOVxhHu92UJUB1tL=Zm67X}Fh-y}= z*viecr5ACE%O^P@T&^^kvqk8Wfy<6J<6J$;qhTe8c~XQ1{u2h`)$$m;GS}KVQa`$C z{dm%7nujmIMi62Y+0*HiA4ie=5XoSddu>=AS)LEk4ac% zV-sb3PE#{mfD_Q&KXp+{mP7LGvLX-jL8q6tnZbNF~B>YI7dVXE~B?q6VeX1UehX zEa5fl)>$7-+4X(0n1w6S!zdN11_owkk&f?3Bnqs@_23l19@B(?P*zEeE++KiLJsda zpI?Wk-!5~R5K{O8D0O>H>hX?zROAs@nHO%|dR|pqbMxjkVqpOV;FPfx?s3N6W(UlFXE9raDLWEE5wH{Ti1RaUk|ZO)U_aQSh4GmA|i^V_2IMTq) zk&1;Y|Ljri7bfW37hK^bD7I)t@vMOaoz69ZWe!LyK7H!cK_IPsrJ|xjOtlZ5va5%<1GG%4Zb76tvT37*#kt3| z|2*sAlpq8@uP`RyaV>9Ulr%+t;>5*mpYiAzefK8yc*gJ*X{md!4zne7q=&Tw+%C14 zRkpU>xbM)57bWjizJ7h4mOE=pD)mMmDg_F0>Qu6C7Hk<&r5I4r`0X59T4aW6CesCZ z6t;^|vK@pF+99#{x|KI>I#%za0AY3X=H@q&w1P)8sIhMk7Q(QId1ZP$q72|I4oRcM zmxiq&X}#knNONBq&Ee~bEZ7h*{KBbTG!b>~tzLW0az7^`C};?pl%Wl3ps{!2)~;Qv zcA1so)s_OHvQ@pjUk3@e&v)XmpKU)}V;66$YZ&;Y30~U~jKxEoUVqYQn1pT>4WB-3 zqE9jSY2>5|iSx!Crg&F|ViFZwwxIg13PaP%_gMw?^;MKI2t*PbaAU_a*U*KK<1$uO zwz+Sc1)D{hI0HO3%o4p-EnmsNcIV06+CZ8Cf5EN*F}z2Rec)By^Ba~gjz+5f%98_# zvr43k9>W^q%6((6Tjh146JfLOC&6?S*g4gn$>W9(??6Ulwsp9xnW^cSk0ZYUg-6Ak zXlWhXM4spr6)$2&VYOhNO$JCBDZQ-Vi3LaWPMq=7dT&C>=i7Gdm?P^wD_Fm?Sk``1 zhPUA7{wv{)@d)%9H1$GUc+(JJMODU1luVbmYHnW)3K9o^6=q3(o8=Bf`nVJW;^XklDB zf8nA50ofYQ@9SRvw~`F@n-zxcCM;td;QgB1KYf|JGHrp zA8%hES8EU!a{uSB=JPR6DhhduUBk#S?!;Y~7ghaC-HRpVX`seHAeNxJqK{SoImJT|Q%8$=i~U|_m@Zou!jV3B@NZaY}M2vkr#*Q9+7p7jfvpMAi zyk#zxgQJ(<_>_|X%~a6mw@p0m=y&#P^rv`1iHC?-Ydh9*m-iTq>(}t~$ka{cw_GMt z1JmD#@PQc6*mf?&x{gD@)@QPu!NthA)1W|I+w^WvyAJSN0^5FLjU6LT{~){#eoC@7 zX!`d2+x=UZd@}$#>D=wx)o6~PIp6@8%KxJ7&BM8F+qPd_bv4hWsL-g=oKz?knlzh= zqC(14NrXaOP38tvDveYKQ7F-j22o0gD3Z(xDUtO#UC;fjcfH&5{=3$;e%p3;b%&qd z_xnB1<2;Ui-;ezO3mY_{lii5--t1@HU&>}S{eAI zVWUTvdQ?D&Xb|vL?U=mqQpRG;Pv)|D?Tt&3pUvF;U_=33T0S@+;S-pOARZS@O&IO9x;(>qMwZ67#uFP z4i5Ut4tSyje2B&b<_Oe4nbJT4(Z$=OI^;DMXsl3)=r>@1eE)15a?rAu-8w`PjSKL0 z^z#=TCrz5P6f@cdr*{zg8pN)(~PEM2HpR-9%& z^duMt%U6D1gez1cCWHZ=r@1%Y96wgwu_!G+i;s2v3Ge@?HrDSEoPL#IsXBW!3Ynz)&L#NoH}BP7il>>4wth;V;u zibjo9oMo2QmN(NUPF!Bob~De=7$MZVq*q;&GXfPmQ5fxo%rLH->G*A^v0ICM$ihps z`}W0I->o6x+`VOiQ$ybGD#PFH0buLZ9nGj54rva5T_K{-!Mk@=DI}Fr&dqU?CFywU z?BnU#^^-{{5J5$=Dx&s*;J#kln`!MocqDKn9S_`xnokGO0@mN$vYLf?>GW_##eTS; z^!BW3%aWRr(kk5Tu5c6m;+nNYLFVY84k1Z zdTgwcV5dV&c|Mxt%Fmk_KOb)8sQ{9#b!~(_!fbJa>)RTI7Rd8x0Qb;GjFoX?6!NF) z>c~d>q@<*SMY=8$^bAg&A>?zc`GZ~6y$~ld{usUY(w{Y&_FcPn{f3XyM}|-TBY>iU z_-*U0u^g&76Q!FA=&>Cyi4ap!ZclX2K7sCLWpo{<^G*eaKZINV0k6m5ueCa&)a~cD zVYr}Bz0T=es0cl@zvVZ#_szE-Ki<8#Ne;5gKGv<(^nQ#rGno$&1<4O@Iyk zccbPI!pYlfBd4O`0J)ftYVQUEFGD6NU?Am*VN-iDPm)d(j_Cbu6uVp9%NZLv0*}U4 zWQC9JKRV`hXd3(FplMQK;sMi=9m=p6hrtCc^}gman_Se>Gtaf{`fDmVaW-eRUynmE zv(sx&^_|iKcZWMWI&J_;a-MUM?;)*ob}r(X_VvYjMp9Iuyh7}$5iu?GhWsBpVP{-? zd}_`8R$p)hJ&U07$XmDii}PotH8*Y5eOH%=FDtaP)Pb3ADQu8a**S)Lr&9*3t&FAg zg_!LH+1UemSc&?-Ja*XIOM1R(0O{#7yG^dLfcZ@G#ceNfqy!biX^zGsFg;KNBQnuA zam&#me8~EjL=a%_Lq0w}EB8=nAL=gs59m`dVe?T1B;`gp8~T=#!Pk(8hfUiG{G4-h zbFXeeK0`;1`dgSdyg$XB9&GJ$%y-0t!i6o*%%(5o_Q;)~4*9WF@_wss- z#^E(07NH*AG{brm)id2UJv+}?(WbwnHl$T}rbFea?TKFtxLtjwZ_Rh-5ts@lHR!UD zkx`!k1I&;5$&nSLy{3aDR^{aeIIgc*;RK7@7a9x!(nL$($I{ZnoDO4WAG`xL{1Z;u zUL`c@iZX3I0t*y%I>qJzlj5SHZ+!bB;93s*zUf!-BN39~VFCIR?bgz@LSCR-pvr`V zkjAq8v9Yl$^GUf6-fv!@sTpP<^65+OheOlAp6L9(P*YZBaihWS`0)d#yQ9Dl8VF^gZNa!eDDT;#~0x42Rb5 zI&gUV2Lfaj!WGcm6?=YUWMrh$Rn=Qs+5;5&6XArn%ThQACTD`i?;BRG#!yY`sNe}%nsNVL!gbHRwlj-XJJX{;$+tsHgua3>!B)rQ4A zg?~{WV690*&U1w*Pi0Ah7m7TNBN~a|=PzG|j~?BLZU~x+9?`wg8#B=#MqX3#vi?gj zth;yr)-t`{`bU+`cO47demCsV*pBPJT3h=cuE#c%$~U#EG|4)Jsy;*Q_tL7+!xDnkB4+hM zaIdE}Pc&|d%Z7auttQPh|2%4Z7rwd|8IdxeB!QxF>Ln$O*aYBnOJko#x_qNfyFq4^;+wGM5Ks5r)$skA!iC0E$oe&n^0~$Ak&?wpv zBDw{kmQoJ%_ZMS?1f~$an}8|y``fo~6$Yxbe?TW(pML%JGn7;zL*%(zcy2UWq`+f| zE`j|KI>-yH-{0y2uV#8tF{?8=`-c=>M`4`^q&Huo^yg*i&l`=vIevWCAht9`fc{5- zZlRMk_=>XO!2P>FxAHEWt%}c|J%;2GNgFM0I^o@d+I?iLo3HvhK0s>6RhImxdp~#uNx~YmMA(OV zQ>}00mqK^Fevxl%Y)tonf`DTO?AMOnv~i!=32ak2Dz;HLq&;EBL^~zqfhqJt&OH}ZOxzCQr5tg;c0I5@ZAIM?_Y4;^97vZ3 zk1^HW%m$=8KRk1YcIc18gI!4cFztxSOnBa`v+nwl!X^D^mjp3}MwF|rt|l*22I1QuYc7jZEs-q7q}L{$HFyhqy%{jes*p4%T=VVWUS%`u#((MugfLokt|3%#TxCt zudO4t?Wh0y53?WC{>T0NY^gyZ*++18gKoo>{evw6r|l=6N%=Hx&Z3MqiOzrj;Q##> z^TloHZreWi-~R9y*w+8_3LGWilmEQ3)S*uol8wu?*wLuhrV#$m`|EnNLEJ2ve_zFS zo$3e%I#mD&C8oleh?W%QEtv}~Xn?J9^v;qB>RQlxoW72-x z?n=+31)8SKVOuUeG`{3fkYe0#*~sW*P4A>s&mQ9j^nd94_h9#WSJz(6kJtChQNAfz zQsX2xs6hWY ze0brz5vWYwj{RSjUyIlOzAv%kAP@yM*jZa&?}7!wf4`o~ZvKB|ftvltCP}RWx%fje zdx3V%a{MNNoE8D;+#NfnroDVA%u{pd?%Y|Hx!u8G0d}V=@vlXL!P)rVpZVXvT4yBm7Hd7@pUx--z#I7s9ww9FC4N zNts0JtIIXxM8|zl<FRiMKzf`GRMZXYsYlcs-smti+(W`jbM zB2ChOI9gr{%zgu^gO1~Y3|El^r|bnnDR7n*(r%7W5{JNft*n`)X*6P!hGwvwu#wm+)%)BLcPwCRp{rhcjh+$$*!(bqD z@XE)Fr2D)+Tn?Q6t-gMmg|O0orjlaI4*{KRuy}TS=%qyFt~rrT>`$4lXrA@}ot(JA z$7H`fJ@j-D1#53ILID(a=wQ}zkMdgzjVos%vI46|-HFcM`I_oR4dx>%zvEQ<@GCccZ)oEB~tuo$)LD=#1ZvxnC1P80t(q{L zJ@mgeuMjXHs~Y9!j?4A%W^J#UCA9i{dVKgbK$$Ln!Qt@R{ASZ{MakX98^mv@C4H${ab23l3fNA7ZYu%&g~06sfDd`w-Pgat-LaMzGbs0>V!7P- zZ-S#KaW&8XKUx6IZ6EHt$ssa0aP{iwZf%!*`8fUhWTy|oHGw`R+Y9Da(-2AAs)=^! zA4M9+l9*wyz(32yQkx30n2$W!LY53$9R`YG$s}`@xC~z36R4LiQ`|W=)rW@ah~<+R z&Mif+wt-sy-@0+pIF8Ea2f;*MU0F65e&vjcB=!A~OyS&s2!noAovVLIM7H8 z*nxlFNG!!S;F6+qUu?g(p<+9h>1k{CrRT#I5=-YQCHKF2lO3Tv-Us+|psjEFtzR$4 zfLRLpAk7?CmY2*%aQy&cdRY)s5b;`d4kLa~it+DP%=lGJUol3w436X2P#*B_2hdIy zo)D|jh0+WW$p##?|Ld)VF9+r;mPTLKG18Z&=u5c6kI@e9ObqBPCwGLNehd?eNF;Bg z0!K9*>p2IM_tw@AgM{a@MV&@99p{cQETe+^d z#q|o1ul?kgK0pt#ET}7={Ma6?JNga+#Qi3#C$HSAwM*-2KQaTln4%*PWp?yNl|NlMCa zx?rf}(IdZ6NGJGqBq9~m)1#p&*Iir)m5MCdFF#`wyf*+q{>qM8p9&Il1MjcIJ-@V_ zf3v_5uKC{-$!pmb25jP7;kR`gl>l!gtZx&#Jg4^&3D|b(CyGc>N8nQ41Lt?`yz{f1 z`XZ_v2nRgz1FWlfy-pQiF1`6@jZK(qJQSrJqIl7=4h%3!Z~iNInMcU2X5JOqU)g-o z){mq&-ci{9{F@5zqgAPh6`Yx8Ze}L*!J<98;JerYQ%7=L(rf5DUv*LaZ{2Gv+a3b- zjZgxLZLi)rRLZTnw$p&|Yb5OK?AmW$qM4%4N)nn8)4>$|1n2r%mEAwa^BLqn+{2O` zvzKmSEQHQ|M$tS#)P@0bCt)wdy7zy-73H(YN1OIq=iHbVz2%VNRq^zs&r{Vf^2P(z z?In7tPSG25?`Kgr;lQGj0)U)uviZ^x=9g#nk4jR;@?k7uvIb-zhgocwDGu?Po5n{F zj|}yQmQ4+{7RSY-%;O;%nsR4AafXf^yV&#$0PSG$vn?O{bUSuTa?DBzC~9>^-%SSh z55;t6xac7wx)eUh?Q<*|5WNR@m7@J1O*y4u9pG*eA`QzDQ6G>&LJWpVDp`4Nj}vX> zorPHqra+4Eo@yfaJPcU9Bz_Jy+wDY_)#O-(!Q{Y6#{QOTdB#{9u8caE9>_s?5LLiKWKm%) zkPu;w0xcJ>?qzMrv|D5ioVYTxm*}m=!m7-(NX6GI}!F@Wvhn0(jAQy(soWoXuc@O^*}U4T;w( z;{YZc>4itw?h)*R+DOjbqg!`ZyqnP8ye)FhIe~s{%p{?wCw7QxG&yeSk^;8b>))WR z{{TCN#?hue4-%>S**@Q1&M4eK_YC;RtDK2!Yj*B@QmOj)Zo~~jwJ=Vb<@CJ+avEU! z{kY{35LaSbE%pG~8LZR-jDO2lp{N&DGXhEj&^iQE)_K?lBu1?Bi$v0EjBR<>dJ_Tj zI;=gc?k`%lOz`DHFt4I5SlA7+)$f>}XQJHS82j6DL=tWhr}UXVeL9foTqdNgWav8c zoeJJs3^H1}v^(Dh2XN@*r1c^k;fHk?o{m82h3@DAj@OeF72hbh>gx6vITC3_5n%Wt zEvm0ny@rHQWM?Fi*F4>51RtWdM~YpO_CA}>7Sr{bxVh8;WT^dNj5G+N zJ?fVho;-GJA1Qnh=k;wGvxQ!q-iZ;|OAvoyY4?hiWn&}cTjDmN;$+DT5c_}uGE$>f zXgHFX`^wGnr2*g*IXw!1KFCaO7PK<$n_fAccXs%yE`$QId$~B9bW%8%pkwTd+xk3O z=P_sZ#iRZbm4F~Z@cOqYpEmZSpm;*wGpVmVXV&!X)3*hT)#*sVlLFo>nFpG^a#C_S z`|c+Qoe-3sMZv&U@XTY9b?>cPw|=Ev2{PD=w=O?>_}-HzPsZGH{qp6D5Kg7Lz(IW? zjtU1}&e}t;9pkD8a@CsP@xmsHhh=H>6L)*3lAI;+6e4kyBn|*j zwiNh;i^c;n$VRU?Q@t~7MbgB|WI$z+&nJ=;R;y!ClU1`h_ax)QoMifByX5_S999p{u<__W`KjNXMVM##T-R%!y6L&7^b95aAyHS>=2U$co(fR3krnrovAeha zoKLuYdjo1#CRY0B=M8>A3>1`%m9-t`8v{T3sCULKVvrD9(vqatxM}NFUm|QV#>|dU zjw(7PCc{Oa1v-PIT~pM3$e};NOCLh-t)I7fG-n)#Rx)d(_VcrTNl#EJPK62FC(6y) zIaTH5o_yKDx(<`U3rJKAlRNEYofbotCv(t+bh~>K%Ef~xx`CGG8W}4KjjG?j$mT&) zjutH&*_!VwGQpnUWhsC)l6e#Gl7vuKR9`Tc+&UWYDA@sR#SUS+9DmEBjR<5S zo6f4mM^1RWhSoiVhRq#0{~}L6SqyQBeo1s+c}cU2#J2Nq_?m}s+a?@jw%{^o%3={v z2#JE7H7ZSSp;BOwRaoTTXrV7s)Ch+)LAfv5HhSFb*)nK|lKJqqbgdd;{WWVq)HcLXjjuaduNzGwWVDqw_U9h-S3y1xpWlfEd{F4b^ z!gbTu*G5p`h~Tap#&+eF=}F#aNa;_cQu8O6%ChjXdinX z%1L9`*hfZYf-o`?*>&hWe_1ZBKg+eBsO(!<|3(Vc?DbpDuHu&HHk1CjlmBMb_UqgC z0JNgH&TW00FFdSU&Uqagt?O~`LkJGziY98mMR8M@|?n%dMmwF%;VLr3@1o>+M<1ll1k+l zQKEZ-#dT3sR19iePg*F%A~hwo;McE1mJCd)>c@lQ{F|R;r9y9Y^6kJW#JHhaC&q`K zT=z09O)Sd!CMK7{_V`#3?mRPu^FPVXK;i4{>|{j5A5Jd*(^wYJp}&a1OT-=T%sG>k zsCK{Q$C)wu*H0DvD0EY!tQyG$M(PPK0UhsFF^@)4m2}lxcDCWEGiNFqei_oxkF&23 z97KHYtN!7)6N}^#4eBa$V$V+?RrCK%NwB1RCPC)!1urbXt+bVyGOk;pR$X|1l2BE~ z7vJIIdGQIo$oE435O1TT-e8jndvi>FVIsx8=MZ?7k_zVk4YUPD_Fq;~Ha{RRn` zWF3d}x!&sTq)z3tmKza%u!g&5ZHL%v?==R?Ew2@K1nE+68wKobVy`xRTq=YUE<+7OJ{T}Fy2p7Y3cox zadw%=NkS3LS7(*~QnJojsc?@qu3wxEyrz9V3R7?(uVJ{j*Y883ZofM7jb_rvk0!` zq>|R(a;=`U{$Brrf`Uiz@Y^!(KzNJ6SmN@E5Fo6sEG1;CyT_L*0#+XV)BOb}&CNWUy*kI0)NympB!*9>l6De}JcAn$IUO=8@E{X|h9FgZ*& z4+wzH<*@2vG91qr3GKIS?*2A4=Ix>Wr~}AmHgG8_$q+>(QoTTt3eIfYMP|B%f!SfX zXYFE`xF8@kTLbG+%mI|bZ5aJSXqI=1mcW%s5J3D_Saz{BO`wxWG#Dd(1xY9$WsW3D zG@O#`XnlBgy%(H|Fu<(;A-%v!XmEnx_mES||GH!xrU*vG-Dd?-&Fa4V)?Nbr*+Ldd zG`*cc=631JL2S_c{6%YQ{ymqqQPEw%MJU{4ZkQx@Ez=t~4dK~5{%?$Vxf3c}ASfYN z3Nqcy=afc86dvzbgHNpZ3Bd1*Cmq#03!(#g1DHtGlBGJr!MB#_(qoV(1mbesSINpu3dy3kAJmtA9#@G8BF6wXDGDZ~8 zoKzLl&e%C{LBbahap5PbTW6n^6__NdBi`>LC%5gDD@la=+o|rzTx9f$QfK;gjF(XgXSL}3uN#}h;po`?{Do36XfO^Zy%8G<$Ub?Cf zu92;3MO?77W=E6bDT(ta)N1KeV4BJJZHq~V;f>b~fer-sL+%yTXs11V>a1K_DVrm9 zkX`NYIwVnEJE=JLL0*j_r$kxBQLXT?=9yP|4pUPzwjC9}lsewfNNx^OKi`&$X`LWZ z7gX*|5flG}uXVYVVXHrd5a*Lf%M8Zt?mTR=X-ap2(1HzsFP{?0|@B6%>`%ZNoS-s-O?`}Bx2NLs-q^e45_B1(; zl9w+6qt&+AUH^6CN9#W?^Ye$$^C9dy!hVXj_MQ2dq?LIEB?z35sR-_=H4FG()8M^*sXGLJq4-`m@XL zNHXrnIFRwHl@>LN_9s#bXPJ7T2gw|$;gfL{#}HJuQp!7m3ywr78d3hUE+0|LD8(s$ zHXDg!W2O!Huq~#HSuRg~Gibw{XK!t4nCLh~3r2Q?3=-PhvFr1Df@r`D~yHvGVYU4<@(Njg8WqdJxeEu=t)Xw=wx z^yh;-lbZq|hYvaVuxaS<$H#*@qc>cvm(Ter9h{$8nw?d-{E>no4BZ%?W2nn?28;qN z=(?ilD+|`(vW?j5sKiPTv2kMuKv#6uTVDyk=!~h)mLu=b@D0<@IEVFZJ{ol_C z)cAlDZm8|?Y~}A~?Ag~y`jUDsx<3wp8{7Y|_`2HK*LmaL#eg=wM)>NNe!>>tVdgcr zqw#vWmKz2CT-HsGsaAMZCiU1M+XnXcIwW;|1*Sk`S+#sr%QGi#A)}!8p$Fjf%jjiP z!K+~RlGnWN-!4?<{OaA-=jmi72^kxAGbTnI`d2r3dYm{PT4t^8yiY76!7UPe>j$Z} z)oo;i;RV&WU5Y*PI^AM#eE5o&XO-VqlY|+N8&rL2*5RW%M96m z$2LAK$f+(W8ZM(qV=PBk=I_q~l?^aqqlfNX@n%bA?X5%H`~UIdqKqe=jzp~&qn*7 zQnr&vscG#mDJNh^-eU0Q)5$_=O#y09W9#ws8)7jtbuy96UAJEJ@3z^H&my|7Y~Ex( zXNO&_?YSdR{Uh?_to=cjE2g3p)RPe{hs8{^_pG8Y_=ox_?0Jdlumh8S?#&itkYz`W zVI4a5lC}O;a{DIdfvhN&7u3vngReA1_@=j`N>E1%^QELzX2)K9)-AQ{@u<7ED8XNL zKkH#jK_H6e1^2^M#pb%}pcni6FPF2JMnuzYos_KjWZAUVdPtfZ5^tXNlv}v+=wJWQ z0;pZLywZ(>Ewn_T$0zj4p5t<)w9w5%Dh!|f;=(-=<{_dc`rX6s(3(ed7rQT#1Kt!# zM^1WLT6Yr26TYf69|`L4s*zuMmLcCU*Kf+v?U{ta+Z|V7}!n>y|qje{m~R4 zVJ2?lg(ih--Tyw3jGhc2zQwBaFY|VY%m(nrUU3=%2 z%F=n7O*m(dw)YtyQ1StR*|V zP+-+jKpjabBZe2FVmsDFZ11|Y{xJsmQq;9cLKR5a7s@e@sg0Y>G<#u$R)ZD znfYquuAP67zB;e}a`0e!uop2?D;6)(8QbMkgfCWpg8hq;E`@jOI#D1@U0mqogDQ%7 zuL8&2KV%n-&t=AHk69Zx9?V*A+kMv-=cLl$#TV5Nu4MAO|3srfM3U7>EikF{>-B7J zW^RZeYA!QfMGu+ONLRiy=7$DWr&)O~szg2}bjB`_bV!IqC zhn>-mSKwa3Ru&|*un9ucteno8u7Oiob8E}7oxg!np#`4qmA3eCzqKV}r0F|?cMAF( z?}AC+e%xI6Bj3ke&9yvgY2RbUqQWeeaj*GaU+=FlQy4k}9tC^awT4ozqHsIJkOwwq z?ax}%EX5DPLX7%VjKG0iZ!9PjcL-u3Ib;?Itip#v_SY@Lzu=TjwG|D_;sR=ioi|$p zzQqj|)RqV*u%u{biYe)x?~ZtlfQRZ>kTqoLj9Y6jsB$vRQ0ai?`2kzTTP?4Vg$L>lg!kszy61=d0YpbA|RZeY03 zTWTL&$`{~T!{SF}&d?a2abhhV!q+BFov`Xs2VpP@^zk4!R~pvSpsO*h*#J~0i&IV-wO8jWE26k7|TJ>q&K+`7lkmM&$VRz z(-X7r(O4-*!KhmT=e}8-F(jE9?3WGjj|dPM1X9Gd7E&KBXM)ku+WImb7FkDls(^Hj z)mQ;#-|pb#AjARa(JlkQ8Y7t)AyODi^`TX7wpeIEWVunLk58q#i>QCb^rsbn5-sfTvF*zOh6Bd$`EeRVyn?2Uuc+z5djhaT&*?NS#FKWIqwn(KYRm=l4^tt26A+#T+RPOvYfEagjzj z)oep2ARwut=O-;agW9)mAKN6WhT0kMb5z#!NkdBg#~buzoZ?(=JyKG4EKb;BR6#N- zM%hozP$Uqj9zTBEp2(m{k}4XO06fD^OVqGVwX5*roEGVD&QJ8{ll)|aX$Zk90DpD< zo&7M0*Hna;!~@tPAC*8IFa-|(>rJyaZryrLX`>h)EQEhTMgadGrf=Fw0-cSh7zv^} zu`+#J4s=*08^d;PH@EgR_6z&J*fv%h3G409b>k;YNT$*XE^?*4NPCK+9OA?QjKJ)V zjwv?JbjS|rgP zC@g&YXPo>uyF1?M|Uy+A6q=*lYtKPTo&9- zNYE;�FgozxUOq!qrY81uyaAt};pn91AAv>8jE8zKNPJ>%1Ms@9J48qDaSZ^;Nmu zEsQ>k$an`8Z#FX%e!WXlIjO^otfm04X5GPL-xH2-Ov(UtcOH&ABNerhC#<`DLwBiZ zASNPV*~y@l!cI%fJMbRdx9@P8P;R|nVHm$vn4L^ZIpX2b2WHIQ_ZOsUa<&OWB{QS; z;r9lR82_dF0&M8{dxu);wKTdxs+t`bm*&PVgOWb%sMi0;g9Ir-Y&FNECO~v3^ljH;` zD1@qV(G+=$WkxfJ?7;=6vuZAe>?Cc4fg~f~Lz5(#q!bzEdl(RT#DIzGPE|QNvRH6E z?M-IiSL8!Q;K(!cn$3KbpzP%-=Yr4GpxusD={bZ|r`rsOA8P9X;jERc7);3(<|;#& zeu7DfCXqO93eLG>R2(X=HPpso8zwb*!(89&wM3gIWC#cO7u{J47p7XL9&5D#V9Zqh z^5WzQ(LIcW7JeYux|jo5zzOaE6+RFHl?a$!2C zt)LcSTL3Xa&PBaX@Jd@K5}l%M!!$XDB#R}-uUW|nq?J1D0H`nbiHVHj9d;ykJbC)G zf*3#AG;>DXMsss{)U^sgO~*_N5G(uNTP`v801j671`t?J`fj$dnXt#jC7JS5DaG5% zs~_!hmq_|DqX~!Mr^0*hNUcCWw$l41UgvtP=ZSfm*tX9}ZC$ZvPofZT3LH1=Hap~` za{;}$N&3ID?jE5$_Y}KRm_`e~N)s3~F$k%Mps20f``9OJ$iMg38K`>r&>=wwTvL!` z3~I1$B`Q!aGTOzdHh_S-$M*GBM_Hz+S4)eOX$kJDa3xw9?!b zqUlBpi%_W*%W7|3yf$Iy1l?;#h5NK_CBMz8L1hK@8PV3~V-&3*utC8^VA%os?}_Oy+i z;E;ZD(s|J&NoQ2L?cM`prcTwibPHQ|Cni)UE0^yMDV%eD&vy?2h=u*+pwm;58d>)N zN3=Vs5Wf?Ryyd8}%Kwp*SQV;0O0d46Kj9k$=b~*3dm1LS!p1h3tM;WQ|1;}5Gj$+G zbePHDwNGX3x+|iGt?9b(UNKo};~7*{RV5-y@S5Odnkf&FAx@R&>71Xv zTi?;iDO4d0rBH-!B}%n$75PDm=1YC&N4DZ48*EuQ>FV$^04$4b+^}K^O?C^GlFXh#xbfOu+$?;ckMzck9(_wY}eNkG`XJj3455 z5@!XB_rwyzxMiht;@&Y)lQ%`%F&rL&d6bY^O`b91r0WU{OBO2pp&SyT#GUIH&sLTp z5&*p!D-EV`tgIgT)9Bo(i@YP$K3{_Q^ zxTQ|Q7n7@j3p85Mw}1bUv)esk_xCRrJ(Jqe=w~93?zoy_nvMZXTd7x(AOgO#jb$Zt6ddz_UsLgGg@)}CO2Je zO>e#M{fC-sN-A{wk-Hg*-O6>xzffujvLI!;dS{;(pDoHqlVqc|S+!1Ge(2D-mR;(5 z*_26<^9hj&X@Q^R9NY$A0;HKPaiY4IBO`@0IQYF7M3gL$LDFl5I2Q!=CTGBv7B}48 z__(9+eh;&`0%}qW^81LKreN)t8LBZaE>}DzL)}i^5E@6B{uq0SkdxnT-n{u3ETt=R zysxi`R*!apVG+I<6iSPeR|`~@3Q~mB2H7)e6GcH}xf*V-{bmzK%EAE4)z?-I6Z7`i zH@cFAR8cU`^Ps*twxV)e(H{zGFHG%gAsC@>L|7sh9x!H=fQE%{mF{l3QAEE8?HAfr zFOd>sFcP|{fl4pL#6PYFq*fauX8Nam@b@CoR{r^HFZM<3RDVT+hH8(ZVaDt#$aV^P;{G)UPetv!*D!o(80C0=$Oc zpBCrk?*2;Q-gL7$ZwUO`KR#$jn~8zm9tfw({5$=i7Xu9Cw~NaXoS2%wZkTQt$fTkj z`yp6^g)?Gs$|j}72y79_8sB&I;p{Jh$a8FIiVc$zpP%9|`2BO=;`X9rC&c%RBD#jb zVrDlgFlLsR{09I2$d_P2YIom4egtatH!N@AX^cSw z?&~&C{Ud<4r!7fLtjg7~sMv(^n+G=TT)1@Ix~<$rl;8es1%q||?LO!#SoA?bwzxOl z+EyrJ%X-$WmAQ)Ii#%EF+AJ|55VM*tn3N}_%0FKsoWx;d=ySN)a~SRN@CCw7s;HsS zO^h?dMrCN00}?W5xByF4*Jek()cL5+L|^>g`0T9#^gW323OIGR)HH%3?IWKqn(~?Z z(tpkf9c}I2Qc|ai3zcL_+3#Y|lR!*U)YRS;4`E?91C&wLcVR zmVyC(z-`0s_h>@eAy7|}-Ec%cY6h!b0y_>)9?B>ma5i5!#+15SVhYV`X z)$CnudVIJ7F>bP(|7M5UX{69z#I?l@rdVjrRP=ct7=-cj*x7kK z+7|YYa7u#HDtbWztFhGAmM|fVs5HdVp%v;XG6r(3X(>f%>CcadGcdglTq(1biu<~X zE4lMq?;Efo`raYt@;Th^&~&9YtcZ{%gxYTtdb#)o<=!DyXe%CQ_B&lNuo1rzl|KZy zU=s$E&o@2GB3qoCi{P((NZ}|&yG7&`Atq%iH|scZLN*t?Zg6@bo0pf9Gp2g8WkCku ziMHayKF4fr&BckE8yIBw_%i2!c6j@)wGR5Lb8~wh?lJgEL&YwdE6cJ@#+*@_s;q2a zY5!DERbOK8D+)3$BqLf9-78RfPd7=0;D&~tF4~GW>jEp+TyEOBqUJ-M~Q59>lvVxmJsk7szz%fh4GhODkiLYq8 zENcjwX?v-tF5r%uH7A_V%w$!BuCO$B_yd=!m%5%XyR(_PJKR@p&>#iLwdPms0t0$b zO<(78+Q2hMA89M2GE>viS=s*1Ev(ehu-vh(1ygIf%-x!1Se2Ger{ApfVNH9VewePM zxj-+;*&al;E4WjDr5xe`k!r+=v6lR-5VA&mkBN>J zU*0%<^yr101?CH$f&MHh3FsGichx5I7jF(zOB6vlov^%%OS~8)ZrEUoSdc|@Y2`(A zPa7sPiEhq_rf191R|~Nxc>C80@-F#io3)|nrfio#6pOYpG%%9OI1FuD{x(~jvy{eT z;+8C4{E!%6Up_8MIGzEZ>uqa1-`cz|EQmcX#04t%@bb}>Q}3J=ZKn-CRKkh+m!m1DxISW@jryBf~>WW&T+@7;ypcyW!P!-k~INsOhwm zj8KsUzo$&Kus^FK;CYJY=%CE#{k_#KLr#H(E-5`B6IiR08~Dy%bm;%9q%@b=hur6# zOAAhIG#k2`+~w{HKBmI;&oeR^#@+TiV&|?lRJ2=e5)w_0=4&e316roak#y)@JT6ZY z0=4p9-Mq^0$CaN+6igg;*&}kaG&xB2%*g|>SR$iVNcr{Zgl{e8Z()-c#ErQKY$+_! zg#wXgi!HGq?y|vq4;*+kClcsIrNQq^$5+T+CfTctL36@b^y)n%WgZmJi(|nsPLL>pSZK* z=hV{Dw7~Xc^fqtWO9+D`ROFsjI9=VGu;|2{&?%5t;%?=nyqZ53E+{}jKOm{#=-Qt@ ze{P8ls{EMhJpcStC6!L2t?l@d*SzM-bt3tAOd)=q-qbHiraE^%;rCOdG9&#p3O zoownM?5NHZ7lKj>a4>nMoT1DkkltHO_m0XGV{_FGRS!G;wAwTeAD@5(`SlToN` z?{3%B`lqJ#*5V}fBA!fPvryKEUP{XXqgr@2kl#x}+6 zn*t41pfZ7*Q)w8X=oc0}CTaCl3?j8BwGhacmXBQW#n(*m2xMbNj7?S*9OGEYs@9MV z9;m)3V(!PUU$4b?b3-F?Vo--H`=OJgtT?jt^8(MDSz=WAX1bch<%}Cq6_v4bs}m%` zCJB0r13&h?{*R*A(8hk9I&~5N#2`7~lQ+K>^b9nx@E!&i&4I6Ljem0Ro7OKzZRmyD z*9yfCFJ5@MJOL5_*lvy+ds9mddV%>!X#@?)(;*)Wb?O%GS6_hk6 zE_-xfg&i3jZlni{o8FW(ot-Y4Gm^CW_bJ)RetyE;)>}XD*srr<476ZwmQ~Fbn1$uZ z5B7_QHKCXEm!&ET1s%PLx;;;uY*A73S59xImZ;LfX0)bXgwN~gYD3+|hj>(W9gy&D z*i+-;X&rM_M&>p~90v!MwogmoR8x_e)9;wY(<3>(k3D>e4jMoL#xdMo73QREPdB=?tS*xWc&J0eiObZzT0u)(6_IzM}+PR zAj$X|aXT#~rK1oSapXou8#y}86jj@aGnd>H>zBsLJzsiTd}zv~&^J#5NLTAt)n8l; z(=3#b`+ta0?Ifbow|W_f>eetQ7@3#z-Bqat1+sk8W9q(XakMUI0EX#@KTfRBIAhDS z8Taq9!#W`Eq`NvKY|UsCaFz!{NlVMC%<<6PEzmGe#TT1qZ2gMy5Fp?PzN0t}`Q9^I zbEQC!MgBrUCtRyQ6a%YRKBXLop77Y3)PN${YMM%<_f)JXN{Yz-(8m$oD>^ii}h8D7&#?AXno(q z?n}3E`ScS~MUM+!{TOn`KSBOIkXVw-=MlzJ{bL6#+H%@S|I0|B%ahaXV@~wzac&!BSdT&L z7o)EeOw9ftwx1sTy%ZVC1(&Dz!9J#Dmdvw9&lROu=H%?GHKCUs;+!C_a<=iMH$VJm zJe#O=#^2xUWpBe%Oz#z-&Vk6W{!mcrtpJNtJ&uu`K-(5rbvcea~1tDnCSLz-uQ15o};OE9p-CT!$KIc+1d&Z1? zjaC!hb-$C5ySty=^7P!&+>8?~r7Jl|$NL3^ERj4dx38bm?GbFH%dB-5&q*-LB|5L`BGzE5gL1iF5!3=frKj79Zag zRi^(e9VdI%hZk9B>5e({{Y~69rHh2b(UUy)Nc`tk5{?Fe>aP0xGdi=&aaCu}PlA?1 zU2AS&F(LkU8gTy$vr_!qJ|}#AY3W+P_&V-mHt=ZmxQxq@@&k9c~1|JuA~`Ep$Bbc;9+^q|bw6EgC!cmAxn@1{pOT#a&(%5`K;TIl;hfJymJlaFnr2 zg2*$FSB)~bxqwnD0qs3C0jewc!>L-n;eH=n8b8Es^FrrimBd`NiE=g7@3#L6Ig;of z7#PYgu_6=DTDmx~VSJ-%z^x zN?!82P@U9wuFF=Z1N;P;Z65eJ`SYfEfkP({VB!y7pcke*@@{oZ(g%PKOIY2Bav2#O zv4BFBQ~|615dAb53`iHv?tAa6f5tSsE-Sw++)&wmt8C8e3Yquqpe9Ix#vw}OZEWFoZf8O9_2GZke z|K8*~qnfbzSbIdGSjW&S!$ zcXC0-9X}~d%`YOAJ9(tj(K%SWXt6S~C_%@UN-T@AbI33VjEkq-sQBI+$+A%47*tpC z$eRbttFE0;?&<40UCY53)Eu8Crw=t>gxj$jvke}bj!Jdv%;1vvmKJ2>Nz>r#PafH| zYxV8s%!S6v;PFTpE{ey(0~3#hXbd?#Wm~T*PZ&X|z`o!mh*=JQ-MRKX$`GHf)jaufxB9!QN@cUFkVE#677+Y?~cS4LJ9Qap*V2-$lksH(0d z_SieFLyD<)WvMiGG5OQ%*TlM5q(QU> zujV{^<<20u)1NP_m47p<%bU`^I#--JUvcVt$HUVzp<21a+=p{x6t}k*J=)w~8;fxx zMre@RvAk@XM%BX7T|-%-dVQrura5=bW`)0xmzQE#`ks#u_6>HOJ$CG~@8)XYPOt!X zz7&eCo;3xx*F}!ZC8Qh{aqP*L`|_%1W|=)nO|6K_+j`Smd1i!@@N zY`woeY5@ITZf?7`HBt`t zmzK6LPKU%WWGf$z6}3yij|gSH*Y-peA)o4Kd;7%#E04P!PHVJ-gb?r z(|>iBkT^@@{s;t`d71Cn==QXhM_|V>dimMOYSk3|&UkU(0*5(lD_^(XA>-xZlgE26 z%l~Se4kAk8kE9C(5z zu;&TgI+bj`Vpl_4h>I=c5$RoD^ExJGn%%nx5ng7SKXw!ur&~+K`nY&}AU7M^f=d+c=gWtojjZQ8RcC<*+@!}Pkceql zVx_M~X5tr`8FJVgMnc6q>2>>q&osmtHs16oMmec+gJwJ50%z8}FnUGn^II_oM@P1s z@7X`v+qQA{OO1}}BnJ%&Ly7gm?lVjwu4m_=+kzI3xLbg%$y~Wi-l!?biMiwJfBl9C zo^T}k_x$IZ;A7&JI;h!~`!8%ti4b%1@-*hE)zRZy?lYs{70Gzg?EZO9zfW%ssI073 z`k;e8M`8wn+PRD=&#|#DD=V#u!%Igj@P37Gc*D*#D7;wi2kG&IY^iN8GX_pDd^=`h z{I5GK1VfHxY~VqKlden(|K&u~p#efGVcMLUF}Vf_(+0ZP$J9|V`c(h=@F8Z9lTfBY zuX>BD29fc$Ua8XyR9g4{^H2XlgDig~nYqe7Jx?!HC>d;CK{X4yA;ayDj__44S$AY- z*{*8Rm4=V|dwm&X$C>fMVU}$K-JLqfe>N4_Zq5Eg#x=X8{!Py#vq|`HQGnICJ8u2j zM91iooXB(yG3|Ls3N3Tr5~2b>58HlT=LTAjxoA%_VPgQl#*0rRUy4-XewCNZ$eR_O zce9?}$!OPg;)&VqovwlWk|>=OpPuz|+4Rctpz?2aeupJj+>?sm0q{TAHHq(0Sbs0E zq}DA$b|d*q|Nc8SSv>?1G}NE7mTn`Zymd8s@d8(j9Wt<&^xqnQ^7x%yfM@W5s=T7L z^F&&DE4b{Dk=}vpR_{=4t*UJQblK9S-mBKp(sp=ZQBLq;KdeQjy_n+`bW2gWb(!mm zyN^ESPm`9DIE5|TedL3es15V|J<>B@x-Gk(px$$Db1BO$o8vvQ*~81L?#EGy2P4o$ z1e`i`jic2H{Y^@0?vyz-9=ny_Ot(98zJewj%G@EcXM%!UjMROzn)i?+MCO~Gt~~cE zd*lTXE95=Ia8h5@>KoRB3*td++0lzd-xIx3_Bx|fmi*n=a(3*O(tZDoLOj6nlqH1F ze!U2@W#8mfl+z@jJd`D4)v-QJ(bW`fR1bo)7BC77{Oz8aWsH%Y;4= zxG;ov261BS+pOuEZ}0ZoczkkYrN#nOJ5c;17vpGiibYp8`40gz+j{<%nA0)T;g--n zn}#G>Q4)o%x_syl-iZ}lK!L39!u(2)|Es+>kLx*Y|NcMBV1`jMBbTyd#*%H6Jz5PT zgGAb-ZAemSrF|Wj;Tj@}vJ_?_lFHUzGGi2?NQrhTZ7S{4Lih7*uIu{#=JELbbwBR= zzq>zXNPRx<&-;Cz$9WvD*Xwv4^FMspwR}hsGCtXRr>+)sUjXZoy+=rCJ)LP{U!m~8 zW9x~V+QHAmUn)vVTtb)@fT*;dmcs6?jcNz&R3wfEfBspQBQd_e^}AK_lKYP4;5c?y zK2^OzG(B+c%%eQ-ZGMXMQ7rgP>!GioqKe8EkyWXEXz1x7MP#bf)2;nAwg-ZbpFFwK zb23KW&+QWLFFAG2kpg(l$zgZUbL*2xRjq9pru(g*ruE()&+|Xnf&`!-x0112Ex-5D z*{|DGq{RT6)8Ri(Sfj0-?lNF-YvKt?(`D@NJ_K&d4Iaiu*L4Rwee?_jgbx3_*e`+3ijR6Icf7$ zm>*%C=$5%G9XkcD>SL=ZCF+}g@zCc!j;MUob>}1L%X{kcL!E&%C|c57-j{BF`p-2h zF3izB_|ejBE}4UqHZEOS7mU-!o-(O3{xxgm_ulMYrtg|N0Hzj4FHLgBKUer!yL|Ar z4Ks4j^OqpI+nbIQxFBW~B|Aao$0b1ltq1e>A>&T6(f@ z)UzD6z1do~u7V!Ne8pN#hvA&i2M->+@jmE5k9SPYqeuF*(zp^`ONu*MTXo?o3MpjN z+aE3LrC?cvq`0KmeMHe7QNoOyFv0GYw*?XG?oZa`sbQ*~Z5MQ>L`Hfq3-N4*rMF?M zL-5twT3=sZz~P&Ax8FSK_u;33+P|qZ&H#bhpufcEg;t~o_2;&kzwQlk?k^A)D?^g=`#NnVuLg(nhIe_;0twy zT7PO2CJP_l7V+)2&)U+U)uQoc3jO&J`vq0RY0&5oNq{o6xlaD6Ur9x!cdeClbu4L*1ss-3{bo*M($FisUr{NNrF&-X znv3WKR^8vY#L?{8ZWoup=6Z)HVv*!dA)ScgUKnSQ6>cDjxc$mnVYkEC?}yH_fA+NB zlw!>QAB>!`R?%rn7^`scyGFb}hD5qj37lGg5SBW>+RCp;TtT3mro6EIrD zpYOr`FVeEB`!%Uqsg@l3asw20zjCY9kjz{vPMOPmd@W~M$?tdGBnjRk&%~jvX6YWf zAVpc0H(dbGBF_YJq^S0*QDTq}-H4kAxR%lAU<*fc9Qu^)T;nB%s89)Qi`*W{?TPvR z*pyJ`mHR)t9{A&XTD|J?4QoY0jl`A4ID<&7(frbSv|QIZ!&^TNX}3t-GnxYUdP0k+ z|FSkm3YzLWhL#-LCx|ozf z_i8j~;jF1s?=HyQekgtK@3{|eNUl<)ZfZ$yKcbf{G#H2(Hj%&R^9hC;sh3o%WZ)77 z>%9VRu*A+j8erw@qkR^dJsTpC#2G?OM=<%l~ic+O`!I2)%%I7;0 zmews;-O{HWw;PRElDhOOhzG5fuh$-*`RZLe*<<8(>j$Yfb_3tnwD{Y583<|!JT!Lm z^(jSFj6hFdqUnLLx0LKeFW`z34E!ouFuco~5%ht6it%nKn=L3nxLb26%T_{WC(an* zz1K$mcQ|qb$OMxwya@tD?;InWsVz7pJTH6{=_P?e2=&jWpu$gGV z9aryovNE-=In4kJL;A?0-P&VhQRwVi??xea3p8S@-0E$grN@#gX+6j#8x;YXIig8w zvW3$L9}D8PP{llJyuG9McWZWkqJruGc0^jvEXM1x8yr;;pe6y&Y!b;ZvNySVVzd|& zK7OPCs(<;d=g9|))o09D4ZoJxc`AD>@3Z(|+NrpDi9}+wFyfs1as8bO1^}k6P;i)x zGd7tnZ?VHwDGUtx)Z2ar+U=wUlycv2+(AtfX~Sh&d>)@#nEmQHHRn83JAN7i#f%-d z+Il7iKw(voM(<@Yv=yYp%W3;r%PoOFpmvxIU}yh6>cR!td-FA=ON+JL8(fYb_`zk; zgGY}Fz!mTaHK@97Y|KDj-S9Q?dr$EBZCQ(<4^Aj=DVT2daHZ4nx8LwT;rKQ;`*&ZR z58`CJQ^Db6XeBWuC&RNf@x9deUojMZ#)x3<_AwNjR*v<(f5n6{M`|Ni==LUDj)r+3C z|306amVV_I?kvZA8wr|&n&KF@$9ag~br8r&S#lB9DJM&_v$JD+dU*6(6h}&r8gcJUOvr<<39g<v~S+6oJC%8W!h{`-PicMleBdzik7upViBmk(V0+#0;)>0<1_z5~v>5o9JF>;JRb z$CZYL@8Yj5AF^Z1xA{-ISl9El#?E8_)((?<4`glMl!I$(TFVPQ4m|c!6 zPo4t5BvE6Ck$V5TyID#bnA7wtg_wuOclm^7>t)wnjp=-M__Gtw1xmHe@|hxxmZ!gr zZmrn?FQ8SjU18d^>O)~cZqw$AIqMH|a}&2Nl4yr(`kN!3xE0t4X+L*V5fD(B!(Ay7 zF0Ddz3n7PZYbpry0ev%63Hy~-eMwL9ox?BYYdWe7SZ(Qi{s$lX7Tu~j$#5dF3D(*D z>5&I>@Ec{+t&@jZd;2gG@tpIvgI5;ff0NXgqtxCm_yH-Y#K1E&(b%`I6T&jmSZg`T z*kyvq3Dp=PqaD?~48y#7pzHor0>MeVi zx%Fa!^l#tM!C2<-(OJtjhpUJ+ya2b#nj1MV9Fn$z9msGV>5H(+&HQ% z*ZDUy3@RVAsm%7S?HO}Pmx?Gyf>#P$zTT1To_Nc0M2MXiH(t~oCcO(YqLS=Oo`-(m zOlq5|`lc3!c9;G;ONi9O$FqcIix#ly{@~h>i1+>dffZw`*(e8kdVIbCU*_T52)x;t zxz+E~IqA#Pb%-E!W9rM7MMXt*?o2>T)!lnTNQmU-+S-WW zS>B|X^Vc49+o$$1ov!y6M3LNOsb(*21A|U3+i#|wdow(IV&unGV+7Pn6Tgj|n8M14 z_Tu~*0ecT@x4v^oj4zb2UhKTJcW5tN@_W)AZl$ikZDOm8ok|=yRg%T8Ch9v+X}qge zU@?S_80&Fl%?cB=cf0Cya_9b0X64oF^X|71QWlR+4x%QnzYqM@uZ_cgg3=y(Lob0m zwjA46ixMz~JV7eyr?f4cECzJ8QRAD{55)VwUxz$>(ZR=$ffpAe8;*R@um7Vk`B0op z&mM~9ilLP+C@KP9yTn|09f$^Te3M~_5kS>88D)$C=n%dedz&mL3aBq@<`T`pB=YT! z)n6Xno>+(H;vuib*rdST-PJTG3ItB~YJI;cCzt^n!+pQisgJ9h&#&};^dBuizYn#T zJ`T-E#R0V^6_+Gg`e{Hz3^44w)FE3js`gZUdg@ddBW1;3@ARulRR7_+M&!1E#hy(G zMWy9dDm-A3WAZ-9bc_aC?zTK}pwdA~bo<&zOO4K(Z(Lwy^u@eqo+ar&{B`?OygDwM z_|nQ!2THNzplf&d@Qv37caE8~9Nu`xkqdV4)Pu!>f?!V7E0yytNy9^MK_Ac`nTHd~)Z z$ZRJ1l_xt};uXI+F$6+m^<>|I8vANw;*sKyJC7OSq(f$1w@EX1-g{OW$2h9dH0ct_7>mBt z!f^q64~wk{t#$bIRl|cP=9)bBn2<~-$QFoUt?T{?V0<}$#N^;#bRPMnTv8c+JqO0D zIHG}bUQ(%_0@InduDDgWT_z;D#{zA}T+7n+rlue{`2EL`F((yY*JIr{Z)?R*$<%K- z8!t@WqjE?j{h*r-znWShhqvk7*$@1I;&c`xITtK5SVLJuABDCNM_&Ts>MBby2{q*? z$N%cV?8X}|E4B8lci@brN@-57_JUbnq(+}`Pg-`~olyJKtH)>R{ywuAqQ;P5re}}9 ztPZ0EeKM*!D?h5-Pe0b6;J|^SUa=wwWrcv@V8YYZWqThyckbK=x@IQa%SAaWt@bY9 z*{B7%XVnFwC^y&qgyP_?ci0i7x_>kIrej6p+!Lex9`0F==b@}5a)r&_6KD>J;KZ;rS|OZn9*s1ZjhK3OT9Qn{&ht<<)DOvWUGplfSoD!-mOu+Z2*44qHyKi>+Q6V1I( zy!!A~2_Ju}e%X7_>xQzlq-SGB3#OhWW4&i?u%cN|+mx4IIB5}9I;qBGBj$vWb(zQC z?!7|&bk)7wQ|yEV8cT?vpb7@W{z9CxB5gkoBSF}S;#2E-kb2iiq)UglSwA}Yi+G#^ zzz#2dp~M=m6wCn1{HZ}|+hiK!o;=wY>F>_cL68|qN+_E^qmr>YS+|4Hc}Zm|Jvt^? z^shYiSeKJ?19A@V0!zRr(cjgkUw!uh`ewY*cV|%Adv!P4r=OonyB+Z)pO({YG8TBE zR{i84!pezyod5`9w?94EBl5l-x{-;P9Q{f?4VLg|=7$?335~xTkYA9hIeO&Cd?6~( z@1yo}^H6HITtS9!k_hu1Nx_j);K%MW$n^8g(Wq1hD%D!E9R0_0K&=P|%{g<{!5gJF zvrK#p{#Cu;S@&-EYeVi(M0xB^pwou@i-%hKLvOa!N=n1ETXuiwZ{Lt}k~mO&IabJ+ zIZt&m4@C&9uPZ52ZmeBvYwLXZJS)*;vSP(vy=>DRrw>%tX=Hx^>X^WexM+FIUE8f1 z6})qFoj6iEN_Xc)BU8yz&8xnMY3TT|&W|2Ec#3TYjV9AB71LUjLWcFtxRd5xj~3=F zEz}4&BMo3b?v#RROGTyyjQ<{4F)^FWUZuV3BOD?g2*#ndBRmy zJk0R8^9@%JJ3%cwXkABtO1c>#He*g~G|)K4dC^nYIE{RCOtT( zW+J*O|I^(|U;ntCbQcV|A+V!4DW2H}u&fK>6c6K;XSRV9uXJ^liFWkPmZPRFWR z+~^88BoB;I7Yh>Y5?gp3W#3J^e@RS*&st=%=utSfTK*|{DS9jhnTxIocIe>ompSqF zE^u(qt_2z2{rsiZ>z`)?R*7xGr8iYvaO|EOhOquapJCt z(&sCCS@#88#vK1}vx(A@VW;ZdsuYT{#M92}jz_u!8dM1E>nL~T?AbU<po!`Kd z540>3m5GR;GrL&kVnK=-@HyQhkcbzOJ#bCCbKLiZCG2{cUza&7;LJIjW5-!R#{ck2 ztvaF_81fSg#xLC0dg4j}Z0_8?W@34BYNMj4>n^_+t-Tr=_mx!WW0&01(_=Q~ zCelL?uGJO3-4T$RO`BF_B3KeASXRjUW-pppI*C)V>tCB=MaALqjBdCBe9~BEYe7_I zW4|EIQl!AMk+Lu=_rYRk1!ZXbujUW$DW+R70pC#JTUKw0@!2ihB%%2qeKa*~+Wh&t z0>5uKNl>$+!#K-et57&14H;zK5okrRmuBH6M_X<&R>=%edEJpr`gyuGCR+F$n*Oo~ zH3qMz&)z^pHVTb}4gzKsTxUhI@q5$}itp2gyod)W^b%m{w(P?bF0mqTf72!H+DMQU z(w3>KkJNFgThNuUY-*kly~T18*=8!UnPooz{5+*YgLpu+^O`w4J7#3=xvIjaT^O2J z%P_G`H>0;9#tSsb00fk@b|x?i5w3s(IDGAt9;8u?Add=@g?-pfzL+_ChZ!B>w6lxP*fG&Rf%672!Umm5!XCM5#n*8PL*{E*%$qiSx((-F zs51?^dLl~jGjyn!D{FtLga4^`EmDtI)0F&+t%%lFfmFMzv{sA-2J<4FaIKASAAfH$ zQAS$5<7l~0`})rMXRF0B#ow1@dg9ozUqDIhsA5lpBj~jo2NpT^(wX=Pa+|$B$p2YM-s3@aDAIK9ODH0Fz2u z*qRwlE0~o;-6MS{E{-S}<`5E55ma(eu7+ zgsikX1K`|Z=6xY4#vIR2#9D)<5o>Mwj2WUYC#%Vi8tY}E(^OZpDR&XZs)+&5> zeMsq>q zAk`K-mYa%gs#o>;>&PF3rWm;J7XD*7F9gg*jDjyeF1&l7%$Ni%mb41E@UyW4F3&hH zT$@XcxkCcdU~GiYYz660iAD*5kJs+^t`G}*tVg_J~ootu=jbdmpu z_Q`_!tZ$eI?e&E7r?f$n9fsWahkRLyr1VSS2LRNTQ~MCu!WRGqP1Vu6$>@~~E&Sgn z;HSDv4Fka?I|(>ht%l`(efB_@r57zjkWg{TskjE(oKUUL@3#WNog0&GHkX!s_fqEr zdX{_LF;;j`=M#K^v$}!V@=r!7rJu(p>oQ6OK+h*9eu{}IRG+jAdV|CwRvi5V`nBuV z8{oZz1hm1eL^+>cPVW9o8oaH5_X_X=Sd2i`Bx{~i<#KFNvO^Okr4yMu7x-pipflbd z*|9QXCYb}^WNZBg#`<<3O*%D-OWX3-Da~R=_eD$64#(sDz6rF9%8v8E)>w~i9iey& zQu?=zfLekG0v)*(_&7KBIsJnoPOJ&@TSB@Vp?KfA1{zbGtk8}kkpbmjn&5Fqj#yMU z!7=?8kt+Uq5@FI{AW@2Xw4VO)Z`v-jB4S7X}Y{x;<`uu-(5gR+G85%g^Z_Ryb zoO)beo$$l<)Zw6WiN0}mA$gKj)!JG&O7fG!wjO@bb@N!*>0QJ^C#^_8ABNd|L!xCpFn)AeIDi&@u%8T7Ya#n;2@dS;ECMC*kqUJFQ zOfumii1irXG5V_k!Px7T3IU-&7#B2i@^%!b>w;)4-i}*NP&W!+ilT6ptW&GZg7IRW zbLn*s(*e2ZOI<%pw`RP9;v{!$g-XUi(v4bBbk#T#we7NmlTQpp4VUWH)X?hZZ~>%K zq}oQk!;cOe1!X;DGsRvc(NsssXjHGh8!IOl4J|Nx_UwmoUyq+ihdCutfTm^Kc*7{=0`dazYm- z8Xw$Ld+hznU$)*!6IeZ~(ZgHTe9WXtv3$xaXrVoeTR45<2EYl})Em!sjrwX}chvbJ zAcweYd2ekoW@|Ax=4Jx0*;%ExMBgHkvC{Ho%9&-g8w*o@s4q>q`E1~@qXN~M2#0vP zwcVfU&9f{r7#s0de@(8ifxBaoQfG~@$DDu2E0JtPhe|rQMCsH*h@BF-jXb(St6}<6u>9hcd z!ouFq#q*H)9ELHSo^QE7d(4#~rY-1MjbMJ&p5F#kA)-dHUveyclkGUjMS`VfUZVnK z1!A0X)~jJwP!}4gqm!2{)byuqEppO+!IhGJ)xP{amDacSSac>%lPnGSSRW%Z@}{MQ zw{MP>CPKc|S4*FR{XR#G2&1d}W|3#|snv@Q#IT1h0f5vP|K4t&R2IehuE+T`FXV$n z<{N^5&b=UD4`6Z^Sv`gx(r>4n7pIDT{W+wby|MQ0>`CsjO5czAq48NlJWr5i=1`Q!Or&f z-Cn8_JS-4diq!jhp)VopIThC;R7F7>ic_zEr*2^q=CcslTv6~QAkpTNrr+B=TU?$lCbksp{=zSsIi9(e+bzDZg(<^v$JoCMi#t9e6((vB(wI68fEh=)K3?~%8Z`SJQba{LmEN(EjRNi0u zlk=70#hCj#qKO%i^xMm$`P{G$>oUuv$)xb4VUUy60)qt=E28jlO6mx; z8)i(LN0=T9wp{!?3d=dsEhCcrS0e+FXI+5hiS@@4+^OWxs@S8#Fdx$=`q-E5oS2k z6bGcbj`J#ZBteO#dYq+|nTwTI|3*w{C3$bB+0|zqlo7_ z7Fip)d2L!xbc{%C1nMHdy*>wZdSLQJ&H}BcYVfVho!ELbp55Q40cI5d<9716if`}U zK>^ZHXk;NuAe|W#OXXe!-lHG0fjW2ISb7VapVC*4jk^Ft#ykMMy1j zH*}rGwQ#yUWu?TjYG0Ojk;IAB7Btp%=*DX`kL9;4*I>X-ES15QEYTeyt<^q&cHbhi zzwMB7`ohBPfc_m&`1K%w`Rvm;alopomYmqSHuZS5Po8>r-Ni8Twio}HMDbC?l^=d> zipoCZJvQXz5IUo+feSU6JnD?Gc8W+Rn6a!S zjusnJmXKh)&2i3Z{L-86MYH1SI1}BX?KiY{cNd9#7VOQ2Oz-LO+8QFxx=SLQ@+?1--;G>4^MDC)d#sqOlfbKG2|CA+GtU(mL$ez ze;GY68xgiRs7X{WVBoBcbE6QHL^xa`OKYge%#`=4`+6+m{w8N%TsuTDG$|`NxImUZ zG0|!0#Wxo-6iq1OqV3F)I5LGqNxi#tcad4k0A^UrcI9keAbQ4nIqxQBCHLLDj^8|t z*wR4fORvRr-$ht$ADFMuJ=871avMZ(sz>XucV$&AZl^Y$vOatL>b1t)oPP?f@rQ|P zUF_1DSV?m*INut0vv4zwEb^|{Fo+{fD@Z;TphLyghZk`r6A%!Aa{tSQ;)EE16djQF z`kuV##;&~PafJ%Zlf5DY3kW=ZKVjEPqjd8gRjQg)zw##QRPTFxGf^$F;y(CU z44n-*OP^h+%WQag^U~9dtCwQ#y@YOE8>(4Cx$lzlo^J8T-PsY=Cr~xzEZbx47H&gj z%)VX3q&-nKiDGtLm5_4C`@H*Bv^5ruEkC8=dWBOi3Xw-jLEnw~!i5B9Ri;3+kD0T|q8hpWQ;0S_5FI!lNfVu?NC18c7d!}$B z*ZWV=q`W&uT2fvWQkAPWHejeVkZUhqu^nVeVtj@Q-KsV739UlbPwbBZI4n*y`0S_= ztk~kMzke+D4+YLf`it&*P-pWy>o~UbpiIY(853gnmc>XeJ2Qh?T!r~!an9rjhOqt& zg)6XAh?vao@db61%5G=#+R{-WWvPupfE|ovFL}`kq)8q5zmuB-C7{F$7Kp#6Pow%* znnF7vjwI4?JES(fg!crfAi=`W$hDDHLa>Cfh~UVNuAmSWvNm2dl8-T)2Io0HJ1YlM zJ92?`H0>LGmYfKn$q42@=V_bzeVEJevtTM43DLP^}kVWhg?TBXio^WiU3RT(V-&Xyg0nts) zKvxX0V<*)#9)-uC!7tHc+VpVodu-V_xD0mgtUF=nMMWY2$L1P|X)OdoA?&7qO>FXq zLWyKl*P8FX)o(%B=G-uBvgFP(%J2d(qfYLxID54Y#YF!4s{paNHR~`;zLpTY@aQ-p zY@uJv77$6`kTpYJf}RLLDwmj#2U3&MvEQh;2e&HD8&b7tRj*$cAOy_ld~Q_U z&uzrrR%W#e(X5NCJKz|d;ulh6+?Mi^Zdqt~uuA4y)k;=>Kux<+A|gYN%H-rPLeVR|}_HxgIn{!O*2+1Aa(O4S zNK?w=@qLcr)L4rcMe$sv2n+P))LeVde+`7d>aOe<3aSDi#|{954J?mM^o`fd2YxW% z#-9M%!?#DTX=1<^ba0!z;G`6d*KF&RT-kFHUBT`Q8!^J7eOo%mh)(S5kK1kv+?r6? zD{8iBS%CG1WG1t>+`3GFfM%L2DB@^0rYgmQ6Wj6xRS=Uxu8`Ws(rp%l_94aa!UKBe zW2^wA2owT!7IA^dYPDm>Y$_XUO5uVYooGK-EtCY=Z7MJ2ByU&TSSazxBfE)zlsB2- zot&JT3=N;37?w)5wvkL)Y#r;g?}nTW1tsO(&6AuKPnog;{4RBYX3N_+(Vdr%i1ZdG z*TqYhaw;kY9{ye^q*SNpMT9-bHKIWxz4<*AOsxza6j(~ymFJOjz9?3fRG+|E&6ZVO zeA7Az^WV6X+%i_Im?#Q5k*=oSwZ?0zkV<i38|9C`qC*d(pwmEky}V~GKl|yZ zQF@4fcR4B(mjRKDpfc?pOU8&W0tZ^isQZ$YLnJ}C6J8ekn4ER+_hw;rg*y^ufLJ+P zqrGF19tf8KrjY_xq_3mvbMaIAMVmSg;aXsOr&Y$XzuMxAkeJ2ws}o5NMe<_zRT;rqjKdRZO@8GI zK%aJg`;Ob-+fRqse!?mIO%=e@@e$$|0&a@@YcvXpnM<&pxXbry1^1w!kX3S*cfG$` zfW?qAc5>rnWd8b{Ya1v~K)Scdjr;P;VDU12;6IXp19Qz%^=V(ABWKh1-~S7yilzbmBR0Eibxs6B|ibKY&l`!g9k|?9_kSbAvR)(~0IBE^ERm&~yN`u_XKm zGbXVgs7{N-1| zB+IvdXH<}YfWy~1-M$J0OTpcF=dyTX$q_X&3q{AOX3!Br*Gsi2eA64ZB{STH z03_dB9U`;nb~#rdxo^fHMSOqNw3WZz8I1|S3C@VYTQ89~tspo+K;jNSO?zi>1{hs? z>FS?JoJKy#%L6!a`Xf8-5}-jO3aDi6ZTySKyVsy=O`)7t!q{ zxo^$E-ic`MG7LT&n3po$NS;ZGy#^|*u5-ke02^l@{OcS$4Lg#np1O1?()t7W55ywLg@=<23KYIYCLZvk@<=UD$sc`;}5dZ zQJOyBJ4;~RpTxx9z5+@i!f?5~lqC64MG1Gmat9fKfrw4SgtRw}kV-sq1B`)$g)nco z4yPIO@zZ(jiDCebXrd8psmNfQoURcJ#uDB|j}zwvD%Y;Ad0ZXM01P^O0@#VL{<__5 zbBt_hxG*DV>65`1AWUFkde|K2C&jSo1U#+JQ#Z|2S2shAj+=%a^`z*vVoy@~7~0a9 z!3R!PHN4cca+A=W$-KDOu4v#hJsnch7A+x z4_rtWEwbWE2XJZd5C8L@<+rOFn+C7_f4;}PJ1q9f-@m@HK%xHcZ#jG`?6R?D@Tg(dC`CVrS#-~@9clR_3!`t z*ZgnXn*Y4U^4s5We*gXJ-$%sXzuEJp*!+M0`agdw{^!~K=birF1&TZI|Ff9>^lvJ@ zF8=M#SN@GQZ)Y9}nEu~|&v!Tfe^=@M?SC4%PfO*!*Ee}0WJuPm()liK#SeS`7YPph Axc~qF literal 0 HcmV?d00001 From 6dcdd6536456158667747f724d6bd3a2ceaa8d88 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Wed, 27 May 2026 08:46:23 +0200 Subject: [PATCH 694/831] ci : only run docker jobs when pushed to master [no ci] (#3828) --- .github/workflows/docker.yml | 2 -- 1 file changed, 2 deletions(-) diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index 6c0de0ece70..c5162dc8251 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -1,7 +1,6 @@ name: Publish Docker image on: - pull_request: push: branches: - master @@ -9,7 +8,6 @@ on: jobs: push_to_registry: name: Push Docker image to Docker Hub - if: github.event.pull_request.draft == false runs-on: ubuntu-22.04 env: From f6e617bab7843d86d94237f9791ee41d524333f9 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Thu, 28 May 2026 07:21:25 +0200 Subject: [PATCH 695/831] ci : set GGML_NATIVE=OFF for bindings-java (#3830) * ci : set GGML_NATIVE=OFF for bindings-java This commit attempts to address an issue with the bindings-java job which is currently failing. I've not been able to reproduce this locally my windows machine and I suspect that what might be happning is that windows job compiles on a runner where it has different CPU features, for example AVX512 and when this dll is used on a different runner that does not have that feature it will crash. Refs: https://github.com/ggml-org/whisper.cpp/actions/runs/26496174929/job/78059073255?pr=3829 * ci : also disable BMI2 --- .github/workflows/build.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 7ace04e1207..aaaa8fe5826 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -640,6 +640,8 @@ jobs: -DCMAKE_BUILD_TYPE=${{ matrix.build }} -DBUILD_SHARED_LIBS=ON -DWHISPER_SDL2=${{ matrix.sdl2 }} + -DGGML_NATIVE=OFF + -DGGML_BMI2=OFF - name: Build run: | From 9186e2453bdd051854b17cfb0d068f629663e114 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Thu, 28 May 2026 12:09:13 +0200 Subject: [PATCH 696/831] ci : renable arm64 docker builds (#3832) This commit re-enables the arm64 docker images builds which were removed in Commit 9366544991bfee59c927e7c23b1861c6c762e708 ("ci : fix arm builds"). It also uses ubuntu-24.04-arm as the runner which enables us to avoid QEMU. Resolves: https://github.com/ggml-org/whisper.cpp/issues/2859 --- .github/workflows/docker.yml | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index c5162dc8251..9e07f7b2292 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -9,28 +9,25 @@ jobs: push_to_registry: name: Push Docker image to Docker Hub - runs-on: ubuntu-22.04 + runs-on: ${{ matrix.config.runs_on }} env: COMMIT_SHA: ${{ github.sha }} strategy: fail-fast: false matrix: config: - - { tag: "main", dockerfile: ".devops/main.Dockerfile", platform: "linux/amd64" } - - { tag: "main-musa", dockerfile: ".devops/main-musa.Dockerfile", platform: "linux/amd64" } - - { tag: "main-intel", dockerfile: ".devops/main-intel.Dockerfile", platform: "linux/amd64" } - - { tag: "main-cuda", dockerfile: ".devops/main-cuda.Dockerfile", platform: "linux/amd64" } - - { tag: "main-vulkan", dockerfile: ".devops/main-vulkan.Dockerfile", platform: "linux/amd64" } + - { tag: "main", dockerfile: ".devops/main.Dockerfile", platform: "linux/amd64", runs_on: "ubuntu-24.04" } + - { tag: "main-arm64", dockerfile: ".devops/main.Dockerfile", platform: "linux/arm64", runs_on: "ubuntu-24.04-arm" } + - { tag: "main-musa", dockerfile: ".devops/main-musa.Dockerfile", platform: "linux/amd64", runs_on: "ubuntu-24.04" } + - { tag: "main-intel", dockerfile: ".devops/main-intel.Dockerfile", platform: "linux/amd64", runs_on: "ubuntu-24.04" } + - { tag: "main-cuda", dockerfile: ".devops/main-cuda.Dockerfile", platform: "linux/amd64", runs_on: "ubuntu-24.04" } + - { tag: "main-vulkan", dockerfile: ".devops/main-vulkan.Dockerfile", platform: "linux/amd64", runs_on: "ubuntu-24.04" } + - { tag: "main-vulkan-arm64", dockerfile: ".devops/main-vulkan.Dockerfile", platform: "linux/arm64", runs_on: "ubuntu-24.04-arm" } steps: - name: Check out the repo uses: actions/checkout@v6 - - name: Set up QEMU - uses: docker/setup-qemu-action@v3 - with: - image: tonistiigi/binfmt:qemu-v7.0.0-28 - - name: Set up Docker Buildx uses: docker/setup-buildx-action@v3 From f41562bdd6f3ed19ef352e93f54bdf829771e59d Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Thu, 28 May 2026 14:41:48 +0200 Subject: [PATCH 697/831] ci : add on push/pull_request paths ruby job (#3833) * ci : add on push/pull_request paths ruby job This commit adds paths to bindings-ruby to only build if changes where made to bindings/ruby or to include/whisper.h. * ci : add additional paths [no ci] --- .github/workflows/bindings-ruby.yml | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/.github/workflows/bindings-ruby.yml b/.github/workflows/bindings-ruby.yml index c3f158e26e4..0c31701a2a3 100644 --- a/.github/workflows/bindings-ruby.yml +++ b/.github/workflows/bindings-ruby.yml @@ -4,8 +4,19 @@ on: push: branches: - master + paths: + - bindings/ruby/** + - include/whisper.h + - examples/common-whisper.h + - ggml/include/ggml.h + pull_request: types: [opened, synchronize, reopened] + paths: + - bindings/ruby/** + - include/whisper.h + - examples/common-whisper.h + - ggml/include/ggml.h jobs: ubuntu-22: From e47a3eeb04176d33630a0a3042caf3b64dc644ae Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Thu, 28 May 2026 14:53:34 +0200 Subject: [PATCH 698/831] ci : fix include paths for bindings-go job [no ci] (#3835) --- .github/workflows/bindings-go.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/bindings-go.yml b/.github/workflows/bindings-go.yml index 83473e4636a..44381a4b411 100644 --- a/.github/workflows/bindings-go.yml +++ b/.github/workflows/bindings-go.yml @@ -3,11 +3,11 @@ on: push: paths: - bindings/go/** - - whisper.h + - include/whisper.h pull_request: paths: - bindings/go/** - - whisper.h + - include/whisper.h jobs: ubuntu-22: From c932729a304f7d9eb5354afa38624cfa86a780cf Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Thu, 28 May 2026 18:06:04 +0200 Subject: [PATCH 699/831] ci : add ignore for bindings/{ruby, go} in build.yml [no ci] (#3837) This commit adds an ignore for bindings-ruby and bindings-go in build.yml as these are handled by separate .yml file (separate jobs) and don't need to trigger a full CI build. --- .github/workflows/build.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index aaaa8fe5826..e855ef7cf87 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -28,6 +28,9 @@ on: pull_request: types: [opened, synchronize, reopened] + paths-ignore: + - 'bindings/ruby/**' # handled by bindings-ruby.yml + - 'bindings/go/**' # handled by bindings-go.yml workflow_dispatch: inputs: create_release: From 205ee5a1898d7a52167a6064b166bd890b06ac6e Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Mon, 25 May 2026 21:12:10 +0800 Subject: [PATCH 700/831] CUDA: add fast walsh-hadamard transform (llama/23615) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * CUDA: add fast walsh-hadamard transform * review: add unrolls + change size_t -> int * warp size 64 --------- Co-authored-by: Johannes Gäßler --- ggml/src/ggml-cuda/fwht.cu | 108 ++++++++++++++++++++++++++++++++ ggml/src/ggml-cuda/fwht.cuh | 3 + ggml/src/ggml-cuda/ggml-cuda.cu | 8 +++ 3 files changed, 119 insertions(+) create mode 100644 ggml/src/ggml-cuda/fwht.cu create mode 100644 ggml/src/ggml-cuda/fwht.cuh diff --git a/ggml/src/ggml-cuda/fwht.cu b/ggml/src/ggml-cuda/fwht.cu new file mode 100644 index 00000000000..74e94d8442b --- /dev/null +++ b/ggml/src/ggml-cuda/fwht.cu @@ -0,0 +1,108 @@ +#include "common.cuh" +#include "fwht.cuh" + +template +__launch_bounds__(4*ggml_cuda_get_physical_warp_size(), 1) +__global__ void fwht_cuda(const float * src, float * dst, const int64_t n_rows, const float scale) { + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + + const int64_t r = (int64_t) blockIdx.x * blockDim.y + threadIdx.y; + + if (r >= n_rows) { + return; + } + + src += r * N; + dst += r * N; + + static constexpr int el_w = N / warp_size; + float reg[el_w]; + const int lane = threadIdx.x; + +#pragma unroll + for (int i = 0; i < el_w; ++i) { + reg[i] = src[i * warp_size + lane] * scale; + } + +#pragma unroll + for (int h = 1; h < warp_size; h *= 2) { +#pragma unroll + for (int j = 0; j < el_w; j++) { + const float val = reg[j]; + const float val2 = __shfl_xor_sync(0xFFFFFFFF, val, h, warp_size); + + reg[j] = (lane & h) == 0 ? val + val2 : val2 - val; + } + } + +#pragma unroll + for (int h = warp_size; h < N; h *= 2) { + const int step = h / warp_size; +#pragma unroll + for (int j = 0; j < el_w; j += 2 * step) { +#pragma unroll + for (int k = 0; k < step; k++) { + const float x = reg[j + k]; + const float y = reg[j + k + step]; + + reg[j + k] = x + y; + reg[j + k + step] = x - y; + } + } + } + +#pragma unroll + for (int i = 0; i < el_w; ++i) { + dst[i * warp_size + lane] = reg[i]; + } +} + +void ggml_cuda_op_fwht(ggml_backend_cuda_context & ctx, const ggml_tensor * src, ggml_tensor * dst) { + GGML_ASSERT(ggml_are_same_shape(src, dst)); + GGML_ASSERT(ggml_is_contiguous(src)); + GGML_ASSERT(ggml_is_contiguous(dst)); + const int n = src->ne[0]; + const int64_t rows = ggml_nrows(src); + + const float * src_d = (const float *) src->data; + float * dst_d = (float *) dst->data; + + const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size; + GGML_ASSERT(n % warp_size == 0); + const int rows_per_block = 4; + + const int64_t num_blocks = (rows + rows_per_block - 1) / rows_per_block; + + cudaStream_t stream = ctx.stream(); + dim3 grid_dims(num_blocks, 1, 1); + dim3 block_dims(warp_size, rows_per_block, 1); + const ggml_cuda_kernel_launch_params launch_params = + ggml_cuda_kernel_launch_params(grid_dims, block_dims, 0, stream); + + const float scale = 1 / sqrtf(n); + + switch (n) { + case 64: + { + ggml_cuda_kernel_launch(fwht_cuda<64>, launch_params, src_d, dst_d, rows, scale); + break; + } + case 128: + { + ggml_cuda_kernel_launch(fwht_cuda<128>, launch_params, src_d, dst_d, rows, scale); + break; + } + case 256: + { + ggml_cuda_kernel_launch(fwht_cuda<256>, launch_params, src_d, dst_d, rows, scale); + break; + } + case 512: + { + ggml_cuda_kernel_launch(fwht_cuda<512>, launch_params, src_d, dst_d, rows, scale); + break; + } + default: + GGML_ABORT("fatal error"); + } +} diff --git a/ggml/src/ggml-cuda/fwht.cuh b/ggml/src/ggml-cuda/fwht.cuh new file mode 100644 index 00000000000..fa4c30477a7 --- /dev/null +++ b/ggml/src/ggml-cuda/fwht.cuh @@ -0,0 +1,3 @@ +#include "common.cuh" + +void ggml_cuda_op_fwht(ggml_backend_cuda_context & ctx, const ggml_tensor * src, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index e25be3592fd..1bb09ac80ee 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -24,6 +24,7 @@ #include "ggml-cuda/diagmask.cuh" #include "ggml-cuda/diag.cuh" #include "ggml-cuda/fattn.cuh" +#include "ggml-cuda/fwht.cuh" #include "ggml-cuda/getrows.cuh" #include "ggml-cuda/im2col.cuh" #include "ggml-cuda/mmf.cuh" @@ -2594,6 +2595,13 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor bool use_batched_cublas_bf16 = src0->type == GGML_TYPE_BF16 && bf16_mma_hardware_available(cc); bool use_batched_cublas_f32 = src0->type == GGML_TYPE_F32; + const int32_t hint = ggml_get_op_params_i32(dst, 1); + if (hint == GGML_HINT_SRC0_IS_HADAMARD) { + GGML_ASSERT(!split); + ggml_cuda_op_fwht(ctx, src1, dst); + return; + } + if (!split && use_mul_mat_vec_f) { // the custom F16 vector kernel can be used over batched cuBLAS GEMM // but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention) From 1c477d4056c8c424d45b8ddce1814598600c0d79 Mon Sep 17 00:00:00 2001 From: forforever73 <63285796+forforever73@users.noreply.github.com> Date: Tue, 26 May 2026 02:05:16 +0800 Subject: [PATCH 701/831] metal : add apple device id (llama/23566) Co-authored-by: lvyichen --- ggml/src/ggml-metal/ggml-metal-device.h | 26 ++++++++++++++ ggml/src/ggml-metal/ggml-metal-device.m | 46 +++++++++++++++++++++++++ 2 files changed, 72 insertions(+) diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index 1f212a92f98..4a3ebb5569d 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -215,6 +215,30 @@ void ggml_metal_rsets_free(ggml_metal_rsets_t rsets); // device // +enum ggml_metal_device_id { + GGML_METAL_DEVICE_GENERIC = 0, + + GGML_METAL_DEVICE_M1, + GGML_METAL_DEVICE_M1_PRO, + GGML_METAL_DEVICE_M1_MAX, + GGML_METAL_DEVICE_M1_ULTRA, + GGML_METAL_DEVICE_M2, + GGML_METAL_DEVICE_M2_PRO, + GGML_METAL_DEVICE_M2_MAX, + GGML_METAL_DEVICE_M2_ULTRA, + GGML_METAL_DEVICE_M3, + GGML_METAL_DEVICE_M3_PRO, + GGML_METAL_DEVICE_M3_MAX, + GGML_METAL_DEVICE_M3_ULTRA, + GGML_METAL_DEVICE_M4, + GGML_METAL_DEVICE_M4_PRO, + GGML_METAL_DEVICE_M4_MAX, + GGML_METAL_DEVICE_M5, + GGML_METAL_DEVICE_M5_PRO, + GGML_METAL_DEVICE_M5_MAX, + GGML_METAL_DEVICE_M5_ULTRA, +}; + struct ggml_metal_device_props { int device; char name[128]; @@ -234,6 +258,8 @@ struct ggml_metal_device_props { bool supports_gpu_family_apple7; + enum ggml_metal_device_id device_id; + int op_offload_min_batch_size; }; diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 780dfe81bb3..885344ec670 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -628,6 +628,50 @@ void ggml_metal_rsets_free(ggml_metal_rsets_t rsets) { free(rsets); } +static enum ggml_metal_device_id ggml_metal_device_id_parse(const char * name) { + if (!name) { + return GGML_METAL_DEVICE_GENERIC; + } + + static const char prefix[] = "Apple "; + if (strncmp(name, prefix, sizeof(prefix) - 1) != 0) { + return GGML_METAL_DEVICE_GENERIC; + } + const char * suffix = name + sizeof(prefix) - 1; + + static const struct { + const char * name; + enum ggml_metal_device_id id; + } table[] = { + {"M1", GGML_METAL_DEVICE_M1}, + {"M1 Pro", GGML_METAL_DEVICE_M1_PRO}, + {"M1 Max", GGML_METAL_DEVICE_M1_MAX}, + {"M1 Ultra", GGML_METAL_DEVICE_M1_ULTRA}, + {"M2", GGML_METAL_DEVICE_M2}, + {"M2 Pro", GGML_METAL_DEVICE_M2_PRO}, + {"M2 Max", GGML_METAL_DEVICE_M2_MAX}, + {"M2 Ultra", GGML_METAL_DEVICE_M2_ULTRA}, + {"M3", GGML_METAL_DEVICE_M3}, + {"M3 Pro", GGML_METAL_DEVICE_M3_PRO}, + {"M3 Max", GGML_METAL_DEVICE_M3_MAX}, + {"M3 Ultra", GGML_METAL_DEVICE_M3_ULTRA}, + {"M4", GGML_METAL_DEVICE_M4}, + {"M4 Pro", GGML_METAL_DEVICE_M4_PRO}, + {"M4 Max", GGML_METAL_DEVICE_M4_MAX}, + {"M5", GGML_METAL_DEVICE_M5}, + {"M5 Pro", GGML_METAL_DEVICE_M5_PRO}, + {"M5 Max", GGML_METAL_DEVICE_M5_MAX}, + {"M5 Ultra", GGML_METAL_DEVICE_M5_ULTRA}, + }; + + for (size_t i = 0; i < sizeof(table)/sizeof(table[0]); ++i) { + if (strcmp(suffix, table[i].name) == 0) { + return table[i].id; + } + } + return GGML_METAL_DEVICE_GENERIC; +} + ggml_metal_device_t ggml_metal_device_init(int device) { ggml_metal_device_t dev = calloc(1, sizeof(struct ggml_metal_device)); @@ -795,6 +839,8 @@ ggml_metal_device_t ggml_metal_device_init(int device) { dev->props.supports_gpu_family_apple7 = [dev->mtl_device supportsFamily:MTLGPUFamilyApple7]; + dev->props.device_id = ggml_metal_device_id_parse([[dev->mtl_device name] UTF8String]); + dev->props.op_offload_min_batch_size = getenv("GGML_OP_OFFLOAD_MIN_BATCH") ? atoi(getenv("GGML_OP_OFFLOAD_MIN_BATCH")) : 32; dev->props.max_buffer_size = dev->mtl_device.maxBufferLength; From 2307712d32a17becc38f6efb8154be5196aa8f87 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Tue, 26 May 2026 05:05:51 +0200 Subject: [PATCH 702/831] CUDA: missing PDL sync for FWHT, better fallback (llama/23690) --- ggml/src/ggml-cuda/fwht.cu | 35 +++++++++++++-------------------- ggml/src/ggml-cuda/fwht.cuh | 3 ++- ggml/src/ggml-cuda/ggml-cuda.cu | 4 +--- 3 files changed, 17 insertions(+), 25 deletions(-) diff --git a/ggml/src/ggml-cuda/fwht.cu b/ggml/src/ggml-cuda/fwht.cu index 74e94d8442b..184dc254c72 100644 --- a/ggml/src/ggml-cuda/fwht.cu +++ b/ggml/src/ggml-cuda/fwht.cu @@ -19,6 +19,7 @@ __global__ void fwht_cuda(const float * src, float * dst, const int64_t n_rows, float reg[el_w]; const int lane = threadIdx.x; + ggml_cuda_pdl_sync(); #pragma unroll for (int i = 0; i < el_w; ++i) { reg[i] = src[i * warp_size + lane] * scale; @@ -57,10 +58,11 @@ __global__ void fwht_cuda(const float * src, float * dst, const int64_t n_rows, } } -void ggml_cuda_op_fwht(ggml_backend_cuda_context & ctx, const ggml_tensor * src, ggml_tensor * dst) { +bool ggml_cuda_op_fwht(ggml_backend_cuda_context & ctx, const ggml_tensor * src, ggml_tensor * dst) { GGML_ASSERT(ggml_are_same_shape(src, dst)); - GGML_ASSERT(ggml_is_contiguous(src)); - GGML_ASSERT(ggml_is_contiguous(dst)); + if (!ggml_is_contiguous(src) || !ggml_is_contiguous(dst)) { + return false; + } const int n = src->ne[0]; const int64_t rows = ggml_nrows(src); @@ -68,7 +70,6 @@ void ggml_cuda_op_fwht(ggml_backend_cuda_context & ctx, const ggml_tensor * src, float * dst_d = (float *) dst->data; const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size; - GGML_ASSERT(n % warp_size == 0); const int rows_per_block = 4; const int64_t num_blocks = (rows + rows_per_block - 1) / rows_per_block; @@ -83,26 +84,18 @@ void ggml_cuda_op_fwht(ggml_backend_cuda_context & ctx, const ggml_tensor * src, switch (n) { case 64: - { - ggml_cuda_kernel_launch(fwht_cuda<64>, launch_params, src_d, dst_d, rows, scale); - break; - } + ggml_cuda_kernel_launch(fwht_cuda<64>, launch_params, src_d, dst_d, rows, scale); + return true; case 128: - { - ggml_cuda_kernel_launch(fwht_cuda<128>, launch_params, src_d, dst_d, rows, scale); - break; - } + ggml_cuda_kernel_launch(fwht_cuda<128>, launch_params, src_d, dst_d, rows, scale); + return true; case 256: - { - ggml_cuda_kernel_launch(fwht_cuda<256>, launch_params, src_d, dst_d, rows, scale); - break; - } + ggml_cuda_kernel_launch(fwht_cuda<256>, launch_params, src_d, dst_d, rows, scale); + return true; case 512: - { - ggml_cuda_kernel_launch(fwht_cuda<512>, launch_params, src_d, dst_d, rows, scale); - break; - } + ggml_cuda_kernel_launch(fwht_cuda<512>, launch_params, src_d, dst_d, rows, scale); + return true; default: - GGML_ABORT("fatal error"); + return false; } } diff --git a/ggml/src/ggml-cuda/fwht.cuh b/ggml/src/ggml-cuda/fwht.cuh index fa4c30477a7..cf3df94cafa 100644 --- a/ggml/src/ggml-cuda/fwht.cuh +++ b/ggml/src/ggml-cuda/fwht.cuh @@ -1,3 +1,4 @@ #include "common.cuh" -void ggml_cuda_op_fwht(ggml_backend_cuda_context & ctx, const ggml_tensor * src, ggml_tensor * dst); +// Returns whether the Fast Walsh-Hadamard transform could be used. +bool ggml_cuda_op_fwht(ggml_backend_cuda_context & ctx, const ggml_tensor * src, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 1bb09ac80ee..23d1c069248 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2596,9 +2596,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor bool use_batched_cublas_f32 = src0->type == GGML_TYPE_F32; const int32_t hint = ggml_get_op_params_i32(dst, 1); - if (hint == GGML_HINT_SRC0_IS_HADAMARD) { - GGML_ASSERT(!split); - ggml_cuda_op_fwht(ctx, src1, dst); + if (hint == GGML_HINT_SRC0_IS_HADAMARD && !split && ggml_cuda_op_fwht(ctx, src1, dst)) { return; } From bc77933c2de8c5d104bfc234e3358217265fe147 Mon Sep 17 00:00:00 2001 From: Nikhil Jain Date: Mon, 25 May 2026 20:32:49 -0700 Subject: [PATCH 703/831] Check batch_compute_passes before sending passes when not doing GPU profiling (llama/23457) * Only run webgpu CI on my fork * Add webgpu only workflow * refactor batch_compute_passes to a per-thread variable, and submit individual passes when it is set to false and no GPU profiling is enabled * restore build.yml --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 35 +++++++++++++++++----------- 1 file changed, 22 insertions(+), 13 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 921c12b41ac..1561a4e30c6 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -259,6 +259,7 @@ struct webgpu_context_struct { wgpu::Buffer set_rows_host_error_buf; wgpu::CommandEncoder active_command_encoder; wgpu::ComputePassEncoder active_compute_pass; + bool batch_compute_passes = true; size_t memset_bytes_per_thread; @@ -590,9 +591,18 @@ static webgpu_encoded_op ggml_backend_webgpu_build_multi(webgpu_context & } #else for (size_t i = 0; i < dispatches.size(); i++) { - ctx->active_compute_pass.SetPipeline(dispatches[i].pipeline.pipeline); - ctx->active_compute_pass.SetBindGroup(0, bind_groups[i]); - ctx->active_compute_pass.DispatchWorkgroups(dispatches[i].workgroups.first, dispatches[i].workgroups.second, 1); + if (ctx->batch_compute_passes) { + ctx->active_compute_pass.SetPipeline(dispatches[i].pipeline.pipeline); + ctx->active_compute_pass.SetBindGroup(0, bind_groups[i]); + ctx->active_compute_pass.DispatchWorkgroups(dispatches[i].workgroups.first, dispatches[i].workgroups.second, + 1); + } else { + wgpu::ComputePassEncoder pass = ctx->active_command_encoder.BeginComputePass(); + pass.SetPipeline(dispatches[i].pipeline.pipeline); + pass.SetBindGroup(0, bind_groups[i]); + pass.DispatchWorkgroups(dispatches[i].workgroups.first, dispatches[i].workgroups.second, 1); + pass.End(); + } } #endif @@ -1956,10 +1966,10 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, std::vector reduce_entries; if (use_vec_reduce) { const uint32_t reduce_sg_size = ctx->global_ctx->capabilities.max_subgroup_size; - const uint32_t reduce_wg_size = - std::max(reduce_sg_size, (uint32_t) std::min( - (uint64_t) nwg * reduce_sg_size, - ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup)); + const uint32_t reduce_wg_size = std::max( + reduce_sg_size, + (uint32_t) std::min((uint64_t) nwg * reduce_sg_size, + ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup)); ggml_webgpu_shader_lib_context reduce_shader_ctx = shader_lib_ctx; reduce_shader_ctx.max_wg_size = reduce_wg_size; reduce_pipeline = ctx->shader_lib->get_flash_attn_vec_reduce_pipeline(reduce_shader_ctx); @@ -3110,18 +3120,16 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str uint32_t num_batched_kernels = 0; uint32_t num_inflight_batches = 0; bool contains_set_rows = false; - bool batch_compute_passes = true; int num_encoded_ops = 1; int node_idx = 0; #ifdef GGML_WEBGPU_GPU_PROFILE ctx->profile_timestamp_query_count = 0; - batch_compute_passes = false; std::vector profile_pipeline_names; #endif ctx->active_command_encoder = ctx->global_ctx->device.CreateCommandEncoder(); - if (batch_compute_passes) { + if (ctx->batch_compute_passes) { ctx->active_compute_pass = ctx->active_command_encoder.BeginComputePass(); } @@ -3148,7 +3156,7 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str // reset state for next batch ctx->active_command_encoder = ctx->global_ctx->device.CreateCommandEncoder(); - if (batch_compute_passes) { + if (ctx->batch_compute_passes) { ctx->active_compute_pass = ctx->active_command_encoder.BeginComputePass(); } ctx->param_arena.reset(); @@ -3548,8 +3556,8 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer const uint32_t kv_tile = decisions.kv_tile; const uint32_t vec_nwg_cap = ctx->webgpu_global_ctx->capabilities.min_subgroup_size; - uint32_t nwg = 1u; - const uint64_t kv_span = (uint64_t) std::max(1u, kv_tile); + uint32_t nwg = 1u; + const uint64_t kv_span = (uint64_t) std::max(1u, kv_tile); while ((2u * nwg * kv_span) < (uint64_t) K->ne[1] && nwg < vec_nwg_cap) { nwg <<= 1; } @@ -3839,6 +3847,7 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) { wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "set_rows_host_error_buf"); #ifdef GGML_WEBGPU_GPU_PROFILE + webgpu_ctx->batch_compute_passes = false; ggml_webgpu_create_buffer( webgpu_ctx->global_ctx->device, webgpu_ctx->profile_timestamp_dev_buf, WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES, wgpu::BufferUsage::QueryResolve | wgpu::BufferUsage::CopySrc, "profile_timestamp_dev_buf"); From 00a5110b1945a6144b7f5f766da98d03c84bcec4 Mon Sep 17 00:00:00 2001 From: Masashi Yoshimura Date: Tue, 26 May 2026 12:42:49 +0900 Subject: [PATCH 704/831] ggml-webgpu: Add MMVQ path for Q4/Q8/Q2_K/Q4_K and clean up legacy MUL_MAT pipeline (llama/23594) * ggml-webgpu: Add MMVQ path for Q4/Q8/Q2_K/Q4_K * Fix to editorconfig checking pass * Remove mul-mat-legacy pipeline * Fix to use vendor name as is and add dot_product/vendor to shader_lib_ctx --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 231 +++--- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 165 ++-- .../wgsl-shaders/common_decls.tmpl | 3 +- .../src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl | 747 ------------------ .../ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl | 22 +- .../wgsl-shaders/mul_mat_vec_acc.tmpl | 1 - .../wgsl-shaders/mul_mat_vec_q_acc.tmpl | 303 +++++++ .../ggml-webgpu/wgsl-shaders/quantize_q8.wgsl | 173 ++++ 8 files changed, 714 insertions(+), 931 deletions(-) delete mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 4c4eda1cbe5..60e98a60741 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -52,7 +52,7 @@ #define WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG 4 #define WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG 4 -// default size for legacy matrix multiplication +// default size for reg-tile matrix multiplication #define WEBGPU_MUL_MAT_WG_SIZE 256 // Same hash combine function as in boost @@ -93,6 +93,8 @@ struct ggml_webgpu_shader_lib_context { uint32_t sg_mat_k = 0; uint32_t min_subgroup_size = 0; uint32_t max_subgroup_size = 0; + bool supports_dot_product = false; + std::string vendor; }; struct webgpu_pipeline { @@ -850,31 +852,15 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions( /** Matrix Multiplication **/ -struct ggml_webgpu_legacy_mul_mat_pipeline_key { - ggml_type src0_type; - ggml_type src1_type; - - bool operator==(const ggml_webgpu_legacy_mul_mat_pipeline_key & other) const { - return src0_type == other.src0_type && src1_type == other.src1_type; - } -}; - -struct ggml_webgpu_legacy_mul_mat_pipeline_key_hash { - size_t operator()(const ggml_webgpu_legacy_mul_mat_pipeline_key & key) const { - size_t seed = 0; - ggml_webgpu_hash_combine(seed, key.src0_type); - ggml_webgpu_hash_combine(seed, key.src1_type); - return seed; - } -}; - struct ggml_webgpu_mul_mat_vec_pipeline_key { ggml_type src0_type; ggml_type src1_type; int vectorized; + bool use_mmvq; bool operator==(const ggml_webgpu_mul_mat_vec_pipeline_key & other) const { - return src0_type == other.src0_type && src1_type == other.src1_type && vectorized == other.vectorized; + return src0_type == other.src0_type && src1_type == other.src1_type && vectorized == other.vectorized && + use_mmvq == other.use_mmvq; } }; @@ -884,6 +870,7 @@ struct ggml_webgpu_mul_mat_vec_pipeline_key_hash { ggml_webgpu_hash_combine(seed, key.src0_type); ggml_webgpu_hash_combine(seed, key.src1_type); ggml_webgpu_hash_combine(seed, key.vectorized); + ggml_webgpu_hash_combine(seed, key.use_mmvq); return seed; } }; @@ -894,6 +881,20 @@ struct ggml_webgpu_mul_mat_vec_shader_decisions { uint32_t vec_size; }; +struct ggml_webgpu_quantize_q8_pipeline_key { + ggml_type src0_type; + + bool operator==(const ggml_webgpu_quantize_q8_pipeline_key & other) const { return src0_type == other.src0_type; } +}; + +struct ggml_webgpu_quantize_q8_pipeline_key_hash { + size_t operator()(const ggml_webgpu_quantize_q8_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.src0_type); + return seed; + } +}; + struct ggml_webgpu_mul_mat_pipeline_key { ggml_type src0_type; ggml_type src1_type; @@ -1051,6 +1052,36 @@ struct ggml_webgpu_soft_max_pipeline_key_hash { } }; +/** MMVQ **/ + +inline bool ggml_webgpu_can_use_mmvq(const ggml_tensor * src0, + const ggml_tensor * src1, + bool supports_dot_product, + const std::string & vendor) { + if (src1->ne[1] == 1) { + bool supports_dp4a = vendor == "amd" || vendor == "intel" || vendor == "nvidia"; + if (supports_dp4a && supports_dot_product) { + switch (src1->type) { + case GGML_TYPE_F32: + switch (src0->type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q4_K: + return src0->ne[0] % 4 == 0; + default: + break; + } + break; + default: + break; + } + } + } + return false; +} + class ggml_webgpu_shader_lib { wgpu::Device device; pre_wgsl::Preprocessor preprocessor; @@ -1099,14 +1130,12 @@ class ggml_webgpu_shader_lib { webgpu_pipeline, ggml_webgpu_flash_attn_blk_pipeline_key_hash> flash_attn_blk_pipelines; - std::unordered_map - mul_mat_legacy_pipelines; // legacy mul_mat (non-subgroup/non-regtile/non-vec) std::unordered_map mul_mat_vec_pipelines; // fast mat-vec (n==1) std::unordered_map mul_mat_fast_pipelines; // fast mat-mat (reg-tile or subgroup) + std::unordered_map + quantize_q8_pipelines; std::unordered_map mul_mat_id_gather_pipelines; // key is fixed std::unordered_map mul_mat_id_pipelines; // src0_type/src1_type @@ -1631,7 +1660,7 @@ class ggml_webgpu_shader_lib { key.type = context.dst->type; key.d_state = (int) context.src0->ne[0]; key.xbc_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src4) && - ggml_webgpu_tensor_overlap(context.src1, context.src5); + ggml_webgpu_tensor_overlap(context.src1, context.src5); auto it = ssm_scan_pipelines.find(key); if (it != ssm_scan_pipelines.end()) { @@ -1744,6 +1773,44 @@ class ggml_webgpu_shader_lib { return pad_pipelines[key]; } + webgpu_pipeline get_quantize_q8_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_quantize_q8_pipeline_key key = {}; + key.src0_type = context.src0->type; + + auto it = quantize_q8_pipelines.find(key); + if (it != quantize_q8_pipelines.end()) { + return it->second; + } + const char * shader_src = wgsl_quantize_q8; + std::vector defines; + std::string variant = "quantize_q8"; + + uint32_t wg_size = WEBGPU_MUL_MAT_VEC_WG_SIZE; + + defines.push_back("SRC1_INNER_TYPE=f32"); + defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); + + const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type); + std::string src0_name = src0_traits->type_name; + std::string type_upper = src0_name; + variant += "_" + src0_name; + std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper); + + defines.push_back("MUL_ACC_" + type_upper); + defines.push_back("Q8_1_T"); + + defines.push_back(context.supports_subgroups ? "USE_SUBGROUP_REDUCTION" : "USE_WORKGROUP_REDUCTION"); + variant += context.supports_subgroups ? "_sg_reduce" : "_wg_reduce"; + + auto processed = preprocessor.preprocess(shader_src, defines); + auto decisions = std::make_shared(); + decisions->wg_size = wg_size; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + quantize_q8_pipelines[key] = pipeline; + return quantize_q8_pipelines[key]; + } + webgpu_pipeline get_mul_mat_vec_pipeline(const ggml_webgpu_shader_lib_context & context) { ggml_webgpu_mul_mat_vec_pipeline_key key = {}; key.src0_type = context.src0->type; @@ -1752,6 +1819,8 @@ class ggml_webgpu_shader_lib { (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? 1 : 0; + key.use_mmvq = + ggml_webgpu_can_use_mmvq(context.src0, context.src1, context.supports_dot_product, context.vendor); auto it = mul_mat_vec_pipelines.find(key); if (it != mul_mat_vec_pipelines.end()) { @@ -1788,6 +1857,19 @@ class ggml_webgpu_shader_lib { defines.push_back("U32_DEQUANT_HELPERS"); defines.push_back("SRC0_INNER_TYPE=u32"); switch (context.src0->type) { + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + if (key.use_mmvq) { + defines.push_back("LEGACY_QUANTS"); + } + break; + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q4_K: + if (key.use_mmvq) { + defines.push_back("K_QUANTS"); + } + break; case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ1_M: case GGML_TYPE_IQ2_S: @@ -1840,6 +1922,11 @@ class ggml_webgpu_shader_lib { outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG; } + if (key.use_mmvq) { + defines.push_back("MMVQ"); + defines.push_back("Q8_1_T"); + } + defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); defines.push_back(std::string("OUTPUTS_PER_WG=") + std::to_string(outputs_per_wg)); defines.push_back(context.supports_subgroups ? "USE_SUBGROUP_REDUCTION" : "USE_WORKGROUP_REDUCTION"); @@ -2018,100 +2105,6 @@ class ggml_webgpu_shader_lib { return mul_mat_fast_pipelines[key]; } - webgpu_pipeline get_mul_mat_legacy_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_legacy_mul_mat_pipeline_key key = {}; - key.src0_type = context.src0->type; - key.src1_type = context.src1->type; - - auto it = mul_mat_legacy_pipelines.find(key); - if (it != mul_mat_legacy_pipelines.end()) { - return it->second; - } - - std::vector defines; - std::string variant = "mul_mat"; - - switch (context.src1->type) { - case GGML_TYPE_F32: - defines.push_back("SRC1_TYPE=f32"); - variant += "_f32"; - break; - case GGML_TYPE_F16: - defines.push_back("SRC1_TYPE=f16"); - variant += "_f16"; - break; - default: - GGML_ABORT("Unsupported src1 type for mul_mat legacy shader"); - } - - const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type); - const char * src0_name = src0_traits->type_name; - - switch (context.src0->type) { - case GGML_TYPE_F32: - defines.push_back("SRC0_TYPE=f32"); - defines.push_back("FLOAT"); - variant += "_f32"; - break; - case GGML_TYPE_F16: - defines.push_back("SRC0_TYPE=f16"); - defines.push_back("FLOAT"); - variant += "_f16"; - break; - default: - { - std::string type_upper = src0_name; - std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper); - - switch (context.src0->type) { - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q5_0: - case GGML_TYPE_Q8_0: - case GGML_TYPE_Q3_K: - case GGML_TYPE_Q6_K: - case GGML_TYPE_IQ2_XXS: - case GGML_TYPE_IQ2_XS: - case GGML_TYPE_IQ2_S: - case GGML_TYPE_IQ3_XXS: - case GGML_TYPE_IQ3_S: - case GGML_TYPE_IQ1_S: - case GGML_TYPE_IQ4_NL: - case GGML_TYPE_MXFP4: - { - // Quantized types using u32 buffers for portability. - defines.push_back("SRC0_TYPE=u32"); - defines.push_back("U32_DEQUANT_HELPERS"); - break; - } - default: - { - defines.push_back(std::string("SRC0_TYPE=") + src0_name); - } - } - - defines.push_back("BYTE_HELPERS"); - defines.push_back(type_upper + "_T"); - defines.push_back(type_upper); - defines.push_back(type_upper + "_SCALE_MIN"); - defines.push_back(type_upper + "_TABLES"); - defines.push_back(type_upper + "_GRID"); - - variant += std::string("_") + src0_name; - break; - } - } - - auto processed = preprocessor.preprocess(wgsl_mul_mat, defines); - - auto decisions = std::make_shared(); - decisions->wg_size = WEBGPU_MUL_MAT_WG_SIZE; - - webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); - pipeline.context = decisions; - mul_mat_legacy_pipelines[key] = pipeline; - return mul_mat_legacy_pipelines[key]; - } - webgpu_pipeline get_mul_mat_id_gather_pipeline(const ggml_webgpu_shader_lib_context & context) { auto it = mul_mat_id_gather_pipelines.find(1); if (it != mul_mat_id_gather_pipelines.end()) { diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 1561a4e30c6..f113da909ce 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -181,6 +181,7 @@ struct webgpu_capabilities { wgpu::Limits limits; bool supports_subgroups = false; bool supports_subgroup_matrix = false; + bool supports_dot_product = false; uint32_t sg_mat_m = 0; uint32_t sg_mat_n = 0; @@ -210,6 +211,8 @@ struct webgpu_global_context_struct { wgpu::Buffer memset_params_buf; webgpu_pipeline memset_pipeline; + std::string vendor; + // TODO: We should rework the CPU profiling time handling to make it more useful. ref: https://github.com/ggml-org/llama.cpp/pull/22050 #ifdef GGML_WEBGPU_CPU_PROFILE // Profiling: labeled CPU time in ms (total) @@ -1394,6 +1397,58 @@ static webgpu_encoded_op ggml_webgpu_get_rows(webgpu_context & ctx, return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } +static void ggml_webgpu_quantize_q8_dispatch(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst, + std::vector & dispatches) { + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.supports_subgroups = ctx->global_ctx->capabilities.supports_subgroups; + + webgpu_pipeline qq8_pipeline = ctx->shader_lib->get_quantize_q8_pipeline(shader_lib_ctx); + + // quantize_q8 pipeline + const size_t dst_offset = ggml_webgpu_tensor_offset(dst); + const size_t q8_src1_align_offset = ROUNDUP_POW2( + dst_offset + ggml_nbytes(dst), ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment); + const size_t q8_src1_binding_size = + ROUNDUP_POW2(src1->ne[3] * src1->ne[2] * (36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32)), + WEBGPU_STORAGE_BUF_BINDING_MULT); + + std::vector q8_params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)), + (uint32_t) src1->ne[0], + (uint32_t) src1->ne[2], + (uint32_t) src1->ne[3], + }; + + std::vector q8_entries = { + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src1), + ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(dst), q8_src1_align_offset, q8_src1_binding_size) + }; + + auto q8_decisions = static_cast(qq8_pipeline.context.get()); + + uint32_t q8_wg_size = q8_decisions->wg_size; + uint32_t q8_wg_x = 1; + uint32_t q8_wg_y = 1; + const uint32_t wg_per_vec = (src0->ne[0] / 4 + (q8_wg_size - 1)) / q8_wg_size; + const uint32_t q8_total_wg = src1->ne[2] * src1->ne[3] * wg_per_vec; + const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; + compute_2d_workgroups(q8_total_wg, max_wg_per_dim, q8_wg_x, q8_wg_y); + + dispatches.push_back({ + qq8_pipeline, std::move(q8_params), std::move(q8_entries), { q8_wg_x, q8_wg_y } + }); +} + static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, @@ -1401,47 +1456,9 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, // Determine if this is a mat-vec operation bool is_vec = (dst->ne[1] == 1); - // Determine if we should use fast path - bool use_fast = false; - switch (src1->type) { - case GGML_TYPE_F16: - use_fast = (src0->type == GGML_TYPE_F16); - break; - case GGML_TYPE_F32: - // TODO: implement better mat-mat for k-quants, mat-vec for all k-quants except q6_K - switch (src0->type) { - case GGML_TYPE_F32: - case GGML_TYPE_F16: - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - case GGML_TYPE_Q5_0: - case GGML_TYPE_Q5_1: - case GGML_TYPE_Q8_0: - case GGML_TYPE_Q6_K: - case GGML_TYPE_Q4_K: - case GGML_TYPE_Q5_K: - case GGML_TYPE_Q3_K: - case GGML_TYPE_Q2_K: - case GGML_TYPE_Q1_0: - case GGML_TYPE_IQ1_S: - case GGML_TYPE_IQ1_M: - case GGML_TYPE_IQ2_XXS: - case GGML_TYPE_IQ2_XS: - case GGML_TYPE_IQ2_S: - case GGML_TYPE_IQ3_XXS: - case GGML_TYPE_IQ3_S: - case GGML_TYPE_IQ4_NL: - case GGML_TYPE_IQ4_XS: - case GGML_TYPE_MXFP4: - use_fast = true; - break; - default: - break; - } - break; - default: - break; - } + // use MMVQ path for mat-vec + bool use_mmvq = ggml_webgpu_can_use_mmvq(src0, src1, ctx->global_ctx->capabilities.supports_dot_product, + ctx->global_ctx->vendor); ggml_webgpu_shader_lib_context shader_lib_ctx = {}; @@ -1456,16 +1473,20 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, shader_lib_ctx.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k; shader_lib_ctx.min_subgroup_size = ctx->global_ctx->capabilities.min_subgroup_size; shader_lib_ctx.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size; + shader_lib_ctx.supports_dot_product = ctx->global_ctx->capabilities.supports_dot_product; + shader_lib_ctx.vendor = ctx->global_ctx->vendor; // Get or create pipeline - webgpu_pipeline pipeline; + webgpu_pipeline pipeline; + std::vector dispatches; - if (use_fast && is_vec) { + if (is_vec) { + if (use_mmvq) { + ggml_webgpu_quantize_q8_dispatch(ctx, src0, src1, dst, dispatches); + } pipeline = ctx->shader_lib->get_mul_mat_vec_pipeline(shader_lib_ctx); - } else if (use_fast) { - pipeline = ctx->shader_lib->get_mul_mat_fast_pipeline(shader_lib_ctx); } else { - pipeline = ctx->shader_lib->get_mul_mat_legacy_pipeline(shader_lib_ctx); + pipeline = ctx->shader_lib->get_mul_mat_fast_pipeline(shader_lib_ctx); } // Build params @@ -1489,25 +1510,31 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, }; // Build bind group entries - std::vector entries = { - ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), - ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1), - ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst), - }; + std::vector entries = {}; + + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0)); + if (use_mmvq) { + auto & mmvq_qq8_entry = dispatches[0].bind_group_entries[1]; + entries.push_back(ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(dst), mmvq_qq8_entry.offset, + mmvq_qq8_entry.size)); + } else { + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1)); + } + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst)); // Calculate workgroup dimensions uint32_t wg_x = 1; uint32_t wg_y = 1; const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; - if (use_fast && is_vec) { + if (is_vec) { auto * decisions = static_cast(pipeline.context.get()); uint32_t batches = dst->ne[2] * dst->ne[3]; uint32_t output_groups = CEIL_DIV(dst->ne[0], decisions->outputs_per_wg); uint32_t total_wg = output_groups * batches; compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y); - } else if (use_fast) { + } else { auto * decisions = static_cast(pipeline.context.get()); // Fast-path tiled/subgroup calculations @@ -1528,15 +1555,13 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, } uint32_t total_wg = wg_m * wg_n * dst->ne[2] * dst->ne[3]; compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y); - - } else { // legacy - auto * decisions = static_cast(pipeline.context.get()); - uint32_t wg_size = decisions->wg_size; - uint32_t total_wg = CEIL_DIV(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3], wg_size); - compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y); } - return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); + dispatches.push_back({ + pipeline, std::move(params), std::move(entries), { wg_x, wg_y } + }); + + return ggml_backend_webgpu_build_multi(ctx, dispatches); } static webgpu_encoded_op ggml_webgpu_mul_mat_id_vec(webgpu_context & ctx, @@ -3590,6 +3615,22 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer } } break; + case GGML_OP_MUL_MAT: + { + const ggml_tensor * src0 = tensor->src[0]; + const ggml_tensor * src1 = tensor->src[1]; + bool use_mmvq = + ggml_webgpu_can_use_mmvq(src0, src1, ctx->webgpu_global_ctx->capabilities.supports_dot_product, + ctx->webgpu_global_ctx->vendor); + if (use_mmvq) { + const size_t q8_src1_size = + src1->ne[3] * src1->ne[2] * (36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32)); + res = ROUNDUP_POW2(res + q8_src1_size + + ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment, + WEBGPU_STORAGE_BUF_BINDING_MULT); + } + } + break; case GGML_OP_MUL_MAT_ID: { const ggml_tensor * src0 = tensor->src[0]; @@ -3715,12 +3756,16 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { ctx->webgpu_global_ctx->adapter.GetInfo(&info); ctx->webgpu_global_ctx->command_submit_batch_size = ggml_backend_webgpu_get_command_submit_batch_size(); ctx->webgpu_global_ctx->max_inflight_batches = ggml_backend_webgpu_get_max_inflight_batches(); + ctx->webgpu_global_ctx->vendor = info.vendor; wgpu::SupportedFeatures features; ctx->webgpu_global_ctx->adapter.GetFeatures(&features); // we require f16 support GGML_ASSERT(ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ShaderF16)); ctx->webgpu_global_ctx->capabilities.supports_subgroups = ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::Subgroups); + // for dot4I8packed + ctx->webgpu_global_ctx->capabilities.supports_dot_product = ctx->webgpu_global_ctx->instance.HasWGSLLanguageFeature( + wgpu::WGSLLanguageFeatureName::Packed4x8IntegerDotProduct); bool valid_subgroup_matrix_config = false; #ifndef __EMSCRIPTEN__ diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl index 372ea79bf9d..758efa17d77 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl @@ -95,11 +95,10 @@ struct q5_1 { }; #endif - #ifdef Q8_1_T struct q8_1 { d: f16, - m: f16, + s: f16, // d * sum(qs[i]) qs: array }; #endif diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl deleted file mode 100644 index fcbefdeb802..00000000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +++ /dev/null @@ -1,747 +0,0 @@ -enable f16; - -#define DECLARE_BYTE_LOADERS_SRC0 -#include "common_decls.tmpl" - - -#ifdef FLOAT -const BLOCK_SIZE = 1u; - -#elif defined(Q4_0) || defined(Q4_1) || defined(Q5_0) || defined(Q5_1) || defined(Q8_0) || defined(Q8_1) || defined(IQ4_NL) -const BLOCK_SIZE = 32u; - -#elif defined(Q2_K) || defined(Q3_K) || defined(Q4_K) || defined(Q5_K) || defined(Q6_K) || defined(IQ2_XXS) || defined(IQ2_XS) || defined(IQ2_S) || defined(IQ3_XXS) || defined(IQ3_S) || defined(IQ1_S) || defined(IQ1_M) || defined(IQ4_XS) -const BLOCK_SIZE = 256u; -#endif - -#ifdef FLOAT -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - return f32(src0[src0_idx_base + offset]) * f32(src1[src1_idx_base + offset]); -} -#endif - -#ifdef Q4_0 -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_byte_base = (src0_idx_base + offset) * 18; // Block stride: 18 bytes - let d = load_f16_as_f32_at_src0(block_byte_base); - var sum: f32 = 0.0; - for (var j: u32 = 0; j < 4; j++) { - let q_byte_offset = block_byte_base + 2 + j * 4; - let q_packed = load_u32_at_src0(q_byte_offset); - for (var k: u32 = 0; k < 4; k++) { - let q_byte = get_byte(q_packed, k); - let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0f) * d; - let q_lo = (f32(q_byte & 0xF) - 8.0f) * d; - let src1_offset = src1_idx_base + offset * 32 + j * 4 + k; - sum += q_lo * f32(src1[src1_offset]); - sum += q_hi * f32(src1[src1_offset + 16]); - } - } - return sum; -} -#endif - -#ifdef Q4_1 -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_q4_1 = src0[src0_idx_base + offset]; - let d = f32(block_q4_1.d); - let m = f32(block_q4_1.m); - var sum: f32 = 0.0; - for (var j: u32 = 0; j < 4; j++) { - let q_packed = block_q4_1.qs[j]; - for (var k: u32 = 0; k < 4; k++) { - let q_byte = get_byte(q_packed, k); - let q_hi = f32((q_byte >> 4) & 0xF) * d + m; - let q_lo = f32(q_byte & 0xF) * d + m; - let src1_offset = src1_idx_base + offset * 32 + j * 4 + k; - sum += q_lo * f32(src1[src1_offset]); - sum += q_hi * f32(src1[src1_offset + 16]); - } - } - return sum; -} -#endif - -#ifdef Q5_0 -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_byte_base = (src0_idx_base + offset) * 22; // Block stride: 22 bytes - let d = load_f16_as_f32_at_src0(block_byte_base); - var sum: f32 = 0.0; - let qh_packed = load_u32_at_src0(block_byte_base + 2); - for (var j: u32 = 0; j < 4; j++) { - let q_byte_offset = block_byte_base + 6 + j * 4; - let q_packed = load_u32_at_src0(q_byte_offset); - for (var k: u32 = 0; k < 4; k++) { - let q_byte = get_byte(q_packed, k); - let qh_hi = (qh_packed >> (j * 4 + k + 12)) & 0x10; - let q_hi = (f32(((q_byte >> 4) & 0xF) | qh_hi) - 16.0) * d; - let qh_lo = ((qh_packed >> (j * 4 + k)) << 4) & 0x10; - let q_lo = (f32((q_byte & 0xF) | qh_lo) - 16.0) * d; - let src1_offset = src1_idx_base + offset * 32 + j * 4 + k; - sum += q_lo * f32(src1[src1_offset]); - sum += q_hi * f32(src1[src1_offset + 16]); - } - } - return sum; -} -#endif - -#ifdef Q5_1 -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_q5_1 = src0[src0_idx_base + offset]; - let d = f32(block_q5_1.d); - let m = f32(block_q5_1.m); - var sum: f32 = 0.0; - for (var j: u32 = 0; j < 4; j++) { - let q_packed = block_q5_1.qs[j]; - for (var k: u32 = 0; k < 4; k++) { - let q_byte = get_byte(q_packed, k); - let qh_hi = (block_q5_1.qh >> (j * 4 + k + 12)) & 0x10; - let q_hi = f32(((q_byte >> 4) & 0xF) | qh_hi) * d + m; - let qh_lo = ((block_q5_1.qh >> (j * 4 + k)) << 4) & 0x10; - let q_lo = f32((q_byte & 0xF) | qh_lo) * d + m; - let src1_offset = src1_idx_base + offset * 32 + j * 4 + k; - sum += q_lo * f32(src1[src1_offset]); - sum += q_hi * f32(src1[src1_offset + 16]); - } - } - return sum; -} -#endif - -#ifdef Q8_0 -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_byte_base = (src0_idx_base + offset) * 34; // Block stride: 34 bytes - let d = load_f16_as_f32_at_src0(block_byte_base); - var sum: f32 = 0.0; - for (var j: u32 = 0; j < 8; j++) { - let q_byte_offset = block_byte_base + 2 + j * 4; - let q_packed = load_u32_at_src0(q_byte_offset); - for (var k: u32 = 0u; k < 4u; k++) { - let q_byte = get_byte_i32(q_packed, k); - let q_val = f32(q_byte) * d; - let src1_offset = src1_idx_base + offset * 32 + j * 4 + k; - sum += q_val * f32(src1[src1_offset]); - } - } - return sum; -} -#endif - -#ifdef Q8_1 -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_q8_1 = src0[src0_idx_base + offset]; - let d = f32(block_q8_1.d); - let m = f32(block_q8_1.m); - var sum: f32 = 0.0; - for (var j: u32 = 0; j < 8; j++) { - let q_packed = block_q8_1.qs[j]; - for (var k: u32 = 0; k < 4; k++) { - let q_byte = get_byte_i32(q_packed, k); - let q_val = f32(q_byte) * d + m; - let src1_offset = src1_idx_base + offset * 32 + j * 4 + k; - sum += q_val * f32(src1[src1_offset]); - } - } - return sum; -} -#endif - -#ifdef Q2_K -// 16 blocks of 16 elements each -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - let d = f32(block.d); - let m = f32(block.dmin); - var sum = 0.0; - var src1_i = src1_idx_base + offset * 256; - var is: u32 = 0; - // 2 halves of the block (128 elements each) - for (var q_b_idx: u32 = 0; q_b_idx < 64; q_b_idx += 32) { - // 4 groups (each group has 2 blocks of 16 elements) - for (var shift: u32 = 0; shift < 8; shift += 2) { - // 2 blocks - for (var k: u32 = 0; k < 32; k += 16) { - let sc = get_byte(block.scales[is / 4], is % 4); - is++; - let dl = d * f32(sc & 0xF); - let ml = m * f32(sc >> 4); - for (var l: u32 = 0u; l < 16; l++) { - let q_idx = q_b_idx + k + l; - let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4); - let qs_val = (q_byte >> shift) & 3; - sum += (f32(qs_val) * dl - ml) * src1[src1_i]; - src1_i++; - } - } - } - } - return sum; -} -#endif - -#ifdef Q3_K -// 16 blocks of 16 elements each -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_byte_base = (src0_idx_base + offset) * 110; // Block stride: 110 bytes - - // Bytes 108-109: f16 scale 'd' - let d = load_f16_as_f32_at_src0(block_byte_base + 108); - - // extract 6-bit scales, which consist of 4-bits from first 8 bytes of scale, - // and 2-bits from the last 4 bytes - // Bytes 96-107: 12 bytes of scales (3 u32s) - let kmask1: u32 = 0x03030303; - let kmask2: u32 = 0x0f0f0f0f; - var scale_vals: array; - scale_vals[0] = load_u32_at_src0(block_byte_base + 96); - scale_vals[1] = load_u32_at_src0(block_byte_base + 100); - scale_vals[2] = load_u32_at_src0(block_byte_base + 104); - - var tmp: u32 = scale_vals[2]; - scale_vals[2] = ((scale_vals[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); - scale_vals[3] = ((scale_vals[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4); - scale_vals[0] = (scale_vals[0] & kmask2) | ((tmp & kmask1) << 4); - scale_vals[1] = (scale_vals[1] & kmask2) | (((tmp >> 2) & kmask1) << 4); - - // Bytes 0-31: 32 bytes of hmask (8 u32s) - var hmask_vals: array; - for (var i: u32 = 0; i < 8; i++) { - hmask_vals[i] = load_u32_at_src0(block_byte_base + i * 4); - } - - // Bytes 32-95: 64 bytes of qs (16 u32s) - var qs_vals: array; - for (var i: u32 = 0u; i < 16; i++) { - qs_vals[i] = load_u32_at_src0(block_byte_base + 32 + i * 4); - } - - var sum = 0.0; - var src1_i = src1_idx_base + offset * 256; - var is: u32 = 0; - var m: u32 = 1; - // 2 halves of the block (128 elements each) - for (var q_b_idx: u32 = 0; q_b_idx < 64; q_b_idx += 32) { - // 4 groups (each group has 2 blocks of 16 elements) - for (var shift: u32 = 0; shift < 8; shift += 2) { - // 2 blocks - for (var k: u32 = 0; k < 32; k += 16) { - let sc = get_byte(scale_vals[is / 4], is % 4); - is++; - let dl = d * (f32(sc) - 32.0); - for (var l: u32 = 0u; l < 16u; l++) { - let q_idx = q_b_idx + k + l; - let hm_idx = k + l; - let q_byte = get_byte(qs_vals[q_idx / 4], q_idx % 4); - let hmask_byte = get_byte(hmask_vals[hm_idx / 4], hm_idx % 4); - let hm = select(4.0, 0.0, (hmask_byte & m) != 0); - let qs_val = (q_byte >> shift) & 3; - sum += ((f32(qs_val) - hm) * dl) * src1[src1_i]; - src1_i++; - } - } - m <<= 1; - } - } - return sum; -} -#endif - -#ifdef Q4_K -// 8 blocks of 32 elements each -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - let d = f32(block.d); - let m = f32(block.dmin); - var sum = 0.0; - var src1_i = src1_idx_base + offset * 256; - var is: u32 = 0; - // 2 blocks each iteration - for (var q_b_idx: u32 = 0; q_b_idx < 128; q_b_idx += 32) { - for (var shift: u32 = 0; shift < 8; shift += 4) { - let scale_min = get_scale_min(is, block.scales); - is++; - let dl = d * scale_min.x; - let ml = m * scale_min.y; - for (var l: u32 = 0; l < 32; l++) { - let q_idx = q_b_idx + l; - let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4); - let qs_val = (q_byte >> shift) & 0xF; - sum += (f32(qs_val) * dl - ml) * src1[src1_i]; - src1_i++; - } - } - } - return sum; -} -#endif - -#ifdef Q5_K -// 8 blocks of 32 elements each -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - let d = f32(block.d); - let m = f32(block.dmin); - var sum = 0.0; - var src1_i = src1_idx_base + offset * 256; - var is: u32 = 0; - var u: u32 = 1; - // 2 blocks each iteration - for (var q_b_idx: u32 = 0; q_b_idx < 128; q_b_idx += 32) { - for (var shift: u32 = 0; shift < 8; shift += 4) { - let scale_min = get_scale_min(is, block.scales); - is++; - let dl = d * scale_min.x; - let ml = m * scale_min.y; - for (var l: u32 = 0; l < 32; l++) { - let q_idx = q_b_idx + l; - let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4); - let qh_byte = get_byte(block.qh[l / 4], l % 4); - let qs_val = (q_byte >> shift) & 0xF; - let qh_val = select(0.0, 16.0, (qh_byte & u) != 0); - sum += ((f32(qs_val) + qh_val) * dl - ml) * src1[src1_i]; - src1_i++; - } - u <<= 1; - } - } - return sum; -} -#endif - -#ifdef Q6_K -// 16 blocks of 16 elements each -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_byte_base = (src0_idx_base + offset) * 210; // Block stride: 210 bytes - - // Bytes 208-209: f16 scale 'd' - let d = load_f16_as_f32_at_src0(block_byte_base + 208); - - // Bytes 0-127: 128 bytes of ql (32 u32s) - var ql_vals: array; - for (var i: u32 = 0; i < 32; i++) { - ql_vals[i] = load_u32_at_src0(block_byte_base + i * 4); - } - - // Bytes 128-191: 64 bytes of qh (16 u32s) - var qh_vals: array; - for (var i: u32 = 0; i < 16; i++) { - qh_vals[i] = load_u32_at_src0(block_byte_base + 128 + i * 4); - } - - // Bytes 192-207: 16 bytes of scales (4 u32s) - var scale_vals: array; - for (var i: u32 = 0; i < 4; i++) { - scale_vals[i] = load_u32_at_src0(block_byte_base + 192 + i * 4); - } - - var sum = 0.0; - var src1_i = src1_idx_base + offset * 256; - var qh_b_idx: u32 = 0; - var sc_b_idx: u32 = 0; - for (var ql_b_idx: u32 = 0; ql_b_idx < 128; ql_b_idx += 64) { - for (var l: u32 = 0; l < 32; l++) { - let ql13_b = get_byte(ql_vals[(ql_b_idx + l) / 4], (ql_b_idx + l) % 4); - let ql24_b = get_byte(ql_vals[(ql_b_idx + l + 32) / 4], (ql_b_idx + l + 32) % 4); - let qh_b = get_byte(qh_vals[(qh_b_idx + l) / 4], (qh_b_idx + l) % 4); - - let q1 = f32((ql13_b & 0xF) | ((qh_b & 3) << 4)) - 32.0; - let q2 = f32((ql24_b & 0xF) | (((qh_b >> 2) & 3) << 4)) - 32.0; - let q3 = f32((ql13_b >> 4) | (((qh_b >> 4) & 3) << 4)) - 32.0; - let q4 = f32((ql24_b >> 4) | (((qh_b >> 6) & 3) << 4)) - 32.0; - - let is = l/16; - let is1 = sc_b_idx + is; - let sc1 = get_byte_i32(scale_vals[is1 / 4], is1 % 4); - let is2 = sc_b_idx + is + 2; - let sc2 = get_byte_i32(scale_vals[is2 / 4], is2 % 4); - let is3 = sc_b_idx + is + 4; - let sc3 = get_byte_i32(scale_vals[is3 / 4], is3 % 4); - let is4 = sc_b_idx + is + 6; - let sc4 = get_byte_i32(scale_vals[is4 / 4], is4 % 4); - - sum += d * f32(sc1) * q1 * src1[src1_i + l]; - sum += d * f32(sc2) * q2 * src1[src1_i + l + 32]; - sum += d * f32(sc3) * q3 * src1[src1_i + l + 64]; - sum += d * f32(sc4) * q4 * src1[src1_i + l + 96]; - } - src1_i += 128; - qh_b_idx += 32; - sc_b_idx += 8; - } - return sum; -} -#endif - -#ifdef IQ2_XXS -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_byte_base = (src0_idx_base + offset) * 66; // Block stride: 66 bytes - let d = load_f16_as_f32_at_src0(block_byte_base); - var src1_i = src1_idx_base + offset * 256; - var sum = 0.0; - for (var ib: u32 = 0; ib < 32; ib += 4) { - let aux0_offset = block_byte_base + 2 + ib * 2; - let aux1_offset = block_byte_base + 2 + (ib + 2) * 2; - let aux0 = load_u32_at_src0(aux0_offset); - let aux1 = load_u32_at_src0(aux1_offset); - let db = d * (0.5 + f32(aux1 >> 28)) * 0.25; - for (var l: u32 = 0; l < 4; l++) { - let ig = get_byte(aux0, l) * 8; - let is = (aux1 >> (7 * l)) & 127; - let signs = get_byte(ksigns_iq2xs[is / 4], is % 4); - for (var j: u32 = 0; j < 8; j++) { - let g = get_byte(iq2xxs_grid[(ig + j) / 4], (ig + j) % 4); - let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0); - sum += db * f32(g) * m * src1[src1_i]; - src1_i++; - } - } - } - return sum; -} -#endif - -#ifdef IQ2_XS -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_byte_base = (src0_idx_base + offset) * 74; // Block stride: 74 bytes - let d = load_f16_as_f32_at_src0(block_byte_base); - var src1_i = src1_idx_base + offset * 256; - - var scale_vals = array( - load_u32_at_src0(block_byte_base + 66), - load_u32_at_src0(block_byte_base + 70) - ); - - var sum = 0.0; - for (var ib: u32 = 0; ib < 32; ib += 4) { - let s = get_byte(scale_vals[ib / 16], (ib % 16) / 4); - let db = array( - d * (0.5 + f32(s & 0xF)) * 0.25, - d * (0.5 + f32(s >> 4)) * 0.25 - ); - for (var l: u32 = 0; l < 4; l++) { - let qs_offset = block_byte_base + 2 + (ib + l) * 2; - let qs_val = load_u32_at_src0(qs_offset) & 0xFFFF; - let ig = (qs_val & 511) * 8; - let is = qs_val >> 9; - let signs = get_byte(ksigns_iq2xs[is / 4], is % 4); - let dl = db[l/2]; - for (var j: u32 = 0; j < 8; j++) { - let g = get_byte(iq2xs_grid[(ig + j) / 4], (ig + j) % 4); - let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0); - sum += dl * f32(g) * m * src1[src1_i]; - src1_i++; - } - } - } - return sum; -} -#endif - -#ifdef IQ2_S -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_byte_base = (src0_idx_base + offset) * 82; // Block stride: 82 bytes - let d = load_f16_as_f32_at_src0(block_byte_base); - var src1_i = src1_idx_base + offset * 256; - - var qs_vals : array; - for (var i: u32 = 0; i < 16; i++) { - qs_vals[i] = load_u32_at_src0(block_byte_base + 2 + i * 4); - } - - var qh_vals: array; - qh_vals[0] = load_u32_at_src0(block_byte_base + 66); - qh_vals[1] = load_u32_at_src0(block_byte_base + 70); - - var scale_vals: array; - scale_vals[0] = load_u32_at_src0(block_byte_base + 74); - scale_vals[1] = load_u32_at_src0(block_byte_base + 78); - - var sum = 0.0; - for (var ib: u32 = 0; ib < 8; ib ++) { - let s = get_byte(scale_vals[ib / 4], ib % 4); - let db = array( - d * (0.5 + f32(s & 0xF)) * 0.25, - d * (0.5 + f32(s >> 4)) * 0.25 - ); - let qs_w = qs_vals[ib]; - for (var l: u32 = 0; l < 4; l++) { - let qh_b = (get_byte(qh_vals[ib / 4], ib % 4) << (8 - 2 * l)) & 0x300; - let ig = (get_byte(qs_w, l) | qh_b) * 8; - let signs = get_byte(qs_vals[ib + 8], l); - let dl = db[l/2]; - for (var j: u32 = 0; j < 8; j++) { - let g = get_byte(iq2s_grid[(ig + j) / 4], (ig + j) % 4); - let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0); - sum += dl * f32(g) * m * src1[src1_i]; - src1_i++; - } - } - } - return sum; -} -#endif - -#ifdef IQ3_XXS -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_byte_base = (src0_idx_base + offset) * 98; // Block stride: 98 bytes - let d = load_f16_as_f32_at_src0(block_byte_base); - var src1_i = src1_idx_base + offset * 256; - var sum = 0.0; - for (var ib: u32 = 0; ib < 16; ib += 2) { - let sc_sign_offset = block_byte_base + 2 + (ib + 32) * 2; - let sc_sign = load_u32_at_src0(sc_sign_offset); - let db = d * (0.5 + f32(sc_sign >> 28)) * 0.5; - for (var l: u32 = 0; l < 4; l++) { - let is = (sc_sign >> (7 * l)) & 127; - let signs = get_byte(ksigns_iq2xs[is / 4], is % 4); - let ig_val = load_u32_at_src0(block_byte_base + 2 + (ib * 2 + l) * 2) & 0xFFFF; - let ig1 = get_byte(ig_val, 0); - let ig2 = get_byte(ig_val, 1); - for (var j: u32 = 0; j < 4; j++) { - let g1 = get_byte(iq3xxs_grid[ig1], j); - let g2 = get_byte(iq3xxs_grid[ig2], j); - let m1 = select(1.0, -1.0, (get_byte(kmask_iq2xs[0], j) & signs) != 0); - let m2 = select(1.0, -1.0, (get_byte(kmask_iq2xs[1], j) & signs) != 0); - sum += db * f32(g1) * m1 * src1[src1_i]; - sum += db * f32(g2) * m2 * src1[src1_i + 4]; - src1_i++; - } - src1_i += 4; - } - } - return sum; -} -#endif - -#ifdef IQ3_S -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_byte_base = (src0_idx_base + offset) * 110; // Block stride: 110 bytes - let d = load_f16_as_f32_at_src0(block_byte_base); - var src1_i = src1_idx_base + offset * 256; - - var qh_vals = array( - load_u32_at_src0(block_byte_base + 66), - load_u32_at_src0(block_byte_base + 70) - ); - - var sign_vals: array; - for (var i: u32 = 0; i < 8; i++) { - sign_vals[i] = load_u32_at_src0(block_byte_base + 74 + i * 4); - } - - var scale_vals = load_u32_at_src0(block_byte_base + 106); - - var sum = 0.0; - for (var ib: u32 = 0; ib < 4; ib++) { - let s = get_byte(scale_vals, ib); - let db = array( - d * (1.0 + 2.0 * f32(s & 0xF)), - d * (1.0 + 2.0 * f32(s >> 4)) - ); - for (var k: u32 = 0; k < 2; k++) { - let dl = db[k]; - let qh_byte = get_byte(qh_vals[ib / 2], (ib % 2) * 2 + k); - let sign_w = sign_vals[ib * 2 + k]; - for (var l: u32 = 0; l < 4; l++) { - let signs = get_byte(sign_w, l); - let ig_val = load_u32_at_src0(block_byte_base + 2 + (ib * 8 + k * 4 + l) * 2) & 0xFFFF; - let ig1 = get_byte(ig_val, 0) | ((qh_byte << ((8 - (2 * l)))) & 256); - let ig2 = get_byte(ig_val, 1) | ((qh_byte << ((7 - (2 * l)))) & 256); - for (var j: u32 = 0; j < 4; j++) { - let g1 = get_byte(iq3s_grid[ig1], j); - let g2 = get_byte(iq3s_grid[ig2], j); - let m1 = select(1.0, -1.0, (get_byte(kmask_iq2xs[0], j) & signs) != 0); - let m2 = select(1.0, -1.0, (get_byte(kmask_iq2xs[1], j) & signs) != 0); - sum += dl * f32(g1) * m1 * src1[src1_i]; - sum += dl * f32(g2) * m2 * src1[src1_i + 4]; - src1_i++; - } - src1_i += 4; - } - } - } - return sum; -} -#endif - -#ifdef IQ1_S -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_byte_base = (src0_idx_base + offset) * 50; // Block stride: 50 bytes - let d = load_f16_as_f32_at_src0(block_byte_base); - var src1_i = src1_idx_base + offset * 256; - var sum = 0.0; - for (var ib: u32 = 0; ib < 8; ib++) { - let qh = load_u32_at_src0(block_byte_base + 34 + ib * 2) & 0xFFFF; - let dl = d * (2.0 * f32((qh >> 12) & 7) + 1.0); - let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000) != 0); - let qs_w = load_u32_at_src0(block_byte_base + 2 + ib * 4); - for (var l: u32 = 0; l < 4; l++) { - let ig = (get_byte(qs_w, l) | (((qh >> (3 * l)) & 7) << 8)) * 8; - for (var j: u32 = 0; j < 8; j++) { - let gw = iq1_grid[(ig + j) / 16]; - let g = (gw >> (((ig + j) % 16) * 2)) & 3; - let gs = bitcast(g << 30) >> 30; - sum += dl * (f32(gs) + delta) * src1[src1_i]; - src1_i++; - } - } - } - return sum; -} -#endif - - -#ifdef IQ1_M -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - - let scale = ((block.scales[0] >> 12) & 0xF) | ((block.scales[0] >> 24) & 0x00F0) | ((block.scales[1] >> 4) & 0x0F00) | ((block.scales[1] >> 16) & 0xF000); - let d = f32(bitcast>(scale).x); - var src1_i = src1_idx_base + offset * 256; - var sum = 0.0; - for (var ib: u32 = 0; ib < 8; ib++) { - let sw = (block.scales[ib / 4] >> (16 * ((ib / 2) % 2))) & 0xFFFF; - let s1 : u32 = (sw >> (6 * (ib % 2))) & 0x7; - let s2 : u32 = (sw >> (6 * (ib % 2) + 3)) & 0x7; - var dl = array( - d * f32(2 * s1 + 1), - d * f32(2 * s2 + 1) - ); - - let qh = block.qh[ib / 2] >> (16 * (ib % 2)); - var idx = array( - get_byte(block.qs[ib], 0) | ((qh << 8) & 0x700), - get_byte(block.qs[ib], 1) | ((qh << 4) & 0x700), - get_byte(block.qs[ib], 2) | ((qh) & 0x700), - get_byte(block.qs[ib], 3) | ((qh >> 4) & 0x700) - ); - var delta = array( - select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x08) != 0), - select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x80) != 0), - select(IQ1_DELTA, -IQ1_DELTA, ((qh >> 8) & 0x08) != 0), - select(IQ1_DELTA, -IQ1_DELTA, ((qh >> 8) & 0x80) != 0) - ); - for (var l: u32 = 0; l < 4; l++) { - let ig = idx[l] * 8; - for (var j: u32 = 0; j < 8; j++) { - let gw = iq1_grid[(ig + j) / 16]; - let g = (gw >> (((ig + j) % 16) * 2)) & 3; - let gs = bitcast(g << 30) >> 30; - sum += dl[l/2] * (f32(gs) + delta[l]) * src1[src1_i]; - src1_i++; - } - } - } - return sum; -} -#endif - -#ifdef IQ4_NL -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_byte_base = (src0_idx_base + offset) * 18; // Block stride: 18 bytes - let d = load_f16_as_f32_at_src0(block_byte_base); - var src1_i = src1_idx_base + offset * 32; - var sum = 0.0; - var qs: array; - for (var i: u32 = 0; i < 4; i++) { - qs[i] = load_u32_at_src0(block_byte_base + 2 + i * 4); - } - for (var j: u32 = 0; j < 16; j++) { - let qsb = get_byte(qs[j / 4], j % 4); - sum += d * f32(kvalues_iq4nl[qsb & 0xF]) * src1[src1_i]; - sum += d * f32(kvalues_iq4nl[qsb >> 4]) * src1[src1_i + 16]; - src1_i++; - } - return sum; -} -#endif - -#ifdef IQ4_XS -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - let d = unpack2x16float(block.d_scales_h)[0]; - let scales_h = block.d_scales_h >> 16; - var src1_i = src1_idx_base + offset * 256; - var sum = 0.0; - for (var ib: u32 = 0; ib < 8; ib++) { - let ls = ((get_byte(block.scales_l, ib / 2) >> (4 * (ib % 2))) & 0xF) | (((scales_h >> (2 * ib)) & 3) << 4); - let dl = d * (f32(ls) - 32.0); - for (var j: u32 = 0; j < 16; j++) { - let iqs = ib * 16 + j; - let qsb = get_byte(block.qs[iqs / 4], iqs % 4); - sum += dl * f32(kvalues_iq4nl[qsb & 0xF]) * src1[src1_i]; - sum += dl * f32(kvalues_iq4nl[qsb >> 4]) * src1[src1_i + 16]; - src1_i++; - } - src1_i += 16; - } - return sum; -} -#endif - -struct MulMatParams { - offset_src0: u32, // in elements/blocks - offset_src1: u32, // in elements/blocks - offset_dst: u32, // in elements/blocks - m: u32, - n: u32, - k: u32, - // all strides are in elements/blocks - stride_01: u32, - stride_11: u32, - stride_02: u32, - stride_12: u32, - stride_03: u32, - stride_13: u32, - - bs02: u32, - bs03: u32, - broadcast2: u32, - broadcast3: u32 -}; - -@group(0) @binding(0) var src0: array; // M rows, K columns -@group(0) @binding(1) var src1: array; // K rows, N columns (transposed) -@group(0) @binding(2) var dst: array; // M rows, N columns - -@group(0) @binding(3) var params: MulMatParams; - -@compute @workgroup_size(256) -fn main(@builtin(local_invocation_id) local_id: vec3, - @builtin(workgroup_id) wg_id: vec3, - @builtin(num_workgroups) num_wg: vec3) { - let wg_linear = wg_id.y * num_wg.x + wg_id.x; - let global_idx = wg_linear * 256u + local_id.x; - - let total = params.m * params.n * params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3; - if (global_idx >= total) { - return; - } - - let dst2_stride = params.m * params.n; - let dst3_stride = dst2_stride * params.bs02 * params.broadcast2; - - let dst3_idx = global_idx / dst3_stride; - let src03_idx = dst3_idx / params.broadcast3; // src0 may be broadcast along the third dimension - let src13_idx = dst3_idx; // src1 is not broadcast - let dst3_rem = global_idx % dst3_stride; - - let dst2_idx = dst3_rem / dst2_stride; - let src02_idx = dst2_idx / params.broadcast2; // src0 may also be broadcast along the second dimension - let src12_idx = dst2_idx; // src1 is not broadcast - - let dst2_rem = dst3_rem % dst2_stride; - - let row = dst2_rem / params.m; // output row - let col = dst2_rem % params.m; // output column - - let src0_idx_base = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02 + col * params.stride_01; - let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12 + row * params.stride_11; - - var sum = 0.0; - for (var i: u32 = 0u; i < params.k/BLOCK_SIZE; i = i + 1u) { - sum += multiply_add(src0_idx_base, src1_idx_base, i); - } - dst[params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row * params.m + col] = sum; -} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl index a194cf40468..f0a7fbd059a 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl @@ -3,10 +3,18 @@ enable subgroups; #endif enable f16; +#ifdef MMVQ +requires packed_4x8_integer_dot_product; +#endif + #define DECLARE_BYTE_LOADERS_SRC0 #include "common_decls.tmpl" +#ifdef MMVQ +#include "mul_mat_vec_q_acc.tmpl" +#else #include "mul_mat_vec_acc.tmpl" +#endif struct MulMatParams { offset_src0: u32, @@ -28,9 +36,14 @@ struct MulMatParams { }; @group(0) @binding(0) var src0: array; + +#ifdef MMVQ +@group(0) @binding(1) var src1q: array; +#else @group(0) @binding(1) var src1: array; -@group(0) @binding(2) var dst: array; +#endif +@group(0) @binding(2) var dst: array; // "mul_mat_vec_acc.tmpl" requires params.k, params.m, params.stride_01 @group(0) @binding(3) var params: MulMatParams; @@ -75,10 +88,15 @@ fn main( let src12_idx = dst2_idx; let src0_batch_offset = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02; - let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12; let dst_idx_base = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row_base; +#ifdef MMVQ + let src1q_idx_base = (src13_idx * params.bs02 * params.broadcast2 + src12_idx) * (params.k / 32u); + let acc = accumulate_vec_q_dot(thread_id, row_base, src0_batch_offset, src1q_idx_base); +#else + let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12; let acc = accumulate_vec_dot(thread_id, row_base, src0_batch_offset, src1_idx_base); +#endif #ifdef USE_SUBGROUP_REDUCTION for (var row = 0u; row < OUTPUTS_PER_WG; row++) { diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl index 711c7e829d8..08753b9d643 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl @@ -436,7 +436,6 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src } #endif - #ifdef MUL_ACC_Q3_K #define BLOCK_SIZE 256 #define BLOCK_SIZE_BYTES 110 diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl new file mode 100644 index 00000000000..3ef2f77ebe0 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl @@ -0,0 +1,303 @@ +#ifdef U32_DEQUANT_HELPERS +#define SRC0_TYPE u32 + +fn byte_of(v: u32, b: u32) -> u32 { + return (v >> (b * 8u)) & 0xFFu; +} + +fn sbyte_of(v: u32, b: u32) -> i32 { + let raw = i32((v >> (b * 8u)) & 0xFFu); + return select(raw, raw - 256, raw >= 128); +} +#endif + +#define SRC0_TYPE SRC0_INNER_TYPE +#define SRC1_TYPE SRC1_INNER_TYPE + +#ifdef LEGACY_QUANTS +#define BLOCK_SIZE 32 +#define THREADS_PER_BLOCK 4 +#elif K_QUANTS +#define BLOCK_SIZE 256 +#define THREADS_PER_BLOCK 16 +#endif + +#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) +#define Q8_BLOCK_SIZE 32 + +#ifdef MUL_ACC_Q4_0 +#define BLOCK_SIZE_BYTES 18 +#define B_DS_TYPE vec2 +fn repack_a(block_byte_base: u32, inner_id: u32) -> vec2 { + let qs_packed = load_u32_at_src0(block_byte_base + 2u + 4u * inner_id); + + return vec2( + qs_packed & 0x0F0F0F0Fu, + (qs_packed >> 4u) & 0x0F0F0F0Fu + ); +} +fn repack_b_qs(block:u32, inner_id: u32) -> vec2 { + return vec2( + src1q[block].qs[inner_id], + src1q[block].qs[inner_id + 4u], + ); +} +fn repack_b_dm(block: u32) -> B_DS_TYPE { + return B_DS_TYPE( + f32(src1q[block].d), + f32(src1q[block].s) + ); +} +fn get_dm(block_byte_base: u32) -> f32 { + return f32(load_f16_at_src0(block_byte_base)); +} +fn mul_q8_1(row_sum: i32, da: f32, b_ds: B_DS_TYPE) -> f32 { + return f32(row_sum) * (da * b_ds.x) - 8.0 * da * b_ds.y / THREADS_PER_BLOCK; +} +#endif + +#ifdef MUL_ACC_Q4_1 +#define BLOCK_SIZE_BYTES 20 +#define B_DS_TYPE vec2 +fn repack_a(block_byte_base: u32, inner_id: u32) -> vec2 { + let qs_packed = load_u32_at_src0(block_byte_base + 4u + 4u * inner_id); + + return vec2( + qs_packed & 0x0F0F0F0Fu, + (qs_packed >> 4u) & 0x0F0F0F0Fu + ); +} +fn repack_b_qs(block:u32, inner_id: u32) -> vec2 { + return vec2( + src1q[block].qs[inner_id], + src1q[block].qs[inner_id + 4u], + ); +} +fn repack_b_dm(block: u32) -> B_DS_TYPE { + return B_DS_TYPE( + f32(src1q[block].d), + f32(src1q[block].s) + ); +} +fn get_dm(block_byte_base: u32) -> vec2 { + return vec2( + f32(load_f16_at_src0(block_byte_base)), + f32(load_f16_at_src0(block_byte_base + 2u)) + ); +} +fn mul_q8_1(row_sum: i32, dma: vec2, b_ds: B_DS_TYPE) -> f32 { + return f32(row_sum) * (dma.x * b_ds.x) + dma.y * b_ds.y / THREADS_PER_BLOCK; +} +#endif + +#ifdef MUL_ACC_Q8_0 +#define BLOCK_SIZE_BYTES 34 +#define B_DS_TYPE f32 +fn repack_a(block_byte_base: u32, inner_id: u32) -> vec2 { + return vec2( + load_u32_at_src0(block_byte_base + 2u + 4u * (inner_id * 2u)), + load_u32_at_src0(block_byte_base + 2u + 4u * (inner_id * 2u + 1)) + ); +} +fn repack_b_qs(block:u32, inner_id: u32) -> vec2 { + return vec2( + src1q[block].qs[inner_id * 2u], + src1q[block].qs[inner_id * 2u + 1], + ); +} +fn repack_b_dm(block: u32) -> B_DS_TYPE { + return B_DS_TYPE(src1q[block].d); +} +fn get_dm(block_byte_base: u32) -> f32 { + return f32(load_f16_at_src0(block_byte_base)); +} +fn mul_q8_1(row_sum: i32, da: f32, b_ds: B_DS_TYPE) -> f32 { + return f32(row_sum) * (da * b_ds); +} +#endif + +#ifdef LEGACY_QUANTS +fn mmvq_dot_product(a_byte_base: u32, b_inner_id: u32, b_repacked: vec2, b_ds: B_DS_TYPE) -> f32 { + var row_sum = 0; + let a_repacked = repack_a(a_byte_base, b_inner_id); + + row_sum += dot4I8Packed(a_repacked[0], b_repacked[0]); + row_sum += dot4I8Packed(a_repacked[1], b_repacked[1]); + + return mul_q8_1(row_sum, get_dm(a_byte_base), b_ds); +} + +fn accumulate_vec_q_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1q_idx_base: u32) -> array { + var acc: array; + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { + let b_inner_id = thread_id % THREADS_PER_BLOCK; + let b_block_idx = src1q_idx_base + block; + + let b_repacked = repack_b_qs(b_block_idx, b_inner_id); + let b_ds = repack_b_dm(b_block_idx); + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + acc[row] += mmvq_dot_product(block_byte_base, b_inner_id, b_repacked, b_ds); + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_Q2_K +#define BLOCK_SIZE_BYTES 84 +#define B_DS_TYPE f32 +fn repack_a(block_byte_base: u32, tid: u32) -> vec4 { + let ih2 = tid / 8u; + let phase = tid % 2u; + let iq4_idx = 2u * ih2 + phase; + let qs_byte_base = block_byte_base + 16u + 16u * iq4_idx; + let qs_shift = tid & 6u; + return vec4( + (load_u32_at_src0_aligned(qs_byte_base) >> qs_shift) & 0x03030303u, + (load_u32_at_src0_aligned(qs_byte_base + 4u) >> qs_shift) & 0x03030303u, + (load_u32_at_src0_aligned(qs_byte_base + 8u) >> qs_shift) & 0x03030303u, + (load_u32_at_src0_aligned(qs_byte_base + 12u) >> qs_shift) & 0x03030303u, + ); +} +fn repack_b_qs(q8_block_idx: u32, tid: u32) -> vec4 { + let phase = tid % 2u; + return vec4( + src1q[q8_block_idx].qs[4u * phase], + src1q[q8_block_idx].qs[4u * phase + 1u], + src1q[q8_block_idx].qs[4u * phase + 2u], + src1q[q8_block_idx].qs[4u * phase + 3u], + ); +} +fn repack_b_dm(q8_block_idx: u32) -> B_DS_TYPE { + return B_DS_TYPE(src1q[q8_block_idx].d); +} +fn get_dm(block_byte_base: u32) -> vec2 { + return vec2( + f32(load_f16_at_src0(block_byte_base + 80u)), + f32(load_f16_at_src0(block_byte_base + 82u)), + ); +} +fn get_scale_min(block_byte_base: u32, tid: u32) -> vec2 { + let scale_byte = block_byte_base + tid; + let scale = byte_of(load_u32_at_src0_aligned(scale_byte), scale_byte & 3u); + return vec2(f32(scale & 0xFu), f32(scale >> 4u)); +} +fn mmvq_dot_product(a_byte_base: u32, tid: u32, b_repacked: vec4, b_ds: B_DS_TYPE) -> f32 { + let a_repacked = repack_a(a_byte_base, tid); + let dm = get_dm(a_byte_base); + let scale_min = get_scale_min(a_byte_base, tid); + + let scale_q = i32(scale_min.x); + let scale_m_i8x4 = u32(scale_min.y) * 0x01010101u; + + let row_sum_d = (dot4I8Packed(b_repacked[0], a_repacked[0]) + dot4I8Packed(b_repacked[1], a_repacked[1]) + + dot4I8Packed(b_repacked[2], a_repacked[2]) + dot4I8Packed(b_repacked[3], a_repacked[3])) * scale_q; + let row_sum_m = dot4I8Packed(b_repacked[0], scale_m_i8x4) + dot4I8Packed(b_repacked[1], scale_m_i8x4) + + dot4I8Packed(b_repacked[2], scale_m_i8x4) + dot4I8Packed(b_repacked[3], scale_m_i8x4); + + return b_ds * (dm.x * f32(row_sum_d) - dm.y * f32(row_sum_m)); +} +#endif + +#ifdef MUL_ACC_Q4_K +#define BLOCK_SIZE_BYTES 144 +#define B_DS_TYPE vec2 +fn repack_a(block_byte_base: u32, tid: u32) -> vec4 { + let iq4 = tid / 4u; + let phase = tid % 2u; + let nibble = (tid >> 1u) % 2u; + let q_qs_byte_base = block_byte_base + 16u + 32u * iq4 + 16u * phase; + let qs_shift = 4u * nibble; + return vec4( + (load_u32_at_src0_aligned(q_qs_byte_base) >> qs_shift) & 0x0F0F0F0Fu, + (load_u32_at_src0_aligned(q_qs_byte_base + 4u) >> qs_shift) & 0x0F0F0F0Fu, + (load_u32_at_src0_aligned(q_qs_byte_base + 8u) >> qs_shift) & 0x0F0F0F0Fu, + (load_u32_at_src0_aligned(q_qs_byte_base + 12u) >> qs_shift) & 0x0F0F0F0Fu, + ); +} +fn repack_b_qs(q8_block_idx: u32, tid: u32) -> vec4 { + let phase = tid % 2u; + return vec4( + src1q[q8_block_idx].qs[4u * phase], + src1q[q8_block_idx].qs[4u * phase + 1u], + src1q[q8_block_idx].qs[4u * phase + 2u], + src1q[q8_block_idx].qs[4u * phase + 3u], + ); +} +fn repack_b_dm(q8_block_idx: u32) -> B_DS_TYPE { + return B_DS_TYPE( + f32(src1q[q8_block_idx].d), + f32(src1q[q8_block_idx].s), + ); +} +fn get_dm(block_byte_base: u32) -> vec2 { + return vec2( + f32(load_f16_at_src0(block_byte_base + 0u)), + f32(load_f16_at_src0(block_byte_base + 2u)), + ); +} +fn get_scale_min(block_byte_base: u32, tid: u32) -> vec2 { + let sc_m_idx = tid / 2u; + let scales_byte_base = block_byte_base + 4u; + let scales0_3 = load_u32_at_src0_aligned(scales_byte_base); + let scales4_7 = load_u32_at_src0_aligned(scales_byte_base + 4u); + let scales8_11 = load_u32_at_src0_aligned(scales_byte_base + 8u); + + let byte_idx = sc_m_idx & 3u; + let is_high = sc_m_idx >= 4u; + + let sc_low = byte_of(scales0_3, byte_idx) & 0x3Fu; + let sc_high = (byte_of(scales8_11, byte_idx) & 0x0Fu) | ((byte_of(scales0_3, byte_idx) & 0xC0u) >> 2u); + let scale = f32(select(sc_low, sc_high, is_high)); + + let mn_low = byte_of(scales4_7, byte_idx) & 0x3Fu; + let mn_high = (byte_of(scales8_11, byte_idx) >> 4u) | ((byte_of(scales4_7, byte_idx) & 0xC0u) >> 2u); + let min_val = f32(select(mn_low, mn_high, is_high)); + + return vec2(scale, min_val); +} +fn mmvq_dot_product(a_byte_base: u32, tid: u32, b_repacked: vec4, b_ds: B_DS_TYPE) -> f32 { + let a_repacked = repack_a(a_byte_base, tid); + let dm = get_dm(a_byte_base); + let scale_min = get_scale_min(a_byte_base, tid); + + let row_sum = dot4I8Packed(a_repacked[0], b_repacked[0]) + dot4I8Packed(a_repacked[1], b_repacked[1]) + + dot4I8Packed(a_repacked[2], b_repacked[2]) + dot4I8Packed(a_repacked[3], b_repacked[3]); + + // Each thread covers half of the Q8_1 block, so add only b_ds.y/2. + return b_ds.x * dm.x * scale_min.x * f32(row_sum) - dm.y * scale_min.y * (b_ds.y / (Q8_BLOCK_SIZE / ELEMS_PER_THREAD)); +} +#endif + +#ifdef K_QUANTS +fn accumulate_vec_q_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1q_idx_base: u32) -> array { + var acc: array; + + let tid = thread_id % THREADS_PER_BLOCK; + + for (var block = thread_id / THREADS_PER_BLOCK; block < params.k / BLOCK_SIZE; block += WG_SIZE / THREADS_PER_BLOCK) { + let src1q_idx = src1q_idx_base + (block * BLOCK_SIZE + ELEMS_PER_THREAD * tid) / Q8_BLOCK_SIZE; + let b_repacked = repack_b_qs(src1q_idx, tid); + let b_ds = repack_b_dm(src1q_idx); + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + acc[row] += mmvq_dot_product(block_byte_base, tid, b_repacked, b_ds); + } + } + } + + return acc; +} +#endif diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl new file mode 100644 index 00000000000..b3f1fa04b80 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl @@ -0,0 +1,173 @@ +#ifdef USE_SUBGROUP_REDUCTION +enable subgroups; +#endif +enable f16; + +requires packed_4x8_integer_dot_product; + +#include "common_decls.tmpl" + +struct Params { + offset_src1: u32, + stride_12: u32, + stride_13: u32, + ne0: u32, + ne2: u32, + ne3: u32, +}; + +#define SRC1_TYPE vec4 + +@group(0) @binding(0) var src1: array; +@group(0) @binding(1) var src1q: array; + +@group(0) @binding(2) var params: Params; + +#ifdef USE_SUBGROUP_REDUCTION +fn cluster_max_8(v: f32) -> f32 { + var r = v; + r = max(r, subgroupShuffleXor(r, 1u)); + r = max(r, subgroupShuffleXor(r, 2u)); + r = max(r, subgroupShuffleXor(r, 4u)); + return r; +} + +#if defined(MUL_ACC_Q4_0) || defined(MUL_ACC_Q4_1) || defined(MUL_ACC_Q4_K) +fn cluster_add_i4x8(v: i32) -> i32 { + var r= v; + r += subgroupShuffleXor(r, 1u); + r += subgroupShuffleXor(r, 2u); + r += subgroupShuffleXor(r, 4u); + return r; +} +#endif +#endif + +#ifdef USE_WORKGROUP_REDUCTION +#define CLUSTER_SIZE 8 + +var partial_amaxs: array, WG_SIZE / CLUSTER_SIZE>; +var partial_sums: array, WG_SIZE / CLUSTER_SIZE>; +#endif + +@compute @workgroup_size(WG_SIZE) +fn main( + @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) wg_id: vec3, + @builtin(num_workgroups) num_wg: vec3 +) { + let thread_id = local_id.x; + let num_vec4 = params.ne0 / 4u; + + let wg_per_vec = (num_vec4 + (WG_SIZE - 1u)) / WG_SIZE; + let total_batches = wg_per_vec * params.ne2 * params.ne3; + + let wg_linear = wg_id.y * num_wg.x + wg_id.x; + if (wg_linear >= total_batches) { + return; + } + + let src13_idx = wg_linear / (params.ne2 * wg_per_vec); + let src12_idx = (wg_linear - src13_idx * (params.ne2 * wg_per_vec)) / wg_per_vec; + let src11_wg_idx = wg_linear % wg_per_vec; + let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12; + let src1_idx_vec4_base = src1_idx_base / 4u; + + let blocks_per_row = params.ne0 / 32u; + let blocks_per_wg = (WG_SIZE * 4u) / 32u; + let src1q_idx_base = (src13_idx * params.ne2 + src12_idx) * blocks_per_row; + let src1q_idx = src1q_idx_base + src11_wg_idx * blocks_per_wg + thread_id / 8u; + let qs_idx = thread_id % 8u; + + // reduction + var q4 = vec4(0.0); + var q4_quants = 0u; + var thread_amax = 0.0; + + let src11_vec4_idx = src11_wg_idx * WG_SIZE + thread_id; + let is_valid = src11_vec4_idx < num_vec4; + +#ifdef USE_SUBGROUP_REDUCTION + + var d = 0.0; + + if (is_valid) { + q4 = src1[src1_idx_vec4_base + src11_vec4_idx]; + let abs_q4 = abs(q4); + thread_amax = max(max(abs_q4[0u], abs_q4[1u]), max(abs_q4[2], abs_q4[3])); + } + + d = cluster_max_8(thread_amax) / 127.0; + + if (is_valid) { + let id = select(0.0, 1.0 / d, d > 0.0); + q4_quants = pack4xI8(vec4(round(q4 * id))); + if (qs_idx == 0u) { + src1q[src1q_idx].d = f16(d); + } + src1q[src1q_idx].qs[qs_idx] = q4_quants; + } + +#if defined(MUL_ACC_Q4_0) || defined(MUL_ACC_Q4_1) || defined(MUL_ACC_Q4_K) + let q4_quants_sum = dot4I8Packed(q4_quants, 0x01010101u); + let s = f16(d * f32(cluster_add_i4x8(q4_quants_sum))); + + if (is_valid) { + if (qs_idx == 0u) { + src1q[src1q_idx].s = s; + } + } +#endif +#endif + +#ifdef USE_WORKGROUP_REDUCTION + + var d = 0.0; + let cluster_id = thread_id / 8u; + + if (is_valid) { + q4 = src1[src1_idx_vec4_base + src11_vec4_idx]; + let abs_q4 = abs(q4); + thread_amax = max(max(abs_q4[0], abs_q4[1]), max(abs_q4[2], abs_q4[3])); + partial_amaxs[cluster_id][qs_idx] = thread_amax; + } + + workgroupBarrier(); + + if (is_valid) { + let amax = max( + max( + max(partial_amaxs[cluster_id][0], partial_amaxs[cluster_id][1]), max(partial_amaxs[cluster_id][2], partial_amaxs[cluster_id][3])), + max( + max(partial_amaxs[cluster_id][4], partial_amaxs[cluster_id][5]), max(partial_amaxs[cluster_id][6], partial_amaxs[cluster_id][7])) + ); + + d = amax / 127.0; + let id = select(0.0f, 1.0f / d, d > 0.0f); + + q4_quants = pack4xI8(vec4(round(q4 * id))); + src1q[src1q_idx].qs[qs_idx] = q4_quants; + + if (qs_idx == 0u) { + src1q[src1q_idx].d = f16(d); + } + } + +#if defined(MUL_ACC_Q4_0) || defined(MUL_ACC_Q4_1) || defined(MUL_ACC_Q4_K) + + partial_sums[cluster_id][qs_idx] = dot4I8Packed(q4_quants, 0x01010101u); + + workgroupBarrier(); + + if (is_valid) { + if (qs_idx == 0u) { + let s = d * f32(partial_sums[cluster_id][0] + partial_sums[cluster_id][1] + partial_sums[cluster_id][2] + partial_sums[cluster_id][3] + + partial_sums[cluster_id][4] + partial_sums[cluster_id][5] + partial_sums[cluster_id][6] + partial_sums[cluster_id][7]); + src1q[src1q_idx].s = f16(s); + } + } + +#endif +#endif + +} From 049f0af3398f67acc547844dec0d14310ab2bbb5 Mon Sep 17 00:00:00 2001 From: Alexey Kopytko Date: Tue, 26 May 2026 13:59:00 +0900 Subject: [PATCH 705/831] SYCL: implement ggml_sycl_pool_vmm (llama/22862) * SYCL: implement ggml_sycl_pool_vmm * Add an option to bypass VMM with GGML_SYCL_DISABLE_VMM * Clean up debugging logging * document GGML_SYCL_DISABLE_VMM * Multi-stream MoE optimization * Revert "Multi-stream MoE optimization" This reverts commit 938929c3f13a562ec67c59e87cc5d38595444cce. * Update common.hpp Co-authored-by: Neo Zhang * Flip GGML_SYCL_DISABLE_VMM to GGML_SYCL_ENABLE_VMM * add logging for GGML_SYCL_ENABLE_VMM when extension is not available (SYCL_EXT_ONEAPI_VIRTUAL_MEM macro) * Apply suggestions from code review Co-authored-by: Alexey Kopytko * Apply suggestion from @sanmai * Apply suggestion from @sanmai --------- Co-authored-by: Neo Zhang --- ggml/src/ggml-sycl/common.hpp | 3 + ggml/src/ggml-sycl/ggml-sycl.cpp | 171 +++++++++++++++++++++++++++++-- 2 files changed, 163 insertions(+), 11 deletions(-) diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp index 6d19538215e..31e26ff48e4 100644 --- a/ggml/src/ggml-sycl/common.hpp +++ b/ggml/src/ggml-sycl/common.hpp @@ -224,6 +224,7 @@ struct sycl_device_info { int max_wg_per_cu; // max work groups per compute unit - refer to // cudaOccupancyMaxActiveBlocksPerMultiprocessor bool vmm; // virtual memory support + size_t vmm_granularity; // granularity of virtual memory size_t total_vram; sycl_hw_info hw_info; optimize_feature opt_feature; @@ -244,6 +245,8 @@ struct ggml_sycl_device_info { const ggml_sycl_device_info & ggml_sycl_info(); +static constexpr size_t SYCL_BUFFER_ALIGNMENT = 128; + struct ggml_sycl_pool { virtual ~ggml_sycl_pool() = default; diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index b3fbb621196..729a88b4db8 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -19,6 +19,7 @@ #include #include #include +#include #include #include #include @@ -37,6 +38,11 @@ #if defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC # include #endif +#if SYCL_EXT_ONEAPI_VIRTUAL_MEM +# include +# include +# define GGML_SYCL_USE_VMM +#endif #include #include "ggml.h" @@ -70,6 +76,7 @@ int g_ggml_sycl_debug = 0; int g_ggml_sycl_disable_optimize = 0; int g_ggml_sycl_disable_graph = 0; int g_ggml_sycl_disable_dnn = 0; +int g_ggml_sycl_enable_vmm = 1; int g_ggml_sycl_prioritize_dmmv = 0; int g_ggml_sycl_use_async_mem_op = 0; int g_ggml_sycl_use_async_mem_op_requested = 1; @@ -96,13 +103,30 @@ static ggml_sycl_device_info ggml_sycl_init() { // GGML_LOG_INFO("%s: SYCL_USE_XMX: no\n", __func__); // #endif for (int i = 0; i < info.device_count; ++i) { - info.devices[i].vmm = 0; dpct::device_info prop; auto & device = dpct::dev_mgr::instance().get_device(i); SYCL_CHECK(CHECK_TRY_ERROR(dpct::get_device_info( prop, device))); +#if !defined(GGML_SYCL_USE_VMM) + info.devices[i].vmm = 0; +#else + info.devices[i].vmm = device.has(sycl::aspect::ext_oneapi_virtual_mem); + if (info.devices[i].vmm) { + // NB: SYCL's get_mem_granularity always returns the _minimum_ granularity, + // but the L0 API requires a larger page size for allocs above 2 MiB and + // rejects non-multiples with UR_RESULT_ERROR_INVALID_VALUE [sic]. + // Here we clamp it to 2 MiB for simplicity, but other devices may require + // calling zeVirtualMemQueryPageSize or yet unexposed public API. + const size_t physical_page = 2ull << 20; // 2 MiB + info.devices[i].vmm_granularity = std::max( + sycl::ext::oneapi::experimental::get_mem_granularity( + device, sycl::context(device)), + physical_page); + } +#endif + info.default_tensor_split[i] = total_vram; total_vram += prop.get_global_mem_size(); @@ -234,6 +258,7 @@ static void ggml_check_sycl() try { g_ggml_sycl_disable_optimize = get_sycl_env("GGML_SYCL_DISABLE_OPT", 0); g_ggml_sycl_disable_graph = get_sycl_env("GGML_SYCL_DISABLE_GRAPH", 1); g_ggml_sycl_disable_dnn = get_sycl_env("GGML_SYCL_DISABLE_DNN", 0); + g_ggml_sycl_enable_vmm = get_sycl_env("GGML_SYCL_ENABLE_VMM", 1); g_ggml_sycl_prioritize_dmmv = get_sycl_env("GGML_SYCL_PRIORITIZE_DMMV", 0); #ifdef GGML_SYCL_SUPPORT_LEVEL_ZERO g_ggml_sycl_enable_level_zero = get_sycl_env("GGML_SYCL_ENABLE_LEVEL_ZERO", ggml_sycl_info().ext_oneapi_level_zero); @@ -275,6 +300,11 @@ static void ggml_check_sycl() try { #else GGML_LOG_INFO(" GGML_SYCL_SUPPORT_LEVEL_ZERO: no\n"); #endif +#if defined(GGML_SYCL_USE_VMM) + GGML_LOG_INFO(" GGML_SYCL_USE_VMM: yes\n"); +#else + GGML_LOG_INFO(" GGML_SYCL_USE_VMM: no\n"); +#endif GGML_LOG_INFO("Running with Environment Variables:\n"); GGML_LOG_INFO(" GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug); @@ -293,6 +323,11 @@ static void ggml_check_sycl() try { GGML_LOG_INFO(" GGML_SYCL_DISABLE_DNN: %d\n", g_ggml_sycl_disable_dnn); #else GGML_LOG_INFO(" GGML_SYCL_DISABLE_DNN: DNN disabled by compile flag\n"); +#endif +#if defined(GGML_SYCL_USE_VMM) + GGML_LOG_INFO(" GGML_SYCL_ENABLE_VMM: %d\n", g_ggml_sycl_enable_vmm); +#else + GGML_LOG_INFO(" GGML_SYCL_ENABLE_VMM: virtual memory extension is not available\n"); #endif GGML_LOG_INFO(" GGML_SYCL_PRIORITIZE_DMMV: %d\n", g_ggml_sycl_prioritize_dmmv); g_ggml_sycl_use_async_mem_op_requested = get_sycl_env("GGML_SYCL_USE_ASYNC_MEM_OP", 1); @@ -754,7 +789,7 @@ catch (sycl::exception const &exc) { } static size_t ggml_backend_sycl_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { - return 128; + return SYCL_BUFFER_ALIGNMENT; GGML_UNUSED(buft); } @@ -1177,7 +1212,7 @@ static ggml_backend_buffer_t ggml_backend_sycl_split_buffer_type_alloc_buffer(gg } static size_t ggml_backend_sycl_split_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { - return 128; + return SYCL_BUFFER_ALIGNMENT; GGML_UNUSED(buft); } @@ -1462,6 +1497,121 @@ struct ggml_sycl_pool_leg : public ggml_sycl_pool { } }; +// pool with virtual memory management +#if defined(GGML_SYCL_USE_VMM) +struct ggml_sycl_pool_vmm : public ggml_sycl_pool { + static const size_t SYCL_POOL_VMM_MAX_SIZE = 1ull << 35; // 32 GB + + int device; + sycl::context ctx; + sycl::device dev; + + uintptr_t pool_addr = 0; + size_t pool_used = 0; + size_t pool_size = 0; + size_t granularity; + + // physical_mem owns the commits (unlike cuMemMap) + struct mapping { + sycl::ext::oneapi::experimental::physical_mem phys; + void * map_ptr; + }; + std::vector mappings; + + explicit ggml_sycl_pool_vmm(queue_ptr qptr_, int device_) : + device(device_), + ctx(qptr_->get_context()), + dev(qptr_->get_device()), + granularity(ggml_sycl_info().devices[device_].vmm_granularity) { + } + + ~ggml_sycl_pool_vmm() { + if (pool_addr == 0) { + return; + } + + // Per spec, unmap must (a) match the exact (ptr, size) of an earlier + // physical_mem::map() call and (b) precede destruction of the + // physical_mem objects (their dtors won't unmap). + for (auto & m : mappings) { + SYCL_CHECK(CHECK_TRY_ERROR(sycl::ext::oneapi::experimental::unmap( + m.map_ptr, m.phys.size(), ctx))); + } + SYCL_CHECK(CHECK_TRY_ERROR(sycl::ext::oneapi::experimental::free_virtual_mem( + pool_addr, SYCL_POOL_VMM_MAX_SIZE, ctx))); + } + + void * alloc(size_t size, size_t * actual_size) override { + // round up the allocation size to the alignment to ensure that all allocations are aligned for all data types + size = GGML_PAD(size, SYCL_BUFFER_ALIGNMENT); + + size_t avail = pool_size - pool_used; + + if (size > avail) { + // round up to the next multiple of the granularity + size_t reserve_size = GGML_PAD(size - avail, granularity); + + GGML_ASSERT(pool_size + reserve_size <= SYCL_POOL_VMM_MAX_SIZE); + + // allocate more physical memory + std::optional phys; + SYCL_CHECK(CHECK_TRY_ERROR(phys.emplace(dev, ctx, reserve_size))); + + // reserve virtual address space (if not already reserved) + if (pool_addr == 0) { + SYCL_CHECK(CHECK_TRY_ERROR( + pool_addr = sycl::ext::oneapi::experimental::reserve_virtual_mem( + SYCL_POOL_VMM_MAX_SIZE, ctx))); + } + + // map at the end of the pool + void * map_ptr = nullptr; + SYCL_CHECK(CHECK_TRY_ERROR( + map_ptr = phys->map(pool_addr + pool_size, reserve_size, + sycl::ext::oneapi::experimental::address_access_mode::read_write))); + + // stash these so we could unmap this exact range in dtor + mappings.push_back({ + std::move(*phys), + map_ptr, + }); + + // add to the pool + pool_size += reserve_size; + +#ifdef DEBUG_SYCL_MALLOC + GGML_LOG_INFO("sycl pool[%d]: size increased to %llu MB (reserved %llu MB)\n", + device, (unsigned long long) (pool_size/1024/1024), + (unsigned long long) (reserve_size/1024/1024)); +#endif + } + + GGML_ASSERT(pool_addr != 0); + + void * ptr = reinterpret_cast(pool_addr + pool_used); + *actual_size = size; + pool_used += size; + +#ifdef DEBUG_SYCL_MALLOC + GGML_LOG_INFO("sycl pool[%d]: allocated %llu bytes at %p\n", device, (unsigned long long) size, ptr); +#endif + + return ptr; + } + + void free(void * ptr, size_t size) override { +#ifdef DEBUG_SYCL_MALLOC + GGML_LOG_INFO("sycl pool[%d]: freed %llu bytes at %p\n", device, (unsigned long long) size, ptr); +#endif + + pool_used -= size; + + // all deallocations must be in reverse order of the allocations + GGML_ASSERT(ptr == reinterpret_cast(pool_addr + pool_used)); + } +}; +#endif // defined(GGML_SYCL_USE_VMM) + struct ggml_sycl_pool_host : public ggml_sycl_pool { queue_ptr qptr; int device; @@ -1542,20 +1692,19 @@ std::unique_ptr ggml_backend_sycl_context::new_pool_for_host(que } std::unique_ptr ggml_backend_sycl_context::new_pool_for_device(queue_ptr qptr, int device) { - // TBD: NO VMM support - // if (ggml_sycl_info().devices[device].vmm) { - // return std::unique_ptr(new ggml_sycl_pool_vmm(device)); - // } - return std::unique_ptr(new ggml_sycl_pool_leg(qptr, device)); +#if defined(GGML_SYCL_USE_VMM) + if (g_ggml_sycl_enable_vmm && ggml_sycl_info().devices[device].vmm) { + return std::unique_ptr(new ggml_sycl_pool_vmm(qptr, device)); + } +#endif // defined(GGML_SYCL_USE_VMM) + return std::unique_ptr(new ggml_sycl_pool_leg(qptr, device)); } + std::unique_ptr ggml_backend_sycl_context::new_fattn_kv_buffers(queue_ptr qptr, int device) { return std::unique_ptr(new ggml_sycl_fattn_kv_buffers(qptr, device)); } -// TBD pool with virtual memory management -// struct ggml_sycl_pool_vmm : public ggml_sycl_pool - /// kernels typedef void (*ggml_sycl_op_mul_mat_t)( ggml_backend_sycl_context & ctx, From f8df28d3319ecf97d8bbc27cdc22bcff8f1dcbe0 Mon Sep 17 00:00:00 2001 From: Max Krasnyansky Date: Tue, 26 May 2026 06:20:05 -0700 Subject: [PATCH 706/831] hexagon: add support for CONCAT op (llama/23648) * hexagon: add support for CONCAT with optimized concat_2d_transposed qwen3.5 models are quite heavy on the CONCAT with large and transposed src1. * hex-concat: use fastdiv in generic version * hex-concat: make checks for transposed a bit more readable * hex-concat: reoder dma ops for better pipelining * hex-cont/cpy: optimize CPY and CONT ops The primary change is to avoid scalar divs in the inner loops. We were calling hvx_copy_uu(... type_size) where type_size is non a constexpr. This causes runtime divs by that value which is normally just 4 or 2 (f32/f16). * hex-get-rows: optimize GET_ROWS for large rows We now use DMA for larger rows and also split them into chunks to improve perf for Qwen3.5 and other models that do lots of GET_ROWS with huge (2MB+ rows). Also bump the DMA queue depth now that we can take advantage of it. * hex-concat: unroll the inner loops of concat_2d * hex-concat: more updates to concat_2d to improve perf a bit further * hex-cpy: fixed n_rows per thread checks in the copy ops * hmx-fa: fix alignment issues while computing dma sizes * hex-set-rows: add early returns for idle threads * hvx-rope: minor optimization to replace loops with fastdiv logic * hex-rope: replace scalar tail processing with HVX * hex-rope: optimize rope cache init with HVX Add hvx-utils sin/cos helpers that use an aprox method (similar to rsqrt, inverse, etc) Use the helpers to optimize ROPE. --- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 24 ++ ggml/src/ggml-hexagon/htp/CMakeLists.txt | 1 + ggml/src/ggml-hexagon/htp/concat-ops.c | 275 ++++++++++++++++ ggml/src/ggml-hexagon/htp/cpy-ops.c | 310 ++++++++++-------- ggml/src/ggml-hexagon/htp/get-rows-ops.c | 120 ++++++- .../src/ggml-hexagon/htp/hmx-flash-attn-ops.c | 21 +- ggml/src/ggml-hexagon/htp/htp-ctx.h | 1 + ggml/src/ggml-hexagon/htp/htp-ops.h | 1 + ggml/src/ggml-hexagon/htp/hvx-sin-cos.h | 90 +++++ ggml/src/ggml-hexagon/htp/hvx-utils.h | 2 + ggml/src/ggml-hexagon/htp/main.c | 6 +- ggml/src/ggml-hexagon/htp/rope-ops.c | 240 ++++++++++---- ggml/src/ggml-hexagon/htp/set-rows-ops.c | 6 + ggml/src/ggml-hexagon/htp/unary-ops.c | 2 +- 14 files changed, 868 insertions(+), 231 deletions(-) create mode 100644 ggml/src/ggml-hexagon/htp/concat-ops.c create mode 100644 ggml/src/ggml-hexagon/htp/hvx-sin-cos.h diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 9db99cb0f3a..1c8ecc197e9 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -2874,6 +2874,7 @@ static htp_op_code op_remap_to_htp(const ggml_tensor * t) { case GGML_OP_NORM: return HTP_OP_NORM; case GGML_OP_L2_NORM: return HTP_OP_L2_NORM; case GGML_OP_RMS_NORM: return HTP_OP_RMS_NORM; + case GGML_OP_CONCAT: return HTP_OP_CONCAT; case GGML_OP_SCALE: return HTP_OP_SCALE; case GGML_OP_SQR: return HTP_OP_SQR; case GGML_OP_SQRT: return HTP_OP_SQRT; @@ -3286,6 +3287,25 @@ static bool ggml_hexagon_supported_repeat(const struct ggml_hexagon_session * se return true; } +static bool ggml_hexagon_supported_concat(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + int dim = ((const int32_t *) op->op_params)[0]; + if (dim < 0 || dim >= GGML_MAX_DIMS) { + return false; + } + + for (int i = 0; i < GGML_MAX_SRC; ++i) { + const struct ggml_tensor * src = op->src[i]; + if (!src) { + continue; + } + if (src->type != GGML_TYPE_F32 && src->type != GGML_TYPE_I32 && src->type != GGML_TYPE_F16) { + return false; + } + } + + return true; +} + static bool ggml_hexagon_supported_fill(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { const struct ggml_tensor * dst = op; @@ -3434,6 +3454,10 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons supp = ggml_hexagon_supported_cumsum(sess, op); break; + case GGML_OP_CONCAT: + supp = ggml_hexagon_supported_concat(sess, op); + break; + case GGML_OP_FILL: supp = ggml_hexagon_supported_fill(sess, op); break; diff --git a/ggml/src/ggml-hexagon/htp/CMakeLists.txt b/ggml/src/ggml-hexagon/htp/CMakeLists.txt index 36f923243cd..33d67dda9cc 100644 --- a/ggml/src/ggml-hexagon/htp/CMakeLists.txt +++ b/ggml/src/ggml-hexagon/htp/CMakeLists.txt @@ -35,6 +35,7 @@ add_library(${HTP_LIB} SHARED ssm-conv.c cumsum-ops.c fill-ops.c + concat-ops.c diag-ops.c solve-tri-ops.c gated-delta-net-ops.c diff --git a/ggml/src/ggml-hexagon/htp/concat-ops.c b/ggml/src/ggml-hexagon/htp/concat-ops.c new file mode 100644 index 00000000000..61580f2c08f --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/concat-ops.c @@ -0,0 +1,275 @@ +#include "htp-ctx.h" +#include "htp-ops.h" +#include "hexagon_types.h" +#include "hexagon_protos.h" +#include "hvx_hexagon_protos.h" +#include "hex-dma.h" +#include "vtcm-utils.h" +#include "hvx-utils.h" +#include "hex-fastdiv.h" +#include + +struct htp_concat_context { + struct htp_ops_context * octx; + uint32_t dim; + uint32_t nrows_per_thread; + struct fastdiv_values div_ne0; + struct fastdiv_values div_ne1; + struct fastdiv_values div_ne2; +}; + +static void concat_2d_f32_transposed(unsigned int nth, unsigned int ith, void * data) { + struct htp_concat_context * cctx = (struct htp_concat_context *) data; + struct htp_ops_context * octx = cctx->octx; + + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * src1 = octx->src[1]; + const struct htp_tensor * dst = octx->dst; + + const uint32_t src0_ne0 = src0->ne[0]; + const uint32_t src1_ne0 = src1->ne[0]; + const uint32_t ne1 = dst->ne[1]; + + const uint32_t start_i = ith * cctx->nrows_per_thread; + const uint32_t end_i = (start_i + cctx->nrows_per_thread < ne1) ? (start_i + cctx->nrows_per_thread) : ne1; + if (start_i >= end_i) return; + + dma_queue * q = octx->ctx->dma[ith]; + + uint8_t * spad0_base = octx->src0_spad.data + ith * octx->src0_spad.size_per_thread; + uint8_t * spad1_base = octx->src1_spad.data + ith * octx->src1_spad.size_per_thread; + + const uint32_t block_i = 32; + const uint32_t spad1_stride = block_i * sizeof(float); + + int32_t offsets[32] __attribute__((aligned(128))); + for(int k=0; k<32; k++) { + offsets[k] = k * spad1_stride; + } + HVX_Vector vv = *(HVX_Vector*)offsets; + const uint32_t src1_ne0_padded = hex_round_up(src1_ne0, 32); + const uint32_t spad0_row_bytes = hex_round_up((src0_ne0 + src1_ne0_padded) * sizeof(float), VLEN); + uint32_t mu = src1_ne0_padded * spad1_stride; + + for (uint32_t i = start_i; i < end_i; i += block_i) { + uint32_t current_block_i = (end_i - i < block_i) ? (end_i - i) : block_i; + + uint32_t src1_width_bytes = current_block_i * sizeof(float); + uint8_t * src1_ptr = (uint8_t *)src1->data + i * src1->nb[1]; + dma_queue_push(q, dma_make_ptr(spad1_base, src1_ptr), spad1_stride, src1->nb[0], src1_width_bytes, src1_ne0); + + uint32_t src0_row_bytes = src0_ne0 * sizeof(float); + uint8_t * src0_ptr = (uint8_t *)src0->data + i * src0->nb[1]; + dma_queue_push(q, dma_make_ptr(spad0_base, src0_ptr), spad0_row_bytes, src0->nb[1], src0_row_bytes, current_block_i); + + dma_queue_pop(q); // src1 + + HVX_Vector * vtcm_tmp = (HVX_Vector *)(spad1_base + src1_ne0_padded * spad1_stride); + + for (uint32_t j = 0; j < src1_ne0_padded; j += 32) { + #pragma unroll(4) + for (uint32_t ii = 0; ii < current_block_i; ii++) { + size_t rt = (size_t)(spad1_base + j * spad1_stride + ii * sizeof(float)); + Q6_vgather_ARMVw(&vtcm_tmp[ii], rt, mu, vv); + uint8_t * dst_ptr = spad0_base + ii * spad0_row_bytes + (src0_ne0 + j) * sizeof(float); + hvx_vmemu(dst_ptr) = vtcm_tmp[ii]; + } + } + + dma_queue_pop(q); // src0 + + uint8_t * dst_ptr = (uint8_t *)dst->data + i * dst->nb[1]; + dma_queue_push(q, dma_make_ptr(dst_ptr, spad0_base), dst->nb[1], spad0_row_bytes, (src0_ne0 + src1_ne0) * sizeof(float), current_block_i); + + dma_queue_pop(q); + } +} + +static void concat_2d_f16_transposed(unsigned int nth, unsigned int ith, void * data) { + struct htp_concat_context * cctx = (struct htp_concat_context *) data; + struct htp_ops_context * octx = cctx->octx; + + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * src1 = octx->src[1]; + const struct htp_tensor * dst = octx->dst; + + const uint32_t src0_ne0 = src0->ne[0]; + const uint32_t src1_ne0 = src1->ne[0]; + const uint32_t ne1 = dst->ne[1]; + + const uint32_t start_i = ith * cctx->nrows_per_thread; + const uint32_t end_i = (start_i + cctx->nrows_per_thread < ne1) ? (start_i + cctx->nrows_per_thread) : ne1; + if (start_i >= end_i) return; + + dma_queue * q = octx->ctx->dma[ith]; + + uint8_t * spad0_base = octx->src0_spad.data + ith * octx->src0_spad.size_per_thread; + uint8_t * spad1_base = octx->src1_spad.data + ith * octx->src1_spad.size_per_thread; + + const uint32_t block_i = 64; + const uint32_t spad1_stride = block_i * sizeof(__fp16); + + int16_t offsets[64] __attribute__((aligned(128))); + for(int k=0; k<64; k++) { + offsets[k] = k * spad1_stride; + } + HVX_Vector vv = *(HVX_Vector*)offsets; + const uint32_t src1_ne0_padded = hex_round_up(src1_ne0, 64); + const uint32_t spad0_row_bytes = hex_round_up((src0_ne0 + src1_ne0_padded) * sizeof(__fp16), VLEN); + uint32_t mu = src1_ne0_padded * spad1_stride; + + for (uint32_t i = start_i; i < end_i; i += block_i) { + uint32_t current_block_i = (end_i - i < block_i) ? (end_i - i) : block_i; + + uint32_t src1_width_bytes = current_block_i * sizeof(__fp16); + uint8_t * src1_ptr = (uint8_t *)src1->data + i * src1->nb[1]; + dma_queue_push(q, dma_make_ptr(spad1_base, src1_ptr), spad1_stride, src1->nb[0], src1_width_bytes, src1_ne0); + + uint32_t src0_row_bytes = src0_ne0 * sizeof(__fp16); + uint8_t * src0_ptr = (uint8_t *)src0->data + i * src0->nb[1]; + dma_queue_push(q, dma_make_ptr(spad0_base, src0_ptr), spad0_row_bytes, src0->nb[1], src0_row_bytes, current_block_i); + + dma_queue_pop(q); // src1 + + HVX_Vector * vtcm_tmp = (HVX_Vector *)(spad1_base + src1_ne0_padded * spad1_stride); + + for (uint32_t j = 0; j < src1_ne0_padded; j += 64) { + #pragma unroll(4) + for (uint32_t ii = 0; ii < current_block_i; ii++) { + size_t rt = (size_t)(spad1_base + j * spad1_stride + ii * sizeof(__fp16)); + Q6_vgather_ARMVh(&vtcm_tmp[ii], rt, mu, vv); + uint8_t * dst_ptr = spad0_base + ii * spad0_row_bytes + (src0_ne0 + j) * sizeof(__fp16); + hvx_vmemu(dst_ptr) = vtcm_tmp[ii]; + } + } + + dma_queue_pop(q); // src0 + + uint8_t * dst_ptr = (uint8_t *)dst->data + i * dst->nb[1]; + dma_queue_push(q, dma_make_ptr(dst_ptr, spad0_base), dst->nb[1], spad0_row_bytes, (src0_ne0 + src1_ne0) * sizeof(__fp16), current_block_i); + + dma_queue_pop(q); + } +} + +static void concat_generic(unsigned int nth, unsigned int ith, void * data) { + struct htp_concat_context * cctx = (struct htp_concat_context *) data; + struct htp_ops_context * octx = cctx->octx; + + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * src1 = octx->src[1]; + const struct htp_tensor * dst = octx->dst; + + const int dim = cctx->dim; + const uint32_t type_size = (dst->type == HTP_TYPE_F32 || dst->type == HTP_TYPE_I32) ? 4 : 2; + + const uint32_t ne[4] = {dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3]}; + const uint32_t total_elements = ne[0] * ne[1] * ne[2] * ne[3]; + const uint32_t chunk_size = (total_elements + nth - 1) / nth; + + const uint32_t start_idx = MIN(ith * chunk_size, total_elements); + const uint32_t end_idx = MIN(start_idx + chunk_size, total_elements); + + // Naive scalar element-wise copy + for (uint32_t idx = start_idx; idx < end_idx; idx++) { + uint32_t idx_div_ne0 = fastdiv(idx, &cctx->div_ne0); + uint32_t i0 = idx - idx_div_ne0 * ne[0]; + + uint32_t idx_div_ne01 = fastdiv(idx_div_ne0, &cctx->div_ne1); + uint32_t i1 = idx_div_ne0 - idx_div_ne01 * ne[1]; + + uint32_t idx_div_ne012 = fastdiv(idx_div_ne01, &cctx->div_ne2); + uint32_t i2 = idx_div_ne01 - idx_div_ne012 * ne[2]; + uint32_t i3 = idx_div_ne012; + + uint8_t * dst_ptr = (uint8_t *)dst->data + i3 * dst->nb[3] + i2 * dst->nb[2] + i1 * dst->nb[1] + i0 * dst->nb[0]; + + uint32_t idx_dim = 0; + if (dim == 0) idx_dim = i0; + else if (dim == 1) idx_dim = i1; + else if (dim == 2) idx_dim = i2; + else if (dim == 3) idx_dim = i3; + + const struct htp_tensor * src = (idx_dim < src0->ne[dim]) ? src0 : src1; + + uint32_t s0 = i0; + uint32_t s1 = i1; + uint32_t s2 = i2; + uint32_t s3 = i3; + + if (dim == 0 && src == src1) s0 -= src0->ne[0]; + if (dim == 1 && src == src1) s1 -= src0->ne[1]; + if (dim == 2 && src == src1) s2 -= src0->ne[2]; + if (dim == 3 && src == src1) s3 -= src0->ne[3]; + + uint8_t * src_ptr = (uint8_t *)src->data + s3 * src->nb[3] + s2 * src->nb[2] + s1 * src->nb[1] + s0 * src->nb[0]; + + if (type_size == 4) { + *(float*)dst_ptr = *(float*)src_ptr; + } else { + *(__fp16*)dst_ptr = *(__fp16*)src_ptr; + } + } +} + +int op_concat(struct htp_ops_context * octx) { + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * src1 = octx->src[1]; + const struct htp_tensor * dst = octx->dst; + + int dim = octx->op_params[0]; + + bool is_2d = dst->ne[2] == 1 && dst->ne[3] == 1; + + const uint32_t type_size = (dst->type == HTP_TYPE_F32 || dst->type == HTP_TYPE_I32) ? 4 : 2; + bool is_src1_transposed = (src1->nb[0] > src1->nb[1]); + bool is_src0_transposed = (src0->nb[0] > src0->nb[1]); + + uint32_t n_threads = octx->n_threads; + struct htp_concat_context cctx; + cctx.octx = octx; + cctx.dim = dim; + cctx.div_ne0 = init_fastdiv_values(dst->ne[0]); + cctx.div_ne1 = init_fastdiv_values(dst->ne[1]); + cctx.div_ne2 = init_fastdiv_values(dst->ne[2]); + + void (*worker_func)(unsigned int, unsigned int, void *) = concat_generic; + + if (dim == 0 && is_2d && is_src1_transposed && !is_src0_transposed) { + n_threads = MIN(dst->ne[1], n_threads); + if (n_threads < 1) { + n_threads = 1; + } + uint32_t block_i = (type_size == 4) ? 32 : 64; + + cctx.nrows_per_thread = hmx_ceil_div(dst->ne[1], n_threads); + + // Allocate VTCM + uint32_t spad1_stride = block_i * type_size; + + uint32_t src1_ne0_padded = hex_round_up(src1->ne[0], block_i); + uint32_t spad0_row_bytes = hex_round_up((src0->ne[0] + src1_ne0_padded) * type_size, VLEN); + + octx->src0_spad.size_per_thread = block_i * spad0_row_bytes; + octx->src1_spad.size_per_thread = src1_ne0_padded * spad1_stride + block_i * VLEN; + + octx->src0_spad.size = n_threads * octx->src0_spad.size_per_thread; + octx->src1_spad.size = n_threads * octx->src1_spad.size_per_thread; + + if (octx->src0_spad.size + octx->src1_spad.size > octx->ctx->vtcm_size) { + return HTP_STATUS_VTCM_TOO_SMALL; + } + + octx->src0_spad.data = octx->ctx->vtcm_base; + octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; + + if (type_size == 4) { + worker_func = concat_2d_f32_transposed; + } else { + worker_func = concat_2d_f16_transposed; + } + } + + worker_pool_run_func(octx->ctx->worker_pool, worker_func, &cctx, n_threads); + return HTP_STATUS_OK; +} diff --git a/ggml/src/ggml-hexagon/htp/cpy-ops.c b/ggml/src/ggml-hexagon/htp/cpy-ops.c index 5c040a32224..ae507effa51 100644 --- a/ggml/src/ggml-hexagon/htp/cpy-ops.c +++ b/ggml/src/ggml-hexagon/htp/cpy-ops.c @@ -28,158 +28,170 @@ struct htp_copy_context { uint32_t dst_blocks_per_row; uint32_t src0_nrows_per_thread; - - void (*copy)(struct htp_copy_context * ct, struct htp_ops_context * octx, int nth, int ith); }; #define cpy_preamble \ const struct htp_tensor *src0 = octx->src[0]; \ const struct htp_tensor *dst = octx->dst; \ \ - const uint32_t ne00 = src0->ne[0]; \ - const uint32_t ne01 = src0->ne[1]; \ - const uint32_t ne02 = src0->ne[2]; \ - const uint32_t ne03 = src0->ne[3]; \ - \ - const uint32_t nb00 = src0->nb[0]; \ - const uint32_t nb01 = src0->nb[1]; \ - const uint32_t nb02 = src0->nb[2]; \ - const uint32_t nb03 = src0->nb[3]; \ - \ - const uint32_t ne0 = dst->ne[0]; \ - const uint32_t ne1 = dst->ne[1]; \ - const uint32_t ne2 = dst->ne[2]; \ - const uint32_t ne3 = dst->ne[3]; \ - \ - const uint32_t nb0 = dst->nb[0]; \ - const uint32_t nb1 = dst->nb[1]; \ - const uint32_t nb2 = dst->nb[2]; \ - const uint32_t nb3 = dst->nb[3]; \ - \ + const uint32_t ne00 = src0->ne[0]; \ + const uint32_t ne01 = src0->ne[1]; \ + const uint32_t ne02 = src0->ne[2]; \ + const uint32_t ne03 = src0->ne[3]; \ + \ + const uint32_t nb00 = src0->nb[0]; \ + const uint32_t nb01 = src0->nb[1]; \ + const uint32_t nb02 = src0->nb[2]; \ + const uint32_t nb03 = src0->nb[3]; \ + \ + const uint32_t ne0 = dst->ne[0]; \ + const uint32_t ne1 = dst->ne[1]; \ + const uint32_t ne2 = dst->ne[2]; \ + const uint32_t ne3 = dst->ne[3]; \ + \ + const uint32_t nb0 = dst->nb[0]; \ + const uint32_t nb1 = dst->nb[1]; \ + const uint32_t nb2 = dst->nb[2]; \ + const uint32_t nb3 = dst->nb[3]; \ + \ const uint32_t nr = ne01; -static void cpy_thread_sametype_sameshape(struct htp_copy_context * ct, struct htp_ops_context * octx, const int nth, const int ith) { - cpy_preamble; - - // parallelize by src0 rows - const uint32_t dr = ct->src0_nrows_per_thread; - const uint32_t ir0 = dr * ith; - const uint32_t ir1 = (ir0 + dr) < nr ? (ir0 + dr) : nr; - - // copy by rows - for (uint32_t i03 = 0; i03 < ne03; i03++) { - for (uint32_t i02 = 0; i02 < ne02; i02++) { - #pragma unroll(2) - for (uint32_t i01 = ir0; i01 < ir1; i01++) { - uint8_t* dst_ptr = (uint8_t*) dst->data + i01*nb1 + i02*nb2 + i03*nb3; - uint8_t* src0_ptr = (uint8_t*) src0->data + i01*nb01 + i02*nb02 + i03*nb03; - hex_l2fetch(src0_ptr, ne00 * ct->src0_type_size, nb01, 2); - hvx_copy_uu(dst_ptr, src0_ptr, ne00, ct->src0_type_size); - } - } - } +#define DEFINE_CPY_SAMESHAPE(NAME, ELEM_TYPE, ELEM_SIZE) \ +static void cpy_thread_##NAME##_sameshape(unsigned int nth, unsigned int ith, void * data) { \ + struct htp_copy_context * ct = (struct htp_copy_context *) data; \ + struct htp_ops_context * octx = ct->octx; \ + cpy_preamble; \ + const uint32_t dr = ct->src0_nrows_per_thread; \ + const uint32_t ir0 = dr * ith; \ + const uint32_t ir1 = (ir0 + dr) < nr ? (ir0 + dr) : nr; \ + if (ir0 >= nr) return; \ + for (uint32_t i03 = 0; i03 < ne03; i03++) { \ + for (uint32_t i02 = 0; i02 < ne02; i02++) { \ + _Pragma("unroll(4)") \ + for (uint32_t i01 = ir0; i01 < ir1; i01++) { \ + uint8_t* dst_ptr = (uint8_t*) dst->data + i01*nb1 + i02*nb2 + i03*nb3; \ + uint8_t* src0_ptr = (uint8_t*) src0->data + i01*nb01 + i02*nb02 + i03*nb03; \ + hex_l2fetch(src0_ptr, ne00 * ELEM_SIZE, nb01, 2); \ + hvx_copy_uu(dst_ptr, src0_ptr, ne00, ELEM_SIZE); \ + } \ + } \ + } \ } -static void cpy_thread_sametype_reshape(struct htp_copy_context * ct, struct htp_ops_context * octx, int nth, int ith) { - cpy_preamble; - - // parallelize by src0 rows - const uint32_t dr = ct->src0_nrows_per_thread; - const uint32_t ir0 = dr * ith; - const uint32_t ir1 = (ir0 + dr) < nr ? (ir0 + dr) : nr; - - // Fast path: when both src0 and dst are contiguous in memory - // Replace the element-by-element loop with a single bulk HVX copy per (i03, i02) slice. - const bool src0_contig = (nb00 == ct->src0_type_size) && - (nb01 == ne00 * nb00) && - (nb02 == ne01 * nb01) && - (nb03 == ne02 * nb02); - const bool dst_contig = (nb0 == ct->dst_type_size) && - (nb1 == ne0 * nb0) && - (nb2 == ne1 * nb1) && - (nb3 == ne2 * nb2); - - if (src0_contig && dst_contig) { - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - uint8_t * src_ptr = (uint8_t *) src0->data + i03*nb03 + i02*nb02 + ir0*nb01; - uint32_t flat = ((i03*ne02 + i02)*ne01 + ir0) * ne00; - uint8_t * dst_ptr = (uint8_t *) dst->data + flat * ct->src0_type_size; - hvx_copy_uu(dst_ptr, src_ptr, (ir1 - ir0) * ne00, ct->src0_type_size); - } - } - return; - } - - // dst counters - int64_t k10 = 0; - int64_t i11 = 0; - int64_t i12 = 0; - int64_t i13 = 0; - - // number of blocks in a row - const int64_t nk00 = ct->src0_blocks_per_row; - const int64_t nk0 = ct->dst_blocks_per_row; - - for (int64_t i03 = 0; i03 < ne03; i03++) { - for (int64_t i02 = 0; i02 < ne02; i02++) { - k10 += nk00 * ir0; - while (k10 >= nk0) { - k10 -= nk0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - for (int64_t i01 = ir0; i01 < ir1; i01++) { - for (int64_t k00 = 0; k00 < nk00; k00++) { - const char * src0_ptr = ((char *) src0->data + k00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); - char * dst_ptr = ((char *) dst->data + k10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); - memcpy(dst_ptr, src0_ptr, ct->dst_type_size); - - if (++k10 == nk0) { - k10 = 0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - } - } - k10 += nk00 * (ne01 - ir1); - while (k10 >= nk0) { - k10 -= nk0; - if (++i11 == ne1) { - i11 = 0; - if (++i12 == ne2) { - i12 = 0; - if (++i13 == ne3) { - i13 = 0; - } - } - } - } - } - } +DEFINE_CPY_SAMESHAPE(f32, float, 4) +DEFINE_CPY_SAMESHAPE(f16, __fp16, 2) + +#define DEFINE_CPY_RESHAPE(NAME, ELEM_TYPE, ELEM_SIZE) \ +static void cpy_thread_##NAME##_reshape(unsigned int nth, unsigned int ith, void * data) { \ + struct htp_copy_context * ct = (struct htp_copy_context *) data; \ + struct htp_ops_context * octx = ct->octx; \ + cpy_preamble; \ + const uint32_t dr = ct->src0_nrows_per_thread; \ + const uint32_t ir0 = dr * ith; \ + const uint32_t ir1 = (ir0 + dr) < nr ? (ir0 + dr) : nr; \ + if (ir0 >= nr) return; \ + const bool src0_contig = (nb00 == ELEM_SIZE) && \ + (nb01 == ne00 * nb00) && \ + (nb02 == ne01 * nb01) && \ + (nb03 == ne02 * nb02); \ + const bool dst_contig = (nb0 == ELEM_SIZE) && \ + (nb1 == ne0 * nb0) && \ + (nb2 == ne1 * nb1) && \ + (nb3 == ne2 * nb2); \ + if (src0_contig && dst_contig) { \ + for (int64_t i03 = 0; i03 < ne03; i03++) { \ + for (int64_t i02 = 0; i02 < ne02; i02++) { \ + uint8_t * src_ptr = (uint8_t *) src0->data + i03*nb03 + i02*nb02 + ir0*nb01; \ + uint32_t flat = ((i03*ne02 + i02)*ne01 + ir0) * ne00; \ + uint8_t * dst_ptr = (uint8_t *) dst->data + flat * ELEM_SIZE; \ + hvx_copy_uu(dst_ptr, src_ptr, (ir1 - ir0) * ne00, ELEM_SIZE); \ + } \ + } \ + return; \ + } \ + const bool reshape_flat_fast = (ne03 == 1 && ne2 == 1 && ne3 == 1) && \ + (ne0 == ne00 * ne01) && (ne1 == ne02) && \ + (nb00 == ELEM_SIZE) && (nb0 == ELEM_SIZE); \ + if (reshape_flat_fast) { \ + for (uint32_t i02 = 0; i02 < ne02; i02++) { \ + for (uint32_t i01 = ir0; i01 < ir1; i01++) { \ + uint8_t * src0_ptr = (uint8_t *) src0->data + i01 * nb01 + i02 * nb02; \ + uint8_t * dst_ptr = (uint8_t *) dst->data + i01 * ne00 * ELEM_SIZE + i02 * nb1; \ + hvx_copy_uu(dst_ptr, src0_ptr, ne00, ELEM_SIZE); \ + } \ + } \ + return; \ + } \ + int64_t k10 = 0; \ + int64_t i11 = 0; \ + int64_t i12 = 0; \ + int64_t i13 = 0; \ + const int64_t nk00 = ct->src0_blocks_per_row; \ + const int64_t nk0 = ct->dst_blocks_per_row; \ + for (int64_t i03 = 0; i03 < ne03; i03++) { \ + for (int64_t i02 = 0; i02 < ne02; i02++) { \ + k10 += nk00 * ir0; \ + while (k10 >= nk0) { \ + k10 -= nk0; \ + if (++i11 == ne1) { \ + i11 = 0; \ + if (++i12 == ne2) { \ + i12 = 0; \ + if (++i13 == ne3) { \ + i13 = 0; \ + } \ + } \ + } \ + } \ + for (int64_t i01 = ir0; i01 < ir1; i01++) { \ + for (int64_t k00 = 0; k00 < nk00; k00++) { \ + const char * src0_ptr = ((char *) src0->data + k00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); \ + char * dst_ptr = ((char *) dst->data + k10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); \ + memcpy(dst_ptr, src0_ptr, ELEM_SIZE); \ + if (++k10 == nk0) { \ + k10 = 0; \ + if (++i11 == ne1) { \ + i11 = 0; \ + if (++i12 == ne2) { \ + i12 = 0; \ + if (++i13 == ne3) { \ + i13 = 0; \ + } \ + } \ + } \ + } \ + } \ + } \ + k10 += nk00 * (ne01 - ir1); \ + while (k10 >= nk0) { \ + k10 -= nk0; \ + if (++i11 == ne1) { \ + i11 = 0; \ + if (++i12 == ne2) { \ + i12 = 0; \ + if (++i13 == ne3) { \ + i13 = 0; \ + } \ + } \ + } \ + } \ + } \ + } \ } -static void cpy_thread_f16_f32_sameshape(struct htp_copy_context * ct, struct htp_ops_context * octx, const int nth, const int ith) { +DEFINE_CPY_RESHAPE(f32, float, 4) +DEFINE_CPY_RESHAPE(f16, __fp16, 2) + +static void cpy_thread_f16_f32_sameshape(unsigned int nth, unsigned int ith, void * data) { + struct htp_copy_context * ct = (struct htp_copy_context *) data; + struct htp_ops_context * octx = ct->octx; cpy_preamble; // parallelize by src0 rows const uint32_t dr = ct->src0_nrows_per_thread; const uint32_t ir0 = dr * ith; const uint32_t ir1 = (ir0 + dr) < nr ? (ir0 + dr) : nr; + if (ir0 >= nr) return; // copy by rows for (uint32_t i03 = 0; i03 < ne03; i03++) { @@ -195,13 +207,16 @@ static void cpy_thread_f16_f32_sameshape(struct htp_copy_context * ct, struct ht } } -static void cpy_thread_f32_f16_sameshape(struct htp_copy_context * ct, struct htp_ops_context * octx, const int nth, const int ith) { +static void cpy_thread_f32_f16_sameshape(unsigned int nth, unsigned int ith, void * data) { + struct htp_copy_context * ct = (struct htp_copy_context *) data; + struct htp_ops_context * octx = ct->octx; cpy_preamble; // parallelize by src0 rows const uint32_t dr = ct->src0_nrows_per_thread; const uint32_t ir0 = dr * ith; const uint32_t ir1 = (ir0 + dr) < nr ? (ir0 + dr) : nr; + if (ir0 >= nr) return; // copy by rows for (uint32_t i03 = 0; i03 < ne03; i03++) { @@ -217,11 +232,6 @@ static void cpy_thread_f32_f16_sameshape(struct htp_copy_context * ct, struct ht } } -static void cpy_work_func(unsigned int n, unsigned int i, void *data) { - struct htp_copy_context *ct = (struct htp_copy_context *) data; - ct->copy(ct, ct->octx, n, i); -} - int op_cpy(struct htp_ops_context * octx) { cpy_preamble; @@ -254,22 +264,32 @@ int op_cpy(struct htp_ops_context * octx) { ct.src0_nrows_per_thread = (nr + n_threads - 1) / n_threads; + worker_callback_t copy_fun; + if (sametype && sameshape) { - ct.copy = cpy_thread_sametype_sameshape; + if (src0->type == HTP_TYPE_F32) { + copy_fun = cpy_thread_f32_sameshape; + } else { + copy_fun = cpy_thread_f16_sameshape; + } } else if (sameshape) { /**/ if (dst->type == HTP_TYPE_F16 && src0->type == HTP_TYPE_F32) - ct.copy = cpy_thread_f16_f32_sameshape; + copy_fun = cpy_thread_f16_f32_sameshape; else if (dst->type == HTP_TYPE_F32 && src0->type == HTP_TYPE_F16) - ct.copy = cpy_thread_f32_f16_sameshape; + copy_fun = cpy_thread_f32_f16_sameshape; else return HTP_STATUS_NO_SUPPORT; } else if (sametype) { - ct.copy = cpy_thread_sametype_reshape; + if (src0->type == HTP_TYPE_F32) { + copy_fun = cpy_thread_f32_reshape; + } else { + copy_fun = cpy_thread_f16_reshape; + } } else { return HTP_STATUS_NO_SUPPORT; } - worker_pool_run_func(octx->ctx->worker_pool, cpy_work_func, &ct, n_threads); + worker_pool_run_func(octx->ctx->worker_pool, copy_fun, &ct, n_threads); return HTP_STATUS_OK; } diff --git a/ggml/src/ggml-hexagon/htp/get-rows-ops.c b/ggml/src/ggml-hexagon/htp/get-rows-ops.c index 5a1dc933860..bf7063e9880 100644 --- a/ggml/src/ggml-hexagon/htp/get-rows-ops.c +++ b/ggml/src/ggml-hexagon/htp/get-rows-ops.c @@ -17,9 +17,13 @@ struct get_rows_context { struct htp_ops_context * octx; - uint32_t src1_nrows_per_thread; + uint32_t tasks_per_thread; + uint32_t total_tasks; + uint32_t chunks_per_row; + uint32_t chunk_size; struct fastdiv_values get_rows_div_ne10; struct fastdiv_values get_rows_div_ne10_ne11; + struct fastdiv_values get_rows_div_chunks_per_row; }; #define get_rows_preamble \ @@ -52,20 +56,23 @@ struct get_rows_context { \ const uint32_t nr = ne10 * ne11 * ne12; -static void get_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *data) { +static void get_rows_thread_f32_f32_dma(unsigned int nth, unsigned int ith, void *data) { struct get_rows_context * grctx = (struct get_rows_context *)data; struct htp_ops_context * octx = grctx->octx; get_rows_preamble; uint64_t qt = HAP_perf_get_qtimer_count(); - // parallelize by src1 elements (which correspond to dst rows) - const uint32_t dr = grctx->src1_nrows_per_thread; + const uint32_t dr = grctx->tasks_per_thread; const uint32_t ir0 = dr * ith; - const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr; + if (ir0 >= grctx->total_tasks) { + return; + } + const uint32_t ir1 = MIN(ir0 + dr, grctx->total_tasks); const bool is_i32 = (octx->src[1]->type == HTP_TYPE_I32); + dma_queue * dma_queue = octx->ctx->dma[ith]; for (uint32_t i = ir0; i < ir1; ++i) { const uint32_t i12 = fastdiv(i, &grctx->get_rows_div_ne10_ne11); const uint32_t rem = i - i12 * ne11 * ne10; @@ -73,28 +80,76 @@ static void get_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *da const uint32_t i10 = rem - i11 * ne10; const uintptr_t src1_addr = octx->src[1]->data + i10*nb10 + i11*nb11 + i12*nb12; - uint32_t i01 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr; if (i01 >= ne01) { - // invalid index, skip for now to avoid crash continue; } const uintptr_t src0_ptr = octx->src[0]->data + i01*nb01 + i11*nb02 + i12*nb03; const uintptr_t dst_ptr = octx->dst->data + i10*nb1 + i11*nb2 + i12*nb3; - hvx_copy_f32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, ne00); + + while (!dma_queue_push(dma_queue, dma_make_ptr((void *)dst_ptr, (const void *)src0_ptr), nb1, nb01, ne00 * sizeof(float), 1)) { + dma_queue_pop(dma_queue); + } } + dma_queue_flush(dma_queue); qt = HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - qt); - FARF(HIGH, "get-rows-f32-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, + FARF(HIGH, "get-rows-f32-f32-dma %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, ne00, ne01, ne02, ne03, ir0, ir1, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, (unsigned) qt); } -int op_get_rows(struct htp_ops_context * octx) { +static void get_rows_thread_f32_f32_hvx(unsigned int nth, unsigned int ith, void *data) { + struct get_rows_context * grctx = (struct get_rows_context *)data; + struct htp_ops_context * octx = grctx->octx; get_rows_preamble; - const uint32_t n_threads = MIN(nr, octx->n_threads); + uint64_t qt = HAP_perf_get_qtimer_count(); + + const uint32_t dr = grctx->tasks_per_thread; + const uint32_t ir0 = dr * ith; + if (ir0 >= grctx->total_tasks) { + return; + } + const uint32_t ir1 = MIN(ir0 + dr, grctx->total_tasks); + + const bool is_i32 = (octx->src[1]->type == HTP_TYPE_I32); + + const uint32_t chunks_per_row = grctx->chunks_per_row; + const uint32_t chunk_size = grctx->chunk_size; + for (uint32_t i = ir0; i < ir1; ++i) { + const uint32_t row_idx = fastdiv(i, &grctx->get_rows_div_chunks_per_row); + const uint32_t chunk_idx = i - row_idx * chunks_per_row; + + const uint32_t i12 = fastdiv(row_idx, &grctx->get_rows_div_ne10_ne11); + const uint32_t rem = row_idx - i12 * ne11 * ne10; + const uint32_t i11 = fastdiv(rem, &grctx->get_rows_div_ne10); + const uint32_t i10 = rem - i11 * ne10; + + const uintptr_t src1_addr = octx->src[1]->data + i10*nb10 + i11*nb11 + i12*nb12; + uint32_t i01 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr; + + if (i01 >= ne01) { + continue; + } + + const uint32_t offset = chunk_idx * chunk_size; + if (offset < ne00) { + const uint32_t copy_size = MIN(chunk_size, ne00 - offset); + const uintptr_t src0_ptr = octx->src[0]->data + i01*nb01 + i11*nb02 + i12*nb03 + offset * sizeof(float); + const uintptr_t dst_ptr = octx->dst->data + i10*nb1 + i11*nb2 + i12*nb3 + offset * sizeof(float); + hvx_copy_f32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, copy_size); + } + } + + qt = HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - qt); + FARF(HIGH, "get-rows-f32-f32-hvx %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, + ne00, ne01, ne02, ne03, ir0, ir1, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, (unsigned) qt); +} + +int op_get_rows(struct htp_ops_context * octx) { + get_rows_preamble; if (octx->src[0]->type != HTP_TYPE_F32) { return HTP_STATUS_NO_SUPPORT; @@ -112,13 +167,52 @@ int op_get_rows(struct htp_ops_context * octx) { return HTP_STATUS_OK; } + const uint32_t nb00 = octx->src[0]->nb[0]; + const uint32_t nb0 = octx->dst->nb[0]; + + const bool can_use_dma = (nb00 == sizeof(float)) && (nb0 == sizeof(float)); + const bool use_dma = can_use_dma && (ne00 >= 2048); + struct get_rows_context grctx; grctx.octx = octx; grctx.get_rows_div_ne10 = init_fastdiv_values(octx->src[1]->ne[0]); grctx.get_rows_div_ne10_ne11 = init_fastdiv_values(octx->src[1]->ne[0] * octx->src[1]->ne[1]); - grctx.src1_nrows_per_thread = (nr + n_threads - 1) / n_threads; + if (use_dma) { + grctx.chunks_per_row = 1; + grctx.chunk_size = ne00; + grctx.total_tasks = nr; + grctx.get_rows_div_chunks_per_row = init_fastdiv_values(1); + + const uint32_t n_threads = MIN(nr, octx->n_threads); + grctx.tasks_per_thread = (nr + n_threads - 1) / n_threads; + + worker_pool_run_func(octx->ctx->worker_pool, get_rows_thread_f32_f32_dma, &grctx, n_threads); + } else { + uint32_t chunks_per_row = 1; + uint32_t chunk_size = ne00; + uint32_t total_tasks = nr; + + if (nr < octx->n_threads) { + const uint32_t min_chunk_size = 1024; + uint32_t max_chunks = ne00 / min_chunk_size; + if (max_chunks == 0) { + max_chunks = 1; + } + chunks_per_row = MIN((octx->n_threads + nr - 1) / nr, max_chunks); + chunk_size = (ne00 + chunks_per_row - 1) / chunks_per_row; + total_tasks = nr * chunks_per_row; + } + + grctx.chunks_per_row = chunks_per_row; + grctx.chunk_size = chunk_size; + grctx.total_tasks = total_tasks; + grctx.get_rows_div_chunks_per_row = init_fastdiv_values(chunks_per_row); - worker_pool_run_func(octx->ctx->worker_pool, get_rows_thread_f32_f32, &grctx, n_threads); + const uint32_t n_threads = MIN(total_tasks, octx->n_threads); + grctx.tasks_per_thread = (total_tasks + n_threads - 1) / n_threads; + + worker_pool_run_func(octx->ctx->worker_pool, get_rows_thread_f32_f32_hvx, &grctx, n_threads); + } return HTP_STATUS_OK; } diff --git a/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c b/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c index 9e1b778b01f..a496f6289ae 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +++ b/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c @@ -50,8 +50,8 @@ static size_t hmx_fa_compute_vtcm_usage(size_t gqa_factor, size_t DK, size_t DV, const size_t g_br = hex_align_up(gqa_factor * Br, HMX_FP16_TILE_N_ROWS); const size_t q_tile_size = hex_align_up(g_br * DK * sizeof(__fp16), 4096); // Q: [g_br, DK] const size_t o_tile_size = hex_align_up(g_br * DV * sizeof(__fp16), 4096); // O: [g_br, DV] x2 ping-pong - const size_t k_dma_size = hex_align_up(Bc * DK * sizeof(__fp16), 4096); // K DMA: [Bc, DK] x2 double-buf - const size_t v_dma_size = hex_align_up(Bc * DV * sizeof(__fp16), 4096); // V DMA: [Bc, DV] x2 double-buf + const size_t k_dma_size = hex_align_up(Bc * hex_round_up(DK * sizeof(__fp16), 128), 4096); // K DMA: [Bc, DK] x2 double-buf + const size_t v_dma_size = hex_align_up(Bc * hex_round_up(DV * sizeof(__fp16), 128), 4096); // V DMA: [Bc, DV] x2 double-buf const size_t k_tile_size = hex_align_up(Bc * DK * sizeof(__fp16), 4096); // K tiles: [Bc, DK] interleaved const size_t v_tile_size = hex_align_up(Bc * DV * sizeof(__fp16), 4096); // V tiles: [Bc, DV] interleaved const size_t s_tile_size = hex_align_up(g_br * Bc * sizeof(__fp16), 4096); // S/P:[g_br, Bc] @@ -1278,7 +1278,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { struct hmx_fa_context factx; memset(&factx, 0, sizeof(factx)); factx.octx = octx; - factx.n_threads = octx->ctx->n_threads; + factx.n_threads = n_threads; factx.DK = DK; factx.DV = DV; factx.n_kv = nek1; @@ -1328,10 +1328,15 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { factx.m1 = powf(2.0f, -(max_bias / 2.0f) / factx.n_head_log2); // ======== VTCM allocation (GQA-aware) ======== + const size_t size_k_row = DK * sizeof(__fp16); + const size_t size_v_row = DV * sizeof(__fp16); + const size_t size_k_row_padded = hex_round_up(size_k_row, 128); + const size_t size_v_row_padded = hex_round_up(size_v_row, 128); + const size_t q_tile_bytes = hex_align_up(g_br * DK * sizeof(__fp16), 4096); const size_t o_tile_bytes = hex_align_up(g_br * DV * sizeof(__fp16), 4096); - const size_t k_dma_bytes = hex_align_up(Bc * DK * sizeof(__fp16), 4096); - const size_t v_dma_bytes = hex_align_up(Bc * DV * sizeof(__fp16), 4096); + const size_t k_dma_bytes = hex_align_up(Bc * size_k_row_padded, 4096); + const size_t v_dma_bytes = hex_align_up(Bc * size_v_row_padded, 4096); const size_t k_tile_bytes = hex_align_up(Bc * DK * sizeof(__fp16), 4096); const size_t v_tile_bytes = hex_align_up(Bc * DV * sizeof(__fp16), 4096); const size_t s_tile_bytes = hex_align_up(g_br * Bc * sizeof(__fp16), 4096); @@ -1401,11 +1406,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { // ======== DMA setup ======== dma_queue * const dma = ctx->dma[0]; - // Padded row sizes for DMA - const size_t size_k_row = nek0 * sizeof(__fp16); - const size_t size_v_row = nev0 * sizeof(__fp16); - const size_t size_k_row_padded = hex_round_up(nek0 * sizeof(__fp16), 128); - const size_t size_v_row_padded = hex_round_up(nev0 * sizeof(__fp16), 128); + // Padded row sizes for DMA (defined in outer scope) const size_t n_row_tiles_g_br = g_br / HMX_FP16_TILE_N_ROWS; const size_t n_tiles_per_bc = Bc / HMX_FP16_TILE_N_COLS; diff --git a/ggml/src/ggml-hexagon/htp/htp-ctx.h b/ggml/src/ggml-hexagon/htp/htp-ctx.h index 6fe3e6c7d85..51f9243ce0a 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ctx.h +++ b/ggml/src/ggml-hexagon/htp/htp-ctx.h @@ -104,6 +104,7 @@ int op_argsort(struct htp_ops_context * octx); int op_ssm_conv(struct htp_ops_context * octx); int op_cumsum(struct htp_ops_context * octx); int op_fill(struct htp_ops_context * octx); +int op_concat(struct htp_ops_context * octx); int op_diag(struct htp_ops_context * octx); int op_solve_tri(struct htp_ops_context * octx); int op_gated_delta_net(struct htp_ops_context * octx); diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index 9d905a30133..54cfadd9b0a 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -89,6 +89,7 @@ enum htp_op_code { HTP_OP_TRI, HTP_OP_PAD, HTP_OP_NORM, + HTP_OP_CONCAT, HTP_OP_INVALID }; diff --git a/ggml/src/ggml-hexagon/htp/hvx-sin-cos.h b/ggml/src/ggml-hexagon/htp/hvx-sin-cos.h new file mode 100644 index 00000000000..c5b9a5d47c1 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hvx-sin-cos.h @@ -0,0 +1,90 @@ +#ifndef HVX_SIN_COS_H +#define HVX_SIN_COS_H + +#include "hvx-base.h" +#include "hvx-floor.h" + +static inline HVX_Vector hvx_vec_cos_f32(HVX_Vector x) { + HVX_Vector const_inv_pi = hvx_vec_splat_f32(0.3183098861837907f); + HVX_Vector const_half = hvx_vec_splat_f32(0.5f); + HVX_Vector const_pi = hvx_vec_splat_f32(3.141592653589793f); + HVX_Vector const_one = hvx_vec_splat_f32(1.0f); + HVX_Vector const_neg_one = hvx_vec_splat_f32(-1.0f); + + // n = floor(x * (1/pi) + 0.5) + HVX_Vector n_float = hvx_vec_floor_f32(hvx_vec_add_f32_f32(hvx_vec_mul_f32_f32(x, const_inv_pi), const_half)); + + // y = x - n * pi + HVX_Vector y = hvx_vec_sub_f32_f32(x, hvx_vec_mul_f32_f32(n_float, const_pi)); + + // Sign determination: if n is odd, sign is -1.0f, else 1.0f + // half_n = n * 0.5f + HVX_Vector half_n = hvx_vec_mul_f32_f32(n_float, const_half); + // floor_half_n = floor(half_n) + HVX_Vector floor_half_n = hvx_vec_floor_f32(half_n); + // is_odd = half_n > floor_half_n + HVX_VectorPred is_odd = Q6_Q_vcmp_gt_VsfVsf(half_n, floor_half_n); + // sign = vmux(is_odd, -1.0f, 1.0f) + HVX_Vector sign = Q6_V_vmux_QVV(is_odd, const_neg_one, const_one); + + // z = y^2 + HVX_Vector z = hvx_vec_mul_f32_f32(y, y); + + // Chebyshev approximation for cos(y) + HVX_Vector c4 = hvx_vec_splat_f32(2.3557242013849433e-05f); + HVX_Vector c3 = hvx_vec_splat_f32(-0.0013871428263450528f); + HVX_Vector c2 = hvx_vec_splat_f32(0.041665895266688284f); + HVX_Vector c1 = hvx_vec_splat_f32(-0.4999999360426369f); + HVX_Vector c0 = hvx_vec_splat_f32(0.9999999999071725f); + + HVX_Vector cos_y = hvx_vec_add_f32_f32(c3, hvx_vec_mul_f32_f32(z, c4)); + cos_y = hvx_vec_add_f32_f32(c2, hvx_vec_mul_f32_f32(z, cos_y)); + cos_y = hvx_vec_add_f32_f32(c1, hvx_vec_mul_f32_f32(z, cos_y)); + cos_y = hvx_vec_add_f32_f32(c0, hvx_vec_mul_f32_f32(z, cos_y)); + + return hvx_vec_mul_f32_f32(cos_y, sign); +} + +static inline HVX_Vector hvx_vec_sin_f32(HVX_Vector x) { + HVX_Vector const_inv_pi = hvx_vec_splat_f32(0.3183098861837907f); + HVX_Vector const_half = hvx_vec_splat_f32(0.5f); + HVX_Vector const_pi = hvx_vec_splat_f32(3.141592653589793f); + HVX_Vector const_one = hvx_vec_splat_f32(1.0f); + HVX_Vector const_neg_one = hvx_vec_splat_f32(-1.0f); + + // n = floor(x * (1/pi) + 0.5) + HVX_Vector n_float = hvx_vec_floor_f32(hvx_vec_add_f32_f32(hvx_vec_mul_f32_f32(x, const_inv_pi), const_half)); + + // y = x - n * pi + HVX_Vector y = hvx_vec_sub_f32_f32(x, hvx_vec_mul_f32_f32(n_float, const_pi)); + + // Sign determination: if n is odd, sign is -1.0f, else 1.0f + // half_n = n * 0.5f + HVX_Vector half_n = hvx_vec_mul_f32_f32(n_float, const_half); + // floor_half_n = floor(half_n) + HVX_Vector floor_half_n = hvx_vec_floor_f32(half_n); + // is_odd = half_n > floor_half_n + HVX_VectorPred is_odd = Q6_Q_vcmp_gt_VsfVsf(half_n, floor_half_n); + // sign = vmux(is_odd, -1.0f, 1.0f) + HVX_Vector sign = Q6_V_vmux_QVV(is_odd, const_neg_one, const_one); + + // z = y^2 + HVX_Vector z = hvx_vec_mul_f32_f32(y, y); + + // Chebyshev approximation for sin(y) + HVX_Vector s4 = hvx_vec_splat_f32(2.642186986152672e-06f); + HVX_Vector s3 = hvx_vec_splat_f32(-0.00019825318964070864f); + HVX_Vector s2 = hvx_vec_splat_f32(0.00833326283319605f); + HVX_Vector s1 = hvx_vec_splat_f32(-0.16666666082087775f); + HVX_Vector s0 = hvx_vec_splat_f32(0.999999999915155f); + + HVX_Vector sin_y = hvx_vec_add_f32_f32(s3, hvx_vec_mul_f32_f32(z, s4)); + sin_y = hvx_vec_add_f32_f32(s2, hvx_vec_mul_f32_f32(z, sin_y)); + sin_y = hvx_vec_add_f32_f32(s1, hvx_vec_mul_f32_f32(z, sin_y)); + sin_y = hvx_vec_add_f32_f32(s0, hvx_vec_mul_f32_f32(z, sin_y)); + sin_y = hvx_vec_mul_f32_f32(y, sin_y); + + return hvx_vec_mul_f32_f32(sin_y, sign); +} + +#endif /* HVX_SIN_COS_H */ diff --git a/ggml/src/ggml-hexagon/htp/hvx-utils.h b/ggml/src/ggml-hexagon/htp/hvx-utils.h index e0452811ec3..0a760cd344c 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-utils.h +++ b/ggml/src/ggml-hexagon/htp/hvx-utils.h @@ -14,6 +14,8 @@ #include "hvx-sqrt.h" #include "hvx-arith.h" #include "hvx-div.h" +#include "hvx-floor.h" +#include "hvx-sin-cos.h" #include "hvx-base.h" #endif /* HVX_UTILS_H */ diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index e8619388478..f3a0866c7cd 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -420,8 +420,7 @@ AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_que ctx->n_threads = n_hvx; for (int i = 0; i < ctx->n_threads; i++) { - // see discussion https://github.com/ggml-org/llama.cpp/pull/18151#discussion_r2632388541 - ctx->dma[i] = dma_queue_create(128); + ctx->dma[i] = dma_queue_create(256); // queue depth } // init worker pool @@ -601,6 +600,9 @@ static int execute_op(struct htp_ops_context * octx) { case HTP_OP_PAD: return op_pad(octx); + case HTP_OP_CONCAT: + return op_concat(octx); + case HTP_OP_GATED_DELTA_NET: return op_gated_delta_net(octx); diff --git a/ggml/src/ggml-hexagon/htp/rope-ops.c b/ggml/src/ggml-hexagon/htp/rope-ops.c index b398e19f06e..c839044b84f 100644 --- a/ggml/src/ggml-hexagon/htp/rope-ops.c +++ b/ggml/src/ggml-hexagon/htp/rope-ops.c @@ -7,6 +7,7 @@ #include #include +#include #include "hex-dma.h" #include "hvx-utils.h" @@ -75,6 +76,9 @@ struct htp_rope_context { size_t theta_cache_offset; uint32_t src0_nrows; + struct fastdiv_values div_ne2_ne1; + struct fastdiv_values div_ne1; + uint64_t t_start; }; @@ -117,13 +121,84 @@ static __attribute__((noinline)) void rope_cache_init(const float theta_base, float * cache, const float theta_scale) { // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py - float theta = theta_base; +#if __HVX_ARCH__ >= 79 + const bool is_v79_or_newer = true; +#else + const bool is_v79_or_newer = false; +#endif + + if (is_v79_or_newer && ext_factor == 0.0f) { + // Fast path: fully vectorized + // We process 32 pairs (64 elements) per iteration. + const uint32_t n_blocks = ne0 / 64; + + // Initialize theta scale powers: [1.0f, theta_scale, theta_scale^2, ..., theta_scale^31] + float __attribute__((aligned(128))) theta_powers[32]; + theta_powers[0] = 1.0f; + for (int j = 1; j < 32; j++) { + theta_powers[j] = theta_powers[j - 1] * theta_scale; + } + HVX_Vector v_theta_powers = hvx_vmem(theta_powers); - for (uint32_t i0 = 0; i0 < ne0; i0 += 2) { - const float ff = freq_factors ? freq_factors[i0 / 2] : 1.0f; - rope_yarn_one(theta / ff, freq_scale, corr_dims, i0, ext_factor, mscale, cache); + HVX_Vector v_freq_scale = hvx_vec_splat_f32(freq_scale); + HVX_Vector v_mscale = hvx_vec_splat_f32(mscale); + + // Base theta starts at theta_base + float theta_block = theta_base; + // The scale factor for the next block is theta_scale^32 + float theta_scale_32 = 1.0f; + for (int j = 0; j < 32; j++) { + theta_scale_32 *= theta_scale; + } + + for (uint32_t b = 0; b < n_blocks; b++) { + uint32_t i0 = b * 64; + HVX_Vector v_theta_base = hvx_vec_splat_f32(theta_block); + HVX_Vector v_theta = hvx_vec_mul_f32_f32(v_theta_base, v_theta_powers); + + if (freq_factors) { + // Load 32 elements of freq_factors + HVX_Vector v_ff = hvx_vmemu(freq_factors + i0 / 2); + HVX_Vector v_inv_ff = hvx_vec_inverse_f32(v_ff); + v_theta = hvx_vec_mul_f32_f32(v_theta, v_inv_ff); + } + + HVX_Vector v_theta_final = hvx_vec_mul_f32_f32(v_theta, v_freq_scale); + + HVX_Vector vcos = hvx_vec_cos_f32(v_theta_final); + HVX_Vector vsin = hvx_vec_sin_f32(v_theta_final); + + vcos = hvx_vec_mul_f32_f32(vcos, v_mscale); + vsin = hvx_vec_mul_f32_f32(vsin, v_mscale); + + HVX_VectorPair vstore = Q6_W_vshuff_VVR(vsin, vcos, -4); - theta *= theta_scale; + if (((uintptr_t)cache) % 128 == 0) { + hvx_vmem(cache + i0 + 0) = Q6_V_lo_W(vstore); + hvx_vmem(cache + i0 + 32) = Q6_V_hi_W(vstore); + } else { + hvx_vec_store_u(cache + i0 + 0, 32 * sizeof(float), Q6_V_lo_W(vstore)); + hvx_vec_store_u(cache + i0 + 32, 32 * sizeof(float), Q6_V_hi_W(vstore)); + } + + theta_block *= theta_scale_32; + } + + // Leftovers + float theta = theta_block; + for (uint32_t i0 = n_blocks * 64; i0 < ne0; i0 += 2) { + const float ff = freq_factors ? freq_factors[i0 / 2] : 1.0f; + rope_yarn_one(theta / ff, freq_scale, corr_dims, i0, ext_factor, mscale, cache); + theta *= theta_scale; + } + } else { + // Fallback to original scalar loop + float theta = theta_base; + for (uint32_t i0 = 0; i0 < ne0; i0 += 2) { + const float ff = freq_factors ? freq_factors[i0 / 2] : 1.0f; + rope_yarn_one(theta / ff, freq_scale, corr_dims, i0, ext_factor, mscale, cache); + theta *= theta_scale; + } } } @@ -195,24 +270,18 @@ static void rope_corr_dims(int n_dims, } static inline void hvx_rope_neox_f32_aa(float * restrict dst, const float * restrict src0, uint32_t ne, const float * restrict theta_cache) { - const HVX_Vector * restrict vsrc = (const HVX_Vector *) src0; - const HVX_Vector * restrict vtheta = (const HVX_Vector *) theta_cache; - HVX_Vector * restrict vdst = (HVX_Vector *) dst; - - uint32_t nvec = (ne / (VLEN_FP32 * 2) * 2); // 2 vecs per loop, step of 2 + const uint32_t he = ne / 2; + const uint32_t nvec = he / 32; + const uint32_t nloe = he % 32; - uint32_t he = ne / 2; // half_dims offset in elements - uint32_t hv = he / VLEN_FP32; // half_dims offset in vectors + for (uint32_t i = 0; i < nvec; i++) { + HVX_Vector v0 = ((const HVX_Vector *) src0)[i]; + HVX_Vector v1 = hvx_vmemu(src0 + he + i * 32); - #pragma unroll(2) - for (uint32_t i = 0; i < nvec; i += 2) { - HVX_Vector v0 = vsrc[i/2+0]; - HVX_Vector v1 = vsrc[i/2+hv]; + HVX_Vector v2 = ((const HVX_Vector *) theta_cache)[i * 2 + 0]; + HVX_Vector v3 = ((const HVX_Vector *) theta_cache)[i * 2 + 1]; - HVX_Vector v2 = vtheta[i+0]; - HVX_Vector v3 = vtheta[i+1]; - - HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(v3, v2, -4); // vcos_sin[0] = cos_theta, vcos_sin[1] = sin_theta + HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(v3, v2, -4); HVX_Vector vx0_c = Q6_Vqf32_vmpy_VsfVsf(v0, Q6_V_lo_W(vcos_sin)); HVX_Vector vx0_s = Q6_Vqf32_vmpy_VsfVsf(v0, Q6_V_hi_W(vcos_sin)); @@ -222,37 +291,45 @@ static inline void hvx_rope_neox_f32_aa(float * restrict dst, const float * rest HVX_Vector v4 = Q6_Vqf32_vsub_Vqf32Vqf32(vx0_c, vx1_s); HVX_Vector v5 = Q6_Vqf32_vadd_Vqf32Vqf32(vx0_s, vx1_c); - vdst[i/2+0] = Q6_Vsf_equals_Vqf32(v4); - vdst[i/2+hv] = Q6_Vsf_equals_Vqf32(v5); + ((HVX_Vector *) dst)[i] = Q6_Vsf_equals_Vqf32(v4); + hvx_vmemu(dst + he + i * 32) = Q6_Vsf_equals_Vqf32(v5); } - for (uint32_t i = nvec * VLEN_FP32; i < ne; i += 2) { - const float cos_theta = theta_cache[i+0]; - const float sin_theta = theta_cache[i+1]; - float x0 = src0[i/2]; - float x1 = src0[i/2 + he]; - dst[i/2] = x0 * cos_theta - x1 * sin_theta; - dst[i/2 + he] = x0 * sin_theta + x1 * cos_theta; + if (nloe > 0) { + HVX_Vector v0 = hvx_vmemu(src0 + nvec * 32); + HVX_Vector v1 = hvx_vmemu(src0 + he + nvec * 32); + + HVX_Vector v2 = ((const HVX_Vector *) theta_cache)[nvec * 2 + 0]; + HVX_Vector v3 = ((const HVX_Vector *) theta_cache)[nvec * 2 + 1]; + + HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(v3, v2, -4); + + HVX_Vector vx0_c = Q6_Vqf32_vmpy_VsfVsf(v0, Q6_V_lo_W(vcos_sin)); + HVX_Vector vx0_s = Q6_Vqf32_vmpy_VsfVsf(v0, Q6_V_hi_W(vcos_sin)); + HVX_Vector vx1_c = Q6_Vqf32_vmpy_VsfVsf(v1, Q6_V_lo_W(vcos_sin)); + HVX_Vector vx1_s = Q6_Vqf32_vmpy_VsfVsf(v1, Q6_V_hi_W(vcos_sin)); + + HVX_Vector v4 = Q6_Vqf32_vsub_Vqf32Vqf32(vx0_c, vx1_s); + HVX_Vector v5 = Q6_Vqf32_vadd_Vqf32Vqf32(vx0_s, vx1_c); + + hvx_vec_store_u(dst + nvec * 32, nloe * sizeof(float), Q6_Vsf_equals_Vqf32(v4)); + hvx_vec_store_u(dst + he + nvec * 32, nloe * sizeof(float), Q6_Vsf_equals_Vqf32(v5)); } } static inline void hvx_rope_f32_aa(float * restrict dst, const float * restrict src0, uint32_t ne, const float * restrict theta_cache) { - const HVX_Vector * restrict vsrc = (const HVX_Vector *) src0; - const HVX_Vector * restrict vtheta = (const HVX_Vector *) theta_cache; - HVX_Vector * restrict vdst = (HVX_Vector *) dst; - - uint32_t nvec = (ne / (VLEN_FP32 * 2)) * 2; // 2 vecs per loop, step of two + const uint32_t nvec = ne / 64; + const uint32_t nloe = ne % 64; - #pragma unroll(2) - for (uint32_t i = 0; i < nvec; i+=2) { - HVX_Vector v0 = vsrc[i+0]; - HVX_Vector v1 = vsrc[i+1]; + for (uint32_t i = 0; i < nvec; i++) { + HVX_Vector v0 = ((const HVX_Vector *) src0)[i * 2 + 0]; + HVX_Vector v1 = ((const HVX_Vector *) src0)[i * 2 + 1]; - HVX_Vector v2 = vtheta[i+0]; - HVX_Vector v3 = vtheta[i+1]; + HVX_Vector v2 = ((const HVX_Vector *) theta_cache)[i * 2 + 0]; + HVX_Vector v3 = ((const HVX_Vector *) theta_cache)[i * 2 + 1]; - HVX_VectorPair vx0_x1 = Q6_W_vdeal_VVR(v1, v0, -4); // vx0_x1[0] = x0, vx0_x1[1] = x1 - HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(v3, v2, -4); // vcos_sin[0] = cos_theta, vcos_sin[1] = sin_theta + HVX_VectorPair vx0_x1 = Q6_W_vdeal_VVR(v1, v0, -4); + HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(v3, v2, -4); HVX_Vector vx0_c = Q6_Vqf32_vmpy_VsfVsf(Q6_V_lo_W(vx0_x1), Q6_V_lo_W(vcos_sin)); HVX_Vector vx0_s = Q6_Vqf32_vmpy_VsfVsf(Q6_V_lo_W(vx0_x1), Q6_V_hi_W(vcos_sin)); @@ -264,17 +341,52 @@ static inline void hvx_rope_f32_aa(float * restrict dst, const float * restrict HVX_VectorPair vstore = Q6_W_vshuff_VVR(Q6_Vsf_equals_Vqf32(v5), Q6_Vsf_equals_Vqf32(v4), -4); - vdst[i+0] = Q6_V_lo_W(vstore); - vdst[i+1] = Q6_V_hi_W(vstore); + ((HVX_Vector *) dst)[i * 2 + 0] = Q6_V_lo_W(vstore); + ((HVX_Vector *) dst)[i * 2 + 1] = Q6_V_hi_W(vstore); } - for (uint32_t i = nvec * VLEN_FP32; i < ne; i += 2) { - const float cos_theta = theta_cache[i+0]; - const float sin_theta = theta_cache[i+1]; - float x0 = src0[i+0]; - float x1 = src0[i+1]; - dst[i+0] = x0 * cos_theta - x1 * sin_theta; - dst[i+1] = x0 * sin_theta + x1 * cos_theta; + if (nloe > 0) { + if (nloe <= 32) { + HVX_Vector v0 = hvx_vmemu(src0 + nvec * 64); + HVX_Vector v2 = hvx_vmemu(theta_cache + nvec * 64); + + HVX_VectorPair vx0_x1 = Q6_W_vdeal_VVR(Q6_V_vzero(), v0, -4); + HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(Q6_V_vzero(), v2, -4); + + HVX_Vector vx0_c = Q6_Vqf32_vmpy_VsfVsf(Q6_V_lo_W(vx0_x1), Q6_V_lo_W(vcos_sin)); + HVX_Vector vx0_s = Q6_Vqf32_vmpy_VsfVsf(Q6_V_lo_W(vx0_x1), Q6_V_hi_W(vcos_sin)); + HVX_Vector vx1_c = Q6_Vqf32_vmpy_VsfVsf(Q6_V_hi_W(vx0_x1), Q6_V_lo_W(vcos_sin)); + HVX_Vector vx1_s = Q6_Vqf32_vmpy_VsfVsf(Q6_V_hi_W(vx0_x1), Q6_V_hi_W(vcos_sin)); + + HVX_Vector v4 = Q6_Vqf32_vsub_Vqf32Vqf32(vx0_c, vx1_s); + HVX_Vector v5 = Q6_Vqf32_vadd_Vqf32Vqf32(vx0_s, vx1_c); + + HVX_VectorPair vstore = Q6_W_vshuff_VVR(Q6_Vsf_equals_Vqf32(v5), Q6_Vsf_equals_Vqf32(v4), -4); + + hvx_vec_store_u(dst + nvec * 64, nloe * sizeof(float), Q6_V_lo_W(vstore)); + } else { + HVX_Vector v0 = hvx_vmemu(src0 + nvec * 64); + HVX_Vector v1 = hvx_vmemu(src0 + nvec * 64 + 32); + + HVX_Vector v2 = hvx_vmemu(theta_cache + nvec * 64); + HVX_Vector v3 = hvx_vmemu(theta_cache + nvec * 64 + 32); + + HVX_VectorPair vx0_x1 = Q6_W_vdeal_VVR(v1, v0, -4); + HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(v3, v2, -4); + + HVX_Vector vx0_c = Q6_Vqf32_vmpy_VsfVsf(Q6_V_lo_W(vx0_x1), Q6_V_lo_W(vcos_sin)); + HVX_Vector vx0_s = Q6_Vqf32_vmpy_VsfVsf(Q6_V_lo_W(vx0_x1), Q6_V_hi_W(vcos_sin)); + HVX_Vector vx1_c = Q6_Vqf32_vmpy_VsfVsf(Q6_V_hi_W(vx0_x1), Q6_V_lo_W(vcos_sin)); + HVX_Vector vx1_s = Q6_Vqf32_vmpy_VsfVsf(Q6_V_hi_W(vx0_x1), Q6_V_hi_W(vcos_sin)); + + HVX_Vector v4 = Q6_Vqf32_vsub_Vqf32Vqf32(vx0_c, vx1_s); + HVX_Vector v5 = Q6_Vqf32_vadd_Vqf32Vqf32(vx0_s, vx1_c); + + HVX_VectorPair vstore = Q6_W_vshuff_VVR(Q6_Vsf_equals_Vqf32(v5), Q6_Vsf_equals_Vqf32(v4), -4); + + ((HVX_Vector *) dst)[nvec * 2 + 0] = Q6_V_lo_W(vstore); + hvx_vec_store_u(dst + nvec * 64 + 32, (nloe - 32) * sizeof(float), Q6_V_hi_W(vstore)); + } } } @@ -348,13 +460,19 @@ static void rope_job_f32(unsigned int nth, unsigned int ith, void * data) { const int32_t * pos = (const int32_t *) src1->data; const float * freq_factors = src2 ? (const float *) src2->data : NULL; - uint32_t ir = 0; + const uint32_t i3_start = fastdiv(src0_start_row, &rctx->div_ne2_ne1); + const uint32_t rem = fastmodulo(src0_start_row, ne2 * ne1, &rctx->div_ne2_ne1); + const uint32_t i2_start = fastdiv(rem, &rctx->div_ne1); + const uint32_t i1_start = fastmodulo(rem, ne1, &rctx->div_ne1); + + uint32_t ir = src0_start_row; uint32_t prev_i2 = (uint32_t) -1; - for (uint32_t i3 = 0; i3 < ne3; i3++) { // batch - for (uint32_t i2 = 0; i2 < ne2; i2++) { // seq-len - for (uint32_t i1 = 0; i1 < ne1; ) { // attn-heads - if (ir < src0_start_row) { ir++; i1++; continue; } + for (uint32_t i3 = i3_start; i3 < ne3; i3++) { // batch + const uint32_t i2_init = (i3 == i3_start) ? i2_start : 0; + for (uint32_t i2 = i2_init; i2 < ne2; i2++) { // seq-len + const uint32_t i1_init = (i3 == i3_start && i2 == i2_start) ? i1_start : 0; + for (uint32_t i1 = i1_init; i1 < ne1; ) { // attn-heads if (ir >= src0_end_row) goto done; // Rows in this block @@ -407,9 +525,6 @@ static void rope_job_f32(unsigned int nth, unsigned int ith, void * data) { ne0, rctx->ext_factor, rctx->attn_factor, theta_cache, rctx->theta_scale); } - - // FARF(HIGH, "rope-theta %u: ir %u i1 %u i2 %u i3 %u cache %p : usec %u", ith, ir, i1, i2, i3, theta_cache, - // (unsigned) HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - rctx->t_start)); } // Skip output DMA transactions from prev block (if any) @@ -489,7 +604,7 @@ static int execute_op_rope_f32(struct htp_ops_context * octx) { // Aligned row sizes for VTCM const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN); const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN); - const size_t theta_cache_size_aligned = hex_round_up(src0->ne[0] * sizeof(float), 128); + const size_t theta_cache_size_aligned = hex_round_up(src0->ne[0] * sizeof(float), 256); // Calculate spad sizes per thread size_t src0_spad_per_thread = theta_cache_size_aligned + HTP_ROPE_SPAD_NROWS * src0_row_size_aligned; @@ -546,6 +661,11 @@ static int execute_op_rope_f32(struct htp_ops_context * octx) { rctx.src0_nrows = src0_nrows; rctx.src0_nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads; + if (src0_nrows > 0) { + rctx.div_ne2_ne1 = init_fastdiv_values(dst->ne[2] * dst->ne[1]); + rctx.div_ne1 = init_fastdiv_values(dst->ne[1]); + } + FARF(HIGH, "rope-f32 n-rows %u n-dims %d ne0 %u ext-factor %.6f theta-scale %.6f attn-factor %.6f\n", rctx.src0_nrows, rctx.n_dims, ne0, rctx.ext_factor, rctx.theta_scale, rctx.attn_factor); diff --git a/ggml/src/ggml-hexagon/htp/set-rows-ops.c b/ggml/src/ggml-hexagon/htp/set-rows-ops.c index 0def7b408bf..58c54967db0 100644 --- a/ggml/src/ggml-hexagon/htp/set-rows-ops.c +++ b/ggml/src/ggml-hexagon/htp/set-rows-ops.c @@ -65,6 +65,9 @@ static void set_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *da // parallelize by rows of src0 const uint32_t dr = srctx->src0_nrows_per_thread; const uint32_t ir0 = dr * ith; + if (ir0 >= nr) { + return; + } const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr; const bool is_i32 = (octx->src[1]->type == HTP_TYPE_I32); @@ -109,6 +112,9 @@ static void set_rows_thread_f16_f32(unsigned int nth, unsigned int ith, void *da // parallelize by rows of src0 const uint32_t dr = srctx->src0_nrows_per_thread; const uint32_t ir0 = dr * ith; + if (ir0 >= nr) { + return; + } const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr; const bool is_i32 = (octx->src[1]->type == HTP_TYPE_I32); diff --git a/ggml/src/ggml-hexagon/htp/unary-ops.c b/ggml/src/ggml-hexagon/htp/unary-ops.c index 40d2d60153a..7d0431d8ba8 100644 --- a/ggml/src/ggml-hexagon/htp/unary-ops.c +++ b/ggml/src/ggml-hexagon/htp/unary-ops.c @@ -207,7 +207,7 @@ static void hvx_fast_norm_f32(const uint8_t * restrict src, // scale = rsqrt(variance + epsilon), mean_x broadcast for subtraction HVX_Vector scale_v = hvx_vec_rsqrt_f32(Q6_Vsf_equals_Vqf32(var_epsilon_v)); - HVX_Vector mean_x_b = hvx_vec_splat_f32(hvx_vec_get_f32(Q6_Vsf_equals_Vqf32(mean_x_v))); + HVX_Vector mean_x_b = hvx_vec_repl_f32(Q6_Vsf_equals_Vqf32(mean_x_v)); #pragma unroll(4) for (int i = 0; i < nvec; i++) { From a0efd13f0fe9e2123a5d04f57bb353225c5f4453 Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Tue, 26 May 2026 08:48:05 -0500 Subject: [PATCH 707/831] vulkan: optimize conv2d and implement coopmat1 support (llama/22620) * vulkan: add CONV_SHAPE_64x128 for medium-K conv2d * vulkan: skip conv2d bounds checks when shapes align with tile sizes * vulkan: use WG_SIZE=128 for CONV_SHAPE_64x32 conv2d * vulkan: stage cm2 conv2d accumulator through shmem before global store * vulkan: add coopmat1 conv2d path * fallback when using too much shared memory. clean up comments * Require 16x16x16 and subgroup size 32 or 64 * check whether shared memory is sufficient before overwriting conv2d params with coopmat1 values --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 119 +++++++++++-- .../ggml-vulkan/vulkan-shaders/conv2d_mm.comp | 159 ++++++++++++++++-- .../vulkan-shaders/vulkan-shaders-gen.cpp | 12 +- 3 files changed, 264 insertions(+), 26 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index aa289220a90..18d7cedad4b 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -398,6 +398,7 @@ enum vk_conv_shapes { CONV_SHAPE_128x128, CONV_SHAPE_64x32, CONV_SHAPE_32x256, + CONV_SHAPE_64x128, CONV_SHAPE_COUNT, }; @@ -412,6 +413,7 @@ vk_conv_block_size vk_conv_block_sizes[CONV_SHAPE_COUNT] = { { 128, 128, 16 }, // CONV_SHAPE_128x128 { 64, 32, 32 }, // CONV_SHAPE_64x32 { 32, 256, 16 }, // CONV_SHAPE_32x256 + { 64, 128, 16 }, // CONV_SHAPE_64x128 }; enum dmmv_wg_sizes { @@ -447,14 +449,16 @@ struct vk_fa_pipeline_state { }; struct vk_conv2d_pipeline_state { - vk_conv2d_pipeline_state(uint32_t s0, uint32_t s1, uint32_t p0, uint32_t p1, uint32_t d0, uint32_t d1, uint32_t KW, uint32_t KH) - : s0(s0), s1(s1), p0(p0), p1(p1), d0(d0), d1(d1), KW(KW), KH(KH) {} + vk_conv2d_pipeline_state(uint32_t s0, uint32_t s1, uint32_t p0, uint32_t p1, uint32_t d0, uint32_t d1, uint32_t KW, uint32_t KH, uint32_t aligned) + : s0(s0), s1(s1), p0(p0), p1(p1), d0(d0), d1(d1), KW(KW), KH(KH), aligned(aligned) {} uint32_t s0, s1, p0, p1, d0, d1, KW, KH; + // when set, shader can skip K/CRS/NPQ bounds checks and address clamps + uint32_t aligned; bool operator<(const vk_conv2d_pipeline_state &b) const { - return std::tie(s0, s1, p0, p1, d0, d1, KW, KH) < - std::tie(b.s0, b.s1, b.p0, b.p1, b.d0, b.d1, b.KW, b.KH); + return std::tie(s0, s1, p0, p1, d0, d1, KW, KH, aligned) < + std::tie(b.s0, b.s1, b.p0, b.p1, b.d0, b.d1, b.KW, b.KH, b.aligned); } }; @@ -4934,7 +4938,8 @@ static void ggml_vk_load_shaders(vk_device& device) { // conv2d, conv_transpose_2d for (uint32_t s = 0; s < CONV_SHAPE_COUNT; ++s) { - uint32_t conv2d_WG_SIZE = 256; + // smaller WG for the small-tile fallback gives more concurrent WGs per SM + uint32_t conv2d_WG_SIZE = (s == CONV_SHAPE_64x32) ? 128 : 256; uint32_t use_collectives = 0; // Enables subgroup ops for preventing the re-calculation of indices. uint32_t conv2d_TS_K = (s == CONV_SHAPE_64x32) ? 4 : 8; uint32_t conv2d_SHMEM_PAD = 4; @@ -4973,18 +4978,77 @@ static void ggml_vk_load_shaders(vk_device& device) { conv2d_BS.CRS); // CRS block size should be capped at subgroup size for correctness when shuffle is used. } - uint32_t conv2d_shmem_req = - (conv2d_BS.K * (conv2d_BS.CRS + conv2d_SHMEM_PAD) + conv2d_BS.CRS * (conv2d_BS.NPQ + conv2d_SHMEM_PAD)) * sizeof(float); - if (device->properties.limits.maxComputeSharedMemorySize < conv2d_shmem_req) { + // cm1 is used only when cm2 is unavailable; capped at 64x128 (due to shared memory size). + // Requires 16x16x16 f16-acc since that's the fragment shape hard-coded in the shader. + // Subgroup size must be 32 or 64 (to keep WG_SIZE sane) and we need + // subgroup_size_control to force the driver to actually use it. + bool conv2d_use_cm1 = false; +#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + conv2d_use_cm1 = !device->coopmat2 && + device->coopmat_support && device->coopmat_support_16x16x16_f16acc && + device->subgroup_size_control && + (device->subgroup_size == 32 || device->subgroup_size == 64) && + s != CONV_SHAPE_128x128; +#endif + + const uint32_t conv2d_cm1_shmem_pad = 8; + + auto shmem_req = [&](uint32_t pad, bool csh_store, bool fp16_shmem) { + const uint32_t elem_size = fp16_shmem ? (uint32_t)sizeof(uint16_t) : (uint32_t)sizeof(float); + const uint32_t csh_elems = csh_store ? conv2d_BS.K * conv2d_BS.NPQ : 0u; + return (conv2d_BS.K * (conv2d_BS.CRS + pad) + conv2d_BS.CRS * (conv2d_BS.NPQ + pad) + csh_elems) * elem_size; + }; + + // coopmat1 needs to store the output through shared memory, so check up front + // whether it'll fit and disable it before applying coopmat1 parameters. + if (conv2d_use_cm1 && device->properties.limits.maxComputeSharedMemorySize < shmem_req(conv2d_cm1_shmem_pad, true, true)) { + conv2d_use_cm1 = false; + } + + uint32_t conv2d_WM = 16, conv2d_WN = 16; // cm1 subgroup tile, ignored otherwise + if (conv2d_use_cm1) { + conv2d_SHMEM_PAD = conv2d_cm1_shmem_pad; + // 16x16x16 fragments; pick WM/WN to keep WG_SIZE at 256 + // (i.e. 8 subgroups for sg=32, 4 subgroups for sg=64). + const bool sg64 = (device->subgroup_size == 64); + switch (s) { + case CONV_SHAPE_64x32: conv2d_WM = sg64 ? 32 : 16; conv2d_WN = 16; break; + case CONV_SHAPE_64x128: conv2d_WM = 32; conv2d_WN = sg64 ? 64 : 32; break; + case CONV_SHAPE_32x256: conv2d_WM = sg64 ? 16 : 32; conv2d_WN = sg64 ? 128 : 32; break; + default: break; + } + const uint32_t warps_M = conv2d_BS.K / conv2d_WM; + const uint32_t warps_N = conv2d_BS.NPQ / conv2d_WN; + conv2d_WG_SIZE = warps_M * warps_N * device->subgroup_size; + } + + // stage cm2 accumulator through shmem for coalesced global stores; + // skipped on 128x128 where the extra Csh footprint hurts occupancy. + // cm1 always uses the staged path. + uint32_t conv2d_csh_store = (device->coopmat2 && s != CONV_SHAPE_128x128) ? 1u : 0u; + if (conv2d_use_cm1) { + conv2d_csh_store = 1; + } + + // shmem is fp16 on cm2/cm1 (matches Csh), fp32 on scalar + const bool conv2d_use_fp16_shmem = device->coopmat2 || conv2d_use_cm1; + + // shrink CRS if the non-cm1 config still doesn't fit + if (device->properties.limits.maxComputeSharedMemorySize < shmem_req(conv2d_SHMEM_PAD, conv2d_csh_store, conv2d_use_fp16_shmem)) { + GGML_ASSERT(!conv2d_use_cm1); conv2d_BS.CRS = 8; if (use_collectives) { conv2d_BS.CRS = std::min(device->subgroup_size, conv2d_BS.CRS); } + conv2d_csh_store = 0; } std::array wg_denoms = { conv2d_BS.K, 1, 1 }; std::vector spec_constants = { conv2d_WG_SIZE, conv2d_BS.K, conv2d_BS.CRS, conv2d_BS.NPQ, conv2d_TS_K, use_collectives, conv2d_SHMEM_PAD }; + // cm1 needs a fixed subgroup width to match the WG_SIZE we computed + const uint32_t conv2d_required_subgroup_size = conv2d_use_cm1 ? device->subgroup_size : 0; + #define CREATE_CONV(name, type_suffix, spv_suffix) \ for (auto &c : device->pipeline_##name##type_suffix[s]) { \ const vk_conv2d_pipeline_state &state = c.first; \ @@ -4997,10 +5061,14 @@ static void ggml_vk_load_shaders(vk_device& device) { spec_constants_cpy.push_back(state.d1); \ spec_constants_cpy.push_back(state.KW); \ spec_constants_cpy.push_back(state.KH); \ + spec_constants_cpy.push_back(state.aligned); \ + spec_constants_cpy.push_back(conv2d_csh_store); \ + spec_constants_cpy.push_back(conv2d_WM); \ + spec_constants_cpy.push_back(conv2d_WN); \ ggml_vk_create_pipeline( \ device, c.second, #name #type_suffix, \ name##type_suffix##spv_suffix##_len, name##type_suffix##spv_suffix##_data, "main", 3, \ - sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants_cpy, 1, true, use_collectives); \ + sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants_cpy, 1, true, use_collectives || conv2d_required_subgroup_size, conv2d_required_subgroup_size); \ } #define CREATE_CONVS(spv_suffix) \ CREATE_CONV(conv2d, _f32, spv_suffix) \ @@ -5011,6 +5079,11 @@ static void ggml_vk_load_shaders(vk_device& device) { if (device->coopmat2) { CREATE_CONVS(_cm2) } else +#endif +#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + if (conv2d_use_cm1) { + CREATE_CONVS(_cm1) + } else #endif if (conv2d_UNROLL) { CREATE_CONVS(_unroll) @@ -9473,10 +9546,23 @@ static vk_conv_shapes ggml_vk_conv_select_shape(ggml_backend_vk_context * ctx, u // so small convolutions will still choose a smaller tile. const uint32_t shader_core_count = ctx->device->shader_core_count > 0 ? ctx->device->shader_core_count : 32; - if (K > 64 && n_tiles(CONV_SHAPE_128x128) >= shader_core_count * 2) { + // 128x128 isn't used with cm1 due to shared memory size; fall through to a smaller tile. + bool allow_128x128 = true; +#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + if (!ctx->device->coopmat2 && ctx->device->coopmat_support && ctx->device->coopmat_support_16x16x16_f16acc) { + allow_128x128 = false; + } +#endif + + if (allow_128x128 && K > 64 && n_tiles(CONV_SHAPE_128x128) >= shader_core_count * 2) { return CONV_SHAPE_128x128; } else if (K <= 32 && n_tiles(CONV_SHAPE_32x256) >= shader_core_count * 2) { return CONV_SHAPE_32x256; + } else if (K <= 64 && n_tiles(CONV_SHAPE_64x128) >= shader_core_count * 2) { + return CONV_SHAPE_64x128; + } else if (!allow_128x128 && K > 64 && n_tiles(CONV_SHAPE_64x128) >= shader_core_count * 2) { + // cm1 fallback for large K when 128x128 isn't available + return CONV_SHAPE_64x128; } else { return CONV_SHAPE_64x32; } @@ -10008,7 +10094,18 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const uint32_t p1 = !transpose ? (uint32_t)ggml_get_op_params_i32(dst, 3) : 0; uint32_t d0 = !transpose ? (uint32_t)ggml_get_op_params_i32(dst, 4) : 1; uint32_t d1 = !transpose ? (uint32_t)ggml_get_op_params_i32(dst, 5) : 1; - vk_conv2d_pipeline_state conv2d_pipeline_state(s0, s1, p0, p1, d0, d1, KW, KH); + + // tile-aligned shapes let the shader skip bounds checks + const uint32_t Cin = (uint32_t)src1->ne[2]; + const uint32_t CRS = Cin * KW * KH; + const uint32_t BS_K = vk_conv_block_sizes[shape].K; + const uint32_t BS_CRS = vk_conv_block_sizes[shape].CRS; + const uint32_t BS_NPQ = vk_conv_block_sizes[shape].NPQ; + const uint32_t aligned = ((K % BS_K == 0) && + (CRS % BS_CRS == 0) && + (NPQ % BS_NPQ == 0)) ? 1u : 0u; + + vk_conv2d_pipeline_state conv2d_pipeline_state(s0, s1, p0, p1, d0, d1, KW, KH, aligned); std::map *pipelines = nullptr; if (op == GGML_OP_CONV_2D) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp index 875c012cd3b..1428ef68d81 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp @@ -7,6 +7,13 @@ #extension GL_KHR_memory_scope_semantics : enable #endif +#ifdef COOPMAT +#extension GL_KHR_cooperative_matrix : enable +#extension GL_KHR_shader_subgroup_basic : enable +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#extension GL_KHR_memory_scope_semantics : enable +#endif + #ifdef USE_COLLECTIVES # extension GL_KHR_shader_subgroup_shuffle : enable #endif @@ -77,6 +84,39 @@ layout(constant_id = 12) const uint d1 = 1; // Kernel spatial sizes layout(constant_id = 13) const uint KW = 1; layout(constant_id = 14) const uint KH = 1; +// when set, skip bounds checks and address clamps (K/CRS/NPQ are tile-aligned) +layout(constant_id = 15) const uint aligned = 0; +// stage cm2 result through shmem (Csh) for coalesced stores. cm1 always does this. +layout(constant_id = 16) const uint csh_store = 0; + +#ifdef COOPMAT +// cm1 subgroup tile: each subgroup computes a WM x WN region as a grid of +// TM x TN x TK fragments. Requires WM%TM == WN%TN == BS_K%WM == BS_NPQ%WN == +// BS_CRS%TK == 0, and WG_SIZE == (BS_K/WM) * (BS_NPQ/WN) * subgroup_size. +layout(constant_id = 17) const uint WM = 32; +layout(constant_id = 18) const uint WN = 32; +const uint TM = 16; +const uint TN = 16; +const uint TK = 16; +const uint cms_per_row = WM / TM; +const uint cms_per_col = WN / TN; +const uint warps_M = BS_K / WM; +const uint warps_N = BS_NPQ / WN; +#endif + +// without padding, H_idx/W_idx are in bounds by construction (non-TRANSPOSE only) +#ifdef TRANSPOSE +const bool hw_in_bounds = false; +#else +const bool hw_in_bounds = (p0 == 0) && (p1 == 0); +#endif + +// TRANSPOSE stride alignment is trivially satisfied for stride 1 +#ifdef TRANSPOSE +const bool stride_in_bounds = (s0 == 1) && (s1 == 1); +#else +const bool stride_in_bounds = true; +#endif uint32_t tid = gl_LocalInvocationID.x; const uint32_t WG_SIZE = gl_WorkGroupSize.x; @@ -94,7 +134,7 @@ uint32_t n_elems_out = K * NPQ; // Number of blocktiles per input uint32_t NB_CRS = splitWork(CRS, BS_CRS); -#ifdef COOPMAT2 +#if defined(COOPMAT2) || defined(COOPMAT) #define SHMEM_TYPE float16_t #else #define SHMEM_TYPE float @@ -112,6 +152,17 @@ const uint32_t Bsh_len = BS_CRS * Bsh_stride; shared SHMEM_TYPE Ash[Ash_len]; // K x CRS shared SHMEM_TYPE Bsh[Bsh_len]; // CRS x NPQ +#if defined(COOPMAT2) || defined(COOPMAT) +// stage matC through shmem so global stores are row-major (NPQ-contiguous) +const uint32_t Csh_stride = BS_NPQ; +#ifdef COOPMAT +const uint32_t Csh_len = BS_K * Csh_stride; +#else +const uint32_t Csh_len = csh_store != 0 ? BS_K * Csh_stride : 1; +#endif +shared SHMEM_TYPE Csh[Csh_len]; // K x NPQ +#endif + // Threadtile sizes const uint32_t TS_NPQ = BS_K * BS_NPQ / WG_SIZE / TS_K; @@ -161,7 +212,7 @@ ACC_TYPE perElemOpStore(const in uint32_t r, const in uint32_t c, const in ACC_T uint32_t OH_idx = fastdiv(NPQ_idx - N_idx * p.OH * p.OW, p.OWmp, p.OWL); // divide by p.OW; uint32_t OW_idx = NPQ_idx - N_idx * p.OH * p.OW - OH_idx * p.OW; uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + K_idx * p.nb2 + N_idx * p.nb3; - if (K_idx < K && NPQ_idx < NPQ) { + if (aligned != 0 || (K_idx < K && NPQ_idx < NPQ)) { dst_data[dst_idx] = D_TYPE(elem); } return elem; @@ -176,6 +227,13 @@ void main() { #ifdef COOPMAT2 coopmat matC; matC = coopmat(0.0); +#elif defined(COOPMAT) + coopmat sums[cms_per_row * cms_per_col]; + [[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) { + sums[i] = coopmat(0.0); + } + const uint warp_r = gl_SubgroupID / warps_N; + const uint warp_c = gl_SubgroupID % warps_N; #else float regC[TS_K][TS_NPQ]; for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) { @@ -228,12 +286,15 @@ void main() { uint32_t B_lx = Ac; uint32_t K_idx = B_idx_K * BS_K + B_ly; /* Global K_idx (row index of A)*/ #ifdef TRANSPOSE - uint32_t knl_idx = min(KW_idx_a + KH_idx_a * p.nb01 + K_idx * p.nb02 + Cin_idx_a * p.nb03, K * CRS - 1); + uint32_t knl_idx = KW_idx_a + KH_idx_a * p.nb01 + K_idx * p.nb02 + Cin_idx_a * p.nb03; #else - uint32_t knl_idx = min(KW_idx_a + KH_idx_a * p.nb01 + Cin_idx_a * p.nb02 + K_idx * p.nb03, K * CRS - 1); + uint32_t knl_idx = KW_idx_a + KH_idx_a * p.nb01 + Cin_idx_a * p.nb02 + K_idx * p.nb03; #endif + if (aligned == 0) { + knl_idx = min(knl_idx, K * CRS - 1); + } float val = knl_data[knl_idx]; - if (K_idx >= K || CRS_idx_a >= CRS) { + if (aligned == 0 && (K_idx >= K || CRS_idx_a >= CRS)) { val = 0.0; } Ash[B_ly * Ash_stride + B_lx] = SHMEM_TYPE(val); @@ -282,15 +343,27 @@ void main() { uint32_t H_idx = OH_idx * s1 + KH_idx_b * d1 - p1; uint32_t W_idx = OW_idx * s0 + KW_idx_b * d0 - p0; #endif - uint32_t src_idx = - min(max(W_idx + H_idx * p.nb11 + Cin_idx_b * p.nb12 + N_idx * p.nb13, 0), p.Cin * p.N * p.W * p.H - 1); + uint32_t src_idx = W_idx + H_idx * p.nb11 + Cin_idx_b * p.nb12 + N_idx * p.nb13; + // skip clamp when address can't go OOB + if (aligned == 0 || !hw_in_bounds || !stride_in_bounds) { + src_idx = min(max(src_idx, 0), p.Cin * p.N * p.W * p.H - 1); + } float val = src_data[src_idx]; - if (CRS_idx_b >= CRS || NPQ_idx >= NPQ - || H_idx >= p.H || W_idx >= p.W // Lower bound checks aren't necessary. (idx >= 0x80000000 for such case) + bool oob = false; + if (aligned == 0 && (CRS_idx_b >= CRS || NPQ_idx >= NPQ)) { + oob = true; + } + // also catches lower-bound underflow (idx wraps to 0x80000000+) + if (!hw_in_bounds && (H_idx >= p.H || W_idx >= p.W)) { + oob = true; + } #ifdef TRANSPOSE - || (H_idx_x_s1 - H_idx * s1 != 0) || (W_idx_x_s0 - W_idx * s0 != 0) + if (!stride_in_bounds && + ((H_idx_x_s1 - H_idx * s1 != 0) || (W_idx_x_s0 - W_idx * s0 != 0))) { + oob = true; + } #endif - ) { + if (oob) { val = 0.0; } Bsh[B_ly * Bsh_stride + B_lx] = SHMEM_TYPE(val); @@ -303,6 +376,23 @@ void main() { coopMatLoad(matA, Ash, 0, Ash_stride, gl_CooperativeMatrixLayoutRowMajor); coopMatLoad(matB, Bsh, 0, Bsh_stride, gl_CooperativeMatrixLayoutRowMajor); matC = coopMatMulAdd(matA, matB, matC); +#elif defined(COOPMAT) + // each subgroup multiplies its grid of fragments per TK-sized CRS chunk + [[unroll]] for (uint k_step = 0; k_step < BS_CRS / TK; k_step++) { + coopmat cache_a[cms_per_row]; + [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { + const uint a_off = (warp_r * WM + cm_row * TM) * Ash_stride + k_step * TK; + coopMatLoad(cache_a[cm_row], Ash, a_off, Ash_stride, gl_CooperativeMatrixLayoutRowMajor); + } + [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { + coopmat cache_b; + const uint b_off = k_step * TK * Bsh_stride + warp_c * WN + cm_col * TN; + coopMatLoad(cache_b, Bsh, b_off, Bsh_stride, gl_CooperativeMatrixLayoutRowMajor); + [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { + sums[cm_col * cms_per_row + cm_row] = coopMatMulAdd(cache_a[cm_row], cache_b, sums[cm_col * cms_per_row + cm_row]); + } + } + } #else if (T_y * TS_K < K) { UNROLL for (uint32_t CRS_lidx = 0; CRS_lidx < BS_CRS; CRS_lidx++) { @@ -325,8 +415,51 @@ void main() { barrier(); } /* Save C* */ +#if defined(COOPMAT2) || defined(COOPMAT) + // stage matC into Csh, then write to dst with coalesced NPQ-contiguous stores +#ifdef COOPMAT + const bool use_staged_store = true; +#else + const bool use_staged_store = (csh_store != 0); +#endif + if (use_staged_store) { +#ifdef COOPMAT + // cm1: each subgroup stores its fragment grid into its Csh slot + [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { + [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { + const uint csh_off = (warp_r * WM + cm_row * TM) * Csh_stride + warp_c * WN + cm_col * TN; + coopMatStore(sums[cm_col * cms_per_row + cm_row], Csh, csh_off, Csh_stride, gl_CooperativeMatrixLayoutRowMajor); + } + } +#else + coopMatStore(matC, Csh, 0, Csh_stride, gl_CooperativeMatrixLayoutRowMajor); +#endif + barrier(); + + // cooperative shmem->global: WG threads spread across BS_NPQ (the + // contiguous direction of dst), each iter covers store_rows_per_iter K-rows + const uint32_t store_rows_per_iter = WG_SIZE / BS_NPQ; + const uint32_t store_iters = BS_K / store_rows_per_iter; + const uint32_t k_thread_offset = tid / BS_NPQ; + const uint32_t npq_thread = tid % BS_NPQ; + [[unroll]] for (uint32_t i = 0; i < store_iters; i++) { + uint32_t k_local = i * store_rows_per_iter + k_thread_offset; + uint32_t K_idx = B_idx_K * BS_K + k_local; + uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + npq_thread; + uint32_t N_idx = fastdiv(NPQ_idx, p.OWOHmp, p.OWOHL); + uint32_t OH_idx = fastdiv(NPQ_idx - N_idx * p.OH * p.OW, p.OWmp, p.OWL); + uint32_t OW_idx = NPQ_idx - N_idx * p.OH * p.OW - OH_idx * p.OW; + uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + K_idx * p.nb2 + N_idx * p.nb3; + if (aligned != 0 || (K_idx < K && NPQ_idx < NPQ)) { + dst_data[dst_idx] = D_TYPE(Csh[k_local * Csh_stride + npq_thread]); + } + } + } #ifdef COOPMAT2 - coopMatPerElementNV(matC, matC, perElemOpStore); + else { + coopMatPerElementNV(matC, matC, perElemOpStore); + } +#endif #else if (T_y * TS_K < K) { for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) { @@ -337,7 +470,7 @@ void main() { uint32_t OH_idx = fastdiv(NPQ_idx - N_idx * p.OH * p.OW, p.OWmp, p.OWL); // divide by p.OW; uint32_t OW_idx = NPQ_idx - N_idx * p.OH * p.OW - OH_idx * p.OW; uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + K_idx * p.nb2 + N_idx * p.nb3; - if (K_idx < K && NPQ_idx < NPQ) { + if (aligned != 0 || (K_idx < K && NPQ_idx < NPQ)) { dst_data[dst_idx] = regC[T_ly][T_lx]; } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index a1d735150fd..a0aac391298 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -984,8 +984,16 @@ void process_shaders() { string_to_spv(name + (unroll ? "_unroll" : ""), "conv2d_mm.comp", defines); #if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) if (unroll) { - defines["COOPMAT2"] = "1"; - string_to_spv(name, "conv2d_mm.comp", defines, true, false, true); + auto cm2_defines = defines; + cm2_defines["COOPMAT2"] = "1"; + string_to_spv(name, "conv2d_mm.comp", cm2_defines, true, false, true); + } +#endif +#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + if (unroll) { + auto cm1_defines = defines; + cm1_defines["COOPMAT"] = "1"; + string_to_spv(name, "conv2d_mm.comp", cm1_defines, true, true, false); } #endif } From 6a249cd6400a2be44f2fdd5d38248aa2b36d5f92 Mon Sep 17 00:00:00 2001 From: Vladislav Date: Wed, 27 May 2026 01:59:35 +0300 Subject: [PATCH 708/831] ggml-zendnn : fixed naming of matmul function (llama/20964) * ggml-zendnn: fixed naming of matmul function * ggml-zendnn: fixed naming of mul_mat_id function * ggml-zendnn: fixed print in mul_mat_id --------- Co-authored-by: plotnikov.v10 --- ggml/src/ggml-zendnn/ggml-zendnn.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-zendnn/ggml-zendnn.cpp b/ggml/src/ggml-zendnn/ggml-zendnn.cpp index 6051d082003..3c33dcb11a0 100644 --- a/ggml/src/ggml-zendnn/ggml-zendnn.cpp +++ b/ggml/src/ggml-zendnn/ggml-zendnn.cpp @@ -88,7 +88,7 @@ static bool ggml_zendnn_matmul(ggml_backend_zendnn_context * ctx, int64_t m, int return true; } -static bool ggml_zendnn_sgemm(ggml_backend_zendnn_context * ctx, int64_t m, int64_t n, int64_t k, +static bool ggml_zendnn_gemm(ggml_backend_zendnn_context * ctx, int64_t m, int64_t n, int64_t k, const void * A, int64_t lda, const void * B, int64_t ldb, void * C, int64_t ldc, int Atype, int Btype, int Ctype) { @@ -200,7 +200,7 @@ static void ggml_zendnn_compute_forward_mul_mat( for (int64_t i12 = 0; i12 < ne12; i12++) { const void* wdata = (src1->type == vec_dot_type || src0->type == GGML_TYPE_Q8_0) ? src1->data : work_data; const size_t row_size = ggml_row_size(vec_dot_type, ne10); - if (!ggml_zendnn_sgemm(ctx, + if (!ggml_zendnn_gemm(ctx, ne01, // m ne11, // n ne10, // k @@ -213,7 +213,7 @@ static void ggml_zendnn_compute_forward_mul_mat( src0->type, src0->type == GGML_TYPE_Q8_0 ? GGML_TYPE_F32 : vec_dot_type, dst->type)) - GGML_ABORT("%s: ZenDNN sgemm failed\n", __func__); + GGML_ABORT("%s: ZenDNN gemm failed\n", __func__); } } } @@ -355,7 +355,7 @@ static void ggml_zendnn_compute_forward_mul_mat_id( } // batched gemm for all tokens in this expert - if (!ggml_zendnn_sgemm(ctx, + if (!ggml_zendnn_gemm(ctx, ne01, // m cne1, // n ne10, // k @@ -368,7 +368,7 @@ static void ggml_zendnn_compute_forward_mul_mat_id( src0->type, src0->type == GGML_TYPE_Q8_0 ? GGML_TYPE_F32 : vec_dot_type, dst->type)) { - GGML_ABORT("%s: ZenDNN sgemm failed\n", __func__); + GGML_ABORT("%s: ZenDNN gemm failed\n", __func__); } // scatter output rows to destination From 80e87ec453081f649903c70168cf3279fe455eff Mon Sep 17 00:00:00 2001 From: Winston Ma Date: Wed, 27 May 2026 17:48:40 +0800 Subject: [PATCH 709/831] vulkan: avoid preferring transfer queue on AMD UMA devices (llama/22455) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 18d7cedad4b..f45b9cfd1e9 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -5841,8 +5841,12 @@ static vk_device ggml_vk_get_device(size_t idx) { ggml_vk_load_shaders(device); - // Only use transfer queue on AMD non-GCN, when the graphics queue is not enabled - const bool prefers_transfer_queue = device->vendor_id == VK_VENDOR_ID_AMD && device->architecture != AMD_GCN && !allow_graphics_queue; + // Prefer a dedicated transfer queue on AMD dGPUs (non-GCN) when graphics queue use is disabled. + const bool prefers_transfer_queue = + device->vendor_id == VK_VENDOR_ID_AMD && + device->architecture != AMD_GCN && + !device->uma && + !allow_graphics_queue; if (!device->single_queue) { const uint32_t transfer_queue_index = compute_queue_family_index == transfer_queue_family_index ? 1 : 0; From 98c6722fecccfca0c6ac947b487888bf375b90a2 Mon Sep 17 00:00:00 2001 From: Oliver Simons Date: Wed, 27 May 2026 14:21:04 +0200 Subject: [PATCH 710/831] CUDA: restrict PDL to CTK >= 12.3 due to MSVC issues (llama/23742) --- ggml/src/ggml-cuda/common.cuh | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index e54ecb29308..50d7763dcdd 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -110,11 +110,14 @@ # define GGML_CUDA_USE_CUB #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070 -// PDL host-side support (cudaLaunchKernelEx) requires CUDART >= 11.8 and excludes HIP/MUSA. +// PDL host-side support (cudaLaunchKernelEx) requires CUDART >= 11.8. +// However, this has been bugged in CTK < 12.3 for MSVC builds, see +// https://github.com/ggml-org/llama.cpp/pull/22522#discussion_r3302393293 // __CUDA_ARCH__ is undefined in host passes; GPU arch check happens in device-side code. -#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11080 +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && \ + (CUDART_VERSION >= 12030 || (!(defined(_MSC_VER) && !defined(__clang__)) && CUDART_VERSION >= 11080)) # define GGML_CUDA_USE_PDL -#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11080 +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && (CUDART_VERSION >= 12030 || (!(defined(_MSC_VER) && !defined(__clang__)) && CUDART_VERSION >= 11080)) static __device__ __forceinline__ void ggml_cuda_pdl_sync() { #if defined(GGML_CUDA_USE_PDL) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER From c5cde8c7171dea86345338a7dee64d939b4c09cf Mon Sep 17 00:00:00 2001 From: l8bloom Date: Wed, 27 May 2026 16:59:08 +0200 Subject: [PATCH 711/831] vulkan: add REPEAT op support for f16 to f16. (llama/23298) * feat: extend repeat op for vulkan * feat: add repeat_f16 vulkan pipeline * fix: ensure same dst and src types * fix: use type_size instead of data types * fix: use int16 and int32 for repeat shader op * chore: rename repeat_f* to repeat_i* * chore: rename repeat vulkan pipelines --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 15 +++++++++++---- .../vulkan-shaders/vulkan-shaders-gen.cpp | 4 +++- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index f45b9cfd1e9..99b42f3bdf0 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -768,7 +768,8 @@ struct vk_device_struct { vk_pipeline pipeline_clamp_f32; vk_pipeline pipeline_pad_f32; vk_pipeline pipeline_roll_f32; - vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32; + vk_pipeline pipeline_repeat_i32, pipeline_repeat_back_f32; + vk_pipeline pipeline_repeat_i16; vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f16_f32, pipeline_cpy_f32_bf16, pipeline_cpy_bf16_f32, pipeline_cpy_f32_i32, pipeline_cpy_i32_f32; vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16, pipeline_contig_cpy_bf16_f32, pipeline_contig_cpy_f32_i32, pipeline_contig_cpy_i32_f32; vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT]; @@ -4708,9 +4709,11 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_roll_f32, "roll_f32", roll_f32_len, roll_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_repeat_f32, "repeat_f32", repeat_f32_len, repeat_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_repeat_i32, "repeat_i32", repeat_i32_len, repeat_i32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_repeat_back_f32, "repeat_back_f32", repeat_back_f32_len, repeat_back_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_repeat_i16, "repeat_i16", repeat_i16_len, repeat_i16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + #define CREATE_UNARY(name) \ ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \ ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); @@ -9738,7 +9741,10 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return nullptr; case GGML_OP_REPEAT: if (ggml_type_size(src0->type) == sizeof(float) && ggml_type_size(dst->type) == sizeof(float)) { - return ctx->device->pipeline_repeat_f32; + return ctx->device->pipeline_repeat_i32; + } + if (ggml_type_size(src0->type) == 2 && ggml_type_size(dst->type) == 2) { + return ctx->device->pipeline_repeat_i16; } return nullptr; case GGML_OP_REPEAT_BACK: @@ -16253,7 +16259,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm return false; } case GGML_OP_REPEAT: - return ggml_type_size(op->type) == sizeof(float) && ggml_type_size(op->src[0]->type) == sizeof(float); + return ggml_type_size(op->type) == ggml_type_size(op->src[0]->type) && + (ggml_type_size(op->type) == sizeof(float) || ggml_type_size(op->type) == 2); case GGML_OP_REPEAT_BACK: return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_ROPE: diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index a0aac391298..24b9d25f733 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -798,9 +798,11 @@ void process_shaders() { string_to_spv("div_f32", "div.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); - string_to_spv("repeat_f32", "repeat.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("repeat_i32", "repeat.comp", {{"A_TYPE", "int32_t"}, {"D_TYPE", "int32_t"}}); string_to_spv("repeat_back_f32", "repeat_back.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("repeat_i16", "repeat.comp", {{"A_TYPE", "int16_t"}, {"D_TYPE", "int16_t"}}); + string_to_spv("scale_f32", "scale.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); string_to_spv("sqr_f32", "square.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); From 1b590bbb9ae834d31d1116e804249296bc83762c Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Wed, 27 May 2026 10:18:28 -0500 Subject: [PATCH 712/831] vulkan: use GL_NV_cooperative_matrix_decode_vector for faster matmul (llama/23541) --- ggml/src/ggml-vulkan/CMakeLists.txt | 6 + ggml/src/ggml-vulkan/ggml-vulkan.cpp | 189 +++++- .../ggml-vulkan/vulkan-shaders/CMakeLists.txt | 4 + .../vulkan-shaders/dequant_funcs_cm2.glsl | 608 ++++++++++++++++++ .../feature-tests/coopmat2_decode_vector.comp | 7 + .../vulkan-shaders/flash_attn_cm2.comp | 42 +- .../vulkan-shaders/mul_mm_cm2.comp | 8 +- .../src/ggml-vulkan/vulkan-shaders/types.glsl | 7 + 8 files changed, 865 insertions(+), 6 deletions(-) create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2_decode_vector.comp diff --git a/ggml/src/ggml-vulkan/CMakeLists.txt b/ggml/src/ggml-vulkan/CMakeLists.txt index 65785ae4566..2d9e85794ad 100644 --- a/ggml/src/ggml-vulkan/CMakeLists.txt +++ b/ggml/src/ggml-vulkan/CMakeLists.txt @@ -79,6 +79,12 @@ if (Vulkan_FOUND) "GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT" ) + test_shader_extension_support( + "GL_NV_cooperative_matrix_decode_vector" + "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/feature-tests/coopmat2_decode_vector.comp" + "GGML_VULKAN_COOPMAT2_DECODE_VECTOR_GLSLC_SUPPORT" + ) + test_shader_extension_support( "GL_EXT_integer_dot_product" "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/feature-tests/integer_dot.comp" diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 99b42f3bdf0..fb07282ef76 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -21,6 +21,19 @@ DispatchLoaderDynamic & ggml_vk_default_dispatcher(); #include +// Fallback definitions for VK_NV_cooperative_matrix_decode_vector in case the +// installed Vulkan headers predate the extension. +#ifndef VK_NV_cooperative_matrix_decode_vector +#define VK_NV_cooperative_matrix_decode_vector 1 +#define VK_NV_COOPERATIVE_MATRIX_DECODE_VECTOR_EXTENSION_NAME "VK_NV_cooperative_matrix_decode_vector" +#define VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_DECODE_VECTOR_FEATURES_NV ((VkStructureType)1000689000) +typedef struct VkPhysicalDeviceCooperativeMatrixDecodeVectorFeaturesNV { + VkStructureType sType; + void* pNext; + VkBool32 cooperativeMatrixDecodeVector; +} VkPhysicalDeviceCooperativeMatrixDecodeVectorFeaturesNV; +#endif + // SPIR-V Headers: different SDK installations expose different include paths. // LunarG Vulkan SDK on Windows typically provides . // Linux packages, MSYS2 and MinGW often use the Khronos layout . @@ -678,6 +691,7 @@ struct vk_device_struct { uint32_t coopmat_int_k; bool coopmat2; + bool coopmat2_decode_vector; bool pipeline_executable_properties_support {}; @@ -2167,6 +2181,136 @@ static uint32_t compile_count = 0; static std::mutex compile_count_mutex; static std::condition_variable compile_count_cond; +static constexpr uint32_t kSpvOpCooperativeMatrixLoadTensorNV = 5367; +static constexpr uint32_t kSpvCapabilityCooperativeMatrixDecodeVectorNV = 5447; +static constexpr uint32_t kSpvTensorAddressingDecodeVectorFuncBit = 0x4; + +// Remove SPV_NV_cooperative_matrix_decode_vector usage from a SPIR-V module so it +// can be loaded on drivers that only support SPV_NV_cooperative_matrix2. Drops the +// OpExtension declaration, the CooperativeMatrixDecodeVectorNV OpCapability, and the +// DecodeVectorFunc operand from any OpCooperativeMatrixLoadTensorNV instruction. +// Returns true when the input used the extension (and `out` was populated with a +// stripped copy); returns false otherwise without touching `out`. +static bool ggml_vk_strip_decode_vector(const uint32_t * code, size_t word_count, std::vector & out) { + static const char kDecodeVectorExt[] = "SPV_NV_cooperative_matrix_decode_vector"; + + if (word_count < 5) { + return false; + } + + bool uses_decode_vector = false; + for (size_t pos = 5; pos < word_count; ) { + uint32_t word = code[pos]; + uint32_t wc = word >> spv::WordCountShift; + uint32_t op = word & spv::OpCodeMask; + GGML_ASSERT(wc > 0 && pos + wc <= word_count); + if (op == spv::OpExtension && wc >= 2) { + const char * s = reinterpret_cast(&code[pos + 1]); + if (strcmp(s, kDecodeVectorExt) == 0) { + uses_decode_vector = true; + break; + } + } + pos += wc; + } + + if (!uses_decode_vector) { + return false; + } + + VK_LOG_DEBUG("ggml_vk_strip_decode_vector: stripping SPV_NV_cooperative_matrix_decode_vector"); + + // Bulk-copy unchanged runs and only break the run when an instruction needs to + // be dropped or patched. Use reserve + insert/push_back so the destination buffer + // is touched exactly once (no zero-initialization pass from resize()). + out.clear(); + out.reserve(word_count); + + size_t run_start = 0; + auto flush_run = [&](size_t up_to) { + if (up_to > run_start) { + out.insert(out.end(), code + run_start, code + up_to); + } + }; + + for (size_t pos = 5; pos < word_count; ) { + uint32_t word = code[pos]; + uint32_t wc = word >> spv::WordCountShift; + uint32_t op = word & spv::OpCodeMask; + GGML_ASSERT(wc > 0 && pos + wc <= word_count); + + if (op == spv::OpExtension && wc >= 2) { + const char * s = reinterpret_cast(&code[pos + 1]); + if (strcmp(s, kDecodeVectorExt) == 0) { + flush_run(pos); + pos += wc; + run_start = pos; + continue; + } + } + + if (op == spv::OpCapability && wc == 2 && code[pos + 1] == kSpvCapabilityCooperativeMatrixDecodeVectorNV) { + flush_run(pos); + pos += wc; + run_start = pos; + continue; + } + + if (op == kSpvOpCooperativeMatrixLoadTensorNV) { + // [opcode/wc][ResultType][Result][Pointer][Object][TensorLayout][MemOperand mask][mem extras...][TA mask][ta extras...] + GGML_ASSERT(wc >= 8); + + uint32_t mem_mask = code[pos + 6]; + size_t cur = pos + 7; + // Each of these MemoryAccess bits (when set) carries one trailing operand. + cur += (mem_mask & 0x2) ? 1 : 0; // Aligned + cur += (mem_mask & 0x8) ? 1 : 0; // MakePointerAvailable + cur += (mem_mask & 0x10) ? 1 : 0; // MakePointerVisible + cur += (mem_mask & 0x10000) ? 1 : 0; // AliasScopeINTELMask + cur += (mem_mask & 0x20000) ? 1 : 0; // NoAliasINTELMask + GGML_ASSERT(cur < pos + wc); + + uint32_t ta_mask = code[cur]; + if ((ta_mask & kSpvTensorAddressingDecodeVectorFuncBit) == 0) { + pos += wc; + continue; // leave instruction inside the current unchanged run + } + + flush_run(pos); + + // Append unchanged prefix of the instruction (header through the mem-extras). + size_t inst_start = out.size(); + size_t pre_n = cur - pos; + out.insert(out.end(), code + pos, code + pos + pre_n); + + // Emit TA mask with the DecodeVectorFunc bit cleared. + out.push_back(ta_mask & ~kSpvTensorAddressingDecodeVectorFuncBit); + + // TA extras: TensorView (0x1) and DecodeFunc (0x2) are kept verbatim; + // DecodeVectorFunc (0x4) is dropped along with its trailing id operand. + size_t keep_ta_extras = ((ta_mask & 0x1) ? 1 : 0) + ((ta_mask & 0x2) ? 1 : 0); + if (keep_ta_extras) { + out.insert(out.end(), code + cur + 1, code + cur + 1 + keep_ta_extras); + } + + GGML_ASSERT(wc == pre_n + 1 + keep_ta_extras + 1); + + // Patch the instruction header with the new (one-shorter) word count. + uint32_t new_wc = wc - 1; + out[inst_start] = (new_wc << spv::WordCountShift) | op; + + pos += wc; + run_start = pos; + continue; + } + + pos += wc; + } + + flush_run(word_count); + return true; +} + static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipeline, size_t spv_size, const void* spv_data, const std::string entrypoint, uint32_t parameter_count, std::array wg_denoms, std::vector specialization_constants, bool disable_robustness, bool require_full_subgroups, uint32_t required_subgroup_size) { @@ -2238,6 +2382,18 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin shader_module_create_info = vk::ShaderModuleCreateInfo({}, spirv.size() * sizeof(uint32_t), spirv.data()); } +#if defined(GGML_VULKAN_COOPMAT2_DECODE_VECTOR_GLSLC_SUPPORT) + if (device->coopmat2 && !device->coopmat2_decode_vector) { + const uint32_t * src = spirv.empty() ? reinterpret_cast(spv_data) : spirv.data(); + size_t src_n = spirv.empty() ? spv_size / sizeof(uint32_t) : spirv.size(); + std::vector stripped; + if (ggml_vk_strip_decode_vector(src, src_n, stripped)) { + spirv = std::move(stripped); + shader_module_create_info = vk::ShaderModuleCreateInfo({}, spirv.size() * sizeof(uint32_t), spirv.data()); + } + } +#endif + pipeline->shader_module = device->device.createShaderModule(shader_module_create_info); vk::PushConstantRange pcr( @@ -5159,6 +5315,7 @@ static vk_device ggml_vk_get_device(size_t idx) { bool amd_shader_core_properties2 = false; bool pipeline_robustness = false; bool coopmat2_support = false; + bool coopmat2_decode_vector_support = false; bool pipeline_executable_properties_support = false; device->coopmat_support = false; device->integer_dot_product = false; @@ -5193,6 +5350,9 @@ static vk_device ggml_vk_get_device(size_t idx) { !getenv("GGML_VK_DISABLE_COOPMAT2")) { coopmat2_support = true; #endif + } else if (strcmp(VK_NV_COOPERATIVE_MATRIX_DECODE_VECTOR_EXTENSION_NAME, properties.extensionName) == 0 && + !getenv("GGML_VK_DISABLE_COOPMAT2_DECODE_VECTOR")) { + coopmat2_decode_vector_support = true; #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) } else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0 && !getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) { @@ -5470,6 +5630,14 @@ static vk_device ggml_vk_get_device(size_t idx) { } #endif + VkPhysicalDeviceCooperativeMatrixDecodeVectorFeaturesNV coopmat2_decode_vector_features {}; + coopmat2_decode_vector_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_DECODE_VECTOR_FEATURES_NV; + if (coopmat2_decode_vector_support) { + last_struct->pNext = (VkBaseOutStructure *)&coopmat2_decode_vector_features; + last_struct = (VkBaseOutStructure *)&coopmat2_decode_vector_features; + device_extensions.push_back(VK_NV_COOPERATIVE_MATRIX_DECODE_VECTOR_EXTENSION_NAME); + } + #if defined(VK_KHR_shader_bfloat16) VkPhysicalDeviceShaderBfloat16FeaturesKHR bfloat16_features {}; bfloat16_features.pNext = nullptr; @@ -5629,6 +5797,7 @@ static vk_device ggml_vk_get_device(size_t idx) { found_fp32_128 && found_fp32_256 && coopmat2_props.cooperativeMatrixFlexibleDimensionsMaxDimension >= 512) { device->coopmat2 = true; + device->coopmat2_decode_vector = coopmat2_decode_vector_support && coopmat2_decode_vector_features.cooperativeMatrixDecodeVector; } } #endif @@ -5915,6 +6084,7 @@ static void ggml_vk_print_gpu_info(size_t idx) { bool fp16_compute = false; bool coopmat_support = false; bool coopmat2_support = false; + bool coopmat2_decode_vector_support = false; bool integer_dot_product = false; bool bfloat16_support = false; @@ -5933,6 +6103,9 @@ static void ggml_vk_print_gpu_info(size_t idx) { !getenv("GGML_VK_DISABLE_COOPMAT2")) { coopmat2_support = true; #endif + } else if (strcmp(VK_NV_COOPERATIVE_MATRIX_DECODE_VECTOR_EXTENSION_NAME, properties.extensionName) == 0 && + !getenv("GGML_VK_DISABLE_COOPMAT2_DECODE_VECTOR")) { + coopmat2_decode_vector_support = true; #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) } else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0 && !getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) { @@ -6017,6 +6190,13 @@ static void ggml_vk_print_gpu_info(size_t idx) { } #endif + VkPhysicalDeviceCooperativeMatrixDecodeVectorFeaturesNV coopmat2_decode_vector_features {}; + coopmat2_decode_vector_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_DECODE_VECTOR_FEATURES_NV; + if (coopmat2_decode_vector_support) { + last_struct->pNext = (VkBaseOutStructure *)&coopmat2_decode_vector_features; + last_struct = (VkBaseOutStructure *)&coopmat2_decode_vector_features; + } + vkGetPhysicalDeviceFeatures2(physical_device, &device_features2); fp16 = fp16 && vk12_features.shaderFloat16; @@ -6041,7 +6221,14 @@ static void ggml_vk_print_gpu_info(size_t idx) { #endif && ggml_vk_khr_cooperative_matrix_support(props2.properties, driver_props, device_architecture); - std::string matrix_cores = coopmat2_support ? "NV_coopmat2" : coopmat_support ? "KHR_coopmat" : "none"; + coopmat2_decode_vector_support = coopmat2_decode_vector_support && coopmat2_decode_vector_features.cooperativeMatrixDecodeVector; +#if !defined(GGML_VULKAN_COOPMAT2_DECODE_VECTOR_GLSLC_SUPPORT) + coopmat2_decode_vector_support = false; +#endif + + std::string matrix_cores = coopmat2_support ? (coopmat2_decode_vector_support ? "NV_coopmat2v" : "NV_coopmat2") + : coopmat_support ? "KHR_coopmat" + : "none"; std::string device_name = props2.properties.deviceName.data(); GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | bf16: %d | warp size: %zu | shared memory: %d | int dot: %d | matrix cores: %s\n", diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt b/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt index e1f613fb4f6..10a9ea21025 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +++ b/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt @@ -11,6 +11,10 @@ if (GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) add_compile_definitions(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) message(STATUS "Enabling coopmat2 glslc support") endif() +if (GGML_VULKAN_COOPMAT2_DECODE_VECTOR_GLSLC_SUPPORT) + add_compile_definitions(GGML_VULKAN_COOPMAT2_DECODE_VECTOR_GLSLC_SUPPORT) + message(STATUS "Enabling coopmat2 decode_vector glslc support") +endif() if (GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) add_compile_definitions(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) message(STATUS "Enabling dot glslc support") diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl index c582aba87dc..7171cbfa559 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl @@ -1,4 +1,12 @@ +// Each format defines a scalar dequantFunc plus a V=4 dequantFunc_v +// passed as the optional vector decoder to coopMatLoadTensorNV via +// GL_NV_cooperative_matrix_decode_vector. When the driver doesn't support +// the extension, ggml-vulkan.cpp strips it from the compiled SPIR-V. +#ifdef GL_NV_cooperative_matrix_decode_vector +#extension GL_NV_cooperative_matrix_decode_vector : enable +#endif + #include "types.glsl" layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufF32 { @@ -25,6 +33,19 @@ float16_t dequantFuncQ1_0(const in decodeBufQ1_0 bl, const in uint blockCoords[2 return bit != 0u ? d : -d; } +f16vec4 dequantFuncQ1_0_v(const in decodeBufQ1_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float16_t d = bl.block.d; + const float16_t md = -d; + const uint idx = coordInBlock[1]; + const uint qs_nib = uint(bl.block.qs[idx >> 3]) >> (idx & 0x4u); + return f16vec4( + (qs_nib & 1u) != 0u ? d : md, + (qs_nib & 2u) != 0u ? d : md, + (qs_nib & 4u) != 0u ? d : md, + (qs_nib & 8u) != 0u ? d : md); +} + layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ4_0 { block_q4_0_packed16 block; }; @@ -42,10 +63,28 @@ float16_t dequantFuncQ4_0(const in decodeBufQ4_0 bl, const in uint blockCoords[2 return ret; } +f16vec4 dequantFuncQ4_0_v(const in decodeBufQ4_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + const uint shift = (idx & 0x10) >> 2; // 0 or 4 + const uint qs_i = (idx & 0xE) >> 1; // even, in {0,2,4,6} + const uint qsw = uint32_t(bl.block.qs[qs_i ]) + | (uint32_t(bl.block.qs[qs_i + 1u]) << 16); + // shift in {0,4}: per-byte mask 0x0F isolates the wanted nibble in each byte. + const uint q4 = (qsw >> shift) & 0x0F0F0F0Fu; + const u8vec4 q = unpack8(q4); + return f16vec4((vec4(q) - vec4(8.0)) * vec4(float(d))); +} + layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ4_1 { block_q4_1 block; }; +layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ4_1_packed32 { + block_q4_1_packed32 block; +}; + float16_t dequantFuncQ4_1(const in decodeBufQ4_1 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) { const float16_t d = bl.block.d; @@ -60,10 +99,27 @@ float16_t dequantFuncQ4_1(const in decodeBufQ4_1 bl, const in uint blockCoords[2 return ret; } +f16vec4 dequantFuncQ4_1_v(const in decodeBufQ4_1 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufQ4_1_packed32 bl32 = decodeBufQ4_1_packed32(bl); + const float16_t d = bl.block.d; + const float16_t m = bl.block.m; + const uint idx = coordInBlock[1]; + const uint shift = (idx & 0x10) >> 2; // 0 or 4 + const uint qs_w = (idx & 0xC) >> 2; // iqs / 4 in [0,4) + const uint qsw = uint32_t(bl32.block.qs[qs_w]); + const u8vec4 q = unpack8((qsw >> shift) & 0x0F0F0F0Fu); + return f16vec4(vec4(q) * vec4(float(d)) + vec4(float(m))); +} + layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ5_0 { block_q5_0 block; }; +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ5_0_packed16 { + block_q5_0_packed16 block; +}; + float16_t dequantFuncQ5_0(const in decodeBufQ5_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) { const float16_t d = bl.block.d; @@ -82,10 +138,32 @@ float16_t dequantFuncQ5_0(const in decodeBufQ5_0 bl, const in uint blockCoords[2 return ret; } +f16vec4 dequantFuncQ5_0_v(const in decodeBufQ5_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufQ5_0_packed16 bl16 = decodeBufQ5_0_packed16(bl); + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + const uint shift = (idx & 0x10) >> 2; // 0 or 4 + const uint qs_i = (idx & 0xC) >> 1; // packed16 word index, in {0,2,4,6} + const uint qsw = uint32_t(bl16.block.qs[qs_i ]) + | (uint32_t(bl16.block.qs[qs_i + 1u]) << 16); + const u8vec4 ql = unpack8((qsw >> shift) & 0x0F0F0F0Fu); + + const uint uint_qh = uint(bl16.block.qh[1]) << 16 | uint(bl16.block.qh[0]); + const uint qh_pack = uint_qh >> idx; // bits 0..3 = element idx..idx+3 high bits + const uvec4 qh_high = (uvec4(qh_pack, qh_pack >> 1u, qh_pack >> 2u, qh_pack >> 3u) & uvec4(0x01u)) << 4u; + + return f16vec4((vec4(ql) + vec4(qh_high) - vec4(16.0)) * vec4(float(d))); +} + layout(buffer_reference, std430, buffer_reference_align = 8) buffer decodeBufQ5_1 { block_q5_1 block; }; +layout(buffer_reference, std430, buffer_reference_align = 8) buffer decodeBufQ5_1_packed32 { + block_q5_1_packed32 block; +}; + float16_t dequantFuncQ5_1(const in decodeBufQ5_1 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) { const float16_t d = bl.block.d; @@ -105,6 +183,23 @@ float16_t dequantFuncQ5_1(const in decodeBufQ5_1 bl, const in uint blockCoords[2 return ret; } +f16vec4 dequantFuncQ5_1_v(const in decodeBufQ5_1 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufQ5_1_packed32 bl32 = decodeBufQ5_1_packed32(bl); + const float16_t d = bl.block.d; + const float16_t m = bl.block.m; + const uint idx = coordInBlock[1]; + const uint shift = (idx & 0x10) >> 2; // 0 or 4 + const uint qs_w = (idx & 0xC) >> 2; // iqs / 4 in [0,4) + const uint qsw = uint32_t(bl32.block.qs[qs_w]); + const u8vec4 ql = unpack8((qsw >> shift) & 0x0F0F0F0Fu); + + const uint qh_pack = bl.block.qh >> idx; // bits 0..3 = element idx..idx+3 high bits + const uvec4 qh_high = (uvec4(qh_pack, qh_pack >> 1u, qh_pack >> 2u, qh_pack >> 3u) & uvec4(0x01u)) << 4u; + + return f16vec4((vec4(ql) + vec4(qh_high)) * vec4(float(d)) + vec4(float(m))); +} + layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ8_0 { block_q8_0_packed16 block; }; @@ -121,6 +216,17 @@ float16_t dequantFuncQ8_0(const in decodeBufQ8_0 bl, const in uint blockCoords[2 return ret; } +f16vec4 dequantFuncQ8_0_v(const in decodeBufQ8_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + const uint base = idx >> 1u; + const uint w = uint(uint16_t(bl.block.qs[base])) + | (uint(uint16_t(bl.block.qs[base + 1u])) << 16u); + const i8vec4 qi = unpack8(int32_t(w)); + return f16vec4(vec4(qi) * vec4(float(d))); +} + layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ2_K { block_q2_K block; }; @@ -129,6 +235,10 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ2 block_q2_K_packed16 block; }; +layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ2_K_packed32 { + block_q2_K_packed32 block; +}; + float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) { decodeBufQ2_K_packed16 bl16 = decodeBufQ2_K_packed16(bl); @@ -147,10 +257,36 @@ float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2 return ret; } +f16vec4 dequantFuncQ2_K_v(const in decodeBufQ2_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufQ2_K_packed32 bl32 = decodeBufQ2_K_packed32(bl); + const f16vec2 dm = bl.block.dm; + const uint idx = coordInBlock[1]; + + const uint scalesi = idx >> 4; // 0..15 + const uint qsshift = (idx & 0x60) >> 4; // 0,2,4,6 + + // qs_i (packed16) = ((idx & 0x80) >> 3) + ((idx & 0x1E) >> 1) is even for idx % 4 == 0, + // so qs_w (packed32) = qs_i / 2 = ((idx & 0x80) >> 4) + ((idx & 0x1Cu) >> 2). + const uint qs_w = ((idx & 0x80) >> 4) + ((idx & 0x1Cu) >> 2); + const uint qsw = uint32_t(bl32.block.qs[qs_w]); + const uint qs4 = (qsw >> qsshift) & 0x03030303u; + const u8vec4 qi = unpack8(qs4); + + const uint scales = bl.block.scales[scalesi]; + const float16_t d_sub = dm.x * float16_t(scales & 0xF); + const float16_t m_sub = dm.y * float16_t(scales >> 4); + return f16vec4(vec4(qi) * vec4(float(d_sub)) - vec4(float(m_sub))); +} + layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ3_K { block_q3_K block; }; +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ3_K_packed16 { + block_q3_K_packed16 block; +}; + float16_t dequantFuncQ3_K(const in decodeBufQ3_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) { const uint idx = coordInBlock[1]; @@ -179,6 +315,47 @@ float16_t dequantFuncQ3_K(const in decodeBufQ3_K bl, const in uint blockCoords[2 return ret; } +f16vec4 dequantFuncQ3_K_v(const in decodeBufQ3_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufQ3_K_packed16 bl16 = decodeBufQ3_K_packed16(bl); + const uint idx = coordInBlock[1]; + + const uint n = idx >> 7; // 0,1 + const uint is = idx >> 4; // 0..15 + const uint halfsplit = (idx & 0x60) >> 5; // 0,1,2,3 + const uint qsshift = halfsplit << 1; // 0,2,4,6 + const uint hbit = (n << 2) + halfsplit; // 0..7 (bit position in hmask byte) + + uint32_t scaleidx0 = (is < 8) ? is : (is - 8); + uint32_t scaleidx0shift = (is < 8) ? 0u : 4u; + uint32_t scaleidx1 = is + 8 - (is / 4) * 4; + uint32_t scaleidx1shift = (is / 4) * 2; + + const int8_t us = int8_t( + ((bl.block.scales[scaleidx0] >> scaleidx0shift) & 0xF) | + (((bl.block.scales[scaleidx1] >> scaleidx1shift) & 3) << 4)); + const float16_t dl = bl.block.d * float16_t(int(us) - 32); + + // For idx % 4 == 0: (idx & 0x1F) == (idx & 0x1C) is a multiple of 4. + const uint qsi = (n << 5) + (idx & 0x1Cu); + const uint hmi = (idx & 0x1Cu); + + // Two adjacent uint16 packed16 reads, combined into a uint32 in registers. + // After this: byte j of qsw / hmw holds the data for element idx+j. + const uint qsw = uint32_t(bl16.block.qs[qsi >> 1]) + | (uint32_t(bl16.block.qs[(qsi >> 1) + 1u]) << 16); + const uint hmw = uint32_t(bl16.block.hmask[hmi >> 1]) + | (uint32_t(bl16.block.hmask[(hmi >> 1) + 1u]) << 16); + + // qsshift in {0,2,4,6} and hbit in {0..7}: per-byte masks isolate the wanted bits + // with no inter-byte leakage. + const uint ql4 = (qsw >> qsshift) & 0x03030303u; + const uint qh4 = (hmw >> hbit) & 0x01010101u; + + const ivec4 q = ivec4(unpack8(ql4 | (qh4 << 2))) - ivec4(4); + return f16vec4(vec4(q) * vec4(float(dl))); +} + layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K { block_q4_K block; }; @@ -187,6 +364,10 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4 block_q4_K_packed16 block; }; +layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K_packed32 { + block_q4_K_packed32 block; +}; + layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K_packed128 { block_q4_K_packed128 block; }; @@ -334,6 +515,55 @@ float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2 return float16_t(ret); } +f16vec4 dequantFuncQ4_K_v(const in decodeBufQ4_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufQ4_K_packed32 bl32 = decodeBufQ4_K_packed32(bl); + decodeBufQ4_K_packed128 bl128 = decodeBufQ4_K_packed128(bl); + const uint idx = coordInBlock[1]; + + const uint is = idx >> 5; // 0..7 + +#if defined(IS_MUL_MM2) && defined(DATA_A_Q4_K) + vec2 v = shAscales[is * shAscales_stride + (blockCoords[0] % BM)]; + float d = v.x; + float m = v.y; +#else + uvec4 v = bl128.block.q4k[0]; + const vec2 loadd = vec2(unpackFloat2x16(v.x)); + + uint32_t sc; + uint32_t mbyte; + + uint32_t scale0 = v.y; + uint32_t scale4 = v.z; + uint32_t scale8 = v.w; + + uint32_t sc_lo = scale0; + uint32_t mb_lo = scale4; + uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2); + uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2); + + sc = is < 4 ? sc_lo : sc_hi; + mbyte = is < 4 ? mb_lo : mb_hi; + sc = sc >> (8 * (is & 3)); + mbyte = mbyte >> (8 * (is & 3)); + sc &= 0x3F; + mbyte &= 0x3F; + + const float d = loadd.x * float(sc); + const float m = loadd.y * float(mbyte); +#endif + + // idx in [0,256); vector decode uses idx a multiple of 4. packed32 word index: + // (qs_i >> 1) == (idx >> 6) * 8 + ((idx & 0x1E) >> 2). sh is 0 or 4 only, so a + // single (w >> sh) & 0x0F0F0F0F isolates all four nibbles without inter-byte leakage. + const uint sh = (idx & 0x20u) >> 3u; + const uint w = uint32_t(bl32.block.qs[(idx >> 6) * 8u + ((idx & 0x1Eu) >> 2)]); + const u8vec4 q = unpack8((w >> sh) & 0x0F0F0F0Fu); + + return f16vec4(vec4(d) * vec4(q) - vec4(m)); +} + layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K { block_q5_K block; }; @@ -346,6 +576,10 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5 block_q5_K_packed128 block; }; +layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K_packed32 { + block_q5_K_packed32 block; +}; + float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) { decodeBufQ5_K_packed16 bl16 = decodeBufQ5_K_packed16(bl); @@ -399,6 +633,58 @@ float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2 return float16_t(ret); } +f16vec4 dequantFuncQ5_K_v(const in decodeBufQ5_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufQ5_K_packed32 bl32 = decodeBufQ5_K_packed32(bl); + decodeBufQ5_K_packed128 bl128 = decodeBufQ5_K_packed128(bl); + const uint idx = coordInBlock[1]; + const uint is = idx >> 5; + +#if defined(IS_MUL_MM2) && defined(DATA_A_Q5_K) + vec2 v = shAscales[is * shAscales_stride + (blockCoords[0] % BM)]; + float d = v.x; + float m = v.y; +#else + uvec4 v = bl128.block.q5k[0]; + + const f16vec2 loadd = unpackFloat2x16(v.x); + + uint32_t sc; + uint32_t mbyte; + + uint32_t scale0 = v.y; + uint32_t scale4 = v.z; + uint32_t scale8 = v.w; + + uint32_t sc_lo = scale0; + uint32_t mb_lo = scale4; + uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2); + uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2); + + sc = is < 4 ? sc_lo : sc_hi; + mbyte = is < 4 ? mb_lo : mb_hi; + sc = sc >> (8 * (is & 3)); + mbyte = mbyte >> (8 * (is & 3)); + sc &= 0x3F; + mbyte &= 0x3F; + + const float16_t d = loadd.x * float16_t(sc); + const float16_t m = loadd.y * float16_t(mbyte); +#endif + + // sh is 0 or 4; mask 0x0F0F0F0F covers the four nibbles regardless (no inter-byte leakage). + const uint sh = (idx & 0x20u) >> 3u; + const uint qs_w = (idx >> 6) * 8u + ((idx & 0x1Eu) >> 2); + const uint qh_w = (idx & 0x1Eu) >> 2; + + const uint ql4 = (uint32_t(bl32.block.qs[qs_w]) >> sh) & 0x0F0F0F0Fu; + // qh stores bit `is` per element across 4 consecutive bytes; one shift+mask handles all 4. + const uint qh4 = ((uint32_t(bl32.block.qh[qh_w]) >> is) & 0x01010101u) << 4u; + + const u8vec4 qi = unpack8(ql4 | qh4); + return f16vec4(vec4(qi) * vec4(d) - vec4(m)); +} + layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ6_K { block_q6_K block; }; @@ -431,6 +717,35 @@ float16_t dequantFuncQ6_K(const in decodeBufQ6_K bl, const in uint blockCoords[2 return ret; } +f16vec4 dequantFuncQ6_K_v(const in decodeBufQ6_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufQ6_K_packed16 bl16 = decodeBufQ6_K_packed16(bl); + const uint idx = coordInBlock[1]; + + const uint b = (idx & 0x40) >> 6; + const uint qhshift = (idx & 0x60) >> 4; // 0,2,4,6 + const uint is = idx >> 4; + const uint sh = b * 4; // 0 or 4 + + const float16_t dscale = bl.block.d * float16_t(bl.block.scales[is]); + + const uint ql_i = ((idx & 0x80) >> 2) + ((idx & 0x3E) >> 1); + const uint qh_i = ((idx & 0x80) >> 3) + ((idx & 0x1E) >> 1); + + // Two adjacent uint16 packed16 reads, combined into a uint32 in registers. + // After this: byte j of qlw / qhw holds the data for element idx+j. + const uint qlw = uint32_t(bl16.block.ql[ql_i ]) | (uint32_t(bl16.block.ql[ql_i + 1]) << 16); + const uint qhw = uint32_t(bl16.block.qh[qh_i ]) | (uint32_t(bl16.block.qh[qh_i + 1]) << 16); + + // sh in {0,4} and qhshift in {0,2,4,6}: per-byte masks 0x0F / 0x03 keep only the + // wanted bits with no inter-byte leakage; place qh's 2 bits at nibble high position. + const uint ql4 = (qlw >> sh) & 0x0F0F0F0Fu; + const uint qh4 = ((qhw >> qhshift) & 0x03030303u) << 4u; + + const ivec4 qi = ivec4(unpack8(ql4 | qh4)); + return f16vec4((vec4(qi) - vec4(32.0f)) * vec4(float(dscale))); +} + #if defined(DATA_A_IQ1_S) layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ1_S { block_iq1_s block; @@ -453,6 +768,29 @@ float16_t dequantFuncIQ1_S(const in decodeBufIQ1_S bl, const in uint blockCoords float16_t ret = float16_t(dl) * (float16_t(bitfieldExtract(int(grid), 2 * int(idx % 8), 2)) + float16_t(delta)); return ret; } + +f16vec4 dequantFuncIQ1_S_v(const in decodeBufIQ1_S bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + + const uint ib32 = idx >> 5; + const uint ib8 = idx >> 3; + const int i8b = int(idx & 4); // 0 or 4 + + const uint qh = bl.block.qh[ib32]; + const uint qs = bl.block.qs[ib8]; + const float dl = float(d) * float(2 * bitfieldExtract(qh, 12, 3) + 1); + const float delta = ((qh & 0x8000u) != 0u) ? -IQ1S_DELTA : IQ1S_DELTA; + const uint grid = iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]; + + const ivec4 q = ivec4( + bitfieldExtract(int(grid), 2 * (i8b + 0), 2), + bitfieldExtract(int(grid), 2 * (i8b + 1), 2), + bitfieldExtract(int(grid), 2 * (i8b + 2), 2), + bitfieldExtract(int(grid), 2 * (i8b + 3), 2)); + return f16vec4((vec4(q) + vec4(delta)) * dl); +} #endif #if defined(DATA_A_IQ1_M) @@ -485,6 +823,33 @@ float16_t dequantFuncIQ1_M(const in decodeBufIQ1_M bl, const in uint blockCoords float16_t ret = d * float16_t(dl) * (float16_t(bitfieldExtract(int(grid), 2 * i8, 2)) + float16_t(delta)); return ret; } + +f16vec4 dequantFuncIQ1_M_v(const in decodeBufIQ1_M bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufIQ1_M_packed64 bl64 = decodeBufIQ1_M_packed64(bl); + const uint idx = coordInBlock[1]; + + uvec2 scales = unpack32(bl64.block.scales); + const float16_t d = uint16BitsToHalf(uint16_t(((scales.x & 0xF000) >> 12) | ((scales.x & 0xF0000000) >> 24) | ((scales.y & 0xF000) >> 4) | ((scales.y & 0xF0000000) >> 16))); + + const uint ib8 = idx >> 3; + const uint ib16 = idx >> 4; + const int i8b = int(idx & 4); // 0 or 4 -- i8 base for the V=4 group + + const uint sc = bl.block.scales[ib8 / 8]; + const uint qs = bl.block.qs[ib8]; + const uint qh = bl.block.qh[ib16] >> (4 * (ib8 & 1)); + const float dl = 2.0 * float(bitfieldExtract(sc, 3 * int(ib16 & 3), 3)) + 1.0; + const float delta = ((qh & 8u) != 0u) ? -IQ1S_DELTA : IQ1S_DELTA; + const uint grid = iq1s_grid[qs | ((qh & 7u) << 8)]; + + const ivec4 q = ivec4( + bitfieldExtract(int(grid), 2 * (i8b + 0), 2), + bitfieldExtract(int(grid), 2 * (i8b + 1), 2), + bitfieldExtract(int(grid), 2 * (i8b + 2), 2), + bitfieldExtract(int(grid), 2 * (i8b + 3), 2)); + return f16vec4((vec4(q) + vec4(delta)) * (float(d) * dl)); +} #endif #if defined(DATA_A_IQ2_XXS) @@ -520,6 +885,33 @@ float16_t dequantFuncIQ2_XXS(const in decodeBufIQ2_XXS bl, const in uint blockCo vec2 ret = dscale * g * ((sign & (1 << (idx & 7))) != 0 ? -1.0hf : 1.0hf); return float16_t(ret[idx & 1]); } + +f16vec4 dequantFuncIQ2_XXS_v(const in decodeBufIQ2_XXS bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufIQ2_XXS_packed16 bl16 = decodeBufIQ2_XXS_packed16(bl); + const uint idx = coordInBlock[1]; + + const uint ib32 = idx >> 5; + const uint ib8 = (idx & 0x18) >> 3; + const uint iqs = 8 * ib32 + ib8; + + const uint qs = bl.block.qs[iqs]; + const uint signscale = pack32(u16vec2(bl16.block.qs[4*ib32+2], bl16.block.qs[4*ib32+3])); + const float dscale = float(bl.block.d) * 0.25 * (0.5 + float(signscale >> 28)); + + uint sign = bitfieldExtract(signscale, 7 * int(ib8), 7); + sign |= bitCount(sign) << 7; + const uint sb = sign >> (idx & 7u); + + const uint g2 = iq2xxs_grid[qs][(idx & 4) >> 2]; + const u8vec4 g = unpack8(g2); + + return f16vec4( + dscale * float(g.x) * ((sb & 1u) != 0u ? -1.0 : 1.0), + dscale * float(g.y) * ((sb & 2u) != 0u ? -1.0 : 1.0), + dscale * float(g.z) * ((sb & 4u) != 0u ? -1.0 : 1.0), + dscale * float(g.w) * ((sb & 8u) != 0u ? -1.0 : 1.0)); +} #endif #if defined(DATA_A_IQ2_XS) @@ -548,6 +940,31 @@ float16_t dequantFuncIQ2_XS(const in decodeBufIQ2_XS bl, const in uint blockCoor vec2 ret = dscale * g * ((sign & (1 << (idx & 7))) != 0 ? -1.0hf : 1.0hf); return float16_t(ret[idx & 1]); } + +f16vec4 dequantFuncIQ2_XS_v(const in decodeBufIQ2_XS bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const uint idx = coordInBlock[1]; + + const uint is = idx >> 5; + const uint sshift = (idx & 0x10) >> 2; + const uint iqs = idx >> 3; + + const uint16_t qs = bl.block.qs[iqs]; + const float dscale = float(bl.block.d) * 0.25 * (0.5 + float((bl.block.scales[is] >> sshift) & 0xF)); + + uint sign = uint(qs >> 9); + sign |= bitCount(sign) << 7; + const uint sb = sign >> (idx & 7u); + + const uint g2 = iq2xs_grid[qs & 0x1FF][(idx & 4) >> 2]; + const u8vec4 g = unpack8(g2); + + return f16vec4( + dscale * float(g.x) * ((sb & 1u) != 0u ? -1.0 : 1.0), + dscale * float(g.y) * ((sb & 2u) != 0u ? -1.0 : 1.0), + dscale * float(g.z) * ((sb & 4u) != 0u ? -1.0 : 1.0), + dscale * float(g.w) * ((sb & 8u) != 0u ? -1.0 : 1.0)); +} #endif #if defined(DATA_A_IQ2_S) @@ -576,6 +993,32 @@ float16_t dequantFuncIQ2_S(const in decodeBufIQ2_S bl, const in uint blockCoords const vec2 v = db * vec2(sign01) * vec2(unpack8(g2)); return float16_t(v[idx & 1]); } + +f16vec4 dequantFuncIQ2_S_v(const in decodeBufIQ2_S bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const uint idx = coordInBlock[1]; + + const uint ib32 = idx >> 5; + const uint ib8 = idx >> 3; + const uint qhshift = 2 * (ib8 % 4); + + const uint scale = (bl.block.scales[ib32] >> ((idx & 0x10) >> 2)) & 0xf; + const uint qs = bl.block.qs[ib8]; + const uint qh = bl.block.qh[ib32]; + const uint sb = uint(bl.block.qs[QUANT_K / 8 + ib8]) >> (idx & 0x6u); + + const float d = float(bl.block.d); + const float db = d * 0.25 * (0.5 + scale); + + const uint g2 = iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(idx & 4) >> 2]; + const u8vec4 g = unpack8(g2); + + return f16vec4( + db * float(g.x) * ((sb & 1u) != 0u ? -1.0 : 1.0), + db * float(g.y) * ((sb & 2u) != 0u ? -1.0 : 1.0), + db * float(g.z) * ((sb & 4u) != 0u ? -1.0 : 1.0), + db * float(g.w) * ((sb & 8u) != 0u ? -1.0 : 1.0)); +} #endif #if defined(DATA_A_IQ3_XXS) @@ -609,6 +1052,32 @@ float16_t dequantFuncIQ3_XXS(const in decodeBufIQ3_XXS bl, const in uint blockCo const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); return float16_t(v[idx & 1]); } + +f16vec4 dequantFuncIQ3_XXS_v(const in decodeBufIQ3_XXS bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufIQ3_XXS_packed16 bl16 = decodeBufIQ3_XXS_packed16(bl); + const uint idx = coordInBlock[1]; + + const uint iqs = idx >> 2; + const uint is = QUANT_K / 4 + ((idx & 0xE0) >> 3); + + const float d = float(bl.block.d); + const uint qs = bl.block.qs[iqs]; + const uint signs = pack32(u16vec2(bl16.block.qs[is/2+0], bl16.block.qs[is/2+1])); + const float db = d * 0.5 * (0.5 + (signs >> 28)); + + const uint sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7); + const uint sb = (sign7 | (bitCount(sign7) << 7)) >> (idx & 0x6u); + + const uint grid = iq3xxs_grid[qs]; + const u8vec4 g = unpack8(grid); + + return f16vec4( + db * float(g.x) * ((sb & 1u) != 0u ? -1.0 : 1.0), + db * float(g.y) * ((sb & 2u) != 0u ? -1.0 : 1.0), + db * float(g.z) * ((sb & 4u) != 0u ? -1.0 : 1.0), + db * float(g.w) * ((sb & 8u) != 0u ? -1.0 : 1.0)); +} #endif #if defined(DATA_A_IQ3_S) @@ -635,6 +1104,30 @@ float16_t dequantFuncIQ3_S(const in decodeBufIQ3_S bl, const in uint blockCoords return float16_t(v[idx & 1]); } + +f16vec4 dequantFuncIQ3_S_v(const in decodeBufIQ3_S bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const uint idx = coordInBlock[1]; + + const uint iqs = idx >> 2; + const uint iqh = idx >> 5; + + const float d = float(bl.block.d); + const uint qs = bl.block.qs[iqs]; + const uint qh = bl.block.qh[iqh]; + const uint sb = uint(bl.block.signs[iqs / 2]) >> (idx & 0x6u); + const uint scale = bl.block.scales[iqs / 16]; + const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf)); + + const uint grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)]; + const u8vec4 g = unpack8(grid); + + return f16vec4( + db * float(g.x) * ((sb & 1u) != 0u ? -1.0 : 1.0), + db * float(g.y) * ((sb & 2u) != 0u ? -1.0 : 1.0), + db * float(g.z) * ((sb & 4u) != 0u ? -1.0 : 1.0), + db * float(g.w) * ((sb & 8u) != 0u ? -1.0 : 1.0)); +} #endif #if defined(DATA_A_IQ4_XS) @@ -642,6 +1135,10 @@ layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ4 block_iq4_xs block; }; +layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufIQ4_XS_packed32 { + block_iq4_xs_packed32 block; +}; + float16_t dequantFuncIQ4_XS(const in decodeBufIQ4_XS bl, const in uint blockCoords[2], const in uint coordInBlock[2]) { const float16_t d = bl.block.d; @@ -657,6 +1154,30 @@ float16_t dequantFuncIQ4_XS(const in decodeBufIQ4_XS bl, const in uint blockCoor float16_t ret = d * float16_t(int(sl | (sh << 4)) - 32) * float16_t(kvalues_iq4nl[q]); return ret; } + +f16vec4 dequantFuncIQ4_XS_v(const in decodeBufIQ4_XS bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufIQ4_XS_packed32 bl32 = decodeBufIQ4_XS_packed32(bl); + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + + const uint ib32 = idx >> 5; // 0..7 + const uint sl = (bl32.block.scales_l >> (4 * ib32)) & 0xF; + const uint sh = (uint(bl32.block.scales_h) >> (2 * ib32)) & 0x3; + const uint qshift = (idx & 0x10) >> 2; // {0, 4} + const uint qs_w = 4 * ib32 + ((idx & 0xC) >> 2); // iqs / 4, in [0,32) + + const float16_t dl = d * float16_t(int(sl | (sh << 4)) - 32); + + const uint qsw = bl32.block.qs[qs_w]; + const u8vec4 qv = unpack8((qsw >> qshift) & 0x0F0F0F0Fu); + const vec4 ret = vec4( + float(kvalues_iq4nl[qv.x]), + float(kvalues_iq4nl[qv.y]), + float(kvalues_iq4nl[qv.z]), + float(kvalues_iq4nl[qv.w])) * float(dl); + return f16vec4(ret); +} #endif #if defined(DATA_A_IQ4_NL) @@ -664,6 +1185,10 @@ layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ4 block_iq4_nl block; }; +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ4_NL_packed16 { + block_iq4_nl_packed16 block; +}; + float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoords[2], const in uint coordInBlock[2]) { const float16_t d = bl.block.d; @@ -676,6 +1201,24 @@ float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoor float16_t ret = float16_t(kvalues_iq4nl[qs]) * d; return ret; } + +f16vec4 dequantFuncIQ4_NL_v(const in decodeBufIQ4_NL bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufIQ4_NL_packed16 bl16 = decodeBufIQ4_NL_packed16(bl); + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + const uint shift = (idx & 0x10) >> 2; // 0 or 4 + const uint qs_i = (idx & 0xC) >> 1; // packed16 word index, in {0,2,4,6} + const uint qsw = uint32_t(bl16.block.qs[qs_i ]) + | (uint32_t(bl16.block.qs[qs_i + 1u]) << 16); + // shift in {0,4}: per-byte mask 0x0F isolates the wanted nibble in each byte. + const u8vec4 q = unpack8((qsw >> shift) & 0x0F0F0F0Fu); + return f16vec4( + float(d) * float(kvalues_iq4nl[q.x]), + float(d) * float(kvalues_iq4nl[q.y]), + float(d) * float(kvalues_iq4nl[q.z]), + float(d) * float(kvalues_iq4nl[q.w])); +} #endif #if defined(DATA_A_MXFP4) @@ -695,6 +1238,26 @@ float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords float16_t ret = float16_t(kvalues_mxfp4[qs] * d * 0.5); return ret; } + +f16vec4 dequantFuncMXFP4_v(const in decodeBufMXFP4 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float d = e8m0_to_fp32(bl.block.e); + const uint idx = coordInBlock[1]; + const uint iqs = idx & 0xF; + const uint shift = (idx & 0x10) >> 2; + uvec4 qv = uvec4( + uint(bl.block.qs[iqs]), + uint(bl.block.qs[iqs + 1u]), + uint(bl.block.qs[iqs + 2u]), + uint(bl.block.qs[iqs + 3u])); + qv = (qv >> shift) & 0xFu; + const vec4 ret = vec4( + float(kvalues_mxfp4[qv.x]), + float(kvalues_mxfp4[qv.y]), + float(kvalues_mxfp4[qv.z]), + float(kvalues_mxfp4[qv.w])) * d * 0.5f; + return f16vec4(ret); +} #endif #if defined(DATA_A_NVFP4) @@ -702,6 +1265,10 @@ layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufNVF block_nvfp4 block; }; +layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufNVFP4_packed32 { + block_nvfp4_packed32 block; +}; + float16_t dequantFuncNVFP4(const in decodeBufNVFP4 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) { const uint idx = coordInBlock[1]; @@ -713,56 +1280,97 @@ float16_t dequantFuncNVFP4(const in decodeBufNVFP4 bl, const in uint blockCoords qs = (qs >> shift) & 0xF; return float16_t(kvalues_mxfp4[qs] * d * 0.5); } + +f16vec4 dequantFuncNVFP4_v(const in decodeBufNVFP4 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufNVFP4_packed32 bl32 = decodeBufNVFP4_packed32(bl); + const uint idx = coordInBlock[1]; + const uint sub = idx >> 4; + const uint qs_w = ((idx & 0x30) >> 3) + ((idx & 0x4u) >> 2); // iqs / 4, in [0,8) + const uint shift = (idx & 0x8) >> 1; + const float d = ue4m3_to_fp32(bl.block.d[sub]); + + const uint qsw = uint32_t(bl32.block.qs[qs_w]); + const u8vec4 qv = unpack8((qsw >> shift) & 0x0F0F0F0Fu); + const vec4 ret = vec4( + float(kvalues_mxfp4[qv.x]), + float(kvalues_mxfp4[qv.y]), + float(kvalues_mxfp4[qv.z]), + float(kvalues_mxfp4[qv.w])) * d * 0.5f; + return f16vec4(ret); +} #endif #if defined(DATA_A_Q1_0) #define dequantFuncA dequantFuncQ1_0 +#define dequantFuncA_v dequantFuncQ1_0_v #elif defined(DATA_A_Q4_0) #define dequantFuncA dequantFuncQ4_0 +#define dequantFuncA_v dequantFuncQ4_0_v #elif defined(DATA_A_Q4_1) #define dequantFuncA dequantFuncQ4_1 +#define dequantFuncA_v dequantFuncQ4_1_v #elif defined(DATA_A_Q5_0) #define dequantFuncA dequantFuncQ5_0 +#define dequantFuncA_v dequantFuncQ5_0_v #elif defined(DATA_A_Q5_1) #define dequantFuncA dequantFuncQ5_1 +#define dequantFuncA_v dequantFuncQ5_1_v #elif defined(DATA_A_Q8_0) #define dequantFuncA dequantFuncQ8_0 +#define dequantFuncA_v dequantFuncQ8_0_v #elif defined(DATA_A_Q2_K) #define dequantFuncA dequantFuncQ2_K +#define dequantFuncA_v dequantFuncQ2_K_v #elif defined(DATA_A_Q3_K) #define dequantFuncA dequantFuncQ3_K +#define dequantFuncA_v dequantFuncQ3_K_v #elif defined(DATA_A_Q4_K) #define dequantFuncA dequantFuncQ4_K +#define dequantFuncA_v dequantFuncQ4_K_v #define fetch_scales fetch_scalesQ4_K #define store_scales store_scalesQ4_K #elif defined(DATA_A_Q5_K) #define dequantFuncA dequantFuncQ5_K +#define dequantFuncA_v dequantFuncQ5_K_v #define fetch_scales fetch_scalesQ5_K #define store_scales store_scalesQ4_K #elif defined(DATA_A_Q6_K) #define dequantFuncA dequantFuncQ6_K +#define dequantFuncA_v dequantFuncQ6_K_v #elif defined(DATA_A_IQ1_S) #define dequantFuncA dequantFuncIQ1_S +#define dequantFuncA_v dequantFuncIQ1_S_v #elif defined(DATA_A_IQ1_M) #define dequantFuncA dequantFuncIQ1_M +#define dequantFuncA_v dequantFuncIQ1_M_v #elif defined(DATA_A_IQ2_XXS) #define dequantFuncA dequantFuncIQ2_XXS +#define dequantFuncA_v dequantFuncIQ2_XXS_v #elif defined(DATA_A_IQ2_XS) #define dequantFuncA dequantFuncIQ2_XS +#define dequantFuncA_v dequantFuncIQ2_XS_v #elif defined(DATA_A_IQ2_S) #define dequantFuncA dequantFuncIQ2_S +#define dequantFuncA_v dequantFuncIQ2_S_v #elif defined(DATA_A_IQ3_XXS) #define dequantFuncA dequantFuncIQ3_XXS +#define dequantFuncA_v dequantFuncIQ3_XXS_v #elif defined(DATA_A_IQ3_S) #define dequantFuncA dequantFuncIQ3_S +#define dequantFuncA_v dequantFuncIQ3_S_v #elif defined(DATA_A_IQ4_XS) #define dequantFuncA dequantFuncIQ4_XS +#define dequantFuncA_v dequantFuncIQ4_XS_v #elif defined(DATA_A_IQ4_NL) #define dequantFuncA dequantFuncIQ4_NL +#define dequantFuncA_v dequantFuncIQ4_NL_v #elif defined(DATA_A_MXFP4) #define dequantFuncA dequantFuncMXFP4 +#define dequantFuncA_v dequantFuncMXFP4_v #elif defined(DATA_A_NVFP4) #define dequantFuncA dequantFuncNVFP4 +#define dequantFuncA_v dequantFuncNVFP4_v #elif defined(DATA_A_F32) #define dequantFuncA dequantFuncF32 #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2_decode_vector.comp b/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2_decode_vector.comp new file mode 100644 index 00000000000..65e9c678401 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2_decode_vector.comp @@ -0,0 +1,7 @@ +#version 460 + +#extension GL_NV_cooperative_matrix_decode_vector : require + +void main() +{ +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index 141bb870883..6d45b4931df 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -11,6 +11,9 @@ #extension GL_KHR_memory_scope_semantics : enable #extension GL_KHR_cooperative_matrix : enable #extension GL_NV_cooperative_matrix2 : enable +#ifdef GL_NV_cooperative_matrix_decode_vector +#extension GL_NV_cooperative_matrix_decode_vector : enable +#endif #extension GL_EXT_buffer_reference : enable #extension GL_KHR_shader_subgroup_ballot : enable #extension GL_KHR_shader_subgroup_vote : enable @@ -54,6 +57,41 @@ float16_t faDecodeV(const decodeBufFA_V bl_in, const uint blockCoords[2], const } } +// V=4 vector decode for K/V; dispatches to per-format _v decoders. +f16vec4 faDecodeKVector(const decodeBufFA_K bl_in, const uint blockCoords[2], const uint coordInBlock[2]) { + switch (FaTypeK) { + case 0u: return f16vec4(decodeBufF32(bl_in).block); + case 2u: return dequantFuncQ4_0_v(decodeBufQ4_0(bl_in), blockCoords, coordInBlock); + case 3u: return dequantFuncQ4_1_v(decodeBufQ4_1(bl_in), blockCoords, coordInBlock); + case 6u: return dequantFuncQ5_0_v(decodeBufQ5_0(bl_in), blockCoords, coordInBlock); + case 7u: return dequantFuncQ5_1_v(decodeBufQ5_1(bl_in), blockCoords, coordInBlock); + case 8u: return dequantFuncQ8_0_v(decodeBufQ8_0(bl_in), blockCoords, coordInBlock); + case 41u: return dequantFuncQ1_0_v(decodeBufQ1_0(bl_in), blockCoords, coordInBlock); + default: return f16vec4(0); + } +} + +f16vec4 faDecodeVVector(const decodeBufFA_V bl_in, const uint blockCoords[2], const uint coordInBlock[2]) { + switch (FaTypeV) { + case 0u: return f16vec4(decodeBufF32(bl_in).block); + case 2u: return dequantFuncQ4_0_v(decodeBufQ4_0(bl_in), blockCoords, coordInBlock); + case 3u: return dequantFuncQ4_1_v(decodeBufQ4_1(bl_in), blockCoords, coordInBlock); + case 6u: return dequantFuncQ5_0_v(decodeBufQ5_0(bl_in), blockCoords, coordInBlock); + case 7u: return dequantFuncQ5_1_v(decodeBufQ5_1(bl_in), blockCoords, coordInBlock); + case 8u: return dequantFuncQ8_0_v(decodeBufQ8_0(bl_in), blockCoords, coordInBlock); + case 41u: return dequantFuncQ1_0_v(decodeBufQ1_0(bl_in), blockCoords, coordInBlock); + default: return f16vec4(0); + } +} + +#ifdef GL_NV_cooperative_matrix_decode_vector +#define FADECODEK , faDecodeK, faDecodeKVector +#define FADECODEV , faDecodeV, faDecodeVVector +#else +#define FADECODEK , faDecodeK +#define FADECODEV , faDecodeV +#endif + layout (binding = 0) readonly buffer Q {uint8_t data_q[];}; layout (binding = 1) readonly buffer K {uint8_t data_k[];}; layout (binding = 2) readonly buffer V {uint8_t data_v[];}; @@ -259,7 +297,7 @@ void main() { // F16: bs_k==1 (direct load). F32: bs_k==4 (vec4 / dequantFuncF32). Q4/Q8 family: bs_k==32. Q1_0: bs_k==128. const bool k_use_decode = (bs_k > 1u); if (k_use_decode) { - coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose, faDecodeK); + coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose FADECODEK); } else { coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose); } @@ -325,7 +363,7 @@ void main() { uint32_t v_offset = iv2*p.nb22 + iv3*p.nb23; const bool v_use_decode = (bs_v > 1u); if (v_use_decode) { - coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad), faDecodeV); + coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad) FADECODEV); } else { coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad)); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp index 497a18ff8a7..250d708479b 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp @@ -71,10 +71,12 @@ layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; #if QUANT_K > 1 -#define DECODEFUNCA , dequantFuncA - #include "dequant_funcs_cm2.glsl" - +#if defined(dequantFuncA_v) && defined(GL_NV_cooperative_matrix_decode_vector) +#define DECODEFUNCA , dequantFuncA, dequantFuncA_v +#else +#define DECODEFUNCA , dequantFuncA +#endif #else #define DECODEFUNCA #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl index 4bcd97756fd..06eff6f219f 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl @@ -1722,11 +1722,18 @@ struct block_nvfp4 uint8_t qs[QUANT_K_NVFP4 / 2]; }; +struct block_nvfp4_packed32 +{ + uint32_t d[QUANT_K_NVFP4 / 16 / 4]; + uint32_t qs[QUANT_K_NVFP4 / 2 / 4]; +}; + #if defined(DATA_A_NVFP4) #define QUANT_K QUANT_K_NVFP4 #define QUANT_R QUANT_R_NVFP4 #define QUANT_AUXF 1 #define A_TYPE block_nvfp4 +#define A_TYPE_PACKED32 block_nvfp4_packed32 #endif #if defined(DATA_A_IQ4_NL) || defined(DATA_A_IQ4_XS) From 8bce478ee8be5cc5f78d6d38cbacdd4f6f1ae64e Mon Sep 17 00:00:00 2001 From: Matt Corallo <649246+TheBlueMatt@users.noreply.github.com> Date: Wed, 27 May 2026 15:19:23 +0000 Subject: [PATCH 713/831] vulkan: Switch MUL_MAT_VEC to 4 K per iteration for F16/32 (llama/22887) * vulkan: Switch MUL_MAT_VEC to 4 K per iteration for F16/32 Against mesa git, this shows a 4.8% performance improvement for tg128 on Qwen3.5-9B:BF16 on Intel BMG. Note that this breaks some tests until the last commit which fixes OOB A reads. * vulkan: Use aligned loads in mul_mat_vec when available Against mesa git, this shows a 3.3% performance improvement for tg128 on Qwen3.5-9B:BF16 on Intel BMG. * Make explicit that `num_rows` is <= `NUM_ROWS` in mul_mat_vec Mesa's UUB logic can't see through conditionals, limiting its ability to understand the bounds on the `num_rows` field in the cleanup run. Making it explicit that `num_rows` is, indeed, always <= `NUM_ROWS` helps mesa make slightly better codegen. Against mesa git, this currently shows a 1% performance improvement in tg128 on Qwen3.5-9B:BF16 on Intel BMG. * vulkan: Fix OOB A reads in MUL_MAT_VEC for odd sizes There was a TODO to fix the OOB reads from the A matrix which we do here. It is within performance noise (+<0.1%) in tg128 for Qwen3.5-9B:BF16 on Intel BMG. --- .../vulkan-shaders/dequant_funcs.glsl | 39 +++++ .../vulkan-shaders/mul_mat_vec.comp | 149 ++++++++++++++---- .../src/ggml-vulkan/vulkan-shaders/types.glsl | 2 + 3 files changed, 163 insertions(+), 27 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl index 88d07d2dfd5..e67299fdeca 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl @@ -5,21 +5,60 @@ #include "types.glsl" #if defined(DATA_A_F32) +FLOAT_TYPE dequantize1(uint ib, uint iqs, uint a_offset) { + return data_a[a_offset + ib]; +} vec2 dequantize(uint ib, uint iqs, uint a_offset) { return vec2(data_a[a_offset + ib], data_a[a_offset + ib + 1]); } +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + return vec4(data_a[a_offset + ib ], data_a[a_offset + ib + 1], + data_a[a_offset + ib + 2], data_a[a_offset + ib + 3]); +} +vec4 dequantize4_2aligned(uint ib, uint iqs, uint a_offset) { + return vec4(data_a[a_offset + ib ], data_a[a_offset + ib + 1], + data_a[a_offset + ib + 2], data_a[a_offset + ib + 3]); +} + #endif #if defined(DATA_A_F16) +FLOAT_TYPE dequantize1(uint ib, uint iqs, uint a_offset) { + return data_a[a_offset + ib]; +} vec2 dequantize(uint ib, uint iqs, uint a_offset) { return vec2(data_a[a_offset + ib], data_a[a_offset + ib + 1]); } +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + return vec4(data_a[a_offset + ib ], data_a[a_offset + ib + 1], + data_a[a_offset + ib + 2], data_a[a_offset + ib + 3]); +} +vec4 dequantize4_2aligned(uint ib, uint iqs, uint a_offset) { + const vec2 a = data_a_packed32[(a_offset + ib)/2]; + const vec2 b = data_a_packed32[(a_offset + ib)/2 + 1]; + return vec4(a, b); +} #endif #if defined(DATA_A_BF16) +FLOAT_TYPE dequantize1(uint ib, uint iqs, uint a_offset) { + return bf16_to_fp32(data_a[a_offset + ib]); +} vec2 dequantize(uint ib, uint iqs, uint a_offset) { return vec2(bf16_to_fp32(data_a[a_offset + ib]), bf16_to_fp32(data_a[a_offset + ib + 1])); } +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + return vec4(bf16_to_fp32(data_a[a_offset + ib ]), bf16_to_fp32(data_a[a_offset + ib + 1]), + bf16_to_fp32(data_a[a_offset + ib + 2]), bf16_to_fp32(data_a[a_offset + ib + 3])); +} +vec4 dequantize4_2aligned(uint ib, uint iqs, uint a_offset) { + const uint a = data_a_packed32[(a_offset + ib)/2]; + const uint b = data_a_packed32[(a_offset + ib)/2 + 1]; + return vec4(uintBitsToFloat((a & 0x0000ffff) << 16), + uintBitsToFloat( a & 0xffff0000), + uintBitsToFloat((b & 0x0000ffff) << 16), + uintBitsToFloat( b & 0xffff0000)); +} #endif #if defined(DATA_A_Q4_0) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp index 2271be4021b..5a9d0e778fd 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp @@ -10,12 +10,38 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; #if !defined(DATA_A_F32) && !defined(DATA_A_F16) && !defined(DATA_A_BF16) #define K_PER_ITER 8 #else -#define K_PER_ITER 2 +#define K_PER_ITER 4 #endif uint a_offset, b_offset, d_offset, y_offset; +vec4 load_b(const uint j, const uint iybs, const uint iqs, const bool lastiter, out bool OOB_y, out bool OOB_z, out bool OOB_w) { + // Check if the latter elements are OOB, and don't fetch B or accumulate it. + OOB_y = lastiter && (iybs + iqs + y_offset >= p.ncols); + OOB_z = lastiter && (iybs + iqs + y_offset*2 >= p.ncols); + OOB_w = lastiter && (iybs + iqs + y_offset*3 >= p.ncols); + + if (!OOB_w) { + return vec4(FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs]), + FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset]), + FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset*2]), + FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset*3])); + } else if (!OOB_z) { + return vec4(FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs]), + FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset]), + FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset*2]), + 0); + } else if (!OOB_y) { + return vec4(FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs]), + FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset]), + 0, 0); + } else { + return vec4(FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs]), + 0, 0, 0); + } +} + void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i, bool lastiter) { [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { @@ -25,6 +51,8 @@ void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const #if K_PER_ITER == 8 #if QUANT_R == 2 + // Note that we end up fetching bogus elements here, but its fine as they'll be + // within an accessible block. const vec4 bv02 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4]); const vec4 bv13 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs + y_offset) / 4]); const vec4 bv0 = vec4(bv02.x, bv13.x, bv02.y, bv13.y); @@ -34,18 +62,11 @@ void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const const vec4 bv1 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4 + 1]); #endif #else - // Check if the second of the pair of elements is OOB, and don't fetch B or - // accumulate it. We still fetch a pair of elements for A, which is fine for - // quantized formats since they'll be within the same block. We should - // probably skip fetching the second element for F16/F32, but as of now we - // still do. - const bool OOB = lastiter && (iybs + iqs + y_offset >= p.ncols); - - FLOAT_TYPE b0 = 0, b1 = 0; - b0 = FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs]); - if (!OOB) { - b1 = FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset]); - } + bool OOB_y; + bool OOB_z; + bool OOB_w; + + const vec4 b = load_b(j, iybs, iqs, lastiter, OOB_y, OOB_z, OOB_w); #endif uint ibi = first_row*p.ncols; [[unroll]] for (uint n = 0; n < num_rows; ++n) { @@ -71,22 +92,60 @@ void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const temp[j][n] += rowtmp; #else - const vec2 v = dequantize(ib, iqs, a_offset); - - // matrix multiplication - temp[j][n] = fma(FLOAT_TYPE(v.x), b0, temp[j][n]); - if (!OOB) { - temp[j][n] = fma(FLOAT_TYPE(v.y), b1, temp[j][n]); + if (!OOB_w) { + const vec4 v = dequantize4(ib, iqs, a_offset); + temp[j][n] += dot(v, b); + } else if (!OOB_z) { + const vec2 v0 = dequantize(ib, iqs, a_offset); + const FLOAT_TYPE v1 = dequantize1(ib + 2/QUANT_R, iqs, a_offset); + const vec3 v = vec3(v0.x, v0.y, v1); + const vec3 b0 = vec3(b.x, b.y, b.z); + temp[j][n] += dot(v, b0); + } else if (!OOB_y) { + const vec2 v0 = dequantize(ib, iqs, a_offset); + const vec2 b0 = vec2(b.x, b.y); + temp[j][n] += dot(v0, b0); + } else { + const FLOAT_TYPE v = dequantize1(ib, iqs, a_offset); + temp[j][n] = fma(v, b.x, temp[j][n]); } #endif } } } +#if defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16) +void iter_aligned_nonquant(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i) +{ + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + const uint col = i*BLOCK_SIZE + K_PER_ITER*tid; + const uint iqs = 0; // quant index + const uint iybs = col; // y block start index + + const vec4 b = data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4]; + + uint ibi = first_row*p.ncols; + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint ib = (ibi + col)/QUANT_K; // block index + ibi += p.ncols; + + const vec4 v = dequantize4_2aligned(ib, iqs, a_offset); + + // matrix multiplication + temp[j][n] += dot(v, b); + } + } +} +#endif + void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const uint tid = gl_LocalInvocationID.x; get_offsets(a_offset, b_offset, d_offset); + const bool is_aligned_nonquant = + p.batch_stride_b % 4 == 0 && b_offset % 4 == 0 && + p.ncols % 4 == 0 && BLOCK_SIZE % 4 == 0 && + K_PER_ITER == 4; y_offset = QUANT_R == 1 ? 1 : QUANT_K/2; @@ -105,17 +164,26 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { int unroll_count = 4; uint unrolled_iters = num_iters & ~(unroll_count - 1); -#if K_PER_ITER == 2 + uint i = 0; + +#if K_PER_ITER == 4 // If the K dimension is odd, we need lastiter==true on the last iteration // so OOB is computed correctly. Skip some unrolling to make that happen. - if ((p.ncols & 1) != 0 && + if ((p.ncols & 3) != 0 && unrolled_iters == num_iters && unrolled_iters > 0) { unrolled_iters -= unroll_count; } + if (is_aligned_nonquant) { + while (i < unrolled_iters) { + // Manually partially unroll the loop + [[unroll]] for (uint k = 0; k < unroll_count; ++k) { + iter_aligned_nonquant(temp, first_row, num_rows, tid, i*K_PER_ITER); + i++; + } + } + } else { #endif - - uint i = 0; while (i < unrolled_iters) { // Manually partially unroll the loop [[unroll]] for (uint k = 0; k < unroll_count; ++k) { @@ -123,18 +191,30 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { i++; } } +#if K_PER_ITER == 4 + } +#endif unroll_count = 2; unrolled_iters = num_iters & ~(unroll_count - 1); -#if K_PER_ITER == 2 - if ((p.ncols & 1) != 0 && +#if K_PER_ITER == 4 + if ((p.ncols & 3) != 0 && unrolled_iters == num_iters && unrolled_iters > 0) { unrolled_iters -= unroll_count; } -#endif + if (is_aligned_nonquant) { + while (i < unrolled_iters && is_aligned_nonquant) { + // Manually partially unroll the loop + [[unroll]] for (uint k = 0; k < unroll_count; ++k) { + iter_aligned_nonquant(temp, first_row, num_rows, tid, i*K_PER_ITER); + i++; + } + } + } else { +#endif while (i < unrolled_iters) { // Manually partially unroll the loop [[unroll]] for (uint k = 0; k < unroll_count; ++k) { @@ -142,10 +222,25 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { i++; } } +#if K_PER_ITER == 4 + } +#endif + +#if K_PER_ITER == 4 + if (is_aligned_nonquant) { + while (i < num_iters) { + iter_aligned_nonquant(temp, first_row, num_rows, tid, i*K_PER_ITER); + i++; + } + } else { +#endif while (i < num_iters) { iter(temp, first_row, num_rows, tid, i*K_PER_ITER, true); i++; } +#if K_PER_ITER == 4 + } +#endif reduce_result(temp, d_offset, first_row, num_rows, tid); } @@ -164,6 +259,6 @@ void main() { if (first_row >= p.stride_d) { return; } - compute_outputs(first_row, p.stride_d - first_row); + compute_outputs(first_row, min(NUM_ROWS, p.stride_d - first_row)); } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl index 06eff6f219f..f84d6f87334 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl @@ -31,6 +31,7 @@ #else #define A_TYPE float16_t #endif +#define A_TYPE_PACKED32 f16vec2 #endif #if defined(DATA_A_BF16) @@ -44,6 +45,7 @@ #else #define A_TYPE uint16_t #endif +#define A_TYPE_PACKED32 uint32_t #endif #define QUANT_K_Q4_0 32 From a52bd385d678e152774c211dc7a8ac372650558b Mon Sep 17 00:00:00 2001 From: Masashi Yoshimura Date: Thu, 28 May 2026 01:48:12 +0900 Subject: [PATCH 714/831] ggml-webgpu: Fix how to dispatch WG to some ops (llama/23750) --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 53 +++++++++++-------- ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl | 10 ++-- .../wgsl-shaders/mul_mat_id_gather.wgsl | 43 +++++++-------- 3 files changed, 57 insertions(+), 49 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index f113da909ce..f6d17a073be 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -749,8 +749,11 @@ static webgpu_encoded_op ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst), }; - uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); - return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); + uint32_t wg_x; + uint32_t wg_y; + uint32_t total_wg = CEIL_DIV(ne, decisions->wg_size); + compute_2d_workgroups(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, wg_x, wg_y); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); } static webgpu_encoded_op ggml_webgpu_set(webgpu_context & ctx, @@ -974,9 +977,10 @@ static webgpu_encoded_op ggml_webgpu_conv_2d(webgpu_context & ctx, auto * decisions = static_cast(pipeline.context.get()); + uint32_t wg_x; + uint32_t wg_y; uint32_t total_wg = CEIL_DIV((uint32_t) ggml_nelements(dst), decisions->wg_size); - uint32_t wg_x = std::min(ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, total_wg); - uint32_t wg_y = CEIL_DIV(total_wg, wg_x); + compute_2d_workgroups(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, wg_x, wg_y); return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); } @@ -1064,9 +1068,10 @@ static webgpu_encoded_op ggml_webgpu_im2col(webgpu_context & ctx, auto * decisions = static_cast(pipeline.context.get()); + uint32_t wg_x; + uint32_t wg_y; uint32_t total_wg = CEIL_DIV((uint32_t) ggml_nelements(dst), decisions->wg_size); - uint32_t wg_x = std::min(ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, total_wg); - uint32_t wg_y = CEIL_DIV(total_wg, wg_x); + compute_2d_workgroups(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, wg_x, wg_y); return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); } @@ -1689,14 +1694,11 @@ static webgpu_encoded_op ggml_webgpu_mul_mat_id(webgpu_context & ctx, gathered_count_ids_binding_size), }; - const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; - - const uint32_t gather_total_wg = param_n_expert; - const uint32_t gather_wg_x = std::min(gather_total_wg, max_wg_per_dim); - const uint32_t gather_wg_y = CEIL_DIV(gather_total_wg, gather_wg_x); + // n_expert is much less than maxComputeWorkgroupsPerDimension (e.g., n_exeprt=256 at Qwen3.5-35B-A3B) + const uint32_t gather_wg_x = param_n_expert; dispatches.push_back({ - gather_pipeline, std::move(gather_params), std::move(gather_entries), { gather_wg_x, gather_wg_y } + gather_pipeline, std::move(gather_params), std::move(gather_entries), { gather_wg_x, 1 } }); // params for mul_mat_id.wgsl @@ -1748,7 +1750,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat_id(webgpu_context & ctx, uint32_t max_wg_n = CEIL_DIV(total_gathered, tile_n_s) + max_active_experts; uint32_t total_wg = wg_m * max_wg_n; - compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y); + compute_2d_workgroups(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, wg_x, wg_y); dispatches.push_back({ main_pipeline, std::move(main_params), std::move(main_entries), { wg_x, wg_y } @@ -2771,10 +2773,12 @@ static webgpu_encoded_op ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * block_size, npr, nrows }; - const uint32_t total_wg_init = npr * nrows; - const uint32_t max_wg = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; - const uint32_t wg_x_init = std::min(total_wg_init, max_wg); - const uint32_t wg_y_init = CEIL_DIV(total_wg_init, wg_x_init); + uint32_t wg_x_init; + uint32_t wg_y_init; + const uint32_t total_wg_init = npr * nrows; + const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; + compute_2d_workgroups(total_wg_init, max_wg_per_dim, wg_x_init, wg_y_init); + std::vector init_entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src), ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(dst), init_align_offset, init_binding_size) @@ -2831,9 +2835,11 @@ static webgpu_encoded_op ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * ggml_webgpu_make_bind_group_entry(2, ggml_webgpu_tensor_buf(dst), align_out, size_out) }; + uint32_t wg_x_merge; + uint32_t wg_y_merge; const uint32_t total_wg_merge = nm * nrows; - const uint32_t wg_x_merge = std::min(total_wg_merge, max_wg); - const uint32_t wg_y_merge = CEIL_DIV(total_wg_merge, wg_x_merge); + compute_2d_workgroups(total_wg_merge, max_wg_per_dim, wg_x_merge, wg_y_merge); + dispatches.push_back({ argsort_merge_pipeline, std::move(merge_params), std::move(merge_entries), { wg_x_merge, wg_y_merge } }); @@ -2953,9 +2959,12 @@ static webgpu_encoded_op ggml_webgpu_upscale(webgpu_context ctx, ggml_tensor * s webgpu_pipeline pipeline = ctx->shader_lib->get_upscale_pipeline(shader_lib_ctx); auto * decisions = static_cast(pipeline.context.get()); - uint32_t total_wg = CEIL_DIV((uint32_t) ggml_nelements(dst), decisions->wg_size); - uint32_t wg_x = std::min(ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, total_wg); - uint32_t wg_y = CEIL_DIV(total_wg, wg_x); + + uint32_t wg_x; + uint32_t wg_y; + uint32_t total_wg = CEIL_DIV((uint32_t) ggml_nelements(dst), decisions->wg_size); + compute_2d_workgroups(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, wg_x, wg_y); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl index fa3bdf4e393..e268adfb16b 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl @@ -49,12 +49,14 @@ struct Params{ var params: Params; @compute @workgroup_size(WG_SIZE) -fn main(@builtin(global_invocation_id) gid: vec3) { - if (gid.x >= params.ne) { +fn main( + @builtin(global_invocation_index) gindex: u32, +) { + if (gindex >= params.ne) { return; } - var i = gid.x; + var i = gindex; let i3 = i / (params.src_ne2 * params.src_ne1 * params.src_ne0); i = i % (params.src_ne2 * params.src_ne1 * params.src_ne0); let i2 = i / (params.src_ne1 * params.src_ne0); @@ -62,7 +64,7 @@ fn main(@builtin(global_invocation_id) gid: vec3) { let i1 = i / params.src_ne0; let i0 = i % params.src_ne0; - var j = gid.x; + var j = gindex; let j3 = j / (params.dst_ne2 * params.dst_ne1 * params.dst_ne0); j = j % (params.dst_ne2 * params.dst_ne1 * params.dst_ne0); let j2 = j / (params.dst_ne1 * params.dst_ne0); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl index d79d5f3f282..581e922709d 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl @@ -21,35 +21,32 @@ var count:atomic; @compute @workgroup_size(WG_SIZE) fn main(@builtin(workgroup_id) wg_id: vec3, - @builtin(local_invocation_id) local_id: vec3, - @builtin(num_workgroups) num_wg: vec3) { + @builtin(local_invocation_id) local_id: vec3) { let thread_id = local_id.x; - let own_expert = wg_id.y * num_wg.x + wg_id.x; // the expert assigned to this workgroup + let own_expert = wg_id.x; // the expert assigned to this workgroup - if (own_expert < params.n_expert) { - if (thread_id == 0u) { - atomicStore(&count, 0); - } + if (thread_id == 0u) { + atomicStore(&count, 0); + } - workgroupBarrier(); - - for (var i = thread_id;i < params.n_expert_used * params.n_tokens;i += WG_SIZE) { - let row = i / params.n_expert_used; - let col = i % params.n_expert_used; - let expert = u32(ids[params.offset_ids + row * params.stride_ids_1 + col]); - if (own_expert == expert) { - let pos = atomicAdd(&count, 1u); - let gathered_id = own_expert * params.n_tokens + pos; - global_gathered_expert_used[gathered_id] = col; - global_gathered_tokens[gathered_id] = row; - } + workgroupBarrier(); + + for (var i = thread_id;i < params.n_expert_used * params.n_tokens;i += WG_SIZE) { + let row = i / params.n_expert_used; + let col = i % params.n_expert_used; + let expert = u32(ids[params.offset_ids + row * params.stride_ids_1 + col]); + if (own_expert == expert) { + let pos = atomicAdd(&count, 1u); + let gathered_id = own_expert * params.n_tokens + pos; + global_gathered_expert_used[gathered_id] = col; + global_gathered_tokens[gathered_id] = row; } + } - workgroupBarrier(); + workgroupBarrier(); - if (thread_id == 0u) { - gathered_count_ids[own_expert] = atomicLoad(&count); - } + if (thread_id == 0u) { + gathered_count_ids[own_expert] = atomicLoad(&count); } } From 3bbe93378cc96339d362e4dbf490df10412ad389 Mon Sep 17 00:00:00 2001 From: Max Krasnyansky Date: Wed, 27 May 2026 10:46:11 -0700 Subject: [PATCH 715/831] hexagon: add support for Q4_1 in MUL_MAT and MUL_MAT_ID (llama/23647) * hex-mm: add support for Q4_1 matmul/matvec, hvx-only for now * hmx-mm: add support for Q4_1 * hex-mm: use Q8_1 dynamic quantization to avoid having to compute sums in the vec_dot * hexagon: fix repack scratch buffer overflow * hex-mm: fix Q4_1 repack buffer sizing * hexagon: flip the build order for mm and fa (seems to help LTO) * hex-mm: add vec_dot 4x1s and minor HMX cleanup after adding Q4_1 * hex-mm: fix fp16 vec_dot fallback to 2x1 and another issue that could cause incorrect output * hexagon: resurrect early-wake and add support for polling for op-batch completions With Q4_1 ggml-hexagon now claims pretty much the entire graphs which gives the CPU more time to chilax. This is a good thing! But it does add extra latency for the pure benchmark runs. Early wakeup helps recover the latency a bit in the normals runs and op-batch polling is just for benchmarking. --------- Co-authored-by: Todor Boinovski --- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 267 ++- ggml/src/ggml-hexagon/htp/CMakeLists.txt | 4 +- ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c | 155 +- ggml/src/ggml-hexagon/htp/htp-ops.h | 2 + ggml/src/ggml-hexagon/htp/main.c | 7 +- ggml/src/ggml-hexagon/htp/matmul-ops.c | 1769 ++++++++++++++++++-- 6 files changed, 2004 insertions(+), 200 deletions(-) diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 1c8ecc197e9..5e8a4a740c1 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -68,6 +68,7 @@ static u32vec opt_pmu_evt { 0x3, 0x111, 0x100, 0x105, 0x240, 0x256, 0x7D, 0x8C } static int opt_opstage = HTP_OPSTAGE_QUEUE | HTP_OPSTAGE_COMPUTE; static int opt_opbatch = 1024; // max number of ops in a batch static int opt_opqueue = 16; // max number of pending batches +static int opt_oppoll = 0; // polling for batch completions static std::regex* opt_opfilter = NULL; // regex of ops to not claim @@ -550,7 +551,7 @@ static void repack_q4_0_q4x4x2(ggml_tensor * t, const void * data, size_t size) size_t row_size = ggml_row_size(t->type, t->ne[0]); size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q4_0x4x2)); // extra elements for the pad - size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any) + size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size/2 quants + scales) // Ensure we don't try to read more data than is available in the source buffer 'data' // or write more than the tensor can hold. @@ -611,7 +612,7 @@ static void repack_q4x4x2_q4_0(void * data, const ggml_tensor * t, size_t size) size_t row_size = ggml_row_size(t->type, t->ne[0]); size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q4_0x4x2)); // extra elements for the pad - size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any) + size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size/2 quants + scales) // Ensure we don't try to copy more data than the tensor actually contains. const size_t total_tensor_size = (size_t)nrows * row_size; @@ -660,6 +661,239 @@ static void repack_q4x4x2_q4_0(void * data, const ggml_tensor * t, size_t size) ggml_aligned_free(buf_rp, row_size_rp); } +static void unpack_q4_1_quants(uint8_t * qs, const block_q4_1 * x, unsigned int bi) { + static const int qk = QK4_1; + + for (unsigned int i = 0; i < qk / 2; ++i) { + const int x0 = (x->qs[i] & 0x0F); + const int x1 = (x->qs[i] >> 4); + qs[bi * qk + i + 0] = x0; + qs[bi * qk + i + qk / 2] = x1; + } +} + +static void pack_q4_1_quants(block_q4_1 * x, const uint8_t * qs, unsigned int bi) { + static const int qk = QK4_1; + + for (unsigned int i = 0; i < qk / 2; ++i) { + const uint8_t x0 = qs[bi * qk + i + 0]; + const uint8_t x1 = qs[bi * qk + i + qk / 2]; + x->qs[i] = x0 | (x1 << 4); + } +} + +static void repack_row_q4_1x4x2(uint8_t * y, const block_q4_1 * x, int64_t k) { + static const int qk = QK_Q4_0x4x2; + const int nb = (k + qk - 1) / qk; // number of blocks (padded) + const int nloe = k % qk; // leftovers + + const int dblk_size = 8 * 4; // 8x (d, m) __fp16 = 32 bytes + const int qblk_size = qk / 2; // int4 = 128 bytes + const int qrow_size = k / 2; // int4 (not padded to blocks) + + uint8_t * y_q = y + 0; // quants first + uint8_t * y_d = y + qrow_size; // then scales/offsets + + // Repack the quants + for (int i = 0; i < nb; i++) { + uint8_t qs[QK_Q4_0x4x2]; // unpacked quants + unpack_q4_1_quants(qs, &x[i * 8 + 0], 0); + unpack_q4_1_quants(qs, &x[i * 8 + 1], 1); + unpack_q4_1_quants(qs, &x[i * 8 + 2], 2); + unpack_q4_1_quants(qs, &x[i * 8 + 3], 3); + unpack_q4_1_quants(qs, &x[i * 8 + 4], 4); + unpack_q4_1_quants(qs, &x[i * 8 + 5], 5); + unpack_q4_1_quants(qs, &x[i * 8 + 6], 6); + unpack_q4_1_quants(qs, &x[i * 8 + 7], 7); + + bool partial = (nloe && i == nb-1); + + uint8_t * q = y_q + (i * qblk_size); + for (int j = 0; j < qk / 2; j++) { + q[j] = partial ? (qs[j*2+1] << 4) | qs[j*2+0] : (qs[j+128] << 4) | qs[j+000]; + } + } + + // Repack the scales and offsets + for (int i = 0; i < nb; i++) { + ggml_half * d_m = (ggml_half *) (y_d + i * dblk_size); + for (int j = 0; j < 8; j++) { + d_m[j * 2 + 0] = x[i * 8 + j].d; + d_m[j * 2 + 1] = x[i * 8 + j].m; + } + } +} + +static void unpack_row_q4_1x4x2(block_q4_1 * x, const uint8_t * y, int64_t k) { + static const int qk = QK_Q4_0x4x2; + const int nb = (k + qk - 1) / qk; // number of blocks (padded) + const int nloe = k % qk; // leftovers + + const int dblk_size = 8 * 4; // 8x (d, m) __fp16 = 32 bytes + const int qblk_size = qk / 2; // int4 = 128 bytes + const int qrow_size = k / 2; // int4 (not padded to blocks) + + const uint8_t * y_q = y + 0; // quants first + const uint8_t * y_d = y + qrow_size; // then scales/offsets + + // Unpack the quants + for (int i = 0; i < nb; i++) { + uint8_t qs[QK_Q4_0x4x2]; + bool partial = (nloe && i == nb-1); + + const uint8_t * q = y_q + (i * qblk_size); + for (int j = 0; j < qk / 2; j++) { + if (partial) { + qs[j*2+0] = q[j] & 0x0F; + qs[j*2+1] = q[j] >> 4; + } else { + qs[j+000] = q[j] & 0x0F; + qs[j+128] = q[j] >> 4; + } + } + + pack_q4_1_quants(&x[i * 8 + 0], qs, 0); + pack_q4_1_quants(&x[i * 8 + 1], qs, 1); + pack_q4_1_quants(&x[i * 8 + 2], qs, 2); + pack_q4_1_quants(&x[i * 8 + 3], qs, 3); + pack_q4_1_quants(&x[i * 8 + 4], qs, 4); + pack_q4_1_quants(&x[i * 8 + 5], qs, 5); + pack_q4_1_quants(&x[i * 8 + 6], qs, 6); + pack_q4_1_quants(&x[i * 8 + 7], qs, 7); + } + + // Unpack the scales and offsets + for (int i = 0; i < nb; i++) { + const ggml_half * d_m = (const ggml_half *) (y_d + i * dblk_size); + for (int j = 0; j < 8; j++) { + x[i * 8 + j].d = d_m[j * 2 + 0]; + x[i * 8 + j].m = d_m[j * 2 + 1]; + } + } +} + +static void init_row_q4_1x4x2(block_q4_1 * x, int64_t k) { + static const int qk = QK_Q4_0x4x2; + const int nb = (k + qk - 1) / qk; // number of blocks (padded) + + uint8_t qs[QK_Q4_0x4x2]; // unpacked quants + memset(qs, 0, sizeof(qs)); + + for (int i = 0; i < nb; i++) { + pack_q4_1_quants(&x[i * 8 + 0], qs, 0); + pack_q4_1_quants(&x[i * 8 + 1], qs, 1); + pack_q4_1_quants(&x[i * 8 + 2], qs, 2); + pack_q4_1_quants(&x[i * 8 + 3], qs, 3); + pack_q4_1_quants(&x[i * 8 + 4], qs, 4); + pack_q4_1_quants(&x[i * 8 + 5], qs, 5); + pack_q4_1_quants(&x[i * 8 + 6], qs, 6); + pack_q4_1_quants(&x[i * 8 + 7], qs, 7); + } + + for (int i = 0; i < nb; i++) { + for (int j = 0; j < 8; j++) { + x[i * 8 + j].d = 0; + x[i * 8 + j].m = 0; + } + } +} + +static void repack_q4_1_q4x4x2(ggml_tensor * t, const void * data, size_t size) { + int64_t nrows = ggml_nrows(t); + + size_t row_size = ggml_row_size(t->type, t->ne[0]); + size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q4_0x4x2)); + size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size/2 quants + scales) + + const size_t total_tensor_size = (size_t)nrows * row_size; + const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size; + + const int64_t n_full_rows = n_bytes_to_copy / row_size; + const size_t n_rem_bytes = n_bytes_to_copy % row_size; + + void * buf_pd = ggml_aligned_malloc(row_size_pd); + GGML_ASSERT(buf_pd != NULL); + + void * buf_rp = ggml_aligned_malloc(row_size_rp); + GGML_ASSERT(buf_rp != NULL); + + HEX_VERBOSE("ggml-hex: repack-q4_1-q4x4x2 %s : data %p size %zu dims %ldx%ld row-size %zu\n", t->name, data, size, + t->ne[0], nrows, row_size); + + init_row_q4_1x4x2((block_q4_1 *) buf_pd, t->ne[0]); + + for (int64_t i = 0; i < n_full_rows; i++) { + const uint8_t * src = (const uint8_t *) data + (i * row_size); + uint8_t * dst = (uint8_t *) t->data + (i * row_size); + + memcpy(buf_pd, src, row_size); + repack_row_q4_1x4x2((uint8_t *) buf_rp, (const block_q4_1 *) buf_pd, t->ne[0]); + memcpy(dst, buf_rp, row_size); + } + + if (n_rem_bytes > 0) { + const int64_t i = n_full_rows; + const uint8_t * src = (const uint8_t *) data + (i * row_size); + uint8_t * dst = (uint8_t *) t->data + (i * row_size); + + init_row_q4_1x4x2((block_q4_1 *) buf_pd, t->ne[0]); + memcpy(buf_pd, src, n_rem_bytes); + repack_row_q4_1x4x2((uint8_t *) buf_rp, (const block_q4_1 *) buf_pd, t->ne[0]); + memcpy(dst, buf_rp, n_rem_bytes); + } + + ggml_aligned_free(buf_pd, row_size_pd); + ggml_aligned_free(buf_rp, row_size_rp); +} + +static void repack_q4x4x2_q4_1(void * data, const ggml_tensor * t, size_t size) { + int64_t nrows = ggml_nrows(t); + + size_t row_size = ggml_row_size(t->type, t->ne[0]); + size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q4_0x4x2)); + size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size/2 quants + scales) + + const size_t total_tensor_size = (size_t)nrows * row_size; + const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size; + + const int64_t n_full_rows = n_bytes_to_copy / row_size; + const size_t n_rem_bytes = n_bytes_to_copy % row_size; + + void * buf_pd = ggml_aligned_malloc(row_size_pd); + GGML_ASSERT(buf_pd != NULL); + + void * buf_rp = ggml_aligned_malloc(row_size_rp); + GGML_ASSERT(buf_rp != NULL); + + HEX_VERBOSE("ggml-hex: repack-q4x4x2-q4_1 %s : data %p size %zu dims %ldx%ld row-size %zu\n", t->name, data, size, + t->ne[0], nrows, row_size); + + memset(buf_rp, 0, row_size_rp); // clear-out padded buffer to make sure the tail is all zeros + + for (int64_t i = 0; i < n_full_rows; i++) { + const uint8_t * src = (const uint8_t *) t->data + (i * row_size); + uint8_t * dst = (uint8_t *) data + (i * row_size); + + memcpy(buf_rp, src, row_size); + unpack_row_q4_1x4x2((block_q4_1 *) buf_pd, (const uint8_t *) buf_rp, t->ne[0]); + memcpy(dst, buf_pd, row_size); + } + + if (n_rem_bytes > 0) { + const int64_t i = n_full_rows; + const uint8_t * src = (const uint8_t *) t->data + (i * row_size); + uint8_t * dst = (uint8_t *) data + (i * row_size); + + // We still need to read and unpack the entire source row because quantization is block-based. + memcpy(buf_rp, src, row_size); + unpack_row_q4_1x4x2((block_q4_1 *) buf_pd, (const uint8_t *) buf_rp, t->ne[0]); + memcpy(dst, buf_pd, n_rem_bytes); + } + + ggml_aligned_free(buf_pd, row_size_pd); + ggml_aligned_free(buf_rp, row_size_rp); +} + // ======== Q8x4x2 ==================== static void dump_block_q8_0(const block_q8_0 * b, int i) { HEX_VERBOSE("ggml-hex: repack q8_0 %d: %d %d %d %d ... %d %d %d %d : %.6f\n", i, b->qs[0], b->qs[1], b->qs[2], @@ -876,7 +1110,7 @@ static void repack_q8_0_q8x4x2(ggml_tensor * t, const void * data, size_t size) size_t row_size = ggml_row_size(t->type, t->ne[0]); size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q8_0x4x2)); // extra elements for the pad - size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any) + size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size quants + scales) // Ensure we don't try to read more data than is available in the source buffer 'data' // or write more than the tensor can hold. @@ -937,7 +1171,7 @@ static void repack_q8x4x2_q8_0(void * data, const ggml_tensor * t, size_t size) size_t row_size = ggml_row_size(t->type, t->ne[0]); size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q8_0x4x2)); // extra elements for the pad - size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any) + size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size quants + scales) // Ensure we don't try to copy more data than the tensor actually contains. const size_t total_tensor_size = (size_t)nrows * row_size; @@ -1238,7 +1472,7 @@ static void repack_mxfp4_mxfp4x4x2(ggml_tensor * t, const void * data, size_t si size_t row_size = ggml_row_size(t->type, t->ne[0]); size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_MXFP4x4x2)); // extra elements for the pad - size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any) + size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size/2 quants + scales) // Ensure we don't try to read more data than is available in the source buffer 'data' // or write more than the tensor can hold. @@ -1299,7 +1533,7 @@ static void repack_mxfp4x4x2_mxfp4(void * data, const ggml_tensor * t, size_t si size_t row_size = ggml_row_size(t->type, t->ne[0]); size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_MXFP4x4x2)); // extra elements for the pad - size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any) + size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size/2 quants + scales) // Ensure we don't try to copy more data than the tensor actually contains. const size_t total_tensor_size = (size_t)nrows * row_size; @@ -1365,6 +1599,12 @@ static void ggml_backend_hexagon_buffer_set_tensor(ggml_backend_buffer_t buffer, repack_q4_0_q4x4x2(tensor, data, size); break; + case GGML_TYPE_Q4_1: + GGML_ASSERT(offset == 0); + GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); + repack_q4_1_q4x4x2(tensor, data, size); + break; + case GGML_TYPE_Q8_0: GGML_ASSERT(offset == 0); GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); @@ -1407,6 +1647,12 @@ static void ggml_backend_hexagon_buffer_get_tensor(ggml_backend_buffer_t buffer, repack_q4x4x2_q4_0(data, tensor, size); break; + case GGML_TYPE_Q4_1: + GGML_ASSERT(offset == 0); + GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); + repack_q4x4x2_q4_1(data, tensor, size); + break; + case GGML_TYPE_Q8_0: GGML_ASSERT(offset == 0); GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); @@ -1886,7 +2132,8 @@ void ggml_hexagon_session::flush_pending(bool all) { uint32_t n_dbufs; // Read response packet from queue - int err = dspqueue_read(this->queue, &flags, 1, &n_dbufs, &dbuf, sizeof(rsp), &rsp_size, (uint8_t *) &rsp, DSPQUEUE_TIMEOUT); + const uint32_t timeo = opt_oppoll ? 0 : DSPQUEUE_TIMEOUT; + int err = dspqueue_read(this->queue, &flags, 1, &n_dbufs, &dbuf, sizeof(rsp), &rsp_size, (uint8_t *) &rsp, timeo); if (err == AEE_EEXPIRED) { continue; } @@ -2327,6 +2574,7 @@ static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * s switch (src0->type) { case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: case GGML_TYPE_Q8_0: case GGML_TYPE_IQ4_NL: case GGML_TYPE_MXFP4: @@ -2377,6 +2625,7 @@ static bool ggml_hexagon_supported_mul_mat_id(const struct ggml_hexagon_session switch (src0->type) { case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: case GGML_TYPE_Q8_0: case GGML_TYPE_IQ4_NL: case GGML_TYPE_MXFP4: @@ -3622,6 +3871,8 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) { // Basic sanity checks to make sure definitions match static_assert((unsigned int) HTP_TYPE_Q4_0 == (unsigned int) GGML_TYPE_Q4_0, "please update hexagon_type to match ggml_type"); + static_assert((unsigned int) HTP_TYPE_Q4_1 == (unsigned int) GGML_TYPE_Q4_1, + "please update hexagon_type to match ggml_type"); static_assert((unsigned int) HTP_TYPE_Q8_0 == (unsigned int) GGML_TYPE_Q8_0, "please update hexagon_type to match ggml_type"); static_assert((unsigned int) HTP_TYPE_MXFP4 == (unsigned int) GGML_TYPE_MXFP4, @@ -3634,6 +3885,7 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) { const char * str_opstage = getenv("GGML_HEXAGON_OPSTAGE"); const char * str_opbatch = getenv("GGML_HEXAGON_OPBATCH"); const char * str_opqueue = getenv("GGML_HEXAGON_OPQUEUE"); + const char * str_oppoll = getenv("GGML_HEXAGON_OPPOLL"); const char * str_opfilter = getenv("GGML_HEXAGON_OPFILTER"); const char * str_profile = getenv("GGML_HEXAGON_PROFILE"); const char * str_etm = getenv("GGML_HEXAGON_ETM"); @@ -3671,6 +3923,7 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) { opt_opstage = str_opstage ? strtoul(str_opstage, NULL, 0) : opt_opstage; opt_opbatch = str_opbatch ? strtoul(str_opbatch, NULL, 0) : opt_opbatch; opt_opqueue = str_opqueue ? strtoul(str_opqueue, NULL, 0) : opt_opqueue; + opt_oppoll = str_oppoll ? strtoul(str_oppoll, NULL, 0) : opt_oppoll; opt_profile = str_profile ? atoi(str_profile) : 0; opt_etm = str_etm ? atoi(str_etm) : 0; opt_nhvx = str_nhvx ? strtoul(str_nhvx, NULL, 0) : opt_nhvx; diff --git a/ggml/src/ggml-hexagon/htp/CMakeLists.txt b/ggml/src/ggml-hexagon/htp/CMakeLists.txt index 33d67dda9cc..d7927261a85 100644 --- a/ggml/src/ggml-hexagon/htp/CMakeLists.txt +++ b/ggml/src/ggml-hexagon/htp/CMakeLists.txt @@ -59,14 +59,14 @@ list(FIND HTP_HMX_VERSIONS ${DSP_VERSION} _hmx_idx) if (_hmx_idx GREATER_EQUAL 0) target_sources(${HTP_LIB} PRIVATE hmx-queue.c - hmx-matmul-ops.c hmx-flash-attn-ops.c + hmx-matmul-ops.c ) # -mhmx enables HMX instruction set (needed by files that include hmx-utils.h) set_source_files_properties( - hmx-matmul-ops.c hmx-flash-attn-ops.c + hmx-matmul-ops.c PROPERTIES COMPILE_OPTIONS "-mhmx" ) diff --git a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c index 3ef0bcdb26d..ab5fd73380b 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c @@ -34,6 +34,10 @@ static const __fp16 q4_0_to_fp16_lut[64] __attribute__((aligned(VLEN))) = { -8, 0, -7, 0, -6, 0, -5, 0, -4, 0, -3, 0, -2, 0, -1, 0, 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7, 0, }; +static const __fp16 q4_1_to_fp16_lut[64] __attribute__((aligned(VLEN))) = { + 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7, 0, 8, 0, 9, 0, 10, 0, 11, 0, 12, 0, 13, 0, 14, 0, 15, 0, +}; + // MXFP4 dequantization LUT: maps 4-bit index to fp16 mantissa value // kvalues: 0, 0.5, 1, 1.5, 2, 3, 4, 6, 0, -0.5, -1, -1.5, -2, -3, -4, -6 static const __fp16 mxfp4_to_fp16_lut[64] __attribute__((aligned(VLEN))) = { @@ -62,6 +66,8 @@ static inline size_t get_x4x2_row_stride(int weight_type, int k) { case HTP_TYPE_Q4_0: case HTP_TYPE_IQ4_NL: return (size_t) nb * (QK_Q4_0x4x2 / 2 + HMX_X4X2_DBLK_SIZE); // 144 * nb + case HTP_TYPE_Q4_1: + return (size_t) nb * (QK_Q4_0x4x2 / 2 + 32); // 160 * nb case HTP_TYPE_Q8_0: return (size_t) nb * (QK_Q8_0x4x2 + HMX_X4X2_DBLK_SIZE); // 272 * nb case HTP_TYPE_MXFP4: @@ -233,6 +239,54 @@ static inline HVX_Vector_x2 dequantize_x4x2_q4_0_x4groups_hvx( return r; } +static inline HVX_Vector dequantize_x4x2_q4_1_group_hvx(const uint8_t *packed_32, bool upper_nibbles, const __fp16 *scale_offset, const HVX_Vector vlut_cvt) { + HVX_Vector vq = hvx_vmemu(packed_32); + const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + HVX_Vector v_dm = hvx_vmemu(scale_offset); + HVX_Vector v_scales = hvx_vec_repl_f16(v_dm); + HVX_Vector v_offsets = hvx_vec_repl_f16(Q6_V_vror_VR(v_dm, 2)); + + HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles); + v_quants = Q6_V_vand_VV(v_quants, mask_h4); + v_quants = Q6_Vb_vshuff_Vb(v_quants); + HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0); + HVX_Vector v_hf = Q6_V_lo_W(vp); + + return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(v_hf, v_scales), v_offsets)); +} + +static inline HVX_Vector_x2 dequantize_x4x2_q4_1_x4groups_hvx( + const uint8_t *packed_128, bool upper_nibbles, + const __fp16 *scales_offsets_4, const HVX_Vector vlut_cvt) { + HVX_Vector vq = hvx_vmemu(packed_128); + const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles); + v_quants = Q6_V_vand_VV(v_quants, mask_h4); + + v_quants = Q6_Vb_vshuff_Vb(v_quants); + + HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0); + HVX_Vector v_lo = Q6_V_lo_W(vp); + HVX_Vector v_hi = Q6_V_hi_W(vp); + + HVX_Vector vscale_offset = hvx_vmemu(scales_offsets_4); + HVX_VectorPair dm_deal = Q6_W_vdeal_VVR(vscale_offset, vscale_offset, -2); + HVX_Vector vd = Q6_V_lo_W(dm_deal); + HVX_Vector vm = Q6_V_hi_W(dm_deal); + + HVX_Vector v_sc01 = hvx_vec_repl_2x_f16(vd); + HVX_Vector v_sc23 = hvx_vec_repl_2x_f16(Q6_V_vror_VR(vd, 4)); + + HVX_Vector v_os01 = hvx_vec_repl_2x_f16(vm); + HVX_Vector v_os23 = hvx_vec_repl_2x_f16(Q6_V_vror_VR(vm, 4)); + + v_lo = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(v_lo, v_sc01), v_os01)); + v_hi = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(v_hi, v_sc23), v_os23)); + + HVX_Vector_x2 r = { v_lo, v_hi }; + return r; +} + // Dequantize one x4x2 Q8_0 group (32 int8 quants) -> 32 FP16 in first 64 bytes. static inline HVX_Vector dequantize_x4x2_q8_0_group_hvx(const int8_t *quants_32, const __fp16 *scale) { HVX_Vector vq = hvx_vmemu(quants_32); @@ -331,11 +385,13 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task( int start_tile, int end_tile) { const int n_k_tiles = (unsigned)k_block / HMX_FP16_TILE_N_COLS; - const bool is_q4 = (weight_type == HTP_TYPE_Q4_0 || weight_type == HTP_TYPE_IQ4_NL); + const bool is_q4 = (weight_type == HTP_TYPE_Q4_0 || weight_type == HTP_TYPE_Q4_1 || weight_type == HTP_TYPE_IQ4_NL); + const bool is_q4_1 = (weight_type == HTP_TYPE_Q4_1); const int qrow_size = is_q4 ? ((unsigned)k_block / 2) : k_block; const HVX_Vector vlut_cvt = (weight_type == HTP_TYPE_IQ4_NL) ? hvx_vmem(iq4_nl_to_fp16_lut) : (weight_type == HTP_TYPE_MXFP4) ? hvx_vmem(mxfp4_to_fp16_lut) : + (weight_type == HTP_TYPE_Q4_1) ? hvx_vmem(q4_1_to_fp16_lut) : hvx_vmem(q4_0_to_fp16_lut); // vscatter setup: write dequantized K-values directly to transposed [K][N] tile positions. @@ -356,8 +412,10 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task( unsigned sub_blk_base = ((kt * 32) % QK_Q4_0x4x2) / 32; // 0 or 4 bool upper = (sub_blk_base >= 4); unsigned packed_off = blk_idx * (QK_Q4_0x4x2 / 2); // 128 contiguous packed bytes - unsigned scale_off = qrow_size + blk_idx * HMX_X4X2_DBLK_SIZE - + sub_blk_base * (int)sizeof(__fp16); // 4 consecutive scales + unsigned dblk_size = is_q4_1 ? 32 : HMX_X4X2_DBLK_SIZE; + unsigned scale_step = is_q4_1 ? 4 : (int)sizeof(__fp16); + unsigned scale_off = qrow_size + blk_idx * dblk_size + + sub_blk_base * scale_step; __fp16 *tile_bases[4]; for (unsigned g = 0; g < 4; g++) { tile_bases[g] = vtcm_dst + (t + g) * HMX_FP16_TILE_N_ELMS; } @@ -367,20 +425,38 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task( unsigned row_offset = ct * HMX_FP16_TILE_N_COLS * row_stride; unsigned row1 = ct * HMX_FP16_TILE_N_COLS + 1; - for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) { - const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride; - const uint8_t *r1 = vtcm_src + row_offset; row_offset += row_stride; + if (is_q4_1) { + for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) { + const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride; + const uint8_t *r1 = vtcm_src + row_offset; row_offset += row_stride; - HVX_Vector_x2 dv0 = dequantize_x4x2_q4_0_x4groups_hvx(r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt); - HVX_Vector_x2 dv1 = dequantize_x4x2_q4_0_x4groups_hvx(r1 + packed_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt); + HVX_Vector_x2 dv0 = dequantize_x4x2_q4_1_x4groups_hvx(r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt); + HVX_Vector_x2 dv1 = dequantize_x4x2_q4_1_x4groups_hvx(r1 + packed_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt); - Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[0]); - Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[1]); - v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[0]); + Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[1]); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); - Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[0]); - Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[1]); - v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[0]); + Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[1]); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + } + } else { + for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) { + const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride; + const uint8_t *r1 = vtcm_src + row_offset; row_offset += row_stride; + + HVX_Vector_x2 dv0 = dequantize_x4x2_q4_0_x4groups_hvx(r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt); + HVX_Vector_x2 dv1 = dequantize_x4x2_q4_0_x4groups_hvx(r1 + packed_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt); + + Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[0]); + Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[1]); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + + Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[0]); + Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[1]); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + } } for (int g = 0; g < 4; g++) { (void) *(volatile HVX_Vector *)(tile_bases[g]); } @@ -446,26 +522,43 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task( unsigned sub_blk = ((kt * 32) % QK_Q4_0x4x2) / 32; bool upper = (sub_blk >= 4); unsigned byte_off = blk_idx * (QK_Q4_0x4x2 / 2) + (upper ? (sub_blk - 4) : sub_blk) * 32; - unsigned scale_off = qrow_size + blk_idx * HMX_X4X2_DBLK_SIZE + sub_blk * (int)sizeof(__fp16); + unsigned dblk_size = is_q4_1 ? 32 : HMX_X4X2_DBLK_SIZE; + unsigned scale_step = is_q4_1 ? 4 : (int)sizeof(__fp16); + unsigned scale_off = qrow_size + blk_idx * dblk_size + sub_blk * scale_step; HVX_Vector v_off = v_scat_base; // reset to column 0 unsigned row_offset = ct * HMX_FP16_TILE_N_COLS * row_stride; unsigned row1 = ct * HMX_FP16_TILE_N_COLS + 1; - for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) { - const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride; - const uint8_t *r1 = vtcm_src + row_offset; row_offset += row_stride; - - HVX_Vector v0 = dequantize_x4x2_q4_0_group_hvx( - r0 + byte_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt); - HVX_Vector v1 = (row1 < n_cols) - ? dequantize_x4x2_q4_0_group_hvx( - r1 + byte_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt) - : Q6_V_vzero(); - - Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0); - v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); - Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v1); - v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + if (is_q4_1) { + for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) { + const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride; + const uint8_t *r1 = vtcm_src + row_offset; row_offset += row_stride; + + HVX_Vector v0 = dequantize_x4x2_q4_1_group_hvx(r0 + byte_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt); + HVX_Vector v1 = (row1 < n_cols) + ? dequantize_x4x2_q4_1_group_hvx(r1 + byte_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt) + : Q6_V_vzero(); + + Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v1); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + } + } else { + for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) { + const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride; + const uint8_t *r1 = vtcm_src + row_offset; row_offset += row_stride; + + HVX_Vector v0 = dequantize_x4x2_q4_0_group_hvx(r0 + byte_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt); + HVX_Vector v1 = (row1 < n_cols) + ? dequantize_x4x2_q4_0_group_hvx(r1 + byte_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt) + : Q6_V_vzero(); + + Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v1); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + } } (void) *(volatile HVX_Vector *)(tile_base); } else if (weight_type == HTP_TYPE_MXFP4) { @@ -593,6 +686,8 @@ static void dequantize_x4x2_weight_chunk_to_fp16_tiles( // --- End x4x2 dequantizers --- +#pragma clang diagnostic ignored "-Wbackend-plugin" // spurios warning for hmx intrinsics + // requires external HMX lock static void core_dot_chunk_fp16(__fp16 *restrict output, const __fp16 *restrict activation, const __fp16 *restrict weight, const __fp16 *restrict scales, int n_row_tiles, int n_col_tiles, int n_dot_tiles) { diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index 54cfadd9b0a..aadc77235ba 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -20,6 +20,7 @@ enum htp_data_type { HTP_TYPE_F32 = 0, HTP_TYPE_F16 = 1, HTP_TYPE_Q4_0 = 2, + HTP_TYPE_Q4_1 = 3, HTP_TYPE_Q8_0 = 8, HTP_TYPE_IQ4_NL = 20, HTP_TYPE_I32 = 26, @@ -28,6 +29,7 @@ enum htp_data_type { // types used internally for repack, dyn.quant, etc HTP_TYPE_Q4_0x4x2 = 200, + HTP_TYPE_Q4_1x4x2, HTP_TYPE_Q8_0x4x2, HTP_TYPE_MXFP4x4x2, diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index f3a0866c7cd..7dd90ac7d7f 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -853,6 +853,11 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) { for (uint32_t i=0; i < n_ops; i++) { struct profile_data prof; + if (i == (n_ops-1)) { + // wake up the host before starting the last op + dspqueue_write_early_wakeup_noblock(queue, 0, 0); + } + profile_start(ctx->profiler, &prof); proc_op_req(octx, tens, i, &ops[i]); @@ -869,8 +874,6 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) { } } - // dspqueue_write_early_wakeup_noblock(ctx->queue, 10, 0); - struct htp_opbatch_rsp rsp; rsp.id = req.id; rsp.status = HTP_STATUS_OK; diff --git a/ggml/src/ggml-hexagon/htp/matmul-ops.c b/ggml/src/ggml-hexagon/htp/matmul-ops.c index 46fc5602dc9..7036c491bc4 100644 --- a/ggml/src/ggml-hexagon/htp/matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/matmul-ops.c @@ -40,6 +40,11 @@ struct htp_matmul_context { const void * restrict vx0, const void * restrict vx1, const void * restrict vy0, const void * restrict vy1); + void (*vec_dot_4x1)(const int n, float * restrict s0, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vx2, const void * restrict vx3, + const void * restrict vy0); + // Precomputed values uint32_t src0_nrows_per_thread; uint32_t src1_nrows_per_thread; @@ -155,6 +160,13 @@ static inline size_t q8x4x2_row_size(uint32_t ne) { return hex_round_up(ne + nb * 8 * sizeof(__fp16), 128); } +static inline size_t q8_1x4x2_row_size(uint32_t ne) { + // ensures perfect alignment of quants and full row + const uint32_t qk = QK_Q8_0x4x2; + const uint32_t nb = (ne + qk - 1) / qk; + return hex_round_up(ne + nb * 8 * 2 * sizeof(__fp16), 128); +} + static inline HVX_Vector_x8 hvx_vec_load_q4x4x8_full(const uint8_t * restrict ptr) { const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; @@ -223,6 +235,62 @@ static HVX_Vector_x8 hvx_vec_load_q4x4x8_partial(const uint8_t * restrict ptr, u return r; } +static inline HVX_Vector_x8 hvx_vec_load_q4_1x4x8_full(const uint8_t * restrict ptr) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; + + HVX_Vector v0_1 = vptr[0]; // first 256 elements (128 bytes) + HVX_Vector v2_3 = vptr[1]; // ... + HVX_Vector v4_5 = vptr[2]; // ... + HVX_Vector v6_7 = vptr[3]; // ... + + const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + + HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4); // & 0x0F : first 128 elements + HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4); // >> 4 : second 128 elements + HVX_Vector v2 = Q6_V_vand_VV(v2_3, mask_h4); // & 0x0F ... + HVX_Vector v3 = Q6_Vub_vlsr_VubR(v2_3, 4); // >> 4 + HVX_Vector v4 = Q6_V_vand_VV(v4_5, mask_h4); // & 0x0F + HVX_Vector v5 = Q6_Vub_vlsr_VubR(v4_5, 4); // >> 4 + HVX_Vector v6 = Q6_V_vand_VV(v6_7, mask_h4); // & 0x0F + HVX_Vector v7 = Q6_Vub_vlsr_VubR(v6_7, 4); // >> 4 + + HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 }; + return r; +} + +static HVX_Vector_x8 hvx_vec_load_q4_1x4x8_partial(const uint8_t * restrict ptr, uint32_t n) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; + + const uint32_t qk = QK_Q4_0x4x2; // 256 + const uint32_t nb = n / qk; + const uint32_t nloe = n % qk; + + const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + + HVX_Vector_x8 r; + uint32_t i = 0; + + #pragma unroll(2) + for (i=0; i < nb; i++) { + HVX_Vector v = vptr[i]; // 256 elements (128 bytes) + HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : first 128 elements + HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : second 128 elements + r.v[i*2+0] = v0; + r.v[i*2+1] = v1; + } + + if (nloe) { + HVX_Vector v = vptr[i]; // 256 elements (128 bytes) + HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : even 128 elements + HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : odd 128 elements + HVX_VectorPair v0_1_p = Q6_W_vshuff_VVR(v1, v0, -1); // zip even:odd:... + r.v[i*2+0] = Q6_V_lo_W(v0_1_p); + r.v[i*2+1] = Q6_V_hi_W(v0_1_p); + } + + return r; +} + static inline HVX_Vector_x8 hvx_vec_load_mxfp4x4x8_full(const uint8_t * restrict ptr) { const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; @@ -401,82 +469,96 @@ static inline HVX_Vector hvx_vec_rmpy_x8_partial(HVX_Vector_x8 x, HVX_Vector_x8 return hvx_vec_rmpy_x8_partial(x, y, 512); } -static void vec_dot_q4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) { +static void vec_dot_q4_1x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) { assert(n % 32 == 0); // min sub-block size assert((unsigned long) vx0 % 128 == 0); assert((unsigned long) vy0 % 128 == 0); const uint32_t qk = QK_Q4_0x4x2 * 4; - const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_dblk_size = 8 * 4 * 2 * 2; // 32x (d, m) __fp16 = 128 bytes const uint32_t x_qblk_size = qk / 2; // int4 const uint32_t x_qrow_size = n / 2; // int4 (not padded) - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_dblk_size = 8 * 4 * 4; // 32x (d, s) __fp16 = 128 bytes const uint32_t y_qblk_size = qk; // int8 const uint32_t y_qrow_size = n; // int8 (not padded) const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales/offsets const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales/sums // Row sum (sf) HVX_Vector r0_sum = Q6_V_vzero(); - // Multiply and accumulate into int32. - // Compute combined scale (fp32). - // Apply scale to acc and accumulate into the row sum (qf32). - const uint32_t nb = n / qk; // num full blocks const uint32_t nloe = n % qk; // num leftover elemements uint32_t i = 0; for (; i < nb; i++) { HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_full(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_full(r0_x_q + i * x_qblk_size); HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector ds = *(const HVX_UVector *) (y_d + i * y_dblk_size); + HVX_VectorPair ds_deal = Q6_W_vdeal_VVR(ds, ds, -2); + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds_deal)); + HVX_Vector vy_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds_deal)); + + HVX_Vector dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); + HVX_VectorPair dm_deal = Q6_W_vdeal_VVR(dm, dm, -2); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(dm_deal)); + HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(dm_deal)); HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy_s))); HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_ms); - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa_total, r0_sum)); } // Process leftovers if (nloe) { HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); - HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_partial(r0_x_q + i * x_qblk_size, nloe); HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector ds = *(const HVX_UVector *) (y_d + i * y_dblk_size); + HVX_VectorPair ds_deal = Q6_W_vdeal_VVR(ds, ds, -2); + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds_deal)); + HVX_Vector vy_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds_deal)); + + HVX_Vector dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); + HVX_VectorPair dm_deal = Q6_W_vdeal_VVR(dm, dm, -2); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(dm_deal)); + HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(dm_deal)); HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy_s))); // Zero out unused elements HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); r0_dd = Q6_V_vand_QV(bmask, r0_dd); + r0_ms = Q6_V_vand_QV(bmask, r0_ms); r0_ia = Q6_V_vand_QV(bmask, r0_ia); HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_ms); - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa_total, r0_sum)); } r0_sum = hvx_vec_reduce_sum_f32(r0_sum); - hvx_vec_store_u(s0, 4, r0_sum); } -static void vec_dot_q4x4x2_q8x4x2_2x1(const int n, float * restrict s0, +static void vec_dot_q4_1x4x2_q8x4x2_2x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vx1, const void * restrict vy0) { assert(n % 32 == 0); // min sub-block size @@ -486,11 +568,11 @@ static void vec_dot_q4x4x2_q8x4x2_2x1(const int n, float * restrict s0, const uint32_t qk = QK_Q4_0x4x2 * 4; - const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_dblk_size = 8 * 4 * 2 * 2; // 32x (d, m) __fp16 = 128 bytes const uint32_t x_qblk_size = qk / 2; // int4 const uint32_t x_qrow_size = n / 2; // int4 (not padded) - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_dblk_size = 8 * 4 * 4; // 32x (d, s) __fp16 = 128 bytes const uint32_t y_qblk_size = qk; // int8 const uint32_t y_qrow_size = n; // int8 (not padded) @@ -500,77 +582,306 @@ static void vec_dot_q4x4x2_q8x4x2_2x1(const int n, float * restrict s0, const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales/sums // Row sum (sf) HVX_Vector r0_sum = Q6_V_vzero(); HVX_Vector r1_sum = Q6_V_vzero(); - // Multiply and accumulate into int32. - // Compute combined scale (fp32). - // Apply scale to acc and accumulate into the row sum (qf32). - const uint32_t nb = n / qk; // num full blocks const uint32_t nloe = n % qk; // num leftover elemements uint32_t i = 0; for (; i < nb; i++) { HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_full(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_full(r1_x_q + i * x_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_full(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_q4_1x4x8_full(r1_x_q + i * x_qblk_size); HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + HVX_Vector ds = *(const HVX_UVector *) (y_d + i * y_dblk_size); + HVX_VectorPair ds_deal = Q6_W_vdeal_VVR(ds, ds, -2); + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds_deal)); + HVX_Vector vy_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds_deal)); + + HVX_Vector r0_dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); + HVX_VectorPair r0_dm_deal = Q6_W_vdeal_VVR(r0_dm, r0_dm, -2); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r0_dm_deal)); + HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r0_dm_deal)); + + HVX_Vector r1_dm = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); + HVX_VectorPair r1_dm_deal = Q6_W_vdeal_VVR(r1_dm, r1_dm, -2); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r1_dm_deal)); + HVX_Vector r1_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r1_dm_deal)); HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy_s))); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + HVX_Vector r1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy_s))); HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_ms); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + HVX_Vector r1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_ms); - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); - r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa_total, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa_total, r1_sum)); } // Process leftovers if (nloe) { HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); - HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_partial(r1_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_partial(r0_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r1_q = hvx_vec_load_q4_1x4x8_partial(r1_x_q + i * x_qblk_size, nloe); HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe)); - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + HVX_Vector ds = *(const HVX_UVector *) (y_d + i * y_dblk_size); + HVX_VectorPair ds_deal = Q6_W_vdeal_VVR(ds, ds, -2); + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds_deal)); + HVX_Vector vy_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds_deal)); + + HVX_Vector r0_dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); + HVX_VectorPair r0_dm_deal = Q6_W_vdeal_VVR(r0_dm, r0_dm, -2); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r0_dm_deal)); + HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r0_dm_deal)); + + HVX_Vector r1_dm = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); + HVX_VectorPair r1_dm_deal = Q6_W_vdeal_VVR(r1_dm, r1_dm, -2); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r1_dm_deal)); + HVX_Vector r1_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r1_dm_deal)); HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy_s))); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + HVX_Vector r1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy_s))); // Zero out unused elements HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); r0_dd = Q6_V_vand_QV(bmask, r0_dd); + r0_ms = Q6_V_vand_QV(bmask, r0_ms); r1_dd = Q6_V_vand_QV(bmask, r1_dd); + r1_ms = Q6_V_vand_QV(bmask, r1_ms); r0_ia = Q6_V_vand_QV(bmask, r0_ia); r1_ia = Q6_V_vand_QV(bmask, r1_ia); HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_ms); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + HVX_Vector r1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_ms); - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); - r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa_total, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa_total, r1_sum)); } HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum); hvx_vec_store_u(s0, 8, rsum); } -static void vec_dot_q4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1, +static void vec_dot_q4_1x4x2_q8x4x2_4x1(const int n, float * restrict s0, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vx2, const void * restrict vx3, + const void * restrict vy0) { + assert(n % 32 == 0); // min sub-block size + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vx2 % 128 == 0); + assert((unsigned long) vx3 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + + const uint32_t qk = QK_Q4_0x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 2 * 2; // 32x (d, m) __fp16 = 128 bytes + const uint32_t x_qblk_size = qk / 2; // int4 + const uint32_t x_qrow_size = n / 2; // int4 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 4; // 32x (d, s) __fp16 = 128 bytes + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales + const uint8_t * restrict r2_x_q = ((const uint8_t *) vx2) + 0; // quants first + const uint8_t * restrict r2_x_d = ((const uint8_t *) vx2) + x_qrow_size; // then scales + const uint8_t * restrict r3_x_q = ((const uint8_t *) vx3) + 0; // quants first + const uint8_t * restrict r3_x_d = ((const uint8_t *) vx3) + x_qrow_size; // then scales + + const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales/sums + + // Row sum (sf) + HVX_Vector r0_sum = Q6_V_vzero(); + HVX_Vector r1_sum = Q6_V_vzero(); + HVX_Vector r2_sum = Q6_V_vzero(); + HVX_Vector r3_sum = Q6_V_vzero(); + + const uint32_t nb = n / qk; // num full blocks + const uint32_t nloe = n % qk; // num leftover elements + + uint32_t i = 0; + for (; i < nb; i++) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_full(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_q4_1x4x8_full(r1_x_q + i * x_qblk_size); + HVX_Vector_x8 r2_q = hvx_vec_load_q4_1x4x8_full(r2_x_q + i * x_qblk_size); + HVX_Vector_x8 r3_q = hvx_vec_load_q4_1x4x8_full(r3_x_q + i * x_qblk_size); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); + HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r2_q, vy_q)); + HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r3_q, vy_q)); + + HVX_Vector ds = *(const HVX_UVector *) (y_d + i * y_dblk_size); + HVX_VectorPair ds_deal = Q6_W_vdeal_VVR(ds, ds, -2); + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds_deal)); + HVX_Vector vy_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds_deal)); + + HVX_Vector r0_dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); + HVX_VectorPair r0_dm_deal = Q6_W_vdeal_VVR(r0_dm, r0_dm, -2); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r0_dm_deal)); + HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r0_dm_deal)); + + HVX_Vector r1_dm = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); + HVX_VectorPair r1_dm_deal = Q6_W_vdeal_VVR(r1_dm, r1_dm, -2); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r1_dm_deal)); + HVX_Vector r1_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r1_dm_deal)); + + HVX_Vector r2_dm = *(const HVX_UVector *) (r2_x_d + i * x_dblk_size); + HVX_VectorPair r2_dm_deal = Q6_W_vdeal_VVR(r2_dm, r2_dm, -2); + HVX_Vector r2_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r2_dm_deal)); + HVX_Vector r2_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r2_dm_deal)); + + HVX_Vector r3_dm = *(const HVX_UVector *) (r3_x_d + i * x_dblk_size); + HVX_VectorPair r3_dm_deal = Q6_W_vdeal_VVR(r3_dm, r3_dm, -2); + HVX_Vector r3_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r3_dm_deal)); + HVX_Vector r3_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r3_dm_deal)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy_s))); + + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + HVX_Vector r1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy_s))); + + HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d))); + HVX_Vector r2_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_m, vy_s))); + + HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d))); + HVX_Vector r3_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_m, vy_s))); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_ms); + + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + HVX_Vector r1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_ms); + + HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); + HVX_Vector r2_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_ms); + + HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); + HVX_Vector r3_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_ms); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa_total, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa_total, r1_sum)); + r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa_total, r2_sum)); + r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa_total, r3_sum)); + } + + if (nloe) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_partial(r0_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r1_q = hvx_vec_load_q4_1x4x8_partial(r1_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r2_q = hvx_vec_load_q4_1x4x8_partial(r2_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r3_q = hvx_vec_load_q4_1x4x8_partial(r3_x_q + i * x_qblk_size, nloe); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe)); + HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r2_q, vy_q, nloe)); + HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r3_q, vy_q, nloe)); + + HVX_Vector ds = *(const HVX_UVector *) (y_d + i * y_dblk_size); + HVX_VectorPair ds_deal = Q6_W_vdeal_VVR(ds, ds, -2); + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds_deal)); + HVX_Vector vy_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds_deal)); + + HVX_Vector r0_dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); + HVX_VectorPair r0_dm_deal = Q6_W_vdeal_VVR(r0_dm, r0_dm, -2); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r0_dm_deal)); + HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r0_dm_deal)); + + HVX_Vector r1_dm = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); + HVX_VectorPair r1_dm_deal = Q6_W_vdeal_VVR(r1_dm, r1_dm, -2); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r1_dm_deal)); + HVX_Vector r1_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r1_dm_deal)); + + HVX_Vector r2_dm = *(const HVX_UVector *) (r2_x_d + i * x_dblk_size); + HVX_VectorPair r2_dm_deal = Q6_W_vdeal_VVR(r2_dm, r2_dm, -2); + HVX_Vector r2_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r2_dm_deal)); + HVX_Vector r2_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r2_dm_deal)); + + HVX_Vector r3_dm = *(const HVX_UVector *) (r3_x_d + i * x_dblk_size); + HVX_VectorPair r3_dm_deal = Q6_W_vdeal_VVR(r3_dm, r3_dm, -2); + HVX_Vector r3_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r3_dm_deal)); + HVX_Vector r3_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r3_dm_deal)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy_s))); + + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + HVX_Vector r1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy_s))); + + HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d))); + HVX_Vector r2_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_m, vy_s))); + + HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d))); + HVX_Vector r3_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_m, vy_s))); + + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_dd = Q6_V_vand_QV(bmask, r0_dd); + r0_ms = Q6_V_vand_QV(bmask, r0_ms); + r1_dd = Q6_V_vand_QV(bmask, r1_dd); + r1_ms = Q6_V_vand_QV(bmask, r1_ms); + r2_dd = Q6_V_vand_QV(bmask, r2_dd); + r2_ms = Q6_V_vand_QV(bmask, r2_ms); + r3_dd = Q6_V_vand_QV(bmask, r3_dd); + r3_ms = Q6_V_vand_QV(bmask, r3_ms); + r0_ia = Q6_V_vand_QV(bmask, r0_ia); + r1_ia = Q6_V_vand_QV(bmask, r1_ia); + r2_ia = Q6_V_vand_QV(bmask, r2_ia); + r3_ia = Q6_V_vand_QV(bmask, r3_ia); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_ms); + + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + HVX_Vector r1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_ms); + + HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); + HVX_Vector r2_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_ms); + + HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); + HVX_Vector r3_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_ms); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa_total, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa_total, r1_sum)); + r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa_total, r2_sum)); + r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa_total, r3_sum)); + } + + HVX_Vector_x4 rsum_in = { .v = { r0_sum, r1_sum, r2_sum, r3_sum } }; + HVX_Vector rsum = hvx_vec_reduce_sum_f32x4(rsum_in); + hvx_vec_store_u(s0, 16, rsum); +} + + +static void vec_dot_q4_1x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1, const void * restrict vx0, const void * restrict vx1, const void * restrict vy0, const void * restrict vy1) { assert(n % 32 == 0); @@ -581,11 +892,11 @@ static void vec_dot_q4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * const uint32_t qk = QK_Q4_0x4x2 * 4; - const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_dblk_size = 8 * 4 * 2 * 2; // 32x (d, m) __fp16 = 128 bytes const uint32_t x_qblk_size = qk / 2; // int4 const uint32_t x_qrow_size = n / 2; // int4 (not padded) - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_dblk_size = 8 * 4 * 4; // 32x (d, s) __fp16 = 128 bytes const uint32_t y_qblk_size = qk; // int8 const uint32_t y_qrow_size = n; // int8 (not padded) @@ -595,9 +906,9 @@ static void vec_dot_q4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0; // quants first - const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales + const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales/sums const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0; // quants first - const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; // then scales + const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; // then scales/sums // Row sums (sf) - 4 accumulators for 2×2 tile HVX_Vector r0_c0_sum = Q6_V_vzero(); @@ -610,13 +921,13 @@ static void vec_dot_q4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * uint32_t i = 0; for (; i < nb; i++) { - // Load src1 columns (reused across both src0 rows) + // Load src1 columns HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_full(y0_q + i * y_qblk_size); HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_full(y1_q + i * y_qblk_size); - // Load src0 rows (reused across both src1 columns) - HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_full(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_full(r1_x_q + i * x_qblk_size); + // Load src0 rows + HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_full(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_q4_1x4x8_full(r1_x_q + i * x_qblk_size); // Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1 HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q)); @@ -625,16 +936,38 @@ static void vec_dot_q4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q)); // Load scales - HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); - HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + HVX_Vector ds0 = *(const HVX_UVector *) (y0_d + i * y_dblk_size); + HVX_VectorPair ds0_deal = Q6_W_vdeal_VVR(ds0, ds0, -2); + HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds0_deal)); + HVX_Vector vy0_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds0_deal)); + + HVX_Vector ds1 = *(const HVX_UVector *) (y1_d + i * y_dblk_size); + HVX_VectorPair ds1_deal = Q6_W_vdeal_VVR(ds1, ds1, -2); + HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds1_deal)); + HVX_Vector vy1_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds1_deal)); + + HVX_Vector r0_dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); + HVX_VectorPair r0_dm_deal = Q6_W_vdeal_VVR(r0_dm, r0_dm, -2); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r0_dm_deal)); + HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r0_dm_deal)); + + HVX_Vector r1_dm = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); + HVX_VectorPair r1_dm_deal = Q6_W_vdeal_VVR(r1_dm, r1_dm, -2); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r1_dm_deal)); + HVX_Vector r1_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r1_dm_deal)); // Compute combined scales HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d))); + HVX_Vector r0_c0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy0_s))); + HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d))); + HVX_Vector r0_c1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy1_s))); + HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d))); + HVX_Vector r1_c0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy0_s))); + HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d))); + HVX_Vector r1_c1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy1_s))); // Apply scales and accumulate HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); @@ -642,40 +975,72 @@ static void vec_dot_q4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); - r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); - r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); - r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); - r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); + HVX_Vector r0_c0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_ms); + HVX_Vector r0_c1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_ms); + HVX_Vector r1_c0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_ms); + HVX_Vector r1_c1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_ms); + + r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa_total, r0_c0_sum)); + r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa_total, r0_c1_sum)); + r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa_total, r1_c0_sum)); + r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa_total, r1_c1_sum)); } // Process leftovers if (nloe) { HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_partial(y0_q + i * y_qblk_size, nloe); HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_partial(y1_q + i * y_qblk_size, nloe); - HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_partial(r1_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_partial(r0_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r1_q = hvx_vec_load_q4_1x4x8_partial(r1_x_q + i * x_qblk_size, nloe); HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy0_q, nloe)); HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy1_q, nloe)); HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy0_q, nloe)); HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy1_q, nloe)); - HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); - HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + HVX_Vector ds0 = *(const HVX_UVector *) (y0_d + i * y_dblk_size); + HVX_VectorPair ds0_deal = Q6_W_vdeal_VVR(ds0, ds0, -2); + HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds0_deal)); + HVX_Vector vy0_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds0_deal)); + + HVX_Vector ds1 = *(const HVX_UVector *) (y1_d + i * y_dblk_size); + HVX_VectorPair ds1_deal = Q6_W_vdeal_VVR(ds1, ds1, -2); + HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds1_deal)); + HVX_Vector vy1_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds1_deal)); + + HVX_Vector r0_dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); + HVX_VectorPair r0_dm_deal = Q6_W_vdeal_VVR(r0_dm, r0_dm, -2); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r0_dm_deal)); + HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r0_dm_deal)); + + HVX_Vector r1_dm = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); + HVX_VectorPair r1_dm_deal = Q6_W_vdeal_VVR(r1_dm, r1_dm, -2); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r1_dm_deal)); + HVX_Vector r1_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r1_dm_deal)); HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d))); + HVX_Vector r0_c0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy0_s))); + HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d))); + HVX_Vector r0_c1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy1_s))); + HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d))); + HVX_Vector r1_c0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy0_s))); + HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d))); + HVX_Vector r1_c1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy1_s))); - // Zero out unused scales + // Zero out unused elements HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd); + r0_c0_ms = Q6_V_vand_QV(bmask, r0_c0_ms); r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd); + r0_c1_ms = Q6_V_vand_QV(bmask, r0_c1_ms); r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd); + r1_c0_ms = Q6_V_vand_QV(bmask, r1_c0_ms); r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd); + r1_c1_ms = Q6_V_vand_QV(bmask, r1_c1_ms); + r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia); r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia); r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia); @@ -686,10 +1051,15 @@ static void vec_dot_q4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); - r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); - r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); - r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); - r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); + HVX_Vector r0_c0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_ms); + HVX_Vector r0_c1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_ms); + HVX_Vector r1_c0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_ms); + HVX_Vector r1_c1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_ms); + + r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa_total, r0_c0_sum)); + r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa_total, r0_c1_sum)); + r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa_total, r1_c0_sum)); + r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa_total, r1_c1_sum)); } // Reduce and store results @@ -700,26 +1070,26 @@ static void vec_dot_q4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * hvx_vec_store_u(s1, 8, r0_r1_c1_sum); // row0,col1 row1,col1 } -static void vec_dot_q8x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) { +static void vec_dot_q4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) { assert(n % 32 == 0); // min sub-block size assert((unsigned long) vx0 % 128 == 0); assert((unsigned long) vy0 % 128 == 0); const uint32_t qk = QK_Q4_0x4x2 * 4; - const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t x_qblk_size = qk; // int8 - const uint32_t x_qrow_size = n; // int8 (not padded) + const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_qblk_size = qk / 2; // int4 + const uint32_t x_qrow_size = n / 2; // int4 (not padded) - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales - const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales + const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales // Row sum (sf) HVX_Vector r0_sum = Q6_V_vzero(); @@ -729,12 +1099,12 @@ static void vec_dot_q8x4x2_q8x4x2_1x1(const int n, float * restrict s0, const vo // Apply scale to acc and accumulate into the row sum (qf32). const uint32_t nb = n / qk; // num full blocks - int32_t nloe = n % qk; // num leftover elemements (must be signed) + const uint32_t nloe = n % qk; // num leftover elemements uint32_t i = 0; for (; i < nb; i++) { HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_full(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_full(r0_x_q + i * x_qblk_size); HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); @@ -751,7 +1121,433 @@ static void vec_dot_q8x4x2_q8x4x2_1x1(const int n, float * restrict s0, const vo // Process leftovers if (nloe) { HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); - HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_partial(r0_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + + // Zero out unused elements + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_dd = Q6_V_vand_QV(bmask, r0_dd); + r0_ia = Q6_V_vand_QV(bmask, r0_ia); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + } + + r0_sum = hvx_vec_reduce_sum_f32(r0_sum); + + hvx_vec_store_u(s0, 4, r0_sum); +} + +static void vec_dot_q4x4x2_q8x4x2_2x1(const int n, float * restrict s0, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0) { + assert(n % 32 == 0); // min sub-block size + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + + const uint32_t qk = QK_Q4_0x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_qblk_size = qk / 2; // int4 + const uint32_t x_qrow_size = n / 2; // int4 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales + + const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales + + // Row sum (sf) + HVX_Vector r0_sum = Q6_V_vzero(); + HVX_Vector r1_sum = Q6_V_vzero(); + + // Multiply and accumulate into int32. + // Compute combined scale (fp32). + // Apply scale to acc and accumulate into the row sum (qf32). + + const uint32_t nb = n / qk; // num full blocks + const uint32_t nloe = n % qk; // num leftover elemements + + uint32_t i = 0; + for (; i < nb; i++) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_full(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_full(r1_x_q + i * x_qblk_size); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + } + + // Process leftovers + if (nloe) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_partial(r1_x_q + i * x_qblk_size, nloe); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + + // Zero out unused elements + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_dd = Q6_V_vand_QV(bmask, r0_dd); + r1_dd = Q6_V_vand_QV(bmask, r1_dd); + r0_ia = Q6_V_vand_QV(bmask, r0_ia); + r1_ia = Q6_V_vand_QV(bmask, r1_ia); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + } + + HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum); + hvx_vec_store_u(s0, 8, rsum); +} + +static void vec_dot_q4x4x2_q8x4x2_4x1(const int n, float * restrict s0, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vx2, const void * restrict vx3, + const void * restrict vy0) { + assert(n % 32 == 0); // min sub-block size + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vx2 % 128 == 0); + assert((unsigned long) vx3 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + + const uint32_t qk = QK_Q4_0x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_qblk_size = qk / 2; // int4 + const uint32_t x_qrow_size = n / 2; // int4 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; + const uint8_t * restrict r2_x_q = ((const uint8_t *) vx2) + 0; + const uint8_t * restrict r2_x_d = ((const uint8_t *) vx2) + x_qrow_size; + const uint8_t * restrict r3_x_q = ((const uint8_t *) vx3) + 0; + const uint8_t * restrict r3_x_d = ((const uint8_t *) vx3) + x_qrow_size; + + const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); + + // Row sum (sf) + HVX_Vector r0_sum = Q6_V_vzero(); + HVX_Vector r1_sum = Q6_V_vzero(); + HVX_Vector r2_sum = Q6_V_vzero(); + HVX_Vector r3_sum = Q6_V_vzero(); + + const uint32_t nb = n / qk; // num full blocks + const uint32_t nloe = n % qk; // num leftover elements + + uint32_t i = 0; + for (; i < nb; i++) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_full(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_full(r1_x_q + i * x_qblk_size); + HVX_Vector_x8 r2_q = hvx_vec_load_q4x4x8_full(r2_x_q + i * x_qblk_size); + HVX_Vector_x8 r3_q = hvx_vec_load_q4x4x8_full(r3_x_q + i * x_qblk_size); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); + HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r2_q, vy_q)); + HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r3_q, vy_q)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + HVX_Vector r2_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r2_x_d + i * x_dblk_size)); + HVX_Vector r3_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r3_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d))); + HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d))); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); + HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum)); + r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum)); + } + + if (nloe) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_partial(r1_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r2_q = hvx_vec_load_q4x4x8_partial(r2_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r3_q = hvx_vec_load_q4x4x8_partial(r3_x_q + i * x_qblk_size, nloe); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe)); + HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r2_q, vy_q, nloe)); + HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r3_q, vy_q, nloe)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + HVX_Vector r2_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r2_x_d + i * x_dblk_size)); + HVX_Vector r3_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r3_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d))); + HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d))); + + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_dd = Q6_V_vand_QV(bmask, r0_dd); + r1_dd = Q6_V_vand_QV(bmask, r1_dd); + r2_dd = Q6_V_vand_QV(bmask, r2_dd); + r3_dd = Q6_V_vand_QV(bmask, r3_dd); + r0_ia = Q6_V_vand_QV(bmask, r0_ia); + r1_ia = Q6_V_vand_QV(bmask, r1_ia); + r2_ia = Q6_V_vand_QV(bmask, r2_ia); + r3_ia = Q6_V_vand_QV(bmask, r3_ia); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); + HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum)); + r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum)); + } + + HVX_Vector_x4 rsum_in = { .v = { r0_sum, r1_sum, r2_sum, r3_sum } }; + HVX_Vector rsum = hvx_vec_reduce_sum_f32x4(rsum_in); + hvx_vec_store_u(s0, 16, rsum); +} + + +static void vec_dot_q4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0, const void * restrict vy1) { + assert(n % 32 == 0); + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + assert((unsigned long) vy1 % 128 == 0); + + const uint32_t qk = QK_Q4_0x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_qblk_size = qk / 2; // int4 + const uint32_t x_qrow_size = n / 2; // int4 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales + + const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0; // quants first + const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales + const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0; // quants first + const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; // then scales + + // Row sums (sf) - 4 accumulators for 2×2 tile + HVX_Vector r0_c0_sum = Q6_V_vzero(); + HVX_Vector r0_c1_sum = Q6_V_vzero(); + HVX_Vector r1_c0_sum = Q6_V_vzero(); + HVX_Vector r1_c1_sum = Q6_V_vzero(); + + const uint32_t nb = n / qk; // num full blocks + const uint32_t nloe = n % qk; // num leftover elements + + uint32_t i = 0; + for (; i < nb; i++) { + // Load src1 columns (reused across both src0 rows) + HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_full(y0_q + i * y_qblk_size); + HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_full(y1_q + i * y_qblk_size); + + // Load src0 rows (reused across both src1 columns) + HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_full(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_full(r1_x_q + i * x_qblk_size); + + // Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1 + HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q)); + HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q)); + HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q)); + HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q)); + + // Load scales + HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); + HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + + // Compute combined scales + HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d))); + HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d))); + HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d))); + HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d))); + + // Apply scales and accumulate + HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); + HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); + HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); + HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); + + r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); + r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); + r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); + r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); + } + + // Process leftovers + if (nloe) { + HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_partial(y0_q + i * y_qblk_size, nloe); + HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_partial(y1_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_partial(r1_x_q + i * x_qblk_size, nloe); + + HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy0_q, nloe)); + HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy1_q, nloe)); + HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy0_q, nloe)); + HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy1_q, nloe)); + + HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); + HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + + HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d))); + HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d))); + HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d))); + HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d))); + + // Zero out unused scales + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd); + r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd); + r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd); + r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd); + r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia); + r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia); + r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia); + r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia); + + HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); + HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); + HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); + HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); + + r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); + r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); + r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); + r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); + } + + // Reduce and store results + HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum); + HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum); + + hvx_vec_store_u(s0, 8, r0_r1_c0_sum); // row0,col0 row1,col0 + hvx_vec_store_u(s1, 8, r0_r1_c1_sum); // row0,col1 row1,col1 +} + +static void vec_dot_q8x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) { + assert(n % 32 == 0); // min sub-block size + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + + const uint32_t qk = QK_Q4_0x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_qblk_size = qk; // int8 + const uint32_t x_qrow_size = n; // int8 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales + + const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales + + // Row sum (sf) + HVX_Vector r0_sum = Q6_V_vzero(); + + // Multiply and accumulate into int32. + // Compute combined scale (fp32). + // Apply scale to acc and accumulate into the row sum (qf32). + + const uint32_t nb = n / qk; // num full blocks + int32_t nloe = n % qk; // num leftover elemements (must be signed) + + uint32_t i = 0; + for (; i < nb; i++) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_full(r0_x_q + i * x_qblk_size); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + } + + // Process leftovers + if (nloe) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_partial(r0_x_q + i * x_qblk_size, nloe); HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); @@ -804,10 +1600,109 @@ static void vec_dot_q8x4x2_q8x4x2_2x1(const int n, float * restrict s0, // Row sum (qf32) HVX_Vector r0_sum = Q6_V_vzero(); HVX_Vector r1_sum = Q6_V_vzero(); - - // Multiply and accumulate into int32. - // Compute combined scale (fp32). - // Apply scale to acc and accumulate into the row sum (qf32). + + // Multiply and accumulate into int32. + // Compute combined scale (fp32). + // Apply scale to acc and accumulate into the row sum (qf32). + + const uint32_t nb = n / qk; // num full blocks + int32_t nloe = n % qk; // num leftover elemements (must be signed) + + uint32_t i = 0; + for (; i < nb; i++) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_full(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_full(r1_x_q + i * x_qblk_size); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + } + + // Process leftovers + if (nloe) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_partial(r0_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_partial(r1_x_q + i * x_qblk_size, nloe); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + + // Zero out unused elements + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_dd = Q6_V_vand_QV(bmask, r0_dd); + r1_dd = Q6_V_vand_QV(bmask, r1_dd); + r0_ia = Q6_V_vand_QV(bmask, r0_ia); + r1_ia = Q6_V_vand_QV(bmask, r1_ia); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + } + + HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum); + hvx_vec_store_u(s0, 8, rsum); +} + +static void vec_dot_q8x4x2_q8x4x2_4x1(const int n, float * restrict s0, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vx2, const void * restrict vx3, + const void * restrict vy0) { + assert(n % 32 == 0); // min sub-block size + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vx2 % 128 == 0); + assert((unsigned long) vx3 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + + const uint32_t qk = QK_Q4_0x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_qblk_size = qk; // int8 + const uint32_t x_qrow_size = n; // int8 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales + const uint8_t * restrict r2_x_q = ((const uint8_t *) vx2) + 0; // quants first + const uint8_t * restrict r2_x_d = ((const uint8_t *) vx2) + x_qrow_size; // then scales + const uint8_t * restrict r3_x_q = ((const uint8_t *) vx3) + 0; // quants first + const uint8_t * restrict r3_x_d = ((const uint8_t *) vx3) + x_qrow_size; // then scales + + const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales + + // Row sum (qf32) + HVX_Vector r0_sum = Q6_V_vzero(); + HVX_Vector r1_sum = Q6_V_vzero(); + HVX_Vector r2_sum = Q6_V_vzero(); + HVX_Vector r3_sum = Q6_V_vzero(); const uint32_t nb = n / qk; // num full blocks int32_t nloe = n % qk; // num leftover elemements (must be signed) @@ -817,58 +1712,86 @@ static void vec_dot_q8x4x2_q8x4x2_2x1(const int n, float * restrict s0, HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_full(r0_x_q + i * x_qblk_size); HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_full(r1_x_q + i * x_qblk_size); + HVX_Vector_x8 r2_q = hvx_vec_load_q8x4x8_full(r2_x_q + i * x_qblk_size); + HVX_Vector_x8 r3_q = hvx_vec_load_q8x4x8_full(r3_x_q + i * x_qblk_size); HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); + HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r2_q, vy_q)); + HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r3_q, vy_q)); HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + HVX_Vector r2_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r2_x_d + i * x_dblk_size)); + HVX_Vector r3_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r3_x_d + i * x_dblk_size)); HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d))); + HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d))); HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); + HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum)); + r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum)); } - // Process leftovers if (nloe) { HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_partial(r0_x_q + i * x_qblk_size, nloe); HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_partial(r1_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r2_q = hvx_vec_load_q8x4x8_partial(r2_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r3_q = hvx_vec_load_q8x4x8_partial(r3_x_q + i * x_qblk_size, nloe); HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe)); + HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r2_q, vy_q, nloe)); + HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r3_q, vy_q, nloe)); - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + HVX_Vector r2_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r2_x_d + i * x_dblk_size)); + HVX_Vector r3_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r3_x_d + i * x_dblk_size)); HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d))); + HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d))); - // Zero out unused elements HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); r0_dd = Q6_V_vand_QV(bmask, r0_dd); r1_dd = Q6_V_vand_QV(bmask, r1_dd); + r2_dd = Q6_V_vand_QV(bmask, r2_dd); + r3_dd = Q6_V_vand_QV(bmask, r3_dd); r0_ia = Q6_V_vand_QV(bmask, r0_ia); r1_ia = Q6_V_vand_QV(bmask, r1_ia); + r2_ia = Q6_V_vand_QV(bmask, r2_ia); + r3_ia = Q6_V_vand_QV(bmask, r3_ia); HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); + HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum)); + r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum)); } - HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum); - hvx_vec_store_u(s0, 8, rsum); + HVX_Vector_x4 rsum_in = { .v = { r0_sum, r1_sum, r2_sum, r3_sum } }; + HVX_Vector rsum = hvx_vec_reduce_sum_f32x4(rsum_in); + hvx_vec_store_u(s0, 16, rsum); } + static void vec_dot_q8x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1, const void * restrict vx0, const void * restrict vx1, const void * restrict vy0, const void * restrict vy1) { @@ -1163,6 +2086,135 @@ static void vec_dot_iq4nlx4x2_q8x4x2_2x1(const int n, hvx_vec_store_u(s0, 8, rsum); } +static void vec_dot_iq4nlx4x2_q8x4x2_4x1(const int n, + float * restrict s0, + const void * restrict vx0, + const void * restrict vx1, + const void * restrict vx2, + const void * restrict vx3, + const void * restrict vy0) { + assert(n % 32 == 0); + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vx2 % 128 == 0); + assert((unsigned long) vx3 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + + const uint32_t qk = QK_Q4_0x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_qblk_size = qk / 2; // int4 + const uint32_t x_qrow_size = n / 2; // int4 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales + const uint8_t * restrict r2_x_q = ((const uint8_t *) vx2) + 0; // quants first + const uint8_t * restrict r2_x_d = ((const uint8_t *) vx2) + x_qrow_size; // then scales + const uint8_t * restrict r3_x_q = ((const uint8_t *) vx3) + 0; // quants first + const uint8_t * restrict r3_x_d = ((const uint8_t *) vx3) + x_qrow_size; // then scales + + const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales + + HVX_Vector r0_sum = Q6_V_vzero(); + HVX_Vector r1_sum = Q6_V_vzero(); + HVX_Vector r2_sum = Q6_V_vzero(); + HVX_Vector r3_sum = Q6_V_vzero(); + + const uint32_t nb = n / qk; + const uint32_t nloe = n % qk; + + uint32_t i = 0; + for (; i < nb; i++) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_full(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_iq4nlx4x8_full(r1_x_q + i * x_qblk_size); + HVX_Vector_x8 r2_q = hvx_vec_load_iq4nlx4x8_full(r2_x_q + i * x_qblk_size); + HVX_Vector_x8 r3_q = hvx_vec_load_iq4nlx4x8_full(r3_x_q + i * x_qblk_size); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); + HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r2_q, vy_q)); + HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r3_q, vy_q)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + HVX_Vector r2_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r2_x_d + i * x_dblk_size)); + HVX_Vector r3_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r3_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d))); + HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d))); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); + HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum)); + r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum)); + } + + if (nloe) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_partial(r0_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r1_q = hvx_vec_load_iq4nlx4x8_partial(r1_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r2_q = hvx_vec_load_iq4nlx4x8_partial(r2_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r3_q = hvx_vec_load_iq4nlx4x8_partial(r3_x_q + i * x_qblk_size, nloe); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe)); + HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r2_q, vy_q, nloe)); + HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r3_q, vy_q, nloe)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + HVX_Vector r2_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r2_x_d + i * x_dblk_size)); + HVX_Vector r3_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r3_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d))); + HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d))); + + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_dd = Q6_V_vand_QV(bmask, r0_dd); + r1_dd = Q6_V_vand_QV(bmask, r1_dd); + r2_dd = Q6_V_vand_QV(bmask, r2_dd); + r3_dd = Q6_V_vand_QV(bmask, r3_dd); + r0_ia = Q6_V_vand_QV(bmask, r0_ia); + r1_ia = Q6_V_vand_QV(bmask, r1_ia); + r2_ia = Q6_V_vand_QV(bmask, r2_ia); + r3_ia = Q6_V_vand_QV(bmask, r3_ia); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); + HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum)); + r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum)); + } + + HVX_Vector_x4 rsum_in = { .v = { r0_sum, r1_sum, r2_sum, r3_sum } }; + HVX_Vector rsum = hvx_vec_reduce_sum_f32x4(rsum_in); + hvx_vec_store_u(s0, 16, rsum); +} + + static void vec_dot_iq4nlx4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1, @@ -1282,37 +2334,148 @@ static void vec_dot_iq4nlx4x2_q8x4x2_2x2(const int n, HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum); HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum); - hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum); - hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); + hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum); + hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); +} + +static void vec_dot_mxfp4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) { + assert(n % 32 == 0); // min sub-block size + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + + const uint32_t qk = QK_MXFP4x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 1; // 32x e8m0 + const uint32_t x_qblk_size = qk / 2; // fp4 + const uint32_t x_qrow_size = n / 2; // fp4 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales + + const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales + + // Row sum (sf) + HVX_Vector r0_sum = Q6_V_vzero(); + + // Multiply and accumulate into int32. + // Compute combined scale (fp32). + // Apply scale to acc and accumulate into the row sum (qf32). + + const uint32_t nb = n / qk; // num full blocks + int32_t nloe = n % qk; // num leftover elemements (must be signed) + + uint32_t i = 0; + for (; i < nb; i++) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full( y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_full(r0_x_q + i * x_qblk_size); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); + + HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size); + HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); + + // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving + HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16 + vy_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy_d), half)); + vy_d = Q6_Vsf_equals_Vqf32(vy_d); + + // Convert rX_d scales from e8m0 to fp32 + // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ... + // Left shift with zero fill to create FP32 + // FIXME: might need to handle zero as a special case (see ggml-cpu code) + HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0; + HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff); + r0_d = Q6_V_vdelta_VV(r0_d, expand); + r0_d = Q6_V_vand_VV(r0_d, e8m0_mask); + r0_d = Q6_Vw_vasl_VwR(r0_d, 23); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d)); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + } + + // Process leftovers + if (nloe) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial( y_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); + + HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size); + HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); + + // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving + HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16 + vy_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy_d), half)); + vy_d = Q6_Vsf_equals_Vqf32(vy_d); + + // Convert rX_d scales from e8m0 to fp32 + // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ... + // Left shift with zero fill to create FP32 + // FIXME: might need to handle zero as a special case (see ggml-cpu code) + HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0; + HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff); + r0_d = Q6_V_vdelta_VV(r0_d, expand); + r0_d = Q6_V_vand_VV(r0_d, e8m0_mask); + r0_d = Q6_Vw_vasl_VwR(r0_d, 23); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d)); + + // Zero-out unused scales + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_dd = Q6_V_vand_QV(bmask, r0_dd); + r0_ia = Q6_V_vand_QV(bmask, r0_ia); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + } + + r0_sum = hvx_vec_reduce_sum_f32(r0_sum); + + hvx_vec_store_u(s0, 4, r0_sum); } -static void vec_dot_mxfp4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) { +static void vec_dot_mxfp4x4x2_q8x4x2_2x1(const int n, float * restrict s0, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0) { assert(n % 32 == 0); // min sub-block size assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); assert((unsigned long) vy0 % 128 == 0); const uint32_t qk = QK_MXFP4x4x2 * 4; - const uint32_t x_dblk_size = 8 * 4 * 1; // 32x e8m0 - const uint32_t x_qblk_size = qk / 2; // fp4 - const uint32_t x_qrow_size = n / 2; // fp4 (not padded) + const uint32_t x_dblk_size = 8 * 4 * 1; // 32x e8m0 + const uint32_t x_qblk_size = qk / 2; // fp4 + const uint32_t x_qrow_size = n / 2; // fp4 (not padded) - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales - const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales + const uint8_t * restrict y_q = ((const uint8_t *) vy0) + 0; // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales // Row sum (sf) HVX_Vector r0_sum = Q6_V_vzero(); + HVX_Vector r1_sum = Q6_V_vzero(); // Multiply and accumulate into int32. // Compute combined scale (fp32). - // Apply scale to acc and accumulate into the row sum (qf32). + // Apply scale to acc and accumulate into the row sum (f32). const uint32_t nb = n / qk; // num full blocks int32_t nloe = n % qk; // num leftover elemements (must be signed) @@ -1321,11 +2484,14 @@ static void vec_dot_mxfp4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const for (; i < nb; i++) { HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full( y_q + i * y_qblk_size); HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_full(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8_full(r1_x_q + i * x_qblk_size); HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size); HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); + HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16 @@ -1341,23 +2507,32 @@ static void vec_dot_mxfp4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const r0_d = Q6_V_vdelta_VV(r0_d, expand); r0_d = Q6_V_vand_VV(r0_d, e8m0_mask); r0_d = Q6_Vw_vasl_VwR(r0_d, 23); + r1_d = Q6_V_vdelta_VV(r1_d, expand); + r1_d = Q6_V_vand_VV(r1_d, e8m0_mask); + r1_d = Q6_Vw_vasl_VwR(r1_d, 23); HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d)); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy_d)); HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); } // Process leftovers if (nloe) { HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial( y_q + i * y_qblk_size, nloe); HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8_partial(r1_x_q + i * x_qblk_size, nloe); - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size); HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); + HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16 @@ -1373,30 +2548,40 @@ static void vec_dot_mxfp4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const r0_d = Q6_V_vdelta_VV(r0_d, expand); r0_d = Q6_V_vand_VV(r0_d, e8m0_mask); r0_d = Q6_Vw_vasl_VwR(r0_d, 23); + r1_d = Q6_V_vdelta_VV(r1_d, expand); + r1_d = Q6_V_vand_VV(r1_d, e8m0_mask); + r1_d = Q6_Vw_vasl_VwR(r1_d, 23); HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d)); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy_d)); - // Zero-out unused scales + // Zero-out unused values HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); r0_dd = Q6_V_vand_QV(bmask, r0_dd); + r1_dd = Q6_V_vand_QV(bmask, r1_dd); r0_ia = Q6_V_vand_QV(bmask, r0_ia); + r1_ia = Q6_V_vand_QV(bmask, r1_ia); HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); } - r0_sum = hvx_vec_reduce_sum_f32(r0_sum); - - hvx_vec_store_u(s0, 4, r0_sum); + HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum); + hvx_vec_store_u(s0, 8, rsum); } -static void vec_dot_mxfp4x4x2_q8x4x2_2x1(const int n, float * restrict s0, +static void vec_dot_mxfp4x4x2_q8x4x2_4x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vx1, + const void * restrict vx2, const void * restrict vx3, const void * restrict vy0) { assert(n % 32 == 0); // min sub-block size assert((unsigned long) vx0 % 128 == 0); assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vx2 % 128 == 0); + assert((unsigned long) vx3 % 128 == 0); assert((unsigned long) vy0 % 128 == 0); const uint32_t qk = QK_MXFP4x4x2 * 4; @@ -1413,17 +2598,19 @@ static void vec_dot_mxfp4x4x2_q8x4x2_2x1(const int n, float * restrict s0, const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales + const uint8_t * restrict r2_x_q = ((const uint8_t *) vx2) + 0; // quants first + const uint8_t * restrict r2_x_d = ((const uint8_t *) vx2) + x_qrow_size; // then scales + const uint8_t * restrict r3_x_q = ((const uint8_t *) vx3) + 0; // quants first + const uint8_t * restrict r3_x_d = ((const uint8_t *) vx3) + x_qrow_size; // then scales const uint8_t * restrict y_q = ((const uint8_t *) vy0) + 0; // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales // Row sum (sf) HVX_Vector r0_sum = Q6_V_vzero(); HVX_Vector r1_sum = Q6_V_vzero(); - - // Multiply and accumulate into int32. - // Compute combined scale (fp32). - // Apply scale to acc and accumulate into the row sum (f32). + HVX_Vector r2_sum = Q6_V_vzero(); + HVX_Vector r3_sum = Q6_V_vzero(); const uint32_t nb = n / qk; // num full blocks int32_t nloe = n % qk; // num leftover elemements (must be signed) @@ -1433,13 +2620,19 @@ static void vec_dot_mxfp4x4x2_q8x4x2_2x1(const int n, float * restrict s0, HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full( y_q + i * y_qblk_size); HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_full(r0_x_q + i * x_qblk_size); HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8_full(r1_x_q + i * x_qblk_size); + HVX_Vector_x8 r2_q = hvx_vec_load_mxfp4x4x8_full(r2_x_q + i * x_qblk_size); + HVX_Vector_x8 r3_q = hvx_vec_load_mxfp4x4x8_full(r3_x_q + i * x_qblk_size); HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); + HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r2_q, vy_q)); + HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r3_q, vy_q)); - HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size); + HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size); HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); + HVX_Vector r2_d = *(const HVX_UVector *) (r2_x_d + i * x_dblk_size); + HVX_Vector r3_d = *(const HVX_UVector *) (r3_x_d + i * x_dblk_size); // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16 @@ -1447,9 +2640,6 @@ static void vec_dot_mxfp4x4x2_q8x4x2_2x1(const int n, float * restrict s0, vy_d = Q6_Vsf_equals_Vqf32(vy_d); // Convert rX_d scales from e8m0 to fp32 - // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ... - // Left shift with zero fill to create FP32 - // FIXME: might need to handle zero as a special case (see ggml-cpu code) HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0; HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff); r0_d = Q6_V_vdelta_VV(r0_d, expand); @@ -1458,29 +2648,46 @@ static void vec_dot_mxfp4x4x2_q8x4x2_2x1(const int n, float * restrict s0, r1_d = Q6_V_vdelta_VV(r1_d, expand); r1_d = Q6_V_vand_VV(r1_d, e8m0_mask); r1_d = Q6_Vw_vasl_VwR(r1_d, 23); + r2_d = Q6_V_vdelta_VV(r2_d, expand); + r2_d = Q6_V_vand_VV(r2_d, e8m0_mask); + r2_d = Q6_Vw_vasl_VwR(r2_d, 23); + r3_d = Q6_V_vdelta_VV(r3_d, expand); + r3_d = Q6_V_vand_VV(r3_d, e8m0_mask); + r3_d = Q6_Vw_vasl_VwR(r3_d, 23); HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d)); HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy_d)); + HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r2_d, vy_d)); + HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r3_d, vy_d)); HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); + HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum)); + r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum)); } - // Process leftovers if (nloe) { HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial( y_q + i * y_qblk_size, nloe); HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8_partial(r1_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r2_q = hvx_vec_load_mxfp4x4x8_partial(r2_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r3_q = hvx_vec_load_mxfp4x4x8_partial(r3_x_q + i * x_qblk_size, nloe); HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); + HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r2_q, vy_q)); + HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r3_q, vy_q)); HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size); HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); + HVX_Vector r2_d = *(const HVX_UVector *) (r2_x_d + i * x_dblk_size); + HVX_Vector r3_d = *(const HVX_UVector *) (r3_x_d + i * x_dblk_size); // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16 @@ -1488,9 +2695,6 @@ static void vec_dot_mxfp4x4x2_q8x4x2_2x1(const int n, float * restrict s0, vy_d = Q6_Vsf_equals_Vqf32(vy_d); // Convert rX_d scales from e8m0 to fp32 - // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ... - // Left shift with zero fill to create FP32 - // FIXME: might need to handle zero as a special case (see ggml-cpu code) HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0; HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff); r0_d = Q6_V_vdelta_VV(r0_d, expand); @@ -1499,28 +2703,46 @@ static void vec_dot_mxfp4x4x2_q8x4x2_2x1(const int n, float * restrict s0, r1_d = Q6_V_vdelta_VV(r1_d, expand); r1_d = Q6_V_vand_VV(r1_d, e8m0_mask); r1_d = Q6_Vw_vasl_VwR(r1_d, 23); + r2_d = Q6_V_vdelta_VV(r2_d, expand); + r2_d = Q6_V_vand_VV(r2_d, e8m0_mask); + r2_d = Q6_Vw_vasl_VwR(r2_d, 23); + r3_d = Q6_V_vdelta_VV(r3_d, expand); + r3_d = Q6_V_vand_VV(r3_d, e8m0_mask); + r3_d = Q6_Vw_vasl_VwR(r3_d, 23); HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d)); HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy_d)); + HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r2_d, vy_d)); + HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r3_d, vy_d)); // Zero-out unused values HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); r0_dd = Q6_V_vand_QV(bmask, r0_dd); r1_dd = Q6_V_vand_QV(bmask, r1_dd); + r2_dd = Q6_V_vand_QV(bmask, r2_dd); + r3_dd = Q6_V_vand_QV(bmask, r3_dd); r0_ia = Q6_V_vand_QV(bmask, r0_ia); r1_ia = Q6_V_vand_QV(bmask, r1_ia); + r2_ia = Q6_V_vand_QV(bmask, r2_ia); + r3_ia = Q6_V_vand_QV(bmask, r3_ia); HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); + HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum)); + r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum)); } - HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum); - hvx_vec_store_u(s0, 8, rsum); + HVX_Vector_x4 rsum_in = { .v = { r0_sum, r1_sum, r2_sum, r3_sum } }; + HVX_Vector rsum = hvx_vec_reduce_sum_f32x4(rsum_in); + hvx_vec_store_u(s0, 16, rsum); } + static void vec_dot_mxfp4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1, const void * restrict vx0, const void * restrict vx1, const void * restrict vy0, const void * restrict vy1) { @@ -2138,7 +3360,6 @@ static void matvec_2d(unsigned int nth, unsigned int ith, void * data) { const uint32_t src0_start_row = src0_nrows_per_thread * ith; const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); - const uint32_t src0_end_row_x2 = src0_start_row + ((src0_end_row - src0_start_row) & ~1U); // no work for this thread if (src0_start_row >= src0_end_row) { @@ -2168,39 +3389,89 @@ static void matvec_2d(unsigned int nth, unsigned int ith, void * data) { const uint8_t * restrict src1_col = (const uint8_t *) src1_data; float * restrict dst_col = (float *) dst->data; - // Prefill spad with 2x src0 rows - #pragma unroll(2) - for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { - const uint32_t is0 = (ir0 - src0_start_row); - if (is0 >= MM_SPAD_SRC0_NROWS) { - break; + if (mmctx->vec_dot_4x1 != NULL) { + const uint32_t src0_end_row_x4 = src0_start_row + ((src0_end_row - src0_start_row) & ~3U); + + // Prefill spad with 4x src0 rows + #pragma unroll(4) + for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x4; ir0 += 4) { + const uint32_t is0 = (ir0 - src0_start_row); + if (is0 >= MM_SPAD_SRC0_NROWS) { + break; + } + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), + src0_stride, src0_row_size, 4); } - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), - src0_stride, src0_row_size, 2); - } - // Process src0 rows - for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { - const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; - mmctx->vec_dot_2x1(ne00, &tmp[ir0 - src0_start_row], ss0, ss0 + src0_stride, src1_col); + // Process src0 rows + for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x4; ir0 += 4) { + const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; + mmctx->vec_dot_4x1(ne00, &tmp[ir0 - src0_start_row], ss0, ss0 + src0_stride, ss0 + 2 * src0_stride, ss0 + 3 * src0_stride, src1_col); - // Prefetch next (n + spad_nrows) row - const uint32_t pr0 = (ir0 + MM_SPAD_SRC0_NROWS); - const uint32_t is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS; - if (pr0 < src0_end_row_x2) { - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + pr0 * src0_row_size), + // Prefetch next (n + spad_nrows) row + const uint32_t pr0 = (ir0 + MM_SPAD_SRC0_NROWS); + const uint32_t is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS; + if (pr0 < src0_end_row_x4) { + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + pr0 * src0_row_size), + src0_stride, src0_row_size, 4); + } + } + + // Process leftovers + uint32_t ir0 = src0_end_row_x4; + if (ir0 + 2 <= src0_end_row) { + const uint32_t is0 = (ir0 - src0_start_row) % MM_SPAD_SRC0_NROWS; + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), src0_stride, src0_row_size, 2); + const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; + mmctx->vec_dot_2x1(ne00, &tmp[ir0 - src0_start_row], ss0, ss0 + src0_stride, src1_col); + ir0 += 2; } - } + if (ir0 < src0_end_row) { + const uint32_t is0 = (ir0 - src0_start_row) % MM_SPAD_SRC0_NROWS; + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), + src0_stride, src0_row_size, 1); + const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; + mmctx->vec_dot_1x1(ne00, &tmp[ir0 - src0_start_row], ss0, src1_col); + ir0 += 1; + } + } else { + const uint32_t src0_end_row_x2 = src0_start_row + ((src0_end_row - src0_start_row) & ~1U); - // Process the last row (if any) - if (src0_end_row != src0_end_row_x2) { - const uint32_t ir0 = src0_end_row_x2; - const uint32_t is0 = (ir0 - src0_start_row); - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), - src0_stride, src0_row_size, 1); - const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; - mmctx->vec_dot_1x1(ne00, &tmp[ir0 - src0_start_row], ss0, src1_col); + // Prefill spad with 2x src0 rows + #pragma unroll(2) + for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { + const uint32_t is0 = (ir0 - src0_start_row); + if (is0 >= MM_SPAD_SRC0_NROWS) { + break; + } + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), + src0_stride, src0_row_size, 2); + } + + // Process src0 rows + for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { + const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; + mmctx->vec_dot_2x1(ne00, &tmp[ir0 - src0_start_row], ss0, ss0 + src0_stride, src1_col); + + // Prefetch next (n + spad_nrows) row + const uint32_t pr0 = (ir0 + MM_SPAD_SRC0_NROWS); + const uint32_t is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS; + if (pr0 < src0_end_row_x2) { + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + pr0 * src0_row_size), + src0_stride, src0_row_size, 2); + } + } + + // Process the last row (if any) + if (src0_end_row != src0_end_row_x2) { + const uint32_t ir0 = src0_end_row_x2; + const uint32_t is0 = (ir0 - src0_start_row); + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), + src0_stride, src0_row_size, 1); + const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; + mmctx->vec_dot_1x1(ne00, &tmp[ir0 - src0_start_row], ss0, src1_col); + } } hvx_copy_f32_ua((uint8_t *) &dst_col[src0_start_row], (uint8_t *) tmp, src0_end_row - src0_start_row); @@ -2432,6 +3703,94 @@ static void matvec_id(unsigned int nth, unsigned int ith, void * data) { // *** dynamic quant +static inline void quantize_block_f32_q8_1x1(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) { + assert((unsigned long) x % 128 == 0); + assert((unsigned long) y_q % 128 == 0); + + HVX_Vector * vx = (HVX_Vector *) x; + HVX_Vector zero = Q6_V_vzero(); + + // Use reduce max fp32 to find max(abs(e)) first + HVX_Vector vmax0_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[0])); + HVX_Vector vmax1_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[1])); + HVX_Vector vmax2_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[2])); + HVX_Vector vmax3_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[3])); + + // Load and convert into QF32 + HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero); // 32 elements + HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero); // 32 elements + HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero); // 32 elements + HVX_Vector vx3_qf = Q6_Vqf32_vsub_VsfVsf(vx[3], zero); // 32 elements + + // Convert to QF32 + HVX_Vector vmax0_qf = Q6_Vqf32_vsub_VsfVsf(vmax0_sf, zero); + HVX_Vector vmax1_qf = Q6_Vqf32_vsub_VsfVsf(vmax1_sf, zero); + HVX_Vector vmax2_qf = Q6_Vqf32_vsub_VsfVsf(vmax2_sf, zero); + HVX_Vector vmax3_qf = Q6_Vqf32_vsub_VsfVsf(vmax3_sf, zero); + + // Combine and convert to fp16 + HVX_Vector vmax01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax1_qf, vmax0_qf))); + HVX_Vector vmax23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax3_qf, vmax2_qf))); + + // Convert into fp16 + HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf))); + HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf))); + + HVX_Vector vd01_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax01_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0 + HVX_Vector vd23_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax23_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0 + HVX_Vector vd01_hf = Q6_Vhf_equals_Vqf16(vd01_qf16); + HVX_Vector vd23_hf = Q6_Vhf_equals_Vqf16(vd23_qf16); + + // Divide input by the scale + HVX_Vector vd01_inv_hf = hvx_vec_inverse_f16(vd01_hf); + HVX_Vector vd23_inv_hf = hvx_vec_inverse_f16(vd23_hf); + vx01_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd01_inv_hf)); + vx23_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd23_inv_hf)); + + // Convert to int8 + HVX_Vector vx01_i16 = hvx_vec_i16_from_hf_rnd_sat(vx01_hf); + HVX_Vector vx23_i16 = hvx_vec_i16_from_hf_rnd_sat(vx23_hf); + HVX_Vector vx_i8 = Q6_Vb_vpack_VhVh_sat(vx23_i16, vx01_i16); + + *(HVX_Vector *) y_q = vx_i8; + + // --- Sum calculation --- + const HVX_Vector ones = Q6_Vb_vsplat_R(1); + HVX_Vector v_sums = Q6_Vw_vrmpy_VbVb(vx_i8, ones); // sum every 4 consecutive elements + // Sum 8 elements: + v_sums = Q6_Vw_vadd_VwVw(v_sums, Q6_V_vror_VR(v_sums, 4)); + v_sums = Q6_Vw_vadd_VwVw(v_sums, Q6_V_vror_VR(v_sums, 8)); + v_sums = Q6_Vw_vadd_VwVw(v_sums, Q6_V_vror_VR(v_sums, 16)); + + // Copy to stack to extract sums and vmaxes + float vmax0[32] __attribute__((aligned(128))); + float vmax1[32] __attribute__((aligned(128))); + float vmax2[32] __attribute__((aligned(128))); + float vmax3[32] __attribute__((aligned(128))); + int32_t sums[32] __attribute__((aligned(128))); + + hvx_vec_store_u(vmax0, 128, vmax0_sf); + hvx_vec_store_u(vmax1, 128, vmax1_sf); + hvx_vec_store_u(vmax2, 128, vmax2_sf); + hvx_vec_store_u(vmax3, 128, vmax3_sf); + hvx_vec_store_u(sums, 128, v_sums); + + float d0 = vmax0[0] / 127.0f; + float d1 = vmax1[0] / 127.0f; + float d2 = vmax2[0] / 127.0f; + float d3 = vmax3[0] / 127.0f; + + __fp16 * y_d_half = (__fp16 *) y_d; + y_d_half[0] = d0; + y_d_half[1] = (float) sums[0] * d0; + y_d_half[2] = d1; + y_d_half[3] = (float) sums[8] * d1; + y_d_half[4] = d2; + y_d_half[5] = (float) sums[16] * d2; + y_d_half[6] = d3; + y_d_half[7] = (float) sums[24] * d3; +} + static inline void quantize_block_f32_q8x1(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) { assert((unsigned long) x % 128 == 0); assert((unsigned long) y_q % 128 == 0); @@ -2656,6 +4015,77 @@ static void quantize_f32_q8x4x2(unsigned int nth, unsigned int ith, void * data) ir_last, src_row_size, dst_row_size, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } +static void quantize_row_f32_q8_1x4x2(float * restrict x, uint8_t * restrict y, uint32_t k) { + assert(k % 32 == 0); + const uint32_t qk = QK_Q8_0x4x2; + const uint32_t nb = (k + qk - 1) / qk; + + const uint32_t qrow_size = k; // int8 + + const uint32_t dblk_size = 8 * 4; // 8x (d, s) __fp16 = 32 bytes + const uint32_t qblk_size = QK_Q8_0x4x2; // int8 + + uint8_t * restrict y_q = (y + 0); // quants first + uint8_t * restrict y_d = (y + qrow_size); // then scales/sums + + // Temp scales override input since we're working off of the aligned temp buffer in VTCM + uint8_t * restrict t_d = (uint8_t *) x; + + for (uint32_t i = 0; i < nb; i++) { + quantize_block_f32_q8_1x1(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2); + quantize_block_f32_q8_1x1(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2); + } + + // now copy the scales/sums into final location + hvx_copy_f16_ua(y_d, t_d, nb * 16); +} + +static void quantize_f32_q8_1x4x2(unsigned int nth, unsigned int ith, void * data) { + struct htp_matmul_context * mmctx = data; + struct htp_ops_context * octx = mmctx->octx; + + const struct htp_tensor * src = octx->src[1]; + uint8_t * restrict dst = octx->src1_spad.data; + struct htp_spad * spad = &octx->src0_spad; + uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread; + + uint64_t t1 = HAP_perf_get_qtimer_count(); + + const uint32_t ne0 = src->ne[0]; + const uint32_t ne1 = src->ne[1]; + const uint32_t ne2 = src->ne[2]; + const uint32_t ne3 = src->ne[3]; + + const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows + + const uint32_t ir_first = nrows_per_thread * ith; // first row + const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row + + const size_t src_row_size = src->nb[1]; + const size_t dst_row_size = q8_1x4x2_row_size(ne0); + + uint8_t * restrict src_data = (uint8_t *) src->data + (src_row_size * ir_first); + uint8_t * restrict dst_data = (uint8_t *) dst + (dst_row_size * ir_first); + uint8_t * restrict tmp_data = (uint8_t *) spad->data + (spad->size_per_thread * ith); + + const size_t src_row_size_padded = hex_round_up(src_row_size, QK_Q8_0x4x2 * sizeof(float)); + memset(tmp_data, 0, src_row_size_padded); // zero-out temp row data for padding + + for (uint32_t i = ir_first; i < ir_last; ++i) { + hex_l2fetch(src_data, src_row_size, src_row_size, 2); + hvx_copy_f32_aa(tmp_data, src_data, ne0); + + quantize_row_f32_q8_1x4x2((float *) tmp_data, dst_data, ne0); + dst_data += dst_row_size; + src_data += src_row_size; + } + + uint64_t t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "quantize-f32-q8_1x4: %u/%u : n-rows %u (%u:%u) row-size %u -> %u usec %u\n", ith, nth, nrows, ir_first, + ir_last, src_row_size, dst_row_size, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + static void quantize_f32_f16(unsigned int nth, unsigned int ith, void * data) { struct htp_matmul_context * mmctx = data; struct htp_ops_context * octx = mmctx->octx; @@ -2751,24 +4181,35 @@ static int htp_mminit_vec_dot(struct htp_matmul_context * mmctx, enum htp_data_t mmctx->vec_dot_1x1 = vec_dot_q4x4x2_q8x4x2_1x1; mmctx->vec_dot_2x1 = vec_dot_q4x4x2_q8x4x2_2x1; mmctx->vec_dot_2x2 = vec_dot_q4x4x2_q8x4x2_2x2; + mmctx->vec_dot_4x1 = vec_dot_q4x4x2_q8x4x2_4x1; + return 0; + case HTP_TYPE_Q4_1: + mmctx->type = "q4_1x4x2-f32"; + mmctx->vec_dot_1x1 = vec_dot_q4_1x4x2_q8x4x2_1x1; + mmctx->vec_dot_2x1 = vec_dot_q4_1x4x2_q8x4x2_2x1; + mmctx->vec_dot_2x2 = vec_dot_q4_1x4x2_q8x4x2_2x2; + mmctx->vec_dot_4x1 = vec_dot_q4_1x4x2_q8x4x2_4x1; return 0; case HTP_TYPE_Q8_0: mmctx->type = "q8x4x2-f32"; mmctx->vec_dot_1x1 = vec_dot_q8x4x2_q8x4x2_1x1; mmctx->vec_dot_2x1 = vec_dot_q8x4x2_q8x4x2_2x1; mmctx->vec_dot_2x2 = vec_dot_q8x4x2_q8x4x2_2x2; + mmctx->vec_dot_4x1 = vec_dot_q8x4x2_q8x4x2_4x1; return 0; case HTP_TYPE_IQ4_NL: mmctx->type = "iq4nlx4x2-f32"; mmctx->vec_dot_1x1 = vec_dot_iq4nlx4x2_q8x4x2_1x1; mmctx->vec_dot_2x1 = vec_dot_iq4nlx4x2_q8x4x2_2x1; mmctx->vec_dot_2x2 = vec_dot_iq4nlx4x2_q8x4x2_2x2; + mmctx->vec_dot_4x1 = vec_dot_iq4nlx4x2_q8x4x2_4x1; return 0; case HTP_TYPE_MXFP4: mmctx->type = "mxfp4x4x2-f32"; mmctx->vec_dot_1x1 = vec_dot_mxfp4x4x2_q8x4x2_1x1; mmctx->vec_dot_2x1 = vec_dot_mxfp4x4x2_q8x4x2_2x1; mmctx->vec_dot_2x2 = vec_dot_mxfp4x4x2_q8x4x2_2x2; + mmctx->vec_dot_4x1 = vec_dot_mxfp4x4x2_q8x4x2_4x1; return 0; default: return -1; @@ -2894,8 +4335,13 @@ static int op_matmul_hvx(struct htp_ops_context * octx) { return HTP_STATUS_NO_SUPPORT; } - quant_job_func = quantize_f32_q8x4x2; - src1_row_size = q8x4x2_row_size(ne10); + if (src0->type == HTP_TYPE_Q4_1) { + quant_job_func = quantize_f32_q8_1x4x2; + src1_row_size = q8_1x4x2_row_size(ne10); + } else { + quant_job_func = quantize_f32_q8x4x2; + src1_row_size = q8x4x2_row_size(ne10); + } htp_mminit_spad(octx, dst_row_size, src0_row_size_padded, src1_row_size, src1_nrows, 0); } @@ -2962,7 +4408,7 @@ int op_matmul(struct htp_ops_context * octx) { // HMX supports F16, Q4_0, Q8_0, IQ4_NL, MXFP4 weights. // Other types fall back to HVX. uint32_t wtype = src0->type; - if (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_Q4_0 && wtype != HTP_TYPE_Q8_0 && wtype != HTP_TYPE_IQ4_NL && wtype != HTP_TYPE_MXFP4) { + if (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_Q4_0 && wtype != HTP_TYPE_Q4_1 && wtype != HTP_TYPE_Q8_0 && wtype != HTP_TYPE_IQ4_NL && wtype != HTP_TYPE_MXFP4) { return op_matmul_hvx(octx); } @@ -3098,8 +4544,13 @@ int op_matmul_id(struct htp_ops_context * octx) { return HTP_STATUS_NO_SUPPORT; } - quant_job_func = quantize_f32_q8x4x2; - src1_row_size = q8x4x2_row_size(ne10); + if (src0->type == HTP_TYPE_Q4_1) { + quant_job_func = quantize_f32_q8_1x4x2; + src1_row_size = q8_1x4x2_row_size(ne10); + } else { + quant_job_func = quantize_f32_q8x4x2; + src1_row_size = q8x4x2_row_size(ne10); + } const size_t src2_spad_size_per_thread = hex_round_up(matrix_row_counts_size + matrix_row_map_size, 256); htp_mminit_spad(octx, dst_row_size, src0_row_size_padded, src1_row_size, src1_nrows, src2_spad_size_per_thread); From 8c8f213daccd54f7a913034c68a885c11b851134 Mon Sep 17 00:00:00 2001 From: Reese Levine Date: Wed, 27 May 2026 14:22:33 -0700 Subject: [PATCH 716/831] ggml-webgpu: remove legacy constants (llama/23672) --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index f6d17a073be..1846886db4e 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -94,14 +94,6 @@ static inline uint32_t ggml_webgpu_u32_from_f32(float value) { #define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4 #define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4 -// For operations which process a row in parallel, this seems like a reasonable -// default -#define WEBGPU_ROW_SPLIT_WG_SIZE 64 - -// Track https://github.com/gpuweb/gpuweb/issues/5315 for fixes to -// implementations so this can be removed, necessary only for get_rows right now -#define WEBGPU_MAX_WG_SIZE 288 - /* End Constants */ // This is a "fake" base pointer, since WebGPU buffers do not have pointers to @@ -631,7 +623,7 @@ static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx, size_t size) { std::vector params = { (uint32_t) offset, (uint32_t) size, value }; std::vector entries = { ggml_webgpu_make_bind_group_entry(0, buf, 0, buf.GetSize()) }; - size_t bytes_per_wg = WEBGPU_MAX_WG_SIZE * ctx->capabilities.memset_bytes_per_thread; + size_t bytes_per_wg = ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup * ctx->capabilities.memset_bytes_per_thread; uint32_t wg_x = CEIL_DIV(size + 3, bytes_per_wg); ctx->queue.WriteBuffer(ctx->memset_params_buf, 0, params.data(), params.size() * sizeof(uint32_t)); @@ -1366,7 +1358,7 @@ static webgpu_encoded_op ggml_webgpu_get_rows(webgpu_context & ctx, shader_lib_ctx.src0 = src; shader_lib_ctx.src1 = nullptr; shader_lib_ctx.dst = dst; - shader_lib_ctx.max_wg_size = WEBGPU_MAX_WG_SIZE; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; webgpu_pipeline pipeline = ctx->shader_lib->get_get_rows_pipeline(shader_lib_ctx); auto * decisions = static_cast(pipeline.context.get()); @@ -3716,13 +3708,13 @@ static ggml_guid_t ggml_backend_webgpu_guid(void) { static void ggml_webgpu_init_memset_pipeline(webgpu_global_context & ctx) { // we use the maximum workgroup size for the memset pipeline - size_t max_threads = WEBGPU_MAX_WG_SIZE * ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; + size_t max_threads = ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup * ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; // Size the bytes_per_thread so that the largest buffer size can be handled ctx->capabilities.memset_bytes_per_thread = CEIL_DIV(ctx->capabilities.limits.maxStorageBufferBindingSize, max_threads); std::vector constants(2); constants[0].key = "wg_size"; - constants[0].value = WEBGPU_MAX_WG_SIZE; + constants[0].value = ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; constants[1].key = "bytes_per_thread"; constants[1].value = ctx->capabilities.memset_bytes_per_thread; ctx->memset_pipeline = ggml_webgpu_create_pipeline(ctx->device, wgsl_memset, "memset", constants); From 7e843a80e1df28463a64a92ffd10410d5280ab3b Mon Sep 17 00:00:00 2001 From: ymcki <84055651+ymcki@users.noreply.github.com> Date: Thu, 28 May 2026 12:23:21 +0800 Subject: [PATCH 717/831] opencl: OP_GATED_DELTA_NET (llama/23312) * OP_GATED_DELTA_NET impl * add back lanes_per_column declaration * removed has_subgroup_arithmetic and has_subgroup_clustered_reduce * removed trailing spaces and fixes indentation. Hard coded subgroup size for Adreno and Intel. Return not supported when K>1 state snapshot * support for K>1 state snapshot * removed picky indent multiple of 4 fixes * removed return that won\'t be executed --- ggml/src/ggml-opencl/CMakeLists.txt | 1 + ggml/src/ggml-opencl/ggml-opencl.cpp | 345 ++++++++++++++++-- .../ggml-opencl/kernels/gated_delta_net.cl | 247 +++++++++++++ 3 files changed, 566 insertions(+), 27 deletions(-) create mode 100644 ggml/src/ggml-opencl/kernels/gated_delta_net.cl diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index f75d089b574..446fb727996 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -164,6 +164,7 @@ set(GGML_OPENCL_KERNELS sqr sqrt ssm_conv + gated_delta_net sub sum_rows cumsum diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 42286435bc6..6d6c3e8973d 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -412,6 +412,7 @@ struct ggml_backend_opencl_context { size_t max_workgroup_size; bool fp16_support; bool has_vector_subgroup_broadcast; + bool has_qcom_subgroup_shuffle = false; // cl_qcom_subgroup_shuffle bool disable_fusion; std::regex *opfilter = nullptr; // regex of ops to not claim @@ -634,6 +635,10 @@ struct ggml_backend_opencl_context { cl_kernel kernel_conv_2d_f32; cl_kernel kernel_conv_2d_f16_f32; cl_kernel kernel_ssm_conv_f32_f32, kernel_ssm_conv_f32_f32_4; + // [size_idx][kda][tgpp] where size_idx: 0=S_V=16, 1=32, 2=64, 3=128; kda: 0 or 1. + // tgpp 0 = TG variant (COLS_PER_LANE_GROUP=1), tgpp 1 = prefill variant (COLS_PER_LANE_GROUP=4). + cl_kernel kernel_gated_delta_net_f32[4][2][2] = {}; + cl_kernel kernel_timestep_embedding; cl_kernel kernel_gemv_moe_q4_0_f32_ns, kernel_gemm_moe_q4_0_f32_ns; cl_kernel kernel_gemv_moe_q4_1_f32_ns, kernel_gemm_moe_q4_1_f32_ns; @@ -837,16 +842,16 @@ static std::vector g_ggml_backend_opencl_devices; static std::vector> g_ggml_backend_opencl_dev_ctxs; inline std::string read_file(const std::string &path) { - std::ifstream ifs(path); - if (!ifs) { - return ""; - } - std::string text; - ifs.seekg(0, std::ios::end); - text.resize(ifs.tellg()); - ifs.seekg(0, std::ios::beg); - ifs.read(&text[0], text.size()); - return text; + std::ifstream ifs(path); + if (!ifs) { + return ""; + } + std::string text; + ifs.seekg(0, std::ios::end); + text.resize(ifs.tellg()); + ifs.seekg(0, std::ios::beg); + ifs.read(&text[0], text.size()); + return text; } static cl_program build_program_from_source(cl_context ctx, cl_device_id dev, const char* program_buffer, const std::string &compile_opts) { @@ -2463,12 +2468,12 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx) { build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); CL_CHECK((backend_ctx->kernel_upscale = clCreateKernel(backend_ctx->program_upscale, "kernel_upscale", &err), err)); if (backend_ctx->program_upscale) { - cl_int err_bilinear; - backend_ctx->kernel_upscale_bilinear = clCreateKernel(backend_ctx->program_upscale, "kernel_upscale_bilinear", &err_bilinear); - if (err_bilinear != CL_SUCCESS) { + cl_int err_bilinear; + backend_ctx->kernel_upscale_bilinear = clCreateKernel(backend_ctx->program_upscale, "kernel_upscale_bilinear", &err_bilinear); + if (err_bilinear != CL_SUCCESS) { GGML_LOG_WARN("ggml_opencl: kernel_upscale_bilinear not found in upscale.cl. Bilinear upscale will not be available. Error: %d\n", err_bilinear); backend_ctx->kernel_upscale_bilinear = nullptr; - } + } } else { backend_ctx->kernel_upscale_bilinear = nullptr; } @@ -2538,8 +2543,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx) { GGML_LOG_CONT("."); } - // conv2d - { + // conv2d + { #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src { #include "conv2d.cl.h" @@ -2597,6 +2602,86 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx) { GGML_LOG_CONT("."); } + // gated_delta_net: one kernel per (S_V, KDA, tgpp) triple. + { + #ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gated_delta_net.cl.h" + }; + #else + const std::string kernel_src = read_file("gated_delta_net.cl"); + #endif + + const int gdn_sizes[4] = { 16, 32, 64, 128 }; + const int sg_size = backend_ctx->gpu_family == GPU_FAMILY::ADRENO ? 64 : backend_ctx->gpu_family == GPU_FAMILY::INTEL ? 32 : -1; + if (sg_size < 0) { + GGML_LOG_ERROR("Unsupported GPU Family: only Adreno and Intel are supported.\n"); + exit(1); + } + + for (int si = 0; si < 4; si++) { + const int S_V = gdn_sizes[si]; + + // MUST match the dispatcher heuristic in ggml_cl_gated_delta_net exactly. + int lanes_per_column; + if (S_V >= 128) { + lanes_per_column = 8; + } else { + lanes_per_column = std::min(S_V, sg_size); + } + + // Round LANES_PER_COLUMN down until it is: + // * power-of-two + // * divides both S_V and sg_size + while (lanes_per_column > 1 && + (((lanes_per_column & (lanes_per_column - 1)) != 0) || + (S_V % lanes_per_column) != 0 || + (sg_size % lanes_per_column) != 0)) { + lanes_per_column >>= 1; + } + + GGML_ASSERT(lanes_per_column >= 1); + GGML_ASSERT(((lanes_per_column & (lanes_per_column - 1)) == 0)); + GGML_ASSERT((S_V % lanes_per_column) == 0); + GGML_ASSERT((sg_size % lanes_per_column) == 0); + + const bool is_partial_reduce = (lanes_per_column != 1) && (lanes_per_column < sg_size); + int use_qcom_shuffle = 0; + if (is_partial_reduce) { + if (backend_ctx->has_qcom_subgroup_shuffle) { + use_qcom_shuffle = 1; + } + } + for (int kda = 0; kda < 2; kda++) { + for (int tgpp = 0; tgpp < 2; tgpp++) { + const int cpl = (tgpp == 0) ? 1 : 4; + const int spw = (tgpp == 0) ? 1 : 1; + + std::string opts = compile_opts; + opts += " -DS_V=" + std::to_string(S_V); + opts += " -DKDA=" + std::to_string(kda); + opts += " -DSUBGROUP_SIZE=" + std::to_string(sg_size); + opts += " -DLANES_PER_COLUMN=" + std::to_string(lanes_per_column); + opts += " -DCOLS_PER_LANE_GROUP=" + std::to_string(cpl); + opts += " -DUSE_QCOM_SUBGROUP_SHUFFLE=" + std::to_string(use_qcom_shuffle); + + // Since spw=1 is found to be optimal, SUBGROUPS_PER_WG > 1 code in + // the kernel is removed. If you want to experiment with spw > 1, + // Please remember to implement code to handle it. + opts += " -DSUBGROUPS_PER_WG=" + std::to_string(spw); + + cl_program prog = build_program_from_source( + backend_ctx->context, backend_ctx->device, kernel_src.c_str(), opts); + + CL_CHECK((backend_ctx->kernel_gated_delta_net_f32[si][kda][tgpp] = + clCreateKernel(prog, "kernel_gated_delta_net", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + } + } + } + GGML_LOG_CONT("."); + } + // mul_mv_id_q4_0_f32_8x_flat { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -2827,7 +2912,7 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx) { #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src { #include "gemm_noshuffle_q4_1_f32.cl.h" - }; + }; #else const std::string kernel_src = read_file("gemm_noshuffle_q4_1_f32.cl"); #endif @@ -2866,7 +2951,7 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx) { #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src { #include "gemm_noshuffle_iq4_nl_f32.cl.h" - }; + }; #else const std::string kernel_src = read_file("gemm_noshuffle_iq4_nl_f32.cl"); #endif @@ -2905,7 +2990,7 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx) { #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src { #include "gemm_noshuffle_q8_0_f32.cl.h" - }; + }; #else const std::string kernel_src = read_file("gemm_noshuffle_q8_0_f32.cl"); #endif @@ -2946,7 +3031,7 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx) { #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src { #include "gemm_noshuffle_q4_k_f32.cl.h" - }; + }; #else const std::string kernel_src = read_file("gemm_noshuffle_q4_k_f32.cl"); #endif @@ -3781,6 +3866,16 @@ static ggml_backend_opencl_context * ggml_cl_init(ggml_backend_dev_t dev) { clGetDeviceInfo(device, CL_DEVICE_EXTENSIONS, ext_str_size, ext_buffer, NULL); ext_buffer[ext_str_size] = '\0'; // ensure it is null terminated + // check support for qcom_subgroup_shuffle + if (opencl_c_version.major == 3 && strstr(ext_buffer, "cl_khr_subgroups") != NULL) { + GGML_LOG_INFO("ggml_opencl: cl_khr_subgroups support: true\n"); + if (strstr(ext_buffer, "cl_qcom_subgroup_shuffle") != NULL) { + backend_ctx->has_qcom_subgroup_shuffle = true; + } + } + GGML_LOG_INFO("ggml_opencl: cl_qcom_subgroup_shuffle support: %s\n", + backend_ctx->has_qcom_subgroup_shuffle ? "true" : "false"); + // Check if ext_buffer contains cl_khr_fp16 backend_ctx->fp16_support = strstr(ext_buffer, "cl_khr_fp16") != NULL; GGML_LOG_INFO("ggml_opencl: device FP16 support: %s\n", backend_ctx->fp16_support ? "true" : "false"); @@ -4832,17 +4927,17 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te case GGML_UNARY_OP_RELU: case GGML_UNARY_OP_GELU_ERF: case GGML_UNARY_OP_GELU_QUICK: - return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; + return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; case GGML_UNARY_OP_SIGMOID: return ggml_is_contiguous(op->src[0]); case GGML_UNARY_OP_TANH: case GGML_UNARY_OP_NEG: case GGML_UNARY_OP_EXP: - return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16; + return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16; case GGML_UNARY_OP_EXPM1: - return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16; + return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16; case GGML_UNARY_OP_SOFTPLUS: - return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16; + return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16; default: return false; } @@ -4891,6 +4986,15 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32); case GGML_OP_SSM_CONV: return (op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32); + case GGML_OP_GATED_DELTA_NET: + { + // Match the Vulkan backend: only F32 -> F32, S_v in {16, 32, 64, 128}. + if (op->src[0]->type != GGML_TYPE_F32 || op->type != GGML_TYPE_F32) { + return false; + } + const int64_t S_v = op->src[2]->ne[0]; + return S_v == 16 || S_v == 32 || S_v == 64 || S_v == 128; + } case GGML_OP_CONCAT: return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; case GGML_OP_TIMESTEP_EMBEDDING: @@ -10555,7 +10659,7 @@ static void ggml_cl_pad(ggml_backend_t backend, const ggml_tensor * src0, ggml_t size_t local_work_size[] = { lws0, 1, 1 }; size_t * local_work_size_ptr = local_work_size; - if (d_ne0 % lws0 != 0 && !backend_ctx->non_uniform_workgroups) { + if (d_ne0 % lws0 != 0 && !backend_ctx->non_uniform_workgroups) { local_work_size_ptr = nullptr; } @@ -17052,6 +17156,185 @@ static void ggml_cl_glu(ggml_backend_t backend, const ggml_tensor * src0, const backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); } +static void ggml_cl_gated_delta_net(ggml_backend_t backend, ggml_tensor * dst) { + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + const ggml_tensor * src_q = dst->src[0]; + const ggml_tensor * src_k = dst->src[1]; + const ggml_tensor * src_v = dst->src[2]; + const ggml_tensor * src_g = dst->src[3]; + const ggml_tensor * src_beta = dst->src[4]; + const ggml_tensor * src_state = dst->src[5]; + + GGML_ASSERT(src_q && src_q->extra); + GGML_ASSERT(src_k && src_k->extra); + GGML_ASSERT(src_v && src_v->extra); + GGML_ASSERT(src_g && src_g->extra); + GGML_ASSERT(src_beta && src_beta->extra); + GGML_ASSERT(src_state && src_state->extra); + + ggml_backend_opencl_context * backend_ctx = (ggml_backend_opencl_context *) backend->context; + + const cl_uint S_v = (cl_uint) src_v->ne[0]; + const cl_uint H_v = (cl_uint) src_v->ne[1]; + const cl_uint n_tokens = (cl_uint) src_v->ne[2]; + const cl_uint n_seqs = (cl_uint) src_v->ne[3]; + const cl_uint K = (cl_uint) src_state->ne[1]; + + int si; + switch (S_v) { + case 16: si = 0; break; + case 32: si = 1; break; + case 64: si = 2; break; + case 128: si = 3; break; + default: + GGML_ASSERT(false && "ggml_cl_gated_delta_net: unsupported S_v"); + } + + const int kda = (src_g->ne[0] == (int64_t) S_v) ? 1 : 0; + + // TODO: Optimize when S_v!=128. Not necessary for now as Qwen3.5/6 are all S_v=128 + // token generation mode (tgpp=0): + // process 1 token at a time, so columns per lane (cpl) == 1 + // prompt processing mode (tgpp=1): + // cpl=4 to process 4 tokens for single-token. 4 is chosen for Adreno 750 as per + // work-item/thread has at most 128 registers. + // All Qwen3.5/6 models are S_v == 128, so LANES_PER_COLUMN == 8 + // such that ROWS_PER_LANE = 128/8 = 16 + // Variables in the kernel: + // k_reg, q_reg, g_exp are all 16 floats + // s_shard has cpl*ROWS_PER_LANE = 4*16 = 64 floats + // Total 112 registers used. + // subgroups_per_workgroup (spw) can be set to 1,2,4,8,16 for tg and 1,2,4 for pp + // for S_v=128. + // Empirically found that when spw=1, we get the best performance for both tg and pp + const int tgpp = (n_tokens == 1) ? 0 : 1; + const int cpl = (tgpp == 0) ? 1 : 4; + // spw needs adjustment when S_v != 128 + const int spw = (tgpp == 0) ? 1 : 1; + + cl_kernel kernel = backend_ctx->kernel_gated_delta_net_f32[si][kda][tgpp]; + GGML_ASSERT(kernel != nullptr); + + const cl_uint s_off = S_v * H_v * n_tokens * n_seqs; + + const cl_uint sq1 = (cl_uint)(src_q->nb[1] / sizeof(float)); + const cl_uint sq2 = (cl_uint)(src_q->nb[2] / sizeof(float)); + const cl_uint sq3 = (cl_uint)(src_q->nb[3] / sizeof(float)); + const cl_uint sv1 = (cl_uint)(src_v->nb[1] / sizeof(float)); + const cl_uint sv2 = (cl_uint)(src_v->nb[2] / sizeof(float)); + const cl_uint sv3 = (cl_uint)(src_v->nb[3] / sizeof(float)); + const cl_uint sb1 = (cl_uint)(src_beta->nb[1] / sizeof(float)); + const cl_uint sb2 = (cl_uint)(src_beta->nb[2] / sizeof(float)); + const cl_uint sb3 = (cl_uint)(src_beta->nb[3] / sizeof(float)); + + const cl_uint H_k = (cl_uint) src_q->ne[1]; + const cl_uint rq3 = (cl_uint)(src_v->ne[3] / src_q->ne[3]); + + const float scale = 1.0f / sqrtf((float) S_v); + + ggml_tensor_extra_cl * extra_q = (ggml_tensor_extra_cl *) src_q->extra; + ggml_tensor_extra_cl * extra_k = (ggml_tensor_extra_cl *) src_k->extra; + ggml_tensor_extra_cl * extra_v = (ggml_tensor_extra_cl *) src_v->extra; + ggml_tensor_extra_cl * extra_g = (ggml_tensor_extra_cl *) src_g->extra; + ggml_tensor_extra_cl * extra_beta = (ggml_tensor_extra_cl *) src_beta->extra; + ggml_tensor_extra_cl * extra_state = (ggml_tensor_extra_cl *) src_state->extra; + ggml_tensor_extra_cl * extra_dst = (ggml_tensor_extra_cl *) dst->extra; + + const cl_ulong off_q = extra_q->offset + src_q->view_offs; + const cl_ulong off_k = extra_k->offset + src_k->view_offs; + const cl_ulong off_v = extra_v->offset + src_v->view_offs; + const cl_ulong off_g = extra_g->offset + src_g->view_offs; + const cl_ulong off_beta = extra_beta->offset + src_beta->view_offs; + const cl_ulong off_state = extra_state->offset + src_state->view_offs; + const cl_ulong off_dst = extra_dst->offset + dst->view_offs; + + int idx = 0; + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_mem), &extra_q->data_device)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_ulong), &off_q)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_mem), &extra_k->data_device)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_ulong), &off_k)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_mem), &extra_v->data_device)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_ulong), &off_v)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_mem), &extra_g->data_device)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_ulong), &off_g)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_mem), &extra_beta->data_device)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_ulong), &off_beta)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_mem), &extra_state->data_device)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_ulong), &off_state)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_mem), &extra_dst->data_device)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_ulong), &off_dst)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &H_v)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &n_tokens)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &n_seqs)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &s_off)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &sq1)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &sq2)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &sq3)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &sv1)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &sv2)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &sv3)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &sb1)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &sb2)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &sb3)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &H_k)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &rq3)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(float), &scale)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &K)); + + // Subgroup size is 64 for Adreno and 32 for Intel + const int sg_size = backend_ctx->gpu_family == GPU_FAMILY::ADRENO ? 64 : backend_ctx->gpu_family == GPU_FAMILY::INTEL ? 32 : -1; + if (sg_size < 0) { + GGML_LOG_ERROR("Unsupported GPU Family: only Adreno and Intel are supported.\n"); + exit(1); + } + + // For the subgroup-shuffle kernel, we can safely prefer 8 lanes/column for S_v>=128 + // For the subgroup-shuffle kernel: + // S_v >= 128 -> prefer 8 lanes/column (good occupancy & register pressure tradeoff) + // else -> min(S_v, subgroup_size) + int lanes_per_column; + if ((int)S_v >= 128) { + lanes_per_column = 8; + } else { + lanes_per_column = std::min((int)S_v, sg_size); + } + + // Max workgroup size for Adreno 750 is 1024 + const int wg_size = sg_size * spw; + + // Ensure lanes_per_column is a power-of-two and divides both S_v and subgroup_size. + // (Required for lane-group shuffle-xor reduction correctness.) + while (lanes_per_column > 1 && + (((lanes_per_column & (lanes_per_column - 1)) != 0) || + (((int)S_v % lanes_per_column) != 0) || + (sg_size % lanes_per_column) != 0)) { + lanes_per_column >>= 1; + } + GGML_ASSERT(lanes_per_column >= 1); + GGML_ASSERT(((lanes_per_column & (lanes_per_column - 1)) == 0)); + GGML_ASSERT(((int)S_v % lanes_per_column) == 0); + GGML_ASSERT((sg_size % lanes_per_column) == 0); + + const int cols_per_wg = spw * (sg_size / lanes_per_column) * cpl; + GGML_ASSERT(cols_per_wg > 0); + GGML_ASSERT(((int)S_v % cols_per_wg) == 0); + + size_t global_work_size[3]; + size_t local_work_size[3]; + + global_work_size[0] = (size_t) H_v * (size_t) wg_size; + global_work_size[1] = (size_t) n_seqs; + global_work_size[2] = (size_t) S_v / (size_t) cols_per_wg; + + local_work_size[0] = (size_t) wg_size; + local_work_size[1] = 1; + local_work_size[2] = 1; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); +} + //------------------------------------------------------------------------------ // Op offloading //------------------------------------------------------------------------------ @@ -17267,8 +17550,8 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor } func = ggml_cl_group_norm; break; - case GGML_OP_REPEAT: - if (!any_on_device) { + case GGML_OP_REPEAT: + if (!any_on_device) { return false; } func = ggml_cl_repeat; @@ -17297,6 +17580,14 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor } func = ggml_cl_ssm_conv; break; + case GGML_OP_GATED_DELTA_NET: + if (!any_on_device) { + return false; + } + // GDN has 6 source tensors, so it cannot use the standard + // (src0, src1, dst) func signature. Dispatch directly and return. + ggml_cl_gated_delta_net(backend, tensor); + return true; case GGML_OP_CONCAT: if (!any_on_device) { return false; diff --git a/ggml/src/ggml-opencl/kernels/gated_delta_net.cl b/ggml/src/ggml-opencl/kernels/gated_delta_net.cl new file mode 100644 index 00000000000..d11192f5802 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gated_delta_net.cl @@ -0,0 +1,247 @@ +#pragma OPENCL EXTENSION cl_khr_subgroups : enable + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#ifndef S_V +#define S_V 128 +#endif +#ifndef KDA +#define KDA 0 +#endif +#ifndef SUBGROUP_SIZE +#define SUBGROUP_SIZE 64 +#endif +#ifndef LANES_PER_COLUMN +#define LANES_PER_COLUMN 8 +#endif +#ifndef COLS_PER_LANE_GROUP +#define COLS_PER_LANE_GROUP 1 +#endif +#ifndef SUBGROUPS_PER_WG +#define SUBGROUPS_PER_WG 1 +#endif +#ifndef USE_QCOM_SUBGROUP_SHUFFLE +#define USE_QCOM_SUBGROUP_SHUFFLE 0 +#endif + +#define WG_SIZE (SUBGROUP_SIZE * SUBGROUPS_PER_WG) +#define LANE_GROUPS_PER_SG (SUBGROUP_SIZE / LANES_PER_COLUMN) +#define COLS_PER_SG (LANE_GROUPS_PER_SG * COLS_PER_LANE_GROUP) +#define COLS_PER_WG (SUBGROUPS_PER_WG * COLS_PER_SG) +#define ROWS_PER_LANE (S_V / LANES_PER_COLUMN) + +#if USE_QCOM_SUBGROUP_SHUFFLE +#pragma OPENCL EXTENSION cl_qcom_subgroup_shuffle : enable +#endif + +// XOR-based parallel sum +// This does a reduction across groups of LANES_PER_COLUMN +static inline float reduce_add_shmem(float partial, __local float * temp, uint lane) { +#if USE_QCOM_SUBGROUP_SHUFFLE + #pragma unroll + for (uint s = LANES_PER_COLUMN / 2u; s > 0u; s >>= 1u) { + partial += qcom_sub_group_shuffle_xor(partial, s, CLK_SUB_GROUP_SHUFFLE_WIDTH_WAVE_SIZE_QCOM, partial); + } + return partial; +#else + temp[lane] = partial; + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + #pragma unroll + for (uint s = LANES_PER_COLUMN / 2u; s > 0u; s >>= 1u) { + float other = temp[lane ^ s]; + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + temp[lane] += other; + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + } + const float result = temp[lane]; + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + return result; +#endif +} + +#define REDUCE_PARTIAL(partial, temp_ptr, lid) \ + ((LANES_PER_COLUMN == 1u) ? (partial) : reduce_add_shmem((partial), (temp_ptr), (lid))) + +// force compiler to optimize kernel for a specific fixed work-group size +__attribute__((reqd_work_group_size(WG_SIZE, 1, 1))) +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_32 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_gated_delta_net( + global const char * q_buf, ulong off_q, + global const char * k_buf, ulong off_k, + global const char * v_buf, ulong off_v, + global const char * g_buf, ulong off_g, + global const char * beta_buf, ulong off_beta, + global const char * state_buf, ulong off_state, + global char * dst_buf, ulong off_dst, + uint H_v, + uint n_tokens, + uint n_seqs, + uint s_off, + uint sq1, uint sq2, uint sq3, + uint sv1, uint sv2, uint sv3, + uint sb1, uint sb2, uint sb3, + uint H_k, + uint rq3, + float scale, + uint K) { + + global const float * data_q = (global const float *)(q_buf + off_q); + global const float * data_k = (global const float *)(k_buf + off_k); + global const float * data_v = (global const float *)(v_buf + off_v); + global const float * data_g = (global const float *)(g_buf + off_g); + global const float * data_beta = (global const float *)(beta_buf + off_beta); + global const float * data_state = (global const float *)(state_buf + off_state); + global float * data_dst = (global float *)(dst_buf + off_dst); + + const uint head_id = get_group_id(0); + const uint seq_id = get_group_id(1); + const uint tid = (uint)get_local_id(0); + + const uint sg_id = get_sub_group_id(); // subgroup id + const uint sg_lid = get_sub_group_local_id(); // subgroup lane id + + const uint lane = sg_lid % LANES_PER_COLUMN; + const uint lane_group = sg_lid / LANES_PER_COLUMN; + const uint wg_col_base = get_group_id(2) * COLS_PER_WG; + const uint sg_col_base = wg_col_base + sg_id * COLS_PER_SG; + + const uint iq1 = head_id % H_k; // head index for Q and K + const uint iq3 = seq_id / rq3; // seq index for Q and K + + const uint state_size = S_V * S_V; + const uint state_base = (seq_id * K * H_v + head_id) * state_size; + const uint q_off_base = iq3 * sq3 + iq1 * sq1; + const uint v_off_base = seq_id * sv3 + head_id * sv1; + const uint gb_off_base = seq_id * sb3 + head_id * sb1; + const uint state_out_base = (seq_id * H_v + head_id) * state_size; + const uint state_size_per_snap = state_size * H_v * n_seqs; + + __local float reduce_temp[WG_SIZE]; + __local float * temp_ptr = reduce_temp + sg_id * SUBGROUP_SIZE; + + float s_shard[COLS_PER_LANE_GROUP][ROWS_PER_LANE]; + #pragma unroll + for (uint cg = 0; cg < COLS_PER_LANE_GROUP; cg++) { + const uint col = sg_col_base + cg * LANE_GROUPS_PER_SG + lane_group; + #pragma unroll + for (uint r = 0; r < ROWS_PER_LANE; r++) { + s_shard[cg][r] = data_state[state_base + col * S_V + r * LANES_PER_COLUMN + lane]; + } + } + + const int shift = (int)n_tokens - (int)K; + uint attn_off = (seq_id * n_tokens * H_v + head_id) * S_V; + + for (uint t = 0; t < n_tokens; t++) { + const uint q_off = q_off_base + t * sq2; + const uint k_off = q_off; + const uint v_off = v_off_base + t * sv2; + const uint gb_off = gb_off_base + t * sb2; + const float beta_val = data_beta[gb_off]; + + float k_reg[ROWS_PER_LANE]; + float q_reg[ROWS_PER_LANE]; +#if KDA + float g_exp[ROWS_PER_LANE]; + #pragma unroll + for (uint r = 0; r < ROWS_PER_LANE; r++) { + const uint i = r * LANES_PER_COLUMN + lane; + k_reg[r] = data_k[k_off + i]; + q_reg[r] = data_q[q_off + i]; + g_exp[r] = exp(data_g[gb_off * S_V + i]); + } +#else + const float g_val = exp(data_g[gb_off]); + + #pragma unroll + for (uint r = 0; r < ROWS_PER_LANE; r++) { + const uint i = r * LANES_PER_COLUMN + lane; + k_reg[r] = data_k[k_off + i]; + q_reg[r] = data_q[q_off + i]; + } +#endif + + #pragma unroll + for (uint cg = 0; cg < COLS_PER_LANE_GROUP; cg++) { + const uint col = sg_col_base + cg * LANE_GROUPS_PER_SG + lane_group; + float v_val = data_v[v_off + col]; + + float kv_shard = 0.0f; + #pragma unroll + for (uint r = 0; r < ROWS_PER_LANE; r++) { +#if KDA + float gs = g_exp[r] * s_shard[cg][r]; + kv_shard += gs * k_reg[r]; +#else + kv_shard += s_shard[cg][r] * k_reg[r]; +#endif + } + +#if !KDA + kv_shard *= g_val; // Applied once instead of ROWS_PER_LANE times +#endif + + const float kv_col = REDUCE_PARTIAL(kv_shard, temp_ptr, sg_lid); + + const float delta_col = (v_val - kv_col) * beta_val; + + float attn_partial = 0.0f; + #pragma unroll + for (uint r = 0; r < ROWS_PER_LANE; r++) { +#if KDA + float gs = g_exp[r] * s_shard[cg][r]; +#else + float gs = g_val * s_shard[cg][r]; +#endif + s_shard[cg][r] = gs + k_reg[r] * delta_col; + attn_partial += s_shard[cg][r] * q_reg[r]; + } + const float attn_col = REDUCE_PARTIAL(attn_partial, temp_ptr, sg_lid); + + if (lane == 0) { + data_dst[attn_off + col] = attn_col * scale; + } + } + attn_off += S_V * H_v; + + if (K > 1u) { + const int target_slot = (int)t - shift; + if (target_slot >= 0 && target_slot < (int)K) { + #pragma unroll + for (uint cg = 0; cg < COLS_PER_LANE_GROUP; cg++) { + const uint col = sg_col_base + cg * LANE_GROUPS_PER_SG + lane_group; + const uint slot_base = s_off + (uint)target_slot * state_size_per_snap + state_out_base; + #pragma unroll + for (uint r = 0; r < ROWS_PER_LANE; r++) { + data_dst[slot_base + col * S_V + r * LANES_PER_COLUMN + lane] = s_shard[cg][r]; + } + } + } + } + } + + if (K == 1u) { + #pragma unroll + for (uint cg = 0; cg < COLS_PER_LANE_GROUP; cg++) { + const uint col = sg_col_base + cg * LANE_GROUPS_PER_SG + lane_group; + #pragma unroll + for (uint r = 0; r < ROWS_PER_LANE; r++) { + data_dst[s_off + state_base + col * S_V + r * LANES_PER_COLUMN + lane] = s_shard[cg][r]; + } + } + } +} From d284e1c3aa307e56cc43a81152c1cebea46e29cb Mon Sep 17 00:00:00 2001 From: ymcki <84055651+ymcki@users.noreply.github.com> Date: Thu, 28 May 2026 14:05:25 +0800 Subject: [PATCH 718/831] Hexagon: OP_GATED_DELTA_NET K>1 support (llama/23531) * K>1 state snapshot support * removed picky indent multiple of 4 fixes --- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 5 +-- .../ggml-hexagon/htp/gated-delta-net-ops.c | 32 +++++++++++++++---- 2 files changed, 29 insertions(+), 8 deletions(-) diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 5e8a4a740c1..3af7aff7028 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -2537,6 +2537,7 @@ static bool ggml_hexagon_supported_gated_delta_net(const struct ggml_hexagon_ses const int64_t H = v->ne[1]; const int64_t n_tokens = v->ne[2]; const int64_t n_seqs = v->ne[3]; + const int64_t K = state->ne[1]; if (S_v <= 0 || S_v > 128 || H <= 0 || n_tokens <= 0 || n_seqs <= 0) { return false; @@ -2549,10 +2550,10 @@ static bool ggml_hexagon_supported_gated_delta_net(const struct ggml_hexagon_ses if ((g->ne[0] != 1 && g->ne[0] != S_v) || beta->ne[0] != 1) { return false; } - if (ggml_nelements(state) != S_v * S_v * H * n_seqs) { + if (ggml_nelements(state) != S_v * S_v * H * n_seqs * K) { return false; } - if (dst->ne[0] != S_v * H || dst->ne[1] != n_tokens * n_seqs + S_v * n_seqs) { + if (dst->ne[0] != S_v * H || dst->ne[1] != n_tokens * n_seqs + S_v * n_seqs * K) { return false; } diff --git a/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c b/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c index 2e84badc9b7..c4d08bb21c4 100644 --- a/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c +++ b/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c @@ -586,6 +586,7 @@ static void gated_delta_net_f32_pp_thread(unsigned int nth, unsigned int ith, vo const uint32_t H = v->ne[1]; const uint32_t n_tokens = v->ne[2]; const uint32_t n_seqs = v->ne[3]; + const uint32_t K = state->ne[1]; const uint32_t total_rows = H * n_seqs; if (ith >= total_rows) { @@ -606,6 +607,10 @@ static void gated_delta_net_f32_pp_thread(unsigned int nth, unsigned int ith, vo float local_k[HTP_GDN_MAX_SV] __attribute__((aligned(128))); float local_sums[4] __attribute__((aligned(128))); + const uint64_t state_seq_stride = state->nb[2] / sizeof(float); + const uint64_t state_size_per_snap = (uint64_t) S_v * S_v * H * n_seqs; + const int64_t shift = (int64_t) n_tokens - (int64_t) K; + for (uint32_t ir = ith; ir < total_rows; ir += nth) { const uint32_t iv1 = ir % H; const uint32_t iv3 = ir / H; @@ -615,8 +620,8 @@ static void gated_delta_net_f32_pp_thread(unsigned int nth, unsigned int ith, vo const uint32_t iq3 = iv3 / rq3; const uint32_t ik3 = iv3 / rk3; - float * s_out = state_out_base + ((uint64_t) iv3 * H + iv1) * S_v * S_v; - const float * s_in = state_in_base + ((uint64_t) iv3 * H + iv1) * S_v * S_v; + float * s_out = state_out_base + (uint64_t) (K - 1) * state_size_per_snap + ((uint64_t) iv3 * H + iv1) * S_v * S_v; + const float * s_in = state_in_base + (uint64_t) iv3 * state_seq_stride + (uint64_t) iv1 * S_v * S_v; memcpy(s_out, s_in, gctx->state_bytes); float * s_work = s_out; @@ -689,6 +694,16 @@ static void gated_delta_net_f32_pp_thread(unsigned int nth, unsigned int ith, vo } } + if (K > 1) { + const int64_t target_slot = (int64_t) t - shift; + if (target_slot >= 0 && target_slot < (int64_t) K) { + float * curr_state_o = state_out_base + (uint64_t) target_slot * state_size_per_snap + ((uint64_t) iv3 * H + iv1) * S_v * S_v; + if (curr_state_o != s_work) { + memcpy(curr_state_o, s_work, gctx->state_bytes); + } + } + } + attn_data += (uint64_t) S_v * H; } } @@ -709,6 +724,7 @@ static void gated_delta_net_f32_tg_thread(unsigned int nth, unsigned int ith, vo const uint32_t S_v = v->ne[0]; const uint32_t H = v->ne[1]; const uint32_t n_seqs = v->ne[3]; + const uint32_t K = state->ne[1]; const uint32_t total_rows = H * n_seqs; if (ith >= total_rows) { @@ -736,6 +752,9 @@ static void gated_delta_net_f32_tg_thread(unsigned int nth, unsigned int ith, vo spad = gctx->vtcm_state_base + gctx->vtcm_state_per_thread * ith; } + const uint64_t state_seq_stride = state->nb[2] / sizeof(float); + const uint64_t state_size_per_snap = (uint64_t) S_v * S_v * H * n_seqs; + for (uint32_t ir = ith; ir < total_rows; ir += nth) { const uint32_t iv1 = ir % H; const uint32_t iv3 = ir / H; @@ -745,8 +764,8 @@ static void gated_delta_net_f32_tg_thread(unsigned int nth, unsigned int ith, vo const uint32_t iq3 = iv3 / rq3; const uint32_t ik3 = iv3 / rk3; - float * s_out = state_out_base + ((uint64_t) iv3 * H + iv1) * S_v * S_v; - const float * s_in = state_in_base + ((uint64_t) iv3 * H + iv1) * S_v * S_v; + float * s_out = state_out_base + (uint64_t) (K - 1) * state_size_per_snap + ((uint64_t) iv3 * H + iv1) * S_v * S_v; + const float * s_in = state_in_base + (uint64_t) iv3 * state_seq_stride + (uint64_t) iv1 * S_v * S_v; float * s_work; if (spad) { @@ -901,6 +920,7 @@ int op_gated_delta_net(struct htp_ops_context * octx) { const uint32_t H = v->ne[1]; const uint32_t n_tokens = v->ne[2]; const uint32_t n_seqs = v->ne[3]; + const uint32_t K = state->ne[1]; if (S_v == 0 || S_v > HTP_GDN_MAX_SV || H == 0 || n_tokens == 0 || n_seqs == 0) { return HTP_STATUS_NO_SUPPORT; @@ -913,10 +933,10 @@ int op_gated_delta_net(struct htp_ops_context * octx) { (n_seqs % q->ne[3]) != 0 || (n_seqs % k->ne[3]) != 0) { return HTP_STATUS_NO_SUPPORT; } - if (state->ne[0] * state->ne[1] * state->ne[2] * state->ne[3] != S_v * S_v * H * n_seqs) { + if (state->ne[0] * state->ne[2] * state->ne[3] != S_v * S_v * H * n_seqs) { return HTP_STATUS_NO_SUPPORT; } - if (dst->ne[0] != S_v * H || dst->ne[1] != n_tokens * n_seqs + S_v * n_seqs) { + if (dst->ne[0] != S_v * H || dst->ne[1] != n_tokens * n_seqs + S_v * n_seqs * K) { return HTP_STATUS_NO_SUPPORT; } From 8e403258767fa1e0946006cee470064f0375f0ba Mon Sep 17 00:00:00 2001 From: Martin Klacer Date: Thu, 28 May 2026 08:04:21 +0100 Subject: [PATCH 719/831] ggml: fixed Arm SVE usage bug in vec.h, vec.cpp (llama/22841) * Updated vec.h/vec.cpp code to accumulate to F32 rather than F16 Change-Id: I0cb789347f2bf60ffaf9047319f727e788c825f8 Signed-off-by: Martin Klacer Co-authored-by: Milos Puzovic --- ggml/src/ggml-cpu/vec.cpp | 90 +++++++++------------- ggml/src/ggml-cpu/vec.h | 158 +++++++++++++++++--------------------- 2 files changed, 107 insertions(+), 141 deletions(-) diff --git a/ggml/src/ggml-cpu/vec.cpp b/ggml/src/ggml-cpu/vec.cpp index d0e4001338a..67b6b05cac8 100644 --- a/ggml/src/ggml-cpu/vec.cpp +++ b/ggml/src/ggml-cpu/vec.cpp @@ -273,67 +273,51 @@ void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * G #if defined(GGML_SIMD) #if defined(__ARM_FEATURE_SVE) - const int sve_register_length = svcntb() * 8; //get vector length - const int ggml_f16_epr = sve_register_length / 16; // running when 16 - const int ggml_f16_step = 8 * ggml_f16_epr; // choose 8 SVE registers - - const int np= (n & ~(ggml_f16_step - 1)); - svfloat16_t sum1 = svdup_n_f16(0.0f); - svfloat16_t sum2 = svdup_n_f16(0.0f); - svfloat16_t sum3 = svdup_n_f16(0.0f); - svfloat16_t sum4 = svdup_n_f16(0.0f); - - svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8; - svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8; - for (int i = 0; i < np; i += ggml_f16_step) { - ax1 = GGML_F16x_VEC_LOAD(x + i + 0 * ggml_f16_epr, 0); - ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0); - sum1 = GGML_F16x_VEC_FMA(sum1, ax1, ay1); - - ax2 = GGML_F16x_VEC_LOAD(x + i + 1 * ggml_f16_epr, 1); - ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1); - sum2 = GGML_F16x_VEC_FMA(sum2, ax2, ay2); - - ax3 = GGML_F16x_VEC_LOAD(x + i + 2 * ggml_f16_epr, 2); - ay3 = GGML_F16x_VEC_LOAD(y + i + 2 * ggml_f16_epr, 2); - sum3 = GGML_F16x_VEC_FMA(sum3, ax3, ay3); - - ax4 = GGML_F16x_VEC_LOAD(x + i + 3 * ggml_f16_epr, 3); - ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3); - sum4 = GGML_F16x_VEC_FMA(sum4, ax4, ay4); - - ax5 = GGML_F16x_VEC_LOAD(x + i + 4 * ggml_f16_epr, 4); - ay5 = GGML_F16x_VEC_LOAD(y + i + 4 * ggml_f16_epr, 4); - sum1 = GGML_F16x_VEC_FMA(sum1, ax5, ay5); + const int ggml_f16_epr = svcnth(); + const int ggml_f16_step = 8 * ggml_f16_epr; + const int np = n - (n % ggml_f16_step); + const int np2 = n - (n % ggml_f16_epr); + + svfloat32_t sum1_lo = svdup_n_f32(0.0f); + svfloat32_t sum1_hi = svdup_n_f32(0.0f); + svfloat32_t sum2_lo = svdup_n_f32(0.0f); + svfloat32_t sum2_hi = svdup_n_f32(0.0f); + svfloat32_t sum3_lo = svdup_n_f32(0.0f); + svfloat32_t sum3_hi = svdup_n_f32(0.0f); + svfloat32_t sum4_lo = svdup_n_f32(0.0f); + svfloat32_t sum4_hi = svdup_n_f32(0.0f); - ax6 = GGML_F16x_VEC_LOAD(x + i + 5 * ggml_f16_epr, 5); - ay6 = GGML_F16x_VEC_LOAD(y + i + 5 * ggml_f16_epr, 5); - sum2 = GGML_F16x_VEC_FMA(sum2, ax6, ay6); - - ax7 = GGML_F16x_VEC_LOAD(x + i + 6 * ggml_f16_epr, 6); - ay7 = GGML_F16x_VEC_LOAD(y + i + 6 * ggml_f16_epr, 6); - sum3 = GGML_F16x_VEC_FMA(sum3, ax7, ay7); - - ax8 = GGML_F16x_VEC_LOAD(x + i + 7 * ggml_f16_epr, 7); - ay8 = GGML_F16x_VEC_LOAD(y + i + 7 * ggml_f16_epr, 7); - sum4 = GGML_F16x_VEC_FMA(sum4, ax8, ay8); + for (int i = 0; i < np; i += ggml_f16_step) { + ggml_sve_f16_fma_widened(&sum1_lo, &sum1_hi, GGML_F16x_VEC_LOAD(x + i + 0 * ggml_f16_epr, 0), GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0)); + ggml_sve_f16_fma_widened(&sum2_lo, &sum2_hi, GGML_F16x_VEC_LOAD(x + i + 1 * ggml_f16_epr, 1), GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1)); + ggml_sve_f16_fma_widened(&sum3_lo, &sum3_hi, GGML_F16x_VEC_LOAD(x + i + 2 * ggml_f16_epr, 2), GGML_F16x_VEC_LOAD(y + i + 2 * ggml_f16_epr, 2)); + ggml_sve_f16_fma_widened(&sum4_lo, &sum4_hi, GGML_F16x_VEC_LOAD(x + i + 3 * ggml_f16_epr, 3), GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3)); + ggml_sve_f16_fma_widened(&sum1_lo, &sum1_hi, GGML_F16x_VEC_LOAD(x + i + 4 * ggml_f16_epr, 4), GGML_F16x_VEC_LOAD(y + i + 4 * ggml_f16_epr, 4)); + ggml_sve_f16_fma_widened(&sum2_lo, &sum2_hi, GGML_F16x_VEC_LOAD(x + i + 5 * ggml_f16_epr, 5), GGML_F16x_VEC_LOAD(y + i + 5 * ggml_f16_epr, 5)); + ggml_sve_f16_fma_widened(&sum3_lo, &sum3_hi, GGML_F16x_VEC_LOAD(x + i + 6 * ggml_f16_epr, 6), GGML_F16x_VEC_LOAD(y + i + 6 * ggml_f16_epr, 6)); + ggml_sve_f16_fma_widened(&sum4_lo, &sum4_hi, GGML_F16x_VEC_LOAD(x + i + 7 * ggml_f16_epr, 7), GGML_F16x_VEC_LOAD(y + i + 7 * ggml_f16_epr, 7)); } - const int np2 = (n & ~(ggml_f16_epr - 1)); // round down to multiple of 8 - for (int k = np; k < np2; k += ggml_f16_epr) { - svfloat16_t rx = GGML_F16x_VEC_LOAD(x + k, 0); - svfloat16_t ry = GGML_F16x_VEC_LOAD(y + k, 0); - sum1 = GGML_F16x_VEC_FMA(sum1, rx, ry); + for (int i = np; i < np2; i += ggml_f16_epr) { + ggml_sve_f16_fma_widened(&sum1_lo, &sum1_hi, GGML_F16x_VEC_LOAD(x + i, 0), GGML_F16x_VEC_LOAD(y + i, 0)); } if (np2 < n) { - svbool_t pg = svwhilelt_b16(np2, n); - svfloat16_t hx = svld1_f16(pg, (const __fp16 *)(x + np2)); - svfloat16_t hy = svld1_f16(pg, (const __fp16 *)(y + np2)); + const svbool_t pg = svwhilelt_b16(np2, n); + const svfloat16_t rx = svld1_f16(pg, (const __fp16 *)(x + np2)); + const svfloat16_t ry = svld1_f16(pg, (const __fp16 *)(y + np2)); - sum1 = svmad_f16_x(pg, hx, hy, sum1); + ggml_sve_f16_fma_widened(&sum1_lo, &sum1_hi, rx, ry); } - GGML_F16x_VEC_REDUCE(sumf, sum1, sum2, sum3, sum4); + + sum1_lo = svadd_f32_m(DEFAULT_PG32, sum1_lo, sum2_lo); + sum1_hi = svadd_f32_m(DEFAULT_PG32, sum1_hi, sum2_hi); + sum3_lo = svadd_f32_m(DEFAULT_PG32, sum3_lo, sum4_lo); + sum3_hi = svadd_f32_m(DEFAULT_PG32, sum3_hi, sum4_hi); + sum1_lo = svadd_f32_m(DEFAULT_PG32, sum1_lo, sum3_lo); + sum1_hi = svadd_f32_m(DEFAULT_PG32, sum1_hi, sum3_hi); + + sumf = ggml_sve_sum_f32x2(sum1_lo, sum1_hi); #elif defined(__riscv_v_intrinsic) #if defined(__riscv_zvfh) int vl = __riscv_vsetvlmax_e32m2(); diff --git a/ggml/src/ggml-cpu/vec.h b/ggml/src/ggml-cpu/vec.h index bcd68da9aa9..5de9cb5b7e0 100644 --- a/ggml/src/ggml-cpu/vec.h +++ b/ggml/src/ggml-cpu/vec.h @@ -14,6 +14,35 @@ // floating point type used to accumulate sums typedef double ggml_float; +#if defined(__ARM_FEATURE_SVE) +inline static void ggml_sve_f16_fma_widened( + svfloat32_t * acc_lo, + svfloat32_t * acc_hi, + svfloat16_t x, + svfloat16_t y) { +#if defined(__ARM_FEATURE_SVE2) + *acc_lo = svmlalb_f32(*acc_lo, x, y); + *acc_hi = svmlalt_f32(*acc_hi, x, y); +#else + // Plain SVE fallback path if SVE2 instructions not available + svfloat16_t x_even = svtrn1_f16(x, x); + svfloat16_t x_odd = svtrn2_f16(x, x); + + svfloat16_t y_even = svtrn1_f16(y, y); + svfloat16_t y_odd = svtrn2_f16(y, y); + + svbool_t pg = svptrue_b32(); + + *acc_lo = svmla_f32_x(pg, *acc_lo, svcvt_f32_f16_x(pg, x_even), svcvt_f32_f16_x(pg, y_even)); + *acc_hi = svmla_f32_x(pg, *acc_hi, svcvt_f32_f16_x(pg, x_odd), svcvt_f32_f16_x(pg, y_odd)); +#endif +} + +inline static ggml_float ggml_sve_sum_f32x2(svfloat32_t sum_lo, svfloat32_t sum_hi) { + return (ggml_float) (svaddv_f32(svptrue_b32(), sum_lo) + svaddv_f32(svptrue_b32(), sum_hi)); +} +#endif + #define GGML_GELU_FP16 #define GGML_GELU_QUICK_FP16 @@ -122,108 +151,61 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG #if defined(GGML_SIMD) #if defined(__ARM_FEATURE_SVE) - const int sve_register_length = svcntb() * 8; - const int ggml_f16_epr = sve_register_length / 16; // running when 16 - const int ggml_f16_step = 8 * ggml_f16_epr; // choose 8 SVE registers - - int np = (n & ~(ggml_f16_step - 1)); - - svfloat16_t sum_00 = svdup_n_f16(0.0f); - svfloat16_t sum_01 = svdup_n_f16(0.0f); - svfloat16_t sum_02 = svdup_n_f16(0.0f); - svfloat16_t sum_03 = svdup_n_f16(0.0f); + const int ggml_f16_epr = svcnth(); + const int ggml_f16_step = 2 * ggml_f16_epr; + int np = n - (n % ggml_f16_step); + int np2 = n - (n % ggml_f16_epr); - svfloat16_t sum_10 = svdup_n_f16(0.0f); - svfloat16_t sum_11 = svdup_n_f16(0.0f); - svfloat16_t sum_12 = svdup_n_f16(0.0f); - svfloat16_t sum_13 = svdup_n_f16(0.0f); - - svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8; - svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8; + svfloat32_t sum_0_0_lo = svdup_n_f32(0.0f); + svfloat32_t sum_0_0_hi = svdup_n_f32(0.0f); + svfloat32_t sum_0_1_lo = svdup_n_f32(0.0f); + svfloat32_t sum_0_1_hi = svdup_n_f32(0.0f); + svfloat32_t sum_1_0_lo = svdup_n_f32(0.0f); + svfloat32_t sum_1_0_hi = svdup_n_f32(0.0f); + svfloat32_t sum_1_1_lo = svdup_n_f32(0.0f); + svfloat32_t sum_1_1_hi = svdup_n_f32(0.0f); for (int i = 0; i < np; i += ggml_f16_step) { - ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0); // 8 elements - - ax1 = GGML_F16x_VEC_LOAD(x[0] + i + 0*ggml_f16_epr, 0); // 8 elements - sum_00 = GGML_F16x_VEC_FMA(sum_00, ax1, ay1); // sum_00 = sum_00+ax1*ay1 - ax1 = GGML_F16x_VEC_LOAD(x[1] + i + 0*ggml_f16_epr, 0); // 8 elements - sum_10 = GGML_F16x_VEC_FMA(sum_10, ax1, ay1); - - ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1); // next 8 elements + const svfloat16_t ay0 = GGML_F16x_VEC_LOAD(y + i, 0); + const svfloat16_t ax00 = GGML_F16x_VEC_LOAD(x[0] + i, 0); + const svfloat16_t ax01 = GGML_F16x_VEC_LOAD(x[1] + i, 0); - ax2 = GGML_F16x_VEC_LOAD(x[0] + i + 1*ggml_f16_epr, 1); // next 8 elements - sum_01 = GGML_F16x_VEC_FMA(sum_01, ax2, ay2); - ax2 = GGML_F16x_VEC_LOAD(x[1] + i + 1*ggml_f16_epr, 1); - sum_11 = GGML_F16x_VEC_FMA(sum_11, ax2, ay2); + ggml_sve_f16_fma_widened(&sum_0_0_lo, &sum_0_0_hi, ax00, ay0); + ggml_sve_f16_fma_widened(&sum_1_0_lo, &sum_1_0_hi, ax01, ay0); - ay3 = GGML_F16x_VEC_LOAD(y + i + 2 * ggml_f16_epr, 2); + const svfloat16_t ay1 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 0); + const svfloat16_t ax10 = GGML_F16x_VEC_LOAD(x[0] + i + 1 * ggml_f16_epr, 0); + const svfloat16_t ax11 = GGML_F16x_VEC_LOAD(x[1] + i + 1 * ggml_f16_epr, 0); - ax3 = GGML_F16x_VEC_LOAD(x[0] + i + 2*ggml_f16_epr, 2); - sum_02 = GGML_F16x_VEC_FMA(sum_02, ax3, ay3); - ax3 = GGML_F16x_VEC_LOAD(x[1] + i + 2*ggml_f16_epr, 2); - sum_12 = GGML_F16x_VEC_FMA(sum_12, ax3, ay3); - - ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3); - - ax4 = GGML_F16x_VEC_LOAD(x[0] + i + 3*ggml_f16_epr, 3); - sum_03 = GGML_F16x_VEC_FMA(sum_03, ax4, ay4); - ax4 = GGML_F16x_VEC_LOAD(x[1] + i + 3*ggml_f16_epr, 3); - sum_13 = GGML_F16x_VEC_FMA(sum_13, ax4, ay4); - - ay5 = GGML_F16x_VEC_LOAD(y + i + 4 * ggml_f16_epr, 4); - - ax5 = GGML_F16x_VEC_LOAD(x[0] + i + 4*ggml_f16_epr, 4); - - sum_00 = GGML_F16x_VEC_FMA(sum_00, ax5, ay5); - ax5 = GGML_F16x_VEC_LOAD(x[1] + i + 4*ggml_f16_epr, 4); - sum_10 = GGML_F16x_VEC_FMA(sum_10, ax5, ay5); - - ay6 = GGML_F16x_VEC_LOAD(y + i + 5 * ggml_f16_epr, 5); - - ax6 = GGML_F16x_VEC_LOAD(x[0] + i + 5*ggml_f16_epr, 5); - - sum_01 = GGML_F16x_VEC_FMA(sum_01, ax6, ay6); - ax6 = GGML_F16x_VEC_LOAD(x[1] + i + 5*ggml_f16_epr, 5); - sum_11 = GGML_F16x_VEC_FMA(sum_11, ax6, ay6); - - ay7 = GGML_F16x_VEC_LOAD(y + i + 6 * ggml_f16_epr, 6); - - ax7 = GGML_F16x_VEC_LOAD(x[0] + i + 6*ggml_f16_epr, 6); - - sum_02 = GGML_F16x_VEC_FMA(sum_02, ax7, ay7); - ax7 = GGML_F16x_VEC_LOAD(x[1] + i + 6*ggml_f16_epr, 6); - sum_12 = GGML_F16x_VEC_FMA(sum_12, ax7, ay7); - - ay8 = GGML_F16x_VEC_LOAD(y + i + 7 * ggml_f16_epr, 7); - - ax8 = GGML_F16x_VEC_LOAD(x[0] + i + 7*ggml_f16_epr, 7); - - sum_03 = GGML_F16x_VEC_FMA(sum_03, ax8, ay8); - ax8 = GGML_F16x_VEC_LOAD(x[1] + i + 7*ggml_f16_epr, 7); - sum_13 = GGML_F16x_VEC_FMA(sum_13, ax8, ay8); + ggml_sve_f16_fma_widened(&sum_0_1_lo, &sum_0_1_hi, ax10, ay1); + ggml_sve_f16_fma_widened(&sum_1_1_lo, &sum_1_1_hi, ax11, ay1); } - const int np2 = (n & ~(ggml_f16_epr - 1)); - for (int k = np; k < np2; k += ggml_f16_epr) { - svfloat16_t ry = GGML_F16x_VEC_LOAD(y + k, 0); + for (int i = np; i < np2; i += ggml_f16_epr) { + const svfloat16_t ry = GGML_F16x_VEC_LOAD(y + i, 0); + const svfloat16_t rx0 = GGML_F16x_VEC_LOAD(x[0] + i, 0); + const svfloat16_t rx1 = GGML_F16x_VEC_LOAD(x[1] + i, 0); - svfloat16_t rx = GGML_F16x_VEC_LOAD(x[0] + k, 0); - sum_00 = GGML_F16x_VEC_FMA(sum_00, rx, ry); - rx = GGML_F16x_VEC_LOAD(x[1] + k, 0); - sum_10 = GGML_F16x_VEC_FMA(sum_10, rx, ry); + ggml_sve_f16_fma_widened(&sum_0_0_lo, &sum_0_0_hi, rx0, ry); + ggml_sve_f16_fma_widened(&sum_1_0_lo, &sum_1_0_hi, rx1, ry); } if (np2 < n) { - svbool_t pg = svwhilelt_b16(np2, n); - svfloat16_t hx_0 = svld1_f16(pg, (const __fp16 *)(x[0] + np2)); - svfloat16_t hx_1 = svld1_f16(pg, (const __fp16 *)(x[1] + np2)); - svfloat16_t hy = svld1_f16(pg, (const __fp16 *)(y + np2)); + const svbool_t pg = svwhilelt_b16(np2, n); + const svfloat16_t ay = svld1_f16(pg, (const __fp16 *)(y + np2)); + const svfloat16_t ax0 = svld1_f16(pg, (const __fp16 *)(x[0] + np2)); + const svfloat16_t ax1 = svld1_f16(pg, (const __fp16 *)(x[1] + np2)); - sum_00 = svmad_f16_x(pg, hx_0, hy, sum_00); - sum_10 = svmad_f16_x(pg, hx_1, hy, sum_10); + ggml_sve_f16_fma_widened(&sum_0_0_lo, &sum_0_0_hi, ax0, ay); + ggml_sve_f16_fma_widened(&sum_1_0_lo, &sum_1_0_hi, ax1, ay); } - GGML_F16x_VEC_REDUCE(sumf[0], sum_00, sum_01, sum_02, sum_03); - GGML_F16x_VEC_REDUCE(sumf[1], sum_10, sum_11, sum_12, sum_13); + + svfloat32_t sum_0_lo = svadd_f32_x(DEFAULT_PG32, sum_0_0_lo, sum_0_1_lo); + svfloat32_t sum_0_hi = svadd_f32_x(DEFAULT_PG32, sum_0_0_hi, sum_0_1_hi); + svfloat32_t sum_1_lo = svadd_f32_x(DEFAULT_PG32, sum_1_0_lo, sum_1_1_lo); + svfloat32_t sum_1_hi = svadd_f32_x(DEFAULT_PG32, sum_1_0_hi, sum_1_1_hi); + sumf[0] = ggml_sve_sum_f32x2(sum_0_lo, sum_0_hi); + sumf[1] = ggml_sve_sum_f32x2(sum_1_lo, sum_1_hi); np = n; #elif defined(__riscv_v_intrinsic) #if defined(__riscv_zvfh) From 60e420ff6ac28ae5bc5af42b4a77bc98dca760e6 Mon Sep 17 00:00:00 2001 From: fairydreaming <166155368+fairydreaming@users.noreply.github.com> Date: Thu, 28 May 2026 10:55:42 +0200 Subject: [PATCH 720/831] cuda : fix KQ mask offset integer overflow in fattn MMA kernel (llama/23610) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Stanisław Szymczyk --- ggml/src/ggml-cuda/fattn-mma-f16.cuh | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 4871b90df86..3c8b6eaaf24 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -472,7 +472,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask( const int i = 8 * (threadIdx.x % (nbatch_fa/8)); - cp_async_cg_16(tile_mask_32 + j_sram*(nbatch_fa*sizeof(half) + 16) + i*sizeof(half), mask_h + j_vram*stride_mask + i); + cp_async_cg_16(tile_mask_32 + j_sram*(nbatch_fa*sizeof(half) + 16) + i*sizeof(half), mask_h + int64_t(j_vram)*stride_mask + i); } } else if constexpr (oob_check) { #pragma unroll @@ -488,7 +488,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask( for (int i0 = 0; i0 < nbatch_fa; i0 += warp_size) { const int i = i0 + threadIdx.x; - tile_mask[j_sram*(nbatch_fa + 8) + i] = i < i_sup ? mask_h[j_vram*stride_mask + i] : half(0.0f); + tile_mask[j_sram*(nbatch_fa + 8) + i] = i < i_sup ? mask_h[int64_t(j_vram)*stride_mask + i] : half(0.0f); } } } else if constexpr (nbatch_fa < 2*warp_size) { @@ -505,7 +505,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask( const int i = threadIdx.x % (warp_size/cols_per_warp); - ggml_cuda_memcpy_1(tile_mask + j_sram*(nbatch_fa + 8) + 2*i, mask_h + j_vram*stride_mask + 2*i); + ggml_cuda_memcpy_1(tile_mask + j_sram*(nbatch_fa + 8) + 2*i, mask_h + int64_t(j_vram)*stride_mask + 2*i); } } else { #pragma unroll @@ -521,7 +521,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask( for (int i0 = 0; i0 < nbatch_fa; i0 += 2*warp_size) { const int i = i0 + 2*threadIdx.x; - ggml_cuda_memcpy_1(tile_mask + j_sram*(nbatch_fa + 8) + i, mask_h + j_vram*stride_mask + i); + ggml_cuda_memcpy_1(tile_mask + j_sram*(nbatch_fa + 8) + i, mask_h + int64_t(j_vram)*stride_mask + i); } } } From 5db94bac041884f804c49ae98a57a738f83b9c0c Mon Sep 17 00:00:00 2001 From: Winston Ma Date: Thu, 28 May 2026 18:46:07 +0800 Subject: [PATCH 721/831] vulkan: Fix memory logger unsafe iterator access (llama/23667) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index fb07282ef76..b8ac4a9c26c 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -2095,9 +2095,9 @@ void vk_memory_logger::log_deallocation(vk_buffer_ref buf_ref) { const bool device = bool(buf->memory_property_flags & vk::MemoryPropertyFlagBits::eDeviceLocal); std::string type = device ? "device" : "host"; auto it = allocations.find(buf->buffer); - total_device -= device ? it->second : 0; - total_host -= device ? 0 : it->second; if (it != allocations.end()) { + total_device -= device ? it->second : 0; + total_host -= device ? 0 : it->second; VK_LOG_MEMORY(buf->device->name << ": -" << format_size(it->second) << " " << type << " at " << buf->buffer << ". Total device: " << format_size(total_device) << ", total host: " << format_size(total_host)); allocations.erase(it); } else { From 816c3029bc8cd046d4e2726ba2f1bbf99b8adc8f Mon Sep 17 00:00:00 2001 From: Winston Ma Date: Thu, 28 May 2026 18:48:34 +0800 Subject: [PATCH 722/831] vulkan: fix wrong index variable in inner loop (llama/23665) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index b8ac4a9c26c..238ee822397 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -7233,7 +7233,7 @@ static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_cont const uint64_t s_off = buf_offset + i3*nb3 + i2*nb2 + i1*nb1; const uint64_t d_off = offset + i3*dstnb3 + i2*dstnb2 + i1*dstnb1; for (uint64_t i0 = 0; i0 < ne0; i0++) { - slices.push_back({ s_off + i1*nb0, d_off + i0*dstnb0, dstnb0 }); + slices.push_back({ s_off + i0*nb0, d_off + i0*dstnb0, dstnb0 }); } } } From b896e91f18ec245f1415fe5d18a77e766197985e Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Thu, 28 May 2026 06:18:43 -0500 Subject: [PATCH 723/831] vulkan: fast path for walsh-hadamard transform (llama/23687) * vulkan: fast path for walsh-hadamard transform * disable for intel due to segfault --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 91 +++++++++++++++++++ ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp | 69 ++++++++++++++ .../vulkan-shaders/vulkan-shaders-gen.cpp | 1 + 3 files changed, 161 insertions(+) create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 238ee822397..c9f906d7930 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -860,6 +860,7 @@ struct vk_device_struct { vk_pipeline pipeline_argsort_large_f32[num_argsort_pipelines]; vk_pipeline pipeline_topk_f32[num_topk_pipelines]; vk_pipeline pipeline_sum_rows_f32; + vk_pipeline pipeline_fwht_f32[4]; vk_pipeline pipeline_cumsum_f32; vk_pipeline pipeline_cumsum_small_f32; vk_pipeline pipeline_cumsum_multipass1_f32; @@ -1150,6 +1151,13 @@ struct vk_op_push_constants { float param4; }; +struct vk_op_fwht_push_constants { + uint32_t n_rows; + uint32_t src_offset; + uint32_t dst_offset; + float scale; +}; + struct vk_op_count_experts_push_constants { uint32_t ne00; uint32_t ne01; @@ -2055,6 +2063,15 @@ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk GGML_UNUSED(src3); } +template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_fwht_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) { + p.src_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type); + p.dst_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type); + + GGML_UNUSED(src1); + GGML_UNUSED(src2); + GGML_UNUSED(src3); +} + struct ggml_backend_vk_buffer_context { vk_device_ref device; vk_buffer dev_buffer; @@ -4982,6 +4999,16 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + // Intel Arc B390 was observed segfaulting with this shader. + if (device->subgroup_basic && device->subgroup_shuffle && device->vendor_id != VK_VENDOR_ID_INTEL) { + int idx = 0; + for (uint32_t n : {64, 128, 256, 512}) { + if (device->subgroup_size <= n) { + ggml_vk_create_pipeline(device, device->pipeline_fwht_f32[idx], "fwht_f32", fwht_f32_len, fwht_f32_data, "main", 2, sizeof(vk_op_fwht_push_constants), {1, 1, 1}, { device->subgroup_size, n }, 1, true, true, device->subgroup_size); + } + ++idx; + } + } const uint32_t cumsum_elem_per_thread = (device->vendor_id == VK_VENDOR_ID_AMD || device->vendor_id == VK_VENDOR_ID_INTEL) ? 2 : 4; ggml_vk_create_pipeline(device, device->pipeline_cumsum_f32, "cumsum_f32", cumsum_f32_len, cumsum_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { 256, device->subgroup_size, cumsum_elem_per_thread }, 1, true, true, device->subgroup_size); @@ -8741,6 +8768,68 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con }, pc, { (uint32_t)ne03, (uint32_t)ne01, (uint32_t)ne12 }); } +static int ggml_vk_fwht_pipeline_idx(int64_t n) { + switch (n) { + case 64: return 0; + case 128: return 1; + case 256: return 2; + case 512: return 3; + default: return -1; + } +} + +static bool ggml_vk_can_use_fwht(const ggml_backend_vk_context * ctx, const ggml_tensor * src1, const ggml_tensor * dst) { + if (ctx->num_additional_fused_ops != 0) { + return false; + } + + if (ggml_get_op_params_i32(dst, 1) != GGML_HINT_SRC0_IS_HADAMARD) { + return false; + } + + const int idx = ggml_vk_fwht_pipeline_idx(src1->ne[0]); + if (idx < 0 || ctx->device->pipeline_fwht_f32[idx] == nullptr) { + return false; + } + + if (src1->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) { + return false; + } + + if (!ggml_is_contiguous(src1)) { + return false; + } + GGML_ASSERT(ggml_is_contiguous(dst)); + + return true; +} + +static void ggml_vk_fwht(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src, ggml_tensor * dst) { + const int idx = ggml_vk_fwht_pipeline_idx(src->ne[0]); + vk_pipeline pipeline = ctx->device->pipeline_fwht_f32[idx]; + + const uint32_t rows_per_workgroup = 4; + const uint32_t n_rows = (uint32_t)ggml_nrows(src); + const uint32_t max_workgroups_x = ctx->device->properties.limits.maxComputeWorkGroupCount[0]; + + const uint32_t total_workgroups = CEIL_DIV(n_rows, rows_per_workgroup); + const uint32_t workgroups_x = std::min(total_workgroups, max_workgroups_x); + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + + const vk_subbuffer src_buf = ggml_vk_tensor_subbuffer(ctx, src, true); + const vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst, true); + + vk_op_fwht_push_constants pc = { + n_rows, + 0, + 0, + 1.0f / std::sqrt((float)src->ne[0]), + }; + init_pushconst_tensor_offsets(ctx, pc, src, nullptr, nullptr, nullptr, dst); + + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src_buf, dst_buf }, pc, { workgroups_x, 1, 1 }); +} + static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) { ggml_tensor * dst = cgraph->nodes[node_idx]; ggml_tensor * src0 = dst->src[0]; @@ -8774,6 +8863,8 @@ static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, c m_offset += cur_M_size; } + } else if (ggml_vk_can_use_fwht(ctx, src1, dst)) { + ggml_vk_fwht(ctx, subctx, src1, dst); } else if (src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && dst->ne[1] == 1 && // detect 0213 permutation, and batch size of 1 src0->nb[0] <= src0->nb[2] && diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp b/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp new file mode 100644 index 00000000000..72059d4afc2 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp @@ -0,0 +1,69 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : require +#extension GL_KHR_shader_subgroup_basic : enable +#extension GL_KHR_shader_subgroup_shuffle : enable + +layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in; + +layout(constant_id = 0) const uint WARP_SIZE = 32; +layout(constant_id = 1) const uint N = 128; + +layout(push_constant) uniform parameter +{ + uint n_rows; + uint src_offset; + uint dst_offset; + float scale; +}; + +layout(binding = 0, std430) readonly buffer A { float data_a[]; }; +layout(binding = 1, std430) writeonly buffer D { float data_d[]; }; + +const uint EL_W = N / WARP_SIZE; + +void main() { + const uint lane = gl_SubgroupInvocationID; + for (uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_SubgroupID; + row < n_rows; + row += gl_NumWorkGroups.x * gl_WorkGroupSize.y) { + const uint row_offset = row * N; + + float reg[EL_W]; + + [[unroll]] + for (uint i = 0; i < EL_W; ++i) { + reg[i] = data_a[src_offset + row_offset + i * WARP_SIZE + lane] * scale; + } + + [[unroll]] + for (uint h = 1; h < WARP_SIZE; h <<= 1) { + [[unroll]] + for (uint j = 0; j < EL_W; ++j) { + const float val = reg[j]; + const float val2 = subgroupShuffleXor(val, h); + reg[j] = (lane & h) == 0 ? val + val2 : val2 - val; + } + } + + [[unroll]] + for (uint h = WARP_SIZE; h < N; h <<= 1) { + const uint step = h / WARP_SIZE; + [[unroll]] + for (uint j = 0; j < EL_W; j += 2 * step) { + [[unroll]] + for (uint k = 0; k < step; ++k) { + const float x = reg[j + k]; + const float y = reg[j + k + step]; + reg[j + k] = x + y; + reg[j + k + step] = x - y; + } + } + } + + [[unroll]] + for (uint i = 0; i < EL_W; ++i) { + data_d[dst_offset + row_offset + i * WARP_SIZE + lane] = reg[i]; + } + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 24b9d25f733..fa9b938e4f7 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -934,6 +934,7 @@ void process_shaders() { string_to_spv("argmax_f32", "argmax.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "int"}})); string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("fwht_f32", "fwht.comp", {}); string_to_spv("count_equal_i32", "count_equal.comp", merge_maps(base_dict, {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}})); string_to_spv("cumsum_f32", "cumsum.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("cumsum_multipass1_f32", "cumsum_multipass1.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); From 1b241b879c4687d7ff4b3af1a14cb8e491a70d2d Mon Sep 17 00:00:00 2001 From: Max Krasnyansky Date: Thu, 28 May 2026 04:49:11 -0700 Subject: [PATCH 724/831] hexagon: minor refresh for HMX FA and MM (llama/23796) * hex-fa: clean up qf32/fp32 handling and stride handling * hex-fa: fix corner case fp NAN issues that were cause bad output from gemma4 on v79 * hex-fa: vectorize leftover handling * hex-fa: avoid HVX fallback during token gen HMX has more FP16 compute capacity * hmx-mm: remove dead code * hmx-mm: use fastdiv in x4x2 dequant * hmx-mm: sandwich dequant and scatter to improve perf * hmx-mm: fixed rebase conflicts * hmx-mm: further improve weight dequant by doing early type dispatch and precomputing fastdiv * hmx-mm: an even earlier dispatch for per-type dequant * hmx-mm: dequant linear types like q4_0 and q4_1 without the LUTs This is a bit faster than LUT. * hex-cmake: one more tweak for lto --------- Co-authored-by: Trivikram Reddy --- ggml/src/ggml-hexagon/htp/CMakeLists.txt | 3 +- ggml/src/ggml-hexagon/htp/flash-attn-ops.c | 157 +++--- .../src/ggml-hexagon/htp/hmx-flash-attn-ops.c | 3 - ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c | 493 ++++++++++-------- 4 files changed, 370 insertions(+), 286 deletions(-) diff --git a/ggml/src/ggml-hexagon/htp/CMakeLists.txt b/ggml/src/ggml-hexagon/htp/CMakeLists.txt index d7927261a85..ff3fc0804e3 100644 --- a/ggml/src/ggml-hexagon/htp/CMakeLists.txt +++ b/ggml/src/ggml-hexagon/htp/CMakeLists.txt @@ -58,15 +58,16 @@ list(FIND HTP_HMX_VERSIONS ${DSP_VERSION} _hmx_idx) if (_hmx_idx GREATER_EQUAL 0) target_sources(${HTP_LIB} PRIVATE - hmx-queue.c hmx-flash-attn-ops.c hmx-matmul-ops.c + hmx-queue.c ) # -mhmx enables HMX instruction set (needed by files that include hmx-utils.h) set_source_files_properties( hmx-flash-attn-ops.c hmx-matmul-ops.c + hmx-queue.c PROPERTIES COMPILE_OPTIONS "-mhmx" ) diff --git a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c index d95df6ac9d5..1bd8c1407de 100644 --- a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +++ b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c @@ -22,6 +22,16 @@ // Must be multiple of 32 #define FLASH_ATTN_BLOCK_SIZE (32 * 2) +#if __HVX_ARCH__ < 79 +#define HVX_OP_ADD_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(a, b)) +#define HVX_OP_SUB_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(a, b)) +#define HVX_OP_MUL_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b)) +#else +#define HVX_OP_ADD_F32(a, b) Q6_Vsf_vadd_VsfVsf(a, b) +#define HVX_OP_SUB_F32(a, b) Q6_Vsf_vsub_VsfVsf(a, b) +#define HVX_OP_MUL_F32(a, b) Q6_Vsf_vmpy_VsfVsf(a, b) +#endif + // This is a bit of a hack because the compiler is strugling to properly inline // the default hvx_vec_f32_to_f16 with output into the local array. static __attribute__((noinline)) void hvx_vec_f32_to_f16_a(void *ptr, HVX_Vector v0, HVX_Vector v1) @@ -54,8 +64,8 @@ static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict rsum_p = hvx_vec_mpyacc_f32_f16(rsum_p, x_hf, y_hf); } - HVX_Vector rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum_p), Q6_V_hi_W(rsum_p))); - rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32(rsum))); + HVX_Vector rsum = HVX_OP_ADD_F32(Q6_V_lo_W(rsum_p), Q6_V_hi_W(rsum_p)); + rsum = HVX_OP_MUL_F32(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32(rsum)); hvx_vec_store_u(r, 4, rsum); } @@ -105,10 +115,10 @@ static inline HVX_Vector hvx_dot_f16_f16_aa_rx4(const void * restrict y, rsum3_p = hvx_vec_mpyacc_f32_f16(rsum3_p, x3_hf, y_hf); } - HVX_Vector rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum0_p), Q6_V_hi_W(rsum0_p))); - HVX_Vector rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum1_p), Q6_V_hi_W(rsum1_p))); - HVX_Vector rsum2 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum2_p), Q6_V_hi_W(rsum2_p))); - HVX_Vector rsum3 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum3_p), Q6_V_hi_W(rsum3_p))); + HVX_Vector rsum0 = HVX_OP_ADD_F32(Q6_V_lo_W(rsum0_p), Q6_V_hi_W(rsum0_p)); + HVX_Vector rsum1 = HVX_OP_ADD_F32(Q6_V_lo_W(rsum1_p), Q6_V_hi_W(rsum1_p)); + HVX_Vector rsum2 = HVX_OP_ADD_F32(Q6_V_lo_W(rsum2_p), Q6_V_hi_W(rsum2_p)); + HVX_Vector rsum3 = HVX_OP_ADD_F32(Q6_V_lo_W(rsum3_p), Q6_V_hi_W(rsum3_p)); HVX_Vector_x4 rsum0123 = { .v = { rsum0, rsum1, rsum2, rsum3 } }; return hvx_vec_reduce_sum_f32x4(rsum0123); @@ -123,7 +133,7 @@ static inline HVX_Vector hvx_dot_f16_f16_aa_rx32(const void * restrict y, const size_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors const size_t nloe = n % VLEN_FP16; // leftover elements - HVX_Vector sums; // initialize at j = 0 + HVX_Vector sums = Q6_V_vzero(); const size_t stride_x_4 = stride_x * 4; for (uint32_t j = 0; j < VLEN_FP32; j += 4) { HVX_Vector sums_x4 = hvx_dot_f16_f16_aa_rx4(y, x, stride_x, nvec, nloe); @@ -132,8 +142,7 @@ static inline HVX_Vector hvx_dot_f16_f16_aa_rx32(const void * restrict y, x += stride_x_4; } - sums = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), sums); - return Q6_Vsf_equals_Vqf32(sums); + return HVX_OP_MUL_F32(hvx_vec_splat_f32(s), sums); } // MAD: y (F32) += x (F16) * s (F16) @@ -268,11 +277,10 @@ static inline void hvx_scale_vec_f32_aa(uint8_t * restrict dst, const uint8_t * uint32_t i = 0; #pragma unroll(4) for (; i < nvec; ++i) { - vdst[i] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs)); + vdst[i] = HVX_OP_MUL_F32(vsrc[i], vs); } if (nloe) { - HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs); - hvx_vec_store_a(&vdst[i], nloe * sizeof(float), Q6_Vsf_equals_Vqf32(v)); + hvx_vec_store_a(&vdst[i], nloe * sizeof(float), HVX_OP_MUL_F32(vsrc[i], vs)); } } @@ -438,25 +446,44 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void * // Process in sub-blocks of 32 (VLEN_FP32) HVX_Vector sb_scores[FLASH_ATTN_BLOCK_SIZE / VLEN_FP32]; HVX_Vector v_max = hvx_vec_splat_f32(-INFINITY); - for (uint32_t iv = 0; ic + VLEN_FP32 <= current_block_size; ic += VLEN_FP32, ++iv) { + for (uint32_t iv = 0; ic < current_block_size; ic += VLEN_FP32, ++iv) { // 1. Compute scores HVX_Vector scores = hvx_dot_f16_f16_aa_rx32(q_ptr_vtcm, k_base + ic * factx->size_k_row_padded, factx->size_k_row_padded, DK, factx->scale); // 2. Softcap if (factx->logit_softcap != 0.0f) { scores = hvx_vec_tanh_f32(scores); - scores = Q6_Vqf32_vmpy_VsfVsf(scores, logit_cap); - scores = Q6_Vsf_equals_Vqf32(scores); + scores = HVX_OP_MUL_F32(scores, logit_cap); } // 3. Mask if (mask) { const __fp16 * mp = m_base + ic; HVX_Vector m_vals_f16 = *(const HVX_UVector *) mp; - HVX_VectorPair m_vals_f32_pair = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_f16), slope_vec); - HVX_Vector add_val = Q6_V_lo_W(m_vals_f32_pair); - scores = Q6_Vqf32_vadd_Vqf32Vsf(add_val, scores); - scores = Q6_Vsf_equals_Vqf32(scores); + + // Multiplying -INFINITY (0xFC00) by a slope in VhfVhf instructions can incorrectly produce NaN on v79. + // Clamp -INFINITY to the max negative fp16 finite value (-65504.0f). + HVX_Vector vinf = Q6_Vh_vsplat_R(0xFC00); + HVX_Vector vmin = Q6_Vh_vsplat_R(0xFBFF); + HVX_VectorPred is_inf = Q6_Q_vcmp_eq_VhVh(m_vals_f16, vinf); + m_vals_f16 = Q6_V_vmux_QVV(is_inf, vmin, m_vals_f16); + + #if __HVX_ARCH__ >= 79 + HVX_VectorPair m_vals_f32_pair = Q6_Wsf_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_f16), slope_vec); + HVX_Vector add_val = Q6_V_lo_W(m_vals_f32_pair); + scores = Q6_Vsf_vadd_VsfVsf(add_val, scores); + #else + HVX_VectorPair m_vals_f32_pair = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_f16), slope_vec); + HVX_Vector add_val = Q6_V_lo_W(m_vals_f32_pair); + scores = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(add_val, scores)); + #endif + } + + // Mask out invalid lanes for leftover handling + uint32_t valid_lanes = current_block_size - ic; + if (valid_lanes < VLEN_FP32) { + HVX_VectorPred valid_pred = Q6_Q_vsetq_R(valid_lanes * 4); // 4 bytes per fp32 lane + scores = Q6_V_vmux_QVV(valid_pred, scores, hvx_vec_splat_f32(-INFINITY)); } sb_scores[iv] = scores; @@ -466,78 +493,55 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void * { // 4. Online Softmax Update HVX_Vector M_new_vec = Q6_Vsf_vmax_VsfVsf(v_max, M_vec); - HVX_Vector diff_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(M_vec, M_new_vec)); + HVX_Vector diff_vec = HVX_OP_SUB_F32(M_vec, M_new_vec); HVX_Vector ms_vec = hvx_vec_exp_f32(diff_vec); M_vec = M_new_vec; hvx_scale_vec_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms_vec); HVX_Vector p_sum_vec = hvx_vec_splat_f32(0.0f); - for (uint32_t ic2 = 0, iv = 0; ic2 + VLEN_FP32 <= current_block_size; ic2 += VLEN_FP32, ++iv) { + for (uint32_t ic2 = 0, iv = 0; ic2 < current_block_size; ic2 += VLEN_FP32, ++iv) { HVX_Vector scores = sb_scores[iv]; - HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_vec); - HVX_Vector P = hvx_vec_exp_f32(Q6_Vsf_equals_Vqf32(scores_shifted)); + HVX_Vector scores_shifted = HVX_OP_SUB_F32(scores, M_vec); + HVX_Vector P = hvx_vec_exp_f32(scores_shifted); - p_sum_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(p_sum_vec, P)); + p_sum_vec = HVX_OP_ADD_F32(p_sum_vec, P); // 5. Accumulate V __fp16 __attribute__((aligned(VLEN))) p_arr[VLEN_FP16]; hvx_vec_f32_to_f16_a(p_arr, P, hvx_vec_splat_f32(0)); + float __attribute__((aligned(128))) P_arr[VLEN_FP32]; + hvx_vec_store_a(P_arr, 128, P); + for (uint32_t j = 0; j < VLEN_FP32; j += 2) { - const uint32_t cur_ic = ic2 + j; - const uint8_t * v_ptr = v_base + cur_ic * factx->size_v_row_padded; + const uint32_t cur_ic = ic2 + j; + if (cur_ic >= current_block_size) { + break; + } + + if (cur_ic + 1 == current_block_size) { + // Odd leftover, process single row + if (P_arr[j] != 0.0f) { + const uint8_t * v_ptr = v_base + cur_ic * factx->size_v_row_padded; + hvx_mad_f32_f16_aa(VKQ32, v_ptr, (p_arr + j), DV); + } + break; + } + + // Avoid NaN * 0.0 = NaN for uninitialized V cache rows. + // Check the f32 values to safely avoid strict aliasing violations. + if (P_arr[j] == 0.0f && P_arr[j + 1] == 0.0f) { + continue; + } + + const uint8_t * v_ptr = v_base + cur_ic * factx->size_v_row_padded; hvx_mad_f32_f16_aa_rx2(VKQ32, v_ptr, v_ptr + factx->size_v_row_padded, (p_arr + j), (p_arr + j + 1), DV); } } p_sum_vec = hvx_vec_reduce_sum_f32(p_sum_vec); - S_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(S_vec, ms_vec)), p_sum_vec)); - } - - if (ic < current_block_size) { - // Sync scalars for leftover/next block if needed - float M = hvx_vec_get_f32(M_vec); - float S = hvx_vec_get_f32(S_vec); - - // Leftover - for (; ic < current_block_size; ++ic) { - float s_val; - const uint8_t * k_ptr = k_base + ic * factx->size_k_row_padded; - hvx_dot_f16_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, factx->scale); - if (factx->logit_softcap != 0.0f) { - s_val = factx->logit_softcap * tanhf(s_val); - } - - if (mask) { - const float m_val = m_base[ic]; - s_val += slope * m_val; - } - - const float Mold = M; - __fp16 vs = 1.0f; - - if (s_val > M) { - M = s_val; - HVX_Vector diff_vec = hvx_vec_splat_f32(Mold - M); - HVX_Vector ms_vec = hvx_vec_exp_f32(diff_vec); - hvx_scale_vec_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms_vec); - - float ms = hvx_vec_get_f32(ms_vec); - S = S * ms + vs; - } else { - HVX_Vector diff_vec = hvx_vec_splat_f32(s_val - M); - vs = hvx_vec_get_f32(hvx_vec_exp_f32(diff_vec)); - S += vs; - } - - const uint8_t * v_ptr = v_base + ic * factx->size_v_row_padded; - - hvx_mad_f32_f16_aa(VKQ32, v_ptr, &vs, DV); - } - - M_vec = hvx_vec_splat_f32(M); - S_vec = hvx_vec_splat_f32(S); + S_vec = HVX_OP_ADD_F32(HVX_OP_MUL_F32(S_vec, ms_vec), p_sum_vec); } // Issue DMA for next+1 block (if exists) @@ -599,8 +603,9 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void * const int i2 = iq2; const int i3 = iq3; - // dst is permuted - uint8_t * dst_ptr = (uint8_t *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1) * nb1; + // dst is permuted: [DV, n_heads, n_tokens, n_seq] + // head stride is nb[1], token stride is nb[2], batch stride is nb[3] + uint8_t * dst_ptr = (uint8_t *) dst->data + i2 * dst->nb[1] + i1 * dst->nb[2] + i3 * dst->nb[3]; if (dst->type == HTP_TYPE_F32) { hvx_copy_f32_ua(dst_ptr, (uint8_t *) VKQ32, DV); @@ -623,8 +628,8 @@ int op_flash_attn_ext(struct htp_ops_context * octx) { } #ifdef HTP_HAS_HMX - // HMX path: prefill (neq1 >= 32), head_dim multiple of 32, F16 KV - if (k->type == HTP_TYPE_F16 && v->type == HTP_TYPE_F16 && k->ne[0] % 32 == 0 && q->ne[1] >= 32) { + // HMX path: head_dim multiple of 32, F16 KV + if (k->type == HTP_TYPE_F16 && v->type == HTP_TYPE_F16 && k->ne[0] % 32 == 0) { int ret = hmx_flash_attn_ext(octx); if (ret == HTP_STATUS_OK) { return ret; diff --git a/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c b/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c index a496f6289ae..f132c08500d 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +++ b/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c @@ -1248,9 +1248,6 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { if (DK % 32 != 0 || DV % 32 != 0) { return HTP_STATUS_NO_SUPPORT; } - if (neq1 < 32) { - return HTP_STATUS_NO_SUPPORT; - } // GQA factor const uint32_t n_kv_heads = k->ne[2]; diff --git a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c index ab5fd73380b..083d125882d 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c @@ -16,6 +16,7 @@ #include "ggml-common.h" #include "hex-dma.h" +#include "hex-fastdiv.h" #include "worker-pool.h" #include "hvx-utils.h" @@ -187,45 +188,44 @@ static int hmx_compute_chunks(size_t vtcm_total, // In x4x2, sub-blocks 0..3 use lower nibbles, sub-blocks 4..7 use upper nibbles // of the same 32 packed bytes. static inline HVX_Vector dequantize_x4x2_q4_0_group_hvx(const uint8_t *packed_32, bool upper_nibbles, const __fp16 *scale, const HVX_Vector vlut_cvt) { + (void)vlut_cvt; HVX_Vector vq = hvx_vmemu(packed_32); const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + const HVX_Vector i8 = Q6_Vb_vsplat_R(8); HVX_Vector v_scales = hvx_vec_repl_f16(hvx_vmemu(scale)); - // q4x4x2 stores two int4 values per byte. Keep only the selected nibble. - HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles); + + HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles); v_quants = Q6_V_vand_VV(v_quants, mask_h4); - // Shuffle before LUT - v_quants = Q6_Vb_vshuff_Vb(v_quants); - // Use standard vlut16 (not _nomatch) to avoid stale-register NaN. - // _nomatch retains the previous destination-register value for colliding - // indices, but the C intrinsic doesn't model the implicit read so the - // compiler may allocate a register containing garbage/NaN. - HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0); - HVX_Vector v_hf = Q6_V_lo_W(vp); + + HVX_Vector v_int8 = Q6_Vb_vsub_VbVb(v_quants, i8); + HVX_Vector v0 = Q6_V_lo_W(Q6_Wh_vunpack_Vb(v_int8)); + HVX_Vector v_hf = Q6_Vhf_equals_Vh(v0); return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hf, v_scales)); } // Batch-dequantize 4 contiguous x4x2 Q4_0 groups (4x32 = 128 packed bytes) using -// full HVX vector width. One vmemu + one vlut16 replaces 4 separate calls. +// full HVX vector width. // Output: vector_x2 each hold 32 FP16 values in the first 64 bytes. static inline HVX_Vector_x2 dequantize_x4x2_q4_0_x4groups_hvx( const uint8_t *packed_128, bool upper_nibbles, const __fp16 *scales_4, const HVX_Vector vlut_cvt) { - // Load all 128 packed bytes (4 contiguous 32-byte groups) + (void)vlut_cvt; HVX_Vector vq = hvx_vmemu(packed_128); const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + const HVX_Vector i8 = Q6_Vb_vsplat_R(8); HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles); v_quants = Q6_V_vand_VV(v_quants, mask_h4); - // Shuffle before LUT - v_quants = Q6_Vb_vshuff_Vb(v_quants); + HVX_Vector v_int8 = Q6_Vb_vsub_VbVb(v_quants, i8); - // Full-width vlut16: 128 byte lookups -> 128 fp16 results in a VectorPair - HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0); - HVX_Vector v_lo = Q6_V_lo_W(vp); // [group0: 32 fp16 | group1: 32 fp16] - HVX_Vector v_hi = Q6_V_hi_W(vp); // [group2: 32 fp16 | group3: 32 fp16] + HVX_VectorPair vp_int16 = Q6_Wh_vunpack_Vb(v_int8); + HVX_Vector v_lo = Q6_V_lo_W(vp_int16); + HVX_Vector v_hi = Q6_V_hi_W(vp_int16); + + v_lo = Q6_Vhf_equals_Vh(v_lo); + v_hi = Q6_Vhf_equals_Vh(v_hi); - // Build per-group scale vectors: first 64 bytes use scale_a, last 64 use scale_b HVX_Vector vscale = hvx_vmemu(scales_4); HVX_Vector v_sc01 = hvx_vec_repl_2x_f16(vscale); HVX_Vector v_sc23 = hvx_vec_repl_2x_f16(Q6_V_vror_VR(vscale, 4)); @@ -233,13 +233,12 @@ static inline HVX_Vector_x2 dequantize_x4x2_q4_0_x4groups_hvx( v_lo = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_lo, v_sc01)); v_hi = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hi, v_sc23)); - // Extract individual groups: scatter uses q_mask64 so only first 64 bytes matter - HVX_Vector_x2 r = { v_lo,/* group1 already in [0:63] */ - v_hi /* group2 already in [0:63] */ }; + HVX_Vector_x2 r = { v_lo, v_hi }; return r; } static inline HVX_Vector dequantize_x4x2_q4_1_group_hvx(const uint8_t *packed_32, bool upper_nibbles, const __fp16 *scale_offset, const HVX_Vector vlut_cvt) { + (void)vlut_cvt; HVX_Vector vq = hvx_vmemu(packed_32); const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); HVX_Vector v_dm = hvx_vmemu(scale_offset); @@ -248,9 +247,9 @@ static inline HVX_Vector dequantize_x4x2_q4_1_group_hvx(const uint8_t *packed_32 HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles); v_quants = Q6_V_vand_VV(v_quants, mask_h4); - v_quants = Q6_Vb_vshuff_Vb(v_quants); - HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0); - HVX_Vector v_hf = Q6_V_lo_W(vp); + + HVX_Vector v0 = Q6_V_lo_W(Q6_Wh_vunpack_Vb(v_quants)); + HVX_Vector v_hf = Q6_Vhf_equals_Vh(v0); return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(v_hf, v_scales), v_offsets)); } @@ -258,16 +257,18 @@ static inline HVX_Vector dequantize_x4x2_q4_1_group_hvx(const uint8_t *packed_32 static inline HVX_Vector_x2 dequantize_x4x2_q4_1_x4groups_hvx( const uint8_t *packed_128, bool upper_nibbles, const __fp16 *scales_offsets_4, const HVX_Vector vlut_cvt) { + (void)vlut_cvt; HVX_Vector vq = hvx_vmemu(packed_128); const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles); v_quants = Q6_V_vand_VV(v_quants, mask_h4); - v_quants = Q6_Vb_vshuff_Vb(v_quants); + HVX_VectorPair vp_int16 = Q6_Wh_vunpack_Vb(v_quants); + HVX_Vector v_lo = Q6_V_lo_W(vp_int16); + HVX_Vector v_hi = Q6_V_hi_W(vp_int16); - HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0); - HVX_Vector v_lo = Q6_V_lo_W(vp); - HVX_Vector v_hi = Q6_V_hi_W(vp); + v_lo = Q6_Vhf_equals_Vh(v_lo); + v_hi = Q6_Vhf_equals_Vh(v_hi); HVX_Vector vscale_offset = hvx_vmemu(scales_offsets_4); HVX_VectorPair dm_deal = Q6_W_vdeal_VVR(vscale_offset, vscale_offset, -2); @@ -287,6 +288,45 @@ static inline HVX_Vector_x2 dequantize_x4x2_q4_1_x4groups_hvx( return r; } +// LUT-based dequantizers for non-linear IQ4_NL format. +static inline HVX_Vector dequantize_x4x2_iq4_nl_group_hvx(const uint8_t *packed_32, bool upper_nibbles, const __fp16 *scale, const HVX_Vector vlut_cvt) { + HVX_Vector vq = hvx_vmemu(packed_32); + const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + HVX_Vector v_scales = hvx_vec_repl_f16(hvx_vmemu(scale)); + HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles); + v_quants = Q6_V_vand_VV(v_quants, mask_h4); + v_quants = Q6_Vb_vshuff_Vb(v_quants); + HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0); + HVX_Vector v_hf = Q6_V_lo_W(vp); + + return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hf, v_scales)); +} + +static inline HVX_Vector_x2 dequantize_x4x2_iq4_nl_x4groups_hvx( + const uint8_t *packed_128, bool upper_nibbles, + const __fp16 *scales_4, const HVX_Vector vlut_cvt) { + HVX_Vector vq = hvx_vmemu(packed_128); + const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles); + v_quants = Q6_V_vand_VV(v_quants, mask_h4); + + v_quants = Q6_Vb_vshuff_Vb(v_quants); + + HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0); + HVX_Vector v_lo = Q6_V_lo_W(vp); + HVX_Vector v_hi = Q6_V_hi_W(vp); + + HVX_Vector vscale = hvx_vmemu(scales_4); + HVX_Vector v_sc01 = hvx_vec_repl_2x_f16(vscale); + HVX_Vector v_sc23 = hvx_vec_repl_2x_f16(Q6_V_vror_VR(vscale, 4)); + + v_lo = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_lo, v_sc01)); + v_hi = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hi, v_sc23)); + + HVX_Vector_x2 r = { v_lo, v_hi }; + return r; +} + // Dequantize one x4x2 Q8_0 group (32 int8 quants) -> 32 FP16 in first 64 bytes. static inline HVX_Vector dequantize_x4x2_q8_0_group_hvx(const int8_t *quants_32, const __fp16 *scale) { HVX_Vector vq = hvx_vmemu(quants_32); @@ -374,122 +414,176 @@ static inline HVX_Vector_x4 dequantize_x4x2_mxfp4_x4groups_hvx(const uint8_t * return r; } +typedef struct { + __fp16 *dst; + const uint8_t *src; + int n_cols; + int k_block; + size_t row_stride; + int weight_type; + int n_tot_tiles; + int n_tiles_per_task; + int n_tasks; + int n_k_tiles; + struct fastdiv_values n_k_tiles_div; +} x4x2_dequantize_state_t; + // Dequantize a tile range from x4x2 weight data (already in VTCM) to tile-major FP16. // Input: vtcm_src has n_cols rows of x4x2 data, each row_stride bytes. // Output: vtcm_dst in tile-major FP16 layout. -static void dequantize_x4x2_weight_to_fp16_tiles_task( - __fp16 *restrict vtcm_dst, - const uint8_t *restrict vtcm_src, - int n_cols, int k_block, - size_t row_stride, int weight_type, - int start_tile, int end_tile) { - - const int n_k_tiles = (unsigned)k_block / HMX_FP16_TILE_N_COLS; - const bool is_q4 = (weight_type == HTP_TYPE_Q4_0 || weight_type == HTP_TYPE_Q4_1 || weight_type == HTP_TYPE_IQ4_NL); - const bool is_q4_1 = (weight_type == HTP_TYPE_Q4_1); - const int qrow_size = is_q4 ? ((unsigned)k_block / 2) : k_block; - - const HVX_Vector vlut_cvt = (weight_type == HTP_TYPE_IQ4_NL) ? hvx_vmem(iq4_nl_to_fp16_lut) : - (weight_type == HTP_TYPE_MXFP4) ? hvx_vmem(mxfp4_to_fp16_lut) : - (weight_type == HTP_TYPE_Q4_1) ? hvx_vmem(q4_1_to_fp16_lut) : - hvx_vmem(q4_0_to_fp16_lut); - // vscatter setup: write dequantized K-values directly to transposed [K][N] tile positions. - // Each int32 element holds a K-row-pair (2 adjacent fp16 values). word[i] at offset i*128 - // maps to K-rows 2i and 2i+1. Column offset (n*4) added per row. - const HVX_Vector v_scat_base = hvx_vmem(hmx_transpose_scatter_offsets); - const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); // 4 bytes = 1 column step - const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64); // first 16 words (64 bytes) - - unsigned ct = (unsigned)start_tile / n_k_tiles; // column tile index - unsigned kt = (unsigned)start_tile % n_k_tiles; // K tile index - for (unsigned t = start_tile; t < end_tile; ) { - if (kt >= n_k_tiles) { kt = 0; ct++; } - - // --- Batch-4 fast path for Q4: process 4 contiguous K-tiles with one vlut16 per row --- - if (is_q4 && (kt % 4 == 0) && (t + 4 <= end_tile) && ((t + 3) / n_k_tiles == ct)) { - unsigned blk_idx = (kt * 32) / QK_Q4_0x4x2; - unsigned sub_blk_base = ((kt * 32) % QK_Q4_0x4x2) / 32; // 0 or 4 - bool upper = (sub_blk_base >= 4); - unsigned packed_off = blk_idx * (QK_Q4_0x4x2 / 2); // 128 contiguous packed bytes - unsigned dblk_size = is_q4_1 ? 32 : HMX_X4X2_DBLK_SIZE; - unsigned scale_step = is_q4_1 ? 4 : (int)sizeof(__fp16); - unsigned scale_off = qrow_size + blk_idx * dblk_size - + sub_blk_base * scale_step; - - __fp16 *tile_bases[4]; - for (unsigned g = 0; g < 4; g++) { tile_bases[g] = vtcm_dst + (t + g) * HMX_FP16_TILE_N_ELMS; } - - HVX_Vector v_off = v_scat_base; - - unsigned row_offset = ct * HMX_FP16_TILE_N_COLS * row_stride; - unsigned row1 = ct * HMX_FP16_TILE_N_COLS + 1; - - if (is_q4_1) { - for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) { - const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride; - const uint8_t *r1 = vtcm_src + row_offset; row_offset += row_stride; - - HVX_Vector_x2 dv0 = dequantize_x4x2_q4_1_x4groups_hvx(r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt); - HVX_Vector_x2 dv1 = dequantize_x4x2_q4_1_x4groups_hvx(r1 + packed_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt); +#define DEFINE_DEQUANTIZE_Q4_TASK(suffix, lut_name, helper_prefix, dblk_size, scale_step) \ +static void dequantize_x4x2_weight_to_fp16_tiles_task_##suffix( \ + const x4x2_dequantize_state_t *state, \ + int start_tile, int end_tile) { \ + \ + const int n_k_tiles = state->n_k_tiles; \ + const int qrow_size = (unsigned)state->k_block / 2; \ + const struct fastdiv_values n_k_tiles_div = state->n_k_tiles_div; \ + const HVX_Vector vlut_cvt = hvx_vmem(lut_name); \ + \ + const HVX_Vector v_scat_base = hvx_vmem(hmx_transpose_scatter_offsets); \ + const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); \ + const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64); \ + \ + unsigned ct = fastdiv((unsigned)start_tile, &n_k_tiles_div); \ + unsigned kt = fastmodulo((unsigned)start_tile, n_k_tiles, &n_k_tiles_div); \ + \ + for (unsigned t = start_tile; t < (unsigned)end_tile; ) { \ + if (kt >= (unsigned)n_k_tiles) { kt = 0; ct++; } \ + \ + if ((kt % 4 == 0) && (t + 4 <= (unsigned)end_tile) && (fastdiv(t + 3, &n_k_tiles_div) == ct)) { \ + unsigned blk_idx = ((kt * 32) / QK_Q4_0x4x2); \ + unsigned sub_blk_base = ((kt * 32) % QK_Q4_0x4x2) / 32; \ + bool upper = (sub_blk_base >= 4); \ + unsigned packed_off = blk_idx * (QK_Q4_0x4x2 / 2); \ + unsigned scale_off = qrow_size + blk_idx * (dblk_size) + sub_blk_base * (scale_step); \ + \ + __fp16 *tile_bases[4]; \ + for (unsigned g = 0; g < 4; g++) { \ + tile_bases[g] = state->dst + (t + g) * HMX_FP16_TILE_N_ELMS; \ + } \ + \ + HVX_Vector v_off = v_scat_base; \ + unsigned row_offset = ct * HMX_FP16_TILE_N_COLS * state->row_stride; \ + \ + for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) { \ + const uint8_t *r0 = state->src + row_offset; row_offset += state->row_stride; \ + const uint8_t *r1 = state->src + row_offset; row_offset += state->row_stride; \ + \ + HVX_Vector_x2 dv0 = dequantize_x4x2_##helper_prefix##_x4groups_hvx( \ + r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt); \ + Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[0]); \ + Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[1]); \ + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); \ + \ + HVX_Vector_x2 dv1 = dequantize_x4x2_##helper_prefix##_x4groups_hvx( \ + r1 + packed_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt); \ + Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[0]); \ + Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[1]); \ + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); \ + } \ + \ + for (int g = 0; g < 4; g++) { (void) *(volatile HVX_Vector *)(tile_bases[g]); } \ + t += 4; kt += 4; \ + continue; \ + } \ + \ + __fp16 *tile_base = state->dst + t * HMX_FP16_TILE_N_ELMS; \ + { \ + unsigned blk_idx = (kt * 32) / QK_Q4_0x4x2; \ + unsigned sub_blk = ((kt * 32) % QK_Q4_0x4x2) / 32; \ + bool upper = (sub_blk >= 4); \ + unsigned byte_off = blk_idx * (QK_Q4_0x4x2 / 2) + (upper ? (sub_blk - 4) : sub_blk) * 32; \ + unsigned scale_off = qrow_size + blk_idx * (dblk_size) + sub_blk * (scale_step); \ + \ + HVX_Vector v_off = v_scat_base; \ + unsigned row_offset = ct * HMX_FP16_TILE_N_COLS * state->row_stride; \ + unsigned row1 = ct * HMX_FP16_TILE_N_COLS + 1; \ + \ + for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) { \ + const uint8_t *r0 = state->src + row_offset; row_offset += state->row_stride; \ + const uint8_t *r1 = state->src + row_offset; row_offset += state->row_stride; \ + \ + HVX_Vector v0 = dequantize_x4x2_##helper_prefix##_group_hvx( \ + r0 + byte_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt); \ + HVX_Vector v1 = (row1 < (unsigned)state->n_cols) \ + ? dequantize_x4x2_##helper_prefix##_group_hvx( \ + r1 + byte_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt) \ + : Q6_V_vzero(); \ + \ + Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0); \ + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); \ + Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v1); \ + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); \ + } \ + (void) *(volatile HVX_Vector *)(tile_base); \ + } \ + ++t; ++kt; \ + } \ + \ + if (start_tile < end_tile) { \ + (void) *(volatile HVX_Vector *)(state->dst + (end_tile - 1) * HMX_FP16_TILE_N_ELMS); \ + } \ +} \ + \ +static void dequantize_x4x2_worker_loop_##suffix(unsigned int n, unsigned int i, void *data) { \ + x4x2_dequantize_state_t *state = (x4x2_dequantize_state_t *)data; \ + for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) { \ + int start = task_id * state->n_tiles_per_task; \ + int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles); \ + dequantize_x4x2_weight_to_fp16_tiles_task_##suffix(state, start, end); \ + } \ +} - Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[0]); - Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[1]); - v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); +DEFINE_DEQUANTIZE_Q4_TASK(q4_0, q4_0_to_fp16_lut, q4_0, HMX_X4X2_DBLK_SIZE, (int)sizeof(__fp16)) +DEFINE_DEQUANTIZE_Q4_TASK(q4_1, q4_1_to_fp16_lut, q4_1, 32, 4) +DEFINE_DEQUANTIZE_Q4_TASK(iq4_nl, iq4_nl_to_fp16_lut, iq4_nl, HMX_X4X2_DBLK_SIZE, (int)sizeof(__fp16)) - Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[0]); - Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[1]); - v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); - } - } else { - for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) { - const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride; - const uint8_t *r1 = vtcm_src + row_offset; row_offset += row_stride; +static void dequantize_x4x2_weight_to_fp16_tiles_task_mxfp4( + const x4x2_dequantize_state_t *state, + int start_tile, int end_tile) { - HVX_Vector_x2 dv0 = dequantize_x4x2_q4_0_x4groups_hvx(r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt); - HVX_Vector_x2 dv1 = dequantize_x4x2_q4_0_x4groups_hvx(r1 + packed_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt); + const int n_k_tiles = state->n_k_tiles; + const int qrow_size = state->k_block; + const struct fastdiv_values n_k_tiles_div = state->n_k_tiles_div; + const HVX_Vector vlut_cvt = hvx_vmem(mxfp4_to_fp16_lut); - Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[0]); - Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[1]); - v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + const HVX_Vector v_scat_base = hvx_vmem(hmx_transpose_scatter_offsets); + const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); + const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64); - Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[0]); - Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[1]); - v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); - } - } + unsigned ct = fastdiv((unsigned)start_tile, &n_k_tiles_div); + unsigned kt = fastmodulo((unsigned)start_tile, n_k_tiles, &n_k_tiles_div); - for (int g = 0; g < 4; g++) { (void) *(volatile HVX_Vector *)(tile_bases[g]); } - t += 4; kt += 4; - continue; - } + for (unsigned t = start_tile; t < (unsigned)end_tile; ) { + if (kt >= (unsigned)n_k_tiles) { kt = 0; ct++; } - // --- Batch-4 fast path for MXFP4: same nibble layout but E8M0 scales --- - if (weight_type == HTP_TYPE_MXFP4 && (kt % 4 == 0) && (t + 4 <= end_tile) && ((t + 3) / n_k_tiles == ct)) { + // Batch-4 fast path for MXFP4 + if ((kt % 4 == 0) && (t + 4 <= (unsigned)end_tile) && (fastdiv(t + 3, &n_k_tiles_div) == ct)) { int blk_idx = (kt * 32) / QK_MXFP4x4x2; - int sub_blk_base = ((kt * 32) % QK_MXFP4x4x2) / 32; // 0 or 4 + int sub_blk_base = ((kt * 32) % QK_MXFP4x4x2) / 32; bool upper = (sub_blk_base >= 4); - int packed_off = blk_idx * (QK_MXFP4x4x2 / 2); // 128 contiguous packed bytes - int e8m0_blk_off = qrow_size + blk_idx * HMX_X4X2_MXFP4_EBLK_SIZE; // all 8 E8M0 scales + int packed_off = blk_idx * (QK_MXFP4x4x2 / 2); + int e8m0_blk_off = qrow_size + blk_idx * HMX_X4X2_MXFP4_EBLK_SIZE; __fp16 * tile_bases[4]; for (int g = 0; g < 4; g++) { - tile_bases[g] = vtcm_dst + (t + g) * HMX_FP16_TILE_N_ELMS; + tile_bases[g] = state->dst + (t + g) * HMX_FP16_TILE_N_ELMS; } HVX_Vector v_off = v_scat_base; for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) { int row0 = ct * HMX_FP16_TILE_N_COLS + r; int row1 = row0 + 1; - const uint8_t * r0 = vtcm_src + row0 * row_stride; - const uint8_t * r1 = vtcm_src + row1 * row_stride; + const uint8_t * r0 = state->src + row0 * state->row_stride; + const uint8_t * r1 = state->src + row1 * state->row_stride; - // Batch-convert all 8 E8M0 scales once per row (stays in HVX register) mxfp4_scales_t r0_e8 = mxfp4_convert_scales(r0 + e8m0_blk_off); HVX_Vector_x4 dv0, dv1; dv0 = dequantize_x4x2_mxfp4_x4groups_hvx(r0 + packed_off, upper, sub_blk_base, vlut_cvt, r0_e8); - if (row1 < n_cols) { + if (row1 < state->n_cols) { mxfp4_scales_t r1_e8 = mxfp4_convert_scales(r1 + e8m0_blk_off); dv1 = dequantize_x4x2_mxfp4_x4groups_hvx(r1 + packed_off, upper, sub_blk_base, vlut_cvt, r1_e8); } else { @@ -510,58 +604,13 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task( (void) *(volatile HVX_Vector *) (tile_bases[g]); } - t += 4; + t += 4; kt += 4; continue; } - // --- Single-tile fallback --- - __fp16 *tile_base = vtcm_dst + t * HMX_FP16_TILE_N_ELMS; - - if (is_q4) { - unsigned blk_idx = (kt * 32) / QK_Q4_0x4x2; - unsigned sub_blk = ((kt * 32) % QK_Q4_0x4x2) / 32; - bool upper = (sub_blk >= 4); - unsigned byte_off = blk_idx * (QK_Q4_0x4x2 / 2) + (upper ? (sub_blk - 4) : sub_blk) * 32; - unsigned dblk_size = is_q4_1 ? 32 : HMX_X4X2_DBLK_SIZE; - unsigned scale_step = is_q4_1 ? 4 : (int)sizeof(__fp16); - unsigned scale_off = qrow_size + blk_idx * dblk_size + sub_blk * scale_step; - - HVX_Vector v_off = v_scat_base; // reset to column 0 - unsigned row_offset = ct * HMX_FP16_TILE_N_COLS * row_stride; - unsigned row1 = ct * HMX_FP16_TILE_N_COLS + 1; - if (is_q4_1) { - for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) { - const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride; - const uint8_t *r1 = vtcm_src + row_offset; row_offset += row_stride; - - HVX_Vector v0 = dequantize_x4x2_q4_1_group_hvx(r0 + byte_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt); - HVX_Vector v1 = (row1 < n_cols) - ? dequantize_x4x2_q4_1_group_hvx(r1 + byte_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt) - : Q6_V_vzero(); - - Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0); - v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); - Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v1); - v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); - } - } else { - for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) { - const uint8_t *r0 = vtcm_src + row_offset; row_offset += row_stride; - const uint8_t *r1 = vtcm_src + row_offset; row_offset += row_stride; - - HVX_Vector v0 = dequantize_x4x2_q4_0_group_hvx(r0 + byte_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt); - HVX_Vector v1 = (row1 < n_cols) - ? dequantize_x4x2_q4_0_group_hvx(r1 + byte_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt) - : Q6_V_vzero(); - - Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0); - v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); - Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v1); - v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); - } - } - (void) *(volatile HVX_Vector *)(tile_base); - } else if (weight_type == HTP_TYPE_MXFP4) { + // Single-tile fallback + __fp16 *tile_base = state->dst + t * HMX_FP16_TILE_N_ELMS; + { int blk_idx = (kt * 32) / QK_MXFP4x4x2; int sub_blk = ((kt * 32) % QK_MXFP4x4x2) / 32; bool upper = (sub_blk >= 4); @@ -573,15 +622,14 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task( int row0 = ct * HMX_FP16_TILE_N_COLS + r; int row1 = row0 + 1; - const uint8_t * r0 = vtcm_src + row0 * row_stride; - const uint8_t * r1 = vtcm_src + row1 * row_stride; + const uint8_t * r0 = state->src + row0 * state->row_stride; + const uint8_t * r1 = state->src + row1 * state->row_stride; - // Batch-convert all 8 E8M0 scales once per row (stays in HVX register) mxfp4_scales_t r0_e8 = mxfp4_convert_scales(r0 + e8m0_blk_off); HVX_Vector v0 = dequantize_x4x2_mxfp4_group_hvx(r0 + byte_off, upper, sub_blk, vlut_cvt, r0_e8); HVX_Vector v1; - if (row1 < n_cols) { + if (row1 < state->n_cols) { mxfp4_scales_t r1_e8 = mxfp4_convert_scales(r1 + e8m0_blk_off); v1 = dequantize_x4x2_mxfp4_group_hvx(r1 + byte_off, upper, sub_blk, vlut_cvt, r1_e8); } else { @@ -594,23 +642,59 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task( v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); } (void) *(volatile HVX_Vector *) (tile_base); - } else { - // Q8_0 + } + ++t; ++kt; + } + + if (start_tile < end_tile) { + (void) *(volatile HVX_Vector *)(state->dst + (end_tile - 1) * HMX_FP16_TILE_N_ELMS); + } +} + +static void dequantize_x4x2_worker_loop_mxfp4(unsigned int n, unsigned int i, void *data) { + x4x2_dequantize_state_t *state = (x4x2_dequantize_state_t *)data; + for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) { + int start = task_id * state->n_tiles_per_task; + int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles); + dequantize_x4x2_weight_to_fp16_tiles_task_mxfp4(state, start, end); + } +} + +static void dequantize_x4x2_weight_to_fp16_tiles_task_q8_0( + const x4x2_dequantize_state_t *state, + int start_tile, int end_tile) { + + const int n_k_tiles = state->n_k_tiles; + const int qrow_size = state->k_block; + const struct fastdiv_values n_k_tiles_div = state->n_k_tiles_div; + + const HVX_Vector v_scat_base = hvx_vmem(hmx_transpose_scatter_offsets); + const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); + const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64); + + unsigned ct = fastdiv((unsigned)start_tile, &n_k_tiles_div); + unsigned kt = fastmodulo((unsigned)start_tile, n_k_tiles, &n_k_tiles_div); + + for (unsigned t = start_tile; t < (unsigned)end_tile; ) { + if (kt >= (unsigned)n_k_tiles) { kt = 0; ct++; } + + __fp16 *tile_base = state->dst + t * HMX_FP16_TILE_N_ELMS; + { int blk_idx = (kt * 32) / QK_Q8_0x4x2; int sub_blk = ((kt * 32) % QK_Q8_0x4x2) / 32; int byte_off = blk_idx * QK_Q8_0x4x2 + sub_blk * 32; int scale_off = qrow_size + blk_idx * HMX_X4X2_DBLK_SIZE + sub_blk * (int)sizeof(__fp16); - HVX_Vector v_off = v_scat_base; // reset to column 0 + HVX_Vector v_off = v_scat_base; for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) { int row0 = ct * HMX_FP16_TILE_N_COLS + r; int row1 = row0 + 1; - const uint8_t *r0 = vtcm_src + row0 * row_stride; - const uint8_t *r1 = vtcm_src + row1 * row_stride; + const uint8_t *r0 = state->src + row0 * state->row_stride; + const uint8_t *r1 = state->src + row1 * state->row_stride; HVX_Vector v0 = dequantize_x4x2_q8_0_group_hvx((const int8_t *)(r0 + byte_off), (const __fp16 *)(r0 + scale_off)); - HVX_Vector v1 = (row1 < n_cols) ? dequantize_x4x2_q8_0_group_hvx((const int8_t *)(r1 + byte_off), (const __fp16 *)(r1 + scale_off)) : Q6_V_vzero(); + HVX_Vector v1 = (row1 < state->n_cols) ? dequantize_x4x2_q8_0_group_hvx((const int8_t *)(r1 + byte_off), (const __fp16 *)(r1 + scale_off)) : Q6_V_vzero(); Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0); v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); @@ -622,50 +706,31 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task( ++t; ++kt; } - // Drain HVX scatter write buffer: a vmem load on the same HW thread retires - // all pending scatter entries to VTCM. Without this, the main thread's HMX - // reads may see stale data because atomic_fetch_sub (release) only orders - // regular stores, not the HVX scatter buffer. if (start_tile < end_tile) { - (void) *(volatile HVX_Vector *)(vtcm_dst + (end_tile - 1) * HMX_FP16_TILE_N_ELMS); + (void) *(volatile HVX_Vector *)(state->dst + (end_tile - 1) * HMX_FP16_TILE_N_ELMS); } } -typedef struct { - __fp16 *dst; - const uint8_t *src; - int n_cols; - int k_block; - size_t row_stride; - int weight_type; - int n_tot_tiles; - int n_tiles_per_task; - int n_tasks; -} x4x2_dequantize_state_t; - -static void dequantize_x4x2_worker_loop(unsigned int n, unsigned int i, void *data) { +static void dequantize_x4x2_worker_loop_q8_0(unsigned int n, unsigned int i, void *data) { x4x2_dequantize_state_t *state = (x4x2_dequantize_state_t *)data; - for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) { int start = task_id * state->n_tiles_per_task; int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles); - - dequantize_x4x2_weight_to_fp16_tiles_task( - state->dst, state->src, state->n_cols, state->k_block, - state->row_stride, state->weight_type, start, end); + dequantize_x4x2_weight_to_fp16_tiles_task_q8_0(state, start, end); } } static void dequantize_x4x2_weight_chunk_to_fp16_tiles( struct htp_context *ctx, __fp16 *vtcm_dst, const void *vtcm_src, int n_cols, int k_block, - size_t row_stride, int weight_type) { + size_t row_stride, int weight_type, + int n_k_tiles, struct fastdiv_values n_k_tiles_div, + worker_callback_t dequant_worker_fn) { assert(n_cols % HMX_FP16_TILE_N_COLS == 0); assert(k_block % HMX_FP16_TILE_N_COLS == 0); size_t n_col_tiles = n_cols / HMX_FP16_TILE_N_COLS; - size_t n_k_tiles = k_block / HMX_FP16_TILE_N_COLS; size_t n_tot_tiles = n_col_tiles * n_k_tiles; size_t n_tiles_per_task = hmx_ceil_div(n_tot_tiles, ctx->n_threads); @@ -680,8 +745,10 @@ static void dequantize_x4x2_weight_chunk_to_fp16_tiles( state.k_block = k_block; state.row_stride = row_stride; state.weight_type = weight_type; + state.n_k_tiles = n_k_tiles; + state.n_k_tiles_div = n_k_tiles_div; - worker_pool_run_func(ctx->worker_pool, dequantize_x4x2_worker_loop, &state, ctx->n_threads); + worker_pool_run_func(ctx->worker_pool, dequant_worker_fn, &state, ctx->n_threads); } // --- End x4x2 dequantizers --- @@ -978,6 +1045,20 @@ int hmx_matmul_q_f32(struct htp_context *ctx, float *restrict dst, const float * return -1; } + worker_callback_t dequant_worker_fn = NULL; + switch (weight_type) { + case HTP_TYPE_Q4_0: dequant_worker_fn = dequantize_x4x2_worker_loop_q4_0; break; + case HTP_TYPE_IQ4_NL: dequant_worker_fn = dequantize_x4x2_worker_loop_iq4_nl; break; + case HTP_TYPE_Q4_1: dequant_worker_fn = dequantize_x4x2_worker_loop_q4_1; break; + case HTP_TYPE_MXFP4: dequant_worker_fn = dequantize_x4x2_worker_loop_mxfp4; break; + case HTP_TYPE_Q8_0: dequant_worker_fn = dequantize_x4x2_worker_loop_q8_0; break; + default: + return -1; + } + + const int n_k_tiles = k / HMX_FP16_TILE_N_COLS; + const struct fastdiv_values n_k_tiles_div = init_fastdiv_values(n_k_tiles); + // --- Dynamic VTCM layout --- const size_t vec_dot_size = k * sizeof(__fp16); const size_t vtcm_budget = ctx->vtcm_size; @@ -1070,7 +1151,7 @@ int hmx_matmul_q_f32(struct htp_context *ctx, float *restrict dst, const float * { // B0: wait for DMA, dequant weight chunk 0 dma_queue_pop(ctx->dma[0]); - dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[0], vtcm_qweight, n_cols_A0, k, row_stride, weight_type); + dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[0], vtcm_qweight, n_cols_A0, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn); // A1: issue DMA for weight chunk 1 const size_t n_cols_A1 = hex_smin(n - 1 * n_chunk_n_cols, n_chunk_n_cols); @@ -1089,7 +1170,7 @@ int hmx_matmul_q_f32(struct htp_context *ctx, float *restrict dst, const float * // B1: DMA pop + dequant (runs in parallel with C0 on HMX worker) if (1 < n_chunk_cnt) { dma_queue_pop(ctx->dma[0]); - dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[1], vtcm_qweight, n_cols_A1, k, row_stride, weight_type); + dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[1], vtcm_qweight, n_cols_A1, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn); } } @@ -1131,7 +1212,7 @@ int hmx_matmul_q_f32(struct htp_context *ctx, float *restrict dst, const float * // B_{i+2}: DMA pop + dequant (multi-thread HVX, parallel with C_{i+1}) if (i + 2 < n_chunk_cnt) { dma_queue_pop(ctx->dma[0]); - dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[(i + 2) % 2], vtcm_qweight, n_cols_p2, k, row_stride, weight_type); + dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[(i + 2) % 2], vtcm_qweight, n_cols_p2, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn); } } } From 04795e6272a74123486dfccfd0e62ecf816ba178 Mon Sep 17 00:00:00 2001 From: Jaden_Mach <88880593+jadenmach2@users.noreply.github.com> Date: Thu, 28 May 2026 08:50:25 -0400 Subject: [PATCH 725/831] CUDA: route batch>=4 quantized matmul to MMQ on AMD MFMA hardware (llama/23227) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * CUDA: per-quant MMVQ/MMQ batch threshold on AMD MFMA hardware The dispatcher uses a single global threshold (MMVQ_MAX_BATCH_SIZE = 8) to choose between mul_mat_vec_q (per-row GEMV) and mul_mat_q (MFMA-tiled GEMM) for quantized matmul. On AMD CDNA, the optimal crossover differs substantially by quant family because the per-row GEMV cost is dominated by dequantisation, not the dot-product itself: K-quants pay a heavier super-block decode and so MMQ wins sooner; legacy and IQ quants have lean decode and stay ahead until the batch fully populates an MFMA tile. This patch introduces ggml_cuda_should_use_mmvq(type, cc, ne11) -> bool, mirroring the existing ggml_cuda_should_use_mmq, and gates per-quant thresholds on amd_mfma_available(cc): Q3_K, Q4_K, Q5_K : MMVQ <= 3 (MMQ wins from batch=4: +5% .. +76%) Q2_K, Q6_K : MMVQ <= 5 (MMQ wins from batch=6: +8% .. +35%) others : MMVQ <= 8 (legacy & IQ regress under MMQ; unchanged) Non-AMD-MFMA paths (NVIDIA, RDNA, CDNA1 without MFMA) are byte-identical to master. GGML_CUDA_FORCE_MMVQ=1 restores the original global threshold for A/B testing. Measured on MI250X (gfx90a, ROCm 7.2.1) with Llama-3.2-3B-Instruct, llama-bench pp512 across all 20 supported quants, ubatch 1..8, 10 reps. Full table in PR description. Selected pp512 throughput (tok/s, ub=8): Q4_K_S: 559 -> 940 (+68%) Q5_K_S: 503 -> 884 (+76%) Q3_K_S: 629 -> 879 (+40%) Q2_K : 615 -> 809 (+32%) Q6_K : 582 -> 776 (+33%) Selected pp512 throughput (tok/s, ub=4): Q4_K_S: 444 -> 480 (+ 8%) Q4_0 : 682 -> 685 (+ 0%) (no regression - retains MMVQ) IQ4_XS: 706 -> 698 (- 1%) (no regression - retains MMVQ) * CUDA: address review — inline MMVQ batch table, drop env hatch & doc block * tune kernel selection logic for CDNA1 --------- Co-authored-by: Johannes Gäßler --- ggml/src/ggml-cuda/ggml-cuda.cu | 2 ++ ggml/src/ggml-cuda/mmvq.cu | 47 +++++++++++++++++++++++++++++++++ ggml/src/ggml-cuda/mmvq.cuh | 2 ++ 3 files changed, 51 insertions(+) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 23d1c069248..dc3e8fd6265 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2570,6 +2570,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1], /*n_experts=*/0); use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src0->nb, src1->ne[1], /*mul_mat_id=*/false); use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src0->nb, src1->ne[1]); + use_mul_mat_vec_q = use_mul_mat_vec_q && ggml_cuda_should_use_mmvq(src0->type, cc, src1->ne[1]); any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc); } } else { @@ -2578,6 +2579,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1], /*n_experts=*/0); use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src0->nb, src1->ne[1], /*mul_mat_id=*/false); use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src0->nb, src1->ne[1]); + use_mul_mat_vec_q = use_mul_mat_vec_q && ggml_cuda_should_use_mmvq(src0->type, cc, src1->ne[1]); any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc); } diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 13b8b855282..873ff05a074 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -271,6 +271,53 @@ int get_mmvq_mmid_max_batch(ggml_type type, int cc) { return MMVQ_MAX_BATCH_SIZE; } +bool ggml_cuda_should_use_mmvq(enum ggml_type type, int cc, int64_t ne11) { + if (GGML_CUDA_CC_IS_CDNA(cc)) { + if (GGML_CUDA_CC_IS_CDNA1(cc)) { + switch (type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + return ne11 <= 7; + case GGML_TYPE_Q5_1: + return ne11 <= 7; + case GGML_TYPE_Q8_0: + return ne11 <= 6; + case GGML_TYPE_Q2_K: + return ne11 <= 4; + case GGML_TYPE_Q3_K: + return ne11 <= 3; + case GGML_TYPE_Q4_K: + return ne11 <= 2; + case GGML_TYPE_Q5_K: + return ne11 <= 3; + case GGML_TYPE_Q6_K: + return ne11 <= 4; + case GGML_TYPE_IQ1_S: + return ne11 <= 5; + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ4_XS: + return ne11 <= 6; + default: + return ne11 <= MMVQ_MAX_BATCH_SIZE; + } + } + switch (type) { // tuned for CDNA2 + case GGML_TYPE_Q2_K: + return ne11 <= 5; + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + return ne11 <= 3; + case GGML_TYPE_Q6_K: + return ne11 <= 5; + default: + return ne11 <= MMVQ_MAX_BATCH_SIZE; + } + } + return ne11 <= MMVQ_MAX_BATCH_SIZE; +} + // Device constexpr: returns the max batch size for the current arch+type at compile time. template static constexpr __device__ int get_mmvq_mmid_max_batch_for_device() { diff --git a/ggml/src/ggml-cuda/mmvq.cuh b/ggml/src/ggml-cuda/mmvq.cuh index 6bf0a8e8677..5605bf7a4e6 100644 --- a/ggml/src/ggml-cuda/mmvq.cuh +++ b/ggml/src/ggml-cuda/mmvq.cuh @@ -2,6 +2,8 @@ #define MMVQ_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVQ kernels. +bool ggml_cuda_should_use_mmvq(enum ggml_type type, int cc, int64_t ne11); + // Returns the maximum batch size for which MMVQ should be used for MUL_MAT_ID, // based on the quantization type and GPU architecture (compute capability). int get_mmvq_mmid_max_batch(ggml_type type, int cc); From 4e8af441e5f5ec8b91e193a598929cce489374ed Mon Sep 17 00:00:00 2001 From: redfox <59549776+yaohengxu@users.noreply.github.com> Date: Thu, 28 May 2026 20:51:14 +0800 Subject: [PATCH 726/831] =?UTF-8?q?mmvq=20Optim:=20add=20MMVQ=5FPARAMETERS?= =?UTF-8?q?=5FTURING(mmvq=5Fparameter=5Ftable=5Fid)=20for=20=E2=80=A6=20(#?= =?UTF-8?q?23729)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * mmvq Optim: add MMVQ_PARAMETERS_TURING(mmvq_parameter_table_id) for SM75 TURING * avoid a mismatch for JIT compilation of Turing device code for Ampere or newer Co-authored-by: Johannes Gäßler --------- Co-authored-by: Copilot Co-authored-by: Johannes Gäßler --- ggml/src/ggml-cuda/mmvq.cu | 35 ++++++++++++++++++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 873ff05a074..ecb6fdedadd 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -63,6 +63,7 @@ static constexpr __host__ __device__ int get_vdr_mmvq(ggml_type type) { enum mmvq_parameter_table_id { MMVQ_PARAMETERS_GENERIC = 0, + MMVQ_PARAMETERS_TURING, MMVQ_PARAMETERS_GCN, MMVQ_PARAMETERS_RDNA2, MMVQ_PARAMETERS_RDNA3_0, @@ -78,6 +79,8 @@ static constexpr __device__ mmvq_parameter_table_id get_device_table_id() { return MMVQ_PARAMETERS_RDNA2; #elif defined(GCN) || defined(CDNA) return MMVQ_PARAMETERS_GCN; +#elif defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING && __CUDA_ARCH__ < GGML_CUDA_CC_AMPERE + return MMVQ_PARAMETERS_TURING; #else return MMVQ_PARAMETERS_GENERIC; #endif @@ -96,6 +99,9 @@ static __host__ mmvq_parameter_table_id get_device_table_id(int cc) { if (GGML_CUDA_CC_IS_GCN(cc) || GGML_CUDA_CC_IS_CDNA(cc)) { return MMVQ_PARAMETERS_GCN; } + if (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING && ggml_cuda_highest_compiled_arch(cc) < GGML_CUDA_CC_AMPERE) { + return MMVQ_PARAMETERS_TURING; + } return MMVQ_PARAMETERS_GENERIC; } @@ -417,11 +423,38 @@ static constexpr __host__ __device__ int calc_nwarps(ggml_type type, int ncols_d } return 1; } + if (table_id == MMVQ_PARAMETERS_TURING) { + if (ncols_dst == 1) { + switch (type) { + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + return 2; + default: + return 4; + } + } + switch (ncols_dst) { + case 2: + case 3: + case 4: + return 4; + case 5: + case 6: + case 7: + case 8: + return 2; + default: + return 1; + } + } return 1; } static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int table_id, bool small_k = false, int nwarps = 1) { - if (table_id == MMVQ_PARAMETERS_GENERIC || table_id == MMVQ_PARAMETERS_GCN) { + if (table_id == MMVQ_PARAMETERS_GENERIC || table_id == MMVQ_PARAMETERS_GCN || table_id == MMVQ_PARAMETERS_TURING) { switch (ncols_dst) { case 1: return small_k ? nwarps : 1; From e1faa7cb4d7b2c4f185fbf3fef04fc616d871fec Mon Sep 17 00:00:00 2001 From: fl0rianr <226492742+fl0rianr@users.noreply.github.com> Date: Thu, 28 May 2026 15:01:14 +0200 Subject: [PATCH 727/831] ggml: auto apply iGPU flag CUDA/HIP if integrated device (llama/23007) --- ggml/src/ggml-cuda/ggml-cuda.cu | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index dc3e8fd6265..18aaa098398 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -4994,8 +4994,14 @@ static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * } static enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend_dev_t dev) { - GGML_UNUSED(dev); - return GGML_BACKEND_DEVICE_TYPE_GPU; + ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *) dev->context; + + cudaDeviceProp prop; + CUDA_CHECK(cudaGetDeviceProperties(&prop, ctx->device)); + + return prop.integrated + ? GGML_BACKEND_DEVICE_TYPE_IGPU + : GGML_BACKEND_DEVICE_TYPE_GPU; } static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) { From 94922ce12cd88a5449e9451c34332b667e6a1d14 Mon Sep 17 00:00:00 2001 From: lhez Date: Thu, 28 May 2026 11:05:42 -0700 Subject: [PATCH 728/831] opencl: move backend info printing into its own function (llama/23702) * opencl: move backend info print into its own function * opencl: move new log line * opencl: fix for non adreno path --- ggml/src/ggml-opencl/ggml-opencl.cpp | 155 +++++++++++++++------------ 1 file changed, 86 insertions(+), 69 deletions(-) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 6d6c3e8973d..751ec6116c0 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -379,6 +379,8 @@ struct ggml_backend_opencl_device_context { GPU_FAMILY gpu_family = GPU_FAMILY::UNKNOWN; ADRENO_GPU_GEN adreno_gen = ADRENO_GPU_GEN::ADRENO_UNKNOWN; + std::regex *opfilter = nullptr; // regex of ops to not claim + std::string opfilter_str; // regex string for opfilter size_t global_mem_size = 0; }; @@ -415,8 +417,6 @@ struct ggml_backend_opencl_context { bool has_qcom_subgroup_shuffle = false; // cl_qcom_subgroup_shuffle bool disable_fusion; - std::regex *opfilter = nullptr; // regex of ops to not claim - bool adreno_has_large_buffer; bool adreno_use_large_buffer; ggml_cl_compiler_version adreno_cl_compiler_version; @@ -428,6 +428,8 @@ struct ggml_backend_opencl_context { size_t image2d_max_width; size_t image2d_max_height; + cl_device_svm_capabilities svm_caps; + cl_context context; cl_command_queue queue; @@ -3731,6 +3733,68 @@ static std::vector ggml_opencl_probe_devices(ggml_backend_r return found_devices; } +static void ggml_opencl_print_backend_info(ggml_backend_opencl_device_context * dev_ctx) { + GGML_ASSERT(dev_ctx); + GGML_ASSERT(dev_ctx->backend_ctx); + + auto * backend_ctx = dev_ctx->backend_ctx; + + GGML_LOG_INFO("ggml_opencl: OpenCL driver: %s\n", + backend_ctx->driver_version.c_str()); + GGML_LOG_INFO("ggml_opencl: vector subgroup broadcast support: %s\n", + backend_ctx->has_vector_subgroup_broadcast ? "true" : "false"); + GGML_LOG_INFO("ggml_opencl: device FP16 support: %s\n", + backend_ctx->fp16_support ? "true" : "false"); + GGML_LOG_INFO("ggml_opencl: mem base addr align: %u\n", + backend_ctx->alignment); + GGML_LOG_INFO("ggml_opencl: global mem size: %zu MB\n", + backend_ctx->global_mem_size/1024/1024); + GGML_LOG_INFO("ggml_opencl: max mem alloc size: %zu MB\n", + backend_ctx->max_alloc_size/1024/1024); + GGML_LOG_INFO("ggml_opencl: device max image buffer size (pixels): %lu\n", + backend_ctx->image_max_buffer_size); + GGML_LOG_INFO("ggml_opencl: device max image2d size: %lu x %lu\n", + backend_ctx->image2d_max_width, backend_ctx->image2d_max_height); + GGML_LOG_INFO("ggml_opencl: device max workgroup size: %lu\n", + backend_ctx->max_workgroup_size); + GGML_LOG_INFO("ggml_opencl: SVM coarse grain buffer support: %s\n", + backend_ctx->svm_caps & CL_DEVICE_SVM_COARSE_GRAIN_BUFFER ? "true" : "false"); + GGML_LOG_INFO("ggml_opencl: SVM fine grain buffer support: %s\n", + backend_ctx->svm_caps & CL_DEVICE_SVM_FINE_GRAIN_BUFFER ? "true" : "false"); + GGML_LOG_INFO("ggml_opencl: SVM fine grain system support: %s\n", + backend_ctx->svm_caps & CL_DEVICE_SVM_FINE_GRAIN_SYSTEM ? "true" : "false"); + GGML_LOG_INFO("ggml_opencl: SVM atomics support: %s\n", + backend_ctx->svm_caps & CL_DEVICE_SVM_ATOMICS ? "true" : "false"); + GGML_LOG_INFO("ggml_opencl: cl_qcom_subgroup_shuffle support: %s\n", + backend_ctx->has_qcom_subgroup_shuffle ? "true" : "false"); + + // Print out configurations +#ifdef GGML_OPENCL_SOA_Q + GGML_LOG_INFO("ggml_opencl: flattening quantized weights representation as struct of arrays (GGML_OPENCL_SOA_Q)\n"); +#endif // GGML_OPENCL_SOA_Q + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + GGML_LOG_INFO("ggml_opencl: using kernels optimized for Adreno (GGML_OPENCL_USE_ADRENO_KERNELS)\n"); + if (backend_ctx->adreno_xmem_gemm_enabled) { + GGML_LOG_INFO("ggml_opencl: Adreno xmem F16xF32 GEMM enabled (temporary weight prepack)\n"); + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + + if (backend_ctx->adreno_use_large_buffer) { + if (!backend_ctx->adreno_has_large_buffer) { + GGML_LOG_INFO("ggml_opencl: Adreno large buffer requested but not supported by driver, will use regular buffer\n"); + backend_ctx->adreno_use_large_buffer = false; + } else { + GGML_LOG_INFO("ggml_opencl: Adreno large buffer enabled\n"); + } + } + + if (dev_ctx->opfilter) { + // for information only, the actual regex object is created in ggml_opencl_is_device_supported + GGML_LOG_INFO("ggml_opencl: opfilter regex = \"%s\"\n", dev_ctx->opfilter_str.c_str()); + } +} + // check if device should be accepted static bool ggml_opencl_is_device_supported(ggml_backend_dev_t dev) { GGML_ASSERT(dev); @@ -3799,6 +3863,13 @@ static bool ggml_opencl_is_device_supported(ggml_backend_dev_t dev) { } clGetDeviceInfo(dev_ctx->device, CL_DEVICE_GLOBAL_MEM_SIZE, sizeof(size_t), &dev_ctx->global_mem_size, NULL); + + const char * str_opfilter = getenv("GGML_OPENCL_OPFILTER"); + if (str_opfilter) { + dev_ctx->opfilter_str = str_opfilter; + dev_ctx->opfilter = new std::regex(str_opfilter, std::regex_constants::icase); + } + return true; } @@ -3850,15 +3921,12 @@ static ggml_backend_opencl_context * ggml_cl_init(ggml_backend_dev_t dev) { char *driver_version = (char *)alloca(driver_version_str_size + 1); clGetDeviceInfo(device, CL_DRIVER_VERSION, driver_version_str_size, driver_version, NULL); driver_version[driver_version_str_size] = '\0'; - GGML_LOG_INFO("ggml_opencl: OpenCL driver: %s\n", driver_version); backend_ctx->driver_version = driver_version; backend_ctx->adreno_cl_compiler_version = get_adreno_cl_compiler_version(driver_version); backend_ctx->has_vector_subgroup_broadcast = (backend_ctx->adreno_cl_compiler_version.type == E031 && backend_ctx->adreno_cl_compiler_version.major >= 47) || (backend_ctx->adreno_cl_compiler_version.type == DX && backend_ctx->adreno_cl_compiler_version.major >= 17); - GGML_LOG_INFO("ggml_opencl: vector subgroup broadcast support: %s\n", - backend_ctx->has_vector_subgroup_broadcast ? "true" : "false"); size_t ext_str_size; clGetDeviceInfo(device, CL_DEVICE_EXTENSIONS, 0, NULL, &ext_str_size); @@ -3867,18 +3935,12 @@ static ggml_backend_opencl_context * ggml_cl_init(ggml_backend_dev_t dev) { ext_buffer[ext_str_size] = '\0'; // ensure it is null terminated // check support for qcom_subgroup_shuffle - if (opencl_c_version.major == 3 && strstr(ext_buffer, "cl_khr_subgroups") != NULL) { - GGML_LOG_INFO("ggml_opencl: cl_khr_subgroups support: true\n"); - if (strstr(ext_buffer, "cl_qcom_subgroup_shuffle") != NULL) { - backend_ctx->has_qcom_subgroup_shuffle = true; - } + if (strstr(ext_buffer, "cl_qcom_subgroup_shuffle") != NULL) { + backend_ctx->has_qcom_subgroup_shuffle = true; } - GGML_LOG_INFO("ggml_opencl: cl_qcom_subgroup_shuffle support: %s\n", - backend_ctx->has_qcom_subgroup_shuffle ? "true" : "false"); // Check if ext_buffer contains cl_khr_fp16 backend_ctx->fp16_support = strstr(ext_buffer, "cl_khr_fp16") != NULL; - GGML_LOG_INFO("ggml_opencl: device FP16 support: %s\n", backend_ctx->fp16_support ? "true" : "false"); // check Adreno large buffer support backend_ctx->adreno_has_large_buffer = strstr(ext_buffer, "cl_qcom_large_buffer") != NULL; @@ -3887,35 +3949,15 @@ static ggml_backend_opencl_context * ggml_cl_init(ggml_backend_dev_t dev) { CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_MEM_BASE_ADDR_ALIGN, sizeof(cl_uint), &base_align_in_bits, NULL)); GGML_ASSERT(base_align_in_bits % 8u == 0); backend_ctx->alignment = base_align_in_bits / 8u; - GGML_LOG_INFO("ggml_opencl: mem base addr align: %u\n", backend_ctx->alignment); backend_ctx->global_mem_size = dev_ctx->global_mem_size; - GGML_LOG_INFO("ggml_opencl: global mem size: %zu MB\n", backend_ctx->global_mem_size/1024/1024); - - clGetDeviceInfo(device, CL_DEVICE_MAX_MEM_ALLOC_SIZE, sizeof(size_t), &backend_ctx->max_alloc_size, NULL); - GGML_LOG_INFO("ggml_opencl: max mem alloc size: %zu MB\n", backend_ctx->max_alloc_size/1024/1024); - - clGetDeviceInfo(device, CL_DEVICE_IMAGE_MAX_BUFFER_SIZE, sizeof(size_t), &backend_ctx->image_max_buffer_size, NULL); - GGML_LOG_INFO("ggml_opencl: device max image buffer size (pixels): %lu\n", backend_ctx->image_max_buffer_size); - clGetDeviceInfo(device, CL_DEVICE_IMAGE2D_MAX_WIDTH, sizeof(size_t), &backend_ctx->image2d_max_width, NULL); - clGetDeviceInfo(device, CL_DEVICE_IMAGE2D_MAX_HEIGHT, sizeof(size_t), &backend_ctx->image2d_max_height, NULL); - GGML_LOG_INFO("ggml_opencl: device max image2d size: %lu x %lu\n", backend_ctx->image2d_max_width, backend_ctx->image2d_max_height); - - clGetDeviceInfo(device, CL_DEVICE_MAX_WORK_GROUP_SIZE, sizeof(size_t), &backend_ctx->max_workgroup_size, NULL); - GGML_LOG_INFO("ggml_opencl: device max workgroup size: %lu\n", backend_ctx->max_workgroup_size); - - // Check SVM. - cl_device_svm_capabilities svm_caps; - CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_SVM_CAPABILITIES, sizeof(cl_device_svm_capabilities), &svm_caps, 0)); - GGML_LOG_INFO("ggml_opencl: SVM coarse grain buffer support: %s\n", - svm_caps & CL_DEVICE_SVM_COARSE_GRAIN_BUFFER ? "true" : "false"); - GGML_LOG_INFO("ggml_opencl: SVM fine grain buffer support: %s\n", - svm_caps & CL_DEVICE_SVM_FINE_GRAIN_BUFFER ? "true" : "false"); - GGML_LOG_INFO("ggml_opencl: SVM fine grain system support: %s\n", - svm_caps & CL_DEVICE_SVM_FINE_GRAIN_SYSTEM ? "true" : "false"); - GGML_LOG_INFO("ggml_opencl: SVM atomics support: %s\n", - svm_caps & CL_DEVICE_SVM_ATOMICS ? "true" : "false"); + CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_MAX_MEM_ALLOC_SIZE, sizeof(size_t), &backend_ctx->max_alloc_size, NULL)); + CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_IMAGE_MAX_BUFFER_SIZE, sizeof(size_t), &backend_ctx->image_max_buffer_size, NULL)); + CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_IMAGE2D_MAX_WIDTH, sizeof(size_t), &backend_ctx->image2d_max_width, NULL)); + CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_IMAGE2D_MAX_HEIGHT, sizeof(size_t), &backend_ctx->image2d_max_height, NULL)); + CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_MAX_WORK_GROUP_SIZE, sizeof(size_t), &backend_ctx->max_workgroup_size, NULL)); + CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_SVM_CAPABILITIES, sizeof(cl_device_svm_capabilities), &backend_ctx->svm_caps, 0)); if (opencl_c_version.major >= 3) { // Assume it is not available for 3.0, since it is optional in 3.0. @@ -3931,36 +3973,15 @@ static ggml_backend_opencl_context * ggml_cl_init(ggml_backend_dev_t dev) { backend_ctx->non_uniform_workgroups = true; } - // Print out configurations -#ifdef GGML_OPENCL_SOA_Q - GGML_LOG_INFO("ggml_opencl: flattening quantized weights representation as struct of arrays (GGML_OPENCL_SOA_Q)\n"); -#endif // GGML_OPENCL_SOA_Q - -#ifdef GGML_OPENCL_USE_ADRENO_KERNELS - GGML_LOG_INFO("ggml_opencl: using kernels optimized for Adreno (GGML_OPENCL_USE_ADRENO_KERNELS)\n"); -#endif // GGML_OPENCL_USE_ADRENO_KERNELS - #ifdef GGML_OPENCL_USE_ADRENO_KERNELS + // determine whether to use Adreno xmem GEMM backend_ctx->adreno_xmem_gemm_enabled = getenv("GGML_OPENCL_ADRENO_XMEM_GEMM") != nullptr && backend_ctx->gpu_family == GPU_FAMILY::ADRENO; - if (getenv("GGML_OPENCL_ADRENO_XMEM_GEMM") != nullptr) { - GGML_LOG_INFO("ggml_opencl: Adreno xmem F16xF32 GEMM %s\n", - backend_ctx->adreno_xmem_gemm_enabled ? - "enabled (temporary weight prepack)" : "requested but unsupported by this driver"); - } -#endif // GGML_OPENCL_USE_ADRENO_KERNELS +#endif // determine whether to use large buffer for Adreno backend_ctx->adreno_use_large_buffer = getenv("GGML_OPENCL_ADRENO_USE_LARGE_BUFFER") != nullptr && backend_ctx->gpu_family == GPU_FAMILY::ADRENO; - if (backend_ctx->adreno_use_large_buffer) { - if (!backend_ctx->adreno_has_large_buffer) { - GGML_LOG_INFO("ggml_opencl: Adreno large buffer requested but not supported by driver, will use regular buffer\n"); - backend_ctx->adreno_use_large_buffer = false; - } else { - GGML_LOG_INFO("ggml_opencl: Adreno large buffer enabled\n"); - } - } cl_int err; @@ -4010,12 +4031,6 @@ static ggml_backend_opencl_context * ggml_cl_init(ggml_backend_dev_t dev) { backend_ctx->disable_fusion = getenv("GGML_OPENCL_DISABLE_FUSION") != nullptr; - const char * str_opfilter = getenv("GGML_OPENCL_OPFILTER"); - if (str_opfilter) { - backend_ctx->opfilter = new std::regex(str_opfilter, std::regex_constants::icase); - GGML_LOG_INFO("ggml_opencl: opfilter regex = \"%s\"\n", str_opfilter); - } - dev_ctx->backend_ctx = backend_ctx.release(); return dev_ctx->backend_ctx; } @@ -4825,7 +4840,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te ggml_backend_opencl_context * backend_ctx = dev_ctx->backend_ctx; // reject ops that match the opfilter regex - if (backend_ctx->opfilter && std::regex_match(std::string(ggml_op_desc(op)), *backend_ctx->opfilter)) { + if (dev_ctx->opfilter && std::regex_match(std::string(ggml_op_desc(op)), *dev_ctx->opfilter)) { return false; } @@ -7823,6 +7838,8 @@ static ggml_backend_t ggml_backend_opencl_device_init(ggml_backend_dev_t dev, co /* .context = */ backend_ctx, }; + ggml_backend_opencl_device_context * dev_ctx = (ggml_backend_opencl_device_context *) dev->context; + ggml_opencl_print_backend_info(dev_ctx); return backend; GGML_UNUSED(params); From 442be1789d750994b8afaad8533e16e46730606e Mon Sep 17 00:00:00 2001 From: Max Krasnyansky Date: Thu, 28 May 2026 14:05:54 -0700 Subject: [PATCH 729/831] hexagon: basic/generic op fusion support and RMS_NORM+MUL fusion (llama/23835) Updating infra to enable op fusion and using RMS_NORM+MUL as the use-case. --- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 143 +++++++-------- ggml/src/ggml-hexagon/htp-opnode.h | 241 +++++++++++++++++++++++++ ggml/src/ggml-hexagon/htp/htp-ops.h | 1 + ggml/src/ggml-hexagon/htp/main.c | 1 + ggml/src/ggml-hexagon/htp/unary-ops.c | 194 +++++++++++++++++++- ggml/src/ggml-hexagon/op-desc.h | 153 ---------------- 6 files changed, 497 insertions(+), 236 deletions(-) create mode 100644 ggml/src/ggml-hexagon/htp-opnode.h delete mode 100644 ggml/src/ggml-hexagon/op-desc.h diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 3af7aff7028..48ded82e83c 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -39,7 +39,7 @@ #include "ggml-hexagon.h" #include "ggml-impl.h" #include "ggml-quants.h" -#include "op-desc.h" +#include "htp-opnode.h" #include "htp-ops.h" #include "htp_iface.h" #include "htp-drv.h" @@ -102,23 +102,23 @@ static const char * status_to_str(uint32_t status) { // ** debug helpers -static void ggml_hexagon_dump_op_exec(const std::string &sess_name, const ggml_tensor * op, const uint32_t req_flags) { +static void ggml_hexagon_dump_op_exec(const std::string &sess_name, const htp_opnode & node, const uint32_t req_flags) { if (!opt_verbose) return; - op_desc desc(op); + htp_opformat fmt(node); GGML_LOG_DEBUG("ggml-hex: %s execute-op %s: %s : %s : %s : %s : %s : flags 0x%x\n", sess_name.c_str(), - ggml_op_desc(op), desc.names, desc.dims, desc.types, desc.strides, desc.buffs, req_flags); + node.op_name().c_str(), fmt.names, fmt.dims, fmt.types, fmt.strides, fmt.buffs, req_flags); } static void ggml_hexagon_dump_op_supp(const std::string &sess_name, const struct ggml_tensor * op, bool supp) { if (!opt_verbose) return; - op_desc desc(op); + htp_opformat fmt(htp_opformat(htp_opnode{const_cast(op), {}, HTP_OP_INVALID})); GGML_LOG_DEBUG("ggml-hex: %s supports-op %s: %s : %s : %s : %s : %s : %s\n", sess_name.c_str(), - ggml_op_desc(op), desc.names, desc.dims, desc.types, desc.strides, desc.buffs, supp ? "yes" : "no"); + ggml_op_desc(op), fmt.names, fmt.dims, fmt.types, fmt.strides, fmt.buffs, supp ? "yes" : "no"); } -static void ggml_hexagon_dump_op_prof(const std::string &sess_name, const ggml_tensor * op, +static void ggml_hexagon_dump_op_prof(const std::string &sess_name, const htp_opnode & node, uint32_t op_usec, uint32_t op_cycles, const uint32_t pmu[]) { if (!opt_profile) return; @@ -129,15 +129,16 @@ static void ggml_hexagon_dump_op_prof(const std::string &sess_name, const ggml_t pmu[0], pmu[1], pmu[2], pmu[3], pmu[4], pmu[5], pmu[6], pmu[7]); } - op_desc desc(op); + htp_opformat fmt(node); GGML_LOG_DEBUG("ggml-hex: %s profile-op %s: %s : %s : %s : %s : usec %u cycles %u%s\n", sess_name.c_str(), - ggml_op_desc(op), desc.names, desc.dims, desc.types, desc.strides, op_usec, op_cycles, pmu_str); + node.op_name().c_str(), fmt.names, fmt.dims, fmt.types, fmt.strides, op_usec, op_cycles, pmu_str); } // ** backend sessions struct ggml_hexagon_opbatch; struct ggml_hexagon_opqueue; +struct htp_opnode; struct ggml_hexagon_session { std::string name; @@ -167,7 +168,7 @@ struct ggml_hexagon_session { void allocate(int dev_id) noexcept(false); void release() noexcept(true); - void enqueue_op(htp_op_code opcode, const ggml_tensor *op); + void enqueue_op(const htp_opnode & node); void flush(bool all = true); void flush_pending(bool all = false); @@ -1782,12 +1783,10 @@ static ggml_backend_buffer_type_i ggml_backend_hexagon_repack_buffer_type_interf /* .is_host = */ ggml_backend_hexagon_repack_buffer_type_is_host, }; -// Backend session implementation - struct ggml_hexagon_opbatch { ggml_hexagon_session* sess; - std::vector ops; // pointers to original ops + std::vector ops; // htp_opnode of ops std::vector h_bufs; // htp buffer descriptors std::vector h_tens; // htp tensor descriptors @@ -1919,7 +1918,7 @@ struct ggml_hexagon_opbatch { return ti; } - bool fit_op(const struct ggml_tensor *t) const { + bool fit_op(const htp_opnode & node) const { if (n_ops >= n_ops_max ) return false; // check how much extras we will need @@ -1939,10 +1938,10 @@ struct ggml_hexagon_opbatch { } }; - for (unsigned int i=0; i < HTP_OP_MAX_INPUTS && t->src[i]; i++) { - fit_tensor(t->src[i]); + for (const auto * src : node.get_inputs()) { + fit_tensor(src); } - fit_tensor(t); + fit_tensor(node.dst()); if ((extra_bufs + n_bufs) > n_bufs_max) return false; if ((extra_tens + n_tens) > n_tens_max) return false; @@ -1952,29 +1951,30 @@ struct ggml_hexagon_opbatch { } // assumes that fit_op() was called first and returned true - void add_op(htp_op_code opcode, const struct ggml_tensor * t) { + void add_op(const htp_opnode & node) { // Add new op unsigned int n = n_ops++; GGML_ASSERT(n_ops <= n_ops_max); - ops[n] = t; + ops[n] = node; htp_op_desc &o = h_ops[n]; - memcpy(&o.params, &t->op_params, sizeof(t->op_params)); - o.opcode = opcode; + memcpy(&o.params, &node.node->op_params, sizeof(node.node->op_params)); + o.opcode = node.opcode; o.flags = 0; if (!(opt_opstage & HTP_OPSTAGE_COMPUTE)) { o.flags |= HTP_OPFLAGS_SKIP_COMPUTE; } - ggml_hexagon_dump_op_exec(sess->c_name(), t, o.flags); + ggml_hexagon_dump_op_exec(sess->c_name(), node, o.flags); + auto inputs = node.get_inputs(); for (unsigned int i=0; i < HTP_OP_MAX_INPUTS; i++) { - o.src[i] = t->src[i] ? add_tensor(t->src[i]) : 0xffff; + o.src[i] = (i < inputs.size() && inputs[i]) ? add_tensor(inputs[i]) : 0xffff; } - o.dst = add_tensor(t); + o.dst = add_tensor(node.dst()); } }; @@ -1983,7 +1983,7 @@ struct ggml_hexagon_opqueue { ggml_hexagon_shared_buffer *shm_buf; size_t shm_blk_size; - using opvec = std::vector; + using opvec = std::vector; std::queue done; // completed batch ids std::vector op_cache; // per batch op cache @@ -2182,11 +2182,11 @@ void ggml_hexagon_session::flush_batch() { } } -void ggml_hexagon_session::enqueue_op(htp_op_code opcode, const ggml_tensor *op) { - if (!op_batch->fit_op(op)) { +void ggml_hexagon_session::enqueue_op(const htp_opnode & node) { + if (!op_batch->fit_op(node)) { flush_batch(); } - op_batch->add_op(opcode, op); + op_batch->add_op(node); } // Flush HTP response queue i.e wait for all outstanding requests to complete @@ -3179,10 +3179,43 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg HEX_VERBOSE("ggml-hex: %s graph-compute n_nodes %d\n", sess->c_name(), graph->n_nodes); + std::vector nodes; + nodes.reserve(graph->n_nodes); + + // Fusion for (int i = 0; i < graph->n_nodes; ++i) { ggml_tensor * n = graph->nodes[i]; - if (op_is_compute(n) && (opt_opstage & HTP_OPSTAGE_QUEUE)) { - sess->enqueue_op(op_remap_to_htp(n), n); + if (!op_is_compute(n)) { + continue; + } + + ggml_tensor * next_node = (i + 1 < graph->n_nodes) ? graph->nodes[i + 1] : nullptr; + + htp_opnode node = { + /*.node =*/ n, + /*.fused =*/ {}, + /*.opcode =*/ HTP_OP_INVALID + }; + + if (n->op == GGML_OP_RMS_NORM && next_node) { + if (next_node->op == GGML_OP_MUL && op_is_compute(next_node) && ggml_can_fuse(graph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { + node.add_fused(next_node); + node.opcode = HTP_OP_RMS_NORM_MUL; + i++; // skip the fused MUL node + } + } + + if (node.opcode == HTP_OP_INVALID) { + node.opcode = op_remap_to_htp(n); + } + + nodes.push_back(std::move(node)); + } + + // Queue and execute + if (opt_opstage & HTP_OPSTAGE_QUEUE) { + for (const auto & node : nodes) { + sess->enqueue_op(node); } } @@ -3201,51 +3234,7 @@ static void ggml_backend_hexagon_synchronize(ggml_backend_t backend) { sess->flush(); } -struct node_info { - ggml_tensor * node; - - std::vector fused; - - ggml_op op() const { - return node->op; - } - - const ggml_tensor * dst() const { - return fused.empty() ? node : fused.back(); - } - - const ggml_tensor * src0() const { - return node->src[0]; - } - - const ggml_tensor * src1() const { - return node->src[1]; - } - - bool is_empty() const { - return ggml_op_is_empty(node->op); - } - - void add_fused(ggml_tensor * t) { - fused.push_back(t); - } - - bool stackable() const { - switch (this->op()) { - case GGML_OP_MUL_MAT: - case GGML_OP_MUL_MAT_ID: - return ggml_is_quantized(this->src0()->type); - default: - return false; - } - } - - bool same_input(const node_info& n) const { - return n.src1() == this->src1(); - } -}; - -static std::vector ggml_hexagon_graph_optimize_reorder(const std::vector & nodes) { +static std::vector ggml_hexagon_graph_optimize_reorder(const std::vector & nodes) { const int n = nodes.size(); std::vector res; @@ -3299,14 +3288,14 @@ static void ggml_backend_hexagon_graph_optimize(ggml_backend_t backend, ggml_cgr enum ggml_op ops[MAX_FUSE]; - std::vector nodes; + std::vector nodes; nodes.reserve(gf->n_nodes); // fuse nodes: // we don't want to make reorders that break fusing, so we first pack all fusable tensors // and perform the reorder over the fused nodes. after the reorder is done, we unfuse for (int i = 0; i < n; i++) { - node_info node = { + htp_opnode node = { /*.node =*/gf->nodes[i], /*.fused =*/{}, }; diff --git a/ggml/src/ggml-hexagon/htp-opnode.h b/ggml/src/ggml-hexagon/htp-opnode.h new file mode 100644 index 00000000000..14b232240b4 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp-opnode.h @@ -0,0 +1,241 @@ +#ifndef HTP_OPNODE_H +#define HTP_OPNODE_H + +#define GGML_COMMON_IMPL_CPP +#include "ggml-backend-impl.h" +#include "ggml-common.h" + +#include +#include +#include +#include "htp-ops.h" + +struct htp_opnode { + ggml_tensor * node = nullptr; + + std::vector fused; + + htp_op_code opcode = HTP_OP_INVALID; + + ggml_op op() const { + return node->op; + } + + const ggml_tensor * dst() const { + return fused.empty() ? node : fused.back(); + } + + const ggml_tensor * src0() const { + return node->src[0]; + } + + const ggml_tensor * src1() const { + return node->src[1]; + } + + bool is_empty() const { + return ggml_op_is_empty(node->op); + } + + void add_fused(ggml_tensor * t) { + fused.push_back(t); + } + + bool stackable() const { + switch (this->op()) { + case GGML_OP_MUL_MAT: + case GGML_OP_MUL_MAT_ID: + return ggml_is_quantized(this->src0()->type); + default: + return false; + } + } + + bool same_input(const htp_opnode& n) const { + return n.src1() == this->src1(); + } + + std::vector get_inputs() const { + std::vector inputs; + std::vector outputs; + outputs.push_back(node); + for (const auto * f : fused) { + outputs.push_back(f); + } + + auto contains = [&](const std::vector & vec, const ggml_tensor * t) { + for (const auto * x : vec) { + if (x == t) return true; + } + return false; + }; + + auto add_input = [&](const ggml_tensor * t) { + if (t && !contains(outputs, t) && !contains(inputs, t)) { + inputs.push_back(t); + } + }; + + for (int i = 0; i < GGML_MAX_SRC && node->src[i]; i++) { + add_input(node->src[i]); + } + for (const auto * f : fused) { + for (int i = 0; i < GGML_MAX_SRC && f->src[i]; i++) { + add_input(f->src[i]); + } + } + return inputs; + } + + std::string op_name() const { + if (fused.empty()) { + return ggml_op_desc(node); + } + std::string name = ggml_op_desc(node); + for (const auto * f : fused) { + name += "+"; + name += ggml_op_desc(f); + } + return name; + } +}; + +struct htp_opformat { + char strides[64 * GGML_MAX_SRC]; + char dims[64 * GGML_MAX_SRC]; + char types[16 * GGML_MAX_SRC]; + char buffs[64 * GGML_MAX_SRC]; + char names[64 * GGML_MAX_SRC]; + + int format_tensor_dims(char * str, const struct ggml_tensor * t) { + if (t->ne[2] == 1 && t->ne[3] == 1) { + return sprintf(str, "%d:%d", (int) t->ne[0], (int) t->ne[1]); + } else { + return sprintf(str, "%d:%d:%d:%d", (int) t->ne[0], (int) t->ne[1], (int) t->ne[2], (int) t->ne[3]); + } + } + + void format_op_dims(char * str, const htp_opnode & node) { + char * p = str; + auto inputs = node.get_inputs(); + + if (!inputs.empty()) { + p += format_tensor_dims(p, inputs[0]); + + for (size_t i = 1; i < inputs.size(); i++) { + p += sprintf(p, " x "); + p += format_tensor_dims(p, inputs[i]); + } + + p += sprintf(p, " -> "); + } + + char self[64]; + format_tensor_dims(self, node.dst()); + p += sprintf(p, "%s", self); + } + + int format_tensor_strides(char * str, const struct ggml_tensor * t) { + const char * c = ggml_is_contiguous(t) ? "" : "!"; + + if (t->ne[2] == 1 && t->ne[3] == 1) { + return sprintf(str, "%zu:%zu%s", (size_t) t->nb[0], (size_t) t->nb[1], c); + } else { + return sprintf(str, "%zu:%zu:%zu:%zu%s", (size_t) t->nb[0], (size_t) t->nb[1], (size_t) t->nb[2], (size_t) t->nb[3], c); + } + } + + void format_op_strides(char * str, const htp_opnode & node) { + char * p = str; + auto inputs = node.get_inputs(); + + if (!inputs.empty()) { + p += format_tensor_strides(p, inputs[0]); + + for (size_t i = 1; i < inputs.size(); i++) { + p += sprintf(p, " x "); + p += format_tensor_strides(p, inputs[i]); + } + + p += sprintf(p, " -> "); + } + + char self[64]; + format_tensor_strides(self, node.dst()); + p += sprintf(p, "%s", self); + } + + void format_op_types(char * str, const htp_opnode & node) { + char * p = str; + auto inputs = node.get_inputs(); + + if (!inputs.empty()) { + p += sprintf(p, "%s", ggml_type_name(inputs[0]->type)); + + for (size_t i = 1; i < inputs.size(); i++) { + p += sprintf(p, " x "); + p += sprintf(p, "%s", ggml_type_name(inputs[i]->type)); + } + + p += sprintf(p, " -> "); + } + + p += sprintf(p, "%s", ggml_type_name(node.dst()->type)); + } + + const char * tensor_buff_name(const struct ggml_tensor * t) { + if (t->buffer) { + return ggml_backend_buffer_name(t->buffer); + } + return "NONE"; + } + + void format_op_buffs(char * str, const htp_opnode & node) { + char * p = str; + auto inputs = node.get_inputs(); + + if (!inputs.empty()) { + p += sprintf(p, "%s", tensor_buff_name(inputs[0])); + + for (size_t i = 1; i < inputs.size(); i++) { + p += sprintf(p, " x "); + p += sprintf(p, "%s", tensor_buff_name(inputs[i])); + } + + p += sprintf(p, " -> "); + } + + p += sprintf(p, "%s", tensor_buff_name(node.dst())); + } + + void format_op_names(char * str, const htp_opnode & node) { + char * p = str; + auto inputs = node.get_inputs(); + + if (!inputs.empty()) { + p += sprintf(p, "%s", inputs[0]->name); + + for (size_t i = 1; i < inputs.size(); i++) { + p += sprintf(p, " x "); + p += sprintf(p, "%s", inputs[i]->name); + } + + p += sprintf(p, " -> "); + } + + p += sprintf(p, "%s", node.dst()->name); + } + + void format(const htp_opnode & node) { + format_op_dims(dims, node); + format_op_strides(strides, node); + format_op_types(types, node); + format_op_buffs(buffs, node); + format_op_names(names, node); + } + + htp_opformat() {} + htp_opformat(const htp_opnode & node) { format(node); } +}; + +#endif // HTP_OPNODE_H diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index aadc77235ba..fa85bf4ca0c 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -58,6 +58,7 @@ enum htp_op_code { HTP_OP_MUL_MAT, HTP_OP_MUL_MAT_ID, HTP_OP_RMS_NORM, + HTP_OP_RMS_NORM_MUL, HTP_OP_UNARY_SILU, HTP_OP_UNARY_GELU, HTP_OP_UNARY_SIGMOID, diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index 7dd90ac7d7f..623008be4e2 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -537,6 +537,7 @@ static int execute_op(struct htp_ops_context * octx) { case HTP_OP_NORM: case HTP_OP_RMS_NORM: + case HTP_OP_RMS_NORM_MUL: case HTP_OP_SCALE: case HTP_OP_SQR: case HTP_OP_SQRT: diff --git a/ggml/src/ggml-hexagon/htp/unary-ops.c b/ggml/src/ggml-hexagon/htp/unary-ops.c index 7d0431d8ba8..770a6673211 100644 --- a/ggml/src/ggml-hexagon/htp/unary-ops.c +++ b/ggml/src/ggml-hexagon/htp/unary-ops.c @@ -23,21 +23,26 @@ struct htp_unary_context { // Precomputed values const uint8_t * data_src0; + const uint8_t * data_src1; // weight/scale tensor for RMS_NORM_MUL uint8_t * data_dst; size_t src0_data_row_size; // actual data bytes per row + size_t src1_data_row_size; size_t dst_data_row_size; // actual data bytes per row size_t src0_row_size_aligned; + size_t src1_row_size_aligned; size_t dst_row_size_aligned; size_t src0_spad_half_size; + size_t src1_spad_half_size; size_t dst_spad_half_size; uint32_t block; uint32_t src0_nrows; uint32_t src0_nrows_per_thread; uint32_t nc; + bool broadcast_weight; }; // Convert flat row index to DDR byte offset using the tensor's actual strides. @@ -158,6 +163,71 @@ static void hvx_fast_rms_norm_f32(const uint8_t * restrict src, } } +static void hvx_fast_rms_norm_mul_f32(const uint8_t * restrict src, + const uint8_t * restrict weight, + uint8_t * restrict dst, + const int num_elems, + float epsilon) { + const HVX_Vector * restrict v_src = (const HVX_Vector *) src; + const HVX_Vector * restrict v_weight = (const HVX_Vector *) weight; + HVX_Vector * restrict v_dst = (HVX_Vector *) dst; + + const int nvec = num_elems / VLEN_FP32; // number of full vectors + const int nloe = num_elems % VLEN_FP32; // leftover elements + + // Compute sum of squares for full vectors + HVX_Vector sum_v = Q6_V_vsplat_R(0x00000000); + HVX_Vector epsilon_v = hvx_vec_splat_f32(epsilon); + + #pragma unroll(4) + for (int i = 0; i < nvec; i++) { + HVX_Vector v1 = v_src[i]; + HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, v1); + sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2); + } + + // Handle tail elements using vectorized ops with masking + if (nloe > 0) { + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4); + HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]); + HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, v1); + sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2); + } + + // Reduce HVX sum + sum_v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_v)); + + HVX_Vector t_v = hvx_vec_splat_f32((float) num_elems); + HVX_Vector denom_v = hvx_vec_inverse_f32(t_v); + HVX_Vector mean_v = Q6_Vqf32_vmpy_VsfVsf(sum_v, denom_v); + HVX_Vector mean_epsilon_v = Q6_Vqf32_vadd_Vqf32Vsf(mean_v, epsilon_v); + + // Scale and multiply + HVX_Vector scale_v = hvx_vec_rsqrt_f32(Q6_Vsf_equals_Vqf32(mean_epsilon_v)); + + #pragma unroll(4) + for (int i = 0; i < nvec; i++) { + HVX_Vector v1 = v_src[i]; + HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, scale_v); + HVX_Vector v3 = Q6_Vsf_equals_Vqf32(v2); + HVX_Vector result = Q6_Vqf32_vmpy_VsfVsf(v3, v_weight[i]); + v_dst[i] = Q6_Vsf_equals_Vqf32(result); + } + + // Handle tail elements using vectorized ops with masking + if (nloe > 0) { + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4); + HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]); + HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, scale_v); + HVX_Vector v3 = Q6_Vsf_equals_Vqf32(v2); + HVX_Vector result = Q6_Vqf32_vmpy_VsfVsf(v3, v_weight[nvec]); + HVX_Vector res_v = Q6_Vsf_equals_Vqf32(result); + + // Store with masking to avoid overwriting memory beyond the tensor + hvx_vec_store_a(&v_dst[nvec], nloe * 4, res_v); + } +} + static void hvx_fast_norm_f32(const uint8_t * restrict src, uint8_t * restrict dst, uint8_t * restrict pad, @@ -269,6 +339,27 @@ static void rms_norm_f32(const float * restrict src, } } +static void rms_norm_mul_f32(const float * restrict src, + const float * restrict weight, + float * restrict dst, + const uint32_t num_rows, + const uint32_t row_elems, + const size_t row_size, + const size_t weight_row_size, + int32_t * op_params, + bool broadcast_weight) { + float epsilon = 0.f; + memcpy(&epsilon, op_params, sizeof(float)); + + for (uint32_t ir = 0; ir < num_rows; ir++) { + const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size); + const uint8_t * restrict w_local = (const uint8_t *)weight + (broadcast_weight ? 0 : ir * weight_row_size); + uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size); + + hvx_fast_rms_norm_mul_f32(src_local, w_local, dst_local, row_elems, epsilon); + } +} + static void norm_f32(const float * restrict src, float * restrict dst, uint8_t * restrict spad, @@ -598,12 +689,15 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * t1 = HAP_perf_get_qtimer_count(); const uint8_t * restrict data_src = uctx->data_src0; + const uint8_t * restrict data_src1 = uctx->data_src1; uint8_t * restrict data_dst = uctx->data_dst; uint8_t * src0_spad_data = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread); + uint8_t * src1_spad_data = octx->src1_spad.data + (ith * octx->src1_spad.size_per_thread); uint8_t * dst_spad_data = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread); size_t src0_spad_half_size = uctx->src0_spad_half_size; + size_t src1_spad_half_size = uctx->src1_spad_half_size; size_t dst_spad_half_size = uctx->dst_spad_half_size; // Non-contiguous tensors have gaps at dim-2/3 boundaries that a single-stride @@ -624,6 +718,12 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * dma_queue * dma_queue = octx->ctx->dma[ith]; + // If weight is broadcasted, load it once per thread at the beginning of execution + if (htp_op == HTP_OP_RMS_NORM_MUL && uctx->broadcast_weight) { + dma_queue_push(dma_queue, dma_make_ptr(src1_spad_data, data_src1), uctx->src1_row_size_aligned, 0, uctx->src1_data_row_size, 1); + dma_queue_flush(dma_queue); + } + for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; spad_idx++) { const uint32_t block_size = unary_block_size(ir, src0_end_row, BLOCK, src0_contig, dst_contig, ne01, ne1); @@ -636,6 +736,14 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * dma_queue_push(dma_queue, dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src + src0_off), src0_row_size_aligned, nb01, src0_data_row_size, block_size); + + if (htp_op == HTP_OP_RMS_NORM_MUL && !uctx->broadcast_weight) { + const size_t src1_off = unary_row_offset(ir, ne01, ne02, nb01, nb02, nb03); + dma_queue_push(dma_queue, + dma_make_ptr(src1_spad_data + (spad_idx * src1_spad_half_size), data_src1 + src1_off), + uctx->src1_row_size_aligned, nb01, uctx->src1_data_row_size, block_size); + } + ir += block_size; } @@ -644,6 +752,10 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * float * dst_spad = (float *) dma_queue_pop(dma_queue).src; float * src0_spad = (float *) dma_queue_pop(dma_queue).dst; + float * src1_spad = NULL; + if (htp_op == HTP_OP_RMS_NORM_MUL && !uctx->broadcast_weight) { + src1_spad = (float *) dma_queue_pop(dma_queue).dst; + } // Process block in VTCM switch (htp_op) { @@ -653,6 +765,12 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * case HTP_OP_RMS_NORM: rms_norm_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params); break; + case HTP_OP_RMS_NORM_MUL: + { + const float * w_ptr = uctx->broadcast_weight ? (const float *) src1_spad_data : src1_spad; + rms_norm_mul_f32(src0_spad, w_ptr, dst_spad, block_size, ne0, src0_row_size_aligned, uctx->src1_row_size_aligned, op_params, uctx->broadcast_weight); + } + break; case HTP_OP_SCALE: scale_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params); break; @@ -700,9 +818,16 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * if (pref_ir < src0_end_row) { const uint32_t pref_block_size = unary_block_size(pref_ir, src0_end_row, BLOCK, src0_contig, dst_contig, ne01, ne1); const size_t src0_pref_off = unary_row_offset(pref_ir, ne01, ne02, nb01, nb02, nb03); - dma_queue_push(dma_queue, - dma_make_ptr(src0_spad, data_src + src0_pref_off), - src0_row_size_aligned, nb01, src0_data_row_size, pref_block_size); + dma_queue_push(dma_queue, + dma_make_ptr(src0_spad, data_src + src0_pref_off), + src0_row_size_aligned, nb01, src0_data_row_size, pref_block_size); + + if (htp_op == HTP_OP_RMS_NORM_MUL && !uctx->broadcast_weight) { + const size_t src1_pref_off = unary_row_offset(pref_ir, ne01, ne02, nb01, nb02, nb03); + dma_queue_push(dma_queue, + dma_make_ptr(src1_spad, data_src1 + src1_pref_off), + uctx->src1_row_size_aligned, nb01, uctx->src1_data_row_size, pref_block_size); + } } } ir += block_size; @@ -732,6 +857,9 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) { case HTP_OP_RMS_NORM: op_type = "rmsnorm-f32"; break; + case HTP_OP_RMS_NORM_MUL: + op_type = "rmsnorm-mul-f32"; + break; case HTP_OP_SCALE: op_type = "scale-f32"; break; @@ -777,12 +905,44 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) { const size_t src0_row_size_aligned = hex_round_up(src0_data_row_size, VLEN); const size_t dst_row_size_aligned = hex_round_up(dst_data_row_size, VLEN); + size_t src1_data_row_size = 0; + size_t src1_row_size_aligned = 0; + bool broadcast_weight = false; + const struct htp_tensor * src1 = NULL; + + if (octx->op == HTP_OP_RMS_NORM_MUL) { + src1 = octx->src[1]; + src1_data_row_size = src1->ne[0] * sizeof(float); + src1_row_size_aligned = hex_round_up(src1_data_row_size, VLEN); + broadcast_weight = (src1->ne[1] * src1->ne[2] * src1->ne[3] == 1); + } + // VTCM scratchpads for all tensors // N rows per thread, padded to HVX vector size // Double buffering requires 2x size per buffer - size_t spad_size_per_row = 2 * (src0_row_size_aligned + dst_row_size_aligned); - size_t vtcm_row_per_thread = (octx->ctx->vtcm_size)/ (n_threads * spad_size_per_row); + size_t spad_size_per_row = 0; + size_t vtcm_row_per_thread = 0; + + if (octx->op == HTP_OP_RMS_NORM_MUL) { + if (broadcast_weight) { + size_t available_vtcm = octx->ctx->vtcm_size; + size_t src1_spad_total = n_threads * src1_row_size_aligned; + if (available_vtcm > src1_spad_total) { + available_vtcm -= src1_spad_total; + } else { + available_vtcm = 0; + } + spad_size_per_row = 2 * (src0_row_size_aligned + dst_row_size_aligned); + vtcm_row_per_thread = available_vtcm / (n_threads * spad_size_per_row); + } else { + spad_size_per_row = 2 * (src0_row_size_aligned + dst_row_size_aligned + src1_row_size_aligned); + vtcm_row_per_thread = (octx->ctx->vtcm_size) / (n_threads * spad_size_per_row); + } + } else { + spad_size_per_row = 2 * (src0_row_size_aligned + dst_row_size_aligned); + vtcm_row_per_thread = (octx->ctx->vtcm_size)/ (n_threads * spad_size_per_row); + } // Make sure the reserved vtcm size is sufficient if (vtcm_row_per_thread == 0) { @@ -797,8 +957,25 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) { octx->src0_spad.size = n_threads * octx->src0_spad.size_per_thread; octx->dst_spad.size = n_threads * octx->dst_spad.size_per_thread; + if (octx->op == HTP_OP_RMS_NORM_MUL) { + if (broadcast_weight) { + octx->src1_spad.size_per_thread = src1_row_size_aligned; + } else { + octx->src1_spad.size_per_thread = src1_row_size_aligned * vtcm_row_per_thread * 2; + } + octx->src1_spad.size = n_threads * octx->src1_spad.size_per_thread; + } else { + octx->src1_spad.size = 0; + octx->src1_spad.size_per_thread = 0; + } + octx->src0_spad.data = octx->ctx->vtcm_base; - octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size; + if (octx->op == HTP_OP_RMS_NORM_MUL) { + octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; + octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; + } else { + octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size; + } FARF(HIGH, "%s: (%ux%ux%ux%u) -> (%ux%ux%ux%u) : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], @@ -811,19 +988,24 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) { .src0_nrows = src0_nrows, .data_src0 = (const uint8_t *)src0->data, + .data_src1 = (octx->op == HTP_OP_RMS_NORM_MUL) ? (const uint8_t *)src1->data : NULL, .data_dst = (uint8_t *)dst->data, .src0_data_row_size = src0_data_row_size, + .src1_data_row_size = src1_data_row_size, .dst_data_row_size = dst_data_row_size, .src0_row_size_aligned = src0_row_size_aligned, + .src1_row_size_aligned = src1_row_size_aligned, .dst_row_size_aligned = dst_row_size_aligned, .src0_spad_half_size = octx->src0_spad.size_per_thread / 2, + .src1_spad_half_size = (octx->op == HTP_OP_RMS_NORM_MUL) ? (octx->src1_spad.size_per_thread / (broadcast_weight ? 1 : 2)) : 0, .dst_spad_half_size = octx->dst_spad.size_per_thread / 2, .block = (octx->src0_spad.size_per_thread / 2) / src0_row_size_aligned, .nc = src0->ne[0], + .broadcast_weight = broadcast_weight, }; worker_pool_run_func(octx->ctx->worker_pool, unary_job_f32_per_thread, &uctx, n_threads); diff --git a/ggml/src/ggml-hexagon/op-desc.h b/ggml/src/ggml-hexagon/op-desc.h deleted file mode 100644 index a1e8ddd8b97..00000000000 --- a/ggml/src/ggml-hexagon/op-desc.h +++ /dev/null @@ -1,153 +0,0 @@ -#ifndef OP_DESC_H -#define OP_DESC_H - -#define GGML_COMMON_IMPL_CPP -#include "ggml-backend-impl.h" -#include "ggml-common.h" - -#include -#include - -struct op_desc { - char strides[64 * GGML_MAX_SRC]; - char dims[64 * GGML_MAX_SRC]; - char types[16 * GGML_MAX_SRC]; - char buffs[64 * GGML_MAX_SRC]; - char names[64 * GGML_MAX_SRC]; - - int format_tensor_dims(char * str, const struct ggml_tensor * t) { - if (t->ne[2] == 1 && t->ne[3] == 1) { - return sprintf(str, "%d:%d", (int) t->ne[0], (int) t->ne[1]); - } else { - return sprintf(str, "%d:%d:%d:%d", (int) t->ne[0], (int) t->ne[1], (int) t->ne[2], (int) t->ne[3]); - } - } - - void format_op_dims(char * str, const struct ggml_tensor * t) { - char * p = str; - - // append src0 and src1 (if any) - if (t->src[0]) { - p += format_tensor_dims(p, t->src[0]); - - for (int i = 1; i < GGML_MAX_SRC && t->src[i]; i++) { - p += sprintf(p, " x "); - p += format_tensor_dims(p, t->src[i]); - } - - p += sprintf(p, " -> "); - } - - // format self dims separately for better visual alignment - char self[64]; - format_tensor_dims(self, t); - - p += sprintf(p, "%s", self); - } - - int format_tensor_strides(char * str, const struct ggml_tensor * t) { - const char * c = ggml_is_contiguous(t) ? "" : "!"; - - if (t->ne[2] == 1 && t->ne[3] == 1) { - return sprintf(str, "%zu:%zu%s", (size_t) t->nb[0], (size_t) t->nb[1], c); - } else { - return sprintf(str, "%zu:%zu:%zu:%zu%s", (size_t) t->nb[0], (size_t) t->nb[1], (size_t) t->nb[2], (size_t) t->nb[3], c); - } - } - - void format_op_strides(char * str, const struct ggml_tensor * t) { - char * p = str; - - // append src0 and src1 (if any) - if (t->src[0]) { - p += format_tensor_strides(p, t->src[0]); - - for (int i = 1; i < GGML_MAX_SRC && t->src[i]; i++) { - p += sprintf(p, " x "); - p += format_tensor_strides(p, t->src[i]); - } - - p += sprintf(p, " -> "); - } - - // format self dims separately for better visual alignment - char self[64]; - format_tensor_strides(self, t); - - p += sprintf(p, "%s", self); - } - - void format_op_types(char * str, const struct ggml_tensor * t) { - char * p = str; - - // append src0 and src1 (if any) - if (t->src[0]) { - p += sprintf(p, "%s", ggml_type_name(t->src[0]->type)); - - for (int i = 1; i < GGML_MAX_SRC && t->src[i]; i++) { - p += sprintf(p, " x "); - p += sprintf(p, "%s", ggml_type_name(t->src[i]->type)); - } - - p += sprintf(p, " -> "); - } - - p += sprintf(p, "%s", ggml_type_name(t->type)); - } - - const char * tensor_buff_name(const struct ggml_tensor * t) { - if (t->buffer) { - return ggml_backend_buffer_name(t->buffer); - } - return "NONE"; - } - - void format_op_buffs(char * str, const struct ggml_tensor * t) { - char * p = str; - - // append src0 and src1 (if any) - if (t->src[0]) { - p += sprintf(p, "%s", tensor_buff_name(t->src[0])); - - for (int i = 1; i < GGML_MAX_SRC && t->src[i]; i++) { - p += sprintf(p, " x "); - p += sprintf(p, "%s", tensor_buff_name(t->src[i])); - } - - p += sprintf(p, " -> "); - } - - p += sprintf(p, "%s", tensor_buff_name(t)); - } - - void format_op_names(char * str, const struct ggml_tensor * t) { - char * p = str; - - // append src0 and src1 (if any) - if (t->src[0]) { - p += sprintf(p, "%s", t->src[0]->name); - - for (int i = 1; i < GGML_MAX_SRC && t->src[i]; i++) { - p += sprintf(p, " x "); - p += sprintf(p, "%s", t->src[i]->name); - } - - p += sprintf(p, " -> "); - } - - p += sprintf(p, "%s", t->name); - } - - void format(const ggml_tensor * op) { - format_op_dims(dims, op); - format_op_strides(strides, op); - format_op_types(types, op); - format_op_buffs(buffs, op); - format_op_names(names, op); - } - - op_desc() {} - op_desc(const ggml_tensor * op) { format(op); } -}; - -#endif // OP_DESC_H From f1b687da28a6e28beb2a2e7ed2d74f554eb279be Mon Sep 17 00:00:00 2001 From: Matt Corallo <649246+TheBlueMatt@users.noreply.github.com> Date: Fri, 29 May 2026 03:30:24 +0000 Subject: [PATCH 730/831] meta : Add missing `buffer` set in allreduce fallback !COMPUTE clear (llama/23480) Without this at least the vulkan backend will skip the `* 0` for !COMPUTE tensors, causing corrupt output. --- ggml/src/ggml-backend-meta.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/ggml/src/ggml-backend-meta.cpp b/ggml/src/ggml-backend-meta.cpp index d0d64523b4a..48b2027fac3 100644 --- a/ggml/src/ggml-backend-meta.cpp +++ b/ggml/src/ggml-backend-meta.cpp @@ -2076,6 +2076,7 @@ static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend, node_zero->src[0] = node; ggml_set_op_params_f32(node_zero, 0, 0.0f); node_zero->data = node->data; + node_zero->buffer = node->buffer; node_zero->flags |= GGML_TENSOR_FLAG_COMPUTE; step_cgraphs[j] = get_cgraph_aux(); From e90501e179632071cd7bba5cf5f05ec9991e64ff Mon Sep 17 00:00:00 2001 From: Andreas Kieslinger <47689530+aendk@users.noreply.github.com> Date: Fri, 29 May 2026 06:46:10 +0200 Subject: [PATCH 731/831] cuda : disables launch_fattn PDL enrollment due to compiler bug (llama/23825) --- ggml/src/ggml-cuda/fattn-common.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index debcb6e5447..d650b5fbd0f 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -1153,8 +1153,8 @@ void launch_fattn( GGML_ASSERT(block_dim.x % warp_size == 0); - const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks_num, block_dim, nbytes_shared, main_stream); - ggml_cuda_kernel_launch(fattn_kernel, launch_params, + // disabled PDL enrollment for now due to a compiler bug. + fattn_kernel<<>>( (const char *) Q->data, K_data, V_data, From cc65eb1816f780fd8478c58894f45b4c160e5ffc Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 29 May 2026 09:43:15 +0300 Subject: [PATCH 732/831] sync : ggml --- scripts/sync-ggml.last | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index a4f87b2b9ae..6aed494381c 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -e705c5fed490514458bdd2eaddc43bd098fcce9b +5fbba2f28a17545214650298fd729563475004ca From 5828fba79f0c00f4cd7c7c205824b72664ac79d2 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 29 May 2026 09:44:28 +0300 Subject: [PATCH 733/831] talk-llama : sync llama.cpp --- examples/talk-llama/llama-arch.cpp | 6 +- examples/talk-llama/llama-arch.h | 1 + examples/talk-llama/llama-chat.cpp | 20 +++- examples/talk-llama/llama-chat.h | 1 + examples/talk-llama/llama-model.cpp | 3 + examples/talk-llama/llama-model.h | 2 +- examples/talk-llama/llama-vocab.cpp | 14 ++- examples/talk-llama/llama-vocab.h | 1 + examples/talk-llama/models/mistral3.cpp | 12 +- examples/talk-llama/models/models.h | 13 +++ examples/talk-llama/models/talkie.cpp | 149 ++++++++++++++++++++++++ 11 files changed, 213 insertions(+), 9 deletions(-) create mode 100644 examples/talk-llama/models/talkie.cpp diff --git a/examples/talk-llama/llama-arch.cpp b/examples/talk-llama/llama-arch.cpp index c9eead18aa3..e95ba6daac1 100644 --- a/examples/talk-llama/llama-arch.cpp +++ b/examples/talk-llama/llama-arch.cpp @@ -133,6 +133,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_LLAMA_EMBED, "llama-embed" }, { LLM_ARCH_MAINCODER, "maincoder" }, { LLM_ARCH_KIMI_LINEAR, "kimi-linear" }, + { LLM_ARCH_TALKIE, "talkie" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -767,8 +768,9 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, // Nemotron 3 Super - {LLM_TENSOR_FFN_LATENT_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, - {LLM_TENSOR_FFN_LATENT_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + // latent projections feed ggml_mul_mat, the buft probe must use MUL_MAT to keep them on GPU + {LLM_TENSOR_FFN_LATENT_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_FFN_LATENT_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, }; LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {} diff --git a/examples/talk-llama/llama-arch.h b/examples/talk-llama/llama-arch.h index 89cf16cc37c..7c1dcc4d6c2 100644 --- a/examples/talk-llama/llama-arch.h +++ b/examples/talk-llama/llama-arch.h @@ -137,6 +137,7 @@ enum llm_arch { LLM_ARCH_LLAMA_EMBED, LLM_ARCH_MAINCODER, LLM_ARCH_KIMI_LINEAR, + LLM_ARCH_TALKIE, LLM_ARCH_UNKNOWN, }; diff --git a/examples/talk-llama/llama-chat.cpp b/examples/talk-llama/llama-chat.cpp index f10397747b0..6d822ec62d6 100644 --- a/examples/talk-llama/llama-chat.cpp +++ b/examples/talk-llama/llama-chat.cpp @@ -62,6 +62,7 @@ static const std::map LLM_CHAT_TEMPLATES = { { "rwkv-world", LLM_CHAT_TEMPLATE_RWKV_WORLD }, { "granite", LLM_CHAT_TEMPLATE_GRANITE_3_X }, { "granite-4.0", LLM_CHAT_TEMPLATE_GRANITE_4_0 }, + { "granite-4.1", LLM_CHAT_TEMPLATE_GRANITE_4_1 }, { "gigachat", LLM_CHAT_TEMPLATE_GIGACHAT }, { "megrez", LLM_CHAT_TEMPLATE_MEGREZ }, { "yandex", LLM_CHAT_TEMPLATE_YANDEX }, @@ -194,7 +195,10 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) { return LLM_CHAT_TEMPLATE_RWKV_WORLD; } else if (tmpl_contains("<|start_of_role|>")) { if (tmpl_contains("") || tmpl_contains("")) { - return LLM_CHAT_TEMPLATE_GRANITE_4_0; + if (tmpl_contains("g4_default_system_message")) { + return LLM_CHAT_TEMPLATE_GRANITE_4_0; + } + return LLM_CHAT_TEMPLATE_GRANITE_4_1; } return LLM_CHAT_TEMPLATE_GRANITE_3_X; } else if (tmpl_contains("message['role'] + additional_special_tokens[0] + message['content'] + additional_special_tokens[1]")) { @@ -651,6 +655,20 @@ int32_t llm_chat_apply_template( if (add_ass) { ss << "<|start_of_role|>assistant<|end_of_role|>"; } + } else if (tmpl == LLM_CHAT_TEMPLATE_GRANITE_4_1) { + // IBM Granite 4.1 template + for (const auto & message : chat) { + std::string role(message->role); + if (role == "assistant_tool_call") { + ss << "<|start_of_role|>assistant<|end_of_role|><|tool_call|>"; + } else { + ss << "<|start_of_role|>" << role << "<|end_of_role|>"; + } + ss << message->content << "<|end_of_text|>\n"; + } + if (add_ass) { + ss << "<|start_of_role|>assistant<|end_of_role|>"; + } } else if (tmpl == LLM_CHAT_TEMPLATE_GIGACHAT) { // GigaChat template bool has_system = !chat.empty() && std::string(chat[0]->role) == "system"; diff --git a/examples/talk-llama/llama-chat.h b/examples/talk-llama/llama-chat.h index ea6540c0be7..dc37f919a96 100644 --- a/examples/talk-llama/llama-chat.h +++ b/examples/talk-llama/llama-chat.h @@ -41,6 +41,7 @@ enum llm_chat_template { LLM_CHAT_TEMPLATE_RWKV_WORLD, LLM_CHAT_TEMPLATE_GRANITE_3_X, LLM_CHAT_TEMPLATE_GRANITE_4_0, + LLM_CHAT_TEMPLATE_GRANITE_4_1, LLM_CHAT_TEMPLATE_GIGACHAT, LLM_CHAT_TEMPLATE_MEGREZ, LLM_CHAT_TEMPLATE_YANDEX, diff --git a/examples/talk-llama/llama-model.cpp b/examples/talk-llama/llama-model.cpp index 0d21b2a53c5..0c3e03a61dc 100644 --- a/examples/talk-llama/llama-model.cpp +++ b/examples/talk-llama/llama-model.cpp @@ -44,6 +44,8 @@ static llama_model * llama_model_mapping(llm_arch arch, const llama_model_params return new llama_model_llama_embed(params); case LLM_ARCH_MAINCODER: return new llama_model_maincoder(params); + case LLM_ARCH_TALKIE: + return new llama_model_talkie(params); case LLM_ARCH_DECI: return new llama_model_deci(params); case LLM_ARCH_BAICHUAN: @@ -2353,6 +2355,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_QWEN3NEXT: case LLM_ARCH_MIMO2: case LLM_ARCH_STEP35: + case LLM_ARCH_TALKIE: return LLAMA_ROPE_TYPE_NEOX; case LLM_ARCH_QWEN2VL: diff --git a/examples/talk-llama/llama-model.h b/examples/talk-llama/llama-model.h index 398a0aa725c..b797b8966ac 100644 --- a/examples/talk-llama/llama-model.h +++ b/examples/talk-llama/llama-model.h @@ -488,7 +488,7 @@ struct llama_layer { struct ggml_tensor * indexer_attn_k = nullptr; struct ggml_tensor * indexer_attn_q_b = nullptr; // note: for lora a/b, not bias - // gemma4 layer output scale + // gemma4 layer output scale, reused for talkie embedding skip scale struct ggml_tensor * out_scale = nullptr; struct llama_layer_posnet posnet; diff --git a/examples/talk-llama/llama-vocab.cpp b/examples/talk-llama/llama-vocab.cpp index a5cf148b268..473becade82 100644 --- a/examples/talk-llama/llama-vocab.cpp +++ b/examples/talk-llama/llama-vocab.cpp @@ -511,6 +511,14 @@ struct llm_tokenizer_bpe : llm_tokenizer { }; byte_encode = false; break; + case LLAMA_VOCAB_PRE_TYPE_MINICPM5: + regex_exprs = { + // original regex from tokenizer.json (openbmb/MiniCPM5-1B) + "\\p{N}{1,3}", + // "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}+| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" + "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}+| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + }; + break; default: // default regex for BPE tokenization pre-processing regex_exprs = { @@ -2039,6 +2047,9 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT; } else if (tokenizer_pre == "default") { pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT; + } else if (tokenizer_pre == "minicpm5") { + pre_type = LLAMA_VOCAB_PRE_TYPE_MINICPM5; + ignore_merges = true; } else if ( tokenizer_pre == "llama3" || tokenizer_pre == "llama-v3" || @@ -2196,7 +2207,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { } else if ( tokenizer_pre == "gpt-4o" || tokenizer_pre == "llama4" || - tokenizer_pre == "kanana2") { + tokenizer_pre == "kanana2" || + tokenizer_pre == "talkie") { pre_type = LLAMA_VOCAB_PRE_TYPE_GPT4O; clean_spaces = false; } else if ( diff --git a/examples/talk-llama/llama-vocab.h b/examples/talk-llama/llama-vocab.h index 8b040b912e2..8ab77594284 100644 --- a/examples/talk-llama/llama-vocab.h +++ b/examples/talk-llama/llama-vocab.h @@ -60,6 +60,7 @@ enum llama_vocab_pre_type { LLAMA_VOCAB_PRE_TYPE_JAIS2 = 49, LLAMA_VOCAB_PRE_TYPE_GEMMA4 = 50, LLAMA_VOCAB_PRE_TYPE_SARVAM_MOE = 51, + LLAMA_VOCAB_PRE_TYPE_MINICPM5 = 52, }; struct LLM_KV; diff --git a/examples/talk-llama/models/mistral3.cpp b/examples/talk-llama/models/mistral3.cpp index 4e6ebef82cb..1ac5a95ccdc 100644 --- a/examples/talk-llama/models/mistral3.cpp +++ b/examples/talk-llama/models/mistral3.cpp @@ -177,9 +177,9 @@ llama_model_mistral3::graph::graph(const llama_model & model, const llm_graph_pa cb(cur, "ffn_norm", il); cur = build_ffn(cur, - model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, - model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, - model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, model.layers[il].ffn_up_s, + model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, model.layers[il].ffn_gate_s, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, model.layers[il].ffn_down_s, NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); cb(cur, "ffn_out", il); @@ -200,7 +200,11 @@ llama_model_mistral3::graph::graph(const llama_model & model, const llm_graph_pa LLM_FFN_SILU, true, hparams.expert_weights_scale, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, - il); + il, + nullptr, nullptr, + model.layers[il].ffn_up_exps_s, + model.layers[il].ffn_gate_exps_s, + model.layers[il].ffn_down_exps_s); cb(cur, "ffn_moe_out", il); } cur = ggml_add(ctx0, cur, ffn_inp); diff --git a/examples/talk-llama/models/models.h b/examples/talk-llama/models/models.h index 7e551eb965b..db228865d5d 100644 --- a/examples/talk-llama/models/models.h +++ b/examples/talk-llama/models/models.h @@ -186,6 +186,19 @@ struct llama_model_maincoder : public llama_model_base { }; +struct llama_model_talkie : public llama_model_base { + llama_model_talkie(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr build_arch_graph(const llm_graph_params & params) const override; +}; + + struct llama_model_deci : public llama_model_base { llama_model_deci(const struct llama_model_params & params) : llama_model_base(params) {} void load_arch_hparams(llama_model_loader & ml) override; diff --git a/examples/talk-llama/models/talkie.cpp b/examples/talk-llama/models/talkie.cpp new file mode 100644 index 00000000000..1258eeb19b6 --- /dev/null +++ b/examples/talk-llama/models/talkie.cpp @@ -0,0 +1,149 @@ +#include "models.h" + +void llama_model_talkie::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); + + switch (hparams.n_layer) { + case 40: type = LLM_TYPE_13B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_talkie::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + // no k gain + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {1, n_head}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + + layer.out_scale = create_tensor(tn(LLM_TENSOR_LAYER_OUT_SCALE, "weight", i), {1}, 0); + } +} + +std::unique_ptr llama_model_talkie::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique(*this, params); +} + +llama_model_talkie::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_k(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_v()); + GGML_ASSERT(n_embd_head == n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + inpL = build_norm(inpL, nullptr, nullptr, LLM_NORM_RMS, -1); + cb(inpL, "inp_norm", -1); + + ggml_tensor * embd_skip = inpL; + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + const float kq_scale = 1.0f / sqrtf(float(n_embd_head)); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + ggml_tensor * inp_skip = embd_skip; + + cur = build_norm(inpL, nullptr, nullptr, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + // reference applies qknorm after rope + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_norm", il); + + Kcur = build_norm(Kcur, nullptr, nullptr, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_norm", il); + + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, nullptr, model.layers[il].wo_s, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + cb(cur, "attn_out", il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + inp_skip = ggml_get_rows(ctx0, inp_skip, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + cur = build_norm(ffn_inp, nullptr, nullptr, LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, nullptr, nullptr, + model.layers[il].ffn_gate, nullptr, nullptr, + model.layers[il].ffn_down, nullptr, model.layers[il].ffn_down_s, + nullptr, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + + ggml_tensor * skip = ggml_mul(ctx0, inp_skip, model.layers[il].out_scale); + cb(skip, "embd_skip", il); + + cur = ggml_add(ctx0, cur, skip); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, nullptr, nullptr, LLM_NORM_RMS, -1); + cb(cur, "result_norm", -1); + + res->t_embd = cur; + + cur = build_lora_mm(model.output, cur); + cur = ggml_scale(ctx0, cur, hparams.f_logit_scale); + cb(cur, "result_output", -1); + + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} From 92fc3f2a58bb6c518aef3bc8ddbe4c84e75a79b3 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 29 May 2026 09:46:12 +0300 Subject: [PATCH 734/831] ggml : bump version to 0.13.1 (ggml/1523) --- ggml/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index f542f18b6d4..dc8899b46ef 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -5,7 +5,7 @@ project("ggml" C CXX ASM) ### GGML Version set(GGML_VERSION_MAJOR 0) set(GGML_VERSION_MINOR 13) -set(GGML_VERSION_PATCH 0) +set(GGML_VERSION_PATCH 1) set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}") list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/") From f24588a272ae8e23280d9c220536437164e6ed28 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 29 May 2026 09:46:42 +0300 Subject: [PATCH 735/831] sync : ggml --- scripts/sync-ggml.last | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index 6aed494381c..538ef80bc7a 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -5fbba2f28a17545214650298fd729563475004ca +1e33fed33e87c43aa4c4078e2a9c239d4c1f1bd3 From f39cc7128295ff5c67bbedb73161bed549f96e96 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 31 May 2026 15:44:07 +0300 Subject: [PATCH 736/831] common : re-implement `ffmpeg-transcode.cpp` + clarify ffmpeg usage (#3846) * examples : remove ffmpeg-transcode.cpp * examples : implement ffmpeg-transcode.cpp Assisted-by: llama.cpp:local pi * common : switch from WHISPER_FFMPEG -> WHISPER_COMMON_FFMPEG --- CMakeLists.txt | 3 +- README.md | 7 +- examples/CMakeLists.txt | 4 +- examples/common-whisper.cpp | 84 +++--- examples/ffmpeg-transcode.cpp | 553 +++++++++++++--------------------- tests/CMakeLists.txt | 2 +- 6 files changed, 271 insertions(+), 382 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 2200673d0a3..35c8674725f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -85,7 +85,7 @@ option(WHISPER_CURL "whisper: use libcurl to download model from an URL" OFF) option(WHISPER_SDL2 "whisper: support for libSDL2" OFF) if (CMAKE_SYSTEM_NAME MATCHES "Linux") - option(WHISPER_FFMPEG "whisper: support building and linking with ffmpeg libs (avcodec, swresample, ...)" OFF) + option(WHISPER_COMMON_FFMPEG "whisper: examples link with ffmpeg libs in order to decode more audio formats" OFF) endif() option(WHISPER_COREML "whisper: enable Core ML framework" OFF) @@ -121,6 +121,7 @@ whisper_option_depr(WARNING WHISPER_RPC GGML_RPC) whisper_option_depr(WARNING WHISPER_SYCL GGML_SYCL) whisper_option_depr(WARNING WHISPER_SYCL_F16 GGML_SYCL_F16) whisper_option_depr(WARNING WHISPER_CCACHE GGML_CCACHE) +whisper_option_depr(WARNING WHISPER_FFMPEG WHISPER_COMMON_FFMPEG) if (GGML_CUDA AND NOT MSVC) #GGML_CUDA enabled, add the necessary compile options -Wno-deprecated-gpu-targets diff --git a/README.md b/README.md index 050a35be21c..d1680e99bfc 100644 --- a/README.md +++ b/README.md @@ -425,9 +425,10 @@ cmake -B build -DGGML_MUSA=1 -DMUSA_ARCHITECTURES="21" cmake --build build -j --config Release ``` -## FFmpeg support (Linux only) +## FFmpeg support (examples only) -If you want to support more audio formats (such as Opus and AAC), you can turn on the `WHISPER_FFMPEG` build flag to enable FFmpeg integration. +By default, the examples in this repo use the [miniaudio](https://github.com/mackron/miniaudio) library to decode audio files. +Some of the examples also can use FFmpeg for decoding and broader format support. To enable that, build with `WHISPER_COMMON_FFMPEG`. First, you need to install required libraries: @@ -442,7 +443,7 @@ sudo dnf install libavcodec-free-devel libavformat-free-devel libavutil-free-dev Then you can build the project as follows: ```bash -cmake -B build -D WHISPER_FFMPEG=yes +cmake -B build -D WHISPER_COMMON_FFMPEG=yes cmake --build build ``` diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index b202ca00b77..0bb54cec489 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -20,7 +20,7 @@ set(TARGET common) unset(COMMON_EXTRA_LIBS) -if (WHISPER_FFMPEG) +if (WHISPER_COMMON_FFMPEG) # As of cmake 3.27, there is no official cmake support for FindFFmpeg. # Consequnelty we added a FindFFmpeg.cmake script the cmake subfolder: # whisper.cpp does not need the full ffmpeg libs, just AVFORMAT AVCODEC AVUTIL SWRESAMPLE @@ -39,7 +39,7 @@ if (WHISPER_FFMPEG) message(STATUS "Found avformat ${AVFORMAT_VERSION}") include_directories(${FFMPEG_INCLUDE_DIRS}) - add_compile_definitions(WHISPER_FFMPEG) + add_compile_definitions(WHISPER_COMMON_FFMPEG) list(APPEND COMMON_EXTRA_LIBS ${FFMPEG_LIBRARIES}) diff --git a/examples/common-whisper.cpp b/examples/common-whisper.cpp index d29166b50d8..8cdd2320c17 100644 --- a/examples/common-whisper.cpp +++ b/examples/common-whisper.cpp @@ -34,8 +34,8 @@ #include #include -#ifdef WHISPER_FFMPEG -// as implemented in ffmpeg_trancode.cpp only embedded in common lib if whisper built with ffmpeg support +#ifdef WHISPER_COMMON_FFMPEG +// as implemented in ffmpeg-trancode.cpp only embedded in common lib if whisper built with ffmpeg support extern bool ffmpeg_decode_audio(const std::string & ifname, std::vector & wav_data); #endif @@ -75,7 +75,7 @@ static bool read_audio_from_decoder(ma_decoder & decoder, std::vector & p return true; } -bool read_audio_data(const std::string & fname, std::vector& pcmf32, std::vector>& pcmf32s, bool stereo) { +bool read_audio_data(const std::string & fname, std::vector & pcmf32, std::vector> & pcmf32s, bool stereo) { std::vector audio_data; // used for pipe input from stdin or ffmpeg decoding output ma_result result; @@ -96,53 +96,67 @@ bool read_audio_data(const std::string & fname, std::vector& pcmf32, std: decoder_config = ma_decoder_config_init(ma_format_f32, stereo ? 2 : 1, WHISPER_SAMPLE_RATE); if (fname == "-") { - #ifdef _WIN32 - _setmode(_fileno(stdin), _O_BINARY); - #endif - - uint8_t buf[1024]; - while (true) - { - const size_t n = fread(buf, 1, sizeof(buf), stdin); - if (n == 0) { - break; - } - audio_data.insert(audio_data.end(), buf, buf + n); - } - - result = ma_decoder_init_memory(audio_data.data(), audio_data.size(), &decoder_config, &decoder); +#ifdef _WIN32 + _setmode(_fileno(stdin), _O_BINARY); +#endif + + uint8_t buf[1024]; + while (true) + { + const size_t n = fread(buf, 1, sizeof(buf), stdin); + if (n == 0) { + break; + } + audio_data.insert(audio_data.end(), buf, buf + n); + } + + result = ma_decoder_init_memory(audio_data.data(), audio_data.size(), &decoder_config, &decoder); if (result != MA_SUCCESS) { - fprintf(stderr, "Error: failed to open audio data from stdin (%s)\n", ma_result_description(result)); - return false; - } + fprintf(stderr, "%s: failed to open audio data from stdin (%s)\n", __func__, ma_result_description(result)); + return false; + } decoder.initialized = true; - fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, audio_data.size()); - } - else { - result = ma_decoder_init_file(fname.c_str(), &decoder_config, &decoder); - if (result == MA_SUCCESS) { - decoder.initialized = true; + fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, audio_data.size()); + } else { + fprintf(stderr, "%s: reading audio data from '%s' ...\n", __func__, fname.c_str()); + + // first try miniaudio. if it fails (or skipped) - try ffmpeg + { + const char * skip = getenv("WHISPER_COMMON_MINIAUDIO_SKIP"); + if (!skip || strlen(skip) == 0 || strcmp(skip, "0") == 0) { + fprintf(stderr, "%s: trying to decode with miniaudio\n", __func__); + + result = ma_decoder_init_file(fname.c_str(), &decoder_config, &decoder); + if (result == MA_SUCCESS) { + decoder.initialized = true; + } + } else { + fprintf(stderr, "%s: skipping miniaudio\n", __func__); + } } -#if defined(WHISPER_FFMPEG) + +#if defined(WHISPER_COMMON_FFMPEG) if (!decoder.initialized) { + fprintf(stderr, "%s: trying to decode with ffmpeg\n", __func__); + if (ffmpeg_decode_audio(fname, audio_data) != 0) { - fprintf(stderr, "error: failed to ffmpeg decode '%s'\n", fname.c_str()); + fprintf(stderr, "%s: failed to ffmpeg decode\n", __func__); return false; } result = ma_decoder_init_memory(audio_data.data(), audio_data.size(), &decoder_config, &decoder); if (result != MA_SUCCESS) { - fprintf(stderr, "error: failed to read audio data as wav (%s)\n", ma_result_description(result)); + fprintf(stderr, "%s: failed to read audio data as wav (%s)\n", __func__, ma_result_description(result)); return false; } decoder.initialized = true; } -#else - if (!decoder.initialized) { - fprintf(stderr, "error: failed to read audio data from (%s)\n", fname.c_str()); - return false; - } #endif + + if (!decoder.initialized) { + fprintf(stderr, "%s: failed to read audio data\n", __func__); + return false; + } } return read_audio_from_decoder(decoder.decoder, pcmf32, pcmf32s, stereo); diff --git a/examples/ffmpeg-transcode.cpp b/examples/ffmpeg-transcode.cpp index 1fae58a4ffa..dc57fe74596 100644 --- a/examples/ffmpeg-transcode.cpp +++ b/examples/ffmpeg-transcode.cpp @@ -1,368 +1,241 @@ -/* SPDX-License-Identifier: GPL-2.0 */ +#ifdef WHISPER_COMMON_FFMPEG -/* - * transcode.c - convert audio file to WAVE - * - * Copyright (C) 2019 Andrew Clayton - * Copyright (C) 2024 William Tambellini - */ +#include "whisper.h" -// Just for conveninent C++ API -#include #include - -// C -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include +#include +#include +#include extern "C" { -#include -#include #include +#include #include } -typedef uint64_t u64; -typedef int64_t s64; -typedef uint32_t u32; -typedef int32_t s32; -typedef uint16_t u16; -typedef int16_t s16; -typedef uint8_t u8; -typedef int8_t s8; - -#define WAVE_SAMPLE_RATE 16000 -#define AVIO_CTX_BUF_SZ 4096 - -static const char* ffmpegLog = getenv("FFMPEG_LOG"); -// Todo: add __FILE__ __LINE__ -#define LOG(...) \ - do { if (ffmpegLog) fprintf(stderr, __VA_ARGS__); } while(0) // C99 - -/* - * WAVE file header based on definition from - * https://gist.github.com/Jon-Schneider/8b7c53d27a7a13346a643dac9c19d34f - * - * We must ensure this structure doesn't have any holes or - * padding so we can just map it straight to the WAVE data. - */ -struct wave_hdr { - /* RIFF Header: "RIFF" */ - char riff_header[4]; - /* size of audio data + sizeof(struct wave_hdr) - 8 */ - int wav_size; - /* "WAVE" */ - char wav_header[4]; - - /* Format Header */ - /* "fmt " (includes trailing space) */ - char fmt_header[4]; - /* Should be 16 for PCM */ - int fmt_chunk_size; - /* Should be 1 for PCM. 3 for IEEE Float */ - s16 audio_format; - s16 num_channels; - int sample_rate; - /* - * Number of bytes per second - * sample_rate * num_channels * bit_depth/8 - */ - int byte_rate; - /* num_channels * bytes per sample */ - s16 sample_alignment; - /* bits per sample */ - s16 bit_depth; - - /* Data Header */ - /* "data" */ - char data_header[4]; - /* - * size of audio - * number of samples * num_channels * bit_depth/8 - */ - int data_bytes; -} __attribute__((__packed__)); - -struct audio_buffer { - u8 *ptr; - int size; /* size left in the buffer */ -}; - -static void set_wave_hdr(wave_hdr& wh, size_t size) { - memcpy(&wh.riff_header, "RIFF", 4); - wh.wav_size = size + sizeof(struct wave_hdr) - 8; - memcpy(&wh.wav_header, "WAVE", 4); - memcpy(&wh.fmt_header, "fmt ", 4); - wh.fmt_chunk_size = 16; - wh.audio_format = 1; - wh.num_channels = 1; - wh.sample_rate = WAVE_SAMPLE_RATE; - wh.sample_alignment = 2; - wh.bit_depth = 16; - wh.byte_rate = wh.sample_rate * wh.sample_alignment; - memcpy(&wh.data_header, "data", 4); - wh.data_bytes = size; +// Write a minimal WAV header into the output buffer. +// Returns the number of bytes written (44 for a standard PCM WAV header). +static size_t wav_header_write(uint8_t * buf, int num_channels, int sample_rate, int bits_per_sample, uint32_t data_size) { + // RIFF header + memcpy(buf, "RIFF", 4); + uint32_t chunk_size = 36 + data_size; + memcpy(buf + 4, &chunk_size, 4); + memcpy(buf + 8, "WAVE", 4); + + // fmt subchunk + memcpy(buf + 12, "fmt ", 4); + uint32_t subchunk1_size = 16; + memcpy(buf + 16, &subchunk1_size, 4); + uint16_t audio_format = 1; // PCM + memcpy(buf + 20, &audio_format, 2); + memcpy(buf + 22, &num_channels, 2); + memcpy(buf + 24, &sample_rate, 4); + + int bytes_per_sample = (bits_per_sample / 8) * num_channels; + int byte_rate = sample_rate * bytes_per_sample; + memcpy(buf + 28, &byte_rate, 4); + memcpy(buf + 32, &bytes_per_sample, 2); + memcpy(buf + 34, &bits_per_sample, 2); + + // data subchunk + memcpy(buf + 36, "data", 4); + memcpy(buf + 40, &data_size, 4); + + return 44; } -static void write_wave_hdr(int fd, size_t size) { - struct wave_hdr wh; - set_wave_hdr(wh, size); - write(fd, &wh, sizeof(struct wave_hdr)); -} +bool ffmpeg_decode_audio(const std::string & ifname, std::vector & wav_data) { + { + const char * verbose = getenv("WHISPER_COMMON_FFMPEG_VERBOSE"); + if (verbose && strcmp(verbose, "2") == 0) { + av_log_set_level(AV_LOG_DEBUG); + } else if (verbose && strcmp(verbose, "1") == 0) { + av_log_set_level(AV_LOG_VERBOSE); + } else { + av_log_set_level(AV_LOG_WARNING); + } + } -static int map_file(int fd, u8 **ptr, size_t *size) -{ - struct stat sb; + AVFormatContext * fmt_ctx = nullptr; + if (avformat_open_input(&fmt_ctx, ifname.c_str(), nullptr, nullptr) != 0) { + fprintf(stderr, "error: failed to open input file '%s'\n", ifname.c_str()); + return true; + } - fstat(fd, &sb); - *size = sb.st_size; + if (avformat_find_stream_info(fmt_ctx, nullptr) < 0) { + fprintf(stderr, "error: failed to find stream information\n"); + avformat_close_input(&fmt_ctx); + return true; + } - *ptr = (u8*)mmap(NULL, *size, PROT_READ|PROT_WRITE, MAP_PRIVATE, fd, 0); - if (*ptr == MAP_FAILED) { - perror("mmap"); - return -1; - } + // Find the first audio stream + int audio_stream_idx = -1; + for (unsigned int i = 0; i < fmt_ctx->nb_streams; i++) { + if (fmt_ctx->streams[i]->codecpar->codec_type == AVMEDIA_TYPE_AUDIO) { + audio_stream_idx = i; + break; + } + } - return 0; -} + if (audio_stream_idx == -1) { + fprintf(stderr, "error: failed to find an audio stream in '%s'\n", ifname.c_str()); + avformat_close_input(&fmt_ctx); + return true; + } -static int read_packet(void *opaque, u8 *buf, int buf_size) -{ - struct audio_buffer *audio_buf = (audio_buffer*)opaque; + AVStream * audio_stream = fmt_ctx->streams[audio_stream_idx]; - buf_size = FFMIN(buf_size, audio_buf->size); + // Open the decoder + const AVCodec * codec = avcodec_find_decoder(audio_stream->codecpar->codec_id); + if (!codec) { + fprintf(stderr, "error: failed to find decoder for codec id %d\n", audio_stream->codecpar->codec_id); + avformat_close_input(&fmt_ctx); + return true; + } - /* copy internal buffer data to buf */ - memcpy(buf, audio_buf->ptr, buf_size); - audio_buf->ptr += buf_size; - audio_buf->size -= buf_size; + AVCodecContext * codec_ctx = avcodec_alloc_context3(codec); + if (!codec_ctx) { + fprintf(stderr, "error: failed to allocate codec context\n"); + avformat_close_input(&fmt_ctx); + return true; + } - return buf_size; -} + if (avcodec_parameters_to_context(codec_ctx, audio_stream->codecpar) < 0) { + fprintf(stderr, "error: failed to copy codec parameters to context\n"); + avcodec_free_context(&codec_ctx); + avformat_close_input(&fmt_ctx); + return true; + } -static void convert_frame(struct SwrContext *swr, AVCodecContext *codec, - AVFrame *frame, s16 **data, int *size, bool flush) -{ - int nr_samples; - s64 delay; - u8 *buffer; - - delay = swr_get_delay(swr, codec->sample_rate); - nr_samples = av_rescale_rnd(delay + frame->nb_samples, - WAVE_SAMPLE_RATE, codec->sample_rate, - AV_ROUND_UP); - av_samples_alloc(&buffer, NULL, 1, nr_samples, AV_SAMPLE_FMT_S16, 0); - - /* - * !flush is used to check if we are flushing any remaining - * conversion buffers... - */ - nr_samples = swr_convert(swr, &buffer, nr_samples, - !flush ? (const u8 **)frame->data : NULL, - !flush ? frame->nb_samples : 0); - - *data = (s16*)realloc(*data, (*size + nr_samples) * sizeof(s16)); - memcpy(*data + *size, buffer, nr_samples * sizeof(s16)); - *size += nr_samples; - av_freep(&buffer); -} + if (avcodec_open2(codec_ctx, codec, nullptr) < 0) { + fprintf(stderr, "error: failed to open codec\n"); + avcodec_free_context(&codec_ctx); + avformat_close_input(&fmt_ctx); + return true; + } -static bool is_audio_stream(const AVStream *stream) -{ - if (stream->codecpar->codec_type == AVMEDIA_TYPE_AUDIO) - return true; + // Setup resampler: convert to 16-bit signed PCM, mono, 16000 Hz + const enum AVSampleFormat out_sample_fmt = AV_SAMPLE_FMT_S16; + const int out_sample_rate = WHISPER_SAMPLE_RATE; - return false; -} + AVChannelLayout out_ch_layout = AV_CHANNEL_LAYOUT_MONO; -// Return non zero on error, 0 on success -// audio_buffer: input memory -// data: decoded output audio data (wav file) -// size: size of output data -static int decode_audio(struct audio_buffer *audio_buf, s16 **data, int *size) -{ - LOG("decode_audio: input size: %d\n", audio_buf->size); - AVFormatContext *fmt_ctx; - AVIOContext *avio_ctx; - AVStream *stream; - AVCodecContext *codec; - AVPacket *packet; - AVFrame *frame; - struct SwrContext *swr; - u8 *avio_ctx_buffer; - unsigned int i; - int stream_index = -1; - int err; - const size_t errbuffsize = 1024; - char errbuff[errbuffsize]; - - fmt_ctx = avformat_alloc_context(); - avio_ctx_buffer = (u8*)av_malloc(AVIO_CTX_BUF_SZ); - LOG("Creating an avio context: AVIO_CTX_BUF_SZ=%d\n", AVIO_CTX_BUF_SZ); - avio_ctx = avio_alloc_context(avio_ctx_buffer, AVIO_CTX_BUF_SZ, 0, audio_buf, &read_packet, NULL, NULL); - fmt_ctx->pb = avio_ctx; - - // open the input stream and read header - err = avformat_open_input(&fmt_ctx, NULL, NULL, NULL); - if (err) { - LOG("Could not read audio buffer: %d: %s\n", err, av_make_error_string(errbuff, errbuffsize, err)); - return err; - } - - err = avformat_find_stream_info(fmt_ctx, NULL); - if (err < 0) { - LOG("Could not retrieve stream info from audio buffer: %d\n", err); - return err; - } - - for (i = 0; i < fmt_ctx->nb_streams; i++) { - if (is_audio_stream(fmt_ctx->streams[i])) { - stream_index = i; - break; - } - } - - if (stream_index == -1) { - LOG("Could not retrieve audio stream from buffer\n"); - return -1; - } - - stream = fmt_ctx->streams[stream_index]; - codec = avcodec_alloc_context3( - avcodec_find_decoder(stream->codecpar->codec_id)); - avcodec_parameters_to_context(codec, stream->codecpar); - err = avcodec_open2(codec, avcodec_find_decoder(codec->codec_id), - NULL); - if (err) { - LOG("Failed to open decoder for stream #%d in audio buffer\n", stream_index); - return err; - } - - /* prepare resampler */ - swr = swr_alloc(); - -#if LIBAVCODEC_VERSION_MAJOR > 60 - AVChannelLayout in_ch_layout = codec->ch_layout; - AVChannelLayout out_ch_layout = AV_CHANNEL_LAYOUT_MONO; - - /* Set the source audio layout as-is */ - av_opt_set_chlayout(swr, "in_chlayout", &in_ch_layout, 0); - av_opt_set_int(swr, "in_sample_rate", codec->sample_rate, 0); - av_opt_set_sample_fmt(swr, "in_sample_fmt", codec->sample_fmt, 0); - - /* Convert it into 16khz Mono */ - av_opt_set_chlayout(swr, "out_chlayout", &out_ch_layout, 0); - av_opt_set_int(swr, "out_sample_rate", WAVE_SAMPLE_RATE, 0); - av_opt_set_sample_fmt(swr, "out_sample_fmt", AV_SAMPLE_FMT_S16, 0); -#else - av_opt_set_int(swr, "in_channel_count", codec->channels, 0); - av_opt_set_int(swr, "out_channel_count", 1, 0); - av_opt_set_int(swr, "in_channel_layout", codec->channel_layout, 0); - av_opt_set_int(swr, "out_channel_layout", AV_CH_LAYOUT_MONO, 0); - av_opt_set_int(swr, "in_sample_rate", codec->sample_rate, 0); - av_opt_set_int(swr, "out_sample_rate", WAVE_SAMPLE_RATE, 0); - av_opt_set_sample_fmt(swr, "in_sample_fmt", codec->sample_fmt, 0); - av_opt_set_sample_fmt(swr, "out_sample_fmt", AV_SAMPLE_FMT_S16, 0); -#endif - - swr_init(swr); - if (!swr_is_initialized(swr)) { - LOG("Resampler has not been properly initialized\n"); - return -1; - } - - packet=av_packet_alloc(); - if (!packet) { - LOG("Error allocating the packet\n"); - return -1; - } - frame = av_frame_alloc(); - if (!frame) { - LOG("Error allocating the frame\n"); - return -1; - } - - /* iterate through frames */ - *data = NULL; - *size = 0; - while (av_read_frame(fmt_ctx, packet) >= 0) { - avcodec_send_packet(codec, packet); - - err = avcodec_receive_frame(codec, frame); - if (err == AVERROR(EAGAIN)) - continue; - - convert_frame(swr, codec, frame, data, size, false); - } - /* Flush any remaining conversion buffers... */ - convert_frame(swr, codec, frame, data, size, true); - - av_packet_free(&packet); - av_frame_free(&frame); - swr_free(&swr); - //avio_context_free(); // todo? - avcodec_free_context(&codec); - avformat_close_input(&fmt_ctx); - avformat_free_context(fmt_ctx); - - if (avio_ctx) { - av_freep(&avio_ctx->buffer); - av_freep(&avio_ctx); - } - - return 0; -} + SwrContext * swr_ctx = nullptr; + if (swr_alloc_set_opts2(&swr_ctx, &out_ch_layout, out_sample_fmt, out_sample_rate, + &codec_ctx->ch_layout, codec_ctx->sample_fmt, codec_ctx->sample_rate, + 0, nullptr) < 0) { + fprintf(stderr, "error: failed to allocate swr context\n"); + avcodec_free_context(&codec_ctx); + avformat_close_input(&fmt_ctx); + return true; + } -// in mem decoding/conversion/resampling: -// ifname: input file path -// owav_data: in mem wav file. Can be forwarded as it to whisper/drwav -// return 0 on success -int ffmpeg_decode_audio(const std::string &ifname, std::vector& owav_data) { - LOG("ffmpeg_decode_audio: %s\n", ifname.c_str()); - int ifd = open(ifname.c_str(), O_RDONLY); - if (ifd == -1) { - fprintf(stderr, "Couldn't open input file %s\n", ifname.c_str()); - return -1; + if (swr_init(swr_ctx) < 0) { + fprintf(stderr, "error: failed to initialize swr context\n"); + swr_free(&swr_ctx); + avcodec_free_context(&codec_ctx); + avformat_close_input(&fmt_ctx); + return true; } - u8 *ibuf = NULL; - size_t ibuf_size; - int err = map_file(ifd, &ibuf, &ibuf_size); - if (err) { - LOG("Couldn't map input file %s\n", ifname.c_str()); - return err; + + // Decode and resample + AVPacket * packet = av_packet_alloc(); + AVFrame * frame = av_frame_alloc(); + + // Buffer to collect resampled output + std::vector pcm_data; + + // Max output samples per swr_convert call + const int max_out_samples = 16 * 1024; + std::vector out_buffer(max_out_samples); + + while (av_read_frame(fmt_ctx, packet) >= 0) { + if (packet->stream_index != audio_stream_idx) { + av_packet_unref(packet); + continue; + } + + int ret = avcodec_send_packet(codec_ctx, packet); + av_packet_unref(packet); + + if (ret < 0) { + continue; + } + + while (ret >= 0) { + ret = avcodec_receive_frame(codec_ctx, frame); + if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF) { + break; + } + if (ret < 0) { + break; + } + + // Resample + int out_samples = av_rescale_rnd(swr_get_delay(swr_ctx, out_sample_rate) + frame->nb_samples, + out_sample_rate, out_sample_rate, AV_ROUND_UP); + if (out_samples > (int)out_buffer.size()) { + out_buffer.resize(out_samples); + } + + const uint8_t * in_data[16] = {0}; + for (int p = 0; p < (int)codec_ctx->ch_layout.nb_channels && p < 16; p++) { + in_data[p] = frame->data[p]; + } + uint8_t * out_data[16] = {0}; + out_data[0] = (uint8_t *)out_buffer.data(); + + int got_samples = swr_convert(swr_ctx, out_data, out_samples, in_data, frame->nb_samples); + if (got_samples > 0) { + pcm_data.insert(pcm_data.end(), out_buffer.begin(), out_buffer.begin() + got_samples); + } + } } - LOG("Mapped input file: %s size: %d\n", ibuf, (int) ibuf_size); - struct audio_buffer inaudio_buf; - inaudio_buf.ptr = ibuf; - inaudio_buf.size = ibuf_size; - - s16 *odata=NULL; - int osize=0; - - err = decode_audio(&inaudio_buf, &odata, &osize); - LOG("decode_audio returned %d \n", err); - if (err != 0) { - LOG("decode_audio failed\n"); - return err; + + // Flush the decoder + avcodec_send_packet(codec_ctx, nullptr); + while (avcodec_receive_frame(codec_ctx, frame) >= 0) { + int out_samples = av_rescale_rnd(swr_get_delay(swr_ctx, out_sample_rate) + frame->nb_samples, + out_sample_rate, out_sample_rate, AV_ROUND_UP); + if (out_samples > (int)out_buffer.size()) { + out_buffer.resize(out_samples); + } + const uint8_t * in_data[16] = {0}; + for (int p = 0; p < (int)codec_ctx->ch_layout.nb_channels && p < 16; p++) { + in_data[p] = frame->data[p]; + } + uint8_t * out_data[16] = {0}; + out_data[0] = (uint8_t *)out_buffer.data(); + + int got_samples = swr_convert(swr_ctx, out_data, out_samples, in_data, frame->nb_samples); + if (got_samples > 0) { + pcm_data.insert(pcm_data.end(), out_buffer.begin(), out_buffer.begin() + got_samples); + } } - LOG("decode_audio output size: %d\n", osize); - - wave_hdr wh; - const size_t outdatasize = osize * sizeof(s16); - set_wave_hdr(wh, outdatasize); - owav_data.resize(sizeof(wave_hdr) + outdatasize); - // header: - memcpy(owav_data.data(), &wh, sizeof(wave_hdr)); - // the data: - memcpy(owav_data.data() + sizeof(wave_hdr), odata, osize* sizeof(s16)); - - return 0; + + // Flush the resampler + uint8_t * out_data[16] = {0}; + out_data[0] = (uint8_t *)out_buffer.data(); + int flush_samples = swr_convert(swr_ctx, out_data, max_out_samples, nullptr, 0); + if (flush_samples > 0) { + pcm_data.insert(pcm_data.end(), out_buffer.begin(), out_buffer.begin() + flush_samples); + } + + // Build WAV output + uint32_t data_size = pcm_data.size() * sizeof(int16_t); + wav_data.resize(44 + data_size); + + wav_header_write(wav_data.data(), 1, out_sample_rate, 16, data_size); + memcpy(wav_data.data() + 44, pcm_data.data(), data_size); + + // Cleanup + av_frame_free(&frame); + av_packet_free(&packet); + swr_free(&swr_ctx); + avcodec_free_context(&codec_ctx); + avformat_close_input(&fmt_ctx); + + return false; // success } + +#endif // WHISPER_COMMON_FFMPEG diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 09e77ea89c2..0593b748d36 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -78,7 +78,7 @@ add_test(NAME ${TEST_TARGET} -f ${PROJECT_SOURCE_DIR}/samples/jfk.wav) set_tests_properties(${TEST_TARGET} PROPERTIES LABELS "large") -if (WHISPER_FFMPEG) +if (WHISPER_COMMON_FFMPEG) set(TEST_TARGET test-whisper-cli-tiny-mp3) # Check with reviewers: any way to check the output transcription via ctest (diff, ...)? add_test(NAME ${TEST_TARGET} From 6c343e7a4ed01a77be70cc4be2f5001cc72521e3 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 31 May 2026 15:48:05 +0300 Subject: [PATCH 737/831] common : pass sample rate to `ffmpeg_decode_audio()` --- examples/common-whisper.cpp | 2 +- examples/ffmpeg-transcode.cpp | 5 +---- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/examples/common-whisper.cpp b/examples/common-whisper.cpp index 8cdd2320c17..c84e6843adc 100644 --- a/examples/common-whisper.cpp +++ b/examples/common-whisper.cpp @@ -36,7 +36,7 @@ #ifdef WHISPER_COMMON_FFMPEG // as implemented in ffmpeg-trancode.cpp only embedded in common lib if whisper built with ffmpeg support -extern bool ffmpeg_decode_audio(const std::string & ifname, std::vector & wav_data); +extern bool ffmpeg_decode_audio(const std::string & ifname, std::vector & wav_data, int out_sample_rate = WHISPER_SAMPLE_RATE); #endif // extract f32 PCM frames from an initialized decoder, downmix to mono and keep the stereo split diff --git a/examples/ffmpeg-transcode.cpp b/examples/ffmpeg-transcode.cpp index dc57fe74596..7657af69823 100644 --- a/examples/ffmpeg-transcode.cpp +++ b/examples/ffmpeg-transcode.cpp @@ -1,7 +1,5 @@ #ifdef WHISPER_COMMON_FFMPEG -#include "whisper.h" - #include #include #include @@ -44,7 +42,7 @@ static size_t wav_header_write(uint8_t * buf, int num_channels, int sample_rate, return 44; } -bool ffmpeg_decode_audio(const std::string & ifname, std::vector & wav_data) { +bool ffmpeg_decode_audio(const std::string & ifname, std::vector & wav_data, int out_sample_rate) { { const char * verbose = getenv("WHISPER_COMMON_FFMPEG_VERBOSE"); if (verbose && strcmp(verbose, "2") == 0) { @@ -116,7 +114,6 @@ bool ffmpeg_decode_audio(const std::string & ifname, std::vector & wav_ // Setup resampler: convert to 16-bit signed PCM, mono, 16000 Hz const enum AVSampleFormat out_sample_fmt = AV_SAMPLE_FMT_S16; - const int out_sample_rate = WHISPER_SAMPLE_RATE; AVChannelLayout out_ch_layout = AV_CHANNEL_LAYOUT_MONO; From 2e045a967b802564844fa17cf19792c8cf1f04ac Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 31 May 2026 15:45:44 +0300 Subject: [PATCH 738/831] ci : remove obsolete self-hosted label --- .github/workflows/build.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index e855ef7cf87..773122a0f0a 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -1490,7 +1490,7 @@ jobs: LLAMA_ARG_THREADS=$(nproc) GG_BUILD_NO_BF16=1 GG_BUILD_EXTRA_TESTS_0=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt ggml-ci-x64-nvidia-cuda: - runs-on: [self-hosted, Linux, mnt-root, NVIDIA] + runs-on: [self-hosted, Linux, NVIDIA] steps: - name: Clone @@ -1504,7 +1504,7 @@ jobs: GG_BUILD_CUDA=1 bash ./ci/run.sh ~/results/whisper.cpp /mnt/whisper.cpp ggml-ci-x64-nvidia-vulkan-cm: - runs-on: [self-hosted, Linux, mnt-root, NVIDIA] + runs-on: [self-hosted, Linux, NVIDIA] steps: - name: Clone @@ -1518,7 +1518,7 @@ jobs: GG_BUILD_VULKAN=1 GGML_VK_DISABLE_COOPMAT2=1 bash ./ci/run.sh ~/results/whisper.cpp /mnt/whisper.cpp ggml-ci-x64-nvidia-vulkan-cm2: - runs-on: [self-hosted, Linux, mnt-root, NVIDIA, COOPMAT2] + runs-on: [self-hosted, Linux, NVIDIA, COOPMAT2] steps: - name: Clone From 099af1c67d26172e2607a57d945e6c4a19a57a6f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 31 May 2026 16:04:12 +0300 Subject: [PATCH 739/831] pi : add config [no ci] --- .gitignore | 3 +++ .pi/gg/SYSTEM.md | 27 +++++++++++++++++++++++++++ 2 files changed, 30 insertions(+) create mode 100644 .pi/gg/SYSTEM.md diff --git a/.gitignore b/.gitignore index 6eb8ff45915..7a98228af3c 100644 --- a/.gitignore +++ b/.gitignore @@ -65,3 +65,6 @@ cmake-build-debug/ local.properties .log .exe + +# AGENTS +.pi/SYSTEM.md diff --git a/.pi/gg/SYSTEM.md b/.pi/gg/SYSTEM.md new file mode 100644 index 00000000000..1ae0e40674e --- /dev/null +++ b/.pi/gg/SYSTEM.md @@ -0,0 +1,27 @@ +You are a coding agent. Here are some very important rules that you must follow: + +General: +- Be very precise and concise when writing code, comments, explanations, etc. +- PR and commit titles format: ` : `. Lookup recents for examples +- Don't try to build or run the code unless you are explicitly asked to do so +- Use the `gh` CLI tool when querying PRs, issues, or other GitHub resources + +Coding: +- When in doubt, always refer to the CONTRIBUTING.md file of the project +- When referencing issues or PRs in comments, use the format: + - C/C++ code: `// ref: <url>` + - Other (CMake, etc.): `# ref: <url>` + +Pull requests (PRs): +- New branch names are prefixed with "gg/" +- Before opening a pull request, ask the user to confirm the description +- When creating a pull request, look for the repository's PR template and follow it +- For the AI usage disclosure section, write "YES. llama.cpp + pi + [MODEL]" +- Ask the user to tell you what model was used and write it in place of [MODEL] +- Always create the pull requests in draft mode + +Commits: +- On every commit that you make, include a "Assisted-by: llama.cpp:local pi" tag +- Do not explicitly set the git author in commits - rely on the default git config +- Always use `--no-gpg-sign` when committing +- Never `git push` without explicit confirmation from the user From fe69461618ffc50ba8afa65c25cc6c6e34d4537f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov <ggerganov@gmail.com> Date: Sun, 31 May 2026 16:06:32 +0300 Subject: [PATCH 740/831] ci : fix self-hosted paths to mnt --- .github/workflows/build.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 773122a0f0a..878c5833eaa 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -1501,7 +1501,7 @@ jobs: id: ggml-ci run: | nvidia-smi - GG_BUILD_CUDA=1 bash ./ci/run.sh ~/results/whisper.cpp /mnt/whisper.cpp + GG_BUILD_CUDA=1 bash ./ci/run.sh ~/results/whisper.cpp ~/mnt/whisper.cpp ggml-ci-x64-nvidia-vulkan-cm: runs-on: [self-hosted, Linux, NVIDIA] @@ -1515,7 +1515,7 @@ jobs: id: ggml-ci run: | vulkaninfo --summary - GG_BUILD_VULKAN=1 GGML_VK_DISABLE_COOPMAT2=1 bash ./ci/run.sh ~/results/whisper.cpp /mnt/whisper.cpp + GG_BUILD_VULKAN=1 GGML_VK_DISABLE_COOPMAT2=1 bash ./ci/run.sh ~/results/whisper.cpp ~/mnt/whisper.cpp ggml-ci-x64-nvidia-vulkan-cm2: runs-on: [self-hosted, Linux, NVIDIA, COOPMAT2] @@ -1529,7 +1529,7 @@ jobs: id: ggml-ci run: | vulkaninfo --summary - GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/whisper.cpp /mnt/whisper.cpp + GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/whisper.cpp ~/mnt/whisper.cpp #ggml-ci-x64-cpu-amx: # runs-on: [self-hosted, Linux, X64, CPU, AMX] @@ -1542,7 +1542,7 @@ jobs: # - name: Test # id: ggml-ci # run: | - # bash ./ci/run.sh ~/results/whisper.cpp /mnt/whisper.cpp + # bash ./ci/run.sh ~/results/whisper.cpp ~/mnt/whisper.cpp ggml-ci-mac-metal: runs-on: [self-hosted, macOS, ARM64] From 0dff27498f704b9eab8527f03c769efb7e7f051c Mon Sep 17 00:00:00 2001 From: Daniel Bevenius <daniel.bevenius@gmail.com> Date: Mon, 1 Jun 2026 07:20:19 +0200 Subject: [PATCH 741/831] ci : fix path to whisper.h in examples.yml [no ci] (#3842) This commit updates the include path to whisper.h and also ensures that this is only built on pushes to master. --- .github/workflows/build.yml | 5 +++-- .github/workflows/examples.yml | 6 ++++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 878c5833eaa..b7badd51041 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -29,8 +29,9 @@ on: pull_request: types: [opened, synchronize, reopened] paths-ignore: - - 'bindings/ruby/**' # handled by bindings-ruby.yml - - 'bindings/go/**' # handled by bindings-go.yml + - 'bindings/ruby/**' # handled by bindings-ruby.yml + - 'bindings/go/**' # handled by bindings-go.yml + - 'examples/addon.node/**' # handled by examples.yml workflow_dispatch: inputs: create_release: diff --git a/.github/workflows/examples.yml b/.github/workflows/examples.yml index 1c9ade5a300..df3aa832c2e 100644 --- a/.github/workflows/examples.yml +++ b/.github/workflows/examples.yml @@ -1,13 +1,15 @@ name: Examples Tests on: push: + branches: + - master paths: - examples/addon.node/** - - whisper.h + - include/whisper.h pull_request: paths: - examples/addon.node/** - - whisper.h + - include/whisper.h jobs: addon_node-ubuntu-22: From 23ee03506a91ac3d3f0071b40e66a430eebdfa1d Mon Sep 17 00:00:00 2001 From: Georgi Gerganov <ggerganov@gmail.com> Date: Mon, 1 Jun 2026 14:56:20 +0300 Subject: [PATCH 742/831] release : v1.8.6 --- CMakeLists.txt | 2 +- README.md | 2 +- bindings/javascript/package.json | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 35c8674725f..4df278c3ad8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,6 +1,6 @@ cmake_minimum_required(VERSION 3.5) # for add_link_options and implicit target directories. project("whisper.cpp" C CXX) -project("whisper.cpp" VERSION 1.8.5) +project("whisper.cpp" VERSION 1.8.6) include(CheckIncludeFileCXX) set(SOVERSION 1) diff --git a/README.md b/README.md index d1680e99bfc..fe7fa74153a 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ [![Conan Center](https://shields.io/conan/v/whisper-cpp)](https://conan.io/center/whisper-cpp) [![npm](https://img.shields.io/npm/v/whisper.cpp.svg)](https://www.npmjs.com/package/whisper.cpp/) -Stable: [v1.8.1](https://github.com/ggml-org/whisper.cpp/releases/tag/v1.8.1) / [Roadmap](https://github.com/orgs/ggml-org/projects/4/) +Stable: [v1.8.6](https://github.com/ggml-org/whisper.cpp/releases/tag/v1.8.6) / [Roadmap](https://github.com/orgs/ggml-org/projects/4/) High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisper) automatic speech recognition (ASR) model: diff --git a/bindings/javascript/package.json b/bindings/javascript/package.json index caf12b6dd2d..1f2f34672ae 100644 --- a/bindings/javascript/package.json +++ b/bindings/javascript/package.json @@ -1,6 +1,6 @@ { "name": "whisper.cpp", - "version": "1.8.5", + "version": "1.8.6", "description": "Whisper speech recognition", "main": "whisper.js", "scripts": { From ef24de1e5814c4fb14cc396aa6aed623032073ab Mon Sep 17 00:00:00 2001 From: Patrice Levesque <github-wayne@ptaff.ca> Date: Tue, 2 Jun 2026 03:22:16 -0400 Subject: [PATCH 743/831] cmake : do not assume /usr/lib library installation. (#3693) Current `pkgconfig` configuration file installation path and its contents assume libraries are installed under `/usr/lib` and this is not always the case, for instance `/usr/lib64` is quite possible under Gentoo Linux. Thus use the `CMAKE_INSTALL_LIBDIR` variable instead of a hardcoded `lib`. --- CMakeLists.txt | 2 +- cmake/whisper.pc.in | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 4df278c3ad8..3932cf2845e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -209,7 +209,7 @@ configure_file(cmake/whisper.pc.in @ONLY) install(FILES "${CMAKE_CURRENT_BINARY_DIR}/whisper.pc" - DESTINATION lib/pkgconfig) + DESTINATION ${CMAKE_INSTALL_LIBDIR}/pkgconfig) # # programs, examples and tests diff --git a/cmake/whisper.pc.in b/cmake/whisper.pc.in index 00ec7912014..200179d5d11 100644 --- a/cmake/whisper.pc.in +++ b/cmake/whisper.pc.in @@ -1,6 +1,6 @@ prefix=@CMAKE_INSTALL_PREFIX@ exec_prefix=${prefix} -libdir=${exec_prefix}/lib +libdir=${prefix}/@CMAKE_INSTALL_LIBDIR@ includedir=${prefix}/include Name: whisper From e5d44125788a69cca621c85c4d022e83162ac113 Mon Sep 17 00:00:00 2001 From: Noah Lyons <n.lyons53@gmail.com> Date: Tue, 2 Jun 2026 07:10:27 -0400 Subject: [PATCH 744/831] server : merge split utf-8 token text in verbose json (#3850) --- examples/cli/cli.cpp | 33 --------------------------------- examples/common-whisper.cpp | 28 ++++++++++++++++++++++++++++ examples/common-whisper.h | 3 +++ examples/server/server.cpp | 23 +++++++++++++++++++++-- tests/CMakeLists.txt | 8 ++++++++ tests/test-common-utf8.cpp | 34 ++++++++++++++++++++++++++++++++++ 6 files changed, 94 insertions(+), 35 deletions(-) create mode 100644 tests/test-common-utf8.cpp diff --git a/examples/cli/cli.cpp b/examples/cli/cli.cpp index 55cd71b4e55..7ca563dc250 100644 --- a/examples/cli/cli.cpp +++ b/examples/cli/cli.cpp @@ -31,39 +31,6 @@ static void replace_all(std::string & s, const std::string & search, const std:: } } -// Returns the number of trailing continuation bytes still needed for `s` to end -// on a complete UTF-8 codepoint. Returns 0 if the tail of `s` is already a -// complete codepoint (or if the tail looks malformed and we should stop merging). -// Used to merge whisper tokens whose bytes split a multi-byte UTF-8 character -// (e.g. CJK), so the JSON output stays valid UTF-8. See https://github.com/ggml-org/whisper.cpp/issues/1798. -static int utf8_trailing_bytes_needed(const std::string & s) { - const int n = (int) s.size(); - int i = n - 1; - // walk back past continuation bytes (10xxxxxx) - while (i >= 0 && ((unsigned char) s[i] & 0xC0) == 0x80) { - --i; - } - if (i < 0) { - // all continuation bytes, or empty — nothing we can do - return 0; - } - const unsigned char c = (unsigned char) s[i]; - int expected; - if ((c & 0x80) == 0x00) { - expected = 1; // ASCII - } else if ((c & 0xE0) == 0xC0) { - expected = 2; - } else if ((c & 0xF0) == 0xE0) { - expected = 3; - } else if ((c & 0xF8) == 0xF0) { - expected = 4; - } else { - return 0; // malformed lead, give up - } - const int have = n - i; - return have >= expected ? 0 : (expected - have); -} - // command-line parameters struct whisper_params { int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); diff --git a/examples/common-whisper.cpp b/examples/common-whisper.cpp index c84e6843adc..b12481c013f 100644 --- a/examples/common-whisper.cpp +++ b/examples/common-whisper.cpp @@ -198,6 +198,34 @@ int timestamp_to_sample(int64_t t, int n_samples, int whisper_sample_rate) { return std::max(0, std::min((int) n_samples - 1, (int) ((t*whisper_sample_rate)/100))); } +int utf8_trailing_bytes_needed(const std::string & s) { + const int n = (int) s.size(); + int i = n - 1; + while (i >= 0 && ((unsigned char) s[i] & 0xC0) == 0x80) { + --i; + } + if (i < 0) { + return 0; + } + + const unsigned char c = (unsigned char) s[i]; + int expected; + if ((c & 0x80) == 0x00) { + expected = 1; + } else if ((c & 0xE0) == 0xC0) { + expected = 2; + } else if ((c & 0xF0) == 0xE0) { + expected = 3; + } else if ((c & 0xF8) == 0xF0) { + expected = 4; + } else { + return 0; + } + + const int have = n - i; + return have >= expected ? 0 : (expected - have); +} + bool speak_with_file(const std::string & command, const std::string & text, const std::string & path, int voice_id) { std::ofstream speak_file(path.c_str()); if (speak_file.fail()) { diff --git a/examples/common-whisper.h b/examples/common-whisper.h index 8714c381046..aec430d3635 100644 --- a/examples/common-whisper.h +++ b/examples/common-whisper.h @@ -28,5 +28,8 @@ std::string to_timestamp(int64_t t, bool comma = false); // given a timestamp get the sample int timestamp_to_sample(int64_t t, int n_samples, int whisper_sample_rate); +// Returns the number of trailing bytes still needed for s to end on a complete UTF-8 codepoint. +int utf8_trailing_bytes_needed(const std::string & s); + // write text to file, and call system("command voice_id file") bool speak_with_file(const std::string & command, const std::string & text, const std::string & path, int voice_id); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index aae74c3d840..b87ef27375f 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1107,10 +1107,29 @@ int main(int argc, char ** argv) { } segment["tokens"].push_back(token.id); - json word = json{{"word", whisper_full_get_token_text(ctx, i, j)}}; + std::string word_text = whisper_full_get_token_text(ctx, i, j); + int64_t word_t1 = token.t1; + + while (j + 1 < n_tokens && utf8_trailing_bytes_needed(word_text) > 0) { + const whisper_token_data next_token = whisper_full_get_token_data(ctx, i, j + 1); + // Keep verbose_json tokens free of EOT ids, matching the pre-merge server behavior. + if (next_token.id >= whisper_token_eot(ctx)) { + break; + } + + ++j; + segment["tokens"].push_back(next_token.id); + word_text += whisper_full_get_token_text(ctx, i, j); + if (next_token.t1 > -1) { + word_t1 = next_token.t1; + } + total_logprob += next_token.plog; + } + + json word = json{{"word", word_text}}; if (!params.no_timestamps && params.token_timestamps) { word["start"] = token.t0 * 0.01; - word["end"] = token.t1 * 0.01; + word["end"] = word_t1 * 0.01; word["t_dtw"] = token.t_dtw; } word["probability"] = token.p; diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 0593b748d36..646f45f2ab7 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -88,6 +88,14 @@ if (WHISPER_COMMON_FFMPEG) set_tests_properties(${TEST_TARGET} PROPERTIES LABELS "tiny;mp3") endif() +# UTF-8 helper unit test +set(UTF8_TEST test-common-utf8) +add_executable(${UTF8_TEST} ${UTF8_TEST}.cpp) +target_include_directories(${UTF8_TEST} PRIVATE ../examples) +target_link_libraries(${UTF8_TEST} PRIVATE common) +add_test(NAME ${UTF8_TEST} COMMAND ${UTF8_TEST}) +set_tests_properties(${UTF8_TEST} PROPERTIES LABELS "unit") + # VAD test tests VAD in isolation set(VAD_TEST test-vad) add_executable(${VAD_TEST} ${VAD_TEST}.cpp) diff --git a/tests/test-common-utf8.cpp b/tests/test-common-utf8.cpp new file mode 100644 index 00000000000..91c73a7428d --- /dev/null +++ b/tests/test-common-utf8.cpp @@ -0,0 +1,34 @@ +#include "common-whisper.h" + +#include <cstdlib> +#include <cstdio> +#include <string> + +static void expect_needed(const std::string & input, int expected) { + const int actual = utf8_trailing_bytes_needed(input); + if (actual != expected) { + fprintf(stderr, "expected %d trailing UTF-8 bytes, got %d\n", expected, actual); + std::abort(); + } +} + +int main() { + expect_needed("", 0); + expect_needed("plain ascii", 0); + + const std::string cjk = "\xE4\xBD\xA0"; // U+4F60 + expect_needed(cjk.substr(0, 1), 2); + expect_needed(cjk.substr(0, 2), 1); + expect_needed(cjk, 0); + + const std::string emoji = "\xF0\x9F\x98\x80"; // U+1F600 + expect_needed(emoji.substr(0, 1), 3); + expect_needed(emoji.substr(0, 2), 2); + expect_needed(emoji.substr(0, 3), 1); + expect_needed(emoji, 0); + + expect_needed("\x80\x80", 0); + expect_needed("\xFF", 0); + + return 0; +} From 610e664ba7cfe3af46125ed1b5a1184fccb51bcd Mon Sep 17 00:00:00 2001 From: danscMax <153344025+danscMax@users.noreply.github.com> Date: Tue, 2 Jun 2026 14:25:29 +0300 Subject: [PATCH 745/831] whisper : catch C++ exceptions in whisper_init_with_params_no_state (#3831) whisper_model_load() can throw instead of returning false: std::runtime_error from this file (failed ggml context / no compatible buffer type), or vk::SystemError / vk::OutOfDeviceMemoryError from the ggml-vulkan backend during device/buffer allocation. whisper_init_* are extern "C", so a C++ exception unwinding across that boundary aborts non-C++ callers (Rust via whisper-rs, Go via cgo) -- on Windows STATUS_STACK_BUFFER_OVERRUN (0xC0000409) -- even though the function already returns NULL on failure. Wrap whisper_model_load() in try/catch and route any throw into the existing NULL-return path. Co-authored-by: Claude Opus 4.7 <noreply@anthropic.com> --- src/whisper.cpp | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/src/whisper.cpp b/src/whisper.cpp index 0fe29a4541e..5ffc70af00e 100644 --- a/src/whisper.cpp +++ b/src/whisper.cpp @@ -3720,7 +3720,21 @@ struct whisper_context * whisper_init_with_params_no_state(struct whisper_model_ whisper_context * ctx = new whisper_context; ctx->params = params; - if (!whisper_model_load(loader, *ctx)) { + // A C++ exception escaping this extern "C" function aborts non-C++ callers + // (Rust via whisper-rs, Go via cgo, ...). whisper_model_load can throw + // (std::runtime_error here; vk::SystemError from the Vulkan backend during + // device/buffer allocation), so funnel any throw into the existing + // NULL-return failure path instead of letting it cross the C ABI. + bool model_loaded = false; + try { + model_loaded = whisper_model_load(loader, *ctx); + } catch (const std::exception & e) { + WHISPER_LOG_ERROR("%s: exception during model load: %s\n", __func__, e.what()); + } catch (...) { + WHISPER_LOG_ERROR("%s: unknown exception during model load\n", __func__); + } + + if (!model_loaded) { loader->close(loader->context); WHISPER_LOG_ERROR("%s: failed to load model\n", __func__); delete ctx; From 02d5316af5f9cef149ec20eebcba99cd6395b6b3 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov <ggerganov@gmail.com> Date: Thu, 4 Jun 2026 09:35:58 +0300 Subject: [PATCH 746/831] ci : refactor + optimize (#3847) * ci : add ccache clear action * ci : split self-hosted GPU jobs into build-self-hosted.yml Extract self-hosted runner jobs from build.yml into a dedicated build-self-hosted.yml following the llama.cpp pattern: - gpu-cuda (NVIDIA Linux) - gpu-vulkan-nvidia-cm (NVIDIA Linux) - gpu-vulkan-nvidia-cm2 (NVIDIA Linux + COOPMAT2) - gpu-metal (macOS ARM64) - gpu-vulkan (macOS ARM64) GitHub-hosted CPU jobs remain in build.yml. Assisted-by: llama.cpp:local pi * ci : split release jobs into release.yml Extract release-related jobs from build.yml into a dedicated release.yml following the llama.cpp pattern: - determine-tag - windows (Win32/x64, SDL2) - windows-blas (Win32/x64, OpenBLAS) - windows-cublas (x64, CUDA 11.8/12.4) - ios-xcode-build - bindings-java (depends on windows) - release (artifact aggregation + GitHub release) CoreML job stays in build.yml with its own local tag calculation. Assisted-by: llama.cpp:local pi * ci : remove bindings-java job from release.yml Assisted-by: llama.cpp:local pi * cont : add manual trigger for build.yml * cont : remove obsolete ifs * ci : extract sanitizer job to bild-sanitize.yml * ci : extract linux jobs into build-linux.yml * ci : extract macos jobs to build-macos.yml * ci : extract gcc jobs to build-gcc.yml * ci : extract clang jobs to build-clang.yml * ci : extract sycl jobs to build-sycl.yml * ci : extract windows jobs to build-windows.yml * ci : extract emscripten job to build-wasm.yml * ci : extract android jobs into build-android.yml * ci : extract quantize job to quantize.yml * ci : extract coreml job into coreml.yml * ci : extract vad job to vad.yml * ci : extract cpu jobs to build-cpu.yml * ci : make naming of yml files consistent * ci : add --fail to curl download and propagate This commit adds the --fail option to the model download scripts so that if the model download returns a server error this is picked up. This is then detected in run.sh and a error message is displayed and the script stops and returns an error. The motivation for this is that currently it is possible for the model download to fail but this script proceeds and instead of a model file the contents will be an html page probably with the error. This will then cause the model to not be able to load due to a missing magic number. I'm not sure we can do much about the downloading failing, perhaps a retry but at least this will give a clearer error message. Refs: https://github.com/danbev/whisper.cpp/actions/runs/26866349389/job/79230794512 * ci : enable command traces to see download command in use * ci : add retry functionality to download model script This commit adds curl retry options to the model download script. The motivation is that currently when CI jobs run huggingface rate limit the requests and return: ```console curl: (22) The requested URL returned error: 429 ``` This is an attempt to work around this and if it does not work then we can an authorization token. * ci : extract freebsd job to build-freebsd.yml This job has been commented out as it has been flaky in the past. I'll monitor this and if it continues to be unreliable we can disable it in the github actions GUI instead of commenting it out like we did before. * ci : add ccache to jobs (non-docker builds) The ccache will only be saved on pushed to master. * ci : bump ccache-action version to v1.2.21 The motivation for this is that the save parameter does not seem to work with the current version. * ci : add ccache to docker jobs in build-linux.yml * ci : add debug statements to linux docker build * ci : set CCACHE_DIR for build-linux.yml * ci : add ccache to the remaining docker jobs * ci : remove build-linux.yml This commit remove build-linux.yml as the same jobs are also run by build-gcc.yml, with the exception that build-gcc.yml also run ctest). So keeping build-gcc.yml and removing the redundant build-linux.yml. * ci : add linux build artifacts to release * ci : revert to hendrikmuhs/ccache-action for win job This is currently causing the following failure: ```console sccache C:\PROGRA~1\NVIDIA~1\CUDA\v\bin\nvcc.exe -forward-unknown-to-host-compiler -DGGML_BACKEND_BUILD -DGGML_BACKEND_SHARED -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 -DGGML_SCHED_MAX_COPIES=4 -DGGML_SHARED -D_CRT_SECURE_NO_WARNINGS -D_XOPEN_SOURCE=600 -Dggml_cuda_EXPORTS -DCMAKE_INTDIR=\"Release\" -ID:\a\whisper.cpp\whisper.cpp\ggml\src\ggml-cuda\.. -ID:\a\whisper.cpp\whisper.cpp\ggml\src\..\include -isystem "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v\include" -Xcompiler="-MD -O2 -Ob2" -DNDEBUG -std=c++17 -arch=native -use_fast_math -extended-lambda -Xcompiler /Zc:preprocessor -MD -MT ggml\src\ggml-cuda\CMakeFiles\ggml-cuda.dir\Release\allreduce.cu.obj -MF ggml\src\ggml-cuda\CMakeFiles\ggml-cuda.dir\Release\allreduce.cu.obj.d -x cu -c D:\a\whisper.cpp\whisper.cpp\ggml\src\ggml-cuda\allreduce.cu -o ggml\src\ggml-cuda\CMakeFiles\ggml-cuda.dir\Release\allreduce.cu.obj -Xcompiler=-Fdggml\src\ggml-cuda\CMakeFiles\ggml-cuda.dir\Release\,-FS sccache: encountered fatal error sccache: error: Could not parse shell line sccache: caused by: Could not parse shell line ``` Refs: https://github.com/danbev/whisper.cpp/actions/runs/26883673904/job/79290017353 * ci : make static linux artifacts * ci : make linux release artifact names consistent This commit removes the tag form the linux release artifacts to be consistent with the existing artifacts. If we want to include the tag then we can do that in a follow-up PR. * ci : fix linux zip files to have a directory * ci : add HF_TOKEN secret for HF download authorization This is to avoid the HR rate limiting when downloading model. --------- Co-authored-by: Daniel Bevenius <daniel.bevenius@gmail.com> --- .github/actions/ccache-clear/action.yml | 22 + .github/workflows/build-android.yml | 80 ++ .github/workflows/build-clang.yml | 121 ++ .github/workflows/build-coreml.yml | 65 + .github/workflows/build-cpu.yml | 173 +++ .github/workflows/build-freebsd.yml | 47 + .github/workflows/build-gcc.yml | 166 +++ .github/workflows/build-macos.yml | 72 ++ .github/workflows/build-quantize.yml | 41 + .github/workflows/build-sanitize.yml | 82 ++ .github/workflows/build-self-hosted.yml | 116 ++ .github/workflows/build-sycl.yml | 132 ++ .github/workflows/build-vad.yml | 43 + .github/workflows/build-wasm.yml | 51 + .github/workflows/build-windows.yml | 76 ++ .github/workflows/build.yml | 1573 ----------------------- .github/workflows/examples.yml | 2 + .github/workflows/release.yml | 649 ++++++++++ ci/run.sh | 7 + models/download-ggml-model.sh | 8 +- 20 files changed, 1952 insertions(+), 1574 deletions(-) create mode 100644 .github/actions/ccache-clear/action.yml create mode 100644 .github/workflows/build-android.yml create mode 100644 .github/workflows/build-clang.yml create mode 100644 .github/workflows/build-coreml.yml create mode 100644 .github/workflows/build-cpu.yml create mode 100644 .github/workflows/build-freebsd.yml create mode 100644 .github/workflows/build-gcc.yml create mode 100644 .github/workflows/build-macos.yml create mode 100644 .github/workflows/build-quantize.yml create mode 100644 .github/workflows/build-sanitize.yml create mode 100644 .github/workflows/build-self-hosted.yml create mode 100644 .github/workflows/build-sycl.yml create mode 100644 .github/workflows/build-vad.yml create mode 100644 .github/workflows/build-wasm.yml create mode 100644 .github/workflows/build-windows.yml delete mode 100644 .github/workflows/build.yml create mode 100644 .github/workflows/release.yml diff --git a/.github/actions/ccache-clear/action.yml b/.github/actions/ccache-clear/action.yml new file mode 100644 index 00000000000..d38587efaf8 --- /dev/null +++ b/.github/actions/ccache-clear/action.yml @@ -0,0 +1,22 @@ +name: "ccache-clear" +description: "Delete all GitHub Actions caches matching a key prefix" +inputs: + key: + description: "Cache key prefix to match and delete" + required: true + +runs: + using: "composite" + steps: + - name: Clear caches + shell: bash + run: | + CACHES=$(gh cache list --key "ccache-${{ inputs.key }}" --json id,key --jq '.[] | "\(.id) \(.key)"' 2>/dev/null) + if [ -z "$CACHES" ]; then + echo "No caches found with key prefix: ${{ inputs.key }}" + exit 0 + fi + while read -r id key; do + echo "Deleting cache: $id ($key)" + gh cache delete "$id" + done <<< "$CACHES" diff --git a/.github/workflows/build-android.yml b/.github/workflows/build-android.yml new file mode 100644 index 00000000000..d9af1810131 --- /dev/null +++ b/.github/workflows/build-android.yml @@ -0,0 +1,80 @@ +name: CI (android) + +on: + workflow_dispatch: # allows manual triggering + push: + branches: + - master + paths: ['.github/workflows/build-android.yml', + '**/CMakeLists.txt', + '**/*.h', + '**/*.hpp', + '**/*.c', + '**/*.cpp', + '**/*.java'] + + pull_request: + types: [opened, synchronize, reopened] + paths-ignore: + - 'bindings/ruby/**' # handled by bindings-ruby.yml + - 'bindings/go/**' # handled by bindings-go.yml + - 'examples/addon.node/**' # handled by examples.yml + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} + cancel-in-progress: true + +jobs: + android: + runs-on: ubuntu-22.04 + + steps: + - name: Clone + uses: actions/checkout@v6 + with: + path: whisper + + - name: Install Java + uses: actions/setup-java@v5 + with: + distribution: zulu + java-version: 21 + + - name: Setup Android SDK + uses: android-actions/setup-android@v3 + + - name: Build + run: | + cd whisper/examples/whisper.android + ./gradlew assembleRelease --no-daemon + + - name: Build with external ggml + run: | + export PATH_TO_GGML=$PWD/ggml + cd whisper/examples/whisper.android + ./gradlew assembleRelease --no-daemon + + android_java: + runs-on: ubuntu-22.04 + + steps: + - name: Clone + uses: actions/checkout@v6 + + - name: set up JDK 11 + uses: actions/setup-java@v5 + with: + java-version: '11' + distribution: 'temurin' + cache: gradle + + - name: Setup Android SDK + uses: android-actions/setup-android@v3 + with: + cmdline-tools-version: 9.0 + + - name: Build + run: | + cd examples/whisper.android.java + chmod +x ./gradlew + ./gradlew assembleRelease diff --git a/.github/workflows/build-clang.yml b/.github/workflows/build-clang.yml new file mode 100644 index 00000000000..c7a36884f64 --- /dev/null +++ b/.github/workflows/build-clang.yml @@ -0,0 +1,121 @@ +name: CI (clang) + +on: + workflow_dispatch: # allows manual triggering + push: + branches: + - master + paths: ['.github/workflows/build-clang.yml', + '**/CMakeLists.txt', + '**/Makefile', + '**/*.mk', + '**/*.cmake', + '**/*.in', + '**/*.h', + '**/*.hpp', + '**/*.c', + '**/*.cpp', + '**/*.cu', + '**/*.cuh', + '**/*.cl'] + + pull_request: + types: [opened, synchronize, reopened] + paths-ignore: + - 'bindings/ruby/**' # handled by bindings-ruby.yml + - 'bindings/go/**' # handled by bindings-go.yml + - 'examples/addon.node/**' # handled by examples.yml + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} + cancel-in-progress: true + +env: + ubuntu_image: "ubuntu:22.04" + +jobs: + ubuntu-22-clang: + runs-on: ubuntu-22.04 + + strategy: + fail-fast: false + matrix: + build: [Debug, Release] + #arch: [linux/amd64, linux/arm64, linux/arm/v7, linux/ppc64le] + # TODO: arm/v7 disabled due to clang bug + # https://github.com/ggerganov/whisper.cpp/actions/runs/9657764109/job/26637633042?pr=2256#step:4:1990 + arch: [linux/amd64, linux/ppc64le] + + steps: + - name: Clone + uses: actions/checkout@v6 + + - name: Set CCACHE_DIR + run: echo "CCACHE_DIR=${{ runner.temp }}/ccache" >> $GITHUB_ENV + + - name: ccache + uses: ggml-org/ccache-action@v1.2.21 + with: + key: clang-${{ matrix.arch }}-${{ matrix.build }} + evict-old-files: 1d + save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} + + - name: Set up QEMU + uses: docker/setup-qemu-action@v3 + + - name: Build ${{ matrix.arch }} + run: | + docker run --platform ${{ matrix.arch }} --rm \ + -v ${{ github.workspace }}:/workspace \ + -v ${CCACHE_DIR}:${CCACHE_DIR} \ + -e CCACHE_DIR=${CCACHE_DIR} \ + -w /workspace ${{ env.ubuntu_image }} /bin/sh -c ' + set -e + export DEBIAN_FRONTEND=noninteractive + sed -i "s|archive.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list + sed -i "s|security.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list + + apt update + apt install -y clang build-essential cmake libsdl2-dev git ccache + cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} \ + -DCMAKE_CXX_COMPILER=clang++ \ + -DCMAKE_C_COMPILER=clang \ + -DCMAKE_C_COMPILER_LAUNCHER=ccache \ + -DCMAKE_CXX_COMPILER_LAUNCHER=ccache + make + ctest -L gh --output-on-failure' + + ubuntu-22-clang-arm64: + runs-on: ubuntu-22.04-arm + + strategy: + fail-fast: false + matrix: + build: [Debug, Release] + + steps: + - name: Clone + uses: actions/checkout@v6 + + - name: ccache + uses: ggml-org/ccache-action@v1.2.21 + with: + key: clang-arm64-${{ matrix.build }} + evict-old-files: 1d + save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} + + - name: Install dependencies + run: | + sudo apt-get update + sudo apt-get install -y clang build-essential cmake libsdl2-dev git + + - name: Build and Test + run: | + cmake . -DWHISPER_SDL2=ON \ + -DCMAKE_BUILD_TYPE=${{ matrix.build }} \ + -DCMAKE_CXX_COMPILER=clang++ \ + -DCMAKE_C_COMPILER=clang \ + -DGGML_NATIVE=OFF \ + -DGGML_CPU_ARM_ARCH=armv8-a + make + ctest -L gh --output-on-failure diff --git a/.github/workflows/build-coreml.yml b/.github/workflows/build-coreml.yml new file mode 100644 index 00000000000..d383d9ae0a7 --- /dev/null +++ b/.github/workflows/build-coreml.yml @@ -0,0 +1,65 @@ +name: CI (coreml) + +on: + workflow_dispatch: # allows manual triggering + push: + branches: + - master + tags: + - 'v*' + paths: ['.github/workflows/build-coreml.yml', + '**/CMakeLists.txt', + '**/*.h', + '**/*.hpp', + '**/*.c', + '**/*.cpp', + '**/*.swift', + '**/*.m', + '**/*.mm', + '**/*.metal'] + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} + cancel-in-progress: true + +env: + BRANCH_NAME: ${{ github.head_ref || github.ref_name }} + +jobs: + coreml-base-en: + runs-on: macos-latest + + steps: + - name: Checkout with full history + uses: actions/checkout@v6 + with: + fetch-depth: 0 + + - name: Set environment variables + id: set_vars + run: | + BUILD_NUMBER=$(git rev-list --count HEAD) + SHORT_HASH=$(git rev-parse --short=7 HEAD) + if [[ "${{ github.ref_type }}" == "tag" ]]; then + TAG_NAME="${{ github.ref_name }}" + elif [[ "${{ env.BRANCH_NAME }}" == "master" ]]; then + TAG_NAME="b${BUILD_NUMBER}" + else + SAFE_NAME=$(echo "${{ env.BRANCH_NAME }}" | tr '/' '-') + TAG_NAME="${SAFE_NAME}-b${BUILD_NUMBER}-${SHORT_HASH}" + fi + echo "MODEL_NAME=base.en" >> $GITHUB_ENV + echo "GEN_MODEL_NAME=whisper-${TAG_NAME}-ggml-base.en-encoder.mlmodelc" >> $GITHUB_ENV + + - name: Download model + env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} + run: | + ./models/download-ggml-model.sh ${{ env.MODEL_NAME }} + + - name: Generate CoreML model + run: | + python3.11 -m venv venv + source venv/bin/activate + pip install ane_transformers openai-whisper coremltools + ./models/generate-coreml-model.sh ${{ env.MODEL_NAME }} diff --git a/.github/workflows/build-cpu.yml b/.github/workflows/build-cpu.yml new file mode 100644 index 00000000000..9c8e0586fcb --- /dev/null +++ b/.github/workflows/build-cpu.yml @@ -0,0 +1,173 @@ +name: CI (cpu) + +on: + workflow_dispatch: # allows manual triggering + push: + branches: + - master + paths: ['.github/workflows/build-cpu.yml', + '**/CMakeLists.txt', + '**/Makefile', + '**/*.mk', + '**/*.cmake', + '**/*.in', + '**/*.h', + '**/*.hpp', + '**/*.c', + '**/*.cpp', + '**/*.cu', + '**/*.cuh', + '**/*.cl'] + + pull_request: + types: [opened, synchronize, reopened] + paths-ignore: + - 'bindings/ruby/**' # handled by bindings-ruby.yml + - 'bindings/go/**' # handled by bindings-go.yml + - 'examples/addon.node/**' # handled by examples.yml + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} + cancel-in-progress: true + +# TODO: simplify the following jobs using a matrix +jobs: + ggml-ci-x64-cpu-low-perf: + runs-on: ubuntu-22.04 + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v6 + + - name: ccache + uses: ggml-org/ccache-action@v1.2.21 + with: + key: ggml-ci-x64-cpu-low-perf + evict-old-files: 1d + save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} + + - name: Dependencies + id: depends + run: | + sudo apt-get update + sudo apt-get install build-essential libcurl4-openssl-dev + + - name: Test + id: ggml-ci + env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} + run: | + LLAMA_ARG_THREADS=$(nproc) GG_BUILD_LOW_PERF=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt + + ggml-ci-arm64-cpu-low-perf: + runs-on: ubuntu-22.04-arm + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v6 + + - name: ccache + uses: ggml-org/ccache-action@v1.2.21 + with: + key: ggml-ci-arm64-cpu-low-perf + evict-old-files: 1d + save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} + + - name: Dependencies + id: depends + run: | + sudo apt-get update + sudo apt-get install build-essential libcurl4-openssl-dev + + - name: Test + id: ggml-ci + env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} + run: | + LLAMA_ARG_THREADS=$(nproc) GG_BUILD_LOW_PERF=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt + + ggml-ci-x64-cpu-high-perf: + runs-on: ubuntu-22.04 + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v6 + + - name: ccache + uses: ggml-org/ccache-action@v1.2.21 + with: + key: ggml-ci-x64-cpu-high-perf + evict-old-files: 1d + save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} + + - name: Dependencies + id: depends + run: | + sudo apt-get update + sudo apt-get install build-essential libcurl4-openssl-dev + + - name: Test + id: ggml-ci + env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} + run: | + LLAMA_ARG_THREADS=$(nproc) bash ./ci/run.sh ./tmp/results ./tmp/mnt + + ggml-ci-arm64-cpu-high-perf: + runs-on: ubuntu-22.04-arm + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v6 + + - name: ccache + uses: ggml-org/ccache-action@v1.2.21 + with: + key: ggml-ci-arm64-cpu-high-perf + evict-old-files: 1d + save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} + + - name: Dependencies + id: depends + run: | + sudo apt-get update + sudo apt-get install build-essential libcurl4-openssl-dev + + - name: Test + id: ggml-ci + env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} + run: | + LLAMA_ARG_THREADS=$(nproc) GG_BUILD_NO_SVE=1 GG_BUILD_NO_BF16=1 GG_BUILD_EXTRA_TESTS_0=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt + + ggml-ci-arm64-cpu-high-perf-sve: + runs-on: ubuntu-22.04-arm + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v6 + + - name: ccache + uses: ggml-org/ccache-action@v1.2.21 + with: + key: ggml-ci-arm64-cpu-high-perf-sve + evict-old-files: 1d + save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} + + - name: Dependencies + id: depends + run: | + sudo apt-get update + sudo apt-get install build-essential libcurl4-openssl-dev + + - name: Test + id: ggml-ci + env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} + run: | + LLAMA_ARG_THREADS=$(nproc) GG_BUILD_NO_BF16=1 GG_BUILD_EXTRA_TESTS_0=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt diff --git a/.github/workflows/build-freebsd.yml b/.github/workflows/build-freebsd.yml new file mode 100644 index 00000000000..847ae975e30 --- /dev/null +++ b/.github/workflows/build-freebsd.yml @@ -0,0 +1,47 @@ +name: CI (freebsd) + +on: + workflow_dispatch: # allows manual triggering + push: + branches: + - master + paths: ['.github/workflows/build-freebsd.yml', + '**/CMakeLists.txt', + '**/Makefile', + '**/*.mk', + '**/*.cmake', + '**/*.in', + '**/*.h', + '**/*.hpp', + '**/*.c', + '**/*.cpp'] + + pull_request: + types: [opened, synchronize, reopened] + paths-ignore: + - 'bindings/ruby/**' # handled by bindings-ruby.yml + - 'bindings/go/**' # handled by bindings-go.yml + - 'examples/addon.node/**' # handled by examples.yml + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} + cancel-in-progress: true + +jobs: + freeBSD-latest: + runs-on: macos-13 + + steps: + - name: Clone + uses: actions/checkout@v6 + + - name: Build + uses: cross-platform-actions/action@v0.27.0 + with: + operating_system: freebsd + version: '14.2' + run: | + sudo pkg update + sudo pkg install -y gmake sdl2 cmake git + cmake -B build + cmake --build build --config Release diff --git a/.github/workflows/build-gcc.yml b/.github/workflows/build-gcc.yml new file mode 100644 index 00000000000..4528ba3d534 --- /dev/null +++ b/.github/workflows/build-gcc.yml @@ -0,0 +1,166 @@ +name: CI (gcc) + +on: + workflow_dispatch: # allows manual triggering + push: + branches: + - master + paths: ['.github/workflows/build-gcc.yml', + '**/CMakeLists.txt', + '**/Makefile', + '**/*.mk', + '**/*.cmake', + '**/*.in', + '**/*.h', + '**/*.hpp', + '**/*.c', + '**/*.cpp', + '**/*.cu', + '**/*.cuh', + '**/*.cl'] + + pull_request: + types: [opened, synchronize, reopened] + paths-ignore: + - 'bindings/ruby/**' # handled by bindings-ruby.yml + - 'bindings/go/**' # handled by bindings-go.yml + - 'examples/addon.node/**' # handled by examples.yml + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} + cancel-in-progress: true + +env: + ubuntu_image: "ubuntu:22.04" + +jobs: + ubuntu-22-gcc: + runs-on: ubuntu-22.04 + + strategy: + fail-fast: false + matrix: + build: [Debug, Release] + arch: [linux/amd64, linux/ppc64le] + + steps: + - name: Clone + uses: actions/checkout@v6 + + - name: Set CCACHE_DIR + run: echo "CCACHE_DIR=${{ runner.temp }}/ccache" >> $GITHUB_ENV + + - name: ccache + uses: ggml-org/ccache-action@v1.2.21 + with: + key: gcc-${{ matrix.arch }}-${{ matrix.build }} + evict-old-files: 1d + save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} + + - name: Set up QEMU + uses: docker/setup-qemu-action@v3 + + - name: Build ${{ matrix.arch }} + run: | + docker run --platform ${{ matrix.arch }} --rm \ + -v ${{ github.workspace }}:/workspace \ + -v ${CCACHE_DIR}:${CCACHE_DIR} \ + -e CCACHE_DIR=${CCACHE_DIR} \ + -w /workspace ${{ env.ubuntu_image }} /bin/sh -c ' + set -e + export DEBIAN_FRONTEND=noninteractive + sed -i "s|archive.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list + sed -i "s|security.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list + + apt update + apt install -y build-essential cmake libsdl2-dev git ccache + cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} \ + -DCMAKE_C_COMPILER_LAUNCHER=ccache \ + -DCMAKE_CXX_COMPILER_LAUNCHER=ccache + make + ctest -L gh --output-on-failure' + + ubuntu-22-gcc-arm64: + runs-on: ubuntu-22.04-arm + + strategy: + fail-fast: false + matrix: + build: [Debug, Release] + + steps: + - name: Clone + uses: actions/checkout@v6 + + - name: ccache + uses: ggml-org/ccache-action@v1.2.21 + with: + key: gcc-arm64-${{ matrix.build }} + evict-old-files: 1d + save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} + + - name: Install dependencies + run: | + sudo apt-get update + sudo apt-get install -y build-essential cmake libsdl2-dev git + + - name: Configure CMake + run: | + cmake . \ + -DWHISPER_SDL2=ON \ + -DCMAKE_BUILD_TYPE=${{ matrix.build }} \ + -DGGML_NATIVE=OFF \ + -DGGML_CPU_ARM_ARCH=armv8-a + + - name: Build and Test + run: | + make + ctest -L gh --output-on-failure + + ubuntu-22-gcc-arm-v7: + runs-on: ubuntu-22.04 + + strategy: + fail-fast: false + matrix: + build: [Debug, Release] + arch: [linux/arm/v7] + + steps: + - name: Clone + uses: actions/checkout@v6 + + - name: Set CCACHE_DIR + run: echo "CCACHE_DIR=${{ runner.temp }}/ccache" >> $GITHUB_ENV + + - name: ccache + uses: ggml-org/ccache-action@v1.2.21 + with: + key: gcc-${{ matrix.arch }}-${{ matrix.build }} + evict-old-files: 1d + save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} + + - name: Set up QEMU + uses: docker/setup-qemu-action@v3 + + - name: Build ${{ matrix.arch }} + run: | + docker run --platform ${{ matrix.arch }} --rm \ + -v ${{ github.workspace }}:/workspace \ + -v ${CCACHE_DIR}:${CCACHE_DIR} \ + -e CCACHE_DIR=${CCACHE_DIR} \ + -w /workspace ${{ env.ubuntu_image }} /bin/sh -c ' + set -e + export DEBIAN_FRONTEND=noninteractive + sed -i "s|archive.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list + sed -i "s|security.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list + + apt update + apt install -y build-essential cmake libsdl2-dev git ccache + cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} \ + -DGGML_NATIVE=OFF \ + -DGGML_CPU_ARM_ARCH=armv7-a+fp \ + -DCMAKE_C_COMPILER_LAUNCHER=ccache \ + -DCMAKE_CXX_COMPILER_LAUNCHER=ccache + make + ctest -L gh --output-on-failure' diff --git a/.github/workflows/build-macos.yml b/.github/workflows/build-macos.yml new file mode 100644 index 00000000000..804f8bbb642 --- /dev/null +++ b/.github/workflows/build-macos.yml @@ -0,0 +1,72 @@ +name: CI (macOS) + +on: + workflow_dispatch: # allows manual triggering + push: + branches: + - master + paths: ['.github/workflows/build-macos.yml', + '**/CMakeLists.txt', + '**/Makefile', + '**/*.mk', + '**/*.cmake', + '**/*.in', + '**/*.h', + '**/*.hpp', + '**/*.c', + '**/*.cpp', + '**/*.cu', + '**/*.cuh', + '**/*.swift', + '**/*.m', + '**/*.mm', + '**/*.metal'] + + pull_request: + types: [opened, synchronize, reopened] + paths-ignore: + - 'bindings/ruby/**' # handled by bindings-ruby.yml + - 'bindings/go/**' # handled by bindings-go.yml + - 'examples/addon.node/**' # handled by examples.yml + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} + cancel-in-progress: true + +jobs: + macOS-latest: + runs-on: macOS-latest + + strategy: + matrix: + destination: ['generic/platform=macOS', 'generic/platform=iOS', 'generic/platform=tvOS'] + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v6 + + - name: ccache + uses: ggml-org/ccache-action@v1.2.21 + with: + key: macos-${{ matrix.destination }} + evict-old-files: 1d + save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} + + - name: Dependencies + run: | + brew update + cmake --version + brew install sdl2 + + - name: Build + run: | + sysctl -a + cmake -B build -G Xcode \ + -DGGML_METAL_USE_BF16=ON \ + -DGGML_METAL_EMBED_LIBRARY=ON \ + -DWHISPER_BUILD_EXAMPLES=OFF \ + -DWHISPER_BUILD_TESTS=OFF \ + -DWHISPER_BUILD_SERVER=OFF \ + -DCMAKE_OSX_ARCHITECTURES="arm64;x86_64" + cmake --build build --config Release -j $(sysctl -n hw.logicalcpu) diff --git a/.github/workflows/build-quantize.yml b/.github/workflows/build-quantize.yml new file mode 100644 index 00000000000..8036a3a3450 --- /dev/null +++ b/.github/workflows/build-quantize.yml @@ -0,0 +1,41 @@ +name: CI (quantize) + +on: + workflow_dispatch: # allows manual triggering + push: + branches: + - master + paths: ['.github/workflows/build-quantize.yml', + '**/CMakeLists.txt', + '**/*.h', + '**/*.hpp', + '**/*.c', + '**/*.cpp'] + + pull_request: + types: [opened, synchronize, reopened] + paths-ignore: + - 'bindings/ruby/**' # handled by bindings-ruby.yml + - 'bindings/go/**' # handled by bindings-go.yml + - 'examples/addon.node/**' # handled by examples.yml + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} + cancel-in-progress: true + +jobs: + quantize: + runs-on: ubuntu-22.04 + + steps: + - name: Clone + uses: actions/checkout@v6 + + - name: Test quantize + env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} + run: | + ./models/download-ggml-model.sh tiny.en + cmake -B build + cmake --build build --config Release + ./build/bin/whisper-quantize models/ggml-tiny.en.bin models/ggml-tiny.en-q4_0.bin q4_0 diff --git a/.github/workflows/build-sanitize.yml b/.github/workflows/build-sanitize.yml new file mode 100644 index 00000000000..9250fe81023 --- /dev/null +++ b/.github/workflows/build-sanitize.yml @@ -0,0 +1,82 @@ +name: CI (sanitize) + +on: + workflow_dispatch: # allows manual triggering + push: + branches: + - master + paths: ['.github/workflows/build-sanitize.yml', + '**/CMakeLists.txt', + '**/Makefile', + '**/*.mk', + '**/*.cmake', + '**/*.h', + '**/*.hpp', + '**/*.c', + '**/*.cpp'] + + pull_request: + types: [opened, synchronize, reopened] + paths-ignore: + - 'bindings/ruby/**' + - 'bindings/go/**' + - 'examples/addon.node/**' + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} + cancel-in-progress: true + +jobs: + ubuntu-22-gcc-sanitized: + runs-on: ubuntu-22.04 + + continue-on-error: true + + strategy: + fail-fast: false + matrix: + sanitizer: [ADDRESS, THREAD, UNDEFINED] + + steps: + - name: Clone + uses: actions/checkout@v6 + + - name: ccache + uses: ggml-org/ccache-action@v1.2.21 + with: + key: sanitize-${{ matrix.sanitizer }} + evict-old-files: 1d + save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} + + - name: Install dependencies + run: | + sudo apt-get update + sudo apt-get install -y build-essential cmake git + + - name: Build (undefined) + if: ${{ matrix.sanitizer == 'UNDEFINED' }} + run: | + cmake . -DCMAKE_BUILD_TYPE=Debug \ + -DWHISPER_SANITIZE_${{ matrix.sanitizer }}=ON \ + -DGGML_OPENMP=OFF + make + + - name: Build + if: ${{ matrix.sanitizer == 'ADDRESS' }} + run: | + cmake . -DCMAKE_BUILD_TYPE=RelWithDebInfo \ + -DWHISPER_SANITIZE_${{ matrix.sanitizer }}=ON + make + + - name: Build (no OpenMP) + if: ${{ matrix.sanitizer == 'THREAD' }} + run: | + cmake . -DCMAKE_BUILD_TYPE=RelWithDebInfo \ + -DWHISPER_SANITIZE_${{ matrix.sanitizer }}=ON \ + -DGGML_OPENMP=OFF + make + + - name: Test + if: ${{ matrix.sanitizer != 'UNDEFINED' }} + run: | + ctest -L gh --output-on-failure diff --git a/.github/workflows/build-self-hosted.yml b/.github/workflows/build-self-hosted.yml new file mode 100644 index 00000000000..3fe131b9ba5 --- /dev/null +++ b/.github/workflows/build-self-hosted.yml @@ -0,0 +1,116 @@ +name: CI (self-hosted) + +on: + workflow_dispatch: # allows manual triggering + push: + branches: + - master + paths: [ + '.github/workflows/build.yml', + '**/CMakeLists.txt', + '**/.cmake', + '**/*.h', + '**/*.hpp', + '**/*.c', + '**/*.cpp', + '**/*.cu', + '**/*.cuh', + '**/*.swift', + '**/*.m', + '**/*.mm', + '**/*.metal', + '**/*.comp' + ] + + pull_request: + types: [opened, synchronize, reopened] + paths: [ + '.github/workflows/build-self-hosted.yml', + '**/CMakeLists.txt', + '**/.cmake', + '**/*.h', + '**/*.hpp', + '**/*.c', + '**/*.cpp', + '**/*.cu', + '**/*.cuh', + '**/*.swift', + '**/*.m', + '**/*.mm', + '**/*.metal', + '**/*.comp' + ] + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} + cancel-in-progress: true + +jobs: + gpu-cuda: + runs-on: [self-hosted, Linux, NVIDIA] + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v6 + + - name: Test + id: ggml-ci + run: | + nvidia-smi + GG_BUILD_CUDA=1 bash ./ci/run.sh ~/results/whisper.cpp ~/mnt/whisper.cpp + + gpu-vulkan-nvidia-cm: + runs-on: [self-hosted, Linux, NVIDIA] + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v6 + + - name: Test + id: ggml-ci + run: | + vulkaninfo --summary + GG_BUILD_VULKAN=1 GGML_VK_DISABLE_COOPMAT2=1 bash ./ci/run.sh ~/results/whisper.cpp ~/mnt/whisper.cpp + + gpu-vulkan-nvidia-cm2: + runs-on: [self-hosted, Linux, NVIDIA, COOPMAT2] + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v6 + + - name: Test + id: ggml-ci + run: | + vulkaninfo --summary + GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/whisper.cpp ~/mnt/whisper.cpp + + gpu-metal: + runs-on: [self-hosted, macOS, ARM64] + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v6 + + - name: Test + id: ggml-ci + run: | + GG_BUILD_METAL=1 bash ./ci/run.sh ~/results/whisper.cpp ~/mnt/whisper.cpp + + gpu-vulkan: + runs-on: [self-hosted, macOS, ARM64] + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v6 + + - name: Test + id: ggml-ci + run: | + vulkaninfo --summary + GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/whisper.cpp ~/mnt/whisper.cpp diff --git a/.github/workflows/build-sycl.yml b/.github/workflows/build-sycl.yml new file mode 100644 index 00000000000..57aa7cc4d95 --- /dev/null +++ b/.github/workflows/build-sycl.yml @@ -0,0 +1,132 @@ +name: CI (sycl) + +on: + workflow_dispatch: # allows manual triggering + push: + branches: + - master + paths: ['.github/workflows/build-sycl.yml', + '**/CMakeLists.txt', + '**/Makefile', + '**/*.mk', + '**/*.cmake', + '**/*.in', + '**/*.h', + '**/*.hpp', + '**/*.c', + '**/*.cpp', + '**/*.cu', + '**/*.cuh', + '**/*.cl'] + + pull_request: + types: [opened, synchronize, reopened] + paths-ignore: + - 'bindings/ruby/**' # handled by bindings-ruby.yml + - 'bindings/go/**' # handled by bindings-go.yml + - 'examples/addon.node/**' # handled by examples.yml + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} + cancel-in-progress: true + +jobs: + ubuntu-22-cmake-sycl: + runs-on: ubuntu-22.04 + + strategy: + fail-fast: false + matrix: + dwhisper_sycl: [ON] + dcmake_c_compiler: [icx] + dcmake_cxx_compiler: [icpx] + arch: [linux/amd64, linux/arm64, linux/arm/v7, linux/ppc64le] + + continue-on-error: true + + steps: + - name: Clone + uses: actions/checkout@v6 + + - name: add oneAPI to apt + shell: bash + run: | + cd /tmp + wget https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB + sudo apt-key add GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB + rm GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB + sudo add-apt-repository "deb https://apt.repos.intel.com/oneapi all main" + + - name: install oneAPI dpcpp compiler + shell: bash + run: | + sudo apt update + sudo apt install intel-oneapi-compiler-dpcpp-cpp git + + - name: install oneAPI MKL library + shell: bash + run: | + sudo apt install intel-oneapi-mkl-devel git + + - name: Clone + id: checkout + uses: actions/checkout@v6 + + - name: Build + id: cmake_build + run: | + source /opt/intel/oneapi/setvars.sh + mkdir build + cd build + cmake -DGGML_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx .. + cmake --build . --config Release -j $(nproc) + + ubuntu-22-cmake-sycl-fp16: + runs-on: ubuntu-22.04 + + strategy: + fail-fast: false + matrix: + dwhisper_sycl: [ON] + dcmake_c_compiler: [icx] + dcmake_cxx_compiler: [icpx] + arch: [linux/amd64, linux/arm64, linux/arm/v7, linux/ppc64le] + + continue-on-error: true + + steps: + - name: Clone + uses: actions/checkout@v6 + + - name: add oneAPI to apt + shell: bash + run: | + cd /tmp + wget https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB + sudo apt-key add GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB + rm GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB + sudo add-apt-repository "deb https://apt.repos.intel.com/oneapi all main" + + - name: install oneAPI dpcpp compiler + shell: bash + run: | + sudo apt update + sudo apt install intel-oneapi-compiler-dpcpp-cpp git + + - name: install oneAPI MKL library + shell: bash + run: | + sudo apt install intel-oneapi-mkl-devel + + - name: Clone + id: checkout + uses: actions/checkout@v6 + + - name: Build + id: cmake_build + run: | + source /opt/intel/oneapi/setvars.sh + mkdir build + cd build + cmake -DGGML_SYCL_F16=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx .. + cmake --build . --config Release -j $(nproc) diff --git a/.github/workflows/build-vad.yml b/.github/workflows/build-vad.yml new file mode 100644 index 00000000000..71e910a3fcb --- /dev/null +++ b/.github/workflows/build-vad.yml @@ -0,0 +1,43 @@ +name: CI (vad) + +on: + workflow_dispatch: # allows manual triggering + push: + branches: + - master + paths: ['.github/workflows/build-vad.yml', + '**/CMakeLists.txt', + '**/*.h', + '**/*.hpp', + '**/*.c', + '**/*.cpp'] + + pull_request: + types: [opened, synchronize, reopened] + paths-ignore: + - 'bindings/ruby/**' # handled by bindings-ruby.yml + - 'bindings/go/**' # handled by bindings-go.yml + - 'examples/addon.node/**' # handled by examples.yml + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} + cancel-in-progress: true + +jobs: + vad: + runs-on: ubuntu-latest + + steps: + - name: Checkout + uses: actions/checkout@v6 + + - name: Build + shell: bash + run: | + cmake -B build + cmake --build build --config Release + + - name: Test + shell: bash + run: | + ctest -R ^test-vad$ --test-dir build --output-on-failure -VV diff --git a/.github/workflows/build-wasm.yml b/.github/workflows/build-wasm.yml new file mode 100644 index 00000000000..42a9401af3c --- /dev/null +++ b/.github/workflows/build-wasm.yml @@ -0,0 +1,51 @@ +name: CI (wasm) + +on: + workflow_dispatch: # allows manual triggering + push: + branches: + - master + paths: ['.github/workflows/build-wasm.yml', + '**/CMakeLists.txt', + '**/Makefile', + '**/*.mk', + '**/*.cmake', + '**/*.in', + '**/*.h', + '**/*.hpp', + '**/*.c', + '**/*.cpp'] + + pull_request: + types: [opened, synchronize, reopened] + paths-ignore: + - 'bindings/ruby/**' # handled by bindings-ruby.yml + - 'bindings/go/**' # handled by bindings-go.yml + - 'examples/addon.node/**' # handled by examples.yml + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} + cancel-in-progress: true + +jobs: + emscripten: + runs-on: ubuntu-22.04 + + strategy: + matrix: + build: [Release] + + steps: + - name: Clone + uses: actions/checkout@v6 + + - name: Setup emsdk + uses: mymindstorm/setup-emsdk@v14 + + - name: Verify + run: emcc -v + + - name: Build + run: | + emcmake cmake . -DCMAKE_BUILD_TYPE=${{ matrix.build }} + make diff --git a/.github/workflows/build-windows.yml b/.github/workflows/build-windows.yml new file mode 100644 index 00000000000..cd1591f0132 --- /dev/null +++ b/.github/workflows/build-windows.yml @@ -0,0 +1,76 @@ +name: CI (windows) + +on: + workflow_dispatch: # allows manual triggering + push: + branches: + - master + paths: ['.github/workflows/build-windows.yml', + '**/CMakeLists.txt', + '**/Makefile', + '**/*.mk', + '**/*.cmake', + '**/*.in', + '**/*.h', + '**/*.hpp', + '**/*.c', + '**/*.cpp', + '**/*.cu', + '**/*.cuh', + '**/*.cl'] + + pull_request: + types: [opened, synchronize, reopened] + paths-ignore: + - 'bindings/ruby/**' # handled by bindings-ruby.yml + - 'bindings/go/**' # handled by bindings-go.yml + - 'examples/addon.node/**' # handled by examples.yml + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} + cancel-in-progress: true + +jobs: + windows-msys2: + runs-on: windows-latest + + strategy: + fail-fast: false + matrix: + include: + - { sys: UCRT64, env: ucrt-x86_64, build: Release } + - { sys: CLANG64, env: clang-x86_64, build: Release } + + steps: + - name: Clone + uses: actions/checkout@v6 + + - name: Setup ${{ matrix.sys }} + uses: msys2/setup-msys2@v2 + with: + update: true + msystem: ${{matrix.sys}} + install: >- + base-devel + git + mingw-w64-${{matrix.env}}-toolchain + mingw-w64-${{matrix.env}}-cmake + mingw-w64-${{matrix.env}}-SDL2 + mingw-w64-${{matrix.env}}-openblas + + - name: Build using CMake + shell: msys2 {0} + run: | + cmake -B build -DWHISPER_SDL2=ON + cmake --build build --config ${{ matrix.build }} -j $(nproc) + + - name: Clean after building using CMake + shell: msys2 {0} + run: | + rm -rf build + + - name: Build using CMake w/ OpenBLAS + shell: msys2 {0} + run: | + cmake -B build -DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS + cmake --build build --config ${{ matrix.build }} -j $(nproc) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml deleted file mode 100644 index b7badd51041..00000000000 --- a/.github/workflows/build.yml +++ /dev/null @@ -1,1573 +0,0 @@ -name: CI - -on: - push: - branches: - - master - tags: - - 'v*' - paths: ['.github/workflows/build.yml', - '**/CMakeLists.txt', - '**/Makefile', - '**/*.mk', - '**/*.cmake', - '**/*.in', - '**/*.h', - '**/*.hpp', - '**/*.c', - '**/*.cpp', - '**/*.cu', - '**/*.cuh', - '**/*.cl', - '**/*.swift', - '**/*.m', - '**/*.mm', - '**/*.metal', - '**/*.comp', - '**/*.java'] - - pull_request: - types: [opened, synchronize, reopened] - paths-ignore: - - 'bindings/ruby/**' # handled by bindings-ruby.yml - - 'bindings/go/**' # handled by bindings-go.yml - - 'examples/addon.node/**' # handled by examples.yml - workflow_dispatch: - inputs: - create_release: - description: 'Create new release' - required: true - type: boolean - pre_release_tag: - description: 'Pre-release tag name' - required: false - type: string - run_type: - description: 'Workflow type to run' - required: true - type: choice - options: - - full-ci - - release-only - -concurrency: - group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} - cancel-in-progress: true - -permissions: - contents: write # for creating release - -env: - BRANCH_NAME: ${{ github.head_ref || github.ref_name }} - ubuntu_image: "ubuntu:22.04" - VCPKG_BINARY_SOURCES: "clear;x-gha,readwrite" - -jobs: - determine-tag: - runs-on: ubuntu-latest - outputs: - tag_name: ${{ steps.tag.outputs.name }} - should_release: ${{ steps.tag.outputs.should_release }} - - steps: - - name: Checkout with full history - uses: actions/checkout@v6 - with: - fetch-depth: 0 - - - name: Determine tag name - id: tag - shell: bash - run: | - BUILD_NUMBER=$(git rev-list --count HEAD) - SHORT_HASH=$(git rev-parse --short=7 HEAD) - CUSTOM_TAG="${{ github.event.inputs.pre_release_tag }}" - SHOULD_RELEASE="false" - - echo "Raw values:" - echo "BUILD_NUMBER: $BUILD_NUMBER" - echo "SHORT_HASH: $SHORT_HASH" - echo "BRANCH_NAME: ${{ env.BRANCH_NAME }}" - echo "CUSTOM_TAG: $CUSTOM_TAG" - - if [[ "${{ github.ref_type }}" == "tag" ]]; then - echo "Using pushed tag name" - TAG_NAME="${{ github.ref_name }}" - SHOULD_RELEASE="true" - elif [[ -n "$CUSTOM_TAG" ]]; then - echo "Using custom tag" - TAG_NAME="${CUSTOM_TAG}" - SHOULD_RELEASE="true" - elif [[ "${{ github.event.inputs.create_release }}" == "true" ]]; then - echo "Manual release requested" - SHOULD_RELEASE="true" - TAG_NAME="b${BUILD_NUMBER}" - elif [[ "${{ env.BRANCH_NAME }}" == "master" ]]; then - echo "Using master branch format" - TAG_NAME="b${BUILD_NUMBER}" - SHOULD_RELEASE="false" - else - echo "Using non-master branch format" - SAFE_NAME=$(echo "${{ env.BRANCH_NAME }}" | tr '/' '-') - TAG_NAME="${SAFE_NAME}-b${BUILD_NUMBER}-${SHORT_HASH}" - SHOULD_RELEASE="false" - fi - - echo "Final tag name: $TAG_NAME" - echo "Should release: $SHOULD_RELEASE" - echo "name=$TAG_NAME" >> $GITHUB_OUTPUT - echo "should_release=$SHOULD_RELEASE" >> $GITHUB_OUTPUT - - - ubuntu-22: - if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || - github.event.inputs.run_type == 'full-ci' }} - runs-on: ubuntu-22.04 - - strategy: - fail-fast: false - matrix: - arch: [linux/amd64, linux/ppc64le] - - steps: - - name: Clone - uses: actions/checkout@v6 - - - name: Set up QEMU - uses: docker/setup-qemu-action@v3 - - - name: Build ${{ matrix.arch }} - run: | - docker run --platform ${{ matrix.arch }} --rm \ - -v ${{ github.workspace }}:/workspace \ - -w /workspace ${{ env.ubuntu_image }} /bin/sh -c ' - set -e - export DEBIAN_FRONTEND=noninteractive - sed -i "s|archive.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list - sed -i "s|security.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list - - apt update - apt install -y build-essential libsdl2-dev cmake git - cmake -B build - cmake --build build --config Release -j $(nproc)' - - ubuntu-22-arm64: - if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || - github.event.inputs.run_type == 'full-ci' }} - runs-on: ubuntu-22.04-arm - - steps: - - name: Clone - uses: actions/checkout@v6 - - - name: Install dependencies - run: | - sudo apt-get update - sudo apt-get install -y build-essential libsdl2-dev cmake git - - - name: Build - run: | - cmake -B build -DGGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=armv8-a - cmake --build build --config Release -j $(nproc) - - ubuntu-22-arm-v7: - if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || - github.event.inputs.run_type == 'full-ci' }} - runs-on: ubuntu-22.04 - - strategy: - fail-fast: false - matrix: - arch: [linux/arm/v7] - - steps: - - name: Clone - uses: actions/checkout@v6 - - - name: Set up QEMU - uses: docker/setup-qemu-action@v3 - - - name: Build ${{ matrix.arch }} - run: | - docker run --platform ${{ matrix.arch }} --rm \ - -v ${{ github.workspace }}:/workspace \ - -w /workspace ${{ env.ubuntu_image }} /bin/sh -c ' - set -e - export DEBIAN_FRONTEND=noninteractive - sed -i "s|archive.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list - sed -i "s|security.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list - - apt update - apt install -y build-essential libsdl2-dev cmake git - cmake -B build -DGGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=armv7-a+fp - cmake --build build --config Release -j $(nproc)' - - macOS-latest: - if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || - github.event.inputs.run_type == 'full-ci' }} - runs-on: macOS-latest - - strategy: - matrix: - destination: ['generic/platform=macOS', 'generic/platform=iOS', 'generic/platform=tvOS'] - - steps: - - name: Clone - id: checkout - uses: actions/checkout@v6 - - - name: ccache - uses: hendrikmuhs/ccache-action@v1.2.16 - with: - key: macOS-latest-swift - evict-old-files: 1d - - - name: Dependencies - run: | - brew update - cmake --version - brew install sdl2 - - - name: Build - run: | - sysctl -a - cmake -B build -G Xcode \ - -DGGML_METAL_USE_BF16=ON \ - -DGGML_METAL_EMBED_LIBRARY=ON \ - -DWHISPER_BUILD_EXAMPLES=OFF \ - -DWHISPER_BUILD_TESTS=OFF \ - -DWHISPER_BUILD_SERVER=OFF \ - -DCMAKE_OSX_ARCHITECTURES="arm64;x86_64" - cmake --build build --config Release -j $(sysctl -n hw.logicalcpu) - - -# freeBSD-latest: -# runs-on: macos-13 -# -# steps: -# - name: Clone -# uses: actions/checkout@v6 -# -# - name: Build -# uses: cross-platform-actions/action@v0.27.0 -# with: -# operating_system: freebsd -# version: '14.2' -# run: | -# sudo pkg update -# sudo pkg install -y gmake sdl2 cmake git -# cmake -B build -# cmake --build build --config Release - - ubuntu-22-gcc: - if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || - github.event.inputs.run_type == 'full-ci' }} - runs-on: ubuntu-22.04 - - strategy: - fail-fast: false - matrix: - build: [Debug, Release] - arch: [linux/amd64, linux/ppc64le] - - steps: - - name: Clone - uses: actions/checkout@v6 - - - name: Set up QEMU - uses: docker/setup-qemu-action@v3 - - - name: Build ${{ matrix.arch }} - run: | - docker run --platform ${{ matrix.arch }} --rm \ - -v ${{ github.workspace }}:/workspace \ - -w /workspace ${{ env.ubuntu_image }} /bin/sh -c ' - set -e - export DEBIAN_FRONTEND=noninteractive - sed -i "s|archive.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list - sed -i "s|security.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list - - apt update - apt install -y build-essential cmake libsdl2-dev git - cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} - make - ctest -L gh --output-on-failure' - - ubuntu-22-gcc-arm64: - if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || - github.event.inputs.run_type == 'full-ci' }} - runs-on: ubuntu-22.04-arm - - strategy: - fail-fast: false - matrix: - build: [Debug, Release] - - steps: - - name: Clone - uses: actions/checkout@v6 - - - name: Install dependencies - run: | - sudo apt-get update - sudo apt-get install -y build-essential cmake libsdl2-dev git - - - name: Configure CMake - run: | - cmake . \ - -DWHISPER_SDL2=ON \ - -DCMAKE_BUILD_TYPE=${{ matrix.build }} \ - -DGGML_NATIVE=OFF \ - -DGGML_CPU_ARM_ARCH=armv8-a - - - name: Build and Test - run: | - make - ctest -L gh --output-on-failure - - ubuntu-22-gcc-arm-v7: - if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || - github.event.inputs.run_type == 'full-ci' }} - runs-on: ubuntu-22.04 - - strategy: - fail-fast: false - matrix: - build: [Debug, Release] - arch: [linux/arm/v7] - - steps: - - name: Clone - uses: actions/checkout@v6 - - - name: Set up QEMU - uses: docker/setup-qemu-action@v3 - - - name: Build ${{ matrix.arch }} - run: | - docker run --platform ${{ matrix.arch }} --rm \ - -v ${{ github.workspace }}:/workspace \ - -w /workspace ${{ env.ubuntu_image }} /bin/sh -c ' - set -e - export DEBIAN_FRONTEND=noninteractive - sed -i "s|archive.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list - sed -i "s|security.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list - - apt update - apt install -y build-essential cmake libsdl2-dev git - cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} -DGGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=armv7-a+fp - make - ctest -L gh --output-on-failure' - - ubuntu-22-clang: - if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || - github.event.inputs.run_type == 'full-ci' }} - runs-on: ubuntu-22.04 - - strategy: - fail-fast: false - matrix: - build: [Debug, Release] - #arch: [linux/amd64, linux/arm64, linux/arm/v7, linux/ppc64le] - # TODO: arm/v7 disabled due to clang bug - # https://github.com/ggerganov/whisper.cpp/actions/runs/9657764109/job/26637633042?pr=2256#step:4:1990 - arch: [linux/amd64, linux/ppc64le] - - steps: - - name: Clone - uses: actions/checkout@v6 - - - name: Set up QEMU - uses: docker/setup-qemu-action@v3 - - - name: Build ${{ matrix.arch }} - run: | - docker run --platform ${{ matrix.arch }} --rm \ - -v ${{ github.workspace }}:/workspace \ - -w /workspace ${{ env.ubuntu_image }} /bin/sh -c ' - set -e - export DEBIAN_FRONTEND=noninteractive - sed -i "s|archive.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list - sed -i "s|security.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list - - apt update - apt install -y clang build-essential cmake libsdl2-dev git - cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_COMPILER=clang - make - ctest -L gh --output-on-failure' - - ubuntu-22-clang-arm64: - if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || - github.event.inputs.run_type == 'full-ci' }} - runs-on: ubuntu-22.04-arm - - strategy: - fail-fast: false - matrix: - build: [Debug, Release] - - steps: - - name: Clone - uses: actions/checkout@v6 - - - name: Install dependencies - run: | - sudo apt-get update - sudo apt-get install -y clang build-essential cmake libsdl2-dev git - - - name: Build and Test - run: | - cmake . -DWHISPER_SDL2=ON \ - -DCMAKE_BUILD_TYPE=${{ matrix.build }} \ - -DCMAKE_CXX_COMPILER=clang++ \ - -DCMAKE_C_COMPILER=clang \ - -DGGML_NATIVE=OFF \ - -DGGML_CPU_ARM_ARCH=armv8-a - make - ctest -L gh --output-on-failure - - ubuntu-22-gcc-sanitized: - if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || - github.event.inputs.run_type == 'full-ci' }} - runs-on: ubuntu-22.04 - - strategy: - fail-fast: false - matrix: - sanitizer: [ADDRESS, THREAD, UNDEFINED] - - steps: - - name: Clone - uses: actions/checkout@v6 - - - name: Install dependencies - run: | - sudo apt-get update - sudo apt-get install -y build-essential cmake git - - - name: Build and Test - run: | - cmake . -DCMAKE_BUILD_TYPE=Debug \ - -DWHISPER_SANITIZE_${{ matrix.sanitizer }}=ON \ - -DGGML_OPENMP=OFF - make - ctest -L gh --output-on-failure - - ubuntu-22-cmake-sycl: - if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || - github.event.inputs.run_type == 'full-ci' }} - runs-on: ubuntu-22.04 - - strategy: - fail-fast: false - matrix: - dwhisper_sycl: [ON] - dcmake_c_compiler: [icx] - dcmake_cxx_compiler: [icpx] - arch: [linux/amd64, linux/arm64, linux/arm/v7, linux/ppc64le] - - continue-on-error: true - - steps: - - name: Clone - uses: actions/checkout@v6 - - - name: add oneAPI to apt - shell: bash - run: | - cd /tmp - wget https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB - sudo apt-key add GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB - rm GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB - sudo add-apt-repository "deb https://apt.repos.intel.com/oneapi all main" - - - name: install oneAPI dpcpp compiler - shell: bash - run: | - sudo apt update - sudo apt install intel-oneapi-compiler-dpcpp-cpp git - - - name: install oneAPI MKL library - shell: bash - run: | - sudo apt install intel-oneapi-mkl-devel git - - - name: Clone - id: checkout - uses: actions/checkout@v6 - - - name: Build - id: cmake_build - run: | - source /opt/intel/oneapi/setvars.sh - mkdir build - cd build - cmake -DGGML_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx .. - cmake --build . --config Release -j $(nproc) - - ubuntu-22-cmake-sycl-fp16: - if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || - github.event.inputs.run_type == 'full-ci' }} - runs-on: ubuntu-22.04 - - strategy: - fail-fast: false - matrix: - dwhisper_sycl: [ON] - dcmake_c_compiler: [icx] - dcmake_cxx_compiler: [icpx] - arch: [linux/amd64, linux/arm64, linux/arm/v7, linux/ppc64le] - - continue-on-error: true - - steps: - - name: Clone - uses: actions/checkout@v6 - - - name: add oneAPI to apt - shell: bash - run: | - cd /tmp - wget https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB - sudo apt-key add GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB - rm GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB - sudo add-apt-repository "deb https://apt.repos.intel.com/oneapi all main" - - - name: install oneAPI dpcpp compiler - shell: bash - run: | - sudo apt update - sudo apt install intel-oneapi-compiler-dpcpp-cpp git - - - name: install oneAPI MKL library - shell: bash - run: | - sudo apt install intel-oneapi-mkl-devel - - - name: Clone - id: checkout - uses: actions/checkout@v6 - - - name: Build - id: cmake_build - run: | - source /opt/intel/oneapi/setvars.sh - mkdir build - cd build - cmake -DGGML_SYCL_F16=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx .. - cmake --build . --config Release -j $(nproc) - - windows-msys2: - if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || - github.event.inputs.run_type == 'full-ci' }} - runs-on: windows-latest - - strategy: - fail-fast: false - matrix: - include: - - { sys: UCRT64, env: ucrt-x86_64, build: Release } - - { sys: CLANG64, env: clang-x86_64, build: Release } - - steps: - - name: Clone - uses: actions/checkout@v6 - - - name: Setup ${{ matrix.sys }} - uses: msys2/setup-msys2@v2 - with: - update: true - msystem: ${{matrix.sys}} - install: >- - base-devel - git - mingw-w64-${{matrix.env}}-toolchain - mingw-w64-${{matrix.env}}-cmake - mingw-w64-${{matrix.env}}-SDL2 - mingw-w64-${{matrix.env}}-openblas - - - name: Build using CMake - shell: msys2 {0} - run: | - cmake -B build -DWHISPER_SDL2=ON - cmake --build build --config ${{ matrix.build }} -j $(nproc) - - - name: Clean after building using CMake - shell: msys2 {0} - run: | - rm -rf build - - - name: Build using CMake w/ OpenBLAS - shell: msys2 {0} - run: | - cmake -B build -DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS - cmake --build build --config ${{ matrix.build }} -j $(nproc) - - windows: - if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || - github.event.inputs.run_type == 'full-ci' }} - runs-on: windows-latest - needs: determine-tag - - strategy: - matrix: - build: [Release] - arch: [Win32, x64] - sdl2: [ON] - include: - - arch: Win32 - s2arc: x86 - jnaPath: win32-x86 - - arch: x64 - s2arc: x64 - jnaPath: win32-x86-64 - - sdl2: ON - s2ver: 2.28.5 - - steps: - - name: Clone - uses: actions/checkout@v6 - - - name: Add msbuild to PATH - uses: microsoft/setup-msbuild@v2 - - - name: Fetch SDL2 and set SDL2_DIR - if: matrix.sdl2 == 'ON' - run: | - C:/msys64/usr/bin/wget.exe -qO sdl2.zip https://github.com/libsdl-org/SDL/releases/download/release-${{ matrix.s2ver }}/SDL2-devel-${{ matrix.s2ver }}-VC.zip - 7z x sdl2.zip - echo "SDL2_DIR=$env:GITHUB_WORKSPACE/SDL2-${{ matrix.s2ver }}/cmake" >> $env:GITHUB_ENV - - - name: Configure - run: > - cmake -S . -B ./build -A ${{ matrix.arch }} - -DCMAKE_BUILD_TYPE=${{ matrix.build }} - -DBUILD_SHARED_LIBS=ON - -DWHISPER_SDL2=${{ matrix.sdl2 }} - -DGGML_NATIVE=OFF - -DGGML_BMI2=OFF - - - name: Build - run: | - cd ./build - msbuild ALL_BUILD.vcxproj -t:build -p:configuration=${{ matrix.build }} -p:platform=${{ matrix.arch }} - - - name: Copy SDL2.dll - if: matrix.sdl2 == 'ON' - run: copy "$env:SDL2_DIR/../lib/${{ matrix.s2arc }}/SDL2.dll" build/bin/${{ matrix.build }} - - - name: Upload SDL2.dll - if: matrix.sdl2 == 'ON' - uses: actions/upload-artifact@v6 - with: - name: ${{ matrix.s2arc }}_SDL2.dll - path: build/bin/${{ matrix.build }}/SDL2.dll - - - name: Upload whisper dll - uses: actions/upload-artifact@v6 - with: - name: whisper_${{ matrix.arch }}.dll - path: build/bin/${{ matrix.build }}/whisper.dll - - - name: Upload ggml dll - uses: actions/upload-artifact@v6 - with: - name: ggml_${{ matrix.arch }}.dll - path: build/bin/${{ matrix.build }}/ggml.dll - overwrite: true - - - name: Upload ggml base dll - uses: actions/upload-artifact@v6 - with: - name: ggml_base_${{ matrix.arch }}.dll - path: build/bin/${{ matrix.build }}/ggml-base.dll - - - name: Upload ggml cpu dll - uses: actions/upload-artifact@v6 - with: - name: ggml_cpu_${{ matrix.arch }}.dll - path: build/bin/${{ matrix.build }}/ggml-cpu.dll - - - name: Pack bin artifacts - shell: pwsh - run: | - Compress-Archive -Path "build/bin/${{ matrix.build }}" -DestinationPath "whisper-bin-${{ matrix.arch }}.zip" - - - name: Upload binaries - if: matrix.sdl2 == 'ON' && ${{ needs.determine-tag.outputs.should_release }} - uses: actions/upload-artifact@v6 - with: - name: whisper-bin-${{ matrix.arch }}.zip - path: whisper-bin-${{ matrix.arch }}.zip - - windows-blas: - if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || - github.event.inputs.run_type == 'full-ci' }} - runs-on: windows-latest - - strategy: - matrix: - build: [Release] - arch: [Win32, x64] - blas: [ON] - sdl2: [ON] - blasver: [0.3.29] - include: - - arch: Win32 - s2arc: x86 - blasfile: x86 - - arch: x64 - s2arc: x64 - blasfile: x64_64 - - sdl2: ON - s2ver: 2.28.5 - - steps: - - name: Clone - uses: actions/checkout@v6 - - - name: Export GitHub Actions cache environment variables - uses: actions/github-script@v8 - with: - script: | - core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); - core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); - - - name: Add msbuild to PATH - uses: microsoft/setup-msbuild@v2 - - - name: Install OpenBLAS and pkgconfiglite - if: matrix.blas == 'ON' - run: | - Invoke-WebRequest "https://github.com/OpenMathLib/OpenBLAS/releases/download/v${{matrix.blasver}}/OpenBLAS-${{matrix.blasver}}_${{matrix.blasfile}}.zip" -OutFile "OpenBLAS-${{matrix.blasver}}.zip" - Expand-Archive "OpenBLAS-${{matrix.blasver}}.zip" -DestinationPath "OpenBLAS-${{matrix.blasver}}" - choco install pkgconfiglite - - - name: Fetch SDL2 and set SDL2_DIR - if: matrix.sdl2 == 'ON' - run: | - C:/msys64/usr/bin/wget.exe -qO sdl2.zip https://github.com/libsdl-org/SDL/releases/download/release-${{ matrix.s2ver }}/SDL2-devel-${{ matrix.s2ver }}-VC.zip - 7z x sdl2.zip - echo "SDL2_DIR=$env:GITHUB_WORKSPACE/SDL2-${{ matrix.s2ver }}/cmake" >> $env:GITHUB_ENV - - - name: Configure - run: > - cmake -S . -B ./build -A ${{ matrix.arch }} - -DCMAKE_TOOLCHAIN_FILE="$env:VCPKG_INSTALLATION_ROOT/scripts/buildsystems/vcpkg.cmake" - -DCMAKE_BUILD_TYPE=${{ matrix.build }} - -DGGML_BLAS=${{ matrix.blas }} - -DGGML_BLAS_VENDOR=OpenBLAS - -DBLAS_LIBRARIES="$env:GITHUB_WORKSPACE/OpenBLAS-${{matrix.blasver}}/lib/libopenblas.lib" - -DBLAS_INCLUDE_DIRS="$env:GITHUB_WORKSPACE/OpenBLAS-${{matrix.blasver}}/include" - -DWHISPER_SDL2=${{ matrix.sdl2 }} - - - name: Build - run: | - cd ./build - msbuild ALL_BUILD.vcxproj -t:build -p:configuration=${{ matrix.build }} -p:platform=${{ matrix.arch }} - - - name: Copy openblas.dll - if: matrix.blas == 'ON' - run: copy "$env:GITHUB_WORKSPACE/OpenBLAS-${{matrix.blasver}}/bin/libopenblas.dll" build/bin/${{ matrix.build }} - - - name: Copy SDL2.dll - if: matrix.sdl2 == 'ON' - run: copy "$env:SDL2_DIR/../lib/${{ matrix.s2arc }}/SDL2.dll" build/bin/${{ matrix.build }} - - - name: Pack bin artifacts - shell: pwsh - run: | - Compress-Archive -Path "build/bin/${{ matrix.build }}" -DestinationPath "whisper-blas-bin-${{ matrix.arch }}.zip" - - - name: Upload binaries - if: matrix.blas == 'ON' && matrix.sdl2 == 'ON' && ${{ needs.determine-tag.outputs.should_release }} - uses: actions/upload-artifact@v6 - with: - name: whisper-blas-bin-${{ matrix.arch }}.zip - path: whisper-blas-bin-${{ matrix.arch }}.zip - - windows-cublas: - if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || - github.event.inputs.run_type == 'full-ci' }} - runs-on: windows-2022 - needs: determine-tag - strategy: - fail-fast: false - matrix: - build: [Release] - arch: [x64] - cublas: [ON] - sdl2: [ON] - cuda-toolkit: [12.4.0, 11.8.0] - include: - - arch: x64 - sdl2: ON - sdl2_ver: 2.28.5 - steps: - - name: Clone repository - uses: actions/checkout@v6 - - - name: Install Ninja - id: install_ninja - run: | - choco install ninja - - - name: Install ccache - uses: hendrikmuhs/ccache-action@v1.2.16 - with: - key: ${{ github.job }}-${{ matrix.cuda-toolkit }}-${{ matrix.build }} - variant: sccache - evict-old-files: 5d - - - name: Install Cuda Toolkit 11.8.0 - if: ${{ matrix.cuda-toolkit == '11.8.0' }} - run: | - $CUDA_VERSION = ${{ matrix.cuda-toolkit }} - $CUDA_TOOLKIT_DIR = "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v$CUDA_VERSION" - $CUDA_DOWNLOAD = "https://developer.download.nvidia.com/compute/cuda/redist" - - # Components versions - $CUDART_VER = "11.8.89" - $NVCC_VER = "11.8.89" - $NVRTC_VER = "11.8.89" - $CUBLAS_VER = "11.8.1.74" - $NVTX_VER = "11.8.86" - $VS_VER = "11.8.86" - $NVPROF_VER = "11.8.87" - $CCCL_VER = "11.8.89" - - # Create the directory where the CUDA Toolkit will be installed - mkdir -p $CUDA_TOOLKIT_DIR - - # Install unzip to extract the downloaded files - choco install unzip -y - - # Download all the required components - curl -O "$CUDA_DOWNLOAD/cuda_cudart/windows-x86_64/cuda_cudart-windows-x86_64-${CUDART_VER}-archive.zip" - curl -O "$CUDA_DOWNLOAD/cuda_nvcc/windows-x86_64/cuda_nvcc-windows-x86_64-${NVCC_VER}-archive.zip" - curl -O "$CUDA_DOWNLOAD/cuda_nvrtc/windows-x86_64/cuda_nvrtc-windows-x86_64-${NVRTC_VER}-archive.zip" - curl -O "$CUDA_DOWNLOAD/libcublas/windows-x86_64/libcublas-windows-x86_64-${CUBLAS_VER}-archive.zip" - curl -O "$CUDA_DOWNLOAD/cuda_nvtx/windows-x86_64/cuda_nvtx-windows-x86_64-${NVTX_VER}-archive.zip" - curl -O "$CUDA_DOWNLOAD/visual_studio_integration/windows-x86_64/visual_studio_integration-windows-x86_64-${VS_VER}-archive.zip" - curl -O "$CUDA_DOWNLOAD/cuda_nvprof/windows-x86_64/cuda_nvprof-windows-x86_64-${NVPROF_VER}-archive.zip" - curl -O "$CUDA_DOWNLOAD/cuda_cccl/windows-x86_64/cuda_cccl-windows-x86_64-${CCCL_VER}-archive.zip" - - # Extract all the downloaded files to the CUDA Toolkit directory - unzip '*.zip' -d $CUDA_TOOLKIT_DIR - - # Copy all the extracted files to the main CUDA Toolkit directory - xcopy "$CUDA_TOOLKIT_DIR\cuda_cudart-windows-x86_64-${CUDART_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y - xcopy "$CUDA_TOOLKIT_DIR\cuda_nvcc-windows-x86_64-${NVCC_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y - xcopy "$CUDA_TOOLKIT_DIR\cuda_nvrtc-windows-x86_64-${NVRTC_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y - xcopy "$CUDA_TOOLKIT_DIR\libcublas-windows-x86_64-${CUBLAS_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y - xcopy "$CUDA_TOOLKIT_DIR\cuda_nvtx-windows-x86_64-${NVTX_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y - xcopy "$CUDA_TOOLKIT_DIR\cuda_nvprof-windows-x86_64-${NVPROF_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y - xcopy "$CUDA_TOOLKIT_DIR\cuda_cccl-windows-x86_64-${CCCL_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y - xcopy "$CUDA_TOOLKIT_DIR\visual_studio_integration-windows-x86_64-${VS_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y - - # Visual Studio integration - xcopy "$CUDA_TOOLKIT_DIR\visual_studio_integration-windows-x86_64-${VS_VER}-archive\visual_studio_integration\MSBuildExtensions\*" "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\MSBuild\Microsoft\VC\v170\BuildCustomizations" /E /I /H /Y - - # Set environment variables - echo "$CUDA_TOOLKIT_DIR\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append - echo "$CUDA_TOOLKIT_DIR\libnvvp" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append - echo "CUDA_PATH=$CUDA_TOOLKIT_DIR" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8 - echo "CUDA_PATH_V11_8=$CUDA_TOOLKIT_DIR" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8 - - - name: Install Cuda Toolkit 12.4.0 - if: ${{ matrix.cuda-toolkit == '12.4.0' }} - run: | - $CUDA_VERSION = ${{ matrix.cuda-toolkit }} - $CUDA_TOOLKIT_DIR = "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v$CUDA_VERSION" - $CUDA_DOWNLOAD = "https://developer.download.nvidia.com/compute/cuda/redist" - - # Components versions - $CUDART_VER = "12.4.127" - $NVCC_VER = "12.4.131" - $NVRTC_VER = "12.4.127" - $CUBLAS_VER = "12.4.5.8" - $NVTX_VER = "12.4.127" - $PROFILER_VER = "12.4.127" - $VS_VER = "12.4.127" - $NVPROF_VER = "12.4.128" - $CCCL_VER = "12.4.127" - - # Create the directory where the CUDA Toolkit will be installed - mkdir -p $CUDA_TOOLKIT_DIR - - # Install unzip to extract the downloaded files - choco install unzip -y - - # Download all the required components - curl -O "$CUDA_DOWNLOAD/cuda_cudart/windows-x86_64/cuda_cudart-windows-x86_64-${CUDART_VER}-archive.zip" - curl -O "$CUDA_DOWNLOAD/cuda_nvcc/windows-x86_64/cuda_nvcc-windows-x86_64-${NVCC_VER}-archive.zip" - curl -O "$CUDA_DOWNLOAD/cuda_nvrtc/windows-x86_64/cuda_nvrtc-windows-x86_64-${NVRTC_VER}-archive.zip" - curl -O "$CUDA_DOWNLOAD/libcublas/windows-x86_64/libcublas-windows-x86_64-${CUBLAS_VER}-archive.zip" - curl -O "$CUDA_DOWNLOAD/cuda_nvtx/windows-x86_64/cuda_nvtx-windows-x86_64-${NVTX_VER}-archive.zip" - curl -O "$CUDA_DOWNLOAD/cuda_profiler_api/windows-x86_64/cuda_profiler_api-windows-x86_64-${PROFILER_VER}-archive.zip" - curl -O "$CUDA_DOWNLOAD/visual_studio_integration/windows-x86_64/visual_studio_integration-windows-x86_64-${VS_VER}-archive.zip" - curl -O "$CUDA_DOWNLOAD/cuda_nvprof/windows-x86_64/cuda_nvprof-windows-x86_64-${NVPROF_VER}-archive.zip" - curl -O "$CUDA_DOWNLOAD/cuda_cccl/windows-x86_64/cuda_cccl-windows-x86_64-${CCCL_VER}-archive.zip" - - # Extract all the downloaded files to the CUDA Toolkit directory - unzip -q '*.zip' -d $CUDA_TOOLKIT_DIR - - # Copy all the extracted files to the main CUDA Toolkit directory - xcopy "$CUDA_TOOLKIT_DIR\cuda_cudart-windows-x86_64-${CUDART_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y - xcopy "$CUDA_TOOLKIT_DIR\cuda_nvcc-windows-x86_64-${NVCC_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y - xcopy "$CUDA_TOOLKIT_DIR\cuda_nvrtc-windows-x86_64-${NVRTC_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y - xcopy "$CUDA_TOOLKIT_DIR\libcublas-windows-x86_64-${CUBLAS_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y - xcopy "$CUDA_TOOLKIT_DIR\cuda_nvtx-windows-x86_64-${NVTX_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y - xcopy "$CUDA_TOOLKIT_DIR\cuda_nvprof-windows-x86_64-${NVPROF_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y - xcopy "$CUDA_TOOLKIT_DIR\cuda_cccl-windows-x86_64-${CCCL_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y - xcopy "$CUDA_TOOLKIT_DIR\cuda_profiler_api-windows-x86_64-${PROFILER_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y - xcopy "$CUDA_TOOLKIT_DIR\visual_studio_integration-windows-x86_64-${VS_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y - - # Visual Studio integration - xcopy "$CUDA_TOOLKIT_DIR\visual_studio_integration-windows-x86_64-${VS_VER}-archive\visual_studio_integration\MSBuildExtensions\*" "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\MSBuild\Microsoft\VC\v170\BuildCustomizations" /E /I /H /Y - - # Set environment variables - echo "$CUDA_TOOLKIT_DIR\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append - echo "$CUDA_TOOLKIT_DIR\libnvvp" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append - echo "CUDA_PATH=$CUDA_TOOLKIT_DIR" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8 - echo "CUDA_PATH_V12_2=$CUDA_TOOLKIT_DIR" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8 - - - name: Add msbuild to PATH - uses: microsoft/setup-msbuild@v2 - - - name: Install 7-Zip - run: choco install 7zip -y - - - name: Fetch SDL2 and set SDL2_DIR - if: matrix.sdl2 == 'ON' - run: | - Invoke-WebRequest -Uri https://github.com/libsdl-org/SDL/releases/download/release-${{ matrix.sdl2_ver }}/SDL2-devel-${{ matrix.sdl2_ver }}-VC.zip -OutFile sdl2.zip - 7z x sdl2.zip - echo "SDL2_DIR=${{ github.workspace }}\SDL2-${{ matrix.sdl2_ver }}\cmake" | Out-File -FilePath $env:GITHUB_ENV -Append - echo "${{ github.workspace }}\SDL2-${{ matrix.sdl2_ver }}\cmake" > SDL2_PATH.txt - - - name: Install cmake - run: choco install cmake - - - name: Build Project - shell: cmd - run: | - call "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Auxiliary\Build\vcvars64.bat" - cmake --version - where cmake - if "${{ matrix.cuda-toolkit }}" == "11.8.0" ( - set CUDA_FLAGS=-allow-unsupported-compiler -D_ALLOW_COMPILER_AND_STL_VERSION_MISMATCH -D_DISABLE_CONSTEXPR_MUTEX_CONSTRUCTOR - ) else ( - set CUDA_FLAGS= - ) - cmake -S . -B build -G "Ninja Multi-Config" ^ - -DCMAKE_BUILD_TYPE=${{ matrix.build }} ^ - -DGGML_CUDA=${{ matrix.cublas }} ^ - -DWHISPER_SDL2=${{ matrix.sdl2 }} ^ - -DSDL2_DIR="%SDL2_DIR%" ^ - -DCMAKE_POLICY_VERSION_MINIMUM=3.5 ^ - -DCMAKE_CUDA_FLAGS="%CUDA_FLAGS%" - set /A NINJA_JOBS=%NUMBER_OF_PROCESSORS%-1 - cmake --build build --config ${{ matrix.build }} -j %NUMBER_OF_PROCESSORS% - - - name: Check sccache status after build - run: | - sccache --show-stats - - - name: Copy CUDA DLLs - run: | - Get-ChildItem "$env:CUDA_PATH\bin\" -Filter "*.dll" | - Copy-Item -Destination "build/bin/${{ matrix.build }}" - - - name: Copy SDL2.dll - if: matrix.sdl2 == 'ON' - run: copy "$env:SDL2_DIR/../lib/${{ matrix.arch }}/SDL2.dll" build/bin/${{ matrix.build }} - - - name: Pack bin artifacts - shell: pwsh - run: | - Compress-Archive -Path "build/bin/${{ matrix.build }}" -DestinationPath "whisper-cublas-${{ matrix.cuda-toolkit }}-bin-${{ matrix.arch }}.zip" - - - name: Upload binaries - if: ${{ needs.determine-tag.outputs.should_release }} - uses: actions/upload-artifact@v6 - with: - name: whisper-cublas-${{ matrix.cuda-toolkit }}-bin-${{ matrix.arch }}.zip - path: whisper-cublas-${{ matrix.cuda-toolkit }}-bin-${{ matrix.arch }}.zip - - emscripten: - if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || - github.event.inputs.run_type == 'full-ci' }} - runs-on: ubuntu-22.04 - - strategy: - matrix: - build: [Release] - - steps: - - name: Clone - uses: actions/checkout@v6 - - - name: Setup emsdk - uses: mymindstorm/setup-emsdk@v14 - - - name: Verify - run: emcc -v - - - name: Build - run: | - emcmake cmake . -DCMAKE_BUILD_TYPE=${{ matrix.build }} - make - - ios-xcode-build: - runs-on: macos-latest - needs: determine-tag - - strategy: - matrix: - build: [Release] - - steps: - - name: Checkout code - uses: actions/checkout@v6 - - - name: Configure - run: | - cp models/for-tests-ggml-base.en.bin models/ggml-base.en.bin - mkdir models/ggml-base.en-encoder.mlmodelc - - - name: Build - id: cmake_build - run: | - sysctl -a - mkdir build - cd build - cmake -G Xcode .. \ - -DGGML_METAL_USE_BF16=ON \ - -DGGML_METAL_EMBED_LIBRARY=ON \ - -DWHISPER_BUILD_EXAMPLES=OFF \ - -DWHISPER_BUILD_TESTS=OFF \ - -DWHISPER_BUILD_SERVER=OFF \ - -DCMAKE_SYSTEM_NAME=iOS \ - -DCMAKE_OSX_DEPLOYMENT_TARGET=14.0 \ - -DCMAKE_XCODE_ATTRIBUTE_DEVELOPMENT_TEAM=ggml - cmake --build . --config Release -j $(sysctl -n hw.logicalcpu) -- CODE_SIGNING_ALLOWED=NO - - - name: xcodebuild for swift package - id: xcodebuild - run: | - ./build-xcframework.sh - - - name: Build objc example - run: xcodebuild -project examples/whisper.objc/whisper.objc.xcodeproj -scheme whisper.objc -configuration ${{ matrix.build }} -sdk iphoneos CODE_SIGN_IDENTITY="" CODE_SIGNING_REQUIRED=NO FRAMEWORK_FOLDER_PATH=./build-ios build - - - name: Build swiftui example - run: xcodebuild -project examples/whisper.swiftui/whisper.swiftui.xcodeproj -scheme WhisperCppDemo -configuration ${{ matrix.build }} -sdk iphoneos CODE_SIGNING_REQUIRED=NO CODE_SIGN_IDENTITY= -destination 'generic/platform=iOS' FRAMEWORK_FOLDER_PATH=./build-ios build - - - name: Pack artifacts - id: pack_artifacts - run: | - zip --symlinks -r whisper-${{ needs.determine-tag.outputs.tag_name }}-xcframework.zip build-apple/whisper.xcframework - - - name: Upload artifacts - if: ${{ needs.determine-tag.outputs.should_release }} - uses: actions/upload-artifact@v6 - with: - path: whisper-${{ needs.determine-tag.outputs.tag_name }}-xcframework.zip - name: whisper-${{ needs.determine-tag.outputs.tag_name }}-xcframework.zip - - android: - if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || - github.event.inputs.run_type == 'full-ci' }} - runs-on: ubuntu-22.04 - - steps: - - name: Clone - uses: actions/checkout@v6 - with: - path: whisper - - - name: Install Java - uses: actions/setup-java@v5 - with: - distribution: zulu - java-version: 21 - - - name: Setup Android SDK - uses: android-actions/setup-android@v3 - - - name: Build - run: | - cd whisper/examples/whisper.android - ./gradlew assembleRelease --no-daemon - - - name: Build with external ggml - run: | - export PATH_TO_GGML=$PWD/ggml - cd whisper/examples/whisper.android - ./gradlew assembleRelease --no-daemon - - android_java: - runs-on: ubuntu-22.04 - - steps: - - name: Clone - uses: actions/checkout@v6 - - - name: set up JDK 11 - uses: actions/setup-java@v5 - with: - java-version: '11' - distribution: 'temurin' - cache: gradle - - - name: Setup Android SDK - uses: android-actions/setup-android@v3 - with: - cmdline-tools-version: 9.0 - - - name: Build - run: | - cd examples/whisper.android.java - chmod +x ./gradlew - ./gradlew assembleRelease - - bindings-java: - if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || - github.event.inputs.run_type == 'full-ci' }} - needs: ['windows'] - runs-on: windows-latest - steps: - - uses: actions/checkout@v6 - - - name: Install Java - uses: actions/setup-java@v5 - with: - distribution: zulu - java-version: 20 - - - name: Download Whisper Windows lib - uses: actions/download-artifact@v7 - with: - name: whisper_x64.dll - - - name: Download GGML Windows lib - uses: actions/download-artifact@v7 - with: - name: ggml_x64.dll - - - name: Download GGML Base Windows lib - uses: actions/download-artifact@v7 - with: - name: ggml_base_x64.dll - - - name: Download GGML CPU Windows lib - uses: actions/download-artifact@v7 - with: - name: ggml_cpu_x64.dll - - - name: Download SDL2.dll - uses: actions/download-artifact@v7 - with: - name: x64_SDL2.dll - - - name: List downloaded files - shell: pwsh - run: | - Get-ChildItem -Path "." -Recurse -Filter "*.dll" - - - name: Move DLL to correct location - shell: pwsh - run: | - New-Item -Path "build\bin\Release" -ItemType Directory -Force - - Copy-Item -Path "whisper.dll" -Destination "build\bin\Release\whisper.dll" -Force - Write-Host "Copied whisper.dll to build\bin\Release\whisper.dll directory" - - Copy-Item -Path "ggml.dll" -Destination "build\bin\Release\ggml.dll" -Force - Write-Host "Copied ggml.dll to build\bin\Release\ggml.dll directory" - - Copy-Item -Path "ggml-base.dll" -Destination "build\bin\Release\ggml-base.dll" -Force - Write-Host "Copied ggml-base.dll to build\bin\Release\ggml-base.dll directory" - - Copy-Item -Path "ggml-cpu.dll" -Destination "build\bin\Release\ggml-cpu.dll" -Force - Write-Host "Copied ggml-cpu.dll to build\bin\Release\ggml-cpu.dll directory" - - Copy-Item -Path "SDL2.dll" -Destination "build\bin\Release\SDL2.dll" -Force - Write-Host "Copied SDL2.dll to build\bin\Release\SDL2.dll directory" - - - name: List build release files - shell: pwsh - run: | - Get-ChildItem -Path "build\Release" -Recurse -Filter "*.dll" - - - name: Build - run: | - models\download-ggml-model.cmd tiny.en models/ - cd bindings/java - chmod +x ./gradlew - ./gradlew build --info - - - name: Pack jar artifacts - shell: pwsh - run: | - Compress-Archive -Path "bindings/java/build/libs/whispercpp-*.jar" -DestinationPath "whispercpp.jar.zip" - - - name: Upload jar - uses: actions/upload-artifact@v6 - with: - name: whispercpp.jar.zip - path: whispercpp.jar.zip - -# - name: Publish package -# if: ${{ github.ref == 'refs/heads/master' }} -# uses: gradle/gradle-build-action@v2.4.2 -# with: -# arguments: publish -# build-root-directory: bindings/java -# env: -# MAVEN_USERNAME: ${{ secrets.JIRA_USER }} -# MAVEN_PASSWORD: ${{ secrets.JIRA_PASS }} -# PGP_SECRET: ${{ secrets.GPG_PRIVATE_KEY }} -# PGP_PASSPHRASE: ${{ secrets.GPG_PASSPHRASE }} - - quantize: - if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || - github.event.inputs.run_type == 'full-ci' }} - runs-on: ubuntu-22.04 - - steps: - - name: Clone - uses: actions/checkout@v6 - - - name: Test quantize - run: | - ./models/download-ggml-model.sh tiny.en - cmake -B build - cmake --build build --config Release - ./build/bin/whisper-quantize models/ggml-tiny.en.bin models/ggml-tiny.en-q4_0.bin q4_0 - - release: - if: ${{ github.event.inputs.create_release == 'true' || github.event.inputs.pre_release_tag != '' || startsWith(github.ref, 'refs/tags/v') }} - - runs-on: ubuntu-latest - - needs: - - determine-tag - - ios-xcode-build - - windows - - windows-blas - - windows-cublas - - steps: - - name: Clone - id: checkout - uses: actions/checkout@v6 - with: - fetch-depth: 0 - - - name: ccache - uses: hendrikmuhs/ccache-action@v1.2.16 - with: - key: release - evict-old-files: 1d - - # Downloads all the artifacts from the previous jobs - - name: Download artifacts - id: download-artifact - uses: actions/download-artifact@v7 - with: - path: ./artifact - - - name: Move artifacts - id: move_artifacts - run: mkdir -p ./artifact/release && mv ./artifact/*/*.zip ./artifact/release - - - name: Create release - id: create_release - uses: ggml-org/action-create-release@v1 - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - with: - tag_name: ${{ needs.determine-tag.outputs.tag_name }} - prerelease: ${{ github.event.inputs.pre_release_tag != '' }} - draft: true - - - name: Upload release - id: upload_release - uses: actions/github-script@v3 - with: - github-token: ${{secrets.GITHUB_TOKEN}} - script: | - const path = require('path'); - const fs = require('fs'); - const release_id = '${{ steps.create_release.outputs.id }}'; - for (let file of await fs.readdirSync('./artifact/release')) { - if (path.extname(file) === '.zip') { - console.log('uploadReleaseAsset', file); - await github.repos.uploadReleaseAsset({ - owner: context.repo.owner, - repo: context.repo.repo, - release_id: release_id, - name: file, - data: await fs.readFileSync(`./artifact/release/${file}`) - }); - } - } - - coreml-base-en: - if: ${{ (github.event_name == 'push' && github.ref == 'refs/heads/master') || - github.event.inputs.create_release == 'true' || - github.event.inputs.pre_release_tag != '' || - startsWith(github.ref, 'refs/tags/v') }} - runs-on: macos-latest - needs: determine-tag - - steps: - - name: Checkout code - uses: actions/checkout@v6 - - - name: Set environment variables - id: set_vars - run: | - echo "MODEL_NAME=base.en" >> $GITHUB_ENV - echo "GEN_MODEL_NAME=whisper-${{ needs.determine-tag.outputs.tag_name }}-ggml-base.en-encoder.mlmodelc" >> $GITHUB_ENV - - - name: Download model - run: | - ./models/download-ggml-model.sh ${{ env.MODEL_NAME }} - - - name: Generate CoreML model - run: | - python3.11 -m venv venv - source venv/bin/activate - pip install ane_transformers openai-whisper coremltools - ./models/generate-coreml-model.sh ${{ env.MODEL_NAME }} - - vad: - if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || - github.event.inputs.run_type == 'full-ci' }} - runs-on: ubuntu-latest - - steps: - - name: Checkout - uses: actions/checkout@v6 - - - name: Build - shell: bash - run: | - cmake -B build - cmake --build build --config Release - - - name: Test - shell: bash - run: | - ctest -R ^test-vad$ --test-dir build --output-on-failure -VV - -# TODO: simplify the following workflows using a matrix - ggml-ci-x64-cpu-low-perf: - runs-on: ubuntu-22.04 - - steps: - - name: Clone - id: checkout - uses: actions/checkout@v6 - - - name: ccache - uses: ggml-org/ccache-action@v1.2.16 - with: - key: ggml-ci-x64-cpu-low-perf - evict-old-files: 1d - - - name: Dependencies - id: depends - run: | - sudo apt-get update - sudo apt-get install build-essential libcurl4-openssl-dev - - - name: Test - id: ggml-ci - run: | - LLAMA_ARG_THREADS=$(nproc) GG_BUILD_LOW_PERF=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt - - ggml-ci-arm64-cpu-low-perf: - runs-on: ubuntu-22.04-arm - - steps: - - name: Clone - id: checkout - uses: actions/checkout@v6 - - - name: ccache - uses: ggml-org/ccache-action@v1.2.16 - with: - key: ggml-ci-arm64-cpu-low-perf - evict-old-files: 1d - - - name: Dependencies - id: depends - run: | - sudo apt-get update - sudo apt-get install build-essential libcurl4-openssl-dev - - - name: Test - id: ggml-ci - run: | - LLAMA_ARG_THREADS=$(nproc) GG_BUILD_LOW_PERF=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt - - ggml-ci-x64-cpu-high-perf: - runs-on: ubuntu-22.04 - - steps: - - name: Clone - id: checkout - uses: actions/checkout@v6 - - - name: ccache - uses: ggml-org/ccache-action@v1.2.16 - with: - key: ggml-ci-x64-cpu-high-perf - evict-old-files: 1d - - - name: Dependencies - id: depends - run: | - sudo apt-get update - sudo apt-get install build-essential libcurl4-openssl-dev - - - name: Test - id: ggml-ci - run: | - LLAMA_ARG_THREADS=$(nproc) bash ./ci/run.sh ./tmp/results ./tmp/mnt - - ggml-ci-arm64-cpu-high-perf: - runs-on: ubuntu-22.04-arm - - steps: - - name: Clone - id: checkout - uses: actions/checkout@v6 - - - name: ccache - uses: ggml-org/ccache-action@v1.2.16 - with: - key: ggml-ci-arm64-cpu-high-perf - evict-old-files: 1d - - - name: Dependencies - id: depends - run: | - sudo apt-get update - sudo apt-get install build-essential libcurl4-openssl-dev - - - name: Test - id: ggml-ci - run: | - LLAMA_ARG_THREADS=$(nproc) GG_BUILD_NO_SVE=1 GG_BUILD_NO_BF16=1 GG_BUILD_EXTRA_TESTS_0=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt - - ggml-ci-arm64-cpu-high-perf-sve: - runs-on: ubuntu-22.04-arm - - steps: - - name: Clone - id: checkout - uses: actions/checkout@v6 - - - name: ccache - uses: ggml-org/ccache-action@v1.2.16 - with: - key: ggml-ci-arm64-cpu-high-perf-sve - evict-old-files: 1d - - - name: Dependencies - id: depends - run: | - sudo apt-get update - sudo apt-get install build-essential libcurl4-openssl-dev - - - name: Test - id: ggml-ci - run: | - LLAMA_ARG_THREADS=$(nproc) GG_BUILD_NO_BF16=1 GG_BUILD_EXTRA_TESTS_0=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt - - ggml-ci-x64-nvidia-cuda: - runs-on: [self-hosted, Linux, NVIDIA] - - steps: - - name: Clone - id: checkout - uses: actions/checkout@v6 - - - name: Test - id: ggml-ci - run: | - nvidia-smi - GG_BUILD_CUDA=1 bash ./ci/run.sh ~/results/whisper.cpp ~/mnt/whisper.cpp - - ggml-ci-x64-nvidia-vulkan-cm: - runs-on: [self-hosted, Linux, NVIDIA] - - steps: - - name: Clone - id: checkout - uses: actions/checkout@v6 - - - name: Test - id: ggml-ci - run: | - vulkaninfo --summary - GG_BUILD_VULKAN=1 GGML_VK_DISABLE_COOPMAT2=1 bash ./ci/run.sh ~/results/whisper.cpp ~/mnt/whisper.cpp - - ggml-ci-x64-nvidia-vulkan-cm2: - runs-on: [self-hosted, Linux, NVIDIA, COOPMAT2] - - steps: - - name: Clone - id: checkout - uses: actions/checkout@v6 - - - name: Test - id: ggml-ci - run: | - vulkaninfo --summary - GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/whisper.cpp ~/mnt/whisper.cpp - - #ggml-ci-x64-cpu-amx: - # runs-on: [self-hosted, Linux, X64, CPU, AMX] - - # steps: - # - name: Clone - # id: checkout - # uses: actions/checkout@v6 - - # - name: Test - # id: ggml-ci - # run: | - # bash ./ci/run.sh ~/results/whisper.cpp ~/mnt/whisper.cpp - - ggml-ci-mac-metal: - runs-on: [self-hosted, macOS, ARM64] - - steps: - - name: Clone - id: checkout - uses: actions/checkout@v6 - - - name: Test - id: ggml-ci - run: | - GG_BUILD_METAL=1 bash ./ci/run.sh ~/results/whisper.cpp ~/mnt/whisper.cpp - - ggml-ci-mac-vulkan: - runs-on: [self-hosted, macOS, ARM64] - - steps: - - name: Clone - id: checkout - uses: actions/checkout@v6 - - - name: Test - id: ggml-ci - run: | - vulkaninfo --summary - GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/whisper.cpp ~/mnt/whisper.cpp diff --git a/.github/workflows/examples.yml b/.github/workflows/examples.yml index df3aa832c2e..eaa4fe4df61 100644 --- a/.github/workflows/examples.yml +++ b/.github/workflows/examples.yml @@ -42,6 +42,8 @@ jobs: run: npx cmake-js compile -T addon.node -B Release - name: Download test model + env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} run: | bash ./models/download-ggml-model.sh base.en - name: Test diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 00000000000..2ba8b45093b --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,649 @@ +name: Release + +on: + workflow_dispatch: + inputs: + create_release: + description: 'Create new release' + required: true + type: boolean + pre_release_tag: + description: 'Pre-release tag name' + required: false + type: string + + push: + branches: + - master + tags: + - 'v*' + +env: + BRANCH_NAME: ${{ github.head_ref || github.ref_name }} + VCPKG_BINARY_SOURCES: "clear;x-gha,readwrite" + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} + cancel-in-progress: true + +permissions: + contents: write # for creating release + +jobs: + determine-tag: + runs-on: ubuntu-latest + outputs: + tag_name: ${{ steps.tag.outputs.name }} + should_release: ${{ steps.tag.outputs.should_release }} + + steps: + - name: Checkout with full history + uses: actions/checkout@v6 + with: + fetch-depth: 0 + + - name: Determine tag name + id: tag + shell: bash + run: | + BUILD_NUMBER=$(git rev-list --count HEAD) + SHORT_HASH=$(git rev-parse --short=7 HEAD) + CUSTOM_TAG="${{ github.event.inputs.pre_release_tag }}" + SHOULD_RELEASE="false" + + echo "Raw values:" + echo "BUILD_NUMBER: $BUILD_NUMBER" + echo "SHORT_HASH: $SHORT_HASH" + echo "BRANCH_NAME: ${{ env.BRANCH_NAME }}" + echo "CUSTOM_TAG: $CUSTOM_TAG" + + if [[ "${{ github.ref_type }}" == "tag" ]]; then + echo "Using pushed tag name" + TAG_NAME="${{ github.ref_name }}" + SHOULD_RELEASE="true" + elif [[ -n "$CUSTOM_TAG" ]]; then + echo "Using custom tag" + TAG_NAME="${CUSTOM_TAG}" + SHOULD_RELEASE="true" + elif [[ "${{ github.event.inputs.create_release }}" == "true" ]]; then + echo "Manual release requested" + SHOULD_RELEASE="true" + TAG_NAME="b${BUILD_NUMBER}" + elif [[ "${{ env.BRANCH_NAME }}" == "master" ]]; then + echo "Using master branch format" + TAG_NAME="b${BUILD_NUMBER}" + SHOULD_RELEASE="false" + else + echo "Using non-master branch format" + SAFE_NAME=$(echo "${{ env.BRANCH_NAME }}" | tr '/' '-') + TAG_NAME="${SAFE_NAME}-b${BUILD_NUMBER}-${SHORT_HASH}" + SHOULD_RELEASE="false" + fi + + echo "Final tag name: $TAG_NAME" + echo "Should release: $SHOULD_RELEASE" + echo "name=$TAG_NAME" >> $GITHUB_OUTPUT + echo "should_release=$SHOULD_RELEASE" >> $GITHUB_OUTPUT + + ubuntu-cpu: + runs-on: ${{ matrix.os }} + needs: determine-tag + if: ${{ needs.determine-tag.outputs.should_release == 'true' }} + + strategy: + matrix: + include: + - build: x64 + os: ubuntu-22.04 + - build: arm64 + os: ubuntu-22.04-arm + + steps: + - name: Clone + uses: actions/checkout@v6 + + - name: ccache + uses: ggml-org/ccache-action@v1.2.21 + with: + key: release-${{ matrix.os }}-cpu + evict-old-files: 1d + + - name: Dependencies + run: | + sudo apt-get update + sudo apt-get install -y build-essential cmake + + - name: Build + run: | + cmake -B build \ + -DCMAKE_BUILD_TYPE=Release \ + -DBUILD_SHARED_LIBS=OFF \ + -DGGML_NATIVE=OFF \ + ${{ matrix.build == 'arm64' && '-DGGML_CPU_ARM_ARCH=armv8-a' || '' }} + cmake --build build --config Release -j $(nproc) + + - name: Pack artifacts + run: | + cp LICENSE ./build/bin/ + tar -czvf whisper-bin-ubuntu-${{ matrix.build }}.tar.gz \ + --transform "s,^\.,whisper-bin-ubuntu-${{ matrix.build }}," \ + -C ./build/bin . + + - name: Upload artifacts + uses: actions/upload-artifact@v6 + with: + path: whisper-bin-ubuntu-${{ matrix.build }}.tar.gz + name: whisper-bin-ubuntu-${{ matrix.build }}.tar.gz + + windows: + runs-on: windows-latest + needs: determine-tag + + strategy: + matrix: + build: [Release] + arch: [Win32, x64] + sdl2: [ON] + include: + - arch: Win32 + s2arc: x86 + jnaPath: win32-x86 + - arch: x64 + s2arc: x64 + jnaPath: win32-x86-64 + - sdl2: ON + s2ver: 2.28.5 + + steps: + - name: Clone + uses: actions/checkout@v6 + + - name: Add msbuild to PATH + uses: microsoft/setup-msbuild@v2 + + - name: Fetch SDL2 and set SDL2_DIR + if: matrix.sdl2 == 'ON' + run: | + C:/msys64/usr/bin/wget.exe -qO sdl2.zip https://github.com/libsdl-org/SDL/releases/download/release-${{ matrix.s2ver }}/SDL2-devel-${{ matrix.s2ver }}-VC.zip + 7z x sdl2.zip + echo "SDL2_DIR=$env:GITHUB_WORKSPACE/SDL2-${{ matrix.s2ver }}/cmake" >> $env:GITHUB_ENV + + - name: Configure + run: > + cmake -S . -B ./build -A ${{ matrix.arch }} + -DCMAKE_BUILD_TYPE=${{ matrix.build }} + -DBUILD_SHARED_LIBS=ON + -DWHISPER_SDL2=${{ matrix.sdl2 }} + -DGGML_NATIVE=OFF + -DGGML_BMI2=OFF + + - name: Build + run: | + cd ./build + msbuild ALL_BUILD.vcxproj -t:build -p:configuration=${{ matrix.build }} -p:platform=${{ matrix.arch }} + + - name: Copy SDL2.dll + if: matrix.sdl2 == 'ON' + run: copy "$env:SDL2_DIR/../lib/${{ matrix.s2arc }}/SDL2.dll" build/bin/${{ matrix.build }} + + - name: Upload SDL2.dll + if: matrix.sdl2 == 'ON' + uses: actions/upload-artifact@v6 + with: + name: ${{ matrix.s2arc }}_SDL2.dll + path: build/bin/${{ matrix.build }}/SDL2.dll + + - name: Upload whisper dll + uses: actions/upload-artifact@v6 + with: + name: whisper_${{ matrix.arch }}.dll + path: build/bin/${{ matrix.build }}/whisper.dll + + - name: Upload ggml dll + uses: actions/upload-artifact@v6 + with: + name: ggml_${{ matrix.arch }}.dll + path: build/bin/${{ matrix.build }}/ggml.dll + overwrite: true + + - name: Upload ggml base dll + uses: actions/upload-artifact@v6 + with: + name: ggml_base_${{ matrix.arch }}.dll + path: build/bin/${{ matrix.build }}/ggml-base.dll + + - name: Upload ggml cpu dll + uses: actions/upload-artifact@v6 + with: + name: ggml_cpu_${{ matrix.arch }}.dll + path: build/bin/${{ matrix.build }}/ggml-cpu.dll + + - name: Pack bin artifacts + shell: pwsh + run: | + Compress-Archive -Path "build/bin/${{ matrix.build }}" -DestinationPath "whisper-bin-${{ matrix.arch }}.zip" + + - name: Upload binaries + if: matrix.sdl2 == 'ON' && ${{ needs.determine-tag.outputs.should_release }} + uses: actions/upload-artifact@v6 + with: + name: whisper-bin-${{ matrix.arch }}.zip + path: whisper-bin-${{ matrix.arch }}.zip + + windows-blas: + runs-on: windows-latest + needs: determine-tag + + strategy: + matrix: + build: [Release] + arch: [Win32, x64] + blas: [ON] + sdl2: [ON] + blasver: [0.3.29] + include: + - arch: Win32 + s2arc: x86 + blasfile: x86 + - arch: x64 + s2arc: x64 + blasfile: x64_64 + - sdl2: ON + s2ver: 2.28.5 + + steps: + - name: Clone + uses: actions/checkout@v6 + + - name: Export GitHub Actions cache environment variables + uses: actions/github-script@v8 + with: + script: | + core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); + core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); + + - name: Add msbuild to PATH + uses: microsoft/setup-msbuild@v2 + + - name: Install OpenBLAS and pkgconfiglite + if: matrix.blas == 'ON' + run: | + Invoke-WebRequest "https://github.com/OpenMathLib/OpenBLAS/releases/download/v${{matrix.blasver}}/OpenBLAS-${{matrix.blasver}}_${{matrix.blasfile}}.zip" -OutFile "OpenBLAS-${{matrix.blasver}}.zip" + Expand-Archive "OpenBLAS-${{matrix.blasver}}.zip" -DestinationPath "OpenBLAS-${{matrix.blasver}}" + choco install pkgconfiglite + + - name: Fetch SDL2 and set SDL2_DIR + if: matrix.sdl2 == 'ON' + run: | + C:/msys64/usr/bin/wget.exe -qO sdl2.zip https://github.com/libsdl-org/SDL/releases/download/release-${{ matrix.s2ver }}/SDL2-devel-${{ matrix.s2ver }}-VC.zip + 7z x sdl2.zip + echo "SDL2_DIR=$env:GITHUB_WORKSPACE/SDL2-${{ matrix.s2ver }}/cmake" >> $env:GITHUB_ENV + + - name: Configure + run: > + cmake -S . -B ./build -A ${{ matrix.arch }} + -DCMAKE_TOOLCHAIN_FILE="$env:VCPKG_INSTALLATION_ROOT/scripts/buildsystems/vcpkg.cmake" + -DCMAKE_BUILD_TYPE=${{ matrix.build }} + -DGGML_BLAS=${{ matrix.blas }} + -DGGML_BLAS_VENDOR=OpenBLAS + -DBLAS_LIBRARIES="$env:GITHUB_WORKSPACE/OpenBLAS-${{matrix.blasver}}/lib/libopenblas.lib" + -DBLAS_INCLUDE_DIRS="$env:GITHUB_WORKSPACE/OpenBLAS-${{matrix.blasver}}/include" + -DWHISPER_SDL2=${{ matrix.sdl2 }} + + - name: Build + run: | + cd ./build + msbuild ALL_BUILD.vcxproj -t:build -p:configuration=${{ matrix.build }} -p:platform=${{ matrix.arch }} + + - name: Copy openblas.dll + if: matrix.blas == 'ON' + run: copy "$env:GITHUB_WORKSPACE/OpenBLAS-${{matrix.blasver}}/bin/libopenblas.dll" build/bin/${{ matrix.build }} + + - name: Copy SDL2.dll + if: matrix.sdl2 == 'ON' + run: copy "$env:SDL2_DIR/../lib/${{ matrix.s2arc }}/SDL2.dll" build/bin/${{ matrix.build }} + + - name: Pack bin artifacts + shell: pwsh + run: | + Compress-Archive -Path "build/bin/${{ matrix.build }}" -DestinationPath "whisper-blas-bin-${{ matrix.arch }}.zip" + + - name: Upload binaries + if: matrix.blas == 'ON' && matrix.sdl2 == 'ON' && ${{ needs.determine-tag.outputs.should_release }} + uses: actions/upload-artifact@v6 + with: + name: whisper-blas-bin-${{ matrix.arch }}.zip + path: whisper-blas-bin-${{ matrix.arch }}.zip + + windows-cublas: + runs-on: windows-2022 + needs: determine-tag + strategy: + fail-fast: false + matrix: + build: [Release] + arch: [x64] + cublas: [ON] + sdl2: [ON] + cuda-toolkit: [12.4.0, 11.8.0] + include: + - arch: x64 + sdl2: ON + sdl2_ver: 2.28.5 + steps: + - name: Clone repository + uses: actions/checkout@v6 + + - name: Install Ninja + id: install_ninja + run: | + choco install ninja + + - name: Install ccache + uses: hendrikmuhs/ccache-action@v1.2.16 + with: + key: ${{ github.job }}-${{ matrix.cuda-toolkit }}-${{ matrix.build }} + variant: sccache + evict-old-files: 5d + + - name: Install Cuda Toolkit 11.8.0 + if: ${{ matrix.cuda-toolkit == '11.8.0' }} + run: | + $CUDA_VERSION = ${{ matrix.cuda-toolkit }} + $CUDA_TOOLKIT_DIR = "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v$CUDA_VERSION" + $CUDA_DOWNLOAD = "https://developer.download.nvidia.com/compute/cuda/redist" + + # Components versions + $CUDART_VER = "11.8.89" + $NVCC_VER = "11.8.89" + $NVRTC_VER = "11.8.89" + $CUBLAS_VER = "11.8.1.74" + $NVTX_VER = "11.8.86" + $VS_VER = "11.8.86" + $NVPROF_VER = "11.8.87" + $CCCL_VER = "11.8.89" + + # Create the directory where the CUDA Toolkit will be installed + mkdir -p $CUDA_TOOLKIT_DIR + + # Install unzip to extract the downloaded files + choco install unzip -y + + # Download all the required components + curl -O "$CUDA_DOWNLOAD/cuda_cudart/windows-x86_64/cuda_cudart-windows-x86_64-${CUDART_VER}-archive.zip" + curl -O "$CUDA_DOWNLOAD/cuda_nvcc/windows-x86_64/cuda_nvcc-windows-x86_64-${NVCC_VER}-archive.zip" + curl -O "$CUDA_DOWNLOAD/cuda_nvrtc/windows-x86_64/cuda_nvrtc-windows-x86_64-${NVRTC_VER}-archive.zip" + curl -O "$CUDA_DOWNLOAD/libcublas/windows-x86_64/libcublas-windows-x86_64-${CUBLAS_VER}-archive.zip" + curl -O "$CUDA_DOWNLOAD/cuda_nvtx/windows-x86_64/cuda_nvtx-windows-x86_64-${NVTX_VER}-archive.zip" + curl -O "$CUDA_DOWNLOAD/visual_studio_integration/windows-x86_64/visual_studio_integration-windows-x86_64-${VS_VER}-archive.zip" + curl -O "$CUDA_DOWNLOAD/cuda_nvprof/windows-x86_64/cuda_nvprof-windows-x86_64-${NVPROF_VER}-archive.zip" + curl -O "$CUDA_DOWNLOAD/cuda_cccl/windows-x86_64/cuda_cccl-windows-x86_64-${CCCL_VER}-archive.zip" + + # Extract all the downloaded files to the CUDA Toolkit directory + unzip '*.zip' -d $CUDA_TOOLKIT_DIR + + # Copy all the extracted files to the main CUDA Toolkit directory + xcopy "$CUDA_TOOLKIT_DIR\cuda_cudart-windows-x86_64-${CUDART_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y + xcopy "$CUDA_TOOLKIT_DIR\cuda_nvcc-windows-x86_64-${NVCC_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y + xcopy "$CUDA_TOOLKIT_DIR\cuda_nvrtc-windows-x86_64-${NVRTC_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y + xcopy "$CUDA_TOOLKIT_DIR\libcublas-windows-x86_64-${CUBLAS_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y + xcopy "$CUDA_TOOLKIT_DIR\cuda_nvtx-windows-x86_64-${NVTX_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y + xcopy "$CUDA_TOOLKIT_DIR\cuda_nvprof-windows-x86_64-${NVPROF_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y + xcopy "$CUDA_TOOLKIT_DIR\cuda_cccl-windows-x86_64-${CCCL_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y + xcopy "$CUDA_TOOLKIT_DIR\visual_studio_integration-windows-x86_64-${VS_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y + + # Visual Studio integration + xcopy "$CUDA_TOOLKIT_DIR\visual_studio_integration-windows-x86_64-${VS_VER}-archive\visual_studio_integration\MSBuildExtensions\*" "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\MSBuild\Microsoft\VC\v170\BuildCustomizations" /E /I /H /Y + + # Set environment variables + echo "$CUDA_TOOLKIT_DIR\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append + echo "$CUDA_TOOLKIT_DIR\libnvvp" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append + echo "CUDA_PATH=$CUDA_TOOLKIT_DIR" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8 + echo "CUDA_PATH_V11_8=$CUDA_TOOLKIT_DIR" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8 + + - name: Install Cuda Toolkit 12.4.0 + if: ${{ matrix.cuda-toolkit == '12.4.0' }} + run: | + $CUDA_VERSION = ${{ matrix.cuda-toolkit }} + $CUDA_TOOLKIT_DIR = "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v$CUDA_VERSION" + $CUDA_DOWNLOAD = "https://developer.download.nvidia.com/compute/cuda/redist" + + # Components versions + $CUDART_VER = "12.4.127" + $NVCC_VER = "12.4.131" + $NVRTC_VER = "12.4.127" + $CUBLAS_VER = "12.4.5.8" + $NVTX_VER = "12.4.127" + $PROFILER_VER = "12.4.127" + $VS_VER = "12.4.127" + $NVPROF_VER = "12.4.128" + $CCCL_VER = "12.4.127" + + # Create the directory where the CUDA Toolkit will be installed + mkdir -p $CUDA_TOOLKIT_DIR + + # Install unzip to extract the downloaded files + choco install unzip -y + + # Download all the required components + curl -O "$CUDA_DOWNLOAD/cuda_cudart/windows-x86_64/cuda_cudart-windows-x86_64-${CUDART_VER}-archive.zip" + curl -O "$CUDA_DOWNLOAD/cuda_nvcc/windows-x86_64/cuda_nvcc-windows-x86_64-${NVCC_VER}-archive.zip" + curl -O "$CUDA_DOWNLOAD/cuda_nvrtc/windows-x86_64/cuda_nvrtc-windows-x86_64-${NVRTC_VER}-archive.zip" + curl -O "$CUDA_DOWNLOAD/libcublas/windows-x86_64/libcublas-windows-x86_64-${CUBLAS_VER}-archive.zip" + curl -O "$CUDA_DOWNLOAD/cuda_nvtx/windows-x86_64/cuda_nvtx-windows-x86_64-${NVTX_VER}-archive.zip" + curl -O "$CUDA_DOWNLOAD/cuda_profiler_api/windows-x86_64/cuda_profiler_api-windows-x86_64-${PROFILER_VER}-archive.zip" + curl -O "$CUDA_DOWNLOAD/visual_studio_integration/windows-x86_64/visual_studio_integration-windows-x86_64-${VS_VER}-archive.zip" + curl -O "$CUDA_DOWNLOAD/cuda_nvprof/windows-x86_64/cuda_nvprof-windows-x86_64-${NVPROF_VER}-archive.zip" + curl -O "$CUDA_DOWNLOAD/cuda_cccl/windows-x86_64/cuda_cccl-windows-x86_64-${CCCL_VER}-archive.zip" + + # Extract all the downloaded files to the CUDA Toolkit directory + unzip -q '*.zip' -d $CUDA_TOOLKIT_DIR + + # Copy all the extracted files to the main CUDA Toolkit directory + xcopy "$CUDA_TOOLKIT_DIR\cuda_cudart-windows-x86_64-${CUDART_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y + xcopy "$CUDA_TOOLKIT_DIR\cuda_nvcc-windows-x86_64-${NVCC_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y + xcopy "$CUDA_TOOLKIT_DIR\cuda_nvrtc-windows-x86_64-${NVRTC_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y + xcopy "$CUDA_TOOLKIT_DIR\libcublas-windows-x86_64-${CUBLAS_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y + xcopy "$CUDA_TOOLKIT_DIR\cuda_nvtx-windows-x86_64-${NVTX_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y + xcopy "$CUDA_TOOLKIT_DIR\cuda_nvprof-windows-x86_64-${NVPROF_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y + xcopy "$CUDA_TOOLKIT_DIR\cuda_cccl-windows-x86_64-${CCCL_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y + xcopy "$CUDA_TOOLKIT_DIR\cuda_profiler_api-windows-x86_64-${PROFILER_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y + xcopy "$CUDA_TOOLKIT_DIR\visual_studio_integration-windows-x86_64-${VS_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y + + # Visual Studio integration + xcopy "$CUDA_TOOLKIT_DIR\visual_studio_integration-windows-x86_64-${VS_VER}-archive\visual_studio_integration\MSBuildExtensions\*" "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\MSBuild\Microsoft\VC\v170\BuildCustomizations" /E /I /H /Y + + # Set environment variables + echo "$CUDA_TOOLKIT_DIR\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append + echo "$CUDA_TOOLKIT_DIR\libnvvp" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append + echo "CUDA_PATH=$CUDA_TOOLKIT_DIR" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8 + echo "CUDA_PATH_V12_2=$CUDA_TOOLKIT_DIR" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8 + + - name: Add msbuild to PATH + uses: microsoft/setup-msbuild@v2 + + - name: Install 7-Zip + run: choco install 7zip -y + + - name: Fetch SDL2 and set SDL2_DIR + if: matrix.sdl2 == 'ON' + run: | + Invoke-WebRequest -Uri https://github.com/libsdl-org/SDL/releases/download/release-${{ matrix.sdl2_ver }}/SDL2-devel-${{ matrix.sdl2_ver }}-VC.zip -OutFile sdl2.zip + 7z x sdl2.zip + echo "SDL2_DIR=${{ github.workspace }}\SDL2-${{ matrix.sdl2_ver }}\cmake" | Out-File -FilePath $env:GITHUB_ENV -Append + echo "${{ github.workspace }}\SDL2-${{ matrix.sdl2_ver }}\cmake" > SDL2_PATH.txt + + - name: Install cmake + run: choco install cmake + + - name: Build Project + shell: cmd + run: | + call "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Auxiliary\Build\vcvars64.bat" + cmake --version + where cmake + if "${{ matrix.cuda-toolkit }}" == "11.8.0" ( + set CUDA_FLAGS=-allow-unsupported-compiler -D_ALLOW_COMPILER_AND_STL_VERSION_MISMATCH -D_DISABLE_CONSTEXPR_MUTEX_CONSTRUCTOR + ) else ( + set CUDA_FLAGS= + ) + cmake -S . -B build -G "Ninja Multi-Config" ^ + -DCMAKE_BUILD_TYPE=${{ matrix.build }} ^ + -DGGML_CUDA=${{ matrix.cublas }} ^ + -DWHISPER_SDL2=${{ matrix.sdl2 }} ^ + -DSDL2_DIR="%SDL2_DIR%" ^ + -DCMAKE_POLICY_VERSION_MINIMUM=3.5 ^ + -DCMAKE_CUDA_FLAGS="%CUDA_FLAGS%" + set /A NINJA_JOBS=%NUMBER_OF_PROCESSORS%-1 + cmake --build build --config ${{ matrix.build }} -j %NUMBER_OF_PROCESSORS% + + - name: Check sccache status after build + run: | + sccache --show-stats + + - name: Copy CUDA DLLs + run: | + Get-ChildItem "$env:CUDA_PATH\bin\" -Filter "*.dll" | + Copy-Item -Destination "build/bin/${{ matrix.build }}" + + - name: Copy SDL2.dll + if: matrix.sdl2 == 'ON' + run: copy "$env:SDL2_DIR/../lib/${{ matrix.arch }}/SDL2.dll" build/bin/${{ matrix.build }} + + - name: Pack bin artifacts + shell: pwsh + run: | + Compress-Archive -Path "build/bin/${{ matrix.build }}" -DestinationPath "whisper-cublas-${{ matrix.cuda-toolkit }}-bin-${{ matrix.arch }}.zip" + + - name: Upload binaries + if: ${{ needs.determine-tag.outputs.should_release }} + uses: actions/upload-artifact@v6 + with: + name: whisper-cublas-${{ matrix.cuda-toolkit }}-bin-${{ matrix.arch }}.zip + path: whisper-cublas-${{ matrix.cuda-toolkit }}-bin-${{ matrix.arch }}.zip + + ios-xcode-build: + runs-on: macos-latest + needs: determine-tag + + strategy: + matrix: + build: [Release] + + steps: + - name: Checkout code + uses: actions/checkout@v6 + + - name: Configure + run: | + cp models/for-tests-ggml-base.en.bin models/ggml-base.en.bin + mkdir models/ggml-base.en-encoder.mlmodelc + + - name: Build + id: cmake_build + run: | + sysctl -a + mkdir build + cd build + cmake -G Xcode .. \ + -DGGML_METAL_USE_BF16=ON \ + -DGGML_METAL_EMBED_LIBRARY=ON \ + -DWHISPER_BUILD_EXAMPLES=OFF \ + -DWHISPER_BUILD_TESTS=OFF \ + -DWHISPER_BUILD_SERVER=OFF \ + -DCMAKE_SYSTEM_NAME=iOS \ + -DCMAKE_OSX_DEPLOYMENT_TARGET=14.0 \ + -DCMAKE_XCODE_ATTRIBUTE_DEVELOPMENT_TEAM=ggml + cmake --build . --config Release -j $(sysctl -n hw.logicalcpu) -- CODE_SIGNING_ALLOWED=NO + + - name: xcodebuild for swift package + id: xcodebuild + run: | + ./build-xcframework.sh + + - name: Build objc example + run: xcodebuild -project examples/whisper.objc/whisper.objc.xcodeproj -scheme whisper.objc -configuration ${{ matrix.build }} -sdk iphoneos CODE_SIGN_IDENTITY="" CODE_SIGNING_REQUIRED=NO FRAMEWORK_FOLDER_PATH=./build-ios build + + - name: Build swiftui example + run: xcodebuild -project examples/whisper.swiftui/whisper.swiftui.xcodeproj -scheme WhisperCppDemo -configuration ${{ matrix.build }} -sdk iphoneos CODE_SIGNING_REQUIRED=NO CODE_SIGN_IDENTITY= -destination 'generic/platform=iOS' FRAMEWORK_FOLDER_PATH=./build-ios build + + - name: Pack artifacts + id: pack_artifacts + run: | + zip --symlinks -r whisper-${{ needs.determine-tag.outputs.tag_name }}-xcframework.zip build-apple/whisper.xcframework + + - name: Upload artifacts + if: ${{ needs.determine-tag.outputs.should_release }} + uses: actions/upload-artifact@v6 + with: + path: whisper-${{ needs.determine-tag.outputs.tag_name }}-xcframework.zip + name: whisper-${{ needs.determine-tag.outputs.tag_name }}-xcframework.zip + + release: + if: ${{ github.event.inputs.create_release == 'true' || github.event.inputs.pre_release_tag != '' || startsWith(github.ref, 'refs/tags/v') }} + + runs-on: ubuntu-latest + + needs: + - determine-tag + - ubuntu-cpu + - ios-xcode-build + - windows + - windows-blas + - windows-cublas + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v6 + with: + fetch-depth: 0 + + - name: ccache + uses: ggml-org/ccache-action@v1.2.21 + with: + key: release + evict-old-files: 1d + + # Downloads all the artifacts from the previous jobs + - name: Download artifacts + id: download-artifact + uses: actions/download-artifact@v7 + with: + path: ./artifact + + - name: Move artifacts + id: move_artifacts + run: mkdir -p ./artifact/release && mv ./artifact/*/*.zip ./artifact/release && mv ./artifact/*/*.tar.gz ./artifact/release 2>/dev/null || true + + - name: Create release + id: create_release + uses: ggml-org/action-create-release@v1 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + tag_name: ${{ needs.determine-tag.outputs.tag_name }} + prerelease: ${{ github.event.inputs.pre_release_tag != '' }} + draft: true + + - name: Upload release + id: upload_release + uses: actions/github-script@v3 + with: + github-token: ${{secrets.GITHUB_TOKEN}} + script: | + const path = require('path'); + const fs = require('fs'); + const release_id = '${{ steps.create_release.outputs.id }}'; + for (let file of await fs.readdirSync('./artifact/release')) { + if (path.extname(file) === '.zip' || file.endsWith('.tar.gz')) { + console.log('uploadReleaseAsset', file); + await github.repos.uploadReleaseAsset({ + owner: context.repo.owner, + repo: context.repo.repo, + release_id: release_id, + name: file, + data: await fs.readFileSync(`./artifact/release/${file}`) + }); + } + } diff --git a/ci/run.sh b/ci/run.sh index b03fdf1c6b1..dca4476a0fa 100644 --- a/ci/run.sh +++ b/ci/run.sh @@ -151,8 +151,15 @@ function gg_download_model { local cwd=`pwd` mkdir -p "$MNT/models" cd "$MNT/models" + set -x bash "$cwd/models/download-ggml-model.sh" ${model_name} . + local download_status=$? + set +x cd "$cwd" + if [ $download_status -ne 0 ]; then + echo "Error: failed to download model ${model_name}" + ret=1 + fi fi } diff --git a/models/download-ggml-model.sh b/models/download-ggml-model.sh index f1394e98484..0539c8afb3d 100755 --- a/models/download-ggml-model.sh +++ b/models/download-ggml-model.sh @@ -120,7 +120,13 @@ fi if [ -x "$(command -v wget2)" ]; then wget2 --no-config --progress bar -O ggml-"$model".bin $src/$pfx-"$model".bin elif [ -x "$(command -v curl)" ]; then - curl -L --output ggml-"$model".bin $src/$pfx-"$model".bin + curl -L --fail \ + --retry 5 \ + --retry-delay 5 \ + --retry-all-errors \ + --retry-connrefused \ + ${HF_TOKEN:+--header "Authorization: Bearer $HF_TOKEN"} \ + --output ggml-"$model".bin $src/$pfx-"$model".bin elif [ -x "$(command -v wget)" ]; then wget --no-config --quiet --show-progress -O ggml-"$model".bin $src/$pfx-"$model".bin else From 12d1828837f8ca3ea2b3c94180bcab733fb76092 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius <daniel.bevenius@gmail.com> Date: Thu, 4 Jun 2026 10:30:48 +0200 Subject: [PATCH 747/831] ci : only publish/push docker images daily (#3854) This commit updates the docker workflow to be triggered on a schedule or manually. --- .github/workflows/docker.yml | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index 9e07f7b2292..51724976e0a 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -1,9 +1,10 @@ name: Publish Docker image on: - push: - branches: - - master + workflow_dispatch: # allows manual triggering + schedule: + # Rebuild daily rather than on every push because it is expensive + - cron: '12 4 * * *' jobs: push_to_registry: @@ -57,16 +58,14 @@ jobs: id: tags run: | TAGS="ghcr.io/${{ github.repository }}:${{ matrix.config.tag }}" - if [ "${{ github.event_name }}" == "push" ]; then - TAGS="$TAGS,ghcr.io/${{ github.repository }}:${{ matrix.config.tag }}-${{ env.COMMIT_SHA }}" - fi + TAGS="$TAGS,ghcr.io/${{ github.repository }}:${{ matrix.config.tag }}-${{ env.COMMIT_SHA }}" echo "tags=$TAGS" >> $GITHUB_OUTPUT - name: Build and push Docker image (tagged) uses: docker/build-push-action@v6 with: context: . - push: ${{ github.event_name == 'push' }} + push: true platforms: ${{ matrix.config.platform }} tags: ${{ steps.tags.outputs.tags }} file: ${{ matrix.config.dockerfile }} From 9302c060f0d8178a01aa6b36e9673032fbc9aff8 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius <daniel.bevenius@gmail.com> Date: Thu, 4 Jun 2026 11:37:22 +0200 Subject: [PATCH 748/831] ci : use ccache instead of sccache for windows-cublas [no ci] (#3855) This commit updates the Install cache step to use ggml-org/ccache-action and switched to use ccache instead of sccache. The motivation for switching to ccache is that this is what llama.cpp does and also there is an issue with later version of sscache: ```console sccache C:\PROGRA~1\NVIDIA~1\CUDA\v\bin\nvcc.exe -forward-unknown-to-host-compiler -DGGML_BACKEND_BUILD -DGGML_BACKEND_SHARED -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 -DGGML_SCHED_MAX_COPIES=4 -DGGML_SHARED -D_CRT_SECURE_NO_WARNINGS -D_XOPEN_SOURCE=600 -Dggml_cuda_EXPORTS -DCMAKE_INTDIR=\"Release\" -ID:\a\whisper.cpp\whisper.cpp\ggml\src\ggml-cuda\.. -ID:\a\whisper.cpp\whisper.cpp\ggml\src\..\include -isystem "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v\include" -Xcompiler="-MD -O2 -Ob2" -DNDEBUG -std=c++17 -arch=native -use_fast_math -extended-lambda -Xcompiler /Zc:preprocessor -MD -MT ggml\src\ggml-cuda\CMakeFiles\ggml-cuda.dir\Release\allreduce.cu.obj -MF ggml\src\ggml-cuda\CMakeFiles\ggml-cuda.dir\Release\allreduce.cu.obj.d -x cu -c D:\a\whisper.cpp\whisper.cpp\ggml\src\ggml-cuda\allreduce.cu -o ggml\src\ggml-cuda\CMakeFiles\ggml-cuda.dir\Release\allreduce.cu.obj -Xcompiler=-Fdggml\src\ggml-cuda\CMakeFiles\ggml-cuda.dir\Release\,-FS sccache: encountered fatal error sccache: error: Could not parse shell line sccache: caused by: Could not parse shell line ``` ``` --- .github/workflows/release.yml | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 2ba8b45093b..c3ae9de4deb 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -340,10 +340,9 @@ jobs: choco install ninja - name: Install ccache - uses: hendrikmuhs/ccache-action@v1.2.16 + uses: ggml-org/ccache-action@v1.2.21 with: key: ${{ github.job }}-${{ matrix.cuda-toolkit }}-${{ matrix.build }} - variant: sccache evict-old-files: 5d - name: Install Cuda Toolkit 11.8.0 @@ -497,9 +496,9 @@ jobs: set /A NINJA_JOBS=%NUMBER_OF_PROCESSORS%-1 cmake --build build --config ${{ matrix.build }} -j %NUMBER_OF_PROCESSORS% - - name: Check sccache status after build + - name: Check ccache status after build run: | - sccache --show-stats + ccache --show-stats - name: Copy CUDA DLLs run: | From 7ecb08f26359708dbc7fbeea428916684c64a76e Mon Sep 17 00:00:00 2001 From: Daniel Bevenius <daniel.bevenius@gmail.com> Date: Thu, 4 Jun 2026 11:38:46 +0200 Subject: [PATCH 749/831] ci : pin github actions to commit SHAs (#3856) This commit pins github actions used to the same commi SHAs that llama.cpp uses. --- .github/workflows/build-android.yml | 4 ++-- .github/workflows/build-clang.yml | 2 +- .github/workflows/build-gcc.yml | 4 ++-- .github/workflows/build-windows.yml | 2 +- .github/workflows/docker.yml | 6 +++--- 5 files changed, 9 insertions(+), 9 deletions(-) diff --git a/.github/workflows/build-android.yml b/.github/workflows/build-android.yml index d9af1810131..42673166cf3 100644 --- a/.github/workflows/build-android.yml +++ b/.github/workflows/build-android.yml @@ -41,7 +41,7 @@ jobs: java-version: 21 - name: Setup Android SDK - uses: android-actions/setup-android@v3 + uses: android-actions/setup-android@40fd30fb8d7440372e1316f5d1809ec01dcd3699 # v4.0.1 - name: Build run: | @@ -69,7 +69,7 @@ jobs: cache: gradle - name: Setup Android SDK - uses: android-actions/setup-android@v3 + uses: android-actions/setup-android@40fd30fb8d7440372e1316f5d1809ec01dcd3699 # v4.0.1 with: cmdline-tools-version: 9.0 diff --git a/.github/workflows/build-clang.yml b/.github/workflows/build-clang.yml index c7a36884f64..5308164cc68 100644 --- a/.github/workflows/build-clang.yml +++ b/.github/workflows/build-clang.yml @@ -61,7 +61,7 @@ jobs: save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} - name: Set up QEMU - uses: docker/setup-qemu-action@v3 + uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a # v4 - name: Build ${{ matrix.arch }} run: | diff --git a/.github/workflows/build-gcc.yml b/.github/workflows/build-gcc.yml index 4528ba3d534..b1b04c24034 100644 --- a/.github/workflows/build-gcc.yml +++ b/.github/workflows/build-gcc.yml @@ -58,7 +58,7 @@ jobs: save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} - name: Set up QEMU - uses: docker/setup-qemu-action@v3 + uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a # v4 - name: Build ${{ matrix.arch }} run: | @@ -141,7 +141,7 @@ jobs: save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} - name: Set up QEMU - uses: docker/setup-qemu-action@v3 + uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a # v4 - name: Build ${{ matrix.arch }} run: | diff --git a/.github/workflows/build-windows.yml b/.github/workflows/build-windows.yml index cd1591f0132..9fd910ac0ec 100644 --- a/.github/workflows/build-windows.yml +++ b/.github/workflows/build-windows.yml @@ -46,7 +46,7 @@ jobs: uses: actions/checkout@v6 - name: Setup ${{ matrix.sys }} - uses: msys2/setup-msys2@v2 + uses: msys2/setup-msys2@cafece8e6baf9247cf9b1bf95097b0b983cc558d # v2 with: update: true msystem: ${{matrix.sys}} diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index 51724976e0a..e7ca8595ddd 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -30,10 +30,10 @@ jobs: uses: actions/checkout@v6 - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v3 + uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4 - name: Log in to Docker Hub - uses: docker/login-action@v3 + uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4 with: registry: ghcr.io username: ${{ github.repository_owner }} @@ -62,7 +62,7 @@ jobs: echo "tags=$TAGS" >> $GITHUB_OUTPUT - name: Build and push Docker image (tagged) - uses: docker/build-push-action@v6 + uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # v7 with: context: . push: true From ad17783d3499d54bd64f8afd19932ea7b0d5d175 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius <daniel.bevenius@gmail.com> Date: Thu, 4 Jun 2026 14:25:15 +0200 Subject: [PATCH 750/831] ci : use emscripten-core and pin version (#3857) This commit updates the setup emscripten sdk jobs to use emscripten-core instead of mymindstorm and also pins the commit sha for the version instead of using a version tag. --- .github/workflows/build-wasm.yml | 2 +- .../workflows/{examples-wasm.yml => deploy-examples-wasm.yml} | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) rename .github/workflows/{examples-wasm.yml => deploy-examples-wasm.yml} (96%) diff --git a/.github/workflows/build-wasm.yml b/.github/workflows/build-wasm.yml index 42a9401af3c..d2891eda90f 100644 --- a/.github/workflows/build-wasm.yml +++ b/.github/workflows/build-wasm.yml @@ -40,7 +40,7 @@ jobs: uses: actions/checkout@v6 - name: Setup emsdk - uses: mymindstorm/setup-emsdk@v14 + uses: emscripten-core/setup-emsdk@6ab9eb1bda2574c4ddb79809fc9247783eaf9021 # v14 - name: Verify run: emcc -v diff --git a/.github/workflows/examples-wasm.yml b/.github/workflows/deploy-examples-wasm.yml similarity index 96% rename from .github/workflows/examples-wasm.yml rename to .github/workflows/deploy-examples-wasm.yml index 927438cdad8..e7fdae77854 100644 --- a/.github/workflows/examples-wasm.yml +++ b/.github/workflows/deploy-examples-wasm.yml @@ -28,7 +28,7 @@ jobs: uses: actions/configure-pages@v5 - name: Setup emsdk - uses: mymindstorm/setup-emsdk@v14 + uses: emscripten-core/setup-emsdk@6ab9eb1bda2574c4ddb79809fc9247783eaf9021 # v14 - name: Build WASM Examples # Enable for real build later in whisper.cpp From 99613cb720b65036237d44b52f753b51f75c2797 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius <daniel.bevenius@gmail.com> Date: Thu, 4 Jun 2026 16:27:58 +0200 Subject: [PATCH 751/831] ci: build-windows action slimming (#3858) * ci : remove base-devel and git from msys2 job This commit removes the above packages as they might not be required and could help reduce the github cache size. * ci : try reducing the installs to only the compilers --- .github/workflows/build-windows.yml | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/.github/workflows/build-windows.yml b/.github/workflows/build-windows.yml index 9fd910ac0ec..156a57f74b6 100644 --- a/.github/workflows/build-windows.yml +++ b/.github/workflows/build-windows.yml @@ -38,8 +38,8 @@ jobs: fail-fast: false matrix: include: - - { sys: UCRT64, env: ucrt-x86_64, build: Release } - - { sys: CLANG64, env: clang-x86_64, build: Release } + - { sys: UCRT64, env: ucrt-x86_64, compiler: gcc, build: Release } + - { sys: CLANG64, env: clang-x86_64, compiler: clang, build: Release } steps: - name: Clone @@ -51,9 +51,7 @@ jobs: update: true msystem: ${{matrix.sys}} install: >- - base-devel - git - mingw-w64-${{matrix.env}}-toolchain + mingw-w64-${{matrix.env}}-${{matrix.compiler}} mingw-w64-${{matrix.env}}-cmake mingw-w64-${{matrix.env}}-SDL2 mingw-w64-${{matrix.env}}-openblas From 574fc0da69bcf2da3262e40d1b4009341df3d53f Mon Sep 17 00:00:00 2001 From: Daniel Bevenius <daniel.bevenius@gmail.com> Date: Sat, 6 Jun 2026 05:40:58 +0200 Subject: [PATCH 752/831] ci : add ccache to quantize, vad, and wasm jobs (#3860) * ci : add ccache to build-quantize * ci : add ccache to build-vad * ci : add ccache to build-wasm [no ci] --- .github/workflows/build-quantize.yml | 9 ++++++++- .github/workflows/build-vad.yml | 9 ++++++++- .github/workflows/build-wasm.yml | 18 ++++++++++++++++-- 3 files changed, 32 insertions(+), 4 deletions(-) diff --git a/.github/workflows/build-quantize.yml b/.github/workflows/build-quantize.yml index 8036a3a3450..69ab2c34638 100644 --- a/.github/workflows/build-quantize.yml +++ b/.github/workflows/build-quantize.yml @@ -31,11 +31,18 @@ jobs: - name: Clone uses: actions/checkout@v6 + - name: ccache + uses: ggml-org/ccache-action@v1.2.21 + with: + key: quantize-ubuntu-22 + evict-old-files: 1d + save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} + - name: Test quantize env: HF_TOKEN: ${{ secrets.HF_TOKEN }} run: | ./models/download-ggml-model.sh tiny.en - cmake -B build + cmake -B build -DCMAKE_C_COMPILER_LAUNCHER=ccache -DCMAKE_CXX_COMPILER_LAUNCHER=ccache cmake --build build --config Release ./build/bin/whisper-quantize models/ggml-tiny.en.bin models/ggml-tiny.en-q4_0.bin q4_0 diff --git a/.github/workflows/build-vad.yml b/.github/workflows/build-vad.yml index 71e910a3fcb..3c5ebec2026 100644 --- a/.github/workflows/build-vad.yml +++ b/.github/workflows/build-vad.yml @@ -31,10 +31,17 @@ jobs: - name: Checkout uses: actions/checkout@v6 + - name: ccache + uses: ggml-org/ccache-action@v1.2.21 + with: + key: vad-ubuntu-latest + evict-old-files: 1d + save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} + - name: Build shell: bash run: | - cmake -B build + cmake -B build -DCMAKE_C_COMPILER_LAUNCHER=ccache -DCMAKE_CXX_COMPILER_LAUNCHER=ccache cmake --build build --config Release - name: Test diff --git a/.github/workflows/build-wasm.yml b/.github/workflows/build-wasm.yml index d2891eda90f..45c77c0be4c 100644 --- a/.github/workflows/build-wasm.yml +++ b/.github/workflows/build-wasm.yml @@ -45,7 +45,21 @@ jobs: - name: Verify run: emcc -v + - name: ccache + uses: ggml-org/ccache-action@v1.2.21 + with: + key: wasm-ubuntu-22 + evict-old-files: 1d + save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} + - name: Build + env: + CCACHE_SLOPPINESS: time_macros,include_file_mtime,include_file_ctime + CCACHE_COMPILERCHECK: content run: | - emcmake cmake . -DCMAKE_BUILD_TYPE=${{ matrix.build }} - make + emcmake cmake -B build -DCMAKE_BUILD_TYPE=${{ matrix.build }} \ + -DCMAKE_C_COMPILER_LAUNCHER=ccache \ + -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ + "-DCMAKE_C_FLAGS=-ffile-prefix-map=$EMSDK=/emsdk" \ + "-DCMAKE_CXX_FLAGS=-ffile-prefix-map=$EMSDK=/emsdk" + cmake --build build -j $(nproc) From a8ec021f2750a473ff4a8f3883bc9fdf5feafa84 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius <daniel.bevenius@gmail.com> Date: Sat, 6 Jun 2026 18:34:40 +0200 Subject: [PATCH 753/831] ci : add HF_TOKEN to docker.yml workflow [no ci] (#3861) This commit adds the HF_TOKEN secret to the docker workflows to avoid HF rate limiting which currently sometimes causes the jobs to fail. Refs: https://github.com/ggml-org/whisper.cpp/actions/runs/27053852601/job/79854251771 --- .devops/main-cuda.Dockerfile | 2 +- .devops/main-intel.Dockerfile | 3 ++- .devops/main-musa.Dockerfile | 2 +- .devops/main-vulkan.Dockerfile | 2 +- .devops/main.Dockerfile | 2 +- .github/workflows/docker.yml | 2 ++ 6 files changed, 8 insertions(+), 5 deletions(-) diff --git a/.devops/main-cuda.Dockerfile b/.devops/main-cuda.Dockerfile index c2bf0fbd1c6..7a21fc4e3db 100644 --- a/.devops/main-cuda.Dockerfile +++ b/.devops/main-cuda.Dockerfile @@ -25,7 +25,7 @@ ENV LD_LIBRARY_PATH /usr/local/cuda-${CUDA_MAIN_VERSION}/compat:$LD_LIBRARY_PATH COPY .. . # Enable cuBLAS -RUN make base.en CMAKE_ARGS="-DGGML_CUDA=1 -DCMAKE_CUDA_ARCHITECTURES='75;80;86;90'" +RUN --mount=type=secret,id=HF_TOKEN,required=false,env=HF_TOKEN make base.en CMAKE_ARGS="-DGGML_CUDA=1 -DCMAKE_CUDA_ARCHITECTURES='75;80;86;90'" RUN find /app/build -name "*.o" -delete && \ find /app/build -name "*.a" -delete && \ diff --git a/.devops/main-intel.Dockerfile b/.devops/main-intel.Dockerfile index 86b901c1538..a0c04ad34ad 100644 --- a/.devops/main-intel.Dockerfile +++ b/.devops/main-intel.Dockerfile @@ -10,7 +10,8 @@ RUN apt-get update && \ COPY .. . # Enable SYCL ARG GGML_SYCL_F16=OFF -RUN if [ "${GGML_SYCL_F16}" = "ON" ]; then \ +RUN --mount=type=secret,id=HF_TOKEN,required=false,env=HF_TOKEN \ + if [ "${GGML_SYCL_F16}" = "ON" ]; then \ echo "GGML_SYCL_F16 is set" \ && export OPT_SYCL_F16="-DGGML_SYCL_F16=ON"; \ fi && \ diff --git a/.devops/main-musa.Dockerfile b/.devops/main-musa.Dockerfile index 026791e3f89..c68367830f1 100644 --- a/.devops/main-musa.Dockerfile +++ b/.devops/main-musa.Dockerfile @@ -16,7 +16,7 @@ RUN apt-get update && \ COPY .. . # Enable muBLAS -RUN make base.en CMAKE_ARGS="-DGGML_MUSA=1" +RUN --mount=type=secret,id=HF_TOKEN,required=false,env=HF_TOKEN make base.en CMAKE_ARGS="-DGGML_MUSA=1" RUN find /app/build -name "*.o" -delete && \ find /app/build -name "*.a" -delete && \ diff --git a/.devops/main-vulkan.Dockerfile b/.devops/main-vulkan.Dockerfile index 077af4f1001..16ee19dc689 100644 --- a/.devops/main-vulkan.Dockerfile +++ b/.devops/main-vulkan.Dockerfile @@ -6,7 +6,7 @@ RUN apt-get update && \ && rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/* COPY .. . -RUN make base.en CMAKE_ARGS="-DGGML_VULKAN=1" +RUN --mount=type=secret,id=HF_TOKEN,required=false,env=HF_TOKEN make base.en CMAKE_ARGS="-DGGML_VULKAN=1" FROM ubuntu:24.04 AS runtime WORKDIR /app diff --git a/.devops/main.Dockerfile b/.devops/main.Dockerfile index e1eb9b33700..d0e809f4e13 100644 --- a/.devops/main.Dockerfile +++ b/.devops/main.Dockerfile @@ -6,7 +6,7 @@ RUN apt-get update && \ && rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/* COPY .. . -RUN make base.en +RUN --mount=type=secret,id=HF_TOKEN,required=false,env=HF_TOKEN make base.en FROM ubuntu:22.04 AS runtime WORKDIR /app diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index e7ca8595ddd..b4c455b92e9 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -69,3 +69,5 @@ jobs: platforms: ${{ matrix.config.platform }} tags: ${{ steps.tags.outputs.tags }} file: ${{ matrix.config.dockerfile }} + secrets: | + HF_TOKEN=${{ secrets.HF_TOKEN }} From e1da83d7736f4a170a4c8057c205df35c39fe230 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius <daniel.bevenius@gmail.com> Date: Mon, 8 Jun 2026 07:27:12 +0200 Subject: [PATCH 754/831] ci : add ccache to build-sycl [no ci] (#3859) --- .github/workflows/build-sycl.yml | 40 +++++++++++++++++++++++--------- 1 file changed, 29 insertions(+), 11 deletions(-) diff --git a/.github/workflows/build-sycl.yml b/.github/workflows/build-sycl.yml index 57aa7cc4d95..c76954e49cf 100644 --- a/.github/workflows/build-sycl.yml +++ b/.github/workflows/build-sycl.yml @@ -61,24 +61,33 @@ jobs: shell: bash run: | sudo apt update - sudo apt install intel-oneapi-compiler-dpcpp-cpp git + sudo apt install intel-oneapi-compiler-dpcpp-cpp - name: install oneAPI MKL library shell: bash run: | - sudo apt install intel-oneapi-mkl-devel git + sudo apt install intel-oneapi-mkl-devel - - name: Clone - id: checkout - uses: actions/checkout@v6 + - name: ccache + uses: ggml-org/ccache-action@v1.2.21 + with: + key: sycl-${{ matrix.arch }} + evict-old-files: 1d + save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} - name: Build id: cmake_build + env: + CCACHE_SLOPPINESS: time_macros + CCACHE_NODIRECT: 1 run: | source /opt/intel/oneapi/setvars.sh + export CCACHE_COMPILERCHECK="string:$(icpx --version 2>&1 | head -1)" mkdir build cd build - cmake -DGGML_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx .. + cmake -DGGML_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx \ + -DCMAKE_C_COMPILER_LAUNCHER=ccache \ + -DCMAKE_CXX_COMPILER_LAUNCHER=ccache .. cmake --build . --config Release -j $(nproc) ubuntu-22-cmake-sycl-fp16: @@ -111,22 +120,31 @@ jobs: shell: bash run: | sudo apt update - sudo apt install intel-oneapi-compiler-dpcpp-cpp git + sudo apt install intel-oneapi-compiler-dpcpp-cpp - name: install oneAPI MKL library shell: bash run: | sudo apt install intel-oneapi-mkl-devel - - name: Clone - id: checkout - uses: actions/checkout@v6 + - name: ccache + uses: ggml-org/ccache-action@v1.2.21 + with: + key: sycl-fp16-${{ matrix.arch }} + evict-old-files: 1d + save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} - name: Build id: cmake_build + env: + CCACHE_SLOPPINESS: time_macros + CCACHE_NODIRECT: 1 run: | source /opt/intel/oneapi/setvars.sh + export CCACHE_COMPILERCHECK="string:$(icpx --version 2>&1 | head -1)" mkdir build cd build - cmake -DGGML_SYCL_F16=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx .. + cmake -DGGML_SYCL_F16=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx \ + -DCMAKE_C_COMPILER_LAUNCHER=ccache \ + -DCMAKE_CXX_COMPILER_LAUNCHER=ccache .. cmake --build . --config Release -j $(nproc) From c50e951afdf0b1bd4d63adddbd48dc90ff92893c Mon Sep 17 00:00:00 2001 From: fairydreaming <166155368+fairydreaming@users.noreply.github.com> Date: Fri, 29 May 2026 10:15:17 +0200 Subject: [PATCH 755/831] model : support for DeepseekV32ForCausalLM with generic DeepSeek Sparse Attention (DSA) implementation (llama/23346) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * llama : support DeepSeek V3.2 model family (with DSA lightning indexer) * convert : handle DeepseekV32ForCausalLM architecture * ggml : support for f16 GGML_OP_FILL * memory : separate hparams argument in llama_kv_cache constructor * memory : add llama_kv_cache_dsa memory (KV cache + lightning indexer cache) * llama : support for LLM_ARCH_DEEPSEEK32 * model : llama_model_deepseek32 implementation * model : merge two scale operations into one in DSA lightning indexer implementation * chore : remove unused code * model : support NVFP4 in DeepSeek V3.2 Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * memory : refactoring TODO Co-authored-by: ggerganov <ggerganov@users.noreply.github.com> --------- Co-authored-by: Stanisław Szymczyk <sszymczy@gmail.com> Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> Co-authored-by: ggerganov <ggerganov@users.noreply.github.com> --- ggml/src/ggml-cpu/ops.cpp | 36 +++++++++++++++++++++++++++++++++++- ggml/src/ggml.c | 2 +- 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 7485ba4fc86..dc73696ad9f 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -2235,8 +2235,42 @@ static void ggml_compute_forward_fill_f32(const ggml_compute_params * params, gg } } +static void ggml_compute_forward_fill_f16(const ggml_compute_params * params, ggml_tensor * dst) { + const ggml_fp16_t c = GGML_CPU_FP32_TO_FP16(ggml_get_op_params_f32(dst, 0)); + + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne); + GGML_TENSOR_LOCALS(size_t, nb, dst, nb); + + const auto [ir0, ir1] = get_thread_range(params, dst); + + for (int64_t ir = ir0; ir < ir1; ++ir) { + const int64_t i03 = ir/(ne2*ne1); + const int64_t i02 = (ir - i03*ne2*ne1)/ne1; + const int64_t i01 = (ir - i03*ne2*ne1 - i02*ne1); + + ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1); + + ggml_vec_set_f16(ne0, dst_ptr, c); + } +} + void ggml_compute_forward_fill(const ggml_compute_params * params, ggml_tensor * dst) { - ggml_compute_forward_fill_f32(params, dst); + const ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_fill_f32(params, dst); + } break; + case GGML_TYPE_F16: + { + ggml_compute_forward_fill_f16(params, dst); + } break; + default: + { + GGML_ABORT("unsupported type for ggml_compute_forward_fill: %s", ggml_type_name(src0->type)); + } + } } // ggml_compute_tri diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 476c3079795..8815c67d8bc 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -5223,7 +5223,7 @@ static struct ggml_tensor * ggml_fill_impl( struct ggml_tensor * a, float c, bool inplace) { - GGML_ASSERT(a->type == GGML_TYPE_F32); + GGML_ASSERT(a->type == GGML_TYPE_F32 || a->type == GGML_TYPE_F16); GGML_ASSERT(ggml_is_contiguous(a)); struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); From f7aad4ed7e6818cebdbe87fd78a28f02f6cefedb Mon Sep 17 00:00:00 2001 From: Oliver Simons <osimons@nvidia.com> Date: Fri, 29 May 2026 12:28:18 +0200 Subject: [PATCH 756/831] CUDA: Check PTX version on host side to guard PDL dispatch (llama/23530) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * CUDA: Check PTX version on host side to guard PDL dispatch Checking on `__CUDA_ARCH_LIST__` alone is insufficient for JIT, as this variable doesn't differentiate between compiling for say sm_90, sm_90a or sm_90f (so forward-jittable PTX vs. arch/family-specific PTX). Thus, one can have a bug when compiling with `DCMAKE_CUDA_ARCHITECTURES="89;90a"`, where current code would wrongly dispatch to PDL on sm_90/sm_120 in forward-JIT mode. This PR fixes this issue by checking `cudaFuncAttributes::ptxVersion` of the incoming kernel at runtime. A check on ptxVersion alone is sufficient, as device-codes will always be >= ptxVersion (and any violation of this would be a severe bug in CUDA/nvcc), see: https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/#gpu-code-code-code * Implement MurmurHash3 mixer for better hash distribution Magic constants were taken from boost: https://github.com/boostorg/container_hash/blob/2698b43803c012601e6bb1a6116e83767b97986c/include/boost/container_hash/detail/hash_mix.hpp#L19-L65 * Update ggml/src/ggml-cuda/common.cuh Co-authored-by: Johannes Gäßler <johannesg@5d6.de> * Address review comments, make seed non-zero * Apply code-formatting * Replace std::size_t -> size_t for consistency --------- Co-authored-by: Johannes Gäßler <johannesg@5d6.de> --- ggml/src/ggml-cuda/common.cuh | 60 +++++++++++++++++++++++++++++++++-- 1 file changed, 58 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 50d7763dcdd..560fab0b17b 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -7,6 +7,7 @@ #include <cstdint> #include <cstdlib> #include <memory> +#include <mutex> #if defined(GGML_USE_HIP) #define GGML_COMMON_DECL_HIP @@ -1552,6 +1553,62 @@ struct ggml_cuda_pdl_config { ggml_cuda_pdl_config& operator=(ggml_cuda_pdl_config&&) = delete; }; + +static bool ggml_cuda_kernel_can_use_pdl(const void * kernel) { + const int device = ggml_cuda_get_device(); + + struct cache_key { + int device; + const void * kernel; + + bool operator==(const cache_key & other) const { return device == other.device && kernel == other.kernel; } + }; + + struct cache_key_hash { + // MurmurHash3 mixing function for better hash distribution (vs. just std::hash which in some implementations simply returns the identity) + static size_t hash_mix(size_t x) { + std::uint64_t y = x; + const std::uint64_t m = 0xe9846af9b1a615d; + + y ^= y >> 32; + y *= m; + y ^= y >> 32; + y *= m; + y ^= y >> 28; + + return static_cast<size_t>(y); + } + + size_t operator()(const cache_key & key) const { + // Use a nonzero seed to avoid mapping all-zero keys to zero + size_t h = 42; + h = hash_mix(h + key.device); + h = hash_mix(h + reinterpret_cast<size_t>(key.kernel)); + return h; + } + }; + + static std::mutex cache_mutex; + static std::unordered_map<cache_key, bool, cache_key_hash> cache; + + const cache_key key = { device, kernel }; + std::lock_guard<std::mutex> lock(cache_mutex); + const auto it = cache.find(key); + if (it != cache.end()) { + return it->second; + } + + cudaFuncAttributes attr = {}; + CUDA_CHECK(cudaFuncGetAttributes(&attr, kernel)); + + // PDL device-side primitives are emitted only for PTX versions >= 90. + // We have to guard on a loaded kernel's PTX version so a kernel forward-JIT'ed + // from pre-Hopper PTX to a Hopper-or-newer GPU does not opt into PDL. + const bool can_use_pdl = attr.ptxVersion >= 90; + cache.emplace(key, can_use_pdl); + return can_use_pdl; +} + #endif //defined(GGML_CUDA_USE_PDL) @@ -1564,8 +1621,7 @@ static __inline__ void ggml_cuda_kernel_launch(Kernel kernel, const ggml_cuda_ke return env == nullptr || std::atoi(env) != 0; }(); - const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; - if (env_pdl_enabled && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_HOPPER) { + if (env_pdl_enabled && ggml_cuda_kernel_can_use_pdl(reinterpret_cast<const void *>(kernel))) { auto pdl_cfg = ggml_cuda_pdl_config(launch_params); CUDA_CHECK(cudaLaunchKernelEx(&pdl_cfg.cfg, kernel, std::forward<Args>(args)... )); From acd91d2c3891ac8f2538152882552657691ed6af Mon Sep 17 00:00:00 2001 From: Reese Levine <reeselevine1@gmail.com> Date: Fri, 29 May 2026 14:14:11 -0700 Subject: [PATCH 757/831] ggml-webgpu: add q4_0/q8_0 SET_ROWS (llama/23760) * Add q8_0 and q4_0 set_rows * Add fast(er) quantization set_rows path * formatting/naming * a little more naming * Remove unused constant * Don't override other override * Avoid bitcast * Narrow relaxation --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 90 ++++--- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 11 +- .../ggml-webgpu/wgsl-shaders/set_rows.wgsl | 5 +- .../wgsl-shaders/set_rows_quant.wgsl | 224 ++++++++++++++++++ 4 files changed, 289 insertions(+), 41 deletions(-) create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/set_rows_quant.wgsl diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 60e98a60741..f4c5eca0df5 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -84,16 +84,16 @@ struct ggml_webgpu_shader_lib_context { ggml_tensor * src5; ggml_tensor * dst; - uint32_t max_wg_size; - size_t wg_mem_limit_bytes = 0; - bool supports_subgroups = false; - bool supports_subgroup_matrix = false; - uint32_t sg_mat_m = 0; - uint32_t sg_mat_n = 0; - uint32_t sg_mat_k = 0; - uint32_t min_subgroup_size = 0; - uint32_t max_subgroup_size = 0; - bool supports_dot_product = false; + uint32_t max_wg_size; + size_t wg_mem_limit_bytes = 0; + bool supports_subgroups = false; + bool supports_subgroup_matrix = false; + uint32_t sg_mat_m = 0; + uint32_t sg_mat_n = 0; + uint32_t sg_mat_k = 0; + uint32_t min_subgroup_size = 0; + uint32_t max_subgroup_size = 0; + bool supports_dot_product = false; std::string vendor; }; @@ -166,9 +166,11 @@ struct ggml_webgpu_set_rows_pipeline_key { int dst_type; int vec4; int i64_idx; + int pair_blocks; bool operator==(const ggml_webgpu_set_rows_pipeline_key & other) const { - return dst_type == other.dst_type && vec4 == other.vec4 && i64_idx == other.i64_idx; + return dst_type == other.dst_type && vec4 == other.vec4 && i64_idx == other.i64_idx && + pair_blocks == other.pair_blocks; } }; @@ -178,6 +180,7 @@ struct ggml_webgpu_set_rows_pipeline_key_hash { ggml_webgpu_hash_combine(seed, key.dst_type); ggml_webgpu_hash_combine(seed, key.vec4); ggml_webgpu_hash_combine(seed, key.i64_idx); + ggml_webgpu_hash_combine(seed, key.pair_blocks); return seed; } }; @@ -185,6 +188,7 @@ struct ggml_webgpu_set_rows_pipeline_key_hash { struct ggml_webgpu_set_rows_shader_decisions { bool vec4; bool i64_idx; + bool pair_blocks; uint32_t wg_size; }; @@ -772,31 +776,30 @@ inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions( (v_offset_elems % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0u); const bool kv_vec_type_supported = K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0; - const uint32_t kv_vec_head_align = K->type == GGML_TYPE_F16 ? GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH : - (uint32_t) ggml_blck_size(K->type); - const bool kv_vec_head_dims_aligned = context.src0->ne[0] % kv_vec_head_align == 0 && - context.src2->ne[0] % kv_vec_head_align == 0; + const uint32_t kv_vec_head_align = + K->type == GGML_TYPE_F16 ? GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH : (uint32_t) ggml_blck_size(K->type); + const bool kv_vec_head_dims_aligned = + context.src0->ne[0] % kv_vec_head_align == 0 && context.src2->ne[0] % kv_vec_head_align == 0; // Compile with enough invocations to cover the largest reported subgroup. - const bool use_vec = context.supports_subgroups && (context.src0->ne[1] < 20) && - kv_vec_head_dims_aligned && kv_vec_type_supported && - (K->type != GGML_TYPE_F16 || f16_vec4_aligned) && + const bool use_vec = context.supports_subgroups && (context.src0->ne[1] < 20) && kv_vec_head_dims_aligned && + kv_vec_type_supported && (K->type != GGML_TYPE_F16 || f16_vec4_aligned) && (context.src2->type == K->type); const bool tile_can_dispatch_all_q_rows = context.max_subgroup_size > 0 && context.max_wg_size >= GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE * context.max_subgroup_size; - const bool use_subgroup_matrix = - context.supports_subgroup_matrix && context.sg_mat_k > 0 && context.sg_mat_n > 0 && - context.src0->ne[0] % context.sg_mat_k == 0 && context.src2->ne[0] % context.sg_mat_n == 0; + const bool use_subgroup_matrix = context.supports_subgroup_matrix && context.sg_mat_k > 0 && context.sg_mat_n > 0 && + context.src0->ne[0] % context.sg_mat_k == 0 && + context.src2->ne[0] % context.sg_mat_n == 0; const bool use_tile = context.supports_subgroups && !use_subgroup_matrix && K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16 && f16_vec4_aligned && (context.src0->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && (context.src2->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && tile_can_dispatch_all_q_rows && !use_vec; - decisions.path = use_vec ? GGML_WEBGPU_FLASH_ATTN_PATH_VEC : - use_tile ? GGML_WEBGPU_FLASH_ATTN_PATH_TILE : - use_subgroup_matrix ? GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX : - GGML_WEBGPU_FLASH_ATTN_PATH_NONE; + decisions.path = use_vec ? GGML_WEBGPU_FLASH_ATTN_PATH_VEC : + use_tile ? GGML_WEBGPU_FLASH_ATTN_PATH_TILE : + use_subgroup_matrix ? GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX : + GGML_WEBGPU_FLASH_ATTN_PATH_NONE; if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_NONE) { return decisions; @@ -1131,9 +1134,9 @@ class ggml_webgpu_shader_lib { ggml_webgpu_flash_attn_blk_pipeline_key_hash> flash_attn_blk_pipelines; std::unordered_map<ggml_webgpu_mul_mat_vec_pipeline_key, webgpu_pipeline, ggml_webgpu_mul_mat_vec_pipeline_key_hash> - mul_mat_vec_pipelines; // fast mat-vec (n==1) + mul_mat_vec_pipelines; // fast mat-vec (n==1) std::unordered_map<ggml_webgpu_mul_mat_pipeline_key, webgpu_pipeline, ggml_webgpu_mul_mat_pipeline_key_hash> - mul_mat_fast_pipelines; // fast mat-mat (reg-tile or subgroup) + mul_mat_fast_pipelines; // fast mat-mat (reg-tile or subgroup) std::unordered_map<ggml_webgpu_quantize_q8_pipeline_key, webgpu_pipeline, ggml_webgpu_quantize_q8_pipeline_key_hash> quantize_q8_pipelines; std::unordered_map<int, webgpu_pipeline> mul_mat_id_gather_pipelines; // key is fixed @@ -1264,10 +1267,13 @@ class ggml_webgpu_shader_lib { } webgpu_pipeline get_set_rows_pipeline(const ggml_webgpu_shader_lib_context & context) { - ggml_webgpu_set_rows_pipeline_key key = {}; - key.dst_type = context.dst->type; - key.vec4 = context.src0->ne[0] % 4 == 0; - key.i64_idx = context.src1->type == GGML_TYPE_I64; + const bool quantized = ggml_is_quantized(context.dst->type); + ggml_webgpu_set_rows_pipeline_key key = {}; + key.dst_type = context.dst->type; + key.vec4 = + (context.dst->type == GGML_TYPE_F32 || context.dst->type == GGML_TYPE_F16) && context.src0->ne[0] % 4 == 0; + key.i64_idx = context.src1->type == GGML_TYPE_I64; + key.pair_blocks = quantized && ((context.src0->ne[0] / ggml_blck_size(context.dst->type)) % 2 == 0); auto it = set_rows_pipelines.find(key); if (it != set_rows_pipelines.end()) { @@ -1286,6 +1292,14 @@ class ggml_webgpu_shader_lib { defines.push_back("DST_F16"); variant += "_dstf16"; break; + case GGML_TYPE_Q8_0: + defines.push_back("DST_Q8_0"); + variant += "_dstq8_0"; + break; + case GGML_TYPE_Q4_0: + defines.push_back("DST_Q4_0"); + variant += "_dstq4_0"; + break; default: GGML_ABORT("Unsupported dst type for set_rows shader"); } @@ -1298,13 +1312,19 @@ class ggml_webgpu_shader_lib { defines.push_back("I64_IDX"); variant += "_i64idx"; } + if (key.pair_blocks) { + defines.push_back("PAIR_BLOCKS"); + variant += "_pair_blocks"; + } defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); - auto processed = preprocessor.preprocess(wgsl_set_rows, defines); - auto decisions = std::make_shared<ggml_webgpu_set_rows_shader_decisions>(); + const auto & shader_source = quantized ? wgsl_set_rows_quant : wgsl_set_rows; + auto processed = preprocessor.preprocess(shader_source, defines); + auto decisions = std::make_shared<ggml_webgpu_set_rows_shader_decisions>(); decisions->vec4 = key.vec4; decisions->i64_idx = key.i64_idx; + decisions->pair_blocks = key.pair_blocks; decisions->wg_size = context.max_wg_size; set_rows_pipelines[key] = ggml_webgpu_create_pipeline(device, processed, variant); set_rows_pipelines[key].context = decisions; @@ -1660,7 +1680,7 @@ class ggml_webgpu_shader_lib { key.type = context.dst->type; key.d_state = (int) context.src0->ne[0]; key.xbc_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src4) && - ggml_webgpu_tensor_overlap(context.src1, context.src5); + ggml_webgpu_tensor_overlap(context.src1, context.src5); auto it = ssm_scan_pipelines.find(key); if (it != ssm_scan_pipelines.end()) { @@ -1819,7 +1839,7 @@ class ggml_webgpu_shader_lib { (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? 1 : 0; - key.use_mmvq = + key.use_mmvq = ggml_webgpu_can_use_mmvq(context.src0, context.src1, context.supports_dot_product, context.vendor); auto it = mul_mat_vec_pipelines.find(key); diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 1846886db4e..1a99f1cb52f 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -1331,7 +1331,11 @@ static std::optional<webgpu_encoded_op> ggml_webgpu_set_rows(webgpu_context & ct } uint32_t threads; - if (decisions->vec4) { + if (ggml_is_quantized(dst->type)) { + const uint32_t blocks_per_row = src->ne[0] / ggml_blck_size(dst->type); + threads = + (src->ne[1] * src->ne[2] * src->ne[3]) * (decisions->pair_blocks ? (blocks_per_row / 2) : blocks_per_row); + } else if (decisions->vec4) { threads = (src->ne[1] * src->ne[2] * src->ne[3]) * (src->ne[0] / 4); } else { threads = src->ne[0] * src->ne[1] * src->ne[2] * src->ne[3]; @@ -4046,8 +4050,9 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_I32); break; case GGML_OP_SET_ROWS: - supports_op = ((op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32) && src0->type == GGML_TYPE_F32 && - (src1->type == GGML_TYPE_I64 || src1->type == GGML_TYPE_I32)); + supports_op = ((op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_Q8_0 || + op->type == GGML_TYPE_Q4_0) && + src0->type == GGML_TYPE_F32 && (src1->type == GGML_TYPE_I64 || src1->type == GGML_TYPE_I32)); break; case GGML_OP_GET_ROWS: if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_webgpu_supported_qtype(src0->type)) { diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl index 99e9192c71a..09f2f0eddb3 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl @@ -71,7 +71,6 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) { return; } - // getting the row from gid let elems_per_row = params.ne0 / VEC_SIZE; var i = gid.x / elems_per_row; @@ -104,6 +103,6 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) { let i_dst_row = params.offset_dst + idx_val * params.stride_dst1 + i_src2 * params.stride_dst2 + i_src3 * params.stride_dst3; let i_src_row = params.offset_src + i_src1 * params.stride_src1 + i_src2 * params.stride_src2 + i_src3 * params.stride_src3; - let col_idx = (gid.x % elems_per_row); - dst[i_dst_row/VEC_SIZE + col_idx] = DST_TYPE(src[i_src_row/VEC_SIZE + col_idx]); + let col_idx = gid.x % elems_per_row; + dst[i_dst_row / VEC_SIZE + col_idx] = DST_TYPE(src[i_src_row / VEC_SIZE + col_idx]); } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/set_rows_quant.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/set_rows_quant.wgsl new file mode 100644 index 00000000000..876e65b6ae1 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/set_rows_quant.wgsl @@ -0,0 +1,224 @@ +#ifdef DST_Q8_0 +#define BLOCK_SIZE 32u +#define BLOCK_BYTES 34u +#define QS_WORDS 8u +#elif defined(DST_Q4_0) +#define BLOCK_SIZE 32u +#define BLOCK_BYTES 18u +#define QS_WORDS 4u +#endif + +@group(0) @binding(0) +var<storage, read_write> src: array<f32>; + +@group(0) @binding(1) +var<storage, read_write> idx: array<u32>; + +@group(0) @binding(2) +#ifdef PAIR_BLOCKS +var<storage, read_write> dst: array<u32>; +#else +var<storage, read_write> dst: array<atomic<u32>>; +#endif + +#ifdef I64_IDX +@group(0) @binding(3) +var<storage, read_write> error: atomic<u32>; +#define PARAMS_BINDING 4 +#else +#define PARAMS_BINDING 3 +#endif + +struct Params { + offset_src: u32, // in elements + offset_idx: u32, // in elements + offset_dst: u32, // in blocks + + // Strides (in elements / blocks) + stride_src1: u32, + stride_src2: u32, + stride_src3: u32, + + stride_idx0: u32, + stride_idx1: u32, + stride_idx2: u32, + + stride_dst1: u32, + stride_dst2: u32, + stride_dst3: u32, + + // Shape of src + ne0: u32, + n_rows: u32, + ne2: u32, + ne3: u32, + + // Shape of idx + idx1: u32, + idx2: u32, +}; + +@group(0) @binding(PARAMS_BINDING) +var<uniform> params: Params; + +// if the quantization type is unaligned and there are an odd number of blocks per row, we need to store atomically +#ifndef PAIR_BLOCKS +fn merge_store_dst_word(word_idx: u32, mask: u32, bits: u32) { + loop { + let old = atomicLoad(&dst[word_idx]); + let merged = (old & ~mask) | (bits & mask); + let result = atomicCompareExchangeWeak(&dst[word_idx], old, merged); + if (result.exchanged) { + return; + } + } +} +#else +fn merge_store_dst_word(word_idx: u32, mask: u32, bits: u32) { + let old = dst[word_idx]; + dst[word_idx] = (old & ~mask) | (bits & mask); +} +#endif + +fn store_u16(dst_word_idx: u32, block_byte_offset: u32, byte_offset: u32, value: u32) { + let total_byte_offset = block_byte_offset + byte_offset; + let word_idx = dst_word_idx + total_byte_offset / 4u; + let shift = (total_byte_offset & 2u) * 8u; + let mask = 0xFFFFu << shift; + merge_store_dst_word(word_idx, mask, (value & 0xFFFFu) << shift); +} + +fn store_u32(dst_word_idx: u32, block_byte_offset: u32, byte_offset: u32, value: u32) { + let total_byte_offset = block_byte_offset + byte_offset; + let word_idx = dst_word_idx + total_byte_offset / 4u; + let shift = (total_byte_offset & 3u) * 8u; + + if (shift == 0u) { +#ifdef PAIR_BLOCKS + dst[word_idx] = value; +#else + atomicStore(&dst[word_idx], value); +#endif + return; + } + + let lo_mask = 0xFFFFFFFFu << shift; + let hi_mask = (1u << shift) - 1u; + merge_store_dst_word(word_idx, lo_mask, value << shift); + merge_store_dst_word(word_idx + 1u, hi_mask, value >> (32u - shift)); +} + +fn quantize_block_params(src_block: u32) -> vec2<f32> { +#ifdef DST_Q8_0 + var amax = 0.0; + for (var j: u32 = 0u; j < BLOCK_SIZE; j++) { + amax = max(amax, abs(src[src_block + j])); + } + + let d = amax / 127.0; + let id = select(0.0, 1.0 / d, d > 0.0); + return vec2(d, id); +#elif defined(DST_Q4_0) + var amax = 0.0; + var max_val = 0.0; + for (var j: u32 = 0u; j < BLOCK_SIZE; j++) { + let v = src[src_block + j]; + let av = abs(v); + if (amax < av) { + amax = av; + max_val = v; + } + } + + let d = max_val / -8.0; + let id = select(0.0, 1.0 / d, d != 0.0); + return vec2(d, id); +#endif +} + +fn quantize_block_word(src_block: u32, j: u32, id: f32) -> u32 { +#ifdef DST_Q8_0 + let base = src_block + j * 4u; + return (u32(i32(round(src[base + 0u] * id)) & 0xFF) << 0u) | + (u32(i32(round(src[base + 1u] * id)) & 0xFF) << 8u) | + (u32(i32(round(src[base + 2u] * id)) & 0xFF) << 16u) | + (u32(i32(round(src[base + 3u] * id)) & 0xFF) << 24u); +#elif defined(DST_Q4_0) + var packed_q = 0u; + for (var k: u32 = 0u; k < 4u; k++) { + let x0 = src[src_block + j * 4u + k] * id; + let x1 = src[src_block + 16u + j * 4u + k] * id; + let q0 = u32(clamp(i32(x0 + 8.5), 0, 15)); + let q1 = u32(clamp(i32(x1 + 8.5), 0, 15)); + packed_q |= (q0 & 0xFu) << (8u * k); + packed_q |= (q1 & 0xFu) << (8u * k + 4u); + } + return packed_q; +#endif +} + +fn quantize_block(src_block: u32, dst_word_idx: u32, block_byte_offset: u32) { + let params = quantize_block_params(src_block); + let d = params.x; + let id = params.y; + let packed_d = pack2x16float(vec2(d, 0.0)) & 0xFFFFu; + store_u16(dst_word_idx, block_byte_offset, 0u, packed_d); + + for (var j: u32 = 0u; j < QS_WORDS; j++) { + store_u32(dst_word_idx, block_byte_offset, 2u + j * 4u, quantize_block_word(src_block, j, id)); + } +} + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(global_invocation_id) gid: vec3<u32>) { + let blocks_per_row = params.ne0 / BLOCK_SIZE; +#ifdef PAIR_BLOCKS + let blocks_per_invocation = 2u; +#else + let blocks_per_invocation = 1u; +#endif + let invocations_per_row = blocks_per_row / blocks_per_invocation; + let total_invocations = params.ne3 * params.ne2 * params.n_rows * invocations_per_row; + if (gid.x >= total_invocations) { + return; + } + + var i = gid.x / invocations_per_row; + let block_in_row = (gid.x % invocations_per_row) * blocks_per_invocation; + + let i_src3 = i / (params.ne2 * params.n_rows); + i = i % (params.ne2 * params.n_rows); + let i_src2 = i / params.n_rows; + let i_src1 = i % params.n_rows; + + let i_idx2 = i_src3 % params.idx2; + let i_idx1 = i_src2 % params.idx1; + let i_idx0 = i_src1; + +#ifdef I64_IDX + let idx_high = (params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2) * 2u; + let idx_val = idx[idx_high]; + let idx_low_val = idx[idx_high + 1u]; + + if (idx_low_val != 0u) { + atomicStore(&error, 1u); + return; + } +#else + let idx_i = params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2; + let idx_val = idx[idx_i]; +#endif + + let dst_row_blocks = params.offset_dst + idx_val * params.stride_dst1 + i_src2 * params.stride_dst2 + i_src3 * params.stride_dst3; + let src_row = params.offset_src + i_src1 * params.stride_src1 + i_src2 * params.stride_src2 + i_src3 * params.stride_src3; + let src_block = src_row + block_in_row * BLOCK_SIZE; + let dst_block_byte = (dst_row_blocks + block_in_row) * BLOCK_BYTES; + + let dst_word_idx = dst_block_byte / 4u; +#ifdef PAIR_BLOCKS + quantize_block(src_block, dst_word_idx, 0u); + quantize_block(src_block + BLOCK_SIZE, dst_word_idx, BLOCK_BYTES); +#else + quantize_block(src_block, dst_word_idx, dst_block_byte & 3u); +#endif +} From 9147a9676b9945920088e90ee703d2ff462ded8b Mon Sep 17 00:00:00 2001 From: Reese Levine <reeselevine1@gmail.com> Date: Fri, 29 May 2026 14:16:05 -0700 Subject: [PATCH 758/831] ggml-webgpu: Check earlier for WebGPU required features (llama/23879) --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 1a99f1cb52f..d577b5afa3c 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -3724,7 +3724,7 @@ static void ggml_webgpu_init_memset_pipeline(webgpu_global_context & ctx) { ctx->memset_pipeline = ggml_webgpu_create_pipeline(ctx->device, wgsl_memset, "memset", constants); } -static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { +static void create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { wgpu::RequestAdapterOptions options = {}; #ifndef __EMSCRIPTEN__ @@ -3762,10 +3762,6 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { ctx->webgpu_global_ctx->command_submit_batch_size = ggml_backend_webgpu_get_command_submit_batch_size(); ctx->webgpu_global_ctx->max_inflight_batches = ggml_backend_webgpu_get_max_inflight_batches(); ctx->webgpu_global_ctx->vendor = info.vendor; - wgpu::SupportedFeatures features; - ctx->webgpu_global_ctx->adapter.GetFeatures(&features); - // we require f16 support - GGML_ASSERT(ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ShaderF16)); ctx->webgpu_global_ctx->capabilities.supports_subgroups = ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::Subgroups); // for dot4I8packed @@ -3877,7 +3873,6 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { "device_desc: %s\n", info.vendorID, std::string(info.vendor).c_str(), std::string(info.architecture).c_str(), info.deviceID, std::string(info.device).c_str(), std::string(info.description).c_str()); - return true; } static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) { @@ -4507,7 +4502,12 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() { UINT64_MAX); } - if (adapter != nullptr) { + // WebGPU backend requires f16 support and, on native, implicit device synchronization. + if (adapter != nullptr && adapter.HasFeature(wgpu::FeatureName::ShaderF16) +#ifndef __EMSCRIPTEN__ + && adapter.HasFeature(wgpu::FeatureName::ImplicitDeviceSynchronization) +#endif + ) { ctx->device_count = 1; } @@ -4515,8 +4515,11 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() { } ggml_backend_t ggml_backend_webgpu_init(void) { - ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_webgpu_reg(), 0); - + ggml_backend_reg_t reg = ggml_backend_webgpu_reg(); + if (ggml_backend_reg_dev_count(reg) == 0) { + return nullptr; + } + ggml_backend_dev_t dev = ggml_backend_reg_dev_get(reg, 0); return ggml_backend_webgpu_backend_init(dev, nullptr); } From 4317ddbe2b0fa7f436593658c8252469e598b36a Mon Sep 17 00:00:00 2001 From: Ruben Ortlam <rortlam@redhat.com> Date: Sat, 30 May 2026 10:39:31 +0200 Subject: [PATCH 759/831] vulkan: add Flash Attention support for BFloat16 KV cache (llama/23420) * vulkan: add flash attention bf16 kv support * vulkan: bf16 FA coopmat1 support * vulkan: bf16 FA coopmat2 support * fix FA bf16 f32 fallback * fix FA bf16 coopmat1 shader * fix FA bf16 coopmat2 shader * code cleanup * cleanup comment change * address feedback * add O_TYPE for cm2 FA * use O_TYPE for gqaStore function * reduce BFLOAT16 ifdefs --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 148 +++++++++++++----- .../vulkan-shaders/flash_attn_base.glsl | 12 +- .../vulkan-shaders/flash_attn_cm1.comp | 98 +++++++----- .../vulkan-shaders/flash_attn_cm2.comp | 36 +++-- .../vulkan-shaders/flash_attn_dequant.glsl | 8 + .../vulkan-shaders/vulkan-shaders-gen.cpp | 22 +++ 6 files changed, 235 insertions(+), 89 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index c9f906d7930..2a30fb95c61 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -691,6 +691,7 @@ struct vk_device_struct { uint32_t coopmat_int_k; bool coopmat2; + bool coopmat2_bf16_support {}; bool coopmat2_decode_vector; bool pipeline_executable_properties_support {}; @@ -3139,7 +3140,7 @@ struct vk_fa_tuning_params { }; static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type k_type, ggml_type v_type); -static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc); +static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type k_type = GGML_TYPE_F16); static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type k_type, ggml_type v_type, bool f32acc) { @@ -3279,6 +3280,13 @@ static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_ FaCodePath path = device->coopmat2 ? FA_COOPMAT2 : device->coopmat1_fa_support ? FA_COOPMAT1 : FA_SCALAR; + if (path == FA_COOPMAT2 && k_type == GGML_TYPE_BF16 && !device->coopmat2_bf16_support) { + path = FA_COOPMAT1; + } + if (path == FA_COOPMAT1 && k_type == GGML_TYPE_BF16 && !device->coopmat_bf16_support) { + path = FA_SCALAR; + } + if (path == FA_COOPMAT1 && device->architecture == vk_device_architecture::NVIDIA_TURING) { // Nvidia compiler bug, see https://github.com/ggml-org/llama.cpp/pull/19075#issuecomment-3820716090 path = FA_SCALAR; @@ -3288,7 +3296,7 @@ static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_ bool shape_ok = (f32acc && device->coopmat_support_16x16x16_f32acc) || (!f32acc && device->coopmat_support_16x16x16_f16acc); const vk_fa_tuning_params params = get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, k_type, v_type, f32acc); - bool shmem_ok = ggml_vk_flash_attn_coopmat_shmem_support(device, params, hsk, hsv, f32acc); + bool shmem_ok = ggml_vk_flash_attn_coopmat_shmem_support(device, params, hsk, hsv, f32acc, k_type); if (!shape_ok || !shmem_ok) { path = FA_SCALAR; @@ -3334,8 +3342,8 @@ static vk_fa_pipeline_state get_fa_pipeline_state(const vk_device& device, const static std::vector<uint32_t> get_fa_spec_constants(const vk_fa_pipeline_state& state) { const auto fa_block_bytes = [](ggml_type t) -> uint32_t { - // decodeBufF32 uses a block of vec4s for a better memory access pattern. - return t == GGML_TYPE_F32 ? 16u : (uint32_t) ggml_type_size(t); + if (t == GGML_TYPE_F32) return 16u; + return (uint32_t) ggml_type_size(t); }; return { /* 0 WorkGroupSize */ state.workgroup_size, @@ -3849,10 +3857,16 @@ static void ggml_vk_load_shaders(vk_device& device) { const uint32_t fa_sgs = fa.first.subgroup_size; const bool fa_ds = fa.first.subgroup_size == 0; + const bool bf16_kv = fa.first.k_type == GGML_TYPE_BF16; const bool use_mmq = ggml_vk_fa_scalar_uses_mmq(device, fa.first.k_type); const void * spv_data = nullptr; size_t spv_size = 0; - if (use_mmq) { + const char *name = nullptr; + if (bf16_kv) { + spv_data = flash_attn_f32_f16_fp32_data; + spv_size = flash_attn_f32_f16_fp32_len; + name = aligned ? "flash_attn_f32_bf16_aligned" : "flash_attn_f32_bf16"; + } else if (use_mmq) { #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) if (device->fp16) { if (f32acc) { spv_data = flash_attn_f32_f16_int8_data; spv_size = flash_attn_f32_f16_int8_len; } @@ -3862,6 +3876,7 @@ static void ggml_vk_load_shaders(vk_device& device) { spv_size = flash_attn_f32_f16_fp32_int8_len; } #endif + name = aligned ? "flash_attn_f32_f16_aligned" : "flash_attn_f32_f16"; } else { if (device->fp16) { if (f32acc) { spv_data = flash_attn_f32_f16_data; spv_size = flash_attn_f32_f16_len; } @@ -3870,8 +3885,8 @@ static void ggml_vk_load_shaders(vk_device& device) { spv_data = flash_attn_f32_f16_fp32_data; spv_size = flash_attn_f32_f16_fp32_len; } + name = aligned ? "flash_attn_f32_f16_aligned" : "flash_attn_f32_f16"; } - const char *name = aligned ? "flash_attn_f32_f16_aligned" : "flash_attn_f32_f16"; ggml_vk_create_pipeline(device, fa.second, name, spv_size, spv_data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), aligned ? Bc : 1, true, @@ -3889,11 +3904,25 @@ static void ggml_vk_load_shaders(vk_device& device) { const uint32_t fa_sgs = fa.first.subgroup_size; const bool fa_ds = fa.first.subgroup_size == 0; + const bool bf16_kv = fa.first.k_type == GGML_TYPE_BF16; + const void * spv_data; size_t spv_size; - if (f32acc) { spv_data = flash_attn_f32_f16_cm1_data; spv_size = flash_attn_f32_f16_cm1_len; } - else { spv_data = flash_attn_f32_f16_f16acc_cm1_data; spv_size = flash_attn_f32_f16_f16acc_cm1_len; } - const char *name = aligned ? "flash_attn_f32_f16_aligned_cm1" : "flash_attn_f32_f16_cm1"; + const char *name; + if (bf16_kv) { +#if defined(VK_KHR_shader_bfloat16) && defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) + if (!device->coopmat_bf16_support) continue; + spv_data = flash_attn_f32_f16_bf16_cm1_data; + spv_size = flash_attn_f32_f16_bf16_cm1_len; + name = aligned ? "flash_attn_f32_bf16_aligned_cm1" : "flash_attn_f32_bf16_cm1"; +#else + continue; +#endif + } else { + if (f32acc) { spv_data = flash_attn_f32_f16_cm1_data; spv_size = flash_attn_f32_f16_cm1_len; } + else { spv_data = flash_attn_f32_f16_f16acc_cm1_data; spv_size = flash_attn_f32_f16_f16acc_cm1_len; } + name = aligned ? "flash_attn_f32_f16_aligned_cm1" : "flash_attn_f32_f16_cm1"; + } ggml_vk_create_pipeline(device, fa.second, name, spv_size, spv_data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), aligned ? Bc : 1, true, @@ -3911,10 +3940,20 @@ static void ggml_vk_load_shaders(vk_device& device) { const bool aligned = fa.first.aligned; const bool f32acc = fa.first.f32acc; + const bool bf16_kv = fa.first.k_type == GGML_TYPE_BF16; const void * spv_data; size_t spv_size; const char * name; - if (aligned) { + if (bf16_kv) { +#if defined(VK_KHR_shader_bfloat16) && defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) + if (!device->coopmat2_bf16_support) continue; + spv_data = flash_attn_f32_f16_bf16_cm2_data; + spv_size = flash_attn_f32_f16_bf16_cm2_len; + name = aligned ? "flash_attn_f32_bf16_aligned_cm2" : "flash_attn_f32_bf16_cm2"; +#else + continue; +#endif + } else if (aligned) { if (f32acc) { spv_data = flash_attn_f32_f16_cm2_data; spv_size = flash_attn_f32_f16_cm2_len; name = "flash_attn_f32_f16_aligned_f32acc_cm2"; } else { spv_data = flash_attn_f32_f16_f16acc_cm2_data; spv_size = flash_attn_f32_f16_f16acc_cm2_len; name = "flash_attn_f32_f16_aligned_f16acc_cm2"; } } else { @@ -5784,46 +5823,72 @@ static vk_device ggml_vk_get_device(size_t idx) { found_fp16_256 = false, found_fp32_128 = false, found_fp32_256 = false; + bool found_bf16_128 = false, + found_bf16_256 = false; // need to support fp16*fp16 with fp16/fp32 accumulator, for workgroupsize 128 // with 32x16x16 and 256 with 32x32x16. for (auto &prop : flexible_dimensions) { if (prop.saturatingAccumulation == VK_FALSE && - prop.scope == VK_SCOPE_WORKGROUP_KHR && - prop.AType == VK_COMPONENT_TYPE_FLOAT16_KHR && - prop.BType == VK_COMPONENT_TYPE_FLOAT16_KHR) { - - if (prop.workgroupInvocations == 128 && - prop.MGranularity <= 32 && - prop.NGranularity <= 16 && - prop.KGranularity <= 16) { - if (prop.CType == VK_COMPONENT_TYPE_FLOAT16_KHR && - prop.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR) { - found_fp16_128 = true; + prop.scope == VK_SCOPE_WORKGROUP_KHR) { + + if (prop.AType == VK_COMPONENT_TYPE_FLOAT16_KHR && + prop.BType == VK_COMPONENT_TYPE_FLOAT16_KHR) { + + if (prop.workgroupInvocations == 128 && + prop.MGranularity <= 32 && + prop.NGranularity <= 16 && + prop.KGranularity <= 16) { + if (prop.CType == VK_COMPONENT_TYPE_FLOAT16_KHR && + prop.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR) { + found_fp16_128 = true; + } + if (prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR && + prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) { + found_fp32_128 = true; + } } - if (prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR && - prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) { - found_fp32_128 = true; + if (prop.workgroupInvocations == 256 && + prop.MGranularity <= 32 && + prop.NGranularity <= 32 && + prop.KGranularity <= 16) { + if (prop.CType == VK_COMPONENT_TYPE_FLOAT16_KHR && + prop.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR) { + found_fp16_256 = true; + } + if (prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR && + prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) { + found_fp32_256 = true; + } } } - if (prop.workgroupInvocations == 256 && - prop.MGranularity <= 32 && - prop.NGranularity <= 32 && - prop.KGranularity <= 16) { - if (prop.CType == VK_COMPONENT_TYPE_FLOAT16_KHR && - prop.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR) { - found_fp16_256 = true; + +#if defined(VK_KHR_shader_bfloat16) && defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) + if (prop.AType == VK_COMPONENT_TYPE_BFLOAT16_KHR && + prop.BType == VK_COMPONENT_TYPE_BFLOAT16_KHR && + prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR && + prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) { + + if (prop.workgroupInvocations == 128 && + prop.MGranularity <= 32 && + prop.NGranularity <= 16 && + prop.KGranularity <= 16) { + found_bf16_128 = true; } - if (prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR && - prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) { - found_fp32_256 = true; + if (prop.workgroupInvocations == 256 && + prop.MGranularity <= 32 && + prop.NGranularity <= 32 && + prop.KGranularity <= 16) { + found_bf16_256 = true; } } +#endif } } if (found_fp16_128 && found_fp16_256 && found_fp32_128 && found_fp32_256 && coopmat2_props.cooperativeMatrixFlexibleDimensionsMaxDimension >= 512) { device->coopmat2 = true; + device->coopmat2_bf16_support = found_bf16_128 && found_bf16_256; device->coopmat2_decode_vector = coopmat2_decode_vector_support && coopmat2_decode_vector_features.cooperativeMatrixDecodeVector; } } @@ -9448,7 +9513,8 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con const uint32_t Br = params.block_rows; const uint32_t Bc = params.block_cols; - const uint32_t float_type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float); + // BF16 uses the fp32 shader (FLOAT_TYPE=float) + const uint32_t float_type_size = (device->fp16 && k_type != GGML_TYPE_BF16) ? sizeof(ggml_fp16_t) : sizeof(float); const bool mmq = ggml_vk_fa_scalar_uses_mmq(device, k_type); @@ -9489,7 +9555,7 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con return supported; } -static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc) { +static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type k_type) { // Needs to be kept up to date on shader changes const uint32_t Br = params.block_rows; const uint32_t Bc = params.block_cols; @@ -9519,8 +9585,10 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co const uint32_t vsh_stride = MatBc / 4 * row_split; const uint32_t ksh = ((kvshstride >= vsh_stride) ? (Bc * kvshstride) : (Bc * vsh_stride)) * f16vec4; + // BF16 PVMat accumulator is f32 (no bf16 accumulator support), so pvsh is vec4 (16 bytes) + const uint32_t pvsh_elem_size = (k_type == GGML_TYPE_BF16) ? 16u : f16vec4; const uint32_t osh_stride = params.row_split * MatBr / 4; - const uint32_t pvsh = MatBc * osh_stride * f16vec4; + const uint32_t pvsh = MatBc * osh_stride * pvsh_elem_size; const uint32_t slope = Br * acctype; @@ -9589,7 +9657,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx uint32_t workgroups_y = (uint32_t)neq2; uint32_t workgroups_z = (uint32_t)neq3; - const bool f32acc = !ctx->device->fp16 || dst->op_params[3] == GGML_PREC_F32; + const bool f32acc = !ctx->device->fp16 || dst->op_params[3] == GGML_PREC_F32 || k->type == GGML_TYPE_BF16; // For scalar/coopmat1 FA, we can use the "large" size to accommodate qga. // For coopmat2 FA, we always use the small size (which is still pretty large for gqa). @@ -16400,6 +16468,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm switch (t) { case GGML_TYPE_F32: case GGML_TYPE_F16: + case GGML_TYPE_BF16: case GGML_TYPE_Q8_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q5_0: @@ -16415,6 +16484,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm if (!fa_kv_ok(op->src[1]->type) || !fa_kv_ok(op->src[2]->type)) { return false; } + if ((op->src[1]->type == GGML_TYPE_BF16) != (op->src[2]->type == GGML_TYPE_BF16)) { + return false; + } if (!coopmat2 && !(device->subgroup_shuffle && device->subgroup_vote)) { // scalar/coopmat1 FA uses subgroupShuffle/subgroupAll return false; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl index 9a7957da97b..66dcf610219 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl @@ -97,8 +97,17 @@ layout (binding = 6) readonly buffer MO {uint32_t data_mask_opt[];}; #define FA_TYPE_Q5_0 6u #define FA_TYPE_Q5_1 7u #define FA_TYPE_Q8_0 8u +#define FA_TYPE_BF16 30u #define FA_TYPE_Q1_0 41u +#if defined(BFLOAT16) +#define O_TYPE float +#define O_TYPEV4 vec4 +#else +#define O_TYPE FLOAT_TYPE +#define O_TYPEV4 FLOAT_TYPEV4 +#endif + // Number of matrix elements per buffer block, derived from the K/V type spec // constant. F32 is treated as a vec4 "block" of 4 floats. F16 uses block size 1 // and bypasses the dequant path entirely. Quants follow their ggml block sizes. @@ -111,6 +120,7 @@ uint fa_block_elems(uint ty) { case FA_TYPE_Q5_0: return uint(QUANT_K_Q5_0); case FA_TYPE_Q5_1: return uint(QUANT_K_Q5_1); case FA_TYPE_Q8_0: return uint(QUANT_K_Q8_0); + case FA_TYPE_BF16: return 1u; case FA_TYPE_Q1_0: return uint(QUANT_K_Q1_0); // cm2-only, harmless elsewhere default: return 1u; } @@ -248,7 +258,7 @@ const float FATTN_KQ_MAX_OFFSET = 3.0f*0.6931f; // Store the output when doing grouped query attention. // Rows index by Q's dimension 2, and the first N rows are valid. -void gqaStore(const in uint32_t r, const in uint32_t c, const in FLOAT_TYPEV4 elems, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) +void gqaStore(const in uint32_t r, const in uint32_t c, const in O_TYPEV4 elems, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) { uint32_t offset = (iq2 + r) * HSV / 4 + c; data_ov4[o_offset + offset] = D_TYPEV4(elems); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp index bffcc095be3..23ae3833e52 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp @@ -6,6 +6,10 @@ #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require +#if defined(BFLOAT16) +#extension GL_EXT_bfloat16 : enable +#endif + #extension GL_KHR_shader_subgroup_basic : enable #extension GL_KHR_shader_subgroup_arithmetic : enable #extension GL_KHR_shader_subgroup_vote : enable @@ -14,7 +18,9 @@ #include "types.glsl" #include "flash_attn_base.glsl" +#if !defined(BFLOAT16) #include "flash_attn_dequant.glsl" +#endif // These need to be supported N,M values for a MatBc x MatBr x 16 coopmatmuladd const uint32_t MatBr = 16; @@ -27,32 +33,32 @@ const uint32_t cols_per_thread = Bc / cols_per_iter; layout (binding = 0) readonly buffer Q {float data_q[];}; layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];}; -layout (binding = 1) readonly buffer K {float16_t data_k[];}; -layout (binding = 1) readonly buffer KV4 {f16vec4 data_kv4[];}; -layout (binding = 2) readonly buffer V {float16_t data_v[];}; -layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];}; +layout (binding = 1) readonly buffer K {FLOAT_TYPE data_k[];}; +layout (binding = 1) readonly buffer KV4 {FLOAT_TYPEV4 data_kv4[];}; +layout (binding = 2) readonly buffer V {FLOAT_TYPE data_v[];}; +layout (binding = 2) readonly buffer VV4 {FLOAT_TYPEV4 data_vv4[];}; layout (binding = 3) readonly buffer M {float16_t data_m[];}; shared float tmpsh[row_split]; -const uint32_t qstride = HSK_pad / 4 + 2; // in units of f16vec4 -shared f16vec4 Qf[Br * qstride]; +const uint32_t qstride = HSK_pad / 4 + 2; +shared FLOAT_TYPEV4 Qf[Br * qstride]; const uint psh_stride = Br / 4 + 2; -shared f16vec4 Psh[Bc * psh_stride]; +shared FLOAT_TYPEV4 Psh[Bc * psh_stride]; // Avoid padding for hsk==256 to make it fit in 48KB shmem. const uint32_t sfshstride = (HSK <= 128) ? (Br / 4 + 2) : Br / 4; shared ACC_TYPEV4 sfsh[Bc * sfshstride]; const uint32_t D_pad = HSK_pad > HSV_pad ? HSK_pad : HSV_pad; -const uint32_t kvsh_stride = (SHMEM_STAGING != 0 ? D_pad : MatBr) / 4 + 2; // in units of f16vec4 +const uint32_t kvsh_stride = (SHMEM_STAGING != 0 ? D_pad : MatBr) / 4 + 2; const uint v_cols = MatBc / 4 * row_split; // total cols, 4 vec4s per MatBc * number of subgroups const uint vsh_stride = v_cols; -shared f16vec4 kvsh[(kvsh_stride >= vsh_stride) ? (Bc * kvsh_stride) : (Bc * vsh_stride)]; +shared FLOAT_TYPEV4 kvsh[(kvsh_stride >= vsh_stride) ? (Bc * kvsh_stride) : (Bc * vsh_stride)]; const uint32_t osh_stride = row_split * MatBr / 4; -shared f16vec4 pvsh[MatBc * osh_stride]; +shared O_TYPEV4 pvsh[MatBc * osh_stride]; shared ACC_TYPE slope[Br]; @@ -76,7 +82,7 @@ void main() { if ((HSK % 16) != 0) { [[unroll]] for (uint i = 0; i < Br * qstride; i += gl_WorkGroupSize.x) { if (i + tid < Br * qstride) { - Qf[i + tid] = f16vec4(0); + Qf[i + tid] = FLOAT_TYPEV4(0); } } barrier(); @@ -89,15 +95,15 @@ void main() { uint32_t r = (idx + tid) / (HSK / 4); if (r < Br && d < HSK / 4 && i * Br + r < N) { - Qf[r * qstride + d] = f16vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale); + Qf[r * qstride + d] = FLOAT_TYPEV4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale); } } barrier(); - f16vec4 Of[rows_per_thread][d_per_thread]; + O_TYPEV4 Of[rows_per_thread][d_per_thread]; [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { [[unroll]] for (uint32_t d = 0; d < d_per_thread; ++d) { - Of[r][d] = f16vec4(0.0); + Of[r][d] = O_TYPEV4(0.0); } } @@ -222,15 +228,18 @@ void main() { uint32_t d = (idx + tid) % (HSK_pad / 4); uint32_t c = (idx + tid) / (HSK_pad / 4); if (idx + gl_WorkGroupSize.x <= Bc * HSK_pad / 4 || c < Bc) { - f16vec4 K_Tf = f16vec4(0); + FLOAT_TYPEV4 K_Tf = FLOAT_TYPEV4(0); if ((!KV_bounds_check || j * Bc + c < KV) && (HSK == HSK_pad || d < HSK / 4)) { +#if !defined(BFLOAT16) if (USE_DECODE_K) { uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE_K + 4 * d; uint ib = coord / BLOCK_SIZE_K; uint iqs = (coord % BLOCK_SIZE_K); K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); - } else { - K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]); + } else +#endif + { + K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]); } } @@ -244,16 +253,16 @@ void main() { // Bc split across workgroup (four subgroups), loop over HSK in chunks of 16: 16 x 16 * 16 x 16 -> 16 x 16 // This is written transposed in order to allow for N being 8 if implementations need it coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator> SfMat = coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator>(0); - coopmat<float16_t, gl_ScopeSubgroup, MatBc, 16, gl_MatrixUseA> KMat; - coopmat<float16_t, gl_ScopeSubgroup, 16, MatBr, gl_MatrixUseB> QMat; + coopmat<FLOAT_TYPE, gl_ScopeSubgroup, MatBc, 16, gl_MatrixUseA> KMat; + coopmat<FLOAT_TYPE, gl_ScopeSubgroup, 16, MatBr, gl_MatrixUseB> QMat; [[unroll]] for (uint32_t d = 0; d < HSK_pad / 16; ++d) { // If SHMEM_STAGING is set, a Bc * HSK_pad size tile of K is loaded to shmem - // If not, f16 K is loaded directly from global memory if aligned, otherwise + // If not, K is loaded directly from global memory if aligned, otherwise // staged through a Bc * MatBr size staging buffer. - // If K is not type f16, then it is always staged for dequantization. + // If K is a quant type, then it is always staged for dequantization. if (SHMEM_STAGING == 0) { - // For quants we always need to dequant into kvsh; for f16 we can load + // For quants we always need to dequant into kvsh; for f16/bf16 we can load // directly from global memory when alignment / bounds allow it. const bool stage_k = USE_DECODE_K || KV_bounds_check || d * 16 + 16 > HSK; if (stage_k) { @@ -262,15 +271,18 @@ void main() { uint32_t col_vec = (idx + tid) % (MatBr / 4); uint32_t row = (idx + tid) / (MatBr / 4); if (idx + tid < Bc * MatBr / 4) { - f16vec4 K_Tf = f16vec4(0); + FLOAT_TYPEV4 K_Tf = FLOAT_TYPEV4(0); if ((!KV_bounds_check || j * Bc + row < KV) && (HSK == HSK_pad || d * 16 + col_vec * 4 < HSK)) { +#if !defined(BFLOAT16) if (USE_DECODE_K) { uint coord = (j * Bc + row) * k_stride * BLOCK_SIZE_K + d * 16 + col_vec * 4; uint ib = coord / BLOCK_SIZE_K; uint iqs = (coord % BLOCK_SIZE_K); K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); - } else { - K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + row) * k_stride / 4 + d * 16 / 4 + col_vec]); + } else +#endif + { + K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + row) * k_stride / 4 + d * 16 / 4 + col_vec]); } } @@ -357,7 +369,7 @@ void main() { [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) { const uint d_local = d0 / threads_per_rowgroup; [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - Of[r][d_local] = float16_t(eMf[r]) * Of[r][d_local]; + Of[r][d_local] = O_TYPE(eMf[r]) * Of[r][d_local]; } } @@ -368,10 +380,10 @@ void main() { [[unroll]] for (uint32_t r = 0; r < rows_per_thread; r += 4) { const uint row = tile_row(r); if (KV_bounds_check && j * Bc + col >= KV) { - Psh[col * psh_stride + row / 4] = f16vec4(0.0f); + Psh[col * psh_stride + row / 4] = FLOAT_TYPEV4(0.0f); } else { const vec4 mfvec = vec4(Mf[r], Mf[r + 1], Mf[r + 2], Mf[r + 3]); - const f16vec4 Pf = f16vec4(exp(vec4(sfsh[row / 4 + col * sfshstride]) - mfvec)); + const FLOAT_TYPEV4 Pf = FLOAT_TYPEV4(exp(vec4(sfsh[row / 4 + col * sfshstride]) - mfvec)); [[unroll]] for (uint32_t vec_idx = 0; vec_idx < 4; ++vec_idx) { Lf[r + vec_idx] += Pf[vec_idx]; } @@ -385,15 +397,18 @@ void main() { uint32_t d = (idx + tid) % (HSV_pad / 4); uint32_t c = (idx + tid) / (HSV_pad / 4); if (idx + gl_WorkGroupSize.x <= Bc * HSV_pad / 4 || c < Bc) { - f16vec4 V_Tf = f16vec4(0); + FLOAT_TYPEV4 V_Tf = FLOAT_TYPEV4(0); if ((!KV_bounds_check || j * Bc + c < KV) && (HSV == HSV_pad || d < HSV / 4)) { +#if !defined(BFLOAT16) if (USE_DECODE_V) { uint coord = (j * Bc + c) * v_stride * BLOCK_SIZE_V + 4 * d; uint ib = coord / BLOCK_SIZE_V; uint iqs = (coord % BLOCK_SIZE_V); V_Tf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); - } else { - V_Tf = f16vec4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]); + } else +#endif + { + V_Tf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]); } } @@ -409,7 +424,7 @@ void main() { [[unroll]] for (uint32_t hsv_tile = 0; hsv_tile < num_hsv_tiles; ++hsv_tile) { const uint hsv_offset = (hsv_tile * row_split + gl_SubgroupID) * 16; - coopmat<float16_t, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator> PVMat = coopmat<float16_t, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator>(0); + coopmat<O_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator> PVMat = coopmat<O_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator>(0); // Preload V tiles for [Bc, 16 * num subgroups] const uint v_rows = Bc; @@ -417,11 +432,11 @@ void main() { const uint v_loads_per_thread = v_total / gl_WorkGroupSize.x; // If SHMEM_STAGING is set, a Bc * HSV_pad size tile of V is loaded to shmem. - // If not, f16 V is loaded directly from global memory if aligned, otherwise + // If not, V is loaded directly from global memory if aligned, otherwise // staged through a Bc * MatBr size staging buffer. - // If V is not type f16, then it is always staged for dequantization. + // If V is a quant type, then it is always staged for dequantization. if (SHMEM_STAGING == 0) { - // For quants we always preload via kvsh. For f16 we only preload when + // For quants we always preload via kvsh. For f16/bf16 we only preload when // alignment / bounds force it (otherwise we coopMatLoad direct from data_vv4). const bool stage_v = USE_DECODE_V || KV_bounds_check; if (stage_v) { @@ -438,13 +453,16 @@ void main() { const uint iqs = coord % BLOCK_SIZE_V; if (!KV_bounds_check || (v_row < KV && v_col < HSV)) { +#if !defined(BFLOAT16) if (USE_DECODE_V) { kvsh[row * vsh_stride + col] = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); - } else { + } else +#endif + { kvsh[row * vsh_stride + col] = data_vv4[(v_offset + v_row * v_stride + v_col) / 4]; } } else { - kvsh[row * vsh_stride + col] = f16vec4(0.0f); + kvsh[row * vsh_stride + col] = FLOAT_TYPEV4(0.0f); } } } @@ -459,7 +477,7 @@ void main() { if (SHMEM_STAGING == 0) { if (!USE_DECODE_V && !KV_bounds_check) { - // F16 values can be loaded directly from global memory + // F16/BF16 values can be loaded directly from global memory const uint v_tile_row = j * Bc + bc_chunk * MatBc; const uint v_tile_offset = v_offset / 4 + v_tile_row * v_stride / 4 + hsv_offset / 4; coopMatLoad(QMat, data_vv4, v_tile_offset, v_stride / 4, gl_CooperativeMatrixLayoutRowMajor); @@ -573,7 +591,7 @@ void main() { [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) { const uint d_local = d0 / threads_per_rowgroup; - Of[r][d_local] *= float16_t(ms); + Of[r][d_local] *= O_TYPE(ms); } } else { vs = exp(sink - Mf[r]); @@ -591,7 +609,7 @@ void main() { [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) { const uint d_local = d0 / threads_per_rowgroup; [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - Of[r][d_local] *= float16_t(Lfrcp[r]); + Of[r][d_local] *= O_TYPE(Lfrcp[r]); #if defined(FLOAT_TYPE_MAX) Of[r][d_local] = clamp(Of[r][d_local], -FLOAT_TYPE_MAX, FLOAT_TYPE_MAX); #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index 6d45b4931df..b9c03fe499d 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -8,6 +8,10 @@ #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require #extension GL_EXT_shader_explicit_arithmetic_types_int16 : require +#if defined(BFLOAT16) +#extension GL_EXT_bfloat16 : enable +#endif + #extension GL_KHR_memory_scope_semantics : enable #extension GL_KHR_cooperative_matrix : enable #extension GL_NV_cooperative_matrix2 : enable @@ -21,7 +25,9 @@ #include "types.glsl" #include "flash_attn_base.glsl" +#if !defined(BFLOAT16) #include "dequant_funcs_cm2.glsl" +#endif // buffer_reference stride = sizeof(struct) = FaBlockBytesK/V. layout(buffer_reference, std430, buffer_reference_align = 1) buffer decodeBufFA_K { @@ -31,6 +37,7 @@ layout(buffer_reference, std430, buffer_reference_align = 1) buffer decodeBufFA_ uint8_t raw[FaBlockBytesV]; }; +#if !defined(BFLOAT16) float16_t faDecodeK(const decodeBufFA_K bl_in, const uint blockCoords[2], const uint coordInBlock[2]) { switch (FaTypeK) { case FA_TYPE_F32: return dequantFuncF32 (decodeBufF32 (bl_in), blockCoords, coordInBlock); @@ -91,6 +98,7 @@ f16vec4 faDecodeVVector(const decodeBufFA_V bl_in, const uint blockCoords[2], co #define FADECODEK , faDecodeK #define FADECODEV , faDecodeV #endif +#endif layout (binding = 0) readonly buffer Q {uint8_t data_q[];}; layout (binding = 1) readonly buffer K {uint8_t data_k[];}; @@ -195,15 +203,15 @@ void main() { tensorLayoutV = setTensorLayoutStrideNV(tensorLayoutV, v_stride, 1); coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseAccumulator> Q; - coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseA> Qf16; + coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseA> Qf16; uint32_t q_offset = gqa_iq1*p.nb01*4/*sizeof(float)*/ + iq2*p.nb02+iq3*p.nb03; coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, HSK_pad)); - Qf16 = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseA>(Q); - Qf16 *= float16_t(p.scale); + Q *= Q_TYPE(p.scale); + Qf16 = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseA>(Q); - coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(0); + coopmat<O_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O = coopmat<O_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(0); coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> L, M; @@ -291,16 +299,20 @@ void main() { coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> S = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0); - coopmat<float16_t, gl_ScopeWorkgroup, HSK_pad, Bc, gl_MatrixUseB> K_T; + coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, HSK_pad, Bc, gl_MatrixUseB> K_T; uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13; // F16: bs_k==1 (direct load). F32: bs_k==4 (vec4 / dequantFuncF32). Q4/Q8 family: bs_k==32. Q1_0: bs_k==128. +#if defined(BFLOAT16) + coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose); +#else const bool k_use_decode = (bs_k > 1u); if (k_use_decode) { coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose FADECODEK); } else { coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose); } +#endif S = coopMatMulAdd(Qf16, K_T, S); if (LOGIT_SOFTCAP) { @@ -351,22 +363,26 @@ void main() { coopMatPerElementNV(P, P, replacePadding, ACC_TYPE(0.0), R, C); } - coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseA> P_A = coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseA>(P); + coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseA> P_A = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseA>(P); // compute rowsum by multiplying by matrix of all ones. - coopmat<float16_t, gl_ScopeWorkgroup, Bc, Bc, gl_MatrixUseB> One = coopmat<float16_t, gl_ScopeWorkgroup, Bc, Bc, gl_MatrixUseB>(1.0); + coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, Bc, Bc, gl_MatrixUseB> One = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, Bc, Bc, gl_MatrixUseB>(1.0); rowsum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0.0); rowsum = coopMatMulAdd(P_A, One, rowsum); - coopmat<float16_t, gl_ScopeWorkgroup, Bc, HSV_pad, gl_MatrixUseB> V; + coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, Bc, HSV_pad, gl_MatrixUseB> V; uint32_t v_offset = iv2*p.nb22 + iv3*p.nb23; +#if defined(BFLOAT16) + coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad)); +#else const bool v_use_decode = (bs_v > 1u); if (v_use_decode) { coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad) FADECODEV); } else { coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad)); } +#endif L = eM*L + rowsum; @@ -378,7 +394,7 @@ void main() { // resize eM by using smear/reduce coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce); - O *= coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(eMdiag); + O *= coopmat<O_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(eMdiag); O = coopMatMulAdd(P_A, V, O); } @@ -427,7 +443,7 @@ void main() { if (sink > Mr[i]) { ms = exp(Mr[i] - sink); - O[i] *= float16_t(ms); + O[i] *= O_TYPE(ms); } else { vs = exp(sink - Mr[i]); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl index 02106f33cbe..8704479d960 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl @@ -28,6 +28,9 @@ layout (binding = 2) readonly buffer V_PACKED_Q5_1 { block_q5_1_packed16 data[]; layout (binding = 1) readonly buffer K_PACKED_Q8_0 { block_q8_0_packed16 data[]; } k_packed_q8_0; layout (binding = 2) readonly buffer V_PACKED_Q8_0 { block_q8_0_packed16 data[]; } v_packed_q8_0; +layout (binding = 1) readonly buffer K_PACKED_BF16 { u16vec4 data[]; } k_packed_bf16; +layout (binding = 2) readonly buffer V_PACKED_BF16 { u16vec4 data[]; } v_packed_bf16; + // Q4_1 and Q5_1 packed32 views: aliased to the same memory as the packed16 // views, used by the MMQ K-side hot path for fast 4-uint loads. layout (binding = 1) readonly buffer K_PACKED_Q4_1_P32 { block_q4_1_packed32 data[]; } k_packed_q4_1_p32; @@ -99,6 +102,9 @@ layout (binding = 1) readonly buffer K_PACKED_Q5_1_P32 { block_q5_1_packed32 dat return FLOAT_TYPE(BUF.data[a_offset + ib].d) * FLOAT_TYPEV4(v0.x, v0.y, v1.x, v1.y); \ } +#define FA_DEQUANT4_BF16(BUF) \ + return FLOAT_TYPEV4(bf16_to_fp32(uvec4(BUF.data[(a_offset + ib) / 4]))); + FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { if (binding_idx == BINDING_IDX_K) { switch (FaTypeK) { @@ -108,6 +114,7 @@ FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { case FA_TYPE_Q5_0: FA_DEQUANT4_Q5_0(k_packed_q5_0) case FA_TYPE_Q5_1: FA_DEQUANT4_Q5_1(k_packed_q5_1) case FA_TYPE_Q8_0: FA_DEQUANT4_Q8_0(k_packed_q8_0) + case FA_TYPE_BF16: FA_DEQUANT4_BF16(k_packed_bf16) } } else { switch (FaTypeV) { @@ -117,6 +124,7 @@ FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { case FA_TYPE_Q5_0: FA_DEQUANT4_Q5_0(v_packed_q5_0) case FA_TYPE_Q5_1: FA_DEQUANT4_Q5_1(v_packed_q5_1) case FA_TYPE_Q8_0: FA_DEQUANT4_Q8_0(v_packed_q8_0) + case FA_TYPE_BF16: FA_DEQUANT4_BF16(v_packed_bf16) } } return FLOAT_TYPEV4(0); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index fa9b938e4f7..de7dbec2c63 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -662,6 +662,28 @@ void process_shaders() { } } + const std::map<std::string, std::string> fa_bf16_dict = { + {"FLOAT_TYPE", "bfloat16_t"}, + {"FLOAT_TYPEV2", "bf16vec2"}, + {"FLOAT_TYPEV4", "bf16vec4"}, + {"ACC_TYPE", "float"}, + {"ACC_TYPEV2", "vec2"}, + {"ACC_TYPEV4", "vec4"}, + {"BFLOAT16", "1"}, + }; + +#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + string_to_spv("flash_attn_f32_f16_bf16", "flash_attn_cm1.comp", + merge_maps(fa_bf16_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"COOPMAT", "1"}}), + true, true, false, false); +#endif + +#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + string_to_spv("flash_attn_f32_f16_bf16", "flash_attn_cm2.comp", + merge_maps(fa_bf16_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), + true, false, true, false); +#endif + std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}}; for (const auto& tname : type_names) { From 64b0d6b7fca11e62ef705127ec63fafdf0800a1e Mon Sep 17 00:00:00 2001 From: Jinyang He <hejinyang@loongson.cn> Date: Sat, 30 May 2026 16:53:26 +0800 Subject: [PATCH 760/831] ggml : add some lsx support (llama/23798) * loongarch : optimize LSX fp16 load/store with native intrinsics Use __lsx_vfcvtl_s_h and __lsx_vfcvt_h_s instead of scalar loops in __lsx_f16x4_load and __lsx_f16x4_store. * loongarch : add LSX implementation for q8_0 dot product * loongarch : add LSX implementation for q6_K dot product * loongarch : add LSX implementation for iq4_xs dot product * Improve reduce ops when sun int16 pairs to int32 --- ggml/src/ggml-cpu/arch/loongarch/quants.c | 151 ++++++++++++++++++++++ ggml/src/ggml-cpu/simd-mappings.h | 19 +-- 2 files changed, 154 insertions(+), 16 deletions(-) diff --git a/ggml/src/ggml-cpu/arch/loongarch/quants.c b/ggml/src/ggml-cpu/arch/loongarch/quants.c index 74e0c086c6d..9c43da6cf89 100644 --- a/ggml/src/ggml-cpu/arch/loongarch/quants.c +++ b/ggml/src/ggml-cpu/arch/loongarch/quants.c @@ -977,6 +977,35 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi sumf = hsum_float_8(acc); *s = sumf; + +#elif defined(__loongarch_sx) + + __m128 acc = (__m128)__lsx_vldi(0); + + for (; ib < nb; ++ib) { + const float d = GGML_CPU_FP16_TO_FP32(x[ib].d) * GGML_CPU_FP16_TO_FP32(y[ib].d); + const __m128i qx_0 = __lsx_vld((const __m128i *)x[ib].qs, 0); + const __m128i qx_1 = __lsx_vld((const __m128i *)x[ib].qs + 1, 0); + const __m128i qy_0 = __lsx_vld((const __m128i *)y[ib].qs, 0); + const __m128i qy_1 = __lsx_vld((const __m128i *)y[ib].qs + 1, 0); + + const __m128i p16_0 = lsx_maddubs_h(qx_0, qy_0); + const __m128i p16_1 = lsx_maddubs_h(qx_1, qy_1); + + // Sum int16 pairs → int32 + const __m128i s_0 = __lsx_vaddwev_w_h(p16_0, p16_1); + const __m128i s_1 = __lsx_vaddwod_w_h(p16_0, p16_1); + + const __m128 q = __lsx_vffint_s_w(__lsx_vadd_w(s_0, s_1)); + acc = __lsx_vfmadd_s(__lsx_vreplfr2vr_s(d), q, acc); + } + + __m128 res = lsx_hadd_s(acc, acc); + res = lsx_hadd_s(res, res); + sumf = ((v4f32)res)[0]; + + *s = sumf; + #else UNUSED(nb); UNUSED(ib); @@ -1443,6 +1472,99 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi *s = hsum_float_8(acc); +#elif defined(__loongarch_sx) + + const __m128i m32s = __lsx_vreplgr2vr_b(32); + + __m128 acc_0 = (__m128)__lsx_vldi(0); + __m128 acc_1 = (__m128)__lsx_vldi(0); + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); + + const uint8_t * GGML_RESTRICT q4 = x[i].ql; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + const __m128i scale_i8 = __lsx_vld(x[i].scales, 0); + const __m128i scales_lo = __lsx_vsllwil_h_b(scale_i8, 0); + const __m128i scales_hi = __lsx_vsllwil_h_b(__lsx_vbsrl_v(scale_i8, 8), 0); + + __m128i sumi_0 = __lsx_vldi(0); + __m128i sumi_1 = __lsx_vldi(0); + + for (int j = 0; j < QK_K/128; ++j) { + + const __m128i q4bitsH_0 = __lsx_vld((const __m128i*)qh, 0); qh += 16; + const __m128i q4bitsH_1 = __lsx_vld((const __m128i*)qh, 0); qh += 16; + + const __m128i q4h_0 = __lsx_vslli_b(__lsx_vandi_b(q4bitsH_0, 3), 4); + const __m128i q4h_1 = __lsx_vslli_b(__lsx_vandi_b(q4bitsH_1, 3), 4); + const __m128i q4h_2 = __lsx_vslli_b(__lsx_vandi_b(q4bitsH_0, 3 << 2), 2); + const __m128i q4h_3 = __lsx_vslli_b(__lsx_vandi_b(q4bitsH_1, 3 << 2), 2); + const __m128i q4h_4 = __lsx_vandi_b(q4bitsH_0, 3 << 4); + const __m128i q4h_5 = __lsx_vandi_b(q4bitsH_1, 3 << 4); + const __m128i q4h_6 = __lsx_vsrli_b(__lsx_vandi_b(q4bitsH_0, 3 << 6), 2); + const __m128i q4h_7 = __lsx_vsrli_b(__lsx_vandi_b(q4bitsH_1, 3 << 6), 2); + + const __m128i q4bits1_0 = __lsx_vld((const __m128i*)q4, 0); q4 += 16; + const __m128i q4bits1_1 = __lsx_vld((const __m128i*)q4, 0); q4 += 16; + const __m128i q4bits2_0 = __lsx_vld((const __m128i*)q4, 0); q4 += 16; + const __m128i q4bits2_1 = __lsx_vld((const __m128i*)q4, 0); q4 += 16; + + const __m128i q4_0 = __lsx_vor_v(__lsx_vandi_b(q4bits1_0, 0xf), q4h_0); + const __m128i q4_1 = __lsx_vor_v(__lsx_vandi_b(q4bits1_1, 0xf), q4h_1); + const __m128i q4_2 = __lsx_vor_v(__lsx_vandi_b(q4bits2_0, 0xf), q4h_2); + const __m128i q4_3 = __lsx_vor_v(__lsx_vandi_b(q4bits2_1, 0xf), q4h_3); + const __m128i q4_4 = __lsx_vor_v(__lsx_vsrli_b(q4bits1_0, 4), q4h_4); + const __m128i q4_5 = __lsx_vor_v(__lsx_vsrli_b(q4bits1_1, 4), q4h_5); + const __m128i q4_6 = __lsx_vor_v(__lsx_vsrli_b(q4bits2_0, 4), q4h_6); + const __m128i q4_7 = __lsx_vor_v(__lsx_vsrli_b(q4bits2_1, 4), q4h_7); + + const __m128i q8_0 = __lsx_vld((const __m128i*)q8, 0); q8 += 16; + const __m128i q8_1 = __lsx_vld((const __m128i*)q8, 0); q8 += 16; + const __m128i q8_2 = __lsx_vld((const __m128i*)q8, 0); q8 += 16; + const __m128i q8_3 = __lsx_vld((const __m128i*)q8, 0); q8 += 16; + const __m128i q8_4 = __lsx_vld((const __m128i*)q8, 0); q8 += 16; + const __m128i q8_5 = __lsx_vld((const __m128i*)q8, 0); q8 += 16; + const __m128i q8_6 = __lsx_vld((const __m128i*)q8, 0); q8 += 16; + const __m128i q8_7 = __lsx_vld((const __m128i*)q8, 0); q8 += 16; + + __m128i p16_0 = lsx_maddubs_h(__lsx_vsub_b(q4_0, m32s), q8_0); + __m128i p16_1 = lsx_maddubs_h(__lsx_vsub_b(q4_1, m32s), q8_1); + __m128i p16_2 = lsx_maddubs_h(__lsx_vsub_b(q4_2, m32s), q8_2); + __m128i p16_3 = lsx_maddubs_h(__lsx_vsub_b(q4_3, m32s), q8_3); + __m128i p16_4 = lsx_maddubs_h(__lsx_vsub_b(q4_4, m32s), q8_4); + __m128i p16_5 = lsx_maddubs_h(__lsx_vsub_b(q4_5, m32s), q8_5); + __m128i p16_6 = lsx_maddubs_h(__lsx_vsub_b(q4_6, m32s), q8_6); + __m128i p16_7 = lsx_maddubs_h(__lsx_vsub_b(q4_7, m32s), q8_7); + + const __m128i sc_vec = j == 0 ? scales_lo : scales_hi; + + p16_0 = lsx_madd_h(__lsx_vreplvei_h(sc_vec, 0), p16_0); + p16_1 = lsx_madd_h(__lsx_vreplvei_h(sc_vec, 1), p16_1); + p16_2 = lsx_madd_h(__lsx_vreplvei_h(sc_vec, 2), p16_2); + p16_3 = lsx_madd_h(__lsx_vreplvei_h(sc_vec, 3), p16_3); + p16_4 = lsx_madd_h(__lsx_vreplvei_h(sc_vec, 4), p16_4); + p16_5 = lsx_madd_h(__lsx_vreplvei_h(sc_vec, 5), p16_5); + p16_6 = lsx_madd_h(__lsx_vreplvei_h(sc_vec, 6), p16_6); + p16_7 = lsx_madd_h(__lsx_vreplvei_h(sc_vec, 7), p16_7); + + sumi_0 = __lsx_vadd_w(sumi_0, __lsx_vadd_w(p16_0, p16_2)); + sumi_1 = __lsx_vadd_w(sumi_1, __lsx_vadd_w(p16_1, p16_3)); + sumi_0 = __lsx_vadd_w(sumi_0, __lsx_vadd_w(p16_4, p16_6)); + sumi_1 = __lsx_vadd_w(sumi_1, __lsx_vadd_w(p16_5, p16_7)); + } + + __m128 p_0 = __lsx_vfmul_s(__lsx_vreplfr2vr_s(d), __lsx_vffint_s_w(sumi_0)); + __m128 p_1 = __lsx_vfmul_s(__lsx_vreplfr2vr_s(d), __lsx_vffint_s_w(sumi_1)); + acc_0 = __lsx_vfadd_s(p_0, acc_0); + acc_1 = __lsx_vfadd_s(p_1, acc_1); + } + + *s = hsum_float_4x4(acc_0, acc_1, (__m128)__lsx_vldi(0), (__m128)__lsx_vldi(0)); + #else UNUSED(x); UNUSED(y); @@ -2149,6 +2271,35 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v *s = hsum_float_8(accum); +#elif defined(__loongarch_sx) + + const __m128i values128 = __lsx_vld((const __m128i*)kvalues_iq4nl, 0); + + __m128 accum = (__m128)__lsx_vldi(0); + for (int ibl = 0; ibl < nb; ++ibl) { + const uint8_t * qs = x[ibl].qs; + const int8_t * q8 = y[ibl].qs; + uint16_t sh = x[ibl].scales_h; + __m128i sumi = __lsx_vldi(0); + for (int ib = 0; ib < QK_K/32; ++ib) { + const __m128i q4bits = __lsx_vld((const __m128i*)qs, 0); qs += 16; + const __m128i q8b_0 = __lsx_vld((const __m128i*)q8, 0); q8 += 16; + const __m128i q8b_1 = __lsx_vld((const __m128i*)q8, 0); q8 += 16; + const __m128i q4b_0 = __lsx_vshuf_b(values128, values128, __lsx_vandi_b(q4bits, 0xf)); + const __m128i q4b_1 = __lsx_vshuf_b(values128, values128, __lsx_vsrli_b(q4bits, 4)); + const __m128i p16_0 = lsx_maddubs_h(q4b_0, q8b_0); + const __m128i p16_1 = lsx_maddubs_h(q4b_1, q8b_1); + const int16_t ls = (((x[ibl].scales_l[ib/2] >> ((ib & 1) * 4)) & 0xf) | ((sh & 0x3) << 4)) - 32; + sh >>= 2; + sumi = __lsx_vadd_w(lsx_madd_h(p16_0, __lsx_vreplgr2vr_h(ls)), sumi); + sumi = __lsx_vadd_w(lsx_madd_h(p16_1, __lsx_vreplgr2vr_h(ls)), sumi); + } + const float ds = GGML_CPU_FP16_TO_FP32(x[ibl].d) * y[ibl].d; + accum = __lsx_vfadd_s(__lsx_vfmul_s(__lsx_vreplfr2vr_s(ds), __lsx_vffint_s_w(sumi)), accum); + } + + *s = ((v4f32)lsx_hadd_s(lsx_hadd_s(accum, accum), lsx_hadd_s(accum, accum)))[0]; + #else UNUSED(x); UNUSED(y); diff --git a/ggml/src/ggml-cpu/simd-mappings.h b/ggml/src/ggml-cpu/simd-mappings.h index 0deda930985..62e687201ef 100644 --- a/ggml/src/ggml-cpu/simd-mappings.h +++ b/ggml/src/ggml-cpu/simd-mappings.h @@ -1125,25 +1125,12 @@ static inline void __lasx_f32cx8_store(ggml_fp16_t * x, __m256 y) { #define GGML_F16_EPR 4 static inline __m128 __lsx_f16x4_load(const ggml_fp16_t * x) { - float tmp[4]; - - tmp[0] = GGML_CPU_FP16_TO_FP32(x[0]); - tmp[1] = GGML_CPU_FP16_TO_FP32(x[1]); - tmp[2] = GGML_CPU_FP16_TO_FP32(x[2]); - tmp[3] = GGML_CPU_FP16_TO_FP32(x[3]); - - return (__m128)__lsx_vld(tmp, 0); + return __lsx_vfcvtl_s_h(__lsx_vld((const void *)x, 0)); } static inline void __lsx_f16x4_store(ggml_fp16_t * x, __m128 y) { - float arr[4]; - - __lsx_vst(y, arr, 0); - - x[0] = GGML_CPU_FP32_TO_FP16(arr[0]); - x[1] = GGML_CPU_FP32_TO_FP16(arr[1]); - x[2] = GGML_CPU_FP32_TO_FP16(arr[2]); - x[3] = GGML_CPU_FP32_TO_FP16(arr[3]); + __m128i a = __lsx_vfcvt_h_s(y, y); + memcpy(x, &a, sizeof(ggml_fp16_t) * 4); } #define GGML_F32Cx4 __m128 From bf74b557d2a960628ba380acb43da5396aa4cb96 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov <ggerganov@gmail.com> Date: Sat, 30 May 2026 15:26:13 +0300 Subject: [PATCH 761/831] metal : restore im2col implementation for large kernels (llama/23901) --- ggml/src/ggml-metal/ggml-metal-device.cpp | 8 +- ggml/src/ggml-metal/ggml-metal-ops.cpp | 24 +++-- ggml/src/ggml-metal/ggml-metal.metal | 106 +++++++++++----------- 3 files changed, 77 insertions(+), 61 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index ba006d9b31a..5d4b10d34b9 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -1732,6 +1732,8 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rope(ggml_metal_ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_im2col(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_IM2COL); + GGML_TENSOR_LOCALS(int64_t, ne0, op->src[0], ne); + GGML_ASSERT(ggml_is_contiguous(op->src[1])); GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32); GGML_ASSERT(op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32); @@ -1739,7 +1741,11 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_im2col(ggml_meta char base[256]; char name[256]; - snprintf(base, 256, "kernel_im2col_%s", ggml_type_name(op->type)); + if (ne00*ne01 <= 1024) { + snprintf(base, 256, "kernel_im2col_%s", ggml_type_name(op->type)); + } else { + snprintf(base, 256, "kernel_im2col_ext_%s", ggml_type_name(op->type)); + } snprintf(name, 256, "%s", base); ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 206af227a2c..e2ce56e9e28 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -3635,16 +3635,26 @@ int ggml_metal_op_im2col(ggml_metal_op_t ctx, int idx) { auto pipeline = ggml_metal_library_get_pipeline_im2col(lib, op); - GGML_ASSERT(KH*KW <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + if (KH*KW <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + const uint64_t ntptg0 = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)/(KH*KW), N); - const uint64_t ntptg0 = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)/(KH*KW), N); + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); - ggml_metal_encoder_set_pipeline(enc, pipeline); - ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 1); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + ggml_metal_encoder_dispatch_threadgroups(enc, IC, OH, OW, ntptg0, KH, KW); + } else { + const uint64_t n_threads = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), N); + const int64_t quotient = N / n_threads + (N % n_threads > 0 ? 1 : 0); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); - ggml_metal_encoder_dispatch_threadgroups(enc, IC, OH, OW, ntptg0, KH, KW); + ggml_metal_encoder_dispatch_threadgroups(enc, quotient * CHW, OH, OW, n_threads, 1, 1); + } return 1; } diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index e772664ba91..4adf4614acb 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -4696,59 +4696,59 @@ kernel void kernel_im2col( template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col<float>; template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col<half>; -// TODO: obsolete -- remove -//typedef void (im2col_ext_t)( -// constant ggml_metal_kargs_im2col & args, -// device const float * x, -// device char * dst, -// uint3 tgpig[[threadgroup_position_in_grid]], -// uint3 tgpg[[threadgroups_per_grid]], -// uint3 tpitg[[thread_position_in_threadgroup]], -// uint3 ntg[[threads_per_threadgroup]]); -// -//template <typename T> -//kernel void kernel_im2col_ext( -// constant ggml_metal_kargs_im2col & args, -// device const float * x, -// device char * dst, -// uint3 tgpig[[threadgroup_position_in_grid]], -// uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW -// uint3 tpitg[[thread_position_in_threadgroup]], -// uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1] -// const int64_t KHW = (int64_t)args.KHW; -// -// const int64_t d = tgpig[0] / args.CHW; -// const int64_t chw = tgpig[0] % args.CHW; -// const int64_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1) -// const int64_t HW = tgpig[0] % KHW; -// -// const int64_t tpitg_0 = (d * ntg[0]) + tpitg[0]; -// if (tpitg_0 >= args.N) { -// return; -// } -// -// const int64_t tpitg_1 = HW / args.KW; -// const int64_t tpitg_2 = HW % args.KW; -// -// const int64_t iiw = tgpig[2] * args.s0 + tpitg_2 * args.d0 - args.p0; -// const int64_t iih = tgpig[1] * args.s1 + tpitg_1 * args.d1 - args.p1; -// -// const int64_t offset_dst = -// (tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * args.CHW + -// (tgpig_0 * KHW + tpitg_1 * args.KW + tpitg_2); -// -// device T * pdst = (device T *) (dst); -// -// if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) { -// pdst[offset_dst] = 0.0f; -// } else { -// const int64_t offset_src = tpitg_0 * args.ofs0 + tgpig_0 * args.ofs1; -// pdst[offset_dst] = x[offset_src + iih * args.IW + iiw]; -// } -//} -// -//template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext<float>; -//template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext<half>; +// TODO: optimize +typedef void (im2col_ext_t)( + constant ggml_metal_kargs_im2col & args, + device const float * x, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tgpg[[threadgroups_per_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]); + +template <typename T> +kernel void kernel_im2col_ext( + constant ggml_metal_kargs_im2col & args, + device const float * x, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1] + const int64_t KHW = (int64_t)args.KHW; + + const int64_t d = tgpig[0] / args.CHW; + const int64_t chw = tgpig[0] % args.CHW; + const int64_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1) + const int64_t HW = tgpig[0] % KHW; + + const int64_t tpitg_0 = (d * ntg[0]) + tpitg[0]; + if (tpitg_0 >= args.N) { + return; + } + + const int64_t tpitg_1 = HW / args.KW; + const int64_t tpitg_2 = HW % args.KW; + + const int64_t iiw = tgpig[2] * args.s0 + tpitg_2 * args.d0 - args.p0; + const int64_t iih = tgpig[1] * args.s1 + tpitg_1 * args.d1 - args.p1; + + const int64_t offset_dst = + (tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * args.CHW + + (tgpig_0 * KHW + tpitg_1 * args.KW + tpitg_2); + + device T * pdst = (device T *) (dst); + + if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) { + pdst[offset_dst] = 0.0f; + } else { + const int64_t offset_src = tpitg_0 * args.ofs0 + tgpig_0 * args.ofs1; + pdst[offset_dst] = x[offset_src + iih * args.IW + iiw]; + } +} + +template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext<float>; +template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext<half>; template <typename TK> kernel void kernel_conv_2d( From 1c0d1f0f7c40a18670a460cc23b654ceb679ba9a Mon Sep 17 00:00:00 2001 From: lhez <lih@qti.qualcomm.com> Date: Sat, 30 May 2026 10:17:47 -0700 Subject: [PATCH 762/831] opencl: support bf16 by converting to f16 (llama/23839) --- ggml/src/ggml-opencl/ggml-opencl.cpp | 81 +++++++++++++++++++++++++++- ggml/src/ggml-opencl/kernels/cvt.cl | 42 +++++++++++++++ 2 files changed, 121 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 751ec6116c0..3f3643a4cef 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -585,6 +585,7 @@ struct ggml_backend_opencl_context { cl_kernel kernel_convert_block_mxfp4_trans4_ns, kernel_restore_block_mxfp4_trans4_ns; cl_kernel kernel_convert_block_q8_0, kernel_restore_block_q8_0, kernel_restore_block_q8_0_trans; cl_kernel kernel_convert_block_q6_K_noshuffle, kernel_restore_block_q6_K_noshuffle; + cl_kernel kernel_convert_bf16_to_f16, kernel_convert_f16_to_bf16; cl_kernel kernel_mul_mat_q4_0_f32_8x_flat; cl_kernel kernel_convert_block_q4_0_noshuffle; cl_kernel kernel_restore_block_q4_0_noshuffle; @@ -1175,6 +1176,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx) { CL_CHECK((backend_ctx->kernel_restore_block_iq4_nl = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_iq4_nl", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_iq4_nl_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_iq4_nl_noshuffle", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_iq4_nl_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_iq4_nl_noshuffle", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_bf16_to_f16 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_bf16_to_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_f16_to_bf16 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_f16_to_bf16", &err), err)); GGML_LOG_CONT("."); } @@ -5019,6 +5022,8 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te case GGML_OP_MUL_MAT: if (op->src[0]->type == GGML_TYPE_F16) { return true; + } else if (op->src[0]->type == GGML_TYPE_BF16) { + return true; } else if (op->src[0]->type == GGML_TYPE_F32) { return op->src[1]->type == GGML_TYPE_F32; } else if (op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q4_1 || @@ -6828,6 +6833,40 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, } #endif // GGML_OPENCL_SOA_Q + // convert bf16 to f16 and store as f16 in device buffer + if (tensor->type == GGML_TYPE_BF16) { + GGML_ASSERT(offset % sizeof(ggml_fp16_t) == 0 && size % sizeof(ggml_fp16_t) == 0 + && "Offset and size must be multiples of 2 for bf16 tensors"); + + ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *) tensor->extra; + GGML_ASSERT(extra); + + cl_ulong n_elements = size / sizeof(ggml_fp16_t); + cl_ulong off_dst = (extra->offset + offset) / sizeof(ggml_fp16_t); + + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, + size, (void *) data, &err); + CL_CHECK(err); + + cl_kernel kernel = backend_ctx->kernel_convert_bf16_to_f16; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->data_device)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_ulong), &off_dst)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &n_elements)); + + size_t global_work_size[] = { (size_t)CEIL_DIV(n_elements, 64)*64, 1, 1 }; + size_t local_work_size[] = { 64, 1, 1 }; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + CL_CHECK(clReleaseEvent(evt)); + + return; + } + ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *) tensor->extra; GGML_ASSERT(extra); @@ -7676,6 +7715,41 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, } #endif // GGML_OPENCL_SOA_Q + if (tensor->type == GGML_TYPE_BF16) { + GGML_ASSERT(offset % sizeof(ggml_fp16_t) == 0 && size % sizeof(ggml_fp16_t) == 0 + && "Offset and size must be multiples of 2 for bf16 tensors"); + + ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *) tensor->extra; + GGML_ASSERT(extra); + + cl_ulong n_elements = size / sizeof(ggml_fp16_t); + cl_ulong off_src = (extra->offset + tensor->view_offs + offset) / sizeof(ggml_fp16_t); + + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, size, NULL, &err); + CL_CHECK(err); + + cl_kernel kernel = backend_ctx->kernel_convert_f16_to_bf16; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &off_src)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &n_elements)); + + size_t global_work_size[] = { (size_t)CEIL_DIV(n_elements, 64)*64, 1, 1 }; + size_t local_work_size[] = { 64, 1, 1 }; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseEvent(evt)); + + CL_CHECK(clEnqueueReadBuffer( + queue, data_device, CL_TRUE, 0, size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); + + return; + } + ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *) tensor->extra; CL_CHECK(clEnqueueReadBuffer( @@ -8165,6 +8239,7 @@ static void ggml_cl_copy_to_contiguous(ggml_backend_t backend, const ggml_tensor kernel = backend_ctx->kernel_cpy_f32_f32; break; case GGML_TYPE_F16: + case GGML_TYPE_BF16: // stored as f16 on device kernel = backend_ctx->kernel_cpy_f16_f16; break; default: @@ -11125,7 +11200,8 @@ static bool ggml_cl_can_use_adreno_xmem_gemm_f16_f32( if (backend_ctx->gpu_family != GPU_FAMILY::ADRENO) { return false; } - if (src0->type != GGML_TYPE_F16 || src1->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) { + if ((src0->type != GGML_TYPE_F16 && src0->type != GGML_TYPE_BF16) || + src1->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) { return false; } if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) { @@ -12843,7 +12919,8 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co GGML_ASSERT(dst); GGML_ASSERT(dst->extra); - const enum ggml_type src0t = src0->type; + // bf16 is stored as f16 on device + const enum ggml_type src0t = (src0->type == GGML_TYPE_BF16) ? GGML_TYPE_F16 : src0->type; const enum ggml_type src1t = src1->type; ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; diff --git a/ggml/src/ggml-opencl/kernels/cvt.cl b/ggml/src/ggml-opencl/kernels/cvt.cl index c25eabdd72b..4f01887efb3 100644 --- a/ggml/src/ggml-opencl/kernels/cvt.cl +++ b/ggml/src/ggml-opencl/kernels/cvt.cl @@ -117,6 +117,48 @@ struct block_iq4_nl uint8_t qs[QK4_NL / 2]; }; +//------------------------------------------------------------------------------ +// bf16 to f16 +//------------------------------------------------------------------------------ +kernel void kernel_convert_bf16_to_f16( + global const ushort * src, + global half * dst, + ulong off_dst, + ulong n +) { + uint i = get_global_id(0); + if (i >= n) { + return; + } + + dst[i + off_dst] = (half) as_float((uint) src[i] << 16); +} + +//------------------------------------------------------------------------------ +// f16 to bf16 +//------------------------------------------------------------------------------ +kernel void kernel_convert_f16_to_bf16( + global const half * src, + ulong off_src, + global ushort * dst, + ulong n +) { + uint i = get_global_id(0); + if (i >= n) { + return; + } + + float f = (float) src[i + off_src]; + uint bits = as_uint(f); + if ((bits & 0x7fffffffu) > 0x7f800000u) { + // nan to quiet nan + dst[i] = (ushort)((bits >> 16) | 0x40u); + } else { + uint rounded = bits + 0x7fffu + ((bits >> 16) & 1u); + dst[i] = (ushort)(rounded >> 16); + } +} + //------------------------------------------------------------------------------ // kernel_convert_block_q4_0 // Convert the block_q4_0 format to 2 separate arrays (AOS -> SOA). From 687fbcb149c8e28bec2f563c1e9a081872b4e44c Mon Sep 17 00:00:00 2001 From: Neo Zhang <zhang.jianyu@outlook.com> Date: Mon, 1 Jun 2026 14:50:55 +0800 Subject: [PATCH 763/831] sycl : Optimize Q3_K mul_mat by reorder (llama/23725) --- ggml/src/ggml-sycl/convert.cpp | 25 ++++++- ggml/src/ggml-sycl/dequantize.hpp | 57 ++++++++++++++ ggml/src/ggml-sycl/dmmv.cpp | 120 +++++++++++++++++++++++++++++- ggml/src/ggml-sycl/ggml-sycl.cpp | 52 +++++++++++++ ggml/src/ggml-sycl/mmvq.cpp | 30 +++++++- ggml/src/ggml-sycl/quants.hpp | 25 +++++++ ggml/src/ggml-sycl/vecdotq.hpp | 35 +++++++++ 7 files changed, 340 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-sycl/convert.cpp b/ggml/src/ggml-sycl/convert.cpp index 576f19d79ae..65593402e7d 100644 --- a/ggml/src/ggml-sycl/convert.cpp +++ b/ggml/src/ggml-sycl/convert.cpp @@ -107,6 +107,19 @@ static void dequantize_row_q3_K_sycl(const void *vx, dst_t *y, const int64_t k, #endif } +template <typename dst_t> +static void dequantize_row_q3_K_sycl_reorder(const void *vx, dst_t *y, const int64_t k, + dpct::queue_ptr stream) { + const int64_t nb = k / QK_K; + + dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)), + [=](sycl::nd_item<3> item_ct1) { + dequantize_block_q3_K_reorder(vx, y, item_ct1, nb); + }); +} + template <typename dst_t> static void dequantize_row_q4_0_sycl(const void *vx, dst_t *y, const int64_t k, dpct::queue_ptr stream) { @@ -652,7 +665,11 @@ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) { case GGML_TYPE_Q2_K: return dequantize_row_q2_K_sycl; case GGML_TYPE_Q3_K: - return dequantize_row_q3_K_sycl; + if (dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { + return dequantize_row_q3_K_sycl_reorder; + } else { + return dequantize_row_q3_K_sycl; + } case GGML_TYPE_Q4_K: if (dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { return dequantize_row_q4_K_sycl_reorder; @@ -730,7 +747,11 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) { case GGML_TYPE_Q2_K: return dequantize_row_q2_K_sycl; case GGML_TYPE_Q3_K: - return dequantize_row_q3_K_sycl; + if (dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { + return dequantize_row_q3_K_sycl_reorder; + } else { + return dequantize_row_q3_K_sycl; + } case GGML_TYPE_Q4_K: if (dst->src[0]->extra && ((ggml_tensor_extra_gpu*)dst->src[0]->extra)->optimized_feature.reorder) { diff --git a/ggml/src/ggml-sycl/dequantize.hpp b/ggml/src/ggml-sycl/dequantize.hpp index 2324bfacd22..a723d2afbd6 100644 --- a/ggml/src/ggml-sycl/dequantize.hpp +++ b/ggml/src/ggml-sycl/dequantize.hpp @@ -390,6 +390,63 @@ static void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restri } +template<typename dst_t> +static void dequantize_block_q3_K_reorder(const void * __restrict__ vx, dst_t * __restrict__ yy, + const sycl::nd_item<3> & item_ct1, int64_t n_blocks) { +#if QK_K == 256 + const int64_t i = item_ct1.get_group(2); + if (i >= n_blocks) { + return; + } + + const uint8_t * base = static_cast<const uint8_t *>(vx); + const size_t qs_offset = i * (QK_K / 4); + const size_t hmask_offset = n_blocks * (QK_K / 4) + i * (QK_K / 8); + const size_t scales_offset = n_blocks * (QK_K / 4) + n_blocks * (QK_K / 8) + i * 12; + const size_t d_offset = n_blocks * (QK_K / 4) + n_blocks * (QK_K / 8) + n_blocks * 12 + + i * sizeof(ggml_half); + + const uint8_t * qs = base + qs_offset; + const uint8_t * hmask = base + hmask_offset; + const uint8_t * scales = base + scales_offset; + const float d_all = static_cast<float>(*reinterpret_cast<const ggml_half *>(base + d_offset)); + + const int64_t r = item_ct1.get_local_id(2) / 4; + const int64_t tid = r / 2; + const int64_t is0 = r % 2; + const int64_t l0 = 16 * is0 + 4 * (item_ct1.get_local_id(2) % 4); + const int64_t n = tid / 4; + const int64_t j = tid - 4 * n; + const int64_t is = 8 * n + 2 * j + is0; + const int shift = 2 * j; + uint8_t m = 1 << (4 * n + j); + + uint8_t us = is < 4 + ? (scales[is - 0] & 0xF) | (((scales[is + 8] >> 0) & 3) << 4) + : is < 8 + ? (scales[is - 0] & 0xF) | (((scales[is + 4] >> 2) & 3) << 4) + : is < 12 + ? (scales[is - 8] >> 4) | (((scales[is + 0] >> 4) & 3) << 4) + : (scales[is - 8] >> 4) | (((scales[is - 4] >> 6) & 3) << 4); + + const float dl = d_all * (us - 32); + + dst_t * y = yy + i * QK_K + 128 * n + 32 * j; + const uint8_t * q = qs + 32 * n; + const uint8_t * hm = hmask; + + for (int l = l0; l < l0 + 4; ++l) { + y[l] = dl * ((int8_t) ((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4)); + } +#else + GGML_UNUSED(vx); + GGML_UNUSED(yy); + GGML_UNUSED(item_ct1); + GGML_UNUSED(n_blocks); + GGML_ABORT("Q3_K reorder dequantize not supported for QK_K != 256"); +#endif +} + #if QK_K == 256 static inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) { if (j < 4) { diff --git a/ggml/src/ggml-sycl/dmmv.cpp b/ggml/src/ggml-sycl/dmmv.cpp index 4ae431a962e..d80b0a38219 100644 --- a/ggml/src/ggml-sycl/dmmv.cpp +++ b/ggml/src/ggml-sycl/dmmv.cpp @@ -501,6 +501,103 @@ static void dequantize_mul_mat_vec_q3_k(const void *__restrict__ vx, } } +static void dequantize_mul_mat_vec_q3_k_reorder(const void *__restrict__ vx, + const float *__restrict__ yy, + float *__restrict__ dst, + const int ncols, int nrows, + const sycl::nd_item<3> &item_ct1) { + + const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + + item_ct1.get_local_id(1); + if (row > nrows) return; + + const int num_blocks_per_row = ncols / QK_K; + const int ib0 = row*num_blocks_per_row; + + // SOA base pointers for the reordered layout: + // [qs: nb * (QK_K/4)] [hmask: nb * (QK_K/8)] [scales: nb * 12] [d: nb * sizeof(half)] + const int nb = nrows * num_blocks_per_row; + const uint8_t * qs_base = (const uint8_t *)vx; + const uint8_t * hmask_base = qs_base + (size_t)nb * (QK_K / 4); + const uint8_t * scales_base = hmask_base + (size_t)nb * (QK_K / 8); + const sycl::half * d_base = (const sycl::half *)(scales_base + (size_t)nb * 12); + + float tmp = 0; // partial sum for thread in warp + +#if QK_K == 256 + + const uint16_t kmask1 = 0x0303; + const uint16_t kmask2 = 0x0f0f; + + const int tid = + item_ct1.get_local_id(2) / K_QUANTS_PER_ITERATION; // 0...31 or 0...16 + const int ix = + item_ct1.get_local_id(2) % K_QUANTS_PER_ITERATION; // 0 or 0,1 + + const int n = K_QUANTS_PER_ITERATION; // iterations in the inner loop + const int step = 16/K_QUANTS_PER_ITERATION; + const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... + const int in = tid - step*im; // 0....15 or 0...7 + + const uint8_t m = 1 << (4*im); + + const int l0 = n*in; // 0...15 or 0...14 in steps of 2 + const int q_offset = 32*im + l0; + const int y_offset = 128*im + l0; + + uint16_t utmp[4]; + const int8_t * s = (const int8_t *)utmp; + + const uint16_t s_shift = 4*im; + + for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { + const int bi = ib0 + i; + + const float * y = yy + i * QK_K + y_offset; + const uint8_t * q = qs_base + bi * (QK_K / 4) + q_offset; + const uint8_t * h = hmask_base + bi * (QK_K / 8) + l0; + + const uint16_t * a = (const uint16_t *)(scales_base + bi * 12); + utmp[0] = ((a[0] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 0)) & kmask1) << 4); + utmp[1] = ((a[1] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 0)) & kmask1) << 4); + utmp[2] = ((a[2] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 2)) & kmask1) << 4); + utmp[3] = ((a[3] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 2)) & kmask1) << 4); + + const float d = d_base[bi]; + + float sum = 0; + for (int l = 0; l < n; ++l) { + sum += y[l+ 0] * (s[0] - 32) * (((q[l] >> 0) & 3) - (h[l] & (m << 0) ? 0 : 4)) + + y[l+32] * (s[2] - 32) * (((q[l] >> 2) & 3) - (h[l] & (m << 1) ? 0 : 4)) + + y[l+64] * (s[4] - 32) * (((q[l] >> 4) & 3) - (h[l] & (m << 2) ? 0 : 4)) + + y[l+96] * (s[6] - 32) * (((q[l] >> 6) & 3) - (h[l] & (m << 3) ? 0 : 4)); + sum += y[l+16] * (s[1] - 32) * (((q[l+16] >> 0) & 3) - (h[l+16] & (m << 0) ? 0 : 4)) + + y[l+48] * (s[3] - 32) * (((q[l+16] >> 2) & 3) - (h[l+16] & (m << 1) ? 0 : 4)) + + y[l+80] * (s[5] - 32) * (((q[l+16] >> 4) & 3) - (h[l+16] & (m << 2) ? 0 : 4)) + + y[l+112] * (s[7] - 32) * (((q[l+16] >> 6) & 3) - (h[l+16] & (m << 3) ? 0 : 4)); + } + tmp += d * sum; + } +#else + GGML_UNUSED(vx); + GGML_UNUSED(yy); + GGML_UNUSED(ncols); + GGML_UNUSED(item_ct1); + GGML_ABORT("Q3_K reorder DMMV not supported for QK_K != 256"); +#endif + + // sum up partial sums and write back result +#pragma unroll + for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) { + tmp += + dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); + } + + if (item_ct1.get_local_id(2) == 0) { + dst[row] = tmp; + } +} + /* DPCT1110:6: The total declared local variable size in device function dequantize_mul_mat_vec_q4_k exceeds 128 bytes and may cause high register @@ -1440,6 +1537,22 @@ static void dequantize_mul_mat_vec_q3_K_sycl(const void *vx, const float *y, }); } +static void dequantize_mul_mat_vec_q3_K_sycl_reorder(const void *vx, const float *y, + float *dst, const int ncols, + const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int ny = 2 / K_QUANTS_PER_ITERATION; + const int block_num_y = (nrows + ny - 1) / ny; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE); + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] { + dequantize_mul_mat_vec_q3_k_reorder(vx, y, dst, ncols, nrows, item_ct1); + }); +} + static void dequantize_mul_mat_vec_q4_K_sycl(const void *vx, const float *y, float *dst, const int ncols, const int nrows, @@ -1581,7 +1694,12 @@ void ggml_sycl_op_dequantize_mul_mat_vec( dequantize_mul_mat_vec_q2_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); break; case GGML_TYPE_Q3_K: - dequantize_mul_mat_vec_q3_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); + if ((ggml_tensor_extra_gpu *) dst->src[0]->extra && + ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { + dequantize_mul_mat_vec_q3_K_sycl_reorder(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); + } else { + dequantize_mul_mat_vec_q3_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); + } break; case GGML_TYPE_Q4_K: if ((ggml_tensor_extra_gpu *) dst->src[0]->extra && diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 729a88b4db8..e59f5c174d3 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -3549,6 +3549,7 @@ inline bool ggml_sycl_supports_reorder_mul_mat_sycl(enum ggml_type type) { case GGML_TYPE_Q4_0: case GGML_TYPE_Q8_0: return true; + case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: @@ -3572,6 +3573,7 @@ inline bool ggml_sycl_supports_reorder_mmvq(enum ggml_type type) { switch (type) { case GGML_TYPE_Q4_0: case GGML_TYPE_Q8_0: + case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: @@ -3791,6 +3793,54 @@ static bool reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, d return true; } +static bool reorder_qw_q3_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) { + GGML_ASSERT(size % sizeof(block_q3_K) == 0); + GGML_ASSERT(offset % sizeof(block_q3_K) == 0); + + const int nblocks = size / sizeof(block_q3_K); + + sycl_reorder_temp_buffer tmp(stream, size); + if (!tmp) { + GGML_LOG_WARN("%s: failed to allocate %zu bytes for reorder temp buffer, skipping reorder\n", __func__, size); + return false; + } + uint8_t * tmp_buf = static_cast<uint8_t *>(tmp.ptr); + + sycl::event copy_event; + SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size))); + if (!g_ggml_sycl_use_async_mem_op) { + copy_event.wait(); + } + + auto * qs_ptr = data_device; + auto * hmask_ptr = qs_ptr + (QK_K / 4) * nblocks; + auto * scales_ptr = hmask_ptr + (QK_K / 8) * nblocks; + sycl::half * d_ptr = (sycl::half *) (scales_ptr + 12 * nblocks); + + auto reorder_event = stream->parallel_for(nblocks, [=](auto i) { + const block_q3_K * x = (const block_q3_K *) tmp_buf; + const int ib = i; + + for (int j = 0; j < QK_K / 4; ++j) { + qs_ptr[ib * (QK_K / 4) + j] = x[ib].qs[j]; + } + + for (int j = 0; j < QK_K / 8; ++j) { + hmask_ptr[ib * (QK_K / 8) + j] = x[ib].hmask[j]; + } + + for (int j = 0; j < 12; ++j) { + scales_ptr[ib * 12 + j] = x[ib].scales[j]; + } + + d_ptr[ib] = x[ib].d; + }); + if (!g_ggml_sycl_use_async_mem_op) { + reorder_event.wait_and_throw(); + } + return true; +} + static bool reorder_qw_q5_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) { GGML_ASSERT(size % sizeof(block_q5_K) == 0); GGML_ASSERT(offset % sizeof(block_q5_K) == 0); @@ -3903,6 +3953,8 @@ static bool reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) { return reorder_qw_q4_0(data_device, ncols, nrows, size, 0, stream); case GGML_TYPE_Q8_0: return reorder_qw_q8_0(data_device, ncols, nrows, size, 0, stream); + case GGML_TYPE_Q3_K: + return reorder_qw_q3_k(data_device, size, 0, stream); case GGML_TYPE_Q4_K: return reorder_qw_q4_k(data_device, size, 0, stream); case GGML_TYPE_Q5_K: diff --git a/ggml/src/ggml-sycl/mmvq.cpp b/ggml/src/ggml-sycl/mmvq.cpp index 49998f13ba8..abd1e49a70e 100644 --- a/ggml/src/ggml-sycl/mmvq.cpp +++ b/ggml/src/ggml-sycl/mmvq.cpp @@ -770,6 +770,26 @@ static void mul_mat_vec_q3_K_q8_1_sycl(const void *vx, const void *vy, } } +static void reorder_mul_mat_vec_q3_k_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, + const int nrows, dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + + // Round up to a whole number of subgroup-sized workgroups; out-of-range rows are skipped inside the kernel. + constexpr size_t num_subgroups = WARP_SIZE; + const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y * (int) num_subgroups) * (int) num_subgroups; + + const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE); + const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE); + + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size), + [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q3_K>>(vx, vy, dst, ncols, nrows, + nd_item); + }); + }); +} + static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy, float *dst, const int ncols, const int nrows, @@ -1153,7 +1173,15 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens mul_mat_vec_q2_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); break; case GGML_TYPE_Q3_K: - mul_mat_vec_q3_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + if ((ggml_tensor_extra_gpu *) dst->src[0]->extra && + ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { + GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q3_k_q8_1_sycl\n"); + reorder_mul_mat_vec_q3_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, + stream); + } else { + GGML_SYCL_DEBUG("Calling mul_mat_vec_q3_K_q8_1_sycl\n"); + mul_mat_vec_q3_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } break; case GGML_TYPE_Q4_K: if ((ggml_tensor_extra_gpu *) dst->src[0]->extra && diff --git a/ggml/src/ggml-sycl/quants.hpp b/ggml/src/ggml-sycl/quants.hpp index 806028ef3a3..95287f17510 100644 --- a/ggml/src/ggml-sycl/quants.hpp +++ b/ggml/src/ggml-sycl/quants.hpp @@ -58,6 +58,31 @@ template <> struct block_q_t<GGML_TYPE_Q4_0> { static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; } }; +template <> struct block_q_t<GGML_TYPE_Q3_K> { + struct traits { + static constexpr uint32_t qk = QK_K; + static constexpr uint32_t qi = QI3_K; + static constexpr uint32_t qr = QR3_K; + static constexpr uint32_t vdr_mmvq = 1; + }; + + // Reordered layout: [qs (QK_K/4 per block)] [hmask (QK_K/8 per block)] [scales] [d] + static constexpr std::pair<int, int> get_block_offset(const int block_index, const int n_blocks) { + auto qs_offset = block_index * (QK_K / 4); + auto hmask_offset = n_blocks * (QK_K / 4) + block_index * (QK_K / 8); + return { qs_offset, hmask_offset }; + } + + static constexpr std::pair<int, int> get_d_offset(int nrows, int ncols, const int block_index) { + auto nblocks = (nrows * (ncols / QK_K)); + auto total_qs_bytes = nblocks * (QK_K / 4) + nblocks * (QK_K / 8); + return { total_qs_bytes + block_index * 12, + total_qs_bytes + nblocks * 12 + block_index * sizeof(ggml_half) }; + } + + static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; } +}; + template <> struct block_q_t<GGML_TYPE_Q4_K> { struct traits { static constexpr uint32_t qk = QK_K; diff --git a/ggml/src/ggml-sycl/vecdotq.hpp b/ggml/src/ggml-sycl/vecdotq.hpp index 16b2d65d271..4b58b09ab2c 100644 --- a/ggml/src/ggml-sycl/vecdotq.hpp +++ b/ggml/src/ggml-sycl/vecdotq.hpp @@ -394,6 +394,41 @@ template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q8_0> { } }; +template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q3_K> { + static constexpr ggml_type gtype = GGML_TYPE_Q3_K; + + using q3_k_block = ggml_sycl_reordered::block_q_t<GGML_TYPE_Q3_K>; + using q3_k_traits = typename q3_k_block::traits; + + __dpct_inline__ float operator()(const void * __restrict__ vbq, const std::pair<int, int> ibx_offset, + const std::pair<int, int> d_offset, const int8_t * q8_1_quant_ptr, + const sycl::half2 * q8_1_ds, const int & iqs) { + const uint8_t * base = static_cast<const uint8_t *>(vbq); + const uint8_t * qs = base + ibx_offset.first; + const uint8_t * hmask = base + ibx_offset.second; + const uint8_t * scales = base + d_offset.first; + const ggml_half d = *reinterpret_cast<const ggml_half *>(base + d_offset.second); + + const int bq8_offset = QR3_K * (iqs / (QI3_K / 2)); + const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1 / 2); + + const int vl = get_int_from_uint8(qs, iqs); + const int vh = ~get_int_from_uint8(hmask, iqs % (QI3_K / 2)) >> bq8_offset; + + int u[QR3_K]; + float d8[QR3_K]; + +#pragma unroll + for (int i = 0; i < QR3_K; ++i) { + const int8_t * quant_base_ptr = q8_1_quant_ptr + (bq8_offset + i) * QK8_1; + u[i] = get_int_from_int8_aligned(quant_base_ptr, iqs % QI8_1); + d8[i] = (*(q8_1_ds + bq8_offset + i))[0]; + } + + return vec_dot_q3_K_q8_1_impl_mmvq(vl, vh, u, scales, scale_offset, static_cast<float>(d), d8); + } +}; + static inline float vec_dot_q4_K_q8_1_common(const int * __restrict__ q4, const uint16_t * __restrict__ scales, const ggml_half2 & dm, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { From 20323e48c4612b508e4776ec6ef93cdefbc8325c Mon Sep 17 00:00:00 2001 From: Neo Zhang <zhang.jianyu@outlook.com> Date: Mon, 1 Jun 2026 14:53:04 +0800 Subject: [PATCH 764/831] Add more types in GET_ROWS OP (llama/23710) * add to support Q1_0, NVFP4, IQ2_XXS, IQ2_XS, IQ2_S, IQ3_XXS, IQ1_S, IQ1_M, IQ3_S, IQ4_NL, IQ4_XS, I32, MXFP4, Q2_K, Q3_K, Q5_K, and Q6_K in GET_ROWS OP * correct the link --- ggml/src/ggml-sycl/dequantize.hpp | 472 ++++++++++++++++++++++++++++++ ggml/src/ggml-sycl/getrows.cpp | 78 ++++- ggml/src/ggml-sycl/ggml-sycl.cpp | 18 ++ 3 files changed, 565 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-sycl/dequantize.hpp b/ggml/src/ggml-sycl/dequantize.hpp index a723d2afbd6..ca8cd96c08c 100644 --- a/ggml/src/ggml-sycl/dequantize.hpp +++ b/ggml/src/ggml-sycl/dequantize.hpp @@ -20,6 +20,10 @@ typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int typedef void (*dequantize_kernel_t_reorder)(const void *d, const int64_t ib, const void *qs, const int iqs, dfloat2 &v); +#if QK_K == 256 +static inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m); +#endif + static __dpct_inline__ void dequantize_q4_0(const void *vx, const int64_t ib, const int iqs, dfloat2 &v) { const block_q4_0 * x = (const block_q4_0 *) vx; @@ -90,6 +94,474 @@ static __dpct_inline__ void dequantize_q4_1(const void *vx, const int64_t ib, #endif // GGML_SYCL_F16 } +static __dpct_inline__ void dequantize_q4_K(const void *vx, const int64_t ib, + const int iqs, dfloat2 &v) { +#if QK_K == 256 + const block_q4_K * x = (const block_q4_K *) vx; + const sycl::half2 dm = x[ib].dm; + const float dall = dm[0]; + const float dmin = dm[1]; + + auto dequantize_one = [&](const int idx) -> dfloat { + const int il = idx / 64; + const int in = idx % 64; + const int is = 2 * il + (in >= 32 ? 1 : 0); + const int off = in & 31; + const int qsi = 32 * il + off; + + uint8_t sc; + uint8_t m; + get_scale_min_k4(is, x[ib].scales, sc, m); + + const uint8_t q = x[ib].qs[qsi]; + const uint8_t qv = (in >= 32) ? (q >> 4) : (q & 0xF); + return sycl::fma((dfloat) qv, (dfloat) (dall * sc), (dfloat) (-dmin * m)); + }; + + v.x() = dequantize_one(iqs + 0); + v.y() = dequantize_one(iqs + 1); +#else + GGML_ABORT("Q4_K dequantize not supported for QK_K != 256"); +#endif +} + +static __dpct_inline__ void dequantize_q2_K(const void *vx, const int64_t ib, + const int iqs, dfloat2 &v) { +#if QK_K == 256 + const block_q2_K * x = (const block_q2_K *) vx; + const float dall = x[ib].dm[0]; + const float dmin = x[ib].dm[1]; + + auto dequantize_one = [&](const int idx) -> dfloat { + const int n = idx / 128; + const int r = idx % 128; + const int g = r / 32; + const int l = r % 32; + const int is = 8 * n + l / 16; + + const uint8_t q = x[ib].qs[32 * n + l]; + const uint8_t sc = x[ib].scales[is + 2 * g]; + const float d = dall * (sc & 0xF); + const float m = dmin * (sc >> 4); + + return sycl::fma((dfloat) ((q >> (2 * g)) & 3), (dfloat) d, (dfloat) (-m)); + }; + + v.x() = dequantize_one(iqs + 0); + v.y() = dequantize_one(iqs + 1); +#else + GGML_ABORT("Q2_K dequantize not supported for QK_K != 256"); +#endif +} + +static __dpct_inline__ void dequantize_q3_K(const void *vx, const int64_t ib, + const int iqs, dfloat2 &v) { +#if QK_K == 256 + const block_q3_K * x = (const block_q3_K *) vx; + const float d_all = x[ib].d; + + auto dequantize_one = [&](const int idx) -> dfloat { + const int n = idx / 128; + const int r = idx % 128; + const int j = r / 32; + const int l = r % 32; + + const int is0 = l / 16; + const int is = 8 * n + 2 * j + is0; + const int shift = 2 * j; + const uint8_t m = 1 << (4 * n + j); + + const int8_t us = is < 4 ? (x[ib].scales[is - 0] & 0xF) | (((x[ib].scales[is + 8] >> 0) & 3) << 4) : + is < 8 ? (x[ib].scales[is - 0] & 0xF) | (((x[ib].scales[is + 4] >> 2) & 3) << 4) : + is < 12 ? (x[ib].scales[is - 8] >> 4) | (((x[ib].scales[is + 0] >> 4) & 3) << 4) : + (x[ib].scales[is - 8] >> 4) | (((x[ib].scales[is - 4] >> 6) & 3) << 4); + + const float dl = d_all * (us - 32); + const uint8_t q = x[ib].qs[32 * n + l]; + const uint8_t h = x[ib].hmask[l]; + const int8_t qv = ((q >> shift) & 3) - ((h & m) ? 0 : 4); + + return (dfloat) (dl * qv); + }; + + v.x() = dequantize_one(iqs + 0); + v.y() = dequantize_one(iqs + 1); +#else + GGML_ABORT("Q3_K dequantize not supported for QK_K != 256"); +#endif +} + +static __dpct_inline__ void dequantize_q5_K(const void *vx, const int64_t ib, + const int iqs, dfloat2 &v) { +#if QK_K == 256 + const block_q5_K * x = (const block_q5_K *) vx; + const float dall = x[ib].dm[0]; + const float dmin = x[ib].dm[1]; + + auto dequantize_one = [&](const int idx) -> dfloat { + const int il = idx / 64; + const int in = idx % 64; + const int is = 2 * il + (in >= 32 ? 1 : 0); + const int ir = (in & 31) / 2; + const int iq = in & 1; + + const uint8_t q = x[ib].qs[32 * il + 2 * ir + iq]; + const uint8_t h = x[ib].qh[2 * ir + iq]; + const uint8_t qv = (in >= 32) ? (q >> 4) : (q & 0xF); + + uint8_t sc; + uint8_t m; + get_scale_min_k4(is, x[ib].scales, sc, m); + + const float d = dall * sc; + const float mn = dmin * m; + const uint8_t hm = 1 << (2 * il + (in >= 32 ? 1 : 0)); + + return sycl::fma((dfloat) (qv + ((h & hm) ? 16 : 0)), (dfloat) d, (dfloat) (-mn)); + }; + + v.x() = dequantize_one(iqs + 0); + v.y() = dequantize_one(iqs + 1); +#else + GGML_ABORT("Q5_K dequantize not supported for QK_K != 256"); +#endif +} + +static __dpct_inline__ void dequantize_q6_K(const void *vx, const int64_t ib, + const int iqs, dfloat2 &v) { +#if QK_K == 256 + const block_q6_K * x = (const block_q6_K *) vx; + const float d = x[ib].d; + + auto dequantize_one = [&](const int idx) -> dfloat { + const int ip = idx / 128; + const int in = idx % 128; + const int il = in & 31; + const int ig = in / 32; + const int is = 8 * ip + il / 16; + + const uint8_t ql0 = x[ib].ql[64 * ip + il]; + const uint8_t ql1 = x[ib].ql[64 * ip + il + 32]; + const uint8_t qh = x[ib].qh[32 * ip + il]; + const int8_t * sc = x[ib].scales + is; + + uint8_t qv; + int8_t scale; + if (ig == 0) { + qv = (ql0 & 0xF) | (((qh >> 0) & 3) << 4); + scale = sc[0]; + } else if (ig == 1) { + qv = (ql1 & 0xF) | (((qh >> 2) & 3) << 4); + scale = sc[2]; + } else if (ig == 2) { + qv = (ql0 >> 4) | (((qh >> 4) & 3) << 4); + scale = sc[4]; + } else { + qv = (ql1 >> 4) | (((qh >> 6) & 3) << 4); + scale = sc[6]; + } + + return (dfloat) (d * scale * ((int8_t) qv - 32)); + }; + + v.x() = dequantize_one(iqs + 0); + v.y() = dequantize_one(iqs + 1); +#else + GGML_ABORT("Q6_K dequantize not supported for QK_K != 256"); +#endif +} + +static __dpct_inline__ void dequantize_mxfp4(const void *vx, const int64_t ib, + const int iqs, dfloat2 &v) { + const block_mxfp4 * x = (const block_mxfp4 *) vx; + const float d = ggml_sycl_e8m0_to_fp32(x[ib].e); + const uint8_t q = x[ib].qs[iqs]; + + v.x() = d * kvalues_mxfp4[q & 0xF] * 0.5f; + v.y() = d * kvalues_mxfp4[q >> 4] * 0.5f; +} + +static __dpct_inline__ void dequantize_q1_0(const void *vx, const int64_t ib, + const int iqs, dfloat2 &v) { + const block_q1_0 * x = (const block_q1_0 *) vx; + const dfloat d = x[ib].d; + + const int bit_index_0 = iqs + 0; + const int bit_index_1 = iqs + 1; + + const int bit_0 = (x[ib].qs[bit_index_0 / 8] >> (bit_index_0 % 8)) & 1; + const int bit_1 = (x[ib].qs[bit_index_1 / 8] >> (bit_index_1 % 8)) & 1; + + v.x() = (2 * bit_0 - 1) * d; + v.y() = (2 * bit_1 - 1) * d; +} + +static __dpct_inline__ void dequantize_nvfp4(const void *vx, const int64_t ib, + const int iqs, dfloat2 &v) { + const block_nvfp4 & xb = ((const block_nvfp4 *) vx)[ib]; + + auto dequantize_one = [&](const int idx) -> dfloat { + const int sub = idx / QK_NVFP4_SUB; + const int j = idx % QK_NVFP4_SUB; + const int jh = j % (QK_NVFP4_SUB / 2); + + const float d = ggml_sycl_ue4m3_to_fp32(xb.d[sub]); + const uint8_t q = xb.qs[sub * (QK_NVFP4_SUB / 2) + jh]; + const uint8_t qv = (j < (QK_NVFP4_SUB / 2)) ? (q & 0x0F) : (q >> 4); + + return d * kvalues_mxfp4[qv]; + }; + + v.x() = dequantize_one(iqs + 0); + v.y() = dequantize_one(iqs + 1); +} + +static __dpct_inline__ void dequantize_iq2_xxs(const void *vx, const int64_t ib, + const int iqs, dfloat2 &v) { +#if QK_K == 256 + const block_iq2_xxs * x = (const block_iq2_xxs *) vx; + + auto dequantize_one = [&](const int idx) -> dfloat { + const int ib8 = idx / 32; + const int r = idx % 32; + const int il = r / 8; + const int j = r % 8; + + const uint16_t * q2 = x[ib].qs + 4 * ib8; + const uint8_t * aux8 = (const uint8_t *) q2; + const uint8_t * grid = (const uint8_t *) (iq2xxs_grid + aux8[il]); + const uint32_t aux32 = q2[2] | (q2[3] << 16); + const float d = (float) x[ib].d * (0.5f + (aux32 >> 28)) * 0.25f; + const uint8_t signs = ksigns_iq2xs[(aux32 >> (7 * il)) & 127]; + + return d * grid[j] * ((signs & kmask_iq2xs[j]) ? -1.f : 1.f); + }; + + v.x() = dequantize_one(iqs + 0); + v.y() = dequantize_one(iqs + 1); +#else + GGML_ABORT("IQ2_XXS dequantize not supported for QK_K != 256"); +#endif +} + +static __dpct_inline__ void dequantize_iq2_xs(const void *vx, const int64_t ib, + const int iqs, dfloat2 &v) { +#if QK_K == 256 + const block_iq2_xs * x = (const block_iq2_xs *) vx; + + auto dequantize_one = [&](const int idx) -> dfloat { + const int ib8 = idx / 32; + const int r = idx % 32; + const int il = r / 8; + const int j = r % 8; + + const uint16_t * q2 = x[ib].qs + 4 * ib8; + const uint8_t * grid = (const uint8_t *) (iq2xs_grid + (q2[il] & 511)); + const float d = (float) x[ib].d * (0.5f + ((x[ib].scales[ib8] >> (4 * (il / 2))) & 0xf)) * 0.25f; + const uint8_t signs = ksigns_iq2xs[q2[il] >> 9]; + + return d * grid[j] * ((signs & kmask_iq2xs[j]) ? -1.f : 1.f); + }; + + v.x() = dequantize_one(iqs + 0); + v.y() = dequantize_one(iqs + 1); +#else + GGML_ABORT("IQ2_XS dequantize not supported for QK_K != 256"); +#endif +} + +static __dpct_inline__ void dequantize_iq2_s(const void *vx, const int64_t ib, + const int iqs, dfloat2 &v) { +#if QK_K == 256 + const block_iq2_s * x = (const block_iq2_s *) vx; + + auto dequantize_one = [&](const int idx) -> dfloat { + const int ib8 = idx / 32; + const int r = idx % 32; + const int il = r / 8; + const int j = r % 8; + + const uint16_t grid_id = x[ib].qs[4 * ib8 + il] | ((x[ib].qh[ib8] << (8 - 2 * il)) & 0x300); + const uint8_t * grid = (const uint8_t *) (iq2s_grid + grid_id); + const float d = (float) x[ib].d * (0.5f + ((x[ib].scales[ib8] >> (4 * (il / 2))) & 0xf)) * 0.25f; + const uint8_t signs = x[ib].qs[QK_K / 8 + 4 * ib8 + il]; + + return d * grid[j] * ((signs & kmask_iq2xs[j]) ? -1.f : 1.f); + }; + + v.x() = dequantize_one(iqs + 0); + v.y() = dequantize_one(iqs + 1); +#else + GGML_ABORT("IQ2_S dequantize not supported for QK_K != 256"); +#endif +} + +static __dpct_inline__ void dequantize_iq3_xxs(const void *vx, const int64_t ib, + const int iqs, dfloat2 &v) { +#if QK_K == 256 + const block_iq3_xxs * x = (const block_iq3_xxs *) vx; + + auto dequantize_one = [&](const int idx) -> dfloat { + const int ib8 = idx / 32; + const int r = idx % 32; + const int il = r / 8; + const int j = r % 8; + + const uint8_t * q3 = x[ib].qs + 8 * ib8; + const uint16_t * gas = (const uint16_t *) (x[ib].qs + QK_K / 4) + 2 * ib8; + const uint8_t * grid1 = (const uint8_t *) (iq3xxs_grid + q3[2 * il + 0]); + const uint8_t * grid2 = (const uint8_t *) (iq3xxs_grid + q3[2 * il + 1]); + const uint32_t aux32 = gas[0] | (gas[1] << 16); + const float d = (float) x[ib].d * (0.5f + (aux32 >> 28)) * 0.5f; + const uint8_t signs = ksigns_iq2xs[(aux32 >> (7 * il)) & 127]; + + if (j < 4) { + return d * grid1[j] * ((signs & kmask_iq2xs[j + 0]) ? -1.f : 1.f); + } + return d * grid2[j - 4] * ((signs & kmask_iq2xs[j + 0]) ? -1.f : 1.f); + }; + + v.x() = dequantize_one(iqs + 0); + v.y() = dequantize_one(iqs + 1); +#else + GGML_ABORT("IQ3_XXS dequantize not supported for QK_K != 256"); +#endif +} + +static __dpct_inline__ void dequantize_iq3_s(const void *vx, const int64_t ib, + const int iqs, dfloat2 &v) { +#if QK_K == 256 + const block_iq3_s * x = (const block_iq3_s *) vx; + + auto dequantize_one = [&](const int idx) -> dfloat { + const int ib8 = idx / 32; + const int r = idx % 32; + const int il = r / 8; + const int j = r % 8; + + const uint8_t * qs = x[ib].qs + 8 * ib8; + const uint16_t grid1_id = qs[2 * il + 0] | ((x[ib].qh[ib8] << (8 - 2 * il)) & 256); + const uint16_t grid2_id = qs[2 * il + 1] | ((x[ib].qh[ib8] << (7 - 2 * il)) & 256); + const uint8_t * grid1 = (const uint8_t *) (iq3s_grid + grid1_id); + const uint8_t * grid2 = (const uint8_t *) (iq3s_grid + grid2_id); + const float d = (float) x[ib].d * (1 + 2 * ((x[ib].scales[ib8 / 2] >> (4 * (ib8 % 2))) & 0xf)); + const uint8_t signs = x[ib].signs[4 * ib8 + il]; + + if (j < 4) { + return d * grid1[j] * ((signs & kmask_iq2xs[j + 0]) ? -1.f : 1.f); + } + return d * grid2[j - 4] * ((signs & kmask_iq2xs[j + 0]) ? -1.f : 1.f); + }; + + v.x() = dequantize_one(iqs + 0); + v.y() = dequantize_one(iqs + 1); +#else + GGML_ABORT("IQ3_S dequantize not supported for QK_K != 256"); +#endif +} + +static __dpct_inline__ void dequantize_iq1_s(const void *vx, const int64_t ib, + const int iqs, dfloat2 &v) { +#if QK_K == 256 + const block_iq1_s * x = (const block_iq1_s *) vx; + + auto dequantize_one = [&](const int idx) -> dfloat { + const int ib8 = idx / 32; + const int r = idx % 32; + const int il = r / 8; + const int j = r % 8; + + const float delta = (x[ib].qh[ib8] & 0x8000) ? (-1.f - IQ1S_DELTA) : (-1.f + IQ1S_DELTA); + const float d = (float) x[ib].d * (2 * ((x[ib].qh[ib8] >> 12) & 7) + 1); + const uint16_t grid_id = x[ib].qs[4 * ib8 + il] | (((x[ib].qh[ib8] >> (3 * il)) & 7) << 8); + const uint32_t g = iq1s_grid_gpu[grid_id]; + const int8_t qv = (j < 4) ? ((g >> (8 * j)) & 0x0F) : ((g >> (8 * (j - 4) + 4)) & 0x0F); + + return d * (qv + delta); + }; + + v.x() = dequantize_one(iqs + 0); + v.y() = dequantize_one(iqs + 1); +#else + GGML_ABORT("IQ1_S dequantize not supported for QK_K != 256"); +#endif +} + +static __dpct_inline__ void dequantize_iq1_m(const void *vx, const int64_t ib, + const int iqs, dfloat2 &v) { +#if QK_K == 256 + const block_iq1_m * x = (const block_iq1_m *) vx; + + auto dequantize_one = [&](const int idx) -> dfloat { + const int ib8 = idx / 32; + const int r = idx % 32; + const int il = r / 8; + const int j = r % 8; + + const uint16_t * sc = (const uint16_t *) x[ib].scales; + iq1m_scale_t scale; + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); + + const int ib16 = 2 * ib8 + il / 2; + const float d = (float) scale.f16 * (2 * ((sc[ib16 / 4] >> (3 * (ib16 % 4))) & 0x7) + 1); + + const uint8_t qh = x[ib].qh[2 * ib8 + il / 2]; + const float delta = (qh & (0x08 << (4 * (il % 2)))) ? (-1.f - IQ1M_DELTA) : (-1.f + IQ1M_DELTA); + + const uint16_t grid_id = x[ib].qs[4 * ib8 + il] | (((qh >> (4 * (il % 2))) & 7) << 8); + const uint32_t g = iq1s_grid_gpu[grid_id]; + const int8_t qv = (j < 4) ? ((g >> (8 * j)) & 0x0F) : ((g >> (8 * (j - 4) + 4)) & 0x0F); + + return d * (qv + delta); + }; + + v.x() = dequantize_one(iqs + 0); + v.y() = dequantize_one(iqs + 1); +#else + GGML_ABORT("IQ1_M dequantize not supported for QK_K != 256"); +#endif +} + +static __dpct_inline__ void dequantize_iq4_nl(const void *vx, const int64_t ib, + const int iqs, dfloat2 &v) { + const block_iq4_nl * x = (const block_iq4_nl *) vx; + const float d = (float) x[ib].d; + + auto dequantize_one = [&](const int idx) -> dfloat { + if (idx < 16) { + return d * kvalues_iq4nl[x[ib].qs[idx] & 0xF]; + } + return d * kvalues_iq4nl[x[ib].qs[idx - 16] >> 4]; + }; + + v.x() = dequantize_one(iqs + 0); + v.y() = dequantize_one(iqs + 1); +} + +static __dpct_inline__ void dequantize_iq4_xs(const void *vx, const int64_t ib, + const int iqs, dfloat2 &v) { +#if QK_K == 256 + const block_iq4_xs * x = (const block_iq4_xs *) vx; + + auto dequantize_one = [&](const int idx) -> dfloat { + const int ib8 = idx / 32; + const int r = idx % 32; + const int byte_idx = (r < 16) ? r : (r - 16); + const uint8_t q = x[ib].qs[16 * ib8 + byte_idx]; + const uint8_t qv = (r < 16) ? (q & 0x0F) : (q >> 4); + + const float d = (float) x[ib].d * ((((x[ib].scales_l[ib8 / 2] >> (4 * (ib8 % 2))) & 0xf) | + (((x[ib].scales_h >> (2 * ib8)) & 3) << 4)) - 32); + return d * kvalues_iq4nl[qv]; + }; + + v.x() = dequantize_one(iqs + 0); + v.y() = dequantize_one(iqs + 1); +#else + GGML_ABORT("IQ4_XS dequantize not supported for QK_K != 256"); +#endif +} + static __dpct_inline__ void dequantize_q5_0(const void *vx, const int64_t ib, const int iqs, dfloat2 &v) { const block_q5_0 * x = (const block_q5_0 *) vx; diff --git a/ggml/src/ggml-sycl/getrows.cpp b/ggml/src/ggml-sycl/getrows.cpp index ca457454775..298f247f84e 100644 --- a/ggml/src/ggml-sycl/getrows.cpp +++ b/ggml/src/ggml-sycl/getrows.cpp @@ -129,11 +129,11 @@ static void get_rows_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor *sr GGML_UNUSED(ctx); } -template <typename src0_t> +template <typename src0_t, typename dst_t> static void get_rows_sycl_float(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const src0_t *src0_dd, const int32_t *src1_dd, - float *dst_dd, queue_ptr stream) { + dst_t *dst_dd, queue_ptr stream) { GGML_TENSOR_BINARY_OP_LOCALS @@ -170,7 +170,7 @@ static void get_rows_sycl_float(ggml_backend_sycl_context & ctx, const ggml_tens void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { GGML_ASSERT(dst->src[1]->type == GGML_TYPE_I32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_I32 ); GGML_ASSERT(dst->src[0]->nb[0] == ggml_type_size(dst->src[0]->type)); GGML_ASSERT(dst->src[1]->nb[0] == ggml_type_size(dst->src[1]->type)); @@ -191,6 +191,66 @@ void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { get_rows_sycl_float(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, src1_i32, (float *)dst->data, ctx.stream()); break; + case GGML_TYPE_I32: + get_rows_sycl_float(ctx, dst->src[0], dst->src[1], dst, (const int32_t *)dst->src[0]->data, + src1_i32, (int32_t *)dst->data, ctx.stream()); + break; + case GGML_TYPE_Q1_0: + get_rows_sycl<QK1_0, 1, dequantize_q1_0>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); + break; + case GGML_TYPE_MXFP4: + get_rows_sycl<QK_MXFP4, 2, dequantize_mxfp4>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); + break; + case GGML_TYPE_NVFP4: + get_rows_sycl<QK_NVFP4, 1, dequantize_nvfp4>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); + break; + case GGML_TYPE_IQ2_XXS: + get_rows_sycl<QK_K, 1, dequantize_iq2_xxs>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); + break; + case GGML_TYPE_IQ2_XS: + get_rows_sycl<QK_K, 1, dequantize_iq2_xs>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); + break; + case GGML_TYPE_IQ2_S: + get_rows_sycl<QK_K, 1, dequantize_iq2_s>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); + break; + case GGML_TYPE_IQ3_XXS: + get_rows_sycl<QK_K, 1, dequantize_iq3_xxs>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); + break; + case GGML_TYPE_IQ1_S: + get_rows_sycl<QK_K, 1, dequantize_iq1_s>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); + break; + case GGML_TYPE_IQ1_M: + get_rows_sycl<QK_K, 1, dequantize_iq1_m>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); + break; + case GGML_TYPE_IQ3_S: + get_rows_sycl<QK_K, 1, dequantize_iq3_s>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); + break; + case GGML_TYPE_IQ4_NL: + get_rows_sycl<QK4_NL, 1, dequantize_iq4_nl>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); + break; + case GGML_TYPE_IQ4_XS: + get_rows_sycl<QK_K, 1, dequantize_iq4_xs>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); + break; + case GGML_TYPE_Q2_K: + get_rows_sycl<QK_K, 1, dequantize_q2_K>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); + break; + case GGML_TYPE_Q3_K: + get_rows_sycl<QK_K, 1, dequantize_q3_K>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); + break; case GGML_TYPE_Q4_0: get_rows_sycl<QK4_0, QR4_0, dequantize_q4_0>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, src1_i32, (float *)dst->data, ctx.stream()); @@ -199,6 +259,10 @@ void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { get_rows_sycl<QK4_1, QR4_1, dequantize_q4_1>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, src1_i32, (float *)dst->data, ctx.stream()); break; + case GGML_TYPE_Q4_K: + get_rows_sycl<QK_K, 1, dequantize_q4_K>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); + break; case GGML_TYPE_Q5_0: get_rows_sycl<QK5_0, QR5_0, dequantize_q5_0>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, src1_i32, (float *)dst->data, ctx.stream()); @@ -207,6 +271,14 @@ void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { get_rows_sycl<QK5_1, QR5_1, dequantize_q5_1>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, src1_i32, (float *)dst->data, ctx.stream()); break; + case GGML_TYPE_Q5_K: + get_rows_sycl<QK_K, 1, dequantize_q5_K>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); + break; + case GGML_TYPE_Q6_K: + get_rows_sycl<QK_K, 1, dequantize_q6_K>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); + break; case GGML_TYPE_Q8_0: get_rows_sycl<QK8_0, QR8_0, dequantize_q8_0>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, src1_i32, (float *)dst->data, ctx.stream()); diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index e59f5c174d3..96138f57ebe 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -5301,13 +5301,31 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_GET_ROWS: { switch (op->src[0]->type) { + case GGML_TYPE_I32: case GGML_TYPE_F16: case GGML_TYPE_BF16: case GGML_TYPE_F32: + case GGML_TYPE_Q1_0: + case GGML_TYPE_MXFP4: + case GGML_TYPE_NVFP4: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ4_XS: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: case GGML_TYPE_Q8_0: return true; default: From ec0c6619500e86d2c2f290d1c95a3f022397fbcf Mon Sep 17 00:00:00 2001 From: Neo Zhang <zhang.jianyu@outlook.com> Date: Mon, 1 Jun 2026 14:53:53 +0800 Subject: [PATCH 765/831] Support Q4_1, Q5_0, Q5_1 in Flash-attention (llama/23812) * support Q4_1, Q5_0, Q5_1 * update ut case --- ggml/src/ggml-sycl/common.hpp | 1 + ggml/src/ggml-sycl/fattn-common.hpp | 6 +++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp index 31e26ff48e4..d8bb3638dfd 100644 --- a/ggml/src/ggml-sycl/common.hpp +++ b/ggml/src/ggml-sycl/common.hpp @@ -45,6 +45,7 @@ namespace syclexp = sycl::ext::oneapi::experimental; #define GGML_COMMON_IMPL_SYCL #define SYCL_FLASH_ATTN //remove it to disable FLASH_ATTENTION in building. #define SYCL_FAST_FP16 //don't change. remove it will break fattn-tile.hpp building +#define GGML_SYCL_FA_ALL_QUANTS //define it to enable all quantization types in flash attention. undefine it to only support F16, Q4_0 and Q8_0 in flash attention. /* suppress warning spam */ #pragma clang diagnostic push diff --git a/ggml/src/ggml-sycl/fattn-common.hpp b/ggml/src/ggml-sycl/fattn-common.hpp index 03f0c2623c8..c6cc13cfb00 100644 --- a/ggml/src/ggml-sycl/fattn-common.hpp +++ b/ggml/src/ggml-sycl/fattn-common.hpp @@ -1031,7 +1031,7 @@ void launch_fattn( auto KV_max_ptr_ct1 = KV_max.ptr; cgh.parallel_for(sycl::nd_range<3>(blocks_num_KV_max * block_dim_KV_max, block_dim_KV_max), - [=](sycl::nd_item<3> item_ct1) { + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(warp_size)]] { GGML_UNUSED(item_ct1); flash_attn_mask_to_KV_max<ncols1, warp_size>( mask_data_ct0, KV_max_ptr_ct1, iter_k, s31, s33, @@ -1149,7 +1149,7 @@ void launch_fattn( auto K_ne_ct6 = K->ne[2]; cgh.parallel_for(sycl::nd_range<3>(blocks_num_combine * block_dim_combine, block_dim_combine), - [=](sycl::nd_item<3> item_ct1) { + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(warp_size)]] { GGML_UNUSED(item_ct1); flash_attn_stream_k_fixup<DV, ncols1, ncols2>(KQV_data_ct0, dst_tmp_meta_ptr_ct1, Q_ne_ct2, Q_ne_ct3, Q_ne_ct4, @@ -1169,7 +1169,7 @@ void launch_fattn( auto KQV_data_ct2 = (float *) KQV->data; cgh.parallel_for(sycl::nd_range<3>(blocks_num_combine * block_dim_combine, block_dim_combine), - [=](sycl::nd_item<3> item_ct1) { + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(warp_size)]] { GGML_UNUSED(item_ct1); flash_attn_combine_results<DV>( dst_tmp_ptr_ct0, dst_tmp_meta_ptr_ct1, KQV_data_ct2, parallel_blocks, From aea93ada610cf565e0585dfe2822cd4a2206a488 Mon Sep 17 00:00:00 2001 From: Winston Ma <winstonma@ymail.com> Date: Mon, 1 Jun 2026 17:46:23 +0800 Subject: [PATCH 766/831] vulkan: Removed unused functions (llama/23175) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 2a30fb95c61..74104149db8 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -7166,13 +7166,6 @@ static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context& subctx->s->buffer->buf.dispatch(wg0, wg1, wg2); } -static void ggml_vk_end_submission(vk_submission& s, std::vector<vk_semaphore> wait_semaphores, std::vector<vk_semaphore> signal_semaphores) { - s.buffer->buf.end(); - - s.wait_semaphores = std::move(wait_semaphores); - s.signal_semaphores = std::move(signal_semaphores); -} - static void ggml_vk_ctx_end(vk_context& ctx) { VK_LOG_DEBUG("ggml_vk_ctx_end(" << ctx << ", " << ctx->seqs.size() << ")"); if (ctx->s == nullptr) { @@ -14510,12 +14503,6 @@ static const char * ggml_backend_vk_host_buffer_type_name(ggml_backend_buffer_ty UNUSED(buft); } -static const char * ggml_backend_vk_host_buffer_name(ggml_backend_buffer_t buffer) { - return GGML_VK_NAME "_Host"; - - UNUSED(buffer); -} - static void ggml_backend_vk_host_buffer_free_buffer(ggml_backend_buffer_t buffer) { VK_LOG_MEMORY("ggml_backend_vk_host_buffer_free_buffer()"); ggml_vk_host_free(vk_instance.devices[0], buffer->context); From 982533fc0c38dabc6f7fa9155b7e33e5f565e223 Mon Sep 17 00:00:00 2001 From: Matt Corallo <649246+TheBlueMatt@users.noreply.github.com> Date: Mon, 1 Jun 2026 09:46:48 +0000 Subject: [PATCH 767/831] vulkan: Block-load Q3_K/Q6_K block data and subtract on 32b ints (llama/23056) Q2_K/Q3_K/Q6_K do much better when using MMVQ on Intel BMG even though they're only 2-byte aligned, and Q3_K still wins on NVIDIA as well. mesa isn't all that great at coalescing back-to-back loads from alternating arrays, so we force it instead. Further, we can do subtraction directly on a full int32_t rather than an i8vec4 with bit twiddling because the high bit is always free to start. On Intel BMG on mesa, the switch to MMVQ provides an immediate ~57% perf increase in tg128 for unsloth/Qwen3.5-9B-GGUF:Q3_K and ~78% perf increase in tg128 for unsloth/Qwen3.5-9B-GGUF:Q6_K. The futher switch to block loads leads to a ~24% perf increase in tg128 for unsloth/Qwen3.5-9B-GGUF:Q3_K and a ~48% perf increase in tg128 for unsloth/Qwen3.5-9B-GGUF:Q6_K. Finally, Xe2 wins on MMVQ even for small k, so we take the NVIDIA override for K quants on Xe2 as well. --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 19 ++- .../vulkan-shaders/mul_mat_vecq_funcs.glsl | 108 +++++++++++------- 2 files changed, 80 insertions(+), 47 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 74104149db8..3cf191f2085 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -8336,8 +8336,10 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_ return false; } - // General performance issue with q3_k and q6_k due to 2-byte alignment - if (src0_type == GGML_TYPE_Q3_K || src0_type == GGML_TYPE_Q6_K) { + // q6_k only has 2-byte alignment which makes it somewhat problematic, + // using MMVQ is only a win on Intel. + bool mmvq_q6 = device->vendor_id == VK_VENDOR_ID_INTEL; + if (src0_type == GGML_TYPE_Q6_K && !mmvq_q6) { return false; } @@ -8349,7 +8351,7 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_ // Quantization overhead is not worth it for small k switch (device->vendor_id) { case VK_VENDOR_ID_NVIDIA: - if (src0_type == GGML_TYPE_Q2_K || src0_type == GGML_TYPE_IQ1_S || src0_type == GGML_TYPE_IQ1_M) { + if (src0_type == GGML_TYPE_Q2_K || src0_type == GGML_TYPE_Q3_K || src0_type == GGML_TYPE_IQ1_S || src0_type == GGML_TYPE_IQ1_M) { return true; } @@ -8376,9 +8378,16 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_ return true; } case VK_VENDOR_ID_INTEL: + if (device->architecture == vk_device_architecture::INTEL_XE2) { + if (src0_type == GGML_TYPE_Q2_K || src0_type == GGML_TYPE_Q3_K || src0_type == GGML_TYPE_Q6_K) { + return true; + } + } + if (device->driver_id == vk::DriverId::eIntelProprietaryWindows) { - // Intel Windows proprietary driver MMVQ performance is worse than fp16, see - // https://github.com/ggml-org/llama.cpp/issues/17628 + // Intel Windows proprietary driver MMVQ performance for !Q2/Q3/Q6 is worse than fp16, + // see https://github.com/ggml-org/llama.cpp/issues/17628 and + // https://github.com/ggml-org/llama.cpp/pull/23056 return false; } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl index bc580aeeb83..73cf9c79955 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl @@ -212,28 +212,40 @@ i32vec4 repack4(uint ib, uint iqs) { const uint qs_shift = ((iqs_k % 32) / 8) * 2; const uint hm_shift = iqs_k / 8; + const uvec4 qs = uvec4( uint32_t(data_a_packed16[ib_k].qs[qs_idx * 2 ]) | + (uint32_t(data_a_packed16[ib_k].qs[qs_idx * 2 + 1]) << 16), + uint32_t(data_a_packed16[ib_k].qs[qs_idx * 2 + 2]) | + (uint32_t(data_a_packed16[ib_k].qs[qs_idx * 2 + 3]) << 16), + uint32_t(data_a_packed16[ib_k].qs[qs_idx * 2 + 4]) | + (uint32_t(data_a_packed16[ib_k].qs[qs_idx * 2 + 5]) << 16), + uint32_t(data_a_packed16[ib_k].qs[qs_idx * 2 + 6]) | + (uint32_t(data_a_packed16[ib_k].qs[qs_idx * 2 + 7]) << 16)); + + const uvec4 hmask = uvec4( uint32_t(data_a_packed16[ib_k].hmask[iqs * 2 ]) | + (uint32_t(data_a_packed16[ib_k].hmask[iqs * 2 + 1]) << 16), + uint32_t(data_a_packed16[ib_k].hmask[iqs * 2 + 2]) | + (uint32_t(data_a_packed16[ib_k].hmask[iqs * 2 + 3]) << 16), + uint32_t(data_a_packed16[ib_k].hmask[iqs * 2 + 4]) | + (uint32_t(data_a_packed16[ib_k].hmask[iqs * 2 + 5]) << 16), + uint32_t(data_a_packed16[ib_k].hmask[iqs * 2 + 6]) | + (uint32_t(data_a_packed16[ib_k].hmask[iqs * 2 + 7]) << 16)); + // bitwise OR to add 4 if hmask is set, subtract later - const i8vec2 vals00 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 ] >> qs_shift) & uint16_t(0x0303))) | - unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 ] >> hm_shift) & uint16_t(0x0101)) << 2)); - const i8vec2 vals01 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 1] >> qs_shift) & uint16_t(0x0303))) | - unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 1] >> hm_shift) & uint16_t(0x0101)) << 2)); - const i8vec2 vals10 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 2] >> qs_shift) & uint16_t(0x0303))) | - unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 2] >> hm_shift) & uint16_t(0x0101)) << 2)); - const i8vec2 vals11 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 3] >> qs_shift) & uint16_t(0x0303))) | - unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 3] >> hm_shift) & uint16_t(0x0101)) << 2)); - const i8vec2 vals20 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 4] >> qs_shift) & uint16_t(0x0303))) | - unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 4] >> hm_shift) & uint16_t(0x0101)) << 2)); - const i8vec2 vals21 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 5] >> qs_shift) & uint16_t(0x0303))) | - unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 5] >> hm_shift) & uint16_t(0x0101)) << 2)); - const i8vec2 vals30 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 6] >> qs_shift) & uint16_t(0x0303))) | - unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 6] >> hm_shift) & uint16_t(0x0101)) << 2)); - const i8vec2 vals31 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 7] >> qs_shift) & uint16_t(0x0303))) | - unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 7] >> hm_shift) & uint16_t(0x0101)) << 2)); - - return i32vec4(pack32(i8vec4(vals00.x, vals00.y, vals01.x, vals01.y) - int8_t(4)), - pack32(i8vec4(vals10.x, vals10.y, vals11.x, vals11.y) - int8_t(4)), - pack32(i8vec4(vals20.x, vals20.y, vals21.x, vals21.y) - int8_t(4)), - pack32(i8vec4(vals30.x, vals30.y, vals31.x, vals31.y) - int8_t(4))); + const uint vals0 = (( qs.x >> qs_shift) & 0x03030303) | + (((hmask.x >> hm_shift) & 0x01010101) << 2); + const uint vals1 = (( qs.y >> qs_shift) & 0x03030303) | + (((hmask.y >> hm_shift) & 0x01010101) << 2); + const uint vals2 = (( qs.z >> qs_shift) & 0x03030303) | + (((hmask.z >> hm_shift) & 0x01010101) << 2); + const uint vals3 = (( qs.w >> qs_shift) & 0x03030303) | + (((hmask.w >> hm_shift) & 0x01010101) << 2); + + // Subtract 4 by twiddling bits rather than using re-packing as mesa + // compiles repacking poorly. + return i32vec4(int32_t(((vals0 ^ 0x80808080) - 0x04040404) ^ 0x80808080), + int32_t(((vals1 ^ 0x80808080) - 0x04040404) ^ 0x80808080), + int32_t(((vals2 ^ 0x80808080) - 0x04040404) ^ 0x80808080), + int32_t(((vals3 ^ 0x80808080) - 0x04040404) ^ 0x80808080)); } float get_d_scale(uint ib, uint iqs) { @@ -343,27 +355,39 @@ i32vec4 repack4(uint ib, uint iqs) { const uint qh_idx = (iqs_k / 32) * 8 + iqs; const uint qh_shift = ((iqs_k % 32) / 8) * 2; - const i8vec2 vals00 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 ] >> ql_shift) & uint16_t(0x0F0F))) | - unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 ] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32); - const i8vec2 vals01 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 1] >> ql_shift) & uint16_t(0x0F0F))) | - unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 1] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32); - const i8vec2 vals10 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 2] >> ql_shift) & uint16_t(0x0F0F))) | - unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 2] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32); - const i8vec2 vals11 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 3] >> ql_shift) & uint16_t(0x0F0F))) | - unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 3] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32); - const i8vec2 vals20 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 4] >> ql_shift) & uint16_t(0x0F0F))) | - unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 4] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32); - const i8vec2 vals21 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 5] >> ql_shift) & uint16_t(0x0F0F))) | - unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 5] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32); - const i8vec2 vals30 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 6] >> ql_shift) & uint16_t(0x0F0F))) | - unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 6] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32); - const i8vec2 vals31 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 7] >> ql_shift) & uint16_t(0x0F0F))) | - unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 7] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32); - - return i32vec4(pack32(i8vec4(vals00.x, vals00.y, vals01.x, vals01.y)), - pack32(i8vec4(vals10.x, vals10.y, vals11.x, vals11.y)), - pack32(i8vec4(vals20.x, vals20.y, vals21.x, vals21.y)), - pack32(i8vec4(vals30.x, vals30.y, vals31.x, vals31.y))); + const uvec4 ql = uvec4( uint32_t(data_a_packed16[ib_k].ql[ql_idx * 2 ]) | + (uint32_t(data_a_packed16[ib_k].ql[ql_idx * 2 + 1]) << 16), + uint32_t(data_a_packed16[ib_k].ql[ql_idx * 2 + 2]) | + (uint32_t(data_a_packed16[ib_k].ql[ql_idx * 2 + 3]) << 16), + uint32_t(data_a_packed16[ib_k].ql[ql_idx * 2 + 4]) | + (uint32_t(data_a_packed16[ib_k].ql[ql_idx * 2 + 5]) << 16), + uint32_t(data_a_packed16[ib_k].ql[ql_idx * 2 + 6]) | + (uint32_t(data_a_packed16[ib_k].ql[ql_idx * 2 + 7]) << 16)); + + const uvec4 qh = uvec4( uint32_t(data_a_packed16[ib_k].qh[qh_idx * 2 ]) | + (uint32_t(data_a_packed16[ib_k].qh[qh_idx * 2 + 1]) << 16), + uint32_t(data_a_packed16[ib_k].qh[qh_idx * 2 + 2]) | + (uint32_t(data_a_packed16[ib_k].qh[qh_idx * 2 + 3]) << 16), + uint32_t(data_a_packed16[ib_k].qh[qh_idx * 2 + 4]) | + (uint32_t(data_a_packed16[ib_k].qh[qh_idx * 2 + 5]) << 16), + uint32_t(data_a_packed16[ib_k].qh[qh_idx * 2 + 6]) | + (uint32_t(data_a_packed16[ib_k].qh[qh_idx * 2 + 7]) << 16)); + + const uint vals0 = (( ql.x >> ql_shift) & 0x0F0F0F0F) | + (((qh.x >> qh_shift) & 0x03030303) << 4); + const uint vals1 = (( ql.y >> ql_shift) & 0x0F0F0F0F) | + (((qh.y >> qh_shift) & 0x03030303) << 4); + const uint vals2 = (( ql.z >> ql_shift) & 0x0F0F0F0F) | + (((qh.z >> qh_shift) & 0x03030303) << 4); + const uint vals3 = (( ql.w >> ql_shift) & 0x0F0F0F0F) | + (((qh.w >> qh_shift) & 0x03030303) << 4); + + // Subtract 32 by twiddling bits rather than using re-packing as mesa + // compiles repacking poorly. + return i32vec4(int32_t(((vals0 ^ 0x80808080) - 0x20202020) ^ 0x80808080), + int32_t(((vals1 ^ 0x80808080) - 0x20202020) ^ 0x80808080), + int32_t(((vals2 ^ 0x80808080) - 0x20202020) ^ 0x80808080), + int32_t(((vals3 ^ 0x80808080) - 0x20202020) ^ 0x80808080)); } float get_d_scale(uint ib, uint iqs) { From e815b264eba131de2d06be039612fad9eb4330f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= <johannesg@5d6.de> Date: Mon, 1 Jun 2026 12:30:10 +0200 Subject: [PATCH 768/831] TP: quantized KV cache support (llama/23792) * TP: quantized KV cache support * fix partial view * remove overly strict assert --- ggml/include/ggml-backend.h | 10 +- ggml/src/ggml-backend-meta.cpp | 278 +++++++++++++++++---------------- 2 files changed, 149 insertions(+), 139 deletions(-) diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h index b6f73739809..2924fdbe988 100644 --- a/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h @@ -381,11 +381,15 @@ extern "C" { // - most tensors have n_segments == 1 and a contiguous slice of the tensor data // - some tensors have an inhomogenenous data layout along the split axis, // those tensors are divided into segments which are each individually split across devices - // - ne has one entry per segment and device that add up to ggml_tensor::ne for that axis, - // the outer/inner loops are over segments/devices like [seg0_dev0, seg0_dev1, seg1_dev0, seg1_dev1], + // - ne has one entry per segment and device and that segment repeats nr times, + // in total when accounting for repetitions the segments add up to ggml_tensor::ne for that axis, + // the outer/inner loops are over segments/devices like [seg0_dev0_r0, seg0_dev1_r0, seg0_dev0_r1, seg0_dev1_r1, seg1_dev0_r0, seg1_dev1_r0], // - for example, a transformer may have a fused QKV matrix rather than 3 matrices, those would be 3 separate segments - // that each need to be split individually across devices so that each device gets a slice of Q, K, and V + // that each need to be split individually across devices so that each device gets a slice of Q, K, and V, + // the Q matrix can be larger than the K and V matrices so this can either be expressed as 3 segments or as 2 segments + // where the segment for K/V repeats twice int64_t ne[16*GGML_BACKEND_META_MAX_DEVICES]; + uint32_t nr[16]; uint32_t n_segments; }; diff --git a/ggml/src/ggml-backend-meta.cpp b/ggml/src/ggml-backend-meta.cpp index 48b2027fac3..8c44c3e44ae 100644 --- a/ggml/src/ggml-backend-meta.cpp +++ b/ggml/src/ggml-backend-meta.cpp @@ -487,6 +487,9 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state(co static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( ggml_backend_meta_simple_tensor_container & stc, const struct ggml_tensor * tensor, bool assume_sync) { + // FIXME Currently this function preserves/erases the information in n_segments and nr in an inconsistent way. + // Since the operations in question are developed specifically for llama.cpp this currently does not manifest as a bug there. + // However, in a broader ggml context with arbitrary ggml graphs this can lead to unexpected results. const size_t n_bufs = ggml_backend_meta_buffer_n_bufs(tensor->buffer); ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) tensor->buffer->context; @@ -497,11 +500,11 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( for (size_t j = 0; j < n_bufs; j++) { int64_t sum_a = 0; for (size_t s = 0; s < a.n_segments; s++) { - sum_a += a.ne[s*n_bufs + j]; + sum_a += a.ne[s*n_bufs + j] * a.nr[s]; } int64_t sum_b = 0; for (size_t s = 0; s < b.n_segments; s++) { - sum_b += b.ne[s*n_bufs + j]; + sum_b += b.ne[s*n_bufs + j] * b.nr[s]; } if (sum_a != sum_b) { return false; @@ -511,7 +514,7 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( }; auto handle_generic = [&](const std::vector<ggml_backend_meta_split_state> & src_ss, bool scalar_only) -> ggml_backend_meta_split_state { - ggml_backend_meta_split_state ret = {GGML_BACKEND_SPLIT_AXIS_NONE, {0}, 1}; + ggml_backend_meta_split_state ret = {GGML_BACKEND_SPLIT_AXIS_NONE, {0}, {1}, 1}; for (size_t i = 0; i < GGML_MAX_SRC; i++) { if (tensor->src[i] == nullptr || tensor->src[i] == tensor) { continue; @@ -519,15 +522,15 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( if (ret.axis == GGML_BACKEND_SPLIT_AXIS_NONE) { ret = src_ss[i]; } else if (!split_states_equal(src_ss[i], ret)) { - ret = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; + ret = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1}; break; } } if (ret.axis == GGML_BACKEND_SPLIT_AXIS_NONE) { - ret = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; + ret = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1}; } if (scalar_only && ret.axis >= 0 && ret.axis < GGML_MAX_DIMS) { - ret = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; + ret = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1}; } GGML_ASSERT(ret.axis != GGML_BACKEND_SPLIT_AXIS_UNKNOWN); return ret; @@ -571,42 +574,24 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( auto handle_mul_mat = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state { if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { - return {GGML_BACKEND_SPLIT_AXIS_MIRRORED, {0}, 1}; + return {GGML_BACKEND_SPLIT_AXIS_MIRRORED, {0}, {1}, 1}; } if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_1 && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { ggml_backend_meta_split_state ret = src_ss[0]; ret.axis = GGML_BACKEND_SPLIT_AXIS_0; + ret.nr[0] = 1; ret.n_segments = 1; return ret; } if (src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_1 && src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { - ggml_backend_meta_split_state ret = src_ss[1]; - ret.n_segments = 1; - return ret; + return src_ss[1]; } if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_0 && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_0) { GGML_ASSERT(split_states_equal(src_ss[0], src_ss[1])); - return {assume_sync ? GGML_BACKEND_SPLIT_AXIS_MIRRORED : GGML_BACKEND_SPLIT_AXIS_PARTIAL, {0}, 1}; + return {assume_sync ? GGML_BACKEND_SPLIT_AXIS_MIRRORED : GGML_BACKEND_SPLIT_AXIS_PARTIAL, {0}, {1}, 1}; } GGML_ABORT("fatal error"); - //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; - }; - - auto handle_cpy = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state { - if (src_ss[0].axis >= 0 && src_ss[0].axis < GGML_MAX_DIMS) { - int64_t ne_split_src = tensor->src[0]->ne[0]; - for (int dim = 1; dim <= src_ss[0].axis; dim++) { - ne_split_src *= tensor->src[0]->ne[dim]; - } - int64_t ne_split_dst = 1; - for (int dim = 0; dim < GGML_MAX_DIMS; dim++) { - ne_split_dst *= tensor->ne[dim]; - if (ne_split_dst == ne_split_src) { - return {ggml_backend_meta_split_axis(dim), {0}, 1}; - } - } - } - return handle_generic(src_ss, /*scalar_only =*/ false); + //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1}; }; auto handle_reshape = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state { @@ -615,33 +600,25 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( case GGML_BACKEND_SPLIT_AXIS_1: case GGML_BACKEND_SPLIT_AXIS_2: case GGML_BACKEND_SPLIT_AXIS_3: { - GGML_ASSERT(!ggml_is_permuted(tensor) && !ggml_is_permuted(tensor->src[0])); - if (src_ss[0].axis == ggml_n_dims(tensor->src[0]) - 1) { - return {ggml_backend_meta_split_axis(ggml_n_dims(tensor) - 1), {0}, 1}; + GGML_ASSERT(src_ss[0].n_segments == 1); + if (src_ss[0].axis == ggml_n_dims(tensor->src[0]) - 1 && src_ss[0].nr[0] == 1) { + return {ggml_backend_meta_split_axis(ggml_n_dims(tensor) - 1), {0}, {1}, 1}; } - std::vector<int64_t> base_ne_in; - base_ne_in.reserve(GGML_MAX_DIMS - src_ss[0].axis); - { - base_ne_in.push_back(1); - int dim = 0; - for (; dim <= src_ss[0].axis; dim++) { - base_ne_in[0] *= tensor->src[0]->ne[dim]; - } - for (; dim <= GGML_MAX_DIMS; dim++) { - base_ne_in.push_back(base_ne_in.back() * tensor->src[0]->ne[dim]); - } + int64_t base_ne_in = tensor->src[0]->ne[0]; + for (int dim = 1; dim <= src_ss[0].axis; dim++) { + base_ne_in *= tensor->src[0]->ne[dim]; } + base_ne_in /= src_ss[0].nr[0]; int64_t base_ne_out = 1; for (int dim = 0; dim < GGML_MAX_DIMS; dim++) { const int64_t base_ne_out_next = base_ne_out *= tensor->ne[dim]; - for (const int64_t & bni : base_ne_in) { - if (bni == base_ne_out_next) { - return {ggml_backend_meta_split_axis(dim), {0}, 1}; - } + if (base_ne_out_next % base_ne_in == 0) { + return {ggml_backend_meta_split_axis(dim), {0}, {uint32_t(base_ne_out_next/base_ne_in)}, 1}; } - if (base_ne_out_next > base_ne_in[0]) { - GGML_ASSERT(dim + 1 < GGML_MAX_DIMS); - return {ggml_backend_meta_split_axis(dim + 1), {0}, 1}; + if (base_ne_out_next > base_ne_in) { + GGML_ASSERT(src_ss[0].n_segments == 1); + GGML_ASSERT(src_ss[0].nr[0] == 1); + return {ggml_backend_meta_split_axis(dim), {0}, {1}, 1}; } base_ne_out = base_ne_out_next; } @@ -653,11 +630,18 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( } default: { GGML_ABORT("fatal error"); - //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; + //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1}; } } }; + auto handle_cpy = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state { + if (src_ss[0].axis >= 0 && src_ss[0].axis < GGML_MAX_DIMS) { + return handle_reshape(src_ss); + } + return handle_generic(src_ss, /*scalar_only =*/ false); + }; + auto handle_view = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state { if (ggml_is_contiguous(tensor) && ggml_is_contiguous(tensor->src[0])) { return handle_reshape(src_ss); @@ -681,7 +665,7 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( if (!ggml_is_permuted(tensor) && !ggml_is_permuted(tensor->src[0]) && axis >= 0 && axis < GGML_MAX_DIMS-1) { for (int dim = 0; dim < GGML_MAX_DIMS-1; dim++) { if (tensor->nb[dim+1] == tensor->src[0]->nb[axis+1]) { - return {ggml_backend_meta_split_axis(dim), {0}, 1}; + return {ggml_backend_meta_split_axis(dim), {0}, {1}, 1}; } } GGML_ABORT("fatal error"); @@ -690,7 +674,7 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( return src_ss[0]; } GGML_ABORT("view of permuted tensor not implemented"); - //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; + //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1}; }; auto handle_permute = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state { @@ -699,7 +683,8 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( case GGML_BACKEND_SPLIT_AXIS_1: case GGML_BACKEND_SPLIT_AXIS_2: case GGML_BACKEND_SPLIT_AXIS_3: { - return {ggml_backend_meta_split_axis(tensor->op_params[src_ss[0].axis]), {0}, 1}; + GGML_ASSERT(src_ss[0].n_segments == 1 || src_ss[0].nr[0] == 1); + return {ggml_backend_meta_split_axis(tensor->op_params[src_ss[0].axis]), {0}, {src_ss[0].nr[0]}, 1}; } case GGML_BACKEND_SPLIT_AXIS_MIRRORED: case GGML_BACKEND_SPLIT_AXIS_PARTIAL: { @@ -707,7 +692,7 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( } default: { GGML_ABORT("fatal error"); - //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; + //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1}; } } }; @@ -716,7 +701,8 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( switch (src_ss[0].axis) { case GGML_BACKEND_SPLIT_AXIS_0: case GGML_BACKEND_SPLIT_AXIS_1: { - return {ggml_backend_meta_split_axis(int(src_ss[0].axis) ^ 1), {0}, 1}; + GGML_ASSERT(src_ss[0].n_segments == 1 || src_ss[0].nr[0] == 1); + return {ggml_backend_meta_split_axis(int(src_ss[0].axis) ^ 1), {0}, {src_ss[0].nr[0]}, 1}; } case GGML_BACKEND_SPLIT_AXIS_2: case GGML_BACKEND_SPLIT_AXIS_3: @@ -726,7 +712,7 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( } default: { GGML_ABORT("fatal error"); - //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; + //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1}; } } }; @@ -764,16 +750,16 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( GGML_ASSERT( src_ss[2].axis == GGML_BACKEND_SPLIT_AXIS_2); GGML_ASSERT(tensor->src[4] == nullptr || src_ss[3].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED); GGML_ASSERT(tensor->src[4] == nullptr || src_ss[4].axis == GGML_BACKEND_SPLIT_AXIS_0); - return {GGML_BACKEND_SPLIT_AXIS_1, {0}, 1}; + return {GGML_BACKEND_SPLIT_AXIS_1, {0}, {1}, 1}; }; auto handle_ssm_conv = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state { if (src_ss[0].axis == src_ss[1].axis) { if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_0) { - return {GGML_BACKEND_SPLIT_AXIS_1, {0}, 1}; + return {GGML_BACKEND_SPLIT_AXIS_1, {0}, {1}, 1}; } if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_1) { - return {GGML_BACKEND_SPLIT_AXIS_0, {0}, 1}; + return {GGML_BACKEND_SPLIT_AXIS_0, {0}, {1}, 1}; } } return handle_generic(src_ss, /*scalar_only =*/ false); @@ -781,8 +767,8 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( auto handle_gated_delta_net = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state { if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && - src_ss[2].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[3].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && - src_ss[4].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { + src_ss[2].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[3].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && + src_ss[4].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { return src_ss[0]; } GGML_ASSERT(src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_1); @@ -793,12 +779,12 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( // state shape is (S_v*S_v*H, K, n_seqs); the heads dim is nested inside axis 0, // so a head-aligned split on the input cache reshapes to axis 0 here (not axis 2). GGML_ASSERT(src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_2 || src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_1 || src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_0); - return {GGML_BACKEND_SPLIT_AXIS_0, {0}, 1}; + return {GGML_BACKEND_SPLIT_AXIS_0, {0}, {1}, 1}; }; auto calculate_split_state = [&]() -> ggml_backend_meta_split_state { if (ggml_nelements(tensor) == 0) { - return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; + return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1}; } if (ggml_backend_buffer_get_usage(tensor->buffer) != GGML_BACKEND_BUFFER_USAGE_COMPUTE && tensor->view_src == nullptr) { ggml_backend_dev_t dev = ggml_backend_buft_get_device(ggml_backend_buffer_get_type(tensor->buffer)); @@ -807,19 +793,21 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( if (ret.axis >= 0 && ret.axis <= GGML_MAX_DIMS) { const int64_t granularity = ret.axis == GGML_BACKEND_SPLIT_AXIS_0 ? ggml_blck_size(tensor->type) : 1; int64_t ne_sum = 0; - for (size_t sj = 0; sj < ret.n_segments*n_bufs; sj++) { - GGML_ASSERT(ret.ne[sj] % granularity == 0); - ne_sum += ret.ne[sj]; + for (size_t s = 0; s < ret.n_segments; s++) { + for (size_t j = 0; j < n_bufs; j++) { + GGML_ASSERT(ret.ne[s*n_bufs + j] % granularity == 0); + ne_sum += ret.ne[s*n_bufs + j] * ret.nr[s]; + } } GGML_ASSERT(ne_sum == tensor->ne[ret.axis]); } return ret; } - std::vector<ggml_backend_meta_split_state> src_ss(GGML_MAX_SRC, {GGML_BACKEND_SPLIT_AXIS_NONE, {0}, 1}); + std::vector<ggml_backend_meta_split_state> src_ss(GGML_MAX_SRC, {GGML_BACKEND_SPLIT_AXIS_NONE, {0}, {1}, 1}); for (size_t i = 0; i < GGML_MAX_SRC; i++) { if (tensor->src[i] == nullptr || tensor->src[i] == tensor) { - src_ss[i] = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; + src_ss[i] = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1}; continue; } src_ss[i] = ggml_backend_meta_get_split_state(stc, tensor->src[i], /*assume_sync =*/ true); @@ -829,7 +817,7 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( ggml_backend_meta_split_state split_state; switch (tensor->op) { case GGML_OP_NONE: { - split_state = {GGML_BACKEND_SPLIT_AXIS_MIRRORED, {0}, 1}; + split_state = {GGML_BACKEND_SPLIT_AXIS_MIRRORED, {0}, {1}, 1}; } break; case GGML_OP_DUP: { split_state = handle_generic(src_ss, /*scalar_only =*/ true); @@ -1016,7 +1004,7 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( } break; default: { GGML_ABORT("ggml op not implemented: %s", ggml_op_name(tensor->op)); - split_state = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, 1}; + split_state = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1}; } break; } if (split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS) { @@ -1034,23 +1022,25 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( split_state.ne[s*n_bufs + j] = 0; } for (size_t s = 0; s < src_ss[i].n_segments; s++) { - split_state.ne[j] += src_ss[i].ne[s*n_bufs + j]; + split_state.ne[j] += src_ss[i].ne[s*n_bufs + j] * src_ss[i].nr[s]; } split_state.ne[j] *= tensor->ne[split_state.axis]; if (split_state.ne[j] != 0 || tensor->src[i]->ne[src_ss[i].axis] != 0) { - GGML_ASSERT(split_state.ne[j] % tensor->src[i]->ne[src_ss[i].axis] == 0); - split_state.ne[j] /= tensor->src[i]->ne[src_ss[i].axis]; + const int64_t div = tensor->src[i]->ne[src_ss[i].axis] * split_state.nr[0]; + GGML_ASSERT(split_state.ne[j] % div == 0); + split_state.ne[j] /= div; } } } else { + GGML_ASSERT(split_state.n_segments == 1); for (size_t j = 0; j < n_bufs; j++) { + // Assert that ratio is consistent: int64_t sum = 0; for (size_t s = 0; s < src_ss[i].n_segments; s++) { - sum += src_ss[i].ne[s*n_bufs + j]; + sum += src_ss[i].ne[s*n_bufs + j] * src_ss[i].nr[s]; } - // Assert that ratio is consistent: - GGML_ASSERT(split_state.ne[j] * tensor->src[i]->ne[src_ss[i].axis] - == sum * tensor->ne[split_state.axis]); + GGML_ASSERT(split_state.ne[j]*split_state.nr[0] * tensor->src[i]->ne[src_ss[i].axis] + == sum * tensor->ne[split_state.axis]); } } first_src_split_by_axis = false; @@ -1080,13 +1070,14 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( srcs_info += ", "; } const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor->src[0], true); + GGML_ASSERT(split_state.n_segments == 1); const char * axis_name = ggml_backend_meta_split_axis_name(split_state.axis); std::string ne_info; for (size_t j = 0; j < n_bufs; j++) { if (!ne_info.empty()) { ne_info += ", "; } - ne_info += std::to_string(split_state.ne[j]); + ne_info += std::to_string(split_state.ne[j]) + "x" + std::to_string(split_state.nr[0]); } srcs_info += std::string(tensor->src[i]->name) + "[" + ggml_op_name(tensor->src[i]->op) + ", " + axis_name + ", {" + ne_info + "}]"; } @@ -1095,7 +1086,8 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( if (!ne_info.empty()) { ne_info += ", "; } - ne_info += std::to_string(buf_ctx->split_state_cache[key].first.ne[j]); + const ggml_backend_meta_split_state & ss = buf_ctx->split_state_cache[key].first; + ne_info += std::to_string(ss.ne[j]) + "x" + std::to_string(ss.nr[0]); } GGML_LOG_DEBUG("SPLIT_STATE: {%s} -> %s[%s, %s, {%s}]\n", srcs_info.c_str(), tensor->name, ggml_op_name(tensor->op), ggml_backend_meta_split_axis_name(buf_ctx->split_state_cache[key].first.axis), ne_info.c_str()); @@ -1107,8 +1099,10 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( #ifndef NDEBUG if (ret.axis >= 0 && ret.axis < GGML_MAX_DIMS) { int64_t ne_ret = 0; - for (size_t sj = 0; sj < ret.n_segments*n_bufs; sj++) { - ne_ret += ret.ne[sj]; + for (size_t s = 0; s < ret.n_segments; s++) { + for (size_t j = 0; j < n_bufs; j++) { + ne_ret += ret.ne[s*n_bufs + j] * ret.nr[s]; + } } assert(ne_ret == tensor->ne[int(ret.axis)]); } @@ -1155,7 +1149,7 @@ static enum ggml_status ggml_backend_meta_buffer_init_tensor_impl(ggml_backend_m // GGML_ASSERT(ggml_is_contiguously_allocated(tensor)); ne[split_dim] = 0; for (size_t s = 0; s < split_state.n_segments; s++) { - ne[split_dim] += split_state.ne[s*n_simple_bufs + j]; + ne[split_dim] += split_state.ne[s*n_simple_bufs + j] * split_state.nr[s]; } for (int i = 0; i < GGML_MAX_DIMS; i++) { if (tensor->nb[i] > tensor->nb[split_dim]) { @@ -1229,7 +1223,7 @@ static enum ggml_status ggml_backend_meta_buffer_init_tensor_impl(ggml_backend_m for (size_t j = 0; j < n_simple_bufs; j++) { int64_t ne_sum = 0; for (size_t s = 0; s < split_state_src.n_segments; s++) { - ne_sum += split_state_src.ne[s*n_simple_bufs + j]; + ne_sum += split_state_src.ne[s*n_simple_bufs + j] * split_state_src.nr[s]; } if (ne_sum == 0) { simple_tensors[j]->flags &= ~GGML_TENSOR_FLAG_COMPUTE; @@ -1255,8 +1249,9 @@ static void ggml_backend_meta_buffer_set_tensor(ggml_backend_buffer_t buffer, gg const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ false); - if (split_state.n_segments != 1) { + if (split_state.n_segments != 1 || split_state.nr[0] != 1) { GGML_ASSERT(split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS); + GGML_ASSERT(split_state.nr[0] != 0); GGML_ASSERT(tensor->ne[3] == 1); size_t offset_data = 0; @@ -1267,24 +1262,26 @@ static void ggml_backend_meta_buffer_set_tensor(ggml_backend_buffer_t buffer, gg const size_t row_stride = tensor->nb[1]; GGML_ASSERT(offset % row_stride == 0); GGML_ASSERT(size % row_stride == 0); - const int64_t r_start = offset / row_stride; - const int64_t r_count = size / row_stride; - GGML_ASSERT(r_start + r_count <= tensor->ne[1]); + const int64_t row_start = offset / row_stride; + const int64_t row_count = size / row_stride; + GGML_ASSERT(row_start + row_count <= tensor->ne[1]); const int64_t blck_size = ggml_blck_size(tensor->type); for (size_t s = 0; s < split_state.n_segments; s++) { - for (size_t j = 0; j < n_bufs; j++) { - ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); - GGML_ASSERT(split_state.ne[s*n_bufs + j] % blck_size == 0); - const size_t nbytes = split_state.ne[s*n_bufs + j]/blck_size * tensor->nb[0]; - ggml_backend_tensor_set_2d(simple_tensor, (const char *) data + offset_data, - simple_offsets[j] + r_start * simple_tensor->nb[1], nbytes, - r_count, simple_tensor->nb[1], tensor->nb[1]); - offset_data += nbytes; - simple_offsets[j] += nbytes; + for (size_t r = 0; r < split_state.nr[s]; r++) { + for (size_t j = 0; j < n_bufs; j++) { + ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); + GGML_ASSERT(split_state.ne[s*n_bufs + j] % blck_size == 0); + const size_t nbytes = split_state.ne[s*n_bufs + j]/blck_size * tensor->nb[0]; + ggml_backend_tensor_set_2d(simple_tensor, (const char *) data + offset_data, + simple_offsets[j] + row_start * simple_tensor->nb[1], nbytes, + row_count, simple_tensor->nb[1], tensor->nb[1]); + offset_data += nbytes; + simple_offsets[j] += nbytes; + } } } - GGML_ASSERT(offset_data*r_count == size); + GGML_ASSERT(offset_data*row_count == size); return; } GGML_ASSERT(split_state.axis == GGML_BACKEND_SPLIT_AXIS_1); @@ -1292,22 +1289,24 @@ static void ggml_backend_meta_buffer_set_tensor(ggml_backend_buffer_t buffer, gg const size_t row_stride = tensor->nb[2]; GGML_ASSERT(offset % row_stride == 0); GGML_ASSERT(size % row_stride == 0); - const int64_t r_start = offset / row_stride; - const int64_t r_count = size / row_stride; - GGML_ASSERT(r_start + r_count <= tensor->ne[2]); + const int64_t row_start = offset / row_stride; + const int64_t row_count = size / row_stride; + GGML_ASSERT(row_start + row_count <= tensor->ne[2]); for (size_t s = 0; s < split_state.n_segments; s++) { - for (size_t j = 0; j < n_bufs; j++) { - ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); - const size_t nbytes = split_state.ne[s*n_bufs + j] * tensor->nb[1]; - ggml_backend_tensor_set_2d(simple_tensor, (const char *) data + offset_data, - simple_offsets[j] + r_start * simple_tensor->nb[2], nbytes, - r_count, simple_tensor->nb[2], tensor->nb[2]); - offset_data += nbytes; - simple_offsets[j] += nbytes; + for (size_t r = 0; r < split_state.nr[s]; r++) { + for (size_t j = 0; j < n_bufs; j++) { + ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); + const size_t nbytes = split_state.ne[s*n_bufs + j] * tensor->nb[1]; + ggml_backend_tensor_set_2d(simple_tensor, (const char *) data + offset_data, + simple_offsets[j] + row_start * simple_tensor->nb[2], nbytes, + row_count, simple_tensor->nb[2], tensor->nb[2]); + offset_data += nbytes; + simple_offsets[j] += nbytes; + } } } - GGML_ASSERT(offset_data*r_count == size); + GGML_ASSERT(offset_data*row_count == size); return; } @@ -1365,8 +1364,9 @@ static void ggml_backend_meta_buffer_get_tensor(ggml_backend_buffer_t buffer, co const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ false); - if (split_state.n_segments != 1) { + if (split_state.n_segments != 1 || split_state.nr[0] != 1) { GGML_ASSERT(split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS); + GGML_ASSERT(split_state.nr[0] != 0); GGML_ASSERT(tensor->ne[3] == 1); size_t offset_data = 0; @@ -1377,24 +1377,26 @@ static void ggml_backend_meta_buffer_get_tensor(ggml_backend_buffer_t buffer, co const size_t row_stride = tensor->nb[1]; GGML_ASSERT(offset % row_stride == 0); GGML_ASSERT(size % row_stride == 0); - const int64_t r_start = offset / row_stride; - const int64_t r_count = size / row_stride; - GGML_ASSERT(r_start + r_count <= tensor->ne[1]); + const int64_t row_start = offset / row_stride; + const int64_t row_count = size / row_stride; + GGML_ASSERT(row_start + row_count <= tensor->ne[1]); const int64_t blck_size = ggml_blck_size(tensor->type); for (size_t s = 0; s < split_state.n_segments; s++) { - for (size_t j = 0; j < n_bufs; j++) { - const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); - GGML_ASSERT(split_state.ne[s*n_bufs + j] % blck_size == 0); - const size_t nbytes = split_state.ne[s*n_bufs + j]/blck_size * tensor->nb[0]; - ggml_backend_tensor_get_2d(simple_tensor, (char *) data + offset_data, - simple_offsets[j] + r_start * simple_tensor->nb[1], nbytes, - r_count, simple_tensor->nb[1], tensor->nb[1]); - offset_data += nbytes; - simple_offsets[j] += nbytes; + for (size_t r = 0; r < split_state.nr[s]; r++) { + for (size_t j = 0; j < n_bufs; j++) { + const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); + GGML_ASSERT(split_state.ne[s*n_bufs + j] % blck_size == 0); + const size_t nbytes = split_state.ne[s*n_bufs + j]/blck_size * tensor->nb[0]; + ggml_backend_tensor_get_2d(simple_tensor, (char *) data + offset_data, + simple_offsets[j] + row_start * simple_tensor->nb[1], nbytes, + row_count, simple_tensor->nb[1], tensor->nb[1]); + offset_data += nbytes; + simple_offsets[j] += nbytes; + } } } - GGML_ASSERT(offset_data*r_count == size); + GGML_ASSERT(offset_data*row_count == size); return; } GGML_ASSERT(split_state.axis == GGML_BACKEND_SPLIT_AXIS_1); @@ -1402,22 +1404,24 @@ static void ggml_backend_meta_buffer_get_tensor(ggml_backend_buffer_t buffer, co const size_t row_stride = tensor->nb[2]; GGML_ASSERT(offset % row_stride == 0); GGML_ASSERT(size % row_stride == 0); - const int64_t r_start = offset / row_stride; - const int64_t r_count = size / row_stride; - GGML_ASSERT(r_start + r_count <= tensor->ne[2]); + const int64_t row_start = offset / row_stride; + const int64_t row_count = size / row_stride; + GGML_ASSERT(row_start + row_count <= tensor->ne[2]); for (size_t s = 0; s < split_state.n_segments; s++) { - for (size_t j = 0; j < n_bufs; j++) { - const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); - const size_t nbytes = split_state.ne[s*n_bufs + j] * tensor->nb[1]; - ggml_backend_tensor_get_2d(simple_tensor, (char *) data + offset_data, - simple_offsets[j] + r_start * simple_tensor->nb[2], nbytes, - r_count, simple_tensor->nb[2], tensor->nb[2]); - offset_data += nbytes; - simple_offsets[j] += nbytes; + for (size_t r = 0; r < split_state.nr[s]; r++) { + for (size_t j = 0; j < n_bufs; j++) { + const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); + const size_t nbytes = split_state.ne[s*n_bufs + j] * tensor->nb[1]; + ggml_backend_tensor_get_2d(simple_tensor, (char *) data + offset_data, + simple_offsets[j] + row_start * simple_tensor->nb[2], nbytes, + row_count, simple_tensor->nb[2], tensor->nb[2]); + offset_data += nbytes; + simple_offsets[j] += nbytes; + } } } - GGML_ASSERT(offset_data*r_count == size); + GGML_ASSERT(offset_data*row_count == size); return; } @@ -1675,6 +1679,7 @@ static void ggml_backend_meta_set_tensor_async(ggml_backend_t backend, ggml_tens const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ false); GGML_ASSERT(split_state.n_segments == 1); + GGML_ASSERT(split_state.nr[0] == 1); switch (split_state.axis) { case GGML_BACKEND_SPLIT_AXIS_0: @@ -1719,6 +1724,7 @@ static void ggml_backend_meta_get_tensor_async(ggml_backend_t backend, const ggm const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ false); GGML_ASSERT(split_state.n_segments == 1); + GGML_ASSERT(split_state.nr[0] == 1); switch (split_state.axis) { case GGML_BACKEND_SPLIT_AXIS_0: From c471bcce1b2d7cdaf372d621b64b1275cb7a01d8 Mon Sep 17 00:00:00 2001 From: Winston Ma <winstonma@ymail.com> Date: Mon, 1 Jun 2026 20:03:32 +0800 Subject: [PATCH 769/831] vulkan: reduce host memory lock contention (llama/23376) * vulkan: reduces lock contention * replace unique_lock with lock_guard --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 3cf191f2085..c3d4c7a7129 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -62,6 +62,7 @@ typedef struct VkPhysicalDeviceCooperativeMatrixDecodeVectorFeaturesNV { #include <map> #include <set> #include <unordered_map> +#include <shared_mutex> #include <mutex> #include <future> #include <thread> @@ -618,6 +619,7 @@ static constexpr std::initializer_list<std::array<int, 3>> rms_norm_mul_rope_vie struct vk_device_struct { std::recursive_mutex mutex; + mutable std::shared_mutex pinned_memory_mutex; vk::PhysicalDevice physical_device; vk::PhysicalDeviceProperties properties; @@ -7010,7 +7012,7 @@ static void * ggml_vk_host_malloc(vk_device& device, size_t size) { return nullptr; } - std::lock_guard<std::recursive_mutex> guard(device->mutex); + std::lock_guard<std::shared_mutex> guard(device->pinned_memory_mutex); device->pinned_memory.push_back(std::make_tuple(buf->ptr, size, buf)); return buf->ptr; @@ -7021,7 +7023,7 @@ static void ggml_vk_host_free(vk_device& device, void* ptr) { return; } VK_LOG_MEMORY("ggml_vk_host_free(" << ptr << ")"); - std::lock_guard<std::recursive_mutex> guard(device->mutex); + std::lock_guard<std::shared_mutex> guard(device->pinned_memory_mutex); vk_buffer buf; size_t index; @@ -7045,7 +7047,7 @@ static void ggml_vk_host_free(vk_device& device, void* ptr) { } static void ggml_vk_host_get(const vk_device& device, const void * ptr, vk_buffer& buf, size_t& buf_offset) { - std::lock_guard<std::recursive_mutex> guard(device->mutex); + std::shared_lock<std::shared_mutex> guard(device->pinned_memory_mutex); buf = nullptr; buf_offset = 0; for (size_t i = 0; i < device->pinned_memory.size(); i++) { From 71d80aa49eb93868a8ed7e9f8abeae9e061adcfe Mon Sep 17 00:00:00 2001 From: Jeff Bolz <jbolz@nvidia.com> Date: Mon, 1 Jun 2026 07:04:01 -0500 Subject: [PATCH 770/831] vulkan: don't hold the device mutex while compiling pipelines (llama/23641) * vulkan: don't hold the device mutex while compiling pipelines We need to hold a lock while we traverse all pipelines and lazily initialize them, but we don't need to hold it while the pipeline is being compiled. And it doesn't need to be the same lock as the device mutex. We call load_shaders each time a pipeline is needed, so we only need to compile that one pipeline (and, for example, don't want to end up compiling a pipeline that another thread should be compiling). * remove 'needed' --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 144 ++++++++++++++++++--------- 1 file changed, 99 insertions(+), 45 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index c3d4c7a7129..e7d04634b8a 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -65,6 +65,7 @@ typedef struct VkPhysicalDeviceCooperativeMatrixDecodeVectorFeaturesNV { #include <shared_mutex> #include <mutex> #include <future> +#include <condition_variable> #include <thread> #if defined(_MSC_VER) @@ -159,8 +160,9 @@ struct vk_pipeline_struct { uint32_t align; // true if fields have been set by ggml_vk_create_pipeline bool initialized {}; - // set to true to request the pipeline is compiled - std::atomic<bool> needed {}; + // true while a compile is in flight, used to dedupe concurrent claims. + // Protected by device->compile_mutex. + bool compile_pending {}; // set to true when the shader has been compiled std::atomic<bool> compiled {}; // number of registers used, extracted from pipeline executable properties @@ -621,6 +623,13 @@ struct vk_device_struct { std::recursive_mutex mutex; mutable std::shared_mutex pinned_memory_mutex; + // Guards compile_pending, all_pipelines, and the dynamic pipeline maps + // (flash_attn, fa_mask_opt, solve_tri, conv2d, etc). The actual compile + // runs with no lock held, so different pipelines can compile in parallel. + // Lock order is device->mutex -> compile_mutex, never the reverse. + std::mutex compile_mutex; + std::condition_variable compile_cv; + vk::PhysicalDevice physical_device; vk::PhysicalDeviceProperties properties; std::string name; @@ -1729,7 +1738,7 @@ struct ggml_vk_garbage_collector { }; static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx, vk_context subctx); -static void ggml_vk_load_shaders(vk_device& device); +static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested = nullptr); static void ggml_pipeline_allocate_descriptor_sets(ggml_backend_vk_context * ctx); static bool vk_memory_logger_enabled = false; @@ -2196,11 +2205,6 @@ static void ggml_vk_wait_for_fence(ggml_backend_vk_context * ctx) { ctx->device->device.resetFences({ ctx->fence }); } -// variables to track number of compiles in progress -static uint32_t compile_count = 0; -static std::mutex compile_count_mutex; -static std::condition_variable compile_count_cond; - static constexpr uint32_t kSpvOpCooperativeMatrixLoadTensorNV = 5367; static constexpr uint32_t kSpvCapabilityCooperativeMatrixDecodeVectorNV = 5447; static constexpr uint32_t kSpvTensorAddressingDecodeVectorFuncBit = 0x4; @@ -2495,7 +2499,6 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin std::cerr << "ggml_vulkan: " << e.what() << std::endl; throw e; } - pipeline->compiled = true; if (vk_instance.debug_utils_support) { vk::DebugUtilsObjectNameInfoEXT duoni; @@ -2544,14 +2547,13 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin } } - device->all_pipelines.push_back(pipeline); - { - std::lock_guard<std::mutex> guard(compile_count_mutex); - assert(compile_count > 0); - compile_count--; + std::lock_guard<std::mutex> guard(device->compile_mutex); + device->all_pipelines.push_back(pipeline); + pipeline->compiled = true; + pipeline->compile_pending = false; } - compile_count_cond.notify_all(); + device->compile_cv.notify_all(); } static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline) { @@ -2567,8 +2569,7 @@ static void ggml_pipeline_request_descriptor_sets(ggml_backend_vk_context *ctx, VK_LOG_DEBUG("ggml_pipeline_request_descriptor_sets(" << pipeline->name << ", " << n << ")"); ctx->pipeline_descriptor_set_requirements += n; if (!pipeline->compiled) { - pipeline->needed = true; - ggml_vk_load_shaders(ctx->device); + ggml_vk_load_shaders(ctx->device, pipeline); } ggml_pipeline_allocate_descriptor_sets(ctx); } @@ -3567,10 +3568,26 @@ static bool ggml_vk_fa_scalar_uses_mmq(const vk_device& device, ggml_type k_type #endif } -static void ggml_vk_load_shaders(vk_device& device) { +// load_shaders walks the pipeline list under compile_mutex and either claims +// the requested pipeline for compilation or, if another thread is already +// compiling it, drops the lock and waits on compile_cv. Compiles themselves +// run unlocked. +struct CompileTask { + vk_pipeline pipeline; + size_t spv_size; + const void * spv_data; + std::string entrypoint; + uint32_t parameter_count; + std::array<uint32_t, 3> wg_denoms; + std::vector<uint32_t> specialization_constants; + bool disable_robustness; + bool require_full_subgroups; + uint32_t required_subgroup_size; +}; + +static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) { VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")"); - std::lock_guard<std::recursive_mutex> guard(device->mutex); // some shaders have a minimum subgroup size const uint32_t subgroup_size_8 = std::max(device->subgroup_size, 8u); const uint32_t subgroup_size_16 = std::max(device->subgroup_size, 16u); @@ -3600,6 +3617,15 @@ static void ggml_vk_load_shaders(vk_device& device) { l_mmqid_wg_denoms, m_mmqid_wg_denoms, s_mmqid_wg_denoms; uint32_t l_align, m_align, s_align; + + vk_pipeline wait_pipeline; + CompileTask claimed_task {}; + bool has_claimed_task = false; + + // The rest of the walk reads and writes shared device state, so hold the + // lock until we're done deciding what to compile. + std::unique_lock<std::mutex> compile_lock(device->compile_mutex); + if (device->coopmat2) { // spec constants and tile sizes for non-quant matmul/matmul_id l_warptile = { 256, 128, 256, 64, 1 }; @@ -3785,7 +3811,6 @@ static void ggml_vk_load_shaders(vk_device& device) { device->pipeline_matmul_id_bf16 = std::make_shared<vk_matmul_pipeline_struct>(); } - std::vector<std::future<void>> compiles; auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& base_pipeline, const char *name, size_t spv_size, const void* spv_data, const char *entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, const std::vector<uint32_t>& specialization_constants, uint32_t align, bool disable_robustness = false, bool require_full_subgroups = false, uint32_t required_subgroup_size = 0) { @@ -3819,23 +3844,33 @@ static void ggml_vk_load_shaders(vk_device& device) { #endif } - if (!pipeline->needed || pipeline->compiled) { + // We only care about the pipeline this call asked for; the rest + // (including the 64-bit indexing variant) are handled by their + // own request_descriptor_sets / load_shaders calls. + if (pipeline.get() != requested.get()) { continue; } - // TODO: We're no longer benefitting from the async compiles (shaders are - // compiled individually, as needed) and this complexity can be removed. - { - // wait until fewer than N compiles are in progress - uint32_t N = std::max(1u, std::thread::hardware_concurrency()); - std::unique_lock<std::mutex> guard(compile_count_mutex); - while (compile_count >= N) { - compile_count_cond.wait(guard); - } - compile_count++; + + if (pipeline->compiled) { + continue; } - compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), spv_size, spv_data, entrypoint, - parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size)); + wait_pipeline = pipeline; + + if (!pipeline->compile_pending) { + pipeline->compile_pending = true; + claimed_task.pipeline = pipeline; + claimed_task.spv_size = spv_size; + claimed_task.spv_data = spv_data; + claimed_task.entrypoint = entrypoint; + claimed_task.parameter_count = parameter_count; + claimed_task.wg_denoms = wg_denoms; + claimed_task.specialization_constants = specialization_constants; + claimed_task.disable_robustness = disable_robustness; + claimed_task.require_full_subgroups = require_full_subgroups; + claimed_task.required_subgroup_size = required_subgroup_size; + has_claimed_task = true; + } } }; @@ -5332,8 +5367,25 @@ static void ggml_vk_load_shaders(vk_device& device) { } } - for (auto &c : compiles) { - c.wait(); + // Drop compile_mutex so other threads can walk while we compile. + compile_lock.unlock(); + + // Compile what we claimed; create_pipeline_func reacquires compile_mutex + // at the end to flip compile_pending/compiled and notify waiters. + if (has_claimed_task) { + auto & task = claimed_task; + ggml_vk_create_pipeline_func(device, task.pipeline, task.spv_size, task.spv_data, + task.entrypoint, task.parameter_count, task.wg_denoms, + task.specialization_constants, task.disable_robustness, + task.require_full_subgroups, task.required_subgroup_size); + } + + // Another thread may be compiling the pipeline we need; block on it here. + if (wait_pipeline) { + std::unique_lock<std::mutex> wait_lock(device->compile_mutex); + device->compile_cv.wait(wait_lock, [&] { + return wait_pipeline->compiled.load(); + }); } } @@ -9722,7 +9774,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx vk_pipeline pipeline = nullptr; { - std::lock_guard<std::recursive_mutex> guard(ctx->device->mutex); + std::lock_guard<std::mutex> guard(ctx->device->compile_mutex); auto &pipelines = ctx->device->pipeline_flash_attn_f32_f16; auto it = pipelines.find(fa_pipeline_state); if (it != pipelines.end()) { @@ -9786,13 +9838,15 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx vk_pipeline pipeline_fa_mask_opt = nullptr; if (use_mask_opt) { - std::lock_guard<std::recursive_mutex> guard(ctx->device->mutex); - auto &pipelines = ctx->device->pipeline_fa_mask_opt; - auto it = pipelines.find({Br, Bc}); - if (it != pipelines.end()) { - pipeline_fa_mask_opt = it->second; - } else { - pipelines[{Br, Bc}] = pipeline_fa_mask_opt = std::make_shared<vk_pipeline_struct>(); + { + std::lock_guard<std::mutex> guard(ctx->device->compile_mutex); + auto &pipelines = ctx->device->pipeline_fa_mask_opt; + auto it = pipelines.find({Br, Bc}); + if (it != pipelines.end()) { + pipeline_fa_mask_opt = it->second; + } else { + pipelines[{Br, Bc}] = pipeline_fa_mask_opt = std::make_shared<vk_pipeline_struct>(); + } } assert(pipeline_fa_mask_opt); ggml_pipeline_request_descriptor_sets(ctx, pipeline_fa_mask_opt, 1); @@ -10326,7 +10380,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const vk_pipeline pipeline = nullptr; { - std::lock_guard<std::recursive_mutex> guard(ctx->device->mutex); + std::lock_guard<std::mutex> guard(ctx->device->compile_mutex); auto it = ctx->device->pipeline_solve_tri_f32.find(solve_tri_pipeline_state); if (it != ctx->device->pipeline_solve_tri_f32.end()) { pipeline = it->second; @@ -10485,7 +10539,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const vk_pipeline pipeline = nullptr; { - std::lock_guard<std::recursive_mutex> guard(ctx->device->mutex); + std::lock_guard<std::mutex> guard(ctx->device->compile_mutex); auto it = pipelines->find(conv2d_pipeline_state); if (it != pipelines->end()) { pipeline = it->second; From 050b8567a0fff75392c249d9283f8ee2dfa89292 Mon Sep 17 00:00:00 2001 From: Shrivas Shankar <86219405+shrivasshankar@users.noreply.github.com> Date: Mon, 1 Jun 2026 07:40:28 -0500 Subject: [PATCH 771/831] metal: template GLU kernels to support f16/f32 (llama/23882) Drops the hardcoded f32 GLU kernels in favor of a single template. We now load/store in the native tensor type (half or float) to save memory bandwidth, but keep the actual ALU compute in float to avoid exploding math in geglu/swiglu. Also opened up the dispatch gate to allow f16 inputs. --- ggml/src/ggml-metal/ggml-metal-device.m | 2 +- ggml/src/ggml-metal/ggml-metal.metal | 96 +++++++++++++++++-------- 2 files changed, 67 insertions(+), 31 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 885344ec670..196af102643 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -1107,7 +1107,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_GLU_OP_SWIGLU_OAI: case GGML_GLU_OP_GEGLU_ERF: case GGML_GLU_OP_GEGLU_QUICK: - return ggml_is_contiguous_1(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; + return ggml_is_contiguous_1(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16); default: return false; } diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 4adf4614acb..2bd310d9450 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1421,7 +1421,8 @@ template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat<int>; template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat<short>; -kernel void kernel_reglu_f32( +template<typename T> +kernel void kernel_reglu( constant ggml_metal_kargs_glu & args, device const char * src0, device const char * src1, @@ -1429,19 +1430,25 @@ kernel void kernel_reglu_f32( uint tgpig[[threadgroup_position_in_grid]], uint tpitg[[thread_position_in_threadgroup]], uint ntg[[threads_per_threadgroup]]) { - device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00; - device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10; - device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1); + device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00; + device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10; + device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1); for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) { const float x0 = src0_row[i0]; const float x1 = src1_row[i0]; - dst_row[i0] = x0*x1*(x0 > 0.0f); + dst_row[i0] = (T)(x0*x1*(x0 > 0.0f)); } } -kernel void kernel_geglu_f32( +typedef decltype(kernel_reglu<float>) kernel_reglu_t; + +template [[host_name("kernel_reglu_f32")]] kernel kernel_reglu_t kernel_reglu<float>; +template [[host_name("kernel_reglu_f16")]] kernel kernel_reglu_t kernel_reglu<half>; + +template<typename T> +kernel void kernel_geglu( constant ggml_metal_kargs_glu & args, device const char * src0, device const char * src1, @@ -1449,9 +1456,9 @@ kernel void kernel_geglu_f32( uint tgpig[[threadgroup_position_in_grid]], uint tpitg[[thread_position_in_threadgroup]], uint ntg[[threads_per_threadgroup]]) { - device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00; - device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10; - device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1); + device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00; + device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10; + device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1); for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) { const float x0 = src0_row[i0]; @@ -1459,11 +1466,17 @@ kernel void kernel_geglu_f32( const float gelu = 0.5f*x0*(1.0f + precise::tanh(SQRT_2_OVER_PI*x0*(1.0f + GELU_COEF_A*x0*x0))); - dst_row[i0] = gelu*x1; + dst_row[i0] = (T)(gelu*x1); } } -kernel void kernel_swiglu_f32( +typedef decltype(kernel_geglu<float>) kernel_geglu_t; + +template [[host_name("kernel_geglu_f32")]] kernel kernel_geglu_t kernel_geglu<float>; +template [[host_name("kernel_geglu_f16")]] kernel kernel_geglu_t kernel_geglu<half>; + +template<typename T> +kernel void kernel_swiglu( constant ggml_metal_kargs_glu & args, device const char * src0, device const char * src1, @@ -1471,9 +1484,9 @@ kernel void kernel_swiglu_f32( uint tgpig[[threadgroup_position_in_grid]], uint tpitg[[thread_position_in_threadgroup]], uint ntg[[threads_per_threadgroup]]) { - device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00; - device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10; - device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1); + device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00; + device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10; + device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1); for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) { const float x0 = src0_row[i0]; @@ -1481,11 +1494,17 @@ kernel void kernel_swiglu_f32( const float silu = x0 / (1.0f + exp(-x0)); - dst_row[i0] = silu*x1; + dst_row[i0] = (T)(silu*x1); } } -kernel void kernel_swiglu_oai_f32( +typedef decltype(kernel_swiglu<float>) kernel_swiglu_t; + +template [[host_name("kernel_swiglu_f32")]] kernel kernel_swiglu_t kernel_swiglu<float>; +template [[host_name("kernel_swiglu_f16")]] kernel kernel_swiglu_t kernel_swiglu<half>; + +template<typename T> +kernel void kernel_swiglu_oai( constant ggml_metal_kargs_glu & args, device const char * src0, device const char * src1, @@ -1493,9 +1512,9 @@ kernel void kernel_swiglu_oai_f32( uint tgpig[[threadgroup_position_in_grid]], uint tpitg[[thread_position_in_threadgroup]], uint ntg[[threads_per_threadgroup]]) { - device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00; - device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10; - device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1); + device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00; + device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10; + device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1); for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) { float x0 = src0_row[i0]; @@ -1507,11 +1526,17 @@ kernel void kernel_swiglu_oai_f32( float out_glu = x0 / (1.0f + exp(-x0 * args.alpha)); out_glu = out_glu * (1.0f + x1); - dst_row[i0] = out_glu; + dst_row[i0] = (T)out_glu; } } -kernel void kernel_geglu_erf_f32( +typedef decltype(kernel_swiglu_oai<float>) kernel_swiglu_oai_t; + +template [[host_name("kernel_swiglu_oai_f32")]] kernel kernel_swiglu_oai_t kernel_swiglu_oai<float>; +template [[host_name("kernel_swiglu_oai_f16")]] kernel kernel_swiglu_oai_t kernel_swiglu_oai<half>; + +template<typename T> +kernel void kernel_geglu_erf( constant ggml_metal_kargs_glu & args, device const char * src0, device const char * src1, @@ -1519,9 +1544,9 @@ kernel void kernel_geglu_erf_f32( uint tgpig[[threadgroup_position_in_grid]], uint tpitg[[thread_position_in_threadgroup]], uint ntg[[threads_per_threadgroup]]) { - device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00; - device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10; - device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1); + device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00; + device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10; + device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1); for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) { const float x0 = src0_row[i0]; @@ -1529,11 +1554,17 @@ kernel void kernel_geglu_erf_f32( const float gelu_erf = 0.5f*x0*(1.0f+erf_approx<float>(x0*SQRT_2_INV)); - dst_row[i0] = gelu_erf*x1; + dst_row[i0] = (T)(gelu_erf*x1); } } -kernel void kernel_geglu_quick_f32( +typedef decltype(kernel_geglu_erf<float>) kernel_geglu_erf_t; + +template [[host_name("kernel_geglu_erf_f32")]] kernel kernel_geglu_erf_t kernel_geglu_erf<float>; +template [[host_name("kernel_geglu_erf_f16")]] kernel kernel_geglu_erf_t kernel_geglu_erf<half>; + +template<typename T> +kernel void kernel_geglu_quick( constant ggml_metal_kargs_glu & args, device const char * src0, device const char * src1, @@ -1541,9 +1572,9 @@ kernel void kernel_geglu_quick_f32( uint tgpig[[threadgroup_position_in_grid]], uint tpitg[[thread_position_in_threadgroup]], uint ntg[[threads_per_threadgroup]]) { - device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00; - device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10; - device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1); + device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00; + device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10; + device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1); for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) { const float x0 = src0_row[i0]; @@ -1551,10 +1582,15 @@ kernel void kernel_geglu_quick_f32( const float gelu_quick = x0*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x0))); - dst_row[i0] = gelu_quick*x1; + dst_row[i0] = (T)(gelu_quick*x1); } } +typedef decltype(kernel_geglu_quick<float>) kernel_geglu_quick_t; + +template [[host_name("kernel_geglu_quick_f32")]] kernel kernel_geglu_quick_t kernel_geglu_quick<float>; +template [[host_name("kernel_geglu_quick_f16")]] kernel kernel_geglu_quick_t kernel_geglu_quick<half>; + kernel void kernel_op_sum_f32( constant ggml_metal_kargs_sum & args, device const float * src0, From e728bae15950e1786b4c9574fa56a889e772f516 Mon Sep 17 00:00:00 2001 From: shaofeiqi <shaoqi@qti.qualcomm.com> Date: Mon, 1 Jun 2026 10:06:50 -0700 Subject: [PATCH 772/831] opencl: add basic support for q5_0 and q5_1 (llama/23548) * opencl: add general q5_0 support * opencl: add general q5_1 support * opencl: support non-uniform workgrp size --------- Co-authored-by: Li He <lih@qti.qualcomm.com> --- ggml/src/ggml-opencl/CMakeLists.txt | 6 + ggml/src/ggml-opencl/ggml-opencl.cpp | 422 +++++++++++++++++- ggml/src/ggml-opencl/kernels/cvt.cl | 100 +++++ .../kernels/mul_mm_q5_0_f32_l4_lm.cl | 173 +++++++ .../kernels/mul_mm_q5_1_f32_l4_lm.cl | 175 ++++++++ .../ggml-opencl/kernels/mul_mv_q5_0_f32.cl | 241 ++++++++++ .../kernels/mul_mv_q5_0_f32_flat.cl | 243 ++++++++++ .../ggml-opencl/kernels/mul_mv_q5_1_f32.cl | 243 ++++++++++ .../kernels/mul_mv_q5_1_f32_flat.cl | 247 ++++++++++ 9 files changed, 1845 insertions(+), 5 deletions(-) create mode 100644 ggml/src/ggml-opencl/kernels/mul_mm_q5_0_f32_l4_lm.cl create mode 100644 ggml/src/ggml-opencl/kernels/mul_mm_q5_1_f32_l4_lm.cl create mode 100644 ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32.cl create mode 100644 ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32_flat.cl create mode 100644 ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32.cl create mode 100644 ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32_flat.cl diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index 446fb727996..cd15d573238 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -87,6 +87,10 @@ set(GGML_OPENCL_KERNELS mul_mv_q4_1_f32_flat mul_mv_q4_k_f32 mul_mv_q4_k_f32_flat + mul_mv_q5_0_f32 + mul_mv_q5_0_f32_flat + mul_mv_q5_1_f32 + mul_mv_q5_1_f32_flat mul_mv_q5_k_f32 mul_mv_q5_k_f32_flat mul_mv_q6_k_f32 @@ -126,6 +130,8 @@ set(GGML_OPENCL_KERNELS mul_mm_f16_f32_l4_lm mul_mm_q4_0_f32_l4_lm mul_mm_q4_1_f32_l4_lm + mul_mm_q5_0_f32_l4_lm + mul_mm_q5_1_f32_l4_lm mul_mm_q8_0_f32_l4_lm mul_mm_iq4_nl_f32_l4_lm mul_mm_q4_k_f32_l4_lm diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 3f3643a4cef..7cafbe0cdc3 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -576,7 +576,9 @@ struct ggml_backend_opencl_context { cl_kernel kernel_convert_block_q4_0_trans4_ns, kernel_restore_block_q4_0_trans4_ns; cl_kernel kernel_convert_block_q4_1, kernel_restore_block_q4_1; cl_kernel kernel_convert_block_q4_1_trans4_ns, kernel_restore_block_q4_1_trans4_ns; + cl_kernel kernel_convert_block_q5_0, kernel_restore_block_q5_0; cl_kernel kernel_convert_block_q5_0_trans4_ns, kernel_restore_block_q5_0_trans4_ns; + cl_kernel kernel_convert_block_q5_1, kernel_restore_block_q5_1; cl_kernel kernel_convert_block_q5_1_trans4_ns, kernel_restore_block_q5_1_trans4_ns; cl_kernel kernel_convert_block_q4_k_trans4_ns, kernel_restore_block_q4_k_trans4_ns; cl_kernel kernel_convert_block_q5_k_trans4_ns, kernel_restore_block_q5_k_trans4_ns; @@ -604,6 +606,10 @@ struct ggml_backend_opencl_context { cl_kernel kernel_mul_mat_q4_0_f32_1d_8x_flat, kernel_mul_mat_q4_0_f32_1d_16x_flat; cl_kernel kernel_mul_mv_q4_1_f32; cl_kernel kernel_mul_mv_q4_1_f32_flat; + cl_kernel kernel_mul_mv_q5_0_f32; + cl_kernel kernel_mul_mv_q5_0_f32_flat; + cl_kernel kernel_mul_mv_q5_1_f32; + cl_kernel kernel_mul_mv_q5_1_f32_flat; cl_kernel kernel_mul_mv_q4_K_f32; cl_kernel kernel_mul_mv_q4_K_f32_flat; cl_kernel kernel_mul_mv_q5_K_f32; @@ -662,6 +668,8 @@ struct ggml_backend_opencl_context { cl_kernel kernel_mul_mm_f16_f32_l4_lm; cl_kernel kernel_mul_mm_q4_0_f32_l4_lm; cl_kernel kernel_mul_mm_q4_1_f32_l4_lm; + cl_kernel kernel_mul_mm_q5_0_f32_l4_lm; + cl_kernel kernel_mul_mm_q5_1_f32_l4_lm; cl_kernel kernel_mul_mm_q8_0_f32_l4_lm; cl_kernel kernel_mul_mm_q4_k_f32_l4_lm; cl_kernel kernel_mul_mm_q5_k_f32_l4_lm; @@ -1141,8 +1149,12 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx) { CL_CHECK((backend_ctx->kernel_restore_block_q4_1 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_1", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q4_1_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_1_trans4_ns", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_q4_1_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_1_trans4_ns", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q5_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q5_0", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q5_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q5_0", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q5_0_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q5_0_trans4_ns", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_q5_0_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q5_0_trans4_ns", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q5_1 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q5_1", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q5_1 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q5_1", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q5_1_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q5_1_trans4_ns", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_q5_1_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q5_1_trans4_ns", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q4_k_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_k_trans4_ns", &err), err)); @@ -1485,6 +1497,74 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx) { GGML_LOG_CONT("."); } + // mul_mv_q5_0_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_q5_0_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_q5_0_f32.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mv_q5_0_f32 = clCreateKernel(prog, "kernel_mul_mv_q5_0_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // mul_mv_q5_0_f32_flat + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_q5_0_f32_flat.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_q5_0_f32_flat.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mv_q5_0_f32_flat = clCreateKernel(prog, "kernel_mul_mv_q5_0_f32_flat", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // mul_mv_q5_1_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_q5_1_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_q5_1_f32.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mv_q5_1_f32 = clCreateKernel(prog, "kernel_mul_mv_q5_1_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // mul_mv_q5_1_f32_flat + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_q5_1_f32_flat.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_q5_1_f32_flat.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mv_q5_1_f32_flat = clCreateKernel(prog, "kernel_mul_mv_q5_1_f32_flat", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + // mul_mv_q5_k_f32 { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -1835,6 +1915,38 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx) { GGML_LOG_CONT("."); } + // mul_mm_q5_0_f32_l4_lm + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mm_q5_0_f32_l4_lm.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mm_q5_0_f32_l4_lm.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mm_q5_0_f32_l4_lm = clCreateKernel(prog, "kernel_mul_mm_q5_0_f32_l4_lm", &err), err)); + GGML_LOG_CONT("."); + } + + // mul_mm_q5_1_f32_l4_lm + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mm_q5_1_f32_l4_lm.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mm_q5_1_f32_l4_lm.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mm_q5_1_f32_l4_lm = clCreateKernel(prog, "kernel_mul_mm_q5_1_f32_l4_lm", &err), err)); + GGML_LOG_CONT("."); + } + // mul_mm_q8_0_f32_l4_lm { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -5027,6 +5139,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te } else if (op->src[0]->type == GGML_TYPE_F32) { return op->src[1]->type == GGML_TYPE_F32; } else if (op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q4_1 || + op->src[0]->type == GGML_TYPE_Q5_0 || op->src[0]->type == GGML_TYPE_Q5_1 || op->src[0]->type == GGML_TYPE_MXFP4 || op->src[0]->type == GGML_TYPE_IQ4_NL || op->src[0]->type == GGML_TYPE_Q4_K || @@ -5977,7 +6090,24 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, return; } #endif // GGML_OPENCL_USE_ADRENO_KERNELS - return; + cl_kernel kernel = backend_ctx->kernel_convert_block_q5_0; + cl_ulong n_blk = ggml_nelements(tensor)/ggml_blck_size(tensor->type); + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->qs)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->qh)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &n_blk)); + + size_t global_work_size[] = {(size_t)CEIL_DIV(n_blk, 64) * 64, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + + tensor->extra = extra; + return; } if (tensor->type == GGML_TYPE_Q5_1) { ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra; @@ -6078,6 +6208,24 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, return; } #endif // GGML_OPENCL_USE_ADRENO_KERNELS + cl_kernel kernel = backend_ctx->kernel_convert_block_q5_1; + cl_ulong n_blk = ggml_nelements(tensor)/ggml_blck_size(tensor->type); + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->qs)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->qh)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra->m)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &n_blk)); + + size_t global_work_size[] = {(size_t)CEIL_DIV(n_blk, 64) * 64, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + + tensor->extra = extra; return; } if (tensor->type == GGML_TYPE_MXFP4) { @@ -7135,8 +7283,29 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, return; } #endif // GGML_OPENCL_USE_ADRENO_KERNELS - // TODO: normal q5_0 - (void) extra; + + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + + cl_kernel kernel = backend_ctx->kernel_restore_block_q5_0; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->qs)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &data_device)); + + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {1, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clEnqueueReadBuffer( + queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); return; } if (tensor->type == GGML_TYPE_Q5_1) { @@ -7177,8 +7346,29 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, return; } #endif // GGML_OPENCL_USE_ADRENO_KERNELS - // TODO: normal q5_1 - (void) extra; + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + + cl_kernel kernel = backend_ctx->kernel_restore_block_q5_1; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->qs)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->m)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &data_device)); + + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {1, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clEnqueueReadBuffer( + queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); return; } if (tensor->type == GGML_TYPE_MXFP4) { @@ -12936,6 +13126,8 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co #ifdef GGML_OPENCL_SOA_Q ggml_tensor_extra_cl_q4_0 * extra0_q4_0 = (ggml_tensor_extra_cl_q4_0 *)src0->extra; ggml_tensor_extra_cl_q4_1 * extra0_q4_1 = (ggml_tensor_extra_cl_q4_1 *)src0->extra; + ggml_tensor_extra_cl_q5_0 * extra0_q5_0 = (ggml_tensor_extra_cl_q5_0 *)src0->extra; + ggml_tensor_extra_cl_q5_1 * extra0_q5_1 = (ggml_tensor_extra_cl_q5_1 *)src0->extra; ggml_tensor_extra_cl_mxfp4 * extra0_mxfp4 = (ggml_tensor_extra_cl_mxfp4 *)src0->extra; ggml_tensor_extra_cl_q8_0 * extra0_q8_0 = (ggml_tensor_extra_cl_q8_0 *)src0->extra; ggml_tensor_extra_cl_iq4_nl * extra0_iq4_nl = (ggml_tensor_extra_cl_iq4_nl *)src0->extra; @@ -13271,6 +13463,93 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); return; } + case GGML_TYPE_Q5_0: { + if (ne11 < 32) { + break; + } + if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) { + break; + } + + kernel = backend_ctx->kernel_mul_mm_q5_0_f32_l4_lm; + nth0 = 128; // calculated as (BM*BN)/(TM*TN) + + int batch_stride_a = ne00*ne01; + int batch_stride_b = ne10*ne11; + int batch_stride_d = ne0*ne1; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q5_0->qs)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q5_0->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q5_0->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne10)); // stride_a + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne10)); // stride_b + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne01)); // stride_d + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &batch_stride_a)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &batch_stride_b)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &batch_stride_d)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(int), &r3)); + + // 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed. + size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13}; + size_t local_work_size[] = {(size_t)nth0, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + return; + } + case GGML_TYPE_Q5_1: { + if (ne11 < 32) { + break; + } + if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) { + break; + } + + kernel = backend_ctx->kernel_mul_mm_q5_1_f32_l4_lm; + nth0 = 128; // calculated as (BM*BN)/(TM*TN) + + int batch_stride_a = ne00*ne01; + int batch_stride_b = ne10*ne11; + int batch_stride_d = ne0*ne1; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q5_1->qs)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q5_1->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q5_1->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q5_1->m)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne10)); // stride_a + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne10)); // stride_b + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne01)); // stride_d + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &batch_stride_a)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &batch_stride_b)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &batch_stride_d)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int), &r3)); + + // 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed. + size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13}; + size_t local_work_size[] = {(size_t)nth0, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + return; + } case GGML_TYPE_Q8_0: { if (ne11 < 32) { break; @@ -13807,6 +14086,137 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co #endif // GGML_OPENCL_SOA_Q break; } + case GGML_TYPE_Q5_0: { +#ifdef GGML_OPENCL_SOA_Q + if (backend_ctx->gpu_family == INTEL) { + nth0 = 16; + nth1 = 1; + ndst = 4; + } else if (backend_ctx->gpu_family == ADRENO) { + nth0 = 64; + nth1 = 1; + ndst = 4; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } + + kernel = backend_ctx->kernel_mul_mv_q5_0_f32_flat; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q5_0->qs)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q5_0->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q5_0->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &r3)); +#else + if (backend_ctx->gpu_family == INTEL) { + nth0 = 16; + nth1 = 1; + ndst = 4; + } else if (backend_ctx->gpu_family == ADRENO) { + nth0 = 64; + nth1 = 1; + ndst = 4; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } + + kernel = backend_ctx->kernel_mul_mv_q5_0_f32; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r3)); +#endif // GGML_OPENCL_SOA_Q + break; + } + case GGML_TYPE_Q5_1: { +#ifdef GGML_OPENCL_SOA_Q + if (backend_ctx->gpu_family == INTEL) { + nth0 = 16; + nth1 = 1; + ndst = 4; + } else if (backend_ctx->gpu_family == ADRENO) { + nth0 = 64; + nth1 = 1; + ndst = 4; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } + + kernel = backend_ctx->kernel_mul_mv_q5_1_f32_flat; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q5_1->qs)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q5_1->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q5_1->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q5_1->m)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &r3)); +#else + if (backend_ctx->gpu_family == INTEL) { + nth0 = 16; + nth1 = 1; + ndst = 4; + } else if (backend_ctx->gpu_family == ADRENO) { + nth0 = 64; + nth1 = 1; + ndst = 4; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } + + kernel = backend_ctx->kernel_mul_mv_q5_1_f32; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r3)); +#endif // GGML_OPENCL_SOA_Q + break; + } case GGML_TYPE_Q8_0: { #ifdef GGML_OPENCL_SOA_Q kernel = backend_ctx->kernel_mul_mv_q8_0_f32_flat; @@ -14247,6 +14657,8 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_MXFP4 || src0t == GGML_TYPE_Q4_1 || + src0t == GGML_TYPE_Q5_0 || + src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_Q2_K) { diff --git a/ggml/src/ggml-opencl/kernels/cvt.cl b/ggml/src/ggml-opencl/kernels/cvt.cl index 4f01887efb3..d07f0a1a025 100644 --- a/ggml/src/ggml-opencl/kernels/cvt.cl +++ b/ggml/src/ggml-opencl/kernels/cvt.cl @@ -537,6 +537,53 @@ kernel void kernel_restore_block_q4_1_trans4_ns( ((__global ushort8 *)(&(b->qs[0])))[0] = pre_block; } +//------------------------------------------------------------------------------ +// kernel_convert_block_q5_0 +// Convert the block_q5_0 format to 3 separate arrays (AOS -> SOA). +// This kernel does not deshuffle the bits. +//------------------------------------------------------------------------------ +kernel void kernel_convert_block_q5_0( + global struct block_q5_0 * src0, + global uchar * dst_qs, + global uint * dst_qh, + global half * dst_d, + ulong n_blk +) { + if (get_global_id(0) >= n_blk) { + return; + } + + global struct block_q5_0 * b = (global struct block_q5_0 *) src0 + get_global_id(0); + global uchar * qs = (global uchar *) dst_qs + (QK5_0/2)*get_global_id(0); + global uint * qh = (global uint *) dst_qh + get_global_id(0); + global half * d = (global half *) dst_d + get_global_id(0); + + *d = b->d; + *qh = *((global uint *)(b->qh)); + + for (int i = 0; i < QK5_0/2; ++i) { + qs[i] = b->qs[i]; + } +} + +kernel void kernel_restore_block_q5_0( + global uchar * src_qs, + global uint * src_qh, + global half * src_d, + global struct block_q5_0 * dst +) { + global struct block_q5_0 * b = (global struct block_q5_0 *) dst + get_global_id(0); + global uchar * qs = (global uchar *) src_qs + (QK5_0/2)*get_global_id(0); + global uint * qh = (global uint *) src_qh + get_global_id(0); + global half * d = (global half *) src_d + get_global_id(0); + + b->d = *d; + *((global uint *)(b->qh)) = *qh; + for (int i = 0; i < QK5_0/2; ++i) { + b->qs[i] = qs[i]; + } +} + kernel void kernel_convert_block_q5_0_trans4_ns( __global struct block_q5_0 * src0, __global uint * dst_qs, @@ -636,6 +683,59 @@ kernel void kernel_restore_block_q5_0_trans4_ns( ((__global ushort8 *)(&(b->qs[0])))[0] = pre_block; } +//------------------------------------------------------------------------------ +// kernel_convert_block_q5_1 +// Convert the block_q5_1 format to 4 separate arrays (AOS -> SOA). +// This kernel does not deshuffle the bits. +//------------------------------------------------------------------------------ +kernel void kernel_convert_block_q5_1( + global struct block_q5_1 * src0, + global uchar * dst_qs, + global uint * dst_qh, + global half * dst_d, + global half * dst_m, + ulong n_blk +) { + if (get_global_id(0) >= n_blk) { + return; + } + + global struct block_q5_1 * b = (global struct block_q5_1 *) src0 + get_global_id(0); + global uchar * qs = (global uchar *) dst_qs + (QK5_1/2)*get_global_id(0); + global uint * qh = (global uint *) dst_qh + get_global_id(0); + global half * d = (global half *) dst_d + get_global_id(0); + global half * m = (global half *) dst_m + get_global_id(0); + + *d = b->d; + *m = b->m; + *qh = *((global uint *)(b->qh)); + + for (int i = 0; i < QK5_1/2; ++i) { + qs[i] = b->qs[i]; + } +} + +kernel void kernel_restore_block_q5_1( + global uchar * src_qs, + global uint * src_qh, + global half * src_d, + global half * src_m, + global struct block_q5_1 * dst +) { + global struct block_q5_1 * b = (global struct block_q5_1 *) dst + get_global_id(0); + global uchar * qs = (global uchar *) src_qs + (QK5_1/2)*get_global_id(0); + global uint * qh = (global uint *) src_qh + get_global_id(0); + global half * d = (global half *) src_d + get_global_id(0); + global half * m = (global half *) src_m + get_global_id(0); + + b->d = *d; + b->m = *m; + *((global uint *)(b->qh)) = *qh; + for (int i = 0; i < QK5_1/2; ++i) { + b->qs[i] = qs[i]; + } +} + kernel void kernel_convert_block_q5_1_trans4_ns( __global struct block_q5_1 * src0, __global uint * dst_qs, diff --git a/ggml/src/ggml-opencl/kernels/mul_mm_q5_0_f32_l4_lm.cl b/ggml/src/ggml-opencl/kernels/mul_mm_q5_0_f32_l4_lm.cl new file mode 100644 index 00000000000..1e980a478a8 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mm_q5_0_f32_l4_lm.cl @@ -0,0 +1,173 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#define LOAD_VEC_A 8 +#define LOAD_VEC_B 4 + +#define BM 64 +#define BN 64 +#define BK 32 +#define TM 4 +#define TN 8 + +kernel void kernel_mul_mm_q5_0_f32_l4_lm( + global uchar4 * src0_qs, + global uint * src0_qh, + global half * src0_d, + global float4 * src1, + ulong offset1, + global float * dst, + ulong offsetd, + + int ne00, + int ne01, + int ne02, + int ne11, + int ne12, + + int stride_a, + int stride_b, + int stride_d, + + int batch_stride_a, + int batch_stride_b, + int batch_stride_d, + + int r2, + int r3 +) { + src1 = (global float4*)((global char*)src1 + offset1); + dst = (global float *)((global char*)dst + offsetd); + + local float buf_a[BM * BK]; + local float buf_b[BN * BK]; + + const int batch_idx = get_global_id(2); + + const int i13 = batch_idx / ne12; + const int i12 = batch_idx % ne12; + + const int i03 = i13 / r3; + const int i02 = i12 / r2; + + const int batch_idx_a = i03 * ne02 + i02; + + const int ir = get_group_id(0); + const int ic = get_group_id(1); + + const int tid = get_local_id(0); + const int th_r = tid % (BM / TM); + const int th_c = tid / (BM / TM); + + const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A); + const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A); + const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B); + const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B); + + const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK; + const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK; + + int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A; + int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B; + + float sums[TM * TN]; + float cache_a[TM]; + float cache_b[TN]; + + for (int i = 0; i < TM * TN; i++) { + sums[i] = 0.0f; + } + + for (int block = 0; block < ne00; block += BK) { + for (int l = 0; l < BM; l += loadstride_a) { + if (ir*BM + loadc_a + l < ne01) { + int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a; + int ib = idx / 4; + int iqs = idx % 4; + + float d = (float)src0_d[ib]; + uint qh_val = src0_qh[ib]; + + global uchar4 * qs_ptr = src0_qs + ib*4 + iqs; + uchar4 q = *qs_ptr; + + uint qh_lo = qh_val >> (iqs * 4); + uint qh_hi = qh_val >> (iqs * 4 + 16); + + uchar4 b_lo = (uchar4)((uchar)qh_lo, (uchar)(qh_lo >> 1), (uchar)(qh_lo >> 2), (uchar)(qh_lo >> 3)) & (uchar)1; + uchar4 b_hi = (uchar4)((uchar)qh_hi, (uchar)(qh_hi >> 1), (uchar)(qh_hi >> 2), (uchar)(qh_hi >> 3)) & (uchar)1; + + float4 v1 = (convert_float4((q & (uchar)0x0F) | (b_lo << (uchar)4)) - 16.0f) * d; + float4 v2 = (convert_float4((q >> (uchar)4) | (b_hi << (uchar)4)) - 16.0f) * d; + + buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = v1.s0; + buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = v1.s1; + buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = v1.s2; + buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = v1.s3; + buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = v2.s0; + buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = v2.s1; + buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = v2.s2; + buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = v2.s3; + } else { + buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = 0.0f; + } + } + + for (int l = 0; l < BN; l += loadstride_b) { + if (ic*BN + loadc_b + l < ne11) { + int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b; + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3; + } else { + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f; + } + } + + barrier(CLK_LOCAL_MEM_FENCE); + + pos_a += BK / LOAD_VEC_A; + pos_b += BK / LOAD_VEC_B; + + for (int i = 0; i < BK; i++) { + for (int j = 0; j < TM; j++) { + cache_a[j] = buf_a[(i) * BM + th_r * TM + j]; + } + + for (int j = 0; j < TN; j++) { + cache_b[j] = buf_b[(i) * BN + th_c * TN + j]; + } + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + const int sums_idx = cc*TM + cr; + sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]); + } + } + } + barrier(CLK_LOCAL_MEM_FENCE); + } + + const int dr = ir * BM + th_r * TM; + const int dc = ic * BN + th_c * TN; + + const int offsets = batch_idx * batch_stride_d; + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + if (dr + cr < ne01 && dc + cc < ne11) { + dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr]; + } + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mm_q5_1_f32_l4_lm.cl b/ggml/src/ggml-opencl/kernels/mul_mm_q5_1_f32_l4_lm.cl new file mode 100644 index 00000000000..ba06be54697 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mm_q5_1_f32_l4_lm.cl @@ -0,0 +1,175 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#define LOAD_VEC_A 8 +#define LOAD_VEC_B 4 + +#define BM 64 +#define BN 64 +#define BK 32 +#define TM 4 +#define TN 8 + +kernel void kernel_mul_mm_q5_1_f32_l4_lm( + global uchar4 * src0_qs, + global uint * src0_qh, + global half * src0_d, + global half * src0_m, + global float4 * src1, + ulong offset1, + global float * dst, + ulong offsetd, + + int ne00, + int ne01, + int ne02, + int ne11, + int ne12, + + int stride_a, + int stride_b, + int stride_d, + + int batch_stride_a, + int batch_stride_b, + int batch_stride_d, + + int r2, + int r3 +) { + src1 = (global float4*)((global char*)src1 + offset1); + dst = (global float *)((global char*)dst + offsetd); + + local float buf_a[BM * BK]; + local float buf_b[BN * BK]; + + const int batch_idx = get_global_id(2); + + const int i13 = batch_idx / ne12; + const int i12 = batch_idx % ne12; + + const int i03 = i13 / r3; + const int i02 = i12 / r2; + + const int batch_idx_a = i03 * ne02 + i02; + + const int ir = get_group_id(0); + const int ic = get_group_id(1); + + const int tid = get_local_id(0); + const int th_r = tid % (BM / TM); + const int th_c = tid / (BM / TM); + + const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A); + const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A); + const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B); + const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B); + + const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK; + const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK; + + int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A; + int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B; + + float sums[TM * TN]; + float cache_a[TM]; + float cache_b[TN]; + + for (int i = 0; i < TM * TN; i++) { + sums[i] = 0.0f; + } + + for (int block = 0; block < ne00; block += BK) { + for (int l = 0; l < BM; l += loadstride_a) { + if (ir*BM + loadc_a + l < ne01) { + int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a; + int ib = idx / 4; + int iqs = idx % 4; + + float d = (float)src0_d[ib]; + float m = (float)src0_m[ib]; + uint qh_val = src0_qh[ib]; + + global uchar4 * qs = src0_qs + ib*4 + iqs; + uchar4 q = *qs; + + uint qh_lo = qh_val >> (iqs * 4); + uint qh_hi = qh_val >> (iqs * 4 + 16); + + uchar4 b_lo = (uchar4)((uchar)qh_lo, (uchar)(qh_lo >> 1), (uchar)(qh_lo >> 2), (uchar)(qh_lo >> 3)) & (uchar)1; + uchar4 b_hi = (uchar4)((uchar)qh_hi, (uchar)(qh_hi >> 1), (uchar)(qh_hi >> 2), (uchar)(qh_hi >> 3)) & (uchar)1; + + float4 v1 = convert_float4((q & (uchar)0x0F) | (b_lo << (uchar)4)) * d + m; + float4 v2 = convert_float4((q >> (uchar)4) | (b_hi << (uchar)4)) * d + m; + + buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = v1.s0; + buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = v1.s1; + buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = v1.s2; + buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = v1.s3; + buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = v2.s0; + buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = v2.s1; + buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = v2.s2; + buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = v2.s3; + } else { + buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = 0.0f; + } + } + + for (int l = 0; l < BN; l += loadstride_b) { + if (ic*BN + loadc_b + l < ne11) { + int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b; + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3; + } else { + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f; + } + } + + barrier(CLK_LOCAL_MEM_FENCE); + + pos_a += BK / LOAD_VEC_A; + pos_b += BK / LOAD_VEC_B; + + for (int i = 0; i < BK; i++) { + for (int j = 0; j < TM; j++) { + cache_a[j] = buf_a[(i) * BM + th_r * TM + j]; + } + + for (int j = 0; j < TN; j++) { + cache_b[j] = buf_b[(i) * BN + th_c * TN + j]; + } + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + const int sums_idx = cc*TM + cr; + sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]); + } + } + } + barrier(CLK_LOCAL_MEM_FENCE); + } + + const int dr = ir * BM + th_r * TM; + const int dc = ic * BN + th_c * TN; + + const int offsets = batch_idx * batch_stride_d; + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + if (dr + cr < ne01 && dc + cc < ne11) { + dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr]; + } + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32.cl new file mode 100644 index 00000000000..6d8c9e8f037 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32.cl @@ -0,0 +1,241 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define QK5_0 32 + +struct block_q5_0 { + half d; + uchar qh[4]; + uchar qs[QK5_0 / 2]; +}; + +inline float block_q5_0_dot_y( + global const struct block_q5_0 * qb_curr, + float sumy, + float16 yl, + int il, + global const float * yb +) { + float d = qb_curr->d; + + float4 acc = (float4)(0.0f, 0.0f, 0.0f, 0.0f); + + global const ushort * qs = ((global const ushort *)((global const uchar *) qb_curr + 6 + il)); + + acc.s0 += yl.s0 * (qs[0] & 0x000F); + acc.s0 += yl.s1 * (qs[0] & 0x0F00); + acc.s0 += yl.s8 * (qs[0] & 0x00F0); + acc.s3 += yl.s9 * (qs[0] & 0xF000); + + acc.s0 += yl.s2 * (qs[1] & 0x000F); + acc.s1 += yl.s3 * (qs[1] & 0x0F00); + acc.s2 += yl.sa * (qs[1] & 0x00F0); + acc.s3 += yl.sb * (qs[1] & 0xF000); + + acc.s0 += yl.s4 * (qs[2] & 0x000F); + acc.s1 += yl.s5 * (qs[2] & 0x0F00); + acc.s2 += yl.sc * (qs[2] & 0x00F0); + acc.s3 += yl.sd * (qs[2] & 0xF000); + + acc.s0 += yl.s6 * (qs[3] & 0x000F); + acc.s1 += yl.s7 * (qs[3] & 0x0F00); + acc.s2 += yl.se * (qs[3] & 0x00F0); + acc.s3 += yl.sf * (qs[3] & 0xF000); + + uint qh_val = *((global const uint *)((global const uchar *) qb_curr + 2)); + uchar qh_lo = (uchar)((qh_val >> il) & 0xFF); + uchar qh_hi = (uchar)((qh_val >> (il + 16)) & 0xFF); + + float qh_sum = 0.0f; + qh_sum += yb[0] * (float)((qh_lo >> 0) & 1); + qh_sum += yb[1] * (float)((qh_lo >> 1) & 1); + qh_sum += yb[2] * (float)((qh_lo >> 2) & 1); + qh_sum += yb[3] * (float)((qh_lo >> 3) & 1); + qh_sum += yb[4] * (float)((qh_lo >> 4) & 1); + qh_sum += yb[5] * (float)((qh_lo >> 5) & 1); + qh_sum += yb[6] * (float)((qh_lo >> 6) & 1); + qh_sum += yb[7] * (float)((qh_lo >> 7) & 1); + qh_sum += yb[16] * (float)((qh_hi >> 0) & 1); + qh_sum += yb[17] * (float)((qh_hi >> 1) & 1); + qh_sum += yb[18] * (float)((qh_hi >> 2) & 1); + qh_sum += yb[19] * (float)((qh_hi >> 3) & 1); + qh_sum += yb[20] * (float)((qh_hi >> 4) & 1); + qh_sum += yb[21] * (float)((qh_hi >> 5) & 1); + qh_sum += yb[22] * (float)((qh_hi >> 6) & 1); + qh_sum += yb[23] * (float)((qh_hi >> 7) & 1); + + return d * (acc.s0 + acc.s1 + acc.s2 + acc.s3 + 16.0f * qh_sum - 16.0f * sumy); +} + +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 4 // each subgroup works on 4 rows +#define N_SIMDGROUP 1 // number of subgroups in a thread group +#define N_SIMDWIDTH 16 // assuming subgroup size is 16 +#elif defined (ADRENO_GPU) +#define N_DST 4 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif + +inline void mul_vec_q_n_f32( + global void * src0, + global float * src1, + global float * dst, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + const ulong nb = ne00/QK5_0; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + global struct block_q5_0 * x = (global struct block_q5_0 *) src0 + offset0; + global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float16 yl; + float4 sumf = (float4)(0.f, 0.f, 0.f, 0.f); + + int ix = get_sub_group_local_id()/2; + int il = 8*(get_sub_group_local_id()%2); + + global float * yb = y + ix * QK5_0 + il; + + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { + float sumy = 0; + + sumy += yb[0]; + sumy += yb[1]; + sumy += yb[2]; + sumy += yb[3]; + sumy += yb[4]; + sumy += yb[5]; + sumy += yb[6]; + sumy += yb[7]; + + sumy += yb[16]; + sumy += yb[17]; + sumy += yb[18]; + sumy += yb[19]; + sumy += yb[20]; + sumy += yb[21]; + sumy += yb[22]; + sumy += yb[23]; + + + yl.s0 = yb[0]; + yl.s1 = yb[1]/256.f; + + yl.s2 = yb[2]; + yl.s3 = yb[3]/256.f; + + yl.s4 = yb[4]; + yl.s5 = yb[5]/256.f; + + yl.s6 = yb[6]; + yl.s7 = yb[7]/256.f; + + yl.s8 = yb[16]/16.f; + yl.s9 = yb[17]/4096.f; + + yl.sa = yb[18]/16.f; + yl.sb = yb[19]/4096.f; + + yl.sc = yb[20]/16.f; + yl.sd = yb[21]/4096.f; + + yl.se = yb[22]/16.f; + yl.sf = yb[23]/4096.f; + + sumf.s0 += block_q5_0_dot_y(x+ib+0*nb, sumy, yl, il, yb); + sumf.s1 += block_q5_0_dot_y(x+ib+1*nb, sumy, yl, il, yb); + sumf.s2 += block_q5_0_dot_y(x+ib+2*nb, sumy, yl, il, yb); + sumf.s3 += block_q5_0_dot_y(x+ib+3*nb, sumy, yl, il, yb); + + yb += QK5_0 * (N_SIMDWIDTH/2); + } + + float4 tot = (float4)( + sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), + sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3) + ); + + if (get_sub_group_local_id() == 0) { + if (first_row + 0 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; + } + if (first_row + 1 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; + } + if (first_row + 2 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; + } + if (first_row + 3 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; + } + } +} + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mv_q5_0_f32( + global void * src0, + ulong offset0, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + mul_vec_q_n_f32(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32_flat.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32_flat.cl new file mode 100644 index 00000000000..34ec133d398 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32_flat.cl @@ -0,0 +1,243 @@ + +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define QK5_0 32 + +inline float block_q5_0_dot_y_flat( + global const uchar * x, + global const uint * qh_ptr, + global const half * dh, + float sumy, + float16 yl, + int il, + global const float * yb +) { + float d = *dh; + global const ushort * qs = ((global const ushort *)(x + il)); + + float4 acc = (float4)(0.0f, 0.0f, 0.0f, 0.0f); + + acc.s0 += yl.s0 * (qs[0] & 0x000F); + acc.s0 += yl.s1 * (qs[0] & 0x0F00); + acc.s0 += yl.s8 * (qs[0] & 0x00F0); + acc.s3 += yl.s9 * (qs[0] & 0xF000); + + acc.s0 += yl.s2 * (qs[1] & 0x000F); + acc.s1 += yl.s3 * (qs[1] & 0x0F00); + acc.s2 += yl.sa * (qs[1] & 0x00F0); + acc.s3 += yl.sb * (qs[1] & 0xF000); + + acc.s0 += yl.s4 * (qs[2] & 0x000F); + acc.s1 += yl.s5 * (qs[2] & 0x0F00); + acc.s2 += yl.sc * (qs[2] & 0x00F0); + acc.s3 += yl.sd * (qs[2] & 0xF000); + + acc.s0 += yl.s6 * (qs[3] & 0x000F); + acc.s1 += yl.s7 * (qs[3] & 0x0F00); + acc.s2 += yl.se * (qs[3] & 0x00F0); + acc.s3 += yl.sf * (qs[3] & 0xF000); + + uint qh_val = *qh_ptr; + uchar qh_lo = (uchar)((qh_val >> il) & 0xFF); + uchar qh_hi = (uchar)((qh_val >> (il + 16)) & 0xFF); + + float qh_sum = 0.0f; + qh_sum += yb[0] * (float)((qh_lo >> 0) & 1); + qh_sum += yb[1] * (float)((qh_lo >> 1) & 1); + qh_sum += yb[2] * (float)((qh_lo >> 2) & 1); + qh_sum += yb[3] * (float)((qh_lo >> 3) & 1); + qh_sum += yb[4] * (float)((qh_lo >> 4) & 1); + qh_sum += yb[5] * (float)((qh_lo >> 5) & 1); + qh_sum += yb[6] * (float)((qh_lo >> 6) & 1); + qh_sum += yb[7] * (float)((qh_lo >> 7) & 1); + qh_sum += yb[16] * (float)((qh_hi >> 0) & 1); + qh_sum += yb[17] * (float)((qh_hi >> 1) & 1); + qh_sum += yb[18] * (float)((qh_hi >> 2) & 1); + qh_sum += yb[19] * (float)((qh_hi >> 3) & 1); + qh_sum += yb[20] * (float)((qh_hi >> 4) & 1); + qh_sum += yb[21] * (float)((qh_hi >> 5) & 1); + qh_sum += yb[22] * (float)((qh_hi >> 6) & 1); + qh_sum += yb[23] * (float)((qh_hi >> 7) & 1); + + return d * (acc.s0 + acc.s1 + acc.s2 + acc.s3 + 16.0f * qh_sum - 16.0f * sumy); +} + +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 4 // each subgroup works on 4 rows +#define N_SIMDGROUP 1 // number of subgroups in a thread group +#define N_SIMDWIDTH 16 // assuming subgroup size is 16 +#elif defined (ADRENO_GPU) +#define N_DST 4 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif + +inline void mul_vec_q_n_f32_flat( + global void * src0_qs, + global void * src0_qh, + global void * src0_d, + global float * src1, + global float * dst, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + const ulong nb = ne00/QK5_0; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + ulong offset0_qs = offset0 * (QK5_0/2); + + global uchar * x = (global uchar *) src0_qs + offset0_qs; + global uint * qh = (global uint *) src0_qh + offset0; + global half * d = (global half *) src0_d + offset0; + global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float16 yl; + float4 sumf = (float4)(0.f, 0.f, 0.f, 0.f); + + int ix = get_sub_group_local_id()/2; + int il = 8*(get_sub_group_local_id()%2); + + global float * yb = y + ix * QK5_0 + il; + + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { + float sumy = 0; + + sumy += yb[0]; + sumy += yb[1]; + sumy += yb[2]; + sumy += yb[3]; + sumy += yb[4]; + sumy += yb[5]; + sumy += yb[6]; + sumy += yb[7]; + + sumy += yb[16]; + sumy += yb[17]; + sumy += yb[18]; + sumy += yb[19]; + sumy += yb[20]; + sumy += yb[21]; + sumy += yb[22]; + sumy += yb[23]; + + + yl.s0 = yb[0]; + yl.s1 = yb[1]/256.f; + + yl.s2 = yb[2]; + yl.s3 = yb[3]/256.f; + + yl.s4 = yb[4]; + yl.s5 = yb[5]/256.f; + + yl.s6 = yb[6]; + yl.s7 = yb[7]/256.f; + + yl.s8 = yb[16]/16.f; + yl.s9 = yb[17]/4096.f; + + yl.sa = yb[18]/16.f; + yl.sb = yb[19]/4096.f; + + yl.sc = yb[20]/16.f; + yl.sd = yb[21]/4096.f; + + yl.se = yb[22]/16.f; + yl.sf = yb[23]/4096.f; + + sumf.s0 += block_q5_0_dot_y_flat(x + ib*(QK5_0/2) + 0*nb*(QK5_0/2), qh + ib + 0*nb, d + ib + 0*nb, sumy, yl, il, yb); + sumf.s1 += block_q5_0_dot_y_flat(x + ib*(QK5_0/2) + 1*nb*(QK5_0/2), qh + ib + 1*nb, d + ib + 1*nb, sumy, yl, il, yb); + sumf.s2 += block_q5_0_dot_y_flat(x + ib*(QK5_0/2) + 2*nb*(QK5_0/2), qh + ib + 2*nb, d + ib + 2*nb, sumy, yl, il, yb); + sumf.s3 += block_q5_0_dot_y_flat(x + ib*(QK5_0/2) + 3*nb*(QK5_0/2), qh + ib + 3*nb, d + ib + 3*nb, sumy, yl, il, yb); + + yb += QK5_0 * (N_SIMDWIDTH/2); + } + + float4 tot = (float4)( + sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), + sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3) + ); + + if (get_sub_group_local_id() == 0) { + if (first_row + 0 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; + } + if (first_row + 1 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; + } + if (first_row + 2 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; + } + if (first_row + 3 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; + } + } +} + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mv_q5_0_f32_flat( + global void * src0_qs, + global void * src0_qh, + global void * src0_d, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + mul_vec_q_n_f32_flat(src0_qs, src0_qh, src0_d, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32.cl new file mode 100644 index 00000000000..1480f675038 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32.cl @@ -0,0 +1,243 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define QK5_1 32 + +struct block_q5_1 { + half d; + half m; + uchar qh[4]; + uchar qs[QK5_1 / 2]; +}; + +inline float block_q5_1_dot_y( + global const struct block_q5_1 * qb_curr, + float sumy, + float16 yl, + int il, + global const float * yb +) { + float d = qb_curr->d; + float m = qb_curr->m; + + float4 acc = (float4)(0.0f, 0.0f, 0.0f, 0.0f); + + global const ushort * qs = ((global const ushort *)((global const uchar *) qb_curr + 8 + il)); + + acc.s0 += yl.s0 * (qs[0] & 0x000F); + acc.s0 += yl.s1 * (qs[0] & 0x0F00); + acc.s0 += yl.s8 * (qs[0] & 0x00F0); + acc.s3 += yl.s9 * (qs[0] & 0xF000); + + acc.s0 += yl.s2 * (qs[1] & 0x000F); + acc.s1 += yl.s3 * (qs[1] & 0x0F00); + acc.s2 += yl.sa * (qs[1] & 0x00F0); + acc.s3 += yl.sb * (qs[1] & 0xF000); + + acc.s0 += yl.s4 * (qs[2] & 0x000F); + acc.s1 += yl.s5 * (qs[2] & 0x0F00); + acc.s2 += yl.sc * (qs[2] & 0x00F0); + acc.s3 += yl.sd * (qs[2] & 0xF000); + + acc.s0 += yl.s6 * (qs[3] & 0x000F); + acc.s1 += yl.s7 * (qs[3] & 0x0F00); + acc.s2 += yl.se * (qs[3] & 0x00F0); + acc.s3 += yl.sf * (qs[3] & 0xF000); + + uint qh_val = *((global const uint *)((global const uchar *) qb_curr + 4)); + uchar qh_lo = (uchar)((qh_val >> il) & 0xFF); + uchar qh_hi = (uchar)((qh_val >> (il + 16)) & 0xFF); + + float qh_sum = 0.0f; + qh_sum += yb[0] * (float)((qh_lo >> 0) & 1); + qh_sum += yb[1] * (float)((qh_lo >> 1) & 1); + qh_sum += yb[2] * (float)((qh_lo >> 2) & 1); + qh_sum += yb[3] * (float)((qh_lo >> 3) & 1); + qh_sum += yb[4] * (float)((qh_lo >> 4) & 1); + qh_sum += yb[5] * (float)((qh_lo >> 5) & 1); + qh_sum += yb[6] * (float)((qh_lo >> 6) & 1); + qh_sum += yb[7] * (float)((qh_lo >> 7) & 1); + qh_sum += yb[16] * (float)((qh_hi >> 0) & 1); + qh_sum += yb[17] * (float)((qh_hi >> 1) & 1); + qh_sum += yb[18] * (float)((qh_hi >> 2) & 1); + qh_sum += yb[19] * (float)((qh_hi >> 3) & 1); + qh_sum += yb[20] * (float)((qh_hi >> 4) & 1); + qh_sum += yb[21] * (float)((qh_hi >> 5) & 1); + qh_sum += yb[22] * (float)((qh_hi >> 6) & 1); + qh_sum += yb[23] * (float)((qh_hi >> 7) & 1); + + return d * (acc.s0 + acc.s1 + acc.s2 + acc.s3 + 16.0f * qh_sum) + sumy * m; +} + +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 4 // each subgroup works on 4 rows +#define N_SIMDGROUP 1 // number of subgroups in a thread group +#define N_SIMDWIDTH 16 // assuming subgroup size is 16 +#elif defined (ADRENO_GPU) +#define N_DST 4 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif + +inline void mul_vec_q_n_f32( + global void * src0, + global float * src1, + global float * dst, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + const ulong nb = ne00/QK5_1; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + global struct block_q5_1 * x = (global struct block_q5_1 *) src0 + offset0; + global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float16 yl; + float4 sumf = (float4)(0.f, 0.f, 0.f, 0.f); + + int ix = get_sub_group_local_id()/2; + int il = 8*(get_sub_group_local_id()%2); + + global float * yb = y + ix * QK5_1 + il; + + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { + float sumy = 0; + + sumy += yb[0]; + sumy += yb[1]; + sumy += yb[2]; + sumy += yb[3]; + sumy += yb[4]; + sumy += yb[5]; + sumy += yb[6]; + sumy += yb[7]; + + sumy += yb[16]; + sumy += yb[17]; + sumy += yb[18]; + sumy += yb[19]; + sumy += yb[20]; + sumy += yb[21]; + sumy += yb[22]; + sumy += yb[23]; + + + yl.s0 = yb[0]; + yl.s1 = yb[1]/256.f; + + yl.s2 = yb[2]; + yl.s3 = yb[3]/256.f; + + yl.s4 = yb[4]; + yl.s5 = yb[5]/256.f; + + yl.s6 = yb[6]; + yl.s7 = yb[7]/256.f; + + yl.s8 = yb[16]/16.f; + yl.s9 = yb[17]/4096.f; + + yl.sa = yb[18]/16.f; + yl.sb = yb[19]/4096.f; + + yl.sc = yb[20]/16.f; + yl.sd = yb[21]/4096.f; + + yl.se = yb[22]/16.f; + yl.sf = yb[23]/4096.f; + + sumf.s0 += block_q5_1_dot_y(x+ib+0*nb, sumy, yl, il, yb); + sumf.s1 += block_q5_1_dot_y(x+ib+1*nb, sumy, yl, il, yb); + sumf.s2 += block_q5_1_dot_y(x+ib+2*nb, sumy, yl, il, yb); + sumf.s3 += block_q5_1_dot_y(x+ib+3*nb, sumy, yl, il, yb); + + yb += QK5_1 * (N_SIMDWIDTH/2); + } + + float4 tot = (float4)( + sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), + sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3) + ); + + if (get_sub_group_local_id() == 0) { + if (first_row + 0 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; + } + if (first_row + 1 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; + } + if (first_row + 2 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; + } + if (first_row + 3 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; + } + } +} + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mv_q5_1_f32( + global void * src0, + ulong offset0, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + mul_vec_q_n_f32(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32_flat.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32_flat.cl new file mode 100644 index 00000000000..57c2f140958 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32_flat.cl @@ -0,0 +1,247 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define QK5_1 32 + +inline float block_q5_1_dot_y_flat( + global const uchar * x, + global const uint * qh_ptr, + global const half * dh, + global const half * mh, + float sumy, + float16 yl, + int il, + global const float * yb +) { + float d = *dh; + float m = *mh; + global const ushort * qs = ((global const ushort *)(x + il)); + + float4 acc = (float4)(0.0f, 0.0f, 0.0f, 0.0f); + + acc.s0 += yl.s0 * (qs[0] & 0x000F); + acc.s0 += yl.s1 * (qs[0] & 0x0F00); + acc.s0 += yl.s8 * (qs[0] & 0x00F0); + acc.s3 += yl.s9 * (qs[0] & 0xF000); + + acc.s0 += yl.s2 * (qs[1] & 0x000F); + acc.s1 += yl.s3 * (qs[1] & 0x0F00); + acc.s2 += yl.sa * (qs[1] & 0x00F0); + acc.s3 += yl.sb * (qs[1] & 0xF000); + + acc.s0 += yl.s4 * (qs[2] & 0x000F); + acc.s1 += yl.s5 * (qs[2] & 0x0F00); + acc.s2 += yl.sc * (qs[2] & 0x00F0); + acc.s3 += yl.sd * (qs[2] & 0xF000); + + acc.s0 += yl.s6 * (qs[3] & 0x000F); + acc.s1 += yl.s7 * (qs[3] & 0x0F00); + acc.s2 += yl.se * (qs[3] & 0x00F0); + acc.s3 += yl.sf * (qs[3] & 0xF000); + + uint qh_val = *qh_ptr; + uchar qh_lo = (uchar)((qh_val >> il) & 0xFF); + uchar qh_hi = (uchar)((qh_val >> (il + 16)) & 0xFF); + + float qh_sum = 0.0f; + qh_sum += yb[0] * (float)((qh_lo >> 0) & 1); + qh_sum += yb[1] * (float)((qh_lo >> 1) & 1); + qh_sum += yb[2] * (float)((qh_lo >> 2) & 1); + qh_sum += yb[3] * (float)((qh_lo >> 3) & 1); + qh_sum += yb[4] * (float)((qh_lo >> 4) & 1); + qh_sum += yb[5] * (float)((qh_lo >> 5) & 1); + qh_sum += yb[6] * (float)((qh_lo >> 6) & 1); + qh_sum += yb[7] * (float)((qh_lo >> 7) & 1); + qh_sum += yb[16] * (float)((qh_hi >> 0) & 1); + qh_sum += yb[17] * (float)((qh_hi >> 1) & 1); + qh_sum += yb[18] * (float)((qh_hi >> 2) & 1); + qh_sum += yb[19] * (float)((qh_hi >> 3) & 1); + qh_sum += yb[20] * (float)((qh_hi >> 4) & 1); + qh_sum += yb[21] * (float)((qh_hi >> 5) & 1); + qh_sum += yb[22] * (float)((qh_hi >> 6) & 1); + qh_sum += yb[23] * (float)((qh_hi >> 7) & 1); + + return d * (acc.s0 + acc.s1 + acc.s2 + acc.s3 + 16.0f * qh_sum) + sumy * m; +} + +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 4 // each subgroup works on 4 rows +#define N_SIMDGROUP 1 // number of subgroups in a thread group +#define N_SIMDWIDTH 16 // assuming subgroup size is 16 +#elif defined (ADRENO_GPU) +#define N_DST 4 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif + +inline void mul_vec_q_n_f32_flat( + global void * src0_qs, + global void * src0_qh, + global void * src0_d, + global void * src0_m, + global float * src1, + global float * dst, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + const ulong nb = ne00/QK5_1; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + ulong offset0_qs = offset0 * (QK5_1/2); + + global uchar * x = (global uchar *) src0_qs + offset0_qs; + global uint * qh = (global uint *) src0_qh + offset0; + global half * d = (global half *) src0_d + offset0; + global half * ms = (global half *) src0_m + offset0; + global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float16 yl; + float4 sumf = (float4)(0.f, 0.f, 0.f, 0.f); + + int ix = get_sub_group_local_id()/2; + int il = 8*(get_sub_group_local_id()%2); + + global float * yb = y + ix * QK5_1 + il; + + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { + float sumy = 0; + + sumy += yb[0]; + sumy += yb[1]; + sumy += yb[2]; + sumy += yb[3]; + sumy += yb[4]; + sumy += yb[5]; + sumy += yb[6]; + sumy += yb[7]; + + sumy += yb[16]; + sumy += yb[17]; + sumy += yb[18]; + sumy += yb[19]; + sumy += yb[20]; + sumy += yb[21]; + sumy += yb[22]; + sumy += yb[23]; + + + yl.s0 = yb[0]; + yl.s1 = yb[1]/256.f; + + yl.s2 = yb[2]; + yl.s3 = yb[3]/256.f; + + yl.s4 = yb[4]; + yl.s5 = yb[5]/256.f; + + yl.s6 = yb[6]; + yl.s7 = yb[7]/256.f; + + yl.s8 = yb[16]/16.f; + yl.s9 = yb[17]/4096.f; + + yl.sa = yb[18]/16.f; + yl.sb = yb[19]/4096.f; + + yl.sc = yb[20]/16.f; + yl.sd = yb[21]/4096.f; + + yl.se = yb[22]/16.f; + yl.sf = yb[23]/4096.f; + + sumf.s0 += block_q5_1_dot_y_flat(x + ib*(QK5_1/2) + 0*nb*(QK5_1/2), qh + ib + 0*nb, d + ib + 0*nb, ms + ib + 0*nb, sumy, yl, il, yb); + sumf.s1 += block_q5_1_dot_y_flat(x + ib*(QK5_1/2) + 1*nb*(QK5_1/2), qh + ib + 1*nb, d + ib + 1*nb, ms + ib + 1*nb, sumy, yl, il, yb); + sumf.s2 += block_q5_1_dot_y_flat(x + ib*(QK5_1/2) + 2*nb*(QK5_1/2), qh + ib + 2*nb, d + ib + 2*nb, ms + ib + 2*nb, sumy, yl, il, yb); + sumf.s3 += block_q5_1_dot_y_flat(x + ib*(QK5_1/2) + 3*nb*(QK5_1/2), qh + ib + 3*nb, d + ib + 3*nb, ms + ib + 3*nb, sumy, yl, il, yb); + + yb += QK5_1 * (N_SIMDWIDTH/2); + } + + float4 tot = (float4)( + sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), + sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3) + ); + + if (get_sub_group_local_id() == 0) { + if (first_row + 0 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; + } + if (first_row + 1 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; + } + if (first_row + 2 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; + } + if (first_row + 3 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; + } + } +} + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mv_q5_1_f32_flat( + global void * src0_qs, + global void * src0_qh, + global void * src0_d, + global void * src0_m, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + mul_vec_q_n_f32_flat(src0_qs, src0_qh, src0_d, src0_m, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); +} From db2a39507ccb33f5e9b2aedeefa9d149990d1f0a Mon Sep 17 00:00:00 2001 From: Masashi Yoshimura <yoshimura.masashi.frbs@gmail.com> Date: Tue, 2 Jun 2026 08:59:06 +0900 Subject: [PATCH 773/831] revert to using global_invocation_id for cpy shader (llama/23955) --- ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl index e268adfb16b..67f1dc0928f 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl @@ -50,13 +50,13 @@ var<uniform> params: Params; @compute @workgroup_size(WG_SIZE) fn main( - @builtin(global_invocation_index) gindex: u32, + @builtin(global_invocation_id) gid: vec3<u32>, ) { - if (gindex >= params.ne) { + if (gid.x >= params.ne) { return; } - var i = gindex; + var i = gid.x; let i3 = i / (params.src_ne2 * params.src_ne1 * params.src_ne0); i = i % (params.src_ne2 * params.src_ne1 * params.src_ne0); let i2 = i / (params.src_ne1 * params.src_ne0); @@ -64,7 +64,7 @@ fn main( let i1 = i / params.src_ne0; let i0 = i % params.src_ne0; - var j = gindex; + var j = gid.x; let j3 = j / (params.dst_ne2 * params.dst_ne1 * params.dst_ne0); j = j % (params.dst_ne2 * params.dst_ne1 * params.dst_ne0); let j2 = j / (params.dst_ne1 * params.dst_ne0); @@ -80,4 +80,3 @@ fn main( dst[params.offset_dst + dst_idx] = DST_TYPE((src[params.offset_src + src_idx])); } - From 9a0265d13b890fffa18315f4da5de51147dc8ccd Mon Sep 17 00:00:00 2001 From: lhez <lih@qti.qualcomm.com> Date: Mon, 1 Jun 2026 19:15:09 -0700 Subject: [PATCH 774/831] opencl: fix compiler warnings for non-adreno path (llama/23922) * opencl: fix compiler warnings for non-adreno path * opencl: fix const cast warning --- ggml/src/ggml-opencl/ggml-opencl.cpp | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 7cafbe0cdc3..b67ea46bce8 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -380,7 +380,7 @@ struct ggml_backend_opencl_device_context { ADRENO_GPU_GEN adreno_gen = ADRENO_GPU_GEN::ADRENO_UNKNOWN; std::regex *opfilter = nullptr; // regex of ops to not claim - std::string opfilter_str; // regex string for opfilter + std::string opfilter_str = ""; // regex string for opfilter size_t global_mem_size = 0; }; @@ -6822,9 +6822,6 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, cl_buffer_region region; - cl_uchar mask_0F = 0x0F; - cl_uchar mask_F0 = 0xF0; - #ifdef GGML_OPENCL_USE_ADRENO_KERNELS // Adreno MoE Q6_K kernel needs special transposed layout if (use_adreno_moe_kernels(backend_ctx, tensor)) { @@ -6858,6 +6855,9 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, cl_kernel kernel = backend_ctx->kernel_convert_block_q6_k_trans4_ns; + cl_uchar mask_0F = 0x0F; + cl_uchar mask_F0 = 0xF0; + int ne00 = tensor->ne[0]; int ne01 = tensor->ne[1]; int ne02 = tensor->ne[2]; @@ -6994,7 +6994,7 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, cl_int err; cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, - size, (void *) data, &err); + size, const_cast<void *>(data), &err); CL_CHECK(err); cl_kernel kernel = backend_ctx->kernel_convert_bf16_to_f16; @@ -7782,9 +7782,6 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, if (tensor->type == GGML_TYPE_Q6_K) { ggml_tensor_extra_cl_q6_K * extra = (ggml_tensor_extra_cl_q6_K *)tensor->extra; - cl_uchar mask_0F = 0x0F; - cl_uchar mask_F0 = 0xF0; - #ifdef GGML_OPENCL_USE_ADRENO_KERNELS if (use_adreno_moe_kernels(backend_ctx, tensor)) { cl_int err; @@ -7794,6 +7791,9 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, cl_kernel kernel = backend_ctx->kernel_restore_block_q6_k_trans4_ns; + cl_uchar mask_0F = 0x0F; + cl_uchar mask_F0 = 0xF0; + int ne00 = tensor->ne[0]; int ne01 = tensor->ne[1]; int ne02 = tensor->ne[2]; @@ -14888,6 +14888,8 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, const int ne1 = dst->ne[1]; const int ne2 = dst->ne[2]; + GGML_UNUSED(ne2); + const int r2 = ne12/ne02; const int r3 = ne13/ne03; const int dst_rows = ne20*ne21; // ne20 = n_used_experts, ne21 = n_rows @@ -14902,6 +14904,8 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, const int n_tile_size = 32; const int max_post_router_tile = (ne20 * ne21 / n_tile_size) + ne02; + GGML_UNUSED(max_post_router_tile); + cl_kernel kernel; // subgroup mat vec From 79223704a1ec33e4147ab6f0d10934667cea7200 Mon Sep 17 00:00:00 2001 From: Anav Prasad <anavp@nvidia.com> Date: Mon, 1 Jun 2026 19:38:37 -0700 Subject: [PATCH 775/831] clean up unused variables warnings (llama/23975) --- ggml/src/ggml-cuda/fattn-mma-f16.cuh | 6 +++--- ggml/src/ggml-cuda/gated_delta_net.cu | 10 +++++----- ggml/src/ggml-cuda/mmf.cuh | 6 +++--- ggml/src/ggml-cuda/mmvf.cu | 13 ++++++------- ggml/src/ggml-cuda/mmvq.cu | 13 ++++--------- ggml/src/ggml-cuda/topk-moe.cu | 2 +- 6 files changed, 22 insertions(+), 28 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 3c8b6eaaf24..ac5abb13367 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -568,7 +568,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( constexpr bool Q_in_reg = ggml_cuda_fattn_mma_get_Q_in_reg (DKQ, DV, ncols); constexpr int nstages = ggml_cuda_fattn_mma_get_nstages (DKQ, DV, ncols1, ncols2); - constexpr int stride_tile_Q = DKQ/2 + 4; constexpr int stride_tile_K = nbatch_K2 + 4; constexpr int stride_tile_V = V_is_K_view ? stride_tile_K : nbatch_V2 + 4; @@ -604,9 +603,9 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( #pragma unroll for (int k0_start = (DKQ/2-1) - (DKQ/2-1) % nbatch_K2; k0_start >= 0; k0_start -= nbatch_K2) { const int k0_stop = k0_start + nbatch_K2 < DKQ/2 ? k0_start + nbatch_K2 : DKQ/2; - const int k0_diff = k0_stop - k0_start; if constexpr (nstages <= 1) { + const int k0_diff = k0_stop - k0_start; constexpr bool use_cp_async = nstages == 1; flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, nbatch_fa, use_cp_async, oob_check> (K_h2 + int64_t(k_VKQ_0)*stride_K + k0_start, tile_K, k0_diff, stride_K, k_VKQ_sup); @@ -640,6 +639,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } } } else { + constexpr int stride_tile_Q = DKQ/2 + 4; #pragma unroll for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += T_A_KQ::J) { load_ldmatrix(Q_B[0], tile_Q + (threadIdx.y / np)*(T_B_KQ::I*stride_tile_Q) + k_KQ_0, stride_tile_Q); @@ -954,9 +954,9 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( for (int i0_start = 0; i0_start < DV; i0_start += 2*nbatch_V2) { static_assert(DV % (2*nbatch_V2) == 0, "bad loop size"); const int i0_stop = i0_start + 2*nbatch_V2; - const int i0_diff = i0_stop - i0_start; if constexpr (nstages <= 1) { + const int i0_diff = i0_stop - i0_start; if (!V_is_K_view || i0_stop > 2*nbatch_K2) { constexpr bool use_cp_async = nstages == 1; flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, nbatch_fa, use_cp_async, oob_check> diff --git a/ggml/src/ggml-cuda/gated_delta_net.cu b/ggml/src/ggml-cuda/gated_delta_net.cu index 018d5d37d47..7cfda652367 100644 --- a/ggml/src/ggml-cuda/gated_delta_net.cu +++ b/ggml/src/ggml-cuda/gated_delta_net.cu @@ -43,7 +43,6 @@ gated_delta_net_cuda(const float * q, // output state layout (per-slot D * n_seqs) — same per-(seq,head) offset as before. const int64_t state_in_offset = sequence * K * H * S_v * S_v + h_idx * S_v * S_v; const int64_t state_out_offset = (sequence * H + h_idx) * S_v * S_v; - const int64_t state_size_per_token = S_v * S_v * H * n_seqs; // per-slot stride in output state += state_out_offset; curr_state += state_in_offset + col * S_v; attn_data += (sequence * n_tokens * H + h_idx) * S_v; @@ -61,10 +60,6 @@ gated_delta_net_cuda(const float * q, s_shard[r] = curr_state[i]; } - // slot mapping: target_slot = t - shift. When n_tokens < K only the last n_tokens slots - // are written; earlier slots are left untouched (caller-owned). - const int shift = (int) n_tokens - K; - for (int t = 0; t < n_tokens; t++) { const float * q_t = q + iq3 * sq3 + t * sq2 + iq1 * sq1; const float * k_t = k + iq3 * sq3 + t * sq2 + iq1 * sq1; @@ -148,6 +143,11 @@ gated_delta_net_cuda(const float * q, attn_data += S_v * H; if constexpr (keep_rs_t) { + // slot mapping: target_slot = t - shift. When n_tokens < K only the last n_tokens slots + // are written; earlier slots are left untouched (caller-owned). + const int shift = (int) n_tokens - K; + + const int64_t state_size_per_token = S_v * S_v * H * n_seqs; // per-slot stride in output const int target_slot = t - shift; if (target_slot >= 0 && target_slot < K) { float * curr_state = (dst + attn_score_elems) + target_slot * state_size_per_token + state_out_offset; diff --git a/ggml/src/ggml-cuda/mmf.cuh b/ggml/src/ggml-cuda/mmf.cuh index c2a8d54c95a..d55cc1ec7b5 100644 --- a/ggml/src/ggml-cuda/mmf.cuh +++ b/ggml/src/ggml-cuda/mmf.cuh @@ -91,7 +91,7 @@ static __global__ void mul_mat_f( const int row0 = blockIdx.x * rows_per_block; int expert_idx = 0; - int col_base = 0; + [[maybe_unused]] int col_base = 0; const int channel_dst = has_ids ? 0 : blockIdx.y; @@ -122,12 +122,12 @@ static __global__ void mul_mat_f( ids += col_offset * stride_row_id; } - const float2 * y2 = (const float2 *) y; + [[maybe_unused]] const float2 * y2 = (const float2 *) y; extern __shared__ char data_mmv[]; char * shmem_base = data_mmv; - int * slot_map = (int *) shmem_base; + [[maybe_unused]] int * slot_map = (int *) shmem_base; char * compute_base = has_ids ? (shmem_base + GGML_PAD(cols_per_block, 16) * sizeof(int)) : shmem_base; tile_C C[ntA][ntB]; diff --git a/ggml/src/ggml-cuda/mmvf.cu b/ggml/src/ggml-cuda/mmvf.cu index 09d95f309b4..3d6de64b775 100644 --- a/ggml/src/ggml-cuda/mmvf.cu +++ b/ggml/src/ggml-cuda/mmvf.cu @@ -80,9 +80,8 @@ static __global__ void mul_mat_vec_f( gate_x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row*stride_row; } - const int channel_bias = ids ? channel_x : channel_dst; - if constexpr (has_fusion) { + const int channel_bias = ids ? channel_x : channel_dst; if (use_bias) { x_bias += int64_t(sample_dst)*stride_sample_dst + channel_bias*stride_channel_dst; } @@ -95,7 +94,7 @@ static __global__ void mul_mat_vec_f( extern __shared__ char data_mmv[]; float * buf_iw = (float *) data_mmv; - float * buf_iw_gate = nullptr; + [[maybe_unused]] float * buf_iw_gate = nullptr; if constexpr (has_fusion) { buf_iw_gate = (float *) (data_mmv + warp_size*sizeof(float)); } @@ -123,7 +122,7 @@ static __global__ void mul_mat_vec_f( if constexpr (std::is_same_v<T, float>) { const float2 * x2 = (const float2 *) x; - const float2 * gate_x2 = nullptr; + [[maybe_unused]] const float2 * gate_x2 = nullptr; if constexpr (has_fusion) { if (use_gate) { gate_x2 = (const float2 *) gate_x; @@ -155,7 +154,7 @@ static __global__ void mul_mat_vec_f( } } else if constexpr (std::is_same_v<T, half>) { const half2 * x2 = (const half2 *) x; - const half2 * gate_x2 = nullptr; + [[maybe_unused]] const half2 * gate_x2 = nullptr; if constexpr (has_fusion) { if (use_gate) { gate_x2 = (const half2 *) gate_x; @@ -266,7 +265,7 @@ static __global__ void mul_mat_vec_f( } #else const nv_bfloat162 * x2 = (const nv_bfloat162 *) x; - const nv_bfloat162 * gate_x2 = nullptr; + [[maybe_unused]] const nv_bfloat162 * gate_x2 = nullptr; if constexpr (has_fusion) { if (use_gate) { gate_x2 = (const nv_bfloat162 *) gate_x; @@ -274,7 +273,7 @@ static __global__ void mul_mat_vec_f( } for (int col2 = tid; col2 < ncols2; col2 += block_size) { const nv_bfloat162 tmpx = x2[col2]; - nv_bfloat162 tmpx_gate; + [[maybe_unused]] nv_bfloat162 tmpx_gate; if constexpr (has_fusion) { if (use_gate) { tmpx_gate = gate_x2[col2]; diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index ecb6fdedadd..86b4a493019 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -515,7 +515,7 @@ static __global__ void mul_mat_vec_q( bool use_gate = false; bool use_bias = false; bool use_gate_bias = false; - const void * vgate = nullptr; + [[maybe_unused]] const void * vgate = nullptr; const float * x_bias = nullptr; const float * gate_bias = nullptr; ggml_glu_op active_glu; @@ -531,8 +531,8 @@ static __global__ void mul_mat_vec_q( } - float x_biases[ncols_dst] = { 0.0f }; - float gate_biases[ncols_dst] = { 0.0f }; + [[maybe_unused]] float x_biases[ncols_dst] = { 0.0f }; + [[maybe_unused]] float gate_biases[ncols_dst] = { 0.0f }; if constexpr (has_fusion) { const uint32_t channel_bias = ids ? channel_x : channel_dst; if (use_bias) { @@ -589,12 +589,7 @@ static __global__ void mul_mat_vec_q( } __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size]; - __shared__ float tmp_shared_gate[(has_fusion && (nwarps-1 > 0)) ? nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size]; - if constexpr (!has_fusion) { - (void) tmp_shared_gate; - } else if (!use_gate) { - (void) tmp_shared_gate; - } + [[maybe_unused]] __shared__ float tmp_shared_gate[(has_fusion && (nwarps-1 > 0)) ? nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size]; if (threadIdx.y > 0) { #pragma unroll diff --git a/ggml/src/ggml-cuda/topk-moe.cu b/ggml/src/ggml-cuda/topk-moe.cu index da20c9aab7c..c4253bfa43b 100644 --- a/ggml/src/ggml-cuda/topk-moe.cu +++ b/ggml/src/ggml-cuda/topk-moe.cu @@ -134,7 +134,7 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * // selection_wt is only needed when bias is present (selection uses wt + bias) // when no bias, we use wt directly for both selection and weight values - float selection_wt[has_bias ? experts_per_thread : 1]; + [[maybe_unused]] float selection_wt[has_bias ? experts_per_thread : 1]; if constexpr (has_bias) { #pragma unroll From 754247f28b7615704a408ccd4c6331ab26c9d402 Mon Sep 17 00:00:00 2001 From: Todor Boinovski <todorb@qti.qualcomm.com> Date: Mon, 1 Jun 2026 23:19:07 -0700 Subject: [PATCH 776/831] hexagon: add gelu_quick (llama/24007) --- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 48ded82e83c..920829f6a93 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -3142,13 +3142,14 @@ static htp_op_code op_remap_to_htp(const ggml_tensor * t) { case GGML_OP_UNARY: switch (ggml_get_unary_op(t)) { - case GGML_UNARY_OP_SILU: return HTP_OP_UNARY_SILU; - case GGML_UNARY_OP_GELU: return HTP_OP_UNARY_GELU; - case GGML_UNARY_OP_SIGMOID: return HTP_OP_UNARY_SIGMOID; - case GGML_UNARY_OP_NEG: return HTP_OP_UNARY_NEG; - case GGML_UNARY_OP_EXP: return HTP_OP_UNARY_EXP; - case GGML_UNARY_OP_SOFTPLUS: return HTP_OP_UNARY_SOFTPLUS; - case GGML_UNARY_OP_TANH: return HTP_OP_UNARY_TANH; + case GGML_UNARY_OP_SILU: return HTP_OP_UNARY_SILU; + case GGML_UNARY_OP_GELU: return HTP_OP_UNARY_GELU; + case GGML_UNARY_OP_GELU_QUICK: return HTP_OP_UNARY_GELU; + case GGML_UNARY_OP_SIGMOID: return HTP_OP_UNARY_SIGMOID; + case GGML_UNARY_OP_NEG: return HTP_OP_UNARY_NEG; + case GGML_UNARY_OP_EXP: return HTP_OP_UNARY_EXP; + case GGML_UNARY_OP_SOFTPLUS: return HTP_OP_UNARY_SOFTPLUS; + case GGML_UNARY_OP_TANH: return HTP_OP_UNARY_TANH; default: break; } @@ -3630,6 +3631,7 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons break; case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_GELU: + case GGML_UNARY_OP_GELU_QUICK: supp = ggml_hexagon_supported_activations(sess, op); break; default: From 8d61a9edf0b3b429671f1a7037d67243d941b517 Mon Sep 17 00:00:00 2001 From: Max Krasnyansky <maxk@qti.qualcomm.com> Date: Mon, 1 Jun 2026 23:40:08 -0700 Subject: [PATCH 777/831] hexagon: MUL_MAT, MUL_MAT_ID, FLASH_ATTN and GDN cleanup and optimizations for latest models (llama/23989) * hex-mm: initial support for F32 * F32 -> F32 matmuls * hex-rms-norm: fix src1 stride use in fused rms_norm_mul * hex-ops: clear spad pointers in the ops that clober it This fixes an odd case where fused rms-norm-mul was failing but only in qwen3.5-2B and only at searth op-bath sizes. * hmx-mm: add support for F32 * F32 -> F32 matmul_2d on HMX Decided to use Q4_0 * F32 -> F32 matmul for this. Q4_0 gets dequantized and tiled into F16, and here we quantize and tile F32 into F16. Super simple and pretty efficient. * hmx-mm: route f16 2D matmuls through the same kernel used for all other types * hmx-mm: re-introduce pipelined vs non-pipelined mode that we used to have but is much more generic way This update futher improves matmul performance and at the same time removes most of the redudant logic we had in different paths. * hmx-fa: slighlty improved pipeline simimar to matmul updates * hmx-mm: initial version of MAT_MUL_ID support for HMX * hmx-mm: fixed mxfp4 handling for MUL_MAT_ID * hex-gdn: optimize GATED_DELTA_NET DMA prefetch/double-buff, vectorize everything with HVX, in other words -- the usual :) * hmx-mm: missed one more case where we can use fastmod * hexagon: update DCVS settings for a slight perf bump * hmx-fa: use fastdiv in hmx-flash-attn * hmx-fa: precompute slope values to avoid disrupting the inner loop * hvx-utils/fa: new HVX helpers for powf and logf and using those to speed up FA alibi * hex-ops: fixed a bug in fusion logic that was messing up the order of the src tensors when some srcs are empty * hex-fa: correctly fallback to HVX if we have sinks or the dims are not quite right --- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 22 + ggml/src/ggml-hexagon/htp-opnode.h | 46 +- ggml/src/ggml-hexagon/htp/CMakeLists.txt | 47 +- ggml/src/ggml-hexagon/htp/argsort-ops.c | 1 + ggml/src/ggml-hexagon/htp/concat-ops.c | 2 + ggml/src/ggml-hexagon/htp/flash-attn-ops.c | 15 +- .../ggml-hexagon/htp/gated-delta-net-ops.c | 660 ++++++++----- .../src/ggml-hexagon/htp/hmx-flash-attn-ops.c | 116 ++- ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c | 885 +++++++++++++----- ggml/src/ggml-hexagon/htp/hmx-ops.c | 6 + ggml/src/ggml-hexagon/htp/hmx-ops.h | 22 +- ggml/src/ggml-hexagon/htp/htp-ctx.h | 4 + ggml/src/ggml-hexagon/htp/hvx-flash-attn.h | 47 + ggml/src/ggml-hexagon/htp/hvx-log.h | 65 ++ ggml/src/ggml-hexagon/htp/hvx-pow.h | 42 + ggml/src/ggml-hexagon/htp/hvx-utils.h | 2 + ggml/src/ggml-hexagon/htp/main.c | 26 +- ggml/src/ggml-hexagon/htp/matmul-ops.c | 390 +++++++- ggml/src/ggml-hexagon/htp/pad-ops.c | 2 + ggml/src/ggml-hexagon/htp/unary-ops.c | 17 +- 20 files changed, 1825 insertions(+), 592 deletions(-) create mode 100644 ggml/src/ggml-hexagon/htp/hmx-ops.c create mode 100644 ggml/src/ggml-hexagon/htp/hvx-flash-attn.h create mode 100644 ggml/src/ggml-hexagon/htp/hvx-log.h create mode 100644 ggml/src/ggml-hexagon/htp/hvx-pow.h diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 920829f6a93..d550841a2a5 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -1927,6 +1927,7 @@ struct ggml_hexagon_opbatch { size_t extra_tens = 0; auto fit_tensor = [&](const ggml_tensor *t) { + if (!t) return; if (!t_map.count(t)) { extra_tens++; @@ -2602,6 +2603,27 @@ static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * s GGML_LOG_DEBUG("ggml_hexagon_supported_mul_mat: permuted F16 src0 not supported\n"); return false; } + if (src1->ne[2] < src0->ne[2] || src1->ne[3] < src0->ne[3]) { + GGML_LOG_DEBUG("ggml_hexagon_supported_mul_mat: src1 broadcasting not supported\n"); + return false; + } + if (ggml_nrows(src1) > 1024) { + return false; // no huge batches (for now) + } + break; + + case GGML_TYPE_F32: + if (src1->type != GGML_TYPE_F32) { + return false; + } + if (src0->nb[1] < src0->nb[0]) { + GGML_LOG_DEBUG("ggml_hexagon_supported_mul_mat: permuted F32 src0 not supported\n"); + return false; + } + if (src1->ne[2] < src0->ne[2] || src1->ne[3] < src0->ne[3]) { + GGML_LOG_DEBUG("ggml_hexagon_supported_mul_mat: src1 broadcasting not supported\n"); + return false; + } if (ggml_nrows(src1) > 1024) { return false; // no huge batches (for now) } diff --git a/ggml/src/ggml-hexagon/htp-opnode.h b/ggml/src/ggml-hexagon/htp-opnode.h index 14b232240b4..8a1228ccdc0 100644 --- a/ggml/src/ggml-hexagon/htp-opnode.h +++ b/ggml/src/ggml-hexagon/htp-opnode.h @@ -56,7 +56,7 @@ struct htp_opnode { } std::vector<const ggml_tensor *> get_inputs() const { - std::vector<const ggml_tensor *> inputs; + std::vector<const ggml_tensor *> inputs(GGML_MAX_SRC, nullptr); std::vector<const ggml_tensor *> outputs; outputs.push_back(node); for (const auto * f : fused) { @@ -70,20 +70,38 @@ struct htp_opnode { return false; }; + int count = 0; auto add_input = [&](const ggml_tensor * t) { if (t && !contains(outputs, t) && !contains(inputs, t)) { - inputs.push_back(t); + if (count < (int)inputs.size()) { + inputs[count++] = t; + } else { + inputs.push_back(t); + } } }; - for (int i = 0; i < GGML_MAX_SRC && node->src[i]; i++) { - add_input(node->src[i]); + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (fused.empty()) { + inputs[i] = node->src[i]; + } else { + if (node->src[i]) { + add_input(node->src[i]); + } + } } for (const auto * f : fused) { - for (int i = 0; i < GGML_MAX_SRC && f->src[i]; i++) { - add_input(f->src[i]); + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (f->src[i]) { + add_input(f->src[i]); + } } } + + if (!fused.empty()) { + inputs.resize(count); + } + return inputs; } @@ -108,6 +126,9 @@ struct htp_opformat { char names[64 * GGML_MAX_SRC]; int format_tensor_dims(char * str, const struct ggml_tensor * t) { + if (!t) { + return sprintf(str, "NONE"); + } if (t->ne[2] == 1 && t->ne[3] == 1) { return sprintf(str, "%d:%d", (int) t->ne[0], (int) t->ne[1]); } else { @@ -136,6 +157,9 @@ struct htp_opformat { } int format_tensor_strides(char * str, const struct ggml_tensor * t) { + if (!t) { + return sprintf(str, "NONE"); + } const char * c = ggml_is_contiguous(t) ? "" : "!"; if (t->ne[2] == 1 && t->ne[3] == 1) { @@ -170,11 +194,11 @@ struct htp_opformat { auto inputs = node.get_inputs(); if (!inputs.empty()) { - p += sprintf(p, "%s", ggml_type_name(inputs[0]->type)); + p += sprintf(p, "%s", inputs[0] ? ggml_type_name(inputs[0]->type) : "NONE"); for (size_t i = 1; i < inputs.size(); i++) { p += sprintf(p, " x "); - p += sprintf(p, "%s", ggml_type_name(inputs[i]->type)); + p += sprintf(p, "%s", inputs[i] ? ggml_type_name(inputs[i]->type) : "NONE"); } p += sprintf(p, " -> "); @@ -184,7 +208,7 @@ struct htp_opformat { } const char * tensor_buff_name(const struct ggml_tensor * t) { - if (t->buffer) { + if (t && t->buffer) { return ggml_backend_buffer_name(t->buffer); } return "NONE"; @@ -213,11 +237,11 @@ struct htp_opformat { auto inputs = node.get_inputs(); if (!inputs.empty()) { - p += sprintf(p, "%s", inputs[0]->name); + p += sprintf(p, "%s", inputs[0] ? inputs[0]->name : "NONE"); for (size_t i = 1; i < inputs.size(); i++) { p += sprintf(p, " x "); - p += sprintf(p, "%s", inputs[i]->name); + p += sprintf(p, "%s", inputs[i] ? inputs[i]->name : "NONE"); } p += sprintf(p, " -> "); diff --git a/ggml/src/ggml-hexagon/htp/CMakeLists.txt b/ggml/src/ggml-hexagon/htp/CMakeLists.txt index ff3fc0804e3..f4b44fe1a65 100644 --- a/ggml/src/ggml-hexagon/htp/CMakeLists.txt +++ b/ggml/src/ggml-hexagon/htp/CMakeLists.txt @@ -19,27 +19,6 @@ add_library(${HTP_LIB} SHARED htp_iface_skel.c worker-pool.c hex-dma.c - matmul-ops.c - binary-ops.c - unary-ops.c - sum-rows-ops.c - softmax-ops.c - act-ops.c - rope-ops.c - flash-attn-ops.c - set-rows-ops.c - get-rows-ops.c - cpy-ops.c - repeat-ops.c - argsort-ops.c - ssm-conv.c - cumsum-ops.c - fill-ops.c - concat-ops.c - diag-ops.c - solve-tri-ops.c - gated-delta-net-ops.c - pad-ops.c ) target_compile_definitions(${HTP_LIB} PRIVATE @@ -58,8 +37,8 @@ list(FIND HTP_HMX_VERSIONS ${DSP_VERSION} _hmx_idx) if (_hmx_idx GREATER_EQUAL 0) target_sources(${HTP_LIB} PRIVATE - hmx-flash-attn-ops.c hmx-matmul-ops.c + hmx-flash-attn-ops.c hmx-queue.c ) @@ -76,6 +55,30 @@ endif() build_idl(htp_iface.idl ${HTP_LIB}) +target_sources(${HTP_LIB} PRIVATE + matmul-ops.c + binary-ops.c + unary-ops.c + sum-rows-ops.c + softmax-ops.c + act-ops.c + rope-ops.c + flash-attn-ops.c + set-rows-ops.c + get-rows-ops.c + cpy-ops.c + repeat-ops.c + argsort-ops.c + ssm-conv.c + cumsum-ops.c + fill-ops.c + concat-ops.c + diag-ops.c + solve-tri-ops.c + gated-delta-net-ops.c + pad-ops.c +) + set_target_properties(${HTP_LIB} PROPERTIES EXPORT_COMPILE_COMMANDS ON) install(TARGETS ${HTP_LIB}) diff --git a/ggml/src/ggml-hexagon/htp/argsort-ops.c b/ggml/src/ggml-hexagon/htp/argsort-ops.c index bdd0623615d..73af38a35ab 100644 --- a/ggml/src/ggml-hexagon/htp/argsort-ops.c +++ b/ggml/src/ggml-hexagon/htp/argsort-ops.c @@ -276,6 +276,7 @@ int op_argsort(struct htp_ops_context * octx) { octx->src0_spad.data = octx->ctx->vtcm_base; octx->src0_spad.size = total_spad_size; octx->src0_spad.size_per_thread = spad_per_thread; + octx->src0_spad.src = NULL; FARF(HIGH, "argsort: %ux%ux%ux%u -> %ux%ux%ux%u (0x%x, 0x%x)", octx->src[0]->ne[0], octx->src[0]->ne[1], octx->src[0]->ne[2], octx->src[0]->ne[3], diff --git a/ggml/src/ggml-hexagon/htp/concat-ops.c b/ggml/src/ggml-hexagon/htp/concat-ops.c index 61580f2c08f..f2a381313c5 100644 --- a/ggml/src/ggml-hexagon/htp/concat-ops.c +++ b/ggml/src/ggml-hexagon/htp/concat-ops.c @@ -262,6 +262,8 @@ int op_concat(struct htp_ops_context * octx) { octx->src0_spad.data = octx->ctx->vtcm_base; octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; + octx->src0_spad.src = NULL; + octx->src1_spad.src = NULL; if (type_size == 4) { worker_func = concat_2d_f32_transposed; diff --git a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c index 1bd8c1407de..e996214691a 100644 --- a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +++ b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c @@ -11,6 +11,7 @@ #include "hex-dma.h" #include "hvx-utils.h" #include "hvx-dump.h" +#include "hvx-flash-attn.h" #define GGML_COMMON_DECL_C #include "ggml-common.h" @@ -245,6 +246,7 @@ struct htp_fa_context { uint32_t n_head_log2; float m0; float m1; + float slopes[512]; uint32_t n_blocks; @@ -412,7 +414,7 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void * } const uint32_t h = iq2; // head index - const float slope = (factx->max_bias > 0.0f) ? (h < factx->n_head_log2 ? powf(factx->m0, h + 1) : powf(factx->m1, 2*(h - factx->n_head_log2) + 1)) : 1.0f; + const float slope = factx->slopes[h]; HVX_Vector S_vec = hvx_vec_splat_f32(0.0f); HVX_Vector M_vec = hvx_vec_splat_f32(-INFINITY); @@ -628,8 +630,8 @@ int op_flash_attn_ext(struct htp_ops_context * octx) { } #ifdef HTP_HAS_HMX - // HMX path: head_dim multiple of 32, F16 KV - if (k->type == HTP_TYPE_F16 && v->type == HTP_TYPE_F16 && k->ne[0] % 32 == 0) { + // HMX path: head_dim multiple of 64, F16 KV, and no sinks + if (k->type == HTP_TYPE_F16 && v->type == HTP_TYPE_F16 && k->ne[0] % 64 == 0 && v->ne[0] % 64 == 0 && octx->src[4] == NULL) { int ret = hmx_flash_attn_ext(octx); if (ret == HTP_STATUS_OK) { return ret; @@ -689,6 +691,13 @@ int op_flash_attn_ext(struct htp_ops_context * octx) { factx.m0 = powf(2.0f, -(max_bias ) / factx.n_head_log2); factx.m1 = powf(2.0f, -(max_bias / 2.0f) / factx.n_head_log2); + if (n_head > 512) { + return HTP_STATUS_NO_SUPPORT; + } + for (uint32_t h = 0; h < n_head; ++h) { + factx.slopes[h] = (max_bias > 0.0f) ? alibi_slope(h, factx.n_head_log2, factx.m0, factx.m1) : 1.0f; + } + // total rows in q const uint32_t neq0 = q->ne[0]; const uint32_t neq1 = q->ne[1]; diff --git a/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c b/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c index c4d08bb21c4..3b092d7440d 100644 --- a/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c +++ b/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c @@ -3,6 +3,7 @@ #include <string.h> #include "hvx-utils.h" +#include "hex-fastdiv.h" #define GGML_COMMON_DECL_C #include "ggml-common.h" @@ -14,106 +15,103 @@ #define HTP_GDN_MAX_SV 128 + struct htp_gdn_context { struct htp_ops_context * octx; uint32_t rows_per_thread; - size_t state_bytes; - bool use_vtcm; - uint8_t * vtcm_state_base; - size_t vtcm_state_per_thread; + size_t state_bytes; + uint8_t * vtcm_base; + size_t vtcm_per_thread; }; -static inline float gdn_mul_dot_f32(float * restrict dst, const float * restrict mul, - const float * restrict dot, uint32_t n) { +static inline HVX_Vector gdn_mul_dot_f32(float * restrict dst, const float * restrict mul, const float * restrict dot, uint32_t n) { HVX_Vector acc = Q6_V_vzero(); - const uint32_t epv = 128 / sizeof(float); + const uint32_t epv = 128 / sizeof(float); const uint32_t nvec = n / epv; - const uint32_t tail = n % epv; + const uint32_t nloe = n % epv; for (uint32_t i = 0; i < nvec; ++i) { - HVX_Vector vd = hvx_vmemu(dst + i * epv); - HVX_Vector vm = hvx_vmem(mul + i * epv); + HVX_Vector vd = hvx_vmemu(dst + i * epv); + HVX_Vector vm = hvx_vmem(mul + i * epv); HVX_Vector vdot = hvx_vmem(dot + i * epv); - HVX_Vector out = hvx_vec_mul_f32_f32(vd, vm); + HVX_Vector out = hvx_vec_mul_f32_f32(vd, vm); hvx_vmemu(dst + i * epv) = out; acc = hvx_vec_add_f32_f32(acc, hvx_vec_mul_f32_f32(out, vdot)); } - if (tail) { + if (nloe) { const uint32_t off = nvec * epv; - HVX_Vector vd = hvx_vmemu(dst + off); - HVX_Vector vm = hvx_vmem(mul + off); + HVX_Vector vd = hvx_vmemu(dst + off); + HVX_Vector vm = hvx_vmem(mul + off); HVX_Vector vdot = hvx_vmem(dot + off); - HVX_Vector out = hvx_vec_mul_f32_f32(vd, vm); - hvx_vec_store_u(dst + off, tail * sizeof(float), out); - HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float)); + HVX_Vector out = hvx_vec_mul_f32_f32(vd, vm); + hvx_vec_store_u(dst + off, nloe * sizeof(float), out); + HVX_VectorPred mask = Q6_Q_vsetq2_R(nloe * sizeof(float)); HVX_Vector prod = hvx_vec_mul_f32_f32(out, vdot); acc = hvx_vec_add_f32_f32(acc, Q6_V_vmux_QVV(mask, prod, Q6_V_vzero())); } - return hvx_vec_get_f32(hvx_vec_reduce_sum_f32(acc)); + return hvx_vec_reduce_sum_f32(acc); } -static inline float gdn_mul_scalar_dot_f32(float * restrict dst, float mul, - const float * restrict dot, uint32_t n) { +static inline HVX_Vector gdn_mul_scalar_dot_f32(float * restrict dst, float mul, const float * restrict dot, uint32_t n) { HVX_Vector acc = Q6_V_vzero(); const HVX_Vector vmul = hvx_vec_splat_f32(mul); - const uint32_t epv = 128 / sizeof(float); + const uint32_t epv = 128 / sizeof(float); const uint32_t nvec = n / epv; - const uint32_t tail = n % epv; + const uint32_t nloe = n % epv; for (uint32_t i = 0; i < nvec; ++i) { - HVX_Vector vd = hvx_vmemu(dst + i * epv); + HVX_Vector vd = hvx_vmemu(dst + i * epv); HVX_Vector vdot = hvx_vmem(dot + i * epv); - HVX_Vector out = hvx_vec_mul_f32_f32(vd, vmul); + HVX_Vector out = hvx_vec_mul_f32_f32(vd, vmul); hvx_vmemu(dst + i * epv) = out; acc = hvx_vec_add_f32_f32(acc, hvx_vec_mul_f32_f32(out, vdot)); } - if (tail) { + if (nloe) { const uint32_t off = nvec * epv; - HVX_Vector vd = hvx_vmemu(dst + off); + HVX_Vector vd = hvx_vmemu(dst + off); HVX_Vector vdot = hvx_vmem(dot + off); - HVX_Vector out = hvx_vec_mul_f32_f32(vd, vmul); - hvx_vec_store_u(dst + off, tail * sizeof(float), out); - HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float)); + HVX_Vector out = hvx_vec_mul_f32_f32(vd, vmul); + hvx_vec_store_u(dst + off, nloe * sizeof(float), out); + HVX_VectorPred mask = Q6_Q_vsetq2_R(nloe * sizeof(float)); HVX_Vector prod = hvx_vec_mul_f32_f32(out, vdot); acc = hvx_vec_add_f32_f32(acc, Q6_V_vmux_QVV(mask, prod, Q6_V_vzero())); } - return hvx_vec_get_f32(hvx_vec_reduce_sum_f32(acc)); + return hvx_vec_reduce_sum_f32(acc); } -static inline float gdn_add_scaled_dot_f32(float * restrict dst, const float * restrict src, - float scale, const float * restrict dot, uint32_t n) { +static inline HVX_Vector gdn_add_scaled_dot_f32(float * restrict dst, const float * restrict src, + HVX_Vector vscale, const float * restrict dot, uint32_t n) { HVX_Vector acc = Q6_V_vzero(); - const HVX_Vector vscale = hvx_vec_splat_f32(scale); - const uint32_t epv = 128 / sizeof(float); + const uint32_t epv = 128 / sizeof(float); const uint32_t nvec = n / epv; - const uint32_t tail = n % epv; + const uint32_t nloe = n % epv; for (uint32_t i = 0; i < nvec; ++i) { - HVX_Vector vd = hvx_vmemu(dst + i * epv); - HVX_Vector vs = hvx_vmem(src + i * epv); + HVX_Vector vd = hvx_vmemu(dst + i * epv); + HVX_Vector vs = hvx_vmem(src + i * epv); HVX_Vector vdot = hvx_vmem(dot + i * epv); - HVX_Vector out = hvx_vec_add_f32_f32(vd, hvx_vec_mul_f32_f32(vs, vscale)); + HVX_Vector out = hvx_vec_add_f32_f32(vd, hvx_vec_mul_f32_f32(vs, vscale)); hvx_vmemu(dst + i * epv) = out; acc = hvx_vec_add_f32_f32(acc, hvx_vec_mul_f32_f32(out, vdot)); } - if (tail) { + if (nloe) { const uint32_t off = nvec * epv; - HVX_Vector vd = hvx_vmemu(dst + off); - HVX_Vector vs = hvx_vmem(src + off); + HVX_Vector vd = hvx_vmemu(dst + off); + HVX_Vector vs = hvx_vmem(src + off); HVX_Vector vdot = hvx_vmem(dot + off); - HVX_Vector out = hvx_vec_add_f32_f32(vd, hvx_vec_mul_f32_f32(vs, vscale)); - hvx_vec_store_u(dst + off, tail * sizeof(float), out); - HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float)); + HVX_Vector out = hvx_vec_add_f32_f32(vd, hvx_vec_mul_f32_f32(vs, vscale)); + hvx_vec_store_u(dst + off, nloe * sizeof(float), out); + HVX_VectorPred mask = Q6_Q_vsetq2_R(nloe * sizeof(float)); HVX_Vector prod = hvx_vec_mul_f32_f32(out, vdot); acc = hvx_vec_add_f32_f32(acc, Q6_V_vmux_QVV(mask, prod, Q6_V_vzero())); } - return hvx_vec_get_f32(hvx_vec_reduce_sum_f32(acc)); + return hvx_vec_reduce_sum_f32(acc); } static inline void gdn_mul_dot4_f32(float * restrict dst0, float * restrict dst1, @@ -126,7 +124,7 @@ static inline void gdn_mul_dot4_f32(float * restrict dst0, float * restrict dst1 const uint32_t epv = 128 / sizeof(float); const uint32_t nvec = n / epv; - const uint32_t tail = n % epv; + const uint32_t nloe = n % epv; for (uint32_t i = 0; i < nvec; ++i) { HVX_Vector vm = hvx_vmem(mul + i * epv); HVX_Vector vdot = hvx_vmem(dot + i * epv); @@ -147,11 +145,11 @@ static inline void gdn_mul_dot4_f32(float * restrict dst0, float * restrict dst1 acc3 = hvx_vec_add_f32_f32(acc3, hvx_vec_mul_f32_f32(out3, vdot)); } - if (tail) { + if (nloe) { const uint32_t off = nvec * epv; - HVX_Vector vm = hvx_vmem(mul + off); + HVX_Vector vm = hvx_vmem(mul + off); HVX_Vector vdot = hvx_vmem(dot + off); - HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float)); + HVX_VectorPred mask = Q6_Q_vsetq2_R(nloe * sizeof(float)); HVX_Vector zero = Q6_V_vzero(); HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + off), vm); @@ -159,10 +157,10 @@ static inline void gdn_mul_dot4_f32(float * restrict dst0, float * restrict dst1 HVX_Vector out2 = hvx_vec_mul_f32_f32(hvx_vmemu(dst2 + off), vm); HVX_Vector out3 = hvx_vec_mul_f32_f32(hvx_vmemu(dst3 + off), vm); - hvx_vec_store_u(dst0 + off, tail * sizeof(float), out0); - hvx_vec_store_u(dst1 + off, tail * sizeof(float), out1); - hvx_vec_store_u(dst2 + off, tail * sizeof(float), out2); - hvx_vec_store_u(dst3 + off, tail * sizeof(float), out3); + hvx_vec_store_u(dst0 + off, nloe * sizeof(float), out0); + hvx_vec_store_u(dst1 + off, nloe * sizeof(float), out1); + hvx_vec_store_u(dst2 + off, nloe * sizeof(float), out2); + hvx_vec_store_u(dst3 + off, nloe * sizeof(float), out3); acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero)); acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero)); @@ -185,7 +183,7 @@ static inline void gdn_mul_scalar_dot4_f32(float * restrict dst0, float * restri const uint32_t epv = 128 / sizeof(float); const uint32_t nvec = n / epv; - const uint32_t tail = n % epv; + const uint32_t nloe = n % epv; for (uint32_t i = 0; i < nvec; ++i) { HVX_Vector vdot = hvx_vmem(dot + i * epv); @@ -205,10 +203,10 @@ static inline void gdn_mul_scalar_dot4_f32(float * restrict dst0, float * restri acc3 = hvx_vec_add_f32_f32(acc3, hvx_vec_mul_f32_f32(out3, vdot)); } - if (tail) { + if (nloe) { const uint32_t off = nvec * epv; HVX_Vector vdot = hvx_vmem(dot + off); - HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float)); + HVX_VectorPred mask = Q6_Q_vsetq2_R(nloe * sizeof(float)); HVX_Vector zero = Q6_V_vzero(); HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + off), vmul); @@ -216,10 +214,10 @@ static inline void gdn_mul_scalar_dot4_f32(float * restrict dst0, float * restri HVX_Vector out2 = hvx_vec_mul_f32_f32(hvx_vmemu(dst2 + off), vmul); HVX_Vector out3 = hvx_vec_mul_f32_f32(hvx_vmemu(dst3 + off), vmul); - hvx_vec_store_u(dst0 + off, tail * sizeof(float), out0); - hvx_vec_store_u(dst1 + off, tail * sizeof(float), out1); - hvx_vec_store_u(dst2 + off, tail * sizeof(float), out2); - hvx_vec_store_u(dst3 + off, tail * sizeof(float), out3); + hvx_vec_store_u(dst0 + off, nloe * sizeof(float), out0); + hvx_vec_store_u(dst1 + off, nloe * sizeof(float), out1); + hvx_vec_store_u(dst2 + off, nloe * sizeof(float), out2); + hvx_vec_store_u(dst3 + off, nloe * sizeof(float), out3); acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero)); acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero)); @@ -246,7 +244,7 @@ static inline void gdn_add_scaled_dot4_f32(float * restrict dst0, float * restri const uint32_t epv = 128 / sizeof(float); const uint32_t nvec = n / epv; - const uint32_t tail = n % epv; + const uint32_t nloe = n % epv; for (uint32_t i = 0; i < nvec; ++i) { HVX_Vector vs = hvx_vmem(src + i * epv); HVX_Vector vdot = hvx_vmem(dot + i * epv); @@ -267,11 +265,11 @@ static inline void gdn_add_scaled_dot4_f32(float * restrict dst0, float * restri acc3 = hvx_vec_add_f32_f32(acc3, hvx_vec_mul_f32_f32(out3, vdot)); } - if (tail) { + if (nloe) { const uint32_t off = nvec * epv; HVX_Vector vs = hvx_vmem(src + off); HVX_Vector vdot = hvx_vmem(dot + off); - HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float)); + HVX_VectorPred mask = Q6_Q_vsetq2_R(nloe * sizeof(float)); HVX_Vector zero = Q6_V_vzero(); HVX_Vector out0 = hvx_vec_add_f32_f32(hvx_vmemu(dst0 + off), hvx_vec_mul_f32_f32(vs, scale0)); @@ -279,10 +277,10 @@ static inline void gdn_add_scaled_dot4_f32(float * restrict dst0, float * restri HVX_Vector out2 = hvx_vec_add_f32_f32(hvx_vmemu(dst2 + off), hvx_vec_mul_f32_f32(vs, scale2)); HVX_Vector out3 = hvx_vec_add_f32_f32(hvx_vmemu(dst3 + off), hvx_vec_mul_f32_f32(vs, scale3)); - hvx_vec_store_u(dst0 + off, tail * sizeof(float), out0); - hvx_vec_store_u(dst1 + off, tail * sizeof(float), out1); - hvx_vec_store_u(dst2 + off, tail * sizeof(float), out2); - hvx_vec_store_u(dst3 + off, tail * sizeof(float), out3); + hvx_vec_store_u(dst0 + off, nloe * sizeof(float), out0); + hvx_vec_store_u(dst1 + off, nloe * sizeof(float), out1); + hvx_vec_store_u(dst2 + off, nloe * sizeof(float), out2); + hvx_vec_store_u(dst3 + off, nloe * sizeof(float), out3); acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero)); acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero)); @@ -310,7 +308,7 @@ static inline void gdn_mul_dot8_f32(float * restrict dst0, float * restrict dst1 const uint32_t epv = 128 / sizeof(float); const uint32_t nvec = n / epv; - const uint32_t tail = n % epv; + const uint32_t nloe = n % epv; for (uint32_t i = 0; i < nvec; ++i) { HVX_Vector vm = hvx_vmem(mul + i * epv); HVX_Vector vdot = hvx_vmem(dot + i * epv); @@ -343,11 +341,11 @@ static inline void gdn_mul_dot8_f32(float * restrict dst0, float * restrict dst1 acc7 = hvx_vec_add_f32_f32(acc7, hvx_vec_mul_f32_f32(out7, vdot)); } - if (tail) { + if (nloe) { const uint32_t off = nvec * epv; HVX_Vector vm = hvx_vmem(mul + off); HVX_Vector vdot = hvx_vmem(dot + off); - HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float)); + HVX_VectorPred mask = Q6_Q_vsetq2_R(nloe * sizeof(float)); HVX_Vector zero = Q6_V_vzero(); HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + off), vm); @@ -359,14 +357,14 @@ static inline void gdn_mul_dot8_f32(float * restrict dst0, float * restrict dst1 HVX_Vector out6 = hvx_vec_mul_f32_f32(hvx_vmemu(dst6 + off), vm); HVX_Vector out7 = hvx_vec_mul_f32_f32(hvx_vmemu(dst7 + off), vm); - hvx_vec_store_u(dst0 + off, tail * sizeof(float), out0); - hvx_vec_store_u(dst1 + off, tail * sizeof(float), out1); - hvx_vec_store_u(dst2 + off, tail * sizeof(float), out2); - hvx_vec_store_u(dst3 + off, tail * sizeof(float), out3); - hvx_vec_store_u(dst4 + off, tail * sizeof(float), out4); - hvx_vec_store_u(dst5 + off, tail * sizeof(float), out5); - hvx_vec_store_u(dst6 + off, tail * sizeof(float), out6); - hvx_vec_store_u(dst7 + off, tail * sizeof(float), out7); + hvx_vec_store_u(dst0 + off, nloe * sizeof(float), out0); + hvx_vec_store_u(dst1 + off, nloe * sizeof(float), out1); + hvx_vec_store_u(dst2 + off, nloe * sizeof(float), out2); + hvx_vec_store_u(dst3 + off, nloe * sizeof(float), out3); + hvx_vec_store_u(dst4 + off, nloe * sizeof(float), out4); + hvx_vec_store_u(dst5 + off, nloe * sizeof(float), out5); + hvx_vec_store_u(dst6 + off, nloe * sizeof(float), out6); + hvx_vec_store_u(dst7 + off, nloe * sizeof(float), out7); acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero)); acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero)); @@ -400,7 +398,7 @@ static inline void gdn_mul_scalar_dot8_f32(float * restrict dst0, float * restri const uint32_t epv = 128 / sizeof(float); const uint32_t nvec = n / epv; - const uint32_t tail = n % epv; + const uint32_t nloe = n % epv; for (uint32_t i = 0; i < nvec; ++i) { HVX_Vector vdot = hvx_vmem(dot + i * epv); @@ -432,10 +430,10 @@ static inline void gdn_mul_scalar_dot8_f32(float * restrict dst0, float * restri acc7 = hvx_vec_add_f32_f32(acc7, hvx_vec_mul_f32_f32(out7, vdot)); } - if (tail) { + if (nloe) { const uint32_t off = nvec * epv; HVX_Vector vdot = hvx_vmem(dot + off); - HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float)); + HVX_VectorPred mask = Q6_Q_vsetq2_R(nloe * sizeof(float)); HVX_Vector zero = Q6_V_vzero(); HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + off), vmul); @@ -447,14 +445,14 @@ static inline void gdn_mul_scalar_dot8_f32(float * restrict dst0, float * restri HVX_Vector out6 = hvx_vec_mul_f32_f32(hvx_vmemu(dst6 + off), vmul); HVX_Vector out7 = hvx_vec_mul_f32_f32(hvx_vmemu(dst7 + off), vmul); - hvx_vec_store_u(dst0 + off, tail * sizeof(float), out0); - hvx_vec_store_u(dst1 + off, tail * sizeof(float), out1); - hvx_vec_store_u(dst2 + off, tail * sizeof(float), out2); - hvx_vec_store_u(dst3 + off, tail * sizeof(float), out3); - hvx_vec_store_u(dst4 + off, tail * sizeof(float), out4); - hvx_vec_store_u(dst5 + off, tail * sizeof(float), out5); - hvx_vec_store_u(dst6 + off, tail * sizeof(float), out6); - hvx_vec_store_u(dst7 + off, tail * sizeof(float), out7); + hvx_vec_store_u(dst0 + off, nloe * sizeof(float), out0); + hvx_vec_store_u(dst1 + off, nloe * sizeof(float), out1); + hvx_vec_store_u(dst2 + off, nloe * sizeof(float), out2); + hvx_vec_store_u(dst3 + off, nloe * sizeof(float), out3); + hvx_vec_store_u(dst4 + off, nloe * sizeof(float), out4); + hvx_vec_store_u(dst5 + off, nloe * sizeof(float), out5); + hvx_vec_store_u(dst6 + off, nloe * sizeof(float), out6); + hvx_vec_store_u(dst7 + off, nloe * sizeof(float), out7); acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero)); acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero)); @@ -496,7 +494,7 @@ static inline void gdn_add_scaled_dot8_f32(float * restrict dst0, float * restri const uint32_t epv = 128 / sizeof(float); const uint32_t nvec = n / epv; - const uint32_t tail = n % epv; + const uint32_t nloe = n % epv; for (uint32_t i = 0; i < nvec; ++i) { HVX_Vector vs = hvx_vmem(src + i * epv); HVX_Vector vdot = hvx_vmem(dot + i * epv); @@ -529,11 +527,11 @@ static inline void gdn_add_scaled_dot8_f32(float * restrict dst0, float * restri acc7 = hvx_vec_add_f32_f32(acc7, hvx_vec_mul_f32_f32(out7, vdot)); } - if (tail) { + if (nloe) { const uint32_t off = nvec * epv; HVX_Vector vs = hvx_vmem(src + off); HVX_Vector vdot = hvx_vmem(dot + off); - HVX_VectorPred mask = Q6_Q_vsetq2_R(tail * sizeof(float)); + HVX_VectorPred mask = Q6_Q_vsetq2_R(nloe * sizeof(float)); HVX_Vector zero = Q6_V_vzero(); HVX_Vector out0 = hvx_vec_add_f32_f32(hvx_vmemu(dst0 + off), hvx_vec_mul_f32_f32(vs, scale0)); @@ -545,14 +543,14 @@ static inline void gdn_add_scaled_dot8_f32(float * restrict dst0, float * restri HVX_Vector out6 = hvx_vec_add_f32_f32(hvx_vmemu(dst6 + off), hvx_vec_mul_f32_f32(vs, scale6)); HVX_Vector out7 = hvx_vec_add_f32_f32(hvx_vmemu(dst7 + off), hvx_vec_mul_f32_f32(vs, scale7)); - hvx_vec_store_u(dst0 + off, tail * sizeof(float), out0); - hvx_vec_store_u(dst1 + off, tail * sizeof(float), out1); - hvx_vec_store_u(dst2 + off, tail * sizeof(float), out2); - hvx_vec_store_u(dst3 + off, tail * sizeof(float), out3); - hvx_vec_store_u(dst4 + off, tail * sizeof(float), out4); - hvx_vec_store_u(dst5 + off, tail * sizeof(float), out5); - hvx_vec_store_u(dst6 + off, tail * sizeof(float), out6); - hvx_vec_store_u(dst7 + off, tail * sizeof(float), out7); + hvx_vec_store_u(dst0 + off, nloe * sizeof(float), out0); + hvx_vec_store_u(dst1 + off, nloe * sizeof(float), out1); + hvx_vec_store_u(dst2 + off, nloe * sizeof(float), out2); + hvx_vec_store_u(dst3 + off, nloe * sizeof(float), out3); + hvx_vec_store_u(dst4 + off, nloe * sizeof(float), out4); + hvx_vec_store_u(dst5 + off, nloe * sizeof(float), out5); + hvx_vec_store_u(dst6 + off, nloe * sizeof(float), out6); + hvx_vec_store_u(dst7 + off, nloe * sizeof(float), out7); acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero)); acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero)); @@ -605,26 +603,65 @@ static void gated_delta_net_f32_pp_thread(unsigned int nth, unsigned int ith, vo float local_gate[HTP_GDN_MAX_SV] __attribute__((aligned(128))); float local_q[HTP_GDN_MAX_SV] __attribute__((aligned(128))); float local_k[HTP_GDN_MAX_SV] __attribute__((aligned(128))); - float local_sums[4] __attribute__((aligned(128))); + float local_sums[32] __attribute__((aligned(128))); + + dma_queue * dma = octx->ctx->dma[ith]; + size_t state_aligned = (size_t) S_v * S_v * sizeof(float); + state_aligned = (state_aligned + 127) & ~(size_t)127; + float * s_work[2]; + s_work[0] = (float *) (gctx->vtcm_base + gctx->vtcm_per_thread * ith); + s_work[1] = s_work[0] + state_aligned / sizeof(float); + + struct fastdiv_values fd_H = init_fastdiv_values(H); + struct fastdiv_values fd_q1 = init_fastdiv_values(q->ne[1]); + struct fastdiv_values fd_k1 = init_fastdiv_values(k->ne[1]); + struct fastdiv_values fd_rq3 = init_fastdiv_values(rq3); + struct fastdiv_values fd_rk3 = init_fastdiv_values(rk3); const uint64_t state_seq_stride = state->nb[2] / sizeof(float); const uint64_t state_size_per_snap = (uint64_t) S_v * S_v * H * n_seqs; const int64_t shift = (int64_t) n_tokens - (int64_t) K; + uint32_t ir_prefetch = ith; + int spad_idx = 0; + + // Prefetch preamble (up to 2 steps) + for (int k = 0; k < 2 && ir_prefetch < total_rows; k++) { + const uint32_t piv1 = fastmodulo(ir_prefetch, H, &fd_H); + const uint32_t piv3 = fastdiv(ir_prefetch, &fd_H); + const float * ps_in = state_in_base + (uint64_t) piv3 * state_seq_stride + (uint64_t) piv1 * S_v * S_v; + float * ps_out = state_out_base + (uint64_t) (K - 1) * state_size_per_snap + ((uint64_t) piv3 * H + piv1) * S_v * S_v; + + // Push dummy write-back + dma_queue_push(dma, dma_make_ptr(ps_out, s_work[spad_idx]), + S_v * sizeof(float), S_v * sizeof(float), + S_v * sizeof(float), 0); + + // Push fetch + dma_queue_push(dma, dma_make_ptr(s_work[spad_idx], ps_in), + S_v * sizeof(float), S_v * sizeof(float), + S_v * sizeof(float), S_v); + + ir_prefetch += nth; + spad_idx ^= 1; + } + + int curr_spad_idx = 0; for (uint32_t ir = ith; ir < total_rows; ir += nth) { - const uint32_t iv1 = ir % H; - const uint32_t iv3 = ir / H; + dma_queue_pop(dma); + dma_queue_pop(dma); - const uint32_t iq1 = iv1 % q->ne[1]; - const uint32_t ik1 = iv1 % k->ne[1]; - const uint32_t iq3 = iv3 / rq3; - const uint32_t ik3 = iv3 / rk3; + float * s_work_curr = s_work[curr_spad_idx]; - float * s_out = state_out_base + (uint64_t) (K - 1) * state_size_per_snap + ((uint64_t) iv3 * H + iv1) * S_v * S_v; - const float * s_in = state_in_base + (uint64_t) iv3 * state_seq_stride + (uint64_t) iv1 * S_v * S_v; + const uint32_t iv1 = fastmodulo(ir, H, &fd_H); + const uint32_t iv3 = fastdiv(ir, &fd_H); + + const uint32_t iq1 = fastmodulo(iv1, q->ne[1], &fd_q1); + const uint32_t ik1 = fastmodulo(iv1, k->ne[1], &fd_k1); + const uint32_t iq3 = fastdiv(iv3, &fd_rq3); + const uint32_t ik3 = fastdiv(iv3, &fd_rk3); - memcpy(s_out, s_in, gctx->state_bytes); - float * s_work = s_out; + float * s_out = state_out_base + (uint64_t) (K - 1) * state_size_per_snap + ((uint64_t) iv3 * H + iv1) * S_v * S_v; float * attn_data = dst_base + ((uint64_t) iv3 * n_tokens * H + iv1) * S_v; @@ -640,57 +677,117 @@ static void gated_delta_net_f32_pp_thread(unsigned int nth, unsigned int ith, vo const float beta_val = *(const float *) ((const uint8_t *) (uintptr_t) beta->data + (uint64_t) iv3 * beta->nb[3] + (uint64_t) t * beta->nb[2] + (uint64_t) iv1 * beta->nb[1]); - memcpy(local_q, q_t, (size_t) S_v * sizeof(float)); - memcpy(local_k, k_t, (size_t) S_v * sizeof(float)); + hvx_copy_f32_au((uint8_t *) local_q, (const uint8_t *) q_t, S_v); + hvx_copy_f32_au((uint8_t *) local_k, (const uint8_t *) k_t, S_v); if (kda) { hvx_exp_f32((uint8_t *) local_gate, (const uint8_t *) g_t, S_v, false); uint32_t j = 0; + for (; j + 8 <= S_v; j += 8) { + float * row0 = s_work_curr + (uint64_t) (j + 0) * S_v; + float * row1 = s_work_curr + (uint64_t) (j + 1) * S_v; + float * row2 = s_work_curr + (uint64_t) (j + 2) * S_v; + float * row3 = s_work_curr + (uint64_t) (j + 3) * S_v; + float * row4 = s_work_curr + (uint64_t) (j + 4) * S_v; + float * row5 = s_work_curr + (uint64_t) (j + 5) * S_v; + float * row6 = s_work_curr + (uint64_t) (j + 6) * S_v; + float * row7 = s_work_curr + (uint64_t) (j + 7) * S_v; + gdn_mul_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7, + local_gate, local_k, S_v, local_sums); + + float local_delta_b[32] __attribute__((aligned(128))); + HVX_Vector vv_t = hvx_vmemu(v_t + j); + HVX_Vector v_local_sums = hvx_vmem(local_sums); + HVX_Vector diff = hvx_vec_sub_f32_f32(vv_t, v_local_sums); + hvx_vmem(local_delta_b) = hvx_vec_mul_f32_f32(diff, hvx_vec_splat_f32(beta_val)); + + gdn_add_scaled_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7, + local_k, local_delta_b, local_q, S_v, local_sums); + + HVX_Vector res_attn = hvx_vec_mul_f32_f32(hvx_vmem(local_sums), hvx_vec_splat_f32(scale)); + hvx_vec_store_u(attn_data + j, 8 * sizeof(float), res_attn); + } for (; j + 4 <= S_v; j += 4) { - float * row0 = s_work + (uint64_t) (j + 0) * S_v; - float * row1 = s_work + (uint64_t) (j + 1) * S_v; - float * row2 = s_work + (uint64_t) (j + 2) * S_v; - float * row3 = s_work + (uint64_t) (j + 3) * S_v; + float * row0 = s_work_curr + (uint64_t) (j + 0) * S_v; + float * row1 = s_work_curr + (uint64_t) (j + 1) * S_v; + float * row2 = s_work_curr + (uint64_t) (j + 2) * S_v; + float * row3 = s_work_curr + (uint64_t) (j + 3) * S_v; gdn_mul_dot4_f32(row0, row1, row2, row3, local_gate, local_k, S_v, local_sums); - float local_delta_b[4] __attribute__((aligned(128))); - for (uint32_t r = 0; r < 4; ++r) { - local_delta_b[r] = (v_t[j + r] - local_sums[r]) * beta_val; - } + + float local_delta_b[32] __attribute__((aligned(128))); + HVX_Vector vv_t = hvx_vmemu(v_t + j); + HVX_Vector v_local_sums = hvx_vmem(local_sums); + HVX_Vector diff = hvx_vec_sub_f32_f32(vv_t, v_local_sums); + hvx_vmem(local_delta_b) = hvx_vec_mul_f32_f32(diff, hvx_vec_splat_f32(beta_val)); + gdn_add_scaled_dot4_f32(row0, row1, row2, row3, local_k, local_delta_b, local_q, S_v, local_sums); - for (uint32_t r = 0; r < 4; ++r) { - attn_data[j + r] = local_sums[r] * scale; - } + + HVX_Vector res_attn = hvx_vec_mul_f32_f32(hvx_vmem(local_sums), hvx_vec_splat_f32(scale)); + hvx_vec_store_u(attn_data + j, 4 * sizeof(float), res_attn); } + HVX_Vector vscale_splat = hvx_vec_splat_f32(scale); for (; j < S_v; ++j) { - float * row = s_work + (uint64_t) j * S_v; - const float sum = gdn_mul_dot_f32(row, local_gate, local_k, S_v); - const float dj = (v_t[j] - sum) * beta_val; - attn_data[j] = gdn_add_scaled_dot_f32(row, local_k, dj, local_q, S_v) * scale; + float * row = s_work_curr + (uint64_t) j * S_v; + HVX_Vector vsum = gdn_mul_dot_f32(row, local_gate, local_k, S_v); + HVX_Vector vv_t = hvx_vec_splat_f32(v_t[j]); + HVX_Vector vdj = hvx_vec_mul_f32_f32(hvx_vec_sub_f32_f32(vv_t, vsum), hvx_vec_splat_f32(beta_val)); + HVX_Vector vres = gdn_add_scaled_dot_f32(row, local_k, vdj, local_q, S_v); + attn_data[j] = hvx_vec_get_f32(hvx_vec_mul_f32_f32(vres, vscale_splat)); } } else { const float gate = expf(g_t[0]); uint32_t j = 0; + for (; j + 8 <= S_v; j += 8) { + float * row0 = s_work_curr + (uint64_t) (j + 0) * S_v; + float * row1 = s_work_curr + (uint64_t) (j + 1) * S_v; + float * row2 = s_work_curr + (uint64_t) (j + 2) * S_v; + float * row3 = s_work_curr + (uint64_t) (j + 3) * S_v; + float * row4 = s_work_curr + (uint64_t) (j + 4) * S_v; + float * row5 = s_work_curr + (uint64_t) (j + 5) * S_v; + float * row6 = s_work_curr + (uint64_t) (j + 6) * S_v; + float * row7 = s_work_curr + (uint64_t) (j + 7) * S_v; + gdn_mul_scalar_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7, + gate, local_k, S_v, local_sums); + + float local_delta_b[32] __attribute__((aligned(128))); + HVX_Vector vv_t = hvx_vmemu(v_t + j); + HVX_Vector v_local_sums = hvx_vmem(local_sums); + HVX_Vector diff = hvx_vec_sub_f32_f32(vv_t, v_local_sums); + hvx_vmem(local_delta_b) = hvx_vec_mul_f32_f32(diff, hvx_vec_splat_f32(beta_val)); + + gdn_add_scaled_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7, + local_k, local_delta_b, local_q, S_v, local_sums); + + HVX_Vector res_attn = hvx_vec_mul_f32_f32(hvx_vmem(local_sums), hvx_vec_splat_f32(scale)); + hvx_vec_store_u(attn_data + j, 8 * sizeof(float), res_attn); + } for (; j + 4 <= S_v; j += 4) { - float * row0 = s_work + (uint64_t) (j + 0) * S_v; - float * row1 = s_work + (uint64_t) (j + 1) * S_v; - float * row2 = s_work + (uint64_t) (j + 2) * S_v; - float * row3 = s_work + (uint64_t) (j + 3) * S_v; + float * row0 = s_work_curr + (uint64_t) (j + 0) * S_v; + float * row1 = s_work_curr + (uint64_t) (j + 1) * S_v; + float * row2 = s_work_curr + (uint64_t) (j + 2) * S_v; + float * row3 = s_work_curr + (uint64_t) (j + 3) * S_v; gdn_mul_scalar_dot4_f32(row0, row1, row2, row3, gate, local_k, S_v, local_sums); - float local_delta_b[4] __attribute__((aligned(128))); - for (uint32_t r = 0; r < 4; ++r) { - local_delta_b[r] = (v_t[j + r] - local_sums[r]) * beta_val; - } + + float local_delta_b[32] __attribute__((aligned(128))); + HVX_Vector vv_t = hvx_vmemu(v_t + j); + HVX_Vector v_local_sums = hvx_vmem(local_sums); + HVX_Vector diff = hvx_vec_sub_f32_f32(vv_t, v_local_sums); + hvx_vmem(local_delta_b) = hvx_vec_mul_f32_f32(diff, hvx_vec_splat_f32(beta_val)); + gdn_add_scaled_dot4_f32(row0, row1, row2, row3, local_k, local_delta_b, local_q, S_v, local_sums); - for (uint32_t r = 0; r < 4; ++r) { - attn_data[j + r] = local_sums[r] * scale; - } + + HVX_Vector res_attn = hvx_vec_mul_f32_f32(hvx_vmem(local_sums), hvx_vec_splat_f32(scale)); + hvx_vec_store_u(attn_data + j, 4 * sizeof(float), res_attn); } + HVX_Vector vscale_splat = hvx_vec_splat_f32(scale); for (; j < S_v; ++j) { - float * row = s_work + (uint64_t) j * S_v; - const float sum = gdn_mul_scalar_dot_f32(row, gate, local_k, S_v); - const float dj = (v_t[j] - sum) * beta_val; - attn_data[j] = gdn_add_scaled_dot_f32(row, local_k, dj, local_q, S_v) * scale; + float * row = s_work_curr + (uint64_t) j * S_v; + HVX_Vector vsum = gdn_mul_scalar_dot_f32(row, gate, local_k, S_v); + HVX_Vector vv_t = hvx_vec_splat_f32(v_t[j]); + HVX_Vector vdj = hvx_vec_mul_f32_f32(hvx_vec_sub_f32_f32(vv_t, vsum), hvx_vec_splat_f32(beta_val)); + HVX_Vector vres = gdn_add_scaled_dot_f32(row, local_k, vdj, local_q, S_v); + attn_data[j] = hvx_vec_get_f32(hvx_vec_mul_f32_f32(vres, vscale_splat)); } } @@ -698,17 +795,40 @@ static void gated_delta_net_f32_pp_thread(unsigned int nth, unsigned int ith, vo const int64_t target_slot = (int64_t) t - shift; if (target_slot >= 0 && target_slot < (int64_t) K) { float * curr_state_o = state_out_base + (uint64_t) target_slot * state_size_per_snap + ((uint64_t) iv3 * H + iv1) * S_v * S_v; - if (curr_state_o != s_work) { - memcpy(curr_state_o, s_work, gctx->state_bytes); + if (curr_state_o != s_out) { + hvx_copy_f32_uu((uint8_t *) curr_state_o, (const uint8_t *) s_work_curr, S_v * S_v); } } } attn_data += (uint64_t) S_v * H; } + + // Push real write-back + dma_queue_push(dma, dma_make_ptr(s_out, s_work_curr), + S_v * sizeof(float), S_v * sizeof(float), + S_v * sizeof(float), S_v); + + // Prefetch next block (if any) + if (ir_prefetch < total_rows) { + const uint32_t piv1 = fastmodulo(ir_prefetch, H, &fd_H); + const uint32_t piv3 = fastdiv(ir_prefetch, &fd_H); + const float * ps_in = state_in_base + (uint64_t) piv3 * state_seq_stride + (uint64_t) piv1 * S_v * S_v; + + dma_queue_push(dma, dma_make_ptr(s_work[spad_idx], ps_in), + S_v * sizeof(float), S_v * sizeof(float), + S_v * sizeof(float), S_v); + + ir_prefetch += nth; + spad_idx ^= 1; + } + + curr_spad_idx ^= 1; } + dma_queue_flush(dma); } + static void gated_delta_net_f32_tg_thread(unsigned int nth, unsigned int ith, void * data) { struct htp_gdn_context * gctx = (struct htp_gdn_context *) data; struct htp_ops_context * octx = gctx->octx; @@ -743,41 +863,64 @@ static void gated_delta_net_f32_tg_thread(unsigned int nth, unsigned int ith, vo float local_gate[HTP_GDN_MAX_SV] __attribute__((aligned(128))); float local_q[HTP_GDN_MAX_SV] __attribute__((aligned(128))); float local_k[HTP_GDN_MAX_SV] __attribute__((aligned(128))); - float local_sums[8] __attribute__((aligned(128))); + float local_sums[32] __attribute__((aligned(128))); dma_queue * dma = octx->ctx->dma[ith]; + size_t state_aligned = (size_t) S_v * S_v * sizeof(float); + state_aligned = (state_aligned + 127) & ~(size_t)127; + float * s_work[2]; + s_work[0] = (float *) (gctx->vtcm_base + gctx->vtcm_per_thread * ith); + s_work[1] = s_work[0] + state_aligned / sizeof(float); - uint8_t * spad = NULL; - if (gctx->use_vtcm) { - spad = gctx->vtcm_state_base + gctx->vtcm_state_per_thread * ith; - } + struct fastdiv_values fd_H = init_fastdiv_values(H); + struct fastdiv_values fd_q1 = init_fastdiv_values(q->ne[1]); + struct fastdiv_values fd_k1 = init_fastdiv_values(k->ne[1]); + struct fastdiv_values fd_rq3 = init_fastdiv_values(rq3); + struct fastdiv_values fd_rk3 = init_fastdiv_values(rk3); const uint64_t state_seq_stride = state->nb[2] / sizeof(float); const uint64_t state_size_per_snap = (uint64_t) S_v * S_v * H * n_seqs; + uint32_t ir_prefetch = ith; + int spad_idx = 0; + + // Prefetch preamble (up to 2 steps) + for (int k = 0; k < 2 && ir_prefetch < total_rows; k++) { + const uint32_t piv1 = fastmodulo(ir_prefetch, H, &fd_H); + const uint32_t piv3 = fastdiv(ir_prefetch, &fd_H); + const float * ps_in = state_in_base + (uint64_t) piv3 * state_seq_stride + (uint64_t) piv1 * S_v * S_v; + float * ps_out = state_out_base + (uint64_t) (K - 1) * state_size_per_snap + ((uint64_t) piv3 * H + piv1) * S_v * S_v; + + // Push dummy write-back + dma_queue_push(dma, dma_make_ptr(ps_out, s_work[spad_idx]), + S_v * sizeof(float), S_v * sizeof(float), + S_v * sizeof(float), 0); + + // Push fetch + dma_queue_push(dma, dma_make_ptr(s_work[spad_idx], ps_in), + S_v * sizeof(float), S_v * sizeof(float), + S_v * sizeof(float), S_v); + + ir_prefetch += nth; + spad_idx ^= 1; + } + + int curr_spad_idx = 0; for (uint32_t ir = ith; ir < total_rows; ir += nth) { - const uint32_t iv1 = ir % H; - const uint32_t iv3 = ir / H; + dma_queue_pop(dma); + dma_queue_pop(dma); - const uint32_t iq1 = iv1 % q->ne[1]; - const uint32_t ik1 = iv1 % k->ne[1]; - const uint32_t iq3 = iv3 / rq3; - const uint32_t ik3 = iv3 / rk3; + float * s_work_curr = s_work[curr_spad_idx]; - float * s_out = state_out_base + (uint64_t) (K - 1) * state_size_per_snap + ((uint64_t) iv3 * H + iv1) * S_v * S_v; - const float * s_in = state_in_base + (uint64_t) iv3 * state_seq_stride + (uint64_t) iv1 * S_v * S_v; - float * s_work; + const uint32_t iv1 = fastmodulo(ir, H, &fd_H); + const uint32_t iv3 = fastdiv(ir, &fd_H); - if (spad) { - dma_queue_push(dma, dma_make_ptr(spad, s_in), - S_v * sizeof(float), S_v * sizeof(float), - S_v * sizeof(float), S_v); - dma_queue_pop(dma); - s_work = (float *) spad; - } else { - s_work = s_out; - memcpy(s_work, s_in, gctx->state_bytes); - } + const uint32_t iq1 = fastmodulo(iv1, q->ne[1], &fd_q1); + const uint32_t ik1 = fastmodulo(iv1, k->ne[1], &fd_k1); + const uint32_t iq3 = fastdiv(iv3, &fd_rq3); + const uint32_t ik3 = fastdiv(iv3, &fd_rk3); + + float * s_out = state_out_base + (uint64_t) (K - 1) * state_size_per_snap + ((uint64_t) iv3 * H + iv1) * S_v * S_v; float * attn_data = dst_base + ((uint64_t) iv3 * H + iv1) * S_v; @@ -792,111 +935,145 @@ static void gated_delta_net_f32_tg_thread(unsigned int nth, unsigned int ith, vo const float beta_val = *(const float *) ((const uint8_t *) (uintptr_t) beta->data + (uint64_t) iv3 * beta->nb[3] + (uint64_t) iv1 * beta->nb[1]); - memcpy(local_q, q_t, (size_t) S_v * sizeof(float)); - memcpy(local_k, k_t, (size_t) S_v * sizeof(float)); + hvx_copy_f32_au((uint8_t *) local_q, (const uint8_t *) q_t, S_v); + hvx_copy_f32_au((uint8_t *) local_k, (const uint8_t *) k_t, S_v); if (kda) { hvx_exp_f32((uint8_t *) local_gate, (const uint8_t *) g_t, S_v, false); uint32_t j = 0; for (; j + 8 <= S_v; j += 8) { - float * row0 = s_work + (uint64_t) (j + 0) * S_v; - float * row1 = s_work + (uint64_t) (j + 1) * S_v; - float * row2 = s_work + (uint64_t) (j + 2) * S_v; - float * row3 = s_work + (uint64_t) (j + 3) * S_v; - float * row4 = s_work + (uint64_t) (j + 4) * S_v; - float * row5 = s_work + (uint64_t) (j + 5) * S_v; - float * row6 = s_work + (uint64_t) (j + 6) * S_v; - float * row7 = s_work + (uint64_t) (j + 7) * S_v; + float * row0 = s_work_curr + (uint64_t) (j + 0) * S_v; + float * row1 = s_work_curr + (uint64_t) (j + 1) * S_v; + float * row2 = s_work_curr + (uint64_t) (j + 2) * S_v; + float * row3 = s_work_curr + (uint64_t) (j + 3) * S_v; + float * row4 = s_work_curr + (uint64_t) (j + 4) * S_v; + float * row5 = s_work_curr + (uint64_t) (j + 5) * S_v; + float * row6 = s_work_curr + (uint64_t) (j + 6) * S_v; + float * row7 = s_work_curr + (uint64_t) (j + 7) * S_v; gdn_mul_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7, local_gate, local_k, S_v, local_sums); - float local_delta_b[8] __attribute__((aligned(128))); - for (uint32_t r = 0; r < 8; ++r) { - local_delta_b[r] = (v_t[j + r] - local_sums[r]) * beta_val; - } + + float local_delta_b[32] __attribute__((aligned(128))); + HVX_Vector vv_t = hvx_vmemu(v_t + j); + HVX_Vector v_local_sums = hvx_vmem(local_sums); + HVX_Vector diff = hvx_vec_sub_f32_f32(vv_t, v_local_sums); + hvx_vmem(local_delta_b) = hvx_vec_mul_f32_f32(diff, hvx_vec_splat_f32(beta_val)); + gdn_add_scaled_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7, local_k, local_delta_b, local_q, S_v, local_sums); - for (uint32_t r = 0; r < 8; ++r) { - attn_data[j + r] = local_sums[r] * scale; - } + + HVX_Vector res_attn = hvx_vec_mul_f32_f32(hvx_vmem(local_sums), hvx_vec_splat_f32(scale)); + hvx_vec_store_u(attn_data + j, 8 * sizeof(float), res_attn); } for (; j + 4 <= S_v; j += 4) { - float * row0 = s_work + (uint64_t) (j + 0) * S_v; - float * row1 = s_work + (uint64_t) (j + 1) * S_v; - float * row2 = s_work + (uint64_t) (j + 2) * S_v; - float * row3 = s_work + (uint64_t) (j + 3) * S_v; + float * row0 = s_work_curr + (uint64_t) (j + 0) * S_v; + float * row1 = s_work_curr + (uint64_t) (j + 1) * S_v; + float * row2 = s_work_curr + (uint64_t) (j + 2) * S_v; + float * row3 = s_work_curr + (uint64_t) (j + 3) * S_v; gdn_mul_dot4_f32(row0, row1, row2, row3, local_gate, local_k, S_v, local_sums); - float local_delta_b[4] __attribute__((aligned(128))); - for (uint32_t r = 0; r < 4; ++r) { - local_delta_b[r] = (v_t[j + r] - local_sums[r]) * beta_val; - } + + float local_delta_b[32] __attribute__((aligned(128))); + HVX_Vector vv_t = hvx_vmemu(v_t + j); + HVX_Vector v_local_sums = hvx_vmem(local_sums); + HVX_Vector diff = hvx_vec_sub_f32_f32(vv_t, v_local_sums); + hvx_vmem(local_delta_b) = hvx_vec_mul_f32_f32(diff, hvx_vec_splat_f32(beta_val)); + gdn_add_scaled_dot4_f32(row0, row1, row2, row3, local_k, local_delta_b, local_q, S_v, local_sums); - for (uint32_t r = 0; r < 4; ++r) { - attn_data[j + r] = local_sums[r] * scale; - } + + HVX_Vector res_attn = hvx_vec_mul_f32_f32(hvx_vmem(local_sums), hvx_vec_splat_f32(scale)); + hvx_vec_store_u(attn_data + j, 4 * sizeof(float), res_attn); } + HVX_Vector vscale_splat = hvx_vec_splat_f32(scale); for (; j < S_v; ++j) { - float * row = s_work + (uint64_t) j * S_v; - const float sum = gdn_mul_dot_f32(row, local_gate, local_k, S_v); - const float dj = (v_t[j] - sum) * beta_val; - attn_data[j] = gdn_add_scaled_dot_f32(row, local_k, dj, local_q, S_v) * scale; + float * row = s_work_curr + (uint64_t) j * S_v; + HVX_Vector vsum = gdn_mul_dot_f32(row, local_gate, local_k, S_v); + HVX_Vector vv_t = hvx_vec_splat_f32(v_t[j]); + HVX_Vector vdj = hvx_vec_mul_f32_f32(hvx_vec_sub_f32_f32(vv_t, vsum), hvx_vec_splat_f32(beta_val)); + HVX_Vector vres = gdn_add_scaled_dot_f32(row, local_k, vdj, local_q, S_v); + attn_data[j] = hvx_vec_get_f32(hvx_vec_mul_f32_f32(vres, vscale_splat)); } } else { const float gate = expf(g_t[0]); uint32_t j = 0; for (; j + 8 <= S_v; j += 8) { - float * row0 = s_work + (uint64_t) (j + 0) * S_v; - float * row1 = s_work + (uint64_t) (j + 1) * S_v; - float * row2 = s_work + (uint64_t) (j + 2) * S_v; - float * row3 = s_work + (uint64_t) (j + 3) * S_v; - float * row4 = s_work + (uint64_t) (j + 4) * S_v; - float * row5 = s_work + (uint64_t) (j + 5) * S_v; - float * row6 = s_work + (uint64_t) (j + 6) * S_v; - float * row7 = s_work + (uint64_t) (j + 7) * S_v; + float * row0 = s_work_curr + (uint64_t) (j + 0) * S_v; + float * row1 = s_work_curr + (uint64_t) (j + 1) * S_v; + float * row2 = s_work_curr + (uint64_t) (j + 2) * S_v; + float * row3 = s_work_curr + (uint64_t) (j + 3) * S_v; + float * row4 = s_work_curr + (uint64_t) (j + 4) * S_v; + float * row5 = s_work_curr + (uint64_t) (j + 5) * S_v; + float * row6 = s_work_curr + (uint64_t) (j + 6) * S_v; + float * row7 = s_work_curr + (uint64_t) (j + 7) * S_v; gdn_mul_scalar_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7, gate, local_k, S_v, local_sums); - float local_delta_b[8] __attribute__((aligned(128))); - for (uint32_t r = 0; r < 8; ++r) { - local_delta_b[r] = (v_t[j + r] - local_sums[r]) * beta_val; - } + + float local_delta_b[32] __attribute__((aligned(128))); + HVX_Vector vv_t = hvx_vmemu(v_t + j); + HVX_Vector v_local_sums = hvx_vmem(local_sums); + HVX_Vector diff = hvx_vec_sub_f32_f32(vv_t, v_local_sums); + hvx_vmem(local_delta_b) = hvx_vec_mul_f32_f32(diff, hvx_vec_splat_f32(beta_val)); + gdn_add_scaled_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7, local_k, local_delta_b, local_q, S_v, local_sums); - for (uint32_t r = 0; r < 8; ++r) { - attn_data[j + r] = local_sums[r] * scale; - } + + HVX_Vector res_attn = hvx_vec_mul_f32_f32(hvx_vmem(local_sums), hvx_vec_splat_f32(scale)); + hvx_vec_store_u(attn_data + j, 8 * sizeof(float), res_attn); } for (; j + 4 <= S_v; j += 4) { - float * row0 = s_work + (uint64_t) (j + 0) * S_v; - float * row1 = s_work + (uint64_t) (j + 1) * S_v; - float * row2 = s_work + (uint64_t) (j + 2) * S_v; - float * row3 = s_work + (uint64_t) (j + 3) * S_v; + float * row0 = s_work_curr + (uint64_t) (j + 0) * S_v; + float * row1 = s_work_curr + (uint64_t) (j + 1) * S_v; + float * row2 = s_work_curr + (uint64_t) (j + 2) * S_v; + float * row3 = s_work_curr + (uint64_t) (j + 3) * S_v; gdn_mul_scalar_dot4_f32(row0, row1, row2, row3, gate, local_k, S_v, local_sums); - float local_delta_b[4] __attribute__((aligned(128))); - for (uint32_t r = 0; r < 4; ++r) { - local_delta_b[r] = (v_t[j + r] - local_sums[r]) * beta_val; - } + + float local_delta_b[32] __attribute__((aligned(128))); + HVX_Vector vv_t = hvx_vmemu(v_t + j); + HVX_Vector v_local_sums = hvx_vmem(local_sums); + HVX_Vector diff = hvx_vec_sub_f32_f32(vv_t, v_local_sums); + hvx_vmem(local_delta_b) = hvx_vec_mul_f32_f32(diff, hvx_vec_splat_f32(beta_val)); + gdn_add_scaled_dot4_f32(row0, row1, row2, row3, local_k, local_delta_b, local_q, S_v, local_sums); - for (uint32_t r = 0; r < 4; ++r) { - attn_data[j + r] = local_sums[r] * scale; - } + + HVX_Vector res_attn = hvx_vec_mul_f32_f32(hvx_vmem(local_sums), hvx_vec_splat_f32(scale)); + hvx_vec_store_u(attn_data + j, 4 * sizeof(float), res_attn); } + HVX_Vector vscale_splat = hvx_vec_splat_f32(scale); for (; j < S_v; ++j) { - float * row = s_work + (uint64_t) j * S_v; - const float sum = gdn_mul_scalar_dot_f32(row, gate, local_k, S_v); - const float dj = (v_t[j] - sum) * beta_val; - attn_data[j] = gdn_add_scaled_dot_f32(row, local_k, dj, local_q, S_v) * scale; + float * row = s_work_curr + (uint64_t) j * S_v; + HVX_Vector vsum = gdn_mul_scalar_dot_f32(row, gate, local_k, S_v); + HVX_Vector vv_t = hvx_vec_splat_f32(v_t[j]); + HVX_Vector vdj = hvx_vec_mul_f32_f32(hvx_vec_sub_f32_f32(vv_t, vsum), hvx_vec_splat_f32(beta_val)); + HVX_Vector vres = gdn_add_scaled_dot_f32(row, local_k, vdj, local_q, S_v); + attn_data[j] = hvx_vec_get_f32(hvx_vec_mul_f32_f32(vres, vscale_splat)); } } - if (spad) { - dma_queue_push(dma, dma_make_ptr(s_out, spad), + // Push real write-back + dma_queue_push(dma, dma_make_ptr(s_out, s_work_curr), + S_v * sizeof(float), S_v * sizeof(float), + S_v * sizeof(float), S_v); + + // Prefetch next block (if any) + if (ir_prefetch < total_rows) { + const uint32_t piv1 = fastmodulo(ir_prefetch, H, &fd_H); + const uint32_t piv3 = fastdiv(ir_prefetch, &fd_H); + const float * ps_in = state_in_base + (uint64_t) piv3 * state_seq_stride + (uint64_t) piv1 * S_v * S_v; + + dma_queue_push(dma, dma_make_ptr(s_work[spad_idx], ps_in), S_v * sizeof(float), S_v * sizeof(float), S_v * sizeof(float), S_v); - dma_queue_pop(dma); + + ir_prefetch += nth; + spad_idx ^= 1; } + + curr_spad_idx ^= 1; } + dma_queue_flush(dma); } + int op_gated_delta_net(struct htp_ops_context * octx) { const struct htp_tensor * q = octx->src[0]; const struct htp_tensor * k = octx->src[1]; @@ -952,18 +1129,11 @@ int op_gated_delta_net(struct htp_ops_context * octx) { size_t state_aligned = (size_t) S_v * S_v * sizeof(float); state_aligned = (state_aligned + 127) & ~(size_t)127; - gctx.use_vtcm = false; - gctx.vtcm_state_base = NULL; - gctx.vtcm_state_per_thread = 0; + assert(octx->ctx->vtcm_base != NULL); + assert(octx->ctx->vtcm_size >= 2 * state_aligned * octx->n_threads); - if (n_tokens == 1 && octx->ctx->vtcm_base) { - size_t vtcm_total = state_aligned * octx->n_threads; - if (octx->ctx->vtcm_size >= vtcm_total) { - gctx.use_vtcm = true; - gctx.vtcm_state_base = octx->ctx->vtcm_base; - gctx.vtcm_state_per_thread = state_aligned; - } - } + gctx.vtcm_base = octx->ctx->vtcm_base; + gctx.vtcm_per_thread = 2 * state_aligned; if (n_tokens == 1) { worker_pool_run_func(octx->ctx->worker_pool, gated_delta_net_f32_tg_thread, &gctx, octx->n_threads); diff --git a/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c b/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c index f132c08500d..2796564fb75 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +++ b/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c @@ -17,14 +17,17 @@ #define GGML_COMMON_DECL_C #include "ggml-common.h" #include "hex-dma.h" +#include "hex-fastdiv.h" #include "hmx-profile.h" #include "hmx-queue.h" #include "hmx-utils.h" #include "htp-ctx.h" #include "htp-ops.h" #include "hvx-dump.h" +#include "hvx-copy.h" #include "hvx-reduce.h" #include "hvx-utils.h" +#include "hvx-flash-attn.h" #include "vtcm-utils.h" #include "worker-pool.h" @@ -46,7 +49,7 @@ // g_br = hex_align_up(gqa_factor * Br, 32) replaces Br for all Q/O/S/P/D dimensions. // Layout: Q + O_ping + O_pong + K_dma*2 + V_dma*2 + K_tile + V_tile + S + P + D + vectors + scales // Mask is DMA'd into a VTCM buffer (Br rows per KV block) to avoid DDR reads in softmax. -static size_t hmx_fa_compute_vtcm_usage(size_t gqa_factor, size_t DK, size_t DV, size_t Br, size_t Bc, size_t n_threads) { +static size_t hmx_fa_compute_vtcm_usage(size_t gqa_factor, size_t DK, size_t DV, size_t Br, size_t Bc, size_t n_threads, bool use_pipeline) { const size_t g_br = hex_align_up(gqa_factor * Br, HMX_FP16_TILE_N_ROWS); const size_t q_tile_size = hex_align_up(g_br * DK * sizeof(__fp16), 4096); // Q: [g_br, DK] const size_t o_tile_size = hex_align_up(g_br * DV * sizeof(__fp16), 4096); // O: [g_br, DV] x2 ping-pong @@ -67,7 +70,7 @@ static size_t hmx_fa_compute_vtcm_usage(size_t gqa_factor, size_t DK, size_t DV, + k_dma_size * 2 // K DMA x2 + v_dma_size * 2 // V DMA x2 + k_tile_size * 1 // K tiles - + v_tile_size * 1 // V tiles + + v_tile_size * (use_pipeline ? 2 : 1) // V tiles (double-buffered if pipelining) + s_tile_size * 2 // S + P + d_tile_size * 1 // D (diagonal matrix) + col_vec_size * 4 // m_vec, l_vec, s_rowmax, p_rowsum @@ -144,12 +147,13 @@ static int hmx_fa_find_chunk_size(size_t * Br_out, // See .cursor/todos/hmx-flash-attn-bc-search-space.md for the perf trade-off. const size_t bc_unit = HMX_FP16_TILE_N_COLS * 2; // 64 const size_t fp16 = sizeof(__fp16); + const bool can_pipeline = (kv_len >= FA_MIN_KV_BLOCKS * bc_unit && n_threads >= 2); // Approximate per-unit VTCM costs (without per-buffer alignment padding). const size_t per_gbr = (DK + 2 * DV) * fp16 + 4 * fp16; // Q + O×2 + 4 col vectors const size_t per_gbr2 = fp16; // D diagonal matrix const size_t per_bc = - 3 * (DK + DV) * fp16 + 2 * n_threads * fp16; // K_dma×2 + V_dma×2 + K_tile + V_tile + row bufs + 3 * DK * fp16 + (can_pipeline ? 4 : 3) * DV * fp16 + 2 * n_threads * fp16; // K/V DMA x2 + tiles + row bufs const size_t per_gbr_bc = 2 * fp16; // S + P const size_t overhead = 256 * 2 + 13 * 4096; @@ -164,7 +168,6 @@ static int hmx_fa_find_chunk_size(size_t * Br_out, // Pipeline constraint: cap Bc so n_kv_blocks >= FA_MIN_KV_BLOCKS. // Only relax when kv_len is too short to form enough blocks. - const bool can_pipeline = (kv_len >= FA_MIN_KV_BLOCKS * bc_unit && n_threads >= 2); const size_t Bc_limit = can_pipeline ? hex_align_down(kv_len / FA_MIN_KV_BLOCKS, bc_unit) : (kv_len >= bc_unit ? hex_align_down(kv_len, bc_unit) : bc_unit); // Cost coefficients calibrated from profiling @@ -200,7 +203,7 @@ static int hmx_fa_find_chunk_size(size_t * Br_out, } // Exact VTCM verification (alignment padding may push over budget) - while (Bc >= bc_unit && hmx_fa_compute_vtcm_usage(gqa_factor, DK, DV, Br, Bc, n_threads) > vtcm_budget) { + while (Bc >= bc_unit && hmx_fa_compute_vtcm_usage(gqa_factor, DK, DV, Br, Bc, n_threads, can_pipeline) > vtcm_budget) { Bc -= bc_unit; } if (Bc < bc_unit) { @@ -303,6 +306,7 @@ struct hmx_fa_context { uint32_t n_kv_heads; // number of KV heads uint32_t n_heads; // number of Q heads uint32_t G; // GQA factor = n_heads / n_kv_heads + struct fastdiv_values div_G; uint32_t n_kv_blocks; uint32_t neq1; // Q token count @@ -321,7 +325,7 @@ struct hmx_fa_context { __fp16 * vtcm_k_fp16[2]; // K DMA double-buffer [Bc, D] __fp16 * vtcm_v_fp16[2]; // V DMA double-buffer [Bc, D] __fp16 * vtcm_k_tiles; // K tiles (transposed) - __fp16 * vtcm_v_tiles; // V tiles (column-major) + __fp16 * vtcm_v_tiles[2]; // V tiles (column-major, double-buffered) __fp16 * vtcm_s_tiles; // S = QK^T [g_br, Bc] __fp16 * vtcm_p_tiles; // P = softmax(S) [g_br, Bc] __fp16 * vtcm_d_tiles; // Diagonal rescale [g_br, g_br] @@ -402,7 +406,9 @@ static void fa_v_interleave_thread(unsigned int n, unsigned int i, void * data) return; } - hmx_interleave_cols_to_tiles(factx->vtcm_v_tiles, factx->vtcm_v_fp16[args->buf_idx], total_rows, (int) factx->DV, + __fp16 * v_tiles_dest = factx->use_pipeline ? factx->vtcm_v_tiles[args->buf_idx] : factx->vtcm_v_tiles[0]; + + hmx_interleave_cols_to_tiles(v_tiles_dest, factx->vtcm_v_fp16[args->buf_idx], total_rows, (int) factx->DV, (int) args->src_stride, (int) args->n_col_tiles, start, end); } @@ -464,10 +470,10 @@ static void fa_q_load_thread(unsigned int n, unsigned int i, void * data) { for (size_t r = start; r < end; r += 2) { const bool next_row_valid = (r + 1) < n_rows_g; - const size_t q_idx0 = (r + 0) / G; - const size_t h_idx0 = (r + 0) % G; - const size_t q_idx1 = (r + 1) / G; - const size_t h_idx1 = (r + 1) % G; + const size_t q_idx0 = fastdiv(r + 0, &factx->div_G); + const size_t h_idx0 = fastmodulo(r + 0, G, &factx->div_G); + const size_t q_idx1 = fastdiv(r + 1, &factx->div_G); + const size_t h_idx1 = fastmodulo(r + 1, G, &factx->div_G); const uint8_t * q_ptr0 = (const uint8_t *) q->data + (q_start + q_idx0) * q->nb[1] + (kv_head * G + h_idx0) * q->nb[2] + ib3 * q->nb[3]; @@ -567,8 +573,8 @@ static void fa_o_store_thread(unsigned int n, unsigned int i, void * data) { const uint32_t ib3 = args->ib3; for (size_t r = start; r < end; ++r) { - const size_t q_idx = r / G; - const size_t h_idx = r % G; + const size_t q_idx = fastdiv(r, &factx->div_G); + const size_t h_idx = fastmodulo(r, G, &factx->div_G); // FIX(dst-indexing): ggml_flash_attn_ext() creates dst as permute(0,2,1,3) -> // [DV, n_heads, n_tokens, n_seq], so head stride is nb[1] and token stride is nb[2]. @@ -780,11 +786,11 @@ static void fa_softmax_thread(unsigned int n, unsigned int i, void * data) { if (args->mask_vtcm) { // Read mask from VTCM buffer (DMA'd per KV block). // GQA dedup (scheme B): skip load when qi unchanged. - const size_t qi0 = (r + 0) / G; + const size_t qi0 = fastdiv(r + 0, &factx->div_G); v_mask0 = *(const HVX_UVector *) (args->mask_vtcm + qi0 * args->mask_vtcm_row_stride + c); v_mask1 = v_neg_inf; if (r + 1 < (int) n_rows_g) { - const size_t qi1 = (r + 1) / G; + const size_t qi1 = fastdiv(r + 1, &factx->div_G); if (qi1 == qi0) { v_mask1 = v_mask0; // scheme B: reuse — same mask row } else { @@ -794,8 +800,8 @@ static void fa_softmax_thread(unsigned int n, unsigned int i, void * data) { } else { // Fallback: read mask directly from DDR (when mask->ne[2] > 1). const struct htp_tensor * mask = args->mask; - const size_t q_idx0 = args->q_start + ((r + 0) / G); - const size_t h_idx0 = args->kv_head * G + (r + 0) % G; + const size_t q_idx0 = args->q_start + fastdiv(r + 0, &factx->div_G); + const size_t h_idx0 = args->kv_head * G + fastmodulo(r + 0, G, &factx->div_G); const uint32_t im2_0 = h_idx0 % mask->ne[2]; const uint32_t im3_0 = args->ib3 % mask->ne[3]; @@ -805,12 +811,12 @@ static void fa_softmax_thread(unsigned int n, unsigned int i, void * data) { v_mask1 = v_neg_inf; if (r + 1 < (int) n_rows_g) { - const size_t q_idx1 = args->q_start + ((r + 1) / G); + const size_t q_idx1 = args->q_start + fastdiv(r + 1, &factx->div_G); if (q_idx1 == q_idx0) { // scheme B: same mask row in DDR path v_mask1 = v_mask0; } else { - const size_t h_idx1 = args->kv_head * G + (r + 1) % G; + const size_t h_idx1 = args->kv_head * G + fastmodulo(r + 1, G, &factx->div_G); const uint32_t im2_1 = h_idx1 % mask->ne[2]; const uint32_t im3_1 = args->ib3 % mask->ne[3]; const __fp16 * m1_ptr = (const __fp16 *) ((const uint8_t *) mask->data + q_idx1 * mask->nb[1] + @@ -1191,14 +1197,13 @@ static void hmx_fa_o_norm_worker(void * data) { // Row r in the GQA-merged block maps to Q head h = kv_head * G + r % G. // slope(h) = m0^(h+1) when h < n_head_log2, else m1^(2*(h-n_head_log2)+1). // When max_bias == 0, all slopes are 1.0 (no ALiBi). -static __attribute__((noinline)) void fa_compute_slopes(fa_softmax_args_t * sargs, +static __attribute__((noinline)) void fa_compute_slopes( const struct hmx_fa_context * factx, uint32_t kv_head, size_t n_rows_g) { + __fp16 * slopes = factx->vtcm_slopes; if (factx->max_bias == 0.0f) { - for (size_t r = 0; r < n_rows_g; ++r) { - sargs->slopes[r] = 1.0f; - } + hvx_splat_f16_a(slopes, 1.0f, n_rows_g); return; } @@ -1207,10 +1212,32 @@ static __attribute__((noinline)) void fa_compute_slopes(fa_softmax_args_t * sarg const float m0 = factx->m0; const float m1 = factx->m1; + __fp16 temp_slopes[512] __attribute__((aligned(128))); + if (G <= 32) { + // Fast path: Compute G unique slope values in vector registers + HVX_Vector v_val = hvx_alibi_slopes(kv_head, G, n_head_log2, m0, m1); + + __fp16 temp_slopes_aligned[64] __attribute__((aligned(128))); + hvx_vmem(temp_slopes_aligned) = hvx_vec_f32_to_f16(v_val, Q6_V_vzero()); + + for (uint32_t i = 0; i < G; ++i) { + temp_slopes[i] = temp_slopes_aligned[i]; + } + } else { + // Fallback path: G > 32 (rare configurations) + for (uint32_t i = 0; i < G; ++i) { + temp_slopes[i] = (__fp16)alibi_slope(kv_head * G + i, n_head_log2, m0, m1); + } + } + + // Allocate stack buffer to avoid scalar writes to VTCM (which generates L2 misses) + __fp16 local_slopes[n_rows_g] __attribute__((aligned(128))); for (size_t r = 0; r < n_rows_g; ++r) { - const uint32_t h = kv_head * G + r % G; - sargs->slopes[r] = (h < n_head_log2) ? powf(m0, h + 1) : powf(m1, 2 * (h - n_head_log2) + 1); + local_slopes[r] = temp_slopes[fastmodulo(r, G, &factx->div_G)]; } + + // Copy to VTCM slopes using HVX block copy (both are aligned to 128 bytes) + hvx_copy_f16_aa((uint8_t *)slopes, (const uint8_t *)local_slopes, n_rows_g); } // ============================================================================ @@ -1254,19 +1281,22 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { const uint32_t G = neq2 / n_kv_heads; // Thread count for multi-thread HVX phases - const uint32_t n_threads = octx->n_threads; + const uint32_t n_threads_init = octx->n_threads; // Compute dynamic block sizes (GQA-aware, accounting for per-thread row bufs) size_t Br, Bc; const size_t vtcm_budget = ctx->vtcm_size; - if (hmx_fa_find_chunk_size(&Br, &Bc, G, DK, DV, neq1, nek1, vtcm_budget, n_threads) != 0) { + if (hmx_fa_find_chunk_size(&Br, &Bc, G, DK, DV, neq1, nek1, vtcm_budget, n_threads_init) != 0) { return HTP_STATUS_VTCM_TOO_SMALL; } const size_t g_br = hex_align_up(G * Br, HMX_FP16_TILE_N_ROWS); const uint32_t n_kv_blocks = (nek1 + Bc - 1) / Bc; - const bool use_pipeline = (n_kv_blocks >= FA_MIN_KV_BLOCKS && n_threads >= 2); + const bool use_pipeline = (n_kv_blocks >= FA_MIN_KV_BLOCKS && n_threads_init >= 2); + + // Bypass thread pool dispatch for small prompts/non-pipelined prefill by setting n_threads = 1 + const uint32_t n_threads = use_pipeline ? n_threads_init : 1; FARF(HIGH, "hmx-fa: neq1=%u nek1=%u DK=%u DV=%u G=%u Br=%zu Bc=%zu g_br=%zu n_kv_blocks=%u pipeline=%d vtcm=%zu", neq1, nek1, DK, DV, G, Br, Bc, g_br, n_kv_blocks, use_pipeline, vtcm_budget); @@ -1282,6 +1312,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { factx.n_kv_heads = n_kv_heads; factx.n_heads = neq2; factx.G = G; + factx.div_G = init_fastdiv_values(G); factx.neq1 = neq1; factx.Br = (uint32_t) Br; factx.Bc = (uint32_t) Bc; @@ -1354,7 +1385,12 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { factx.vtcm_v_fp16[0] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, v_dma_bytes); factx.vtcm_v_fp16[1] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, v_dma_bytes); factx.vtcm_k_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, k_tile_bytes); - factx.vtcm_v_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, v_tile_bytes); + factx.vtcm_v_tiles[0] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, v_tile_bytes); + if (use_pipeline) { + factx.vtcm_v_tiles[1] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, v_tile_bytes); + } else { + factx.vtcm_v_tiles[1] = NULL; + } factx.vtcm_s_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, s_tile_bytes); factx.vtcm_p_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, s_tile_bytes); factx.vtcm_d_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, d_tile_bytes); @@ -1457,6 +1493,8 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { // ---- KV block loop with DMA double-buffering ---- size_t buf_idx = 0; + fa_compute_slopes(&factx, kv_head, n_rows_g); + // Prefetch first KV block if (factx.n_kv_blocks > 0) { const uint32_t kv_rows0 = hex_smin(Bc, nek1); @@ -1535,7 +1573,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { ou_job.o_curr = o_tile_curr; ou_job.o_prev = o_tile_prev; ou_job.p_tiles = factx.vtcm_p_tiles; - ou_job.v_tiles = factx.vtcm_v_tiles; + ou_job.v_tiles = factx.vtcm_v_tiles[1 - buf_idx]; ou_job.d_tiles = factx.vtcm_d_tiles; ou_job.hmx_scales = factx.vtcm_hmx_scales_id; ou_job.n_row_tiles = n_row_tiles; @@ -1550,11 +1588,6 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { fa_phase_k_interleave(&factx, kv_rows, k_src_stride, buf_idx); TIMER_STOP(k_interleave); - if (kv_blk > 0) { - hmx_queue_pop(hmx_q); - hex_swap_ptr((void **) &o_tile_curr, (void **) &o_tile_prev); - } - // ---- Phase 2: qk_dot(blk) on HMX ‖ V_int(blk) + DMA prefetch on HVX ---- qk_job.q_tiles = factx.vtcm_q_tiles; qk_job.k_tiles = factx.vtcm_k_tiles; @@ -1574,6 +1607,13 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { fa_phase_v_interleave(&factx, kv_rows, v_src_stride, buf_idx, n_tiles_per_bc); TIMER_STOP(v_interleave); + // Pop and swap previous block's output update (deferred HMX pop) + if (kv_blk > 0) { + hmx_queue_pop(hmx_q); + hex_swap_ptr((void **) &o_tile_curr, (void **) &o_tile_prev); + } + + // Pop current block's dot product job hmx_queue_pop(hmx_q); TIMER_STOP(qk_dot); @@ -1601,7 +1641,6 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { sargs.mask_vtcm = has_mask_dma ? (const __fp16 *) factx.vtcm_mask_buf : NULL; sargs.mask_vtcm_row_stride = factx.mask_buf_row_stride; sargs.slopes = factx.vtcm_slopes; - fa_compute_slopes(&sargs, &factx, kv_head, n_rows_g); TIMER_START(softmax); fa_phase_softmax_and_build_d(&factx, &sargs, n_row_tiles, n_row_tiles_g_br); @@ -1617,7 +1656,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { ou_job.o_curr = o_tile_curr; ou_job.o_prev = o_tile_prev; ou_job.p_tiles = factx.vtcm_p_tiles; - ou_job.v_tiles = factx.vtcm_v_tiles; + ou_job.v_tiles = factx.vtcm_v_tiles[1 - buf_idx]; ou_job.d_tiles = factx.vtcm_d_tiles; ou_job.hmx_scales = factx.vtcm_hmx_scales_id; ou_job.n_row_tiles = n_row_tiles; @@ -1712,7 +1751,6 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { sargs.mask_vtcm = has_mask_dma ? (const __fp16 *) factx.vtcm_mask_buf : NULL; sargs.mask_vtcm_row_stride = factx.mask_buf_row_stride; sargs.slopes = factx.vtcm_slopes; - fa_compute_slopes(&sargs, &factx, kv_head, n_rows_g); TIMER_START(softmax); fa_phase_softmax_and_build_d(&factx, &sargs, n_row_tiles, n_row_tiles_g_br); @@ -1732,7 +1770,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { const size_t DV_tiles = (size_t) (DV / 32); const __fp16 * restrict d_base = factx.vtcm_d_tiles; const __fp16 * restrict p_base = factx.vtcm_p_tiles; - const __fp16 * restrict v_base = factx.vtcm_v_tiles; + const __fp16 * restrict v_base = factx.vtcm_v_tiles[0]; const __fp16 * restrict op_base = o_tile_prev; __fp16 * restrict oc_base = o_tile_curr; __builtin_assume(n_row_tiles > 0); diff --git a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c index 083d125882d..dab605210cf 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c @@ -73,6 +73,10 @@ static inline size_t get_x4x2_row_stride(int weight_type, int k) { return (size_t) nb * (QK_Q8_0x4x2 + HMX_X4X2_DBLK_SIZE); // 272 * nb case HTP_TYPE_MXFP4: return (size_t) nb * (QK_MXFP4x4x2 / 2 + HMX_X4X2_MXFP4_EBLK_SIZE); // 136 * nb + case HTP_TYPE_F16: + return (size_t) k * sizeof(__fp16); + case HTP_TYPE_F32: + return (size_t) k * sizeof(float); default: return 0; } @@ -545,7 +549,7 @@ static void dequantize_x4x2_weight_to_fp16_tiles_task_mxfp4( int start_tile, int end_tile) { const int n_k_tiles = state->n_k_tiles; - const int qrow_size = state->k_block; + const int qrow_size = (unsigned)state->k_block / 2; const struct fastdiv_values n_k_tiles_div = state->n_k_tiles_div; const HVX_Vector vlut_cvt = hvx_vmem(mxfp4_to_fp16_lut); @@ -720,12 +724,129 @@ static void dequantize_x4x2_worker_loop_q8_0(unsigned int n, unsigned int i, voi } } +static void convert_f16_weight_to_fp16_tiles_task( + const x4x2_dequantize_state_t *state, + int start_tile, int end_tile) { + + const int n_k_tiles = state->n_k_tiles; + const struct fastdiv_values n_k_tiles_div = state->n_k_tiles_div; + + const HVX_Vector v_scat_base = hvx_vmem(hmx_transpose_scatter_offsets); + const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); + const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64); + + unsigned ct = fastdiv((unsigned)start_tile, &n_k_tiles_div); + unsigned kt = fastmodulo((unsigned)start_tile, n_k_tiles, &n_k_tiles_div); + + for (unsigned t = start_tile; t < (unsigned)end_tile; ) { + if (kt >= (unsigned)n_k_tiles) { kt = 0; ct++; } + + __fp16 *tile_base = state->dst + t * HMX_FP16_TILE_N_ELMS; + { + int byte_off = kt * 32 * sizeof(__fp16); + + HVX_Vector v_off = v_scat_base; + for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) { + int row0 = ct * HMX_FP16_TILE_N_COLS + r; + int row1 = row0 + 1; + + const uint8_t *r0 = state->src + row0 * state->row_stride; + const uint8_t *r1 = state->src + row1 * state->row_stride; + + HVX_Vector v0 = hvx_vmemu((const __fp16 *)(r0 + byte_off)); + HVX_Vector v1 = (row1 < state->n_cols) ? hvx_vmemu((const __fp16 *)(r1 + byte_off)) : Q6_V_vzero(); + + Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v1); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + } + (void) *(volatile HVX_Vector *)(tile_base); + } + ++t; ++kt; + } + + if (start_tile < end_tile) { + (void) *(volatile HVX_Vector *)(state->dst + (end_tile - 1) * HMX_FP16_TILE_N_ELMS); + } +} + +static void convert_f16_worker_loop(unsigned int n, unsigned int i, void *data) { + x4x2_dequantize_state_t *state = (x4x2_dequantize_state_t *)data; + for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) { + int start = task_id * state->n_tiles_per_task; + int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles); + convert_f16_weight_to_fp16_tiles_task(state, start, end); + } +} + +static void quantize_f32_weight_to_fp16_tiles_task( + const x4x2_dequantize_state_t *state, + int start_tile, int end_tile) { + + const int n_k_tiles = state->n_k_tiles; + const struct fastdiv_values n_k_tiles_div = state->n_k_tiles_div; + + const HVX_Vector v_scat_base = hvx_vmem(hmx_transpose_scatter_offsets); + const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); + const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64); + + unsigned ct = fastdiv((unsigned)start_tile, &n_k_tiles_div); + unsigned kt = fastmodulo((unsigned)start_tile, n_k_tiles, &n_k_tiles_div); + + for (unsigned t = start_tile; t < (unsigned)end_tile; ) { + if (kt >= (unsigned)n_k_tiles) { kt = 0; ct++; } + + __fp16 *tile_base = state->dst + t * HMX_FP16_TILE_N_ELMS; + { + int byte_off = kt * 32 * sizeof(float); + + HVX_Vector v_off = v_scat_base; + for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) { + int row0 = ct * HMX_FP16_TILE_N_COLS + r; + int row1 = row0 + 1; + + const uint8_t *r0 = state->src + row0 * state->row_stride; + const uint8_t *r1 = state->src + row1 * state->row_stride; + + HVX_Vector v0_f32 = hvx_vmemu((const float *)(r0 + byte_off)); + HVX_Vector v1_f32 = (row1 < state->n_cols) ? hvx_vmemu((const float *)(r1 + byte_off)) : Q6_V_vzero(); + + HVX_Vector v_out = hvx_vec_f32_to_f16(v0_f32, v1_f32); + + Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v_out); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + + HVX_Vector v_out_hi = Q6_V_vror_VR(v_out, 64); + Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v_out_hi); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + } + (void) *(volatile HVX_Vector *)(tile_base); + } + ++t; ++kt; + } + + if (start_tile < end_tile) { + (void) *(volatile HVX_Vector *)(state->dst + (end_tile - 1) * HMX_FP16_TILE_N_ELMS); + } +} + +static void quantize_f32_worker_loop(unsigned int n, unsigned int i, void *data) { + x4x2_dequantize_state_t *state = (x4x2_dequantize_state_t *)data; + for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) { + int start = task_id * state->n_tiles_per_task; + int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles); + quantize_f32_weight_to_fp16_tiles_task(state, start, end); + } +} + + static void dequantize_x4x2_weight_chunk_to_fp16_tiles( struct htp_context *ctx, __fp16 *vtcm_dst, const void *vtcm_src, int n_cols, int k_block, size_t row_stride, int weight_type, int n_k_tiles, struct fastdiv_values n_k_tiles_div, - worker_callback_t dequant_worker_fn) { + worker_callback_t dequant_worker_fn, int n_threads) { assert(n_cols % HMX_FP16_TILE_N_COLS == 0); assert(k_block % HMX_FP16_TILE_N_COLS == 0); @@ -733,7 +854,7 @@ static void dequantize_x4x2_weight_chunk_to_fp16_tiles( size_t n_col_tiles = n_cols / HMX_FP16_TILE_N_COLS; size_t n_tot_tiles = n_col_tiles * n_k_tiles; - size_t n_tiles_per_task = hmx_ceil_div(n_tot_tiles, ctx->n_threads); + size_t n_tiles_per_task = (n_threads == 1) ? n_tot_tiles : hmx_ceil_div(n_tot_tiles, n_threads); x4x2_dequantize_state_t state; state.n_tasks = (n_tot_tiles + n_tiles_per_task - 1) / n_tiles_per_task; @@ -748,7 +869,11 @@ static void dequantize_x4x2_weight_chunk_to_fp16_tiles( state.n_k_tiles = n_k_tiles; state.n_k_tiles_div = n_k_tiles_div; - worker_pool_run_func(ctx->worker_pool, dequant_worker_fn, &state, ctx->n_threads); + if (state.n_tasks == 1 || n_threads == 1) { + dequant_worker_fn(1, 0, &state); + } else { + worker_pool_run_func(ctx->worker_pool, dequant_worker_fn, &state, n_threads); + } } // --- End x4x2 dequantizers --- @@ -876,11 +1001,11 @@ static void transfer_output_chunk_worker_fn(unsigned int n, unsigned int i, void } static void transfer_output_chunk_threaded(struct htp_context *ctx, float *dst, const __fp16 *vtcm_src, - int n_rows, int n_cols, int n) { + int n_rows, int n_cols, int n, int n_threads) { assert(n_cols % HMX_FP16_TILE_N_COLS == 0); size_t n_tot_chunks = n_rows; - size_t n_chunks_per_task = HMX_FP16_TILE_N_ROWS; // must be multiple of HMX_FP16_TILE_N_ROWS (32) + size_t n_chunks_per_task = (n_threads == 1) ? n_tot_chunks : HMX_FP16_TILE_N_ROWS; // must be multiple of HMX_FP16_TILE_N_ROWS (32) output_transfer_task_state_t state; state.n_tasks = (n_tot_chunks + n_chunks_per_task - 1) / n_chunks_per_task; @@ -891,7 +1016,11 @@ static void transfer_output_chunk_threaded(struct htp_context *ctx, float *dst, state.n_cols = n_cols; state.n = n; - worker_pool_run_func(ctx->worker_pool, transfer_output_chunk_worker_fn, &state, ctx->n_threads); + if (state.n_tasks == 1 || n_threads == 1) { + transfer_output_chunk_worker_fn(1, 0, &state); + } else { + worker_pool_run_func(ctx->worker_pool, transfer_output_chunk_worker_fn, &state, n_threads); + } } // activations : fp32 -> fp16 @@ -973,12 +1102,12 @@ static void transfer_activation_chunk_worker_fn(unsigned int n, unsigned int i, } } -static void transfer_activation_chunk_threaded(struct htp_context *ctx, __fp16 *dst, const float *src, int n_rows, int k_block, int k_stride) { +static void transfer_activation_chunk_threaded(struct htp_context *ctx, __fp16 *dst, const float *src, int n_rows, int k_block, int k_stride, int n_threads) { assert(k_block % HMX_FP16_TILE_N_COLS == 0 && k_stride % HMX_FP16_TILE_N_COLS == 0); assert(VLEN == 32 * sizeof(float)); size_t n_tot_chunks = n_rows; - size_t n_chunks_per_task = 32; // must be multiple of 32 to ensure correct destination address + size_t n_chunks_per_task = (n_threads == 1) ? n_tot_chunks : 32; // must be multiple of 32 to ensure correct destination address activation_transfer_task_state_t state; state.n_tasks = (n_tot_chunks + n_chunks_per_task - 1) / n_chunks_per_task; @@ -989,7 +1118,11 @@ static void transfer_activation_chunk_threaded(struct htp_context *ctx, __fp16 * state.k_block = k_block; state.k_stride = k_stride; - worker_pool_run_func(ctx->worker_pool, transfer_activation_chunk_worker_fn, &state, ctx->n_threads); + if (state.n_tasks == 1 || n_threads == 1) { + transfer_activation_chunk_worker_fn(1, 0, &state); + } else { + worker_pool_run_func(ctx->worker_pool, transfer_activation_chunk_worker_fn, &state, n_threads); + } } // C += AB @@ -1031,9 +1164,9 @@ static void core_mma_chunk_fp16(__fp16 *restrict c, const __fp16 *restrict a, co } } -int hmx_matmul_q_f32(struct htp_context *ctx, float *restrict dst, const float *restrict activation, +int hmx_matmul_2d_f32(struct htp_context *ctx, float *restrict dst, const float *restrict activation, const uint8_t *restrict permuted_weight, int m, int k, int n, - int weight_type) { + int act_stride, int weight_stride, int weight_type) { if (k % 32 != 0 || n % 32 != 0) { return -1; } if (!hex_is_aligned(dst, VLEN) || !hex_is_aligned(activation, VLEN) || !hex_is_aligned(permuted_weight, VLEN)) { @@ -1052,6 +1185,8 @@ int hmx_matmul_q_f32(struct htp_context *ctx, float *restrict dst, const float * case HTP_TYPE_Q4_1: dequant_worker_fn = dequantize_x4x2_worker_loop_q4_1; break; case HTP_TYPE_MXFP4: dequant_worker_fn = dequantize_x4x2_worker_loop_mxfp4; break; case HTP_TYPE_Q8_0: dequant_worker_fn = dequantize_x4x2_worker_loop_q8_0; break; + case HTP_TYPE_F16: dequant_worker_fn = convert_f16_worker_loop; break; + case HTP_TYPE_F32: dequant_worker_fn = quantize_f32_worker_loop; break; default: return -1; } @@ -1059,21 +1194,25 @@ int hmx_matmul_q_f32(struct htp_context *ctx, float *restrict dst, const float * const int n_k_tiles = k / HMX_FP16_TILE_N_COLS; const struct fastdiv_values n_k_tiles_div = init_fastdiv_values(n_k_tiles); + // --- Dynamic Mode Configuration --- + const bool use_pipeline = (m > 32); + const int num_threads = (m <= 32) ? 1 : ctx->n_threads; + // --- Dynamic VTCM layout --- const size_t vec_dot_size = k * sizeof(__fp16); const size_t vtcm_budget = ctx->vtcm_size; size_t vtcm_used = 0; // Pipeline = 4-stage DMA→dequant→HMX→store with HMX worker overlap. - const size_t size_per_n = row_stride + 2 * vec_dot_size; // Q + S0 + S1 (dequant bufs) - const size_t size_per_mn = 2 * sizeof(__fp16); // O x 2 (output double buffer) + const size_t size_per_n = row_stride + (use_pipeline ? 2 * vec_dot_size : vec_dot_size); // Q + S0 + S1 (dequant bufs) + const size_t size_per_mn = (use_pipeline ? 2 : 1) * sizeof(__fp16); // O x 2 (output double buffer) size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0; if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, size_per_n, /*per_m=*/vec_dot_size, size_per_mn, hex_align_up(m, HMX_FP16_TILE_N_ROWS), n, /*m_block_cost=*/(size_t) n * 3, /*n_block_cost=*/(size_t) m * 2, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used)) { - FARF(HIGH, "hmx-mm-q: VTCM too small : m %d k %d n %d budget %zu", m, k, n, vtcm_budget); + FARF(HIGH, "hmx-mm-2d: VTCM too small : m %d k %d n %d budget %zu", m, k, n, vtcm_budget); return -1; } @@ -1083,27 +1222,27 @@ int hmx_matmul_q_f32(struct htp_context *ctx, float *restrict dst, const float * size_t scratch0_size, scratch1_size, scratch2_size; scratch0_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); // dequant buf 0 - scratch1_size = scratch0_size; // dequant buf 1 - scratch2_size = output_area_size; // output buf 1 + scratch1_size = use_pipeline ? scratch0_size : 0; // dequant buf 1 + scratch2_size = use_pipeline ? output_area_size : 0; // output buf 1 uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base; __fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_area_size); __fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, act_area_size); __fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, output_area_size); void *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch0_size); - void *vtcm_scratch1 = vtcm_seq_alloc(&vtcm_ptr, scratch1_size); + void *vtcm_scratch1 = scratch1_size ? vtcm_seq_alloc(&vtcm_ptr, scratch1_size) : NULL; void *vtcm_scratch2 = scratch2_size ? vtcm_seq_alloc(&vtcm_ptr, scratch2_size) : NULL; __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); vtcm_used = vtcm_ptr - (uint8_t *) ctx->vtcm_base; if (vtcm_used > vtcm_budget) { - FARF(ERROR, "hmx-mm-q: VTCM overflow: used %zu budget %zu", vtcm_used, vtcm_budget); + FARF(ERROR, "hmx-mm-2d: VTCM overflow: used %zu budget %zu", vtcm_used, vtcm_budget); return -1; } hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16 - FARF(HIGH, "hmx-mm-q: standard : m %d k %d n %d wtype %d mc %zu nc %zu vtcm %zu/%zu", + FARF(HIGH, "hmx-mm-2d: standard : m %d k %d n %d wtype %d mc %zu nc %zu vtcm %zu/%zu", m, k, n, weight_type, m_chunk_n_rows, n_chunk_n_cols, vtcm_used, vtcm_budget); TIMER_DEFINE(activation_load); @@ -1114,115 +1253,137 @@ int hmx_matmul_q_f32(struct htp_context *ctx, float *restrict dst, const float * TIMER_DEFINE(total); TIMER_START(total); - // 4-stage pipeline: DMA load (A), dequantize (B), HMX matmul (C), store (D) - // HMX compute (C) runs on dedicated worker thread, overlapping with HVX stages (B, D). - - // A --> B: vtcm_qweight, 1 buffer - // B --> C: vtcm_weight0/vtcm_weight1, 2 buffers - // C --> D: vtcm_output0/vtcm_output1, 2 buffers + int n_chunk_cnt = hmx_ceil_div(n, n_chunk_n_cols); - // Async timeline (C overlaps B+D): - // main+HVX: [A0][Act][B0][A1][sub C0][B1‖C0][A2][wait,sub C1][D0+B2‖C1][wait,sub C2][D1‖C2][wait][D2] - // HMX queue: [████ C0 ████████][████ C1 ████████████][████ C2 ████████] + if (use_pipeline) { + // --- Asynchronous Pipelined Loop (Current implementation) --- + hmx_matmul_job_t job_slots[2]; // persistent double-buffered job descriptors - int n_chunk_cnt = hmx_ceil_div(n, n_chunk_n_cols); - hmx_matmul_job_t job_slots[2]; // persistent double-buffered job descriptors + for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) { + const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows); - for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) { - const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows); + void *vtcm_qweight = vtcm_weight; + void *vtcm_weight_bufs[2] = { vtcm_scratch0, vtcm_scratch1 }; + void *vtcm_output_bufs[2] = { vtcm_output, vtcm_scratch2 }; - void *vtcm_qweight = vtcm_weight; - void *vtcm_weight_bufs[2] = { vtcm_scratch0, vtcm_scratch1 }; - void *vtcm_output_bufs[2] = { vtcm_output, vtcm_scratch2 }; + // prologue: A0 + const size_t n_cols_A0 = hex_smin(n - 0 * n_chunk_n_cols, n_chunk_n_cols); + { + const uint8_t *qweight_chunk_A0 = permuted_weight; + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A0), row_stride, weight_stride, row_stride, n_cols_A0); + } - // prologue: A0 - const size_t n_cols_A0 = hex_smin(n - 0 * n_chunk_n_cols, n_chunk_n_cols); - { - const uint8_t *qweight_chunk_A0 = permuted_weight; - dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A0), row_stride, row_stride, row_stride, n_cols_A0); - } + { + const float *activation_chunk = activation + mr * act_stride; + transfer_activation_chunk_threaded(ctx, vtcm_activation, activation_chunk, n_rows, k, act_stride, num_threads); + } - { - const float *activation_chunk = activation + mr * k; - transfer_activation_chunk_threaded(ctx, vtcm_activation, activation_chunk, n_rows, k, k); - } + // prologue: B0, A1, submit C0 (async), B1 (overlaps C0) + { + // B0: wait for DMA, dequant weight chunk 0 + dma_queue_pop(ctx->dma[0]); + dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[0], vtcm_qweight, n_cols_A0, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn, num_threads); - // prologue: B0, A1, submit C0 (async), B1 (overlaps C0) - { - // B0: wait for DMA, dequant weight chunk 0 - dma_queue_pop(ctx->dma[0]); - dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[0], vtcm_qweight, n_cols_A0, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn); + // A1: issue DMA for weight chunk 1 + const size_t n_cols_A1 = hex_smin(n - 1 * n_chunk_n_cols, n_chunk_n_cols); + if (1 < n_chunk_cnt) { + const uint8_t *qweight_chunk_A1 = permuted_weight + n_chunk_n_cols * weight_stride; + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A1), row_stride, weight_stride, row_stride, n_cols_A1); + } - // A1: issue DMA for weight chunk 1 - const size_t n_cols_A1 = hex_smin(n - 1 * n_chunk_n_cols, n_chunk_n_cols); - if (1 < n_chunk_cnt) { - const uint8_t *qweight_chunk_A1 = permuted_weight + n_chunk_n_cols * row_stride; - dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A1), row_stride, row_stride, row_stride, n_cols_A1); + // submit C0 (non-blocking — HMX worker executes in parallel) + hmx_matmul_job_init(&job_slots[0], (__fp16 *) vtcm_output_bufs[0], (__fp16 *) vtcm_activation, + (__fp16 *) vtcm_weight_bufs[0], vtcm_scales, + hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS), + hmx_ceil_div(n_cols_A0, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS); + hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_matmul_worker_fn, &job_slots[0])); + + // B1: DMA pop + dequant (runs in parallel with C0 on HMX worker) + if (1 < n_chunk_cnt) { + dma_queue_pop(ctx->dma[0]); + dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[1], vtcm_qweight, n_cols_A1, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn, num_threads); + } } - // submit C0 (non-blocking — HMX worker executes in parallel) - hmx_matmul_job_init(&job_slots[0], (__fp16 *) vtcm_output_bufs[0], (__fp16 *) vtcm_activation, - (__fp16 *) vtcm_weight_bufs[0], vtcm_scales, - hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS), - hmx_ceil_div(n_cols_A0, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS); - hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_matmul_worker_fn, &job_slots[0])); + // main loop: wait C_i → submit C_{i+1} → D_i + B_{i+2} (parallel with C_{i+1}) + for (int i = 0; i < n_chunk_cnt; ++i) { + const size_t nc = i * n_chunk_n_cols; + const size_t nc_p1 = nc + 1 * n_chunk_n_cols; + const size_t nc_p2 = nc + 2 * n_chunk_n_cols; - // B1: DMA pop + dequant (runs in parallel with C0 on HMX worker) - if (1 < n_chunk_cnt) { - dma_queue_pop(ctx->dma[0]); - dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[1], vtcm_qweight, n_cols_A1, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn); - } - } + const size_t n_cols = hex_smin(n - nc, n_chunk_n_cols); + const size_t n_cols_p1 = hex_smin(n - nc_p1, n_chunk_n_cols); + const size_t n_cols_p2 = hex_smin(n - nc_p2, n_chunk_n_cols); - // main loop: wait C_i → submit C_{i+1} → D_i + B_{i+2} (parallel with C_{i+1}) - for (int i = 0; i < n_chunk_cnt; ++i) { - const size_t nc = i * n_chunk_n_cols; - const size_t nc_p1 = nc + 1 * n_chunk_n_cols; - const size_t nc_p2 = nc + 2 * n_chunk_n_cols; + // issue A_{i+2}: DMA push (non-blocking) + if (i + 2 < n_chunk_cnt) { + const uint8_t *qweight_chunk_p2 = permuted_weight + nc_p2 * weight_stride; + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_p2), row_stride, weight_stride, row_stride, n_cols_p2); + } - const size_t n_cols = hex_smin(n - nc, n_chunk_n_cols); - const size_t n_cols_p1 = hex_smin(n - nc_p1, n_chunk_n_cols); - const size_t n_cols_p2 = hex_smin(n - nc_p2, n_chunk_n_cols); + // wait C_i: block until prologue/previous C completes + hmx_queue_pop(ctx->hmx_queue); - // issue A_{i+2}: DMA push (non-blocking) - if (i + 2 < n_chunk_cnt) { - const uint8_t *qweight_chunk_p2 = permuted_weight + nc_p2 * row_stride; - dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_p2), row_stride, row_stride, row_stride, n_cols_p2); - } + // submit C_{i+1} (non-blocking, overlaps with D_i + B_{i+2} below) + if (i + 1 < n_chunk_cnt) { + hmx_matmul_job_init(&job_slots[(i + 1) % 2], (__fp16 *) vtcm_output_bufs[(i + 1) % 2], + (__fp16 *) vtcm_activation, (__fp16 *) vtcm_weight_bufs[(i + 1) % 2], + vtcm_scales, hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS), + hmx_ceil_div(n_cols_p1, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS); + hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_matmul_worker_fn, &job_slots[(i + 1) % 2])); + } + + // D_i: store output (multi-thread HVX, parallel with C_{i+1}) + float *output_chunk = dst + (mr * n + nc); + transfer_output_chunk_threaded(ctx, output_chunk, vtcm_output_bufs[i % 2], n_rows, n_cols, n, num_threads); - // wait C_i: block until prologue/previous C completes - hmx_queue_pop(ctx->hmx_queue); - - // submit C_{i+1} (non-blocking, overlaps with D_i + B_{i+2} below) - // job_slots[(i+1)%2] is safe: C_i just completed, freeing slot i%2's - // counterpart — and (i+1)%2 was last used by C_{i-1} which completed - // before C_i was submitted. - if (i + 1 < n_chunk_cnt) { - hmx_matmul_job_init(&job_slots[(i + 1) % 2], (__fp16 *) vtcm_output_bufs[(i + 1) % 2], - (__fp16 *) vtcm_activation, (__fp16 *) vtcm_weight_bufs[(i + 1) % 2], - vtcm_scales, hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS), - hmx_ceil_div(n_cols_p1, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS); - hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_matmul_worker_fn, &job_slots[(i + 1) % 2])); + // B_{i+2}: DMA pop + dequant (multi-thread HVX, parallel with C_{i+1}) + if (i + 2 < n_chunk_cnt) { + dma_queue_pop(ctx->dma[0]); + dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[(i + 2) % 2], vtcm_qweight, n_cols_p2, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn, num_threads); + } } + } + hmx_queue_suspend(ctx->hmx_queue); + } else { + // --- Synchronous Loop (Optimized for small/non-pipelined cases) --- + HAP_compute_res_hmx_lock(ctx->vtcm_rctx); + + for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) { + const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows); + const size_t n_row_tiles = hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS); - // D_i: store output (multi-thread HVX, parallel with C_{i+1}) - float *output_chunk = dst + (mr * n + nc); - transfer_output_chunk_threaded(ctx, output_chunk, vtcm_output_bufs[i % 2], n_rows, n_cols, n); + // Load Activation + const float *activation_chunk = activation + mr * act_stride; + transfer_activation_chunk_threaded(ctx, vtcm_activation, activation_chunk, n_rows, k, act_stride, num_threads); - // B_{i+2}: DMA pop + dequant (multi-thread HVX, parallel with C_{i+1}) - if (i + 2 < n_chunk_cnt) { + for (size_t nc = 0; nc < n; nc += n_chunk_n_cols) { + const size_t n_cols = hex_smin(n - nc, n_chunk_n_cols); + const size_t n_col_tiles = hmx_ceil_div(n_cols, HMX_FP16_TILE_N_COLS); + + // A: DMA Load Weight + const uint8_t *qweight_chunk = permuted_weight + nc * weight_stride; + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_weight, qweight_chunk), row_stride, weight_stride, row_stride, n_cols); dma_queue_pop(ctx->dma[0]); - dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[(i + 2) % 2], vtcm_qweight, n_cols_p2, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn); + + // B: Dequantize / Convert Weight + dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_scratch0, vtcm_weight, n_cols, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn, num_threads); + + // C: HMX Compute (Synchronous) + core_dot_chunk_fp16(vtcm_output, vtcm_activation, vtcm_scratch0, vtcm_scales, n_row_tiles, n_col_tiles, k / HMX_FP16_TILE_N_ROWS); + + // D: Output Store + float *output_chunk = dst + (mr * n + nc); + transfer_output_chunk_threaded(ctx, output_chunk, vtcm_output, n_rows, n_cols, n, num_threads); } } + HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); } - hmx_queue_suspend(ctx->hmx_queue); - TIMER_STOP(total); #if defined(ENABLE_PROFILE_TIMERS) - FARF(HIGH, "hex-mm-q: %lld us : m %d k %d n %d", TIMER_US(total), m, k, n); + FARF(HIGH, "hex-mm-2d: %lld us : m %d k %d n %d", TIMER_US(total), m, k, n); if (!use_pipeline) { FARF(HIGH, " activation_load: %lld us, weight_load: %lld us, hmx_core: %lld us, output_store: %lld us", TIMER_US(activation_load), TIMER_US(weight_load), TIMER_US(hmx_core), TIMER_US(output_store)); @@ -1401,11 +1562,11 @@ int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32 dma_queue_pop(ctx->dma[0]); transfer_activation_chunk_threaded(ctx, vtcm_act_g, vtcm_f32_act, (int) n_rows, - params->k, params->k); + params->k, params->k, ctx->n_threads); } else { transfer_activation_chunk_threaded(ctx, vtcm_act_g, activation_chunk, (int) n_rows, - params->k, params->act_stride); + params->k, params->act_stride, ctx->n_threads); } } TIMER_STOP(activation_load); @@ -1455,7 +1616,7 @@ int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32 TIMER_START(output_store); { float *output = hmx_matmul_dst_batch_ptr(params, b2_base + g, b3) + mr * params->dst_stride + nc; - transfer_output_chunk_threaded(ctx, output, vtcm_output, (int) n_rows, (int) n_cols, params->dst_stride); + transfer_output_chunk_threaded(ctx, output, vtcm_output, (int) n_rows, (int) n_cols, params->dst_stride, ctx->n_threads); } TIMER_STOP(output_store); } @@ -1475,177 +1636,431 @@ int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32 TIMER_US(activation_load), TIMER_US(weight_load), TIMER_US(hmx_core), TIMER_US(output_store)); #endif - return 0; + return 0; } -// - int hmx_matmul_f16_f32(struct htp_context *ctx, float *restrict dst, const float *restrict activation, const __fp16 *restrict permuted_weight, int m, int k, int n, int act_stride, int weight_stride) { if (!dst || !activation || !permuted_weight || !m || !n || !k) { return -1; } - if (act_stride < k || weight_stride < k) { return -1; } - if (k % 32 != 0 || n % 32 != 0) { return -1; } + return hmx_matmul_2d_f32(ctx, dst, activation, (const uint8_t *)permuted_weight, m, k, n, + act_stride, weight_stride * (int)sizeof(__fp16), HTP_TYPE_F16); +} - if (!hex_is_aligned(dst, VLEN) || !hex_is_aligned(activation, VLEN) || !hex_is_aligned(permuted_weight, VLEN)) { - return -1; +struct mmid_row_mapping { + uint32_t i1; + uint32_t i2; +}; + +typedef struct { + __fp16 *dst; + const float *src; + int n_tasks; + int n_tot_chunks; + int n_chunks_per_task; + int k_block; + const struct mmid_row_mapping *matrix_rows; + int cur_a; + int mapping_stride; + int ne11; + struct fastdiv_values ne11_div; + size_t nb11; + size_t nb12; + int start_row; + int cne1; +} activation_transfer_gathered_task_state_t; + +typedef struct { + const __fp16 *vtcm_src; + float *dst; + int n_tasks; + int n_tot_chunks; + int n_chunks_per_task; + int n_cols; + const struct mmid_row_mapping *matrix_rows; + int cur_a; + int mapping_stride; + size_t dst_nb1; + size_t dst_nb2; + int start_row; + int cne1; +} output_transfer_scattered_task_state_t; + +static void transfer_activation_chunk_fp32_to_fp16_gathered( + __fp16 *restrict vtcm_dst, + const float *restrict src, + int start_row, + int n_rows, + int k_block, + const struct mmid_row_mapping *matrix_rows, + int cur_a, + int mapping_stride, + int ne11, + const struct fastdiv_values * ne11_div, + size_t nb11, + size_t nb12, + int cne1) { + const int n_rows_padded = hex_align_up(n_rows, HMX_FP16_TILE_N_ROWS); + const int n_rows_tiled = (n_rows / HMX_FP16_TILE_N_ROWS) * HMX_FP16_TILE_N_ROWS; + + int r = 0; + + #pragma unroll(2) + for (r = 0; r < n_rows_tiled; r += 2) { + int r0 = r / HMX_FP16_TILE_N_ROWS; // tile row index + int r1 = r % HMX_FP16_TILE_N_ROWS; // intra-tile row idx + + int r_idx0 = start_row + r + 0; + int r_idx1 = start_row + r + 1; + + struct mmid_row_mapping mapping0 = matrix_rows[cur_a * mapping_stride + r_idx0]; + struct mmid_row_mapping mapping1 = matrix_rows[cur_a * mapping_stride + r_idx1]; + + int i11_0 = fastmodulo(mapping0.i1, ne11, ne11_div); + int i11_1 = fastmodulo(mapping1.i1, ne11, ne11_div); + + const float *row0_ptr = (const float *) ((const uint8_t *) src + i11_0 * nb11 + mapping0.i2 * nb12); + const float *row1_ptr = (const float *) ((const uint8_t *) src + i11_1 * nb11 + mapping1.i2 * nb12); + + const HVX_Vector *pv_in0 = (const HVX_Vector *) row0_ptr; + const HVX_Vector *pv_in1 = (const HVX_Vector *) row1_ptr; + + for (int c = 0; c < k_block; c += 32) { + HVX_Vector v0 = *pv_in0++; + HVX_Vector v1 = *pv_in1++; + + HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); + + int c0 = c / HMX_FP16_TILE_N_COLS; // tile column index + int tile_idx = r0 * (k_block / HMX_FP16_TILE_N_COLS) + c0; + + HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HMX_FP16_TILE_N_ELMS); + tile[r1 / 2] = v_out; + } } - // --- Dynamic VTCM layout --- - const size_t vtcm_budget = ctx->vtcm_size; - const size_t vec_dot_size = k * sizeof(__fp16); + for (; r < n_rows_padded; r += 2) { + int r0 = r / HMX_FP16_TILE_N_ROWS; // tile row index + int r1 = r % HMX_FP16_TILE_N_ROWS; // intra-tile row idx - // DMA-based activation gather for strided tensors (see batched path comment). - const bool use_dma_activation = (act_stride > k); - const size_t f32_scratch_per_m = use_dma_activation ? (size_t) k * sizeof(float) : 0; + const bool row0_valid = (start_row + r + 0) < cne1; + const bool row1_valid = (start_row + r + 1) < cne1; - size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0, vtcm_used = 0; - // FP16 weight: interleave and activation load have similar per-element cost. - if (hmx_compute_chunks(vtcm_budget, - /*overhead=*/256, - /*per_n=*/3 * vec_dot_size, // W + S0 + S1 - /*per_m=*/vec_dot_size + f32_scratch_per_m, // A + optional F32 scratch - /*per_mn=*/sizeof(__fp16), // O - hex_align_up(m, HMX_FP16_TILE_N_ROWS), n, - /*m_block_cost=*/(size_t) n, - /*n_block_cost=*/(size_t) m, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) { - FARF(HIGH, "%s: VTCM too small (m=%d k=%d n=%d budget=%zu)", __func__, m, k, n, vtcm_budget); - return -1; + const float *row0_ptr = NULL; + const float *row1_ptr = NULL; + + if (row0_valid) { + struct mmid_row_mapping mapping0 = matrix_rows[cur_a * mapping_stride + (start_row + r + 0)]; + int i11_0 = fastmodulo(mapping0.i1, ne11, ne11_div); + row0_ptr = (const float *) ((const uint8_t *) src + i11_0 * nb11 + mapping0.i2 * nb12); + } + if (row1_valid) { + struct mmid_row_mapping mapping1 = matrix_rows[cur_a * mapping_stride + (start_row + r + 1)]; + int i11_1 = fastmodulo(mapping1.i1, ne11, ne11_div); + row1_ptr = (const float *) ((const uint8_t *) src + i11_1 * nb11 + mapping1.i2 * nb12); + } + + const HVX_Vector *pv_in0 = (const HVX_Vector *) row0_ptr; + const HVX_Vector *pv_in1 = (const HVX_Vector *) row1_ptr; + + for (int c = 0; c < k_block; c += 32) { + HVX_Vector v0 = row0_valid ? *pv_in0++ : Q6_V_vzero(); + HVX_Vector v1 = row1_valid ? *pv_in1++ : Q6_V_vzero(); + + HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); + + int c0 = c / HMX_FP16_TILE_N_COLS; // tile column index + int tile_idx = r0 * (k_block / HMX_FP16_TILE_N_COLS) + c0; + + HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HMX_FP16_TILE_N_ELMS); + tile[r1 / 2] = v_out; + } } +} - const size_t weight_area_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); - const size_t activation_area_size = hex_align_up(m_chunk_n_rows * vec_dot_size, HMX_FP16_TILE_SIZE); - const size_t output_area_size = hex_align_up(m_chunk_n_rows * n_chunk_n_cols * sizeof(__fp16), HMX_FP16_TILE_SIZE); - const size_t scratch_area_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); - const size_t f32_scratch_size = use_dma_activation - ? hex_align_up(m_chunk_n_rows * (size_t) k * sizeof(float), HMX_FP16_TILE_SIZE) : 0; +static void transfer_activation_chunk_gathered_worker_fn(unsigned int n, unsigned int i, void *data) { + activation_transfer_gathered_task_state_t *st = data; + int chunk_idx = i; + int chunk_size = st->n_chunks_per_task; + int start_row = st->start_row + chunk_idx * chunk_size; + int n_rows = hex_smin(st->cne1 - start_row, chunk_size); + if (n_rows > 0) { + __fp16 *dst = st->dst + (size_t)(start_row - st->start_row) * st->k_block; + transfer_activation_chunk_fp32_to_fp16_gathered( + dst, st->src, start_row, n_rows, st->k_block, + st->matrix_rows, st->cur_a, st->mapping_stride, + st->ne11, &st->ne11_div, st->nb11, st->nb12, st->cne1); + } +} - // VTCM layout: weight | activation | output | scratch0 | scratch1 | scales | [f32_scratch] - uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base; - __fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_area_size); - __fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, activation_area_size); - __fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, output_area_size); - void *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch_area_size); - void *vtcm_scratch1 = vtcm_seq_alloc(&vtcm_ptr, scratch_area_size); - __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); - float *vtcm_f32_act = use_dma_activation ? (float *) vtcm_seq_alloc(&vtcm_ptr, f32_scratch_size) : NULL; - if ((size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base) > vtcm_budget) { - FARF(ERROR, "%s: vtcm overflow: used=%zu limit=%zu", __func__, - (size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget); - return -1; +static void transfer_activation_chunk_gathered_threaded( + struct htp_context *ctx, + __fp16 *dst, + const float *src, + int start_row, + int n_rows, + int k_block, + const struct mmid_row_mapping *matrix_rows, + int cur_a, + int mapping_stride, + int ne11, + size_t nb11, + size_t nb12, + int cne1, + int n_threads) { + if (n_rows <= 0) return; + int chunks_per_thread = hmx_ceil_div(n_rows, n_threads); + chunks_per_thread = hex_align_up(chunks_per_thread, HMX_FP16_TILE_N_ROWS); + + int actual_threads = hmx_ceil_div(n_rows, chunks_per_thread); + + activation_transfer_gathered_task_state_t state = { + .dst = dst, + .src = src, + .n_tasks = actual_threads, + .n_tot_chunks = n_rows, + .n_chunks_per_task = chunks_per_thread, + .k_block = k_block, + .matrix_rows = matrix_rows, + .cur_a = cur_a, + .mapping_stride = mapping_stride, + .ne11 = ne11, + .ne11_div = init_fastdiv_values(ne11), + .nb11 = nb11, + .nb12 = nb12, + .start_row = start_row, + .cne1 = cne1, + }; + + if (actual_threads <= 1) { + transfer_activation_chunk_gathered_worker_fn(1, 0, &state); + } else { + worker_pool_run_func(ctx->worker_pool, transfer_activation_chunk_gathered_worker_fn, &state, actual_threads); } +} - hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16 +static void transfer_output_chunk_fp16_to_fp32_scattered( + float *restrict dst, + const __fp16 *restrict vtcm_src, + int start_row, + int n_rows, + int n_cols, + const struct mmid_row_mapping *matrix_rows, + int cur_a, + int mapping_stride, + size_t dst_nb1, + size_t dst_nb2, + int cne1) { + assert(n_cols % HMX_FP16_TILE_N_COLS == 0); + const size_t tile_row_stride = (n_cols / HMX_FP16_TILE_N_COLS) * HMX_FP16_TILE_N_ELMS; - FARF(HIGH, "%s: m=%d k=%d n=%d mc=%zu nc=%zu vtcm=%zu/%zu", - __func__, m, k, n, m_chunk_n_rows, n_chunk_n_cols, - (size_t)(vtcm_ptr - (uint8_t *)ctx->vtcm_base), vtcm_budget); + const HVX_Vector one = hvx_vec_splat_f16(1.0); - TIMER_DEFINE(activation_load); - TIMER_DEFINE(weight_load); - TIMER_DEFINE(hmx_core); - TIMER_DEFINE(output_store); + for (size_t r = 0; r < n_rows; r += 2) { + const size_t r0 = r / HMX_FP16_TILE_N_ROWS; + const size_t r1 = (r % HMX_FP16_TILE_N_ROWS) / 2; // index of the row pair within the tile + const __fp16 *row_base = vtcm_src + r0 * tile_row_stride; - TIMER_DEFINE(total); - TIMER_START(total); + int r_idx0 = start_row + (int)r + 0; + int r_idx1 = start_row + (int)r + 1; - HAP_compute_res_hmx_lock(ctx->vtcm_rctx); + if (r_idx0 >= cne1) break; - for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) { - // transfer activation matrix chunk into VTCM - const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows); - const size_t n_row_tiles = hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS); + struct mmid_row_mapping mapping0 = matrix_rows[cur_a * mapping_stride + r_idx0]; + float *output_row0 = (float *) ((uint8_t *) dst + mapping0.i1 * dst_nb1 + mapping0.i2 * dst_nb2); - TIMER_START(activation_load); - { - const float *activation_chunk = activation + mr * act_stride; - if (use_dma_activation) { - const size_t row_bytes = (size_t) k * sizeof(float); - const size_t stride_bytes = (size_t) act_stride * sizeof(float); - dma_queue_push(ctx->dma[0], - dma_make_ptr(vtcm_f32_act, activation_chunk), - row_bytes, stride_bytes, row_bytes, n_rows); - dma_queue_pop(ctx->dma[0]); - transfer_activation_chunk_threaded(ctx, vtcm_activation, - vtcm_f32_act, n_rows, k, k); - } else { - transfer_activation_chunk_threaded(ctx, vtcm_activation, - activation_chunk, n_rows, k, act_stride); + float *output_row1 = NULL; + if (r_idx1 < cne1) { + struct mmid_row_mapping mapping1 = matrix_rows[cur_a * mapping_stride + r_idx1]; + output_row1 = (float *) ((uint8_t *) dst + mapping1.i1 * dst_nb1 + mapping1.i2 * dst_nb2); + } + + #pragma unroll(4) + for (size_t c = 0; c < (size_t)n_cols; c += HMX_FP16_TILE_N_COLS) { + const size_t c0 = c / HMX_FP16_TILE_N_COLS; + const __fp16 *tile = row_base + c0 * HMX_FP16_TILE_N_ELMS; + HVX_Vector v = ((const HVX_Vector *) tile)[r1]; + HVX_VectorPair vp = Q6_Wqf32_vmpy_VhfVhf(v, one); + + volatile HVX_Vector *pv_out0 = (volatile HVX_Vector *) (output_row0 + c); + volatile HVX_Vector *pv_out1 = output_row1 ? (volatile HVX_Vector *) (output_row1 + c) : NULL; + + *pv_out0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(vp)); + if (pv_out1) { + *pv_out1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(vp)); } } - TIMER_STOP(activation_load); + } +} - const size_t fp16_row_bytes = (size_t) k * sizeof(__fp16); - const size_t weight_row_bytes = (size_t) weight_stride * sizeof(__fp16); +static void transfer_output_chunk_scattered_worker_fn(unsigned int n, unsigned int i, void *data) { + output_transfer_scattered_task_state_t *st = data; + int chunk_idx = i; + int chunk_size = st->n_chunks_per_task; + int start_row = st->start_row + chunk_idx * chunk_size; + int n_rows = hex_smin(st->cne1 - start_row, chunk_size); + if (n_rows > 0) { + const __fp16 *src = st->vtcm_src + (size_t)(start_row - st->start_row) * st->n_cols; + transfer_output_chunk_fp16_to_fp32_scattered( + st->dst, src, start_row, n_rows, st->n_cols, + st->matrix_rows, st->cur_a, st->mapping_stride, + st->dst_nb1, st->dst_nb2, st->cne1); + } +} - void *buf_curr = vtcm_scratch0; - void *buf_next = vtcm_scratch1; +static void transfer_output_chunk_scattered_threaded( + struct htp_context *ctx, + float *dst, + const __fp16 *vtcm_src, + int start_row, + int n_rows, + int n_cols, + const struct mmid_row_mapping *matrix_rows, + int cur_a, + int mapping_stride, + size_t dst_nb1, + size_t dst_nb2, + int cne1, + int n_threads) { + if (n_rows <= 0) return; + int chunks_per_thread = hmx_ceil_div(n_rows, n_threads); + chunks_per_thread = hex_align_up(chunks_per_thread, HMX_FP16_TILE_N_ROWS); + + int actual_threads = hmx_ceil_div(n_rows, chunks_per_thread); + + output_transfer_scattered_task_state_t state = { + .vtcm_src = vtcm_src, + .dst = dst, + .n_tasks = actual_threads, + .n_tot_chunks = n_rows, + .n_chunks_per_task = chunks_per_thread, + .n_cols = n_cols, + .matrix_rows = matrix_rows, + .cur_a = cur_a, + .mapping_stride = mapping_stride, + .dst_nb1 = dst_nb1, + .dst_nb2 = dst_nb2, + .start_row = start_row, + .cne1 = cne1, + }; + + if (actual_threads <= 1) { + transfer_output_chunk_scattered_worker_fn(1, 0, &state); + } else { + worker_pool_run_func(ctx->worker_pool, transfer_output_chunk_scattered_worker_fn, &state, actual_threads); + } +} - // issue async DMA for the first weight chunk - // NOTE: use 2D DMA (n_cols rows x fp16_row_bytes) to avoid 16-bit roiwidth overflow. - // The source rows can be strided (e.g. KV-cache K after ggml_permute). - { - const size_t n_cols_first = hex_smin(n, n_chunk_n_cols); +int hmx_matmul_id_2d_f32(struct htp_context *ctx, + float *restrict dst, + const float *activation, + const uint8_t *permuted_weight, + int m, int k, int n, + int ne11, + size_t act_nb1, size_t act_nb2, + size_t dst_nb1, size_t dst_nb2, + int weight_stride, + int weight_type, + const struct mmid_row_mapping *matrix_rows, + int cur_a, + int mapping_stride) { + const int cne1 = m; + const int m_padded = hex_align_up(m, 32); - dma_queue_push(ctx->dma[0], dma_make_ptr(buf_curr, permuted_weight), - fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_first); - } + if (k % 32 != 0 || n % 32 != 0) { return -1; } - for (size_t nc = 0; nc < n; nc += n_chunk_n_cols) { - const size_t n_cols = hex_smin(n - nc, n_chunk_n_cols); - const size_t n_col_tiles = hmx_ceil_div(n_cols, HMX_FP16_TILE_N_COLS); + if (!hex_is_aligned(dst, VLEN) || !hex_is_aligned(activation, VLEN) || !hex_is_aligned(permuted_weight, VLEN)) { + return -1; + } - TIMER_START(weight_load); - { - dma_queue_pop(ctx->dma[0]); // wait until current weight chunk is ready + size_t row_stride = get_x4x2_row_stride(weight_type, k); + if (row_stride == 0) { + return -1; + } - // issue async DMA for the next weight chunk (double buffering) - const size_t nc_next = nc + n_chunk_n_cols; - if (nc_next < n) { - const size_t n_cols_next = hex_smin(n - nc_next, n_chunk_n_cols); - const __fp16 *next_weight_chunk = permuted_weight + nc_next * weight_stride; + worker_callback_t dequant_worker_fn = NULL; + switch (weight_type) { + case HTP_TYPE_Q4_0: dequant_worker_fn = dequantize_x4x2_worker_loop_q4_0; break; + case HTP_TYPE_IQ4_NL: dequant_worker_fn = dequantize_x4x2_worker_loop_iq4_nl; break; + case HTP_TYPE_Q4_1: dequant_worker_fn = dequantize_x4x2_worker_loop_q4_1; break; + case HTP_TYPE_MXFP4: dequant_worker_fn = dequantize_x4x2_worker_loop_mxfp4; break; + case HTP_TYPE_Q8_0: dequant_worker_fn = dequantize_x4x2_worker_loop_q8_0; break; + case HTP_TYPE_F16: dequant_worker_fn = convert_f16_worker_loop; break; + case HTP_TYPE_F32: dequant_worker_fn = quantize_f32_worker_loop; break; + default: + return -1; + } - dma_queue_push(ctx->dma[0], dma_make_ptr(buf_next, next_weight_chunk), - fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_next); - } + const int n_k_tiles = k / HMX_FP16_TILE_N_COLS; + const struct fastdiv_values n_k_tiles_div = init_fastdiv_values(n_k_tiles); - // interleave row-major fp16 from scratch into tile-major in vtcm_weight - hmx_interleave_rows_to_tiles(vtcm_weight, (const __fp16 *) buf_curr, n_cols, k, k, 0, n_cols); + const int num_threads = ctx->n_threads; - hex_swap_ptr(&buf_curr, &buf_next); - } - TIMER_STOP(weight_load); + const size_t vec_dot_size = k * sizeof(__fp16); + const size_t vtcm_budget = ctx->vtcm_size; + size_t vtcm_used = 0; - TIMER_START(hmx_core); - { - core_dot_chunk_fp16(vtcm_output, vtcm_activation, vtcm_weight, vtcm_scales, n_row_tiles, n_col_tiles, k / 32); - } - TIMER_STOP(hmx_core); + const size_t size_per_n = row_stride + vec_dot_size; + const size_t size_per_mn = sizeof(__fp16); - TIMER_START(output_store); - { - float *output = dst + (mr * n + nc); - transfer_output_chunk_threaded(ctx, output, vtcm_output, n_rows, n_cols, n); - } - TIMER_STOP(output_store); - } + size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0; + if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, size_per_n, /*per_m=*/vec_dot_size, size_per_mn, + m_padded, n, + /*m_block_cost=*/(size_t) n * 3, + /*n_block_cost=*/(size_t) m_padded * 2, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used)) { + FARF(HIGH, "hmx-mm-id-2d: VTCM too small : m %d k %d n %d budget %zu", m_padded, k, n, vtcm_budget); + return -1; + } + + const size_t weight_area_size = hex_align_up(n_chunk_n_cols * row_stride, HMX_FP16_TILE_SIZE); + const size_t act_area_size = hex_align_up(m_chunk_n_rows * vec_dot_size, HMX_FP16_TILE_SIZE); + const size_t output_area_size = hex_align_up(m_chunk_n_rows * n_chunk_n_cols * sizeof(__fp16), HMX_FP16_TILE_SIZE); + + size_t scratch0_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); + uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base; + __fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_area_size); + __fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, act_area_size); + __fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, output_area_size); + void *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch0_size); + __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); + + vtcm_used = vtcm_ptr - (uint8_t *) ctx->vtcm_base; + if (vtcm_used > vtcm_budget) { + FARF(ERROR, "hmx-mm-id-2d: VTCM overflow: used %zu budget %zu", vtcm_used, vtcm_budget); + return -1; } - HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); + hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); - TIMER_STOP(total); + HAP_compute_res_hmx_lock(ctx->vtcm_rctx); -#if defined(ENABLE_PROFILE_TIMERS) - FARF(HIGH, "%s: %lld us, m=%d k=%d n=%d", __func__, TIMER_US(total), m, k, n); - FARF(HIGH, " activation_load: %lld us, weight_load: %lld us, hmx_core: %lld us, output_store: %lld us", - TIMER_US(activation_load), TIMER_US(weight_load), TIMER_US(hmx_core), TIMER_US(output_store)); - { - size_t weight_size = (size_t)k * n * sizeof(__fp16); - float bandwidth = 1e-3f * weight_size / (float)TIMER_US(weight_load); - FARF(HIGH, " weight load bandwidth: %.2f GB/s", bandwidth); + for (size_t mr = 0; mr < (size_t) m_padded; mr += m_chunk_n_rows) { + const size_t n_rows = hex_smin(m_padded - mr, m_chunk_n_rows); + const size_t n_row_tiles = hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS); + + transfer_activation_chunk_gathered_threaded( + ctx, vtcm_activation, activation, (int) mr, (int) n_rows, k, + matrix_rows, cur_a, mapping_stride, ne11, act_nb1, act_nb2, cne1, num_threads); + + for (size_t nc = 0; nc < (size_t) n; nc += n_chunk_n_cols) { + const size_t n_cols = hex_smin((size_t) n - nc, n_chunk_n_cols); + const size_t n_col_tiles = hmx_ceil_div(n_cols, HMX_FP16_TILE_N_COLS); + + const uint8_t *qweight_chunk = permuted_weight + nc * weight_stride; + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_weight, qweight_chunk), row_stride, weight_stride, row_stride, n_cols); + dma_queue_pop(ctx->dma[0]); + + dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_scratch0, vtcm_weight, n_cols, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn, num_threads); + + core_dot_chunk_fp16(vtcm_output, vtcm_activation, vtcm_scratch0, vtcm_scales, n_row_tiles, n_col_tiles, k / HMX_FP16_TILE_N_ROWS); + + transfer_output_chunk_scattered_threaded( + ctx, dst, vtcm_output, (int) mr, (int) n_rows, (int) n_cols, + matrix_rows, cur_a, mapping_stride, dst_nb1, dst_nb2, cne1, num_threads); + } } -#endif + HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); return 0; } diff --git a/ggml/src/ggml-hexagon/htp/hmx-ops.c b/ggml/src/ggml-hexagon/htp/hmx-ops.c new file mode 100644 index 00000000000..114d8c14811 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hmx-ops.c @@ -0,0 +1,6 @@ +// HMX operations compiled as a single translation unit. +// This allows interprocedural optimizations within HMX ops without requiring global HTP LTO. + +#include "hmx-queue.c" +#include "hmx-matmul-ops.c" +#include "hmx-flash-attn-ops.c" diff --git a/ggml/src/ggml-hexagon/htp/hmx-ops.h b/ggml/src/ggml-hexagon/htp/hmx-ops.h index f114edb822f..a67842f3ffc 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-ops.h +++ b/ggml/src/ggml-hexagon/htp/hmx-ops.h @@ -52,14 +52,32 @@ int hmx_matmul_f16_f32(struct htp_context *ctx, // Batch semantics match ggml_mul_mat(): src0 broadcasts to src1 in dims 2/3. int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32_batched_params_t *params); -// HMX matrix multiplication — quantised weights (Q4_0/Q8_0/IQ4_NL/MXFP4) -int hmx_matmul_q_f32(struct htp_context *ctx, +// HMX matrix multiplication — all supported weight types (F16/F32/Q4_0/Q4_1/Q8_0/IQ4_NL/MXFP4) +int hmx_matmul_2d_f32(struct htp_context *ctx, float *restrict dst, const float *activation, const uint8_t *permuted_weight, int m, int k, int n, + int act_stride, + int weight_stride, int weight_type); +struct mmid_row_mapping; + +int hmx_matmul_id_2d_f32(struct htp_context *ctx, + float *restrict dst, + const float *activation, + const uint8_t *permuted_weight, + int m, int k, int n, + int ne11, + size_t act_nb1, size_t act_nb2, + size_t dst_nb1, size_t dst_nb2, + int weight_stride, + int weight_type, + const struct mmid_row_mapping *matrix_rows, + int cur_a, + int mapping_stride); + // HMX flash attention int hmx_flash_attn_ext(struct htp_ops_context * octx); diff --git a/ggml/src/ggml-hexagon/htp/htp-ctx.h b/ggml/src/ggml-hexagon/htp/htp-ctx.h index 51f9243ce0a..0f1676f077a 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ctx.h +++ b/ggml/src/ggml-hexagon/htp/htp-ctx.h @@ -79,6 +79,10 @@ struct htp_context { uint64_t max_vmem; + // Persistent DDR scratchpad for MUL_MAT_ID mappings + void * ddr_spad_base; + size_t ddr_spad_size; + struct htp_ops_context octx; #ifdef HTP_HAS_HMX diff --git a/ggml/src/ggml-hexagon/htp/hvx-flash-attn.h b/ggml/src/ggml-hexagon/htp/hvx-flash-attn.h new file mode 100644 index 00000000000..f1f2e49e455 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hvx-flash-attn.h @@ -0,0 +1,47 @@ +#ifndef HVX_FLASH_ATTN_H +#define HVX_FLASH_ATTN_H + +#include <math.h> +#include "hvx-utils.h" + +// Scalar helper to compute a single ALiBi slope. +static inline float alibi_slope(uint32_t h, uint32_t n_head_log2, float m0, float m1) { + return (h < n_head_log2) ? powf(m0, h + 1) : powf(m1, 2 * (h - n_head_log2) + 1); +} + +// Vectorized helper to compute 32 ALiBi slopes starting from (kv_head * G). +static inline HVX_Vector hvx_alibi_slopes( + uint32_t kv_head, + uint32_t G, + uint32_t n_head_log2, + float m0, + float m1 +) { + static const float ramp_32[32] __attribute__((aligned(128))) = { + 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, + 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, + 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, + 24.0f, 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f + }; + HVX_Vector v_ramp = hvx_vmem(ramp_32); + HVX_Vector v_h_base = hvx_vec_splat_f32((float)(kv_head * G)); + HVX_Vector v_h = hvx_vec_add_f32_f32(v_h_base, v_ramp); + + // Compute exponent_m0: h + 1 + HVX_Vector v_exp_m0 = hvx_vec_add_f32_f32(v_h, hvx_vec_splat_f32(1.0f)); + + // Compute exponent_m1: 2 * (h - n_head_log2) + 1 + HVX_Vector v_n_head_log2 = hvx_vec_splat_f32((float)n_head_log2); + HVX_Vector v_h_minus = hvx_vec_sub_f32_f32(v_h, v_n_head_log2); + HVX_Vector v_exp_m1 = hvx_vec_add_f32_f32(hvx_vec_mul_f32_f32(hvx_vec_splat_f32(2.0f), v_h_minus), hvx_vec_splat_f32(1.0f)); + + // Compute powers + HVX_Vector v_pow_m0 = hvx_vec_pow_const_base_f32(m0, v_exp_m0); + HVX_Vector v_pow_m1 = hvx_vec_pow_const_base_f32(m1, v_exp_m1); + + // Select based on h < n_head_log2 + HVX_VectorPred p_cond = Q6_Q_vcmp_gt_VsfVsf(v_n_head_log2, v_h); // v_n_head_log2 > v_h <=> h < n_head_log2 + return Q6_V_vmux_QVV(p_cond, v_pow_m0, v_pow_m1); +} + +#endif /* HVX_FLASH_ATTN_H */ diff --git a/ggml/src/ggml-hexagon/htp/hvx-log.h b/ggml/src/ggml-hexagon/htp/hvx-log.h new file mode 100644 index 00000000000..7013dae785a --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hvx-log.h @@ -0,0 +1,65 @@ +#ifndef HVX_LOG_H +#define HVX_LOG_H + +#include "hvx-base.h" + +// Approximates ln(x) element-wise for float vectors. +// x must contain positive float elements. +// Uses Abramowitz & Stegun polynomial approximation 4.1.44 for ln(1+y) over [0, 1]. +static inline HVX_Vector hvx_vec_log_f32(HVX_Vector x) { + // x = m * 2^e, where m in [1, 2) + HVX_Vector biased_e = Q6_Vuw_vlsr_VuwR(x, 23); + HVX_Vector e_int = Q6_Vw_vsub_VwVw(biased_e, Q6_V_vsplat_R(127)); + HVX_Vector e_float = Q6_Vsf_equals_Vw(e_int); + + // Extract mantissa and set exponent to 127 (which represents float value in [1.0, 2.0)) + HVX_Vector mant_mask = Q6_V_vsplat_R(0x007FFFFF); + HVX_Vector exp_127 = Q6_V_vsplat_R(0x3F800000); + HVX_Vector m = Q6_V_vor_VV(Q6_V_vand_VV(x, mant_mask), exp_127); + + // y = m - 1.0f, y in [0, 1) + HVX_Vector y = hvx_vec_sub_f32_f32(m, hvx_vec_splat_f32(1.0f)); + + // Abramowitz & Stegun 4.1.44 polynomial approximation of ln(1+y) + HVX_Vector c; + HVX_Vector res; + + c = hvx_vec_splat_f32(-0.0064535442f); + res = hvx_vec_mul_f32_f32(y, c); + + c = hvx_vec_splat_f32(0.0360884937f); + res = hvx_vec_add_f32_f32(res, c); + res = hvx_vec_mul_f32_f32(y, res); + + c = hvx_vec_splat_f32(-0.0953293897f); + res = hvx_vec_add_f32_f32(res, c); + res = hvx_vec_mul_f32_f32(y, res); + + c = hvx_vec_splat_f32(0.1676540711f); + res = hvx_vec_add_f32_f32(res, c); + res = hvx_vec_mul_f32_f32(y, res); + + c = hvx_vec_splat_f32(-0.2407338084f); + res = hvx_vec_add_f32_f32(res, c); + res = hvx_vec_mul_f32_f32(y, res); + + c = hvx_vec_splat_f32(0.3317990258f); + res = hvx_vec_add_f32_f32(res, c); + res = hvx_vec_mul_f32_f32(y, res); + + c = hvx_vec_splat_f32(-0.4998741238f); + res = hvx_vec_add_f32_f32(res, c); + res = hvx_vec_mul_f32_f32(y, res); + + c = hvx_vec_splat_f32(0.9999964239f); + res = hvx_vec_add_f32_f32(res, c); + res = hvx_vec_mul_f32_f32(y, res); + + // ln(x) = e * ln(2) + ln(1+y) + HVX_Vector ln2 = hvx_vec_splat_f32(0.69314718056f); + HVX_Vector term_e = hvx_vec_mul_f32_f32(e_float, ln2); + + return hvx_vec_add_f32_f32(term_e, res); +} + +#endif /* HVX_LOG_H */ diff --git a/ggml/src/ggml-hexagon/htp/hvx-pow.h b/ggml/src/ggml-hexagon/htp/hvx-pow.h new file mode 100644 index 00000000000..48fe0e8eade --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hvx-pow.h @@ -0,0 +1,42 @@ +#ifndef HVX_POW_H +#define HVX_POW_H + +#include <math.h> +#include "hvx-base.h" +#include "hvx-exp.h" +#include "hvx-log.h" + +// Approximates base^exponent element-wise for float vectors. +// base must be a positive constant. exponent is an HVX f32 vector. +// Uses base^x = exp(x * ln(base)). +static inline HVX_Vector hvx_vec_pow_const_base_f32(float base, HVX_Vector exponent) { + float ln_base = logf(base); + HVX_Vector ln_base_v = hvx_vec_splat_f32(ln_base); + HVX_Vector x = hvx_vec_mul_f32_f32(exponent, ln_base_v); + + static const float kInf = INFINITY; + static const float kMaxExp = 88.7228f; + + const HVX_Vector max_exp = hvx_vec_splat_f32(kMaxExp); + const HVX_Vector inf = hvx_vec_splat_f32(kInf); + + return hvx_vec_exp_f32_guard(x, max_exp, inf); +} + +// Approximates base^exponent element-wise for float vectors. +// base and exponent are HVX f32 vectors. base elements must be positive. +// Uses base^exponent = exp(exponent * ln(base)). +static inline HVX_Vector hvx_vec_pow_f32(HVX_Vector base, HVX_Vector exponent) { + HVX_Vector ln_base = hvx_vec_log_f32(base); + HVX_Vector x = hvx_vec_mul_f32_f32(exponent, ln_base); + + static const float kInf = INFINITY; + static const float kMaxExp = 88.7228f; + + const HVX_Vector max_exp = hvx_vec_splat_f32(kMaxExp); + const HVX_Vector inf = hvx_vec_splat_f32(kInf); + + return hvx_vec_exp_f32_guard(x, max_exp, inf); +} + +#endif /* HVX_POW_H */ diff --git a/ggml/src/ggml-hexagon/htp/hvx-utils.h b/ggml/src/ggml-hexagon/htp/hvx-utils.h index 0a760cd344c..23373f73ae2 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-utils.h +++ b/ggml/src/ggml-hexagon/htp/hvx-utils.h @@ -17,5 +17,7 @@ #include "hvx-floor.h" #include "hvx-sin-cos.h" #include "hvx-base.h" +#include "hvx-pow.h" +#include "hvx-log.h" #endif /* HVX_UTILS_H */ diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index 623008be4e2..3715227d2c7 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -12,6 +12,7 @@ #include <HAP_mem.h> #include <HAP_power.h> #include <HAP_ps.h> +#include <HAP_dcvs.h> #include <qurt.h> #include <qurt_thread.h> #include <qurt_memory.h> @@ -63,8 +64,7 @@ AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) { request.type = HAP_power_set_DCVS_v3; request.dcvs_v3.set_dcvs_enable = TRUE; - request.dcvs_v3.dcvs_enable = TRUE; - request.dcvs_v3.dcvs_option = HAP_DCVS_V2_PERFORMANCE_MODE; + request.dcvs_v3.dcvs_enable = FALSE; request.dcvs_v3.set_bus_params = TRUE; request.dcvs_v3.bus_params.min_corner = HAP_DCVS_VCORNER_MAX; request.dcvs_v3.bus_params.max_corner = HAP_DCVS_VCORNER_MAX; @@ -75,6 +75,10 @@ AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) { request.dcvs_v3.core_params.target_corner = HAP_DCVS_VCORNER_MAX; request.dcvs_v3.set_sleep_disable = TRUE; request.dcvs_v3.sleep_disable = TRUE; + +#if (__HEXAGON_ARCH__ >= 79) + HAP_set_dcvs_v3_protected_bus_corners(&request, 1); +#endif if ((err = HAP_power_set((void *) ctx, &request)) != 0) { return err; } @@ -103,7 +107,7 @@ AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) { FARF(ALWAYS, "Setting HMX clock\n"); err = HAP_power_set((void *) ctx, &request); if (err != AEE_SUCCESS) { - FARF(ERROR, "Error setting HMX clock."); + FARF(ERROR, "ggml-hex: error setting HMX clock."); return err; } } @@ -117,7 +121,7 @@ AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) { FARF(ALWAYS, "Powering HMX on\n"); err = HAP_power_set((void *) ctx, &request); if (err != AEE_SUCCESS) { - FARF(ERROR, "Error powering on HMX."); + FARF(ERROR, "ggml-hex: error powering on HMX."); return err; } } @@ -423,10 +427,18 @@ AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_que ctx->dma[i] = dma_queue_create(256); // queue depth } + ctx->ddr_spad_size = 512 * 1024; // 512 KB + ctx->ddr_spad_base = memalign(128, ctx->ddr_spad_size); + // init worker pool err = worker_pool_init(&ctx->worker_pool, n_hvx); if (err != AEE_SUCCESS) { FARF(ERROR, "Unable to create worker pool"); + if (ctx->ddr_spad_base) { + free(ctx->ddr_spad_base); + ctx->ddr_spad_base = NULL; + ctx->ddr_spad_size = 0; + } return err; } @@ -474,6 +486,12 @@ AEEResult htp_iface_stop(remote_handle64 handle) { vtcm_free(ctx); + if (ctx->ddr_spad_base) { + free(ctx->ddr_spad_base); + ctx->ddr_spad_base = NULL; + ctx->ddr_spad_size = 0; + } + return AEE_SUCCESS; } diff --git a/ggml/src/ggml-hexagon/htp/matmul-ops.c b/ggml/src/ggml-hexagon/htp/matmul-ops.c index 7036c491bc4..5121c6f9bad 100644 --- a/ggml/src/ggml-hexagon/htp/matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/matmul-ops.c @@ -53,6 +53,11 @@ struct htp_matmul_context { struct fastdiv_values mm_div_ne1; struct fastdiv_values mm_div_r2; struct fastdiv_values mm_div_r3; + + // Fields for scattered mapping & HMX support in MUL_MAT_ID + const uint32_t * matrix_row_counts; + const struct mmid_row_mapping * matrix_rows; + bool hmx_eligible; }; // vdelta control to expand first 32 e8m0 values into 32 uint32 elements @@ -2913,6 +2918,176 @@ static void vec_dot_mxfp4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); // row0,col1 row1,col1 } +#if __HVX_ARCH__ < 79 +#define HVX_OP_ADD_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(a, b)) +#define HVX_OP_MUL_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b)) +#else +#define HVX_OP_ADD_F32(a, b) Q6_Vsf_vadd_VsfVsf(a, b) +#define HVX_OP_MUL_F32(a, b) Q6_Vsf_vmpy_VsfVsf(a, b) +#endif + +static void vec_dot_f32_f32_aa_1x1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + const HVX_Vector * restrict x = (const HVX_Vector *) vx; + const HVX_Vector * restrict y = (const HVX_Vector *) vy; + + uint32_t nvec = n / VLEN_FP32; // num full fp32 hvx vectors + uint32_t nloe = n % VLEN_FP32; // leftover elements + + HVX_Vector rsum = Q6_V_vzero(); + + uint32_t i = 0; + + #pragma unroll(4) + for (i = 0; i < nvec; i++) { + HVX_Vector prod = HVX_OP_MUL_F32(x[i], y[i]); + rsum = HVX_OP_ADD_F32(rsum, prod); + } + + if (nloe) { + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4); + HVX_Vector x_sf = Q6_V_vand_QV(bmask, x[i]); + HVX_Vector y_sf = Q6_V_vand_QV(bmask, y[i]); + HVX_Vector prod = HVX_OP_MUL_F32(x_sf, y_sf); + rsum = HVX_OP_ADD_F32(rsum, prod); + } + + *s = hvx_vec_get_f32(hvx_vec_reduce_sum_f32(rsum)); +} + +static void vec_dot_f32_f32_aa_2x1(const int n, float * restrict s0, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0) { + const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0; + const HVX_Vector * restrict x1 = (const HVX_Vector *) vx1; + const HVX_Vector * restrict y = (const HVX_Vector *) vy0; + + uint32_t nvec = n / VLEN_FP32; + uint32_t nloe = n % VLEN_FP32; + + HVX_Vector rsum0 = Q6_V_vzero(); + HVX_Vector rsum1 = Q6_V_vzero(); + + uint32_t i = 0; + + #pragma unroll(2) + for (i = 0; i < nvec; i++) { + HVX_Vector y_sf = y[i]; + HVX_Vector prod0 = HVX_OP_MUL_F32(x0[i], y_sf); + HVX_Vector prod1 = HVX_OP_MUL_F32(x1[i], y_sf); + rsum0 = HVX_OP_ADD_F32(rsum0, prod0); + rsum1 = HVX_OP_ADD_F32(rsum1, prod1); + } + + if (nloe) { + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4); + HVX_Vector y_sf = Q6_V_vand_QV(bmask, y[i]); + HVX_Vector x0_sf = Q6_V_vand_QV(bmask, x0[i]); + HVX_Vector x1_sf = Q6_V_vand_QV(bmask, x1[i]); + HVX_Vector prod0 = HVX_OP_MUL_F32(x0_sf, y_sf); + HVX_Vector prod1 = HVX_OP_MUL_F32(x1_sf, y_sf); + rsum0 = HVX_OP_ADD_F32(rsum0, prod0); + rsum1 = HVX_OP_ADD_F32(rsum1, prod1); + } + + HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(rsum0, rsum1); + HVX_VectorAlias va; + va.v = rsum; + s0[0] = va.fp32[0]; + s0[1] = va.fp32[1]; +} + +static void vec_dot_f32_f32_aa_2x2(const int n, float * restrict s0, float * restrict s1, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0, const void * restrict vy1) { + const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0; + const HVX_Vector * restrict x1 = (const HVX_Vector *) vx1; + const HVX_Vector * restrict y0 = (const HVX_Vector *) vy0; + const HVX_Vector * restrict y1 = (const HVX_Vector *) vy1; + + uint32_t nvec = n / VLEN_FP32; + uint32_t nloe = n % VLEN_FP32; + + HVX_Vector r0_c0_sum = Q6_V_vzero(); + HVX_Vector r0_c1_sum = Q6_V_vzero(); + HVX_Vector r1_c0_sum = Q6_V_vzero(); + HVX_Vector r1_c1_sum = Q6_V_vzero(); + + uint32_t i = 0; + + #pragma unroll(2) + for (i = 0; i < nvec; i++) { + HVX_Vector r0_sf = x0[i]; + HVX_Vector r1_sf = x1[i]; + HVX_Vector c0_sf = y0[i]; + HVX_Vector c1_sf = y1[i]; + + r0_c0_sum = HVX_OP_ADD_F32(r0_c0_sum, HVX_OP_MUL_F32(r0_sf, c0_sf)); + r0_c1_sum = HVX_OP_ADD_F32(r0_c1_sum, HVX_OP_MUL_F32(r0_sf, c1_sf)); + r1_c0_sum = HVX_OP_ADD_F32(r1_c0_sum, HVX_OP_MUL_F32(r1_sf, c0_sf)); + r1_c1_sum = HVX_OP_ADD_F32(r1_c1_sum, HVX_OP_MUL_F32(r1_sf, c1_sf)); + } + + if (nloe) { + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4); + + HVX_Vector r0_sf = Q6_V_vand_QV(bmask, x0[i]); + HVX_Vector r1_sf = Q6_V_vand_QV(bmask, x1[i]); + HVX_Vector c0_sf = Q6_V_vand_QV(bmask, y0[i]); + HVX_Vector c1_sf = Q6_V_vand_QV(bmask, y1[i]); + + r0_c0_sum = HVX_OP_ADD_F32(r0_c0_sum, HVX_OP_MUL_F32(r0_sf, c0_sf)); + r0_c1_sum = HVX_OP_ADD_F32(r0_c1_sum, HVX_OP_MUL_F32(r0_sf, c1_sf)); + r1_c0_sum = HVX_OP_ADD_F32(r1_c0_sum, HVX_OP_MUL_F32(r1_sf, c0_sf)); + r1_c1_sum = HVX_OP_ADD_F32(r1_c1_sum, HVX_OP_MUL_F32(r1_sf, c1_sf)); + } + + // Reduce and store results + HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum); + HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum); + + HVX_VectorAlias va0, va1; + va0.v = r0_r1_c0_sum; + va1.v = r0_r1_c1_sum; + s0[0] = va0.fp32[0]; + s0[1] = va0.fp32[1]; + s1[0] = va1.fp32[0]; + s1[1] = va1.fp32[1]; +} + +static void vec_dot_f32_f32_uu_1x1(const int n, float * restrict s, const void * restrict x, const void * restrict y) { + const HVX_UVector * restrict vx = (const HVX_UVector * restrict) x; + const HVX_UVector * restrict vy = (const HVX_UVector * restrict) y; + + uint32_t nvec = n / VLEN_FP32; // num full fp32 hvx vectors + uint32_t nloe = n % VLEN_FP32; // leftover elements + + HVX_Vector rsum = Q6_V_vzero(); + + uint32_t i = 0; + + #pragma unroll(2) + for (i = 0; i < nvec; i++) { + HVX_Vector x_sf = vx[i]; + HVX_Vector y_sf = vy[i]; + + rsum = HVX_OP_ADD_F32(rsum, HVX_OP_MUL_F32(x_sf, y_sf)); + } + + if (nloe) { + HVX_Vector x_sf = vx[i]; + HVX_Vector y_sf = vy[i]; + + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4); + x_sf = Q6_V_vand_QV(bmask, x_sf); + y_sf = Q6_V_vand_QV(bmask, y_sf); + + rsum = HVX_OP_ADD_F32(rsum, HVX_OP_MUL_F32(x_sf, y_sf)); + } + + rsum = hvx_vec_reduce_sum_f32(rsum); + hvx_vec_store_u(&s[0], 4, rsum); +} + static void vec_dot_f16_f16_aa_1x1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { const HVX_Vector * restrict x = (const HVX_Vector *) vx; const HVX_Vector * restrict y = (const HVX_Vector *) vy; @@ -3331,7 +3506,7 @@ static void matmul_2d(unsigned int nth, unsigned int ith, void * data) { // Process the last row (if any) if (src0_end_row != src0_end_row_x2) { uint32_t ir0 = src0_end_row_x2; - const int is0 = (ir0 - src0_start_row); + const int is0 = (ir0 - src0_start_row) % MM_SPAD_SRC0_NROWS; dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), src0_stride, src0_row_size, 1); const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; @@ -3466,7 +3641,7 @@ static void matvec_2d(unsigned int nth, unsigned int ith, void * data) { // Process the last row (if any) if (src0_end_row != src0_end_row_x2) { const uint32_t ir0 = src0_end_row_x2; - const uint32_t is0 = (ir0 - src0_start_row); + const uint32_t is0 = (ir0 - src0_start_row) % MM_SPAD_SRC0_NROWS; dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), src0_stride, src0_row_size, 1); const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; @@ -3516,11 +3691,8 @@ static void matmul_id(unsigned int nth, unsigned int ith, void * data) { const uint32_t n_ids = ids->ne[0]; // n_expert_used const uint32_t n_as = ne02; // n_expert - const size_t matrix_row_counts_size = n_as * sizeof(uint32_t); - const size_t matrix_row_map_size = n_as * ids->ne[0] * ids->ne[1] * sizeof(struct mmid_row_mapping); - - const uint32_t * matrix_row_counts = (const uint32_t *) src2_spad->data + 0; - const struct mmid_row_mapping * matrix_rows = (const void *) src2_spad->data + matrix_row_counts_size; + const uint32_t * matrix_row_counts = mmctx->matrix_row_counts; + const struct mmid_row_mapping * matrix_rows = mmctx->matrix_rows; const size_t dst_row_size = nb1; const size_t src0_row_size = nb01; @@ -3542,6 +3714,10 @@ static void matmul_id(unsigned int nth, unsigned int ith, void * data) { continue; } + if (mmctx->hmx_eligible) { + continue; + } + const uint8_t * src0_row = (const uint8_t *) src0->data + (0 + cur_a * nb02 + 0); // Prefill spad with src0 rows @@ -3583,7 +3759,7 @@ static void matmul_id(unsigned int nth, unsigned int ith, void * data) { // Process the last row (if any) if (src0_end_row != src0_end_row_x2) { uint32_t ir0 = src0_end_row_x2; - const uint32_t is0 = (ir0 - src0_start_row); + const uint32_t is0 = (ir0 - src0_start_row) % MM_SPAD_SRC0_NROWS; dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size), src0_row_size_padded, src0_row_size, 1); const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; @@ -3685,7 +3861,7 @@ static void matvec_id(unsigned int nth, unsigned int ith, void * data) { // Process the last row (if any) if (src0_end_row != src0_end_row_x2) { uint32_t ir0 = src0_end_row_x2; - const uint32_t is0 = (ir0 - src0_start_row); + const uint32_t is0 = (ir0 - src0_start_row) % MM_SPAD_SRC0_NROWS; dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size), src0_row_size_padded, src0_row_size, 1); const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; @@ -4086,6 +4262,47 @@ static void quantize_f32_q8_1x4x2(unsigned int nth, unsigned int ith, void * dat ir_last, src_row_size, dst_row_size, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } +static void quantize_f32_f32(unsigned int nth, unsigned int ith, void * data) { + struct htp_matmul_context * mmctx = data; + struct htp_ops_context * octx = mmctx->octx; + + const struct htp_tensor * src = octx->src[1]; + uint8_t * restrict dst = octx->src1_spad.data; + uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread; + uint32_t dst_stride = octx->src1_spad.stride; + + uint64_t t1 = HAP_perf_get_qtimer_count(); + + const uint32_t ne0 = src->ne[0]; + const uint32_t ne1 = src->ne[1]; + const uint32_t ne2 = src->ne[2]; + const uint32_t ne3 = src->ne[3]; + + const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows + + const uint32_t ir_first = nrows_per_thread * ith; // first row + const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row + + const size_t src_row_size = ne0 * sizeof(float); + const size_t src_stride = src->nb[1]; + + uint8_t * restrict src_data = (uint8_t *) src->data + (src_stride * ir_first); + uint8_t * restrict dst_data = (uint8_t *) dst + (dst_stride * ir_first); + + for (uint32_t i = ir_first; i < ir_last; ++i) { + hex_l2fetch(src_data, src_row_size, src_stride, 2); + hvx_copy_f32_au(dst_data, src_data, ne0); + + dst_data += dst_stride; + src_data += src_stride; + } + + uint64_t t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "quantize-f32-f32: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\n", ith, nth, nrows, ir_first, + ir_last, src_row_size, src_stride, dst_stride, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + static void quantize_f32_f16(unsigned int nth, unsigned int ith, void * data) { struct htp_matmul_context * mmctx = data; struct htp_ops_context * octx = mmctx->octx; @@ -4328,6 +4545,60 @@ static int op_matmul_hvx(struct htp_ops_context * octx) { mmctx->mm_div_r2 = init_fastdiv_values(src1->ne[2] / src0->ne[2]); mmctx->mm_div_r3 = init_fastdiv_values(src1->ne[3] / src0->ne[3]); + need_quant = false; + } + } else if (src0->type == HTP_TYPE_F32) { + // Try optimized f32-f32 path first (src1 in VTCM) + const size_t f32_src1_row_size = hex_round_up(ne10 * 4, 128); + const size_t f32_src1_spad_size = hex_round_up(f32_src1_row_size * src1_nrows, 256); + const size_t f32_src0_spad_size = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256) * octx->n_threads; + const size_t f32_dst_spad_size = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256) * octx->n_threads; + + const size_t f32_total_size = f32_src1_spad_size + f32_src0_spad_size + f32_dst_spad_size; + + const bool is_batched = (ne02 > 1) || (ne03 > 1); + const bool is_permuted = htp_is_permuted(octx->src[0]) || htp_is_permuted(octx->src[1]); + + if (!is_batched && !is_permuted && f32_total_size <= octx->ctx->vtcm_size) { + // Optimized path + quant_job_func = quantize_f32_f32; + mmctx->type = "f32-f32"; + mmctx->vec_dot_1x1 = vec_dot_f32_f32_aa_1x1; + mmctx->vec_dot_2x1 = vec_dot_f32_f32_aa_2x1; + mmctx->vec_dot_2x2 = vec_dot_f32_f32_aa_2x2; + + src1_row_size = f32_src1_row_size; + + octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); + octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); + octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256); + + octx->src1_spad.size = octx->src1_spad.size_per_thread; + octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; + octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; + } else { + // Fallback to DDR / broadcasting + quant_job_func = NULL; + mmctx->type = "f32-f32"; + mmctx->vec_dot_1x1 = vec_dot_f32_f32_uu_1x1; + matmul_job_func = matmul_4d; + + src1_row_size = nb11; + + octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); + octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size, 256); + octx->src1_spad.size_per_thread = hex_round_up(MM_SPAD_SRC1_NROWS * src1_row_size, 256); + + octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; + octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads; + octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; + + // Init fastdiv for matmul_4d (supports broadcasting) + mmctx->mm_div_ne12_ne1 = init_fastdiv_values(src1->ne[2] * dst->ne[1]); + mmctx->mm_div_ne1 = init_fastdiv_values(dst->ne[1]); + mmctx->mm_div_r2 = init_fastdiv_values(src1->ne[2] / src0->ne[2]); + mmctx->mm_div_r3 = init_fastdiv_values(src1->ne[3] / src0->ne[3]); + need_quant = false; } } else { @@ -4405,20 +4676,20 @@ int op_matmul(struct htp_ops_context * octx) { return op_matmul_hvx(octx); } - // HMX supports F16, Q4_0, Q8_0, IQ4_NL, MXFP4 weights. + // HMX supports F16, F32, Q4_0, Q8_0, IQ4_NL, MXFP4 weights. // Other types fall back to HVX. uint32_t wtype = src0->type; - if (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_Q4_0 && wtype != HTP_TYPE_Q4_1 && wtype != HTP_TYPE_Q8_0 && wtype != HTP_TYPE_IQ4_NL && wtype != HTP_TYPE_MXFP4) { + if (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_F32 && wtype != HTP_TYPE_Q4_0 && wtype != HTP_TYPE_Q4_1 && wtype != HTP_TYPE_Q8_0 && wtype != HTP_TYPE_IQ4_NL && wtype != HTP_TYPE_MXFP4) { return op_matmul_hvx(octx); } // Quantised HMX path requires K aligned to 256 (x4x2 super-block). - // F16 HMX path requires K aligned to 32 (tile width). - if (wtype != HTP_TYPE_F16 && src0->ne[0] % 256 != 0) { + // F16 and F32 HMX paths require K aligned to 32 (tile width). + if (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_F32 && src0->ne[0] % 256 != 0) { return op_matmul_hvx(octx); } - if (wtype == HTP_TYPE_F16 && src0->ne[0] % 32 != 0) { + if ((wtype == HTP_TYPE_F16 || wtype == HTP_TYPE_F32) && src0->ne[0] % 32 != 0) { return op_matmul_hvx(octx); } @@ -4463,8 +4734,8 @@ int op_matmul(struct htp_ops_context * octx) { return HTP_STATUS_OK; } - if (src0->type == HTP_TYPE_F16) { - if (is_batched) { + if (is_batched) { + if (src0->type == HTP_TYPE_F16) { hmx_matmul_f16_f32_batched_params_t batch_params = { .dst = (float *) dst->data, .activation = (float *) src1->data, @@ -4488,13 +4759,11 @@ int op_matmul(struct htp_ops_context * octx) { }; ret = hmx_matmul_f16_f32_batched(octx->ctx, &batch_params); } else { - ret = hmx_matmul_f16_f32(octx->ctx, - (float*) dst->data, (float*) src1->data, (const __fp16 *) src0->data, - m_total, k, n, act_stride, wgt_stride); + return op_matmul_hvx(octx); } } else { - ret = hmx_matmul_q_f32(octx->ctx, (float*) dst->data, (float*) src1->data, (const uint8_t *) src0->data, - m_total, k, n, (int) src0->type); + ret = hmx_matmul_2d_f32(octx->ctx, (float*) dst->data, (float*) src1->data, (const uint8_t *) src0->data, + m_total, k, n, act_stride, (int) src0->nb[1], (int) src0->type); } if (ret != 0) { @@ -4539,8 +4808,30 @@ int op_matmul_id(struct htp_ops_context * octx) { size_t matrix_row_counts_size = n_as * sizeof(uint32_t); size_t matrix_row_map_size = n_as * ids->ne[0] * ids->ne[1] * sizeof(struct mmid_row_mapping); + const size_t total_map_size = matrix_row_counts_size + matrix_row_map_size; + + void * mapping_buf = NULL; + bool must_free_mapping = false; + + if (octx->ctx->ddr_spad_base && total_map_size <= octx->ctx->ddr_spad_size) { + mapping_buf = octx->ctx->ddr_spad_base; + } else { + mapping_buf = memalign(128, total_map_size); + if (mapping_buf) { + must_free_mapping = true; + } else { + return HTP_STATUS_INTERNAL_ERR; + } + } + + uint32_t * matrix_row_counts = (uint32_t *) mapping_buf; + struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *) ((uint8_t *) mapping_buf + matrix_row_counts_size); + + mmctx->matrix_row_counts = matrix_row_counts; + mmctx->matrix_rows = matrix_rows; if (htp_mminit_vec_dot(mmctx, src0->type) != 0) { + if (must_free_mapping) free(mapping_buf); return HTP_STATUS_NO_SUPPORT; } @@ -4552,7 +4843,7 @@ int op_matmul_id(struct htp_ops_context * octx) { src1_row_size = q8x4x2_row_size(ne10); } - const size_t src2_spad_size_per_thread = hex_round_up(matrix_row_counts_size + matrix_row_map_size, 256); + const size_t src2_spad_size_per_thread = 0; // We moved the mapping to DDR! htp_mminit_spad(octx, dst_row_size, src0_row_size_padded, src1_row_size, src1_nrows, src2_spad_size_per_thread); size_t spad_size = octx->src2_spad.size + octx->src1_spad.size + octx->src0_spad.size + octx->dst_spad.size; @@ -4568,6 +4859,7 @@ int op_matmul_id(struct htp_ops_context * octx) { // Make sure the reserved vtcm size is sufficient if (octx->ctx->vtcm_size < spad_size) { FARF(ERROR, "matmul-id-%s : current VTCM reservation %zu is too small, needed %zu\n", mmctx->type, octx->ctx->vtcm_size, spad_size); + if (must_free_mapping) free(mapping_buf); return HTP_STATUS_VTCM_TOO_SMALL; } @@ -4587,9 +4879,6 @@ int op_matmul_id(struct htp_ops_context * octx) { if (src1_nrows > 1) { // initialize matrix_row_counts and map - uint32_t * matrix_row_counts = (uint32_t *) octx->src2_spad.data + 0; - struct mmid_row_mapping * matrix_rows = (void *) octx->src2_spad.data + matrix_row_counts_size; - memset(matrix_row_counts, 0, n_as * sizeof(uint32_t)); // group rows by src0 matrix @@ -4599,14 +4888,60 @@ int op_matmul_id(struct htp_ops_context * octx) { assert(i02 >= 0 && i02 < n_as); - MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = (struct mmid_row_mapping) { id, iid1 }; + matrix_rows[i02 * n_ids * ids->ne[1] + matrix_row_counts[i02]] = (struct mmid_row_mapping) { id, iid1 }; matrix_row_counts[i02] += 1; } } } - if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { + if (must_free_mapping) free(mapping_buf); return HTP_STATUS_OK; + } + + bool hmx_eligible = false; +#ifdef HTP_HAS_HMX + if (octx->ctx->hmx_enabled && src1_nrows > 1) { + uint32_t wtype = src0->type; + if (ne01 % 32 == 0 && + (wtype == HTP_TYPE_F16 || wtype == HTP_TYPE_F32 || wtype == HTP_TYPE_Q4_0 || wtype == HTP_TYPE_Q4_1 || wtype == HTP_TYPE_Q8_0 || wtype == HTP_TYPE_IQ4_NL || wtype == HTP_TYPE_MXFP4)) { + if ((wtype == HTP_TYPE_F16 || wtype == HTP_TYPE_F32) && ne00 % 32 == 0) { + hmx_eligible = true; + } else if (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_F32 && ne00 % 256 == 0) { + hmx_eligible = true; + } + } + } +#endif + + mmctx->hmx_eligible = hmx_eligible; + + if (hmx_eligible) { + for (uint32_t cur_a = 0; cur_a < n_as; ++cur_a) { + const int32_t cne1 = matrix_row_counts[cur_a]; + if (cne1 == 0) continue; + + int ret = hmx_matmul_id_2d_f32(octx->ctx, (float*) dst->data, (float*) src1->data, + (const uint8_t *) src0->data + cur_a * nb02, + cne1, ne00, ne01, + ne11, + nb11, nb12, + nb1, nb2, + (int) src0->nb[1], (int) src0->type, + matrix_rows, cur_a, n_ids * ids->ne[1]); + if (ret != 0) { + FARF(ERROR, "HMX matmul failed for expert %u, error %d\n", cur_a, ret); + if (must_free_mapping) free(mapping_buf); + return HTP_STATUS_NO_SUPPORT; + } + } + + // HMX has overwritten VTCM, so force dynamic quantization cache to clear + octx->src1_spad.src = NULL; + + if (must_free_mapping) free(mapping_buf); + return HTP_STATUS_OK; + } if (octx->src1_spad.src != src1) { const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads); @@ -4618,5 +4953,6 @@ int op_matmul_id(struct htp_ops_context * octx) { const uint32_t n_matmul_jobs = octx->n_threads; worker_pool_run_func(octx->ctx->worker_pool, matmul_id_job_func, mmctx, n_matmul_jobs); + if (must_free_mapping) free(mapping_buf); return HTP_STATUS_OK; } diff --git a/ggml/src/ggml-hexagon/htp/pad-ops.c b/ggml/src/ggml-hexagon/htp/pad-ops.c index 3abc3c2ead1..aaa72b31590 100644 --- a/ggml/src/ggml-hexagon/htp/pad-ops.c +++ b/ggml/src/ggml-hexagon/htp/pad-ops.c @@ -511,6 +511,8 @@ int op_pad(struct htp_ops_context * octx) { octx->dst_spad.size = n_threads * octx->dst_spad.size_per_thread; octx->src0_spad.data = octx->ctx->vtcm_base; octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size; + octx->src0_spad.src = NULL; + octx->dst_spad.src = NULL; } struct htp_pad_context pctx = { diff --git a/ggml/src/ggml-hexagon/htp/unary-ops.c b/ggml/src/ggml-hexagon/htp/unary-ops.c index 770a6673211..71fab2cdbcb 100644 --- a/ggml/src/ggml-hexagon/htp/unary-ops.c +++ b/ggml/src/ggml-hexagon/htp/unary-ops.c @@ -692,6 +692,11 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * const uint8_t * restrict data_src1 = uctx->data_src1; uint8_t * restrict data_dst = uctx->data_dst; + const struct htp_tensor * src1 = (htp_op == HTP_OP_RMS_NORM_MUL) ? octx->src[1] : NULL; + const uint32_t nb11 = src1 ? src1->nb[1] : 0; + const uint32_t nb12 = src1 ? src1->nb[2] : 0; + const uint32_t nb13 = src1 ? src1->nb[3] : 0; + uint8_t * src0_spad_data = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread); uint8_t * src1_spad_data = octx->src1_spad.data + (ith * octx->src1_spad.size_per_thread); uint8_t * dst_spad_data = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread); @@ -738,10 +743,10 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * src0_row_size_aligned, nb01, src0_data_row_size, block_size); if (htp_op == HTP_OP_RMS_NORM_MUL && !uctx->broadcast_weight) { - const size_t src1_off = unary_row_offset(ir, ne01, ne02, nb01, nb02, nb03); + const size_t src1_off = unary_row_offset(ir, ne01, ne02, nb11, nb12, nb13); dma_queue_push(dma_queue, dma_make_ptr(src1_spad_data + (spad_idx * src1_spad_half_size), data_src1 + src1_off), - uctx->src1_row_size_aligned, nb01, uctx->src1_data_row_size, block_size); + uctx->src1_row_size_aligned, nb11, uctx->src1_data_row_size, block_size); } ir += block_size; @@ -823,10 +828,10 @@ static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * src0_row_size_aligned, nb01, src0_data_row_size, pref_block_size); if (htp_op == HTP_OP_RMS_NORM_MUL && !uctx->broadcast_weight) { - const size_t src1_pref_off = unary_row_offset(pref_ir, ne01, ne02, nb01, nb02, nb03); + const size_t src1_pref_off = unary_row_offset(pref_ir, ne01, ne02, nb11, nb12, nb13); dma_queue_push(dma_queue, dma_make_ptr(src1_spad, data_src1 + src1_pref_off), - uctx->src1_row_size_aligned, nb01, uctx->src1_data_row_size, pref_block_size); + uctx->src1_row_size_aligned, nb11, uctx->src1_data_row_size, pref_block_size); } } } @@ -977,6 +982,10 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) { octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size; } + octx->src0_spad.src = NULL; + octx->src1_spad.src = NULL; + octx->dst_spad.src = NULL; + FARF(HIGH, "%s: (%ux%ux%ux%u) -> (%ux%ux%ux%u) : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size); From d31cb20b258f335df1a658d0dc2d9a271110d14a Mon Sep 17 00:00:00 2001 From: Max Krasnyansky <maxk@qti.qualcomm.com> Date: Tue, 2 Jun 2026 14:08:29 -0700 Subject: [PATCH 778/831] hexagon: profiler output fix and script updates (llama/24042) * hex-ops: fix profiler output (ie remove the redundant NONEs) * hex-prof: update profiling script to support tot.usec column --- ggml/src/ggml-hexagon/htp-opnode.h | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/ggml/src/ggml-hexagon/htp-opnode.h b/ggml/src/ggml-hexagon/htp-opnode.h index 8a1228ccdc0..52c727c6206 100644 --- a/ggml/src/ggml-hexagon/htp-opnode.h +++ b/ggml/src/ggml-hexagon/htp-opnode.h @@ -56,6 +56,20 @@ struct htp_opnode { } std::vector<const ggml_tensor *> get_inputs() const { + if (fused.empty()) { + int last_non_null = -1; + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (node->src[i]) { + last_non_null = i; + } + } + std::vector<const ggml_tensor *> inputs(last_non_null + 1, nullptr); + for (int i = 0; i <= last_non_null; i++) { + inputs[i] = node->src[i]; + } + return inputs; + } + std::vector<const ggml_tensor *> inputs(GGML_MAX_SRC, nullptr); std::vector<const ggml_tensor *> outputs; outputs.push_back(node); @@ -82,12 +96,8 @@ struct htp_opnode { }; for (int i = 0; i < GGML_MAX_SRC; i++) { - if (fused.empty()) { - inputs[i] = node->src[i]; - } else { - if (node->src[i]) { - add_input(node->src[i]); - } + if (node->src[i]) { + add_input(node->src[i]); } } for (const auto * f : fused) { @@ -98,10 +108,7 @@ struct htp_opnode { } } - if (!fused.empty()) { - inputs.resize(count); - } - + inputs.resize(count); return inputs; } From f110ff540c06ba74a9f26c9964eb21c1edf846b1 Mon Sep 17 00:00:00 2001 From: lhez <lih@qti.qualcomm.com> Date: Tue, 2 Jun 2026 14:16:17 -0700 Subject: [PATCH 779/831] opencl: use flat variants of q4_K and q6_K gemv for very large M (llama/24006) --- ggml/src/ggml-opencl/ggml-opencl.cpp | 31 +++++++++++++++++++++------- 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index b67ea46bce8..c411e4aeaec 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -4950,6 +4950,21 @@ inline bool enable_adreno_trans_weight(const ggml_backend_opencl_context *backen return ((elem_num < 128 * 1024 * 1024) && adreno_kernel); // max element num: 2**27 } +static inline bool use_flat_gemv_for_large_m_q4_K(const ggml_tensor *tensor) { + // gemv_noshuffle variant perf drops for large M, use flat variant for large M. + // threshold is well above typical hidden/FFN dims, but below typical vocab sizes. + // note that this forces large M weights to use LM GEMM. + return tensor->ne[1] >= 32768 && tensor->ne[2] == 1 && tensor->ne[3] == 1; +} + +static inline bool use_flat_gemv_for_large_m_q6_K(const ggml_tensor *tensor) { + // gemv_noshuffle variant perf drops for large M, use flat variant for large M. + // threshold is well above typical hidden/FFN dims, but below typical vocab sizes. + // q6_K flat gemv is worse for smaller K; 2048 seems to be a reasonable threshold. + // note that this forces large M weights to use LM GEMM. + return tensor->ne[1] >= 32768 && tensor->ne[0] >= 2048 && tensor->ne[2] == 1 && tensor->ne[3] == 1; +} + static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) { ggml_backend_opencl_device_context * dev_ctx = (ggml_backend_opencl_device_context *)dev->context; ggml_backend_opencl_context * backend_ctx = dev_ctx->backend_ctx; @@ -6595,7 +6610,7 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, #ifdef GGML_OPENCL_USE_ADRENO_KERNELS cl_kernel kernel = backend_ctx->kernel_convert_block_q4_K; - if (use_adreno_kernels(backend_ctx, tensor)) { + if (use_adreno_kernels(backend_ctx, tensor) && !use_flat_gemv_for_large_m_q4_K(tensor)) { kernel = backend_ctx->kernel_convert_block_q4_K_noshuffle; } #else @@ -6623,7 +6638,7 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, tensor->extra = extra; #ifdef GGML_OPENCL_USE_ADRENO_KERNELS - if (use_adreno_kernels(backend_ctx, tensor)) { + if (use_adreno_kernels(backend_ctx, tensor) && !use_flat_gemv_for_large_m_q4_K(tensor)) { int M = tensor->ne[1]; int K = tensor->ne[0]; @@ -6923,7 +6938,7 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, cl_kernel kernel; #ifdef GGML_OPENCL_USE_ADRENO_KERNELS kernel = backend_ctx->kernel_convert_block_q6_K; - if (use_adreno_kernels(backend_ctx, tensor)) { + if (use_adreno_kernels(backend_ctx, tensor) && !use_flat_gemv_for_large_m_q6_K(tensor)) { kernel = backend_ctx->kernel_convert_block_q6_K_noshuffle; } #else @@ -6956,7 +6971,7 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, tensor->extra = extra; #ifdef GGML_OPENCL_USE_ADRENO_KERNELS - if (use_adreno_kernels(backend_ctx, tensor)) { + if (use_adreno_kernels(backend_ctx, tensor) && !use_flat_gemv_for_large_m_q6_K(tensor)) { cl_int M = tensor->ne[1]; // ne01 cl_int K = tensor->ne[0]; // ne00 @@ -7599,7 +7614,7 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, CL_CHECK(clReleaseMemObject(data_device)); return; } - if (use_adreno_kernels(backend_ctx, tensor)) { + if (use_adreno_kernels(backend_ctx, tensor) && !use_flat_gemv_for_large_m_q4_K(tensor)) { int M = tensor->ne[1]; int K = tensor->ne[0]; @@ -7820,7 +7835,7 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, CL_CHECK(clReleaseMemObject(data_device)); return; } - if (use_adreno_kernels(backend_ctx, tensor)) { + if (use_adreno_kernels(backend_ctx, tensor) && !use_flat_gemv_for_large_m_q6_K(tensor)) { static ggml_cl_buffer buf_trans_ql; static ggml_cl_buffer buf_trans_qh; static ggml_cl_buffer buf_trans_s; @@ -13213,13 +13228,13 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co } // q4_k x fp32 - if (src0t == GGML_TYPE_Q4_K && src1t == GGML_TYPE_F32) { + if (src0t == GGML_TYPE_Q4_K && src1t == GGML_TYPE_F32 && !use_flat_gemv_for_large_m_q4_K(src0)) { ggml_cl_mul_mat_q4_k_f32_adreno(backend, src0, src1, dst); return; } // q6_K x fp32 - if (src0t == GGML_TYPE_Q6_K && src1t == GGML_TYPE_F32) { + if (src0t == GGML_TYPE_Q6_K && src1t == GGML_TYPE_F32 && !use_flat_gemv_for_large_m_q6_K(src0)) { ggml_cl_mul_mat_q6_K_f32_adreno(backend, src0, src1, dst); return; } From d5a49ebec8445172aa73d36c3cefb0586e7ce1ab Mon Sep 17 00:00:00 2001 From: Aman Gupta <amangupta052@gmail.com> Date: Wed, 3 Jun 2026 18:39:59 +0800 Subject: [PATCH 780/831] cuda: reserve space for quantize kv-cache at startup (llama/23907) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * cuda: reserve space for quantize kv-cache at startup * address review comments * remove forward decl Co-authored-by: Johannes Gäßler <johannesg@5d6.de> * remove assert in ggml-cuda.cu Co-authored-by: Johannes Gäßler <johannesg@5d6.de> --------- Co-authored-by: Johannes Gäßler <johannesg@5d6.de> --- ggml/src/ggml-cuda/fattn-common.cuh | 65 ++++++++++++++++++++++++----- ggml/src/ggml-cuda/fattn.cu | 35 ++++++++++++++++ ggml/src/ggml-cuda/fattn.cuh | 2 + ggml/src/ggml-cuda/ggml-cuda.cu | 8 ++-- 4 files changed, 96 insertions(+), 14 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index d650b5fbd0f..064f753f7ef 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -44,6 +44,46 @@ typedef void (* fattn_kernel_t)( typedef float (*vec_dot_KQ_t)( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds); +struct ggml_cuda_flash_attn_ext_f16_extra_data { + uintptr_t K; + uintptr_t V; + uintptr_t end; +}; + +static inline ggml_cuda_flash_attn_ext_f16_extra_data ggml_cuda_flash_attn_ext_get_f16_extra_data( + const ggml_tensor * dst, const bool need_f16_K, const bool need_f16_V) { + GGML_ASSERT(dst->op == GGML_OP_FLASH_ATTN_EXT); + + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; + + GGML_ASSERT(K != nullptr); + GGML_ASSERT(V != nullptr); + + const bool V_is_K_view = V->view_src && (V->view_src == K || (V->view_src == K->view_src && V->view_offs == K->view_offs)); + + ggml_cuda_flash_attn_ext_f16_extra_data data = {}; + data.end = (uintptr_t) dst->data + ggml_nbytes(dst); + + if (need_f16_K && K->type != GGML_TYPE_F16) { + data.end = GGML_PAD(data.end, 128); + data.K = data.end; + data.end += ggml_nelements(K)*ggml_type_size(GGML_TYPE_F16); + } + + if (need_f16_V && V->type != GGML_TYPE_F16) { + if (V_is_K_view) { + data.V = data.K; + } else { + data.end = GGML_PAD(data.end, 128); + data.V = data.end; + data.end += ggml_nelements(V)*ggml_type_size(GGML_TYPE_F16); + } + } + + return data; +} + template <int D, int nthreads> static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_f16( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) { @@ -952,8 +992,9 @@ void launch_fattn( const int cc = ggml_cuda_info().devices[id].cc; const int nsm = ggml_cuda_info().devices[id].nsm; - ggml_cuda_pool_alloc<half> K_f16(pool); - ggml_cuda_pool_alloc<half> V_f16(pool); + const ggml_cuda_flash_attn_ext_f16_extra_data f16_extra = + ggml_cuda_flash_attn_ext_get_f16_extra_data(KQV, need_f16_K, need_f16_V); + ggml_cuda_pool_alloc<int> KV_max(pool); ggml_cuda_pool_alloc<float> dst_tmp(pool); ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool); @@ -972,10 +1013,11 @@ void launch_fattn( const size_t bs = ggml_blck_size(K->type); const size_t ts = ggml_type_size(K->type); - K_f16.alloc(ggml_nelements(K)); + GGML_ASSERT(f16_extra.K != 0); + half * K_f16 = (half *) f16_extra.K; if (ggml_is_contiguously_allocated(K)) { to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type); - to_fp16(K_data, K_f16.ptr, ggml_nelements(K), main_stream); + to_fp16(K_data, K_f16, ggml_nelements(K), main_stream); nb11 = nb11*bs*sizeof(half)/ts; nb12 = nb12*bs*sizeof(half)/ts; @@ -986,13 +1028,13 @@ void launch_fattn( const int64_t s01 = nb11 / ts; const int64_t s02 = nb12 / ts; const int64_t s03 = nb13 / ts; - to_fp16(K_data, K_f16.ptr, K->ne[0], K->ne[1], K->ne[2], K->ne[3], s01, s02, s03, main_stream); + to_fp16(K_data, K_f16, K->ne[0], K->ne[1], K->ne[2], K->ne[3], s01, s02, s03, main_stream); nb11 = K->ne[0] * sizeof(half); nb12 = K->ne[1] * nb11; nb13 = K->ne[2] * nb12; } - K_data = (char *) K_f16.ptr; + K_data = (char *) K_f16; } if (need_f16_V && V->type != GGML_TYPE_F16) { @@ -1005,11 +1047,12 @@ void launch_fattn( const size_t bs = ggml_blck_size(V->type); const size_t ts = ggml_type_size(V->type); - V_f16.alloc(ggml_nelements(V)); + GGML_ASSERT(f16_extra.V != 0); + half * V_f16 = (half *) f16_extra.V; if (ggml_is_contiguously_allocated(V)) { to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type); - to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream); - V_data = (char *) V_f16.ptr; + to_fp16(V_data, V_f16, ggml_nelements(V), main_stream); + V_data = (char *) V_f16; nb21 = nb21*bs*sizeof(half)/ts; nb22 = nb22*bs*sizeof(half)/ts; @@ -1020,13 +1063,13 @@ void launch_fattn( const int64_t s01 = nb21 / ts; const int64_t s02 = nb22 / ts; const int64_t s03 = nb23 / ts; - to_fp16(V_data, V_f16.ptr, V->ne[0], V->ne[1], V->ne[2], V->ne[3], s01, s02, s03, main_stream); + to_fp16(V_data, V_f16, V->ne[0], V->ne[1], V->ne[2], V->ne[3], s01, s02, s03, main_stream); nb21 = V->ne[0] * sizeof(half); nb22 = V->ne[1] * nb21; nb23 = V->ne[2] * nb22; } - V_data = (char *) V_f16.ptr; + V_data = (char *) V_f16; } } diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 1c7777e8a71..d6c501b1d7e 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -537,6 +537,41 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const return BEST_FATTN_KERNEL_TILE; } +size_t ggml_cuda_flash_attn_ext_get_alloc_size(int device, const ggml_tensor * dst) { + GGML_ASSERT(dst->op == GGML_OP_FLASH_ATTN_EXT); + + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; + + GGML_ASSERT(K != nullptr); + GGML_ASSERT(V != nullptr); + + const best_fattn_kernel kernel = ggml_cuda_get_best_fattn_kernel(device, dst); + + bool need_f16_K = false; + bool need_f16_V = false; + + switch (kernel) { + case BEST_FATTN_KERNEL_TILE: + case BEST_FATTN_KERNEL_WMMA_F16: + case BEST_FATTN_KERNEL_MMA_F16: + need_f16_K = true; + need_f16_V = true; + break; + case BEST_FATTN_KERNEL_VEC: + need_f16_K = K->type == GGML_TYPE_F32; + need_f16_V = V->type == GGML_TYPE_F32; + break; + case BEST_FATTN_KERNEL_NONE: + break; + } + + const ggml_cuda_flash_attn_ext_f16_extra_data f16_extra = + ggml_cuda_flash_attn_ext_get_f16_extra_data(dst, need_f16_K, need_f16_V); + + return f16_extra.end - (uintptr_t) dst->data; +} + void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { ggml_cuda_set_device(ctx.device); switch (ggml_cuda_get_best_fattn_kernel(ggml_cuda_get_device(), dst)) { diff --git a/ggml/src/ggml-cuda/fattn.cuh b/ggml/src/ggml-cuda/fattn.cuh index 78705d59951..f9a7e15fbd6 100644 --- a/ggml/src/ggml-cuda/fattn.cuh +++ b/ggml/src/ggml-cuda/fattn.cuh @@ -3,3 +3,5 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst); bool ggml_cuda_flash_attn_ext_supported(int device, const ggml_tensor * dst); + +size_t ggml_cuda_flash_attn_ext_get_alloc_size(int device, const ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 18aaa098398..f5293ad4cbb 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -801,7 +801,11 @@ static size_t ggml_backend_cuda_buffer_type_get_alignment(ggml_backend_buffer_ty } static size_t ggml_backend_cuda_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { - size_t size = ggml_nbytes(tensor); + ggml_backend_cuda_buffer_type_context * buft_ctx = (ggml_backend_cuda_buffer_type_context *) buft->context; + + size_t size = tensor->op == GGML_OP_FLASH_ATTN_EXT + ? ggml_cuda_flash_attn_ext_get_alloc_size(buft_ctx->device, tensor) + : ggml_nbytes(tensor); int64_t ne0 = tensor->ne[0]; if (ggml_is_quantized(tensor->type)) { @@ -812,8 +816,6 @@ static size_t ggml_backend_cuda_buffer_type_get_alloc_size(ggml_backend_buffer_t } return size; - - GGML_UNUSED(buft); } static const ggml_backend_buffer_type_i ggml_backend_cuda_buffer_type_interface = { From 750fa4ca35c4b39a487724f854bfc788a2d9db53 Mon Sep 17 00:00:00 2001 From: Charles Xu <charles.xu@arm.com> Date: Wed, 3 Jun 2026 12:45:10 +0200 Subject: [PATCH 781/831] ggml-cpu: use runtime SVE width in FWHT (llama/24059) --- ggml/src/ggml-cpu/ops.cpp | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index dc73696ad9f..3a1912ae91b 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -8955,7 +8955,12 @@ static void ggml_compute_forward_flash_attn_ext_f16( k->type == v->type && neq1 >= Q_TILE_SZ); #ifdef GGML_SIMD - use_tiled &= (DV % GGML_F32_EPR == 0); +#if defined(__ARM_FEATURE_SVE) + const int64_t f32_epr = svcntw(); +#else + const int64_t f32_epr = GGML_F32_EPR; +#endif + use_tiled &= (DV % f32_epr == 0); #endif int current_chunk = ith; @@ -11358,7 +11363,11 @@ static void ggml_compute_forward_fwht_f32(const ggml_compute_params * params, gg // Scalar passes #if defined(GGML_SIMD) +#if defined(__ARM_FEATURE_SVE) + const int step = svcntw(); +#else const int step = GGML_F32_EPR; +#endif #else const int step = n; #endif From 00a9728de303c399f069dd1b0b7a33689ab3e56a Mon Sep 17 00:00:00 2001 From: Andreas Kieslinger <47689530+aendk@users.noreply.github.com> Date: Wed, 3 Jun 2026 13:56:42 +0200 Subject: [PATCH 782/831] Avoid PDL race conditions by disabling __restrict__ when PDL is used (llama/24030) * Removes __restrict__ from PDL kernel headers due to incompatibility with PDL. Adds preprocessor directives based on arch in kernel body to add __restrict__ to retain performance on older architectures. * Simplifies new __restrict__ usage via macro * Add hopper to PDL __restrict__ fix. Co-authored-by: Oliver Simons <osimons@nvidia.com> --------- Co-authored-by: Oliver Simons <osimons@nvidia.com> --- ggml/src/ggml-cuda/common.cuh | 6 ++++++ ggml/src/ggml-cuda/fattn-common.cuh | 25 ++++++++++++++++--------- ggml/src/ggml-cuda/fattn-mma-f16.cuh | 26 +++++++++++++++++--------- ggml/src/ggml-cuda/fattn-tile.cuh | 26 +++++++++++++++++--------- ggml/src/ggml-cuda/fattn-vec.cuh | 26 +++++++++++++++++--------- ggml/src/ggml-cuda/fattn-wmma-f16.cu | 26 +++++++++++++++++--------- ggml/src/ggml-cuda/getrows.cu | 5 ++++- ggml/src/ggml-cuda/mmvf.cu | 6 +++++- ggml/src/ggml-cuda/mmvq.cu | 6 +++++- ggml/src/ggml-cuda/quantize.cu | 4 +++- ggml/src/ggml-cuda/reduce_rows.cuh | 4 +++- ggml/src/ggml-cuda/set-rows.cu | 9 ++++++--- ggml/src/ggml-cuda/ssm-conv.cu | 10 +++++++--- ggml/src/ggml-cuda/ssm-scan.cu | 28 ++++++++++++++++++++++------ 14 files changed, 145 insertions(+), 62 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 560fab0b17b..e6e50e04119 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -1611,6 +1611,12 @@ static bool ggml_cuda_kernel_can_use_pdl(const void * kernel) { #endif //defined(GGML_CUDA_USE_PDL) +// PDL and __restrict__ need to be mutually exclusive, see https://github.com/ggml-org/llama.cpp/pull/24030 +# if (defined(GGML_CUDA_USE_PDL) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER) +# define GGML_CUDA_RESTRICT +# else +# define GGML_CUDA_RESTRICT __restrict__ +# endif // defined(GGML_CUDA_USE_PDL) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER template<typename Kernel, typename... Args> static __inline__ void ggml_cuda_kernel_launch(Kernel kernel, const ggml_cuda_kernel_launch_params & launch_params, Args&&... args) { diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index 064f753f7ef..8dfa51ad1e8 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -718,8 +718,8 @@ static __global__ void flash_attn_mask_to_KV_max( template<int D, int ncols1, int ncols2> // D == head size __launch_bounds__(D, 1) static __global__ void flash_attn_stream_k_fixup_uniform( - float * __restrict__ dst, - const float2 * __restrict__ dst_fixup, + float * dst_ptr, + const float2 * dst_fixup_ptr, const int ne01, const int ne02, const int ne12, const int nblocks_stream_k, const int gqa_ratio, @@ -729,6 +729,8 @@ static __global__ void flash_attn_stream_k_fixup_uniform( const uint3 fd_iter_j) { constexpr int ncols = ncols1*ncols2; ggml_cuda_pdl_lc(); + float * GGML_CUDA_RESTRICT dst = dst_ptr; + const float2 * GGML_CUDA_RESTRICT dst_fixup = dst_fixup_ptr; const int tile_idx = blockIdx.x; // One block per output tile. const int j = blockIdx.y; @@ -800,8 +802,8 @@ static __global__ void flash_attn_stream_k_fixup_uniform( template <int D, int ncols1, int ncols2> // D == head size __launch_bounds__(D, 1) static __global__ void flash_attn_stream_k_fixup_general( - float * __restrict__ dst, - const float2 * __restrict__ dst_fixup, + float * dst_ptr, + const float2 * dst_fixup_ptr, const int ne01, const int ne02, const int gqa_ratio, const int total_work, @@ -809,6 +811,8 @@ static __global__ void flash_attn_stream_k_fixup_general( const uint3 fd_iter_k_j_z, const uint3 fd_iter_k_j, const uint3 fd_iter_k) { + float * GGML_CUDA_RESTRICT dst = dst_ptr; + const float2 * GGML_CUDA_RESTRICT dst_fixup = dst_fixup_ptr; constexpr int ncols = ncols1*ncols2; const int bidx0 = blockIdx.x; @@ -907,11 +911,14 @@ static __global__ void flash_attn_stream_k_fixup_general( template<int D> // D == head size __launch_bounds__(D, 1) static __global__ void flash_attn_combine_results( - const float * __restrict__ VKQ_parts, - const float2 * __restrict__ VKQ_meta, - float * __restrict__ dst, + const float * VKQ_parts_ptr, + const float2 * VKQ_meta_ptr, + float * dst_ptr, const int parallel_blocks) { ggml_cuda_pdl_lc(); + const float * GGML_CUDA_RESTRICT VKQ_parts = VKQ_parts_ptr; + const float2 * GGML_CUDA_RESTRICT VKQ_meta = VKQ_meta_ptr; + float * GGML_CUDA_RESTRICT dst = dst_ptr; // Dimension 0: threadIdx.x // Dimension 1: blockIdx.x // Dimension 2: blockIdx.y @@ -1196,8 +1203,8 @@ void launch_fattn( GGML_ASSERT(block_dim.x % warp_size == 0); - // disabled PDL enrollment for now due to a compiler bug. - fattn_kernel<<<blocks_num, block_dim, nbytes_shared, main_stream>>>( + ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks_num, block_dim, nbytes_shared, main_stream); + ggml_cuda_kernel_launch(fattn_kernel, launch_params, (const char *) Q->data, K_data, V_data, diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index ac5abb13367..83478a02cb6 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -1703,14 +1703,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( template<int DKQ, int DV, int ncols1, int ncols2, bool use_logit_softcap, bool V_is_K_view> __launch_bounds__(ggml_cuda_fattn_mma_get_nthreads(DKQ, DV, ncols1*ncols2), ggml_cuda_fattn_mma_get_occupancy(DKQ, DV, ncols1*ncols2)) static __global__ void flash_attn_ext_f16( - const char * __restrict__ Q, - const char * __restrict__ K, - const char * __restrict__ V, - const char * __restrict__ mask, - const char * __restrict__ sinks, - const int * __restrict__ KV_max, - float * __restrict__ dst, - float2 * __restrict__ dst_meta, + const char * Q_ptr, + const char * K_ptr, + const char * V_ptr, + const char * mask_ptr, + const char * sinks_ptr, + const int * KV_max_ptr, + float * dst_ptr, + float2 * dst_meta_ptr, const float scale, const float max_bias, const float m0, @@ -1726,6 +1726,14 @@ static __global__ void flash_attn_ext_f16( const int32_t nb31, const int32_t nb32, const int64_t nb33) { ggml_cuda_pdl_sync(); // TODO optimize placement #if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)) + const char * GGML_CUDA_RESTRICT Q = Q_ptr; + const char * GGML_CUDA_RESTRICT K = K_ptr; + const char * GGML_CUDA_RESTRICT V = V_ptr; + const char * GGML_CUDA_RESTRICT mask = mask_ptr; + const char * GGML_CUDA_RESTRICT sinks = sinks_ptr; + const int * GGML_CUDA_RESTRICT KV_max = KV_max_ptr; + float * GGML_CUDA_RESTRICT dst = dst_ptr; + float2 * GGML_CUDA_RESTRICT dst_meta = dst_meta_ptr; // Skip unused kernel variants for faster compilation: if (use_logit_softcap && !(DKQ == 128 || DKQ == 256 || DKQ == 512)) { @@ -1871,7 +1879,7 @@ static __global__ void flash_attn_ext_f16( (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, ne01, ne02, gqa_ratio, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt_gqa, kb0_start, kb0_stop); #else - GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, + GGML_UNUSED_VARS(Q_ptr, K_ptr, V_ptr, mask_ptr, sinks_ptr, KV_max_ptr, dst_ptr, dst_meta_ptr, scale, max_bias, m0, m1, n_head_log2, logit_softcap, ne00, ne01, ne02, ne03, nb01, nb02, nb03, diff --git a/ggml/src/ggml-cuda/fattn-tile.cuh b/ggml/src/ggml-cuda/fattn-tile.cuh index fac76f13593..0a099810e14 100644 --- a/ggml/src/ggml-cuda/fattn-tile.cuh +++ b/ggml/src/ggml-cuda/fattn-tile.cuh @@ -788,14 +788,14 @@ static __device__ __forceinline__ void flash_attn_tile_iter( template<int DKQ, int DV, int ncols1, int ncols2, bool use_logit_softcap> // D == head size __launch_bounds__(ggml_cuda_fattn_tile_get_nthreads(DKQ, DV, ncols1*ncols2), ggml_cuda_fattn_tile_get_occupancy(DKQ, DV, ncols1*ncols2)) static __global__ void flash_attn_tile( - const char * __restrict__ Q, - const char * __restrict__ K, - const char * __restrict__ V, - const char * __restrict__ mask, - const char * __restrict__ sinks, - const int * __restrict__ KV_max, - float * __restrict__ dst, - float2 * __restrict__ dst_meta, + const char * Q_ptr, + const char * K_ptr, + const char * V_ptr, + const char * mask_ptr, + const char * sinks_ptr, + const int * KV_max_ptr, + float * dst_ptr, + float2 * dst_meta_ptr, const float scale, const float max_bias, const float m0, @@ -810,6 +810,14 @@ static __global__ void flash_attn_tile( const int32_t ne31, const int32_t ne32, const int32_t ne33, const int32_t nb31, const int32_t nb32, const int64_t nb33) { #ifdef FLASH_ATTN_AVAILABLE + const char * GGML_CUDA_RESTRICT Q = Q_ptr; + const char * GGML_CUDA_RESTRICT K = K_ptr; + const char * GGML_CUDA_RESTRICT V = V_ptr; + const char * GGML_CUDA_RESTRICT mask = mask_ptr; + const char * GGML_CUDA_RESTRICT sinks = sinks_ptr; + const int * GGML_CUDA_RESTRICT KV_max = KV_max_ptr; + float * GGML_CUDA_RESTRICT dst = dst_ptr; + float2 * GGML_CUDA_RESTRICT dst_meta = dst_meta_ptr; // Skip unused kernel variants for faster compilation: @@ -1126,7 +1134,7 @@ static __global__ void flash_attn_tile( } } #else - GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, + GGML_UNUSED_VARS(Q_ptr, K_ptr, V_ptr, mask_ptr, sinks_ptr, KV_max_ptr, dst_ptr, dst_meta_ptr, scale, max_bias, m0, m1, n_head_log2, logit_softcap, ne00, ne01, ne02, ne03, nb01, nb02, nb03, diff --git a/ggml/src/ggml-cuda/fattn-vec.cuh b/ggml/src/ggml-cuda/fattn-vec.cuh index b0a6cf67f1a..69dd9368624 100644 --- a/ggml/src/ggml-cuda/fattn-vec.cuh +++ b/ggml/src/ggml-cuda/fattn-vec.cuh @@ -19,14 +19,14 @@ static constexpr __device__ int ggml_cuda_fattn_vec_get_nthreads_device() { template<int D, int ncols, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size __launch_bounds__(ggml_cuda_fattn_vec_get_nthreads_device(), 1) static __global__ void flash_attn_ext_vec( - const char * __restrict__ Q, - const char * __restrict__ K, - const char * __restrict__ V, - const char * __restrict__ mask, - const char * __restrict__ sinks, - const int * __restrict__ KV_max, - float * __restrict__ dst, - float2 * __restrict__ dst_meta, + const char * Q_ptr, + const char * K_ptr, + const char * V_ptr, + const char * mask_ptr, + const char * sinks_ptr, + const int * KV_max_ptr, + float * dst_ptr, + float2 * dst_meta_ptr, const float scale, const float max_bias, const float m0, @@ -42,6 +42,14 @@ static __global__ void flash_attn_ext_vec( const int32_t nb31, const int32_t nb32, const int64_t nb33) { ggml_cuda_pdl_lc(); #ifdef FLASH_ATTN_AVAILABLE + const char * GGML_CUDA_RESTRICT Q = Q_ptr; + const char * GGML_CUDA_RESTRICT K = K_ptr; + const char * GGML_CUDA_RESTRICT V = V_ptr; + const char * GGML_CUDA_RESTRICT mask = mask_ptr; + const char * GGML_CUDA_RESTRICT sinks = sinks_ptr; + const int * GGML_CUDA_RESTRICT KV_max = KV_max_ptr; + float * GGML_CUDA_RESTRICT dst = dst_ptr; + float2 * GGML_CUDA_RESTRICT dst_meta = dst_meta_ptr; // Skip unused kernel variants for faster compilation: if (use_logit_softcap && !(D == 128 || D == 256)) { @@ -506,7 +514,7 @@ static __global__ void flash_attn_ext_vec( dst_meta[((sequence*int(ne01.z) + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(KQ_max[tid], KQ_sum[tid]); } #else - GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, + GGML_UNUSED_VARS(Q_ptr, K_ptr, V_ptr, mask_ptr, sinks_ptr, KV_max_ptr, dst_ptr, dst_meta_ptr, scale, max_bias, m0, m1, n_head_log2, logit_softcap, ne00, ne01, ne02, ne03, nb01, nb02, nb03, diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ggml/src/ggml-cuda/fattn-wmma-f16.cu index 4b6f6501094..6850716fc0d 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cu +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cu @@ -24,14 +24,14 @@ namespace wmma = rocwmma; template<int D, int ncols, int nwarps, int VKQ_stride, typename KQ_acc_t, bool use_logit_softcap> __launch_bounds__(nwarps*ggml_cuda_get_physical_warp_size(), 1) static __global__ void flash_attn_ext_f16( - const char * __restrict__ Q, - const char * __restrict__ K, - const char * __restrict__ V, - const char * __restrict__ mask, - const char * __restrict__ sinks, - const int * __restrict__ KV_max, - float * __restrict__ dst, - float2 * __restrict__ dst_meta, + const char * Q_ptr, + const char * K_ptr, + const char * V_ptr, + const char * mask_ptr, + const char * sinks_ptr, + const int * KV_max_ptr, + float * dst_ptr, + float2 * dst_meta_ptr, const float scale, const float max_bias, const float m0, @@ -46,6 +46,14 @@ static __global__ void flash_attn_ext_f16( const int32_t ne31, const int32_t ne32, const int32_t ne33, const int32_t nb31, const int32_t nb32, const int64_t nb33) { #if defined(FLASH_ATTN_AVAILABLE) && (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN)) + const char * GGML_CUDA_RESTRICT Q = Q_ptr; + const char * GGML_CUDA_RESTRICT K = K_ptr; + const char * GGML_CUDA_RESTRICT V = V_ptr; + const char * GGML_CUDA_RESTRICT mask = mask_ptr; + const char * GGML_CUDA_RESTRICT sinks = sinks_ptr; + const int * GGML_CUDA_RESTRICT KV_max = KV_max_ptr; + float * GGML_CUDA_RESTRICT dst = dst_ptr; + float2 * GGML_CUDA_RESTRICT dst_meta = dst_meta_ptr; // Skip unused kernel variants for faster compilation: if (use_logit_softcap && !(D == 128 || D == 256)) { NO_DEVICE_CODE; @@ -494,7 +502,7 @@ static __global__ void flash_attn_ext_f16( dst_meta[j_dst_unrolled] = dst_meta_val; } #else - GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, + GGML_UNUSED_VARS(Q_ptr, K_ptr, V_ptr, mask_ptr, sinks_ptr, KV_max_ptr, dst_ptr, dst_meta_ptr, scale, max_bias, m0, m1, n_head_log2, logit_softcap, ne00, ne01, ne02, ne03, nb01, nb02, nb03, diff --git a/ggml/src/ggml-cuda/getrows.cu b/ggml/src/ggml-cuda/getrows.cu index 457b695eb2a..eb157b8baf2 100644 --- a/ggml/src/ggml-cuda/getrows.cu +++ b/ggml/src/ggml-cuda/getrows.cu @@ -42,7 +42,7 @@ static __global__ void k_get_rows( template<typename src0_t, typename dst_t> static __global__ void k_get_rows_float( - const src0_t * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst, + const src0_t * src0_ptr, const int32_t * src1_ptr, dst_t * dst_ptr, const int64_t ne00, /*const int64_t ne01, const int64_t ne02, const int64_t ne03,*/ /*const int64_t ne10,*/ const int64_t ne11, const uint3 ne12_fdv, /*const int64_t ne13,*/ /*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3, @@ -50,6 +50,9 @@ static __global__ void k_get_rows_float( const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) { ggml_cuda_pdl_lc(); + const src0_t * GGML_CUDA_RESTRICT src0 = src0_ptr; + const int32_t * GGML_CUDA_RESTRICT src1 = src1_ptr; + dst_t * GGML_CUDA_RESTRICT dst = dst_ptr; ggml_cuda_pdl_sync(); for (int64_t z = blockIdx.z; z < ne11*(int64_t)ne12_fdv.z; z += gridDim.z) { for (int64_t i00 = blockIdx.y*blockDim.x + threadIdx.x; i00 < ne00; i00 += gridDim.y*blockDim.x) { diff --git a/ggml/src/ggml-cuda/mmvf.cu b/ggml/src/ggml-cuda/mmvf.cu index 3d6de64b775..d7dbc8b9928 100644 --- a/ggml/src/ggml-cuda/mmvf.cu +++ b/ggml/src/ggml-cuda/mmvf.cu @@ -6,11 +6,15 @@ template <typename T, typename type_acc, int ncols_dst, int block_size, bool has_fusion = false, bool is_multi_token_id = false> static __global__ void mul_mat_vec_f( - const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst, + const T * x_ptr, const float * y_ptr, const int32_t * ids_ptr, const ggml_cuda_mm_fusion_args_device fusion, float * dst_ptr, const int ncols2, const uint3 nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst, const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, const int ids_stride) { + const T * GGML_CUDA_RESTRICT x = x_ptr; + const float * GGML_CUDA_RESTRICT y = y_ptr; + const int32_t * GGML_CUDA_RESTRICT ids = ids_ptr; + float * GGML_CUDA_RESTRICT dst = dst_ptr; const int row = blockIdx.x; // for MUL_MAT_ID - blockIdx.y = n_expert_used, blockIdx.z = ncols_dst (tokens) const int channel_dst = blockIdx.y; diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 86b4a493019..4b0426590ac 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -476,12 +476,16 @@ static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int template <ggml_type type, int ncols_dst, bool has_fusion, bool small_k = false> __launch_bounds__(calc_nwarps(type, ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1) static __global__ void mul_mat_vec_q( - const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst, + const void * vx_ptr, const void * vy_ptr, const int32_t * ids_ptr, const ggml_cuda_mm_fusion_args_device fusion, float * dst_ptr, const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y, const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x, const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio, const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst, const uint32_t ids_stride) { + const void * GGML_CUDA_RESTRICT vx = vx_ptr; + const void * GGML_CUDA_RESTRICT vy = vy_ptr; + const int32_t * GGML_CUDA_RESTRICT ids = ids_ptr; + float * GGML_CUDA_RESTRICT dst = dst_ptr; constexpr int qk = ggml_cuda_type_traits<type>::qk; constexpr int qi = ggml_cuda_type_traits<type>::qi; diff --git a/ggml/src/ggml-cuda/quantize.cu b/ggml/src/ggml-cuda/quantize.cu index 49516965cad..39a500a1704 100644 --- a/ggml/src/ggml-cuda/quantize.cu +++ b/ggml/src/ggml-cuda/quantize.cu @@ -3,10 +3,12 @@ __launch_bounds__(CUDA_QUANTIZE_BLOCK_SIZE, 1) static __global__ void quantize_q8_1( - const float * __restrict__ x, void * __restrict__ vy, + const float * x_ptr, void * vy_ptr, const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03, const int64_t ne0, const uint32_t ne1, const uint3 ne2) { ggml_cuda_pdl_lc(); + const float * GGML_CUDA_RESTRICT x = x_ptr; + void * GGML_CUDA_RESTRICT vy = vy_ptr; const int64_t i0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; if (i0 >= ne0) { diff --git a/ggml/src/ggml-cuda/reduce_rows.cuh b/ggml/src/ggml-cuda/reduce_rows.cuh index 5895d3bf8e5..968c47aa20a 100644 --- a/ggml/src/ggml-cuda/reduce_rows.cuh +++ b/ggml/src/ggml-cuda/reduce_rows.cuh @@ -2,7 +2,9 @@ // Row reduction kernel template - compute sum (norm=false) or mean (norm=true) template <bool norm> -static __global__ void reduce_rows_f32(const float * __restrict__ x, float * __restrict__ dst, const int ncols) { +static __global__ void reduce_rows_f32(const float * x_ptr, float * dst_ptr, const int ncols) { + const float * GGML_CUDA_RESTRICT x = x_ptr; + float * GGML_CUDA_RESTRICT dst = dst_ptr; const int row = blockIdx.x; const int col = threadIdx.x; diff --git a/ggml/src/ggml-cuda/set-rows.cu b/ggml/src/ggml-cuda/set-rows.cu index e14f96b824c..3b4f004c946 100644 --- a/ggml/src/ggml-cuda/set-rows.cu +++ b/ggml/src/ggml-cuda/set-rows.cu @@ -111,9 +111,9 @@ static void set_rows_cuda_quant( } template <typename src_t, typename idx_t, typename dst_t> -static __global__ void k_set_rows(const src_t * __restrict__ src0, - const idx_t * __restrict__ src1, - dst_t * __restrict__ dst, +static __global__ void k_set_rows(const src_t * src0_ptr, + const idx_t * src1_ptr, + dst_t * dst_ptr, const int64_t ne_total, const int64_t ne10, const int64_t ne11, @@ -133,6 +133,9 @@ static __global__ void k_set_rows(const src_t * __restrict__ src0, const uint3 ne02, const uint3 ne11_fd, const uint3 ne12_fd) { + const src_t * GGML_CUDA_RESTRICT src0 = src0_ptr; + const idx_t * GGML_CUDA_RESTRICT src1 = src1_ptr; + dst_t * GGML_CUDA_RESTRICT dst = dst_ptr; const int64_t i = int64_t(blockDim.x) * blockIdx.x + threadIdx.x; if (i >= ne_total) { diff --git a/ggml/src/ggml-cuda/ssm-conv.cu b/ggml/src/ggml-cuda/ssm-conv.cu index 48787b4b890..1463169cf78 100644 --- a/ggml/src/ggml-cuda/ssm-conv.cu +++ b/ggml/src/ggml-cuda/ssm-conv.cu @@ -3,12 +3,16 @@ #include "unary.cuh" template <bool apply_silu, size_t split_d_inner, size_t d_conv> -static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float * __restrict__ src1, - const float * __restrict__ bias, +static __global__ void ssm_conv_f32(const float * src0_ptr, const float * src1_ptr, + const float * bias_ptr, const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1, - float * __restrict__ dst, const int dst_nb0, const int dst_nb1, const int dst_nb2, + float * dst_ptr, const int dst_nb0, const int dst_nb1, const int dst_nb2, const int64_t n_t) { ggml_cuda_pdl_lc(); + const float * GGML_CUDA_RESTRICT src0 = src0_ptr; + const float * GGML_CUDA_RESTRICT src1 = src1_ptr; + const float * GGML_CUDA_RESTRICT bias = bias_ptr; + float * GGML_CUDA_RESTRICT dst = dst_ptr; GGML_UNUSED(src0_nb0); const int tid = threadIdx.x; const int bidx = blockIdx.x; diff --git a/ggml/src/ggml-cuda/ssm-scan.cu b/ggml/src/ggml-cuda/ssm-scan.cu index 412980376ac..2e3f97c7284 100644 --- a/ggml/src/ggml-cuda/ssm-scan.cu +++ b/ggml/src/ggml-cuda/ssm-scan.cu @@ -17,14 +17,22 @@ using namespace cub; #endif // __clang__ template <size_t splitD, size_t N, size_t L_template> __global__ void __launch_bounds__(splitD, 1) - ssm_scan_f32(const float *__restrict__ src0, const float *__restrict__ src1, const float *__restrict__ src2, - const float *__restrict__ src3, const float *__restrict__ src4, const float *__restrict__ src5, - const int32_t * __restrict__ src6, float * __restrict__ dst, + ssm_scan_f32(const float * src0_ptr, const float * src1_ptr, const float * src2_ptr, + const float * src3_ptr, const float * src4_ptr, const float * src5_ptr, + const int32_t * src6_ptr, float * dst_ptr, const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3, const int src2_nb1, const int src2_nb2, const int src3_nb1, const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3, const int64_t s_off, const int64_t d_inner, const int64_t L_param) { + const float * GGML_CUDA_RESTRICT src0 = src0_ptr; + const float * GGML_CUDA_RESTRICT src1 = src1_ptr; + const float * GGML_CUDA_RESTRICT src2 = src2_ptr; + const float * GGML_CUDA_RESTRICT src3 = src3_ptr; + const float * GGML_CUDA_RESTRICT src4 = src4_ptr; + const float * GGML_CUDA_RESTRICT src5 = src5_ptr; + const int32_t * GGML_CUDA_RESTRICT src6 = src6_ptr; + float * GGML_CUDA_RESTRICT dst = dst_ptr; const size_t L = L_template == 0 ? L_param : L_template; ggml_cuda_pdl_sync(); const float *s0_block = (const float *)((const char *)src0 + src6[blockIdx.x] * src0_nb3 + blockIdx.y * splitD * src0_nb2); @@ -118,13 +126,21 @@ __global__ void __launch_bounds__(splitD, 1) template <int c_factor, int d_state> __global__ void __launch_bounds__(d_state, 1) ssm_scan_f32_group( - const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2, - const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5, - const int32_t * __restrict__ src6, float * __restrict__ dst, + const float * src0_ptr, const float * src1_ptr, const float * src2_ptr, + const float * src3_ptr, const float * src4_ptr, const float * src5_ptr, + const int32_t * src6_ptr, float * dst_ptr, const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3, const int src2_nb1, const int src2_nb2, const int src3_nb1, const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3, const int64_t s_off, const int64_t n_head, const int64_t d_head, const int64_t n_group, const int64_t n_tok) { + const float * GGML_CUDA_RESTRICT src0 = src0_ptr; + const float * GGML_CUDA_RESTRICT src1 = src1_ptr; + const float * GGML_CUDA_RESTRICT src2 = src2_ptr; + const float * GGML_CUDA_RESTRICT src3 = src3_ptr; + const float * GGML_CUDA_RESTRICT src4 = src4_ptr; + const float * GGML_CUDA_RESTRICT src5 = src5_ptr; + const int32_t * GGML_CUDA_RESTRICT src6 = src6_ptr; + float * GGML_CUDA_RESTRICT dst = dst_ptr; const int warp = threadIdx.x / WARP_SIZE; const int lane = threadIdx.x % WARP_SIZE; From a1a31868870f0900940d27b5c4d426a9938731d4 Mon Sep 17 00:00:00 2001 From: rehan-10xengineer <rehanbackup0317@gmail.com> Date: Thu, 4 Jun 2026 10:03:40 +0500 Subject: [PATCH 783/831] ggml-cpu: extend RVV quantization vec dot to higher VLENs (llama/22754) * ggml-cpu: add rvv 512b,1024b impls for iq4_xs * ggml-cpu: refactor; add rvv 512b, 1024b impls for q6_K, i-quants * ggml-cpu: refactor; add 512 and 1024 implementations of tq3_s, iq3_xxs, iq2_s, iq2_xs, iq2_xxs improve iq2_xs impl for rvv 256 Co-authored-by: Rehan Qasim <rehan.qasim@10xengineers.ai> --------- Co-authored-by: taimur-10x <taimur.ahmad@10xengineers.ai> Co-authored-by: Rehan Qasim <rehan.qasim@10xengineers.ai> --- ggml/src/ggml-cpu/arch/riscv/quants.c | 3895 +++++++++++++++++++------ 1 file changed, 2969 insertions(+), 926 deletions(-) diff --git a/ggml/src/ggml-cpu/arch/riscv/quants.c b/ggml/src/ggml-cpu/arch/riscv/quants.c index ee69e5ab5e5..47e9180bf9b 100644 --- a/ggml/src/ggml-cpu/arch/riscv/quants.c +++ b/ggml/src/ggml-cpu/arch/riscv/quants.c @@ -123,7 +123,7 @@ void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, in assert(k % QK_K == 0); size_t nb = k / QK_K; -#if defined __riscv_v_intrinsic +#if defined __riscv_v block_q8_K * y_blocks = (block_q8_K *)y; const size_t vlmax_f32m8 = __riscv_vsetvlmax_e32m8(); @@ -578,7 +578,8 @@ void ggml_vec_dot_q1_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi #endif } -void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_xtheadvector +void ggml_vec_dot_q2_K_q8_K_xtheadvector(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(nrc == 1); UNUSED(nrc); UNUSED(bx); @@ -590,8 +591,6 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi const int nb = n / QK_K; -#if defined __riscv_xtheadvector - float sumf = 0; uint8_t atmp[16]; @@ -686,246 +685,281 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi } *s = sumf; +} +#endif -#elif defined __riscv_v +#if defined __riscv_v +void ggml_vec_dot_q2_K_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q2_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; float sumf = 0; uint8_t atmp[16]; - const int vector_length = __riscv_vlenb() * 8; uint8_t temp_01[32] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 }; - switch (vector_length) { - case 256: - for (int i = 0; i < nb; ++i) { - const uint8_t * q2 = x[i].qs; - const int8_t * q8 = y[i].qs; - const uint8_t * sc = x[i].scales; + for (int i = 0; i < nb; ++i) { + const uint8_t * q2 = x[i].qs; + const int8_t * q8 = y[i].qs; + const uint8_t * sc = x[i].scales; + const float dall = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin); + uint8_t *patmp = atmp; + int vsums; + int tmp, t1, t2, t3, t4, t5, t6, t7; + __asm__ __volatile__( + "vsetivli zero, 16, e8, m1\n\t" + "vmv.v.x v8, zero\n\t" + "lb zero, 15(%[sc])\n\t" + "vle8.v v1, (%[sc])\n\t" + "vle8.v v2, (%[bsums])\n\t" + "addi %[tmp], %[bsums], 16\n\t" + "vand.vi v0, v1, 0xF\n\t" + "vsrl.vi v1, v1, 4\n\t" + "vle8.v v3, (%[tmp])\n\t" + "vse8.v v0, (%[scale])\n\t" + "vsetivli zero, 16, e16, m2\n\t" + "vzext.vf2 v0, v1\n\t" + "vwmul.vv v4, v0, v2\n\t" + "vsetivli zero, 16, e32, m4\n\t" + "vredsum.vs v8, v4, v8\n\t" + "vmv.x.s %[vsums], v8" + : [tmp] "=&r" (tmp), [vsums] "=&r" (vsums) + : [sc] "r" (sc), [scale] "r" (atmp), [bsums] "r" (y[i].bsums) + : "memory" + , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" + , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" + , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" + , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" + ); + sumf += dmin * vsums; + int isum = 0; + + for (int j = 0; j < QK_K/128; ++j) { + __asm__ __volatile__( + "lb zero, 31(%[q2])\n\t" + "addi %[tmp], %[q2], 16\n\t" + "addi %[t1], %[q8], 16\n\t" + "vsetivli zero, 16, e8, m1\n\t" + "vle8.v v0, (%[q2])\n\t" + "vle8.v v1, (%[tmp])\n\t" + "vsrl.vi v2, v0, 2\n\t" + "vsrl.vi v3, v1, 2\n\t" + "vsrl.vi v4, v0, 4\n\t" + "addi %[tmp], %[q8], 32\n\t" + "vle8.v v8, (%[q8])\n\t" + "vle8.v v9, (%[t1])\n\t" + "addi %[t1], %[t1], 32\n\t" + "vsrl.vi v5, v1, 4\n\t" + "vsrl.vi v6, v0, 6\n\t" + "vsrl.vi v7, v1, 6\n\t" + "vle8.v v10, (%[tmp])\n\t" + "vle8.v v11, (%[t1])\n\t" + "addi %[tmp], %[tmp], 32\n\t" + "addi %[t1], %[t1], 32\n\t" + "vand.vi v0, v0, 0x3\n\t" + "vand.vi v1, v1, 0x3\n\t" + "vand.vi v2, v2, 0x3\n\t" + "vle8.v v12, (%[tmp])\n\t" + "vle8.v v13, (%[t1])\n\t" + "addi %[tmp], %[tmp], 32\n\t" + "addi %[t1], %[t1], 32\n\t" + "vand.vi v3, v3, 0x3\n\t" + "vand.vi v4, v4, 0x3\n\t" + "vand.vi v5, v5, 0x3\n\t" + "vle8.v v14, (%[tmp])\n\t" + "vle8.v v15, (%[t1])\n\t" + "vwmul.vv v16, v0, v8\n\t" + "vwmul.vv v18, v1, v9\n\t" + "vwmul.vv v20, v2, v10\n\t" + "vwmul.vv v22, v3, v11\n\t" + "vwmul.vv v24, v4, v12\n\t" + "vwmul.vv v26, v5, v13\n\t" + "vwmul.vv v28, v6, v14\n\t" + "vwmul.vv v30, v7, v15\n\t" + "vsetivli zero, 8, e16, m1\n\t" + "vmv.v.x v0, zero\n\t" + "lbu %[tmp], 0(%[scale])\n\t" + "vwredsum.vs v8, v16, v0\n\t" + "vwredsum.vs v9, v18, v0\n\t" + "lbu %[t1], 1(%[scale])\n\t" + "vwredsum.vs v10, v20, v0\n\t" + "vwredsum.vs v11, v22, v0\n\t" + "lbu %[t2], 2(%[scale])\n\t" + "vwredsum.vs v12, v24, v0\n\t" + "vwredsum.vs v13, v26, v0\n\t" + "lbu %[t3], 3(%[scale])\n\t" + "vwredsum.vs v14, v28, v0\n\t" + "vwredsum.vs v15, v30, v0\n\t" + "lbu %[t4], 4(%[scale])\n\t" + "vwredsum.vs v8, v17, v8\n\t" + "vwredsum.vs v9, v19, v9\n\t" + "lbu %[t5], 5(%[scale])\n\t" + "vwredsum.vs v10, v21, v10\n\t" + "vwredsum.vs v11, v23, v11\n\t" + "lbu %[t6], 6(%[scale])\n\t" + "vwredsum.vs v12, v25, v12\n\t" + "vwredsum.vs v13, v27, v13\n\t" + "lbu %[t7], 7(%[scale])\n\t" + "vwredsum.vs v14, v29, v14\n\t" + "vwredsum.vs v15, v31, v15\n\t" + "vsetivli zero, 4, e32, m1\n\t" + "vmul.vx v0, v8, %[tmp]\n\t" + "vmul.vx v1, v9, %[t1]\n\t" + "vmacc.vx v0, %[t2], v10\n\t" + "vmacc.vx v1, %[t3], v11\n\t" + "vmacc.vx v0, %[t4], v12\n\t" + "vmacc.vx v1, %[t5], v13\n\t" + "vmacc.vx v0, %[t6], v14\n\t" + "vmacc.vx v1, %[t7], v15\n\t" + "vmv.x.s %[tmp], v0\n\t" + "vmv.x.s %[t1], v1\n\t" + "add %[isum], %[isum], %[tmp]\n\t" + "add %[isum], %[isum], %[t1]" + : [tmp] "=&r" (tmp), [t1] "=&r" (t1), [t2] "=&r" (t2), [t3] "=&r" (t3) + , [t4] "=&r" (t4), [t5] "=&r" (t5), [t6] "=&r" (t6), [t7] "=&r" (t7) + , [isum] "+&r" (isum) + : [q2] "r" (q2), [scale] "r" (patmp), [q8] "r" (q8) + : "memory" + , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" + , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" + , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" + , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" + ); + q2 += 32; q8 += 128; patmp += 8; + } + + sumf += dall * isum; + } + + *s = sumf; +} + +void ggml_vec_dot_q2_K_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); - const float dall = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); - const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin); + const block_q2_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; - size_t vl = 16; + const int nb = n / QK_K; - vuint8m1_t scales = __riscv_vle8_v_u8m1(sc, vl); - vuint8m1_t aux = __riscv_vand_vx_u8m1(scales, 0x0F, vl); + float sumf = 0; + uint8_t atmp[16]; - vint16m1_t q8sums = __riscv_vle16_v_i16m1(y[i].bsums, vl); + uint8_t temp_01[32] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 }; - vuint8mf2_t scales_2 = __riscv_vle8_v_u8mf2(sc, vl); - vuint8mf2_t mins8 = __riscv_vsrl_vx_u8mf2(scales_2, 0x4, vl); - vint16m1_t mins = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(mins8, vl)); - vint32m2_t prod = __riscv_vwmul_vv_i32m2(q8sums, mins, vl); - vint32m1_t vsums = __riscv_vredsum_vs_i32m2_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); + for (int i = 0; i < nb; ++i) { + const uint8_t * q2 = x[i].qs; + const int8_t * q8 = y[i].qs; + const uint8_t * sc = x[i].scales; - sumf += dmin * __riscv_vmv_x_s_i32m1_i32(vsums); + const float dall = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin); - vl = 32; + size_t vl = 16; - vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); - vuint8m1_t v_b = __riscv_vle8_v_u8m1(temp_01, vl); + vuint8m1_t scales = __riscv_vle8_v_u8m1(sc, vl); + vuint8m1_t aux = __riscv_vand_vx_u8m1(scales, 0x0F, vl); - uint8_t is = 0; - int isum = 0; + vint16m1_t q8sums = __riscv_vle16_v_i16m1(y[i].bsums, vl); - for (int j = 0; j < QK_K / 128; ++j) { - // load Q2 - vuint8m1_t q2_x = __riscv_vle8_v_u8m1(q2, vl); + vuint8mf2_t scales_2 = __riscv_vle8_v_u8mf2(sc, vl); + vuint8mf2_t mins8 = __riscv_vsrl_vx_u8mf2(scales_2, 0x4, vl); + vint16m1_t mins = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(mins8, vl)); + vint32m2_t prod = __riscv_vwmul_vv_i32m2(q8sums, mins, vl); + vint32m1_t vsums = __riscv_vredsum_vs_i32m2_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); - vuint8m1_t q2_0 = __riscv_vand_vx_u8m1(q2_x, 0x03, vl); - vuint8m1_t q2_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x2, vl), 0x03, vl); - vuint8m1_t q2_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x4, vl), 0x03, vl); - vuint8m1_t q2_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x6, vl), 0x03, vl); + sumf += dmin * __riscv_vmv_x_s_i32m1_i32(vsums); - // duplicate scale elements for product - vuint8m1_t sc0 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 0 + is, vl), vl); - vuint8m1_t sc1 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 2 + is, vl), vl); - vuint8m1_t sc2 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 4 + is, vl), vl); - vuint8m1_t sc3 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 6 + is, vl), vl); + vl = 32; - vint16m2_t p0 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_0, sc0, vl)); - vint16m2_t p1 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_1, sc1, vl)); - vint16m2_t p2 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_2, sc2, vl)); - vint16m2_t p3 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_3, sc3, vl)); + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + vuint8m1_t v_b = __riscv_vle8_v_u8m1(temp_01, vl); - // load Q8 - vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl); - vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8 + 32, vl); - vint8m1_t q8_2 = __riscv_vle8_v_i8m1(q8 + 64, vl); - vint8m1_t q8_3 = __riscv_vle8_v_i8m1(q8 + 96, vl); + uint8_t is = 0; + int isum = 0; - vint32m4_t s0 = __riscv_vwmul_vv_i32m4(p0, __riscv_vwcvt_x_x_v_i16m2(q8_0, vl), vl); - vint32m4_t s1 = __riscv_vwmul_vv_i32m4(p1, __riscv_vwcvt_x_x_v_i16m2(q8_1, vl), vl); - vint32m4_t s2 = __riscv_vwmul_vv_i32m4(p2, __riscv_vwcvt_x_x_v_i16m2(q8_2, vl), vl); - vint32m4_t s3 = __riscv_vwmul_vv_i32m4(p3, __riscv_vwcvt_x_x_v_i16m2(q8_3, vl), vl); + for (int j = 0; j < QK_K / 128; ++j) { + // load Q2 + vuint8m1_t q2_x = __riscv_vle8_v_u8m1(q2, vl); - vint32m1_t isum0 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s0, s1, vl), vzero, vl); - vint32m1_t isum1 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s2, s3, vl), isum0, vl); + vuint8m1_t q2_0 = __riscv_vand_vx_u8m1(q2_x, 0x03, vl); + vuint8m1_t q2_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x2, vl), 0x03, vl); + vuint8m1_t q2_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x4, vl), 0x03, vl); + vuint8m1_t q2_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x6, vl), 0x03, vl); - isum += __riscv_vmv_x_s_i32m1_i32(isum1); + // duplicate scale elements for product + vuint8m1_t sc0 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 0 + is, vl), vl); + vuint8m1_t sc1 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 2 + is, vl), vl); + vuint8m1_t sc2 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 4 + is, vl), vl); + vuint8m1_t sc3 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 6 + is, vl), vl); - q2 += 32; - q8 += 128; - is = 8; - } + vint16m2_t p0 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_0, sc0, vl)); + vint16m2_t p1 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_1, sc1, vl)); + vint16m2_t p2 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_2, sc2, vl)); + vint16m2_t p3 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_3, sc3, vl)); - sumf += dall * isum; - } - break; - case 128: - for (int i = 0; i < nb; ++i) { - const uint8_t * q2 = x[i].qs; - const int8_t * q8 = y[i].qs; - const uint8_t * sc = x[i].scales; - const float dall = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); - const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin); - uint8_t *patmp = atmp; - int vsums; - int tmp, t1, t2, t3, t4, t5, t6, t7; - __asm__ __volatile__( - "vsetivli zero, 16, e8, m1\n\t" - "vmv.v.x v8, zero\n\t" - "lb zero, 15(%[sc])\n\t" - "vle8.v v1, (%[sc])\n\t" - "vle8.v v2, (%[bsums])\n\t" - "addi %[tmp], %[bsums], 16\n\t" - "vand.vi v0, v1, 0xF\n\t" - "vsrl.vi v1, v1, 4\n\t" - "vle8.v v3, (%[tmp])\n\t" - "vse8.v v0, (%[scale])\n\t" - "vsetivli zero, 16, e16, m2\n\t" - "vzext.vf2 v0, v1\n\t" - "vwmul.vv v4, v0, v2\n\t" - "vsetivli zero, 16, e32, m4\n\t" - "vredsum.vs v8, v4, v8\n\t" - "vmv.x.s %[vsums], v8" - : [tmp] "=&r" (tmp), [vsums] "=&r" (vsums) - : [sc] "r" (sc), [scale] "r" (atmp), [bsums] "r" (y[i].bsums) - : "memory" - , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" - , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" - , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" - , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" - ); - sumf += dmin * vsums; - int isum = 0; - - for (int j = 0; j < QK_K/128; ++j) { - __asm__ __volatile__( - "lb zero, 31(%[q2])\n\t" - "addi %[tmp], %[q2], 16\n\t" - "addi %[t1], %[q8], 16\n\t" - "vsetivli zero, 16, e8, m1\n\t" - "vle8.v v0, (%[q2])\n\t" - "vle8.v v1, (%[tmp])\n\t" - "vsrl.vi v2, v0, 2\n\t" - "vsrl.vi v3, v1, 2\n\t" - "vsrl.vi v4, v0, 4\n\t" - "addi %[tmp], %[q8], 32\n\t" - "vle8.v v8, (%[q8])\n\t" - "vle8.v v9, (%[t1])\n\t" - "addi %[t1], %[t1], 32\n\t" - "vsrl.vi v5, v1, 4\n\t" - "vsrl.vi v6, v0, 6\n\t" - "vsrl.vi v7, v1, 6\n\t" - "vle8.v v10, (%[tmp])\n\t" - "vle8.v v11, (%[t1])\n\t" - "addi %[tmp], %[tmp], 32\n\t" - "addi %[t1], %[t1], 32\n\t" - "vand.vi v0, v0, 0x3\n\t" - "vand.vi v1, v1, 0x3\n\t" - "vand.vi v2, v2, 0x3\n\t" - "vle8.v v12, (%[tmp])\n\t" - "vle8.v v13, (%[t1])\n\t" - "addi %[tmp], %[tmp], 32\n\t" - "addi %[t1], %[t1], 32\n\t" - "vand.vi v3, v3, 0x3\n\t" - "vand.vi v4, v4, 0x3\n\t" - "vand.vi v5, v5, 0x3\n\t" - "vle8.v v14, (%[tmp])\n\t" - "vle8.v v15, (%[t1])\n\t" - "vwmul.vv v16, v0, v8\n\t" - "vwmul.vv v18, v1, v9\n\t" - "vwmul.vv v20, v2, v10\n\t" - "vwmul.vv v22, v3, v11\n\t" - "vwmul.vv v24, v4, v12\n\t" - "vwmul.vv v26, v5, v13\n\t" - "vwmul.vv v28, v6, v14\n\t" - "vwmul.vv v30, v7, v15\n\t" - "vsetivli zero, 8, e16, m1\n\t" - "vmv.v.x v0, zero\n\t" - "lbu %[tmp], 0(%[scale])\n\t" - "vwredsum.vs v8, v16, v0\n\t" - "vwredsum.vs v9, v18, v0\n\t" - "lbu %[t1], 1(%[scale])\n\t" - "vwredsum.vs v10, v20, v0\n\t" - "vwredsum.vs v11, v22, v0\n\t" - "lbu %[t2], 2(%[scale])\n\t" - "vwredsum.vs v12, v24, v0\n\t" - "vwredsum.vs v13, v26, v0\n\t" - "lbu %[t3], 3(%[scale])\n\t" - "vwredsum.vs v14, v28, v0\n\t" - "vwredsum.vs v15, v30, v0\n\t" - "lbu %[t4], 4(%[scale])\n\t" - "vwredsum.vs v8, v17, v8\n\t" - "vwredsum.vs v9, v19, v9\n\t" - "lbu %[t5], 5(%[scale])\n\t" - "vwredsum.vs v10, v21, v10\n\t" - "vwredsum.vs v11, v23, v11\n\t" - "lbu %[t6], 6(%[scale])\n\t" - "vwredsum.vs v12, v25, v12\n\t" - "vwredsum.vs v13, v27, v13\n\t" - "lbu %[t7], 7(%[scale])\n\t" - "vwredsum.vs v14, v29, v14\n\t" - "vwredsum.vs v15, v31, v15\n\t" - "vsetivli zero, 4, e32, m1\n\t" - "vmul.vx v0, v8, %[tmp]\n\t" - "vmul.vx v1, v9, %[t1]\n\t" - "vmacc.vx v0, %[t2], v10\n\t" - "vmacc.vx v1, %[t3], v11\n\t" - "vmacc.vx v0, %[t4], v12\n\t" - "vmacc.vx v1, %[t5], v13\n\t" - "vmacc.vx v0, %[t6], v14\n\t" - "vmacc.vx v1, %[t7], v15\n\t" - "vmv.x.s %[tmp], v0\n\t" - "vmv.x.s %[t1], v1\n\t" - "add %[isum], %[isum], %[tmp]\n\t" - "add %[isum], %[isum], %[t1]" - : [tmp] "=&r" (tmp), [t1] "=&r" (t1), [t2] "=&r" (t2), [t3] "=&r" (t3) - , [t4] "=&r" (t4), [t5] "=&r" (t5), [t6] "=&r" (t6), [t7] "=&r" (t7) - , [isum] "+&r" (isum) - : [q2] "r" (q2), [scale] "r" (patmp), [q8] "r" (q8) - : "memory" - , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" - , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" - , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" - , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" - ); - q2 += 32; q8 += 128; patmp += 8; - } + // load Q8 + vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl); + vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8 + 32, vl); + vint8m1_t q8_2 = __riscv_vle8_v_i8m1(q8 + 64, vl); + vint8m1_t q8_3 = __riscv_vle8_v_i8m1(q8 + 96, vl); - sumf += dall * isum; + vint32m4_t s0 = __riscv_vwmul_vv_i32m4(p0, __riscv_vwcvt_x_x_v_i16m2(q8_0, vl), vl); + vint32m4_t s1 = __riscv_vwmul_vv_i32m4(p1, __riscv_vwcvt_x_x_v_i16m2(q8_1, vl), vl); + vint32m4_t s2 = __riscv_vwmul_vv_i32m4(p2, __riscv_vwcvt_x_x_v_i16m2(q8_2, vl), vl); + vint32m4_t s3 = __riscv_vwmul_vv_i32m4(p3, __riscv_vwcvt_x_x_v_i16m2(q8_3, vl), vl); + + vint32m1_t isum0 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s0, s1, vl), vzero, vl); + vint32m1_t isum1 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s2, s3, vl), isum0, vl); + + isum += __riscv_vmv_x_s_i32m1_i32(isum1); + + q2 += 32; + q8 += 128; + is = 8; } - break; - default: - assert(false && "Unsupported vector length"); - break; + + sumf += dall * isum; } *s = sumf; +} +#endif +void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_xtheadvector + ggml_vec_dot_q2_K_q8_K_xtheadvector(n, s, bs, vx, bx, vy, by, nrc); +#elif defined __riscv_v + switch (__riscv_vlenb() * 8) { + case 128: + ggml_vec_dot_q2_K_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); + break; + default: + ggml_vec_dot_q2_K_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); + break; + } #else - - UNUSED(x); - UNUSED(y); - UNUSED(nb); - ggml_vec_dot_q2_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); #endif } -void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_xtheadvector +void ggml_vec_dot_q3_K_q8_K_xtheadvector(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); UNUSED(nrc); @@ -941,8 +975,6 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi const int nb = n / QK_K; -#if defined __riscv_xtheadvector - uint32_t utmp[4]; float sumf = 0; @@ -1068,257 +1100,274 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi } *s = sumf; +} +#endif -#elif defined __riscv_v +#if defined __riscv_v +void ggml_vec_dot_q3_K_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const uint32_t kmask1 = 0x03030303; + const uint32_t kmask2 = 0x0f0f0f0f; + + const block_q3_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; uint32_t utmp[4]; float sumf = 0; uint32_t aux[3]; - const int vector_length = __riscv_vlenb() * 8; - switch (vector_length) { - case 256: - for (int i = 0; i < nb; ++i) { + for (int i = 0; i < nb; ++i) { + const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict qh = x[i].hmask; + const int8_t * restrict q8 = y[i].qs; + + int8_t * scale = (int8_t *)utmp; + int tmp, t1, t2, t3, t4, t5, t6, t7; + __asm__ __volatile__( + "vsetivli zero, 12, e8, m1\n\t" + "vle8.v v0, (%[s6b])\n\t" + "vmv1r.v v2, v0\n\t" + "vsetivli zero, 2, e64, m1\n\t" + "vmv.v.x v9, %[sh]\n\t"\ + "vslidedown.vi v1, v0, 1\n\t" + "vslide1up.vx v8, v9, zero\n\t" // {0, 0, 4, 4} + "vslideup.vi v0, v2, 1\n\t" // {aux[0], aux[1], aux[0], aux[1]} + "vsetivli zero, 4, e32, m1\n\t" + "vid.v v9\n\t" + "vmv.x.s %[tmp], v1\n\t" + "vsll.vi v9, v9, 1\n\t" // {0, 2, 4, 6} + "vmv.v.x v1, %[tmp]\n\t" // {aux[2], aux[2], aux[2], aux[2]} + "vsrl.vv v4, v1, v9\n\t" + "vsrl.vv v2, v0, v8\n\t" + "vand.vx v5, v4, %[kmask1]\n\t" + "vand.vx v3, v2, %[kmask2]\n\t" + "vsll.vi v6, v5, 4\n\t" + "vor.vv v7, v6, v3\n\t" + "vsetivli zero, 16, e8, m1\n\t" + "vsub.vx v0, v7, %[c]\n\t" + "vse8.v v0, (%[scale])" + : [tmp] "=&r" (tmp) + : [sh] "r" (0x0000000400000004), [s6b] "r" (x[i].scales), [c] "r" (32) + , [scale] "r" (scale), [kmask1] "r" (kmask1), [kmask2] "r" (kmask2) + : "memory" + , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" + , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" + , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" + , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" + ); + + uint8_t m = 1; + int isum = 0; + for (int j = 0; j < QK_K; j += 128) { + __asm__ __volatile__( + "lb zero, 31(%[q3])\n\t" + "vsetvli zero, %[vl32], e8, m2, ta, mu\n\t" + "vle8.v v8, (%[q3])\n\t" + "vsrl.vi v10, v8, 2\n\t" + "vsrl.vi v12, v8, 4\n\t" + "vsrl.vi v14, v8, 6\n\t" + "lb zero, 64(%[q8])\n\t" + "vand.vi v8, v8, 3\n\t" + "vand.vi v10, v10, 3\n\t" + "vand.vi v12, v12, 3\n\t" + "vle8.v v2, (%[qh])\n\t" + "lb zero, 127(%[q8])\n\t" + "vand.vx v4, v2, %[m]\n\t" + "slli %[m], %[m], 1\n\t" + "vmseq.vx v0, v4, zero\n\t" + "vadd.vi v8, v8, -4, v0.t\n\t" + "lb zero, 0(%[q8])\n\t" + "vand.vx v4, v2, %[m]\n\t" + "slli %[m], %[m], 1\n\t" + "vmseq.vx v0, v4, zero\n\t" + "vadd.vi v10, v10, -4, v0.t\n\t" + "vand.vx v4, v2, %[m]\n\t" + "slli %[m], %[m], 1\n\t" + "vmseq.vx v0, v4, zero\n\t" + "vadd.vi v12, v12, -4, v0.t\n\t" + "vand.vx v4, v2, %[m]\n\t" + "slli %[m], %[m], 1\n\t" + "vmseq.vx v0, v4, zero\n\t" + "vadd.vi v14, v14, -4, v0.t\n\t" + "vsetvli zero, %[vl128], e8, m8\n\t" + "vle8.v v0, (%[q8])\n\t" + "lb %[tmp], 0(%[scale])\n\t" + "lb %[t1], 1(%[scale])\n\t" + "lb %[t2], 2(%[scale])\n\t" + "lb %[t3], 3(%[scale])\n\t" + "vsetvli zero, %[vl64], e8, m4\n\t" + "vwmul.vv v16, v0, v8\n\t" + "vwmul.vv v24, v4, v12\n\t" + "vsetivli zero, 16, e16, m2\n\t" + "vmv.v.x v0, zero\n\t" + "vwredsum.vs v8, v16, v0\n\t" + "lb %[t4], 4(%[scale])\n\t" + "lb %[t5], 5(%[scale])\n\t" + "vwredsum.vs v9, v18, v0\n\t" + "vwredsum.vs v10, v20, v0\n\t" + "vwredsum.vs v11, v22, v0\n\t" + "vwredsum.vs v12, v24, v0\n\t" + "lb %[t6], 6(%[scale])\n\t" + "lb %[t7], 7(%[scale])\n\t" + "vwredsum.vs v13, v26, v0\n\t" + "vwredsum.vs v14, v28, v0\n\t" + "vwredsum.vs v15, v30, v0\n\t" + "vsetivli zero, 4, e32, m1\n\t" + "vmul.vx v0, v8, %[tmp]\n\t" + "vmul.vx v1, v9, %[t1]\n\t" + "vmacc.vx v0, %[t2], v10\n\t" + "vmacc.vx v1, %[t3], v11\n\t" + "vmacc.vx v0, %[t4], v12\n\t" + "vmacc.vx v1, %[t5], v13\n\t" + "vmacc.vx v0, %[t6], v14\n\t" + "vmacc.vx v1, %[t7], v15\n\t" + "vmv.x.s %[tmp], v0\n\t" + "vmv.x.s %[t1], v1\n\t" + "add %[isum], %[isum], %[tmp]\n\t" + "add %[isum], %[isum], %[t1]" + : [tmp] "=&r" (tmp), [t1] "=&r" (t1), [t2] "=&r" (t2), [t3] "=&r" (t3) + , [t4] "=&r" (t4), [t5] "=&r" (t5), [t6] "=&r" (t6), [t7] "=&r" (t7) + , [m] "+&r" (m), [isum] "+&r" (isum) + : [vl128] "r" (128), [vl64] "r" (64), [vl32] "r" (32) + , [q3] "r" (q3), [qh] "r" (qh), [scale] "r" (scale), [q8] "r" (q8) + : "memory" + , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" + , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" + , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" + , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" + ); + q3 += 32; q8 += 128; scale += 8; + } + + const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + sumf += d * isum; + } - const uint8_t * GGML_RESTRICT q3 = x[i].qs; - const uint8_t * GGML_RESTRICT qh = x[i].hmask; - const int8_t * GGML_RESTRICT q8 = y[i].qs; + *s = sumf; +} - memcpy(aux, x[i].scales, 12); - utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4); - utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4); - utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4); - utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4); +void ggml_vec_dot_q3_K_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); - int8_t * scale = (int8_t *)utmp; - for (int j = 0; j < 16; ++j) scale[j] -= 32; + const uint32_t kmask1 = 0x03030303; + const uint32_t kmask2 = 0x0f0f0f0f; + const block_q3_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; - size_t vl = 32; - uint8_t m = 1; + const int nb = n / QK_K; + uint32_t utmp[4]; + float sumf = 0; + uint32_t aux[3]; - vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); - vuint8m1_t vqh = __riscv_vle8_v_u8m1(qh, vl); + for (int i = 0; i < nb; ++i) { + const uint8_t * GGML_RESTRICT q3 = x[i].qs; + const uint8_t * GGML_RESTRICT qh = x[i].hmask; + const int8_t * GGML_RESTRICT q8 = y[i].qs; - int sum_t = 0; + memcpy(aux, x[i].scales, 12); + utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4); + utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4); + utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4); + utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4); - for (int j = 0; j < QK_K; j += 128) { + int8_t * scale = (int8_t *)utmp; + for (int j = 0; j < 16; ++j) scale[j] -= 32; - vl = 32; - // load Q3 - vuint8m1_t q3_x = __riscv_vle8_v_u8m1(q3, vl); + size_t vl = 32; + uint8_t m = 1; - vint8m1_t q3_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q3_x, 0x03, vl)); - vint8m1_t q3_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x2, vl), 0x03 , vl)); - vint8m1_t q3_2 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x4, vl), 0x03 , vl)); - vint8m1_t q3_3 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x6, vl), 0x03 , vl)); + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + vuint8m1_t vqh = __riscv_vle8_v_u8m1(qh, vl); - // compute mask for subtraction - vuint8m1_t qh_m0 = __riscv_vand_vx_u8m1(vqh, m, vl); - vbool8_t vmask_0 = __riscv_vmseq_vx_u8m1_b8(qh_m0, 0, vl); - vint8m1_t q3_m0 = __riscv_vsub_vx_i8m1_mu(vmask_0, q3_0, q3_0, 0x4, vl); - m <<= 1; + int sum_t = 0; - vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl); - vbool8_t vmask_1 = __riscv_vmseq_vx_u8m1_b8(qh_m1, 0, vl); - vint8m1_t q3_m1 = __riscv_vsub_vx_i8m1_mu(vmask_1, q3_1, q3_1, 0x4, vl); - m <<= 1; + for (int j = 0; j < QK_K; j += 128) { - vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl); - vbool8_t vmask_2 = __riscv_vmseq_vx_u8m1_b8(qh_m2, 0, vl); - vint8m1_t q3_m2 = __riscv_vsub_vx_i8m1_mu(vmask_2, q3_2, q3_2, 0x4, vl); - m <<= 1; + vl = 32; - vuint8m1_t qh_m3 = __riscv_vand_vx_u8m1(vqh, m, vl); - vbool8_t vmask_3 = __riscv_vmseq_vx_u8m1_b8(qh_m3, 0, vl); - vint8m1_t q3_m3 = __riscv_vsub_vx_i8m1_mu(vmask_3, q3_3, q3_3, 0x4, vl); - m <<= 1; + // load Q3 + vuint8m1_t q3_x = __riscv_vle8_v_u8m1(q3, vl); - // load Q8 and take product with Q3 - vint16m2_t a0 = __riscv_vwmul_vv_i16m2(q3_m0, __riscv_vle8_v_i8m1(q8, vl), vl); - vint16m2_t a1 = __riscv_vwmul_vv_i16m2(q3_m1, __riscv_vle8_v_i8m1(q8+32, vl), vl); - vint16m2_t a2 = __riscv_vwmul_vv_i16m2(q3_m2, __riscv_vle8_v_i8m1(q8+64, vl), vl); - vint16m2_t a3 = __riscv_vwmul_vv_i16m2(q3_m3, __riscv_vle8_v_i8m1(q8+96, vl), vl); + vint8m1_t q3_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q3_x, 0x03, vl)); + vint8m1_t q3_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x2, vl), 0x03 , vl)); + vint8m1_t q3_2 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x4, vl), 0x03 , vl)); + vint8m1_t q3_3 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x6, vl), 0x03 , vl)); - vl = 16; + // compute mask for subtraction + vuint8m1_t qh_m0 = __riscv_vand_vx_u8m1(vqh, m, vl); + vbool8_t vmask_0 = __riscv_vmseq_vx_u8m1_b8(qh_m0, 0, vl); + vint8m1_t q3_m0 = __riscv_vsub_vx_i8m1_mu(vmask_0, q3_0, q3_0, 0x4, vl); + m <<= 1; - // retrieve lane to multiply with scale - vint32m2_t aux0_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 0), (scale[0]), vl); - vint32m2_t aux0_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 1), (scale[1]), vl); - vint32m2_t aux1_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 0), (scale[2]), vl); - vint32m2_t aux1_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 1), (scale[3]), vl); - vint32m2_t aux2_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 0), (scale[4]), vl); - vint32m2_t aux2_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 1), (scale[5]), vl); - vint32m2_t aux3_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 0), (scale[6]), vl); - vint32m2_t aux3_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 1), (scale[7]), vl); + vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl); + vbool8_t vmask_1 = __riscv_vmseq_vx_u8m1_b8(qh_m1, 0, vl); + vint8m1_t q3_m1 = __riscv_vsub_vx_i8m1_mu(vmask_1, q3_1, q3_1, 0x4, vl); + m <<= 1; - vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux0_0, aux0_1, vl), vzero, vl); - vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux1_0, aux1_1, vl), isum0, vl); - vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux2_0, aux2_1, vl), isum1, vl); - vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux3_0, aux3_1, vl), isum2, vl); + vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl); + vbool8_t vmask_2 = __riscv_vmseq_vx_u8m1_b8(qh_m2, 0, vl); + vint8m1_t q3_m2 = __riscv_vsub_vx_i8m1_mu(vmask_2, q3_2, q3_2, 0x4, vl); + m <<= 1; - sum_t += __riscv_vmv_x_s_i32m1_i32(isum3); + vuint8m1_t qh_m3 = __riscv_vand_vx_u8m1(vqh, m, vl); + vbool8_t vmask_3 = __riscv_vmseq_vx_u8m1_b8(qh_m3, 0, vl); + vint8m1_t q3_m3 = __riscv_vsub_vx_i8m1_mu(vmask_3, q3_3, q3_3, 0x4, vl); + m <<= 1; - q3 += 32; q8 += 128; scale += 8; + // load Q8 and take product with Q3 + vint16m2_t a0 = __riscv_vwmul_vv_i16m2(q3_m0, __riscv_vle8_v_i8m1(q8, vl), vl); + vint16m2_t a1 = __riscv_vwmul_vv_i16m2(q3_m1, __riscv_vle8_v_i8m1(q8+32, vl), vl); + vint16m2_t a2 = __riscv_vwmul_vv_i16m2(q3_m2, __riscv_vle8_v_i8m1(q8+64, vl), vl); + vint16m2_t a3 = __riscv_vwmul_vv_i16m2(q3_m3, __riscv_vle8_v_i8m1(q8+96, vl), vl); - } + vl = 16; - const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + // retrieve lane to multiply with scale + vint32m2_t aux0_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 0), (scale[0]), vl); + vint32m2_t aux0_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 1), (scale[1]), vl); + vint32m2_t aux1_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 0), (scale[2]), vl); + vint32m2_t aux1_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 1), (scale[3]), vl); + vint32m2_t aux2_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 0), (scale[4]), vl); + vint32m2_t aux2_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 1), (scale[5]), vl); + vint32m2_t aux3_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 0), (scale[6]), vl); + vint32m2_t aux3_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 1), (scale[7]), vl); - sumf += d*sum_t; + vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux0_0, aux0_1, vl), vzero, vl); + vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux1_0, aux1_1, vl), isum0, vl); + vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux2_0, aux2_1, vl), isum1, vl); + vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux3_0, aux3_1, vl), isum2, vl); - } - break; - case 128: - for (int i = 0; i < nb; ++i) { - const uint8_t * restrict q3 = x[i].qs; - const uint8_t * restrict qh = x[i].hmask; - const int8_t * restrict q8 = y[i].qs; - - int8_t * scale = (int8_t *)utmp; - int tmp, t1, t2, t3, t4, t5, t6, t7; - __asm__ __volatile__( - "vsetivli zero, 12, e8, m1\n\t" - "vle8.v v0, (%[s6b])\n\t" - "vmv1r.v v2, v0\n\t" - "vsetivli zero, 2, e64, m1\n\t" - "vmv.v.x v9, %[sh]\n\t"\ - "vslidedown.vi v1, v0, 1\n\t" - "vslide1up.vx v8, v9, zero\n\t" // {0, 0, 4, 4} - "vslideup.vi v0, v2, 1\n\t" // {aux[0], aux[1], aux[0], aux[1]} - "vsetivli zero, 4, e32, m1\n\t" - "vid.v v9\n\t" - "vmv.x.s %[tmp], v1\n\t" - "vsll.vi v9, v9, 1\n\t" // {0, 2, 4, 6} - "vmv.v.x v1, %[tmp]\n\t" // {aux[2], aux[2], aux[2], aux[2]} - "vsrl.vv v4, v1, v9\n\t" - "vsrl.vv v2, v0, v8\n\t" - "vand.vx v5, v4, %[kmask1]\n\t" - "vand.vx v3, v2, %[kmask2]\n\t" - "vsll.vi v6, v5, 4\n\t" - "vor.vv v7, v6, v3\n\t" - "vsetivli zero, 16, e8, m1\n\t" - "vsub.vx v0, v7, %[c]\n\t" - "vse8.v v0, (%[scale])" - : [tmp] "=&r" (tmp) - : [sh] "r" (0x0000000400000004), [s6b] "r" (x[i].scales), [c] "r" (32) - , [scale] "r" (scale), [kmask1] "r" (kmask1), [kmask2] "r" (kmask2) - : "memory" - , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" - , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" - , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" - , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" - ); + sum_t += __riscv_vmv_x_s_i32m1_i32(isum3); - uint8_t m = 1; - int isum = 0; - for (int j = 0; j < QK_K; j += 128) { - __asm__ __volatile__( - "lb zero, 31(%[q3])\n\t" - "vsetvli zero, %[vl32], e8, m2, ta, mu\n\t" - "vle8.v v8, (%[q3])\n\t" - "vsrl.vi v10, v8, 2\n\t" - "vsrl.vi v12, v8, 4\n\t" - "vsrl.vi v14, v8, 6\n\t" - "lb zero, 64(%[q8])\n\t" - "vand.vi v8, v8, 3\n\t" - "vand.vi v10, v10, 3\n\t" - "vand.vi v12, v12, 3\n\t" - "vle8.v v2, (%[qh])\n\t" - "lb zero, 127(%[q8])\n\t" - "vand.vx v4, v2, %[m]\n\t" - "slli %[m], %[m], 1\n\t" - "vmseq.vx v0, v4, zero\n\t" - "vadd.vi v8, v8, -4, v0.t\n\t" - "lb zero, 0(%[q8])\n\t" - "vand.vx v4, v2, %[m]\n\t" - "slli %[m], %[m], 1\n\t" - "vmseq.vx v0, v4, zero\n\t" - "vadd.vi v10, v10, -4, v0.t\n\t" - "vand.vx v4, v2, %[m]\n\t" - "slli %[m], %[m], 1\n\t" - "vmseq.vx v0, v4, zero\n\t" - "vadd.vi v12, v12, -4, v0.t\n\t" - "vand.vx v4, v2, %[m]\n\t" - "slli %[m], %[m], 1\n\t" - "vmseq.vx v0, v4, zero\n\t" - "vadd.vi v14, v14, -4, v0.t\n\t" - "vsetvli zero, %[vl128], e8, m8\n\t" - "vle8.v v0, (%[q8])\n\t" - "lb %[tmp], 0(%[scale])\n\t" - "lb %[t1], 1(%[scale])\n\t" - "lb %[t2], 2(%[scale])\n\t" - "lb %[t3], 3(%[scale])\n\t" - "vsetvli zero, %[vl64], e8, m4\n\t" - "vwmul.vv v16, v0, v8\n\t" - "vwmul.vv v24, v4, v12\n\t" - "vsetivli zero, 16, e16, m2\n\t" - "vmv.v.x v0, zero\n\t" - "vwredsum.vs v8, v16, v0\n\t" - "lb %[t4], 4(%[scale])\n\t" - "lb %[t5], 5(%[scale])\n\t" - "vwredsum.vs v9, v18, v0\n\t" - "vwredsum.vs v10, v20, v0\n\t" - "vwredsum.vs v11, v22, v0\n\t" - "vwredsum.vs v12, v24, v0\n\t" - "lb %[t6], 6(%[scale])\n\t" - "lb %[t7], 7(%[scale])\n\t" - "vwredsum.vs v13, v26, v0\n\t" - "vwredsum.vs v14, v28, v0\n\t" - "vwredsum.vs v15, v30, v0\n\t" - "vsetivli zero, 4, e32, m1\n\t" - "vmul.vx v0, v8, %[tmp]\n\t" - "vmul.vx v1, v9, %[t1]\n\t" - "vmacc.vx v0, %[t2], v10\n\t" - "vmacc.vx v1, %[t3], v11\n\t" - "vmacc.vx v0, %[t4], v12\n\t" - "vmacc.vx v1, %[t5], v13\n\t" - "vmacc.vx v0, %[t6], v14\n\t" - "vmacc.vx v1, %[t7], v15\n\t" - "vmv.x.s %[tmp], v0\n\t" - "vmv.x.s %[t1], v1\n\t" - "add %[isum], %[isum], %[tmp]\n\t" - "add %[isum], %[isum], %[t1]" - : [tmp] "=&r" (tmp), [t1] "=&r" (t1), [t2] "=&r" (t2), [t3] "=&r" (t3) - , [t4] "=&r" (t4), [t5] "=&r" (t5), [t6] "=&r" (t6), [t7] "=&r" (t7) - , [m] "+&r" (m), [isum] "+&r" (isum) - : [vl128] "r" (128), [vl64] "r" (64), [vl32] "r" (32) - , [q3] "r" (q3), [qh] "r" (qh), [scale] "r" (scale), [q8] "r" (q8) - : "memory" - , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" - , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" - , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" - , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" - ); - q3 += 32; q8 += 128; scale += 8; - } + q3 += 32; q8 += 128; scale += 8; - const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; - sumf += d * isum; } - break; - default: - assert(false && "Unsupported vector length"); - break; - } - *s = sumf; - -#else + const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; - UNUSED(kmask1); - UNUSED(kmask2); - UNUSED(x); - UNUSED(y); - UNUSED(nb); + sumf += d*sum_t; - ggml_vec_dot_q3_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); -#endif + } + *s = sumf; } -void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +void ggml_vec_dot_q3_K_q8_K_vl512(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); UNUSED(nrc); @@ -1326,27 +1375,289 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi UNUSED(by); UNUSED(bs); - const block_q4_K * GGML_RESTRICT x = vx; + const uint32_t kmask1 = 0x03030303; + const uint32_t kmask2 = 0x0f0f0f0f; + + const block_q3_K * GGML_RESTRICT x = vx; const block_q8_K * GGML_RESTRICT y = vy; const int nb = n / QK_K; - static const uint32_t kmask1 = 0x3f3f3f3f; - static const uint32_t kmask2 = 0x0f0f0f0f; - static const uint32_t kmask3 = 0x03030303; + // mask for processing 16 elements per prod register + const vuint16m1_t va_index = __riscv_vid_v_u16m1(32); + const vbool16_t va_mask = __riscv_vmsgtu_vx_u16m1_b16(va_index, 15, 32); uint32_t utmp[4]; - -#if defined __riscv_xtheadvector - - const uint8_t * scales = (const uint8_t*)&utmp[0]; - const uint8_t * mins = (const uint8_t*)&utmp[2]; - float sumf = 0; + uint32_t aux[3]; for (int i = 0; i < nb; ++i) { - const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); - const float dmin = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin); + const uint8_t * GGML_RESTRICT q3 = x[i].qs; + const uint8_t * GGML_RESTRICT qh = x[i].hmask; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + memcpy(aux, x[i].scales, 12); + utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4); + utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4); + utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4); + utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4); + + int8_t * scale = (int8_t *)utmp; + for (int j = 0; j < 16; ++j) scale[j] -= 32; + + + size_t vl = 32; + uint8_t m = 1; + + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + vuint8mf2_t vqh = __riscv_vle8_v_u8mf2(qh, vl); + + int sum_t = 0; + + vint32m2_t vaux_0 = __riscv_vmv_v_x_i32m2(0, vl); + vint32m2_t vaux_1 = __riscv_vmv_v_x_i32m2(0, vl); + vint32m2_t vaux_2 = __riscv_vmv_v_x_i32m2(0, vl); + vint32m2_t vaux_3 = __riscv_vmv_v_x_i32m2(0, vl); + + for (int j = 0; j < QK_K; j += 128) { + + vl = 32; + + // load Q3 + vuint8mf2_t q3_x = __riscv_vle8_v_u8mf2(q3, vl); + + vint8mf2_t q3_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(q3_x, 0x03, vl)); + vint8mf2_t q3_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q3_x, 0x2, vl), 0x03 , vl)); + vint8mf2_t q3_2 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q3_x, 0x4, vl), 0x03 , vl)); + vint8mf2_t q3_3 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q3_x, 0x6, vl), 0x03 , vl)); + + // compute mask for subtraction + vuint8mf2_t qh_m0 = __riscv_vand_vx_u8mf2(vqh, m, vl); + vbool16_t vmask_0 = __riscv_vmseq_vx_u8mf2_b16(qh_m0, 0, vl); + vint8mf2_t q3_m0 = __riscv_vsub_vx_i8mf2_mu(vmask_0, q3_0, q3_0, 0x4, vl); + m <<= 1; + + vuint8mf2_t qh_m1 = __riscv_vand_vx_u8mf2(vqh, m, vl); + vbool16_t vmask_1 = __riscv_vmseq_vx_u8mf2_b16(qh_m1, 0, vl); + vint8mf2_t q3_m1 = __riscv_vsub_vx_i8mf2_mu(vmask_1, q3_1, q3_1, 0x4, vl); + m <<= 1; + + vuint8mf2_t qh_m2 = __riscv_vand_vx_u8mf2(vqh, m, vl); + vbool16_t vmask_2 = __riscv_vmseq_vx_u8mf2_b16(qh_m2, 0, vl); + vint8mf2_t q3_m2 = __riscv_vsub_vx_i8mf2_mu(vmask_2, q3_2, q3_2, 0x4, vl); + m <<= 1; + + vuint8mf2_t qh_m3 = __riscv_vand_vx_u8mf2(vqh, m, vl); + vbool16_t vmask_3 = __riscv_vmseq_vx_u8mf2_b16(qh_m3, 0, vl); + vint8mf2_t q3_m3 = __riscv_vsub_vx_i8mf2_mu(vmask_3, q3_3, q3_3, 0x4, vl); + m <<= 1; + + // load Q8 and take product + vint16m1_t va_q_0 = __riscv_vwmul_vv_i16m1(q3_m0, __riscv_vle8_v_i8mf2(q8, vl), vl); + vint16m1_t va_q_1 = __riscv_vwmul_vv_i16m1(q3_m1, __riscv_vle8_v_i8mf2(q8+32, vl), vl); + vint16m1_t va_q_2 = __riscv_vwmul_vv_i16m1(q3_m2, __riscv_vle8_v_i8mf2(q8+64, vl), vl); + vint16m1_t va_q_3 = __riscv_vwmul_vv_i16m1(q3_m3, __riscv_vle8_v_i8mf2(q8+96, vl), vl); + + // accumulate + vaux_0 = __riscv_vwmacc_vx_i32m2(vaux_0, scale[0], va_q_0, 16); + vaux_1 = __riscv_vwmacc_vx_i32m2(vaux_1, scale[2], va_q_1, 16); + vaux_2 = __riscv_vwmacc_vx_i32m2(vaux_2, scale[4], va_q_2, 16); + vaux_3 = __riscv_vwmacc_vx_i32m2(vaux_3, scale[6], va_q_3, 16); + // + vaux_0 = __riscv_vwmacc_vx_i32m2_m(va_mask, vaux_0, scale[1], va_q_0, vl); + vaux_1 = __riscv_vwmacc_vx_i32m2_m(va_mask, vaux_1, scale[3], va_q_1, vl); + vaux_2 = __riscv_vwmacc_vx_i32m2_m(va_mask, vaux_2, scale[5], va_q_2, vl); + vaux_3 = __riscv_vwmacc_vx_i32m2_m(va_mask, vaux_3, scale[7], va_q_3, vl); + + q3 += 32; q8 += 128; scale += 8; + } + + vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_0, vaux_1, vl), vzero, vl); + vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_2, vaux_3, vl), isum0, vl); + + sum_t += __riscv_vmv_x_s_i32m1_i32(isum1); + + const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + sumf += d*sum_t; + } + + *s = sumf; +} + +void ggml_vec_dot_q3_K_q8_K_vl1024(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const uint32_t kmask1 = 0x03030303; + const uint32_t kmask2 = 0x0f0f0f0f; + + const block_q3_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + // mask for processing 16 elements per prod register + const vuint16mf2_t va_index = __riscv_vid_v_u16mf2(32); + const vbool32_t va_mask = __riscv_vmsgtu_vx_u16mf2_b32(va_index, 15, 32); + + uint32_t utmp[4]; + float sumf = 0; + uint32_t aux[3]; + + for (int i = 0; i < nb; ++i) { + const uint8_t * GGML_RESTRICT q3 = x[i].qs; + const uint8_t * GGML_RESTRICT qh = x[i].hmask; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + memcpy(aux, x[i].scales, 12); + utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4); + utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4); + utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4); + utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4); + + int8_t * scale = (int8_t *)utmp; + for (int j = 0; j < 16; ++j) scale[j] -= 32; + + + size_t vl = 32; + uint8_t m = 1; + + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + vuint8mf4_t vqh = __riscv_vle8_v_u8mf4(qh, vl); + + int sum_t = 0; + + vint32m1_t vaux_0 = __riscv_vmv_v_x_i32m1(0, vl); + vint32m1_t vaux_1 = __riscv_vmv_v_x_i32m1(0, vl); + vint32m1_t vaux_2 = __riscv_vmv_v_x_i32m1(0, vl); + vint32m1_t vaux_3 = __riscv_vmv_v_x_i32m1(0, vl); + + for (int j = 0; j < QK_K; j += 128) { + + vl = 32; + + // load Q3 + vuint8mf4_t q3_x = __riscv_vle8_v_u8mf4(q3, vl); + + vint8mf4_t q3_0 = __riscv_vreinterpret_v_u8mf4_i8mf4(__riscv_vand_vx_u8mf4(q3_x, 0x03, vl)); + vint8mf4_t q3_1 = __riscv_vreinterpret_v_u8mf4_i8mf4(__riscv_vand_vx_u8mf4(__riscv_vsrl_vx_u8mf4(q3_x, 0x2, vl), 0x03 , vl)); + vint8mf4_t q3_2 = __riscv_vreinterpret_v_u8mf4_i8mf4(__riscv_vand_vx_u8mf4(__riscv_vsrl_vx_u8mf4(q3_x, 0x4, vl), 0x03 , vl)); + vint8mf4_t q3_3 = __riscv_vreinterpret_v_u8mf4_i8mf4(__riscv_vand_vx_u8mf4(__riscv_vsrl_vx_u8mf4(q3_x, 0x6, vl), 0x03 , vl)); + + // compute mask for subtraction + vuint8mf4_t qh_m0 = __riscv_vand_vx_u8mf4(vqh, m, vl); + vbool32_t vmask_0 = __riscv_vmseq_vx_u8mf4_b32(qh_m0, 0, vl); + vint8mf4_t q3_m0 = __riscv_vsub_vx_i8mf4_mu(vmask_0, q3_0, q3_0, 0x4, vl); + m <<= 1; + + vuint8mf4_t qh_m1 = __riscv_vand_vx_u8mf4(vqh, m, vl); + vbool32_t vmask_1 = __riscv_vmseq_vx_u8mf4_b32(qh_m1, 0, vl); + vint8mf4_t q3_m1 = __riscv_vsub_vx_i8mf4_mu(vmask_1, q3_1, q3_1, 0x4, vl); + m <<= 1; + + vuint8mf4_t qh_m2 = __riscv_vand_vx_u8mf4(vqh, m, vl); + vbool32_t vmask_2 = __riscv_vmseq_vx_u8mf4_b32(qh_m2, 0, vl); + vint8mf4_t q3_m2 = __riscv_vsub_vx_i8mf4_mu(vmask_2, q3_2, q3_2, 0x4, vl); + m <<= 1; + + vuint8mf4_t qh_m3 = __riscv_vand_vx_u8mf4(vqh, m, vl); + vbool32_t vmask_3 = __riscv_vmseq_vx_u8mf4_b32(qh_m3, 0, vl); + vint8mf4_t q3_m3 = __riscv_vsub_vx_i8mf4_mu(vmask_3, q3_3, q3_3, 0x4, vl); + m <<= 1; + + // load Q8 and take product + vint16mf2_t va_q_0 = __riscv_vwmul_vv_i16mf2(q3_m0, __riscv_vle8_v_i8mf4(q8, vl), vl); + vint16mf2_t va_q_1 = __riscv_vwmul_vv_i16mf2(q3_m1, __riscv_vle8_v_i8mf4(q8+32, vl), vl); + vint16mf2_t va_q_2 = __riscv_vwmul_vv_i16mf2(q3_m2, __riscv_vle8_v_i8mf4(q8+64, vl), vl); + vint16mf2_t va_q_3 = __riscv_vwmul_vv_i16mf2(q3_m3, __riscv_vle8_v_i8mf4(q8+96, vl), vl); + + // accumulate + vaux_0 = __riscv_vwmacc_vx_i32m1(vaux_0, scale[0], va_q_0, 16); + vaux_1 = __riscv_vwmacc_vx_i32m1(vaux_1, scale[2], va_q_1, 16); + vaux_2 = __riscv_vwmacc_vx_i32m1(vaux_2, scale[4], va_q_2, 16); + vaux_3 = __riscv_vwmacc_vx_i32m1(vaux_3, scale[6], va_q_3, 16); + // + vaux_0 = __riscv_vwmacc_vx_i32m1_m(va_mask, vaux_0, scale[1], va_q_0, vl); + vaux_1 = __riscv_vwmacc_vx_i32m1_m(va_mask, vaux_1, scale[3], va_q_1, vl); + vaux_2 = __riscv_vwmacc_vx_i32m1_m(va_mask, vaux_2, scale[5], va_q_2, vl); + vaux_3 = __riscv_vwmacc_vx_i32m1_m(va_mask, vaux_3, scale[7], va_q_3, vl); + + q3 += 32; q8 += 128; scale += 8; + } + + vint32m1_t isum0 = __riscv_vredsum_vs_i32m1_i32m1(__riscv_vadd_vv_i32m1(vaux_0, vaux_1, vl), vzero, vl); + vint32m1_t isum1 = __riscv_vredsum_vs_i32m1_i32m1(__riscv_vadd_vv_i32m1(vaux_2, vaux_3, vl), isum0, vl); + + sum_t += __riscv_vmv_x_s_i32m1_i32(isum1); + + const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + sumf += d*sum_t; + } + + *s = sumf; +} +#endif + +void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_xtheadvector + ggml_vec_dot_q3_K_q8_K_xtheadvector(n, s, bs, vx, bx, vy, by, nrc); +#elif defined __riscv_v + switch (__riscv_vlenb() * 8) { + case 128: + ggml_vec_dot_q3_K_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); + break; + case 256: + ggml_vec_dot_q3_K_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); + break; + case 512: + ggml_vec_dot_q3_K_q8_K_vl512(n, s, bs, vx, bx, vy, by, nrc); + break; + case 1024: + ggml_vec_dot_q3_K_q8_K_vl1024(n, s, bs, vx, bx, vy, by, nrc); + break; + default: + ggml_vec_dot_q3_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); + break; + } +#else + ggml_vec_dot_q3_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + +#if defined __riscv_xtheadvector +static NOINLINE void ggml_vec_dot_q4_K_q8_K_xtheadvector(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q4_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + static const uint32_t kmask1 = 0x3f3f3f3f; + static const uint32_t kmask2 = 0x0f0f0f0f; + static const uint32_t kmask3 = 0x03030303; + + uint32_t utmp[4]; + + const uint8_t * scales = (const uint8_t*)&utmp[0]; + const uint8_t * mins = (const uint8_t*)&utmp[2]; + + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); + const float dmin = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin); int tmp, tmp2, sumi; __asm__ __volatile__( @@ -1452,277 +1763,317 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi } *s = sumf; +} +#endif -#elif defined __riscv_v +#if defined __riscv_v +static NOINLINE void ggml_vec_dot_q4_K_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q4_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + static const uint32_t kmask1 = 0x3f3f3f3f; + static const uint32_t kmask2 = 0x0f0f0f0f; + static const uint32_t kmask3 = 0x03030303; + + uint32_t utmp[4]; const uint8_t * scales = (const uint8_t*)&utmp[0]; const uint8_t * mins = (const uint8_t*)&utmp[2]; float sumf = 0; - const int vector_length = __riscv_vlenb() * 8; + for (int i = 0; i < nb; ++i) { + const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); + const float dmin = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin); + + float ftmp, ft2; + const uint8_t * restrict q40; + const uint8_t * restrict q41; + const uint8_t * restrict q42; + const uint8_t * restrict q43; + const int8_t * restrict q80; + const int8_t * restrict q81; + const int8_t * restrict q82; + const int8_t * restrict q83; + int s0, s1, s2, s3; + + __asm__ __volatile__( + "li %[s1], 8\n\t" + "vsetivli zero, 4, e32, m1, ta, ma\n\t" + "vle32.v v1, (%[s6b])\n\t" + "vslide1down.vx v1, v1, zero\n\t" + "vmv.v.x v16, zero\n\t" + "vslidedown.vi v2, v1, 2\n\t" + "vmv1r.v v3, v2\n\t" + "vslideup.vi v2, v3, 1\n\t" // {aux[2], aux[2]} + "vsetivli zero, 2, e32, m1, ta, ma\n\t" + "vmv.v.i v4, 4\n\t" + "vand.vx v8, v1, %[kmask1]\n\t" + "vslide1up.vx v5, v4, zero\n\t" // {0, 4} + "vsrl.vi v6, v1, 6\n\t" + "vsrl.vv v7, v2, v5\n\t" + "vsse32.v v8, (%[utmp]), %[s1]\n\t" + "vand.vx v0, v6, %[kmask3]\n\t" + "vand.vx v2, v7, %[kmask2]\n\t" + "vsll.vi v6, v0, 4\n\t" + "addi %[s0], %[utmp], 4\n\t" + "vor.vv v1, v6, v2\n\t" + "vsse32.v v1, (%[s0]), %[s1]\n\t" + "vsetivli zero, 8, e16, m1, ta, ma\n\t" + "vle32.v v2, (%[bsums])\n\t" + "vnsrl.wi v0, v2, 0\n\t" + "vnsrl.wi v1, v2, 16\n\t" + "vadd.vv v2, v0, v1\n\t" + "vle8.v v3, (%[mins])\n\t" + "vzext.vf2 v4, v3\n\t" + "vwmul.vv v6, v4, v2\n\t" + "vsetivli zero, 4, e32, m1, ta, ma\n\t" + "vredsum.vs v0, v6, v16\n\t" + "vredsum.vs v0, v7, v0\n\t" + "vfcvt.f.x.v v0, v0\n\t" + "vfmv.f.s %[ftmp], v0\n\t" + "vsetivli zero, 16, e8, m1, ta, ma\n\t" + "vle8.v v0, (%[xs])\n\t" + "fnmsub.s %[sumf], %[dmin], %[ftmp], %[sumf]\n\t" + "addi %[q40], %[xs], 64\n\t" + "addi %[q41], %[xs], 16\n\t" + "addi %[q42], %[xs], 32\n\t" + "addi %[q43], %[xs], 48\n\t" + "addi %[q80], %[ys], 64\n\t" + "vle8.v v1, (%[q41])\n\t" + "vle8.v v2, (%[q42])\n\t" + "addi %[q81], %[ys], 16\n\t" + "addi %[q41], %[q41], 64\n\t" + "addi %[q82], %[ys], 32\n\t" + "vle8.v v3, (%[q43])\n\t" + "vle8.v v8, (%[ys])\n\t" + "addi %[q42], %[q42], 64\n\t" + "addi %[q83], %[ys], 48\n\t" + "addi %[q43], %[q43], 64\n\t" + "vsrl.vi v4, v0, 4\n\t" + "vle8.v v9, (%[q81])\n\t" + "vle8.v v10, (%[q82])\n\t" + "vand.vi v0, v0, 0xF\n\t" + "addi %[q81], %[q81], 64\n\t" + "vsrl.vi v5, v1, 4\n\t" + "addi %[q82], %[q82], 64\n\t" + "vle8.v v11, (%[q83])\n\t" + "vle8.v v12, (%[q80])\n\t" + "vand.vi v1, v1, 0xF\n\t" + "addi %[q83], %[q83], 64\n\t" + "vsrl.vi v6, v2, 4\n\t" + "addi %[q80], %[q80], 64\n\t" + "vle8.v v13, (%[q81])\n\t" + "vle8.v v14, (%[q82])\n\t" + "vand.vi v2, v2, 0xF\n\t" + "addi %[q81], %[q81], 64\n\t" + "vsrl.vi v7, v3, 4\n\t" + "addi %[q82], %[q82], 64\n\t" + "vwmul.vv v16, v0, v8\n\t" + "vle8.v v15, (%[q83])\n\t" + "vle8.v v0, (%[q40])\n\t" + "vand.vi v3, v3, 0xF\n\t" + "addi %[q83], %[q83], 64\n\t" + "vwmul.vv v24, v2, v12\n\t" + "vwmul.vv v20, v4, v10\n\t" + "vwmul.vv v28, v6, v14\n\t" + "vwmacc.vv v16, v1, v9\n\t" + "vle8.v v1, (%[q41])\n\t" + "vle8.v v2, (%[q42])\n\t" + "vwmacc.vv v24, v3, v13\n\t" + "vwmacc.vv v20, v5, v11\n\t" + "vwmacc.vv v28, v7, v15\n\t" + "addi %[q40], %[q80], 64\n\t" + "addi %[q41], %[q81], 64\n\t" + "vle8.v v3, (%[q43])\n\t" + "vle8.v v8, (%[q80])\n\t" + "addi %[q42], %[q82], 64\n\t" + "addi %[q43], %[q83], 64\n\t" + "vsrl.vi v4, v0, 4\n\t" + "vle8.v v9, (%[q81])\n\t" + "vle8.v v10, (%[q82])\n\t" + "vand.vi v0, v0, 0xF\n\t" + "vsrl.vi v5, v1, 4\n\t" + "vsrl.vi v7, v3, 4\n\t" + "vand.vi v3, v3, 0xF\n\t" + "vle8.v v11, (%[q83])\n\t" + "vle8.v v12, (%[q40])\n\t" + "vand.vi v1, v1, 0xF\n\t" + "vsrl.vi v6, v2, 4\n\t" + "vand.vi v2, v2, 0xF\n\t" + "vwmul.vv v18, v0, v8\n\t" + "vle8.v v13, (%[q41])\n\t" + "vle8.v v14, (%[q42])\n\t" + "vwmul.vv v26, v2, v12\n\t" + "vwmul.vv v22, v4, v10\n\t" + "vwmul.vv v30, v6, v14\n\t" + "vwmacc.vv v18, v1, v9\n\t" + "vle8.v v15, (%[q43])\n\t" + "vwmacc.vv v26, v3, v13\n\t" + "vwmacc.vv v22, v5, v11\n\t" + "vwmacc.vv v30, v7, v15\n\t" + "vmv.v.x v0, zero\n\t" + "vsetivli zero, 16, e16, m2, ta, ma\n\t" + "vwredsum.vs v4, v16, v0\n\t" + "lbu %[s0], 0(%[scale])\n\t" + "vwredsum.vs v5, v20, v0\n\t" + "lbu %[s1], 1(%[scale])\n\t" + "vwredsum.vs v6, v24, v0\n\t" + "lbu %[s2], 2(%[scale])\n\t" + "vwredsum.vs v7, v28, v0\n\t" + "lbu %[s3], 3(%[scale])\n\t" + "vwredsum.vs v8, v18, v0\n\t" + "lbu %[q40], 4(%[scale])\n\t" + "vwredsum.vs v9, v22, v0\n\t" + "lbu %[q41], 5(%[scale])\n\t" + "vwredsum.vs v10, v26, v0\n\t" + "lbu %[q42], 6(%[scale])\n\t" + "vwredsum.vs v11, v30, v0\n\t" + "lbu %[q43], 7(%[scale])\n\t" + "vsetivli zero, 4, e32, m1, ta, ma\n\t" + "vmul.vx v0, v4, %[s0]\n\t" + "vmul.vx v1, v8, %[q40]\n\t" + "vmacc.vx v0, %[s1], v5\n\t" + "vmacc.vx v1, %[q41], v9\n\t" + "vmacc.vx v0, %[s2], v6\n\t" + "vmacc.vx v1, %[q42], v10\n\t" + "vmacc.vx v0, %[s3], v7\n\t" + "vmacc.vx v1, %[q43], v11\n\t" + "vfcvt.f.x.v v0, v0\n\t" + "vfcvt.f.x.v v1, v1\n\t" + "vfmv.f.s %[ft2], v0\n\t" + "vfmv.f.s %[ftmp], v1\n\t" + "fadd.s %[ft2], %[ft2], %[ftmp]\n\t" + "fmadd.s %[sumf], %[d], %[ft2], %[sumf]" + : [ftmp] "=&f" (ftmp), [sumf] "+&f" (sumf), [ft2] "=&f" (ft2) + , [s0] "=&r" (s0), [s1] "=&r" (s1), [s2] "=&r" (s2), [s3] "=&r" (s3) + , [q40] "=&r" (q40), [q41] "=&r" (q41), [q42] "=&r" (q42), [q43] "=&r" (q43) + , [q80] "=&r" (q80), [q81] "=&r" (q81), [q82] "=&r" (q82), [q83] "=&r" (q83) + : [d] "f" (d), [ys] "r" (y[i].qs), [xs] "r" (x[i].qs), [scale] "r" (scales) + , [bsums] "r" (y[i].bsums), [mins] "r" (mins), [utmp] "r" (utmp) + , [s6b] "r" (&x[i]), [kmask1] "r" (kmask1), [dmin] "f" (dmin) + , [kmask2] "r" (kmask2), [kmask3] "r" (kmask3) + : "memory" + , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" + , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" + , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" + , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" + ); + } - switch (vector_length) { - case 256: - for (int i = 0; i < nb; ++i) { + *s = sumf; +} - size_t vl = 8; +static NOINLINE void ggml_vec_dot_q4_K_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); - const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); - const float dmin = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin); + const block_q4_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; - vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl); - vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl); - vint16mf2_t q8sums = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl); + const int nb = n / QK_K; - memcpy(utmp, x[i].scales, 12); - utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); - const uint32_t uaux = utmp[1] & kmask1; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[2] = uaux; - utmp[0] &= kmask1; + static const uint32_t kmask1 = 0x3f3f3f3f; + static const uint32_t kmask2 = 0x0f0f0f0f; + static const uint32_t kmask3 = 0x03030303; - vuint8mf4_t mins8 = __riscv_vle8_v_u8mf4(mins, vl); - vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl)); - vint32m1_t prod = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl); + uint32_t utmp[4]; - vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); - sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi); + const uint8_t * scales = (const uint8_t*)&utmp[0]; + const uint8_t * mins = (const uint8_t*)&utmp[2]; - const uint8_t * GGML_RESTRICT q4 = x[i].qs; - const int8_t * GGML_RESTRICT q8 = y[i].qs; + float sumf = 0; + for (int i = 0; i < nb; ++i) { + size_t vl = 8; - vl = 32; + const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); + const float dmin = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin); - int32_t sum_1 = 0; - int32_t sum_2 = 0; + vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl); + vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl); + vint16mf2_t q8sums = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl); - vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1); + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; - for (int j = 0; j < QK_K/64; ++j) { - // load Q4 - vuint8m1_t q4_x = __riscv_vle8_v_u8m1(q4, vl); + vuint8mf4_t mins8 = __riscv_vle8_v_u8mf4(mins, vl); + vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl)); + vint32m1_t prod = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl); - // load Q8 and multiply it with lower Q4 nibble - vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl); - vint8m1_t q4_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q4_x, 0x0F, vl)); - vint16m2_t qv_0 = __riscv_vwmul_vv_i16m2(q4_0, q8_0, vl); - vint16m1_t vs_0 = __riscv_vredsum_vs_i16m2_i16m1(qv_0, vzero, vl); + vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); + sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi); - sum_1 += __riscv_vmv_x_s_i16m1_i16(vs_0) * scales[2*j+0]; + const uint8_t * GGML_RESTRICT q4 = x[i].qs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; - // load Q8 and multiply it with upper Q4 nibble - vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8+32, vl); - vint8m1_t q4_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q4_x, 0x04, vl)); - vint16m2_t qv_1 = __riscv_vwmul_vv_i16m2(q4_1, q8_1, vl); - vint16m1_t vs_1 = __riscv_vredsum_vs_i16m2_i16m1(qv_1, vzero, vl); + vl = 32; - sum_2 += __riscv_vmv_x_s_i16m1_i16(vs_1) * scales[2*j+1]; + int32_t sum_1 = 0; + int32_t sum_2 = 0; - q4 += 32; q8 += 64; + vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1); - } + for (int j = 0; j < QK_K/64; ++j) { + // load Q4 + vuint8m1_t q4_x = __riscv_vle8_v_u8m1(q4, vl); - sumf += d*(sum_1 + sum_2); + // load Q8 and multiply it with lower Q4 nibble + vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl); + vint8m1_t q4_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q4_x, 0x0F, vl)); + vint16m2_t qv_0 = __riscv_vwmul_vv_i16m2(q4_0, q8_0, vl); + vint16m1_t vs_0 = __riscv_vredsum_vs_i16m2_i16m1(qv_0, vzero, vl); - } - break; - case 128: - for (int i = 0; i < nb; ++i) { - const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); - const float dmin = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin); - - float ftmp, ft2; - const uint8_t * restrict q40; - const uint8_t * restrict q41; - const uint8_t * restrict q42; - const uint8_t * restrict q43; - const int8_t * restrict q80; - const int8_t * restrict q81; - const int8_t * restrict q82; - const int8_t * restrict q83; - int s0, s1, s2, s3; + sum_1 += __riscv_vmv_x_s_i16m1_i16(vs_0) * scales[2*j+0]; + + // load Q8 and multiply it with upper Q4 nibble + vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8+32, vl); + vint8m1_t q4_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q4_x, 0x04, vl)); + vint16m2_t qv_1 = __riscv_vwmul_vv_i16m2(q4_1, q8_1, vl); + vint16m1_t vs_1 = __riscv_vredsum_vs_i16m2_i16m1(qv_1, vzero, vl); + + sum_2 += __riscv_vmv_x_s_i16m1_i16(vs_1) * scales[2*j+1]; + + q4 += 32; q8 += 64; - __asm__ __volatile__( - "li %[s1], 8\n\t" - "vsetivli zero, 4, e32, m1, ta, ma\n\t" - "vle32.v v1, (%[s6b])\n\t" - "vslide1down.vx v1, v1, zero\n\t" - "vmv.v.x v16, zero\n\t" - "vslidedown.vi v2, v1, 2\n\t" - "vmv1r.v v3, v2\n\t" - "vslideup.vi v2, v3, 1\n\t" // {aux[2], aux[2]} - "vsetivli zero, 2, e32, m1, ta, ma\n\t" - "vmv.v.i v4, 4\n\t" - "vand.vx v8, v1, %[kmask1]\n\t" - "vslide1up.vx v5, v4, zero\n\t" // {0, 4} - "vsrl.vi v6, v1, 6\n\t" - "vsrl.vv v7, v2, v5\n\t" - "vsse32.v v8, (%[utmp]), %[s1]\n\t" - "vand.vx v0, v6, %[kmask3]\n\t" - "vand.vx v2, v7, %[kmask2]\n\t" - "vsll.vi v6, v0, 4\n\t" - "addi %[s0], %[utmp], 4\n\t" - "vor.vv v1, v6, v2\n\t" - "vsse32.v v1, (%[s0]), %[s1]\n\t" - "vsetivli zero, 8, e16, m1, ta, ma\n\t" - "vle32.v v2, (%[bsums])\n\t" - "vnsrl.wi v0, v2, 0\n\t" - "vnsrl.wi v1, v2, 16\n\t" - "vadd.vv v2, v0, v1\n\t" - "vle8.v v3, (%[mins])\n\t" - "vzext.vf2 v4, v3\n\t" - "vwmul.vv v6, v4, v2\n\t" - "vsetivli zero, 4, e32, m1, ta, ma\n\t" - "vredsum.vs v0, v6, v16\n\t" - "vredsum.vs v0, v7, v0\n\t" - "vfcvt.f.x.v v0, v0\n\t" - "vfmv.f.s %[ftmp], v0\n\t" - "vsetivli zero, 16, e8, m1, ta, ma\n\t" - "vle8.v v0, (%[xs])\n\t" - "fnmsub.s %[sumf], %[dmin], %[ftmp], %[sumf]\n\t" - "addi %[q40], %[xs], 64\n\t" - "addi %[q41], %[xs], 16\n\t" - "addi %[q42], %[xs], 32\n\t" - "addi %[q43], %[xs], 48\n\t" - "addi %[q80], %[ys], 64\n\t" - "vle8.v v1, (%[q41])\n\t" - "vle8.v v2, (%[q42])\n\t" - "addi %[q81], %[ys], 16\n\t" - "addi %[q41], %[q41], 64\n\t" - "addi %[q82], %[ys], 32\n\t" - "vle8.v v3, (%[q43])\n\t" - "vle8.v v8, (%[ys])\n\t" - "addi %[q42], %[q42], 64\n\t" - "addi %[q83], %[ys], 48\n\t" - "addi %[q43], %[q43], 64\n\t" - "vsrl.vi v4, v0, 4\n\t" - "vle8.v v9, (%[q81])\n\t" - "vle8.v v10, (%[q82])\n\t" - "vand.vi v0, v0, 0xF\n\t" - "addi %[q81], %[q81], 64\n\t" - "vsrl.vi v5, v1, 4\n\t" - "addi %[q82], %[q82], 64\n\t" - "vle8.v v11, (%[q83])\n\t" - "vle8.v v12, (%[q80])\n\t" - "vand.vi v1, v1, 0xF\n\t" - "addi %[q83], %[q83], 64\n\t" - "vsrl.vi v6, v2, 4\n\t" - "addi %[q80], %[q80], 64\n\t" - "vle8.v v13, (%[q81])\n\t" - "vle8.v v14, (%[q82])\n\t" - "vand.vi v2, v2, 0xF\n\t" - "addi %[q81], %[q81], 64\n\t" - "vsrl.vi v7, v3, 4\n\t" - "addi %[q82], %[q82], 64\n\t" - "vwmul.vv v16, v0, v8\n\t" - "vle8.v v15, (%[q83])\n\t" - "vle8.v v0, (%[q40])\n\t" - "vand.vi v3, v3, 0xF\n\t" - "addi %[q83], %[q83], 64\n\t" - "vwmul.vv v24, v2, v12\n\t" - "vwmul.vv v20, v4, v10\n\t" - "vwmul.vv v28, v6, v14\n\t" - "vwmacc.vv v16, v1, v9\n\t" - "vle8.v v1, (%[q41])\n\t" - "vle8.v v2, (%[q42])\n\t" - "vwmacc.vv v24, v3, v13\n\t" - "vwmacc.vv v20, v5, v11\n\t" - "vwmacc.vv v28, v7, v15\n\t" - "addi %[q40], %[q80], 64\n\t" - "addi %[q41], %[q81], 64\n\t" - "vle8.v v3, (%[q43])\n\t" - "vle8.v v8, (%[q80])\n\t" - "addi %[q42], %[q82], 64\n\t" - "addi %[q43], %[q83], 64\n\t" - "vsrl.vi v4, v0, 4\n\t" - "vle8.v v9, (%[q81])\n\t" - "vle8.v v10, (%[q82])\n\t" - "vand.vi v0, v0, 0xF\n\t" - "vsrl.vi v5, v1, 4\n\t" - "vsrl.vi v7, v3, 4\n\t" - "vand.vi v3, v3, 0xF\n\t" - "vle8.v v11, (%[q83])\n\t" - "vle8.v v12, (%[q40])\n\t" - "vand.vi v1, v1, 0xF\n\t" - "vsrl.vi v6, v2, 4\n\t" - "vand.vi v2, v2, 0xF\n\t" - "vwmul.vv v18, v0, v8\n\t" - "vle8.v v13, (%[q41])\n\t" - "vle8.v v14, (%[q42])\n\t" - "vwmul.vv v26, v2, v12\n\t" - "vwmul.vv v22, v4, v10\n\t" - "vwmul.vv v30, v6, v14\n\t" - "vwmacc.vv v18, v1, v9\n\t" - "vle8.v v15, (%[q43])\n\t" - "vwmacc.vv v26, v3, v13\n\t" - "vwmacc.vv v22, v5, v11\n\t" - "vwmacc.vv v30, v7, v15\n\t" - "vmv.v.x v0, zero\n\t" - "vsetivli zero, 16, e16, m2, ta, ma\n\t" - "vwredsum.vs v4, v16, v0\n\t" - "lbu %[s0], 0(%[scale])\n\t" - "vwredsum.vs v5, v20, v0\n\t" - "lbu %[s1], 1(%[scale])\n\t" - "vwredsum.vs v6, v24, v0\n\t" - "lbu %[s2], 2(%[scale])\n\t" - "vwredsum.vs v7, v28, v0\n\t" - "lbu %[s3], 3(%[scale])\n\t" - "vwredsum.vs v8, v18, v0\n\t" - "lbu %[q40], 4(%[scale])\n\t" - "vwredsum.vs v9, v22, v0\n\t" - "lbu %[q41], 5(%[scale])\n\t" - "vwredsum.vs v10, v26, v0\n\t" - "lbu %[q42], 6(%[scale])\n\t" - "vwredsum.vs v11, v30, v0\n\t" - "lbu %[q43], 7(%[scale])\n\t" - "vsetivli zero, 4, e32, m1, ta, ma\n\t" - "vmul.vx v0, v4, %[s0]\n\t" - "vmul.vx v1, v8, %[q40]\n\t" - "vmacc.vx v0, %[s1], v5\n\t" - "vmacc.vx v1, %[q41], v9\n\t" - "vmacc.vx v0, %[s2], v6\n\t" - "vmacc.vx v1, %[q42], v10\n\t" - "vmacc.vx v0, %[s3], v7\n\t" - "vmacc.vx v1, %[q43], v11\n\t" - "vfcvt.f.x.v v0, v0\n\t" - "vfcvt.f.x.v v1, v1\n\t" - "vfmv.f.s %[ft2], v0\n\t" - "vfmv.f.s %[ftmp], v1\n\t" - "fadd.s %[ft2], %[ft2], %[ftmp]\n\t" - "fmadd.s %[sumf], %[d], %[ft2], %[sumf]" - : [ftmp] "=&f" (ftmp), [sumf] "+&f" (sumf), [ft2] "=&f" (ft2) - , [s0] "=&r" (s0), [s1] "=&r" (s1), [s2] "=&r" (s2), [s3] "=&r" (s3) - , [q40] "=&r" (q40), [q41] "=&r" (q41), [q42] "=&r" (q42), [q43] "=&r" (q43) - , [q80] "=&r" (q80), [q81] "=&r" (q81), [q82] "=&r" (q82), [q83] "=&r" (q83) - : [d] "f" (d), [ys] "r" (y[i].qs), [xs] "r" (x[i].qs), [scale] "r" (scales) - , [bsums] "r" (y[i].bsums), [mins] "r" (mins), [utmp] "r" (utmp) - , [s6b] "r" (&x[i]), [kmask1] "r" (kmask1), [dmin] "f" (dmin) - , [kmask2] "r" (kmask2), [kmask3] "r" (kmask3) - : "memory" - , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" - , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" - , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" - , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" - ); } - break; - default: - assert(false && "Unsupported vector length"); - break; + + sumf += d*(sum_1 + sum_2); + } *s = sumf; +} +#endif +void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_xtheadvector + ggml_vec_dot_q4_K_q8_K_xtheadvector(n, s, bs, vx, bx, vy, by, nrc); +#elif defined __riscv_v + switch (__riscv_vlenb() * 8) { + case 128: + ggml_vec_dot_q4_K_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); + break; + default: // 256 and above + ggml_vec_dot_q4_K_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); + break; + } #else - - UNUSED(x); - UNUSED(y); - UNUSED(kmask1); - UNUSED(kmask2); - UNUSED(kmask3); - UNUSED(nb); - UNUSED(utmp); - ggml_vec_dot_q4_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); #endif } @@ -1823,7 +2174,6 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi aux32 += __riscv_vmv_x_s_i32m1_i32(vacc2); q5 += 32; q8 += 64; - } sums += aux32 * d; @@ -1846,7 +2196,8 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi #endif } -void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_xtheadvector +static NOINLINE void ggml_vec_dot_q6_K_q8_K_xtheadvector(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); UNUSED(nrc); @@ -1859,8 +2210,6 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi const int nb = n / QK_K; -#if defined __riscv_xtheadvector - float sumf = 0; for (int i = 0; i < nb; ++i) { @@ -1939,224 +2288,462 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi } *s = sumf; +} +#endif -#elif defined __riscv_v +#if defined __riscv_v +static NOINLINE void ggml_vec_dot_q6_K_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q6_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + float sumf = 0.0f; + for (int i = 0; i < nb; ++i) { + __builtin_prefetch(&x[i + 1].d, 0, 1); + + const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * restrict q6 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + const int8_t * restrict scale = x[i].scales; + + int q6h; + float ftmp; + + for (int j = 0; j < QK_K/128; ++j) { + __asm__ __volatile__( + "addi %[q6h], %[q6], 32\n\t" + "ld t0, 0(%[scale])\n\t" + "addi %[scale], %[scale], 8\n\t" + "slli t6, t0, 1 * 8\n\t" + "lb zero, 0(%[q6])\n\t" + "slli t5, t0, 2 * 8\n\t" + "slli t4, t0, 3 * 8\n\t" + "lb zero, 0(%[q6h])\n\t" + "slli t3, t0, 4 * 8\n\t" + "slli t2, t0, 5 * 8\n\t" + "lb zero, 0(%[qh])\n\t" + "lb zero, 31(%[q6h])\n\t" + "slli t1, t0, 6 * 8\n\t" + "srai a7, t0, 56\n\t" + "vsetvli zero, %[vl32], e8, m2\n\t" + "vle8.v v8, (%[q6])\n\t" + "srai t6, t6, 56\n\t" + "srai t5, t5, 56\n\t" + "srai t4, t4, 56\n\t" + "srai t3, t3, 56\n\t" + "vle8.v v10, (%[q6h])\n\t" + "addi %[q6], %[q6], 64\n\t" + "slli t0, t0, 7 * 8\n\t" + "srai t2, t2, 56\n\t" + "srai t1, t1, 56\n\t" + "srai t0, t0, 56\n\t" + "vle8.v v4, (%[qh])\n\t" + "vsrl.vi v12, v8, 4\n\t" + "vsrl.vi v14, v10, 4\n\t" + "lb zero, 0(%[q8])\n\t" + "vand.vi v8, v8, 0xF\n\t" + "vand.vi v10, v10, 0xF\n\t" + "lb zero, 32(%[q8])\n\t" + "vsll.vi v0, v4, 4\n\t" + "vsll.vi v2, v4, 2\n\t" + "lb zero, 64(%[q8])\n\t" + "vsrl.vi v6, v4, 2\n\t" + "vand.vx v0, v0, %[mask]\n\t" + "lb zero, 96(%[q8])\n\t" + "vand.vx v2, v2, %[mask]\n\t" + "vand.vx v4, v4, %[mask]\n\t" + "vand.vx v6, v6, %[mask]\n\t" + "vor.vv v8, v8, v0\n\t" + "lb zero, 127(%[q8])\n\t" + "vor.vv v10, v10, v2\n\t" + "vor.vv v12, v12, v4\n\t" + "vor.vv v14, v14, v6\n\t" + "vsetvli zero, %[vl128], e8, m8\n\t" + "vle8.v v0, (%[q8])\n\t" + "vsub.vx v8, v8, %[vl32]\n\t" + "vsetvli zero, %[vl64], e8, m4\n\t" + "vwmul.vv v16, v0, v8\n\t" + "vwmul.vv v24, v4, v12\n\t" + "vsetivli zero, 16, e16, m2\n\t" + "vmv.v.x v0, zero\n\t" + "vwredsum.vs v10, v16, v0\n\t" + "vwredsum.vs v9, v18, v0\n\t" + "vwredsum.vs v8, v20, v0\n\t" + "vwredsum.vs v7, v22, v0\n\t" + "vwredsum.vs v11, v24, v0\n\t" + "vwredsum.vs v12, v26, v0\n\t" + "vwredsum.vs v13, v28, v0\n\t" + "vwredsum.vs v14, v30, v0\n\t" + "vsetivli zero, 4, e32, m1\n\t" + "vmul.vx v0, v10, t0\n\t" + "vmul.vx v1, v9, t1\n\t" + "vmacc.vx v0, t2, v8\n\t" + "vmacc.vx v1, t3, v7\n\t" + "vmacc.vx v0, t4, v11\n\t" + "vmacc.vx v1, t5, v12\n\t" + "vmacc.vx v0, t6, v13\n\t" + "vmacc.vx v1, a7, v14\n\t" + "vadd.vv v0, v0, v1\n\t" + "vfcvt.f.x.v v0, v0\n\t" + "vfmv.f.s %[ftmp], v0\n\t" + "fmadd.s %[sumf], %[d], %[ftmp], %[sumf]" + : [q6] "+&r" (q6), [q6h] "=&r" (q6h) + , [scale] "+&r" (scale) + , [sumf] "+&f" (sumf), [ftmp] "=&f" (ftmp) + : [qh] "r" (qh), [q8] "r" (q8) + , [vl32] "r" (32), [vl64] "r" (64), [vl128] "r" (128) + , [mask] "r" (0x30), [d] "f" (d) + : "memory" + , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" + , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" + , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" + , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" + , "t0", "t1", "t2", "t3", "t4", "t5", "t6", "a7" + , "a6", "a5", "a4", "a3" + ); + qh += 32; q8 += 128; + } + } + + *s = sumf; +} + +static NOINLINE void ggml_vec_dot_q6_K_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q6_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; float sumf = 0; - const int vector_length = __riscv_vlenb() * 8; + for (int i = 0; i < nb; ++i) { + const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; - switch (vector_length) { - case 256: - for (int i = 0; i < nb; ++i) { + const uint8_t * GGML_RESTRICT q6 = x[i].ql; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const int8_t * GGML_RESTRICT q8 = y[i].qs; - const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + const int8_t * GGML_RESTRICT scale = x[i].scales; - const uint8_t * GGML_RESTRICT q6 = x[i].ql; - const uint8_t * GGML_RESTRICT qh = x[i].qh; - const int8_t * GGML_RESTRICT q8 = y[i].qs; + size_t vl; - const int8_t * GGML_RESTRICT scale = x[i].scales; + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); - size_t vl; + int sum_t = 0; + int is = 0; - vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + for (int j = 0; j < QK_K/128; ++j) { + + vl = 32; + + // load qh + vuint8m1_t qh_x = __riscv_vle8_v_u8m1(qh, vl); + + // load Q6 + vuint8m1_t q6_0 = __riscv_vle8_v_u8m1(q6, vl); + vuint8m1_t q6_1 = __riscv_vle8_v_u8m1(q6+32, vl); + + vuint8m1_t q6a_0 = __riscv_vand_vx_u8m1(q6_0, 0x0F, vl); + vuint8m1_t q6a_1 = __riscv_vand_vx_u8m1(q6_1, 0x0F, vl); + vuint8m1_t q6s_0 = __riscv_vsrl_vx_u8m1(q6_0, 0x04, vl); + vuint8m1_t q6s_1 = __riscv_vsrl_vx_u8m1(q6_1, 0x04, vl); + + vuint8m1_t qh_0 = __riscv_vand_vx_u8m1(qh_x, 0x03, vl); + vuint8m1_t qh_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x2, vl), 0x03 , vl); + vuint8m1_t qh_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x4, vl), 0x03 , vl); + vuint8m1_t qh_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x6, vl), 0x03 , vl); + + vuint8m1_t qhi_0 = __riscv_vor_vv_u8m1(q6a_0, __riscv_vsll_vx_u8m1(qh_0, 0x04, vl), vl); + vuint8m1_t qhi_1 = __riscv_vor_vv_u8m1(q6a_1, __riscv_vsll_vx_u8m1(qh_1, 0x04, vl), vl); + vuint8m1_t qhi_2 = __riscv_vor_vv_u8m1(q6s_0, __riscv_vsll_vx_u8m1(qh_2, 0x04, vl), vl); + vuint8m1_t qhi_3 = __riscv_vor_vv_u8m1(q6s_1, __riscv_vsll_vx_u8m1(qh_3, 0x04, vl), vl); + + vint8m1_t a_0 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_0), 32, vl); + vint8m1_t a_1 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_1), 32, vl); + vint8m1_t a_2 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_2), 32, vl); + vint8m1_t a_3 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_3), 32, vl); + + // load Q8 and take product + vint16m2_t va_q_0 = __riscv_vwmul_vv_i16m2(a_0, __riscv_vle8_v_i8m1(q8, vl), vl); + vint16m2_t va_q_1 = __riscv_vwmul_vv_i16m2(a_1, __riscv_vle8_v_i8m1(q8+32, vl), vl); + vint16m2_t va_q_2 = __riscv_vwmul_vv_i16m2(a_2, __riscv_vle8_v_i8m1(q8+64, vl), vl); + vint16m2_t va_q_3 = __riscv_vwmul_vv_i16m2(a_3, __riscv_vle8_v_i8m1(q8+96, vl), vl); + + vl = 16; + + vint32m2_t vaux_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 0), scale[is+0], vl); + vint32m2_t vaux_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 1), scale[is+1], vl); + vint32m2_t vaux_2 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 0), scale[is+2], vl); + vint32m2_t vaux_3 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 1), scale[is+3], vl); + vint32m2_t vaux_4 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 0), scale[is+4], vl); + vint32m2_t vaux_5 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 1), scale[is+5], vl); + vint32m2_t vaux_6 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 0), scale[is+6], vl); + vint32m2_t vaux_7 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 1), scale[is+7], vl); + + vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_0, vaux_1, vl), vzero, vl); + vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_2, vaux_3, vl), isum0, vl); + vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_4, vaux_5, vl), isum1, vl); + vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_6, vaux_7, vl), isum2, vl); - int sum_t = 0; - int is = 0; + sum_t += __riscv_vmv_x_s_i32m1_i32(isum3); + + q6 += 64; qh += 32; q8 += 128; is=8; + + } + + sumf += d * sum_t; + + } + + *s = sumf; +} + +static NOINLINE void ggml_vec_dot_q6_K_q8_K_vl512(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q6_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + // mask for processing 16 elements per prod register + const vuint16m1_t va_index = __riscv_vid_v_u16m1(32); + const vbool16_t va_mask = __riscv_vmsgtu_vx_u16m1_b16(va_index, 15, 32); + + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * GGML_RESTRICT q6 = x[i].ql; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + const int8_t * GGML_RESTRICT scale = x[i].scales; + + size_t vl = 32; + + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + + int sum_t = 0; + int is = 0; + + vint32m2_t vaux_0 = __riscv_vmv_v_x_i32m2(0, vl); + vint32m2_t vaux_1 = __riscv_vmv_v_x_i32m2(0, vl); + vint32m2_t vaux_2 = __riscv_vmv_v_x_i32m2(0, vl); + vint32m2_t vaux_3 = __riscv_vmv_v_x_i32m2(0, vl); + + for (int j = 0; j < QK_K/128; ++j) { + // load qh + vuint8mf2_t qh_x = __riscv_vle8_v_u8mf2(qh, vl); + + // load Q6 + vuint8mf2_t q6_0 = __riscv_vle8_v_u8mf2(q6, vl); + vuint8mf2_t q6_1 = __riscv_vle8_v_u8mf2(q6+32, vl); + + vuint8mf2_t q6a_0 = __riscv_vand_vx_u8mf2(q6_0, 0x0F, vl); + vuint8mf2_t q6a_1 = __riscv_vand_vx_u8mf2(q6_1, 0x0F, vl); + vuint8mf2_t q6s_0 = __riscv_vsrl_vx_u8mf2(q6_0, 0x04, vl); + vuint8mf2_t q6s_1 = __riscv_vsrl_vx_u8mf2(q6_1, 0x04, vl); + + vuint8mf2_t qh_0 = __riscv_vand_vx_u8mf2(qh_x, 0x03, vl); + vuint8mf2_t qh_1 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(qh_x, 0x2, vl), 0x03 , vl); + vuint8mf2_t qh_2 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(qh_x, 0x4, vl), 0x03 , vl); + vuint8mf2_t qh_3 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(qh_x, 0x6, vl), 0x03 , vl); + + vuint8mf2_t qhi_0 = __riscv_vor_vv_u8mf2(q6a_0, __riscv_vsll_vx_u8mf2(qh_0, 0x04, vl), vl); + vuint8mf2_t qhi_1 = __riscv_vor_vv_u8mf2(q6a_1, __riscv_vsll_vx_u8mf2(qh_1, 0x04, vl), vl); + vuint8mf2_t qhi_2 = __riscv_vor_vv_u8mf2(q6s_0, __riscv_vsll_vx_u8mf2(qh_2, 0x04, vl), vl); + vuint8mf2_t qhi_3 = __riscv_vor_vv_u8mf2(q6s_1, __riscv_vsll_vx_u8mf2(qh_3, 0x04, vl), vl); + + vint8mf2_t a_0 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(qhi_0), 32, vl); + vint8mf2_t a_1 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(qhi_1), 32, vl); + vint8mf2_t a_2 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(qhi_2), 32, vl); + vint8mf2_t a_3 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(qhi_3), 32, vl); + + // load Q8 and take product + vint16m1_t va_q_0 = __riscv_vwmul_vv_i16m1(a_0, __riscv_vle8_v_i8mf2(q8, vl), vl); + vint16m1_t va_q_1 = __riscv_vwmul_vv_i16m1(a_1, __riscv_vle8_v_i8mf2(q8+32, vl), vl); + vint16m1_t va_q_2 = __riscv_vwmul_vv_i16m1(a_2, __riscv_vle8_v_i8mf2(q8+64, vl), vl); + vint16m1_t va_q_3 = __riscv_vwmul_vv_i16m1(a_3, __riscv_vle8_v_i8mf2(q8+96, vl), vl); + + // accumulate + vaux_0 = __riscv_vwmacc_vx_i32m2(vaux_0, scale[is+0], va_q_0, 16); + vaux_1 = __riscv_vwmacc_vx_i32m2(vaux_1, scale[is+2], va_q_1, 16); + vaux_2 = __riscv_vwmacc_vx_i32m2(vaux_2, scale[is+4], va_q_2, 16); + vaux_3 = __riscv_vwmacc_vx_i32m2(vaux_3, scale[is+6], va_q_3, 16); + // + vaux_0 = __riscv_vwmacc_vx_i32m2_m(va_mask, vaux_0, scale[is+1], va_q_0, vl); + vaux_1 = __riscv_vwmacc_vx_i32m2_m(va_mask, vaux_1, scale[is+3], va_q_1, vl); + vaux_2 = __riscv_vwmacc_vx_i32m2_m(va_mask, vaux_2, scale[is+5], va_q_2, vl); + vaux_3 = __riscv_vwmacc_vx_i32m2_m(va_mask, vaux_3, scale[is+7], va_q_3, vl); + + q6 += 64; qh += 32; q8 += 128; is=8; + } + + vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_0, vaux_1, vl), vzero, vl); + vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_2, vaux_3, vl), isum0, vl); + + sum_t += __riscv_vmv_x_s_i32m1_i32(isum1); + + sumf += d * sum_t; + + } + + *s = sumf; +} + +static NOINLINE void ggml_vec_dot_q6_K_q8_K_vl1024(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q6_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; - for (int j = 0; j < QK_K/128; ++j) { + const int nb = n / QK_K; - vl = 32; + // mask for processing 16 elements per prod register + const vuint16mf2_t va_index = __riscv_vid_v_u16mf2(32); + const vbool32_t va_mask = __riscv_vmsgtu_vx_u16mf2_b32(va_index, 15, 32); - // load qh - vuint8m1_t qh_x = __riscv_vle8_v_u8m1(qh, vl); + float sumf = 0; - // load Q6 - vuint8m1_t q6_0 = __riscv_vle8_v_u8m1(q6, vl); - vuint8m1_t q6_1 = __riscv_vle8_v_u8m1(q6+32, vl); + for (int i = 0; i < nb; ++i) { + const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; - vuint8m1_t q6a_0 = __riscv_vand_vx_u8m1(q6_0, 0x0F, vl); - vuint8m1_t q6a_1 = __riscv_vand_vx_u8m1(q6_1, 0x0F, vl); - vuint8m1_t q6s_0 = __riscv_vsrl_vx_u8m1(q6_0, 0x04, vl); - vuint8m1_t q6s_1 = __riscv_vsrl_vx_u8m1(q6_1, 0x04, vl); + const uint8_t * GGML_RESTRICT q6 = x[i].ql; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const int8_t * GGML_RESTRICT q8 = y[i].qs; - vuint8m1_t qh_0 = __riscv_vand_vx_u8m1(qh_x, 0x03, vl); - vuint8m1_t qh_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x2, vl), 0x03 , vl); - vuint8m1_t qh_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x4, vl), 0x03 , vl); - vuint8m1_t qh_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x6, vl), 0x03 , vl); + const int8_t * GGML_RESTRICT scale = x[i].scales; - vuint8m1_t qhi_0 = __riscv_vor_vv_u8m1(q6a_0, __riscv_vsll_vx_u8m1(qh_0, 0x04, vl), vl); - vuint8m1_t qhi_1 = __riscv_vor_vv_u8m1(q6a_1, __riscv_vsll_vx_u8m1(qh_1, 0x04, vl), vl); - vuint8m1_t qhi_2 = __riscv_vor_vv_u8m1(q6s_0, __riscv_vsll_vx_u8m1(qh_2, 0x04, vl), vl); - vuint8m1_t qhi_3 = __riscv_vor_vv_u8m1(q6s_1, __riscv_vsll_vx_u8m1(qh_3, 0x04, vl), vl); + size_t vl = 32; - vint8m1_t a_0 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_0), 32, vl); - vint8m1_t a_1 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_1), 32, vl); - vint8m1_t a_2 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_2), 32, vl); - vint8m1_t a_3 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_3), 32, vl); + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); - // load Q8 and take product - vint16m2_t va_q_0 = __riscv_vwmul_vv_i16m2(a_0, __riscv_vle8_v_i8m1(q8, vl), vl); - vint16m2_t va_q_1 = __riscv_vwmul_vv_i16m2(a_1, __riscv_vle8_v_i8m1(q8+32, vl), vl); - vint16m2_t va_q_2 = __riscv_vwmul_vv_i16m2(a_2, __riscv_vle8_v_i8m1(q8+64, vl), vl); - vint16m2_t va_q_3 = __riscv_vwmul_vv_i16m2(a_3, __riscv_vle8_v_i8m1(q8+96, vl), vl); + int sum_t = 0; + int is = 0; - vl = 16; + vint32m1_t vaux_0 = __riscv_vmv_v_x_i32m1(0, vl); + vint32m1_t vaux_1 = __riscv_vmv_v_x_i32m1(0, vl); + vint32m1_t vaux_2 = __riscv_vmv_v_x_i32m1(0, vl); + vint32m1_t vaux_3 = __riscv_vmv_v_x_i32m1(0, vl); - vint32m2_t vaux_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 0), scale[is+0], vl); - vint32m2_t vaux_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 1), scale[is+1], vl); - vint32m2_t vaux_2 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 0), scale[is+2], vl); - vint32m2_t vaux_3 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 1), scale[is+3], vl); - vint32m2_t vaux_4 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 0), scale[is+4], vl); - vint32m2_t vaux_5 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 1), scale[is+5], vl); - vint32m2_t vaux_6 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 0), scale[is+6], vl); - vint32m2_t vaux_7 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 1), scale[is+7], vl); + for (int j = 0; j < QK_K/128; ++j) { + // load qh + vuint8mf4_t qh_x = __riscv_vle8_v_u8mf4(qh, vl); + + // load Q6 + vuint8mf4_t q6_0 = __riscv_vle8_v_u8mf4(q6, vl); + vuint8mf4_t q6_1 = __riscv_vle8_v_u8mf4(q6+32, vl); + + vuint8mf4_t q6a_0 = __riscv_vand_vx_u8mf4(q6_0, 0x0F, vl); + vuint8mf4_t q6a_1 = __riscv_vand_vx_u8mf4(q6_1, 0x0F, vl); + vuint8mf4_t q6s_0 = __riscv_vsrl_vx_u8mf4(q6_0, 0x04, vl); + vuint8mf4_t q6s_1 = __riscv_vsrl_vx_u8mf4(q6_1, 0x04, vl); + + vuint8mf4_t qh_0 = __riscv_vand_vx_u8mf4(qh_x, 0x03, vl); + vuint8mf4_t qh_1 = __riscv_vand_vx_u8mf4(__riscv_vsrl_vx_u8mf4(qh_x, 0x2, vl), 0x03 , vl); + vuint8mf4_t qh_2 = __riscv_vand_vx_u8mf4(__riscv_vsrl_vx_u8mf4(qh_x, 0x4, vl), 0x03 , vl); + vuint8mf4_t qh_3 = __riscv_vand_vx_u8mf4(__riscv_vsrl_vx_u8mf4(qh_x, 0x6, vl), 0x03 , vl); + + vuint8mf4_t qhi_0 = __riscv_vor_vv_u8mf4(q6a_0, __riscv_vsll_vx_u8mf4(qh_0, 0x04, vl), vl); + vuint8mf4_t qhi_1 = __riscv_vor_vv_u8mf4(q6a_1, __riscv_vsll_vx_u8mf4(qh_1, 0x04, vl), vl); + vuint8mf4_t qhi_2 = __riscv_vor_vv_u8mf4(q6s_0, __riscv_vsll_vx_u8mf4(qh_2, 0x04, vl), vl); + vuint8mf4_t qhi_3 = __riscv_vor_vv_u8mf4(q6s_1, __riscv_vsll_vx_u8mf4(qh_3, 0x04, vl), vl); + + vint8mf4_t a_0 = __riscv_vsub_vx_i8mf4(__riscv_vreinterpret_v_u8mf4_i8mf4(qhi_0), 32, vl); + vint8mf4_t a_1 = __riscv_vsub_vx_i8mf4(__riscv_vreinterpret_v_u8mf4_i8mf4(qhi_1), 32, vl); + vint8mf4_t a_2 = __riscv_vsub_vx_i8mf4(__riscv_vreinterpret_v_u8mf4_i8mf4(qhi_2), 32, vl); + vint8mf4_t a_3 = __riscv_vsub_vx_i8mf4(__riscv_vreinterpret_v_u8mf4_i8mf4(qhi_3), 32, vl); + + // load Q8 and take product + vint16mf2_t va_q_0 = __riscv_vwmul_vv_i16mf2(a_0, __riscv_vle8_v_i8mf4(q8, vl), vl); + vint16mf2_t va_q_1 = __riscv_vwmul_vv_i16mf2(a_1, __riscv_vle8_v_i8mf4(q8+32, vl), vl); + vint16mf2_t va_q_2 = __riscv_vwmul_vv_i16mf2(a_2, __riscv_vle8_v_i8mf4(q8+64, vl), vl); + vint16mf2_t va_q_3 = __riscv_vwmul_vv_i16mf2(a_3, __riscv_vle8_v_i8mf4(q8+96, vl), vl); + + // accumulate + vaux_0 = __riscv_vwmacc_vx_i32m1(vaux_0, scale[is+0], va_q_0, 16); + vaux_1 = __riscv_vwmacc_vx_i32m1(vaux_1, scale[is+2], va_q_1, 16); + vaux_2 = __riscv_vwmacc_vx_i32m1(vaux_2, scale[is+4], va_q_2, 16); + vaux_3 = __riscv_vwmacc_vx_i32m1(vaux_3, scale[is+6], va_q_3, 16); + // + vaux_0 = __riscv_vwmacc_vx_i32m1_m(va_mask, vaux_0, scale[is+1], va_q_0, vl); + vaux_1 = __riscv_vwmacc_vx_i32m1_m(va_mask, vaux_1, scale[is+3], va_q_1, vl); + vaux_2 = __riscv_vwmacc_vx_i32m1_m(va_mask, vaux_2, scale[is+5], va_q_2, vl); + vaux_3 = __riscv_vwmacc_vx_i32m1_m(va_mask, vaux_3, scale[is+7], va_q_3, vl); - vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_0, vaux_1, vl), vzero, vl); - vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_2, vaux_3, vl), isum0, vl); - vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_4, vaux_5, vl), isum1, vl); - vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_6, vaux_7, vl), isum2, vl); + q6 += 64; qh += 32; q8 += 128; is=8; - sum_t += __riscv_vmv_x_s_i32m1_i32(isum3); + } - q6 += 64; qh += 32; q8 += 128; is=8; + vint32m1_t isum0 = __riscv_vredsum_vs_i32m1_i32m1(__riscv_vadd_vv_i32m1(vaux_0, vaux_1, vl), vzero, vl); + vint32m1_t isum1 = __riscv_vredsum_vs_i32m1_i32m1(__riscv_vadd_vv_i32m1(vaux_2, vaux_3, vl), isum0, vl); - } + sum_t += __riscv_vmv_x_s_i32m1_i32(isum1); - sumf += d * sum_t; + sumf += d * sum_t; - } - break; - case 128: - for (int i = 0; i < nb; ++i) { - - __builtin_prefetch(&x[i + 1].d, 0, 1); - - const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; - - const uint8_t * restrict q6 = x[i].ql; - const uint8_t * restrict qh = x[i].qh; - const int8_t * restrict q8 = y[i].qs; - - const int8_t * restrict scale = x[i].scales; - - int q6h; - float ftmp; - - for (int j = 0; j < QK_K/128; ++j) { - __asm__ __volatile__( - "addi %[q6h], %[q6], 32\n\t" - "ld t0, 0(%[scale])\n\t" - "addi %[scale], %[scale], 8\n\t" - "slli t6, t0, 1 * 8\n\t" - "lb zero, 0(%[q6])\n\t" - "slli t5, t0, 2 * 8\n\t" - "slli t4, t0, 3 * 8\n\t" - "lb zero, 0(%[q6h])\n\t" - "slli t3, t0, 4 * 8\n\t" - "slli t2, t0, 5 * 8\n\t" - "lb zero, 0(%[qh])\n\t" - "lb zero, 31(%[q6h])\n\t" - "slli t1, t0, 6 * 8\n\t" - "srai a7, t0, 56\n\t" - "vsetvli zero, %[vl32], e8, m2\n\t" - "vle8.v v8, (%[q6])\n\t" - "srai t6, t6, 56\n\t" - "srai t5, t5, 56\n\t" - "srai t4, t4, 56\n\t" - "srai t3, t3, 56\n\t" - "vle8.v v10, (%[q6h])\n\t" - "addi %[q6], %[q6], 64\n\t" - "slli t0, t0, 7 * 8\n\t" - "srai t2, t2, 56\n\t" - "srai t1, t1, 56\n\t" - "srai t0, t0, 56\n\t" - "vle8.v v4, (%[qh])\n\t" - "vsrl.vi v12, v8, 4\n\t" - "vsrl.vi v14, v10, 4\n\t" - "lb zero, 0(%[q8])\n\t" - "vand.vi v8, v8, 0xF\n\t" - "vand.vi v10, v10, 0xF\n\t" - "lb zero, 32(%[q8])\n\t" - "vsll.vi v0, v4, 4\n\t" - "vsll.vi v2, v4, 2\n\t" - "lb zero, 64(%[q8])\n\t" - "vsrl.vi v6, v4, 2\n\t" - "vand.vx v0, v0, %[mask]\n\t" - "lb zero, 96(%[q8])\n\t" - "vand.vx v2, v2, %[mask]\n\t" - "vand.vx v4, v4, %[mask]\n\t" - "vand.vx v6, v6, %[mask]\n\t" - "vor.vv v8, v8, v0\n\t" - "lb zero, 127(%[q8])\n\t" - "vor.vv v10, v10, v2\n\t" - "vor.vv v12, v12, v4\n\t" - "vor.vv v14, v14, v6\n\t" - "vsetvli zero, %[vl128], e8, m8\n\t" - "vle8.v v0, (%[q8])\n\t" - "vsub.vx v8, v8, %[vl32]\n\t" - "vsetvli zero, %[vl64], e8, m4\n\t" - "vwmul.vv v16, v0, v8\n\t" - "vwmul.vv v24, v4, v12\n\t" - "vsetivli zero, 16, e16, m2\n\t" - "vmv.v.x v0, zero\n\t" - "vwredsum.vs v10, v16, v0\n\t" - "vwredsum.vs v9, v18, v0\n\t" - "vwredsum.vs v8, v20, v0\n\t" - "vwredsum.vs v7, v22, v0\n\t" - "vwredsum.vs v11, v24, v0\n\t" - "vwredsum.vs v12, v26, v0\n\t" - "vwredsum.vs v13, v28, v0\n\t" - "vwredsum.vs v14, v30, v0\n\t" - "vsetivli zero, 4, e32, m1\n\t" - "vmul.vx v0, v10, t0\n\t" - "vmul.vx v1, v9, t1\n\t" - "vmacc.vx v0, t2, v8\n\t" - "vmacc.vx v1, t3, v7\n\t" - "vmacc.vx v0, t4, v11\n\t" - "vmacc.vx v1, t5, v12\n\t" - "vmacc.vx v0, t6, v13\n\t" - "vmacc.vx v1, a7, v14\n\t" - "vadd.vv v0, v0, v1\n\t" - "vfcvt.f.x.v v0, v0\n\t" - "vfmv.f.s %[ftmp], v0\n\t" - "fmadd.s %[sumf], %[d], %[ftmp], %[sumf]" - : [q6] "+&r" (q6), [q6h] "=&r" (q6h) - , [scale] "+&r" (scale) - , [sumf] "+&f" (sumf), [ftmp] "=&f" (ftmp) - : [qh] "r" (qh), [q8] "r" (q8) - , [vl32] "r" (32), [vl64] "r" (64), [vl128] "r" (128) - , [mask] "r" (0x30), [d] "f" (d) - : "memory" - , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" - , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" - , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" - , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" - , "t0", "t1", "t2", "t3", "t4", "t5", "t6", "a7" - , "a6", "a5", "a4", "a3" - ); - qh += 32; q8 += 128; - } - } - break; - default: - assert(false && "Unsupported vector length"); - break; } *s = sumf; +} +#endif +void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_xtheadvector + ggml_vec_dot_q6_K_q8_K_xtheadvector(n, s, bs, vx, bx, vy, by, nrc); +#elif defined __riscv_v + switch (__riscv_vlenb() * 8) { + case 128: + ggml_vec_dot_q6_K_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); + break; + case 256: + ggml_vec_dot_q6_K_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); + break; + case 512: + ggml_vec_dot_q6_K_q8_K_vl512(n, s, bs, vx, bx, vy, by, nrc); + break; + case 1024: + ggml_vec_dot_q6_K_q8_K_vl1024(n, s, bs, vx, bx, vy, by, nrc); + break; + default: + ggml_vec_dot_q6_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); + break; + } #else - - UNUSED(x); - UNUSED(y); - UNUSED(nb); - ggml_vec_dot_q6_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); #endif } -#if defined __riscv_v_intrinsic +#if defined __riscv_v static NOINLINE void ggml_vec_dot_iq1_s_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); @@ -2364,10 +2951,190 @@ static NOINLINE void ggml_vec_dot_iq1_s_q8_K_vl256(int n, float * GGML_RESTRICT *s = sumf; } + +static NOINLINE void ggml_vec_dot_iq1_s_q8_K_vl512(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq1_s * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + // Load qh once for the entire superblock. + vuint16mf4_t qh = __riscv_vle16_v_u16mf4(x[i].qh, 8); + + // Calculate ls. + vuint16mf4_t temp = __riscv_vsrl_vx_u16mf4(qh, 12, 8); + temp = __riscv_vand_vx_u16mf4(temp, 7, 8); + vint32mf2_t ls = __riscv_vreinterpret_v_u32mf2_i32mf2(__riscv_vwmulu_vx_u32mf2(temp, 2, 8)); + ls = __riscv_vadd_vx_i32mf2(ls, 1, 8); + + // Calculate delta. + vbool64_t mask = __riscv_vmseq_vx_u16mf4_b64(__riscv_vand_vx_u16mf4(qh, 0x8000, 8), 0, 8); + vint32mf2_t delta_neg = __riscv_vmv_v_x_i32mf2(-1, 8); + vint32mf2_t delta_pos = __riscv_vmv_v_x_i32mf2(1, 8); + vint32mf2_t delta = __riscv_vmerge_vvm_i32mf2(delta_neg, delta_pos, mask, 8); + + // Load qs. + vuint8mf2_t qs = __riscv_vle8_v_u8mf2(x[i].qs, 32); + + // Prepare the indices. + const uint64_t shift = 0x0009000600030000; + vuint16m1_t qh_shift = __riscv_vreinterpret_v_u64m1_u16m1(__riscv_vmv_v_x_u64m1(shift, 8)); + vuint16m1_t qh_gather_index = __riscv_vreinterpret_v_i16m1_u16m1( + __riscv_vdiv_vx_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vid_v_u16m1(32)), 4, 32)); + vuint16m1_t qh_ext = __riscv_vlmul_ext_v_u16mf2_u16m1(__riscv_vlmul_ext_v_u16mf4_u16mf2(qh)); + vuint16m1_t qh_index = __riscv_vrgather_vv_u16m1(qh_ext, qh_gather_index, 32); + qh_index = __riscv_vsrl_vv_u16m1(qh_index, qh_shift, 32); + qh_index = __riscv_vand_vx_u16m1(qh_index, 7, 32); + qh_index = __riscv_vsll_vx_u16m1(qh_index, 8, 32); + qh_index = __riscv_vor_vv_u16m1(qh_index, __riscv_vzext_vf2_u16m1(qs, 32), 32); + vuint16m1_t index = __riscv_vsll_vx_u16m1(qh_index, 3, 32); + + // Final lsums. + int32_t lsums_s[8]; + vint32m1_t one_scalar = __riscv_vmv_v_x_i32m1(0, 1); + + // Sub-blocks 1-8 + { + vint8m4_t grid0 = __riscv_vreinterpret_v_i64m4_i8m4(__riscv_vluxei16_v_i64m4((const int64_t*)iq1s_grid, index, 32)); + vint8m4_t q80 = __riscv_vle8_v_i8m4(y[i].qs, 256); + vint16m8_t lsum0 = __riscv_vwmul_vv_i16m8(grid0, q80, 256); + lsums_s[0] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(lsum0, 0), one_scalar, 32)); + lsums_s[1] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(lsum0, 1), one_scalar, 32)); + lsums_s[2] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(lsum0, 2), one_scalar, 32)); + lsums_s[3] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(lsum0, 3), one_scalar, 32)); + lsums_s[4] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(lsum0, 4), one_scalar, 32)); + lsums_s[5] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(lsum0, 5), one_scalar, 32)); + lsums_s[6] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(lsum0, 6), one_scalar, 32)); + lsums_s[7] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(lsum0, 7), one_scalar, 32)); + } + __asm__ __volatile__("" ::: "memory"); + vint32mf2_t lsums = __riscv_vle32_v_i32mf2(&lsums_s[0], 8); + + // Calculate the bsums. + vint16mf2_t bsums_0 = __riscv_vle16_v_i16mf2(y[i].bsums, 16); + const vuint32mf2_t bsums_i32 = __riscv_vreinterpret_v_u16mf2_u32mf2(__riscv_vreinterpret_v_i16mf2_u16mf2(bsums_0)); + const vint16mf4_t bsums_i32_0 = __riscv_vreinterpret_v_u16mf4_i16mf4(__riscv_vnsrl_wx_u16mf4(bsums_i32, 0, 8)); + const vint16mf4_t bsums_i32_1 = __riscv_vreinterpret_v_u16mf4_i16mf4(__riscv_vnsrl_wx_u16mf4(bsums_i32, 16, 8)); + const vint32mf2_t bsums = __riscv_vwadd_vv_i32mf2(bsums_i32_0, bsums_i32_1, 8); + + // Accumulation. + vint32mf2_t sumi_v = __riscv_vmul_vv_i32mf2(ls, lsums, 8); + vint32mf2_t sumi1_v = __riscv_vmul_vv_i32mf2(__riscv_vmul_vv_i32mf2(ls, delta, 8), bsums, 8); + + // Update sumf. + int sumi = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32mf2_i32m1(sumi_v, __riscv_vmv_v_x_i32m1(0.0f, 1), 8)); + int sumi1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32mf2_i32m1(sumi1_v, __riscv_vmv_v_x_i32m1(0.0f, 1), 8)); + sumf += GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d * (sumi + IQ1S_DELTA * sumi1); + } + + *s = sumf; +} + +static NOINLINE void ggml_vec_dot_iq1_s_q8_K_vl1024(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq1_s * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + // Mask for processing 32 elements per lsum register. + vuint16m1_t l_index = __riscv_vid_v_u16m1(64); + vbool16_t l_mask = __riscv_vmsgtu_vx_u16m1_b16(l_index, 31, 64); + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + // Load qh once for the entire superblock. + vuint16mf4_t qh = __riscv_vle16_v_u16mf4(x[i].qh, 8); + + // Calculate ls. + vuint16mf4_t temp = __riscv_vsrl_vx_u16mf4(qh, 12, 8); + temp = __riscv_vand_vx_u16mf4(temp, 7, 8); + vint32mf2_t ls = __riscv_vreinterpret_v_u32mf2_i32mf2(__riscv_vwmulu_vx_u32mf2(temp, 2, 8)); + ls = __riscv_vadd_vx_i32mf2(ls, 1, 8); + + // Calculate delta. + vbool64_t mask = __riscv_vmseq_vx_u16mf4_b64(__riscv_vand_vx_u16mf4(qh, 0x8000, 8), 0, 8); + vint32mf2_t delta_neg = __riscv_vmv_v_x_i32mf2(-1, 8); + vint32mf2_t delta_pos = __riscv_vmv_v_x_i32mf2(1, 8); + vint32mf2_t delta = __riscv_vmerge_vvm_i32mf2(delta_neg, delta_pos, mask, 8); + + // Load qs. + vuint8mf2_t qs = __riscv_vle8_v_u8mf2(x[i].qs, 32); + + // Prepare the indices. + const uint64_t shift = 0x0009000600030000; + vuint16m1_t qh_shift = __riscv_vreinterpret_v_u64m1_u16m1(__riscv_vmv_v_x_u64m1(shift, 8)); + vuint16m1_t qh_gather_index = __riscv_vreinterpret_v_i16m1_u16m1( + __riscv_vdiv_vx_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vid_v_u16m1(32)), 4, 32)); + vuint16m1_t qh_ext = __riscv_vlmul_ext_v_u16mf2_u16m1(__riscv_vlmul_ext_v_u16mf4_u16mf2(qh)); + vuint16m1_t qh_index = __riscv_vrgather_vv_u16m1(qh_ext, qh_gather_index, 32); + qh_index = __riscv_vsrl_vv_u16m1(qh_index, qh_shift, 32); + qh_index = __riscv_vand_vx_u16m1(qh_index, 7, 32); + qh_index = __riscv_vsll_vx_u16m1(qh_index, 8, 32); + qh_index = __riscv_vor_vv_u16m1(qh_index, __riscv_vzext_vf2_u16m1(qs, 32), 32); + vuint16mf2_t index = __riscv_vlmul_trunc_v_u16m1_u16mf2(__riscv_vsll_vx_u16m1(qh_index, 3, 32)); + + // Final lsums. + int32_t lsums_s[8]; + vint32m1_t one_scalar = __riscv_vmv_v_x_i32m1(0, 1); + + // Sub-blocks 1-8 + { + vint8m2_t grid0 = __riscv_vreinterpret_v_i64m2_i8m2(__riscv_vluxei16_v_i64m2((const int64_t*)iq1s_grid, index, 32)); + vint8m2_t q80 = __riscv_vle8_v_i8m2(y[i].qs, 256); + vint16m4_t lsum0 = __riscv_vwmul_vv_i16m4(grid0, q80, 256); + + // Reduce. + lsums_s[0] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( __riscv_vget_v_i16m4_i16m1(lsum0, 0), one_scalar, 32)); + lsums_s[1] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1_m(l_mask, __riscv_vget_v_i16m4_i16m1(lsum0, 0), one_scalar, 64)); + lsums_s[2] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( __riscv_vget_v_i16m4_i16m1(lsum0, 1), one_scalar, 32)); + lsums_s[3] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1_m(l_mask, __riscv_vget_v_i16m4_i16m1(lsum0, 1), one_scalar, 64)); + lsums_s[4] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( __riscv_vget_v_i16m4_i16m1(lsum0, 2), one_scalar, 32)); + lsums_s[5] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1_m(l_mask, __riscv_vget_v_i16m4_i16m1(lsum0, 2), one_scalar, 64)); + lsums_s[6] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( __riscv_vget_v_i16m4_i16m1(lsum0, 3), one_scalar, 32)); + lsums_s[7] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1_m(l_mask, __riscv_vget_v_i16m4_i16m1(lsum0, 3), one_scalar, 64)); + } + __asm__ __volatile__("" ::: "memory"); + vint32mf2_t lsums = __riscv_vle32_v_i32mf2(&lsums_s[0], 8); + + // Calculate the bsums. + vint16mf2_t bsums_0 = __riscv_vle16_v_i16mf2(y[i].bsums, 16); + const vuint32mf2_t bsums_i32 = __riscv_vreinterpret_v_u16mf2_u32mf2(__riscv_vreinterpret_v_i16mf2_u16mf2(bsums_0)); + const vint16mf4_t bsums_i32_0 = __riscv_vreinterpret_v_u16mf4_i16mf4(__riscv_vnsrl_wx_u16mf4(bsums_i32, 0, 8)); + const vint16mf4_t bsums_i32_1 = __riscv_vreinterpret_v_u16mf4_i16mf4(__riscv_vnsrl_wx_u16mf4(bsums_i32, 16, 8)); + const vint32mf2_t bsums = __riscv_vwadd_vv_i32mf2(bsums_i32_0, bsums_i32_1, 8); + + // Accumulation. + vint32mf2_t sumi_v = __riscv_vmul_vv_i32mf2(ls, lsums, 8); + vint32mf2_t sumi1_v = __riscv_vmul_vv_i32mf2(__riscv_vmul_vv_i32mf2(ls, delta, 8), bsums, 8); + + // Update sumf. + int sumi = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32mf2_i32m1(sumi_v, __riscv_vmv_v_x_i32m1(0.0f, 1), 8)); + int sumi1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32mf2_i32m1(sumi1_v, __riscv_vmv_v_x_i32m1(0.0f, 1), 8)); + sumf += GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d * (sumi + IQ1S_DELTA * sumi1); + } + + *s = sumf; +} #endif void ggml_vec_dot_iq1_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { -#if defined __riscv_v_intrinsic +#if defined __riscv_v switch (__riscv_vlenb() * 8) { case 128: ggml_vec_dot_iq1_s_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); @@ -2375,6 +3142,12 @@ void ggml_vec_dot_iq1_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo case 256: ggml_vec_dot_iq1_s_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); break; + case 512: + ggml_vec_dot_iq1_s_q8_K_vl512(n, s, bs, vx, bx, vy, by, nrc); + break; + case 1024: + ggml_vec_dot_iq1_s_q8_K_vl1024(n, s, bs, vx, bx, vy, by, nrc); + break; default: ggml_vec_dot_iq1_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); break; @@ -2384,7 +3157,7 @@ void ggml_vec_dot_iq1_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo #endif } -#if defined __riscv_v_intrinsic +#if defined __riscv_v static NOINLINE void ggml_vec_dot_iq1_m_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); @@ -2664,10 +3437,287 @@ static NOINLINE void ggml_vec_dot_iq1_m_q8_K_vl256(int n, float * GGML_RESTRICT *s = sumf; } + +static NOINLINE void ggml_vec_dot_iq1_m_q8_K_vl512(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq1_m * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + iq1m_scale_t scale; + + // Mask for processing 16 elements per lsum register. + const vuint16m1_t l_index = __riscv_vid_v_u16m1(32); + const vbool16_t l_mask = __riscv_vmsgtu_vx_u16m1_b16(l_index, 15, 32); + + float sumf = 0.0f; + for (int i = 0; i < nb; ++i) { + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint8_t * qh = x[i].qh; + const uint16_t * sc = (const uint16_t *)x[i].scales; + + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); + + // Accumulators. + vint32m2_t acc1 = __riscv_vmv_v_x_i32m2(0, 32); + vint32m2_t acc2 = __riscv_vmv_v_x_i32m2(0, 32); + + // We process all the sub-blocks together. + #pragma GCC unroll 1 + for (int ib = 0; ib < QK_K/256; ib++) { + // Load qh for all 16 sub-blocks. + const vuint8mf4_t qh_8 = __riscv_vle8_v_u8mf4(qh, 16); + const vuint16mf2_t qh_16_lo = __riscv_vzext_vf2_u16mf2(qh_8, 16); + const vuint16mf2_t qh_16_hi = __riscv_vsll_vx_u16mf2(qh_16_lo, 8, 16); + const vuint16m1_t qhb = __riscv_vzext_vf2_u16m1( + __riscv_vreinterpret_v_u16mf2_u8mf2(__riscv_vor_vv_u16mf2(qh_16_lo, qh_16_hi, 16)), 32); + __asm__ __volatile__("" ::: "memory"); + + // Prepare grid indices. + const vuint16m1_t qsb = __riscv_vzext_vf2_u16m1(__riscv_vle8_v_u8mf2(&qs[0], 32), 32); + const vuint16m1_t shift = __riscv_vreinterpret_v_u32m1_u16m1(__riscv_vmv_v_x_u32m1(0x00040008, 16)); + vuint16m1_t index = __riscv_vor_vv_u16m1(qsb, __riscv_vand_vx_u16m1(__riscv_vsll_vv_u16m1(qhb, shift, 32), 0x700, 32), 32); + index = __riscv_vsll_vx_u16m1(index, 3, 32); + __asm__ __volatile__("" ::: "memory"); + + // Load the grid. + const vint8m4_t iq1b = __riscv_vreinterpret_v_i64m4_i8m4(__riscv_vreinterpret_v_u64m4_i64m4( + __riscv_vluxei16_v_u64m4(iq1s_grid, index, 32))); + + // Prepare the deltas. + const vbool16_t mask = __riscv_vmsgtu_vx_u16m1_b16( + __riscv_vand_vv_u16m1(qhb, __riscv_vreinterpret_v_u32m1_u16m1(__riscv_vmv_v_x_u32m1(0x00800008, 16)), 32), 0, 32); + const vint64m4_t delta_pos = __riscv_vmv_v_x_i64m4(0x0101010101010101, 32); + const vint8m4_t delta = __riscv_vreinterpret_v_i64m4_i8m4( + __riscv_vmerge_vxm_i64m4(delta_pos, 0xffffffffffffffff, mask, 32)); + + // Load q8 for sub-blocks. + const vint8m4_t q8b = __riscv_vle8_v_i8m4(q8, 256); + + // Calculate the lsums. + const vint16m8_t lsum1 = __riscv_vwmul_vv_i16m8(iq1b, q8b, 256); + const vint16m8_t lsum2 = __riscv_vwmul_vv_i16m8(delta, q8b, 256); + + // Prepare the scales. + const int16_t ls_0 = 2*((sc[0] >> 0) & 0x7) + 1; + const int16_t ls_1 = 2*((sc[0] >> 3) & 0x7) + 1; + const int16_t ls_2 = 2*((sc[0] >> 6) & 0x7) + 1; + const int16_t ls_3 = 2*((sc[0] >> 9) & 0x7) + 1; + const int16_t ls_4 = 2*((sc[1] >> 0) & 0x7) + 1; + const int16_t ls_5 = 2*((sc[1] >> 3) & 0x7) + 1; + const int16_t ls_6 = 2*((sc[1] >> 6) & 0x7) + 1; + const int16_t ls_7 = 2*((sc[1] >> 9) & 0x7) + 1; + const int16_t ls_8 = 2*((sc[2] >> 0) & 0x7) + 1; + const int16_t ls_9 = 2*((sc[2] >> 3) & 0x7) + 1; + const int16_t ls_10 = 2*((sc[2] >> 6) & 0x7) + 1; + const int16_t ls_11 = 2*((sc[2] >> 9) & 0x7) + 1; + const int16_t ls_12 = 2*((sc[3] >> 0) & 0x7) + 1; + const int16_t ls_13 = 2*((sc[3] >> 3) & 0x7) + 1; + const int16_t ls_14 = 2*((sc[3] >> 6) & 0x7) + 1; + const int16_t ls_15 = 2*((sc[3] >> 9) & 0x7) + 1; + + // Accumulate in acc0 and acc1 for each sub-block. + acc1 = __riscv_vwmacc_vx_i32m2( acc1, ls_0, __riscv_vget_v_i16m8_i16m1(lsum1, 0), 16); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc1, ls_1, __riscv_vget_v_i16m8_i16m1(lsum1, 0), 32); + acc2 = __riscv_vwmacc_vx_i32m2( acc2, ls_0, __riscv_vget_v_i16m8_i16m1(lsum2, 0), 16); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc2, ls_1, __riscv_vget_v_i16m8_i16m1(lsum2, 0), 32); + // + acc1 = __riscv_vwmacc_vx_i32m2( acc1, ls_2, __riscv_vget_v_i16m8_i16m1(lsum1, 1), 16); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc1, ls_3, __riscv_vget_v_i16m8_i16m1(lsum1, 1), 32); + acc2 = __riscv_vwmacc_vx_i32m2( acc2, ls_2, __riscv_vget_v_i16m8_i16m1(lsum2, 1), 16); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc2, ls_3, __riscv_vget_v_i16m8_i16m1(lsum2, 1), 32); + // + acc1 = __riscv_vwmacc_vx_i32m2( acc1, ls_4, __riscv_vget_v_i16m8_i16m1(lsum1, 2), 16); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc1, ls_5, __riscv_vget_v_i16m8_i16m1(lsum1, 2), 32); + acc2 = __riscv_vwmacc_vx_i32m2( acc2, ls_4, __riscv_vget_v_i16m8_i16m1(lsum2, 2), 16); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc2, ls_5, __riscv_vget_v_i16m8_i16m1(lsum2, 2), 32); + // + acc1 = __riscv_vwmacc_vx_i32m2( acc1, ls_6, __riscv_vget_v_i16m8_i16m1(lsum1, 3), 16); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc1, ls_7, __riscv_vget_v_i16m8_i16m1(lsum1, 3), 32); + acc2 = __riscv_vwmacc_vx_i32m2( acc2, ls_6, __riscv_vget_v_i16m8_i16m1(lsum2, 3), 16); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc2, ls_7, __riscv_vget_v_i16m8_i16m1(lsum2, 3), 32); + // + acc1 = __riscv_vwmacc_vx_i32m2( acc1, ls_8, __riscv_vget_v_i16m8_i16m1(lsum1, 4), 16); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc1, ls_9, __riscv_vget_v_i16m8_i16m1(lsum1, 4), 32); + acc2 = __riscv_vwmacc_vx_i32m2( acc2, ls_8, __riscv_vget_v_i16m8_i16m1(lsum2, 4), 16); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc2, ls_9, __riscv_vget_v_i16m8_i16m1(lsum2, 4), 32); + // + acc1 = __riscv_vwmacc_vx_i32m2( acc1, ls_10, __riscv_vget_v_i16m8_i16m1(lsum1, 5), 16); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc1, ls_11, __riscv_vget_v_i16m8_i16m1(lsum1, 5), 32); + acc2 = __riscv_vwmacc_vx_i32m2( acc2, ls_10, __riscv_vget_v_i16m8_i16m1(lsum2, 5), 16); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc2, ls_11, __riscv_vget_v_i16m8_i16m1(lsum2, 5), 32); + // + acc1 = __riscv_vwmacc_vx_i32m2( acc1, ls_12, __riscv_vget_v_i16m8_i16m1(lsum1, 6), 16); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc1, ls_13, __riscv_vget_v_i16m8_i16m1(lsum1, 6), 32); + acc2 = __riscv_vwmacc_vx_i32m2( acc2, ls_12, __riscv_vget_v_i16m8_i16m1(lsum2, 6), 16); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc2, ls_13, __riscv_vget_v_i16m8_i16m1(lsum2, 6), 32); + // + acc1 = __riscv_vwmacc_vx_i32m2( acc1, ls_14, __riscv_vget_v_i16m8_i16m1(lsum1, 7), 16); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc1, ls_15, __riscv_vget_v_i16m8_i16m1(lsum1, 7), 32); + acc2 = __riscv_vwmacc_vx_i32m2( acc2, ls_14, __riscv_vget_v_i16m8_i16m1(lsum2, 7), 16); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc2, ls_15, __riscv_vget_v_i16m8_i16m1(lsum2, 7), 32); + + __asm__ __volatile__("" ::: "memory"); + } + + // Reduce and accumulate in `sumf`. + vint32m1_t one = __riscv_vmv_v_x_i32m1(0, 1); + int sumi1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m2_i32m1(acc1, one, 32)); + int sumi2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m2_i32m1(acc2, one, 32)); + sumf += y[i].d * GGML_CPU_FP16_TO_FP32(scale.f16) * (sumi1 + IQ1M_DELTA * sumi2); + } + + *s = sumf; +} + +static NOINLINE void ggml_vec_dot_iq1_m_q8_K_vl1024(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq1_m * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + iq1m_scale_t scale; + float sumf = 0.0f; + for (int i = 0; i < nb; ++i) { + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint8_t * qh = x[i].qh; + const uint16_t * sc = (const uint16_t *)x[i].scales; + + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); + + // Accumulators. + vint32m2_t acc1 = __riscv_vmv_v_x_i32m2(0, 64); + vint32m2_t acc2 = __riscv_vmv_v_x_i32m2(0, 64); + + // We process all the sub-blocks together. + #pragma GCC unroll 1 + for (int ib = 0; ib < QK_K/256; ib++) { + // Load qh for all 16 sub-blocks. + const vuint8mf8_t qh_8 = __riscv_vle8_v_u8mf8(qh, 16); + const vuint16mf4_t qh_16_lo = __riscv_vzext_vf2_u16mf4(qh_8, 16); + const vuint16mf4_t qh_16_hi = __riscv_vsll_vx_u16mf4(qh_16_lo, 8, 16); + const vuint16mf2_t qhb = __riscv_vzext_vf2_u16mf2( + __riscv_vreinterpret_v_u16mf4_u8mf4(__riscv_vor_vv_u16mf4(qh_16_lo, qh_16_hi, 16)), 32); + __asm__ __volatile__("" ::: "memory"); + + // Prepare grid indices. + const vuint16mf2_t qsb = __riscv_vzext_vf2_u16mf2(__riscv_vle8_v_u8mf4(&qs[0], 32), 32); + const vuint16mf2_t shift = __riscv_vreinterpret_v_u32mf2_u16mf2(__riscv_vmv_v_x_u32mf2(0x00040008, 16)); + vuint16mf2_t index = __riscv_vor_vv_u16mf2(qsb, __riscv_vand_vx_u16mf2(__riscv_vsll_vv_u16mf2(qhb, shift, 32), 0x700, 32), 32); + index = __riscv_vsll_vx_u16mf2(index, 3, 32); + __asm__ __volatile__("" ::: "memory"); + + // Load the grid. + const vint8m2_t iq1b = __riscv_vreinterpret_v_i64m2_i8m2(__riscv_vreinterpret_v_u64m2_i64m2( + __riscv_vluxei16_v_u64m2(iq1s_grid, index, 32))); + + // Prepare the deltas. + const vbool32_t mask = __riscv_vmsgtu_vx_u16mf2_b32( + __riscv_vand_vv_u16mf2(qhb, __riscv_vreinterpret_v_u32mf2_u16mf2(__riscv_vmv_v_x_u32mf2(0x00800008, 16)), 32), 0, 32); + const vint64m2_t delta_pos = __riscv_vmv_v_x_i64m2(0x0101010101010101, 32); + const vint8m2_t delta = __riscv_vreinterpret_v_i64m2_i8m2( + __riscv_vmerge_vxm_i64m2(delta_pos, 0xffffffffffffffff, mask, 32)); + + // Load q8 for sub-blocks. + const vint8m2_t q8b = __riscv_vle8_v_i8m2(q8, 256); + + // Calculate the lsums. + const vint16m4_t lsum1 = __riscv_vwmul_vv_i16m4(iq1b, q8b, 256); + const vint16m4_t lsum2 = __riscv_vwmul_vv_i16m4(delta, q8b, 256); + + // Prepare the scales. + const int16_t ls_0 = 2*((sc[0] >> 0) & 0x7) + 1; + const int16_t ls_1 = 2*((sc[0] >> 3) & 0x7) + 1; + const int16_t ls_2 = 2*((sc[0] >> 6) & 0x7) + 1; + const int16_t ls_3 = 2*((sc[0] >> 9) & 0x7) + 1; + const int16_t ls_4 = 2*((sc[1] >> 0) & 0x7) + 1; + const int16_t ls_5 = 2*((sc[1] >> 3) & 0x7) + 1; + const int16_t ls_6 = 2*((sc[1] >> 6) & 0x7) + 1; + const int16_t ls_7 = 2*((sc[1] >> 9) & 0x7) + 1; + const int16_t ls_8 = 2*((sc[2] >> 0) & 0x7) + 1; + const int16_t ls_9 = 2*((sc[2] >> 3) & 0x7) + 1; + const int16_t ls_10 = 2*((sc[2] >> 6) & 0x7) + 1; + const int16_t ls_11 = 2*((sc[2] >> 9) & 0x7) + 1; + const int16_t ls_12 = 2*((sc[3] >> 0) & 0x7) + 1; + const int16_t ls_13 = 2*((sc[3] >> 3) & 0x7) + 1; + const int16_t ls_14 = 2*((sc[3] >> 6) & 0x7) + 1; + const int16_t ls_15 = 2*((sc[3] >> 9) & 0x7) + 1; + + // Mask for processing 16 elements per lsum register. + const vuint16m1_t l_index = __riscv_vid_v_u16m1(64); + + // Accumulate in acc1 and acc2 for each sub-block. + acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_0, __riscv_vget_v_i16m4_i16m1(lsum1, 0), 16); + acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_0, __riscv_vget_v_i16m4_i16m1(lsum2, 0), 16); + acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_4, __riscv_vget_v_i16m4_i16m1(lsum1, 1), 16); + acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_4, __riscv_vget_v_i16m4_i16m1(lsum2, 1), 16); + acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_8, __riscv_vget_v_i16m4_i16m1(lsum1, 2), 16); + acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_8, __riscv_vget_v_i16m4_i16m1(lsum2, 2), 16); + acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_12, __riscv_vget_v_i16m4_i16m1(lsum1, 3), 16); + acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_12, __riscv_vget_v_i16m4_i16m1(lsum2, 3), 16); + // + const vbool16_t l_mask_16_32 = __riscv_vmsgtu_vx_u16m1_b16(l_index, 15, 64); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask_16_32, acc1, ls_1, __riscv_vget_v_i16m4_i16m1(lsum1, 0), 32); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask_16_32, acc2, ls_1, __riscv_vget_v_i16m4_i16m1(lsum2, 0), 32); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask_16_32, acc1, ls_5, __riscv_vget_v_i16m4_i16m1(lsum1, 1), 32); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask_16_32, acc2, ls_5, __riscv_vget_v_i16m4_i16m1(lsum2, 1), 32); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask_16_32, acc1, ls_9, __riscv_vget_v_i16m4_i16m1(lsum1, 2), 32); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask_16_32, acc2, ls_9, __riscv_vget_v_i16m4_i16m1(lsum2, 2), 32); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask_16_32, acc1, ls_13, __riscv_vget_v_i16m4_i16m1(lsum1, 3), 32); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask_16_32, acc2, ls_13, __riscv_vget_v_i16m4_i16m1(lsum2, 3), 32); + // + const vbool16_t l_mask_32_48 = __riscv_vmsgtu_vx_u16m1_b16(l_index, 31, 64); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask_32_48, acc1, ls_2, __riscv_vget_v_i16m4_i16m1(lsum1, 0), 48); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask_32_48, acc2, ls_2, __riscv_vget_v_i16m4_i16m1(lsum2, 0), 48); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask_32_48, acc1, ls_6, __riscv_vget_v_i16m4_i16m1(lsum1, 1), 48); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask_32_48, acc2, ls_6, __riscv_vget_v_i16m4_i16m1(lsum2, 1), 48); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask_32_48, acc1, ls_10, __riscv_vget_v_i16m4_i16m1(lsum1, 2), 48); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask_32_48, acc2, ls_10, __riscv_vget_v_i16m4_i16m1(lsum2, 2), 48); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask_32_48, acc1, ls_14, __riscv_vget_v_i16m4_i16m1(lsum1, 3), 48); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask_32_48, acc2, ls_14, __riscv_vget_v_i16m4_i16m1(lsum2, 3), 48); + // + const vbool16_t l_mask_48_64 = __riscv_vmsgtu_vx_u16m1_b16(l_index, 47, 64); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask_48_64, acc1, ls_3, __riscv_vget_v_i16m4_i16m1(lsum1, 0), 64); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask_48_64, acc2, ls_3, __riscv_vget_v_i16m4_i16m1(lsum2, 0), 64); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask_48_64, acc1, ls_7, __riscv_vget_v_i16m4_i16m1(lsum1, 1), 64); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask_48_64, acc2, ls_7, __riscv_vget_v_i16m4_i16m1(lsum2, 1), 64); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask_48_64, acc1, ls_11, __riscv_vget_v_i16m4_i16m1(lsum1, 2), 64); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask_48_64, acc2, ls_11, __riscv_vget_v_i16m4_i16m1(lsum2, 2), 64); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask_48_64, acc1, ls_15, __riscv_vget_v_i16m4_i16m1(lsum1, 3), 64); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask_48_64, acc2, ls_15, __riscv_vget_v_i16m4_i16m1(lsum2, 3), 64); + + __asm__ __volatile__("" ::: "memory"); + } + + // Reduce and accumulate in `sumf`. + vint32m1_t one = __riscv_vmv_v_x_i32m1(0, 1); + int sumi1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m2_i32m1(acc1, one, 64)); + int sumi2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m2_i32m1(acc2, one, 64)); + sumf += y[i].d * GGML_CPU_FP16_TO_FP32(scale.f16) * (sumi1 + IQ1M_DELTA * sumi2); + } + + *s = sumf; +} #endif void ggml_vec_dot_iq1_m_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { -#if defined __riscv_v_intrinsic +#if defined __riscv_v switch (__riscv_vlenb() * 8) { case 128: ggml_vec_dot_iq1_m_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); @@ -2675,6 +3725,12 @@ void ggml_vec_dot_iq1_m_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo case 256: ggml_vec_dot_iq1_m_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); break; + case 512: + ggml_vec_dot_iq1_m_q8_K_vl512(n, s, bs, vx, bx, vy, by, nrc); + break; + case 1024: + ggml_vec_dot_iq1_m_q8_K_vl1024(n, s, bs, vx, bx, vy, by, nrc); + break; default: ggml_vec_dot_iq1_m_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); break; @@ -2684,7 +3740,7 @@ void ggml_vec_dot_iq1_m_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo #endif } -#if defined __riscv_v_intrinsic +#if defined __riscv_v static const uint8_t sign_gather_indices_arr[64] = { 0,0,0,0,0,0,0,0, 1,1,1,1,1,1,1,1, 2,2,2,2,2,2,2,2, 3,3,3,3,3,3,3,3, 4,4,4,4,4,4,4,4, 5,5,5,5,5,5,5,5, 6,6,6,6,6,6,6,6, 7,7,7,7,7,7,7,7 @@ -2887,10 +3943,275 @@ static NOINLINE void ggml_vec_dot_iq2_s_q8_K_vl256(int n, float * GGML_RESTRICT } *s = 0.125f * sumf; } + +static NOINLINE void ggml_vec_dot_iq2_s_q8_K_vl512(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); + + const block_iq2_s * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + const uint64_t * grid64 = (const uint64_t *)iq2s_grid; + + vuint8m2_t v_ids = __riscv_vid_v_u8m2(128); + vuint8m2_t v_sign_gather_indices = __riscv_vsrl_vx_u8m2(v_ids, 3, 128); + + vuint8m2_t v_ones = __riscv_vmv_v_x_u8m2(1, 128); + vuint8m2_t v_shift_amts = __riscv_vand_vx_u8m2(v_ids, 7, 128); + vuint8m2_t v_sign_masks = __riscv_vsll_vv_u8m2(v_ones, v_shift_amts, 128); + + uint16_t gather_qh_arr[16] = {0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3}; + vuint16mf2_t v_gather_qh = __riscv_vle16_v_u16mf2(gather_qh_arr, 16); + + uint16_t shift_qh_arr[16] = {11, 9, 7, 5, 11, 9, 7, 5, 11, 9, 7, 5, 11, 9, 7, 5}; + vuint16mf2_t v_shift_qh = __riscv_vle16_v_u16mf2(shift_qh_arr, 16); + + // Masks for selecting lower/upper 16 lanes within a 32-lane i16m1 register + vuint16m1_t v_ids16 = __riscv_vid_v_u16m1(32); + vbool16_t m_hi16 = __riscv_vmsgeu_vx_u16m1_b16(v_ids16, 16, 32); + float sumf = 0.0f; + + for (int i = 0; i < nb; ++i) { + const float combined_scale = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * GGML_RESTRICT qs = x[i].qs; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const uint8_t * GGML_RESTRICT scales = x[i].scales; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + const uint8_t * signs_ptr = qs + 32; + + float sum_block = 0.0f; + + for (int ib = 0; ib < 2; ++ib) { + vuint8mf4_t v_qs_u8 = __riscv_vle8_v_u8mf4(qs, 16); + qs += 16; + + vuint8mf8_t v_qh_raw = __riscv_vle8_v_u8mf8(qh, 4); + qh += 4; + + vuint16mf4_t v_qh_u16 = __riscv_vwcvtu_x_x_v_u16mf4(v_qh_raw, 4); + vuint16mf2_t v_qh_u16_ext = __riscv_vlmul_ext_v_u16mf4_u16mf2(v_qh_u16); + vuint16mf2_t v_qh_expanded = __riscv_vrgather_vv_u16mf2(v_qh_u16_ext, v_gather_qh, 16); + v_qh_expanded = __riscv_vsll_vv_u16mf2(v_qh_expanded, v_shift_qh, 16); + v_qh_expanded = __riscv_vand_vx_u16mf2(v_qh_expanded, 0x1800, 16); + + vuint16mf2_t v_qs_u16 = __riscv_vwcvtu_x_x_v_u16mf2(v_qs_u8, 16); + v_qs_u16 = __riscv_vsll_vx_u16mf2(v_qs_u16, 3, 16); + + vuint16mf2_t v_grid_offsets = __riscv_vor_vv_u16mf2(v_qs_u16, v_qh_expanded, 16); + vuint64m2_t v_grid_vals = __riscv_vluxei16_v_u64m2(grid64, v_grid_offsets, 16); + vuint8m2_t v_grid_u8 = __riscv_vreinterpret_v_u64m2_u8m2(v_grid_vals); + vint8m2_t v_grid_i8 = __riscv_vreinterpret_v_u8m2_i8m2(v_grid_u8); + + vuint8mf4_t v_signs_raw = __riscv_vle8_v_u8mf4(signs_ptr, 16); + signs_ptr += 16; + + vuint8m2_t v_signs_source = __riscv_vlmul_ext_v_u8mf4_u8m2(v_signs_raw); + vuint8m2_t v_signs_bcast = __riscv_vrgather_vv_u8m2(v_signs_source, v_sign_gather_indices, 128); + vuint8m2_t v_sign_bits = __riscv_vand_vv_u8m2(v_signs_bcast, v_sign_masks, 128); + vbool4_t m_negative = __riscv_vmsne_vx_u8m2_b4(v_sign_bits, 0, 128); + vint8m2_t v_q8 = __riscv_vle8_v_i8m2(q8, 128); + q8 += 128; + + vint8m2_t v_q8_signed = __riscv_vrsub_vx_i8m2_mu(m_negative, v_q8, v_q8, 0, 128); + vint16m4_t v_dot = __riscv_vwmul_vv_i16m4(v_grid_i8, v_q8_signed, 128); + + vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, 1); + vint16m1_t v0 = __riscv_vget_v_i16m4_i16m1(v_dot, 0); + vint16m1_t v1 = __riscv_vget_v_i16m4_i16m1(v_dot, 1); + vint16m1_t v2 = __riscv_vget_v_i16m4_i16m1(v_dot, 2); + vint16m1_t v3 = __riscv_vget_v_i16m4_i16m1(v_dot, 3); + + int32_t s0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(v0, v_zero, 16)); + int32_t s1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1_m(m_hi16, v0, v_zero, 32)); + int32_t s2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(v1, v_zero, 16)); + int32_t s3 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1_m(m_hi16, v1, v_zero, 32)); + int32_t s4 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(v2, v_zero, 16)); + int32_t s5 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1_m(m_hi16, v2, v_zero, 32)); + int32_t s6 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( v3, v_zero, 16)); + int32_t s7 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1_m(m_hi16, v3, v_zero, 32)); + + uint8_t sc0 = scales[0]; + uint8_t sc1 = scales[1]; + uint8_t sc2 = scales[2]; + uint8_t sc3 = scales[3]; + scales += 4; + + sum_block += s0 * (2 * (sc0 & 0xF) + 1); + sum_block += s1 * (2 * (sc0 >> 4) + 1); + sum_block += s2 * (2 * (sc1 & 0xF) + 1); + sum_block += s3 * (2 * (sc1 >> 4) + 1); + sum_block += s4 * (2 * (sc2 & 0xF) + 1); + sum_block += s5 * (2 * (sc2 >> 4) + 1); + sum_block += s6 * (2 * (sc3 & 0xF) + 1); + sum_block += s7 * (2 * (sc3 >> 4) + 1); + } + + sumf += sum_block * combined_scale; + } + *s = 0.125f * sumf; +} + +static NOINLINE void ggml_vec_dot_iq2_s_q8_K_vl1024(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); + + const block_iq2_s * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + const uint64_t * grid64 = (const uint64_t *)iq2s_grid; + vuint8m2_t v_ids = __riscv_vid_v_u8m2(256); + vuint8m2_t v_sign_gather_indices = __riscv_vsrl_vx_u8m2(v_ids, 3, 256); + + vuint8m2_t v_ones = __riscv_vmv_v_x_u8m2(1, 256); + vuint8m2_t v_shift_amts = __riscv_vand_vx_u8m2(v_ids, 7, 256); + vuint8m2_t v_sign_masks = __riscv_vsll_vv_u8m2(v_ones, v_shift_amts, 256); + + uint16_t gather_qh_arr[32] = { + 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, + 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7 + }; + vuint16mf2_t v_gather_qh = __riscv_vle16_v_u16mf2(gather_qh_arr, 32); + + uint16_t shift_qh_arr[32] = { + 11, 9, 7, 5, 11, 9, 7, 5, 11, 9, 7, 5, 11, 9, 7, 5, + 11, 9, 7, 5, 11, 9, 7, 5, 11, 9, 7, 5, 11, 9, 7, 5 + }; + vuint16mf2_t v_shift_qh = __riscv_vle16_v_u16mf2(shift_qh_arr, 32); + + // Masks for 4 groups of 16 lanes within a 64-lane i16m4 chunk + vuint16m4_t v_ids64 = __riscv_vid_v_u16m4(64); + vbool4_t m_g0 = __riscv_vmsltu_vx_u16m4_b4(v_ids64, 16, 64); + vbool4_t m_g1 = __riscv_vmand_mm_b4( + __riscv_vmsgeu_vx_u16m4_b4(v_ids64, 16, 64), + __riscv_vmsltu_vx_u16m4_b4(v_ids64, 32, 64), 64); + vbool4_t m_g2 = __riscv_vmand_mm_b4( + __riscv_vmsgeu_vx_u16m4_b4(v_ids64, 32, 64), + __riscv_vmsltu_vx_u16m4_b4(v_ids64, 48, 64), 64); + vbool4_t m_g3 = __riscv_vmsgeu_vx_u16m4_b4(v_ids64, 48, 64); + + float sumf = 0.0f; + + for (int i = 0; i < nb; ++i) { + const float combined_scale = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * GGML_RESTRICT qs = x[i].qs; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const uint8_t * GGML_RESTRICT scales = x[i].scales; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + const uint8_t * signs_ptr = qs + 32; + + float sum_block = 0.0f; + + vuint8mf4_t v_qs_u8 = __riscv_vle8_v_u8mf4(qs, 32); + qs += 32; + + vuint8mf8_t v_qh_raw = __riscv_vle8_v_u8mf8(qh, 8); + qh += 8; + + vuint16mf4_t v_qh_u16 = __riscv_vwcvtu_x_x_v_u16mf4(v_qh_raw, 8); + vuint16mf2_t v_qh_u16_ext = __riscv_vlmul_ext_v_u16mf4_u16mf2(v_qh_u16); + vuint16mf2_t v_qh_expanded = __riscv_vrgather_vv_u16mf2(v_qh_u16_ext, v_gather_qh, 32); + v_qh_expanded = __riscv_vsll_vv_u16mf2(v_qh_expanded, v_shift_qh, 32); + v_qh_expanded = __riscv_vand_vx_u16mf2(v_qh_expanded, 0x1800, 32); + + vuint16mf2_t v_qs_u16 = __riscv_vwcvtu_x_x_v_u16mf2(v_qs_u8, 32); + v_qs_u16 = __riscv_vsll_vx_u16mf2(v_qs_u16, 3, 32); + + vuint16mf2_t v_grid_offsets = __riscv_vor_vv_u16mf2(v_qs_u16, v_qh_expanded, 32); + vuint64m2_t v_grid_vals = __riscv_vluxei16_v_u64m2(grid64, v_grid_offsets, 32); + vuint8m2_t v_grid_u8 = __riscv_vreinterpret_v_u64m2_u8m2(v_grid_vals); + vint8m2_t v_grid_i8 = __riscv_vreinterpret_v_u8m2_i8m2(v_grid_u8); + + //loading signs + vuint8mf2_t v_signs_raw = __riscv_vle8_v_u8mf2(signs_ptr, 32); + signs_ptr += 32; + + vuint8m2_t v_signs_source = __riscv_vlmul_ext_v_u8mf2_u8m2(v_signs_raw); + vuint8m2_t v_signs_bcast = __riscv_vrgather_vv_u8m2(v_signs_source, v_sign_gather_indices, 256); + vuint8m2_t v_sign_bits = __riscv_vand_vv_u8m2(v_signs_bcast, v_sign_masks, 256); + vbool4_t m_negative = __riscv_vmsne_vx_u8m2_b4(v_sign_bits, 0, 256); + + vint8m2_t v_q8 = __riscv_vle8_v_i8m2(q8, 256); + q8 += 256; + + vint8m2_t v_q8_signed = __riscv_vrsub_vx_i8m2_mu(m_negative, v_q8, v_q8, 0, 256); + vint16m4_t v_dot = __riscv_vwmul_vv_i16m4(v_grid_i8, v_q8_signed, 256); + + vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, 1); + + vint16m4_t c = v_dot; + + int32_t s0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g0, c, v_zero, 64)); + int32_t s1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g1, c, v_zero, 64)); + int32_t s2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g2, c, v_zero, 64)); + int32_t s3 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g3, c, v_zero, 64)); + + c = __riscv_vslidedown_vx_i16m4(c, 64, 256); + int32_t s4 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g0, c, v_zero, 64)); + int32_t s5 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g1, c, v_zero, 64)); + int32_t s6 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g2, c, v_zero, 64)); + int32_t s7 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g3, c, v_zero, 64)); + + c = __riscv_vslidedown_vx_i16m4(c, 64, 256); + int32_t s8 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g0, c, v_zero, 64)); + int32_t s9 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g1, c, v_zero, 64)); + int32_t s10 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g2, c, v_zero, 64)); + int32_t s11 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g3, c, v_zero, 64)); + + c = __riscv_vslidedown_vx_i16m4(c, 64, 256); + int32_t s12 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g0, c, v_zero, 64)); + int32_t s13 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g1, c, v_zero, 64)); + int32_t s14 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g2, c, v_zero, 64)); + int32_t s15 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g3, c, v_zero, 64)); + + int32_t sums_arr[16] = { s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13, s14, s15 }; + + // Load 8 scale bytes and split into 16 nibbles + vuint8mf2_t v_sc8 = __riscv_vle8_v_u8mf2(scales, 8); + scales += 8; + + vuint8mf2_t v_lo8 = __riscv_vand_vx_u8mf2(v_sc8, 0x0F, 8); + vuint8mf2_t v_hi8 = __riscv_vsrl_vx_u8mf2(v_sc8, 4, 8); + + vuint8m1_t v_idx16 = __riscv_vid_v_u8m1(16); + vuint8m1_t v_half = __riscv_vsrl_vx_u8m1(v_idx16, 1, 16); + vbool8_t m_even = __riscv_vmseq_vx_u8m1_b8(__riscv_vand_vx_u8m1(v_idx16, 1, 16), 0, 16); + + vuint8m1_t v_lo_ext = __riscv_vlmul_ext_v_u8mf2_u8m1(v_lo8); + vuint8m1_t v_hi_ext = __riscv_vlmul_ext_v_u8mf2_u8m1(v_hi8); + vuint8m1_t v_lo_g = __riscv_vrgather_vv_u8m1(v_lo_ext, v_half, 16); + vuint8m1_t v_hi_g = __riscv_vrgather_vv_u8m1(v_hi_ext, v_half, 16); + vuint8m1_t v_nib = __riscv_vmerge_vvm_u8m1(v_lo_g, v_hi_g, m_even, 16); + + static const uint8_t iq2s_scale_lut_16_local[16] = { + 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31 + }; + vuint8m1_t v_lut = __riscv_vle8_v_u8m1(iq2s_scale_lut_16_local, 16); + vuint8m1_t v_sc8v = __riscv_vrgather_vv_u8m1(v_lut, v_nib, 16); + + vint32m4_t v_sums = __riscv_vle32_v_i32m4(sums_arr, 16); + vuint16m2_t v_sc16 = __riscv_vwcvtu_x_x_v_u16m2(v_sc8v, 16); + vuint32m4_t v_sc32u = __riscv_vwcvtu_x_x_v_u32m4(v_sc16, 16); + vint32m4_t v_sc32 = __riscv_vreinterpret_v_u32m4_i32m4(v_sc32u); + vint32m4_t v_prod = __riscv_vmul_vv_i32m4(v_sums, v_sc32, 16); + + vint32m1_t v_zero32 = __riscv_vmv_v_x_i32m1(0, 1); + int32_t sum_part = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m4_i32m1(v_prod, v_zero32, 16)); + sum_block += sum_part; + + sumf += sum_block * combined_scale; + } + *s = 0.125f * sumf; +} #endif void ggml_vec_dot_iq2_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { -#if defined __riscv_v_intrinsic +#if defined __riscv_v switch (__riscv_vlenb() * 8) { case 128: ggml_vec_dot_iq2_s_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); @@ -2898,8 +4219,11 @@ void ggml_vec_dot_iq2_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo case 256: ggml_vec_dot_iq2_s_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); break; + case 512: + ggml_vec_dot_iq2_s_q8_K_vl512(n, s, bs, vx, bx, vy, by, nrc); + break; default: - ggml_vec_dot_iq2_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); + ggml_vec_dot_iq2_s_q8_K_vl1024(n, s, bs, vx, bx, vy, by, nrc); break; } #else @@ -2907,7 +4231,7 @@ void ggml_vec_dot_iq2_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo #endif } -#if defined __riscv_v_intrinsic +#if defined __riscv_v static const int8_t keven_signs_q2xs[1024] = { 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, -1, @@ -3045,59 +4369,140 @@ static NOINLINE void ggml_vec_dot_iq2_xs_q8_K_vl256(int n, float * GGML_RESTRICT int32_t sum_int = 0; - // Loop over 4 subblocks of 64 elements (QK_K = 256) - for (int ib64 = 0; ib64 < QK_K / 64; ++ib64) { - // Load 8 uint16 indices (controls 64 values) - vuint16mf2_t v_qs = __riscv_vle16_v_u16mf2(qs, 8); - qs += 8; + for (int ib128 = 0; ib128 < 2; ++ib128) { + + vuint16m1_t v_qs = __riscv_vle16_v_u16m1(qs, 16); + qs += 16; - // Extract indices for grid (low 9 bits) and signs (high 7 bits) - // Multiply by 8 (<< 3) for byte offsets into the uint64 tables - vuint16mf2_t vidx_grid = __riscv_vsll_vx_u16mf2(__riscv_vand_vx_u16mf2(v_qs, 511, 8), 3, 8); - vuint16mf2_t vidx_sign = __riscv_vsll_vx_u16mf2(__riscv_vsrl_vx_u16mf2(v_qs, 9, 8), 3, 8); + // Prepare offsets for grid and signs + vuint16m1_t vidx_grid = __riscv_vsll_vx_u16m1(__riscv_vand_vx_u16m1(v_qs, 511, 16), 3, 16); + vuint16m1_t vidx_sign = __riscv_vsll_vx_u16m1(__riscv_vsrl_vx_u16m1(v_qs, 9, 16), 3, 16); - vuint64m2_t vq2_64 = __riscv_vluxei16_v_u64m2(grid64, vidx_grid, 8); - vuint64m2_t vs2_64 = __riscv_vluxei16_v_u64m2(signs64, vidx_sign, 8); + // Indexed load 128 weights (16 x 8-byte chunks) + vuint64m4_t vq2_64 = __riscv_vluxei16_v_u64m4(grid64, vidx_grid, 16); + vuint64m4_t vs2_64 = __riscv_vluxei16_v_u64m4(signs64, vidx_sign, 16); - vint8m2_t q2u = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vreinterpret_v_u64m2_u8m2(vq2_64)); - vint8m2_t q2s = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vreinterpret_v_u64m2_u8m2(vs2_64)); + vint8m4_t q2u = __riscv_vreinterpret_v_u8m4_i8m4(__riscv_vreinterpret_v_u64m4_u8m4(vq2_64)); + vint8m4_t q2s = __riscv_vreinterpret_v_u8m4_i8m4(__riscv_vreinterpret_v_u64m4_u8m4(vs2_64)); - vint8m2_t q2_final = __riscv_vmul_vv_i8m2(q2u, q2s, 64); + // Apply signs to get dequantized IQ2 values + vint8m4_t q2_final = __riscv_vmul_vv_i8m4(q2u, q2s, 128); + asm volatile("" ::: "memory"); - vint8m2_t q8v = __riscv_vle8_v_i8m2(q8, 64); - q8 += 64; + // Load corresponding Q8 weights + vint8m4_t q8v = __riscv_vle8_v_i8m4(q8, 128); + q8 += 128; + + vint16m8_t prod = __riscv_vwmul_vv_i16m8(q2_final, q8v, 128); + asm volatile("" ::: "memory"); - vint16m4_t prod = __riscv_vwmul_vv_i16m4(q2_final, q8v, 64); + uint8_t sc0 = scales[0]; + uint8_t sc1 = scales[1]; + uint8_t sc2 = scales[2]; + uint8_t sc3 = scales[3]; + scales += 4; vint32m1_t zero_vec = __riscv_vmv_v_x_i32m1(0, 1); - int32_t sum0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( - __riscv_vget_v_i16m4_i16m1(prod, 0), zero_vec, 16)); - int32_t sum1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( - __riscv_vget_v_i16m4_i16m1(prod, 1), zero_vec, 16)); - int32_t sum2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( - __riscv_vget_v_i16m4_i16m1(prod, 2), zero_vec, 16)); - int32_t sum3 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( - __riscv_vget_v_i16m4_i16m1(prod, 3), zero_vec, 16)); + // 9. Reduce each 16-element chunk and apply corresponding nibble scale - const uint8_t scale_byte_1 = scales[0]; - const uint8_t scale_byte_2 = scales[1]; - scales += 2; + int32_t s0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 0), zero_vec, 16)); + sum_int += s0 * ((sc0 & 0x0F) * 2 + 1); - sum_int += sum0 * ((scale_byte_1 & 0x0F) * 2 + 1); - sum_int += sum1 * ((scale_byte_1 >> 4) * 2 + 1); - sum_int += sum2 * ((scale_byte_2 & 0x0F) * 2 + 1); - sum_int += sum3 * ((scale_byte_2 >> 4) * 2 + 1); + int32_t s1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 1), zero_vec, 16)); + sum_int += s1 * ((sc0 >> 4) * 2 + 1); + + int32_t s2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 2), zero_vec, 16)); + sum_int += s2 * ((sc1 & 0x0F) * 2 + 1); + + int32_t s3 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 3), zero_vec, 16)); + sum_int += s3 * ((sc1 >> 4) * 2 + 1); + + int32_t s4 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 4), zero_vec, 16)); + sum_int += s4 * ((sc2 & 0x0F) * 2 + 1); + + int32_t s5 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 5), zero_vec, 16)); + sum_int += s5 * ((sc2 >> 4) * 2 + 1); + + int32_t s6 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 6), zero_vec, 16)); + sum_int += s6 * ((sc3 & 0x0F) * 2 + 1); + + int32_t s7 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 7), zero_vec, 16)); + sum_int += s7 * ((sc3 >> 4) * 2 + 1); } - sumf += d * sum_int; + sumf += d * (float)sum_int; + } + *s = 0.125f * sumf; +} + +static NOINLINE void ggml_vec_dot_iq2_xs_q8_K_vl512(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); + + const block_iq2_xs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + const uint64_t * grid64 = (const uint64_t *)iq2xs_grid; + + float sumf = 0.0f; + for (int i = 0; i < nb; ++i) { + const float combined_scale = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint16_t * GGML_RESTRICT qs = x[i].qs; + const uint8_t * GGML_RESTRICT scales = x[i].scales; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + vint8m4_t q8_all = __riscv_vle8_v_i8m4(q8, 256); + + // Load indices --- + vuint16m1_t v_qs = __riscv_vle16_v_u16m1(qs, 32); + + // Extract low 9 bits and multiply by 8 (shift left 3) for byte offset into uint64 table + vuint16m1_t vidx_grid = __riscv_vsll_vx_u16m1(__riscv_vand_vx_u16m1(v_qs, 511, 32), 3, 32); + + // Extract high 7 bits (shift right 9) and multiply by 8 (shift left 3) for byte offset + vuint16m1_t vidx_sign = __riscv_vsll_vx_u16m1(__riscv_vsrl_vx_u16m1(v_qs, 9, 32), 3, 32); + + vuint64m4_t vq2_64 = __riscv_vluxei16_v_u64m4(grid64, vidx_grid, 32); + vuint64m4_t vs2_64 = __riscv_vluxei16_v_u64m4(signs64, vidx_sign, 32); + + vint8m4_t q2_all = __riscv_vreinterpret_v_u8m4_i8m4(__riscv_vreinterpret_v_u64m4_u8m4(vq2_64)); + vint8m4_t s2_all = __riscv_vreinterpret_v_u8m4_i8m4(__riscv_vreinterpret_v_u64m4_u8m4(vs2_64)); + + vint8m4_t q2_signed = __riscv_vmul_vv_i8m4(q2_all, s2_all, 256); + vint16m8_t dot_all = __riscv_vwmul_vv_i16m8(q2_signed, q8_all, 256); + float sum = 0.0f; + vint32m1_t zero_vec = __riscv_vmv_v_x_i32m1(0, 1); + +#pragma GCC unroll 1 + for (int j = 0; j < 8; ++j) { + uint8_t sc = scales[j]; + int16_t sc_lo = 2 * (sc & 0x0F) + 1; + int16_t sc_hi = 2 * (sc >> 4) + 1; + + vint32m1_t sum_v0 = __riscv_vwredsum_vs_i16m8_i32m1( + __riscv_vslidedown_vx_i16m8(dot_all, j * 32, 16), zero_vec, 16); + int32_t isum0 = __riscv_vmv_x_s_i32m1_i32(sum_v0); + + vint32m1_t sum_v1 = __riscv_vwredsum_vs_i16m8_i32m1( + __riscv_vslidedown_vx_i16m8(dot_all, j * 32 + 16, 16), zero_vec, 16); + int32_t isum1 = __riscv_vmv_x_s_i32m1_i32(sum_v1); + + sum += (float)isum0 * sc_lo + (float)isum1 * sc_hi; + } + + sumf += sum * combined_scale; } *s = 0.125f * sumf; } #endif void ggml_vec_dot_iq2_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { -#if defined __riscv_v_intrinsic +#if defined __riscv_v switch (__riscv_vlenb() * 8) { case 128: ggml_vec_dot_iq2_xs_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); @@ -3105,8 +4510,8 @@ void ggml_vec_dot_iq2_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v case 256: ggml_vec_dot_iq2_xs_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); break; - default: - ggml_vec_dot_iq2_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); + default: // 512 and above + ggml_vec_dot_iq2_xs_q8_K_vl512(n, s, bs, vx, bx, vy, by, nrc); break; } #else @@ -3114,7 +4519,7 @@ void ggml_vec_dot_iq2_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v #endif } -#if defined __riscv_v_intrinsic +#if defined __riscv_v static NOINLINE void ggml_vec_dot_iq2_xxs_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); @@ -3299,24 +4704,99 @@ static NOINLINE void ggml_vec_dot_iq2_xxs_q8_K_vl256(int n, float * GGML_RESTRIC } *s = 0.125f * sumf; } + +static NOINLINE void ggml_vec_dot_iq2_xxs_q8_K_vl512(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); + + const block_iq2_xxs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + const uint64_t * grid64 = (const uint64_t *)iq2xxs_grid; + // Shift pattern {0,7,14,21} repeated 8 times for all 8 sub-blocks + uint8_t shift_arr[32] = { + 0, 7, 14, 21, 0, 7, 14, 21, 0, 7, 14, 21, 0, 7, 14, 21, + 0, 7, 14, 21, 0, 7, 14, 21, 0, 7, 14, 21, 0, 7, 14, 21 + }; + vuint8mf2_t v_shifts = __riscv_vle8_v_u8mf2(shift_arr, 32); + + // Gather pattern to broadcast the 8 sub-block scales across the 32 lookup slots + uint8_t gather_arr[32] = { + 0,0,0,0, 1,1,1,1, 2,2,2,2, 3,3,3,3, + 4,4,4,4, 5,5,5,5, 6,6,6,6, 7,7,7,7 + }; + vuint8mf2_t v_sign_gather_idx = __riscv_vle8_v_u8mf2(gather_arr, 32); + + float sumf = 0.0f; + for (int i = 0; i < nb; ++i) { + const float combined_scale = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * GGML_RESTRICT q2_ptr = (const uint8_t *) x[i].qs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + vint8m4_t q8_all = __riscv_vle8_v_i8m4(q8, 256); + + // De-interleave all 8 Index/Scale pairs for the 8x32-element sub-blocks + vuint32mf2x2_t tuple = __riscv_vlseg2e32_v_u32mf2x2((const uint32_t*)q2_ptr, 8); + vuint32mf2_t v_ind32 = __riscv_vget_v_u32mf2x2_u32mf2(tuple, 0); + vuint32mf2_t v_sc32 = __riscv_vget_v_u32mf2x2_u32mf2(tuple, 1); + + vuint8mf2_t v_raw_q2 = __riscv_vreinterpret_v_u32mf2_u8mf2(v_ind32); + vuint16m1_t vidx_q2 = __riscv_vwcvtu_x_x_v_u16m1(v_raw_q2, 32); + vidx_q2 = __riscv_vsll_vx_u16m1(vidx_q2, 3, 32); + + vuint32m2_t v_s = __riscv_vrgatherei16_vv_u32m2(__riscv_vlmul_ext_v_u32mf2_u32m2(v_sc32), __riscv_vwcvtu_x_x_v_u16m1(v_sign_gather_idx,32), 32); + v_s = __riscv_vsrl_vv_u32m2(v_s, __riscv_vwcvtu_x_x_v_u32m2(__riscv_vwcvtu_x_x_v_u16m1(v_shifts,32),32), 32); + v_s = __riscv_vand_vx_u32m2(v_s, 127, 32); + vuint16m1_t vidx_s2 = __riscv_vsll_vx_u16m1(__riscv_vncvt_x_x_w_u16m1(v_s, 32), 3, 32); + + vuint64m4_t vq2_64 = __riscv_vluxei16_v_u64m4(grid64, vidx_q2, 32); + vuint64m4_t vs2_64 = __riscv_vluxei16_v_u64m4(signs64, vidx_s2, 32); + vint8m4_t q2_all = __riscv_vreinterpret_v_u8m4_i8m4(__riscv_vreinterpret_v_u64m4_u8m4(vq2_64)); + vint8m4_t s2_all = __riscv_vreinterpret_v_u8m4_i8m4(__riscv_vreinterpret_v_u64m4_u8m4(vs2_64)); + + vint8m4_t q8s_all = __riscv_vmul_vv_i8m4(q8_all, s2_all, 256); + vint16m8_t dot_all = __riscv_vwmul_vv_i16m8(q8s_all, q2_all, 256); + + float sum = 0.0f; + vint32m1_t zero_vec = __riscv_vmv_v_x_i32m1(0, 1); + + for (int j = 0; j < 8; ++j) { + uint32_t s_p = __riscv_vmv_x_s_u32mf2_u32(__riscv_vslidedown_vx_u32mf2(v_sc32, j, 8)); + int16_t sc = 2 * ((s_p >> 28) & 0xF) + 1; + dot_all=__riscv_vslidedown_vx_i16m8(dot_all,j*32,32); + vint32m1_t sum_v = __riscv_vwredsum_vs_i16m8_i32m1(dot_all, zero_vec, 32); + int32_t isum = __riscv_vmv_x_s_i32m1_i32(sum_v); + sum += (float)isum * sc; + } + + sumf += sum * combined_scale; + } + *s = 0.125f * sumf; +} #endif void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { -#if defined __riscv_v_intrinsic +#if defined __riscv_v switch (__riscv_vlenb() * 8) { case 128: ggml_vec_dot_iq2_xxs_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); break; - default: // 256 and above + case 256: ggml_vec_dot_iq2_xxs_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); break; + default: // 512 and above + ggml_vec_dot_iq2_xxs_q8_K_vl512(n, s, bs, vx, bx, vy, by, nrc); + break; } #else ggml_vec_dot_iq2_xxs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); #endif } -#if defined __riscv_v_intrinsic +#if defined __riscv_v static NOINLINE void ggml_vec_dot_iq3_s_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); @@ -3506,19 +4986,108 @@ static NOINLINE void ggml_vec_dot_iq3_s_q8_K_vl256(int n, float * GGML_RESTRICT } *s = sumf; } + +static NOINLINE void ggml_vec_dot_iq3_s_q8_K_vl512(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); + const block_iq3_s * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + const uint32_t * grid32 = (const uint32_t *)iq3s_grid; + + // Generate Constants + vuint8mf2_t v_id_32 = __riscv_vid_v_u8mf2(32); + vuint8mf2_t v_qh_gather = __riscv_vsrl_vx_u8mf2(v_id_32, 3, 32); + vuint8mf2_t v_qh_shifts = __riscv_vand_vx_u8mf2(v_id_32, 7, 32); + vuint8m2_t v_id_128 = __riscv_vid_v_u8m2(128); + vuint8m2_t v_sign_gather = __riscv_vsrl_vx_u8m2(v_id_128, 3, 128); // byte index + vuint8m2_t v_sign_shift_amts = __riscv_vand_vx_u8m2(v_id_128, 7, 128); // bit shift + vuint8m2_t v_one_128 = __riscv_vmv_v_x_u8m2(1, 128); + vuint8m2_t v_sign_masks = __riscv_vsll_vv_u8m2(v_one_128, v_sign_shift_amts, 128); + vuint8m2_t v_scale_indices = __riscv_vsrl_vx_u8m2(v_id_128, 5, 128); + + float sumf = 0.0f; + + for (int i = 0; i < nb; ++i) { + const float combined_scale = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * GGML_RESTRICT qs = x[i].qs; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const uint8_t * GGML_RESTRICT scales = x[i].scales; + const uint8_t * GGML_RESTRICT signs = x[i].signs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + float sum_block = 0.0f; + for (int ib = 0; ib < 2; ++ib) { + vuint8mf2_t v_qs_u8 = __riscv_vle8_v_u8mf2(qs, 32); + qs += 32; + vuint8mf2_t v_qh_loaded = __riscv_vle8_v_u8mf2(qh, 4); + qh += 4; + vuint8mf2_t v_qh_expanded = __riscv_vrgather_vv_u8mf2(v_qh_loaded, v_qh_gather, 32); + v_qh_expanded = __riscv_vsrl_vv_u8mf2(v_qh_expanded, v_qh_shifts, 32); + v_qh_expanded = __riscv_vand_vx_u8mf2(v_qh_expanded, 1, 32); + vuint16m1_t v_qs_u16 = __riscv_vwcvtu_x_x_v_u16m1(v_qs_u8, 32); + v_qs_u16 = __riscv_vsll_vx_u16m1(v_qs_u16, 2, 32); // * 4 + + vuint16m1_t v_qh_u16 = __riscv_vwcvtu_x_x_v_u16m1(v_qh_expanded, 32); + v_qh_u16 = __riscv_vsll_vx_u16m1(v_qh_u16, 10, 32); // * 256 * 4 + + vuint16m1_t v_grid_offsets = __riscv_vor_vv_u16m1(v_qs_u16, v_qh_u16, 32); + vuint32m2_t v_grid_packed = __riscv_vluxei16_v_u32m2(grid32, v_grid_offsets, 32); + vuint8m2_t v_grid_u8 = __riscv_vreinterpret_v_u32m2_u8m2(v_grid_packed); + vuint8mf2_t v_signs_raw = __riscv_vle8_v_u8mf2(signs, 16); + signs += 16; + + vuint8m2_t v_signs_source = __riscv_vlmul_ext_v_u8mf2_u8m2(v_signs_raw); + vuint8m2_t v_signs_bcast = __riscv_vrgather_vv_u8m2(v_signs_source, v_sign_gather, 128); + vuint8m2_t v_sign_bits = __riscv_vand_vv_u8m2(v_signs_bcast, v_sign_masks, 128); + vbool4_t m_negative = __riscv_vmsne_vx_u8m2_b4(v_sign_bits, 0, 128); + + vint8m2_t v_q8 = __riscv_vle8_v_i8m2(q8, 128); + q8 += 128; + + vint8m2_t v_q8_signed = __riscv_vrsub_vx_i8m2_mu(m_negative, v_q8, v_q8, 0, 128); + vint16m4_t v_dot = __riscv_vwmulsu_vv_i16m4(v_q8_signed, v_grid_u8, 128); + uint16_t sc_raw; + memcpy(&sc_raw, scales, 2); + scales += 2; // Advance 2 bytes + + uint8_t sc_unpacked[4]; + sc_unpacked[0] = (sc_raw & 0xF); + sc_unpacked[1] = (sc_raw >> 4) & 0xF; + sc_unpacked[2] = (sc_raw >> 8) & 0xF; + sc_unpacked[3] = (sc_raw >> 12) & 0xF; + + vuint8mf2_t v_sc_4 = __riscv_vle8_v_u8mf2(sc_unpacked, 4); + v_sc_4 = __riscv_vmul_vx_u8mf2(v_sc_4, 2, 4); + v_sc_4 = __riscv_vadd_vx_u8mf2(v_sc_4, 1, 4); + vuint8m2_t v_sc_4_expanded = __riscv_vlmul_ext_v_u8mf2_u8m2(v_sc_4); + vuint8m2_t v_scales_bcast = __riscv_vrgather_vv_u8m2(v_sc_4_expanded, v_scale_indices, 128); + vint16m4_t v_scales_i16 = __riscv_vreinterpret_v_u16m4_i16m4(__riscv_vwcvtu_x_x_v_u16m4(v_scales_bcast, 128)); + vint32m8_t v_weighted_sum = __riscv_vwmul_vv_i32m8(v_dot, v_scales_i16, 128); + vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, 1); + int32_t s_val = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m8_i32m1(v_weighted_sum, v_zero, 128)); + + sum_block += s_val; + } + sumf += sum_block * combined_scale; + } + *s = sumf; +} #endif void ggml_vec_dot_iq3_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { -#if defined __riscv_v_intrinsic +#if defined __riscv_v switch (__riscv_vlenb() * 8) { case 128: - ggml_vec_dot_iq3_s_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); + ggml_vec_dot_iq3_s_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); break; case 256: ggml_vec_dot_iq3_s_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); break; - default: - ggml_vec_dot_iq3_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); + default: // 512 and above + ggml_vec_dot_iq3_s_q8_K_vl512(n, s, bs, vx, bx, vy, by, nrc); break; } #else @@ -3526,7 +5095,7 @@ void ggml_vec_dot_iq3_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo #endif } -#if defined __riscv_v_intrinsic +#if defined __riscv_v static NOINLINE void ggml_vec_dot_iq3_xxs_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); @@ -3712,10 +5281,181 @@ static NOINLINE void ggml_vec_dot_iq3_xxs_q8_K_vl256(int n, float * GGML_RESTRIC } *s = 0.25f * sumf; } + +static NOINLINE void ggml_vec_dot_iq3_xxs_q8_K_vl512(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); + const block_iq3_xxs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + const int nb = n / QK_K; + + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + const uint32_t * grid32 = (const uint32_t *)iq3xxs_grid; + + // generate constants for unpacking metadata words into sign indices + vuint32m1_t v_shifts; + { + vuint32m1_t v_base = __riscv_vid_v_u32m1(16); + vuint32m1_t v_mod4 = __riscv_vand_vx_u32m1(v_base, 3, 16); + v_shifts = __riscv_vmul_vx_u32m1(v_mod4, 7, 16); + } + + vuint16mf2_t v_gather_idx; + { + vuint16mf2_t v_idx = __riscv_vid_v_u16mf2(16); + v_gather_idx = __riscv_vsrl_vx_u16mf2(v_idx, 2, 16); + } + + float sumf = 0.0f; + + for (int i = 0; i < nb; ++i) { + const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * GGML_RESTRICT q3_indices = x[i].qs; + const uint8_t * GGML_RESTRICT metadata = x[i].qs + QK_K/4; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + float block_sum = 0.0f; + for (int ib128 = 0; ib128 < 2; ++ib128) { + + vint8m2_t v_q8 = __riscv_vle8_v_i8m2(q8, 128); + q8 += 128; + vuint8mf2_t v_q3_idx_u8 = __riscv_vle8_v_u8mf2(q3_indices, 32); + q3_indices += 32; + + vuint16m1_t v_q3_idx_u16 = __riscv_vwmulu_vx_u16m1(v_q3_idx_u8, 4, 32); + vuint32m2_t v_q3_mag_u32 = __riscv_vluxei16_v_u32m2(grid32, v_q3_idx_u16, 32); + vint8m2_t v_q3_magnitudes = __riscv_vreinterpret_v_u8m2_i8m2( + __riscv_vreinterpret_v_u32m2_u8m2(v_q3_mag_u32)); + vuint32m1_t v_aux = __riscv_vreinterpret_v_u8m1_u32m1(__riscv_vle8_v_u8m1(metadata, 16)); + metadata += 4 * sizeof(uint32_t); + + vuint32m1_t v_aux_expanded = __riscv_vrgatherei16_vv_u32m1(v_aux, v_gather_idx, 16); + + vuint32m1_t v_s_raw = __riscv_vand_vx_u32m1( + __riscv_vsrl_vv_u32m1(v_aux_expanded, v_shifts, 16), 127, 16); + vuint16mf2_t sign_byte_offset = __riscv_vsll_vx_u16mf2( + __riscv_vncvt_x_x_w_u16mf2(v_s_raw, 16), 3, 16); + vuint64m2_t v_s_u64 = __riscv_vluxei16_v_u64m2(signs64, sign_byte_offset, 16); + vint8m2_t v_signs = __riscv_vreinterpret_v_u8m2_i8m2( + __riscv_vreinterpret_v_u64m2_u8m2(v_s_u64)); + vint8m2_t v_q3_signed = __riscv_vmul_vv_i8m2(v_q3_magnitudes, v_signs, 128); + vint16m4_t prod = __riscv_vwmul_vv_i16m4(v_q3_signed, v_q8, 128); + + vint32m1_t zero_vec = __riscv_vmv_v_x_i32m1(0, 1); + int32_t group0_sum = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( + __riscv_vget_v_i16m4_i16m1(prod, 0), zero_vec, 32)); + int32_t group1_sum = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( + __riscv_vget_v_i16m4_i16m1(prod, 1), zero_vec, 32)); + int32_t group2_sum = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( + __riscv_vget_v_i16m4_i16m1(prod, 2), zero_vec, 32)); + int32_t group3_sum = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( + __riscv_vget_v_i16m4_i16m1(prod, 3), zero_vec, 32)); + + vuint32m1_t v_scales_raw = __riscv_vsrl_vx_u32m1(v_aux, 28, 4); + vuint32m1_t v_scales = __riscv_vadd_vx_u32m1( + __riscv_vsll_vx_u32m1(v_scales_raw, 1, 4), + 1, 4); + int32_t scale0 = (int32_t)__riscv_vmv_x_s_u32m1_u32(v_scales); + int32_t scale1 = (int32_t)__riscv_vmv_x_s_u32m1_u32(__riscv_vslidedown_vx_u32m1(v_scales, 1, 4)); + int32_t scale2 = (int32_t)__riscv_vmv_x_s_u32m1_u32(__riscv_vslidedown_vx_u32m1(v_scales, 2, 4)); + int32_t scale3 = (int32_t)__riscv_vmv_x_s_u32m1_u32(__riscv_vslidedown_vx_u32m1(v_scales, 3, 4)); + + block_sum += (float)(group0_sum * scale0 + group1_sum * scale1 + + group2_sum * scale2 + group3_sum * scale3); + } + + sumf += d * block_sum; + } + *s = 0.25f * sumf; +} + +static NOINLINE void ggml_vec_dot_iq3_xxs_q8_K_vl1024(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); + + const block_iq3_xxs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + const int nb = n / QK_K; + + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + const uint32_t * grid32 = (const uint32_t *)iq3xxs_grid; + + vuint32m1_t v_shifts; + { + vuint32m1_t v_id = __riscv_vid_v_u32m1(32); + vuint32m1_t v_mod4 = __riscv_vand_vx_u32m1(v_id, 3, 32); + v_shifts = __riscv_vmul_vx_u32m1(v_mod4, 7, 32); + } + vuint16mf2_t v_gather_idx; + { + vuint16mf2_t v_id_16 = __riscv_vid_v_u16mf2(32); + v_gather_idx = __riscv_vsrl_vx_u16mf2(v_id_16, 2, 32); + } + + float sumf = 0.0f; + uint32_t aux32[8]; // Buffer for block metadata + + for (int i = 0; i < nb; ++i) { + const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * GGML_RESTRICT q3_indices = x[i].qs; + const uint8_t * GGML_RESTRICT metadata = x[i].qs + QK_K/4; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + vint8m2_t v_q8 = __riscv_vle8_v_i8m2(q8, 256); + vuint8mf2_t v_q3_idx_raw = __riscv_vle8_v_u8mf2(q3_indices, 64); + vuint16m1_t v_q3_idx_u16 = __riscv_vwmulu_vx_u16m1(v_q3_idx_raw, 4, 64); + + vuint32m2_t v_q3_grid_vals = __riscv_vluxei16_v_u32m2(grid32, v_q3_idx_u16, 64); + + vint8m2_t v_q3_mags = __riscv_vreinterpret_v_u8m2_i8m2( + __riscv_vreinterpret_v_u32m2_u8m2(v_q3_grid_vals)); + + memcpy(aux32, metadata, 8 * sizeof(uint32_t)); + vuint32m1_t v_aux_8 = __riscv_vle32_v_u32m1(aux32, 8); + + vuint32m1_t v_aux_32 = __riscv_vrgatherei16_vv_u32m1(v_aux_8, v_gather_idx, 32); + + vuint32m1_t v_sign_idx_raw = __riscv_vand_vx_u32m1( + __riscv_vsrl_vv_u32m1(v_aux_32, v_shifts, 32), 127, 32); + + vuint16mf2_t v_sign_offsets = __riscv_vsll_vx_u16mf2( + __riscv_vncvt_x_x_w_u16mf2(v_sign_idx_raw, 32), 3, 32); + + vuint64m2_t v_signs_u64 = __riscv_vluxei16_v_u64m2(signs64, v_sign_offsets, 32); + + vint8m2_t v_signs = __riscv_vreinterpret_v_u8m2_i8m2( + __riscv_vreinterpret_v_u64m2_u8m2(v_signs_u64)); + + vint8m2_t v_q3_final = __riscv_vmul_vv_i8m2(v_q3_mags, v_signs, 256); + + vint16m4_t v_dot = __riscv_vwmul_vv_i16m4(v_q8, v_q3_final, 256); + float block_sum = 0.0f; + vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, 1); + vint16m4_t v_accum = v_dot; + + for (int j = 0; j < 8; ++j) { + float scale = (float)(2 * (aux32[j] >> 28) + 1); + + vint32m1_t v_partial_sum = __riscv_vwredsum_vs_i16m4_i32m1(v_accum, v_zero, 32); + + int32_t partial_sum_i = __riscv_vmv_x_s_i32m1_i32(v_partial_sum); + block_sum += partial_sum_i * scale; + v_accum = __riscv_vslidedown_vx_i16m4(v_accum, 32, 32); + + } + + sumf += d * block_sum; + } + *s = 0.25f * sumf; +} #endif void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { -#if defined __riscv_v_intrinsic +#if defined __riscv_v switch (__riscv_vlenb() * 8) { case 128: ggml_vec_dot_iq3_xxs_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); @@ -3723,8 +5463,11 @@ void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const case 256: ggml_vec_dot_iq3_xxs_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); break; - default: - ggml_vec_dot_iq3_xxs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); + case 512: + ggml_vec_dot_iq3_xxs_q8_K_vl512(n, s, bs, vx, bx, vy, by, nrc); + break; + default: // 1024 and above + ggml_vec_dot_iq3_xxs_q8_K_vl1024(n, s, bs, vx, bx, vy, by, nrc); break; } #else @@ -3732,7 +5475,7 @@ void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const #endif } -#if defined __riscv_v_intrinsic +#if defined __riscv_v static NOINLINE void ggml_vec_dot_iq4_nl_q8_0_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(nrc == 1); UNUSED(nrc); @@ -3847,7 +5590,7 @@ static NOINLINE void ggml_vec_dot_iq4_nl_q8_0_vl256(int n, float * GGML_RESTRICT #endif void ggml_vec_dot_iq4_nl_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { -#if defined __riscv_v_intrinsic +#if defined __riscv_v switch (__riscv_vlenb() * 8) { case 128: ggml_vec_dot_iq4_nl_q8_0_vl128(n, s, bs, vx, bx, vy, by, nrc); @@ -3861,7 +5604,7 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v #endif } -#if defined __riscv_v_intrinsic +#if defined __riscv_v static NOINLINE void ggml_vec_dot_iq4_xs_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(nrc == 1); UNUSED(nrc); @@ -4007,10 +5750,205 @@ static NOINLINE void ggml_vec_dot_iq4_xs_q8_K_vl256(int n, float * GGML_RESTRICT *s = sumf; } + +static NOINLINE void ggml_vec_dot_iq4_xs_q8_K_vl512(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + assert(n % QK_K == 0); + + const block_iq4_xs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + const vint8m4_t values = __riscv_vle8_v_i8m4(kvalues_iq4nl, 16); + float sumf = 0; + + // Indices for re-ordering IQ4 data. + const uint16_t index[32] = { + 0, 1, 16, 17, + 2, 3, 18, 19, + 4, 5,20, 21, + 6, 7, 22, 23, + 8, 9, 24, 25, + 10, 11, 26, 27, + 12, 13,28, 29, + 14, 15, 30, 31, + }; + const vuint16m1_t i_vec = __riscv_vle16_v_u16m1(index, 32); + + for (int ibl = 0; ibl < nb; ++ibl) { + const int8_t * q8 = y[ibl].qs; + const uint8_t * iq4 = x[ibl].qs; + uint16_t h = x[ibl].scales_h; + + int sumi = 0; + + #pragma GCC unroll 1 + // Process the entire super-block together. + for (int ib = 0; ib < QK_K / 256; ++ib) { + // Weights and activations. + const vuint8m2_t iq4_packed = __riscv_vle8_v_u8m2(iq4, 128); + iq4 += 128; + + // Unpack the weight blocks. + const vuint8m2_t iq4bits_lo = __riscv_vand_vx_u8m2(iq4_packed, 0xf, 128); + const vuint8m2_t iq4bits_hi = __riscv_vsrl_vx_u8m2(iq4_packed, 4, 128); + const vuint8m4_t iq4bits = __riscv_vcreate_v_u8m2_u8m4(iq4bits_lo, iq4bits_hi); + const vuint8m4_t iq4bits_reorder = __riscv_vreinterpret_v_u64m4_u8m4(__riscv_vrgatherei16_vv_u64m4(__riscv_vreinterpret_v_u8m4_u64m4(iq4bits), i_vec, 32)); + const vint8m4_t iq4b = __riscv_vrgather_vv_i8m4(values, iq4bits_reorder, 256); + + __asm__ __volatile__("" ::: "memory"); + + // Multiply with activations. + const vint8m4_t q8b = __riscv_vle8_v_i8m4(q8, 256); + const vint16m8_t prod = __riscv_vwmul_vv_i16m8(iq4b, q8b, 256); + q8 += 256; + + // Reduce separately. + const int acc0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 0), __riscv_vmv_v_x_i32m1(0, 1), 32)); + const int acc1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 1), __riscv_vmv_v_x_i32m1(0, 1), 32)); + const int acc2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 2), __riscv_vmv_v_x_i32m1(0, 1), 32)); + const int acc3 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 3), __riscv_vmv_v_x_i32m1(0, 1), 32)); + const int acc4 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 4), __riscv_vmv_v_x_i32m1(0, 1), 32)); + const int acc5 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 5), __riscv_vmv_v_x_i32m1(0, 1), 32)); + const int acc6 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 6), __riscv_vmv_v_x_i32m1(0, 1), 32)); + const int acc7 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 7), __riscv_vmv_v_x_i32m1(0, 1), 32)); + + + const int ls0 = ((x[ibl].scales_l[0] & 0xf) | ((h << 4) & 0x30)) - 32; + const int ls1 = ((x[ibl].scales_l[0] >> 4) | ((h << 2) & 0x30)) - 32; + const int ls2 = ((x[ibl].scales_l[1] & 0xf) | ((h << 0) & 0x30)) - 32; + const int ls3 = ((x[ibl].scales_l[1] >> 4) | ((h >> 2) & 0x30)) - 32; + h >>= 8; + const int ls4 = ((x[ibl].scales_l[2] & 0xf) | ((h << 4) & 0x30)) - 32; + const int ls5 = ((x[ibl].scales_l[2] >> 4) | ((h << 2) & 0x30)) - 32; + const int ls6 = ((x[ibl].scales_l[3] & 0xf) | ((h << 0) & 0x30)) - 32; + const int ls7 = ((x[ibl].scales_l[3] >> 4) | ((h >> 2) & 0x30)) - 32; + + sumi += acc0 * ls0; + sumi += acc1 * ls1; + sumi += acc2 * ls2; + sumi += acc3 * ls3; + sumi += acc4 * ls4; + sumi += acc5 * ls5; + sumi += acc6 * ls6; + sumi += acc7 * ls7; + + __asm__ __volatile__("" ::: "memory"); + } + + sumf += GGML_CPU_FP16_TO_FP32(x[ibl].d) * y[ibl].d * (sumi); + } + + *s = sumf; +} + +static NOINLINE void ggml_vec_dot_iq4_xs_q8_K_vl1024(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + assert(n % QK_K == 0); + + const block_iq4_xs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + const vint8m2_t values = __riscv_vle8_v_i8m2(kvalues_iq4nl, 16); + float sumf = 0; + + // Indices for re-ordering IQ4 data. + const uint16_t index[32] = { + 0, 1, 16, 17, + 2, 3, 18, 19, + 4, 5,20, 21, + 6, 7, 22, 23, + 8, 9, 24, 25, + 10, 11, 26, 27, + 12, 13,28, 29, + 14, 15, 30, 31, + }; + const vuint16mf2_t i_vec = __riscv_vle16_v_u16mf2(index, 32); + + for (int ibl = 0; ibl < nb; ++ibl) { + const int8_t * q8 = y[ibl].qs; + const uint8_t * iq4 = x[ibl].qs; + uint16_t h = x[ibl].scales_h; + + int sumi = 0; + + #pragma GCC unroll 1 + // Process the entire super-block together. + for (int ib = 0; ib < QK_K / 256; ++ib) { + // Weights and activations. + const vuint8m1_t iq4_packed = __riscv_vle8_v_u8m1(iq4, 128); + iq4 += 128; + + // Unpack the weight blocks. + const vuint8m1_t iq4bits_lo = __riscv_vand_vx_u8m1(iq4_packed, 0xf, 128); + const vuint8m1_t iq4bits_hi = __riscv_vsrl_vx_u8m1(iq4_packed, 4, 128); + const vuint8m2_t iq4bits = __riscv_vcreate_v_u8m1_u8m2(iq4bits_lo, iq4bits_hi); + const vuint8m2_t iq4bits_reorder = __riscv_vreinterpret_v_u64m2_u8m2(__riscv_vrgatherei16_vv_u64m2(__riscv_vreinterpret_v_u8m2_u64m2(iq4bits), i_vec, 32)); + const vint8m2_t iq4b = __riscv_vrgather_vv_i8m2(values, iq4bits_reorder, 256); + + __asm__ __volatile__("" ::: "memory"); + + // Multiply with activations. + const vint8m2_t q8b = __riscv_vle8_v_i8m2(q8, 256); + const vint16m4_t prod = __riscv_vwmul_vv_i16m4(iq4b, q8b, 256); + q8 += 256; + + // Mask for processing 32 elements per prod register. + const vuint16m1_t p_index = __riscv_vid_v_u16m1(64); + const vbool16_t p_mask = __riscv_vmsgtu_vx_u16m1_b16(p_index, 31, 64); + + // Reduce separately. + const int acc0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( __riscv_vget_v_i16m4_i16m1(prod, 0), __riscv_vmv_v_x_i32m1(0, 1), 32)); + const int acc1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1_m(p_mask, __riscv_vget_v_i16m4_i16m1(prod, 0), __riscv_vmv_v_x_i32m1(0, 1), 64)); + const int acc2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( __riscv_vget_v_i16m4_i16m1(prod, 1), __riscv_vmv_v_x_i32m1(0, 1), 32)); + const int acc3 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1_m(p_mask, __riscv_vget_v_i16m4_i16m1(prod, 1), __riscv_vmv_v_x_i32m1(0, 1), 64)); + const int acc4 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( __riscv_vget_v_i16m4_i16m1(prod, 2), __riscv_vmv_v_x_i32m1(0, 1), 32)); + const int acc5 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1_m(p_mask, __riscv_vget_v_i16m4_i16m1(prod, 2), __riscv_vmv_v_x_i32m1(0, 1), 64)); + const int acc6 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( __riscv_vget_v_i16m4_i16m1(prod, 3), __riscv_vmv_v_x_i32m1(0, 1), 32)); + const int acc7 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1_m(p_mask, __riscv_vget_v_i16m4_i16m1(prod, 3), __riscv_vmv_v_x_i32m1(0, 1), 64)); + + const int ls0 = ((x[ibl].scales_l[0] & 0xf) | ((h << 4) & 0x30)) - 32; + const int ls1 = ((x[ibl].scales_l[0] >> 4) | ((h << 2) & 0x30)) - 32; + const int ls2 = ((x[ibl].scales_l[1] & 0xf) | ((h << 0) & 0x30)) - 32; + const int ls3 = ((x[ibl].scales_l[1] >> 4) | ((h >> 2) & 0x30)) - 32; + h >>= 8; + const int ls4 = ((x[ibl].scales_l[2] & 0xf) | ((h << 4) & 0x30)) - 32; + const int ls5 = ((x[ibl].scales_l[2] >> 4) | ((h << 2) & 0x30)) - 32; + const int ls6 = ((x[ibl].scales_l[3] & 0xf) | ((h << 0) & 0x30)) - 32; + const int ls7 = ((x[ibl].scales_l[3] >> 4) | ((h >> 2) & 0x30)) - 32; + + sumi += acc0 * ls0; + sumi += acc1 * ls1; + sumi += acc2 * ls2; + sumi += acc3 * ls3; + sumi += acc4 * ls4; + sumi += acc5 * ls5; + sumi += acc6 * ls6; + sumi += acc7 * ls7; + + __asm__ __volatile__("" ::: "memory"); + } + + sumf += GGML_CPU_FP16_TO_FP32(x[ibl].d) * y[ibl].d * (sumi); + } + + *s = sumf; +} #endif void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { -#if defined __riscv_v_intrinsic +#if defined __riscv_v switch (__riscv_vlenb() * 8) { case 128: ggml_vec_dot_iq4_xs_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); @@ -4018,6 +5956,12 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v case 256: ggml_vec_dot_iq4_xs_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); break; + case 512: + ggml_vec_dot_iq4_xs_q8_K_vl512(n, s, bs, vx, bx, vy, by, nrc); + break; + case 1024: + ggml_vec_dot_iq4_xs_q8_K_vl1024(n, s, bs, vx, bx, vy, by, nrc); + break; default: ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); break; @@ -4027,7 +5971,7 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v #endif } -#if defined __riscv_v_intrinsic +#if defined __riscv_v static NOINLINE void ggml_vec_dot_tq1_0_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(nrc == 1); UNUSED(nrc); @@ -4230,10 +6174,112 @@ static NOINLINE void ggml_vec_dot_tq1_0_q8_K_vl256(int n, float * GGML_RESTRICT *s = sumf; } + +static NOINLINE void ggml_vec_dot_tq1_0_q8_K_vl512(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_tq1_0 * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + float sumf = 0.0f; + uint8_t pow[16] = {1, 1, 1, 1, 3, 3, 3, 3, 9, 9, 9, 9, 27, 27, 27, 27}; + + for (int i = 0; i < nb; i++) { + // First loop. + vint16m1_t suml1; + { + const int vl = 32; + vuint8mf2_t tq = __riscv_vle8_v_u8mf2(x[i].qs, vl); + + vuint16m1_t tq0 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(tq, 3, vl), 8, vl); + vuint16m1_t tq1 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(__riscv_vmul_vx_u8mf2(tq, 3, vl), 3, vl), 8, vl); + vuint16m1_t tq2 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(__riscv_vmul_vx_u8mf2(tq, 9, vl), 3, vl), 8, vl); + vuint16m1_t tq3 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(__riscv_vmul_vx_u8mf2(tq, 27, vl), 3, vl), 8, vl); + vuint16m1_t tq4 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(__riscv_vmul_vx_u8mf2(tq, 81, vl), 3, vl), 8, vl); + + vint16m1_t q80 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 0, vl), vl); + vint16m1_t q81 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 32, vl), vl); + vint16m1_t q82 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 64, vl), vl); + vint16m1_t q83 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 96, vl), vl); + vint16m1_t q84 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 128, vl), vl); + + vint16m1_t sum0 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq0, 1, vl)), q80, vl); + vint16m1_t sum1 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq1, 1, vl)), q81, vl); + vint16m1_t sum2 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq2, 1, vl)), q82, vl); + vint16m1_t sum3 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq3, 1, vl)), q83, vl); + vint16m1_t sum4 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq4, 1, vl)), q84, vl); + + vint16m1_t sumi0 = __riscv_vadd_vv_i16m1(sum0, sum1, vl); + vint16m1_t sumi1 = __riscv_vadd_vv_i16m1(sum2, sum3, vl); + suml1 = __riscv_vadd_vv_i16m1(sum4, __riscv_vadd_vv_i16m1(sumi0, sumi1, vl), vl); + } + + // Second loop. + vint16mf2_t suml2; + { + const int vl = 16; + vuint8mf4_t tq = __riscv_vle8_v_u8mf4(x[i].qs + 32, vl); + + vuint16mf2_t tq0 = __riscv_vsrl_vx_u16mf2(__riscv_vwmulu_vx_u16mf2(tq, 3 * 1, vl), 8, vl); + vuint16mf2_t tq1 = __riscv_vsrl_vx_u16mf2(__riscv_vwmulu_vx_u16mf2(__riscv_vmul_vx_u8mf4(tq, 3, vl), 3, vl), 8, vl); + vuint16mf2_t tq2 = __riscv_vsrl_vx_u16mf2(__riscv_vwmulu_vx_u16mf2(__riscv_vmul_vx_u8mf4(tq, 9, vl), 3, vl), 8, vl); + vuint16mf2_t tq3 = __riscv_vsrl_vx_u16mf2(__riscv_vwmulu_vx_u16mf2(__riscv_vmul_vx_u8mf4(tq, 27, vl), 3, vl), 8, vl); + vuint16mf2_t tq4 = __riscv_vsrl_vx_u16mf2(__riscv_vwmulu_vx_u16mf2(__riscv_vmul_vx_u8mf4(tq, 81, vl), 3, vl), 8, vl); + + vint16mf2_t q80 = __riscv_vwcvt_x_x_v_i16mf2(__riscv_vle8_v_i8mf4(y[i].qs + 160, vl), vl); + vint16mf2_t q81 = __riscv_vwcvt_x_x_v_i16mf2(__riscv_vle8_v_i8mf4(y[i].qs + 176, vl), vl); + vint16mf2_t q82 = __riscv_vwcvt_x_x_v_i16mf2(__riscv_vle8_v_i8mf4(y[i].qs + 192, vl), vl); + vint16mf2_t q83 = __riscv_vwcvt_x_x_v_i16mf2(__riscv_vle8_v_i8mf4(y[i].qs + 208, vl), vl); + vint16mf2_t q84 = __riscv_vwcvt_x_x_v_i16mf2(__riscv_vle8_v_i8mf4(y[i].qs + 224, vl), vl); + + vint16mf2_t sum0 = __riscv_vmul_vv_i16mf2(__riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vsub_vx_u16mf2(tq0, 1, vl)), q80, vl); + vint16mf2_t sum1 = __riscv_vmul_vv_i16mf2(__riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vsub_vx_u16mf2(tq1, 1, vl)), q81, vl); + vint16mf2_t sum2 = __riscv_vmul_vv_i16mf2(__riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vsub_vx_u16mf2(tq2, 1, vl)), q82, vl); + vint16mf2_t sum3 = __riscv_vmul_vv_i16mf2(__riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vsub_vx_u16mf2(tq3, 1, vl)), q83, vl); + vint16mf2_t sum4 = __riscv_vmul_vv_i16mf2(__riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vsub_vx_u16mf2(tq4, 1, vl)), q84, vl); + + vint16mf2_t sumi0 = __riscv_vadd_vv_i16mf2(sum0, sum1, vl); + vint16mf2_t sumi1 = __riscv_vadd_vv_i16mf2(sum2, sum3, vl); + suml2 = __riscv_vadd_vv_i16mf2(sum4, __riscv_vadd_vv_i16mf2(sumi0, sumi1, vl), vl); + } + + // Third loop. + vint16mf2_t suml3; + { + const int vl = 16; + + uint32_t qh; + memcpy(&qh, &x[i].qh[0], 4); + // Prevent fusion with vmv. + __asm__ __volatile__("" : "+r"(qh)); + vuint8mf4_t tq = __riscv_vlmul_trunc_v_u8mf2_u8mf4(__riscv_vreinterpret_v_u32mf2_u8mf2(__riscv_vmv_v_x_u32mf2(qh, vl / 4))); + + vuint8mf4_t p = __riscv_vle8_v_u8mf4(pow, vl); + + vuint16mf2_t tq0 = __riscv_vsrl_vx_u16mf2(__riscv_vwmulu_vx_u16mf2(__riscv_vmul_vv_u8mf4(tq, p, vl), 3, vl), 8, vl); + + vint16mf2_t q80 = __riscv_vwcvt_x_x_v_i16mf2(__riscv_vle8_v_i8mf4(y[i].qs + 240, vl), vl); + + suml3 = __riscv_vmul_vv_i16mf2(__riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vsub_vx_u16mf2(tq0, 1, vl)), q80, vl); + } + + vint32m1_t sum = __riscv_vwredsum_vs_i16m1_i32m1(suml1, __riscv_vmv_v_x_i32m1(0, 1), 32); + sum = __riscv_vwredsum_vs_i16mf2_i32m1(__riscv_vadd_vv_i16mf2(suml2, suml3, 16), sum, 16); + sumf += __riscv_vmv_x_s_i32m1_i32(sum) * y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); + } + + *s = sumf; +} #endif void ggml_vec_dot_tq1_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { -#if defined __riscv_v_intrinsic +#if defined __riscv_v switch (__riscv_vlenb() * 8) { case 128: ggml_vec_dot_tq1_0_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); @@ -4241,8 +6287,8 @@ void ggml_vec_dot_tq1_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo case 256: ggml_vec_dot_tq1_0_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); break; - default: - ggml_vec_dot_tq1_0_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); + default: // 512 and above + ggml_vec_dot_tq1_0_q8_K_vl512(n, s, bs, vx, bx, vy, by, nrc); break; } #else @@ -4250,7 +6296,7 @@ void ggml_vec_dot_tq1_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo #endif } -#if defined __riscv_v_intrinsic +#if defined __riscv_v static NOINLINE void ggml_vec_dot_tq2_0_q8_K_vl128(const int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); @@ -4406,24 +6452,21 @@ static NOINLINE void ggml_vec_dot_tq2_0_q8_K_vl256(int n, float * GGML_RESTRICT #endif void ggml_vec_dot_tq2_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { -#if defined __riscv_v_intrinsic +#if defined __riscv_v switch (__riscv_vlenb() * 8) { case 128: ggml_vec_dot_tq2_0_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); break; - case 256: + default: // 256 and above ggml_vec_dot_tq2_0_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); break; - default: - ggml_vec_dot_tq2_0_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); - break; } #else ggml_vec_dot_tq2_0_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); #endif } -#if defined __riscv_v_intrinsic +#if defined __riscv_v static NOINLINE void ggml_vec_dot_mxfp4_q8_0_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(nrc == 1); UNUSED(nrc); @@ -4538,7 +6581,7 @@ static NOINLINE void ggml_vec_dot_mxfp4_q8_0_vl256(int n, float * GGML_RESTRICT #endif void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { -#if defined __riscv_v_intrinsic +#if defined __riscv_v switch (__riscv_vlenb() * 8) { case 128: ggml_vec_dot_mxfp4_q8_0_vl128(n, s, bs, vx, bx, vy, by, nrc); From e9dbd0c18a1904b84c2b75b8bff81ff6ecb6c886 Mon Sep 17 00:00:00 2001 From: Reese Levine <reeselevine1@gmail.com> Date: Wed, 3 Jun 2026 22:05:04 -0700 Subject: [PATCH 784/831] ggml-webgpu: FlashAttention refactor + standardize quantization support (llama/23834) * Start work on flash_attn refactor * Refactor * Split k/v quantization * Refactor and abstract quantization logic for flash_attn and mul_mat * Add quantization support to tile path * formatting * Move to functions, add a check --- ggml/src/ggml-webgpu/CMakeLists.txt | 7 +- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 659 +++++++++--------- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 416 ++++++----- ggml/src/ggml-webgpu/pre_wgsl.hpp | 44 +- .../ggml-webgpu/wgsl-shaders/flash_attn.wgsl | 271 ++----- .../flash_attn_quant_staging.tmpl | 124 ++++ .../wgsl-shaders/flash_attn_tile.wgsl | 126 ++-- .../wgsl-shaders/flash_attn_vec_split.wgsl | 247 ++----- .../wgsl-shaders/mul_mat_decls.tmpl | 20 +- .../wgsl-shaders/quant_inner_loops.tmpl | 21 + 10 files changed, 985 insertions(+), 950 deletions(-) create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_staging.tmpl create mode 100644 ggml/src/ggml-webgpu/wgsl-shaders/quant_inner_loops.tmpl diff --git a/ggml/src/ggml-webgpu/CMakeLists.txt b/ggml/src/ggml-webgpu/CMakeLists.txt index 3ccce58aa39..1503a1ef8ba 100644 --- a/ggml/src/ggml-webgpu/CMakeLists.txt +++ b/ggml/src/ggml-webgpu/CMakeLists.txt @@ -10,8 +10,11 @@ file(MAKE_DIRECTORY ${SHADER_OUTPUT_DIR}) message(STATUS "Shader output dir: ${SHADER_OUTPUT_DIR}") -# Find all WGSL files -file(GLOB WGSL_SHADER_FILES "${SHADER_DIR}/*.wgsl") +# Find all WGSL sources +file(GLOB WGSL_SHADER_FILES + "${SHADER_DIR}/*.wgsl" + "${SHADER_DIR}/*.tmpl" +) # Generate the header using a Python script add_custom_command( diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index f4c5eca0df5..a5e7de785b4 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -18,6 +18,9 @@ #define GGML_WEBGPU_F32_SIZE_BYTES 4 #define GGML_WEBGPU_I32_SIZE_BYTES 4 #define GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES 8u +#define GGML_WEBGPU_FLASH_ATTN_VEC_MAX_SEQ_LEN 20u +#define GGML_WEBGPU_FLASH_ATTN_VEC_MAX_KV_TILE 32u +#define GGML_WEBGPU_FLASH_ATTN_TILE_MAX_KV_TILE 64u #define GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE 128u // Matches GGML_PAD(..., 256) in src/llama-context.cpp for KV cache sizing. #define GGML_WEBGPU_KV_SEQ_PAD 256u @@ -546,16 +549,10 @@ struct ggml_webgpu_unary_pipeline_key_hash { /** FlashAttention */ -enum ggml_webgpu_flash_attn_path : uint32_t { - GGML_WEBGPU_FLASH_ATTN_PATH_NONE = 0u, - GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX = 1u, - GGML_WEBGPU_FLASH_ATTN_PATH_TILE = 2u, - GGML_WEBGPU_FLASH_ATTN_PATH_VEC = 3u, -}; - -struct ggml_webgpu_flash_attn_pipeline_key { +struct ggml_webgpu_flash_attn_common_pipeline_key { ggml_type q_type; - ggml_type kv_type; + ggml_type k_type; + ggml_type v_type; ggml_type dst_type; uint32_t head_dim_qk; uint32_t head_dim_v; @@ -564,93 +561,224 @@ struct ggml_webgpu_flash_attn_pipeline_key { bool has_mask; bool has_sinks; bool uses_logit_softcap; - uint32_t path; + + bool operator==(const ggml_webgpu_flash_attn_common_pipeline_key & other) const { + return q_type == other.q_type && k_type == other.k_type && v_type == other.v_type && + dst_type == other.dst_type && head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v && + kv_direct == other.kv_direct && kv_overlap == other.kv_overlap && has_mask == other.has_mask && + has_sinks == other.has_sinks && uses_logit_softcap == other.uses_logit_softcap; + } +}; + +inline void ggml_webgpu_flash_attn_hash_common_pipeline_key(size_t & seed, + const ggml_webgpu_flash_attn_common_pipeline_key & key) { + ggml_webgpu_hash_combine(seed, key.q_type); + ggml_webgpu_hash_combine(seed, key.k_type); + ggml_webgpu_hash_combine(seed, key.v_type); + ggml_webgpu_hash_combine(seed, key.dst_type); + ggml_webgpu_hash_combine(seed, key.head_dim_qk); + ggml_webgpu_hash_combine(seed, key.head_dim_v); + ggml_webgpu_hash_combine(seed, key.kv_direct); + ggml_webgpu_hash_combine(seed, key.kv_overlap); + ggml_webgpu_hash_combine(seed, key.has_mask); + ggml_webgpu_hash_combine(seed, key.has_sinks); + ggml_webgpu_hash_combine(seed, key.uses_logit_softcap); +} + +struct ggml_webgpu_flash_attn_vec_pipeline_key { + ggml_webgpu_flash_attn_common_pipeline_key common; + + bool operator==(const ggml_webgpu_flash_attn_vec_pipeline_key & other) const { return common == other.common; } +}; + +struct ggml_webgpu_flash_attn_vec_pipeline_key_hash { + size_t operator()(const ggml_webgpu_flash_attn_vec_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_flash_attn_hash_common_pipeline_key(seed, key.common); + return seed; + } +}; + +struct ggml_webgpu_flash_attn_pipeline_key { + ggml_webgpu_flash_attn_common_pipeline_key common; + bool use_sg_matrix; bool operator==(const ggml_webgpu_flash_attn_pipeline_key & other) const { - return q_type == other.q_type && kv_type == other.kv_type && dst_type == other.dst_type && - head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v && kv_direct == other.kv_direct && - kv_overlap == other.kv_overlap && has_mask == other.has_mask && has_sinks == other.has_sinks && - uses_logit_softcap == other.uses_logit_softcap && path == other.path; + return common == other.common && use_sg_matrix == other.use_sg_matrix; } }; struct ggml_webgpu_flash_attn_pipeline_key_hash { size_t operator()(const ggml_webgpu_flash_attn_pipeline_key & key) const { size_t seed = 0; - ggml_webgpu_hash_combine(seed, key.q_type); - ggml_webgpu_hash_combine(seed, key.kv_type); - ggml_webgpu_hash_combine(seed, key.dst_type); - ggml_webgpu_hash_combine(seed, key.head_dim_qk); - ggml_webgpu_hash_combine(seed, key.head_dim_v); - ggml_webgpu_hash_combine(seed, key.kv_direct); - ggml_webgpu_hash_combine(seed, key.kv_overlap); - ggml_webgpu_hash_combine(seed, key.has_mask); - ggml_webgpu_hash_combine(seed, key.has_sinks); - ggml_webgpu_hash_combine(seed, key.uses_logit_softcap); - ggml_webgpu_hash_combine(seed, key.path); + ggml_webgpu_flash_attn_hash_common_pipeline_key(seed, key.common); + ggml_webgpu_hash_combine(seed, key.use_sg_matrix); return seed; } }; +struct ggml_webgpu_flash_attn_vec_decisions { + uint32_t kv_tile = 0; + uint32_t wg_size = 0; +}; + struct ggml_webgpu_flash_attn_decisions { - uint32_t path = GGML_WEBGPU_FLASH_ATTN_PATH_NONE; - uint32_t q_tile = 0; - uint32_t kv_tile = 0; - uint32_t wg_size = 0; - bool kv_direct = false; - bool kv_overlap = false; + bool use_sg_matrix = false; + uint32_t q_tile = 0; + uint32_t kv_tile = 0; + uint32_t wg_size = 0; }; inline constexpr uint32_t GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH = 4u; inline constexpr uint32_t GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE = 4u; -inline uint32_t ggml_webgpu_flash_attn_pick_vec_ne(const ggml_webgpu_flash_attn_pipeline_key & key) { - if (key.path != GGML_WEBGPU_FLASH_ATTN_PATH_VEC || key.kv_type != GGML_TYPE_F16 || - key.head_dim_qk != key.head_dim_v) { - return 1u; +inline size_t ggml_webgpu_flash_attn_tensor_offset(const ggml_tensor * tensor) { + constexpr uintptr_t ptr_base_addr = 0x1000u; + const ggml_tensor * base = tensor->view_src != nullptr ? tensor->view_src : tensor; + return reinterpret_cast<uintptr_t>(base->data) - ptr_base_addr + tensor->view_offs; +} + +inline bool ggml_webgpu_flash_attn_float_vec4_aligned(const ggml_tensor * K, size_t storage_offset_alignment) { + const uint32_t offset_elems = + (uint32_t) ((ggml_webgpu_flash_attn_tensor_offset(K) & (storage_offset_alignment - 1)) / ggml_type_size(K->type)); + return offset_elems % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0u; +} + +inline bool ggml_webgpu_flash_attn_float_vec4_aligned(const ggml_tensor * K, + const ggml_tensor * V, + size_t storage_offset_alignment) { + return ggml_webgpu_flash_attn_float_vec4_aligned(K, storage_offset_alignment) && + ggml_webgpu_flash_attn_float_vec4_aligned(V, storage_offset_alignment); +} + +inline bool ggml_webgpu_flash_attn_kv_direct( + const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, uint32_t kv_direct_align) { + return K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16 && (Q->ne[0] % kv_direct_align == 0) && + (K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0); +} + +inline ggml_webgpu_flash_attn_common_pipeline_key ggml_webgpu_flash_attn_make_common_pipeline_key( + const ggml_webgpu_shader_lib_context & context, + uint32_t kv_direct_align) { + ggml_webgpu_flash_attn_common_pipeline_key key = {}; + key.q_type = context.src0->type; + key.k_type = context.src1->type; + key.v_type = context.src2->type; + key.dst_type = context.dst->type; + key.head_dim_qk = (uint32_t) context.src0->ne[0]; + key.head_dim_v = (uint32_t) context.src2->ne[0]; + key.kv_direct = ggml_webgpu_flash_attn_kv_direct(context.src0, context.src1, context.src2, kv_direct_align); + key.kv_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src2); + key.has_mask = context.src3 != nullptr; + key.has_sinks = context.src4 != nullptr; + key.uses_logit_softcap = ggml_get_op_params_f32(context.dst, 2) != 0.0f; + return key; +} + +inline std::vector<std::string> ggml_webgpu_flash_attn_common_defines( + const ggml_webgpu_flash_attn_common_pipeline_key & key, + std::string & variant, + uint32_t q_tile, + uint32_t kv_tile, + uint32_t wg_size) { + std::vector<std::string> defines; + + switch (key.k_type) { + case GGML_TYPE_F32: + defines.push_back("K_F32"); + break; + case GGML_TYPE_F16: + defines.push_back("K_F16"); + break; + case GGML_TYPE_Q4_0: + defines.push_back("K_Q4_0"); + break; + case GGML_TYPE_Q8_0: + defines.push_back("K_Q8_0"); + break; + default: + GGML_ABORT("Unsupported K type for flash attention shader"); + } + variant += std::string("_k") + ggml_type_name(key.k_type); + + switch (key.v_type) { + case GGML_TYPE_F32: + defines.push_back("V_F32"); + break; + case GGML_TYPE_F16: + defines.push_back("V_F16"); + break; + case GGML_TYPE_Q4_0: + defines.push_back("V_Q4_0"); + break; + case GGML_TYPE_Q8_0: + defines.push_back("V_Q8_0"); + break; + default: + GGML_ABORT("Unsupported V type for flash attention shader"); + } + variant += std::string("_v") + ggml_type_name(key.v_type); + + switch (key.q_type) { + case GGML_TYPE_F32: + defines.push_back("Q_F32"); + break; + case GGML_TYPE_F16: + defines.push_back("Q_F16"); + break; + default: + GGML_ABORT("Unsupported Q type for flash attention shader"); } + variant += std::string("_q") + ggml_type_name(key.q_type); - switch (key.head_dim_qk) { - case 64: - case 192: - case 576: - return 2u; - case 96: - return 4u; + switch (key.dst_type) { + case GGML_TYPE_F32: + defines.push_back("DST_F32"); + break; + case GGML_TYPE_F16: + defines.push_back("DST_F16"); + break; default: - return 1u; + GGML_ABORT("Unsupported dst type for flash attention shader"); } -} + variant += std::string("_dst") + ggml_type_name(key.dst_type); -inline ggml_webgpu_flash_attn_pipeline_key ggml_webgpu_flash_attn_make_pipeline_key( - const ggml_webgpu_shader_lib_context & context, - const ggml_webgpu_flash_attn_decisions & decisions) { - const bool has_mask = context.src3 != nullptr; - const bool has_sinks = context.src4 != nullptr; - bool kv_direct = false; - if (decisions.path != GGML_WEBGPU_FLASH_ATTN_PATH_TILE) { - uint32_t kv_direct_align = GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH; - if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX) { - kv_direct_align = context.sg_mat_k; - } - kv_direct = (context.src1->type == GGML_TYPE_F16) && - (context.src0->ne[0] % std::max(1u, kv_direct_align) == 0) && - (context.src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0); - } - - ggml_webgpu_flash_attn_pipeline_key key = {}; - key.q_type = context.src0->type; - key.kv_type = context.src1->type; - key.dst_type = context.dst->type; - key.head_dim_qk = (uint32_t) context.src0->ne[0]; - key.head_dim_v = (uint32_t) context.src2->ne[0]; - key.kv_direct = kv_direct; - key.kv_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src2); - key.has_mask = has_mask; - key.has_sinks = has_sinks; - key.uses_logit_softcap = ggml_get_op_params_f32(context.dst, 2) != 0.0f; - key.path = decisions.path; - return key; + if (key.has_mask) { + defines.push_back("MASK"); + variant += "_mask"; + } + if (key.has_sinks) { + defines.push_back("SINKS"); + variant += "_sinks"; + } + if (key.uses_logit_softcap) { + defines.push_back("LOGIT_SOFTCAP"); + variant += "_lgsc"; + } + if (key.kv_direct) { + defines.push_back("KV_DIRECT"); + variant += "_kvdirect"; + } + if (key.kv_overlap) { + defines.push_back("KV_OVERLAP"); + variant += "_kv_overlap"; + } + + defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(key.head_dim_qk)); + variant += std::string("_hsqk") + std::to_string(key.head_dim_qk); + + defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v)); + variant += std::string("_hsv") + std::to_string(key.head_dim_v); + + defines.push_back(std::string("Q_TILE=") + std::to_string(q_tile)); + defines.push_back(std::string("KV_TILE=") + std::to_string(kv_tile)); + defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); + + if (ggml_is_quantized(key.k_type) || ggml_is_quantized(key.v_type)) { + defines.push_back("U32_DEQUANT_HELPERS"); + } + + return defines; } struct ggml_webgpu_flash_attn_vec_reduce_pipeline_key { @@ -688,29 +816,18 @@ struct ggml_webgpu_flash_attn_blk_pipeline_key_hash { } }; -// This is exposed because it's necessary in supports_op +// Note: this will slightly overestimate memory usage for vec path +// since row_max and exp_sum shmem are not needed. inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile, uint32_t kv_tile, uint32_t head_dim_qk, uint32_t head_dim_v, bool has_mask, - bool kv_direct, - uint32_t path = GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX) { + bool kv_direct) { const uint32_t max_head_dim = std::max(head_dim_qk, head_dim_v); size_t f16_elems = 0; size_t f32_elems = 0; - if (path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { - f32_elems += head_dim_qk; // q_shmem - if (!kv_direct) { - f32_elems += kv_tile * max_head_dim; // kv_shmem - } - f32_elems += head_dim_v; // o_shmem - if (has_mask) { - f32_elems += kv_tile; // mask_shmem - } - f32_elems += kv_tile; // inter_shmem - return f32_elems * GGML_WEBGPU_F32_SIZE_BYTES; - } + f32_elems += q_tile * head_dim_qk; // q_shmem if (!kv_direct) { f32_elems += kv_tile * max_head_dim; // kv_shmem @@ -725,25 +842,20 @@ inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile, return f16_elems * GGML_WEBGPU_F16_SIZE_BYTES + f32_elems * GGML_WEBGPU_F32_SIZE_BYTES; } -inline uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_shader_lib_context & context, - const ggml_webgpu_flash_attn_pipeline_key & key) { - const size_t limit_bytes = context.wg_mem_limit_bytes; - uint32_t q_tile = context.sg_mat_m; - uint32_t kv_granularity = std::max(1u, context.sg_mat_n); - if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE) { - q_tile = GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE; - kv_granularity = 1u; - } else if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { - q_tile = 1u; - kv_granularity = 8u; - } - const size_t base_q_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(q_tile, 0, key.head_dim_qk, key.head_dim_v, - key.has_mask, key.kv_direct, key.path); +inline uint32_t ggml_webgpu_flash_attn_max_kv_tile(size_t limit_bytes, + uint32_t q_tile, + uint32_t kv_granularity, + uint32_t head_dim_qk, + uint32_t head_dim_v, + bool has_mask, + bool kv_direct) { + const size_t base_q_bytes = + ggml_webgpu_flash_attn_wg_mem_bytes(q_tile, 0, head_dim_qk, head_dim_v, has_mask, kv_direct); if (limit_bytes <= base_q_bytes) { return 0; } - const size_t one_kv_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(q_tile, 1, key.head_dim_qk, key.head_dim_v, - key.has_mask, key.kv_direct, key.path); + const size_t one_kv_bytes = + ggml_webgpu_flash_attn_wg_mem_bytes(q_tile, 1, head_dim_qk, head_dim_v, has_mask, kv_direct); const size_t bytes_per_kv = one_kv_bytes - base_q_bytes; if (bytes_per_kv == 0) { return 0; @@ -752,105 +864,32 @@ inline uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_shader_lib_ return (uint32_t) ((max_kv_tile / kv_granularity) * kv_granularity); } -inline ggml_webgpu_flash_attn_decisions ggml_webgpu_flash_attn_get_decisions( - const ggml_webgpu_shader_lib_context & context, - size_t storage_offset_alignment) { - ggml_webgpu_flash_attn_decisions decisions = {}; - const size_t alignment = std::max<size_t>(1u, storage_offset_alignment); - const auto * K = context.src1; - const auto * V = context.src2; - GGML_ASSERT(K != nullptr); - GGML_ASSERT(V != nullptr); - - const auto flash_attn_tensor_offset = [](const ggml_tensor * tensor) -> size_t { - constexpr uintptr_t ptr_base_addr = 0x1000u; - const ggml_tensor * base = tensor->view_src != nullptr ? tensor->view_src : tensor; - return reinterpret_cast<uintptr_t>(base->data) - ptr_base_addr + tensor->view_offs; - }; - - const uint32_t k_offset_elems = - (uint32_t) ((flash_attn_tensor_offset(K) & (alignment - 1)) / ggml_type_size(K->type)); - const uint32_t v_offset_elems = - (uint32_t) ((flash_attn_tensor_offset(V) & (alignment - 1)) / ggml_type_size(V->type)); - const bool f16_vec4_aligned = (k_offset_elems % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0u) && - (v_offset_elems % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0u); - const bool kv_vec_type_supported = - K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0; - const uint32_t kv_vec_head_align = - K->type == GGML_TYPE_F16 ? GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH : (uint32_t) ggml_blck_size(K->type); - const bool kv_vec_head_dims_aligned = - context.src0->ne[0] % kv_vec_head_align == 0 && context.src2->ne[0] % kv_vec_head_align == 0; - // Compile with enough invocations to cover the largest reported subgroup. - const bool use_vec = context.supports_subgroups && (context.src0->ne[1] < 20) && kv_vec_head_dims_aligned && - kv_vec_type_supported && (K->type != GGML_TYPE_F16 || f16_vec4_aligned) && - (context.src2->type == K->type); - const bool tile_can_dispatch_all_q_rows = - context.max_subgroup_size > 0 && - context.max_wg_size >= GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE * context.max_subgroup_size; - const bool use_subgroup_matrix = context.supports_subgroup_matrix && context.sg_mat_k > 0 && context.sg_mat_n > 0 && - context.src0->ne[0] % context.sg_mat_k == 0 && - context.src2->ne[0] % context.sg_mat_n == 0; - const bool use_tile = context.supports_subgroups && !use_subgroup_matrix && K->type == GGML_TYPE_F16 && - V->type == GGML_TYPE_F16 && f16_vec4_aligned && - (context.src0->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && - (context.src2->ne[0] % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0) && - tile_can_dispatch_all_q_rows && !use_vec; - - decisions.path = use_vec ? GGML_WEBGPU_FLASH_ATTN_PATH_VEC : - use_tile ? GGML_WEBGPU_FLASH_ATTN_PATH_TILE : - use_subgroup_matrix ? GGML_WEBGPU_FLASH_ATTN_PATH_SUBGROUP_MATRIX : - GGML_WEBGPU_FLASH_ATTN_PATH_NONE; - - if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_NONE) { - return decisions; - } - - const ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context, decisions); - decisions.kv_direct = key.kv_direct; - const uint32_t max_kv_tile = ggml_webgpu_flash_attn_max_kv_tile(context, key); - // invalidate if even the smallest kv_tile doesn't fit in shared memory - if (max_kv_tile == 0) { - decisions.path = GGML_WEBGPU_FLASH_ATTN_PATH_NONE; - return decisions; - } - - if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { - decisions.q_tile = 1u; - decisions.kv_tile = std::max(8u, std::min(32u, max_kv_tile)); - decisions.kv_tile = (decisions.kv_tile / 8u) * 8u; - decisions.wg_size = context.max_subgroup_size; - if (decisions.kv_direct) { - decisions.kv_tile = std::min(decisions.kv_tile, GGML_WEBGPU_KV_SEQ_PAD); - while (GGML_WEBGPU_KV_SEQ_PAD % decisions.kv_tile != 0) { - decisions.kv_tile -= 8u; - } +inline uint32_t ggml_webgpu_flash_attn_get_vec_kv_tile(size_t wg_mem_limit_bytes, + uint32_t head_dim_qk, + uint32_t head_dim_v, + bool has_mask, + bool kv_direct) { + const uint32_t max_kv_tile = + ggml_webgpu_flash_attn_max_kv_tile(wg_mem_limit_bytes, 1u, 1u, head_dim_qk, head_dim_v, has_mask, kv_direct); + GGML_ASSERT(max_kv_tile > 0); + + uint32_t kv_tile = std::min(GGML_WEBGPU_FLASH_ATTN_VEC_MAX_KV_TILE, max_kv_tile); + if (kv_direct) { + kv_tile = std::min(kv_tile, GGML_WEBGPU_KV_SEQ_PAD); + while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) { + kv_tile -= 1u; } - return decisions; } - decisions.q_tile = - decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE : context.sg_mat_m; - decisions.kv_tile = decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? - std::min(64u, max_kv_tile) : - std::min(max_kv_tile, context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES); - decisions.wg_size = decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? - std::min(std::max(1u, context.max_wg_size), - std::max(GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE, - GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE * context.max_subgroup_size)) : - std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE); - - if (decisions.kv_tile == 0) { - return decisions; - } + return kv_tile; +} - if (decisions.kv_direct) { - GGML_ASSERT(decisions.kv_tile <= GGML_WEBGPU_KV_SEQ_PAD); - while (GGML_WEBGPU_KV_SEQ_PAD % decisions.kv_tile != 0) { - decisions.kv_tile -= - decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? context.min_subgroup_size : context.sg_mat_n; - } - } - return decisions; +inline bool ggml_webgpu_flash_attn_can_use_subgroup_matrix_path(bool supports_subgroup_matrix, + uint32_t sg_mat_k, + uint32_t sg_mat_n, + const ggml_tensor * Q, + const ggml_tensor * V) { + return supports_subgroup_matrix && Q->ne[0] % sg_mat_k == 0 && V->ne[0] % sg_mat_n == 0; } /** Matrix Multiplication **/ @@ -1123,6 +1162,10 @@ class ggml_webgpu_shader_lib { concat_pipelines; // type std::unordered_map<ggml_webgpu_repeat_pipeline_key, webgpu_pipeline, ggml_webgpu_repeat_pipeline_key_hash> repeat_pipelines; // type + std::unordered_map<ggml_webgpu_flash_attn_vec_pipeline_key, + webgpu_pipeline, + ggml_webgpu_flash_attn_vec_pipeline_key_hash> + flash_attn_vec_pipelines; std::unordered_map<ggml_webgpu_flash_attn_pipeline_key, webgpu_pipeline, ggml_webgpu_flash_attn_pipeline_key_hash> flash_attn_pipelines; std::unordered_map<ggml_webgpu_flash_attn_vec_reduce_pipeline_key, @@ -1835,10 +1878,10 @@ class ggml_webgpu_shader_lib { ggml_webgpu_mul_mat_vec_pipeline_key key = {}; key.src0_type = context.src0->type; key.src1_type = context.src1->type; - key.vectorized = (context.src0->ne[0] % 4 == 0 && + key.vectorized = (context.src0->ne[0] % 4 == 0 && (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? - 1 : - 0; + 1 : + 0; key.use_mmvq = ggml_webgpu_can_use_mmvq(context.src0, context.src1, context.supports_dot_product, context.vendor); @@ -1971,11 +2014,11 @@ class ggml_webgpu_shader_lib { ggml_webgpu_mul_mat_pipeline_key key = {}; key.src0_type = context.src0->type; key.src1_type = context.src1->type; - key.vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 && - (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? - 1 : - 0; - key.use_subgroup_matrix = context.supports_subgroup_matrix; + key.vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 && + (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? + 1 : + 0; + key.use_subgroup_matrix = context.supports_subgroup_matrix; auto it = mul_mat_fast_pipelines.find(key); if (it != mul_mat_fast_pipelines.end()) { @@ -2148,10 +2191,10 @@ class ggml_webgpu_shader_lib { key.src0_type = context.src0->type; key.src1_type = context.src1->type; key.n_experts = context.src0->ne[2]; - key.vectorized = (context.src0->ne[0] % 4 == 0 && context.src0->ne[1] % 4 == 0 && + key.vectorized = (context.src0->ne[0] % 4 == 0 && context.src0->ne[1] % 4 == 0 && (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? - 1 : - 0; + 1 : + 0; auto it = mul_mat_id_pipelines.find(key); if (it != mul_mat_id_pipelines.end()) { @@ -2271,10 +2314,10 @@ class ggml_webgpu_shader_lib { key.src0_type = context.src0->type; key.src1_type = context.src1->type; key.n_experts = context.src0->ne[2]; - key.vectorized = (context.src0->ne[0] % 4 == 0 && + key.vectorized = (context.src0->ne[0] % 4 == 0 && (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? - 1 : - 0; + 1 : + 0; auto it = mul_mat_id_vec_pipelines.find(key); if (it != mul_mat_id_vec_pipelines.end()) { @@ -2664,119 +2707,62 @@ class ggml_webgpu_shader_lib { return repeat_pipelines[key]; } - webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_shader_lib_context & context, - size_t storage_offset_alignment) { - const ggml_webgpu_flash_attn_decisions decisions = - ggml_webgpu_flash_attn_get_decisions(context, storage_offset_alignment); - GGML_ASSERT(decisions.path != GGML_WEBGPU_FLASH_ATTN_PATH_NONE); - ggml_webgpu_flash_attn_pipeline_key key = ggml_webgpu_flash_attn_make_pipeline_key(context, decisions); - auto it = flash_attn_pipelines.find(key); - if (it != flash_attn_pipelines.end()) { - return it->second; - } - std::vector<std::string> defines; - std::string variant = decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC ? "flash_attn_vec" : - decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE ? "flash_attn_tile" : - "flash_attn"; - - switch (key.kv_type) { - case GGML_TYPE_F32: - defines.push_back("KV_F32"); - break; - case GGML_TYPE_F16: - defines.push_back("KV_F16"); - break; - case GGML_TYPE_Q4_0: - defines.push_back("KV_Q4_0"); - break; - case GGML_TYPE_Q8_0: - defines.push_back("KV_Q8_0"); - break; - default: - GGML_ABORT("Unsupported KV type for flash attention shader"); - } - variant += std::string("_") + ggml_type_name(key.kv_type); - - switch (key.q_type) { - case GGML_TYPE_F32: - defines.push_back("Q_F32"); - break; - case GGML_TYPE_F16: - defines.push_back("Q_F16"); - break; - default: - GGML_ABORT("Unsupported Q type for flash attention shader"); - } - variant += std::string("_q") + ggml_type_name(key.q_type); - - switch (key.dst_type) { - case GGML_TYPE_F32: - defines.push_back("DST_F32"); - break; - case GGML_TYPE_F16: - defines.push_back("DST_F16"); - break; - default: - GGML_ABORT("Unsupported dst type for flash attention shader"); - } - variant += std::string("_dst") + ggml_type_name(key.dst_type); - - if (key.has_mask) { - defines.push_back("MASK"); - if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { - defines.push_back("BLK"); - variant += "_mask_blk"; - } else { - variant += "_mask"; + webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_shader_lib_context & context) { + const bool can_use_subgroup_matrix = ggml_webgpu_flash_attn_can_use_subgroup_matrix_path( + context.supports_subgroup_matrix, context.sg_mat_k, context.sg_mat_n, context.src0, context.src2); + ggml_webgpu_flash_attn_decisions decisions = {}; + decisions.use_sg_matrix = can_use_subgroup_matrix; + decisions.q_tile = decisions.use_sg_matrix ? context.sg_mat_m : GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE; + + ggml_webgpu_flash_attn_pipeline_key key = {}; + key.common = + ggml_webgpu_flash_attn_make_common_pipeline_key(context, decisions.use_sg_matrix ? context.sg_mat_k : 1u); + key.common.kv_direct = decisions.use_sg_matrix && key.common.kv_direct; + key.use_sg_matrix = decisions.use_sg_matrix; + + const uint32_t max_kv_tile = ggml_webgpu_flash_attn_max_kv_tile( + context.wg_mem_limit_bytes, decisions.q_tile, decisions.use_sg_matrix ? context.sg_mat_n : 1u, + key.common.head_dim_qk, key.common.head_dim_v, key.common.has_mask, key.common.kv_direct); + GGML_ASSERT(max_kv_tile > 0); + + decisions.kv_tile = decisions.use_sg_matrix ? + std::min(max_kv_tile, context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES) : + std::min(GGML_WEBGPU_FLASH_ATTN_TILE_MAX_KV_TILE, max_kv_tile); + decisions.wg_size = + decisions.use_sg_matrix ? + std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE) : + std::min(context.max_wg_size, std::max(GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE, + GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE * context.max_subgroup_size)); + + if (key.common.kv_direct) { + decisions.kv_tile = std::min(decisions.kv_tile, GGML_WEBGPU_KV_SEQ_PAD); + while (GGML_WEBGPU_KV_SEQ_PAD % decisions.kv_tile != 0) { + decisions.kv_tile -= decisions.use_sg_matrix ? context.sg_mat_n : context.min_subgroup_size; } } - if (key.has_sinks) { - defines.push_back("SINKS"); - variant += "_sinks"; - } - if (key.uses_logit_softcap) { - defines.push_back("LOGIT_SOFTCAP"); - variant += "_lgsc"; - } - if (key.kv_direct) { - defines.push_back("KV_DIRECT"); - variant += "_kvdirect"; - } - if (key.kv_overlap) { - defines.push_back("KV_OVERLAP"); - variant += "_kv_overlap"; - } - - defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(key.head_dim_qk)); - variant += std::string("_hsqk") + std::to_string(key.head_dim_qk); - defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v)); - variant += std::string("_hsv") + std::to_string(key.head_dim_v); + auto it = flash_attn_pipelines.find(key); + if (it != flash_attn_pipelines.end()) { + return it->second; + } - const char * shader_src = wgsl_flash_attn; - if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { - defines.push_back("KV_GRANULARITY=8"); - defines.push_back(std::string("VEC_NE=") + std::to_string(ggml_webgpu_flash_attn_pick_vec_ne(key)) + "u"); - shader_src = wgsl_flash_attn_vec_split; - } else if (key.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE) { + std::string variant = decisions.use_sg_matrix ? "flash_attn" : "flash_attn_tile"; + std::vector<std::string> defines = ggml_webgpu_flash_attn_common_defines(key.common, variant, decisions.q_tile, + decisions.kv_tile, decisions.wg_size); + const char * shader_src = nullptr; + if (!key.use_sg_matrix) { shader_src = wgsl_flash_attn_tile; defines.push_back("MIN_SUBGROUP_SIZE=" + std::to_string(context.min_subgroup_size) + "u"); defines.push_back("MAX_SUBGROUP_SIZE=" + std::to_string(context.max_subgroup_size) + "u"); - defines.push_back("KV_STAGE_STRIDE=" + std::to_string(std::max(key.head_dim_qk, key.head_dim_v))); variant += "_tile_sg" + std::to_string(context.min_subgroup_size) + "_" + std::to_string(context.max_subgroup_size); } else { + shader_src = wgsl_flash_attn; defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m)); defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n)); defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k)); } - - auto pipeline_decisions = std::make_shared<ggml_webgpu_flash_attn_decisions>(decisions); - pipeline_decisions->kv_overlap = key.kv_overlap; - defines.push_back(std::string("Q_TILE=") + std::to_string(decisions.q_tile)); - defines.push_back(std::string("KV_TILE=") + std::to_string(decisions.kv_tile)); - defines.push_back(std::string("WG_SIZE=") + std::to_string(decisions.wg_size)); - + auto pipeline_decisions = std::make_shared<ggml_webgpu_flash_attn_decisions>(decisions); webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, preprocessor.preprocess(shader_src, defines), variant); pipeline.context = pipeline_decisions; @@ -2784,6 +2770,55 @@ class ggml_webgpu_shader_lib { return flash_attn_pipelines[key]; } + webgpu_pipeline get_flash_attn_vec_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_flash_attn_vec_pipeline_key key = {}; + key.common = ggml_webgpu_flash_attn_make_common_pipeline_key(context, GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH); + + auto it = flash_attn_vec_pipelines.find(key); + if (it != flash_attn_vec_pipelines.end()) { + return it->second; + } + + ggml_webgpu_flash_attn_vec_decisions decisions = {}; + decisions.kv_tile = + ggml_webgpu_flash_attn_get_vec_kv_tile(context.wg_mem_limit_bytes, key.common.head_dim_qk, + key.common.head_dim_v, key.common.has_mask, key.common.kv_direct); + decisions.wg_size = context.max_subgroup_size; + + std::string variant = "flash_attn_vec"; + std::vector<std::string> defines = + ggml_webgpu_flash_attn_common_defines(key.common, variant, 1u, decisions.kv_tile, decisions.wg_size); + if (key.common.has_mask) { + defines.push_back("BLK"); + variant.resize(variant.size() - (sizeof("_mask") - 1)); + variant += "_mask_blk"; + } + uint32_t vec_ne = 1u; + if (key.common.k_type == GGML_TYPE_F16 && key.common.v_type == GGML_TYPE_F16 && + key.common.head_dim_qk == key.common.head_dim_v) { + switch (key.common.head_dim_qk) { + case 64: + case 192: + case 576: + vec_ne = 2u; + break; + case 96: + vec_ne = 4u; + break; + default: + break; + } + } + defines.push_back(std::string("VEC_NE=") + std::to_string(vec_ne) + "u"); + + auto pipeline_decisions = std::make_shared<ggml_webgpu_flash_attn_vec_decisions>(decisions); + webgpu_pipeline pipeline = + ggml_webgpu_create_pipeline(device, preprocessor.preprocess(wgsl_flash_attn_vec_split, defines), variant); + pipeline.context = pipeline_decisions; + flash_attn_vec_pipelines[key] = pipeline; + return flash_attn_vec_pipelines[key]; + } + webgpu_pipeline get_flash_attn_blk_pipeline(const ggml_webgpu_shader_lib_context & context, uint32_t kv_tile) { ggml_webgpu_flash_attn_blk_pipeline_key key = {}; key.kv_tile = kv_tile; diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index d577b5afa3c..c6cfb0bbbad 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -1755,13 +1755,50 @@ static webgpu_encoded_op ggml_webgpu_mul_mat_id(webgpu_context & ctx, return ggml_backend_webgpu_build_multi(ctx, dispatches); } -static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, - ggml_tensor * Q, - ggml_tensor * K, - ggml_tensor * V, - ggml_tensor * mask, - ggml_tensor * sinks, - ggml_tensor * dst) { +struct ggml_webgpu_flash_attn_op { + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + std::vector<uint32_t> params; + std::vector<wgpu::BindGroupEntry> entries; + size_t kv_bind_offset = 0; + size_t kv_bind_size = 0; + bool has_mask = false; + bool has_sinks = false; + bool kv_overlap = false; +}; + +static bool ggml_webgpu_flash_attn_use_vec_path(const webgpu_global_context & global_ctx, + const ggml_tensor * Q, + const ggml_tensor * K, + const ggml_tensor * V) { + const size_t storage_offset_alignment = global_ctx->capabilities.limits.minStorageBufferOffsetAlignment; + const bool k_float_vec4_aligned = (K->type != GGML_TYPE_F16 && K->type != GGML_TYPE_F32) || + ggml_webgpu_flash_attn_float_vec4_aligned(K, storage_offset_alignment); + const bool v_float_vec4_aligned = (V->type != GGML_TYPE_F16 && V->type != GGML_TYPE_F32) || + ggml_webgpu_flash_attn_float_vec4_aligned(V, storage_offset_alignment); + const bool k_vec_type_supported = + K->type == GGML_TYPE_F32 || K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0; + const bool v_vec_type_supported = + V->type == GGML_TYPE_F32 || V->type == GGML_TYPE_F16 || V->type == GGML_TYPE_Q4_0 || V->type == GGML_TYPE_Q8_0; + const uint32_t k_vec_head_align = (K->type == GGML_TYPE_F32 || K->type == GGML_TYPE_F16) ? + GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH : + (uint32_t) ggml_blck_size(K->type); + const uint32_t v_vec_head_align = (V->type == GGML_TYPE_F32 || V->type == GGML_TYPE_F16) ? + GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH : + (uint32_t) ggml_blck_size(V->type); + const bool kv_vec_head_dims_aligned = Q->ne[0] % k_vec_head_align == 0 && V->ne[0] % v_vec_head_align == 0; + + return global_ctx->capabilities.supports_subgroups && (Q->ne[1] < GGML_WEBGPU_FLASH_ATTN_VEC_MAX_SEQ_LEN) && + kv_vec_head_dims_aligned && k_vec_type_supported && v_vec_type_supported && k_float_vec4_aligned && + v_float_vec4_aligned; +} + +static ggml_webgpu_flash_attn_op ggml_webgpu_flash_attn_prepare(webgpu_context & ctx, + ggml_tensor * Q, + ggml_tensor * K, + ggml_tensor * V, + ggml_tensor * mask, + ggml_tensor * sinks, + ggml_tensor * dst) { float scale = ggml_get_op_params_f32(dst, 0); float max_bias = ggml_get_op_params_f32(dst, 1); float logit_softcap = ggml_get_op_params_f32(dst, 2); @@ -1772,47 +1809,43 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, float m0 = powf(2.0f, -(max_bias) / n_head_log2); float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - ggml_webgpu_shader_lib_context shader_lib_ctx = {}; - shader_lib_ctx.src0 = Q; - shader_lib_ctx.src1 = K; - shader_lib_ctx.src2 = V; - shader_lib_ctx.src3 = mask; - shader_lib_ctx.src4 = sinks; - shader_lib_ctx.dst = dst; - shader_lib_ctx.supports_subgroups = ctx->global_ctx->capabilities.supports_subgroups; - shader_lib_ctx.supports_subgroup_matrix = ctx->global_ctx->capabilities.supports_subgroup_matrix; - shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; - shader_lib_ctx.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; - shader_lib_ctx.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m; - shader_lib_ctx.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n; - shader_lib_ctx.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k; - shader_lib_ctx.min_subgroup_size = ctx->global_ctx->capabilities.min_subgroup_size; - shader_lib_ctx.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size; - webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_pipeline( - shader_lib_ctx, ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment); - auto * decisions = static_cast<ggml_webgpu_flash_attn_decisions *>(pipeline.context.get()); - const int has_mask = (mask != nullptr); - const int has_sinks = (sinks != nullptr); - const bool kv_overlap = decisions->kv_overlap; - - uint32_t offset_k = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type)); - uint32_t offset_v = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type)); - size_t kv_bind_offset = 0; - size_t kv_bind_size = 0; - if (kv_overlap) { + ggml_webgpu_flash_attn_op op = {}; + op.shader_lib_ctx.src0 = Q; + op.shader_lib_ctx.src1 = K; + op.shader_lib_ctx.src2 = V; + op.shader_lib_ctx.src3 = mask; + op.shader_lib_ctx.src4 = sinks; + op.shader_lib_ctx.dst = dst; + op.shader_lib_ctx.supports_subgroups = ctx->global_ctx->capabilities.supports_subgroups; + op.shader_lib_ctx.supports_subgroup_matrix = ctx->global_ctx->capabilities.supports_subgroup_matrix; + op.shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + op.shader_lib_ctx.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; + op.shader_lib_ctx.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m; + op.shader_lib_ctx.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n; + op.shader_lib_ctx.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k; + op.shader_lib_ctx.min_subgroup_size = ctx->global_ctx->capabilities.min_subgroup_size; + op.shader_lib_ctx.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size; + + op.has_mask = mask != nullptr; + op.has_sinks = sinks != nullptr; + op.kv_overlap = ggml_webgpu_tensor_overlap(K, V); + + uint32_t offset_k = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type)); + uint32_t offset_v = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type)); + if (op.kv_overlap) { const ggml_webgpu_merged_binding_range merged_range = ggml_webgpu_tensor_merged_binding_range(ctx, { K, V }); - kv_bind_offset = merged_range.offset; - kv_bind_size = merged_range.size; + op.kv_bind_offset = merged_range.offset; + op.kv_bind_size = merged_range.size; offset_k = ggml_webgpu_tensor_merged_element_offset(K, merged_range); offset_v = ggml_webgpu_tensor_merged_element_offset(V, merged_range); } - std::vector<uint32_t> params = { + op.params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, Q) / ggml_type_size(Q->type)), offset_k, offset_v, - has_mask ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)) : 0, - has_sinks ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, sinks) / ggml_type_size(sinks->type)) : 0, + op.has_mask ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)) : 0, + op.has_sinks ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, sinks) / ggml_type_size(sinks->type)) : 0, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), (uint32_t) Q->ne[2], // number of heads (uint32_t) Q->ne[1], // sequence length (Q) @@ -1826,7 +1859,7 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, (uint32_t) (V->nb[1] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 1 (uint32_t) (V->nb[2] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 2 (uint32_t) (V->nb[3] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 3 - has_mask ? (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)) : 0, // stride of mask dim 3 + op.has_mask ? (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)) : 0, // stride of mask dim 3 (uint32_t) (Q->ne[2] / K->ne[2]), // repeat factor for K/V in dim 2 (MHA/MQA/GQA) ggml_webgpu_u32_from_f32(scale), // scale (possibly adjusted for logit softcap) ggml_webgpu_u32_from_f32(max_bias), @@ -1834,32 +1867,56 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, ggml_webgpu_u32_from_f32(n_head_log2), ggml_webgpu_u32_from_f32(m0), ggml_webgpu_u32_from_f32(m1) - }; - std::vector<wgpu::BindGroupEntry> entries = { + op.entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, Q), }; - if (kv_overlap) { - entries.push_back( - ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(K), kv_bind_offset, kv_bind_size)); + if (op.kv_overlap) { + op.entries.push_back( + ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(K), op.kv_bind_offset, op.kv_bind_size)); } else { - entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, K)); - entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, V)); + op.entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, K)); + op.entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, V)); } - uint32_t binding_index = kv_overlap ? 2u : 3u; - if (has_mask) { - entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index++, mask)); + uint32_t binding_index = op.kv_overlap ? 2u : 3u; + if (op.has_mask) { + op.entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index++, mask)); } - if (has_sinks) { - entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index++, sinks)); + if (op.has_sinks) { + op.entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index++, sinks)); } - entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index++, dst)); + op.entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index++, dst)); - if (decisions->path != GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { - uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions->q_tile); - uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches - return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); + return op; +} + +static uint32_t ggml_webgpu_flash_attn_vec_nwg(uint32_t vec_nwg_cap, uint32_t kv_tile, uint32_t seq_len_kv) { + uint32_t nwg = 1u; + const uint64_t kv_span = (uint64_t) kv_tile; + while ((2u * nwg * kv_span) < (uint64_t) seq_len_kv && nwg < vec_nwg_cap) { + nwg <<= 1; } + return std::min(nwg, vec_nwg_cap); +} + +static webgpu_encoded_op ggml_webgpu_flash_attn_direct(webgpu_context & ctx, const ggml_webgpu_flash_attn_op & op) { + webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_pipeline(op.shader_lib_ctx); + auto * decisions = static_cast<ggml_webgpu_flash_attn_decisions *>(pipeline.context.get()); + uint32_t wg_per_head = CEIL_DIV(op.shader_lib_ctx.src0->ne[1], decisions->q_tile); + uint32_t wg_x = wg_per_head * op.shader_lib_ctx.src0->ne[2] * op.shader_lib_ctx.src0->ne[3]; + return ggml_backend_webgpu_build(ctx, pipeline, op.params, op.entries, wg_x); +} + +static webgpu_encoded_op ggml_webgpu_flash_attn_vec(webgpu_context & ctx, + ggml_tensor * Q, + ggml_tensor * K, + ggml_tensor * V, + ggml_tensor * mask, + ggml_tensor * sinks, + ggml_tensor * dst, + ggml_webgpu_flash_attn_op op) { + webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_vec_pipeline(op.shader_lib_ctx); + auto * decisions = static_cast<ggml_webgpu_flash_attn_vec_decisions *>(pipeline.context.get()); wgpu::Buffer blk_buf = {}; uint64_t blk_size_bytes = 0; @@ -1868,13 +1925,8 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, uint32_t blk_batch_count = 0; const uint32_t vec_nwg_cap = ctx->global_ctx->capabilities.min_subgroup_size; - uint32_t nwg = 1u; - const uint64_t kv_span = (uint64_t) std::max(1u, decisions->kv_tile); - while ((2u * nwg * kv_span) < (uint64_t) K->ne[1] && nwg < vec_nwg_cap) { - nwg <<= 1; - } - nwg = std::min(nwg, vec_nwg_cap); - const uint64_t nrows = (uint64_t) Q->ne[1] * Q->ne[2] * Q->ne[3]; + uint32_t nwg = ggml_webgpu_flash_attn_vec_nwg(vec_nwg_cap, decisions->kv_tile, (uint32_t) K->ne[1]); + const uint64_t nrows = (uint64_t) Q->ne[1] * Q->ne[2] * Q->ne[3]; const bool use_vec_reduce = nwg > 1u; GGML_ASSERT(nrows <= UINT32_MAX); @@ -1910,7 +1962,7 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, webgpu_pipeline blk_pipeline; std::vector<uint32_t> blk_params; std::vector<wgpu::BindGroupEntry> blk_entries; - if (has_mask) { + if (op.has_mask) { blk_nblk0 = CEIL_DIV((uint32_t) K->ne[1], decisions->kv_tile); blk_nblk1 = (uint32_t) Q->ne[1]; blk_buf = ggml_webgpu_tensor_buf(dst); @@ -1918,7 +1970,7 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, blk_batch_count = stride_mask3 > 0 ? (uint32_t) Q->ne[3] : 1u; const uint64_t blk_elems = (uint64_t) blk_nblk0 * blk_nblk1 * blk_batch_count; blk_size_bytes = ROUNDUP_POW2(blk_elems * sizeof(uint32_t), WEBGPU_STORAGE_BUF_BINDING_MULT); - const ggml_webgpu_shader_lib_context blk_shader_ctx = shader_lib_ctx; + const ggml_webgpu_shader_lib_context blk_shader_ctx = op.shader_lib_ctx; blk_pipeline = ctx->shader_lib->get_flash_attn_blk_pipeline(blk_shader_ctx, decisions->kv_tile); blk_params = { @@ -1938,8 +1990,8 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, scratch_offset = ROUNDUP_POW2(scratch_offset + blk_size_bytes, align_bytes); } - std::vector<uint32_t> split_params = params; - if (has_mask) { + std::vector<uint32_t> split_params = op.params; + if (op.has_mask) { split_params.push_back(0u); // blk_base split_params.push_back(blk_nblk0); // blk_nblk0 split_params.push_back(blk_nblk1); // blk_nblk1 @@ -1952,9 +2004,9 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(Q), ggml_webgpu_tensor_align_offset(ctx, Q), ggml_webgpu_tensor_binding_size(ctx, Q)), }; - if (kv_overlap) { + if (op.kv_overlap) { split_entries.push_back( - ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(K), kv_bind_offset, kv_bind_size)); + ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(K), op.kv_bind_offset, op.kv_bind_size)); } else { split_entries.push_back(ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(K), ggml_webgpu_tensor_align_offset(ctx, K), @@ -1963,18 +2015,18 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, ggml_webgpu_tensor_align_offset(ctx, V), ggml_webgpu_tensor_binding_size(ctx, V))); } - uint32_t split_binding_index = kv_overlap ? 2u : 3u; - if (has_mask) { + uint32_t split_binding_index = op.kv_overlap ? 2u : 3u; + if (op.has_mask) { split_entries.push_back(ggml_webgpu_make_bind_group_entry(split_binding_index++, ggml_webgpu_tensor_buf(mask), ggml_webgpu_tensor_align_offset(ctx, mask), ggml_webgpu_tensor_binding_size(ctx, mask))); } - if (has_sinks) { + if (op.has_sinks) { split_entries.push_back(ggml_webgpu_make_bind_group_entry(split_binding_index++, ggml_webgpu_tensor_buf(sinks), ggml_webgpu_tensor_align_offset(ctx, sinks), ggml_webgpu_tensor_binding_size(ctx, sinks))); } - if (has_mask) { + if (op.has_mask) { split_entries.push_back( ggml_webgpu_make_bind_group_entry(split_binding_index++, blk_buf, blk_entries[1].offset, blk_size_bytes)); } @@ -1993,7 +2045,7 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, reduce_sg_size, (uint32_t) std::min<uint64_t>((uint64_t) nwg * reduce_sg_size, ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup)); - ggml_webgpu_shader_lib_context reduce_shader_ctx = shader_lib_ctx; + ggml_webgpu_shader_lib_context reduce_shader_ctx = op.shader_lib_ctx; reduce_shader_ctx.max_wg_size = reduce_wg_size; reduce_pipeline = ctx->shader_lib->get_flash_attn_vec_reduce_pipeline(reduce_shader_ctx); @@ -2020,7 +2072,7 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, std::vector<webgpu_dispatch_desc> dispatches; - if (has_mask) { + if (op.has_mask) { dispatches.push_back({ blk_pipeline, std::move(blk_params), std::move(blk_entries), { blk_nblk0, blk_nblk1 * blk_batch_count } }); @@ -2037,6 +2089,20 @@ static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, return ggml_backend_webgpu_build_multi(ctx, dispatches); } +static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, + ggml_tensor * Q, + ggml_tensor * K, + ggml_tensor * V, + ggml_tensor * mask, + ggml_tensor * sinks, + ggml_tensor * dst) { + ggml_webgpu_flash_attn_op op = ggml_webgpu_flash_attn_prepare(ctx, Q, K, V, mask, sinks, dst); + if (ggml_webgpu_flash_attn_use_vec_path(ctx->global_ctx, Q, K, V)) { + return ggml_webgpu_flash_attn_vec(ctx, Q, K, V, mask, sinks, dst, std::move(op)); + } + return ggml_webgpu_flash_attn_direct(ctx, op); +} + static webgpu_encoded_op ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { bool is_unary = dst->op == GGML_OP_UNARY; @@ -3553,70 +3619,43 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer break; case GGML_OP_FLASH_ATTN_EXT: { - const ggml_tensor * Q = tensor->src[0]; - const ggml_tensor * K = tensor->src[1]; - const ggml_tensor * V = tensor->src[2]; - const ggml_tensor * mask = tensor->src[3]; - const ggml_tensor * sinks = tensor->src[4]; - if (Q && K && V) { - ggml_webgpu_shader_lib_context shader_lib_ctx = {}; - shader_lib_ctx.src0 = const_cast<ggml_tensor *>(Q); - shader_lib_ctx.src1 = const_cast<ggml_tensor *>(K); - shader_lib_ctx.src2 = const_cast<ggml_tensor *>(V); - shader_lib_ctx.src3 = const_cast<ggml_tensor *>(mask); - shader_lib_ctx.src4 = const_cast<ggml_tensor *>(sinks); - shader_lib_ctx.dst = const_cast<ggml_tensor *>(tensor); - shader_lib_ctx.max_wg_size = - ctx->webgpu_global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; - shader_lib_ctx.wg_mem_limit_bytes = - ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; - shader_lib_ctx.supports_subgroups = ctx->webgpu_global_ctx->capabilities.supports_subgroups; - shader_lib_ctx.supports_subgroup_matrix = - ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix; - shader_lib_ctx.sg_mat_m = ctx->webgpu_global_ctx->capabilities.sg_mat_m; - shader_lib_ctx.sg_mat_n = ctx->webgpu_global_ctx->capabilities.sg_mat_n; - shader_lib_ctx.sg_mat_k = ctx->webgpu_global_ctx->capabilities.sg_mat_k; - shader_lib_ctx.min_subgroup_size = ctx->webgpu_global_ctx->capabilities.min_subgroup_size; - shader_lib_ctx.max_subgroup_size = ctx->webgpu_global_ctx->capabilities.max_subgroup_size; - - const ggml_webgpu_flash_attn_decisions decisions = ggml_webgpu_flash_attn_get_decisions( - shader_lib_ctx, ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment); - - if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { - const uint32_t kv_tile = decisions.kv_tile; - - const uint32_t vec_nwg_cap = ctx->webgpu_global_ctx->capabilities.min_subgroup_size; - uint32_t nwg = 1u; - const uint64_t kv_span = (uint64_t) std::max(1u, kv_tile); - while ((2u * nwg * kv_span) < (uint64_t) K->ne[1] && nwg < vec_nwg_cap) { - nwg <<= 1; - } - nwg = std::min(nwg, vec_nwg_cap); - - const size_t align = - ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment; - const uint64_t nrows = (uint64_t) Q->ne[1] * Q->ne[2] * Q->ne[3]; - if (nwg > 1u) { - const uint64_t tmp_data_elems = nrows * (uint64_t) V->ne[0] * nwg; - const uint64_t tmp_stats_elems = nrows * 2u * nwg; - const size_t tmp_size_bytes = ROUNDUP_POW2( - (tmp_data_elems + tmp_stats_elems) * sizeof(float), WEBGPU_STORAGE_BUF_BINDING_MULT); - res += tmp_size_bytes + align; - } else { - res += WEBGPU_STORAGE_BUF_BINDING_MULT + align; - } - if (mask != nullptr) { - const uint32_t blk_nblk0 = CEIL_DIV((uint32_t) K->ne[1], kv_tile); - const uint32_t blk_nblk1 = CEIL_DIV((uint32_t) Q->ne[1], 1u); - const uint32_t stride_mask3 = (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)); - const uint32_t blk_batch_count = stride_mask3 > 0 ? (uint32_t) Q->ne[3] : 1u; - const uint64_t blk_elems = (uint64_t) blk_nblk0 * blk_nblk1 * blk_batch_count; - const size_t blk_size_bytes = - ROUNDUP_POW2(blk_elems * sizeof(uint32_t), WEBGPU_STORAGE_BUF_BINDING_MULT); - res += blk_size_bytes + align; - } - res = ROUNDUP_POW2(res, WEBGPU_STORAGE_BUF_BINDING_MULT); + const ggml_tensor * Q = tensor->src[0]; + const ggml_tensor * K = tensor->src[1]; + const ggml_tensor * V = tensor->src[2]; + const ggml_tensor * mask = tensor->src[3]; + const auto & capabilities = ctx->webgpu_global_ctx->capabilities; + if (ggml_webgpu_flash_attn_use_vec_path(ctx->webgpu_global_ctx, Q, K, V)) { + const bool kv_direct = + ggml_webgpu_flash_attn_kv_direct(Q, K, V, GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH); + const uint32_t kv_tile = ggml_webgpu_flash_attn_get_vec_kv_tile( + capabilities.limits.maxComputeWorkgroupStorageSize, (uint32_t) Q->ne[0], (uint32_t) V->ne[0], + mask != nullptr, kv_direct); + + const uint32_t vec_nwg_cap = capabilities.min_subgroup_size; + uint32_t nwg = ggml_webgpu_flash_attn_vec_nwg(vec_nwg_cap, kv_tile, (uint32_t) K->ne[1]); + + const size_t align = capabilities.limits.minStorageBufferOffsetAlignment; + const uint64_t nrows = (uint64_t) Q->ne[1] * Q->ne[2] * Q->ne[3]; + if (nwg > 1u) { + const uint64_t tmp_data_elems = nrows * (uint64_t) V->ne[0] * nwg; + const uint64_t tmp_stats_elems = nrows * 2u * nwg; + const size_t tmp_size_bytes = ROUNDUP_POW2((tmp_data_elems + tmp_stats_elems) * sizeof(float), + WEBGPU_STORAGE_BUF_BINDING_MULT); + res += tmp_size_bytes + align; + } else { + res += WEBGPU_STORAGE_BUF_BINDING_MULT + align; } + if (mask != nullptr) { + const uint32_t blk_nblk0 = CEIL_DIV((uint32_t) K->ne[1], kv_tile); + const uint32_t blk_nblk1 = CEIL_DIV((uint32_t) Q->ne[1], 1u); + const uint32_t stride_mask3 = (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)); + const uint32_t blk_batch_count = stride_mask3 > 0 ? (uint32_t) Q->ne[3] : 1u; + const uint64_t blk_elems = (uint64_t) blk_nblk0 * blk_nblk1 * blk_batch_count; + const size_t blk_size_bytes = + ROUNDUP_POW2(blk_elems * sizeof(uint32_t), WEBGPU_STORAGE_BUF_BINDING_MULT); + res += blk_size_bytes + align; + } + res = ROUNDUP_POW2(res, WEBGPU_STORAGE_BUF_BINDING_MULT); } } break; @@ -4139,70 +4178,63 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const break; case GGML_OP_FLASH_ATTN_EXT: { + // conservative support checks for whether the more resource-intensive shader paths + // can be used, to avoid cases where flash_attn is assigned to the CPU later on supports_op = src0->type == GGML_TYPE_F32 && (src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_Q4_0 || src1->type == GGML_TYPE_Q8_0) && - src2->type == src1->type && op->type == GGML_TYPE_F32; + (src2->type == GGML_TYPE_F32 || src2->type == GGML_TYPE_F16 || + src2->type == GGML_TYPE_Q4_0 || src2->type == GGML_TYPE_Q8_0) && + op->type == GGML_TYPE_F32; if (!supports_op) { break; } - ggml_webgpu_shader_lib_context shader_lib_ctx = {}; - shader_lib_ctx.src0 = src0; - shader_lib_ctx.src1 = src1; - shader_lib_ctx.src2 = src2; - shader_lib_ctx.src3 = op->src[3]; - shader_lib_ctx.src4 = op->src[4]; - shader_lib_ctx.dst = const_cast<ggml_tensor *>(op); - shader_lib_ctx.supports_subgroups = ctx->webgpu_global_ctx->capabilities.supports_subgroups; - shader_lib_ctx.supports_subgroup_matrix = ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix; - shader_lib_ctx.max_wg_size = - ctx->webgpu_global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; - shader_lib_ctx.wg_mem_limit_bytes = - ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; - shader_lib_ctx.sg_mat_m = ctx->webgpu_global_ctx->capabilities.sg_mat_m; - shader_lib_ctx.sg_mat_n = ctx->webgpu_global_ctx->capabilities.sg_mat_n; - shader_lib_ctx.sg_mat_k = ctx->webgpu_global_ctx->capabilities.sg_mat_k; - shader_lib_ctx.min_subgroup_size = ctx->webgpu_global_ctx->capabilities.min_subgroup_size; - shader_lib_ctx.max_subgroup_size = ctx->webgpu_global_ctx->capabilities.max_subgroup_size; - - const ggml_webgpu_flash_attn_decisions decisions = ggml_webgpu_flash_attn_get_decisions( - shader_lib_ctx, ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment); - const size_t limit_bytes = ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; - const bool has_mask = op->src[3] != nullptr; - if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_NONE) { + if (ggml_webgpu_tensor_overlap(src1, src2) && src1->type != src2->type && + !ggml_is_quantized(src1->type) && !ggml_is_quantized(src2->type)) { supports_op = false; break; } - if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_VEC) { - const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes( - decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask, - decisions.kv_direct, decisions.path); - if (min_bytes > limit_bytes) { - supports_op = false; - } - break; - } - - if (decisions.path == GGML_WEBGPU_FLASH_ATTN_PATH_TILE) { - const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes( - decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask, - decisions.kv_direct, decisions.path); - if (min_bytes > limit_bytes) { - supports_op = false; - } - break; - } - - if (!ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix) { + const auto & capabilities = ctx->webgpu_global_ctx->capabilities; + const size_t storage_offset_alignment = capabilities.limits.minStorageBufferOffsetAlignment; + + // subgroup matrix path requirements + const bool use_subgroup_matrix = ggml_webgpu_flash_attn_can_use_subgroup_matrix_path( + capabilities.supports_subgroup_matrix, capabilities.sg_mat_k, capabilities.sg_mat_n, src0, src2); + + // tile path requirements + const bool float_vec4_aligned = + ((src1->type != GGML_TYPE_F16 && src1->type != GGML_TYPE_F32) || + ggml_webgpu_flash_attn_float_vec4_aligned(src1, storage_offset_alignment)) && + ((src2->type != GGML_TYPE_F16 && src2->type != GGML_TYPE_F32) || + ggml_webgpu_flash_attn_float_vec4_aligned(src2, storage_offset_alignment)); + const uint32_t k_tile_head_align = (src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16) ? + GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH : + (uint32_t) ggml_blck_size(src1->type); + const uint32_t v_tile_head_align = (src2->type == GGML_TYPE_F32 || src2->type == GGML_TYPE_F16) ? + GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH : + (uint32_t) ggml_blck_size(src2->type); + const bool tile_kv_head_dims_aligned = + src0->ne[0] % k_tile_head_align == 0 && src2->ne[0] % v_tile_head_align == 0; + const bool tile_can_dispatch_all_q_rows = + capabilities.limits.maxComputeInvocationsPerWorkgroup >= + GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE * capabilities.max_subgroup_size; + const bool use_tile = !use_subgroup_matrix && capabilities.supports_subgroups && float_vec4_aligned && + tile_kv_head_dims_aligned && tile_can_dispatch_all_q_rows; + + if (!use_subgroup_matrix && !use_tile) { supports_op = false; break; } - const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes( - decisions.q_tile, decisions.kv_tile, (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], has_mask, - decisions.kv_direct, decisions.path); - if (min_bytes > limit_bytes) { - supports_op = false; - } + const uint32_t q_tile = + use_subgroup_matrix ? capabilities.sg_mat_m : GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE; + const uint32_t kv_granularity = use_subgroup_matrix ? capabilities.sg_mat_n : 1u; + const bool kv_direct = use_subgroup_matrix ? + ggml_webgpu_flash_attn_kv_direct(src0, src1, src2, capabilities.sg_mat_k) : + false; + const uint32_t max_kv_tile = ggml_webgpu_flash_attn_max_kv_tile( + capabilities.limits.maxComputeWorkgroupStorageSize, q_tile, kv_granularity, (uint32_t) src0->ne[0], + (uint32_t) src2->ne[0], op->src[3] != nullptr, kv_direct); + supports_op = max_kv_tile > 0; break; } case GGML_OP_RMS_NORM: diff --git a/ggml/src/ggml-webgpu/pre_wgsl.hpp b/ggml/src/ggml-webgpu/pre_wgsl.hpp index 4d4359463ca..fb41a961d74 100644 --- a/ggml/src/ggml-webgpu/pre_wgsl.hpp +++ b/ggml/src/ggml-webgpu/pre_wgsl.hpp @@ -37,15 +37,33 @@ static std::string trim(const std::string & s) { } static std::string trim_value(std::istream & is) { - std::string str; - std::getline(is, str); - return trim(str); + std::ostringstream ss; + ss << is.rdbuf(); + return trim(ss.str()); } static bool isIdentChar(char c) { return std::isalnum(static_cast<unsigned char>(c)) || c == '_'; } +static bool endsWithContinuation(const std::string & line) { + size_t i = line.size(); + while (i > 0 && std::isspace((unsigned char) line[i - 1])) { + i--; + } + return i > 0 && line[i - 1] == '\\'; +} + +static void stripContinuation(std::string & line) { + size_t i = line.size(); + while (i > 0 && std::isspace((unsigned char) line[i - 1])) { + i--; + } + if (i > 0 && line[i - 1] == '\\') { + line.erase(i - 1); + } +} + static std::string expandMacrosRecursiveInternal(const std::string & line, const std::unordered_map<std::string, std::string> & macros, std::unordered_set<std::string> & visiting); @@ -595,19 +613,31 @@ class Preprocessor { std::string line; while (std::getline(in, line)) { - std::string t = trim(line); + std::string logical = line; + std::string t = trim(logical); + if (!t.empty() && t[0] == '#') { + while (endsWithContinuation(logical)) { + stripContinuation(logical); + if (!std::getline(in, line)) { + break; + } + logical += "\n"; + logical += line; + } + t = trim(logical); + } if (!t.empty() && t[0] == '#') { bool handled = handleDirective(t, out, macros, predefined_macros, cond, include_stack, mode); if (mode == DirectiveMode::IncludesOnly && !handled) { - out << line << "\n"; + out << logical << "\n"; } } else { if (mode == DirectiveMode::IncludesOnly) { - out << line << "\n"; + out << logical << "\n"; } else if (condActive(cond)) { // Expand macros in the line before outputting - std::string expanded = expandMacrosRecursive(line, macros); + std::string expanded = expandMacrosRecursive(logical, macros); out << expanded << "\n"; } } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl index 6d5d69fb8de..9767ca3d754 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl @@ -4,12 +4,23 @@ enable f16; enable subgroups; enable chromium_experimental_subgroup_matrix; -#ifdef KV_F32 -#define KV_TYPE f32 -#elif defined(KV_Q4_0) || defined(KV_Q8_0) -#define KV_TYPE u32 +#define BYTE_HELPERS +#include "common_decls.tmpl" + +#ifdef K_F32 +#define K_TYPE f32 +#elif defined(K_Q4_0) || defined(K_Q8_0) +#define K_TYPE u32 +#else +#define K_TYPE f16 +#endif + +#ifdef V_F32 +#define V_TYPE f32 +#elif defined(V_Q4_0) || defined(V_Q8_0) +#define V_TYPE u32 #else -#define KV_TYPE f16 +#define V_TYPE f16 #endif // Default values @@ -30,76 +41,6 @@ enable chromium_experimental_subgroup_matrix; // Number of subgroup-matrix-width blocks that span the KV tile. SG_MAT_N must divide KV_TILE. #define KV_BLOCKS (KV_TILE / SG_MAT_N) -// Quantization constants/helpers -#define BLOCK_SIZE 32 -#define BLOCKS_K ((HEAD_DIM_QK + BLOCK_SIZE - 1) / BLOCK_SIZE) -#define BLOCKS_V ((HEAD_DIM_V + BLOCK_SIZE - 1) / BLOCK_SIZE) -// number of quantized elements processed per thread -#if defined(KV_Q4_0) -#define NQ 16 -// Q4_0 has 32 elements, 1 f16 for scale, 8 f16 for 4-bit weights -#define F16_PER_BLOCK 9 -#define BLOCK_SIZE_BYTES 18u -#define WEIGHTS_PER_F16 4 -#elif defined(KV_Q8_0) -#define NQ 8 -// Q8_0 has 32 elements, 1 f16 for scale, 16 f16 for 8-bit weights -#define F16_PER_BLOCK 17 -#define BLOCK_SIZE_BYTES 34u -#define WEIGHTS_PER_F16 2 -#endif -#define F16_PER_THREAD (NQ / WEIGHTS_PER_F16) - -// Ok not to put these in a define block, compiler will remove if unused -fn get_byte(value: u32, index: u32) -> u32 { - return (value >> (index * 8)) & 0xFF; -} - -fn get_byte_i32(value: u32, index: u32) -> i32 { - return bitcast<i32>(((value >> (index * 8)) & 0xFF) << 24) >> 24; -} - -#if defined(KV_Q4_0) || defined(KV_Q8_0) -fn load_k_u16_at(byte_offset: u32) -> u32 { - let word = K[byte_offset / 4u]; - let shift = (byte_offset & 2u) * 8u; - return (word >> shift) & 0xFFFFu; -} - -fn load_k_u32_at(byte_offset: u32) -> u32 { - let word_idx = byte_offset / 4u; - let shift = (byte_offset & 3u) * 8u; - let lo = K[word_idx]; - if (shift == 0u) { - return lo; - } - let hi = K[word_idx + 1u]; - return (lo >> shift) | (hi << (32u - shift)); -} - -fn load_v_u16_at(byte_offset: u32) -> u32 { - let word = V[byte_offset / 4u]; - let shift = (byte_offset & 2u) * 8u; - return (word >> shift) & 0xFFFFu; -} - -fn load_v_u32_at(byte_offset: u32) -> u32 { - let word_idx = byte_offset / 4u; - let shift = (byte_offset & 3u) * 8u; - let lo = V[word_idx]; - if (shift == 0u) { - return lo; - } - let hi = V[word_idx + 1u]; - return (lo >> shift) | (hi << (32u - shift)); -} - -fn f16_from_u16(bits: u32) -> f16 { - let packed = unpack2x16float(bits); - return f16(packed[0]); -} -#endif - struct Params { offset_q: u32, offset_k: u32, @@ -139,11 +80,11 @@ struct Params { @group(0) @binding(0) var<storage, read_write> Q: array<f32>; #ifdef KV_OVERLAP -@group(0) @binding(1) var<storage, read_write> K: array<KV_TYPE>; +@group(0) @binding(1) var<storage, read_write> K: array<K_TYPE>; #define V K #else -@group(0) @binding(1) var<storage, read_write> K: array<KV_TYPE>; -@group(0) @binding(2) var<storage, read_write> V: array<KV_TYPE>; +@group(0) @binding(1) var<storage, read_write> K: array<K_TYPE>; +@group(0) @binding(2) var<storage, read_write> V: array<V_TYPE>; #endif #if defined(MASK) && defined(SINKS) @@ -238,10 +179,47 @@ fn load_f32x4(buf: ptr<storage, array<vec4<f32>>, read_write>, scalar_index: u32 return (*buf)[scalar_index >> 2u]; } -fn load_kvx4(buf: ptr<storage, array<vec4<KV_TYPE>>, read_write>, scalar_index: u32) -> vec4<KV_TYPE> { +fn load_kx4(buf: ptr<storage, array<vec4<K_TYPE>>, read_write>, scalar_index: u32) -> vec4<K_TYPE> { return (*buf)[scalar_index >> 2u]; } +#ifndef KV_DIRECT +#define QUANT_SHMEM kv_shmem +#define QUANT_OUT_TYPE f16 +#include "quant_inner_loops.tmpl" +#include "flash_attn_quant_staging.tmpl" + +#if !defined(K_Q4_0) && !defined(K_Q8_0) +fn load_k_tile_block(local_x: u32, kv_count: u32, kv_tile: u32, k_head_offset: u32) { + for (var elem_idx = local_x; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE) { + let k_row = elem_idx / HEAD_DIM_QK; + let k_col = elem_idx % HEAD_DIM_QK; + let global_k_row = kv_tile + k_row; + let global_k_row_offset = k_head_offset + global_k_row * params.stride_k1; + kv_shmem[elem_idx] = f16(select( + 0.0, + K[global_k_row_offset + k_col], + global_k_row < params.seq_len_kv && k_col < HEAD_DIM_QK)); + } +} +#endif + +#if !defined(V_Q4_0) && !defined(V_Q8_0) +fn load_v_tile_block(local_x: u32, kv_count: u32, kv_tile: u32, v_head_offset: u32) { + for (var elem_idx = local_x; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE) { + let v_row = elem_idx / HEAD_DIM_V; + let v_col = elem_idx % HEAD_DIM_V; + let global_v_row = kv_tile + v_row; + let global_v_row_offset = v_head_offset + global_v_row * params.stride_v1; + kv_shmem[elem_idx] = f16(select( + 0.0, + V[global_v_row_offset + v_col], + global_v_row < params.seq_len_kv && v_col < HEAD_DIM_V)); + } +} +#endif +#endif + @compute @workgroup_size(WG_SIZE) fn main(@builtin(workgroup_id) wg_id: vec3<u32>, @builtin(local_invocation_id) local_id: vec3<u32>, @@ -311,77 +289,15 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>, } for (var kv_tile = 0u; kv_tile < params.seq_len_kv; kv_tile += KV_TILE) { + let kv_count = min(KV_TILE, params.seq_len_kv - kv_tile); // clear inter_shmem to ensure zero-initialized accumulators for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) { inter_shmem[elem_idx] = 0.0; } // load k tile into shared memory -#if defined(KV_Q4_0) - for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) { - let blck_idx = elem_idx / BLOCK_SIZE; - let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; - let k_row = blck_idx / BLOCKS_K; - let global_k_row = kv_tile + k_row; - let block_k = blck_idx % BLOCKS_K; - let row_offset = k_row * HEAD_DIM_QK; - - if (global_k_row < params.seq_len_kv) { - let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k; - let block_byte_base = global_block_idx * BLOCK_SIZE_BYTES; - let d = f16_from_u16(load_k_u16_at(block_byte_base)); - for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); - let q_packed = load_k_u32_at(q_byte_offset); - for (var k = 0u; k < 4u; k++) { - let q_byte = get_byte(q_packed, k); - let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d; - let q_lo = (f16(q_byte & 0xF) - 8.0) * d; - let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; - kv_shmem[row_offset + idx] = q_lo; - kv_shmem[row_offset + idx + 16u] = q_hi; - } - } - } - } -#elif defined(KV_Q8_0) - for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) { - let blck_idx = elem_idx / BLOCK_SIZE; - let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; - let k_row = blck_idx / BLOCKS_K; - let global_k_row = kv_tile + k_row; - let block_k = blck_idx % BLOCKS_K; - let row_offset = k_row * HEAD_DIM_QK; - - if (global_k_row < params.seq_len_kv) { - let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k; - let block_byte_base = global_block_idx * BLOCK_SIZE_BYTES; - let d = f16_from_u16(load_k_u16_at(block_byte_base)); - for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); - let q_packed = load_k_u32_at(q_byte_offset); - for (var k = 0u; k < 4u; k++) { - let q_byte = get_byte_i32(q_packed, k); - let q_val = f16(q_byte) * d; - let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; - kv_shmem[row_offset + idx] = q_val; - } - } - } - } -#elif defined(KV_DIRECT) - // Direct global loads for KV -#else - for (var elem_idx = local_id.x; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE) { - let k_row = elem_idx / HEAD_DIM_QK; - let k_col = elem_idx % HEAD_DIM_QK; - let global_k_row = kv_tile + k_row; - let global_k_row_offset = k_head_offset + global_k_row * params.stride_k1; - kv_shmem[elem_idx] = f16(select( - 0.0, - K[global_k_row_offset + k_col], - global_k_row < params.seq_len_kv && k_col < HEAD_DIM_QK)); - } +#ifndef KV_DIRECT + load_k_tile_block(local_id.x, kv_count, kv_tile, k_head_offset); #endif workgroupBarrier(); @@ -520,71 +436,8 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>, } // load v tile into shared memory -#if defined(KV_Q4_0) - for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) { - let blck_idx = elem_idx / BLOCK_SIZE; - let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; - let v_row = blck_idx / BLOCKS_V; - let global_v_row = kv_tile + v_row; - let block_k = blck_idx % BLOCKS_V; - let row_offset = v_row * HEAD_DIM_V; - - if (global_v_row < params.seq_len_kv) { - let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k; - let block_byte_base = global_block_idx * BLOCK_SIZE_BYTES; - let d = f16_from_u16(load_v_u16_at(block_byte_base)); - for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); - let q_packed = load_v_u32_at(q_byte_offset); - for (var k = 0u; k < 4u; k++) { - let q_byte = get_byte(q_packed, k); - let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d; - let q_lo = (f16(q_byte & 0xF) - 8.0) * d; - let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; - kv_shmem[row_offset + idx] = q_lo; - kv_shmem[row_offset + idx + 16u] = q_hi; - } - } - } - } -#elif defined(KV_Q8_0) - for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) { - let blck_idx = elem_idx / BLOCK_SIZE; - let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; - let v_row = blck_idx / BLOCKS_V; - let global_v_row = kv_tile + v_row; - let block_k = blck_idx % BLOCKS_V; - let row_offset = v_row * HEAD_DIM_V; - - if (global_v_row < params.seq_len_kv) { - let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k; - let block_byte_base = global_block_idx * BLOCK_SIZE_BYTES; - let d = f16_from_u16(load_v_u16_at(block_byte_base)); - for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_byte_offset = block_byte_base + 2u + 2u * (block_offset + j); - let q_packed = load_v_u32_at(q_byte_offset); - for (var k = 0u; k < 4u; k++) { - let q_byte = get_byte_i32(q_packed, k); - let q_val = f16(q_byte) * d; - let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; - kv_shmem[row_offset + idx] = q_val; - } - } - } - } -#elif defined(KV_DIRECT) - // Direct global loads for KV -#else - for (var elem_idx = local_id.x; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE) { - let v_row = elem_idx / HEAD_DIM_V; - let v_col = elem_idx % HEAD_DIM_V; - let global_v_row = kv_tile + v_row; - let global_v_row_offset = v_head_offset + global_v_row * params.stride_v1; - kv_shmem[elem_idx] = f16(select( - 0.0, - V[global_v_row_offset + v_col], - global_v_row < params.seq_len_kv && v_col < HEAD_DIM_V)); - } +#ifndef KV_DIRECT + load_v_tile_block(local_id.x, kv_count, kv_tile, v_head_offset); #endif workgroupBarrier(); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_staging.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_staging.tmpl new file mode 100644 index 00000000000..8f41eb7bfdb --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_staging.tmpl @@ -0,0 +1,124 @@ +#define BLOCK_SIZE 32 +#define BLOCKS_K ((HEAD_DIM_QK + BLOCK_SIZE - 1) / BLOCK_SIZE) +#define BLOCKS_V ((HEAD_DIM_V + BLOCK_SIZE - 1) / BLOCK_SIZE) + +#if defined(K_Q4_0) +#define K_NQ 16 +#define K_BLOCK_SIZE_BYTES 18u +#define K_BYTES_PER_THREAD 8u +#define K_BYTES_PER_INNER_LOOP 4u +#elif defined(K_Q8_0) +#define K_NQ 16 +#define K_BLOCK_SIZE_BYTES 34u +#define K_BYTES_PER_THREAD 16u +#define K_BYTES_PER_INNER_LOOP 4u +#endif + +#if defined(V_Q4_0) +#define V_NQ 16 +#define V_BLOCK_SIZE_BYTES 18u +#define V_BYTES_PER_THREAD 8u +#define V_BYTES_PER_INNER_LOOP 4u +#elif defined(V_Q8_0) +#define V_NQ 16 +#define V_BLOCK_SIZE_BYTES 34u +#define V_BYTES_PER_THREAD 16u +#define V_BYTES_PER_INNER_LOOP 4u +#endif + +#if defined(K_Q4_0) || defined(K_Q8_0) +fn load_k_u16_at(byte_offset: u32) -> u32 { + let word = K[byte_offset / 4u]; + let shift = (byte_offset & 2u) * 8u; + return (word >> shift) & 0xFFFFu; +} + +fn load_k_u32_at(byte_offset: u32) -> u32 { + let word_idx = byte_offset / 4u; + let shift = (byte_offset & 3u) * 8u; + let lo = K[word_idx]; + if (shift == 0u) { + return lo; + } + let hi = K[word_idx + 1u]; + return (lo >> shift) | (hi << (32u - shift)); +} +#endif + +#if defined(V_Q4_0) || defined(V_Q8_0) +fn load_v_u16_at(byte_offset: u32) -> u32 { + let word = V[byte_offset / 4u]; + let shift = (byte_offset & 2u) * 8u; + return (word >> shift) & 0xFFFFu; +} + +fn load_v_u32_at(byte_offset: u32) -> u32 { + let word_idx = byte_offset / 4u; + let shift = (byte_offset & 3u) * 8u; + let lo = V[word_idx]; + if (shift == 0u) { + return lo; + } + let hi = V[word_idx + 1u]; + return (lo >> shift) | (hi << (32u - shift)); +} +#endif + +fn f16_from_u16(bits: u32) -> f16 { + let packed = unpack2x16float(bits); + return f16(packed[0]); +} + +#if defined(K_Q4_0) || defined(K_Q8_0) +fn load_k_tile_block(local_x: u32, kv_count: u32, kv_tile: u32, k_head_offset: u32) { + for (var elem_idx = local_x * K_NQ; elem_idx < kv_count * HEAD_DIM_QK; elem_idx += WG_SIZE * K_NQ) { + let blck_idx = elem_idx / BLOCK_SIZE; + let block_offset = (elem_idx % BLOCK_SIZE) / K_NQ; + let k_row = blck_idx / BLOCKS_K; + let global_k_row = kv_tile + k_row; + let block_k = blck_idx % BLOCKS_K; + let row_offset = k_row * HEAD_DIM_QK; + let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k; + let block_byte_base = global_block_idx * K_BLOCK_SIZE_BYTES; + let d = f16_from_u16(load_k_u16_at(block_byte_base)); + let thread_byte_offset = block_offset * K_BYTES_PER_THREAD; + let shmem_idx = row_offset + block_k * BLOCK_SIZE + thread_byte_offset; + for (var j = 0u; j < K_BYTES_PER_THREAD / K_BYTES_PER_INNER_LOOP; j += 1u) { + let q_byte_offset = block_byte_base + 2u + thread_byte_offset + j * K_BYTES_PER_INNER_LOOP; + let q_packed = load_k_u32_at(q_byte_offset); +#if defined(K_Q4_0) + dequant_q4_0_packed_to_shmem(q_packed, d, shmem_idx + j * K_BYTES_PER_INNER_LOOP); +#elif defined(K_Q8_0) + dequant_q8_0_packed_to_shmem(q_packed, d, shmem_idx + j * K_BYTES_PER_INNER_LOOP); +#endif + } + } +} +#endif + +#if defined(V_Q4_0) || defined(V_Q8_0) +fn load_v_tile_block(local_x: u32, kv_count: u32, kv_tile: u32, v_head_offset: u32) { + for (var elem_idx = local_x * V_NQ; elem_idx < kv_count * HEAD_DIM_V; elem_idx += WG_SIZE * V_NQ) { + let blck_idx = elem_idx / BLOCK_SIZE; + let block_offset = (elem_idx % BLOCK_SIZE) / V_NQ; + let v_row = blck_idx / BLOCKS_V; + let global_v_row = kv_tile + v_row; + let block_k = blck_idx % BLOCKS_V; + let row_offset = v_row * HEAD_DIM_V; + let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k; + let block_byte_base = global_block_idx * V_BLOCK_SIZE_BYTES; + let d = f16_from_u16(load_v_u16_at(block_byte_base)); + let thread_byte_offset = block_offset * V_BYTES_PER_THREAD; + let shmem_idx = row_offset + block_k * BLOCK_SIZE + thread_byte_offset; + for (var j = 0u; j < V_BYTES_PER_THREAD / V_BYTES_PER_INNER_LOOP; j += 1u) { + let q_byte_offset = block_byte_base + 2u + thread_byte_offset + j * V_BYTES_PER_INNER_LOOP; + let q_packed = load_v_u32_at(q_byte_offset); +#if defined(V_Q4_0) + dequant_q4_0_packed_to_shmem(q_packed, d, shmem_idx + j * V_BYTES_PER_INNER_LOOP); +#elif defined(V_Q8_0) + dequant_q8_0_packed_to_shmem(q_packed, d, shmem_idx + j * V_BYTES_PER_INNER_LOOP); +#endif + } + } +} +#endif diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl index 4133f0ab564..e68934113fc 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl @@ -1,16 +1,29 @@ enable f16; enable subgroups; +#define BYTE_HELPERS +#include "common_decls.tmpl" + #ifdef Q_F16 #define Q_TYPE f16 #else #define Q_TYPE f32 #endif -#ifdef KV_F32 -#define KV_TYPE f32 +#ifdef K_F32 +#define K_TYPE f32 +#elif defined(K_Q4_0) || defined(K_Q8_0) +#define K_TYPE u32 +#else +#define K_TYPE f16 +#endif + +#ifdef V_F32 +#define V_TYPE f32 +#elif defined(V_Q4_0) || defined(V_Q8_0) +#define V_TYPE u32 #else -#define KV_TYPE f16 +#define V_TYPE f16 #endif #ifdef DST_F16 @@ -21,7 +34,6 @@ enable subgroups; #define HEAD_DIM_QK 64 #define HEAD_DIM_V 64 -#define KV_STAGE_STRIDE 64 #define Q_TILE 4 #define KV_TILE 64 #define WG_SIZE 128 @@ -64,11 +76,23 @@ struct Params { @group(0) @binding(0) var<storage, read_write> Q: array<Q_TYPE>; #ifdef KV_OVERLAP -@group(0) @binding(1) var<storage, read_write> K: array<vec4<KV_TYPE>>; +#if defined(K_Q4_0) || defined(K_Q8_0) +@group(0) @binding(1) var<storage, read_write> K: array<K_TYPE>; +#else +@group(0) @binding(1) var<storage, read_write> K: array<vec4<K_TYPE>>; +#endif #define V K #else -@group(0) @binding(1) var<storage, read_write> K: array<vec4<KV_TYPE>>; -@group(0) @binding(2) var<storage, read_write> V: array<vec4<KV_TYPE>>; +#if defined(K_Q4_0) || defined(K_Q8_0) +@group(0) @binding(1) var<storage, read_write> K: array<K_TYPE>; +#else +@group(0) @binding(1) var<storage, read_write> K: array<vec4<K_TYPE>>; +#endif +#if defined(V_Q4_0) || defined(V_Q8_0) +@group(0) @binding(2) var<storage, read_write> V: array<V_TYPE>; +#else +@group(0) @binding(2) var<storage, read_write> V: array<vec4<V_TYPE>>; +#endif #endif #if defined(MASK) && defined(SINKS) @@ -121,10 +145,50 @@ const Q_CHUNKS: u32 = HEAD_DIM_QK / 4u; const V_CHUNKS: u32 = HEAD_DIM_V / 4u; const SCORE_REGS_PER_LANE: u32 = (KV_TILE + MIN_SUBGROUP_SIZE - 1u) / MIN_SUBGROUP_SIZE; const OUT_REGS_PER_LANE: u32 = (V_CHUNKS + MIN_SUBGROUP_SIZE - 1u) / MIN_SUBGROUP_SIZE; +const kv_shmem_size = KV_TILE * max(HEAD_DIM_QK, HEAD_DIM_V); var<workgroup> q_shmem: array<Q_TYPE, Q_TILE * HEAD_DIM_QK>; -var<workgroup> kv_shmem: array<KV_TYPE, KV_TILE * KV_STAGE_STRIDE>; -var<workgroup> p_shmem: array<KV_TYPE, Q_TILE * KV_TILE>; +var<workgroup> kv_shmem: array<f16, kv_shmem_size>; +var<workgroup> p_shmem: array<f16, Q_TILE * KV_TILE>; + +#define QUANT_SHMEM kv_shmem +#define QUANT_OUT_TYPE f16 +#include "quant_inner_loops.tmpl" +#include "flash_attn_quant_staging.tmpl" + +#if !defined(K_Q4_0) && !defined(K_Q8_0) +fn load_k_tile_block(local_x: u32, kv_count: u32, kv_tile: u32, k_head_offset: u32) { + for (var vec_idx_local = local_x; vec_idx_local < kv_count * Q_CHUNKS; vec_idx_local += WG_SIZE) { + let kv_local = vec_idx_local / Q_CHUNKS; + let chunk = vec_idx_local % Q_CHUNKS; + let global_k_row = kv_tile + kv_local; + let k_vec_index = (k_head_offset + global_k_row * params.stride_k1 + chunk * 4u) >> 2u; + let k4 = K[k_vec_index]; + let kv_off = kv_local * HEAD_DIM_QK + chunk * 4u; + kv_shmem[kv_off + 0u] = f16(k4.x); + kv_shmem[kv_off + 1u] = f16(k4.y); + kv_shmem[kv_off + 2u] = f16(k4.z); + kv_shmem[kv_off + 3u] = f16(k4.w); + } +} +#endif + +#if !defined(V_Q4_0) && !defined(V_Q8_0) +fn load_v_tile_block(local_x: u32, kv_count: u32, kv_tile: u32, v_head_offset: u32) { + for (var vec_idx_local = local_x; vec_idx_local < kv_count * V_CHUNKS; vec_idx_local += WG_SIZE) { + let kv_local = vec_idx_local / V_CHUNKS; + let chunk = vec_idx_local % V_CHUNKS; + let global_v_row = kv_tile + kv_local; + let v_vec_index = (v_head_offset + global_v_row * params.stride_v1 + chunk * 4u) >> 2u; + let v4 = V[v_vec_index]; + let kv_off = kv_local * HEAD_DIM_V + chunk * 4u; + kv_shmem[kv_off + 0u] = f16(v4.x); + kv_shmem[kv_off + 1u] = f16(v4.y); + kv_shmem[kv_off + 2u] = f16(v4.z); + kv_shmem[kv_off + 3u] = f16(v4.w); + } +} +#endif @compute @workgroup_size(WG_SIZE) fn main(@builtin(workgroup_id) wg_id: vec3<u32>, @@ -206,18 +270,9 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>, local_scores[slot] = FLOAT_MIN; } - for (var vec_idx_local = local_id.x; vec_idx_local < kv_count * Q_CHUNKS; vec_idx_local += WG_SIZE) { - let kv_local = vec_idx_local / Q_CHUNKS; - let chunk = vec_idx_local % Q_CHUNKS; - let global_k_row = kv_tile + kv_local; - let k_vec_index = (k_head_offset + global_k_row * params.stride_k1 + chunk * 4u) >> 2u; - let k4 = K[k_vec_index]; - let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u; - kv_shmem[kv_off + 0u] = KV_TYPE(k4.x); - kv_shmem[kv_off + 1u] = KV_TYPE(k4.y); - kv_shmem[kv_off + 2u] = KV_TYPE(k4.z); - kv_shmem[kv_off + 3u] = KV_TYPE(k4.w); - } +#ifndef KV_DIRECT + load_k_tile_block(local_id.x, kv_count, kv_tile, k_head_offset); +#endif workgroupBarrier(); @@ -238,8 +293,8 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>, q_shmem[q_off + 1u], q_shmem[q_off + 2u], q_shmem[q_off + 3u]); - let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u; - let kv = vec4<KV_TYPE>( + let kv_off = kv_local * HEAD_DIM_QK + chunk * 4u; + let kv = vec4<f16>( kv_shmem[kv_off + 0u], kv_shmem[kv_off + 1u], kv_shmem[kv_off + 2u], @@ -271,25 +326,16 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>, let kv_local = sg_inv_id + slot * subgroup_size; if (row_active && kv_local < kv_count) { let p = exp(local_scores[slot] - new_max); - p_shmem[subgroup_p_offset + kv_local] = KV_TYPE(p); + p_shmem[subgroup_p_offset + kv_local] = f16(p); local_sum += p; } } workgroupBarrier(); - for (var vec_idx_local = local_id.x; vec_idx_local < kv_count * V_CHUNKS; vec_idx_local += WG_SIZE) { - let kv_local = vec_idx_local / V_CHUNKS; - let chunk = vec_idx_local % V_CHUNKS; - let global_v_row = kv_tile + kv_local; - let v_vec_index = (v_head_offset + global_v_row * params.stride_v1 + chunk * 4u) >> 2u; - let v4 = V[v_vec_index]; - let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u; - kv_shmem[kv_off + 0u] = KV_TYPE(v4.x); - kv_shmem[kv_off + 1u] = KV_TYPE(v4.y); - kv_shmem[kv_off + 2u] = KV_TYPE(v4.z); - kv_shmem[kv_off + 3u] = KV_TYPE(v4.w); - } +#ifndef KV_DIRECT + load_v_tile_block(local_id.x, kv_count, kv_tile, v_head_offset); +#endif workgroupBarrier(); @@ -306,14 +352,14 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>, var acc = out_regs[reg_idx]; for (var kv_local = 0u; kv_local < kv_count; kv_local += 1u) { - let p = p_shmem[subgroup_p_offset + kv_local]; - let kv_off = kv_local * KV_STAGE_STRIDE + chunk * 4u; - let v4 = vec4<KV_TYPE>( + let p = f32(p_shmem[subgroup_p_offset + kv_local]); + let kv_off = kv_local * HEAD_DIM_V + chunk * 4u; + let v4 = vec4<f16>( kv_shmem[kv_off + 0u], kv_shmem[kv_off + 1u], kv_shmem[kv_off + 2u], kv_shmem[kv_off + 3u]); - acc += f32(p) * vec4<f32>(v4); + acc += p * vec4<f32>(v4); } out_regs[reg_idx] = acc; } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl index 30ebbebe772..30ed97cca0c 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl @@ -2,10 +2,23 @@ diagnostic(off, subgroup_uniformity); enable f16; enable subgroups; -#ifdef KV_F32 -#define KV_TYPE f32 +#define BYTE_HELPERS +#include "common_decls.tmpl" + +#ifdef K_F32 +#define K_TYPE f32 +#elif defined(K_Q4_0) || defined(K_Q8_0) +#define K_TYPE u32 #else -#define KV_TYPE f16 +#define K_TYPE f16 +#endif + +#ifdef V_F32 +#define V_TYPE f32 +#elif defined(V_Q4_0) || defined(V_Q8_0) +#define V_TYPE u32 +#else +#define V_TYPE f16 #endif #ifdef Q_F16 @@ -32,28 +45,6 @@ enable subgroups; #define KV_BLOCKS (KV_TILE / KV_GRANULARITY) -#define BLOCK_SIZE 32 -#define BLOCKS_K ((HEAD_DIM_QK + BLOCK_SIZE - 1) / BLOCK_SIZE) -#define BLOCKS_V ((HEAD_DIM_V + BLOCK_SIZE - 1) / BLOCK_SIZE) -#if defined(KV_Q4_0) -#define NQ 16 -#define F16_PER_BLOCK 9 -#define WEIGHTS_PER_F16 4 -#elif defined(KV_Q8_0) -#define NQ 8 -#define F16_PER_BLOCK 17 -#define WEIGHTS_PER_F16 2 -#endif -#define F16_PER_THREAD (NQ / WEIGHTS_PER_F16) - -fn get_byte(value: u32, index: u32) -> u32 { - return (value >> (index * 8)) & 0xFF; -} - -fn get_byte_i32(value: u32, index: u32) -> i32 { - return bitcast<i32>(((value >> (index * 8)) & 0xFF) << 24) >> 24; -} - struct Params { offset_q: u32, offset_k: u32, @@ -103,22 +94,22 @@ struct Params { @group(0) @binding(0) var<storage, read_write> Q: array<Q_TYPE>; #ifdef KV_OVERLAP -#if defined(KV_Q4_0) || defined(KV_Q8_0) -@group(0) @binding(1) var<storage, read_write> K: array<KV_TYPE>; +#if defined(K_Q4_0) || defined(K_Q8_0) +@group(0) @binding(1) var<storage, read_write> K: array<K_TYPE>; #else -@group(0) @binding(1) var<storage, read_write> K: array<vec4<KV_TYPE>>; +@group(0) @binding(1) var<storage, read_write> K: array<vec4<K_TYPE>>; #endif #define V K #else -#if defined(KV_Q4_0) || defined(KV_Q8_0) -@group(0) @binding(1) var<storage, read_write> K: array<KV_TYPE>; +#if defined(K_Q4_0) || defined(K_Q8_0) +@group(0) @binding(1) var<storage, read_write> K: array<K_TYPE>; #else -@group(0) @binding(1) var<storage, read_write> K: array<vec4<KV_TYPE>>; +@group(0) @binding(1) var<storage, read_write> K: array<vec4<K_TYPE>>; #endif -#if defined(KV_Q4_0) || defined(KV_Q8_0) -@group(0) @binding(2) var<storage, read_write> V: array<KV_TYPE>; +#if defined(V_Q4_0) || defined(V_Q8_0) +@group(0) @binding(2) var<storage, read_write> V: array<V_TYPE>; #else -@group(0) @binding(2) var<storage, read_write> V: array<vec4<KV_TYPE>>; +@group(0) @binding(2) var<storage, read_write> V: array<vec4<V_TYPE>>; #endif #endif #if defined(MASK) && defined(SINKS) @@ -244,6 +235,49 @@ fn calc_softmax_term(kv_idx: u32, slope: f32, has_bias: bool, apply_mask: bool) return v; } +#ifndef KV_DIRECT +#define QUANT_SHMEM kv_shmem +#define QUANT_OUT_TYPE f32 +#include "quant_inner_loops.tmpl" +#include "flash_attn_quant_staging.tmpl" + +#if !defined(K_Q4_0) && !defined(K_Q8_0) +fn load_k_tile_block(local_x: u32, kv_count: u32, kv_tile: u32, k_head_offset: u32) { + for (var elem_idx = local_x * 4u; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * 4u) { + let k_row = elem_idx / HEAD_DIM_QK; + let k_col = elem_idx % HEAD_DIM_QK; + let global_k_row = kv_tile + k_row; + let global_k_row_offset = k_head_offset + global_k_row * params.stride_k1; + let in_bounds = global_k_row < params.seq_len_kv && (k_col + 3u) < HEAD_DIM_QK; + let vec_idx = (global_k_row_offset + k_col) >> 2u; + let k4 = select(vec4<K_TYPE>(0.0), K[vec_idx], in_bounds); + kv_shmem[elem_idx + 0u] = f32(k4.x); + kv_shmem[elem_idx + 1u] = f32(k4.y); + kv_shmem[elem_idx + 2u] = f32(k4.z); + kv_shmem[elem_idx + 3u] = f32(k4.w); + } +} +#endif + +#if !defined(V_Q4_0) && !defined(V_Q8_0) +fn load_v_tile_block(local_x: u32, kv_count: u32, kv_tile: u32, v_head_offset: u32) { + for (var elem_idx = local_x * 4u; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * 4u) { + let v_row = elem_idx / HEAD_DIM_V; + let v_col = elem_idx % HEAD_DIM_V; + let global_v_row = kv_tile + v_row; + let global_v_row_offset = v_head_offset + global_v_row * params.stride_v1; + let in_bounds = global_v_row < params.seq_len_kv && (v_col + 3u) < HEAD_DIM_V; + let vec_idx = (global_v_row_offset + v_col) >> 2u; + let v4 = select(vec4<V_TYPE>(0.0), V[vec_idx], in_bounds); + kv_shmem[elem_idx + 0u] = f32(v4.x); + kv_shmem[elem_idx + 1u] = f32(v4.y); + kv_shmem[elem_idx + 2u] = f32(v4.z); + kv_shmem[elem_idx + 3u] = f32(v4.w); + } +} +#endif +#endif + @compute @workgroup_size(WG_SIZE) fn main(@builtin(workgroup_id) wg_id: vec3<u32>, @builtin(local_invocation_id) local_id: vec3<u32>, @@ -308,6 +342,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>, } for (var kv_tile = iwg * KV_TILE; kv_tile < params.seq_len_kv; kv_tile += KV_TILE * params.nwg) { + let kv_count = min(KV_TILE, params.seq_len_kv - kv_tile); #ifdef BLK let q_blk = q_row_start; let kv_blk = kv_tile / KV_TILE; @@ -324,76 +359,8 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>, } // load k tile into shared memory -#if defined(KV_Q4_0) - for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) { - let blck_idx = elem_idx / BLOCK_SIZE; - let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; - let k_row = blck_idx / BLOCKS_K; - let global_k_row = kv_tile + k_row; - let block_k = blck_idx % BLOCKS_K; - let row_offset = k_row * HEAD_DIM_QK; - - if (global_k_row < params.seq_len_kv) { - let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k; - let base_idx = global_block_idx * F16_PER_BLOCK; - let d = K[base_idx]; - for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_0 = K[base_idx + 1u + block_offset + j]; - let q_1 = K[base_idx + 1u + block_offset + j + 1]; - let q_packed = bitcast<u32>(vec2(q_0, q_1)); - for (var k = 0u; k < 4u; k++) { - let q_byte = get_byte(q_packed, k); - let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * f32(d); - let q_lo = (f32(q_byte & 0xF) - 8.0) * f32(d); - let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; - kv_shmem[row_offset + idx] = q_lo; - kv_shmem[row_offset + idx + 16u] = q_hi; - } - } - } - } -#elif defined(KV_Q8_0) - for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) { - let blck_idx = elem_idx / BLOCK_SIZE; - let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; - let k_row = blck_idx / BLOCKS_K; - let global_k_row = kv_tile + k_row; - let block_k = blck_idx % BLOCKS_K; - let row_offset = k_row * HEAD_DIM_QK; - - if (global_k_row < params.seq_len_kv) { - let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k; - let base_idx = global_block_idx * F16_PER_BLOCK; - let d = K[base_idx]; - for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_0 = K[base_idx + 1u + block_offset + j]; - let q_1 = K[base_idx + 1u + block_offset + j + 1]; - let q_packed = bitcast<u32>(vec2(q_0, q_1)); - for (var k = 0u; k < 4u; k++) { - let q_byte = get_byte_i32(q_packed, k); - let q_val = f32(q_byte) * f32(d); - let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; - kv_shmem[row_offset + idx] = q_val; - } - } - } - } -#elif defined(KV_DIRECT) - // Direct global loads for KV -#else - for (var elem_idx = local_id.x * 4u; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * 4u) { - let k_row = elem_idx / HEAD_DIM_QK; - let k_col = elem_idx % HEAD_DIM_QK; - let global_k_row = kv_tile + k_row; - let global_k_row_offset = k_head_offset + global_k_row * params.stride_k1; - let in_bounds = global_k_row < params.seq_len_kv && (k_col + 3u) < HEAD_DIM_QK; - let vec_idx = (global_k_row_offset + k_col) >> 2u; - let k4 = select(vec4<KV_TYPE>(0.0), K[vec_idx], in_bounds); - kv_shmem[elem_idx + 0u] = f32(k4.x); - kv_shmem[elem_idx + 1u] = f32(k4.y); - kv_shmem[elem_idx + 2u] = f32(k4.z); - kv_shmem[elem_idx + 3u] = f32(k4.w); - } +#ifndef KV_DIRECT + load_k_tile_block(local_id.x, kv_count, kv_tile, k_head_offset); #endif workgroupBarrier(); @@ -510,76 +477,8 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>, } // load v tile into shared memory -#if defined(KV_Q4_0) - for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) { - let blck_idx = elem_idx / BLOCK_SIZE; - let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; - let v_row = blck_idx / BLOCKS_V; - let global_v_row = kv_tile + v_row; - let block_k = blck_idx % BLOCKS_V; - let row_offset = v_row * HEAD_DIM_V; - - if (global_v_row < params.seq_len_kv) { - let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k; - let base_idx = global_block_idx * F16_PER_BLOCK; - let d = V[base_idx]; - for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_0 = V[base_idx + 1u + block_offset + j]; - let q_1 = V[base_idx + 1u + block_offset + j + 1]; - let q_packed = bitcast<u32>(vec2(q_0, q_1)); - for (var k = 0u; k < 4u; k++) { - let q_byte = get_byte(q_packed, k); - let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * f32(d); - let q_lo = (f32(q_byte & 0xF) - 8.0) * f32(d); - let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; - kv_shmem[row_offset + idx] = q_lo; - kv_shmem[row_offset + idx + 16u] = q_hi; - } - } - } - } -#elif defined(KV_Q8_0) - for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) { - let blck_idx = elem_idx / BLOCK_SIZE; - let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; - let v_row = blck_idx / BLOCKS_V; - let global_v_row = kv_tile + v_row; - let block_k = blck_idx % BLOCKS_V; - let row_offset = v_row * HEAD_DIM_V; - - if (global_v_row < params.seq_len_kv) { - let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k; - let base_idx = global_block_idx * F16_PER_BLOCK; - let d = V[base_idx]; - for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_0 = V[base_idx + 1u + block_offset + j]; - let q_1 = V[base_idx + 1u + block_offset + j + 1]; - let q_packed = bitcast<u32>(vec2(q_0, q_1)); - for (var k = 0u; k < 4u; k++) { - let q_byte = get_byte_i32(q_packed, k); - let q_val = f32(q_byte) * f32(d); - let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; - kv_shmem[row_offset + idx] = q_val; - } - } - } - } -#elif defined(KV_DIRECT) - // Direct global loads for KV -#else - for (var elem_idx = local_id.x * 4u; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * 4u) { - let v_row = elem_idx / HEAD_DIM_V; - let v_col = elem_idx % HEAD_DIM_V; - let global_v_row = kv_tile + v_row; - let global_v_row_offset = v_head_offset + global_v_row * params.stride_v1; - let in_bounds = global_v_row < params.seq_len_kv && (v_col + 3u) < HEAD_DIM_V; - let vec_idx = (global_v_row_offset + v_col) >> 2u; - let v4 = select(vec4<KV_TYPE>(0.0), V[vec_idx], in_bounds); - kv_shmem[elem_idx + 0u] = f32(v4.x); - kv_shmem[elem_idx + 1u] = f32(v4.y); - kv_shmem[elem_idx + 2u] = f32(v4.z); - kv_shmem[elem_idx + 3u] = f32(v4.w); - } +#ifndef KV_DIRECT + load_v_tile_block(local_id.x, kv_count, kv_tile, v_head_offset); #endif workgroupBarrier(); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl index eb2a8368f43..72991504dd0 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl @@ -25,6 +25,10 @@ fn store_shmem(val: f16, idx: u32) { } #endif // SCALAR +#define QUANT_SHMEM shmem +#define QUANT_OUT_TYPE f16 +#include "quant_inner_loops.tmpl" + #ifdef INIT_SRC0_SHMEM_FLOAT fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { for (var elem_idx = thread_id * VEC_SIZE; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * VEC_SIZE) { @@ -124,14 +128,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let q_byte_offset = block_byte_base + 2u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP; let q_packed = load_u32_at_src0(q_byte_offset); - - for (var k = 0u; k < BYTES_PER_INNER_LOOP; k++) { - let q_byte = get_byte(q_packed, k); - let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d; - let q_lo = (f16(q_byte & 0xF) - 8.0) * d; - shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k] = q_lo; - shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = q_hi; - } + dequant_q4_0_packed_to_shmem(q_packed, d, shmem_idx + j * BYTES_PER_INNER_LOOP); } } } @@ -314,12 +311,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) { let q_byte_offset = block_byte_base + 2u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP; let q_packed = load_u32_at_src0(q_byte_offset); - for (var k = 0u; k < BYTES_PER_INNER_LOOP; k++) { - let q_byte = get_byte_i32(q_packed, k); - - let q_val = f16(q_byte) * d; - shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k] = q_val; - } + dequant_q8_0_packed_to_shmem(q_packed, d, shmem_idx + j * BYTES_PER_INNER_LOOP); } } } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/quant_inner_loops.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/quant_inner_loops.tmpl new file mode 100644 index 00000000000..d1da4608434 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/quant_inner_loops.tmpl @@ -0,0 +1,21 @@ +#ifdef U32_DEQUANT_HELPERS +fn dequant_q4_0_packed_to_shmem(q_packed: u32, d: f16, dst_idx: u32) { + let scale = QUANT_OUT_TYPE(d); + for (var k = 0u; k < 4u; k++) { + let q_byte = get_byte(q_packed, k); + let q_hi = (QUANT_OUT_TYPE((q_byte >> 4) & 0xFu) - QUANT_OUT_TYPE(8.0)) * scale; + let q_lo = (QUANT_OUT_TYPE(q_byte & 0xFu) - QUANT_OUT_TYPE(8.0)) * scale; + QUANT_SHMEM[dst_idx + k] = q_lo; + QUANT_SHMEM[dst_idx + k + 16u] = q_hi; + } +} + +fn dequant_q8_0_packed_to_shmem(q_packed: u32, d: f16, dst_idx: u32) { + let scale = QUANT_OUT_TYPE(d); + for (var k = 0u; k < 4u; k++) { + let q_byte = get_byte_i32(q_packed, k); + let q_val = QUANT_OUT_TYPE(q_byte) * scale; + QUANT_SHMEM[dst_idx + k] = q_val; + } +} +#endif From 9d6e561f692b9b6353a33fa63e8b8a6998a41cb1 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov <ggerganov@gmail.com> Date: Thu, 4 Jun 2026 08:05:32 +0300 Subject: [PATCH 785/831] metal : reduce rset heartbeat from 500ms -> 5ms (llama/24074) --- ggml/src/ggml-metal/ggml-metal-device.m | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 196af102643..05d7f43051b 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -547,6 +547,8 @@ void ggml_metal_encoder_end_encoding(ggml_metal_encoder_t encoder) { // number of seconds since the last graph computation // keep the residency sets wired for that amount of time to avoid being collected by the OS int keep_alive_s; + int loops_per_s; + int time_per_loop_ms; // background heartbeat thread to keep the residency sets alive atomic_bool d_stop; @@ -573,10 +575,13 @@ ggml_metal_rsets_t ggml_metal_rsets_init(void) { res->keep_alive_s = 3*60; } + res->time_per_loop_ms = 5; + res->loops_per_s = 1000/res->time_per_loop_ms; + GGML_LOG_INFO("%s: creating a residency set collection (keep_alive = %d s)\n", __func__, res->keep_alive_s); atomic_store_explicit(&res->d_stop, false, memory_order_relaxed); - atomic_store_explicit(&res->d_loop, 2*res->keep_alive_s, memory_order_relaxed); + atomic_store_explicit(&res->d_loop, res->loops_per_s*res->keep_alive_s, memory_order_relaxed); res->d_group = dispatch_group_create(); @@ -599,8 +604,7 @@ ggml_metal_rsets_t ggml_metal_rsets_init(void) { [res->lock unlock]; } - // half a second - usleep(500 * 1000); + usleep(res->time_per_loop_ms * 1000); } } #endif @@ -979,7 +983,7 @@ void ggml_metal_device_rsets_keep_alive(ggml_metal_device_t dev) { return; } - atomic_store_explicit(&dev->rsets->d_loop, 2*dev->rsets->keep_alive_s, memory_order_relaxed); + atomic_store_explicit(&dev->rsets->d_loop, dev->rsets->loops_per_s*dev->rsets->keep_alive_s, memory_order_relaxed); } struct ggml_metal_event { From 991b5a8b4ab652b0bc282f500a58de565e7aa0bc Mon Sep 17 00:00:00 2001 From: Kartik Sirohi <99896785+sirohikartik@users.noreply.github.com> Date: Thu, 4 Jun 2026 18:42:38 +0530 Subject: [PATCH 786/831] ggml: vectorize ggml_vec_dot_q4_1_q8_1 with WASM SIMD128 (llama/22209) * ggml: vectorize ggml_vec_dot_q4_1_q8_1 with WASM SIMD128 Optimize the inner loop of ggml_vec_dot_q4_1_q8_1_generic using WASM SIMD128 intrinsics, gated behind #ifdef __wasm_simd128__ so non-wasm builds are completely unaffected. Approach: - single wasm_v128_load covers all 32 packed 4-bit weights - nibbles unpacked via AND/SHR into two u8x16 registers - widened to i16 before multiply (WASM SIMD has no i8*i8 instruction) - 4x wasm_i32x4_dot_i16x8 calls accumulate all 32 element pairs - horizontal reduce via 4x wasm_i32x4_extract_lane Benchmark (node v25, emcc -O3 -msimd128, 64 blocks x QK8_1=32, 200k iterations): | impl | ns/call | speedup | |--------|---------|---------| | scalar | 880.7 | 1.00x | | simd | 257.8 | 3.42x | Correctness verified against scalar reference across 10 random seeds with exact output match. * ggml: move q4_1_q8_1 WASM SIMD implementation to wasm backend Relocate the SIMD128 implementation of ggml_vec_dot_q4_1_q8_1 to ggml/src/ggml-cpu/arch/wasm/quants.c to follow architecture-specific layout. Restore the generic implementation in ggml/src/ggml-cpu/quants.c. Move for loop in the else block. * ggml: use generic q4_1_q8_1 fallback in wasm backend --- ggml/src/ggml-cpu/arch/wasm/quants.c | 72 ++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) diff --git a/ggml/src/ggml-cpu/arch/wasm/quants.c b/ggml/src/ggml-cpu/arch/wasm/quants.c index 648c6fcaba7..0a7119b4e1f 100644 --- a/ggml/src/ggml-cpu/arch/wasm/quants.c +++ b/ggml/src/ggml-cpu/arch/wasm/quants.c @@ -355,6 +355,78 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi *s = sumf; } +void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + const int qk = QK8_1; + const int nb = n / qk; + + assert(n % qk == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q4_1 * GGML_RESTRICT x = vx; + const block_q8_1 * GGML_RESTRICT y = vy; + + float sumf = 0; + +#if defined __wasm_simd128__ + v128_t sumv = wasm_f32x4_splat(0.0f); + float summs = 0.0f; + + for (int ib = 0; ib < nb; ++ib) { + const block_q4_1 * GGML_RESTRICT x0 = &x[ib]; + const block_q8_1 * GGML_RESTRICT y0 = &y[ib]; + + summs += GGML_CPU_FP16_TO_FP32(x0->m) * GGML_CPU_FP16_TO_FP32(y0->s); + + const v128_t raw = wasm_v128_load(x0->qs); + const v128_t v0s = wasm_v128_and(raw, wasm_i8x16_splat(0x0F)); + const v128_t v1s = wasm_u8x16_shr(raw, 4); + + const v128_t ys_lo = wasm_v128_load(y0->qs); + const v128_t ys_hi = wasm_v128_load(y0->qs + 16); + + const v128_t v0s_l = wasm_u16x8_extend_low_u8x16(v0s); + const v128_t v0s_h = wasm_u16x8_extend_high_u8x16(v0s); + const v128_t ylo_l = wasm_i16x8_extend_low_i8x16(ys_lo); + const v128_t ylo_h = wasm_i16x8_extend_high_i8x16(ys_lo); + const v128_t v1s_l = wasm_u16x8_extend_low_u8x16(v1s); + const v128_t v1s_h = wasm_u16x8_extend_high_u8x16(v1s); + const v128_t yhi_l = wasm_i16x8_extend_low_i8x16(ys_hi); + const v128_t yhi_h = wasm_i16x8_extend_high_i8x16(ys_hi); + + const v128_t acc = wasm_i32x4_add( + wasm_i32x4_add( + wasm_i32x4_dot_i16x8(v0s_l, ylo_l), + wasm_i32x4_dot_i16x8(v0s_h, ylo_h)), + wasm_i32x4_add( + wasm_i32x4_dot_i16x8(v1s_l, yhi_l), + wasm_i32x4_dot_i16x8(v1s_h, yhi_h))); + + sumv = wasm_f32x4_add(sumv, + wasm_f32x4_mul( + wasm_f32x4_convert_i32x4(acc), + wasm_f32x4_splat(GGML_CPU_FP16_TO_FP32(x0->d) * GGML_CPU_FP16_TO_FP32(y0->d)))); + } + + sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) + + wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3) + summs; + + *s = sumf; + +#else + UNUSED(nb); + UNUSED(x); + UNUSED(y); + UNUSED(sumf); + + ggml_vec_dot_q4_1_q8_1_generic( + n, s, bs, vx, bx, vy, by, nrc); +#endif +} + void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { const int qk = QK8_0; const int nb = n / qk; From 4ecede8c8bb6b4d899a13e27b8b672ad1bc67311 Mon Sep 17 00:00:00 2001 From: Mason Milburn <masonmilby@gmail.com> Date: Fri, 5 Jun 2026 01:10:31 -0400 Subject: [PATCH 787/831] sycl : port multi-column MMVQ from CUDA backend (llama/21845) mmvq: Port the ncols_dst optimization from ggml-cuda/mmvq.cu to SYCL. Read weights once per dispatch instead of once per column. Covers all standard quant types + reorder paths for Q4_0, Q8_0, Q3_K, Q4_K, Q5_K, Q6_K. IQ types (except IQ4_XS) excluded due to incompatible vec_dot signatures. ggml-sycl: The weight reorder was only bootstrapped on single-token mat-vec (ne[1] == 1). Speculative / MTP verify issues only multi-column mat-vec, so it never triggered the reorder and ran on the slower non-reorder kernel. Bootstrap it on small multi-column batches (ne[1] <= 8) too. --- ggml/src/ggml-sycl/ggml-sycl.cpp | 4 +- ggml/src/ggml-sycl/mmvq.cpp | 1118 +++++++++++++++++++++++++++++- 2 files changed, 1095 insertions(+), 27 deletions(-) diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 96138f57ebe..3f246e8672d 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -3971,7 +3971,9 @@ static bool should_reorder_tensor(ggml_backend_sycl_context& ctx, const ggml_ten return !g_ggml_sycl_disable_optimize && //allow optimize, controlled by $GGML_SYCL_DISABLE_OPT ctx.opt_feature.reorder && //allow this device due to good perf, skip the devices with bad perf. dst->op == GGML_OP_MUL_MAT && //limit to some supported cases of Q4_0, to do for more cases. - dst->src[1]->ne[1]==1 && dst->src[1]->ne[2]==1 && dst->src[1]->ne[3]==1; + // ne[1] <= 8 so multi-column decode (spec / MTP verify) also bootstraps the reorder; + // all reorderable types have a _switch_ncols kernel. + dst->src[1]->ne[1] <= 8 && dst->src[1]->ne[2]==1 && dst->src[1]->ne[3]==1; } static void opt_for_reorder(ggml_backend_sycl_context * ctx, const ggml_tensor * src0, const ggml_tensor * /* src1 */, diff --git a/ggml/src/ggml-sycl/mmvq.cpp b/ggml/src/ggml-sycl/mmvq.cpp index abd1e49a70e..cf2b59576aa 100644 --- a/ggml/src/ggml-sycl/mmvq.cpp +++ b/ggml/src/ggml-sycl/mmvq.cpp @@ -56,6 +56,65 @@ static void mul_mat_vec_q_reorder(const void * __restrict__ vx, const void * __r } } +template <typename reorder_vec_dot_q_sycl, int ncols_dst> +static void mul_mat_vec_q_reorder_ncols(const void * __restrict__ vx, const void * __restrict__ vy, + float * __restrict__ dst, const int ncols, const int nrows, + const int stride_col_y_bytes, const int stride_col_dst, + const sycl::nd_item<3> & nd_item) { + using block_type = ggml_sycl_reordered::block_q_t<reorder_vec_dot_q_sycl::gtype>; + using block_traits = typename block_type::traits; + + const auto sg = nd_item.get_sub_group(); + const int sg_range = sg.get_group_linear_range(); + const int workgroup_id = nd_item.get_group_linear_id(); + const int sg_id = sg.get_group_linear_id(); + const int row = workgroup_id * sg_range + sg_id; + + if (row >= nrows) { + return; + } + + const int blocks_per_row = ncols / block_traits::qk; + constexpr int blocks_per_subgroup = ceil_div(block_traits::vdr_mmvq * WARP_SIZE, block_traits::qi); + constexpr int block_elements_per_subgroup = block_traits::qi / block_traits::vdr_mmvq; + const int nblocks = nrows * (ncols / block_traits::qk); + + static_assert(blocks_per_subgroup > 0); + static_assert(block_elements_per_subgroup > 0); + + float partial_sum[ncols_dst] = {0.0f}; + for (int i = sg.get_local_linear_id() / block_elements_per_subgroup; i < blocks_per_row; i += blocks_per_subgroup) { + const int ibx = row * blocks_per_row + i; + + const auto bx_offset = block_type::get_block_offset(ibx, nblocks); + const auto d_offset = block_type::get_d_offset(nrows, ncols, ibx); + const int iby = i * block_type::block_to_q8_1_ratio(); + +#pragma unroll + for (int elem = 0; elem < block_elements_per_subgroup; elem += WARP_SIZE) { + const int iqs = elem + block_traits::vdr_mmvq * (sg.get_local_linear_id() % block_elements_per_subgroup); + +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { + const char * vy_j = (const char *)vy + j * stride_col_y_bytes; + const int8_t * q8_1_quant_ptr = (const int8_t *)vy_j + iby * QK8_1; + const sycl::half2* q8_1_ds_ptr = (const sycl::half2 *)(vy_j + ncols + iby * sizeof(sycl::half2)); + + partial_sum[j] += reorder_vec_dot_q_sycl()(vx, bx_offset, d_offset, q8_1_quant_ptr, q8_1_ds_ptr, iqs); + } + } + } + +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { + float sum = sycl::reduce_over_group(nd_item.get_sub_group(), partial_sum[j], std::plus<>()); + + if (sg.leader()) { + dst[j * stride_col_dst + row] = sum; + } + } +} + template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_sycl_t vec_dot_q_sycl> static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows, const sycl::nd_item<3> & item_ct1) { @@ -100,6 +159,70 @@ static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict_ } } +template <int qk, int qi, typename block_q_t, int vdr, + vec_dot_q_sycl_t vec_dot_q_sycl, int ncols_dst> +static void mul_mat_vec_q_ncols( + const void * __restrict__ vx, + const void * __restrict__ vy, + float * __restrict__ dst, + const int ncols, + const int nrows, + const int stride_col_y, + const int stride_col_dst, + const sycl::nd_item<3> & item_ct1) { + + const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + + item_ct1.get_local_id(1); + + if (row >= nrows) { + return; + } + + const int blocks_per_row = ncols / qk; + constexpr int blocks_per_warp = (vdr * WARP_SIZE + qi - 1) / qi; + + // partial sums: one per output column + float tmp[ncols_dst] = {0.0f}; + + const block_q_t * x = (const block_q_t *) vx; + const block_q8_1 * y = (const block_q8_1 *) vy; + + for (int i = item_ct1.get_local_id(2) / (qi / vdr); + i < blocks_per_row; + i += blocks_per_warp) { + + const int ibx = row * blocks_per_row + i; + const int iby = i * (qk / QK8_1); + + // read weight block once, dot against all columns + for (size_t elem = 0; elem < qi / vdr; elem += WARP_SIZE) { + const int iqs = elem + vdr * (item_ct1.get_local_id(2) % (qi / vdr)); + +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { + tmp[j] += vec_dot_q_sycl(&x[ibx], &y[j * stride_col_y + iby], iqs); + } + } + } + + // reduce within subgroup +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { +#pragma unroll + for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { + tmp[j] += dpct::permute_sub_group_by_xor( + item_ct1.get_sub_group(), tmp[j], mask); + } + } + + if (item_ct1.get_local_id(2) == 0) { +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { + dst[j * stride_col_dst + row] = tmp[j]; + } + } +} + template <int qk, int qi, typename block_q_t, int vdr> static void mul_mat_vec_q_iq2_xxs_q8_1(const void *__restrict__ vx, const void *__restrict__ vy, @@ -553,6 +676,45 @@ static void reorder_mul_mat_vec_q4_0_q8_1_sycl(const void * vx, const void * vy, }); } +template <int ncols_dst> +static void reorder_mul_mat_vec_q4_0_q8_1_sycl_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int stride_col_y_bytes, const int stride_col_dst, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK4_0 == 0); + const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y); + constexpr size_t num_subgroups = 16; + GGML_ASSERT(block_num_y % num_subgroups == 0); + const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE); + const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE); + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size), + [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_reorder_ncols<reorder_vec_dot_q_sycl<GGML_TYPE_Q4_0>, ncols_dst>( + vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, nd_item); + }); + }); +} + +static void reorder_mul_mat_vec_q4_0_q8_1_sycl_switch_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, const int ncols_dst, + const int stride_col_y_bytes, const int stride_col_dst, + dpct::queue_ptr stream) { + switch (ncols_dst) { + case 1: reorder_mul_mat_vec_q4_0_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break; + case 2: reorder_mul_mat_vec_q4_0_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 3: reorder_mul_mat_vec_q4_0_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 4: reorder_mul_mat_vec_q4_0_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 5: reorder_mul_mat_vec_q4_0_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 6: reorder_mul_mat_vec_q4_0_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 7: reorder_mul_mat_vec_q4_0_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 8: reorder_mul_mat_vec_q4_0_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + default: GGML_ABORT("unsupported ncols_dst=%d for Q4_0 reorder multi-col MMVQ", ncols_dst); + } +} + static void mul_mat_vec_q4_0_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, dpct::queue_ptr stream) { GGML_ASSERT(ncols % QK4_0 == 0); @@ -571,6 +733,45 @@ static void mul_mat_vec_q4_0_q8_1_sycl(const void * vx, const void * vy, float * } } +template <int ncols_dst> +static void mul_mat_vec_q4_0_q8_1_sycl_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK4_0 == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_ncols<QK4_0, QI4_0, block_q4_0, + VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1, ncols_dst>( + vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, item_ct1); + }); + }); +} + +static void mul_mat_vec_q4_0_q8_1_sycl_switch_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, const int ncols_dst, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + switch (ncols_dst) { + case 1: mul_mat_vec_q4_0_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break; + case 2: mul_mat_vec_q4_0_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 3: mul_mat_vec_q4_0_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 4: mul_mat_vec_q4_0_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 5: mul_mat_vec_q4_0_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 6: mul_mat_vec_q4_0_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 7: mul_mat_vec_q4_0_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 8: mul_mat_vec_q4_0_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + default: GGML_ABORT("unsupported ncols_dst=%d for Q4_0 multi-col MMVQ", ncols_dst); + } +} + static void mul_mat_vec_q4_1_q8_1_sycl(const void *vx, const void *vy, float *dst, const int ncols, const int nrows, @@ -595,6 +796,45 @@ static void mul_mat_vec_q4_1_q8_1_sycl(const void *vx, const void *vy, } } +template <int ncols_dst> +static void mul_mat_vec_q4_1_q8_1_sycl_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK4_1 == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_ncols<QK4_0, QI4_1, block_q4_1, + VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1, ncols_dst>( + vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, item_ct1); + }); + }); +} + +static void mul_mat_vec_q4_1_q8_1_sycl_switch_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, const int ncols_dst, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + switch (ncols_dst) { + case 1: mul_mat_vec_q4_1_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break; + case 2: mul_mat_vec_q4_1_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 3: mul_mat_vec_q4_1_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 4: mul_mat_vec_q4_1_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 5: mul_mat_vec_q4_1_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 6: mul_mat_vec_q4_1_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 7: mul_mat_vec_q4_1_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 8: mul_mat_vec_q4_1_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + default: GGML_ABORT("unsupported ncols_dst=%d for Q4_1 multi-col MMVQ", ncols_dst); + } +} + static void mul_mat_vec_mxfp4_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, dpct::queue_ptr stream) { GGML_ASSERT(ncols % QK_MXFP4 == 0); @@ -613,6 +853,45 @@ static void mul_mat_vec_mxfp4_q8_1_sycl(const void * vx, const void * vy, float } } +template <int ncols_dst> +static void mul_mat_vec_mxfp4_q8_1_sycl_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_MXFP4 == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_ncols<QK_MXFP4, QI_MXFP4, block_mxfp4, + VDR_MXFP4_Q8_1_MMVQ, vec_dot_mxfp4_q8_1, ncols_dst>( + vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, item_ct1); + }); + }); +} + +static void mul_mat_vec_mxfp4_q8_1_sycl_switch_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, const int ncols_dst, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + switch (ncols_dst) { + case 1: mul_mat_vec_mxfp4_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break; + case 2: mul_mat_vec_mxfp4_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 3: mul_mat_vec_mxfp4_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 4: mul_mat_vec_mxfp4_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 5: mul_mat_vec_mxfp4_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 6: mul_mat_vec_mxfp4_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 7: mul_mat_vec_mxfp4_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 8: mul_mat_vec_mxfp4_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + default: GGML_ABORT("unsupported ncols_dst=%d for MXFP4 multi-col MMVQ", ncols_dst); + } +} + static void mul_mat_vec_nvfp4_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, dpct::queue_ptr stream) { GGML_ASSERT(ncols % QK_NVFP4 == 0); @@ -631,6 +910,45 @@ static void mul_mat_vec_nvfp4_q8_1_sycl(const void * vx, const void * vy, float } } +template <int ncols_dst> +static void mul_mat_vec_nvfp4_q8_1_sycl_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_NVFP4 == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_ncols<QK_NVFP4, QI_NVFP4, block_nvfp4, + VDR_NVFP4_Q8_1_MMVQ, vec_dot_nvfp4_q8_1, ncols_dst>( + vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, item_ct1); + }); + }); +} + +static void mul_mat_vec_nvfp4_q8_1_sycl_switch_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, const int ncols_dst, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + switch (ncols_dst) { + case 1: mul_mat_vec_nvfp4_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break; + case 2: mul_mat_vec_nvfp4_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 3: mul_mat_vec_nvfp4_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 4: mul_mat_vec_nvfp4_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 5: mul_mat_vec_nvfp4_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 6: mul_mat_vec_nvfp4_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 7: mul_mat_vec_nvfp4_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 8: mul_mat_vec_nvfp4_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + default: GGML_ABORT("unsupported ncols_dst=%d for NVFP4 multi-col MMVQ", ncols_dst); + } +} + static void mul_mat_vec_q5_0_q8_1_sycl(const void *vx, const void *vy, float *dst, const int ncols, const int nrows, @@ -655,6 +973,45 @@ static void mul_mat_vec_q5_0_q8_1_sycl(const void *vx, const void *vy, } } +template <int ncols_dst> +static void mul_mat_vec_q5_0_q8_1_sycl_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK5_0 == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_ncols<QK5_0, QI5_0, block_q5_0, + VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1, ncols_dst>( + vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, item_ct1); + }); + }); +} + +static void mul_mat_vec_q5_0_q8_1_sycl_switch_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, const int ncols_dst, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + switch (ncols_dst) { + case 1: mul_mat_vec_q5_0_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break; + case 2: mul_mat_vec_q5_0_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 3: mul_mat_vec_q5_0_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 4: mul_mat_vec_q5_0_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 5: mul_mat_vec_q5_0_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 6: mul_mat_vec_q5_0_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 7: mul_mat_vec_q5_0_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 8: mul_mat_vec_q5_0_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + default: GGML_ABORT("unsupported ncols_dst=%d for Q5_0 multi-col MMVQ", ncols_dst); + } +} + static void mul_mat_vec_q5_1_q8_1_sycl(const void *vx, const void *vy, float *dst, const int ncols, const int nrows, @@ -679,6 +1036,45 @@ static void mul_mat_vec_q5_1_q8_1_sycl(const void *vx, const void *vy, } } +template <int ncols_dst> +static void mul_mat_vec_q5_1_q8_1_sycl_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK5_1 == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_ncols<QK5_1, QI5_1, block_q5_1, + VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1, ncols_dst>( + vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, item_ct1); + }); + }); +} + +static void mul_mat_vec_q5_1_q8_1_sycl_switch_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, const int ncols_dst, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + switch (ncols_dst) { + case 1: mul_mat_vec_q5_1_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break; + case 2: mul_mat_vec_q5_1_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 3: mul_mat_vec_q5_1_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 4: mul_mat_vec_q5_1_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 5: mul_mat_vec_q5_1_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 6: mul_mat_vec_q5_1_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 7: mul_mat_vec_q5_1_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 8: mul_mat_vec_q5_1_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + default: GGML_ABORT("unsupported ncols_dst=%d for Q5_1 multi-col MMVQ", ncols_dst); + } +} + static void reorder_mul_mat_vec_q8_0_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, dpct::queue_ptr stream) { GGML_ASSERT(ncols % QK8_0 == 0); @@ -698,6 +1094,45 @@ static void reorder_mul_mat_vec_q8_0_q8_1_sycl(const void * vx, const void * vy, }); } +template <int ncols_dst> +static void reorder_mul_mat_vec_q8_0_q8_1_sycl_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int stride_col_y_bytes, const int stride_col_dst, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK8_0 == 0); + const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y); + constexpr size_t num_subgroups = 16; + GGML_ASSERT(block_num_y % num_subgroups == 0); + const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE); + const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE); + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size), + [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_reorder_ncols<reorder_vec_dot_q_sycl<GGML_TYPE_Q8_0>, ncols_dst>( + vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, nd_item); + }); + }); +} + +static void reorder_mul_mat_vec_q8_0_q8_1_sycl_switch_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, const int ncols_dst, + const int stride_col_y_bytes, const int stride_col_dst, + dpct::queue_ptr stream) { + switch (ncols_dst) { + case 1: reorder_mul_mat_vec_q8_0_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break; + case 2: reorder_mul_mat_vec_q8_0_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 3: reorder_mul_mat_vec_q8_0_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 4: reorder_mul_mat_vec_q8_0_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 5: reorder_mul_mat_vec_q8_0_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 6: reorder_mul_mat_vec_q8_0_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 7: reorder_mul_mat_vec_q8_0_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 8: reorder_mul_mat_vec_q8_0_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + default: GGML_ABORT("unsupported ncols_dst=%d for Q8_0 reorder multi-col MMVQ", ncols_dst); + } +} + static void mul_mat_vec_q8_0_q8_1_sycl(const void *vx, const void *vy, float *dst, const int ncols, const int nrows, @@ -722,6 +1157,45 @@ static void mul_mat_vec_q8_0_q8_1_sycl(const void *vx, const void *vy, } } +template <int ncols_dst> +static void mul_mat_vec_q8_0_q8_1_sycl_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK8_0 == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_ncols<QK8_0, QI8_0, block_q8_0, + VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1, ncols_dst>( + vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, item_ct1); + }); + }); +} + +static void mul_mat_vec_q8_0_q8_1_sycl_switch_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, const int ncols_dst, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + switch (ncols_dst) { + case 1: mul_mat_vec_q8_0_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break; + case 2: mul_mat_vec_q8_0_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 3: mul_mat_vec_q8_0_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 4: mul_mat_vec_q8_0_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 5: mul_mat_vec_q8_0_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 6: mul_mat_vec_q8_0_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 7: mul_mat_vec_q8_0_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 8: mul_mat_vec_q8_0_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + default: GGML_ABORT("unsupported ncols_dst=%d for Q8_0 multi-col MMVQ", ncols_dst); + } +} + static void mul_mat_vec_q2_K_q8_1_sycl(const void *vx, const void *vy, float *dst, const int ncols, const int nrows, @@ -746,6 +1220,45 @@ static void mul_mat_vec_q2_K_q8_1_sycl(const void *vx, const void *vy, } } +template <int ncols_dst> +static void mul_mat_vec_q2_K_q8_1_sycl_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_ncols<QK_K, QI2_K, block_q2_K, + VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1, ncols_dst>( + vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, item_ct1); + }); + }); +} + +static void mul_mat_vec_q2_K_q8_1_sycl_switch_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, const int ncols_dst, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + switch (ncols_dst) { + case 1: mul_mat_vec_q2_K_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break; + case 2: mul_mat_vec_q2_K_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 3: mul_mat_vec_q2_K_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 4: mul_mat_vec_q2_K_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 5: mul_mat_vec_q2_K_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 6: mul_mat_vec_q2_K_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 7: mul_mat_vec_q2_K_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 8: mul_mat_vec_q2_K_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + default: GGML_ABORT("unsupported ncols_dst=%d for Q2_K multi-col MMVQ", ncols_dst); + } +} + static void mul_mat_vec_q3_K_q8_1_sycl(const void *vx, const void *vy, float *dst, const int ncols, const int nrows, @@ -790,6 +1303,85 @@ static void reorder_mul_mat_vec_q3_k_q8_1_sycl(const void * vx, const void * vy, }); } +template <int ncols_dst> +static void reorder_mul_mat_vec_q3_k_q8_1_sycl_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int stride_col_y_bytes, const int stride_col_dst, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y); + constexpr size_t num_subgroups = 16; + GGML_ASSERT(block_num_y % num_subgroups == 0); + const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE); + const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE); + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size), + [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_reorder_ncols<reorder_vec_dot_q_sycl<GGML_TYPE_Q3_K>, ncols_dst>( + vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, nd_item); + }); + }); +} + +static void reorder_mul_mat_vec_q3_k_q8_1_sycl_switch_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, const int ncols_dst, + const int stride_col_y_bytes, const int stride_col_dst, + dpct::queue_ptr stream) { + switch (ncols_dst) { + case 1: reorder_mul_mat_vec_q3_k_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break; + case 2: reorder_mul_mat_vec_q3_k_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 3: reorder_mul_mat_vec_q3_k_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 4: reorder_mul_mat_vec_q3_k_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 5: reorder_mul_mat_vec_q3_k_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 6: reorder_mul_mat_vec_q3_k_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 7: reorder_mul_mat_vec_q3_k_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 8: reorder_mul_mat_vec_q3_k_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + default: GGML_ABORT("unsupported ncols_dst=%d for Q3_K reorder multi-col MMVQ", ncols_dst); + } +} + +template <int ncols_dst> +static void mul_mat_vec_q3_K_q8_1_sycl_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_ncols<QK_K, QI3_K, block_q3_K, + VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1, ncols_dst>( + vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, item_ct1); + }); + }); +} + +static void mul_mat_vec_q3_K_q8_1_sycl_switch_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, const int ncols_dst, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + switch (ncols_dst) { + case 1: mul_mat_vec_q3_K_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break; + case 2: mul_mat_vec_q3_K_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 3: mul_mat_vec_q3_K_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 4: mul_mat_vec_q3_K_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 5: mul_mat_vec_q3_K_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 6: mul_mat_vec_q3_K_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 7: mul_mat_vec_q3_K_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 8: mul_mat_vec_q3_K_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + default: GGML_ABORT("unsupported ncols_dst=%d for Q3_K multi-col MMVQ", ncols_dst); + } +} + + static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy, float *dst, const int ncols, const int nrows, @@ -814,6 +1406,51 @@ static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy, } } +template <int ncols_dst> +static void mul_mat_vec_q4_K_q8_1_sycl_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_ncols<QK_K, QI4_K, block_q4_K, + VDR_Q4_K_Q8_1_MMVQ, + vec_dot_q4_K_q8_1, + ncols_dst>( + vx, vy, dst, ncols, nrows, + stride_col_y, stride_col_dst, item_ct1); + }); + }); +} + +static void mul_mat_vec_q4_K_q8_1_sycl_switch_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int ncols_dst, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + switch (ncols_dst) { + case 1: mul_mat_vec_q4_K_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break; + case 2: mul_mat_vec_q4_K_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 3: mul_mat_vec_q4_K_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 4: mul_mat_vec_q4_K_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 5: mul_mat_vec_q4_K_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 6: mul_mat_vec_q4_K_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 7: mul_mat_vec_q4_K_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 8: mul_mat_vec_q4_K_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + default: GGML_ABORT("unsupported ncols_dst=%d for Q4_K multi-col MMVQ", ncols_dst); + } +} + static void reorder_mul_mat_vec_q4_k_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, dpct::queue_ptr stream) { GGML_ASSERT(ncols % QK_K == 0); @@ -834,6 +1471,44 @@ static void reorder_mul_mat_vec_q4_k_q8_1_sycl(const void * vx, const void * vy, }); } +template <int ncols_dst> +static void reorder_mul_mat_vec_q4_k_q8_1_sycl_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int stride_col_y_bytes, const int stride_col_dst, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y); + constexpr size_t num_subgroups = 16; + GGML_ASSERT(block_num_y % num_subgroups == 0); + const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE); + const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE); + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size), + [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_reorder_ncols<reorder_vec_dot_q_sycl<GGML_TYPE_Q4_K>, ncols_dst>( + vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, nd_item); + }); + }); +} + +static void reorder_mul_mat_vec_q4_k_q8_1_sycl_switch_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, const int ncols_dst, + const int stride_col_y_bytes, const int stride_col_dst, + dpct::queue_ptr stream) { + switch (ncols_dst) { + case 1: reorder_mul_mat_vec_q4_k_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break; + case 2: reorder_mul_mat_vec_q4_k_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 3: reorder_mul_mat_vec_q4_k_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 4: reorder_mul_mat_vec_q4_k_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 5: reorder_mul_mat_vec_q4_k_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 6: reorder_mul_mat_vec_q4_k_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 7: reorder_mul_mat_vec_q4_k_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 8: reorder_mul_mat_vec_q4_k_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + default: GGML_ABORT("unsupported ncols_dst=%d for Q4_K reorder multi-col MMVQ", ncols_dst); + } +} static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy, float *dst, const int ncols, @@ -859,6 +1534,51 @@ static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy, } } +template <int ncols_dst> +static void mul_mat_vec_q5_K_q8_1_sycl_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_ncols<QK_K, QI5_K, block_q5_K, + VDR_Q5_K_Q8_1_MMVQ, + vec_dot_q5_K_q8_1, + ncols_dst>( + vx, vy, dst, ncols, nrows, + stride_col_y, stride_col_dst, item_ct1); + }); + }); +} + +static void mul_mat_vec_q5_K_q8_1_sycl_switch_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int ncols_dst, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + switch (ncols_dst) { + case 1: mul_mat_vec_q5_K_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break; + case 2: mul_mat_vec_q5_K_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 3: mul_mat_vec_q5_K_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 4: mul_mat_vec_q5_K_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 5: mul_mat_vec_q5_K_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 6: mul_mat_vec_q5_K_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 7: mul_mat_vec_q5_K_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 8: mul_mat_vec_q5_K_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + default: GGML_ABORT("unsupported ncols_dst=%d for Q5_K multi-col MMVQ", ncols_dst); + } +} + static void reorder_mul_mat_vec_q5_k_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, dpct::queue_ptr stream) { GGML_ASSERT(ncols % QK_K == 0); @@ -879,6 +1599,45 @@ static void reorder_mul_mat_vec_q5_k_q8_1_sycl(const void * vx, const void * vy, }); } +template <int ncols_dst> +static void reorder_mul_mat_vec_q5_k_q8_1_sycl_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int stride_col_y_bytes, const int stride_col_dst, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y); + constexpr size_t num_subgroups = 16; + GGML_ASSERT(block_num_y % num_subgroups == 0); + const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE); + const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE); + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size), + [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_reorder_ncols<reorder_vec_dot_q_sycl<GGML_TYPE_Q5_K>, ncols_dst>( + vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, nd_item); + }); + }); +} + +static void reorder_mul_mat_vec_q5_k_q8_1_sycl_switch_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, const int ncols_dst, + const int stride_col_y_bytes, const int stride_col_dst, + dpct::queue_ptr stream) { + switch (ncols_dst) { + case 1: reorder_mul_mat_vec_q5_k_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break; + case 2: reorder_mul_mat_vec_q5_k_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 3: reorder_mul_mat_vec_q5_k_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 4: reorder_mul_mat_vec_q5_k_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 5: reorder_mul_mat_vec_q5_k_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 6: reorder_mul_mat_vec_q5_k_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 7: reorder_mul_mat_vec_q5_k_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 8: reorder_mul_mat_vec_q5_k_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + default: GGML_ABORT("unsupported ncols_dst=%d for Q5_K reorder multi-col MMVQ", ncols_dst); + } +} + static void reorder_mul_mat_vec_q6_k_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, dpct::queue_ptr stream) { GGML_ASSERT(ncols % QK_K == 0); @@ -897,6 +1656,46 @@ static void reorder_mul_mat_vec_q6_k_q8_1_sycl(const void * vx, const void * vy, }); }); } + +template <int ncols_dst> +static void reorder_mul_mat_vec_q6_k_q8_1_sycl_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int stride_col_y_bytes, const int stride_col_dst, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y); + constexpr size_t num_subgroups = 16; + GGML_ASSERT(block_num_y % num_subgroups == 0); + const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE); + const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE); + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size), + [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_reorder_ncols<reorder_vec_dot_q_sycl<GGML_TYPE_Q6_K>, ncols_dst>( + vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, nd_item); + }); + }); +} + +static void reorder_mul_mat_vec_q6_k_q8_1_sycl_switch_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, const int ncols_dst, + const int stride_col_y_bytes, const int stride_col_dst, + dpct::queue_ptr stream) { + switch (ncols_dst) { + case 1: reorder_mul_mat_vec_q6_k_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break; + case 2: reorder_mul_mat_vec_q6_k_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 3: reorder_mul_mat_vec_q6_k_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 4: reorder_mul_mat_vec_q6_k_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 5: reorder_mul_mat_vec_q6_k_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 6: reorder_mul_mat_vec_q6_k_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 7: reorder_mul_mat_vec_q6_k_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 8: reorder_mul_mat_vec_q6_k_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + default: GGML_ABORT("unsupported ncols_dst=%d for Q6_K reorder multi-col MMVQ", ncols_dst); + } +} + static void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy, float *dst, const int ncols, const int nrows, @@ -921,6 +1720,51 @@ static void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy, } } +template <int ncols_dst> +static void mul_mat_vec_q6_K_q8_1_sycl_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_ncols<QK_K, QI6_K, block_q6_K, + VDR_Q6_K_Q8_1_MMVQ, + vec_dot_q6_K_q8_1, + ncols_dst>( + vx, vy, dst, ncols, nrows, + stride_col_y, stride_col_dst, item_ct1); + }); + }); +} + +static void mul_mat_vec_q6_K_q8_1_sycl_switch_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int ncols_dst, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + switch (ncols_dst) { + case 1: mul_mat_vec_q6_K_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break; + case 2: mul_mat_vec_q6_K_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 3: mul_mat_vec_q6_K_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 4: mul_mat_vec_q6_K_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 5: mul_mat_vec_q6_K_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 6: mul_mat_vec_q6_K_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 7: mul_mat_vec_q6_K_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 8: mul_mat_vec_q6_K_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + default: GGML_ABORT("unsupported ncols_dst=%d for Q6_K multi-col MMVQ", ncols_dst); + } +} + static void mul_mat_vec_iq2_xxs_q8_1_sycl(const void *vx, const void *vy, float *dst, const int ncols, @@ -1117,6 +1961,51 @@ static void mul_mat_vec_iq4_xs_q8_1_sycl(const void *vx, const void *vy, } } +template <int ncols_dst> +static void mul_mat_vec_iq4_xs_q8_1_sycl_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_ncols<QK_K, QI4_XS/4, block_iq4_xs, + 1, + vec_dot_iq4_xs_q8_1, + ncols_dst>( + vx, vy, dst, ncols, nrows, + stride_col_y, stride_col_dst, item_ct1); + }); + }); +} + +static void mul_mat_vec_iq4_xs_q8_1_sycl_switch_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int ncols_dst, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + switch (ncols_dst) { + case 1: mul_mat_vec_iq4_xs_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break; + case 2: mul_mat_vec_iq4_xs_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 3: mul_mat_vec_iq4_xs_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 4: mul_mat_vec_iq4_xs_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 5: mul_mat_vec_iq4_xs_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 6: mul_mat_vec_iq4_xs_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 7: mul_mat_vec_iq4_xs_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 8: mul_mat_vec_iq4_xs_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + default: GGML_ABORT("unsupported ncols_dst=%d for IQ4_XS multi-col MMVQ", ncols_dst); + } +} + void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, @@ -1143,42 +2032,135 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens case GGML_TYPE_Q4_0: if ((ggml_tensor_extra_gpu *) dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { - GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q4_0_q8_1_sycl\n"); - reorder_mul_mat_vec_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - } else { + if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) { + const int stride_col_y_bytes = src1_padded_col_size * q8_1_ts / q8_1_bs; + const int stride_col_dst = dst->ne[0]; + GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q4_0_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols); + reorder_mul_mat_vec_q4_0_q8_1_sycl_switch_ncols( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, + src1_ncols, stride_col_y_bytes, stride_col_dst, stream); + return; + } else { + GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q4_0_q8_1_sycl\n"); + reorder_mul_mat_vec_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } + } else if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) { + const int stride_col_y = src1_padded_col_size / QK8_1; + const int stride_col_dst = dst->ne[0]; + GGML_SYCL_DEBUG("Calling mul_mat_vec_q4_0_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols); + mul_mat_vec_q4_0_q8_1_sycl_switch_ncols( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, + src1_ncols, stride_col_y, stride_col_dst, stream); + return; + } else if (i == 0 || src1_ncols == 1) { GGML_SYCL_DEBUG("Calling mul_mat_vec_q4_0_q8_1_sycl\n"); mul_mat_vec_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); } break; case GGML_TYPE_Q4_1: - mul_mat_vec_q4_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) { + const int stride_col_y = src1_padded_col_size / QK8_1; + const int stride_col_dst = dst->ne[0]; + GGML_SYCL_DEBUG("Calling mul_mat_vec_q4_1_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols); + mul_mat_vec_q4_1_q8_1_sycl_switch_ncols( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, + src1_ncols, stride_col_y, stride_col_dst, stream); + return; + } else if (i == 0 || src1_ncols == 1) { + mul_mat_vec_q4_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } break; case GGML_TYPE_Q5_0: - mul_mat_vec_q5_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) { + const int stride_col_y = src1_padded_col_size / QK8_1; + const int stride_col_dst = dst->ne[0]; + GGML_SYCL_DEBUG("Calling mul_mat_vec_q5_0_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols); + mul_mat_vec_q5_0_q8_1_sycl_switch_ncols( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, + src1_ncols, stride_col_y, stride_col_dst, stream); + return; + } else if (i == 0 || src1_ncols == 1) { + mul_mat_vec_q5_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } break; case GGML_TYPE_Q5_1: - mul_mat_vec_q5_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) { + const int stride_col_y = src1_padded_col_size / QK8_1; + const int stride_col_dst = dst->ne[0]; + GGML_SYCL_DEBUG("Calling mul_mat_vec_q5_1_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols); + mul_mat_vec_q5_1_q8_1_sycl_switch_ncols( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, + src1_ncols, stride_col_y, stride_col_dst, stream); + return; + } else if (i == 0 || src1_ncols == 1) { + mul_mat_vec_q5_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } break; case GGML_TYPE_Q8_0: if ((ggml_tensor_extra_gpu *) dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { - GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q8_0_q8_1_sycl\n"); - reorder_mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - } else { + if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) { + const int stride_col_y_bytes = src1_padded_col_size * q8_1_ts / q8_1_bs; + const int stride_col_dst = dst->ne[0]; + GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q8_0_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols); + reorder_mul_mat_vec_q8_0_q8_1_sycl_switch_ncols( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, + src1_ncols, stride_col_y_bytes, stride_col_dst, stream); + return; + } else { + GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q8_0_q8_1_sycl\n"); + reorder_mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } + } else if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) { + const int stride_col_y = src1_padded_col_size / QK8_1; + const int stride_col_dst = dst->ne[0]; + GGML_SYCL_DEBUG("Calling mul_mat_vec_q8_0_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols); + mul_mat_vec_q8_0_q8_1_sycl_switch_ncols( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, + src1_ncols, stride_col_y, stride_col_dst, stream); + return; + } else if (i == 0 || src1_ncols == 1) { GGML_SYCL_DEBUG("Calling mul_mat_vec_q8_0_q8_1_sycl\n"); mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); } break; case GGML_TYPE_Q2_K: - mul_mat_vec_q2_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) { + const int stride_col_y = src1_padded_col_size / QK8_1; + const int stride_col_dst = dst->ne[0]; + GGML_SYCL_DEBUG("Calling mul_mat_vec_q2_K_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols); + mul_mat_vec_q2_K_q8_1_sycl_switch_ncols( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, + src1_ncols, stride_col_y, stride_col_dst, stream); + return; + } else if (i == 0 || src1_ncols == 1) { + mul_mat_vec_q2_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } break; case GGML_TYPE_Q3_K: if ((ggml_tensor_extra_gpu *) dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { - GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q3_k_q8_1_sycl\n"); - reorder_mul_mat_vec_q3_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, - stream); - } else { + if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) { + const int stride_col_y_bytes = src1_padded_col_size * q8_1_ts / q8_1_bs; + const int stride_col_dst = dst->ne[0]; + GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q3_k_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols); + reorder_mul_mat_vec_q3_k_q8_1_sycl_switch_ncols( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, + src1_ncols, stride_col_y_bytes, stride_col_dst, stream); + return; + } else { + GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q3_k_q8_1_sycl\n"); + reorder_mul_mat_vec_q3_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } + } else if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) { + const int stride_col_y = src1_padded_col_size / QK8_1; + const int stride_col_dst = dst->ne[0]; + GGML_SYCL_DEBUG("Calling mul_mat_vec_q3_K_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols); + mul_mat_vec_q3_K_q8_1_sycl_switch_ncols( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, + src1_ncols, stride_col_y, stride_col_dst, stream); + return; + } else if (i == 0 || src1_ncols == 1) { GGML_SYCL_DEBUG("Calling mul_mat_vec_q3_K_q8_1_sycl\n"); mul_mat_vec_q3_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); } @@ -1186,9 +2168,27 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens case GGML_TYPE_Q4_K: if ((ggml_tensor_extra_gpu *) dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { - GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q4_k_q8_1_sycl\n"); - reorder_mul_mat_vec_q4_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - } else { + if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) { + const int stride_col_y_bytes = src1_padded_col_size * q8_1_ts / q8_1_bs; + const int stride_col_dst = dst->ne[0]; + GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q4_k_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols); + reorder_mul_mat_vec_q4_k_q8_1_sycl_switch_ncols( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, + src1_ncols, stride_col_y_bytes, stride_col_dst, stream); + return; + } else { + GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q4_k_q8_1_sycl\n"); + reorder_mul_mat_vec_q4_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } + } else if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) { + const int stride_col_y = src1_padded_col_size / QK8_1; + const int stride_col_dst = dst->ne[0]; + GGML_SYCL_DEBUG("Calling mul_mat_vec_q4_K_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols); + mul_mat_vec_q4_K_q8_1_sycl_switch_ncols( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, + src1_ncols, stride_col_y, stride_col_dst, stream); + return; + } else if (i == 0 || src1_ncols == 1) { GGML_SYCL_DEBUG("Calling mul_mat_vec_q4_K_q8_1_sycl\n"); mul_mat_vec_q4_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); } @@ -1196,9 +2196,27 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens case GGML_TYPE_Q5_K: if ((ggml_tensor_extra_gpu *) dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { - GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q5_k_q8_1_sycl\n"); - reorder_mul_mat_vec_q5_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - } else { + if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) { + const int stride_col_y_bytes = src1_padded_col_size * q8_1_ts / q8_1_bs; + const int stride_col_dst = dst->ne[0]; + GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q5_k_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols); + reorder_mul_mat_vec_q5_k_q8_1_sycl_switch_ncols( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, + src1_ncols, stride_col_y_bytes, stride_col_dst, stream); + return; + } else { + GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q5_k_q8_1_sycl\n"); + reorder_mul_mat_vec_q5_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } + } else if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) { + const int stride_col_y = src1_padded_col_size / QK8_1; + const int stride_col_dst = dst->ne[0]; + GGML_SYCL_DEBUG("Calling mul_mat_vec_q5_K_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols); + mul_mat_vec_q5_K_q8_1_sycl_switch_ncols( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, + src1_ncols, stride_col_y, stride_col_dst, stream); + return; + } else if (i == 0 || src1_ncols == 1) { GGML_SYCL_DEBUG("Calling mul_mat_vec_q5_K_q8_1_sycl\n"); mul_mat_vec_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); } @@ -1206,9 +2224,27 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens case GGML_TYPE_Q6_K: if ((ggml_tensor_extra_gpu *) dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { - GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q6_k_q8_1_sycl\n"); - reorder_mul_mat_vec_q6_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - } else { + if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) { + const int stride_col_y_bytes = src1_padded_col_size * q8_1_ts / q8_1_bs; + const int stride_col_dst = dst->ne[0]; + GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q6_k_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols); + reorder_mul_mat_vec_q6_k_q8_1_sycl_switch_ncols( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, + src1_ncols, stride_col_y_bytes, stride_col_dst, stream); + return; + } else { + GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q6_k_q8_1_sycl\n"); + reorder_mul_mat_vec_q6_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } + } else if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) { + const int stride_col_y = src1_padded_col_size / QK8_1; + const int stride_col_dst = dst->ne[0]; + GGML_SYCL_DEBUG("Calling mul_mat_vec_q6_K_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols); + mul_mat_vec_q6_K_q8_1_sycl_switch_ncols( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, + src1_ncols, stride_col_y, stride_col_dst, stream); + return; + } else if (i == 0 || src1_ncols == 1) { GGML_SYCL_DEBUG("Calling mul_mat_vec_q6_k_q8_1_sycl\n"); mul_mat_vec_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); } @@ -1238,13 +2274,43 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens mul_mat_vec_iq4_nl_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); break; case GGML_TYPE_IQ4_XS: - mul_mat_vec_iq4_xs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) { + const int stride_col_y = src1_padded_col_size / QK8_1; + const int stride_col_dst = dst->ne[0]; + GGML_SYCL_DEBUG("Calling mul_mat_vec_iq4_xs_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols); + mul_mat_vec_iq4_xs_q8_1_sycl_switch_ncols( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, + src1_ncols, stride_col_y, stride_col_dst, stream); + return; + } else if (i == 0 || src1_ncols == 1) { + mul_mat_vec_iq4_xs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } break; case GGML_TYPE_MXFP4: - mul_mat_vec_mxfp4_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) { + const int stride_col_y = src1_padded_col_size / QK8_1; + const int stride_col_dst = dst->ne[0]; + GGML_SYCL_DEBUG("Calling mul_mat_vec_mxfp4_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols); + mul_mat_vec_mxfp4_q8_1_sycl_switch_ncols( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, + src1_ncols, stride_col_y, stride_col_dst, stream); + return; + } else if (i == 0 || src1_ncols == 1) { + mul_mat_vec_mxfp4_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } break; case GGML_TYPE_NVFP4: - mul_mat_vec_nvfp4_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) { + const int stride_col_y = src1_padded_col_size / QK8_1; + const int stride_col_dst = dst->ne[0]; + GGML_SYCL_DEBUG("Calling mul_mat_vec_nvfp4_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols); + mul_mat_vec_nvfp4_q8_1_sycl_switch_ncols( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, + src1_ncols, stride_col_y, stride_col_dst, stream); + return; + } else if (i == 0 || src1_ncols == 1) { + mul_mat_vec_nvfp4_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } break; default: GGML_ABORT("fatal error: unsupport data type=%s\n", ggml_type_name(src0->type)); From 4fa1e0687e23bfaf286ba328c2c6d0d592cb3158 Mon Sep 17 00:00:00 2001 From: Oliver Simons <osimons@nvidia.com> Date: Fri, 5 Jun 2026 08:37:34 +0200 Subject: [PATCH 788/831] CUDA: enroll mul_mat_vec_q_moe into pdl (llama/24087) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Enroll mul_mat_vec_q_moe into PDL, boosting MTP performance on BW Data collected on a B4500: Before ``` (llama.cpp) ➜ llama.cpp git:(master) ✗ python mtp-bench.py code_python pred= 192 draft= 150 acc= 116 rate=0.773 tok/s=202.8 code_cpp pred= 192 draft= 147 acc= 117 rate=0.796 tok/s=212.8 explain_concept pred= 192 draft= 161 acc= 110 rate=0.683 tok/s=196.4 summarize pred= 192 draft= 138 acc= 122 rate=0.884 tok/s=226.6 qa_factual pred= 192 draft= 138 acc= 121 rate=0.877 tok/s=225.1 translation pred= 192 draft= 158 acc= 112 rate=0.709 tok/s=201.5 creative_short pred= 192 draft= 160 acc= 110 rate=0.688 tok/s=197.2 stepwise_math pred= 192 draft= 150 acc= 115 rate=0.767 tok/s=209.2 long_code_review pred= 192 draft= 148 acc= 116 rate=0.784 tok/s=208.9 ``` After ``` (llama.cpp) ➜ llama.cpp git:(master) ✗ python mtp-bench.py code_python pred= 192 draft= 150 acc= 116 rate=0.773 tok/s=211.9 code_cpp pred= 192 draft= 147 acc= 117 rate=0.796 tok/s=224.6 explain_concept pred= 192 draft= 161 acc= 110 rate=0.683 tok/s=207.8 summarize pred= 192 draft= 138 acc= 122 rate=0.884 tok/s=240.2 qa_factual pred= 192 draft= 138 acc= 121 rate=0.877 tok/s=238.5 translation pred= 192 draft= 158 acc= 112 rate=0.709 tok/s=213.4 creative_short pred= 192 draft= 160 acc= 110 rate=0.688 tok/s=208.8 stepwise_math pred= 192 draft= 150 acc= 115 rate=0.767 tok/s=221.7 long_code_review pred= 192 draft= 148 acc= 116 rate=0.784 tok/s=220.7 ``` Server launched with: ``` ➜ llama.cpp git:(osimons/enroll_mul_mat_vec_q_moe_into_PDL) ✗ ./build-x64-linux-gcc-reldbg/bin/llama-server \ -m /mnt/share/gguf/unsloth/Qwen3.6-35B-A3B-MTP-GGUF/Qwen3.6-35B-A3B-UD-Q4_K_M.gguf -dio \ --spec-type draft-mtp \ --spec-draft-n-max 2 \ -ngl all \ -fa on \ --host 0.0.0.0 \ --port 8080 -np 1 --chat-template-kwargs "{\"preserve_thinking\": true}" ``` * LC to overlap with following kernels --- ggml/src/ggml-cuda/mmvq.cu | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index 4b0426590ac..bdfbfd2d387 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -682,12 +682,16 @@ static __global__ void mul_mat_vec_q( template <ggml_type type, int c_rows_per_block> __launch_bounds__(get_mmvq_mmid_max_batch_for_device<type>()*ggml_cuda_get_physical_warp_size(), 1) static __global__ void mul_mat_vec_q_moe( - const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, - float * __restrict__ dst, + const void * vx_ptr, const void * vy_ptr, const int32_t * ids_ptr, + float * dst_ptr, const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t nrows_x, const uint32_t stride_row_x, const uint32_t stride_col_y, const uint32_t stride_col_dst, const uint32_t stride_channel_x, const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint32_t ncols_dst, const uint32_t ids_stride) { + const void * GGML_CUDA_RESTRICT vx = vx_ptr; + const void * GGML_CUDA_RESTRICT vy = vy_ptr; + const int32_t * GGML_CUDA_RESTRICT ids = ids_ptr; + float * GGML_CUDA_RESTRICT dst = dst_ptr; constexpr int qk = ggml_cuda_type_traits<type>::qk; constexpr int qi = ggml_cuda_type_traits<type>::qi; @@ -707,6 +711,7 @@ static __global__ void mul_mat_vec_q_moe( return; } + ggml_cuda_pdl_sync(); const uint32_t channel_x = ids[channel_dst + token_idx * ids_stride]; const uint32_t channel_y = fastmodulo(channel_dst, nchannels_y); @@ -726,6 +731,8 @@ static __global__ void mul_mat_vec_q_moe( } } + ggml_cuda_pdl_lc(); + // Warp-level reduction only - no shared memory needed #pragma unroll for (int i = 0; i < c_rows_per_block; ++i) { @@ -794,8 +801,9 @@ static void mul_mat_vec_q_moe_launch( const int64_t nblocks_rows = (nrows_x + rows_per_block - 1) / rows_per_block; const dim3 block_nums(nblocks_rows, nchannels_dst); const dim3 block_dims(warp_size, ncols_dst); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream); - mul_mat_vec_q_moe<type, rows_per_block><<<block_nums, block_dims, 0, stream>>>( + ggml_cuda_kernel_launch(mul_mat_vec_q_moe<type, rows_per_block>, launch_params, vx, vy, ids, dst, ncols_x, nchannels_y, nrows_x, stride_row_x, stride_col_y, stride_col_dst, stride_channel_x, stride_channel_y, stride_channel_dst, From facb02c4c3a32f07935c5da60e92b0f650f1bd40 Mon Sep 17 00:00:00 2001 From: Charles Xu <charles.xu@arm.com> Date: Fri, 5 Jun 2026 09:11:47 +0200 Subject: [PATCH 789/831] kleidiai : dynamic chunck-based scheduling for hybrid execution (llama/23819) --- ggml/src/ggml-cpu/kleidiai/kleidiai.cpp | 272 ++++++++++++------------ 1 file changed, 141 insertions(+), 131 deletions(-) diff --git a/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp index 0ecf7ae02ac..9e54b676b93 100644 --- a/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +++ b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp @@ -38,6 +38,7 @@ #include "kleidiai.h" #include "ggml-cpu.h" +#include "ggml-cpu-impl.h" #include "ggml-impl.h" #include "ggml-backend-impl.h" #include "ggml-threading.h" @@ -61,7 +62,8 @@ struct ggml_kleidiai_context { ggml_kleidiai_kernels * kernels_q8; int sme_thread_cap; // <= 0 means “SME disabled/unknown”; int thread_hint; // <= 0 means “no hint” -} static ctx = { CPU_FEATURE_NONE, nullptr, nullptr, 0, -1 }; + int chunk_multiplier; +} static ctx = { CPU_FEATURE_NONE, nullptr, nullptr, 0, -1, 4 }; static const char* cpu_feature_to_string(cpu_feature f) { if (f == CPU_FEATURE_NONE) { @@ -186,8 +188,9 @@ static void init_kleidiai_context(void) { if (!initialized) { initialized = true; - const char *env_sme = getenv("GGML_KLEIDIAI_SME"); - const char *env_threads = getenv("GGML_TOTAL_THREADS"); + const char *env_sme = getenv("GGML_KLEIDIAI_SME"); + const char *env_threads = getenv("GGML_TOTAL_THREADS"); + const char *env_chunk_mult = getenv("GGML_KLEIDIAI_CHUNK_MULTIPLIER"); const bool cpu_has_sme = ggml_cpu_has_sme(); size_t detected_smcus = 0; @@ -204,6 +207,14 @@ static void init_kleidiai_context(void) { } } + if (env_chunk_mult) { + bool ok = false; + int multiplier = parse_uint_env(env_chunk_mult, "GGML_KLEIDIAI_CHUNK_MULTIPLIER", &ok); + if (ok && multiplier > 0) { + ctx.chunk_multiplier = multiplier; + } + } + // SME policy: // - If CPU doesn't support SME: SME always off. // - Else: @@ -296,6 +307,50 @@ static inline size_t align_up(size_t value, size_t alignment) { return remainder == 0 ? value : value + (alignment - remainder); } +static inline size_t gcd_size(size_t a, size_t b) { + while (b != 0) { + const size_t t = a % b; + a = b; + b = t; + } + return a; +} + +static inline bool lcm_size(size_t a, size_t b, size_t & result) { + if (a == 0 || b == 0) { + result = 0; + return false; + } + const size_t g = gcd_size(a, b); + const size_t q = a / g; + if (q > SIZE_MAX / b) { + return false; + } + result = q * b; + return true; +} + +static inline size_t ceil_div_size(size_t a, size_t b) { + return b == 0 ? 0 : (a + b - 1) / b; +} + +struct kleidiai_block_args { + size_t lhs_bl; + size_t rhs_bl; + size_t pack_bl; +}; + +static inline kleidiai_block_args kleidiai_get_block_args(ggml_type rhs_type) { + switch (rhs_type) { + case GGML_TYPE_Q4_0: + return { QK4_0, QK4_0, QK4_0 }; + case GGML_TYPE_Q8_0: + return { 0, 0, QK8_0 }; + default: + return { 0, 0, 0 }; + } +} + static inline bool kleidiai_pack_fallback_allowed() { if (ctx.sme_thread_cap <= 0) { return false; @@ -746,8 +801,10 @@ class tensor_traits : public ggml::cpu::tensor_traits { size_t n_step; size_t lhs_packed_size; size_t lhs_offset; - size_t n_offset; - size_t n_cols; + size_t lhs_bl; + size_t rhs_bl; + size_t pack_bl; + size_t lhs_packed_offset0; int assigned_threads; int thread_begin; int thread_end; @@ -772,6 +829,8 @@ class tensor_traits : public ggml::cpu::tensor_traits { continue; } + const kleidiai_block_args block_args = kleidiai_get_block_args(kernels->rhs_type); + runtime[runtime_count] = { slot, kernels, @@ -784,7 +843,9 @@ class tensor_traits : public ggml::cpu::tensor_traits { kinfo->get_n_step(), 0, 0, - 0, + block_args.lhs_bl, + block_args.rhs_bl, + block_args.pack_bl, 0, 0, 0, @@ -795,45 +856,8 @@ class tensor_traits : public ggml::cpu::tensor_traits { } if (runtime_count == 0) { - ggml_kleidiai_kernels * fallback = ggml_kleidiai_select_kernels(ctx.features, dst); - if (!fallback) { - return false; - } - kernel_info * kinfo = is_gemv ? &fallback->gemv : &fallback->gemm; - lhs_packing_info * linfo = is_gemv ? &fallback->gemv_lhs_info : &fallback->gemm_lhs_info; - rhs_packing_info * rinfo = &fallback->rhs_info; - if (!kinfo || !linfo || !linfo->packed_size_ex || !linfo->pack_func_ex || - !kinfo->get_rhs_packed_offset_ex || !kinfo->run_kernel_ex || !kinfo->get_dst_offset || - !rinfo || !rinfo->pack_func_ex || !rinfo->packed_size_ex) { - return false; - } - kernel_chain[0] = fallback; - runtime[0] = { - 0, - fallback, - kinfo, - linfo, - kinfo->get_mr(), - kinfo->get_nr(), - kinfo->get_kr(), - kinfo->get_sr(), - kinfo->get_n_step(), - 0, - 0, - 0, - 0, - 0, - 0, - 0, - nullptr - }; - size_t rhs_size_fallback = 0; - const uint8_t * rhs_base = weight_for_slot(0, rhs_size_fallback); - if (!rhs_base) { - rhs_base = static_cast<const uint8_t *>(src0->data); - } - runtime[0].rhs_base = rhs_base; - runtime_count = 1; + GGML_LOG_WARN("kleidiai: no runtime kernel slot available for supported op %s\n", dst->name); + return false; } const int nth_total = params->nth > 0 ? params->nth : 1; @@ -846,6 +870,13 @@ class tensor_traits : public ggml::cpu::tensor_traits { break; } } + int non_sme_slot = -1; + for (int i = 0; i < runtime_count; ++i) { + if ((runtime[i].kernels->required_cpu & CPU_FEATURE_SME) != CPU_FEATURE_SME) { + non_sme_slot = i; + break; + } + } const int sme_cap_limit = ctx.sme_thread_cap; const bool use_hybrid = sme_cap_limit > 0 && @@ -864,12 +895,15 @@ class tensor_traits : public ggml::cpu::tensor_traits { if (!hybrid_enabled) { int chosen_slot = 0; if (too_small_for_hybrid && sme_slot != -1) { - chosen_slot = sme_slot; + chosen_slot = nth_total > sme_cap_limit && non_sme_slot != -1 ? non_sme_slot : sme_slot; } else if (runtime_count > 1 && ctx.sme_thread_cap > 0 && nth_total > ctx.sme_thread_cap) { chosen_slot = 1; } if (chosen_slot != 0 && chosen_slot < runtime_count) { runtime[0] = runtime[chosen_slot]; + runtime[0].assigned_threads = 0; + runtime[0].thread_begin = 0; + runtime[0].thread_end = 0; } runtime_count = runtime_count > 0 ? 1 : 0; @@ -896,6 +930,8 @@ class tensor_traits : public ggml::cpu::tensor_traits { int fallback_indices[GGML_KLEIDIAI_MAX_KERNEL_SLOTS]; int fallback_count = 0; + // The current hybrid chain is bounded to SME + one non-SME fallback slot. + GGML_ASSERT(GGML_KLEIDIAI_MAX_KERNEL_SLOTS == 2); for (int i = 0; i < runtime_count; ++i) { if (i == sme_slot) { continue; @@ -952,73 +988,67 @@ class tensor_traits : public ggml::cpu::tensor_traits { size_t cursor = 0; for (int i = 0; i < runtime_count; ++i) { - const ggml_type slot_rhs_type = runtime[i].kernels->rhs_type; - const size_t slot_pack_size_arg = slot_rhs_type == GGML_TYPE_Q4_0 ? QK4_0 : - slot_rhs_type == GGML_TYPE_Q8_0 ? QK8_0 : 0; - runtime[i].lhs_packed_size = runtime[i].lhs_info->packed_size_ex(m, k, slot_pack_size_arg, runtime[i].mr, runtime[i].kr, runtime[i].sr); + runtime[i].lhs_packed_size = runtime[i].lhs_info->packed_size_ex(m, k, runtime[i].pack_bl, runtime[i].mr, runtime[i].kr, runtime[i].sr); cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN); runtime[i].lhs_offset = cursor; + runtime[i].lhs_packed_offset0 = runtime[i].lhs_info->get_packed_offset_ex(0, k, runtime[i].lhs_bl, runtime[i].mr, runtime[i].kr, runtime[i].sr); cursor += runtime[i].lhs_packed_size; } GGML_ASSERT(cursor <= params->wsize); uint8_t * scratch = static_cast<uint8_t *>(params->wdata); - size_t assigned_cols = 0; - uint64_t weighted_total = 0; - if (runtime_count > 1 && sme_slot != -1) { - for (int i = 0; i < runtime_count; ++i) { - const uint64_t weight = (i == sme_slot) ? (sme_cap << 1) : 1; - weighted_total += (uint64_t)runtime[i].assigned_threads * weight; - } - } + size_t common_step = 1; for (int i = 0; i < runtime_count; ++i) { - runtime[i].n_offset = assigned_cols; if (runtime[i].assigned_threads == 0) { - runtime[i].n_cols = 0; continue; } - const size_t remaining_cols = n - assigned_cols; - if (remaining_cols == 0) { - runtime[i].n_cols = 0; - continue; - } - const size_t step = runtime[i].n_step ? runtime[i].n_step : 1; - size_t target = 0; - if (weighted_total > 0) { - const uint64_t weight = (i == sme_slot) ? (sme_cap << 1) : 1; - target = (size_t)(((uint64_t)n * runtime[i].assigned_threads * weight) / weighted_total); - } else { - target = (size_t)(((uint64_t)n * runtime[i].assigned_threads) / nth_total); - } - target = std::min(target, remaining_cols); - size_t aligned = round_down(target, step); - if (aligned == 0 && remaining_cols >= step) { - aligned = step; + size_t next_step = 0; + if (!lcm_size(common_step, runtime[i].n_step ? runtime[i].n_step : 1, next_step)) { + return false; } - runtime[i].n_cols = aligned; - assigned_cols += aligned; + common_step = next_step; } - - if (assigned_cols < n) { - for (int i = runtime_count - 1; i >= 0; --i) { - if (runtime[i].assigned_threads > 0) { - runtime[i].n_cols += n - assigned_cols; - break; - } - } + GGML_ASSERT(common_step > 0); + + const bool disable_chunking = ggml_is_numa(); + const size_t chunk_multiplier = std::max(1, ctx.chunk_multiplier); + const size_t chunk_divisor = (nth_total == 1 || disable_chunking) ? (size_t)nth_total : (size_t)nth_total * chunk_multiplier; + size_t chunk_cols = align_up(std::max<size_t>(1, ceil_div_size(n, chunk_divisor)), common_step); + if (chunk_cols == 0) { + chunk_cols = common_step; } + // If common_step is larger than n, the loop below runs one valid tail chunk + // with cols == n. + const size_t nchunk_size = std::max<size_t>(1, ceil_div_size(n, chunk_cols)); + GGML_ASSERT(nchunk_size <= (size_t)INT_MAX); + const int nchunk = (int)nchunk_size; const size_t dst_stride = dst->nb[1]; + auto run_chunk = [&](runtime_slot & slot, size_t global_start, size_t cols, uint8_t * dst_batch_base) { + const size_t rhs_packed_offset = slot.kernel->get_rhs_packed_offset_ex(global_start, k, slot.rhs_bl); + const size_t dst_offset = slot.kernel->get_dst_offset(0, global_start, dst_stride); + + const uint8_t * lhs_ptr = scratch + slot.lhs_offset + slot.lhs_packed_offset0; + const uint8_t * rhs_ptr = slot.rhs_base + rhs_packed_offset; + float * dst_ptr = reinterpret_cast<float *>(dst_batch_base + dst_offset); + + slot.kernel->run_kernel_ex(m, cols, k, slot.rhs_bl, + lhs_ptr, + rhs_ptr, + dst_ptr, + dst_stride, + sizeof(float), + -FLT_MAX, + FLT_MAX); + }; + for (int64_t batch_idx = 0; batch_idx < ne12; ++batch_idx) { const uint8_t * lhs_batch_base = static_cast<const uint8_t *>(src1->data) + batch_idx * src1->nb[2]; uint8_t * dst_batch_base = static_cast<uint8_t *>(dst->data) + batch_idx * dst->nb[2]; if (runtime[local_slot].assigned_threads > 0) { runtime_slot & slot = runtime[local_slot]; - const ggml_type slot_rhs_type = slot.kernels->rhs_type; - const size_t slot_lhs_exec_arg = slot_rhs_type == GGML_TYPE_Q4_0 ? QK4_0 : - slot_rhs_type == GGML_TYPE_Q8_0 ? 0 : 0; const int64_t m_roundup_mr = kai_roundup((int64_t)m, (int64_t)slot.mr); int64_t max_threads = slot.mr ? (m_roundup_mr / (int64_t)slot.mr) : slot.assigned_threads; max_threads = std::max<int64_t>(1, max_threads); @@ -1031,8 +1061,8 @@ class tensor_traits : public ggml::cpu::tensor_traits { const int64_t m_start = (int64_t)local_ith * num_m_per_thread0; const int64_t m_count = (local_ith == use_threads - 1) ? num_m_per_threadN_1 : num_m_per_thread0; - const size_t base_packed_off = slot.lhs_info->get_packed_offset_ex(m_start, k, slot_lhs_exec_arg, slot.mr, slot.kr, slot.sr); - const size_t next_block_off = slot.lhs_info->get_packed_offset_ex(m_start + slot.mr, k, slot_lhs_exec_arg, slot.mr, slot.kr, slot.sr); + const size_t base_packed_off = slot.lhs_info->get_packed_offset_ex(m_start, k, slot.lhs_bl, slot.mr, slot.kr, slot.sr); + const size_t next_block_off = slot.lhs_info->get_packed_offset_ex(m_start + slot.mr, k, slot.lhs_bl, slot.mr, slot.kr, slot.sr); const size_t row_stride_bytes = slot.mr ? (next_block_off - base_packed_off) / slot.mr : 0; int64_t remaining = m_count; @@ -1049,7 +1079,7 @@ class tensor_traits : public ggml::cpu::tensor_traits { const size_t dst_off = base_packed_off + (size_t)(cur - m_start) * row_stride_bytes; void * dst_ptr = lhs_packed + dst_off; - slot.lhs_info->pack_func_ex(take, k, slot_lhs_exec_arg, slot.mr, slot.kr, slot.sr, 0, src_ptr, src1->nb[1], dst_ptr); + slot.lhs_info->pack_func_ex(take, k, slot.lhs_bl, slot.mr, slot.kr, slot.sr, 0, src_ptr, src1->nb[1], dst_ptr); cur += take; remaining -= take; @@ -1057,49 +1087,29 @@ class tensor_traits : public ggml::cpu::tensor_traits { } } + if (ith_total == 0) { + ggml_threadpool_chunk_set(params->threadpool, nth_total); + } + + // Publishes both LHS packing and the initialized dynamic chunk queue. ggml_barrier(params->threadpool); runtime_slot & slot = runtime[local_slot]; - if (slot.n_cols > 0 && slot.assigned_threads > 0) { - int64_t active_threads = slot.assigned_threads; - const int64_t max_threads = slot.n_step ? (slot.n_cols / slot.n_step) : slot.assigned_threads; - if (max_threads > 0) { - active_threads = std::min<int64_t>(active_threads, std::max<int64_t>(1, max_threads)); + int current_chunk = ith_total; + while (current_chunk < nchunk) { + const size_t global_start = (size_t)current_chunk * chunk_cols; + if (global_start >= n) { + break; } - active_threads = std::max<int64_t>(1, active_threads); - - if (local_ith < active_threads) { - const size_t step = slot.n_step ? slot.n_step : 1; - const size_t chunk0 = round_down((size_t)(slot.n_cols / active_threads), step); - const size_t chunkN = slot.n_cols - (active_threads - 1) * chunk0; - const size_t local_start = (size_t)local_ith * chunk0; - const size_t cols = (local_ith == active_threads - 1) ? chunkN : chunk0; - - if (cols > 0) { - const ggml_type slot_rhs_type = slot.kernels->rhs_type; - const size_t slot_lhs_exec_arg = slot_rhs_type == GGML_TYPE_Q4_0 ? QK4_0 : - slot_rhs_type == GGML_TYPE_Q8_0 ? 0 : 0; - const size_t slot_rhs_block_arg = slot_rhs_type == GGML_TYPE_Q4_0 ? QK4_0 : - slot_rhs_type == GGML_TYPE_Q8_0 ? 0 : 0; - const size_t global_start = slot.n_offset + local_start; - const size_t lhs_packed_offset = slot.lhs_info->get_packed_offset_ex(0, k, slot_lhs_exec_arg, slot.mr, slot.kr, slot.sr); - const size_t rhs_packed_offset = slot.kernel->get_rhs_packed_offset_ex(global_start, k, slot_rhs_block_arg); - const size_t dst_offset = slot.kernel->get_dst_offset(0, global_start, dst_stride); - - const uint8_t * lhs_ptr = scratch + slot.lhs_offset + lhs_packed_offset; - const uint8_t * rhs_ptr = slot.rhs_base + rhs_packed_offset; - float * dst_ptr = reinterpret_cast<float *>(dst_batch_base + dst_offset); - - slot.kernel->run_kernel_ex(m, cols, k, slot_rhs_block_arg, - lhs_ptr, - rhs_ptr, - dst_ptr, - dst_stride, - sizeof(float), - -FLT_MAX, - FLT_MAX); - } + + const size_t cols = std::min(chunk_cols, n - global_start); + if (cols > 0) { + // KleidiAI GEMM/GEMV kernels accept arbitrary final tail widths; + // only non-tail chunks are guaranteed to be n_step-aligned. + run_chunk(slot, global_start, cols, dst_batch_base); } + + current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1); } if (batch_idx != ne12 - 1) { From 5a1feed8ca57b70d002ca0df2abd3db9328a1daa Mon Sep 17 00:00:00 2001 From: Ruben Ortlam <rortlam@redhat.com> Date: Fri, 5 Jun 2026 19:44:40 +0200 Subject: [PATCH 790/831] vulkan: add fwht support for Intel with shmem reduction (llama/23964) * vulkan: add fwht support for Intel with shmem reduction * don't use N as workgroup size * disable subgroup shuffle on MoltenVK AMD * disable fwht shader on Intel Windows due to driver bug --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 13 ++++ ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp | 78 +++++++++++++++---- .../vulkan-shaders/vulkan-shaders-gen.cpp | 1 + 3 files changed, 76 insertions(+), 16 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index e7d04634b8a..df410368a79 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -5084,6 +5084,14 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) { } ++idx; } + } else if (device->driver_id != vk::DriverId::eIntelProprietaryWindows) { + // Disabled on Intel Windows due to a driver bug: https://github.com/ggml-org/llama.cpp/pull/23964#issuecomment-4598226147 + int idx = 0; + for (uint32_t n : {64, 128, 256, 512}) { + const uint32_t block_size = std::min(device->subgroup_size, n); + ggml_vk_create_pipeline(device, device->pipeline_fwht_f32[idx], "fwht_shmem_f32", fwht_shmem_f32_len, fwht_shmem_f32_data, "main", 2, sizeof(vk_op_fwht_push_constants), {1, 1, 1}, { block_size, n }, 1); + ++idx; + } } const uint32_t cumsum_elem_per_thread = (device->vendor_id == VK_VENDOR_ID_AMD || device->vendor_id == VK_VENDOR_ID_INTEL) ? 2 : 4; @@ -5630,6 +5638,11 @@ static vk_device ggml_vk_get_device(size_t idx) { #endif device->subgroup_shuffle = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) && (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eShuffle); +#ifdef __APPLE__ + if (device->vendor_id == VK_VENDOR_ID_AMD) { + device->subgroup_shuffle = false; + } +#endif device->subgroup_clustered = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) && (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eClustered); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp b/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp index 72059d4afc2..a2069964adb 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp @@ -1,14 +1,16 @@ #version 450 #extension GL_EXT_control_flow_attributes : require +#ifndef FWHT_SHMEM #extension GL_KHR_shader_subgroup_basic : enable #extension GL_KHR_shader_subgroup_shuffle : enable +#endif -layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in; - -layout(constant_id = 0) const uint WARP_SIZE = 32; +layout(constant_id = 0) const uint BLOCK_SIZE = 32; layout(constant_id = 1) const uint N = 128; +layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in; + layout(push_constant) uniform parameter { uint n_rows; @@ -20,35 +22,72 @@ layout(push_constant) uniform parameter layout(binding = 0, std430) readonly buffer A { float data_a[]; }; layout(binding = 1, std430) writeonly buffer D { float data_d[]; }; -const uint EL_W = N / WARP_SIZE; +const uint EL_W = N / BLOCK_SIZE; + +#ifdef FWHT_SHMEM +shared float shmem[4 * N]; +#endif void main() { - const uint lane = gl_SubgroupInvocationID; - for (uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_SubgroupID; - row < n_rows; - row += gl_NumWorkGroups.x * gl_WorkGroupSize.y) { +#ifdef FWHT_SHMEM + const uint tid = gl_LocalInvocationID.x; + const uint shmem_base = gl_LocalInvocationID.y * N; + const uint row_id = gl_LocalInvocationID.y; +#else + const uint tid = gl_SubgroupInvocationID; + const uint row_id = gl_SubgroupID; +#endif + + for (uint base_row = gl_WorkGroupID.x * gl_WorkGroupSize.y; + base_row < n_rows; + base_row += gl_NumWorkGroups.x * gl_WorkGroupSize.y) { + const uint row = base_row + row_id; const uint row_offset = row * N; +#ifndef FWHT_SHMEM + if (row >= n_rows) { + continue; + } +#endif + float reg[EL_W]; [[unroll]] for (uint i = 0; i < EL_W; ++i) { - reg[i] = data_a[src_offset + row_offset + i * WARP_SIZE + lane] * scale; + reg[i] = row < n_rows ? data_a[src_offset + row_offset + i * BLOCK_SIZE + tid] * scale : 0.0; } +#ifdef FWHT_SHMEM + [[unroll]] + for (uint h = 1; h < BLOCK_SIZE; h <<= 1) { + [[unroll]] + for (uint i = 0; i < EL_W; ++i) { + shmem[shmem_base + i * BLOCK_SIZE + tid] = reg[i]; + } + barrier(); + [[unroll]] + for (uint j = 0; j < EL_W; ++j) { + const float val = reg[j]; + const float other = shmem[shmem_base + j * BLOCK_SIZE + (tid ^ h)]; + reg[j] = (tid & h) == 0 ? val + other : other - val; + } + barrier(); + } +#else [[unroll]] - for (uint h = 1; h < WARP_SIZE; h <<= 1) { + for (uint h = 1; h < BLOCK_SIZE; h <<= 1) { [[unroll]] for (uint j = 0; j < EL_W; ++j) { const float val = reg[j]; const float val2 = subgroupShuffleXor(val, h); - reg[j] = (lane & h) == 0 ? val + val2 : val2 - val; + reg[j] = (tid & h) == 0 ? val + val2 : val2 - val; } } +#endif [[unroll]] - for (uint h = WARP_SIZE; h < N; h <<= 1) { - const uint step = h / WARP_SIZE; + for (uint h = BLOCK_SIZE; h < N; h <<= 1) { + const uint step = h / BLOCK_SIZE; [[unroll]] for (uint j = 0; j < EL_W; j += 2 * step) { [[unroll]] @@ -61,9 +100,16 @@ void main() { } } - [[unroll]] - for (uint i = 0; i < EL_W; ++i) { - data_d[dst_offset + row_offset + i * WARP_SIZE + lane] = reg[i]; +#ifdef FWHT_SHMEM + if (row < n_rows) { +#endif + [[unroll]] + for (uint i = 0; i < EL_W; ++i) { + data_d[dst_offset + row_offset + i * BLOCK_SIZE + tid] = reg[i]; + } +#ifdef FWHT_SHMEM } + barrier(); +#endif } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index de7dbec2c63..d65cd12b287 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -957,6 +957,7 @@ void process_shaders() { string_to_spv("argmax_f32", "argmax.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "int"}})); string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("fwht_f32", "fwht.comp", {}); + string_to_spv("fwht_shmem_f32", "fwht.comp", {{"FWHT_SHMEM", "1"}}); string_to_spv("count_equal_i32", "count_equal.comp", merge_maps(base_dict, {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}})); string_to_spv("cumsum_f32", "cumsum.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("cumsum_multipass1_f32", "cumsum_multipass1.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); From a87e950a0634481140473324e346badd458ff11f Mon Sep 17 00:00:00 2001 From: lhez <lih@qti.qualcomm.com> Date: Fri, 5 Jun 2026 13:45:25 -0700 Subject: [PATCH 791/831] opencl: improve get_rows, cpy, concat and q6_k flat gemv (llama/24160) * opencl: allow multiple workgroups for large rows * opencl: improve small cpy * opencl: packed concat for small input * opencl: tweak flat q6_K gemv, increase N_DST and remap threads --- ggml/src/ggml-opencl/ggml-opencl.cpp | 71 +++++++++-- ggml/src/ggml-opencl/kernels/concat.cl | 67 +++++++++++ ggml/src/ggml-opencl/kernels/cpy.cl | 59 +++++++++ ggml/src/ggml-opencl/kernels/get_rows.cl | 24 ++-- .../kernels/mul_mv_q6_k_f32_flat.cl | 112 ++++++++---------- 5 files changed, 247 insertions(+), 86 deletions(-) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index c411e4aeaec..2a41215fd13 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -558,7 +558,7 @@ struct ggml_backend_opencl_context { cl_kernel kernel_set_rows_f32_i64, kernel_set_rows_f32_i32, kernel_set_rows_f16_i64, kernel_set_rows_f16_i32; cl_kernel kernel_rope_norm_f32, kernel_rope_norm_f16, kernel_rope_neox_f32, kernel_rope_neox_f16; cl_kernel kernel_rope_multi_f32, kernel_rope_multi_f16, kernel_rope_vision_f32, kernel_rope_vision_f16; - cl_kernel kernel_cpy_f16_f16, kernel_cpy_f16_f32, kernel_cpy_f32_f16, kernel_cpy_f32_f32, kernel_cpy_i32_i32; + cl_kernel kernel_cpy_f16_f16, kernel_cpy_f16_f32, kernel_cpy_f32_f16, kernel_cpy_f32_f32, kernel_cpy_f32_f32_pack, kernel_cpy_i32_i32; cl_kernel kernel_mul_mat_f32_f32; cl_kernel kernel_mul_mat_f16_f16; cl_kernel kernel_mul_mat_f16_f32_1row; @@ -639,7 +639,7 @@ struct ggml_backend_opencl_context { cl_kernel kernel_softplus_f16, kernel_softplus_f16_4, kernel_softplus_f16_nc; cl_kernel kernel_upscale; cl_kernel kernel_upscale_bilinear; - cl_kernel kernel_concat_f32; + cl_kernel kernel_concat_f32, kernel_concat_f32_pack; cl_kernel kernel_conv_2d_f16; cl_kernel kernel_conv_2d_f32; cl_kernel kernel_conv_2d_f16_f32; @@ -1121,6 +1121,7 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx) { CL_CHECK((backend_ctx->kernel_cpy_f16_f32 = clCreateKernel(prog, "kernel_cpy_f16_f32", &err), err)); CL_CHECK((backend_ctx->kernel_cpy_f32_f16 = clCreateKernel(prog, "kernel_cpy_f32_f16", &err), err)); CL_CHECK((backend_ctx->kernel_cpy_f32_f32 = clCreateKernel(prog, "kernel_cpy_f32_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_cpy_f32_f32_pack = clCreateKernel(prog, "kernel_cpy_f32_f32_pack", &err), err)); CL_CHECK((backend_ctx->kernel_cpy_i32_i32 = clCreateKernel(prog, "kernel_cpy_i32_i32", &err), err)); GGML_LOG_CONT("."); } @@ -2615,6 +2616,7 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx) { cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); CL_CHECK((backend_ctx->kernel_concat_f32 = clCreateKernel(prog, "kernel_concat_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_concat_f32_pack = clCreateKernel(prog, "kernel_concat_f32_pack", &err), err)); CL_CHECK(clReleaseProgram(prog)); GGML_LOG_CONT("."); } @@ -8552,7 +8554,14 @@ static void ggml_cl_get_rows(ggml_backend_t backend, const ggml_tensor * src0, c nth *= 2; } - size_t global_work_size[] = {(size_t)ne10*nth, (size_t)ne11, (size_t)ne12}; + int nchunks = 1; + if (src0->type == GGML_TYPE_F32) { + const int chunk_target = nth * 4; + nchunks = (ne00 + chunk_target - 1) / chunk_target; + nchunks = MAX(1, MIN(nchunks, 64)); + } + + size_t global_work_size[] = {(size_t)ne10*nth*nchunks, (size_t)ne11, (size_t)ne12}; size_t local_work_size[] = {(size_t)nth, 1, 1}; backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); @@ -11128,7 +11137,9 @@ static void ggml_cl_concat(ggml_backend_t backend, const ggml_tensor * src0, con int nth = MIN(64, ne0); - cl_kernel kernel = backend_ctx->kernel_concat_f32; + const bool concat_pack = (dim == 0 && ne0 < 32); + cl_kernel kernel = concat_pack ? backend_ctx->kernel_concat_f32_pack + : backend_ctx->kernel_concat_f32; CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); @@ -11155,10 +11166,28 @@ static void ggml_cl_concat(ggml_backend_t backend, const ggml_tensor * src0, con CL_CHECK(clSetKernelArg(kernel, 22, sizeof(cl_ulong), &nb3)); CL_CHECK(clSetKernelArg(kernel, 23, sizeof(cl_int), &dim)); - size_t global_work_size[] = {(size_t)ne1*nth, (size_t)ne2, (size_t)ne3}; - size_t local_work_size[] = {(size_t)nth, 1, 1}; + if (concat_pack) { + // packed kernel needs the dst dims to unflatten its 1-D row index. + CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int), &ne2)); + CL_CHECK(clSetKernelArg(kernel, 26, sizeof(int), &ne3)); + + const int maxwg = (int)backend_ctx->get_kernel_workgroup_size(kernel); + const int base = MIN(64, maxwg); + const int tpr = MIN(ne0, base); // threads per row + const int rpw = MAX(1, base / tpr); // rows per workgroup + const int lsz = tpr * rpw; + const int nrows = ne1*ne2*ne3; + const int nwg = (nrows + rpw - 1) / rpw; + size_t global_work_size[] = {(size_t)nwg*lsz, 1, 1}; + size_t local_work_size[] = {(size_t)lsz, 1, 1}; + backend_ctx->enqueue_ndrange_kernel(kernel, 1, global_work_size, local_work_size, dst); + } else { + size_t global_work_size[] = {(size_t)ne1*nth, (size_t)ne2, (size_t)ne3}; + size_t local_work_size[] = {(size_t)nth, 1, 1}; - backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + } } static void ggml_cl_timestep_embedding(ggml_backend_t backend, const ggml_tensor * src0, ggml_tensor * dst) { @@ -14536,7 +14565,7 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co } else if (backend_ctx->gpu_family == ADRENO) { nth0 = 64; nth1 = 2; - ndst = 4; + ndst = 16; } else { GGML_ASSERT(false && "TODO: Unknown GPU"); } @@ -16633,7 +16662,8 @@ static void ggml_cl_cpy(ggml_backend_t backend, const ggml_tensor * src0, const kernel = backend_ctx->kernel_cpy_f32_f16; break; case GGML_TYPE_F32: - kernel = backend_ctx->kernel_cpy_f32_f32; + kernel = ne00 < 32 ? backend_ctx->kernel_cpy_f32_f32_pack + : backend_ctx->kernel_cpy_f32_f32; break; default: GGML_ASSERT(false && "not implemented"); @@ -16685,12 +16715,27 @@ static void ggml_cl_cpy(ggml_backend_t backend, const ggml_tensor * src0, const CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb12)); CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb13)); - const int nth = MIN(64, ne00); + if (kernel == backend_ctx->kernel_cpy_f32_f32_pack) { + const int maxwg = (int)backend_ctx->get_kernel_workgroup_size(kernel); + const int base = MIN(64, maxwg); + const int tpr = MIN(ne00, base); // threads per row + const int rpw = MAX(1, base / tpr); // rows per workgroup + const int lsz = tpr * rpw; // <= base <= maxwg + const int nrows = ne01*ne02*ne03; + const int nwg = (nrows + rpw - 1) / rpw; - size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03}; - size_t local_work_size[] = {(size_t)nth, 1, 1}; + size_t global_work_size[] = {(size_t)nwg*lsz, 1, 1}; + size_t local_work_size[] = {(size_t)lsz, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 1, global_work_size, local_work_size, src1); + } else { + const int nth = MIN(64, ne00); - backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, src1); + size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03}; + size_t local_work_size[] = {(size_t)nth, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, src1); + } } static void ggml_cl_dup(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { diff --git a/ggml/src/ggml-opencl/kernels/concat.cl b/ggml/src/ggml-opencl/kernels/concat.cl index 0c1b3d785ca..2fbd7851d3d 100644 --- a/ggml/src/ggml-opencl/kernels/concat.cl +++ b/ggml/src/ggml-opencl/kernels/concat.cl @@ -49,3 +49,70 @@ kernel void kernel_concat_f32( *y = *x; } } + +kernel void kernel_concat_f32_pack( + global const char * src0, + ulong offset0, + global const char * src1, + ulong offset1, + global char * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + ulong nb10, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3, + int dim, + int ne1, + int ne2, + int ne3 +) { + src0 = src0 + offset0; + src1 = src1 + offset1; + dst = dst + offsetd; + + int lsz = get_local_size(0); + int tpr = min(ne0, lsz); // threads per row + int rpw = lsz / tpr; // rows per workgroup + int lid = get_local_id(0); + int row = get_group_id(0)*rpw + lid / tpr; + int lane = lid - (lid / tpr) * tpr; + + int nrows = ne1*ne2*ne3; + if (row >= nrows) { + return; + } + + int i1 = row % ne1; + int t = row / ne1; + int i2 = t % ne2; + int i3 = t / ne2; + + int o[4] = {0, 0, 0, 0}; + o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03)); + + for (int i0 = lane; i0 < ne0; i0 += tpr) { + global const float * x; + if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) { + x = (global const float *)(src0 + (i3 )*nb03 + (i2 )*nb02 + (i1 )*nb01 + (i0 )*nb00); + } else { + x = (global const float *)(src1 + (i3 - o[3])*nb13 + (i2 - o[2])*nb12 + (i1 - o[1])*nb11 + (i0 - o[0])*nb10); + } + + global float * y = (global float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + *y = *x; + } +} diff --git a/ggml/src/ggml-opencl/kernels/cpy.cl b/ggml/src/ggml-opencl/kernels/cpy.cl index 820aa538a34..adbd2e766d2 100644 --- a/ggml/src/ggml-opencl/kernels/cpy.cl +++ b/ggml/src/ggml-opencl/kernels/cpy.cl @@ -183,6 +183,65 @@ kernel void kernel_cpy_f32_f32( } } +kernel void kernel_cpy_f32_f32_pack( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + int lsz = get_local_size(0); + int tpr = min(ne00, lsz); // threads per row + int rpw = lsz / tpr; // rows per workgroup + int lid = get_local_id(0); + int row = get_group_id(0)*rpw + lid / tpr; + int lane = lid - (lid / tpr) * tpr; + + int nrows = ne01*ne02*ne03; + if (row >= nrows) { + return; + } + + int i01 = row % ne01; + int t = row / ne01; + int i02 = t % ne02; + int i03 = t / ne02; + + // linear index of the first element of this row, unflattened over dst dims + long n = (long)row * ne00; + int i3 = (int)(n / ((long)ne2*ne1*ne0)); + long rm = n - (long)i3*ne2*ne1*ne0; + int i2 = (int)(rm / ((long)ne1*ne0)); + rm -= (long)i2*ne1*ne0; + int i1 = (int)(rm / ne0); + int i0 = (int)(rm - (long)i1*ne0); + + global float * dst_data = (global float *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int i00 = lane; i00 < ne00; i00 += tpr) { + global const float * src = (global float *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + dst_data[i00] = src[0]; + } +} + kernel void kernel_cpy_i32_i32( global int * src0, ulong offset0, diff --git a/ggml/src/ggml-opencl/kernels/get_rows.cl b/ggml/src/ggml-opencl/kernels/get_rows.cl index c2962edc983..9ae4fff09fc 100644 --- a/ggml/src/ggml-opencl/kernels/get_rows.cl +++ b/ggml/src/ggml-opencl/kernels/get_rows.cl @@ -82,21 +82,27 @@ kernel void kernel_get_rows_f32( src1 = (global int*)((global char*)src1 + offset1); dst = (global float*)((global char*)dst + offsetd); - int i10 = get_group_id(0); - int i11 = get_group_id(1); - int i12 = get_group_id(2); + int nchunks = get_num_groups(0) / ne10; + int g = get_group_id(0); + int i10 = g / nchunks; + int chunk = g - i10 * nchunks; + int i11 = get_group_id(1); + int i12 = get_group_id(2); int r = ((global int *) ((global char *) src1 + i12*nb12 + i11*nb11 + i10*nb10))[0]; int i02 = i11; int i03 = i12; - for (int ind = get_local_id(0); ind < ne00; ind += get_local_size(0)) { - if (ind >= ne00) { - return; - } - ((global float *) ((global char *) dst + i12*nb3 + i11*nb2 + i10*nb1))[ind] = - ((global float *) ((global char *) src0 + r*nb01 + i02*nb02 + i03*nb03))[ind]; + global float * dst_row = (global float *) ((global char *) dst + i12*nb3 + i11*nb2 + i10*nb1); + global float * src_row = (global float *) ((global char *) src0 + r*nb01 + i02*nb02 + i03*nb03); + + int span = (ne00 + nchunks - 1) / nchunks; + int start = chunk * span; + int end = min(start + span, ne00); + + for (int ind = start + get_local_id(0); ind < end; ind += get_local_size(0)) { + dst_row[ind] = src_row[ind]; } } diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl index 86fe09c6dd6..57b90c05ae5 100644 --- a/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl @@ -33,13 +33,15 @@ inline float block_q_6_K_dot_y_flat( global uchar * blk_qh, global char * blk_scales, global half * blk_d, - global float * yy, int ib, int ip, int is, - int l0 + int l0, + float4 y0, + float4 y1, + float4 y2, + float4 y3 ) { - int y_offset = 128*ip + l0; int q_offset_l = 64*ip + l0; int q_offset_h = 32*ip + l0; @@ -48,36 +50,28 @@ inline float block_q_6_K_dot_y_flat( global uchar * qh = blk_qh + ib*64 + q_offset_h; global char * sc = blk_scales + ib*16 + is; - global float * y = yy + ib * QK_K + y_offset; - float dall = blk_d[ib]; - float sumf = 0; - float4 sums = {0.f, 0.f, 0.f, 0.f}; - - sums.s0 += y[0+ 0] * ((float)((q1[0] & 0xF) | ((qh[0] & Q6_K_MASK1) << 4)) - 32.f); - sums.s1 += y[0+32] * ((float)((q2[0] & 0xF) | ((qh[0] & Q6_K_MASK2) << 2)) - 32.f); - sums.s2 += y[0+64] * ((float)((q1[0] >> 4) | ((qh[0] & Q6_K_MASK3) << 0)) - 32.f); - sums.s3 += y[0+96] * ((float)((q2[0] >> 4) | ((qh[0] & Q6_K_MASK4) >> 2)) - 32.f); - - sums.s0 += y[1+ 0] * ((float)((q1[1] & 0xF) | ((qh[1] & Q6_K_MASK1) << 4)) - 32.f); - sums.s1 += y[1+32] * ((float)((q2[1] & 0xF) | ((qh[1] & Q6_K_MASK2) << 2)) - 32.f); - sums.s2 += y[1+64] * ((float)((q1[1] >> 4) | ((qh[1] & Q6_K_MASK3) << 0)) - 32.f); - sums.s3 += y[1+96] * ((float)((q2[1] >> 4) | ((qh[1] & Q6_K_MASK4) >> 2)) - 32.f); - - sums.s0 += y[2+ 0] * ((float)((q1[2] & 0xF) | ((qh[2] & Q6_K_MASK1) << 4)) - 32.f); - sums.s1 += y[2+32] * ((float)((q2[2] & 0xF) | ((qh[2] & Q6_K_MASK2) << 2)) - 32.f); - sums.s2 += y[2+64] * ((float)((q1[2] >> 4) | ((qh[2] & Q6_K_MASK3) << 0)) - 32.f); - sums.s3 += y[2+96] * ((float)((q2[2] >> 4) | ((qh[2] & Q6_K_MASK4) >> 2)) - 32.f); - - sums.s0 += y[3+ 0] * ((float)((q1[3] & 0xF) | ((qh[3] & Q6_K_MASK1) << 4)) - 32.f); - sums.s1 += y[3+32] * ((float)((q2[3] & 0xF) | ((qh[3] & Q6_K_MASK2) << 2)) - 32.f); - sums.s2 += y[3+64] * ((float)((q1[3] >> 4) | ((qh[3] & Q6_K_MASK3) << 0)) - 32.f); - sums.s3 += y[3+96] * ((float)((q2[3] >> 4) | ((qh[3] & Q6_K_MASK4) >> 2)) - 32.f); - - sumf += dall * (sums.s0 * sc[0] + sums.s1 * sc[2] + sums.s2 * sc[4] + sums.s3 * sc[6]); - - return sumf; + // Vectorized loads: 3 uchar4 weight loads instead of 12 scalar byte reads. + // q_offset_l/h are 4-aligned, so these are aligned vector loads. + uchar4 q1v = vload4(0, q1); + uchar4 q2v = vload4(0, q2); + uchar4 qhv = vload4(0, qh); + + int4 q1i = convert_int4(q1v); + int4 q2i = convert_int4(q2v); + int4 qhi = convert_int4(qhv); + + // Reconstruct the four 6-bit weight groups (low/high nibble of ql OR'd with the + // matching 2-bit plane of qh), same arithmetic as the scalar version, then dot() + // against the cached activation lanes. + float4 w0 = convert_float4((q1i & 0xF) | ((qhi & Q6_K_MASK1) << 4)) - 32.f; + float4 w1 = convert_float4((q2i & 0xF) | ((qhi & Q6_K_MASK2) << 2)) - 32.f; + float4 w2 = convert_float4((q1i >> 4) | ((qhi & Q6_K_MASK3) )) - 32.f; + float4 w3 = convert_float4((q2i >> 4) | ((qhi & Q6_K_MASK4) >> 2)) - 32.f; + + return dall * (dot(y0, w0) * sc[0] + dot(y1, w1) * sc[2] + + dot(y2, w2) * sc[4] + dot(y3, w3) * sc[6]); } #undef N_DST @@ -89,7 +83,7 @@ inline float block_q_6_K_dot_y_flat( #define N_SIMDGROUP 2 #define N_SIMDWIDTH 16 #elif defined (ADRENO_GPU) -#define N_DST 4 +#define N_DST 16 #define N_SIMDGROUP 2 #define N_SIMDWIDTH 64 #endif @@ -146,49 +140,39 @@ kernel void kernel_mul_mv_q6_K_f32_flat( global half * blk_d = (global half *) src0_d + offset_src0_d; global float * yy = (global float *) src1 + r1*ne10 + im*ne00*ne1; - int tid = get_sub_group_local_id()/BLOCK_STRIDE; // first block_stride groups have tid=0 - int ix = get_sub_group_local_id()%BLOCK_STRIDE; // first block is 0..block_stride-1 + int tid = get_sub_group_local_id()%(N_SIMDWIDTH/BLOCK_STRIDE); // within-super-block part, 0..15 + int ix = get_sub_group_local_id()/(N_SIMDWIDTH/BLOCK_STRIDE); // super-block selector, 0..BLOCK_STRIDE-1 int ip = tid/8; // first or second half of (super) block (0 or 1) int il = tid%8; // each half has 8 parts, one per scale int n = 4; // 4 scales at a time (and 4 sums) int l0 = n*il; // offset into half-block, 0..28 int is = 8*ip + l0/16; // 0, 1, 8, 9 - float4 sumf = 0; + float sumf[N_DST]; + for (int row = 0; row < N_DST; row++) { + sumf[row] = 0.f; + } for (int ib = ix; ib < nb; ib += BLOCK_STRIDE) { - if (first_row + 0 < ne01) { - sumf.s0 += block_q_6_K_dot_y_flat(blk_ql + 0*nb*128, blk_qh + 0*nb*64, blk_scales + 0*nb*16, blk_d + 0*nb, yy, ib, ip, is, l0); - } - if (first_row + 1 < ne01) { - sumf.s1 += block_q_6_K_dot_y_flat(blk_ql + 1*nb*128, blk_qh + 1*nb*64, blk_scales + 1*nb*16, blk_d + 1*nb, yy, ib, ip, is, l0); - } - if (first_row + 2 < ne01) { - sumf.s2 += block_q_6_K_dot_y_flat(blk_ql + 2*nb*128, blk_qh + 2*nb*64, blk_scales + 2*nb*16, blk_d + 2*nb, yy, ib, ip, is, l0); - } - if (first_row + 3 < ne01) { - sumf.s3 += block_q_6_K_dot_y_flat(blk_ql + 3*nb*128, blk_qh + 3*nb*64, blk_scales + 3*nb*16, blk_d + 3*nb, yy, ib, ip, is, l0); + global float * y = yy + ib * QK_K + 128*ip + l0; + float4 y0 = vload4(0, y + 0); + float4 y1 = vload4(0, y + 32); + float4 y2 = vload4(0, y + 64); + float4 y3 = vload4(0, y + 96); + + for (int row = 0; row < N_DST; row++) { + if (first_row + row < ne01) { + sumf[row] += block_q_6_K_dot_y_flat( + blk_ql + row*nb*128, blk_qh + row*nb*64, blk_scales + row*nb*16, blk_d + row*nb, + ib, ip, is, l0, y0, y1, y2, y3); + } } } - float4 tot = (float4)( - sub_group_reduce_add(sumf.s0), - sub_group_reduce_add(sumf.s1), - sub_group_reduce_add(sumf.s2), - sub_group_reduce_add(sumf.s3) - ); - if (get_sub_group_local_id() == 0) { - if (first_row + 0 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; - } - if (first_row + 1 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; - } - if (first_row + 2 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; - } - if (first_row + 3 < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; + for (int row = 0; row < N_DST; row++) { + float tot = sub_group_reduce_add(sumf[row]); + if (get_sub_group_local_id() == 0 && first_row + row < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot; } } } From 1777deff4c014d4bfc895eafe24354deeec80c94 Mon Sep 17 00:00:00 2001 From: Ruben Ortlam <rortlam@redhat.com> Date: Sat, 6 Jun 2026 09:11:35 +0200 Subject: [PATCH 792/831] vulkan: check coopmat2 features before reporting support (llama/24186) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index df410368a79..fc9bc8fe376 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -6349,6 +6349,15 @@ static void ggml_vk_print_gpu_info(size_t idx) { } #endif +#if defined(VK_NV_cooperative_matrix2) + VkPhysicalDeviceCooperativeMatrix2FeaturesNV coopmat2_features {}; + coopmat2_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_2_FEATURES_NV; + if (coopmat2_support) { + last_struct->pNext = (VkBaseOutStructure *)&coopmat2_features; + last_struct = (VkBaseOutStructure *)&coopmat2_features; + } +#endif + VkPhysicalDeviceCooperativeMatrixDecodeVectorFeaturesNV coopmat2_decode_vector_features {}; coopmat2_decode_vector_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_DECODE_VECTOR_FEATURES_NV; if (coopmat2_decode_vector_support) { @@ -6380,6 +6389,19 @@ static void ggml_vk_print_gpu_info(size_t idx) { #endif && ggml_vk_khr_cooperative_matrix_support(props2.properties, driver_props, device_architecture); +#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + coopmat2_support = coopmat2_support && + coopmat2_features.cooperativeMatrixWorkgroupScope && + coopmat2_features.cooperativeMatrixFlexibleDimensions && + coopmat2_features.cooperativeMatrixReductions && + coopmat2_features.cooperativeMatrixConversions && + coopmat2_features.cooperativeMatrixPerElementOperations && + coopmat2_features.cooperativeMatrixTensorAddressing && + coopmat2_features.cooperativeMatrixBlockLoads; +#else + coopmat2_support = false; +#endif + coopmat2_decode_vector_support = coopmat2_decode_vector_support && coopmat2_decode_vector_features.cooperativeMatrixDecodeVector; #if !defined(GGML_VULKAN_COOPMAT2_DECODE_VECTOR_GLSLC_SUPPORT) coopmat2_decode_vector_support = false; From 2c139c2e5ee7a0820fdcf5e1a308322027453eac Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen <son@huggingface.co> Date: Mon, 8 Jun 2026 08:03:18 +0200 Subject: [PATCH 793/831] metal : fix im2col 1D case (audio models) (llama/24220) --- ggml/src/ggml-metal/ggml-metal-device.cpp | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 5d4b10d34b9..ce847dd8b6f 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -1738,10 +1738,14 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_im2col(ggml_meta GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32); GGML_ASSERT(op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32); + const bool is_2D = ((const int32_t *)(op->op_params))[6] == 1; + const int64_t KH = is_2D ? ne01 : 1; + const int64_t KW = ne00; + char base[256]; char name[256]; - if (ne00*ne01 <= 1024) { + if (KH*KW <= 1024) { snprintf(base, 256, "kernel_im2col_%s", ggml_type_name(op->type)); } else { snprintf(base, 256, "kernel_im2col_ext_%s", ggml_type_name(op->type)); From 4669631d20f32342ae58c346224eeabde251fa07 Mon Sep 17 00:00:00 2001 From: Harkirat Gill <harkirat.gill@amd.com> Date: Mon, 8 Jun 2026 02:33:23 -0400 Subject: [PATCH 794/831] HIP: add gfx1152 and gfx1153 to RDNA3.5 (llama/24129) --- ggml/src/ggml-cuda/vendors/hip.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h index 5e0e22c7fc2..a6115cd80dc 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h @@ -219,9 +219,9 @@ #define RDNA3 #endif // defined(__GFX11__) -#if defined(__gfx1150__) || defined(__gfx1151__) +#if defined(__gfx1150__) || defined(__gfx1151__) || defined(__gfx1152__) || defined(__gfx1153__) #define RDNA3_5 -#endif // defined(__gfx1150__) || defined(__gfx1151__) +#endif // defined(__gfx1150__) || defined(__gfx1151__) || defined(__gfx1152__) || defined(__gfx1153__) #if defined(RDNA3) && !defined(RDNA3_5) #define RDNA3_0 From b932ec55298a6ebd2bcd2a0d2e62f9df61e6008d Mon Sep 17 00:00:00 2001 From: Georgi Gerganov <ggerganov@gmail.com> Date: Mon, 8 Jun 2026 12:52:17 +0300 Subject: [PATCH 795/831] sync : ggml --- scripts/sync-ggml.last | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index 538ef80bc7a..f42565bab0f 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -1e33fed33e87c43aa4c4078e2a9c239d4c1f1bd3 +c95cd071e1bf235ac41ef58c5a5535f73024375c From b31466b4a13d4d55d1bbd9a6055861a8bc3968de Mon Sep 17 00:00:00 2001 From: Georgi Gerganov <ggerganov@gmail.com> Date: Mon, 8 Jun 2026 12:51:59 +0300 Subject: [PATCH 796/831] ggml : bump version to 0.14.0 (ggml/1533) --- ggml/CMakeLists.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index dc8899b46ef..8f7cb8cdfd2 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -4,8 +4,8 @@ project("ggml" C CXX ASM) ### GGML Version set(GGML_VERSION_MAJOR 0) -set(GGML_VERSION_MINOR 13) -set(GGML_VERSION_PATCH 1) +set(GGML_VERSION_MINOR 14) +set(GGML_VERSION_PATCH 0) set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}") list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/") From 4df9a57df23a1eb5a47ce988d606b81e7dc0db27 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov <ggerganov@gmail.com> Date: Mon, 8 Jun 2026 12:52:27 +0300 Subject: [PATCH 797/831] sync : ggml --- scripts/sync-ggml.last | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index f42565bab0f..6e1bf3a1f4b 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -c95cd071e1bf235ac41ef58c5a5535f73024375c +7142aa6bf9fcaeec0fef8d80fcd90afe4268adf1 From 84bd03a438454a82150853dce83818013c6609d2 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov <ggerganov@gmail.com> Date: Mon, 8 Jun 2026 12:55:06 +0300 Subject: [PATCH 798/831] talk-llama : sync llama.cpp --- examples/talk-llama/CMakeLists.txt | 1 + examples/talk-llama/llama-adapter.cpp | 8 +- examples/talk-llama/llama-arch.cpp | 13 + examples/talk-llama/llama-arch.h | 10 + examples/talk-llama/llama-context.cpp | 224 ++++---- examples/talk-llama/llama-context.h | 17 +- examples/talk-llama/llama-cparams.h | 9 +- examples/talk-llama/llama-ext.h | 14 +- examples/talk-llama/llama-graph.cpp | 322 ++++++++--- examples/talk-llama/llama-graph.h | 109 +++- examples/talk-llama/llama-hparams.cpp | 88 +-- examples/talk-llama/llama-hparams.h | 67 ++- examples/talk-llama/llama-impl.h | 14 + examples/talk-llama/llama-kv-cache-dsa.cpp | 261 +++++++++ examples/talk-llama/llama-kv-cache-dsa.h | 138 +++++ examples/talk-llama/llama-kv-cache-iswa.cpp | 22 +- examples/talk-llama/llama-kv-cache-iswa.h | 4 +- examples/talk-llama/llama-kv-cache.cpp | 222 ++++++-- examples/talk-llama/llama-kv-cache.h | 15 +- examples/talk-llama/llama-kv-cells.h | 2 + .../talk-llama/llama-memory-hybrid-iswa.cpp | 6 +- examples/talk-llama/llama-memory-hybrid.cpp | 7 +- .../talk-llama/llama-memory-recurrent.cpp | 8 +- examples/talk-llama/llama-memory.h | 4 + examples/talk-llama/llama-model-loader.cpp | 15 +- examples/talk-llama/llama-model-saver.cpp | 16 +- examples/talk-llama/llama-model.cpp | 335 ++++++++---- examples/talk-llama/llama-model.h | 14 +- examples/talk-llama/llama-quant.cpp | 4 +- examples/talk-llama/llama-vocab.cpp | 95 +++- examples/talk-llama/llama-vocab.h | 112 ++-- examples/talk-llama/llama.cpp | 9 +- examples/talk-llama/llama.h | 11 +- examples/talk-llama/models/afmoe.cpp | 2 +- examples/talk-llama/models/apertus.cpp | 11 +- examples/talk-llama/models/arcee.cpp | 2 +- examples/talk-llama/models/arctic.cpp | 2 +- examples/talk-llama/models/arwkv7.cpp | 2 +- examples/talk-llama/models/baichuan.cpp | 2 +- examples/talk-llama/models/bailingmoe.cpp | 2 +- examples/talk-llama/models/bailingmoe2.cpp | 21 +- examples/talk-llama/models/bert.cpp | 4 +- examples/talk-llama/models/bitnet.cpp | 2 +- examples/talk-llama/models/bloom.cpp | 2 +- examples/talk-llama/models/chameleon.cpp | 2 +- examples/talk-llama/models/chatglm.cpp | 3 +- examples/talk-llama/models/codeshell.cpp | 3 +- examples/talk-llama/models/cogvlm.cpp | 3 +- examples/talk-llama/models/cohere2.cpp | 4 +- examples/talk-llama/models/command-r.cpp | 3 +- examples/talk-llama/models/dbrx.cpp | 12 +- examples/talk-llama/models/deci.cpp | 3 +- examples/talk-llama/models/deepseek2.cpp | 11 +- examples/talk-llama/models/deepseek2ocr.cpp | 2 +- examples/talk-llama/models/deepseek32.cpp | 499 ++++++++++++++++++ examples/talk-llama/models/dots1.cpp | 3 +- examples/talk-llama/models/dream.cpp | 3 +- examples/talk-llama/models/ernie4-5.cpp | 2 +- examples/talk-llama/models/eurobert.cpp | 2 +- examples/talk-llama/models/exaone-moe.cpp | 22 +- examples/talk-llama/models/exaone.cpp | 2 +- examples/talk-llama/models/exaone4.cpp | 45 +- examples/talk-llama/models/falcon-h1.cpp | 4 +- examples/talk-llama/models/falcon.cpp | 2 +- .../talk-llama/models/gemma-embedding.cpp | 2 +- examples/talk-llama/models/gemma.cpp | 2 +- examples/talk-llama/models/gemma2.cpp | 2 +- examples/talk-llama/models/gemma3.cpp | 2 +- examples/talk-llama/models/gemma3n.cpp | 6 +- .../talk-llama/models/gemma4-assistant.cpp | 200 +++++++ examples/talk-llama/models/gemma4.cpp | 59 ++- examples/talk-llama/models/glm-dsa.cpp | 17 +- examples/talk-llama/models/glm4-moe.cpp | 24 +- examples/talk-llama/models/glm4.cpp | 20 +- examples/talk-llama/models/gpt2.cpp | 3 +- examples/talk-llama/models/gptneox.cpp | 3 +- examples/talk-llama/models/granite-hybrid.cpp | 8 +- examples/talk-llama/models/granite-moe.cpp | 2 +- examples/talk-llama/models/granite.cpp | 39 +- examples/talk-llama/models/grok.cpp | 2 +- examples/talk-llama/models/grovemoe.cpp | 2 +- examples/talk-llama/models/hunyuan-moe.cpp | 2 +- examples/talk-llama/models/internlm2.cpp | 3 +- examples/talk-llama/models/jais.cpp | 2 +- examples/talk-llama/models/jais2.cpp | 2 +- examples/talk-llama/models/jamba.cpp | 6 +- examples/talk-llama/models/jina-bert-v2.cpp | 2 +- examples/talk-llama/models/jina-bert-v3.cpp | 2 +- examples/talk-llama/models/kimi-linear.cpp | 10 +- examples/talk-llama/models/lfm2.cpp | 20 +- examples/talk-llama/models/lfm2moe.cpp | 8 +- examples/talk-llama/models/llada-moe.cpp | 5 +- examples/talk-llama/models/llada.cpp | 4 +- examples/talk-llama/models/llama.cpp | 4 +- examples/talk-llama/models/llama4.cpp | 5 +- examples/talk-llama/models/maincoder.cpp | 3 +- examples/talk-llama/models/mamba.cpp | 2 +- examples/talk-llama/models/mamba2.cpp | 2 +- examples/talk-llama/models/mellum.cpp | 225 ++++++++ examples/talk-llama/models/mimo2.cpp | 23 +- examples/talk-llama/models/minicpm.cpp | 4 +- examples/talk-llama/models/minicpm3.cpp | 2 +- examples/talk-llama/models/minimax-m2.cpp | 2 +- examples/talk-llama/models/mistral3.cpp | 2 +- examples/talk-llama/models/models.h | 42 ++ examples/talk-llama/models/modern-bert.cpp | 13 +- examples/talk-llama/models/mpt.cpp | 2 +- examples/talk-llama/models/nemotron-h.cpp | 10 +- examples/talk-llama/models/nemotron.cpp | 3 +- examples/talk-llama/models/neo-bert.cpp | 2 +- examples/talk-llama/models/nomic-bert-moe.cpp | 2 +- examples/talk-llama/models/nomic-bert.cpp | 2 +- examples/talk-llama/models/olmo.cpp | 2 +- examples/talk-llama/models/olmo2.cpp | 2 +- examples/talk-llama/models/olmoe.cpp | 3 +- examples/talk-llama/models/openai-moe.cpp | 2 +- examples/talk-llama/models/openelm.cpp | 12 +- examples/talk-llama/models/orion.cpp | 2 +- examples/talk-llama/models/pangu-embed.cpp | 3 +- examples/talk-llama/models/phi2.cpp | 2 +- examples/talk-llama/models/phi3.cpp | 2 +- examples/talk-llama/models/phimoe.cpp | 2 +- examples/talk-llama/models/plamo.cpp | 2 +- examples/talk-llama/models/plamo2.cpp | 10 +- examples/talk-llama/models/plamo3.cpp | 2 +- examples/talk-llama/models/plm.cpp | 3 +- examples/talk-llama/models/qwen.cpp | 2 +- examples/talk-llama/models/qwen2.cpp | 3 +- examples/talk-llama/models/qwen2moe.cpp | 3 +- examples/talk-llama/models/qwen3.cpp | 3 +- examples/talk-llama/models/qwen35.cpp | 88 +-- examples/talk-llama/models/qwen35moe.cpp | 87 +-- examples/talk-llama/models/qwen3moe.cpp | 6 +- examples/talk-llama/models/qwen3next.cpp | 12 +- examples/talk-llama/models/qwen3vl.cpp | 3 +- examples/talk-llama/models/qwen3vlmoe.cpp | 3 +- examples/talk-llama/models/refact.cpp | 3 +- examples/talk-llama/models/rnd1.cpp | 5 +- examples/talk-llama/models/rwkv6.cpp | 2 +- examples/talk-llama/models/rwkv6qwen2.cpp | 2 +- examples/talk-llama/models/rwkv7.cpp | 2 +- examples/talk-llama/models/seed-oss.cpp | 3 +- examples/talk-llama/models/smallthinker.cpp | 4 +- examples/talk-llama/models/smollm3.cpp | 2 +- examples/talk-llama/models/stablelm.cpp | 2 +- examples/talk-llama/models/starcoder.cpp | 3 +- examples/talk-llama/models/starcoder2.cpp | 3 +- examples/talk-llama/models/step35.cpp | 314 ++++++++++- examples/talk-llama/models/t5.cpp | 4 +- examples/talk-llama/models/talkie.cpp | 2 +- examples/talk-llama/models/xverse.cpp | 3 +- 151 files changed, 3434 insertions(+), 866 deletions(-) create mode 100644 examples/talk-llama/llama-kv-cache-dsa.cpp create mode 100644 examples/talk-llama/llama-kv-cache-dsa.h create mode 100644 examples/talk-llama/models/deepseek32.cpp create mode 100644 examples/talk-llama/models/gemma4-assistant.cpp create mode 100644 examples/talk-llama/models/mellum.cpp diff --git a/examples/talk-llama/CMakeLists.txt b/examples/talk-llama/CMakeLists.txt index 1adeef8f511..13b284ed0e9 100644 --- a/examples/talk-llama/CMakeLists.txt +++ b/examples/talk-llama/CMakeLists.txt @@ -20,6 +20,7 @@ if (WHISPER_SDL2) llama-io.cpp llama-kv-cache.cpp llama-kv-cache-iswa.cpp + llama-kv-cache-dsa.cpp llama-memory-recurrent.cpp llama-memory-hybrid.cpp llama-memory-hybrid-iswa.cpp diff --git a/examples/talk-llama/llama-adapter.cpp b/examples/talk-llama/llama-adapter.cpp index 4a1aaa955a8..3e0fe66afff 100644 --- a/examples/talk-llama/llama-adapter.cpp +++ b/examples/talk-llama/llama-adapter.cpp @@ -41,7 +41,7 @@ bool llama_adapter_cvec::init(const llama_model & model) { auto it = ctx_map.find(buft); if (it == ctx_map.end()) { ggml_init_params params = { - /*.mem_size =*/ hparams.n_layer*ggml_tensor_overhead(), + /*.mem_size =*/ hparams.n_layer()*ggml_tensor_overhead(), /*.mem_buffer =*/ NULL, /*.no_alloc =*/ true, }; @@ -61,9 +61,9 @@ bool llama_adapter_cvec::init(const llama_model & model) { }; // make tensors - tensors.reserve(hparams.n_layer); + tensors.reserve(hparams.n_layer()); tensors.push_back(nullptr); // there's never a tensor for layer 0 - for (size_t il = 1; il < hparams.n_layer; il++) { + for (size_t il = 1; il < hparams.n_layer(); il++) { ggml_backend_buffer_type_t buft = model.select_buft(il); ggml_context * ctx = ctx_for_buft(buft); if (!ctx) { @@ -121,7 +121,7 @@ bool llama_adapter_cvec::apply( layer_start = il_start; layer_end = il_end; - for (size_t il = 1; il < hparams.n_layer; il++) { + for (size_t il = 1; il < hparams.n_layer(); il++) { assert(tensors[il] != nullptr); const size_t off = n_embd * (il - 1); // buffer doesn't have data for layer 0, since it's never present diff --git a/examples/talk-llama/llama-arch.cpp b/examples/talk-llama/llama-arch.cpp index e95ba6daac1..6a5d5f8d2ac 100644 --- a/examples/talk-llama/llama-arch.cpp +++ b/examples/talk-llama/llama-arch.cpp @@ -57,6 +57,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = { { LLM_ARCH_GEMMA3, "gemma3" }, { LLM_ARCH_GEMMA3N, "gemma3n" }, { LLM_ARCH_GEMMA4, "gemma4" }, + { LLM_ARCH_GEMMA4_ASSISTANT, "gemma4-assistant" }, { LLM_ARCH_GEMMA_EMBEDDING, "gemma-embedding" }, { LLM_ARCH_STARCODER2, "starcoder2" }, { LLM_ARCH_MAMBA, "mamba" }, @@ -75,6 +76,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = { { LLM_ARCH_DEEPSEEK, "deepseek" }, { LLM_ARCH_DEEPSEEK2, "deepseek2" }, { LLM_ARCH_DEEPSEEK2OCR, "deepseek2-ocr" }, + { LLM_ARCH_DEEPSEEK32, "deepseek32" }, { LLM_ARCH_CHATGLM, "chatglm" }, { LLM_ARCH_GLM4, "glm4" }, { LLM_ARCH_GLM4_MOE, "glm4moe" }, @@ -134,6 +136,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = { { LLM_ARCH_MAINCODER, "maincoder" }, { LLM_ARCH_KIMI_LINEAR, "kimi-linear" }, { LLM_ARCH_TALKIE, "talkie" }, + { LLM_ARCH_MELLUM, "mellum" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -194,6 +197,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = { { LLM_KV_MOE_LATENT_SIZE, "%s.moe_latent_size" }, { LLM_KV_NEXTN_PREDICT_LAYERS, "%s.nextn_predict_layers" }, { LLM_KV_NUM_DEEPSTACK_LAYERS, "%s.n_deepstack_layers" }, + { LLM_KV_DEEPSTACK_MAPPING, "%s.deepstack_mapping" }, + { LLM_KV_HIDDEN_ACT, "%s.hidden_activation" }, { LLM_KV_POOLING_TYPE, "%s.pooling_type" }, { LLM_KV_LOGIT_SCALE, "%s.logit_scale" }, { LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" }, @@ -244,6 +249,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = { { LLM_KV_ATTENTION_INDEXER_KEY_LENGTH, "%s.attention.indexer.key_length" }, { LLM_KV_ATTENTION_INDEXER_TOP_K, "%s.attention.indexer.top_k" }, { LLM_KV_ATTENTION_SHARED_KV_LAYERS, "%s.attention.shared_kv_layers" }, + { LLM_KV_ATTENTION_RECURRENT_LAYERS, "%s.attention.recurrent_layers" }, { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, { LLM_KV_ROPE_DIMENSION_COUNT_SWA, "%s.rope.dimension_count_swa" }, @@ -318,12 +324,14 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = { { LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" }, { LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" }, { LLM_KV_TOKENIZER_CHAT_TEMPLATE, "tokenizer.chat_template" }, + { LLM_KV_TOKENIZER_NORMALIZER_LOWERCASE, "tokenizer.ggml.normalizer.lowercase" }, { LLM_KV_TOKENIZER_FIM_PRE_ID, "tokenizer.ggml.fim_pre_token_id" }, { LLM_KV_TOKENIZER_FIM_SUF_ID, "tokenizer.ggml.fim_suf_token_id" }, { LLM_KV_TOKENIZER_FIM_MID_ID, "tokenizer.ggml.fim_mid_token_id" }, { LLM_KV_TOKENIZER_FIM_PAD_ID, "tokenizer.ggml.fim_pad_token_id" }, { LLM_KV_TOKENIZER_FIM_REP_ID, "tokenizer.ggml.fim_rep_token_id" }, { LLM_KV_TOKENIZER_FIM_SEP_ID, "tokenizer.ggml.fim_sep_token_id" }, + { LLM_KV_TOKENIZER_SUPPRESS_TOKENS, "tokenizer.ggml.suppress_tokens" }, { LLM_KV_ADAPTER_TYPE, "adapter.type" }, { LLM_KV_ADAPTER_LORA_ALPHA, "adapter.lora.alpha" }, @@ -446,6 +454,8 @@ static const std::map<llm_tensor, const char *> LLM_TENSOR_NAMES = { { LLM_TENSOR_FFN_NORM_EXPS, "blk.%d.ffn_norm_exps" }, { LLM_TENSOR_ATTN_K_B, "blk.%d.attn_k_b" }, { LLM_TENSOR_ATTN_V_B, "blk.%d.attn_v_b" }, + { LLM_TENSOR_NEXTN_PROJ_PRE, "nextn.pre_projection" }, + { LLM_TENSOR_NEXTN_PROJ_POST, "nextn.post_projection" }, { LLM_TENSOR_NEXTN_EH_PROJ, "blk.%d.nextn.eh_proj" }, { LLM_TENSOR_NEXTN_EMBED_TOKENS, "blk.%d.nextn.embed_tokens" }, { LLM_TENSOR_NEXTN_ENORM, "blk.%d.nextn.enorm" }, @@ -758,6 +768,8 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = { {LLM_TENSOR_INDEXER_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_INDEXER_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_INDEXER_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_NEXTN_PROJ_PRE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_NEXTN_PROJ_POST, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // NextN/MTP tensors are stored per-block (blk.%d.nextn.*) even though only the // last nextn_predict_layers blocks carry them. Classify as LAYER_REPEATING so // the model loader doesn't fault on the block index. @@ -904,6 +916,7 @@ bool llm_arch_supports_sm_tensor(const llm_arch & arch) { case LLM_ARCH_OLMO2: case LLM_ARCH_OLMOE: case LLM_ARCH_DEEPSEEK2: + case LLM_ARCH_DEEPSEEK32: case LLM_ARCH_GLM_DSA: case LLM_ARCH_BITNET: case LLM_ARCH_T5: diff --git a/examples/talk-llama/llama-arch.h b/examples/talk-llama/llama-arch.h index 7c1dcc4d6c2..03b1a265d67 100644 --- a/examples/talk-llama/llama-arch.h +++ b/examples/talk-llama/llama-arch.h @@ -61,6 +61,7 @@ enum llm_arch { LLM_ARCH_GEMMA3, LLM_ARCH_GEMMA3N, LLM_ARCH_GEMMA4, + LLM_ARCH_GEMMA4_ASSISTANT, LLM_ARCH_GEMMA_EMBEDDING, LLM_ARCH_STARCODER2, LLM_ARCH_MAMBA, @@ -79,6 +80,7 @@ enum llm_arch { LLM_ARCH_DEEPSEEK, LLM_ARCH_DEEPSEEK2, LLM_ARCH_DEEPSEEK2OCR, + LLM_ARCH_DEEPSEEK32, LLM_ARCH_CHATGLM, LLM_ARCH_GLM4, LLM_ARCH_GLM4_MOE, @@ -138,6 +140,7 @@ enum llm_arch { LLM_ARCH_MAINCODER, LLM_ARCH_KIMI_LINEAR, LLM_ARCH_TALKIE, + LLM_ARCH_MELLUM, LLM_ARCH_UNKNOWN, }; @@ -198,6 +201,8 @@ enum llm_kv { LLM_KV_MOE_LATENT_SIZE, LLM_KV_NEXTN_PREDICT_LAYERS, LLM_KV_NUM_DEEPSTACK_LAYERS, + LLM_KV_DEEPSTACK_MAPPING, + LLM_KV_HIDDEN_ACT, LLM_KV_POOLING_TYPE, LLM_KV_LOGIT_SCALE, LLM_KV_DECODER_START_TOKEN_ID, @@ -248,6 +253,7 @@ enum llm_kv { LLM_KV_ATTENTION_INDEXER_KEY_LENGTH, LLM_KV_ATTENTION_INDEXER_TOP_K, LLM_KV_ATTENTION_SHARED_KV_LAYERS, + LLM_KV_ATTENTION_RECURRENT_LAYERS, LLM_KV_ROPE_DIMENSION_COUNT, LLM_KV_ROPE_DIMENSION_COUNT_SWA, @@ -307,12 +313,14 @@ enum llm_kv { LLM_KV_TOKENIZER_HF_JSON, LLM_KV_TOKENIZER_RWKV, LLM_KV_TOKENIZER_CHAT_TEMPLATE, + LLM_KV_TOKENIZER_NORMALIZER_LOWERCASE, LLM_KV_TOKENIZER_FIM_PRE_ID, LLM_KV_TOKENIZER_FIM_SUF_ID, LLM_KV_TOKENIZER_FIM_MID_ID, LLM_KV_TOKENIZER_FIM_PAD_ID, LLM_KV_TOKENIZER_FIM_REP_ID, LLM_KV_TOKENIZER_FIM_SEP_ID, + LLM_KV_TOKENIZER_SUPPRESS_TOKENS, LLM_KV_ADAPTER_TYPE, LLM_KV_ADAPTER_LORA_ALPHA, @@ -550,6 +558,8 @@ enum llm_tensor { LLM_TENSOR_INDEXER_PROJ, LLM_TENSOR_INDEXER_ATTN_K, LLM_TENSOR_INDEXER_ATTN_Q_B, + LLM_TENSOR_NEXTN_PROJ_PRE, + LLM_TENSOR_NEXTN_PROJ_POST, LLM_TENSOR_NEXTN_EH_PROJ, LLM_TENSOR_NEXTN_EMBED_TOKENS, LLM_TENSOR_NEXTN_ENORM, diff --git a/examples/talk-llama/llama-context.cpp b/examples/talk-llama/llama-context.cpp index ad36c06667d..9a40c4366af 100644 --- a/examples/talk-llama/llama-context.cpp +++ b/examples/talk-llama/llama-context.cpp @@ -58,19 +58,21 @@ llama_context::llama_context( cparams.n_rs_seq = 0; } - cparams.n_threads = params.n_threads; - cparams.n_threads_batch = params.n_threads_batch; - cparams.yarn_ext_factor = params.yarn_ext_factor >= 0.0f ? params.yarn_ext_factor : hparams.yarn_ext_factor; - cparams.yarn_attn_factor = params.yarn_attn_factor >= 0.0f ? params.yarn_attn_factor : hparams.yarn_attn_factor; - cparams.yarn_beta_fast = params.yarn_beta_fast >= 0.0f ? params.yarn_beta_fast : hparams.yarn_beta_fast; - cparams.yarn_beta_slow = params.yarn_beta_slow >= 0.0f ? params.yarn_beta_slow : hparams.yarn_beta_slow; - cparams.embeddings = params.embeddings; - cparams.embeddings_pre_norm = false; - cparams.embeddings_pre_norm_masked = false; - cparams.offload_kqv = params.offload_kqv; - cparams.no_perf = params.no_perf; - cparams.pooling_type = params.pooling_type; - cparams.warmup = false; + cparams.n_threads = params.n_threads; + cparams.n_threads_batch = params.n_threads_batch; + cparams.yarn_ext_factor = params.yarn_ext_factor >= 0.0f ? params.yarn_ext_factor : hparams.yarn_ext_factor; + cparams.yarn_attn_factor = params.yarn_attn_factor >= 0.0f ? params.yarn_attn_factor : hparams.yarn_attn_factor; + cparams.yarn_beta_fast = params.yarn_beta_fast >= 0.0f ? params.yarn_beta_fast : hparams.yarn_beta_fast; + cparams.yarn_beta_slow = params.yarn_beta_slow >= 0.0f ? params.yarn_beta_slow : hparams.yarn_beta_slow; + cparams.embeddings = params.embeddings; + cparams.embeddings_nextn = false; + cparams.embeddings_nextn_masked = false; + cparams.offload_kqv = params.offload_kqv; + cparams.no_perf = params.no_perf; + cparams.warmup = false; + + cparams.ctx_type = params.ctx_type; + cparams.pooling_type = params.pooling_type; cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx; cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base; @@ -83,7 +85,17 @@ llama_context::llama_context( cparams.cb_eval = params.cb_eval; cparams.cb_eval_user_data = params.cb_eval_user_data; - cparams.ctx_type = params.ctx_type; + cparams.ctx_other = nullptr; + + // TODO: more generic + if (model.arch == LLM_ARCH_GEMMA4_ASSISTANT) { + if (params.ctx_other == nullptr) { + // TODO: change from runtime_error to llama_exception to avoid printing error message + throw std::runtime_error("Gemma4Assistant requires ctx_other to be set (this is normal during memory fitting)"); + } + + cparams.ctx_other = params.ctx_other; + } // Initialize backend samplers here so they are part of the sampling graph // before the reserve passes run later in this function. This avoids a later @@ -182,6 +194,8 @@ llama_context::llama_context( cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch); + cparams.n_outputs_max = params.n_outputs_max == 0 ? cparams.n_batch : params.n_outputs_max; + cparams.op_offload = params.op_offload; cparams.kv_unified = params.kv_unified; @@ -227,6 +241,7 @@ llama_context::llama_context( LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale); LLAMA_LOG_INFO("%s: n_rs_seq = %u\n", __func__, cparams.n_rs_seq); + LLAMA_LOG_INFO("%s: n_outputs_max = %u\n", __func__, cparams.n_outputs_max); if (cparams.n_ctx_seq < hparams.n_ctx_train) { LLAMA_LOG_WARN("%s: n_ctx_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n", @@ -296,10 +311,11 @@ llama_context::llama_context( // init the memory module if (!hparams.vocab_only) { llama_memory_params params_mem = { - /*.type_k =*/ params.type_k, - /*.type_v =*/ params.type_v, - /*.swa_full =*/ params.swa_full, - /*.ctx_type= */ cparams.ctx_type, + /*.type_k =*/ params.type_k, + /*.type_v =*/ params.type_v, + /*.swa_full =*/ params.swa_full, + /*.ctx_type =*/ cparams.ctx_type, + /*.mem_other =*/ llama_get_memory(cparams.ctx_other), }; memory.reset(model.create_memory(params_mem, cparams)); @@ -337,7 +353,7 @@ llama_context::llama_context( // enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary bool pipeline_parallel = model.n_devices() > 1 && - model.n_gpu_layers() > model.hparams.n_layer && + model.n_gpu_layers() > model.hparams.n_layer_all && model.split_mode() == LLAMA_SPLIT_MODE_LAYER && cparams.offload_kqv && !model.has_tensor_overrides(); @@ -531,7 +547,7 @@ void llama_context::sched_reserve() { // note: n_outputs must match n_tokens for embedding models with mean/rank pooling, // because build_pooling creates inp_mean with shape [n_tokens, n_seqs] and multiplies // it with t_embd which is reduced to [n_outputs, ...] via out_ids. if n_outputs != n_tokens, - // the ggml_mul_mat assertion fails. this matches the pp reservation below (line ~553). + // the ggml_mul_mat assertion fails. const uint32_t n_tokens_ch = 16*n_seqs; auto * gf = graph_reserve(n_tokens_ch, n_seqs, n_tokens_ch, mctx.get(), true); if (!gf) { @@ -577,16 +593,18 @@ void llama_context::sched_reserve() { int n_splits_tg = -1; int n_nodes_tg = -1; + const uint32_t n_outputs_pp = std::min(n_tokens, cparams.n_outputs_max); + // reserve pp (prompt processing) graph first so that buffers are only allocated once { - auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(), + auto * gf = graph_reserve(n_tokens, n_seqs, n_outputs_pp, mctx.get(), model.hparams.no_alloc, model.hparams.no_alloc ? backend_buf_exp_size.data() : nullptr); if (!gf) { if (cparams.pipeline_parallel) { LLAMA_LOG_WARN("%s: compute buffer allocation failed, retrying without pipeline parallelism\n", __func__); cparams.pipeline_parallel = false; sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, false, cparams.op_offload)); - gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get()); + gf = graph_reserve(n_tokens, n_seqs, n_outputs_pp, mctx.get()); } if (!gf) { throw std::runtime_error("failed to allocate compute pp buffers"); @@ -614,7 +632,7 @@ void llama_context::sched_reserve() { // // auto * gf = graph_reserve(n_tokens, 1, n_tokens, mctx.get()); // - auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(), model.hparams.no_alloc); + auto * gf = graph_reserve(n_tokens, n_seqs, n_outputs_pp, mctx.get(), model.hparams.no_alloc); if (!gf) { throw std::runtime_error("failed to allocate compute pp buffers"); } @@ -774,7 +792,9 @@ bool llama_context::memory_update(bool optimize) { const uint32_t n_seqs = cparams.n_seq_max; const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); - auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get()); + const uint32_t n_outputs_max = std::min(n_tokens, cparams.n_outputs_max); + + auto * gf = graph_reserve(n_tokens, n_seqs, n_outputs_max, mctx.get()); if (!gf) { LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__); } @@ -882,34 +902,34 @@ float * llama_context::get_embeddings_seq(llama_seq_id seq_id) { return it->second.data(); } -float * llama_context::get_embeddings_pre_norm() { +float * llama_context::get_embeddings_nextn() { output_reorder(); - return embd_pre_norm.data; + return embd_nextn.data; } -float * llama_context::get_embeddings_pre_norm_ith(int32_t i) { +float * llama_context::get_embeddings_nextn_ith(int32_t i) { output_reorder(); try { - if (embd_pre_norm.data == nullptr) { - throw std::runtime_error("no pre-norm embeddings"); + if (embd_nextn.data == nullptr) { + throw std::runtime_error("no nextn embeddings"); } - const uint32_t n_embd = model.hparams.n_embd; + const uint32_t n_embd = model.hparams.n_embd_out(); - if (!cparams.embeddings_pre_norm_masked) { - // unmasked: pre-norm rows are stored densely, indexed by raw token position. - if (i < 0 || (size_t)(i + 1) * n_embd > embd_pre_norm.size) { - throw std::runtime_error(format("out of range [0, %zu)", embd_pre_norm.size / n_embd)); + if (!cparams.embeddings_nextn_masked) { + // unmasked: nextn rows are stored densely, indexed by raw token position. + if (i < 0 || (size_t)(i + 1) * n_embd > embd_nextn.size) { + throw std::runtime_error(format("out of range [0, %zu)", embd_nextn.size / n_embd)); } - return embd_pre_norm.data + (size_t) i * n_embd; + return embd_nextn.data + (size_t) i * n_embd; } const int64_t j = output_resolve_row(i); - return embd_pre_norm.data + j*n_embd; + return embd_nextn.data + j*n_embd; } catch (const std::exception & err) { - LLAMA_LOG_ERROR("%s: invalid pre-norm embeddings id %d, reason: %s\n", __func__, i, err.what()); + LLAMA_LOG_ERROR("%s: invalid nextn embeddings id %d, reason: %s\n", __func__, i, err.what()); #ifndef NDEBUG GGML_ABORT("fatal error"); #else @@ -1098,11 +1118,11 @@ void llama_context::set_embeddings(bool value) { //sched_need_reserve = true; } -void llama_context::set_embeddings_pre_norm(bool value, bool masked) { +void llama_context::set_embeddings_nextn(bool value, bool masked) { LLAMA_LOG_DEBUG("%s: value = %d, masked = %d\n", __func__, value, masked); - cparams.embeddings_pre_norm = value; - cparams.embeddings_pre_norm_masked = masked; + cparams.embeddings_nextn = value; + cparams.embeddings_nextn_masked = masked; } void llama_context::set_causal_attn(bool value) { @@ -1319,7 +1339,7 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll } int llama_context::encode(const llama_batch & batch_inp) { - // MTP hook batches carry both token (next-token id) and embd (h_pre_norm row), + // MTP hook batches carry both token (next-token id) and embd (h_nextn row), // so accept either present rather than requiring exactly one. GGML_ASSERT(batch_inp.token || batch_inp.embd); @@ -1392,9 +1412,9 @@ int llama_context::encode(const llama_batch & batch_inp) { } } - auto * t_logits = res->get_logits(); - auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd(); - auto * t_h_pre_norm = cparams.embeddings_pre_norm ? res->get_h_pre_norm() : nullptr; + auto * t_logits = res->get_logits(); + auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd(); + auto * t_h_nextn = cparams.embeddings_nextn ? res->get_h_nextn() : nullptr; // extract logits if (logits.data && t_logits) { @@ -1460,14 +1480,14 @@ int llama_context::encode(const llama_batch & batch_inp) { } } - // extract pre-norm embeddings (hidden state before the final output norm) - if (embd_pre_norm.data && t_h_pre_norm && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) { - ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_pre_norm); + // extract nextn embeddings (hidden state before the final output norm) + if (embd_nextn.data && t_h_nextn && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) { + ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_nextn); GGML_ASSERT(backend_h != nullptr); - const uint32_t n_embd = hparams.n_embd; - GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_pre_norm.size); - ggml_backend_tensor_get_async(backend_h, t_h_pre_norm, embd_pre_norm.data, 0, n_tokens*n_embd*sizeof(float)); + const uint32_t n_embd = hparams.n_embd_out(); + GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_nextn.size); + ggml_backend_tensor_get_async(backend_h, t_h_nextn, embd_nextn.data, 0, n_tokens*n_embd*sizeof(float)); } // TODO: hacky solution @@ -1622,7 +1642,7 @@ static bool needs_raw_logits(const llama_ubatch & ubatch, const std::map<llama_s } int llama_context::decode(const llama_batch & batch_inp) { - // MTP hook batches carry both token (next-token id) and embd (h_pre_norm row), + // MTP hook batches carry both token (next-token id) and embd (h_nextn row), // so accept either present rather than requiring exactly one. GGML_ASSERT(batch_inp.token || batch_inp.embd); @@ -1822,9 +1842,9 @@ int llama_context::decode(const llama_batch & batch_inp) { // ggml_graph_dump_dot(gf, NULL, "llama.dot"); //} - auto * t_logits = res->get_logits(); - auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr; - auto * t_h_pre_norm = cparams.embeddings_pre_norm ? res->get_h_pre_norm() : nullptr; + auto * t_logits = res->get_logits(); + auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr; + auto * t_h_nextn = cparams.embeddings_nextn ? res->get_h_nextn() : nullptr; if (t_embd && res->get_embd_pooled()) { t_embd = res->get_embd_pooled(); @@ -1905,22 +1925,22 @@ int llama_context::decode(const llama_batch & batch_inp) { } } - // extract pre-norm embeddings (hidden state before the final output norm) + // extract nextn embeddings before // only meaningful in LLAMA_POOLING_TYPE_NONE (per-token); other pooling modes are ignored. { - const bool masked = cparams.embeddings_pre_norm_masked; + const bool masked = cparams.embeddings_nextn_masked; const int64_t n_rows = masked ? n_outputs : (int64_t) ubatch.n_tokens; const int64_t offset = masked ? n_outputs_prev : n_tokens_prev; - if (embd_pre_norm.data && t_h_pre_norm && n_rows > 0 && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) { - ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_pre_norm); + if (embd_nextn.data && t_h_nextn && n_rows > 0 && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) { + ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_nextn); GGML_ASSERT(backend_h != nullptr); - const uint32_t n_embd = hparams.n_embd; - float * embd_pre_norm_out = embd_pre_norm.data + offset*n_embd; + const uint32_t n_embd = hparams.n_embd_out(); + float * embd_nextn_out = embd_nextn.data + offset*n_embd; - GGML_ASSERT((offset + n_rows)*n_embd <= (int64_t) embd_pre_norm.size); - ggml_backend_tensor_get_async(backend_h, t_h_pre_norm, embd_pre_norm_out, 0, n_rows*n_embd*sizeof(float)); + GGML_ASSERT((offset + n_rows)*n_embd <= (int64_t) embd_nextn.size); + ggml_backend_tensor_get_async(backend_h, t_h_nextn, embd_nextn_out, 0, n_rows*n_embd*sizeof(float)); } } @@ -2009,12 +2029,11 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { const auto n_batch = cparams.n_batch; const auto n_vocab = vocab.n_tokens(); - const auto n_embd = hparams.n_embd; const auto n_embd_out = hparams.n_embd_out(); - bool has_logits = true; - bool has_embd = cparams.embeddings; - bool has_embd_pre_norm = cparams.embeddings_pre_norm; + bool has_logits = true; + bool has_embd = cparams.embeddings; + bool has_embd_nextn = cparams.embeddings_nextn; // TODO: hacky enc-dec support if (model.arch == LLM_ARCH_T5) { @@ -2026,14 +2045,14 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { size_t backend_float_count = 0; size_t backend_token_count = 0; - logits.size = has_logits ? n_vocab*n_outputs_max : 0; - embd.size = has_embd ? n_embd_out*n_outputs_max : 0; - embd_pre_norm.size = has_embd_pre_norm ? n_embd*n_outputs_max : 0; + logits.size = has_logits ? n_vocab*n_outputs_max : 0; + embd.size = has_embd ? n_embd_out*n_outputs_max : 0; + embd_nextn.size = has_embd_nextn ? n_embd_out*n_outputs_max : 0; - if (has_embd_pre_norm && !cparams.embeddings_pre_norm_masked) { - // unmasked: pre-norm row exists for every token in the batch, not just + if (has_embd_nextn && !cparams.embeddings_nextn_masked) { + // unmasked: nextn row exists for every token in the batch, not just // those flagged via batch.logits[i] -> size by token count instead. - embd_pre_norm.size = (size_t) n_embd * n_batch; + embd_nextn.size = (size_t) n_embd_out * n_batch; } // Allocate backend sampling output buffers if there are backend samplers configured. @@ -2050,7 +2069,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { const size_t prev_size = buf_output ? ggml_backend_buffer_get_size(buf_output.get()) : 0; const size_t new_size = - (logits.size + embd.size + embd_pre_norm.size + backend_float_count) * sizeof(float) + + (logits.size + embd.size + embd_nextn.size + backend_float_count) * sizeof(float) + ( backend_token_count) * sizeof(llama_token); // alloc only when more than the current capacity is required @@ -2067,7 +2086,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { buf_output = nullptr; logits.data = nullptr; embd.data = nullptr; - embd_pre_norm.data = nullptr; + embd_nextn.data = nullptr; } auto * buft = ggml_backend_cpu_buffer_type(); @@ -2096,8 +2115,8 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { embd = has_embd ? buffer_view<float>{(float *) (base + offset), embd.size} : buffer_view<float>{nullptr, 0}; offset += embd.size * sizeof(float); - embd_pre_norm = has_embd_pre_norm ? buffer_view<float>{(float *) (base + offset), embd_pre_norm.size} : buffer_view<float>{nullptr, 0}; - offset += embd_pre_norm.size * sizeof(float); + embd_nextn = has_embd_nextn ? buffer_view<float>{(float *) (base + offset), embd_nextn.size} : buffer_view<float>{nullptr, 0}; + offset += embd_nextn.size * sizeof(float); if (has_sampling) { sampling.logits = {(float *) (base + offset), (size_t)(n_vocab*n_outputs_max)}; @@ -2140,6 +2159,8 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { this->n_outputs = 0; + GGML_ASSERT(n_outputs_max <= cparams.n_outputs_max); + return n_outputs_max; } @@ -2163,9 +2184,9 @@ void llama_context::output_reorder() { } } - if (embd_pre_norm.size > 0) { + if (embd_nextn.size > 0) { for (uint64_t k = 0; k < n_embd; k++) { - std::swap(embd_pre_norm.data[i0*n_embd + k], embd_pre_norm.data[i1*n_embd + k]); + std::swap(embd_nextn.data[i0*n_embd + k], embd_nextn.data[i1*n_embd + k]); } } @@ -2226,8 +2247,6 @@ ggml_cgraph * llama_context::graph_reserve( if (n_tokens % n_seqs != 0) { n_tokens = ((n_tokens + (n_seqs - 1)) / n_seqs) * n_seqs; // round to next multiple of n_seqs - n_outputs = std::max(n_outputs, n_tokens); - LLAMA_LOG_DEBUG("%s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n", __func__, n_tokens, n_seqs, n_outputs); } @@ -2343,7 +2362,7 @@ llm_graph_cb llama_context::graph_get_cb() const { // norm may be automatically assigned to the backend of the previous layer, increasing data transfer between backends // FIXME: fix in ggml_backend_sched - const bool full_offload = model.n_gpu_layers() > model.hparams.n_layer; + const bool full_offload = model.n_gpu_layers() > model.hparams.n_layer_all; if (ubatch.n_tokens < 32 || full_offload) { if (il != -1 && strcmp(name, "norm") == 0) { const auto & dev_layer = model.dev_layer(il); @@ -3337,6 +3356,7 @@ llama_context_params llama_context_default_params() { /*.n_ubatch =*/ 512, /*.n_seq_max =*/ 1, /*.n_rs_seq =*/ 0, + /*.n_outputs_max =*/ 0, /*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default /*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS, /*.ctx_type =*/ LLAMA_CONTEXT_TYPE_DEFAULT, @@ -3366,6 +3386,7 @@ llama_context_params llama_context_default_params() { /*.kv_unified =*/ false, /*.sampler =*/ nullptr, /*.n_sampler =*/ 0, + /*.ctx_other =*/ nullptr, }; return result; @@ -3403,15 +3424,11 @@ llama_context * llama_init_from_model( LLAMA_LOG_ERROR("%s: SPLIT_MODE_TENSOR requires flash_attn to be enabled\n", __func__); return nullptr; } - if (ggml_is_quantized(params.type_k) || ggml_is_quantized(params.type_v)) { - LLAMA_LOG_ERROR("%s: simultaneous use of SPLIT_MODE_TENSOR and KV cache quantization not implemented\n", __func__); - return nullptr; - } } if (params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED && ggml_is_quantized(params.type_k)) { const uint32_t blck_size = ggml_blck_size(params.type_k); - for (uint32_t il = 0; il < model->hparams.n_layer; ++il) { + for (uint32_t il = 0; il < model->hparams.n_layer(); ++il) { if (model->hparams.n_embd_head_k(il) % blck_size != 0) { LLAMA_LOG_ERROR("%s: K cache type %s with block size %u does not divide n_embd_head_k=%u\n", __func__, ggml_type_name(params.type_k), blck_size, model->hparams.n_embd_head_k(il)); @@ -3422,7 +3439,7 @@ llama_context * llama_init_from_model( if (params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED && ggml_is_quantized(params.type_v)) { const uint32_t blck_size = ggml_blck_size(params.type_v); - for (uint32_t il = 0; il < model->hparams.n_layer; ++il) { + for (uint32_t il = 0; il < model->hparams.n_layer(); ++il) { if (model->hparams.n_embd_head_v(il) % blck_size != 0) { LLAMA_LOG_ERROR("%s: V cache type %s with block size %u does not divide n_embd_head_v=%u\n", __func__, ggml_type_name(params.type_v), blck_size, model->hparams.n_embd_head_v(il)); @@ -3444,12 +3461,11 @@ llama_context * llama_init_from_model( } if (params.ctx_type == LLAMA_CONTEXT_TYPE_MTP && - model->hparams.nextn_predict_layers == 0) { + model->hparams.n_layer_nextn == 0) { LLAMA_LOG_WARN("%s: context type MTP requested but model doesn't contain MTP layers\n", __func__); return nullptr; } - try { auto * ctx = new llama_context(*model, params); return ctx; @@ -3584,20 +3600,28 @@ float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) { return ctx->get_embeddings_seq(seq_id); } -void llama_set_embeddings_pre_norm(llama_context * ctx, bool value, bool masked) { - ctx->set_embeddings_pre_norm(value, masked); +void llama_set_embeddings_nextn(llama_context * ctx, bool value, bool masked) { + ctx->set_embeddings_nextn(value, masked); } -float * llama_get_embeddings_pre_norm(llama_context * ctx) { +llama_memory_t llama_get_memory(const struct llama_context * ctx) { + if (!ctx) { + return nullptr; + } + + return ctx->get_memory(); +} + +float * llama_get_embeddings_nextn(llama_context * ctx) { ctx->synchronize(); - return ctx->get_embeddings_pre_norm(); + return ctx->get_embeddings_nextn(); } -float * llama_get_embeddings_pre_norm_ith(llama_context * ctx, int32_t i) { +float * llama_get_embeddings_nextn_ith(llama_context * ctx, int32_t i) { ctx->synchronize(); - return ctx->get_embeddings_pre_norm_ith(i); + return ctx->get_embeddings_nextn_ith(i); } bool llama_set_sampler(llama_context * ctx, llama_seq_id seq_id, llama_sampler * smpl) { @@ -3651,7 +3675,7 @@ struct ggml_cgraph * llama_graph_reserve( uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs) { - auto * memory = ctx->get_memory(); + auto memory = ctx->get_memory(); llama_memory_context_ptr mctx; if (memory) { mctx = memory->init_full(); @@ -3691,10 +3715,6 @@ int32_t llama_set_adapter_cvec( // memory // -llama_memory_t llama_get_memory(const struct llama_context * ctx) { - return ctx->get_memory(); -} - void llama_memory_clear(llama_memory_t mem, bool data) { if (!mem) { return; @@ -4005,3 +4025,7 @@ void llama_opt_epoch( llama_memory_breakdown llama_get_memory_breakdown(const struct llama_context * ctx) { return ctx->memory_breakdown(); } + +llama_context * llama_get_ctx_other(struct llama_context * ctx) { + return ctx->get_cparams().ctx_other; +} diff --git a/examples/talk-llama/llama-context.h b/examples/talk-llama/llama-context.h index d03f681d4a1..6f8f59a22a3 100644 --- a/examples/talk-llama/llama-context.h +++ b/examples/talk-llama/llama-context.h @@ -6,6 +6,7 @@ #include "llama-graph.h" #include "llama-adapter.h" #include "llama-impl.h" +#include "llama-memory.h" #include "ggml-cpp.h" #include "ggml-opt.h" @@ -84,8 +85,8 @@ struct llama_context { float * get_embeddings_ith(int32_t i); float * get_embeddings_seq(llama_seq_id seq_id); - float * get_embeddings_pre_norm(); - float * get_embeddings_pre_norm_ith(int32_t i); + float * get_embeddings_nextn(); + float * get_embeddings_nextn_ith(int32_t i); llama_token * get_sampled_tokens() const; llama_token get_sampled_token_ith(int32_t idx); @@ -110,7 +111,7 @@ struct llama_context { void set_abort_callback(bool (*abort_callback)(void * data), void * abort_callback_data); void set_embeddings (bool value); - void set_embeddings_pre_norm(bool value, bool masked); + void set_embeddings_nextn(bool value, bool masked); void set_causal_attn(bool value); void set_warmup(bool value); @@ -273,7 +274,7 @@ struct llama_context { llama_cross cross; // TODO: tmp for handling cross-attention - need something better probably - std::unique_ptr<llama_memory_i> memory; + llama_memory_ptr memory; // decode output (2-dimensional array: [n_outputs][n_vocab]) buffer_view<float> logits = {nullptr, 0}; @@ -282,10 +283,10 @@ struct llama_context { // populated only when pooling_type == LLAMA_POOLING_TYPE_NONE buffer_view<float> embd = {nullptr, 0}; - // hidden state before the final output norm (2-dimensional array: [n_outputs][n_embd]) - // populated only when cparams.embeddings_pre_norm is enabled and the model graph - // sets llm_graph_result::t_h_pre_norm - buffer_view<float> embd_pre_norm = {nullptr, 0}; + // hidden state required by the nextn layers (2-dimensional array: [n_outputs][n_embd]) + // populated only when cparams.embeddings_nextn is enabled and the model graph + // sets llm_graph_result::t_h_nextn + buffer_view<float> embd_nextn = {nullptr, 0}; struct sampling_info { // !samplers.empty() to check if any samplers are active diff --git a/examples/talk-llama/llama-cparams.h b/examples/talk-llama/llama-cparams.h index 20ec59fe335..8a35d389ef4 100644 --- a/examples/talk-llama/llama-cparams.h +++ b/examples/talk-llama/llama-cparams.h @@ -13,6 +13,7 @@ struct llama_cparams { uint32_t n_ubatch; uint32_t n_seq_max; uint32_t n_rs_seq; // number of recurrent-state snapshots per seq for rollback + uint32_t n_outputs_max; // max outputs supported by the context int32_t n_threads; // number of threads to use for generation int32_t n_threads_batch; // number of threads to use for batch processing @@ -28,8 +29,8 @@ struct llama_cparams { float yarn_beta_slow; bool embeddings; - bool embeddings_pre_norm; // also extract the hidden state before the final output norm - bool embeddings_pre_norm_masked; // extract for only rows where batch.logits != 0 + bool embeddings_nextn; // also extract the hidden state before the final output norm + bool embeddings_nextn_masked; // extract for only rows where batch.logits != 0 bool causal_attn; bool offload_kqv; bool flash_attn; @@ -38,7 +39,7 @@ struct llama_cparams { bool fused_gdn_ch; // use fused gated delta net (chunked) bool auto_fgdn; bool no_perf; - bool warmup; + bool warmup; // TODO: remove [TAG_LLAMA_GRAPH_NO_WARMUP] bool op_offload; bool kv_unified; bool pipeline_parallel; @@ -48,4 +49,6 @@ struct llama_cparams { ggml_backend_sched_eval_callback cb_eval; void * cb_eval_user_data; + + llama_context * ctx_other; }; diff --git a/examples/talk-llama/llama-ext.h b/examples/talk-llama/llama-ext.h index edfa71c207c..bd74544129b 100644 --- a/examples/talk-llama/llama-ext.h +++ b/examples/talk-llama/llama-ext.h @@ -89,18 +89,16 @@ LLAMA_API ggml_backend_dev_t llama_model_get_device(const struct llama_model * m LLAMA_API llama_memory_breakdown llama_get_memory_breakdown(const struct llama_context * ctx); -// -// pre-norm embeddings (hidden state before the final output norm) -// - -// Set whether the context outputs pre-norm embeddings or not +// Set whether the context outputs nextn embeddings or not // If masked == true, output the embeddings only for the tokens with batch.logits != 0 // If masked == false, output the embeddings for all tokens in the batch regardless of batch.logits -LLAMA_API void llama_set_embeddings_pre_norm(struct llama_context * ctx, bool value, bool masked); +LLAMA_API void llama_set_embeddings_nextn(struct llama_context * ctx, bool value, bool masked); // mirrors: // LLAMA_API float * llama_get_embeddings(struct llama_context * ctx); -LLAMA_API float * llama_get_embeddings_pre_norm (struct llama_context * ctx); +LLAMA_API float * llama_get_embeddings_nextn(struct llama_context * ctx); // LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i); -LLAMA_API float * llama_get_embeddings_pre_norm_ith(struct llama_context * ctx, int32_t i); +LLAMA_API float * llama_get_embeddings_nextn_ith(struct llama_context * ctx, int32_t i); + +LLAMA_API llama_context * llama_get_ctx_other(struct llama_context * ctx); diff --git a/examples/talk-llama/llama-graph.cpp b/examples/talk-llama/llama-graph.cpp index fc027de8b39..da7a9295561 100644 --- a/examples/talk-llama/llama-graph.cpp +++ b/examples/talk-llama/llama-graph.cpp @@ -7,6 +7,7 @@ #include "llama-kv-cache.h" #include "llama-kv-cache-iswa.h" +#include "llama-kv-cache-dsa.h" #include "llama-memory-hybrid.h" #include "llama-memory-hybrid-iswa.h" #include "llama-memory-recurrent.h" @@ -29,7 +30,10 @@ static ggml_tensor * build_attn_inp_kq_mask( const auto n_tokens = ubatch.n_tokens; const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq; - ggml_tensor * res = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream); + // flash attention requires an f16 mask + const auto type = cparams.flash_attn ? GGML_TYPE_F16 : GGML_TYPE_F32; + + ggml_tensor * res = ggml_new_tensor_4d(ctx, type, n_kv, n_tokens/n_stream, 1, n_stream); ggml_set_input(res); ggml_set_name(res, "attn_inp_kq_mask"); @@ -102,6 +106,39 @@ bool llm_graph_input_embd::can_reuse(const llm_graph_params & params) { return res; } +void llm_graph_input_embd_h::set_input(const llama_ubatch * ubatch) { + const int64_t n_tokens = ubatch->n_tokens; + + if (ubatch->token) { + ggml_backend_tensor_set(tokens, ubatch->token, 0, n_tokens*ggml_element_size(tokens)); + } else { + // note: mtmd embedding input goes through here + GGML_ASSERT(ubatch->embd); + GGML_ASSERT(n_embd == embd->ne[0]); + + ggml_backend_tensor_set(embd, ubatch->embd, 0, n_tokens*n_embd*ggml_element_size(h)); + } + + // TODO: extend llama_ubatch to differentiate between token embeddings and hidden states + // for now, we assume that the hidden state is always provided as an embedding + // ref: https://github.com/ggml-org/llama.cpp/pull/23643 + if (ubatch->embd) { + GGML_ASSERT(n_embd == h->ne[0]); + + ggml_backend_tensor_set(h, ubatch->embd, 0, n_tokens*n_embd*ggml_element_size(h)); + } +} + +bool llm_graph_input_embd_h::can_reuse(const llm_graph_params & params) { + bool res = true; + + res &= (!params.ubatch.token) || (tokens && tokens->ne[0] == params.ubatch.n_tokens); + res &= (!params.ubatch.embd) || (embd && embd->ne[1] == params.ubatch.n_tokens); + res &= (!params.ubatch.embd) || (h && h->ne[1] == params.ubatch.n_tokens); + + return res; +} + void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) { if (ubatch->pos && pos) { const int64_t n_tokens = ubatch->n_tokens; @@ -348,7 +385,8 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) { } } -static void print_mask(const float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) { +template <typename T> +static void print_mask(const T * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) { LLAMA_LOG_DEBUG("%s: === Attention mask ===\n", __func__); const char * swa_type_str = "unknown"; @@ -359,7 +397,7 @@ static void print_mask(const float * data, int64_t n_tokens, int64_t n_kv, int64 case LLAMA_SWA_TYPE_SYMMETRIC: swa_type_str = "LLAMA_SWA_TYPE_SYMMETRIC"; break; }; - LLAMA_LOG_DEBUG("%s: n_swa : %d, n_kv: %d, swq_type: %s\n", __func__, (int)n_swa, (int)n_kv, swa_type_str); + LLAMA_LOG_DEBUG("%s: n_swa : %d, n_kv: %d, swa_type: %s\n", __func__, (int)n_swa, (int)n_kv, swa_type_str); LLAMA_LOG_DEBUG("%s: '0' = can attend, '∞' = masked\n", __func__); LLAMA_LOG_DEBUG("%s: Rows = query tokens, Columns = key/value tokens\n\n", __func__); @@ -372,7 +410,7 @@ static void print_mask(const float * data, int64_t n_tokens, int64_t n_kv, int64 for (int i = 0; i < std::min((int64_t)20, n_tokens); ++i) { LLAMA_LOG_DEBUG(" %2d ", i); for (int j = 0; j < std::min((int64_t)20, n_kv); ++j) { - float val = data[i * n_kv + j]; + float val = llama_cast<float>(data[i * n_kv + j]); if (val == -INFINITY) { LLAMA_LOG_DEBUG(" ∞"); } else { @@ -387,7 +425,10 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) { const int64_t n_kv = ubatch->n_tokens; const int64_t n_tokens = ubatch->n_tokens; - const auto fill_mask = [&](float * data, int n_swa, llama_swa_type swa_type) { + const auto fill_mask = [&](auto * data, int64_t ne, int n_swa, llama_swa_type swa_type) { + using T = std::remove_reference_t<decltype(*data)>; + std::fill(data, data + ne, llama_cast<T>(-INFINITY)); + for (int i1 = 0; i1 < n_tokens; ++i1) { const llama_seq_id s1 = ubatch->seq_id[i1][0]; const llama_pos p1 = ubatch->pos[i1]; @@ -413,38 +454,30 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) { continue; } - data[idst + i0] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f; + data[idst + i0] = llama_cast<T>(hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f); } } - }; - - { - GGML_ASSERT(self_kq_mask); - GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer)); - - float * data = (float *) self_kq_mask->data; - - std::fill(data, data + ggml_nelements(self_kq_mask), -INFINITY); - - fill_mask(data, 0, LLAMA_SWA_TYPE_NONE); if (debug) { - print_mask(data, n_tokens, n_kv, 0, LLAMA_SWA_TYPE_NONE); + print_mask(data, n_tokens, n_kv, n_swa, swa_type); } + }; + + GGML_ASSERT(self_kq_mask); + GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer)); + if (self_kq_mask->type == GGML_TYPE_F16) { + fill_mask((ggml_fp16_t *) self_kq_mask->data, ggml_nelements(self_kq_mask), 0, LLAMA_SWA_TYPE_NONE); + } else { + fill_mask((float *) self_kq_mask->data, ggml_nelements(self_kq_mask), 0, LLAMA_SWA_TYPE_NONE); } if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { GGML_ASSERT(self_kq_mask_swa); GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer)); - - float * data = (float *) self_kq_mask_swa->data; - - std::fill(data, data + ggml_nelements(self_kq_mask_swa), -INFINITY); - - fill_mask(data, hparams.n_swa, hparams.swa_type); - - if (debug) { - print_mask(data, n_tokens, n_kv, hparams.n_swa, hparams.swa_type); + if (self_kq_mask_swa->type == GGML_TYPE_F16) { + fill_mask((ggml_fp16_t *) self_kq_mask_swa->data, ggml_nelements(self_kq_mask_swa), hparams.n_swa, hparams.swa_type); + } else { + fill_mask((float *) self_kq_mask_swa->data, ggml_nelements(self_kq_mask_swa), hparams.n_swa, hparams.swa_type); } } } @@ -499,23 +532,51 @@ bool llm_graph_input_attn_k::can_reuse(const llm_graph_params & params) { return res; } +void llm_graph_input_attn_k_dsa::set_input(const llama_ubatch * ubatch) { + mctx->get_mla()->set_input_k_idxs(self_k_idxs_mla, ubatch); + + mctx->get_mla()->set_input_kq_mask(self_kq_mask_mla, ubatch, cparams.causal_attn); + + mctx->get_lid()->set_input_k_idxs(self_k_idxs_lid, ubatch); + + mctx->get_lid()->set_input_kq_mask(self_kq_mask_lid, ubatch, cparams.causal_attn); + + mctx->get_lid()->set_input_k_rot(self_k_rot_lid); +} + +bool llm_graph_input_attn_k_dsa::can_reuse(const llm_graph_params & params) { + const auto * mctx = static_cast<const llama_kv_cache_dsa_context *>(params.mctx); + + this->mctx = mctx; + + bool res = true; + + res &= self_k_idxs_mla->ne[0] == params.ubatch.n_tokens; + res &= self_k_idxs_lid->ne[0] == params.ubatch.n_tokens; + + res &= can_reuse_kq_mask(self_kq_mask_mla, mctx->get_mla(), params.ubatch, params.cparams); + res &= can_reuse_kq_mask(self_kq_mask_lid, mctx->get_lid(), params.ubatch, params.cparams); + + return res; +} + void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) { // base tensors may not be allocated if there are no non-SWA attention layers if (self_k_idxs && self_k_idxs->buffer) { mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch); mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch); - - mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); } + mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); + // swa tensors may not be allocated if there are no SWA attention layers if (self_k_idxs_swa && self_k_idxs_swa->buffer) { mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch); mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch); - - mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn); } + mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn); + if (self_k_rot) { mctx->get_base()->set_input_k_rot(self_k_rot); } @@ -544,18 +605,18 @@ bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) { if (self_k_idxs && self_k_idxs->buffer) { res &= self_k_idxs->ne[0] == params.ubatch.n_tokens; //res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there - - res &= can_reuse_kq_mask(self_kq_mask, mctx->get_base(), params.ubatch, params.cparams); } + res &= can_reuse_kq_mask(self_kq_mask, mctx->get_base(), params.ubatch, params.cparams); + // swa tensors may not be allocated if there are no SWA attention layers if (self_k_idxs_swa && self_k_idxs_swa->buffer) { res &= self_k_idxs_swa->ne[0] == params.ubatch.n_tokens; //res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there - - res &= can_reuse_kq_mask(self_kq_mask_swa, mctx->get_swa(), params.ubatch, params.cparams); } + res &= can_reuse_kq_mask(self_kq_mask_swa, mctx->get_swa(), params.ubatch, params.cparams); + return res; } @@ -568,23 +629,30 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) { GGML_ASSERT(ggml_backend_buffer_is_host(cross_kq_mask->buffer)); GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing - float * data = (float *) cross_kq_mask->data; - - for (int i = 0; i < n_tokens; ++i) { - GGML_ASSERT(!cross->seq_ids_enc.empty() && "llama_encode must be called first"); - for (int j = 0; j < n_enc; ++j) { - float f = -INFINITY; + const auto fill_mask = [&](auto * data) { + using T = std::remove_reference_t<decltype(*data)>; + for (int i = 0; i < n_tokens; ++i) { + GGML_ASSERT(!cross->seq_ids_enc.empty() && "llama_encode must be called first"); + for (int j = 0; j < n_enc; ++j) { + float f = -INFINITY; - for (int s = 0; s < ubatch->n_seq_id[i]; ++s) { - const llama_seq_id seq_id = ubatch->seq_id[i][s]; + for (int s = 0; s < ubatch->n_seq_id[i]; ++s) { + const llama_seq_id seq_id = ubatch->seq_id[i][s]; - if (cross->seq_ids_enc[j].find(seq_id) != cross->seq_ids_enc[j].end()) { - f = 0.0f; + if (cross->seq_ids_enc[j].find(seq_id) != cross->seq_ids_enc[j].end()) { + f = 0.0f; + } } - } - data[i*n_enc + j] = f; + data[i*n_enc + j] = llama_cast<T>(f); + } } + }; + + if (cross_kq_mask->type == GGML_TYPE_F16) { + fill_mask((ggml_fp16_t *) cross_kq_mask->data); + } else { + fill_mask((float *) cross_kq_mask->data); } } @@ -688,7 +756,9 @@ void llm_graph_input_mem_hybrid_iswa::set_input(const llama_ubatch * ubatch) { if (inp_attn->self_k_idxs && inp_attn->self_k_idxs->buffer) { attn_ctx->get_base()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch); attn_ctx->get_base()->set_input_v_idxs(inp_attn->self_v_idxs, ubatch); + } + if (inp_attn->self_kq_mask && inp_attn->self_kq_mask->buffer) { attn_ctx->get_base()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn); } @@ -696,7 +766,9 @@ void llm_graph_input_mem_hybrid_iswa::set_input(const llama_ubatch * ubatch) { if (inp_attn->self_k_idxs_swa && inp_attn->self_k_idxs_swa->buffer) { attn_ctx->get_swa()->set_input_k_idxs(inp_attn->self_k_idxs_swa, ubatch); attn_ctx->get_swa()->set_input_v_idxs(inp_attn->self_v_idxs_swa, ubatch); + } + if (inp_attn->self_kq_mask_swa && inp_attn->self_kq_mask_swa->buffer) { attn_ctx->get_swa()->set_input_kq_mask(inp_attn->self_kq_mask_swa, ubatch, cparams.causal_attn); } @@ -742,18 +814,18 @@ bool llm_graph_input_mem_hybrid_iswa::can_reuse(const llm_graph_params & params) if (inp_attn->self_k_idxs && inp_attn->self_k_idxs->buffer) { res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens; //res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there - - res &= can_reuse_kq_mask(inp_attn->self_kq_mask, attn_ctx->get_base(), params.ubatch, params.cparams); } + res &= can_reuse_kq_mask(inp_attn->self_kq_mask, attn_ctx->get_base(), params.ubatch, params.cparams); + // swa tensors may not be allocated if there are no SWA attention layers if (inp_attn->self_k_idxs_swa && inp_attn->self_k_idxs_swa->buffer) { res &= inp_attn->self_k_idxs_swa->ne[0] == params.ubatch.n_tokens; //res &= inp_attn->self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there - - res &= can_reuse_kq_mask(inp_attn->self_kq_mask_swa, attn_ctx->get_swa(), params.ubatch, params.cparams); } + res &= can_reuse_kq_mask(inp_attn->self_kq_mask_swa, attn_ctx->get_swa(), params.ubatch, params.cparams); + res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs(); res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs; @@ -861,8 +933,8 @@ void llm_graph_result::set_outputs() { if (t_embd_pooled != nullptr) { ggml_set_output(t_embd_pooled); } - if (t_h_pre_norm != nullptr) { - ggml_set_output(t_h_pre_norm); + if (t_h_nextn != nullptr) { + ggml_set_output(t_h_nextn); } for (auto & [seq_id, t] : t_sampled) { if (t != nullptr) { @@ -937,7 +1009,8 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) : cparams (params.cparams), ubatch (params.ubatch), n_embd (hparams.n_embd), - n_layer (hparams.n_layer), + n_layer (hparams.n_layer()), + n_layer_nextn (hparams.n_layer_nextn), n_rot (hparams.n_rot()), n_ctx (cparams.n_ctx), n_head (hparams.n_head()), @@ -1791,7 +1864,12 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const { res->t_inp_embd = cur; // For Granite architecture - if (hparams.f_embedding_scale != 0.0f) { + // NOTE: Only apply scale to token inputs. Raw embeddings are assumed to be + // multimodal inputs that should not be scaled. + if (ubatch.token && hparams.f_embedding_scale != 0.0f) { + if (!ggml_is_contiguous(cur)) { + cur = ggml_cont(ctx0, cur); + } cur = ggml_scale(ctx0, cur, hparams.f_embedding_scale); } @@ -2088,17 +2166,20 @@ ggml_tensor * llm_graph_context::build_attn_mha( llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() const { auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams); + // flash attention requires an f16 mask + const auto type_mask = cparams.flash_attn ? GGML_TYPE_F16 : GGML_TYPE_F32; + // note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch - inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens, 1, 1); + inp->self_kq_mask = ggml_new_tensor_4d(ctx0, type_mask, n_tokens, n_tokens, 1, 1); ggml_set_input(inp->self_kq_mask); - inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; + inp->self_kq_mask_cnv = inp->self_kq_mask; if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { - inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens, 1, 1); + inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, type_mask, n_tokens, n_tokens, 1, 1); ggml_set_input(inp->self_kq_mask_swa); - inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa; + inp->self_kq_mask_swa_cnv = inp->self_kq_mask_swa; } else { inp->self_kq_mask_swa = nullptr; inp->self_kq_mask_swa_cnv = nullptr; @@ -2175,7 +2256,7 @@ static std::unique_ptr<llm_graph_input_attn_kv> build_attn_inp_kv_impl( inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch); inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, mctx_cur, ubatch, cparams); - inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; + inp->self_kq_mask_cnv = inp->self_kq_mask; } inp->self_k_rot = mctx_cur->build_input_k_rot(ctx0); @@ -2282,7 +2363,7 @@ static std::unique_ptr<llm_graph_input_attn_k> build_attn_inp_k_impl( inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch); inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, mctx_cur, ubatch, cparams); - inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; + inp->self_kq_mask_cnv = inp->self_kq_mask; } return inp; @@ -2354,6 +2435,82 @@ ggml_tensor * llm_graph_context::build_attn( return cur; } +ggml_tensor * llm_graph_context::build_attn( + llm_graph_input_attn_k_dsa * inp, + ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * wo_s, + ggml_tensor * q_cur, + ggml_tensor * k_cur, + ggml_tensor * v_cur, + ggml_tensor * kq_b, + ggml_tensor * sinks, + ggml_tensor * v_mla, + ggml_tensor * top_k, + float kq_scale, + int il) const { + // these nodes are added to the graph together so that they are not reordered + // by doing so, the number of splits in the graph is reduced + // expand k later to enable rope fusion which directly writes into k-v cache + ggml_build_forward_expand(gf, q_cur); + ggml_build_forward_expand(gf, v_cur); + ggml_build_forward_expand(gf, k_cur); + + const auto * mctx_cur = inp->mctx->get_mla(); + + // store to KV cache + { + const auto & k_idxs = inp->get_k_idxs_mla(); + + ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il)); + } + + const auto & kq_mask = inp->get_kq_mask_mla(); + + // prepare new kq mask - starts filled with -INFINITY + ggml_tensor * kq_mask_all = ggml_fill(ctx0, kq_mask, -INFINITY); + + // reshape KQ mask into tensor with rows of size 1: + // [n_kv, n_batch, 1, n_stream] -> [1, n_kv, n_batch, n_stream] + kq_mask_all = ggml_view_4d(ctx0, kq_mask_all, 1, kq_mask_all->ne[0], kq_mask_all->ne[1], kq_mask_all->ne[3], kq_mask_all->nb[0], kq_mask_all->nb[1], kq_mask_all->nb[2], 0); + + // reshape top_k indices: [n_top_k, n_batch, 1, n_stream] -> [n_top_k, n_batch, n_stream, 1] + ggml_tensor * top_k_3d = ggml_view_4d(ctx0, top_k, top_k->ne[0], top_k->ne[1], top_k->ne[3], 1, top_k->nb[1], top_k->nb[2], top_k->ne[3]*top_k->nb[3], 0); + + // prepare zero-filled tensor with rows of size 1: [1, n_top_k, n_batch, n_stream] + // this will be our source of zero values for unmasking top k mask elements + ggml_tensor * zeros = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, 1, top_k_3d->ne[0], top_k_3d->ne[1], top_k_3d->ne[2]); + zeros = ggml_fill(ctx0, zeros, 0.0f); + + // modify KQ mask by unmasking elements that are in top_k indices + // ggml_set_rows([1, n_kv, n_batch, n_stream], [1, n_top_k, n_batch, n_stream], [n_top_k, n_batch, n_stream, 1]) + ggml_tensor * kq_mask_top_k = ggml_set_rows(ctx0, kq_mask_all, zeros, top_k_3d); + + // reshape to restore the original shape of KQ mask: + // [1, n_kv, n_batch, n_stream] -> [n_kv, n_batch, 1, n_stream] + kq_mask_top_k = ggml_view_4d(ctx0, kq_mask_top_k, kq_mask_top_k->ne[1], kq_mask_top_k->ne[2], 1, kq_mask_top_k->ne[3], kq_mask_top_k->nb[2], kq_mask_top_k->nb[3], kq_mask_top_k->nb[3], 0); + + // combine with the original kq mask + kq_mask_top_k = ggml_add(ctx0, kq_mask_top_k, kq_mask); + + ggml_tensor * q = q_cur; + ggml_tensor * k = mctx_cur->get_k(ctx0, il); + ggml_tensor * v = ggml_view_4d(ctx0, k, v_cur->ne[0], k->ne[1], k->ne[2], k->ne[3], k->nb[1], k->nb[2], k->nb[3], 0); + + ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask_top_k, sinks, v_mla, kq_scale, il); + cb(cur, "kqv_out", il); + + if (wo) { + cur = build_lora_mm(wo, cur, wo_s); + } + + if (wo_b) { + cur = ggml_add(ctx0, cur, wo_b); + } + + return cur; +} + ggml_tensor * llm_graph_context::build_attn( llm_graph_input_attn_kv_iswa * inp, ggml_tensor * wo, @@ -2446,10 +2603,13 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const { const int32_t n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train; - inp->cross_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_enc, n_tokens, 1, 1); + // flash attention requires an f16 mask + const auto type_mask = cparams.flash_attn ? GGML_TYPE_F16 : GGML_TYPE_F32; + + inp->cross_kq_mask = ggml_new_tensor_4d(ctx0, type_mask, n_enc, n_tokens, 1, 1); ggml_set_input(inp->cross_kq_mask); - inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->cross_kq_mask, GGML_TYPE_F16) : inp->cross_kq_mask; + inp->cross_kq_mask_cnv = inp->cross_kq_mask; return (llm_graph_input_attn_cross *) res->add_input(std::move(inp)); } @@ -2497,6 +2657,34 @@ ggml_tensor * llm_graph_context::build_attn( return cur; } +llm_graph_input_attn_k_dsa * llm_graph_context::build_attn_inp_k_dsa() const { + const auto * mctx_cur = static_cast<const llama_kv_cache_dsa_context *>(mctx); + + auto inp = std::make_unique<llm_graph_input_attn_k_dsa>(hparams, cparams, mctx_cur); + + { + inp->self_k_idxs_mla = mctx_cur->get_mla()->build_input_k_idxs(ctx0, ubatch); + + inp->self_kq_mask_mla = build_attn_inp_kq_mask(ctx0, mctx_cur->get_mla(), ubatch, cparams); + inp->self_kq_mask_mla_cnv = inp->self_kq_mask_mla; + } + + { + inp->self_k_idxs_lid = mctx_cur->get_lid()->build_input_k_idxs(ctx0, ubatch); + + // ensure F32 mask + auto cparams_copy = cparams; + cparams_copy.flash_attn = false; + + inp->self_kq_mask_lid = build_attn_inp_kq_mask(ctx0, mctx_cur->get_lid(), ubatch, cparams_copy); + inp->self_kq_mask_lid_cnv = inp->self_kq_mask_lid; + + inp->self_k_rot_lid = mctx_cur->get_lid()->build_input_k_rot(ctx0); + } + + return (llm_graph_input_attn_k_dsa *) res->add_input(std::move(inp)); +} + // TODO: maybe separate the inner implementation into a separate function // like with the non-sliding window equivalent // once sliding-window hybrid caches are a thing. @@ -2510,7 +2698,7 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch); inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, mctx_cur->get_base(), ubatch, cparams); - inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; + inp->self_kq_mask_cnv = inp->self_kq_mask; } { @@ -2520,7 +2708,7 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch); inp->self_kq_mask_swa = build_attn_inp_kq_mask(ctx0, mctx_cur->get_swa(), ubatch, cparams); - inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa; + inp->self_kq_mask_swa_cnv = inp->self_kq_mask_swa; } inp->self_k_rot = mctx_cur->get_base()->build_input_k_rot(ctx0); @@ -2689,7 +2877,7 @@ llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa() inp_attn->self_v_idxs = attn_ctx->get_base()->build_input_v_idxs(ctx0, ubatch); inp_attn->self_kq_mask = build_attn_inp_kq_mask(ctx0, attn_ctx->get_base(), ubatch, cparams); - inp_attn->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask, GGML_TYPE_F16) : inp_attn->self_kq_mask; + inp_attn->self_kq_mask_cnv = inp_attn->self_kq_mask; } { @@ -2697,7 +2885,7 @@ llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa() inp_attn->self_v_idxs_swa = attn_ctx->get_swa()->build_input_v_idxs(ctx0, ubatch); inp_attn->self_kq_mask_swa = build_attn_inp_kq_mask(ctx0, attn_ctx->get_swa(), ubatch, cparams); - inp_attn->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask_swa, GGML_TYPE_F16) : inp_attn->self_kq_mask_swa; + inp_attn->self_kq_mask_swa_cnv = inp_attn->self_kq_mask_swa; } auto inp = std::make_unique<llm_graph_input_mem_hybrid_iswa>(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur); diff --git a/examples/talk-llama/llama-graph.h b/examples/talk-llama/llama-graph.h index bf6778237e6..6793846e3ea 100644 --- a/examples/talk-llama/llama-graph.h +++ b/examples/talk-llama/llama-graph.h @@ -22,6 +22,7 @@ struct llama_layer; struct llama_memory_context_i; class llama_kv_cache_context; +class llama_kv_cache_dsa_context; class llama_kv_cache_iswa_context; class llama_memory_recurrent_context; class llama_memory_hybrid_context; @@ -35,7 +36,8 @@ enum llm_graph_type { LLM_GRAPH_TYPE_DECODER_MTP, }; -enum llm_ffn_op_type { +enum llm_ffn_op_type : int { + LLM_FFN_NONE = 0, // sentinel: unset; archs must assign before use LLM_FFN_SILU, LLM_FFN_GELU, LLM_FFN_RELU, @@ -121,6 +123,23 @@ class llm_graph_input_embd : public llm_graph_input_i { const int64_t n_embd = 0; }; +// similar to llm_graph_input_embd but with an additional hidden state input +class llm_graph_input_embd_h : public llm_graph_input_i { +public: + llm_graph_input_embd_h(int64_t n_embd) : n_embd(n_embd) {} + virtual ~llm_graph_input_embd_h() = default; + + void set_input(const llama_ubatch * ubatch) override; + + bool can_reuse(const llm_graph_params & params) override; + + ggml_tensor * tokens = nullptr; // I32 [n_batch] + ggml_tensor * embd = nullptr; // F32 [n_embd, n_batch] + ggml_tensor * h = nullptr; // F32 [n_embd, n_batch] + + const int64_t n_embd = 0; +}; + class llm_graph_input_pos : public llm_graph_input_i { public: llm_graph_input_pos(uint32_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {} @@ -274,10 +293,10 @@ class llm_graph_input_attn_no_cache : public llm_graph_input_i { ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; } // n_tokens == n_batch - ggml_tensor * self_kq_mask = nullptr; // F32 [n_tokens, n_batch/n_stream, 1, n_stream] - ggml_tensor * self_kq_mask_cnv = nullptr; // [n_tokens, n_batch/n_stream, 1, n_stream] - ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_tokens, n_batch/n_stream, 1, n_stream] - ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_tokens, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask = nullptr; // F32/F16 [n_tokens, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_cnv = nullptr; // [n_tokens, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_swa = nullptr; // F32/F16 [n_tokens, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_tokens, n_batch/n_stream, 1, n_stream] const llama_hparams hparams; const llama_cparams cparams; @@ -307,8 +326,8 @@ class llm_graph_input_attn_kv : public llm_graph_input_i { ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch] ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa] - ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream] - ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask = nullptr; // F32/F16 [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] // note: assumes v_rot^2 == I ggml_tensor * self_k_rot = nullptr; @@ -347,8 +366,8 @@ class llm_graph_input_attn_k : public llm_graph_input_i { ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch] - ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream] - ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask = nullptr; // F32/F16 [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] const llama_hparams hparams; const llama_cparams cparams; @@ -356,6 +375,44 @@ class llm_graph_input_attn_k : public llm_graph_input_i { const llama_kv_cache_context * mctx; }; +class llm_graph_input_attn_k_dsa : public llm_graph_input_i { +public: + llm_graph_input_attn_k_dsa( + const llama_hparams & hparams, + const llama_cparams & cparams, + const llama_kv_cache_dsa_context * mctx) : + hparams(hparams), + cparams(cparams), + mctx(mctx) { + } + ~llm_graph_input_attn_k_dsa() = default; + + void set_input(const llama_ubatch * ubatch) override; + + bool can_reuse(const llm_graph_params & params) override; + + ggml_tensor * get_k_idxs_mla() const { return self_k_idxs_mla; } + ggml_tensor * get_k_idxs_lid() const { return self_k_idxs_lid; } + + ggml_tensor * get_kq_mask_mla() const { return self_kq_mask_mla_cnv; } + ggml_tensor * get_kq_mask_lid() const { return self_kq_mask_lid; } + + ggml_tensor * self_k_idxs_mla = nullptr; // I64 [n_batch] + ggml_tensor * self_k_idxs_lid = nullptr; // I64 [n_batch] + + ggml_tensor * self_kq_mask_mla = nullptr; // F32/F16 [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_mla_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_lid = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_lid_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] + + ggml_tensor * self_k_rot_lid = nullptr; + + const llama_hparams hparams; + const llama_cparams cparams; + + const llama_kv_cache_dsa_context * mctx; +}; + class llm_graph_input_attn_kv_iswa : public llm_graph_input_i { public: llm_graph_input_attn_kv_iswa( @@ -385,10 +442,10 @@ class llm_graph_input_attn_kv_iswa : public llm_graph_input_i { ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch] ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa] - ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream] - ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] - ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream] - ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask = nullptr; // F32/F16 [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_swa = nullptr; // F32/F16 [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] ggml_tensor * self_k_rot = nullptr; ggml_tensor * self_v_rot = nullptr; @@ -411,8 +468,8 @@ class llm_graph_input_attn_cross : public llm_graph_input_i { ggml_tensor * get_kq_mask_cross() const { return cross_kq_mask_cnv; } - ggml_tensor * cross_kq_mask = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1] - ggml_tensor * cross_kq_mask_cnv = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1] + ggml_tensor * cross_kq_mask = nullptr; // F32/F16 [n_outputs_enc, n_batch, 1, 1] + ggml_tensor * cross_kq_mask_cnv = nullptr; // F32/F16 [n_outputs_enc, n_batch, 1, 1] const llama_cross * cross = nullptr; }; @@ -646,7 +703,7 @@ class llm_graph_result { ggml_tensor * get_logits() const { return t_logits; } ggml_tensor * get_embd() const { return t_embd; } ggml_tensor * get_embd_pooled() const { return t_embd_pooled; } - ggml_tensor * get_h_pre_norm() const { return t_h_pre_norm; } + ggml_tensor * get_h_nextn() const { return t_h_nextn; } ggml_cgraph * get_gf() const { return gf; } ggml_context * get_ctx() const { return ctx_compute.get(); } @@ -675,7 +732,7 @@ class llm_graph_result { ggml_tensor * t_logits = nullptr; ggml_tensor * t_embd = nullptr; ggml_tensor * t_embd_pooled = nullptr; - ggml_tensor * t_h_pre_norm = nullptr; // [n_embd, n_outputs] hidden state before final output norm + ggml_tensor * t_h_nextn = nullptr; // [n_embd, n_outputs] hidden state before final output norm std::map<llama_seq_id, ggml_tensor*> t_sampled_logits; std::map<llama_seq_id, ggml_tensor*> t_candidates; @@ -727,6 +784,7 @@ struct llm_graph_context { const int64_t n_embd; const int64_t n_layer; + const int64_t n_layer_nextn; const int64_t n_rot; const int64_t n_ctx; // user-specified context size (can be different from n_ctx_train) const int64_t n_head; @@ -956,6 +1014,23 @@ struct llm_graph_context { float kq_scale, int il) const; + llm_graph_input_attn_k_dsa * build_attn_inp_k_dsa() const; + + ggml_tensor * build_attn( + llm_graph_input_attn_k_dsa * inp, + ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * wo_s, + ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens] + ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] + ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] + ggml_tensor * kq_b, + ggml_tensor * sinks, // [n_head_q] + ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] + ggml_tensor * top_k, // [n_indexer_top_k, n_tokens] + float kq_scale, + int il) const; + llm_graph_input_attn_kv_iswa * build_attn_inp_kv_iswa() const; // note: if k_cur or v_cur are not provided, they will not be stored in the memory diff --git a/examples/talk-llama/llama-hparams.cpp b/examples/talk-llama/llama-hparams.cpp index 2239309c8fb..2bf57687382 100644 --- a/examples/talk-llama/llama-hparams.cpp +++ b/examples/talk-llama/llama-hparams.cpp @@ -7,19 +7,39 @@ void llama_hparams::set_swa_pattern(uint32_t n_pattern, bool dense_first) { if (dense_first) { - for (uint32_t il = 0; il < n_layer; ++il) { - swa_layers[il] = n_pattern == 0 || (il % n_pattern != 0); + for (uint32_t il = 0; il < n_layer(); ++il) { + is_swa_impl[il] = n_pattern == 0 || (il % n_pattern != 0); } } else { - for (uint32_t il = 0; il < n_layer; ++il) { - swa_layers[il] = n_pattern == 0 || (il % n_pattern < (n_pattern - 1)); + for (uint32_t il = 0; il < n_layer(); ++il) { + is_swa_impl[il] = n_pattern == 0 || (il % n_pattern < (n_pattern - 1)); } } + + for (uint32_t il = n_layer(); il < n_layer_all; ++il) { + is_swa_impl[il] = false; + } +} + +void llama_hparams::set_recr_pattern(uint32_t n_pattern, bool dense_first) { + if (dense_first) { + for (uint32_t il = 0; il < n_layer(); ++il) { + is_recr_impl[il] = n_pattern == 0 || (il % n_pattern != 0); + } + } else { + for (uint32_t il = 0; il < n_layer(); ++il) { + is_recr_impl[il] = n_pattern == 0 || (il % n_pattern < (n_pattern - 1)); + } + } + + for (uint32_t il = n_layer(); il < n_layer_all; ++il) { + is_recr_impl[il] = false; + } } bool llama_hparams::is_swa_any() const { - for (uint32_t il = 0; il < n_layer; ++il) { - if (swa_layers[il]) { + for (uint32_t il = 0; il < n_layer_all; ++il) { + if (is_swa_impl[il]) { return true; } } @@ -28,7 +48,7 @@ bool llama_hparams::is_swa_any() const { } uint32_t llama_hparams::n_head(uint32_t il) const { - if (il < n_layer) { + if (il < n_layer_all) { return n_head_arr[il]; } @@ -36,7 +56,7 @@ uint32_t llama_hparams::n_head(uint32_t il) const { } uint32_t llama_hparams::n_head_kv(uint32_t il) const { - if (il < n_layer) { + if (il < n_layer_all) { return n_head_kv_arr[il]; } @@ -44,7 +64,7 @@ uint32_t llama_hparams::n_head_kv(uint32_t il) const { } uint32_t llama_hparams::n_ff(uint32_t il) const { - if (il < n_layer) { + if (il < n_layer_all) { return n_ff_arr[il]; } @@ -63,7 +83,7 @@ uint32_t llama_hparams::n_gqa(uint32_t il) const { } uint32_t llama_hparams::n_rot(uint32_t il) const { - if (il < n_layer) { + if (il < n_layer_all) { return is_swa(il) ? n_rot_swa : n_rot_full; } @@ -71,6 +91,10 @@ uint32_t llama_hparams::n_rot(uint32_t il) const { } uint32_t llama_hparams::n_embd_inp() const { + if (n_embd_inp_impl > 0) { + return n_embd_inp_impl; + } + uint32_t n_embd_inp = n_embd; if (n_deepstack_layers > 0) { @@ -85,7 +109,7 @@ uint32_t llama_hparams::n_embd_out() const { } uint32_t llama_hparams::n_embd_head_k(uint32_t il) const { - if (il < n_layer) { + if (il < n_layer_all) { return is_swa(il) ? n_embd_head_k_swa : n_embd_head_k_full; } @@ -93,7 +117,7 @@ uint32_t llama_hparams::n_embd_head_k(uint32_t il) const { } uint32_t llama_hparams::n_embd_head_v(uint32_t il) const { - if (il < n_layer) { + if (il < n_layer_all) { return is_swa(il) ? n_embd_head_v_swa : n_embd_head_v_full; } @@ -114,7 +138,7 @@ uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const { bool llama_hparams::is_n_embd_k_gqa_variable() const { const uint32_t val = n_embd_k_gqa(); - for (uint32_t il = 0; il < n_layer; ++il) { + for (uint32_t il = 0; il < n_layer_all; ++il) { if (val != n_embd_k_gqa(il)) { return true; } @@ -125,7 +149,7 @@ bool llama_hparams::is_n_embd_k_gqa_variable() const { bool llama_hparams::is_n_embd_v_gqa_variable() const { const uint32_t val = n_embd_v_gqa(); - for (uint32_t il = 0; il < n_layer; ++il) { + for (uint32_t il = 0; il < n_layer_all; ++il) { if (val != n_embd_v_gqa(il)) { return true; } @@ -136,7 +160,7 @@ bool llama_hparams::is_n_embd_v_gqa_variable() const { uint32_t llama_hparams::n_embd_k_gqa_max() const { uint32_t val = n_embd_k_gqa(); - for (uint32_t il = 0; il < n_layer; ++il) { + for (uint32_t il = 0; il < n_layer_all; ++il) { val = std::max(val, n_embd_k_gqa(il)); } @@ -145,7 +169,7 @@ uint32_t llama_hparams::n_embd_k_gqa_max() const { uint32_t llama_hparams::n_embd_v_gqa_max() const { uint32_t val = n_embd_v_gqa(); - for (uint32_t il = 0; il < n_layer; ++il) { + for (uint32_t il = 0; il < n_layer_all; ++il) { val = std::max(val, n_embd_v_gqa(il)); } @@ -193,12 +217,12 @@ uint32_t llama_hparams::n_embd_s() const { return ssm_d_state * ssm_d_inner; } -bool llama_hparams::is_recurrent(uint32_t il) const { - if (il < n_layer) { - return recurrent_layer_arr[il]; +bool llama_hparams::is_recr(uint32_t il) const { + if (il < n_layer_all) { + return is_recr_impl[il]; } - GGML_ABORT("%s: il (%u) out of bounds (n_layer: %u)\n", __func__, il, n_layer); + GGML_ABORT("%s: il (%u) out of bounds (n_layer_all: %u)\n", __func__, il, n_layer_all); } uint32_t llama_hparams::n_pos_per_embd() const { @@ -206,11 +230,11 @@ uint32_t llama_hparams::n_pos_per_embd() const { } bool llama_hparams::is_swa(uint32_t il) const { - if (il < n_layer) { - return swa_layers[il]; + if (il < n_layer_all) { + return is_swa_impl[il]; } - GGML_ABORT("fatal error"); + GGML_ABORT("%s: il (%u) out of bounds (n_layer_all: %u)\n", __func__, il, n_layer_all); } bool llama_hparams::is_mla() const { @@ -229,12 +253,6 @@ uint32_t llama_hparams::n_embd_head_v_mla() const { } bool llama_hparams::has_kv(uint32_t il) const { - if (kv_only_nextn) { - // MTP head: only the trailing nextn_predict_layers blocks own a KV cache; - // the leading trunk blocks are not executed in this graph. - return nextn_predict_layers > 0 && il >= (n_layer - nextn_predict_layers); - } - if (n_layer_kv_from_start >= 0) { if (il < (uint32_t) n_layer_kv_from_start) { return true; @@ -247,16 +265,8 @@ bool llama_hparams::has_kv(uint32_t il) const { return true; } -uint32_t llama_hparams::n_layer_kv() const { - uint32_t res = 0; - - for (uint32_t il = 0; il < n_layer; ++il) { - if (has_kv(il)) { - res++; - } - } - - return res; +uint32_t llama_hparams::n_layer() const { + return n_layer_all - n_layer_nextn; } bool llama_hparams::use_mrope() const { diff --git a/examples/talk-llama/llama-hparams.h b/examples/talk-llama/llama-hparams.h index e2d051edc6c..032944cb481 100644 --- a/examples/talk-llama/llama-hparams.h +++ b/examples/talk-llama/llama-hparams.h @@ -23,6 +23,9 @@ enum llama_swa_type { LLAMA_SWA_TYPE_SYMMETRIC = 3, }; +// forward declaration; full definition in llama-graph.h +enum llm_ffn_op_type : int; + struct llama_hparams_posnet { uint32_t n_embd; uint32_t n_layer; @@ -34,6 +37,9 @@ struct llama_hparams_convnext { }; struct llama_hparams { + // note: use the `_impl` suffix to avoid name conflict between members and getters + // for example: n_embd_out() vs n_embd_out_impl + bool vocab_only; bool no_alloc; bool rope_finetuned; @@ -42,12 +48,15 @@ struct llama_hparams { uint32_t n_ctx_train; // context size the model was trained on uint32_t n_embd; - uint32_t n_layer; - int32_t n_layer_kv_from_start = -1; // if non-negative, the first n_layer_kv_from_start layers have KV cache + uint32_t n_layer_all; + uint32_t n_layer_nextn = 0; uint32_t n_expert = 0; uint32_t n_expert_used = 0; uint32_t n_rel_attn_bkts = 0; + // TODO: this needs to be reworked + int32_t n_layer_kv_from_start = -1; // if non-negative, the first n_layer_kv_from_start layers have KV cache + // different head size for full_attention and SWA layers uint32_t n_embd_head_k_full; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads uint32_t n_embd_head_v_full; // dimension of values (d_v) aka n_embd_head @@ -90,9 +99,6 @@ struct llama_hparams { uint32_t expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_NONE; uint32_t moe_every_n_layers = 0; uint32_t moe_latent_size = 0; - uint32_t nextn_predict_layers = 0; - - bool kv_only_nextn = false; // if true, only the last nextn_predict_layers blocks have a KV cache (MTP head arches) float f_norm_eps; float f_norm_rms_eps; @@ -134,11 +140,15 @@ struct llama_hparams { llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE; // the size of the sliding window (0 - no SWA) uint32_t n_swa = 0; - // if swa_layers[il] == 1, then layer il is SWA - // if swa_layers[il] == 0, then layer il is dense (i.e. non-SWA) + + // if is_swa_impl[il] == 1, then layer il is SWA + // if is_swa_impl[il] == 0, then layer il is dense (i.e. non-SWA) // by default, all layers are dense // note: using uint32_t type for compatibility reason - std::array<uint32_t, LLAMA_MAX_LAYERS> swa_layers; + std::array<uint32_t, LLAMA_MAX_LAYERS> is_swa_impl; + + // for hybrid state space models + std::array<uint32_t, LLAMA_MAX_LAYERS> is_recr_impl; // for State Space Models uint32_t ssm_d_conv = 0; @@ -150,9 +160,6 @@ struct llama_hparams { // for Kimi Linear KDA uint32_t n_embd_head_kda = 0; - // for hybrid state space models - std::array<bool, LLAMA_MAX_LAYERS> recurrent_layer_arr; - bool ssm_dt_b_c_rms = false; float f_clamp_kqv = 0.0f; @@ -178,6 +185,9 @@ struct llama_hparams { // for Classifiers uint32_t n_cls_out = 1; + // input embedding dimension (0 = use n_embd) + uint32_t n_embd_inp_impl = 0; + // output embedding dimension (0 = use n_embd) uint32_t n_embd_out_impl = 0; @@ -212,8 +222,19 @@ struct llama_hparams { uint32_t indexer_top_k = 0; // qwen3vl deepstack + // When parsed from GGUF, this implies the first N layers consume the first + // N deepstack embeddings. Use deepstack_mapping_arr if you need a more + // complex mapping. If using deepstack_mapping_arr, also make sure to set + // n_deepstack_layers to the number of unique deepstack layers so that + // n_embd_imp is accurate (see granite.cpp). + // TODO: can be expressed via the `new n_embd_inp_impl` and remove this param uint32_t n_deepstack_layers = 0; + // deepstack layer array (Granite4 Vision) + // -1 => no deepstack + // >=0 => input embedding index for deepstack injection + std::array<int32_t, LLAMA_MAX_LAYERS> deepstack_mapping_arr; + // gemma4 per-layer embedding uint32_t n_embd_per_layer = 0; @@ -227,6 +248,14 @@ struct llama_hparams { enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE; + // Resolved FFN gated activation flavor for archs that read + // `<arch>.hidden_activation` from the GGUF (e.g. ModernBert derivatives). + // Defaults to LLM_FFN_NONE (sentinel = 0); the mapping from the GGUF + // string to a real op is done at hparam-load time via + // llm_ffn_op_type_from_string() in llama-model.cpp, mirroring how + // rope_scaling_type_train is handled. + enum llm_ffn_op_type llm_ffn_op; + // Step35: optional per-layer clamps for (Swi)GLU std::array<float, LLAMA_MAX_LAYERS> swiglu_clamp_exp; // clamping for expert FFN std::array<float, LLAMA_MAX_LAYERS> swiglu_clamp_shexp; // shared expert @@ -255,6 +284,13 @@ struct llama_hparams { // return true if one of the layers is SWA bool is_swa_any() const; + bool is_swa(uint32_t il) const; + + void set_recr_pattern(uint32_t n_pattern, bool dense_first = false); + + // whether or not the given layer is recurrent (for hybrid models) + bool is_recr(uint32_t il) const; + uint32_t n_head(uint32_t il = 0) const; uint32_t n_head_kv(uint32_t il = 0) const; @@ -296,13 +332,8 @@ struct llama_hparams { // dimension of the recurrent state embeddings uint32_t n_embd_s() const; - // whether or not the given layer is recurrent (for hybrid models) - bool is_recurrent(uint32_t il) const; - uint32_t n_pos_per_embd() const; - bool is_swa(uint32_t il) const; - // note: currently only support if either all or none of the layers are MLA bool is_mla() const; @@ -311,8 +342,8 @@ struct llama_hparams { bool has_kv(uint32_t il) const; - // number of layers for which has_kv() returns true - uint32_t n_layer_kv() const; + // number of effective layers (excludes nextn layers) + uint32_t n_layer() const; // note that this function uses different SWA parameters from those in the hparams // note: inlined on purpose for performance reasons diff --git a/examples/talk-llama/llama-impl.h b/examples/talk-llama/llama-impl.h index e4f35c8e53d..7923c3f7ed5 100644 --- a/examples/talk-llama/llama-impl.h +++ b/examples/talk-llama/llama-impl.h @@ -3,6 +3,7 @@ #include "ggml.h" // for ggml_log_level #include <string> +#include <type_traits> #include <vector> #ifdef __GNUC__ @@ -40,6 +41,19 @@ struct no_init { no_init() = default; }; +template <typename dst_t, typename src_t> +static inline dst_t llama_cast(src_t v) { + if constexpr (std::is_same_v<src_t, dst_t>) { + return v; + } else if constexpr (std::is_same_v<src_t, ggml_fp16_t> && std::is_same_v<dst_t, float>) { + return ggml_fp16_to_fp32(v); + } else if constexpr (std::is_same_v<src_t, float> && std::is_same_v<dst_t, ggml_fp16_t>) { + return ggml_fp32_to_fp16(v); + } else { + static_assert(std::is_same_v<dst_t, void>, "unsupported type combination"); + } +} + struct time_meas { time_meas(int64_t & t_acc, bool disable = false); ~time_meas(); diff --git a/examples/talk-llama/llama-kv-cache-dsa.cpp b/examples/talk-llama/llama-kv-cache-dsa.cpp new file mode 100644 index 00000000000..916ab653756 --- /dev/null +++ b/examples/talk-llama/llama-kv-cache-dsa.cpp @@ -0,0 +1,261 @@ +#include "llama-kv-cache-dsa.h" + +#include "llama-impl.h" +#include "llama-batch.h" +#include "llama-model.h" + +#include <algorithm> +#include <cassert> + +// +// llama_kv_cache_dsa +// + +llama_kv_cache_dsa::llama_kv_cache_dsa( + const llama_model & model, + ggml_type type_k, + ggml_type type_v, + bool v_trans, + bool offload, + bool unified, + uint32_t kv_size, + uint32_t n_seq_max, + uint32_t n_pad, + uint32_t n_swa, + llama_swa_type swa_type, + const layer_filter_cb & filter, + const layer_reuse_cb & reuse) : + hparams_lid(model.hparams), n_stream(unified ? 1 : n_seq_max) { + + LLAMA_LOG_INFO("%s: creating main KV cache, size = %u cells\n", __func__, kv_size); + + kv_mla = std::make_unique<llama_kv_cache>( + model, model.hparams, type_k, type_v, + v_trans, offload, unified, kv_size, n_seq_max, n_pad, + n_swa, swa_type, nullptr, filter, reuse, nullptr); + + // we use llama_kv_cache for caching indexer keys + // by hand-tweaking some hparams we fool it to create + // indexer key cache tensors with correct dimensions + // https://github.com/ggml-org/llama.cpp/pull/21149#discussion_r3015940823 + + // DSA lightning indexer uses MQA with single key head + std::fill(hparams_lid.n_head_kv_arr.begin(), hparams_lid.n_head_kv_arr.end(), 1); + hparams_lid.n_embd_head_k_full = model.hparams.indexer_head_size; + hparams_lid.rope_type = LLAMA_ROPE_TYPE_NEOX; + + LLAMA_LOG_INFO("%s: creating indexer KV cache, size = %u cells\n", __func__, kv_size); + + kv_lid = std::make_unique<llama_kv_cache>( + model, hparams_lid, type_k, type_v, + v_trans, offload, unified, kv_size, n_seq_max, n_pad, + n_swa, swa_type, nullptr, filter, reuse, nullptr); +} + +void llama_kv_cache_dsa::clear(bool data) { + kv_mla->clear(data); + kv_lid->clear(data); +} + +bool llama_kv_cache_dsa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + bool res = true; + + res = res & kv_mla->seq_rm(seq_id, p0, p1); + res = res & kv_lid->seq_rm(seq_id, p0, p1); + + return res; +} + +void llama_kv_cache_dsa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { + kv_mla->seq_cp(seq_id_src, seq_id_dst, p0, p1); + kv_lid->seq_cp(seq_id_src, seq_id_dst, p0, p1); +} + +void llama_kv_cache_dsa::seq_keep(llama_seq_id seq_id) { + kv_mla->seq_keep(seq_id); + kv_lid->seq_keep(seq_id); +} + +void llama_kv_cache_dsa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { + kv_mla->seq_add(seq_id, p0, p1, shift); + kv_lid->seq_add(seq_id, p0, p1, shift); +} + +void llama_kv_cache_dsa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { + kv_mla->seq_div(seq_id, p0, p1, d); + kv_lid->seq_div(seq_id, p0, p1, d); +} + +llama_pos llama_kv_cache_dsa::seq_pos_min(llama_seq_id seq_id) const { + return kv_mla->seq_pos_min(seq_id); +} + +llama_pos llama_kv_cache_dsa::seq_pos_max(llama_seq_id seq_id) const { + return kv_mla->seq_pos_max(seq_id); +} + +std::map<ggml_backend_buffer_type_t, size_t> llama_kv_cache_dsa::memory_breakdown() const { + std::map<ggml_backend_buffer_type_t, size_t> mb = kv_mla->memory_breakdown(); + for (const auto & buft_size : kv_lid->memory_breakdown()) { + mb[buft_size.first] += buft_size.second; + } + return mb; +} + +llama_memory_context_ptr llama_kv_cache_dsa::init_batch( + llama_batch_allocr & balloc, + uint32_t n_ubatch, + bool embd_all) { + GGML_UNUSED(embd_all); + + do { + balloc.split_reset(); + + std::vector<llama_ubatch> ubatches; + while (true) { + auto ubatch = n_stream == 1 ? balloc.split_simple(n_ubatch) : balloc.split_equal(n_ubatch, true); + + if (ubatch.n_tokens == 0) { + break; + } + + ubatches.push_back(std::move(ubatch)); // NOLINT + } + + if (balloc.get_n_used() < balloc.get_n_tokens()) { + // failed to find a suitable split + break; + } + + auto sinfos_mla = kv_mla->prepare(ubatches); + if (sinfos_mla.empty()) { + break; + } + + auto sinfos_lid = kv_lid->prepare(ubatches); + if (sinfos_lid.empty()) { + break; + } + + assert(sinfos_mla.size() == sinfos_lid.size()); + + return std::make_unique<llama_kv_cache_dsa_context>( + this, std::move(sinfos_mla), std::move(sinfos_lid), std::move(ubatches)); + } while (false); + + return std::make_unique<llama_kv_cache_dsa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE); +} + +llama_memory_context_ptr llama_kv_cache_dsa::init_full() { + return std::make_unique<llama_kv_cache_dsa_context>(this); +} + +llama_memory_context_ptr llama_kv_cache_dsa::init_update(llama_context * lctx, bool optimize) { + return std::make_unique<llama_kv_cache_dsa_context>(this, lctx, optimize); +} + +bool llama_kv_cache_dsa::get_can_shift() const { + return kv_mla->get_can_shift() && + kv_lid->get_can_shift() && + kv_mla->get_size() == kv_lid->get_size(); +} + +void llama_kv_cache_dsa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const { + kv_mla->state_write(io, seq_id, flags); + kv_lid->state_write(io, seq_id, flags); +} + +void llama_kv_cache_dsa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) { + kv_mla->state_read(io, seq_id, flags); + kv_lid->state_read(io, seq_id, flags); +} + +llama_kv_cache * llama_kv_cache_dsa::get_mla() const { + return kv_mla.get(); +} + +llama_kv_cache * llama_kv_cache_dsa::get_lid() const { + return kv_lid.get(); +} + +// +// llama_kv_cache_dsa_context +// + +llama_kv_cache_dsa_context::llama_kv_cache_dsa_context(llama_memory_status status) : status(status) {} + +llama_kv_cache_dsa_context::llama_kv_cache_dsa_context( + llama_kv_cache_dsa * kv) : + ctx_mla(kv->get_mla()->init_full()), + ctx_lid(kv->get_lid()->init_full()), + status(llama_memory_status_combine(ctx_mla->get_status(), ctx_lid->get_status())) { +} + +llama_kv_cache_dsa_context::llama_kv_cache_dsa_context( + llama_kv_cache_dsa * kv, + llama_context * lctx, + bool optimize) : + ctx_mla(kv->get_mla()->init_update(lctx, optimize)), + ctx_lid(kv->get_lid()->init_update(lctx, optimize)), + status(llama_memory_status_combine(ctx_mla->get_status(), ctx_lid->get_status())) { +} + +llama_kv_cache_dsa_context::llama_kv_cache_dsa_context( + llama_kv_cache_dsa * kv, + slot_info_vec_t sinfos_mla, + slot_info_vec_t sinfos_lid, + std::vector<llama_ubatch> ubatches) : + ubatches(std::move(ubatches)), + // note: here we copy the ubatches. not sure if this is ideal + ctx_mla(new llama_kv_cache_context(kv->get_mla(), std::move(sinfos_mla), this->ubatches)), + ctx_lid(new llama_kv_cache_context(kv->get_lid(), std::move(sinfos_lid), this->ubatches)), + status(llama_memory_status_combine(ctx_mla->get_status(), ctx_lid->get_status())) { +} + +llama_kv_cache_dsa_context:: ~llama_kv_cache_dsa_context() = default; + +bool llama_kv_cache_dsa_context::next() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + ctx_mla->next(); + ctx_lid->next(); + + if (++i_next >= ubatches.size()) { + return false; + } + + return true; +} + +bool llama_kv_cache_dsa_context::apply() { + assert(!llama_memory_status_is_fail(status)); + + bool res = true; + + res = res & ctx_mla->apply(); + res = res & ctx_lid->apply(); + + return res; +} + +llama_memory_status llama_kv_cache_dsa_context::get_status() const { + return status; +} + +const llama_ubatch & llama_kv_cache_dsa_context::get_ubatch() const { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return ubatches[i_next]; +} + +const llama_kv_cache_context * llama_kv_cache_dsa_context::get_mla() const { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return static_cast<const llama_kv_cache_context *>(ctx_mla.get()); +} + +const llama_kv_cache_context * llama_kv_cache_dsa_context::get_lid() const { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return static_cast<const llama_kv_cache_context *>(ctx_lid.get()); +} diff --git a/examples/talk-llama/llama-kv-cache-dsa.h b/examples/talk-llama/llama-kv-cache-dsa.h new file mode 100644 index 00000000000..e2b330993b8 --- /dev/null +++ b/examples/talk-llama/llama-kv-cache-dsa.h @@ -0,0 +1,138 @@ +#pragma once + +#include "llama-kv-cache.h" + +#include <vector> + +// +// llama_kv_cache_dsa +// + +// utilizes two instances of llama_kv_cache: +// - the first instance is for caching key tensors of the model, +// - the second instance is for caching lightning indexer key tensors + +class llama_kv_cache_dsa : public llama_memory_i { +public: + llama_kv_cache_dsa( + const llama_model & model, + ggml_type type_k, + ggml_type type_v, + bool v_trans, + bool offload, + bool unified, + uint32_t kv_size, + uint32_t n_seq_max, + uint32_t n_pad, + uint32_t n_swa, + llama_swa_type swa_type, + const layer_filter_cb & filter, + const layer_reuse_cb & reuse); + + ~llama_kv_cache_dsa() = default; + + // + // llama_memory_i + // + + llama_memory_context_ptr init_batch( + llama_batch_allocr & balloc, + uint32_t n_ubatch, + bool embd_all) override; + + llama_memory_context_ptr init_full() override; + + llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override; + + bool get_can_shift() const override; + + void clear(bool data) override; + + bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; + void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; + void seq_keep(llama_seq_id seq_id) override; + void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override; + void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; + + llama_pos seq_pos_min(llama_seq_id seq_id) const override; + llama_pos seq_pos_max(llama_seq_id seq_id) const override; + + std::map<ggml_backend_buffer_type_t, size_t> memory_breakdown() const override; + + // state write/load + + void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override; + void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override; + + // + // llama_kv_cache_dsa specific API + // + + llama_kv_cache * get_mla() const; + llama_kv_cache * get_lid() const; + +private: + // we keep indexer KV cache hparams instance here as llama_kv_cache stores only reference to it + llama_hparams hparams_lid; + const uint32_t n_stream = 1; + + std::unique_ptr<llama_kv_cache> kv_mla; + std::unique_ptr<llama_kv_cache> kv_lid; +}; + +class llama_kv_cache_dsa_context : public llama_memory_context_i { +public: + using slot_info_vec_t = llama_kv_cache::slot_info_vec_t; + + // used for errors + llama_kv_cache_dsa_context(llama_memory_status status); + + // used to create a full-cache context + llama_kv_cache_dsa_context( + llama_kv_cache_dsa * kv); + + // used to create an update context + llama_kv_cache_dsa_context( + llama_kv_cache_dsa * kv, + llama_context * lctx, + bool optimize); + + // used to create a batch processing context from a batch + llama_kv_cache_dsa_context( + llama_kv_cache_dsa * kv, + slot_info_vec_t sinfos_base, + slot_info_vec_t sinfos_ik, + std::vector<llama_ubatch> ubatches); + + virtual ~llama_kv_cache_dsa_context(); + + // + // llama_memory_context_i + // + + bool next() override; + bool apply() override; + + llama_memory_status get_status() const override; + const llama_ubatch & get_ubatch() const override; + + // + // llama_kv_cache_dsa_context specific API + // + + const llama_kv_cache_context * get_mla() const; + const llama_kv_cache_context * get_lid() const; + +private: + //llama_kv_cache_dsa * kv; + + // the index of the next ubatch to process + size_t i_next = 0; + + std::vector<llama_ubatch> ubatches; + + const llama_memory_context_ptr ctx_mla; + const llama_memory_context_ptr ctx_lid; + + const llama_memory_status status; +}; diff --git a/examples/talk-llama/llama-kv-cache-iswa.cpp b/examples/talk-llama/llama-kv-cache-iswa.cpp index 26e2cb4270b..aa1b1b72ebe 100644 --- a/examples/talk-llama/llama-kv-cache-iswa.cpp +++ b/examples/talk-llama/llama-kv-cache-iswa.cpp @@ -23,8 +23,10 @@ llama_kv_cache_iswa::llama_kv_cache_iswa( uint32_t n_seq_max, uint32_t n_ubatch, uint32_t n_pad, + llama_memory_t mem_other, const layer_filter_cb & filter, - const layer_reuse_cb & reuse) : hparams(model.hparams), unified(unified) { + const layer_reuse_cb & reuse, + const layer_share_cb & share) : hparams(model.hparams), unified(unified) { // chain filters const layer_filter_cb filter_base = [&](int32_t il) { @@ -59,17 +61,27 @@ llama_kv_cache_iswa::llama_kv_cache_iswa( LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base); + llama_memory_t mem_other_base = nullptr; + if (mem_other) { + mem_other_base = static_cast<llama_kv_cache_iswa *>(mem_other)->get_base(); + } + + llama_memory_t mem_other_swa = nullptr; + if (mem_other) { + mem_other_swa = static_cast<llama_kv_cache_iswa *>(mem_other)->get_swa(); + } + kv_base = std::make_unique<llama_kv_cache>( - model, type_k, type_v, + model, hparams, type_k, type_v, v_trans, offload, unified, size_base, n_seq_max, n_pad, - 0, LLAMA_SWA_TYPE_NONE, filter_base, reuse); + 0, LLAMA_SWA_TYPE_NONE, mem_other_base, filter_base, reuse, share); LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa); kv_swa = std::make_unique<llama_kv_cache>( - model, type_k, type_v, + model, hparams, type_k, type_v, v_trans, offload, unified, size_swa, n_seq_max, n_pad, - hparams.n_swa, hparams.swa_type, filter_swa, reuse); + hparams.n_swa, hparams.swa_type, mem_other_swa, filter_swa, reuse, share); } void llama_kv_cache_iswa::clear(bool data) { diff --git a/examples/talk-llama/llama-kv-cache-iswa.h b/examples/talk-llama/llama-kv-cache-iswa.h index 70ab22f0d60..dfafc1ef510 100644 --- a/examples/talk-llama/llama-kv-cache-iswa.h +++ b/examples/talk-llama/llama-kv-cache-iswa.h @@ -25,8 +25,10 @@ class llama_kv_cache_iswa : public llama_memory_i { uint32_t n_seq_max, uint32_t n_ubatch, uint32_t n_pad, + llama_memory_t mem_other, const layer_filter_cb & filter, - const layer_reuse_cb & reuse); + const layer_reuse_cb & reuse, + const layer_share_cb & share); ~llama_kv_cache_iswa() = default; diff --git a/examples/talk-llama/llama-kv-cache.cpp b/examples/talk-llama/llama-kv-cache.cpp index a49a055a630..2802103bdd8 100644 --- a/examples/talk-llama/llama-kv-cache.cpp +++ b/examples/talk-llama/llama-kv-cache.cpp @@ -79,6 +79,7 @@ static ggml_tensor * ggml_mul_mat_aux( llama_kv_cache::llama_kv_cache( const llama_model & model, + const llama_hparams & hparams, ggml_type type_k, ggml_type type_v, bool v_trans, @@ -89,14 +90,30 @@ llama_kv_cache::llama_kv_cache( uint32_t n_pad, uint32_t n_swa, llama_swa_type swa_type, + llama_memory_t mem_other, const layer_filter_cb & filter, - const layer_reuse_cb & reuse) : - model(model), hparams(model.hparams), v_trans(v_trans), - n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) { + const layer_reuse_cb & reuse, + const layer_share_cb & share) : + model(model), hparams(hparams), v_trans(v_trans), + n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type), + other(static_cast<llama_kv_cache *>(mem_other)), + v_cells_impl(other ? other->v_cells_impl : std::make_shared<llama_kv_cells_vec>()), + v_cells(*v_cells_impl) { + + // shared cells view the source cache's K/V tensors, so the cell count + // follows the source allocation: a fitted target can be smaller than the + // draft default and oversized views would overflow the source tensors + if (other) { + const uint32_t size_other = other->get_size(); + if (kv_size != size_other) { + LLAMA_LOG_WARN("%s: kv_size = %u overridden to %u to match the shared source cache\n", __func__, kv_size, size_other); + kv_size = size_other; + } + } GGML_ASSERT(kv_size % n_pad == 0); - const uint32_t n_layer_kv = hparams.n_layer_kv(); + const uint32_t n_layer = hparams.n_layer_all; // define a comparator for the buft -> ctx map to ensure that the order is well-defined: struct ggml_backend_buft_comparator { @@ -111,7 +128,7 @@ llama_kv_cache::llama_kv_cache( auto it = ctx_map.find(buft); if (it == ctx_map.end()) { ggml_init_params params = { - /*.mem_size =*/ size_t(2u*(1 + n_stream)*n_layer_kv*ggml_tensor_overhead()), + /*.mem_size =*/ size_t(2u*(1 + n_stream)*n_layer*ggml_tensor_overhead()), /*.mem_buffer =*/ NULL, /*.no_alloc =*/ true, }; @@ -159,7 +176,7 @@ llama_kv_cache::llama_kv_cache( const bool is_mla = hparams.is_mla(); - for (uint32_t il = 0; il < hparams.n_layer; il++) { + for (uint32_t il = 0; il < n_layer; il++) { if (!hparams.has_kv(il)) { LLAMA_LOG_DEBUG("%s: layer %3d: does not have KV cache\n", __func__, il); continue; @@ -170,6 +187,24 @@ llama_kv_cache::llama_kv_cache( continue; } + if (share && other) { + const int32_t il_share = share(il); + + if (il_share >= 0) { + const auto & layer_share = other->layers[other->map_layer_ids[il_share]]; + + LLAMA_LOG_WARN("%s: layer %3d: sharing with layer %d. k = %p, v = %p\n", __func__, il, il_share, + layer_share.k->data, layer_share.v->data); + + map_layer_ids[il] = layers.size(); + + layers.push_back(layer_share); + layers.back().il = il; + + continue; + } + } + if (n_embd_head_k_all == 0) { n_embd_head_k_all = (int32_t) hparams.n_embd_head_k(il); } else if (n_embd_head_k_all > 0 && n_embd_head_k_all != (int32_t) hparams.n_embd_head_k(il)) { @@ -229,7 +264,7 @@ llama_kv_cache::llama_kv_cache( if (reuse) { LLAMA_LOG_DEBUG("%s: reusing layers:\n", __func__); - for (uint32_t il = 0; il < hparams.n_layer; il++) { + for (uint32_t il = 0; il < n_layer; il++) { const int32_t il_reuse = reuse(il); if (il_reuse < 0) { @@ -253,7 +288,7 @@ llama_kv_cache::llama_kv_cache( // allocate tensors and initialize the buffers to avoid NaNs in the padding for (auto & [buft, ctx] : ctx_map) { ggml_backend_buffer_t buf; - if (model.hparams.no_alloc) { + if (hparams.no_alloc) { buf = ggml_backend_buft_alloc_buffer(buft, /*size =*/ 0); // dummy buffer for (ggml_tensor * t = ggml_get_first_tensor(ctx.get()); t != nullptr; t = ggml_get_next_tensor(ctx.get(), t)) { t->buffer = buf; // set dummy buffer for KV cache so that the backend scheduler won't try to allocate it @@ -281,23 +316,37 @@ llama_kv_cache::llama_kv_cache( ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); } - const char * LLAMA_ATTN_ROT_DISABLE = getenv("LLAMA_ATTN_ROT_DISABLE"); - const bool attn_rot_disable = LLAMA_ATTN_ROT_DISABLE ? atoi(LLAMA_ATTN_ROT_DISABLE) : false; - if (attn_rot_disable) { - LLAMA_LOG_WARN("%s: attention rotation force disabled (LLAMA_ATTN_ROT_DISABLE)\n", __func__); - } + // TODO: refactor [TAG_KV_CACHE_SHARE_CELLS] + if (other) { + n_embd_head_k_all = other->n_embd_head_k_all; + n_embd_head_v_all = other->n_embd_head_v_all; + + attn_rot_k = other->attn_rot_k; + attn_rot_v = other->attn_rot_v; + } else { + const char * LLAMA_ATTN_ROT_DISABLE = getenv("LLAMA_ATTN_ROT_DISABLE"); + const bool attn_rot_disable = LLAMA_ATTN_ROT_DISABLE ? atoi(LLAMA_ATTN_ROT_DISABLE) : false; + if (attn_rot_disable) { + LLAMA_LOG_WARN("%s: attention rotation force disabled (LLAMA_ATTN_ROT_DISABLE)\n", __func__); + } + + attn_rot_k = + !attn_rot_disable && + n_embd_head_k_all > 0 && + ggml_is_quantized(type_k) && + hparams.n_embd_head_k() % 64 == 0; - attn_rot_k = - !attn_rot_disable && - n_embd_head_k_all > 0 && - ggml_is_quantized(type_k) && - hparams.n_embd_head_k() % 64 == 0; + // always create Hadamard rotation tensors for DeepSeek V3.2 DSA lightning indexer + if (model.arch == LLM_ARCH_DEEPSEEK32 && hparams.n_embd_head_k_full == hparams.indexer_head_size) { + attn_rot_k = true; + } - attn_rot_v = - !attn_rot_disable && - n_embd_head_v_all > 0 && - ggml_is_quantized(type_v) && - hparams.n_embd_head_v() % 64 == 0; + attn_rot_v = + !attn_rot_disable && + n_embd_head_v_all > 0 && + ggml_is_quantized(type_v) && + hparams.n_embd_head_v() % 64 == 0; + } LLAMA_LOG_INFO("%s: attn_rot_k = %d, n_embd_head_k_all = %d\n", __func__, attn_rot_k, n_embd_head_k_all); LLAMA_LOG_INFO("%s: attn_rot_v = %d, n_embd_head_k_all = %d\n", __func__, attn_rot_v, n_embd_head_v_all); @@ -341,6 +390,11 @@ void llama_kv_cache::clear(bool data) { } bool llama_kv_cache::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + // TODO: refactor [TAG_KV_CACHE_SHARE_CELLS] + if (other) { + return true; + } + GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size())); if (p0 < 0) { @@ -404,6 +458,11 @@ bool llama_kv_cache::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { } void llama_kv_cache::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { + // TODO: refactor [TAG_KV_CACHE_SHARE_CELLS] + if (other) { + return; + } + GGML_ASSERT(seq_id_src >= 0 && (size_t) seq_id_src < seq_to_stream.size()); GGML_ASSERT(seq_id_dst >= 0 && (size_t) seq_id_dst < seq_to_stream.size()); @@ -491,6 +550,11 @@ void llama_kv_cache::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, ll } void llama_kv_cache::seq_keep(llama_seq_id seq_id) { + // TODO: refactor [TAG_KV_CACHE_SHARE_CELLS] + if (other) { + return; + } + GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()); auto & cells = v_cells[seq_to_stream[seq_id]]; @@ -513,6 +577,11 @@ void llama_kv_cache::seq_keep(llama_seq_id seq_id) { } void llama_kv_cache::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { + // TODO: refactor [TAG_KV_CACHE_SHARE_CELLS] + if (other) { + return; + } + GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()); GGML_ASSERT(hparams.n_pos_per_embd() == 1 && "seq_add() is only supported for n_pos_per_embd() == 1"); @@ -558,6 +627,11 @@ void llama_kv_cache::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, ll } void llama_kv_cache::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { + // TODO: refactor [TAG_KV_CACHE_SHARE_CELLS] + if (other) { + return; + } + GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()); GGML_ASSERT(hparams.n_pos_per_embd() == 1 && "seq_div() is only supported for n_pos_per_embd() == 1"); @@ -592,6 +666,11 @@ void llama_kv_cache::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, in } llama_pos llama_kv_cache::seq_pos_min(llama_seq_id seq_id) const { + // TODO: refactor [TAG_KV_CACHE_SHARE_CELLS] + if (other) { + return other->seq_pos_min(seq_id); + } + GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()); const auto & cells = v_cells[seq_to_stream[seq_id]]; @@ -600,6 +679,11 @@ llama_pos llama_kv_cache::seq_pos_min(llama_seq_id seq_id) const { } llama_pos llama_kv_cache::seq_pos_max(llama_seq_id seq_id) const { + // TODO: refactor [TAG_KV_CACHE_SHARE_CELLS] + if (other) { + return other->seq_pos_max(seq_id); + } + GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()); const auto & cells = v_cells[seq_to_stream[seq_id]]; @@ -740,6 +824,11 @@ llama_kv_cache::slot_info_vec_t llama_kv_cache::prepare(const std::vector<llama_ } bool llama_kv_cache::update(llama_context * lctx, bool do_shift, const stream_copy_info & sc_info) { + // TODO: refactor [TAG_KV_CACHE_SHARE_CELLS] + if (other) { + return true; + } + bool updated = false; auto * sched = lctx->get_sched(); @@ -1015,6 +1104,11 @@ llama_kv_cache::slot_info llama_kv_cache::find_slot(const llama_ubatch & ubatch, } void llama_kv_cache::apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch) { + // TODO: refactor [TAG_KV_CACHE_SHARE_CELLS] + if (other) { + return; + } + // keep track of the max sequence position that we would overwrite with this ubatch // for non-SWA cache, this would be always empty llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ]; @@ -1430,8 +1524,8 @@ struct args_set_input_kq_mask { int64_t n_tps; }; -template<bool causal, bool swa, bool is_2d, bool alibi> -static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) { +template<typename T, bool causal, bool swa, bool is_2d, bool alibi> +static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, T * data) { //const auto & hparams = args.hparams; const auto & ubatch = args.ubatch; @@ -1445,6 +1539,9 @@ static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * const int64_t n_stream = args.n_stream; const int64_t n_tps = args.n_tps; + const T mask_keep = llama_cast<T>(0.0f); + const T mask_drop = llama_cast<T>(-INFINITY); + // the min position in the batch for each sequence llama_pos seq_pos_min[LLAMA_MAX_SEQ]; std::fill(seq_pos_min, seq_pos_min + LLAMA_MAX_SEQ, INT32_MAX); @@ -1563,46 +1660,55 @@ static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * } if (alibi) { - data[idst + j] = -std::abs(p0 - p1); + data[idst + j] = llama_cast<T>(static_cast<float>(-std::abs(p0 - p1))); } else { - data[idst + j] = 0.0f; + data[idst + j] = mask_keep; } continue; skip: - data[idst + j] = -INFINITY; + data[idst + j] = mask_drop; } } } } -template<bool causal, bool swa, bool is_2d> -static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) { +template<typename T, bool causal, bool swa, bool is_2d> +static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, T * data) { const bool alibi = args.hparams.use_alibi; if (alibi) { - set_input_kq_mask_impl<causal, swa, is_2d, true> (args, data); + set_input_kq_mask_impl<T, causal, swa, is_2d, true> (args, data); } else { - set_input_kq_mask_impl<causal, swa, is_2d, false>(args, data); + set_input_kq_mask_impl<T, causal, swa, is_2d, false>(args, data); } } -template<bool causal, bool swa> -static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) { +template<typename T, bool causal, bool swa> +static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, T * data) { const bool is_2d = args.ubatch->is_pos_2d(); if (is_2d) { - set_input_kq_mask_impl<causal, swa, true> (args, data); + set_input_kq_mask_impl<T, causal, swa, true> (args, data); } else { - set_input_kq_mask_impl<causal, swa, false>(args, data); + set_input_kq_mask_impl<T, causal, swa, false>(args, data); } } -template<bool causal> -static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, float * data) { +template<typename T, bool causal> +static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, T * data) { const bool swa = args.swa_type != LLAMA_SWA_TYPE_NONE; if (swa) { - set_input_kq_mask_impl<causal, true> (args, data); + set_input_kq_mask_impl<T, causal, true> (args, data); + } else { + set_input_kq_mask_impl<T, causal, false>(args, data); + } +} + +template<typename T> +static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, T * data, bool causal_attn) { + if (causal_attn) { + set_input_kq_mask_impl<T, true> (args, data); } else { - set_input_kq_mask_impl<causal, false>(args, data); + set_input_kq_mask_impl<T, false>(args, data); } } @@ -1610,7 +1716,6 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u const uint32_t n_tokens = ubatch->n_tokens; GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); - float * data = (float *) dst->data; const int64_t n_kv = dst->ne[0]; const int64_t n_stream = dst->ne[3]; // num streams in the current ubatch @@ -1634,10 +1739,10 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u /*.n_tps =*/ n_tps, }; - if (causal_attn) { - set_input_kq_mask_impl<true> (args, data); + if (dst->type == GGML_TYPE_F16) { + set_input_kq_mask_impl<ggml_fp16_t>(args, (ggml_fp16_t *) dst->data, causal_attn); } else { - set_input_kq_mask_impl<false>(args, data); + set_input_kq_mask_impl<float>(args, (float *) dst->data, causal_attn); } //const int64_t t_end = ggml_time_us(); @@ -1798,6 +1903,9 @@ void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) { } ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_context * lctx) const { + // TODO: refactor [TAG_KV_CACHE_SHARE_CELLS] + GGML_ASSERT(!other); + auto * ctx = res->get_ctx(); auto * gf = res->get_gf(); @@ -1843,6 +1951,11 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co } void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const { + // TODO: refactor [TAG_KV_CACHE_SHARE_CELLS] + if (other) { + return; + } + GGML_UNUSED(flags); io.write(&n_stream, sizeof(n_stream)); @@ -1859,7 +1972,19 @@ void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id, lla uint32_t cell_range_begin = cells.size(); for (uint32_t i = 0; i < cells.size(); ++i) { - if (!cells.is_empty(i) && (seq_id == -1 || cells.seq_has(i, seq_id))) { + bool add_cell = true; + + add_cell = add_cell && !cells.is_empty(i); + add_cell = add_cell && (seq_id == -1 || cells.seq_has(i, seq_id)); + + // check the cell is not SWA-masked + if (add_cell && seq_id != -1) { + const bool is_masked = llama_hparams::is_masked_swa(n_swa, swa_type, cells.pos_get(i), cells.seq_pos_max(seq_id)); + + add_cell = !is_masked; + } + + if (add_cell) { ++cell_count; if (cell_range_begin == cells.size()) { cell_range_begin = i; @@ -1896,6 +2021,11 @@ void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id, lla } void llama_kv_cache::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) { + // TODO: refactor [TAG_KV_CACHE_SHARE_CELLS] + if (other) { + return; + } + GGML_UNUSED(flags); GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size())); @@ -2112,7 +2242,7 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32 sinfo = find_slot(ubatch, false); if (sinfo.empty()) { - LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); + LLAMA_LOG_ERROR("%s: failed to find %d available cells in kv cache\n", __func__, cell_count); return false; } diff --git a/examples/talk-llama/llama-kv-cache.h b/examples/talk-llama/llama-kv-cache.h index 0b62dc7b232..3d68f98c142 100644 --- a/examples/talk-llama/llama-kv-cache.h +++ b/examples/talk-llama/llama-kv-cache.h @@ -93,8 +93,12 @@ class llama_kv_cache : public llama_memory_i { using slot_info_vec_t = std::vector<slot_info>; + // TODO: refactor the memory instances to not depend on `llama_model` + // instead pass all necessary info (e.g. hparams, dev layers, arch, etc.) directly + // likely through `struct llama_memory_params` llama_kv_cache( const llama_model & model, + const llama_hparams & hparams, ggml_type type_k, ggml_type type_v, bool v_trans, @@ -105,8 +109,10 @@ class llama_kv_cache : public llama_memory_i { uint32_t n_pad, uint32_t n_swa, llama_swa_type swa_type, + llama_memory_t mem_other, const layer_filter_cb & filter, - const layer_reuse_cb & reuse); + const layer_reuse_cb & reuse, + const layer_share_cb & share); ~llama_kv_cache() = default; @@ -260,7 +266,12 @@ class llama_kv_cache : public llama_memory_i { // note: this is not part of the KV state and it's only used to speed-up the find_slot() method std::vector<uint32_t> v_heads; - std::vector<llama_kv_cells> v_cells; + // TODO: temporary until we refactor to be able to share the same cells between 2 kv caches [TAG_KV_CACHE_SHARE_CELLS] + llama_kv_cache * other; + + std::shared_ptr<llama_kv_cells_vec> v_cells_impl; + + llama_kv_cells_vec & v_cells; // maps from a sequence id to a stream id std::vector<uint32_t> seq_to_stream; diff --git a/examples/talk-llama/llama-kv-cells.h b/examples/talk-llama/llama-kv-cells.h index 10063bf4272..fddd31a0b21 100644 --- a/examples/talk-llama/llama-kv-cells.h +++ b/examples/talk-llama/llama-kv-cells.h @@ -531,3 +531,5 @@ class llama_kv_cells { } } }; + +using llama_kv_cells_vec = std::vector<llama_kv_cells>; diff --git a/examples/talk-llama/llama-memory-hybrid-iswa.cpp b/examples/talk-llama/llama-memory-hybrid-iswa.cpp index 72f5c2fea72..c7d4bcd413e 100644 --- a/examples/talk-llama/llama-memory-hybrid-iswa.cpp +++ b/examples/talk-llama/llama-memory-hybrid-iswa.cpp @@ -43,9 +43,11 @@ llama_memory_hybrid_iswa::llama_memory_hybrid_iswa( n_seq_max, n_ubatch, n_pad, + nullptr, filter_attn == nullptr ? - [&](int32_t il) { return !hparams.is_recurrent(il); } + [&](int32_t il) { return !hparams.is_recr(il); } : filter_attn, + nullptr, nullptr )), mem_recr(new llama_memory_recurrent( @@ -57,7 +59,7 @@ llama_memory_hybrid_iswa::llama_memory_hybrid_iswa( n_seq_max, n_rs_seq, filter_recr == nullptr ? - [&](int32_t il) { return hparams.is_recurrent(il); } + [&](int32_t il) { return hparams.is_recr(il); } : filter_recr )) {} diff --git a/examples/talk-llama/llama-memory-hybrid.cpp b/examples/talk-llama/llama-memory-hybrid.cpp index 33b3b395e0c..f2d49cbce54 100644 --- a/examples/talk-llama/llama-memory-hybrid.cpp +++ b/examples/talk-llama/llama-memory-hybrid.cpp @@ -33,6 +33,7 @@ llama_memory_hybrid::llama_memory_hybrid( hparams(model.hparams), mem_attn(new llama_kv_cache( model, + model.hparams, type_k, type_v, v_trans, @@ -43,9 +44,11 @@ llama_memory_hybrid::llama_memory_hybrid( n_pad, n_swa, swa_type, + nullptr, filter_attn == nullptr ? - [&](int32_t il) { return !hparams.is_recurrent(il); } + [&](int32_t il) { return !hparams.is_recr(il); } : filter_attn, + nullptr, nullptr )), mem_recr(new llama_memory_recurrent( @@ -57,7 +60,7 @@ llama_memory_hybrid::llama_memory_hybrid( n_seq_max, n_rs_seq, filter_recr == nullptr ? - [&](int32_t il) { return hparams.is_recurrent(il); } + [&](int32_t il) { return hparams.is_recr(il); } : filter_recr )) {} diff --git a/examples/talk-llama/llama-memory-recurrent.cpp b/examples/talk-llama/llama-memory-recurrent.cpp index ec5dc5835dd..6a4892fb471 100644 --- a/examples/talk-llama/llama-memory-recurrent.cpp +++ b/examples/talk-llama/llama-memory-recurrent.cpp @@ -26,7 +26,7 @@ llama_memory_recurrent::llama_memory_recurrent( uint32_t n_seq_max, uint32_t n_rs_seq, const layer_filter_cb & filter) : hparams(model.hparams), n_seq_max(n_seq_max) { - const int32_t n_layer = hparams.n_layer; + const int32_t n_layer = hparams.n_layer(); head = 0; size = mem_size; @@ -863,7 +863,7 @@ void llama_memory_recurrent::state_write_meta(llama_io_write_i & io, const std:: void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const { const uint32_t s_trans = 0; - const uint32_t n_layer = hparams.n_layer; + const uint32_t n_layer = hparams.n_layer(); io.write(&s_trans, sizeof(s_trans)); io.write(&n_layer, sizeof(n_layer)); @@ -1047,8 +1047,8 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell io.read(&s_trans, sizeof(s_trans)); io.read(&n_layer, sizeof(n_layer)); - if (n_layer != hparams.n_layer) { - LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, hparams.n_layer); + if (n_layer != hparams.n_layer()) { + LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, hparams.n_layer()); return false; } if (cell_count > size) { diff --git a/examples/talk-llama/llama-memory.h b/examples/talk-llama/llama-memory.h index 4ad1612e45b..db825396645 100644 --- a/examples/talk-llama/llama-memory.h +++ b/examples/talk-llama/llama-memory.h @@ -23,6 +23,8 @@ struct llama_memory_params { bool swa_full; llama_context_type ctx_type; + + llama_memory_t mem_other; }; enum llama_memory_status { @@ -76,6 +78,8 @@ struct llama_memory_i { // return negative value to indicate that the layer il should not reuse memory using layer_reuse_cb = std::function<int32_t(int32_t il)>; + using layer_share_cb = std::function<int32_t(int32_t il)>; + virtual ~llama_memory_i() = default; // split the input batch into a set of ubatches and verify that they can fit into the cache diff --git a/examples/talk-llama/llama-model-loader.cpp b/examples/talk-llama/llama-model-loader.cpp index c645d0785ab..0d1cf3cc33b 100644 --- a/examples/talk-llama/llama-model-loader.cpp +++ b/examples/talk-llama/llama-model-loader.cpp @@ -146,7 +146,7 @@ namespace GGUFMeta { const enum gguf_type arr_type = gguf_get_arr_type(ctx, k); return ArrayInfo { arr_type, - size_t(gguf_get_arr_n(ctx, k)), + gguf_get_arr_n(ctx, k), arr_type == GGUF_TYPE_STRING ? nullptr : gguf_get_arr_data(ctx, k), }; } @@ -393,6 +393,7 @@ namespace GGUFMeta { } template bool llama_model_loader::get_arr<std::vector<std::string>>(enum llm_kv kid, std::vector<std::string> & result, bool required); + template bool llama_model_loader::get_arr<std::array<int32_t, 512>>(enum llm_kv kid, std::array<int32_t, 512> & result, bool required); template<typename T> bool llama_model_loader::get_key(const std::string & key, T & result, bool required) { @@ -445,7 +446,7 @@ namespace GGUFMeta { } if (n > N_MAX) { - throw std::runtime_error(format("n > N_MAX: %u > %u for key %s", (uint32_t) n, (uint32_t) N_MAX, key.c_str())); + throw std::runtime_error(format("n > N_MAX: %u > %u for key %s", n, (uint32_t) N_MAX, key.c_str())); } if (gguf_get_kv_type(metadata, kid) == GGUF_TYPE_ARRAY) { @@ -502,9 +503,9 @@ namespace GGUFMeta { } // TODO: this is not very clever - figure out something better - template bool llama_model_loader::get_key_or_arr<std::array<int, 4>>(enum llm_kv kid, std::array<int, 4> & result, uint32_t n, bool required); + template bool llama_model_loader::get_key_or_arr<std::array<int, 4>> (enum llm_kv kid, std::array<int, 4> & result, uint32_t n, bool required); template bool llama_model_loader::get_key_or_arr<std::array<uint32_t, 512>>(enum llm_kv kid, std::array<uint32_t, 512> & result, uint32_t n, bool required); - template bool llama_model_loader::get_key_or_arr<std::array<float, 512>>(enum llm_kv kid, std::array<float, 512> & result, uint32_t n, bool required); + template bool llama_model_loader::get_key_or_arr<std::array<float, 512>>(enum llm_kv kid, std::array<float, 512> & result, uint32_t n, bool required); llama_model_loader::llama_model_loader( @@ -1050,10 +1051,10 @@ struct ggml_tensor * llama_model_loader::create_tensor( if (it == ctx_map.end()) { // one ggml context per buffer type int max_n_tensors = n_tensors; - max_n_tensors += 1; // duplicated output tensor - max_n_tensors += hparams.n_layer*2; // duplicated rope freq tensors + max_n_tensors += 1; // duplicated output tensor + max_n_tensors += hparams.n_layer()*2; // duplicated rope freq tensors if (files.empty()) { - max_n_tensors += hparams.n_layer*256; // this should be well above what any model actually uses + max_n_tensors += hparams.n_layer()*256; // this should be well above what any model actually uses } const size_t ctx_size = ggml_tensor_overhead()*max_n_tensors; diff --git a/examples/talk-llama/llama-model-saver.cpp b/examples/talk-llama/llama-model-saver.cpp index 528e4c9c069..67d4a9df0f0 100644 --- a/examples/talk-llama/llama-model-saver.cpp +++ b/examples/talk-llama/llama-model-saver.cpp @@ -14,9 +14,6 @@ bool llama_model_saver_supports_arch(llm_arch arch) { switch (arch) { - case LLM_ARCH_QWEN3NEXT: - case LLM_ARCH_QWEN35: - case LLM_ARCH_QWEN35MOE: case LLM_ARCH_PLAMO3: case LLM_ARCH_GEMMA3: case LLM_ARCH_GEMMA3N: @@ -29,6 +26,7 @@ bool llama_model_saver_supports_arch(llm_arch arch) { case LLM_ARCH_APERTUS: case LLM_ARCH_MIMO2: case LLM_ARCH_STEP35: + case LLM_ARCH_MELLUM: return false; default: return true; @@ -79,7 +77,7 @@ void llama_model_saver::add_kv(const enum llm_kv key, const char value) { template <typename Container> void llama_model_saver::add_kv(const enum llm_kv key, const Container & value, const bool per_layer) { GGML_ASSERT(model != nullptr || !per_layer); - const size_t n_values = per_layer ? size_t(model->hparams.n_layer) : value.size(); + const size_t n_values = per_layer ? size_t(model->hparams.n_layer()) : value.size(); GGML_ASSERT(n_values <= value.size()); if (n_values == 0) { @@ -106,6 +104,8 @@ void llama_model_saver::add_kv(const enum llm_kv key, const Container & value, c gguf_set_arr_data(gguf_ctx, llm_kv(key).c_str(), GGUF_TYPE_INT8, value.data(), n_values); } else if (std::is_same<typename Container::value_type, uint32_t>::value) { gguf_set_arr_data(gguf_ctx, llm_kv(key).c_str(), GGUF_TYPE_UINT32, value.data(), n_values); + } else if (std::is_same<typename Container::value_type, bool>::value) { + gguf_set_arr_data(gguf_ctx, llm_kv(key).c_str(), GGUF_TYPE_BOOL, value.data(), n_values); } else if (std::is_same<typename Container::value_type, int32_t>::value) { gguf_set_arr_data(gguf_ctx, llm_kv(key).c_str(), GGUF_TYPE_INT32, value.data(), n_values); } else if (std::is_same<typename Container::value_type, float>::value) { @@ -206,7 +206,7 @@ void llama_model_saver::add_kv_from_model() { if (hparams.n_embd_out_impl > 0) { add_kv(LLM_KV_EMBEDDING_LENGTH_OUT, hparams.n_embd_out_impl); } - add_kv(LLM_KV_BLOCK_COUNT, hparams.n_layer); + add_kv(LLM_KV_BLOCK_COUNT, hparams.n_layer_all); add_kv(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); add_kv(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, true); add_kv(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); @@ -227,8 +227,9 @@ void llama_model_saver::add_kv_from_model() { add_kv(LLM_KV_EXPERT_GROUP_SCALE, hparams.expert_group_scale); add_kv(LLM_KV_EXPERTS_PER_GROUP, hparams.n_group_experts); add_kv(LLM_KV_MOE_EVERY_N_LAYERS, hparams.moe_every_n_layers); - add_kv(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers); + add_kv(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.n_layer_nextn); add_kv(LLM_KV_NUM_DEEPSTACK_LAYERS, hparams.n_deepstack_layers); + add_kv(LLM_KV_DEEPSTACK_MAPPING, hparams.deepstack_mapping_arr); add_kv(LLM_KV_POOLING_TYPE, uint32_t(hparams.pooling_type)); add_kv(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); add_kv(LLM_KV_DECODER_START_TOKEN_ID, hparams.dec_start_token_id); @@ -244,7 +245,7 @@ void llama_model_saver::add_kv_from_model() { add_kv(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale); add_kv(LLM_KV_TOKEN_SHIFT_COUNT, hparams.token_shift_count); add_kv(LLM_KV_INTERLEAVE_MOE_LAYER_STEP, hparams.n_moe_layer_step); - // add_kv(LLM_KV_FULL_ATTENTION_INTERVAL, ???); + // add_kv(LLM_KV_FULL_ATTENTION_INTERVAL, ???); // saved as LLM_KV_ATTENTION_RECURRENT_LAYERS instead add_kv(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, true); add_kv(LLM_KV_ATTENTION_HEAD_COUNT_KV, hparams.n_head_kv_arr, true); @@ -278,6 +279,7 @@ void llama_model_saver::add_kv_from_model() { add_kv(LLM_KV_ATTENTION_INDEXER_HEAD_COUNT, hparams.indexer_n_head); add_kv(LLM_KV_ATTENTION_INDEXER_KEY_LENGTH, hparams.indexer_head_size); add_kv(LLM_KV_ATTENTION_INDEXER_TOP_K, hparams.indexer_top_k); + add_kv(LLM_KV_ATTENTION_RECURRENT_LAYERS, hparams.is_recr_impl, true); const float rope_scaling_factor = hparams.rope_freq_scale_train == 1.0f ? 0.0f : 1.0f/hparams.rope_freq_scale_train; diff --git a/examples/talk-llama/llama-model.cpp b/examples/talk-llama/llama-model.cpp index 0c3e03a61dc..4f12e0949ac 100644 --- a/examples/talk-llama/llama-model.cpp +++ b/examples/talk-llama/llama-model.cpp @@ -10,6 +10,7 @@ #include "llama-kv-cache.h" #include "llama-kv-cache-iswa.h" +#include "llama-kv-cache-dsa.h" #include "llama-memory-hybrid.h" #include "llama-memory-hybrid-iswa.h" #include "llama-memory-recurrent.h" @@ -80,6 +81,8 @@ static llama_model * llama_model_mapping(llm_arch arch, const llama_model_params return new llama_model_mpt(params); case LLM_ARCH_STABLELM: return new llama_model_stablelm(params); + case LLM_ARCH_MELLUM: + return new llama_model_mellum(params); case LLM_ARCH_QWEN: return new llama_model_qwen(params); case LLM_ARCH_QWEN2: @@ -136,6 +139,8 @@ static llama_model * llama_model_mapping(llm_arch arch, const llama_model_params return new llama_model_gemma3n(params); case LLM_ARCH_GEMMA4: return new llama_model_gemma4(params); + case LLM_ARCH_GEMMA4_ASSISTANT: + return new llama_model_gemma4_assistant(params); case LLM_ARCH_GEMMA_EMBEDDING: return new llama_model_gemma_embedding(params); case LLM_ARCH_STARCODER2: @@ -172,6 +177,8 @@ static llama_model * llama_model_mapping(llm_arch arch, const llama_model_params return new llama_model_deepseek2(params); case LLM_ARCH_DEEPSEEK2OCR: return new llama_model_deepseek2ocr(params); + case LLM_ARCH_DEEPSEEK32: + return new llama_model_deepseek32(params); case LLM_ARCH_GLM_DSA: return new llama_model_glm_dsa(params); case LLM_ARCH_MISTRAL4: @@ -368,10 +375,10 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str // count only the same type of previous layers to avoid this auto get_il_eff = [&](const size_t il){ size_t ret = 0; - const bool il_is_recurrent = hparams.is_recurrent(il); - const bool il_is_swa = hparams.is_swa(il); + const bool il_is_recr = hparams.is_recr(il); + const bool il_is_swa = hparams.is_swa(il); for (size_t il_prev = 0; il_prev < il; il_prev++) { - ret += hparams.is_recurrent(il_prev) == il_is_recurrent && hparams.is_swa(il_prev) == il_is_swa; + ret += hparams.is_recr(il_prev) == il_is_recr && hparams.is_swa(il_prev) == il_is_swa; } return ret; }; @@ -393,7 +400,7 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str rotation = get_il_eff(il) % ud->n_devices; } else { il = 0; - rotation = hparams.n_layer % ud->n_devices; + rotation = hparams.n_layer() % ud->n_devices; } const ggml_tensor * tensor_axis_0 = suffix.empty() ? tensor : ud->model->get_tensor((prefix + suffix).c_str()); if (tensor_axis_0 == nullptr) { @@ -407,16 +414,16 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str auto get_tensor_config = [&]() -> tensor_config { // standard attention if (std::regex_match(tensor_name, pattern_q_weight) || std::regex_match(tensor_name, pattern_kv_weight)) { - return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_1, "attn_output.weight"); + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_1, "attn_output.weight", "ssm_out.weight"); } if (std::regex_match(tensor_name, pattern_q_bias) || std::regex_match(tensor_name, pattern_kv_bias)) { - return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_0, "attn_output.weight"); + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_0, "attn_output.weight", "ssm_out.weight"); } if (std::regex_match(tensor_name, pattern_qkv_weight)) { - return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_1); + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_1, "attn_output.weight", "ssm_out.weight"); } if ( std::regex_match(tensor_name, pattern_qkv_bias)) { - return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_0); + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_0, "attn_output.weight", "ssm_out.weight"); } if (std::regex_match(tensor_name, pattern_qk_norm)) { return get_tensor_config_impl(tensor->ne[1] == 1 ? GGML_BACKEND_SPLIT_AXIS_MIRRORED : GGML_BACKEND_SPLIT_AXIS_1, "attn_output.weight"); @@ -432,7 +439,7 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str } if (std::regex_match(tensor_name, pattern_attn_gate_weight)) { - return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_1); + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_1, "attn_output.weight", "ssm_out.weight"); } if (std::regex_match(tensor_name, pattern_ssm_dt) || std::regex_match(tensor_name, pattern_ssm_a)) { return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_0, "ssm_out.weight"); @@ -485,7 +492,7 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_MIRRORED); }; - auto get_split_segments = [&](int axis, uint32_t il) -> std::vector<int64_t> { + auto get_split_segments = [&](int axis, uint32_t il) -> std::vector<std::pair<int64_t, uint32_t>> { if (ud->model->arch == LLM_ARCH_QWEN3NEXT || ud->model->arch == LLM_ARCH_QWEN35 || ud->model->arch == LLM_ARCH_QWEN35MOE) { const int64_t head_k_dim = hparams.ssm_d_state; const int64_t head_v_dim = hparams.ssm_d_state; @@ -500,26 +507,26 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str if (ud->model->arch == LLM_ARCH_QWEN3NEXT) { if (std::regex_match(tensor_name, pattern_qkv_weight) || std::regex_match(tensor_name, pattern_ssm_conv1d)) { GGML_ASSERT(tensor->ne[axis] == 2*key_dim + value_dim); - return {key_dim, key_dim, value_dim}; + return {{key_dim, 2}, {value_dim, 1}}; } } else { const int64_t head_ratio = n_v_heads / n_k_heads; if (std::regex_match(tensor_name, pattern_qkv_weight) || std::regex_match(tensor_name, pattern_ssm_conv1d)) { GGML_ASSERT(tensor->ne[axis] == 2*key_dim + value_dim); - return std::vector<int64_t>(2 + head_ratio, key_dim); + return {{key_dim, 2 + head_ratio}}; } if (std::regex_match(tensor_name, pattern_attn_gate_weight) || std::regex_match(tensor_name, pattern_ssm_out_weight)) { - return std::vector<int64_t>(head_ratio, key_dim); + return {{key_dim, head_ratio}}; } if (std::regex_match(tensor_name, pattern_ssm_dt) || std::regex_match(tensor_name, pattern_ssm_a) || std::regex_match(tensor_name, pattern_ssm_alpha) || std::regex_match(tensor_name, pattern_ssm_beta)) { - return std::vector<int64_t>(head_ratio, n_k_heads); + return {{n_k_heads, head_ratio}}; } if (std::regex_match(tensor_name, pattern_r_cache)) { - return std::vector<int64_t>(2 + head_ratio, key_dim * (hparams.ssm_d_conv - 1)); + return {{key_dim * (hparams.ssm_d_conv - 1), 2 + head_ratio}}; } if (std::regex_match(tensor_name, pattern_s_cache)) { - return std::vector<int64_t>(head_ratio, n_k_heads * head_v_dim * head_v_dim); + return {{n_k_heads * head_v_dim * head_v_dim, head_ratio}}; } } @@ -527,9 +534,9 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str if (std::regex_match(tensor_name, pattern_ffn_gate_up_weight)) { const int64_t n_ff_exp = hparams.n_ff_exp; GGML_ASSERT(tensor->ne[axis] == 2*n_ff_exp); - return {n_ff_exp, n_ff_exp}; + return {{n_ff_exp, 2}}; } - return {tensor->ne[axis]}; + return {{tensor->ne[axis], 1}}; } if (std::regex_match(tensor_name, pattern_qkv_weight) || std::regex_match(tensor_name, pattern_qkv_bias)) { @@ -537,21 +544,23 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str const int64_t n_embd_gqa = hparams.n_embd_v_gqa(il); GGML_ASSERT(hparams.n_embd_k_gqa() == n_embd_gqa); GGML_ASSERT(tensor->ne[axis] == n_embd + 2*n_embd_gqa); - return {n_embd, n_embd_gqa, n_embd_gqa}; + return {{n_embd, 1}, {n_embd_gqa, 2}}; } if (std::regex_match(tensor_name, pattern_ffn_gate_up_weight)) { const int64_t n_ff_exp = hparams.n_ff_exp; GGML_ASSERT(tensor->ne[axis] == 2*n_ff_exp); - return {n_ff_exp, n_ff_exp}; + return {{n_ff_exp, 2}}; } - return {tensor->ne[axis]}; + return {{tensor->ne[axis], 1}}; }; - auto get_split_granularity = [&](int64_t blck_size, uint32_t il, const std::vector<int64_t> & segments) -> std::vector<int64_t> { - if (hparams.is_recurrent(il)) { + auto get_split_granularity = [&](int64_t blck_size, uint32_t il, const std::vector<std::pair<int64_t, uint32_t>> & segments) -> std::vector<int64_t> { + // for better performance it may make sense to round up blck_size to a higher power of 2 so that more efficient kernels can be used + if (hparams.is_recr(il)) { // linear attention - const int64_t head_dim = hparams.ssm_d_state; - const int64_t granularity_qkv = std::lcm(blck_size, head_dim); + const int64_t head_dim = hparams.ssm_d_state; + const int64_t blck_size_perf = std::lcm(blck_size, 128); + const int64_t granularity_qkv = std::lcm(blck_size_perf, head_dim); if (std::regex_match(tensor_name, pattern_qkv_weight) || std::regex_match(tensor_name, pattern_attn_gate_weight) || std::regex_match(tensor_name, pattern_ssm_conv1d) || std::regex_match(tensor_name, pattern_ssm_out_weight)) { return std::vector<int64_t>(segments.size(), granularity_qkv); @@ -573,17 +582,24 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str // regular attention const uint32_t n_gqa = hparams.n_gqa(il); const uint32_t n_embd_q = n_gqa * hparams.n_embd_head_k(il); + + // to handle head sizes like 80, only increase granularity while it doesn't cause underutilization + int64_t blck_size_perf = blck_size; + while (blck_size_perf < 128 && blck_size_perf*ud->n_devices < n_embd_q) { + blck_size_perf *= 2; + } + if (std::regex_match(tensor_name, pattern_attn_sinks)) { GGML_ASSERT(segments.size() == 1); - return {std::lcm(n_embd_q, blck_size)/n_embd_q * n_gqa}; + return {std::lcm(n_embd_q, blck_size_perf)/n_embd_q * n_gqa}; } - const int64_t granularity_q = std::lcm(n_embd_q, blck_size); + const int64_t granularity_q = std::lcm(n_embd_q, blck_size_perf); if (std::regex_match(tensor_name, pattern_q_weight) || std::regex_match(tensor_name, pattern_q_bias)) { GGML_ASSERT(segments.size() == 1); // some models have Q gate tensors, for those cases the granularity needs to be doubled: if (ud->model->arch == LLM_ARCH_QWEN3NEXT || ud->model->arch == LLM_ARCH_QWEN35 || ud->model->arch == LLM_ARCH_QWEN35MOE) { - return {std::lcm(2*n_embd_q, blck_size)}; + return {std::lcm(2*n_embd_q, blck_size_perf)}; } return {granularity_q}; } @@ -600,16 +616,17 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str return {granularity_kv}; } if (std::regex_match(tensor_name, pattern_qkv_weight) || std::regex_match(tensor_name, pattern_qkv_bias)) { - GGML_ASSERT(segments.size() == 3); - return {granularity_q, granularity_kv, granularity_kv}; + GGML_ASSERT(segments.size() == 2); + return {granularity_q, granularity_kv}; } } // FFN if (std::regex_match(tensor_name, pattern_ffn_up_gate_weight) || std::regex_match(tensor_name, pattern_ffn_up_gate_bias) || std::regex_match(tensor_name, pattern_ffn_gate_up_weight) || std::regex_match(tensor_name, pattern_ffn_down_weight)) { - GGML_ASSERT(segments.size() <= 2); - return std::vector<int64_t>(segments.size(), blck_size); + const int64_t blck_size_perf = std::lcm(blck_size, 128); + GGML_ASSERT(segments.size() == 1); + return {blck_size_perf}; } // everything else @@ -622,7 +639,6 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str tensor_config tc = get_tensor_config(); split_state.axis = tc.axis; if (split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS) { - const int64_t ne_full = tensor->ne[split_state.axis]; const int64_t blck_size = ggml_blck_size(tc.tensor_axis_0->type); const float * tensor_split = ud->model->tensor_split(); std::vector<float> tensor_split_scan; @@ -633,12 +649,12 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str tensor_split_scan[j] += tensor_split_scan[j - 1]; } } - const std::vector<int64_t> segments = get_split_segments(split_state.axis, tc.il); + const std::vector<std::pair<int64_t, uint32_t>> segments = get_split_segments(split_state.axis, tc.il); const std::vector<int64_t> granularity = get_split_granularity(blck_size, tc.il, segments); for (size_t is = 0; is < segments.size(); is++) { - const int64_t ne_s = segments[is]; - const int64_t g_s = granularity[is]; - GGML_ASSERT(ne_full % g_s == 0); + const int64_t ne_s = segments[is].first; + const uint32_t nr_s = segments[is].second; + const int64_t g_s = granularity[is]; int64_t low = 0; size_t j = 0; for (; j < ud->n_devices - 1; j++) { @@ -651,10 +667,12 @@ struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const str low = high; } split_state.ne[is*ud->n_devices + (j + tc.rotation) % ud->n_devices] = ne_s - low; + split_state.nr[is] = nr_s; } split_state.n_segments = segments.size(); } else { memset(split_state.ne, 0, sizeof(split_state.ne)); + split_state.nr[0] = 1; split_state.n_segments = 1; } return split_state; @@ -758,6 +776,7 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_A13B: return "A13B"; case LLM_TYPE_7B_A1B: return "7B.A1B"; case LLM_TYPE_8B_A1B: return "8B.A1B"; + case LLM_TYPE_12B_A2_5B: return "12B.A2.5B"; case LLM_TYPE_16B_A1B: return "16B.A1B"; case LLM_TYPE_21B_A3B: return "21B.A3B"; case LLM_TYPE_24B_A2B: return "24B.A2B"; @@ -779,6 +798,7 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_310B_A15B: return "310B.A15B"; case LLM_TYPE_355B_A32B: return "355B.A32B"; case LLM_TYPE_397B_A17B: return "397B.A17B"; + case LLM_TYPE_685B_A37B: return "685B.A37B"; case LLM_TYPE_744B_A40B: return "744B.A40B"; case LLM_TYPE_E2B: return "E2B"; case LLM_TYPE_E4B: return "E4B"; @@ -815,6 +835,28 @@ static llama_rope_scaling_type llama_rope_scaling_type_from_string(const std::st return LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED; } +// Maps the GGUF `<arch>.hidden_activation` string to the FFN op type used by the +// graph builders. Only gated activations that map cleanly to llm_ffn_op_type are +// listed; unrecognized values fall back to GeGLU, which matches the historical +// default for ModernBert-style architectures. +static const std::map<std::string, llm_ffn_op_type> LLM_FFN_OP_TYPES_FROM_STRING = { + { "gelu", LLM_FFN_GEGLU }, + { "geglu", LLM_FFN_GEGLU }, + { "silu", LLM_FFN_SWIGLU }, + { "swish", LLM_FFN_SWIGLU }, + { "swiglu", LLM_FFN_SWIGLU }, + { "relu", LLM_FFN_RELU }, + { "reglu", LLM_FFN_REGLU }, +}; + +llm_ffn_op_type llm_ffn_op_type_from_string(const std::string & name, llm_ffn_op_type fallback) { + const auto it = LLM_FFN_OP_TYPES_FROM_STRING.find(name); + if (it != LLM_FFN_OP_TYPES_FROM_STRING.end()) { + return it->second; + } + return fallback; +} + // CPU: ACCEL -> GPU host -> CPU extra -> CPU static buft_list_t make_cpu_buft_list(const std::vector<llama_device> & devices, bool use_extra_bufts, bool no_host) { buft_list_t buft_list; @@ -1002,7 +1044,7 @@ void llama_model_base::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EMBEDDING_LENGTH_OUT, hparams.n_embd_out_impl, false); ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn, false); ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); - ml.get_key(LLM_KV_BLOCK_COUNT, hparams.n_layer); + ml.get_key(LLM_KV_BLOCK_COUNT, hparams.n_layer_all); ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert, false); ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, false); ml.get_key(LLM_KV_EXPERT_GROUP_COUNT, hparams.n_expert_groups, false); @@ -1044,28 +1086,29 @@ void llama_model_base::load_hparams(llama_model_loader & ml) { std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0); std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0); std::fill(hparams.n_ff_arr.begin(), hparams.n_ff_arr.end(), 0); - std::fill( - hparams.recurrent_layer_arr.begin(), - hparams.recurrent_layer_arr.end(), - llm_arch_is_recurrent(ml.get_arch())); std::fill(hparams.rope_sections.begin(), hparams.rope_sections.end(), 0); - std::fill(hparams.swa_layers.begin(), hparams.swa_layers.end(), 0); + std::fill(hparams.is_swa_impl.begin(), hparams.is_swa_impl.end(), 0); + std::fill(hparams.is_recr_impl.begin(), hparams.is_recr_impl.end(), llm_arch_is_recurrent(ml.get_arch()) ? 1 : 0); std::fill(hparams.xielu_alpha_n.begin(), hparams.xielu_alpha_n.end(), 0.0f); std::fill(hparams.xielu_alpha_p.begin(), hparams.xielu_alpha_p.end(), 0.0f); - std::fill(hparams.xielu_beta.begin(), hparams.xielu_beta.end(), 0.0f); - std::fill(hparams.xielu_eps.begin(), hparams.xielu_eps.end(), 0.0f); + std::fill(hparams.xielu_beta.begin(), hparams.xielu_beta.end(), 0.0f); + std::fill(hparams.xielu_eps.begin(), hparams.xielu_eps.end(), 0.0f); + std::fill(hparams.swiglu_clamp_exp.begin(), hparams.swiglu_clamp_exp.end(), 0.0f); std::fill(hparams.swiglu_clamp_shexp.begin(), hparams.swiglu_clamp_shexp.end(), 0.0f); - ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, hparams.n_layer, false); - ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer, false); + ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, hparams.n_layer(), false); + ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer(), false); + + // Populate deepstack_mapping_arr - initialized to -1 (no deepstack) + std::fill(hparams.deepstack_mapping_arr.begin(), hparams.deepstack_mapping_arr.end(), -1); // n_head_kv is optional, default to n_head hparams.n_head_kv_arr = hparams.n_head_arr; - ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT_KV, hparams.n_head_kv_arr, hparams.n_layer, false); + ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT_KV, hparams.n_head_kv_arr, hparams.n_layer(), false); bool rope_finetuned = false; ml.get_key(LLM_KV_ROPE_SCALING_FINETUNED, rope_finetuned, false); @@ -1164,7 +1207,7 @@ bool llama_model_base::load_tensors(llama_model_loader & ml) { const auto & use_mlock = params.use_mlock; const auto & tensor_split = params.tensor_split; - const int n_layer = hparams.n_layer; + const int n_layer_all = hparams.n_layer_all; const int n_gpu_layers = this->n_gpu_layers(); const bool use_mmap_buffer = true; @@ -1221,10 +1264,10 @@ bool llama_model_base::load_tensors(llama_model_loader & ml) { splits[i] /= split_sum; } - const int i_gpu_start = std::max(int(hparams.n_layer) + 1 - n_gpu_layers, 0); - const int act_gpu_layers = devices.empty() ? 0 : std::min(n_gpu_layers, int(n_layer) + 1); + const int i_gpu_start = std::max(n_layer_all + 1 - n_gpu_layers, 0); + const int act_gpu_layers = devices.empty() ? 0 : std::min(n_gpu_layers, n_layer_all + 1); auto get_layer_buft_list = [&](int il) -> llama_model::impl::layer_dev { - const bool is_swa = il < int(hparams.n_layer) && hparams.is_swa(il); + const bool is_swa = il < n_layer_all && hparams.is_swa(il); if (il < i_gpu_start || (il - i_gpu_start) >= act_gpu_layers) { LLAMA_LOG_DEBUG("load_tensors: layer %3d assigned to device %s, is_swa = %d\n", il, ggml_backend_dev_name(cpu_dev), is_swa); return {cpu_dev, &pimpl->cpu_buft_list}; @@ -1240,13 +1283,13 @@ bool llama_model_base::load_tensors(llama_model_loader & ml) { pimpl->dev_input = { cpu_dev, &pimpl->cpu_buft_list }; // assign the repeating layers to the devices according to the splits - pimpl->dev_layer.resize(n_layer); - for (int il = 0; il < n_layer; ++il) { + pimpl->dev_layer.resize(n_layer_all); + for (int il = 0; il < n_layer_all; ++il) { pimpl->dev_layer[il] = get_layer_buft_list(il); } // assign the output layer - pimpl->dev_output = get_layer_buft_list(n_layer); + pimpl->dev_output = get_layer_buft_list(n_layer_all); const auto TENSOR_NOT_REQUIRED = llama_model_loader::TENSOR_NOT_REQUIRED; @@ -1262,14 +1305,14 @@ bool llama_model_base::load_tensors(llama_model_loader & ml) { throw std::runtime_error("model has expert layers but no expert layers are used"); } - layers.resize(n_layer); + layers.resize(n_layer_all); // call the per-model loading function load_arch_tensors(ml); // generic pass: load optional per-tensor/per-expert ".scale" tensors (e.g. NVFP4 scale2) // this avoids having to add scale loading to every architecture - for (int i = 0; i < n_layer; ++i) { + for (int i = 0; i < n_layer_all; ++i) { auto & layer = layers[i]; // attention weight scales (per-tensor, shape {1}) @@ -1527,7 +1570,7 @@ bool llama_model_base::load_tensors(llama_model_loader & ml) { } if (llama_supports_gpu_offload()) { - const int n_gpu = std::min(n_gpu_layers, int(hparams.n_layer)); + const int n_gpu = std::min(n_gpu_layers, n_layer_all); int n_repeating = n_gpu; if (n_repeating > 0) { @@ -1536,8 +1579,8 @@ bool llama_model_base::load_tensors(llama_model_loader & ml) { } LLAMA_LOG_INFO("%s: offloading %d repeating layers to GPU\n", __func__, n_repeating); - const int max_backend_supported_layers = hparams.n_layer + 1; - const int max_offloadable_layers = hparams.n_layer + 1; + const int max_backend_supported_layers = n_layer_all + 1; + const int max_offloadable_layers = n_layer_all + 1; LLAMA_LOG_INFO("%s: offloaded %d/%d layers to GPU\n", __func__, std::min(n_gpu_layers, max_offloadable_layers), max_backend_supported_layers); } @@ -1606,7 +1649,8 @@ const float * llama_model::tensor_split() const { } uint32_t llama_model::n_gpu_layers() const { - return params.n_gpu_layers >= 0 ? params.n_gpu_layers : hparams.n_layer + 1; + // note: plus 1 for the "output" layer + return params.n_gpu_layers >= 0 ? params.n_gpu_layers : hparams.n_layer_all + 1; } llama_split_mode llama_model::split_mode() const { @@ -1639,10 +1683,10 @@ uint64_t llama_model::n_elements() const { void llama_model::print_info() const { const std::string rope_scaling_type = llama_rope_scaling_type_name(hparams.rope_scaling_type_train); - auto print_f = [](const std::function<uint32_t(uint32_t)> & f, uint32_t n) { + auto print_f = [](const std::function<int32_t(uint32_t)> & f, uint32_t n) { bool is_var = false; - std::vector<uint32_t> v; + std::vector<int32_t> v; for (uint32_t i = 0; i < n; ++i) { v.push_back(f(i)); if (v[i] != v[0]) { @@ -1675,19 +1719,21 @@ void llama_model::print_info() const { if (!hparams.vocab_only) { LLAMA_LOG_INFO("%s: n_ctx_train = %u\n", __func__, hparams.n_ctx_train); - LLAMA_LOG_INFO("%s: n_embd = %u\n", __func__, hparams.n_embd); LLAMA_LOG_INFO("%s: n_embd_inp = %u\n", __func__, hparams.n_embd_inp()); - LLAMA_LOG_INFO("%s: n_layer = %u\n", __func__, hparams.n_layer); - LLAMA_LOG_INFO("%s: n_head = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head(il); }, hparams.n_layer).c_str()); - LLAMA_LOG_INFO("%s: n_head_kv = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head_kv(il); }, hparams.n_layer).c_str()); + LLAMA_LOG_INFO("%s: n_embd = %u\n", __func__, hparams.n_embd); + LLAMA_LOG_INFO("%s: n_embd_out = %u\n", __func__, hparams.n_embd_out()); + LLAMA_LOG_INFO("%s: n_layer = %u\n", __func__, hparams.n_layer()); + LLAMA_LOG_INFO("%s: n_layer_all = %u\n", __func__, hparams.n_layer_all); + LLAMA_LOG_INFO("%s: n_head = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head(il); }, hparams.n_layer_all).c_str()); + LLAMA_LOG_INFO("%s: n_head_kv = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head_kv(il); }, hparams.n_layer_all).c_str()); LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot_full); LLAMA_LOG_INFO("%s: n_swa = %u\n", __func__, hparams.n_swa); LLAMA_LOG_INFO("%s: is_swa_any = %u\n", __func__, hparams.is_swa_any()); LLAMA_LOG_INFO("%s: n_embd_head_k = %u\n", __func__, hparams.n_embd_head_k_full); LLAMA_LOG_INFO("%s: n_embd_head_v = %u\n", __func__, hparams.n_embd_head_v_full); - LLAMA_LOG_INFO("%s: n_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_gqa(il); }, hparams.n_layer).c_str()); - LLAMA_LOG_INFO("%s: n_embd_k_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_k_gqa(il); }, hparams.n_layer).c_str()); - LLAMA_LOG_INFO("%s: n_embd_v_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_v_gqa(il); }, hparams.n_layer).c_str()); + LLAMA_LOG_INFO("%s: n_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_gqa(il); }, hparams.n_layer_all).c_str()); + LLAMA_LOG_INFO("%s: n_embd_k_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_k_gqa(il); }, hparams.n_layer_all).c_str()); + LLAMA_LOG_INFO("%s: n_embd_v_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_v_gqa(il); }, hparams.n_layer_all).c_str()); LLAMA_LOG_INFO("%s: f_norm_eps = %.1e\n", __func__, hparams.f_norm_eps); LLAMA_LOG_INFO("%s: f_norm_rms_eps = %.1e\n", __func__, hparams.f_norm_rms_eps); LLAMA_LOG_INFO("%s: f_clamp_kqv = %.1e\n", __func__, hparams.f_clamp_kqv); @@ -1695,7 +1741,7 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: f_logit_scale = %.1e\n", __func__, hparams.f_logit_scale); LLAMA_LOG_INFO("%s: f_attn_scale = %.1e\n", __func__, hparams.f_attention_scale); LLAMA_LOG_INFO("%s: f_attn_value_scale = %.4f\n", __func__, hparams.f_attn_value_scale); - LLAMA_LOG_INFO("%s: n_ff = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_ff(il); }, hparams.n_layer).c_str()); + LLAMA_LOG_INFO("%s: n_ff = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_ff(il); }, hparams.n_layer_all).c_str()); LLAMA_LOG_INFO("%s: n_expert = %u\n", __func__, hparams.n_expert); LLAMA_LOG_INFO("%s: n_expert_used = %u\n", __func__, hparams.n_expert_used); LLAMA_LOG_INFO("%s: n_expert_groups = %d\n", __func__, hparams.n_expert_groups); @@ -1716,6 +1762,14 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: n_ctx_orig_yarn = %u\n", __func__, hparams.n_ctx_orig_yarn); LLAMA_LOG_INFO("%s: rope_yarn_log_mul = %.4f\n", __func__, hparams.rope_yarn_log_mul); LLAMA_LOG_INFO("%s: rope_finetuned = %s\n", __func__, hparams.rope_finetuned ? "yes" : "unknown"); + if (arch == LLM_ARCH_GRANITE && + std::any_of(hparams.deepstack_mapping_arr.begin(), + hparams.deepstack_mapping_arr.end(), + [](const auto & entry) { return entry >= 0; })) { + LLAMA_LOG_INFO("%s: deepstack_mapping_arr = %s\n", __func__, + print_f([&](uint32_t il) { return hparams.deepstack_mapping_arr[il]; }, + hparams.n_layer_all).c_str()); + } // MRoPE (Multi-axis Rotary Position Embedding) sections if (const auto & s = hparams.rope_sections; s[0] || s[1] || s[2] || s[3]) { LLAMA_LOG_INFO("%s: mrope sections = [%d, %d, %d, %d]\n", __func__, s[0], s[1], s[2], s[3]); @@ -1769,7 +1823,7 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); } - if (arch == LLM_ARCH_DEEPSEEK2 || arch == LLM_ARCH_DEEPSEEK2OCR || arch == LLM_ARCH_GLM_DSA || arch == LLM_ARCH_MISTRAL4) { + if (arch == LLM_ARCH_DEEPSEEK2 || arch == LLM_ARCH_DEEPSEEK2OCR || arch == LLM_ARCH_DEEPSEEK32 || arch == LLM_ARCH_GLM_DSA || arch == LLM_ARCH_MISTRAL4) { LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); LLAMA_LOG_INFO("%s: n_lora_q = %d\n", __func__, hparams.n_lora_q); LLAMA_LOG_INFO("%s: n_lora_kv = %d\n", __func__, hparams.n_lora_kv); @@ -1787,7 +1841,11 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); } - if (arch == LLM_ARCH_QWEN3MOE || arch == LLM_ARCH_OPENAI_MOE || arch == LLM_ARCH_QWEN3VLMOE || arch == LLM_ARCH_RND1) { + if (arch == LLM_ARCH_MELLUM || + arch == LLM_ARCH_QWEN3MOE || + arch == LLM_ARCH_OPENAI_MOE || + arch == LLM_ARCH_QWEN3VLMOE || + arch == LLM_ARCH_RND1) { LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); } @@ -1818,7 +1876,7 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); - LLAMA_LOG_INFO("%s: nextn_predict_layers = %d\n", __func__, hparams.nextn_predict_layers); + LLAMA_LOG_INFO("%s: n_layer_nextn = %d\n", __func__, hparams.n_layer_nextn); } if (arch == LLM_ARCH_SMALLTHINKER || arch == LLM_ARCH_LFM2MOE) { @@ -1957,6 +2015,23 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, { res = nullptr; } break; + case LLM_ARCH_DEEPSEEK32: + { + res = new llama_kv_cache_dsa( + *this, + params.type_k, + params.type_v, + !cparams.flash_attn, + cparams.offload_kqv, + cparams.kv_unified, + cparams.n_ctx_seq, + cparams.n_seq_max, + 1, + hparams.n_swa, + hparams.swa_type, + nullptr, + nullptr); + } break; // Models that need standard caching should rely on recurrent/hybrid // checks default: @@ -1983,22 +2058,21 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_memory_hybrid::layer_filter_cb filter_attn = nullptr; llama_memory_hybrid::layer_filter_cb filter_recr = nullptr; if (arch == LLM_ARCH_FALCON_H1) { - filter_attn = [&](int32_t) { return true; }; - filter_recr = [&](int32_t) { return true; }; + filter_attn = [&](uint32_t) { return true; }; + filter_recr = [&](uint32_t) { return true; }; } else if (arch == LLM_ARCH_NEMOTRON_H || arch == LLM_ARCH_NEMOTRON_H_MOE) { - filter_attn = [&](int32_t il) { - return !hparams.is_recurrent(il) && hparams.n_ff(il) == 0; + filter_attn = [&](uint32_t il) { + return !hparams.is_recr(il) && hparams.n_ff(il) == 0; }; - filter_recr = [&](int32_t il) { - return hparams.is_recurrent(il) && hparams.n_ff(il) == 0; + filter_recr = [&](uint32_t il) { + return hparams.is_recr(il) && hparams.n_ff(il) == 0; }; } else if (arch == LLM_ARCH_QWEN35 || arch == LLM_ARCH_QWEN35MOE) { - const uint32_t n_main = hparams.n_layer - hparams.nextn_predict_layers; - filter_attn = [&, n_main](int32_t il) { - return (uint32_t)il < n_main && !hparams.is_recurrent(il); + filter_attn = [&](uint32_t il) { + return il < hparams.n_layer() && !hparams.is_recr(il); }; - filter_recr = [&, n_main](int32_t il) { - return (uint32_t)il < n_main && hparams.is_recurrent(il); + filter_recr = [&](uint32_t il) { + return il < hparams.n_layer() && hparams.is_recr(il); }; } @@ -2043,13 +2117,16 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, /* filter_recr */ std::move(filter_recr)); } } else { - llama_memory_i::layer_reuse_cb reuse = nullptr; llama_kv_cache::layer_filter_cb filter = nullptr; + llama_memory_i::layer_reuse_cb reuse = nullptr; + llama_kv_cache::layer_share_cb share = nullptr; if (arch == LLM_ARCH_GEMMA3N || arch == LLM_ARCH_GEMMA4) { - reuse = [&](int32_t il) { - if (il >= (int32_t) hparams.n_layer_kv_from_start) { - return (int32_t) hparams.n_layer_kv_from_start - (hparams.is_swa(il) ? 2 : 1); + reuse = [&](uint32_t il) { + GGML_ASSERT(hparams.n_layer_kv_from_start >= 2); + + if (il >= (uint32_t)hparams.n_layer_kv_from_start) { + return hparams.n_layer_kv_from_start - (hparams.is_swa(il) ? 2 : 1); } return -1; @@ -2057,32 +2134,73 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, } if (mtp_on_hybrid_qwen35) { - const uint32_t n_main = hparams.n_layer - hparams.nextn_predict_layers; - filter = [n_main](int32_t il) { return (uint32_t)il >= n_main; }; + filter = [&](uint32_t il) { return il >= hparams.n_layer(); }; + } + + if (arch == LLM_ARCH_STEP35 && hparams.n_layer_nextn > 0) { + if (params.ctx_type == LLAMA_CONTEXT_TYPE_MTP) { + filter = [&](uint32_t il) { return il >= hparams.n_layer(); }; + } else { + filter = [&](uint32_t il) { return il < hparams.n_layer(); }; + } } if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { GGML_ASSERT(hparams.is_swa_any()); - res = new llama_kv_cache_iswa( - *this, - params.type_k, - params.type_v, - !cparams.flash_attn, - cparams.offload_kqv, - params.swa_full, - cparams.kv_unified, - cparams.n_ctx_seq, - cparams.n_seq_max, - cparams.n_ubatch, - 1, - filter, - reuse); + if (arch == LLM_ARCH_GEMMA4_ASSISTANT) { + llama_memory_t mem_other = llama_get_memory(cparams.ctx_other); + + share = [&](int32_t il) { + const llama_model * model_other = llama_get_model(cparams.ctx_other); + + if (hparams.is_swa(il)) { + return llama_model_n_layer(model_other) - 2; + } + + return llama_model_n_layer(model_other) - 1; + }; + + res = new llama_kv_cache_iswa( + *this, + params.type_k, + params.type_v, + !cparams.flash_attn, + cparams.offload_kqv, + params.swa_full, + cparams.kv_unified, + cparams.n_ctx_seq, + cparams.n_seq_max, + cparams.n_ubatch, + 1, + mem_other, + filter, + reuse, + share); + } else { + res = new llama_kv_cache_iswa( + *this, + params.type_k, + params.type_v, + !cparams.flash_attn, + cparams.offload_kqv, + params.swa_full, + cparams.kv_unified, + cparams.n_ctx_seq, + cparams.n_seq_max, + cparams.n_ubatch, + 1, + nullptr, + filter, + reuse, + share); + } } else { GGML_ASSERT(!hparams.is_swa_any()); res = new llama_kv_cache( *this, + hparams, params.type_k, params.type_v, !cparams.flash_attn, @@ -2093,7 +2211,9 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, 1, hparams.n_swa, hparams.swa_type, + nullptr, filter, + nullptr, nullptr); } } @@ -2181,7 +2301,7 @@ int32_t llama_model_n_embd_out(const llama_model * model) { } int32_t llama_model_n_layer(const llama_model * model) { - return model->hparams.n_layer; + return model->hparams.n_layer(); } int32_t llama_model_n_head(const llama_model * model) { @@ -2272,6 +2392,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_DEEPSEEK: case LLM_ARCH_DEEPSEEK2: case LLM_ARCH_DEEPSEEK2OCR: + case LLM_ARCH_DEEPSEEK32: case LLM_ARCH_PLM: case LLM_ARCH_CHATGLM: case LLM_ARCH_GRANITE: @@ -2325,6 +2446,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_GEMMA3: case LLM_ARCH_GEMMA3N: case LLM_ARCH_GEMMA4: + case LLM_ARCH_GEMMA4_ASSISTANT: case LLM_ARCH_GEMMA_EMBEDDING: case LLM_ARCH_STARCODER2: case LLM_ARCH_OPENELM: @@ -2356,6 +2478,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_MIMO2: case LLM_ARCH_STEP35: case LLM_ARCH_TALKIE: + case LLM_ARCH_MELLUM: return LLAMA_ROPE_TYPE_NEOX; case LLM_ARCH_QWEN2VL: diff --git a/examples/talk-llama/llama-model.h b/examples/talk-llama/llama-model.h index b797b8966ac..992c8d9c8fd 100644 --- a/examples/talk-llama/llama-model.h +++ b/examples/talk-llama/llama-model.h @@ -116,6 +116,7 @@ enum llm_type { LLM_TYPE_A13B, LLM_TYPE_7B_A1B, LLM_TYPE_8B_A1B, // lfm2moe + LLM_TYPE_12B_A2_5B, LLM_TYPE_16B_A1B, LLM_TYPE_21B_A3B, // Ernie MoE small LLM_TYPE_24B_A2B, // lfm2moe @@ -137,6 +138,7 @@ enum llm_type { LLM_TYPE_310B_A15B, // /MiMo-V2-Flash LLM_TYPE_355B_A32B, // GLM-4.5 LLM_TYPE_397B_A17B, // Qwen3.5 + LLM_TYPE_685B_A37B, // DeepSeek V3.2 LLM_TYPE_744B_A40B, // GLM-5 LLM_TYPE_E2B, LLM_TYPE_E4B, @@ -144,6 +146,10 @@ enum llm_type { std::string llama_rope_scaling_type_name(llama_rope_scaling_type rope_scaling_type); +// Map a GGUF activation-name string to llm_ffn_op_type. Returns `fallback` if +// the string is empty or not recognized. +llm_ffn_op_type llm_ffn_op_type_from_string(const std::string & name, llm_ffn_op_type fallback); + struct llama_layer_posnet { // resnet struct ggml_tensor * norm1 = nullptr; @@ -542,6 +548,10 @@ struct llama_model { struct ggml_tensor * output_s = nullptr; struct ggml_tensor * output_in_s = nullptr; + // NextN/MTP model-level projections + struct ggml_tensor * nextn_proj_pre = nullptr; + struct ggml_tensor * nextn_proj_post = nullptr; + // classifier struct ggml_tensor * cls = nullptr; struct ggml_tensor * cls_b = nullptr; @@ -694,7 +704,9 @@ const char * llm_type_name(llm_type type); // convenience macro for loading local variables for load_tensors() in llama_model_base // note: cast to int64_t since we will use these for the tensor dimensions #define LLAMA_LOAD_LOCALS \ - const int n_layer = hparams.n_layer; GGML_UNUSED(n_layer); \ + const int n_layer = hparams.n_layer(); GGML_UNUSED(n_layer); \ + const int n_layer_all = hparams.n_layer_all; GGML_UNUSED(n_layer_all); \ + const int n_layer_nextn = hparams.n_layer_nextn; GGML_UNUSED(n_layer_nextn); \ const int64_t n_head = hparams.n_head(); GGML_UNUSED(n_head); \ const int64_t n_head_kv = hparams.n_head_kv(); GGML_UNUSED(n_head_kv); \ const int64_t n_embd = hparams.n_embd; GGML_UNUSED(n_embd); \ diff --git a/examples/talk-llama/llama-quant.cpp b/examples/talk-llama/llama-quant.cpp index 43e05c3d56f..cf92ce4bb8b 100644 --- a/examples/talk-llama/llama-quant.cpp +++ b/examples/talk-llama/llama-quant.cpp @@ -847,7 +847,7 @@ static void init_quantize_state_counters(quantize_state_impl & qs, std::vector<t qs.has_tied_embeddings = false; } } - qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)qs.model.hparams.n_layer; + qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)qs.model.hparams.n_layer(); } // @@ -1348,7 +1348,7 @@ llama_model * llama_quant_model_from_metadata(const llama_quant_model_desc * des model->hparams.n_embd = desc->n_embd; model->hparams.n_embd_head_k_full = desc->n_embd_head_k; model->hparams.n_embd_head_v_full = desc->n_embd_head_v; - model->hparams.n_layer = desc->n_layer; + model->hparams.n_layer_all = desc->n_layer; model->hparams.n_expert = desc->n_expert; for (uint32_t i = 0; i < desc->n_layer; i++) { diff --git a/examples/talk-llama/llama-vocab.cpp b/examples/talk-llama/llama-vocab.cpp index 473becade82..9a4bed49487 100644 --- a/examples/talk-llama/llama-vocab.cpp +++ b/examples/talk-llama/llama-vocab.cpp @@ -353,6 +353,7 @@ struct llm_tokenizer_bpe : llm_tokenizer { case LLAMA_VOCAB_PRE_TYPE_CODESHELL: case LLAMA_VOCAB_PRE_TYPE_EXAONE: case LLAMA_VOCAB_PRE_TYPE_MINERVA: + case LLAMA_VOCAB_PRE_TYPE_MELLUM2: regex_exprs = { "\\p{N}", "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", @@ -432,6 +433,15 @@ struct llm_tokenizer_bpe : llm_tokenizer { "[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", }; break; + case LLAMA_VOCAB_PRE_TYPE_GRANITE_EMB_MULTI: + // Same lookaheads as GPT4O but with \p{M} added so combining marks + // (diacritics) attach to their base letters. Avoids excessive + // backtracking on scripts that use them heavily (Bengali, Hindi, + // Telugu, Thai, ...). See PR #22716 for benchmarks. + regex_exprs = { + "[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}\\p{M}])([^a-z]))*((?=[\\p{L}\\p{M}])([^A-Z]))+(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}\\p{M}])([^a-z]))+((?=[\\p{L}\\p{M}])([^A-Z]))*(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + }; + break; case LLAMA_VOCAB_PRE_TYPE_TINY_AYA: regex_exprs = { // original regex from tokenizer.json: "\\d{1,3}(?=(?:\\d{3})*\\b)" @@ -519,6 +529,13 @@ struct llm_tokenizer_bpe : llm_tokenizer { "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}+| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", }; break; + case LLAMA_VOCAB_PRE_TYPE_WHITESPACE: + // whitespace pre-tokenizer (jinaai/jina-embeddings-v2-base-zh) + regex_exprs = { + "\\S+", + }; + byte_encode = false; + break; default: // default regex for BPE tokenization pre-processing regex_exprs = { @@ -747,7 +764,7 @@ struct llm_tokenizer_wpm_session { void tokenize(const std::string & text, std::vector<llama_token> & output) { // normalize and split by whitespace - std::vector<std::string> words = preprocess(text); + std::vector<std::string> words = preprocess(text, vocab.get_normalizer_lowercase()); // bos token prepended already // find the longest tokens that form the words @@ -792,7 +809,7 @@ struct llm_tokenizer_wpm_session { } // TODO: reduce string copies by using cpts_offs array - static std::vector<std::string> preprocess(const std::string & text) { + static std::vector<std::string> preprocess(const std::string & text, bool lowercase) { const std::vector<uint32_t> cpts_nfd = unicode_cpts_normalize_nfd(unicode_cpts_from_utf8(text)); std::vector<std::string> words(1, ""); @@ -811,7 +828,7 @@ struct llm_tokenizer_wpm_session { continue; } - const std::string s = unicode_cpt_to_utf8(unicode_tolower(cpt)); + const std::string s = unicode_cpt_to_utf8(lowercase ? unicode_tolower(cpt) : cpt); if (flags.is_punctuation || ( cpt < 0x7F && flags.is_symbol ) || is_chinese_char(cpt)) { if (words.back().size()) { // finish previous word if any words.emplace_back(); @@ -1671,6 +1688,35 @@ struct llm_tokenizer_hybriddna_session : llm_tokenizer_bpe_session { const llama_vocab & vocab; }; +struct llm_tokenizer_whitespace_session : llm_tokenizer_bpe_session { + llm_tokenizer_whitespace_session(const llama_vocab & vocab, const llm_tokenizer_bpe & tokenizer) : llm_tokenizer_bpe_session{vocab, tokenizer}, vocab{vocab} {} + + void tokenize(const std::string & text, std::vector<llama_token> & output) override { + const bool lowercase = vocab.get_normalizer_lowercase(); + + std::string segment; + auto flush = [&]() { + if (!segment.empty()) { + llm_tokenizer_bpe_session::tokenize(segment, output); + segment.clear(); + } + }; + + for (uint32_t cpt : unicode_cpts_from_utf8(text)) { + // drop whitespace + if (unicode_cpt_flags_from_cpt(cpt).is_whitespace) { + flush(); + } else { + segment += unicode_cpt_to_utf8(lowercase ? unicode_tolower(cpt) : cpt); + } + } + flush(); + } + +private: + const llama_vocab & vocab; +}; + // // impl // @@ -1751,6 +1797,7 @@ struct llama_vocab::impl { bool remove_extra_whitespaces = false; bool escape_whitespaces = true; bool treat_whitespace_as_suffix = false; + bool normalizer_lowercase = true; // Lowercase normalizer (tokenizer.json) std::unordered_map<std::string, llama_token> token_to_id; std::vector<token_data> id_to_token; @@ -1768,6 +1815,8 @@ struct llama_vocab::impl { // set of all tokens that cause "end of generation" std::set<llama_token> special_eog_ids; + std::vector<llama_token> suppress_tokens; + std::unique_ptr<llm_tokenizer> tokenizer; std::vector<char> precompiled_charsmap; @@ -1900,7 +1949,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { special_mask_id = 103; add_sep = true; - } else if (tokenizer_model == "gpt2" || tokenizer_model == "hybriddna") { + } else if (tokenizer_model == "gpt2" || tokenizer_model == "hybriddna" || tokenizer_model == "whitespace") { type = LLAMA_VOCAB_TYPE_BPE; // read bpe merges and populate bpe ranks @@ -2105,7 +2154,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "jais-2") { pre_type = LLAMA_VOCAB_PRE_TYPE_JAIS2; } else if ( - tokenizer_pre == "gemma4") { + tokenizer_pre == "gemma4" || + tokenizer_pre == "granite-embed-multi-311m") { pre_type = LLAMA_VOCAB_PRE_TYPE_GEMMA4; escape_whitespaces = true; } else if ( @@ -2119,6 +2169,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "roberta-bpe") { pre_type = LLAMA_VOCAB_PRE_TYPE_GPT2; add_sep = true; + } else if ( + tokenizer_pre == "whitespace") { + pre_type = LLAMA_VOCAB_PRE_TYPE_WHITESPACE; + normalizer_lowercase = false; } else if ( tokenizer_pre == "refact") { pre_type = LLAMA_VOCAB_PRE_TYPE_REFACT; @@ -2211,6 +2265,11 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "talkie") { pre_type = LLAMA_VOCAB_PRE_TYPE_GPT4O; clean_spaces = false; + } else if ( + tokenizer_pre == "granite-embed-multi-97m") { + pre_type = LLAMA_VOCAB_PRE_TYPE_GRANITE_EMB_MULTI; + clean_spaces = false; + ignore_merges = true; } else if ( tokenizer_pre == "tiny_aya") { pre_type = LLAMA_VOCAB_PRE_TYPE_TINY_AYA; @@ -2269,6 +2328,9 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "solar-open") { pre_type = LLAMA_VOCAB_PRE_TYPE_SOLAR_OPEN; clean_spaces = false; + } else if ( + tokenizer_pre == "mellum2") { + pre_type = LLAMA_VOCAB_PRE_TYPE_MELLUM2; } else { throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str())); } @@ -2470,6 +2532,19 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { } } + // Lowercase normalizer flag (consulted by WPM / whitespace BPE) + ml.get_key(LLM_KV_TOKENIZER_NORMALIZER_LOWERCASE, normalizer_lowercase, false); + + // suppress tokens + { + const int suppress_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_SUPPRESS_TOKENS).c_str()); + if (suppress_idx != -1) { + const int n = gguf_get_arr_n(ctx, suppress_idx); + const int32_t * data = (const int32_t *) gguf_get_arr_data(ctx, suppress_idx); + suppress_tokens.assign(data, data + n); + } + } + // auto-detect special tokens by text // TODO: convert scripts should provide these tokens through the KV metadata LLM_KV_TOKENIZER_... // for now, we apply this workaround to find the tokens based on their text @@ -3264,6 +3339,8 @@ std::vector<llama_token> llama_vocab::impl::tokenize( std::unique_ptr<llm_tokenizer_bpe_session> session; if (vocab.get_tokenizer_model() == "hybriddna") { session = std::make_unique<llm_tokenizer_hybriddna_session>(vocab, *tok_bpe); + } else if (vocab.get_tokenizer_model() == "whitespace") { + session = std::make_unique<llm_tokenizer_whitespace_session>(vocab, *tok_bpe); } else { session = std::make_unique<llm_tokenizer_bpe_session>(vocab, *tok_bpe); } @@ -3892,6 +3969,14 @@ bool llama_vocab::get_treat_whitespace_as_suffix() const { return pimpl->treat_whitespace_as_suffix; } +bool llama_vocab::get_normalizer_lowercase() const { + return pimpl->normalizer_lowercase; +} + +const std::vector<llama_token> & llama_vocab::get_suppress_tokens() const { + return pimpl->suppress_tokens; +} + int llama_vocab::max_token_len() const { return pimpl->max_token_len; } diff --git a/examples/talk-llama/llama-vocab.h b/examples/talk-llama/llama-vocab.h index 8ab77594284..2626ae36e33 100644 --- a/examples/talk-llama/llama-vocab.h +++ b/examples/talk-llama/llama-vocab.h @@ -8,59 +8,62 @@ // pre-tokenization types enum llama_vocab_pre_type { - LLAMA_VOCAB_PRE_TYPE_DEFAULT = 0, - LLAMA_VOCAB_PRE_TYPE_LLAMA3 = 1, - LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM = 2, - LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER = 3, - LLAMA_VOCAB_PRE_TYPE_FALCON = 4, - LLAMA_VOCAB_PRE_TYPE_MPT = 5, - LLAMA_VOCAB_PRE_TYPE_STARCODER = 6, - LLAMA_VOCAB_PRE_TYPE_GPT2 = 7, - LLAMA_VOCAB_PRE_TYPE_REFACT = 8, - LLAMA_VOCAB_PRE_TYPE_COMMAND_R = 9, - LLAMA_VOCAB_PRE_TYPE_STABLELM2 = 10, - LLAMA_VOCAB_PRE_TYPE_QWEN2 = 11, - LLAMA_VOCAB_PRE_TYPE_OLMO = 12, - LLAMA_VOCAB_PRE_TYPE_DBRX = 13, - LLAMA_VOCAB_PRE_TYPE_SMAUG = 14, - LLAMA_VOCAB_PRE_TYPE_PORO = 15, - LLAMA_VOCAB_PRE_TYPE_CHATGLM3 = 16, - LLAMA_VOCAB_PRE_TYPE_CHATGLM4 = 17, - LLAMA_VOCAB_PRE_TYPE_VIKING = 18, - LLAMA_VOCAB_PRE_TYPE_JAIS = 19, - LLAMA_VOCAB_PRE_TYPE_TEKKEN = 20, - LLAMA_VOCAB_PRE_TYPE_SMOLLM = 21, - LLAMA_VOCAB_PRE_TYPE_CODESHELL = 22, - LLAMA_VOCAB_PRE_TYPE_BLOOM = 23, - LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH = 24, - LLAMA_VOCAB_PRE_TYPE_EXAONE = 25, - LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26, - LLAMA_VOCAB_PRE_TYPE_MINERVA = 27, - LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28, - LLAMA_VOCAB_PRE_TYPE_GPT4O = 29, - LLAMA_VOCAB_PRE_TYPE_SUPERBPE = 30, - LLAMA_VOCAB_PRE_TYPE_TRILLION = 31, - LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32, - LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33, - LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34, - LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35, - LLAMA_VOCAB_PRE_TYPE_HUNYUAN = 36, - LLAMA_VOCAB_PRE_TYPE_KIMI_K2 = 37, - LLAMA_VOCAB_PRE_TYPE_HUNYUAN_DENSE = 38, - LLAMA_VOCAB_PRE_TYPE_GROK_2 = 39, - LLAMA_VOCAB_PRE_TYPE_GRANITE_DOCLING = 40, - LLAMA_VOCAB_PRE_TYPE_MINIMAX_M2 = 41, - LLAMA_VOCAB_PRE_TYPE_AFMOE = 42, - LLAMA_VOCAB_PRE_TYPE_SOLAR_OPEN = 43, - LLAMA_VOCAB_PRE_TYPE_YOUTU = 44, - LLAMA_VOCAB_PRE_TYPE_EXAONE_MOE = 45, - LLAMA_VOCAB_PRE_TYPE_QWEN35 = 46, - LLAMA_VOCAB_PRE_TYPE_TINY_AYA = 47, - LLAMA_VOCAB_PRE_TYPE_JOYAI_LLM = 48, - LLAMA_VOCAB_PRE_TYPE_JAIS2 = 49, - LLAMA_VOCAB_PRE_TYPE_GEMMA4 = 50, - LLAMA_VOCAB_PRE_TYPE_SARVAM_MOE = 51, - LLAMA_VOCAB_PRE_TYPE_MINICPM5 = 52, + LLAMA_VOCAB_PRE_TYPE_DEFAULT = 0, + LLAMA_VOCAB_PRE_TYPE_LLAMA3 = 1, + LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM = 2, + LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER = 3, + LLAMA_VOCAB_PRE_TYPE_FALCON = 4, + LLAMA_VOCAB_PRE_TYPE_MPT = 5, + LLAMA_VOCAB_PRE_TYPE_STARCODER = 6, + LLAMA_VOCAB_PRE_TYPE_GPT2 = 7, + LLAMA_VOCAB_PRE_TYPE_REFACT = 8, + LLAMA_VOCAB_PRE_TYPE_COMMAND_R = 9, + LLAMA_VOCAB_PRE_TYPE_STABLELM2 = 10, + LLAMA_VOCAB_PRE_TYPE_QWEN2 = 11, + LLAMA_VOCAB_PRE_TYPE_OLMO = 12, + LLAMA_VOCAB_PRE_TYPE_DBRX = 13, + LLAMA_VOCAB_PRE_TYPE_SMAUG = 14, + LLAMA_VOCAB_PRE_TYPE_PORO = 15, + LLAMA_VOCAB_PRE_TYPE_CHATGLM3 = 16, + LLAMA_VOCAB_PRE_TYPE_CHATGLM4 = 17, + LLAMA_VOCAB_PRE_TYPE_VIKING = 18, + LLAMA_VOCAB_PRE_TYPE_JAIS = 19, + LLAMA_VOCAB_PRE_TYPE_TEKKEN = 20, + LLAMA_VOCAB_PRE_TYPE_SMOLLM = 21, + LLAMA_VOCAB_PRE_TYPE_CODESHELL = 22, + LLAMA_VOCAB_PRE_TYPE_BLOOM = 23, + LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH = 24, + LLAMA_VOCAB_PRE_TYPE_EXAONE = 25, + LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26, + LLAMA_VOCAB_PRE_TYPE_MINERVA = 27, + LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28, + LLAMA_VOCAB_PRE_TYPE_GPT4O = 29, + LLAMA_VOCAB_PRE_TYPE_SUPERBPE = 30, + LLAMA_VOCAB_PRE_TYPE_TRILLION = 31, + LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32, + LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33, + LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34, + LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35, + LLAMA_VOCAB_PRE_TYPE_HUNYUAN = 36, + LLAMA_VOCAB_PRE_TYPE_KIMI_K2 = 37, + LLAMA_VOCAB_PRE_TYPE_HUNYUAN_DENSE = 38, + LLAMA_VOCAB_PRE_TYPE_GROK_2 = 39, + LLAMA_VOCAB_PRE_TYPE_GRANITE_DOCLING = 40, + LLAMA_VOCAB_PRE_TYPE_MINIMAX_M2 = 41, + LLAMA_VOCAB_PRE_TYPE_AFMOE = 42, + LLAMA_VOCAB_PRE_TYPE_SOLAR_OPEN = 43, + LLAMA_VOCAB_PRE_TYPE_YOUTU = 44, + LLAMA_VOCAB_PRE_TYPE_EXAONE_MOE = 45, + LLAMA_VOCAB_PRE_TYPE_QWEN35 = 46, + LLAMA_VOCAB_PRE_TYPE_TINY_AYA = 47, + LLAMA_VOCAB_PRE_TYPE_JOYAI_LLM = 48, + LLAMA_VOCAB_PRE_TYPE_JAIS2 = 49, + LLAMA_VOCAB_PRE_TYPE_GEMMA4 = 50, + LLAMA_VOCAB_PRE_TYPE_SARVAM_MOE = 51, + LLAMA_VOCAB_PRE_TYPE_MINICPM5 = 52, + LLAMA_VOCAB_PRE_TYPE_WHITESPACE = 53, + LLAMA_VOCAB_PRE_TYPE_GRANITE_EMB_MULTI = 54, + LLAMA_VOCAB_PRE_TYPE_MELLUM2 = 55, }; struct LLM_KV; @@ -138,6 +141,9 @@ struct llama_vocab { bool get_remove_extra_whitespaces () const; bool get_escape_whitespaces () const; bool get_treat_whitespace_as_suffix() const; + bool get_normalizer_lowercase () const; + + const std::vector<llama_token> & get_suppress_tokens() const; int max_token_len() const; diff --git a/examples/talk-llama/llama.cpp b/examples/talk-llama/llama.cpp index dfe30ce8f61..a67fa8039a4 100644 --- a/examples/talk-llama/llama.cpp +++ b/examples/talk-llama/llama.cpp @@ -225,7 +225,9 @@ static bool llama_prepare_model_devices(const llama_model_params & params, llama } case GGML_BACKEND_DEVICE_TYPE_IGPU: - igpus.push_back({false, dev}); + if (igpus.empty()) { + igpus.push_back({false, dev}); + } break; case GGML_BACKEND_DEVICE_TYPE_META: GGML_ABORT("fatal error"); @@ -239,8 +241,9 @@ static bool llama_prepare_model_devices(const llama_model_params & params, llama // add GPUs model->devices.insert(model->devices.end(), gpus.begin(), gpus.end()); - // add integrated GPUs only if no other devices were found - if (model->devices.empty()) { + // add integrated GPUs only if no discrete GPUs were found + // (RPC servers do not count, otherwise the local iGPU would be dropped on iGPU+RPC setups) + if (gpus.empty()) { model->devices.insert(model->devices.end(), igpus.begin(), igpus.end()); } } diff --git a/examples/talk-llama/llama.h b/examples/talk-llama/llama.h index e8374c53b70..27e48067428 100644 --- a/examples/talk-llama/llama.h +++ b/examples/talk-llama/llama.h @@ -339,6 +339,7 @@ extern "C" { uint32_t n_ubatch; // physical maximum batch size uint32_t n_seq_max; // max number of sequences (i.e. distinct states for recurrent models) uint32_t n_rs_seq; // number of recurrent-state snapshots per seq for rollback (0 = no rollback) [EXPERIMENTAL] + uint32_t n_outputs_max; // max outputs in a ubatch (0 = n_batch) int32_t n_threads; // number of threads to use for generation int32_t n_threads_batch; // number of threads to use for batch processing @@ -387,6 +388,10 @@ extern "C" { // note: the samplers must be sampler chains (i.e. use llama_sampler_chain_init) struct llama_sampler_seq_config * samplers; size_t n_samplers; + + // a source/target/parent context + // can be utilized in various ways, for example by sharing results or llama_memory between 2 contexts + struct llama_context * ctx_other; }; struct llama_model_tensor_override { @@ -975,7 +980,11 @@ extern "C" { // Set whether the model is in warmup mode or not // If true, all model tensors are activated during llama_decode() to load and cache their weights. - LLAMA_API void llama_set_warmup(struct llama_context * ctx, bool warmup); + // + // note: using this can cause extra graph reallocations because it changes the graph topology with MoE models, + // so it is generally not recommended to use in practice. will be removed in the future + DEPRECATED(LLAMA_API void llama_set_warmup(struct llama_context * ctx, bool warmup), + "user code should do warmup runs manually [TAG_LLAMA_GRAPH_NO_WARMUP]"); // Set abort callback LLAMA_API void llama_set_abort_callback(struct llama_context * ctx, ggml_abort_callback abort_callback, void * abort_callback_data); diff --git a/examples/talk-llama/models/afmoe.cpp b/examples/talk-llama/models/afmoe.cpp index a7c77ee5d28..063b214256e 100644 --- a/examples/talk-llama/models/afmoe.cpp +++ b/examples/talk-llama/models/afmoe.cpp @@ -30,7 +30,7 @@ void llama_model_afmoe::load_arch_hparams(llama_model_loader & ml) { hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID; } - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 56: type = LLM_TYPE_6B; break; case 32: type = LLM_TYPE_26B; break; default: type = LLM_TYPE_UNKNOWN; diff --git a/examples/talk-llama/models/apertus.cpp b/examples/talk-llama/models/apertus.cpp index bec7136521c..6dfb8905fbe 100644 --- a/examples/talk-llama/models/apertus.cpp +++ b/examples/talk-llama/models/apertus.cpp @@ -2,12 +2,13 @@ void llama_model_apertus::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key_or_arr(LLM_KV_XIELU_ALPHA_N, hparams.xielu_alpha_n, hparams.n_layer); - ml.get_key_or_arr(LLM_KV_XIELU_ALPHA_P, hparams.xielu_alpha_p, hparams.n_layer); - ml.get_key_or_arr(LLM_KV_XIELU_BETA, hparams.xielu_beta, hparams.n_layer); - ml.get_key_or_arr(LLM_KV_XIELU_EPS, hparams.xielu_eps, hparams.n_layer); - switch (hparams.n_layer) { + ml.get_key_or_arr(LLM_KV_XIELU_ALPHA_N, hparams.xielu_alpha_n, hparams.n_layer()); + ml.get_key_or_arr(LLM_KV_XIELU_ALPHA_P, hparams.xielu_alpha_p, hparams.n_layer()); + ml.get_key_or_arr(LLM_KV_XIELU_BETA, hparams.xielu_beta, hparams.n_layer()); + ml.get_key_or_arr(LLM_KV_XIELU_EPS, hparams.xielu_eps, hparams.n_layer()); + + switch (hparams.n_layer()) { case 32: type = LLM_TYPE_8B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/arcee.cpp b/examples/talk-llama/models/arcee.cpp index d086c4717ff..9536e7c5d42 100644 --- a/examples/talk-llama/models/arcee.cpp +++ b/examples/talk-llama/models/arcee.cpp @@ -4,7 +4,7 @@ void llama_model_arcee::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); // Arcee uses the same structure as Llama - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 36: type = LLM_TYPE_4B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/arctic.cpp b/examples/talk-llama/models/arctic.cpp index 27deadffeb7..09ee0f752f0 100644 --- a/examples/talk-llama/models/arctic.cpp +++ b/examples/talk-llama/models/arctic.cpp @@ -4,7 +4,7 @@ void llama_model_arctic::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); if (hparams.n_expert == 128) { - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 35: type = LLM_TYPE_10B_128x3_66B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/arwkv7.cpp b/examples/talk-llama/models/arwkv7.cpp index 9bd04127b25..b38b2064785 100644 --- a/examples/talk-llama/models/arwkv7.cpp +++ b/examples/talk-llama/models/arwkv7.cpp @@ -10,7 +10,7 @@ void llama_model_arwkv7::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_GATE_LORA_RANK, hparams.n_lora_gate, false); ml.get_key(LLM_KV_TOKEN_SHIFT_COUNT, hparams.token_shift_count, false); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 12: switch (hparams.n_embd) { case 768: type = LLM_TYPE_190M; break; diff --git a/examples/talk-llama/models/baichuan.cpp b/examples/talk-llama/models/baichuan.cpp index 4d26081cd5d..585f3614174 100644 --- a/examples/talk-llama/models/baichuan.cpp +++ b/examples/talk-llama/models/baichuan.cpp @@ -2,7 +2,7 @@ void llama_model_baichuan::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 32: type = LLM_TYPE_7B; break; case 40: type = LLM_TYPE_13B; break; default: type = LLM_TYPE_UNKNOWN; diff --git a/examples/talk-llama/models/bailingmoe.cpp b/examples/talk-llama/models/bailingmoe.cpp index fe1ae10864b..7faf73c835b 100644 --- a/examples/talk-llama/models/bailingmoe.cpp +++ b/examples/talk-llama/models/bailingmoe.cpp @@ -8,7 +8,7 @@ void llama_model_bailingmoe::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 28: type = LLM_TYPE_16B; break; case 88: type = LLM_TYPE_290B; break; default: type = LLM_TYPE_UNKNOWN; diff --git a/examples/talk-llama/models/bailingmoe2.cpp b/examples/talk-llama/models/bailingmoe2.cpp index 2f0d44a6259..5000e9c6db8 100644 --- a/examples/talk-llama/models/bailingmoe2.cpp +++ b/examples/talk-llama/models/bailingmoe2.cpp @@ -9,17 +9,13 @@ void llama_model_bailingmoe2::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func); - ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); - GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.n_layer_nextn, false); - // TODO: when MTP is implemented, this should probably be updated if needed - hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; + GGML_ASSERT(hparams.n_layer_nextn < hparams.n_layer_all && "n_layer_nextn must be < n_layer_impl"); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 20: type = LLM_TYPE_16B_A1B; break; - case 21: type = LLM_TYPE_16B_A1B; break; case 32: type = LLM_TYPE_100B_A6B; break; - case 33: type = LLM_TYPE_100B_A6B; break; default: type = LLM_TYPE_UNKNOWN; } } @@ -39,9 +35,9 @@ void llama_model_bailingmoe2::load_arch_tensors(llama_model_loader &) { GGML_ASSERT(n_expert > 0 && "n_expert must be > 0 for bailingmoe2"); GGML_ASSERT(n_expert_used > 0 && "n_expert_used must be > 0 for bailingmoe2"); - for (int i = 0; i < n_layer; ++i) { + for (int i = 0; i < n_layer_all; ++i) { int flags = 0; - if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) { + if (i >= n_layer) { // skip all tensors in the NextN layers flags |= TENSOR_SKIP; } @@ -78,7 +74,7 @@ void llama_model_bailingmoe2::load_arch_tensors(llama_model_loader &) { } // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers - if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) { + if (i >= n_layer) { layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags); layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED | flags); layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags); @@ -112,8 +108,7 @@ llama_model_bailingmoe2::graph::graph(const llama_model & model, const llm_graph ggml_tensor * inp_out_ids = build_inp_out_ids(); - const int n_transformer_layers = n_layer - hparams.nextn_predict_layers; - for (int il = 0; il < n_transformer_layers; ++il) { + for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; // norm @@ -146,7 +141,7 @@ llama_model_bailingmoe2::graph::graph(const llama_model & model, const llm_graph Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); } - if (il == n_transformer_layers - 1 && inp_out_ids) { + if (il == n_layer - 1 && inp_out_ids) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } diff --git a/examples/talk-llama/models/bert.cpp b/examples/talk-llama/models/bert.cpp index 3c28f419ccf..53ce29f23ca 100644 --- a/examples/talk-llama/models/bert.cpp +++ b/examples/talk-llama/models/bert.cpp @@ -1,9 +1,9 @@ #include "models.h" void llama_model_bert::load_arch_hparams(llama_model_loader & ml) { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 3: type = LLM_TYPE_17M; break; // bge-micro case 6: diff --git a/examples/talk-llama/models/bitnet.cpp b/examples/talk-llama/models/bitnet.cpp index 7e8125deec4..c8330274580 100644 --- a/examples/talk-llama/models/bitnet.cpp +++ b/examples/talk-llama/models/bitnet.cpp @@ -3,7 +3,7 @@ void llama_model_bitnet::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 26: type = LLM_TYPE_3B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/bloom.cpp b/examples/talk-llama/models/bloom.cpp index 30b0f3d07d0..609d2ddf998 100644 --- a/examples/talk-llama/models/bloom.cpp +++ b/examples/talk-llama/models/bloom.cpp @@ -3,7 +3,7 @@ void llama_model_bloom::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 24: type = LLM_TYPE_1B; break; case 30: switch (hparams.n_embd) { diff --git a/examples/talk-llama/models/chameleon.cpp b/examples/talk-llama/models/chameleon.cpp index 4bceaefd63b..4f45acecf84 100644 --- a/examples/talk-llama/models/chameleon.cpp +++ b/examples/talk-llama/models/chameleon.cpp @@ -6,7 +6,7 @@ void llama_model_chameleon::load_arch_hparams(llama_model_loader & ml) { hparams.f_norm_eps = 1e-5; // eps for qk-norm, torch default ml.get_key(LLM_KV_SWIN_NORM, hparams.swin_norm, false); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 32: type = LLM_TYPE_7B; break; case 48: type = LLM_TYPE_34B; break; default: type = LLM_TYPE_UNKNOWN; diff --git a/examples/talk-llama/models/chatglm.cpp b/examples/talk-llama/models/chatglm.cpp index 6766fa71c15..7ae5b938fde 100644 --- a/examples/talk-llama/models/chatglm.cpp +++ b/examples/talk-llama/models/chatglm.cpp @@ -2,7 +2,8 @@ void llama_model_chatglm::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 28: { if (hparams.n_head(0) == 16) { type = LLM_TYPE_1_5B; diff --git a/examples/talk-llama/models/codeshell.cpp b/examples/talk-llama/models/codeshell.cpp index 274dd3342a7..de53bb98184 100644 --- a/examples/talk-llama/models/codeshell.cpp +++ b/examples/talk-llama/models/codeshell.cpp @@ -2,7 +2,8 @@ void llama_model_codeshell::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 42: type = LLM_TYPE_7B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/cogvlm.cpp b/examples/talk-llama/models/cogvlm.cpp index 2e231bb3f93..750f57a394e 100644 --- a/examples/talk-llama/models/cogvlm.cpp +++ b/examples/talk-llama/models/cogvlm.cpp @@ -2,7 +2,8 @@ void llama_model_cogvlm::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 32: type = LLM_TYPE_13B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/cohere2.cpp b/examples/talk-llama/models/cohere2.cpp index a514cf88fc6..61a5945a194 100644 --- a/examples/talk-llama/models/cohere2.cpp +++ b/examples/talk-llama/models/cohere2.cpp @@ -5,6 +5,7 @@ void llama_model_cohere2::load_arch_hparams(llama_model_loader & ml) { uint32_t swa_period = 4; ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); hparams.set_swa_pattern(swa_period); + hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; @@ -12,7 +13,8 @@ void llama_model_cohere2::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 32: type = LLM_TYPE_8B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/command-r.cpp b/examples/talk-llama/models/command-r.cpp index adf7fcaa20f..94a46188bb8 100644 --- a/examples/talk-llama/models/command-r.cpp +++ b/examples/talk-llama/models/command-r.cpp @@ -3,7 +3,8 @@ void llama_model_command_r::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale, false); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 40: type = LLM_TYPE_35B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/dbrx.cpp b/examples/talk-llama/models/dbrx.cpp index af71c775365..4f5ac4d06a4 100644 --- a/examples/talk-llama/models/dbrx.cpp +++ b/examples/talk-llama/models/dbrx.cpp @@ -1,14 +1,14 @@ #include "models.h" void llama_model_dbrx::load_arch_hparams(llama_model_loader & ml) { -ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); -ml.get_key(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv); -switch (hparams.n_layer) { - case 40: type = LLM_TYPE_16x12B; break; - default: type = LLM_TYPE_UNKNOWN; + switch (hparams.n_layer()) { + case 40: type = LLM_TYPE_16x12B; break; + default: type = LLM_TYPE_UNKNOWN; + } } - } void llama_model_dbrx::load_arch_tensors(llama_model_loader &) { LLAMA_LOAD_LOCALS; diff --git a/examples/talk-llama/models/deci.cpp b/examples/talk-llama/models/deci.cpp index 567e3535276..cdfcf29e02f 100644 --- a/examples/talk-llama/models/deci.cpp +++ b/examples/talk-llama/models/deci.cpp @@ -2,7 +2,8 @@ void llama_model_deci::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 32: type = LLM_TYPE_7B; break; case 80: type = LLM_TYPE_70B; break; case 162: type = LLM_TYPE_405B; break; diff --git a/examples/talk-llama/models/deepseek2.cpp b/examples/talk-llama/models/deepseek2.cpp index 1fe54adc13e..a9e8bc51403 100644 --- a/examples/talk-llama/models/deepseek2.cpp +++ b/examples/talk-llama/models/deepseek2.cpp @@ -5,7 +5,7 @@ void llama_model_deepseek2::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_VOCAB_SIZE, n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, n_vocab, false); // lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B, Kanana-2-30B-A3B - const bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26 || (hparams.n_layer == 48 && n_vocab == 128256)); + const bool is_lite = (hparams.n_layer() == 27 || hparams.n_layer() == 26 || (hparams.n_layer() == 48 && n_vocab == 128256)); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); @@ -23,7 +23,7 @@ void llama_model_deepseek2::load_arch_hparams(llama_model_loader & ml) { if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { // for compatibility with existing DeepSeek V2 and V2.5 GGUFs // that have no expert_gating_func model parameter set - if ((hparams.n_layer == 47 || hparams.n_layer == 48) && n_vocab == 154880) { + if ((hparams.n_layer() == 47 || hparams.n_layer() == 48) && n_vocab == 154880) { // GLM 4.7 Lite hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID; } else { @@ -43,7 +43,7 @@ void llama_model_deepseek2::load_arch_hparams(llama_model_loader & ml) { hparams.f_attn_temp_offset = 0.0f; - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 27: type = LLM_TYPE_16B; break; case 47: type = LLM_TYPE_30B_A3B; break; case 60: type = LLM_TYPE_236B; break; @@ -191,8 +191,7 @@ llama_model_deepseek2::graph::graph(const llama_model & model, const llm_graph_p ggml_tensor * inp_out_ids = build_inp_out_ids(); - int effective_n_layers = hparams.n_layer - hparams.nextn_predict_layers; - for (int il = 0; il < effective_n_layers; ++il) { + for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; // norm @@ -366,7 +365,7 @@ llama_model_deepseek2::graph::graph(const llama_model & model, const llm_graph_p Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); } } - if (il == effective_n_layers - 1 && inp_out_ids) { + if (il == n_layer - 1 && inp_out_ids) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } diff --git a/examples/talk-llama/models/deepseek2ocr.cpp b/examples/talk-llama/models/deepseek2ocr.cpp index f9e4c98785c..65d31c31b93 100644 --- a/examples/talk-llama/models/deepseek2ocr.cpp +++ b/examples/talk-llama/models/deepseek2ocr.cpp @@ -14,7 +14,7 @@ void llama_model_deepseek2ocr::load_arch_hparams(llama_model_loader & ml) { hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX; } - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 12: type = LLM_TYPE_3B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/deepseek32.cpp b/examples/talk-llama/models/deepseek32.cpp new file mode 100644 index 00000000000..9a20e2ce907 --- /dev/null +++ b/examples/talk-llama/models/deepseek32.cpp @@ -0,0 +1,499 @@ +#include "models.h" + +#include "llama-kv-cache.h" +#include "llama-kv-cache-dsa.h" + +void llama_model_deepseek32::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + hparams.f_norm_eps = 1e-6; // eps for layer norm + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, false); + + // MoE parameters + ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert); + ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + + // deepseek MLA parameters + ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); + ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla_impl, false); + ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla_impl, false); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + + // DSA parameters + ml.get_key(LLM_KV_ATTENTION_INDEXER_HEAD_COUNT, hparams.indexer_n_head); + ml.get_key(LLM_KV_ATTENTION_INDEXER_KEY_LENGTH, hparams.indexer_head_size); + ml.get_key(LLM_KV_ATTENTION_INDEXER_TOP_K, hparams.indexer_top_k); + + // Expert gating function + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func); + + if (ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, 0.0f)) { + // [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX] + // cancel the factor from the convert script + hparams.rope_yarn_log_mul /= 0.1f; + } + + // NextN/MTP parameters + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.n_layer_nextn, false); + GGML_ASSERT(hparams.n_layer_nextn < hparams.n_layer_all && "n_layer_nextn must be < n_layer"); + + switch (hparams.n_layer()) { + case 62: type = LLM_TYPE_685B_A37B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_deepseek32::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + const bool is_mla = hparams.is_mla(); + if (!is_mla) { + throw std::runtime_error("DEEPSEEK32 architecture requires MLA"); + } + + // note: these are the actual head sizes you get when treating as MHA or after "decompression" using wv_b for MLA + const int64_t n_embd_head_k_mla = hparams.n_embd_head_k_mla(); + const int64_t n_embd_head_v_mla = hparams.n_embd_head_v_mla(); + + const int64_t n_embd_head_qk_rope = hparams.n_rot(); + const int64_t n_embd_head_qk_nope = n_embd_head_k_mla - n_embd_head_qk_rope; + + const int64_t q_lora_rank = hparams.n_lora_q; + const int64_t kv_lora_rank = hparams.n_lora_kv; + + const int64_t n_ff_exp = hparams.n_ff_exp; + const int64_t n_expert_shared = hparams.n_expert_shared; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + // try to load output.weight, if not found, use token_embd (tied embeddings) + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + if (!output) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer_all; ++i) { + int flags = 0; + if (i >= n_layer) { + // skip all tensors in the NextN layers + // TODO @ngxson : TENSOR_NOT_REQUIRED was a hack, need to remove it later + flags |= TENSOR_SKIP | TENSOR_NOT_REQUIRED; + } + + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, flags); + layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, flags); + layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, flags); + + layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, flags); + layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k_mla}, flags); + + layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + n_embd_head_qk_rope}, flags); + + // note: only old legacy GGUF files will have the unsplit wkv_b tensor in + layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K_B, "weight", i), {n_embd_head_qk_nope, kv_lora_rank, n_head}, flags); + layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_embd_head_v_mla, n_head}, flags); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_embd_head_v_mla, n_embd}, flags); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, flags); + + // DSA indexer + layer.indexer_k_norm = create_tensor(tn(LLM_TENSOR_INDEXER_K_NORM, "weight", i), {hparams.indexer_head_size}, flags); + layer.indexer_k_norm_b = create_tensor(tn(LLM_TENSOR_INDEXER_K_NORM, "bias", i), {hparams.indexer_head_size}, flags); + layer.indexer_proj = create_tensor(tn(LLM_TENSOR_INDEXER_PROJ, "weight", i), {n_embd, hparams.indexer_n_head}, flags); + layer.indexer_attn_k = create_tensor(tn(LLM_TENSOR_INDEXER_ATTN_K, "weight", i), {n_embd, hparams.indexer_head_size}, flags); + layer.indexer_attn_q_b = create_tensor(tn(LLM_TENSOR_INDEXER_ATTN_Q_B, "weight", i), {q_lora_rank, hparams.indexer_n_head * hparams.indexer_head_size}, flags); + if (i < (int) hparams.n_layer_dense_lead) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, flags); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, flags); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, flags); + } else { + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, flags); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0"); + } + + // MoE branch + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, flags); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, flags); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, flags); + + // Shared expert branch + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, flags); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, flags); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, flags); + } + + // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers + if (i >= n_layer) { + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags); + + // Optional tensors + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, flags | TENSOR_NOT_REQUIRED); + } + } +} + +std::unique_ptr<llm_graph_context> llama_model_deepseek32::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_deepseek32::graph::graph(const llama_model & model, const llm_graph_params & params) : + llm_graph_context(params) { + const bool is_mla = hparams.is_mla(); + GGML_ASSERT(is_mla); + + // note: these are the actual head sizes you get when treating as MHA or after "decompression" using wv_b for MLA + const int64_t n_embd_head_k = hparams.n_embd_head_k_mla(); + const int64_t n_embd_head_v = hparams.n_embd_head_v_mla(); + GGML_UNUSED(n_embd_head_v); + + const int64_t n_embd_head_qk_rope = hparams.n_rot(); + const int64_t n_embd_head_qk_nope = n_embd_head_k - n_embd_head_qk_rope; + + const int64_t n_indexer_head = hparams.indexer_n_head; + const int64_t n_embd_indexer_head = hparams.indexer_head_size; + const int64_t n_embd_indexer_head_rope = hparams.n_rot(); + const int64_t n_embd_indexer_head_nope = n_embd_indexer_head - n_embd_indexer_head_rope; + const uint32_t n_indexer_top_k = hparams.indexer_top_k; + + const uint32_t kv_lora_rank = hparams.n_lora_kv; + + // We have to pre-scale kq_scale and attn_factor to make the YaRN RoPE work correctly. + // See https://github.com/ggml-org/llama.cpp/discussions/7416 for detailed explanation. + // And also: https://github.com/ggml-org/llama.cpp/pull/17945 [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX] + + // first cancel the adjustment from llama_hparams::yarn_attn_factor_adjust to get the original attn_factor + GGML_ASSERT(ext_factor >= 0.0f); + const float attn_factor_org = attn_factor * (1.0f + 0.1f * logf(1.0f / freq_scale)); + + // use the original attn_factor to pre-scale the kq_scale + const float mscale = attn_factor_org * (1.0f + 0.1f * hparams.rope_yarn_log_mul * logf(1.0f / freq_scale)); + const float kq_scale = 1.0f * mscale * mscale / sqrtf(float(n_embd_head_k)); + + ggml_tensor * cur; + ggml_tensor * inpL; + + // {n_embd, n_tokens} + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + llm_graph_input_attn_k_dsa * inp_attn_dsa = build_attn_inp_k_dsa(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self_attention + { + ggml_tensor * qr = ggml_mul_mat(ctx0, model.layers[il].wq_a, cur); + cb(qr, "qr", il); + + qr = build_norm(qr, model.layers[il].attn_q_a_norm, nullptr, LLM_NORM_RMS, il); + cb(qr, "qr", il); + + ggml_tensor * top_k = nullptr; + + // lightning indexer + { + ggml_tensor * indexer_q = ggml_mul_mat(ctx0, model.layers[il].indexer_attn_q_b, qr); + cb(indexer_q, "indexer_q", il); + + // split into {n_embd_indexer_head_rope, n_indexer_head, n_tokens} + ggml_tensor * indexer_q_pe = + ggml_view_3d(ctx0, indexer_q, n_embd_indexer_head_rope, n_indexer_head, n_tokens, + ggml_row_size(indexer_q->type, n_embd_indexer_head), + ggml_row_size(indexer_q->type, n_embd_indexer_head) * n_indexer_head, 0); + cb(indexer_q_pe, "indexer_q_pe", il); + + // and {n_embd_indexer_head_nope, n_indexer_head, n_tokens} + ggml_tensor * indexer_q_nope = + ggml_view_3d(ctx0, indexer_q, n_embd_indexer_head_nope, n_indexer_head, n_tokens, + ggml_row_size(indexer_q->type, n_embd_indexer_head), + ggml_row_size(indexer_q->type, n_embd_indexer_head) * n_indexer_head, + ggml_row_size(indexer_q->type, n_embd_indexer_head_nope)); + cb(indexer_q_nope, "indexer_q_nope", il); + + indexer_q_pe = ggml_rope_ext(ctx0, indexer_q_pe, inp_pos, nullptr, n_rot, + LLAMA_ROPE_TYPE_NEOX, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + cb(indexer_q_pe, "indexer_q_pe", il); + + // {n_embd_indexer_head_rope + n_embd_indexer_head_nope, n_head, n_tokens} + indexer_q = ggml_concat(ctx0, indexer_q_pe, indexer_q_nope, 0); + cb(indexer_q, "indexer_q", il); + + ggml_tensor * indexer_k = ggml_mul_mat(ctx0, model.layers[il].indexer_attn_k, cur); + cb(indexer_k, "indexer_k", il); + + indexer_k = build_norm(indexer_k, model.layers[il].indexer_k_norm, model.layers[il].indexer_k_norm_b, LLM_NORM, il); + cb(indexer_k, "indexer_k", il); + + // split into {n_embd_indexer_head_rope, 1, n_tokens} + ggml_tensor * indexer_k_pe = + ggml_view_3d(ctx0, indexer_k, n_embd_indexer_head_rope, 1, n_tokens, + ggml_row_size(indexer_k->type, n_embd_indexer_head), + ggml_row_size(indexer_k->type, n_embd_indexer_head) * 1, 0); + cb(indexer_k_pe, "indexer_k_pe", il); + + // and {n_embd_indexer_head_nope, 1, n_tokens} + ggml_tensor * indexer_k_nope = + ggml_view_3d(ctx0, indexer_k, n_embd_indexer_head_nope, 1, n_tokens, + ggml_row_size(indexer_k->type, n_embd_indexer_head), + ggml_row_size(indexer_k->type, n_embd_indexer_head) * 1, + ggml_row_size(indexer_k->type, n_embd_indexer_head_nope)); + cb(indexer_k_nope, "indexer_k_nope", il); + + indexer_k_pe = ggml_rope_ext(ctx0, indexer_k_pe, inp_pos, nullptr, n_rot, + LLAMA_ROPE_TYPE_NEOX, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + cb(indexer_k_pe, "indexer_k_pe", il); + + // {n_embd_indexer_head_rope + n_embd_indexer_head_nope, 1, n_tokens} + indexer_k = ggml_concat(ctx0, indexer_k_pe, indexer_k_nope, 0); + cb(indexer_k, "indexer_k", il); + + // perform Hadamard transform on indexer q and k + indexer_q = ggml_mul_mat(ctx0, inp_attn_dsa->self_k_rot_lid, indexer_q); + cb(indexer_q, "indexer_q", il); + indexer_k = ggml_mul_mat(ctx0, inp_attn_dsa->self_k_rot_lid, indexer_k); + cb(indexer_k, "indexer_k", il); + + // store indexer keys to KV cache + const auto * mctx_lid = inp_attn_dsa->mctx->get_lid(); + const auto & k_idxs_lid = inp_attn_dsa->get_k_idxs_lid(); + ggml_build_forward_expand(gf, mctx_lid->cpy_k(ctx0, indexer_k, k_idxs_lid, il)); + + // prepare indexer weights + ggml_tensor * indexer_weights = ggml_mul_mat(ctx0, model.layers[il].indexer_proj, cur); + cb(indexer_weights, "indexer_weights", il); + + // get cached indexer keys + indexer_k = mctx_lid->get_k(ctx0, il); + + // split the batch into streams if needed + const auto n_stream = indexer_k->ne[3]; + indexer_q = ggml_view_4d(ctx0, indexer_q, indexer_q->ne[0], indexer_q->ne[1], indexer_q->ne[2]/n_stream, n_stream, indexer_q->nb[1], indexer_q->nb[2], indexer_q->nb[3]/n_stream, 0); + indexer_weights = ggml_view_4d(ctx0, indexer_weights, indexer_weights->ne[0], indexer_weights->ne[1]/n_stream, indexer_weights->ne[2], n_stream, indexer_weights->nb[1], indexer_weights->nb[2]/n_stream, indexer_weights->nb[3]/n_stream, 0); + + // calculate indexer kq + indexer_q = ggml_permute(ctx0, indexer_q, 0, 2, 1, 3); + cb(indexer_q, "indexer_q", il); + indexer_k = ggml_permute(ctx0, indexer_k, 0, 2, 1, 3); + cb(indexer_k, "indexer_k", il); + + ggml_tensor * indexer_kq = ggml_mul_mat(ctx0, indexer_k, indexer_q); + cb(indexer_kq, "indexer_kq", il); + + // ReLU requires contiguous tensors + indexer_kq = ggml_cont(ctx0, ggml_permute(ctx0, indexer_kq, 2, 1, 0, 3)); + cb(indexer_kq, "indexer_kq", il); + + // apply ReLU + ggml_tensor * indexer_score = ggml_relu(ctx0, indexer_kq); + cb(indexer_score, "indexer_score", il); + + // pre-scale weights to avoid scaling operations on huge indexer_score tensor + indexer_weights = ggml_scale(ctx0, indexer_weights, 1.0f / sqrtf(float(n_embd_indexer_head * n_indexer_head))); + cb(indexer_weights, "indexer_weights", il); + + // multiply scores by indexer weights + indexer_score = ggml_mul(ctx0, indexer_score, indexer_weights); + cb(indexer_score, "indexer_score", il); + + // sum by q n_indexer_head dimension + indexer_score = ggml_sum_rows(ctx0, indexer_score); + cb(indexer_score, "indexer_score", il); + + // permute result to match KQ mask + indexer_score = ggml_cont(ctx0, ggml_permute(ctx0, indexer_score, 2, 1, 0, 3)); + cb(indexer_score, "indexer_score", il); + + // mask indexer scores + ggml_tensor * indexer_kq_mask = inp_attn_dsa->get_kq_mask_lid(); + indexer_score = ggml_add(ctx0, indexer_score, indexer_kq_mask); + cb(indexer_score, "indexer_score", il); + + // get indices of top k indexer scores + uint32_t n_top_k = indexer_score->ne[0] < n_indexer_top_k ? indexer_score->ne[0] : n_indexer_top_k; + top_k = ggml_cont(ctx0, ggml_top_k(ctx0, indexer_score, n_top_k)); + cb(top_k, "top_k", il); + } + + ggml_tensor * q = ggml_mul_mat(ctx0, model.layers[il].wq_b, qr); + cb(q, "q", il); + + // split into {n_embd_head_qk_nope, n_head, n_tokens} + ggml_tensor * q_nope = + ggml_view_3d(ctx0, q, n_embd_head_qk_nope, n_head, n_tokens, ggml_row_size(q->type, n_embd_head_k), + ggml_row_size(q->type, n_embd_head_k) * n_head, 0); + cb(q_nope, "q_nope", il); + + // and {n_embd_head_qk_rope, n_head, n_tokens} + ggml_tensor * q_pe = ggml_view_3d( + ctx0, q, n_embd_head_qk_rope, n_head, n_tokens, ggml_row_size(q->type, n_embd_head_k), + ggml_row_size(q->type, n_embd_head_k) * n_head, ggml_row_size(q->type, n_embd_head_qk_nope)); + cb(q_pe, "q_pe", il); + + ggml_tensor * kv_cmpr_pe = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur); + cb(kv_cmpr_pe, "kv_cmpr_pe", il); + + // split into {kv_lora_rank, n_tokens} + ggml_tensor * kv_cmpr = + ggml_view_2d(ctx0, kv_cmpr_pe, kv_lora_rank, n_tokens, + ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope), 0); + cb(kv_cmpr, "kv_cmpr", il); + + // and {n_embd_head_qk_rope, 1, n_tokens} + ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_cmpr_pe, n_embd_head_qk_rope, 1, n_tokens, + ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope), + ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope), + ggml_row_size(kv_cmpr_pe->type, kv_lora_rank)); + cb(k_pe, "k_pe", il); + + q_pe = ggml_rope_ext(ctx0, q_pe, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + cb(q_pe, "q_pe", il); + + k_pe = ggml_rope_ext(ctx0, k_pe, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + cb(k_pe, "k_pe", il); + + kv_cmpr = build_norm(kv_cmpr, model.layers[il].attn_kv_a_norm, nullptr, LLM_NORM_RMS, il); + cb(kv_cmpr, "kv_cmpr", il); + + // MLA attention + { + // {n_embd_head_qk_nope, n_tokens, n_head} + q_nope = ggml_permute(ctx0, q_nope, 0, 2, 1, 3); + cb(q_nope, "q_nope_perm", il); + + // {n_embd_head_qk_nope, kv_lora_rank, n_head} x {n_embd_head_qk_nope, n_tokens, n_head} + ggml_tensor * q_nope_absorbed = ggml_mul_mat(ctx0, model.layers[il].wk_b, q_nope); + cb(q_nope_absorbed, "q_nope_absorbed", il); + + // {kv_lora_rank, n_head, n_tokens} + q_nope_absorbed = ggml_permute(ctx0, q_nope_absorbed, 0, 2, 1, 3); + cb(q_nope_absorbed, "q_nope_absorbed_perm", il); + + // {n_embd_head_qk_rope + kv_lora_rank, n_head, n_tokens} + // note: rope must go first for in-place context shifting in build_rope_shift() + ggml_tensor * Qcur = ggml_concat(ctx0, q_nope_absorbed, q_pe, 0); + cb(Qcur, "Qcur", il); + + kv_cmpr = ggml_reshape_3d(ctx0, kv_cmpr, kv_lora_rank, 1, n_tokens); + cb(kv_cmpr, "kv_cmpr_reshape", il); + + // {n_embd_head_qk_rope + kv_lora_rank, 1, n_tokens} + ggml_tensor * Kcur = ggml_concat(ctx0, kv_cmpr, k_pe, 0); + cb(Kcur, "Kcur", il); + + // {kv_lora_rank, 1, n_tokens} + ggml_tensor * Vcur = kv_cmpr; + cb(Vcur, "Vcur", il); + + // note: MLA with the absorption optimization converts into MQA (ie: GQA with 1 group) + cur = build_attn(inp_attn_dsa, + model.layers[il].wo, NULL, model.layers[il].wo_s, + Qcur, Kcur, Vcur, nullptr, nullptr, model.layers[il].wv_b, top_k, kq_scale, il); + } + } + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + cur = build_norm(ffn_inp, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + if ((uint32_t) il < hparams.n_layer_dense_lead) { + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, model.layers[il].ffn_up_s, + model.layers[il].ffn_gate, NULL, model.layers[il].ffn_gate_s, + model.layers[il].ffn_down, NULL, model.layers[il].ffn_down_s, + NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } else { + // MoE branch + ggml_tensor * moe_out = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + model.layers[il].ffn_exp_probs_b, + n_expert, n_expert_used, + LLM_FFN_SILU, hparams.expert_weights_norm, + hparams.expert_weights_scale, + (llama_expert_gating_func_type) hparams.expert_gating_func, + il, + nullptr, + model.layers[il].ffn_gate_up_exps, + model.layers[il].ffn_up_exps_s, + model.layers[il].ffn_gate_exps_s, + model.layers[il].ffn_down_exps_s); + cb(moe_out, "ffn_moe_out", il); + + // FFN shared expert + { + ggml_tensor * ffn_shexp = + build_ffn(cur, + model.layers[il].ffn_up_shexp, NULL, model.layers[il].ffn_up_shexp_s, + model.layers[il].ffn_gate_shexp, NULL, model.layers[il].ffn_gate_shexp_s, + model.layers[il].ffn_down_shexp, NULL, model.layers[il].ffn_down_shexp_s, + NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(ffn_shexp, "ffn_shexp", il); + + cur = ggml_add(ctx0, moe_out, ffn_shexp); + cb(cur, "ffn_out", il); + } + } + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + cur = inpL; + + cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = ggml_mul_mat(ctx0, model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} diff --git a/examples/talk-llama/models/dots1.cpp b/examples/talk-llama/models/dots1.cpp index 435d27281c6..07d6ab1b7cd 100644 --- a/examples/talk-llama/models/dots1.cpp +++ b/examples/talk-llama/models/dots1.cpp @@ -8,7 +8,8 @@ void llama_model_dots1::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 62: type = LLM_TYPE_142B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/dream.cpp b/examples/talk-llama/models/dream.cpp index 12ac6f1ce88..abe737c335a 100644 --- a/examples/talk-llama/models/dream.cpp +++ b/examples/talk-llama/models/dream.cpp @@ -2,8 +2,9 @@ void llama_model_dream::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + // Dream models are primarily 7B with 28 layers - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 28: type = LLM_TYPE_7B; break; diff --git a/examples/talk-llama/models/ernie4-5.cpp b/examples/talk-llama/models/ernie4-5.cpp index 9b39c605e35..895cf690bd2 100644 --- a/examples/talk-llama/models/ernie4-5.cpp +++ b/examples/talk-llama/models/ernie4-5.cpp @@ -12,7 +12,7 @@ void llama_model_ernie4_5::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); } - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 18: type = LLM_TYPE_0_3B; break; case 28: type = LLM_TYPE_21B_A3B; break; case 54: type = LLM_TYPE_300B_A47B; break; diff --git a/examples/talk-llama/models/eurobert.cpp b/examples/talk-llama/models/eurobert.cpp index ddf13c3028f..0948d7de656 100644 --- a/examples/talk-llama/models/eurobert.cpp +++ b/examples/talk-llama/models/eurobert.cpp @@ -3,7 +3,7 @@ void llama_model_eurobert::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - if (hparams.n_layer == 12) { + if (hparams.n_layer() == 12) { type = LLM_TYPE_SMALL; // 0.2B } } diff --git a/examples/talk-llama/models/exaone-moe.cpp b/examples/talk-llama/models/exaone-moe.cpp index 76d91982fc5..5aed9379400 100644 --- a/examples/talk-llama/models/exaone-moe.cpp +++ b/examples/talk-llama/models/exaone-moe.cpp @@ -20,13 +20,12 @@ void llama_model_exaone_moe::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); - ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); - GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.n_layer_nextn, false); + GGML_ASSERT(hparams.n_layer_nextn < hparams.n_layer_all && "n_layer_nextn must be < n_layer_impl"); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 32: type = LLM_TYPE_30B_A3B; break; - case 48: - case 49: type = LLM_TYPE_235B_A22B; break; + case 48: type = LLM_TYPE_235B_A22B; break; default: type = LLM_TYPE_UNKNOWN; } } @@ -50,9 +49,9 @@ void llama_model_exaone_moe::load_arch_tensors(llama_model_loader &) { output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); } - for (int i = 0; i < n_layer; ++i) { + for (int i = 0; i < n_layer_all; ++i) { int flags = 0; - if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) { + if (i >= n_layer) { // skip all tensors in the NextN layers flags |= TENSOR_SKIP; } @@ -70,7 +69,7 @@ void llama_model_exaone_moe::load_arch_tensors(llama_model_loader &) { layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, flags); // dense layers for first n_layer_dense_lead layers or nextn_predict_layers layers at the end - if (i < (int) hparams.n_layer_dense_lead || (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers)) { + if (i < (int) hparams.n_layer_dense_lead || (i >= n_layer)) { layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, flags); layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, flags); layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, flags); @@ -95,7 +94,7 @@ void llama_model_exaone_moe::load_arch_tensors(llama_model_loader &) { } // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers - if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) { + if (i >= n_layer) { layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), {2 * n_embd, n_embd}, flags); layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), {n_embd}, flags); layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), {n_embd}, flags); @@ -130,8 +129,7 @@ llama_model_exaone_moe::graph::graph(const llama_model & model, const llm_graph_ ggml_tensor * inp_out_ids = build_inp_out_ids(); - const int n_transformer_layers = n_layer - hparams.nextn_predict_layers; - for (int il = 0; il < n_transformer_layers; ++il) { + for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; // use RoPE for SWA layers @@ -170,7 +168,7 @@ llama_model_exaone_moe::graph::graph(const llama_model & model, const llm_graph_ Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); cb(cur, "attn_out", il); } - if (il == n_transformer_layers - 1 && inp_out_ids) { + if (il == n_layer - 1 && inp_out_ids) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } diff --git a/examples/talk-llama/models/exaone.cpp b/examples/talk-llama/models/exaone.cpp index c7e9960d718..676fb37b5a6 100644 --- a/examples/talk-llama/models/exaone.cpp +++ b/examples/talk-llama/models/exaone.cpp @@ -3,7 +3,7 @@ void llama_model_exaone::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 32: type = LLM_TYPE_8B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/exaone4.cpp b/examples/talk-llama/models/exaone4.cpp index 499e22dde81..863268abcef 100644 --- a/examples/talk-llama/models/exaone4.cpp +++ b/examples/talk-llama/models/exaone4.cpp @@ -1,7 +1,7 @@ #include "models.h" void llama_model_exaone4::load_arch_hparams(llama_model_loader & ml) { - if (hparams.n_layer == 64) { // 32B + if (hparams.n_layer() == 64) { // 32B hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; hparams.n_swa = 4096; uint32_t swa_period = 4; @@ -15,8 +15,11 @@ void llama_model_exaone4::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.n_layer_nextn, false); - switch (hparams.n_layer) { + GGML_ASSERT(hparams.n_layer_nextn < hparams.n_layer_all && "n_layer_nextn must be < n_layer"); + + switch (hparams.n_layer()) { case 30: type = LLM_TYPE_1_2B; break; case 64: type = LLM_TYPE_32B; break; default: type = LLM_TYPE_UNKNOWN; @@ -37,22 +40,38 @@ void llama_model_exaone4::load_arch_tensors(llama_model_loader &) { output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); } - for (int i = 0; i < n_layer; ++i) { + for (int i = 0; i < n_layer_all; ++i) { + const bool is_nextn = i >= n_layer; + int flags = 0; + if (is_nextn) { + // NextN/MTP layers are preserved in GGUF but are not executed yet. + flags |= TENSOR_SKIP; + } + auto & layer = layers[i]; - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, flags); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, flags); - layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + if (!is_nextn) { + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, flags); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, flags); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, flags); - layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, flags); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, flags); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, flags); + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, flags); - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); + if (is_nextn) { + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), {2 * n_embd, n_embd}, flags); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), {n_embd}, flags); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), {n_embd}, flags); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), {n_embd}, flags | TENSOR_NOT_REQUIRED); + } } } diff --git a/examples/talk-llama/models/falcon-h1.cpp b/examples/talk-llama/models/falcon-h1.cpp index 94b65a3c7c9..d6ef2d51986 100644 --- a/examples/talk-llama/models/falcon-h1.cpp +++ b/examples/talk-llama/models/falcon-h1.cpp @@ -11,9 +11,9 @@ void llama_model_falcon_h1::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); - std::fill(hparams.recurrent_layer_arr.begin(), hparams.recurrent_layer_arr.end(), true); + std::fill(hparams.is_recr_impl.begin(), hparams.is_recr_impl.end(), true); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 36: type = LLM_TYPE_0_5B; break; case 24: diff --git a/examples/talk-llama/models/falcon.cpp b/examples/talk-llama/models/falcon.cpp index ad546ef2db5..b2ad90b3272 100644 --- a/examples/talk-llama/models/falcon.cpp +++ b/examples/talk-llama/models/falcon.cpp @@ -3,7 +3,7 @@ void llama_model_falcon::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 32: type = LLM_TYPE_7B; break; case 60: type = LLM_TYPE_40B; break; default: type = LLM_TYPE_UNKNOWN; diff --git a/examples/talk-llama/models/gemma-embedding.cpp b/examples/talk-llama/models/gemma-embedding.cpp index 4e07f5f2bda..80ed3b1a460 100644 --- a/examples/talk-llama/models/gemma-embedding.cpp +++ b/examples/talk-llama/models/gemma-embedding.cpp @@ -21,7 +21,7 @@ void llama_model_gemma_embedding::load_arch_hparams(llama_model_loader & ml) { GGML_ASSERT((hparams.dense_2_feat_in == 0 || hparams.dense_2_feat_in == hparams.n_embd) && "dense_2_feat_in must be equal to n_embd"); GGML_ASSERT((hparams.dense_3_feat_out == 0 || hparams.dense_3_feat_out == hparams.n_embd) && "dense_3_feat_out must be equal to n_embd"); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 24: type = LLM_TYPE_0_3B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/gemma.cpp b/examples/talk-llama/models/gemma.cpp index 1519682fdf6..651cd7e64de 100644 --- a/examples/talk-llama/models/gemma.cpp +++ b/examples/talk-llama/models/gemma.cpp @@ -3,7 +3,7 @@ void llama_model_gemma::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 18: type = LLM_TYPE_2B; break; case 28: type = LLM_TYPE_7B; break; default: type = LLM_TYPE_UNKNOWN; diff --git a/examples/talk-llama/models/gemma2.cpp b/examples/talk-llama/models/gemma2.cpp index ae3f9ffb530..2fbfb15a94a 100644 --- a/examples/talk-llama/models/gemma2.cpp +++ b/examples/talk-llama/models/gemma2.cpp @@ -16,7 +16,7 @@ void llama_model_gemma2::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false); ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 26: type = LLM_TYPE_2B; break; case 42: type = LLM_TYPE_9B; break; case 46: type = LLM_TYPE_27B; break; diff --git a/examples/talk-llama/models/gemma3.cpp b/examples/talk-llama/models/gemma3.cpp index 63a2b380e71..690194529e3 100644 --- a/examples/talk-llama/models/gemma3.cpp +++ b/examples/talk-llama/models/gemma3.cpp @@ -17,7 +17,7 @@ void llama_model_gemma3::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 18: type = LLM_TYPE_270M; break; case 26: type = LLM_TYPE_1B; break; case 32: type = LLM_TYPE_8B; break; // Rnj-1 diff --git a/examples/talk-llama/models/gemma3n.cpp b/examples/talk-llama/models/gemma3n.cpp index 6ec3a006081..83eb8250aa9 100644 --- a/examples/talk-llama/models/gemma3n.cpp +++ b/examples/talk-llama/models/gemma3n.cpp @@ -6,14 +6,14 @@ void llama_model_gemma3n::load_arch_hparams(llama_model_loader & ml) { hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; hparams.set_swa_pattern(swa_period); - hparams.n_layer_kv_from_start = 20; - hparams.f_attention_scale = 1.0f; + hparams.n_layer_kv_from_start = 20; + hparams.f_attention_scale = 1.0f; ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 30: type = LLM_TYPE_E2B; break; case 35: type = LLM_TYPE_E4B; break; default: type = LLM_TYPE_UNKNOWN; diff --git a/examples/talk-llama/models/gemma4-assistant.cpp b/examples/talk-llama/models/gemma4-assistant.cpp new file mode 100644 index 00000000000..5b7a25a5aba --- /dev/null +++ b/examples/talk-llama/models/gemma4-assistant.cpp @@ -0,0 +1,200 @@ +#include "models.h" + +void llama_model_gemma4_assistant::load_arch_hparams(llama_model_loader & ml) { + hparams.n_embd_inp_impl = hparams.n_embd_out(); + + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.is_swa_impl, hparams.n_layer()); + + uint32_t n_kv_shared_layers = 0; + ml.get_key(LLM_KV_ATTENTION_SHARED_KV_LAYERS, n_kv_shared_layers, false); + + hparams.f_attention_scale = 1.0f; + + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.n_layer_nextn, false); + GGML_ASSERT(hparams.n_layer_nextn == hparams.n_layer_all && "n_layer_nextn must be == n_layer_impl"); + + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_SWA, hparams.n_embd_head_k_swa); + ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_SWA, hparams.n_embd_head_v_swa); +} + +void llama_model_gemma4_assistant::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + if (n_embd_head_k != n_embd_head_v) { + throw std::runtime_error("Gemma 4 assistant requires n_embd_head_k == n_embd_head_v"); + } + if (hparams.n_embd_head_k_swa != hparams.n_embd_head_v_swa) { + throw std::runtime_error("Gemma 4 assistant requires n_embd_head_k_swa == n_embd_head_v_swa"); + } + if (hparams.n_embd_out() == n_embd) { + throw std::runtime_error("Gemma 4 assistant requires embedding_length_out to carry the target hidden size"); + } + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); + + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + + const int64_t n_embd_backbone = hparams.n_embd_inp(); + nextn_proj_post = create_tensor(tn(LLM_TENSOR_NEXTN_PROJ_POST, "weight"), { n_embd, n_embd_backbone }, 0); + + int rope_freqs_flag = 0; + + for (int i = 0; i < n_layer_nextn; ++i) { + auto & layer = layers[i]; + + const int64_t n_head = hparams.n_head(i); + const int64_t n_embd_head = hparams.n_embd_head_k(i); + const int64_t n_ff = hparams.n_ff(i); + + if (i == 0) { + nextn_proj_pre = create_tensor(tn(LLM_TENSOR_NEXTN_PROJ_PRE, "weight", i), { 2*n_embd_backbone, n_embd }, 0); + } + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head*n_head }, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head*n_head, n_embd }, 0); + + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head }, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0); + + layer.out_scale = create_tensor(tn(LLM_TENSOR_LAYER_OUT_SCALE, "weight", i), { 1u }, 0); + + if (!hparams.is_swa(i)) { + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), { n_embd_head/2 }, rope_freqs_flag); + rope_freqs_flag = TENSOR_DUPLICATED; + } + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0); + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), { n_embd }, 0); + } +} + +std::unique_ptr<llm_graph_context> llama_model_gemma4_assistant::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_gemma4_assistant::graph::graph(const llama_model & model, const llm_graph_params & params) : + llm_graph_context(params) { + const int64_t n_embd_backbone = hparams.n_embd_inp(); + + ggml_tensor * inp_tokens; + ggml_tensor * inp_h; + { + auto inp = std::make_unique<llm_graph_input_embd>(n_embd_backbone); + + inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens); + cb(inp->tokens, "inp_tokens", -1); + ggml_set_input(inp->tokens); + inp_tokens = inp->tokens; + res->t_inp_tokens = inp->tokens; + + inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd_backbone, ubatch.n_tokens); + cb(inp->embd, "inp_h", -1); + ggml_set_input(inp->embd); + inp_h = inp->embd; + res->t_inp_embd = inp->embd; + + res->add_input(std::move(inp)); + } + + GGML_ASSERT(cparams.ctx_other != nullptr); + const auto * model_other = llama_get_model(cparams.ctx_other); + + ggml_tensor * x = ggml_get_rows(ctx0, model_other->tok_embd, inp_tokens); + x = ggml_scale(ctx0, x, sqrtf((float) n_embd_backbone)); + cb(x, "inp_embd_target", -1); + + ggml_tensor * xh = ggml_concat(ctx0, x, inp_h, 0); + cb(xh, "inp_xh", -1); + + ggml_tensor * cur = ggml_mul_mat(ctx0, model.nextn_proj_pre, xh); + cb(cur, "pre_proj", -1); + + auto * inp_attn = build_attn_inp_kv_iswa(); + ggml_tensor * inp_pos = build_inp_pos(); + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + ggml_tensor * inpL = cur; + + for (int il = 0; il < n_layer_nextn; ++il) { + const bool is_swa = hparams.is_swa(il); + + const int64_t n_embd_head = hparams.n_embd_head_k(il); + const int64_t n_head = hparams.n_head(il); + + const float freq_base_l = model.get_rope_freq_base(cparams, il); + const float freq_scale_l = model.get_rope_freq_scale(cparams, il); + const int n_rot_l = hparams.n_rot(il); + + ggml_tensor * cur_norm = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); + cb(cur_norm, "attn_norm", il); + + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur_norm); + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + + ggml_tensor * freq_factors = is_swa ? nullptr : model.layers[il].rope_freqs; + Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, freq_factors, n_rot_l, rope_type, n_ctx_orig, + freq_base_l, freq_scale_l, ext_factor, attn_factor, beta_fast, beta_slow); + cb(Qcur, "Qcur_pos", il); + + cur = build_attn(inp_attn, model.layers[il].wo, nullptr, nullptr, + Qcur, nullptr, nullptr, nullptr, nullptr, nullptr, hparams.f_attention_scale, il); + + if (il == n_layer_nextn - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + + cur = build_norm(cur, model.layers[il].attn_post_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "attn_post_norm", il); + + ggml_tensor * attn_out = ggml_add(ctx0, cur, inpL); + cb(attn_out, "attn_out", il); + + cur = build_norm(attn_out, model.layers[il].ffn_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, nullptr, nullptr, + model.layers[il].ffn_gate, nullptr, nullptr, + model.layers[il].ffn_down, nullptr, nullptr, + nullptr, + LLM_FFN_GELU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + cur = build_norm(cur, model.layers[il].ffn_post_norm, nullptr, LLM_NORM_RMS, -1); + cb(cur, "ffn_post_norm", il); + + cur = ggml_add(ctx0, cur, attn_out); + + cur = ggml_mul(ctx0, cur, model.layers[il].out_scale); + cb(cur, "out_scaled", il); + + inpL = cur; + } + cur = inpL; + + cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1); + cb(cur, "result_norm", -1); + + ggml_tensor * logits = build_lora_mm(model.output, cur); + cb(logits, "result_output", -1); + res->t_logits = logits; + + ggml_tensor * h_next = ggml_mul_mat(ctx0, model.nextn_proj_post, cur); + cb(h_next, "h_nextn", -1); + res->t_h_nextn = h_next; + + ggml_build_forward_expand(gf, logits); + ggml_build_forward_expand(gf, h_next); +} diff --git a/examples/talk-llama/models/gemma4.cpp b/examples/talk-llama/models/gemma4.cpp index 4f9d8b18bc7..6f7fcd645cb 100644 --- a/examples/talk-llama/models/gemma4.cpp +++ b/examples/talk-llama/models/gemma4.cpp @@ -2,12 +2,12 @@ void llama_model_gemma4::load_arch_hparams(llama_model_loader & ml) { hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; - ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.swa_layers, hparams.n_layer); + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.is_swa_impl, hparams.n_layer()); uint32_t n_kv_shared_layers = 0; ml.get_key(LLM_KV_ATTENTION_SHARED_KV_LAYERS, n_kv_shared_layers, false); - hparams.n_layer_kv_from_start = hparams.n_layer - (int32_t)n_kv_shared_layers; + hparams.n_layer_kv_from_start = hparams.n_layer_all - (int32_t)n_kv_shared_layers; hparams.f_attention_scale = 1.0f; // Gemma4 uses self.scaling = 1.0 (no pre-attn scaling) ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); @@ -19,7 +19,7 @@ void llama_model_gemma4::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_SWA, hparams.n_embd_head_v_swa); ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 30: type = LLM_TYPE_26B_A4B; break; case 35: type = LLM_TYPE_E2B; break; case 42: type = LLM_TYPE_E4B; break; @@ -142,6 +142,33 @@ static ggml_tensor * ggml_view_2d_slice(ggml_context * ctx0, ggml_tensor * x, in idx * x->ne[0] * x->ne[1] * ggml_element_size(x)); } +// TODO @ngxson : maybe improve this in the future +class llm_graph_input_logits_bias : public llm_graph_input_i { +public: + llm_graph_input_logits_bias(const llama_vocab & vocab) { + arr.resize(vocab.n_tokens(), 0.0f); + for (llama_token id : vocab.get_suppress_tokens()) { + if (0 <= id && id < (int32_t)vocab.n_tokens()) { + arr[id] = -INFINITY; + } + } + } + virtual ~llm_graph_input_logits_bias() = default; + + void set_input(const llama_ubatch * /*ubatch*/) override { + const int64_t n_vocab = arr.size(); + ggml_backend_tensor_set(logits_bias, arr.data(), 0, n_vocab*ggml_element_size(logits_bias)); + } + + bool can_reuse(const llm_graph_params & /*params*/) override { + return true; + } + + ggml_tensor * logits_bias = nullptr; // F32 [n_vocab] + + std::vector<float> arr; +}; + llama_model_gemma4::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params), model(model), @@ -245,7 +272,8 @@ llama_model_gemma4::graph::graph(const llama_model & model, const llm_graph_para } // TODO @ngxson : strip unused token right after the last KV layer to speed up prompt processing - if (il == n_layer - 1 && inp_out_ids) { + // keep all rows when extracting unmasked nextn embeddings (MTP target needs the hidden state for every token) + if (il == n_layer - 1 && inp_out_ids && cparams.embeddings_nextn_masked) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); } @@ -345,7 +373,7 @@ llama_model_gemma4::graph::graph(const llama_model & model, const llm_graph_para ggml_tensor * inp_this_layer = ggml_view_2d_slice(ctx0, inp_per_layer, il); // [n_embd_per_layer, n_tokens] // TODO @ngxson : improve this - if (il == n_layer - 1 && inp_out_ids) { + if (il == n_layer - 1 && inp_out_ids && cparams.embeddings_nextn_masked) { inp_this_layer = ggml_get_rows(ctx0, inp_this_layer, inp_out_ids); } @@ -376,6 +404,17 @@ llama_model_gemma4::graph::graph(const llama_model & model, const llm_graph_para model.output_norm, nullptr, LLM_NORM_RMS, -1); + // Expose the post-output-norm hidden state (the LM-head input feature) so that + // MTP draft contexts can read it via llama_get_embeddings_nextn_ith() as the + // recurrent h input. This matches the reference (transformers/vLLM/SGLang), + // which feeds the drafter the target's post-final-norm hidden state. + cb(cur, "h_nextn", -1); + res->t_h_nextn = cur; + + if (!cparams.embeddings_nextn_masked && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + } + cb(cur, "result_norm", -1); res->t_embd = cur; @@ -388,6 +427,16 @@ llama_model_gemma4::graph::graph(const llama_model & model, const llm_graph_para cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping); } + // apply logits bias if needed (e.g. for gemma4_unified patch) + // this is to mirror the suppress_tokens patch on transformers, to avoid model from outputing <image|> and <audio|> tokens (which is a known issue related to the checkpoint) + // TODO: maybe handle this inside the sampling system in the future + if (!model.vocab.get_suppress_tokens().empty()) { + auto inp_bias = std::make_unique<llm_graph_input_logits_bias>(model.vocab); + inp_bias->logits_bias = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, inp_bias->arr.size()); + cur = ggml_add(ctx0, cur, inp_bias->logits_bias); + res->add_input(std::move(inp_bias)); + } + cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/glm-dsa.cpp b/examples/talk-llama/models/glm-dsa.cpp index af2b55ef563..11d91312def 100644 --- a/examples/talk-llama/models/glm-dsa.cpp +++ b/examples/talk-llama/models/glm-dsa.cpp @@ -33,13 +33,10 @@ void llama_model_glm_dsa::load_arch_hparams(llama_model_loader & ml) { } // NextN/MTP parameters - ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); - GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.n_layer_nextn, false); + GGML_ASSERT(hparams.n_layer_nextn < hparams.n_layer_all && "n_layer_nextn must be < n_layer_impl"); - // TODO: when MTP is implemented, this should probably be updated if needed - hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; - - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 79: type = LLM_TYPE_744B_A40B; break; default: type = LLM_TYPE_UNKNOWN; } @@ -76,9 +73,9 @@ void llama_model_glm_dsa::load_arch_tensors(llama_model_loader &) { output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); } - for (int i = 0; i < n_layer; ++i) { + for (int i = 0; i < n_layer_all; ++i) { int flags = 0; - if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) { + if (i >= n_layer) { // skip all tensors in the NextN layers // TODO @ngxson : TENSOR_NOT_REQUIRED was a hack, need to remove it later flags |= TENSOR_SKIP | TENSOR_NOT_REQUIRED; @@ -135,8 +132,8 @@ void llama_model_glm_dsa::load_arch_tensors(llama_model_loader &) { layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, flags); } - // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers - if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) { + // NextN/MTP tensors (preserved but unused) - conditionally load for last n_layer_nextn + if (i >= n_layer) { layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags); layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags); layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags); diff --git a/examples/talk-llama/models/glm4-moe.cpp b/examples/talk-llama/models/glm4-moe.cpp index 27654b8cba3..d60e47ddf0c 100644 --- a/examples/talk-llama/models/glm4-moe.cpp +++ b/examples/talk-llama/models/glm4-moe.cpp @@ -20,16 +20,13 @@ void llama_model_glm4_moe::load_arch_hparams(llama_model_loader & ml) { } // NextN/MTP parameters - ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); - GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.n_layer_nextn, false); + GGML_ASSERT(hparams.n_layer_nextn < hparams.n_layer_all && "n_layer_nextn must be < n_layer_impl"); - // TODO: when MTP is implemented, this should probably be updated if needed - hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; - - switch (hparams.n_layer) { - case 47: type = LLM_TYPE_106B_A12B; break; // GLM-4.5-Air (46 layers + 1 NextN layer) + switch (hparams.n_layer()) { + case 46: type = LLM_TYPE_106B_A12B; break; // GLM-4.5-Air case 48: type = LLM_TYPE_102B_A12B; break; // Solar Open - case 93: type = LLM_TYPE_355B_A32B; break; // GLM-4.5 (92 layers + 1 NextN layer) + case 92: type = LLM_TYPE_355B_A32B; break; // GLM-4.5 default: type = LLM_TYPE_UNKNOWN; } } @@ -54,9 +51,9 @@ void llama_model_glm4_moe::load_arch_tensors(llama_model_loader &) { // Load ALL tensors including NextN layer to satisfy total tensor count // but only PROCESS up to last layer (skipping final NextN layer) in forward pass - for (int i = 0; i < n_layer; ++i) { + for (int i = 0; i < n_layer_all; ++i) { int flags = 0; - if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) { + if (i >= n_layer) { // skip all tensors in the NextN layers flags |= TENSOR_SKIP; } @@ -116,7 +113,7 @@ void llama_model_glm4_moe::load_arch_tensors(llama_model_loader &) { } // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers - if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) { + if (i >= n_layer) { layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags); layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags); layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags); @@ -161,8 +158,7 @@ llama_model_glm4_moe::graph::graph(const llama_model & model, const llm_graph_pa // Only process up to last layer (skip final NextN layer) // Final layer tensors are loaded but not processed in forward pass - const int n_transformer_layers = n_layer - hparams.nextn_predict_layers; - for (int il = 0; il < n_transformer_layers; ++il) { + for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; // Pre-attention norm @@ -211,7 +207,7 @@ llama_model_glm4_moe::graph::graph(const llama_model & model, const llm_graph_pa model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } - if (il == n_transformer_layers - 1 && inp_out_ids) { + if (il == n_layer - 1 && inp_out_ids) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } diff --git a/examples/talk-llama/models/glm4.cpp b/examples/talk-llama/models/glm4.cpp index 7c242fed298..b4326c5f210 100644 --- a/examples/talk-llama/models/glm4.cpp +++ b/examples/talk-llama/models/glm4.cpp @@ -5,13 +5,10 @@ void llama_model_glm4::load_arch_hparams(llama_model_loader & ml) { ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, false); // NextN/MTP parameters (GLM-OCR) - ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); - GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.n_layer_nextn, false); + GGML_ASSERT(hparams.n_layer_nextn < hparams.n_layer_all && "n_layer_nextn must be < n_layer_impl"); - // TODO: when MTP is implemented, this should probably be updated if needed - hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; - - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 17: type = LLM_TYPE_1B; break; // GLM-OCR case 40: type = LLM_TYPE_9B; break; case 61: type = LLM_TYPE_32B; break; @@ -32,9 +29,9 @@ void llama_model_glm4::load_arch_tensors(llama_model_loader &) { output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); } - for (int i = 0; i < n_layer; ++i) { + for (int i = 0; i < n_layer_all; ++i) { int flags = 0; - if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) { + if (i >= n_layer) { // skip all tensors in the NextN layers flags |= TENSOR_SKIP; } @@ -55,7 +52,7 @@ void llama_model_glm4::load_arch_tensors(llama_model_loader &) { layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, flags); // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers - if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) { + if (i >= n_layer) { layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags); layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags); layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags); @@ -100,8 +97,7 @@ llama_model_glm4::graph::graph(const llama_model & model, const llm_graph_params // Only process up to last layer (skip final NextN layer) // Final layer tensors are loaded but not processed in forward pass - const int n_transformer_layers = n_layer - hparams.nextn_predict_layers; - for (int il = 0; il < n_transformer_layers; ++il) { + for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; // Pre-attention norm @@ -140,7 +136,7 @@ llama_model_glm4::graph::graph(const llama_model & model, const llm_graph_params model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); } - if (il == n_transformer_layers - 1 && inp_out_ids) { + if (il == n_layer - 1 && inp_out_ids) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } diff --git a/examples/talk-llama/models/gpt2.cpp b/examples/talk-llama/models/gpt2.cpp index e2dcc8b1521..45afbccc121 100644 --- a/examples/talk-llama/models/gpt2.cpp +++ b/examples/talk-llama/models/gpt2.cpp @@ -2,7 +2,8 @@ void llama_model_gpt2::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 12: type = LLM_TYPE_SMALL; break; case 24: type = LLM_TYPE_MEDIUM; break; case 36: type = LLM_TYPE_LARGE; break; diff --git a/examples/talk-llama/models/gptneox.cpp b/examples/talk-llama/models/gptneox.cpp index 443e35addf2..ed5e8c50da2 100644 --- a/examples/talk-llama/models/gptneox.cpp +++ b/examples/talk-llama/models/gptneox.cpp @@ -3,7 +3,8 @@ void llama_model_gptneox::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); ml.get_key(LLM_KV_USE_PARALLEL_RESIDUAL, hparams.use_par_res); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 6: switch (hparams.n_ff()) { case 512: type = LLM_TYPE_14M; break; diff --git a/examples/talk-llama/models/granite-hybrid.cpp b/examples/talk-llama/models/granite-hybrid.cpp index 27f6706ea10..eb23095aece 100644 --- a/examples/talk-llama/models/granite-hybrid.cpp +++ b/examples/talk-llama/models/granite-hybrid.cpp @@ -19,8 +19,8 @@ void llama_model_granite_hybrid::load_arch_hparams(llama_model_loader & ml) { hparams.rope_finetuned = rope_finetuned; // A layer is recurrent IFF the n_head_kv value is set to 0 - for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = hparams.n_head_kv(i) == 0; + for (uint32_t i = 0; i < hparams.n_layer(); ++i) { + hparams.is_recr_impl[i] = hparams.n_head_kv(i) == 0; } ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -71,7 +71,7 @@ void llama_model_granite_hybrid::load_arch_tensors(llama_model_loader &) { // norm layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - if (hparams.is_recurrent(i)) { + if (hparams.is_recr(i)) { // ssm layers layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, d_in_proj}, 0); @@ -158,7 +158,7 @@ llama_model_granite_hybrid::graph::graph(const llama_model & model, const llm_gr cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); cb(cur, "attn_norm", il); - if (hparams.is_recurrent(il)) { + if (hparams.is_recr(il)) { // ssm layer // cur = build_mamba2_layer(inp->get_recr(), cur, model, ubatch, il); } else { diff --git a/examples/talk-llama/models/granite-moe.cpp b/examples/talk-llama/models/granite-moe.cpp index 0d89bc1f340..115263c418f 100644 --- a/examples/talk-llama/models/granite-moe.cpp +++ b/examples/talk-llama/models/granite-moe.cpp @@ -12,7 +12,7 @@ void llama_model_granite_moe::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ROPE_SCALING_FINETUNED, rope_finetuned, false); hparams.rope_finetuned = rope_finetuned; - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 32: type = LLM_TYPE_3B; break; case 40: type = LLM_TYPE_3B; break; // Add additional layer/vocab/etc checks here for other model sizes diff --git a/examples/talk-llama/models/granite.cpp b/examples/talk-llama/models/granite.cpp index cda4aa231fa..4a75c5ff3cc 100644 --- a/examples/talk-llama/models/granite.cpp +++ b/examples/talk-llama/models/granite.cpp @@ -1,5 +1,7 @@ #include "models.h" +#include <sstream> + void llama_model_granite::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); @@ -7,12 +9,33 @@ void llama_model_granite::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale, false); ml.get_key(LLM_KV_ATTENTION_SCALE, hparams.f_attention_scale, false); + // Granite4 Vision uses array deepstack_mapping + ml.get_arr(LLM_KV_DEEPSTACK_MAPPING, hparams.deepstack_mapping_arr, false); + + // Count the unique deepstack input indices + std::unordered_set<uint32_t> unique_deepstack_idxs; + for (const auto val : hparams.deepstack_mapping_arr) { + if (val >= 0) { + unique_deepstack_idxs.insert(val); + } + } + hparams.n_deepstack_layers = unique_deepstack_idxs.size(); + + // Ensure all values are valid (avoid overflow attacks) + for (const auto val : unique_deepstack_idxs) { + if (val > hparams.n_deepstack_layers) { + std::stringstream ss; + ss << "Invalid deepstack index: " << val << " > " << hparams.n_deepstack_layers; + throw std::runtime_error(ss.str()); + } + } + // Granite uses rope_finetuned as a switch for rope, so default to true bool rope_finetuned = true; ml.get_key(LLM_KV_ROPE_SCALING_FINETUNED, rope_finetuned, false); hparams.rope_finetuned = rope_finetuned; - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 32: type = LLM_TYPE_3B; break; case 40: type = LLM_TYPE_3B; break; // Add additional layer/vocab/etc checks here for other model sizes @@ -112,6 +135,20 @@ llama_model_granite::graph::graph( ggml_tensor * inp_out_ids = build_inp_out_ids(); for (int il = 0; il < n_layer; ++il) { + + // Granite Vision 4.1 deepstack: inject the projector stream that + // targets decoder layer `il` before the decoder runs. + // NOTE: skip the first deepstack layer since that's inpL + const auto & deepstack_emb_idx = hparams.deepstack_mapping_arr[il]; + if (il > 0 && deepstack_emb_idx >= 0) { + ggml_tensor * ds = ggml_view_2d(ctx0, + res->t_inp_embd, n_embd, n_tokens, + res->t_inp_embd->nb[1], + deepstack_emb_idx * n_embd * sizeof(float)); + inpL = ggml_add(ctx0, inpL, ds); + cb(inpL, "deepstack_in", il); + } + ggml_tensor * inpSA = inpL; // norm diff --git a/examples/talk-llama/models/grok.cpp b/examples/talk-llama/models/grok.cpp index 7c46ec1c0f2..42f38af6724 100644 --- a/examples/talk-llama/models/grok.cpp +++ b/examples/talk-llama/models/grok.cpp @@ -26,7 +26,7 @@ void llama_model_grok::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_FAST, hparams.yarn_beta_fast, false); ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, hparams.yarn_beta_slow, false); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 64: type = LLM_TYPE_314B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/grovemoe.cpp b/examples/talk-llama/models/grovemoe.cpp index 1cab75adc7f..643a448e59a 100644 --- a/examples/talk-llama/models/grovemoe.cpp +++ b/examples/talk-llama/models/grovemoe.cpp @@ -7,7 +7,7 @@ void llama_model_grovemoe::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERTS_PER_GROUP, hparams.n_group_experts); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 48: type = LLM_TYPE_30B_A3B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/hunyuan-moe.cpp b/examples/talk-llama/models/hunyuan-moe.cpp index deb3c9671f3..4d55f5e7f31 100644 --- a/examples/talk-llama/models/hunyuan-moe.cpp +++ b/examples/talk-llama/models/hunyuan-moe.cpp @@ -5,7 +5,7 @@ void llama_model_hunyuan_moe::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 32: type = LLM_TYPE_A13B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/internlm2.cpp b/examples/talk-llama/models/internlm2.cpp index f9ee37a24b6..f6cfdfb9458 100644 --- a/examples/talk-llama/models/internlm2.cpp +++ b/examples/talk-llama/models/internlm2.cpp @@ -2,7 +2,8 @@ void llama_model_internlm2::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 32: type = LLM_TYPE_7B; break; case 48: type = LLM_TYPE_20B; break; default: type = LLM_TYPE_UNKNOWN; diff --git a/examples/talk-llama/models/jais.cpp b/examples/talk-llama/models/jais.cpp index 2ba162605f1..415103ce23a 100644 --- a/examples/talk-llama/models/jais.cpp +++ b/examples/talk-llama/models/jais.cpp @@ -4,7 +4,7 @@ void llama_model_jais::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); ml.get_key(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, hparams.f_max_alibi_bias, false); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 24: type = LLM_TYPE_1_3B; break; case 40: type = LLM_TYPE_13B; break; /* TODO: add variants */ diff --git a/examples/talk-llama/models/jais2.cpp b/examples/talk-llama/models/jais2.cpp index 8966131441c..8610fcc9f82 100644 --- a/examples/talk-llama/models/jais2.cpp +++ b/examples/talk-llama/models/jais2.cpp @@ -3,7 +3,7 @@ void llama_model_jais2::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 32: type = LLM_TYPE_8B; break; case 68: type = LLM_TYPE_70B; break; default: type = LLM_TYPE_UNKNOWN; diff --git a/examples/talk-llama/models/jamba.cpp b/examples/talk-llama/models/jamba.cpp index 84ea63c3136..dba160b014f 100644 --- a/examples/talk-llama/models/jamba.cpp +++ b/examples/talk-llama/models/jamba.cpp @@ -8,11 +8,11 @@ void llama_model_jamba::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = hparams.n_head_kv(i) == 0; + for (uint32_t i = 0; i < hparams.n_layer(); ++i) { + hparams.is_recr_impl[i] = hparams.n_head_kv(i) == 0; } - switch (hparams.n_layer) { + switch (hparams.n_layer()) { // TODO: Jamba layers are a bit heterogeneous, so naming this is hard. case 12: // 900M 8x???M case 32: // 51B 16x?B diff --git a/examples/talk-llama/models/jina-bert-v2.cpp b/examples/talk-llama/models/jina-bert-v2.cpp index 4f8866ece4d..86ff1c84d1a 100644 --- a/examples/talk-llama/models/jina-bert-v2.cpp +++ b/examples/talk-llama/models/jina-bert-v2.cpp @@ -4,7 +4,7 @@ void llama_model_jina_bert_v2::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); hparams.f_max_alibi_bias = 8.0f; - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 4: type = LLM_TYPE_33M; break; // jina-embeddings-small case 12: type = LLM_TYPE_137M; break; // jina-embeddings-base default: type = LLM_TYPE_UNKNOWN; diff --git a/examples/talk-llama/models/jina-bert-v3.cpp b/examples/talk-llama/models/jina-bert-v3.cpp index e0527529f56..1c974a6f16c 100644 --- a/examples/talk-llama/models/jina-bert-v3.cpp +++ b/examples/talk-llama/models/jina-bert-v3.cpp @@ -3,7 +3,7 @@ void llama_model_jina_bert_v3::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 24: type = LLM_TYPE_558M; break; default: type = LLM_TYPE_UNKNOWN; diff --git a/examples/talk-llama/models/kimi-linear.cpp b/examples/talk-llama/models/kimi-linear.cpp index ecffb105496..367f6990d1f 100644 --- a/examples/talk-llama/models/kimi-linear.cpp +++ b/examples/talk-llama/models/kimi-linear.cpp @@ -14,8 +14,8 @@ void llama_model_kimi_linear::load_arch_hparams(llama_model_loader & ml) { // Mark KDA layers as recurrent using n_head_kv pattern (like Jamba) // Set n_head_kv = 0 for KDA layers (recurrent), n_head_kv = n_head for MLA layers (attention) - for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = hparams.n_head_kv(i) == 0; // KDA layers are recurrent + for (uint32_t i = 0; i < hparams.n_layer(); ++i) { + hparams.is_recr_impl[i] = hparams.n_head_kv(i) == 0; // KDA layers are recurrent } // MoE parameters - Kimi uses moe_intermediate_size = 1024 @@ -25,7 +25,7 @@ void llama_model_kimi_linear::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 27: type = LLM_TYPE_48B_A3B; break; // Kimi-Linear-48B-A3B default: type = LLM_TYPE_UNKNOWN; } @@ -53,7 +53,7 @@ void llama_model_kimi_linear::load_arch_tensors(llama_model_loader &) { const int64_t n_embd_head_v_kda = hparams.n_embd_head_kda; const int64_t ssm_d_conv = hparams.ssm_d_conv; - if (hparams.is_recurrent(i)) { + if (hparams.is_recr(i)) { // Conv1d weights: try 4D first, then 3D (quantization may remove trailing 1) // 4D: [d_conv, 1, d_inner, 1], 3D: [d_conv, 1, d_inner] layer.ssm_q_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_Q, "weight", i), {ssm_d_conv, 1, n_embd_head_k_kda * n_head, 1}, TENSOR_NOT_REQUIRED); @@ -285,7 +285,7 @@ llama_model_kimi_linear::graph::graph(const llama_model & model, const llm_graph ggml_build_forward_expand(gf, cur); - if (hparams.is_recurrent(il)) { + if (hparams.is_recr(il)) { // === KDA Layer (Kimi Delta Attention) with Recurrent State === // Reference: vLLM kda.py const auto * mctx_cur = inp_rs->mctx; diff --git a/examples/talk-llama/models/lfm2.cpp b/examples/talk-llama/models/lfm2.cpp index 29081344b24..97da8a6abb8 100644 --- a/examples/talk-llama/models/lfm2.cpp +++ b/examples/talk-llama/models/lfm2.cpp @@ -5,10 +5,13 @@ void llama_model_lfm2::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_SHORTCONV_L_CACHE, hparams.n_shortconv_l_cache); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - for (uint32_t il = 0; il < hparams.n_layer; ++il) { - hparams.recurrent_layer_arr[il] = hparams.n_head_kv(il) == 0; + + for (uint32_t il = 0; il < hparams.n_layer(); ++il) { + hparams.is_recr_impl[il] = hparams.n_head_kv(il) == 0; } - hparams.n_layer_dense_lead = hparams.n_layer; + + hparams.n_layer_dense_lead = hparams.n_layer(); + switch (hparams.n_ff()) { case 4608: type = LLM_TYPE_350M; break; case 6912: type = LLM_TYPE_700M; break; @@ -16,10 +19,11 @@ void llama_model_lfm2::load_arch_hparams(llama_model_loader & ml) { case 10752: type = LLM_TYPE_2_6B; break; default: type = LLM_TYPE_UNKNOWN; } + if (const auto is_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); is_swa && hparams.n_swa > 0) { hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; - for (uint32_t il = 0; il < hparams.n_layer; ++il) { - hparams.swa_layers[il] = !hparams.recurrent_layer_arr[il]; + for (uint32_t il = 0; il < hparams.n_layer(); ++il) { + hparams.is_swa_impl[il] = !hparams.is_recr_impl[il]; } } } @@ -59,7 +63,7 @@ void llama_model_lfm2::load_arch_tensors(llama_model_loader &) { // for operator_norm layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - if (!hparams.is_recurrent(i)) { + if (!hparams.is_recr(i)) { layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); GGML_ASSERT(n_embd_v_gqa == n_embd_k_gqa); @@ -235,8 +239,8 @@ llama_model_lfm2::graph<iswa>::graph(const llama_model & model, const llm_graph_ cur = build_norm(cur, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); cb(cur, "model.layers.{}.operator_norm", il); - cur = hparams.is_recurrent(il) ? build_shortconv_block(cur, inp_hybrid->get_recr(), il) : - build_attn_block(cur, inp_pos, inp_hybrid->get_attn(), il); + cur = hparams.is_recr(il) ? build_shortconv_block(cur, inp_hybrid->get_recr(), il) : + build_attn_block(cur, inp_pos, inp_hybrid->get_attn(), il); if (il == n_layer - 1 && inp_out_ids) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); diff --git a/examples/talk-llama/models/lfm2moe.cpp b/examples/talk-llama/models/lfm2moe.cpp index 12a66c05c7d..490f5c223eb 100644 --- a/examples/talk-llama/models/lfm2moe.cpp +++ b/examples/talk-llama/models/lfm2moe.cpp @@ -9,11 +9,11 @@ void llama_model_lfm2moe::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func); - for (uint32_t il = 0; il < hparams.n_layer; ++il) { - hparams.recurrent_layer_arr[il] = hparams.n_head_kv(il) == 0; + for (uint32_t il = 0; il < hparams.n_layer(); ++il) { + hparams.is_recr_impl[il] = hparams.n_head_kv(il) == 0; } - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 24: type = LLM_TYPE_8B_A1B; break; case 40: type = LLM_TYPE_24B_A2B; break; default: type = LLM_TYPE_UNKNOWN; @@ -55,7 +55,7 @@ void llama_model_lfm2moe::load_arch_tensors(llama_model_loader &) { // for operator_norm layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - if (!hparams.is_recurrent(i)) { + if (!hparams.is_recr(i)) { layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); GGML_ASSERT(n_embd_v_gqa == n_embd_k_gqa); diff --git a/examples/talk-llama/models/llada-moe.cpp b/examples/talk-llama/models/llada-moe.cpp index 9722dde9f17..2ae89386447 100644 --- a/examples/talk-llama/models/llada-moe.cpp +++ b/examples/talk-llama/models/llada-moe.cpp @@ -2,11 +2,12 @@ void llama_model_llada_moe::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + // diffusion language model uses non-causal attention hparams.causal_attn = false; - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 16: type = LLM_TYPE_A1_7B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/llada.cpp b/examples/talk-llama/models/llada.cpp index 58b2c466e17..87d4259f9a7 100644 --- a/examples/talk-llama/models/llada.cpp +++ b/examples/talk-llama/models/llada.cpp @@ -2,14 +2,16 @@ void llama_model_llada::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + // LLaDA-8B has 32 layers, similar to LLaMA but for diffusion - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 32: type = LLM_TYPE_8B; break; default: type = LLM_TYPE_UNKNOWN; } + // Set non-causal attention for diffusion models hparams.causal_attn = false; } diff --git a/examples/talk-llama/models/llama.cpp b/examples/talk-llama/models/llama.cpp index cef66d054b0..c0ec7e0a9ad 100644 --- a/examples/talk-llama/models/llama.cpp +++ b/examples/talk-llama/models/llama.cpp @@ -7,13 +7,13 @@ void llama_model_llama::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); if (hparams.n_expert == 8) { - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 32: type = LLM_TYPE_8x7B; break; case 56: type = LLM_TYPE_8x22B; break; default: type = LLM_TYPE_UNKNOWN; } } else { - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 16: type = LLM_TYPE_1B; break; // Llama 3.2 1B case 22: type = LLM_TYPE_1B; break; case 26: type = LLM_TYPE_3B; break; diff --git a/examples/talk-llama/models/llama4.cpp b/examples/talk-llama/models/llama4.cpp index 0ff5376d571..7194c72a585 100644 --- a/examples/talk-llama/models/llama4.cpp +++ b/examples/talk-llama/models/llama4.cpp @@ -8,14 +8,15 @@ void llama_model_llama4::load_arch_hparams(llama_model_loader & ml) { const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); if (found_swa && hparams.n_swa == 0) { hparams.swa_type = LLAMA_SWA_TYPE_NONE; - hparams.n_no_rope_layer_step = hparams.n_layer; // always use rope + hparams.n_no_rope_layer_step = hparams.n_layer(); // always use rope } else { hparams.swa_type = LLAMA_SWA_TYPE_CHUNKED; hparams.n_swa = 8192; hparams.n_attn_temp_floor_scale = 8192; hparams.f_attn_temp_scale = 0.1f; hparams.f_attn_temp_offset = 1.0f; - uint32_t swa_period = 4; // pattern: 3 chunked - 1 full + + uint32_t swa_period = 4; // pattern: 3 chunked - 1 full ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); hparams.set_swa_pattern(swa_period); diff --git a/examples/talk-llama/models/maincoder.cpp b/examples/talk-llama/models/maincoder.cpp index 84cfe399027..ae56a26a1f6 100644 --- a/examples/talk-llama/models/maincoder.cpp +++ b/examples/talk-llama/models/maincoder.cpp @@ -2,7 +2,8 @@ void llama_model_maincoder::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 32: type = LLM_TYPE_1B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/mamba.cpp b/examples/talk-llama/models/mamba.cpp index 887a1fa509a..0d94e98281c 100644 --- a/examples/talk-llama/models/mamba.cpp +++ b/examples/talk-llama/models/mamba.cpp @@ -9,7 +9,7 @@ void llama_model_mamba::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 24: switch (hparams.n_embd) { case 768: type = LLM_TYPE_SMALL; break; diff --git a/examples/talk-llama/models/mamba2.cpp b/examples/talk-llama/models/mamba2.cpp index 3277ca53ec4..c5951cf0f7f 100644 --- a/examples/talk-llama/models/mamba2.cpp +++ b/examples/talk-llama/models/mamba2.cpp @@ -9,7 +9,7 @@ void llama_model_mamba2::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 24: switch (hparams.n_embd) { case 768: type = LLM_TYPE_SMALL; break; diff --git a/examples/talk-llama/models/mellum.cpp b/examples/talk-llama/models/mellum.cpp new file mode 100644 index 00000000000..28823018bc0 --- /dev/null +++ b/examples/talk-llama/models/mellum.cpp @@ -0,0 +1,225 @@ +#include "models.h" + +void llama_model_mellum::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); + + if (hparams.n_swa > 0) { + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + + uint32_t swa_period = 4; + const auto res = ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); + if (res) { + hparams.set_swa_pattern(swa_period); + } else { + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.is_swa_impl, hparams.n_layer()); + } + + hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; + hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; + + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + } else { + hparams.swa_type = LLAMA_SWA_TYPE_NONE; + } + + switch (hparams.n_layer()) { + case 28: type = LLM_TYPE_12B_A2_5B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_mellum::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0 for Mellum"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0 for Mellum"); + } + + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + } +} + +std::unique_ptr<llm_graph_context> llama_model_mellum::build_arch_graph(const llm_graph_params & params) const { + if (hparams.swa_type == LLAMA_SWA_TYPE_STANDARD) { + return std::make_unique<graph<true>>(*this, params); + } + return std::make_unique<graph<false>>(*this, params); +} + +template <bool iswa> +llama_model_mellum::graph<iswa>::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + using inp_attn_type = std::conditional_t<iswa, llm_graph_input_attn_kv_iswa, llm_graph_input_attn_kv>; + inp_attn_type * inp_attn = nullptr; + + if constexpr (iswa) { + inp_attn = build_attn_inp_kv_iswa(); + } else { + inp_attn = build_attn_inp_kv(); + } + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, nullptr, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self_attention + { + // compute Q and K and RoPE them + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); + + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + + const bool is_swa = hparams.is_swa(il); + + if (is_swa) { + // For sliding window layers, use regular rope with no yarn rope scaling. + // This is achieved here by setting freq_scale and attn_factor to 1. + // We also set ext_factor to 0 to avoid a few unnecessary computations. + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, 1.0, + 0.0, 1.0, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, 1.0, + 0.0, 1.0, beta_fast, beta_slow + ); + } else { + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + } + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // MoE + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, nullptr, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + ggml_tensor * moe_out = + build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + hparams.expert_weights_scale, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + il, + nullptr, nullptr, + model.layers[il].ffn_up_exps_s, + model.layers[il].ffn_gate_exps_s, + model.layers[il].ffn_down_exps_s); + cb(moe_out, "ffn_moe_out", il); + cur = moe_out; + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + cur = inpL; + + cur = build_norm(cur, + model.output_norm, nullptr, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur, model.output_s); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} + +template struct llama_model_mellum::graph<false>; +template struct llama_model_mellum::graph<true>; diff --git a/examples/talk-llama/models/mimo2.cpp b/examples/talk-llama/models/mimo2.cpp index d0295ec116f..88989160570 100644 --- a/examples/talk-llama/models/mimo2.cpp +++ b/examples/talk-llama/models/mimo2.cpp @@ -8,18 +8,18 @@ void llama_model_mimo2::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); - ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.swa_layers, hparams.n_layer); + + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.is_swa_impl, hparams.n_layer()); float value_scale = 0.0f; if (ml.get_key(LLM_KV_ATTENTION_VALUE_SCALE, value_scale, false) && value_scale != 1.0f) { hparams.f_attn_value_scale = value_scale; } - ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); - GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); - hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.n_layer_nextn, false); + GGML_ASSERT(hparams.n_layer_nextn < hparams.n_layer_all && "n_layer_nextn must be < n_layer_impl"); - switch (hparams.n_layer - hparams.nextn_predict_layers) { + switch (hparams.n_layer()) { case 48: type = LLM_TYPE_310B_A15B; break; default: type = LLM_TYPE_UNKNOWN; } @@ -34,16 +34,14 @@ void llama_model_mimo2::load_arch_tensors(llama_model_loader &) { output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - const uint32_t n_nextn = hparams.nextn_predict_layers; - - for (int i = 0; i < n_layer; ++i) { + for (int i = 0; i < n_layer_all; ++i) { auto & layer = layers[i]; uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i); uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i); uint32_t n_head = hparams.n_head(i); // NextN/MTP layers (the last n_nextn blocks) are preserved but disabled pending support - const bool is_nextn = (n_nextn > 0) && (static_cast<uint32_t>(i) >= n_layer - n_nextn); + const bool is_nextn = i >= n_layer; const int skip = is_nextn ? TENSOR_SKIP : 0; create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, skip); @@ -92,10 +90,7 @@ llama_model_mimo2::graph::graph(const llama_model & model, const llm_graph_param const float v_scale = hparams.f_attn_value_scale; - // The last hparams.nextn_predict_layers blocks are MTP heads, currently inactive - const int n_transformer_layers = n_layer - hparams.nextn_predict_layers; - - for (int il = 0; il < n_transformer_layers; ++il) { + for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; uint32_t n_head_l = hparams.n_head(il); @@ -173,7 +168,7 @@ llama_model_mimo2::graph::graph(const llama_model & model, const llm_graph_param } } - if (il == n_transformer_layers - 1 && inp_out_ids) { + if (il == n_layer - 1 && inp_out_ids) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } diff --git a/examples/talk-llama/models/minicpm.cpp b/examples/talk-llama/models/minicpm.cpp index 966d3af615c..fc3e5b171d5 100644 --- a/examples/talk-llama/models/minicpm.cpp +++ b/examples/talk-llama/models/minicpm.cpp @@ -3,7 +3,7 @@ void llama_model_minicpm::load_arch_hparams(llama_model_loader & ml) { // Backward-compatible defaults for older MiniCPM GGUFs hparams.f_embedding_scale = 12.0f; - hparams.f_residual_scale = 1.4f / sqrtf(float(hparams.n_layer)); + hparams.f_residual_scale = 1.4f / sqrtf(float(hparams.n_layer())); hparams.f_logit_scale = hparams.n_embd ? (256.0f / float(hparams.n_embd)) : 1.0f; ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -16,7 +16,7 @@ void llama_model_minicpm::load_arch_hparams(llama_model_loader & ml) { // MiniCPM uses rope by default, unlike Granite which uses it as a switch hparams.rope_finetuned = true; - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 52: type = LLM_TYPE_1B; break; case 40: type = LLM_TYPE_2B; break; default: type = LLM_TYPE_UNKNOWN; diff --git a/examples/talk-llama/models/minicpm3.cpp b/examples/talk-llama/models/minicpm3.cpp index 1ffc54fa7c6..e011b1ff0a8 100644 --- a/examples/talk-llama/models/minicpm3.cpp +++ b/examples/talk-llama/models/minicpm3.cpp @@ -5,7 +5,7 @@ void llama_model_minicpm3::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 62: type = LLM_TYPE_4B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/minimax-m2.cpp b/examples/talk-llama/models/minimax-m2.cpp index 22e291d73a3..b25435e4d97 100644 --- a/examples/talk-llama/models/minimax-m2.cpp +++ b/examples/talk-llama/models/minimax-m2.cpp @@ -5,7 +5,7 @@ void llama_model_minimax_m2::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 62: type = LLM_TYPE_230B_A10B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/mistral3.cpp b/examples/talk-llama/models/mistral3.cpp index 1ac5a95ccdc..9a8e3f9a50b 100644 --- a/examples/talk-llama/models/mistral3.cpp +++ b/examples/talk-llama/models/mistral3.cpp @@ -18,7 +18,7 @@ void llama_model_mistral3::load_arch_hparams(llama_model_loader & ml) { } } - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 26: type = LLM_TYPE_3B; break; case 34: type = LLM_TYPE_8B; break; case 40: type = LLM_TYPE_14B; break; diff --git a/examples/talk-llama/models/models.h b/examples/talk-llama/models/models.h index db228865d5d..c137e32e8fd 100644 --- a/examples/talk-llama/models/models.h +++ b/examples/talk-llama/models/models.h @@ -411,6 +411,18 @@ struct llama_model_stablelm : public llama_model_base { std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; }; +struct llama_model_mellum : public llama_model_base { + llama_model_mellum(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + template <bool iswa> + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; +}; struct llama_model_qwen : public llama_model_base { llama_model_qwen(const struct llama_model_params & params) : llama_model_base(params) {} @@ -810,6 +822,19 @@ struct llama_model_gemma4 : public llama_model_base { }; +struct llama_model_gemma4_assistant : public llama_model_base { + llama_model_gemma4_assistant(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; +}; + + struct llama_model_gemma_embedding : public llama_model_base { llama_model_gemma_embedding(const struct llama_model_params & params) : llama_model_base(params) {} void load_arch_hparams(llama_model_loader & ml) override; @@ -1030,6 +1055,19 @@ struct llama_model_deepseek2 : public llama_model_base { }; +struct llama_model_deepseek32 : public llama_model_base { + llama_model_deepseek32(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; +}; + + struct llama_model_deepseek2ocr : public llama_model_base { llama_model_deepseek2ocr(const struct llama_model_params & params) : llama_model_base(params) {} void load_arch_hparams(llama_model_loader & ml) override; @@ -1900,5 +1938,9 @@ struct llama_model_step35 : public llama_model_base { graph(const llama_model & model, const llm_graph_params & params); }; + struct graph_mtp : public llm_graph_context { + graph_mtp(const llama_model & model, const llm_graph_params & params); + }; + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; }; diff --git a/examples/talk-llama/models/modern-bert.cpp b/examples/talk-llama/models/modern-bert.cpp index e9b79ffc6dc..f3e9407e012 100644 --- a/examples/talk-llama/models/modern-bert.cpp +++ b/examples/talk-llama/models/modern-bert.cpp @@ -14,7 +14,15 @@ void llama_model_modern_bert::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - switch (hparams.n_layer) { + // Some ModernBert derivatives (e.g. IBM Granite Embedding 97m R2) use + // SiLU/SwiGLU in the FFN instead of the default GELU/GeGLU. + hparams.llm_ffn_op = LLM_FFN_GEGLU; + std::string hidden_act; + if (ml.get_key(LLM_KV_HIDDEN_ACT, hidden_act, false)) { + hparams.llm_ffn_op = llm_ffn_op_type_from_string(hidden_act, LLM_FFN_GEGLU); + } + + switch (hparams.n_layer()) { case 12: type = LLM_TYPE_47M; break; // granite-embedding-small case 22: @@ -144,7 +152,8 @@ llama_model_modern_bert::graph::graph(const llama_model & model, const llm_graph NULL, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, NULL, - LLM_FFN_GEGLU, LLM_FFN_SEQ, il); + hparams.llm_ffn_op, + LLM_FFN_SEQ, il); // attentions bypass the intermediate layer cur = ggml_add(ctx0, cur, ffn_inp); diff --git a/examples/talk-llama/models/mpt.cpp b/examples/talk-llama/models/mpt.cpp index 0229d20ed36..d094fd9f80b 100644 --- a/examples/talk-llama/models/mpt.cpp +++ b/examples/talk-llama/models/mpt.cpp @@ -5,7 +5,7 @@ void llama_model_mpt::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv, false); ml.get_key(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, hparams.f_max_alibi_bias, false); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 32: type = LLM_TYPE_7B; break; case 48: type = LLM_TYPE_30B; break; default: type = LLM_TYPE_UNKNOWN; diff --git a/examples/talk-llama/models/nemotron-h.cpp b/examples/talk-llama/models/nemotron-h.cpp index a82f9c170b4..a456269347b 100644 --- a/examples/talk-llama/models/nemotron-h.cpp +++ b/examples/talk-llama/models/nemotron-h.cpp @@ -9,8 +9,8 @@ void llama_model_nemotron_h::load_arch_hparams(llama_model_loader & ml) { // A layer is recurrent IFF the n_head_kv value is set to 0 and // the n_ff value is set to 0 - for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = (hparams.n_head_kv(i) == 0 && hparams.n_ff(i) == 0); + for (uint32_t i = 0; i < hparams.n_layer(); ++i) { + hparams.is_recr_impl[i] = (hparams.n_head_kv(i) == 0 && hparams.n_ff(i) == 0); } ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -22,7 +22,7 @@ void llama_model_nemotron_h::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); ml.get_key(LLM_KV_MOE_LATENT_SIZE, hparams.moe_latent_size, false); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 52: type = LLM_TYPE_31B_A3_5B; break; // Nemotron-H_MOE 31B case 56: type = LLM_TYPE_9B; break; case 88: type = LLM_TYPE_120B_A12B; break; @@ -62,7 +62,7 @@ void llama_model_nemotron_h::load_arch_tensors(llama_model_loader &) { // all blocks use the attn norm layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - if (hparams.is_recurrent(i)) { + if (hparams.is_recr(i)) { // ssm layers layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, d_in_proj}, 0); @@ -143,7 +143,7 @@ llama_model_nemotron_h::graph::graph(const llama_model & model, const llm_graph_ cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); cb(cur, "attn_norm", il); - if (hparams.is_recurrent(il)) { + if (hparams.is_recr(il)) { // ssm layer // cur = build_mamba2_layer(inp->get_recr(), cur, model, ubatch, il); } else if (hparams.n_ff(il) == 0) { diff --git a/examples/talk-llama/models/nemotron.cpp b/examples/talk-llama/models/nemotron.cpp index 5d4a3b5c69e..6e2bd9a33ca 100644 --- a/examples/talk-llama/models/nemotron.cpp +++ b/examples/talk-llama/models/nemotron.cpp @@ -2,7 +2,8 @@ void llama_model_nemotron::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 32: type = LLM_TYPE_4B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/neo-bert.cpp b/examples/talk-llama/models/neo-bert.cpp index f00d6eddfc9..4a08d7abd40 100644 --- a/examples/talk-llama/models/neo-bert.cpp +++ b/examples/talk-llama/models/neo-bert.cpp @@ -3,7 +3,7 @@ void llama_model_neo_bert::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - if (hparams.n_layer == 28) { + if (hparams.n_layer() == 28) { type = LLM_TYPE_250M; } } diff --git a/examples/talk-llama/models/nomic-bert-moe.cpp b/examples/talk-llama/models/nomic-bert-moe.cpp index a17abe2c269..da4b62919bb 100644 --- a/examples/talk-llama/models/nomic-bert-moe.cpp +++ b/examples/talk-llama/models/nomic-bert-moe.cpp @@ -4,7 +4,7 @@ void llama_model_nomic_bert_moe::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); ml.get_key(LLM_KV_MOE_EVERY_N_LAYERS, hparams.moe_every_n_layers, 0); - if (hparams.n_layer == 12 && hparams.n_embd == 768) { + if (hparams.n_layer() == 12 && hparams.n_embd == 768) { if (arch == LLM_ARCH_NOMIC_BERT) { type = LLM_TYPE_137M; } else if (arch == LLM_ARCH_NOMIC_BERT_MOE && hparams.moe_every_n_layers == 2) { diff --git a/examples/talk-llama/models/nomic-bert.cpp b/examples/talk-llama/models/nomic-bert.cpp index 5a8a5584457..e7fc72286a6 100644 --- a/examples/talk-llama/models/nomic-bert.cpp +++ b/examples/talk-llama/models/nomic-bert.cpp @@ -4,7 +4,7 @@ void llama_model_nomic_bert::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); ml.get_key(LLM_KV_MOE_EVERY_N_LAYERS, hparams.moe_every_n_layers, 0); - if (hparams.n_layer == 12 && hparams.n_embd == 768) { + if (hparams.n_layer() == 12 && hparams.n_embd == 768) { if (arch == LLM_ARCH_NOMIC_BERT) { type = LLM_TYPE_137M; } else if (arch == LLM_ARCH_NOMIC_BERT_MOE && hparams.moe_every_n_layers == 2) { diff --git a/examples/talk-llama/models/olmo.cpp b/examples/talk-llama/models/olmo.cpp index cfcf17bcb03..9f7a2ba60ef 100644 --- a/examples/talk-llama/models/olmo.cpp +++ b/examples/talk-llama/models/olmo.cpp @@ -4,7 +4,7 @@ void llama_model_olmo::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); ml.get_key(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv, false); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 22: type = LLM_TYPE_1B; break; case 32: type = LLM_TYPE_7B; break; case 80: type = LLM_TYPE_70B; break; diff --git a/examples/talk-llama/models/olmo2.cpp b/examples/talk-llama/models/olmo2.cpp index 7cc262f5504..cb52cdef720 100644 --- a/examples/talk-llama/models/olmo2.cpp +++ b/examples/talk-llama/models/olmo2.cpp @@ -17,7 +17,7 @@ void llama_model_olmo2::load_arch_hparams(llama_model_loader & ml) { hparams.swa_type = LLAMA_SWA_TYPE_NONE; } - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 16: type = LLM_TYPE_1B; break; case 32: type = LLM_TYPE_7B; break; case 40: type = LLM_TYPE_13B; break; diff --git a/examples/talk-llama/models/olmoe.cpp b/examples/talk-llama/models/olmoe.cpp index 7976ae44a51..1e2baeb207f 100644 --- a/examples/talk-llama/models/olmoe.cpp +++ b/examples/talk-llama/models/olmoe.cpp @@ -2,7 +2,8 @@ void llama_model_olmoe::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 16: type = LLM_TYPE_A1_7B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/openai-moe.cpp b/examples/talk-llama/models/openai-moe.cpp index 15b6c8c1205..3ab15d61f08 100644 --- a/examples/talk-llama/models/openai-moe.cpp +++ b/examples/talk-llama/models/openai-moe.cpp @@ -14,7 +14,7 @@ void llama_model_openai_moe::load_arch_hparams(llama_model_loader & ml) { hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 24: type = LLM_TYPE_20B; break; case 36: type = LLM_TYPE_120B; break; default: type = LLM_TYPE_UNKNOWN; diff --git a/examples/talk-llama/models/openelm.cpp b/examples/talk-llama/models/openelm.cpp index 9f76350fd4d..13120bd3236 100644 --- a/examples/talk-llama/models/openelm.cpp +++ b/examples/talk-llama/models/openelm.cpp @@ -3,12 +3,12 @@ void llama_model_openelm::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { - case 16: type = LLM_TYPE_270M; break; - case 20: type = LLM_TYPE_450M; break; - case 28: type = LLM_TYPE_1B; break; - case 36: type = LLM_TYPE_3B; break; - default: type = LLM_TYPE_UNKNOWN; + switch (hparams.n_layer()) { + case 16: type = LLM_TYPE_270M; break; + case 20: type = LLM_TYPE_450M; break; + case 28: type = LLM_TYPE_1B; break; + case 36: type = LLM_TYPE_3B; break; + default: type = LLM_TYPE_UNKNOWN; } } diff --git a/examples/talk-llama/models/orion.cpp b/examples/talk-llama/models/orion.cpp index bcb4bbba4b1..863a2822269 100644 --- a/examples/talk-llama/models/orion.cpp +++ b/examples/talk-llama/models/orion.cpp @@ -3,7 +3,7 @@ void llama_model_orion::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 40: type = LLM_TYPE_14B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/pangu-embed.cpp b/examples/talk-llama/models/pangu-embed.cpp index 7593f879b24..90f05c088c1 100644 --- a/examples/talk-llama/models/pangu-embed.cpp +++ b/examples/talk-llama/models/pangu-embed.cpp @@ -2,7 +2,8 @@ void llama_model_pangu_embed::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 26: type = LLM_TYPE_1B; break; // openPangu-Embedded-1B-V1.1 case 34: type = LLM_TYPE_7B; break; // openPangu-Embedded-7B-V1.1 default: type = LLM_TYPE_UNKNOWN; diff --git a/examples/talk-llama/models/phi2.cpp b/examples/talk-llama/models/phi2.cpp index 8f3ed5f7b7d..81b1ad12cc0 100644 --- a/examples/talk-llama/models/phi2.cpp +++ b/examples/talk-llama/models/phi2.cpp @@ -3,7 +3,7 @@ void llama_model_phi2::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 24: type = LLM_TYPE_1B; break; case 32: type = LLM_TYPE_3B; break; default: type = LLM_TYPE_UNKNOWN; diff --git a/examples/talk-llama/models/phi3.cpp b/examples/talk-llama/models/phi3.cpp index f8a4a4d5aa5..716ff814cc1 100644 --- a/examples/talk-llama/models/phi3.cpp +++ b/examples/talk-llama/models/phi3.cpp @@ -3,7 +3,7 @@ void llama_model_phi3::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 24: type = LLM_TYPE_1B; break; case 32: type = LLM_TYPE_3B; break; case 40: type = LLM_TYPE_14B; break; diff --git a/examples/talk-llama/models/phimoe.cpp b/examples/talk-llama/models/phimoe.cpp index 4575d6139cf..c332553bc7d 100644 --- a/examples/talk-llama/models/phimoe.cpp +++ b/examples/talk-llama/models/phimoe.cpp @@ -3,7 +3,7 @@ void llama_model_phimoe::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 32: type = LLM_TYPE_16x3_8B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/plamo.cpp b/examples/talk-llama/models/plamo.cpp index c7ed1211c31..246144519e4 100644 --- a/examples/talk-llama/models/plamo.cpp +++ b/examples/talk-llama/models/plamo.cpp @@ -3,7 +3,7 @@ void llama_model_plamo::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 40: type = LLM_TYPE_13B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/plamo2.cpp b/examples/talk-llama/models/plamo2.cpp index b713889fe72..b93cf48bc5c 100644 --- a/examples/talk-llama/models/plamo2.cpp +++ b/examples/talk-llama/models/plamo2.cpp @@ -11,11 +11,11 @@ void llama_model_plamo2::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); - for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = hparams.n_head_kv(i) == 0; + for (uint32_t i = 0; i < hparams.n_layer(); ++i) { + hparams.is_recr_impl[i] = hparams.n_head_kv(i) == 0; } - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 16: type = LLM_TYPE_1B; break; case 32: if (hparams.n_embd == 2048) { @@ -54,7 +54,7 @@ void llama_model_plamo2::load_arch_tensors(llama_model_loader &) { for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; - bool is_mamba_layer = hparams.is_recurrent(i); + bool is_mamba_layer = hparams.is_recr(i); layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); @@ -128,7 +128,7 @@ llama_model_plamo2::graph::graph(const llama_model & model, const llm_graph_para cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); // check if this layer is Mamba or Attention - const bool is_mamba_layer = hparams.is_recurrent(il); + const bool is_mamba_layer = hparams.is_recr(il); if (is_mamba_layer) { // PLaMo-2 Mamba layer diff --git a/examples/talk-llama/models/plamo3.cpp b/examples/talk-llama/models/plamo3.cpp index 29f3e803d68..16d0b1dcef7 100644 --- a/examples/talk-llama/models/plamo3.cpp +++ b/examples/talk-llama/models/plamo3.cpp @@ -13,7 +13,7 @@ void llama_model_plamo3::load_arch_hparams(llama_model_loader & ml) { hparams.swa_type = LLAMA_SWA_TYPE_NONE; } - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 24: type = LLM_TYPE_2B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/plm.cpp b/examples/talk-llama/models/plm.cpp index ce050919e6a..8ca325f5e2c 100644 --- a/examples/talk-llama/models/plm.cpp +++ b/examples/talk-llama/models/plm.cpp @@ -3,7 +3,8 @@ void llama_model_plm::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 32: type = LLM_TYPE_1_8B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/qwen.cpp b/examples/talk-llama/models/qwen.cpp index 00467dbad7d..1f5dff3843c 100644 --- a/examples/talk-llama/models/qwen.cpp +++ b/examples/talk-llama/models/qwen.cpp @@ -3,7 +3,7 @@ void llama_model_qwen::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 32: type = LLM_TYPE_7B; break; case 40: type = LLM_TYPE_13B; break; default: type = LLM_TYPE_UNKNOWN; diff --git a/examples/talk-llama/models/qwen2.cpp b/examples/talk-llama/models/qwen2.cpp index a5147460bae..e9c2ea80a6b 100644 --- a/examples/talk-llama/models/qwen2.cpp +++ b/examples/talk-llama/models/qwen2.cpp @@ -2,7 +2,8 @@ void llama_model_qwen2::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 24: type = hparams.n_embd == 1024 ? LLM_TYPE_0_5B : LLM_TYPE_1B; break; case 28: type = hparams.n_embd == 1536 ? LLM_TYPE_1_5B : LLM_TYPE_7B; break; case 32: type = LLM_TYPE_7B; break; diff --git a/examples/talk-llama/models/qwen2moe.cpp b/examples/talk-llama/models/qwen2moe.cpp index 7cb03859deb..e831ed11aad 100644 --- a/examples/talk-llama/models/qwen2moe.cpp +++ b/examples/talk-llama/models/qwen2moe.cpp @@ -5,7 +5,8 @@ void llama_model_qwen2moe::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 24: type = LLM_TYPE_A2_7B; break; case 28: type = LLM_TYPE_57B_A14B; break; default: type = LLM_TYPE_UNKNOWN; diff --git a/examples/talk-llama/models/qwen3.cpp b/examples/talk-llama/models/qwen3.cpp index 41b97fed956..1d0d2fab362 100644 --- a/examples/talk-llama/models/qwen3.cpp +++ b/examples/talk-llama/models/qwen3.cpp @@ -2,7 +2,8 @@ void llama_model_qwen3::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 28: type = hparams.n_embd == 1024 ? LLM_TYPE_0_6B : LLM_TYPE_1_7B; break; case 36: type = hparams.n_embd == 2560 ? LLM_TYPE_4B : LLM_TYPE_8B; break; case 40: type = LLM_TYPE_14B; break; diff --git a/examples/talk-llama/models/qwen35.cpp b/examples/talk-llama/models/qwen35.cpp index 04ecc18fcdc..4b642cff467 100644 --- a/examples/talk-llama/models/qwen35.cpp +++ b/examples/talk-llama/models/qwen35.cpp @@ -13,21 +13,20 @@ void llama_model_qwen35::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); // NextN/MTP (Qwen3.5/3.6): extra decoder block appended beyond the main stack - ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); - GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.n_layer_nextn, false); + GGML_ASSERT(hparams.n_layer_nextn < hparams.n_layer_all && "n_layer_nextn must be < n_layer_impl"); // Mark recurrent layers (linear attention layers). MTP layers are dense // attention-only and must be flagged non-recurrent. - { - const uint32_t n_main = hparams.n_layer - hparams.nextn_predict_layers; + if (!ml.get_key_or_arr(LLM_KV_ATTENTION_RECURRENT_LAYERS, hparams.is_recr_impl, hparams.n_layer_all, false)) { uint32_t full_attn_interval = 4; ml.get_key(LLM_KV_FULL_ATTENTION_INTERVAL, full_attn_interval, false); - for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = (i < n_main) && ((i + 1) % full_attn_interval != 0); + for (uint32_t i = 0; i < hparams.n_layer_all; ++i) { + hparams.is_recr_impl[i] = (i < hparams.n_layer()) && ((i + 1) % full_attn_interval != 0); } } - switch (hparams.n_layer - hparams.nextn_predict_layers) { + switch (hparams.n_layer()) { case 24: type = hparams.n_embd == 1024 ? LLM_TYPE_0_8B : LLM_TYPE_2B; break; case 32: type = hparams.n_embd == 2560 ? LLM_TYPE_4B : LLM_TYPE_9B; break; case 64: type = LLM_TYPE_27B; break; @@ -38,9 +37,7 @@ void llama_model_qwen35::load_arch_hparams(llama_model_loader & ml) { void llama_model_qwen35::load_arch_tensors(llama_model_loader & ml) { LLAMA_LOAD_LOCALS; - const uint32_t n_main = n_layer - hparams.nextn_predict_layers; - const bool mtp_only = (hparams.nextn_predict_layers > 0) && - (ml.get_weight("blk.0.attn_norm.weight") == nullptr); + const bool mtp_only = (hparams.n_layer_nextn > 0) && (ml.get_weight("blk.0.attn_norm.weight") == nullptr); const int trunk_flags = mtp_only ? TENSOR_NOT_REQUIRED : 0; tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); @@ -69,7 +66,7 @@ void llama_model_qwen35::load_arch_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", il), { n_embd }, flags); layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", il), { n_embd }, flags); - if (!hparams.is_recurrent(il)) { + if (!hparams.is_recr(il)) { // Attention layers create_tensor_qkv(layer, il, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, flags); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", il), { n_embd_head_k * n_head, n_embd }, flags); @@ -121,10 +118,10 @@ void llama_model_qwen35::load_arch_tensors(llama_model_loader & ml) { layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", il), { n_embd }, TENSOR_NOT_REQUIRED); }; - for (int i = 0; i < (int) n_main; ++i) { + for (int i = 0; i < n_layer; ++i) { load_block_trunk(i, trunk_flags); } - for (int i = (int) n_main; i < n_layer; ++i) { + for (int i = n_layer; i < n_layer_all; ++i) { load_block_mtp(i); } } @@ -158,8 +155,7 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para ggml_tensor * inp_out_ids = build_inp_out_ids(); // MTP/NextN layers are loaded as extra decoder blocks but not executed in the main pass. - const int n_transformer_layers = n_layer - (int) hparams.nextn_predict_layers; - for (int il = 0; il < n_transformer_layers; ++il) { + for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); @@ -168,7 +164,7 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para ggml_build_forward_expand(gf, cur); // Determine layer type and build appropriate attention mechanism - if (hparams.is_recurrent(il)) { + if (hparams.is_recr(il)) { // Linear attention layer (gated delta net) cur = build_layer_attn_linear(inp->get_recr(), cur, il); } else { @@ -176,7 +172,7 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il); } - if (il == n_transformer_layers - 1 && inp_out_ids && cparams.embeddings_pre_norm_masked) { + if (il == n_layer - 1 && inp_out_ids && cparams.embeddings_nextn_masked) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } @@ -208,16 +204,15 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para } cur = inpL; - cb(cur, "h_pre_norm", -1); - res->t_h_pre_norm = cur; + cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1); + + cb(cur, "h_nextn", -1); + res->t_h_nextn = cur; - if (!cparams.embeddings_pre_norm_masked && inp_out_ids) { + if (!cparams.embeddings_nextn_masked && inp_out_ids) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); } - // Final norm - cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1); - cb(cur, "result_norm", -1); res->t_embd = cur; @@ -490,15 +485,15 @@ ggml_tensor * llama_model_qwen35::graph::build_layer_ffn(ggml_tensor * cur, cons // LLM_GRAPH_TYPE_DECODER_MTP draft head for Qwen3.5/3.6 dense series llama_model_qwen35::graph_mtp::graph_mtp(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - GGML_ASSERT(hparams.nextn_predict_layers > 0 && "QWEN35 MTP requires nextn_predict_layers > 0"); - GGML_ASSERT(hparams.nextn_predict_layers == 1 && "QWEN35 MTP currently only supports a single MTP block"); + GGML_ASSERT(hparams.n_layer_nextn > 0 && "QWEN35 MTP requires n_layer_nextn > 0"); + GGML_ASSERT(hparams.n_layer_nextn == 1 && "QWEN35 MTP currently only supports a single MTP block"); const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); // hparams.n_layer includes both main model layers and MTP layers. The MTP // layer is stored immediately after the main layers in model.layers[]. - const int il = (int) hparams.n_layer - (int) hparams.nextn_predict_layers; + const int il = hparams.n_layer(); const auto & layer = model.layers[il]; GGML_ASSERT(layer.nextn.eh_proj && "MTP block missing nextn.eh_proj"); @@ -508,28 +503,41 @@ llama_model_qwen35::graph_mtp::graph_mtp(const llama_model & model, const llm_gr int sections[4]; std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); - auto inp = std::make_unique<llm_graph_input_embd>(hparams.n_embd); + // TODO: extract in a common llm_graph_context::build_inp_embd_h() + auto inp = std::make_unique<llm_graph_input_embd_h>(hparams.n_embd); inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); ggml_set_input(inp->tokens); - inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens); + inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd_inp(), n_tokens); ggml_set_input(inp->embd); - ggml_set_name(inp->embd, "mtp_h_input"); - ggml_tensor * tok_embd_w = layer.nextn.embed_tokens ? layer.nextn.embed_tokens : model.tok_embd; + // TODO: make static using `ggml_build_forward_select()` + // see llm_graph_context::build_inp_embd() for reference + ggml_tensor * tok_embd; + if (ubatch.token) { + ggml_tensor * tok_embd_w = layer.nextn.embed_tokens ? layer.nextn.embed_tokens : model.tok_embd; - ggml_tensor * h_input = inp->embd; - ggml_tensor * tok_embd = ggml_get_rows(ctx0, tok_embd_w, inp->tokens); + tok_embd = ggml_get_rows(ctx0, tok_embd_w, inp->tokens); + } else { + tok_embd = inp->embd; + } cb(tok_embd, "mtp_tok_embd", il); + inp->h = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens); + ggml_set_input(inp->h); + ggml_set_name(inp->h, "mtp_h_input"); + + ggml_tensor * h_embd = inp->h; + res->add_input(std::move(inp)); ggml_tensor * inp_pos = build_inp_pos(); ggml_tensor * inp_out_ids = build_inp_out_ids(); - auto * inp_attn = build_attn_inp_kv(); - ggml_tensor * h_norm = build_norm(h_input, layer.nextn.hnorm, nullptr, LLM_NORM_RMS, il); + auto * inp_attn = build_attn_inp_kv(); + + ggml_tensor * h_norm = build_norm(h_embd, layer.nextn.hnorm, nullptr, LLM_NORM_RMS, il); cb(h_norm, "mtp_hnorm", il); ggml_tensor * e_norm = build_norm(tok_embd, layer.nextn.enorm, nullptr, LLM_NORM_RMS, il); @@ -611,18 +619,16 @@ llama_model_qwen35::graph_mtp::graph_mtp(const llama_model & model, const llm_gr cur = ggml_add(ctx0, cur, ffn_residual); cb(cur, "mtp_post_ffn", il); - // Pre-norm hidden state: used by the AR draft loop to seed the next MTP step. - // (In the trunk graph this is `t_h_pre_norm`; the MTP head reuses the same slot.) - cb(cur, "h_pre_norm", -1); - res->t_h_pre_norm = cur; - - cur = ggml_get_rows(ctx0, cur, inp_out_ids); - ggml_tensor * head_norm_w = layer.nextn.shared_head_norm ? layer.nextn.shared_head_norm : model.output_norm; GGML_ASSERT(head_norm_w && "QWEN35 MTP: missing both nextn.shared_head_norm and output_norm"); cur = build_norm(cur, head_norm_w, nullptr, LLM_NORM_RMS, -1); + + cb(cur, "h_nextn", -1); + res->t_h_nextn = cur; + + cur = ggml_get_rows(ctx0, cur, inp_out_ids); cb(cur, "mtp_shared_head_norm", -1); ggml_tensor * head_w = layer.nextn.shared_head_head ? layer.nextn.shared_head_head : model.output; diff --git a/examples/talk-llama/models/qwen35moe.cpp b/examples/talk-llama/models/qwen35moe.cpp index dc24f6ed537..eb5e9a406a1 100644 --- a/examples/talk-llama/models/qwen35moe.cpp +++ b/examples/talk-llama/models/qwen35moe.cpp @@ -16,21 +16,20 @@ void llama_model_qwen35moe::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); // NextN/MTP (Qwen3.5/3.6): extra decoder block appended beyond the main stack - ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); - GGML_ASSERT(hparams.nextn_predict_layers < hparams.n_layer && "nextn_predict_layers must be < n_layer"); + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.n_layer_nextn, false); + GGML_ASSERT(hparams.n_layer_nextn < hparams.n_layer_all && "n_layer_nextn must be < n_layer_impl"); // Mark recurrent layers (linear attention layers). MTP layers are dense // attention-only and must be flagged non-recurrent. - { - const uint32_t n_main = hparams.n_layer - hparams.nextn_predict_layers; + if (!ml.get_key_or_arr(LLM_KV_ATTENTION_RECURRENT_LAYERS, hparams.is_recr_impl, hparams.n_layer_all, false)) { uint32_t full_attn_interval = 4; ml.get_key(LLM_KV_FULL_ATTENTION_INTERVAL, full_attn_interval, false); - for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = (i < n_main) && ((i + 1) % full_attn_interval != 0); + for (uint32_t i = 0; i < hparams.n_layer_all; ++i) { + hparams.is_recr_impl[i] = (i < hparams.n_layer()) && ((i + 1) % full_attn_interval != 0); } } - switch (hparams.n_layer - hparams.nextn_predict_layers) { + switch (hparams.n_layer()) { case 40: type = LLM_TYPE_35B_A3B; break; case 48: type = LLM_TYPE_122B_A10B; break; case 60: type = LLM_TYPE_397B_A17B; break; @@ -41,9 +40,7 @@ void llama_model_qwen35moe::load_arch_hparams(llama_model_loader & ml) { void llama_model_qwen35moe::load_arch_tensors(llama_model_loader & ml) { LLAMA_LOAD_LOCALS; - const uint32_t n_main = n_layer - hparams.nextn_predict_layers; - const bool mtp_only = (hparams.nextn_predict_layers > 0) && - (ml.get_weight("blk.0.attn_norm.weight") == nullptr); + const bool mtp_only = (hparams.n_layer_nextn > 0) && (ml.get_weight("blk.0.attn_norm.weight") == nullptr); const int trunk_flags = mtp_only ? TENSOR_NOT_REQUIRED : 0; tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); @@ -75,7 +72,7 @@ void llama_model_qwen35moe::load_arch_tensors(llama_model_loader & ml) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", il), { n_embd }, flags); layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", il), { n_embd }, flags); - if (!hparams.is_recurrent(il)) { + if (!hparams.is_recr(il)) { // Attention layers create_tensor_qkv(layer, il, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, flags); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", il), { n_embd_head_k * n_head, n_embd }, flags); @@ -144,10 +141,10 @@ void llama_model_qwen35moe::load_arch_tensors(llama_model_loader & ml) { layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", il), { n_embd }, TENSOR_NOT_REQUIRED); }; - for (int i = 0; i < (int) n_main; ++i) { + for (int i = 0; i < n_layer; ++i) { load_block_trunk(i, trunk_flags); } - for (int i = (int) n_main; i < n_layer; ++i) { + for (int i = n_layer; i < n_layer_all; ++i) { load_block_mtp(i); } } @@ -181,8 +178,7 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p ggml_tensor * inp_out_ids = build_inp_out_ids(); // MTP/NextN layers are loaded as extra decoder blocks but not executed in the main pass. - const int n_transformer_layers = n_layer - (int) hparams.nextn_predict_layers; - for (int il = 0; il < n_transformer_layers; ++il) { + for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); @@ -191,7 +187,7 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p ggml_build_forward_expand(gf, cur); // Determine layer type and build appropriate attention mechanism - if (hparams.is_recurrent(il)) { + if (hparams.is_recr(il)) { // Linear attention layer (gated delta net) cur = build_layer_attn_linear(inp->get_recr(), cur, il); } else { @@ -199,7 +195,7 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il); } - if (il == n_transformer_layers - 1 && inp_out_ids && cparams.embeddings_pre_norm_masked) { + if (il == n_layer - 1 && inp_out_ids && cparams.embeddings_nextn_masked) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } @@ -231,16 +227,16 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p } cur = inpL; - cb(cur, "h_pre_norm", -1); - res->t_h_pre_norm = cur; + // post-norm hidden state feeds both the LM head and the MTP seed below + cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1); + + cb(cur, "h_nextn", -1); + res->t_h_nextn = cur; - if (!cparams.embeddings_pre_norm_masked && inp_out_ids) { + if (!cparams.embeddings_nextn_masked && inp_out_ids) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); } - // Final norm - cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1); - cb(cur, "result_norm", -1); res->t_embd = cur; @@ -554,13 +550,13 @@ ggml_tensor * llama_model_qwen35moe::graph::build_layer_ffn(ggml_tensor * cur, c // LLM_GRAPH_TYPE_DECODER_MTP draft head for Qwen3.5/3.6 MoE llama_model_qwen35moe::graph_mtp::graph_mtp(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - GGML_ASSERT(hparams.nextn_predict_layers > 0 && "QWEN35MOE MTP requires nextn_predict_layers > 0"); - GGML_ASSERT(hparams.nextn_predict_layers == 1 && "QWEN35MOE MTP currently only supports a single MTP block"); + GGML_ASSERT(hparams.n_layer_nextn > 0 && "QWEN35MOE MTP requires n_layer_nextn > 0"); + GGML_ASSERT(hparams.n_layer_nextn == 1 && "QWEN35MOE MTP currently only supports a single MTP block"); const int64_t n_embd_head = hparams.n_embd_head_v(); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); - const int il = (int) hparams.n_layer - (int) hparams.nextn_predict_layers; + const int il = hparams.n_layer(); const auto & layer = model.layers[il]; GGML_ASSERT(layer.nextn.eh_proj && "MTP block missing nextn.eh_proj"); @@ -571,29 +567,41 @@ llama_model_qwen35moe::graph_mtp::graph_mtp(const llama_model & model, const llm int sections[4]; std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); - auto inp = std::make_unique<llm_graph_input_embd>(hparams.n_embd); + // TODO: extract in a common llm_graph_context::build_inp_embd_h() + auto inp = std::make_unique<llm_graph_input_embd_h>(hparams.n_embd); inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); ggml_set_input(inp->tokens); - inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens); + inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd_inp(), n_tokens); ggml_set_input(inp->embd); - ggml_set_name(inp->embd, "mtp_h_input"); - ggml_tensor * tok_embd_w = layer.nextn.embed_tokens ? layer.nextn.embed_tokens : model.tok_embd; + // TODO: make static using `ggml_build_forward_select()` + // see llm_graph_context::build_inp_embd() for reference + ggml_tensor * tok_embd; + if (ubatch.token) { + ggml_tensor * tok_embd_w = layer.nextn.embed_tokens ? layer.nextn.embed_tokens : model.tok_embd; - ggml_tensor * h_input = inp->embd; - ggml_tensor * tok_embd = ggml_get_rows(ctx0, tok_embd_w, inp->tokens); + tok_embd = ggml_get_rows(ctx0, tok_embd_w, inp->tokens); + } else { + tok_embd = inp->embd; + } cb(tok_embd, "mtp_tok_embd", il); + inp->h = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens); + ggml_set_input(inp->h); + ggml_set_name(inp->h, "mtp_h_input"); + + ggml_tensor * h_embd = inp->h; + res->add_input(std::move(inp)); ggml_tensor * inp_pos = build_inp_pos(); ggml_tensor * inp_out_ids = build_inp_out_ids(); - auto * inp_attn = build_attn_inp_kv(); + auto * inp_attn = build_attn_inp_kv(); - ggml_tensor * h_norm = build_norm(h_input, layer.nextn.hnorm, nullptr, LLM_NORM_RMS, il); + ggml_tensor * h_norm = build_norm(h_embd, layer.nextn.hnorm, nullptr, LLM_NORM_RMS, il); cb(h_norm, "mtp_hnorm", il); ggml_tensor * e_norm = build_norm(tok_embd, layer.nextn.enorm, nullptr, LLM_NORM_RMS, il); @@ -708,17 +716,16 @@ llama_model_qwen35moe::graph_mtp::graph_mtp(const llama_model & model, const llm cur = ggml_add(ctx0, cur, ffn_residual); cb(cur, "mtp_post_ffn", il); - // Pre-norm hidden state: used by the AR draft loop to seed the next MTP step. - cb(cur, "h_pre_norm", -1); - res->t_h_pre_norm = cur; - - cur = ggml_get_rows(ctx0, cur, inp_out_ids); - ggml_tensor * head_norm_w = layer.nextn.shared_head_norm ? layer.nextn.shared_head_norm : model.output_norm; GGML_ASSERT(head_norm_w && "QWEN35MOE MTP: missing both nextn.shared_head_norm and output_norm"); cur = build_norm(cur, head_norm_w, nullptr, LLM_NORM_RMS, -1); + + cb(cur, "h_nextn", -1); + res->t_h_nextn= cur; + + cur = ggml_get_rows(ctx0, cur, inp_out_ids); cb(cur, "mtp_shared_head_norm", -1); ggml_tensor * head_w = layer.nextn.shared_head_head ? layer.nextn.shared_head_head : model.output; diff --git a/examples/talk-llama/models/qwen3moe.cpp b/examples/talk-llama/models/qwen3moe.cpp index a4f8e1379c9..317e668bec7 100644 --- a/examples/talk-llama/models/qwen3moe.cpp +++ b/examples/talk-llama/models/qwen3moe.cpp @@ -1,10 +1,10 @@ #include "models.h" void llama_model_qwen3moe::load_arch_hparams(llama_model_loader & ml) { - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); - + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 48: type = LLM_TYPE_30B_A3B; break; case 94: type = LLM_TYPE_235B_A22B; break; default: type = LLM_TYPE_UNKNOWN; diff --git a/examples/talk-llama/models/qwen3next.cpp b/examples/talk-llama/models/qwen3next.cpp index 1d873427db5..97200a44072 100644 --- a/examples/talk-llama/models/qwen3next.cpp +++ b/examples/talk-llama/models/qwen3next.cpp @@ -14,15 +14,15 @@ void llama_model_qwen3next::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); // Mark recurrent layers (linear attention layers) - { + if (!ml.get_key_or_arr(LLM_KV_ATTENTION_RECURRENT_LAYERS, hparams.is_recr_impl, hparams.n_layer_all, false)) { uint32_t full_attn_interval = 4; ml.get_key(LLM_KV_FULL_ATTENTION_INTERVAL, full_attn_interval, false); - for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = ((i + 1) % full_attn_interval != 0); + for (uint32_t i = 0; i < hparams.n_layer_all; ++i) { + hparams.is_recr_impl[i] = (i < hparams.n_layer()) && ((i + 1) % full_attn_interval != 0); } } - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 48: type = LLM_TYPE_80B_A3B; break; default: type = LLM_TYPE_UNKNOWN; } @@ -68,7 +68,7 @@ void llama_model_qwen3next::load_arch_tensors(llama_model_loader &) { layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0); - if (!hparams.is_recurrent(i)) { + if (!hparams.is_recr(i)) { // Attention layers create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); @@ -129,7 +129,7 @@ llama_model_qwen3next::graph::graph(const llama_model & model, const llm_graph_p ggml_build_forward_expand(gf, cur); // Determine layer type and build appropriate attention mechanism - if (hparams.is_recurrent(il)) { + if (hparams.is_recr(il)) { // Linear attention layer (gated delta net) cur = build_layer_attn_linear(inp->get_recr(), cur, il); } else { diff --git a/examples/talk-llama/models/qwen3vl.cpp b/examples/talk-llama/models/qwen3vl.cpp index 5defd893944..724d6140d19 100644 --- a/examples/talk-llama/models/qwen3vl.cpp +++ b/examples/talk-llama/models/qwen3vl.cpp @@ -4,7 +4,8 @@ void llama_model_qwen3vl::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_NUM_DEEPSTACK_LAYERS, hparams.n_deepstack_layers, false); ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, true); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 28: type = LLM_TYPE_1_7B; break; case 36: type = hparams.n_embd == 2560 ? LLM_TYPE_4B : LLM_TYPE_8B; break; case 64: type = LLM_TYPE_32B; break; diff --git a/examples/talk-llama/models/qwen3vlmoe.cpp b/examples/talk-llama/models/qwen3vlmoe.cpp index 5b77df57122..7c41592f772 100644 --- a/examples/talk-llama/models/qwen3vlmoe.cpp +++ b/examples/talk-llama/models/qwen3vlmoe.cpp @@ -5,7 +5,8 @@ void llama_model_qwen3vlmoe::load_arch_hparams(llama_model_loader & ml) { ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, true); ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 48: type = LLM_TYPE_30B_A3B; break; case 94: type = LLM_TYPE_235B_A22B; break; default: type = LLM_TYPE_UNKNOWN; diff --git a/examples/talk-llama/models/refact.cpp b/examples/talk-llama/models/refact.cpp index bf3949a9092..a46c358fa68 100644 --- a/examples/talk-llama/models/refact.cpp +++ b/examples/talk-llama/models/refact.cpp @@ -2,7 +2,8 @@ void llama_model_refact::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 32: type = LLM_TYPE_1B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/rnd1.cpp b/examples/talk-llama/models/rnd1.cpp index ca8e009615e..fc276ce591b 100644 --- a/examples/talk-llama/models/rnd1.cpp +++ b/examples/talk-llama/models/rnd1.cpp @@ -2,12 +2,13 @@ void llama_model_rnd1::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 48: type = LLM_TYPE_30B_A3B; break; default: type = LLM_TYPE_UNKNOWN; } + // Set non-causal attention for diffusion models hparams.causal_attn = false; } diff --git a/examples/talk-llama/models/rwkv6.cpp b/examples/talk-llama/models/rwkv6.cpp index ba2a9dfa0db..0b5013dc758 100644 --- a/examples/talk-llama/models/rwkv6.cpp +++ b/examples/talk-llama/models/rwkv6.cpp @@ -9,7 +9,7 @@ void llama_model_rwkv6::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_RESCALE_EVERY_N_LAYERS, hparams.rescale_every_n_layers, false); ml.get_key(LLM_KV_TOKEN_SHIFT_COUNT, hparams.token_shift_count, false); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 24: type = LLM_TYPE_1_6B; break; case 32: switch (hparams.n_embd) { diff --git a/examples/talk-llama/models/rwkv6qwen2.cpp b/examples/talk-llama/models/rwkv6qwen2.cpp index 566b8cdcb54..6c7db514435 100644 --- a/examples/talk-llama/models/rwkv6qwen2.cpp +++ b/examples/talk-llama/models/rwkv6qwen2.cpp @@ -9,7 +9,7 @@ void llama_model_rwkv6qwen2::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_RESCALE_EVERY_N_LAYERS, hparams.rescale_every_n_layers, false); ml.get_key(LLM_KV_TOKEN_SHIFT_COUNT, hparams.token_shift_count, false); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 24: type = LLM_TYPE_1_6B; break; case 32: switch (hparams.n_embd) { diff --git a/examples/talk-llama/models/rwkv7.cpp b/examples/talk-llama/models/rwkv7.cpp index 7574b252621..67c51f5b59c 100644 --- a/examples/talk-llama/models/rwkv7.cpp +++ b/examples/talk-llama/models/rwkv7.cpp @@ -10,7 +10,7 @@ void llama_model_rwkv7::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_GATE_LORA_RANK, hparams.n_lora_gate, false); ml.get_key(LLM_KV_TOKEN_SHIFT_COUNT, hparams.token_shift_count, false); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 12: switch (hparams.n_embd) { case 768: type = LLM_TYPE_190M; break; diff --git a/examples/talk-llama/models/seed-oss.cpp b/examples/talk-llama/models/seed-oss.cpp index 806cba574be..57de881a091 100644 --- a/examples/talk-llama/models/seed-oss.cpp +++ b/examples/talk-llama/models/seed-oss.cpp @@ -2,7 +2,8 @@ void llama_model_seed_oss::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 64: type = LLM_TYPE_36B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/smallthinker.cpp b/examples/talk-llama/models/smallthinker.cpp index 4231cccc666..a8e3d957f1f 100644 --- a/examples/talk-llama/models/smallthinker.cpp +++ b/examples/talk-llama/models/smallthinker.cpp @@ -15,14 +15,14 @@ void llama_model_smallthinker::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); } else { hparams.swa_type = LLAMA_SWA_TYPE_NONE; - hparams.n_no_rope_layer_step = hparams.n_layer; + hparams.n_no_rope_layer_step = hparams.n_layer(); } ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 32: type = LLM_TYPE_4B; break; case 52: type = LLM_TYPE_20B; break; default: type = LLM_TYPE_UNKNOWN; diff --git a/examples/talk-llama/models/smollm3.cpp b/examples/talk-llama/models/smollm3.cpp index 90e7d473eaf..c67d967b204 100644 --- a/examples/talk-llama/models/smollm3.cpp +++ b/examples/talk-llama/models/smollm3.cpp @@ -4,7 +4,7 @@ void llama_model_smollm3::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); hparams.n_no_rope_layer_step = 4; - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 36: type = LLM_TYPE_3B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/stablelm.cpp b/examples/talk-llama/models/stablelm.cpp index 4da7f7aefcf..bf6087b8796 100644 --- a/examples/talk-llama/models/stablelm.cpp +++ b/examples/talk-llama/models/stablelm.cpp @@ -3,7 +3,7 @@ void llama_model_stablelm::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 24: type = LLM_TYPE_1B; break; case 32: type = LLM_TYPE_3B; break; case 40: type = LLM_TYPE_12B; break; diff --git a/examples/talk-llama/models/starcoder.cpp b/examples/talk-llama/models/starcoder.cpp index e131af058bc..f73a88fd4e9 100644 --- a/examples/talk-llama/models/starcoder.cpp +++ b/examples/talk-llama/models/starcoder.cpp @@ -2,7 +2,8 @@ void llama_model_starcoder::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 24: type = LLM_TYPE_1B; break; case 36: type = LLM_TYPE_3B; break; case 42: type = LLM_TYPE_7B; break; diff --git a/examples/talk-llama/models/starcoder2.cpp b/examples/talk-llama/models/starcoder2.cpp index 9c207c02885..b81b469374a 100644 --- a/examples/talk-llama/models/starcoder2.cpp +++ b/examples/talk-llama/models/starcoder2.cpp @@ -2,7 +2,8 @@ void llama_model_starcoder2::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 30: type = LLM_TYPE_3B; break; case 32: type = LLM_TYPE_7B; break; case 40: type = LLM_TYPE_15B; break; diff --git a/examples/talk-llama/models/step35.cpp b/examples/talk-llama/models/step35.cpp index 3b68e68707a..e2218c58704 100644 --- a/examples/talk-llama/models/step35.cpp +++ b/examples/talk-llama/models/step35.cpp @@ -22,24 +22,39 @@ void llama_model_step35::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); - ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.swa_layers, hparams.n_layer); - ml.get_key_or_arr(LLM_KV_SWIGLU_CLAMP_EXP, hparams.swiglu_clamp_exp, hparams.n_layer, false); - ml.get_key_or_arr(LLM_KV_SWIGLU_CLAMP_SHEXP, hparams.swiglu_clamp_shexp, hparams.n_layer, false); - switch (hparams.n_layer) { + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.is_swa_impl, hparams.n_layer()); + + ml.get_key_or_arr(LLM_KV_SWIGLU_CLAMP_EXP, hparams.swiglu_clamp_exp, hparams.n_layer(), false); + ml.get_key_or_arr(LLM_KV_SWIGLU_CLAMP_SHEXP, hparams.swiglu_clamp_shexp, hparams.n_layer(), false); + + // NextN/MTP (Step3p5): extra decoder block appended beyond the main stack. + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.n_layer_nextn, false); + GGML_ASSERT(hparams.n_layer_nextn < hparams.n_layer_all && "n_layer_nextn must be < n_layer_impl"); + + switch (hparams.n_layer()) { case 45: type = LLM_TYPE_196B_A11B; break; default: type = LLM_TYPE_UNKNOWN; } } -void llama_model_step35::load_arch_tensors(llama_model_loader &) { +void llama_model_step35::load_arch_tensors(llama_model_loader & ml) { LLAMA_LOAD_LOCALS; + const bool mtp_only = (hparams.n_layer_nextn > 0) && (ml.get_weight("blk.0.attn_norm.weight") == nullptr); + // Trunk-only: the GGUF declares MTP layers in metadata but the actual MTP + // tensors live in a separate file (e.g. user split target/draft). Mark + // MTP tensors NOT_REQUIRED so the trunk loads cleanly. + const std::string mtp_probe = "blk." + std::to_string(n_layer) + ".nextn.eh_proj.weight"; + const bool trunk_only = (hparams.n_layer_nextn > 0) && (ml.get_weight(mtp_probe.c_str()) == nullptr); + const int trunk_flags = mtp_only ? TENSOR_NOT_REQUIRED : 0; + const int mtp_flags = trunk_only ? TENSOR_NOT_REQUIRED : 0; + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); // output output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, trunk_flags); // STEP35 supports per-layer partial RoPE dims; rope factors are stored as a single shared tensor // ("rope_freqs.weight") and ggml uses only the first (n_rot_l/2) entries per layer. @@ -51,14 +66,14 @@ void llama_model_step35::load_arch_tensors(llama_model_loader &) { n_rot_max = n_rot; } - for (int i = 0; i < n_layer; ++i) { + auto load_block_trunk = [&](int i, int flags) { auto & layer = layers[i]; const uint32_t n_head_l = hparams.n_head(i); const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i); const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i); - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, flags); layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, TENSOR_NOT_REQUIRED); layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, TENSOR_NOT_REQUIRED); @@ -70,13 +85,13 @@ void llama_model_step35::load_arch_tensors(llama_model_loader &) { layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot_max/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); } - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head_l, n_embd_k_gqa, n_embd_v_gqa, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_v * n_head_l, n_embd}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head_l, n_embd_k_gqa, n_embd_v_gqa, flags); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_v * n_head_l, n_embd}, flags); // head-wise attention gate (Step35 self_attn.g_proj) layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), {n_embd, n_head_l}, TENSOR_NOT_REQUIRED); - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, flags); // dense MLP (leading dense blocks) layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); @@ -95,10 +110,86 @@ void llama_model_step35::load_arch_tensors(llama_model_loader &) { layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, TENSOR_NOT_REQUIRED); layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, TENSOR_NOT_REQUIRED); layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, TENSOR_NOT_REQUIRED); + }; + + auto load_block_mtp = [&](int i, bool is_first_mtp) { + auto & layer = layers[i]; + + const uint32_t n_head_l = hparams.n_head(i); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i); + + // The MTP block is a full Step3p5 decoder layer (mtp_block) plus the + // NextN-specific wiring (enorm/hnorm/eh_proj + optional shared head). + // `mtp_flags` becomes NOT_REQUIRED when the GGUF is trunk-only. + // + // Only the FIRST MTP block (i == n_main) is required for the + // single-block MTP runtime; trailing MTP blocks are always tolerated + // as missing so pruned GGUFs (block 0 only) load cleanly. Override + // mtp_flags to NOT_REQUIRED for those. + const int eff_mtp_flags = is_first_mtp ? mtp_flags : (mtp_flags | TENSOR_NOT_REQUIRED); + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, eff_mtp_flags); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, TENSOR_NOT_REQUIRED); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, TENSOR_NOT_REQUIRED); + + if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot_max/2}, TENSOR_NOT_REQUIRED | TENSOR_DUPLICATED); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot_max/2}, TENSOR_NOT_REQUIRED | TENSOR_DUPLICATED); + } else { + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot_max/2}, TENSOR_NOT_REQUIRED | TENSOR_DUPLICATED); + } + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head_l, n_embd_k_gqa, n_embd_v_gqa, eff_mtp_flags); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_v * n_head_l, n_embd}, eff_mtp_flags); + + layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), {n_embd, n_head_l}, TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, eff_mtp_flags); + + // dense MLP (leading dense blocks) — present if the MTP block isn't MoE + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); + + // MoE routed experts + selection bias (router_bias) + const int64_t n_ff_exp = hparams.n_ff_exp; + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); + + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, TENSOR_NOT_REQUIRED); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, TENSOR_NOT_REQUIRED); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, TENSOR_NOT_REQUIRED); + + // NextN-specific tensors that define the MTP block. + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, eff_mtp_flags); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, eff_mtp_flags); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, eff_mtp_flags); + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED); + }; + + for (int i = 0; i < n_layer; ++i) { + load_block_trunk(i, trunk_flags); + } + // Only the first MTP block (i == n_main) is required at runtime — the + // single-block-MTP graph in build_arch_graph always uses that one. + // Trailing MTP blocks are loaded if present (so an un-pruned GGUF with + // all MTP layers still works) but tolerated when absent via the pruning + // path. See scripts/prune_step35_extra_mtp.py for the pruner. + for (int i = n_layer; i < n_layer_all; ++i) { + load_block_mtp(i, /*is_first_mtp=*/ i == n_layer); } } std::unique_ptr<llm_graph_context> llama_model_step35::build_arch_graph(const llm_graph_params & params) const { + if (params.gtype == LLM_GRAPH_TYPE_DECODER_MTP) { + return std::make_unique<graph_mtp>(*this, params); + } return std::make_unique<graph>(*this, params); } @@ -111,6 +202,7 @@ llama_model_step35::graph::graph(const llama_model & model, const llm_graph_para auto * inp_attn = build_attn_inp_kv_iswa(); ggml_tensor * inp_out_ids = build_inp_out_ids(); + // MTP/NextN layers are loaded as extra decoder blocks but not executed in the main pass. for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; @@ -198,8 +290,8 @@ llama_model_step35::graph::graph(const llama_model & model, const llm_graph_para cb(cur, "attn_proj", il); } - if (il == n_layer - 1 && inp_out_ids) { - cur = ggml_get_rows(ctx0, cur, inp_out_ids); + if (il == n_layer - 1 && inp_out_ids && cparams.embeddings_nextn_masked) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } @@ -257,6 +349,13 @@ llama_model_step35::graph::graph(const llama_model & model, const llm_graph_para cur = inpL; + cb(cur, "h_nextn", -1); + res->t_h_nextn = cur; + + if (!cparams.embeddings_nextn_masked && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + } + cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1); cb(cur, "result_norm", -1); res->t_embd = cur; @@ -267,3 +366,192 @@ llama_model_step35::graph::graph(const llama_model & model, const llm_graph_para ggml_build_forward_expand(gf, cur); } + +// LLM_GRAPH_TYPE_DECODER_MTP draft head for Step3p5 (MoE) +llama_model_step35::graph_mtp::graph_mtp(const llama_model & model, const llm_graph_params & params) + : llm_graph_context(params) { + GGML_ASSERT(hparams.n_layer_nextn > 0 && "STEP35 MTP requires n_layer_nextn > 0"); + + // Single-block MTP only: always run the first trained MTP block (Qwen + // MTP / vLLM single-MTP-layer style). Multi-block round-robin proved to + // be a much deeper refactor than this PR justifies; the trailing MTP + // blocks are loaded with TENSOR_NOT_REQUIRED so pruned GGUFs (with just + // block 0) also work — see load_arch_tensors below and + // scripts/prune_step35_extra_mtp.py. + const int il = hparams.n_layer(); + const auto & layer = model.layers[il]; + + GGML_ASSERT(layer.nextn.eh_proj && "MTP block missing nextn.eh_proj"); + GGML_ASSERT(layer.nextn.enorm && "MTP block missing nextn.enorm"); + GGML_ASSERT(layer.nextn.hnorm && "MTP block missing nextn.hnorm"); + + const uint32_t n_head_l = hparams.n_head(il); + const uint32_t n_head_kv_l = hparams.n_head_kv(il); + + const float freq_base_l = model.get_rope_freq_base(cparams, il); + const float freq_scale_l = model.get_rope_freq_scale(cparams, il); + + auto inp = std::make_unique<llm_graph_input_embd>(hparams.n_embd); + + inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + ggml_set_input(inp->tokens); + + inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens); + ggml_set_input(inp->embd); + ggml_set_name(inp->embd, "mtp_h_input"); + + ggml_tensor * tok_embd_w = layer.nextn.embed_tokens ? layer.nextn.embed_tokens : model.tok_embd; + + ggml_tensor * h_input = inp->embd; + ggml_tensor * tok_embd = ggml_get_rows(ctx0, tok_embd_w, inp->tokens); + cb(tok_embd, "mtp_tok_embd", il); + + res->add_input(std::move(inp)); + + ggml_tensor * inp_pos = build_inp_pos(); + auto * inp_attn = build_attn_inp_kv_iswa(); + + ggml_tensor * h_norm = build_norm(h_input, layer.nextn.hnorm, nullptr, LLM_NORM_RMS, il); + cb(h_norm, "mtp_hnorm", il); + + ggml_tensor * e_norm = build_norm(tok_embd, layer.nextn.enorm, nullptr, LLM_NORM_RMS, il); + cb(e_norm, "mtp_enorm", il); + + ggml_tensor * concat = ggml_concat(ctx0, e_norm, h_norm, /*dim=*/ 0); + cb(concat, "mtp_concat", il); + + ggml_tensor * cur = build_lora_mm(layer.nextn.eh_proj, concat); + cb(cur, "mtp_eh_proj", il); + + ggml_tensor * inpSA = cur; + + // mtp_block: full Step3p5 decoder layer (attention with optional head-wise gate, then MoE/dense FFN) + cur = build_norm(cur, layer.attn_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "mtp_attn_norm", il); + + ggml_tensor * Qcur = build_lora_mm(layer.wq, cur, layer.wq_s); + ggml_tensor * Kcur = build_lora_mm(layer.wk, cur, layer.wk_s); + ggml_tensor * Vcur = build_lora_mm(layer.wv, cur, layer.wv_s); + cb(Qcur, "mtp_Qcur", il); + cb(Kcur, "mtp_Kcur", il); + cb(Vcur, "mtp_Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head_l, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv_l, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head_v, n_head_kv_l, n_tokens); + + if (layer.attn_q_norm) { + Qcur = build_norm(Qcur, layer.attn_q_norm, nullptr, LLM_NORM_RMS, il); + cb(Qcur, "mtp_Qcur_normed", il); + } + if (layer.attn_k_norm) { + Kcur = build_norm(Kcur, layer.attn_k_norm, nullptr, LLM_NORM_RMS, il); + cb(Kcur, "mtp_Kcur_normed", il); + } + + const bool is_swa = hparams.is_swa(il); + ggml_tensor * rope_factors = is_swa ? nullptr : model.get_rope_factors(cparams, il); + const int64_t n_rot_l = hparams.n_rot(il); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, + n_rot_l, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, + ext_factor, attn_factor, beta_fast, beta_slow); + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + n_rot_l, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, + ext_factor, attn_factor, beta_fast, beta_slow); + cb(Qcur, "mtp_Qcur_pos", il); + cb(Kcur, "mtp_Kcur_pos", il); + + const float kq_scale = 1.0f / sqrtf(float(n_embd_head_k)); + ggml_tensor * attn_out = build_attn(inp_attn, + nullptr, nullptr, nullptr, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + cb(attn_out, "mtp_attn_out", il); + + // head-wise attention gate: sigmoid(g_proj(x)) + if (layer.wqkv_gate) { + ggml_tensor * gate = build_lora_mm(layer.wqkv_gate, cur); // [n_head_l, n_tokens] + cb(gate, "mtp_attn_gate", il); + + gate = ggml_sigmoid(ctx0, gate); + cb(gate, "mtp_attn_gate_sigmoid", il); + + ggml_tensor * attn_3d = ggml_reshape_3d(ctx0, attn_out, n_embd_head_v, n_head_l, n_tokens); + ggml_tensor * gate_3d = ggml_reshape_3d(ctx0, gate, 1, n_head_l, n_tokens); + cb(gate_3d, "mtp_attn_gate_3d", il); + + attn_3d = ggml_mul(ctx0, attn_3d, gate_3d); + cb(attn_3d, "mtp_attn_gated_3d", il); + + attn_out = ggml_reshape_2d(ctx0, attn_3d, n_embd_head_v * n_head_l, n_tokens); + cb(attn_out, "mtp_attn_gated", il); + } + + cur = build_lora_mm(layer.wo, attn_out, layer.wo_s); + cb(cur, "mtp_attn_proj", il); + + cur = ggml_add(ctx0, cur, inpSA); + cb(cur, "mtp_attn_residual", il); + + ggml_tensor * ffn_inp = cur; + cur = build_norm(cur, layer.ffn_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "mtp_ffn_norm", il); + + // FFN: dense MLP or MoE (mirrors trunk path) + if (layer.ffn_gate_inp == nullptr) { + cur = build_ffn(cur, + layer.ffn_up, layer.ffn_up_b, nullptr, + layer.ffn_gate, layer.ffn_gate_b, nullptr, + layer.ffn_down, layer.ffn_down_b, nullptr, + nullptr, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "mtp_ffn_out", il); + } else { + ggml_tensor * moe_out = build_moe_ffn(cur, + layer.ffn_gate_inp, + layer.ffn_up_exps, + layer.ffn_gate_exps, + layer.ffn_down_exps, + layer.ffn_exp_probs_b, + n_expert, n_expert_used, + LLM_FFN_SILU, hparams.expert_weights_norm, + hparams.expert_weights_scale, + (llama_expert_gating_func_type) hparams.expert_gating_func, + il); + cb(moe_out, "mtp_ffn_moe_out", il); + + ggml_tensor * sh_out = build_ffn(cur, + layer.ffn_up_shexp, nullptr, nullptr, + layer.ffn_gate_shexp, nullptr, nullptr, + layer.ffn_down_shexp, nullptr, nullptr, + nullptr, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(sh_out, "mtp_ffn_shared_out", il); + + cur = ggml_add(ctx0, moe_out, sh_out); + cb(cur, "mtp_ffn_out", il); + } + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "mtp_post_ffn", il); + + // Pre-norm hidden state: used by the AR draft loop to seed the next MTP step. + cb(cur, "h_nextn", -1); + res->t_h_nextn = cur; + + ggml_tensor * head_norm_w = layer.nextn.shared_head_norm + ? layer.nextn.shared_head_norm + : model.output_norm; + GGML_ASSERT(head_norm_w && "STEP35 MTP: missing both nextn.shared_head_norm and output_norm"); + cur = build_norm(cur, head_norm_w, nullptr, LLM_NORM_RMS, -1); + cb(cur, "mtp_shared_head_norm", -1); + + ggml_tensor * head_w = layer.nextn.shared_head_head ? layer.nextn.shared_head_head : model.output; + GGML_ASSERT(head_w && "STEP35 MTP: missing LM head (nextn.shared_head_head or model.output)"); + cur = build_lora_mm(head_w, cur); + cb(cur, "result_output", -1); + + res->t_logits = cur; + ggml_build_forward_expand(gf, cur); +} diff --git a/examples/talk-llama/models/t5.cpp b/examples/talk-llama/models/t5.cpp index 73e32741406..b0e3f062572 100644 --- a/examples/talk-llama/models/t5.cpp +++ b/examples/talk-llama/models/t5.cpp @@ -9,10 +9,10 @@ void llama_model_t5::load_arch_hparams(llama_model_loader & ml) { hparams.dec_start_token_id = dec_start_token_id; } - hparams.dec_n_layer = hparams.n_layer; + hparams.dec_n_layer = hparams.n_layer(); ml.get_key(LLM_KV_DECODER_BLOCK_COUNT, hparams.dec_n_layer, false); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 6: type = LLM_TYPE_60M; break; // t5-small case 8: type = LLM_TYPE_80M; break; // flan-t5-small case 12: diff --git a/examples/talk-llama/models/talkie.cpp b/examples/talk-llama/models/talkie.cpp index 1258eeb19b6..393e8f65bf4 100644 --- a/examples/talk-llama/models/talkie.cpp +++ b/examples/talk-llama/models/talkie.cpp @@ -4,7 +4,7 @@ void llama_model_talkie::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); - switch (hparams.n_layer) { + switch (hparams.n_layer()) { case 40: type = LLM_TYPE_13B; break; default: type = LLM_TYPE_UNKNOWN; } diff --git a/examples/talk-llama/models/xverse.cpp b/examples/talk-llama/models/xverse.cpp index d6d1c7a2e5d..3135001293a 100644 --- a/examples/talk-llama/models/xverse.cpp +++ b/examples/talk-llama/models/xverse.cpp @@ -2,7 +2,8 @@ void llama_model_xverse::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + + switch (hparams.n_layer()) { case 32: type = LLM_TYPE_7B; break; case 40: type = LLM_TYPE_13B; break; case 80: type = LLM_TYPE_65B; break; From ba573929cd31ddea3c77c5dc9caae78da8117123 Mon Sep 17 00:00:00 2001 From: Christopher Albert <albert@tugraz.at> Date: Tue, 9 Jun 2026 08:34:31 +0200 Subject: [PATCH 799/831] coreml : fix --quantize crash for mlprogram format; fix --optimize-ane label (#3868) commit 8b92060 switched ct.convert() to mlprogram, but did not update the --quantize path. quantize_weights() from neural_network.quantization_utils only works with the legacy neuralnetwork format. Running with --quantize crashed with: Exception: MLModel of type mlProgram cannot be loaded just from the model spec object. It also needs the path to the weights file. Fix: pass compute_precision=ct.precision.FLOAT16 into ct.convert() when --quantize is set. This matches the original intent of nbits=16 (F16 storage) without changing the quantization scheme or model accuracy. Also fix the three boolean CLI flags (--encoder-only, --quantize, --optimize-ane) to use a _str_to_bool helper so that both --flag True and --flag False parse correctly. The type=bool form accepted "False" as True because bool("False") == True. Remove the "currently broken" label from --optimize-ane: the ANE path (WhisperANE with Conv2d attention and LayerNormANE) converts and loads correctly with both PyTorch 2.x and coremltools 9.x. --- models/convert-whisper-to-coreml.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/models/convert-whisper-to-coreml.py b/models/convert-whisper-to-coreml.py index 66827b6d420..7cf07754a89 100644 --- a/models/convert-whisper-to-coreml.py +++ b/models/convert-whisper-to-coreml.py @@ -8,10 +8,19 @@ from typing import Dict from typing import Optional from ane_transformers.reference.layer_norm import LayerNormANE as LayerNormANEBase -from coremltools.models.neural_network.quantization_utils import quantize_weights from whisper.model import Whisper, AudioEncoder, TextDecoder, ResidualAttentionBlock, MultiHeadAttention, ModelDimensions from whisper import load_model + +def _str_to_bool(v): + if isinstance(v, bool): + return v + if v.lower() in ("true", "1", "yes"): + return True + if v.lower() in ("false", "0", "no"): + return False + raise argparse.ArgumentTypeError(f"boolean value expected, got '{v}'") + # Disable PyTorch Scaled Dot-Product Attention (SDPA) to avoid compatibility issues. # The Whisper implementation expects a specific behavior from # torch.nn.functional.scaled_dot_product_attention that differs between PyTorch @@ -258,11 +267,9 @@ def convert_encoder(hparams, model, quantize=False): inputs=[ct.TensorType(name="logmel_data", shape=input_shape)], outputs=[ct.TensorType(name="output")], compute_units=ct.ComputeUnit.ALL, + compute_precision=ct.precision.FLOAT16 if quantize else ct.precision.FLOAT32, ) - if quantize: - model = quantize_weights(model, nbits=16) - return model def convert_decoder(hparams, model, quantize=False): @@ -283,20 +290,18 @@ def convert_decoder(hparams, model, quantize=False): ct.TensorType(name="token_data", shape=tokens_shape, dtype=int), ct.TensorType(name="audio_data", shape=audio_shape) ], + compute_precision=ct.precision.FLOAT16 if quantize else ct.precision.FLOAT32, ) - if quantize: - model = quantize_weights(model, nbits=16) - return model if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model", type=str, help="model to convert (e.g. tiny, tiny.en, base, base.en, small, small.en, medium, medium.en, large-v1, large-v2, large-v3, large-v3-turbo)", required=True) - parser.add_argument("--encoder-only", type=bool, help="only convert encoder", default=False) - parser.add_argument("--quantize", type=bool, help="quantize weights to F16", default=False) - parser.add_argument("--optimize-ane", type=bool, help="optimize for ANE execution (currently broken)", default=False) + parser.add_argument("--encoder-only", type=_str_to_bool, help="only convert encoder", default=False) + parser.add_argument("--quantize", type=_str_to_bool, help="quantize weights to F16", default=False) + parser.add_argument("--optimize-ane", type=_str_to_bool, help="optimize for ANE execution", default=False) args = parser.parse_args() if args.model not in ["tiny", "tiny.en", "base", "base.en", "small", "small.en", "small.en-tdrz", "medium", "medium.en", "large-v1", "large-v2", "large-v3", "large-v3-turbo"]: From df7638d8229a243af8a4b5a8ae557e0d74e0a0ae Mon Sep 17 00:00:00 2001 From: Daniel Bevenius <daniel.bevenius@gmail.com> Date: Tue, 9 Jun 2026 12:51:00 +0200 Subject: [PATCH 800/831] ci : pin github actions to commit sha's (#3865) --- .github/workflows/bindings-go.yml | 4 +- .github/workflows/bindings-ruby.yml | 4 +- .github/workflows/build-android.yml | 8 ++-- .github/workflows/build-clang.yml | 4 +- .github/workflows/build-coreml.yml | 2 +- .github/workflows/build-cpu.yml | 10 ++--- .github/workflows/build-freebsd.yml | 4 +- .github/workflows/build-gcc.yml | 6 +-- .github/workflows/build-macos.yml | 2 +- .github/workflows/build-quantize.yml | 2 +- .github/workflows/build-sanitize.yml | 2 +- .github/workflows/build-self-hosted.yml | 10 ++--- .github/workflows/build-sycl.yml | 4 +- .github/workflows/build-vad.yml | 2 +- .github/workflows/build-wasm.yml | 2 +- .github/workflows/build-windows.yml | 2 +- .github/workflows/deploy-examples-wasm.yml | 8 ++-- .github/workflows/docker.yml | 2 +- .github/workflows/examples.yml | 4 +- .github/workflows/release.yml | 46 +++++++++++----------- 20 files changed, 64 insertions(+), 64 deletions(-) diff --git a/.github/workflows/bindings-go.yml b/.github/workflows/bindings-go.yml index 44381a4b411..91f869e99cf 100644 --- a/.github/workflows/bindings-go.yml +++ b/.github/workflows/bindings-go.yml @@ -13,10 +13,10 @@ jobs: ubuntu-22: runs-on: ubuntu-22.04 steps: - - uses: actions/setup-go@v6 + - uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6 with: go-version: '^1.23' - - uses: actions/checkout@v6 + - uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - run: | cd bindings/go make test diff --git a/.github/workflows/bindings-ruby.yml b/.github/workflows/bindings-ruby.yml index 0c31701a2a3..80a243e4c98 100644 --- a/.github/workflows/bindings-ruby.yml +++ b/.github/workflows/bindings-ruby.yml @@ -25,8 +25,8 @@ jobs: run: working-directory: bindings/ruby steps: - - uses: ruby/setup-ruby@v1 + - uses: ruby/setup-ruby@afeafc3d1ab54a631816aba4c914a0081c12ff2f # v1.310.0 with: ruby-version: '3.2' - - uses: actions/checkout@v6 + - uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - run: rake test diff --git a/.github/workflows/build-android.yml b/.github/workflows/build-android.yml index 42673166cf3..571c35872c8 100644 --- a/.github/workflows/build-android.yml +++ b/.github/workflows/build-android.yml @@ -30,12 +30,12 @@ jobs: steps: - name: Clone - uses: actions/checkout@v6 + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 with: path: whisper - name: Install Java - uses: actions/setup-java@v5 + uses: actions/setup-java@be666c2fcd27ec809703dec50e508c2fdc7f6654 # v5 with: distribution: zulu java-version: 21 @@ -59,10 +59,10 @@ jobs: steps: - name: Clone - uses: actions/checkout@v6 + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - name: set up JDK 11 - uses: actions/setup-java@v5 + uses: actions/setup-java@be666c2fcd27ec809703dec50e508c2fdc7f6654 # v5 with: java-version: '11' distribution: 'temurin' diff --git a/.github/workflows/build-clang.yml b/.github/workflows/build-clang.yml index 5308164cc68..20b7fec6494 100644 --- a/.github/workflows/build-clang.yml +++ b/.github/workflows/build-clang.yml @@ -48,7 +48,7 @@ jobs: steps: - name: Clone - uses: actions/checkout@v6 + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - name: Set CCACHE_DIR run: echo "CCACHE_DIR=${{ runner.temp }}/ccache" >> $GITHUB_ENV @@ -95,7 +95,7 @@ jobs: steps: - name: Clone - uses: actions/checkout@v6 + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - name: ccache uses: ggml-org/ccache-action@v1.2.21 diff --git a/.github/workflows/build-coreml.yml b/.github/workflows/build-coreml.yml index d383d9ae0a7..8dedd7819ed 100644 --- a/.github/workflows/build-coreml.yml +++ b/.github/workflows/build-coreml.yml @@ -31,7 +31,7 @@ jobs: steps: - name: Checkout with full history - uses: actions/checkout@v6 + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 with: fetch-depth: 0 diff --git a/.github/workflows/build-cpu.yml b/.github/workflows/build-cpu.yml index 9c8e0586fcb..e2b74881ea5 100644 --- a/.github/workflows/build-cpu.yml +++ b/.github/workflows/build-cpu.yml @@ -38,7 +38,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v6 + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - name: ccache uses: ggml-org/ccache-action@v1.2.21 @@ -66,7 +66,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v6 + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - name: ccache uses: ggml-org/ccache-action@v1.2.21 @@ -94,7 +94,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v6 + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - name: ccache uses: ggml-org/ccache-action@v1.2.21 @@ -122,7 +122,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v6 + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - name: ccache uses: ggml-org/ccache-action@v1.2.21 @@ -150,7 +150,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v6 + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - name: ccache uses: ggml-org/ccache-action@v1.2.21 diff --git a/.github/workflows/build-freebsd.yml b/.github/workflows/build-freebsd.yml index 847ae975e30..64e78ad62f8 100644 --- a/.github/workflows/build-freebsd.yml +++ b/.github/workflows/build-freebsd.yml @@ -33,10 +33,10 @@ jobs: steps: - name: Clone - uses: actions/checkout@v6 + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - name: Build - uses: cross-platform-actions/action@v0.27.0 + uses: cross-platform-actions/action@fe0167d8082ac584754ef3ffb567fded22642c7d # v0.27.0 with: operating_system: freebsd version: '14.2' diff --git a/.github/workflows/build-gcc.yml b/.github/workflows/build-gcc.yml index b1b04c24034..3d8b5137344 100644 --- a/.github/workflows/build-gcc.yml +++ b/.github/workflows/build-gcc.yml @@ -45,7 +45,7 @@ jobs: steps: - name: Clone - uses: actions/checkout@v6 + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - name: Set CCACHE_DIR run: echo "CCACHE_DIR=${{ runner.temp }}/ccache" >> $GITHUB_ENV @@ -90,7 +90,7 @@ jobs: steps: - name: Clone - uses: actions/checkout@v6 + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - name: ccache uses: ggml-org/ccache-action@v1.2.21 @@ -128,7 +128,7 @@ jobs: steps: - name: Clone - uses: actions/checkout@v6 + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - name: Set CCACHE_DIR run: echo "CCACHE_DIR=${{ runner.temp }}/ccache" >> $GITHUB_ENV diff --git a/.github/workflows/build-macos.yml b/.github/workflows/build-macos.yml index 804f8bbb642..8b209e4eec8 100644 --- a/.github/workflows/build-macos.yml +++ b/.github/workflows/build-macos.yml @@ -44,7 +44,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v6 + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - name: ccache uses: ggml-org/ccache-action@v1.2.21 diff --git a/.github/workflows/build-quantize.yml b/.github/workflows/build-quantize.yml index 69ab2c34638..1c9576af7f1 100644 --- a/.github/workflows/build-quantize.yml +++ b/.github/workflows/build-quantize.yml @@ -29,7 +29,7 @@ jobs: steps: - name: Clone - uses: actions/checkout@v6 + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - name: ccache uses: ggml-org/ccache-action@v1.2.21 diff --git a/.github/workflows/build-sanitize.yml b/.github/workflows/build-sanitize.yml index 9250fe81023..e517f7bade4 100644 --- a/.github/workflows/build-sanitize.yml +++ b/.github/workflows/build-sanitize.yml @@ -39,7 +39,7 @@ jobs: steps: - name: Clone - uses: actions/checkout@v6 + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - name: ccache uses: ggml-org/ccache-action@v1.2.21 diff --git a/.github/workflows/build-self-hosted.yml b/.github/workflows/build-self-hosted.yml index 3fe131b9ba5..2286b63d6e7 100644 --- a/.github/workflows/build-self-hosted.yml +++ b/.github/workflows/build-self-hosted.yml @@ -52,7 +52,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v6 + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - name: Test id: ggml-ci @@ -66,7 +66,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v6 + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - name: Test id: ggml-ci @@ -80,7 +80,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v6 + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - name: Test id: ggml-ci @@ -94,7 +94,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v6 + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - name: Test id: ggml-ci @@ -107,7 +107,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v6 + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - name: Test id: ggml-ci diff --git a/.github/workflows/build-sycl.yml b/.github/workflows/build-sycl.yml index c76954e49cf..e5361645f1e 100644 --- a/.github/workflows/build-sycl.yml +++ b/.github/workflows/build-sycl.yml @@ -46,7 +46,7 @@ jobs: steps: - name: Clone - uses: actions/checkout@v6 + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - name: add oneAPI to apt shell: bash @@ -105,7 +105,7 @@ jobs: steps: - name: Clone - uses: actions/checkout@v6 + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - name: add oneAPI to apt shell: bash diff --git a/.github/workflows/build-vad.yml b/.github/workflows/build-vad.yml index 3c5ebec2026..dd0efa33efe 100644 --- a/.github/workflows/build-vad.yml +++ b/.github/workflows/build-vad.yml @@ -29,7 +29,7 @@ jobs: steps: - name: Checkout - uses: actions/checkout@v6 + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - name: ccache uses: ggml-org/ccache-action@v1.2.21 diff --git a/.github/workflows/build-wasm.yml b/.github/workflows/build-wasm.yml index 45c77c0be4c..c17a44ae455 100644 --- a/.github/workflows/build-wasm.yml +++ b/.github/workflows/build-wasm.yml @@ -37,7 +37,7 @@ jobs: steps: - name: Clone - uses: actions/checkout@v6 + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - name: Setup emsdk uses: emscripten-core/setup-emsdk@6ab9eb1bda2574c4ddb79809fc9247783eaf9021 # v14 diff --git a/.github/workflows/build-windows.yml b/.github/workflows/build-windows.yml index 156a57f74b6..76b7a7370ce 100644 --- a/.github/workflows/build-windows.yml +++ b/.github/workflows/build-windows.yml @@ -43,7 +43,7 @@ jobs: steps: - name: Clone - uses: actions/checkout@v6 + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - name: Setup ${{ matrix.sys }} uses: msys2/setup-msys2@cafece8e6baf9247cf9b1bf95097b0b983cc558d # v2 diff --git a/.github/workflows/deploy-examples-wasm.yml b/.github/workflows/deploy-examples-wasm.yml index e7fdae77854..55df14720b1 100644 --- a/.github/workflows/deploy-examples-wasm.yml +++ b/.github/workflows/deploy-examples-wasm.yml @@ -22,10 +22,10 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v6 + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - name: Setup Pages - uses: actions/configure-pages@v5 + uses: actions/configure-pages@983d7736d9b0ae728b81ab479565c72886d7745b # v5 - name: Setup emsdk uses: emscripten-core/setup-emsdk@6ab9eb1bda2574c4ddb79809fc9247783eaf9021 # v14 @@ -88,10 +88,10 @@ jobs: find staging -type f | sort - name: Upload artifact - uses: actions/upload-pages-artifact@v4 + uses: actions/upload-pages-artifact@7b1f4a764d45c48632c6b24a0339c27f5614fb0b # v4 with: path: ./staging - name: Deploy to GitHub Pages id: deployment - uses: actions/deploy-pages@v4 + uses: actions/deploy-pages@d6db90164ac5ed86f2b6aed7e0febac5b3c0c03e # v4 diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index b4c455b92e9..2d95e1a697f 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -27,7 +27,7 @@ jobs: steps: - name: Check out the repo - uses: actions/checkout@v6 + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - name: Set up Docker Buildx uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4 diff --git a/.github/workflows/examples.yml b/.github/workflows/examples.yml index eaa4fe4df61..ac811712e78 100644 --- a/.github/workflows/examples.yml +++ b/.github/workflows/examples.yml @@ -19,7 +19,7 @@ jobs: node-version: [ 16.x, 18.x ] steps: - name: Clone - uses: actions/checkout@v6 + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - name: Dependencies run: | @@ -29,7 +29,7 @@ jobs: sudo apt-get install libsdl2-dev - name: Use Node.js ${{ matrix.node-version }} - uses: actions/setup-node@v6 + uses: actions/setup-node@48b55a011bda9f5d6aeb4c2d9c7362e8dae4041e # v6 with: node-version: ${{ matrix.node-version }} cache: 'npm' diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index c3ae9de4deb..11d47546caa 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -38,7 +38,7 @@ jobs: steps: - name: Checkout with full history - uses: actions/checkout@v6 + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 with: fetch-depth: 0 @@ -100,7 +100,7 @@ jobs: steps: - name: Clone - uses: actions/checkout@v6 + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - name: ccache uses: ggml-org/ccache-action@v1.2.21 @@ -130,7 +130,7 @@ jobs: -C ./build/bin . - name: Upload artifacts - uses: actions/upload-artifact@v6 + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6 with: path: whisper-bin-ubuntu-${{ matrix.build }}.tar.gz name: whisper-bin-ubuntu-${{ matrix.build }}.tar.gz @@ -156,10 +156,10 @@ jobs: steps: - name: Clone - uses: actions/checkout@v6 + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - name: Add msbuild to PATH - uses: microsoft/setup-msbuild@v2 + uses: microsoft/setup-msbuild@6fb02220983dee41ce7ae257b6f4d8f9bf5ed4ce # v2 - name: Fetch SDL2 and set SDL2_DIR if: matrix.sdl2 == 'ON' @@ -188,32 +188,32 @@ jobs: - name: Upload SDL2.dll if: matrix.sdl2 == 'ON' - uses: actions/upload-artifact@v6 + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6 with: name: ${{ matrix.s2arc }}_SDL2.dll path: build/bin/${{ matrix.build }}/SDL2.dll - name: Upload whisper dll - uses: actions/upload-artifact@v6 + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6 with: name: whisper_${{ matrix.arch }}.dll path: build/bin/${{ matrix.build }}/whisper.dll - name: Upload ggml dll - uses: actions/upload-artifact@v6 + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6 with: name: ggml_${{ matrix.arch }}.dll path: build/bin/${{ matrix.build }}/ggml.dll overwrite: true - name: Upload ggml base dll - uses: actions/upload-artifact@v6 + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6 with: name: ggml_base_${{ matrix.arch }}.dll path: build/bin/${{ matrix.build }}/ggml-base.dll - name: Upload ggml cpu dll - uses: actions/upload-artifact@v6 + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6 with: name: ggml_cpu_${{ matrix.arch }}.dll path: build/bin/${{ matrix.build }}/ggml-cpu.dll @@ -225,7 +225,7 @@ jobs: - name: Upload binaries if: matrix.sdl2 == 'ON' && ${{ needs.determine-tag.outputs.should_release }} - uses: actions/upload-artifact@v6 + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6 with: name: whisper-bin-${{ matrix.arch }}.zip path: whisper-bin-${{ matrix.arch }}.zip @@ -253,17 +253,17 @@ jobs: steps: - name: Clone - uses: actions/checkout@v6 + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - name: Export GitHub Actions cache environment variables - uses: actions/github-script@v8 + uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8 with: script: | core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); - name: Add msbuild to PATH - uses: microsoft/setup-msbuild@v2 + uses: microsoft/setup-msbuild@6fb02220983dee41ce7ae257b6f4d8f9bf5ed4ce # v2 - name: Install OpenBLAS and pkgconfiglite if: matrix.blas == 'ON' @@ -310,7 +310,7 @@ jobs: - name: Upload binaries if: matrix.blas == 'ON' && matrix.sdl2 == 'ON' && ${{ needs.determine-tag.outputs.should_release }} - uses: actions/upload-artifact@v6 + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6 with: name: whisper-blas-bin-${{ matrix.arch }}.zip path: whisper-blas-bin-${{ matrix.arch }}.zip @@ -332,7 +332,7 @@ jobs: sdl2_ver: 2.28.5 steps: - name: Clone repository - uses: actions/checkout@v6 + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - name: Install Ninja id: install_ninja @@ -459,7 +459,7 @@ jobs: echo "CUDA_PATH_V12_2=$CUDA_TOOLKIT_DIR" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8 - name: Add msbuild to PATH - uses: microsoft/setup-msbuild@v2 + uses: microsoft/setup-msbuild@6fb02220983dee41ce7ae257b6f4d8f9bf5ed4ce # v2 - name: Install 7-Zip run: choco install 7zip -y @@ -516,7 +516,7 @@ jobs: - name: Upload binaries if: ${{ needs.determine-tag.outputs.should_release }} - uses: actions/upload-artifact@v6 + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6 with: name: whisper-cublas-${{ matrix.cuda-toolkit }}-bin-${{ matrix.arch }}.zip path: whisper-cublas-${{ matrix.cuda-toolkit }}-bin-${{ matrix.arch }}.zip @@ -531,7 +531,7 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v6 + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - name: Configure run: | @@ -573,7 +573,7 @@ jobs: - name: Upload artifacts if: ${{ needs.determine-tag.outputs.should_release }} - uses: actions/upload-artifact@v6 + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6 with: path: whisper-${{ needs.determine-tag.outputs.tag_name }}-xcframework.zip name: whisper-${{ needs.determine-tag.outputs.tag_name }}-xcframework.zip @@ -594,7 +594,7 @@ jobs: steps: - name: Clone id: checkout - uses: actions/checkout@v6 + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 with: fetch-depth: 0 @@ -607,7 +607,7 @@ jobs: # Downloads all the artifacts from the previous jobs - name: Download artifacts id: download-artifact - uses: actions/download-artifact@v7 + uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131 # v7 with: path: ./artifact @@ -627,7 +627,7 @@ jobs: - name: Upload release id: upload_release - uses: actions/github-script@v3 + uses: actions/github-script@ffc2c79a5b2490bd33e0a41c1de74b877714d736 # v3 with: github-token: ${{secrets.GITHUB_TOKEN}} script: | From 782f1226c8d9c49a6c64d654bacfe15531913a6c Mon Sep 17 00:00:00 2001 From: Ruben Ortlam <rortlam@redhat.com> Date: Mon, 8 Jun 2026 10:22:44 +0200 Subject: [PATCH 801/831] cuda: reset cuda context after reading memory size (llama/23935) * cuda: reset device in get_memory function if no backend is active * also count device and host buffers * exclude hip and musa from counting and device reset * use device mutex instead of atomic * undo backend_free function move --- ggml/src/ggml-cuda/ggml-cuda.cu | 75 +++++++++++++++++++++++++++++---- 1 file changed, 66 insertions(+), 9 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index f5293ad4cbb..e779a9be9e9 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -622,6 +622,18 @@ ggml_backend_cuda_context::~ggml_backend_cuda_context() { // cuda buffer +struct ggml_backend_cuda_device_context { + int device; + std::string name; + std::string description; + std::string pci_bus_id; + int op_offload_min_batch_size; +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + std::mutex device_mutex; + int active_count = 0; +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) +}; + struct ggml_backend_cuda_buffer_context { int device; void * dev_ptr = nullptr; @@ -639,6 +651,13 @@ struct ggml_backend_cuda_buffer_context { static void ggml_backend_cuda_buffer_free_buffer(ggml_backend_buffer_t buffer) { ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context; + +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) buffer->buft->device->context; + std::lock_guard<std::mutex> lock(dev_ctx->device_mutex); + dev_ctx->active_count--; +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + delete ctx; } @@ -791,6 +810,12 @@ static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffer(ggml_bac ggml_backend_cuda_buffer_context * ctx = new ggml_backend_cuda_buffer_context(buft_ctx->device, dev_ptr); +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) buft->device->context; + std::lock_guard<std::mutex> lock(dev_ctx->device_mutex); + dev_ctx->active_count++; +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + return ggml_backend_buffer_init(buft, ggml_backend_cuda_buffer_interface, ctx, size); } @@ -1490,6 +1515,12 @@ static bool ggml_backend_buft_is_cuda_host(ggml_backend_buffer_type_t buft) { } static void ggml_backend_cuda_host_buffer_free_buffer(ggml_backend_buffer_t buffer) { +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) buffer->buft->device->context; + std::lock_guard<std::mutex> lock(dev_ctx->device_mutex); + dev_ctx->active_count--; +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + CUDA_CHECK(cudaFreeHost(buffer->context)); } @@ -1498,6 +1529,8 @@ static void * ggml_cuda_host_malloc(size_t size) { return nullptr; } + ggml_cuda_set_device(0); // cudaMallocHost can create the implicit CUDA device context, make sure that this is consistently done on device 0. + void * ptr = nullptr; cudaError_t err = cudaMallocHost((void **) &ptr, size); if (err != cudaSuccess) { @@ -1523,6 +1556,12 @@ static ggml_backend_buffer_t ggml_backend_cuda_host_buffer_type_alloc_buffer(ggm buffer->buft = buft; buffer->iface.free_buffer = ggml_backend_cuda_host_buffer_free_buffer; +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) buft->device->context; + std::lock_guard<std::mutex> lock(dev_ctx->device_mutex); + dev_ctx->active_count++; +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + return buffer; } @@ -3140,6 +3179,12 @@ static const char * ggml_backend_cuda_get_name(ggml_backend_t backend) { static void ggml_backend_cuda_free(ggml_backend_t backend) { ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) backend->device->context; + std::lock_guard<std::mutex> lock(dev_ctx->device_mutex); + dev_ctx->active_count--; +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + delete cuda_ctx; delete backend; } @@ -4871,14 +4916,6 @@ void ggml_backend_cuda_unregister_host_buffer(void * buffer) { // backend device -struct ggml_backend_cuda_device_context { - int device; - std::string name; - std::string description; - std::string pci_bus_id; - int op_offload_min_batch_size; -}; - static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) { ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context; return ctx->name.c_str(); @@ -4967,6 +5004,11 @@ static bool ggml_backend_cuda_get_available_uma_memory(long * available_memory_k static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context; + +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + std::lock_guard<std::mutex> lock(ctx->device_mutex); +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + ggml_cuda_set_device(ctx->device); CUDA_CHECK(cudaMemGetInfo(free, total)); @@ -4993,6 +5035,13 @@ static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * } #endif // defined(__linux__) +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + // If no backends or buffers are active, the cudaMemGetInfo call above lazily created a CUDA + // context that permanently consumes VRAM. Reset the device to free it. + if (ctx->active_count == 0) { + CUDA_CHECK(cudaDeviceReset()); + } +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) } static enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend_dev_t dev) { @@ -5687,13 +5736,21 @@ ggml_backend_t ggml_backend_cuda_init(int device) { return nullptr; } + ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_cuda_reg(), device); + ggml_backend_t cuda_backend = new ggml_backend { /* .guid = */ ggml_backend_cuda_guid(), /* .iface = */ ggml_backend_cuda_interface, - /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cuda_reg(), device), + /* .device = */ dev, /* .context = */ ctx, }; +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) dev->context; + std::lock_guard<std::mutex> lock(dev_ctx->device_mutex); + dev_ctx->active_count++; +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + return cuda_backend; } From fbf720dc9f3570ed98bd5e43806fbc4a53428084 Mon Sep 17 00:00:00 2001 From: Jeff Bolz <jbolz@nvidia.com> Date: Mon, 8 Jun 2026 03:40:37 -0500 Subject: [PATCH 802/831] vulkan: Use cm2 decode_vector for mul_mat_id B matrix loads (llama/23991) This allows vec4 loads of the B elements. Also increase BK to 64 when this is enabled. Neither of these alone is consistently faster, but together these give a nice speedup. In ggml-vulkan.cpp, we need to make sure the B matrix alignment and stride are multiples of 4. --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 149 +++++++++++++++--- .../vulkan-shaders/mul_mm_cm2.comp | 47 +++++- .../vulkan-shaders/vulkan-shaders-gen.cpp | 25 +-- 3 files changed, 183 insertions(+), 38 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index fc9bc8fe376..2dd8cd2fbd9 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -1976,6 +1976,9 @@ struct ggml_backend_vk_context { // Cache most recent tensor that was converted into prealloc_y, and what pipeline it used to convert. vk_pipeline_struct * prealloc_y_last_pipeline_used {}; const ggml_tensor * prealloc_y_last_tensor_used {}; + // True when prealloc_y holds the padded fp16 layout used by the coopmat2 B decode-vector callback. + // If false, then it's contiguous. + bool prealloc_y_last_decode_vector_staging {}; // Track which nodes have been used since the last sync, and whether they were written to std::vector<const ggml_tensor *> unsynced_nodes_written; @@ -3652,9 +3655,10 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) { s_mmq_wg_denoms_k = { 32, 64, 1 }; // spec constants and tile sizes for quant matmul_id - l_warptile_mmqid = { 256, 128, 128, 32, 1, device->subgroup_size }; - m_warptile_mmqid = { 256, 128, 64, 32, 0, device->subgroup_size }; - s_warptile_mmqid = { 256, 128, 64, 32, 0, device->subgroup_size }; + const uint32_t mmqid_bk = device->coopmat2_decode_vector ? 64u : 32u; + l_warptile_mmqid = { 256, 128, 128, mmqid_bk, 1, device->subgroup_size }; + m_warptile_mmqid = { 256, 128, 64, mmqid_bk, 0, device->subgroup_size }; + s_warptile_mmqid = { 256, 128, 64, mmqid_bk, 0, device->subgroup_size }; l_mmqid_wg_denoms = { 128, 128, 1 }; m_mmqid_wg_denoms = { 128, 64, 1 }; s_mmqid_wg_denoms = { 128, 64, 1 }; @@ -8110,6 +8114,40 @@ static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context& ggml_vk_sync_buffers(ctx, subctx); } +// Copy/convert tensor into a caller-defined dense layout. Destination strides +// are in output elements, not bytes. +static void ggml_vk_cpy_to_strided( + ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline pipeline, const ggml_tensor * tensor, + const vk_subbuffer & in, const vk_subbuffer & out, + uint32_t nb10, uint32_t nb11, uint32_t nb12, uint32_t nb13) { + VK_LOG_DEBUG("ggml_vk_cpy_to_strided((" << tensor << ", type=" << tensor->type << ", ne0=" << tensor->ne[0] << ", ne1=" << tensor->ne[1] << ", ne2=" << tensor->ne[2] << ", ne3=" << tensor->ne[3] << ", nb0=" << tensor->nb[0] << ", nb1=" << tensor->nb[1] << ", nb2=" << tensor->nb[2] << ", nb3=" << tensor->nb[3] << "), "; + std::cerr << "dst_nb=(" << nb10 << ", " << nb11 << ", " << nb12 << ", " << nb13 << "), buffer in size=" << in.buffer->size << ", buffer out size=" << out.buffer->size << ")"); + const int tensor_type_size = ggml_type_size(tensor->type); + + const uint32_t ne = ggml_nelements(tensor); + std::array<uint32_t, 3> elements; + + if (ne > 262144) { + elements = { 512, 512, CEIL_DIV(ne, 262144) }; + } else if (ne > 512) { + elements = { 512, CEIL_DIV(ne, 512), 1 }; + } else { + elements = { ne, 1, 1 }; + } + + vk_op_unary_push_constants pc = { + (uint32_t)ne, + (uint32_t)tensor->ne[0], (uint32_t)tensor->ne[1], (uint32_t)tensor->ne[2], (uint32_t)tensor->ne[3], (uint32_t)tensor->nb[0] / tensor_type_size, (uint32_t)tensor->nb[1] / tensor_type_size, (uint32_t)tensor->nb[2] / tensor_type_size, (uint32_t)tensor->nb[3] / tensor_type_size, + (uint32_t)tensor->ne[0], (uint32_t)tensor->ne[1], (uint32_t)tensor->ne[2], (uint32_t)tensor->ne[3], nb10, nb11, nb12, nb13, + 0, + 0.0f, 0.0f, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + }; + init_pushconst_fastdiv(pc); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, pc, elements); + ggml_vk_sync_buffers(ctx, subctx); +} + static vk_pipeline ggml_vk_get_quantize_pipeline(ggml_backend_vk_context * ctx, ggml_type type) { switch(type) { case GGML_TYPE_Q8_1: @@ -8367,24 +8405,28 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub } if (y_non_contig) { if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() || - ctx->prealloc_y_last_tensor_used != src1) { + ctx->prealloc_y_last_tensor_used != src1 || + ctx->prealloc_y_last_decode_vector_staging) { if (ctx->prealloc_y_need_sync) { ggml_vk_sync_buffers(ctx, subctx); } ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0)); ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get(); ctx->prealloc_y_last_tensor_used = src1; + ctx->prealloc_y_last_decode_vector_staging = false; } } if (quantize_y) { if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() || - ctx->prealloc_y_last_tensor_used != src1) { + ctx->prealloc_y_last_tensor_used != src1 || + ctx->prealloc_y_last_decode_vector_staging) { if (ctx->prealloc_y_need_sync) { ggml_vk_sync_buffers(ctx, subctx); } ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0), y_ne); ctx->prealloc_y_last_pipeline_used = to_q8_1.get(); ctx->prealloc_y_last_tensor_used = src1; + ctx->prealloc_y_last_decode_vector_staging = false; } } @@ -8642,24 +8684,28 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& if (y_non_contig) { GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne); if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() || - ctx->prealloc_y_last_tensor_used != src1) { + ctx->prealloc_y_last_tensor_used != src1 || + ctx->prealloc_y_last_decode_vector_staging) { if (ctx->prealloc_y_need_sync) { ggml_vk_sync_buffers(ctx, subctx); } ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, d_Qy, d_Y); ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get(); ctx->prealloc_y_last_tensor_used = src1; + ctx->prealloc_y_last_decode_vector_staging = false; } } if (quantize_y) { if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() || - ctx->prealloc_y_last_tensor_used != src1) { + ctx->prealloc_y_last_tensor_used != src1 || + ctx->prealloc_y_last_decode_vector_staging) { if (ctx->prealloc_y_need_sync) { ggml_vk_sync_buffers(ctx, subctx); } ggml_vk_quantize_q8_1(ctx, subctx, d_Qy, d_Y, y_ne); ctx->prealloc_y_last_pipeline_used = to_q8_1.get(); ctx->prealloc_y_last_tensor_used = src1; + ctx->prealloc_y_last_decode_vector_staging = false; } } @@ -9110,12 +9156,30 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& // Reformat and convert to fp16 if non-contiguous, or for coopmat2 for better perf const bool x_non_contig = (ctx->device->coopmat2 && src0->type == GGML_TYPE_F32) || !ggml_vk_dim01_contiguous(src0); - const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) || + // If src0 is BF16, try to use a BF16 x BF16 multiply + ggml_type f16_type = src0->type == GGML_TYPE_BF16 ? GGML_TYPE_BF16 : GGML_TYPE_F16; +#if defined(GGML_VULKAN_COOPMAT2_DECODE_VECTOR_GLSLC_SUPPORT) + // B must already be, or be convertible to, the matmul B type used by this path. + const bool y_decode_vector_supported = ctx->device->coopmat2_decode_vector && + (f16_type != GGML_TYPE_BF16 || ctx->device->coopmat2_bf16_support) && + (src1->type == GGML_TYPE_F32 || src1->type == f16_type); + // If B is copied to prealloc_y, we can choose a 4-element-aligned row stride. + const bool y_decode_vector_uses_prealloc = !ggml_vk_dim01_contiguous(src1) || src1->type != f16_type; + // Direct B reads are safe only if row starts and the original buffer offset are 4-element aligned. + const bool y_decode_vector_aligned = + (ne10 % 4 == 0) && + (y_decode_vector_uses_prealloc || get_misalign_bytes(ctx, src1) % (4 * ggml_type_size(src1->type)) == 0); + // Stage B only when decode-vector is available and direct B reads would be misaligned. + const bool y_decode_vector_staging = y_decode_vector_supported && !y_decode_vector_aligned; +#else + const bool y_decode_vector_staging = false; +#endif + const bool y_non_contig = y_decode_vector_staging || + (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) || (src0->type == GGML_TYPE_BF16 && src1->type != GGML_TYPE_BF16) || !ggml_vk_dim01_contiguous(src1); - // If src0 is BF16, try to use a BF16 x BF16 multiply - ggml_type f16_type = src0->type == GGML_TYPE_BF16 ? GGML_TYPE_BF16 : GGML_TYPE_F16; + const uint32_t y_staged_row_stride = y_decode_vector_staging ? (uint32_t)ggml_vk_align_size(ne10, 4) : (uint32_t)ne10; const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig; @@ -9154,11 +9218,11 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& // Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) :ne11; const uint64_t x_ne = ggml_nelements(src0); - const uint64_t y_ne = padded_n * ne10 * ne12 * ne13; + const uint64_t y_ne = (uint64_t)y_staged_row_stride * padded_n * ne12 * ne13; const uint64_t d_ne = ggml_nelements(dst); const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type); - const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type); + const uint64_t qy_sz = ggml_type_size(src1->type) * ggml_nelements(src1) / ggml_blck_size(src1->type); const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne; const uint64_t y_sz = quantize_y ? (ggml_vk_align_size(y_ne, 128) * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) : (y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne); const uint64_t ids_sz = nbi2; @@ -9168,13 +9232,30 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& vk_pipeline to_fp16_vk_1 = nullptr; vk_pipeline to_q8_1 = nullptr; + auto make_y_staged_dst = [&]() { + ggml_tensor y_staged_dst = *src1; + y_staged_dst.type = f16_type; + y_staged_dst.nb[0] = ggml_type_size(f16_type); + y_staged_dst.nb[1] = y_staged_dst.nb[0] * y_staged_row_stride; + y_staged_dst.nb[2] = y_staged_dst.nb[1] * padded_n; + y_staged_dst.nb[3] = y_staged_dst.nb[2] * y_staged_dst.ne[2]; + return y_staged_dst; + }; + if (x_non_contig) { to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, f16_type); } else { to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type); } if (y_non_contig) { - to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, f16_type); + ggml_tensor y_staged_dst; + const ggml_tensor * y_staged_dst_ptr = nullptr; + if (y_decode_vector_staging) { + y_staged_dst = make_y_staged_dst(); + y_staged_dst_ptr = &y_staged_dst; + } + + to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, y_staged_dst_ptr, f16_type); } else { to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type); } @@ -9292,30 +9373,47 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& } if (y_non_contig) { if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() || - ctx->prealloc_y_last_tensor_used != src1) { + ctx->prealloc_y_last_tensor_used != src1 || + ctx->prealloc_y_last_decode_vector_staging != y_decode_vector_staging) { if (ctx->prealloc_y_need_sync) { ggml_vk_sync_buffers(ctx, subctx); } - ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0)); + if (y_decode_vector_staging) { + const ggml_tensor y_staged_dst = make_y_staged_dst(); + const uint32_t y_staged_dst_type_size = ggml_type_size(y_staged_dst.type); + ggml_vk_cpy_to_strided( + ctx, subctx, to_fp16_vk_1, src1, + ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0), + (uint32_t)(y_staged_dst.nb[0] / y_staged_dst_type_size), + (uint32_t)(y_staged_dst.nb[1] / y_staged_dst_type_size), + (uint32_t)(y_staged_dst.nb[2] / y_staged_dst_type_size), + (uint32_t)(y_staged_dst.nb[3] / y_staged_dst_type_size)); + } else { + ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0)); + } ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get(); ctx->prealloc_y_last_tensor_used = src1; + ctx->prealloc_y_last_decode_vector_staging = y_decode_vector_staging; } } if (quantize_y) { if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() || - ctx->prealloc_y_last_tensor_used != src1) { + ctx->prealloc_y_last_tensor_used != src1 || + ctx->prealloc_y_last_decode_vector_staging) { if (ctx->prealloc_y_need_sync) { ggml_vk_sync_buffers(ctx, subctx); } ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0), y_ne); ctx->prealloc_y_last_pipeline_used = to_q8_1.get(); ctx->prealloc_y_last_tensor_used = src1; + ctx->prealloc_y_last_decode_vector_staging = false; } } ggml_vk_sync_buffers(ctx, subctx); uint32_t stride_batch_x = ne00*ne01; - uint32_t stride_batch_y = ne10*ne11; + uint32_t stride_b_y = y_decode_vector_staging ? y_staged_row_stride : ne10; + uint32_t stride_batch_y = y_decode_vector_staging ? y_staged_row_stride * padded_n : ne10*ne11; if (!ggml_vk_dim01_contiguous(src0) && !qx_needs_dequant) { stride_batch_x = src0->nb[0] / ggml_type_size(src0->type); @@ -9330,7 +9428,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& ctx, subctx, pipeline, { d_X, x_buf_offset, x_sz }, { d_Y, y_buf_offset, y_sz }, { d_D, d_buf_offset, d_sz }, { d_ids, ids_buf_offset, ids_sz }, expert_count_buf, - ne01, ne21, ne10, ne10, ne10, ne01, + ne01, ne21, ne10, ne10, stride_b_y, ne01, stride_batch_x, stride_batch_y, ne20*ne21, n_as, nei0, nei1, nbi1 / ggml_type_size(ids->type), ne11, padded_n ); // NOLINT @@ -9488,24 +9586,28 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte if (y_non_contig) { GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne); if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() || - ctx->prealloc_y_last_tensor_used != src1) { + ctx->prealloc_y_last_tensor_used != src1 || + ctx->prealloc_y_last_decode_vector_staging) { if (ctx->prealloc_y_need_sync) { ggml_vk_sync_buffers(ctx, subctx); } ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, d_Qy, d_Y); ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get(); ctx->prealloc_y_last_tensor_used = src1; + ctx->prealloc_y_last_decode_vector_staging = false; } } if (quantize_y) { if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() || - ctx->prealloc_y_last_tensor_used != src1) { + ctx->prealloc_y_last_tensor_used != src1 || + ctx->prealloc_y_last_decode_vector_staging) { if (ctx->prealloc_y_need_sync) { ggml_vk_sync_buffers(ctx, subctx); } ggml_vk_quantize_q8_1(ctx, subctx, d_Qy, d_Y, y_ne); ctx->prealloc_y_last_pipeline_used = to_q8_1.get(); ctx->prealloc_y_last_tensor_used = src1; + ctx->prealloc_y_last_decode_vector_staging = false; } } @@ -13730,7 +13832,9 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx, vk_contex ggml_vk_destroy_buffer(ctx->prealloc_y); } ctx->prealloc_y = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_y); + ctx->prealloc_y_last_pipeline_used = nullptr; ctx->prealloc_y_last_tensor_used = nullptr; + ctx->prealloc_y_last_decode_vector_staging = false; } if (ctx->prealloc_split_k == nullptr || (ctx->prealloc_size_split_k > 0 && ctx->prealloc_split_k->size < ctx->prealloc_size_split_k)) { VK_LOG_MEMORY("ggml_vk_preallocate_buffers(split_k_size: " << ctx->prealloc_size_split_k << ")"); @@ -14310,6 +14414,8 @@ static void ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) { VK_LOG_DEBUG("ggml_vk_graph_cleanup()"); ctx->prealloc_y_last_pipeline_used = {}; + ctx->prealloc_y_last_tensor_used = nullptr; + ctx->prealloc_y_last_decode_vector_staging = false; ctx->unsynced_nodes_written.clear(); ctx->unsynced_nodes_read.clear(); @@ -14360,6 +14466,8 @@ static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) { ggml_vk_destroy_buffer(ctx->sync_staging); ctx->prealloc_y_last_pipeline_used = nullptr; + ctx->prealloc_y_last_tensor_used = nullptr; + ctx->prealloc_y_last_decode_vector_staging = false; ctx->prealloc_size_x = 0; ctx->prealloc_size_y = 0; @@ -15539,6 +15647,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg ctx->prealloc_y_last_pipeline_used = nullptr; ctx->prealloc_y_last_tensor_used = nullptr; + ctx->prealloc_y_last_decode_vector_staging = false; if (ctx->prealloc_size_add_rms_partials) { ggml_vk_preallocate_buffers(ctx, nullptr); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp index 250d708479b..2656fe1c3e9 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp @@ -11,6 +11,9 @@ #extension GL_KHR_memory_scope_semantics : enable #extension GL_KHR_cooperative_matrix : enable #extension GL_NV_cooperative_matrix2 : enable +#ifdef GGML_VULKAN_COOPMAT2_DECODE_VECTOR +#extension GL_NV_cooperative_matrix_decode_vector : enable +#endif #extension GL_EXT_buffer_reference : enable #extension GL_KHR_shader_subgroup_ballot : enable #extension GL_KHR_shader_subgroup_vote : enable @@ -69,10 +72,13 @@ layout (push_constant) uniform parameter layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; +#if defined(MUL_MAT_ID) && defined(GGML_VULKAN_COOPMAT2_DECODE_VECTOR) +layout (binding = 1) readonly buffer B4 {B_TYPEV4 data_b_v4[];}; +#endif #if QUANT_K > 1 #include "dequant_funcs_cm2.glsl" -#if defined(dequantFuncA_v) && defined(GL_NV_cooperative_matrix_decode_vector) +#if defined(dequantFuncA_v) && defined(GGML_VULKAN_COOPMAT2_DECODE_VECTOR) #define DECODEFUNCA , dequantFuncA, dequantFuncA_v #else #define DECODEFUNCA , dequantFuncA @@ -113,11 +119,33 @@ B_TYPE decodeFuncB(const in decodeBufB bl, const in uint blockCoords[2], const i const uint row_i = blockCoords[0]; const u16vec4 row_idx = row_ids[row_i]; - B_TYPE ret = data_b[row_idx.y * p.batch_stride_b + row_idx.x * p.stride_b + blockCoords[1]]; +#if defined(GGML_VULKAN_COOPMAT2_DECODE_VECTOR) + // The decode-vector path gives B a K-dimension tensor-layout block size of BK. + const uint k = blockCoords[1] * BK + coordInBlock[1]; +#else + const uint k = blockCoords[1]; +#endif + B_TYPE ret = data_b[row_idx.y * p.batch_stride_b + row_idx.x * p.stride_b + k]; return ret; } +#if defined(GGML_VULKAN_COOPMAT2_DECODE_VECTOR) +B_TYPEV4 decodeFuncB_v(const in decodeBufB bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const uint row_i = blockCoords[0]; + + const u16vec4 row_idx = row_ids[row_i]; + const uint k = blockCoords[1] * BK + coordInBlock[1]; + const uint base = row_idx.y * p.batch_stride_b + row_idx.x * p.stride_b + k; + + return data_b_v4[base >> 2]; +} +#define DECODEFUNCB , decodeFuncB, decodeFuncB_v +#else +#define DECODEFUNCB , decodeFuncB +#endif + D_TYPE perElemOpD(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t ir, const in uint32_t ic) { uint dr = ir * BM + r; @@ -287,6 +315,9 @@ void main() { tensorLayoutA = setTensorLayoutBlockSizeNV(tensorLayoutA, 1, QUANT_K); tensorLayoutAClamp = setTensorLayoutBlockSizeNV(tensorLayoutAClamp, 1, QUANT_K); #endif +#if defined(MUL_MAT_ID) && defined(GGML_VULKAN_COOPMAT2_DECODE_VECTOR) + tensorLayoutB = setTensorLayoutBlockSizeNV(tensorLayoutB, 1, BK); +#endif // Use end_k rather than p.K as the dimension because that's what // we need to bound check against when using split_k. @@ -499,7 +530,7 @@ void main() { coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b; coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); - coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover4, block_k, BK), tensorViewTranspose, decodeFuncB); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover4, block_k, BK), tensorViewTranspose DECODEFUNCB); sum = coopMatMulAdd(mat_a, mat_b, sum); } else { @@ -507,7 +538,7 @@ void main() { coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b; coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA); - coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover4, block_k, BK), tensorViewTranspose, decodeFuncB); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover4, block_k, BK), tensorViewTranspose DECODEFUNCB); sum = coopMatMulAdd(mat_a, mat_b, sum); } @@ -543,7 +574,7 @@ void main() { coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b; coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); - coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover2, block_k, BK), tensorViewTranspose, decodeFuncB); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover2, block_k, BK), tensorViewTranspose DECODEFUNCB); sum = coopMatMulAdd(mat_a, mat_b, sum); } else { @@ -551,7 +582,7 @@ void main() { coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b; coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA); - coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover2, block_k, BK), tensorViewTranspose, decodeFuncB); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover2, block_k, BK), tensorViewTranspose DECODEFUNCB); sum = coopMatMulAdd(mat_a, mat_b, sum); } @@ -588,7 +619,7 @@ void main() { coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); #ifdef MUL_MAT_ID - coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BN, block_k, BK), tensorViewTranspose, decodeFuncB); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BN, block_k, BK), tensorViewTranspose DECODEFUNCB); #else coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose); #endif @@ -600,7 +631,7 @@ void main() { coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA); #ifdef MUL_MAT_ID - coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BN, block_k, BK), tensorViewTranspose, decodeFuncB); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BN, block_k, BK), tensorViewTranspose DECODEFUNCB); #else coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose); #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index d65cd12b287..8fc00362870 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -457,6 +457,11 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c if (coopmat) { base_dict["COOPMAT"] = "1"; } +#if defined(GGML_VULKAN_COOPMAT2_DECODE_VECTOR_GLSLC_SUPPORT) + if (coopmat2) { + base_dict["GGML_VULKAN_COOPMAT2_DECODE_VECTOR"] = "1"; + } +#endif const std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp"; @@ -523,11 +528,11 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c }; // Shaders with f16 B_TYPE - string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_f16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_f16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); // bf16 { @@ -548,8 +553,8 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c if (!(coopmat || coopmat2)) #endif { - string_to_spv(shader_name + "_bf16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_bf16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"B_TYPEV4", "bf16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"B_TYPEV4", "bf16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); } } @@ -579,13 +584,13 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c // don't generate f32 variants for coopmat2 if (!coopmat2) { - string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); } if (tname != "f16" && tname != "f32") { - string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); } #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) From 490e50056c2a96f1ebd0cabca43b031ab865dc44 Mon Sep 17 00:00:00 2001 From: Nikhil Jain <nikhil.jain0987@gmail.com> Date: Mon, 8 Jun 2026 08:07:15 -0700 Subject: [PATCH 803/831] Implement 2D workgroups for scale, binary, and unary ops (llama/24044) * Only run webgpu CI on my fork * Add webgpu only workflow * Implement 2d workgroups for more operations * fix * Fix type * Move back to global_invocation_id --- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 32 ++++++++++++------- ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl | 13 +++++--- ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl | 10 +++--- ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl | 11 ++++--- 4 files changed, 41 insertions(+), 25 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index c6cfb0bbbad..94a108dfa77 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -621,10 +621,11 @@ static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx, uint32_t value, size_t offset, size_t size) { - std::vector<uint32_t> params = { (uint32_t) offset, (uint32_t) size, value }; - std::vector<wgpu::BindGroupEntry> entries = { ggml_webgpu_make_bind_group_entry(0, buf, 0, buf.GetSize()) }; - size_t bytes_per_wg = ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup * ctx->capabilities.memset_bytes_per_thread; - uint32_t wg_x = CEIL_DIV(size + 3, bytes_per_wg); + std::vector<uint32_t> params = { (uint32_t) offset, (uint32_t) size, value }; + std::vector<wgpu::BindGroupEntry> entries = { ggml_webgpu_make_bind_group_entry(0, buf, 0, buf.GetSize()) }; + size_t bytes_per_wg = + ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup * ctx->capabilities.memset_bytes_per_thread; + uint32_t wg_x = CEIL_DIV(size + 3, bytes_per_wg); ctx->queue.WriteBuffer(ctx->memset_params_buf, 0, params.data(), params.size() * sizeof(uint32_t)); @@ -1362,7 +1363,7 @@ static webgpu_encoded_op ggml_webgpu_get_rows(webgpu_context & ctx, shader_lib_ctx.src0 = src; shader_lib_ctx.src1 = nullptr; shader_lib_ctx.dst = dst; - shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; webgpu_pipeline pipeline = ctx->shader_lib->get_get_rows_pipeline(shader_lib_ctx); auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get()); @@ -2169,8 +2170,10 @@ static webgpu_encoded_op ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst)); } - uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); - return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); + uint32_t wg_x, wg_y; + uint32_t total_wg = CEIL_DIV(ggml_nelements(dst), decisions->wg_size); + compute_2d_workgroups(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, wg_x, wg_y); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); } static webgpu_encoded_op ggml_webgpu_binary_op(webgpu_context & ctx, @@ -2244,8 +2247,10 @@ static webgpu_encoded_op ggml_webgpu_binary_op(webgpu_context & ctx, } } - uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); - return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); + uint32_t wg_x, wg_y; + uint32_t total_wg = CEIL_DIV(ggml_nelements(dst), decisions->wg_size); + compute_2d_workgroups(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, wg_x, wg_y); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); } static webgpu_encoded_op ggml_webgpu_add_id(webgpu_context & ctx, @@ -2673,8 +2678,10 @@ static webgpu_encoded_op ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * s entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst)); } - uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size); - return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); + uint32_t wg_x, wg_y; + uint32_t total_wg = CEIL_DIV(ggml_nelements(dst), decisions->wg_size); + compute_2d_workgroups(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, wg_x, wg_y); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); } static webgpu_encoded_op ggml_webgpu_soft_max(webgpu_context & ctx, @@ -3751,7 +3758,8 @@ static ggml_guid_t ggml_backend_webgpu_guid(void) { static void ggml_webgpu_init_memset_pipeline(webgpu_global_context & ctx) { // we use the maximum workgroup size for the memset pipeline - size_t max_threads = ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup * ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; + size_t max_threads = ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup * + ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; // Size the bytes_per_thread so that the largest buffer size can be handled ctx->capabilities.memset_bytes_per_thread = CEIL_DIV(ctx->capabilities.limits.maxStorageBufferBindingSize, max_threads); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl index 605de7aa7be..f262c4a8f6a 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl @@ -130,10 +130,13 @@ fn update(dst_i: u32, src0_i: u32, src1_i: u32) { } @compute @workgroup_size(WG_SIZE) -fn main(@builtin(global_invocation_id) gid: vec3<u32>) { - if (gid.x < params.ne) { - let src0_i = params.offset_src0 + src0_index(gid.x); - let src1_i = params.offset_src1 + src1_index(gid.x); - update(params.offset_dst + gid.x, src0_i, src1_i); +fn main(@builtin(global_invocation_id) gid: vec3<u32>, + @builtin(num_workgroups) num_wg: vec3<u32>) { + let threads_per_group = u32(WG_SIZE); + let i = gid.x + (num_wg.x * threads_per_group) * gid.y; + if (i < params.ne) { + let src0_i = params.offset_src0 + src0_index(i); + let src1_i = params.offset_src1 + src1_index(i); + update(params.offset_dst + i, src0_i, src1_i); } } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl index 3b70a876d70..6c76ed69e45 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl @@ -43,12 +43,14 @@ struct Params { var<storage, read_write> src: array<f32>; @compute @workgroup_size(WG_SIZE) -fn main(@builtin(global_invocation_id) gid: vec3<u32>) { - if (gid.x >= params.ne) { +fn main( + @builtin(global_invocation_id) gid: vec3<u32>, + @builtin(num_workgroups) num_wg: vec3<u32>) { + let threads_per_group = u32(WG_SIZE); + var i = gid.x + (num_wg.x * threads_per_group) * gid.y; + if (i >= params.ne) { return; } - - var i = gid.x; let i3 = i / (params.ne2 * params.ne1 * params.ne0); i = i % (params.ne2 * params.ne1 * params.ne0); let i2 = i / (params.ne1 * params.ne0); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl index 8e34e1c9ca0..cb342c47263 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl @@ -66,11 +66,14 @@ fn erf_approx(x: TYPE) -> TYPE { } @compute @workgroup_size(WG_SIZE) -fn main(@builtin(global_invocation_id) gid: vec3<u32>) { - if (gid.x >= params.ne) { +fn main(@builtin(global_invocation_id) gid: vec3<u32>, + @builtin(num_workgroups) num_wg: vec3<u32>) { + let threads_per_group = u32(WG_SIZE); + let flat_i = gid.x + (num_wg.x * threads_per_group) * gid.y; + if (flat_i >= params.ne) { return; } - var i = gid.x; + var i = flat_i; let ne2 = params.ne2; #ifdef DIAG let ne1 = params.ne0; @@ -205,6 +208,6 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) { #ifdef INPLACE src[params.offset_src + src_idx] = res; #else - dst[params.offset_dst + gid.x] = res; + dst[params.offset_dst + flat_i] = res; #endif } From 15e5d401d18dae5968d98ef54241317d5b8bab33 Mon Sep 17 00:00:00 2001 From: Nikhil Jain <nikhil.jain0987@gmail.com> Date: Mon, 8 Jun 2026 08:07:31 -0700 Subject: [PATCH 804/831] Handle buffer overlap / buffer aliasing for concat operator (llama/24000) * Only run webgpu CI on my fork * Add webgpu only workflow * handle buffer overlap case for concat operator * restore build-webgpu.yml Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * Run clang-format * Update ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl --------- Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com> Co-authored-by: Reese Levine <reeselevine1@gmail.com> --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 17 ++++- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 75 ++++++++++++------- ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl | 20 ++++- 3 files changed, 79 insertions(+), 33 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index a5e7de785b4..c75a98a8dd4 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -448,15 +448,19 @@ struct ggml_webgpu_upscale_pipeline_key_hash { /** Concat **/ struct ggml_webgpu_concat_pipeline_key { - int type; + int type; + bool src_overlap; - bool operator==(const ggml_webgpu_concat_pipeline_key & other) const { return type == other.type; } + bool operator==(const ggml_webgpu_concat_pipeline_key & other) const { + return type == other.type && src_overlap == other.src_overlap; + } }; struct ggml_webgpu_concat_pipeline_key_hash { size_t operator()(const ggml_webgpu_concat_pipeline_key & key) const { size_t seed = 0; ggml_webgpu_hash_combine(seed, key.type); + ggml_webgpu_hash_combine(seed, key.src_overlap); return seed; } }; @@ -2634,6 +2638,7 @@ class ggml_webgpu_shader_lib { webgpu_pipeline get_concat_pipeline(const ggml_webgpu_shader_lib_context & context) { ggml_webgpu_concat_pipeline_key key = {}; key.type = context.dst->type; + key.src_overlap = ggml_webgpu_tensor_overlap(context.src0, context.src1); auto it = concat_pipelines.find(key); if (it != concat_pipelines.end()) { @@ -2656,11 +2661,17 @@ class ggml_webgpu_shader_lib { GGML_ABORT("Unsupported type for concat shader"); } + if (key.src_overlap) { + defines.push_back("SRC_OVERLAP"); + variant += "_src_overlap"; + } + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); auto processed = preprocessor.preprocess(wgsl_concat, defines); - auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>(); + auto decisions = std::make_shared<ggml_webgpu_binary_shader_decisions>(); decisions->wg_size = context.max_wg_size; + decisions->src_overlap = key.src_overlap; webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); pipeline.context = decisions; concat_pipelines[key] = pipeline; diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 94a108dfa77..79d5138029d 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -2310,33 +2310,6 @@ static webgpu_encoded_op ggml_webgpu_concat(webgpu_context & ctx, uint32_t ne = (uint32_t) ggml_nelements(dst); uint32_t dim = (uint32_t) dst->op_params[0]; - std::vector<uint32_t> params = { - ne, - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), - (uint32_t) (src0->nb[0] / ggml_type_size(src0->type)), - (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), - (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), - (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), - (uint32_t) (src1->nb[0] / ggml_type_size(src1->type)), - (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), - (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), - (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)), - (uint32_t) dst->ne[0], - (uint32_t) dst->ne[1], - (uint32_t) dst->ne[2], - (uint32_t) dst->ne[3], - dim, - (uint32_t) src0->ne[dim] - }; - - std::vector<wgpu::BindGroupEntry> entries = { - ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), - ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1), - ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst), - }; - ggml_webgpu_shader_lib_context shader_lib_ctx = {}; shader_lib_ctx.src0 = src0; shader_lib_ctx.src1 = src1; @@ -2344,8 +2317,52 @@ static webgpu_encoded_op ggml_webgpu_concat(webgpu_context & ctx, shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; webgpu_pipeline pipeline = ctx->shader_lib->get_concat_pipeline(shader_lib_ctx); - auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get()); - uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); + auto * decisions = static_cast<ggml_webgpu_binary_shader_decisions *>(pipeline.context.get()); + + uint32_t offset_src0 = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)); + uint32_t offset_src1 = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)); + size_t merged_offset = 0; + size_t merged_size = 0; + if (decisions->src_overlap) { + const ggml_webgpu_merged_binding_range merged_range = + ggml_webgpu_tensor_merged_binding_range(ctx, { src0, src1 }); + merged_offset = merged_range.offset; + merged_size = merged_range.size; + offset_src0 = ggml_webgpu_tensor_merged_element_offset(src0, merged_range); + offset_src1 = ggml_webgpu_tensor_merged_element_offset(src1, merged_range); + } + + std::vector<uint32_t> params = { ne, + offset_src0, + offset_src1, + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + (uint32_t) (src0->nb[0] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), + (uint32_t) (src1->nb[0] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)), + (uint32_t) dst->ne[0], + (uint32_t) dst->ne[1], + (uint32_t) dst->ne[2], + (uint32_t) dst->ne[3], + dim, + (uint32_t) src0->ne[dim] }; + + std::vector<wgpu::BindGroupEntry> entries = {}; + if (decisions->src_overlap) { + entries.push_back( + ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(src0), merged_offset, merged_size)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst)); + } else { + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst)); + } + + uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl index a22d245d2cc..eb901bf0547 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl @@ -31,6 +31,16 @@ struct Params { #define DataType i32 #endif +#ifdef SRC_OVERLAP +@group(0) @binding(0) +var<storage, read_write> merged_src: array<DataType>; + +@group(0) @binding(1) +var<storage, read_write> dst: array<DataType>; + +@group(0) @binding(2) +var<uniform> params: Params; +#else @group(0) @binding(0) var<storage, read_write> src0: array<DataType>; @@ -42,7 +52,7 @@ var<storage, read_write> dst: array<DataType>; @group(0) @binding(3) var<uniform> params: Params; - +#endif @compute @workgroup_size(WG_SIZE) fn main(@builtin(global_invocation_id) gid: vec3<u32>) { @@ -62,14 +72,22 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) { ni[1] * params.stride_src0_1 + ni[2] * params.stride_src0_2 + ni[3] * params.stride_src0_3; +#ifdef SRC_OVERLAP + dst[params.offset_dst + gid.x] = merged_src[params.offset_src0 + src_i]; +#else dst[params.offset_dst + gid.x] = src0[params.offset_src0 + src_i]; +#endif } else { ni[params.dim] -= params.src0_nedim; let src_i = ni[0] * params.stride_src1_0 + ni[1] * params.stride_src1_1 + ni[2] * params.stride_src1_2 + ni[3] * params.stride_src1_3; +#ifdef SRC_OVERLAP + dst[params.offset_dst + gid.x] = merged_src[params.offset_src1 + src_i]; +#else dst[params.offset_dst + gid.x] = src1[params.offset_src1 + src_i]; +#endif } } } From aa42b48312f28f31a84840d51bcc783380b00d03 Mon Sep 17 00:00:00 2001 From: Masashi Yoshimura <yoshimura.masashi.frbs@gmail.com> Date: Tue, 9 Jun 2026 07:19:56 +0900 Subject: [PATCH 805/831] ggml-webgpu: Improve prefill speeds for k-quants + refactor matmul for Q4/Q5/Q8 and k-quants (llama/24225) * ggml-webgpu: Improve prefill speeds + refactor matmul for quants * Fixes for editroconfig checker --- .../wgsl-shaders/mul_mat_decls.tmpl | 810 ++++++------------ 1 file changed, 267 insertions(+), 543 deletions(-) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl index 72991504dd0..ed4a6b13bbf 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl @@ -98,72 +98,50 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 } #endif // INIT_SRC0_SHMEM_Q1_0 -#ifdef INIT_SRC0_SHMEM_Q4_0 +#if defined(INIT_SRC0_SHMEM_Q4_0) || defined(INIT_SRC0_SHMEM_Q4_1) || defined(INIT_SRC0_SHMEM_Q5_0) || defined(INIT_SRC0_SHMEM_Q5_1) || defined(INIT_SRC0_SHMEM_Q8_0) || defined(INIT_SRC0_SHMEM_Q8_1) || defined(INIT_SRC0_SHMEM_MXFP4) const BLOCK_SIZE = 32u; -const BLOCK_SIZE_BYTES = 18u; // the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. override BLOCKS_K = TILE_K/BLOCK_SIZE; const NQ = 16u; +#if defined(INIT_SRC0_SHMEM_Q8_0) || defined(INIT_SRC0_SHMEM_Q8_1) +const BYTES_PER_THREAD = 16u; // NQ(16) weights use 16 bytes of q +#else const BYTES_PER_THREAD = 8u; // NQ(16) weights use 8 bytes of q +#endif const BYTES_PER_INNER_LOOP = 4u; // == sizeof(q_packed) fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) { - let blck_idx = i / BLOCK_SIZE; + let block_idx = i / BLOCK_SIZE; let block_offset = (i % BLOCK_SIZE) / NQ; - let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD; + let shmem_idx = block_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD; - let tile_m = blck_idx / BLOCKS_K; + let tile_m = block_idx / BLOCKS_K; let global_m = offset_m + tile_m; - let block_k = blck_idx % BLOCKS_K; + let block_k = block_idx % BLOCKS_K; let global_block_k = k_outer / BLOCK_SIZE + block_k; if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) { let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k; - let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + +#ifdef INIT_SRC0_SHMEM_Q4_0 + let block_byte_base = src0_idx * 18u; // BLOCK_SIZE_BYTES = 18u; let d = load_f16_at_src0(block_byte_base); - // store NQ(16) weights + // load NQ(16) weights for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) { - let q_byte_offset = block_byte_base + 2u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP; let q_packed = load_u32_at_src0(q_byte_offset); dequant_q4_0_packed_to_shmem(q_packed, d, shmem_idx + j * BYTES_PER_INNER_LOOP); } - } - } -} -#endif // INIT_SRC0_SHMEM_Q4_0 +#elif INIT_SRC0_SHMEM_Q4_1 + let block_byte_base = src0_idx * 20u; // BLOCK_SIZE_BYTES = 20u; + let dm = unpack2x16float(load_u32_at_src0_aligned(block_byte_base)); + let d = f16(dm[0]); + let m = f16(dm[1]); -#ifdef INIT_SRC0_SHMEM_Q4_1 -const BLOCK_SIZE = 32u; -const BLOCK_SIZE_BYTES = 20u; -// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. -override BLOCKS_K = TILE_K/BLOCK_SIZE; -const NQ = 16u; -const BYTES_PER_THREAD = 8u; // NQ(16) weights use 8 bytes of q -const BYTES_PER_INNER_LOOP = 4u; // == sizeof(q_packed) - -fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { - for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) { - let blck_idx = i / BLOCK_SIZE; - let block_offset = (i % BLOCK_SIZE) / NQ; - let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD; - - let tile_m = blck_idx / BLOCKS_K; - let global_m = offset_m + tile_m; - let block_k = blck_idx % BLOCKS_K; - let global_block_k = k_outer / BLOCK_SIZE + block_k; - - if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) { - let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k; - let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_f16_at_src0(block_byte_base); - let m = load_f16_at_src0(block_byte_base + 2u); - - // store NQ(16) weights + // load NQ(16) weights for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) { - let q_byte_offset = block_byte_base + 4u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP; let q_packed = load_u32_at_src0(q_byte_offset); @@ -175,41 +153,13 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = q_hi; } } - } - } -} -#endif // INIT_SRC0_SHMEM_Q4_1 - -#ifdef INIT_SRC0_SHMEM_Q5_0 -const BLOCK_SIZE = 32u; -const BLOCK_SIZE_BYTES = 22u; -// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. -// tile_k is defined as 32u, so blocks_k ends up being 1 always -override BLOCKS_K = TILE_K / BLOCK_SIZE; -const NQ = 16u; -const BYTES_PER_THREAD = 8u; // NQ(16) weights use 8 bytes of q -const BYTES_PER_INNER_LOOP = 4u; // == sizeof(q_packed) - -fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { - - for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) { - let blck_idx = i / BLOCK_SIZE; - let block_offset = (i % BLOCK_SIZE) / NQ; - let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD; - - let tile_m = blck_idx / BLOCKS_K; - let global_m = offset_m + tile_m; - let block_k = blck_idx % BLOCKS_K; - let global_block_k = k_outer / BLOCK_SIZE + block_k; - - if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) { - let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k; - let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; +#elif INIT_SRC0_SHMEM_Q5_0 + let block_byte_base = src0_idx * 22u; // BLOCK_SIZE_BYTES = 22u; let d = load_f16_at_src0(block_byte_base); let qh_packed = load_u32_at_src0(block_byte_base + 2u); - // store NQ(16) weights + // load NQ(16) weights for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) { let q_byte_offset = block_byte_base + 6u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP; let q_packed = load_u32_at_src0(q_byte_offset); @@ -226,44 +176,18 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = q_hi; } } - } - } -} -#endif // INIT_SRC0_SHMEM_Q5_0 - -#ifdef INIT_SRC0_SHMEM_Q5_1 -const BLOCK_SIZE = 32u; -const BLOCK_SIZE_BYTES = 24u; -// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. -override BLOCKS_K = TILE_K / BLOCK_SIZE; -const NQ = 16u; -const BYTES_PER_THREAD = 8u; // NQ(16) weights use 8 bytes of q -const BYTES_PER_INNER_LOOP = 4u; // == sizeof(q_packed) - -fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { - - for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) { - let blck_idx = i / BLOCK_SIZE; - let block_offset = (i % BLOCK_SIZE) / NQ; - let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD; - - let tile_m = blck_idx / BLOCKS_K; - let global_m = offset_m + tile_m; - let block_k = blck_idx % BLOCKS_K; - let global_block_k = k_outer / BLOCK_SIZE + block_k; +#elif INIT_SRC0_SHMEM_Q5_1 + let block_byte_base = src0_idx * 24u; // BLOCK_SIZE_BYTES = 24u; - if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) { - let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k; - let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - - let d = load_f16_at_src0(block_byte_base); - let m = load_f16_at_src0(block_byte_base + 2u); - let qh_packed = load_u32_at_src0(block_byte_base + 4u); + let dm = unpack2x16float(load_u32_at_src0_aligned(block_byte_base)); + let d = f16(dm[0]); + let m = f16(dm[1]); + let qh_packed = load_u32_at_src0_aligned(block_byte_base + 4u); - // store NQ(16) weights + // load NQ(16) weights for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) { let q_byte_offset = block_byte_base + 8u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP; - let q_packed = load_u32_at_src0(q_byte_offset); + let q_packed = load_u32_at_src0_aligned(q_byte_offset); for (var k = 0u; k < BYTES_PER_INNER_LOOP; k++) { let q_byte = get_byte(q_packed, k); @@ -277,236 +201,73 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = q_hi; } } - } - } -} -#endif // INIT_SRC0_SHMEM_Q5_1 - -#ifdef INIT_SRC0_SHMEM_Q8_0 -const BLOCK_SIZE = 32u; -const BLOCK_SIZE_BYTES = 34u; -// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. -override BLOCKS_K = TILE_K/BLOCK_SIZE; -const NQ = 16u; -const BYTES_PER_THREAD = 16u; // NQ(16) weights use 16 bytes of q -const BYTES_PER_INNER_LOOP = 4u; // == sizeof(q_packed) - -fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { - for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) { - let blck_idx = i / BLOCK_SIZE; - let block_offset = (i % BLOCK_SIZE) / NQ; - let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD; - - let tile_m = blck_idx / BLOCKS_K; - let global_m = offset_m + tile_m; - let block_k = blck_idx % BLOCKS_K; - let global_block_k = k_outer / BLOCK_SIZE + block_k; - - if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) { - let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k; - let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; +#elif INIT_SRC0_SHMEM_Q8_0 + let block_byte_base = src0_idx * 34u; // BLOCK_SIZE_BYTES = 34u; let d = load_f16_at_src0(block_byte_base); - // store NQ(16) weights + // load NQ(16) weights for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) { let q_byte_offset = block_byte_base + 2u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP; let q_packed = load_u32_at_src0(q_byte_offset); dequant_q8_0_packed_to_shmem(q_packed, d, shmem_idx + j * BYTES_PER_INNER_LOOP); } - } - } -} -#endif // INIT_SRC0_SHMEM_Q8_0 +#elif INIT_SRC0_SHMEM_Q8_1 + let block_byte_base = src0_idx * 36u; // BLOCK_SIZE_BYTES = 36u; + let dm = unpack2x16float(load_u32_at_src0_aligned(block_byte_base)); + let d = f16(dm[0]); + let m = f16(dm[1]); -#ifdef INIT_SRC0_SHMEM_Q8_1 -const BLOCK_SIZE = 32u; -const BLOCK_SIZE_BYTES = 36u; -// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. -override BLOCKS_K = TILE_K/BLOCK_SIZE; -const NQ = 16u; -const BYTES_PER_THREAD = 16u; // NQ(16) weights use 16 bytes of q -const BYTES_PER_INNER_LOOP = 4u; // == sizeof(q_packed) - -fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { - for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) { - let blck_idx = i / BLOCK_SIZE; - let block_offset = (i % BLOCK_SIZE) / NQ; - let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD; - - let tile_m = blck_idx / BLOCKS_K; - let global_m = offset_m + tile_m; - let block_k = blck_idx % BLOCKS_K; - let global_block_k = k_outer / BLOCK_SIZE + block_k; - - if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) { - let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k; - let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_f16_at_src0(block_byte_base); - let m = load_f16_at_src0(block_byte_base + 2u); - - // store NQ(16) weights + // load NQ(16) weights for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) { let q_byte_offset = block_byte_base + 4u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP; let q_packed = load_u32_at_src0(q_byte_offset); for (var k = 0u; k < BYTES_PER_INNER_LOOP; k++) { let q_byte = get_byte_i32(q_packed, k); - let q_val = f16(q_byte) * d + m; shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k] = q_val; } } - } - } -} -#endif // INIT_SRC0_SHMEM_Q8_1 - -#ifdef INIT_SRC0_SHMEM_Q2_K -const BLOCK_SIZE = 256u; -const BLOCK_SIZE_BYTES = 84u; - -fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { - // Use standard thread layout instead of lane/row_group - for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { - let tile_m = elem_idx / TILE_K; - let tile_k = elem_idx % TILE_K; - - let global_m = offset_m + tile_m; - let global_k = k_outer + tile_k; +#elif INIT_SRC0_SHMEM_MXFP4 + let block_byte_base = src0_idx * 17u; + let eu8 = get_byte(load_u32_at_src0_aligned(block_byte_base), block_byte_base & 3u); + let e = ldexp(1.0, i32(eu8) - 128); - if (global_m >= params.m || global_k >= params.k) { - shmem[elem_idx] = f16(0.0); - continue; + // load NQ(16) weights + for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) { + let q_byte_offset = block_byte_base + 1u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP; + let q_packed = load_u32_at_src0(q_byte_offset); + for (var k = 0u; k < BYTES_PER_INNER_LOOP; k++) { + let q_byte = get_byte(q_packed, k); + let q_hi = f32(kvalues_mxfp4[(q_byte >> 4) & 0xF]) * e; + let q_lo = f32(kvalues_mxfp4[q_byte & 0xF]) * e; + shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k] = f16(q_lo); + shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = f16(q_hi); + } + } +#endif } - - let block_k = global_k / BLOCK_SIZE; - let k_in_block = global_k % BLOCK_SIZE; - - let src0_idx = batch_offset + global_m * params.stride_01 + block_k; - let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - - let d = load_f16_at_src0(block_byte_base + 80u); - let dmin = load_f16_at_src0(block_byte_base + 82u); - - // Decode the element at position k_in_block - let block_of_32 = k_in_block / 32u; - let pos_in_32 = k_in_block % 32u; - - let q_b_idx = (block_of_32 / 4u) * 32u; - let shift = (block_of_32 % 4u) * 2u; - let k = (pos_in_32 / 16u) * 16u; - let l = pos_in_32 % 16u; - - let is = k_in_block / 16u; - - let sc_packed = load_u32_at_src0(block_byte_base + 4u * (is / 4u)); - let sc = get_byte(sc_packed, is % 4u); - - let dl = d * f16(sc & 0xFu); - let ml = dmin * f16(sc >> 4u); - - let q_idx = q_b_idx + k + l; - let q_packed = load_u32_at_src0(block_byte_base + 16u + 4u * (q_idx / 4u)); - let q_byte = get_byte(q_packed, q_idx % 4u); - let qs_val = (q_byte >> shift) & 3u; - - let q_val = f16(qs_val) * dl - ml; - shmem[elem_idx] = q_val; } } -#endif // INIT_SRC0_SHMEM_Q2_K +#endif -#ifdef INIT_SRC0_SHMEM_Q3_K +// k-quants +#if defined(INIT_SRC0_SHMEM_Q2_K) || defined(INIT_SRC0_SHMEM_Q3_K) || defined(INIT_SRC0_SHMEM_Q4_K) || defined(INIT_SRC0_SHMEM_Q5_K) || defined(INIT_SRC0_SHMEM_Q6_K) const BLOCK_SIZE = 256u; -const BLOCK_SIZE_BYTES = 110u; - -fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { - for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { - let tile_m = elem_idx / TILE_K; - let tile_k = elem_idx % TILE_K; - - let global_m = offset_m + tile_m; - let global_k = k_outer + tile_k; - - if (global_m >= params.m || global_k >= params.k) { - shmem[elem_idx] = f16(0.0); - continue; - } - - let block_k = global_k / BLOCK_SIZE; - let k_in_block = global_k % BLOCK_SIZE; - - let src0_idx = batch_offset + global_m * params.stride_01 + block_k; - let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; +const NQ = 4u; - let d = load_f16_at_src0(block_byte_base + 108u); - - // Load and unpack scales - let kmask1: u32 = 0x03030303u; - let kmask2: u32 = 0x0f0f0f0fu; - - var scale_vals: array<u32, 4>; - for (var i: u32 = 0u; i < 4u; i++) { - scale_vals[i] = load_u32_at_src0(block_byte_base + 96u + 4u * i); - } - - var tmp: u32 = scale_vals[2]; - scale_vals[2] = ((scale_vals[0] >> 4u) & kmask2) | (((tmp >> 4u) & kmask1) << 4u); - scale_vals[3] = ((scale_vals[1] >> 4u) & kmask2) | (((tmp >> 6u) & kmask1) << 4u); - scale_vals[0] = (scale_vals[0] & kmask2) | ((tmp & kmask1) << 4u); - scale_vals[1] = (scale_vals[1] & kmask2) | (((tmp >> 2u) & kmask1) << 4u); - - // Load hmask and qs arrays - var hmask_vals: array<u32, 8>; - for (var i: u32 = 0u; i < 8u; i++) { - hmask_vals[i] = load_u32_at_src0(block_byte_base + 4u * i); - } - - var qs_vals: array<u32, 16>; - for (var i: u32 = 0u; i < 16u; i++) { - qs_vals[i] = load_u32_at_src0(block_byte_base + 32u + 4u * i); - } - - let half = k_in_block / 128u; // 0 or 1 - let pos_in_half = k_in_block % 128u; // 0-127 - let shift_group = pos_in_half / 32u; // 0-3 - let pos_in_32 = pos_in_half % 32u; // 0-31 - let k_group = pos_in_32 / 16u; // 0 or 1 - let l = pos_in_32 % 16u; // 0-15 - - let q_b_idx = half * 32u; // 0 or 32 - let shift = shift_group * 2u; // 0, 2, 4, 6 - let k = k_group * 16u; // 0 or 16 - let is = k_in_block / 16u; // 0-15 - - // m increments every 32 elements across entire 256 element block - let m_shift = k_in_block / 32u; // 0-7 - let m: u32 = 1u << m_shift; // 1,2,4,8,16,32,64,128 - - let sc = get_byte(scale_vals[is / 4u], is % 4u); - let dl = d * (f16(sc) - 32.0); - - let q_idx = q_b_idx + k + l; - let hm_idx = k + l; - - let q_byte = get_byte(qs_vals[q_idx / 4u], q_idx % 4u); - let hmask_byte = get_byte(hmask_vals[hm_idx / 4u], hm_idx % 4u); - - let hm = select(4.0, 0.0, (hmask_byte & m) != 0); - let qs_val = (q_byte >> shift) & 3u; - - let q_val = (f16(qs_val) - f16(hm)) * dl; - shmem[elem_idx] = q_val; - } +fn store_shmem_kquants(val: vec4<f16>, idx: u32) { + shmem[idx] = val.x; + shmem[idx + 1] = val.y; + shmem[idx + 2] = val.z; + shmem[idx + 3] = val.w; } -#endif // INIT_SRC0_SHMEM_Q3_K - -#ifdef INIT_SRC0_SHMEM_Q4_K -const BLOCK_SIZE = 256u; -const BLOCK_SIZE_BYTES = 144u; +fn load_byte_at_src0_aligned(byte_offset: u32) -> u32 { + return get_byte(load_u32_at_src0_aligned(byte_offset), byte_offset % 4u); +} fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { - for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { + for (var elem_idx = thread_id * NQ; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * NQ) { let tile_m = elem_idx / TILE_K; let tile_k = elem_idx % TILE_K; @@ -514,224 +275,232 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let global_k = k_outer + tile_k; if (global_m >= params.m || global_k >= params.k) { - shmem[elem_idx] = f16(0.0); + store_shmem_kquants(vec4<f16>(f16(0.0), f16(0.0), f16(0.0), f16(0.0)), elem_idx); continue; } - let block_k = global_k / BLOCK_SIZE; - let k_in_block = global_k % BLOCK_SIZE; + let block_k = global_k / BLOCK_SIZE; + let k_in_block = global_k % BLOCK_SIZE; // k_in_block % 4 == 0; let src0_idx = batch_offset + global_m * params.stride_01 + block_k; - let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_f16_at_src0(block_byte_base); - let dmin = load_f16_at_src0(block_byte_base + 2u); - - // Map k_in_block to loop structure: - // Outer loop over 64-element groups (alternating q_b_idx) - // Inner loop over 2 shifts per group - let group_of_64 = k_in_block / 64u; // 0-3 (maps to q_b_idx) - let pos_in_64 = k_in_block % 64u; // 0-63 - let shift_group = pos_in_64 / 32u; // 0 or 1 - let l = pos_in_64 % 32u; // 0-31 - - let q_b_idx = group_of_64 * 32u; // 0, 32, 64, 96 - let shift = shift_group * 4u; // 0 or 4 - let is = k_in_block / 32u; // 0-7 +#ifdef INIT_SRC0_SHMEM_Q2_K + let block_byte_base = src0_idx * 84u; // BLOCK_SIZE_BYTES = 84u; + let scales_byte_base = block_byte_base; + let qs_byte_base = block_byte_base + 16u; + let dm_byte_base = block_byte_base + 80u; + + let d_packed = unpack2x16float(load_u32_at_src0_aligned(dm_byte_base)); + let d = f16(d_packed[0]); + let dmin = f16(d_packed[1]); + + let chunk = k_in_block / 128u; + let pos_in_chunk = k_in_block % 32u; + let sub_block = k_in_block / 16u; + let shift_phase = (k_in_block % 128u) / 32u; + + // whole 2 bits (4 elems) + let qs_word = load_u32_at_src0_aligned(qs_byte_base + 32u * chunk + 1u * pos_in_chunk); + let qs_vec4 = vec4<f16>( + f16((qs_word >> (2u * shift_phase + 0u)) & 0x3u), + f16((qs_word >> (2u * shift_phase + 8u)) & 0x3u), + f16((qs_word >> (2u * shift_phase + 16u)) & 0x3u), + f16((qs_word >> (2u * shift_phase + 24u)) & 0x3u), + ); + + let scale = load_byte_at_src0_aligned(scales_byte_base + sub_block); + + let dl = d * f16(scale & 0xFu); + let ml = dmin * f16(scale >> 4u); + + store_shmem_kquants(qs_vec4 * dl - ml, elem_idx); +#elif INIT_SRC0_SHMEM_Q3_K + let block_byte_base = src0_idx * 110u; // BLOCK_SIZE_BYTES = 110u; + let hmask_byte_base = block_byte_base + 0u; + let qs_byte_base = block_byte_base + 32u; + let scales_byte_base = block_byte_base + 96u; + + let d_all = load_f16_at_src0(block_byte_base + 108u); + + let chunk = k_in_block / 128u; + let pos_in_chunk = k_in_block % 32u; + let sub_block = k_in_block / 16u; + let shift_phase = (k_in_block % 128u) / 32u; + + let hmask_block = pos_in_chunk; + let hmask_shift_phase = k_in_block / 32u; + + // low 2 bits (4 elems) + let q_lo2_word = load_u32_at_src0(qs_byte_base + 32u * chunk + 1u * hmask_block); + let q_lo2_vec4 = vec4<f16>( + f16((q_lo2_word >> (2u * shift_phase + 0u)) & 3u), + f16((q_lo2_word >> (2u * shift_phase + 8u)) & 3u), + f16((q_lo2_word >> (2u * shift_phase + 16u)) & 3u), + f16((q_lo2_word >> (2u * shift_phase + 24u)) & 3u) + ); + + // high 1 bit (4 elems) + let q_hi1_word = load_u32_at_src0(hmask_byte_base + pos_in_chunk); + let q_hi1_vec4 = vec4<f16>( + f16(select(4.0, 0.0, ((q_hi1_word >> (1u * hmask_shift_phase + 0u)) & 1u) == 1u)), + f16(select(4.0, 0.0, ((q_hi1_word >> (1u * hmask_shift_phase + 8u)) & 1u) == 1u)), + f16(select(4.0, 0.0, ((q_hi1_word >> (1u * hmask_shift_phase + 16u)) & 1u) == 1u)), + f16(select(4.0, 0.0, ((q_hi1_word >> (1u * hmask_shift_phase + 24u)) & 1u) == 1u)) + ); + + let q_vec4 = q_lo2_vec4 - q_hi1_vec4; + + let scale_low4 = (load_byte_at_src0_aligned(scales_byte_base + (sub_block % 8u)) >> (4u * (sub_block / 8u))) & 0xFu; + let scale_hi2 = (load_byte_at_src0_aligned(scales_byte_base + 8u + (sub_block % 4u)) >> (2u * (sub_block / 4u))) & 3u; + let dl = d_all * (f16((scale_hi2 << 4u) | scale_low4) - 32.0); + + store_shmem_kquants(dl * q_vec4, elem_idx); +#elif INIT_SRC0_SHMEM_Q4_K + let block_byte_base = src0_idx * 144u; // BLOCK_SIZE_BYTES = 144u; + let dm_byte_base = block_byte_base + 0u; + let scale_byte_base = block_byte_base + 4u; + let qs_byte_base = block_byte_base + 16u; + + let dm = unpack2x16float(load_u32_at_src0_aligned(dm_byte_base)); + let d = f16(dm[0]); + let dmin = f16(dm[1]); + + let chunk = k_in_block / 64u; + let pos_in_chunk = (k_in_block % 64u) % 32u; + let sub_block = k_in_block / 32u; + let shift_phase = sub_block & 1u; + + // whole 4 bits (4 elems) + let qs_word = load_u32_at_src0_aligned(qs_byte_base + 32u * chunk + 1u * pos_in_chunk); + let qs_vec4 = vec4<f16>( + f16((qs_word >> (4u * shift_phase + 0u)) & 0xFu), + f16((qs_word >> (4u * shift_phase + 8u)) & 0xFu), + f16((qs_word >> (4u * shift_phase + 16u)) & 0xFu), + f16((qs_word >> (4u * shift_phase + 24u)) & 0xFu) + ); var sc: u32; var mn: u32; - let scale_base = block_byte_base + 4u; - - if (is < 4u) { - let sc_byte = get_byte(load_u32_at_src0(scale_base), is % 4u); - let min_byte = get_byte(load_u32_at_src0(scale_base + 4), is % 4u); - sc = sc_byte & 63u; - mn = min_byte & 63u; + if (sub_block < 4u) { + let sc_byte = get_byte(load_u32_at_src0_aligned(scale_byte_base), sub_block % 4u); + let min_byte = get_byte(load_u32_at_src0_aligned(scale_byte_base + 4), sub_block % 4u); + sc = sc_byte & 63u; + mn = min_byte & 63u; } else { - let sc_min_lo = get_byte(load_u32_at_src0(scale_base + 8), (is + 4u) % 4u); - let sc_hi = get_byte(load_u32_at_src0(scale_base), (is - 4u) % 4u); - let min_hi = get_byte(load_u32_at_src0(scale_base + 4), is % 4u); - - sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u); - mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u); + let sc_min_lo = get_byte(load_u32_at_src0_aligned(scale_byte_base + 8), (sub_block + 4u) % 4u); + let sc_hi = get_byte(load_u32_at_src0_aligned(scale_byte_base), (sub_block - 4u) % 4u); + let min_hi = get_byte(load_u32_at_src0_aligned(scale_byte_base + 4), sub_block % 4u); + sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u); + mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u); } let dl = d * f16(sc); let ml = dmin * f16(mn); - let q_idx = q_b_idx + l; - let q_packed = load_u32_at_src0(block_byte_base + 16u + 4u * (q_idx / 4u)); - - let q_byte = get_byte(q_packed, q_idx % 4u); - let qs_val = (q_byte >> shift) & 0xFu; - - let q_val = f16(qs_val) * dl - ml; - shmem[elem_idx] = q_val; - } -} -#endif // INIT_SRC0_SHMEM_Q4_K - -#ifdef INIT_SRC0_SHMEM_Q5_K -const BLOCK_SIZE = 256u; -const BLOCK_SIZE_BYTES = 176u; - -fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { - for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { - let tile_m = elem_idx / TILE_K; - let tile_k = elem_idx % TILE_K; - - let global_m = offset_m + tile_m; - let global_k = k_outer + tile_k; - - if (global_m >= params.m || global_k >= params.k) { - shmem[elem_idx] = f16(0.0); - continue; - } - - let block_k = global_k / BLOCK_SIZE; - let k_in_block = global_k % BLOCK_SIZE; - - let src0_idx = batch_offset + global_m * params.stride_01 + block_k; - let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - - let d = load_f16_at_src0(block_byte_base); - let dmin = load_f16_at_src0(block_byte_base + 2u); - - - // The original loop processes elements in groups of 64 - // Each group of 64: q_b_idx cycles through [0,32,64,96], shift cycles [0,4] - // But u increments EVERY 32 elements (after each l loop) - let group_of_64 = k_in_block / 64u; // 0-3 - let pos_in_64 = k_in_block % 64u; // 0-63 - let shift_group = pos_in_64 / 32u; // 0 or 1 - let l = pos_in_64 % 32u; // 0-31 - - let q_b_idx = group_of_64 * 32u; // 0, 32, 64, 96 - let shift = shift_group * 4u; // 0 or 4 - let is = k_in_block / 32u; // 0-7 - - // u increments every 32 elements (0->1, 1->2, 2->4, 3->8, 4->16, 5->32, 6->64, 7->128) - let u_shift = k_in_block / 32u; // 0-7 - let u: u32 = 1u << u_shift; + store_shmem_kquants(dl * qs_vec4 - vec4(ml, ml, ml, ml), elem_idx); +#elif INIT_SRC0_SHMEM_Q5_K + let block_byte_base = src0_idx * 176u; // BLOCK_SIZE_BYTES = 176u; + let dm_byte_base = block_byte_base + 0u; + let scale_byte_base = block_byte_base + 4u; + let qh_byte_base = block_byte_base + 16u; + let qs_byte_base = block_byte_base + 48u; + + let dm = unpack2x16float(load_u32_at_src0_aligned(dm_byte_base)); + let d = f16(dm[0]); + let dmin = f16(dm[1]); + + let chunk = k_in_block / 64u; + let pos_in_chunk = (k_in_block % 64u) % 32u; + let sub_block = k_in_block / 32u; + let shift_phase = sub_block & 1u; + + let qh_block = k_in_block % 32u; + let qh_shift_phase = sub_block; + + // low 4 bits (4 elems) + let qs_word = load_u32_at_src0_aligned(qs_byte_base + 32u * chunk + 1u * pos_in_chunk); + let qs_lo4_vec4 = vec4<f16>( + f16((qs_word >> (4u * shift_phase + 0u)) & 0xFu), + f16((qs_word >> (4u * shift_phase + 8u)) & 0xFu), + f16((qs_word >> (4u * shift_phase + 16u)) & 0xFu), + f16((qs_word >> (4u * shift_phase + 24u)) & 0xFu) + ); + + // high 1 bit (4 elems) + let qh_word = load_u32_at_src0_aligned(qh_byte_base + qh_block); + let qh_vec4 = vec4<f16>( + f16(select(0.0, 16.0, ((qh_word >> (1u * qh_shift_phase + 0u)) & 1u) == 1u)), + f16(select(0.0, 16.0, ((qh_word >> (1u * qh_shift_phase + 8u)) & 1u) == 1u)), + f16(select(0.0, 16.0, ((qh_word >> (1u * qh_shift_phase + 16u)) & 1u) == 1u)), + f16(select(0.0, 16.0, ((qh_word >> (1u * qh_shift_phase + 24u)) & 1u) == 1u)) + ); var sc: u32; var mn: u32; - let scale_base = block_byte_base + 4u; - - if (is < 4u) { - let sc_byte = get_byte(load_u32_at_src0(scale_base), is % 4u); - let min_byte = get_byte(load_u32_at_src0(scale_base + 4), is % 4u); - sc = sc_byte & 63u; - mn = min_byte & 63u; + if (sub_block < 4u) { + let sc_byte = get_byte(load_u32_at_src0_aligned(scale_byte_base), sub_block % 4u); + let min_byte = get_byte(load_u32_at_src0_aligned(scale_byte_base + 4), sub_block % 4u); + sc = sc_byte & 63u; + mn = min_byte & 63u; } else { - let sc_min_lo = get_byte(load_u32_at_src0(scale_base + 8), (is + 4u) % 4u); - let sc_hi = get_byte(load_u32_at_src0(scale_base), (is - 4u) % 4u); - let min_hi = get_byte(load_u32_at_src0(scale_base + 4), is % 4u); - - sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u); - mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u); + let sc_min_lo = get_byte(load_u32_at_src0_aligned(scale_byte_base + 8), (sub_block + 4u) % 4u); + let sc_hi = get_byte(load_u32_at_src0_aligned(scale_byte_base), (sub_block - 4u) % 4u); + let min_hi = get_byte(load_u32_at_src0_aligned(scale_byte_base + 4), sub_block % 4u); + sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u); + mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u); } let dl = d * f16(sc); let ml = dmin * f16(mn); - let q_idx = q_b_idx + l; - let q_packed = load_u32_at_src0(block_byte_base + 48u + 4u * (q_idx / 4u)); - - let q_byte = get_byte(q_packed, q_idx % 4u); - - let qh_packed = load_u32_at_src0(block_byte_base + 16u + 4u * (l / 4u)); - - let qh_byte = get_byte(qh_packed, l % 4u); - - let qs_val = (q_byte >> shift) & 0xFu; - let qh_val = select(0.0, 16.0, (qh_byte & u) != 0); - - let q_val = (f16(qs_val) + f16(qh_val)) * dl - ml; - shmem[elem_idx] = q_val; - } -} - -#endif // INIT_SRC0_SHMEM_Q5_K - -#ifdef INIT_SRC0_SHMEM_Q6_K -const BLOCK_SIZE = 256u; -const BLOCK_SIZE_BYTES = 210u; - -fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { - for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { - let tile_m = elem_idx / TILE_K; - let tile_k = elem_idx % TILE_K; - - let global_m = offset_m + tile_m; - let global_k = k_outer + tile_k; - - if (global_m >= params.m || global_k >= params.k) { - shmem[elem_idx] = f16(0.0); - continue; - } - - let block_k = global_k / BLOCK_SIZE; - let k_in_block = global_k % BLOCK_SIZE; - - let src0_idx = batch_offset + global_m * params.stride_01 + block_k; - let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - - let half = k_in_block / 128u; - let pos_in_half = k_in_block % 128u; - let quarter = pos_in_half / 32u; - let l = pos_in_half % 32u; - - let ql_b_idx = half * 64u; - let qh_b_idx = half * 32u; - let sc_b_idx = half * 8u; - - // Load only ql13 word needed - let ql13_flat = ql_b_idx + l; - let ql13 = load_u32_at_src0(block_byte_base + ql13_flat); - let ql13_b = get_byte(ql13, 0u); - - // Load only ql24 word needed - let ql24_flat = ql_b_idx + l + 32u; - let ql24 = load_u32_at_src0(block_byte_base + ql24_flat); - let ql24_b = get_byte(ql24, 0u); - - // Load only qh word needed - let qh_flat = qh_b_idx + l; - let qh = load_u32_at_src0(block_byte_base + 128u + qh_flat); - let qh_b = get_byte(qh, 0u); - - let q1 = f16((ql13_b & 0xFu) | ((qh_b & 3u) << 4u)) - f16(32.0); - let q2 = f16((ql24_b & 0xFu) | (((qh_b >> 2u) & 3u) << 4u)) - f16(32.0); - let q3 = f16((ql13_b >> 4u) | (((qh_b >> 4u) & 3u) << 4u)) - f16(32.0); - let q4 = f16((ql24_b >> 4u) | (((qh_b >> 6u) & 3u) << 4u)) - f16(32.0); - - // Load only the scale word needed - let is = l / 16u; - let sc_idx = sc_b_idx + is + quarter * 2u; - let sc = load_u32_at_src0(block_byte_base + 192u + sc_idx); - let sc_val = get_byte_i32(sc, 0u); - - let d = load_f16_at_src0(block_byte_base + 208u); - - var q_val: f16; - if (quarter == 0u) { - q_val = q1; - } else if (quarter == 1u) { - q_val = q2; - } else if (quarter == 2u) { - q_val = q3; - } else { - q_val = q4; - } - - shmem[elem_idx] = d * f16(sc_val) * q_val; + store_shmem_kquants((qh_vec4 + qs_lo4_vec4) * dl - vec4<f16>(ml, ml, ml, ml), elem_idx); +#elif INIT_SRC0_SHMEM_Q6_K + let block_byte_base = src0_idx * 210u; // BLOCK_SIZE_BYTES = 210u; + let ql_byte_base = block_byte_base; + let qh_byte_base = block_byte_base + 128u; + let scales_byte_base = block_byte_base + 192u; + let d_byte_base = block_byte_base + 208u; + + let d = load_f16_at_src0(d_byte_base); + + let chunk = k_in_block / 128u; + let ql_pos_in_chunk = (k_in_block % 128u) % 64u; + let qh_pos_in_chunk = (k_in_block % 128u) % 32u; + let sub_block = k_in_block / 16u; + let ql_shift_phase = (k_in_block % 128u) / 64u; + let qh_shift_phase = (k_in_block % 128u) / 32u; + + // low 4 bits (4 elems) + let ql_word = load_u32_at_src0(ql_byte_base + 64u * chunk + 1u * ql_pos_in_chunk); + let ql_lo4_vec4 = vec4<u32>( + (ql_word >> (4u * ql_shift_phase + 0u)) & 0xFu, + (ql_word >> (4u * ql_shift_phase + 8u)) & 0xFu, + (ql_word >> (4u * ql_shift_phase + 16u)) & 0xFu, + (ql_word >> (4u * ql_shift_phase + 24u)) & 0xFu + ); + + // hi 2 bits (4 elems) + let qh_word = load_u32_at_src0(qh_byte_base + 32u * chunk + 1u * qh_pos_in_chunk); + let qh_hi2_vec4 = vec4<u32>( + ((qh_word >> (2u * qh_shift_phase + 0u)) & 0x3u) << 4u, + ((qh_word >> (2u * qh_shift_phase + 8u)) & 0x3u) << 4u, + ((qh_word >> (2u * qh_shift_phase + 16u)) & 0x3u) << 4u, + ((qh_word >> (2u * qh_shift_phase + 24u)) & 0x3u) << 4u, + ); + + let q_vec4 = vec4<f16>(qh_hi2_vec4 | ql_lo4_vec4) - vec4<f16>(32.0, 32.0, 32.0, 32.0); + + let scale_byte = scales_byte_base + 1u * sub_block; + let scale_word = load_u32_at_src0_aligned(scale_byte); + let scale = get_byte_i32(scale_word, scale_byte & 3u); + + store_shmem_kquants(d * q_vec4 * f16(scale), elem_idx); +#endif } } -#endif // INIT_SRC0_SHMEM_Q6_K +#endif // k-quants #ifdef INIT_SRC0_SHMEM_IQ4_NL const BLOCK_SIZE = 32u; @@ -1155,48 +924,3 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 } } #endif // INIT_SRC0_SHMEM_IQ3_S - -#ifdef INIT_SRC0_SHMEM_MXFP4 -const BLOCK_SIZE = 32u; -const BLOCK_SIZE_BYTES = 17u; -// the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. -override BLOCKS_K = TILE_K/BLOCK_SIZE; -const NQ = 16u; -const BYTES_PER_THREAD = 8u; // NQ(16) weights uses 8 bytes of q -const BYTES_PER_INNER_LOOP = 4u; // == sizeof(q_packed) - -fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { - for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) { - let blck_idx = i / BLOCK_SIZE; - let block_offset = (i % BLOCK_SIZE) / NQ; - let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD; - - let tile_m = blck_idx / BLOCKS_K; - let global_m = offset_m + tile_m; - let block_k = blck_idx % BLOCKS_K; - let global_block_k = k_outer / BLOCK_SIZE + block_k; - - if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) { - let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k; - let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let eu8 = get_byte(load_u32_at_src0(block_byte_base), 0); - let e = ldexp(1.0, i32(eu8) - 128); - - // store NQ(16) weights - for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) { - - let q_byte_offset = block_byte_base + 1u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP; - let q_packed = load_u32_at_src0(q_byte_offset); - - for (var k = 0u; k < BYTES_PER_INNER_LOOP; k++) { - let q_byte = get_byte(q_packed, k); - let q_hi = f32(kvalues_mxfp4[(q_byte >> 4) & 0xF]) * e; - let q_lo = f32(kvalues_mxfp4[q_byte & 0xF]) * e; - shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k] = f16(q_lo); - shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = f16(q_hi); - } - } - } - } -} -#endif // INIT_SRC0_SHMEM_MXFP4 From e69e5138fe7a1737d3d37ec52903bb050a09a0eb Mon Sep 17 00:00:00 2001 From: Reese Levine <reeselevine1@gmail.com> Date: Mon, 8 Jun 2026 20:54:24 -0700 Subject: [PATCH 806/831] ggml-webgpu: Add clang-format job (llama/24308) * Add clang-format job * try local formatting --- .../ggml-webgpu/ggml-webgpu-shader-lib.hpp | 19 +++++++++++-------- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 6 +++--- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index c75a98a8dd4..6f877f15ce9 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -644,7 +644,8 @@ inline size_t ggml_webgpu_flash_attn_tensor_offset(const ggml_tensor * tensor) { inline bool ggml_webgpu_flash_attn_float_vec4_aligned(const ggml_tensor * K, size_t storage_offset_alignment) { const uint32_t offset_elems = - (uint32_t) ((ggml_webgpu_flash_attn_tensor_offset(K) & (storage_offset_alignment - 1)) / ggml_type_size(K->type)); + (uint32_t) ((ggml_webgpu_flash_attn_tensor_offset(K) & (storage_offset_alignment - 1)) / + ggml_type_size(K->type)); return offset_elems % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0u; } @@ -655,8 +656,10 @@ inline bool ggml_webgpu_flash_attn_float_vec4_aligned(const ggml_tensor * K, ggml_webgpu_flash_attn_float_vec4_aligned(V, storage_offset_alignment); } -inline bool ggml_webgpu_flash_attn_kv_direct( - const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, uint32_t kv_direct_align) { +inline bool ggml_webgpu_flash_attn_kv_direct(const ggml_tensor * Q, + const ggml_tensor * K, + const ggml_tensor * V, + uint32_t kv_direct_align) { return K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16 && (Q->ne[0] % kv_direct_align == 0) && (K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0); } @@ -671,10 +674,10 @@ inline ggml_webgpu_flash_attn_common_pipeline_key ggml_webgpu_flash_attn_make_co key.dst_type = context.dst->type; key.head_dim_qk = (uint32_t) context.src0->ne[0]; key.head_dim_v = (uint32_t) context.src2->ne[0]; - key.kv_direct = ggml_webgpu_flash_attn_kv_direct(context.src0, context.src1, context.src2, kv_direct_align); - key.kv_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src2); - key.has_mask = context.src3 != nullptr; - key.has_sinks = context.src4 != nullptr; + key.kv_direct = ggml_webgpu_flash_attn_kv_direct(context.src0, context.src1, context.src2, kv_direct_align); + key.kv_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src2); + key.has_mask = context.src3 != nullptr; + key.has_sinks = context.src4 != nullptr; key.uses_logit_softcap = ggml_get_op_params_f32(context.dst, 2) != 0.0f; return key; } @@ -1727,7 +1730,7 @@ class ggml_webgpu_shader_lib { key.type = context.dst->type; key.d_state = (int) context.src0->ne[0]; key.xbc_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src4) && - ggml_webgpu_tensor_overlap(context.src1, context.src5); + ggml_webgpu_tensor_overlap(context.src1, context.src5); auto it = ssm_scan_pipelines.find(key); if (it != ssm_scan_pipelines.end()) { diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 79d5138029d..538e587bbe5 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -4253,9 +4253,9 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const const uint32_t q_tile = use_subgroup_matrix ? capabilities.sg_mat_m : GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE; const uint32_t kv_granularity = use_subgroup_matrix ? capabilities.sg_mat_n : 1u; - const bool kv_direct = use_subgroup_matrix ? - ggml_webgpu_flash_attn_kv_direct(src0, src1, src2, capabilities.sg_mat_k) : - false; + const bool kv_direct = use_subgroup_matrix ? + ggml_webgpu_flash_attn_kv_direct(src0, src1, src2, capabilities.sg_mat_k) : + false; const uint32_t max_kv_tile = ggml_webgpu_flash_attn_max_kv_tile( capabilities.limits.maxComputeWorkgroupStorageSize, q_tile, kv_granularity, (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], op->src[3] != nullptr, kv_direct); From 72894aa2503e3eb3fdb99592da20bb313f1a9c44 Mon Sep 17 00:00:00 2001 From: ravel7524 <58877666+ravel7524@users.noreply.github.com> Date: Tue, 9 Jun 2026 07:46:23 +0200 Subject: [PATCH 807/831] Remove case for GGML_TYPE_Q4_K in mvvq.cu (llama/23528) --- ggml/src/ggml-cuda/mmvq.cu | 1 - 1 file changed, 1 deletion(-) diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index bdfbfd2d387..fe44a58da91 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -411,7 +411,6 @@ static constexpr __host__ __device__ int calc_nwarps(ggml_type type, int ncols_d case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: - case GGML_TYPE_Q4_K: return 8; case GGML_TYPE_Q6_K: return 2; From 2d68a3066f601e95f189dd8468b3e9fe73ac445e Mon Sep 17 00:00:00 2001 From: Yash Raj Pandey <55940078+devYRPauli@users.noreply.github.com> Date: Tue, 9 Jun 2026 03:24:27 -0400 Subject: [PATCH 808/831] ggml-cpu : fix rms_norm_back wrong output under in-place aliasing (llama/24305) * ggml-cpu : fix rms_norm_back wrong output under in-place aliasing * cont : clean-up comment --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> --- ggml/src/ggml-cpu/ops.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 3a1912ae91b..becac9d6ef9 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -4008,12 +4008,12 @@ static void ggml_compute_forward_rms_norm_back_f32( // dx := scale(dx, rrms) float * dx = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); - // dx[i00] = (x*(-sum_xdz/sum_eps) + dz) / sqrtf(mean_eps) - ggml_vec_cpy_f32 (ne00, dx, x); - // ggml_vec_scale_f32(ne00, dx, -mean_xdz/mean_eps); - ggml_vec_scale_f32(ne00, dx, (float)(-sum_xdz)/sum_eps); - ggml_vec_acc_f32 (ne00, dx, dz); - ggml_vec_scale_f32(ne00, dx, rrms); + // dx[i00] = (dz + x*(-sum_xdz/sum_eps)) * rrms + // note: https://github.com/ggml-org/ggml/issues/1491 + const float scale_x = (float) (-sum_xdz) / sum_eps; + for (int64_t i00 = 0; i00 < ne00; i00++) { + dx[i00] = (dz[i00] + x[i00] * scale_x) * rrms; + } } } } From 28c7ed3db7e24261742123d6b33a90b0ef681808 Mon Sep 17 00:00:00 2001 From: Pascal <admin@serveurperso.com> Date: Tue, 9 Jun 2026 11:01:37 +0200 Subject: [PATCH 809/831] ggml : add GGML_OP_COL2IM_1D (llama/24206) * cpu: add GGML_OP_COL2IM_1D Add the overlap-add (scatter-add) step of a 1D transposed convolution. A ConvTranspose1d factorizes as a GEMM followed by col2im: a weight pre-permuted to [IC, K*OC] is contracted against the [IC, T_in] input with mul_mat to produce a column matrix [K*OC, T_in], and col2im_1d scatters those columns back into the [T_out, OC] signal, with T_out = (T_in - 1)*s0 + K - 2*p0. Keeping the contraction as a plain mul_mat leaves the heavy work on the optimized (and quantizable) matmul kernels, so col2im_1d only does the cheap overlap-add. CPU uses a gather formulation parallelized over output channels, supporting F32, F16 and BF16 with an F32 accumulator. * tests: add backend coverage for GGML_OP_COL2IM_1D Add test_col2im_1d next to the conv_transpose_1d cases, covering F32, F16 and BF16 across eight geometries: the canonical kernel = 2*stride DAC upsampling shape, overlap, no overlap, cropping (p0 = 1 and p0 = stride/2), kernel < stride with zeroed gaps, kernel not a multiple of stride, and a single column unfold. Perf mode gets three real vocoder stage shapes reporting memory bandwidth. max_nmse_err relaxes to 5e-4 for F16 and BF16. * cpu: harden GGML_OP_COL2IM_1D ggml_col2im_1d validates s0, oc, p0 and input contiguity at graph build time, before the oc division, protecting every backend at once. The kernel asserts the contiguity its flat indexing assumes and its doc states the full output length including the crop term. The kernel parallelizes over the time axis: the split stays balanced down to OC = 1, where the previous channel split was single threaded. Values are bit identical on the three real vocoder chains, two out of three improve. * tests: extend the GGML_OP_COL2IM_1D grid The eval grid grows to eleven geometries: OC = 1 (mono output stage), K = 1 with stride > 1 (sparse scatter, every gap position zeroed) and a crop down to T_out = 2 where all the gather bounds act at once. * tests: add col2im_1d equivalence test tests/test-col2im-1d.cpp proves mul_mat + col2im_1d matches the native ggml_conv_transpose_1d on the CPU backend, F32 bit exact, F16 and BF16 through casts of the column matrix. test-backend-ops cannot cover this for a CPU only op since the CPU backend is its own reference there. * rpc: bump protocol patch version for GGML_OP_COL2IM_1D GGML_OP_COUNT goes from 96 to 97 with the new op, which trips the static_assert in ggml-rpc.h. Bump RPC_PROTO_PATCH_VERSION since the op is appended and no existing op code shifts. --- ggml/include/ggml-rpc.h | 4 +- ggml/include/ggml.h | 11 ++++++ ggml/src/ggml-cpu/ggml-cpu.c | 5 +++ ggml/src/ggml-cpu/ops.cpp | 72 ++++++++++++++++++++++++++++++++++++ ggml/src/ggml-cpu/ops.h | 1 + ggml/src/ggml.c | 41 +++++++++++++++++++- 6 files changed, 130 insertions(+), 4 deletions(-) diff --git a/ggml/include/ggml-rpc.h b/ggml/include/ggml-rpc.h index 6fcf5a43393..5ad121ae57f 100644 --- a/ggml/include/ggml-rpc.h +++ b/ggml/include/ggml-rpc.h @@ -8,10 +8,10 @@ extern "C" { #define RPC_PROTO_MAJOR_VERSION 4 #define RPC_PROTO_MINOR_VERSION 0 -#define RPC_PROTO_PATCH_VERSION 0 +#define RPC_PROTO_PATCH_VERSION 1 #ifdef __cplusplus -static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT has changed - update RPC_PROTO_PATCH_VERSION"); +static_assert(GGML_OP_COUNT == 97, "GGML_OP_COUNT has changed - update RPC_PROTO_PATCH_VERSION"); #endif #define GGML_RPC_MAX_SERVERS 16 diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index f6725265504..374934aacf3 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -535,6 +535,7 @@ extern "C" { GGML_OP_IM2COL, GGML_OP_IM2COL_BACK, GGML_OP_IM2COL_3D, + GGML_OP_COL2IM_1D, GGML_OP_CONV_2D, GGML_OP_CONV_3D, GGML_OP_CONV_2D_DW, @@ -2007,6 +2008,16 @@ extern "C" { int d1, // dilation dimension 1 bool is_2D); + // col2im_1d: scatter-add GEMM columns back to 1D signal + // a: [K*OC, T_in] (columns from matmul, K = a->ne[0]/OC) + // result: [T_out, OC] where T_out = (T_in - 1)*s0 + K - 2*p0 + GGML_API struct ggml_tensor * ggml_col2im_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, // columns [K*OC, T_in] + int s0, // stride + int oc, // output channels + int p0); // padding to crop from both sides + GGML_API struct ggml_tensor * ggml_conv_1d( struct ggml_context * ctx, struct ggml_tensor * a, // convolution kernel diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index cd5c61a8187..af7827aec39 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -1912,6 +1912,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_im2col_3d(params, tensor); } break; + case GGML_OP_COL2IM_1D: + { + ggml_compute_forward_col2im_1d(params, tensor); + } break; case GGML_OP_CONV_2D: { ggml_compute_forward_conv_2d(params, tensor); @@ -2343,6 +2347,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_CONV_2D: case GGML_OP_CONV_3D: case GGML_OP_CONV_2D_DW: + case GGML_OP_COL2IM_1D: case GGML_OP_CONV_TRANSPOSE_1D: case GGML_OP_CONV_TRANSPOSE_2D: { diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index becac9d6ef9..86842e55474 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -6730,6 +6730,78 @@ static inline int64_t ggml_wrap_around(int64_t coord, int64_t size) { return (coord + size) % size; // adding size avoids negative number weirdness } +// ggml_compute_forward_col2im_1d +// +// Scatter-add columns [K*OC, T_in] -> signal [T_out, OC] +// where T_out = (T_in - 1)*s + K - 2*p. Gather approach: each output reads ceil(K/s) inputs. +// Parallelized over the time axis so the split stays balanced whatever OC is. +// Supports F32, F16, BF16 input/output (same type), F32 accumulator. + +template <typename elem_t> +static void ggml_compute_forward_col2im_1d_impl( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src = dst->src[0]; // [K*OC, T_in] + + GGML_ASSERT(ggml_is_contiguous(src)); + GGML_ASSERT(ggml_is_contiguous(dst)); + + const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; + const int32_t OC = ((const int32_t *)(dst->op_params))[1]; + const int32_t p0 = ((const int32_t *)(dst->op_params))[2]; + + const int64_t K_OC = src->ne[0]; + const int64_t T_in = src->ne[1]; + const int64_t K = K_OC / OC; + const int64_t T_out = dst->ne[0]; + + const elem_t * col_data = (const elem_t *) src->data; + elem_t * dst_data = (elem_t *) dst->data; + + const int ith = params->ith; + const int nth = params->nth; + + // Parallelize over the time axis: the split stays balanced whatever OC is, + // down to OC = 1 for mono audio, and threads read disjoint column bands + const int64_t dr = (T_out + nth - 1) / nth; + const int64_t it0 = dr * ith; + const int64_t it1 = it0 + dr < T_out ? it0 + dr : T_out; + + for (int64_t oc = 0; oc < OC; oc++) { + for (int64_t t_out = it0; t_out < it1; t_out++) { + const int64_t t_abs = t_out + p0; // absolute position in uncropped signal + // Gather: find all (t_in, k) where t_in * s + k == t_abs, 0 <= k < K + int64_t t_in_min = (t_abs - K + 1 + s0 - 1) / s0; // ceil((t_abs-K+1)/s) + if (t_in_min < 0) t_in_min = 0; + int64_t t_in_max = t_abs / s0; + if (t_in_max >= T_in) t_in_max = T_in - 1; + + float sum = 0.0f; + for (int64_t t_in = t_in_min; t_in <= t_in_max; t_in++) { + int64_t k = t_abs - t_in * s0; + if (k >= 0 && k < K) { + // col layout: [K*OC, T_in], element (oc*K+k, t_in) + sum += type_conversion_table<elem_t>::to_f32(col_data[(oc * K + k) + t_in * K_OC]); + } + } + // dst layout: [T_out, OC], element (t_out, oc) + dst_data[t_out + oc * T_out] = type_conversion_table<elem_t>::from_f32(sum); + } + } +} + +void ggml_compute_forward_col2im_1d( + const ggml_compute_params * params, + ggml_tensor * dst) { + switch (dst->src[0]->type) { + case GGML_TYPE_F32: ggml_compute_forward_col2im_1d_impl<float> (params, dst); break; + case GGML_TYPE_F16: ggml_compute_forward_col2im_1d_impl<ggml_fp16_t>(params, dst); break; + case GGML_TYPE_BF16: ggml_compute_forward_col2im_1d_impl<ggml_bf16_t>(params, dst); break; + default: GGML_ABORT("col2im_1d: unsupported type %d", dst->src[0]->type); + } +} + // ggml_compute_forward_conv_2d diff --git a/ggml/src/ggml-cpu/ops.h b/ggml/src/ggml-cpu/ops.h index 7398e561894..a8e18c716db 100644 --- a/ggml/src/ggml-cpu/ops.h +++ b/ggml/src/ggml-cpu/ops.h @@ -68,6 +68,7 @@ void ggml_compute_forward_conv_transpose_1d(const struct ggml_compute_params * p void ggml_compute_forward_im2col(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_im2col_back_f32(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_im2col_3d(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_col2im_1d(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_conv_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_conv_3d(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_conv_transpose_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 8815c67d8bc..18a5ebd2ab0 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1031,6 +1031,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "IM2COL", "IM2COL_BACK", "IM2COL_3D", + "COL2IM_1D", "CONV_2D", "CONV_3D", "CONV_2D_DW", @@ -1080,7 +1081,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "GLU", }; -static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT != 96"); +static_assert(GGML_OP_COUNT == 97, "GGML_OP_COUNT != 97"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1141,6 +1142,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "im2col(x)", "im2col_back(x)", "im2col_3d(x)", + "col2im_1d(x)", "conv_2d(x)", "conv_3d(x)", "conv_2d_dw(x)", @@ -1190,7 +1192,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "glu(x)", }; -static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT != 96"); +static_assert(GGML_OP_COUNT == 97, "GGML_OP_COUNT != 97"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -4541,6 +4543,41 @@ struct ggml_tensor * ggml_conv_1d_dw_ph( return ggml_conv_1d_dw(ctx, a, b, s0, a->ne[0] / 2, d0); } +// ggml_col2im_1d + +struct ggml_tensor * ggml_col2im_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int s0, + int oc, + int p0) { + GGML_ASSERT(ggml_is_matrix(a)); + GGML_ASSERT(ggml_is_contiguous(a)); + GGML_ASSERT(a->type == GGML_TYPE_F32 || a->type == GGML_TYPE_F16 || a->type == GGML_TYPE_BF16); + GGML_ASSERT(s0 > 0); + GGML_ASSERT(oc > 0); + GGML_ASSERT(p0 >= 0); + + const int64_t K_OC = a->ne[0]; + const int64_t T_in = a->ne[1]; + const int64_t K = K_OC / oc; + const int64_t T_out = (T_in - 1) * s0 + K - 2 * p0; + + GGML_ASSERT(K_OC == K * oc); // a->ne[0] must be a whole number of oc blocks + GGML_ASSERT(K > 0 && T_out > 0); + + const int64_t ne[4] = { T_out, oc, 1, 1 }; + struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, 2, ne); + + int32_t params[] = { s0, (int32_t)oc, (int32_t)p0 }; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_COL2IM_1D; + result->src[0] = a; + + return result; +} + // ggml_conv_transpose_1d static int64_t ggml_calc_conv_transpose_1d_output_size(int64_t ins, int64_t ks, int s, int p, int d) { From 686bc802d1f5df3cf4a23102eca785b62877619b Mon Sep 17 00:00:00 2001 From: Ruben Ortlam <rortlam@redhat.com> Date: Tue, 9 Jun 2026 13:27:04 +0200 Subject: [PATCH 810/831] vulkan: add `v_dot2_f32_f16` support in matrix-matrix multiplication and Flash Attention (llama/24123) * vulkan: add support for valve fp16 dot2 extension * use macro for dot2 path choice * properly check for the feature * add dot_product abstraction to reduce preprocessor branching --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 84 ++++++++++++++++--- .../vulkan-shaders/dot_product_funcs.glsl | 27 ++++++ .../vulkan-shaders/flash_attn.comp | 5 +- .../ggml-vulkan/vulkan-shaders/mul_mm.comp | 12 +-- .../vulkan-shaders/vulkan-shaders-gen.cpp | 46 ++++++---- 5 files changed, 139 insertions(+), 35 deletions(-) create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/dot_product_funcs.glsl diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 2dd8cd2fbd9..c4ea0b105ce 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -113,6 +113,21 @@ typedef struct VkPhysicalDeviceShaderBfloat16FeaturesKHR { } VkPhysicalDeviceShaderBfloat16FeaturesKHR; #endif +#if !defined(VK_VALVE_shader_mixed_float_dot_product) +#define VK_VALVE_shader_mixed_float_dot_product 1 +#define VK_VALVE_SHADER_MIXED_FLOAT_DOT_PRODUCT_SPEC_VERSION 1 +#define VK_VALVE_SHADER_MIXED_FLOAT_DOT_PRODUCT_EXTENSION_NAME "VK_VALVE_shader_mixed_float_dot_product" +#define VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_MIXED_FLOAT_DOT_PRODUCT_FEATURES_VALVE ((VkStructureType)1000673000) +typedef struct VkPhysicalDeviceShaderMixedFloatDotProductFeaturesVALVE { + VkStructureType sType; + void* pNext; + VkBool32 shaderMixedFloatDotProductFloat16AccFloat32; + VkBool32 shaderMixedFloatDotProductFloat16AccFloat16; + VkBool32 shaderMixedFloatDotProductBFloat16Acc; + VkBool32 shaderMixedFloatDotProductFloat8AccFloat32; +} VkPhysicalDeviceShaderMixedFloatDotProductFeaturesVALVE; +#endif + #define ROUNDUP_POW2(M, N) (((M) + (N) - 1) & ~((N) - 1)) #define CEIL_DIV(M, N) (((M) + (N)-1) / (N)) static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; } @@ -705,6 +720,8 @@ struct vk_device_struct { bool coopmat2_bf16_support {}; bool coopmat2_decode_vector; + bool dot2_f16 {}; + bool pipeline_executable_properties_support {}; size_t idx; @@ -3920,8 +3937,13 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) { name = aligned ? "flash_attn_f32_f16_aligned" : "flash_attn_f32_f16"; } else { if (device->fp16) { - if (f32acc) { spv_data = flash_attn_f32_f16_data; spv_size = flash_attn_f32_f16_len; } - else { spv_data = flash_attn_f32_f16_f16acc_data; spv_size = flash_attn_f32_f16_f16acc_len; } + if (device->dot2_f16) { + if (f32acc) { spv_data = flash_attn_f32_f16_dot2_data; spv_size = flash_attn_f32_f16_dot2_len; } + else { spv_data = flash_attn_f32_f16_dot2_f16acc_data; spv_size = flash_attn_f32_f16_dot2_f16acc_len; } + } else { + if (f32acc) { spv_data = flash_attn_f32_f16_data; spv_size = flash_attn_f32_f16_len; } + else { spv_data = flash_attn_f32_f16_f16acc_data; spv_size = flash_attn_f32_f16_f16acc_len; } + } } else { spv_data = flash_attn_f32_f16_fp32_data; spv_size = flash_attn_f32_f16_fp32_len; @@ -4215,7 +4237,23 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) { #endif // defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) if (device->fp16) { // Create 6 variants, {s,m,l}x{unaligned,aligned} + // Selects dot2 SPIR-V variant at runtime when device->dot2_f16 is true #define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \ + if (device->mul_mat ## ID ## _l[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + if (device->mul_mat ## ID ## _m[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + if (device->mul_mat ## ID ## _s[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + if (device->mul_mat ## ID ## _l[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _len : NAMELC ## _aligned ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _data : NAMELC ## _aligned ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + if (device->mul_mat ## ID ## _m[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _len : NAMELC ## _aligned ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _data : NAMELC ## _aligned ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + if (device->mul_mat ## ID ## _s[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _len : NAMELC ## _aligned ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _data : NAMELC ## _aligned ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + + // bf16 scalar path promotes to f32, no dot2 variant +#define CREATE_MM_NODOT2(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \ if (device->mul_mat ## ID ## _l[TYPE]) \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ if (device->mul_mat ## ID ## _m[TYPE]) \ @@ -4250,7 +4288,7 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) { CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); - CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM_NODOT2(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); CREATE_MM2(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q1_0], matmul_q1_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0], matmul_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); @@ -4258,7 +4296,6 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) { CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0], matmul_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1], matmul_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0], matmul_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); - CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K], matmul_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K], matmul_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K], matmul_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); @@ -4298,8 +4335,7 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) { CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16); CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16); CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_subgroup_f16_f32, wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16); - CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16); - + CREATE_MM_NODOT2(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16); CREATE_MM2(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q1_0], matmul_id_subgroup_q1_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); @@ -4344,8 +4380,7 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) { CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); - CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); - + CREATE_MM_NODOT2(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM2(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q1_0], matmul_id_q1_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); @@ -4390,6 +4425,7 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) { #undef CREATE_MM2 #undef CREATE_MMQ #undef CREATE_MM +#undef CREATE_MM_NODOT2 } else { // Create 6 variants, {s,m,l}x{unaligned,aligned} #define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \ @@ -5453,6 +5489,7 @@ static vk_device ggml_vk_get_device(size_t idx) { device->integer_dot_product = false; device->shader_64b_indexing = false; bool bfloat16_support = false; + bool dot2_f16_support = false; for (const auto& properties : ext_props) { if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) { @@ -5495,6 +5532,9 @@ static vk_device ggml_vk_get_device(size_t idx) { !getenv("GGML_VK_DISABLE_BFLOAT16")) { bfloat16_support = true; #endif + } else if (strcmp("VK_VALVE_shader_mixed_float_dot_product", properties.extensionName) == 0 && + !getenv("GGML_VK_DISABLE_DOT2")) { + dot2_f16_support = true; } else if (strcmp("VK_KHR_pipeline_executable_properties", properties.extensionName) == 0) { pipeline_executable_properties_support = true; } else if (strcmp("VK_EXT_memory_priority", properties.extensionName) == 0 && @@ -5802,6 +5842,14 @@ static vk_device ggml_vk_get_device(size_t idx) { device_extensions.push_back("VK_KHR_shader_integer_dot_product"); } + VkPhysicalDeviceShaderMixedFloatDotProductFeaturesVALVE dot2_features {}; + dot2_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_MIXED_FLOAT_DOT_PRODUCT_FEATURES_VALVE; + if (dot2_f16_support) { + last_struct->pNext = (VkBaseOutStructure *)&dot2_features; + last_struct = (VkBaseOutStructure *)&dot2_features; + device_extensions.push_back("VK_VALVE_shader_mixed_float_dot_product"); + } + VkPhysicalDevicePipelineExecutablePropertiesFeaturesKHR pep_features {}; pep_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PIPELINE_EXECUTABLE_PROPERTIES_FEATURES_KHR; if (pipeline_executable_properties_support) { @@ -5836,6 +5884,8 @@ static vk_device ggml_vk_get_device(size_t idx) { device->bf16 = false; #endif + device->dot2_f16 = dot2_f16_support && dot2_features.shaderMixedFloatDotProductFloat16AccFloat32; + device->pipeline_robustness = pl_robustness_features.pipelineRobustness; device->multi_add = vk12_props.shaderRoundingModeRTEFloat16 && @@ -6250,6 +6300,7 @@ static void ggml_vk_print_gpu_info(size_t idx) { bool coopmat2_decode_vector_support = false; bool integer_dot_product = false; bool bfloat16_support = false; + bool dot2_f16_support = false; for (auto properties : ext_props) { if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) { @@ -6279,6 +6330,9 @@ static void ggml_vk_print_gpu_info(size_t idx) { !getenv("GGML_VK_DISABLE_BFLOAT16")) { bfloat16_support = true; #endif + } else if (strcmp("VK_VALVE_shader_mixed_float_dot_product", properties.extensionName) == 0 && + !getenv("GGML_VK_DISABLE_DOT2")) { + dot2_f16_support = true; } } @@ -6369,6 +6423,13 @@ static void ggml_vk_print_gpu_info(size_t idx) { last_struct = (VkBaseOutStructure *)&coopmat2_decode_vector_features; } + VkPhysicalDeviceShaderMixedFloatDotProductFeaturesVALVE dot2_features {}; + dot2_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_MIXED_FLOAT_DOT_PRODUCT_FEATURES_VALVE; + if (dot2_f16_support) { + last_struct->pNext = (VkBaseOutStructure *)&dot2_features; + last_struct = (VkBaseOutStructure *)&dot2_features; + } + vkGetPhysicalDeviceFeatures2(physical_device, &device_features2); fp16 = fp16 && vk12_features.shaderFloat16; @@ -6415,9 +6476,12 @@ static void ggml_vk_print_gpu_info(size_t idx) { : coopmat_support ? "KHR_coopmat" : "none"; + bool dot2_f16 = dot2_f16_support && dot2_features.shaderMixedFloatDotProductFloat16AccFloat32; + const char *fp16_str = fp16 ? (dot2_f16 ? "dot2" : "1") : "0"; + std::string device_name = props2.properties.deviceName.data(); - GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | bf16: %d | warp size: %zu | shared memory: %d | int dot: %d | matrix cores: %s\n", - idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, bf16, subgroup_size, + GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %s | bf16: %d | warp size: %zu | shared memory: %d | int dot: %d | matrix cores: %s\n", + idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16_str, bf16, subgroup_size, props2.properties.limits.maxComputeSharedMemorySize, integer_dot_product, matrix_cores.c_str()); if (props2.properties.deviceType == vk::PhysicalDeviceType::eCpu) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dot_product_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/dot_product_funcs.glsl new file mode 100644 index 00000000000..c474bfe09ce --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dot_product_funcs.glsl @@ -0,0 +1,27 @@ +#ifdef DOT2_F16 +#extension GL_EXT_spirv_intrinsics : require + +spirv_instruction(extensions = ["SPV_VALVE_mixed_float_dot_product"], + capabilities = [6912], id = 6916) +float v_dot2_f32_f16(f16vec2 a, f16vec2 b, float acc); + +ACC_TYPE dot_product(f16vec4 a, f16vec4 b, ACC_TYPE acc) { + return ACC_TYPE(v_dot2_f32_f16(a.zw, b.zw, v_dot2_f32_f16(a.xy, b.xy, float(acc)))); +} + +ACC_TYPE dot_product(f16vec2 a, f16vec2 b, ACC_TYPE acc) { + return ACC_TYPE(v_dot2_f32_f16(a, b, float(acc))); +} + +#else + +ACC_TYPE dot_product(FLOAT_TYPEV4 a, FLOAT_TYPEV4 b, ACC_TYPE acc) { + return fma(ACC_TYPE(a.x), ACC_TYPE(b.x), fma(ACC_TYPE(a.y), ACC_TYPE(b.y), + fma(ACC_TYPE(a.z), ACC_TYPE(b.z), fma(ACC_TYPE(a.w), ACC_TYPE(b.w), acc)))); +} + +ACC_TYPE dot_product(FLOAT_TYPEV2 a, FLOAT_TYPEV2 b, ACC_TYPE acc) { + return fma(ACC_TYPE(a.x), ACC_TYPE(b.x), fma(ACC_TYPE(a.y), ACC_TYPE(b.y), acc)); +} + +#endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index 6ac095489b3..91fb07c93e7 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -21,6 +21,7 @@ #extension GL_KHR_shader_subgroup_vote : enable #include "types.glsl" +#include "dot_product_funcs.glsl" #include "flash_attn_base.glsl" #include "flash_attn_dequant.glsl" @@ -318,7 +319,7 @@ void main() { K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]); } [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - Sf[r][c] += dot(ACC_TYPEV4(Q_cache[r]), ACC_TYPEV4(K_Tf)); + Sf[r][c] = dot_product(Q_cache[r], K_Tf, Sf[r][c]); } } } @@ -341,7 +342,7 @@ void main() { K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]); } [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - Sf[r][c] += dot(ACC_TYPEV4(Qf[tile_row(r) * qf_stride + d * D_split + d_tid]), ACC_TYPEV4(K_Tf)); + Sf[r][c] = dot_product(Qf[tile_row(r) * qf_stride + d * D_split + d_tid], K_Tf, Sf[r][c]); } } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp index 89346e48e06..f39410d74f0 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp @@ -29,6 +29,7 @@ #endif #include "types.glsl" +#include "dot_product_funcs.glsl" #ifndef LOAD_VEC_A #define LOAD_VEC_A 1 @@ -329,15 +330,8 @@ void main() { [[unroll]] for (uint cr = 0; cr < TM / 2; cr++) { // [WNITER][TN][WMITER][TM / 2] -> [wsic][cc][wsir][cr] const uint sums_idx = (wsic * TN + cc) * WMITER * (TM / 2) + wsir * (TM / 2) + cr; - #if defined(DATA_A_F32) || defined(DATA_A_F16) - sums[sums_idx].x = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].y), ACC_TYPE(cache_b.y), - fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].z), ACC_TYPE(cache_b.z), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].w), ACC_TYPE(cache_b.w), sums[sums_idx].x)))); - sums[sums_idx].y = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].y), ACC_TYPE(cache_b.y), - fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].z), ACC_TYPE(cache_b.z), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].w), ACC_TYPE(cache_b.w), sums[sums_idx].y)))); - #else - sums[sums_idx].x = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].y), ACC_TYPE(cache_b.y), sums[sums_idx].x)); - sums[sums_idx].y = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].y), ACC_TYPE(cache_b.y), sums[sums_idx].y)); - #endif + sums[sums_idx].x = dot_product(cache_a[wsir * TM + 2 * cr ], cache_b, sums[sums_idx].x); + sums[sums_idx].y = dot_product(cache_a[wsir * TM + 2 * cr + 1], cache_b, sums[sums_idx].y); } } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 8fc00362870..7bcb1460814 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -336,7 +336,8 @@ void string_to_spv_func(std::string name, std::string in_path, std::string out_p // disable spirv-opt for coopmat shaders for https://github.com/ggml-org/llama.cpp/issues/10734 // disable spirv-opt for bf16 shaders for https://github.com/ggml-org/llama.cpp/issues/15344 // disable spirv-opt for rope shaders for https://github.com/ggml-org/llama.cpp/issues/16860 - if (!coopmat && name.find("bf16") == std::string::npos && name.find("rope") == std::string::npos) { + // disable spirv-opt for dot2 shaders (spirv-opt doesn't recognize SPV_VALVE_mixed_float_dot_product capability) + if (!coopmat && name.find("bf16") == std::string::npos && name.find("rope") == std::string::npos && name.find("_dot2") == std::string::npos) { cmd.push_back("-O"); } @@ -427,10 +428,11 @@ void string_to_spv(std::string name, const std::string& source, const std::map<s generate_dep_file = false; } -void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool coopmat2, bool f16acc) { +void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool coopmat2, bool f16acc, bool dot2 = false) { std::string load_vec = coopmat2 ? "1" : fp16 ? "8" : "4"; std::string aligned_b_type_f32 = coopmat2 ? "float" : fp16 ? "mat2x4" : "vec4"; std::string aligned_b_type_f16 = coopmat2 ? "float16_t" : fp16 ? "f16mat2x4" : "f16vec4"; + std::string dot2_sfx = dot2 ? "_dot2" : ""; std::map<std::string, std::string> base_dict; std::string shader_name = "matmul"; @@ -463,6 +465,10 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c } #endif + if (dot2) { + base_dict["DOT2_F16"] = "1"; + } + const std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp"; auto const &FLOAT_TYPE = [&](int vec, const std::string &t) -> std::string { @@ -528,11 +534,11 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c }; // Shaders with f16 B_TYPE - string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_f32_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_f32_f16" + dot2_sfx + "_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_f16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_f16" + dot2_sfx + "_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); // bf16 { @@ -553,8 +559,10 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c if (!(coopmat || coopmat2)) #endif { - string_to_spv(shader_name + "_bf16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"B_TYPEV4", "bf16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"B_TYPEV4", "bf16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + if (!dot2) { + string_to_spv(shader_name + "_bf16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"B_TYPEV4", "bf16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"B_TYPEV4", "bf16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + } } } @@ -584,18 +592,18 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c // don't generate f32 variants for coopmat2 if (!coopmat2) { - string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f32" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f32" + dot2_sfx + "_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); } if (tname != "f16" && tname != "f32") { - string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f16" + dot2_sfx + "_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); } #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) - // Integer dot mmq performs better with f32 accumulators - if (!f16acc && !coopmat && !coopmat2 && (is_legacy_quant(tname) || is_k_quant(tname) || tname == "mxfp4")) { + // Integer dot mmq performs better with f32 accumulators (different shader, skip for dot2) + if (!f16acc && !coopmat && !coopmat2 && !dot2 && (is_legacy_quant(tname) || is_k_quant(tname) || tname == "mxfp4")) { string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc); } #endif @@ -613,6 +621,10 @@ void process_shaders() { matmul_shaders(true, matmul_id_type, false, false, false); matmul_shaders(true, matmul_id_type, false, false, true); + // dot2 variants (scalar fp16 only) + matmul_shaders(true, matmul_id_type, false, false, false, true); + matmul_shaders(true, matmul_id_type, false, false, true, true); + if (matmul_id_type != MatMulIdType::DEFAULT) { #if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) // Coopmat, fp32acc and fp16acc @@ -660,6 +672,12 @@ void process_shaders() { string_to_spv("flash_attn_f32_f16", "flash_attn.comp", merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), fp16, false, false, f16acc); + + if (fp16) { + string_to_spv("flash_attn_f32_f16_dot2", "flash_attn.comp", + merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"DOT2_F16", "1"}}), fp16, false, false, f16acc); + } + #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) string_to_spv("flash_attn_f32_f16", "flash_attn.comp", merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"MMQ", "1"}, {"FA_MMQ_MIXED", "1"}}), fp16, false, false, f16acc, "_int8"); From dc794303d86cfc650f41e2545ae8cf19a7dc5548 Mon Sep 17 00:00:00 2001 From: Jeff Bolz <jbolz@nvidia.com> Date: Tue, 9 Jun 2026 06:27:38 -0500 Subject: [PATCH 811/831] vulkan: reduce iq1 shared memory usage for mul_mm (llama/24287) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 4 +++- ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp | 1 + ggml/src/ggml-vulkan/vulkan-shaders/types.glsl | 8 +++++++- 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index c4ea0b105ce..22405f234de 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -3394,7 +3394,9 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec switch (src0_type) { case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ1_M: - lut_size = 2*2048 + 4*2048; + // Regular matmul uses the compact uint16_t IQ1 grid; the expanded + // uint32_t grid is only enabled for the q8_1/int-dot vector path. + lut_size = 2*2048; break; case GGML_TYPE_IQ2_XXS: lut_size = 8*256; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp index 6fe3e2dc043..fd84c3c91d8 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp @@ -4,6 +4,7 @@ #extension GL_EXT_integer_dot_product : require #define MMQ +#define NEEDS_IQ1S_GRID_GPU #define B_TYPE block_q8_1_x4 #include "mul_mat_vec_base.glsl" diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl index f84d6f87334..8c6b20c6889 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl @@ -598,9 +598,10 @@ const uint[1024] iq1s_grid_const = { 0x55dd55df, 0x55d555d7, 0x5503550c, 0x557f5501, 0x5577557d, 0x55405575, 0x555d555f, 0x55555557 }; +#if defined(NEEDS_IQ1S_GRID_GPU) // Same content as iq1s_grid_const except each 2-bit value is expanded to 4-bit // and has 1 added to it (allows packed values to be extracted with & 0x0F0F0F0F -// and 0xF0F0F0F0). +// and 0xF0F0F0F0). This is only used by the q8_1/int-dot vector path. const uint32_t[2048] iq1s_grid_gpu_const = { 0x00000000, 0x00000002, 0x00000101, 0x00000200, 0x00000202, 0x00010001, 0x00010101, 0x00020000, 0x00020002, 0x00020200, 0x00020202, 0x01000101, 0x01010001, 0x01010100, 0x01010102, 0x01020101, @@ -859,9 +860,12 @@ const uint32_t[2048] iq1s_grid_gpu_const = { 0x20222020, 0x20222022, 0x20222220, 0x20222222, 0x21212021, 0x21212120, 0x21212122, 0x22202020, 0x22202022, 0x22202220, 0x22202222, 0x22212121, 0x22222020, 0x22222022, 0x22222220, 0x22222222, }; +#endif shared uint16_t iq1s_grid[2048]; +#if defined(NEEDS_IQ1S_GRID_GPU) shared uint32_t iq1s_grid_gpu[2048]; +#endif #define NEEDS_INIT_IQ_SHMEM void init_iq_shmem(uvec3 wgsize) @@ -875,12 +879,14 @@ void init_iq_shmem(uvec3 wgsize) iq1s_grid[2*idx+1] = g.y; } } +#if defined(NEEDS_IQ1S_GRID_GPU) [[unroll]] for (uint i = 0; i < iq1s_grid_gpu_const.length(); i += wgsize.x) { uint idx = i + gl_LocalInvocationIndex.x; if (iq1s_grid_gpu_const.length() % wgsize.x == 0 || idx < iq1s_grid_gpu_const.length()) { iq1s_grid_gpu[idx] = iq1s_grid_gpu_const[idx]; } } +#endif barrier(); } #endif From ef85b26d9f0bfeba3548ea6ceb213a5191ef4c11 Mon Sep 17 00:00:00 2001 From: Oliver Simons <osimons@nvidia.com> Date: Wed, 10 Jun 2026 14:27:08 +0200 Subject: [PATCH 812/831] CUDA: Fix ssm_scan_f32 data-races (llama/24360) * Add missing syncthreads before resuing cub_temp_storage __syncthreads() is required before being allowed to resue TempStorage smem: https://nvidia.github.io/cccl/unstable/cub/api/classcub_1_1BlockLoad.html#_CPPv4I0EN3cub9BlockLoad4LoadEv20RandomAccessIteratorRA14ItemsPerThread_1Ti * Add one more missing __syncthreads Could also double-buffer, but alternative is to simply ensure all threads have read smem* before writing to it again in the next loop iteration * Remove unused smem from ssm_scan_f32 --- ggml/src/ggml-cuda/ssm-scan.cu | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/ssm-scan.cu b/ggml/src/ggml-cuda/ssm-scan.cu index 2e3f97c7284..3022249c77d 100644 --- a/ggml/src/ggml-cuda/ssm-scan.cu +++ b/ggml/src/ggml-cuda/ssm-scan.cu @@ -67,6 +67,7 @@ __global__ void __launch_bounds__(splitD, 1) __shared__ CubTempStorage cub_temp_storage; BlockLoad(cub_temp_storage.load_temp).Load(A_block, regA); + __syncthreads(); BlockLoad(cub_temp_storage.load_temp).Load(s0_block, regs0); #else const int stride_s0 = src0_nb2 / sizeof(float); @@ -105,6 +106,7 @@ __global__ void __launch_bounds__(splitD, 1) regs0[n] = state; } y_block[i * stride_y + threadIdx.x] = sumf; + __syncthreads(); } #ifdef USE_CUB @@ -249,9 +251,8 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa GGML_ASSERT(head_dim == 1); GGML_ASSERT(n_group == 1); const dim3 blocks(n_seq, (n_head + threads - 1) / threads, 1); - const int smem_size = (threads * (d_state + 1) * 2) * sizeof(float); if (d_state == 16) { - const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks, threads, smem_size, stream); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks, threads, 0, stream); switch (n_tok) { case 1: From 1a1900f90c165a078d66b4958db46fe82d14ff27 Mon Sep 17 00:00:00 2001 From: Gaurav Garg <gaugarg@nvidia.com> Date: Wed, 10 Jun 2026 23:21:16 +0530 Subject: [PATCH 813/831] Remove padding and multiple D2D copies for MTP (llama/24086) * Make ggml_gated_delta_net take only the initial recurrent state (D, 1, n_seqs) and passes the snapshot count K as an op parameter instead of inferring it from state->ne[1]. Remove the padding hack and copy all emitted snapshots into the recurrent cache with a single strided ggml_cpy * Make GDN changes in all backends. Address review comments. * Fix CI build errors --- ggml/include/ggml.h | 17 +++++++---- ggml/src/ggml-backend-meta.cpp | 4 +-- ggml/src/ggml-cpu/ggml-cpu.c | 2 +- ggml/src/ggml-cpu/ops.cpp | 17 +++++------ ggml/src/ggml-cuda/gated_delta_net.cu | 16 +++++----- ggml/src/ggml-hexagon/ggml-hexagon.cpp | 5 ++-- .../ggml-hexagon/htp/gated-delta-net-ops.c | 29 ++++++++++--------- ggml/src/ggml-metal/ggml-metal-device.cpp | 4 +-- ggml/src/ggml-metal/ggml-metal.metal | 11 ++++--- ggml/src/ggml-opencl/ggml-opencl.cpp | 2 +- .../ggml-opencl/kernels/gated_delta_net.cl | 8 +++-- ggml/src/ggml-sycl/gated_delta_net.cpp | 15 +++++----- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 8 ++--- .../vulkan-shaders/gated_delta_net.comp | 11 ++++--- ggml/src/ggml-webgpu/ggml-webgpu.cpp | 2 +- .../wgsl-shaders/gated_delta_net.wgsl | 7 +++-- ggml/src/ggml.c | 16 ++++++---- 17 files changed, 93 insertions(+), 81 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 374934aacf3..d6807b6dd47 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -2553,10 +2553,16 @@ extern "C" { // TODO: add ggml_gated_delta_net_set_bcast() to be able to configure Q, K broadcast type: tiled vs interleaved [TAG_GGML_GDN_BCAST] // ref: https://github.com/ggml-org/llama.cpp/pull/19468#discussion_r2786394306 // - // state is a 3D tensor of shape (S_v*S_v*H, K, n_seqs): - // K == 1: output carries the final state only. - // K > 1: output carries K snapshot slots; the kernel writes the last min(n_tokens, K) - // per-token snapshots into the trailing slots + // tensor shapes (S_k == S_v, H_v % H_k == 0): + // q, k : [S_k, H_k, n_tokens, n_seqs] + // v : [S_v, H_v, n_tokens, n_seqs] + // g : [1, H_v, n_tokens, n_seqs] (scalar gate) or [S_v, H_v, n_tokens, n_seqs] (KDA) + // beta : [1, H_v, n_tokens, n_seqs] + // state : [S_v, S_v, H_v, n_seqs] -- initial recurrent state s0 + // + // the output packs the attention scores [S_v, H_v, n_tokens, n_seqs] followed by K state + // snapshots, most-recent first (slot 0 = final state, slot s = state s tokens back). K == 1 + // keeps only the final state; when n_tokens < K only slots 0..n_tokens-1 are written. GGML_API struct ggml_tensor * ggml_gated_delta_net( struct ggml_context * ctx, struct ggml_tensor * q, @@ -2564,7 +2570,8 @@ extern "C" { struct ggml_tensor * v, struct ggml_tensor * g, struct ggml_tensor * beta, - struct ggml_tensor * state); + struct ggml_tensor * state, + int64_t K); // custom operators diff --git a/ggml/src/ggml-backend-meta.cpp b/ggml/src/ggml-backend-meta.cpp index 8c44c3e44ae..0a36f099000 100644 --- a/ggml/src/ggml-backend-meta.cpp +++ b/ggml/src/ggml-backend-meta.cpp @@ -776,8 +776,8 @@ static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( GGML_ASSERT(src_ss[2].axis == GGML_BACKEND_SPLIT_AXIS_1); GGML_ASSERT(src_ss[3].axis == GGML_BACKEND_SPLIT_AXIS_1); GGML_ASSERT(src_ss[4].axis == GGML_BACKEND_SPLIT_AXIS_1); - // state shape is (S_v*S_v*H, K, n_seqs); the heads dim is nested inside axis 0, - // so a head-aligned split on the input cache reshapes to axis 0 here (not axis 2). + // state shape is [S_v, S_v, H_v, n_seqs] (s0 only); the heads dim is its own axis 2, + // so a head-aligned split on the input cache lands on axis 2 here. GGML_ASSERT(src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_2 || src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_1 || src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_0); return {GGML_BACKEND_SPLIT_AXIS_0, {0}, {1}, 1}; }; diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index af7827aec39..eb8341c9aec 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -2948,7 +2948,7 @@ struct ggml_cplan ggml_graph_plan( case GGML_OP_GATED_DELTA_NET: { const int64_t S_v = node->src[2]->ne[0]; - const int64_t K = node->src[5]->ne[1]; // state is (D, K, n_seqs) + const int64_t K = ggml_get_op_params_i32(node, 0); const int64_t per_thread = S_v + (K > 1 ? S_v * S_v : 0); cur = per_thread * sizeof(float) * n_tasks; } break; diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 86842e55474..74611dce7f1 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -10624,11 +10624,11 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( const bool kda = (neg0 == S_v); - // state is 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count. - const int64_t K = src_state->ne[1]; + // K (snapshot slot count) is an op param; state holds s0 only [S_v, S_v, H, n_seqs]. + const int64_t K = ggml_get_op_params_i32(dst, 0); GGML_ASSERT(K >= 1); - // per-seq stride in floats (slot 0 of seq s lives at state + s * seq_stride) - const int64_t state_seq_stride = src_state->nb[2] / sizeof(float); + // per-seq stride in floats (seq s starts at state + s * seq_stride) + const int64_t state_seq_stride = src_state->nb[3] / sizeof(float); const int64_t per_thread = S_v + (K > 1 ? S_v * S_v : 0); const int ith = params->ith; @@ -10644,9 +10644,8 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( float * attn_out_base = (float *)dst->data; float * state_out_base = (float *)dst->data + attn_score_elems; - // snapshot slot mapping: target_slot = t - shift. When n_tokens < K only the last - // n_tokens slots are written; earlier slots are left untouched (caller-owned). - const int64_t shift = n_tokens - K; + // snapshot slot mapping: slot 0 = most recent state, slot s = s tokens back. + // When n_tokens < K only slots 0..n_tokens-1 are written; older slots are caller-owned. const float * state_in_base = (const float *)src_state->data; @@ -10674,7 +10673,7 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( : state_out_base + (iv3 * H + iv1) * S_v * S_v; // copy input state into the working buffer and operate in-place - // state layout (D, K, n_seqs): slot 0 of seq iv3 starts at iv3 * state_seq_stride. + // state layout [S_v, S_v, H, n_seqs]: seq iv3 starts at iv3 * state_seq_stride. const float * s_in = state_in_base + iv3 * state_seq_stride + iv1 * S_v * S_v; memcpy(s_out, s_in, S_v * S_v * sizeof(float)); @@ -10727,7 +10726,7 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( attn_data += S_v * H; // advance to next token if (K > 1) { - const int64_t target_slot = t - shift; + const int64_t target_slot = n_tokens - 1 - t; if (target_slot >= 0 && target_slot < K) { float * curr_state_o = state_out_base + target_slot * state_size_per_snap + (iv3 * H + iv1) * S_v * S_v; diff --git a/ggml/src/ggml-cuda/gated_delta_net.cu b/ggml/src/ggml-cuda/gated_delta_net.cu index 7cfda652367..a547360eb06 100644 --- a/ggml/src/ggml-cuda/gated_delta_net.cu +++ b/ggml/src/ggml-cuda/gated_delta_net.cu @@ -39,9 +39,9 @@ gated_delta_net_cuda(const float * q, float * attn_data = dst; float * state = dst + attn_score_elems; - // input state layout (D, K, n_seqs) — seq stride is K * D = K * H * S_v * S_v. + // input state holds s0 only: [S_v, S_v, H, n_seqs] — seq stride is D = H * S_v * S_v. // output state layout (per-slot D * n_seqs) — same per-(seq,head) offset as before. - const int64_t state_in_offset = sequence * K * H * S_v * S_v + h_idx * S_v * S_v; + const int64_t state_in_offset = sequence * H * S_v * S_v + h_idx * S_v * S_v; const int64_t state_out_offset = (sequence * H + h_idx) * S_v * S_v; state += state_out_offset; curr_state += state_in_offset + col * S_v; @@ -143,12 +143,10 @@ gated_delta_net_cuda(const float * q, attn_data += S_v * H; if constexpr (keep_rs_t) { - // slot mapping: target_slot = t - shift. When n_tokens < K only the last n_tokens slots - // are written; earlier slots are left untouched (caller-owned). - const int shift = (int) n_tokens - K; - + // snapshot slot mapping: slot 0 = most recent state, slot s = s tokens back. + // When n_tokens < K only slots 0..n_tokens-1 are written; older slots are caller-owned. const int64_t state_size_per_token = S_v * S_v * H * n_seqs; // per-slot stride in output - const int target_slot = t - shift; + const int target_slot = (int) n_tokens - 1 - t; if (target_slot >= 0 && target_slot < K) { float * curr_state = (dst + attn_score_elems) + target_slot * state_size_per_token + state_out_offset; #pragma unroll @@ -286,8 +284,8 @@ void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * cudaStream_t stream = ctx.stream(); - // state is 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count. - const int K = (int) src_state->ne[1]; + // K (snapshot slot count) is an op param; state holds s0 only [S_v, S_v, H, n_seqs]. + const int K = ggml_get_op_params_i32(dst, 0); const bool keep_rs = K > 1; if (kda) { diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index d550841a2a5..49bd7e4331a 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -2538,7 +2538,7 @@ static bool ggml_hexagon_supported_gated_delta_net(const struct ggml_hexagon_ses const int64_t H = v->ne[1]; const int64_t n_tokens = v->ne[2]; const int64_t n_seqs = v->ne[3]; - const int64_t K = state->ne[1]; + const int64_t K = ggml_get_op_params_i32(op, 0); if (S_v <= 0 || S_v > 128 || H <= 0 || n_tokens <= 0 || n_seqs <= 0) { return false; @@ -2551,7 +2551,8 @@ static bool ggml_hexagon_supported_gated_delta_net(const struct ggml_hexagon_ses if ((g->ne[0] != 1 && g->ne[0] != S_v) || beta->ne[0] != 1) { return false; } - if (ggml_nelements(state) != S_v * S_v * H * n_seqs * K) { + // state holds s0 only [S_v, S_v, H, n_seqs]; K is op param 0. + if (ggml_nelements(state) != S_v * S_v * H * n_seqs) { return false; } if (dst->ne[0] != S_v * H || dst->ne[1] != n_tokens * n_seqs + S_v * n_seqs * K) { diff --git a/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c b/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c index 3b092d7440d..35518e6111c 100644 --- a/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c +++ b/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c @@ -584,7 +584,7 @@ static void gated_delta_net_f32_pp_thread(unsigned int nth, unsigned int ith, vo const uint32_t H = v->ne[1]; const uint32_t n_tokens = v->ne[2]; const uint32_t n_seqs = v->ne[3]; - const uint32_t K = state->ne[1]; + const uint32_t K = octx->op_params[0]; const uint32_t total_rows = H * n_seqs; if (ith >= total_rows) { @@ -618,9 +618,8 @@ static void gated_delta_net_f32_pp_thread(unsigned int nth, unsigned int ith, vo struct fastdiv_values fd_rq3 = init_fastdiv_values(rq3); struct fastdiv_values fd_rk3 = init_fastdiv_values(rk3); - const uint64_t state_seq_stride = state->nb[2] / sizeof(float); + const uint64_t state_seq_stride = state->nb[3] / sizeof(float); const uint64_t state_size_per_snap = (uint64_t) S_v * S_v * H * n_seqs; - const int64_t shift = (int64_t) n_tokens - (int64_t) K; uint32_t ir_prefetch = ith; int spad_idx = 0; @@ -630,7 +629,8 @@ static void gated_delta_net_f32_pp_thread(unsigned int nth, unsigned int ith, vo const uint32_t piv1 = fastmodulo(ir_prefetch, H, &fd_H); const uint32_t piv3 = fastdiv(ir_prefetch, &fd_H); const float * ps_in = state_in_base + (uint64_t) piv3 * state_seq_stride + (uint64_t) piv1 * S_v * S_v; - float * ps_out = state_out_base + (uint64_t) (K - 1) * state_size_per_snap + ((uint64_t) piv3 * H + piv1) * S_v * S_v; + // final state lands in snapshot slot 0 (most-recent-first ordering) + float * ps_out = state_out_base + ((uint64_t) piv3 * H + piv1) * S_v * S_v; // Push dummy write-back dma_queue_push(dma, dma_make_ptr(ps_out, s_work[spad_idx]), @@ -661,7 +661,8 @@ static void gated_delta_net_f32_pp_thread(unsigned int nth, unsigned int ith, vo const uint32_t iq3 = fastdiv(iv3, &fd_rq3); const uint32_t ik3 = fastdiv(iv3, &fd_rk3); - float * s_out = state_out_base + (uint64_t) (K - 1) * state_size_per_snap + ((uint64_t) iv3 * H + iv1) * S_v * S_v; + // final state lands in snapshot slot 0 (most-recent-first ordering) + float * s_out = state_out_base + ((uint64_t) iv3 * H + iv1) * S_v * S_v; float * attn_data = dst_base + ((uint64_t) iv3 * n_tokens * H + iv1) * S_v; @@ -792,7 +793,8 @@ static void gated_delta_net_f32_pp_thread(unsigned int nth, unsigned int ith, vo } if (K > 1) { - const int64_t target_slot = (int64_t) t - shift; + // snapshot slot mapping: slot 0 = most recent state, slot s = s tokens back. + const int64_t target_slot = (int64_t) n_tokens - 1 - (int64_t) t; if (target_slot >= 0 && target_slot < (int64_t) K) { float * curr_state_o = state_out_base + (uint64_t) target_slot * state_size_per_snap + ((uint64_t) iv3 * H + iv1) * S_v * S_v; if (curr_state_o != s_out) { @@ -844,7 +846,6 @@ static void gated_delta_net_f32_tg_thread(unsigned int nth, unsigned int ith, vo const uint32_t S_v = v->ne[0]; const uint32_t H = v->ne[1]; const uint32_t n_seqs = v->ne[3]; - const uint32_t K = state->ne[1]; const uint32_t total_rows = H * n_seqs; if (ith >= total_rows) { @@ -878,8 +879,7 @@ static void gated_delta_net_f32_tg_thread(unsigned int nth, unsigned int ith, vo struct fastdiv_values fd_rq3 = init_fastdiv_values(rq3); struct fastdiv_values fd_rk3 = init_fastdiv_values(rk3); - const uint64_t state_seq_stride = state->nb[2] / sizeof(float); - const uint64_t state_size_per_snap = (uint64_t) S_v * S_v * H * n_seqs; + const uint64_t state_seq_stride = state->nb[3] / sizeof(float); uint32_t ir_prefetch = ith; int spad_idx = 0; @@ -889,7 +889,8 @@ static void gated_delta_net_f32_tg_thread(unsigned int nth, unsigned int ith, vo const uint32_t piv1 = fastmodulo(ir_prefetch, H, &fd_H); const uint32_t piv3 = fastdiv(ir_prefetch, &fd_H); const float * ps_in = state_in_base + (uint64_t) piv3 * state_seq_stride + (uint64_t) piv1 * S_v * S_v; - float * ps_out = state_out_base + (uint64_t) (K - 1) * state_size_per_snap + ((uint64_t) piv3 * H + piv1) * S_v * S_v; + // final state lands in snapshot slot 0 (most-recent-first ordering) + float * ps_out = state_out_base + ((uint64_t) piv3 * H + piv1) * S_v * S_v; // Push dummy write-back dma_queue_push(dma, dma_make_ptr(ps_out, s_work[spad_idx]), @@ -920,7 +921,8 @@ static void gated_delta_net_f32_tg_thread(unsigned int nth, unsigned int ith, vo const uint32_t iq3 = fastdiv(iv3, &fd_rq3); const uint32_t ik3 = fastdiv(iv3, &fd_rk3); - float * s_out = state_out_base + (uint64_t) (K - 1) * state_size_per_snap + ((uint64_t) iv3 * H + iv1) * S_v * S_v; + // final state lands in snapshot slot 0 (most-recent-first ordering) + float * s_out = state_out_base + ((uint64_t) iv3 * H + iv1) * S_v * S_v; float * attn_data = dst_base + ((uint64_t) iv3 * H + iv1) * S_v; @@ -1097,7 +1099,7 @@ int op_gated_delta_net(struct htp_ops_context * octx) { const uint32_t H = v->ne[1]; const uint32_t n_tokens = v->ne[2]; const uint32_t n_seqs = v->ne[3]; - const uint32_t K = state->ne[1]; + const uint32_t K = octx->op_params[0]; if (S_v == 0 || S_v > HTP_GDN_MAX_SV || H == 0 || n_tokens == 0 || n_seqs == 0) { return HTP_STATUS_NO_SUPPORT; @@ -1110,7 +1112,8 @@ int op_gated_delta_net(struct htp_ops_context * octx) { (n_seqs % q->ne[3]) != 0 || (n_seqs % k->ne[3]) != 0) { return HTP_STATUS_NO_SUPPORT; } - if (state->ne[0] * state->ne[2] * state->ne[3] != S_v * S_v * H * n_seqs) { + // state holds s0 only: [S_v, S_v, H, n_seqs] + if (state->ne[0] != S_v || state->ne[1] != S_v || state->ne[2] != H || state->ne[3] != n_seqs) { return HTP_STATUS_NO_SUPPORT; } if (dst->ne[0] != S_v * H || dst->ne[1] != n_tokens * n_seqs + S_v * n_seqs * K) { diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index ce847dd8b6f..4f4f073cb61 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -590,8 +590,8 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net( const int ne20 = op->src[2]->ne[0]; // S_v const int ne21 = op->src[2]->ne[1]; // H const int ne30 = op->src[3]->ne[0]; // G - // state is src[5], 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count. - const int K = op->src[5]->ne[1]; + // state is src[5], 4D [S_v, S_v, H_v, n_seqs] (s0 only); K is op param 0. + const int K = ggml_get_op_params_i32(op, 0); const int nsg = op->src[2]->ne[0]/32; diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 2bd310d9450..0aea68455fb 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -2599,9 +2599,9 @@ kernel void kernel_gated_delta_net_impl( const float scale = 1.0f / sqrt((float)S_v); - // input state layout (D, K, n_seqs): per-seq stride is K*H*D; we read slot 0. + // input state layout [S_v, S_v, H, n_seqs] (s0 only): per-seq stride is H*D. // state is stored transposed: M[i20][is] = S[is][i20], so row i20 is contiguous - const uint state_in_base = (i23*K*args.ne21 + i21)*S_v*S_v + i20*S_v; + const uint state_in_base = (i23*args.ne21 + i21)*S_v*S_v + i20*S_v; device const float * s_ptr = (device const float *) (s) + state_in_base; float ls[NSG]; @@ -2620,9 +2620,8 @@ kernel void kernel_gated_delta_net_impl( device const float * b_ptr = (device const float *) (b) + (i23*args.ne22*args.ne21 + i21); device const float * g_ptr = (device const float *) (g) + (i23*args.ne22*args.ne21 + i21)*G; - // snapshot slot mapping: target_slot = t - shift. When n_tokens < K, only the last - // n_tokens slots are written; earlier slots are left untouched (caller-owned). - const int shift = (int)args.ne22 - (int)K; + // snapshot slot mapping: slot 0 = most recent state, slot s = s tokens back. + // When n_tokens < K, only slots 0..n_tokens-1 are written; older slots are caller-owned. // output state base offset: after attention scores const uint attn_size = args.ne22 * args.ne21 * S_v * args.ne23; @@ -2680,7 +2679,7 @@ kernel void kernel_gated_delta_net_impl( g_ptr += args.ne21*G; if (K > 1) { - const int target_slot = (int)t - shift; + const int target_slot = (int)args.ne22 - 1 - (int)t; if (target_slot >= 0 && target_slot < (int)K) { device float * dst_state = (device float *) (dst) + attn_size + (uint)target_slot * state_size_per_snap + state_out_base; FOR_UNROLL (short j = 0; j < NSG; j++) { diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 2a41215fd13..d30579b9452 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -17750,7 +17750,7 @@ static void ggml_cl_gated_delta_net(ggml_backend_t backend, ggml_tensor * dst) { const cl_uint H_v = (cl_uint) src_v->ne[1]; const cl_uint n_tokens = (cl_uint) src_v->ne[2]; const cl_uint n_seqs = (cl_uint) src_v->ne[3]; - const cl_uint K = (cl_uint) src_state->ne[1]; + const cl_uint K = (cl_uint) ggml_get_op_params_i32(dst, 0); int si; switch (S_v) { diff --git a/ggml/src/ggml-opencl/kernels/gated_delta_net.cl b/ggml/src/ggml-opencl/kernels/gated_delta_net.cl index d11192f5802..319c9829529 100644 --- a/ggml/src/ggml-opencl/kernels/gated_delta_net.cl +++ b/ggml/src/ggml-opencl/kernels/gated_delta_net.cl @@ -123,7 +123,8 @@ kernel void kernel_gated_delta_net( const uint iq3 = seq_id / rq3; // seq index for Q and K const uint state_size = S_V * S_V; - const uint state_base = (seq_id * K * H_v + head_id) * state_size; + // input state holds s0 only [S_v, S_v, H, n_seqs]: per-seq stride is H*D. + const uint state_base = (seq_id * H_v + head_id) * state_size; const uint q_off_base = iq3 * sq3 + iq1 * sq1; const uint v_off_base = seq_id * sv3 + head_id * sv1; const uint gb_off_base = seq_id * sb3 + head_id * sb1; @@ -143,7 +144,8 @@ kernel void kernel_gated_delta_net( } } - const int shift = (int)n_tokens - (int)K; + // snapshot slot mapping: slot 0 = most recent state, slot s = s tokens back. + // When n_tokens < K only slots 0..n_tokens-1 are written; older slots are caller-owned. uint attn_off = (seq_id * n_tokens * H_v + head_id) * S_V; for (uint t = 0; t < n_tokens; t++) { @@ -219,7 +221,7 @@ kernel void kernel_gated_delta_net( attn_off += S_V * H_v; if (K > 1u) { - const int target_slot = (int)t - shift; + const int target_slot = (int)n_tokens - 1 - (int)t; if (target_slot >= 0 && target_slot < (int)K) { #pragma unroll for (uint cg = 0; cg < COLS_PER_LANE_GROUP; cg++) { diff --git a/ggml/src/ggml-sycl/gated_delta_net.cpp b/ggml/src/ggml-sycl/gated_delta_net.cpp index 9c2449aba0c..239e00bd7e5 100644 --- a/ggml/src/ggml-sycl/gated_delta_net.cpp +++ b/ggml/src/ggml-sycl/gated_delta_net.cpp @@ -44,9 +44,9 @@ void gated_delta_net_sycl(const float * q, float * attn_data = dst; float * state = dst + attn_score_elems; - // input state layout (D, K, n_seqs) — seq stride is K * D = K * H * S_v * S_v. + // input state holds s0 only [S_v, S_v, H, n_seqs] — seq stride is D = H * S_v * S_v. // output state layout (per-slot D * n_seqs) — same per-(seq,head) offset as before. - const int64_t state_in_offset = sequence * K * H * S_v * S_v + h_idx * S_v * S_v; + const int64_t state_in_offset = sequence * H * S_v * S_v + h_idx * S_v * S_v; const int64_t state_out_offset = (sequence * H + h_idx) * S_v * S_v; const int64_t state_size_per_token = S_v * S_v * H * n_seqs; // per-slot stride in output state += state_out_offset; @@ -63,9 +63,8 @@ void gated_delta_net_sycl(const float * q, s_shard[r] = curr_state[i]; } - // slot mapping: target_slot = t - shift. When n_tokens < K only the last n_tokens slots - // are written; earlier slots are left untouched (caller-owned). - const int shift = (int) n_tokens - K; + // snapshot slot mapping: slot 0 = most recent state, slot s = s tokens back. + // When n_tokens < K only slots 0..n_tokens-1 are written; older slots are caller-owned. for (int t = 0; t < n_tokens; t++) { const float * q_t = q + iq3 * sq3 + t * sq2 + iq1 * sq1; @@ -144,7 +143,7 @@ void gated_delta_net_sycl(const float * q, // Write state back to global memory if constexpr (keep_rs_t) { - const int target_slot = t - shift; + const int target_slot = (int) n_tokens - 1 - t; if (target_slot >= 0 && target_slot < K) { float * curr_state = (dst + attn_score_elems) + target_slot * state_size_per_token + state_out_offset; #pragma unroll @@ -315,8 +314,8 @@ void ggml_sycl_op_gated_delta_net(ggml_backend_sycl_context & ctx, ggml_tensor * dpct::queue_ptr stream = ctx.stream(); - // state is 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count. - const int K = (int) src_state->ne[1]; + // K (snapshot slot count) is an op param; state holds s0 only [S_v, S_v, H, n_seqs]. + const int K = ggml_get_op_params_i32(dst, 0); const bool keep_rs = K > 1; if (kda) { diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 22405f234de..387826b6d93 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -11528,7 +11528,6 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s const ggml_tensor * src_q = dst->src[0]; const ggml_tensor * src_v = dst->src[2]; const ggml_tensor * src_beta = dst->src[4]; - const ggml_tensor * src_state = dst->src[5]; GGML_ASSERT(dst->buffer != nullptr); @@ -11537,8 +11536,8 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s const uint32_t n_tokens = (uint32_t)src_v->ne[2]; const uint32_t n_seqs = (uint32_t)src_v->ne[3]; - // state is 3D (S_v*S_v*H, K, n_seqs); K is the snapshot slot count. - const uint32_t K = (uint32_t)src_state->ne[1]; + // K (snapshot slot count) is an op param; state holds s0 only [S_v, S_v, H, n_seqs]. + const uint32_t K = (uint32_t)ggml_get_op_params_i32(dst, 0); const uint32_t s_off = S_v * H * n_tokens * n_seqs; @@ -17954,7 +17953,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * src_clone[4], src_clone[5], src_clone[6]); } else if (tensor->op == GGML_OP_GATED_DELTA_NET) { tensor_clone = ggml_gated_delta_net(ggml_ctx, src_clone[0], src_clone[1], - src_clone[2], src_clone[3], src_clone[4], src_clone[5]); + src_clone[2], src_clone[3], src_clone[4], src_clone[5], + ggml_get_op_params_i32(tensor, 0)); } else if (tensor->op == GGML_OP_OPT_STEP_ADAMW) { src_clone[0]->flags = tensor->src[0]->flags; tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1], diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp index 33c3202dbb7..0e384330b9b 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp @@ -102,8 +102,8 @@ void main() { const uint iq3 = seq_id / rq3; const uint state_size = S_V * S_V; - // input state layout (D, K, n_seqs): per-seq stride is K*H*D; we read slot 0. - const uint state_in_base = (seq_id * K * H + head_id) * state_size; + // input state holds s0 only [S_v, S_v, H, n_seqs]: per-seq stride is H*D. + const uint state_in_base = (seq_id * H + head_id) * state_size; // output state layout per slot: same per-(seq,head) offset as the single-slot case. const uint state_out_base = (seq_id * H + head_id) * state_size; const uint state_size_per_snap = state_size * H * n_seqs; @@ -113,9 +113,8 @@ void main() { s_shard[r] = FLOAT_TYPE(data_state[state_in_base + col * S_V + r * LANES_PER_COLUMN + lane]); } - // snapshot slot mapping: target_slot = t - shift. When n_tokens < K, only the last - // n_tokens slots are written; earlier slots are left untouched (caller-owned). - const int shift = int(n_tokens) - int(K); + // snapshot slot mapping: slot 0 = most recent state, slot s = s tokens back. + // When n_tokens < K, only slots 0..n_tokens-1 are written; older slots are caller-owned. uint attn_off = (seq_id * n_tokens * H + head_id) * S_V; @@ -172,7 +171,7 @@ void main() { attn_off += S_V * H; if (K > 1u) { - const int target_slot = int(t) - shift; + const int target_slot = int(n_tokens) - 1 - int(t); if (target_slot >= 0 && target_slot < int(K)) { const uint slot_base = s_off + uint(target_slot) * state_size_per_snap + state_out_base; [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) { diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 538e587bbe5..0b605fa86ba 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -1245,7 +1245,7 @@ static webgpu_encoded_op ggml_webgpu_gated_delta_net(webgpu_context & ctx, const uint32_t h = (uint32_t) src2->ne[1]; const uint32_t n_tokens = (uint32_t) src2->ne[2]; const uint32_t n_seqs = (uint32_t) src2->ne[3]; - const uint32_t K = (uint32_t) src5->ne[1]; + const uint32_t K = (uint32_t) ggml_get_op_params_i32(dst, 0); const float scale = 1.0f / sqrtf((float) s_v); uint32_t scale_u32; memcpy(&scale_u32, &scale, sizeof(scale_u32)); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl index d68520f8282..7d7b3475549 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl @@ -63,10 +63,10 @@ fn main( let iq3 = seq_id / params.rq3; let state_size = S_V * S_V; - let state_in_base = (seq_id * params.K * params.h + head_id) * state_size; + // input state holds s0 only [S_v, S_v, H, n_seqs]: per-seq stride is H*D. + let state_in_base = (seq_id * params.h + head_id) * state_size; let state_out_base = (seq_id * params.h + head_id) * state_size; let state_size_per_snap = state_size * params.h * params.n_seqs; - let shift = i32(params.n_tokens) - i32(params.K); var state: array<f32, S_V>; for (var i = 0u; i < S_V; i++) { @@ -128,7 +128,8 @@ fn main( attn_off += S_V * params.h; if (params.K > 1u) { - let target_slot = i32(t) - shift; + // snapshot slot mapping: slot 0 = most recent state, slot s = s tokens back. + let target_slot = i32(params.n_tokens) - 1 - i32(t); if (target_slot >= 0 && target_slot < i32(params.K)) { let slot_base = params.s_off + u32(target_slot) * state_size_per_snap + state_out_base; for (var i = 0u; i < S_V; i++) { diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 18a5ebd2ab0..b43016c87d2 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -6223,7 +6223,8 @@ struct ggml_tensor * ggml_gated_delta_net( struct ggml_tensor * v, struct ggml_tensor * g, struct ggml_tensor * beta, - struct ggml_tensor * state) { + struct ggml_tensor * state, + int64_t K) { GGML_ASSERT(ggml_is_contiguous_rows(q)); GGML_ASSERT(ggml_is_contiguous_rows(k)); GGML_ASSERT(ggml_is_contiguous_rows(v)); @@ -6247,15 +6248,18 @@ struct ggml_tensor * ggml_gated_delta_net( GGML_ASSERT(g->ne[0] == 1 || g->ne[0] == S_v); GGML_ASSERT(beta->ne[0] == 1); - // state is a 3D tensor (S_v*S_v*H, K, n_seqs). K is the snapshot slot count. - GGML_ASSERT(state->ne[0] == S_v * S_v * H); - GGML_ASSERT(state->ne[2] == n_seqs); - GGML_ASSERT(state->ne[3] == 1); - const int64_t K = state->ne[1]; + // state holds the initial state s0 only: [S_v, S_v, H, n_seqs]. K (snapshot slot count) is an op param. + GGML_ASSERT(state->ne[0] == S_v); + GGML_ASSERT(state->ne[1] == S_v); + GGML_ASSERT(state->ne[2] == H); + GGML_ASSERT(state->ne[3] == n_seqs); + GGML_ASSERT(K >= 1); const int64_t state_rows = K * S_v * n_seqs; const int64_t ne[4] = { S_v * H, n_tokens * n_seqs + state_rows, 1, 1 }; struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + ggml_set_op_params_i32(result, 0, (int32_t) K); + result->op = GGML_OP_GATED_DELTA_NET; result->src[0] = q; result->src[1] = k; From a512e4c5c3adf375239e83410dffb31bff8b2a7f Mon Sep 17 00:00:00 2001 From: Kevin Liu <4396kevinliu@gmail.com> Date: Thu, 11 Jun 2026 09:43:04 -0400 Subject: [PATCH 814/831] vulkan: use medium matmul tile on Asahi Linux (llama/24306) * vulkan: use medium matmul tile on Asahi Linux * vulkan: switch Apple detection to Honeykrisp driver id --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 387826b6d93..47533c2ba97 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -6202,6 +6202,17 @@ static vk_device ggml_vk_get_device(size_t idx) { break; } + // Honeykrisp driver for Asahi Linux doesn't report VK_VENDOR_ID_APPLE. + // Check for Honeykrisp driver and force same configuration as the VK_VENDOR_ID_APPLE case. + if (device->driver_id == vk::DriverId::eMesaHoneykrisp) { + device->mul_mat_l[i] = false; + device->mul_mat_m[i] = true; + device->mul_mat_s[i] = false; + device->mul_mat_id_l[i] = false; + device->mul_mat_id_m[i] = true; + device->mul_mat_id_s[i] = false; + } + device->mul_mat_l_int[i] = device->mul_mat_l[i]; device->mul_mat_m_int[i] = device->mul_mat_m[i]; device->mul_mat_s_int[i] = device->mul_mat_s[i]; From 6870cfd616bd0734c3b9ebe4dbf8010e34fdeb7e Mon Sep 17 00:00:00 2001 From: Winston Ma <winstonma@ymail.com> Date: Thu, 11 Jun 2026 21:46:25 +0800 Subject: [PATCH 815/831] vulkan: add fast path for contiguous buffer transfers (llama/23973) --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 47533c2ba97..5f372404521 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -7615,8 +7615,12 @@ static void ggml_vk_buffer_write_2d(vk_buffer& dst, size_t offset, const void * if(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) { GGML_ASSERT(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostCoherent); - for (size_t i = 0; i < height; i++) { - memcpy((uint8_t *)dst->ptr + offset + i * dpitch, (const uint8_t *) src + i * spitch, width); + if (width == spitch && width == dpitch) { + memcpy((uint8_t *)dst->ptr + offset, src, width * height); + } else { + for (size_t i = 0; i < height; i++) { + memcpy((uint8_t *)dst->ptr + offset + i * dpitch, (const uint8_t *) src + i * spitch, width); + } } } else { std::lock_guard<std::recursive_mutex> guard(dst->device->mutex); @@ -7735,8 +7739,12 @@ static void ggml_vk_buffer_read_2d(vk_buffer& src, size_t offset, void * dst, si if(src->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible && src->device->uma) { GGML_ASSERT(src->memory_property_flags & vk::MemoryPropertyFlagBits::eHostCoherent); - for (size_t i = 0; i < height; i++) { - memcpy((uint8_t *) dst + i * dpitch, (const uint8_t *) src->ptr + offset + i * spitch, width); + if (width == spitch && width == dpitch) { + memcpy(dst, (const uint8_t *) src->ptr + offset, width * height); + } else { + for (size_t i = 0; i < height; i++) { + memcpy((uint8_t *) dst + i * dpitch, (const uint8_t *) src->ptr + offset + i * spitch, width); + } } } else { std::lock_guard<std::recursive_mutex> guard(src->device->mutex); From b04008fcec0ac334d38ec754809bd7c2f8cc1f3d Mon Sep 17 00:00:00 2001 From: Georgi Gerganov <ggerganov@gmail.com> Date: Thu, 11 Jun 2026 19:32:38 +0300 Subject: [PATCH 816/831] ggml : bump version to 0.15.0 (ggml/1539) --- ggml/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 8f7cb8cdfd2..cd0e4fef978 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -4,7 +4,7 @@ project("ggml" C CXX ASM) ### GGML Version set(GGML_VERSION_MAJOR 0) -set(GGML_VERSION_MINOR 14) +set(GGML_VERSION_MINOR 15) set(GGML_VERSION_PATCH 0) set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}") From afd559279c1a6fd484632b86ee6eee5d70ce04a2 Mon Sep 17 00:00:00 2001 From: Jeff Bolz <jbolz@nvidia.com> Date: Thu, 11 Jun 2026 13:22:17 -0500 Subject: [PATCH 817/831] vulkan: ifdef eMesaHoneykrisp (build fix) (llama/24479) Fixes build/CI after #24306. --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 5f372404521..1b1150e7731 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -6202,6 +6202,7 @@ static vk_device ggml_vk_get_device(size_t idx) { break; } +#if VK_HEADER_VERSION >= 287 // Honeykrisp driver for Asahi Linux doesn't report VK_VENDOR_ID_APPLE. // Check for Honeykrisp driver and force same configuration as the VK_VENDOR_ID_APPLE case. if (device->driver_id == vk::DriverId::eMesaHoneykrisp) { @@ -6212,6 +6213,7 @@ static vk_device ggml_vk_get_device(size_t idx) { device->mul_mat_id_m[i] = true; device->mul_mat_id_s[i] = false; } +#endif device->mul_mat_l_int[i] = device->mul_mat_l[i]; device->mul_mat_m_int[i] = device->mul_mat_m[i]; From 2dcfd49d59810ab7fa0672e17b96968721ed6a27 Mon Sep 17 00:00:00 2001 From: shaofeiqi <shaoqi@qti.qualcomm.com> Date: Thu, 11 Jun 2026 21:43:09 -0700 Subject: [PATCH 818/831] opencl: add q5_0/q5_1 gemm and gemv kernels for Adreno (llama/24319) * opencl: add q5_0 adreno support * opencl: add q5_1 adreno support * opencl: cosmetic fix --------- Co-authored-by: Li He <lih@qti.qualcomm.com> --- ggml/src/ggml-opencl/CMakeLists.txt | 4 + ggml/src/ggml-opencl/ggml-opencl.cpp | 729 ++++++++++++++++-- ggml/src/ggml-opencl/kernels/cvt.cl | 114 +++ .../kernels/gemm_noshuffle_q5_0_f32.cl | 131 ++++ .../kernels/gemm_noshuffle_q5_1_f32.cl | 134 ++++ .../kernels/gemv_noshuffle_q5_0_f32.cl | 291 +++++++ .../kernels/gemv_noshuffle_q5_1_f32.cl | 294 +++++++ 7 files changed, 1642 insertions(+), 55 deletions(-) create mode 100644 ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_0_f32.cl create mode 100644 ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_1_f32.cl create mode 100644 ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_0_f32.cl create mode 100644 ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_1_f32.cl diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index cd15d573238..82ce61d72c6 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -142,6 +142,10 @@ set(GGML_OPENCL_KERNELS gemm_noshuffle_q4_0_f32 gemv_noshuffle_q4_1_f32 gemm_noshuffle_q4_1_f32 + gemv_noshuffle_q5_0_f32 + gemm_noshuffle_q5_0_f32 + gemv_noshuffle_q5_1_f32 + gemm_noshuffle_q5_1_f32 gemv_noshuffle_iq4_nl_f32 gemm_noshuffle_iq4_nl_f32 gemv_noshuffle_q8_0_f32 diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index d30579b9452..ca2002424df 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -593,6 +593,10 @@ struct ggml_backend_opencl_context { cl_kernel kernel_restore_block_q4_0_noshuffle; cl_kernel kernel_convert_block_q4_1_noshuffle; cl_kernel kernel_restore_block_q4_1_noshuffle; + cl_kernel kernel_convert_block_q5_0_noshuffle; + cl_kernel kernel_restore_block_q5_0_noshuffle; + cl_kernel kernel_convert_block_q5_1_noshuffle; + cl_kernel kernel_restore_block_q5_1_noshuffle; cl_kernel kernel_convert_block_q4_K_noshuffle; cl_kernel kernel_restore_block_q4_K_noshuffle; cl_kernel kernel_convert_block_q4_K, kernel_restore_block_q4_K; @@ -829,6 +833,10 @@ struct ggml_backend_opencl_context { cl_kernel kernel_gemm_noshuffle_q6_K_f32; cl_kernel kernel_gemv_noshuffle_q5_k_f32; cl_kernel kernel_gemm_noshuffle_q5_k_f32; + cl_kernel kernel_gemv_noshuffle_q5_0_f32; + cl_kernel kernel_gemm_noshuffle_q5_0_f32; + cl_kernel kernel_gemv_noshuffle_q5_1_f32; + cl_kernel kernel_gemm_noshuffle_q5_1_f32; cl_kernel kernel_gemv_noshuffle_iq4_nl_f32; cl_kernel kernel_gemm_noshuffle_iq4_nl_f32; #endif // GGML_OPENCL_USE_ADRENO_KERNELS @@ -1152,6 +1160,10 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx) { CL_CHECK((backend_ctx->kernel_restore_block_q4_1_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_1_trans4_ns", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q5_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q5_0", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_q5_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q5_0", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q5_0_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q5_0_noshuffle", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q5_0_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q5_0_noshuffle", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q5_1_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q5_1_noshuffle", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q5_1_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q5_1_noshuffle", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q5_0_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q5_0_trans4_ns", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_q5_0_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q5_0_trans4_ns", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q5_1 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q5_1", &err), err)); @@ -3065,6 +3077,80 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx) { GGML_LOG_CONT("."); } + // gemm_noshuffle_q5_0_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemm_noshuffle_q5_0_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("gemm_noshuffle_q5_0_f32.cl"); +#endif + cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_gemm_noshuffle_q5_0_f32 = clCreateKernel(prog, "kernel_gemm_noshuffle_q5_0_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // gemv_noshuffle_q5_0_f32 + { + std::string CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable "; + if (backend_ctx->has_vector_subgroup_broadcast) { + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAST "; + } + +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemv_noshuffle_q5_0_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("gemv_noshuffle_q5_0_f32.cl"); +#endif + cl_program prog = build_program_from_source( + backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_gemv_compile_opts); + CL_CHECK((backend_ctx->kernel_gemv_noshuffle_q5_0_f32 = clCreateKernel(prog, "kernel_gemv_noshuffle_q5_0_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // gemm_noshuffle_q5_1_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemm_noshuffle_q5_1_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("gemm_noshuffle_q5_1_f32.cl"); +#endif + cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_gemm_noshuffle_q5_1_f32 = clCreateKernel(prog, "kernel_gemm_noshuffle_q5_1_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // gemv_noshuffle_q5_1_f32 + { + std::string CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable "; + if (backend_ctx->has_vector_subgroup_broadcast) { + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAST "; + } + +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemv_noshuffle_q5_1_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("gemv_noshuffle_q5_1_f32.cl"); +#endif + cl_program prog = build_program_from_source( + backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_gemv_compile_opts); + CL_CHECK((backend_ctx->kernel_gemv_noshuffle_q5_1_f32 = clCreateKernel(prog, "kernel_gemv_noshuffle_q5_1_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + // gemm_noshuffle_iq4_nl_f32 { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -6107,15 +6193,16 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, return; } #endif // GGML_OPENCL_USE_ADRENO_KERNELS - cl_kernel kernel = backend_ctx->kernel_convert_block_q5_0; - cl_ulong n_blk = ggml_nelements(tensor)/ggml_blck_size(tensor->type); + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_kernels(backend_ctx, tensor)) { + cl_kernel kernel = backend_ctx->kernel_convert_block_q5_0_noshuffle; CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->qs)); CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->qh)); CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->d)); - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &n_blk)); - size_t global_work_size[] = {(size_t)CEIL_DIV(n_blk, 64) * 64, 1, 1}; + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; size_t local_work_size[] = {64, 1, 1}; cl_event evt; @@ -6124,7 +6211,39 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, CL_CHECK(clReleaseMemObject(data_device)); tensor->extra = extra; + + int M = tensor->ne[1]; + int K = tensor->ne[0]; + GGML_ASSERT(K % 32 == 0); + + // Transpose qs as ushort + transpose_2d_as_16b(backend_ctx, extra->qs, extra->qs, size_qs, K/4, M); + // Transpose qh as uchar + transpose_2d_as_8b(backend_ctx, extra->qh, extra->qh, size_qh, K/8, M); + // Transpose d as ushort + transpose_2d_as_16b(backend_ctx, extra->d, extra->d, size_d, K/32, M); + return; + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + cl_kernel kernel = backend_ctx->kernel_convert_block_q5_0; + cl_ulong n_blk = ggml_nelements(tensor)/ggml_blck_size(tensor->type); + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->qs)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->qh)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &n_blk)); + + size_t global_work_size[] = {(size_t)CEIL_DIV(n_blk, 64) * 64, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + + tensor->extra = extra; + return; } if (tensor->type == GGML_TYPE_Q5_1) { ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra; @@ -6225,6 +6344,42 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, return; } #endif // GGML_OPENCL_USE_ADRENO_KERNELS + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_kernels(backend_ctx, tensor)) { + cl_kernel kernel = backend_ctx->kernel_convert_block_q5_1_noshuffle; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->qs)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->qh)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra->m)); + + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + + tensor->extra = extra; + + int M = tensor->ne[1]; + int K = tensor->ne[0]; + GGML_ASSERT(K % 32 == 0); + + // Transpose qs as ushort + transpose_2d_as_16b(backend_ctx, extra->qs, extra->qs, size_qs, K/4, M); + // Transpose qh as uchar + transpose_2d_as_8b(backend_ctx, extra->qh, extra->qh, size_qh, K/8, M); + // Transpose d as ushort + transpose_2d_as_16b(backend_ctx, extra->d, extra->d, size_d, K/32, M); + // Transpose m as ushort + transpose_2d_as_16b(backend_ctx, extra->m, extra->m, size_m, K/32, M); + + return; + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS cl_kernel kernel = backend_ctx->kernel_convert_block_q5_1; cl_ulong n_blk = ggml_nelements(tensor)/ggml_blck_size(tensor->type); CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); @@ -7299,6 +7454,48 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, CL_CHECK(clReleaseMemObject(data_device)); return; } + if (use_adreno_kernels(backend_ctx, tensor)) { + ggml_cl_buffer buf_trans_qs; + ggml_cl_buffer buf_trans_qh; + ggml_cl_buffer buf_trans_d; + ggml_cl_buffer buf_unpacked; + + cl_int M = tensor->ne[1]; + cl_int K = tensor->ne[0]; + + GGML_ASSERT(K % 32 == 0); + + size_t size_qs = (ggml_nelements(tensor)/ggml_blck_size(tensor->type))*ggml_blck_size(tensor->type)/2; + size_t size_qh = (ggml_nelements(tensor)/ggml_blck_size(tensor->type))*sizeof(int32_t); + size_t size_d = (ggml_nelements(tensor)/ggml_blck_size(tensor->type))*sizeof(ggml_fp16_t); + + buf_trans_qs.allocate(backend_ctx->context, size_qs); + buf_trans_qh.allocate(backend_ctx->context, size_qh); + buf_trans_d.allocate(backend_ctx->context, size_d); + buf_unpacked.allocate(backend_ctx->context, ggml_nbytes(tensor)); + + transpose_2d_as_16b(backend_ctx, extra->qs, buf_trans_qs.buffer, size_qs, M, K/4); + transpose_2d_as_8b(backend_ctx, extra->qh, buf_trans_qh.buffer, size_qh, M, K/8); + transpose_2d_as_16b(backend_ctx, extra->d, buf_trans_d.buffer, size_d, M, K/32); + + cl_uchar mask_0F = 0x0F; + cl_uchar mask_F0 = 0xF0; + + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {1, 1, 1}; + + cl_kernel kernel = backend_ctx->kernel_restore_block_q5_0_noshuffle; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &buf_trans_qs.buffer)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &buf_trans_qh.buffer)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &buf_trans_d.buffer)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &buf_unpacked.buffer)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_uchar), &mask_0F)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_uchar), &mask_F0)); + + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); + CL_CHECK(clEnqueueReadBuffer(queue, buf_unpacked.buffer, CL_TRUE, offset, size, data, 0, NULL, NULL)); + return; + } #endif // GGML_OPENCL_USE_ADRENO_KERNELS cl_int err; @@ -7362,6 +7559,54 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, CL_CHECK(clReleaseMemObject(data_device)); return; } + + if (use_adreno_kernels(backend_ctx, tensor)) { + ggml_cl_buffer buf_trans_qs; + ggml_cl_buffer buf_trans_qh; + ggml_cl_buffer buf_trans_d; + ggml_cl_buffer buf_trans_m; + ggml_cl_buffer buf_unpacked; + + cl_int M = tensor->ne[1]; + cl_int K = tensor->ne[0]; + GGML_ASSERT(K % 32 == 0); + + size_t size_qs = (ggml_nelements(tensor)/ggml_blck_size(tensor->type))*ggml_blck_size(tensor->type)/2; + size_t size_qh = (ggml_nelements(tensor)/ggml_blck_size(tensor->type))*sizeof(int32_t); + size_t size_d = (ggml_nelements(tensor)/ggml_blck_size(tensor->type))*sizeof(ggml_fp16_t); + size_t size_m = (ggml_nelements(tensor)/ggml_blck_size(tensor->type))*sizeof(ggml_fp16_t); + + buf_trans_qs.allocate(backend_ctx->context, size_qs); + buf_trans_qh.allocate(backend_ctx->context, size_qh); + buf_trans_d.allocate(backend_ctx->context, size_d); + buf_trans_m.allocate(backend_ctx->context, size_m); + buf_unpacked.allocate(backend_ctx->context, ggml_nbytes(tensor)); + + // Transpose back: from col-major to row-major + transpose_2d_as_16b(backend_ctx, extra->qs, buf_trans_qs.buffer, size_qs, M, K/4); + transpose_2d_as_8b(backend_ctx, extra->qh, buf_trans_qh.buffer, size_qh, M, K/8); + transpose_2d_as_16b(backend_ctx, extra->d, buf_trans_d.buffer, size_d, M, K/32); + transpose_2d_as_16b(backend_ctx, extra->m, buf_trans_m.buffer, size_m, M, K/32); + + cl_uchar mask_0F = 0x0F; + cl_uchar mask_F0 = 0xF0; + + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {1, 1, 1}; + + cl_kernel kernel = backend_ctx->kernel_restore_block_q5_1_noshuffle; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &buf_trans_qs.buffer)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &buf_trans_qh.buffer)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &buf_trans_d.buffer)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &buf_trans_m.buffer)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &buf_unpacked.buffer)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_uchar), &mask_0F)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_uchar), &mask_F0)); + + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); + CL_CHECK(clEnqueueReadBuffer(queue, buf_unpacked.buffer, CL_TRUE, offset, size, data, 0, NULL, NULL)); + return; + } #endif // GGML_OPENCL_USE_ADRENO_KERNELS cl_int err; cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, @@ -12205,7 +12450,7 @@ static void ggml_cl_mul_mat_q4_1_f32_adreno(ggml_backend_t backend, const ggml_t #endif } -static void ggml_cl_mul_mat_iq4_nl_f32_adreno(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_cl_mul_mat_q5_0_f32_adreno(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { #ifdef GGML_OPENCL_USE_ADRENO_KERNELS GGML_ASSERT(src0); GGML_ASSERT(src0->extra); @@ -12218,17 +12463,17 @@ static void ggml_cl_mul_mat_iq4_nl_f32_adreno(ggml_backend_t backend, const ggml ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; - ggml_tensor_extra_cl_iq4_nl * extra0_iq4_nl = (ggml_tensor_extra_cl_iq4_nl *)src0->extra; + ggml_tensor_extra_cl_q5_0 * extra0_q5_0 = (ggml_tensor_extra_cl_q5_0 *)src0->extra; cl_ulong offset1 = extra1->offset + src1->view_offs; cl_ulong offsetd = extrad->offset + dst->view_offs; - const int ne00 = src0->ne[0]; - const int ne01 = src0->ne[1]; + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; - const int ne1 = dst->ne[1]; + const int ne1 = dst->ne[1]; - GGML_ASSERT(ne00 % 32 == 0); + GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0); cl_context context = backend_ctx->context; cl_kernel kernel; @@ -12243,17 +12488,17 @@ static void ggml_cl_mul_mat_iq4_nl_f32_adreno(ggml_backend_t backend, const ggml int K = ne00; if (ne1 == 1) { - cl_mem q_img = nullptr; + cl_mem qs_img = nullptr; cl_mem b_sub_buf = nullptr; cl_mem b_img = nullptr; - // image for q - img_fmt = { CL_R, CL_UNSIGNED_INT32}; + // image for qs + img_fmt = { CL_R, CL_UNSIGNED_INT32 }; memset(&img_desc, 0, sizeof(img_desc)); img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; img_desc.image_width = M * K / 2 / 4; - img_desc.buffer = extra0_iq4_nl->q; - CL_CHECK((q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + img_desc.buffer = extra0_q5_0->qs; + CL_CHECK((qs_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); // subbuffer for activations region.origin = offset1; @@ -12268,22 +12513,23 @@ static void ggml_cl_mul_mat_iq4_nl_f32_adreno(ggml_backend_t backend, const ggml img_desc.buffer = b_sub_buf; CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); - kernel = backend_ctx->kernel_gemv_noshuffle_iq4_nl_f32; + kernel = backend_ctx->kernel_gemv_noshuffle_q5_0_f32; - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &q_img)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_iq4_nl->d)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &b_img)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extrad->data_device)); - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &offsetd)); - CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_int), &ne00)); - CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &qs_img)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q5_0->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q5_0->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &b_img)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_int), &ne01)); size_t local_work_size[3] = {64, 4, 1}; size_t global_work_size[3] = {(size_t)CEIL_DIV(ne01/2, 64)*64, 4, 1}; backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); - CL_CHECK(clReleaseMemObject(q_img)); + CL_CHECK(clReleaseMemObject(qs_img)); CL_CHECK(clReleaseMemObject(b_sub_buf)); CL_CHECK(clReleaseMemObject(b_img)); } else { @@ -12291,6 +12537,7 @@ static void ggml_cl_mul_mat_iq4_nl_f32_adreno(ggml_backend_t backend, const ggml cl_mem b_sub_buf_trans = nullptr; cl_mem b_img = nullptr; cl_mem b_img_trans = nullptr; + cl_mem d_sub_buf = nullptr; // subbuffer for activations region.origin = offset1; @@ -12326,6 +12573,11 @@ static void ggml_cl_mul_mat_iq4_nl_f32_adreno(ggml_backend_t backend, const ggml img_desc.buffer = b_sub_buf_trans; CL_CHECK((b_img_trans = clCreateImage(context, 0, &img_fmt, &img_desc, NULL, &err), err)); + // subbuffer for output + region.origin = extrad->offset; + region.size = M * N * sizeof(float); + CL_CHECK((d_sub_buf = clCreateSubBuffer(extrad->data_device, CL_MEM_WRITE_ONLY, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + // transpose activations int height_B = N/4; if (height_B == 0) { @@ -12346,14 +12598,14 @@ static void ggml_cl_mul_mat_iq4_nl_f32_adreno(ggml_backend_t backend, const ggml backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size_t, local_work_size_t, dst); // gemm - kernel = backend_ctx->kernel_gemm_noshuffle_iq4_nl_f32; + kernel = backend_ctx->kernel_gemm_noshuffle_q5_0_f32; int padded_N = N + padding; - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_iq4_nl->q)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_iq4_nl->d)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &b_img_trans)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extrad->data_device)); - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q5_0->qs)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q5_0->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q5_0->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &b_img_trans)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &d_sub_buf)); CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_int), &ne01)); CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_int), &padded_N)); CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_int), &ne00)); @@ -12368,6 +12620,7 @@ static void ggml_cl_mul_mat_iq4_nl_f32_adreno(ggml_backend_t backend, const ggml CL_CHECK(clReleaseMemObject(b_sub_buf_trans)); CL_CHECK(clReleaseMemObject(b_img)); CL_CHECK(clReleaseMemObject(b_img_trans)); + CL_CHECK(clReleaseMemObject(d_sub_buf)); } #else GGML_UNUSED(backend); @@ -12377,7 +12630,7 @@ static void ggml_cl_mul_mat_iq4_nl_f32_adreno(ggml_backend_t backend, const ggml #endif } -static void ggml_cl_mul_mat_q8_0_f32_adreno(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_cl_mul_mat_q5_1_f32_adreno(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { #ifdef GGML_OPENCL_USE_ADRENO_KERNELS GGML_ASSERT(src0); GGML_ASSERT(src0->extra); @@ -12386,34 +12639,21 @@ static void ggml_cl_mul_mat_q8_0_f32_adreno(ggml_backend_t backend, const ggml_t GGML_ASSERT(dst); GGML_ASSERT(dst->extra); - GGML_ASSERT(src0->type == GGML_TYPE_Q8_0); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; - ggml_tensor_extra_cl_q8_0 * extra0_q8_0 = (ggml_tensor_extra_cl_q8_0 *)src0->extra; + ggml_tensor_extra_cl_q5_1 * extra0_q5_1 = (ggml_tensor_extra_cl_q5_1 *)src0->extra; cl_ulong offset1 = extra1->offset + src1->view_offs; cl_ulong offsetd = extrad->offset + dst->view_offs; - GGML_ASSERT(src1->view_offs == 0); - GGML_ASSERT(dst->view_offs == 0); - - const int ne00 = src0->ne[0]; - const int ne01 = src0->ne[1]; - const int ne02 = src0->ne[2]; - - const int ne10 = src1->ne[0]; - const int ne12 = src1->ne[2]; + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; - const int ne0 = dst->ne[0]; - const int ne1 = dst->ne[1]; + const int ne1 = dst->ne[1]; - GGML_ASSERT(ne00 == ne10); - GGML_ASSERT((ne00 % 32) == 0); - GGML_ASSERT(ne0 == ne01); + GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0); cl_context context = backend_ctx->context; cl_kernel kernel; @@ -12428,17 +12668,384 @@ static void ggml_cl_mul_mat_q8_0_f32_adreno(ggml_backend_t backend, const ggml_t int K = ne00; if (ne1 == 1) { - cl_mem q_img = nullptr; + cl_mem qs_img = nullptr; cl_mem b_sub_buf = nullptr; cl_mem b_img = nullptr; - // image for q - img_fmt = { CL_R, CL_UNSIGNED_INT32}; + // image for qs + img_fmt = { CL_R, CL_UNSIGNED_INT32 }; memset(&img_desc, 0, sizeof(img_desc)); img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; - img_desc.image_width = M * K / 4; - img_desc.buffer = extra0_q8_0->q; - CL_CHECK((q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + img_desc.image_width = M * K / 2 / 4; + img_desc.buffer = extra0_q5_1->qs; + CL_CHECK((qs_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + // subbuffer for activations + region.origin = offset1; + region.size = K * N * sizeof(float); + CL_CHECK((b_sub_buf = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for activations + img_fmt = {CL_RGBA, CL_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * N / 4; + img_desc.buffer = b_sub_buf; + CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + kernel = backend_ctx->kernel_gemv_noshuffle_q5_1_f32; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &qs_img)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q5_1->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q5_1->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q5_1->m)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &b_img)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_int), &ne01)); + + size_t local_work_size[3] = {64, 4, 1}; + size_t global_work_size[3] = {(size_t)CEIL_DIV(ne01/2, 64)*64, 4, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + + CL_CHECK(clReleaseMemObject(qs_img)); + CL_CHECK(clReleaseMemObject(b_sub_buf)); + CL_CHECK(clReleaseMemObject(b_img)); + } else { + cl_mem b_sub_buf = nullptr; + cl_mem b_sub_buf_trans = nullptr; + cl_mem b_img = nullptr; + cl_mem b_img_trans = nullptr; + cl_mem d_sub_buf = nullptr; + + // subbuffer for activations + region.origin = offset1; + region.size = K * N * sizeof(float); + CL_CHECK((b_sub_buf = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for activations + img_fmt = {CL_RGBA, CL_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * N / 4; + img_desc.buffer = b_sub_buf; + CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + // pad N to multiple of 8 + int extra_elements = N % 8; + int padding = 0; + if (extra_elements > 0){ + padding = 8 - extra_elements; + } + + // subbuffer for transposed activations + region.origin = 0; + region.size = K * (N + padding) * sizeof(float)/2; + backend_ctx->prealloc_act_trans.allocate(context, region.size); + CL_CHECK((b_sub_buf_trans = clCreateSubBuffer(backend_ctx->prealloc_act_trans.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for transposed activations + img_fmt = {CL_RGBA, CL_HALF_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * (N + padding) / 4; + img_desc.buffer = b_sub_buf_trans; + CL_CHECK((b_img_trans = clCreateImage(context, 0, &img_fmt, &img_desc, NULL, &err), err)); + + // subbuffer for output + region.origin = extrad->offset; + region.size = M * N * sizeof(float); + CL_CHECK((d_sub_buf = clCreateSubBuffer(extrad->data_device, CL_MEM_WRITE_ONLY, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // transpose activations + int height_B = N/4; + if (height_B == 0) { + height_B = 1; + } + int width_B = K/4; + int padded_height_B = (N + padding)/4; + + kernel = backend_ctx->kernel_transpose_32_16; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &b_img)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &b_img_trans)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_B)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_B)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &padded_height_B)); + + size_t local_work_size_t[2] = { 1, 16 }; + size_t global_work_size_t[2] = { (size_t)width_B, (size_t)padded_height_B }; + backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size_t, local_work_size_t, dst); + + // gemm + kernel = backend_ctx->kernel_gemm_noshuffle_q5_1_f32; + int padded_N = N + padding; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q5_1->qs)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q5_1->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q5_1->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q5_1->m)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &b_img_trans)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &d_sub_buf)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_int), &padded_N)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_int), &ne1)); + + size_t global_work_size[3] = {(size_t)CEIL_DIV(ne1, 8), (size_t)CEIL_DIV(ne01, 4), 1}; + size_t local_work_size[3] = {1, 128, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + + CL_CHECK(clReleaseMemObject(b_sub_buf)); + CL_CHECK(clReleaseMemObject(b_sub_buf_trans)); + CL_CHECK(clReleaseMemObject(b_img)); + CL_CHECK(clReleaseMemObject(b_img_trans)); + CL_CHECK(clReleaseMemObject(d_sub_buf)); + } +#else + GGML_UNUSED(backend); + GGML_UNUSED(src0); + GGML_UNUSED(src1); + GGML_UNUSED(dst); +#endif +} + +static void ggml_cl_mul_mat_iq4_nl_f32_adreno(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + ggml_tensor_extra_cl_iq4_nl * extra0_iq4_nl = (ggml_tensor_extra_cl_iq4_nl *)src0->extra; + + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + + const int ne1 = dst->ne[1]; + + GGML_ASSERT(ne00 % 32 == 0); + + cl_context context = backend_ctx->context; + cl_kernel kernel; + + cl_int err; + cl_image_format img_fmt; + cl_image_desc img_desc; + cl_buffer_region region; + + int M = ne01; + int N = ne1; + int K = ne00; + + if (ne1 == 1) { + cl_mem q_img = nullptr; + cl_mem b_sub_buf = nullptr; + cl_mem b_img = nullptr; + + // image for q + img_fmt = { CL_R, CL_UNSIGNED_INT32}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = M * K / 2 / 4; + img_desc.buffer = extra0_iq4_nl->q; + CL_CHECK((q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + // subbuffer for activations + region.origin = offset1; + region.size = K * N * sizeof(float); + CL_CHECK((b_sub_buf = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for activations + img_fmt = {CL_RGBA, CL_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * N / 4; + img_desc.buffer = b_sub_buf; + CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + kernel = backend_ctx->kernel_gemv_noshuffle_iq4_nl_f32; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &q_img)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_iq4_nl->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &b_img)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_int), &ne01)); + + size_t local_work_size[3] = {64, 4, 1}; + size_t global_work_size[3] = {(size_t)CEIL_DIV(ne01/2, 64)*64, 4, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + + CL_CHECK(clReleaseMemObject(q_img)); + CL_CHECK(clReleaseMemObject(b_sub_buf)); + CL_CHECK(clReleaseMemObject(b_img)); + } else { + cl_mem b_sub_buf = nullptr; + cl_mem b_sub_buf_trans = nullptr; + cl_mem b_img = nullptr; + cl_mem b_img_trans = nullptr; + + // subbuffer for activations + region.origin = offset1; + region.size = K * N * sizeof(float); + CL_CHECK((b_sub_buf = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for activations + img_fmt = {CL_RGBA, CL_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * N / 4; + img_desc.buffer = b_sub_buf; + CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + // pad N to multiple of 8 + int extra_elements = N % 8; + int padding = 0; + if (extra_elements > 0){ + padding = 8 - extra_elements; + } + + // subbuffer for transposed activations + region.origin = 0; + region.size = K * (N + padding) * sizeof(float)/2; + backend_ctx->prealloc_act_trans.allocate(context, region.size); + CL_CHECK((b_sub_buf_trans = clCreateSubBuffer(backend_ctx->prealloc_act_trans.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for transposed activations + img_fmt = {CL_RGBA, CL_HALF_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * (N + padding) / 4; + img_desc.buffer = b_sub_buf_trans; + CL_CHECK((b_img_trans = clCreateImage(context, 0, &img_fmt, &img_desc, NULL, &err), err)); + + // transpose activations + int height_B = N/4; + if (height_B == 0) { + height_B = 1; + } + int width_B = K/4; + int padded_height_B = (N + padding)/4; + + kernel = backend_ctx->kernel_transpose_32_16; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &b_img)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &b_img_trans)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_B)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_B)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &padded_height_B)); + + size_t local_work_size_t[2] = { 1, 16 }; + size_t global_work_size_t[2] = { (size_t)width_B, (size_t)padded_height_B }; + backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size_t, local_work_size_t, dst); + + // gemm + kernel = backend_ctx->kernel_gemm_noshuffle_iq4_nl_f32; + int padded_N = N + padding; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_iq4_nl->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_iq4_nl->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &b_img_trans)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_int), &padded_N)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_int), &ne1)); + + size_t global_work_size[3] = {(size_t)CEIL_DIV(ne1, 8), (size_t)CEIL_DIV(ne01, 4), 1}; + size_t local_work_size[3] = {1, 128, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + + CL_CHECK(clReleaseMemObject(b_sub_buf)); + CL_CHECK(clReleaseMemObject(b_sub_buf_trans)); + CL_CHECK(clReleaseMemObject(b_img)); + CL_CHECK(clReleaseMemObject(b_img_trans)); + } +#else + GGML_UNUSED(backend); + GGML_UNUSED(src0); + GGML_UNUSED(src1); + GGML_UNUSED(dst); +#endif +} + +static void ggml_cl_mul_mat_q8_0_f32_adreno(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + GGML_ASSERT(src0->type == GGML_TYPE_Q8_0); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + ggml_tensor_extra_cl_q8_0 * extra0_q8_0 = (ggml_tensor_extra_cl_q8_0 *)src0->extra; + + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + GGML_ASSERT(src1->view_offs == 0); + GGML_ASSERT(dst->view_offs == 0); + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + + const int ne10 = src1->ne[0]; + const int ne12 = src1->ne[2]; + + const int ne0 = dst->ne[0]; + const int ne1 = dst->ne[1]; + + GGML_ASSERT(ne00 == ne10); + GGML_ASSERT((ne00 % 32) == 0); + GGML_ASSERT(ne0 == ne01); + + cl_context context = backend_ctx->context; + cl_kernel kernel; + + cl_int err; + cl_image_format img_fmt; + cl_image_desc img_desc; + cl_buffer_region region; + + int M = ne01; + int N = ne1; + int K = ne00; + + if (ne1 == 1) { + cl_mem q_img = nullptr; + cl_mem b_sub_buf = nullptr; + cl_mem b_img = nullptr; + + // image for q + img_fmt = { CL_R, CL_UNSIGNED_INT32}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = M * K / 4; + img_desc.buffer = extra0_q8_0->q; + CL_CHECK((q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); // create a sub_buffer for B region.origin = offset1; @@ -13243,6 +13850,18 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co return; } + // q5_0 x fp32 + if (src0t == GGML_TYPE_Q5_0 && src1t == GGML_TYPE_F32) { + ggml_cl_mul_mat_q5_0_f32_adreno(backend, src0, src1, dst); + return; + } + + // q5_1 x fp32 + if (src0t == GGML_TYPE_Q5_1 && src1t == GGML_TYPE_F32) { + ggml_cl_mul_mat_q5_1_f32_adreno(backend, src0, src1, dst); + return; + } + // iq4_nl x fp32 if (src0t == GGML_TYPE_IQ4_NL && src1t == GGML_TYPE_F32) { ggml_cl_mul_mat_iq4_nl_f32_adreno(backend, src0, src1, dst); diff --git a/ggml/src/ggml-opencl/kernels/cvt.cl b/ggml/src/ggml-opencl/kernels/cvt.cl index d07f0a1a025..226b127ab3b 100644 --- a/ggml/src/ggml-opencl/kernels/cvt.cl +++ b/ggml/src/ggml-opencl/kernels/cvt.cl @@ -584,6 +584,60 @@ kernel void kernel_restore_block_q5_0( } } +kernel void kernel_convert_block_q5_0_noshuffle( + global struct block_q5_0 * src0, + global uchar * dst_q, + global uint * dst_qh, + global half * dst_d +) { + global struct block_q5_0 * b = (global struct block_q5_0 *) src0 + get_global_id(0); + global uchar * q = (global uchar *) dst_q + QK5_0/2*get_global_id(0); + global uint * qh = (global uint *) dst_qh + get_global_id(0); + global half * d = (global half *) dst_d + get_global_id(0); + + *d = b->d; + *qh = *((global uint *)(b->qh)); + + for (int i = 0; i < QK5_0/4; ++i) { + uchar x0 = b->qs[2*i + 0]; + uchar x1 = b->qs[2*i + 1]; + + q[i + 0 ] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4); + q[i + QK5_0/4] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0); + +#ifdef ADRENO_GPU + if (get_global_id(0) == 65536*4096) { + printf("%04x - %02x\n", *(global ushort*)d, ((x0 & 0xF0) >> 4) | (x1 & 0xF0)); + } +#endif + } +} + +kernel void kernel_restore_block_q5_0_noshuffle( + global uchar * src_q, + global uint * src_qh, + global half * src_d, + global struct block_q5_0 * dst, + uchar mask_0F, + uchar mask_F0 +) { + global struct block_q5_0 * b = (global struct block_q5_0 *) dst + get_global_id(0); + global uchar * q = (global uchar *) src_q + QK5_0/2*get_global_id(0); + global uint * qh = (global uint *) src_qh + get_global_id(0); + global half * d = (global half *) src_d + get_global_id(0); + + b->d = *d; + *((global uint *)(b->qh)) = *qh; + + for (int i = 0; i < QK5_0/4; ++i) { + uchar x0 = q[i + 0 ]; + uchar x1 = q[i + QK5_0/4]; + + b->qs[2*i + 0] = convert_uchar((x0 & mask_0F) | ((x1 & mask_0F) << 4)); + b->qs[2*i + 1] = convert_uchar(((x0 & mask_F0) >> 4) | (x1 & mask_F0)); + } +} + kernel void kernel_convert_block_q5_0_trans4_ns( __global struct block_q5_0 * src0, __global uint * dst_qs, @@ -736,6 +790,66 @@ kernel void kernel_restore_block_q5_1( } } +kernel void kernel_convert_block_q5_1_noshuffle( + global struct block_q5_1 * src0, + global uchar * dst_q, + global uint * dst_qh, + global half * dst_d, + global half * dst_m +) { + global struct block_q5_1 * b = (global struct block_q5_1 *) src0 + get_global_id(0); + global uchar * q = (global uchar *) dst_q + QK5_1/2*get_global_id(0); + global uint * qh = (global uint *) dst_qh + get_global_id(0); + global half * d = (global half *) dst_d + get_global_id(0); + global half * m = (global half *) dst_m + get_global_id(0); + + *d = b->d; + *m = b->m; + *qh = *((global uint *)(b->qh)); + + for (int i = 0; i < QK5_1/4; ++i) { + uchar x0 = b->qs[2*i + 0]; + uchar x1 = b->qs[2*i + 1]; + + q[i + 0 ] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4); + q[i + QK5_1/4] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0); + +#ifdef ADRENO_GPU + if (get_global_id(0) == 65536*4096) { + printf("%04x - %02x\n", *(global ushort*)d, ((x0 & 0xF0) >> 4) | (x1 & 0xF0)); + } +#endif + } +} + +kernel void kernel_restore_block_q5_1_noshuffle( + global uchar * src_q, + global uint * src_qh, + global half * src_d, + global half * src_m, + global struct block_q5_1 * dst, + uchar mask_0F, + uchar mask_F0 +) { + global struct block_q5_1 * b = (global struct block_q5_1 *) dst + get_global_id(0); + global uchar * q = (global uchar *) src_q + QK5_1/2*get_global_id(0); + global uint * qh = (global uint *) src_qh + get_global_id(0); + global half * d = (global half *) src_d + get_global_id(0); + global half * m = (global half *) src_m + get_global_id(0); + + b->d = *d; + b->m = *m; + *((global uint *)(b->qh)) = *qh; + + for (int i = 0; i < QK5_1/4; ++i) { + uchar x0 = q[i + 0 ]; + uchar x1 = q[i + QK5_1/4]; + + b->qs[2*i + 0] = convert_uchar((x0 & mask_0F) | ((x1 & mask_0F) << 4)); + b->qs[2*i + 1] = convert_uchar(((x0 & mask_F0) >> 4) | (x1 & mask_F0)); + } +} + kernel void kernel_convert_block_q5_1_trans4_ns( __global struct block_q5_1 * src0, __global uint * dst_qs, diff --git a/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_0_f32.cl b/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_0_f32.cl new file mode 100644 index 00000000000..1d6bd48005e --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_0_f32.cl @@ -0,0 +1,131 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable + +#ifdef cl_qcom_reqd_sub_group_size +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_128 +#endif + +kernel void kernel_gemm_noshuffle_q5_0_f32( + global const ushort * src0_qs, // quantized A + global const uchar * src0_qh, // 5th bits + global const half * src0_d, // A scales + __read_only image1d_buffer_t src1, // B (1d image) + global float * dst, // C + int m, // M + int n, // N with padding + int k, // K + int n_no_padding // N without padding +) { + + int n_4 = n >> 2; + + int gy = get_global_id(0); + int gx = get_global_id(1); + int gx_2 = gx << 2; + + half8 c0 = 0, c1 = 0, c2 = 0, c3 = 0; + half8 B; + half4 dequantized_weights; + + global const ushort * weight_ptr = src0_qs + gx_2; + global const uchar * qh_ptr = src0_qh + gx_2; + global const half * scale_ptr = src0_d + gx_2; + + for (int i = 0; i < k; i += 4) { + + B.s0123 = read_imageh(src1, gy*2 + i*n_4); + B.s4567 = read_imageh(src1, gy*2 + i*n_4 + 1); + + ushort4 bits4 = vload4(0, weight_ptr + (i >> 2)*m); + uchar4 bits1 = vload4(0, qh_ptr + (i >> 3)*m); + uchar4 qh = bits1 >> (uchar4)(i & 4); + + half4 scale = vload4(0, scale_ptr + (i >> 5)*m); + + // j=0 + dequantized_weights.s0 = (convert_half((bits4.s0 & 0x000F) | ((qh.s0 & 0x01) << 4)) - 16.0h) * scale.s0; + dequantized_weights.s1 = (convert_half((bits4.s1 & 0x000F) | ((qh.s1 & 0x01) << 4)) - 16.0h) * scale.s1; + dequantized_weights.s2 = (convert_half((bits4.s2 & 0x000F) | ((qh.s2 & 0x01) << 4)) - 16.0h) * scale.s2; + dequantized_weights.s3 = (convert_half((bits4.s3 & 0x000F) | ((qh.s3 & 0x01) << 4)) - 16.0h) * scale.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=1 + B.s0123 = read_imageh(src1, gy*2 + (i+1)*n_4); + B.s4567 = read_imageh(src1, gy*2 + (i+1)*n_4 + 1); + dequantized_weights.s0 = (convert_half(((bits4.s0 & 0x00F0) >> 4) | ((qh.s0 & 0x02) << 3)) - 16.0h) * scale.s0; + dequantized_weights.s1 = (convert_half(((bits4.s1 & 0x00F0) >> 4) | ((qh.s1 & 0x02) << 3)) - 16.0h) * scale.s1; + dequantized_weights.s2 = (convert_half(((bits4.s2 & 0x00F0) >> 4) | ((qh.s2 & 0x02) << 3)) - 16.0h) * scale.s2; + dequantized_weights.s3 = (convert_half(((bits4.s3 & 0x00F0) >> 4) | ((qh.s3 & 0x02) << 3)) - 16.0h) * scale.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=2 + B.s0123 = read_imageh(src1, gy*2 + (i+2)*n_4); + B.s4567 = read_imageh(src1, gy*2 + (i+2)*n_4 + 1); + dequantized_weights.s0 = (convert_half(((bits4.s0 & 0x0F00) >> 8) | ((qh.s0 & 0x04) << 2)) - 16.0h) * scale.s0; + dequantized_weights.s1 = (convert_half(((bits4.s1 & 0x0F00) >> 8) | ((qh.s1 & 0x04) << 2)) - 16.0h) * scale.s1; + dequantized_weights.s2 = (convert_half(((bits4.s2 & 0x0F00) >> 8) | ((qh.s2 & 0x04) << 2)) - 16.0h) * scale.s2; + dequantized_weights.s3 = (convert_half(((bits4.s3 & 0x0F00) >> 8) | ((qh.s3 & 0x04) << 2)) - 16.0h) * scale.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=3 + B.s0123 = read_imageh(src1, gy*2 + (i+3)*n_4); + B.s4567 = read_imageh(src1, gy*2 + (i+3)*n_4 + 1); + dequantized_weights.s0 = (convert_half(((bits4.s0 & 0xF000) >> 12) | ((qh.s0 & 0x08) << 1)) - 16.0h) * scale.s0; + dequantized_weights.s1 = (convert_half(((bits4.s1 & 0xF000) >> 12) | ((qh.s1 & 0x08) << 1)) - 16.0h) * scale.s1; + dequantized_weights.s2 = (convert_half(((bits4.s2 & 0xF000) >> 12) | ((qh.s2 & 0x08) << 1)) - 16.0h) * scale.s2; + dequantized_weights.s3 = (convert_half(((bits4.s3 & 0xF000) >> 12) | ((qh.s3 & 0x08) << 1)) - 16.0h) * scale.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + } + + int idx = (gy<<3)*m + (gx<<2); + + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s0, c1.s0, c2.s0, c3.s0), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s1, c1.s1, c2.s1, c3.s1), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s2, c1.s2, c2.s2, c3.s2), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s3, c1.s3, c2.s3, c3.s3), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s4, c1.s4, c2.s4, c3.s4), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s5, c1.s5, c2.s5, c3.s5), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s6, c1.s6, c2.s6, c3.s6), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s7, c1.s7, c2.s7, c3.s7), 0, dst + idx); + } +} diff --git a/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_1_f32.cl b/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_1_f32.cl new file mode 100644 index 00000000000..94b4ef6cacc --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_1_f32.cl @@ -0,0 +1,134 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable + +#ifdef cl_qcom_reqd_sub_group_size +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_128 +#endif + +kernel void kernel_gemm_noshuffle_q5_1_f32( + global const ushort * src0_qs, // quantized A + global const uchar * src0_qh, // 5th bits + global const half * src0_d, // A scales + global const half * src0_m, // A mins + __read_only image1d_buffer_t src1, // B (1d image) + global float * dst, // C + int m, // M + int n, // N with padding + int k, // K + int n_no_padding // N without padding +) { + + int n_4 = n >> 2; + + int gy = get_global_id(0); + int gx = get_global_id(1); + int gx_2 = gx << 2; + + half8 c0 = 0, c1 = 0, c2 = 0, c3 = 0; + half8 B; + half4 dequantized_weights; + + global const ushort * weight_ptr = src0_qs + gx_2; + global const uchar * qh_ptr = src0_qh + gx_2; + global const half * scale_ptr = src0_d + gx_2; + global const half * min_ptr = src0_m + gx_2; + + for (int i = 0; i < k; i += 4) { + + B.s0123 = read_imageh(src1, gy*2 + i*n_4); + B.s4567 = read_imageh(src1, gy*2 + i*n_4 + 1); + + ushort4 bits4 = vload4(0, weight_ptr + (i >> 2)*m); + uchar4 bits1 = vload4(0, qh_ptr + (i >> 3)*m); + uchar4 qh = bits1 >> (uchar4)(i & 4); + + half4 scale = vload4(0, scale_ptr + (i >> 5)*m); + half4 minv = vload4(0, min_ptr + (i >> 5)*m); + + // j=0 + dequantized_weights.s0 = convert_half((bits4.s0 & 0x000F) | ((qh.s0 & 0x01) << 4)) * scale.s0 + minv.s0; + dequantized_weights.s1 = convert_half((bits4.s1 & 0x000F) | ((qh.s1 & 0x01) << 4)) * scale.s1 + minv.s1; + dequantized_weights.s2 = convert_half((bits4.s2 & 0x000F) | ((qh.s2 & 0x01) << 4)) * scale.s2 + minv.s2; + dequantized_weights.s3 = convert_half((bits4.s3 & 0x000F) | ((qh.s3 & 0x01) << 4)) * scale.s3 + minv.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=1 + B.s0123 = read_imageh(src1, gy*2 + (i+1)*n_4); + B.s4567 = read_imageh(src1, gy*2 + (i+1)*n_4 + 1); + dequantized_weights.s0 = convert_half(((bits4.s0 & 0x00F0) >> 4) | ((qh.s0 & 0x02) << 3)) * scale.s0 + minv.s0; + dequantized_weights.s1 = convert_half(((bits4.s1 & 0x00F0) >> 4) | ((qh.s1 & 0x02) << 3)) * scale.s1 + minv.s1; + dequantized_weights.s2 = convert_half(((bits4.s2 & 0x00F0) >> 4) | ((qh.s2 & 0x02) << 3)) * scale.s2 + minv.s2; + dequantized_weights.s3 = convert_half(((bits4.s3 & 0x00F0) >> 4) | ((qh.s3 & 0x02) << 3)) * scale.s3 + minv.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=2 + B.s0123 = read_imageh(src1, gy*2 + (i+2)*n_4); + B.s4567 = read_imageh(src1, gy*2 + (i+2)*n_4 + 1); + dequantized_weights.s0 = convert_half(((bits4.s0 & 0x0F00) >> 8) | ((qh.s0 & 0x04) << 2)) * scale.s0 + minv.s0; + dequantized_weights.s1 = convert_half(((bits4.s1 & 0x0F00) >> 8) | ((qh.s1 & 0x04) << 2)) * scale.s1 + minv.s1; + dequantized_weights.s2 = convert_half(((bits4.s2 & 0x0F00) >> 8) | ((qh.s2 & 0x04) << 2)) * scale.s2 + minv.s2; + dequantized_weights.s3 = convert_half(((bits4.s3 & 0x0F00) >> 8) | ((qh.s3 & 0x04) << 2)) * scale.s3 + minv.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=3 + B.s0123 = read_imageh(src1, gy*2 + (i+3)*n_4); + B.s4567 = read_imageh(src1, gy*2 + (i+3)*n_4 + 1); + dequantized_weights.s0 = convert_half(((bits4.s0 & 0xF000) >> 12) | ((qh.s0 & 0x08) << 1)) * scale.s0 + minv.s0; + dequantized_weights.s1 = convert_half(((bits4.s1 & 0xF000) >> 12) | ((qh.s1 & 0x08) << 1)) * scale.s1 + minv.s1; + dequantized_weights.s2 = convert_half(((bits4.s2 & 0xF000) >> 12) | ((qh.s2 & 0x08) << 1)) * scale.s2 + minv.s2; + dequantized_weights.s3 = convert_half(((bits4.s3 & 0xF000) >> 12) | ((qh.s3 & 0x08) << 1)) * scale.s3 + minv.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + } + + int idx = (gy<<3)*m + (gx<<2); + + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s0, c1.s0, c2.s0, c3.s0), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s1, c1.s1, c2.s1, c3.s1), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s2, c1.s2, c2.s2, c3.s2), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s3, c1.s3, c2.s3, c3.s3), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s4, c1.s4, c2.s4, c3.s4), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s5, c1.s5, c2.s5, c3.s5), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s6, c1.s6, c2.s6, c3.s6), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s7, c1.s7, c2.s7, c3.s7), 0, dst + idx); + } +} diff --git a/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_0_f32.cl b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_0_f32.cl new file mode 100644 index 00000000000..c228f717a94 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_0_f32.cl @@ -0,0 +1,291 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable + +#ifdef cl_qcom_reqd_sub_group_size +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#endif + +#define QK5_0 32 +#define NSUBGROUPS 4 +#define SUBGROUP_SIZE 64 + +#define dequantizeBlockAccum_ns_q5_0_sgbroadcast_1_hi(total_sums, bits4, bits1, scale, y) \ + float shared_y; \ + shared_y = sub_group_broadcast(y.s0, 0); \ + total_sums.s0 += (((bits4.s0 & 0x000F) | (((bits1.s0 ) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0x000F) | (((bits1.s4 ) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 0); \ + total_sums.s0 += ((((bits4.s0 & 0x00F0) >> 4) | (((bits1.s0 >> 1) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s1 & 0x00F0) >> 4) | (((bits1.s4 >> 1) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 0); \ + total_sums.s0 += ((((bits4.s0 & 0x0F00) >> 8) | (((bits1.s0 >> 2) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s1 & 0x0F00) >> 8) | (((bits1.s4 >> 2) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 0); \ + total_sums.s0 += ((((bits4.s0 & 0xF000) >> 12) | (((bits1.s0 >> 3) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s1 & 0xF000) >> 12) | (((bits1.s4 >> 3) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 0); \ + total_sums.s0 += (((bits4.s2 & 0x000F) | (((bits1.s0 >> 4) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0x000F) | (((bits1.s4 >> 4) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 0); \ + total_sums.s0 += ((((bits4.s2 & 0x00F0) >> 4) | (((bits1.s0 >> 5) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s3 & 0x00F0) >> 4) | (((bits1.s4 >> 5) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 0); \ + total_sums.s0 += ((((bits4.s2 & 0x0F00) >> 8) | (((bits1.s0 >> 6) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s3 & 0x0F00) >> 8) | (((bits1.s4 >> 6) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 0); \ + total_sums.s0 += ((((bits4.s2 & 0xF000) >> 12) | (((bits1.s0 >> 7) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s3 & 0xF000) >> 12) | (((bits1.s4 >> 7) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s0, 1); \ + total_sums.s0 += (((bits4.s4 & 0x000F) | (((bits1.s1 ) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0x000F) | (((bits1.s5 ) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 1); \ + total_sums.s0 += ((((bits4.s4 & 0x00F0) >> 4) | (((bits1.s1 >> 1) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s5 & 0x00F0) >> 4) | (((bits1.s5 >> 1) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 1); \ + total_sums.s0 += ((((bits4.s4 & 0x0F00) >> 8) | (((bits1.s1 >> 2) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s5 & 0x0F00) >> 8) | (((bits1.s5 >> 2) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 1); \ + total_sums.s0 += ((((bits4.s4 & 0xF000) >> 12) | (((bits1.s1 >> 3) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s5 & 0xF000) >> 12) | (((bits1.s5 >> 3) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 1); \ + total_sums.s0 += (((bits4.s6 & 0x000F) | (((bits1.s1 >> 4) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0x000F) | (((bits1.s5 >> 4) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 1); \ + total_sums.s0 += ((((bits4.s6 & 0x00F0) >> 4) | (((bits1.s1 >> 5) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s7 & 0x00F0) >> 4) | (((bits1.s5 >> 5) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 1); \ + total_sums.s0 += ((((bits4.s6 & 0x0F00) >> 8) | (((bits1.s1 >> 6) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s7 & 0x0F00) >> 8) | (((bits1.s5 >> 6) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 1); \ + total_sums.s0 += ((((bits4.s6 & 0xF000) >> 12) | (((bits1.s1 >> 7) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s7 & 0xF000) >> 12) | (((bits1.s5 >> 7) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + + +#define dequantizeBlockAccum_ns_q5_0_sgbroadcast_1_lo(total_sums, bits4, bits1, scale, y) \ + shared_y = sub_group_broadcast(y.s0, 2); \ + total_sums.s0 += (((bits4.s0 & 0x000F) | (((bits1.s2 ) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0x000F) | (((bits1.s6 ) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 2); \ + total_sums.s0 += ((((bits4.s0 & 0x00F0) >> 4) | (((bits1.s2 >> 1) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s1 & 0x00F0) >> 4) | (((bits1.s6 >> 1) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 2); \ + total_sums.s0 += ((((bits4.s0 & 0x0F00) >> 8) | (((bits1.s2 >> 2) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s1 & 0x0F00) >> 8) | (((bits1.s6 >> 2) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 2); \ + total_sums.s0 += ((((bits4.s0 & 0xF000) >> 12) | (((bits1.s2 >> 3) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s1 & 0xF000) >> 12) | (((bits1.s6 >> 3) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 2); \ + total_sums.s0 += (((bits4.s2 & 0x000F) | (((bits1.s2 >> 4) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0x000F) | (((bits1.s6 >> 4) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 2); \ + total_sums.s0 += ((((bits4.s2 & 0x00F0) >> 4) | (((bits1.s2 >> 5) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s3 & 0x00F0) >> 4) | (((bits1.s6 >> 5) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 2); \ + total_sums.s0 += ((((bits4.s2 & 0x0F00) >> 8) | (((bits1.s2 >> 6) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s3 & 0x0F00) >> 8) | (((bits1.s6 >> 6) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 2); \ + total_sums.s0 += ((((bits4.s2 & 0xF000) >> 12) | (((bits1.s2 >> 7) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s3 & 0xF000) >> 12) | (((bits1.s6 >> 7) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s0, 3); \ + total_sums.s0 += (((bits4.s4 & 0x000F) | (((bits1.s3 ) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0x000F) | (((bits1.s7 ) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 3); \ + total_sums.s0 += ((((bits4.s4 & 0x00F0) >> 4) | (((bits1.s3 >> 1) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s5 & 0x00F0) >> 4) | (((bits1.s7 >> 1) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 3); \ + total_sums.s0 += ((((bits4.s4 & 0x0F00) >> 8) | (((bits1.s3 >> 2) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s5 & 0x0F00) >> 8) | (((bits1.s7 >> 2) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 3); \ + total_sums.s0 += ((((bits4.s4 & 0xF000) >> 12) | (((bits1.s3 >> 3) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s5 & 0xF000) >> 12) | (((bits1.s7 >> 3) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 3); \ + total_sums.s0 += (((bits4.s6 & 0x000F) | (((bits1.s3 >> 4) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0x000F) | (((bits1.s7 >> 4) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 3); \ + total_sums.s0 += ((((bits4.s6 & 0x00F0) >> 4) | (((bits1.s3 >> 5) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s7 & 0x00F0) >> 4) | (((bits1.s7 >> 5) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 3); \ + total_sums.s0 += ((((bits4.s6 & 0x0F00) >> 8) | (((bits1.s3 >> 6) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s7 & 0x0F00) >> 8) | (((bits1.s7 >> 6) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 3); \ + total_sums.s0 += ((((bits4.s6 & 0xF000) >> 12) | (((bits1.s3 >> 7) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s7 & 0xF000) >> 12) | (((bits1.s7 >> 7) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + + +#define dequantizeBlockAccum_ns_q5_0_sgbroadcast_8_hi(total_sums, bits4, bits1, scale, y) \ + float8 shared_y; \ + shared_y = sub_group_broadcast(y, 0); \ + total_sums.s0 += (((bits4.s0 & 0x000F) | (((bits1.s0 ) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s0; \ + total_sums.s0 += ((((bits4.s0 & 0x00F0) >> 4) | (((bits1.s0 >> 1) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s1; \ + total_sums.s0 += ((((bits4.s0 & 0x0F00) >> 8) | (((bits1.s0 >> 2) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s2; \ + total_sums.s0 += ((((bits4.s0 & 0xF000) >> 12) | (((bits1.s0 >> 3) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s3; \ + total_sums.s0 += (((bits4.s2 & 0x000F) | (((bits1.s0 >> 4) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s4; \ + total_sums.s0 += ((((bits4.s2 & 0x00F0) >> 4) | (((bits1.s0 >> 5) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s5; \ + total_sums.s0 += ((((bits4.s2 & 0x0F00) >> 8) | (((bits1.s0 >> 6) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s6; \ + total_sums.s0 += ((((bits4.s2 & 0xF000) >> 12) | (((bits1.s0 >> 7) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s7; \ + total_sums.s1 += (((bits4.s1 & 0x000F) | (((bits1.s4 ) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s0; \ + total_sums.s1 += ((((bits4.s1 & 0x00F0) >> 4) | (((bits1.s4 >> 1) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s1; \ + total_sums.s1 += ((((bits4.s1 & 0x0F00) >> 8) | (((bits1.s4 >> 2) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s2; \ + total_sums.s1 += ((((bits4.s1 & 0xF000) >> 12) | (((bits1.s4 >> 3) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s3; \ + total_sums.s1 += (((bits4.s3 & 0x000F) | (((bits1.s4 >> 4) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s4; \ + total_sums.s1 += ((((bits4.s3 & 0x00F0) >> 4) | (((bits1.s4 >> 5) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s5; \ + total_sums.s1 += ((((bits4.s3 & 0x0F00) >> 8) | (((bits1.s4 >> 6) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s6; \ + total_sums.s1 += ((((bits4.s3 & 0xF000) >> 12) | (((bits1.s4 >> 7) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s7; \ + shared_y = sub_group_broadcast(y, 1); \ + total_sums.s0 += (((bits4.s4 & 0x000F) | (((bits1.s1 ) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s0; \ + total_sums.s0 += ((((bits4.s4 & 0x00F0) >> 4) | (((bits1.s1 >> 1) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s1; \ + total_sums.s0 += ((((bits4.s4 & 0x0F00) >> 8) | (((bits1.s1 >> 2) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s2; \ + total_sums.s0 += ((((bits4.s4 & 0xF000) >> 12) | (((bits1.s1 >> 3) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s3; \ + total_sums.s0 += (((bits4.s6 & 0x000F) | (((bits1.s1 >> 4) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s4; \ + total_sums.s0 += ((((bits4.s6 & 0x00F0) >> 4) | (((bits1.s1 >> 5) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s5; \ + total_sums.s0 += ((((bits4.s6 & 0x0F00) >> 8) | (((bits1.s1 >> 6) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s6; \ + total_sums.s0 += ((((bits4.s6 & 0xF000) >> 12) | (((bits1.s1 >> 7) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s7; \ + total_sums.s1 += (((bits4.s5 & 0x000F) | (((bits1.s5 ) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s0; \ + total_sums.s1 += ((((bits4.s5 & 0x00F0) >> 4) | (((bits1.s5 >> 1) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s1; \ + total_sums.s1 += ((((bits4.s5 & 0x0F00) >> 8) | (((bits1.s5 >> 2) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s2; \ + total_sums.s1 += ((((bits4.s5 & 0xF000) >> 12) | (((bits1.s5 >> 3) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s3; \ + total_sums.s1 += (((bits4.s7 & 0x000F) | (((bits1.s5 >> 4) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s4; \ + total_sums.s1 += ((((bits4.s7 & 0x00F0) >> 4) | (((bits1.s5 >> 5) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s5; \ + total_sums.s1 += ((((bits4.s7 & 0x0F00) >> 8) | (((bits1.s5 >> 6) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s6; \ + total_sums.s1 += ((((bits4.s7 & 0xF000) >> 12) | (((bits1.s5 >> 7) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s7; \ + + +#define dequantizeBlockAccum_ns_q5_0_sgbroadcast_8_lo(total_sums, bits4, bits1, scale, y) \ + shared_y = sub_group_broadcast(y, 2); \ + total_sums.s0 += (((bits4.s0 & 0x000F) | (((bits1.s2 ) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s0; \ + total_sums.s0 += ((((bits4.s0 & 0x00F0) >> 4) | (((bits1.s2 >> 1) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s1; \ + total_sums.s0 += ((((bits4.s0 & 0x0F00) >> 8) | (((bits1.s2 >> 2) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s2; \ + total_sums.s0 += ((((bits4.s0 & 0xF000) >> 12) | (((bits1.s2 >> 3) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s3; \ + total_sums.s0 += (((bits4.s2 & 0x000F) | (((bits1.s2 >> 4) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s4; \ + total_sums.s0 += ((((bits4.s2 & 0x00F0) >> 4) | (((bits1.s2 >> 5) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s5; \ + total_sums.s0 += ((((bits4.s2 & 0x0F00) >> 8) | (((bits1.s2 >> 6) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s6; \ + total_sums.s0 += ((((bits4.s2 & 0xF000) >> 12) | (((bits1.s2 >> 7) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s7; \ + total_sums.s1 += (((bits4.s1 & 0x000F) | (((bits1.s6 ) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s0; \ + total_sums.s1 += ((((bits4.s1 & 0x00F0) >> 4) | (((bits1.s6 >> 1) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s1; \ + total_sums.s1 += ((((bits4.s1 & 0x0F00) >> 8) | (((bits1.s6 >> 2) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s2; \ + total_sums.s1 += ((((bits4.s1 & 0xF000) >> 12) | (((bits1.s6 >> 3) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s3; \ + total_sums.s1 += (((bits4.s3 & 0x000F) | (((bits1.s6 >> 4) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s4; \ + total_sums.s1 += ((((bits4.s3 & 0x00F0) >> 4) | (((bits1.s6 >> 5) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s5; \ + total_sums.s1 += ((((bits4.s3 & 0x0F00) >> 8) | (((bits1.s6 >> 6) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s6; \ + total_sums.s1 += ((((bits4.s3 & 0xF000) >> 12) | (((bits1.s6 >> 7) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s7; \ + shared_y = sub_group_broadcast(y, 3); \ + total_sums.s0 += (((bits4.s4 & 0x000F) | (((bits1.s3 ) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s0; \ + total_sums.s0 += ((((bits4.s4 & 0x00F0) >> 4) | (((bits1.s3 >> 1) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s1; \ + total_sums.s0 += ((((bits4.s4 & 0x0F00) >> 8) | (((bits1.s3 >> 2) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s2; \ + total_sums.s0 += ((((bits4.s4 & 0xF000) >> 12) | (((bits1.s3 >> 3) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s3; \ + total_sums.s0 += (((bits4.s6 & 0x000F) | (((bits1.s3 >> 4) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s4; \ + total_sums.s0 += ((((bits4.s6 & 0x00F0) >> 4) | (((bits1.s3 >> 5) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s5; \ + total_sums.s0 += ((((bits4.s6 & 0x0F00) >> 8) | (((bits1.s3 >> 6) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s6; \ + total_sums.s0 += ((((bits4.s6 & 0xF000) >> 12) | (((bits1.s3 >> 7) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s7; \ + total_sums.s1 += (((bits4.s5 & 0x000F) | (((bits1.s7 ) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s0; \ + total_sums.s1 += ((((bits4.s5 & 0x00F0) >> 4) | (((bits1.s7 >> 1) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s1; \ + total_sums.s1 += ((((bits4.s5 & 0x0F00) >> 8) | (((bits1.s7 >> 2) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s2; \ + total_sums.s1 += ((((bits4.s5 & 0xF000) >> 12) | (((bits1.s7 >> 3) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s3; \ + total_sums.s1 += (((bits4.s7 & 0x000F) | (((bits1.s7 >> 4) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s4; \ + total_sums.s1 += ((((bits4.s7 & 0x00F0) >> 4) | (((bits1.s7 >> 5) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s5; \ + total_sums.s1 += ((((bits4.s7 & 0x0F00) >> 8) | (((bits1.s7 >> 6) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s6; \ + total_sums.s1 += ((((bits4.s7 & 0xF000) >> 12) | (((bits1.s7 >> 7) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s7; \ + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +__kernel void kernel_gemv_noshuffle_q5_0_f32( + __read_only image1d_buffer_t src0_qs, // quantized A + global ushort * src0_qh, // 5th bits + global half2 * src0_d, // A scales + __read_only image1d_buffer_t src1, // B activations + global float * dst, + ulong offsetd, + int ne00, // K + int ne01) // M +{ + uint groupId = get_local_id(1); + uint gid = get_global_id(0); + ushort slid = get_sub_group_local_id(); + + uint K = ne00; + uint M = ne01; + + uint LINE_STRIDE_A = M / 2; + uint BLOCK_STRIDE_A = NSUBGROUPS * M; + + private uint4 regA; + private half2 regS; + private float8 regB; + + private float2 totalSum = (float2)(0.0f); + + for (uint k = groupId; k < (K / QK5_0); k += NSUBGROUPS) { + regS = src0_d[gid + k * LINE_STRIDE_A]; + + ushort4 qh_raw; + qh_raw.s0 = src0_qh[gid + (4*k + 0) * LINE_STRIDE_A]; + qh_raw.s1 = src0_qh[gid + (4*k + 1) * LINE_STRIDE_A]; + qh_raw.s2 = src0_qh[gid + (4*k + 2) * LINE_STRIDE_A]; + qh_raw.s3 = src0_qh[gid + (4*k + 3) * LINE_STRIDE_A]; + + uchar8 raw = as_uchar8(qh_raw); + uchar8 qh_bytes = (uchar8)(raw.s0, raw.s2, raw.s4, raw.s6, + raw.s1, raw.s3, raw.s5, raw.s7); + + // Load activations + if (slid < 4) { + regB.s0123 = read_imagef(src1, (slid * 2 + k * 8)); + regB.s4567 = read_imagef(src1, (1 + slid * 2 + k * 8)); + } + + regA.s0 = read_imageui(src0_qs, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 0)).x; + regA.s1 = read_imageui(src0_qs, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 1)).x; + regA.s2 = read_imageui(src0_qs, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 2)).x; + regA.s3 = read_imageui(src0_qs, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 3)).x; + +#ifdef VECTOR_SUB_GROUP_BROADCAST + dequantizeBlockAccum_ns_q5_0_sgbroadcast_8_hi(totalSum, as_ushort8(regA), qh_bytes, regS, regB); +#else + dequantizeBlockAccum_ns_q5_0_sgbroadcast_1_hi(totalSum, as_ushort8(regA), qh_bytes, regS, regB); +#endif // VECTOR_SUB_GROUP_BROADCAST + + regA.s0 = read_imageui(src0_qs, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 4)).x; + regA.s1 = read_imageui(src0_qs, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 5)).x; + regA.s2 = read_imageui(src0_qs, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x; + regA.s3 = read_imageui(src0_qs, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x; +#ifdef VECTOR_SUB_GROUP_BROADCAST + dequantizeBlockAccum_ns_q5_0_sgbroadcast_8_lo(totalSum, as_ushort8(regA), qh_bytes, regS, regB); +#else + dequantizeBlockAccum_ns_q5_0_sgbroadcast_1_lo(totalSum, as_ushort8(regA), qh_bytes, regS, regB); +#endif // VECTOR_SUB_GROUP_BROADCAST + } + + // reduction in local memory, assumes #wave=4 + local float2 reduceLM[SUBGROUP_SIZE * 3]; + if (groupId == 1) { + reduceLM[SUBGROUP_SIZE * 0 + slid] = totalSum; + } + if (groupId == 2) { + reduceLM[SUBGROUP_SIZE * 1 + slid] = totalSum; + } + if (groupId == 3) { + reduceLM[SUBGROUP_SIZE * 2 + slid] = totalSum; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + if (groupId == 0) { + totalSum += reduceLM[SUBGROUP_SIZE * 0 + slid]; + } + if (groupId == 0) { + totalSum += reduceLM[SUBGROUP_SIZE * 1 + slid]; + } + if (groupId == 0) { + totalSum += reduceLM[SUBGROUP_SIZE * 2 + slid]; + } + + // 2 outputs per fiber in wave 0 + if (groupId == 0) { + dst = (global float*)((global char*)dst + offsetd); + vstore2(totalSum, 0, &(dst[gid * 2])); + } + +} diff --git a/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_1_f32.cl b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_1_f32.cl new file mode 100644 index 00000000000..daf1308ea4b --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_1_f32.cl @@ -0,0 +1,294 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable + +#ifdef cl_qcom_reqd_sub_group_size +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#endif + +#define QK5_1 32 +#define NSUBGROUPS 4 +#define SUBGROUP_SIZE 64 + +#define dequantizeBlockAccum_ns_q5_1_sgbroadcast_1_hi(total_sums, bits4, bits1, scale, minv, y) \ + float shared_y; \ + shared_y = sub_group_broadcast(y.s0, 0); \ + total_sums.s0 += (((bits4.s0 & 0x000F) | (((bits1.s0 ) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0x000F) | (((bits1.s4 ) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 0); \ + total_sums.s0 += ((((bits4.s0 & 0x00F0) >> 4) | (((bits1.s0 >> 1) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s1 & 0x00F0) >> 4) | (((bits1.s4 >> 1) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 0); \ + total_sums.s0 += ((((bits4.s0 & 0x0F00) >> 8) | (((bits1.s0 >> 2) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s1 & 0x0F00) >> 8) | (((bits1.s4 >> 2) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 0); \ + total_sums.s0 += ((((bits4.s0 & 0xF000) >> 12) | (((bits1.s0 >> 3) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s1 & 0xF000) >> 12) | (((bits1.s4 >> 3) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 0); \ + total_sums.s0 += (((bits4.s2 & 0x000F) | (((bits1.s0 >> 4) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0x000F) | (((bits1.s4 >> 4) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 0); \ + total_sums.s0 += ((((bits4.s2 & 0x00F0) >> 4) | (((bits1.s0 >> 5) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s3 & 0x00F0) >> 4) | (((bits1.s4 >> 5) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 0); \ + total_sums.s0 += ((((bits4.s2 & 0x0F00) >> 8) | (((bits1.s0 >> 6) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s3 & 0x0F00) >> 8) | (((bits1.s4 >> 6) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 0); \ + total_sums.s0 += ((((bits4.s2 & 0xF000) >> 12) | (((bits1.s0 >> 7) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s3 & 0xF000) >> 12) | (((bits1.s4 >> 7) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s0, 1); \ + total_sums.s0 += (((bits4.s4 & 0x000F) | (((bits1.s1 ) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0x000F) | (((bits1.s5 ) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 1); \ + total_sums.s0 += ((((bits4.s4 & 0x00F0) >> 4) | (((bits1.s1 >> 1) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s5 & 0x00F0) >> 4) | (((bits1.s5 >> 1) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 1); \ + total_sums.s0 += ((((bits4.s4 & 0x0F00) >> 8) | (((bits1.s1 >> 2) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s5 & 0x0F00) >> 8) | (((bits1.s5 >> 2) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 1); \ + total_sums.s0 += ((((bits4.s4 & 0xF000) >> 12) | (((bits1.s1 >> 3) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s5 & 0xF000) >> 12) | (((bits1.s5 >> 3) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 1); \ + total_sums.s0 += (((bits4.s6 & 0x000F) | (((bits1.s1 >> 4) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0x000F) | (((bits1.s5 >> 4) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 1); \ + total_sums.s0 += ((((bits4.s6 & 0x00F0) >> 4) | (((bits1.s1 >> 5) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s7 & 0x00F0) >> 4) | (((bits1.s5 >> 5) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 1); \ + total_sums.s0 += ((((bits4.s6 & 0x0F00) >> 8) | (((bits1.s1 >> 6) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s7 & 0x0F00) >> 8) | (((bits1.s5 >> 6) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 1); \ + total_sums.s0 += ((((bits4.s6 & 0xF000) >> 12) | (((bits1.s1 >> 7) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s7 & 0xF000) >> 12) | (((bits1.s5 >> 7) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + + +#define dequantizeBlockAccum_ns_q5_1_sgbroadcast_1_lo(total_sums, bits4, bits1, scale, minv, y) \ + shared_y = sub_group_broadcast(y.s0, 2); \ + total_sums.s0 += (((bits4.s0 & 0x000F) | (((bits1.s2 ) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0x000F) | (((bits1.s6 ) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 2); \ + total_sums.s0 += ((((bits4.s0 & 0x00F0) >> 4) | (((bits1.s2 >> 1) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s1 & 0x00F0) >> 4) | (((bits1.s6 >> 1) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 2); \ + total_sums.s0 += ((((bits4.s0 & 0x0F00) >> 8) | (((bits1.s2 >> 2) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s1 & 0x0F00) >> 8) | (((bits1.s6 >> 2) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 2); \ + total_sums.s0 += ((((bits4.s0 & 0xF000) >> 12) | (((bits1.s2 >> 3) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s1 & 0xF000) >> 12) | (((bits1.s6 >> 3) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 2); \ + total_sums.s0 += (((bits4.s2 & 0x000F) | (((bits1.s2 >> 4) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0x000F) | (((bits1.s6 >> 4) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 2); \ + total_sums.s0 += ((((bits4.s2 & 0x00F0) >> 4) | (((bits1.s2 >> 5) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s3 & 0x00F0) >> 4) | (((bits1.s6 >> 5) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 2); \ + total_sums.s0 += ((((bits4.s2 & 0x0F00) >> 8) | (((bits1.s2 >> 6) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s3 & 0x0F00) >> 8) | (((bits1.s6 >> 6) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 2); \ + total_sums.s0 += ((((bits4.s2 & 0xF000) >> 12) | (((bits1.s2 >> 7) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s3 & 0xF000) >> 12) | (((bits1.s6 >> 7) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s0, 3); \ + total_sums.s0 += (((bits4.s4 & 0x000F) | (((bits1.s3 ) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0x000F) | (((bits1.s7 ) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 3); \ + total_sums.s0 += ((((bits4.s4 & 0x00F0) >> 4) | (((bits1.s3 >> 1) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s5 & 0x00F0) >> 4) | (((bits1.s7 >> 1) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 3); \ + total_sums.s0 += ((((bits4.s4 & 0x0F00) >> 8) | (((bits1.s3 >> 2) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s5 & 0x0F00) >> 8) | (((bits1.s7 >> 2) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 3); \ + total_sums.s0 += ((((bits4.s4 & 0xF000) >> 12) | (((bits1.s3 >> 3) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s5 & 0xF000) >> 12) | (((bits1.s7 >> 3) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 3); \ + total_sums.s0 += (((bits4.s6 & 0x000F) | (((bits1.s3 >> 4) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0x000F) | (((bits1.s7 >> 4) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 3); \ + total_sums.s0 += ((((bits4.s6 & 0x00F0) >> 4) | (((bits1.s3 >> 5) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s7 & 0x00F0) >> 4) | (((bits1.s7 >> 5) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 3); \ + total_sums.s0 += ((((bits4.s6 & 0x0F00) >> 8) | (((bits1.s3 >> 6) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s7 & 0x0F00) >> 8) | (((bits1.s7 >> 6) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 3); \ + total_sums.s0 += ((((bits4.s6 & 0xF000) >> 12) | (((bits1.s3 >> 7) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s7 & 0xF000) >> 12) | (((bits1.s7 >> 7) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + + +#define dequantizeBlockAccum_ns_q5_1_sgbroadcast_8_hi(total_sums, bits4, bits1, scale, minv, y) \ + float8 shared_y; \ + shared_y = sub_group_broadcast(y, 0); \ + total_sums.s0 += (((bits4.s0 & 0x000F) | (((bits1.s0 ) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s0; \ + total_sums.s0 += ((((bits4.s0 & 0x00F0) >> 4) | (((bits1.s0 >> 1) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s1; \ + total_sums.s0 += ((((bits4.s0 & 0x0F00) >> 8) | (((bits1.s0 >> 2) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s2; \ + total_sums.s0 += ((((bits4.s0 & 0xF000) >> 12) | (((bits1.s0 >> 3) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s3; \ + total_sums.s0 += (((bits4.s2 & 0x000F) | (((bits1.s0 >> 4) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s4; \ + total_sums.s0 += ((((bits4.s2 & 0x00F0) >> 4) | (((bits1.s0 >> 5) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s5; \ + total_sums.s0 += ((((bits4.s2 & 0x0F00) >> 8) | (((bits1.s0 >> 6) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s6; \ + total_sums.s0 += ((((bits4.s2 & 0xF000) >> 12) | (((bits1.s0 >> 7) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s7; \ + total_sums.s1 += (((bits4.s1 & 0x000F) | (((bits1.s4 ) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s0; \ + total_sums.s1 += ((((bits4.s1 & 0x00F0) >> 4) | (((bits1.s4 >> 1) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s1; \ + total_sums.s1 += ((((bits4.s1 & 0x0F00) >> 8) | (((bits1.s4 >> 2) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s2; \ + total_sums.s1 += ((((bits4.s1 & 0xF000) >> 12) | (((bits1.s4 >> 3) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s3; \ + total_sums.s1 += (((bits4.s3 & 0x000F) | (((bits1.s4 >> 4) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s4; \ + total_sums.s1 += ((((bits4.s3 & 0x00F0) >> 4) | (((bits1.s4 >> 5) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s5; \ + total_sums.s1 += ((((bits4.s3 & 0x0F00) >> 8) | (((bits1.s4 >> 6) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s6; \ + total_sums.s1 += ((((bits4.s3 & 0xF000) >> 12) | (((bits1.s4 >> 7) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s7; \ + shared_y = sub_group_broadcast(y, 1); \ + total_sums.s0 += (((bits4.s4 & 0x000F) | (((bits1.s1 ) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s0; \ + total_sums.s0 += ((((bits4.s4 & 0x00F0) >> 4) | (((bits1.s1 >> 1) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s1; \ + total_sums.s0 += ((((bits4.s4 & 0x0F00) >> 8) | (((bits1.s1 >> 2) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s2; \ + total_sums.s0 += ((((bits4.s4 & 0xF000) >> 12) | (((bits1.s1 >> 3) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s3; \ + total_sums.s0 += (((bits4.s6 & 0x000F) | (((bits1.s1 >> 4) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s4; \ + total_sums.s0 += ((((bits4.s6 & 0x00F0) >> 4) | (((bits1.s1 >> 5) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s5; \ + total_sums.s0 += ((((bits4.s6 & 0x0F00) >> 8) | (((bits1.s1 >> 6) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s6; \ + total_sums.s0 += ((((bits4.s6 & 0xF000) >> 12) | (((bits1.s1 >> 7) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s7; \ + total_sums.s1 += (((bits4.s5 & 0x000F) | (((bits1.s5 ) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s0; \ + total_sums.s1 += ((((bits4.s5 & 0x00F0) >> 4) | (((bits1.s5 >> 1) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s1; \ + total_sums.s1 += ((((bits4.s5 & 0x0F00) >> 8) | (((bits1.s5 >> 2) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s2; \ + total_sums.s1 += ((((bits4.s5 & 0xF000) >> 12) | (((bits1.s5 >> 3) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s3; \ + total_sums.s1 += (((bits4.s7 & 0x000F) | (((bits1.s5 >> 4) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s4; \ + total_sums.s1 += ((((bits4.s7 & 0x00F0) >> 4) | (((bits1.s5 >> 5) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s5; \ + total_sums.s1 += ((((bits4.s7 & 0x0F00) >> 8) | (((bits1.s5 >> 6) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s6; \ + total_sums.s1 += ((((bits4.s7 & 0xF000) >> 12) | (((bits1.s5 >> 7) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s7; \ + + +#define dequantizeBlockAccum_ns_q5_1_sgbroadcast_8_lo(total_sums, bits4, bits1, scale, minv, y) \ + shared_y = sub_group_broadcast(y, 2); \ + total_sums.s0 += (((bits4.s0 & 0x000F) | (((bits1.s2 ) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s0; \ + total_sums.s0 += ((((bits4.s0 & 0x00F0) >> 4) | (((bits1.s2 >> 1) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s1; \ + total_sums.s0 += ((((bits4.s0 & 0x0F00) >> 8) | (((bits1.s2 >> 2) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s2; \ + total_sums.s0 += ((((bits4.s0 & 0xF000) >> 12) | (((bits1.s2 >> 3) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s3; \ + total_sums.s0 += (((bits4.s2 & 0x000F) | (((bits1.s2 >> 4) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s4; \ + total_sums.s0 += ((((bits4.s2 & 0x00F0) >> 4) | (((bits1.s2 >> 5) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s5; \ + total_sums.s0 += ((((bits4.s2 & 0x0F00) >> 8) | (((bits1.s2 >> 6) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s6; \ + total_sums.s0 += ((((bits4.s2 & 0xF000) >> 12) | (((bits1.s2 >> 7) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s7; \ + total_sums.s1 += (((bits4.s1 & 0x000F) | (((bits1.s6 ) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s0; \ + total_sums.s1 += ((((bits4.s1 & 0x00F0) >> 4) | (((bits1.s6 >> 1) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s1; \ + total_sums.s1 += ((((bits4.s1 & 0x0F00) >> 8) | (((bits1.s6 >> 2) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s2; \ + total_sums.s1 += ((((bits4.s1 & 0xF000) >> 12) | (((bits1.s6 >> 3) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s3; \ + total_sums.s1 += (((bits4.s3 & 0x000F) | (((bits1.s6 >> 4) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s4; \ + total_sums.s1 += ((((bits4.s3 & 0x00F0) >> 4) | (((bits1.s6 >> 5) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s5; \ + total_sums.s1 += ((((bits4.s3 & 0x0F00) >> 8) | (((bits1.s6 >> 6) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s6; \ + total_sums.s1 += ((((bits4.s3 & 0xF000) >> 12) | (((bits1.s6 >> 7) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s7; \ + shared_y = sub_group_broadcast(y, 3); \ + total_sums.s0 += (((bits4.s4 & 0x000F) | (((bits1.s3 ) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s0; \ + total_sums.s0 += ((((bits4.s4 & 0x00F0) >> 4) | (((bits1.s3 >> 1) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s1; \ + total_sums.s0 += ((((bits4.s4 & 0x0F00) >> 8) | (((bits1.s3 >> 2) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s2; \ + total_sums.s0 += ((((bits4.s4 & 0xF000) >> 12) | (((bits1.s3 >> 3) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s3; \ + total_sums.s0 += (((bits4.s6 & 0x000F) | (((bits1.s3 >> 4) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s4; \ + total_sums.s0 += ((((bits4.s6 & 0x00F0) >> 4) | (((bits1.s3 >> 5) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s5; \ + total_sums.s0 += ((((bits4.s6 & 0x0F00) >> 8) | (((bits1.s3 >> 6) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s6; \ + total_sums.s0 += ((((bits4.s6 & 0xF000) >> 12) | (((bits1.s3 >> 7) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s7; \ + total_sums.s1 += (((bits4.s5 & 0x000F) | (((bits1.s7 ) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s0; \ + total_sums.s1 += ((((bits4.s5 & 0x00F0) >> 4) | (((bits1.s7 >> 1) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s1; \ + total_sums.s1 += ((((bits4.s5 & 0x0F00) >> 8) | (((bits1.s7 >> 2) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s2; \ + total_sums.s1 += ((((bits4.s5 & 0xF000) >> 12) | (((bits1.s7 >> 3) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s3; \ + total_sums.s1 += (((bits4.s7 & 0x000F) | (((bits1.s7 >> 4) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s4; \ + total_sums.s1 += ((((bits4.s7 & 0x00F0) >> 4) | (((bits1.s7 >> 5) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s5; \ + total_sums.s1 += ((((bits4.s7 & 0x0F00) >> 8) | (((bits1.s7 >> 6) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s6; \ + total_sums.s1 += ((((bits4.s7 & 0xF000) >> 12) | (((bits1.s7 >> 7) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s7; \ + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +__kernel void kernel_gemv_noshuffle_q5_1_f32( + __read_only image1d_buffer_t src0_qs, // quantized A + global ushort * src0_qh, // 5th bits + global half2 * src0_d, // A scales + global half2 * src0_m, // A mins + __read_only image1d_buffer_t src1, // B activations + global float * dst, + ulong offsetd, + int ne00, // K + int ne01) // M +{ + uint groupId = get_local_id(1); + uint gid = get_global_id(0); + ushort slid = get_sub_group_local_id(); + + uint K = ne00; + uint M = ne01; + + uint LINE_STRIDE_A = M / 2; + uint BLOCK_STRIDE_A = NSUBGROUPS * M; + + __private uint4 regA; + __private half2 regS; + __private half2 regM; + __private float8 regB; + + __private float2 totalSum = (float2)(0.0f); + + for (uint k = groupId; k < (K / QK5_1); k += NSUBGROUPS) { + regS = src0_d[gid + k * LINE_STRIDE_A]; + regM = src0_m[gid + k * LINE_STRIDE_A]; + + ushort4 qh_raw; + qh_raw.s0 = src0_qh[gid + (4*k + 0) * LINE_STRIDE_A]; + qh_raw.s1 = src0_qh[gid + (4*k + 1) * LINE_STRIDE_A]; + qh_raw.s2 = src0_qh[gid + (4*k + 2) * LINE_STRIDE_A]; + qh_raw.s3 = src0_qh[gid + (4*k + 3) * LINE_STRIDE_A]; + + uchar8 raw = as_uchar8(qh_raw); + uchar8 qh_bytes = (uchar8)(raw.s0, raw.s2, raw.s4, raw.s6, + raw.s1, raw.s3, raw.s5, raw.s7); + + // Load activations + if (slid < 4) { + regB.s0123 = read_imagef(src1, (slid * 2 + k * 8)); + regB.s4567 = read_imagef(src1, (1 + slid * 2 + k * 8)); + } + + regA.s0 = read_imageui(src0_qs, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 0)).x; + regA.s1 = read_imageui(src0_qs, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 1)).x; + regA.s2 = read_imageui(src0_qs, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 2)).x; + regA.s3 = read_imageui(src0_qs, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 3)).x; + +#ifdef VECTOR_SUB_GROUP_BROADCAST + dequantizeBlockAccum_ns_q5_1_sgbroadcast_8_hi(totalSum, as_ushort8(regA), qh_bytes, regS, regM, regB); +#else + dequantizeBlockAccum_ns_q5_1_sgbroadcast_1_hi(totalSum, as_ushort8(regA), qh_bytes, regS, regM, regB); +#endif // VECTOR_SUB_GROUP_BROADCAST + + regA.s0 = read_imageui(src0_qs, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 4)).x; + regA.s1 = read_imageui(src0_qs, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 5)).x; + regA.s2 = read_imageui(src0_qs, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x; + regA.s3 = read_imageui(src0_qs, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x; +#ifdef VECTOR_SUB_GROUP_BROADCAST + dequantizeBlockAccum_ns_q5_1_sgbroadcast_8_lo(totalSum, as_ushort8(regA), qh_bytes, regS, regM, regB); +#else + dequantizeBlockAccum_ns_q5_1_sgbroadcast_1_lo(totalSum, as_ushort8(regA), qh_bytes, regS, regM, regB); +#endif // VECTOR_SUB_GROUP_BROADCAST + } + + // reduction in local memory, assumes #wave=4 + local float2 reduceLM[SUBGROUP_SIZE * 3]; + if (groupId == 1) { + reduceLM[SUBGROUP_SIZE * 0 + slid] = totalSum; + } + if (groupId == 2) { + reduceLM[SUBGROUP_SIZE * 1 + slid] = totalSum; + } + if (groupId == 3) { + reduceLM[SUBGROUP_SIZE * 2 + slid] = totalSum; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + if (groupId == 0) { + totalSum += reduceLM[SUBGROUP_SIZE * 0 + slid]; + } + if (groupId == 0) { + totalSum += reduceLM[SUBGROUP_SIZE * 1 + slid]; + } + if (groupId == 0) { + totalSum += reduceLM[SUBGROUP_SIZE * 2 + slid]; + } + + // 2 outputs per fiber in wave 0 + if (groupId == 0) { + dst = (global float*)((global char*)dst + offsetd); + vstore2(totalSum, 0, &(dst[gid * 2])); + } + +} From 882736f8867b3b6703f9331e46b8f8da51ea5cee Mon Sep 17 00:00:00 2001 From: ZihaoMu <zmu@amd.com> Date: Fri, 12 Jun 2026 14:32:44 +0800 Subject: [PATCH 819/831] ggml: support concat for scalar types at cuda backend (llama/24011) * cuda: support concat for scalar types * Update concat.cu * fix metal ci issue --- ggml/src/ggml-cuda/concat.cu | 142 ++++++++++++++---------- ggml/src/ggml-cuda/ggml-cuda.cu | 10 +- ggml/src/ggml-metal/ggml-metal-device.m | 11 +- 3 files changed, 101 insertions(+), 62 deletions(-) diff --git a/ggml/src/ggml-cuda/concat.cu b/ggml/src/ggml-cuda/concat.cu index adba4d522a4..8d557092b2b 100644 --- a/ggml/src/ggml-cuda/concat.cu +++ b/ggml/src/ggml-cuda/concat.cu @@ -1,16 +1,18 @@ #include "concat.cuh" +#include <stdint.h> + // contiguous kernels -template <int dim> -static __global__ void __launch_bounds__(CUDA_CONCAT_BLOCK_SIZE) concat_f32_cont(const float * x, - const float * y, - float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - int64_t ne0, - int64_t ne1, - int64_t ne2) { +template <typename T, int dim> +static __global__ void __launch_bounds__(CUDA_CONCAT_BLOCK_SIZE) concat_cont(const T * x, + const T * y, + T * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne0, + int64_t ne1, + int64_t ne2) { static_assert(dim >= 0 && dim <= 2, "dim must be in [0, 2]"); const int64_t n = ne0 * ne1 * ne2; @@ -50,37 +52,37 @@ static __global__ void __launch_bounds__(CUDA_CONCAT_BLOCK_SIZE) concat_f32_cont } } -static void concat_f32_cuda(const float * x, - const float * y, - float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - int64_t ne0, - int64_t ne1, - int64_t ne2, - int dim, - cudaStream_t stream) { +template <typename T> +static void concat_cont_cuda(const T * x, + const T * y, + T * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne0, + int64_t ne1, + int64_t ne2, + int dim, + cudaStream_t stream) { const int64_t n = ne0 * ne1 * ne2; const int num_blocks = (n + CUDA_CONCAT_BLOCK_SIZE - 1) / CUDA_CONCAT_BLOCK_SIZE; if (dim == 0) { const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(num_blocks, CUDA_CONCAT_BLOCK_SIZE, 0, stream); - ggml_cuda_kernel_launch(concat_f32_cont<0>, launch_params,x, y, dst, ne00, ne01, ne02, ne0, ne1, ne2); + ggml_cuda_kernel_launch(concat_cont<T, 0>, launch_params, x, y, dst, ne00, ne01, ne02, ne0, ne1, ne2); return; } if (dim == 1) { - concat_f32_cont<1> - <<<num_blocks, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne00, ne01, ne02, ne0, ne1, ne2); + concat_cont<T, 1><<<num_blocks, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne00, ne01, ne02, ne0, ne1, ne2); return; } - concat_f32_cont<2><<<num_blocks, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne00, ne01, ne02, ne0, ne1, ne2); + concat_cont<T, 2><<<num_blocks, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne00, ne01, ne02, ne0, ne1, ne2); } // non-contiguous kernel (slow) -template <int dim> +template <typename T, int dim> static __global__ void __launch_bounds__(CUDA_CONCAT_BLOCK_SIZE) - concat_f32_non_cont( + concat_non_cont( const char * src0, const char * src1, char * dst, @@ -107,61 +109,49 @@ static __global__ void __launch_bounds__(CUDA_CONCAT_BLOCK_SIZE) uint64_t nb0, uint64_t nb1, uint64_t nb2, - uint64_t nb3){ + uint64_t nb3) { static_assert(dim >= 0 && dim <= 3, "dim must be in [0, 3]"); const int64_t i3 = blockIdx.z; const int64_t i2 = blockIdx.y; const int64_t i1 = blockIdx.x; - const float * x; + const T * x; for (int64_t i0 = threadIdx.x; i0 < ne0; i0 += blockDim.x) { if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) { - x = (const float *)(src0 + (i3 )*nb03 + (i2 )*nb02 + (i1 )*nb01 + (i0 )*nb00); + x = (const T *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); } else { if constexpr (dim == 0) { - x = (const float *) (src1 + i3 * nb13 + i2 * nb12 + i1 * nb11 + (i0 - ne00) * nb10); + x = (const T *)(src1 + i3*nb13 + i2*nb12 + i1*nb11 + (i0 - ne00)*nb10); } else if constexpr (dim == 1) { - x = (const float *) (src1 + i3 * nb13 + i2 * nb12 + (i1 - ne01) * nb11 + i0 * nb10); + x = (const T *)(src1 + i3*nb13 + i2*nb12 + (i1 - ne01)*nb11 + i0*nb10); } else if constexpr (dim == 2) { - x = (const float *) (src1 + i3 * nb13 + (i2 - ne02) * nb12 + i1 * nb11 + i0 * nb10); + x = (const T *)(src1 + i3*nb13 + (i2 - ne02)*nb12 + i1*nb11 + i0*nb10); } else if constexpr (dim == 3) { - x = (const float *) (src1 + (i3 - ne03) * nb13 + i2 * nb12 + i1 * nb11 + i0 * nb10); + x = (const T *)(src1 + (i3 - ne03)*nb13 + i2*nb12 + i1*nb11 + i0*nb10); } } - float * y = (float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + T * y = (T *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); *y = *x; } } - -void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * src0 = dst->src[0]; - const ggml_tensor * src1 = dst->src[1]; - - cudaStream_t stream = ctx.stream(); - - const int32_t dim = ((int32_t *) dst->op_params)[0]; - - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); - +template <typename T> +static void concat_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, int dim, cudaStream_t stream) { if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) { - const float * src0_d = (const float *)src0->data; - const float * src1_d = (const float *)src1->data; - - float * dst_d = (float *)dst->data; + const T * src0_d = (const T *) src0->data; + const T * src1_d = (const T *) src1->data; + T * dst_d = (T *) dst->data; if (dim != 3) { - for (int i3 = 0; i3 < dst->ne[3]; i3++) { - concat_f32_cuda( - src0_d + i3 * (src0->nb[3] / 4), - src1_d + i3 * (src1->nb[3] / 4), - dst_d + i3 * ( dst->nb[3] / 4), + for (int64_t i3 = 0; i3 < dst->ne[3]; i3++) { + concat_cont_cuda( + src0_d + i3*(src0->nb[3] / sizeof(T)), + src1_d + i3*(src1->nb[3] / sizeof(T)), + dst_d + i3*( dst->nb[3] / sizeof(T)), src0->ne[0], src0->ne[1], src0->ne[2], dst->ne[0], dst->ne[1], dst->ne[2], dim, stream); } @@ -169,13 +159,13 @@ void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const size_t size0 = ggml_nbytes(src0); const size_t size1 = ggml_nbytes(src1); - CUDA_CHECK(cudaMemcpyAsync(dst_d, src0_d, size0, cudaMemcpyDeviceToDevice, stream)); - CUDA_CHECK(cudaMemcpyAsync(dst_d + size0/4, src1_d, size1, cudaMemcpyDeviceToDevice, stream)); + CUDA_CHECK(cudaMemcpyAsync((char *) dst->data, src0->data, size0, cudaMemcpyDeviceToDevice, stream)); + CUDA_CHECK(cudaMemcpyAsync((char *) dst->data + size0, src1->data, size1, cudaMemcpyDeviceToDevice, stream)); } } else { dim3 grid_dim(dst->ne[1], dst->ne[2], dst->ne[3]); auto launch_kernel = [&](auto dim) { - concat_f32_non_cont<dim><<<grid_dim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>( + concat_non_cont<T, dim><<<grid_dim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>( (const char *) src0->data, (const char *) src1->data, (char *) dst->data, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], @@ -203,3 +193,35 @@ void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { } } } + +void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + cudaStream_t stream = ctx.stream(); + + const int32_t dim = ((int32_t *) dst->op_params)[0]; + + GGML_ASSERT(src0->type == src1->type); + GGML_ASSERT(dst->type == src0->type); + GGML_ASSERT(!ggml_is_quantized(src0->type)); + GGML_ASSERT(ggml_blck_size(src0->type) == 1); + + switch (ggml_type_size(src0->type)) { + case 1: + concat_cuda<uint8_t>(src0, src1, dst, dim, stream); + break; + case 2: + concat_cuda<uint16_t>(src0, src1, dst, dim, stream); + break; + case 4: + concat_cuda<uint32_t>(src0, src1, dst, dim, stream); + break; + case 8: + concat_cuda<uint64_t>(src0, src1, dst, dim, stream); + break; + default: + GGML_ABORT("Unsupported type size: %zu", ggml_type_size(src0->type)); + break; + } +} diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index e779a9be9e9..61041bdc16b 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -5345,7 +5345,15 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_CONCAT: { ggml_type src0_type = op->src[0]->type; - return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16; + ggml_type src1_type = op->src[1]->type; + return src0_type == src1_type && + src0_type == op->type && + !ggml_is_quantized(src0_type) && + ggml_blck_size(src0_type) == 1 && + (ggml_type_size(src0_type) == 1 || + ggml_type_size(src0_type) == 2 || + ggml_type_size(src0_type) == 4 || + ggml_type_size(src0_type) == 8); } break; case GGML_OP_CONV_TRANSPOSE_1D: { diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 05d7f43051b..d583bd6efc0 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -1120,8 +1120,17 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_VIEW: case GGML_OP_TRANSPOSE: case GGML_OP_PERMUTE: - case GGML_OP_CONCAT: return true; + case GGML_OP_CONCAT: + { + // kernel_concat copies one float-sized value per element. + // Other scalar types need a type-generic copy kernel first. + const enum ggml_type src0_type = op->src[0]->type; + const enum ggml_type src1_type = op->src[1]->type; + return src0_type == src1_type && + src0_type == op->type && + (src0_type == GGML_TYPE_F32 || src0_type == GGML_TYPE_I32); + } case GGML_OP_ADD: case GGML_OP_SUB: case GGML_OP_MUL: From f35f47b5d242484ef405c82ac1c40ae61e8e582c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov <ggerganov@gmail.com> Date: Fri, 12 Jun 2026 15:32:00 +0300 Subject: [PATCH 820/831] ggml : bump version to 0.15.1 (ggml/1541) --- ggml/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index cd0e4fef978..249ed3da290 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -5,7 +5,7 @@ project("ggml" C CXX ASM) ### GGML Version set(GGML_VERSION_MAJOR 0) set(GGML_VERSION_MINOR 15) -set(GGML_VERSION_PATCH 0) +set(GGML_VERSION_PATCH 1) set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}") list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/") From 0a3fa9ca17960dc2419566f2c03ff8913edfbe17 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov <ggerganov@gmail.com> Date: Mon, 15 Jun 2026 09:13:43 +0300 Subject: [PATCH 821/831] sync : ggml --- scripts/sync-ggml.last | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index 6e1bf3a1f4b..87d353ef452 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -7142aa6bf9fcaeec0fef8d80fcd90afe4268adf1 +3af5f5760e19a96427f5f7a93b79cbdf3d4b265b From 0ec0845110dc934911dc48e8c5beb5ad3189b3f3 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov <ggerganov@gmail.com> Date: Mon, 15 Jun 2026 09:15:48 +0300 Subject: [PATCH 822/831] talk-llama : sync llama.cpp --- examples/talk-llama/llama-arch.cpp | 90 ++--- examples/talk-llama/llama-arch.h | 11 + examples/talk-llama/llama-context.cpp | 113 +++++- examples/talk-llama/llama-context.h | 11 + examples/talk-llama/llama-cparams.h | 3 + examples/talk-llama/llama-ext.h | 16 + examples/talk-llama/llama-graph.cpp | 38 ++- examples/talk-llama/llama-graph.h | 14 +- examples/talk-llama/llama-hparams.h | 1 + examples/talk-llama/llama-model-loader.cpp | 1 + examples/talk-llama/llama-model.cpp | 19 +- examples/talk-llama/llama-model.h | 7 + examples/talk-llama/llama-vocab.cpp | 35 +- examples/talk-llama/llama-vocab.h | 8 +- examples/talk-llama/models/delta-net-base.cpp | 41 ++- examples/talk-llama/models/eagle3.cpp | 323 ++++++++++++++++++ .../talk-llama/models/gemma4-assistant.cpp | 3 + examples/talk-llama/models/gemma4.cpp | 2 + examples/talk-llama/models/llama.cpp | 2 + examples/talk-llama/models/models.h | 17 +- examples/talk-llama/models/openai-moe.cpp | 2 + examples/talk-llama/models/plamo2.cpp | 6 +- examples/talk-llama/models/qwen3.cpp | 2 + examples/talk-llama/models/qwen35.cpp | 2 +- examples/talk-llama/models/qwen3moe.cpp | 2 + 25 files changed, 672 insertions(+), 97 deletions(-) create mode 100644 examples/talk-llama/models/eagle3.cpp diff --git a/examples/talk-llama/llama-arch.cpp b/examples/talk-llama/llama-arch.cpp index 6a5d5f8d2ac..9f93d5bc7ce 100644 --- a/examples/talk-llama/llama-arch.cpp +++ b/examples/talk-llama/llama-arch.cpp @@ -3,7 +3,6 @@ #include "llama-impl.h" #include <map> -#include <set> #include <vector> static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = { @@ -128,6 +127,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = { { LLM_ARCH_RND1, "rnd1" }, { LLM_ARCH_PANGU_EMBED, "pangu-embedded" }, { LLM_ARCH_MISTRAL3, "mistral3" }, + { LLM_ARCH_EAGLE3, "eagle3" }, { LLM_ARCH_MISTRAL4, "mistral4" }, { LLM_ARCH_PADDLEOCR, "paddleocr" }, { LLM_ARCH_MIMO2, "mimo2" }, @@ -292,46 +292,51 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = { { LLM_KV_CLASSIFIER_OUTPUT_LABELS, "%s.classifier.output_labels" }, + { LLM_KV_TARGET_LAYERS, "%s.target_layers" }, + { LLM_KV_TARGET_HIDDEN_SIZE, "%s.target_hidden_size" }, + { LLM_KV_NORM_BEFORE_RESIDUAL, "%s.norm_before_residual" }, + { LLM_KV_SHORTCONV_L_CACHE, "%s.shortconv.l_cache" }, // sentence-transformers dense modules feature dims { LLM_KV_DENSE_2_FEAT_IN, "%s.dense_2_feat_in" }, - { LLM_KV_DENSE_2_FEAT_OUT, "%s.dense_2_feat_out" }, - { LLM_KV_DENSE_3_FEAT_IN, "%s.dense_3_feat_in" }, - { LLM_KV_DENSE_3_FEAT_OUT, "%s.dense_3_feat_out" }, - - { LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" }, - { LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" }, - { LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" }, - { LLM_KV_TOKENIZER_TOKEN_TYPE, "tokenizer.ggml.token_type" }, - { LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, "tokenizer.ggml.token_type_count" }, - { LLM_KV_TOKENIZER_SCORES, "tokenizer.ggml.scores" }, - { LLM_KV_TOKENIZER_MERGES, "tokenizer.ggml.merges" }, - { LLM_KV_TOKENIZER_BOS_ID, "tokenizer.ggml.bos_token_id" }, - { LLM_KV_TOKENIZER_EOS_ID, "tokenizer.ggml.eos_token_id" }, - { LLM_KV_TOKENIZER_EOT_ID, "tokenizer.ggml.eot_token_id" }, - { LLM_KV_TOKENIZER_EOM_ID, "tokenizer.ggml.eom_token_id" }, - { LLM_KV_TOKENIZER_UNK_ID, "tokenizer.ggml.unknown_token_id" }, - { LLM_KV_TOKENIZER_SEP_ID, "tokenizer.ggml.seperator_token_id" }, - { LLM_KV_TOKENIZER_PAD_ID, "tokenizer.ggml.padding_token_id" }, - { LLM_KV_TOKENIZER_CLS_ID, "tokenizer.ggml.cls_token_id" }, - { LLM_KV_TOKENIZER_MASK_ID, "tokenizer.ggml.mask_token_id" }, - { LLM_KV_TOKENIZER_ADD_BOS, "tokenizer.ggml.add_bos_token" }, - { LLM_KV_TOKENIZER_ADD_EOS, "tokenizer.ggml.add_eos_token" }, - { LLM_KV_TOKENIZER_ADD_SEP, "tokenizer.ggml.add_sep_token" }, - { LLM_KV_TOKENIZER_ADD_PREFIX, "tokenizer.ggml.add_space_prefix" }, - { LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, "tokenizer.ggml.remove_extra_whitespaces" }, - { LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, "tokenizer.ggml.precompiled_charsmap" }, - { LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" }, - { LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" }, - { LLM_KV_TOKENIZER_CHAT_TEMPLATE, "tokenizer.chat_template" }, - { LLM_KV_TOKENIZER_NORMALIZER_LOWERCASE, "tokenizer.ggml.normalizer.lowercase" }, - { LLM_KV_TOKENIZER_FIM_PRE_ID, "tokenizer.ggml.fim_pre_token_id" }, - { LLM_KV_TOKENIZER_FIM_SUF_ID, "tokenizer.ggml.fim_suf_token_id" }, - { LLM_KV_TOKENIZER_FIM_MID_ID, "tokenizer.ggml.fim_mid_token_id" }, - { LLM_KV_TOKENIZER_FIM_PAD_ID, "tokenizer.ggml.fim_pad_token_id" }, - { LLM_KV_TOKENIZER_FIM_REP_ID, "tokenizer.ggml.fim_rep_token_id" }, - { LLM_KV_TOKENIZER_FIM_SEP_ID, "tokenizer.ggml.fim_sep_token_id" }, - { LLM_KV_TOKENIZER_SUPPRESS_TOKENS, "tokenizer.ggml.suppress_tokens" }, + { LLM_KV_DENSE_2_FEAT_OUT, "%s.dense_2_feat_out" }, + { LLM_KV_DENSE_3_FEAT_IN, "%s.dense_3_feat_in" }, + { LLM_KV_DENSE_3_FEAT_OUT, "%s.dense_3_feat_out" }, + + { LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" }, + { LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" }, + { LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" }, + { LLM_KV_TOKENIZER_TOKEN_TYPE, "tokenizer.ggml.token_type" }, + { LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, "tokenizer.ggml.token_type_count" }, + { LLM_KV_TOKENIZER_SCORES, "tokenizer.ggml.scores" }, + { LLM_KV_TOKENIZER_MERGES, "tokenizer.ggml.merges" }, + { LLM_KV_TOKENIZER_BOS_ID, "tokenizer.ggml.bos_token_id" }, + { LLM_KV_TOKENIZER_EOS_ID, "tokenizer.ggml.eos_token_id" }, + { LLM_KV_TOKENIZER_EOT_ID, "tokenizer.ggml.eot_token_id" }, + { LLM_KV_TOKENIZER_EOM_ID, "tokenizer.ggml.eom_token_id" }, + { LLM_KV_TOKENIZER_UNK_ID, "tokenizer.ggml.unknown_token_id" }, + { LLM_KV_TOKENIZER_SEP_ID, "tokenizer.ggml.seperator_token_id" }, + { LLM_KV_TOKENIZER_PAD_ID, "tokenizer.ggml.padding_token_id" }, + { LLM_KV_TOKENIZER_CLS_ID, "tokenizer.ggml.cls_token_id" }, + { LLM_KV_TOKENIZER_MASK_ID, "tokenizer.ggml.mask_token_id" }, + { LLM_KV_TOKENIZER_ADD_BOS, "tokenizer.ggml.add_bos_token" }, + { LLM_KV_TOKENIZER_ADD_EOS, "tokenizer.ggml.add_eos_token" }, + { LLM_KV_TOKENIZER_ADD_SEP, "tokenizer.ggml.add_sep_token" }, + { LLM_KV_TOKENIZER_ADD_PREFIX, "tokenizer.ggml.add_space_prefix" }, + { LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, "tokenizer.ggml.remove_extra_whitespaces" }, + { LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, "tokenizer.ggml.precompiled_charsmap" }, + { LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" }, + { LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" }, + { LLM_KV_TOKENIZER_CHAT_TEMPLATE, "tokenizer.chat_template" }, + { LLM_KV_TOKENIZER_NORMALIZER_LOWERCASE, "tokenizer.ggml.normalizer.lowercase" }, + { LLM_KV_TOKENIZER_NORMALIZER_STRIP_ACCENTS, "tokenizer.ggml.normalizer.strip_accents" }, + { LLM_KV_TOKENIZER_FIM_PRE_ID, "tokenizer.ggml.fim_pre_token_id" }, + { LLM_KV_TOKENIZER_FIM_SUF_ID, "tokenizer.ggml.fim_suf_token_id" }, + { LLM_KV_TOKENIZER_FIM_MID_ID, "tokenizer.ggml.fim_mid_token_id" }, + { LLM_KV_TOKENIZER_FIM_PAD_ID, "tokenizer.ggml.fim_pad_token_id" }, + { LLM_KV_TOKENIZER_FIM_REP_ID, "tokenizer.ggml.fim_rep_token_id" }, + { LLM_KV_TOKENIZER_FIM_SEP_ID, "tokenizer.ggml.fim_sep_token_id" }, + { LLM_KV_TOKENIZER_SUPPRESS_TOKENS, "tokenizer.ggml.suppress_tokens" }, { LLM_KV_ADAPTER_TYPE, "adapter.type" }, { LLM_KV_ADAPTER_LORA_ALPHA, "adapter.lora.alpha" }, @@ -559,6 +564,10 @@ static const std::map<llm_tensor, const char *> LLM_TENSOR_NAMES = { { LLM_TENSOR_INDEXER_PROJ, "blk.%d.indexer.proj" }, { LLM_TENSOR_INDEXER_ATTN_K, "blk.%d.indexer.attn_k" }, { LLM_TENSOR_INDEXER_ATTN_Q_B, "blk.%d.indexer.attn_q_b" }, + { LLM_TENSOR_MASKED_EMBD_CENTROIDS, "masked_embd_centroids" }, + { LLM_TENSOR_MASKED_EMBD_ORDERING, "masked_embd_ordering" }, + { LLM_TENSOR_FC, "fc" }, + { LLM_TENSOR_D2T, "d2t" }, }; // declare information about the model weight tensors: @@ -783,6 +792,11 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = { // latent projections feed ggml_mul_mat, the buft probe must use MUL_MAT to keep them on GPU {LLM_TENSOR_FFN_LATENT_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_FFN_LATENT_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_MASKED_EMBD_CENTROIDS, {LLM_TENSOR_LAYER_INPUT, GGML_OP_NONE}}, + {LLM_TENSOR_MASKED_EMBD_ORDERING, {LLM_TENSOR_LAYER_INPUT, GGML_OP_NONE}}, + // eagle3 + {LLM_TENSOR_FC, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_D2T, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}}, }; LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {} diff --git a/examples/talk-llama/llama-arch.h b/examples/talk-llama/llama-arch.h index 03b1a265d67..c5245fb5891 100644 --- a/examples/talk-llama/llama-arch.h +++ b/examples/talk-llama/llama-arch.h @@ -141,6 +141,7 @@ enum llm_arch { LLM_ARCH_KIMI_LINEAR, LLM_ARCH_TALKIE, LLM_ARCH_MELLUM, + LLM_ARCH_EAGLE3, LLM_ARCH_UNKNOWN, }; @@ -314,6 +315,7 @@ enum llm_kv { LLM_KV_TOKENIZER_RWKV, LLM_KV_TOKENIZER_CHAT_TEMPLATE, LLM_KV_TOKENIZER_NORMALIZER_LOWERCASE, + LLM_KV_TOKENIZER_NORMALIZER_STRIP_ACCENTS, LLM_KV_TOKENIZER_FIM_PRE_ID, LLM_KV_TOKENIZER_FIM_SUF_ID, LLM_KV_TOKENIZER_FIM_MID_ID, @@ -336,6 +338,10 @@ enum llm_kv { LLM_KV_CLASSIFIER_OUTPUT_LABELS, + LLM_KV_TARGET_LAYERS, + LLM_KV_TARGET_HIDDEN_SIZE, + LLM_KV_NORM_BEFORE_RESIDUAL, + LLM_KV_SHORTCONV_L_CACHE, LLM_KV_XIELU_ALPHA_N, @@ -566,8 +572,13 @@ enum llm_tensor { LLM_TENSOR_NEXTN_HNORM, LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, + LLM_TENSOR_MASKED_EMBD_CENTROIDS, + LLM_TENSOR_MASKED_EMBD_ORDERING, + LLM_TENSOR_FC, + LLM_TENSOR_D2T, }; + enum llm_tensor_layer { LLM_TENSOR_LAYER_INPUT, LLM_TENSOR_LAYER_REPEATING, diff --git a/examples/talk-llama/llama-context.cpp b/examples/talk-llama/llama-context.cpp index 9a40c4366af..168dbabd766 100644 --- a/examples/talk-llama/llama-context.cpp +++ b/examples/talk-llama/llama-context.cpp @@ -71,6 +71,9 @@ llama_context::llama_context( cparams.no_perf = params.no_perf; cparams.warmup = false; + cparams.embeddings_layer_inp.resize(hparams.n_layer(), false); + embd_layer_inp.resize(hparams.n_layer()); + cparams.ctx_type = params.ctx_type; cparams.pooling_type = params.pooling_type; @@ -91,12 +94,21 @@ llama_context::llama_context( if (model.arch == LLM_ARCH_GEMMA4_ASSISTANT) { if (params.ctx_other == nullptr) { // TODO: change from runtime_error to llama_exception to avoid printing error message - throw std::runtime_error("Gemma4Assistant requires ctx_other to be set (this is normal during memory fitting)"); + throw std::runtime_error("Gemma4Assistant requires ctx_other to be set (this warning is normal during memory fitting)"); } cparams.ctx_other = params.ctx_other; } + if (model.arch == LLM_ARCH_EAGLE3) { + if (model.tok_embd == nullptr || model.output == nullptr) { + if (params.ctx_other == nullptr) { + throw std::runtime_error("EAGLE3 requires ctx_other to be set (this warning is normal during memory fitting)"); + } + cparams.ctx_other = params.ctx_other; + } + } + // Initialize backend samplers here so they are part of the sampling graph // before the reserve passes run later in this function. This avoids a later // re-reserve when graph nodes change. @@ -194,7 +206,7 @@ llama_context::llama_context( cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch); - cparams.n_outputs_max = params.n_outputs_max == 0 ? cparams.n_batch : params.n_outputs_max; + cparams.n_outputs_max = params.n_outputs_max == 0 || llama_model_has_encoder(&model) ? cparams.n_batch : params.n_outputs_max; cparams.op_offload = params.op_offload; cparams.kv_unified = params.kv_unified; @@ -938,6 +950,14 @@ float * llama_context::get_embeddings_nextn_ith(int32_t i) { } } +float * llama_context::get_embeddings_layer_inp(uint32_t lid) { + output_reorder(); + + GGML_ASSERT(lid < embd_layer_inp.size() && embd_layer_inp[lid].has_data()); + + return embd_layer_inp[lid].data; +} + llama_token llama_context::get_sampled_token_ith(int32_t idx) { output_reorder(); @@ -1125,6 +1145,17 @@ void llama_context::set_embeddings_nextn(bool value, bool masked) { cparams.embeddings_nextn_masked = masked; } +void llama_context::set_embeddings_layer_inp(uint32_t lid, bool enable) { + LLAMA_LOG_DEBUG("%s: lid = %d, enable = %d\n", __func__, lid, enable); + + GGML_ASSERT(lid < model.hparams.n_layer()); + + cparams.embeddings_layer_inp[lid] = enable; + + // note: without this reserve, the draft acceptance drops to zero. not sure why - this is unexpected + sched_need_reserve = true; +} + void llama_context::set_causal_attn(bool value) { LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value); @@ -1350,7 +1381,8 @@ int llama_context::encode(const llama_batch & batch_inp) { const auto & hparams = model.hparams; - const int64_t n_embd = hparams.n_embd_inp(); + // eagle3/DFlash: features as encoder input, and non-draft paths fall back to model's input dim + const int64_t n_embd = hparams.n_embd_inp(); const int64_t n_vocab = model.vocab.n_tokens(); // note: during encode, we always pass the full sequence starting from pos = 0 @@ -1925,6 +1957,8 @@ int llama_context::decode(const llama_batch & batch_inp) { } } + extract_layer_inputs(res, n_tokens_prev, ubatch.n_tokens); + // extract nextn embeddings before // only meaningful in LLAMA_POOLING_TYPE_NONE (per-token); other pooling modes are ignored. { @@ -2029,6 +2063,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { const auto n_batch = cparams.n_batch; const auto n_vocab = vocab.n_tokens(); + const auto n_embd = hparams.n_embd; const auto n_embd_out = hparams.n_embd_out(); bool has_logits = true; @@ -2041,9 +2076,9 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { has_embd = true; } - size_t backend_float_count = 0; size_t backend_token_count = 0; + size_t embd_layer_inp_float_count = 0; logits.size = has_logits ? n_vocab*n_outputs_max : 0; embd.size = has_embd ? n_embd_out*n_outputs_max : 0; @@ -2055,6 +2090,12 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { embd_nextn.size = (size_t) n_embd_out * n_batch; } + for (bool enabled : cparams.embeddings_layer_inp) { + if (enabled) { + embd_layer_inp_float_count += (size_t) n_embd * n_batch; + } + } + // Allocate backend sampling output buffers if there are backend samplers configured. const bool has_sampling = !sampling.samplers.empty(); if (has_sampling) { @@ -2069,8 +2110,8 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { const size_t prev_size = buf_output ? ggml_backend_buffer_get_size(buf_output.get()) : 0; const size_t new_size = - (logits.size + embd.size + embd_nextn.size + backend_float_count) * sizeof(float) + - ( backend_token_count) * sizeof(llama_token); + (logits.size + embd.size + embd_nextn.size + embd_layer_inp_float_count + backend_float_count) * sizeof(float) + + ( backend_token_count) * sizeof(llama_token); // alloc only when more than the current capacity is required // TODO: also consider shrinking the buffer @@ -2087,6 +2128,9 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { logits.data = nullptr; embd.data = nullptr; embd_nextn.data = nullptr; + for (auto & layer_inp : embd_layer_inp) { + layer_inp = {nullptr, 0}; + } } auto * buft = ggml_backend_cpu_buffer_type(); @@ -2118,6 +2162,15 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { embd_nextn = has_embd_nextn ? buffer_view<float>{(float *) (base + offset), embd_nextn.size} : buffer_view<float>{nullptr, 0}; offset += embd_nextn.size * sizeof(float); + for (uint32_t il = 0; il < embd_layer_inp.size(); ++il) { + if (cparams.embeddings_layer_inp[il]) { + embd_layer_inp[il] = buffer_view<float>{(float *) (base + offset), (size_t) n_embd * n_batch}; + offset += embd_layer_inp[il].size * sizeof(float); + } else { + embd_layer_inp[il] = buffer_view<float>{nullptr, 0}; + } + } + if (has_sampling) { sampling.logits = {(float *) (base + offset), (size_t)(n_vocab*n_outputs_max)}; offset += sampling.logits.size * sizeof(float); @@ -2164,6 +2217,34 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { return n_outputs_max; } +void llama_context::extract_layer_inputs(const llm_graph_result * res, size_t token_offset, size_t n_tokens) { + for (uint32_t il = 0; il < cparams.embeddings_layer_inp.size(); ++il) { + if (!cparams.embeddings_layer_inp[il]) { + continue; + } + if (!embd_layer_inp[il].has_data()) { + GGML_ABORT("output layer input buffer not allocated"); + } + ggml_tensor * t = res->get_layer_inp((int) il); + if (!t) { + GGML_ABORT("layer input tensor not found"); + } + + const size_t nbytes = ggml_nbytes(t); + const size_t nfloats = nbytes / sizeof(float); + GGML_ASSERT(n_tokens > 0); + GGML_ASSERT(nfloats % n_tokens == 0); + + const size_t row_floats = nfloats / n_tokens; + const size_t dst_offset = token_offset * row_floats; + GGML_ASSERT(dst_offset + nfloats <= embd_layer_inp[il].size); + + ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched.get(), t); + GGML_ASSERT(backend != nullptr); + ggml_backend_tensor_get_async(backend, t, embd_layer_inp[il].data + dst_offset, 0, nbytes); + } +} + void llama_context::output_reorder() { const uint64_t n_vocab = model.vocab.n_tokens(); const uint64_t n_embd = model.hparams.n_embd; @@ -2190,6 +2271,16 @@ void llama_context::output_reorder() { } } + if (embd_layer_inp.size() > 0) { + for (int lid = 0; lid < (int) embd_layer_inp.size(); ++lid) { + if (embd_layer_inp[lid].size > 0) { + for (uint64_t k = 0; k < n_embd; ++k) { + std::swap(embd_layer_inp[lid].data[i0*n_embd + k], embd_layer_inp[lid].data[i1*n_embd + k]); + } + } + } + } + if (!sampling.samplers.empty()) { assert(sampling.logits.size > 0); assert(sampling.probs.size > 0); @@ -3604,6 +3695,10 @@ void llama_set_embeddings_nextn(llama_context * ctx, bool value, bool masked) { ctx->set_embeddings_nextn(value, masked); } +void llama_set_embeddings_layer_inp(llama_context * ctx, uint32_t lid, bool value) { + ctx->set_embeddings_layer_inp(lid, value); +} + llama_memory_t llama_get_memory(const struct llama_context * ctx) { if (!ctx) { return nullptr; @@ -3624,6 +3719,12 @@ float * llama_get_embeddings_nextn_ith(llama_context * ctx, int32_t i) { return ctx->get_embeddings_nextn_ith(i); } +float * llama_get_embeddings_layer_inp(llama_context * ctx, uint32_t lid) { + ctx->synchronize(); + + return ctx->get_embeddings_layer_inp(lid); +} + bool llama_set_sampler(llama_context * ctx, llama_seq_id seq_id, llama_sampler * smpl) { return ctx->set_sampler(seq_id, smpl); } diff --git a/examples/talk-llama/llama-context.h b/examples/talk-llama/llama-context.h index 6f8f59a22a3..853052be2ca 100644 --- a/examples/talk-llama/llama-context.h +++ b/examples/talk-llama/llama-context.h @@ -88,6 +88,8 @@ struct llama_context { float * get_embeddings_nextn(); float * get_embeddings_nextn_ith(int32_t i); + float * get_embeddings_layer_inp(uint32_t lid); + llama_token * get_sampled_tokens() const; llama_token get_sampled_token_ith(int32_t idx); @@ -112,6 +114,7 @@ struct llama_context { void set_embeddings (bool value); void set_embeddings_nextn(bool value, bool masked); + void set_embeddings_layer_inp(uint32_t lid, bool enable); void set_causal_attn(bool value); void set_warmup(bool value); @@ -226,6 +229,10 @@ struct llama_context { // map the output row index `i` to batch index int64_t output_resolve_row(int32_t i) const; + // async-copy enabled layer-input tensors (per cparams.output_layer_inp) + // from backend into host-side embd_layer_inp buffers + void extract_layer_inputs(const llm_graph_result * res, size_t token_offset, size_t n_tokens); + // // graph // @@ -288,6 +295,10 @@ struct llama_context { // sets llm_graph_result::t_h_nextn buffer_view<float> embd_nextn = {nullptr, 0}; + // host buffers for output layer input embeddings, per layer + // populated when cparams.output_layer_inp[il] is true + std::vector<buffer_view<float>> embd_layer_inp; + struct sampling_info { // !samplers.empty() to check if any samplers are active std::map<llama_seq_id, llama_sampler *> samplers; diff --git a/examples/talk-llama/llama-cparams.h b/examples/talk-llama/llama-cparams.h index 8a35d389ef4..2b109f909c0 100644 --- a/examples/talk-llama/llama-cparams.h +++ b/examples/talk-llama/llama-cparams.h @@ -3,6 +3,7 @@ #include "llama.h" #include <cstdint> +#include <vector> #define LLAMA_MAX_SEQ 256 @@ -44,6 +45,8 @@ struct llama_cparams { bool kv_unified; bool pipeline_parallel; + std::vector<bool> embeddings_layer_inp; // [n_layer()] extract input embeddings for layer + enum llama_context_type ctx_type; enum llama_pooling_type pooling_type; diff --git a/examples/talk-llama/llama-ext.h b/examples/talk-llama/llama-ext.h index bd74544129b..b744af52864 100644 --- a/examples/talk-llama/llama-ext.h +++ b/examples/talk-llama/llama-ext.h @@ -101,4 +101,20 @@ LLAMA_API float * llama_get_embeddings_nextn(struct llama_context * ctx); // LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i); LLAMA_API float * llama_get_embeddings_nextn_ith(struct llama_context * ctx, int32_t i); +// Set whether the context outputs the input embeddings of a specific layer +LLAMA_API void llama_set_embeddings_layer_inp(struct llama_context * ctx, uint32_t lid, bool value); + +// mirrors: +// LLAMA_API float * llama_get_embeddings(struct llama_context * ctx); +LLAMA_API float * llama_get_embeddings_layer_inp(struct llama_context * ctx, uint32_t lid); + LLAMA_API llama_context * llama_get_ctx_other(struct llama_context * ctx); + +// +// model/context data extraction +// + +// returns pointer to the target-model layer indices +LLAMA_API const int32_t * llama_model_target_layer_ids (const struct llama_model * model); +// returns the number of extracted layers from target model +LLAMA_API uint32_t llama_model_target_layer_ids_n(const struct llama_model * model); diff --git a/examples/talk-llama/llama-graph.cpp b/examples/talk-llama/llama-graph.cpp index da7a9295561..7468bd9b79e 100644 --- a/examples/talk-llama/llama-graph.cpp +++ b/examples/talk-llama/llama-graph.cpp @@ -567,7 +567,10 @@ void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) { mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch); } - mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); + // the kq mask guards on its own buffer: shared cells leave idxs unbacked while the mask stays live + if (self_kq_mask && self_kq_mask->buffer) { + mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); + } // swa tensors may not be allocated if there are no SWA attention layers if (self_k_idxs_swa && self_k_idxs_swa->buffer) { @@ -575,7 +578,9 @@ void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) { mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch); } - mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn); + if (self_kq_mask_swa && self_kq_mask_swa->buffer) { + mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn); + } if (self_k_rot) { mctx->get_base()->set_input_k_rot(self_k_rot); @@ -607,7 +612,9 @@ bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) { //res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there } - res &= can_reuse_kq_mask(self_kq_mask, mctx->get_base(), params.ubatch, params.cparams); + if (self_kq_mask && self_kq_mask->buffer) { + res &= can_reuse_kq_mask(self_kq_mask, mctx->get_base(), params.ubatch, params.cparams); + } // swa tensors may not be allocated if there are no SWA attention layers if (self_k_idxs_swa && self_k_idxs_swa->buffer) { @@ -615,7 +622,9 @@ bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) { //res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there } - res &= can_reuse_kq_mask(self_kq_mask_swa, mctx->get_swa(), params.ubatch, params.cparams); + if (self_kq_mask_swa && self_kq_mask_swa->buffer) { + res &= can_reuse_kq_mask(self_kq_mask_swa, mctx->get_swa(), params.ubatch, params.cparams); + } return res; } @@ -895,6 +904,10 @@ void llm_graph_result::reset() { t_logits = nullptr; t_embd = nullptr; t_embd_pooled = nullptr; + + t_layer_inp.resize(LLAMA_MAX_LAYERS); + std::fill(t_layer_inp.begin(), t_layer_inp.end(), nullptr); + t_sampled.clear(); t_sampled_probs.clear(); t_sampled_logits.clear(); @@ -923,7 +936,7 @@ void llm_graph_result::set_inputs(const llama_ubatch * ubatch) { } } -void llm_graph_result::set_outputs() { +void llm_graph_result::set_outputs(const llm_graph_params & params) { if (t_logits != nullptr) { ggml_set_output(t_logits); } @@ -936,6 +949,15 @@ void llm_graph_result::set_outputs() { if (t_h_nextn != nullptr) { ggml_set_output(t_h_nextn); } + { + const auto & embeddings_layer_inp = params.cparams.embeddings_layer_inp; + for (size_t il = 0; il < embeddings_layer_inp.size(); ++il) { + if (embeddings_layer_inp[il]) { + GGML_ASSERT(t_layer_inp[il] != nullptr && "layer input tensor is null"); + ggml_set_output(t_layer_inp[il]); + } + } + } for (auto & [seq_id, t] : t_sampled) { if (t != nullptr) { ggml_set_output(t); @@ -1864,9 +1886,9 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const { res->t_inp_embd = cur; // For Granite architecture - // NOTE: Only apply scale to token inputs. Raw embeddings are assumed to be - // multimodal inputs that should not be scaled. - if (ubatch.token && hparams.f_embedding_scale != 0.0f) { + // NOTE: For deepstack models, only apply scale to token inputs (ie text-only input). + // Raw embeddings are assumed to be multimodal inputs that should not be scaled. + if (hparams.f_embedding_scale != 0.0f && (ubatch.token || hparams.n_deepstack_layers == 0)) { if (!ggml_is_contiguous(cur)) { cur = ggml_cont(ctx0, cur); } diff --git a/examples/talk-llama/llama-graph.h b/examples/talk-llama/llama-graph.h index 6793846e3ea..cc5cfe51dcd 100644 --- a/examples/talk-llama/llama-graph.h +++ b/examples/talk-llama/llama-graph.h @@ -705,6 +705,8 @@ class llm_graph_result { ggml_tensor * get_embd_pooled() const { return t_embd_pooled; } ggml_tensor * get_h_nextn() const { return t_h_nextn; } + ggml_tensor * get_layer_inp(int il) const { return t_layer_inp[il]; } + ggml_cgraph * get_gf() const { return gf; } ggml_context * get_ctx() const { return ctx_compute.get(); } @@ -713,7 +715,7 @@ class llm_graph_result { void reset(); void set_inputs(const llama_ubatch * ubatch); - void set_outputs(); + void set_outputs(const llm_graph_params & params); // try to update the existing graph result using the new graph parameters in order to reuse it // this can only be done if we determine that the resulting graph using the new graph parameters @@ -734,10 +736,12 @@ class llm_graph_result { ggml_tensor * t_embd_pooled = nullptr; ggml_tensor * t_h_nextn = nullptr; // [n_embd, n_outputs] hidden state before final output norm - std::map<llama_seq_id, ggml_tensor*> t_sampled_logits; - std::map<llama_seq_id, ggml_tensor*> t_candidates; - std::map<llama_seq_id, ggml_tensor*> t_sampled; - std::map<llama_seq_id, ggml_tensor*> t_sampled_probs; + std::vector<ggml_tensor *> t_layer_inp; + + std::map<llama_seq_id, ggml_tensor *> t_sampled_logits; + std::map<llama_seq_id, ggml_tensor *> t_candidates; + std::map<llama_seq_id, ggml_tensor *> t_sampled; + std::map<llama_seq_id, ggml_tensor *> t_sampled_probs; std::vector<llm_graph_input_ptr> inputs; diff --git a/examples/talk-llama/llama-hparams.h b/examples/talk-llama/llama-hparams.h index 032944cb481..d045059a63e 100644 --- a/examples/talk-llama/llama-hparams.h +++ b/examples/talk-llama/llama-hparams.h @@ -45,6 +45,7 @@ struct llama_hparams { bool rope_finetuned; bool use_par_res; bool swin_norm; + bool norm_before_residual = false; uint32_t n_ctx_train; // context size the model was trained on uint32_t n_embd; diff --git a/examples/talk-llama/llama-model-loader.cpp b/examples/talk-llama/llama-model-loader.cpp index 0d1cf3cc33b..474cabdfc09 100644 --- a/examples/talk-llama/llama-model-loader.cpp +++ b/examples/talk-llama/llama-model-loader.cpp @@ -394,6 +394,7 @@ namespace GGUFMeta { template bool llama_model_loader::get_arr<std::vector<std::string>>(enum llm_kv kid, std::vector<std::string> & result, bool required); template bool llama_model_loader::get_arr<std::array<int32_t, 512>>(enum llm_kv kid, std::array<int32_t, 512> & result, bool required); + template bool llama_model_loader::get_arr<std::vector<int32_t>>(enum llm_kv kid, std::vector<int32_t> & result, bool required); template<typename T> bool llama_model_loader::get_key(const std::string & key, T & result, bool required) { diff --git a/examples/talk-llama/llama-model.cpp b/examples/talk-llama/llama-model.cpp index 4f12e0949ac..7281ed79f10 100644 --- a/examples/talk-llama/llama-model.cpp +++ b/examples/talk-llama/llama-model.cpp @@ -287,6 +287,8 @@ static llama_model * llama_model_mapping(llm_arch arch, const llama_model_params return new llama_model_qwen35moe(params); case LLM_ARCH_MISTRAL3: return new llama_model_mistral3(params); + case LLM_ARCH_EAGLE3: + return new llama_model_eagle3(params); case LLM_ARCH_MIMO2: return new llama_model_mimo2(params); case LLM_ARCH_KIMI_LINEAR: @@ -2238,7 +2240,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { // TODO: move reranking logic here and generalize llm->build_dense_out(dense_2_out_layers, dense_2_out_layers_b, dense_3_out_layers); - llm->res->set_outputs(); + llm->res->set_outputs(params); return llm->res->get_gf(); } @@ -2406,6 +2408,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_ERNIE4_5: case LLM_ARCH_ERNIE4_5_MOE: case LLM_ARCH_MISTRAL3: + case LLM_ARCH_EAGLE3: case LLM_ARCH_MISTRAL4: case LLM_ARCH_LLAMA_EMBED: case LLM_ARCH_MAINCODER: @@ -2600,8 +2603,9 @@ uint64_t llama_model_n_params(const llama_model * model) { bool llama_model_has_encoder(const llama_model * model) { switch (model->arch) { - case LLM_ARCH_T5: return true; - case LLM_ARCH_T5ENCODER: return true; + case LLM_ARCH_T5: + case LLM_ARCH_T5ENCODER: + case LLM_ARCH_EAGLE3: return true; default: return false; } } @@ -2687,3 +2691,12 @@ void llama_model_base::create_tensor_qkv(llama_layer & layer, int bid, layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", bid), {n_embd_v_}, TENSOR_NOT_REQUIRED); } } + +const int32_t * llama_model_target_layer_ids(const struct llama_model * model) { + const auto & v = model->target_layer_ids; + return v.empty() ? nullptr : v.data(); +} + +uint32_t llama_model_target_layer_ids_n(const struct llama_model * model) { + return (uint32_t) model->target_layer_ids.size(); +} diff --git a/examples/talk-llama/llama-model.h b/examples/talk-llama/llama-model.h index 992c8d9c8fd..f4718f6d584 100644 --- a/examples/talk-llama/llama-model.h +++ b/examples/talk-llama/llama-model.h @@ -569,6 +569,13 @@ struct llama_model { struct ggml_tensor * per_layer_model_proj = nullptr; struct ggml_tensor * per_layer_proj_norm = nullptr; + // eagle3 + struct ggml_tensor * fc = nullptr; // feature fusion layer + struct ggml_tensor * d2t = nullptr; // draft to target vocabulary mapping + + // unified vector to store target-model extracted layer ids in eagle3, dflash, etc. + std::vector<int32_t> target_layer_ids; + std::vector<llama_layer> layers; //Dense linear projections for SentenceTransformers models like embeddinggemma diff --git a/examples/talk-llama/llama-vocab.cpp b/examples/talk-llama/llama-vocab.cpp index 9a4bed49487..8543e178dba 100644 --- a/examples/talk-llama/llama-vocab.cpp +++ b/examples/talk-llama/llama-vocab.cpp @@ -764,7 +764,7 @@ struct llm_tokenizer_wpm_session { void tokenize(const std::string & text, std::vector<llama_token> & output) { // normalize and split by whitespace - std::vector<std::string> words = preprocess(text, vocab.get_normalizer_lowercase()); + std::vector<std::string> words = preprocess(text, vocab.get_normalizer_opts()); // bos token prepended already // find the longest tokens that form the words @@ -809,11 +809,14 @@ struct llm_tokenizer_wpm_session { } // TODO: reduce string copies by using cpts_offs array - static std::vector<std::string> preprocess(const std::string & text, bool lowercase) { - const std::vector<uint32_t> cpts_nfd = unicode_cpts_normalize_nfd(unicode_cpts_from_utf8(text)); + static std::vector<std::string> preprocess(const std::string & text, const llama_vocab::normalizer_options & normalizer_opts) { + std::vector<uint32_t> cpts = unicode_cpts_from_utf8(text); + if (normalizer_opts.strip_accents) { + cpts = unicode_cpts_normalize_nfd(cpts); + } std::vector<std::string> words(1, ""); - for (const uint32_t cpt : cpts_nfd) { + for (const uint32_t cpt : cpts) { const auto flags = unicode_cpt_flags_from_cpt(cpt); if (flags.is_whitespace) { @@ -828,7 +831,11 @@ struct llm_tokenizer_wpm_session { continue; } - const std::string s = unicode_cpt_to_utf8(lowercase ? unicode_tolower(cpt) : cpt); + if (normalizer_opts.strip_accents && flags.is_accent_mark) { + continue; + } + + const std::string s = unicode_cpt_to_utf8(normalizer_opts.lowercase ? unicode_tolower(cpt) : cpt); if (flags.is_punctuation || ( cpt < 0x7F && flags.is_symbol ) || is_chinese_char(cpt)) { if (words.back().size()) { // finish previous word if any words.emplace_back(); @@ -1692,7 +1699,7 @@ struct llm_tokenizer_whitespace_session : llm_tokenizer_bpe_session { llm_tokenizer_whitespace_session(const llama_vocab & vocab, const llm_tokenizer_bpe & tokenizer) : llm_tokenizer_bpe_session{vocab, tokenizer}, vocab{vocab} {} void tokenize(const std::string & text, std::vector<llama_token> & output) override { - const bool lowercase = vocab.get_normalizer_lowercase(); + const bool lowercase = vocab.get_normalizer_opts().lowercase; std::string segment; auto flush = [&]() { @@ -1797,7 +1804,9 @@ struct llama_vocab::impl { bool remove_extra_whitespaces = false; bool escape_whitespaces = true; bool treat_whitespace_as_suffix = false; - bool normalizer_lowercase = true; // Lowercase normalizer (tokenizer.json) + + // BertNormalizer options + llama_vocab::normalizer_options normalizer_opts; std::unordered_map<std::string, llama_token> token_to_id; std::vector<token_data> id_to_token; @@ -2172,7 +2181,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { } else if ( tokenizer_pre == "whitespace") { pre_type = LLAMA_VOCAB_PRE_TYPE_WHITESPACE; - normalizer_lowercase = false; + normalizer_opts.lowercase = false; } else if ( tokenizer_pre == "refact") { pre_type = LLAMA_VOCAB_PRE_TYPE_REFACT; @@ -2532,8 +2541,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { } } - // Lowercase normalizer flag (consulted by WPM / whitespace BPE) - ml.get_key(LLM_KV_TOKENIZER_NORMALIZER_LOWERCASE, normalizer_lowercase, false); + // BertNormalizer options + ml.get_key(LLM_KV_TOKENIZER_NORMALIZER_LOWERCASE, normalizer_opts.lowercase, false); + normalizer_opts.strip_accents = normalizer_opts.lowercase; + ml.get_key(LLM_KV_TOKENIZER_NORMALIZER_STRIP_ACCENTS, normalizer_opts.strip_accents, false); // suppress tokens { @@ -3969,8 +3980,8 @@ bool llama_vocab::get_treat_whitespace_as_suffix() const { return pimpl->treat_whitespace_as_suffix; } -bool llama_vocab::get_normalizer_lowercase() const { - return pimpl->normalizer_lowercase; +const llama_vocab::normalizer_options & llama_vocab::get_normalizer_opts() const { + return pimpl->normalizer_opts; } const std::vector<llama_token> & llama_vocab::get_suppress_tokens() const { diff --git a/examples/talk-llama/llama-vocab.h b/examples/talk-llama/llama-vocab.h index 2626ae36e33..707cd4bac4b 100644 --- a/examples/talk-llama/llama-vocab.h +++ b/examples/talk-llama/llama-vocab.h @@ -76,6 +76,12 @@ struct llama_vocab { llama_token_attr attr; }; + struct normalizer_options { + bool lowercase = true; + bool strip_accents = true; + // TODO: clean_text, handle_chinese_chars + }; + llama_vocab(); ~llama_vocab(); @@ -141,7 +147,7 @@ struct llama_vocab { bool get_remove_extra_whitespaces () const; bool get_escape_whitespaces () const; bool get_treat_whitespace_as_suffix() const; - bool get_normalizer_lowercase () const; + const normalizer_options & get_normalizer_opts() const; const std::vector<llama_token> & get_suppress_tokens() const; diff --git a/examples/talk-llama/models/delta-net-base.cpp b/examples/talk-llama/models/delta-net-base.cpp index 4f4c7cac7a8..ad9ce771408 100644 --- a/examples/talk-llama/models/delta-net-base.cpp +++ b/examples/talk-llama/models/delta-net-base.cpp @@ -398,9 +398,8 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_ne GGML_ASSERT(b->ne[0] == 1 && b->ne[1] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs); GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs); - // K=1 (final state only): reshape to 3D (S_v*S_v*H_v, 1, n_seqs) for ggml_gated_delta_net. - ggml_tensor * s_3d = ggml_reshape_3d(ctx0, s, S_v * S_v * H_v, 1, n_seqs); - ggml_tensor * result = ggml_gated_delta_net(ctx0, q, k, v, g, b, s_3d); + // K=1: output carries the final state only. state s is 4D [S_v, S_v, H_v, n_seqs]. + ggml_tensor * result = ggml_gated_delta_net(ctx0, q, k, v, g, b, s, /*K=*/1); if (n_tokens == 1) { cb(result, LLAMA_TENSOR_NAME_FGDN_AR, il); } else { @@ -564,11 +563,8 @@ ggml_tensor * llm_build_delta_net_base::build_recurrent_attn( const int64_t D = S_v * S_v * H_v; const int64_t K = cparams.n_rs_seq + 1; - // TODO: remove pad + simplify - ggml_tensor * s_3d = ggml_reshape_3d(ctx0, s, D, 1, n_seqs); - ggml_tensor * s_3d_pad = ggml_pad (ctx0, s_3d, 0, K - 1, 0, 0); - - ggml_tensor * gdn_out = ggml_gated_delta_net(ctx0, q, k, v, g, b, s_3d_pad); + // state s is 4D [S_v, S_v, H_v, n_seqs]; K snapshot slots are written into the output. + ggml_tensor * gdn_out = ggml_gated_delta_net(ctx0, q, k, v, g, b, s, K); if (n_seq_tokens > 1) { cb(gdn_out, LLAMA_TENSOR_NAME_FGDN_CH, il); } else { @@ -587,21 +583,24 @@ ggml_tensor * llm_build_delta_net_base::build_recurrent_attn( cb(output, "attn_output", il); const size_t row_size = hparams.n_embd_s() * ggml_element_size(ssm_states_all); - for (int64_t k_i = 0; k_i < K; ++k_i) { - const uint32_t cache_slot = (uint32_t) (K - 1 - k_i); - ggml_tensor * src = ggml_view_4d(ctx0, gdn_out, - S_v, S_v, H_v, n_seqs, - ggml_row_size(gdn_out->type, S_v), - ggml_row_size(gdn_out->type, S_v * S_v), - ggml_row_size(gdn_out->type, S_v * S_v * H_v), - ggml_row_size(gdn_out->type, attn_score_elems + k_i * state_size_per_snap)); - ggml_tensor * dst = ggml_view_2d(ctx0, ssm_states_all, - hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1], - ((size_t) cache_slot * mem_size + kv_head) * row_size); + // op writes the last min(n_seq_tokens, K) snapshots; trailing slots are left unwritten + const int64_t n_written = std::min<int64_t>(n_seq_tokens, K); - ggml_build_forward_expand(gf, ggml_cpy(ctx0, src, dst)); - } + // write the produced snapshots into the recurrent cache (snapshot slot i -> rollback group i) + ggml_tensor * src = ggml_view_3d(ctx0, gdn_out, + D, n_seqs, n_written, + ggml_row_size(gdn_out->type, D), + ggml_row_size(gdn_out->type, state_size_per_snap), + ggml_row_size(gdn_out->type, attn_score_elems)); + + ggml_tensor * dst = ggml_view_3d(ctx0, ssm_states_all, + D, n_seqs, n_written, + ssm_states_all->nb[1], + (size_t) mem_size * row_size, + (size_t) kv_head * row_size); + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, src, dst)); return output; } diff --git a/examples/talk-llama/models/eagle3.cpp b/examples/talk-llama/models/eagle3.cpp new file mode 100644 index 00000000000..3321b390515 --- /dev/null +++ b/examples/talk-llama/models/eagle3.cpp @@ -0,0 +1,323 @@ +#include "models.h" + +void llama_model_eagle3::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + if (!ml.get_arr(LLM_KV_TARGET_LAYERS, target_layer_ids, false)) { + throw std::runtime_error("EAGLE3 model requires 'extract_layers' in GGUF metadata"); + } + if (target_layer_ids.size() != 3) { + throw std::runtime_error("EAGLE3 requires exactly 3 entries in 'extract_layers'"); + } + LLAMA_LOG_INFO("%s: EAGLE3 extract_layers = [%d, %d, %d]\n", __func__, + target_layer_ids[0], + target_layer_ids[1], + target_layer_ids[2]); + + uint32_t n_embd_tgt = 0; + + ml.get_key(LLM_KV_TARGET_HIDDEN_SIZE, n_embd_tgt); + LLAMA_LOG_INFO("%s: EAGLE3 n_embd_tgt = %u (draft n_embd = %u)\n", __func__, n_embd_tgt, hparams.n_embd); + + hparams.n_embd_inp_impl = (uint32_t) target_layer_ids.size() * n_embd_tgt; + + // eagle3 norm_before_residual (optional, default false) + // compatible with Readhat eagle3 speculator model + ml.get_key(LLM_KV_NORM_BEFORE_RESIDUAL, hparams.norm_before_residual, false); + if (hparams.norm_before_residual) { + LLAMA_LOG_INFO("%s: EAGLE3gnorm_before_residual = true\n", __func__); + } + + type = LLM_TYPE_UNKNOWN; +} + +void llama_model_eagle3::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + const int64_t n_embd_inp = hparams.n_embd_inp(); + const int64_t n_embd_attn_input = 2 * n_embd; + + // Get vocab size from the d2t tensor in the GGUF file (optional - only needed if eagle3 has different vocab_size than target) + // d2t: draft to target vocabulary mapping + int64_t n_draft_vocab = n_vocab; // Default: same as target vocab + const struct ggml_tensor * d2t_meta = ml->get_tensor_meta("d2t"); + if (d2t_meta) { + n_draft_vocab = d2t_meta->ne[0]; // update draft vocab size + d2t = create_tensor(tn(LLM_TENSOR_D2T), {n_draft_vocab}, 0); + LLAMA_LOG_INFO("%s: EAGLE3 using d2t mapping (draft_vocab_size = %lld)\n", __func__, (long long)n_draft_vocab); + } else { + d2t = nullptr; // no d2t, use default vocab size + LLAMA_LOG_INFO("%s: EAGLE3 without d2t - sharing same vocab_size with target (vocab_size = %lld)\n", __func__, (long long)n_draft_vocab); + } + + // Feature fusion layer: projects 3 target layers to draft hidden size + fc = create_tensor(tn(LLM_TENSOR_FC, "weight"), {n_embd_inp, n_embd}, 0); + + // Output layer (uses draft vocab size) + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_draft_vocab}, TENSOR_NOT_REQUIRED); + + // Token embeddings (optional - Llama 3.3 70B EAGLE3 has its own) + const struct ggml_tensor * tok_embd_meta = ml->get_tensor_meta(tn(LLM_TENSOR_TOKEN_EMBD, "weight").str().c_str()); + if (tok_embd_meta) { + const int64_t n_target_vocab = tok_embd_meta->ne[1]; + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_target_vocab}, 0); + LLAMA_LOG_INFO("%s: EAGLE3 using its own token_embd (vocab = %lld)\n", __func__, (long long)n_target_vocab); + } + + // Single decoder layer + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + // input_layernorm: applied to token embeddings + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + // eagle3 specific: hidden_norm applied to fused target features + layer.attn_norm_2 = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, 0); + + // Attention takes input_embeds_normed + fused_target_normed as input + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd_attn_input, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd_attn_input, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd_attn_input, n_embd_v_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + + // rope_freqs for llama3 rope scaling (optional - only if eagle3 config has rope_scaling) + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED); + } +} + +std::unique_ptr<llm_graph_context> llama_model_eagle3::build_arch_graph(const llm_graph_params & params) const { + switch (params.gtype) { + case LLM_GRAPH_TYPE_ENCODER: + return std::make_unique<graph<true>>(*this, params); + case LLM_GRAPH_TYPE_DEFAULT: + case LLM_GRAPH_TYPE_DECODER: + return std::make_unique<graph<false>>(*this, params); + default: + GGML_ABORT("invalid graph type"); + }; +} + +template <> +ggml_tensor * llama_model_eagle3::graph<true>::build_inp_embd_enc() const { + ggml_tensor * cur = nullptr; + + // Input: Target model features (3 layers concatenated: low, mid, high) + // Data will be provided via ubatch->embd in encode_eagle3_features() + auto inp_target = std::make_unique<llm_graph_input_embd>(hparams.n_embd_inp()); + inp_target->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32,hparams.n_embd_inp(), n_tokens); + ggml_set_input(inp_target->embd); + + cur = inp_target->embd; + cb(cur, "inp_embd", -1); + + res->add_input(std::move(inp_target)); + + return cur; +} + +// eagle3 Encoder: processes target model features through feature fusion layer +// Input: target_features e.g. [12288, n_tokens] from target model layers low, middle, high +// Output: g_embeddings e.g. [4096, n_tokens] stored in context +template <> +llama_model_eagle3::graph<true>::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + ggml_tensor * cur = nullptr; + + cur = build_inp_embd_enc(); + + // Feature fusion layer + cur = build_lora_mm(model.fc, cur); + cb(cur, "fc_out", -1); + + // Output: g_embeddings e.g. [4096, n_tokens] + // store in t_h_nextn (same as MTP) so can be read via llama_get_embeddings_nextn(ctx_dft) + ggml_set_output(cur); + res->t_h_nextn = cur; + + ggml_build_forward_expand(gf, cur); +} + +// eagle3 Decoder: processes draft tokens using g_embeddings from encoder +// Input: draft tokens + g_embeddings from encoder +// Output: draft logits +template <> +llama_model_eagle3::graph<false>::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_layer == 1); // eagle3 has only one decoder layer + + ggml_tensor * cur; + ggml_tensor * inpL; + + // eagle3 Decoder receives: + // 1. Token embeddings (e.g.from eagle3's own tok_embd for Llama 3.3 70B, or target model for Llama 3.1 8B) + // 2. g_embeddings from encoder + auto * tok_embd = model.tok_embd; + if (model.tok_embd == nullptr) { + GGML_ASSERT(cparams.ctx_other != nullptr); + const auto * model_other = llama_get_model(cparams.ctx_other); + + GGML_ASSERT(model_other->tok_embd != nullptr && "EAGLE3 decoder requires token embeddings (own or from target model)"); + tok_embd = model_other->tok_embd; + } + + auto inp = std::make_unique<llm_graph_input_embd>(n_embd); + + inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + ggml_set_input(inp->tokens); + + inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens); + ggml_set_input(inp->embd); + + ggml_tensor * inp_embd = ggml_get_rows(ctx0, tok_embd, inp->tokens); + cb(inp_embd, "inp_embd", -1); + + ggml_tensor * inp_g = inp->embd; + cb(inp_g, "inp_g_embeddings", -1); + + res->add_input(std::move(inp)); + + inpL = inp_g; + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv(); + + const float kq_scale = 1.0f/sqrtf(float(n_embd_head)); + + // Single decoder layer (il = 0) + const int il = 0; + { + // Apply input_layernorm to the token embeddings + ggml_tensor * embd_norm = build_norm(inp_embd, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(embd_norm, "embd_norm", il); + + // Apply hidden_norm to inp_g + ggml_tensor * g_norm = build_norm(inp_g, + model.layers[il].attn_norm_2, NULL, + LLM_NORM_RMS, -1); + cb(g_norm, "g_norm", il); + + // norm_before_residual: determines what goes into the residual connection (compatible with Readhat eagle3 speculator model) + // - false (default): use raw inp_g for residual + // - true: use normalized g_norm for residual + // inpL is the concatenated input (normalized inp_embd + normalized inp_g) + ggml_tensor * inpSA = hparams.norm_before_residual ? g_norm : inpL; + + // Concatenate normalized inp_embd and normalized inp_g + cur = ggml_concat(ctx0, embd_norm, g_norm, il); + cb(cur, "concat_embd", il); + + // Self-attention with concatenated input + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + // rope freq factors, returns nullptr if not available + ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); + + // RoPE + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur_rope", il); + cb(Kcur, "Kcur_rope", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, NULL, nullptr, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + + // Add residual and update it + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // Apply FFN norm to the sum + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "post_attn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + // Output norm with residual + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "eagle3_prenorm", il); + + inpL = cur; + } + + cur = inpL; + + // Output prenorm state (for next token's g_embeddings in autoregressive generation) + ggml_set_output(cur); + res->t_h_nextn = cur; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + cb(cur, "result_norm", -1); + + // lm_head - projects to draft vocabulary + // if the draft has no own output projection, inherit the target model's lm_head + auto * output = model.output; + if (output == nullptr) { + GGML_ASSERT(cparams.ctx_other != nullptr); + const auto * model_other = llama_get_model(cparams.ctx_other); + + GGML_ASSERT(model_other->output != nullptr && "EAGLE3 decoder requires an output projection (own or from target model)"); + output = model_other->output; + } + cur = build_lora_mm(output, cur); + + if (model.d2t) { + const int64_t n_draft_vocab = cur->ne[0]; + const int64_t n_outputs = cur->ne[1]; + const int64_t n_vocab = (int64_t) model.vocab.n_tokens(); + + GGML_ASSERT(model.d2t->type == GGML_TYPE_I64); + GGML_ASSERT(model.d2t->ne[0] == n_draft_vocab); + + ggml_tensor * logits = ggml_fill(ctx0, ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 1, n_vocab, n_outputs), -INFINITY); + cur = ggml_set_rows(ctx0, logits, + ggml_reshape_3d(ctx0, cur, 1, n_draft_vocab, n_outputs), + ggml_reshape_3d(ctx0, model.d2t, n_draft_vocab, 1, 1)); + cur = ggml_reshape_2d(ctx0, cur, n_vocab, n_outputs); + } + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} diff --git a/examples/talk-llama/models/gemma4-assistant.cpp b/examples/talk-llama/models/gemma4-assistant.cpp index 5b7a25a5aba..6378130e79e 100644 --- a/examples/talk-llama/models/gemma4-assistant.cpp +++ b/examples/talk-llama/models/gemma4-assistant.cpp @@ -39,6 +39,9 @@ void llama_model_gemma4_assistant::load_arch_tensors(llama_model_loader &) { output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + create_tensor(tn(LLM_TENSOR_MASKED_EMBD_CENTROIDS, "weight"), {}, TENSOR_NOT_REQUIRED); + create_tensor(tn(LLM_TENSOR_MASKED_EMBD_ORDERING), {}, TENSOR_NOT_REQUIRED); + const int64_t n_embd_backbone = hparams.n_embd_inp(); nextn_proj_post = create_tensor(tn(LLM_TENSOR_NEXTN_PROJ_POST, "weight"), { n_embd, n_embd_backbone }, 0); diff --git a/examples/talk-llama/models/gemma4.cpp b/examples/talk-llama/models/gemma4.cpp index 6f7fcd645cb..6a96979cebd 100644 --- a/examples/talk-llama/models/gemma4.cpp +++ b/examples/talk-llama/models/gemma4.cpp @@ -210,6 +210,8 @@ llama_model_gemma4::graph::graph(const llama_model & model, const llm_graph_para const float freq_scale_l = model.get_rope_freq_scale(cparams, il); const int n_rot_l = hparams.n_rot(il); + res->t_layer_inp[il] = inpL; + // norm cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); cb(cur, "attn_norm", il); diff --git a/examples/talk-llama/models/llama.cpp b/examples/talk-llama/models/llama.cpp index c0ec7e0a9ad..4bfebc8843c 100644 --- a/examples/talk-llama/models/llama.cpp +++ b/examples/talk-llama/models/llama.cpp @@ -124,6 +124,8 @@ llama_model_llama::graph<embed>::graph(const llama_model & model, const llm_grap ggml_tensor * inp_out_ids = build_inp_out_ids(); for (int il = 0; il < n_layer; ++il) { + res->t_layer_inp[il] = inpL; + ggml_tensor * inpSA = inpL; // norm diff --git a/examples/talk-llama/models/models.h b/examples/talk-llama/models/models.h index c137e32e8fd..ee3aff07b9a 100644 --- a/examples/talk-llama/models/models.h +++ b/examples/talk-llama/models/models.h @@ -46,7 +46,7 @@ struct llm_build_delta_net_base : public llm_graph_context { ggml_tensor * s, int il); - // use the ggml_gated_delta_net fused operator (K=1; state has shape (D, 1, n_seqs)) + // use the ggml_gated_delta_net fused operator (K=1; state has shape [S_v, S_v, H_v, n_seqs]) std::pair<ggml_tensor *, ggml_tensor *> build_delta_net_fused( ggml_tensor * q, ggml_tensor * k, @@ -1089,6 +1089,21 @@ struct llama_model_glm_dsa : public llama_model_base { std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; }; +struct llama_model_eagle3 : public llama_model_base { + llama_model_eagle3(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + template <bool is_enc> + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + + ggml_tensor * build_inp_embd_enc() const; + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; +}; + struct llama_model_mistral4 : public llama_model_deepseek2 { llama_model_mistral4(const struct llama_model_params & params) : llama_model_deepseek2(params) {} diff --git a/examples/talk-llama/models/openai-moe.cpp b/examples/talk-llama/models/openai-moe.cpp index 3ab15d61f08..6d74f9c7e6e 100644 --- a/examples/talk-llama/models/openai-moe.cpp +++ b/examples/talk-llama/models/openai-moe.cpp @@ -75,6 +75,8 @@ llama_model_openai_moe::graph::graph(const llama_model & model, const llm_graph_ ggml_tensor * inp_out_ids = build_inp_out_ids(); for (int il = 0; il < n_layer; ++il) { + res->t_layer_inp[il] = inpL; + const float freq_base_l = model.get_rope_freq_base (cparams, il); const float freq_scale_l = model.get_rope_freq_scale(cparams, il); diff --git a/examples/talk-llama/models/plamo2.cpp b/examples/talk-llama/models/plamo2.cpp index b93cf48bc5c..0b81513c368 100644 --- a/examples/talk-llama/models/plamo2.cpp +++ b/examples/talk-llama/models/plamo2.cpp @@ -11,6 +11,10 @@ void llama_model_plamo2::load_arch_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); + // Load attention parameters + ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k_full, false); + ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v_full, false); + for (uint32_t i = 0; i < hparams.n_layer(); ++i) { hparams.is_recr_impl[i] = hparams.n_head_kv(i) == 0; } @@ -273,7 +277,7 @@ ggml_tensor * llama_model_plamo2::graph::build_plamo2_mamba_layer(llm_graph_inpu GGML_ASSERT(n_seqs != 0); GGML_ASSERT(ubatch.equal_seqs()); GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); - GGML_ASSERT(d_inner % n_head == 0); + GGML_ASSERT(d_inner % n_heads == 0); GGML_ASSERT(n_group == 0); ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); diff --git a/examples/talk-llama/models/qwen3.cpp b/examples/talk-llama/models/qwen3.cpp index 1d0d2fab362..f4b2a2aebe0 100644 --- a/examples/talk-llama/models/qwen3.cpp +++ b/examples/talk-llama/models/qwen3.cpp @@ -69,6 +69,8 @@ llama_model_qwen3::graph::graph(const llama_model & model, const llm_graph_param ggml_tensor * inp_out_ids = build_inp_out_ids(); for (int il = 0; il < n_layer; ++il) { + res->t_layer_inp[il] = inpL; + ggml_tensor * inpSA = inpL; // norm diff --git a/examples/talk-llama/models/qwen35.cpp b/examples/talk-llama/models/qwen35.cpp index 4b642cff467..6783d98ec20 100644 --- a/examples/talk-llama/models/qwen35.cpp +++ b/examples/talk-llama/models/qwen35.cpp @@ -173,7 +173,7 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para } if (il == n_layer - 1 && inp_out_ids && cparams.embeddings_nextn_masked) { - cur = ggml_get_rows(ctx0, cur, inp_out_ids); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } diff --git a/examples/talk-llama/models/qwen3moe.cpp b/examples/talk-llama/models/qwen3moe.cpp index 317e668bec7..6f6df5390e3 100644 --- a/examples/talk-llama/models/qwen3moe.cpp +++ b/examples/talk-llama/models/qwen3moe.cpp @@ -78,6 +78,8 @@ llama_model_qwen3moe::graph::graph(const llama_model & model, const llm_graph_pa ggml_tensor * inp_out_ids = build_inp_out_ids(); for (int il = 0; il < n_layer; ++il) { + res->t_layer_inp[il] = inpL; + ggml_tensor * inpSA = inpL; // norm From db5a84bd79926f783b199f5707af42dd99b60f2e Mon Sep 17 00:00:00 2001 From: Rum Nguyen <160252724+rumitvn@users.noreply.github.com> Date: Tue, 16 Jun 2026 13:58:09 +0700 Subject: [PATCH 823/831] cli : add --version flag (#3878) Adds a `--version` option to whisper-cli that prints the library version via `whisper_version()` and exits, plus a corresponding entry in the help output. Mirrors the existing `-h`/`--help` handling. Closes #608 --- examples/cli/cli.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/examples/cli/cli.cpp b/examples/cli/cli.cpp index 7ca563dc250..e505bf0e18d 100644 --- a/examples/cli/cli.cpp +++ b/examples/cli/cli.cpp @@ -151,6 +151,10 @@ static bool whisper_params_parse(int argc, char ** argv, whisper_params & params whisper_print_usage(argc, argv, params); exit(0); } + if (arg == "--version") { + fprintf(stdout, "whisper.cpp version: %s\n", whisper_version()); + exit(0); + } #define ARGV_NEXT (((i + 1) < argc) ? argv[++i] : requires_value_error(arg)) else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(ARGV_NEXT); } else if (arg == "-p" || arg == "--processors") { params.n_processors = std::stoi(ARGV_NEXT); } @@ -234,6 +238,7 @@ static void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params fprintf(stderr, "\n"); fprintf(stderr, "options:\n"); fprintf(stderr, " -h, --help [default] show this help message and exit\n"); + fprintf(stderr, " --version show version information and exit\n"); fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads); fprintf(stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors); fprintf(stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms); From 48f628a84833905ee4a0658ee6d4a5c915ce1997 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius <daniel.bevenius@gmail.com> Date: Tue, 16 Jun 2026 12:28:23 +0200 Subject: [PATCH 824/831] release : v1.8.7 (#3881) --- CMakeLists.txt | 2 +- README.md | 2 +- bindings/javascript/package.json | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 3932cf2845e..b2e936e7267 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,6 +1,6 @@ cmake_minimum_required(VERSION 3.5) # for add_link_options and implicit target directories. project("whisper.cpp" C CXX) -project("whisper.cpp" VERSION 1.8.6) +project("whisper.cpp" VERSION 1.8.7) include(CheckIncludeFileCXX) set(SOVERSION 1) diff --git a/README.md b/README.md index fe7fa74153a..19fdc70daab 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ [![Conan Center](https://shields.io/conan/v/whisper-cpp)](https://conan.io/center/whisper-cpp) [![npm](https://img.shields.io/npm/v/whisper.cpp.svg)](https://www.npmjs.com/package/whisper.cpp/) -Stable: [v1.8.6](https://github.com/ggml-org/whisper.cpp/releases/tag/v1.8.6) / [Roadmap](https://github.com/orgs/ggml-org/projects/4/) +Stable: [v1.8.7](https://github.com/ggml-org/whisper.cpp/releases/tag/v1.8.7) / [Roadmap](https://github.com/orgs/ggml-org/projects/4/) High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisper) automatic speech recognition (ASR) model: diff --git a/bindings/javascript/package.json b/bindings/javascript/package.json index 1f2f34672ae..7c66c730c6c 100644 --- a/bindings/javascript/package.json +++ b/bindings/javascript/package.json @@ -1,6 +1,6 @@ { "name": "whisper.cpp", - "version": "1.8.6", + "version": "1.8.7", "description": "Whisper speech recognition", "main": "whisper.js", "scripts": { From 3805e602d3a3f80ca13211cb96900eae5aad4d1d Mon Sep 17 00:00:00 2001 From: Daniel Bevenius <daniel.bevenius@gmail.com> Date: Tue, 16 Jun 2026 14:33:42 +0200 Subject: [PATCH 825/831] ci : only trigger release jobs for tags (#3883) * ci : only trigger release jobs for tags This commit removes the building of the release jobs on pushed to master. The motivation for this is that it can be confusing at the momement when releasing that the push to master also triggers the release jobs but the actual release will be skipped. With this change the release job is only run when a tag is pushed which should result in a single Release github actions job and make it easier to follow. * ci : add GGML_NATIVE=OFF for ubuntu-22-gcc --- .github/workflows/build-gcc.yml | 1 + .github/workflows/release.yml | 2 -- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/build-gcc.yml b/.github/workflows/build-gcc.yml index 3d8b5137344..53c1b2d783c 100644 --- a/.github/workflows/build-gcc.yml +++ b/.github/workflows/build-gcc.yml @@ -75,6 +75,7 @@ jobs: apt update apt install -y build-essential cmake libsdl2-dev git ccache cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} \ + -DGGML_NATIVE=OFF \ -DCMAKE_C_COMPILER_LAUNCHER=ccache \ -DCMAKE_CXX_COMPILER_LAUNCHER=ccache make diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 11d47546caa..ef2c3083c9f 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -13,8 +13,6 @@ on: type: string push: - branches: - - master tags: - 'v*' From 9efddafb9153e1fb22bdc3dd3057072c99165ed2 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius <daniel.bevenius@gmail.com> Date: Tue, 16 Jun 2026 20:44:10 +0200 Subject: [PATCH 826/831] parakeet : add support for NVIDIA Parakeet (#3735) * parakeet : add support for NVIDIA Parakeet Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> --- CMakeLists.txt | 37 + bindings/ruby/ext/extconf.rb | 2 +- cmake/parakeet-config.cmake.in | 30 + cmake/parakeet.pc.in | 10 + examples/CMakeLists.txt | 2 + examples/parakeet-cli/CMakeLists.txt | 8 + examples/parakeet-cli/README.md | 106 + examples/parakeet-cli/parakeet-cli.cpp | 243 ++ examples/parakeet-quantize/CMakeLists.txt | 7 + .../parakeet-quantize/parakeet-quantize.cpp | 230 + include/parakeet.h | 342 ++ models/convert-parakeet-to-ggml.py | 337 ++ models/for-tests-ggml-parakeet-tdt.bin | Bin 0 -> 16603 bytes models/generate-parakeet-test-model.py | 182 + models/requirements-parakeet.txt | 3 + scripts/quantize-parakeet.sh | 15 + scripts/upload-parakeet.py | 157 + src/CMakeLists.txt | 23 + src/parakeet-arch.h | 188 + src/parakeet.cpp | 3838 +++++++++++++++++ tests/CMakeLists.txt | 59 + tests/librispeech-parakeet/.gitignore | 6 + tests/librispeech-parakeet/Makefile | 15 + tests/librispeech-parakeet/README.md | 57 + tests/librispeech-parakeet/eval.mk | 39 + tests/librispeech-parakeet/eval.py | 47 + .../librispeech-parakeet/normalizers/LICENSE | 25 + .../normalizers/__init__.py | 2 + .../librispeech-parakeet/normalizers/basic.py | 80 + .../normalizers/english.json | 1741 ++++++++ .../normalizers/english.py | 550 +++ tests/parakeet-expected-diffusion-output.txt | 1 + tests/parakeet-expected-gb1-output.txt | 1 + tests/parakeet-expected-jfk-output.txt | 1 + tests/parakeet-verification.h | 110 + tests/run-tests.sh | 46 +- tests/test-parakeet-full.cpp | 101 + tests/test-parakeet.cpp | 99 + 38 files changed, 8733 insertions(+), 7 deletions(-) create mode 100644 cmake/parakeet-config.cmake.in create mode 100644 cmake/parakeet.pc.in create mode 100644 examples/parakeet-cli/CMakeLists.txt create mode 100644 examples/parakeet-cli/README.md create mode 100644 examples/parakeet-cli/parakeet-cli.cpp create mode 100644 examples/parakeet-quantize/CMakeLists.txt create mode 100644 examples/parakeet-quantize/parakeet-quantize.cpp create mode 100644 include/parakeet.h create mode 100755 models/convert-parakeet-to-ggml.py create mode 100644 models/for-tests-ggml-parakeet-tdt.bin create mode 100755 models/generate-parakeet-test-model.py create mode 100644 models/requirements-parakeet.txt create mode 100755 scripts/quantize-parakeet.sh create mode 100644 scripts/upload-parakeet.py create mode 100644 src/parakeet-arch.h create mode 100644 src/parakeet.cpp create mode 100644 tests/librispeech-parakeet/.gitignore create mode 100644 tests/librispeech-parakeet/Makefile create mode 100644 tests/librispeech-parakeet/README.md create mode 100644 tests/librispeech-parakeet/eval.mk create mode 100644 tests/librispeech-parakeet/eval.py create mode 100644 tests/librispeech-parakeet/normalizers/LICENSE create mode 100644 tests/librispeech-parakeet/normalizers/__init__.py create mode 100644 tests/librispeech-parakeet/normalizers/basic.py create mode 100644 tests/librispeech-parakeet/normalizers/english.json create mode 100644 tests/librispeech-parakeet/normalizers/english.py create mode 100644 tests/parakeet-expected-diffusion-output.txt create mode 100644 tests/parakeet-expected-gb1-output.txt create mode 100644 tests/parakeet-expected-jfk-output.txt create mode 100644 tests/parakeet-verification.h create mode 100644 tests/test-parakeet-full.cpp create mode 100644 tests/test-parakeet.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index b2e936e7267..dff25f25a34 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -180,12 +180,20 @@ set(WHISPER_BIN_INSTALL_DIR ${CMAKE_INSTALL_BINDIR} CACHE PATH "Location get_directory_property(WHISPER_TRANSIENT_DEFINES COMPILE_DEFINITIONS) set_target_properties(whisper PROPERTIES PUBLIC_HEADER ${CMAKE_CURRENT_SOURCE_DIR}/include/whisper.h) + install(TARGETS whisper LIBRARY PUBLIC_HEADER) target_compile_definitions(whisper PRIVATE WHISPER_VERSION="${PROJECT_VERSION}" ) +set_target_properties(parakeet PROPERTIES PUBLIC_HEADER ${CMAKE_CURRENT_SOURCE_DIR}/include/parakeet.h) +install(TARGETS parakeet LIBRARY PUBLIC_HEADER) + +target_compile_definitions(parakeet PRIVATE + PARAKEET_VERSION="${PROJECT_VERSION}" +) + configure_package_config_file( ${CMAKE_CURRENT_SOURCE_DIR}/cmake/whisper-config.cmake.in ${CMAKE_CURRENT_BINARY_DIR}/whisper-config.cmake @@ -211,6 +219,35 @@ configure_file(cmake/whisper.pc.in install(FILES "${CMAKE_CURRENT_BINARY_DIR}/whisper.pc" DESTINATION ${CMAKE_INSTALL_LIBDIR}/pkgconfig) +set(PARAKEET_INCLUDE_INSTALL_DIR ${CMAKE_INSTALL_INCLUDEDIR} CACHE PATH "Location of header files") +set(PARAKEET_LIB_INSTALL_DIR ${CMAKE_INSTALL_LIBDIR} CACHE PATH "Location of library files") +set(PARAKEET_BIN_INSTALL_DIR ${CMAKE_INSTALL_BINDIR} CACHE PATH "Location of binary files") + +configure_package_config_file( + ${CMAKE_CURRENT_SOURCE_DIR}/cmake/parakeet-config.cmake.in + ${CMAKE_CURRENT_BINARY_DIR}/parakeet-config.cmake + INSTALL_DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/parakeet + PATH_VARS + PARAKEET_INCLUDE_INSTALL_DIR + PARAKEET_LIB_INSTALL_DIR + PARAKEET_BIN_INSTALL_DIR) + +write_basic_package_version_file( + ${CMAKE_CURRENT_BINARY_DIR}/parakeet-version.cmake + VERSION ${WHISPER_INSTALL_VERSION} + COMPATIBILITY SameMajorVersion) + +install(FILES ${CMAKE_CURRENT_BINARY_DIR}/parakeet-config.cmake + ${CMAKE_CURRENT_BINARY_DIR}/parakeet-version.cmake + DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/parakeet) + +configure_file(cmake/parakeet.pc.in + "${CMAKE_CURRENT_BINARY_DIR}/parakeet.pc" + @ONLY) + +install(FILES "${CMAKE_CURRENT_BINARY_DIR}/parakeet.pc" + DESTINATION ${CMAKE_INSTALL_LIBDIR}/pkgconfig) + # # programs, examples and tests # diff --git a/bindings/ruby/ext/extconf.rb b/bindings/ruby/ext/extconf.rb index 4b09b6ebe13..99894f1234d 100644 --- a/bindings/ruby/ext/extconf.rb +++ b/bindings/ruby/ext/extconf.rb @@ -30,6 +30,6 @@ #{libs}: cmake-targets cmake-targets: #{"\t"}"#{cmake}" -S sources -B build #{options} - #{"\t"}"#{cmake}" --build build --config Release --target common whisper + #{"\t"}"#{cmake}" --build build --config Release --target common whisper parakeet EOF end diff --git a/cmake/parakeet-config.cmake.in b/cmake/parakeet-config.cmake.in new file mode 100644 index 00000000000..aadb55c2d19 --- /dev/null +++ b/cmake/parakeet-config.cmake.in @@ -0,0 +1,30 @@ +set(PARAKEET_VERSION @WHISPER_INSTALL_VERSION@) +set(PARAKEET_BUILD_COMMIT @WHISPER_BUILD_COMMIT@) +set(PARAKEET_BUILD_NUMBER @WHISPER_BUILD_NUMBER@) +set(PARAKEET_SHARED_LIB @BUILD_SHARED_LIBS@) + +@PACKAGE_INIT@ + +set_and_check(PARAKEET_INCLUDE_DIR "@PACKAGE_PARAKEET_INCLUDE_INSTALL_DIR@") +set_and_check(PARAKEET_LIB_DIR "@PACKAGE_PARAKEET_LIB_INSTALL_DIR@") +set_and_check(PARAKEET_BIN_DIR "@PACKAGE_PARAKEET_BIN_INSTALL_DIR@") + +find_package(ggml REQUIRED HINTS ${PARAKEET_LIB_DIR}/cmake) + +find_library(parakeet_LIBRARY parakeet + REQUIRED + HINTS ${PARAKEET_LIB_DIR} + NO_CMAKE_FIND_ROOT_PATH +) + +add_library(parakeet UNKNOWN IMPORTED) +set_target_properties(parakeet + PROPERTIES + INTERFACE_INCLUDE_DIRECTORIES "${PARAKEET_INCLUDE_DIR}" + INTERFACE_LINK_LIBRARIES "ggml::ggml;ggml::ggml-base;" + IMPORTED_LINK_INTERFACE_LANGUAGES "CXX" + IMPORTED_LOCATION "${parakeet_LIBRARY}" + INTERFACE_COMPILE_FEATURES cxx_std_11 + POSITION_INDEPENDENT_CODE ON) + +check_required_components(parakeet) diff --git a/cmake/parakeet.pc.in b/cmake/parakeet.pc.in new file mode 100644 index 00000000000..5a25fbb2e42 --- /dev/null +++ b/cmake/parakeet.pc.in @@ -0,0 +1,10 @@ +prefix=@CMAKE_INSTALL_PREFIX@ +exec_prefix=${prefix} +libdir=${prefix}/@CMAKE_INSTALL_LIBDIR@ +includedir=${prefix}/include + +Name: parakeet +Description: Port of NVIDIA's Parakeet model in C/C++ +Version: @PROJECT_VERSION@ +Libs: -L${libdir} -lggml -lggml-base -lparakeet +Cflags: -I${includedir} diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 0bb54cec489..7aedb9df683 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -107,6 +107,8 @@ else() add_subdirectory(server) add_subdirectory(quantize) add_subdirectory(vad-speech-segments) + add_subdirectory(parakeet-cli) + add_subdirectory(parakeet-quantize) if (WHISPER_SDL2) add_subdirectory(stream) add_subdirectory(command) diff --git a/examples/parakeet-cli/CMakeLists.txt b/examples/parakeet-cli/CMakeLists.txt new file mode 100644 index 00000000000..adb9aba38ef --- /dev/null +++ b/examples/parakeet-cli/CMakeLists.txt @@ -0,0 +1,8 @@ +set(TARGET parakeet-cli) +add_executable(${TARGET} parakeet-cli.cpp) + +include(DefaultTargetOptions) + +target_link_libraries(${TARGET} PRIVATE common parakeet ${FFMPEG_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT}) + +install(TARGETS ${TARGET} RUNTIME) diff --git a/examples/parakeet-cli/README.md b/examples/parakeet-cli/README.md new file mode 100644 index 00000000000..ccb8404f542 --- /dev/null +++ b/examples/parakeet-cli/README.md @@ -0,0 +1,106 @@ +# whisper.cpp/examples/parakeet-cli + +This is an example of using the [Parakeet] model in whisper.cpp. + +### Download converted model +```console +$ hf download ggml-org/parakeet-GGUF parakeet-tdt-0.6b-v3-f16.bin --local-dir models +``` + +### Building +```console +$ cmake -B build -S . +$ cmake --build build --target parakeet-cli -j 12 +``` + +### Usage +```console +$ ./build/bin/parakeet-cli --help + +usage: ./build/bin/parakeet-cli [options] file0 file1 ... +supported audio formats: flac, mp3, ogg, wav + +options: + -h, --help [default] show this help message and exit + -t N, --threads N [4 ] number of threads to use during computation + -m, --model FILE [models/ggml-parakeet-tdt-0.6b-v3.bin] model path + -f, --file FILE [ ] input audio file + -ng, --no-gpu [false ] disable GPU + -dev N, --device N [0 ] GPU device to use + -ps, --print-segments [false ] print segment information +``` + +### Example +```console +$ ./build/bin/parakeet-cli -m models/parakeet-tdt-0.6b-v3-f16.bin -f samples/jfk.wav +Processing audio (176000 samples, 11.00 seconds) +Processing audio: total_frames=1101, chunk_size=1101 +parakeet_decode: starting decode with n_frames=138 +And so, my fellow Americans, ask not what your country can do for you, ask what you can do for your country. +``` + +To print segment information: +```console +$ ./build/bin/parakeet-cli -m models/parakeet-tdt-0.6b-v3-f16.bin -f samples/jfk.wav --print-segments +Processing audio (176000 samples, 11.00 seconds) +Processing audio: total_frames=1101, chunk_size=1101 +parakeet_decode: starting decode with n_frames=138 +And so, my fellow Americans, ask not what your country can do for you, ask what you can do for your country. + +Segments (1): +Segment 0: [0 -> 1101] "And so, my fellow Americans, ask not what your country can do for you, ask what you can do for your country." +Tokens [38]: + [ 0] id= 1976 frame= 3 dur_idx= 4 dur_val= 4 p=0.9996 plog=-15.6206 t0= 24 t1= 56 word_start=true "▁And" + [ 1] id= 547 frame= 7 dur_idx= 4 dur_val= 4 p=0.9999 plog=-18.7922 t0= 56 t1= 88 word_start=true "▁so" + [ 2] id= 7877 frame= 11 dur_idx= 2 dur_val= 2 p=0.8451 plog=-14.5929 t0= 88 t1= 88 word_start=false "," + [ 3] id= 1103 frame= 13 dur_idx= 3 dur_val= 3 p=0.9996 plog=-15.6127 t0= 104 t1= 128 word_start=true "▁my" + [ 4] id= 309 frame= 16 dur_idx= 1 dur_val= 1 p=0.9912 plog=-11.9635 t0= 128 t1= 136 word_start=true "▁f" + [ 5] id= 530 frame= 17 dur_idx= 2 dur_val= 2 p=1.0000 plog=-13.5239 t0= 136 t1= 152 word_start=false "ell" + [ 6] id= 596 frame= 19 dur_idx= 3 dur_val= 3 p=1.0000 plog=-16.3120 t0= 152 t1= 176 word_start=false "ow" + [ 7] id= 3213 frame= 22 dur_idx= 4 dur_val= 4 p=0.9999 plog=-10.1462 t0= 176 t1= 208 word_start=true "▁Amer" + [ 8] id= 404 frame= 26 dur_idx= 4 dur_val= 4 p=1.0000 plog=-25.0910 t0= 208 t1= 240 word_start=false "ic" + [ 9] id= 667 frame= 30 dur_idx= 4 dur_val= 4 p=1.0000 plog=-27.1707 t0= 240 t1= 272 word_start=false "ans" + [10] id= 7877 frame= 37 dur_idx= 4 dur_val= 4 p=0.9094 plog=-16.3405 t0= 272 t1= 272 word_start=false "," + [11] id= 279 frame= 41 dur_idx= 4 dur_val= 4 p=0.9980 plog=-19.7244 t0= 328 t1= 360 word_start=true "▁a" + [12] id= 583 frame= 45 dur_idx= 4 dur_val= 4 p=1.0000 plog=-24.5312 t0= 360 t1= 392 word_start=false "sk" + [13] id= 1491 frame= 53 dur_idx= 4 dur_val= 4 p=1.0000 plog=-23.2991 t0= 424 t1= 456 word_start=true "▁not" + [14] id= 3470 frame= 65 dur_idx= 4 dur_val= 4 p=0.9995 plog=-16.7306 t0= 520 t1= 552 word_start=true "▁what" + [15] id= 3629 frame= 69 dur_idx= 2 dur_val= 2 p=0.8139 plog=-11.6486 t0= 552 t1= 568 word_start=true "▁your" + [16] id= 867 frame= 75 dur_idx= 1 dur_val= 1 p=0.9980 plog=-12.5265 t0= 600 t1= 608 word_start=true "▁co" + [17] id= 331 frame= 76 dur_idx= 2 dur_val= 2 p=1.0000 plog=-11.6697 t0= 608 t1= 624 word_start=false "un" + [18] id= 958 frame= 78 dur_idx= 2 dur_val= 2 p=1.0000 plog=-11.3621 t0= 624 t1= 640 word_start=false "tr" + [19] id= 7893 frame= 80 dur_idx= 2 dur_val= 2 p=1.0000 plog=-14.3245 t0= 640 t1= 656 word_start=false "y" + [20] id= 2059 frame= 82 dur_idx= 3 dur_val= 3 p=1.0000 plog=-17.7694 t0= 656 t1= 680 word_start=true "▁can" + [21] id= 458 frame= 85 dur_idx= 4 dur_val= 4 p=1.0000 plog=-23.2510 t0= 680 t1= 712 word_start=true "▁do" + [22] id= 509 frame= 89 dur_idx= 4 dur_val= 4 p=1.0000 plog=-23.0688 t0= 712 t1= 744 word_start=true "▁for" + [23] id= 1180 frame= 93 dur_idx= 4 dur_val= 4 p=0.9999 plog=-25.0567 t0= 744 t1= 776 word_start=true "▁you" + [24] id= 7877 frame= 98 dur_idx= 4 dur_val= 4 p=0.8820 plog=-14.2549 t0= 776 t1= 776 word_start=false "," + [25] id= 279 frame=102 dur_idx= 3 dur_val= 3 p=0.9992 plog=-16.8176 t0= 816 t1= 840 word_start=true "▁a" + [26] id= 583 frame=105 dur_idx= 4 dur_val= 4 p=1.0000 plog=-21.0352 t0= 840 t1= 872 word_start=false "sk" + [27] id= 3470 frame=109 dur_idx= 3 dur_val= 3 p=0.9999 plog=-15.4659 t0= 872 t1= 896 word_start=true "▁what" + [28] id= 1180 frame=112 dur_idx= 4 dur_val= 4 p=0.9997 plog=-17.6392 t0= 896 t1= 928 word_start=true "▁you" + [29] id= 2059 frame=116 dur_idx= 3 dur_val= 3 p=0.9999 plog=-15.5484 t0= 928 t1= 952 word_start=true "▁can" + [30] id= 458 frame=119 dur_idx= 2 dur_val= 2 p=1.0000 plog=-15.9953 t0= 952 t1= 968 word_start=true "▁do" + [31] id= 509 frame=121 dur_idx= 3 dur_val= 3 p=1.0000 plog=-15.9605 t0= 968 t1= 992 word_start=true "▁for" + [32] id= 3629 frame=124 dur_idx= 2 dur_val= 2 p=0.9994 plog=-12.2083 t0= 992 t1=1008 word_start=true "▁your" + [33] id= 867 frame=126 dur_idx= 2 dur_val= 2 p=0.9969 plog=-9.1252 t0=1008 t1=1024 word_start=true "▁co" + [34] id= 331 frame=128 dur_idx= 1 dur_val= 1 p=0.9999 plog=-12.6911 t0=1024 t1=1032 word_start=false "un" + [35] id= 958 frame=129 dur_idx= 1 dur_val= 1 p=1.0000 plog=-8.8885 t0=1032 t1=1040 word_start=false "tr" + [36] id= 7893 frame=130 dur_idx= 2 dur_val= 2 p=1.0000 plog=-14.1441 t0=1040 t1=1056 word_start=false "y" + [37] id= 7883 frame=132 dur_idx= 4 dur_val= 4 p=0.9567 plog=-11.5227 t0=1056 t1=1056 word_start=false "." +``` + +### Model conversion +Clone the original model from Hugging Face: +```console +$ git clone https://huggingface.co/nvidia/parakeet-tdt-0.6b-v3 +``` +Convert the model: +```console +(venv) $ python models/convert-parakeet-to-ggml.py \ + --model <path to cloned model> \ + --out-dir models \ + --out-name ggml-parakeet-tdt-0.6b-v3-f16.bin +``` + +[Parakeet]: https://huggingface.co/nvidia/parakeet-tdt-0.6b-v3 diff --git a/examples/parakeet-cli/parakeet-cli.cpp b/examples/parakeet-cli/parakeet-cli.cpp new file mode 100644 index 00000000000..03ddc7f8b8c --- /dev/null +++ b/examples/parakeet-cli/parakeet-cli.cpp @@ -0,0 +1,243 @@ +#include "parakeet.h" +#include "common-whisper.h" + +#include <cstdio> +#include <string> +#include <thread> +#include <vector> +#include <cstring> +#include <fstream> + +// command-line parameters +struct parakeet_params { + int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); + + bool use_gpu = true; + int32_t gpu_device = 0; + + bool print_segments = false; + bool output_txt = false; + bool no_prints = false; + + std::string model = "models/ggml-parakeet-tdt-0.6b-v3.bin"; + std::string output_file = ""; + std::vector<std::string> fname_inp = {}; +}; + +static void parakeet_print_usage(int argc, char ** argv, const parakeet_params & params); + +static char * requires_value_error(const std::string & arg) { + fprintf(stderr, "error: argument %s requires value\n", arg.c_str()); + exit(1); +} + +static bool parakeet_params_parse(int argc, char ** argv, parakeet_params & params) { + if (const char * env_device = std::getenv("PARAKEET_ARG_DEVICE")) { + params.gpu_device = std::stoi(env_device); + } + + for (int i = 1; i < argc; i++) { + std::string arg = argv[i]; + + if (arg == "-"){ + params.fname_inp.push_back(arg); + continue; + } + + if (arg[0] != '-') { + params.fname_inp.push_back(arg); + continue; + } + + if (arg == "-h" || arg == "--help") { + parakeet_print_usage(argc, argv, params); + exit(0); + } + #define ARGV_NEXT (((i + 1) < argc) ? argv[++i] : requires_value_error(arg)) + else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(ARGV_NEXT); } + else if (arg == "-m" || arg == "--model") { params.model = ARGV_NEXT; } + else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(ARGV_NEXT); } + else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } + else if (arg == "-dev" || arg == "--device") { params.gpu_device = std::stoi(ARGV_NEXT); } + else if (arg == "-ps" || arg == "--print-segments") { params.print_segments = true; } + else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; } + else if (arg == "-of" || arg == "--output-file") { params.output_file = ARGV_NEXT; } + else if (arg == "-np" || arg == "--no-prints") { params.no_prints = true; } + else { + fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); + parakeet_print_usage(argc, argv, params); + exit(1); + } + } + + return true; +} + +static void parakeet_print_usage(int /*argc*/, char ** argv, const parakeet_params & params) { + fprintf(stderr, "\n"); + fprintf(stderr, "usage: %s [options] file0 file1 ...\n", argv[0]); + fprintf(stderr, "supported audio formats: flac, mp3, ogg, wav\n"); + fprintf(stderr, "\n"); + fprintf(stderr, "options:\n"); + fprintf(stderr, " -h, --help [default] show this help message and exit\n"); + fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads); + fprintf(stderr, " -m, --model FILE [%-7s] model path\n", params.model.c_str()); + fprintf(stderr, " -f, --file FILE [%-7s] input audio file\n", ""); + fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true"); + fprintf(stderr, " -dev N, --device N [%-7d] GPU device to use\n", params.gpu_device); + fprintf(stderr, " -ps, --print-segments [%-7s] print segment information\n", params.print_segments ? "true" : "false"); + fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false"); + fprintf(stderr, " -of, --output-file FILE [%-7s] output file path (without file extension)\n", ""); + fprintf(stderr, " -np, --no-prints [%-7s] do not print anything other than the results\n", params.no_prints ? "true" : "false"); + fprintf(stderr, "\n"); +} + +void token_callback(parakeet_context * ctx, parakeet_state * state, const parakeet_token_data * token_data, void * user_data) { + bool * is_first = (bool *) user_data; + + const char * token_str = parakeet_token_to_str(ctx, token_data->id); + char text_buf[256]; + parakeet_token_to_text(token_str, *is_first, text_buf, sizeof(text_buf)); + printf("%s", text_buf); + fflush(stdout); + + *is_first = false; +} + +static void cb_log_disable(enum ggml_log_level , const char * , void * ) { } + +int main(int argc, char ** argv) { + ggml_backend_load_all(); + + parakeet_params params; + + if (parakeet_params_parse(argc, argv, params) == false) { + return 1; + } + + if (params.no_prints) { + parakeet_log_set(cb_log_disable, NULL); + } + + if (params.fname_inp.empty()) { + fprintf(stderr, "error: no input files specified\n"); + parakeet_print_usage(argc, argv, params); + return 1; + } + + struct parakeet_context_params ctx_params = parakeet_context_default_params(); + ctx_params.use_gpu = params.use_gpu; + ctx_params.gpu_device = params.gpu_device; + + if (!params.no_prints) { + fprintf(stderr, "Loading Parakeet model from: %s\n", params.model.c_str()); + } + + + struct parakeet_context * pctx = parakeet_init_from_file_with_params(params.model.c_str(), ctx_params); + if (pctx == nullptr) { + fprintf(stderr, "error: failed to load Parakeet model from '%s'\n", params.model.c_str()); + return 1; + } + + if (!params.no_prints) { + fprintf(stderr, "Successfully loaded Parakeet model\n"); + fprintf(stderr, "system_info: n_threads = %d / %d | %s\n", + params.n_threads, (int32_t) std::thread::hardware_concurrency(), parakeet_print_system_info()); + } + + // Process each input file + for (const auto & fname : params.fname_inp) { + if (!params.no_prints) { + fprintf(stderr, "\nProcessing file: %s\n", fname.c_str()); + } + + std::vector<float> pcmf32; + std::vector<std::vector<float>> pcmf32s; + if (!read_audio_data(fname.c_str(), pcmf32, pcmf32s, false)) { + fprintf(stderr, "error: failed to read audio file '%s'\n", fname.c_str()); + continue; + } + + if (pcmf32.empty()) { + fprintf(stderr, "error: no audio data in file '%s'\n", fname.c_str()); + continue; + } + + bool is_first = true; + struct parakeet_full_params full_params = parakeet_full_default_params(PARAKEET_SAMPLING_GREEDY); + full_params.n_threads = params.n_threads; + full_params.new_token_callback = token_callback; + full_params.new_token_callback_user_data = &is_first; + + const int mel_frames = (int)(pcmf32.size() / PARAKEET_HOP_LENGTH); + int ret = parakeet_full(pctx, full_params, pcmf32.data(), pcmf32.size()); + + if (ret != 0) { + fprintf(stderr, "error: failed to process audio file '%s'\n", fname.c_str()); + continue; + } + + printf("\n"); + + if (params.output_txt) { + const std::string fname_out = (!params.output_file.empty() ? params.output_file : fname) + ".txt"; + + std::ofstream fout(fname_out); + if (fout.is_open()) { + const int n_segments = parakeet_full_n_segments(pctx); + for (int i = 0; i < n_segments; ++i) { + const char * text = parakeet_full_get_segment_text(pctx, i); + fout << text << "\n"; + } + fout.close(); + if (!params.no_prints) { + fprintf(stderr, "Output written to: %s\n", fname_out.c_str()); + } + } else { + fprintf(stderr, "error: failed to open '%s' for writing\n", fname_out.c_str()); + } + } + + if (!params.no_prints) { + parakeet_print_timings(pctx); + } + + if (params.print_segments) { + const int n_segments = parakeet_full_n_segments(pctx); + fprintf(stderr, "\nSegments (%d):\n", n_segments); + + for (int i = 0; i < n_segments; i++) { + const char * text = parakeet_full_get_segment_text(pctx, i); + const int64_t t0 = parakeet_full_get_segment_t0(pctx, i); + const int64_t t1 = parakeet_full_get_segment_t1(pctx, i); + const int n_tokens = parakeet_full_n_tokens(pctx, i); + + fprintf(stderr, "Segment %d: [%lld -> %lld] \"%s\"\n", i, (long long)t0, (long long)t1, text); + fprintf(stderr, "Tokens [%d]:\n", n_tokens); + + for (int j = 0; j < n_tokens; j++) { + parakeet_token_data token_data = parakeet_full_get_token_data(pctx, i, j); + const char * token_str = parakeet_token_to_str(pctx, token_data.id); + + fprintf(stderr, " [%2d] id=%5d frame=%3d dur_idx=%2d dur_val=%2d p=%.4f plog=%.4f t0=%4lld t1=%4lld word_start=%s \"%s\"\n", + j, + token_data.id, + token_data.frame_index, + token_data.duration_idx, + token_data.duration_value, + token_data.p, + token_data.plog, + (long long)token_data.t0, + (long long)token_data.t1, + token_data.is_word_start ? "true": "false", + token_str); + } + } + } + } + + parakeet_free(pctx); + + return 0; +} diff --git a/examples/parakeet-quantize/CMakeLists.txt b/examples/parakeet-quantize/CMakeLists.txt new file mode 100644 index 00000000000..6b46da18d27 --- /dev/null +++ b/examples/parakeet-quantize/CMakeLists.txt @@ -0,0 +1,7 @@ +set(TARGET parakeet-quantize) +add_executable(${TARGET} parakeet-quantize.cpp) + +include(DefaultTargetOptions) + +target_link_libraries(${TARGET} PRIVATE common parakeet ${CMAKE_THREAD_LIBS_INIT}) +install(TARGETS ${TARGET} RUNTIME) diff --git a/examples/parakeet-quantize/parakeet-quantize.cpp b/examples/parakeet-quantize/parakeet-quantize.cpp new file mode 100644 index 00000000000..a5d9616420f --- /dev/null +++ b/examples/parakeet-quantize/parakeet-quantize.cpp @@ -0,0 +1,230 @@ +#include "ggml.h" +#include "ggml-backend.h" + +#include "common-ggml.h" + +#include <cassert> +#include <cstdio> +#include <cstring> +#include <fstream> +#include <string> +#include <vector> + +struct parakeet_hparams { + int32_t n_vocab = 0; + int32_t n_audio_ctx = 0; + int32_t n_audio_state = 0; + int32_t n_audio_head = 0; + int32_t n_audio_layer = 0; + int32_t n_mels = 0; + int32_t ftype = 0; + int32_t n_fft = 0; + int32_t subsampling_factor = 0; + int32_t n_subsampling_channels = 0; + int32_t n_conv_kernel = 0; + int32_t n_pred_dim = 0; + int32_t n_pred_layers = 0; + int32_t n_tdt_durations = 0; + int32_t n_max_tokens = 0; +}; + +static bool parakeet_model_quantize(const std::string & fname_inp, const std::string & fname_out, ggml_ftype ftype) { + printf("%s: loading model from '%s'\n", __func__, fname_inp.c_str()); + + auto finp = std::ifstream(fname_inp, std::ios::binary); + if (!finp) { + fprintf(stderr, "%s: failed to open '%s' for reading\n", __func__, fname_inp.c_str()); + return false; + } + + auto fout = std::ofstream(fname_out, std::ios::binary); + if (!fout) { + fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname_out.c_str()); + return false; + } + + // magic + { + uint32_t magic; + finp.read((char *) &magic, sizeof(magic)); + if (magic != GGML_FILE_MAGIC) { + fprintf(stderr, "%s: invalid model file (bad magic)\n", __func__); + return false; + } + fout.write((char *) &magic, sizeof(magic)); + } + + // hparams + parakeet_hparams hparams; + { + finp.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab)); + finp.read((char *) &hparams.n_audio_ctx, sizeof(hparams.n_audio_ctx)); + finp.read((char *) &hparams.n_audio_state, sizeof(hparams.n_audio_state)); + finp.read((char *) &hparams.n_audio_head, sizeof(hparams.n_audio_head)); + finp.read((char *) &hparams.n_audio_layer, sizeof(hparams.n_audio_layer)); + finp.read((char *) &hparams.n_mels, sizeof(hparams.n_mels)); + finp.read((char *) &hparams.ftype, sizeof(hparams.ftype)); + finp.read((char *) &hparams.n_fft, sizeof(hparams.n_fft)); + finp.read((char *) &hparams.subsampling_factor, sizeof(hparams.subsampling_factor)); + finp.read((char *) &hparams.n_subsampling_channels, sizeof(hparams.n_subsampling_channels)); + finp.read((char *) &hparams.n_conv_kernel, sizeof(hparams.n_conv_kernel)); + finp.read((char *) &hparams.n_pred_dim, sizeof(hparams.n_pred_dim)); + finp.read((char *) &hparams.n_pred_layers, sizeof(hparams.n_pred_layers)); + finp.read((char *) &hparams.n_tdt_durations, sizeof(hparams.n_tdt_durations)); + finp.read((char *) &hparams.n_max_tokens, sizeof(hparams.n_max_tokens)); + + const int32_t qntvr_src = hparams.ftype / GGML_QNT_VERSION_FACTOR; + const int32_t ftype_dst = GGML_QNT_VERSION * GGML_QNT_VERSION_FACTOR + ftype; + + fprintf(stderr, "%s: n_vocab = %d\n", __func__, hparams.n_vocab); + fprintf(stderr, "%s: n_audio_state = %d\n", __func__, hparams.n_audio_state); + fprintf(stderr, "%s: n_audio_layer = %d\n", __func__, hparams.n_audio_layer); + fprintf(stderr, "%s: n_mels = %d\n", __func__, hparams.n_mels); + fprintf(stderr, "%s: ftype (src) = %d\n", __func__, hparams.ftype); + fprintf(stderr, "%s: qntvr (src) = %d\n", __func__, qntvr_src); + fprintf(stderr, "%s: ftype (dst) = %d\n", __func__, ftype_dst); + fprintf(stderr, "%s: qntvr (dst) = %d\n", __func__, GGML_QNT_VERSION); + + fout.write((char *) &hparams.n_vocab, sizeof(hparams.n_vocab)); + fout.write((char *) &hparams.n_audio_ctx, sizeof(hparams.n_audio_ctx)); + fout.write((char *) &hparams.n_audio_state, sizeof(hparams.n_audio_state)); + fout.write((char *) &hparams.n_audio_head, sizeof(hparams.n_audio_head)); + fout.write((char *) &hparams.n_audio_layer, sizeof(hparams.n_audio_layer)); + fout.write((char *) &hparams.n_mels, sizeof(hparams.n_mels)); + fout.write((char *) &ftype_dst, sizeof(ftype_dst)); + fout.write((char *) &hparams.n_fft, sizeof(hparams.n_fft)); + fout.write((char *) &hparams.subsampling_factor, sizeof(hparams.subsampling_factor)); + fout.write((char *) &hparams.n_subsampling_channels, sizeof(hparams.n_subsampling_channels)); + fout.write((char *) &hparams.n_conv_kernel, sizeof(hparams.n_conv_kernel)); + fout.write((char *) &hparams.n_pred_dim, sizeof(hparams.n_pred_dim)); + fout.write((char *) &hparams.n_pred_layers, sizeof(hparams.n_pred_layers)); + fout.write((char *) &hparams.n_tdt_durations, sizeof(hparams.n_tdt_durations)); + fout.write((char *) &hparams.n_max_tokens, sizeof(hparams.n_max_tokens)); + } + + // mel filterbank + { + int32_t n_mel, n_fb; + finp.read((char *) &n_mel, sizeof(n_mel)); + fout.write((char *) &n_mel, sizeof(n_mel)); + finp.read((char *) &n_fb, sizeof(n_fb)); + fout.write((char *) &n_fb, sizeof(n_fb)); + + const size_t n = (size_t) n_mel * n_fb; + std::vector<float> buf(n); + finp.read((char *) buf.data(), n * sizeof(float)); + fout.write((char *) buf.data(), n * sizeof(float)); + } + + // window function + { + int32_t n_window; + finp.read((char *) &n_window, sizeof(n_window)); + fout.write((char *) &n_window, sizeof(n_window)); + + std::vector<float> buf(n_window); + finp.read((char *) buf.data(), n_window * sizeof(float)); + fout.write((char *) buf.data(), n_window * sizeof(float)); + } + + // TDT durations + { + std::vector<uint32_t> buf(hparams.n_tdt_durations); + finp.read((char *) buf.data(), hparams.n_tdt_durations * sizeof(uint32_t)); + fout.write((char *) buf.data(), hparams.n_tdt_durations * sizeof(uint32_t)); + } + + // vocab + { + int32_t n_tokens; + finp.read((char *) &n_tokens, sizeof(n_tokens)); + fout.write((char *) &n_tokens, sizeof(n_tokens)); + + for (int i = 0; i < n_tokens; ++i) { + int32_t len; + finp.read((char *) &len, sizeof(len)); + fout.write((char *) &len, sizeof(len)); + + std::string token(len, '\0'); + finp.read(&token[0], len); + fout.write(&token[0], len); + } + } + + // tensors — quantize 2D weights skipping tensors that must stay F32: + // ggml_ssm_conv / ggml_conv2d_dw CUDA kernels require F32 weights. + // pos_bias_u / pos_bias_v are declared F32 in the loader. + const std::vector<std::string> to_quant = { ".*" }; + std::vector<std::string> to_skip = { + // CUDA kernel constraints (ggml_ssm_conv / ggml_conv2d_dw require F32 weights) + "encoder\\.layers\\..+\\.conv\\.depthwise_conv\\.weight", + // Declared F32 in loader (pos_bias tensors) + "encoder\\.layers\\..+\\.self_attn\\.pos_bias_u", + "encoder\\.layers\\..+\\.self_attn\\.pos_bias_v", + }; + + // Prediction/joint tensors use n_pred_dim as their inner dimension. K-quant + // types (block size 256) cannot quantize 640 evenly, so keep them F32. For + // other types (Q8_0, Q4_0, block size 32) 640 is divisible and they can be + // quantized normally. The loader mirrors this logic at load time. + { + const ggml_type qtype = ggml_ftype_to_ggml_type(ftype); + const int32_t blck = ggml_blck_size(qtype); + if (blck > 1 && hparams.n_pred_dim % blck != 0) { + to_skip.push_back("decoder\\.prediction\\.embed\\.weight"); + to_skip.push_back("decoder\\.prediction\\.dec_rnn\\.lstm\\.weight_ih_l.*"); + to_skip.push_back("decoder\\.prediction\\.dec_rnn\\.lstm\\.weight_hh_l.*"); + to_skip.push_back("joint\\.pred\\.weight"); + to_skip.push_back("joint\\.joint_net\\.2\\.weight"); + } + } + + if (!ggml_common_quantize_0(finp, fout, ftype, to_quant, to_skip)) { + fprintf(stderr, "%s: failed to quantize tensors\n", __func__); + return false; + } + + finp.close(); + fout.close(); + + return true; +} + +int main(int argc, char ** argv) { + ggml_backend_load_all(); + + if (argc != 4) { + fprintf(stderr, "usage: %s model-f32.bin model-quant.bin type\n", argv[0]); + ggml_print_ftypes(stderr); + return 1; + } + + // initialise F16 lookup tables + { + struct ggml_init_params params = { 0, NULL, false }; + struct ggml_context * ctx = ggml_init(params); + ggml_free(ctx); + } + + const std::string fname_inp = argv[1]; + const std::string fname_out = argv[2]; + const ggml_ftype ftype = ggml_parse_ftype(argv[3]); + + if (ftype == GGML_FTYPE_UNKNOWN) { + fprintf(stderr, "%s: invalid quantization type\n", argv[0]); + ggml_print_ftypes(stderr); + return 1; + } + + const int64_t t_start_us = ggml_time_us(); + + if (!parakeet_model_quantize(fname_inp, fname_out, ftype)) { + fprintf(stderr, "%s: failed to quantize model from '%s'\n", argv[0], fname_inp.c_str()); + return 1; + } + + printf("\n%s: quantize time = %8.2f ms\n", argv[0], (ggml_time_us() - t_start_us) / 1000.0f); + printf("%s: output model = %s\n", argv[0], fname_out.c_str()); + + return 0; +} diff --git a/include/parakeet.h b/include/parakeet.h new file mode 100644 index 00000000000..d35aa870adb --- /dev/null +++ b/include/parakeet.h @@ -0,0 +1,342 @@ +#ifndef PARAKEET_H +#define PARAKEET_H + +#include "ggml.h" +#include "ggml-cpu.h" + +#include <stddef.h> +#include <stdint.h> +#include <stdbool.h> + +#ifdef __GNUC__ +# define PARAKEET_DEPRECATED(func, hint) func __attribute__((deprecated(hint))) +#elif defined(_MSC_VER) +# define PARAKEET_DEPRECATED(func, hint) __declspec(deprecated(hint)) func +#else +# define PARAKEET_DEPRECATED(func, hint) func +#endif + +#ifdef PARAKEET_SHARED +# ifdef _WIN32 +# ifdef PARAKEET_BUILD +# define PARAKEET_API __declspec(dllexport) +# else +# define PARAKEET_API __declspec(dllimport) +# endif +# else +# define PARAKEET_API __attribute__ ((visibility ("default"))) +# endif +#else +# define PARAKEET_API +#endif + +#define PARAKEET_SAMPLE_RATE 16000 +#define PARAKEET_HOP_LENGTH 160 + +#ifdef __cplusplus +extern "C" { +#endif + + struct parakeet_context; + struct parakeet_state; + struct parakeet_full_params; + + typedef int32_t parakeet_pos; + typedef int32_t parakeet_token; + typedef int32_t parakeet_seq_id; + + struct parakeet_context_params { + bool use_gpu; + int gpu_device; // CUDA device + }; + + typedef struct parakeet_token_data { + parakeet_token id; // the BPE subword ID (0-8191) + + int duration_idx; // index into the models durations array + int duration_value; // actual duration value + int frame_index; + + float p; + float plog; + + int64_t t0; + int64_t t1; + + bool is_word_start; + } parakeet_token_data; + + typedef struct parakeet_model_loader { + void * context; + + size_t (*read)(void * ctx, void * output, size_t read_size); + bool (*eof)(void * ctx); + void (*close)(void * ctx); + } parakeet_model_loader; + + PARAKEET_API const char * parakeet_version(void); + + // Various functions for loading a ggml parakeet model. + // Allocate (almost) all memory needed for the model. + // Return NULL on failure + PARAKEET_API struct parakeet_context * parakeet_init_from_file_with_params (const char * path_model, struct parakeet_context_params params); + PARAKEET_API struct parakeet_context * parakeet_init_from_buffer_with_params(void * buffer, size_t buffer_size, struct parakeet_context_params params); + PARAKEET_API struct parakeet_context * parakeet_init_with_params (struct parakeet_model_loader * loader, struct parakeet_context_params params); + + // These are the same as the above, but the internal state of the context is not allocated automatically + // It is the responsibility of the caller to allocate the state using parakeet_init_state() (#523) + PARAKEET_API struct parakeet_context * parakeet_init_from_file_with_params_no_state (const char * path_model, struct parakeet_context_params params); + PARAKEET_API struct parakeet_context * parakeet_init_from_buffer_with_params_no_state(void * buffer, size_t buffer_size, struct parakeet_context_params params); + PARAKEET_API struct parakeet_context * parakeet_init_with_params_no_state (struct parakeet_model_loader * loader, struct parakeet_context_params params); + + PARAKEET_API struct parakeet_state * parakeet_init_state(struct parakeet_context * ctx); + + // Frees all allocated memory + PARAKEET_API void parakeet_free (struct parakeet_context * ctx); + PARAKEET_API void parakeet_free_state(struct parakeet_state * state); + PARAKEET_API void parakeet_free_params(struct parakeet_full_params * params); + PARAKEET_API void parakeet_free_context_params(struct parakeet_context_params * params); + + // Convert RAW PCM audio to log mel spectrogram. + // The resulting spectrogram is stored inside the default state of the provided parakeet context. + // Returns 0 on success + PARAKEET_API int parakeet_pcm_to_mel( + struct parakeet_context * ctx, + const float * samples, + int n_samples, + int n_threads); + + PARAKEET_API int parakeet_pcm_to_mel_with_state( + struct parakeet_context * ctx, + struct parakeet_state * state, + const float * samples, + int n_samples, + int n_threads); + + // This can be used to set a custom log mel spectrogram inside the default state of the provided parakeet context. + // Use this instead of parakeet_pcm_to_mel() if you want to provide your own log mel spectrogram. + // n_mel must be 128 + // Returns 0 on success + PARAKEET_API int parakeet_set_mel( + struct parakeet_context * ctx, + const float * data, + int n_len, + int n_mel); + + PARAKEET_API int parakeet_set_mel_with_state( + struct parakeet_context * ctx, + struct parakeet_state * state, + const float * data, + int n_len, + int n_mel); + + // Run the Parakeet encoder on the log mel spectrogram stored inside the default state in the provided parakeet context. + // Make sure to call parakeet_pcm_to_mel() or parakeet_set_mel() first. + // offset can be used to specify the offset of the first frame in the spectrogram. + // Returns 0 on success + PARAKEET_API int parakeet_encode( + struct parakeet_context * ctx, + int offset, + int n_threads); + + PARAKEET_API int parakeet_encode_with_state( + struct parakeet_context * ctx, + struct parakeet_state * state, + int offset, + int n_threads); + + // Convert the provided text into tokens. + // The tokens pointer must be large enough to hold the resulting tokens. + // Returns the number of tokens on success, no more than n_max_tokens + // Returns a negative number on failure - the number of tokens that would have been returned + // TODO: not sure if correct + PARAKEET_API int parakeet_tokenize( + struct parakeet_context * ctx, + const char * text, + parakeet_token * tokens, + int n_max_tokens); + + // Return the number of tokens in the provided text + // Equivalent to: -parakeet_tokenize(ctx, text, NULL, 0) + int parakeet_token_count(struct parakeet_context * ctx, const char * text); + + PARAKEET_API int parakeet_n_len (struct parakeet_context * ctx); // mel length + PARAKEET_API int parakeet_n_len_from_state(struct parakeet_state * state); // mel length + PARAKEET_API int parakeet_n_vocab (struct parakeet_context * ctx); + PARAKEET_API int parakeet_n_audio_ctx (struct parakeet_context * ctx); + + PARAKEET_API int parakeet_model_n_vocab (struct parakeet_context * ctx); + PARAKEET_API int parakeet_model_n_audio_ctx (struct parakeet_context * ctx); + PARAKEET_API int parakeet_model_n_audio_state(struct parakeet_context * ctx); + PARAKEET_API int parakeet_model_n_audio_head (struct parakeet_context * ctx); + PARAKEET_API int parakeet_model_n_audio_layer(struct parakeet_context * ctx); + PARAKEET_API int parakeet_model_n_mels (struct parakeet_context * ctx); + PARAKEET_API int parakeet_model_ftype (struct parakeet_context * ctx); + + // Token logits obtained from the last call to parakeet_full/parakeet_chunk + // The logits for the last token are stored in the last row + // Rows: n_tokens + // Cols: n_vocab + PARAKEET_API float * parakeet_get_logits (struct parakeet_context * ctx); + PARAKEET_API float * parakeet_get_logits_from_state(struct parakeet_state * state); + + // Token Id -> String. Uses the vocabulary in the provided context + PARAKEET_API const char * parakeet_token_to_str(struct parakeet_context * ctx, parakeet_token token); + + PARAKEET_API int parakeet_token_to_text(const char * token_str, bool is_first, char * output, int max_len); + + // Special tokens + PARAKEET_API parakeet_token parakeet_token_blank(struct parakeet_context * ctx); + PARAKEET_API parakeet_token parakeet_token_unk (struct parakeet_context * ctx); + PARAKEET_API parakeet_token parakeet_token_bos (struct parakeet_context * ctx); + + // Performance information from the default state. + struct parakeet_timings { + float sample_ms; + float encode_ms; + float decode_ms; + }; + PARAKEET_API struct parakeet_timings * parakeet_get_timings(struct parakeet_context * ctx); + PARAKEET_API void parakeet_print_timings(struct parakeet_context * ctx); + PARAKEET_API void parakeet_reset_timings(struct parakeet_context * ctx); + + // Print system information + PARAKEET_API const char * parakeet_print_system_info(void); + + // Available sampling strategies + enum parakeet_sampling_strategy { + PARAKEET_SAMPLING_GREEDY, + }; + + // Token callback. + // Called for each new predicted token. + // Use the parakeet_full_...() functions to obtain the text segments + typedef void (*parakeet_new_token_callback)( + struct parakeet_context * ctx, + struct parakeet_state * state, + const parakeet_token_data * token_data, + void * user_data); + + // Text segment callback + // Called on every newly generated text segment + // Use the parakeet_full_...() functions to obtain the text segments + typedef void (*parakeet_new_segment_callback)(struct parakeet_context * ctx, struct parakeet_state * state, int n_new, void * user_data); + + // Progress callback + typedef void (*parakeet_progress_callback)(struct parakeet_context * ctx, struct parakeet_state * state, int progress, void * user_data); + + // Encoder begin callback + // If not NULL, called before the encoder starts + // If it returns false, the computation is aborted + typedef bool (*parakeet_encoder_begin_callback)(struct parakeet_context * ctx, struct parakeet_state * state, void * user_data); + + // Parameters for the parakeet_full() function + // If you change the order or add new parameters, make sure to update the default values in parakeet.cpp: + // parakeet_full_default_params() + struct parakeet_full_params { + enum parakeet_sampling_strategy strategy; + + int n_threads; + int offset_ms; // start offset in ms + int duration_ms; // audio duration to process in ms + + bool no_context; // do not use past transcription (if any) as context + + int audio_ctx; // overwrite the audio context size (0 = use default) + + // called for every newly generated text segment + parakeet_new_segment_callback new_segment_callback; + void * new_segment_callback_user_data; + + // called for every newly generated token + parakeet_new_token_callback new_token_callback; + void * new_token_callback_user_data; + + // called on each progress update + parakeet_progress_callback progress_callback; + void * progress_callback_user_data; + + // called each time before the encoder starts + parakeet_encoder_begin_callback encoder_begin_callback; + void * encoder_begin_callback_user_data; + + // called each time before ggml computation starts + ggml_abort_callback abort_callback; + void * abort_callback_user_data; + }; + + // NOTE: this function allocates memory, and it is the responsibility of the caller to free the pointer - see parakeet_free_context_params() & parakeet_free_params() + PARAKEET_API struct parakeet_context_params * parakeet_context_default_params_by_ref(void); + PARAKEET_API struct parakeet_context_params parakeet_context_default_params (void); + + PARAKEET_API struct parakeet_full_params * parakeet_full_default_params_by_ref(enum parakeet_sampling_strategy strategy); + PARAKEET_API struct parakeet_full_params parakeet_full_default_params (enum parakeet_sampling_strategy strategy); + + // Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text + // Not thread safe for same context + PARAKEET_API int parakeet_full( + struct parakeet_context * ctx, + struct parakeet_full_params params, + const float * samples, + int n_samples); + + PARAKEET_API int parakeet_full_with_state( + struct parakeet_context * ctx, + struct parakeet_state * state, + struct parakeet_full_params params, + const float * samples, + int n_samples); + + // Process a single chunk of audio data that fits within the model's audio context window. + // This is more efficient than parakeet_full() for short audio clips. + PARAKEET_API int parakeet_chunk( + struct parakeet_context * ctx, + struct parakeet_state * state, + struct parakeet_full_params params, + const float * samples, + int n_samples); + + // Number of generated text segments + PARAKEET_API int parakeet_full_n_segments (struct parakeet_context * ctx); + PARAKEET_API int parakeet_full_n_segments_from_state(struct parakeet_state * state); + + // Get the start and end time of the specified segment + PARAKEET_API int64_t parakeet_full_get_segment_t0 (struct parakeet_context * ctx, int i_segment); + PARAKEET_API int64_t parakeet_full_get_segment_t0_from_state(struct parakeet_state * state, int i_segment); + + PARAKEET_API int64_t parakeet_full_get_segment_t1 (struct parakeet_context * ctx, int i_segment); + PARAKEET_API int64_t parakeet_full_get_segment_t1_from_state(struct parakeet_state * state, int i_segment); + + // Get the text of the specified segment + PARAKEET_API const char * parakeet_full_get_segment_text (struct parakeet_context * ctx, int i_segment); + PARAKEET_API const char * parakeet_full_get_segment_text_from_state(struct parakeet_state * state, int i_segment); + + // Get number of tokens in the specified segment + PARAKEET_API int parakeet_full_n_tokens (struct parakeet_context * ctx, int i_segment); + PARAKEET_API int parakeet_full_n_tokens_from_state(struct parakeet_state * state, int i_segment); + + // Get the token text of the specified token in the specified segment + PARAKEET_API const char * parakeet_full_get_token_text (struct parakeet_context * ctx, int i_segment, int i_token); + PARAKEET_API const char * parakeet_full_get_token_text_from_state(struct parakeet_context * ctx, struct parakeet_state * state, int i_segment, int i_token); + + // Get the token id of the specified token in the specified segment + PARAKEET_API parakeet_token parakeet_full_get_token_id (struct parakeet_context * ctx, int i_segment, int i_token); + PARAKEET_API parakeet_token parakeet_full_get_token_id_from_state(struct parakeet_state * state, int i_segment, int i_token); + + // Get token data for the specified token in the specified segment + PARAKEET_API parakeet_token_data parakeet_full_get_token_data (struct parakeet_context * ctx, int i_segment, int i_token); + PARAKEET_API parakeet_token_data parakeet_full_get_token_data_from_state(struct parakeet_state * state, int i_segment, int i_token); + + // Get the probability of the specified token in the specified segment + PARAKEET_API float parakeet_full_get_token_p (struct parakeet_context * ctx, int i_segment, int i_token); + PARAKEET_API float parakeet_full_get_token_p_from_state(struct parakeet_state * state, int i_segment, int i_token); + + // Control logging output; default behavior is to print to stderr + + PARAKEET_API void parakeet_log_set(ggml_log_callback log_callback, void * user_data); + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/models/convert-parakeet-to-ggml.py b/models/convert-parakeet-to-ggml.py new file mode 100755 index 00000000000..2d6a6d01554 --- /dev/null +++ b/models/convert-parakeet-to-ggml.py @@ -0,0 +1,337 @@ +#!/usr/bin/env python3 +# Convert Parakeet TDT model from NeMo format to ggml format +# +# Usage: python convert-parakeet-to-ggml.py --model parakeet-model.nemo --output-dir output-dir [--use-f32] +# +# The NeMo file is a tar archive containing: +# - model_weights.ckpt (PyTorch checkpoint) +# - model_config.yaml (model configuration) +# - tokenizer files +# +# This script extracts the NeMo archive, loads the model weights and configuration, +# and saves them in ggml format compatible with whisper.cpp. +# + +import torch +import argparse +import io +import os +import sys +import struct +import tarfile +import tempfile +import shutil +import yaml +import numpy as np +from pathlib import Path +from typing import Optional + +def hz_to_mel(freq): + return 2595.0 * np.log10(1.0 + freq / 700.0) + +def mel_to_hz(mel): + return 700.0 * (10.0**(mel / 2595.0) - 1.0) + +def extract_nemo_archive(nemo_path, extract_dir): + print(f"Extracting {nemo_path} to {extract_dir}") + with tarfile.open(nemo_path, 'r') as tar: + tar.extractall(path=extract_dir) + print("Extraction complete") + +def load_model_config(config_path): + with open(config_path, 'r', encoding='utf-8') as f: + config = yaml.safe_load(f) + return config + +def load_tokenizer(extract_dir, config): + tokenizer_model_path = None + tokenizer_vocab_path = None + + for file in os.listdir(extract_dir): + if file.endswith('_tokenizer.model'): + tokenizer_model_path = os.path.join(extract_dir, file) + elif file.endswith('tokenizer.vocab'): + tokenizer_vocab_path = os.path.join(extract_dir, file) + + if not tokenizer_model_path: + raise FileNotFoundError("Tokenizer model file not found") + + if not tokenizer_vocab_path: + raise FileNotFoundError("Tokenizer vocab file not found") + + tokens = {} + with open(tokenizer_vocab_path, 'r', encoding='utf-8') as f: + for idx, line in enumerate(f): + parts = line.strip().split('\t') + if len(parts) >= 1: + token = parts[0] + tokens[token.encode('utf-8')] = idx + + print(f"Loaded {len(tokens)} tokens from {os.path.basename(tokenizer_vocab_path)}") + + if len(tokens) != 8192: + print(f"WARNING: Expected 8192 tokens, got {len(tokens)}") + + return tokens + +def write_tensor(fout, name, data, use_f16=True, force_f32=False): + if 'pre_encode.conv' in name and 'bias' in name and len(data.shape) == 1: + data = data.reshape(1, -1, 1, 1) + print(f" Reshaped conv bias {name} to {data.shape}") + + n_dims = len(data.shape) + + ftype = 1 if use_f16 and not force_f32 else 0 + if force_f32: + data = data.astype(np.float32) + elif use_f16: + if n_dims < 2 or 'bias' in name or 'norm' in name or \ + ('pre_encode.conv' in name and n_dims == 4) or \ + 'depthwise_conv.weight' in name: + data = data.astype(np.float32) + ftype = 0 + else: + data = data.astype(np.float16) + else: + data = data.astype(np.float32) + + dims_reversed = [data.shape[n_dims - 1 - i] for i in range(n_dims)] + print(f"Processing: {name} {list(data.shape)}, dtype: {data.dtype}, n_dims: {n_dims}, reversed: {dims_reversed}") + name_bytes = name.encode('utf-8') + fout.write(struct.pack("iii", n_dims, len(name_bytes), ftype)) + for i in range(n_dims): + fout.write(struct.pack("i", data.shape[n_dims - 1 - i])) + fout.write(name_bytes) + + data.tofile(fout) + +def convert_parakeet_to_ggml(nemo_path, output_dir, use_f16=True, out_name=None): + nemo_path = Path(nemo_path) + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + # Create temporary directory for extraction + with tempfile.TemporaryDirectory() as temp_dir: + extract_nemo_archive(nemo_path, temp_dir) + + config_path = os.path.join(temp_dir, 'model_config.yaml') + config = load_model_config(config_path) + + print("Model configuration:") + print(f" Sample rate: {config['sample_rate']}") + print(f" Encoder layers: {config['encoder']['n_layers']}") + print(f" Encoder d_model: {config['encoder']['d_model']}") + print(f" Mel features: {config['preprocessor']['features']}") + + weights_path = os.path.join(temp_dir, 'model_weights.ckpt') + print(f"\nLoading model weights from {weights_path}") + checkpoint = torch.load(weights_path, map_location='cpu') + + # Extract state dict + if 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + else: + state_dict = checkpoint + + print(f"Loaded {len(state_dict)} tensors") + + # Load tokenizer + print("\nLoading tokenizer...") + tokens = load_tokenizer(temp_dir, config) + print(f"Loaded {len(tokens)} tokens") + + # Prepare hyperparameters for the Parakeet ggml format. + hparams = { + 'n_audio_ctx': 5000, + 'n_audio_state': config['encoder']['d_model'], + 'n_audio_head': config['encoder']['n_heads'], + 'n_audio_layer': config['encoder']['n_layers'], + 'n_mels': config['preprocessor']['features'], + 'n_fft': config['preprocessor']['n_fft'], + 'subsampling_factor': config['encoder']['subsampling_factor'], + 'n_subsampling_channels': config['encoder']['subsampling_conv_channels'], + 'n_conv_kernel': config['encoder']['conv_kernel_size'], + + 'n_pred_dim': config['decoder']['prednet']['pred_hidden'], + 'n_pred_layers': config['decoder']['prednet']['pred_rnn_layers'], + 'n_vocab': config['decoder']['vocab_size'], + 'n_tdt_durations': config['model_defaults']['num_tdt_durations'], + 'n_max_tokens': config['decoding']['greedy']['max_symbols'], + } + + print("\nGGML hyperparameters:") + for key, value in hparams.items(): + print(f" {key}: {value}") + + # Create output file + if out_name: + fname_out = output_dir / out_name + else: + fname_out = output_dir / ("ggml-model-f32.bin" if not use_f16 else "ggml-model.bin") + print(f"\nWriting to {fname_out}") + + with open(fname_out, 'wb') as fout: + # Write magic number + fout.write(struct.pack("i", 0x67676d6c)) # 'ggml' in hex + + # Write hyperparameters + fout.write(struct.pack("i", hparams['n_vocab'])) + fout.write(struct.pack("i", hparams['n_audio_ctx'])) + fout.write(struct.pack("i", hparams['n_audio_state'])) + fout.write(struct.pack("i", hparams['n_audio_head'])) + fout.write(struct.pack("i", hparams['n_audio_layer'])) + fout.write(struct.pack("i", hparams['n_mels'])) + fout.write(struct.pack("i", 1 if use_f16 else 0)) + fout.write(struct.pack("i", hparams['n_fft'])) + fout.write(struct.pack("i", hparams['subsampling_factor'])) + fout.write(struct.pack("i", hparams['n_subsampling_channels'])) + fout.write(struct.pack("i", hparams['n_conv_kernel'])) + fout.write(struct.pack("i", hparams['n_pred_dim'])) + fout.write(struct.pack("i", hparams['n_pred_layers'])) + fout.write(struct.pack("i", hparams['n_tdt_durations'])) + fout.write(struct.pack("i", hparams['n_max_tokens'])) + + # Extract mel filterbank from model + fb_key = None + for key in state_dict.keys(): + if 'featurizer.fb' in key or 'filterbank' in key.lower(): + fb_key = key + break + + if not fb_key: + print("\nERROR: Mel filterbank not found in model!") + print("Expected tensor with 'featurizer.fb' or 'filterbank' in name") + print("\nAvailable preprocessor tensors:") + for key in sorted(state_dict.keys()): + if 'preprocessor' in key or 'featurizer' in key: + print(f" {key}: {state_dict[key].shape}") + raise ValueError("Mel filterbank tensor not found in model") + + print(f"\nUsing model's mel filterbank from: {fb_key}") + mel_filters = state_dict[fb_key].squeeze().numpy().astype(np.float32) + print(f" Filterbank shape: {mel_filters.shape}") + print(f" Filterbank min/max values: {mel_filters.min():.6f} / {mel_filters.max():.6f}") + print(f" Filterbank non-zero elements: {np.count_nonzero(mel_filters)} / {mel_filters.size}") + print(f" First row sum: {mel_filters[0].sum():.6f}") + + if len(mel_filters.shape) != 2: + raise ValueError(f"Expected 2D filterbank, got shape {mel_filters.shape}") + + n_mels, n_freqs = mel_filters.shape + fout.write(struct.pack("i", n_mels)) # n_mel + fout.write(struct.pack("i", n_freqs)) # n_fb (frequency bins) + + # Write mel filterbank + for i in range(n_mels): + for j in range(n_freqs): + fout.write(struct.pack("f", mel_filters[i, j])) + + # Extract window function from model + window_key = None + for key in state_dict.keys(): + if 'featurizer.window' in key or 'preproc' in key and 'window' in key: + window_key = key + break + + if not window_key: + print("\nERROR: Window function not found in model!") + print("Expected tensor with 'featurizer.window' in name") + raise ValueError("Window function tensor not found in model") + + print(f"\nUsing model's window function from: {window_key}") + window = state_dict[window_key].squeeze().numpy().astype(np.float32) + print(f" Window shape: {window.shape}") + print(f" Window min/max values: {window.min():.6f} / {window.max():.6f}") + print(f" Window non-zero elements: {np.count_nonzero(window)} / {window.size}") + print(f" Window sum: {window.sum():.6f}") + + if len(window.shape) != 1: + raise ValueError(f"Expected 1D window, got shape {window.shape}") + + n_window = window.shape[0] + fout.write(struct.pack("i", n_window)) + + # Write window function + for i in range(n_window): + fout.write(struct.pack("f", window[i])) + + # Write TDT durations + tdt_durations = config['model_defaults']['tdt_durations'] + if len(tdt_durations) != hparams['n_tdt_durations']: + raise ValueError(f"TDT durations count mismatch: {len(tdt_durations)} vs {hparams['n_tdt_durations']}") + + for duration in tdt_durations: + fout.write(struct.pack("I", duration)) + + fout.write(struct.pack("i", len(tokens))) + for token_bytes, idx in sorted(tokens.items(), key=lambda x: x[1]): + fout.write(struct.pack("i", len(token_bytes))) + fout.write(token_bytes) + + # Pre-collect prediction LSTM input-hidden biases so they can be + # folded into the hidden-hidden bias during the main write loop. + lstm_prefix = 'decoder.prediction.dec_rnn.lstm' + pred_bias_ih = {} + for key, t in state_dict.items(): + if f'{lstm_prefix}.bias_ih_l' in key: + layer_idx = int(key.rsplit('bias_ih_l', 1)[1]) + pred_bias_ih[layer_idx] = t.squeeze().numpy().astype(np.float32) + + print("\nConverting model weights...") + for name, tensor in state_dict.items(): + # Skip the filterbank and window - already written in preprocessing section + if name == fb_key: + continue + if name == window_key: + continue + + # bias_ih is folded into bias_hh below; skip writing it separately + if f'{lstm_prefix}.bias_ih_l' in name: + continue + + # Don't squeeze Conv2d weights - they need to preserve all 4 dimensions + if 'conv' in name and 'weight' in name and len(tensor.shape) == 4: + data = tensor.numpy() + else: + data = tensor.squeeze().numpy() + + # For prediction LSTM weights/biases: + # Fold bias_ih into bias_hh (bias_ih already skipped above). + # Reorder gates (input, forget, cell, output) from PyTorch layout + # [i, f, g, o] to [i, f, o, g] so the three sigmoid-gated outputs + # (i, f, o) are contiguous. + if name.startswith(f'{lstm_prefix}.'): + if f'{lstm_prefix}.bias_hh_l' in name: + layer_idx = int(name.rsplit('bias_hh_l', 1)[1]) + data = data.astype(np.float32) + pred_bias_ih[layer_idx] + name = name.replace('bias_hh_l', 'bias_h_l') + h = data.shape[0] // 4 + data = np.concatenate([data[:h], data[h:2*h], data[3*h:], data[2*h:3*h]], axis=0) + + write_tensor(fout, name, data, use_f16=use_f16) + + print(f"\nConversion complete!") + print(f"Output file: {fname_out}") + print(f"File size: {fname_out.stat().st_size / (1024**2):.2f} MB") + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description='Convert Parakeet TDT model from NeMo format to ggml format' + ) + parser.add_argument('--model', type=str, required=True, + help='Path to Parakeet .nemo model file') + parser.add_argument('--out-dir', type=str, required=True, + help='Directory to write ggml model file') + parser.add_argument('--use-f32', action='store_true', default=False, + help='Use f32 instead of f16 (default: f16)') + parser.add_argument('--out-name', type=str, default=None, + help='Output file name (default: ggml-model.bin or ggml-model-f32.bin)') + + args = parser.parse_args() + + if not os.path.exists(args.model): + print(f"Error: {args.model} not found") + sys.exit(1) + + use_f16 = not args.use_f32 + convert_parakeet_to_ggml(args.model, args.out_dir, use_f16, args.out_name) diff --git a/models/for-tests-ggml-parakeet-tdt.bin b/models/for-tests-ggml-parakeet-tdt.bin new file mode 100644 index 0000000000000000000000000000000000000000..8b1dda1feba08f4439a0e5fd3e2edbb7c38cc41e GIT binary patch literal 16603 zcmbVz2|U$Zx4)qb2}z_vC_+&x%JAFkq>|D+P*Reph(eM^V=~VoN}^;|M26p9M-vSa z4NB5LsYE1Ess879pXc7^{om)k_r7<3KKpkL&RO>UuC>?xuC>?R<+6SI2mt{BpOFFr zV!T%&UKix`QM}G?9A4-Dei*L{|M8yxcmC&w^E&@?s=WS6IEBno4#XHG4m_>d1A!Ia zus-iRCcYCQ;UCgy{Mq12c{^(~Qn3SDrA08_IEZmu90gl9eIh5mZ^ZVym1Lc_Cn~?* z0!l}hvXM&}>bOc5O%A2NcN1lLYuPjUbV)IFS||pGCRC8ZRS!r^Tr;T}CC*8cyHBmu zgmLGZmH4nw3AFv9VDYB+?9nk+I7?;?a>6y>-9iGIBPHO;W^F90a3KQ&T}=BYThb?I z3&vl!kwvvS@WwHO@f3|C6H-(`E}$Gg=cLnuhuzdCd=k392&0TOf@0<-9Q(=(7VT(6 z$$dGLy`VvIM$5zJ(FSl&!2%S*F3~MzAL)vWjbQ#}J}q3FjuFz)Xk_P%O`*St<P#Zk z{@M&^<~C8|^=jmzh6?0tJx$$P-jnYhS~&A{8P)K(M$KfB&~xoK;=8V!sueFaahd#; zRHn9(d`oS(dg&G$QT~K3o1f0ces9DdMrCNYJPV^Oj-r0SU2>pyJ03e-MU|GF2jd+j zw4v4nq6K%O;pWx!+-wh|@p*K&vK2g5c#hS(TFBQu9>ie00`5{M!}Ke$%;PXo%r$i- z3x$hJ#){TL#yC&fFX)eAadXIhGdt4!Ef1P*t%G}|CB$!81=wpP;?<jzP;RyhPKmFk zoH>O!FG&qHMA%ci`IUI|moK&koFV%nE#Xz%Zzwxt#PQa7NNiUhMt2V}?g5!#@=kFj zUe0<*w_H-hLsLi55$?yqs$PXmb5o<sq5@FZ<}q6T9H0soHf)f^YbbCWPU_<hLT>tR z%H*EFHW@E;x?qS;=ZbM0%~qpPrX`qYY{!^`u3({XgA_QZ(B!Wg?9?+;FxI|}B=)|e z=U=9vXKXsQmd_yB6K_M~!$~A_;4^#ADu&2R&4hdBwh$?Y2HJ8m9Qp&Fk<8fZP?u&* zwpXNJ<HQ0?UiARRq&36q2Xm;Z-gTNhsseJ0!tvoq2b8%u84AR2({+zBD`GcIp)rqh z$!5)BhzlrVraheu!wrPFpTich-OsI1Mkoan?`=V&(K+<QekWk6v+zyVcN3{M{q)B5 zOuE_H2R>DaL-X_;U=5V%QmvT?#twAMs{Q0#-E3l18-kZtmobIKJBfv|D#zWV8#G<5 z@%^M&Dzmx*W9N>=9|>t>=Z2HOoHN47`=;P@7jf>JXbn2^XArIr7R6rgEm%IXiX}f^ z;-s1iOjsz!nSVP2_uGfz;KU9{{b<N}(-)7bp_SD3Wj^U&_#Cp&9ztd9gEaBC8Du53 zkf0M>+A1ba69x;gzSV>F*dM2#Sr^i&C51;_dPwdWUue;sNSwwMlU-#Sn3;oxb49Pr z#bAlEWG1O*_kSJExnzF}^22YCD^~N-$36mWy%Iqt=os1fsE;Hre}O5xT%acF9Z6J+ zz#``!+A}<exU^d``%c8u3q225U149k!BCMr-B1Z?3np`49(Sj?IGY>`J3&r3Q)=4$ z5=A~`lKbPHL(IDC_-N85@SJ^tGSd&Tb~lb;Ba=g$x1WbMnsQ)Y`5bm%PJn^Bmjn!7 zK*_b2M67=eXgoPVr!79tYz~=;w(AZ<W<vrhnVW-{(R}LC5f8m{Iq+%3dypJs45#Ma zLR1cB*9}X+pLxUZ)9xy8(w_tqI%ndMDqjpR6Jf}eWPDmT6)$5dS@0?r?lew-!x^Kw zeUFlf)v;K{<^2QF_u?Lm8##;{v3e|~FB}U6ZnyDmM<p;xm886EJ#5k%jrFTnLf*CG z5bNbiZj0<8u1+Cj=CDn8Rc0)0(N(0gmR=_ZzZ?a{BSE-XNPt?NVd3sZg5Qs<#*JDl zNSfAqu%2yfVmBchnk+X$psy9gHA%w4fiawSa;s^Hp#WEZeJ#XUeW1I#o0*NNnb0U- zi{JInVBXie7_ig=B^6!~9X%1K^9U#YkBYJKi7fYd+GZv#tcRYO-9b!>W`S0uB#}!r z1^W^wvig!J?01-rQXl8g+@p)h6sgtVmUxTKbsi2PipObhSvPC5aV?xw{EhcMpF_?S z7W+5IaoMvv__hgPO>;Co6ug8SbPk2oo-EKy3}=^_=0SUIB~xK*iK*HOIM<IwYGaT6 z%StJhE`u4S&9J7U5DQdpk<b}|cvPa7#2Jgwi5oxQiwqZ7`)WKE=856?&f}<fbP9To zc}5oAnn&%;E|7+0WyDP)3vzr5Xr4HWc=9;ZU&*5V*%2VP<T~jv7ok-#mq<<dH{f&* z<6IgW!3c^^Cd}1fAcxoEtmD~aP0NGIxGQJK1*>(~yYCKO%NfRX6N$vbjUTAKQUcZM z?+3l;a5!nP51)u!B9Xi9(@BZZROlu{1pJg?Nl_kMJncA$O^w0#rFF)OIU1Z{ySI^8 zk5cGztzhc%#=z0|9qgK;x>PG{5$;>rOssxwz&rNZQ1BrH1hXs1%*Yt3<NA&`j5tk) z`Rs-O=MPx%Fo$+F{K6@l5l6)4!xL=>cs6}5im&EkV>!b;rY3=7-i_x>$!a009BKTe z+(i!6R^tt$MEa!s4imDjkX`7$1gty@iPPArplvt<jT;?6$oD>6*|;A?IQIDCdp6{M znopezi(o|DYx;?l!`uuJ!r^g<XoCQfZ(ajG49lop=N{D0enVF!wv#(fHE5wX2|7NC z<FSqyh%z3>-D0yCXS6ipP5(8hnsS&-51K=Z%YCVh;3O)va3v`j>x<*Q%!UIQT5xgF z4f1`z6pX)qf@BaOj_c-2tgz83I=!8P(^UhoX-O&?9=FGr+$gp}Gly6nSj9~IJ|B;| zJA%d%2A&?6ME%vL;7x~-pmB&pOl-=~#4wbN)H+5SS3BZzM^kcUPC2TJ<-mi=RQR&k zmln#G6D^bR+*!hLsOOYGFBk`7;cEkM%KQy0lM_&2n<d(fb0boQ_u$&T6kH>Gl<*)J z;$D0{#^+Whp03V>Pr6Jz2%mqMO!pvs?q$b+Y~?@f`HyY<hXeoN$bW3-Q!^!A=Qloe z<97#F``z0d+;#W3JJ|l?z3y($y}DiwJGbxHyISQvJlXXCWxuCm|LZ5zeBp3R7g-7C zJiSPtUMqI@2*Zh|vM8^b4sls3IQN?xo(viR8Y||5xyu`pS=ESVD#by6TL2n<+(Wud zO3*T(jVw649Bp+nu%prdvfR8#yp9qYbIV}nsaMp}Uziao2qrFPm9ctm4_Ox_hNf2= zsH}57eAElXUryt>(`Vg-8Jd$gdym%96(9Pss{b^t4~?M?2TiGue*`LT9s|l|B2ZwC z82M`^=-CQ$Y~^Q=*OznYlASVizpoOUKlU8OEj3BohiZB}@FQ4S&!kp|AHkMI;dt6o z9$b&74`Diij~1`}H(+huY3K3Edo!-h)F<_`l0mfQ95s%4f*mg+D`ySgjw=@xP`;W} z_;PrSuQ`6_D~132ztMsHZr6RfGyYJ+;jx+Q{nbs_R(lV<>c5h<`@`|Z<jLHS4s#f} zH<b8r4X`cz0{Zo@1NoLqB;t?)=c`&H5%s-EG(TxmLxoikKU@Wl%{>Fop2M-cvxN#T z>!Uy1mO$lggvp&BXwSi?wBp_w(thC}y<O%3%1WcSJ!jY8o~7nk>NAEbba5Z%E^<P< zu4O}WQRa=C-}rO*a~}NHe?1pIN;#eppxbkR7Weq0_TruF{{@wqe;TL7G1&9O8~04u zi84Q>U{ljp8d;?dcM@zXjnB*ix5EKg^lM?I3R=;XXOnTHN)3<jPY1EWA}ZJ9hx@)= zXFqO@f|Pz4?z3pb#iqOPy=VoHf+*7Nx&ZEUN8^{Jk=VHC4UxI%%4i%I2@j@m=(8OW zWZ=Cjxi(=DC?qb2eyci!2hv>rT>f2DX6o`$(KGacEux8VGu;p6rMvL2sQeY*|ISus z{ejAjhTDvjg)793eG6lz4TH>8l<H?}!&#?|O$-f_D0cedePv;eqfZ-1I7LHKr89=T zwuT5g39RB@54n?ncNVjTW+B#>3I&;J5bMUkl97|(ue12yTD<-rvoJW6Pi&X9VbHuB zCTXh_KCDm2>kAv<%8KXE^F0sN<@k}suX3UIz(SL;*(=D4uc7E}dWV=;c`?h`^Ej$_ z6+3TnI(ta|EogUagUmo<lkr!NkP!;{OhL^G`YkL3*XNC=O9cR}Mx4RjFNC0RO$~W; zsSV{!f3S<+9)WXlL6ExKANJZ$Wd1L@)gMBoZS8tIR}z9!dkV>~8CQq!RpD6*zx{7@ zYw!mu3#v0f?QR(KI*QUsrm5ItmrkF#Z6@!#meZDQSyVpYOST9LVt0)S;=WY!t-hA} zeDx-duR=-PljtG4`ge6}Ff<Fs;VYFDzXHQC4i|Wda`_|I9-6~HJH}mh-VW{_Jjdm_ z+ug<1(ZOMxt>bQYFFW^bdVkEJXJr=-dUcr?oBcHT`n;QsbxMXS(o?xlN^cYW=heWM ztod(B=C7fN8>3KQ-&J~y<>2$)ZFpZIlRhpjfy2`~z(^pM^7lTSzsObLz5ml!{l}>P zdBD4Nx;ogo|2c5sr@fGIMTY6!{}MLIA}camgzTSKL9E;@I8B`!V2QLo!l?bAWLidp zx!>uylRIE`<}VZ9v<PCR=3sJWZ7xxbe@c#b^`l>gHk@%dh^D8mU>avPYy2yRNe$LR zU%D0??&e^Z^?kasTn;T72)nF0f}{)i;t3~1uov}ZET(=&m7?$1{CIJt>B6nFv-Sb& zW&Ea6>zgCdPI8BWxe=%>oPgq=L$UQ^EZJk73df!_)20b3+&zj-_-5WZTwN`ND;4S~ zBdW~IIh9JUtvm+%rFHNH*9LW__mPgs>zEr@O}2|RFxo%l>Di@s;HUXS?0hE7@o99V zW}fNzTQ8E?el-|JdvLIK(H)XzBn4Z|CSv@}Loi1x4l19NLVA!WJPODMLE4Syub-t- zUp8Ukv@gujC8N<?U7Jqm=mya{Iy9mA0TF)ULZhXVnb;{ej4MLkkh_}kaPN5o96NTJ zgnpEVpcS!LHc-UYmF$6ug`2>za6KvguoZ4D{6W&wli;;R92u)zLhXBVNOOcSH%a_D zNjkcd>v6If=PYk0<ww2{uiJ6BdZ_`nx)i|OhJGp(Ce4XY|52{|;Xa{#_u$)95mL9* z2X9?@$woEHu+Iweao3<9T_!-V<m@icy&=gY-;Sd1J)g4U4s0XA?Ml4O^C;5QfvD;j z4h>2VsNBZO=xp3d>PRO!&~**x{d6MJR+vD10g{kgsxTIdP~k`x`QCk!n$75^70sKW zWjzZWd2is1#RXVZ`I2Q%6cgX&3qXC^05h_EI63rcHP|$`vtgHv@V@>-+Pg>%6s9gP z7Mzh!R)$Q4xgAIkCe4Ej{Z(Y6)M<1-bp`5X2%vV;1N?O@5K}Jt;l-kLjG6&2SLK;Y zZ)Yfj`mlT=TAxP5AJr3q3MIJwPK29zqMSw@m*B=0$ucwUpQsRc7zo>kCBpoS6(oDC z0FKx`KvL92F)g^H+|}a{jK3{~eVpZ><D>(O%U)Dj<qhP~ObDd+nY1bMvGdDDoR=Yv z-t8lxDSa;{3%(>zdQRiU{*_c9@+BjDSdsI}q7c71Rl<@z<A_RN1KGCaGh<TJO44kV z=qlxMNGN$vM|DO4>6?rhxA(E54T7;{;2mlGb_F_ioo08WET%VdBJk23CtC3G4%%MY zi6@PA5SK?6K=Jl0j+?eTekd&k&Pi{Q{c=1Ve)u_=oN}0%GWIniAyZFH`!1n>{0)>g zo&slc!l`+Z0#4s3VUq7-0c$SKC3AYG;MD#U_^SiqJCDEe!T;8Q{K;3iYPaFI!A~M= z>J4E7r$Hez3X=7ez)WWx91t50QJ!I#xOEn+49&wAP8K-#*?GuOumTx-3hxv`Y1^S# zoT|T%-s-#0JieaBxG6@%o`Ea$Ia5q><PYOu<XcjI#lrY{FW|hxZZta4gx*{gLRVD0 z2I<qc*mm@$=_VrFn;*7e<i%rjMBH?|^Q;g&-e^+VkpTkq5HshH0`tD<G8=q*zwrUV zX7JrRhP!5*0GCKgptXprNkR4uI5hV&c~x<pitO<Nx99$F!P|tC@7V`hiJ`=rl}43} znV=T@nz?(bz=#G{!q{hu+`N64@ng>kT=uh!3flP5k=<NeHR}*g++u@!%#=6>Q{O_v z{RVnsN;y>?7mKHvFZgQnC)UX0C@SB~W5XWY#jB&{W3iqBd2qCXyrz#x&G0NNxN8l$ zO-u2+kQzK$rd#QyQ;6AZX|!)t1zFai2x_*Y=sKC}p!C%Trta{kgQ6S=JJt_cJ>Oyc z^iyc2=YkC~6skOfP;_w??&%t!eN!6ArpxuLWvB+uNScOa8x1gYiXEDXjON&`NhBFU z_ONNkck1FysCSM5p17>VP1#zA)>qrnb5NRlOfU&G60@L4$bg-xcMewUY9Ke)nZxun zM?9}80>i?MKr>YqM-XYq?->OL3r6D0l1L1QK7sXG7n#Q{ndIB+CFpbOsY$8SXr3~a zK`7%UDDF*%RoXsSb+i%2bcy0N8VE5ap$r-RfTnD1$K?F!^hHuNemvs^^EIx5jnGnj zpL~*-78&A}HWggF&<vI~DY3V5TR^A%EA@ML6bgK*8O19W@U*9m>i=58=7oG@8$~Tq z=ZF?u*{1=sddI?&itC^nPzKxqq@@B#EHWnI`88;w8$1rZHrK&9gK{D|-<BS!NrX)u zBe>G-0$jOD3%Hi&#T+&AAscu}%QI_frH`Qi2Y;=G@M;~BcjqPDJN!1SULeMGfA3G? z-PXgKyjysBwHDS+i9=IWFASI7$UJcz&V9&Hf^&l$2#oTfUDuSk7IH_iqV5!A-PYrt z=(<Q3j5vX-2bRIYhtqJ>gL`Dy)luLf-T}h*SJ2kuT(VT2qPJ}#TrM=lgm7sTnpaYp zII;#WnmUmeIbN{qM?Cqx<qnod{$!K;80fa^#!u0~sO7K)dKZU~?#>3%Dcl7bZ$FU@ z&YJM-kPavQlQoqvU5ct=&&Y0G@~`2s2qvfp;E{*29MSceM9Rw;cOGzpBSA0NJa1pB z)FVB_Y*cxUgWvwvc?@x;p}_AS(YK?`IgAJ`CdJRbKx2Lvd2yOcGwTIWPN9K1i26}J zRhl>??tc`>XGqPvtEhdf6Wo2u3w(2Jz|8R~UG|$xCeMq)_)*zI4sP0C#PJVs{tH7I z8s47WJ6-pB?euW4<^T96LF(U=3gwxeaHzEk7tLD>Qy15fAn#gSAyHkqAfuf0T1$ao zfIeBQ7*Cd_RziMP3OqbKhr2xJEM0U`lU`r&5=XXj8H>hLxG*LZsYx!`;o*cXLetP> z0t=(xj6mhgk3=@$6Ykp51+I(61M|v=4zrSl{=LOGb<+u0_v#!?;%zRA%YV?Ab7Sz+ zBRwz_@hAOHG~ndeQI(_YgkinMW7=6H#rb@(1|1yJ;KbMh7&FBhopYr)d!HPE=#>3% zE8>kw@Se?d!EQ@>rnj1mwl0F**@nb+?M}>;Tm!-Po#Fb4DKP9Q0fo{KYSQ3>o%1$h zvt>PAIO|Skwr8`W&0f)FO@=$}LMmSNiiEDo+UT8nh5Qh_gg#TQz|VuSRL3Wd#6BIv zDH(qirhZMKi=L0g$9d=J#p6Q2J=(~+ZMsbp?^eLex9{2g(_}eJV;QqDunl59mym>j zeW2)Li2ae)7`M5DJdimI0&SC-;0qb}Sur1X<d&e@yw})l_>$Qp7zJ&$8!OYZM{!Fm z=hK>E4f<rV2<NDOIld98p>`XZAj<wNS*$+_Am;;P({+J3%PfMUMk~<0W(!2o0(?60 z0(l`ihitu}g!><tRo>Y=5&d+6@z==9)MrIF@;At-f7lN{-uqh{AF`iq4tw_Q_z(N} zBQPjgc#a4d#G=Xq1GsE-0*ePmq0rljxZ`vdT%V9d@cw$Lx&Aj{nnk!zb7!!j`M=>w zpeT`+4`n8=9trkl4)8PgH+ow(;z>&@q9HT~-Mcd(WOEuVKHZP})iULe;r;8OZMEBL zzr&Wl68{*M`SP=P@%wacGTlvYv~#G=ny2)#LI{l8nFPxZ=ngGH^?ye!e<?Pc7UOD* zD8txkQV_G_G%+<j2=e3N(IZlTBYnzqXsmw?R{nk&{$;H0p02JtUANo1IM}($)&dCq zXvBx=uDDs@DrCh4(y7C(V0VT)S@z{I%Ln4G>->K$e4m~B`IvmTXz5OjKDB~++A?S# zF9}n9Zj(*%HQ?`RIOOAI|A7Yo|NM`#T|HfFhvwnnVY}De&feK!8~^KkTl(u5|Fy#Y z*%$rEwJvsy!AHs!v|FhVyP8#S&M_mL?Y|IW)mDJz3OVXlRzp{tOA{H1Rp@Xh1}!h` zL1lk6EcqCRJ)hOdxqC)1LW8#_G}W3sHL#;bP5DIS=bFk;>2mxQpN1vkJthhT{s5OH zxw66+N$|o%sK2C0M>mZG%fVT=VBZ&3;XoahiOX}xZE%A$%|rOv=soS(cM)5?i_u?b zCXKmagie7~M7gM)>@dq9*VpVNE56)<59g2LR$)8#$B!s*Yfb~}!OJMwnM&V3&I4~5 zArR|bL@M8HrV+l^sYi4y(VgQ<KRW8NHUnoEg#;Z;GZY)r!+&d)Lwk*f!!Ac#yS;yA zdk6iDI7|1o;y~LuQlD&!7w*e~WD#LLh|LE9uY@6?lla=<wZB!!KMUm>y2OZuWNdpB zd>51F_Qc)+buliv@iiD)C#Rrg-r6Cb@#i@CITpU&|BKE1IZRjGJ-a<@`C@H7pYD+( z60xC`o4cZjLwyLapI@Q=Qg3)#C~BgYyK*Sp|Gx;{=ew*D?^X){^%>9A_^A$Seo`#> zYR%k{`og*t#SiV0f6aXSEfD`X;(z22Y~B9wCU!jrRHjG|OA7_@LFZYxx<ni-+-^|e ziSE=isRqZsy-70Wy~0XQ4%l55CiT}8IHW!deK`%Rh3;GMJemk!jAr5$D+h8`RM<Gs zCW@3b&&1Q6CG3x|yI8wv1X0}imYjXLgZ?bLhc2TIgJ^m=QE4{B=h>0;lN=W_IQ}S@ zFOF|aPQlTVNhtAV5{#`A$No2m$>6O~G~LG*9NvpzviU8%+**h#T4Ff!RVuC!Z6Mm- zT{Pm`Sk7*j7Mv5fl3d>sh2v8P>4N*C$m@nN+%P|P;vLfqGFtCw@bEx#c7{K=v`(RQ zWn(b=oD7$*-+v!H=ReRJ7&{pnueY;jQhhP{tRRuyC5NBPMDeZu1W;O;gDz6Xz^?rg z)}GhG@sqxS*0Lx#*;dV*xVZ<n9TCONG6~Gri{qG(@#XCPxWP(|z>V~3%oQAIGz@De zJ3*X!0XB{|r4G+l;U$SSrtAAn%4Z(C3j|^6`%FxFG!+t7n4x%yEVkyJA**iwptEAO zfrXhFntLeYM^Q~IsvU(ZE``CJ0Tpn6(FAg;GMwNpe{7!<#Z-hR;B=WqD8Eb<n^p!P z$vB3KQI@*0$#^#23f}ZM;rF^(oaYnT>EPj5hOIq^bL^Vnu0<ql9G*!2Z|Lp&13l4g zxhP#`MI=Kb(NmXKoH1nugnC|M;yNdAHL4DxlFSUa3CXbNTR6y7-@#E1C&~6fU$9#6 z5PN(DD@W^xk@}xWG()x%+pEsgxVTJeab_&dhEaI9;Wd>?S7Ubuza|f!CsSXxo+?dd znb?<o*vot&bt$st`SEH@!Dmls?)!!(&JW`78DYe4_WO$Grn=-~@<%Ess}0Uot`Ky$ zpR5QCq+8oWxY-{UK(bv5)Cj5r%&^1Tho?cFg$*nTss`443zoEQr-sV_9%wi-Cx%60 zg84%-b<!AUQ!i#-G;JZ{Jp(ArW#L+^<B+fUANQ31vU30PFf7wr9^(gRnWzcdL1z}? zz0yoncM{^db_{3EO^(E>*6UPmr5320Ovagyc9DA1Bw8Rn4F>1;l4vzO99cAtZcPe= zi#u+S%(_dM@OTOYJ`6#l+G5C+^M@AaIC5T4m4uJ9q*$Z_>5o}zWavi5&as7v25GFL zN;vQ`5zldsW3$)?lzwak{vqw4>sA86lkTHzlp(~6mcZ}V?`iI}I%c)60QX$<aZ;{! z94ERgB}X4s<JEx?U_8km4jeJZ{A1B*J}A!JxcxY_Y8`|3f^MSh=XMgf^(`&E7>u3+ zlIW)r46@eaAzXd<kW2md@#DGdKkP5(urgdrj6?r~VNCWjd9pL2n>u`HA%|}-!W4fY zIBg~gc*%&{y)hAXj-1ckc3BCNi9hV{I7@%z@QPJ_hvIoRe`xOQ!x~;87kkN$-FYSv z#TLooj@f?riMt85`51$@WHX*gzeg*-*<fY#IT(3A7wxQSL1v2>yP!M{8|`$k&0rZ` z(n!aI*;g6MX{T^vs08=RwyX5lwYPXGoPz`9J8+noE!Id2aO$PEGa;r2$#A<?RCSty z7W-|;wMqVXZ{H61+LMIwr(@_rQF+XpuSqofCt=K`5GsGyjXoAI;T2X55IM!Y#OUDV zA>Q>@-}hhkxuNazuPNw1*^IbDGw+zw0*)VXqb{PSX!-WXOqJGCxD%CwR&oo6HfBwp zXXUrQ-H-l8GJhN0z2-AMJG7asOt}TuW_~ADS9Hl~Z)q&D{AFS%<@@jF%xCKUbI$yl z{$h-CileTREr{ig1$b!3a`fbdCW?|ZWcf)6aG2{toaDTqx3`G7aJh!WMJ4lU{VLe* zR#D6=mgFcqw$rcsv@zJAg^cR;AqMx0@MBRV$iL^ZKW9eZ80|zbo2CN(&1qPEpcOu! z*Mt<wDV0mY?m(aNByvFNZKbQY7HWJM1IvehAp-$>Ou|f-;JT5+N%rRmc5y*AZrE0d zZF!@>Do}$v$!axOR{e#Z*_#Q^GY4tk15H>m_YS@D<sGv+%>>3j{74OzkHX;ABlz-R z0#Tm13+4`Jg3R(hxVmRPehn%_3_F6?M{7aVr(dMUxt@_?BhhKq0#Lb@g8SCT!NQjY zoI|M<B-Tt0l-Cc^(U;q?-eNrdYULea%_+kr`<fxWIF#&kyh?QEo~Fvf_2`|C`=HZf z7#sOa6s#|s(`?-WdQGkvUu*nCw~jz6yK_79J~aaa^d;HzQ*9V=-m*?v=!9<q?vto# zH%V$vEb4RA2-D&R^SaDnrusGZ#K}+eilQpY9{WTrZ!IFJDyq;WE=mGUu7ZlPD#qrO zIgqZ+;Munf7r3>fPV^19xx=5B?R|zXD-AFrX(80=MN&_BVH~%z2z(`vVfC7&c*%1; zx%|GG<f;53Z_ReX%=H%Jc+__yTN*$Iw)^lR_GYHg)P_vGu@FR+rJz+&pLwOW0*z0E zFcE=oD!0`yBqehsIp5?&aDl-SvWZu>WOTWPnU^ZZ*&I<tMLwM)!AF*&<c}hH{M`|< zHFG^UZb>J3+G#ND@KyGgOg2@J(%}R=y+wL>H7W}$^w4{a8wy?94Ox-pbcM4Jc|6LG zm7i}6hS%4lb#5@~<vWo5PsTysyli4KcLko`tphnO1S9;1WA}MC+!%HqK1@-Ar!U9D znT%9$7n;raSlxsm!r*yU30Ku^K<Elam0d|Rw=5ZRLvqOJsx{2E#b3#h?}Bh)^LX%_ zj`TiPlv5(LhrD&@!l%o8u(qZJtpy`YhIxKtk39NBg^L6@lS4*A{3}CvWbKMytB>NQ zDtYFXrXD$++=<)I&ZUVFjwYj4+R%+gEAhMX2#g!APdD|IF{uy6!_f^PfW--PdH8#} zW6VpkZi^xWu4th~R!d;Q%>(db-8mS3rWHJDH{q(w_9kO*&NaE)RKqG5Xy6+Sie>j$ zrti`Mv{ou)_Lx~<a#a|<@7PVIg^ni=EXIJ|8wt*vfpC&;GeC}a2IJPU4AeY%8>a2t zg<n$xQBOmKtu492=8V{m9_lIo<UmyZm2>)MkM|D;@~7vUZgUEf)-A_o`%75A%nhuh z2FZ#hW%LNRjp?a`SHQM~rL*JVN`3%dRmfw_A5Y^IfGTh&tO<Y+>l3*0M;7rrd5Dxh z2mw2BNeEI`;99hs<GDadQg>kybqM}KI%Su@k!us6&1oaOvGyZ2&VEnzbsv%H*ibY~ zRt2xH9pL<4kPM7jz;#Sa$8)Yd(EeTmcMSZ*1jW;^v)<a|qH8z=)rOP#H6O@^<}#8P zF%x!G$#7o8?uSV8+gM-NMkEwtVEpGK7;tW)Zx!Pq$j1d#GA47Bcxh3=H-ea6kwP7R zeW#93>@lxy2DdA!h59J6BsOS}#+3Jik;X<7H_kbD^CXbCMy}!IV<U*=>;lMAabSXl zf01926=cfESo&l`9hoj5&Drl3LdRc!Oxk~)frj2_nEMq>+=3Oj>1M;AMM8?(u2Bh* zAvz@e!%?W}cEso_H!;aKkX%a_XPKNI^v%{hCPzt+<CVS+2G$(|&dBvJ%$5a>v!>L* zSD0I-IFG)sT>*^?_T!?Y*|cG-9v(I2f>SfV<~tfNUv(nn#wf5emo6cy#V0VkOM=^y zKaNPgK1XzKorL46x1gl22E30PMLC;dx}~d+l#fv3w0W?w!#ony!){XBaa*yaZ!zqD z9fnz&duj2g4mPyy4z!-UOmh?!xr-+W!nod<*m6pjF^RdzcBzHYYipK3$L#AUv2hqG zDouy1*-B{FS&MEOQz5!465cGG%l;hvj=RfegTB{8>>hubJe*;Vr#$0O%=9(gc*&SP z`Y;(yFUy0zeH*>qc>uclB8g(`Of<>ggj%l}@JCY|@d>+3B^$<aQznb!)NDt*S8xl? ze|d`2k6gh+vJv2vod>f^32wQsiYKd-VD_gm+(hFpTv8cH?{B>c!xUc7-qL%FkV-Vn zY82sE<;vpJ&fv-khVQ|p@hVma`GfX)KNNdjcXryPzaM&};H>BquJ_A_3d^64;o z_(UY*HFq=!Rqn^bHVo|Nu7dk%8l3VdE-ecFKzo}O!2QBi;Qd1kt(@=C@SqV8)w~A% z^%|(tp;G2?-x2tJb3D#_(19{*`p{!`8orXdfaa@}sgX`1a@%L&5p6SskXN)}Q5HCQ zYQbaG0F!iyk2LrGb83GtlS~McgVYT>@X0PYMoL78>vLy4F;JWjnhue)&;K1fNwmYF z3!BM6<!JW#j#`$geqdj3c*Z=C-2hn|3ZZeoY(>PCl^~Eg3opnmfR{4!dBu~VF!`V) zho3{@I}U#P*UaWWIl8jwclPY+kzBv!BT+v86u2cwf`01}-rDm8^AQ6>1zmquK>YXe z`25JZB^j)3oq5sFRhU>@O5|`mT!mBA^4mDv<kd8k3;1gSUV=BuZ4UoYPq%HS{ob9s zU3DE?wmNM4g9KmQwuxkpTn;;&e==?A@_3?WmPzKlQkXMY3c_Wcp~OxbI$Ny`qC2gr z{Jyv7vf?u=x~L27TO(O7ZA$$@Lcj$_pv?CeSW!Cy-i99_j|AmNgZ3;M+!zRoL7R{y z*I}=P5<E^+g{iOm$jz$zq(v!%dL=Fftz;v%*YO*jTbN3<Rts`2$+n<T+ix<q^D(+* zJ!f|fYoQa|hI2nnuO=Jfu2RPhBHT}%mSCP=3J#SNCfzC}o|z_$xL`J{U*8DT3ek`v z=7#HamXrBnuSif@JC0b{O^Wu*aGo_xg4)DN5Lj+Z4)?~hf(F;fu%?wbP__>j%AbV~ z7YabV>@eQA=LMq`F0tEw<dN#M4pJ3<1hjT!F*AIM!N&PC3AsHHF0EXL`=^ZHw1ssO zfgl4Ads>SpQ&!_Q_tW6e8bd0-e!+sb0YfVJYmlPy_bTCO$=2P~Rd<)i-aiw)wmWy& z?wVn(FM;Cci$Sloh)p}aiVU_Vz__aps6X`xzBX7)ld9Ih4C`#v^m~L`UaW>&HskSW zeGTo;>BJNDIw;@YUqN(S!D{zHY@hrPM!|HP@In{WYOaHD@OpA&l_c+ASD6Or<w0<6 z1=x&Rg|c-zta`>3#`wS?6qet{Zgb8z-ugKj+QUwQTEqo<Tz(^!mez*%_vCQmMR`~y zR?PZp??+DWUYu)K24AMB(24!^kbPJfpDM(Y)2ib*9_xcJq2n_S=4ZgQtTpIxCabdY zmmHp58ceJ#J#dccTsm|4Te4u3FJ7KI3$>dsQkT{e{IKB(J!#R)<VWA7%{?LXp;&t5 zy_P9>;Qb6x<JswgYn`;+d4SB?>`66MD~OjfFY!A^2zXH^D!acU2k)JN2_Lr7kFSov z&YT&j!DAIE5u=Fl!<Tscm?wqO3{dp=g5fvQ@$9$&e3IwO8bbwbmoI{8EBAs^+)tvU zB1fFo-6AiuZE%wNM)0zVzy;6NLBN9((Du=a%ygSgclF&R%9G}BPgI>GlQrJ4F19P! zRWD|Phe{H=^m8;mzcY#+w+><UZ|z|8?x_=r5rLp99EuGeW)Ww_bHEjC!DW4O$mVBq z$mRJYue9x0d|HQo)>@3bvN`<xBn%?HOoK`{X=r%rPhVG$hF#prsQNL0PFT1fvhLf0 z`{sjKBfJq!Qd^jJJ<Y)A_2C8I2r@&h5CiPha8p7#eX+EWHk}v4*;*fs{Y1p^o?Zn8 zpIZdRl^if@vVz+NH?USBiS&-t0oPytxXnV|L|>_kO}x_qt?Rx*n)oC1NtwtoS?Yv4 zBo9J_`+kUca|DyZ2T7(;GaK_x5^VH$kxbPjkaa%E=(bO#k1WFAVcZy)xZW77)wFTi zoj_zPWl*C_88<Fi3}c(T;Xsoz-Kbd#g4zWT&3F@^?W?ip+$~s;-vQ5_q+!kFc)WA& zJW41&huftesgQU>WwhZvT6Dt=*WKAepRZBk*iF9#YSz4ivubt7YgvoKwzT59qki!5 za2$+y!^1nKo0?W@)AVkDsqQtTDdZMKrVbtFyk~wYhCu9+B&-aa1dlL_l!|lUmZ2?p z*XUu;z&)&t6T_#c?BND?1U|^sLm{>ISY`PV2JF^A``CrNvPCb@N|%N8dIxCr#VNG( zeG^Sg)#PwiN&&~KlP0d~B2%{cVqPp49vExkq(uVo#qBt36}UkTc~>z;KW3QJ=x%5I z&&uLOwYPMxJqvryTiE~q;Njo!A3ON|V{NB5>!Bw|N?zvCrMK*H%U}g&39G@kFmtr2 zo6p%dEEM1OoI#I7AqeGOgZM=O%)uk0Vbiw`bac}}W>+vikNS-3Vm~p5L!`M^m@GUV zD2^I+Yf<X!Z6f5ELVnk6gSjajm?9Db58Qv03vWp#QJ;liuapdt+;oS&)6b@U7OPOJ zg;xR`%4Mr=XyMD_TDZRJByNa50Lz5Zn0+s8aBy%r#PY&@1F>S<S?UKfj*6q<#(KIn zSCiSqi2+d0Fv&fXgvX2}Aa>;dJ(Af2>DOY)8%9WTLkCJo{JR0TTFN7FHXpzz$PXL* zcEZ_*b8%v>1DVi$7NvenMqgzeydvrdT|bt>s)}%u(!{%S6PZd63j&T#-wfC8CBV5k zK_G285B=9p=hjVqjW7By!xFbzdU!gAcSzp=*Sgc7(7eW^Yi<G!Stv?3o!UV*+`9|s zmpvgPo*Tls?0D!YjmI=H4@&D3XtAg&w`-OdH-D8M!Y@%k>G#Z;+a@I7$~OpdJBuGG zURHi<NhS`5Oz^PeJCn(G?qZ!+F<2bY#eG6ToV&uqS<@5#pt#nWJ>2x8^0S{DLVqt) ze*Y}dzQGF~Bn0T5>56zsN0xK_dMP{g_6?Z3%A8tP1(5A=_Aql!2_$tN0SoQjD876< z%IYSwo&xRkVsSLOsK2H*y`SlvUsF-5AdMbqOaSTOOUcBU%ABGkUE<>(N7Y<{j5}8Q zljNr@r19M_&PT`f==NbF#8$2)72|I4ZqqD;!}ECi=kASYV>`NX=aXiPQAxu`>c>F+ zUKcg(495%oPNbo7I0j4i(ESS-$S;b+VK=v<dbJW^rtSyvL#N@;JZ<dm-wolcIT}^H zhn*9nF;OBLyrnf!?5P0f?v}CeaHJpZx;Pw<{8$8%>@l!>5XGMCFkoCJneq;-MM!k? zIh2{W63wsis{7-T!BZ^@i{7@;?Fm+7n$R8CyTYDul`g^qk;zcF{65Tk?8ff=rod?m z9H37O7NBtJN%Wg8$5~aq3SP+M(C;r&A#vmkkjef5_7iWS`ST!pvhE{2n>L8GSAxKq z7u4)u&w#JUSkA`gV!A1<k<Jq7Wk}^Am@q~Li;k({_f<h4{Je@hY!>3CKTW}-np_CI zRZGk!&c_$Ln=ffC>rF&9m_Y2aJajX+!e`u2BD?P{yUAfR_-r|aF;7c~a9JprbH0%k z9#SNA^L6&#m38>rBbJpQX)yOjFFW(LJ=Td1(xRMY#E3zf>peimJ{>-kiqqn84t`VN z^}nHm|HM6pPP6!qL!~Pj26xB8mWoKoIIaf!coD?sz5=GWLK2-7MsSUeabQ&IYDP2t zEHkamnO+Dq;$3M-Ad~wLYvf~SvHm48YUvO5yihIVF45*##w`Y)umv37QzN0Onu{6= zi_rOm6PU-OfWY)t<D&ax@$~m7x+JIw)p^(F_>ND8*ZGaljQQP(U+2NU8S*E4UEXU0 z?VI(<ybu*|in&V87aYRth&X+*0Xm-^1&^|3K&3(fuGKn--_#AUu)BojTjk@O;U#$2 zel<3oD`ZNQPT)4}PB!aE5L`KP9s?ZLP~)loU^eA39o9S=T|46-@Ma2C*b)xbY0rpd zZ~^_D)JOVcGRTr#KiINmBwno6WLo+UvXb*>!)QZ)Tt3YY&Ry*y0?`+6<Asy3^@kKk z!NdplWR`$lV?Lvo`w%CO{D?{?{m94Av6!psgzJ2L*xw0C@U@qhb9?ZHy3MZw$@E~% zxv&}1bO-3<G6^v941<j83=~ja3SfII^1~5nUgx*}4<NoGw}(Z*xFyz@^*Vq&nHxX` z?FL9fU@qwrUXPwNxkDKLd42kyG3Moh{|FGsYFlu&Af!e)7G4|9VO^Sw&}^+2%t>SE zm<a`>9o6x`{sstl+=7YMa;c0|Gm~~V6rUMp5xF(@;F;17s<rYq;eMUYHQGCo9kb{R zl@4Bo1G_?Cr+6mKs??`06RW79bsv6ei$cQ@`@mIPluL9L;DifzslfCXFz4C=C_gKO z%fFuo_ZSAJYx}~*PkHFSB^bEfbyW1cKfbZvOr-=ONY09%^lhCa$b%e~-&ck%r`Kd} zh&TIu@D=sfnLu2hkA#3P!klM^_wY!EH%!l7$ay9=3}#LmAnS}~!(DY>lwH(HSHuaT z*lqQp1tj$kg#Q9DA0Ix#^}I_uZXuGSR4)|V`)xq)q&mGDFck|{Mnk2N*%19#_{+B? zdGG&W#Y0b9SBJg2f994RB-LW+x-+!(kQUi$w+MC)?f_MrXj0Q42b$LYOsMZPbiW=* zcf=-}Jd}`P_geR0x8EK5c8()cE*(Yf@DLSj`bHK`2*E=8`?zxcB`R=gKCpsMAi~fb zWV$Obug8VPJNOg#CVlc;dIVg#=7g^W0-1=}jbxhb8VH-e9JgO7q#Go3xM~U$QHmpi z=<A2Z30h>t6bJTe!YJ;<d>eB9ybBuU4v=Em3^cpnXL5_Ta;MEmCX4NUGp{oIXiUa+ zoFd2KDQ`Knz93pTsbm7SNveRv*9a;r;f3j2BWUfS?PQ77c~b7$4H?W-{BbaW?U>d? ze0j&s{+a?Pz5FIh-4sH(-NEP~_?%ZXY=PrNkD<5+506*l;Ow^TSZESR-!yfi%J`pj z<D$Dnb^02Z7h-_A>Bm4%=qfI}l!}hI`mn715h$k*5^18&wjCbHy}4VE`^kS8`MIW) zS9BeYpY@M1zi)Lg%ndi#|4M?0^3I~SeG@XNn*EX8YJ42VA3QL$*yR7v<=^X(ZxQb% z&gO7cc?GJ2SINa)A|Pbe4cA{S0dMzO?8qpB-@*A*dG9kOk*1^0WfAb=YT*9^s#p(I literal 0 HcmV?d00001 diff --git a/models/generate-parakeet-test-model.py b/models/generate-parakeet-test-model.py new file mode 100755 index 00000000000..192a96ce627 --- /dev/null +++ b/models/generate-parakeet-test-model.py @@ -0,0 +1,182 @@ +#!/usr/bin/env python3 +import struct +import sys +import numpy as np +from pathlib import Path + +def write_tensor(fout, name, data): + n_dims = len(data.shape) + data = data.astype(np.float32) + ftype = 0 # GGML_TYPE_F32 + + name_bytes = name.encode('utf-8') + fout.write(struct.pack("iii", n_dims, len(name_bytes), ftype)) + for i in range(n_dims): + fout.write(struct.pack("i", data.shape[n_dims - 1 - i])) + fout.write(name_bytes) + data.tofile(fout) + +def generate(output_path): + rng = np.random.default_rng(42) + + hparams = { + 'n_vocab': 10, + 'n_audio_ctx': 3200, + 'n_audio_state': 8, + 'n_audio_head': 2, + 'n_audio_layer': 1, + 'n_mels': 16, + 'ftype': 0, + 'n_fft': 64, + 'subsampling_factor': 8, + 'n_subsampling_channels': 4, + 'n_conv_kernel': 3, + 'n_pred_dim': 8, + 'n_pred_layers': 1, + 'n_tdt_durations': 2, + 'n_max_tokens': 5, + } + + n_vocab = hparams['n_vocab'] + n_state = hparams['n_audio_state'] + n_head = hparams['n_audio_head'] + n_layer = hparams['n_audio_layer'] + n_mels = hparams['n_mels'] + n_fft = hparams['n_fft'] + n_sub_fac = hparams['subsampling_factor'] + n_sub_ch = hparams['n_subsampling_channels'] + n_conv_ker = hparams['n_conv_kernel'] + dec_dim = hparams['n_pred_dim'] + n_pred_l = hparams['n_pred_layers'] + n_tdt = hparams['n_tdt_durations'] + + n_pre_enc = (n_mels // n_sub_fac) * n_sub_ch + n_head_dim = n_state // n_head + n_pred_embed = n_vocab + 1 + n_lstm_gates = 4 * dec_dim + n_joint_out = n_vocab + n_tdt + 1 + n_freqs = n_fft // 2 + 1 + + def f32(*shape): + return rng.standard_normal(shape).astype(np.float32) + + with open(output_path, 'wb') as fout: + fout.write(struct.pack("I", 0x67676d6c)) + + for key in ['n_vocab', + 'n_audio_ctx', + 'n_audio_state', + 'n_audio_head', + 'n_audio_layer', + 'n_mels', + 'ftype', + 'n_fft', + 'subsampling_factor', + 'n_subsampling_channels', + 'n_conv_kernel', + 'n_pred_dim', + 'n_pred_layers', + 'n_tdt_durations', + 'n_max_tokens']: + fout.write(struct.pack("i", hparams[key])) + + fout.write(struct.pack("i", n_mels)) + fout.write(struct.pack("i", n_freqs)) + f32(n_mels, n_freqs).tofile(fout) + + fout.write(struct.pack("i", n_fft)) + f32(n_fft).tofile(fout) + + for d in range(n_tdt): + fout.write(struct.pack("I", d)) + + tokens = ['<unk>', '<s>', '</s>'] + [chr(ord('a') + i) for i in range(n_vocab - 3)] + assert len(tokens) == n_vocab + fout.write(struct.pack("i", n_vocab)) + for tok in tokens: + tok_bytes = tok.encode('utf-8') + fout.write(struct.pack("i", len(tok_bytes))) + fout.write(tok_bytes) + + write_tensor(fout, "encoder.pre_encode.out.weight", f32(n_state, n_pre_enc)) + write_tensor(fout, "encoder.pre_encode.out.bias", f32(n_state)) + + write_tensor(fout, "encoder.pre_encode.conv.0.weight", f32(n_sub_ch, 1, 3, 3)) + write_tensor(fout, "encoder.pre_encode.conv.0.bias", f32(1, n_sub_ch, 1, 1)) + + write_tensor(fout, "encoder.pre_encode.conv.2.weight", f32(n_sub_ch, 1, 3, 3)) + write_tensor(fout, "encoder.pre_encode.conv.2.bias", f32(1, n_sub_ch, 1, 1)) + + write_tensor(fout, "encoder.pre_encode.conv.3.weight", f32(n_sub_ch, n_sub_ch, 1, 1)) + write_tensor(fout, "encoder.pre_encode.conv.3.bias", f32(1, n_sub_ch, 1, 1)) + + write_tensor(fout, "encoder.pre_encode.conv.5.weight", f32(n_sub_ch, 1, 3, 3)) + write_tensor(fout, "encoder.pre_encode.conv.5.bias", f32(1, n_sub_ch, 1, 1)) + + write_tensor(fout, "encoder.pre_encode.conv.6.weight", f32(n_sub_ch, n_sub_ch, 1, 1)) + write_tensor(fout, "encoder.pre_encode.conv.6.bias", f32(1, n_sub_ch, 1, 1)) + + for i in range(n_layer): + p = f"encoder.layers.{i}" + + write_tensor(fout, f"{p}.norm_feed_forward1.weight", f32(n_state)) + write_tensor(fout, f"{p}.norm_feed_forward1.bias", f32(n_state)) + write_tensor(fout, f"{p}.feed_forward1.linear1.weight", f32(4*n_state, n_state)) + write_tensor(fout, f"{p}.feed_forward1.linear2.weight", f32(n_state, 4*n_state)) + + write_tensor(fout, f"{p}.norm_conv.weight", f32(n_state)) + write_tensor(fout, f"{p}.norm_conv.bias", f32(n_state)) + write_tensor(fout, f"{p}.conv.pointwise_conv1.weight", f32(2*n_state, n_state)) + write_tensor(fout, f"{p}.conv.depthwise_conv.weight", f32(n_state, n_conv_ker)) + write_tensor(fout, f"{p}.conv.batch_norm.weight", f32(n_state)) + write_tensor(fout, f"{p}.conv.batch_norm.bias", f32(n_state)) + write_tensor(fout, f"{p}.conv.batch_norm.running_mean", f32(n_state)) + write_tensor(fout, f"{p}.conv.batch_norm.running_var", np.abs(f32(n_state))) + num_batches = np.zeros(1, dtype=np.int32) + write_tensor(fout, f"{p}.conv.batch_norm.num_batches_tracked", num_batches) + write_tensor(fout, f"{p}.conv.pointwise_conv2.weight", f32(n_state, n_state)) + + write_tensor(fout, f"{p}.norm_self_att.weight", f32(n_state)) + write_tensor(fout, f"{p}.norm_self_att.bias", f32(n_state)) + + write_tensor(fout, f"{p}.self_attn.pos_bias_u", f32(n_head, n_head_dim)) + write_tensor(fout, f"{p}.self_attn.pos_bias_v", f32(n_head, n_head_dim)) + write_tensor(fout, f"{p}.self_attn.linear_q.weight", f32(n_state, n_state)) + write_tensor(fout, f"{p}.self_attn.linear_k.weight", f32(n_state, n_state)) + write_tensor(fout, f"{p}.self_attn.linear_v.weight", f32(n_state, n_state)) + write_tensor(fout, f"{p}.self_attn.linear_out.weight", f32(n_state, n_state)) + write_tensor(fout, f"{p}.self_attn.linear_pos.weight", f32(n_state, n_state)) + + write_tensor(fout, f"{p}.norm_feed_forward2.weight", f32(n_state)) + write_tensor(fout, f"{p}.norm_feed_forward2.bias", f32(n_state)) + write_tensor(fout, f"{p}.feed_forward2.linear1.weight", f32(4*n_state, n_state)) + write_tensor(fout, f"{p}.feed_forward2.linear2.weight", f32(n_state, 4*n_state)) + + write_tensor(fout, f"{p}.norm_out.weight", f32(n_state)) + write_tensor(fout, f"{p}.norm_out.bias", f32(n_state)) + + write_tensor(fout, "decoder.prediction.embed.weight", f32(n_pred_embed, dec_dim)) + + def reorder_gates(data): + h = data.shape[0] // 4 + return np.concatenate([data[:h], data[h:2*h], data[3*h:], data[2*h:3*h]], axis=0) + + for i in range(n_pred_l): + base = f"decoder.prediction.dec_rnn.lstm" + write_tensor(fout, f"{base}.weight_ih_l{i}", reorder_gates(f32(n_lstm_gates, dec_dim))) + write_tensor(fout, f"{base}.weight_hh_l{i}", reorder_gates(f32(n_lstm_gates, dec_dim))) + write_tensor(fout, f"{base}.bias_h_l{i}", reorder_gates(f32(n_lstm_gates) + f32(n_lstm_gates))) + + write_tensor(fout, "joint.pred.weight", f32(dec_dim, dec_dim)) + write_tensor(fout, "joint.pred.bias", f32(dec_dim)) + write_tensor(fout, "joint.enc.weight", f32(dec_dim, n_state)) + write_tensor(fout, "joint.enc.bias", f32(dec_dim)) + write_tensor(fout, "joint.joint_net.2.weight", f32(n_joint_out, dec_dim)) + write_tensor(fout, "joint.joint_net.2.bias", f32(n_joint_out)) + + size = Path(output_path).stat().st_size + print(f"Generated {output_path} ({size / 1024:.1f} KB)") + +if __name__ == '__main__': + output = sys.argv[1] if len(sys.argv) > 1 else 'models/for-tests-ggml-parakeet-tdt.bin' + generate(output) diff --git a/models/requirements-parakeet.txt b/models/requirements-parakeet.txt new file mode 100644 index 00000000000..5239ae0af5d --- /dev/null +++ b/models/requirements-parakeet.txt @@ -0,0 +1,3 @@ +torch +numpy +pyyaml diff --git a/scripts/quantize-parakeet.sh b/scripts/quantize-parakeet.sh new file mode 100755 index 00000000000..7816696bfcb --- /dev/null +++ b/scripts/quantize-parakeet.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +set -e + +build_dir=build +modelname=ggml-parakeet-tdt-0.6b-v3 +model=models/${modelname}-f32.bin +cmd=parakeet-quantize + +cmake --build ${build_dir} --target $cmd -j 12 + +${build_dir}/bin/${cmd} $model models/${modelname}-q8_0.bin q8_0 +${build_dir}/bin/${cmd} $model models/${modelname}-q4_0.bin q4_0 +${build_dir}/bin/${cmd} $model models/${modelname}-q4_k.bin q4_k +${build_dir}/bin/${cmd} $model models/${modelname}-q2_k.bin q2_k diff --git a/scripts/upload-parakeet.py b/scripts/upload-parakeet.py new file mode 100644 index 00000000000..3644bec8bd3 --- /dev/null +++ b/scripts/upload-parakeet.py @@ -0,0 +1,157 @@ +import argparse +import os +from huggingface_hub import HfApi, create_repo + +USER_NAME = "ggml-org" +REPO_ID = f"{USER_NAME}/parakeet-GGUF" + +MODELS = { + "f32": { + "local_path": "models/ggml-parakeet-tdt-0.6b-v3-f32.bin", + "remote_name": "ggml-parakeet-tdt-0.6b-v3-f32.bin", + "description": "Full precision (F32)", + }, + "f16": { + "local_path": "models/ggml-parakeet-tdt-0.6b-v3-f16.bin", + "remote_name": "ggml-parakeet-tdt-0.6b-v3-f16.bin", + "description": "Half precision (F16)", + }, + "q8_0": { + "local_path": "models/ggml-parakeet-tdt-0.6b-v3-q8_0.bin", + "remote_name": "ggml-parakeet-tdt-0.6b-v3-q8_0.bin", + "description": "8-bit quantized (Q8_0)", + }, + "q4_0": { + "local_path": "models/ggml-parakeet-tdt-0.6b-v3-q4_0.bin", + "remote_name": "ggml-parakeet-tdt-0.6b-v3-q4_0.bin", + "description": "4-bit quantized (Q4_0)", + }, + "q4_k": { + "local_path": "models/ggml-parakeet-tdt-0.6b-v3-q4_k.bin", + "remote_name": "ggml-parakeet-tdt-0.6b-v3-q4_k.bin", + "description": "4-bit K-quantized (Q4_k)", + }, +} + +def build_model_card(uploaded_variants): + lines = [ + f"---", + f"license: mit", + f"base_model: nvidia/parakeet-tdt-0.6b-v3", + f"tags:", + f"- gguf", + f"- asr", + f"---", + f"", + f"# Parakeet TDT 0.6B v3 (GGUF)", + f"", + f"GGUF conversions of [nvidia/parakeet-tdt-0.6b-v3](https://huggingface.co/nvidia/parakeet-tdt-0.6b-v3) for use with [whisper.cpp](https://github.com/ggml-org/whisper.cpp).", + f"", + f"## Available files", + f"", + ] + + for key, m in MODELS.items(): + if key in uploaded_variants: + lines.append(f"- `{m['remote_name']}` — {m['description']}") + + lines += [ + f"", + f"## Usage", + f"", + f"Build parakeet-cli:", + f"```console", + f"git clone https://github.com/ggml-org/whisper.cpp.git", + f"cd whisper.cpp", + f"cmake -B build -S .", + f"cmake --build build --target parakeet-cli -j $(nproc)", + f"```", + f"", + f"Download a model (e.g. Q8_0):", + f"```console", + f"hf download {REPO_ID} {MODELS['q8_0']['remote_name']} --local-dir models", + f"```", + f"", + f"Run:", + f"```console", + f"./build/bin/parakeet-cli -m models/{MODELS['q8_0']['remote_name']} -f samples/jfk.wav", + f"```", + f"", + ] + + return "\n".join(lines) + + +def upload_variant(api, key): + m = MODELS[key] + local_path = m["local_path"] + + if not os.path.exists(local_path): + print(f" Skipping {key}: {local_path} not found") + return False + + print(f" Uploading {m['remote_name']} ({m['description']})...") + api.upload_file( + path_or_fileobj=local_path, + path_in_repo=m["remote_name"], + repo_id=REPO_ID, + repo_type="model", + commit_message=f"Upload {m['remote_name']}", + ) + return True + + +def main(): + parser = argparse.ArgumentParser(description="Upload parakeet GGUF models to Hugging Face") + parser.add_argument( + "variants", + nargs="*", + default=None, + metavar="{" + ",".join(MODELS.keys()) + "}", + help="Model variants to upload (default: all)", + ) + parser.add_argument( + "--no-model-card", + action="store_true", + help="Skip updating the model card README", + ) + args = parser.parse_args() + + api = HfApi() + create_repo(repo_id=REPO_ID, repo_type="model", exist_ok=True) + + variants = args.variants if args.variants else list(MODELS.keys()) + + unknown = [v for v in variants if v not in MODELS] + if unknown: + parser.error(f"unknown variant(s): {', '.join(unknown)} (choose from {', '.join(MODELS.keys())})") + + uploaded = [] + for key in variants: + if upload_variant(api, key): + uploaded.append(key) + + if not uploaded: + print("No models were uploaded.") + return + + if not args.no_model_card: + print("Updating model card...") + existing = [k for k in MODELS if k in uploaded or + any(f.rfilename == MODELS[k]["remote_name"] + for f in api.list_repo_files(REPO_ID, repo_type="model") + if hasattr(f, "rfilename"))] + card = build_model_card(existing if existing else uploaded) + api.upload_file( + path_or_fileobj=card.encode(), + path_in_repo="README.md", + repo_id=REPO_ID, + repo_type="model", + commit_message="Update README.md", + ) + + print(f"\nDone. Repository: https://huggingface.co/{REPO_ID}") + + +if __name__ == "__main__": + main() diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 095a2791de5..4e7c5b24dc3 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -109,23 +109,43 @@ add_library(whisper whisper.cpp ) +add_library(parakeet + ../include/parakeet.h + parakeet-arch.h + parakeet.cpp + ) + +target_include_directories(parakeet PUBLIC . ../include) +target_compile_features (parakeet PUBLIC cxx_std_11) +target_link_libraries(parakeet PUBLIC ggml Threads::Threads) + # Set the version numbers set_target_properties(whisper PROPERTIES VERSION ${PROJECT_VERSION} SOVERSION ${SOVERSION} ) +set_target_properties(parakeet PROPERTIES + VERSION ${PROJECT_VERSION} + SOVERSION ${SOVERSION} +) + target_include_directories(whisper PUBLIC . ../include) target_compile_features (whisper PUBLIC cxx_std_11) # don't bump if (CMAKE_CXX_BYTE_ORDER STREQUAL "BIG_ENDIAN") set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DWHISPER_BIG_ENDIAN) + set(PARAKEET_EXTRA_FLAGS ${PARAKEET_EXTRA_FLAGS} -DPARAKEET_BIG_ENDIAN) endif() if (WHISPER_EXTRA_FLAGS) target_compile_options(whisper PRIVATE ${WHISPER_EXTRA_FLAGS}) endif() +if (PARAKEET_EXTRA_FLAGS) + target_compile_options(parakeet PRIVATE ${PARAKEET_EXTRA_FLAGS}) +endif() + find_package(Threads REQUIRED) target_link_libraries(whisper PUBLIC ggml Threads::Threads) @@ -144,4 +164,7 @@ endif() if (BUILD_SHARED_LIBS) set_target_properties(whisper PROPERTIES POSITION_INDEPENDENT_CODE ON) target_compile_definitions(whisper PRIVATE WHISPER_SHARED WHISPER_BUILD) + + set_target_properties(parakeet PROPERTIES POSITION_INDEPENDENT_CODE ON) + target_compile_definitions(parakeet PRIVATE PARAKEET_SHARED PARAKEET_BUILD) endif() diff --git a/src/parakeet-arch.h b/src/parakeet-arch.h new file mode 100644 index 00000000000..3407a95c9c7 --- /dev/null +++ b/src/parakeet-arch.h @@ -0,0 +1,188 @@ +#pragma once + +#include "ggml.h" + +#include <map> + +enum parakeet_tensor { + // Encoder pre_encode + PARAKEET_TENSOR_ENC_PRE_OUT_WEIGHT, + PARAKEET_TENSOR_ENC_PRE_OUT_BIAS, + PARAKEET_TENSOR_ENC_PRE_CONV_0_WEIGHT, + PARAKEET_TENSOR_ENC_PRE_CONV_0_BIAS, + PARAKEET_TENSOR_ENC_PRE_CONV_2_WEIGHT, + PARAKEET_TENSOR_ENC_PRE_CONV_2_BIAS, + PARAKEET_TENSOR_ENC_PRE_CONV_3_WEIGHT, + PARAKEET_TENSOR_ENC_PRE_CONV_3_BIAS, + PARAKEET_TENSOR_ENC_PRE_CONV_5_WEIGHT, + PARAKEET_TENSOR_ENC_PRE_CONV_5_BIAS, + PARAKEET_TENSOR_ENC_PRE_CONV_6_WEIGHT, + PARAKEET_TENSOR_ENC_PRE_CONV_6_BIAS, + + // Encoder layers (per-layer) + PARAKEET_TENSOR_ENC_NORM_FF1_WEIGHT, + PARAKEET_TENSOR_ENC_NORM_FF1_BIAS, + PARAKEET_TENSOR_ENC_FF1_LINEAR1_WEIGHT, + PARAKEET_TENSOR_ENC_FF1_LINEAR2_WEIGHT, + PARAKEET_TENSOR_ENC_NORM_CONV_WEIGHT, + PARAKEET_TENSOR_ENC_NORM_CONV_BIAS, + PARAKEET_TENSOR_ENC_CONV_PW1_WEIGHT, + PARAKEET_TENSOR_ENC_CONV_DW_WEIGHT, + PARAKEET_TENSOR_ENC_CONV_BN_WEIGHT, + PARAKEET_TENSOR_ENC_CONV_BN_BIAS, + PARAKEET_TENSOR_ENC_CONV_BN_MEAN, + PARAKEET_TENSOR_ENC_CONV_BN_VAR, + PARAKEET_TENSOR_ENC_CONV_BN_NUM_BATCHES, + PARAKEET_TENSOR_ENC_CONV_PW2_WEIGHT, + PARAKEET_TENSOR_ENC_NORM_ATTN_WEIGHT, + PARAKEET_TENSOR_ENC_NORM_ATTN_BIAS, + PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_U, + PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_V, + PARAKEET_TENSOR_ENC_ATTN_Q_WEIGHT, + PARAKEET_TENSOR_ENC_ATTN_K_WEIGHT, + PARAKEET_TENSOR_ENC_ATTN_V_WEIGHT, + PARAKEET_TENSOR_ENC_ATTN_OUT_WEIGHT, + PARAKEET_TENSOR_ENC_ATTN_POS_WEIGHT, + PARAKEET_TENSOR_ENC_NORM_FF2_WEIGHT, + PARAKEET_TENSOR_ENC_NORM_FF2_BIAS, + PARAKEET_TENSOR_ENC_FF2_LINEAR1_WEIGHT, + PARAKEET_TENSOR_ENC_FF2_LINEAR2_WEIGHT, + PARAKEET_TENSOR_ENC_NORM_OUT_WEIGHT, + PARAKEET_TENSOR_ENC_NORM_OUT_BIAS, + + // Prediction network + PARAKEET_TENSOR_PRED_EMBED_WEIGHT, + PARAKEET_TENSOR_PRED_LSTM_WEIGHT_IH, + PARAKEET_TENSOR_PRED_LSTM_WEIGHT_HH, + PARAKEET_TENSOR_PRED_LSTM_BIAS_H, + + // Joint network + PARAKEET_TENSOR_JOINT_PRED_WEIGHT, + PARAKEET_TENSOR_JOINT_PRED_BIAS, + PARAKEET_TENSOR_JOINT_ENC_WEIGHT, + PARAKEET_TENSOR_JOINT_ENC_BIAS, + PARAKEET_TENSOR_JOINT_NET_WEIGHT, + PARAKEET_TENSOR_JOINT_NET_BIAS, +}; + +static const std::map<parakeet_tensor, const char *> PARAKEET_TENSOR_NAMES = { + // Encoder pre_encode + {PARAKEET_TENSOR_ENC_PRE_OUT_WEIGHT, "encoder.pre_encode.out.weight"}, + {PARAKEET_TENSOR_ENC_PRE_OUT_BIAS, "encoder.pre_encode.out.bias"}, + {PARAKEET_TENSOR_ENC_PRE_CONV_0_WEIGHT, "encoder.pre_encode.conv.0.weight"}, + {PARAKEET_TENSOR_ENC_PRE_CONV_0_BIAS, "encoder.pre_encode.conv.0.bias"}, + {PARAKEET_TENSOR_ENC_PRE_CONV_2_WEIGHT, "encoder.pre_encode.conv.2.weight"}, + {PARAKEET_TENSOR_ENC_PRE_CONV_2_BIAS, "encoder.pre_encode.conv.2.bias"}, + {PARAKEET_TENSOR_ENC_PRE_CONV_3_WEIGHT, "encoder.pre_encode.conv.3.weight"}, + {PARAKEET_TENSOR_ENC_PRE_CONV_3_BIAS, "encoder.pre_encode.conv.3.bias"}, + {PARAKEET_TENSOR_ENC_PRE_CONV_5_WEIGHT, "encoder.pre_encode.conv.5.weight"}, + {PARAKEET_TENSOR_ENC_PRE_CONV_5_BIAS, "encoder.pre_encode.conv.5.bias"}, + {PARAKEET_TENSOR_ENC_PRE_CONV_6_WEIGHT, "encoder.pre_encode.conv.6.weight"}, + {PARAKEET_TENSOR_ENC_PRE_CONV_6_BIAS, "encoder.pre_encode.conv.6.bias"}, + + // Encoder layers (use %d for layer number) + {PARAKEET_TENSOR_ENC_NORM_FF1_WEIGHT, "encoder.layers.%d.norm_feed_forward1.weight"}, + {PARAKEET_TENSOR_ENC_NORM_FF1_BIAS, "encoder.layers.%d.norm_feed_forward1.bias"}, + {PARAKEET_TENSOR_ENC_FF1_LINEAR1_WEIGHT, "encoder.layers.%d.feed_forward1.linear1.weight"}, + {PARAKEET_TENSOR_ENC_FF1_LINEAR2_WEIGHT, "encoder.layers.%d.feed_forward1.linear2.weight"}, + {PARAKEET_TENSOR_ENC_NORM_CONV_WEIGHT, "encoder.layers.%d.norm_conv.weight"}, + {PARAKEET_TENSOR_ENC_NORM_CONV_BIAS, "encoder.layers.%d.norm_conv.bias"}, + {PARAKEET_TENSOR_ENC_CONV_PW1_WEIGHT, "encoder.layers.%d.conv.pointwise_conv1.weight"}, + {PARAKEET_TENSOR_ENC_CONV_DW_WEIGHT, "encoder.layers.%d.conv.depthwise_conv.weight"}, + {PARAKEET_TENSOR_ENC_CONV_BN_WEIGHT, "encoder.layers.%d.conv.batch_norm.weight"}, + {PARAKEET_TENSOR_ENC_CONV_BN_BIAS, "encoder.layers.%d.conv.batch_norm.bias"}, + {PARAKEET_TENSOR_ENC_CONV_BN_MEAN, "encoder.layers.%d.conv.batch_norm.running_mean"}, + {PARAKEET_TENSOR_ENC_CONV_BN_VAR, "encoder.layers.%d.conv.batch_norm.running_var"}, + {PARAKEET_TENSOR_ENC_CONV_BN_NUM_BATCHES, "encoder.layers.%d.conv.batch_norm.num_batches_tracked"}, + {PARAKEET_TENSOR_ENC_CONV_PW2_WEIGHT, "encoder.layers.%d.conv.pointwise_conv2.weight"}, + {PARAKEET_TENSOR_ENC_NORM_ATTN_WEIGHT, "encoder.layers.%d.norm_self_att.weight"}, + {PARAKEET_TENSOR_ENC_NORM_ATTN_BIAS, "encoder.layers.%d.norm_self_att.bias"}, + {PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_U, "encoder.layers.%d.self_attn.pos_bias_u"}, + {PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_V, "encoder.layers.%d.self_attn.pos_bias_v"}, + {PARAKEET_TENSOR_ENC_ATTN_Q_WEIGHT, "encoder.layers.%d.self_attn.linear_q.weight"}, + {PARAKEET_TENSOR_ENC_ATTN_K_WEIGHT, "encoder.layers.%d.self_attn.linear_k.weight"}, + {PARAKEET_TENSOR_ENC_ATTN_V_WEIGHT, "encoder.layers.%d.self_attn.linear_v.weight"}, + {PARAKEET_TENSOR_ENC_ATTN_OUT_WEIGHT, "encoder.layers.%d.self_attn.linear_out.weight"}, + {PARAKEET_TENSOR_ENC_ATTN_POS_WEIGHT, "encoder.layers.%d.self_attn.linear_pos.weight"}, + {PARAKEET_TENSOR_ENC_NORM_FF2_WEIGHT, "encoder.layers.%d.norm_feed_forward2.weight"}, + {PARAKEET_TENSOR_ENC_NORM_FF2_BIAS, "encoder.layers.%d.norm_feed_forward2.bias"}, + {PARAKEET_TENSOR_ENC_FF2_LINEAR1_WEIGHT, "encoder.layers.%d.feed_forward2.linear1.weight"}, + {PARAKEET_TENSOR_ENC_FF2_LINEAR2_WEIGHT, "encoder.layers.%d.feed_forward2.linear2.weight"}, + {PARAKEET_TENSOR_ENC_NORM_OUT_WEIGHT, "encoder.layers.%d.norm_out.weight"}, + {PARAKEET_TENSOR_ENC_NORM_OUT_BIAS, "encoder.layers.%d.norm_out.bias"}, + + // Prediction network + {PARAKEET_TENSOR_PRED_EMBED_WEIGHT, "decoder.prediction.embed.weight"}, + {PARAKEET_TENSOR_PRED_LSTM_WEIGHT_IH, "decoder.prediction.dec_rnn.lstm.weight_ih_l%d"}, + {PARAKEET_TENSOR_PRED_LSTM_WEIGHT_HH, "decoder.prediction.dec_rnn.lstm.weight_hh_l%d"}, + {PARAKEET_TENSOR_PRED_LSTM_BIAS_H, "decoder.prediction.dec_rnn.lstm.bias_h_l%d"}, + + // Joint network + {PARAKEET_TENSOR_JOINT_PRED_WEIGHT, "joint.pred.weight"}, + {PARAKEET_TENSOR_JOINT_PRED_BIAS, "joint.pred.bias"}, + {PARAKEET_TENSOR_JOINT_ENC_WEIGHT, "joint.enc.weight"}, + {PARAKEET_TENSOR_JOINT_ENC_BIAS, "joint.enc.bias"}, + {PARAKEET_TENSOR_JOINT_NET_WEIGHT, "joint.joint_net.2.weight"}, + {PARAKEET_TENSOR_JOINT_NET_BIAS, "joint.joint_net.2.bias"}, +}; + +static const std::map<parakeet_tensor, ggml_op> PARAKEET_TENSOR_INFO = { + // Encoder pre_encode + {PARAKEET_TENSOR_ENC_PRE_OUT_WEIGHT, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_ENC_PRE_OUT_BIAS, GGML_OP_ADD}, + {PARAKEET_TENSOR_ENC_PRE_CONV_0_WEIGHT, GGML_OP_IM2COL}, + {PARAKEET_TENSOR_ENC_PRE_CONV_0_BIAS, GGML_OP_ADD}, + {PARAKEET_TENSOR_ENC_PRE_CONV_2_WEIGHT, GGML_OP_IM2COL}, + {PARAKEET_TENSOR_ENC_PRE_CONV_2_BIAS, GGML_OP_ADD}, + {PARAKEET_TENSOR_ENC_PRE_CONV_3_WEIGHT, GGML_OP_IM2COL}, + {PARAKEET_TENSOR_ENC_PRE_CONV_3_BIAS, GGML_OP_ADD}, + {PARAKEET_TENSOR_ENC_PRE_CONV_5_WEIGHT, GGML_OP_IM2COL}, + {PARAKEET_TENSOR_ENC_PRE_CONV_5_BIAS, GGML_OP_ADD}, + {PARAKEET_TENSOR_ENC_PRE_CONV_6_WEIGHT, GGML_OP_IM2COL}, + {PARAKEET_TENSOR_ENC_PRE_CONV_6_BIAS, GGML_OP_ADD}, + + // Encoder layers + {PARAKEET_TENSOR_ENC_NORM_FF1_WEIGHT, GGML_OP_MUL}, + {PARAKEET_TENSOR_ENC_NORM_FF1_BIAS, GGML_OP_ADD}, + {PARAKEET_TENSOR_ENC_FF1_LINEAR1_WEIGHT, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_ENC_FF1_LINEAR2_WEIGHT, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_ENC_NORM_CONV_WEIGHT, GGML_OP_MUL}, + {PARAKEET_TENSOR_ENC_NORM_CONV_BIAS, GGML_OP_ADD}, + {PARAKEET_TENSOR_ENC_CONV_PW1_WEIGHT, GGML_OP_IM2COL}, + {PARAKEET_TENSOR_ENC_CONV_DW_WEIGHT, GGML_OP_IM2COL}, + {PARAKEET_TENSOR_ENC_CONV_BN_WEIGHT, GGML_OP_MUL}, + {PARAKEET_TENSOR_ENC_CONV_BN_BIAS, GGML_OP_ADD}, + {PARAKEET_TENSOR_ENC_CONV_BN_MEAN, GGML_OP_SUB}, + {PARAKEET_TENSOR_ENC_CONV_BN_VAR, GGML_OP_DIV}, + {PARAKEET_TENSOR_ENC_CONV_BN_NUM_BATCHES, GGML_OP_NONE}, + {PARAKEET_TENSOR_ENC_CONV_PW2_WEIGHT, GGML_OP_IM2COL}, + {PARAKEET_TENSOR_ENC_NORM_ATTN_WEIGHT, GGML_OP_MUL}, + {PARAKEET_TENSOR_ENC_NORM_ATTN_BIAS, GGML_OP_ADD}, + {PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_U, GGML_OP_ADD}, + {PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_V, GGML_OP_ADD}, + {PARAKEET_TENSOR_ENC_ATTN_Q_WEIGHT, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_ENC_ATTN_K_WEIGHT, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_ENC_ATTN_V_WEIGHT, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_ENC_ATTN_OUT_WEIGHT, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_ENC_ATTN_POS_WEIGHT, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_ENC_NORM_FF2_WEIGHT, GGML_OP_MUL}, + {PARAKEET_TENSOR_ENC_NORM_FF2_BIAS, GGML_OP_ADD}, + {PARAKEET_TENSOR_ENC_FF2_LINEAR1_WEIGHT, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_ENC_FF2_LINEAR2_WEIGHT, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_ENC_NORM_OUT_WEIGHT, GGML_OP_MUL}, + {PARAKEET_TENSOR_ENC_NORM_OUT_BIAS, GGML_OP_ADD}, + + // Prediction network + {PARAKEET_TENSOR_PRED_EMBED_WEIGHT, GGML_OP_GET_ROWS}, + {PARAKEET_TENSOR_PRED_LSTM_WEIGHT_IH, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_PRED_LSTM_WEIGHT_HH, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_PRED_LSTM_BIAS_H, GGML_OP_ADD}, + + // Joint network + {PARAKEET_TENSOR_JOINT_PRED_WEIGHT, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_JOINT_PRED_BIAS, GGML_OP_ADD}, + {PARAKEET_TENSOR_JOINT_ENC_WEIGHT, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_JOINT_ENC_BIAS, GGML_OP_ADD}, + {PARAKEET_TENSOR_JOINT_NET_WEIGHT, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_JOINT_NET_BIAS, GGML_OP_ADD}, +}; diff --git a/src/parakeet.cpp b/src/parakeet.cpp new file mode 100644 index 00000000000..b5da73e985c --- /dev/null +++ b/src/parakeet.cpp @@ -0,0 +1,3838 @@ +#include "parakeet.h" +#include "parakeet-arch.h" + +#include "ggml.h" +#include "ggml-cpp.h" +#include "ggml-alloc.h" +#include "ggml-backend.h" + +#include <atomic> +#include <algorithm> +#include <cassert> +#include <cfloat> +#define _USE_MATH_DEFINES +#include <cmath> +#include <climits> +#include <cstdarg> +#include <cstdio> +#include <cstring> +#include <fstream> +#include <functional> +#include <cctype> +#include <map> +#include <random> +#include <set> +#include <string> +#include <thread> +#include <vector> + +#ifdef _MSC_VER +#include <codecvt> +#endif + +#if defined(PARAKEET_BIG_ENDIAN) +template<typename T> +static T byteswap(T value) { + T value_swapped; + char * source = reinterpret_cast<char *>(&value); + char * target = reinterpret_cast<char *>(&value_swapped); + int size = sizeof(T); + for (int i = 0; i < size; i++) { + target[size - 1 - i] = source[i]; + } + return value_swapped; +} + +template<typename T> +static void byteswap_tensor_data(ggml_tensor * tensor) { + T * datum = reinterpret_cast<T *>(tensor->data); + for (int i = 0; i < ggml_nelements(tensor); i++) { + datum[i] = byteswap(datum[i]); + } +} + +static void byteswap_tensor(ggml_tensor * tensor) { + switch (tensor->type) { + case GGML_TYPE_I16: { + byteswap_tensor_data<int16_t>(tensor); + break; + } + case GGML_TYPE_F16: { + byteswap_tensor_data<ggml_fp16_t>(tensor); + break; + } + case GGML_TYPE_I32: { + byteswap_tensor_data<int32_t>(tensor); + break; + } + case GGML_TYPE_F32: { + byteswap_tensor_data<float>(tensor); + break; + } + default: { // GML_TYPE_I8 + break; + } + } +} + +#define BYTESWAP_VALUE(d) d = byteswap(d) +#define BYTESWAP_FILTERS(f) \ + do { \ + for (auto & datum : f.data) { \ + datum = byteswap(datum); \ + } \ + } while (0) +#define BYTESWAP_TENSOR(t) \ + do { \ + byteswap_tensor(t); \ + } while (0) +#else +#define BYTESWAP_VALUE(d) do {} while (0) +#define BYTESWAP_FILTERS(f) do {} while (0) +#define BYTESWAP_TENSOR(t) do {} while (0) +#endif + +#ifdef __GNUC__ +#ifdef __MINGW32__ +#define PARAKEET_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__))) +#else +#define PARAKEET_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__))) +#endif +#else +#define PARAKEET_ATTRIBUTE_FORMAT(...) +#endif + +// +// logging +// + +PARAKEET_ATTRIBUTE_FORMAT(2, 3) +static void parakeet_log_internal (ggml_log_level level, const char * format, ...); +static void parakeet_log_callback_default(ggml_log_level level, const char * text, void * user_data); + +#define PARAKEET_LOG_ERROR(...) parakeet_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__) +#define PARAKEET_LOG_WARN(...) parakeet_log_internal(GGML_LOG_LEVEL_WARN , __VA_ARGS__) +#define PARAKEET_LOG_INFO(...) parakeet_log_internal(GGML_LOG_LEVEL_INFO , __VA_ARGS__) + +// define this to enable verbose trace logging - useful for debugging purposes +//#define PARAKEET_DEBUG + +#if defined(PARAKEET_DEBUG) +#define PARAKEET_LOG_DEBUG(...) parakeet_log_internal(GGML_LOG_LEVEL_DEBUG, __VA_ARGS__) +#else +#define PARAKEET_LOG_DEBUG(...) +#endif + +#define PARAKEET_ASSERT(x) \ + do { \ + if (!(x)) { \ + PARAKEET_LOG_ERROR("PARAKEET_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \ + abort(); \ + } \ + } while (0) + +#define PARAKEET_MAX_NODES 8192 + +// Threshold for when local attention should be used. +// 8192 frames x 80ms = 655 s (about 10.9 mins) +static constexpr int PARAKEET_LOCAL_ATTN_THRESHOLD = 8192; +// Window of context in each director of the current token. +// 128 frames * 80ms = 10.24 s +static constexpr int PARAKEET_LOCAL_ATTN_WINDOW = 128; + +static std::string format(const char * fmt, ...) { + va_list ap; + va_list ap2; + va_start(ap, fmt); + va_copy(ap2, ap); + int size = vsnprintf(NULL, 0, fmt, ap); + GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT + std::vector<char> buf(size + 1); + int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2); + GGML_ASSERT(size2 == size); + va_end(ap2); + va_end(ap); + return std::string(buf.data(), size); +} + +// +// ggml helpers +// + +static bool ggml_graph_compute_helper( + struct ggml_cgraph * graph, + int n_threads, + ggml_abort_callback abort_callback, + void * abort_callback_data) { + ggml_backend_ptr backend { ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr) }; + + auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend.get())); + + auto * set_abort_callback_fn = (ggml_backend_set_abort_callback_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_abort_callback"); + if (set_abort_callback_fn) { + set_abort_callback_fn(backend.get(), abort_callback, abort_callback_data); + } + + auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads"); + if (ggml_backend_set_n_threads_fn) { + ggml_backend_set_n_threads_fn(backend.get(), n_threads); + } + + return ggml_backend_graph_compute(backend.get(), graph) == GGML_STATUS_SUCCESS; +} + +static bool ggml_graph_compute_helper( + ggml_backend_sched_t sched, + struct ggml_cgraph * graph, + int n_threads, + bool sched_reset = true) { + for (int i = 0; i < ggml_backend_sched_get_n_backends(sched); ++i) { + ggml_backend_t backend = ggml_backend_sched_get_backend(sched, i); + ggml_backend_dev_t dev = ggml_backend_get_device(backend); + ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr; + + auto * fn_set_n_threads = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads"); + if (fn_set_n_threads) { + fn_set_n_threads(backend, n_threads); + } + } + + const bool t = (ggml_backend_sched_graph_compute(sched, graph) == GGML_STATUS_SUCCESS); + + if (!t || sched_reset) { + ggml_backend_sched_reset(sched); + } + + return t; +} + +// TODO: move these functions to ggml-base with support for ggml-backend? + + +struct parakeet_mel { + int n_len = 0; + int n_len_org = 0; + int n_mel = 0; + + std::vector<float> data; +}; + +struct parakeet_filters { + int32_t n_mel = 0; + int32_t n_fb = 0; // number of frequency bins + + std::vector<float> data; +}; + +struct parakeet_vocab { + using id = int32_t; + using token = std::string; + + int n_vocab = 8192; + size_t max_token_length = 0; + + std::map<token, id> token_to_id; + std::map<id, token> id_to_token; + + id token_unk; + id token_bos; + id token_blank; + id token_eos; +}; + +struct parakeet_segment { + int64_t t0; + int64_t t1; + + std::string text; + + std::vector<parakeet_token_data> tokens; +}; + +struct parakeet_batch { + int32_t n_tokens; + + parakeet_token * token; + int32_t * i_time; // index of the audio frame + parakeet_pos * pos; + int32_t * n_seq_id; // always 1, here for consistency with llama.cpp + parakeet_seq_id ** seq_id; // null terminated + int8_t * logits; +}; + +// ggml_backend_sched wrapper for parakeet usage +struct parakeet_sched { + ggml_backend_sched_t sched = nullptr; + + std::vector<uint8_t> meta; +}; + +// TODO: Find out is there a multiple version types. It is not yet clear to me +// at this point. +enum parakeet_arch { + PARAKEET_ARCH_UNKNOWN = 0, + PARAKEET_ARCH_TDT = 1, // NVIDIA Parakeet TDT (RNN-T) +}; + +struct parakeet_hparams { + int32_t n_vocab = 8192; + int32_t n_audio_ctx = 0; // 0 = unlimited, will be set based on input + int32_t n_audio_state = 1024; + int32_t n_audio_head = 8; + int32_t n_audio_layer = 24; + int32_t n_mels = 128; + int32_t ftype = 1; + int32_t n_fft = 512; // FFT size for mel spectrogram + float eps = 1e-5f; + int32_t subsampling_factor = 8; + int32_t n_subsampling_channels = 256; + int32_t n_conv_kernel = 9; + int32_t n_pred_dim = 640; + int32_t n_pred_layers = 2; + int32_t n_tdt_durations = 5; + int32_t n_max_tokens = 10; + + parakeet_arch arch = PARAKEET_ARCH_TDT; +}; + +struct parakeet_layer_encoder { + struct ggml_tensor * norm_ff1_w = nullptr; + struct ggml_tensor * norm_ff1_b = nullptr; + + struct ggml_tensor * ff1_linear1_w = nullptr; + struct ggml_tensor * ff1_linear2_w = nullptr; + + struct ggml_tensor * norm_conv_w = nullptr; + struct ggml_tensor * norm_conv_b = nullptr; + + struct ggml_tensor * conv_pw1_w = nullptr; // pointwise_conv1 + struct ggml_tensor * conv_dw_w = nullptr; // depthwise_conv + struct ggml_tensor * conv_bn_w = nullptr; // batch_norm weight + struct ggml_tensor * conv_bn_b = nullptr; // batch_norm bias + struct ggml_tensor * conv_bn_mean = nullptr; // batch_norm running_mean + struct ggml_tensor * conv_bn_var = nullptr; // batch_norm running_var + struct ggml_tensor * conv_bn_num_batches = nullptr; // batch_norm num_batches_tracked + struct ggml_tensor * conv_pw2_w = nullptr; // pointwise_conv2 + + struct ggml_tensor * norm_attn_w = nullptr; + struct ggml_tensor * norm_attn_b = nullptr; + + struct ggml_tensor * attn_pos_bias_u = nullptr; + struct ggml_tensor * attn_pos_bias_v = nullptr; + struct ggml_tensor * attn_q_w = nullptr; + struct ggml_tensor * attn_k_w = nullptr; + struct ggml_tensor * attn_v_w = nullptr; + struct ggml_tensor * attn_out_w = nullptr; + struct ggml_tensor * attn_pos_w = nullptr; + + struct ggml_tensor * norm_ff2_w = nullptr; + struct ggml_tensor * norm_ff2_b = nullptr; + + struct ggml_tensor * ff2_linear1_w = nullptr; + struct ggml_tensor * ff2_linear2_w = nullptr; + + struct ggml_tensor * norm_out_w = nullptr; + struct ggml_tensor * norm_out_b = nullptr; +}; + +struct parakeet_lsmt_layer { + struct ggml_tensor * ih_w = nullptr; // input-to-hidden weight + struct ggml_tensor * hh_w = nullptr; // hidden-to-hidden weight + struct ggml_tensor * b_h = nullptr; // bias (ih folded into hh at conversion time) +}; + +struct parakeet_prediction_network { + struct ggml_tensor * embed_w = nullptr; + + std::vector<parakeet_lsmt_layer> lstm_layer; +}; + +struct parakeet_joint_network { + struct ggml_tensor * pred_w = nullptr; + struct ggml_tensor * pred_b = nullptr; + struct ggml_tensor * enc_w = nullptr; + struct ggml_tensor * enc_b = nullptr; + struct ggml_tensor * net_w = nullptr; + struct ggml_tensor * net_b = nullptr; +}; + +struct parakeet_model { + parakeet_filters filters; + parakeet_hparams hparams; + + struct ggml_tensor * enc_pre_out_w = nullptr; + struct ggml_tensor * enc_pre_out_b = nullptr; + struct ggml_tensor * enc_pre_conv_0_w = nullptr; + struct ggml_tensor * enc_pre_conv_0_b = nullptr; + struct ggml_tensor * enc_pre_conv_2_w = nullptr; + struct ggml_tensor * enc_pre_conv_2_b = nullptr; + struct ggml_tensor * enc_pre_conv_3_w = nullptr; + struct ggml_tensor * enc_pre_conv_3_b = nullptr; + struct ggml_tensor * enc_pre_conv_5_w = nullptr; + struct ggml_tensor * enc_pre_conv_5_b = nullptr; + struct ggml_tensor * enc_pre_conv_6_w = nullptr; + struct ggml_tensor * enc_pre_conv_6_b = nullptr; + + std::vector<parakeet_layer_encoder> layers; + + parakeet_prediction_network prediction; + + parakeet_joint_network joint; + + std::vector<uint32_t> tdt_durations; + + std::vector<ggml_context *> ctxs; + + std::vector<ggml_backend_buffer_t> buffers; + + int n_loaded = 0; + std::map<std::string, struct ggml_tensor *> tensors; +}; + +struct parakeet_lstm_state_layer { + struct ggml_tensor * h_state = nullptr; + struct ggml_tensor * c_state = nullptr; +}; + +struct parakeet_lstm_state { + std::vector<parakeet_lstm_state_layer> layer; + + std::vector<uint8_t> ctx_buf; + + ggml_backend_buffer_t buffer = nullptr; +}; + +struct parakeet_state { + int64_t t_sample_us = 0; + int64_t t_encode_us = 0; + int64_t t_decode_us = 0; + int64_t t_predict_us = 0; + int64_t t_predict_build_us = 0; // time spent building the prediction graph + int64_t t_predict_alloc_us = 0; // time spent in ggml_backend_sched_alloc_graph + int64_t t_predict_compute_us = 0; // time spent in ggml_graph_compute_helper + int64_t t_mel_us = 0; + + int32_t n_sample = 0; // number of tokens sampled + int32_t n_encode = 0; // number of encoder calls + int32_t n_decode = 0; // number of decoder calls with n_tokens == 1 (text-generation) + int32_t n_predict = 0; // number of prediction network calls + int32_t n_fail_p = 0; // number of logprob threshold failures + int32_t n_fail_h = 0; // number of entropy threshold failures + + parakeet_mel mel; + + parakeet_batch batch; + + int n_frames = 0; + + std::vector<ggml_backend_t> backends; + + parakeet_sched sched_encode; + parakeet_sched sched_decode; + + // outputs from encoder stages + struct ggml_tensor * enc_out = nullptr; + struct ggml_tensor * pred_out = nullptr; + + std::vector<uint8_t> enc_out_buf; + ggml_backend_buffer_t enc_out_buffer = nullptr; + + std::vector<uint8_t> pred_out_buf; + ggml_backend_buffer_t pred_out_buffer = nullptr; + + struct ggml_tensor * attn_mask = nullptr; + + std::vector<float> inp_mel; + std::vector<float> inp_mask; + + std::vector<float> logits; + + std::vector<parakeet_segment> result_all; + + std::vector<parakeet_token> decoded_tokens; + std::vector<parakeet_token_data> decoded_token_data; + + std::string path_model; + + int32_t n_audio_ctx = 0; + int32_t sched_encode_n_audio_ctx = 0; + + parakeet_lstm_state lstm_state; +}; + +// FFT cache for mel spectrogram computation +struct parakeet_mel_cache { + int n_fft = 0; + + // In FFT, we frequently use sine and cosine operations with the same values. + // We can use precalculated values to speed up the process. + std::vector<float> sin_vals; + std::vector<float> cos_vals; + + // Hann window (Use cosf to eliminate difference) + // ref: https://pytorch.org/docs/stable/generated/torch.hann_window.html + // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L147 + std::vector<float> hann_window; + + // Window function from model (Parakeet uses actual window from training) + std::vector<float> window; + + void init(int fft_size) { + n_fft = fft_size; + sin_vals.resize(n_fft); + cos_vals.resize(n_fft); + hann_window.resize(n_fft); + + fill_sin_cos_table(); + fill_hann_window(n_fft, true, hann_window.data()); + } + + void fill_sin_cos_table() { + for (int i = 0; i < n_fft; i++) { + double theta = (2 * M_PI * i) / n_fft; + sin_vals[i] = sinf(theta); + cos_vals[i] = cosf(theta); + } + } + + void fill_hann_window(int length, bool periodic, float * output) { + int offset = -1; + if (periodic) { + offset = 0; + } + for (int i = 0; i < length; i++) { + output[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset))); + } + } +}; + +struct parakeet_context { + int64_t t_load_us = 0; + int64_t t_start_us = 0; + + ggml_type wtype = ggml_type::GGML_TYPE_F16; + ggml_type itype = ggml_type::GGML_TYPE_F16; + + parakeet_context_params params; + + parakeet_model model; + parakeet_vocab vocab; + + parakeet_state * state = nullptr; + + parakeet_mel_cache mel_cache; + + std::string path_model; +}; + +struct parakeet_global { + // We save the log callback globally + ggml_log_callback log_callback = parakeet_log_callback_default; + void * log_callback_user_data = nullptr; +}; + +static parakeet_global g_state; + +static const std::string PARAKEET_SPM_SPACE = "\xE2\x96\x81"; + +static inline int utf8_codepoint_len(unsigned char c) { + if ((c & 0x80) == 0x00) return 1; + if ((c & 0xE0) == 0xC0) return 2; + if ((c & 0xF0) == 0xE0) return 3; + if ((c & 0xF8) == 0xF0) return 4; + return 1; +} + +static bool is_sentencepiece_control(const std::string & piece) { + return piece == "<unk>" || piece == "<s>" || piece == "</s>" || piece == "[BLANK]"; +} + +static std::string sentencepiece_normalize(const std::string & text) { + std::string normalized; + normalized.reserve(text.size() + PARAKEET_SPM_SPACE.size()); + normalized += PARAKEET_SPM_SPACE; // SentencePiece dummy prefix + + for (unsigned char c : text) { + if (std::isspace(c)) { + normalized += PARAKEET_SPM_SPACE; + } else { + normalized += static_cast<char>(c); + } + } + + return normalized; +} + +static std::string sentencepiece_piece_to_text(const std::string & piece, bool is_first_piece) { + if (is_sentencepiece_control(piece)) { + return ""; + } + + std::string text; + text.reserve(piece.size()); + + size_t pos = 0; + while (pos < piece.size()) { + if (piece.compare(pos, PARAKEET_SPM_SPACE.size(), PARAKEET_SPM_SPACE) == 0) { + if (!is_first_piece || !text.empty()) { + text += ' '; + } + pos += PARAKEET_SPM_SPACE.size(); + continue; + } + + text += piece[pos]; + ++pos; + } + + return text; +} + + +static struct parakeet_batch parakeet_batch_init(int32_t n_tokens) { + parakeet_batch batch = { 0, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, }; + + batch.token = (parakeet_token * ) malloc(sizeof(parakeet_token) * (n_tokens)); + batch.i_time = (int32_t *) malloc(sizeof(int32_t) * (n_tokens)); + batch.pos = (parakeet_pos *) malloc(sizeof(parakeet_pos) * (n_tokens)); + batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * (n_tokens)); + batch.seq_id = (parakeet_seq_id **) malloc(sizeof(parakeet_seq_id *) * (n_tokens + 1)); + for (int i = 0; i < n_tokens; ++i) { + batch.seq_id[i] = (parakeet_seq_id *) malloc(sizeof(parakeet_seq_id)); + } + batch.seq_id[n_tokens] = nullptr; + batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens); + + return batch; +} + +static void parakeet_batch_free(struct parakeet_batch batch) { + if (batch.token) free(batch.token); + if (batch.i_time) free(batch.i_time); + if (batch.pos) free(batch.pos); + if (batch.n_seq_id) free(batch.n_seq_id); + if (batch.seq_id) { + for (int i = 0; batch.seq_id[i]; ++i) { + free(batch.seq_id[i]); + } + free(batch.seq_id); + } + if (batch.logits) free(batch.logits); +} + +static void parakeet_batch_prep_legacy(parakeet_batch & batch, const parakeet_token * tokens, int n_tokens, int n_past, int seq_id) { + batch.n_tokens = n_tokens; + for (int i = 0; i < n_tokens; ++i) { + if (tokens) { + batch.token[i] = tokens[i]; + } + batch.pos [i] = n_past + i; + batch.n_seq_id[i] = 1; + batch.seq_id [i][0] = seq_id; + batch.logits [i] = 0; + } + batch.logits[n_tokens - 1] = 1; +} + + +static size_t parakeet_sched_size(struct parakeet_sched & allocr) { + size_t size = allocr.meta.size(); + for (int i = 0; i < ggml_backend_sched_get_n_backends(allocr.sched); ++i) { + ggml_backend_t backend = ggml_backend_sched_get_backend(allocr.sched, i); + size += ggml_backend_sched_get_buffer_size(allocr.sched, backend); + } + return size; +} + +static bool parakeet_sched_graph_init(struct parakeet_sched & allocr, std::vector<ggml_backend_t> backends, std::function<struct ggml_cgraph *()> && get_graph) { + auto & sched = allocr.sched; + auto & meta = allocr.meta; + + sched = ggml_backend_sched_new(backends.data(), nullptr, backends.size(), PARAKEET_MAX_NODES, false, true); + + if (!sched) { + PARAKEET_LOG_ERROR("%s: failed to create scheduler\n", __func__); + return false; + } + + meta.resize(ggml_tensor_overhead()*PARAKEET_MAX_NODES + ggml_graph_overhead()); + + if (!ggml_backend_sched_alloc_graph(sched, get_graph())) { + PARAKEET_LOG_ERROR("%s: failed to allocate the compute buffer\n", __func__); + ggml_backend_sched_free(sched); + sched = nullptr; + return false; + } + + ggml_backend_sched_reset(sched); + + return true; +} + +static void parakeet_sched_free(struct parakeet_sched & sched) { + if (sched.sched) { + ggml_backend_sched_free(sched.sched); + sched.sched = nullptr; + } + + sched.meta.clear(); +} + + +template<typename T> +static void read_safe(parakeet_model_loader * loader, T & dest) { + loader->read(loader->context, &dest, sizeof(T)); + BYTESWAP_VALUE(dest); +} + +static bool parakeet_lstm_state_init( + struct parakeet_state & pstate, + ggml_backend_t backend, + int n_layer, + int n_pred_dim) { + parakeet_lstm_state & lstm_state = pstate.lstm_state; + + lstm_state.ctx_buf.resize(ggml_tensor_overhead() * n_layer * 2); + lstm_state.layer.resize(n_layer); + + struct ggml_init_params params = { + /*.mem_size =*/ lstm_state.ctx_buf.size(), + /*.mem_buffer =*/ lstm_state.ctx_buf.data(), + /*.no_alloc =*/ true, + }; + + struct ggml_context * ctx = ggml_init(params); + + if (!ctx) { + PARAKEET_LOG_ERROR("%s: failed to allocate memory for the lstm states context\n", __func__); + return false; + } + + + for (int il = 0; il < n_layer; ++il) { + lstm_state.layer[il].h_state = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_pred_dim); + lstm_state.layer[il].c_state = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_pred_dim); + } + + lstm_state.buffer = ggml_backend_alloc_ctx_tensors(ctx, backend); + if (!lstm_state.buffer) { + PARAKEET_LOG_ERROR("%s: failed to allocate memory for the lstm states\n", __func__); + return false; + } + + ggml_backend_buffer_clear(lstm_state.buffer, 0); + + ggml_free(ctx); + + return true; +} + +static bool parakeet_pred_state_init( + struct parakeet_state & pstate, + ggml_backend_t backend, + int n_pred_dim) { + pstate.pred_out_buf.resize(ggml_tensor_overhead()); + + struct ggml_init_params params = { + /*.mem_size =*/ pstate.pred_out_buf.size(), + /*.mem_buffer =*/ pstate.pred_out_buf.data(), + /*.no_alloc =*/ true, + }; + + struct ggml_context * ctx = ggml_init(params); + if (!ctx) { + PARAKEET_LOG_ERROR("%s: failed to allocate memory for pred tensor context\n", __func__); + return false; + } + + pstate.pred_out = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_pred_dim); + pstate.pred_out_buffer = ggml_backend_alloc_ctx_tensors(ctx, backend); + if (!pstate.pred_out_buffer) { + PARAKEET_LOG_ERROR("%s: failed to allocate memory for pred tensor\n", __func__); + ggml_free(ctx); + return false; + } + + ggml_free(ctx); + + return true; +} + +static bool parakeet_enc_state_init( + struct parakeet_state & pstate, + ggml_backend_t backend, + int n_audio_state, + int n_frames_max) { + pstate.enc_out_buf.resize(ggml_tensor_overhead()); + + struct ggml_init_params params = { + /*.mem_size =*/ pstate.enc_out_buf.size(), + /*.mem_buffer =*/ pstate.enc_out_buf.data(), + /*.no_alloc =*/ true, + }; + + struct ggml_context * ctx = ggml_init(params); + if (!ctx) { + PARAKEET_LOG_ERROR("%s: failed to allocate memory for enc_out tensor context\n", __func__); + return false; + } + + pstate.enc_out = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_state, n_frames_max); + pstate.enc_out_buffer = ggml_backend_alloc_ctx_tensors(ctx, backend); + if (!pstate.enc_out_buffer) { + PARAKEET_LOG_ERROR("%s: failed to allocate memory for enc_out tensor\n", __func__); + ggml_free(ctx); + return false; + } + + ggml_free(ctx); + + return true; +} + +static ggml_backend_t parakeet_backend_init_gpu(const parakeet_context_params & params) { + ggml_log_set(g_state.log_callback, g_state.log_callback_user_data); + + ggml_backend_dev_t dev = nullptr; + + int cnt = 0; + if (params.use_gpu) { + for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { + ggml_backend_dev_t dev_cur = ggml_backend_dev_get(i); + enum ggml_backend_dev_type dev_type = ggml_backend_dev_type(dev_cur); + const char * dev_name = ggml_backend_dev_name(dev_cur); + PARAKEET_LOG_INFO("%s: device %zu: %s (type: %d)\n", __func__, i, dev_name, dev_type); + if (dev_type == GGML_BACKEND_DEVICE_TYPE_GPU || dev_type == GGML_BACKEND_DEVICE_TYPE_IGPU) { + PARAKEET_LOG_INFO("%s: found GPU device %zu: %s (type: %d, cnt: %d)\n", __func__, i, dev_name, dev_type, cnt); + if (cnt == params.gpu_device) { + dev = dev_cur; + } + + if (++cnt > params.gpu_device) { + break; + } + } + } + } + + if (dev == nullptr) { + PARAKEET_LOG_INFO("%s: no GPU found\n", __func__); + return nullptr; + } + + PARAKEET_LOG_INFO("%s: using %s backend\n", __func__, ggml_backend_dev_name(dev)); + ggml_backend_t result = ggml_backend_dev_init(dev, nullptr); + if (!result) { + PARAKEET_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev)); + } + + return result; +} + +static std::vector<ggml_backend_t> parakeet_backend_init(const parakeet_context_params & params) { + std::vector<ggml_backend_t> result; + + ggml_backend_t backend_gpu = parakeet_backend_init_gpu(params); + + if (backend_gpu) { + result.push_back(backend_gpu); + } + + // ACCEL backends + for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { + ggml_backend_dev_t dev = ggml_backend_dev_get(i); + if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_ACCEL) { + PARAKEET_LOG_INFO("%s: using %s backend\n", __func__, ggml_backend_dev_name(dev)); + ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr); + if (!backend) { + PARAKEET_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev)); + continue; + } + result.push_back(backend); + } + } + + ggml_backend_t backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr); + if (backend_cpu == nullptr) { + throw std::runtime_error("failed to initialize CPU backend"); + } + result.push_back(backend_cpu); + + return result; +} + +using buft_list_t = std::vector<std::pair<ggml_backend_dev_t, ggml_backend_buffer_type_t>>; + +static buft_list_t make_buft_list(parakeet_context_params & params) { + // Prio order: GPU -> CPU Extra -> CPU + buft_list_t buft_list; + + // GPU + if (params.use_gpu) { + int cnt = 0; + for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { + ggml_backend_dev_t dev = ggml_backend_dev_get(i); + if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU || ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_IGPU) { + if (cnt == params.gpu_device) { + auto * buft = ggml_backend_dev_buffer_type(dev); + if (buft) { + buft_list.emplace_back(dev, buft); + } + } + + if (++cnt > params.gpu_device) { + break; + } + } + } + } + + // CPU Extra + auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev); + auto get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t) + ggml_backend_reg_get_proc_address(cpu_reg, "ggml_backend_dev_get_extra_bufts"); + if (get_extra_bufts_fn) { + ggml_backend_buffer_type_t * extra_bufts = get_extra_bufts_fn(cpu_dev); + while (extra_bufts && *extra_bufts) { + buft_list.emplace_back(cpu_dev, *extra_bufts); + ++extra_bufts; + } + } + + // CPU + buft_list.emplace_back(cpu_dev, ggml_backend_cpu_buffer_type()); + + return buft_list; +} + +static bool weight_buft_supported(const parakeet_hparams & hparams, ggml_tensor * w, ggml_op op, ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev) { + bool op_supported = true; + + if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU || + ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_IGPU || + (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_CPU && buft == ggml_backend_cpu_buffer_type())) { + // GPU and default CPU backend support all operators + op_supported = true; + } else { + switch (op) { + // The current extra_buffer_type implementations only support GGML_OP_MUL_MAT and GGML_OP_GET_ROWS + case GGML_OP_GET_ROWS: + case GGML_OP_MUL_MAT: { + ggml_init_params params = { + /*.mem_size =*/ 2 * ggml_tensor_overhead(), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + + ggml_context_ptr ctx_ptr { ggml_init(params) }; + if (!ctx_ptr) { + throw std::runtime_error("failed to create ggml context"); + } + ggml_context * ctx = ctx_ptr.get(); + + ggml_tensor * op_tensor = nullptr; + + if (op == GGML_OP_MUL_MAT) { + int64_t n_ctx = hparams.n_audio_ctx; + ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], n_ctx, w->ne[2], w->ne[3]); + op_tensor = ggml_mul_mat(ctx, w, b); + } else if (op == GGML_OP_GET_ROWS) { + int64_t num_indices = 8; + ggml_tensor * indices = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, num_indices); + op_tensor = ggml_get_rows(ctx, w, indices); + } + + // create a temporary dummy buffer for the weight so that supports_op can check the buffer type + GGML_ASSERT(w->buffer == nullptr); + w->buffer = ggml_backend_buft_alloc_buffer(buft, 0); + op_supported = ggml_backend_dev_supports_op(dev, op_tensor); + ggml_backend_buffer_free(w->buffer); + w->buffer = nullptr; + break; + } + default: { + op_supported = false; + break; + } + }; + } + + return op_supported; +} + +static ggml_backend_buffer_type_t select_weight_buft(const parakeet_hparams & hparams, ggml_tensor * w, ggml_op op, buft_list_t buft_list) { + GGML_ASSERT(!buft_list.empty()); + for (const auto & p : buft_list) { + ggml_backend_dev_t dev = p.first; + ggml_backend_buffer_type_t buft = p.second; + if (weight_buft_supported(hparams, w, op, buft, dev)) { + return buft; + } + } + + return nullptr; +} + + +// load the model from a ggml file +// + +// see the convert-parakeet-to-ggml.py script for details +// +static bool parakeet_model_load(struct parakeet_model_loader * loader, parakeet_context & wctx) { + PARAKEET_LOG_INFO("%s: loading model\n", __func__); + + const int64_t t_start_us = ggml_time_us(); + + wctx.t_start_us = t_start_us; + + auto & model = wctx.model; + auto & vocab = wctx.vocab; + + // verify magic + { + uint32_t magic; + read_safe(loader, magic); + if (magic != GGML_FILE_MAGIC) { + PARAKEET_LOG_ERROR("%s: invalid model data (bad magic)\n", __func__); + return false; + } + } + + //load hparams + parakeet_hparams hparams; + { + read_safe(loader, hparams.n_vocab); + read_safe(loader, hparams.n_audio_ctx); + read_safe(loader, hparams.n_audio_state); + read_safe(loader, hparams.n_audio_head); + read_safe(loader, hparams.n_audio_layer); + read_safe(loader, hparams.n_mels); + read_safe(loader, hparams.ftype); + read_safe(loader, hparams.n_fft); + read_safe(loader, hparams.subsampling_factor); + read_safe(loader, hparams.n_subsampling_channels); + read_safe(loader, hparams.n_conv_kernel); + read_safe(loader, hparams.n_pred_dim); + read_safe(loader, hparams.n_pred_layers); + read_safe(loader, hparams.n_tdt_durations); + read_safe(loader, hparams.n_max_tokens); + + hparams.arch = PARAKEET_ARCH_TDT; + wctx.model.hparams = hparams; + + const int32_t qntvr = hparams.ftype / GGML_QNT_VERSION_FACTOR; + + hparams.ftype %= GGML_QNT_VERSION_FACTOR; + + // for the big tensors, we have the option to store the data in 16-bit floats or quantized + // in order to save memory and also to speed up the computation + wctx.wtype = ggml_ftype_to_ggml_type((ggml_ftype) hparams.ftype); + if (wctx.wtype == GGML_TYPE_COUNT) { + PARAKEET_LOG_ERROR("%s: invalid model (bad ftype value %d)\n", __func__, hparams.ftype); + return false; + } + + const char* arch_name = hparams.arch == PARAKEET_ARCH_TDT ? "Parakeet TDT" : "unknown"; + PARAKEET_LOG_INFO("%s: arch = %s\n", __func__, arch_name); + PARAKEET_LOG_INFO("%s: n_vocab = %d\n", __func__, hparams.n_vocab); + PARAKEET_LOG_INFO("%s: n_audio_ctx = %d\n", __func__, hparams.n_audio_ctx); + PARAKEET_LOG_INFO("%s: n_audio_state = %d\n", __func__, hparams.n_audio_state); + PARAKEET_LOG_INFO("%s: n_audio_head = %d\n", __func__, hparams.n_audio_head); + PARAKEET_LOG_INFO("%s: n_audio_layer = %d\n", __func__, hparams.n_audio_layer); + PARAKEET_LOG_INFO("%s: n_mels = %d\n", __func__, hparams.n_mels); + PARAKEET_LOG_INFO("%s: n_fft = %d\n", __func__, hparams.n_fft); + PARAKEET_LOG_INFO("%s: eps = %f\n", __func__, hparams.eps); + PARAKEET_LOG_INFO("%s: ftype = %d\n", __func__, hparams.ftype); + PARAKEET_LOG_INFO("%s: qntvr = %d\n", __func__, qntvr); + PARAKEET_LOG_INFO("%s: subsampling_factor = %d\n", __func__, hparams.subsampling_factor); + PARAKEET_LOG_INFO("%s: n_subsampling_channels = %d\n", __func__, hparams.n_subsampling_channels); + PARAKEET_LOG_INFO("%s: n_conv_kernel = %d\n", __func__, hparams.n_conv_kernel); + PARAKEET_LOG_INFO("%s: n_pred_dim = %d\n", __func__, hparams.n_pred_dim); + PARAKEET_LOG_INFO("%s: n_pred_layers = %d\n", __func__, hparams.n_pred_layers); + PARAKEET_LOG_INFO("%s: n_tdt_durations = %d\n", __func__, hparams.n_tdt_durations); + PARAKEET_LOG_INFO("%s: n_max_tokens = %d\n", __func__, hparams.n_max_tokens); + } + + // load mel filters + { + auto & filters = wctx.model.filters; + + read_safe(loader, filters.n_mel); + read_safe(loader, filters.n_fb); + + filters.data.resize(filters.n_mel * filters.n_fb); + loader->read(loader->context, filters.data.data(), filters.data.size() * sizeof(float)); + BYTESWAP_FILTERS(filters); + } + + // load window function + { + int32_t n_window = 0; + read_safe(loader, n_window); + + wctx.mel_cache.window.resize(n_window); + loader->read(loader->context, wctx.mel_cache.window.data(), n_window * sizeof(float)); + +#ifdef GGML_BIG_ENDIAN + for (auto & datum : wctx.mel_cache.window) { + datum = byteswap(datum); + } +#endif + + PARAKEET_LOG_INFO("%s: loaded window function with %d samples\n", __func__, n_window); + } + + // load TDT (Token and Duration Transducer) values + { + auto & tdt_durations = wctx.model.tdt_durations; + tdt_durations.resize(hparams.n_tdt_durations); + loader->read(loader->context, tdt_durations.data(), hparams.n_tdt_durations * sizeof(uint32_t)); + + PARAKEET_LOG_INFO("%s: loaded tdt_durations: [", __func__); + for (const auto value : tdt_durations) { + PARAKEET_LOG_INFO("%u ", value); + } + PARAKEET_LOG_INFO("]\n"); + } + + // load vocab + { + int32_t n_vocab = 0; + read_safe(loader, n_vocab); + + std::string word; + std::vector<char> tmp; + + tmp.reserve(128); + + for (int i = 0; i < n_vocab; i++) { + uint32_t len; + read_safe(loader, len); + + if (len > 0) { + tmp.resize(len); + loader->read(loader->context, &tmp[0], tmp.size()); // read to buffer + word.assign(&tmp[0], tmp.size()); + } else { + PARAKEET_LOG_WARN("%s: warning: empty-string token in vocab, i = %d\n", __func__, i); + word = ""; + } + + vocab.token_to_id[word] = i; + vocab.id_to_token[i] = word; + vocab.max_token_length = std::max(vocab.max_token_length, word.size()); + } + // Blank token for transducer is at index n_vocab (8192), outside the vocabulary + int blank_id = n_vocab; + vocab.token_blank = blank_id; + vocab.id_to_token[blank_id] = "[BLANK]"; + vocab.token_to_id["[BLANK]"] = blank_id; + + // Set special token IDs by looking them up in the loaded vocabulary + // These are from the SentencePiece vocab file loaded above + if (vocab.token_to_id.find("<unk>") != vocab.token_to_id.end()) { + vocab.token_unk = vocab.token_to_id.at("<unk>"); + } else { + vocab.token_unk = 0; // Fallback + } + + if (vocab.token_to_id.find("<s>") != vocab.token_to_id.end()) { + vocab.token_bos = vocab.token_to_id.at("<s>"); + } else if (vocab.token_to_id.find("<|startoftranscript|>") != vocab.token_to_id.end()) { + vocab.token_bos = vocab.token_to_id.at("<|startoftranscript|>"); + } else { + vocab.token_bos = 0; // Fallback + } + + if (vocab.token_to_id.find("</s>") != vocab.token_to_id.end()) { + vocab.token_eos = vocab.token_to_id.at("</s>"); + } else if (vocab.token_to_id.find("<|endoftext|>") != vocab.token_to_id.end()) { + vocab.token_eos = vocab.token_to_id.at("<|endoftext|>"); + } else { + vocab.token_eos = 0; // Fallback + } + + vocab.n_vocab = model.hparams.n_vocab; + + PARAKEET_LOG_INFO("%s: loaded vocab with %d tokens (blank_id=%d, unk=%d, bos=%d, eos=%d)\n", + __func__, n_vocab, blank_id, vocab.token_unk, vocab.token_bos, vocab.token_eos); + } + + const ggml_type wtype = wctx.wtype; + + + const int n_audio_layer = hparams.n_audio_layer; + + // Calculate tensor count: pre_encode (12) + encoder layers (29 per layer) + prediction (9) + joint (6) + size_t n_tensors = 12 + (29 * n_audio_layer) + 9 + 6; + + std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map; + auto get_ctx = [&](ggml_backend_buffer_type_t buft) -> ggml_context * { + auto it = ctx_map.find(buft); + if (it == ctx_map.end()) { + ggml_init_params params = { + /*.mem_size =*/ n_tensors * ggml_tensor_overhead(), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + + ggml_context * ctx = ggml_init(params); + if (!ctx) { + throw std::runtime_error("failed to create ggml context"); + } + + ctx_map[buft] = ctx; + wctx.model.ctxs.emplace_back(ctx); + + return ctx; + } + + return it->second; + }; + + // Create a list of available bufts, in priority order + buft_list_t buft_list = make_buft_list(wctx.params); + + auto create_tensor = [&](parakeet_tensor type, ggml_tensor * meta, int layer = -1) -> ggml_tensor * { + ggml_op op = PARAKEET_TENSOR_INFO.at(type); + ggml_backend_buffer_type_t buft = select_weight_buft(hparams, meta, op, buft_list); + if (!buft) { + throw std::runtime_error(format("failed to find a compatible buffer type for parakeet tensor %s", + PARAKEET_TENSOR_NAMES.at(type))); + } + + ggml_context * ctx = get_ctx(buft); + ggml_tensor * tensor = ggml_dup_tensor(ctx, meta); + + std::string tensor_name; + if (layer >= 0) { + tensor_name = format(PARAKEET_TENSOR_NAMES.at(type), layer); + } else { + tensor_name = PARAKEET_TENSOR_NAMES.at(type); + } + + wctx.model.tensors[tensor_name] = tensor; + + return tensor; + }; + + // prepare tensors for the weights + + ggml_init_params params = { + /*.mem_size =*/ n_tensors * ggml_tensor_overhead(), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + + ggml_context * ctx = ggml_init(params); + + const int n_audio_state = hparams.n_audio_state; + + model.layers.resize(n_audio_layer); + + // Encoder pre_encode + const int n_subsampling_channels = hparams.n_subsampling_channels; + const int n_pre_enc_features = (hparams.n_mels / hparams.subsampling_factor) * n_subsampling_channels; + model.enc_pre_out_w = create_tensor(PARAKEET_TENSOR_ENC_PRE_OUT_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_pre_enc_features, n_audio_state)); + ggml_set_name(model.enc_pre_out_w, "enc_pre_out_w"); + model.enc_pre_out_b = create_tensor(PARAKEET_TENSOR_ENC_PRE_OUT_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state)); + ggml_set_name(model.enc_pre_out_b, "enc_pre_out_b"); + + model.enc_pre_conv_0_w = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_0_WEIGHT, ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 3, 3, 1, n_subsampling_channels)); + ggml_set_name(model.enc_pre_conv_0_w, "enc_pre_conv_0_w"); + model.enc_pre_conv_0_b = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_0_BIAS, ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, 1, n_subsampling_channels, 1)); + ggml_set_name(model.enc_pre_conv_0_b, "enc_pre_conv_0_b"); + + model.enc_pre_conv_2_w = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_2_WEIGHT, ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 3, 3, 1, n_subsampling_channels)); + ggml_set_name(model.enc_pre_conv_2_w, "enc_pre_conv_2_w"); + model.enc_pre_conv_2_b = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_2_BIAS, ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, 1, n_subsampling_channels, 1)); + ggml_set_name(model.enc_pre_conv_2_b, "enc_pre_conv_2_b"); + + model.enc_pre_conv_3_w = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_3_WEIGHT, ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, 1, n_subsampling_channels, n_subsampling_channels)); + ggml_set_name(model.enc_pre_conv_3_w, "enc_pre_conv_3_w"); + model.enc_pre_conv_3_b = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_3_BIAS, ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, 1, n_subsampling_channels, 1)); + ggml_set_name(model.enc_pre_conv_3_b, "enc_pre_conv_3_b"); + + model.enc_pre_conv_5_w = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_5_WEIGHT, ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 3, 3, 1, n_subsampling_channels)); + ggml_set_name(model.enc_pre_conv_5_w, "enc_pre_conv_5_w"); + model.enc_pre_conv_5_b = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_5_BIAS, ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, 1, n_subsampling_channels, 1)); + ggml_set_name(model.enc_pre_conv_5_b, "enc_pre_conv_5_b"); + + model.enc_pre_conv_6_w = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_6_WEIGHT, ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, 1, n_subsampling_channels, n_subsampling_channels)); + ggml_set_name(model.enc_pre_conv_6_w, "enc_pre_conv_6_w"); + model.enc_pre_conv_6_b = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_6_BIAS, ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, 1, n_subsampling_channels, 1)); + ggml_set_name(model.enc_pre_conv_6_b, "enc_pre_conv_6_b"); + + // Encoder layers + for (int i = 0; i < n_audio_layer; ++i) { + auto & layer = model.layers[i]; + + // Feed forward 1 + layer.norm_ff1_w = create_tensor(PARAKEET_TENSOR_ENC_NORM_FF1_WEIGHT, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + layer.norm_ff1_b = create_tensor(PARAKEET_TENSOR_ENC_NORM_FF1_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + layer.ff1_linear1_w = create_tensor(PARAKEET_TENSOR_ENC_FF1_LINEAR1_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, 4*n_audio_state), i); + ggml_format_name(layer.ff1_linear1_w, "enc_%d_ff1_linear1_w", i); + layer.ff1_linear2_w = create_tensor(PARAKEET_TENSOR_ENC_FF1_LINEAR2_WEIGHT, ggml_new_tensor_2d(ctx, wtype, 4*n_audio_state, n_audio_state), i); + ggml_format_name(layer.ff1_linear2_w, "enc_%d_ff1_linear2_w", i); + + // Convolution module + layer.norm_conv_w = create_tensor(PARAKEET_TENSOR_ENC_NORM_CONV_WEIGHT, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + ggml_format_name(layer.norm_conv_w, "enc_%d_norm_conv_w", i); + layer.norm_conv_b = create_tensor(PARAKEET_TENSOR_ENC_NORM_CONV_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + ggml_format_name(layer.norm_conv_b, "enc_%d_norm_conv_b", i); + layer.conv_pw1_w = create_tensor(PARAKEET_TENSOR_ENC_CONV_PW1_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, 2*n_audio_state), i); + ggml_format_name(layer.conv_pw1_w, "enc_%d_conv_pw1_w", i); + layer.conv_dw_w = create_tensor(PARAKEET_TENSOR_ENC_CONV_DW_WEIGHT, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hparams.n_conv_kernel, n_audio_state), i); + ggml_format_name(layer.conv_dw_w, "enc_%d_conv_dw_w", i); + layer.conv_bn_w = create_tensor(PARAKEET_TENSOR_ENC_CONV_BN_WEIGHT, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + ggml_format_name(layer.conv_bn_w, "enc_%d_conv_bn_w", i); + layer.conv_bn_b = create_tensor(PARAKEET_TENSOR_ENC_CONV_BN_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + ggml_format_name(layer.conv_bn_b, "enc_%d_conv_bn_b", i); + layer.conv_bn_mean = create_tensor(PARAKEET_TENSOR_ENC_CONV_BN_MEAN, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + layer.conv_bn_var = create_tensor(PARAKEET_TENSOR_ENC_CONV_BN_VAR, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + ggml_format_name(layer.conv_bn_var, "enc_%d_conv_bn_var", i); + layer.conv_bn_num_batches = create_tensor(PARAKEET_TENSOR_ENC_CONV_BN_NUM_BATCHES, ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1), i); + layer.conv_pw2_w = create_tensor(PARAKEET_TENSOR_ENC_CONV_PW2_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i); + ggml_format_name(layer.conv_pw2_w, "enc_%d_conv_pw2_w", i); + + // Self attention + layer.norm_attn_w = create_tensor(PARAKEET_TENSOR_ENC_NORM_ATTN_WEIGHT, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + layer.norm_attn_b = create_tensor(PARAKEET_TENSOR_ENC_NORM_ATTN_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + layer.attn_pos_bias_u = create_tensor(PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_U, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hparams.n_audio_state / hparams.n_audio_head, hparams.n_audio_head), i); + layer.attn_pos_bias_v = create_tensor(PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_V, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hparams.n_audio_state / hparams.n_audio_head, hparams.n_audio_head), i); + layer.attn_q_w = create_tensor(PARAKEET_TENSOR_ENC_ATTN_Q_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i); + layer.attn_k_w = create_tensor(PARAKEET_TENSOR_ENC_ATTN_K_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i); + layer.attn_v_w = create_tensor(PARAKEET_TENSOR_ENC_ATTN_V_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i); + layer.attn_out_w = create_tensor(PARAKEET_TENSOR_ENC_ATTN_OUT_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i); + layer.attn_pos_w = create_tensor(PARAKEET_TENSOR_ENC_ATTN_POS_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i); + ggml_format_name(layer.attn_pos_w, "enc_%d_attn_pos_w", i); + + // Feed forward 2 + layer.norm_ff2_w = create_tensor(PARAKEET_TENSOR_ENC_NORM_FF2_WEIGHT, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + layer.norm_ff2_b = create_tensor(PARAKEET_TENSOR_ENC_NORM_FF2_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + layer.ff2_linear1_w = create_tensor(PARAKEET_TENSOR_ENC_FF2_LINEAR1_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, 4*n_audio_state), i); + layer.ff2_linear2_w = create_tensor(PARAKEET_TENSOR_ENC_FF2_LINEAR2_WEIGHT, ggml_new_tensor_2d(ctx, wtype, 4*n_audio_state, n_audio_state), i); + + // Output norm + layer.norm_out_w = create_tensor(PARAKEET_TENSOR_ENC_NORM_OUT_WEIGHT, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + layer.norm_out_b = create_tensor(PARAKEET_TENSOR_ENC_NORM_OUT_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + } + + // Prediction network (decoder) + const int dec_hidden = hparams.n_pred_dim; + const int n_pred_embed = hparams.n_vocab + 1; // vocab + blank token + const int n_lstm_gates = 4 * dec_hidden; // 4 LSTM gates + const int n_joint_out = hparams.n_vocab + hparams.n_tdt_durations + 1; // vocab + durations + blank + + // The prediction/joint hidden dimension is 640, which is not a multiple of the + // K-quant block size (256). For K-quant models, we keep these tensors at F32. + const int blck = ggml_blck_size(wtype); + const ggml_type pred_wtype = (blck > 1 && dec_hidden % blck != 0) ? GGML_TYPE_F32 : wtype; + const ggml_type join_wtype = pred_wtype; + + model.prediction.embed_w = create_tensor(PARAKEET_TENSOR_PRED_EMBED_WEIGHT, ggml_new_tensor_2d(ctx, pred_wtype, dec_hidden, n_pred_embed)); + model.prediction.lstm_layer.resize(hparams.n_pred_layers); + for (int i = 0; i < hparams.n_pred_layers; ++i) { + auto & layer = model.prediction.lstm_layer[i]; + layer.ih_w = create_tensor(PARAKEET_TENSOR_PRED_LSTM_WEIGHT_IH, ggml_new_tensor_2d(ctx, pred_wtype, dec_hidden, n_lstm_gates), i); + ggml_format_name(layer.ih_w, "pred_%d_ih_w", i); + + layer.hh_w = create_tensor(PARAKEET_TENSOR_PRED_LSTM_WEIGHT_HH, ggml_new_tensor_2d(ctx, pred_wtype, dec_hidden, n_lstm_gates), i); + ggml_format_name(layer.hh_w, "pred_%d_hh_w", i); + + layer.b_h = create_tensor(PARAKEET_TENSOR_PRED_LSTM_BIAS_H, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_lstm_gates), i); + ggml_format_name(layer.b_h, "pred_%d_b_h", i); + } + + // Joint network + model.joint.pred_w = create_tensor(PARAKEET_TENSOR_JOINT_PRED_WEIGHT, ggml_new_tensor_2d(ctx, join_wtype, dec_hidden, dec_hidden)); + ggml_set_name(model.joint.pred_w, "pred_w"); + model.joint.pred_b = create_tensor(PARAKEET_TENSOR_JOINT_PRED_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, dec_hidden)); + ggml_set_name(model.joint.pred_b, "pred_b"); + model.joint.enc_w = create_tensor(PARAKEET_TENSOR_JOINT_ENC_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, dec_hidden)); + ggml_set_name(model.joint.enc_w, "enc_w"); + model.joint.enc_b = create_tensor(PARAKEET_TENSOR_JOINT_ENC_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, dec_hidden)); + ggml_set_name(model.joint.enc_b, "enc_b"); + model.joint.net_w = create_tensor(PARAKEET_TENSOR_JOINT_NET_WEIGHT, ggml_new_tensor_2d(ctx, join_wtype, dec_hidden, n_joint_out)); + ggml_set_name(model.joint.net_w, "net_w"); + model.joint.net_b = create_tensor(PARAKEET_TENSOR_JOINT_NET_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_joint_out)); + ggml_set_name(model.joint.net_b, "net_b"); + + ggml_free(ctx); + + // allocate tensors in the backend buffers + for (auto & p : ctx_map) { + ggml_backend_buffer_type_t buft = p.first; + ggml_context * ctx = p.second; + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); + if (buf) { + wctx.model.buffers.emplace_back(buf); + + size_t size_main = ggml_backend_buffer_get_size(buf); + PARAKEET_LOG_INFO("%s: %12s total size = %8.2f MB\n", __func__, ggml_backend_buffer_name(buf), size_main / 1e6); + } + } + + // load weights + { + size_t total_size = 0; + + auto & tensors_map = wctx.model.tensors; + int & n_loaded = wctx.model.n_loaded; + + n_loaded = 0; + + std::vector<char> read_buf; + + while (true) { + int32_t n_dims; + int32_t length; + int32_t ttype; + + read_safe(loader, n_dims); + read_safe(loader, length); + read_safe(loader, ttype); + + if (loader->eof(loader->context)) { + break; + } + + int32_t nelements = 1; + int32_t ne[4] = { 1, 1, 1, 1 }; + for (int i = 0; i < n_dims; ++i) { + read_safe(loader, ne[i]); + nelements *= ne[i]; + } + + std::string name; + std::vector<char> tmp(length); // create a buffer + loader->read(loader->context, &tmp[0], tmp.size()); // read to buffer + name.assign(&tmp[0], tmp.size()); + + if (tensors_map.find(name) == tensors_map.end()) { + PARAKEET_LOG_ERROR("%s: unknown tensor '%s' in model file\n", __func__, name.data()); + return false; + } + + auto tensor = tensors_map[name.data()]; + + if (ggml_nelements(tensor) != nelements) { + PARAKEET_LOG_ERROR("%s: tensor '%s' has wrong size in model file\n", __func__, name.data()); + PARAKEET_LOG_ERROR("%s: shape: [%d, %d, %d], expected: [%d, %d, %d]\n", + __func__, ne[0], ne[1], ne[2], (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2]); + return false; + } + + if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2] || tensor->ne[3] != ne[3]) { + PARAKEET_LOG_ERROR("%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d, %d], expected [%d, %d, %d, %d]\n", + __func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], (int) tensor->ne[3], ne[0], ne[1], ne[2], ne[3]); + return false; + } + + const size_t bpe = ggml_type_size(ggml_type(ttype)); + + if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) { + PARAKEET_LOG_ERROR("%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n", + __func__, name.data(), ggml_nbytes(tensor), nelements*bpe); + return false; + } + + if (ggml_backend_buffer_is_host(tensor->buffer)) { + // for the CPU and Metal backend, we can read directly into the tensor + loader->read(loader->context, tensor->data, ggml_nbytes(tensor)); + BYTESWAP_TENSOR(tensor); + } else { + // read into a temporary buffer first, then copy to device memory + read_buf.resize(ggml_nbytes(tensor)); + + loader->read(loader->context, read_buf.data(), read_buf.size()); + + ggml_backend_tensor_set(tensor, read_buf.data(), 0, ggml_nbytes(tensor)); + } + + total_size += ggml_nbytes(tensor); + n_loaded++; + } + + PARAKEET_LOG_INFO("%s: model size = %7.2f MB\n", __func__, total_size/1e6); + + if (n_loaded == 0) { + PARAKEET_LOG_WARN("%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__); + } else if (n_loaded != (int) tensors_map.size()) { + PARAKEET_LOG_ERROR("%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, tensors_map.size(), n_loaded); + return false; + } + } + + auto & buffers = wctx.model.buffers; + for (auto & buf : buffers) { + ggml_backend_buffer_set_usage(buf, GGML_BACKEND_BUFFER_USAGE_WEIGHTS); + } + + wctx.t_load_us = ggml_time_us() - t_start_us; + + return true; +} + +// conv subsampling + conformer encoder +static struct ggml_cgraph * parakeet_build_graph_encode(parakeet_context & pctx, parakeet_state & pstate) { + const auto & model = pctx.model; + const auto & hparams = model.hparams; + const int n_mel_time = pstate.n_audio_ctx > 0 ? pstate.n_audio_ctx : hparams.n_audio_ctx; + const int n_mels = hparams.n_mels; + const int n_layer = hparams.n_audio_layer; + const int n_state = hparams.n_audio_state; + const float fc_factor = 0.5f; + + struct ggml_init_params params = { + /*.mem_size =*/ pstate.sched_encode.meta.size(), + /*.mem_buffer =*/ pstate.sched_encode.meta.data(), + /*.no_alloc =*/ true, + }; + + struct ggml_context * ctx0 = ggml_init(params); + ggml_cgraph * gf = ggml_new_graph_custom(ctx0, PARAKEET_MAX_NODES, false); + + // Conv subsampling + + // [freq, time] + struct ggml_tensor * mel = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_mels, n_mel_time, 1, 1); + ggml_set_name(mel, "mel"); + ggml_set_input(mel); + + // [freq, time, channels, batch] + struct ggml_tensor * cur = ggml_conv_2d(ctx0, model.enc_pre_conv_0_w, mel, 2, 2, 1, 1, 1, 1); + cur = ggml_add(ctx0, cur, model.enc_pre_conv_0_b); + ggml_set_name(cur, "pre_conv_0"); + + cur = ggml_relu(ctx0, cur); + ggml_set_name(cur, "pre_conv_0_relu"); + + // [freq, time, channels, batch] + cur = ggml_conv_2d_dw_direct(ctx0, model.enc_pre_conv_2_w, cur, 2, 2, 1, 1, 1, 1); + cur = ggml_add(ctx0, cur, model.enc_pre_conv_2_b); + ggml_set_name(cur, "pre_conv_2"); + + // [freq, time, channels, batch] + cur = ggml_conv_2d(ctx0, model.enc_pre_conv_3_w, cur, 1, 1, 0, 0, 1, 1); + cur = ggml_add(ctx0, cur, model.enc_pre_conv_3_b); + ggml_set_name(cur, "pre_conv_3"); + + cur = ggml_relu(ctx0, cur); + ggml_set_name(cur, "pre_conv_3_relu"); + + // [freq, time, channels, batch] + cur = ggml_conv_2d_dw_direct(ctx0, model.enc_pre_conv_5_w, cur, 2, 2, 1, 1, 1, 1); + ggml_set_name(cur, "pre_conv_5_direct"); + cur = ggml_add(ctx0, cur, model.enc_pre_conv_5_b); + ggml_set_name(cur, "pre_conv_5"); + + // [freq, time, channels, batch] + cur = ggml_conv_2d(ctx0, model.enc_pre_conv_6_w, cur, 1, 1, 0, 0, 1, 1); + cur = ggml_add(ctx0, cur, model.enc_pre_conv_6_b); + ggml_set_name(cur, "pre_conv_6"); + + cur = ggml_relu(ctx0, cur); + ggml_set_name(cur, "pre_conv_6_relu"); + + // [freq, time, chan] + cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); + // [freq, chan, time] + cur = ggml_cont(ctx0, cur); + + const int n_freq = cur->ne[0]; // 16 + const int n_chan = cur->ne[1]; // 256 + const int n_frames = cur->ne[2]; // time + + // [freq, time, chan, batch] -> [(freq * chan), time] + cur = ggml_reshape_2d(ctx0, cur, n_freq * n_chan, n_frames); + + cur = ggml_mul_mat(ctx0, model.enc_pre_out_w, cur); + cur = ggml_add(ctx0, cur, model.enc_pre_out_b); + + ggml_set_name(cur, "pre_enc_out"); + + // Encoder + // cur: [n_state, n_enc_time] + + const int n_time = cur->ne[1]; + const bool local_attn = n_time > PARAKEET_LOCAL_ATTN_THRESHOLD; + const int att_left = local_attn ? PARAKEET_LOCAL_ATTN_WINDOW : n_time - 1; + const int att_right = local_attn ? PARAKEET_LOCAL_ATTN_WINDOW : n_time - 1; + const int window_size = local_attn ? att_left + att_right + 1 : 2 * n_time - 1; + const int d_half = n_state / 2; + const int mask_dim = local_attn ? window_size : n_time; + + // mask [key, n_time] + struct ggml_tensor * attn_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, mask_dim, n_time); + ggml_set_name(attn_mask, "attn_mask"); + ggml_set_input(attn_mask); + + struct ggml_tensor * local_mask = nullptr; + if (local_attn) { + const int chunk = att_left + att_right; + local_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, chunk + window_size - 1, chunk); + ggml_set_name(local_mask, "local_mask"); + ggml_set_input(local_mask); + } + + struct ggml_tensor * pos_freqs = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, d_half); + ggml_set_name(pos_freqs, "pos_freqs"); + ggml_set_input(pos_freqs); + + struct ggml_tensor * rel_positions = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, window_size); + ggml_set_name(rel_positions, "rel_positions"); + ggml_set_input(rel_positions); + + struct ggml_tensor * freqs = ggml_repeat_4d(ctx0, pos_freqs, d_half, window_size, 1, 1); + struct ggml_tensor * theta = ggml_mul(ctx0, freqs, rel_positions); + + struct ggml_tensor * sin_t = ggml_reshape_3d(ctx0, ggml_sin(ctx0, theta), 1, d_half, window_size); + struct ggml_tensor * cos_t = ggml_reshape_3d(ctx0, ggml_cos(ctx0, theta), 1, d_half, window_size); + // [n_state, window_size] + struct ggml_tensor * pos_emb = ggml_reshape_2d(ctx0, ggml_cont(ctx0, ggml_concat(ctx0, sin_t, cos_t, 0)), n_state, window_size); + ggml_set_name(pos_emb, "pos_emb"); + + for (int il = 0; il < n_layer; ++il) { + const auto & layer = model.layers[il]; + + // FFN1 + { + struct ggml_tensor * residual = cur; + ggml_format_name(cur, "enc_%d_res", il); + + // norm + cur = ggml_norm(ctx0, cur, hparams.eps); + cur = ggml_add(ctx0, ggml_mul(ctx0, cur, layer.norm_ff1_w), layer.norm_ff1_b); + ggml_format_name(cur, "enc_%d_ffn_norm_1", il); + + // ffn_1 + cur = ggml_mul_mat(ctx0, layer.ff1_linear1_w, cur); + cur = ggml_silu(ctx0, cur); + ggml_format_name(cur, "enc_%d_silu", il); + + cur = ggml_mul_mat(ctx0, layer.ff1_linear2_w, cur); + ggml_format_name(cur, "enc_%d_ffn_1", il); + + cur = ggml_add(ctx0, residual, ggml_scale(ctx0, cur, fc_factor)); + ggml_format_name(cur, "enc_%d_res_ffn", il); + } + + // self attention block using relative positional encoding computed in graph. + { + // [feat, time_frames, 1, 1] + struct ggml_tensor * residual = cur; + + cur = ggml_norm(ctx0, cur, hparams.eps); + cur = ggml_add(ctx0, ggml_mul(ctx0, cur, layer.norm_attn_w), layer.norm_attn_b); + ggml_format_name(cur, "enc_%d_attn_norm", il); + + const int n_head = hparams.n_audio_head; + const int d_head = n_state / n_head; + + // [feat, time_frames, 1, 1] + struct ggml_tensor * Q_cur = ggml_mul_mat(ctx0, layer.attn_q_w, cur); + struct ggml_tensor * K_cur = ggml_mul_mat(ctx0, layer.attn_k_w, cur); + struct ggml_tensor * V_cur = ggml_mul_mat(ctx0, layer.attn_v_w, cur); + + Q_cur = ggml_reshape_3d(ctx0, Q_cur, d_head, n_head, n_time); + K_cur = ggml_reshape_3d(ctx0, K_cur, d_head, n_head, n_time); + V_cur = ggml_reshape_3d(ctx0, V_cur, d_head, n_head, n_time); + + struct ggml_tensor * pos = ggml_mul_mat(ctx0, layer.attn_pos_w, pos_emb); + pos = ggml_reshape_3d(ctx0, pos, d_head, n_head, window_size); + pos = ggml_cont(ctx0, ggml_permute(ctx0, pos, 0, 2, 1, 3)); + + if (local_attn) { + const int chunk = att_left + att_right; + const int n_group = (n_time + chunk - 1) / chunk; + const int n_time_padded = n_group * chunk; + const int n_kv_chunk = chunk + window_size - 1; + const int n_kv_dense = n_kv_chunk * n_group; + const bool need_padding = n_time_padded > n_time; + + Q_cur = ggml_cont(ctx0, ggml_permute(ctx0, Q_cur, 0, 2, 1, 3)); + K_cur = ggml_cont(ctx0, ggml_permute(ctx0, K_cur, 0, 2, 1, 3)); + V_cur = ggml_cont(ctx0, ggml_permute(ctx0, V_cur, 0, 2, 1, 3)); + + // content bias + struct ggml_tensor * bias_u = ggml_reshape_3d(ctx0, layer.attn_pos_bias_u, d_head, 1, n_head); + struct ggml_tensor * Q_u = ggml_add(ctx0, Q_cur, bias_u); + + // position bias + struct ggml_tensor * bias_v = ggml_reshape_3d(ctx0, layer.attn_pos_bias_v, d_head, 1, n_head); + struct ggml_tensor * Q_v = ggml_add(ctx0, Q_cur, bias_v); + + // right pad the time_frame. + struct ggml_tensor * Q_u_padded = need_padding ? + ggml_pad_ext(ctx0, Q_u, 0, 0, 0, n_time_padded - n_time, 0, 0, 0, 0) : Q_u; + Q_u_padded = ggml_reshape_4d(ctx0, Q_u_padded, d_head, chunk, n_group, n_head); + + // Add padding to front and back (for the first timeframe and the last timeframe). + struct ggml_tensor * K_padded = ggml_pad_ext(ctx0, K_cur, 0, 0, att_left, att_right, 0, 0, 0, 0); + + // pad time axis to match n_kv_dense if needed. + if (n_kv_dense > K_padded->ne[1]) { + K_padded = ggml_pad_ext(ctx0, K_padded, 0, 0, 0, n_kv_dense - K_padded->ne[1], 0, 0, 0, 0); + } + + // Create a 4d tensor where each group spans a wide window of + // 512 keys (n_kv_chunk), but moving to the next group (nb[2]) + // only jumps forward by 256 frames (chunk * nb[1]). This creates + // a 256 frame overlap, shared keys in RAM without copies. + struct ggml_tensor * K_chunk = ggml_view_4d(ctx0, K_padded, + d_head, n_kv_chunk, n_group, n_head, + K_padded->nb[1], + (size_t) chunk * K_padded->nb[1], + K_padded->nb[2], + 0); + K_chunk = ggml_cont(ctx0, K_chunk); + + struct ggml_tensor * content_scores = ggml_mul_mat(ctx0, K_chunk, Q_u_padded); + + // The above mul_mat operation, combined with K_chunk's overlapping + // frames, produces a dense matrix. But some of the results in + // this matrix were computed for keys that aren't part of that + // query's window. So we shift each row to keep only the results + // that we want. + content_scores = ggml_view_4d(ctx0, content_scores, + window_size, chunk, n_group, n_head, + (size_t) (chunk + window_size) * content_scores->nb[0], + content_scores->nb[2], + content_scores->nb[3], + 0); + content_scores = ggml_cont(ctx0, content_scores); + + // ungrouping. + content_scores = ggml_reshape_3d(ctx0, content_scores, window_size, n_time_padded, n_head); + + // remove padding if padding was applied (truncating to n_time). + if (need_padding) { + content_scores = ggml_view_3d(ctx0, content_scores, + window_size, n_time, n_head, + content_scores->nb[1], + content_scores->nb[2], + 0); + } + + struct ggml_tensor * rel_pos_scores = ggml_mul_mat(ctx0, pos, Q_v); + + // attention_score = content similarity + relative position scores + struct ggml_tensor * attn_scores = ggml_add(ctx0, content_scores, rel_pos_scores); + + attn_scores = ggml_soft_max_ext(ctx0, attn_scores, attn_mask, 1.0f / std::sqrt(d_head), 0.0f); + + // right pad the probabilites. + struct ggml_tensor * probs_padded = need_padding ? + ggml_pad_ext(ctx0, attn_scores, 0, 0, 0, n_time_padded - n_time, 0, 0, 0, 0) : attn_scores; + + probs_padded = ggml_reshape_4d(ctx0, probs_padded, window_size, chunk, n_group, n_head); + probs_padded = ggml_pad_ext(ctx0, probs_padded, 0, chunk, 0, 0, 0, 0, 0, 0); + probs_padded = ggml_view_4d(ctx0, probs_padded, + n_kv_chunk, chunk, n_group, n_head, + (size_t) n_kv_chunk * probs_padded->nb[0], + probs_padded->nb[2], + probs_padded->nb[3], + 0); + probs_padded = ggml_cont(ctx0, probs_padded); + probs_padded = ggml_mul(ctx0, probs_padded, local_mask); + + // Add padding to front and back (for the first timeframe and the last timeframe). + struct ggml_tensor * V_padded = ggml_pad_ext(ctx0, V_cur, 0, 0, att_left, att_right, 0, 0, 0, 0); + + // pad time axis to match n_kv_dense if needed. + if (n_kv_dense > V_padded->ne[1]) { + V_padded = ggml_pad_ext(ctx0, V_padded, 0, 0, 0, n_kv_dense - V_padded->ne[1], 0, 0, 0, 0); + } + + V_padded = ggml_cont(ctx0, ggml_transpose(ctx0, V_padded)); + + struct ggml_tensor * V_chunk = ggml_view_4d(ctx0, V_padded, + n_kv_chunk, d_head, n_group, n_head, + V_padded->nb[1], + (size_t) chunk * V_padded->nb[0], + V_padded->nb[2], + 0); + V_chunk = ggml_cont(ctx0, V_chunk); + + cur = ggml_mul_mat(ctx0, V_chunk, probs_padded); + // ungroup. + cur = ggml_reshape_3d(ctx0, cur, d_head, n_time_padded, n_head); + // unpad + if (need_padding) { + cur = ggml_view_3d(ctx0, cur, d_head, n_time, n_head, cur->nb[1], cur->nb[2], 0); + } + + cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 0, 2, 1, 3)); + cur = ggml_reshape_2d(ctx0, cur, n_state, n_time); + cur = ggml_mul_mat(ctx0, layer.attn_out_w, cur); + } else { + struct ggml_tensor * Q_u = ggml_add(ctx0, Q_cur, layer.attn_pos_bias_u); + ggml_format_name(Q_u, "enc_%d_attn_q_u", il); + + struct ggml_tensor * K_prep = ggml_permute(ctx0, K_cur, 0, 2, 1, 3); + struct ggml_tensor * Q_prep = ggml_permute(ctx0, Q_u, 0, 2, 1, 3); + struct ggml_tensor * content_scores = ggml_mul_mat(ctx0, K_prep, Q_prep); + ggml_format_name(content_scores, "enc_%d_attn_content_scores", il); + + struct ggml_tensor * Q_v = ggml_add(ctx0, Q_cur, layer.attn_pos_bias_v); + ggml_format_name(Q_v, "enc_%d_attn_q_v", il); + + Q_v = ggml_permute(ctx0, Q_v, 0, 2, 1, 3); + Q_v = ggml_cont(ctx0, Q_v); + ggml_format_name(Q_v, "enc_%d_attn_q_v_perm", il); + + struct ggml_tensor * rel_pos_scores = ggml_mul_mat(ctx0, pos, Q_v); + ggml_format_name(rel_pos_scores, "enc_%d_attn_rel_pos", il); + + // Relative position shifting is performed in the following block. + // Some more details on the operations performed below can be found here: + // https://github.com/danbev/learning-ai/blob/main/notes/whisper/parakeet.md#relative-position-shift + { + const auto pos_window = rel_pos_scores->ne[0]; + const auto n_frame = rel_pos_scores->ne[1]; + const auto n_head_cur = rel_pos_scores->ne[2]; + + rel_pos_scores = ggml_pad(ctx0, rel_pos_scores, 1, 0, 0, 0); + rel_pos_scores = ggml_roll(ctx0, rel_pos_scores, 1, 0, 0, 0); + + rel_pos_scores = ggml_reshape_3d(ctx0, rel_pos_scores, n_frame, pos_window + 1, n_head_cur); + ggml_format_name(rel_pos_scores, "enc_%d_attn_rel_pos_reshaped", il); + + int center = pos_window / 2; + size_t offset = rel_pos_scores->nb[0] * (center+1); + + rel_pos_scores = ggml_view_3d(ctx0, rel_pos_scores, + n_frame, pos_window, n_head_cur, + (pos_window) * 4, + rel_pos_scores->nb[2], + offset); + + ggml_format_name(rel_pos_scores, "enc_%d_attn_rel_pos_shifted", il); + + rel_pos_scores = ggml_view_3d(ctx0, rel_pos_scores, + content_scores->ne[0], + content_scores->ne[1], + rel_pos_scores->ne[2], + rel_pos_scores->nb[1], + rel_pos_scores->nb[2], + 0); + rel_pos_scores = ggml_cont(ctx0, rel_pos_scores); + ggml_format_name(rel_pos_scores, "enc_%d_attn_rel_pos_shifted_view", il); + } + + struct ggml_tensor * attn_scores = ggml_add(ctx0, content_scores, rel_pos_scores); + ggml_format_name(attn_scores, "enc_%d_attn_scores", il); + attn_scores = ggml_scale(ctx0, attn_scores, 1.0f / std::sqrt(d_head)); + attn_scores = ggml_add(ctx0, attn_scores, attn_mask); + ggml_format_name(attn_scores, "enc_%d_attn_scores_scaled", il); + + struct ggml_tensor * probs = ggml_soft_max(ctx0, attn_scores); + ggml_format_name(probs, "enc_%d_attn_probs", il); + + V_cur = ggml_cont(ctx0, ggml_permute(ctx0, V_cur, 1, 2, 0, 3)); + ggml_format_name(V_cur, "enc_%d_attn_v_cur", il); + cur = ggml_mul_mat(ctx0, probs, V_cur); + ggml_format_name(cur, "enc_%d_attn_inp", il); + + cur = ggml_permute(ctx0, cur, 2, 0, 1, 3); + cur = ggml_cont_2d(ctx0, cur, n_state, n_time); + cur = ggml_mul_mat(ctx0, layer.attn_out_w, cur); + } + ggml_format_name(cur, "enc_%d_attn_out", il); + + cur = ggml_add(ctx0, residual, cur); + ggml_format_name(cur, "enc_%d_attn_res", il); + } + + // Convolution + { + struct ggml_tensor * residual = cur; + ggml_format_name(cur, "enc_%d_residual_conv", il); + + cur = ggml_norm(ctx0, cur, hparams.eps); + cur = ggml_add(ctx0, ggml_mul(ctx0, cur, layer.norm_conv_w), layer.norm_conv_b); + ggml_format_name(cur, "enc_%d_norm_conv", il); + + // pointwise 1d convolution: [1024, 138] -> [2048, 138] + cur = ggml_mul_mat(ctx0, layer.conv_pw1_w, cur); + ggml_format_name(cur, "enc_%d_conv_pw1", il); + + { + int64_t d = cur->ne[0] / 2; + struct ggml_tensor * signal = ggml_view_2d(ctx0, cur, d, cur->ne[1], cur->nb[1], 0); + struct ggml_tensor * gate = ggml_view_2d(ctx0, cur, d, cur->ne[1], cur->nb[1], d * cur->nb[0]); + + cur = ggml_mul(ctx0, signal, ggml_sigmoid(ctx0, gate)); + ggml_format_name(cur, "enc_%d_conv_glu", il); + } + + cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur)); + + // use ggml_ssm_conv for f32 precision + const int dw_pad = (hparams.n_conv_kernel - 1) / 2; + cur = ggml_pad(ctx0, cur, dw_pad, 0, 0, 0); + cur = ggml_roll(ctx0, cur, dw_pad, 0, 0, 0); + cur = ggml_pad(ctx0, cur, dw_pad, 0, 0, 0); + ggml_format_name(cur, "enc_%d_conv_dw_pad", il); + + cur = ggml_ssm_conv(ctx0, cur, layer.conv_dw_w); + ggml_format_name(cur, "enc_%d_conv_1d_dw", il); + + cur = ggml_sub(ctx0, cur, layer.conv_bn_mean); + struct ggml_tensor * std = ggml_sqrt(ctx0, layer.conv_bn_var); + cur = ggml_div(ctx0, cur, std); + cur = ggml_add(ctx0, ggml_mul(ctx0, cur, layer.conv_bn_w), layer.conv_bn_b); + ggml_format_name(cur, "enc_%d_conv_bn", il); + + cur = ggml_silu(ctx0, cur); + ggml_format_name(cur, "enc_%d_conv_silu", il); + + cur = ggml_mul_mat(ctx0, layer.conv_pw2_w, cur); + ggml_format_name(cur, "enc_%d_conv_pw2", il); + + cur = ggml_add(ctx0, residual, cur); + ggml_format_name(cur, "enc_%d_conv_res", il); + } + + // FFN2 + { + struct ggml_tensor * residual = cur; + cur = ggml_norm(ctx0, cur, hparams.eps); + cur = ggml_add(ctx0, ggml_mul(ctx0, cur, layer.norm_ff2_w), layer.norm_ff2_b); + ggml_format_name(cur, "enc_%d_ffn_norm_2", il); + + cur = ggml_mul_mat(ctx0, layer.ff2_linear1_w, cur); + cur = ggml_silu(ctx0, cur); + cur = ggml_mul_mat(ctx0, layer.ff2_linear2_w, cur); + cur = ggml_add(ctx0, residual, ggml_scale(ctx0, cur, 0.5)); + ggml_format_name(cur, "enc_%d_ffn_res", il); + } + + cur = ggml_norm(ctx0, cur, hparams.eps); + cur = ggml_add(ctx0, ggml_mul(ctx0, cur, layer.norm_out_w), layer.norm_out_b); + } + + ggml_set_name(cur, "encoder_out"); + pstate.n_frames = cur->ne[1]; + + struct ggml_tensor * enc_out_view = ggml_view_2d(ctx0, pstate.enc_out, n_state, pstate.n_frames, pstate.enc_out->nb[1], 0); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, cur, enc_out_view)); + + ggml_free(ctx0); + + return gf; +} + +static bool parakeet_encode_internal( + parakeet_context & pctx, + parakeet_state & pstate, + const int mel_offset, + const int n_threads, + ggml_abort_callback abort_callback, + void * abort_callback_data) { + const int64_t t_start_us = ggml_time_us(); + + auto & sched = pstate.sched_encode.sched; + + ggml_cgraph * gf = parakeet_build_graph_encode(pctx, pstate); + + if (!ggml_backend_sched_alloc_graph(sched, gf)) { + // should never happen as we pre-allocate the memory + return false; + } + + // set mel input + { + struct ggml_tensor * mel = ggml_graph_get_tensor(gf, "mel"); + + const auto & mel_inp = pstate.mel; + const int n_ctx = pstate.n_audio_ctx > 0 ? pstate.n_audio_ctx : pctx.model.hparams.n_audio_ctx; + + assert(mel->type == GGML_TYPE_F32); + assert(mel_inp.n_mel == pctx.model.hparams.n_mels); + + pstate.inp_mel.resize(ggml_nelements(mel)); + + float * dst = pstate.inp_mel.data(); + memset(dst, 0, ggml_nbytes(mel)); + + const int i0 = std::min(mel_offset, mel_inp.n_len); + const int i1 = std::min(mel_offset + n_ctx, mel_inp.n_len); + + memcpy(dst, mel_inp.data.data() + i0 * mel_inp.n_mel, (i1 - i0) * mel_inp.n_mel * sizeof(float)); + + ggml_backend_tensor_set(mel, pstate.inp_mel.data(), 0, ggml_nelements(mel)*sizeof(float)); + } + + // set attention mask + { + struct ggml_tensor * attn_mask = ggml_graph_get_tensor(gf, "attn_mask"); + const int n_q = attn_mask->ne[1]; + const int n_k = attn_mask->ne[0]; + + const int32_t subsampl_factor = pctx.model.hparams.subsampling_factor; + const int n_tokens_real = (pstate.mel.n_len_org + subsampl_factor - 1) / subsampl_factor; + + std::vector<float> mask_data(n_q * n_k); + const float mask_value = -1e30f; + + if (n_k == n_q) { // full attention + for (int q = 0; q < n_q; ++q) { + for (int k = 0; k < n_k; ++k) { + mask_data[q * n_k + k] = (k >= n_tokens_real) ? mask_value : 0.0f; + } + } + } else { // local attention + const int att_left = n_k / 2; + for (int q = 0; q < n_q; ++q) { + for (int k = 0; k < n_k; ++k) { + const int key = q - att_left + k; + mask_data[q * n_k + k] = (key >= 0 && key < n_tokens_real) ? 0.0f : mask_value; + } + } + } + ggml_backend_tensor_set(attn_mask, mask_data.data(), 0, mask_data.size() * sizeof(float)); + } + + // set local attention skew mask + if (struct ggml_tensor * local_mask = ggml_graph_get_tensor(gf, "local_mask")) { + const int n_k = local_mask->ne[0]; + const int n_q = local_mask->ne[1]; + + std::vector<float> mask_data(n_q * n_k); + const int window_size = n_k - n_q + 1; + for (int q = 0; q < n_q; ++q) { + for (int k = 0; k < n_k; ++k) { + const int rel = k - q; + mask_data[q * n_k + k] = (rel >= 0 && rel < window_size) ? 1.0f : 0.0f; + } + } + ggml_backend_tensor_set(local_mask, mask_data.data(), 0, mask_data.size() * sizeof(float)); + } + + // set positional frequency + { + struct ggml_tensor * pos_freqs_t = ggml_graph_get_tensor(gf, "pos_freqs"); + const int d_half = pos_freqs_t->ne[0]; + const int n_state = pctx.model.hparams.n_audio_state; + const float log_10000 = logf(10000.0f); + std::vector<float> freqs(d_half); + for (int k = 0; k < d_half; ++k) { + freqs[k] = expf(-(float(k * 2) * log_10000 / float(n_state))); + } + ggml_backend_tensor_set(pos_freqs_t, freqs.data(), 0, freqs.size() * sizeof(float)); + } + + // set relative position offsets + { + struct ggml_tensor * rel_pos_t = ggml_graph_get_tensor(gf, "rel_positions"); + const int window_size = rel_pos_t->ne[1]; + std::vector<float> pos(window_size); + if (window_size == PARAKEET_LOCAL_ATTN_WINDOW * 2 + 1) { + for (int t = 0; t < window_size; ++t) { + pos[t] = float(PARAKEET_LOCAL_ATTN_WINDOW - t); + } + } else { + const int n_time = (window_size + 1) / 2; + for (int t = 0; t < window_size; ++t) { + pos[t] = float(n_time - 1 - t); + } + } + ggml_backend_tensor_set(rel_pos_t, pos.data(), 0, pos.size() * sizeof(float)); + } + + if (!ggml_graph_compute_helper(sched, gf, n_threads)) { + return false; + } + + pstate.t_encode_us += ggml_time_us() - t_start_us; + pstate.n_encode++; + + return !(abort_callback && abort_callback(abort_callback_data)); +} + +static bool parakeet_ensure_encode_sched( + parakeet_context & pctx, + parakeet_state & pstate, + int n_audio_ctx) { + if (pstate.sched_encode.sched && pstate.sched_encode_n_audio_ctx == n_audio_ctx) { + return true; + } + + parakeet_sched_free(pstate.sched_encode); + + const int32_t prev_n_audio_ctx = pstate.n_audio_ctx; + pstate.n_audio_ctx = n_audio_ctx; + + const int subsampl_factor = pctx.model.hparams.subsampling_factor; + const int n_frames_max = (n_audio_ctx + subsampl_factor - 1) / subsampl_factor; + if (n_frames_max > pstate.enc_out->ne[1]) { + ggml_backend_buffer_free(pstate.enc_out_buffer); + pstate.enc_out_buffer = nullptr; + pstate.enc_out = nullptr; + + if (!parakeet_enc_state_init(pstate, pstate.backends[0], pctx.model.hparams.n_audio_state, n_frames_max)) { + pstate.sched_encode_n_audio_ctx = 0; + pstate.n_audio_ctx = prev_n_audio_ctx; + return false; + } + } + + const bool ok = parakeet_sched_graph_init(pstate.sched_encode, pstate.backends, + [&]() { + return parakeet_build_graph_encode(pctx, pstate); + }); + + if (!ok) { + pstate.sched_encode_n_audio_ctx = 0; + pstate.n_audio_ctx = prev_n_audio_ctx; + return false; + } + + pstate.sched_encode_n_audio_ctx = n_audio_ctx; + return true; +} + +static struct ggml_tensor * parakeet_build_graph_lstm_layer( + struct ggml_context * ctx0, + struct ggml_cgraph * gf, + struct ggml_tensor * x_t, // the current input token embedding + struct ggml_tensor * w_ih, // input to hidden weights (4 weight tensors packed) + struct ggml_tensor * w_hh, // hidden to hidden weights (4 weight tensors packed) + struct ggml_tensor * b_h, // folded ih+hh bias (4 bias tensors packed) + struct ggml_tensor * h_state, // this layers hidden state + struct ggml_tensor * c_state, // this layers cell state + int li) { // layer index (for tensor naming) + + ggml_format_name(x_t, "lstm_layer_%d_x_t", li); + ggml_format_name(h_state, "lstm_layer_%d_h_state", li); + ggml_format_name(c_state, "lstm_layer_%d_c_state", li); + + // The 4 gates (i, f, o, c) are packed in the same weight tensor. + struct ggml_tensor * inp_gates = ggml_mul_mat(ctx0, w_ih, x_t); + + // Hidden-to-Hidden Projections are also packed in the same weight tensor. + // b_h holds the folded ih+hh bias (see parakeet_model_load), so it is + // the only bias that needs to be added here. + struct ggml_tensor * hid_gates = ggml_mul_mat(ctx0, w_hh, h_state); + hid_gates = ggml_add(ctx0, hid_gates, b_h); + + // Combine the input and hidden contributions of the gates. + struct ggml_tensor * gates = ggml_add(ctx0, inp_gates, hid_gates); + ggml_format_name(gates, "lstm_layer_%d_gates", li); + + const int h_dim = h_state->ne[0]; + const size_t row_size = ggml_row_size(gates->type, h_dim); + + // The gates are packed as [i, f, o, c] (reordered at convert time, see + // parakeet_model_load), so the three sigmoid-gated outputs (i, f, o) are + // contiguous and can be computed with a single ggml_sigmoid call. + struct ggml_tensor * ifo = ggml_sigmoid(ctx0, ggml_view_1d(ctx0, gates, 3 * h_dim, 0)); + ggml_format_name(ifo, "lstm_layer_%d_ifo", li); + + // 1. Input Gate at time t. + struct ggml_tensor * i_t = ggml_view_1d(ctx0, ifo, h_dim, 0 * row_size); + ggml_format_name(i_t, "lstm_layer_%d_i_t", li); + + // Forget gate. + struct ggml_tensor * f_t = ggml_view_1d(ctx0, ifo, h_dim, 1 * row_size); + ggml_format_name(f_t, "lstm_layer_%d_f_t", li); + + // Output gate. + struct ggml_tensor * o_t = ggml_view_1d(ctx0, ifo, h_dim, 2 * row_size); + ggml_format_name(o_t, "lstm_layer_%d_o_t", li); + + // Cell gate. + struct ggml_tensor * c_t = ggml_tanh(ctx0, ggml_view_1d(ctx0, gates, h_dim, 3 * row_size)); + ggml_format_name(c_t, "lstm_layer_%d_c_t", li); + + // Calculate the new cell state. + struct ggml_tensor * c_new = ggml_add(ctx0, + ggml_mul(ctx0, f_t, c_state), // apply forget gate to cell state. + ggml_mul(ctx0, i_t, c_t)); // apply input gate to cell gate. + ggml_build_forward_expand(gf, ggml_cpy(ctx0, c_new, c_state)); + + // Calculate the new hidden state. + struct ggml_tensor * h_new = ggml_mul(ctx0, o_t, ggml_tanh(ctx0, c_new)); + ggml_set_output(h_new); + ggml_format_name(h_new, "lstm_layer_%d_h_new", li); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, h_new, h_state)); + + return h_new; +} + +static struct ggml_cgraph * parakeet_build_graph_prediction( + parakeet_context & pctx, + parakeet_state & pstate, + const parakeet_batch & batch, + bool worst_case) { + GGML_UNUSED(worst_case); + const auto & model = pctx.model; + const auto & hparams = model.hparams; + const int n_tokens = batch.n_tokens; + + struct ggml_init_params params = { + /*.mem_size =*/ pstate.sched_decode.meta.size(), + /*.mem_buffer =*/ pstate.sched_decode.meta.data(), + /*.no_alloc =*/ true, + }; + + struct ggml_context * ctx0 = ggml_init(params); + ggml_cgraph * gf = ggml_new_graph_custom(ctx0, PARAKEET_MAX_NODES, false); + + // Prediction Network + struct ggml_tensor * token = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + ggml_set_name(token, "token_inp"); + ggml_set_input(token); + + struct ggml_tensor * token_embd = ggml_get_rows(ctx0, model.prediction.embed_w, token); + + struct ggml_tensor * inpL = token_embd; + + for (int il = 0; il < hparams.n_pred_layers; ++il) { + inpL = parakeet_build_graph_lstm_layer(ctx0, gf, inpL, + model.prediction.lstm_layer[il].ih_w, + model.prediction.lstm_layer[il].hh_w, + model.prediction.lstm_layer[il].b_h, + pstate.lstm_state.layer[il].h_state, + pstate.lstm_state.layer[il].c_state, + il); + } + + struct ggml_tensor * pred_out = inpL; + ggml_format_name(pred_out, "lstm_pred_out"); + + // Project the prediction network output to the joint network hidden dimension. + struct ggml_tensor * pred = ggml_mul_mat(ctx0, model.joint.pred_w, pred_out); + pred = ggml_add(ctx0, pred, model.joint.pred_b); + ggml_set_name(pred, "h_pred"); + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, pred, pstate.pred_out)); + + ggml_free(ctx0); + + return gf; +} + +static struct ggml_cgraph * parakeet_build_graph_joint( + parakeet_context & pctx, + parakeet_state & pstate, + const parakeet_batch & batch, + bool worst_case) { + GGML_UNUSED(worst_case); + const auto & model = pctx.model; + const auto & hparams = model.hparams; + + struct ggml_init_params params = { + /*.mem_size =*/ pstate.sched_decode.meta.size(), + /*.mem_buffer =*/ pstate.sched_decode.meta.data(), + /*.no_alloc =*/ true, + }; + + struct ggml_context * ctx0 = ggml_init(params); + ggml_cgraph * gf = ggml_new_graph_custom(ctx0, PARAKEET_MAX_NODES, false); + + struct ggml_tensor * pred = pstate.pred_out; + ggml_format_name(pred, "pred"); + + const int t_idx = batch.i_time[0]; + struct ggml_tensor * enc_out = ggml_view_1d(ctx0, pstate.enc_out, hparams.n_audio_state, + (size_t) t_idx * pstate.enc_out->nb[1]); + ggml_format_name(enc_out, "enc_out_view"); + + // Project the encoder output to the joint network hidden dimension. + struct ggml_tensor * enc = ggml_mul_mat(ctx0, model.joint.enc_w, enc_out); + enc = ggml_add(ctx0, enc, model.joint.enc_b); + ggml_set_name(enc, "enc"); + + struct ggml_tensor * joint = ggml_add(ctx0, enc, pred); + ggml_set_name(joint, "joint"); + joint = ggml_relu(ctx0, joint); + + struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.joint.net_w, joint); + logits = ggml_add(ctx0, logits, model.joint.net_b); + ggml_set_output(logits); + ggml_set_name(logits, "logits"); + + struct ggml_tensor * probs = ggml_soft_max(ctx0, logits); + struct ggml_tensor * log_probs = ggml_log(ctx0, probs); + ggml_set_output(log_probs); + ggml_format_name(log_probs, "log_probs"); + + ggml_build_forward_expand(gf, log_probs); + + ggml_free(ctx0); + + return gf; +} + +static bool parakeet_predict( + parakeet_context & pctx, + parakeet_state & pstate, + const parakeet_batch & batch, + const int n_threads, + ggml_abort_callback abort_callback, + void * abort_callback_data) { + + const int n_tokens = batch.n_tokens; + + const int64_t t_start_us = ggml_time_us(); + + { + auto & sched = pstate.sched_decode.sched; + + const int64_t t_build_start_us = ggml_time_us(); + ggml_cgraph * gf = parakeet_build_graph_prediction(pctx, pstate, batch, false); + pstate.t_predict_build_us += ggml_time_us() - t_build_start_us; + + const int64_t t_alloc_start_us = ggml_time_us(); + if (!ggml_backend_sched_alloc_graph(sched, gf)) { + // should never happen as we pre-allocate the memory + return false; + } + pstate.t_predict_alloc_us += ggml_time_us() - t_alloc_start_us; + + // set the inputs + { + struct ggml_tensor * token_inp = ggml_graph_get_tensor(gf, "token_inp"); + ggml_backend_tensor_set(token_inp, batch.token, 0, n_tokens * ggml_element_size(token_inp)); + } + + const int64_t t_compute_start_us = ggml_time_us(); + if (!ggml_graph_compute_helper(sched, gf, n_threads)) { + return false; + } + pstate.t_predict_compute_us += ggml_time_us() - t_compute_start_us; + } + + pstate.t_predict_us += ggml_time_us() - t_start_us; + pstate.n_predict++; + + return !(abort_callback && abort_callback(abort_callback_data)); +} + +static bool parakeet_joint( + parakeet_context & pctx, + parakeet_state & pstate, + const parakeet_batch & batch, + const int n_threads, + ggml_abort_callback abort_callback, + void * abort_callback_data) { + const int64_t t_start_us = ggml_time_us(); + + const auto & model = pctx.model; + const auto & hparams = model.hparams; + const int n_tokens = batch.n_tokens; + + auto & logits_out = pstate.logits; + + struct ggml_tensor * logits; + + { + auto & sched = pstate.sched_decode.sched; + + ggml_cgraph * gf = parakeet_build_graph_joint(pctx, pstate, batch, false); + + if (!ggml_backend_sched_alloc_graph(sched, gf)) { + // should never happen as we pre-allocate the memory + return false; + } + + logits = ggml_graph_node(gf, -1); + + if (!ggml_graph_compute_helper(sched, gf, n_threads)) { + return false; + } + + } + + const int n_logits = hparams.n_vocab + hparams.n_tdt_durations + 1; // one for the blank token + logits_out.resize(n_tokens * n_logits); + for (int i = 0; i < n_tokens; i++) { + if (batch.logits[i] == 0) { + continue; + } + ggml_backend_tensor_get(logits, logits_out.data() + (n_logits*i), sizeof(float)*(n_logits*i), sizeof(float)*n_logits); + } + + if (batch.n_tokens == 1) { + pstate.t_decode_us += ggml_time_us() - t_start_us; + pstate.n_decode++; + } + + return !(abort_callback && abort_callback(abort_callback_data)); +} + +static bool is_word_start_token(parakeet_vocab & vocab, parakeet_token token_id) { + const std::string & token_str = vocab.id_to_token[token_id]; + // check if it starts with the SentencePiece meta-space "▁" (U+2581) or 3-byte UTF-8 character: 0xE2 0x96 0x81 + if (!token_str.empty()) { + if (token_str.find("\xE2\x96\x81") == 0 || token_str[0] == '_') { + return true; + } + } + return false; +} + +static bool is_punctuation_token(parakeet_vocab & vocab, parakeet_token token_id) { + const std::string & token_str = vocab.id_to_token[token_id]; + static const std::string punct_chars = ".,!?;:'\"-()[]{}"; + + if (token_str.empty()) { + return false; + } + + std::string clean_token = token_str; + if (clean_token.find("\xE2\x96\x81") == 0) { + clean_token = clean_token.substr(3); // Remove the 3-byte UTF-8 character + } else if (clean_token[0] == '_') { + clean_token = clean_token.substr(1); + } + + return clean_token.length() == 1 && punct_chars.find(clean_token[0]) != std::string::npos; +} + +// Collapse punctuation timestamps to match the original Parakeet model. +// Punctuations symbols like ',', '.' and others are not spoken words but the +// model will still produce a duration for these tokens. But since these are +// non-spoken we collapse the timestamps so that they don't have an time duration. +static void refine_timestamps_tdt(parakeet_vocab & vocab, std::vector<parakeet_token_data> & tokens) { + if (tokens.empty()) { + return; + } + + int64_t last_non_punct_t1 = -1; + + for (size_t i = 0; i < tokens.size(); ++i) { + if (is_punctuation_token(vocab, tokens[i].id)) { + if (last_non_punct_t1 >= 0) { + tokens[i].t0 = last_non_punct_t1; + tokens[i].t1 = last_non_punct_t1; + } + } else { + last_non_punct_t1 = tokens[i].t1; + } + } +} + +static parakeet_token_data create_token_data( + parakeet_context & pctx, + parakeet_state & pstate, + parakeet_token token_id, + int duration_idx, + int duration_value, + int frame_index, + float token_logit, + int n_vocab_logits) { + + float token_sum = 0.0f; + for (int i = 0; i < n_vocab_logits; ++i) { + token_sum += expf(pstate.logits[i]); + } + float token_p = expf(token_logit) / token_sum; + + parakeet_token_data token_data; + token_data.id = token_id; + token_data.duration_idx = duration_idx; + token_data.duration_value = duration_value; + token_data.frame_index = frame_index; + token_data.p = token_p; + token_data.plog = token_logit; + token_data.t0 = frame_index * pctx.model.hparams.subsampling_factor; + token_data.t1 = (frame_index + duration_value) * pctx.model.hparams.subsampling_factor; + token_data.is_word_start = is_word_start_token(pctx.vocab, token_id); + + return token_data; +} + +static bool parakeet_decode( + parakeet_context & pctx, + parakeet_state & pstate, + parakeet_batch & batch, + const int n_threads, + const parakeet_full_params * params = nullptr) { + const auto & hparams = pctx.model.hparams; + const auto & tdt_durations = pctx.model.tdt_durations; + + const int n_tdt_durations = hparams.n_tdt_durations; + const int n_frames = pstate.n_frames; + const int blank_id = pctx.vocab.token_blank; + const int n_vocab_logits = blank_id + 1; + const int max_tokens_per_timestep = hparams.n_max_tokens; + + // time index into the encoder frame (current time frame) + int t = 0; + // number of symbols emitted for the current time frame + int tokens_emitted = 0; + + // Start with the blank token (8192) + parakeet_token last_token = blank_id; + + PARAKEET_LOG_DEBUG("parakeet_decode: starting decode with n_frames=%d\n", n_frames); + + batch.n_tokens = 1; + batch.token[0] = last_token; + batch.logits[0] = 1; + batch.i_time[0] = 0; + + // run the prediction network for the initial blank token. This will + // initialize the LSTM state and produce an initial hidden state that can + // be used in the joint network below. + if (!parakeet_predict(pctx, pstate, batch, n_threads, + params ? params->abort_callback : nullptr, + params ? params->abort_callback_user_data : nullptr)) { + return false; + } + + // process all time frames of the encoder output + while (t < n_frames) { + batch.n_tokens = 1; + batch.i_time[0] = t; + batch.logits[0] = 1; + + // Use the current encoder frame (t) and the output of the prediction to + // generate probabilities for the next token and duration. batch.i_time + // is used in to select the correct frame from the encoder output. + // The joint network outputs logits for all the tokens in the vocabulary + // plus the blank token, and also n_duration logits for the duration + // tokens which contain information about how many frames to skip/advance forward. + if (!parakeet_joint(pctx, pstate, batch, n_threads, + params ? params->abort_callback : nullptr, + params ? params->abort_callback_user_data : nullptr)) { + return false; + } + + const int64_t t_start_sample_us = ggml_time_us(); + + // find the best token (greedy). + // TODO: implement beam search? + int best_token = 0; + float max_logit = -1e10f; + for (int i = 0; i < n_vocab_logits; ++i) { + if (pstate.logits[i] > max_logit) { + max_logit = pstate.logits[i]; + best_token = i; + } + } + + // find the max index of the duration logits, and look up that index + // value in the tdt_durations array to get the actual duration value. + int best_duration_idx = 0; + float best_duration_logit = -1e10f; + for (int i = 0; i < n_tdt_durations; ++i) { + if (pstate.logits[n_vocab_logits + i] > best_duration_logit) { + best_duration_logit = pstate.logits[n_vocab_logits + i]; + best_duration_idx = i; + } + } + // look up that max duration index value in the tdt_durations array to + // get the actual duration value. + int duration = tdt_durations[best_duration_idx]; + + if (best_token == blank_id) { + if (duration == 0) { + duration = 1; + } + // skip forward by duration time frames. + t += duration; + // reset symbols emitted counter + tokens_emitted = 0; + // continue without predicting. + continue; + } + + // Emit non-blank token at current frame t. + pstate.decoded_tokens.push_back(best_token); + pstate.t_sample_us += ggml_time_us() - t_start_sample_us; + pstate.n_sample++; + + parakeet_token_data token_data = create_token_data( + pctx, pstate, best_token, best_duration_idx, duration, t, + max_logit, n_vocab_logits); + + pstate.decoded_token_data.push_back(token_data); + + // Call token callback if registered (for real-time streaming) + if (params && params->new_token_callback) { + params->new_token_callback(&pctx, &pstate, &token_data, params->new_token_callback_user_data); + } + + last_token = best_token; + + // advance predictor for the non-blank token. + batch.token[0] = last_token; + if (!parakeet_predict(pctx, pstate, batch, n_threads, + params ? params->abort_callback : nullptr, + params ? params->abort_callback_user_data : nullptr)) { + return false; + } + + // if duration greater than 0, continue looping over the encoder frames + // and skip to the updated time frame (t). + if (duration > 0) { + t += duration; + tokens_emitted = 0; + continue; + } + + // if duration is zero we stay on the current time frame. + tokens_emitted++; + if (tokens_emitted >= max_tokens_per_timestep) { + t += 1; // forced blank/time advance behavior + tokens_emitted = 0; + } + } + + return true; +} + +// 500 -> 00:05.000 +// 6000 -> 01:00.000 +// naive Discrete Fourier Transform +// input is real-valued +// output is complex-valued +static void dft(const float* in, int N, float* out, const parakeet_mel_cache & cache) { + const int sin_cos_step = cache.n_fft / N; + + for (int k = 0; k < N; k++) { + float re = 0; + float im = 0; + + for (int n = 0; n < N; n++) { + int idx = (k * n * sin_cos_step) % cache.n_fft; // t = 2*M_PI*k*n/N + re += in[n]*cache.cos_vals[idx]; // cos(t) + im -= in[n]*cache.sin_vals[idx]; // sin(t) + } + + out[k*2 + 0] = re; + out[k*2 + 1] = im; + } +} + +// Cooley-Tukey FFT +// poor man's implementation - use something better +// input is real-valued +// output is complex-valued +static void fft(float* in, int N, float* out, const parakeet_mel_cache & cache) { + if (N == 1) { + out[0] = in[0]; + out[1] = 0; + return; + } + + const int half_N = N / 2; + if (N - half_N*2 == 1) { + dft(in, N, out, cache); + return; + } + + float* even = in + N; + for (int i = 0; i < half_N; ++i) { + even[i]= in[2*i]; + } + float* even_fft = out + 2 * N; + fft(even, half_N, even_fft, cache); + + float* odd = even; + for (int i = 0; i < half_N; ++i) { + odd[i] = in[2*i + 1]; + } + float* odd_fft = even_fft + N; + fft(odd, half_N, odd_fft, cache); + + const int sin_cos_step = cache.n_fft / N; + for (int k = 0; k < half_N; k++) { + int idx = k * sin_cos_step; // t = 2*M_PI*k/N + float re = cache.cos_vals[idx]; // cos(t) + float im = -cache.sin_vals[idx]; // sin(t) + + float re_odd = odd_fft[2*k + 0]; + float im_odd = odd_fft[2*k + 1]; + + out[2*k + 0] = even_fft[2*k + 0] + re*re_odd - im*im_odd; + out[2*k + 1] = even_fft[2*k + 1] + re*im_odd + im*re_odd; + + out[2*(k + half_N) + 0] = even_fft[2*k + 0] - re*re_odd + im*im_odd; + out[2*(k + half_N) + 1] = even_fft[2*k + 1] - re*im_odd - im*re_odd; + } +} + +struct mel_worker_params { + int ith; + int window_size; + int n_samples; + int frame_size; + int frame_step; + int n_threads; +}; + +static void log_mel_spectrogram_worker_thread( + mel_worker_params params, + const float * window_func, + const std::vector<float> & samples, + const parakeet_filters & filters, + parakeet_mel & mel, + const parakeet_mel_cache & cache) { + std::vector<float> fft_in(params.frame_size * 2, 0.0); + std::vector<float> fft_out(params.frame_size * 2 * 2 * 2); + + int n_fb = filters.n_fb; // number of frequency bins + int i = params.ith; + + // make sure n_fb == 1 + (frame_size / 2), bin_0 to bin_nyquist + assert(n_fb == 1 + (params.frame_size / 2)); + + const double eps = 5.960464477539063e-08; + + // calculate FFT only when fft_in are not all zero + for (; i < std::min(params.n_samples / params.frame_step + 1, mel.n_len); i += params.n_threads) { + const int offset = i * params.frame_step; + + const int window_pad_left = (params.frame_size - params.window_size) / 2; + + // Zero-pad left + std::fill(fft_in.begin(), fft_in.begin() + window_pad_left, 0.0f); + + // Apply windowed samples in the center + const int n_to_process = std::min({params.window_size, params.n_samples - offset}); + for (int j = 0; j < n_to_process; j++) { + fft_in[window_pad_left + j] = window_func[j] * samples[offset + window_pad_left + j]; + } + + // Zero-pad right (and any samples we didn't have) + std::fill(fft_in.begin() + window_pad_left + n_to_process, fft_in.begin() + params.frame_size, 0.0f); + + // FFT + fft(fft_in.data(), params.frame_size, fft_out.data(), cache); + + // Calculate modulus^2 of complex numbers + // Use pow(fft_out[2 * j + 0], 2) + pow(fft_out[2 * j + 1], 2) causes inference quality problem? Interesting. + for (int j = 0; j < n_fb; j++) { + fft_out[j] = (fft_out[2 * j + 0] * fft_out[2 * j + 0] + fft_out[2 * j + 1] * fft_out[2 * j + 1]); + } + + // mel spectrogram + for (int j = 0; j < mel.n_mel; j++) { + double sum = 0.0; + // unroll loop (suggested by GH user @lunixbochs) + int k = 0; + for (k = 0; k < n_fb - 3; k += 4) { + sum += + fft_out[k + 0] * filters.data[j * n_fb + k + 0] + + fft_out[k + 1] * filters.data[j * n_fb + k + 1] + + fft_out[k + 2] * filters.data[j * n_fb + k + 2] + + fft_out[k + 3] * filters.data[j * n_fb + k + 3]; + } + // handle n_fb remainder + for (; k < n_fb; k++) { + sum += fft_out[k] * filters.data[j * n_fb + k]; + } + + mel.data[i * mel.n_mel + j] = std::log(sum + eps); + } + } + + // Otherwise fft_out are all zero - use log(eps) for consistency + const double empty_sum = std::log(eps); + for (; i < mel.n_len; i += params.n_threads) { + for (int j = 0; j < mel.n_mel; j++) { + mel.data[i * mel.n_mel + j] = empty_sum; + } + } +} + +static bool log_mel_spectrogram( + parakeet_state & wstate, + const float * samples, + const int n_samples, + const int /*sample_rate*/, + const int frame_size, + const int frame_step, + const int n_mel, + const int n_threads, + const parakeet_filters & filters, + const bool debug, + parakeet_mel & mel, + const parakeet_mel_cache & cache) { + const int64_t t_start_us = ggml_time_us(); + + const float * window_func = cache.window.empty() ? cache.hann_window.data() : cache.window.data(); + const int window_size = cache.window.empty() ? cache.n_fft : cache.window.size(); + + std::vector<float> samples_preprocessed(samples, samples + n_samples); + + // Apply preemphasis filter (high-pass): x[i] = x[i] - 0.97 * x[i-1] + { + const float preemph = 0.97f; + for (int i = n_samples - 1; i > 0; i--) { + samples_preprocessed[i] = samples_preprocessed[i] - preemph * samples_preprocessed[i - 1]; + } + } + + // Parakeet Pytorch implementation uses centered contant padding. + const size_t pad = (size_t)(frame_size / 2); + std::vector<float> samples_padded(n_samples + 2 * pad, 0.0f); + std::copy(samples_preprocessed.begin(), samples_preprocessed.end(), samples_padded.begin() + pad); + + mel.n_mel = n_mel; + mel.n_len = (samples_padded.size() - frame_size) / frame_step + 1; + mel.n_len_org = mel.n_len; + mel.data.resize(mel.n_mel * mel.n_len); + + // Worker Threads (STFT + Mel + Natural Log) + { + std::vector<std::thread> workers(n_threads - 1); + const mel_worker_params mel_params { 0, window_size, (int)samples_padded.size(), frame_size, frame_step, n_threads }; + + for (int iw = 0; iw < n_threads - 1; ++iw) { + mel_worker_params params = mel_params; + params.ith = iw + 1; + workers[iw] = std::thread(log_mel_spectrogram_worker_thread, + params, + window_func, + std::cref(samples_padded), + std::cref(filters), + std::ref(mel), + std::cref(cache)); + } + + log_mel_spectrogram_worker_thread( + mel_params, + window_func, + samples_padded, + filters, + mel, + cache); + + for (int iw = 0; iw < n_threads - 1; ++iw) { + workers[iw].join(); + } + } + + { + const double eps = 1e-5; + int valid_frames = n_samples / frame_step; + + for (int j = 0; j < mel.n_mel; j++) { + double sum = 0.0; + double sq_diff_sum = 0.0; + + // Calculate Mean ONLY on valid audio frames + for (int i = 0; i < valid_frames; i++) { + sum += (double)mel.data[i * mel.n_mel + j]; + } + double mean = sum / valid_frames; + + // Calculate Variance ONLY on valid audio frames + for (int i = 0; i < valid_frames; i++) { + double diff = (double)mel.data[i * mel.n_mel + j] - mean; + sq_diff_sum += diff * diff; + } + + double std_dev = std::sqrt(sq_diff_sum / (valid_frames - 1.0)); + double denominator = std_dev + eps; + + // Apply to ALL frames (including the padded ones) + for (int i = 0; i < mel.n_len; i++) { + mel.data[i * mel.n_mel + j] = (float)((mel.data[i * mel.n_mel + j] - mean) / denominator); + } + } + } + + wstate.t_mel_us += ggml_time_us() - t_start_us; + + if (debug) { + std::ofstream outFile("log_mel_spectrogram.json"); + outFile << "["; + for (uint64_t i = 0; i < mel.data.size() - 1; i++) { + outFile << mel.data[i] << ", "; + } + outFile << mel.data[mel.data.size() - 1] << "]"; + outFile.close(); + } + + return true; +} + +static std::vector<parakeet_vocab::id> tokenize(const parakeet_vocab & vocab, const std::string & text) { + std::vector<parakeet_vocab::id> tokens; + const std::string normalized = sentencepiece_normalize(text); + + size_t i = 0; + while (i < normalized.size()) { + const size_t remaining = normalized.size() - i; + const size_t max_len = std::min(vocab.max_token_length, remaining); + + bool found = false; + for (size_t len = max_len; len > 0; --len) { + const auto it = vocab.token_to_id.find(normalized.substr(i, len)); + if (it != vocab.token_to_id.end() && !is_sentencepiece_control(it->first)) { + tokens.push_back(it->second); + i += len; + found = true; + break; + } + } + + if (!found) { + if (vocab.token_unk >= 0) { + tokens.push_back(vocab.token_unk); + } + + const unsigned char c = static_cast<unsigned char>(normalized[i]); + i += utf8_codepoint_len(c); + } + } + + return tokens; +} + + +// +// interface implementation +// + +struct parakeet_state * parakeet_init_state(parakeet_context * ctx) { + parakeet_state * state = new parakeet_state; + + state->backends = parakeet_backend_init(ctx->params); + if (state->backends.empty()) { + PARAKEET_LOG_ERROR("%s: parakeet_backend_init() failed\n", __func__); + parakeet_free_state(state); + return nullptr; + } + + const int batch_size = ctx->model.hparams.n_audio_ctx; + + state->logits.reserve(ctx->vocab.n_vocab * batch_size); + + state->batch = parakeet_batch_init(batch_size); + + { + const int n_audio_state = ctx->model.hparams.n_audio_state; + const int subsampl_factor = ctx->model.hparams.subsampling_factor; + const int n_frames_max = (batch_size + subsampl_factor - 1) / subsampl_factor; + + if (!parakeet_enc_state_init(*state, state->backends[0], n_audio_state, n_frames_max)) { + PARAKEET_LOG_ERROR("%s: parakeet_enc_state_init() failed\n", __func__); + parakeet_free_state(state); + return nullptr; + } + + const size_t mem_enc_ctx = state->enc_out_buf.size(); + const size_t mem_enc_out_buf = ggml_backend_buffer_get_size(state->enc_out_buffer); + PARAKEET_LOG_INFO("%s: enc_out state: %7.2f MB (meta) + %7.2f MB (data)\n", __func__, + mem_enc_ctx / 1024.0 / 1024.0, mem_enc_out_buf / 1024.0 / 1024.0); + } + + // conv/encoder allocator + bool ok = parakeet_sched_graph_init(state->sched_encode, state->backends, + [&]() { + return parakeet_build_graph_encode(*ctx, *state); + }); + + if (!ok) { + PARAKEET_LOG_ERROR("%s: failed to init encode allocator\n", __func__); + parakeet_free_state(state); + return nullptr; + } + state->sched_encode_n_audio_ctx = state->n_audio_ctx > 0 ? state->n_audio_ctx : ctx->model.hparams.n_audio_ctx; + + if (!parakeet_lstm_state_init(*state, state->backends[0], ctx->model.hparams.n_pred_layers, ctx->model.hparams.n_pred_dim)) { + PARAKEET_LOG_ERROR("%s: parakeet_lstm_states_init () failed\n", __func__); + parakeet_free_state(state); + return nullptr; + } + + { + const size_t mem_lstm_ctx = state->lstm_state.ctx_buf.size(); + const size_t mem_lstm_buf = ggml_backend_buffer_get_size(state->lstm_state.buffer); + PARAKEET_LOG_INFO("%s: lstm state: %7.2f MB (meta) + %7.2f MB (data)\n", __func__, + mem_lstm_ctx / 1024.0 / 1024.0, mem_lstm_buf / 1024.0 / 1024.0); + } + + if (!parakeet_pred_state_init(*state, state->backends[0], ctx->model.hparams.n_pred_dim)) { + PARAKEET_LOG_ERROR("%s: parakeet_pred_state_init() failed\n", __func__); + parakeet_free_state(state); + return nullptr; + } + + { + const size_t mem_pred_ctx = state->pred_out_buf.size(); + const size_t mem_pred_out_buf = ggml_backend_buffer_get_size(state->pred_out_buffer); + PARAKEET_LOG_INFO("%s: pred state: %7.2f MB (meta) + %7.2f MB (data)\n", __func__, + mem_pred_ctx / 1024.0 / 1024.0, mem_pred_out_buf / 1024.0 / 1024.0); + } + + PARAKEET_LOG_INFO("%s: compute buffer (encode) = %7.2f MB\n", __func__, parakeet_sched_size(state->sched_encode) / 1e6); + + { + bool ok = parakeet_sched_graph_init(state->sched_decode, state->backends, + [&]() { + const auto & hparams = ctx->model.hparams; + const int n_tokens = hparams.n_audio_ctx; // Use audio ctx for Parakeet + + parakeet_batch_prep_legacy(state->batch, nullptr, n_tokens, 0, 0); + + return parakeet_build_graph_prediction(*ctx, *state, state->batch, true); + }); + + if (!ok) { + PARAKEET_LOG_ERROR("%s: failed to init decoder allocator\n", __func__); + parakeet_free_state(state); + return nullptr; + } + + PARAKEET_LOG_INFO("%s: compute buffer (decode) = %7.2f MB\n", __func__, parakeet_sched_size(state->sched_decode) / 1e6); + } + + return state; +} + +struct parakeet_context_params parakeet_context_default_params() { + struct parakeet_context_params result = { + /*.use_gpu =*/ true, + /*.gpu_device =*/ 0, + }; + return result; +} + +struct parakeet_context * parakeet_init_from_file_with_params_no_state(const char * path_model, struct parakeet_context_params params) { + PARAKEET_LOG_INFO("%s: loading model from '%s'\n", __func__, path_model); +#ifdef _MSC_VER + // Convert UTF-8 path to wide string (UTF-16) for Windows, resolving character encoding issues. + std::wstring_convert<std::codecvt_utf8<wchar_t>> converter; + std::wstring path_model_wide = converter.from_bytes(path_model); + auto fin = std::ifstream(path_model_wide, std::ios::binary); +#else + auto fin = std::ifstream(path_model, std::ios::binary); +#endif + if (!fin) { + PARAKEET_LOG_ERROR("%s: failed to open '%s'\n", __func__, path_model); + return nullptr; + } + + parakeet_model_loader loader = {}; + + loader.context = &fin; + + loader.read = [](void * ctx, void * output, size_t read_size) { + std::ifstream * fin = (std::ifstream*)ctx; + fin->read((char *)output, read_size); + return read_size; + }; + + loader.eof = [](void * ctx) { + std::ifstream * fin = (std::ifstream*)ctx; + return fin->eof(); + }; + + loader.close = [](void * ctx) { + std::ifstream * fin = (std::ifstream*)ctx; + fin->close(); + }; + + auto ctx = parakeet_init_with_params_no_state(&loader, params); + + if (ctx) { + ctx->path_model = path_model; + } + + return ctx; +} + +struct parakeet_context * parakeet_init_from_buffer_with_params_no_state(void * buffer, size_t buffer_size, struct parakeet_context_params params) { + struct buf_context { + uint8_t* buffer; + size_t size; + size_t current_offset; + }; + + buf_context ctx = { reinterpret_cast<uint8_t*>(buffer), buffer_size, 0 }; + + PARAKEET_LOG_INFO("%s: loading model from buffer\n", __func__); + + parakeet_model_loader loader = {}; + + loader.context = &ctx; + + loader.read = [](void * ctx, void * output, size_t read_size) { + buf_context * buf = reinterpret_cast<buf_context *>(ctx); + + size_t size_to_copy = buf->current_offset + read_size < buf->size ? read_size : buf->size - buf->current_offset; + + memcpy(output, buf->buffer + buf->current_offset, size_to_copy); + buf->current_offset += size_to_copy; + + return size_to_copy; + }; + + loader.eof = [](void * ctx) { + buf_context * buf = reinterpret_cast<buf_context *>(ctx); + + return buf->current_offset >= buf->size; + }; + + loader.close = [](void * /*ctx*/) { }; + + return parakeet_init_with_params_no_state(&loader, params); +} + +struct parakeet_context * parakeet_init_with_params_no_state(struct parakeet_model_loader * loader, struct parakeet_context_params params) { + ggml_time_init(); + + PARAKEET_LOG_INFO("%s: use gpu = %d\n", __func__, params.use_gpu); + PARAKEET_LOG_INFO("%s: gpu_device = %d\n", __func__, params.gpu_device); + PARAKEET_LOG_INFO("%s: devices = %zu\n", __func__, ggml_backend_dev_count()); + PARAKEET_LOG_INFO("%s: backends = %zu\n", __func__, ggml_backend_reg_count()); + + parakeet_context * ctx = new parakeet_context; + ctx->params = params; + + bool model_loaded = false; + try { + model_loaded = parakeet_model_load(loader, *ctx); + } catch (const std::exception & e) { + PARAKEET_LOG_ERROR("%s: exception during model load: %s\n", __func__, e.what()); + } catch (...) { + PARAKEET_LOG_ERROR("%s: unknown exception during model load\n", __func__); + } + + if (!model_loaded) { + loader->close(loader->context); + PARAKEET_LOG_ERROR("%s: failed to load model\n", __func__); + delete ctx; + return nullptr; + } + + loader->close(loader->context); + + // Initialize mel cache with model's FFT size + ctx->mel_cache.init(ctx->model.hparams.n_fft); + PARAKEET_LOG_INFO("%s: initialized mel cache with n_fft = %d\n", __func__, ctx->model.hparams.n_fft); + + return ctx; +} + +struct parakeet_context * parakeet_init_from_file_with_params(const char * path_model, struct parakeet_context_params params) { + parakeet_context * ctx = parakeet_init_from_file_with_params_no_state(path_model, params); + if (!ctx) { + return nullptr; + } + + ctx->state = parakeet_init_state(ctx); + if (!ctx->state) { + parakeet_free(ctx); + return nullptr; + } + + return ctx; +} + +struct parakeet_context * parakeet_init_from_buffer_with_params(void * buffer, size_t buffer_size, struct parakeet_context_params params) { + parakeet_context * ctx = parakeet_init_from_buffer_with_params_no_state(buffer, buffer_size, params); + if (!ctx) { + return nullptr; + } + + ctx->state = parakeet_init_state(ctx); + if (!ctx->state) { + parakeet_free(ctx); + return nullptr; + } + + return ctx; +} + +struct parakeet_context * parakeet_init_with_params(struct parakeet_model_loader * loader, struct parakeet_context_params params) { + parakeet_context * ctx = parakeet_init_with_params_no_state(loader, params); + if (!ctx) { + return nullptr; + } + + ctx->state = parakeet_init_state(ctx); + if (!ctx->state) { + parakeet_free(ctx); + return nullptr; + } + + return ctx; +} + +void parakeet_free_state(struct parakeet_state * state) { + if (state) { + ggml_backend_buffer_free(state->lstm_state.buffer); + ggml_backend_buffer_free(state->pred_out_buffer); + ggml_backend_buffer_free(state->enc_out_buffer); + + parakeet_batch_free(state->batch); + + parakeet_sched_free(state->sched_encode); + parakeet_sched_free(state->sched_decode); + + for (auto & backend : state->backends) { + ggml_backend_free(backend); + } + + delete state; + } +} + +void parakeet_free(struct parakeet_context * ctx) { + if (ctx) { + for (ggml_context * context : ctx->model.ctxs) { + ggml_free(context); + } + + for (ggml_backend_buffer_t buf : ctx->model.buffers) { + ggml_backend_buffer_free(buf); + } + + parakeet_free_state(ctx->state); + + delete ctx; + } +} + +void parakeet_free_context_params(struct parakeet_context_params * params) { + if (params) { + delete params; + } +} + +void parakeet_free_params(struct parakeet_full_params * params) { + if (params) { + delete params; + } +} + +int parakeet_pcm_to_mel_with_state(struct parakeet_context * ctx, struct parakeet_state * state, const float * samples, int n_samples, int n_threads) { + if (!log_mel_spectrogram(*state, + samples, + n_samples, + PARAKEET_SAMPLE_RATE, + ctx->model.hparams.n_fft, + PARAKEET_HOP_LENGTH, + ctx->model.filters.n_mel, + n_threads, + ctx->model.filters, + false, // debug + state->mel, + ctx->mel_cache)) { + PARAKEET_LOG_ERROR("%s: failed to compute mel spectrogram\n", __func__); + return -1; + } + + return 0; +} + +int parakeet_pcm_to_mel(struct parakeet_context * ctx, const float * samples, int n_samples, int n_threads) { + return parakeet_pcm_to_mel_with_state(ctx, ctx->state, samples, n_samples, n_threads); +} + +int parakeet_set_mel_with_state( + struct parakeet_context * ctx, + struct parakeet_state * state, + const float * data, + int n_len, + int n_mel) { + if (n_mel != ctx->model.filters.n_mel) { + PARAKEET_LOG_ERROR("%s: invalid number of mel bands: %d (expected %d)\n", __func__, n_mel, ctx->model.filters.n_mel); + return -1; + } + + state->mel.n_len = n_len; + state->mel.n_len_org = n_len; + state->mel.n_mel = n_mel; + + state->mel.data.resize(n_len*n_mel); + memcpy(state->mel.data.data(), data, n_len*n_mel*sizeof(float)); + + return 0; +} + +int parakeet_set_mel( + struct parakeet_context * ctx, + const float * data, + int n_len, + int n_mel) { + return parakeet_set_mel_with_state(ctx, ctx->state, data, n_len, n_mel); +} + +int parakeet_encode_with_state(struct parakeet_context * ctx, struct parakeet_state * state, int offset, int n_threads) { + if (!parakeet_encode_internal(*ctx, *state, offset, n_threads, nullptr, nullptr)) { + PARAKEET_LOG_ERROR("%s: failed to eval\n", __func__); + return -1; + } + + return 0; +} + +int parakeet_encode(struct parakeet_context * ctx, int offset, int n_threads) { + if (!parakeet_encode_internal(*ctx, *ctx->state, offset, n_threads, nullptr, nullptr)) { + PARAKEET_LOG_ERROR("%s: failed to eval\n", __func__); + return -1; + } + + return 0; +} + +int parakeet_tokenize(struct parakeet_context * ctx, const char * text, parakeet_token * tokens, int n_max_tokens) { + const auto res = tokenize(ctx->vocab, text); + + if (n_max_tokens < (int) res.size()) { + PARAKEET_LOG_ERROR("%s: too many resulting tokens: %d (max %d)\n", __func__, (int) res.size(), n_max_tokens); + return -(int) res.size(); + } + + for (int i = 0; i < (int) res.size(); i++) { + tokens[i] = res[i]; + } + + return res.size(); +} + +int parakeet_token_count(struct parakeet_context * ctx, const char * text) { + return -parakeet_tokenize(ctx, text, NULL, 0); +} + +int parakeet_model_n_vocab(struct parakeet_context * ctx) { + return ctx->model.hparams.n_vocab; +} + +int parakeet_model_n_audio_ctx(struct parakeet_context * ctx) { + return ctx->model.hparams.n_audio_ctx; +} + +int parakeet_model_n_audio_state(struct parakeet_context * ctx) { + return ctx->model.hparams.n_audio_state; +} + +int parakeet_model_n_audio_head(struct parakeet_context * ctx) { + return ctx->model.hparams.n_audio_head; +} + +int parakeet_model_n_audio_layer(struct parakeet_context * ctx) { + return ctx->model.hparams.n_audio_layer; +} + +int parakeet_model_n_mels(struct parakeet_context * ctx) { + return ctx->model.hparams.n_mels; +} + +int parakeet_model_ftype(struct parakeet_context * ctx) { + return ctx->model.hparams.ftype; +} + +int parakeet_n_len_from_state(struct parakeet_state * state) { + return state->mel.n_len_org; +} + +int parakeet_n_len(struct parakeet_context * ctx) { + return ctx->state->mel.n_len_org; +} + +int parakeet_n_vocab(struct parakeet_context * ctx) { + return ctx->vocab.n_vocab; +} + +int parakeet_n_audio_ctx(struct parakeet_context * ctx) { + return ctx->model.hparams.n_audio_ctx; +} + +float * parakeet_get_logits(struct parakeet_context * ctx) { + return ctx->state->logits.data(); +} + +float * parakeet_get_logits_from_state(struct parakeet_state * state) { + return state->logits.data(); +} + +const char * parakeet_token_to_str(struct parakeet_context * ctx, parakeet_token token) { + return ctx->vocab.id_to_token.at(token).c_str(); +} + +int parakeet_token_to_text(const char * token_str, bool is_first, char * output, int max_len) { + std::string text = sentencepiece_piece_to_text(token_str, is_first); + + if (output == nullptr) { + return text.size(); + } + + int bytes_to_copy = std::min((int)text.size(), max_len - 1); + if (bytes_to_copy > 0) { + memcpy(output, text.c_str(), bytes_to_copy); + output[bytes_to_copy] = '\0'; + } else if (max_len > 0) { + output[0] = '\0'; + } + + return text.size(); +} + +parakeet_token parakeet_token_bos(struct parakeet_context * ctx) { + return ctx->vocab.token_bos; +} + +parakeet_token parakeet_token_unk(struct parakeet_context * ctx) { + return ctx->vocab.token_unk; +} + +parakeet_token parakeet_token_blank(struct parakeet_context * ctx) { + return ctx->vocab.token_blank; +} + +struct parakeet_timings * parakeet_get_timings(struct parakeet_context * ctx) { + if (ctx->state == nullptr) { + return nullptr; + } + parakeet_timings * timings = new parakeet_timings; + timings->sample_ms = 1e-3f * ctx->state->t_sample_us / std::max(1, ctx->state->n_sample); + timings->encode_ms = 1e-3f * ctx->state->t_encode_us / std::max(1, ctx->state->n_encode); + timings->decode_ms = 1e-3f * ctx->state->t_decode_us / std::max(1, ctx->state->n_decode); + return timings; +} + +void parakeet_print_timings(struct parakeet_context * ctx) { + const int64_t t_end_us = ggml_time_us(); + + PARAKEET_LOG_INFO("\n"); + PARAKEET_LOG_INFO("%s: load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0f); + if (ctx->state != nullptr) { + + const int32_t n_sample = std::max(1, ctx->state->n_sample); + const int32_t n_encode = std::max(1, ctx->state->n_encode); + const int32_t n_decode = std::max(1, ctx->state->n_decode); + const int32_t n_predict = std::max(1, ctx->state->n_predict); + + PARAKEET_LOG_INFO("%s: fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h); + PARAKEET_LOG_INFO("%s: mel time = %8.2f ms\n", __func__, ctx->state->t_mel_us / 1000.0f); + PARAKEET_LOG_INFO("%s: sample time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample); + PARAKEET_LOG_INFO("%s: encode time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode); + PARAKEET_LOG_INFO("%s: decode time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode); + PARAKEET_LOG_INFO("%s: predict time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_predict_us, n_predict, 1e-3f * ctx->state->t_predict_us / n_predict); + PARAKEET_LOG_INFO("%s: - build = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_predict_build_us, n_predict, 1e-3f * ctx->state->t_predict_build_us / n_predict); + PARAKEET_LOG_INFO("%s: - alloc = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_predict_alloc_us, n_predict, 1e-3f * ctx->state->t_predict_alloc_us / n_predict); + PARAKEET_LOG_INFO("%s: - compute = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_predict_compute_us, n_predict, 1e-3f * ctx->state->t_predict_compute_us / n_predict); + + } + PARAKEET_LOG_INFO("%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f); +} + +void parakeet_reset_timings(struct parakeet_context * ctx) { + ctx->t_start_us = ggml_time_us(); + if (ctx->state != nullptr) { + ctx->state->t_mel_us = 0; + ctx->state->t_sample_us = 0; + ctx->state->t_encode_us = 0; + ctx->state->t_decode_us = 0; + ctx->state->t_predict_us = 0; + ctx->state->t_predict_build_us = 0; + ctx->state->t_predict_alloc_us = 0; + ctx->state->t_predict_compute_us = 0; + + ctx->state->n_sample = 0; + ctx->state->n_encode = 0; + ctx->state->n_decode = 0; + ctx->state->n_predict = 0; + } +} + +const char * parakeet_print_system_info(void) { + static std::string s; + + s = ""; + s += "PARAKEET : "; + + for (size_t i = 0; i < ggml_backend_reg_count(); i++) { + auto * reg = ggml_backend_reg_get(i); + auto * get_features_fn = (ggml_backend_get_features_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_get_features"); + if (get_features_fn) { + ggml_backend_feature * features = get_features_fn(reg); + s += ggml_backend_reg_name(reg); + s += " : "; + for (; features->name; features++) { + s += features->name; + s += " = "; + s += features->value; + s += " | "; + } + } + } + return s.c_str(); +} + +struct parakeet_context_params * parakeet_context_default_params_by_ref(void) { + struct parakeet_context_params params = parakeet_context_default_params(); + + struct parakeet_context_params* result = new parakeet_context_params(); + *result = params; + return result; +} + +struct parakeet_full_params * parakeet_full_default_params_by_ref(enum parakeet_sampling_strategy strategy) { + struct parakeet_full_params params = parakeet_full_default_params(strategy); + + struct parakeet_full_params* result = new parakeet_full_params(); + *result = params; + return result; +} + +struct parakeet_full_params parakeet_full_default_params(enum parakeet_sampling_strategy strategy) { + struct parakeet_full_params result = { + /*.strategy =*/ strategy, + /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()), + /*.offset_ms =*/ 0, + /*.duration_ms =*/ 0, + /*.no_context =*/ true, + /*.audio_ctx =*/ 0, + /*.new_token_callback =*/ nullptr, + /*.new_token_callback_user_data =*/ nullptr, + /*.new_segment_callback =*/ nullptr, + /*.new_segment_callback_user_data =*/ nullptr, + /*.progress_callback =*/ nullptr, + /*.progress_callback_user_data =*/ nullptr, + /*.encoder_begin_callback =*/ nullptr, + /*.encoder_begin_callback_user_data =*/ nullptr, + /*.abort_callback =*/ nullptr, + /*.abort_callback_user_data =*/ nullptr, + }; + + return result; +} + +static void parakeet_reset_state(struct parakeet_state * state) { + state->decoded_tokens.clear(); + state->decoded_token_data.clear(); + + if (state->lstm_state.buffer) { + ggml_backend_buffer_clear(state->lstm_state.buffer, 0); + } + +} + +// Encode and decode the mel spectrogram already in state, without recomputing it. +static int parakeet_chunk_with_state( + struct parakeet_context * ctx, + struct parakeet_state * state, + struct parakeet_full_params params) { + return parakeet_chunk(ctx, state, params, nullptr, 0); +} + +int parakeet_full_with_state( + struct parakeet_context * ctx, + struct parakeet_state * state, + struct parakeet_full_params params, + const float * samples, + int n_samples) { + state->result_all.clear(); + + if (params.no_context) { + parakeet_reset_state(state); + } + + if (n_samples > 0) { + if (parakeet_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) { + PARAKEET_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__); + return -2; + } + } + + const int n_mel_total = state->mel.n_len; + const int n_audio_ctx = ctx->model.hparams.n_audio_ctx; + + if (n_mel_total <= n_audio_ctx) { + if (params.progress_callback) { + params.progress_callback(ctx, state, 0, params.progress_callback_user_data); + } + return parakeet_chunk_with_state(ctx, state, params); + } + + PARAKEET_LOG_DEBUG("%s: audio too long (%d mel > n_audio_ctx=%d), using dynamic encoder graph\n", + __func__, n_mel_total, n_audio_ctx); + + if (params.encoder_begin_callback) { + if (!params.encoder_begin_callback(ctx, state, params.encoder_begin_callback_user_data)) { + PARAKEET_LOG_ERROR("%s: encoder_begin_callback returned false\n", __func__); + return -6; + } + } + + if (params.progress_callback) { + params.progress_callback(ctx, state, 0, params.progress_callback_user_data); + } + + if (!parakeet_ensure_encode_sched(*ctx, *state, n_mel_total)) { + PARAKEET_LOG_ERROR("%s: failed to allocate dynamic encoder graph for %d mel frames\n", + __func__, n_mel_total); + return -6; + } + + state->n_audio_ctx = n_mel_total; + + if (!parakeet_encode_internal(*ctx, *state, 0, params.n_threads, + params.abort_callback, params.abort_callback_user_data)) { + PARAKEET_LOG_ERROR("%s: failed to encode\n", __func__); + return -6; + } + + if (params.progress_callback) { + params.progress_callback(ctx, state, 100, params.progress_callback_user_data); + } + + const size_t tokens_before = state->decoded_tokens.size(); + + if (!parakeet_decode(*ctx, *state, state->batch, params.n_threads, ¶ms)) { + PARAKEET_LOG_ERROR("%s: failed to decode\n", __func__); + return -7; + } + + const size_t tokens_after = state->decoded_tokens.size(); + const size_t new_token_count = tokens_after - tokens_before; + + if (new_token_count > 0) { + std::string text; + std::vector<parakeet_token_data> result_tokens; + + for (size_t i = tokens_before; i < tokens_after; i++) { + const auto token_id = state->decoded_tokens[i]; + const char * tok_str = parakeet_token_to_str(ctx, token_id); + if (tok_str) { + const bool is_first = (tokens_before == 0) && text.empty(); + text += sentencepiece_piece_to_text(tok_str, is_first); + } + result_tokens.push_back(state->decoded_token_data[i]); + } + + refine_timestamps_tdt(ctx->vocab, result_tokens); + + if (!text.empty()) { + parakeet_segment seg; + seg.t0 = 0; + seg.t1 = state->n_frames; + seg.text = text; + seg.tokens = result_tokens; + state->result_all.push_back(std::move(seg)); + + if (params.new_segment_callback) { + params.new_segment_callback(ctx, state, 1, params.new_segment_callback_user_data); + } + } + } + + return 0; +} + +int parakeet_full( + struct parakeet_context * ctx, + struct parakeet_full_params params, + const float * samples, + int n_samples) { + return parakeet_full_with_state(ctx, ctx->state, params, samples, n_samples); +} + +int parakeet_chunk( + struct parakeet_context * ctx, + struct parakeet_state * state, + struct parakeet_full_params params, + const float * samples, + int n_samples) { + + if (params.no_context) { + parakeet_reset_state(state); + } + + if (n_samples > 0) { + if (parakeet_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) { + PARAKEET_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__); + return -2; + } + } + + if (params.audio_ctx == 0) { + const int total_len = parakeet_n_len_from_state(state); + const int model_max_ctx = parakeet_n_audio_ctx(ctx); + params.audio_ctx = std::min(total_len, model_max_ctx); + PARAKEET_LOG_DEBUG("Processing audio: total_frames=%d, chunk_size=%d\n", total_len, params.audio_ctx); + } + state->n_audio_ctx = params.audio_ctx; + + const int n_frames = parakeet_n_len_from_state(state); + + if (!parakeet_ensure_encode_sched(*ctx, *state, state->n_audio_ctx)) { + PARAKEET_LOG_ERROR("%s: failed to allocate encoder graph for %d mel frames\n", + __func__, state->n_audio_ctx); + return -6; + } + + if (params.encoder_begin_callback) { + if (!params.encoder_begin_callback(ctx, state, params.encoder_begin_callback_user_data)) { + PARAKEET_LOG_ERROR("%s: encoder_begin_callback returned false - aborting\n", __func__); + return -6; + } + } + if (!parakeet_encode_internal(*ctx, *state, 0, params.n_threads, params.abort_callback, params.abort_callback_user_data)) { + PARAKEET_LOG_ERROR("%s: failed to encode\n", __func__); + return -6; + } + + const size_t tokens_before = state->decoded_tokens.size(); + + if (!parakeet_decode(*ctx, *state, state->batch, params.n_threads, ¶ms)) { + PARAKEET_LOG_ERROR("%s: failed to decode\n", __func__); + return -7; + } + + const size_t tokens_after = state->decoded_tokens.size(); + const size_t new_token_count = tokens_after - tokens_before; + + if (new_token_count > 0) { + std::string text; + std::vector<parakeet_token_data> result_tokens; + + for (size_t i = tokens_before; i < tokens_after; i++) { + const auto token_id = state->decoded_tokens[i]; + const char * token_str = parakeet_token_to_str(ctx, token_id); + if (token_str) { + const bool is_first_piece = (tokens_before == 0) && text.empty(); + text += sentencepiece_piece_to_text(token_str, is_first_piece); + } + + // Use the stored token data from parakeet_decode + result_tokens.push_back(state->decoded_token_data[i]); + } + + refine_timestamps_tdt(ctx->vocab, result_tokens); + + if (!text.empty()) { + parakeet_segment segment; + segment.t0 = 0; // Caller tracks timing + segment.t1 = n_frames; + segment.text = text; + segment.tokens = result_tokens; + + state->result_all.push_back(std::move(segment)); + + if (params.new_segment_callback) { + params.new_segment_callback(ctx, state, 1, params.new_segment_callback_user_data); + } + } + } + + return 0; +} + +int parakeet_full_n_segments_from_state(struct parakeet_state * state) { + return state->result_all.size(); +} + +int parakeet_full_n_segments(struct parakeet_context * ctx) { + return ctx->state->result_all.size(); +} + +int64_t parakeet_full_get_segment_t0_from_state(struct parakeet_state * state, int i_segment) { + return state->result_all[i_segment].t0; +} + +int64_t parakeet_full_get_segment_t1_from_state(struct parakeet_state * state, int i_segment) { + return state->result_all[i_segment].t1; +} + +int64_t parakeet_full_get_segment_t0(struct parakeet_context * ctx, int i_segment) { + return parakeet_full_get_segment_t0_from_state(ctx->state, i_segment); +} + +int64_t parakeet_full_get_segment_t1(struct parakeet_context * ctx, int i_segment) { + return parakeet_full_get_segment_t1_from_state(ctx->state, i_segment); +} + +const char * parakeet_full_get_segment_text_from_state(struct parakeet_state * state, int i_segment) { + return state->result_all[i_segment].text.c_str(); +} + +const char * parakeet_full_get_segment_text(struct parakeet_context * ctx, int i_segment) { + return ctx->state->result_all[i_segment].text.c_str(); +} + +int parakeet_full_n_tokens_from_state(struct parakeet_state * state, int i_segment) { + return state->result_all[i_segment].tokens.size(); +} + +int parakeet_full_n_tokens(struct parakeet_context * ctx, int i_segment) { + return ctx->state->result_all[i_segment].tokens.size(); +} + +const char * parakeet_full_get_token_text_from_state(struct parakeet_context * ctx, struct parakeet_state * state, int i_segment, int i_token) { + return ctx->vocab.id_to_token[state->result_all[i_segment].tokens[i_token].id].c_str(); +} + +const char* parakeet_full_get_token_text(struct parakeet_context * ctx, int i_segment, int i_token) { + return ctx->vocab.id_to_token[ctx->state->result_all[i_segment].tokens[i_token].id].c_str(); +} + +parakeet_token parakeet_full_get_token_id_from_state(struct parakeet_state * state, int i_segment, int i_token) { + return state->result_all[i_segment].tokens[i_token].id; +} + +parakeet_token parakeet_full_get_token_id(struct parakeet_context * ctx, int i_segment, int i_token) { + return ctx->state->result_all[i_segment].tokens[i_token].id; +} + +struct parakeet_token_data parakeet_full_get_token_data_from_state(struct parakeet_state * state, int i_segment, int i_token) { + return state->result_all[i_segment].tokens[i_token]; +} + +struct parakeet_token_data parakeet_full_get_token_data(struct parakeet_context * ctx, int i_segment, int i_token) { + return ctx->state->result_all[i_segment].tokens[i_token]; +} + +float parakeet_full_get_token_p_from_state(struct parakeet_state * state, int i_segment, int i_token) { + return state->result_all[i_segment].tokens[i_token].p; +} + +float parakeet_full_get_token_p(struct parakeet_context * ctx, int i_segment, int i_token) { + return ctx->state->result_all[i_segment].tokens[i_token].p; +} + +void parakeet_log_set(ggml_log_callback log_callback, void * user_data) { + g_state.log_callback = log_callback ? log_callback : parakeet_log_callback_default; + g_state.log_callback_user_data = user_data; + ggml_log_set(g_state.log_callback, g_state.log_callback_user_data); +} + +const char * parakeet_version(void) { + return PARAKEET_VERSION; +} + +GGML_ATTRIBUTE_FORMAT(2, 3) +static void parakeet_log_internal(ggml_log_level level, const char * format, ...) { + va_list args; + va_start(args, format); + char buffer[1024]; + int len = vsnprintf(buffer, 1024, format, args); + if (len < 1024) { + g_state.log_callback(level, buffer, g_state.log_callback_user_data); + } else { + char* buffer2 = new char[len+1]; + vsnprintf(buffer2, len+1, format, args); + buffer2[len] = 0; + g_state.log_callback(level, buffer2, g_state.log_callback_user_data); + delete[] buffer2; + } + va_end(args); +} + +static void parakeet_log_callback_default(ggml_log_level level, const char * text, void * user_data) { + (void) level; + (void) user_data; +#ifndef PARAKEET_DEBUG + if (level == GGML_LOG_LEVEL_DEBUG) { + return; + } +#endif + fputs(text, stderr); + fflush(stderr); +} diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 646f45f2ab7..74a5b142948 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -118,3 +118,62 @@ target_compile_definitions(${VAD_TEST} PRIVATE SAMPLE_PATH="${PROJECT_SOURCE_DIR}/samples/jfk.wav") add_test(NAME ${VAD_TEST} COMMAND ${VAD_TEST}) set_tests_properties(${VAD_TEST} PROPERTIES LABELS "base;en") + +# Parakeet model loading test +set(PARAKEET_TEST test-parakeet) +add_executable(${PARAKEET_TEST} ${PARAKEET_TEST}.cpp) +target_include_directories(${PARAKEET_TEST} PRIVATE ../include ../ggml/include ../examples) +target_link_libraries(${PARAKEET_TEST} PRIVATE parakeet common) +target_compile_definitions(${PARAKEET_TEST} PRIVATE + PARAKEET_MODEL_PATH="${PROJECT_SOURCE_DIR}/models/for-tests-ggml-parakeet-tdt.bin" + SAMPLE_PATH="${PROJECT_SOURCE_DIR}/samples/jfk.wav") +add_test(NAME ${PARAKEET_TEST} COMMAND ${PARAKEET_TEST}) +set_tests_properties(${PARAKEET_TEST} PROPERTIES LABELS "parakeet;gh") + +# The following parakeet test require a real ggml-parakeet-tdt model to have +# been converted or downloaded: +# $ hf download danbev/parakeet parakeet-tdt-0.6b-v3-f32.bin --local-dir models +# +# And also required more audio samples that are shipped by default. These can +# downloaded by running: +# $ make samples +function(add_parakeet_transcription_test TEST_TARGET TEST_SOURCE SAMPLE_PATH EXPECTED_TRANSCRIPTION_PATH) + set(TRANSCRIPTION_SIMILARITY_THRESHOLD "1.0") + if (ARGC GREATER 4) + set(TRANSCRIPTION_SIMILARITY_THRESHOLD "${ARGV4}") + endif() + + add_executable(${TEST_TARGET} ${TEST_SOURCE}) + target_include_directories(${TEST_TARGET} PRIVATE ../include ../ggml/include ../examples) + target_link_libraries(${TEST_TARGET} PRIVATE parakeet common) + target_compile_definitions(${TEST_TARGET} PRIVATE + PARAKEET_MODEL_PATH="${PROJECT_SOURCE_DIR}/models/ggml-parakeet-tdt-0.6b-v3-f32.bin" + SAMPLE_PATH="${PROJECT_SOURCE_DIR}/${SAMPLE_PATH}" + EXPECTED_TRANSCRIPTION_PATH="${PROJECT_SOURCE_DIR}/${EXPECTED_TRANSCRIPTION_PATH}" + TRANSCRIPTION_SIMILARITY_THRESHOLD=${TRANSCRIPTION_SIMILARITY_THRESHOLD}) + + add_custom_target(run-${TEST_TARGET} + COMMAND $<TARGET_FILE:${TEST_TARGET}> + DEPENDS ${TEST_TARGET} + WORKING_DIRECTORY ${PROJECT_BINARY_DIR}) +endfunction() + +add_parakeet_transcription_test( + test-parakeet-full-jfk + test-parakeet-full.cpp + samples/jfk.wav + tests/parakeet-expected-jfk-output.txt) + +add_parakeet_transcription_test( + test-parakeet-full-gb1 + test-parakeet-full.cpp + samples/gb1.wav + tests/parakeet-expected-gb1-output.txt) + +add_parakeet_transcription_test( + test-parakeet-full-diffusion + test-parakeet-full.cpp + samples/diffusion2023-07-03.flac + tests/parakeet-expected-diffusion-output.txt + 0.95) + diff --git a/tests/librispeech-parakeet/.gitignore b/tests/librispeech-parakeet/.gitignore new file mode 100644 index 00000000000..838bfeae9db --- /dev/null +++ b/tests/librispeech-parakeet/.gitignore @@ -0,0 +1,6 @@ +__pycache__ +*.tar.gz +*.txt +eval.conf +venv +LibriSpeech diff --git a/tests/librispeech-parakeet/Makefile b/tests/librispeech-parakeet/Makefile new file mode 100644 index 00000000000..0afa2465f49 --- /dev/null +++ b/tests/librispeech-parakeet/Makefile @@ -0,0 +1,15 @@ +TAR_URL = https://www.openslr.org/resources/12/test-clean.tar.gz + +all: eval + +eval: + $(MAKE) -f eval.mk + +clean: + $(MAKE) -f eval.mk clean + +get-audio: + wget -c $(TAR_URL) + tar -xf test-clean.tar.gz + +.PHONY: all eval clean setup-venv clean-venv get-audio diff --git a/tests/librispeech-parakeet/README.md b/tests/librispeech-parakeet/README.md new file mode 100644 index 00000000000..e09cba405ef --- /dev/null +++ b/tests/librispeech-parakeet/README.md @@ -0,0 +1,57 @@ +# parakeet.cpp/tests/librispeech + +[LibriSpeech](https://www.openslr.org/12) is a standard dataset for +training and evaluating automatic speech recognition systems. + +This directory contains a set of tools to evaluate the recognition +performance of parakeet.cpp on LibriSpeech corpus. + +## Quick Start + +1. (Pre-requirement) Compile `parakeet-cli` and prepare the Parakeet + model in `ggml` format. + + ``` + $ # Execute the commands below in the project root dir. + $ cmake -B build + $ cmake --build build --config Release + ``` + +2. Download the audio files from LibriSpeech project. + + ``` + $ make get-audio + ``` + +3. Set up the environment to compute WER score. + + ``` + $ pip install -r requirements.txt + ``` + + For example, if you use `virtualenv`, you can set up it as follows: + + ``` + $ python3 -m venv venv + $ . venv/bin/activate + $ pip install -r requirements.txt + ``` + +4. Run the benchmark test. + + ``` + $ make + ``` + +## How-to guides + +### How to change the inference parameters + +Create `eval.conf` and override variables. + +``` +PARAKEET_MODEL = parakeet-tdt-0.6b-v3 +PARAKEET_FLAGS = --no-prints --threads 8 --language en --output-txt +``` + +Check out `eval.mk` for more details. diff --git a/tests/librispeech-parakeet/eval.mk b/tests/librispeech-parakeet/eval.mk new file mode 100644 index 00000000000..7d8992ec471 --- /dev/null +++ b/tests/librispeech-parakeet/eval.mk @@ -0,0 +1,39 @@ +PYTHON = python + +PARAKEET_PREFIX = ../../ +PARAKEET_MODEL = parakeet-tdt-0.6b-v3 + +PARAKEET_CLI = $(PARAKEET_PREFIX)build/bin/parakeet-cli +PARAKEET_FLAGS = --no-prints --output-txt + +# You can create eval.conf to override the PARAKEET_* variables +# defined above. +-include eval.conf + +# This follows the file structure of the LibriSpeech project. +AUDIO_SRCS = $(sort $(wildcard LibriSpeech/*/*/*/*.flac)) +TRANS_TXTS = $(addsuffix .txt, $(AUDIO_SRCS)) + +# We output the evaluation result to this file. +DONE = $(PARAKEET_MODEL).txt + +all: $(DONE) + +$(DONE): $(TRANS_TXTS) + $(PYTHON) eval.py > $@.tmp + mv $@.tmp $@ + +# Note: This task writes to a temporary file first to +# create the target file atomically. +%.flac.txt: %.flac + $(PARAKEET_CLI) $(PARAKEET_FLAGS) --model $(PARAKEET_PREFIX)models/ggml-$(PARAKEET_MODEL).bin --file $^ --output-file $^.tmp + mv $^.tmp.txt $^.txt + +archive: + tar -czf $(PARAKEET_MODEL).tar.gz --exclude="*.flac" LibriSpeech $(DONE) + +clean: + @rm -f $(TRANS_TXTS) + @rm -f $(DONE) + +.PHONY: all clean diff --git a/tests/librispeech-parakeet/eval.py b/tests/librispeech-parakeet/eval.py new file mode 100644 index 00000000000..cdaf8352fd4 --- /dev/null +++ b/tests/librispeech-parakeet/eval.py @@ -0,0 +1,47 @@ +import os +import glob +import jiwer +from normalizers import EnglishTextNormalizer + +def get_reference(): + ref = {} + for path in glob.glob('LibriSpeech/*/*/*/*.trans.txt'): + with open(path) as fp: + for line in fp: + code, text = line.strip().split(" ", maxsplit=1) + ref [code] = text + return ref + +def get_hypothesis(): + hyp = {} + for path in glob.glob('LibriSpeech/*/*/*/*.flac.txt'): + with open(path) as fp: + text = fp.read().strip() + code = os.path.basename(path).replace('.flac.txt', '') + hyp[code] = text + return hyp + +def get_codes(): + codes = [] + for path in glob.glob('LibriSpeech/*/*/*/*.flac'): + codes.append(os.path.basename(path).replace('.flac', '')) + return sorted(codes) + +def main(): + normalizer = EnglishTextNormalizer() + + ref_orig = get_reference() + hyp_orig = get_hypothesis() + + ref_clean = [] + hyp_clean = [] + + for code in get_codes(): + ref_clean.append(normalizer(ref_orig[code])) + hyp_clean.append(normalizer(hyp_orig[code])) + + wer = jiwer.wer(ref_clean, hyp_clean) + print(f"WER: {wer * 100:.2f}%") + +if __name__ == '__main__': + main() diff --git a/tests/librispeech-parakeet/normalizers/LICENSE b/tests/librispeech-parakeet/normalizers/LICENSE new file mode 100644 index 00000000000..7c8e603b0fc --- /dev/null +++ b/tests/librispeech-parakeet/normalizers/LICENSE @@ -0,0 +1,25 @@ +Code in this directory is adapted from OpenAI Whisper project +(https://github.com/openai/whisper) and carries the following +copyright and license. + + MIT License + + Copyright (c) 2022 OpenAI + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. diff --git a/tests/librispeech-parakeet/normalizers/__init__.py b/tests/librispeech-parakeet/normalizers/__init__.py new file mode 100644 index 00000000000..896d5e33641 --- /dev/null +++ b/tests/librispeech-parakeet/normalizers/__init__.py @@ -0,0 +1,2 @@ +from .basic import BasicTextNormalizer as BasicTextNormalizer +from .english import EnglishTextNormalizer as EnglishTextNormalizer diff --git a/tests/librispeech-parakeet/normalizers/basic.py b/tests/librispeech-parakeet/normalizers/basic.py new file mode 100644 index 00000000000..8690ae71c5f --- /dev/null +++ b/tests/librispeech-parakeet/normalizers/basic.py @@ -0,0 +1,80 @@ +import re +import unicodedata + +import regex + +# non-ASCII letters that are not separated by "NFKD" normalization +ADDITIONAL_DIACRITICS = { + "œ": "oe", + "Œ": "OE", + "ø": "o", + "Ø": "O", + "æ": "ae", + "Æ": "AE", + "ß": "ss", + "ẞ": "SS", + "đ": "d", + "Đ": "D", + "ð": "d", + "Ð": "D", + "þ": "th", + "Þ": "th", + "ł": "l", + "Ł": "L", +} + + +def remove_symbols_and_diacritics(s: str, keep=""): + """ + Replace any other markers, symbols, and punctuations with a space, + and drop any diacritics (category 'Mn' and some manual mappings) + """ + return "".join( + ( + c + if c in keep + else ( + ADDITIONAL_DIACRITICS[c] + if c in ADDITIONAL_DIACRITICS + else ( + "" + if unicodedata.category(c) == "Mn" + else " " if unicodedata.category(c)[0] in "MSP" else c + ) + ) + ) + for c in unicodedata.normalize("NFKD", s) + ) + + +def remove_symbols(s: str): + """ + Replace any other markers, symbols, punctuations with a space, keeping diacritics + """ + return "".join( + " " if unicodedata.category(c)[0] in "MSP" else c + for c in unicodedata.normalize("NFKC", s) + ) + + +class BasicTextNormalizer: + def __init__(self, remove_diacritics: bool = False, split_letters: bool = False): + self.clean = ( + remove_symbols_and_diacritics if remove_diacritics else remove_symbols + ) + self.split_letters = split_letters + + def __call__(self, s: str): + s = s.lower() + s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets + s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis + s = self.clean(s).lower() + + if self.split_letters: + s = " ".join(regex.findall(r"\X", s, regex.U)) + + s = re.sub( + r"\s+", " ", s + ) # replace any successive whitespace characters with a space + + return s diff --git a/tests/librispeech-parakeet/normalizers/english.json b/tests/librispeech-parakeet/normalizers/english.json new file mode 100644 index 00000000000..74a1c3521d9 --- /dev/null +++ b/tests/librispeech-parakeet/normalizers/english.json @@ -0,0 +1,1741 @@ +{ + "accessorise": "accessorize", + "accessorised": "accessorized", + "accessorises": "accessorizes", + "accessorising": "accessorizing", + "acclimatisation": "acclimatization", + "acclimatise": "acclimatize", + "acclimatised": "acclimatized", + "acclimatises": "acclimatizes", + "acclimatising": "acclimatizing", + "accoutrements": "accouterments", + "aeon": "eon", + "aeons": "eons", + "aerogramme": "aerogram", + "aerogrammes": "aerograms", + "aeroplane": "airplane", + "aeroplanes": "airplanes", + "aesthete": "esthete", + "aesthetes": "esthetes", + "aesthetic": "esthetic", + "aesthetically": "esthetically", + "aesthetics": "esthetics", + "aetiology": "etiology", + "ageing": "aging", + "aggrandisement": "aggrandizement", + "agonise": "agonize", + "agonised": "agonized", + "agonises": "agonizes", + "agonising": "agonizing", + "agonisingly": "agonizingly", + "almanack": "almanac", + "almanacks": "almanacs", + "aluminium": "aluminum", + "amortisable": "amortizable", + "amortisation": "amortization", + "amortisations": "amortizations", + "amortise": "amortize", + "amortised": "amortized", + "amortises": "amortizes", + "amortising": "amortizing", + "amphitheatre": "amphitheater", + "amphitheatres": "amphitheaters", + "anaemia": "anemia", + "anaemic": "anemic", + "anaesthesia": "anesthesia", + "anaesthetic": "anesthetic", + "anaesthetics": "anesthetics", + "anaesthetise": "anesthetize", + "anaesthetised": "anesthetized", + "anaesthetises": "anesthetizes", + "anaesthetising": "anesthetizing", + "anaesthetist": "anesthetist", + "anaesthetists": "anesthetists", + "anaesthetize": "anesthetize", + "anaesthetized": "anesthetized", + "anaesthetizes": "anesthetizes", + "anaesthetizing": "anesthetizing", + "analogue": "analog", + "analogues": "analogs", + "analyse": "analyze", + "analysed": "analyzed", + "analyses": "analyzes", + "analysing": "analyzing", + "anglicise": "anglicize", + "anglicised": "anglicized", + "anglicises": "anglicizes", + "anglicising": "anglicizing", + "annualised": "annualized", + "antagonise": "antagonize", + "antagonised": "antagonized", + "antagonises": "antagonizes", + "antagonising": "antagonizing", + "apologise": "apologize", + "apologised": "apologized", + "apologises": "apologizes", + "apologising": "apologizing", + "appal": "appall", + "appals": "appalls", + "appetiser": "appetizer", + "appetisers": "appetizers", + "appetising": "appetizing", + "appetisingly": "appetizingly", + "arbour": "arbor", + "arbours": "arbors", + "archeological": "archaeological", + "archaeologically": "archeologically", + "archaeologist": "archeologist", + "archaeologists": "archeologists", + "archaeology": "archeology</span>", + "ardour": "ardor", + "armour": "armor", + "armoured": "armored", + "armourer": "armorer", + "armourers": "armorers", + "armouries": "armories", + "armoury": "armory", + "artefact": "artifact", + "artefacts": "artifacts", + "authorise": "authorize", + "authorised": "authorized", + "authorises": "authorizes", + "authorising": "authorizing", + "axe": "ax", + "backpedalled": "backpedaled", + "backpedalling": "backpedaling", + "bannister": "banister", + "bannisters": "banisters", + "baptise": "baptize", + "baptised": "baptized", + "baptises": "baptizes", + "baptising": "baptizing", + "bastardise": "bastardize", + "bastardised": "bastardized", + "bastardises": "bastardizes", + "bastardising": "bastardizing", + "battleax": "battleaxe", + "baulk": "balk", + "baulked": "balked", + "baulking": "balking", + "baulks": "balks", + "bedevilled": "bedeviled", + "bedevilling": "bedeviling", + "behaviour": "behavior", + "behavioural": "behavioral", + "behaviourism": "behaviorism", + "behaviourist": "behaviorist", + "behaviourists": "behaviorists", + "behaviours": "behaviors", + "behove": "behoove", + "behoved": "behooved", + "behoves": "behooves", + "bejewelled": "bejeweled", + "belabour": "belabor", + "belaboured": "belabored", + "belabouring": "belaboring", + "belabours": "belabors", + "bevelled": "beveled", + "bevvies": "bevies", + "bevvy": "bevy", + "biassed": "biased", + "biassing": "biasing", + "bingeing": "binging", + "bougainvillaea": "bougainvillea", + "bougainvillaeas": "bougainvilleas", + "bowdlerise": "bowdlerize", + "bowdlerised": "bowdlerized", + "bowdlerises": "bowdlerizes", + "bowdlerising": "bowdlerizing", + "breathalyse": "breathalyze", + "breathalysed": "breathalyzed", + "breathalyser": "breathalyzer", + "breathalysers": "breathalyzers", + "breathalyses": "breathalyzes", + "breathalysing": "breathalyzing", + "brutalise": "brutalize", + "brutalised": "brutalized", + "brutalises": "brutalizes", + "brutalising": "brutalizing", + "busses": "buses", + "bussing": "busing", + "caesarean": "cesarean", + "caesareans": "cesareans", + "calibre": "caliber", + "calibres": "calibers", + "calliper": "caliper", + "callipers": "calipers", + "callisthenics": "calisthenics", + "canalise": "canalize", + "canalised": "canalized", + "canalises": "canalizes", + "canalising": "canalizing", + "cancelation": "cancellation", + "cancelations": "cancellations", + "cancelled": "canceled", + "cancelling": "canceling", + "candour": "candor", + "cannibalise": "cannibalize", + "cannibalised": "cannibalized", + "cannibalises": "cannibalizes", + "cannibalising": "cannibalizing", + "canonise": "canonize", + "canonised": "canonized", + "canonises": "canonizes", + "canonising": "canonizing", + "capitalise": "capitalize", + "capitalised": "capitalized", + "capitalises": "capitalizes", + "capitalising": "capitalizing", + "caramelise": "caramelize", + "caramelised": "caramelized", + "caramelises": "caramelizes", + "caramelising": "caramelizing", + "carbonise": "carbonize", + "carbonised": "carbonized", + "carbonises": "carbonizes", + "carbonising": "carbonizing", + "carolled": "caroled", + "carolling": "caroling", + "catalogue": "catalog", + "catalogued": "cataloged", + "catalogues": "catalogs", + "cataloguing": "cataloging", + "catalyse": "catalyze", + "catalysed": "catalyzed", + "catalyses": "catalyzes", + "catalysing": "catalyzing", + "categorise": "categorize", + "categorised": "categorized", + "categorises": "categorizes", + "categorising": "categorizing", + "cauterise": "cauterize", + "cauterised": "cauterized", + "cauterises": "cauterizes", + "cauterising": "cauterizing", + "cavilled": "caviled", + "cavilling": "caviling", + "centigramme": "centigram", + "centigrammes": "centigrams", + "centilitre": "centiliter", + "centilitres": "centiliters", + "centimetre": "centimeter", + "centimetres": "centimeters", + "centralise": "centralize", + "centralised": "centralized", + "centralises": "centralizes", + "centralising": "centralizing", + "centre": "center", + "centred": "centered", + "centrefold": "centerfold", + "centrefolds": "centerfolds", + "centrepiece": "centerpiece", + "centrepieces": "centerpieces", + "centres": "centers", + "channelled": "channeled", + "channelling": "channeling", + "characterise": "characterize", + "characterised": "characterized", + "characterises": "characterizes", + "characterising": "characterizing", + "cheque": "check", + "chequebook": "checkbook", + "chequebooks": "checkbooks", + "chequered": "checkered", + "cheques": "checks", + "chilli": "chili", + "chimaera": "chimera", + "chimaeras": "chimeras", + "chiselled": "chiseled", + "chiselling": "chiseling", + "circularise": "circularize", + "circularised": "circularized", + "circularises": "circularizes", + "circularising": "circularizing", + "civilise": "civilize", + "civilised": "civilized", + "civilises": "civilizes", + "civilising": "civilizing", + "clamour": "clamor", + "clamoured": "clamored", + "clamouring": "clamoring", + "clamours": "clamors", + "clangour": "clangor", + "clarinettist": "clarinetist", + "clarinettists": "clarinetists", + "collectivise": "collectivize", + "collectivised": "collectivized", + "collectivises": "collectivizes", + "collectivising": "collectivizing", + "colonisation": "colonization", + "colonise": "colonize", + "colonised": "colonized", + "coloniser": "colonizer", + "colonisers": "colonizers", + "colonises": "colonizes", + "colonising": "colonizing", + "colour": "color", + "colourant": "colorant", + "colourants": "colorants", + "coloured": "colored", + "coloureds": "coloreds", + "colourful": "colorful", + "colourfully": "colorfully", + "colouring": "coloring", + "colourize": "colorize", + "colourized": "colorized", + "colourizes": "colorizes", + "colourizing": "colorizing", + "colourless": "colorless", + "colours": "colors", + "commercialise": "commercialize", + "commercialised": "commercialized", + "commercialises": "commercializes", + "commercialising": "commercializing", + "compartmentalise": "compartmentalize", + "compartmentalised": "compartmentalized", + "compartmentalises": "compartmentalizes", + "compartmentalising": "compartmentalizing", + "computerise": "computerize", + "computerised": "computerized", + "computerises": "computerizes", + "computerising": "computerizing", + "conceptualise": "conceptualize", + "conceptualised": "conceptualized", + "conceptualises": "conceptualizes", + "conceptualising": "conceptualizing", + "connexion": "connection", + "connexions": "connections", + "contextualise": "contextualize", + "contextualised": "contextualized", + "contextualises": "contextualizes", + "contextualising": "contextualizing", + "cosier": "cozier", + "cosies": "cozies", + "cosiest": "coziest", + "cosily": "cozily", + "cosiness": "coziness", + "cosy": "cozy", + "councillor": "councilor", + "councillors": "councilors", + "counselled": "counseled", + "counselling": "counseling", + "counsellor": "counselor", + "counsellors": "counselors", + "crenelated": "crenellated", + "criminalise": "criminalize", + "criminalised": "criminalized", + "criminalises": "criminalizes", + "criminalising": "criminalizing", + "criticise": "criticize", + "criticised": "criticized", + "criticises": "criticizes", + "criticising": "criticizing", + "crueller": "crueler", + "cruellest": "cruelest", + "crystallisation": "crystallization", + "crystallise": "crystallize", + "crystallised": "crystallized", + "crystallises": "crystallizes", + "crystallising": "crystallizing", + "cudgelled": "cudgeled", + "cudgelling": "cudgeling", + "customise": "customize", + "customised": "customized", + "customises": "customizes", + "customising": "customizing", + "cypher": "cipher", + "cyphers": "ciphers", + "decentralisation": "decentralization", + "decentralise": "decentralize", + "decentralised": "decentralized", + "decentralises": "decentralizes", + "decentralising": "decentralizing", + "decriminalisation": "decriminalization", + "decriminalise": "decriminalize", + "decriminalised": "decriminalized", + "decriminalises": "decriminalizes", + "decriminalising": "decriminalizing", + "defence": "defense", + "defenceless": "defenseless", + "defences": "defenses", + "dehumanisation": "dehumanization", + "dehumanise": "dehumanize", + "dehumanised": "dehumanized", + "dehumanises": "dehumanizes", + "dehumanising": "dehumanizing", + "demeanour": "demeanor", + "demilitarisation": "demilitarization", + "demilitarise": "demilitarize", + "demilitarised": "demilitarized", + "demilitarises": "demilitarizes", + "demilitarising": "demilitarizing", + "demobilisation": "demobilization", + "demobilise": "demobilize", + "demobilised": "demobilized", + "demobilises": "demobilizes", + "demobilising": "demobilizing", + "democratisation": "democratization", + "democratise": "democratize", + "democratised": "democratized", + "democratises": "democratizes", + "democratising": "democratizing", + "demonise": "demonize", + "demonised": "demonized", + "demonises": "demonizes", + "demonising": "demonizing", + "demoralisation": "demoralization", + "demoralise": "demoralize", + "demoralised": "demoralized", + "demoralises": "demoralizes", + "demoralising": "demoralizing", + "denationalisation": "denationalization", + "denationalise": "denationalize", + "denationalised": "denationalized", + "denationalises": "denationalizes", + "denationalising": "denationalizing", + "deodorise": "deodorize", + "deodorised": "deodorized", + "deodorises": "deodorizes", + "deodorising": "deodorizing", + "depersonalise": "depersonalize", + "depersonalised": "depersonalized", + "depersonalises": "depersonalizes", + "depersonalising": "depersonalizing", + "deputise": "deputize", + "deputised": "deputized", + "deputises": "deputizes", + "deputising": "deputizing", + "desensitisation": "desensitization", + "desensitise": "desensitize", + "desensitised": "desensitized", + "desensitises": "desensitizes", + "desensitising": "desensitizing", + "destabilisation": "destabilization", + "destabilise": "destabilize", + "destabilised": "destabilized", + "destabilises": "destabilizes", + "destabilising": "destabilizing", + "dialled": "dialed", + "dialling": "dialing", + "dialogue": "dialog", + "dialogues": "dialogs", + "diarrhoea": "diarrhea", + "digitise": "digitize", + "digitised": "digitized", + "digitises": "digitizes", + "digitising": "digitizing", + "disc": "disk", + "discolour": "discolor", + "discoloured": "discolored", + "discolouring": "discoloring", + "discolours": "discolors", + "discs": "disks", + "disembowelled": "disemboweled", + "disembowelling": "disemboweling", + "disfavour": "disfavor", + "dishevelled": "disheveled", + "dishonour": "dishonor", + "dishonourable": "dishonorable", + "dishonourably": "dishonorably", + "dishonoured": "dishonored", + "dishonouring": "dishonoring", + "dishonours": "dishonors", + "disorganisation": "disorganization", + "disorganised": "disorganized", + "distil": "distill", + "distils": "distills", + "dramatisation": "dramatization", + "dramatisations": "dramatizations", + "dramatise": "dramatize", + "dramatised": "dramatized", + "dramatises": "dramatizes", + "dramatising": "dramatizing", + "draught": "draft", + "draughtboard": "draftboard", + "draughtboards": "draftboards", + "draughtier": "draftier", + "draughtiest": "draftiest", + "draughts": "drafts", + "draughtsman": "draftsman", + "draughtsmanship": "draftsmanship", + "draughtsmen": "draftsmen", + "draughtswoman": "draftswoman", + "draughtswomen": "draftswomen", + "draughty": "drafty", + "drivelled": "driveled", + "drivelling": "driveling", + "duelled": "dueled", + "duelling": "dueling", + "economise": "economize", + "economised": "economized", + "economises": "economizes", + "economising": "economizing", + "edoema": "edema", + "editorialise": "editorialize", + "editorialised": "editorialized", + "editorialises": "editorializes", + "editorialising": "editorializing", + "empathise": "empathize", + "empathised": "empathized", + "empathises": "empathizes", + "empathising": "empathizing", + "emphasise": "emphasize", + "emphasised": "emphasized", + "emphasises": "emphasizes", + "emphasising": "emphasizing", + "enamelled": "enameled", + "enamelling": "enameling", + "enamoured": "enamored", + "encyclopaedia": "encyclopedia", + "encyclopaedias": "encyclopedias", + "encyclopaedic": "encyclopedic", + "endeavour": "endeavor", + "endeavoured": "endeavored", + "endeavouring": "endeavoring", + "endeavours": "endeavors", + "energise": "energize", + "energised": "energized", + "energises": "energizes", + "energising": "energizing", + "enrol": "enroll", + "enrols": "enrolls", + "enthral": "enthrall", + "enthrals": "enthralls", + "epaulette": "epaulet", + "epaulettes": "epaulets", + "epicentre": "epicenter", + "epicentres": "epicenters", + "epilogue": "epilog", + "epilogues": "epilogs", + "epitomise": "epitomize", + "epitomised": "epitomized", + "epitomises": "epitomizes", + "epitomising": "epitomizing", + "equalisation": "equalization", + "equalise": "equalize", + "equalised": "equalized", + "equaliser": "equalizer", + "equalisers": "equalizers", + "equalises": "equalizes", + "equalising": "equalizing", + "eulogise": "eulogize", + "eulogised": "eulogized", + "eulogises": "eulogizes", + "eulogising": "eulogizing", + "evangelise": "evangelize", + "evangelised": "evangelized", + "evangelises": "evangelizes", + "evangelising": "evangelizing", + "exorcise": "exorcize", + "exorcised": "exorcized", + "exorcises": "exorcizes", + "exorcising": "exorcizing", + "extemporisation": "extemporization", + "extemporise": "extemporize", + "extemporised": "extemporized", + "extemporises": "extemporizes", + "extemporising": "extemporizing", + "externalisation": "externalization", + "externalisations": "externalizations", + "externalise": "externalize", + "externalised": "externalized", + "externalises": "externalizes", + "externalising": "externalizing", + "factorise": "factorize", + "factorised": "factorized", + "factorises": "factorizes", + "factorising": "factorizing", + "faecal": "fecal", + "faeces": "feces", + "familiarisation": "familiarization", + "familiarise": "familiarize", + "familiarised": "familiarized", + "familiarises": "familiarizes", + "familiarising": "familiarizing", + "fantasise": "fantasize", + "fantasised": "fantasized", + "fantasises": "fantasizes", + "fantasising": "fantasizing", + "favour": "favor", + "favourable": "favorable", + "favourably": "favorably", + "favoured": "favored", + "favouring": "favoring", + "favourite": "favorite", + "favourites": "favorites", + "favouritism": "favoritism", + "favours": "favors", + "feminise": "feminize", + "feminised": "feminized", + "feminises": "feminizes", + "feminising": "feminizing", + "fertilisation": "fertilization", + "fertilise": "fertilize", + "fertilised": "fertilized", + "fertiliser": "fertilizer", + "fertilisers": "fertilizers", + "fertilises": "fertilizes", + "fertilising": "fertilizing", + "fervour": "fervor", + "fibre": "fiber", + "fibreglass": "fiberglass", + "fibres": "fibers", + "fictionalisation": "fictionalization", + "fictionalisations": "fictionalizations", + "fictionalise": "fictionalize", + "fictionalised": "fictionalized", + "fictionalises": "fictionalizes", + "fictionalising": "fictionalizing", + "fillet": "filet", + "filleted": "fileted", + "filleting": "fileting", + "fillets": "filets", + "finalisation": "finalization", + "finalise": "finalize", + "finalised": "finalized", + "finalises": "finalizes", + "finalising": "finalizing", + "flautist": "flutist", + "flautists": "flutists", + "flavour": "flavor", + "flavoured": "flavored", + "flavouring": "flavoring", + "flavourings": "flavorings", + "flavourless": "flavorless", + "flavours": "flavors", + "flavoursome": "flavorsome", + "flyer / flier": "flier / flyer", + "foetal": "fetal", + "foetid": "fetid", + "foetus": "fetus", + "foetuses": "fetuses", + "formalisation": "formalization", + "formalise": "formalize", + "formalised": "formalized", + "formalises": "formalizes", + "formalising": "formalizing", + "fossilisation": "fossilization", + "fossilise": "fossilize", + "fossilised": "fossilized", + "fossilises": "fossilizes", + "fossilising": "fossilizing", + "fraternisation": "fraternization", + "fraternise": "fraternize", + "fraternised": "fraternized", + "fraternises": "fraternizes", + "fraternising": "fraternizing", + "fulfil": "fulfill", + "fulfilment": "fulfillment", + "fulfils": "fulfills", + "funnelled": "funneled", + "funnelling": "funneling", + "galvanise": "galvanize", + "galvanised": "galvanized", + "galvanises": "galvanizes", + "galvanising": "galvanizing", + "gambolled": "gamboled", + "gambolling": "gamboling", + "gaol": "jail", + "gaolbird": "jailbird", + "gaolbirds": "jailbirds", + "gaolbreak": "jailbreak", + "gaolbreaks": "jailbreaks", + "gaoled": "jailed", + "gaoler": "jailer", + "gaolers": "jailers", + "gaoling": "jailing", + "gaols": "jails", + "gasses": "gases", + "gage": "gauge", + "gaged": "gauged", + "gages": "gauges", + "gaging": "gauging", + "generalisation": "generalization", + "generalisations": "generalizations", + "generalise": "generalize", + "generalised": "generalized", + "generalises": "generalizes", + "generalising": "generalizing", + "ghettoise": "ghettoize", + "ghettoised": "ghettoized", + "ghettoises": "ghettoizes", + "ghettoising": "ghettoizing", + "gipsies": "gypsies", + "glamorise": "glamorize", + "glamorised": "glamorized", + "glamorises": "glamorizes", + "glamorising": "glamorizing", + "glamor": "glamour", + "globalisation": "globalization", + "globalise": "globalize", + "globalised": "globalized", + "globalises": "globalizes", + "globalising": "globalizing", + "glueing": "gluing", + "goitre": "goiter", + "goitres": "goiters", + "gonorrhoea": "gonorrhea", + "gramme": "gram", + "grammes": "grams", + "gravelled": "graveled", + "grey": "gray", + "greyed": "grayed", + "greying": "graying", + "greyish": "grayish", + "greyness": "grayness", + "greys": "grays", + "grovelled": "groveled", + "grovelling": "groveling", + "groyne": "groin", + "groynes": "groins", + "gruelling": "grueling", + "gruellingly": "gruelingly", + "gryphon": "griffin", + "gryphons": "griffins", + "gynaecological": "gynecological", + "gynaecologist": "gynecologist", + "gynaecologists": "gynecologists", + "gynaecology": "gynecology", + "haematological": "hematological", + "haematologist": "hematologist", + "haematologists": "hematologists", + "haematology": "hematology", + "haemoglobin": "hemoglobin", + "haemophilia": "hemophilia", + "haemophiliac": "hemophiliac", + "haemophiliacs": "hemophiliacs", + "haemorrhage": "hemorrhage", + "haemorrhaged": "hemorrhaged", + "haemorrhages": "hemorrhages", + "haemorrhaging": "hemorrhaging", + "haemorrhoids": "hemorrhoids", + "harbour": "harbor", + "harboured": "harbored", + "harbouring": "harboring", + "harbours": "harbors", + "harmonisation": "harmonization", + "harmonise": "harmonize", + "harmonised": "harmonized", + "harmonises": "harmonizes", + "harmonising": "harmonizing", + "homoeopath": "homeopath", + "homoeopathic": "homeopathic", + "homoeopaths": "homeopaths", + "homoeopathy": "homeopathy", + "homogenise": "homogenize", + "homogenised": "homogenized", + "homogenises": "homogenizes", + "homogenising": "homogenizing", + "honour": "honor", + "honourable": "honorable", + "honourably": "honorably", + "honoured": "honored", + "honouring": "honoring", + "honours": "honors", + "hospitalisation": "hospitalization", + "hospitalise": "hospitalize", + "hospitalised": "hospitalized", + "hospitalises": "hospitalizes", + "hospitalising": "hospitalizing", + "humanise": "humanize", + "humanised": "humanized", + "humanises": "humanizes", + "humanising": "humanizing", + "humour": "humor", + "humoured": "humored", + "humouring": "humoring", + "humourless": "humorless", + "humours": "humors", + "hybridise": "hybridize", + "hybridised": "hybridized", + "hybridises": "hybridizes", + "hybridising": "hybridizing", + "hypnotise": "hypnotize", + "hypnotised": "hypnotized", + "hypnotises": "hypnotizes", + "hypnotising": "hypnotizing", + "hypothesise": "hypothesize", + "hypothesised": "hypothesized", + "hypothesises": "hypothesizes", + "hypothesising": "hypothesizing", + "idealisation": "idealization", + "idealise": "idealize", + "idealised": "idealized", + "idealises": "idealizes", + "idealising": "idealizing", + "idolise": "idolize", + "idolised": "idolized", + "idolises": "idolizes", + "idolising": "idolizing", + "immobilisation": "immobilization", + "immobilise": "immobilize", + "immobilised": "immobilized", + "immobiliser": "immobilizer", + "immobilisers": "immobilizers", + "immobilises": "immobilizes", + "immobilising": "immobilizing", + "immortalise": "immortalize", + "immortalised": "immortalized", + "immortalises": "immortalizes", + "immortalising": "immortalizing", + "immunisation": "immunization", + "immunise": "immunize", + "immunised": "immunized", + "immunises": "immunizes", + "immunising": "immunizing", + "impanelled": "impaneled", + "impanelling": "impaneling", + "imperilled": "imperiled", + "imperilling": "imperiling", + "individualise": "individualize", + "individualised": "individualized", + "individualises": "individualizes", + "individualising": "individualizing", + "industrialise": "industrialize", + "industrialised": "industrialized", + "industrialises": "industrializes", + "industrialising": "industrializing", + "inflexion": "inflection", + "inflexions": "inflections", + "initialise": "initialize", + "initialised": "initialized", + "initialises": "initializes", + "initialising": "initializing", + "initialled": "initialed", + "initialling": "initialing", + "instal": "install", + "instalment": "installment", + "instalments": "installments", + "instals": "installs", + "instil": "instill", + "instils": "instills", + "institutionalisation": "institutionalization", + "institutionalise": "institutionalize", + "institutionalised": "institutionalized", + "institutionalises": "institutionalizes", + "institutionalising": "institutionalizing", + "intellectualise": "intellectualize", + "intellectualised": "intellectualized", + "intellectualises": "intellectualizes", + "intellectualising": "intellectualizing", + "internalisation": "internalization", + "internalise": "internalize", + "internalised": "internalized", + "internalises": "internalizes", + "internalising": "internalizing", + "internationalisation": "internationalization", + "internationalise": "internationalize", + "internationalised": "internationalized", + "internationalises": "internationalizes", + "internationalising": "internationalizing", + "ionisation": "ionization", + "ionise": "ionize", + "ionised": "ionized", + "ioniser": "ionizer", + "ionisers": "ionizers", + "ionises": "ionizes", + "ionising": "ionizing", + "italicise": "italicize", + "italicised": "italicized", + "italicises": "italicizes", + "italicising": "italicizing", + "itemise": "itemize", + "itemised": "itemized", + "itemises": "itemizes", + "itemising": "itemizing", + "jeopardise": "jeopardize", + "jeopardised": "jeopardized", + "jeopardises": "jeopardizes", + "jeopardising": "jeopardizing", + "jewelled": "jeweled", + "jeweller": "jeweler", + "jewellers": "jewelers", + "jewellery": "jewelry", + "judgement": "judgment", + "kilogramme": "kilogram", + "kilogrammes": "kilograms", + "kilometre": "kilometer", + "kilometres": "kilometers", + "labelled": "labeled", + "labelling": "labeling", + "labour": "labor", + "laboured": "labored", + "labourer": "laborer", + "labourers": "laborers", + "labouring": "laboring", + "labours": "labors", + "lacklustre": "lackluster", + "legalisation": "legalization", + "legalise": "legalize", + "legalised": "legalized", + "legalises": "legalizes", + "legalising": "legalizing", + "legitimise": "legitimize", + "legitimised": "legitimized", + "legitimises": "legitimizes", + "legitimising": "legitimizing", + "leukaemia": "leukemia", + "levelled": "leveled", + "leveller": "leveler", + "levellers": "levelers", + "levelling": "leveling", + "libelled": "libeled", + "libelling": "libeling", + "libellous": "libelous", + "liberalisation": "liberalization", + "liberalise": "liberalize", + "liberalised": "liberalized", + "liberalises": "liberalizes", + "liberalising": "liberalizing", + "licence": "license", + "licenced": "licensed", + "licences": "licenses", + "licencing": "licensing", + "likeable": "likable", + "lionisation": "lionization", + "lionise": "lionize", + "lionised": "lionized", + "lionises": "lionizes", + "lionising": "lionizing", + "liquidise": "liquidize", + "liquidised": "liquidized", + "liquidiser": "liquidizer", + "liquidisers": "liquidizers", + "liquidises": "liquidizes", + "liquidising": "liquidizing", + "litre": "liter", + "litres": "liters", + "localise": "localize", + "localised": "localized", + "localises": "localizes", + "localising": "localizing", + "louvre": "louver", + "louvred": "louvered", + "louvres": "louvers", + "lustre": "luster", + "magnetise": "magnetize", + "magnetised": "magnetized", + "magnetises": "magnetizes", + "magnetising": "magnetizing", + "manoeuvrability": "maneuverability", + "manoeuvrable": "maneuverable", + "manoeuvre": "maneuver", + "manoeuvred": "maneuvered", + "manoeuvres": "maneuvers", + "manoeuvring": "maneuvering", + "manoeuvrings": "maneuverings", + "marginalisation": "marginalization", + "marginalise": "marginalize", + "marginalised": "marginalized", + "marginalises": "marginalizes", + "marginalising": "marginalizing", + "marshalled": "marshaled", + "marshalling": "marshaling", + "marvelled": "marveled", + "marvelling": "marveling", + "marvellous": "marvelous", + "marvellously": "marvelously", + "materialisation": "materialization", + "materialise": "materialize", + "materialised": "materialized", + "materialises": "materializes", + "materialising": "materializing", + "maximisation": "maximization", + "maximise": "maximize", + "maximised": "maximized", + "maximises": "maximizes", + "maximising": "maximizing", + "meagre": "meager", + "mechanisation": "mechanization", + "mechanise": "mechanize", + "mechanised": "mechanized", + "mechanises": "mechanizes", + "mechanising": "mechanizing", + "mediaeval": "medieval", + "memorialise": "memorialize", + "memorialised": "memorialized", + "memorialises": "memorializes", + "memorialising": "memorializing", + "memorise": "memorize", + "memorised": "memorized", + "memorises": "memorizes", + "memorising": "memorizing", + "mesmerise": "mesmerize", + "mesmerised": "mesmerized", + "mesmerises": "mesmerizes", + "mesmerising": "mesmerizing", + "metabolise": "metabolize", + "metabolised": "metabolized", + "metabolises": "metabolizes", + "metabolising": "metabolizing", + "metre": "meter", + "metres": "meters", + "micrometre": "micrometer", + "micrometres": "micrometers", + "militarise": "militarize", + "militarised": "militarized", + "militarises": "militarizes", + "militarising": "militarizing", + "milligramme": "milligram", + "milligrammes": "milligrams", + "millilitre": "milliliter", + "millilitres": "milliliters", + "millimetre": "millimeter", + "millimetres": "millimeters", + "miniaturisation": "miniaturization", + "miniaturise": "miniaturize", + "miniaturised": "miniaturized", + "miniaturises": "miniaturizes", + "miniaturising": "miniaturizing", + "minibusses": "minibuses", + "minimise": "minimize", + "minimised": "minimized", + "minimises": "minimizes", + "minimising": "minimizing", + "misbehaviour": "misbehavior", + "misdemeanour": "misdemeanor", + "misdemeanours": "misdemeanors", + "misspelt": "misspelled", + "mitre": "miter", + "mitres": "miters", + "mobilisation": "mobilization", + "mobilise": "mobilize", + "mobilised": "mobilized", + "mobilises": "mobilizes", + "mobilising": "mobilizing", + "modelled": "modeled", + "modeller": "modeler", + "modellers": "modelers", + "modelling": "modeling", + "modernise": "modernize", + "modernised": "modernized", + "modernises": "modernizes", + "modernising": "modernizing", + "moisturise": "moisturize", + "moisturised": "moisturized", + "moisturiser": "moisturizer", + "moisturisers": "moisturizers", + "moisturises": "moisturizes", + "moisturising": "moisturizing", + "monologue": "monolog", + "monologues": "monologs", + "monopolisation": "monopolization", + "monopolise": "monopolize", + "monopolised": "monopolized", + "monopolises": "monopolizes", + "monopolising": "monopolizing", + "moralise": "moralize", + "moralised": "moralized", + "moralises": "moralizes", + "moralising": "moralizing", + "motorised": "motorized", + "mould": "mold", + "moulded": "molded", + "moulder": "molder", + "mouldered": "moldered", + "mouldering": "moldering", + "moulders": "molders", + "mouldier": "moldier", + "mouldiest": "moldiest", + "moulding": "molding", + "mouldings": "moldings", + "moulds": "molds", + "mouldy": "moldy", + "moult": "molt", + "moulted": "molted", + "moulting": "molting", + "moults": "molts", + "moustache": "mustache", + "moustached": "mustached", + "moustaches": "mustaches", + "moustachioed": "mustachioed", + "multicoloured": "multicolored", + "nationalisation": "nationalization", + "nationalisations": "nationalizations", + "nationalise": "nationalize", + "nationalised": "nationalized", + "nationalises": "nationalizes", + "nationalising": "nationalizing", + "naturalisation": "naturalization", + "naturalise": "naturalize", + "naturalised": "naturalized", + "naturalises": "naturalizes", + "naturalising": "naturalizing", + "neighbour": "neighbor", + "neighbourhood": "neighborhood", + "neighbourhoods": "neighborhoods", + "neighbouring": "neighboring", + "neighbourliness": "neighborliness", + "neighbourly": "neighborly", + "neighbours": "neighbors", + "neutralisation": "neutralization", + "neutralise": "neutralize", + "neutralised": "neutralized", + "neutralises": "neutralizes", + "neutralising": "neutralizing", + "normalisation": "normalization", + "normalise": "normalize", + "normalised": "normalized", + "normalises": "normalizes", + "normalising": "normalizing", + "odour": "odor", + "odourless": "odorless", + "odours": "odors", + "oesophagus": "esophagus", + "oesophaguses": "esophaguses", + "oestrogen": "estrogen", + "offence": "offense", + "offences": "offenses", + "omelette": "omelet", + "omelettes": "omelets", + "optimise": "optimize", + "optimised": "optimized", + "optimises": "optimizes", + "optimising": "optimizing", + "organisation": "organization", + "organisational": "organizational", + "organisations": "organizations", + "organise": "organize", + "organised": "organized", + "organiser": "organizer", + "organisers": "organizers", + "organises": "organizes", + "organising": "organizing", + "orthopaedic": "orthopedic", + "orthopaedics": "orthopedics", + "ostracise": "ostracize", + "ostracised": "ostracized", + "ostracises": "ostracizes", + "ostracising": "ostracizing", + "outmanoeuvre": "outmaneuver", + "outmanoeuvred": "outmaneuvered", + "outmanoeuvres": "outmaneuvers", + "outmanoeuvring": "outmaneuvering", + "overemphasise": "overemphasize", + "overemphasised": "overemphasized", + "overemphasises": "overemphasizes", + "overemphasising": "overemphasizing", + "oxidisation": "oxidization", + "oxidise": "oxidize", + "oxidised": "oxidized", + "oxidises": "oxidizes", + "oxidising": "oxidizing", + "paederast": "pederast", + "paederasts": "pederasts", + "paediatric": "pediatric", + "paediatrician": "pediatrician", + "paediatricians": "pediatricians", + "paediatrics": "pediatrics", + "paedophile": "pedophile", + "paedophiles": "pedophiles", + "paedophilia": "pedophilia", + "palaeolithic": "paleolithic", + "palaeontologist": "paleontologist", + "palaeontologists": "paleontologists", + "palaeontology": "paleontology", + "panelled": "paneled", + "panelling": "paneling", + "panellist": "panelist", + "panellists": "panelists", + "paralyse": "paralyze", + "paralysed": "paralyzed", + "paralyses": "paralyzes", + "paralysing": "paralyzing", + "parcelled": "parceled", + "parcelling": "parceling", + "parlour": "parlor", + "parlours": "parlors", + "particularise": "particularize", + "particularised": "particularized", + "particularises": "particularizes", + "particularising": "particularizing", + "passivisation": "passivization", + "passivise": "passivize", + "passivised": "passivized", + "passivises": "passivizes", + "passivising": "passivizing", + "pasteurisation": "pasteurization", + "pasteurise": "pasteurize", + "pasteurised": "pasteurized", + "pasteurises": "pasteurizes", + "pasteurising": "pasteurizing", + "patronise": "patronize", + "patronised": "patronized", + "patronises": "patronizes", + "patronising": "patronizing", + "patronisingly": "patronizingly", + "pedalled": "pedaled", + "pedalling": "pedaling", + "pedestrianisation": "pedestrianization", + "pedestrianise": "pedestrianize", + "pedestrianised": "pedestrianized", + "pedestrianises": "pedestrianizes", + "pedestrianising": "pedestrianizing", + "penalise": "penalize", + "penalised": "penalized", + "penalises": "penalizes", + "penalising": "penalizing", + "pencilled": "penciled", + "pencilling": "penciling", + "personalise": "personalize", + "personalised": "personalized", + "personalises": "personalizes", + "personalising": "personalizing", + "pharmacopoeia": "pharmacopeia", + "pharmacopoeias": "pharmacopeias", + "philosophise": "philosophize", + "philosophised": "philosophized", + "philosophises": "philosophizes", + "philosophising": "philosophizing", + "philtre": "filter", + "philtres": "filters", + "phoney": "phony", + "plagiarise": "plagiarize", + "plagiarised": "plagiarized", + "plagiarises": "plagiarizes", + "plagiarising": "plagiarizing", + "plough": "plow", + "ploughed": "plowed", + "ploughing": "plowing", + "ploughman": "plowman", + "ploughmen": "plowmen", + "ploughs": "plows", + "ploughshare": "plowshare", + "ploughshares": "plowshares", + "polarisation": "polarization", + "polarise": "polarize", + "polarised": "polarized", + "polarises": "polarizes", + "polarising": "polarizing", + "politicisation": "politicization", + "politicise": "politicize", + "politicised": "politicized", + "politicises": "politicizes", + "politicising": "politicizing", + "popularisation": "popularization", + "popularise": "popularize", + "popularised": "popularized", + "popularises": "popularizes", + "popularising": "popularizing", + "pouffe": "pouf", + "pouffes": "poufs", + "practise": "practice", + "practised": "practiced", + "practises": "practices", + "practising": "practicing", + "praesidium": "presidium", + "praesidiums": "presidiums", + "pressurisation": "pressurization", + "pressurise": "pressurize", + "pressurised": "pressurized", + "pressurises": "pressurizes", + "pressurising": "pressurizing", + "pretence": "pretense", + "pretences": "pretenses", + "primaeval": "primeval", + "prioritisation": "prioritization", + "prioritise": "prioritize", + "prioritised": "prioritized", + "prioritises": "prioritizes", + "prioritising": "prioritizing", + "privatisation": "privatization", + "privatisations": "privatizations", + "privatise": "privatize", + "privatised": "privatized", + "privatises": "privatizes", + "privatising": "privatizing", + "professionalisation": "professionalization", + "professionalise": "professionalize", + "professionalised": "professionalized", + "professionalises": "professionalizes", + "professionalising": "professionalizing", + "programme": "program", + "programmes": "programs", + "prologue": "prolog", + "prologues": "prologs", + "propagandise": "propagandize", + "propagandised": "propagandized", + "propagandises": "propagandizes", + "propagandising": "propagandizing", + "proselytise": "proselytize", + "proselytised": "proselytized", + "proselytiser": "proselytizer", + "proselytisers": "proselytizers", + "proselytises": "proselytizes", + "proselytising": "proselytizing", + "psychoanalyse": "psychoanalyze", + "psychoanalysed": "psychoanalyzed", + "psychoanalyses": "psychoanalyzes", + "psychoanalysing": "psychoanalyzing", + "publicise": "publicize", + "publicised": "publicized", + "publicises": "publicizes", + "publicising": "publicizing", + "pulverisation": "pulverization", + "pulverise": "pulverize", + "pulverised": "pulverized", + "pulverises": "pulverizes", + "pulverising": "pulverizing", + "pummelled": "pummel", + "pummelling": "pummeled", + "pyjama": "pajama", + "pyjamas": "pajamas", + "pzazz": "pizzazz", + "quarrelled": "quarreled", + "quarrelling": "quarreling", + "radicalise": "radicalize", + "radicalised": "radicalized", + "radicalises": "radicalizes", + "radicalising": "radicalizing", + "rancour": "rancor", + "randomise": "randomize", + "randomised": "randomized", + "randomises": "randomizes", + "randomising": "randomizing", + "rationalisation": "rationalization", + "rationalisations": "rationalizations", + "rationalise": "rationalize", + "rationalised": "rationalized", + "rationalises": "rationalizes", + "rationalising": "rationalizing", + "ravelled": "raveled", + "ravelling": "raveling", + "realisable": "realizable", + "realisation": "realization", + "realisations": "realizations", + "realise": "realize", + "realised": "realized", + "realises": "realizes", + "realising": "realizing", + "recognisable": "recognizable", + "recognisably": "recognizably", + "recognisance": "recognizance", + "recognise": "recognize", + "recognised": "recognized", + "recognises": "recognizes", + "recognising": "recognizing", + "reconnoitre": "reconnoiter", + "reconnoitred": "reconnoitered", + "reconnoitres": "reconnoiters", + "reconnoitring": "reconnoitering", + "refuelled": "refueled", + "refuelling": "refueling", + "regularisation": "regularization", + "regularise": "regularize", + "regularised": "regularized", + "regularises": "regularizes", + "regularising": "regularizing", + "remodelled": "remodeled", + "remodelling": "remodeling", + "remould": "remold", + "remoulded": "remolded", + "remoulding": "remolding", + "remoulds": "remolds", + "reorganisation": "reorganization", + "reorganisations": "reorganizations", + "reorganise": "reorganize", + "reorganised": "reorganized", + "reorganises": "reorganizes", + "reorganising": "reorganizing", + "revelled": "reveled", + "reveller": "reveler", + "revellers": "revelers", + "revelling": "reveling", + "revitalise": "revitalize", + "revitalised": "revitalized", + "revitalises": "revitalizes", + "revitalising": "revitalizing", + "revolutionise": "revolutionize", + "revolutionised": "revolutionized", + "revolutionises": "revolutionizes", + "revolutionising": "revolutionizing", + "rhapsodise": "rhapsodize", + "rhapsodised": "rhapsodized", + "rhapsodises": "rhapsodizes", + "rhapsodising": "rhapsodizing", + "rigour": "rigor", + "rigours": "rigors", + "ritualised": "ritualized", + "rivalled": "rivaled", + "rivalling": "rivaling", + "romanticise": "romanticize", + "romanticised": "romanticized", + "romanticises": "romanticizes", + "romanticising": "romanticizing", + "rumour": "rumor", + "rumoured": "rumored", + "rumours": "rumors", + "sabre": "saber", + "sabres": "sabers", + "saltpetre": "saltpeter", + "sanitise": "sanitize", + "sanitised": "sanitized", + "sanitises": "sanitizes", + "sanitising": "sanitizing", + "satirise": "satirize", + "satirised": "satirized", + "satirises": "satirizes", + "satirising": "satirizing", + "saviour": "savior", + "saviours": "saviors", + "savour": "savor", + "savoured": "savored", + "savouries": "savories", + "savouring": "savoring", + "savours": "savors", + "savoury": "savory", + "scandalise": "scandalize", + "scandalised": "scandalized", + "scandalises": "scandalizes", + "scandalising": "scandalizing", + "sceptic": "skeptic", + "sceptical": "skeptical", + "sceptically": "skeptically", + "scepticism": "skepticism", + "sceptics": "skeptics", + "sceptre": "scepter", + "sceptres": "scepters", + "scrutinise": "scrutinize", + "scrutinised": "scrutinized", + "scrutinises": "scrutinizes", + "scrutinising": "scrutinizing", + "secularisation": "secularization", + "secularise": "secularize", + "secularised": "secularized", + "secularises": "secularizes", + "secularising": "secularizing", + "sensationalise": "sensationalize", + "sensationalised": "sensationalized", + "sensationalises": "sensationalizes", + "sensationalising": "sensationalizing", + "sensitise": "sensitize", + "sensitised": "sensitized", + "sensitises": "sensitizes", + "sensitising": "sensitizing", + "sentimentalise": "sentimentalize", + "sentimentalised": "sentimentalized", + "sentimentalises": "sentimentalizes", + "sentimentalising": "sentimentalizing", + "sepulchre": "sepulcher", + "sepulchres": "sepulchers", + "serialisation": "serialization", + "serialisations": "serializations", + "serialise": "serialize", + "serialised": "serialized", + "serialises": "serializes", + "serialising": "serializing", + "sermonise": "sermonize", + "sermonised": "sermonized", + "sermonises": "sermonizes", + "sermonising": "sermonizing", + "sheikh": "sheik", + "shovelled": "shoveled", + "shovelling": "shoveling", + "shrivelled": "shriveled", + "shrivelling": "shriveling", + "signalise": "signalize", + "signalised": "signalized", + "signalises": "signalizes", + "signalising": "signalizing", + "signalled": "signaled", + "signalling": "signaling", + "smoulder": "smolder", + "smouldered": "smoldered", + "smouldering": "smoldering", + "smoulders": "smolders", + "snivelled": "sniveled", + "snivelling": "sniveling", + "snorkelled": "snorkeled", + "snorkelling": "snorkeling", + "snowplough": "snowplow", + "snowploughs": "snowplow", + "socialisation": "socialization", + "socialise": "socialize", + "socialised": "socialized", + "socialises": "socializes", + "socialising": "socializing", + "sodomise": "sodomize", + "sodomised": "sodomized", + "sodomises": "sodomizes", + "sodomising": "sodomizing", + "solemnise": "solemnize", + "solemnised": "solemnized", + "solemnises": "solemnizes", + "solemnising": "solemnizing", + "sombre": "somber", + "specialisation": "specialization", + "specialisations": "specializations", + "specialise": "specialize", + "specialised": "specialized", + "specialises": "specializes", + "specialising": "specializing", + "spectre": "specter", + "spectres": "specters", + "spiralled": "spiraled", + "spiralling": "spiraling", + "splendour": "splendor", + "splendours": "splendors", + "squirrelled": "squirreled", + "squirrelling": "squirreling", + "stabilisation": "stabilization", + "stabilise": "stabilize", + "stabilised": "stabilized", + "stabiliser": "stabilizer", + "stabilisers": "stabilizers", + "stabilises": "stabilizes", + "stabilising": "stabilizing", + "standardisation": "standardization", + "standardise": "standardize", + "standardised": "standardized", + "standardises": "standardizes", + "standardising": "standardizing", + "stencilled": "stenciled", + "stencilling": "stenciling", + "sterilisation": "sterilization", + "sterilisations": "sterilizations", + "sterilise": "sterilize", + "sterilised": "sterilized", + "steriliser": "sterilizer", + "sterilisers": "sterilizers", + "sterilises": "sterilizes", + "sterilising": "sterilizing", + "stigmatisation": "stigmatization", + "stigmatise": "stigmatize", + "stigmatised": "stigmatized", + "stigmatises": "stigmatizes", + "stigmatising": "stigmatizing", + "storey": "story", + "storeys": "stories", + "subsidisation": "subsidization", + "subsidise": "subsidize", + "subsidised": "subsidized", + "subsidiser": "subsidizer", + "subsidisers": "subsidizers", + "subsidises": "subsidizes", + "subsidising": "subsidizing", + "succour": "succor", + "succoured": "succored", + "succouring": "succoring", + "succours": "succors", + "sulphate": "sulfate", + "sulphates": "sulfates", + "sulphide": "sulfide", + "sulphides": "sulfides", + "sulphur": "sulfur", + "sulphurous": "sulfurous", + "summarise": "summarize", + "summarised": "summarized", + "summarises": "summarizes", + "summarising": "summarizing", + "swivelled": "swiveled", + "swivelling": "swiveling", + "symbolise": "symbolize", + "symbolised": "symbolized", + "symbolises": "symbolizes", + "symbolising": "symbolizing", + "sympathise": "sympathize", + "sympathised": "sympathized", + "sympathiser": "sympathizer", + "sympathisers": "sympathizers", + "sympathises": "sympathizes", + "sympathising": "sympathizing", + "synchronisation": "synchronization", + "synchronise": "synchronize", + "synchronised": "synchronized", + "synchronises": "synchronizes", + "synchronising": "synchronizing", + "synthesise": "synthesize", + "synthesised": "synthesized", + "synthesiser": "synthesizer", + "synthesisers": "synthesizers", + "synthesises": "synthesizes", + "synthesising": "synthesizing", + "syphon": "siphon", + "syphoned": "siphoned", + "syphoning": "siphoning", + "syphons": "siphons", + "systematisation": "systematization", + "systematise": "systematize", + "systematised": "systematized", + "systematises": "systematizes", + "systematising": "systematizing", + "tantalise": "tantalize", + "tantalised": "tantalized", + "tantalises": "tantalizes", + "tantalising": "tantalizing", + "tantalisingly": "tantalizingly", + "tasselled": "tasseled", + "technicolour": "technicolor", + "temporise": "temporize", + "temporised": "temporized", + "temporises": "temporizes", + "temporising": "temporizing", + "tenderise": "tenderize", + "tenderised": "tenderized", + "tenderises": "tenderizes", + "tenderising": "tenderizing", + "terrorise": "terrorize", + "terrorised": "terrorized", + "terrorises": "terrorizes", + "terrorising": "terrorizing", + "theatre": "theater", + "theatregoer": "theatergoer", + "theatregoers": "theatergoers", + "theatres": "theaters", + "theorise": "theorize", + "theorised": "theorized", + "theorises": "theorizes", + "theorising": "theorizing", + "tonne": "ton", + "tonnes": "tons", + "towelled": "toweled", + "towelling": "toweling", + "toxaemia": "toxemia", + "tranquillise": "tranquilize", + "tranquillised": "tranquilized", + "tranquilliser": "tranquilizer", + "tranquillisers": "tranquilizers", + "tranquillises": "tranquilizes", + "tranquillising": "tranquilizing", + "tranquillity": "tranquility", + "tranquillize": "tranquilize", + "tranquillized": "tranquilized", + "tranquillizer": "tranquilizer", + "tranquillizers": "tranquilizers", + "tranquillizes": "tranquilizes", + "tranquillizing": "tranquilizing", + "tranquilly": "tranquility", + "transistorised": "transistorized", + "traumatise": "traumatize", + "traumatised": "traumatized", + "traumatises": "traumatizes", + "traumatising": "traumatizing", + "travelled": "traveled", + "traveller": "traveler", + "travellers": "travelers", + "travelling": "traveling", + "travelog": "travelogue", + "travelogs": "travelogues", + "trialled": "trialed", + "trialling": "trialing", + "tricolour": "tricolor", + "tricolours": "tricolors", + "trivialise": "trivialize", + "trivialised": "trivialized", + "trivialises": "trivializes", + "trivialising": "trivializing", + "tumour": "tumor", + "tumours": "tumors", + "tunnelled": "tunneled", + "tunnelling": "tunneling", + "tyrannise": "tyrannize", + "tyrannised": "tyrannized", + "tyrannises": "tyrannizes", + "tyrannising": "tyrannizing", + "tyre": "tire", + "tyres": "tires", + "unauthorised": "unauthorized", + "uncivilised": "uncivilized", + "underutilised": "underutilized", + "unequalled": "unequaled", + "unfavourable": "unfavorable", + "unfavourably": "unfavorably", + "unionisation": "unionization", + "unionise": "unionize", + "unionised": "unionized", + "unionises": "unionizes", + "unionising": "unionizing", + "unorganised": "unorganized", + "unravelled": "unraveled", + "unravelling": "unraveling", + "unrecognisable": "unrecognizable", + "unrecognised": "unrecognized", + "unrivalled": "unrivaled", + "unsavoury": "unsavory", + "untrammelled": "untrammeled", + "urbanisation": "urbanization", + "urbanise": "urbanize", + "urbanised": "urbanized", + "urbanises": "urbanizes", + "urbanising": "urbanizing", + "utilisable": "utilizable", + "utilisation": "utilization", + "utilise": "utilize", + "utilised": "utilized", + "utilises": "utilizes", + "utilising": "utilizing", + "valour": "valor", + "vandalise": "vandalize", + "vandalised": "vandalized", + "vandalises": "vandalizes", + "vandalising": "vandalizing", + "vaporisation": "vaporization", + "vaporise": "vaporize", + "vaporised": "vaporized", + "vaporises": "vaporizes", + "vaporising": "vaporizing", + "vapour": "vapor", + "vapours": "vapors", + "verbalise": "verbalize", + "verbalised": "verbalized", + "verbalises": "verbalizes", + "verbalising": "verbalizing", + "victimisation": "victimization", + "victimise": "victimize", + "victimised": "victimized", + "victimises": "victimizes", + "victimising": "victimizing", + "videodisc": "videodisk", + "videodiscs": "videodisks", + "vigour": "vigor", + "visualisation": "visualization", + "visualisations": "visualizations", + "visualise": "visualize", + "visualised": "visualized", + "visualises": "visualizes", + "visualising": "visualizing", + "vocalisation": "vocalization", + "vocalisations": "vocalizations", + "vocalise": "vocalize", + "vocalised": "vocalized", + "vocalises": "vocalizes", + "vocalising": "vocalizing", + "vulcanised": "vulcanized", + "vulgarisation": "vulgarization", + "vulgarise": "vulgarize", + "vulgarised": "vulgarized", + "vulgarises": "vulgarizes", + "vulgarising": "vulgarizing", + "waggon": "wagon", + "waggons": "wagons", + "watercolour": "watercolor", + "watercolours": "watercolors", + "weaselled": "weaseled", + "weaselling": "weaseling", + "westernisation": "westernization", + "westernise": "westernize", + "westernised": "westernized", + "westernises": "westernizes", + "westernising": "westernizing", + "womanise": "womanize", + "womanised": "womanized", + "womaniser": "womanizer", + "womanisers": "womanizers", + "womanises": "womanizes", + "womanising": "womanizing", + "woollen": "woolen", + "woollens": "woolens", + "woollies": "woolies", + "woolly": "wooly", + "worshipped": "worshiped", + "worshipping": "worshiping", + "worshipper": "worshiper", + "yodelled": "yodeled", + "yodelling": "yodeling", + "yoghourt": "yogurt", + "yoghourts": "yogurts", + "yoghurt": "yogurt", + "yoghurts": "yogurts", + "mhm": "hmm", + "mmm": "hmm" +} \ No newline at end of file diff --git a/tests/librispeech-parakeet/normalizers/english.py b/tests/librispeech-parakeet/normalizers/english.py new file mode 100644 index 00000000000..4932042bc5b --- /dev/null +++ b/tests/librispeech-parakeet/normalizers/english.py @@ -0,0 +1,550 @@ +import json +import os +import re +from fractions import Fraction +from typing import Iterator, List, Match, Optional, Union + +from more_itertools import windowed + +from .basic import remove_symbols_and_diacritics + + +class EnglishNumberNormalizer: + """ + Convert any spelled-out numbers into arabic numbers, while handling: + + - remove any commas + - keep the suffixes such as: `1960s`, `274th`, `32nd`, etc. + - spell out currency symbols after the number. e.g. `$20 million` -> `20000000 dollars` + - spell out `one` and `ones` + - interpret successive single-digit numbers as nominal: `one oh one` -> `101` + """ + + def __init__(self): + super().__init__() + + self.zeros = {"o", "oh", "zero"} + self.ones = { + name: i + for i, name in enumerate( + [ + "one", + "two", + "three", + "four", + "five", + "six", + "seven", + "eight", + "nine", + "ten", + "eleven", + "twelve", + "thirteen", + "fourteen", + "fifteen", + "sixteen", + "seventeen", + "eighteen", + "nineteen", + ], + start=1, + ) + } + self.ones_plural = { + "sixes" if name == "six" else name + "s": (value, "s") + for name, value in self.ones.items() + } + self.ones_ordinal = { + "zeroth": (0, "th"), + "first": (1, "st"), + "second": (2, "nd"), + "third": (3, "rd"), + "fifth": (5, "th"), + "twelfth": (12, "th"), + **{ + name + ("h" if name.endswith("t") else "th"): (value, "th") + for name, value in self.ones.items() + if value > 3 and value != 5 and value != 12 + }, + } + self.ones_suffixed = {**self.ones_plural, **self.ones_ordinal} + + self.tens = { + "twenty": 20, + "thirty": 30, + "forty": 40, + "fifty": 50, + "sixty": 60, + "seventy": 70, + "eighty": 80, + "ninety": 90, + } + self.tens_plural = { + name.replace("y", "ies"): (value, "s") for name, value in self.tens.items() + } + self.tens_ordinal = { + name.replace("y", "ieth"): (value, "th") + for name, value in self.tens.items() + } + self.tens_suffixed = {**self.tens_plural, **self.tens_ordinal} + + self.multipliers = { + "hundred": 100, + "thousand": 1_000, + "million": 1_000_000, + "billion": 1_000_000_000, + "trillion": 1_000_000_000_000, + "quadrillion": 1_000_000_000_000_000, + "quintillion": 1_000_000_000_000_000_000, + "sextillion": 1_000_000_000_000_000_000_000, + "septillion": 1_000_000_000_000_000_000_000_000, + "octillion": 1_000_000_000_000_000_000_000_000_000, + "nonillion": 1_000_000_000_000_000_000_000_000_000_000, + "decillion": 1_000_000_000_000_000_000_000_000_000_000_000, + } + self.multipliers_plural = { + name + "s": (value, "s") for name, value in self.multipliers.items() + } + self.multipliers_ordinal = { + name + "th": (value, "th") for name, value in self.multipliers.items() + } + self.multipliers_suffixed = { + **self.multipliers_plural, + **self.multipliers_ordinal, + } + self.decimals = {*self.ones, *self.tens, *self.zeros} + + self.preceding_prefixers = { + "minus": "-", + "negative": "-", + "plus": "+", + "positive": "+", + } + self.following_prefixers = { + "pound": "£", + "pounds": "£", + "euro": "€", + "euros": "€", + "dollar": "$", + "dollars": "$", + "cent": "¢", + "cents": "¢", + } + self.prefixes = set( + list(self.preceding_prefixers.values()) + + list(self.following_prefixers.values()) + ) + self.suffixers = { + "per": {"cent": "%"}, + "percent": "%", + } + self.specials = {"and", "double", "triple", "point"} + + self.words = set( + [ + key + for mapping in [ + self.zeros, + self.ones, + self.ones_suffixed, + self.tens, + self.tens_suffixed, + self.multipliers, + self.multipliers_suffixed, + self.preceding_prefixers, + self.following_prefixers, + self.suffixers, + self.specials, + ] + for key in mapping + ] + ) + self.literal_words = {"one", "ones"} + + def process_words(self, words: List[str]) -> Iterator[str]: + prefix: Optional[str] = None + value: Optional[Union[str, int]] = None + skip = False + + def to_fraction(s: str): + try: + return Fraction(s) + except ValueError: + return None + + def output(result: Union[str, int]): + nonlocal prefix, value + result = str(result) + if prefix is not None: + result = prefix + result + value = None + prefix = None + return result + + if len(words) == 0: + return + + for prev, current, next in windowed([None] + words + [None], 3): + if skip: + skip = False + continue + + next_is_numeric = next is not None and re.match(r"^\d+(\.\d+)?$", next) + has_prefix = current[0] in self.prefixes + current_without_prefix = current[1:] if has_prefix else current + if re.match(r"^\d+(\.\d+)?$", current_without_prefix): + # arabic numbers (potentially with signs and fractions) + f = to_fraction(current_without_prefix) + assert f is not None + if value is not None: + if isinstance(value, str) and value.endswith("."): + # concatenate decimals / ip address components + value = str(value) + str(current) + continue + else: + yield output(value) + + prefix = current[0] if has_prefix else prefix + if f.denominator == 1: + value = f.numerator # store integers as int + else: + value = current_without_prefix + elif current not in self.words: + # non-numeric words + if value is not None: + yield output(value) + yield output(current) + elif current in self.zeros: + value = str(value or "") + "0" + elif current in self.ones: + ones = self.ones[current] + + if value is None: + value = ones + elif isinstance(value, str) or prev in self.ones: + if ( + prev in self.tens and ones < 10 + ): # replace the last zero with the digit + assert value[-1] == "0" + value = value[:-1] + str(ones) + else: + value = str(value) + str(ones) + elif ones < 10: + if value % 10 == 0: + value += ones + else: + value = str(value) + str(ones) + else: # eleven to nineteen + if value % 100 == 0: + value += ones + else: + value = str(value) + str(ones) + elif current in self.ones_suffixed: + # ordinal or cardinal; yield the number right away + ones, suffix = self.ones_suffixed[current] + if value is None: + yield output(str(ones) + suffix) + elif isinstance(value, str) or prev in self.ones: + if prev in self.tens and ones < 10: + assert value[-1] == "0" + yield output(value[:-1] + str(ones) + suffix) + else: + yield output(str(value) + str(ones) + suffix) + elif ones < 10: + if value % 10 == 0: + yield output(str(value + ones) + suffix) + else: + yield output(str(value) + str(ones) + suffix) + else: # eleven to nineteen + if value % 100 == 0: + yield output(str(value + ones) + suffix) + else: + yield output(str(value) + str(ones) + suffix) + value = None + elif current in self.tens: + tens = self.tens[current] + if value is None: + value = tens + elif isinstance(value, str): + value = str(value) + str(tens) + else: + if value % 100 == 0: + value += tens + else: + value = str(value) + str(tens) + elif current in self.tens_suffixed: + # ordinal or cardinal; yield the number right away + tens, suffix = self.tens_suffixed[current] + if value is None: + yield output(str(tens) + suffix) + elif isinstance(value, str): + yield output(str(value) + str(tens) + suffix) + else: + if value % 100 == 0: + yield output(str(value + tens) + suffix) + else: + yield output(str(value) + str(tens) + suffix) + elif current in self.multipliers: + multiplier = self.multipliers[current] + if value is None: + value = multiplier + elif isinstance(value, str) or value == 0: + f = to_fraction(value) + p = f * multiplier if f is not None else None + if f is not None and p.denominator == 1: + value = p.numerator + else: + yield output(value) + value = multiplier + else: + before = value // 1000 * 1000 + residual = value % 1000 + value = before + residual * multiplier + elif current in self.multipliers_suffixed: + multiplier, suffix = self.multipliers_suffixed[current] + if value is None: + yield output(str(multiplier) + suffix) + elif isinstance(value, str): + f = to_fraction(value) + p = f * multiplier if f is not None else None + if f is not None and p.denominator == 1: + yield output(str(p.numerator) + suffix) + else: + yield output(value) + yield output(str(multiplier) + suffix) + else: # int + before = value // 1000 * 1000 + residual = value % 1000 + value = before + residual * multiplier + yield output(str(value) + suffix) + value = None + elif current in self.preceding_prefixers: + # apply prefix (positive, minus, etc.) if it precedes a number + if value is not None: + yield output(value) + + if next in self.words or next_is_numeric: + prefix = self.preceding_prefixers[current] + else: + yield output(current) + elif current in self.following_prefixers: + # apply prefix (dollars, cents, etc.) only after a number + if value is not None: + prefix = self.following_prefixers[current] + yield output(value) + else: + yield output(current) + elif current in self.suffixers: + # apply suffix symbols (percent -> '%') + if value is not None: + suffix = self.suffixers[current] + if isinstance(suffix, dict): + if next in suffix: + yield output(str(value) + suffix[next]) + skip = True + else: + yield output(value) + yield output(current) + else: + yield output(str(value) + suffix) + else: + yield output(current) + elif current in self.specials: + if next not in self.words and not next_is_numeric: + # apply special handling only if the next word can be numeric + if value is not None: + yield output(value) + yield output(current) + elif current == "and": + # ignore "and" after hundreds, thousands, etc. + if prev not in self.multipliers: + if value is not None: + yield output(value) + yield output(current) + elif current == "double" or current == "triple": + if next in self.ones or next in self.zeros: + repeats = 2 if current == "double" else 3 + ones = self.ones.get(next, 0) + value = str(value or "") + str(ones) * repeats + skip = True + else: + if value is not None: + yield output(value) + yield output(current) + elif current == "point": + if next in self.decimals or next_is_numeric: + value = str(value or "") + "." + else: + # should all have been covered at this point + raise ValueError(f"Unexpected token: {current}") + else: + # all should have been covered at this point + raise ValueError(f"Unexpected token: {current}") + + if value is not None: + yield output(value) + + def preprocess(self, s: str): + # replace "<number> and a half" with "<number> point five" + results = [] + + segments = re.split(r"\band\s+a\s+half\b", s) + for i, segment in enumerate(segments): + if len(segment.strip()) == 0: + continue + if i == len(segments) - 1: + results.append(segment) + else: + results.append(segment) + last_word = segment.rsplit(maxsplit=2)[-1] + if last_word in self.decimals or last_word in self.multipliers: + results.append("point five") + else: + results.append("and a half") + + s = " ".join(results) + + # put a space at number/letter boundary + s = re.sub(r"([a-z])([0-9])", r"\1 \2", s) + s = re.sub(r"([0-9])([a-z])", r"\1 \2", s) + + # but remove spaces which could be a suffix + s = re.sub(r"([0-9])\s+(st|nd|rd|th|s)\b", r"\1\2", s) + + return s + + def postprocess(self, s: str): + def combine_cents(m: Match): + try: + currency = m.group(1) + integer = m.group(2) + cents = int(m.group(3)) + return f"{currency}{integer}.{cents:02d}" + except ValueError: + return m.string + + def extract_cents(m: Match): + try: + return f"¢{int(m.group(1))}" + except ValueError: + return m.string + + # apply currency postprocessing; "$2 and ¢7" -> "$2.07" + s = re.sub(r"([€£$])([0-9]+) (?:and )?¢([0-9]{1,2})\b", combine_cents, s) + s = re.sub(r"[€£$]0.([0-9]{1,2})\b", extract_cents, s) + + # write "one(s)" instead of "1(s)", just for the readability + s = re.sub(r"\b1(s?)\b", r"one\1", s) + + return s + + def __call__(self, s: str): + s = self.preprocess(s) + s = " ".join(word for word in self.process_words(s.split()) if word is not None) + s = self.postprocess(s) + + return s + + +class EnglishSpellingNormalizer: + """ + Applies British-American spelling mappings as listed in [1]. + + [1] https://www.tysto.com/uk-us-spelling-list.html + """ + + def __init__(self): + mapping_path = os.path.join(os.path.dirname(__file__), "english.json") + self.mapping = json.load(open(mapping_path)) + + def __call__(self, s: str): + return " ".join(self.mapping.get(word, word) for word in s.split()) + + +class EnglishTextNormalizer: + def __init__(self): + self.ignore_patterns = r"\b(hmm|mm|mhm|mmm|uh|um)\b" + self.replacers = { + # common contractions + r"\bwon't\b": "will not", + r"\bcan't\b": "can not", + r"\blet's\b": "let us", + r"\bain't\b": "aint", + r"\by'all\b": "you all", + r"\bwanna\b": "want to", + r"\bgotta\b": "got to", + r"\bgonna\b": "going to", + r"\bi'ma\b": "i am going to", + r"\bimma\b": "i am going to", + r"\bwoulda\b": "would have", + r"\bcoulda\b": "could have", + r"\bshoulda\b": "should have", + r"\bma'am\b": "madam", + # contractions in titles/prefixes + r"\bmr\b": "mister ", + r"\bmrs\b": "missus ", + r"\bst\b": "saint ", + r"\bdr\b": "doctor ", + r"\bprof\b": "professor ", + r"\bcapt\b": "captain ", + r"\bgov\b": "governor ", + r"\bald\b": "alderman ", + r"\bgen\b": "general ", + r"\bsen\b": "senator ", + r"\brep\b": "representative ", + r"\bpres\b": "president ", + r"\brev\b": "reverend ", + r"\bhon\b": "honorable ", + r"\basst\b": "assistant ", + r"\bassoc\b": "associate ", + r"\blt\b": "lieutenant ", + r"\bcol\b": "colonel ", + r"\bjr\b": "junior ", + r"\bsr\b": "senior ", + r"\besq\b": "esquire ", + # prefect tenses, ideally it should be any past participles, but it's harder.. + r"'d been\b": " had been", + r"'s been\b": " has been", + r"'d gone\b": " had gone", + r"'s gone\b": " has gone", + r"'d done\b": " had done", # "'s done" is ambiguous + r"'s got\b": " has got", + # general contractions + r"n't\b": " not", + r"'re\b": " are", + r"'s\b": " is", + r"'d\b": " would", + r"'ll\b": " will", + r"'t\b": " not", + r"'ve\b": " have", + r"'m\b": " am", + } + self.standardize_numbers = EnglishNumberNormalizer() + self.standardize_spellings = EnglishSpellingNormalizer() + + def __call__(self, s: str): + s = s.lower() + + s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets + s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis + s = re.sub(self.ignore_patterns, "", s) + s = re.sub(r"\s+'", "'", s) # when there's a space before an apostrophe + + for pattern, replacement in self.replacers.items(): + s = re.sub(pattern, replacement, s) + + s = re.sub(r"(\d),(\d)", r"\1\2", s) # remove commas between digits + s = re.sub(r"\.([^0-9]|$)", r" \1", s) # remove periods not followed by numbers + s = remove_symbols_and_diacritics(s, keep=".%$¢€£") # keep numeric symbols + + s = self.standardize_numbers(s) + s = self.standardize_spellings(s) + + # now remove prefix/suffix symbols that are not preceded/followed by numbers + s = re.sub(r"[.$¢€£]([^0-9])", r" \1", s) + s = re.sub(r"([^0-9])%", r"\1 ", s) + + s = re.sub(r"\s+", " ", s) # replace any successive whitespaces with a space + + return s diff --git a/tests/parakeet-expected-diffusion-output.txt b/tests/parakeet-expected-diffusion-output.txt new file mode 100644 index 00000000000..9753a86953a --- /dev/null +++ b/tests/parakeet-expected-diffusion-output.txt @@ -0,0 +1 @@ +Hello and welcome to Diffusion. Sit back and relax while we stretch your brain with weird and wonderful science. I'm Ian Wolf. On this edition, Dr. Viv Robinson rewrites cosmology. But first up, here's news of two massive galaxies that might be older than the Big Bang. Galaxies too massive. Astronomers from the Swinburne University of Technology in Melbourne, using the James Webb Space Telescope, have observed six galaxies that formed in the universe's first 700 million years appear to be up to a hundred times more massive than our best theories say can possibly exist. Astronomer Ivo Labe and his colleagues wrote in his paper, adding up the stars in those galaxies, it would exceed the total amount of mass available in the universe at that time. There's too much mass and not enough time for it to get together. The galaxies must have had much longer than the 700 million years after the Big Bang that our standard model of the universe gives them, and the universe must have had more mass available, or galaxies must have formed differently than what we think. The Big Bang is currently thought to have started everything 13.77 billion years ago. And these galaxies, we're watching them at 0.77 billion years ago because they're so far away. Galaxies are thought to accumulate gas moved together by giant clumps of dark matter in their region. Generally, only about 10% of the gas in the galaxy ignites to make a star. For galaxies in the remotest parts of the universe where the gas is thin, it takes a long time to accumulate this much gas for this many stars. These six galaxies, however, have so many stars adding up to so much mass that all of the gas in each galaxy had to have become 100% converted into stars in the 700 million years since the universe started in the Big Bang. Under our current understanding, this is impossible. It suggests something in our understanding of the cosmos is wrong. Are we wrong about how to calculate astronomical masses, galaxy formation, dark matter, and the Big Bang and the age of the universe? An astronomer from the Cosmic Dawn Centre in Denmark used the James Webb telescope to look at closer galaxies, and then used the very high resolution of that telescope to calculate the mass more precisely with a different method, and found that these galaxies are three to ten times more massive than we previously thought. Applying this more accurate technique to the six galaxies that are 13 billion light years away would increase their mass, which makes it much worse than what we thought. The paper was titled A Population of Red Candidate Massive Galaxies, approximately 600 million years after the Big Bang, and was published in the journal Nature.com. We're brought to you across Australia on the Community Radio Network and podcast over the internet on www.diffusionradio.com Challenging Physics Newton said everything is either a particle or a wave. Faraday and Maxwell added fields. Einstein added space-time. Quantum physics says everything is made of quanta, which have the properties of both waves and particles, but is neither. Quantum mechanics has no explanation for gravity, and relativity doesn't account for the quantum world. There's a contradiction between our most basic explanations of the universe. Dr. Viv Robinson was the first person to create a physical explanation of Einstein's gravity in a paper published in the Journal of Physics Communications. He's made corrections to people's extensions of Einstein's mathematics and has a different way to interpret those mathematics that gives a different picture of the age of the universe and a different way of looking at how the physics works. From the standard model of quantum physics to Big Bang cosmology. Everything, including you and me, is made of light. It's a very big and very bold claim. I spoke to Dr. Viv Robinson via Zoom and began by asking him, what is the universe made of? The whole stuff of the universe, or entity. I won't call it items because one of them is absolutely nothing. The first thing to all the mass and all the energy is made up of photons. They're little packets of electromagnetic energy, postulated by Maxwell and Planck and proven by Einstein. They come in many different sizes, shapes, and which make that they make up all the mass and energy of the universe. The volume is made up by empty space, absolutely nothing. But it's the properties of the space that are important. And it does this through two of its properties, electric permittivity and magnetic permeability. And it's those properties which then transmit all of the fields. So that's really all it is. They're just the only two stars in a call because the photos are physical things, and space is just the absence of everything, but its property, its properties are what is important about it. And that's a little bit different to what you might hear from a quantum physics class where they talk about space being full of virtual particles coming into and out of existence so that it's not totally empty, or sometimes they say it's full of fields. The fields of every force is in there and things are coming up all the time. So if you go very fast, you'll interact with the fields, all the virtual particles, and you'll get radiation. Yes, well, uh the unfortunate part is that physics is doing exceedingly well under Newtonian mechanics and exceedingly well under Maxwell's mechanics. But as things get smaller and smaller, you get to a stage where things aren't continuous. I mean, Newton's work will anything that's continuous, but eventually you get to the stage where you know a droplet of water is fine, it has surface tension, evaporates, and you're left with one molecule of water. That doesn't behave the same as bulk water. Into that molecule you go hydrogen atoms and oxygen atoms, they behave nothing like water. And then you get, well, they're made of protons, neutrons, electrons, and they have completely different properties from bulk water. So quantum mechanics, things get quantized, and you get the smallest quantity you can get, and that has very, very different properties from the bulk. And what has happened in the past is that uh the uh early on in quantum mechanics and met men like Dirac and Schrdinger, they didn't know what an the structure was an electron was. Also, all they had to know, they knew it was it had wave properties. And so all they did was they attributed it to a way a wave property to it. Now, waves have the advantage over particles, you can manipulate them almost forever with all sorts of different transforms until you get the answer you want. And that gave some confidence to quantum mechanics guys that yes, waves work, and they've been using that forever, and all I'm saying, no, no, no, no, no. Everything is particles, and the particles have specific properties, and you can't manipulate those properties, or you can to a certain extent, but they are what they are, and it's when you know what those properties are that the whole quantum mechanics becomes much simpler. You don't need any of that uh foamy sort of stuff to get to explain whatever you want to explain. I mentioned that there are many different forms of photons, and photons are electromagnetic radiation with an electric field, saying on a magnetic field perpendicular to it, and the whole lot travels in the speed of light in the third dimension. There are many, many variations of that. So that that's fine for energy radiation. But how about matter particles? Well, matter particles are nothing more than photons of the appropriate wavelength making uh appropriate energy making two revolutions per wavelength. And when they do that, what holds what allows them to do that is that they rotate around the magnetic field. And suddenly, instead of in a linear photon, magnetic fields are open. When they rotate around the magnetic field, then the magnetic field of a particle is closed. And a closed magnetic field is much more stable than an open magnetic field, and that's why most of the universe, for example, when uh less about, I think the best estimate I've seen, one percent is radiation, the other 99% is photons struggling in circles, making two revolutions per wavelength. And it's for that that gives particles all their properties. Now, I may say this is a bit hairy-fairy, but it's been known for a long, long time that you get a particle and an antiparticle, you put them together, bing, two photons. At the same time, you can get a photon and goes and hit the target, bang, a particle and an antiparticle. Now that shows a relationship between the two that somehow lots of people missed. But what's the simplest relationship you can have? The simplest relationship is that a particle is a photon making two revolutions in one direction, an antiparticle is the same particle making two revolutions in the other direction. Put them together, they unlock. Because they have mass, they have this thing called angular momentum, which is a great Newtonian property. But because mathematicians sort of didn't know what an electron was, they called it a point particle. You can't have angular momentum with a point particle, so they call it spin and they wave all sorts of different things to make it seem as if they know what they're talking about. It's really just angular momentum. And that's the relationship between mass and energy. Energy is the photon zipping along at the speed of light. Mass is the same photon making two revolutions per wavelength. That's how they can interchange so easily. And that property gives particles all of their properties, including mass. And one of the things that Einstein did work out in 1905, those little what they called uh packets of radio of electromagnetic energy, he did work out that they carried momentum or carried inertia, they had momentum, they had mass. I don't know why people want to prove Einstein wrong. Photons have mass. Now I think the reason for this is that they think oh, Einstein's special relativity corrections, anything traveling at the speed of light, will have an infinite mass. The special relativity corrections only apply to photons which are spiraling. And that's just as um the reason for that is about as complicated as uh post Thagoras' theorem. And what he was at 300 BC or something like that, not difficult. And so photons themselves always travel at the speed of light. And so the rotating photons, photons that are rotating, are rotating also at the same speed of light. Well, that's one old hell of a gyroscope. And that is what gives particles a spin, that's why E equals mc squared, and it's all straightforward. There you go. Really? Well. So if we go back a little bit there where you're saying there's no wave nature, what about the double-slit experiment and other sorts of experiments that seem to show wave properties of particles other than photons? Particles um De Royal worked out in 1925 that if if photons, if um photons behave like particles, and particles to behave like photons, I agree with him, it's completely it's completely true. The actual nature of the rotating photon generates the de Broilie wavelength, and it has all the right properties. For me, and to me, Einstein's special and general relativity theories are relatively simple, so it may I may be talking a little bit out of line here. But the deuil wavelength is automatically generated by the particle as it moves. So it's not something that they hypothesize and don't know what occurs. They they hypothesized it, they measured it, but they don't know how it occurs. Well, yeah, it's quite it's fairly straightforward, but not at uh not not not at this level. What are the implications for this difference in understanding? So are there predictions that you would make that are different to the ones that people following the standard model would make? Oh, not the numbers of them, yeah. So probably the electron tunneling. Where electrons hit a barrier. That's got a very simple mechanical analog. I mean, the electrons are held in uh what you call a very taut field. Now, if you've got something coming up, you've got everything in a tight situation, you come something up banging it at this end, you can do it with billiard balls that'll transport through, and another one will knock out. So, what they call tunneling under this model, but in reality, what they call tunneling is just really a momentum exchange. So that's a little bit like one of those Newton cradles. Where you've got the balls on all attached by a string or a chain to a fulcrum over the top, and one will hit the other one and transfer the momentum to the other one without actually transferring itself. Yeah, you don't get electrons, you know, they have they have wave properties, but yes, but you won't get an electron uh tunneling the wave, the wave is in a very fixed position with respect to the uh electron. It's equal on either side of it. If their tunneling theory were correct, then the lower the energy of the electron, the longer its wavelength, therefore the easier it would be to tunnel. However, in the energy transfer one, the higher the energy, the greater probability it'll knock another electron out the other side. Or it's a simple experiment to do. Just increase the energy of uh an electron coming up to a barrier and see which ones go come out the other end first. Is anyone set up to do that? Oh, anyone could set up to do it. Well, a lot of laboratories could do it. And the so-called tunneling effect is what they use in all of the microelectronics systems. And they wouldn't, it wouldn't, it'd be a very, very simple exercise to carry that out. They may well have done it, and the mathematicians have turned around and added another factor. Yeah, it's a standard thing they do when they don't get the right answer, just add another factor. I can't do that. It's physical reality is physical reality. End of story. I guess that's something to look up and see if someone's done those experiments and and what they did with the results. I think there is I think I'm sure it has been done, and the result is that the higher the energy of the electron, the greater the probability of it emerging on the other side of the barrier. And on the very much bigger scale, are there differences in the way the universe looks for astronomy? Yeah, not as far as astronomy is concerned. What the astronomers see is what there is. No question about it. They're great, they're brilliant, as the astronomers, and most of the experimentalists are they're doing an exceedingly good job. The problem becomes in interpreting what they've seen. And when it comes to the whole universe, for example, it's all based on Einstein's theory of gravity. Well, it should be, but it's more advanced than Newton's inverse square, but for most practical purposes, uh Newton's inverse square works quite well. The two situations where it doesn't work, when the mass is so large, like the mass of the sun or the mass of the center of uh Sagittarius A with the planet or star S2 going around it. That's one situation. The reason why a planet uh or Mercury's orbit precesses in its direction of travel is simply that gravity, when mass is strong enough, gravity actually becomes weaker than inverse square. And that's one of the things you get when you solve Einstein's gravity theory accurately. It becomes weaker than inverse square. Now, when it's weak, if it's weaker than inverse square, Mercury travels a little bit closer to the Sun and is attracted by a slightly stronger force. So it'll arrive back at its perihelion point a little later, and it it'll um process in its direction of travel. And Newton pointed that out in 1687. So I don't know why they didn't sort of work it out correctly. But gravity is weaker than inverse square, is the solution to Einstein's gravity. The other thing is that when gravity is an infinite steady state universe under Newton's theory of gravity, inverse square, will collapse. The reason being that the relative to the universe density mass increases as r cubed, gravity decreases as r squared, so eventually you get to the stage where gravity just uh dominates mass and it collapses. But if gravity is weaker than inverse square, and I just tried to show you that Mercury is precessing orbit because the sun's gravity is weaker than inverse square, well, that applies to all gravity. There's nothing special about our sun, except that it's keeping all us alive on this. When you have an infinite steady-state universe, if gravity is weaker than inverse squares, its effect gets relatively weaker over long distances. And I'm talking typically uh 10 billion light years or something like that, maybe more. But that means an infinite steady-state universe won't collapse. That's a huge, huge difference. That's the biggest thing, mind you, what difference does it make to us here on Earth if uh if Bang's web has seen galaxies, fully formed galaxies 20 billion light years away, doesn't make a scrap of difference to us. But as far as understanding how the universe works, that mistake, and the simple the simple mistake that they the um all mathematicians were uh made, Einstein introduced approximations. He couldn't solve the gravity exactly himself. I have no problem solving his uh his gravity exactly. But he he uh introduced the approximation that one over one plus x approximately equals one minus x. You know, when x is ten to the minus seven or which is or ten to the minus eight, that's a good approximation. I mean you you just read his paper, he says so. And you read the mathematics, you don't even you could read the German version, look at the mathematics, and he says so, and you just work it out, and that was the difference. So, all of their exact solutions to Einstein's gravity, they took where he used the approximation, he derived the figure from one plus one over x, the equivalent of that, and then he rather than do that, he equated it to one minus x, which is which is true. You know, one plus one millionth is nine hundred and ninety millionth. Why they did it, I have no idea. Mind you, it'd be interesting to try and find out why. Uh I think it's if a mathematician of repute says one thing, and I I I will agree that uh on my first readings of Einstein's relativity theories, you think, oh my god, really? Could he understand that then? Then you get in and you start. It's not that difficult. And I think most of them had a solution. You know, somebody came up with a solution to Einstein's group, and everybody just followed it. And nobody, and this is the big thing that I always stress to everybody, don't take somebody's word for it. Go back and check the original yourself. I've seen a few times where people have just made terrible, terrible mistakes. But this would probably be the biggest one in the whole field of cosmology, sorry. Astronomy? You guys, great. Thanks, Uncle Sam, for providing us with all this information. That was part one of my interview with Dr. Viv Robinson. You heard Viv say that matter is made of photons moving in circles. Physicists took Einstein's approximations as gospel instead of using the exact solutions available with lather mathematics. Gravity changes to be weaker over distances, and the universe isn't expanding. Listen next week for part two. If you have any questions for Dr. Robinson, he'd love to answer them on the show. So send your questions to science at diffusionradio.com. If you're in Darlinghurst this Wednesday night, the 5th of July, I will be part of the lineup of scientists speaking at Future Science Talks at the East Village. Go to www.futurescience talks.com.au to grab a ticket and come up and say hello. And if you can't make it Wednesday night, I'll keep you posted on some future talks I'll be giving. And that's all from us this week on Diffusion. Are you a scientist, artist, biohacker, or maker who'd like to be interviewed about your work? Would your company like to sponsor diffusion? Send your contributions, opinions, helpful suggestions and donations to science at diffusionradio.com. That's science at diffusionradio.com. Please subscribe to the Diffusion Science Radio channel on youtube.com slash C slash Diffusion Radio and rate the show on iTunes and tell your friends. Follow me on Twitter at IanWorf. The news music was Rhinos Theme by Kevin McLeod of Incompitech.com. I produce diffusion, which is broadcast around Australia, to 28 stations on the community radio network, including Radio Blue Mountains 89.1 FM in New South Wales, 8CCC in Alice Springs and Tennant Creek, 2 MVR in Nambucker Valley, 3 MVR in the Malleigh Border Districts of Victoria and South Australia, City Park Radio 7LTN in Launcest and Tasmania, and 2XFM in Canberra. Diffusion is narrowcast on Indigo FM88 in Northeast Victoria. Diffusion is syndicated globally on astronomy.fm. Subscribe to the podcast on the diffusion website www.diffusionradio.com. That's www.diffusionradio.com and check the website for links, photos, and videos about this week's show. If you enjoyed the show, you can explore more than a thousand previous episodes archived on diffusionradio.com where the shows are labelled by keywords so you can focus in on the stories you want to hear. Make a donation through PayPal.me slash Ian Worf. Or join my patrons at patreon.com slash Diffusion Radio. I'm Ian Worf. Join us inside your audio device of choice for more science wondering next week on Diffusion Science Radio. Science is fun. It helps you to learn, to know, and to appreciate. When you study science, you make fun feel. diff --git a/tests/parakeet-expected-gb1-output.txt b/tests/parakeet-expected-gb1-output.txt new file mode 100644 index 00000000000..312ed1ce048 --- /dev/null +++ b/tests/parakeet-expected-gb1-output.txt @@ -0,0 +1 @@ +My fellow Americans, this day has brought terrible news and great sadness to our country. At nine o'clock this morning, mission control in Houston lost contact with our space shuttle Columbia. A short time later, debris was seen falling from the skies above Texas. The Columbia's lost. There are no survivors. On board was a crew of seven. Colonel Rick Husband, Lieutenant Colonel Michael Anderson, Commander Laurel Clark, Captain David Brown, Commander William McCool, Dr. Kulpna Shavla, and Ilan Ramon, a colonel in the Israeli Air Force. These men and women assumed great risk in the service to all humanity. In an age when space flight has come to seem almost routine. It is easy to overlook the dangers of travel by rocket and the difficulties of navigating the fierce outer atmosphere of the earth. These astronauts knew the dangers, and they faced them willingly, knowing they had a high and noble purpose in life. Because of their courage and daring and idealism, we will miss them all the more. And those you loved will always have the respect and gratitude of this country. The cause in which they died will continue. Mankind is led into the darkness beyond our world by the inspiration of discovery and the longing to understand. Our journey into space will go on. In the skies today, we saw destruction and tragedy. Yet farther than we can see, there is comfort and hope. In the words of the prophet Isaiah, lift your eyes and look to the heavens. Who created all these? He who brings out the starry hosts one by one and calls them each by name. Because of his great power and mighty strength, not one of them is missing. The same creator who names the stars also knows the names of the seven souls we mourn today. The crew of the shuttle Columbia did not return safely to Earth. Yet we can pray that all are safely home. May God bless the grieving families and make out may God continue to bless America. diff --git a/tests/parakeet-expected-jfk-output.txt b/tests/parakeet-expected-jfk-output.txt new file mode 100644 index 00000000000..ece35697ae8 --- /dev/null +++ b/tests/parakeet-expected-jfk-output.txt @@ -0,0 +1 @@ +And so, my fellow Americans, ask not what your country can do for you, ask what you can do for your country. diff --git a/tests/parakeet-verification.h b/tests/parakeet-verification.h new file mode 100644 index 00000000000..0e95610ba26 --- /dev/null +++ b/tests/parakeet-verification.h @@ -0,0 +1,110 @@ +#pragma once + +#include <algorithm> +#include <cassert> +#include <cctype> +#include <cstdio> +#include <fstream> +#include <iterator> +#include <string> +#include <vector> + +#ifndef TRANSCRIPTION_SIMILARITY_THRESHOLD +#define TRANSCRIPTION_SIMILARITY_THRESHOLD 1.0 +#endif + +static std::string read_expected_transcription(const char * path) { + std::ifstream fin(path); + assert(fin.is_open()); + + std::string text( + (std::istreambuf_iterator<char>(fin)), + std::istreambuf_iterator<char>()); + + while (!text.empty() && (text.back() == '\n' || text.back() == '\r')) { + text.pop_back(); + } + + return text; +} + +static std::vector<std::string> transcription_words(const std::string & text) { + std::vector<std::string> words; + std::string word; + + for (unsigned char ch : text) { + if (std::isalnum(ch)) { + word.push_back((char) std::tolower(ch)); + } else if (!word.empty()) { + words.push_back(word); + word.clear(); + } + } + + if (!word.empty()) { + words.push_back(word); + } + + return words; +} + +static double transcription_lcs_similarity(const std::string & expected, const std::string & actual) { + const std::vector<std::string> expected_words = transcription_words(expected); + const std::vector<std::string> actual_words = transcription_words(actual); + + if (expected_words.empty() && actual_words.empty()) { + return 1.0; + } + + if (expected_words.empty() || actual_words.empty()) { + return 0.0; + } + + std::vector<int> prev(actual_words.size() + 1, 0); + std::vector<int> cur (actual_words.size() + 1, 0); + + for (size_t i = 0; i < expected_words.size(); ++i) { + std::fill(cur.begin(), cur.end(), 0); + + for (size_t j = 0; j < actual_words.size(); ++j) { + if (expected_words[i] == actual_words[j]) { + cur[j + 1] = prev[j] + 1; + } else { + cur[j + 1] = std::max(prev[j + 1], cur[j]); + } + } + + prev.swap(cur); + } + + const int lcs = prev[actual_words.size()]; + return (2.0 * lcs) / (expected_words.size() + actual_words.size()); +} + +static bool verify_transcription(const std::string & expected, const std::string & actual) { + const double threshold = TRANSCRIPTION_SIMILARITY_THRESHOLD; + + if (threshold >= 1.0) { + if (actual == expected) { + return true; + } + + fprintf(stderr, "\n\n"); + fprintf(stderr, "[Failed] Transcript mismatched\n"); + fprintf(stderr, "expected:\n%s\n\n", expected.c_str()); + fprintf(stderr, "actual:\n%s\n", actual.c_str()); + return false; + } + + const double similarity = transcription_lcs_similarity(expected, actual); + printf("\nTranscript similarity: %.6f (threshold %.6f)\n", similarity, threshold); + + if (similarity >= threshold) { + return true; + } + + fprintf(stderr, "\n\nTranscript similarity below threshold: %.6f < %.6f\n", similarity, threshold); + fprintf(stderr, "Expected:\n%s\n\n", expected.c_str()); + fprintf(stderr, "Actual:\n%s\n", actual.c_str()); + return false; +} diff --git a/tests/run-tests.sh b/tests/run-tests.sh index ad2b8d3ec09..bc28314a704 100755 --- a/tests/run-tests.sh +++ b/tests/run-tests.sh @@ -21,13 +21,21 @@ cd `dirname $0` # Whisper models models=( "tiny.en" "tiny" "base.en" "base" "small.en" "small" "medium.en" "medium" "large-v1" "large-v2" "large-v3" "large-v3-turbo" ) +# Parakeet model variants +parakeet_models=( "f16" "f32" "q2_k" "q4_0" "q4_k" "q8_0" ) + # list available models function list_models { printf "\n" - printf " Available models:" + printf " Available whisper models:" for model in "${models[@]}"; do printf " $model" done + printf "\n" + printf " Available parakeet models:" + for model in "${parakeet_models[@]}"; do + printf " parakeet-$model" + done printf "\n\n" } @@ -39,15 +47,37 @@ if [ $# -eq 0 ]; then fi model=$1 -main="../build/bin/whisper-cli" threads="" if [ $# -eq 2 ]; then threads="-t $2" fi -if [ ! -f ../models/ggml-$model.bin ]; then - printf "Model $model not found. Aborting\n" +# Detect parakeet model (prefix "parakeet-" or a bare variant like "f32") +is_parakeet=0 +parakeet_variant="" +if [[ $model == parakeet-* ]]; then + is_parakeet=1 + parakeet_variant="${model#parakeet-}" +fi +for v in "${parakeet_models[@]}"; do + if [[ $model == "$v" ]]; then + is_parakeet=1 + parakeet_variant="$v" + break + fi +done + +if [ $is_parakeet -eq 1 ]; then + main="../build/bin/parakeet-cli" + model_path="../models/ggml-parakeet-tdt-0.6b-v3-${parakeet_variant}.bin" +else + main="../build/bin/whisper-cli" + model_path="../models/ggml-${model}.bin" +fi + +if [ ! -f $model_path ]; then + printf "Model $model not found ($model_path). Aborting\n" list_models exit 1 fi @@ -110,7 +140,11 @@ function run_lang() { fi fi - $main -m ../models/ggml-$model.bin $threads -f $fname_dst -l $lang -otxt 2> /dev/null + if [ $is_parakeet -eq 1 ]; then + $main -m $model_path $threads -f $fname_dst -otxt 2> /dev/null + else + $main -m $model_path $threads -f $fname_dst -l $lang -otxt 2> /dev/null + fi git diff --no-index --word-diff=color --word-diff-regex=. $lang-$i-ref.txt $fname_dst.txt @@ -120,7 +154,7 @@ function run_lang() { run_lang "en" "${urls_en[@]}" -if [[ $model != *.en* ]]; then +if [ $is_parakeet -eq 0 ] && [[ $model != *.en* ]]; then run_lang "es" "${urls_es[@]}" run_lang "it" "${urls_it[@]}" run_lang "pt" "${urls_pt[@]}" diff --git a/tests/test-parakeet-full.cpp b/tests/test-parakeet-full.cpp new file mode 100644 index 00000000000..22ac4c20e31 --- /dev/null +++ b/tests/test-parakeet-full.cpp @@ -0,0 +1,101 @@ +#include "parakeet.h" +#include "common-whisper.h" +#include "parakeet-verification.h" + +#include <cstdio> +#include <string> + +#ifdef NDEBUG +#undef NDEBUG +#endif +#include <cassert> + +struct test_state { + bool is_first = true; + std::string transcript; +}; + +void progress_callback(parakeet_context * ctx, parakeet_state * state, int progress, void * user_data) { + bool * called = static_cast<bool *>(user_data); + *called = true; +} + +bool encoder_begin_callback(parakeet_context * ctx, parakeet_state * state, void * user_data) { + bool * called = static_cast<bool *>(user_data); + *called = true; + return true; +} + +bool abort_callback(void * user_data) { + bool * called = static_cast<bool *>(user_data); + *called = true; + return false; // just continue without aborting. +} + +void token_callback(parakeet_context * ctx, parakeet_state * state, const parakeet_token_data * token_data, void * user_data) { + test_state * tstate = static_cast<test_state *>(user_data); + + const char * token_str = parakeet_token_to_str(ctx, token_data->id); + char text_buf[256]; + parakeet_token_to_text(token_str, tstate->is_first, text_buf, sizeof(text_buf)); + + printf("%s", text_buf); + fflush(stdout); + + tstate->transcript += text_buf; + tstate->is_first = false; +} + +int main() { + std::string model_path = PARAKEET_MODEL_PATH; + std::string sample_path = SAMPLE_PATH; + + std::vector<float> pcmf32; + std::vector<std::vector<float>> pcmf32s; + assert(read_audio_data(sample_path.c_str(), pcmf32, pcmf32s, false)); + assert(pcmf32.size() > 0); + assert(pcmf32s.size() == 0); // no stereo vector + + printf("Loading Parakeet model from: %s\n", model_path.c_str()); + + struct parakeet_context_params ctx_params = parakeet_context_default_params(); + + struct parakeet_context * pctx = parakeet_init_from_file_with_params(model_path.c_str(), ctx_params); + if (pctx == nullptr) { + fprintf(stderr, "Failed to load Parakeet model\n"); + return 1; + } + printf("Successfully loaded Parakeet model\n"); + + struct parakeet_full_params params = parakeet_full_default_params(PARAKEET_SAMPLING_GREEDY); + test_state tstate; + params.new_token_callback = token_callback; + params.new_token_callback_user_data = &tstate; + bool progress_callback_called = false; + params.progress_callback = progress_callback; + params.progress_callback_user_data = &progress_callback_called; + bool encoder_begin_callback_called = false; + params.encoder_begin_callback = encoder_begin_callback; + params.encoder_begin_callback_user_data = &encoder_begin_callback_called; + bool abort_callback_called = false; + params.abort_callback = abort_callback; + params.abort_callback_user_data = &abort_callback_called; + + int ret = parakeet_full(pctx, params, pcmf32.data(), pcmf32.size()); + assert(ret == 0); + assert(progress_callback_called); + assert(encoder_begin_callback_called); + assert(abort_callback_called); + + const std::string expected = read_expected_transcription(EXPECTED_TRANSCRIPTION_PATH); + const bool transcript_matches = verify_transcription(expected, tstate.transcript); + + parakeet_free(pctx); + + if (!transcript_matches) { + return 1; + } + + printf("\nTest passed: parakeet_full succeeded!\n"); + return 0; +} diff --git a/tests/test-parakeet.cpp b/tests/test-parakeet.cpp new file mode 100644 index 00000000000..83237c600ac --- /dev/null +++ b/tests/test-parakeet.cpp @@ -0,0 +1,99 @@ +#include "parakeet.h" +#include "common-whisper.h" + +#include <cstdio> +#include <string> + +#ifdef NDEBUG +#undef NDEBUG +#endif +#include <cassert> + +void token_callback(parakeet_context * ctx, parakeet_state * state, const parakeet_token_data * token_data, void * user_data) { + static bool is_first = true; + const char * token_str = parakeet_token_to_str(ctx, token_data->id); + char text_buf[256]; + parakeet_token_to_text(token_str, is_first, text_buf, sizeof(text_buf)); + + int32_t time_ms = token_data->frame_index * 10; + + printf("%s", text_buf); + fflush(stdout); + + is_first = false; +} + +void segment_callback(parakeet_context * ctx, parakeet_state * state, int n_new, void * user_data) { + const int n_segments = parakeet_full_n_segments_from_state(state); + const int s0 = n_segments - n_new; + + printf("\nSegment Callback: %d new segment(s)\n", n_new); + + for (int i = s0; i < n_segments; i++) { + const char * text = parakeet_full_get_segment_text_from_state(state, i); + const int64_t t0 = parakeet_full_get_segment_t0_from_state(state, i); + const int64_t t1 = parakeet_full_get_segment_t1_from_state(state, i); + + printf("Segment %d: [%lld -> %lld] \"%s\"\n", i, (long long)t0, (long long)t1, text); + printf("Tokens:\n"); + + const int n_tokens = parakeet_full_n_tokens_from_state(state, i); + for (int j = 0; j < n_tokens; j++) { + parakeet_token_data token_data = parakeet_full_get_token_data_from_state(state, i, j); + const char * token_str = parakeet_token_to_str(ctx, token_data.id); + + printf(" [%2d] id=%5d frame=%3d dur_idx=%2d dur_val=%2d p=%.4f plog=%.4f t0=%4lld t1=%4lld word_start=%d \"%s\"\n", + j, + token_data.id, + token_data.frame_index, + token_data.duration_idx, + token_data.duration_value, + token_data.p, + token_data.plog, + (long long)token_data.t0, + (long long)token_data.t1, + token_data.is_word_start, + token_str); + } + } + printf("\n"); +} + +int main() { + std::string model_path = PARAKEET_MODEL_PATH; + std::string sample_path = SAMPLE_PATH; + + // Load the sample audio file + std::vector<float> pcmf32; + std::vector<std::vector<float>> pcmf32s; + assert(read_audio_data(sample_path.c_str(), pcmf32, pcmf32s, false)); + assert(pcmf32.size() > 0); + assert(pcmf32s.size() == 0); + + printf("Loading Parakeet model from: %s\n", model_path.c_str()); + + struct parakeet_context_params ctx_params = parakeet_context_default_params(); + + struct parakeet_context * pctx = parakeet_init_from_file_with_params_no_state(model_path.c_str(), ctx_params); + if (pctx == nullptr) { + fprintf(stderr, "Failed to load Parakeet model\n"); + return 1; + } + printf("Successfully loaded Parakeet model\n"); + + struct parakeet_full_params params = parakeet_full_default_params(PARAKEET_SAMPLING_GREEDY); + params.new_token_callback = token_callback; + params.new_token_callback_user_data = nullptr; + params.new_segment_callback = segment_callback; + params.new_segment_callback_user_data = nullptr; + parakeet_state * state = parakeet_init_state(pctx); + + int ret = parakeet_chunk(pctx, state, params, pcmf32.data(), pcmf32.size()); + assert(ret == 0); + + parakeet_free_state(state); + parakeet_free(pctx); + + printf("\nTest passed: Parakeet model loaded and freed successfully\n"); + return 0; +} From 0d14756929dc9f21ddccf6102bb783397b7a8f1b Mon Sep 17 00:00:00 2001 From: KITAITI Makoto <KitaitiMakoto@gmail.com> Date: Wed, 17 Jun 2026 13:42:09 +0900 Subject: [PATCH 827/831] ruby : add support for Parakeet (#3885) * Add Whisper::Parakeet::Params * Add tests for Parakeet::Params * Remove unused variabel * Add callbacks to Parakeet::Params * Group callback and user_data params * Undefine local macros * Define GetParakeetParams * Remove unused variable * Use ITERATE_CALLBACK_PARAMS * Use ITERATE_CALLBACK_PARAMS instead of ITERATE_USER_DATA_PARAMS * Fix memsize * Remove unnecessary macros * Simplify params registration * Define Parakeet * Add hook methods to Parakeet::Params * Fix typo * Check callback container in GetParakeetParams * Reduce if * Free parakeet_full_params * Implement Parakeet::Context#initialize * Add TestParakeetContext * Add Parakeet::Segment * Prevent double-free * Add Parakeet::Context#transcribe * Add Parakeet::Context#each_segment * Define Parakeet::Segment attributes * Define Parakeet::Segment#deconstruct_keys * Add tests for Parakeet::Segment#deconstruct_keys * Run Parakeet::Context#transcribe without GVL * Make it to abort for Parakeet * Add Parakeet.log_set * Define Parakeet::Token * Define Parakeet::Segment#each_token * Implement some hooks of Parakeet::Params * Convert int to VALUE * Implement hooks for Parakeet * Implement Parakeet::Context#full * Add tests for Parakeet::Context#full * Add Parakeet to RBS * Fix ruby_whisper_parakeet_params_free * Free ruby_whisper_parakeet_context * Add tests for hooks * Add Parakeet section to README * Add more attributes of Parakeet::Context * Add tests for Parakeet::Context's attributes * Update RBS * Register parakeet-tdt-0.6b-v3 * Narrow scope of log constants * Extract activate and deactivate of log_queue * Make start_log_callback_thread private * Don't call start_log_callback_thread unncecessarilly * Early return from log_queue_enqueue when not active * Gropu log_queue members * is_active -> is_open * Fix English * Share parakeet full body function * ruby_whisper_parakeet_abort_callback_user_data -> ruby_whisper_abort_callback_user_data * NULL check for callback containers * Fix Parakeet.log_set * Omit Parakeet tests on CI * Extract Whisper::LogSettable * Join log callback thread in a log queue function * Revert Join log callback thread in a log queue function * Extract output methods to modules * Move Parakeet init functions into init_parakeet() * Add output methods to Parakeet classes * Add Parakeet's output methods to RBS * Use Whisper::Output in RBS * Add LogSettable to RBS * Fix module Token -> class Token * Add Parakeet::Model * Add test for Parakeet::Model * Add Parakeet::Model to RBS * Move position of Parakeet::Model in RBS * Parakeet -> TestBase::Parakeet * Add Parakeet::Context#model in RBS * Add Whisper::Output * Fix nil check * Define ruby_whisper_parakeet_model_memsize * Fix order of declaration in ruby_whisper_parakeet_model_get_xxx * Define Parakeet.system_info_str * Add test for Parakeet.system_info_str * Add signature of Parakeet.system_info_str * Define Parakeet::VERSION * Add test for Parakeet::VERSION * Add signature of Parakeet::VERSION * Add Parakeet::Context::Params * Make Parakeet::Context.new accept Context::Params * Add test for Parakeet::Context.new with Context::Params * Update RBS * Remove params from Parakeet::Params which are moved from whisper_parakeet_full_params * Remove tests for removed params * Make Parakeet tests follow original behavior changes * Add Parakeet model shortcuts * Alloc token data in factory instead of alloc func * Fix variable name * Update RBS * Refactor log settable module * Use log settable for Whisper * Address deadlock * Make test follow change of log queue implementation * Refactor to make abort callback use the same way to parakeet's way * Remove redundant structs * Fix test name * Fix README * Add missing parallel transcription * Fix test for parakeet info * Remove removed params * Wait for logs dequeued * Fix instance variable name * Load etc feature * Remove unnecessary comment * Remove unnecessary thread safety check * Remove outdated comment * Skip downloading model if cache exists * Change Hugging Face URI for Parakeet models * Bump required Ruby version to 3.3 * Fix English --- .github/workflows/bindings-ruby.yml | 2 +- bindings/ruby/README.md | 31 + bindings/ruby/Rakefile | 17 +- bindings/ruby/ext/ruby_whisper.c | 116 ++-- bindings/ruby/ext/ruby_whisper.h | 135 ++++- bindings/ruby/ext/ruby_whisper_context.c | 51 +- bindings/ruby/ext/ruby_whisper_log_queue.c | 180 ++++++ bindings/ruby/ext/ruby_whisper_log_settable.h | 47 ++ bindings/ruby/ext/ruby_whisper_parakeet.c | 49 ++ .../ruby/ext/ruby_whisper_parakeet_context.c | 304 ++++++++++ .../ruby_whisper_parakeet_context_params.c | 117 ++++ .../ruby/ext/ruby_whisper_parakeet_model.c | 84 +++ .../ruby/ext/ruby_whisper_parakeet_params.c | 548 ++++++++++++++++++ .../ruby/ext/ruby_whisper_parakeet_segment.c | 157 +++++ .../ruby/ext/ruby_whisper_parakeet_token.c | 188 ++++++ .../ext/ruby_whisper_parakeet_transcribe.cpp | 58 ++ bindings/ruby/ext/ruby_whisper_params.c | 117 ++-- bindings/ruby/ext/ruby_whisper_segment.c | 12 +- bindings/ruby/ext/ruby_whisper_transcribe.cpp | 62 +- bindings/ruby/lib/whisper/context.rb | 15 - bindings/ruby/lib/whisper/log_settable.rb | 36 ++ bindings/ruby/lib/whisper/model/uri.rb | 14 +- bindings/ruby/lib/whisper/output.rb | 74 +++ bindings/ruby/lib/whisper/segment.rb | 58 -- bindings/ruby/sig/whisper.rbs | 369 +++++++++++- bindings/ruby/test/helper.rb | 2 + bindings/ruby/test/test_callback.rb | 1 + bindings/ruby/test/test_parakeet.rb | 28 + bindings/ruby/test/test_parakeet_callback.rb | 107 ++++ bindings/ruby/test/test_parakeet_context.rb | 116 ++++ .../ruby/test/test_parakeet_context_params.rb | 24 + bindings/ruby/test/test_parakeet_model.rb | 21 + bindings/ruby/test/test_parakeet_params.rb | 78 +++ bindings/ruby/test/test_parakeet_segment.rb | 42 ++ bindings/ruby/test/test_parakeet_token.rb | 73 +++ bindings/ruby/test/test_vad_segment.rb | 2 +- bindings/ruby/test/test_whisper.rb | 1 + bindings/ruby/whispercpp.gemspec | 2 +- 38 files changed, 3005 insertions(+), 333 deletions(-) create mode 100644 bindings/ruby/ext/ruby_whisper_log_queue.c create mode 100644 bindings/ruby/ext/ruby_whisper_log_settable.h create mode 100644 bindings/ruby/ext/ruby_whisper_parakeet.c create mode 100644 bindings/ruby/ext/ruby_whisper_parakeet_context.c create mode 100644 bindings/ruby/ext/ruby_whisper_parakeet_context_params.c create mode 100644 bindings/ruby/ext/ruby_whisper_parakeet_model.c create mode 100644 bindings/ruby/ext/ruby_whisper_parakeet_params.c create mode 100644 bindings/ruby/ext/ruby_whisper_parakeet_segment.c create mode 100644 bindings/ruby/ext/ruby_whisper_parakeet_token.c create mode 100644 bindings/ruby/ext/ruby_whisper_parakeet_transcribe.cpp delete mode 100644 bindings/ruby/lib/whisper/context.rb create mode 100644 bindings/ruby/lib/whisper/log_settable.rb create mode 100644 bindings/ruby/lib/whisper/output.rb delete mode 100644 bindings/ruby/lib/whisper/segment.rb create mode 100644 bindings/ruby/test/test_parakeet.rb create mode 100644 bindings/ruby/test/test_parakeet_callback.rb create mode 100644 bindings/ruby/test/test_parakeet_context.rb create mode 100644 bindings/ruby/test/test_parakeet_context_params.rb create mode 100644 bindings/ruby/test/test_parakeet_model.rb create mode 100644 bindings/ruby/test/test_parakeet_params.rb create mode 100644 bindings/ruby/test/test_parakeet_segment.rb create mode 100644 bindings/ruby/test/test_parakeet_token.rb diff --git a/.github/workflows/bindings-ruby.yml b/.github/workflows/bindings-ruby.yml index 80a243e4c98..8cdb7a810f7 100644 --- a/.github/workflows/bindings-ruby.yml +++ b/.github/workflows/bindings-ruby.yml @@ -27,6 +27,6 @@ jobs: steps: - uses: ruby/setup-ruby@afeafc3d1ab54a631816aba4c914a0081c12ff2f # v1.310.0 with: - ruby-version: '3.2' + ruby-version: '3.3' - uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - run: rake test diff --git a/bindings/ruby/README.md b/bindings/ruby/README.md index 07b81830c58..7f6b7d92c09 100644 --- a/bindings/ruby/README.md +++ b/bindings/ruby/README.md @@ -396,6 +396,37 @@ whisper .full(Whisper::Params.new, samples) ``` +### Parakeet ### + +whispercpp gem now supports NVIDIA's ASR model Parakeet. + +If you want to use Parakeet instead of Whisper, the API should feel familiar. +In most cases, replace `Whisper::Context` and `Whisper::Params` with `Whisper::Parakeet::Context` and `Whisper::Parakeet::Params`, then use `#transcribe`, `#full`, `#each_segment`, and `#each_token` in the same way. + +```ruby +require "whisper" + +# It's useful to assign Whisper::Parakeet to top-level Parakeet constant unless you use Parakeet gem. +Parakeet = Whisper::Parakeet + +parakeet = Parakeet::Context.new("path/to/model") + +params = Parakeet::Params.new( + no_context: true +) + +parakeet + .transcribe("path/to/audio.wav", params) + .each_segment do |segment| + puts "[#{segment.start_time} --> #{segment.end_time}] #{segment.text}" + end +``` + +The main differences are: + +* Namespace is `Whisper::Parakeet`. +* Parakeet also supports `on_new_token` / `new_token_callback` in addition to segment and progress callbacks. + Custom context params --------------------- diff --git a/bindings/ruby/Rakefile b/bindings/ruby/Rakefile index 7b521b3bdfa..2327651a06a 100644 --- a/bindings/ruby/Rakefile +++ b/bindings/ruby/Rakefile @@ -84,6 +84,21 @@ else end end +TEST_PARAKEET_MODEL = "test/fixtures/for-tests-ggml-parakeet-tdt.bin" +TEST_PARAKEET_MODEL_SRC = File.expand_path(File.join(__dir__, "..", "..", "models", "for-tests-ggml-parakeet-tdt.bin")) +TEST_PARAKEET_MODEL_DIR = TEST_PARAKEET_MODEL.pathmap("%d") +directory TEST_PARAKEET_MODEL_DIR +if File.exist? TEST_PARAKEET_MODEL_SRC + file TEST_PARAKEET_MODEL => [TEST_PARAKEET_MODEL_SRC, TEST_PARAKEET_MODEL_DIR] do |t| + symlink t.source, t.name + end +else + require "open-uri" + file TEST_PARAKEET_MODEL => TEST_PARAKEET_MODEL_DIR do |t| + File.write t.name, URI("https://github.com/ggml-org/whisper.cpp/raw/refs/heads/master/models/for-tests-ggml-parakeet-tdt.bin").read + end +end + TEST_MEMORY_VIEW = "test/jfk_reader/jfk_reader.#{RbConfig::CONFIG['DLEXT']}" file TEST_MEMORY_VIEW => "test/jfk_reader/jfk_reader.c" do |t| chdir "test/jfk_reader" do @@ -93,4 +108,4 @@ file TEST_MEMORY_VIEW => "test/jfk_reader/jfk_reader.c" do |t| end CLEAN.include TEST_MEMORY_VIEW -task test: [LIB_FILE, TEST_MEMORY_VIEW, TEST_FIXTURE_AUDIO] +task test: [LIB_FILE, TEST_MEMORY_VIEW, TEST_FIXTURE_AUDIO, TEST_PARAKEET_MODEL] diff --git a/bindings/ruby/ext/ruby_whisper.c b/bindings/ruby/ext/ruby_whisper.c index 56fceb1c894..7941b1a99dd 100644 --- a/bindings/ruby/ext/ruby_whisper.c +++ b/bindings/ruby/ext/ruby_whisper.c @@ -1,19 +1,29 @@ #include "ruby_whisper.h" VALUE mWhisper; +VALUE mLogSettable; VALUE mVAD; +VALUE mParakeet; VALUE cContext; VALUE cParams; VALUE cVADContext; VALUE cVADParams; VALUE cVADSegments; VALUE cVADSegment; +VALUE cParakeetContext; +VALUE cParakeetContextParams; +VALUE cParakeetParams; +VALUE cParakeetSegment; +VALUE cParakeetModel; VALUE eError; VALUE cSegment; VALUE cToken; VALUE cModel; +VALUE mOutputContext; +VALUE mOutputSegment; + ID id_to_s; ID id_call; ID id___method__; @@ -27,9 +37,11 @@ ID id_pre_converted_models; ID id_coreml_compiled_models; ID id_cache; ID id_n_processors; - -static bool is_log_callback_finalized = false; -static bool is_ruby_log_callback_present = false; +ID id_extended; +ID id_start_log_callback_thread; +ID id_log_callback_thread; +ID id_alive_p; +ID id_join; // High level API extern VALUE ruby_whisper_segment_allocate(VALUE klass); @@ -45,8 +57,13 @@ extern void init_ruby_whisper_vad_params(VALUE *mVAD); extern void init_ruby_whisper_vad_context(VALUE *mVAD); extern void init_ruby_whisper_vad_segment(VALUE *mVAD); extern void init_ruby_whisper_vad_segments(VALUE *mVAD); +extern void init_ruby_whisper_parakeet(VALUE *mWhisper); extern void register_callbacks(ruby_whisper_params *rwp, VALUE *context); +static ruby_whisper_log_queue whisper_log_queue; + +LOG_SETTABLE_SETUP(whisper_log_queue, mWhisper, whisper_log_set) + /* * call-seq: * lang_max_id -> Integer @@ -102,79 +119,6 @@ static VALUE ruby_whisper_s_system_info_str(VALUE self) { return rb_str_new2(whisper_print_system_info()); } -static VALUE ruby_whisper_s_finalize_log_callback(VALUE self, VALUE id) { - is_log_callback_finalized = true; - return Qnil; -} - -typedef struct { - int level; - const char * buffer; -} call_log_callbacks_args; - -static void* -call_log_callbacks(void *v_args) { - VALUE log_callback = rb_iv_get(mWhisper, "log_callback"); - if (NIL_P(log_callback)) { - return NULL; - } - - call_log_callbacks_args *args = (call_log_callbacks_args *)v_args; - VALUE user_data = rb_iv_get(mWhisper, "user_data"); - rb_funcall(log_callback, id_call, 3, INT2NUM(args->level), rb_str_new2(args->buffer), user_data); - - return NULL; -} - -static void -ruby_whisper_log_callback(enum ggml_log_level level, const char * buffer, void * user_data) { - if (is_log_callback_finalized) { - return; - } - if (!is_ruby_log_callback_present) { - return; - } - - call_log_callbacks_args args = { - level, - buffer, - }; - if (ruby_thread_has_gvl_p()) { - call_log_callbacks((void *)&args); - } else { - rb_thread_call_with_gvl(call_log_callbacks, (void *)&args); - } -} - -/* - * call-seq: - * log_set ->(level, buffer, user_data) { ... }, user_data -> nil - */ -static VALUE ruby_whisper_s_log_set(VALUE self, VALUE log_callback, VALUE user_data) { - VALUE old_callback = rb_iv_get(self, "log_callback"); - if (!NIL_P(old_callback)) { - rb_undefine_finalizer(old_callback); - } - - rb_iv_set(self, "log_callback", log_callback); - rb_iv_set(self, "user_data", user_data); - - if (!NIL_P(log_callback)) { - VALUE finalize_log_callback = rb_funcall(mWhisper, rb_intern("method"), 1, rb_str_new2("finalize_log_callback")); - rb_define_finalizer(log_callback, finalize_log_callback); - } - - if (NIL_P(log_callback)) { - whisper_log_set(NULL, NULL); - is_ruby_log_callback_present = false; - } else { - whisper_log_set(ruby_whisper_log_callback, NULL); - is_ruby_log_callback_present = true; - } - - return Qnil; -} - void Init_whisper() { id_to_s = rb_intern("to_s"); id_call = rb_intern("call"); @@ -189,9 +133,19 @@ void Init_whisper() { id_coreml_compiled_models = rb_intern("coreml_compiled_models"); id_cache = rb_intern("cache"); id_n_processors = rb_intern("n_processors"); + id_extended = rb_intern("extended"); + id_start_log_callback_thread = rb_intern("start_log_callback_thread"); + id_log_callback_thread = rb_intern("@log_callback_thread"); + id_alive_p = rb_intern("alive?"); + id_join = rb_intern("join"); mWhisper = rb_define_module("Whisper"); + rb_require("whisper/log_settable"); + mLogSettable = rb_path2class("Whisper::LogSettable"); mVAD = rb_define_module_under(mWhisper, "VAD"); + rb_require("whisper/output"); + mOutputContext = rb_path2class("Whisper::Output::Context"); + mOutputSegment = rb_path2class("Whisper::Output::Segment"); rb_define_const(mWhisper, "VERSION", rb_str_new2(whisper_version())); rb_define_const(mWhisper, "LOG_LEVEL_NONE", INT2NUM(GGML_LOG_LEVEL_NONE)); @@ -222,8 +176,8 @@ void Init_whisper() { rb_define_singleton_method(mWhisper, "lang_str", ruby_whisper_s_lang_str, 1); rb_define_singleton_method(mWhisper, "lang_str_full", ruby_whisper_s_lang_str_full, 1); rb_define_singleton_method(mWhisper, "system_info_str", ruby_whisper_s_system_info_str, 0); - rb_define_singleton_method(mWhisper, "log_set", ruby_whisper_s_log_set, 2); - rb_define_private_method(rb_singleton_class(mWhisper), "finalize_log_callback", ruby_whisper_s_finalize_log_callback, 1); + + LOG_SETTABLE_INIT(whisper_log_queue, mWhisper) cContext = init_ruby_whisper_context(&mWhisper); init_ruby_whisper_context_params(&cContext); @@ -236,8 +190,10 @@ void Init_whisper() { init_ruby_whisper_vad_segment(&mVAD); init_ruby_whisper_vad_segments(&mVAD); init_ruby_whisper_vad_context(&mVAD); + init_ruby_whisper_parakeet(&mWhisper); - rb_require("whisper/context"); - rb_require("whisper/segment"); rb_require("whisper/model/uri"); + + rb_include_module(cContext, mOutputContext); + rb_include_module(cSegment, mOutputSegment); } diff --git a/bindings/ruby/ext/ruby_whisper.h b/bindings/ruby/ext/ruby_whisper.h index ba4d8b6fbcc..10e90674953 100644 --- a/bindings/ruby/ext/ruby_whisper.h +++ b/bindings/ruby/ext/ruby_whisper.h @@ -5,8 +5,12 @@ #include <ruby/version.h> #include <ruby/util.h> #include <ruby/thread.h> +#include <ruby/thread_native.h> +#include <ruby/atomic.h> #include <ruby/memory_view.h> #include "whisper.h" +#include "parakeet.h" +#include "ruby_whisper_log_settable.h" #if RUBY_API_VERSION_MAJOR < 4 // Exists but not declared as public API @@ -20,13 +24,28 @@ typedef struct { VALUE callbacks; } ruby_whisper_callback_container; -typedef struct { - VALUE *context; - VALUE user_data; - VALUE callback; - VALUE callbacks; - bool is_interrupted; -} ruby_whisper_abort_callback_container; +typedef struct ruby_whisper_abort_callback_user_data { + volatile rb_atomic_t is_interrupted; + ruby_whisper_callback_container *callback_container; +} ruby_whisper_abort_callback_user_data; + +typedef struct ruby_whisper_log { + enum ggml_log_level level; + char *text; + size_t length; + size_t capacity; +} ruby_whisper_log; + +typedef struct ruby_whisper_log_queue { + rb_nativethread_lock_t lock; + rb_nativethread_cond_t cond; + bool is_open; + + size_t head; + size_t tail; + size_t size; + ruby_whisper_log *logs; +} ruby_whisper_log_queue; typedef struct { struct whisper_context *context; @@ -42,7 +61,7 @@ typedef struct { ruby_whisper_callback_container *new_segment_callback_container; ruby_whisper_callback_container *progress_callback_container; ruby_whisper_callback_container *encoder_begin_callback_container; - ruby_whisper_abort_callback_container *abort_callback_container; + ruby_whisper_callback_container *abort_callback_container; VALUE vad_params; } ruby_whisper_params; @@ -84,6 +103,63 @@ typedef struct parsed_samples_t { bool memview_exported; } parsed_samples_t; +typedef struct { + VALUE *context; + VALUE *params; + float *samples; + int n_samples; +} ruby_whisper_full_args; + +typedef struct ruby_whisper_full_parallel_args { + VALUE *context; + VALUE *params; + float *samples; + int n_samples; + int n_processors; +} ruby_whisper_full_parallel_args; + +typedef struct { + struct parakeet_full_params params; + ruby_whisper_callback_container *new_segment_callback_container; + ruby_whisper_callback_container *new_token_callback_container; + ruby_whisper_callback_container *progress_callback_container; + ruby_whisper_callback_container *encoder_begin_callback_container; + ruby_whisper_callback_container *abort_callback_container; +} ruby_whisper_parakeet_params; + +typedef struct { + struct parakeet_context_params params; +} ruby_whisper_parakeet_context_params; + +typedef struct { + struct parakeet_context *context; +} ruby_whisper_parakeet_context; + +typedef struct { + VALUE context; + int index; +} ruby_whisper_parakeet_segment; + +typedef struct { + parakeet_token_data *token_data; + VALUE text; +} ruby_whisper_parakeet_token; + +typedef struct { + VALUE context; +} ruby_whisper_parakeet_model; + +extern ID id_extended; +extern ID id_log_callback_thread; +extern ID id_start_log_callback_thread; +extern ID id_alive_p; +extern ID id_join; +extern void ruby_whisper_log_queue_initialize(ruby_whisper_log_queue *log_queue); +extern void ruby_whisper_log_queue_open(ruby_whisper_log_queue *log_queue); +extern void ruby_whisper_log_queue_close(ruby_whisper_log_queue *log_queue); +extern void ruby_whisper_log_queue_enqueue(ruby_whisper_log_queue *log_queue, enum ggml_log_level level, const char *text); +extern VALUE ruby_whisper_log_queue_drain(ruby_whisper_log_queue *log_queue); + #define GetContext(obj, rw) do { \ TypedData_Get_Struct((obj), ruby_whisper, &ruby_whisper_type, (rw)); \ if ((rw)->context == NULL) { \ @@ -120,4 +196,47 @@ typedef struct parsed_samples_t { } \ } while (0) +#define GetParakeetContextParams(obj, rwpcp) do { \ + TypedData_Get_Struct((obj), ruby_whisper_parakeet_context_params, &ruby_whisper_parakeet_context_params_type, (rwpcp)); \ +} while (0) + +#define GetParakeetContext(obj, rwpc) do { \ + TypedData_Get_Struct((obj), ruby_whisper_parakeet_context, &ruby_whisper_parakeet_context_type, (rwpc)); \ + if ((rwpc)->context == NULL) { \ + rb_raise(rb_eRuntimeError, "Not initialized"); \ + } \ +} while (0) + +#define GetParakeetParams(obj, rwpp) do { \ + TypedData_Get_Struct((obj), ruby_whisper_parakeet_params, &ruby_whisper_parakeet_params_type, (rwpp)); \ + if (!(rwpp)->new_segment_callback_container || \ + !(rwpp)->new_token_callback_container || \ + !(rwpp)->progress_callback_container || \ + !(rwpp)->encoder_begin_callback_container || \ + !(rwpp)->abort_callback_container) { \ + rb_raise(rb_eRuntimeError, "Not initialized"); \ + } \ +} while (0) + +#define GetParakeetSegment(obj, rwps) do { \ + TypedData_Get_Struct((obj), ruby_whisper_parakeet_segment, &ruby_whisper_parakeet_segment_type, (rwps)); \ + if (!(rwps)->context) { \ + rb_raise(rb_eRuntimeError, "Not initialized"); \ + } \ +} while (0) + +#define GetParakeetToken(obj, rwpt) do { \ + TypedData_Get_Struct((obj), ruby_whisper_parakeet_token, &ruby_whisper_parakeet_token_type, (rwpt)); \ + if (!(rwpt)->token_data) { \ + rb_raise(rb_eRuntimeError, "Not initialized"); \ + } \ +} while (0) + +#define GetParakeetModel(obj, rwpm) do { \ + TypedData_Get_Struct((obj), ruby_whisper_parakeet_model, &ruby_whisper_parakeet_model_type, (rwpm)); \ + if (NIL_P((rwpm)->context)) { \ + rb_raise(rb_eRuntimeError, "Not initialized"); \ + } \ +} while (0) + #endif diff --git a/bindings/ruby/ext/ruby_whisper_context.c b/bindings/ruby/ext/ruby_whisper_context.c index 26058fc07e6..9e5fc33e726 100644 --- a/bindings/ruby/ext/ruby_whisper_context.c +++ b/bindings/ruby/ext/ruby_whisper_context.c @@ -28,7 +28,7 @@ extern const rb_data_type_t ruby_whisper_context_params_type; extern VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self); extern VALUE rb_whisper_model_s_new(VALUE context); extern VALUE rb_whisper_segment_s_new(VALUE context, int index); -extern void prepare_transcription(ruby_whisper_params *rwp, VALUE *context, int n_processors); +extern void prepare_transcription(ruby_whisper_params *rwp, VALUE *context, int n_processors, ruby_whisper_abort_callback_user_data *abort_callback_user_data); ID transcribe_option_names[1]; @@ -38,21 +38,6 @@ typedef struct fill_samples_args { int n_samples; } fill_samples_args; -typedef struct full_args { - VALUE *context; - VALUE *params; - float *samples; - int n_samples; -} full_args; - -typedef struct full_parallel_args { - VALUE *context; - VALUE *params; - float *samples; - int n_samples; - int n_processors; -} full_parallel_args; - typedef struct full_without_gvl_args { struct whisper_context *context; struct whisper_full_params *params; @@ -71,7 +56,7 @@ typedef struct full_parallel_without_gvl_args { } full_parallel_without_gvl_args; typedef struct full_ubf_args { - ruby_whisper_abort_callback_container *abort_callback_container; + ruby_whisper_abort_callback_user_data *abort_callback_user_data; } full_ubf_args; static void @@ -379,7 +364,7 @@ fill_samples(VALUE rb_args) return Qnil; } -struct parsed_samples_t +parsed_samples_t parse_samples(VALUE *samples, VALUE *n_samples) { bool memview_available = rb_memory_view_available_p(*samples); @@ -480,20 +465,24 @@ full_ubf(void *rb_args) { full_ubf_args *args = (full_ubf_args *)rb_args; - args->abort_callback_container->is_interrupted = true; + RUBY_ATOMIC_SET(args->abort_callback_user_data->is_interrupted, 1); } -static VALUE +VALUE full_body(VALUE rb_args) { - full_args *args = (full_args *)rb_args; + ruby_whisper_full_args *args = (ruby_whisper_full_args *)rb_args; ruby_whisper *rw; ruby_whisper_params *rwp; GetContext(*args->context, rw); TypedData_Get_Struct(*args->params, ruby_whisper_params, &ruby_whisper_params_type, rwp); - prepare_transcription(rwp, args->context, 1); + ruby_whisper_abort_callback_user_data abort_callback_user_data = { + 0, + NULL, + }; + prepare_transcription(rwp, args->context, 1, &abort_callback_user_data); struct full_without_gvl_args full_without_gvl_args = { rw->context, @@ -503,7 +492,7 @@ full_body(VALUE rb_args) 0, }; full_ubf_args full_ubf_args = { - rwp->abort_callback_container, + &abort_callback_user_data, }; rb_thread_call_without_gvl(full_without_gvl, (void *)&full_without_gvl_args, full_ubf, (void *)&full_ubf_args); return INT2NUM(full_without_gvl_args.result); @@ -529,7 +518,7 @@ VALUE ruby_whisper_full(int argc, VALUE *argv, VALUE self) VALUE n_samples = argc == 2 ? Qnil : argv[2]; struct parsed_samples_t parsed = parse_samples(&argv[1], &n_samples); - full_args args = { + ruby_whisper_full_args args = { &self, &argv[0], parsed.samples, @@ -552,17 +541,21 @@ full_parallel_without_gvl(void *rb_args) return NULL; } -static VALUE +VALUE full_parallel_body(VALUE rb_args) { - full_parallel_args *args = (full_parallel_args *)rb_args; + ruby_whisper_full_parallel_args *args = (ruby_whisper_full_parallel_args *)rb_args; ruby_whisper *rw; ruby_whisper_params *rwp; GetContext(*args->context, rw); TypedData_Get_Struct(*args->params, ruby_whisper_params, &ruby_whisper_params_type, rwp); - prepare_transcription(rwp, args->context, args->n_processors); + ruby_whisper_abort_callback_user_data abort_callback_user_data = { + 0, + NULL, + }; + prepare_transcription(rwp, args->context, args->n_processors, &abort_callback_user_data); struct full_parallel_without_gvl_args full_parallel_without_gvl_args = { rw->context, @@ -573,7 +566,7 @@ full_parallel_body(VALUE rb_args) 0, }; full_ubf_args full_ubf_args = { - rwp->abort_callback_container, + &abort_callback_user_data, }; rb_thread_call_without_gvl(full_parallel_without_gvl, (void *)&full_parallel_without_gvl_args, full_ubf, (void *)&full_ubf_args); return INT2NUM(full_parallel_without_gvl_args.result); @@ -613,7 +606,7 @@ ruby_whisper_full_parallel(int argc, VALUE *argv,VALUE self) break; } struct parsed_samples_t parsed = parse_samples(&argv[1], &n_samples); - const full_parallel_args args = { + const ruby_whisper_full_parallel_args args = { &self, &argv[0], parsed.samples, diff --git a/bindings/ruby/ext/ruby_whisper_log_queue.c b/bindings/ruby/ext/ruby_whisper_log_queue.c new file mode 100644 index 00000000000..6558a339c6f --- /dev/null +++ b/bindings/ruby/ext/ruby_whisper_log_queue.c @@ -0,0 +1,180 @@ +#include "ruby_whisper.h" + +#define LOG_QUEUE_CAPACITY 256 +#define LOG_DEFAULT_CAPACITY 1024 + +void +ruby_whisper_log_queue_initialize(ruby_whisper_log_queue *log_queue) +{ + rb_nativethread_lock_initialize(&log_queue->lock); + rb_native_cond_initialize(&log_queue->cond); + log_queue->head = 0; + log_queue->tail = 0; + log_queue->size = 0; + log_queue->is_open = true; + log_queue->logs = ALLOC_N(ruby_whisper_log, LOG_QUEUE_CAPACITY); + for (size_t i = 0; i < LOG_QUEUE_CAPACITY; i++) { + // we cannot call Ruby API like ALLOC_N because this slot may be realloced without GVL + // this doesn't be freed because log queue lives until the end of process + char *slot = malloc(sizeof(char) * LOG_QUEUE_CAPACITY); + if (!slot) { + rb_raise(rb_eRuntimeError, "Could not allocate memory for log text"); + } + ruby_whisper_log log = { + 0, + slot, + 0, + LOG_QUEUE_CAPACITY, + }; + log_queue->logs[i] = log; + } +} + +void +ruby_whisper_log_queue_open(ruby_whisper_log_queue *log_queue) +{ + rb_nativethread_lock_lock(&log_queue->lock); + + log_queue->is_open = true; + + rb_native_cond_signal(&log_queue->cond); + + rb_nativethread_lock_unlock(&log_queue->lock); +} + +void +ruby_whisper_log_queue_close(ruby_whisper_log_queue *log_queue) +{ + rb_nativethread_lock_lock(&log_queue->lock); + + log_queue->is_open = false; + rb_native_cond_broadcast(&log_queue->cond); + + rb_nativethread_lock_unlock(&log_queue->lock); +} + +static size_t +calc_enough_cap(size_t len) +{ + size_t quot = len / LOG_DEFAULT_CAPACITY; + size_t rem = len % LOG_DEFAULT_CAPACITY; + + return sizeof(char) * (rem == 0 ? quot : quot + 1) * LOG_DEFAULT_CAPACITY; +} + +void +ruby_whisper_log_queue_enqueue(ruby_whisper_log_queue *log_queue, enum ggml_log_level level, const char *text) +{ + rb_nativethread_lock_lock(&log_queue->lock); + + if (!log_queue->is_open) { + rb_nativethread_lock_unlock(&log_queue->lock); + return; + } + + size_t len = strlen(text); + ruby_whisper_log *log = &log_queue->logs[log_queue->head]; + if (len > log->capacity) { + size_t new_cap = calc_enough_cap(len); + // we cannot call Ruby API like REALLOC_N because this function is called without GVL + char *slot = realloc(log->text, new_cap); + if (!slot) { + rb_nativethread_lock_unlock(&log_queue->lock); + return; + } + log->text = slot; + log->capacity = new_cap; + } + // we cannot call Ruby API like MEMCPY because this function is called without GVL + memcpy(log->text, text, sizeof(char) * len); + log->length = len; + log->level = level; + log_queue->head = (log_queue->head + 1) % LOG_QUEUE_CAPACITY; + bool is_full = log_queue->size >= LOG_QUEUE_CAPACITY; + log_queue->size = is_full ? LOG_QUEUE_CAPACITY : log_queue->size + 1; + if (is_full) { + log_queue->tail = log_queue->head; + } + + rb_native_cond_signal(&log_queue->cond); + rb_nativethread_lock_unlock(&log_queue->lock); +} + +static void* +ruby_whisper_log_queue_wait(void *args) +{ + ruby_whisper_log_queue *log_queue = (ruby_whisper_log_queue *)args; + + rb_native_cond_wait(&log_queue->cond, &log_queue->lock); + rb_nativethread_lock_unlock(&log_queue->lock); + + return NULL; +} + +static void +ruby_whisper_log_queue_wait_ubf(void *args) +{ + ruby_whisper_log_queue *log_queue = (ruby_whisper_log_queue *)args; + + rb_native_cond_broadcast(&log_queue->cond); +} + +typedef struct { + enum ggml_log_level level; + size_t length; + char *text; +} log_snapshot; + +VALUE +ruby_whisper_log_queue_drain(ruby_whisper_log_queue *log_queue) +{ + log_snapshot logs[LOG_QUEUE_CAPACITY]; + + rb_nativethread_lock_lock(&log_queue->lock); + + while (log_queue->size == 0 && log_queue->is_open) { + rb_thread_call_without_gvl(ruby_whisper_log_queue_wait, (void *)log_queue, ruby_whisper_log_queue_wait_ubf, (void *)log_queue); + rb_nativethread_lock_lock(&log_queue->lock); + } + + if (log_queue->size == 0 && !log_queue->is_open) { + rb_native_cond_broadcast(&log_queue->cond); + rb_nativethread_lock_unlock(&log_queue->lock); + return Qnil; + } + + size_t size = log_queue->size; + ruby_whisper_log *log; + size_t i; + for (i = 0; i < size; i++) { + log = &log_queue->logs[(log_queue->tail + i) % LOG_QUEUE_CAPACITY]; + logs[i].level = log->level; + logs[i].length = log->length; + char *text = malloc(log->length); + if (!text) { + logs[i].text = NULL; + continue; + } + logs[i].text = text; + memcpy(logs[i].text, log->text, log->length); + } + log_queue->size = 0; + log_queue->tail = log_queue->head; + + rb_native_cond_signal(&log_queue->cond); + + rb_nativethread_lock_unlock(&log_queue->lock); + + VALUE rb_logs = rb_ary_new2(size); + VALUE rb_text; + for (i = 0; i < size; i++) { + if (!logs[i].text) { + continue; + } + rb_text = rb_str_new(logs[i].text, logs[i].length); + free(logs[i].text); + rb_ary_push(rb_logs, rb_ary_new3(2, INT2NUM(logs[i].level), rb_text)); + } + + return rb_logs; +} diff --git a/bindings/ruby/ext/ruby_whisper_log_settable.h b/bindings/ruby/ext/ruby_whisper_log_settable.h new file mode 100644 index 00000000000..b98fbac826b --- /dev/null +++ b/bindings/ruby/ext/ruby_whisper_log_settable.h @@ -0,0 +1,47 @@ +#ifndef RUBY_WHISPER_LOG_SETTABLE_H +#define RUBY_WHISPER_LOG_SETTABLE_H + +#define LOG_SETTABLE_SETUP(log_queue, mod, log_set) \ + static VALUE \ + ruby_whisper_##log_queue##_s_drain_logs(VALUE self) \ + { \ + return ruby_whisper_log_queue_drain(&log_queue); \ + } \ + static void \ + ruby_whisper_##log_queue##_log_callback(enum ggml_log_level level, const char *text, void *user_data) \ + { \ + ruby_whisper_log_queue_enqueue(&log_queue, level, text); \ + } \ + static VALUE \ + ruby_whisper_##log_queue##_s_log_set(VALUE self, VALUE log_callback, VALUE user_data) \ + { \ + rb_iv_set(self, "@log_callback", log_callback); \ + rb_iv_set(self, "@log_callback_user_data", user_data); \ + if (NIL_P(log_callback)) { \ + log_set(NULL, NULL); \ + } else { \ + ruby_whisper_log_queue_open(&log_queue); \ + rb_funcall((mod), id_start_log_callback_thread, 0); \ + log_set(ruby_whisper_##log_queue##_log_callback, NULL); \ + } \ + return Qnil; \ + } \ + static void \ + ruby_whisper_##log_queue##_end_proc(VALUE args) \ + { \ + ruby_whisper_log_queue_close(&log_queue); \ + VALUE log_callback_thread = rb_ivar_get(mod, id_log_callback_thread); \ + if (!NIL_P(log_callback_thread) && RTEST(rb_funcall(log_callback_thread, id_alive_p, 0))) { \ + rb_funcall(log_callback_thread, id_join, 0); \ + } \ + } + +#define LOG_SETTABLE_INIT(log_queue, mod) \ + ruby_whisper_log_queue_initialize(&log_queue); \ + rb_define_singleton_method(mod, "drain_logs", ruby_whisper_##log_queue##_s_drain_logs, 0); \ + rb_define_singleton_method(mod, "log_set", ruby_whisper_##log_queue##_s_log_set, 2); \ + rb_set_end_proc(ruby_whisper_##log_queue##_end_proc, Qnil); \ + rb_extend_object(mod, mLogSettable); \ + rb_funcall(mLogSettable, id_extended, 1, mod); + +#endif diff --git a/bindings/ruby/ext/ruby_whisper_parakeet.c b/bindings/ruby/ext/ruby_whisper_parakeet.c new file mode 100644 index 00000000000..d69369401d0 --- /dev/null +++ b/bindings/ruby/ext/ruby_whisper_parakeet.c @@ -0,0 +1,49 @@ +#include "ruby_whisper.h" +#include <stdio.h> +#include <unistd.h> + +extern VALUE mParakeet; +extern VALUE mLogSettable; +extern VALUE cParakeetContext; +extern VALUE cParakeetSegment; +extern VALUE mOutputContext; +extern VALUE mOutputSegment; + +extern void init_ruby_whisper_parakeet_params(VALUE *mParakeet); +extern void init_ruby_whisper_parakeet_token(VALUE *mParakeet); +extern void init_ruby_whisper_parakeet_segment(VALUE *mParakeet); +extern VALUE init_ruby_whisper_parakeet_context(VALUE *mParakeet); +extern void init_ruby_whisper_parakeet_context_params(VALUE *cParakeetContext); +extern void init_ruby_whisper_parakeet_model(VALUE *mParakeet); + +static ruby_whisper_log_queue parakeet_log_queue; + +LOG_SETTABLE_SETUP(parakeet_log_queue, mParakeet, parakeet_log_set) + +static VALUE +ruby_whisper_parakeet_s_system_info_str(VALUE self) +{ + return rb_str_new2(parakeet_print_system_info()); +} + +void +init_ruby_whisper_parakeet(VALUE *mWhisper) +{ + mParakeet = rb_define_module_under(*mWhisper, "Parakeet"); + + rb_define_const(mParakeet, "VERSION", rb_str_new2(parakeet_version())); + + LOG_SETTABLE_INIT(parakeet_log_queue, mParakeet) + + rb_define_singleton_method(mParakeet, "system_info_str", ruby_whisper_parakeet_s_system_info_str, 0); + + init_ruby_whisper_parakeet_params(&mParakeet); + init_ruby_whisper_parakeet_token(&mParakeet); + init_ruby_whisper_parakeet_segment(&mParakeet); + cParakeetContext = init_ruby_whisper_parakeet_context(&mParakeet); + init_ruby_whisper_parakeet_context_params(&cParakeetContext); + init_ruby_whisper_parakeet_model(&mParakeet); + + rb_include_module(cParakeetContext, mOutputContext); + rb_include_module(cParakeetSegment, mOutputSegment); +} diff --git a/bindings/ruby/ext/ruby_whisper_parakeet_context.c b/bindings/ruby/ext/ruby_whisper_parakeet_context.c new file mode 100644 index 00000000000..b4a2fc5c4b7 --- /dev/null +++ b/bindings/ruby/ext/ruby_whisper_parakeet_context.c @@ -0,0 +1,304 @@ +#include "ruby_whisper.h" + +#define ITERATE_SEGMENT_ATTRS(ITERATOR) \ + ITERATOR(get_segment_t0, LONG) \ + ITERATOR(get_segment_t1, LONG) \ + ITERATOR(get_segment_text, STRING) \ + ITERATOR(n_tokens, INT) + +#define ITERATE_TOKEN_ATTRS(ITERATOR) \ + ITERATOR(get_token_text, STRING) \ + ITERATOR(get_token_id, INT) \ + ITERATOR(get_token_p, FLOAT) + +#define VAL_FROM_LONG(v) LONG2NUM(v) +#define VAL_FROM_STRING(v) rb_utf8_str_new_cstr(v) +#define VAL_FROM_INT(v) INT2NUM(v) +#define VAL_FROM_FLOAT(v) DBL2NUM(v) +#define READER(type) VAL_FROM_##type + +extern ID id_to_s; +extern ID id___method__; +extern ID id_to_enum; +extern ID id_new; + +extern VALUE cParakeetContext; +extern VALUE eError; + +extern VALUE ruby_whisper_normalize_model_path(VALUE model_path); +extern VALUE ruby_whisper_parakeet_transcribe(VALUE self, VALUE audio_path, VALUE params); +extern VALUE ruby_whisper_parakeet_segment_init(VALUE context, int index); +extern parsed_samples_t parse_samples(VALUE *samples, VALUE *n_samples); +extern VALUE release_samples(VALUE rb_parsed_args); +extern void ruby_whisper_parakeet_prepare_transcription(ruby_whisper_parakeet_params *rwpp, VALUE *context, ruby_whisper_abort_callback_user_data *abort_callback_user_data); +extern rb_data_type_t ruby_whisper_parakeet_params_type; +extern rb_data_type_t ruby_whisper_parakeet_context_params_type; +extern VALUE ruby_whisper_parakeet_token_s_from_token_data(struct parakeet_context *context, const parakeet_token_data *token_data); +extern VALUE ruby_whisper_parakeet_model_s_new(VALUE context); + +static void +ruby_whisper_parakeet_context_free(void *p) +{ + ruby_whisper_parakeet_context *rwpc = (ruby_whisper_parakeet_context *)p; + if (rwpc->context) { + parakeet_free(rwpc->context); + rwpc->context = NULL; + } + xfree(rwpc); +} + +static size_t +ruby_whisper_parakeet_context_memsize(const void *p) +{ + ruby_whisper_parakeet_context *rwpc = (ruby_whisper_parakeet_context *)p; + if (!rwpc) { + return 0; + } + size_t size = sizeof(*rwpc); + return size; +} + +const rb_data_type_t ruby_whisper_parakeet_context_type = { + "ruby_whisper_parakeet_context", + {0, ruby_whisper_parakeet_context_free, ruby_whisper_parakeet_context_memsize,}, + 0, 0, + 0 +}; + +static VALUE +ruby_whisper_parakeet_context_allocate(VALUE klass) +{ + ruby_whisper_parakeet_context *rwpc; + + VALUE obj = TypedData_Make_Struct(klass, ruby_whisper_parakeet_context, &ruby_whisper_parakeet_context_type, rwpc); + rwpc->context = NULL; + + return obj; +} + +typedef struct { + struct parakeet_context **context; + char *model_path; + struct parakeet_context_params params; +} ruby_whisper_parakeet_context_init_args; + +static void* +ruby_whisper_parakeet_context_init_without_gvl(void *args) +{ + ruby_whisper_parakeet_context_init_args *init_args = (ruby_whisper_parakeet_context_init_args *)args; + *init_args->context = parakeet_init_from_file_with_params(init_args->model_path, init_args->params); + return NULL; +} + +static VALUE +ruby_whisper_parakeet_context_initialize(int argc, VALUE *argv, VALUE self) +{ + ruby_whisper_parakeet_context *rwpc; + VALUE model_path; + VALUE context_params; + struct parakeet_context_params params; + + rb_scan_args(argc, argv, "11", &model_path, &context_params); + TypedData_Get_Struct(self, ruby_whisper_parakeet_context, &ruby_whisper_parakeet_context_type, rwpc); + + model_path = ruby_whisper_normalize_model_path(model_path); + if (!rb_respond_to(model_path, id_to_s)) { + rb_raise(rb_eRuntimeError, "Expected file path to model to initialize Parakeet::Context"); + } + if (NIL_P(context_params)) { + params = parakeet_context_default_params(); + } else { + ruby_whisper_parakeet_context_params *rwpcp; + GetParakeetContextParams(context_params, rwpcp); + params = rwpcp->params; + } + ruby_whisper_parakeet_context_init_args init_args = { + &rwpc->context, + StringValueCStr(model_path), + params, + }; + rb_thread_call_without_gvl(ruby_whisper_parakeet_context_init_without_gvl, (void *)&init_args, NULL, NULL); + if (rwpc->context == NULL) { + rb_raise(rb_eRuntimeError, "Failed to load model"); + } + + return Qnil; +} + +static VALUE +ruby_whisper_parakeet_context_full_n_segments(VALUE self) +{ + ruby_whisper_parakeet_context *rwpc; + GetParakeetContext(self, rwpc); + + return INT2NUM(parakeet_full_n_segments(rwpc->context)); +} + +#define DEF_SEGMENT_ATTR(name, type) \ + static VALUE \ + ruby_whisper_parakeet_context_full_##name(VALUE self, VALUE i_segment) \ + { \ + ruby_whisper_parakeet_context *rwpc; \ + GetParakeetContext(self, rwpc); \ + return READER(type)(parakeet_full_##name(rwpc->context, NUM2INT(i_segment))); \ + } + +ITERATE_SEGMENT_ATTRS(DEF_SEGMENT_ATTR) + +#define DEF_TOKEN_ATTR(name, type) \ + static VALUE \ + ruby_whisper_parakeet_context_full_##name(VALUE self, VALUE i_segment, VALUE i_token) \ + { \ + ruby_whisper_parakeet_context *rwpc; \ + GetParakeetContext(self, rwpc); \ + return READER(type)(parakeet_full_##name(rwpc->context, NUM2INT(i_segment), NUM2INT(i_token))); \ + } + +ITERATE_TOKEN_ATTRS(DEF_TOKEN_ATTR) + +static VALUE +ruby_whisper_parakeet_context_full_get_token_data(VALUE self, VALUE i_segment, VALUE i_token) +{ + ruby_whisper_parakeet_context *rwpc; + GetParakeetContext(self, rwpc); + parakeet_token_data token_data = parakeet_full_get_token_data(rwpc->context, NUM2INT(i_segment), NUM2INT(i_token)); + + return ruby_whisper_parakeet_token_s_from_token_data(rwpc->context, &token_data); +} + +static VALUE +ruby_whisper_parakeet_context_each_segment(VALUE self) +{ + if (!rb_block_given_p()) { + const VALUE method_name = rb_funcall(self, id___method__, 0); + return rb_funcall(self, id_to_enum, 1, method_name); + } + + ruby_whisper_parakeet_context *rwpc; + GetParakeetContext(self, rwpc); + + const int n_segments = parakeet_full_n_segments(rwpc->context); + for (int i = 0; i < n_segments; ++i) { + rb_yield(ruby_whisper_parakeet_segment_init(self, i)); + } + + return self; +} + +typedef struct { + struct parakeet_context *context; + struct parakeet_full_params *params; + float *samples; + int n_samples; + int result; +} parakeet_full_without_gvl_args; + +static void* +parakeet_full_without_gvl(void *rb_args) +{ + parakeet_full_without_gvl_args *args = (parakeet_full_without_gvl_args *)rb_args; + args->result = parakeet_full(args->context, *args->params, args->samples, args->n_samples); + + return NULL; +} + +typedef struct { + ruby_whisper_abort_callback_user_data *abort_callback_user_data; +} parakeet_full_ubf_args; + +static void +parakeet_full_ubf(void *rb_args) +{ + parakeet_full_ubf_args *args = (parakeet_full_ubf_args *)rb_args; + + RUBY_ATOMIC_SET(args->abort_callback_user_data->is_interrupted, 1); +} + +VALUE +ruby_whisper_parakeet_context_full_body(VALUE rb_args) +{ + ruby_whisper_full_args *args = (ruby_whisper_full_args *)rb_args; + ruby_whisper_parakeet_context *rwpc; + GetParakeetContext(*args->context, rwpc); + ruby_whisper_parakeet_params *rwpp; + GetParakeetParams(*args->params, rwpp); + + ruby_whisper_abort_callback_user_data abort_callback_user_data = { + 0, + NULL, + }; + ruby_whisper_parakeet_prepare_transcription(rwpp, args->context, &abort_callback_user_data); + + parakeet_full_without_gvl_args full_without_gvl_args = { + rwpc->context, + &rwpp->params, + args->samples, + args->n_samples, + 0 + }; + parakeet_full_ubf_args full_ubf_args = { + &abort_callback_user_data, + }; + rb_thread_call_without_gvl(parakeet_full_without_gvl, (void *)&full_without_gvl_args, parakeet_full_ubf, (void *)&full_ubf_args); + + return INT2NUM(full_without_gvl_args.result); +} + +static VALUE +ruby_whisper_parakeet_context_full(int argc, VALUE *argv, VALUE self) +{ + if (argc < 2 || argc > 3) { + rb_raise(rb_eArgError, "wrong number of arguments (given %d, expected 2..3)", argc); + } + + VALUE n_samples = argc == 2 ? Qnil : argv[2]; + + struct parsed_samples_t parsed = parse_samples(&argv[1], &n_samples); + ruby_whisper_full_args args = { + &self, + &argv[0], + parsed.samples, + parsed.n_samples, + }; + VALUE rb_result = rb_ensure(ruby_whisper_parakeet_context_full_body, (VALUE)&args, release_samples, (VALUE)&parsed); + const int result = NUM2INT(rb_result); + if (result == 0) { + return self; + } else { + rb_exc_raise(rb_funcall(eError, id_new, 1, rb_result)); + } +} + +static VALUE +ruby_whisper_parakeet_context_get_model(VALUE self) +{ + return ruby_whisper_parakeet_model_s_new(self); +} + +VALUE +init_ruby_whisper_parakeet_context(VALUE *mParakeet) +{ + cParakeetContext = rb_define_class_under(*mParakeet, "Context", rb_cObject); + + rb_define_alloc_func(cParakeetContext, ruby_whisper_parakeet_context_allocate); + + rb_define_method(cParakeetContext, "initialize", ruby_whisper_parakeet_context_initialize, -1); + rb_define_method(cParakeetContext, "transcribe", ruby_whisper_parakeet_transcribe, 2); + rb_define_method(cParakeetContext, "full_n_segments", ruby_whisper_parakeet_context_full_n_segments, 0); + rb_define_method(cParakeetContext, "full_get_token_data", ruby_whisper_parakeet_context_full_get_token_data, 2); + rb_define_method(cParakeetContext, "model", ruby_whisper_parakeet_context_get_model, 0); + rb_define_method(cParakeetContext, "each_segment", ruby_whisper_parakeet_context_each_segment, 0); + rb_define_method(cParakeetContext, "full", ruby_whisper_parakeet_context_full, -1); + +#define REGISTER_SEGMENT_ATTR(name, type) \ + rb_define_method(cParakeetContext, "full_" #name, ruby_whisper_parakeet_context_full_##name, 1); + + ITERATE_SEGMENT_ATTRS(REGISTER_SEGMENT_ATTR) + +#define REGISTER_TOKEN_ATTR(name, type) \ + rb_define_method(cParakeetContext, "full_" #name, ruby_whisper_parakeet_context_full_##name, 2); + + ITERATE_TOKEN_ATTRS(REGISTER_TOKEN_ATTR) + + return cParakeetContext; +} diff --git a/bindings/ruby/ext/ruby_whisper_parakeet_context_params.c b/bindings/ruby/ext/ruby_whisper_parakeet_context_params.c new file mode 100644 index 00000000000..38bd6d57ce1 --- /dev/null +++ b/bindings/ruby/ext/ruby_whisper_parakeet_context_params.c @@ -0,0 +1,117 @@ +#include "ruby_whisper.h" + +#define ITERATE_ATTRS(ITERATOR) \ + ITERATOR(use_gpu, BOOL) \ + ITERATOR(gpu_device, INT) + +#define VAL_FROM_BOOL(v) ((v) ? Qtrue : Qfalse) +#define VAL_TO_BOOL(v) (RTEST(v)) +#define VAL_FROM_INT(v) (INT2NUM(v)) +#define VAL_TO_INT(v) (NUM2INT(v)) +#define READER(type) VAL_FROM_##type +#define WRITER(type) VAL_TO_##type + +#define DEF_ATTR(name, type) \ + static VALUE \ + ruby_whisper_parakeet_context_params_get_##name(VALUE self) \ + { \ + ruby_whisper_parakeet_context_params *rwpcp; \ + GetParakeetContextParams(self, rwpcp); \ + return READER(type)(rwpcp->params.name); \ + } \ + static VALUE \ + ruby_whisper_parakeet_context_params_set_##name(VALUE self, VALUE val) \ + { \ + ruby_whisper_parakeet_context_params *rwpcp; \ + GetParakeetContextParams(self, rwpcp); \ + rwpcp->params.name = WRITER(type)(val); \ + return val; \ + } + +enum { +#define DEF_IDX(name, type) RUBY_WHISPER_PARAKEET_CONTEXT_PARAMS_##name, + + ITERATE_ATTRS(DEF_IDX) + RUBY_WHISPER_PARAKEET_NUM_CONTEXT_PARAMS +}; + +extern VALUE cParakeetContextParams; + +typedef VALUE (*param_writer_t)(VALUE, VALUE); + +static ID param_names[RUBY_WHISPER_PARAKEET_NUM_CONTEXT_PARAMS]; +static param_writer_t param_writers[RUBY_WHISPER_PARAKEET_NUM_CONTEXT_PARAMS]; + +static size_t +ruby_whisper_parakeet_context_params_memsize(const void *p) +{ + if (!p) { + return 0; + } + return sizeof(ruby_whisper_parakeet_context_params); +} + +const rb_data_type_t ruby_whisper_parakeet_context_params_type = { + "ruby_whisper_parakeet_context_params", + {0, RUBY_DEFAULT_FREE, ruby_whisper_parakeet_context_params_memsize,}, + 0, 0, + 0, +}; + +static VALUE +ruby_whisper_parakeet_context_params_s_allocate(VALUE klass) +{ + ruby_whisper_parakeet_context_params *rwpcp; + return TypedData_Make_Struct(klass, ruby_whisper_parakeet_context_params, &ruby_whisper_parakeet_context_params_type, rwpcp); +} + +static VALUE +ruby_whisper_parakeet_context_params_initialize(int argc, VALUE *argv, VALUE self) +{ + VALUE kw_hash; + VALUE values[RUBY_WHISPER_PARAKEET_NUM_CONTEXT_PARAMS] = {Qundef}; + VALUE value; + ruby_whisper_parakeet_context_params *rwpcp; + int i; + + TypedData_Get_Struct(self, ruby_whisper_parakeet_context_params, &ruby_whisper_parakeet_context_params_type, rwpcp); + rwpcp->params = parakeet_context_default_params(); + + rb_scan_args_kw(RB_SCAN_ARGS_KEYWORDS, argc, argv, ":", &kw_hash); + if (NIL_P(kw_hash)) { + return Qnil; + } + + rb_get_kwargs(kw_hash, param_names, 0, RUBY_WHISPER_PARAKEET_NUM_CONTEXT_PARAMS, values); + for (i = 0; i < RUBY_WHISPER_PARAKEET_NUM_CONTEXT_PARAMS; i++) { + value = values[i]; + if (value == Qundef) { + continue; + } + param_writers[i](self, value); + } + + return Qnil; +} + +ITERATE_ATTRS(DEF_ATTR) + +void +init_ruby_whisper_parakeet_context_params(VALUE *cParakeetContext) +{ + cParakeetContextParams = rb_define_class_under(*cParakeetContext, "Params", rb_cObject); + + rb_define_alloc_func(cParakeetContextParams, ruby_whisper_parakeet_context_params_s_allocate); + + rb_define_method(cParakeetContextParams, "initialize", ruby_whisper_parakeet_context_params_initialize, -1); + + int i = 0; +#define REGISTER_ATTR(name, type) \ + param_names[i] = rb_intern(#name); \ + param_writers[i] = ruby_whisper_parakeet_context_params_set_##name; \ + rb_define_method(cParakeetContextParams, #name, ruby_whisper_parakeet_context_params_get_##name, 0); \ + rb_define_method(cParakeetContextParams, #name "=", ruby_whisper_parakeet_context_params_set_##name, 1); \ + i++; + + ITERATE_ATTRS(REGISTER_ATTR) +} diff --git a/bindings/ruby/ext/ruby_whisper_parakeet_model.c b/bindings/ruby/ext/ruby_whisper_parakeet_model.c new file mode 100644 index 00000000000..dce43c688e7 --- /dev/null +++ b/bindings/ruby/ext/ruby_whisper_parakeet_model.c @@ -0,0 +1,84 @@ +#include "ruby_whisper.h" + +#define ITERATE_ATTRS(ITERATOR) \ + ITERATOR(n_vocab) \ + ITERATOR(n_audio_ctx) \ + ITERATOR(n_audio_state) \ + ITERATOR(n_audio_head) \ + ITERATOR(n_audio_layer) \ + ITERATOR(n_mels) \ + ITERATOR(ftype) + +extern rb_data_type_t ruby_whisper_parakeet_context_type; +extern VALUE cParakeetModel; + +static void +ruby_whisper_parakeet_model_mark(void *p) +{ + ruby_whisper_parakeet_model *rwpm = (ruby_whisper_parakeet_model *)p; + if (!NIL_P(rwpm->context)) { + rb_gc_mark(rwpm->context); + } +} + +static size_t +ruby_whisper_parakeet_model_memsize(const void *p) +{ + if (!p) { + return 0; + } + return sizeof(ruby_whisper_parakeet_model); +} + +static const rb_data_type_t ruby_whisper_parakeet_model_type = { + "ruby_whisper_parakeet_model", + {ruby_whisper_parakeet_model_mark, RUBY_DEFAULT_FREE, ruby_whisper_parakeet_model_memsize}, + 0, 0, + 0 +}; + +static VALUE +ruby_whisper_parakeet_model_s_allocate(VALUE klass) +{ + ruby_whisper_parakeet_model *rwpm; + VALUE model = TypedData_Make_Struct(klass, ruby_whisper_parakeet_model, &ruby_whisper_parakeet_model_type, rwpm); + rwpm->context = Qnil; + + return model; +} + +VALUE +ruby_whisper_parakeet_model_s_new(VALUE context) +{ + const VALUE model = ruby_whisper_parakeet_model_s_allocate(cParakeetModel); + ruby_whisper_parakeet_model *rwpm; + TypedData_Get_Struct(model, ruby_whisper_parakeet_model, &ruby_whisper_parakeet_model_type, rwpm); + rwpm->context = context; + return model; +} + +#define DEF_ATTR(name) \ + static VALUE \ + ruby_whisper_parakeet_model_get_##name(VALUE self) \ + { \ + ruby_whisper_parakeet_model *rwpm; \ + ruby_whisper_parakeet_context *rwpc; \ + GetParakeetModel(self, rwpm); \ + GetParakeetContext(rwpm->context, rwpc); \ + return INT2NUM(parakeet_model_##name(rwpc->context)); \ + } + +ITERATE_ATTRS(DEF_ATTR) + +void +init_ruby_whisper_parakeet_model(VALUE *mParakeet) +{ + cParakeetModel = rb_define_class_under(*mParakeet, "Model", rb_cObject); + + rb_define_alloc_func(cParakeetModel, ruby_whisper_parakeet_model_s_allocate); + +#define REGISTER_ATTR(name) \ + rb_define_method(cParakeetModel, #name, ruby_whisper_parakeet_model_get_##name, 0); + + ITERATE_ATTRS(REGISTER_ATTR) +} diff --git a/bindings/ruby/ext/ruby_whisper_parakeet_params.c b/bindings/ruby/ext/ruby_whisper_parakeet_params.c new file mode 100644 index 00000000000..076e2a0cdfb --- /dev/null +++ b/bindings/ruby/ext/ruby_whisper_parakeet_params.c @@ -0,0 +1,548 @@ +#include "ruby_whisper.h" + +#define ITERATE_PARAMS(ITERATOR) \ + ITERATOR(n_threads, INT) \ + ITERATOR(offset_ms, INT) \ + ITERATOR(duration_ms, INT) \ + ITERATOR(no_context, BOOL) \ + ITERATOR(audio_ctx, INT) + +#define ITERATE_NORMAL_CALLBACK_NAMES(ITERATOR, DATA) \ + ITERATOR(new_segment, DATA) \ + ITERATOR(new_token, DATA) \ + ITERATOR(progress, DATA) \ + ITERATOR(encoder_begin, DATA) + +#define ITERATE_NORMAL_CALLBACK_PARAM(name, ITERATOR) ITERATOR(name##_callback) +#define ITERATE_NORMAL_CALLBACK_PARAMS(ITERATOR) \ + ITERATE_NORMAL_CALLBACK_NAMES(ITERATE_NORMAL_CALLBACK_PARAM, ITERATOR) + +#define ITERATE_CALLBACK_PARAMS(ITERATOR) \ + ITERATE_NORMAL_CALLBACK_PARAMS(ITERATOR) \ + ITERATOR(abort_callback) + +enum { +#define DEF_IDX(name, type) RUBY_WHISPER_PARAKEET_PARAM_##name, +#define DEF_IDX_CALLBACK(name) RUBY_WHISPER_PARAKEET_PARAM_##name, +#define DEF_IDX_USER_DATA(name) RUBY_WHISPER_PARAKEET_PARAM_##name##_user_data, + ITERATE_PARAMS(DEF_IDX) + ITERATE_CALLBACK_PARAMS(DEF_IDX_CALLBACK) + ITERATE_CALLBACK_PARAMS(DEF_IDX_USER_DATA) + + RUBY_WHISPER_PARAKEET_NUM_PARAMS +}; + +#define VAL_TO_INT(v) (NUM2INT(v)) +#define VAL_FROM_INT(v) (INT2NUM(v)) +#define VAL_TO_BOOL(v) (RTEST(v)) +#define VAL_FROM_BOOL(v) (v ? Qtrue : Qfalse) + +extern VALUE cParakeetParams; +extern ID id_call; + +extern void ruby_whisper_callback_container_mark(ruby_whisper_callback_container *rwc); +extern ruby_whisper_callback_container* ruby_whisper_callback_container_allocate(void); +extern bool ruby_whisper_callback_container_is_present(const ruby_whisper_callback_container *container); +extern VALUE ruby_whisper_parakeet_segment_init(VALUE context, int index); +extern VALUE ruby_whisper_parakeet_token_s_from_token_data(struct parakeet_context *context, const parakeet_token_data *token_data); + +static ID param_names[RUBY_WHISPER_PARAKEET_NUM_PARAMS]; +typedef VALUE (*param_writer_t)(VALUE, VALUE); +static param_writer_t param_writers[RUBY_WHISPER_PARAKEET_NUM_PARAMS]; + +typedef struct { + const ruby_whisper_callback_container *container; + struct parakeet_state *state; + int n_new; +} call_parakeet_new_segment_callbacks_args; + +static void* +call_parakeet_new_segment_callbacks(void *v_args) +{ + call_parakeet_new_segment_callbacks_args *args = (call_parakeet_new_segment_callbacks_args *)v_args; + const ruby_whisper_callback_container *container = args->container; + + if (!NIL_P(container->callback)) { + rb_funcall(container->callback, id_call, 4, *container->context, Qnil, INT2NUM(args->n_new), container->user_data); + } + if (NIL_P(container->callbacks)) { + return NULL; + } + const long n_callbacks = RARRAY_LEN(container->callbacks); + if (n_callbacks == 0) { + return NULL; + } + const int n_segments = parakeet_full_n_segments_from_state(args->state); + for (int i = args->n_new; i > 0; i--) { + int i_segment = n_segments - i; + VALUE segment = ruby_whisper_parakeet_segment_init(*container->context, i_segment); + for (int j = 0; j < n_callbacks; j++) { + VALUE cb = rb_ary_entry(container->callbacks, j); + rb_funcall(cb, id_call, 1, segment); + } + } + + return NULL; +} + +static void +ruby_whisper_parakeet_new_segment_callback(struct parakeet_context *context, struct parakeet_state *state, int n_new, void *user_data) +{ + const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data; + if (!ruby_whisper_callback_container_is_present(container)) { + return; + } + + call_parakeet_new_segment_callbacks_args args = { + container, + state, + n_new, + }; + rb_thread_call_with_gvl(call_parakeet_new_segment_callbacks, (void *)&args); +} + +typedef struct { + const ruby_whisper_callback_container *container; + struct parakeet_context *context; + struct parakeet_state *state; + const parakeet_token_data *token_data; +} call_parakeet_new_token_callbacks_args; + +static void* +call_parakeet_new_token_callbacks(void *v_args) +{ + call_parakeet_new_token_callbacks_args *args = (call_parakeet_new_token_callbacks_args *)v_args; + VALUE token = Qnil; + const ruby_whisper_callback_container *container = args->container; + + if (!NIL_P(container->callback)) { + token = ruby_whisper_parakeet_token_s_from_token_data(args->context, args->token_data); + rb_funcall(container->callback, id_call, 4, *container->context, Qnil, token, container->user_data); + } + if (NIL_P(container->callbacks)) { + return NULL; + } + const long n_callbacks = RARRAY_LEN(container->callbacks); + if (n_callbacks == 0) { + return NULL; + } + if (NIL_P(token)) { + token = ruby_whisper_parakeet_token_s_from_token_data(args->context, args->token_data); + } + for (int i = 0; i < n_callbacks; i++) { + VALUE cb = rb_ary_entry(container->callbacks, i); + rb_funcall(cb, id_call, 1, token); + } + + return NULL; +} + +static void +ruby_whisper_parakeet_new_token_callback(struct parakeet_context *context, struct parakeet_state *state, const parakeet_token_data *token_data, void *user_data) +{ + const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data; + if (!ruby_whisper_callback_container_is_present(container)) { + return; + } + + call_parakeet_new_token_callbacks_args args = { + container, + context, + state, + token_data, + }; + rb_thread_call_with_gvl(call_parakeet_new_token_callbacks, (void *)&args); +} + +typedef struct { + const ruby_whisper_callback_container *container; + struct parakeet_state *state; + int progress; +} call_parakeet_progress_callbacks_args; + +static void* +call_parakeet_progress_callback(void *v_args) +{ + call_parakeet_progress_callbacks_args *args = (call_parakeet_progress_callbacks_args *)v_args; + const ruby_whisper_callback_container *container = args->container; + + if (!NIL_P(container->callback)) { + rb_funcall(container->callback, id_call, 4, *container->context, Qnil, INT2NUM(args->progress), container->user_data); + } + if (NIL_P(container->callbacks)) { + return NULL; + } + const long n_callbacks = RARRAY_LEN(container->callbacks); + if (n_callbacks == 0) { + return NULL; + } + for (long i = 0; i < n_callbacks; i++) { + VALUE cb = rb_ary_entry(container->callbacks, i); + rb_funcall(cb, id_call, 1, INT2NUM(args->progress)); + } + + return NULL; +} + +static void +ruby_whisper_parakeet_progress_callback(struct parakeet_context *context, struct parakeet_state *state, int progress, void *user_data) +{ + const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data; + if (!ruby_whisper_callback_container_is_present(container)) { + return; + } + + call_parakeet_progress_callbacks_args args = { + container, + state, + progress, + }; + rb_thread_call_with_gvl(call_parakeet_progress_callback, (void *)&args); +} + +typedef struct { + const ruby_whisper_callback_container *container; + struct parakeet_state *state; + bool is_continued; +} call_parakeet_encoder_begin_callbacks_args; + +static void* +call_parakeet_encoder_begin_callbacks(void *v_args) +{ + call_parakeet_encoder_begin_callbacks_args *args = (call_parakeet_encoder_begin_callbacks_args *)v_args; + const ruby_whisper_callback_container *container = args->container; + VALUE result = Qnil; + + if (!NIL_P(container->callback)) { + result = rb_funcall(container->callback, id_call, 3, *container->context, Qnil, container->user_data); + if (result == Qfalse) { + args->is_continued = false; + return NULL; + } + } + if (NIL_P(container->callbacks)) { + return NULL; + } + const long n_callbacks = RARRAY_LEN(container->callbacks); + if (n_callbacks == 0) { + return NULL; + } + for (long i = 0; i < n_callbacks; i++) { + VALUE cb = rb_ary_entry(container->callbacks, i); + result = rb_funcall(cb, id_call, 0); + if (result == Qfalse) { + args->is_continued = false; + return NULL; + } + } + + return NULL; +} + +static bool +ruby_whisper_parakeet_encoder_begin_callback(struct parakeet_context *context, struct parakeet_state *state, void *user_data) +{ + const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data; + if (!ruby_whisper_callback_container_is_present(container)) { + return true; + } + + call_parakeet_encoder_begin_callbacks_args args = { + container, + state, + true, + }; + rb_thread_call_with_gvl(call_parakeet_encoder_begin_callbacks, (void *)&args); + + return args.is_continued; +} + +typedef struct { + const ruby_whisper_callback_container *container; + bool is_interrupted; +} call_parakeet_abort_callbacks_args; + +static void* +call_parakeet_abort_callbacks(void *v_args) +{ + call_parakeet_abort_callbacks_args *args = (call_parakeet_abort_callbacks_args *)v_args; + const ruby_whisper_callback_container *container = args->container; + VALUE result = Qnil; + + if (!NIL_P(container->callback)) { + result = rb_funcall(container->callback, id_call, 1, container->user_data); + if (RTEST(result)) { + args->is_interrupted = true; + return NULL; + } + } + if (NIL_P(container->callbacks)) { + return NULL; + } + const long n_callbacks = RARRAY_LEN(container->callbacks); + if (n_callbacks == 0) { + return NULL; + } + VALUE cb; + for (long i = 0; i < n_callbacks; i++) { + cb = rb_ary_entry(container->callbacks, i); + result = rb_funcall(cb, id_call, 0); + if (RTEST(result)) { + args->is_interrupted = true; + return NULL; + } + } + + return NULL; +} + +static bool +ruby_whisper_parakeet_abort_callback(void *user_data) +{ + ruby_whisper_abort_callback_user_data *data = (ruby_whisper_abort_callback_user_data *)user_data; + + int is_interrupted = RUBY_ATOMIC_LOAD(data->is_interrupted); + if (is_interrupted) { + return true; + } + + if (!(data->callback_container) || !ruby_whisper_callback_container_is_present(data->callback_container)) { + return false; + } + + call_parakeet_abort_callbacks_args args = { + data->callback_container, + false, + }; + rb_thread_call_with_gvl(call_parakeet_abort_callbacks, (void *)&args); + + return args.is_interrupted; +} + +#define CALLBACK_CONTAINER_NAME(name) name ## _container + +void +ruby_whisper_parakeet_prepare_transcription(ruby_whisper_parakeet_params *rwpp, VALUE *context, ruby_whisper_abort_callback_user_data *abort_callback_user_data) +{ +#define PARAM_NAME(name) name +#define USER_DATA_NAME(name) name##_user_data +#define REGISTER_CALLBACK(name) \ + if (ruby_whisper_callback_container_is_present(rwpp->CALLBACK_CONTAINER_NAME(name))) { \ + rwpp->CALLBACK_CONTAINER_NAME(name)->context = context; \ + rwpp->params.PARAM_NAME(name) = ruby_whisper_parakeet_##name; \ + rwpp->params.USER_DATA_NAME(name) = rwpp->CALLBACK_CONTAINER_NAME(name); \ + } + + ITERATE_NORMAL_CALLBACK_PARAMS(REGISTER_CALLBACK) + + if (ruby_whisper_callback_container_is_present(rwpp->abort_callback_container)) { + abort_callback_user_data->callback_container = rwpp->abort_callback_container; + } + rwpp->params.abort_callback = ruby_whisper_parakeet_abort_callback; + rwpp->params.abort_callback_user_data = (void *)abort_callback_user_data; +} + +static void +ruby_whisper_parakeet_params_mark(void *p) +{ + ruby_whisper_parakeet_params *rwpp = (ruby_whisper_parakeet_params *)p; + +#define MARK_CONTAINER(name) \ + if (rwpp->name##_container) { \ + ruby_whisper_callback_container_mark(rwpp->name##_container); \ + } + + ITERATE_CALLBACK_PARAMS(MARK_CONTAINER) +} + +static void +ruby_whisper_parakeet_params_free(void *p) +{ + ruby_whisper_parakeet_params *rwpp = (ruby_whisper_parakeet_params *)p; + +#define FREE_CONTAINER(name) \ + if (rwpp->name##_container) { \ + xfree(rwpp->name##_container); \ + } + + ITERATE_CALLBACK_PARAMS(FREE_CONTAINER) + + xfree(rwpp); +} + +static size_t +ruby_whisper_parakeet_params_memsize(const void *p) +{ + const struct ruby_whisper_parakeet_params *params = p; + if (!params) { + return 0; + } + return sizeof(ruby_whisper_parakeet_params); +} + +const rb_data_type_t ruby_whisper_parakeet_params_type = { + "ruby_whisper_parakeet_params", + {ruby_whisper_parakeet_params_mark, ruby_whisper_parakeet_params_free, ruby_whisper_parakeet_params_memsize,}, + 0, 0, + 0 +}; + +#define READER(type) VAL_FROM_##type +#define WRITER(type) VAL_TO_##type +#define DEF_PARAM_ATTR(name, type) \ + static VALUE \ + ruby_whisper_parakeet_params_get_##name(VALUE self) \ + { \ + ruby_whisper_parakeet_params *rwpp; \ + GetParakeetParams(self, rwpp); \ + return READER(type)(rwpp->params.name); \ + } \ + static VALUE \ + ruby_whisper_parakeet_params_set_##name(VALUE self, VALUE val) \ + { \ + ruby_whisper_parakeet_params *rwpp; \ + GetParakeetParams(self, rwpp); \ + rwpp->params.name = WRITER(type)(val); \ + return val; \ + } + +#define DEF_CALLBACK_PARAM_ATTR(name) \ + static VALUE \ + ruby_whisper_parakeet_params_get_##name(VALUE self) \ + { \ + ruby_whisper_parakeet_params *rwpp; \ + GetParakeetParams(self, rwpp); \ + return rwpp->CALLBACK_CONTAINER_NAME(name)->callback; \ + } \ + static VALUE \ + ruby_whisper_parakeet_params_set_##name(VALUE self, VALUE val) \ + { \ + ruby_whisper_parakeet_params *rwpp; \ + GetParakeetParams(self, rwpp); \ + rwpp->CALLBACK_CONTAINER_NAME(name)->callback = (val); \ + return val; \ + } + +#define DEF_USER_DATA_PARAM_ATTR(name) \ + static VALUE \ + ruby_whisper_parakeet_params_get_##name##_user_data(VALUE self) \ + { \ + ruby_whisper_parakeet_params *rwpp; \ + GetParakeetParams(self, rwpp); \ + return rwpp->CALLBACK_CONTAINER_NAME(name)->user_data; \ + } \ + static VALUE \ + ruby_whisper_parakeet_params_set_##name##_user_data(VALUE self, VALUE val) \ + { \ + ruby_whisper_parakeet_params *rwpp; \ + GetParakeetParams(self, rwpp); \ + rwpp->CALLBACK_CONTAINER_NAME(name)->user_data = val; \ + return val; \ + } + +#define DEF_HOOK(name, data) \ + static VALUE \ + ruby_whisper_parakeet_params_on_##name(VALUE self) \ + { \ + ruby_whisper_parakeet_params *rwpp; \ + GetParakeetParams(self, rwpp); \ + const VALUE blk = rb_block_proc(); \ + if (NIL_P(rwpp->name##_callback_container->callbacks)) { \ + rwpp->name##_callback_container->callbacks = rb_ary_new(); \ + } \ + rb_ary_push(rwpp->name##_callback_container->callbacks, blk); \ + return Qnil; \ + } + +ITERATE_PARAMS(DEF_PARAM_ATTR) +ITERATE_CALLBACK_PARAMS(DEF_CALLBACK_PARAM_ATTR) +ITERATE_CALLBACK_PARAMS(DEF_USER_DATA_PARAM_ATTR) +ITERATE_NORMAL_CALLBACK_NAMES(DEF_HOOK, _) + +static VALUE +ruby_whisper_parakeet_params_abort_on(VALUE self) +{ + ruby_whisper_parakeet_params *rwpp; + GetParakeetParams(self, rwpp); + const VALUE blk = rb_block_proc(); + if (NIL_P(rwpp->abort_callback_container->callbacks)) { + rwpp->abort_callback_container->callbacks = rb_ary_new(); + } + rb_ary_push(rwpp->abort_callback_container->callbacks, blk); + + return Qnil; +} + +static VALUE +ruby_whisper_parakeet_params_s_allocate(VALUE klass) +{ + ruby_whisper_parakeet_params *rwpp; + VALUE obj = TypedData_Make_Struct(klass, ruby_whisper_parakeet_params, &ruby_whisper_parakeet_params_type, rwpp); + rwpp->params = parakeet_full_default_params(PARAKEET_SAMPLING_GREEDY); + return obj; +} + +static VALUE +ruby_whisper_parakeet_params_initialize(int argc, VALUE *argv, VALUE self) +{ + VALUE kw_hash; + VALUE values[RUBY_WHISPER_PARAKEET_NUM_PARAMS] = {Qundef}; + VALUE value; + ruby_whisper_parakeet_params *rwpp; + int i; + + TypedData_Get_Struct(self, ruby_whisper_parakeet_params, &ruby_whisper_parakeet_params_type, rwpp); + +#define INIT_CONTAINER(name) rwpp->name##_container = ruby_whisper_callback_container_allocate(); + + ITERATE_CALLBACK_PARAMS(INIT_CONTAINER) + + rb_scan_args_kw(RB_SCAN_ARGS_KEYWORDS, argc, argv, ":", &kw_hash); + if (NIL_P(kw_hash)) { + return Qnil; + } + + rb_get_kwargs(kw_hash, param_names, 0, RUBY_WHISPER_PARAKEET_NUM_PARAMS, values); + + for (i = 0; i < RUBY_WHISPER_PARAKEET_NUM_PARAMS; i++) { + value = values[i]; + if (value == Qundef) { + continue; + } + param_writers[i](self, value); + } + + return Qnil; +} + +void +init_ruby_whisper_parakeet_params(VALUE *mParakeet) +{ + cParakeetParams = rb_define_class_under(*mParakeet, "Params", rb_cObject); + rb_define_alloc_func(cParakeetParams, ruby_whisper_parakeet_params_s_allocate); + + rb_define_method(cParakeetParams, "initialize", ruby_whisper_parakeet_params_initialize, -1); + + int i = 0; +#define REGISTER_PARAM(name) \ + param_names[i] = rb_intern(#name); \ + param_writers[i] = ruby_whisper_parakeet_params_set_##name; \ + rb_define_method(cParakeetParams, #name, ruby_whisper_parakeet_params_get_##name, 0); \ + rb_define_method(cParakeetParams, #name "=", ruby_whisper_parakeet_params_set_##name, 1); \ + i++; + +#define REGISTER_PARAM_ATTR(name, type) REGISTER_PARAM(name) +#define REGISTER_CALLBACK_PARAM_ATTR(name) REGISTER_PARAM(name) +#define REGISTER_USER_DATA_PARAM_ATTR(name) REGISTER_PARAM(name##_user_data) + + ITERATE_PARAMS(REGISTER_PARAM_ATTR) + ITERATE_CALLBACK_PARAMS(REGISTER_CALLBACK_PARAM_ATTR) + ITERATE_CALLBACK_PARAMS(REGISTER_USER_DATA_PARAM_ATTR) + +#define REGISTER_HOOK(name, data) \ + rb_define_method(cParakeetParams, "on_" #name, ruby_whisper_parakeet_params_on_##name, 0); + + ITERATE_NORMAL_CALLBACK_NAMES(REGISTER_HOOK, _) + + rb_define_method(cParakeetParams, "abort_on", ruby_whisper_parakeet_params_abort_on, 0); +} diff --git a/bindings/ruby/ext/ruby_whisper_parakeet_segment.c b/bindings/ruby/ext/ruby_whisper_parakeet_segment.c new file mode 100644 index 00000000000..b1e81ba930c --- /dev/null +++ b/bindings/ruby/ext/ruby_whisper_parakeet_segment.c @@ -0,0 +1,157 @@ +#include "ruby_whisper.h" + +#define ITERATE_ATTRS(ITERATOR) \ + ITERATOR(start_time, t0, TIME) \ + ITERATOR(end_time, t1, TIME) \ + ITERATOR(text, text, STRING) + +enum { +#define DEF_IDX(name, c_name, type) RUBY_WHISPER_PARAKEET_SEGMENT_##name, + + ITERATE_ATTRS(DEF_IDX) + RUBY_WHISPER_PARAKEET_SEGMENT_NUM_ATTRS, +}; + +#define VAL_FROM_TIME(v) (LONG2NUM((v) * 10)) +#define VAL_FROM_STRING(v) (rb_str_new2(v)) +#define READER(type) VAL_FROM_##type +#define DEF_ATTR(rb_name, c_name, type) \ + static VALUE \ + ruby_whisper_parakeet_get_##rb_name(VALUE self) \ + { \ + ruby_whisper_parakeet_segment *rwps; \ + GetParakeetSegment(self, rwps); \ + ruby_whisper_parakeet_context *rwpc; \ + GetParakeetContext(rwps->context, rwpc); \ + return READER(type)(parakeet_full_get_segment_##c_name(rwpc->context, rwps->index)); \ + } + +extern ID id___method__; +extern ID id_to_enum; +extern VALUE cParakeetSegment; +extern VALUE sym_start_time; +extern VALUE sym_end_time; +extern VALUE sym_text; +extern const rb_data_type_t ruby_whisper_parakeet_context_type; +extern VALUE ruby_whisper_parakeet_token_s_from_index(struct parakeet_context *context, int i_segment, int i_token); + +static void +rb_whisper_parakeet_segment_mark(void *p) +{ + ruby_whisper_parakeet_segment *rwps = (ruby_whisper_parakeet_segment *)p; + rb_gc_mark(rwps->context); +} + +static size_t +ruby_whisper_parakeet_segment_memsize(const void *p) +{ + const ruby_whisper_parakeet_segment *rwps = (const ruby_whisper_parakeet_segment *)p; + if (!rwps) { + return 0; + } + return sizeof(*rwps); +} + +static const rb_data_type_t ruby_whisper_parakeet_segment_type = { + "ruby_whisper_parakeet_segment", + {rb_whisper_parakeet_segment_mark, RUBY_DEFAULT_FREE, ruby_whisper_parakeet_segment_memsize,}, + 0, 0, + 0 +}; + +static VALUE +ruby_whisper_parakeet_segment_s_allocate(VALUE klass) +{ + ruby_whisper_parakeet_segment *rwps; + return TypedData_Make_Struct(klass, ruby_whisper_parakeet_segment, &ruby_whisper_parakeet_segment_type, rwps); +} + +VALUE +ruby_whisper_parakeet_segment_init(VALUE context, int index) +{ + ruby_whisper_parakeet_segment *rwps; + + const VALUE segment = ruby_whisper_parakeet_segment_s_allocate(cParakeetSegment); + TypedData_Get_Struct(segment, ruby_whisper_parakeet_segment, &ruby_whisper_parakeet_segment_type, rwps); + rwps->context = context; + rwps->index = index; + + return segment; +} + +ITERATE_ATTRS(DEF_ATTR) + +static VALUE +ruby_whisper_parakeet_segment_each_token(VALUE self) +{ + if (!rb_block_given_p()) { + const VALUE method_name = rb_funcall(self, id___method__, 0); + return rb_funcall(self, id_to_enum, 1, method_name); + } + + ruby_whisper_parakeet_segment *rwps; + GetParakeetSegment(self, rwps); + ruby_whisper_parakeet_context *rwpc; + GetParakeetContext(rwps->context, rwpc); + + const int n_tokens = parakeet_full_n_tokens(rwpc->context, rwps->index); + for (int i = 0; i < n_tokens; i++) { + rb_yield(ruby_whisper_parakeet_token_s_from_index(rwpc->context, rwps->index, i)); + } + + return self; +} + +static VALUE +ruby_whisper_parakeet_segment_deconstruct_keys(VALUE self, VALUE keys) +{ + ruby_whisper_parakeet_segment *rwps; + GetParakeetSegment(self, rwps); + ruby_whisper_parakeet_context *rwpc; + GetParakeetContext(rwps->context, rwpc); + + VALUE hash = rb_hash_new(); + long n_keys; + if (NIL_P(keys)) { + keys = rb_ary_new3( + RUBY_WHISPER_PARAKEET_SEGMENT_NUM_ATTRS, + sym_start_time, + sym_end_time, + sym_text + ); + n_keys = RUBY_WHISPER_PARAKEET_SEGMENT_NUM_ATTRS; + } else { + n_keys = RARRAY_LEN(keys); + if (n_keys > RUBY_WHISPER_PARAKEET_SEGMENT_NUM_ATTRS) { + return hash; + } + } + for (int i = 0; i < n_keys; i++) { + VALUE key = rb_ary_entry(keys, i); + +#define CHECK_AND_SET_KEY(rb_name, c_name, type) \ + if (key == sym_##rb_name) { \ + rb_hash_aset(hash, key, ruby_whisper_parakeet_get_##rb_name(self)); \ + } + + ITERATE_ATTRS(CHECK_AND_SET_KEY) + } + + return hash; +} + +void +init_ruby_whisper_parakeet_segment(VALUE *mParakeet) +{ + cParakeetSegment = rb_define_class_under(*mParakeet, "Segment", rb_cObject); + + rb_define_alloc_func(cParakeetSegment, ruby_whisper_parakeet_segment_s_allocate); + +#define REGISTER_ATTR(rb_name, c_name, type) \ + rb_define_method(cParakeetSegment, #rb_name, ruby_whisper_parakeet_get_##rb_name, 0); + + ITERATE_ATTRS(REGISTER_ATTR) + + rb_define_method(cParakeetSegment, "each_token", ruby_whisper_parakeet_segment_each_token, 0); + rb_define_method(cParakeetSegment, "deconstruct_keys", ruby_whisper_parakeet_segment_deconstruct_keys, 1); +} diff --git a/bindings/ruby/ext/ruby_whisper_parakeet_token.c b/bindings/ruby/ext/ruby_whisper_parakeet_token.c new file mode 100644 index 00000000000..a00b7ae1cbb --- /dev/null +++ b/bindings/ruby/ext/ruby_whisper_parakeet_token.c @@ -0,0 +1,188 @@ +#include "ruby_whisper.h" + +#define ITERATE_MEMBERS(ITERATOR) \ + ITERATOR(id, id, id, id, INT) \ + ITERATOR(duration_idx, duration_idx, duration_idx, duration_idx, INT) \ + ITERATOR(duration_value, duration_value, duration_value, duration_value, INT) \ + ITERATOR(frame_index, frame_index, frame_index, frame_index, INT) \ + ITERATOR(probability, probability, p, p, FLOAT) \ + ITERATOR(log_probability, log_probability, plog, plog, FLOAT) \ + ITERATOR(start_time, start_time, start_time, t0, TIME) \ + ITERATOR(end_time, end_time, end_time, t1, TIME) \ + ITERATOR(word_start?, word_start, word_start_p, is_word_start, BOOL) + +#define ITERATE_ATTRS(ITERATOR) \ + ITERATOR(text, text, text, text, STRING) + +enum { +#define DEF_IDX(rb_name, s_key, c_name, p_name, type) RUBY_WHISPER_PARAKEET_TOKEN_##c_name, + + ITERATE_MEMBERS(DEF_IDX) + ITERATE_ATTRS(DEF_IDX) + RUBY_WHISPER_PARAKEET_TOKEN_NUM_ATTRS, +}; + +#define VAL_FROM_INT(v) (INT2NUM(v)) +#define VAL_FROM_FLOAT(v) (DBL2NUM(v)) +#define VAL_FROM_TIME(v) (LONG2NUM(v * 10)) +#define VAL_FROM_BOOL(v) ((v) ? Qtrue : Qfalse) +#define VAL_FROM_STRING(v) (rb_str_new2(v)) + +#define READER(type) VAL_FROM_##type +#define MEMBER_NAME(name) name +#define DEF_MEMBER_ATTR(rb_name, s_key, c_name, p_name, type) \ + static VALUE \ + ruby_whisper_parakeet_token_get_##c_name(VALUE self) \ + { \ + ruby_whisper_parakeet_token *rwpt; \ + GetParakeetToken(self, rwpt); \ + return READER(type)(rwpt->token_data->MEMBER_NAME(p_name)); \ + } + +#define DEF_ATTR(rb_name, s_key, c_name, p_name, type) \ + static VALUE \ + ruby_whisper_parakeet_token_get_##c_name(VALUE self) \ + { \ + ruby_whisper_parakeet_token *rwpt; \ + GetParakeetToken(self, rwpt); \ + return rwpt->p_name; \ + } + +VALUE cParakeetToken; + +#define DEC_ATTR_SYMS(rb_name, s_key, c_name, p_name, type) static VALUE sym_##s_key; + +ITERATE_MEMBERS(DEC_ATTR_SYMS) +ITERATE_ATTRS(DEC_ATTR_SYMS) + +static void +ruby_whisper_parakeet_token_mark(void *p) +{ + ruby_whisper_parakeet_token *rwpt = (ruby_whisper_parakeet_token *)p; + rb_gc_mark(rwpt->text); +} + +static void +ruby_whisper_parakeet_token_free(void *p) +{ + ruby_whisper_parakeet_token *rwpt = (ruby_whisper_parakeet_token *)p; + if (rwpt->token_data) { + xfree(rwpt->token_data); + rwpt->token_data = NULL; + } + xfree(rwpt); +} + +static size_t +ruby_whisper_parakeet_token_memsize(const void *p) +{ + ruby_whisper_parakeet_token *rwpt = (ruby_whisper_parakeet_token *)p; + if (!rwpt) { + return 0; + } + size_t size = sizeof(*rwpt); + if (rwpt->token_data) { + size += sizeof(*rwpt->token_data); + } + + return size; +} + +static const rb_data_type_t ruby_whisper_parakeet_token_type = { + "ruby_whisper_parakeet_token", + {ruby_whisper_parakeet_token_mark, ruby_whisper_parakeet_token_free, ruby_whisper_parakeet_token_memsize}, + 0, 0, + 0, +}; + +static VALUE +ruby_whisper_parakeet_token_s_allocate(VALUE klass) +{ + ruby_whisper_parakeet_token *rwpt; + VALUE token = TypedData_Make_Struct(klass, ruby_whisper_parakeet_token, &ruby_whisper_parakeet_token_type, rwpt); + + rwpt->token_data = NULL; + rwpt->text = Qnil; + + return token; +} + +VALUE +ruby_whisper_parakeet_token_s_from_token_data(struct parakeet_context *context, const parakeet_token_data *token_data) +{ + const VALUE token = ruby_whisper_parakeet_token_s_allocate(cParakeetToken); + ruby_whisper_parakeet_token *rwpt; + TypedData_Get_Struct(token, ruby_whisper_parakeet_token, &ruby_whisper_parakeet_token_type, rwpt); + + rwpt->token_data = ALLOC(parakeet_token_data); + *rwpt->token_data = *token_data; + rwpt->text = rb_utf8_str_new_cstr(parakeet_token_to_str(context, token_data->id)); + + return token; +} + +VALUE +ruby_whisper_parakeet_token_s_from_index(struct parakeet_context *context, int i_segment, int i_token) +{ + parakeet_token_data token_data = parakeet_full_get_token_data(context, i_segment, i_token); + return ruby_whisper_parakeet_token_s_from_token_data(context, &token_data); +} + +ITERATE_MEMBERS(DEF_MEMBER_ATTR) +// Define #text using parakeet_token_to_str or parakeet_token_to_text +ITERATE_ATTRS(DEF_ATTR) + +static VALUE +ruby_whisper_parakeet_token_deconstruct_keys(VALUE self, VALUE keys) +{ + ruby_whisper_parakeet_token *rwpt; + GetParakeetToken(self, rwpt); + + VALUE hash = rb_hash_new(); + long n_keys = 0; + + if (NIL_P(keys)) { + VALUE attrs[] = { +#define LIST_SYMS(rb_name, s_key, c_name, p_name, type) sym_##s_key, + + ITERATE_MEMBERS(LIST_SYMS) + ITERATE_ATTRS(LIST_SYMS) + }; + keys = rb_ary_new_from_values(RUBY_WHISPER_PARAKEET_TOKEN_NUM_ATTRS, attrs); + n_keys = RUBY_WHISPER_PARAKEET_TOKEN_NUM_ATTRS; + } else { + n_keys = RARRAY_LEN(keys); + if (n_keys > RUBY_WHISPER_PARAKEET_TOKEN_NUM_ATTRS) { + return hash; + } + } + for (long i = 0; i < n_keys; i++) { + VALUE key = rb_ary_entry(keys, i); + +#define CHECK_AND_SET_KEY(rb_name, s_key, c_name, p_name, type) \ + if (key == sym_##s_key) { \ + rb_hash_aset(hash, key, ruby_whisper_parakeet_token_get_##c_name(self)); \ + } + + ITERATE_MEMBERS(CHECK_AND_SET_KEY) + ITERATE_ATTRS(CHECK_AND_SET_KEY) + } + + return hash; +} + +void +init_ruby_whisper_parakeet_token(VALUE *mParakeet) +{ + cParakeetToken = rb_define_class_under(*mParakeet, "Token", rb_cObject); + rb_define_alloc_func(cParakeetToken, ruby_whisper_parakeet_token_s_allocate); + +#define REGISTER_ATTR(rb_name, s_key, c_name, p_name, type) \ + sym_##s_key = ID2SYM(rb_intern(#s_key)); \ + rb_define_method(cParakeetToken, #rb_name, ruby_whisper_parakeet_token_get_##c_name, 0); + + ITERATE_MEMBERS(REGISTER_ATTR) + ITERATE_ATTRS(REGISTER_ATTR) + + rb_define_method(cParakeetToken, "deconstruct_keys", ruby_whisper_parakeet_token_deconstruct_keys, 1); +} diff --git a/bindings/ruby/ext/ruby_whisper_parakeet_transcribe.cpp b/bindings/ruby/ext/ruby_whisper_parakeet_transcribe.cpp new file mode 100644 index 00000000000..c4deccce84a --- /dev/null +++ b/bindings/ruby/ext/ruby_whisper_parakeet_transcribe.cpp @@ -0,0 +1,58 @@ +#include "ruby_whisper.h" +#include "common-whisper.h" +#include <string> +#include <vector> + +#ifdef __cplusplus +extern "C" { +#endif + +extern const rb_data_type_t ruby_whisper_parakeet_context_type; +extern const rb_data_type_t ruby_whisper_parakeet_params_type; + +extern VALUE ruby_whisper_parakeet_context_full_body(VALUE rb_args); + +extern ID id_to_path; +extern ID id_new; + +extern VALUE eError; + +VALUE +ruby_whisper_parakeet_transcribe(VALUE self, VALUE audio_path, VALUE params) +{ + if (rb_respond_to(audio_path, id_to_path)) { + audio_path = rb_funcall(audio_path, id_to_path, 0); + } + + std::string fname = StringValueCStr(audio_path); + std::vector<float> pcmf32; + std::vector<std::vector<float>> pcmf32s; + + if (!read_audio_data(fname, pcmf32, pcmf32s, false)) { + rb_raise(rb_eRuntimeError, "Failed to open %s", fname.c_str()); + return Qnil; + } + + ruby_whisper_parakeet_context *rwpc; + ruby_whisper_parakeet_params *rwpp; + GetParakeetContext(self, rwpc); + GetParakeetParams(params, rwpp); + + ruby_whisper_full_args args = { + &self, + ¶ms, + pcmf32.data(), + (int)pcmf32.size(), + }; + VALUE rb_result = ruby_whisper_parakeet_context_full_body((VALUE)&args); + const int result = NUM2INT(rb_result); + if (result == 0) { + return self; + } else { + rb_exc_raise(rb_funcall(eError, id_new, 1, rb_result)); + } +} + +#ifdef __cplusplus +} +#endif diff --git a/bindings/ruby/ext/ruby_whisper_params.c b/bindings/ruby/ext/ruby_whisper_params.c index 2aae7c12d19..f38e9bde3ea 100644 --- a/bindings/ruby/ext/ruby_whisper_params.c +++ b/bindings/ruby/ext/ruby_whisper_params.c @@ -76,8 +76,8 @@ static ID id_vad; static ID id_vad_model_path; static ID id_vad_params; -static void -rb_whisper_callbcack_container_mark(ruby_whisper_callback_container *rwc) +void +ruby_whisper_callback_container_mark(ruby_whisper_callback_container *rwc) { if (rwc == NULL) return; @@ -86,8 +86,8 @@ rb_whisper_callbcack_container_mark(ruby_whisper_callback_container *rwc) rb_gc_mark(rwc->callbacks); } -static ruby_whisper_callback_container* -rb_whisper_callback_container_allocate() { +ruby_whisper_callback_container* +ruby_whisper_callback_container_allocate() { ruby_whisper_callback_container *container; container = ALLOC(ruby_whisper_callback_container); container->context = NULL; @@ -97,38 +97,11 @@ rb_whisper_callback_container_allocate() { return container; } -static void -rb_whisper_abort_callback_container_mark(ruby_whisper_abort_callback_container *rwc) -{ - if (rwc == NULL) return; - - rb_gc_mark(rwc->user_data); - rb_gc_mark(rwc->callback); - rb_gc_mark(rwc->callbacks); -} - -static ruby_whisper_abort_callback_container* -rb_whisper_abort_callback_container_allocate() { - ruby_whisper_abort_callback_container *container; - container = ALLOC(ruby_whisper_abort_callback_container); - container->context = NULL; - container->user_data = Qnil; - container->callback = Qnil; - container->callbacks = Qnil; - container->is_interrupted = false; - return container; -} - -static bool +bool ruby_whisper_callback_container_is_present(const ruby_whisper_callback_container *container) { return !NIL_P(container->callback) || !NIL_P(container->callbacks); } -static bool -ruby_whisper_abort_callback_container_is_present(const ruby_whisper_abort_callback_container *container) { - return !NIL_P(container->callback) || !NIL_P(container->callbacks); -} - typedef struct { const ruby_whisper_callback_container *container; struct whisper_state *state; @@ -283,24 +256,19 @@ static bool encoder_begin_callback(struct whisper_context *ctx, struct whisper_s } typedef struct { - const ruby_whisper_abort_callback_container *container; - struct whisper_state *state; + const ruby_whisper_callback_container *container; bool is_interrupted; } call_abort_callbacks_args; static void* call_abort_callbacks(void *v_args) { call_abort_callbacks_args *args = (call_abort_callbacks_args *)v_args; - const ruby_whisper_abort_callback_container *container = args->container; - - if (container->is_interrupted) { - args->is_interrupted = true; - return NULL; - } + const ruby_whisper_callback_container *container = args->container; + VALUE result = Qnil; if (!NIL_P(container->callback)) { - VALUE result = rb_funcall(container->callback, id_call, 1, container->user_data); - if (!NIL_P(result) && Qfalse != result) { + result = rb_funcall(container->callback, id_call, 1, container->user_data); + if (RTEST(result)) { args->is_interrupted = true; return NULL; } @@ -308,14 +276,14 @@ call_abort_callbacks(void *v_args) { if (NIL_P(container->callbacks)) { return NULL; } - const long callbacks_len = RARRAY_LEN(container->callbacks); - if (0 == callbacks_len) { + const long n_callbacks = RARRAY_LEN(container->callbacks); + if (0 == n_callbacks) { return NULL; } - for (int j = 0; j < callbacks_len; j++) { + for (int j = 0; j < n_callbacks; j++) { VALUE cb = rb_ary_entry(container->callbacks, j); - VALUE result = rb_funcall(cb, id_call, 1, container->user_data); - if (!NIL_P(result) && Qfalse != result) { + VALUE result = rb_funcall(cb, id_call, 0); + if (RTEST(result)) { args->is_interrupted = true; return NULL; } @@ -325,19 +293,19 @@ call_abort_callbacks(void *v_args) { } static bool abort_callback(void * user_data) { - const ruby_whisper_abort_callback_container *container = (ruby_whisper_abort_callback_container *)user_data; + ruby_whisper_abort_callback_user_data *data = (ruby_whisper_abort_callback_user_data *)user_data; - if (container->is_interrupted) { + int is_interrupted = RUBY_ATOMIC_LOAD(data->is_interrupted); + if (is_interrupted) { return true; } - if (!ruby_whisper_abort_callback_container_is_present(container)) { + if (!(data->callback_container) || !ruby_whisper_callback_container_is_present(data->callback_container)) { return false; } call_abort_callbacks_args args = { - container, - NULL, + data->callback_container, false }; rb_thread_call_with_gvl(call_abort_callbacks, (void *)&args); @@ -352,29 +320,19 @@ check_thread_safety(ruby_whisper_params *rwp, int n_processors) return; } - if (ruby_whisper_callback_container_is_present(rwp->new_segment_callback_container)) { - rb_raise(rb_eRuntimeError, "new segment callback not supported on parallel transcription"); - } - - if (ruby_whisper_callback_container_is_present(rwp->progress_callback_container)) { - rb_raise(rb_eRuntimeError, "progress callback not supported on parallel transcription"); - } + // new_segment_callback is called only after multiple threads are joined + // progress_callback is not called when parallel if (ruby_whisper_callback_container_is_present(rwp->encoder_begin_callback_container)) { rb_raise(rb_eRuntimeError, "encoder begin callback not supported on parallel transcription"); } - if (ruby_whisper_abort_callback_container_is_present(rwp->abort_callback_container)) { + if (ruby_whisper_callback_container_is_present(rwp->abort_callback_container)) { rb_raise(rb_eRuntimeError, "abort callback not supported on parallel transcription"); } - - VALUE log_callback = rb_iv_get(mWhisper, "log_callback"); - if (!NIL_P(log_callback)) { - rb_raise(rb_eRuntimeError, "log callback not supported for parallel transcription"); - } } -static void register_callbacks(ruby_whisper_params * rwp, VALUE * context) { +static void register_callbacks(ruby_whisper_params * rwp, VALUE * context, ruby_whisper_abort_callback_user_data *abort_callback_user_data) { if (ruby_whisper_callback_container_is_present(rwp->new_segment_callback_container)) { rwp->new_segment_callback_container->context = context; rwp->params.new_segment_callback = new_segment_callback; @@ -393,10 +351,10 @@ static void register_callbacks(ruby_whisper_params * rwp, VALUE * context) { rwp->params.encoder_begin_callback_user_data = rwp->encoder_begin_callback_container; } + abort_callback_user_data->callback_container = rwp->abort_callback_container; rwp->abort_callback_container->context = context; rwp->params.abort_callback = abort_callback; - rwp->abort_callback_container->is_interrupted = false; - rwp->params.abort_callback_user_data = rwp->abort_callback_container; + rwp->params.abort_callback_user_data = (void *)abort_callback_user_data; } static void set_vad_params(ruby_whisper_params *rwp) @@ -406,14 +364,11 @@ static void set_vad_params(ruby_whisper_params *rwp) rwp->params.vad_params = rwvp->params; } -/* - TODO: Set abort callback to trap SIGINT and SIGTERM -*/ void -prepare_transcription(ruby_whisper_params *rwp, VALUE *context, int n_processors) +prepare_transcription(ruby_whisper_params *rwp, VALUE *context, int n_processors, ruby_whisper_abort_callback_user_data *abort_callback_user_data) { check_thread_safety(rwp, n_processors); - register_callbacks(rwp, context); + register_callbacks(rwp, context, abort_callback_user_data); set_vad_params(rwp); } @@ -421,10 +376,10 @@ void rb_whisper_params_mark(void *p) { ruby_whisper_params *rwp = (ruby_whisper_params *)p; - rb_whisper_callbcack_container_mark(rwp->new_segment_callback_container); - rb_whisper_callbcack_container_mark(rwp->progress_callback_container); - rb_whisper_callbcack_container_mark(rwp->encoder_begin_callback_container); - rb_whisper_abort_callback_container_mark(rwp->abort_callback_container); + ruby_whisper_callback_container_mark(rwp->new_segment_callback_container); + ruby_whisper_callback_container_mark(rwp->progress_callback_container); + ruby_whisper_callback_container_mark(rwp->encoder_begin_callback_container); + ruby_whisper_callback_container_mark(rwp->abort_callback_container); rb_gc_mark(rwp->vad_params); } @@ -492,10 +447,10 @@ ruby_whisper_params_allocate(VALUE klass) } rwp->diarize = false; rwp->vad_params = TypedData_Wrap_Struct(cVADParams, &ruby_whisper_vad_params_type, (void *)&rwp->params.vad_params); - rwp->new_segment_callback_container = rb_whisper_callback_container_allocate(); - rwp->progress_callback_container = rb_whisper_callback_container_allocate(); - rwp->encoder_begin_callback_container = rb_whisper_callback_container_allocate(); - rwp->abort_callback_container = rb_whisper_abort_callback_container_allocate(); + rwp->new_segment_callback_container = ruby_whisper_callback_container_allocate(); + rwp->progress_callback_container = ruby_whisper_callback_container_allocate(); + rwp->encoder_begin_callback_container = ruby_whisper_callback_container_allocate(); + rwp->abort_callback_container = ruby_whisper_callback_container_allocate(); return obj; } diff --git a/bindings/ruby/ext/ruby_whisper_segment.c b/bindings/ruby/ext/ruby_whisper_segment.c index ee0d66c4cc8..cf0372797d3 100644 --- a/bindings/ruby/ext/ruby_whisper_segment.c +++ b/bindings/ruby/ext/ruby_whisper_segment.c @@ -4,12 +4,12 @@ extern ID id___method__; extern ID id_to_enum; -static VALUE sym_start_time; -static VALUE sym_end_time; -static VALUE sym_text; -static VALUE sym_no_speech_prob; -static VALUE sym_speaker_turn_next; -static VALUE sym_n_tokens; +VALUE sym_start_time; +VALUE sym_end_time; +VALUE sym_text; +VALUE sym_no_speech_prob; +VALUE sym_speaker_turn_next; +VALUE sym_n_tokens; extern const rb_data_type_t ruby_whisper_type; diff --git a/bindings/ruby/ext/ruby_whisper_transcribe.cpp b/bindings/ruby/ext/ruby_whisper_transcribe.cpp index 37656af1c44..73f606ca476 100644 --- a/bindings/ruby/ext/ruby_whisper_transcribe.cpp +++ b/bindings/ruby/ext/ruby_whisper_transcribe.cpp @@ -16,6 +16,8 @@ extern ID id_to_path; extern ID transcribe_option_names[1]; extern void prepare_transcription(ruby_whisper_params * rwp, VALUE * self, int n_processors); +extern VALUE full_body(VALUE rb_args); +extern VALUE full_parallel_body(VALUE rb_args); typedef struct{ struct whisper_context *context; @@ -35,18 +37,6 @@ transcribe_without_gvl(void *rb_args) return NULL; } -typedef struct { - ruby_whisper_abort_callback_container *abort_callback_container; -} transcribe_ubf_args; - -static void -transcribe_ubf(void *rb_args) -{ - transcribe_ubf_args *args = (transcribe_ubf_args *)rb_args; - - args->abort_callback_container->is_interrupted = true; -} - /* * transcribe a single file * can emit to a block results @@ -91,32 +81,28 @@ ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) { fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str()); return self; } - // Commented out because it is work in progress - // { - // static bool is_aborted = false; // NOTE: this should be atomic to avoid data race - - // rwp->params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) { - // bool is_aborted = *(bool*)user_data; - // return !is_aborted; - // }; - // rwp->params.encoder_begin_callback_user_data = &is_aborted; - // } - - prepare_transcription(rwp, &self, n_processors); - - transcribe_without_gvl_args args = { - rw->context, - &rwp->params, - pcmf32.data(), - pcmf32.size(), - n_processors, - 0, - }; - transcribe_ubf_args ubf_args = { - rwp->abort_callback_container, - }; - rb_thread_call_without_gvl(transcribe_without_gvl, (void *)&args, transcribe_ubf, (void *)&ubf_args); - if (args.result != 0) { + + VALUE rb_result; + if (n_processors == 1) { + ruby_whisper_full_args args = { + &self, + ¶ms, + pcmf32.data(), + (int)pcmf32.size(), + }; + rb_result = full_body((VALUE)&args); + } else { + ruby_whisper_full_parallel_args parallel_args = { + &self, + ¶ms, + pcmf32.data(), + (int)pcmf32.size(), + n_processors, + }; + rb_result = full_parallel_body((VALUE)¶llel_args); + } + const int result = NUM2INT(rb_result); + if (result != 0) { fprintf(stderr, "failed to process audio\n"); return self; } diff --git a/bindings/ruby/lib/whisper/context.rb b/bindings/ruby/lib/whisper/context.rb deleted file mode 100644 index c3a134b773d..00000000000 --- a/bindings/ruby/lib/whisper/context.rb +++ /dev/null @@ -1,15 +0,0 @@ -module Whisper - class Context - def to_srt - each_segment.with_index.reduce("") {|srt, (segment, index)| - srt << "#{index + 1}\n#{segment.to_srt_cue}\n" - } - end - - def to_webvtt - each_segment.with_index.reduce("WEBVTT\n\n") {|webvtt, (segment, index)| - webvtt << "#{index + 1}\n#{segment.to_webvtt_cue}\n" - } - end - end -end diff --git a/bindings/ruby/lib/whisper/log_settable.rb b/bindings/ruby/lib/whisper/log_settable.rb new file mode 100644 index 00000000000..2f8218d26ee --- /dev/null +++ b/bindings/ruby/lib/whisper/log_settable.rb @@ -0,0 +1,36 @@ +require "mutex_m" + +module Whisper + module LogSettable + class << self + def extended(base) + base.extend Mutex_m + end + end + + private + + def start_log_callback_thread + return if @log_callback_thread&.alive? + + @log_callback_thread = Thread.new { + begin + while logs = drain_logs + begin + callback, user_data = synchronize {[@log_callback, @log_callback_user_data]} + next if callback.nil? + + logs.each do |(level, text)| + callback.call level, text, user_data + end + rescue => err + $stderr.puts err + end + end + rescue => err + $stderr.puts err + end + } + end + end +end diff --git a/bindings/ruby/lib/whisper/model/uri.rb b/bindings/ruby/lib/whisper/model/uri.rb index 8eb57e5e8cf..ef92eb901c4 100644 --- a/bindings/ruby/lib/whisper/model/uri.rb +++ b/bindings/ruby/lib/whisper/model/uri.rb @@ -41,6 +41,8 @@ def base_cache_dir def cache path = cache_path + return path if cache_path.exist? + headers = {} headers["if-modified-since"] = path.mtime.httpdate if path.exist? request @uri, headers @@ -216,8 +218,18 @@ def escaping(path) @pre_converted_models[name] = URI.new("https://huggingface.co/ggml-org/whisper-vad/resolve/main/ggml-#{name}.bin") end + %w[ + parakeet-tdt-0.6b-v3-f16 + parakeet-tdt-0.6b-v3-f32 + parakeet-tdt-0.6b-v3-q4_0 + parakeet-tdt-0.6b-v3-q4_k + parakeet-tdt-0.6b-v3-q8_0 + ].each do |name| + @pre_converted_models[name] = URI.new("https://huggingface.co/ggml-org/parakeet-GGUF/resolve/main/ggml-#{name}.bin") + end + @coreml_compiled_models = @pre_converted_models.each_with_object({}) {|(name, uri), models| - next if name.end_with?("-tdrz") || name.start_with?("silero-") + next if name.end_with?("-tdrz") || name.start_with?("silero-") || name.start_with?("parakeet-") if matched = name.match(/\A(?<name>.*)-q\d_\d\z/) name = matched[:name] diff --git a/bindings/ruby/lib/whisper/output.rb b/bindings/ruby/lib/whisper/output.rb new file mode 100644 index 00000000000..1781af17a33 --- /dev/null +++ b/bindings/ruby/lib/whisper/output.rb @@ -0,0 +1,74 @@ +module Whisper + module Output + module Context + def to_srt + each_segment.with_index.reduce("") {|srt, (segment, index)| + srt << "#{index + 1}\n#{segment.to_srt_cue}\n" + } + end + + def to_webvtt + each_segment.with_index.reduce("WEBVTT\n\n") {|webvtt, (segment, index)| + webvtt << "#{index + 1}\n#{segment.to_webvtt_cue}\n" + } + end + end + + module Segment + SRT_ESCAPES = { + "&" => "&", + "<" => "<", + ">" => ">", + } + SRT_ESCAPES_RE = Regexp.union(SRT_ESCAPES.keys) + private_constant :SRT_ESCAPES, :SRT_ESCAPES_RE + + def to_srt_cue + "#{srt_start_time} --> #{srt_end_time}\n#{srt_text}\n" + end + + def to_webvtt_cue + "#{webvtt_start_time} --> #{webvtt_end_time}\n#{webvtt_text}\n" + end + + private + + def time_to_a(time) + sec, decimal_part = time.divmod(1000) + min, sec = sec.divmod(60) + hour, min = min.divmod(60) + [hour, min, sec, decimal_part] + end + + def srt_time(time) + "%02d:%02d:%02d,%03d" % time_to_a(time) + end + + def srt_start_time + srt_time(start_time) + end + + def srt_end_time + srt_time(end_time) + end + + def srt_text + text.gsub(SRT_ESCAPES_RE, SRT_ESCAPES) + end + + def webvtt_time(time) + "%02d:%02d:%02d.%03d" % time_to_a(time) + end + + def webvtt_start_time + webvtt_time(start_time) + end + + def webvtt_end_time + webvtt_time(end_time) + end + + alias webvtt_text srt_text + end + end +end diff --git a/bindings/ruby/lib/whisper/segment.rb b/bindings/ruby/lib/whisper/segment.rb deleted file mode 100644 index dc187dcac36..00000000000 --- a/bindings/ruby/lib/whisper/segment.rb +++ /dev/null @@ -1,58 +0,0 @@ -module Whisper - class Segment - SRT_ESCAPES = { - "&" => "&", - "<" => "<", - ">" => ">", - } - SRT_ESCAPES_RE = Regexp.union(SRT_ESCAPES.keys) - private_constant :SRT_ESCAPES, :SRT_ESCAPES_RE - - def to_srt_cue - "#{srt_start_time} --> #{srt_end_time}\n#{srt_text}\n" - end - - def to_webvtt_cue - "#{webvtt_start_time} --> #{webvtt_end_time}\n#{webvtt_text}\n" - end - - private - - def time_to_a(time) - sec, decimal_part = time.divmod(1000) - min, sec = sec.divmod(60) - hour, min = min.divmod(60) - [hour, min, sec, decimal_part] - end - - def srt_time(time) - "%02d:%02d:%02d,%03d" % time_to_a(time) - end - - def srt_start_time - srt_time(start_time) - end - - def srt_end_time - srt_time(end_time) - end - - def srt_text - text.gsub(SRT_ESCAPES_RE, SRT_ESCAPES) - end - - def webvtt_time(time) - "%02d:%02d:%02d.%03d" % time_to_a(time) - end - - def webvtt_start_time - webvtt_time(start_time) - end - - def webvtt_end_time - webvtt_time(end_time) - end - - alias webvtt_text srt_text - end -end diff --git a/bindings/ruby/sig/whisper.rbs b/bindings/ruby/sig/whisper.rbs index cbec4803820..c12e1fe55e5 100644 --- a/bindings/ruby/sig/whisper.rbs +++ b/bindings/ruby/sig/whisper.rbs @@ -40,7 +40,21 @@ module Whisper def self.log_set: (log_callback?, Object? user_data) -> log_callback def self.system_info_str: () -> String + module Output + module Context + def to_srt: () -> String + def to_webvtt: () -> String + end + + module Segment + def to_srt_cue: () -> String + def to_webvtt_cue: () -> String + end + end + class Context + include Output::Context + def self.new: (String | path | ::URI::HTTP) -> instance # transcribe a single file @@ -139,17 +153,14 @@ module Whisper | (Whisper::Params, _Samples, ?Integer n_samples) -> self | (Whisper::Params, _Samples, ?Integer? n_samples, Integer n_processors) -> self - def to_srt: () -> String - def to_webvtt: () -> String - class Params def self.new: ( - use_gpu: boolish, - flash_attn: boolish, - gpu_device: Integer, - dtw_token_timestamps: boolish, - dtw_aheads_preset: Integer, - dtw_n_top: Integer | nil, + ?use_gpu: boolish, + ?flash_attn: boolish, + ?gpu_device: Integer, + ?dtw_token_timestamps: boolish, + ?dtw_aheads_preset: Integer, + ?dtw_n_top: Integer | nil, ) -> instance def use_gpu=: (boolish) -> boolish @@ -444,6 +455,9 @@ module Whisper def abort_on: { (Object user_data) -> boolish } -> void end + module LogSettable + end + class Model def self.pre_converted_models: () -> Hash[String, Model::URI] def self.coreml_compiled_models: () -> Hash[Model::URI, Model::ZipURI] @@ -474,6 +488,8 @@ module Whisper end class Segment + include Output::Segment + type deconstructed_keys = { start_time: (Integer | nil), end_time: (Integer | nil), @@ -514,9 +530,6 @@ module Whisper # def each_token: { (Token) -> void } -> void | () -> Enumerator[Token] - def to_srt_cue: () -> String - def to_webvtt_cue: () -> String - # Possible keys: `:start_time`, `:end_time`, `:text`, `:no_speech_prob`, `:speaker_turn_next` # @@ -528,7 +541,7 @@ module Whisper def deconstruct_keys: (Array[:start_time | :end_time | :text | :no_speech_prob | :speaker_turn_next | :n_tokens] | nil) -> deconstructed_keys end - module Token + class Token type deconstructed_keys = { id: (Integer | nil), tid: (Integer | nil), @@ -598,6 +611,336 @@ module Whisper def deconstruct_keys: (Array[:id | :tid | :probability | :log_probability | :pt | :ptsum | :t_dtw | :voice_length | :start_time | :end_time | :text] | nil) -> deconstructed_keys end + module Parakeet + extend LogSettable + + VERSION: String + + # Control logging output. The default behavior is to print to stderr. + # + def self.log_set: (nil, Object? user_data) -> nil + | (^(Integer level, String message, Object user_data) -> void, Object? user_data) -> nil + def self.system_info_str: () -> String + + class Context + include Output::Context + + # Load a Parakeet model from the given file path. + # + def self.new: (String | path | ::URI::HTTP, ?Params) -> instance + + # Transcribe a single audio file. + # + def transcribe: (path audio_file_path, Whisper::Parakeet::Params) -> self + + # Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text. + # Not thread safe for the same context. + # + # The second argument `samples` must be an array of samples, respond to `:length`, + # or be a MemoryView of an array of float. It must be 32 bit float PCM audio data. + # + def full: (Whisper::Parakeet::Params, Array[Float] samples, ?Integer n_samples) -> self + | (Whisper::Parakeet::Params, _Samples, ?Integer n_samples) -> self + + # Number of generated text segments. + # + def full_n_segments: () -> Integer + + # Start time of a segment indexed by `segment_index` in centiseconds (10 times milliseconds). + # + # full_get_segment_t0(3) # => 1668 (16680 ms) + # + def full_get_segment_t0: (Integer segment_index) -> Integer + + # End time of a segment indexed by `segment_index` in centiseconds (10 times milliseconds). + # + # full_get_segment_t1(3) # => 1668 (16680 ms) + # + def full_get_segment_t1: (Integer segment_index) -> Integer + + # Text of a segment indexed by `segment_index`. + # + # full_get_segment_text(3) # => "ask not what your country can do for you, ..." + # + def full_get_segment_text: (Integer segment_index) -> String + + # Number of tokens in the segment indexed by `segment_index`. + # + def full_n_tokens: (Integer segment_index) -> Integer + + # Text of the token indexed by `token_index` in the segment indexed by `segment_index`. + # + def full_get_token_text: (Integer segment_index, Integer token_index) -> String + + # Token id of the token indexed by `token_index` in the segment indexed by `segment_index`. + # + def full_get_token_id: (Integer segment_index, Integer token_index) -> Integer + + # Probability of the token indexed by `token_index` in the segment indexed by `segment_index`. + # + def full_get_token_p: (Integer segment_index, Integer token_index) -> Float + + # Token data of the token indexed by `token_index` in the segment indexed by `segment_index`. + # + def full_get_token_data: (Integer segment_index, Integer token_index) -> Token + + def model: () -> Model + + # Yields each Whisper::Parakeet::Segment: + # + # parakeet.transcribe("path/to/audio.wav", params) + # parakeet.each_segment do |segment| + # puts segment.text + # end + # + # Returns an `Enumerator` if no block given: + # + # parakeet.transcribe("path/to/audio.wav", params) + # enum = parakeet.each_segment + # enum.to_a # => [#<Whisper::Parakeet::Segment>, ...] + # + def each_segment: { (Segment) -> void } -> void + | () -> Enumerator[Segment] + + class Params + def self.new: (?use_gpu: boolish, ?gpu_device: Integer) -> instance + def use_gpu: () -> boolish + def use_gpu=: (boolish) -> boolish + def gpu_device: () -> Integer + def gpu_device=: (Integer) -> Integer + end + end + + class Params + def self.new: ( + ?n_threads: Integer, + ?offset_ms: Integer, + ?duration_ms: Integer, + ?no_context: boolish, + ?audio_ctx: Integer, + ?new_segment_callback: ^(Whisper::Parakeet::Context, untyped, Integer n_new, Object user_data) -> void, + ?new_segment_callback_user_data: Object, + ?new_token_callback: ^(Whisper::Parakeet::Context, untyped, Whisper::Parakeet::Token, Object user_data) -> void, + ?new_token_callback_user_data: Object, + ?progress_callback: ^(Whisper::Parakeet::Context, untyped, Integer progress, Object user_data) -> void, + ?progress_callback_user_data: Object, + ?encoder_begin_callback: ^(Whisper::Parakeet::Context, untyped, Object user_data) -> boolish, + ?encoder_begin_callback_user_data: Object, + ?abort_callback: ^(Object user_data) -> boolish, + ?abort_callback_user_data: Object + ) -> instance + + # Number of threads to use. + # + def n_threads=: (Integer) -> Integer + def n_threads: () -> Integer + + # Start offset in ms. + # + def offset_ms=: (Integer) -> Integer + def offset_ms: () -> Integer + + # Audio duration to process in ms. + # + def duration_ms=: (Integer) -> Integer + def duration_ms: () -> Integer + + # If `true`, does not use past transcription (if any) as context. + # + def no_context=: (boolish) -> boolish + def no_context: () -> (true | false) + + # Overwrite the audio context size. `0` uses the default value. + # + def audio_ctx=: (Integer) -> Integer + def audio_ctx: () -> Integer + + # Sets new segment callback, called for every newly generated text segment. + # + # params.new_segment_callback = ->(context, _, n_new, user_data) { + # # ... + # } + # + def new_segment_callback=: (^(Whisper::Parakeet::Context, untyped, Integer n_new, Object user_data) -> void) -> (^(Whisper::Parakeet::Context, untyped, Integer n_new, Object user_data) -> void) + def new_segment_callback: () -> ((^(Whisper::Parakeet::Context, untyped, Integer n_new, Object user_data) -> void) | nil) + + # Sets user data passed to the last argument of new segment callback. + # + def new_segment_callback_user_data=: (Object?) -> Object? + def new_segment_callback_user_data: () -> Object? + + # Sets token callback, called for every newly predicted token. + # + def new_token_callback=: (^(Whisper::Parakeet::Context, untyped, Whisper::Parakeet::Token, Object user_data) -> void) -> (^(Whisper::Parakeet::Context, untyped, Whisper::Parakeet::Token, Object user_data) -> void) + def new_token_callback: () -> ((^(Whisper::Parakeet::Context, untyped, Whisper::Parakeet::Token, Object user_data) -> void) | nil) + + # Sets user data passed to the last argument of token callback. + # + def new_token_callback_user_data=: (Object?) -> Object? + def new_token_callback_user_data: () -> Object? + + # Sets progress callback, called on each progress update. + # + # +progress+ is an Integer between 0 and 100. + # + def progress_callback=: (^(Whisper::Parakeet::Context, untyped, Integer progress, Object user_data) -> void) -> (^(Whisper::Parakeet::Context, untyped, Integer progress, Object user_data) -> void) + def progress_callback: () -> ((^(Whisper::Parakeet::Context, untyped, Integer progress, Object user_data) -> void) | nil) + + # Sets user data passed to the last argument of progress callback. + # + def progress_callback_user_data=: (Object?) -> Object? + def progress_callback_user_data: () -> Object? + + # Sets encoder begin callback, called each time before the encoder starts. + # + # If it returns `false`, the computation is aborted. + # + def encoder_begin_callback=: (^(Whisper::Parakeet::Context, untyped, Object user_data) -> boolish) -> (^(Whisper::Parakeet::Context, untyped, Object user_data) -> boolish) + def encoder_begin_callback: () -> ((^(Whisper::Parakeet::Context, untyped, Object user_data) -> boolish) | nil) + + # Sets user data passed to the last argument of encoder begin callback. + # + def encoder_begin_callback_user_data=: (Object?) -> Object? + def encoder_begin_callback_user_data: () -> Object? + + # Sets abort callback, called each time before ggml computation starts. + # + def abort_callback=: (^(Object user_data) -> boolish) -> (^(Object user_data) -> boolish) + def abort_callback: () -> ((^(Object user_data) -> boolish) | nil) + + # Sets user data passed to the last argument of abort callback. + # + def abort_callback_user_data=: (Object?) -> Object? + def abort_callback_user_data: () -> Object? + + # Hook called on new segment. Yields each Whisper::Parakeet::Segment. + # + def on_new_segment: { (Segment) -> void } -> void + + # Hook called on new token. Yields each Whisper::Parakeet::Token. + # + def on_new_token: { (Token) -> void } -> void + + # Hook called on progress update. Yields each progress `Integer` between 0 and 100. + # + def on_progress: { (Integer progress) -> void } -> void + + # Hook called each time before the encoder starts. + # + def on_encoder_begin: { () -> boolish } -> void + + # Call block to determine whether abort or not. Return `true` when you want to abort. + # + def abort_on: { () -> boolish } -> void + end + + class Segment + include Output::Segment + + type deconstructed_keys = { + start_time: (Integer | nil), + end_time: (Integer | nil), + text: (String | nil) + } + + # Start time in milliseconds. + # + def start_time: () -> Integer + + # End time in milliseconds. + # + def end_time: () -> Integer + + # Text of the segment. + # + def text: () -> String + + # Yields each Whisper::Parakeet::Token: + # + # parakeet.each_segment.first.each_token do |token| + # p token + # end + # + # Returns an `Enumerator` if no block is given: + # + # parakeet.each_segment.first.each_token.to_a # => [#<Whisper::Parakeet::Token>, ...] + # + def each_token: { (Token) -> void } -> void + | () -> Enumerator[Token] + + # Possible keys: `:start_time`, `:end_time`, `:text` + # + def deconstruct_keys: (Array[:start_time | :end_time | :text] | nil) -> deconstructed_keys + end + + class Token + type deconstructed_keys = { + id: (Integer | nil), + duration_idx: (Integer | nil), + duration_value: (Integer | nil), + frame_index: (Integer | nil), + probability: (Float | nil), + log_probability: (Float | nil), + start_time: (Integer | nil), + end_time: (Integer | nil), + word_start: ((true | false) | nil), + text: (String | nil), + } + + # Token ID. + # + def id: () -> Integer + + # Index into the model's durations array. + # + def duration_idx: () -> Integer + + # Actual duration value. + # + def duration_value: () -> Integer + + # Frame index of the token. + # + def frame_index: () -> Integer + + # Probability of the token. + # + def probability: () -> Float + + # Log probability of the token. + # + def log_probability: () -> Float + + # Start time of the token in milliseconds. + # + def start_time: () -> Integer + + # End time of the token in milliseconds. + # + def end_time: () -> Integer + + # Whether this token is the start of a word. + # + def word_start?: () -> (true | false) + + # Get the token text of the token. + # + def text: () -> String + + def deconstruct_keys: (Array[:id | :duration_idx | :duration_value | :frame_index | :probability | :log_probability | :start_time | :end_time | :word_start | :text] | nil) -> deconstructed_keys + end + + class Model + def n_vocab: () -> Integer + def n_audio_ctx: () -> Integer + def n_audio_state: () -> Integer + def n_audio_head: () -> Integer + def n_audio_layer: () -> Integer + def n_mels: () -> Integer + def ftype: () -> Integer + end + end + module VAD class Params def self.new: ( diff --git a/bindings/ruby/test/helper.rb b/bindings/ruby/test/helper.rb index 56cd3849fdd..5e37ad98596 100644 --- a/bindings/ruby/test/helper.rb +++ b/bindings/ruby/test/helper.rb @@ -5,6 +5,8 @@ class TestBase < Test::Unit::TestCase AUDIO = File.join(__dir__, "fixtures", "jfk.wav") + Parakeet = Whisper::Parakeet + class << self def whisper return @whisper if @whisper diff --git a/bindings/ruby/test/test_callback.rb b/bindings/ruby/test/test_callback.rb index a7f49245ade..6490c8abb48 100644 --- a/bindings/ruby/test/test_callback.rb +++ b/bindings/ruby/test/test_callback.rb @@ -129,6 +129,7 @@ def test_encoder_begin_callback_abort return false } @whisper.transcribe(@audio, @params) + sleep 0.5 # wait for logs dequeued assert_match(/encoder_begin_callback returned false - aborting/, logs.join) Whisper.log_set ->(level, buffer, user_data) {}, nil end diff --git a/bindings/ruby/test/test_parakeet.rb b/bindings/ruby/test/test_parakeet.rb new file mode 100644 index 00000000000..bfd57076f56 --- /dev/null +++ b/bindings/ruby/test/test_parakeet.rb @@ -0,0 +1,28 @@ +require_relative "helper" +require "stringio" + +class TestParakeet < TestBase + def test_log_set + log_callback = Parakeet.instance_variable_get("@log_callback") + user_data = Parakeet.instance_variable_get("@log_callback_user_data") + + $stdout = StringIO.new + Parakeet.log_set proc {|level, message, _| puts [level, message].join(": ")}, nil + Parakeet::Context.new("test/fixtures/for-tests-ggml-parakeet-tdt.bin") + sleep 0.1 + $stdout.rewind + logs = $stdout.string + assert_match /loading model from/, logs + ensure + $stdout = STDOUT + Parakeet.log_set log_callback, user_data + end + + def test_system_info_str + assert_match /\APARAKEET : /, Parakeet.system_info_str + end + + def test_version + assert_instance_of String, Parakeet::VERSION + end +end diff --git a/bindings/ruby/test/test_parakeet_callback.rb b/bindings/ruby/test/test_parakeet_callback.rb new file mode 100644 index 00000000000..1209e960f09 --- /dev/null +++ b/bindings/ruby/test/test_parakeet_callback.rb @@ -0,0 +1,107 @@ +require_relative "helper" + +class TestParakeetCallback < TestBase + def setup + omit "Skip not to download large model" if ENV["CI"] + + Whisper.instance_variable_set "@whisper", nil + GC.start + @params = Parakeet::Params.new + @parakeet = Parakeet::Context.new("parakeet-tdt-0.6b-v3-q4_0") + end + + def test_new_segment_callback + @params.new_segment_callback = ->(context, state, n_new, user_data) { + assert_kind_of Integer, n_new + assert n_new > 0 + assert_same @parakeet, context + + n_segments = context.full_n_segments + n_new.times do |i| + i_segment = n_segments - 1 + i + start_time = context.full_get_segment_t0(i_segment) * 10 + end_time = context.full_get_segment_t1(i_segment) * 10 + text = context.full_get_segment_text(i_segment) + + assert_kind_of Integer, start_time + assert start_time >= 0 + assert_kind_of Integer, end_time + assert end_time > 0 + assert_match(/ask not what your country can do for you, ask what you can do for your/, text) if i_segment == 0 + end + } + + @parakeet.transcribe AUDIO, @params + end + + def test_on_new_segment + seg = nil + index = 0 + @params.on_new_segment do |segment| + assert_instance_of Parakeet::Segment, segment + if index == 0 + seg = segment + assert_equal 0, segment.start_time + assert_match(/ask not what your country can do for you, ask what you can do for your/, segment.text) + end + index += 1 + end + @parakeet.transcribe AUDIO, @params + assert_equal 0, seg.start_time + assert_match /ask not what your country can do for you, ask what you can do for your/, seg.text + end + + def test_on_new_token + index = 0 + @params.on_new_token do |token| + assert_instance_of Parakeet::Token, token + if index == 0 + assert_instance_of Integer, token.start_time + assert_match "▁And", token.text + end + index += 1 + end + + @parakeet.transcribe AUDIO, @params + end + + def test_on_progress + first = nil + @params.on_progress do |progress| + assert_kind_of Integer, progress + assert 0 <= progress && progress <= 100 + first = progress if first.nil? + end + + @parakeet.transcribe AUDIO, @params + + assert_equal 0, first + end + + def test_on_encoder_begin + i = 0 + @params.on_encoder_begin do + i += 1 + end + + @parakeet.transcribe AUDIO, @params + + assert i > 0 + end + + def test_abort_on + do_abort = false + @params.on_new_segment do |segment| + do_abort = true if segment.text.match?(/ask/) + end + i = 0 + @params.abort_on do + i += 1 + do_abort + end + + @parakeet.transcribe(AUDIO, @params) rescue nil + + assert i > 0 + end +end diff --git a/bindings/ruby/test/test_parakeet_context.rb b/bindings/ruby/test/test_parakeet_context.rb new file mode 100644 index 00000000000..2d039ce75f5 --- /dev/null +++ b/bindings/ruby/test/test_parakeet_context.rb @@ -0,0 +1,116 @@ +require_relative "helper" +require "stringio" + +class TestParakeetContext < TestBase + def setup + omit "Skip not to download large model" if ENV["CI"] + + Whisper.instance_variable_set "@whisper", nil + GC.start + + @parakeet = Parakeet::Context.new("parakeet-tdt-0.6b-v3-q4_0") + @params = Parakeet::Params.new + end + + def test_new + assert_instance_of Parakeet::Context, @parakeet + end + + def test_new_with_params + log_callback = Parakeet.instance_variable_get(:@log_callback) + user_data = Parakeet.instance_variable_get(:@log_callback_user_data) + begin + logs = "" + Parakeet.log_set proc {|level, message| logs << message}, nil + params = Parakeet::Context::Params.new(use_gpu: false) + parakeet = Parakeet::Context.new("parakeet-tdt-0.6b-v3-q4_0", params) + assert_instance_of Parakeet::Context, parakeet + assert_match /use gpu\s+=\s+0/, logs + ensure + Parakeet.log_set log_callback, user_data + end + end + + sub_test_case "full" do + def setup + super + @samples = File.read(AUDIO, nil, 78).unpack("s<*").collect {|i| i.to_f / 2**15} + end + + def test_full + @parakeet.full @params, @samples, @samples.length + + segments = @parakeet.each_segment.to_a + assert_equal 1, segments.length + assert_match /ask not what your country can do for you, ask what you can do for your/, segments.first.text + end + + def test_full_without_length + @parakeet.full(@params, @samples) + + segments = @parakeet.each_segment.to_a + assert_equal 1, segments.length + assert_match /ask not what your country can do for you, ask what you can do for your/, @parakeet.each_segment.first.text + end + + def test_full_enumerator + samples = @samples.each + @parakeet.full @params, samples, @samples.length + + segments = @parakeet.each_segment.to_a + assert_equal 1, segments.length + assert_match /ask not what your country can do for you, ask what you can do for your/, @parakeet.each_segment.first.text + end + + def test_full_enumerator_without_length + samples = @samples.each + assert_raise ArgumentError do + @parakeet.full @params, samples + end + end + + def test_full_enumerator_with_too_large_length + samples = @samples.each.take(10).to_enum + assert_raise StopIteration do + @parakeet.full @params, samples, 11 + end + end + + def test_full_with_memory_view + samples = JFKReader.new(AUDIO) + @parakeet.full @params, samples + + segments = @parakeet.each_segment.to_a + assert_equal 1, segments.length + assert_match /ask not what your country can do for you, ask what you can do for your/, @parakeet.each_segment.first.text + end + + def test_full_with_memroy_view_gc + samples = JFKReader.new(AUDIO) + @parakeet.full(@params, samples) + GC.start + require "fiddle" + Fiddle::MemoryView.export samples do |view| + assert_equal 176000, view.to_s.unpack("#{view.format}*").length + end + end + end + + def test_transcribe + assert_nothing_raised do + @parakeet.transcribe AUDIO, @params + end + end + + def test_transcribe_with_pathname + assert_nothing_raised do + @parakeet.transcribe Pathname(AUDIO), @params + end + end + + def test_transcribe_with_nothing + assert_raise_message(/open/) do + @parakeet.transcribe "nothing", @params + end + end +end diff --git a/bindings/ruby/test/test_parakeet_context_params.rb b/bindings/ruby/test/test_parakeet_context_params.rb new file mode 100644 index 00000000000..fcd0f2410f7 --- /dev/null +++ b/bindings/ruby/test/test_parakeet_context_params.rb @@ -0,0 +1,24 @@ +require_relative "helper" + +class TestParakeetContextParams < TestBase + def setup + @params = Parakeet::Context::Params.new + end + + def test_new + assert_instance_of Parakeet::Context::Params, @params + end + + def test_attributes + assert_true @params.use_gpu + assert_instance_of Integer, @params.gpu_device + end + + def test_attribute_writer + @params.use_gpu = false + assert_false @params.use_gpu + + @params.gpu_device = 2 + assert_equal 2, @params.gpu_device + end +end diff --git a/bindings/ruby/test/test_parakeet_model.rb b/bindings/ruby/test/test_parakeet_model.rb new file mode 100644 index 00000000000..5343b35ed8e --- /dev/null +++ b/bindings/ruby/test/test_parakeet_model.rb @@ -0,0 +1,21 @@ +require_relative "helper" + +class TestParakeetModel < TestBase + def test_model + parakeet = Parakeet::Context.new("test/fixtures/for-tests-ggml-parakeet-tdt.bin") + assert_instance_of Parakeet::Model, parakeet.model + end + + def test_attributes + parakeet = Parakeet::Context.new("test/fixtures/for-tests-ggml-parakeet-tdt.bin") + model = parakeet.model + + assert_equal 10, model.n_vocab + assert_equal 3200, model.n_audio_ctx + assert_equal 8, model.n_audio_state + assert_equal 2, model.n_audio_head + assert_equal 1, model.n_audio_layer + assert_equal 16, model.n_mels + assert_equal 0, model.ftype + end +end diff --git a/bindings/ruby/test/test_parakeet_params.rb b/bindings/ruby/test/test_parakeet_params.rb new file mode 100644 index 00000000000..dc651f7ab12 --- /dev/null +++ b/bindings/ruby/test/test_parakeet_params.rb @@ -0,0 +1,78 @@ +require_relative "helper" +require "etc" + +class TestParakeetParams < TestBase + PARAM_NAMES = [ + :n_threads, + :offset_ms, + :duration_ms, + :no_context, + :audio_ctx + ] + + def setup + @params = Parakeet::Params.new + end + + def test_new + assert_instance_of Parakeet::Params, @params + end + + def test_n_threads + assert_equal [4, Etc.nprocessors].min, @params.n_threads + + @params.n_threads = 1 + assert_equal 1, @params.n_threads + end + + def test_offset_ms + assert_equal 0, @params.offset_ms + + @params.offset_ms = 10_000 + assert_equal 10_000, @params.offset_ms + end + + def test_duration_ms + assert_equal 0, @params.duration_ms + + @params.duration_ms = 60_000 + assert_equal 60_000, @params.duration_ms + end + + def test_no_context + assert_equal true, @params.no_context + + @params.no_context = false + assert_equal false, @params.no_context + end + + def test_audio_ctx + assert_equal 0, @params.audio_ctx + + @params.audio_ctx = 1 + assert_equal 1, @params.audio_ctx + end + + def test_new_with_kw_args + params = Parakeet::Params.new(n_threads: 1) + assert_equal 1, params.n_threads + assert_equal 0, params.offset_ms + end + + data(PARAM_NAMES.collect {|param| [param, param]}.to_h) + def test_new_with_kw_args_default_values(param) + default_value = @params.send(param) + value = case [param, default_value] + in [*, true | false] + !default_value + in [*, Integer] + default_value + 1 + end + params = Parakeet::Params.new(param => value) + assert_equal value, params.send(param) + + PARAM_NAMES.reject {|name| name == param}.each do |name| + assert_equal @params.send(name), params.send(name) + end + end +end diff --git a/bindings/ruby/test/test_parakeet_segment.rb b/bindings/ruby/test/test_parakeet_segment.rb new file mode 100644 index 00000000000..d5b99bd5ee6 --- /dev/null +++ b/bindings/ruby/test/test_parakeet_segment.rb @@ -0,0 +1,42 @@ +require_relative "helper" + +class TestParakeetSegment < TestBase + def setup + omit "Skip not to download large model" if ENV["CI"] + + @parakeet = Parakeet::Context.new("parakeet-tdt-0.6b-v3-q4_0") + @parakeet.transcribe AUDIO, Parakeet::Params.new + end + + def test_segment + whole_text = "" + @parakeet.each_segment do |segment| + assert_instance_of Parakeet::Segment, segment + assert_kind_of Integer, segment.start_time + assert segment.end_time >= segment.start_time + assert_kind_of String, segment.text + whole_text << segment.text + end + assert_match(/ask not what your country can do for you, ask what you can do for your country/, whole_text) + end + + def test_deconstruct_keys + segment = @parakeet.each_segment.first + expected = { + start_time: segment.start_time, + end_time: segment.end_time, + text: segment.text + } + assert_equal expected, segment.deconstruct_keys([:start_time, :end_time, :text]) + end + + def test_deconstruct_keys_with_nil + segment = @parakeet.each_segment.first + expected = { + start_time: segment.start_time, + end_time: segment.end_time, + text: segment.text + } + assert_equal expected, segment.deconstruct_keys(nil) + end +end diff --git a/bindings/ruby/test/test_parakeet_token.rb b/bindings/ruby/test/test_parakeet_token.rb new file mode 100644 index 00000000000..6f0b8b5a37c --- /dev/null +++ b/bindings/ruby/test/test_parakeet_token.rb @@ -0,0 +1,73 @@ +require_relative "helper" + +class TestParakeetToken < TestBase + ATTRS = %i[ + id + duration_idx + duration_value + frame_index + probability + log_probability + start_time + end_time + word_start? + text + ] + + def setup + omit "Skip not to download large model" if ENV["CI"] + + Whisper.instance_variable_set "@whisper", nil + GC.start + + parakeet = Parakeet::Context.new("parakeet-tdt-0.6b-v3-q4_0") + params = Parakeet::Params.new + parakeet.transcribe AUDIO, params + @segment = parakeet.each_segment.first + end + + def test_each_token + i = 0 + @segment.each_token do |token| + i += 1 + assert_instance_of Parakeet::Token, token + end + assert_equal 38, i + end + + def test_each_token_without_block + assert_instance_of Enumerator, @segment.each_token + end + + def test_token + token = @segment.each_token.first + + assert_instance_of Parakeet::Token, token + assert_instance_of Integer, token.id + assert_instance_of Integer, token.duration_idx + assert_instance_of Integer, token.duration_value + assert_instance_of Integer, token.frame_index + assert_instance_of Float, token.probability + assert_instance_of Float, token.log_probability + assert_instance_of Integer, token.start_time + assert_instance_of Integer, token.end_time + assert_instance_of String, token.text + end + + def test_text + assert_equal ["▁And", "▁so", ",", "▁my", "▁f", "ell", "ow", "▁Amer", "ic", "ans", ",", "▁a", "sk", "▁not", "▁what", "▁your", "▁co", "un", "tr", "y", "▁can", "▁do", "▁for", "▁you", ",", "▁a", "sk", "▁what", "▁you", "▁can", "▁do", "▁for", "▁your", "▁co", "un", "tr", "y", "."], + @segment.each_token.collect(&:text) + end + + def test_deconstruct_keys_with_nil + token = @segment.each_token.first + expected = ATTRS.collect {|attr| [attr.to_s.sub(/\?\z/, "").intern, token.send(attr)]}.to_h + assert_equal expected, token.deconstruct_keys(nil) + end + + def test_deconstruct_keys_with_keys + token = @segment.each_token.first + expected = ATTRS.collect {|attr| [attr.to_s.sub(/\?\z/, "").intern, token.send(attr)]}.to_h + assert_equal expected, token.deconstruct_keys(expected.keys) + end +end diff --git a/bindings/ruby/test/test_vad_segment.rb b/bindings/ruby/test/test_vad_segment.rb index 7348562cb15..6d66c27fd32 100644 --- a/bindings/ruby/test/test_vad_segment.rb +++ b/bindings/ruby/test/test_vad_segment.rb @@ -9,7 +9,7 @@ def test_initialize end assert_raise do - segments.end_time + segment.end_time end assert_raise do diff --git a/bindings/ruby/test/test_whisper.rb b/bindings/ruby/test/test_whisper.rb index f7e25239d5d..082547e7c08 100644 --- a/bindings/ruby/test/test_whisper.rb +++ b/bindings/ruby/test/test_whisper.rb @@ -149,6 +149,7 @@ def test_log_set } Whisper.log_set log_callback, user_data Whisper::Context.new("base.en") + sleep 0.1 # wait for logs dequeued assert logs.length > 30 logs.each do |log| diff --git a/bindings/ruby/whispercpp.gemspec b/bindings/ruby/whispercpp.gemspec index 2d952222f29..301ecfcc13d 100644 --- a/bindings/ruby/whispercpp.gemspec +++ b/bindings/ruby/whispercpp.gemspec @@ -23,7 +23,7 @@ Gem::Specification.new do |s| s.test_files = s.files.select {|file| file.start_with? "test/"} s.extensions << 'ext/extconf.rb' - s.required_ruby_version = '>= 3.1.0' + s.required_ruby_version = '>= 3.3.0' #### Documentation and testing. s.homepage = 'https://github.com/ggml-org/whisper.cpp' From 86c40c3bd6fc86f1187fb751d111b49e0fc18e84 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius <daniel.bevenius@gmail.com> Date: Wed, 17 Jun 2026 11:36:57 +0200 Subject: [PATCH 828/831] release : v1.9.0 (#3886) --- CMakeLists.txt | 2 +- README.md | 2 +- bindings/javascript/package.json | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index dff25f25a34..8527d6d9bed 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,6 +1,6 @@ cmake_minimum_required(VERSION 3.5) # for add_link_options and implicit target directories. project("whisper.cpp" C CXX) -project("whisper.cpp" VERSION 1.8.7) +project("whisper.cpp" VERSION 1.9.0) include(CheckIncludeFileCXX) set(SOVERSION 1) diff --git a/README.md b/README.md index 19fdc70daab..a32d9b61382 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ [![Conan Center](https://shields.io/conan/v/whisper-cpp)](https://conan.io/center/whisper-cpp) [![npm](https://img.shields.io/npm/v/whisper.cpp.svg)](https://www.npmjs.com/package/whisper.cpp/) -Stable: [v1.8.7](https://github.com/ggml-org/whisper.cpp/releases/tag/v1.8.7) / [Roadmap](https://github.com/orgs/ggml-org/projects/4/) +Stable: [v1.9.0](https://github.com/ggml-org/whisper.cpp/releases/tag/v1.9.0) / [Roadmap](https://github.com/orgs/ggml-org/projects/4/) High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisper) automatic speech recognition (ASR) model: diff --git a/bindings/javascript/package.json b/bindings/javascript/package.json index 7c66c730c6c..b777591a4e3 100644 --- a/bindings/javascript/package.json +++ b/bindings/javascript/package.json @@ -1,6 +1,6 @@ { "name": "whisper.cpp", - "version": "1.8.7", + "version": "1.9.0", "description": "Whisper speech recognition", "main": "whisper.js", "scripts": { From 200b1197907545a88c5a00fb15f52e2cf88af6f5 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius <daniel.bevenius@gmail.com> Date: Thu, 18 Jun 2026 14:49:08 +0200 Subject: [PATCH 829/831] ci : add GGML_NATIVE=OFF and GGML_BMI2=OFF to windows-blas (#3891) * ci : add GGML_NATIVE=OFF and build all cpu-variants This commit adds -DGGML_BACKEND_DL=ON, -DGGML_NATIVE=OFF, and -DGGML_CPU_ALL_VARIANTS=ON to the releases. The motivation for this is that currently the Windows BLAS build uses the native CPU instructions and if target systems do not support these instructions, the build will fail like the linked issue reports. Resolves: https://github.com/ggml-org/whisper.cpp/issues/3889 * ci : update ubuntu-cpu release job for all variants [no ci] This commit enables the ubuntu-cpu job to include all cpu variants and ensures that the ggml backend libraries are built into the bin directory similar to how llama.cpp does it. The following is a build on my fork with this change: https://github.com/danbev/whisper.cpp/releases/tag/untagged-fc3c71f0bf0f7bf19d19 --- .github/workflows/release.yml | 15 +++++++++++---- CMakeLists.txt | 1 + 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index ef2c3083c9f..8dcfeb9827c 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -115,9 +115,11 @@ jobs: run: | cmake -B build \ -DCMAKE_BUILD_TYPE=Release \ - -DBUILD_SHARED_LIBS=OFF \ + -DCMAKE_INSTALL_RPATH='$ORIGIN' \ + -DCMAKE_BUILD_WITH_INSTALL_RPATH=ON \ + -DGGML_BACKEND_DL=ON \ -DGGML_NATIVE=OFF \ - ${{ matrix.build == 'arm64' && '-DGGML_CPU_ARM_ARCH=armv8-a' || '' }} + ${{ matrix.build == 'x64' && '-DGGML_CPU_ALL_VARIANTS=ON' || '-DGGML_CPU_ARM_ARCH=armv8-a' }} cmake --build build --config Release -j $(nproc) - name: Pack artifacts @@ -173,7 +175,7 @@ jobs: -DBUILD_SHARED_LIBS=ON -DWHISPER_SDL2=${{ matrix.sdl2 }} -DGGML_NATIVE=OFF - -DGGML_BMI2=OFF + ${{ matrix.arch == 'x64' && '-DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON' || '-DGGML_BMI2=OFF' }} - name: Build run: | @@ -287,6 +289,8 @@ jobs: -DBLAS_LIBRARIES="$env:GITHUB_WORKSPACE/OpenBLAS-${{matrix.blasver}}/lib/libopenblas.lib" -DBLAS_INCLUDE_DIRS="$env:GITHUB_WORKSPACE/OpenBLAS-${{matrix.blasver}}/include" -DWHISPER_SDL2=${{ matrix.sdl2 }} + -DGGML_NATIVE=OFF + ${{ matrix.arch == 'x64' && '-DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON' || '-DGGML_BMI2=OFF' }} - name: Build run: | @@ -490,7 +494,10 @@ jobs: -DWHISPER_SDL2=${{ matrix.sdl2 }} ^ -DSDL2_DIR="%SDL2_DIR%" ^ -DCMAKE_POLICY_VERSION_MINIMUM=3.5 ^ - -DCMAKE_CUDA_FLAGS="%CUDA_FLAGS%" + -DCMAKE_CUDA_FLAGS="%CUDA_FLAGS%" ^ + -DGGML_BACKEND_DL=ON ^ + -DGGML_NATIVE=OFF ^ + -DGGML_CPU_ALL_VARIANTS=ON set /A NINJA_JOBS=%NUMBER_OF_PROCESSORS%-1 cmake --build build --config ${{ matrix.build }} -j %NUMBER_OF_PROCESSORS% diff --git a/CMakeLists.txt b/CMakeLists.txt index 8527d6d9bed..1f95e175af4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -19,6 +19,7 @@ endif() list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/") set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) +set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) if (CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR) set(WHISPER_STANDALONE ON) From f049fff95a089aa9969deb009cdd4892b3e74916 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius <daniel.bevenius@gmail.com> Date: Fri, 19 Jun 2026 06:12:37 +0200 Subject: [PATCH 830/831] release : v1.9.1 (#3892) --- CMakeLists.txt | 2 +- README.md | 2 +- bindings/javascript/package.json | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 1f95e175af4..26037c26538 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,6 +1,6 @@ cmake_minimum_required(VERSION 3.5) # for add_link_options and implicit target directories. project("whisper.cpp" C CXX) -project("whisper.cpp" VERSION 1.9.0) +project("whisper.cpp" VERSION 1.9.1) include(CheckIncludeFileCXX) set(SOVERSION 1) diff --git a/README.md b/README.md index a32d9b61382..0e2d5f100d5 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ [![Conan Center](https://shields.io/conan/v/whisper-cpp)](https://conan.io/center/whisper-cpp) [![npm](https://img.shields.io/npm/v/whisper.cpp.svg)](https://www.npmjs.com/package/whisper.cpp/) -Stable: [v1.9.0](https://github.com/ggml-org/whisper.cpp/releases/tag/v1.9.0) / [Roadmap](https://github.com/orgs/ggml-org/projects/4/) +Stable: [v1.9.1](https://github.com/ggml-org/whisper.cpp/releases/tag/v1.9.1) / [Roadmap](https://github.com/orgs/ggml-org/projects/4/) High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisper) automatic speech recognition (ASR) model: diff --git a/bindings/javascript/package.json b/bindings/javascript/package.json index b777591a4e3..09829326605 100644 --- a/bindings/javascript/package.json +++ b/bindings/javascript/package.json @@ -1,6 +1,6 @@ { "name": "whisper.cpp", - "version": "1.9.0", + "version": "1.9.1", "description": "Whisper speech recognition", "main": "whisper.js", "scripts": { From 0fe839304770af090079369bff4f6769c2bbc733 Mon Sep 17 00:00:00 2001 From: Alcahest <xaris@gzgd.info> Date: Wed, 24 Jun 2026 09:25:20 +0200 Subject: [PATCH 831/831] ci: build CUDA for Blackwell sm_120 (RTX 50) on CUDA 12.9 --- .github/workflows/build-binaries.yml | 43 ++++++++++++++-------------- 1 file changed, 22 insertions(+), 21 deletions(-) diff --git a/.github/workflows/build-binaries.yml b/.github/workflows/build-binaries.yml index aec894d595e..7f9c29d5324 100644 --- a/.github/workflows/build-binaries.yml +++ b/.github/workflows/build-binaries.yml @@ -15,7 +15,8 @@ permissions: contents: write env: - CUDA_ARCHITECTURES: "75;80;86;89" + # RTX 20-50 (Turing through Blackwell). sm_120 requires CUDA Toolkit >= 12.8. + CUDA_ARCHITECTURES: "75;80;86;89;120" jobs: build-macos-arm64: @@ -190,21 +191,21 @@ jobs: - name: Install Ninja run: choco install ninja -y - - name: Install CUDA Toolkit 12.4.0 + - name: Install CUDA Toolkit 12.9.1 run: | - $CUDA_VERSION = "12.4.0" + $CUDA_VERSION = "12.9.1" $CUDA_TOOLKIT_DIR = "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v$CUDA_VERSION" $CUDA_DOWNLOAD = "https://developer.download.nvidia.com/compute/cuda/redist" - # Component versions for CUDA 12.4.0 - $CUDART_VER = "12.4.127" - $NVCC_VER = "12.4.131" - $NVRTC_VER = "12.4.127" - $CUBLAS_VER = "12.4.5.8" - $NVTX_VER = "12.4.127" - $PROFILER_VER = "12.4.127" - $VS_VER = "12.4.127" - $CCCL_VER = "12.4.127" + # Component versions for CUDA 12.9.1 + $CUDART_VER = "12.9.79" + $NVCC_VER = "12.9.86" + $NVRTC_VER = "12.9.86" + $CUBLAS_VER = "12.9.1.4" + $NVTX_VER = "12.9.79" + $PROFILER_VER = "12.9.79" + $VS_VER = "12.9.79" + $CCCL_VER = "12.9.27" # Create CUDA toolkit directory New-Item -ItemType Directory -Force -Path $CUDA_TOOLKIT_DIR @@ -400,7 +401,7 @@ jobs: sudo apt-get update sudo apt-get install -y build-essential cmake wget - - name: Install CUDA Toolkit 12.4 + - name: Install CUDA Toolkit 12.9 run: | # Download and install CUDA keyring wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb @@ -408,16 +409,16 @@ jobs: sudo apt-get update # Install minimal CUDA toolkit (compiler and libraries only, no driver) - sudo apt-get install -y cuda-toolkit-12-4 + sudo apt-get install -y cuda-toolkit-12-9 # Set environment variables - echo "/usr/local/cuda-12.4/bin" >> $GITHUB_PATH - echo "CUDA_PATH=/usr/local/cuda-12.4" >> $GITHUB_ENV - echo "LD_LIBRARY_PATH=/usr/local/cuda-12.4/lib64:$LD_LIBRARY_PATH" >> $GITHUB_ENV + echo "/usr/local/cuda-12.9/bin" >> $GITHUB_PATH + echo "CUDA_PATH=/usr/local/cuda-12.9" >> $GITHUB_ENV + echo "LD_LIBRARY_PATH=/usr/local/cuda-12.9/lib64:$LD_LIBRARY_PATH" >> $GITHUB_ENV - name: Verify CUDA installation run: | - export PATH=/usr/local/cuda-12.4/bin:$PATH + export PATH=/usr/local/cuda-12.9/bin:$PATH nvcc --version - name: Setup ccache @@ -427,8 +428,8 @@ jobs: - name: Build whisper.cpp with CUDA run: | - export PATH=/usr/local/cuda-12.4/bin:$PATH - export CUDA_PATH=/usr/local/cuda-12.4 + export PATH=/usr/local/cuda-12.9/bin:$PATH + export CUDA_PATH=/usr/local/cuda-12.9 cmake -B build \ -DCMAKE_BUILD_TYPE=Release \ -DCMAKE_C_COMPILER_LAUNCHER=ccache \ @@ -436,7 +437,7 @@ jobs: -DBUILD_SHARED_LIBS=OFF \ -DGGML_NATIVE=OFF \ -DGGML_CUDA=ON \ - -DCMAKE_CUDA_COMPILER=/usr/local/cuda-12.4/bin/nvcc \ + -DCMAKE_CUDA_COMPILER=/usr/local/cuda-12.9/bin/nvcc \ -DCMAKE_CUDA_ARCHITECTURES="${{ env.CUDA_ARCHITECTURES }}" cmake --build build --config Release -j $(nproc)